[
  {
    "path": ".devcontainer/README.md",
    "content": "# Dev container configurations\n\nThis directory contains the configuration for dev containers, which is used to\ninitialize the development environment in **Codespaces**, **Visual Studio\nCode**, and **JetBrains IDEs**. The environment is installed with all the\nnecessary dependencies for development and is ready for linting, formatting, and\nrunning tests.\n\n* **GitHub Codespaces**. Create a codespace for the repo by clicking\n    the \"Code\" button on the main page of the repo, selecting the \"Codespaces\"\n    tab, and clicking the \"+\". The configurations will automatically be used.\n    Follow\n    [this guide](https://docs.github.com/en/codespaces/developing-in-a-codespace/creating-a-codespace-for-a-repository)\n    for more details.\n\n* **Visual Studio Code**. Open the root folder of the repo in VS Code. A\n    notification will pop up to open it in a dev container with the\n    configuration. Follow\n    [this guide](https://code.visualstudio.com/docs/devcontainers/tutorial)\n    for more details.\n\n* **JetBrains IDEs**. Open the `.devcontainer/devcontainer.json` in your\n   JetBrains IDE. Click the docker icon to create a dev container.\n   Follow\n   [this guide](https://www.jetbrains.com/help/idea/connect-to-devcontainer.html)\n   for more details."
  },
  {
    "path": ".devcontainer/devcontainer.json",
    "content": "{\n    \"image\": \"mcr.microsoft.com/vscode/devcontainers/python:3.10\",\n    \"postCreateCommand\": \"sh ./.devcontainer/setup.sh && pip install -r requirements.txt\",\n    \"customizations\": {\n        \"vscode\": {\n            \"settings\": {\n                \"python.testing.pytestEnabled\": true,\n                \"editor.formatOnSave\": true,\n                \"editor.codeActionsOnSave\": {\n                    \"source.organizeImports\": true\n                },\n                \"[python]\": {\n                    \"editor.defaultFormatter\": \"charliermarsh.ruff\"\n                },\n                \"editor.rulers\": [\n                    80\n                ]\n            },\n            \"extensions\": [\n                \"charliermarsh.ruff\",\n                \"ms-python.python\"\n            ]\n        }\n    },\n    \"features\": {\n        \"ghcr.io/devcontainers/features/github-cli:1\": {}\n    }\n}\n"
  },
  {
    "path": ".devcontainer/setup.sh",
    "content": "sudo pip install --upgrade pip\nsudo pip install -r requirements.txt\necho \"bash shell/lint.sh\" > .git/hooks/pre-commit\nchmod a+x .git/hooks/pre-commit\n"
  },
  {
    "path": ".gemini/config.yaml",
    "content": "have_fun: false\nmemory_config:\n  disabled: false\ncode_review:\n  disable: false\n  comment_severity_threshold: MEDIUM\n  max_review_comments: -1\npull_request_opened:\n  help: true\n  summary: true\n  code_review: true\ninclude_drafts: true\nignore_patterns: []\n"
  },
  {
    "path": ".gemini/styleguide.md",
    "content": "# Keras API design guidelines\n\nThese guidelines are meant to help focus design discussions and help us create delightful developer experiences.\n\nThese are meant as guidelines, not rules: each decision should be debated in its own unique context.\n\nSome text remixed from external references:\n\n- [User experience design for APIs](https://blog.keras.io/user-experience-design-for-apis.html)\n- [Notes to Myself on Software Engineering](https://medium.com/s/story/notes-to-myself-on-software-engineering-c890f16f4e4d)\n\n---\n\n## Design end-to-end workflows, not individual functions and classes.\n\nWhen developing APIs, start by designing end-to-end workflows, and only sketch out specific function/class signatures at the end.\n\n- The goal is to arrive at workflows that feel like they are purposefully designed and well-optimized, rather than cobbled together to route around the features provided by the API. The workflows should come first, before atomic features. **Features only exist to support a workflow.** No feature should exist to provide a capability \"just in case\", \"because we can\".\n- **Every design review document should prominently feature a code example of one or two end-to-end workflows showing the canonical use-case for the new API.**\n- Every time we discuss choices surrounding a specific API feature, we should start by asking: **in what workflows will this be used?** Then we should make the choice that makes the most sense with respect to these workflows. We should not make API design decisions about features in isolation.\n- This implies that we will often ask the question: **do users really need to configure this parameter?**, and in many cases, the answer will be \"no\", rather than being \"yes\" by default.\n\n---\n\n## Carefully weigh whether a new feature should be included.\n\nIt's okay to say no: just because someone asks for a feature doesn't mean we should do it. Every feature has a cost that goes beyond the initial CL: maintenance cost, documentation cost, and cognitive cost for our users (a sprawling API surface is a major usability issue).\n\nIn particular, in the Keras API, every new feature has to be maintained in perpetuity.\n\nAs such, our criteria for adding a new feature in the API is the following:\n\n- **It should be broadly useful to our users**, rather than a niche feature that is only relevant to a specific vertical of researchers. Niche features should be maintained independently by those who need them (e.g. by extending the API via subclassing), as third-party add-on packages.\n- **It should be widely recognized as a machine learning best practice.** We will not add new layers/etc that were recently published to ArXiv.org, even in case of claims of increased accuracy/etc. We only add new objects that are already commonly used in the machine learning community. Presumably, a new technique that does result in meaningful gains would be broadly adopted after a few months anyway (like ResNet), and that's when we would be adding it to the core API. SIG-addons maintains a repository of significantly more volatile and independently maintained code to which the barriers to entry are lower.\n- **It should have an owner committed to maintaining it in the long term.** In particular, the code should be maintainable by multiple people on the team, not just by one technical guru.\n\nIn addition, when saying yes to a request for supporting a new use case, remember that **literally adding what the user/team requested is often not the optimal choice**. Users are focused on their own specific use case, and we must counter this with a holistic and principled vision of the whole project (see: designing end-to-end workflows, not atomic functions/classes). Often, the right answer is to extend an existing feature. **Find the natural place to integrate the new feature in existing APIs.**\n\n### Examples:\n\n- We should not have added the self-normalizing activation function to the API. It was added before passing the test of time, and that technique has shown later not to reach broad adoption. **Note that citation count is not a good metric of adoption**; that paper has a high citation count.\n- We should not move to core an API that has debuted somewhere on GitHub or TF-Addons but has failed to gain more than a few users after a few months.\n\n---\n\n## Seek to minimize cognitive load for our users.\n\nAlways seek to minimize the cognitive load imposed on our users in the course of using our APIs.\n\nAt a high level:\n\n- **Automate everything that can be automated.**\n- **Minimize the actions & choices required from the user.** Make sure default values for arguments are sensible and reflect best practices (so that users usually wouldn't have to manually configure these). Don't expose options that are not important or do not match real use cases, \"just in case\".\n- **Design simple and consistent workflows that reflect simple and consistent mental models.**\n\nHere are a few practical rules:\n\n- **No API should deal with internal implementation details.** An API is a language for our users to talk about the problem they care about -- and they don't care about our internal hacks. For instance, an option like `use_locking` in an optimizer should be avoided. If an argument requires users to understand the implementation (not just what the code is supposed to implement, like SGD in this case), then the argument should not be included in the public API. **An API is all about the problem it solves, not about how the code works in the background.**\n- **Introduce as few new concepts as possible.** It's not just that additional data structures require more effort in order to learn about their methods and properties, it's that they multiply the number of **mental models** that are necessary to grok your API. Ideally, you should only need **a single universal mental model around which everything is organized** (in Keras, that's the `Layer`). Definitely avoid having more than 2 or 3 mental models underlying the workflows you design. Likewise, avoid having concepts that are mostly overlapping but subtly different, since the difference will be difficult to convey clearly and will confuse our users (like, say, `Network` and `Model` -- this is why we don't export `Network` as a public API).\n- **Objects that do interchangeable things should have identical or very close APIs.** In particular they should have the same positional arguments. For example, it should be possible to swap one optimizer for another in user code (when leaving all arguments to their default value) without editing the arguments.\n- **If you find yourself proposing a signature with more than 6-7 arguments, consider whether all of these arguments are useful.** How many people and use cases would be affected if you removed one argument? How much would they be affected -- would they be able to easily extend the API (e.g. via subclassing) to support their use case without that built-in argument? Could this API be broken up into smaller, modular objects?\n- **Best-practices should come baked into your API.** The simplest way to use your API (leaving all arguments to their default value, using the most obvious tool for the task, etc) should be as close as possible to the best way of solving the problem. In particular, all arguments that can be given a default value should be given a default value, and that default should match the most common use case.\n- **Plain Python types are preferable to custom types.** Use tuples, strings, ints... A custom type requires more knowledge and effort on the part of the user (e.g. `TensorShape`, which is also breaking established conventions of scientific Python). **When using enums, make sure that their values are strings**, so as to make it possible for users to pass plain strings (example: `data_format=\"channels_last\"`, `padding=\"valid\"`).\n- **Explicit, single-level configuration arguments are preferable to nested, hidden configuration arguments.** Avoid something like: `MyLayer(hyperparameter_dict)`, instead use `MyLayer(units, activation=None, ...)`.\n\nIn particular, naming is important and difficult:\n\n- **The meaning of an argument should be clear from its name and should not require knowledge that only the implementers have.** In particular, argument names should only involve recognized terms of art (\"L1 norm\" is a term of art), and should not involve implementation-related vocabulary (e.g. \"fused batchnorm\").\n- **Avoid `OverlyLongAndSpecificNamingPatterns`.** If you find yourself with argument names with involve more than 3 subparts (e.g. \"squared_operator_norm\"), reconsider. Argument names should be intuitive and easy to remember.\n- Avoid overly generic names (`x`, `variable`, `parameter`).\n- **Make sure you are consistent in your naming choices.** Naming consistency means both **internal naming consistency** (don't call `dim` what is called `axis` in other places, don't call `ndims` what is called `ndim` elsewhere) and **consistency with established conventions for the problem domain (terms of art)**. Before settling on a name, make sure to look up existing names used by domain experts (or other APIs). In our case, argument names should be consistent with the broader scientific Python conventions, in particular NumPy.\n\nNote that Keras uses the following naming rules:\n\n- We use the convention `num_*` for counters, though omitting an explicit counter is nicer when there is no ambiguity (e.g. `units`, `epochs`, `filters`).\n- The rank of a tensor is its `ndim`. A specific dimension index is an `axis`. The number of dimensions in a linear projection (or similar) is `units`.\n- By convention Keras layers are named with nouns rather than verbs (e.g. `Normalization` and not `Normalize`, `Convolution` and not `Convolve`).\n- Following Python conventions, classes use capitalized parts (e.g. `ClassName`) and functions and methods use snake case (e.g. `function_name`).\n- If an argument name has a numerical suffix (e.g. `alpha_1`), we put an underscore before the suffix in snake case. The capitalized equivalent would be e.g. `Alpha1`.\n- We used fully spelled-out names, e.g. `attention_scores` and not `attn_scores`. There are a couple standardized exceptions to this rule, in particular `dim` for \"dimension\" and `num` for \"number\". These are sufficiently common that they are not ambiguous to a first-time reader.\n\n### Example:\n\n```python\nMyConstructor(\n   per_variable_sparsity_config=[\n      'layer_1/kernel:0.8', 'layer_2/kernel:1.5'])\n```\n\nWhat's wrong with this?\n\n- Overly long argument name\n- Too much cognitive load involved in preparing an appropriate argument value\n- Preparing an argument value requires internal implementation knowledge\n- Reliance on TF variable names (subject to changes at any time, thus breaking this code)\n- Nested config adding indirection\n- Incorrect typing (float values being passing as strings)\n\nPossible alternative:\n\n```\nobj = MyConstructor()\nobj.configure_sparsity(some_layer.kernel, value=0.8)\nobj.configure_sparsity(some_other_layer.kernel, value=1.5)\n```\n\nWhat's nice about this?\n\n- Object-based variable references.\n- Modular, simple action, with a clear name.\n- Plain Python types.\n\n---\n\n## Balance expressivity vs. user-friendliness.\n\n### Simple use cases should be simple, advanced use cases should be possible:\n\n**Don't increase the cognitive load of common use cases for the sake of niche use cases**, even minimally.\n**Make sure that advanced users have a path to support their use case**, even if this path requires the users to roll out plugins or other API extensions (in particular via subclassing). **It is ok for advanced use cases not to be directly supported in the built-in API options.**\n\n### Keep our APIs modular.\n\n**Complex objects should be achievable by composing simple objects with few arguments, that do one thing reliably.** There is a balance to strike between having complex signatures on fewer objects, and having more objects with simpler signatures. A good API has a reasonable number of objects, with reasonably simple signatures (see also: avoiding signatures with more than 6-7 arguments).\n\n**Things that create state or side-effects should be classes. Functions should be stateless.**\nFor instance, layers that create weights should not be cast as functions, since it makes the weights (and other elements of state) hard to access, impossible to update, and forces reliance on a global state capturing the side effects of layer-functions.\n\n### APIs should be strictly compartmentalized.\n\nFor instance, the optimizer API or the layers API should not contain arguments for configuring distributed training. That should go into the distribution API.\n\n---\n\n## Don't neglect error messages, docstrings, and documentation.\n\nDocumentation and error messages are an integral part of the API. Good docs and helpful error messages are key to a delightful user experience.\n\n- **Catch user errors early and anticipate common mistakes.** Do user input validation as soon as possible. Actively keep track of common mistakes that people make (by screening GitHub and StackOverflow), and either solve them by simplifying our API, adding targeted error messages for these mistakes, or having a \"solutions to common issues\" page in our docs. Consider adding automated fallback behaviors (e.g. casting a wrongly-typed input) instead of raising errors, when applicable. Be nice to our users.\n- **Provide detailed feedback messages upon user error.** Error messages should be contextual, informative, and actionable. Every error message that transparently provides the user with the solution to their problem means one less support ticket, multiplied by how many times users run into the same issue. A good error message should answer:\n    - What happened, in what context?\n    - What did the software expect?\n    - How can the user fix it?\n- **A docstring should answer the question: what is this about, and why & how should I use it?** It should assume as little context as possible, and it shouldn't mention specialized terms without first introducing them (for example, \"num_blocks: Number of blocks in the kernel\" is not a good argument description if this is the first time you mention \"blocks\" in your docstring).\n- **Show, don't tell: your documentation should not talk about how the software works, it should show how to use it.** Show code examples for end-to-end workflows; show code examples for each and every common use case and key feature of your API. **All docstrings should include code examples.**\n- **Deliberately design the user onboarding process for your feature.** How are complete newcomers going to find out the best way to solve their use case with your tool? Have an answer ready. Make sure your onboarding material closely maps to what your users care about: don't teach newcomers how your framework is implemented, teach them how they can use it to solve their own problems. After shipping a CL and writing good docstrings, make sure to create a Colab guide / tutorial showcasing the target workflow, and post it on the docs website.\n- The feature is not ready until:\n    - 1) Users know about it\n    - 2) They know how to use it\n    - 3) They're actually using it to solve the corresponding problem.\n\nNote that Keras uses the following rules for writing docstrings:\n\n- For class docstrings, document arguments in a `Arguments:` section in the class docstring, not in `__init__`.\n    - When a user creates a class, they are not calling the `MyLayer.__init__()` method as if it were a regular method, they are calling `MyLayer`. We don't want to generate documentation for the `__init__()` method as a standalone method that needs to be called directly, that would be confusing. We also don't need `__init__()` docstrings that always start with \"Initializes a MyLayer class.\", which is useless information. Leaving `__init__()` without a docstring is the best practice.\n    - If constructor arguments are documented in `__init__`, it forces us to programmatically copy the `__init__` docstring when generating docs and concatenate it to the class docstring. This means that the Arguments section becomes the last thing in the docstring, which is bad.\n- The order of information in a class docstring should be:\n    - One-line description of the class, that gives initial context to the user. e.g. `Applies Dropout to the input.` Make sure the one-line description is useful. No `Intantiates an ObscureName class instance.`\n    - Paragraph(s) of more detailed information that tells the user what the object is for and when they need to use it. e.g. `The Dropout layer randomly sets input units to 0 with a frequency of \"rate\" at each step during training time, which helps prevent overfitting. Inputs not set to 0 are scaled up by \"1/(1 - rate)\" such that the sum over all inputs is unchanged. [...]`\n    - If there is a reference paper, cite it here.\n    - `Arguments` section.\n    - If it's a layer that has arguments in `call`, the `Call arguments` section.\n    - If it's a `Layer`, `Input shape` and `Output shape` sections.\n    - Example(s).\n    - Lastly, addendum. Information that isn't very important and that most users don't need, but that should be documented somewhere.\n        - e.g. the section \"About the layer's `dtype` attribute\" in the base Layer class.\n        - e.g. warnings about edge cases or compatibility issues.\n        - e.g. pointers to further guides and tutorials.\n\n### Error messages: a case study\n\nThe following would be a very poor error message:\n\n```\nAssertionError: '1 != 3'\n```\n\nIn general, to validate user input, always use `ValueError` and avoid `assert`.\n\nAlso bad:\n\n```\nValueError: 'Invalid target shape (600, 1).'\n```\n\nThe following is better, but still not sufficient, because it does not tell the user what they passed, and does not quite say how to fix it:\n\n```\nValueError: 'categorical_crossentropy requires target.shape[1] == classes'\n```\n\nNow, here's a good example, that says **what was passed**, **what was expected**, and **how to fix the issue**:\n\n```\nValueError: '''You are passing a target array of shape (600, 1) while using as loss `categorical_crossentropy`.\n`categorical_crossentropy` expects targets to be binary matrices (1s and 0s) of shape (samples, classes).\nIf your targets are integer classes, you can convert them to the expected format via:\n\n---\nfrom keras.utils import to_categorical\ny_binary = to_categorical(y_int)\n---\n\nAlternatively, you can use the loss function `sparse_categorical_crossentropy` instead, which does expect integer targets.\n```\n\n---\n\nWhen performing code reviews on pull requests, you must strictly adhere to the following principles in addition to the API design guidelines above:\n\n1. **Question the Necessity of Changes**: Do not assume that the pull request changes are strictly necessary. Critically review the proposed changes to ensure they add real value. Point out any code that is solving a non-existent problem or adding unnecessary complexity.\n\n2. **Call out \"AI Slop\"**: Actively look for and identify \"AI slop\"—generic, overly verbose, or hallucinated code that lacks context or violates best practices. If you suspect the code is AI slop, explicitly call it out.\n\n3. **Poke Holes in the Implementation**: Your goal is to critically test the logic. Actively search for and point out failing edge cases, race conditions, or unhandled exceptions in the implementation.\n\n4. **Demand Robustness**: Do not accept fragile code. If the proposed code is not robust enough or lacks proper error handling, explicitly tell the author why the current approach is brittle and what must be done to reinforce it.\n\n5. **Respect Existing Repo Patterns**: Before suggesting review comments (like asking users to add boilerplate or specific patterns), actively check for existing design patterns across the repository. Do not suggest adding useless code or structures that contradict or fall outside the established Keras repo coding style.\n"
  },
  {
    "path": ".github/dependabot.yml",
    "content": "# To get started with Dependabot version updates, you'll need to specify which\n# package ecosystems to update and where the package manifests are located.\n# Please see the documentation for all configuration options:\n# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates\n\nversion: 2\nupdates:\n  - package-ecosystem: \"github-actions\"\n    directory: \"/\"\n    schedule:\n      interval: \"monthly\"\n    groups:\n      github-actions:\n        patterns:\n          - \"*\"\n  - package-ecosystem: \"pip\"\n    directory: \"/\"\n    schedule:\n      interval: \"monthly\"\n    groups:\n      python:\n        patterns:\n          - \"*\"\n    ignore:\n      # 2.19.1 is the last version of the TensorFlow that supports TPUs.\n      - dependency-name: \"tensorflow-tpu\"\n      # TODO: ignore all updates for JAX GPU due to cuda version issue\n      - dependency-name: \"jax[cuda12_pip]\"\n      # TODO(#21914): Update this version when TF is updated\n      - dependency-name: \"ai-edge-litert\"\n"
  },
  {
    "path": ".github/workflows/actions.yml",
    "content": "name: Tests\n\n# TODO: Consider enabling all tests (pytest, applications, etc.) with NNX in the future\n# Currently only basic flow tests run with NNX enabled\n\non:\n  push:\n    branches: [ master ]\n  pull_request:\n  release:\n    types: [created]\n\npermissions:\n  contents: read\n\njobs:\n  build:\n    strategy:\n      fail-fast: false\n      matrix:\n        python-version: ['3.11']\n        backend: [tensorflow, jax, torch, numpy, openvino]\n        nnx_enabled: [false]\n        include:\n          - python-version: '3.11'\n            backend: jax\n            nnx_enabled: true\n    name: ${{ matrix.backend == 'jax' && format('Run tests ({0}, {1}, nnx_enabled = {2})', matrix.python-version, matrix.backend, matrix.nnx_enabled) || format('Run tests ({0}, {1})', matrix.python-version, matrix.backend) }}\n    runs-on: ubuntu-latest\n    env:\n      PYTHON: ${{ matrix.python-version }}\n      KERAS_HOME: .github/workflows/config/${{ matrix.backend }}\n    steps:\n      - uses: actions/checkout@v6.0.2\n      - name: Check for changes in keras/src/applications\n        uses: dorny/paths-filter@v3\n        id: filter\n        with:\n          filters: |\n            applications:\n              - 'keras/src/applications/**'\n      - name: Set up Python\n        uses: actions/setup-python@v6\n        with:\n          python-version: ${{ matrix.python-version }}\n      - name: Get pip cache dir\n        id: pip-cache\n        run: |\n          python -m pip install --upgrade pip setuptools\n          echo \"dir=$(pip cache dir)\" >> $GITHUB_OUTPUT\n      - name: pip cache\n        uses: actions/cache@v5\n        with:\n          path: ${{ steps.pip-cache.outputs.dir }}\n          key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }}\n      - name: Install dependencies\n        run: |\n          pip install -r requirements.txt --progress-bar off --upgrade\n          if [ \"${{ matrix.nnx_enabled }}\" == \"true\" ]; then\n            pip install --upgrade flax>=0.11.1\n          fi\n          pip install --no-deps tf_keras==2.20.0\n          pip uninstall -y keras keras-nightly\n          pip install -e \".\" --progress-bar off --upgrade\n      - name: Test applications with pytest\n        if: ${{ steps.filter.outputs.applications == 'true' && matrix.nnx_enabled == false }}\n        run: |\n          pytest keras/src/applications --cov=keras/src/applications --cov-config=pyproject.toml\n          coverage xml --include='keras/src/applications/*' -o apps-coverage.xml\n      - name: Codecov keras.applications\n        if: ${{ steps.filter.outputs.applications == 'true' && matrix.nnx_enabled == false }}\n        uses: codecov/codecov-action@v5\n        with:\n          env_vars: PYTHON,KERAS_HOME\n          flags: keras.applications,keras.applications-${{ matrix.backend }}\n          files: apps-coverage.xml\n          token: ${{ secrets.CODECOV_TOKEN }}\n          fail_ci_if_error: false\n      - name: Test integrations\n        if: ${{ matrix.backend != 'numpy' && matrix.nnx_enabled == false }}\n        run: |\n          python integration_tests/import_test.py\n          python integration_tests/numerical_test.py\n      - name: Test JAX-specific integrations\n        if: ${{ matrix.backend == 'jax' && matrix.nnx_enabled == false }}\n        run: |\n          python integration_tests/jax_custom_fit_test.py\n      - name: Test basic flow with NNX\n        if: ${{ matrix.nnx_enabled == true }}\n        env:\n          KERAS_NNX_ENABLED: true\n        run: |\n          python integration_tests/import_test.py\n          python integration_tests/basic_full_flow.py\n      - name: Test TF-specific integrations\n        if: ${{ matrix.backend == 'tensorflow'}}\n        run: |\n          python integration_tests/tf_distribute_training_test.py\n          python integration_tests/tf_custom_fit_test.py\n      - name: Test Torch-specific integrations\n        if: ${{ matrix.backend == 'torch'}}\n        run: |\n          pytest integration_tests/torch_workflow_test.py\n          python integration_tests/torch_custom_fit_test.py\n      - name: Test with pytest\n        if: ${{ matrix.nnx_enabled == false }}\n        run: |\n          if [ \"${{ matrix.backend }}\" == \"openvino\" ]; then\n            IGNORE_FILE=\"keras/src/backend/openvino/excluded_tests.txt\"\n            IGNORE_ARGS=$(awk '{print \"--ignore=\" $0}' \"$IGNORE_FILE\")\n          else\n            IGNORE_ARGS=\"\"\n          fi\n          pytest keras --ignore keras/src/applications --cov=keras --cov-config=pyproject.toml $IGNORE_ARGS\n          coverage xml --omit='keras/src/applications/*,keras/api' -o core-coverage.xml\n      - name: Codecov keras\n        if: ${{ matrix.nnx_enabled == false }}\n        uses: codecov/codecov-action@v5\n        with:\n          env_vars: PYTHON,KERAS_HOME,KERAS_NNX_ENABLED\n          flags: keras,keras-${{ matrix.backend }}${{ matrix.nnx_enabled == 'true' && '-nnx' || '' }}\n          files: core-coverage.xml\n          token: ${{ secrets.CODECOV_TOKEN }}\n          fail_ci_if_error: false\n\n  format:\n    name: Check the code format\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v6.0.2\n      - name: Set up Python 3.11\n        uses: actions/setup-python@v6\n        with:\n          python-version: '3.11'\n      - name: Get pip cache dir\n        id: pip-cache\n        run: |\n          python -m pip install --upgrade pip setuptools\n          echo \"dir=$(pip cache dir)\" >> $GITHUB_OUTPUT\n      - name: pip cache\n        uses: actions/cache@v5\n        with:\n          path: ${{ steps.pip-cache.outputs.dir }}\n          key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }}\n      - name: Install dependencies\n        run: |\n          pip install -r requirements.txt --progress-bar off --upgrade\n          pip uninstall -y keras keras-nightly\n          pip install -e \".\" --progress-bar off --upgrade\n      - name: Run pre-commit\n        run: pre-commit run --all-files --hook-stage manual"
  },
  {
    "path": ".github/workflows/auto-assignment.yaml",
    "content": "name: auto-assignment\non:\n  issues:\n    types:\n      - opened\n\npermissions:\n  contents: read\n  issues: write\n  pull-requests: write\n\njobs:\n  welcome:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v6.0.2\n      - uses: actions/github-script@v8\n        with:\n          script: |\n            const script = require('./\\.github/workflows/scripts/auto-assignment.js')\n            script({github, context})"
  },
  {
    "path": ".github/workflows/config/jax/keras.json",
    "content": "{\n    \"floatx\": \"float32\",\n    \"epsilon\": 1e-07,\n    \"backend\": \"jax\",\n    \"image_data_format\": \"channels_last\",\n    \"nnx_enabled\": false\n}\n"
  },
  {
    "path": ".github/workflows/config/numpy/keras.json",
    "content": "{\n    \"floatx\": \"float32\",\n    \"epsilon\": 1e-07,\n    \"backend\": \"numpy\",\n    \"image_data_format\": \"channels_last\"\n}\n"
  },
  {
    "path": ".github/workflows/config/openvino/keras.json",
    "content": "{\n    \"floatx\": \"float32\",\n    \"epsilon\": 1e-07,\n    \"backend\": \"openvino\",\n    \"image_data_format\": \"channels_last\"\n}\n"
  },
  {
    "path": ".github/workflows/config/tensorflow/keras.json",
    "content": "{\n    \"floatx\": \"float32\",\n    \"epsilon\": 1e-07,\n    \"backend\": \"tensorflow\",\n    \"image_data_format\": \"channels_last\"\n}\n"
  },
  {
    "path": ".github/workflows/config/torch/keras.json",
    "content": "{\n    \"floatx\": \"float32\",\n    \"epsilon\": 1e-07,\n    \"backend\": \"torch\",\n    \"image_data_format\": \"channels_first\"\n}\n"
  },
  {
    "path": ".github/workflows/gpu_tests.yml",
    "content": "name: Keras GPU Tests\n\non:\n  push:\n    branches: [master]\n  pull_request:\n    types: [unlabeled]\n  release:\n    types: [created]\n\npermissions:\n  contents: read\n\njobs:\n\n  test-in-container:\n    name: Run tests on GPU\n    runs-on: linux-x86-g2-16-l4-1gpu\n    # Only run on pushes to master, releases or \"kokoro:force-run\" unlabel\n    if: |\n      github.event_name == 'push' ||\n      github.event_name == 'release' ||\n      (github.event_name == 'pull_request' && github.event.action == 'unlabeled' && github.event.label.name == 'kokoro:force-run')\n\n    strategy:\n      fail-fast: false\n      matrix:\n        backend: [torch]\n\n    container:\n      image: python:3.11-slim\n      options: --privileged --network host\n\n    steps:\n      - name: Checkout Repository\n        uses: actions/checkout@v4\n\n      - name: Check CUDA Version\n        run: nvidia-smi\n\n      - name: Install Dependencies\n        run: pip install --no-cache-dir -r requirements-${{ matrix.backend }}-cuda.txt\n\n      - name: Set Keras Backend\n        run: echo \"KERAS_BACKEND=jax\" >> $GITHUB_ENV\n\n      - name: Verify TF Installation\n        if: ${{ matrix.backend == 'tensorflow'}}\n        run: python3 -c \"import tensorflow as tf; print('Tensorflow devices:', tf.config.list_logical_devices()); assert len(tf.config.list_physical_devices('GPU')) > 0\"\n\n      - name: Verify JAX Installation\n        if: ${{ matrix.backend == 'jax'}}\n        run: python3 -c \"import jax; print('JAX devices:', jax.devices()); assert jax.default_backend() == 'gpu'\"\n\n      - name: Verify Torch Installation\n        if: ${{ matrix.backend == 'torch'}}\n        run: python3 -c \"import torch; print('Torch devices:', [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]); assert torch.cuda.device_count() > 0\"\n\n      - name: Run Tests\n        run: pytest -s keras --ignore keras/src/applications --cov=keras --cov-config=pyproject.toml\n\n      - name: Run Distribution Tests\n        if: ${{ matrix.backend == 'jax'}}\n        run: pytest keras/src/distribution/distribution_lib_test.py --cov=keras --cov-config=pyproject.toml\n"
  },
  {
    "path": ".github/workflows/labeler.yaml",
    "content": "# Copyright 2024 Google LLC. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# you may not use this file except in compliance with the License.\n# You may obtain a copy of the License at\n#\n#     http://www.apache.org/licenses/LICENSE-2.0\n#\n# Unless required by applicable law or agreed to in writing, software\n# distributed under the License is distributed on an \"AS IS\" BASIS,\n# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n# See the License for the specific language governing permissions and\n# limitations under the License.\n\n\n# This workflow automatically identifies issues and pull requests (PRs) and add the \n# appropriate label as per defined rules.\n# First Labeler workflow: It searches for the keyword \"Gemma\" (case-insensitive) in both the title \n# and description of the issue/PR. If a match is found, the workflow adds the label 'Gemma' to the issue/PR.\n\nname: 'Labeler'\non:\n  issues:\n    types: [edited,opened]\n  pull_request_target:\n    types: [opened, edited]\n\npermissions:\n  contents: read\n  issues: write\n  pull-requests: write\n\njobs:\n  add_labels:\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v6.0.2\n      - uses: actions/github-script@v8\n        with:\n          script: |\n            const script = require('./\\.github/workflows/scripts/labeler.js')\n            script({github, context})"
  },
  {
    "path": ".github/workflows/nightly.yml",
    "content": "name: Nightly\n\non:\n  workflow_dispatch: # To Generate wheels on demand outside of schedule.\n  schedule:\n    - cron: \"0 3 * * *\" # run at 3 AM UTC / 8 PM PDT\n\npermissions:\n  contents: read\n\njobs:\n  build:\n    strategy:\n      fail-fast: false\n      matrix:\n        python-version: [\"3.11\"]\n        backend: [tensorflow, jax, torch, numpy]\n    name: Run tests (Python ${{ matrix.python-version }})\n    runs-on: ubuntu-latest\n    env:\n      PYTHON: ${{ matrix.python-version }}\n      KERAS_BACKEND: ${{ matrix.backend }}\n    steps:\n      - uses: actions/checkout@v6.0.2\n      - name: Set up Python\n        uses: actions/setup-python@v6\n        with:\n          python-version: ${{ matrix.python-version }}\n      - name: Get pip cache dir\n        id: pip-cache\n        run: |\n          python -m pip install --upgrade pip setuptools\n          echo \"dir=$(pip cache dir)\" >> $GITHUB_OUTPUT\n      - name: pip cache\n        uses: actions/cache@v5\n        with:\n          path: ${{ steps.pip-cache.outputs.dir }}\n          key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }}\n      - name: Install dependencies\n        run: |\n          pip install -r requirements.txt --progress-bar off --upgrade\n          pip uninstall -y keras keras-nightly\n          pip install -e \".\" --progress-bar off --upgrade\n      - name: Test integrations\n        if: ${{ matrix.backend != 'numpy'}}\n        run: |\n          python integration_tests/import_test.py\n      - name: Test TF-specific integrations\n        if: ${{ matrix.backend == 'tensorflow'}}\n        run: |\n          python integration_tests/tf_distribute_training_test.py\n      - name: Test Torch-specific integrations\n        if: ${{ matrix.backend == 'torch'}}\n        run: |\n          pytest integration_tests/torch_workflow_test.py\n      - name: Test with pytest\n        run: |\n          pytest keras --ignore keras/src/applications --cov=keras --cov-config=pyproject.toml\n\n  format:\n    name: Check the code format\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v6.0.2\n      - name: Set up Python 3.11\n        uses: actions/setup-python@v6\n        with:\n          python-version: \"3.11\"\n      - name: Get pip cache dir\n        id: pip-cache\n        run: |\n          python -m pip install --upgrade pip setuptools\n          echo \"dir=$(pip cache dir)\" >> $GITHUB_OUTPUT\n      - name: pip cache\n        uses: actions/cache@v5\n        with:\n          path: ${{ steps.pip-cache.outputs.dir }}\n          key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }}\n      - name: Install dependencies\n        run: |\n          pip install -r requirements.txt --progress-bar off --upgrade\n          pip uninstall -y keras keras-nightly\n          pip install -e \".\" --progress-bar off --upgrade\n      - name: Run pre-commit\n        run: pre-commit run --all-files --hook-stage manual\n\n  nightly:\n    name: Build Wheel file and upload\n    needs: [build, format]\n    runs-on: ubuntu-latest\n    steps:\n      - uses: actions/checkout@v6.0.2\n      - name: Set up Python\n        uses: actions/setup-python@v6\n        with:\n          python-version: \"3.11\"\n      - name: Install dependencies\n        run: |\n          python -m pip install --upgrade pip setuptools\n          pip install twine\n          pip install -r requirements.txt --progress-bar off --upgrade\n          pip uninstall -y keras keras-nightly\n      - name: Build wheel file\n        run: |\n          python pip_build.py --nightly\n      - name: Publish to PyPI\n        uses: pypa/gh-action-pypi-publish@release/v1\n        with:\n          password: ${{ secrets.PYPI_NIGHTLY_API_TOKEN }}\n          packages-dir: dist/\n          verbose: true\n"
  },
  {
    "path": ".github/workflows/scorecard.yml",
    "content": "name: Scorecard supply-chain security\non:\n  # For Branch-Protection check. Only the default branch is supported. See\n  # https://github.com/ossf/scorecard/blob/main/docs/checks.md#branch-protection\n  branch_protection_rule:\n  # To guarantee Maintained check is occasionally updated. See\n  # https://github.com/ossf/scorecard/blob/main/docs/checks.md#maintained\n  schedule:\n    - cron: '42 8 * * 2'\n  push:\n    branches: [ \"master\" ]\n\n# Declare default permissions as read only.\npermissions: read-all\n\njobs:\n  analysis:\n    name: Scorecard analysis\n    runs-on: ubuntu-latest\n    permissions:\n      # Needed to upload the results to code-scanning dashboard.\n      security-events: write\n      # Needed to publish results and get a badge (see publish_results below).\n      id-token: write\n\n    steps:\n      - name: \"Checkout code\"\n        uses: actions/checkout@0c366fd6a839edf440554fa01a7085ccba70ac98 # v4.1.1\n        with:\n          persist-credentials: false\n\n      - name: \"Run analysis\"\n        uses: ossf/scorecard-action@4eaacf0543bb3f2c246792bd56e8cdeffafb205a # v2.4.3\n        with:\n          results_file: results.sarif\n          results_format: sarif\n          # (Optional) \"write\" PAT token. Uncomment the `repo_token` line below if:\n          # - you want to enable the Branch-Protection check on a *public* repository, or\n          # - you are installing Scorecard on a *private* repository\n          # To create the PAT, follow the steps in https://github.com/ossf/scorecard-action#authentication-with-pat.\n          # repo_token: ${{ secrets.SCORECARD_TOKEN }}\n\n          # Publish results to OpenSSF REST API for easy access by consumers\n          # Allows the repository to include the Scorecard badge.\n          # See https://github.com/ossf/scorecard-action#publishing-results.\n          publish_results: true\n\n      # Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF\n      # format to the repository Actions tab.\n      - name: \"Upload artifact\"\n        uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0\n        with:\n          name: SARIF file\n          path: results.sarif\n          retention-days: 5\n\n      # Upload the results to GitHub's code scanning dashboard.\n      - name: \"Upload to code-scanning\"\n        uses: github/codeql-action/upload-sarif@89a39a4e59826350b863aa6b6252a07ad50cf83e # v3.29.5\n        with:\n          sarif_file: results.sarif\n"
  },
  {
    "path": ".github/workflows/scripts/auto-assignment.js",
    "content": "/**\n * @license\n * Copyright 2023 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (the \"License\");\n * you may not use this file except in compliance with the License.\n * You may obtain a copy of the License at\n *\n * http://www.apache.org/licenses/LICENSE-2.0\n *\n * Unless required by applicable law or agreed to in writing, software\n * distributed under the License is distributed on an \"AS IS\" BASIS,\n * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n * See the License for the specific language governing permissions and\n * limitations under the License.\n * =============================================================================\n */\n\n/** Automatically assign issues and PRs to users in the `assigneesList` \n *  on a rotating basis.\n\n  @param {!object}\n    GitHub objects can call GitHub APIs using their built-in library functions.\n    The context object contains issue and PR details.\n*/\n\nmodule.exports = async ({ github, context }) => {\n  let issueNumber;\n  let assigneesList;\n  // Is this an issue? If so, assign the issue number. Otherwise, assign the PR number.\n  if (context.payload.issue) {\n    //assignee List for issues. \n    assigneesList = [\"mehtamansi29\", \"sachinprasadhs\"];\n    issueNumber = context.payload.issue.number;\n  } else {\n    //assignee List for PRs. \n    assigneesList = [];\n    issueNumber = context.payload.number;\n  }\n  console.log(\"assignee list\", assigneesList);\n  console.log(\"entered auto assignment for this issue:  \", issueNumber);\n  if (!assigneesList.length) {\n    console.log(\"No assignees found for this repo.\");\n    return;\n  }\n  let noOfAssignees = assigneesList.length;\n  let selection = issueNumber % noOfAssignees;\n  let assigneeForIssue = assigneesList[selection];\n\n  console.log(\n    \"issue Number = \",\n    issueNumber + \" , assigning to: \",\n    assigneeForIssue\n  );\n  return github.rest.issues.addAssignees({\n    issue_number: context.issue.number,\n    owner: context.repo.owner,\n    repo: context.repo.repo,\n    assignees: [assigneeForIssue],\n  });\n};\n"
  },
  {
    "path": ".github/workflows/scripts/labeler.js",
    "content": "/*\nCopyright 2024 Google LLC. All Rights Reserved.\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou may not use this file except in compliance with the License.\nYou may obtain a copy of the License at\n\n     http://www.apache.org/licenses/LICENSE-2.0\n\n Unless required by applicable law or agreed to in writing, software\n distributed under the License is distributed on an \"AS IS\" BASIS,\n WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n See the License for the specific language governing permissions and\n limitations under the License.\n*/\n\n\n/**\n * Invoked from labeler.yaml file to add\n * label 'Gemma' to the issue and PR for which have gemma keyword present.\n * @param {!Object.<string,!Object>} github contains pre defined functions.\n *  context Information about the workflow run.\n */\n\nmodule.exports = async ({ github, context }) => {\n    const issue_title = context.payload.issue ?  context.payload.issue.title : context.payload.pull_request.title\n    const issue_description = context.payload.issue ? context.payload.issue.body : context.payload.pull_request.body\n    const issue_number = context.payload.issue ? context.payload.issue.number : context.payload.pull_request.number\n    const keyword_label =  {\n         gemma:'Gemma'\n    }\n    const labelsToAdd = []\n    console.log(issue_title,issue_description,issue_number)\n    \n    for(const [keyword, label] of Object.entries(keyword_label)){\n     if(issue_title.toLowerCase().indexOf(keyword) !=-1 || issue_description.toLowerCase().indexOf(keyword) !=-1 ){\n        console.log(`'${keyword}'keyword is present inside the title or description. Pushing label '${label}' to row.`)\n        labelsToAdd.push(label)\n    }\n   }\n   if(labelsToAdd.length > 0){\n    console.log(`Adding labels ${labelsToAdd} to the issue '#${issue_number}'.`)\n     github.rest.issues.addLabels({\n        owner: context.repo.owner,\n        repo: context.repo.repo,\n        issue_number: context.issue.number,\n        labels: labelsToAdd\n     })\n   }\n};"
  },
  {
    "path": ".github/workflows/stale-issue-pr.yaml",
    "content": "name: Close inactive issues\non:\n  schedule:\n    - cron: \"30 1 * * *\"\njobs:\n  close-issues:\n    # Don't do this in forks\n    if: github.repository == 'keras-team/keras'\n    runs-on: ubuntu-latest\n    permissions:\n      issues: write\n      pull-requests: write\n      actions: write\n    steps:\n      - name: Awaiting response issues\n        uses: actions/stale@v10\n        with:\n          operations-per-run: 500\n          days-before-issue-stale: 14\n          days-before-issue-close: 14\n          stale-issue-label: \"stale\"\n          # reason for closed the issue default value is not_planned\n          close-issue-reason: completed\n          only-labels: \"stat:awaiting response from contributor\"\n          stale-issue-message: > \n            This issue is stale because it has been open for 14 days with no activity.\n            It will be closed if no further activity occurs. Thank you.\n          # List of labels to remove when issues/PRs unstale. \n          labels-to-remove-when-unstale: \"stat:awaiting response from contributor\"\n          close-issue-message: >\n            This issue was closed because it has been inactive for 28 days.\n            Please reopen if you'd like to work on this further.\n          days-before-pr-stale: 14\n          days-before-pr-close: 14\n          stale-pr-message: \"This PR is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.\"\n          close-pr-message: \"This PR was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further.\"\n          repo-token: ${{ secrets.GITHUB_TOKEN }}\n      - name: Contribution issues\n        uses: actions/stale@v10\n        with:\n          operations-per-run: 500\n          days-before-issue-stale: 180\n          days-before-issue-close: 365\n          stale-issue-label: \"stale\"\n          # reason for closed the issue default value is not_planned\n          close-issue-reason: not_planned\n          any-of-labels: \"stat:contributions welcome,good first issue\"\n          # List of labels to remove when issues/PRs unstale. \n          labels-to-remove-when-unstale: \"stat:contributions welcome,good first issue\"\n          stale-issue-message: > \n            This issue is stale because it has been open for 180 days with no activity.\n            It will be closed if no further activity occurs. Thank you.\n          close-issue-message: >\n            This issue was closed because it has been inactive for more than 1 year.\n          repo-token: ${{ secrets.GITHUB_TOKEN }}\n"
  },
  {
    "path": ".github/workflows/tpu_tests.yml",
    "content": "name: Keras TPU Tests\n\non:\n  push:\n    branches: [master]\n  pull_request:\n    types: [unlabeled]\n  release:\n    types: [created]\n\npermissions:\n  contents: read\n\njobs:\n\n  test-in-container:\n    name: Run tests on TPU\n    runs-on: linux-x86-ct6e-44-1tpu\n    # Only run on pushes to master, releases or \"kokoro:force-run\" unlabel\n    if: |\n      github.event_name == 'push' ||\n      github.event_name == 'release' ||\n      (github.event_name == 'pull_request' && github.event.action == 'unlabeled' && github.event.label.name == 'kokoro:force-run')\n\n    strategy:\n      fail-fast: false\n      matrix:\n        backend: [jax]\n\n    container:\n      image: python:3.11-slim\n      options: --privileged --network host\n\n    steps:\n      - name: Checkout Repository\n        uses: actions/checkout@v6.0.2\n      \n      - name: Install Dependencies\n        run: pip install --no-cache-dir -r requirements-${{ matrix.backend }}-tpu.txt\n      \n      - name: Set Keras Backend\n        run: echo \"KERAS_BACKEND=jax\" >> $GITHUB_ENV\n\n      - name: Verify JAX Installation\n        run: python3 -c \"import jax; print('JAX devices:', jax.devices()); assert jax.default_backend() == 'tpu'\"\n\n      - name: Run Tests\n        run: pytest keras --ignore keras/src/applications --cov=keras --cov-config=pyproject.toml\n"
  },
  {
    "path": ".gitignore",
    "content": ".DS_Store\n*.pyc\n.vscode-test\n__pycache__\n**/.vscode-test/**\n**/.vscode test/**\n**/.vscode-smoke/**\n**/.venv*/\nvenv\nbin/**\nbuild/**\nobj/**\n.pytest_cache\ntmp/**\n.vs/\ndist/**\n**/*.egg-info/*\n.vscode\nexamples/**/*.jpg\n.python-version\n.coverage\n*coverage.xml\n.ruff_cache\n\npytest.ini\nvenv/"
  },
  {
    "path": ".kokoro/README.md",
    "content": "CI to run on PR and merge to Master."
  },
  {
    "path": ".kokoro/github/ubuntu/gpu/build.sh",
    "content": "set -e\nset -x\n\ncd \"${KOKORO_ROOT}/\"\n\nsudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1\n\nPYTHON_BINARY=\"/usr/bin/python3.10\"\n\n\"${PYTHON_BINARY}\" -m venv venv\nsource venv/bin/activate\n# Check the python version\npython --version\npython3 --version\n\n# setting the LD_LIBRARY_PATH manually is causing segmentation fault\n#export LD_LIBRARY_PATH=\"$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:\"\n# Check cuda\nnvidia-smi\nnvcc --version\n\ncd \"src/github/keras\"\npip install -U pip setuptools\n# psutil is used by background log reader\npip install -U psutil\n\nif [ \"$KERAS_BACKEND\" == \"tensorflow\" ]\nthen\n   echo \"TensorFlow backend detected.\"\n   pip install -r requirements-tensorflow-cuda.txt --progress-bar off --timeout 1000\n   pip uninstall -y keras keras-nightly\n   echo \"Check that TensorFlow uses GPU\"\n   python3 -c 'import tensorflow as tf;print(tf.__version__);print(tf.config.list_physical_devices(\"GPU\"))'\n   # Raise error if GPU is not detected.\n   python3 -c 'import tensorflow as tf;assert len(tf.config.list_physical_devices(\"GPU\")) > 0'\n\n   # TODO: keras/layers/merging/merging_test.py::MergingLayersTest::test_sparse_dot_2d Fatal Python error: Aborted\n   pytest keras --ignore keras/src/applications \\\n               --ignore keras/src/layers/merging/merging_test.py \\\n               --cov=keras \\\n               --cov-config=pyproject.toml\nfi\n\nif [ \"$KERAS_BACKEND\" == \"jax\" ]\nthen\n   echo \"JAX backend detected.\"\n   pip install -r requirements-jax-cuda.txt --progress-bar off --timeout 1000\n   pip uninstall -y keras keras-nightly\n   python3 -c 'import jax;print(jax.__version__);print(jax.default_backend())'\n   # Raise error if GPU is not detected.\n   python3 -c 'import jax;assert jax.default_backend().lower() == \"gpu\"'\n\n   # TODO: keras/layers/merging/merging_test.py::MergingLayersTest::test_sparse_dot_2d Fatal Python error: Aborted\n   # TODO: keras/trainers/data_adapters/py_dataset_adapter_test.py::PyDatasetAdapterTest::test_basic_flow0 Fatal Python error: Aborted\n   # keras/backend/jax/distribution_lib_test.py is configured for CPU test for now.\n   pytest keras --ignore keras/src/applications \\\n               --ignore keras/src/layers/merging/merging_test.py \\\n               --ignore keras/src/trainers/data_adapters/py_dataset_adapter_test.py \\\n               --ignore keras/src/backend/jax/distribution_lib_test.py \\\n               --ignore keras/src/distribution/distribution_lib_test.py \\\n               --cov=keras \\\n               --cov-config=pyproject.toml\n\n   pytest keras/src/distribution/distribution_lib_test.py --cov=keras --cov-config=pyproject.toml\nfi\n\nif [ \"$KERAS_BACKEND\" == \"torch\" ]\nthen\n   echo \"PyTorch backend detected.\"\n   pip install -r requirements-torch-cuda.txt --progress-bar off --timeout 1000\n   pip uninstall -y keras keras-nightly\n   python3 -c 'import torch;print(torch.__version__);print(torch.cuda.is_available())'\n   # Raise error if GPU is not detected.\n   python3 -c 'import torch;assert torch.cuda.is_available()'\n\n   pytest keras --ignore keras/src/applications \\\n               --cov=keras \\\n               --cov-config=pyproject.toml\n\nfi\n"
  },
  {
    "path": ".kokoro/github/ubuntu/gpu/jax/continuous.cfg",
    "content": "build_file: \"keras/.kokoro/github/ubuntu/gpu/build.sh\"\n\naction {\n  define_artifacts {\n    regex: \"**/sponge_log.log\"\n    regex: \"**/sponge_log.xml\"\n  }\n}\n\nenv_vars: {\n   key: \"KERAS_BACKEND\"\n   value: \"jax\"\n}\n\n# Set timeout to 120 mins from default 180 mins\ntimeout_mins: 120"
  },
  {
    "path": ".kokoro/github/ubuntu/gpu/jax/presubmit.cfg",
    "content": "build_file: \"keras/.kokoro/github/ubuntu/gpu/build.sh\"\n\naction {\n  define_artifacts {\n    regex: \"**/sponge_log.log\"\n    regex: \"**/sponge_log.xml\"\n  }\n}\n\nenv_vars: {\n   key: \"KERAS_BACKEND\"\n   value: \"jax\"\n}\n\n# Set timeout to 120 mins from default 180 mins\ntimeout_mins: 120"
  },
  {
    "path": ".kokoro/github/ubuntu/gpu/tensorflow/continuous.cfg",
    "content": "build_file: \"keras/.kokoro/github/ubuntu/gpu/build.sh\"\n\naction {\n  define_artifacts {\n    regex: \"**/sponge_log.log\"\n    regex: \"**/sponge_log.xml\"\n  }\n}\n\nenv_vars: {\n   key: \"KERAS_BACKEND\"\n   value: \"tensorflow\"\n}\n\n# Set timeout to 60 mins from default 180 mins\ntimeout_mins: 60"
  },
  {
    "path": ".kokoro/github/ubuntu/gpu/tensorflow/presubmit.cfg",
    "content": "build_file: \"keras/.kokoro/github/ubuntu/gpu/build.sh\"\n\naction {\n  define_artifacts {\n    regex: \"**/sponge_log.log\"\n    regex: \"**/sponge_log.xml\"\n  }\n}\n\nenv_vars: {\n   key: \"KERAS_BACKEND\"\n   value: \"tensorflow\"\n}\n\n# Set timeout to 60 mins from default 180 mins\ntimeout_mins: 60"
  },
  {
    "path": ".pre-commit-config.yaml",
    "content": "repos:\n  - repo: local\n    hooks:\n      - id: api-gen\n        name: api_gen\n        entry: |\n          bash shell/api_gen.sh\n          git status\n          clean=$(git status | grep \"nothing to commit\")\n          if [ -z \"$clean\" ]; then\n            echo \"Please run shell/api_gen.sh to generate API.\"\n            exit 1\n          fi\n        language: system\n        stages: [pre-commit, manual]\n        require_serial: true\n  - repo: https://github.com/astral-sh/ruff-pre-commit\n    rev: v0.9.2\n    hooks:\n      - id: ruff\n        args: [--config, pyproject.toml, --fix, .]\n        stages: [pre-commit]\n      - id: ruff-format\n        args: [--config, pyproject.toml, .]\n        stages: [pre-commit]\n      - id: ruff\n        args: [--config, pyproject.toml, .]\n        stages: [manual]\n      - id: ruff-format\n        args: [\"--check\", --config, pyproject.toml, .]\n        stages: [manual]"
  },
  {
    "path": "CONTRIBUTING.md",
    "content": "Keras 3 is a high-velocity open-source project. We welcome contributions!\n\nContributions can be made in a variety of ways, including coding, enriching documentation, refining docstrings, and providing code examples.\n\n\n## Current items open for contributions\nAt [this link](https://github.com/keras-team/keras/issues/18442), you'll find a list of items where your help is needed!\n\n\n## How to contribute code\n\nFollow these steps to submit your code contribution.\n\n### Step 1. Open an issue\n\nBefore making any changes, we recommend opening an issue (if one doesn't already\nexist) and discussing your proposed changes. This way, we can give you feedback\nand validate the proposed changes.\n\nIf the changes are minor (simple bug fix or documentation fix), then feel free\nto open a Pull Request (PR) without discussion.\n\n### Step 2. Make code changes\n\nTo make code changes, you need to fork the repository. You will need to setup a\ndevelopment environment and run the unit tests. This is covered in the section\n\"Setup environment\".\n\n### Step 3. Create a pull request\n\nOnce the change is ready, open a pull request from your branch in your fork to\nthe master branch in [keras-team/keras](https://github.com/keras-team/keras).\n\n### Step 4. Sign the Contributor License Agreement\n\nAfter creating the pull request, the `cla/google` check will be performed and,\nif you haven't signed the Contributor License Agreement (CLA), it will fail with\ninstructions on how to do so. Please follow the instructions to sign the CLA and\nthe check will pass.\n\n![CLA signed](https://github.com/keras-team/keras/assets/1091026/71c26353-e3b5-4135-8bae-64693c717775)\n\n\n### Step 5. Code review\n\nIf the tests fail, look into the error messages and try to fix them.\n\n![CI tests](https://github.com/keras-team/keras/assets/1091026/6f6c17ef-6bd7-4e95-9fbc-1906cde37380)\n\nA reviewer will review the pull request and provide comments. There may be\nseveral rounds of comments and code changes before the pull request gets\napproved by the reviewer.\n\n![Approval from reviewer](https://github.com/keras-team/keras/assets/1091026/8d28f74c-21e9-4146-b0ff-62d649a552a8)\n\n### Step 6. Merging\n\nOnce the pull request is approved, a `ready to pull` tag will be added to the\npull request. A team member will take care of the merging.\n\n![Ready to pull and merged](https://github.com/keras-team/keras/assets/1091026/c3908345-d7ae-44ee-a428-01f3b448b46b)\n\nHere is an [example pull request](https://github.com/keras-team/keras/pull/18848)\nfor your reference.\n\n## Setup environment\n\nWe provide two ways of setting up a development environment. One is to use a\ndev container, and the other one is to set up a local environment by installing\nthe dev tools needed.\n\n### Option 1: GitHub Codespace or dev container\n\nWe support GitHub Codespaces, Visual Studio Code dev containers and JetBrain dev\ncontainers. Please see the\n[Dev container documentation](https://github.com/keras-team/keras/tree/master/.devcontainer).\n\n### Option 2: Set up a local environment\n\nTo set up your local dev environment, you will need the following tools.\n\n1.  [git](https://github.com/) for code repository management.\n2.  [python](https://www.python.org/) to build and code in Keras.\n\nThe following commands check the tools above are successfully installed. Note\nthat Keras requires at least Python 3.10 to run.\n\n```shell\ngit --version\npython --version\n```\n\nClone your forked repo to your local machine. Go to the cloned directory to\ninstall the dependencies.\n\n```shell\ngit clone https://github.com/YOUR_GITHUB_USERNAME/keras.git\ncd keras\npip install -r requirements.txt\n```\n\nYou then need to configure the backend to use, see the\n[Configuring your backend](https://github.com/keras-team/keras/blob/master/README.md#configuring-your-backend)\nsection of the README.\n\nYou can also add GPU support to your environment, see the\n[Adding GPU support](https://github.com/keras-team/keras/blob/master/README.md#adding-gpu-support)\nsection of the README.\n\n## Generating public API and formatting the code\n\nFor the first time you are setting up the repo, please run `pre-commit install`.\nNote that this needs to be done only once at the beginning.\n\nNow, whenever you run `git commit -m \"<message>\"`, three things are\nautomatically done:\n\n- Public API generation\n- Code formatting\n- Code linting\n\nIf there's any error, the commit will not go through. Please fix the error (\nmost of the times, the error is fixed automatically by the formatter/linter) and\nre-run the following:\n\n```\ngit add .\ngit commit -m \"<message>\" # This will not get logged as a duplicate commit.\n```\n\nIn case you want to run the above manually on all files, you can do the\nfollowing:\n\n```\npre-commit run --all-files\n```\n\nKerasHub uses [Ruff](https://docs.astral.sh/ruff/) to format the code.\n\n### Docstrings\n\nWe do not have an automated way to check docstring style, so if you write\nor edit any docstring, please make sure to check them manually.\nKeras docstrings follow the conventions below:\n\nA **class docstring** may contain the following items:\n\n* A one-line description of the class.\n* Paragraph(s) of more detailed information.\n* Optional `Examples` section.\n* `Args` section for arguments in `__init__()`.\n* If it's a layer:\n    * `Call arguments` section for arguments in `Layer.call()`.\n    * `Returns` section for the return values of `Layer.call()`.\n    * Optional `Raises` section for possible errors.\n\nYou can check out `MultiHeadAttention` as an example\n[(link)](https://github.com/keras-team/keras/blob/v3.0.0/keras/layers/attention/multi_head_attention.py#L20).\n\nA **function docstring** may contain the following items:\n\n* One-line description of the function.\n* Paragraph(s) of more detailed information.\n* Optional `Examples` section.\n* `Args` section for the function arguments.\n* `Returns` section for the return values.\n* Optional `Raises` section for possible errors.\n\nYou can check out `text_dataset_from_directory` as an example\n[(link)](https://github.com/keras-team/keras/blob/v3.0.0/keras/utils/text_dataset_utils.py#L27).\n\n## Run tests\n\nWe use [pytest](https://pytest.org/) to run the tests.\n\n### Run a test file\n\nTo run the tests in `keras/src/losses/losses_test.py`, use the following command\nat the root directory of the repo.\n\n```shell\npytest keras/src/losses/losses_test.py\n```\n\n### Run a single test case\n\nYou can specify a single test class to run within a file.\n\n```shell\npytest keras/src/losses/losses_test.py::MeanSquaredErrorTest\n```\n\nYou can also specify a single test method to run within a class.\n\n```shell\npytest keras/src/losses/losses_test.py::MeanSquaredErrorTest::test_sample_weighted\n```\n\n### Run all tests\n\nYou can run all the tests locally by running the following command in the repo\nroot directory.\n\n```shell\npytest keras\n```\n\nNote that you can skip the Keras applications tests using the\n`SKIP_APPLICATIONS_TESTS` environment variable. This will cut down the testing\ntime significantly.\n\n```shell\nSKIP_APPLICATIONS_TESTS=True pytest keras\n```\n\nTo run all tests using a different backend, you can simply specify it on the\ncommand line.\n\n```shell\nKERAS_BACKEND=jax SKIP_APPLICATIONS_TESTS=True pytest keras\n```\n"
  },
  {
    "path": "LICENSE",
    "content": "                                 Apache License\n                           Version 2.0, January 2004\n                        http://www.apache.org/licenses/\n\n   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION\n\n   1. Definitions.\n\n      \"License\" shall mean the terms and conditions for use, reproduction,\n      and distribution as defined by Sections 1 through 9 of this document.\n\n      \"Licensor\" shall mean the copyright owner or entity authorized by\n      the copyright owner that is granting the License.\n\n      \"Legal Entity\" shall mean the union of the acting entity and all\n      other entities that control, are controlled by, or are under common\n      control with that entity. For the purposes of this definition,\n      \"control\" means (i) the power, direct or indirect, to cause the\n      direction or management of such entity, whether by contract or\n      otherwise, or (ii) ownership of fifty percent (50%) or more of the\n      outstanding shares, or (iii) beneficial ownership of such entity.\n\n      \"You\" (or \"Your\") shall mean an individual or Legal Entity\n      exercising permissions granted by this License.\n\n      \"Source\" form shall mean the preferred form for making modifications,\n      including but not limited to software source code, documentation\n      source, and configuration files.\n\n      \"Object\" form shall mean any form resulting from mechanical\n      transformation or translation of a Source form, including but\n      not limited to compiled object code, generated documentation,\n      and conversions to other media types.\n\n      \"Work\" shall mean the work of authorship, whether in Source or\n      Object form, made available under the License, as indicated by a\n      copyright notice that is included in or attached to the work\n      (an example is provided in the Appendix below).\n\n      \"Derivative Works\" shall mean any work, whether in Source or Object\n      form, that is based on (or derived from) the Work and for which the\n      editorial revisions, annotations, elaborations, or other modifications\n      represent, as a whole, an original work of authorship. For the purposes\n      of this License, Derivative Works shall not include works that remain\n      separable from, or merely link (or bind by name) to the interfaces of,\n      the Work and Derivative Works thereof.\n\n      \"Contribution\" shall mean any work of authorship, including\n      the original version of the Work and any modifications or additions\n      to that Work or Derivative Works thereof, that is intentionally\n      submitted to Licensor for inclusion in the Work by the copyright owner\n      or by an individual or Legal Entity authorized to submit on behalf of\n      the copyright owner. For the purposes of this definition, \"submitted\"\n      means any form of electronic, verbal, or written communication sent\n      to the Licensor or its representatives, including but not limited to\n      communication on electronic mailing lists, source code control systems,\n      and issue tracking systems that are managed by, or on behalf of, the\n      Licensor for the purpose of discussing and improving the Work, but\n      excluding communication that is conspicuously marked or otherwise\n      designated in writing by the copyright owner as \"Not a Contribution.\"\n\n      \"Contributor\" shall mean Licensor and any individual or Legal Entity\n      on behalf of whom a Contribution has been received by Licensor and\n      subsequently incorporated within the Work.\n\n   2. Grant of Copyright License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      copyright license to reproduce, prepare Derivative Works of,\n      publicly display, publicly perform, sublicense, and distribute the\n      Work and such Derivative Works in Source or Object form.\n\n   3. Grant of Patent License. Subject to the terms and conditions of\n      this License, each Contributor hereby grants to You a perpetual,\n      worldwide, non-exclusive, no-charge, royalty-free, irrevocable\n      (except as stated in this section) patent license to make, have made,\n      use, offer to sell, sell, import, and otherwise transfer the Work,\n      where such license applies only to those patent claims licensable\n      by such Contributor that are necessarily infringed by their\n      Contribution(s) alone or by combination of their Contribution(s)\n      with the Work to which such Contribution(s) was submitted. If You\n      institute patent litigation against any entity (including a\n      cross-claim or counterclaim in a lawsuit) alleging that the Work\n      or a Contribution incorporated within the Work constitutes direct\n      or contributory patent infringement, then any patent licenses\n      granted to You under this License for that Work shall terminate\n      as of the date such litigation is filed.\n\n   4. Redistribution. You may reproduce and distribute copies of the\n      Work or Derivative Works thereof in any medium, with or without\n      modifications, and in Source or Object form, provided that You\n      meet the following conditions:\n\n      (a) You must give any other recipients of the Work or\n          Derivative Works a copy of this License; and\n\n      (b) You must cause any modified files to carry prominent notices\n          stating that You changed the files; and\n\n      (c) You must retain, in the Source form of any Derivative Works\n          that You distribute, all copyright, patent, trademark, and\n          attribution notices from the Source form of the Work,\n          excluding those notices that do not pertain to any part of\n          the Derivative Works; and\n\n      (d) If the Work includes a \"NOTICE\" text file as part of its\n          distribution, then any Derivative Works that You distribute must\n          include a readable copy of the attribution notices contained\n          within such NOTICE file, excluding those notices that do not\n          pertain to any part of the Derivative Works, in at least one\n          of the following places: within a NOTICE text file distributed\n          as part of the Derivative Works; within the Source form or\n          documentation, if provided along with the Derivative Works; or,\n          within a display generated by the Derivative Works, if and\n          wherever such third-party notices normally appear. The contents\n          of the NOTICE file are for informational purposes only and\n          do not modify the License. You may add Your own attribution\n          notices within Derivative Works that You distribute, alongside\n          or as an addendum to the NOTICE text from the Work, provided\n          that such additional attribution notices cannot be construed\n          as modifying the License.\n\n      You may add Your own copyright statement to Your modifications and\n      may provide additional or different license terms and conditions\n      for use, reproduction, or distribution of Your modifications, or\n      for any such Derivative Works as a whole, provided Your use,\n      reproduction, and distribution of the Work otherwise complies with\n      the conditions stated in this License.\n\n   5. Submission of Contributions. Unless You explicitly state otherwise,\n      any Contribution intentionally submitted for inclusion in the Work\n      by You to the Licensor shall be under the terms and conditions of\n      this License, without any additional terms or conditions.\n      Notwithstanding the above, nothing herein shall supersede or modify\n      the terms of any separate license agreement you may have executed\n      with Licensor regarding such Contributions.\n\n   6. Trademarks. This License does not grant permission to use the trade\n      names, trademarks, service marks, or product names of the Licensor,\n      except as required for reasonable and customary use in describing the\n      origin of the Work and reproducing the content of the NOTICE file.\n\n   7. Disclaimer of Warranty. Unless required by applicable law or\n      agreed to in writing, Licensor provides the Work (and each\n      Contributor provides its Contributions) on an \"AS IS\" BASIS,\n      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or\n      implied, including, without limitation, any warranties or conditions\n      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A\n      PARTICULAR PURPOSE. You are solely responsible for determining the\n      appropriateness of using or redistributing the Work and assume any\n      risks associated with Your exercise of permissions under this License.\n\n   8. Limitation of Liability. In no event and under no legal theory,\n      whether in tort (including negligence), contract, or otherwise,\n      unless required by applicable law (such as deliberate and grossly\n      negligent acts) or agreed to in writing, shall any Contributor be\n      liable to You for damages, including any direct, indirect, special,\n      incidental, or consequential damages of any character arising as a\n      result of this License or out of the use or inability to use the\n      Work (including but not limited to damages for loss of goodwill,\n      work stoppage, computer failure or malfunction, or any and all\n      other commercial damages or losses), even if such Contributor\n      has been advised of the possibility of such damages.\n\n   9. Accepting Warranty or Additional Liability. While redistributing\n      the Work or Derivative Works thereof, You may choose to offer,\n      and charge a fee for, acceptance of support, warranty, indemnity,\n      or other liability obligations and/or rights consistent with this\n      License. However, in accepting such obligations, You may act only\n      on Your own behalf and on Your sole responsibility, not on behalf\n      of any other Contributor, and only if You agree to indemnify,\n      defend, and hold each Contributor harmless for any liability\n      incurred by, or claims asserted against, such Contributor by reason\n      of your accepting any such warranty or additional liability.\n\n   END OF TERMS AND CONDITIONS\n\n   APPENDIX: How to apply the Apache License to your work.\n\n      To apply the Apache License to your work, attach the following\n      boilerplate notice, with the fields enclosed by brackets \"[]\"\n      replaced with your own identifying information. (Don't include\n      the brackets!)  The text should be enclosed in the appropriate\n      comment syntax for the file format. We also recommend that a\n      file or class name and description of purpose be included on the\n      same \"printed page\" as the copyright notice for easier\n      identification within third-party archives.\n\n   Copyright [yyyy] [name of copyright owner]\n\n   Licensed under the Apache License, Version 2.0 (the \"License\");\n   you may not use this file except in compliance with the License.\n   You may obtain a copy of the License at\n\n       http://www.apache.org/licenses/LICENSE-2.0\n\n   Unless required by applicable law or agreed to in writing, software\n   distributed under the License is distributed on an \"AS IS\" BASIS,\n   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n   See the License for the specific language governing permissions and\n   limitations under the License."
  },
  {
    "path": "README.md",
    "content": "# Keras 3: Deep Learning for Humans\n\nKeras 3 is a multi-backend deep learning framework, with support for JAX, TensorFlow, PyTorch, and OpenVINO (for inference-only).\nEffortlessly build and train models for computer vision, natural language processing, audio processing,\ntimeseries forecasting, recommender systems, etc.\n\n- **Accelerated model development**: Ship deep learning solutions faster thanks to the high-level UX of Keras\nand the availability of easy-to-debug runtimes like PyTorch or JAX eager execution.\n- **State-of-the-art performance**: By picking the backend that is the fastest for your model architecture (often JAX!),\nleverage speedups ranging from 20% to 350% compared to other frameworks. [Benchmark here](https://keras.io/getting_started/benchmarks/).\n- **Datacenter-scale training**: Scale confidently from your laptop to large clusters of GPUs or TPUs.\n\nJoin nearly three million developers, from burgeoning startups to global enterprises, in harnessing the power of Keras 3.\n\n\n## Installation\n\n### Install with pip\n\nKeras 3 is available on PyPI as `keras`. Note that Keras 2 remains available as the `tf-keras` package.\n\n1. Install `keras`:\n\n```\npip install keras --upgrade\n```\n\n2. Install backend package(s).\n\nTo use `keras`, you should also install the backend of choice: `tensorflow`, `jax`, or `torch`. Additionally,\nThe `openvino` backend is available with support for model inference only.\n\n### Local installation\n\n#### Minimal installation\n\nKeras 3 is compatible with Linux and macOS systems. For Windows users, we recommend using WSL2 to run Keras.\nTo install a local development version:\n\n1. Install dependencies:\n\n```\npip install -r requirements.txt\n```\n\n2. Run installation command from the root directory.\n\n```\npython pip_build.py --install\n```\n\n3. Run API generation script when creating PRs that update `keras_export` public APIs:\n\n```\n./shell/api_gen.sh\n```\n\n## Backend Compatibility Table\n\nThe following table lists the minimum supported versions of each backend for the latest stable release of Keras (v3.x):\n\n| Backend    | Minimum Supported Version |\n|------------|---------------------------|\n| TensorFlow | 2.16.1                    |\n| JAX        | 0.4.20                    |\n| PyTorch    | 2.1.0                     |\n| OpenVINO   | 2025.3.0                  |\n\n#### Adding GPU support\n\nThe `requirements.txt` file will install a CPU-only version of TensorFlow, JAX, and PyTorch. For GPU support, we also\nprovide a separate `requirements-{backend}-cuda.txt` for TensorFlow, JAX, and PyTorch. These install all CUDA\ndependencies via `pip` and expect a NVIDIA driver to be pre-installed. We recommend a clean Python environment for each\nbackend to avoid CUDA version mismatches. As an example, here is how to create a JAX GPU environment with `conda`:\n\n```shell\nconda create -y -n keras-jax python=3.10\nconda activate keras-jax\npip install -r requirements-jax-cuda.txt\npython pip_build.py --install\n```\n\n## Configuring your backend\n\nYou can export the environment variable `KERAS_BACKEND` or you can edit your local config file at `~/.keras/keras.json`\nto configure your backend. Available backend options are: `\"tensorflow\"`, `\"jax\"`, `\"torch\"`, `\"openvino\"`. Example:\n\n```\nexport KERAS_BACKEND=\"jax\"\n```\n\nIn Colab, you can do:\n\n```python\nimport os\nos.environ[\"KERAS_BACKEND\"] = \"jax\"\n\nimport keras\n```\n\n**Note:** The backend must be configured before importing `keras`, and the backend cannot be changed after\nthe package has been imported.\n\n**Note:** The OpenVINO backend is an inference-only backend, meaning it is designed only for running model\npredictions using `model.predict()` method.\n\n## Backwards compatibility\n\nKeras 3 is intended to work as a drop-in replacement for `tf.keras` (when using the TensorFlow backend). Just take your\nexisting `tf.keras` code, make sure that your calls to `model.save()` are using the up-to-date `.keras` format, and you're\ndone.\n\nIf your `tf.keras` model does not include custom components, you can start running it on top of JAX or PyTorch immediately.\n\nIf it does include custom components (e.g. custom layers or a custom `train_step()`), it is usually possible to convert it\nto a backend-agnostic implementation in just a few minutes.\n\nIn addition, Keras models can consume datasets in any format, regardless of the backend you're using:\nyou can train your models with your existing `tf.data.Dataset` pipelines or PyTorch `DataLoaders`.\n\n## Why use Keras 3?\n\n- Run your high-level Keras workflows on top of any framework -- benefiting at will from the advantages of each framework,\ne.g. the scalability and performance of JAX or the production ecosystem options of TensorFlow.\n- Write custom components (e.g. layers, models, metrics) that you can use in low-level workflows in any framework.\n    - You can take a Keras model and train it in a training loop written from scratch in native TF, JAX, or PyTorch.\n    - You can take a Keras model and use it as part of a PyTorch-native `Module` or as part of a JAX-native model function.\n- Make your ML code future-proof by avoiding framework lock-in.\n- As a PyTorch user: get access to power and usability of Keras, at last!\n- As a JAX user: get access to a fully-featured, battle-tested, well-documented modeling and training library.\n\n\nRead more in the [Keras 3 release announcement](https://keras.io/keras_3/).\n"
  },
  {
    "path": "SECURITY.md",
    "content": "# Security Policy\n\n - [**Using Keras Securely**](#using-keras-securely)\n   - [Untrusted inputs](#untrusted-inputs)\n   - [Data privacy](#data-privacy)\n   - [Untrusted environments or networks](#untrusted-environments-or-networks)\n   - [Multi-Tenant environments](#multi-tenant-environments)\n - [**Reporting a Vulnerability**](#reporting-a-vulnerability)\n\n## Using Keras Securely\n\n### Untrusted inputs\n\nSome models accept various input formats (text, images, audio, etc.). The libraries converting these inputs have varying security levels, so it's crucial to isolate the model and carefully pre-process inputs to mitigate script injection risks.\n\nFor maximum security when handling untrusted inputs, you may need to employ the following:\n\n* Sandboxing: Isolate the model process.\n* Pre-analysis: check how the model performs by default when exposed to prompt injection (e.g. using [fuzzing for prompt injection](https://github.com/FonduAI/awesome-prompt-injection?tab=readme-ov-file#tools)). This will give you leads on how hard you will have to work on the next topics.\n* Updates: Keep your model and libraries updated with the latest security patches.\n* Input Sanitation: Before feeding data to the model, sanitize inputs rigorously. This involves techniques such as:\n    * Validation: Enforce strict rules on allowed characters and data types.\n    * Filtering: Remove potentially malicious scripts or code fragments.\n    * Encoding: Convert special characters into safe representations.\n    * Verification: Run tooling that identifies potential script injections (e.g. [models that detect prompt injection attempts](https://python.langchain.com/docs/guides/safety/hugging_face_prompt_injection)). \n\n### Data privacy\nTo protect sensitive data from potential leaks or unauthorized access, it is essential to sandbox the model execution. This means running the model in a secure, isolated environment, which helps mitigate many attack vectors.\n\nWhen training the model with sensitive data, expose your newly-trained model to tests to identify potential sensitive data leaks.\n\n### Untrusted environments or networks\n\nIf you can't run your models in a secure and isolated environment or if it must be exposed to an untrusted network, make sure to take the following security precautions:\n* Confirm the hash of  any downloaded artifact (i.e. pre-trained model weights) matches a known-good value\n* Encrypt your data while sending it over the network.\n\n### Multi-Tenant environments\n\nIf you intend to run multiple models in parallel with shared memory, it is your responsibility to ensure the models do not interact or access each other's data. The primary areas of concern are tenant isolation, resource allocation, model sharing and hardware attacks.\n\n#### Tenant Isolation\n\nYou must make sure that models run separately. Since models can run code, it's important to use strong isolation methods to prevent unwanted access to the data from other tenants.\n\nSeparating networks is also a big part of isolation. If you keep model network traffic separate, you not only prevent unauthorized access to data or models, but also prevent malicious users or tenants sending graphs to execute under another tenant’s identity.\n\n#### Resource Allocation\n\nA denial of service caused by one model can impact the overall system health. Implement safeguards like rate limits, access controls, and health monitoring.\n\n#### Model Sharing\n\nIn a multitenant design that allows sharing models, make sure that tenants and users fully understand the potential security risks involved. They must be aware that they will essentially be running code provided by other users. Unfortunately, there are no reliable methods available to detect malicious models, graphs, or checkpoints. To mitigate this risk, the recommended approach is to sandbox the model execution, effectively isolating it from the rest of the system.\n\n#### Hardware Attacks\n\nBesides the virtual environment, the hardware (GPUs or TPUs) can also be attacked. [Research](https://scholar.google.com/scholar?q=gpu+side+channel) has shown that side channel attacks on GPUs are possible, which can make data leak from other models or processes running on the same system at the same time.\n\n## Reporting a Vulnerability\n\nBeware that none of the topics under [Using Keras Securely](#using-keras-securely) are considered vulnerabilities of Keras.\n\nIf you have discovered a security vulnerability in this project, please report it\nprivately. **Do not disclose it as a public issue.** This gives us time to work with you\nto fix the issue before public exposure, reducing the chance that the exploit will be\nused before a patch is released.\n\nYou may submit the report in the following ways:\n\n- send an email to francois.chollet@gmail.com and/or\n- send a [private vulnerability report](https://github.com/keras-team/keras/security/advisories/new)\n\nPlease provide the following information in your report:\n\n- A description of the vulnerability and its impact\n- How to reproduce the issue\n\nThis project is maintained by volunteers on a reasonable-effort basis. As such,\nplease give us 90 days to work on a fix before public exposure.\n"
  },
  {
    "path": "api_gen.py",
    "content": "\"\"\"Script to generate keras public API in `keras/api` directory.\n\nUsage:\n\nRun via `./shell/api_gen.sh`.\nIt generates API and formats user and generated APIs.\n\"\"\"\n\nimport os\nimport re\nimport shutil\n\nimport namex\n\nPACKAGE = \"keras\"\nBUILD_DIR_NAME = \"tmp_build_dir\"\n\n\ndef ignore_files(_, filenames):\n    return [f for f in filenames if f.endswith(\"_test.py\")]\n\n\ndef copy_source_to_build_directory(root_path):\n    # Copy sources (`keras/` directory and setup files) to build dir\n    build_dir = os.path.join(root_path, BUILD_DIR_NAME)\n    build_package_dir = os.path.join(build_dir, PACKAGE)\n    build_src_dir = os.path.join(build_package_dir, \"src\")\n    root_src_dir = os.path.join(root_path, PACKAGE, \"src\")\n    if os.path.exists(build_dir):\n        shutil.rmtree(build_dir)\n    os.makedirs(build_package_dir)\n    shutil.copytree(root_src_dir, build_src_dir)\n    return build_dir\n\n\ndef create_legacy_directory(package_dir):\n    src_dir = os.path.join(package_dir, \"src\")\n    # Make keras/_tf_keras/ by copying keras/\n    tf_keras_dirpath_parent = os.path.join(package_dir, \"_tf_keras\")\n    tf_keras_dirpath = os.path.join(tf_keras_dirpath_parent, \"keras\")\n    os.makedirs(tf_keras_dirpath, exist_ok=True)\n    with open(os.path.join(tf_keras_dirpath_parent, \"__init__.py\"), \"w\") as f:\n        f.write(\"from keras._tf_keras import keras\\n\")\n    with open(os.path.join(package_dir, \"__init__.py\")) as f:\n        init_file = f.read()\n        init_file = init_file.replace(\n            \"from keras import _legacy as _legacy\",\n            \"from keras import _tf_keras as _tf_keras\",\n        )\n    with open(os.path.join(package_dir, \"__init__.py\"), \"w\") as f:\n        f.write(init_file)\n    # Remove the import of `_tf_keras` in `keras/_tf_keras/keras/__init__.py`\n    init_file = init_file.replace(\"from keras import _tf_keras\\n\", \"\\n\")\n    with open(os.path.join(tf_keras_dirpath, \"__init__.py\"), \"w\") as f:\n        f.write(init_file)\n    for dirname in os.listdir(package_dir):\n        dirpath = os.path.join(package_dir, dirname)\n        if os.path.isdir(dirpath) and dirname not in (\n            \"_legacy\",\n            \"_tf_keras\",\n            \"src\",\n        ):\n            destpath = os.path.join(tf_keras_dirpath, dirname)\n            if os.path.exists(destpath):\n                shutil.rmtree(destpath)\n            shutil.copytree(\n                dirpath,\n                destpath,\n                ignore=ignore_files,\n            )\n\n    # Copy keras/_legacy/ file contents to keras/_tf_keras/keras\n    legacy_submodules = [\n        path[:-3]\n        for path in os.listdir(os.path.join(src_dir, \"legacy\"))\n        if path.endswith(\".py\")\n    ]\n    legacy_submodules += [\n        path\n        for path in os.listdir(os.path.join(src_dir, \"legacy\"))\n        if os.path.isdir(os.path.join(src_dir, \"legacy\", path))\n    ]\n    for root, _, fnames in os.walk(os.path.join(package_dir, \"_legacy\")):\n        for fname in fnames:\n            if fname.endswith(\".py\"):\n                legacy_fpath = os.path.join(root, fname)\n                tf_keras_root = root.replace(\n                    os.path.join(os.path.sep, \"_legacy\"),\n                    os.path.join(os.path.sep, \"_tf_keras\", \"keras\"),\n                )\n                core_api_fpath = os.path.join(\n                    root.replace(os.path.join(os.path.sep, \"_legacy\"), \"\"),\n                    fname,\n                )\n                if not os.path.exists(tf_keras_root):\n                    os.makedirs(tf_keras_root)\n                tf_keras_fpath = os.path.join(tf_keras_root, fname)\n                with open(legacy_fpath) as f:\n                    legacy_contents = f.read()\n                    legacy_contents = legacy_contents.replace(\n                        \"keras._legacy\", \"keras._tf_keras.keras\"\n                    )\n                if os.path.exists(core_api_fpath):\n                    with open(core_api_fpath) as f:\n                        core_api_contents = f.read()\n                    core_api_contents = core_api_contents.replace(\n                        \"from keras import _tf_keras as _tf_keras\\n\", \"\"\n                    )\n                    for legacy_submodule in legacy_submodules:\n                        core_api_contents = core_api_contents.replace(\n                            f\"from keras import {legacy_submodule} as {legacy_submodule}\\n\",  # noqa: E501\n                            \"\",\n                        )\n                        core_api_contents = core_api_contents.replace(\n                            f\"keras.{legacy_submodule}\",\n                            f\"keras._tf_keras.keras.{legacy_submodule}\",\n                        )\n                    # Remove duplicate generated comments string.\n                    legacy_contents = re.sub(r\"\\n\", r\"\\\\n\", legacy_contents)\n                    legacy_contents = re.sub('\"\"\".*\"\"\"', \"\", legacy_contents)\n                    legacy_contents = re.sub(r\"\\\\n\", r\"\\n\", legacy_contents)\n                    # If the same module is in legacy and core_api, use legacy\n                    legacy_imports = re.findall(\n                        r\"import (\\w+)\", legacy_contents\n                    )\n                    for import_name in legacy_imports:\n                        core_api_contents = re.sub(\n                            f\"\\n.* import {import_name} as {import_name}\\n\",\n                            r\"\\n\",\n                            core_api_contents,\n                        )\n                    legacy_contents = f\"{core_api_contents}\\n{legacy_contents}\"\n                with open(tf_keras_fpath, \"w\") as f:\n                    f.write(legacy_contents)\n\n    # Delete keras/api/_legacy/\n    shutil.rmtree(os.path.join(package_dir, \"_legacy\"))\n\n\ndef export_version_string(api_init_fname):\n    with open(api_init_fname) as f:\n        contents = f.read()\n    with open(api_init_fname, \"w\") as f:\n        contents += \"from keras.src.version import __version__ as __version__\\n\"\n        f.write(contents)\n\n\ndef build():\n    root_path = os.path.dirname(os.path.abspath(__file__))\n    code_api_dir = os.path.join(root_path, PACKAGE, \"api\")\n    # Create temp build dir\n    build_dir = copy_source_to_build_directory(root_path)\n    build_api_dir = os.path.join(build_dir, PACKAGE)\n    build_src_dir = os.path.join(build_api_dir, \"src\")\n    build_api_init_fname = os.path.join(build_api_dir, \"__init__.py\")\n    try:\n        os.chdir(build_dir)\n        open(build_api_init_fname, \"w\").close()\n        namex.generate_api_files(\n            \"keras\",\n            code_directory=\"src\",\n            exclude_directories=[\n                os.path.join(\"src\", \"backend\", \"jax\"),\n                os.path.join(\"src\", \"backend\", \"openvino\"),\n                os.path.join(\"src\", \"backend\", \"tensorflow\"),\n                os.path.join(\"src\", \"backend\", \"torch\"),\n            ],\n        )\n        # Add __version__ to `api/`.\n        export_version_string(build_api_init_fname)\n        # Creates `_tf_keras` with full keras API\n        create_legacy_directory(package_dir=os.path.join(build_dir, PACKAGE))\n        # Copy back the keras/api and keras/__init__.py from build directory\n        if os.path.exists(build_src_dir):\n            shutil.rmtree(build_src_dir)\n        if os.path.exists(code_api_dir):\n            shutil.rmtree(code_api_dir)\n        shutil.copytree(\n            build_api_dir, code_api_dir, ignore=shutil.ignore_patterns(\"src/\")\n        )\n    finally:\n        # Clean up: remove the build directory (no longer needed)\n        shutil.rmtree(build_dir)\n\n\nif __name__ == \"__main__\":\n    build()\n"
  },
  {
    "path": "benchmarks/__init__.py",
    "content": ""
  },
  {
    "path": "benchmarks/layer_benchmark/README.md",
    "content": "# Benchmark the layer performance\n\nThis directory contains benchmarks to compare the performance of\n`keras.layers.XXX` and `tf.keras.layers.XXX`. We compare the performance of\nboth the forward pass and train step (forward & backward pass). \n\nTo run the benchmark, use the command below and change the flags according to\nyour target:\n\n```shell\npython3 -m benchmarks.layer_benchmark.conv_benchmark \\\n    --benchmark_name=benchmark_conv2D \\\n    --num_samples=2048 \\\n    --batch_size=256 \\\n    --jit_compile=True\n```"
  },
  {
    "path": "benchmarks/layer_benchmark/__init__.py",
    "content": ""
  },
  {
    "path": "benchmarks/layer_benchmark/activation_benchmark.py",
    "content": "\"\"\"Benchmark activation layers.\n\nTo run benchmarks, see the following command for an example, please change the\nflag to your custom value:\n\n```\npython3 -m benchmarks.layer_benchmark.activation_benchmark \\\n    --benchmark_name=benchmark_elu \\\n    --num_samples=2048 \\\n    --batch_size=256 \\\n    --jit_compile=True\n```\n\"\"\"\n\nfrom absl import app\nfrom absl import flags\n\nfrom benchmarks.layer_benchmark.base_benchmark import LayerBenchmark\n\nFLAGS = flags.FLAGS\n\n\ndef benchmark_elu(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"ELU\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 256],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_prelu(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"PReLU\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 256],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_relu(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"ReLU\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 256],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_leaky_relu(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"LeakyReLU\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 256],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_softmax(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"Softmax\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 256],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\nBENCHMARK_NAMES = {\n    \"benchmark_elu\": benchmark_elu,\n    \"benchmark_relu\": benchmark_relu,\n    \"benchmark_leaky_relu\": benchmark_leaky_relu,\n    \"benchmark_prelu\": benchmark_prelu,\n    \"benchmark_softmax\": benchmark_softmax,\n}\n\n\ndef main(_):\n    benchmark_name = FLAGS.benchmark_name\n    num_samples = FLAGS.num_samples\n    batch_size = FLAGS.batch_size\n    jit_compile = FLAGS.jit_compile\n\n    if benchmark_name is None:\n        for name, benchmark_fn in BENCHMARK_NAMES.items():\n            benchmark_fn(num_samples, batch_size, jit_compile)\n        return\n\n    if benchmark_name not in BENCHMARK_NAMES:\n        raise ValueError(\n            f\"Invalid benchmark name: {benchmark_name}, `benchmark_name` must \"\n            f\"be one of {BENCHMARK_NAMES.keys()}\"\n        )\n    benchmark_fn = BENCHMARK_NAMES[benchmark_name]\n    benchmark_fn(num_samples, batch_size, jit_compile)\n\n\nif __name__ == \"__main__\":\n    app.run(main)\n"
  },
  {
    "path": "benchmarks/layer_benchmark/attention_benchmark.py",
    "content": "\"\"\"Benchmark attention layers.\n\nTo run benchmarks, see the following command for an example, please change the\nflag to your custom value:\n\n```\npython3 -m benchmarks.layer_benchmark.attention_benchmark \\\n    --benchmark_name=benchmark_attention \\\n    --num_samples=2048 \\\n    --batch_size=256 \\\n    --jit_compile=True\n```\n\"\"\"\n\nfrom absl import app\nfrom absl import flags\n\nfrom benchmarks.layer_benchmark.base_benchmark import LayerBenchmark\n\nFLAGS = flags.FLAGS\n\n\ndef benchmark_attention(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"Attention\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[[256, 64], [256, 64]],\n        flat_call_inputs=False,\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_multi_head_attention(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"MultiHeadAttention\"\n    init_args = {\n        \"num_heads\": 4,\n        \"key_dim\": 16,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[[256, 64], [256, 64], [256, 64]],\n        flat_call_inputs=True,\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_additive_attention(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"AdditiveAttention\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[[256, 64], [256, 64], [256, 64]],\n        flat_call_inputs=False,\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\nBENCHMARK_NAMES = {\n    \"benchmark_attention\": benchmark_attention,\n    \"benchmark_multi_head_attention\": benchmark_multi_head_attention,\n    \"benchmark_additive_attention\": benchmark_additive_attention,\n}\n\n\ndef main(_):\n    benchmark_name = FLAGS.benchmark_name\n    num_samples = FLAGS.num_samples\n    batch_size = FLAGS.batch_size\n    jit_compile = FLAGS.jit_compile\n\n    if benchmark_name is None:\n        for name, benchmark_fn in BENCHMARK_NAMES.items():\n            benchmark_fn(num_samples, batch_size, jit_compile)\n        return\n\n    if benchmark_name not in BENCHMARK_NAMES:\n        raise ValueError(\n            f\"Invalid benchmark name: {benchmark_name}, `benchmark_name` must \"\n            f\"be one of {BENCHMARK_NAMES.keys()}\"\n        )\n    benchmark_fn = BENCHMARK_NAMES[benchmark_name]\n    benchmark_fn(num_samples, batch_size, jit_compile)\n\n\nif __name__ == \"__main__\":\n    app.run(main)\n"
  },
  {
    "path": "benchmarks/layer_benchmark/base_benchmark.py",
    "content": "import time\n\nimport numpy as np\nimport tensorflow as tf\nfrom absl import flags\n\nimport keras\n\nFLAGS = flags.FLAGS\n\nflags.DEFINE_string(\n    \"benchmark_name\",\n    None,\n    \"The name of benchmark to run. If None, all benchmarks in the file will be \"\n    \"run.\",\n)\n\nflags.DEFINE_integer(\n    \"num_samples\",\n    1000,\n    \"Number of input data samples.\",\n)\n\nflags.DEFINE_integer(\n    \"batch_size\",\n    20,\n    \"Batch size of data.\",\n)\n\nflags.DEFINE_bool(\n    \"jit_compile\",\n    True,\n    \"If True, the benchmark will run with XLA compilation.\",\n)\n\n\nclass BenchmarkMetricsCallback:\n    def __init__(self, start_batch=1, stop_batch=None):\n        self.start_batch = start_batch\n        self.stop_batch = stop_batch\n\n        self.state = {}\n\n    def on_train_batch_begin(self, batch, logs=None):\n        if batch == self.start_batch:\n            self.state[\"benchmark_begin\"] = time.time()\n\n    def on_train_batch_end(self, batch, logs=None):\n        if batch == self.stop_batch:\n            self.state[\"benchmark_end\"] = time.time()\n            throughput = (self.stop_batch - self.start_batch + 1) / (\n                self.state[\"benchmark_end\"] - self.state[\"benchmark_begin\"]\n            )\n            self.state[\"throughput\"] = throughput\n\n    def on_predict_batch_begin(self, batch, logs=None):\n        if batch == self.start_batch:\n            self.state[\"benchmark_begin\"] = time.time()\n\n    def on_predict_batch_end(self, batch, logs=None):\n        if batch == self.stop_batch:\n            self.state[\"benchmark_end\"] = time.time()\n            throughput = (self.stop_batch - self.start_batch + 1) / (\n                self.state[\"benchmark_end\"] - self.state[\"benchmark_begin\"]\n            )\n            self.state[\"throughput\"] = throughput\n\n\nclass KerasCoreBenchmarkMetricsCallback(keras.callbacks.Callback):\n    def __init__(self, start_batch=1, stop_batch=None):\n        self._callback = BenchmarkMetricsCallback(start_batch, stop_batch)\n\n    def on_train_batch_begin(self, batch, logs=None):\n        self._callback.on_train_batch_begin(batch, logs)\n\n    def on_train_batch_end(self, batch, logs=None):\n        self._callback.on_train_batch_end(batch, logs)\n\n    def on_predict_batch_begin(self, batch, logs=None):\n        self._callback.on_predict_batch_begin(batch, logs)\n\n    def on_predict_batch_end(self, batch, logs=None):\n        self._callback.on_predict_batch_end(batch, logs)\n\n\nclass TFKerasBenchmarkMetricsCallback(tf.keras.callbacks.Callback):\n    def __init__(self, start_batch=1, stop_batch=None):\n        self._callback = BenchmarkMetricsCallback(start_batch, stop_batch)\n\n    def on_train_batch_begin(self, batch, logs=None):\n        self._callback.on_train_batch_begin(batch, logs)\n\n    def on_train_batch_end(self, batch, logs=None):\n        self._callback.on_train_batch_end(batch, logs)\n\n    def on_predict_batch_begin(self, batch, logs=None):\n        self._callback.on_predict_batch_begin(batch, logs)\n\n    def on_predict_batch_end(self, batch, logs=None):\n        self._callback.on_predict_batch_end(batch, logs)\n\n\nclass LayerBenchmark:\n    def __init__(\n        self,\n        layer_name,\n        init_args,\n        input_shape,\n        flat_call_inputs=True,\n        jit_compile=True,\n        keras_layer=None,\n        tf_keras_layer=None,\n    ):\n        self.layer_name = layer_name\n        _keras_layer_class = getattr(keras.layers, layer_name)\n        _tf_keras_layer_class = getattr(tf.keras.layers, layer_name)\n\n        if keras_layer is None:\n            # Sometimes you want to initialize the keras layer and tf_keras\n            # layer in a different way. For example, `Bidirectional` layer,\n            # which takes in `keras.layers.Layer` and\n            # `tf.keras.layer.Layer` separately.\n            self._keras_layer = _keras_layer_class(**init_args)\n        else:\n            self._keras_layer = keras_layer\n\n        if tf_keras_layer is None:\n            self._tf_keras_layer = _tf_keras_layer_class(**init_args)\n        else:\n            self._tf_keras_layer = tf_keras_layer\n\n        self.input_shape = input_shape\n        self._keras_model = self._build_keras_model(\n            input_shape, flat_call_inputs\n        )\n        self._tf_keras_model = self._build_tf_keras_model(\n            input_shape, flat_call_inputs\n        )\n\n        self._keras_model.compile(\n            loss=\"mse\", optimizer=\"sgd\", jit_compile=jit_compile\n        )\n        self._tf_keras_model.compile(\n            loss=\"mse\", optimizer=\"sgd\", jit_compile=jit_compile\n        )\n\n        self.flat_call_inputs = flat_call_inputs\n        self.jit_compile = jit_compile\n        self.input_shape = input_shape\n\n    def _build_keras_model(self, input_shape, flat_call_inputs=True):\n        inputs = []\n        if not isinstance(input_shape[0], (tuple, list)):\n            input_shape = [input_shape]\n\n        for shape in input_shape:\n            inputs.append(keras.Input(shape=shape))\n\n        if flat_call_inputs:\n            outputs = self._keras_layer(*inputs)\n        else:\n            outputs = self._keras_layer(inputs)\n        return keras.Model(inputs=inputs, outputs=outputs)\n\n    def _build_tf_keras_model(self, input_shape, flat_call_inputs=True):\n        inputs = []\n        if not isinstance(input_shape[0], (tuple, list)):\n            input_shape = [input_shape]\n\n        for shape in input_shape:\n            inputs.append(tf.keras.Input(shape=shape))\n\n        if flat_call_inputs:\n            outputs = self._tf_keras_layer(*inputs)\n        else:\n            outputs = self._tf_keras_layer(inputs)\n        return tf.keras.Model(inputs=inputs, outputs=outputs)\n\n    def benchmark_predict(self, num_samples, batch_size, data=None):\n        if data is None:\n            # Generate default data if not provided.\n            if isinstance(self.input_shape[0], (tuple, list)):\n                # The layer has multiple inputs.\n                data = []\n                for data_shape in self.input_shape:\n                    data_shape = [num_samples] + list(data_shape)\n                    data.append(np.random.normal(size=data_shape))\n            else:\n                data_shape = [num_samples] + list(self.input_shape)\n                data = np.random.normal(size=data_shape)\n\n        num_iterations = num_samples // batch_size - 1\n        callback = KerasCoreBenchmarkMetricsCallback(stop_batch=num_iterations)\n        tf_keras_callback = TFKerasBenchmarkMetricsCallback(\n            stop_batch=num_iterations\n        )\n\n        self._keras_model.predict(\n            data,\n            batch_size=batch_size,\n            callbacks=[callback],\n        )\n\n        self._tf_keras_model.predict(\n            data,\n            batch_size=batch_size,\n            callbacks=[tf_keras_callback],\n        )\n\n        keras_throughput = callback._callback.state[\"throughput\"] * batch_size\n        tf_keras_throughput = (\n            tf_keras_callback._callback.state[\"throughput\"] * batch_size\n        )\n        print(\n            f\"Keras 3 throughput of forward pass of {self.layer_name}: \"\n            f\"{keras_throughput:.2f} samples/sec.\"\n        )\n        print(\n            f\"TF Keras throughput of forward pass of {self.layer_name}: \"\n            f\"{tf_keras_throughput:.2f} samples/sec.\"\n        )\n\n    def benchmark_train(self, num_samples, batch_size, data=None, label=None):\n        if data is None:\n            # Generate default data if not provided.\n            if isinstance(self.input_shape[0], (tuple, list)):\n                # The layer has multiple inputs.\n                data = []\n                for data_shape in self.input_shape:\n                    data_shape = [num_samples] + list(data_shape)\n                    data.append(np.random.normal(size=data_shape))\n            else:\n                data_shape = [num_samples] + list(self.input_shape)\n                data = [np.random.normal(size=data_shape)]\n\n        if label is None:\n            # Generate default label if not provided.\n            if self.flat_call_inputs:\n                # Scale by a small factor to avoid zero gradients.\n                label = (\n                    keras.backend.convert_to_numpy(self._keras_layer(*data))\n                    * 1.001\n                )\n            else:\n                label = (\n                    keras.backend.convert_to_numpy(self._keras_layer(data))\n                    * 1.001\n                )\n\n        num_iterations = num_samples // batch_size - 1\n        callback = KerasCoreBenchmarkMetricsCallback(stop_batch=num_iterations)\n        tf_keras_callback = TFKerasBenchmarkMetricsCallback(\n            stop_batch=num_iterations\n        )\n\n        self._keras_model.fit(\n            data,\n            label,\n            batch_size=batch_size,\n            callbacks=[callback],\n        )\n        self._tf_keras_model.fit(\n            data,\n            label,\n            batch_size=batch_size,\n            callbacks=[tf_keras_callback],\n        )\n\n        keras_throughput = callback._callback.state[\"throughput\"] * batch_size\n        tf_keras_throughput = (\n            tf_keras_callback._callback.state[\"throughput\"] * batch_size\n        )\n        print(\n            f\"Keras 3 throughput of forward & backward pass of \"\n            f\"{self.layer_name}: {keras_throughput:.2f} samples/sec.\"\n        )\n        print(\n            f\"TF Keras  throughput of forward & backward pass of \"\n            f\"{self.layer_name}: {tf_keras_throughput:.2f} samples/sec.\"\n        )\n"
  },
  {
    "path": "benchmarks/layer_benchmark/conv_benchmark.py",
    "content": "\"\"\"Benchmark conv layers.\n\nTo run benchmarks, see the following command for an example, please change the\nflag to your custom value:\n\n```\npython3 -m benchmarks.layer_benchmark.conv_benchmark \\\n    --benchmark_name=benchmark_conv2D \\\n    --num_samples=2046 \\\n    --batch_size=256 \\\n    --jit_compile=True\n```\n\"\"\"\n\nfrom absl import app\nfrom absl import flags\n\nfrom benchmarks.layer_benchmark.base_benchmark import LayerBenchmark\n\nFLAGS = flags.FLAGS\n\n\ndef benchmark_conv1D(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"Conv1D\"\n    init_args = {\n        \"filters\": 64,\n        \"kernel_size\": 2,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[1024, 256],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_conv2D(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"Conv2D\"\n    init_args = {\n        \"filters\": 16,\n        \"kernel_size\": 2,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[128, 128, 4],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_conv3D(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"Conv3D\"\n    init_args = {\n        \"filters\": 16,\n        \"kernel_size\": 2,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[32, 32, 32, 4],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_depthwise_conv1D(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"DepthwiseConv1D\"\n    init_args = {\n        \"kernel_size\": 16,\n        \"depth_multiplier\": 2,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 64],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_depthwise_conv2D(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"DepthwiseConv2D\"\n    init_args = {\n        \"kernel_size\": 16,\n        \"depth_multiplier\": 2,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[128, 128, 4],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_separable_conv1D(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"SeparableConv1D\"\n    init_args = {\n        \"kernel_size\": 16,\n        \"depth_multiplier\": 2,\n        \"filters\": 3,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 64],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_separable_conv2D(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"SeparableConv2D\"\n    init_args = {\n        \"kernel_size\": 16,\n        \"depth_multiplier\": 2,\n        \"filters\": 3,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[128, 128, 4],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_conv1D_transpose(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"Conv1DTranspose\"\n    init_args = {\n        \"filters\": 32,\n        \"kernel_size\": 4,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 256],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_conv2D_transpose(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"Conv2DTranspose\"\n    init_args = {\n        \"filters\": 16,\n        \"kernel_size\": 2,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[128, 128, 4],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_conv3D_transpose(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"Conv3DTranspose\"\n    init_args = {\n        \"filters\": 16,\n        \"kernel_size\": 2,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[32, 32, 32, 4],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\nBENCHMARK_NAMES = {\n    \"benchmark_conv1D\": benchmark_conv1D,\n    \"benchmark_conv2D\": benchmark_conv2D,\n    \"benchmark_conv3D\": benchmark_conv3D,\n    \"benchmark_depthwise_conv1D\": benchmark_depthwise_conv1D,\n    \"benchmark_depthwise_conv2D\": benchmark_depthwise_conv2D,\n    \"benchmark_separable_conv1D\": benchmark_separable_conv1D,\n    \"benchmark_separable_conv2D\": benchmark_separable_conv2D,\n    \"benchmark_conv1D_transpose\": benchmark_conv1D_transpose,\n    \"benchmark_conv2D_transpose\": benchmark_conv2D_transpose,\n    \"benchmark_conv3D_transpose\": benchmark_conv3D_transpose,\n}\n\n\ndef main(_):\n    benchmark_name = FLAGS.benchmark_name\n    num_samples = FLAGS.num_samples\n    batch_size = FLAGS.batch_size\n    jit_compile = FLAGS.jit_compile\n\n    if benchmark_name is None:\n        for name, benchmark_fn in BENCHMARK_NAMES:\n            benchmark_fn(num_samples, batch_size, jit_compile)\n        return\n\n    if benchmark_name not in BENCHMARK_NAMES:\n        raise ValueError(\n            f\"Invalid benchmark name: {benchmark_name}, `benchmark_name` must \"\n            f\"be one of {BENCHMARK_NAMES.keys()}\"\n        )\n    benchmark_fn = BENCHMARK_NAMES[benchmark_name]\n    benchmark_fn(num_samples, batch_size, jit_compile)\n\n\nif __name__ == \"__main__\":\n    app.run(main)\n"
  },
  {
    "path": "benchmarks/layer_benchmark/core_benchmark.py",
    "content": "\"\"\"Benchmark core layers.\n\nTo run benchmarks, see the following command for an example, please change the\nflag to your custom value:\n\n```\npython3 -m benchmarks.layer_benchmark.core_benchmark \\\n    --benchmark_name=benchmark_dense \\\n    --num_samples=2048 \\\n    --batch_size=256 \\\n    --jit_compile=True\n```\n\"\"\"\n\nimport numpy as np\nfrom absl import app\nfrom absl import flags\n\nfrom benchmarks.layer_benchmark.base_benchmark import LayerBenchmark\n\nFLAGS = flags.FLAGS\n\n\ndef benchmark_dense(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"Dense\"\n    init_args = {\"units\": 256}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 256],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_einsum_dense(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"EinsumDense\"\n    init_args = {\n        \"equation\": \"abc,cd->abd\",\n        \"output_shape\": (None, 256),\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 256],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_embedding(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"Embedding\"\n    init_args = {\n        \"input_dim\": 128,\n        \"output_dim\": 256,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[\n            256,\n        ],\n        jit_compile=jit_compile,\n    )\n\n    data = [np.random.randint(30, size=(num_samples, 256))]\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n        data=data,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n        data=data,\n    )\n\n\nBENCHMARK_NAMES = {\n    \"benchmark_dense\": benchmark_dense,\n    \"benchmark_einsum_dense\": benchmark_einsum_dense,\n    \"benchmark_embedding\": benchmark_embedding,\n}\n\n\ndef main(_):\n    benchmark_name = FLAGS.benchmark_name\n    num_samples = FLAGS.num_samples\n    batch_size = FLAGS.batch_size\n    jit_compile = FLAGS.jit_compile\n\n    if benchmark_name is None:\n        for name, benchmark_fn in BENCHMARK_NAMES.items():\n            benchmark_fn(num_samples, batch_size, jit_compile)\n        return\n\n    if benchmark_name not in BENCHMARK_NAMES:\n        raise ValueError(\n            f\"Invalid benchmark name: {benchmark_name}, `benchmark_name` must \"\n            f\"be one of {BENCHMARK_NAMES.keys()}\"\n        )\n    benchmark_fn = BENCHMARK_NAMES[benchmark_name]\n    benchmark_fn(num_samples, batch_size, jit_compile)\n\n\nif __name__ == \"__main__\":\n    app.run(main)\n"
  },
  {
    "path": "benchmarks/layer_benchmark/merge_benchmark.py",
    "content": "\"\"\"Benchmark merge layers.\n\nTo run benchmarks, see the following command for an example, please change the\nflag to your custom value:\n\n```\npython3 -m benchmarks.layer_benchmark.merge_benchmark \\\n    --benchmark_name=benchmark_add \\\n    --num_samples=2048 \\\n    --batch_size=256 \\\n    --jit_compile=True\n```\n\"\"\"\n\nfrom absl import app\nfrom absl import flags\n\nfrom benchmarks.layer_benchmark.base_benchmark import LayerBenchmark\n\nFLAGS = flags.FLAGS\n\n\ndef benchmark_add(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"Add\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[[256, 256], [256, 256]],\n        flat_call_inputs=False,\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_average(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"Average\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[[256, 256], [256, 256]],\n        flat_call_inputs=False,\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_concatenate(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"Concatenate\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[[256, 256], [256, 256]],\n        flat_call_inputs=False,\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_dot(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"Dot\"\n    init_args = {\"axes\": [2, 1]}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[[256, 32], [32, 64]],\n        flat_call_inputs=False,\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_maximum(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"Maximum\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[[256, 256], [256, 256]],\n        flat_call_inputs=False,\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_minimum(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"Minimum\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[[256, 256], [256, 256]],\n        flat_call_inputs=False,\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_multiply(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"Multiply\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[[256, 64], [256, 64]],\n        flat_call_inputs=False,\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_subtract(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"Subtract\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[[256, 256], [256, 256]],\n        flat_call_inputs=False,\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\nBENCHMARK_NAMES = {\n    \"benchmark_add\": benchmark_add,\n    \"benchmark_average\": benchmark_average,\n    \"benchmark_concatenate\": benchmark_concatenate,\n    \"benchmark_dot\": benchmark_dot,\n    \"benchmark_maximum\": benchmark_maximum,\n    \"benchmark_minimum\": benchmark_minimum,\n    \"benchmark_multiply\": benchmark_multiply,\n    \"benchmark_subtract\": benchmark_subtract,\n}\n\n\ndef main(_):\n    benchmark_name = FLAGS.benchmark_name\n    num_samples = FLAGS.num_samples\n    batch_size = FLAGS.batch_size\n    jit_compile = FLAGS.jit_compile\n\n    if benchmark_name is None:\n        for name, benchmark_fn in BENCHMARK_NAMES.items():\n            benchmark_fn(num_samples, batch_size, jit_compile)\n        return\n\n    if benchmark_name not in BENCHMARK_NAMES:\n        raise ValueError(\n            f\"Invalid benchmark name: {benchmark_name}, `benchmark_name` must \"\n            f\"be one of {BENCHMARK_NAMES.keys()}\"\n        )\n    benchmark_fn = BENCHMARK_NAMES[benchmark_name]\n    benchmark_fn(num_samples, batch_size, jit_compile)\n\n\nif __name__ == \"__main__\":\n    app.run(main)\n"
  },
  {
    "path": "benchmarks/layer_benchmark/normalization_benchmark.py",
    "content": "\"\"\"Benchmark normalization layers.\n\nTo run benchmarks, see the following command for an example, please change the\nflag to your custom value:\n\n```\npython3 -m benchmarks.layer_benchmark.normalization_benchmark \\\n    --benchmark_name=benchmark_batch_normalization \\\n    --num_samples=2048 \\\n    --batch_size=256 \\\n    --jit_compile=True\n```\n\"\"\"\n\nfrom absl import app\nfrom absl import flags\n\nfrom benchmarks.layer_benchmark.base_benchmark import LayerBenchmark\n\nFLAGS = flags.FLAGS\n\n\ndef benchmark_batch_normalization(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"BatchNormalization\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 256, 4],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_group_normalization(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"GroupNormalization\"\n    init_args = {\n        \"groups\": 2,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 256, 4],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_layer_normalization(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"LayerNormalization\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 128, 4],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_unit_normalization(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"UnitNormalization\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 128, 4],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\nBENCHMARK_NAMES = {\n    \"benchmark_batch_normalization\": benchmark_batch_normalization,\n    \"benchmark_group_normalization\": benchmark_group_normalization,\n    \"benchmark_layer_normalization\": benchmark_layer_normalization,\n    \"benchmark_unit_normalization\": benchmark_unit_normalization,\n}\n\n\ndef main(_):\n    benchmark_name = FLAGS.benchmark_name\n    num_samples = FLAGS.num_samples\n    batch_size = FLAGS.batch_size\n    jit_compile = FLAGS.jit_compile\n\n    if benchmark_name is None:\n        for name, benchmark_fn in BENCHMARK_NAMES.items():\n            benchmark_fn(num_samples, batch_size, jit_compile)\n        return\n\n    if benchmark_name not in BENCHMARK_NAMES:\n        raise ValueError(\n            f\"Invalid benchmark name: {benchmark_name}, `benchmark_name` must \"\n            f\"be one of {BENCHMARK_NAMES.keys()}\"\n        )\n    benchmark_fn = BENCHMARK_NAMES[benchmark_name]\n    benchmark_fn(num_samples, batch_size, jit_compile)\n\n\nif __name__ == \"__main__\":\n    app.run(main)\n"
  },
  {
    "path": "benchmarks/layer_benchmark/pooling_benchmark.py",
    "content": "\"\"\"Benchmark pooling layers.\n\nTo run benchmarks, see the following command for an example, please change the\nflag to your custom value:\n\n```\npython3 -m benchmarks.layer_benchmark.pooling_benchmark \\\n    --benchmark_name=benchmark_max_pooling1d \\\n    --num_samples=2048 \\\n    --batch_size=256 \\\n    --jit_compile=True\n```\n\"\"\"\n\nfrom absl import app\nfrom absl import flags\n\nfrom benchmarks.layer_benchmark.base_benchmark import LayerBenchmark\n\nFLAGS = flags.FLAGS\n\n\ndef benchmark_average_pooling1d(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"AveragePooling1D\"\n    init_args = {\n        \"pool_size\": 2,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[1024, 256],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_average_pooling2d(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"AveragePooling2D\"\n    init_args = {\n        \"pool_size\": 2,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 256, 3],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_average_pooling3d(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"AveragePooling3D\"\n    init_args = {\n        \"pool_size\": 2,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[64, 64, 32, 3],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_max_pooling1d(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"MaxPooling1D\"\n    init_args = {\n        \"pool_size\": 2,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[1024, 256],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_max_pooling2d(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"MaxPooling2D\"\n    init_args = {\n        \"pool_size\": 2,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 256, 3],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_max_pooling3d(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"MaxPooling3D\"\n    init_args = {\n        \"pool_size\": 2,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[64, 64, 32, 3],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_global_average_pooling1d(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"GlobalAveragePooling1D\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[1024, 256],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_global_average_pooling2d(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"GlobalAveragePooling2D\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 256, 3],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_global_average_pooling3d(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"GlobalAveragePooling3D\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[64, 64, 32, 3],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_global_max_pooling1d(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"GlobalMaxPooling1D\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[1024, 256],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_global_max_pooling2d(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"GlobalMaxPooling2D\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 256, 3],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_global_max_pooling3d(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"GlobalMaxPooling3D\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[64, 64, 32, 3],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\nBENCHMARK_NAMES = {\n    \"benchmark_average_pooling1d\": benchmark_average_pooling1d,\n    \"benchmark_average_pooling2d\": benchmark_average_pooling2d,\n    \"benchmark_average_pooling3d\": benchmark_average_pooling3d,\n    \"benchmark_max_pooling1d\": benchmark_max_pooling1d,\n    \"benchmark_max_pooling2d\": benchmark_max_pooling2d,\n    \"benchmark_max_pooling3d\": benchmark_max_pooling3d,\n    \"benchmark_global_average_pooling1d\": benchmark_global_average_pooling1d,\n    \"benchmark_global_average_pooling2d\": benchmark_global_average_pooling2d,\n    \"benchmark_global_average_pooling3d\": benchmark_global_average_pooling3d,\n    \"benchmark_global_max_pooling1d\": benchmark_global_max_pooling1d,\n    \"benchmark_global_max_pooling2d\": benchmark_global_max_pooling2d,\n    \"benchmark_global_max_pooling3d\": benchmark_global_max_pooling3d,\n}\n\n\ndef main(_):\n    benchmark_name = FLAGS.benchmark_name\n    num_samples = FLAGS.num_samples\n    batch_size = FLAGS.batch_size\n    jit_compile = FLAGS.jit_compile\n\n    if benchmark_name is None:\n        for name, benchmark_fn in BENCHMARK_NAMES.items():\n            benchmark_fn(num_samples, batch_size, jit_compile)\n        return\n\n    if benchmark_name not in BENCHMARK_NAMES:\n        raise ValueError(\n            f\"Invalid benchmark name: {benchmark_name}, `benchmark_name` must \"\n            f\"be one of {BENCHMARK_NAMES.keys()}\"\n        )\n    benchmark_fn = BENCHMARK_NAMES[benchmark_name]\n    benchmark_fn(num_samples, batch_size, jit_compile)\n\n\nif __name__ == \"__main__\":\n    app.run(main)\n"
  },
  {
    "path": "benchmarks/layer_benchmark/random_rotation_benchmark.py",
    "content": "\"\"\"Benchmark RandomRotation layer.\"\"\"\n\nfrom absl import app\nfrom absl import flags\n\nfrom benchmarks.layer_benchmark.base_benchmark import LayerBenchmark\n\nFLAGS = flags.FLAGS\n\n\ndef benchmark_random_rotation(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"RandomRotation\"\n    init_args = {\"factor\": 0.1}\n\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[224, 224, 3],\n        jit_compile=jit_compile,\n    )\n\n    # Predict is effectively a no-op for preprocessing layers,\n    # but we still call it to follow the standard benchmark structure.\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\nBENCHMARK_NAMES = {\n    \"benchmark_random_rotation\": benchmark_random_rotation,\n}\n\n\ndef main(_):\n    benchmark_name = FLAGS.benchmark_name\n    num_samples = FLAGS.num_samples\n    batch_size = FLAGS.batch_size\n    jit_compile = FLAGS.jit_compile\n\n    if benchmark_name is None:\n        for benchmark_fn in BENCHMARK_NAMES.values():\n            benchmark_fn(num_samples, batch_size, jit_compile)\n        return\n\n    if benchmark_name not in BENCHMARK_NAMES:\n        raise ValueError(\n            f\"Invalid benchmark name: {benchmark_name}, \"\n            f\"`benchmark_name` must be one of {BENCHMARK_NAMES.keys()}\"\n        )\n\n    BENCHMARK_NAMES[benchmark_name](num_samples, batch_size, jit_compile)\n\n\nif __name__ == \"__main__\":\n    app.run(main)\n"
  },
  {
    "path": "benchmarks/layer_benchmark/regularization_benchmark.py",
    "content": "\"\"\"Benchmark regularization layers.\n\nTo run benchmarks, see the following command for an example, please change the\nflag to your custom value:\n\n```\npython3 -m benchmarks.layer_benchmark.regularization_benchmark \\\n    --benchmark_name=benchmark_dropout\\\n    --num_samples=2048 \\\n    --batch_size=256 \\\n    --jit_compile=True\n```\n\"\"\"\n\nfrom absl import app\nfrom absl import flags\n\nfrom benchmarks.layer_benchmark.base_benchmark import LayerBenchmark\n\nFLAGS = flags.FLAGS\n\n\ndef benchmark_dropout(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"Dropout\"\n    init_args = {\n        \"rate\": 0.5,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 256, 4],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_gaussian_dropout(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"GaussianDropout\"\n    init_args = {\n        \"rate\": 0.5,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 256, 4],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_gaussian_noise(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"GaussianNoise\"\n    init_args = {\n        \"stddev\": 0.5,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 256, 4],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_spatial_dropout1D(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"SpatialDropout1D\"\n    init_args = {\n        \"rate\": 0.5,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 3],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_spatial_dropout2D(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"SpatialDropout2D\"\n    init_args = {\n        \"rate\": 0.5,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 256, 3],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_spatial_dropout3D(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"SpatialDropout3D\"\n    init_args = {\n        \"rate\": 0.5,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[32, 32, 32, 3],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\nBENCHMARK_NAMES = {\n    \"benchmark_dropout\": benchmark_dropout,\n    \"benchmark_gaussian_dropout\": benchmark_gaussian_dropout,\n    \"benchmark_gaussian_noise\": benchmark_gaussian_noise,\n    \"benchmark_spatial_dropout1D\": benchmark_spatial_dropout1D,\n    \"benchmark_spatial_dropout2D\": benchmark_spatial_dropout2D,\n    \"benchmark_spatial_dropout3D\": benchmark_spatial_dropout3D,\n}\n\n\ndef main(_):\n    benchmark_name = FLAGS.benchmark_name\n    num_samples = FLAGS.num_samples\n    batch_size = FLAGS.batch_size\n    jit_compile = FLAGS.jit_compile\n\n    if benchmark_name is None:\n        for name, benchmark_fn in BENCHMARK_NAMES.items():\n            benchmark_fn(num_samples, batch_size, jit_compile)\n        return\n\n    if benchmark_name not in BENCHMARK_NAMES:\n        raise ValueError(\n            f\"Invalid benchmark name: {benchmark_name}, `benchmark_name` must \"\n            f\"be one of {BENCHMARK_NAMES.keys()}\"\n        )\n    benchmark_fn = BENCHMARK_NAMES[benchmark_name]\n    benchmark_fn(num_samples, batch_size, jit_compile)\n\n\nif __name__ == \"__main__\":\n    app.run(main)\n"
  },
  {
    "path": "benchmarks/layer_benchmark/reshaping_benchmark.py",
    "content": "\"\"\"Benchmark reshaping layers.\n\nTo run benchmarks, see the following command for an example, please change the\nflag to your custom value:\n\n```\npython3 -m benchmarks.layer_benchmark.reshaping_benchmark \\\n    --benchmark_name=benchmark_cropping2d \\\n    --num_samples=2048 \\\n    --batch_size=256 \\\n    --jit_compile=True\n```\n\"\"\"\n\nfrom absl import app\nfrom absl import flags\n\nfrom benchmarks.layer_benchmark.base_benchmark import LayerBenchmark\n\nFLAGS = flags.FLAGS\n\n\ndef benchmark_cropping1d(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"Cropping1D\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[1024, 256],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_cropping2d(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"Cropping2D\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 256, 3],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_cropping3d(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"Cropping3D\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[32, 32, 32, 3],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_flatten(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"Flatten\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 256, 3],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_permute(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"Permute\"\n    init_args = {\n        \"dims\": (2, 1),\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 256],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_up_sampling1d(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"UpSampling1D\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 3],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_up_sampling2d(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"UpSampling2D\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[128, 128, 3],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_up_sampling3d(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"UpSampling3D\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[32, 16, 16, 3],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_zero_padding1d(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"ZeroPadding1D\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 3],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_zero_padding2d(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"ZeroPadding2D\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 256, 3],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_zero_padding3d(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"ZeroPadding3D\"\n    init_args = {}\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[32, 32, 32, 3],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\nBENCHMARK_NAMES = {\n    \"benchmark_cropping1d\": benchmark_cropping1d,\n    \"benchmark_cropping2d\": benchmark_cropping2d,\n    \"benchmark_cropping3d\": benchmark_cropping3d,\n    \"benchmark_flatten\": benchmark_flatten,\n    \"benchmark_permute\": benchmark_permute,\n    \"benchmark_up_sampling1d\": benchmark_up_sampling1d,\n    \"benchmark_up_sampling2d\": benchmark_up_sampling2d,\n    \"benchmark_up_sampling3d\": benchmark_up_sampling3d,\n    \"benchmark_zero_padding1d\": benchmark_zero_padding1d,\n    \"benchmark_zero_padding2d\": benchmark_zero_padding2d,\n    \"benchmark_zero_padding3d\": benchmark_zero_padding3d,\n}\n\n\ndef main(_):\n    benchmark_name = FLAGS.benchmark_name\n    num_samples = FLAGS.num_samples\n    batch_size = FLAGS.batch_size\n    jit_compile = FLAGS.jit_compile\n\n    if benchmark_name is None:\n        for name, benchmark_fn in BENCHMARK_NAMES.items():\n            benchmark_fn(num_samples, batch_size, jit_compile)\n        return\n\n    if benchmark_name not in BENCHMARK_NAMES:\n        raise ValueError(\n            f\"Invalid benchmark name: {benchmark_name}, `benchmark_name` must \"\n            f\"be one of {BENCHMARK_NAMES.keys()}\"\n        )\n    benchmark_fn = BENCHMARK_NAMES[benchmark_name]\n    benchmark_fn(num_samples, batch_size, jit_compile)\n\n\nif __name__ == \"__main__\":\n    app.run(main)\n"
  },
  {
    "path": "benchmarks/layer_benchmark/rnn_benchmark.py",
    "content": "\"\"\"Benchmark rnn layers.\n\nTo run benchmarks, see the following command for an example, please change the\nflag to your custom value:\n\n```\npython3 -m benchmarks.layer_benchmark.rnn_benchmark \\\n    --benchmark_name=benchmark_lstm \\\n    --num_samples=2048 \\\n    --batch_size=256 \\\n    --jit_compile=True\n```\n\"\"\"\n\nimport tensorflow as tf\nfrom absl import app\nfrom absl import flags\n\nimport keras\nfrom benchmarks.layer_benchmark.base_benchmark import LayerBenchmark\n\nFLAGS = flags.FLAGS\n\n\ndef benchmark_conv_lstm1d(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"ConvLSTM1D\"\n    init_args = {\n        \"filters\": 16,\n        \"kernel_size\": 2,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[32, 256, 3],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_conv_lstm2d(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"ConvLSTM2D\"\n    init_args = {\n        \"filters\": 16,\n        \"kernel_size\": 2,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[32, 32, 32, 3],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_conv_lstm3d(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"ConvLSTM3D\"\n    init_args = {\n        \"filters\": 8,\n        \"kernel_size\": 2,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[8, 16, 16, 16, 3],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_gru(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"GRU\"\n    init_args = {\n        \"units\": 32,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 256],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_lstm(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"LSTM\"\n    init_args = {\n        \"units\": 32,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 256],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_simple_rnn(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"SimpleRNN\"\n    init_args = {\n        \"units\": 32,\n    }\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 256],\n        jit_compile=jit_compile,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_bidirectional(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"Bidirectional\"\n    init_args = {}\n    keras_layer = keras.layers.Bidirectional(keras.layers.LSTM(32))\n    tf_keras_layer = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32))\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[256, 256],\n        jit_compile=jit_compile,\n        keras_layer=keras_layer,\n        tf_keras_layer=tf_keras_layer,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\ndef benchmark_time_distributed(\n    num_samples,\n    batch_size,\n    jit_compile=True,\n):\n    layer_name = \"TimeDistributed\"\n    init_args = {}\n    keras_layer = keras.layers.TimeDistributed(keras.layers.Conv2D(16, (3, 3)))\n    tf_keras_layer = tf.keras.layers.TimeDistributed(\n        tf.keras.layers.Conv2D(16, (3, 3))\n    )\n    benchmark = LayerBenchmark(\n        layer_name,\n        init_args,\n        input_shape=[10, 32, 32, 3],\n        jit_compile=jit_compile,\n        keras_layer=keras_layer,\n        tf_keras_layer=tf_keras_layer,\n    )\n\n    benchmark.benchmark_predict(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n    benchmark.benchmark_train(\n        num_samples=num_samples,\n        batch_size=batch_size,\n    )\n\n\nBENCHMARK_NAMES = {\n    \"benchmark_conv_lstm1d\": benchmark_conv_lstm1d,\n    \"benchmark_conv_lstm2d\": benchmark_conv_lstm2d,\n    \"benchmark_conv_lstm3d\": benchmark_conv_lstm3d,\n    \"benchmark_gru\": benchmark_gru,\n    \"benchmark_lstm\": benchmark_lstm,\n    \"benchmark_simple_rnn\": benchmark_simple_rnn,\n    \"benchmark_bidirectional\": benchmark_bidirectional,\n    \"benchmark_time_distributed\": benchmark_time_distributed,\n}\n\n\ndef main(_):\n    benchmark_name = FLAGS.benchmark_name\n    num_samples = FLAGS.num_samples\n    batch_size = FLAGS.batch_size\n    jit_compile = FLAGS.jit_compile\n\n    if benchmark_name is None:\n        for name, benchmark_fn in BENCHMARK_NAMES.items():\n            benchmark_fn(num_samples, batch_size, jit_compile)\n        return\n\n    if benchmark_name not in BENCHMARK_NAMES:\n        raise ValueError(\n            f\"Invalid benchmark name: {benchmark_name}, `benchmark_name` must \"\n            f\"be one of {BENCHMARK_NAMES.keys()}\"\n        )\n    benchmark_fn = BENCHMARK_NAMES[benchmark_name]\n    benchmark_fn(num_samples, batch_size, jit_compile)\n\n\nif __name__ == \"__main__\":\n    app.run(main)\n"
  },
  {
    "path": "benchmarks/model_benchmark/__init__.py",
    "content": ""
  },
  {
    "path": "benchmarks/model_benchmark/benchmark_utils.py",
    "content": "import time\n\nimport keras\n\n\nclass BenchmarkMetricsCallback(keras.callbacks.Callback):\n    def __init__(self, start_batch=1, stop_batch=None):\n        self.start_batch = start_batch\n        self.stop_batch = stop_batch\n\n        # Store the throughput of each epoch.\n        self.state = {\"throughput\": []}\n\n    def on_train_batch_begin(self, batch, logs=None):\n        if batch == self.start_batch:\n            self.state[\"epoch_begin_time\"] = time.time()\n\n    def on_train_batch_end(self, batch, logs=None):\n        if batch == self.stop_batch:\n            epoch_end_time = time.time()\n            throughput = (self.stop_batch - self.start_batch + 1) / (\n                epoch_end_time - self.state[\"epoch_begin_time\"]\n            )\n            self.state[\"throughput\"].append(throughput)\n"
  },
  {
    "path": "benchmarks/model_benchmark/bert_benchmark.py",
    "content": "\"\"\"Benchmark BERT model on GLUE/MRPC task.\n\nTo run the script, make sure you are in benchmarks/ directory, abd run the\ncommand below:\n```\npython3 -m model_benchmark.bert_benchmark \\\n    --epochs 2 \\\n    --batch_size 32\n```\n\n\"\"\"\n\nimport time\n\nimport keras_nlp\nimport numpy as np\nimport tensorflow as tf\nimport tensorflow_datasets as tfds\nfrom absl import app\nfrom absl import flags\nfrom absl import logging\nfrom model_benchmark.benchmark_utils import BenchmarkMetricsCallback\n\nimport keras\n\nflags.DEFINE_string(\"model_size\", \"small\", \"The size of model to benchmark.\")\nflags.DEFINE_string(\n    \"mixed_precision_policy\",\n    \"mixed_float16\",\n    \"The global precision policy to use, e.g., 'mixed_float16' or 'float32'.\",\n)\nflags.DEFINE_integer(\"epochs\", 2, \"The number of epochs.\")\nflags.DEFINE_integer(\"batch_size\", 8, \"Batch Size.\")\n\n\nFLAGS = flags.FLAGS\n\n\nMODEL_SIZE_MAP = {\n    \"tiny\": \"bert_tiny_en_uncased\",\n    \"small\": \"bert_small_en_uncased\",\n    \"base\": \"bert_base_en_uncased\",\n    \"large\": \"bert_large_en_uncased\",\n}\n\n\ndef load_data():\n    \"\"\"Load data.\n\n    Load GLUE/MRPC dataset, and convert the dictionary format to\n    (features, label), where `features` is a tuple of all input sentences.\n    \"\"\"\n    feature_names = (\"sentence1\", \"sentence2\")\n\n    def split_features(x):\n        # GLUE comes with dictionary data, we convert it to a uniform format\n        # (features, label), where features is a tuple consisting of all\n        # features. This format is necessary for using KerasNLP preprocessors.\n        features = tuple([x[name] for name in feature_names])\n        label = x[\"label\"]\n        return (features, label)\n\n    train_ds, test_ds, validation_ds = tfds.load(\n        \"glue/mrpc\",\n        split=[\"train\", \"test\", \"validation\"],\n    )\n\n    train_ds = (\n        train_ds.map(split_features, num_parallel_calls=tf.data.AUTOTUNE)\n        .batch(FLAGS.batch_size)\n        .prefetch(tf.data.AUTOTUNE)\n    )\n    test_ds = (\n        test_ds.map(split_features, num_parallel_calls=tf.data.AUTOTUNE)\n        .batch(FLAGS.batch_size)\n        .prefetch(tf.data.AUTOTUNE)\n    )\n    validation_ds = (\n        validation_ds.map(split_features, num_parallel_calls=tf.data.AUTOTUNE)\n        .batch(FLAGS.batch_size)\n        .prefetch(tf.data.AUTOTUNE)\n    )\n    return train_ds, test_ds, validation_ds\n\n\ndef load_model():\n    if FLAGS.model_size not in MODEL_SIZE_MAP.keys():\n        raise KeyError(\n            f\"`model_size` must be one of {MODEL_SIZE_MAP.keys()}, but \"\n            f\"received {FLAGS.model_size}.\"\n        )\n    return keras_nlp.models.BertClassifier.from_preset(\n        MODEL_SIZE_MAP[FLAGS.model_size], num_classes=2\n    )\n\n\ndef main(_):\n    keras.mixed_precision.set_dtype_policy(FLAGS.mixed_precision_policy)\n\n    logging.info(\n        \"Benchmarking configs...\\n\"\n        \"=========================\\n\"\n        f\"MODEL: BERT {FLAGS.model_size}\\n\"\n        f\"TASK: glue/mrpc \\n\"\n        f\"BATCH_SIZE: {FLAGS.batch_size}\\n\"\n        f\"EPOCHS: {FLAGS.epochs}\\n\"\n        \"=========================\\n\"\n    )\n\n    # Load datasets.\n    train_ds, test_ds, validation_ds = load_data()\n\n    # Load the model.\n    model = load_model()\n    # Set loss and metrics.\n    loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n    metrics = [keras.metrics.SparseCategoricalAccuracy()]\n    # Configure optimizer.\n    lr = keras.optimizers.schedules.PolynomialDecay(\n        5e-4,\n        decay_steps=train_ds.cardinality() * FLAGS.epochs,\n        end_learning_rate=0.0,\n    )\n    optimizer = keras.optimizers.AdamW(lr, weight_decay=0.01)\n    optimizer.exclude_from_weight_decay(\n        var_names=[\"LayerNorm\", \"layer_norm\", \"bias\"]\n    )\n\n    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)\n\n    benchmark_metrics_callback = BenchmarkMetricsCallback(\n        start_batch=1,\n        stop_batch=train_ds.cardinality().numpy() - 1,\n    )\n\n    # Start training.\n    logging.info(\"Starting Training...\")\n\n    st = time.time()\n    history = model.fit(\n        train_ds,\n        validation_data=validation_ds,\n        epochs=FLAGS.epochs,\n        callbacks=[benchmark_metrics_callback],\n    )\n\n    wall_time = time.time() - st\n    validation_accuracy = history.history[\"val_sparse_categorical_accuracy\"][-1]\n    examples_per_second = (\n        np.mean(np.array(benchmark_metrics_callback.state[\"throughput\"]))\n        * FLAGS.batch_size\n    )\n\n    logging.info(\"Training Finished!\")\n    logging.info(f\"Wall Time: {wall_time:.4f} seconds.\")\n    logging.info(f\"Validation Accuracy: {validation_accuracy:.4f}\")\n    logging.info(f\"examples_per_second: {examples_per_second:.4f}\")\n\n\nif __name__ == \"__main__\":\n    app.run(main)\n"
  },
  {
    "path": "benchmarks/model_benchmark/image_classification_benchmark.py",
    "content": "\"\"\"Image classification benchmark.\n\nThis script runs image classification benchmark with \"dogs vs cats\" datasets.\nIt supports the following 3 models:\n\n- EfficientNetV2B0\n- Xception\n- ResNet50V2\n\nTo run the benchmark, make sure you are in model_benchmark/ directory, and run\nthe command below:\n\npython3 -m model_benchmark.image_classification_benchmark \\\n    --model=\"EfficientNetV2B0\" \\\n    --epochs=2 \\\n    --batch_size=32 \\\n    --mixed_precision_policy=\"mixed_float16\"\n\"\"\"\n\nimport time\n\nimport numpy as np\nimport tensorflow as tf\nimport tensorflow_datasets as tfds\nfrom absl import app\nfrom absl import flags\nfrom absl import logging\nfrom model_benchmark.benchmark_utils import BenchmarkMetricsCallback\n\nimport keras\n\nflags.DEFINE_string(\"model\", \"EfficientNetV2B0\", \"The model to benchmark.\")\nflags.DEFINE_integer(\"epochs\", 1, \"The number of epochs.\")\nflags.DEFINE_integer(\"batch_size\", 4, \"Batch Size.\")\nflags.DEFINE_string(\n    \"mixed_precision_policy\",\n    \"mixed_float16\",\n    \"The global precision policy to use, e.g., 'mixed_float16' or 'float32'.\",\n)\n\nFLAGS = flags.FLAGS\n\nBATCH_SIZE = 32\nIMAGE_SIZE = (224, 224)\nCHANNELS = 3\n\nMODEL_MAP = {\n    \"EfficientNetV2B0\": keras.applications.EfficientNetV2B0,\n    \"Xception\": keras.applications.Xception,\n    \"ResNet50V2\": keras.applications.ResNet50V2,\n}\n\n\ndef load_data():\n    # Load cats vs dogs dataset, and split into train and validation sets.\n    train_dataset, val_dataset = tfds.load(\n        \"cats_vs_dogs\", split=[\"train[:90%]\", \"train[90%:]\"], as_supervised=True\n    )\n\n    resizing = keras.layers.Resizing(\n        IMAGE_SIZE[0], IMAGE_SIZE[1], crop_to_aspect_ratio=True\n    )\n\n    def preprocess_inputs(image, label):\n        image = tf.cast(image, \"float32\")\n        return resizing(image), label\n\n    train_dataset = (\n        train_dataset.map(\n            preprocess_inputs, num_parallel_calls=tf.data.AUTOTUNE\n        )\n        .batch(FLAGS.batch_size)\n        .prefetch(tf.data.AUTOTUNE)\n    )\n    val_dataset = (\n        val_dataset.map(preprocess_inputs, num_parallel_calls=tf.data.AUTOTUNE)\n        .batch(FLAGS.batch_size)\n        .cache()\n        .prefetch(tf.data.AUTOTUNE)\n    )\n    return train_dataset, val_dataset\n\n\ndef load_model():\n    model_class = MODEL_MAP[FLAGS.model]\n    # Load the EfficientNetV2B0 model and add a classification head.\n    model = model_class(include_top=False, weights=\"imagenet\")\n    classifier = keras.models.Sequential(\n        [\n            keras.Input([IMAGE_SIZE[0], IMAGE_SIZE[1], CHANNELS]),\n            model,\n            keras.layers.GlobalAveragePooling2D(),\n            keras.layers.Dense(2),\n        ]\n    )\n    return classifier\n\n\ndef main(_):\n    keras.mixed_precision.set_dtype_policy(FLAGS.mixed_precision_policy)\n\n    logging.info(\n        \"Benchmarking configs...\\n\"\n        \"=========================\\n\"\n        f\"MODEL: {FLAGS.model}\\n\"\n        f\"TASK: image classification/dogs-vs-cats \\n\"\n        f\"BATCH_SIZE: {FLAGS.batch_size}\\n\"\n        f\"EPOCHS: {FLAGS.epochs}\\n\"\n        \"=========================\\n\"\n    )\n\n    # Load datasets.\n    train_ds, validation_ds = load_data()\n\n    # Load the model.\n    classifier = load_model()\n\n    lr = keras.optimizers.schedules.PolynomialDecay(\n        5e-4,\n        decay_steps=train_ds.cardinality() * FLAGS.epochs,\n        end_learning_rate=0.0,\n    )\n    optimizer = keras.optimizers.Adam(lr)\n    loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n\n    benchmark_metrics_callback = BenchmarkMetricsCallback(\n        start_batch=1,\n        stop_batch=train_ds.cardinality().numpy() - 1,\n    )\n\n    classifier.compile(\n        optimizer=optimizer,\n        loss=loss,\n        metrics=[\"sparse_categorical_accuracy\"],\n    )\n    # Start training.\n    logging.info(\"Starting Training...\")\n\n    st = time.time()\n\n    history = classifier.fit(\n        train_ds,\n        validation_data=validation_ds,\n        epochs=FLAGS.epochs,\n        callbacks=[benchmark_metrics_callback],\n    )\n\n    wall_time = time.time() - st\n    validation_accuracy = history.history[\"val_sparse_categorical_accuracy\"][-1]\n\n    examples_per_second = (\n        np.mean(np.array(benchmark_metrics_callback.state[\"throughput\"]))\n        * FLAGS.batch_size\n    )\n\n    logging.info(\"Training Finished!\")\n    logging.info(f\"Wall Time: {wall_time:.4f} seconds.\")\n    logging.info(f\"Validation Accuracy: {validation_accuracy:.4f}\")\n    logging.info(f\"examples_per_second: {examples_per_second:.4f}\")\n\n\nif __name__ == \"__main__\":\n    app.run(main)\n"
  },
  {
    "path": "benchmarks/torch_ctl_benchmark/README.md",
    "content": "# Benchmark the performance of torch custom training loop\n\nThis directory contains benchmarks to compare the performance of a Keras model\nand a equivalent Torch model while using the same Torch custom training loop.\n\nThe benchmark purpose is to understand the performance diff resulting from the\nmodeling API choice (Keras or Torch).\n\nTo run the benchmark, use the command below and change to your target:\n\n```shell\npython3 -m benchmarks.torch_ctl_benchmark.conv_model_benchmark\n```"
  },
  {
    "path": "benchmarks/torch_ctl_benchmark/__init__.py",
    "content": ""
  },
  {
    "path": "benchmarks/torch_ctl_benchmark/benchmark_utils.py",
    "content": "import time\n\nimport numpy as np\nimport torch\n\n\ndef train_loop(model, train_loader, num_epochs, optimizer, loss_fn, framework):\n    device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n    model.to(device)\n    start = None\n    average_batch_time_per_epoch = []\n    for _ in range(num_epochs):\n        running_loss = 0.0\n        for batch_idx, (inputs, targets) in enumerate(train_loader):\n            if batch_idx == 1:\n                start = time.time()\n            inputs = inputs.to(device)\n            targets = targets.to(device)\n            # Forward pass\n            outputs = model(inputs)\n            loss = loss_fn(outputs, targets)\n\n            # Backward and optimize\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            running_loss += loss.item()\n\n        end = time.time()\n        average_batch_time_per_epoch.append(\n            (end - start) / (len(train_loader) - 1)\n        )\n    average_time = np.mean(average_batch_time_per_epoch)\n\n    print(f\"Time per batch in {framework}: {average_time:.2f}\")\n"
  },
  {
    "path": "benchmarks/torch_ctl_benchmark/conv_model_benchmark.py",
    "content": "\"\"\"Benchmark Keras performance with torch custom training loop.\n\nIn this file we use a convolution model. Training loop is written in the\nvanilla torch way, and we compare the performance between building model with\nKeras and torch.\n\"\"\"\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\n\nimport keras\nfrom benchmarks.torch_ctl_benchmark.benchmark_utils import train_loop\nfrom keras import layers\n\nnum_classes = 2\ninput_shape = (3, 256, 256)\nbatch_size = 128\nnum_batches = 20\nnum_epochs = 1\n\nx_train = np.random.normal(\n    size=(num_batches * batch_size, *input_shape)\n).astype(np.float32)\ny_train = np.random.randint(0, num_classes, size=(num_batches * batch_size,))\n\n# Create a TensorDataset\ndataset = torch.utils.data.TensorDataset(\n    torch.from_numpy(x_train), torch.from_numpy(y_train)\n)\n# Create a DataLoader\ntrain_loader = torch.utils.data.DataLoader(\n    dataset, batch_size=batch_size, shuffle=False\n)\n\n\nclass TorchModel(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n\n        self.conv = torch.nn.Conv2d(3, 32, kernel_size=(3, 3))\n        self.activation = torch.nn.ReLU()\n        self.max_pool = torch.nn.MaxPool2d((2, 2))\n        self.flatten = torch.nn.Flatten()\n        self.dense = torch.nn.LazyLinear(num_classes)\n        self.softmax = torch.nn.Softmax(dim=1)\n\n    def forward(self, x):\n        x = self.conv(x)\n        x = self.activation(x)\n        x = self.max_pool(x)\n        x = self.flatten(x)\n        x = self.dense(x)\n        x = self.softmax(x)\n        return x\n\n\ndef run_keras_custom_training_loop():\n    keras_model = keras.Sequential(\n        [\n            layers.Input(shape=input_shape),\n            layers.Conv2D(32, kernel_size=(3, 3), activation=\"relu\"),\n            layers.MaxPooling2D(pool_size=(2, 2)),\n            layers.Flatten(),\n            layers.Dense(num_classes),\n            layers.Softmax(),\n        ]\n    )\n    optimizer = optim.Adam(keras_model.parameters(), lr=0.001)\n    loss_fn = nn.CrossEntropyLoss()\n    train_loop(\n        keras_model,\n        train_loader,\n        num_epochs=num_epochs,\n        optimizer=optimizer,\n        loss_fn=loss_fn,\n        framework=\"keras\",\n    )\n\n\ndef run_torch_custom_training_loop():\n    torch_model = TorchModel()\n    optimizer = optim.Adam(torch_model.parameters(), lr=0.001)\n    loss_fn = nn.CrossEntropyLoss()\n    train_loop(\n        torch_model,\n        train_loader,\n        num_epochs=num_epochs,\n        optimizer=optimizer,\n        loss_fn=loss_fn,\n        framework=\"torch\",\n    )\n\n\nif __name__ == \"__main__\":\n    run_keras_custom_training_loop()\n    run_torch_custom_training_loop()\n"
  },
  {
    "path": "benchmarks/torch_ctl_benchmark/dense_model_benchmark.py",
    "content": "\"\"\"Benchmark Keras performance with torch custom training loop.\n\nIn this file we use a model with 3 dense layers. Training loop is written in the\nvanilla torch way, and we compare the performance between building model with\nKeras and torch.\n\"\"\"\n\nimport numpy as np\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\n\nimport keras\nfrom benchmarks.torch_ctl_benchmark.benchmark_utils import train_loop\nfrom keras import layers\n\nnum_classes = 2\ninput_shape = (8192,)\nbatch_size = 4096\nnum_batches = 20\nnum_epochs = 1\n\nx_train = np.random.normal(\n    size=(num_batches * batch_size, *input_shape)\n).astype(np.float32)\ny_train = np.random.randint(0, num_classes, size=(num_batches * batch_size,))\n\n# Create a TensorDataset\ndataset = torch.utils.data.TensorDataset(\n    torch.from_numpy(x_train), torch.from_numpy(y_train)\n)\n# Create a DataLoader\ntrain_loader = torch.utils.data.DataLoader(\n    dataset, batch_size=batch_size, shuffle=False\n)\n\n\nclass TorchModel(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n\n        self.dense1 = torch.nn.Linear(8192, 64)\n        self.activation1 = torch.nn.ReLU()\n        self.dense2 = torch.nn.Linear(64, 8)\n        self.activation2 = torch.nn.ReLU()\n        self.dense3 = torch.nn.Linear(8, num_classes)\n        self.softmax = torch.nn.Softmax(dim=1)\n\n    def forward(self, x):\n        x = self.dense1(x)\n        x = self.activation1(x)\n        x = self.dense2(x)\n        x = self.activation2(x)\n        x = self.dense3(x)\n        x = self.softmax(x)\n        return x\n\n\ndef run_keras_custom_training_loop():\n    keras_model = keras.Sequential(\n        [\n            layers.Input(shape=input_shape),\n            layers.Dense(64, activation=\"relu\"),\n            layers.Dense(8, activation=\"relu\"),\n            layers.Dense(num_classes),\n            layers.Softmax(),\n        ]\n    )\n    optimizer = optim.Adam(keras_model.parameters(), lr=0.001)\n    loss_fn = nn.CrossEntropyLoss()\n    train_loop(\n        keras_model,\n        train_loader,\n        num_epochs=num_epochs,\n        optimizer=optimizer,\n        loss_fn=loss_fn,\n        framework=\"keras\",\n    )\n\n\ndef run_torch_custom_training_loop():\n    torch_model = TorchModel()\n    optimizer = optim.Adam(torch_model.parameters(), lr=0.001)\n    loss_fn = nn.CrossEntropyLoss()\n    train_loop(\n        torch_model,\n        train_loader,\n        num_epochs=num_epochs,\n        optimizer=optimizer,\n        loss_fn=loss_fn,\n        framework=\"torch\",\n    )\n\n\nif __name__ == \"__main__\":\n    run_keras_custom_training_loop()\n    run_torch_custom_training_loop()\n"
  },
  {
    "path": "codecov.yml",
    "content": "coverage:\n  status:\n    project:\n      default:\n        # `auto` compares coverage with the base-commit\n        target: auto\n\n    patch:\n      default:\n        target:auto\n\ncomment:\n  layout: \"header, reach, diff, flags, files\"\n  behavior: default\n  require_changes: no\n  require_base: no\n  require_head: yes\n  show_carryforward_flags: yes\n\nflag_management:\n  default_rules:\n    carryforward: false\n    statuses:\n      - type: project\n        target: auto\n      - type: patch\n        target: auto\n  individual_flags:\n    - name: keras\n      paths:\n        - keras\n    - name: keras.applications\n      paths:\n        - keras/applications\n      carryforward: true\n"
  },
  {
    "path": "conftest.py",
    "content": "try:\n    # When using torch and tensorflow, torch needs to be imported first,\n    # otherwise it will segfault upon import. This should force the torch\n    # import to happen first for all tests.\n    import torch  # noqa: F401\nexcept ImportError:\n    torch = None\n\nimport pytest  # noqa: E402\n\nfrom keras.src.backend import backend  # noqa: E402\n\n\ndef pytest_configure(config):\n    config.addinivalue_line(\n        \"markers\",\n        \"requires_trainable_backend: mark test for trainable backend only\",\n    )\n\n\ndef pytest_collection_modifyitems(config, items):\n    openvino_skipped_tests = []\n    if backend() == \"openvino\":\n        with open(\n            \"keras/src/backend/openvino/excluded_concrete_tests.txt\", \"r\"\n        ) as file:\n            openvino_skipped_tests = file.readlines()\n            # it is necessary to check if stripped line is not empty\n            # and exclude such lines\n            openvino_skipped_tests = [\n                line.strip() for line in openvino_skipped_tests if line.strip()\n            ]\n\n    tpu_skipped_tests = []\n    if backend() == \"jax\":\n        import jax\n\n        if jax.default_backend() == \"tpu\":\n            with open(\n                \"keras/src/backend/jax/excluded_tpu_tests.txt\", \"r\"\n            ) as file:\n                tpu_skipped_tests = file.readlines()\n                # it is necessary to check if stripped line is not empty\n                # and exclude such lines\n                tpu_skipped_tests = [\n                    line.strip() for line in tpu_skipped_tests if line.strip()\n                ]\n\n    requires_trainable_backend = pytest.mark.skipif(\n        backend() in [\"numpy\", \"openvino\"],\n        reason=\"Trainer not implemented for NumPy and OpenVINO backend.\",\n    )\n    for item in items:\n        if \"requires_trainable_backend\" in item.keywords:\n            item.add_marker(requires_trainable_backend)\n        # also, skip concrete tests for openvino, listed in the special file\n        # this is more granular mechanism to exclude tests rather\n        # than using --ignore option\n        for skipped_test in openvino_skipped_tests:\n            if skipped_test in item.nodeid:\n                item.add_marker(\n                    skip_if_backend(\n                        \"openvino\",\n                        \"Not supported operation by openvino backend\",\n                    )\n                )\n        # also, skip concrete tests for TPU when using JAX backend\n        for skipped_test in tpu_skipped_tests:\n            if skipped_test in item.nodeid:\n                item.add_marker(\n                    pytest.mark.skip(\n                        reason=\"Known TPU test failure\",\n                    )\n                )\n\n\ndef skip_if_backend(given_backend, reason):\n    return pytest.mark.skipif(backend() == given_backend, reason=reason)\n"
  },
  {
    "path": "examples/demo_custom_jax_workflow.py",
    "content": "# flake8: noqa\nimport os\n\n# Set backend env to JAX\nos.environ[\"KERAS_BACKEND\"] = \"jax\"\n\nimport jax\nimport numpy as np\n\nfrom keras import Model\nfrom keras import backend\nfrom keras import initializers\nfrom keras import layers\nfrom keras import ops\nfrom keras import optimizers\n\n\nclass MyDense(layers.Layer):\n    def __init__(self, units, name=None):\n        super().__init__(name=name)\n        self.units = units\n\n    def build(self, input_shape):\n        input_dim = input_shape[-1]\n        w_shape = (input_dim, self.units)\n        w_value = initializers.GlorotUniform()(w_shape)\n        self.w = backend.Variable(w_value, name=\"kernel\")\n\n        b_shape = (self.units,)\n        b_value = initializers.Zeros()(b_shape)\n        self.b = backend.Variable(b_value, name=\"bias\")\n\n    def call(self, inputs):\n        return ops.matmul(inputs, self.w) + self.b\n\n\nclass MyModel(Model):\n    def __init__(self, hidden_dim, output_dim):\n        super().__init__()\n        self.dense1 = MyDense(hidden_dim)\n        self.dense2 = MyDense(hidden_dim)\n        self.dense3 = MyDense(output_dim)\n\n    def call(self, x):\n        x = jax.nn.relu(self.dense1(x))\n        x = jax.nn.relu(self.dense2(x))\n        return self.dense3(x)\n\n\ndef Dataset():\n    for _ in range(20):\n        yield (np.random.random((32, 128)), np.random.random((32, 4)))\n\n\ndef loss_fn(y_true, y_pred):\n    return ops.sum((y_true - y_pred) ** 2)\n\n\nmodel = MyModel(hidden_dim=256, output_dim=4)\n\noptimizer = optimizers.SGD(learning_rate=0.001)\ndataset = Dataset()\n\n# Build model\nx = np.random.random((1, 128))\nmodel(x)\n# Build optimizer\noptimizer.build(model.trainable_variables)\n\n\n######### Custom JAX workflow ###############\n\n\ndef compute_loss_and_updates(\n    trainable_variables, non_trainable_variables, x, y\n):\n    y_pred, non_trainable_variables = model.stateless_call(\n        trainable_variables, non_trainable_variables, x\n    )\n    loss = loss_fn(y, y_pred)\n    return loss, non_trainable_variables\n\n\ngrad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)\n\n\n@jax.jit\ndef train_step(state, data):\n    trainable_variables, non_trainable_variables, optimizer_variables = state\n    x, y = data\n    (loss, non_trainable_variables), grads = grad_fn(\n        trainable_variables, non_trainable_variables, x, y\n    )\n    trainable_variables, optimizer_variables = optimizer.stateless_apply(\n        optimizer_variables, grads, trainable_variables\n    )\n    # Return updated state\n    return loss, (\n        trainable_variables,\n        non_trainable_variables,\n        optimizer_variables,\n    )\n\n\ntrainable_variables = model.trainable_variables\nnon_trainable_variables = model.non_trainable_variables\noptimizer_variables = optimizer.variables\nstate = trainable_variables, non_trainable_variables, optimizer_variables\n# Training loop\nfor data in dataset:\n    loss, state = train_step(state, data)\n    print(\"Loss:\", loss)\n\n# Post-processing model state update\ntrainable_variables, non_trainable_variables, optimizer_variables = state\nfor variable, value in zip(model.trainable_variables, trainable_variables):\n    variable.assign(value)\nfor variable, value in zip(\n    model.non_trainable_variables, non_trainable_variables\n):\n    variable.assign(value)\n"
  },
  {
    "path": "examples/demo_custom_layer_backend_agnostic.py",
    "content": "import numpy as np\n\nimport keras\nfrom keras import Model\nfrom keras import initializers\nfrom keras import layers\nfrom keras import losses\nfrom keras import metrics\nfrom keras import ops\nfrom keras import optimizers\n\n\nclass MyDense(layers.Layer):\n    def __init__(self, units, name=None):\n        super().__init__(name=name)\n        self.units = units\n\n    def build(self, input_shape):\n        input_dim = input_shape[-1]\n        self.w = self.add_weight(\n            shape=(input_dim, self.units),\n            initializer=initializers.GlorotNormal(),\n            name=\"kernel\",\n            trainable=True,\n        )\n\n        self.b = self.add_weight(\n            shape=(self.units,),\n            initializer=initializers.Zeros(),\n            name=\"bias\",\n            trainable=True,\n        )\n\n    def call(self, inputs):\n        # Use Keras ops to create backend-agnostic layers/metrics/etc.\n        return ops.matmul(inputs, self.w) + self.b\n\n\nclass MyDropout(layers.Layer):\n    def __init__(self, rate, name=None):\n        super().__init__(name=name)\n        self.rate = rate\n        # Use seed_generator for managing RNG state.\n        # It is a state element and its seed variable is\n        # tracked as part of `layer.variables`.\n        self.seed_generator = keras.random.SeedGenerator(1337)\n\n    def call(self, inputs):\n        # Use `keras.random` for random ops.\n        return keras.random.dropout(inputs, self.rate, seed=self.seed_generator)\n\n\nclass MyModel(Model):\n    def __init__(self, hidden_dim, output_dim):\n        super().__init__()\n        self.dense1 = MyDense(hidden_dim)\n        self.dense2 = MyDense(hidden_dim)\n        self.dense3 = MyDense(output_dim)\n        self.dp = MyDropout(0.5)\n\n    def call(self, x):\n        x1 = self.dense1(x)\n        x2 = self.dense2(x)\n        # Why not use some ops here as well\n        x = ops.concatenate([x1, x2], axis=-1)\n        x = self.dp(x)\n        return self.dense3(x)\n\n\nmodel = MyModel(hidden_dim=256, output_dim=16)\n\nx = np.random.random((50000, 128))\ny = np.random.random((50000, 16))\nbatch_size = 32\nepochs = 5\n\nmodel.compile(\n    optimizer=optimizers.SGD(learning_rate=0.001),\n    loss=losses.MeanSquaredError(),\n    metrics=[metrics.MeanSquaredError()],\n)\nhistory = model.fit(x, y, batch_size=batch_size, epochs=epochs)\n\nmodel.summary()\n\nprint(\"History:\")\nprint(history.history)\n"
  },
  {
    "path": "examples/demo_custom_tf_workflow.py",
    "content": "# flake8: noqa\nimport os\n\n# Set backend env to tensorflow\nos.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n\nimport numpy as np\nimport tensorflow as tf\n\nfrom keras import Model\nfrom keras import backend\nfrom keras import initializers\nfrom keras import layers\nfrom keras import ops\nfrom keras import optimizers\n\n\nclass MyDense(layers.Layer):\n    def __init__(self, units, name=None):\n        super().__init__(name=name)\n        self.units = units\n\n    def build(self, input_shape):\n        input_dim = input_shape[-1]\n        w_shape = (input_dim, self.units)\n        w_value = initializers.GlorotUniform()(w_shape)\n        self.w = backend.Variable(w_value, name=\"kernel\")\n\n        b_shape = (self.units,)\n        b_value = initializers.Zeros()(b_shape)\n        self.b = backend.Variable(b_value, name=\"bias\")\n\n    def call(self, inputs):\n        return ops.matmul(inputs, self.w) + self.b\n\n\nclass MyModel(Model):\n    def __init__(self, hidden_dim, output_dim):\n        super().__init__()\n        self.dense1 = MyDense(hidden_dim)\n        self.dense2 = MyDense(hidden_dim)\n        self.dense3 = MyDense(output_dim)\n\n    def call(self, x):\n        x = tf.nn.relu(self.dense1(x))\n        x = tf.nn.relu(self.dense2(x))\n        return self.dense3(x)\n\n\ndef Dataset():\n    for _ in range(20):\n        yield (\n            np.random.random((32, 128)).astype(\"float32\"),\n            np.random.random((32, 4)).astype(\"float32\"),\n        )\n\n\ndef loss_fn(y_true, y_pred):\n    return ops.sum((y_true - y_pred) ** 2)\n\n\nmodel = MyModel(hidden_dim=256, output_dim=4)\n\noptimizer = optimizers.SGD(learning_rate=0.001)\ndataset = Dataset()\n\n\n######### Custom TF workflow ###############\n\n\n@tf.function(jit_compile=True)\ndef train_step(data):\n    x, y = data\n    with tf.GradientTape() as tape:\n        y_pred = model(x)\n        loss = loss_fn(y, y_pred)\n    gradients = tape.gradient(loss, model.trainable_variables)\n    optimizer.apply_gradients(zip(gradients, model.trainable_variables))\n    return loss\n\n\nfor data in dataset:\n    loss = train_step(data)\n    print(\"Loss:\", float(loss))\n"
  },
  {
    "path": "examples/demo_custom_torch_workflow.py",
    "content": "# flake8: noqa\nimport os\n\n# Set backend env to torch\nos.environ[\"KERAS_BACKEND\"] = \"torch\"\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom keras import layers\nimport keras\nimport numpy as np\n\n# Model / data parameters\nnum_classes = 10\ninput_shape = (28, 28, 1)\nlearning_rate = 0.01\nbatch_size = 64\nnum_epochs = 1\n\n# Load the data and split it between train and test sets\n(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n\n# Scale images to the [0, 1] range\nx_train = x_train.astype(\"float32\") / 255\nx_test = x_test.astype(\"float32\") / 255\n# Make sure images have shape (28, 28, 1)\nx_train = np.expand_dims(x_train, -1)\nx_test = np.expand_dims(x_test, -1)\nprint(\"x_train shape:\", x_train.shape)\nprint(x_train.shape[0], \"train samples\")\nprint(x_test.shape[0], \"test samples\")\n\n# Create the Keras model\nmodel = keras.Sequential(\n    [\n        layers.Input(shape=(28, 28, 1)),\n        layers.Conv2D(32, kernel_size=(3, 3), activation=\"relu\"),\n        layers.MaxPooling2D(pool_size=(2, 2)),\n        layers.Conv2D(64, kernel_size=(3, 3), activation=\"relu\"),\n        layers.MaxPooling2D(pool_size=(2, 2)),\n        layers.Flatten(),\n        layers.Dropout(0.5),\n        layers.Dense(num_classes),\n    ]\n)\n\n#################################################################\n######## Writing a torch training loop for a Keras model ########\n#################################################################\n\n# Instantiate the torch optimizer\noptimizer = optim.Adam(model.parameters(), lr=learning_rate)\n\n# Instantiate the torch loss function\nloss_fn = nn.CrossEntropyLoss()\n\n\ndef train(model, train_loader, num_epochs, optimizer, loss_fn):\n    for epoch in range(num_epochs):\n        running_loss = 0.0\n        for batch_idx, (inputs, targets) in enumerate(train_loader):\n            # Forward pass\n            outputs = model(inputs)\n            loss = loss_fn(outputs, targets)\n\n            # Backward and optimize\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            running_loss += loss.item()\n\n            # Print loss statistics\n            if (batch_idx + 1) % 10 == 0:\n                print(\n                    f\"Epoch [{epoch + 1}/{num_epochs}], \"\n                    f\"Batch [{batch_idx + 1}/{len(train_loader)}], \"\n                    f\"Loss: {running_loss / 10}\"\n                )\n                running_loss = 0.0\n\n\n# Create a TensorDataset\ndataset = torch.utils.data.TensorDataset(\n    torch.from_numpy(x_train), torch.from_numpy(y_train)\n)\n\n# Create a DataLoader\ntrain_loader = torch.utils.data.DataLoader(\n    dataset, batch_size=batch_size, shuffle=False\n)\n\ntrain(model, train_loader, num_epochs, optimizer, loss_fn)\n\n\n################################################################\n######## Using a Keras model or layer in a torch Module ########\n################################################################\n\n\nclass MyModel(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.model = keras.Sequential(\n            [\n                layers.Input(shape=(28, 28, 1)),\n                layers.Conv2D(32, kernel_size=(3, 3), activation=\"relu\"),\n                layers.MaxPooling2D(pool_size=(2, 2)),\n                layers.Conv2D(64, kernel_size=(3, 3), activation=\"relu\"),\n                layers.MaxPooling2D(pool_size=(2, 2)),\n                layers.Flatten(),\n                layers.Dropout(0.5),\n                layers.Dense(num_classes),\n            ]\n        )\n\n    def forward(self, x):\n        return self.model(x)\n\n\ntorch_module = MyModel()\n\n# Instantiate the torch optimizer\noptimizer = optim.Adam(torch_module.parameters(), lr=learning_rate)\n\n# Instantiate the torch loss function\nloss_fn = nn.CrossEntropyLoss()\n\ntrain(torch_module, train_loader, num_epochs, optimizer, loss_fn)\n"
  },
  {
    "path": "examples/demo_functional.py",
    "content": "import numpy as np\n\nfrom keras import Model\nfrom keras import layers\nfrom keras import losses\nfrom keras import metrics\nfrom keras import optimizers\nimport keras\n\nkeras.config.disable_traceback_filtering()\n\ninputs = layers.Input((100,))\nx = layers.Dense(512, activation=\"relu\")(inputs)\nresidual = x\nx = layers.Dense(512, activation=\"relu\")(x)\nx = layers.Dense(512, activation=\"relu\")(x)\nx += residual\nx = layers.Dense(512, activation=\"relu\")(x)\nresidual = x\nx = layers.Dense(512, activation=\"relu\")(x)\nx = layers.Dense(512, activation=\"relu\")(x)\nx += residual\nresidual = x\nx = layers.Dense(512, activation=\"relu\")(x)\nx = layers.Dense(512, activation=\"relu\")(x)\nx += residual\noutputs = layers.Dense(16)(x)\nmodel = Model(inputs, outputs)\n\nmodel.summary()\n\nx = np.random.random((50000, 100))\ny = np.random.random((50000, 16))\nbatch_size = 32\nepochs = 5\n\nmodel.compile(\n    optimizer=optimizers.Adam(learning_rate=0.001),\n    loss=losses.MeanSquaredError(),\n    metrics=[\n        metrics.CategoricalAccuracy(name=\"acc\"),\n        metrics.MeanSquaredError(name=\"mse\"),\n    ],\n)\n\nprint(\"\\nTrain model\")\nhistory = model.fit(\n    x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2\n)\nprint(\"\\nHistory:\")\nprint(history.history)\n\nprint(\"\\nEvaluate model\")\nscores = model.evaluate(x, y, return_dict=True)\nprint(scores)\n\nprint(\"\\nRun inference\")\npred = model.predict(x)\nprint(f\"Inferred output shape {pred.shape}\")\n"
  },
  {
    "path": "examples/demo_jax_distributed.py",
    "content": "# To run this demo, you will need to spin up a \"TPU VM\" on Google Cloud.\n# Please follow instructions here: https://cloud.google.com/tpu/docs/run-calculation-jax\n\n# Force a JAX backend\nimport os, pprint, collections\n\nos.environ[\"KERAS_BACKEND\"] = \"jax\"\n\npp = pprint.PrettyPrinter()\n\nimport jax\nimport jax.numpy as jnp\nimport tensorflow as tf  # just for tf.data\nimport keras  # Keras multi-backend\n\nimport numpy as np\nfrom tqdm import tqdm\n\nfrom jax.experimental import mesh_utils\nfrom jax.sharding import Mesh\nfrom jax.sharding import NamedSharding\nfrom jax.sharding import PartitionSpec as P\n\n\"\"\" Dataset\nClassic MNIST, loaded using tf.data\n\"\"\"\n\nBATCH_SIZE = 192\n\n(\n    (x_train, train_labels),\n    (x_eval, eval_labels),\n) = keras.datasets.mnist.load_data()\nx_train = np.expand_dims(x_train, axis=-1).astype(\n    np.float32\n)  # from 28x28 to 28x28 x 1 color channel (B&W)\nx_eval = np.expand_dims(x_eval, axis=-1).astype(np.float32)\n\ntrain_data = tf.data.Dataset.from_tensor_slices((x_train, train_labels))\ntrain_data = train_data.shuffle(5000, reshuffle_each_iteration=True)\ntrain_data = train_data.batch(BATCH_SIZE, drop_remainder=True)\ntrain_data = train_data.repeat()\n\neval_data = tf.data.Dataset.from_tensor_slices((x_eval, eval_labels))\neval_data = eval_data.batch(10000)  # everything as one batch\n\nSTEPS_PER_EPOCH = len(train_labels) // BATCH_SIZE\n\n\"\"\" Keras model\nSimple but non-trivial model with:\n* Batch Normalization (non-trainable state updated during training, different training-time and inference behavior)\n* Dropout (randomness, different training time and inference behavior)\n\"\"\"\n\n\n# Keras \"sequential\" model building style\ndef make_backbone():\n    return keras.Sequential(\n        [\n            keras.layers.Rescaling(\n                1.0 / 255.0\n            ),  # input images are in the range [0, 255]\n            keras.layers.Conv2D(\n                filters=12, kernel_size=3, padding=\"same\", use_bias=False\n            ),\n            keras.layers.BatchNormalization(scale=False, center=True),\n            keras.layers.Activation(\"relu\"),\n            keras.layers.Conv2D(\n                filters=24,\n                kernel_size=6,\n                padding=\"same\",\n                use_bias=False,\n                strides=2,\n            ),\n            keras.layers.BatchNormalization(scale=False, center=True),\n            keras.layers.Activation(\"relu\"),\n            keras.layers.Conv2D(\n                filters=32,\n                kernel_size=6,\n                padding=\"same\",\n                use_bias=False,\n                strides=2,\n                name=\"large_k\",\n            ),\n            keras.layers.BatchNormalization(scale=False, center=True),\n            keras.layers.Activation(\"relu\"),\n        ],\n        name=\"backbone\",\n    )\n\n\ndef make_model():\n    input = keras.Input(shape=[28, 28, 1])\n    y = make_backbone()(input)\n    y = keras.layers.Flatten()(y)\n    y = keras.layers.Dense(200, activation=\"relu\")(y)\n    y = keras.layers.Dropout(0.4)(y)\n    y = keras.layers.Dense(10, activation=\"softmax\")(y)\n    model = keras.Model(inputs=input, outputs=y)\n    return model\n\n\n\"\"\" JAX-native distribution with a Keras model\nFor now, you have to write a custom training loop for this\nNote: The features required by jax.sharding are not supported by the Colab TPU\nruntime at this time, but are available on Cloud TPU VMs and Kaggle TPU VMs.\n\"\"\"\n\nif len(jax.local_devices()) < 8:\n    raise Exception(\"This part requires 8 devices to run\")\nelse:\n    print(\"\\nIdentified local devices:\")\n    pp.pprint(jax.local_devices())\n\n# ----------------- Keras ---------------------\n\n# instantiate the model\nmodel = make_model()\n\n# learning rate\nlr = keras.optimizers.schedules.ExponentialDecay(0.01, STEPS_PER_EPOCH, 0.6)\n\n# optimizer\noptimizer = keras.optimizers.Adam(lr)\n\n# initialize all state with .build()\n(one_batch, one_batch_labels) = next(iter(train_data))\nmodel.build(one_batch)\noptimizer.build(model.trainable_variables)\n\n\"\"\" Distribution settings\n\n* Sharding the data on the batch axis\n* Replicating all model variables\n\nNote: this implements standard \"data parallel\" distributed training\n\n* Just for show, sharding the largest convolutional kernel along the\n  \"channels\" axis 4-ways and replicating 2-ways\n\nNote: this does not reflect a best practice but is intended to show\n      that you can split a very large kernel across multiple devices\n      if you have to\n\"\"\"\n\nprint(\n    \"\\nMostly data-parallel distribution. \"\n    \"Data is sharded across devices while the model is replicated. \"\n    \"For demo purposes, we split the largest kernel 4-ways \"\n    \"(and replicate 2-ways since we have 8 devices).\"\n)\n\n# ------------------ Jax ----------------------\n\ndevices = mesh_utils.create_device_mesh((8,))\n\n# data will be split along the batch axis\ndata_mesh = Mesh(devices, axis_names=(\"batch\",))  # naming axes of the mesh\n# naming axes of the sharded partition\ndata_sharding = NamedSharding(\n    data_mesh,\n    P(\n        \"batch\",\n    ),\n)\n# all variables will be replicated on all devices\nvar_mesh = Mesh(devices, axis_names=(\"_\"))\n# in NamedSharding, axes that are not mentioned are replicated (all axes here)\nvar_replication = NamedSharding(var_mesh, P())\n\n# for the demo, we will split the largest kernel 4-ways (and replicate 2-ways since we have 8 devices)\nlarge_kernel_mesh = Mesh(\n    devices.reshape((-1, 4)), axis_names=(None, \"out_chan\")\n)  # naming axes of the mesh\nlarge_kernel_sharding = NamedSharding(\n    large_kernel_mesh, P(None, None, None, \"out_chan\")\n)  # naming axes of the sharded partition\n\n# ----------------- Keras ---------------------\n\n# Use Keras APIs to find the variable of a specific layer (we will be sharding this one in a special way)\n# In a Conv2D or Dense layer, the variables are 'kernel' and 'bias'\nspecial_layer_var = model.get_layer(\"backbone\").get_layer(\"large_k\").kernel\n\n# ------------------ Jax ----------------------\n# - accessing variables in Keras lists model.trainable_variables,\n# - model.non_trainable_variables and optimizer.variables\n\n# Apply the distribution settings to the model variables\nnon_trainable_variables = jax.device_put(\n    model.non_trainable_variables, var_replication\n)\noptimizer_variables = jax.device_put(optimizer.variables, var_replication)\n# this is what you would do replicate all trainable variables:\n# trainable_variables = jax.device_put(model.trainable_variables, var_replication)\n\n# For the demo, we split the largest kernel 4-ways instead of replicating it.\n# We still replicate all other trainable variables as in standard \"data-parallel\"\n# distributed training.\nprint_once = True\ntrainable_variables = model.trainable_variables\nfor i, v in enumerate(trainable_variables):\n    if v is special_layer_var:\n        # Apply distribution settings: sharding\n        sharded_v = jax.device_put(v, large_kernel_sharding)\n        trainable_variables[i] = sharded_v\n\n        print(\"Sharding of convolutional\", v.name, v.shape)\n        jax.debug.visualize_array_sharding(\n            jnp.reshape(sharded_v, [-1, v.shape[-1]])\n        )\n    else:\n        # Apply distribution settings: replication\n        replicated_v = jax.device_put(v, var_replication)\n        trainable_variables[i] = replicated_v\n\n        if print_once:\n            print_once = False\n            print(\n                \"\\nSharding of all other model variables (they are replicated)\"\n            )\n            jax.debug.visualize_array_sharding(\n                jnp.reshape(replicated_v, [-1, v.shape[-1]])\n            )\n\n# collect state in a handy named tuple\nTrainingState = collections.namedtuple(\n    \"TrainingState\",\n    [\"trainable_variables\", \"non_trainable_variables\", \"optimizer_variables\"],\n)\ndevice_train_state = TrainingState(\n    trainable_variables=trainable_variables,\n    non_trainable_variables=non_trainable_variables,\n    optimizer_variables=optimizer_variables,\n)\n# display data sharding\nx, y = next(iter(train_data))\nsharded_x = jax.device_put(x.numpy(), data_sharding)\nprint(\"Data sharding\")\njax.debug.visualize_array_sharding(jnp.reshape(sharded_x, [-1, 28 * 28]))\n\n# ------------------ Jax ----------------------\n# - Using Keras-provided stateless APIs\n# - model.stateless_call\n# - optimizer.stateless_apply\n# These functions also work on other backends.\n\n# define loss\nloss = keras.losses.SparseCategoricalCrossentropy()\n\n\n# This is the loss function that will be differentiated.\n# Keras provides a pure functional forward pass: model.stateless_call\ndef compute_loss(trainable_variables, non_trainable_variables, x, y):\n    y_pred, updated_non_trainable_variables = model.stateless_call(\n        trainable_variables, non_trainable_variables, x\n    )\n    loss_value = loss(y, y_pred)\n    return loss_value, updated_non_trainable_variables\n\n\n# function to compute gradients\ncompute_gradients = jax.value_and_grad(compute_loss, has_aux=True)\n\n\n# Training step: Keras provides a pure functional optimizer.stateless_apply\n@jax.jit\ndef train_step(train_state, x, y):\n    (loss_value, non_trainable_variables), grads = compute_gradients(\n        train_state.trainable_variables,\n        train_state.non_trainable_variables,\n        x,\n        y,\n    )\n\n    trainable_variables, optimizer_variables = optimizer.stateless_apply(\n        train_state.optimizer_variables, grads, train_state.trainable_variables\n    )\n\n    return loss_value, TrainingState(\n        trainable_variables, non_trainable_variables, optimizer_variables\n    )\n\n\n# training loop\nEPOCHS = 5\nprint(\"\\nTraining:\")\ndata_iter = iter(train_data)\nfor epoch in range(EPOCHS):\n    loss_value = None  # default\n    for i in tqdm(range(STEPS_PER_EPOCH)):\n        x, y = next(data_iter)\n        sharded_x = jax.device_put(x.numpy(), data_sharding)\n        loss_value, device_train_state = train_step(\n            device_train_state, sharded_x, y.numpy()\n        )\n    print(\"Epoch\", epoch, \"loss:\", loss_value)\n\n# The output of the model is still sharded. Sharding follows the data.\n\ndata, labels = next(iter(eval_data))\nsharded_data = jax.device_put(data.numpy(), data_sharding)\n\n\n@jax.jit\ndef predict(data):\n    predictions, updated_non_trainable_variables = model.stateless_call(\n        device_train_state.trainable_variables,\n        device_train_state.non_trainable_variables,\n        data,\n    )\n    return predictions\n\n\npredictions = predict(sharded_data)\nprint(\"\\nModel output sharding follows data sharding:\")\njax.debug.visualize_array_sharding(predictions)\n\n# Post-processing model state update to write them back into the model\nupdate = lambda variable, value: variable.assign(value)\n\njax.tree_map(\n    update, model.trainable_variables, device_train_state.trainable_variables\n)\njax.tree_map(\n    update,\n    model.non_trainable_variables,\n    device_train_state.non_trainable_variables,\n)\njax.tree_map(\n    update, optimizer.variables, device_train_state.optimizer_variables\n)\n\n# check that the model has the new state by running an eval\n# known issue: the optimizer should not be required here\nmodel.compile(\n    loss=keras.losses.SparseCategoricalCrossentropy(),\n    metrics=[keras.metrics.SparseCategoricalAccuracy()],\n)\nprint(\"\\nUpdating model and running an eval:\")\nloss, accuracy = model.evaluate(eval_data)\nprint(\"The model achieved an evaluation accuracy of:\", accuracy)\n"
  },
  {
    "path": "examples/demo_mnist_convnet.py",
    "content": "import numpy as np\nimport keras\nfrom keras import layers\nfrom keras.utils import to_categorical\n\n# Model / data parameters\nnum_classes = 10\ninput_shape = (28, 28, 1)\n\n# Load the data and split it between train and test sets\n(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n\n# Scale images to the [0, 1] range\nx_train = x_train.astype(\"float32\") / 255\nx_test = x_test.astype(\"float32\") / 255\n# Make sure images have shape (28, 28, 1)\nx_train = np.expand_dims(x_train, -1)\nx_test = np.expand_dims(x_test, -1)\nprint(\"x_train shape:\", x_train.shape)\nprint(x_train.shape[0], \"train samples\")\nprint(x_test.shape[0], \"test samples\")\n\n\n# convert class vectors to binary class matrices\ny_train = to_categorical(y_train, num_classes)\ny_test = to_categorical(y_test, num_classes)\n\nbatch_size = 128\nepochs = 3\n\nmodel = keras.Sequential(\n    [\n        layers.Input(shape=input_shape),\n        layers.Conv2D(32, kernel_size=(3, 3), activation=\"relu\"),\n        layers.MaxPooling2D(pool_size=(2, 2)),\n        layers.Conv2D(64, kernel_size=(3, 3), activation=\"relu\"),\n        layers.MaxPooling2D(pool_size=(2, 2)),\n        layers.Flatten(),\n        layers.Dropout(0.5),\n        layers.Dense(num_classes, activation=\"softmax\"),\n    ]\n)\n\nmodel.summary()\n\nmodel.compile(\n    loss=\"categorical_crossentropy\", optimizer=\"adam\", metrics=[\"accuracy\"]\n)\n\nmodel.fit(\n    x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1\n)\n\nscore = model.evaluate(x_test, y_test, verbose=0)\nprint(\"Test loss:\", score[0])\nprint(\"Test accuracy:\", score[1])\n"
  },
  {
    "path": "examples/demo_subclass.py",
    "content": "import numpy as np\n\nfrom keras import Model\nfrom keras import layers\nfrom keras import losses\nfrom keras import metrics\nfrom keras import optimizers\n\n\nclass MyModel(Model):\n    def __init__(self, hidden_dim, output_dim):\n        super().__init__()\n        self.dense1 = layers.Dense(hidden_dim, activation=\"relu\")\n        self.dense2 = layers.Dense(hidden_dim, activation=\"relu\")\n        self.dense3 = layers.Dense(output_dim)\n\n    def call(self, x):\n        x = self.dense1(x)\n        x = self.dense2(x)\n        return self.dense3(x)\n\n\nmodel = MyModel(hidden_dim=256, output_dim=16)\n\nx = np.random.random((50000, 128))\ny = np.random.random((50000, 16))\nbatch_size = 32\nepochs = 6\n\nmodel.compile(\n    optimizer=optimizers.SGD(learning_rate=0.001),\n    loss=losses.MeanSquaredError(),\n    metrics=[metrics.MeanSquaredError()],\n)\nhistory = model.fit(\n    x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2\n)\n\nprint(\"History:\")\nprint(history.history)\n\nmodel.summary()\n"
  },
  {
    "path": "examples/demo_torch_multi_gpu.py",
    "content": "# flake8: noqa\nimport os\n\n# Set backend env to torch\nos.environ[\"KERAS_BACKEND\"] = \"torch\"\n\nimport torch\nimport torch.nn as nn\nimport torch.optim as optim\nfrom keras import layers\nimport keras\nimport numpy as np\n\nimport torch.multiprocessing as mp\nimport torch.distributed as dist\nfrom torch.nn.parallel import DistributedDataParallel as DDP\nfrom torch.utils.data import DataLoader, Dataset\nfrom torch.utils.data.distributed import DistributedSampler\n\n# Model / data parameters\nnum_classes = 10\ninput_shape = (28, 28, 1)\nlearning_rate = 0.01\nbatch_size = 128\nnum_epochs = 1\n\n\ndef get_data():\n    # Load the data and split it between train and test sets\n    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n\n    # Scale images to the [0, 1] range\n    x_train = x_train.astype(\"float32\") / 255\n    x_test = x_test.astype(\"float32\") / 255\n    # Make sure images have shape (28, 28, 1)\n    x_train = np.expand_dims(x_train, -1)\n    x_test = np.expand_dims(x_test, -1)\n    print(\"x_train shape:\", x_train.shape)\n    print(x_train.shape[0], \"train samples\")\n    print(x_test.shape[0], \"test samples\")\n\n    # Create a TensorDataset\n    dataset = torch.utils.data.TensorDataset(\n        torch.from_numpy(x_train), torch.from_numpy(y_train)\n    )\n    return dataset\n\n\ndef get_model():\n    # Create the Keras model\n    model = keras.Sequential(\n        [\n            layers.Input(shape=(28, 28, 1)),\n            layers.Conv2D(32, kernel_size=(3, 3), activation=\"relu\"),\n            layers.MaxPooling2D(pool_size=(2, 2)),\n            layers.Conv2D(64, kernel_size=(3, 3), activation=\"relu\"),\n            layers.MaxPooling2D(pool_size=(2, 2)),\n            layers.Flatten(),\n            layers.Dropout(0.5),\n            layers.Dense(num_classes),\n        ]\n    )\n    return model\n\n\nclass MyModel(nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.model = keras.Sequential(\n            [\n                layers.Input(shape=(28, 28, 1)),\n                layers.Conv2D(32, kernel_size=(3, 3), activation=\"relu\"),\n                layers.MaxPooling2D(pool_size=(2, 2)),\n                layers.Conv2D(64, kernel_size=(3, 3), activation=\"relu\"),\n                layers.MaxPooling2D(pool_size=(2, 2)),\n                layers.Flatten(),\n                layers.Dropout(0.5),\n                layers.Dense(num_classes),\n            ]\n        )\n\n    def forward(self, x):\n        return self.model(x)\n\n\ndef train(model, train_loader, num_epochs, optimizer, loss_fn):\n    for epoch in range(num_epochs):\n        running_loss = 0.0\n        for batch_idx, (inputs, targets) in enumerate(train_loader):\n            inputs = inputs.cuda(non_blocking=True)\n            targets = targets.cuda(non_blocking=True)\n\n            # Forward pass\n            outputs = model(inputs)\n            loss = loss_fn(outputs, targets)\n\n            # Backward and optimize\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            running_loss += loss.item()\n\n            # Print loss statistics\n            if (batch_idx + 1) % 10 == 0:\n                print(\n                    f\"Epoch [{epoch + 1}/{num_epochs}], \"\n                    f\"Batch [{batch_idx + 1}/{len(train_loader)}], \"\n                    f\"Loss: {running_loss / 10}\"\n                )\n                running_loss = 0.0\n\n\ndef setup(current_gpu_index, num_gpu):\n    # Device setup\n    os.environ[\"MASTER_ADDR\"] = \"localhost\"\n    os.environ[\"MASTER_PORT\"] = \"56492\"\n    device = torch.device(\"cuda:{}\".format(current_gpu_index))\n    dist.init_process_group(\n        backend=\"nccl\",\n        init_method=\"env://\",\n        world_size=num_gpu,\n        rank=current_gpu_index,\n    )\n    torch.cuda.set_device(device)\n\n\ndef prepare(dataset, current_gpu_index, num_gpu, batch_size):\n    sampler = DistributedSampler(\n        dataset,\n        num_replicas=num_gpu,\n        rank=current_gpu_index,\n        shuffle=False,\n    )\n\n    # Create a DataLoader\n    train_loader = DataLoader(\n        dataset,\n        sampler=sampler,\n        batch_size=batch_size,\n        shuffle=False,\n    )\n\n    return train_loader\n\n\ndef cleanup():\n    # Cleanup\n    dist.destroy_process_group()\n\n\ndef main(current_gpu_index, num_gpu):\n    # setup the process groups\n    setup(current_gpu_index, num_gpu)\n\n    #################################################################\n    ######## Writing a torch training loop for a Keras model ########\n    #################################################################\n\n    dataset = get_data()\n    model = get_model()\n\n    # prepare the dataloader\n    dataloader = prepare(dataset, current_gpu_index, num_gpu, batch_size)\n\n    # Instantiate the torch optimizer\n    optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n\n    # Instantiate the torch loss function\n    loss_fn = nn.CrossEntropyLoss()\n\n    # Put model on device\n    model = model.to(current_gpu_index)\n    ddp_model = DDP(\n        model, device_ids=[current_gpu_index], output_device=current_gpu_index\n    )\n\n    train(ddp_model, dataloader, num_epochs, optimizer, loss_fn)\n\n    ################################################################\n    ######## Using a Keras model or layer in a torch Module ########\n    ################################################################\n\n    torch_module = MyModel().to(current_gpu_index)\n    ddp_torch_module = DDP(\n        torch_module,\n        device_ids=[current_gpu_index],\n        output_device=current_gpu_index,\n    )\n\n    # Instantiate the torch optimizer\n    optimizer = optim.Adam(torch_module.parameters(), lr=learning_rate)\n\n    # Instantiate the torch loss function\n    loss_fn = nn.CrossEntropyLoss()\n\n    train(ddp_torch_module, dataloader, num_epochs, optimizer, loss_fn)\n\n    cleanup()\n\n\nif __name__ == \"__main__\":\n    # GPU parameters\n    num_gpu = torch.cuda.device_count()\n\n    print(f\"Running on {num_gpu} GPUs\")\n\n    torch.multiprocessing.spawn(\n        main,\n        args=(num_gpu,),\n        nprocs=num_gpu,\n        join=True,\n    )\n"
  },
  {
    "path": "guides/custom_train_step_in_jax.py",
    "content": "\"\"\"\nTitle: Customizing what happens in `fit()` with JAX\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate created: 2023/06/27\nLast modified: 2023/06/27\nDescription: Overriding the training step of the Model class with JAX.\nAccelerator: GPU\n\"\"\"\n\n\"\"\"\n## Introduction\n\nWhen you're doing supervised learning, you can use `fit()` and everything works\nsmoothly.\n\nWhen you need to take control of every little detail, you can write your own training\nloop entirely from scratch.\n\nBut what if you need a custom training algorithm, but you still want to benefit from\nthe convenient features of `fit()`, such as callbacks, built-in distribution support,\nor step fusing?\n\nA core principle of Keras is **progressive disclosure of complexity**. You should\nalways be able to get into lower-level workflows in a gradual way. You shouldn't fall\noff a cliff if the high-level functionality doesn't exactly match your use case. You\nshould be able to gain more control over the small details while retaining a\ncommensurate amount of high-level convenience.\n\nWhen you need to customize what `fit()` does, you should **override the training step\nfunction of the `Model` class**. This is the function that is called by `fit()` for\nevery batch of data. You will then be able to call `fit()` as usual -- and it will be\nrunning your own learning algorithm.\n\nNote that this pattern does not prevent you from building models with the Functional\nAPI. You can do this whether you're building `Sequential` models, Functional API\nmodels, or subclassed models.\n\nLet's see how that works.\n\"\"\"\n\n\"\"\"\n## Setup\n\"\"\"\n\nimport os\n\n# This guide can only be run with the JAX backend.\nos.environ[\"KERAS_BACKEND\"] = \"jax\"\n\nimport jax\nimport keras\nimport numpy as np\n\n\"\"\"\n## A first simple example\n\nLet's start from a simple example:\n\n- We create a new class that subclasses `keras.Model`.\n- We implement a fully-stateless `compute_loss_and_updates()` method\nto compute the loss as well as the updated values for the non-trainable\nvariables of the model. Internally, it calls `stateless_call()` and\nthe built-in `compute_loss()`.\n- We implement a fully-stateless `train_step()` method to compute current\nmetric values (including the loss) as well as updated values for the \ntrainable variables, the optimizer variables, and the metric variables.\n\nNote that you can also take into account the `sample_weight` argument by:\n\n- Unpacking the data as `x, y, sample_weight = data`\n- Passing `sample_weight` to `compute_loss()`\n- Passing `sample_weight` alongside `y` and `y_pred`\nto metrics in `stateless_update_state()`\n\"\"\"\n\n\nclass CustomModel(keras.Model):\n    def compute_loss_and_updates(\n        self,\n        trainable_variables,\n        non_trainable_variables,\n        x,\n        y,\n        training=False,\n    ):\n        y_pred, non_trainable_variables = self.stateless_call(\n            trainable_variables,\n            non_trainable_variables,\n            x,\n            training=training,\n        )\n        loss = self.compute_loss(x, y, y_pred)\n        return loss, (y_pred, non_trainable_variables)\n\n    def train_step(self, state, data):\n        (\n            trainable_variables,\n            non_trainable_variables,\n            optimizer_variables,\n            metrics_variables,\n        ) = state\n        x, y = data\n\n        # Get the gradient function.\n        grad_fn = jax.value_and_grad(\n            self.compute_loss_and_updates, has_aux=True\n        )\n\n        # Compute the gradients.\n        (loss, (y_pred, non_trainable_variables)), grads = grad_fn(\n            trainable_variables,\n            non_trainable_variables,\n            x,\n            y,\n            training=True,\n        )\n\n        # Update trainable variables and optimizer variables.\n        (\n            trainable_variables,\n            optimizer_variables,\n        ) = self.optimizer.stateless_apply(\n            optimizer_variables, grads, trainable_variables\n        )\n\n        # Update metrics.\n        new_metrics_vars, logs = [], []\n        for metric in self.metrics:\n            this_metric_vars = metrics_variables[\n                len(new_metrics_vars) : len(new_metrics_vars)\n                + len(metric.variables)\n            ]\n            if metric.name == \"loss\":\n                this_metric_vars = metric.stateless_update_state(\n                    this_metric_vars, loss\n                )\n            else:\n                this_metric_vars = metric.stateless_update_state(\n                    this_metric_vars, y, y_pred\n                )\n            logs = metric.stateless_result(this_metric_vars)\n            new_metrics_vars += this_metric_vars\n\n        # Return metric logs and updated state variables.\n        state = (\n            trainable_variables,\n            non_trainable_variables,\n            optimizer_variables,\n            new_metrics_vars,\n        )\n        return logs, state\n\n\n\"\"\"\nLet's try this out:\n\"\"\"\n\n# Construct and compile an instance of CustomModel\ninputs = keras.Input(shape=(32,))\noutputs = keras.layers.Dense(1)(inputs)\nmodel = CustomModel(inputs, outputs)\nmodel.compile(optimizer=\"adam\", loss=\"mse\", metrics=[\"mae\"])\n\n# Just use `fit` as usual\nx = np.random.random((1000, 32))\ny = np.random.random((1000, 1))\nmodel.fit(x, y, epochs=3)\n\n\n\"\"\"\n## Going lower-level\n\nNaturally, you could just skip passing a loss function in `compile()`, and instead do\neverything *manually* in `train_step`. Likewise for metrics.\n\nHere's a lower-level example, that only uses `compile()` to configure the optimizer:\n\"\"\"\n\n\nclass CustomModel(keras.Model):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.loss_tracker = keras.metrics.Mean(name=\"loss\")\n        self.mae_metric = keras.metrics.MeanAbsoluteError(name=\"mae\")\n        self.loss_fn = keras.losses.MeanSquaredError()\n\n    def compute_loss_and_updates(\n        self,\n        trainable_variables,\n        non_trainable_variables,\n        x,\n        y,\n        training=False,\n    ):\n        y_pred, non_trainable_variables = self.stateless_call(\n            trainable_variables,\n            non_trainable_variables,\n            x,\n            training=training,\n        )\n        loss = self.loss_fn(y, y_pred)\n        return loss, (y_pred, non_trainable_variables)\n\n    def train_step(self, state, data):\n        (\n            trainable_variables,\n            non_trainable_variables,\n            optimizer_variables,\n            metrics_variables,\n        ) = state\n        x, y = data\n\n        # Get the gradient function.\n        grad_fn = jax.value_and_grad(\n            self.compute_loss_and_updates, has_aux=True\n        )\n\n        # Compute the gradients.\n        (loss, (y_pred, non_trainable_variables)), grads = grad_fn(\n            trainable_variables,\n            non_trainable_variables,\n            x,\n            y,\n            training=True,\n        )\n\n        # Update trainable variables and optimizer variables.\n        (\n            trainable_variables,\n            optimizer_variables,\n        ) = self.optimizer.stateless_apply(\n            optimizer_variables, grads, trainable_variables\n        )\n\n        # Update metrics.\n        loss_tracker_vars = metrics_variables[\n            : len(self.loss_tracker.variables)\n        ]\n        mae_metric_vars = metrics_variables[len(self.loss_tracker.variables) :]\n\n        loss_tracker_vars = self.loss_tracker.stateless_update_state(\n            loss_tracker_vars, loss\n        )\n        mae_metric_vars = self.mae_metric.stateless_update_state(\n            mae_metric_vars, y, y_pred\n        )\n\n        logs = {}\n        logs[self.loss_tracker.name] = self.loss_tracker.stateless_result(\n            loss_tracker_vars\n        )\n        logs[self.mae_metric.name] = self.mae_metric.stateless_result(\n            mae_metric_vars\n        )\n\n        new_metrics_vars = loss_tracker_vars + mae_metric_vars\n\n        # Return metric logs and updated state variables.\n        state = (\n            trainable_variables,\n            non_trainable_variables,\n            optimizer_variables,\n            new_metrics_vars,\n        )\n        return logs, state\n\n    @property\n    def metrics(self):\n        # We list our `Metric` objects here so that `reset_states()` can be\n        # called automatically at the start of each epoch\n        # or at the start of `evaluate()`.\n        return [self.loss_tracker, self.mae_metric]\n\n\n# Construct an instance of CustomModel\ninputs = keras.Input(shape=(32,))\noutputs = keras.layers.Dense(1)(inputs)\nmodel = CustomModel(inputs, outputs)\n\n# We don't pass a loss or metrics here.\nmodel.compile(optimizer=\"adam\")\n\n# Just use `fit` as usual -- you can use callbacks, etc.\nx = np.random.random((1000, 32))\ny = np.random.random((1000, 1))\nmodel.fit(x, y, epochs=5)\n\n\n\"\"\"\n## Providing your own evaluation step\n\nWhat if you want to do the same for calls to `model.evaluate()`? Then you would\noverride `test_step` in exactly the same way. Here's what it looks like:\n\"\"\"\n\n\nclass CustomModel(keras.Model):\n    def test_step(self, state, data):\n        # Unpack the data.\n        x, y = data\n        (\n            trainable_variables,\n            non_trainable_variables,\n            metrics_variables,\n        ) = state\n\n        # Compute predictions and loss.\n        y_pred, non_trainable_variables = self.stateless_call(\n            trainable_variables,\n            non_trainable_variables,\n            x,\n            training=False,\n        )\n        loss = self.compute_loss(x, y, y_pred)\n\n        # Update metrics.\n        new_metrics_vars, logs = [], []\n        for metric in self.metrics:\n            this_metric_vars = metrics_variables[\n                len(new_metrics_vars) : len(new_metrics_vars)\n                + len(metric.variables)\n            ]\n            if metric.name == \"loss\":\n                this_metric_vars = metric.stateless_update_state(\n                    this_metric_vars, loss\n                )\n            else:\n                this_metric_vars = metric.stateless_update_state(\n                    this_metric_vars, y, y_pred\n                )\n            logs = metric.stateless_result(this_metric_vars)\n            new_metrics_vars += this_metric_vars\n\n        # Return metric logs and updated state variables.\n        state = (\n            trainable_variables,\n            non_trainable_variables,\n            new_metrics_vars,\n        )\n        return logs, state\n\n\n# Construct an instance of CustomModel\ninputs = keras.Input(shape=(32,))\noutputs = keras.layers.Dense(1)(inputs)\nmodel = CustomModel(inputs, outputs)\nmodel.compile(loss=\"mse\", metrics=[\"mae\"])\n\n# Evaluate with our custom test_step\nx = np.random.random((1000, 32))\ny = np.random.random((1000, 1))\nmodel.evaluate(x, y)\n\n\n\"\"\"\nThat's it!\n\"\"\"\n"
  },
  {
    "path": "guides/custom_train_step_in_tensorflow.py",
    "content": "\"\"\"\nTitle: Customizing what happens in `fit()` with TensorFlow\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate created: 2020/04/15\nLast modified: 2023/06/27\nDescription: Overriding the training step of the Model class with TensorFlow.\nAccelerator: GPU\n\"\"\"\n\n\"\"\"\n## Introduction\n\nWhen you're doing supervised learning, you can use `fit()` and everything works\nsmoothly.\n\nWhen you need to take control of every little detail, you can write your own training\nloop entirely from scratch.\n\nBut what if you need a custom training algorithm, but you still want to benefit from\nthe convenient features of `fit()`, such as callbacks, built-in distribution support,\nor step fusing?\n\nA core principle of Keras is **progressive disclosure of complexity**. You should\nalways be able to get into lower-level workflows in a gradual way. You shouldn't fall\noff a cliff if the high-level functionality doesn't exactly match your use case. You\nshould be able to gain more control over the small details while retaining a\ncommensurate amount of high-level convenience.\n\nWhen you need to customize what `fit()` does, you should **override the training step\nfunction of the `Model` class**. This is the function that is called by `fit()` for\nevery batch of data. You will then be able to call `fit()` as usual -- and it will be\nrunning your own learning algorithm.\n\nNote that this pattern does not prevent you from building models with the Functional\nAPI. You can do this whether you're building `Sequential` models, Functional API\nmodels, or subclassed models.\n\nLet's see how that works.\n\"\"\"\n\n\"\"\"\n## Setup\n\"\"\"\n\nimport os\n\n# This guide can only be run with the TF backend.\nos.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n\nimport tensorflow as tf\nimport keras\nfrom keras import layers\nimport numpy as np\n\n\"\"\"\n## A first simple example\n\nLet's start from a simple example:\n\n- We create a new class that subclasses `keras.Model`.\n- We just override the method `train_step(self, data)`.\n- We return a dictionary mapping metric names (including the loss) to their current\nvalue.\n\nThe input argument `data` is what gets passed to fit as training data:\n\n- If you pass NumPy arrays, by calling `fit(x, y, ...)`, then `data` will be the tuple\n`(x, y)`\n- If you pass a `tf.data.Dataset`, by calling `fit(dataset, ...)`, then `data` will be\nwhat gets yielded by `dataset` at each batch.\n\nIn the body of the `train_step()` method, we implement a regular training update,\nsimilar to what you are already familiar with. Importantly, **we compute the loss via\n`self.compute_loss()`**, which wraps the loss(es) function(s) that were passed to\n`compile()`.\n\nSimilarly, we call `metric.update_state(y, y_pred)` on metrics from `self.metrics`,\nto update the state of the metrics that were passed in `compile()`,\nand we query results from `self.metrics` at the end to retrieve their current value.\n\"\"\"\n\n\nclass CustomModel(keras.Model):\n    def train_step(self, data):\n        # Unpack the data. Its structure depends on your model and\n        # on what you pass to `fit()`.\n        x, y = data\n\n        with tf.GradientTape() as tape:\n            y_pred = self(x, training=True)  # Forward pass\n            # Compute the loss value\n            # (the loss function is configured in `compile()`)\n            loss = self.compute_loss(y=y, y_pred=y_pred)\n\n        # Compute gradients\n        trainable_vars = self.trainable_variables\n        gradients = tape.gradient(loss, trainable_vars)\n\n        # Update weights\n        self.optimizer.apply(gradients, trainable_vars)\n\n        # Update metrics (includes the metric that tracks the loss)\n        for metric in self.metrics:\n            if metric.name == \"loss\":\n                metric.update_state(loss)\n            else:\n                metric.update_state(y, y_pred)\n\n        # Return a dict mapping metric names to current value\n        return {m.name: m.result() for m in self.metrics}\n\n\n\"\"\"\nLet's try this out:\n\"\"\"\n\n# Construct and compile an instance of CustomModel\ninputs = keras.Input(shape=(32,))\noutputs = keras.layers.Dense(1)(inputs)\nmodel = CustomModel(inputs, outputs)\nmodel.compile(optimizer=\"adam\", loss=\"mse\", metrics=[\"mae\"])\n\n# Just use `fit` as usual\nx = np.random.random((1000, 32))\ny = np.random.random((1000, 1))\nmodel.fit(x, y, epochs=3)\n\n\"\"\"\n## Going lower-level\n\nNaturally, you could just skip passing a loss function in `compile()`, and instead do\neverything *manually* in `train_step`. Likewise for metrics.\n\nHere's a lower-level example, that only uses `compile()` to configure the optimizer:\n\n- We start by creating `Metric` instances to track our loss and a MAE score (in `__init__()`).\n- We implement a custom `train_step()` that updates the state of these metrics\n(by calling `update_state()` on them), then query them (via `result()`) to return their current average value,\nto be displayed by the progress bar and to be pass to any callback.\n- Note that we would need to call `reset_states()` on our metrics between each epoch! Otherwise\ncalling `result()` would return an average since the start of training, whereas we usually work\nwith per-epoch averages. Thankfully, the framework can do that for us: just list any metric\nyou want to reset in the `metrics` property of the model. The model will call `reset_states()`\non any object listed here at the beginning of each `fit()` epoch or at the beginning of a call to\n`evaluate()`.\n\"\"\"\n\n\nclass CustomModel(keras.Model):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.loss_tracker = keras.metrics.Mean(name=\"loss\")\n        self.mae_metric = keras.metrics.MeanAbsoluteError(name=\"mae\")\n        self.loss_fn = keras.losses.MeanSquaredError()\n\n    def train_step(self, data):\n        x, y = data\n\n        with tf.GradientTape() as tape:\n            y_pred = self(x, training=True)  # Forward pass\n            # Compute our own loss\n            loss = self.loss_fn(y, y_pred)\n\n        # Compute gradients\n        trainable_vars = self.trainable_variables\n        gradients = tape.gradient(loss, trainable_vars)\n\n        # Update weights\n        self.optimizer.apply(gradients, trainable_vars)\n\n        # Compute our own metrics\n        self.loss_tracker.update_state(loss)\n        self.mae_metric.update_state(y, y_pred)\n        return {\n            \"loss\": self.loss_tracker.result(),\n            \"mae\": self.mae_metric.result(),\n        }\n\n    @property\n    def metrics(self):\n        # We list our `Metric` objects here so that `reset_states()` can be\n        # called automatically at the start of each epoch\n        # or at the start of `evaluate()`.\n        return [self.loss_tracker, self.mae_metric]\n\n\n# Construct an instance of CustomModel\ninputs = keras.Input(shape=(32,))\noutputs = keras.layers.Dense(1)(inputs)\nmodel = CustomModel(inputs, outputs)\n\n# We don't pass a loss or metrics here.\nmodel.compile(optimizer=\"adam\")\n\n# Just use `fit` as usual -- you can use callbacks, etc.\nx = np.random.random((1000, 32))\ny = np.random.random((1000, 1))\nmodel.fit(x, y, epochs=5)\n\n\n\"\"\"\n## Supporting `sample_weight` & `class_weight`\n\nYou may have noticed that our first basic example didn't make any mention of sample\nweighting. If you want to support the `fit()` arguments `sample_weight` and\n`class_weight`, you'd simply do the following:\n\n- Unpack `sample_weight` from the `data` argument\n- Pass it to `compute_loss` & `update_state` (of course, you could also just apply\nit manually if you don't rely on `compile()` for losses & metrics)\n- That's it.\n\"\"\"\n\n\nclass CustomModel(keras.Model):\n    def train_step(self, data):\n        # Unpack the data. Its structure depends on your model and\n        # on what you pass to `fit()`.\n        if len(data) == 3:\n            x, y, sample_weight = data\n        else:\n            sample_weight = None\n            x, y = data\n\n        with tf.GradientTape() as tape:\n            y_pred = self(x, training=True)  # Forward pass\n            # Compute the loss value.\n            # The loss function is configured in `compile()`.\n            loss = self.compute_loss(\n                y=y,\n                y_pred=y_pred,\n                sample_weight=sample_weight,\n            )\n\n        # Compute gradients\n        trainable_vars = self.trainable_variables\n        gradients = tape.gradient(loss, trainable_vars)\n\n        # Update weights\n        self.optimizer.apply(gradients, trainable_vars)\n\n        # Update the metrics.\n        # Metrics are configured in `compile()`.\n        for metric in self.metrics:\n            if metric.name == \"loss\":\n                metric.update_state(loss)\n            else:\n                metric.update_state(y, y_pred, sample_weight=sample_weight)\n\n        # Return a dict mapping metric names to current value.\n        # Note that it will include the loss (tracked in self.metrics).\n        return {m.name: m.result() for m in self.metrics}\n\n\n# Construct and compile an instance of CustomModel\ninputs = keras.Input(shape=(32,))\noutputs = keras.layers.Dense(1)(inputs)\nmodel = CustomModel(inputs, outputs)\nmodel.compile(optimizer=\"adam\", loss=\"mse\", metrics=[\"mae\"])\n\n# You can now use sample_weight argument\nx = np.random.random((1000, 32))\ny = np.random.random((1000, 1))\nsw = np.random.random((1000, 1))\nmodel.fit(x, y, sample_weight=sw, epochs=3)\n\n\"\"\"\n## Providing your own evaluation step\n\nWhat if you want to do the same for calls to `model.evaluate()`? Then you would\noverride `test_step` in exactly the same way. Here's what it looks like:\n\"\"\"\n\n\nclass CustomModel(keras.Model):\n    def test_step(self, data):\n        # Unpack the data\n        x, y = data\n        # Compute predictions\n        y_pred = self(x, training=False)\n        # Updates the metrics tracking the loss\n        loss = self.compute_loss(y=y, y_pred=y_pred)\n        # Update the metrics.\n        for metric in self.metrics:\n            if metric.name == \"loss\":\n                metric.update_state(loss)\n            else:\n                metric.update_state(y, y_pred)\n        # Return a dict mapping metric names to current value.\n        # Note that it will include the loss (tracked in self.metrics).\n        return {m.name: m.result() for m in self.metrics}\n\n\n# Construct an instance of CustomModel\ninputs = keras.Input(shape=(32,))\noutputs = keras.layers.Dense(1)(inputs)\nmodel = CustomModel(inputs, outputs)\nmodel.compile(loss=\"mse\", metrics=[\"mae\"])\n\n# Evaluate with our custom test_step\nx = np.random.random((1000, 32))\ny = np.random.random((1000, 1))\nmodel.evaluate(x, y)\n\n\"\"\"\n## Wrapping up: an end-to-end GAN example\n\nLet's walk through an end-to-end example that leverages everything you just learned.\n\nLet's consider:\n\n- A generator network meant to generate 28x28x1 images.\n- A discriminator network meant to classify 28x28x1 images into two classes (\"fake\" and\n\"real\").\n- One optimizer for each.\n- A loss function to train the discriminator.\n\"\"\"\n\n# Create the discriminator\ndiscriminator = keras.Sequential(\n    [\n        keras.Input(shape=(28, 28, 1)),\n        layers.Conv2D(64, (3, 3), strides=(2, 2), padding=\"same\"),\n        layers.LeakyReLU(negative_slope=0.2),\n        layers.Conv2D(128, (3, 3), strides=(2, 2), padding=\"same\"),\n        layers.LeakyReLU(negative_slope=0.2),\n        layers.GlobalMaxPooling2D(),\n        layers.Dense(1),\n    ],\n    name=\"discriminator\",\n)\n\n# Create the generator\nlatent_dim = 128\ngenerator = keras.Sequential(\n    [\n        keras.Input(shape=(latent_dim,)),\n        # We want to generate 128 coefficients to reshape into a 7x7x128 map\n        layers.Dense(7 * 7 * 128),\n        layers.LeakyReLU(negative_slope=0.2),\n        layers.Reshape((7, 7, 128)),\n        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding=\"same\"),\n        layers.LeakyReLU(negative_slope=0.2),\n        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding=\"same\"),\n        layers.LeakyReLU(negative_slope=0.2),\n        layers.Conv2D(1, (7, 7), padding=\"same\", activation=\"sigmoid\"),\n    ],\n    name=\"generator\",\n)\n\n\"\"\"\nHere's a feature-complete GAN class, overriding `compile()` to use its own signature,\nand implementing the entire GAN algorithm in 17 lines in `train_step`:\n\"\"\"\n\n\nclass GAN(keras.Model):\n    def __init__(self, discriminator, generator, latent_dim):\n        super().__init__()\n        self.discriminator = discriminator\n        self.generator = generator\n        self.latent_dim = latent_dim\n        self.d_loss_tracker = keras.metrics.Mean(name=\"d_loss\")\n        self.g_loss_tracker = keras.metrics.Mean(name=\"g_loss\")\n        self.seed_generator = keras.random.SeedGenerator(1337)\n\n    @property\n    def metrics(self):\n        return [self.d_loss_tracker, self.g_loss_tracker]\n\n    def compile(self, d_optimizer, g_optimizer, loss_fn):\n        super().compile()\n        self.d_optimizer = d_optimizer\n        self.g_optimizer = g_optimizer\n        self.loss_fn = loss_fn\n\n    def train_step(self, real_images):\n        if isinstance(real_images, tuple):\n            real_images = real_images[0]\n        # Sample random points in the latent space\n        batch_size = tf.shape(real_images)[0]\n        random_latent_vectors = keras.random.normal(\n            shape=(batch_size, self.latent_dim), seed=self.seed_generator\n        )\n\n        # Decode them to fake images\n        generated_images = self.generator(random_latent_vectors)\n\n        # Combine them with real images\n        combined_images = tf.concat([generated_images, real_images], axis=0)\n\n        # Assemble labels discriminating real from fake images\n        labels = tf.concat(\n            [tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0\n        )\n        # Add random noise to the labels - important trick!\n        labels += 0.05 * keras.random.uniform(\n            tf.shape(labels), seed=self.seed_generator\n        )\n\n        # Train the discriminator\n        with tf.GradientTape() as tape:\n            predictions = self.discriminator(combined_images)\n            d_loss = self.loss_fn(labels, predictions)\n        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)\n        self.d_optimizer.apply(grads, self.discriminator.trainable_weights)\n\n        # Sample random points in the latent space\n        random_latent_vectors = keras.random.normal(\n            shape=(batch_size, self.latent_dim), seed=self.seed_generator\n        )\n\n        # Assemble labels that say \"all real images\"\n        misleading_labels = tf.zeros((batch_size, 1))\n\n        # Train the generator (note that we should *not* update the weights\n        # of the discriminator)!\n        with tf.GradientTape() as tape:\n            predictions = self.discriminator(\n                self.generator(random_latent_vectors)\n            )\n            g_loss = self.loss_fn(misleading_labels, predictions)\n        grads = tape.gradient(g_loss, self.generator.trainable_weights)\n        self.g_optimizer.apply(grads, self.generator.trainable_weights)\n\n        # Update metrics and return their value.\n        self.d_loss_tracker.update_state(d_loss)\n        self.g_loss_tracker.update_state(g_loss)\n        return {\n            \"d_loss\": self.d_loss_tracker.result(),\n            \"g_loss\": self.g_loss_tracker.result(),\n        }\n\n\n\"\"\"\nLet's test-drive it:\n\"\"\"\n\n# Prepare the dataset. We use both the training & test MNIST digits.\nbatch_size = 64\n(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()\nall_digits = np.concatenate([x_train, x_test])\nall_digits = all_digits.astype(\"float32\") / 255.0\nall_digits = np.reshape(all_digits, (-1, 28, 28, 1))\ndataset = tf.data.Dataset.from_tensor_slices(all_digits)\ndataset = dataset.shuffle(buffer_size=1024).batch(batch_size)\n\ngan = GAN(\n    discriminator=discriminator, generator=generator, latent_dim=latent_dim\n)\ngan.compile(\n    d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),\n    g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),\n    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),\n)\n\n# To limit the execution time, we only train on 100 batches. You can train on\n# the entire dataset. You will need about 20 epochs to get nice results.\ngan.fit(dataset.take(100), epochs=1)\n\n\"\"\"\nThe ideas behind deep learning are simple, so why should their implementation be painful?\n\"\"\"\n"
  },
  {
    "path": "guides/custom_train_step_in_torch.py",
    "content": "\"\"\"\nTitle: Customizing what happens in `fit()` with PyTorch\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate created: 2023/06/27\nLast modified: 2023/06/27\nDescription: Overriding the training step of the Model class with PyTorch.\nAccelerator: GPU\n\"\"\"\n\n\"\"\"\n## Introduction\n\nWhen you're doing supervised learning, you can use `fit()` and everything works\nsmoothly.\n\nWhen you need to take control of every little detail, you can write your own training\nloop entirely from scratch.\n\nBut what if you need a custom training algorithm, but you still want to benefit from\nthe convenient features of `fit()`, such as callbacks, built-in distribution support,\nor step fusing?\n\nA core principle of Keras is **progressive disclosure of complexity**. You should\nalways be able to get into lower-level workflows in a gradual way. You shouldn't fall\noff a cliff if the high-level functionality doesn't exactly match your use case. You\nshould be able to gain more control over the small details while retaining a\ncommensurate amount of high-level convenience.\n\nWhen you need to customize what `fit()` does, you should **override the training step\nfunction of the `Model` class**. This is the function that is called by `fit()` for\nevery batch of data. You will then be able to call `fit()` as usual -- and it will be\nrunning your own learning algorithm.\n\nNote that this pattern does not prevent you from building models with the Functional\nAPI. You can do this whether you're building `Sequential` models, Functional API\nmodels, or subclassed models.\n\nLet's see how that works.\n\"\"\"\n\n\"\"\"\n## Setup\n\"\"\"\n\nimport os\n\n# This guide can only be run with the torch backend.\nos.environ[\"KERAS_BACKEND\"] = \"torch\"\n\nimport torch\nimport keras\nfrom keras import layers\nimport numpy as np\n\n\"\"\"\n## A first simple example\n\nLet's start from a simple example:\n\n- We create a new class that subclasses `keras.Model`.\n- We just override the method `train_step(self, data)`.\n- We return a dictionary mapping metric names (including the loss) to their current\nvalue.\n\nThe input argument `data` is what gets passed to fit as training data:\n\n- If you pass NumPy arrays, by calling `fit(x, y, ...)`, then `data` will be the tuple\n`(x, y)`\n- If you pass a `torch.utils.data.DataLoader` or a `tf.data.Dataset`,\nby calling `fit(dataset, ...)`, then `data` will be what gets yielded\nby `dataset` at each batch.\n\nIn the body of the `train_step()` method, we implement a regular training update,\nsimilar to what you are already familiar with. Importantly, **we compute the loss via\n`self.compute_loss()`**, which wraps the loss(es) function(s) that were passed to\n`compile()`.\n\nSimilarly, we call `metric.update_state(y, y_pred)` on metrics from `self.metrics`,\nto update the state of the metrics that were passed in `compile()`,\nand we query results from `self.metrics` at the end to retrieve their current value.\n\"\"\"\n\n\nclass CustomModel(keras.Model):\n    def train_step(self, data):\n        # Unpack the data. Its structure depends on your model and\n        # on what you pass to `fit()`.\n        x, y = data\n\n        # Call torch.nn.Module.zero_grad() to clear the leftover gradients\n        # for the weights from the previous train step.\n        self.zero_grad()\n\n        # Compute loss\n        y_pred = self(x, training=True)  # Forward pass\n        loss = self.compute_loss(y=y, y_pred=y_pred)\n\n        # Call torch.Tensor.backward() on the loss to compute gradients\n        # for the weights.\n        loss.backward()\n\n        trainable_weights = [v for v in self.trainable_weights]\n        gradients = [v.value.grad for v in trainable_weights]\n\n        # Update weights\n        with torch.no_grad():\n            self.optimizer.apply(gradients, trainable_weights)\n\n        # Update metrics (includes the metric that tracks the loss)\n        for metric in self.metrics:\n            if metric.name == \"loss\":\n                metric.update_state(loss)\n            else:\n                metric.update_state(y, y_pred)\n\n        # Return a dict mapping metric names to current value\n        # Note that it will include the loss (tracked in self.metrics).\n        return {m.name: m.result() for m in self.metrics}\n\n\n\"\"\"\nLet's try this out:\n\"\"\"\n\n# Construct and compile an instance of CustomModel\ninputs = keras.Input(shape=(32,))\noutputs = keras.layers.Dense(1)(inputs)\nmodel = CustomModel(inputs, outputs)\nmodel.compile(optimizer=\"adam\", loss=\"mse\", metrics=[\"mae\"])\n\n# Just use `fit` as usual\nx = np.random.random((1000, 32))\ny = np.random.random((1000, 1))\nmodel.fit(x, y, epochs=3)\n\n\"\"\"\n## Going lower-level\n\nNaturally, you could just skip passing a loss function in `compile()`, and instead do\neverything *manually* in `train_step`. Likewise for metrics.\n\nHere's a lower-level example, that only uses `compile()` to configure the optimizer:\n\n- We start by creating `Metric` instances to track our loss and a MAE score (in `__init__()`).\n- We implement a custom `train_step()` that updates the state of these metrics\n(by calling `update_state()` on them), then query them (via `result()`) to return their current average value,\nto be displayed by the progress bar and to be pass to any callback.\n- Note that we would need to call `reset_states()` on our metrics between each epoch! Otherwise\ncalling `result()` would return an average since the start of training, whereas we usually work\nwith per-epoch averages. Thankfully, the framework can do that for us: just list any metric\nyou want to reset in the `metrics` property of the model. The model will call `reset_states()`\non any object listed here at the beginning of each `fit()` epoch or at the beginning of a call to\n`evaluate()`.\n\"\"\"\n\n\nclass CustomModel(keras.Model):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.loss_tracker = keras.metrics.Mean(name=\"loss\")\n        self.mae_metric = keras.metrics.MeanAbsoluteError(name=\"mae\")\n        self.loss_fn = keras.losses.MeanSquaredError()\n\n    def train_step(self, data):\n        x, y = data\n\n        # Call torch.nn.Module.zero_grad() to clear the leftover gradients\n        # for the weights from the previous train step.\n        self.zero_grad()\n\n        # Compute loss\n        y_pred = self(x, training=True)  # Forward pass\n        loss = self.loss_fn(y, y_pred)\n\n        # Call torch.Tensor.backward() on the loss to compute gradients\n        # for the weights.\n        loss.backward()\n\n        trainable_weights = [v for v in self.trainable_weights]\n        gradients = [v.value.grad for v in trainable_weights]\n\n        # Update weights\n        with torch.no_grad():\n            self.optimizer.apply(gradients, trainable_weights)\n\n        # Compute our own metrics\n        self.loss_tracker.update_state(loss)\n        self.mae_metric.update_state(y, y_pred)\n        return {\n            \"loss\": self.loss_tracker.result(),\n            \"mae\": self.mae_metric.result(),\n        }\n\n    @property\n    def metrics(self):\n        # We list our `Metric` objects here so that `reset_states()` can be\n        # called automatically at the start of each epoch\n        # or at the start of `evaluate()`.\n        return [self.loss_tracker, self.mae_metric]\n\n\n# Construct an instance of CustomModel\ninputs = keras.Input(shape=(32,))\noutputs = keras.layers.Dense(1)(inputs)\nmodel = CustomModel(inputs, outputs)\n\n# We don't pass a loss or metrics here.\nmodel.compile(optimizer=\"adam\")\n\n# Just use `fit` as usual -- you can use callbacks, etc.\nx = np.random.random((1000, 32))\ny = np.random.random((1000, 1))\nmodel.fit(x, y, epochs=5)\n\n\n\"\"\"\n## Supporting `sample_weight` & `class_weight`\n\nYou may have noticed that our first basic example didn't make any mention of sample\nweighting. If you want to support the `fit()` arguments `sample_weight` and\n`class_weight`, you'd simply do the following:\n\n- Unpack `sample_weight` from the `data` argument\n- Pass it to `compute_loss` & `update_state` (of course, you could also just apply\nit manually if you don't rely on `compile()` for losses & metrics)\n- That's it.\n\"\"\"\n\n\nclass CustomModel(keras.Model):\n    def train_step(self, data):\n        # Unpack the data. Its structure depends on your model and\n        # on what you pass to `fit()`.\n        if len(data) == 3:\n            x, y, sample_weight = data\n        else:\n            sample_weight = None\n            x, y = data\n\n        # Call torch.nn.Module.zero_grad() to clear the leftover gradients\n        # for the weights from the previous train step.\n        self.zero_grad()\n\n        # Compute loss\n        y_pred = self(x, training=True)  # Forward pass\n        loss = self.compute_loss(\n            y=y,\n            y_pred=y_pred,\n            sample_weight=sample_weight,\n        )\n\n        # Call torch.Tensor.backward() on the loss to compute gradients\n        # for the weights.\n        loss.backward()\n\n        trainable_weights = [v for v in self.trainable_weights]\n        gradients = [v.value.grad for v in trainable_weights]\n\n        # Update weights\n        with torch.no_grad():\n            self.optimizer.apply(gradients, trainable_weights)\n\n        # Update metrics (includes the metric that tracks the loss)\n        for metric in self.metrics:\n            if metric.name == \"loss\":\n                metric.update_state(loss)\n            else:\n                metric.update_state(y, y_pred, sample_weight=sample_weight)\n\n        # Return a dict mapping metric names to current value\n        # Note that it will include the loss (tracked in self.metrics).\n        return {m.name: m.result() for m in self.metrics}\n\n\n# Construct and compile an instance of CustomModel\ninputs = keras.Input(shape=(32,))\noutputs = keras.layers.Dense(1)(inputs)\nmodel = CustomModel(inputs, outputs)\nmodel.compile(optimizer=\"adam\", loss=\"mse\", metrics=[\"mae\"])\n\n# You can now use sample_weight argument\nx = np.random.random((1000, 32))\ny = np.random.random((1000, 1))\nsw = np.random.random((1000, 1))\nmodel.fit(x, y, sample_weight=sw, epochs=3)\n\n\"\"\"\n## Providing your own evaluation step\n\nWhat if you want to do the same for calls to `model.evaluate()`? Then you would\noverride `test_step` in exactly the same way. Here's what it looks like:\n\"\"\"\n\n\nclass CustomModel(keras.Model):\n    def test_step(self, data):\n        # Unpack the data\n        x, y = data\n        # Compute predictions\n        y_pred = self(x, training=False)\n        # Updates the metrics tracking the loss\n        loss = self.compute_loss(y=y, y_pred=y_pred)\n        # Update the metrics.\n        for metric in self.metrics:\n            if metric.name == \"loss\":\n                metric.update_state(loss)\n            else:\n                metric.update_state(y, y_pred)\n        # Return a dict mapping metric names to current value.\n        # Note that it will include the loss (tracked in self.metrics).\n        return {m.name: m.result() for m in self.metrics}\n\n\n# Construct an instance of CustomModel\ninputs = keras.Input(shape=(32,))\noutputs = keras.layers.Dense(1)(inputs)\nmodel = CustomModel(inputs, outputs)\nmodel.compile(loss=\"mse\", metrics=[\"mae\"])\n\n# Evaluate with our custom test_step\nx = np.random.random((1000, 32))\ny = np.random.random((1000, 1))\nmodel.evaluate(x, y)\n\n\"\"\"\n## Wrapping up: an end-to-end GAN example\n\nLet's walk through an end-to-end example that leverages everything you just learned.\n\nLet's consider:\n\n- A generator network meant to generate 28x28x1 images.\n- A discriminator network meant to classify 28x28x1 images into two classes (\"fake\" and\n\"real\").\n- One optimizer for each.\n- A loss function to train the discriminator.\n\"\"\"\n\n# Create the discriminator\ndiscriminator = keras.Sequential(\n    [\n        keras.Input(shape=(28, 28, 1)),\n        layers.Conv2D(64, (3, 3), strides=(2, 2), padding=\"same\"),\n        layers.LeakyReLU(negative_slope=0.2),\n        layers.Conv2D(128, (3, 3), strides=(2, 2), padding=\"same\"),\n        layers.LeakyReLU(negative_slope=0.2),\n        layers.GlobalMaxPooling2D(),\n        layers.Dense(1),\n    ],\n    name=\"discriminator\",\n)\n\n# Create the generator\nlatent_dim = 128\ngenerator = keras.Sequential(\n    [\n        keras.Input(shape=(latent_dim,)),\n        # We want to generate 128 coefficients to reshape into a 7x7x128 map\n        layers.Dense(7 * 7 * 128),\n        layers.LeakyReLU(negative_slope=0.2),\n        layers.Reshape((7, 7, 128)),\n        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding=\"same\"),\n        layers.LeakyReLU(negative_slope=0.2),\n        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding=\"same\"),\n        layers.LeakyReLU(negative_slope=0.2),\n        layers.Conv2D(1, (7, 7), padding=\"same\", activation=\"sigmoid\"),\n    ],\n    name=\"generator\",\n)\n\n\"\"\"\nHere's a feature-complete GAN class, overriding `compile()` to use its own signature,\nand implementing the entire GAN algorithm in 17 lines in `train_step`:\n\"\"\"\n\n\nclass GAN(keras.Model):\n    def __init__(self, discriminator, generator, latent_dim):\n        super().__init__()\n        self.discriminator = discriminator\n        self.generator = generator\n        self.latent_dim = latent_dim\n        self.d_loss_tracker = keras.metrics.Mean(name=\"d_loss\")\n        self.g_loss_tracker = keras.metrics.Mean(name=\"g_loss\")\n        self.seed_generator = keras.random.SeedGenerator(1337)\n        self.built = True\n\n    @property\n    def metrics(self):\n        return [self.d_loss_tracker, self.g_loss_tracker]\n\n    def compile(self, d_optimizer, g_optimizer, loss_fn):\n        super().compile()\n        self.d_optimizer = d_optimizer\n        self.g_optimizer = g_optimizer\n        self.loss_fn = loss_fn\n\n    def train_step(self, real_images):\n        if isinstance(real_images, tuple):\n            real_images = real_images[0]\n        # Sample random points in the latent space\n        batch_size = real_images.shape[0]\n        random_latent_vectors = keras.random.normal(\n            shape=(batch_size, self.latent_dim), seed=self.seed_generator\n        )\n\n        # Decode them to fake images\n        generated_images = self.generator(random_latent_vectors)\n\n        # Combine them with real images\n        real_images = torch.tensor(real_images)\n        combined_images = torch.concat([generated_images, real_images], axis=0)\n\n        # Assemble labels discriminating real from fake images\n        labels = torch.concat(\n            [torch.ones((batch_size, 1)), torch.zeros((batch_size, 1))], axis=0\n        )\n        # Add random noise to the labels - important trick!\n        labels += 0.05 * keras.random.uniform(\n            labels.shape, seed=self.seed_generator\n        )\n\n        # Train the discriminator\n        self.zero_grad()\n        predictions = self.discriminator(combined_images)\n        d_loss = self.loss_fn(labels, predictions)\n        d_loss.backward()\n        grads = [v.value.grad for v in self.discriminator.trainable_weights]\n        with torch.no_grad():\n            self.d_optimizer.apply(grads, self.discriminator.trainable_weights)\n\n        # Sample random points in the latent space\n        random_latent_vectors = keras.random.normal(\n            shape=(batch_size, self.latent_dim), seed=self.seed_generator\n        )\n\n        # Assemble labels that say \"all real images\"\n        misleading_labels = torch.zeros((batch_size, 1))\n\n        # Train the generator (note that we should *not* update the weights\n        # of the discriminator)!\n        self.zero_grad()\n        predictions = self.discriminator(self.generator(random_latent_vectors))\n        g_loss = self.loss_fn(misleading_labels, predictions)\n        grads = g_loss.backward()\n        grads = [v.value.grad for v in self.generator.trainable_weights]\n        with torch.no_grad():\n            self.g_optimizer.apply(grads, self.generator.trainable_weights)\n\n        # Update metrics and return their value.\n        self.d_loss_tracker.update_state(d_loss)\n        self.g_loss_tracker.update_state(g_loss)\n        return {\n            \"d_loss\": self.d_loss_tracker.result(),\n            \"g_loss\": self.g_loss_tracker.result(),\n        }\n\n\n\"\"\"\nLet's test-drive it:\n\"\"\"\n\n# Prepare the dataset. We use both the training & test MNIST digits.\nbatch_size = 64\n(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()\nall_digits = np.concatenate([x_train, x_test])\nall_digits = all_digits.astype(\"float32\") / 255.0\nall_digits = np.reshape(all_digits, (-1, 28, 28, 1))\n\n# Create a TensorDataset\ndataset = torch.utils.data.TensorDataset(\n    torch.from_numpy(all_digits), torch.from_numpy(all_digits)\n)\n# Create a DataLoader\ndataloader = torch.utils.data.DataLoader(\n    dataset, batch_size=batch_size, shuffle=True\n)\n\ngan = GAN(\n    discriminator=discriminator, generator=generator, latent_dim=latent_dim\n)\ngan.compile(\n    d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),\n    g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),\n    loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),\n)\n\ngan.fit(dataloader, epochs=1)\n\n\"\"\"\nThe ideas behind deep learning are simple, so why should their implementation be painful?\n\"\"\"\n"
  },
  {
    "path": "guides/distributed_training_with_jax.py",
    "content": "\"\"\"\nTitle: Multi-GPU distributed training with JAX\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate created: 2023/07/11\nLast modified: 2023/07/11\nDescription: Guide to multi-GPU/TPU training for Keras models with JAX.\nAccelerator: GPU\n\"\"\"\n\n\"\"\"\n## Introduction\n\nThere are generally two ways to distribute computation across multiple devices:\n\n**Data parallelism**, where a single model gets replicated on multiple devices or\nmultiple machines. Each of them processes different batches of data, then they merge\ntheir results. There exist many variants of this setup, that differ in how the different\nmodel replicas merge results, in whether they stay in sync at every batch or whether they\nare more loosely coupled, etc.\n\n**Model parallelism**, where different parts of a single model run on different devices,\nprocessing a single batch of data together. This works best with models that have a\nnaturally-parallel architecture, such as models that feature multiple branches.\n\nThis guide focuses on data parallelism, in particular **synchronous data parallelism**,\nwhere the different replicas of the model stay in sync after each batch they process.\nSynchronicity keeps the model convergence behavior identical to what you would see for\nsingle-device training.\n\nSpecifically, this guide teaches you how to use `jax.sharding` APIs to train Keras\nmodels, with minimal changes to your code, on multiple GPUs or TPUS (typically 2 to 16)\ninstalled on a single machine (single host, multi-device training). This is the\nmost common setup for researchers and small-scale industry workflows.\n\"\"\"\n\n\"\"\"\n## Setup\n\nLet's start by defining the function that creates the model that we will train,\nand the function that creates the dataset we will train on (MNIST in this case).\n\"\"\"\n\nimport os\n\nos.environ[\"KERAS_BACKEND\"] = \"jax\"\n\nimport jax\nimport numpy as np\nimport tensorflow as tf\nimport keras\n\nfrom jax.experimental import mesh_utils\nfrom jax.sharding import Mesh\nfrom jax.sharding import NamedSharding\nfrom jax.sharding import PartitionSpec as P\n\n\ndef get_model():\n    # Make a simple convnet with batch normalization and dropout.\n    inputs = keras.Input(shape=(28, 28, 1))\n    x = keras.layers.Rescaling(1.0 / 255.0)(inputs)\n    x = keras.layers.Conv2D(\n        filters=12, kernel_size=3, padding=\"same\", use_bias=False\n    )(x)\n    x = keras.layers.BatchNormalization(scale=False, center=True)(x)\n    x = keras.layers.ReLU()(x)\n    x = keras.layers.Conv2D(\n        filters=24,\n        kernel_size=6,\n        use_bias=False,\n        strides=2,\n    )(x)\n    x = keras.layers.BatchNormalization(scale=False, center=True)(x)\n    x = keras.layers.ReLU()(x)\n    x = keras.layers.Conv2D(\n        filters=32,\n        kernel_size=6,\n        padding=\"same\",\n        strides=2,\n        name=\"large_k\",\n    )(x)\n    x = keras.layers.BatchNormalization(scale=False, center=True)(x)\n    x = keras.layers.ReLU()(x)\n    x = keras.layers.GlobalAveragePooling2D()(x)\n    x = keras.layers.Dense(256, activation=\"relu\")(x)\n    x = keras.layers.Dropout(0.5)(x)\n    outputs = keras.layers.Dense(10)(x)\n    model = keras.Model(inputs, outputs)\n    return model\n\n\ndef get_datasets():\n    # Load the data and split it between train and test sets\n    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n\n    # Scale images to the [0, 1] range\n    x_train = x_train.astype(\"float32\")\n    x_test = x_test.astype(\"float32\")\n    # Make sure images have shape (28, 28, 1)\n    x_train = np.expand_dims(x_train, -1)\n    x_test = np.expand_dims(x_test, -1)\n    print(\"x_train shape:\", x_train.shape)\n    print(x_train.shape[0], \"train samples\")\n    print(x_test.shape[0], \"test samples\")\n\n    # Create TF Datasets\n    train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n    eval_data = tf.data.Dataset.from_tensor_slices((x_test, y_test))\n    return train_data, eval_data\n\n\n\"\"\"\n## Single-host, multi-device synchronous training\n\nIn this setup, you have one machine with several GPUs or TPUs on it (typically 2 to 16).\nEach device will run a copy of your model (called a **replica**). For simplicity, in\nwhat follows, we'll assume we're dealing with 8 GPUs, at no loss of generality.\n\n**How it works**\n\nAt each step of training:\n\n- The current batch of data (called **global batch**) is split into 8 different\n  sub-batches (called **local batches**). For instance, if the global batch has 512\n  samples, each of the 8 local batches will have 64 samples.\n- Each of the 8 replicas independently processes a local batch: they run a forward pass,\n  then a backward pass, outputting the gradient of the weights with respect to the loss of\n  the model on the local batch.\n- The weight updates originating from local gradients are efficiently merged across the 8\n  replicas. Because this is done at the end of every step, the replicas always stay in\n  sync.\n\nIn practice, the process of synchronously updating the weights of the model replicas is\nhandled at the level of each individual weight variable. This is done through a using\na `jax.sharding.NamedSharding` that is configured to replicate the variables.\n\n**How to use it**\n\nTo do single-host, multi-device synchronous training with a Keras model, you\nwould use the `jax.sharding` features. Here's how it works:\n\n- We first create a device mesh using `mesh_utils.create_device_mesh`.\n- We use `jax.sharding.Mesh`, `jax.sharding.NamedSharding` and\n  `jax.sharding.PartitionSpec` to define how to partition JAX arrays.\n    - We specify that we want to replicate the model and optimizer variables\n      across all devices by using a spec with no axis.\n    - We specify that we want to shard the data across devices by using a spec\n      that splits along the batch dimension.\n- We use `jax.device_put` to replicate the model and optimizer variables across\n  devices. This happens once at the beginning.\n- In the training loop, for each batch that we process, we use `jax.device_put`\n  to split the batch across devices before invoking the train step.\n\nHere's the flow, where each step is split into its own utility function:\n\"\"\"\n\n# Config\nnum_epochs = 2\nbatch_size = 64\n\ntrain_data, eval_data = get_datasets()\ntrain_data = train_data.batch(batch_size, drop_remainder=True)\n\nmodel = get_model()\noptimizer = keras.optimizers.Adam(1e-3)\nloss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n\n# Initialize all state with .build()\n(one_batch, one_batch_labels) = next(iter(train_data))\nmodel.build(one_batch)\noptimizer.build(model.trainable_variables)\n\n\n# This is the loss function that will be differentiated.\n# Keras provides a pure functional forward pass: model.stateless_call\ndef compute_loss(trainable_variables, non_trainable_variables, x, y):\n    y_pred, updated_non_trainable_variables = model.stateless_call(\n        trainable_variables, non_trainable_variables, x\n    )\n    loss_value = loss(y, y_pred)\n    return loss_value, updated_non_trainable_variables\n\n\n# Function to compute gradients\ncompute_gradients = jax.value_and_grad(compute_loss, has_aux=True)\n\n\n# Training step, Keras provides a pure functional optimizer.stateless_apply\n@jax.jit\ndef train_step(train_state, x, y):\n    (\n        trainable_variables,\n        non_trainable_variables,\n        optimizer_variables,\n    ) = train_state\n    (loss_value, non_trainable_variables), grads = compute_gradients(\n        trainable_variables, non_trainable_variables, x, y\n    )\n\n    trainable_variables, optimizer_variables = optimizer.stateless_apply(\n        optimizer_variables, grads, trainable_variables\n    )\n\n    return loss_value, (\n        trainable_variables,\n        non_trainable_variables,\n        optimizer_variables,\n    )\n\n\n# Replicate the model and optimizer variable on all devices\ndef get_replicated_train_state(devices):\n    # All variables will be replicated on all devices\n    var_mesh = Mesh(devices, axis_names=(\"_\"))\n    # In NamedSharding, axes not mentioned are replicated (all axes here)\n    var_replication = NamedSharding(var_mesh, P())\n\n    # Apply the distribution settings to the model variables\n    trainable_variables = jax.device_put(\n        model.trainable_variables, var_replication\n    )\n    non_trainable_variables = jax.device_put(\n        model.non_trainable_variables, var_replication\n    )\n    optimizer_variables = jax.device_put(optimizer.variables, var_replication)\n\n    # Combine all state in a tuple\n    return (trainable_variables, non_trainable_variables, optimizer_variables)\n\n\nnum_devices = len(jax.local_devices())\nprint(f\"Running on {num_devices} devices: {jax.local_devices()}\")\ndevices = mesh_utils.create_device_mesh((num_devices,))\n\n# Data will be split along the batch axis\ndata_mesh = Mesh(devices, axis_names=(\"batch\",))  # naming axes of the mesh\ndata_sharding = NamedSharding(\n    data_mesh,\n    P(\n        \"batch\",\n    ),\n)  # naming axes of the sharded partition\n\n# Display data sharding\nx, y = next(iter(train_data))\nsharded_x = jax.device_put(x.numpy(), data_sharding)\nprint(\"Data sharding\")\njax.debug.visualize_array_sharding(jax.numpy.reshape(sharded_x, [-1, 28 * 28]))\n\ntrain_state = get_replicated_train_state(devices)\n\n# Custom training loop\nfor epoch in range(num_epochs):\n    data_iter = iter(train_data)\n    loss_value = None  # default\n    for data in data_iter:\n        x, y = data\n        sharded_x = jax.device_put(x.numpy(), data_sharding)\n        loss_value, train_state = train_step(train_state, sharded_x, y.numpy())\n    print(\"Epoch\", epoch, \"loss:\", loss_value)\n\n# Post-processing model state update to write them back into the model\ntrainable_variables, non_trainable_variables, optimizer_variables = train_state\nfor variable, value in zip(model.trainable_variables, trainable_variables):\n    variable.assign(value)\nfor variable, value in zip(\n    model.non_trainable_variables, non_trainable_variables\n):\n    variable.assign(value)\n\n\"\"\"\nThat's it!\n\"\"\"\n"
  },
  {
    "path": "guides/distributed_training_with_tensorflow.py",
    "content": "\"\"\"\nTitle: Multi-GPU distributed training with TensorFlow\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate created: 2020/04/28\nLast modified: 2023/06/29\nDescription: Guide to multi-GPU training for Keras models with TensorFlow.\nAccelerator: GPU\n\"\"\"\n\n\"\"\"\n## Introduction\n\nThere are generally two ways to distribute computation across multiple devices:\n\n**Data parallelism**, where a single model gets replicated on multiple devices or\nmultiple machines. Each of them processes different batches of data, then they merge\ntheir results. There exist many variants of this setup, that differ in how the different\nmodel replicas merge results, in whether they stay in sync at every batch or whether they\nare more loosely coupled, etc.\n\n**Model parallelism**, where different parts of a single model run on different devices,\nprocessing a single batch of data together. This works best with models that have a\nnaturally-parallel architecture, such as models that feature multiple branches.\n\nThis guide focuses on data parallelism, in particular **synchronous data parallelism**,\nwhere the different replicas of the model stay in sync after each batch they process.\nSynchronicity keeps the model convergence behavior identical to what you would see for\nsingle-device training.\n\nSpecifically, this guide teaches you how to use the `tf.distribute` API to train Keras\nmodels on multiple GPUs, with minimal changes to your code,\non multiple GPUs (typically 2 to 16) installed on a single machine (single host,\nmulti-device training). This is the most common setup for researchers and small-scale\nindustry workflows.\n\"\"\"\n\n\"\"\"\n## Setup\n\"\"\"\n\nimport os\n\nos.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n\nimport tensorflow as tf\nimport keras\n\n\"\"\"\n## Single-host, multi-device synchronous training\n\nIn this setup, you have one machine with several GPUs on it (typically 2 to 16). Each\ndevice will run a copy of your model (called a **replica**). For simplicity, in what\nfollows, we'll assume we're dealing with 8 GPUs, at no loss of generality.\n\n**How it works**\n\nAt each step of training:\n\n- The current batch of data (called **global batch**) is split into 8 different\nsub-batches (called **local batches**). For instance, if the global batch has 512\nsamples, each of the 8 local batches will have 64 samples.\n- Each of the 8 replicas independently processes a local batch: they run a forward pass,\nthen a backward pass, outputting the gradient of the weights with respect to the loss of\nthe model on the local batch.\n- The weight updates originating from local gradients are efficiently merged across the 8\nreplicas. Because this is done at the end of every step, the replicas always stay in\nsync.\n\nIn practice, the process of synchronously updating the weights of the model replicas is\nhandled at the level of each individual weight variable. This is done through a **mirrored\nvariable** object.\n\n**How to use it**\n\nTo do single-host, multi-device synchronous training with a Keras model, you would use\nthe [`tf.distribute.MirroredStrategy` API](\n    https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy).\nHere's how it works:\n\n- Instantiate a `MirroredStrategy`, optionally configuring which specific devices you\nwant to use (by default the strategy will use all GPUs available).\n- Use the strategy object to open a scope, and within this scope, create all the Keras\nobjects you need that contain variables. Typically, that means **creating & compiling the\nmodel** inside the distribution scope. In some cases, the first call to `fit()` may also\ncreate variables, so it's a good idea to put your `fit()` call in the scope as well.\n- Train the model via `fit()` as usual.\n\nImportantly, we recommend that you use `tf.data.Dataset` objects to load data\nin a multi-device or distributed workflow.\n\nSchematically, it looks like this:\n\n```python\n# Create a MirroredStrategy.\nstrategy = tf.distribute.MirroredStrategy()\nprint('Number of devices: {}'.format(strategy.num_replicas_in_sync))\n\n# Open a strategy scope.\nwith strategy.scope():\n    # Everything that creates variables should be under the strategy scope.\n    # In general this is only model construction & `compile()`.\n    model = Model(...)\n    model.compile(...)\n\n    # Train the model on all available devices.\n    model.fit(train_dataset, validation_data=val_dataset, ...)\n\n    # Test the model on all available devices.\n    model.evaluate(test_dataset)\n```\n\nHere's a simple end-to-end runnable example:\n\"\"\"\n\n\ndef get_compiled_model():\n    # Make a simple 2-layer densely-connected neural network.\n    inputs = keras.Input(shape=(784,))\n    x = keras.layers.Dense(256, activation=\"relu\")(inputs)\n    x = keras.layers.Dense(256, activation=\"relu\")(x)\n    outputs = keras.layers.Dense(10)(x)\n    model = keras.Model(inputs, outputs)\n    model.compile(\n        optimizer=keras.optimizers.Adam(),\n        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n        metrics=[keras.metrics.SparseCategoricalAccuracy()],\n    )\n    return model\n\n\ndef get_dataset():\n    batch_size = 32\n    num_val_samples = 10000\n\n    # Return the MNIST dataset in the form of a `tf.data.Dataset`.\n    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n\n    # Preprocess the data (these are Numpy arrays)\n    x_train = x_train.reshape(-1, 784).astype(\"float32\") / 255\n    x_test = x_test.reshape(-1, 784).astype(\"float32\") / 255\n    y_train = y_train.astype(\"float32\")\n    y_test = y_test.astype(\"float32\")\n\n    # Reserve num_val_samples samples for validation\n    x_val = x_train[-num_val_samples:]\n    y_val = y_train[-num_val_samples:]\n    x_train = x_train[:-num_val_samples]\n    y_train = y_train[:-num_val_samples]\n    return (\n        tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(\n            batch_size\n        ),\n        tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size),\n        tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size),\n    )\n\n\n# Create a MirroredStrategy.\nstrategy = tf.distribute.MirroredStrategy()\nprint(\"Number of devices: {}\".format(strategy.num_replicas_in_sync))\n\n# Open a strategy scope.\nwith strategy.scope():\n    # Everything that creates variables should be under the strategy scope.\n    # In general this is only model construction & `compile()`.\n    model = get_compiled_model()\n\n    # Train the model on all available devices.\n    train_dataset, val_dataset, test_dataset = get_dataset()\n    model.fit(train_dataset, epochs=2, validation_data=val_dataset)\n\n    # Test the model on all available devices.\n    model.evaluate(test_dataset)\n\n\"\"\"\n## Using callbacks to ensure fault tolerance\n\nWhen using distributed training, you should always make sure you have a strategy to\nrecover from failure (fault tolerance). The simplest way to handle this is to pass\n`ModelCheckpoint` callback to `fit()`, to save your model\nat regular intervals (e.g. every 100 batches or every epoch). You can then restart\ntraining from your saved model.\n\nHere's a simple example:\n\"\"\"\n\n# Prepare a directory to store all the checkpoints.\ncheckpoint_dir = \"./ckpt\"\nif not os.path.exists(checkpoint_dir):\n    os.makedirs(checkpoint_dir)\n\n\ndef make_or_restore_model():\n    # Either restore the latest model, or create a fresh one\n    # if there is no checkpoint available.\n    checkpoints = [\n        os.path.join(checkpoint_dir, name)\n        for name in os.listdir(checkpoint_dir)\n    ]\n    if checkpoints:\n        latest_checkpoint = max(checkpoints, key=os.path.getctime)\n        print(\"Restoring from\", latest_checkpoint)\n        return keras.models.load_model(latest_checkpoint)\n    print(\"Creating a new model\")\n    return get_compiled_model()\n\n\ndef run_training(epochs=1):\n    # Create a MirroredStrategy.\n    strategy = tf.distribute.MirroredStrategy()\n\n    # Open a strategy scope and create/restore the model\n    with strategy.scope():\n        model = make_or_restore_model()\n\n        callbacks = [\n            # This callback saves a SavedModel every epoch\n            # We include the current epoch in the folder name.\n            keras.callbacks.ModelCheckpoint(\n                filepath=os.path.join(checkpoint_dir, \"ckpt-{epoch}.keras\"),\n                save_freq=\"epoch\",\n            )\n        ]\n        model.fit(\n            train_dataset,\n            epochs=epochs,\n            callbacks=callbacks,\n            validation_data=val_dataset,\n            verbose=2,\n        )\n\n\n# Running the first time creates the model\nrun_training(epochs=1)\n\n# Calling the same function again will resume from where we left off\nrun_training(epochs=1)\n\n\"\"\"\n## `tf.data` performance tips\n\nWhen doing distributed training, the efficiency with which you load data can often become\ncritical. Here are a few tips to make sure your `tf.data` pipelines\nrun as fast as possible.\n\n**Note about dataset batching**\n\nWhen creating your dataset, make sure it is batched with the global batch size.\nFor instance, if each of your 8 GPUs is capable of running a batch of 64 samples, you\ncall use a global batch size of 512.\n\n**Calling `dataset.cache()`**\n\nIf you call `.cache()` on a dataset, its data will be cached after running through the\nfirst iteration over the data. Every subsequent iteration will use the cached data. The\ncache can be in memory (default) or to a local file you specify.\n\nThis can improve performance when:\n\n- Your data is not expected to change from iteration to iteration\n- You are reading data from a remote distributed filesystem\n- You are reading data from local disk, but your data would fit in memory and your\nworkflow is significantly IO-bound (e.g. reading & decoding image files).\n\n**Calling `dataset.prefetch(buffer_size)`**\n\nYou should almost always call `.prefetch(buffer_size)` after creating a dataset. It means\nyour data pipeline will run asynchronously from your model,\nwith new samples being preprocessed and stored in a buffer while the current batch\nsamples are used to train the model. The next batch will be prefetched in GPU memory by\nthe time the current batch is over.\n\"\"\"\n\n\"\"\"\nThat's it!\n\"\"\"\n"
  },
  {
    "path": "guides/distributed_training_with_torch.py",
    "content": "\"\"\"\nTitle: Multi-GPU distributed training with PyTorch\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate created: 2023/06/29\nLast modified: 2023/06/29\nDescription: Guide to multi-GPU training for Keras models with PyTorch.\nAccelerator: GPU\n\"\"\"\n\n\"\"\"\n## Introduction\n\nThere are generally two ways to distribute computation across multiple devices:\n\n**Data parallelism**, where a single model gets replicated on multiple devices or\nmultiple machines. Each of them processes different batches of data, then they merge\ntheir results. There exist many variants of this setup, that differ in how the different\nmodel replicas merge results, in whether they stay in sync at every batch or whether they\nare more loosely coupled, etc.\n\n**Model parallelism**, where different parts of a single model run on different devices,\nprocessing a single batch of data together. This works best with models that have a\nnaturally-parallel architecture, such as models that feature multiple branches.\n\nThis guide focuses on data parallelism, in particular **synchronous data parallelism**,\nwhere the different replicas of the model stay in sync after each batch they process.\nSynchronicity keeps the model convergence behavior identical to what you would see for\nsingle-device training.\n\nSpecifically, this guide teaches you how to use PyTorch's `DistributedDataParallel`\nmodule wrapper to train Keras, with minimal changes to your code,\non multiple GPUs (typically 2 to 16) installed on a single machine (single host,\nmulti-device training). This is the most common setup for researchers and small-scale\nindustry workflows.\n\"\"\"\n\n\"\"\"\n## Setup\n\nLet's start by defining the function that creates the model that we will train,\nand the function that creates the dataset we will train on (MNIST in this case).\n\"\"\"\n\nimport os\n\nos.environ[\"KERAS_BACKEND\"] = \"torch\"\n\nimport torch\nimport numpy as np\nimport keras\n\n\ndef get_model():\n    # Make a simple convnet with batch normalization and dropout.\n    inputs = keras.Input(shape=(28, 28, 1))\n    x = keras.layers.Rescaling(1.0 / 255.0)(inputs)\n    x = keras.layers.Conv2D(\n        filters=12, kernel_size=3, padding=\"same\", use_bias=False\n    )(x)\n    x = keras.layers.BatchNormalization(scale=False, center=True)(x)\n    x = keras.layers.ReLU()(x)\n    x = keras.layers.Conv2D(\n        filters=24,\n        kernel_size=6,\n        use_bias=False,\n        strides=2,\n    )(x)\n    x = keras.layers.BatchNormalization(scale=False, center=True)(x)\n    x = keras.layers.ReLU()(x)\n    x = keras.layers.Conv2D(\n        filters=32,\n        kernel_size=6,\n        padding=\"same\",\n        strides=2,\n        name=\"large_k\",\n    )(x)\n    x = keras.layers.BatchNormalization(scale=False, center=True)(x)\n    x = keras.layers.ReLU()(x)\n    x = keras.layers.GlobalAveragePooling2D()(x)\n    x = keras.layers.Dense(256, activation=\"relu\")(x)\n    x = keras.layers.Dropout(0.5)(x)\n    outputs = keras.layers.Dense(10)(x)\n    model = keras.Model(inputs, outputs)\n    return model\n\n\ndef get_dataset():\n    # Load the data and split it between train and test sets\n    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n\n    # Scale images to the [0, 1] range\n    x_train = x_train.astype(\"float32\")\n    x_test = x_test.astype(\"float32\")\n    # Make sure images have shape (28, 28, 1)\n    x_train = np.expand_dims(x_train, -1)\n    x_test = np.expand_dims(x_test, -1)\n    print(\"x_train shape:\", x_train.shape)\n\n    # Create a TensorDataset\n    dataset = torch.utils.data.TensorDataset(\n        torch.from_numpy(x_train), torch.from_numpy(y_train)\n    )\n    return dataset\n\n\n\"\"\"\nNext, let's define a simple PyTorch training loop that targets\na GPU (note the calls to `.cuda()`).\n\"\"\"\n\n\ndef train_model(model, dataloader, num_epochs, optimizer, loss_fn):\n    for epoch in range(num_epochs):\n        running_loss = 0.0\n        running_loss_count = 0\n        for batch_idx, (inputs, targets) in enumerate(dataloader):\n            inputs = inputs.cuda(non_blocking=True)\n            targets = targets.cuda(non_blocking=True)\n\n            # Forward pass\n            outputs = model(inputs)\n            loss = loss_fn(outputs, targets)\n\n            # Backward and optimize\n            optimizer.zero_grad()\n            loss.backward()\n            optimizer.step()\n\n            running_loss += loss.item()\n            running_loss_count += 1\n\n        # Print loss statistics\n        print(\n            f\"Epoch {epoch + 1}/{num_epochs}, \"\n            f\"Loss: {running_loss / running_loss_count}\"\n        )\n\n\n\"\"\"\n## Single-host, multi-device synchronous training\n\nIn this setup, you have one machine with several GPUs on it (typically 2 to 16). Each\ndevice will run a copy of your model (called a **replica**). For simplicity, in what\nfollows, we'll assume we're dealing with 8 GPUs, at no loss of generality.\n\n**How it works**\n\nAt each step of training:\n\n- The current batch of data (called **global batch**) is split into 8 different\nsub-batches (called **local batches**). For instance, if the global batch has 512\nsamples, each of the 8 local batches will have 64 samples.\n- Each of the 8 replicas independently processes a local batch: they run a forward pass,\nthen a backward pass, outputting the gradient of the weights with respect to the loss of\nthe model on the local batch.\n- The weight updates originating from local gradients are efficiently merged across the 8\nreplicas. Because this is done at the end of every step, the replicas always stay in\nsync.\n\nIn practice, the process of synchronously updating the weights of the model replicas is\nhandled at the level of each individual weight variable. This is done through a **mirrored\nvariable** object.\n\n**How to use it**\n\nTo do single-host, multi-device synchronous training with a Keras model, you would use\nthe `torch.nn.parallel.DistributedDataParallel` module wrapper.\nHere's how it works:\n\n- We use `torch.multiprocessing.start_processes` to start multiple Python processes, one\nper device. Each process will run the `per_device_launch_fn` function.\n- The `per_device_launch_fn` function does the following:\n    - It uses `torch.distributed.init_process_group` and `torch.cuda.set_device`\n    to configure the device to be used for that process.\n    - It uses `torch.utils.data.distributed.DistributedSampler`\n    and `torch.utils.data.DataLoader` to turn our data into a distributed data loader.\n    - It also uses `torch.nn.parallel.DistributedDataParallel` to turn our model into\n    a distributed PyTorch module.\n    - It then calls the `train_model` function.\n- The `train_model` function will then run in each process, with the model using\na separate device in each process.\n\nHere's the flow, where each step is split into its own utility function:\n\"\"\"\n\n# Config\nnum_gpu = torch.cuda.device_count()\nnum_epochs = 2\nbatch_size = 64\nprint(f\"Running on {num_gpu} GPUs\")\n\n\ndef setup_device(current_gpu_index, num_gpus):\n    # Device setup\n    os.environ[\"MASTER_ADDR\"] = \"localhost\"\n    os.environ[\"MASTER_PORT\"] = \"56492\"\n    device = torch.device(\"cuda:{}\".format(current_gpu_index))\n    torch.distributed.init_process_group(\n        backend=\"nccl\",\n        init_method=\"env://\",\n        world_size=num_gpus,\n        rank=current_gpu_index,\n    )\n    torch.cuda.set_device(device)\n\n\ndef cleanup():\n    torch.distributed.destroy_process_group()\n\n\ndef prepare_dataloader(dataset, current_gpu_index, num_gpus, batch_size):\n    sampler = torch.utils.data.distributed.DistributedSampler(\n        dataset,\n        num_replicas=num_gpus,\n        rank=current_gpu_index,\n        shuffle=False,\n    )\n    dataloader = torch.utils.data.DataLoader(\n        dataset,\n        sampler=sampler,\n        batch_size=batch_size,\n        shuffle=False,\n    )\n    return dataloader\n\n\ndef per_device_launch_fn(current_gpu_index, num_gpu):\n    # Setup the process groups\n    setup_device(current_gpu_index, num_gpu)\n\n    dataset = get_dataset()\n    model = get_model()\n\n    # prepare the dataloader\n    dataloader = prepare_dataloader(\n        dataset, current_gpu_index, num_gpu, batch_size\n    )\n\n    # Instantiate the torch optimizer\n    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n\n    # Instantiate the torch loss function\n    loss_fn = torch.nn.CrossEntropyLoss()\n\n    # Put model on device\n    model = model.to(current_gpu_index)\n    ddp_model = torch.nn.parallel.DistributedDataParallel(\n        model, device_ids=[current_gpu_index], output_device=current_gpu_index\n    )\n\n    train_model(ddp_model, dataloader, num_epochs, optimizer, loss_fn)\n\n    cleanup()\n\n\n\"\"\"\nTime to start multiple processes:\n\"\"\"\n\nif __name__ == \"__main__\":\n    # We use the \"fork\" method rather than \"spawn\" to support notebooks\n    torch.multiprocessing.start_processes(\n        per_device_launch_fn,\n        args=(num_gpu,),\n        nprocs=num_gpu,\n        join=True,\n        start_method=\"fork\",\n    )\n\n\"\"\"\nThat's it!\n\"\"\"\n"
  },
  {
    "path": "guides/functional_api.py",
    "content": "\"\"\"\nTitle: The Functional API\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate created: 2019/03/01\nLast modified: 2020/04/12\nDescription: Complete guide to the functional API.\nAccelerator: GPU\n\"\"\"\n\n\"\"\"\n## Setup\n\"\"\"\n\nimport numpy as np\nimport keras\nfrom keras import layers\nfrom keras import ops\n\n\"\"\"\n## Introduction\n\nThe Keras *functional API* is a way to create models that are more flexible\nthan the `keras.Sequential` API. The functional API can handle models\nwith non-linear topology, shared layers, and even multiple inputs or outputs.\n\nThe main idea is that a deep learning model is usually\na directed acyclic graph (DAG) of layers.\nSo the functional API is a way to build *graphs of layers*.\n\nConsider the following model:\n\n<div class=\"k-default-codeblock\">\n```\n(input: 784-dimensional vectors)\n       ↧\n[Dense (64 units, relu activation)]\n       ↧\n[Dense (64 units, relu activation)]\n       ↧\n[Dense (10 units, softmax activation)]\n       ↧\n(output: logits of a probability distribution over 10 classes)\n```\n</div>\n\nThis is a basic graph with three layers.\nTo build this model using the functional API, start by creating an input node:\n\"\"\"\n\ninputs = keras.Input(shape=(784,))\n\n\"\"\"\nThe shape of the data is set as a 784-dimensional vector.\nThe batch size is always omitted since only the shape of each sample is specified.\n\nIf, for example, you have an image input with a shape of `(32, 32, 3)`,\nyou would use:\n\"\"\"\n\n# Just for demonstration purposes.\nimg_inputs = keras.Input(shape=(32, 32, 3))\n\n\"\"\"\nThe `inputs` that is returned contains information about the shape and `dtype`\nof the input data that you feed to your model.\nHere's the shape:\n\"\"\"\n\ninputs.shape\n\n\"\"\"\nHere's the dtype:\n\"\"\"\n\ninputs.dtype\n\n\"\"\"\nYou create a new node in the graph of layers by calling a layer on this `inputs`\nobject:\n\"\"\"\n\ndense = layers.Dense(64, activation=\"relu\")\nx = dense(inputs)\n\n\"\"\"\nThe \"layer call\" action is like drawing an arrow from \"inputs\" to this layer\nyou created.\nYou're \"passing\" the inputs to the `dense` layer, and you get `x` as the output.\n\nLet's add a few more layers to the graph of layers:\n\"\"\"\n\nx = layers.Dense(64, activation=\"relu\")(x)\noutputs = layers.Dense(10)(x)\n\n\"\"\"\nAt this point, you can create a `Model` by specifying its inputs and outputs\nin the graph of layers:\n\"\"\"\n\nmodel = keras.Model(inputs=inputs, outputs=outputs, name=\"mnist_model\")\n\n\"\"\"\nLet's check out what the model summary looks like:\n\"\"\"\n\nmodel.summary()\n\n\"\"\"\nYou can also plot the model as a graph:\n\"\"\"\n\nkeras.utils.plot_model(model, \"my_first_model.png\")\n\n\"\"\"\nAnd, optionally, display the input and output shapes of each layer\nin the plotted graph:\n\"\"\"\n\nkeras.utils.plot_model(\n    model, \"my_first_model_with_shape_info.png\", show_shapes=True\n)\n\n\"\"\"\nThis figure and the code are almost identical. In the code version,\nthe connection arrows are replaced by the call operation.\n\nA \"graph of layers\" is an intuitive mental image for a deep learning model,\nand the functional API is a way to create models that closely mirrors this.\n\"\"\"\n\n\"\"\"\n## Training, evaluation, and inference\n\nTraining, evaluation, and inference work exactly in the same way for models\nbuilt using the functional API as for `Sequential` models.\n\nThe `Model` class offers a built-in training loop (the `fit()` method)\nand a built-in evaluation loop (the `evaluate()` method). Note\nthat you can easily [customize these loops](/guides/customizing_what_happens_in_fit/)\nto implement training routines beyond supervised learning\n(e.g. [GANs](https://keras.io/examples/generative/dcgan_overriding_train_step/)).\n\nHere, load the MNIST image data, reshape it into vectors,\nfit the model on the data (while monitoring performance on a validation split),\nthen evaluate the model on the test data:\n\"\"\"\n\n(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n\nx_train = x_train.reshape(60000, 784).astype(\"float32\") / 255\nx_test = x_test.reshape(10000, 784).astype(\"float32\") / 255\n\nmodel.compile(\n    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n    optimizer=keras.optimizers.RMSprop(),\n    metrics=[\"accuracy\"],\n)\n\nhistory = model.fit(\n    x_train, y_train, batch_size=64, epochs=2, validation_split=0.2\n)\n\ntest_scores = model.evaluate(x_test, y_test, verbose=2)\nprint(\"Test loss:\", test_scores[0])\nprint(\"Test accuracy:\", test_scores[1])\n\n\"\"\"\nFor further reading, see the [training and evaluation](/guides/training_with_built_in_methods/) guide.\n\"\"\"\n\n\"\"\"\n## Save and serialize\n\nSaving the model and serialization work the same way for models built using\nthe functional API as they do for `Sequential` models. The standard way\nto save a functional model is to call `model.save()`\nto save the entire model as a single file. You can later recreate the same model\nfrom this file, even if the code that built the model is no longer available.\n\nThis saved file includes the:\n\n- model architecture\n- model weight values (that were learned during training)\n- model training config, if any (as passed to `compile()`)\n- optimizer and its state, if any (to restart training where you left off)\n\"\"\"\n\nmodel.save(\"my_model.keras\")\ndel model\n# Recreate the exact same model purely from the file:\nmodel = keras.models.load_model(\"my_model.keras\")\n\n\"\"\"\nFor details, read the model [serialization & saving](\n    /guides/serialization_and_saving/) guide.\n\"\"\"\n\n\"\"\"\n## Use the same graph of layers to define multiple models\n\nIn the functional API, models are created by specifying their inputs\nand outputs in a graph of layers. That means that a single\ngraph of layers can be used to generate multiple models.\n\nIn the example below, you use the same stack of layers to instantiate two models:\nan `encoder` model that turns image inputs into 16-dimensional vectors,\nand an end-to-end `autoencoder` model for training.\n\"\"\"\n\nencoder_input = keras.Input(shape=(28, 28, 1), name=\"img\")\nx = layers.Conv2D(16, 3, activation=\"relu\")(encoder_input)\nx = layers.Conv2D(32, 3, activation=\"relu\")(x)\nx = layers.MaxPooling2D(3)(x)\nx = layers.Conv2D(32, 3, activation=\"relu\")(x)\nx = layers.Conv2D(16, 3, activation=\"relu\")(x)\nencoder_output = layers.GlobalMaxPooling2D()(x)\n\nencoder = keras.Model(encoder_input, encoder_output, name=\"encoder\")\nencoder.summary()\n\nx = layers.Reshape((4, 4, 1))(encoder_output)\nx = layers.Conv2DTranspose(16, 3, activation=\"relu\")(x)\nx = layers.Conv2DTranspose(32, 3, activation=\"relu\")(x)\nx = layers.UpSampling2D(3)(x)\nx = layers.Conv2DTranspose(16, 3, activation=\"relu\")(x)\ndecoder_output = layers.Conv2DTranspose(1, 3, activation=\"relu\")(x)\n\nautoencoder = keras.Model(encoder_input, decoder_output, name=\"autoencoder\")\nautoencoder.summary()\n\n\"\"\"\nHere, the decoding architecture is strictly symmetrical\nto the encoding architecture, so the output shape is the same as\nthe input shape `(28, 28, 1)`.\n\nThe reverse of a `Conv2D` layer is a `Conv2DTranspose` layer,\nand the reverse of a `MaxPooling2D` layer is an `UpSampling2D` layer.\n\"\"\"\n\n\"\"\"\n## All models are callable, just like layers\n\nYou can treat any model as if it were a layer by invoking it on an `Input` or\non the output of another layer. By calling a model you aren't just reusing\nthe architecture of the model, you're also reusing its weights.\n\nTo see this in action, here's a different take on the autoencoder example that\ncreates an encoder model, a decoder model, and chains them in two calls\nto obtain the autoencoder model:\n\"\"\"\n\nencoder_input = keras.Input(shape=(28, 28, 1), name=\"original_img\")\nx = layers.Conv2D(16, 3, activation=\"relu\")(encoder_input)\nx = layers.Conv2D(32, 3, activation=\"relu\")(x)\nx = layers.MaxPooling2D(3)(x)\nx = layers.Conv2D(32, 3, activation=\"relu\")(x)\nx = layers.Conv2D(16, 3, activation=\"relu\")(x)\nencoder_output = layers.GlobalMaxPooling2D()(x)\n\nencoder = keras.Model(encoder_input, encoder_output, name=\"encoder\")\nencoder.summary()\n\ndecoder_input = keras.Input(shape=(16,), name=\"encoded_img\")\nx = layers.Reshape((4, 4, 1))(decoder_input)\nx = layers.Conv2DTranspose(16, 3, activation=\"relu\")(x)\nx = layers.Conv2DTranspose(32, 3, activation=\"relu\")(x)\nx = layers.UpSampling2D(3)(x)\nx = layers.Conv2DTranspose(16, 3, activation=\"relu\")(x)\ndecoder_output = layers.Conv2DTranspose(1, 3, activation=\"relu\")(x)\n\ndecoder = keras.Model(decoder_input, decoder_output, name=\"decoder\")\ndecoder.summary()\n\nautoencoder_input = keras.Input(shape=(28, 28, 1), name=\"img\")\nencoded_img = encoder(autoencoder_input)\ndecoded_img = decoder(encoded_img)\nautoencoder = keras.Model(autoencoder_input, decoded_img, name=\"autoencoder\")\nautoencoder.summary()\n\n\"\"\"\nAs you can see, the model can be nested: a model can contain sub-models\n(since a model is just like a layer).\nA common use case for model nesting is *ensembling*.\nFor example, here's how to ensemble a set of models into a single model\nthat averages their predictions:\n\"\"\"\n\n\ndef get_model():\n    inputs = keras.Input(shape=(128,))\n    outputs = layers.Dense(1)(inputs)\n    return keras.Model(inputs, outputs)\n\n\nmodel1 = get_model()\nmodel2 = get_model()\nmodel3 = get_model()\n\ninputs = keras.Input(shape=(128,))\ny1 = model1(inputs)\ny2 = model2(inputs)\ny3 = model3(inputs)\noutputs = layers.average([y1, y2, y3])\nensemble_model = keras.Model(inputs=inputs, outputs=outputs)\n\n\"\"\"\n## Manipulate complex graph topologies\n\n### Models with multiple inputs and outputs\n\nThe functional API makes it easy to manipulate multiple inputs and outputs.\nThis cannot be handled with the `Sequential` API.\n\nFor example, if you're building a system for ranking customer issue tickets by\npriority and routing them to the correct department,\nthen the model will have three inputs:\n\n- the title of the ticket (text input),\n- the text body of the ticket (text input), and\n- any tags added by the user (categorical input)\n\nThis model will have two outputs:\n\n- the priority score between 0 and 1 (scalar sigmoid output), and\n- the department that should handle the ticket (softmax output\nover the set of departments).\n\nYou can build this model in a few lines with the functional API:\n\"\"\"\n\nnum_tags = 12  # Number of unique issue tags\nnum_words = 10000  # Size of vocabulary obtained when preprocessing text data\nnum_departments = 4  # Number of departments for predictions\n\ntitle_input = keras.Input(\n    shape=(None,), name=\"title\"\n)  # Variable-length sequence of ints\nbody_input = keras.Input(\n    shape=(None,), name=\"body\"\n)  # Variable-length sequence of ints\ntags_input = keras.Input(\n    shape=(num_tags,), name=\"tags\"\n)  # Binary vectors of size `num_tags`\n\n# Embed each word in the title into a 64-dimensional vector\ntitle_features = layers.Embedding(num_words, 64)(title_input)\n# Embed each word in the text into a 64-dimensional vector\nbody_features = layers.Embedding(num_words, 64)(body_input)\n\n# Reduce sequence of embedded words in the title into a single 128-dimensional vector\ntitle_features = layers.LSTM(128)(title_features)\n# Reduce sequence of embedded words in the body into a single 32-dimensional vector\nbody_features = layers.LSTM(32)(body_features)\n\n# Merge all available features into a single large vector via concatenation\nx = layers.concatenate([title_features, body_features, tags_input])\n\n# Stick a logistic regression for priority prediction on top of the features\npriority_pred = layers.Dense(1, name=\"priority\")(x)\n# Stick a department classifier on top of the features\ndepartment_pred = layers.Dense(num_departments, name=\"department\")(x)\n\n# Instantiate an end-to-end model predicting both priority and department\nmodel = keras.Model(\n    inputs=[title_input, body_input, tags_input],\n    outputs={\"priority\": priority_pred, \"department\": department_pred},\n)\n\n\"\"\"\nNow plot the model:\n\"\"\"\n\nkeras.utils.plot_model(\n    model, \"multi_input_and_output_model.png\", show_shapes=True\n)\n\n\"\"\"\nWhen compiling this model, you can assign different losses to each output.\nYou can even assign different weights to each loss -- to modulate\ntheir contribution to the total training loss.\n\"\"\"\n\nmodel.compile(\n    optimizer=keras.optimizers.RMSprop(1e-3),\n    loss=[\n        keras.losses.BinaryCrossentropy(from_logits=True),\n        keras.losses.CategoricalCrossentropy(from_logits=True),\n    ],\n    loss_weights=[1.0, 0.2],\n)\n\n\"\"\"\nSince the output layers have different names, you could also specify\nthe losses and loss weights with the corresponding layer names:\n\"\"\"\n\nmodel.compile(\n    optimizer=keras.optimizers.RMSprop(1e-3),\n    loss={\n        \"priority\": keras.losses.BinaryCrossentropy(from_logits=True),\n        \"department\": keras.losses.CategoricalCrossentropy(from_logits=True),\n    },\n    loss_weights={\"priority\": 1.0, \"department\": 0.2},\n)\n\n\"\"\"\nTrain the model by passing lists of NumPy arrays of inputs and targets:\n\"\"\"\n\n# Dummy input data\ntitle_data = np.random.randint(num_words, size=(1280, 10))\nbody_data = np.random.randint(num_words, size=(1280, 100))\ntags_data = np.random.randint(2, size=(1280, num_tags)).astype(\"float32\")\n\n# Dummy target data\npriority_targets = np.random.random(size=(1280, 1))\ndept_targets = np.random.randint(2, size=(1280, num_departments))\n\nmodel.fit(\n    {\"title\": title_data, \"body\": body_data, \"tags\": tags_data},\n    {\"priority\": priority_targets, \"department\": dept_targets},\n    epochs=2,\n    batch_size=32,\n)\n\n\"\"\"\nWhen calling fit with a `Dataset` object, it should yield either a\ntuple of lists like `([title_data, body_data, tags_data], [priority_targets, dept_targets])`\nor a tuple of dictionaries like\n`({'title': title_data, 'body': body_data, 'tags': tags_data}, {'priority': priority_targets, 'department': dept_targets})`.\n\nFor more detailed explanation, refer to the [training and evaluation](/guides/training_with_built_in_methods/) guide.\n\"\"\"\n\n\"\"\"\n### A toy ResNet model\n\nIn addition to models with multiple inputs and outputs,\nthe functional API makes it easy to manipulate non-linear connectivity\ntopologies -- these are models with layers that are not connected sequentially,\nwhich the `Sequential` API cannot handle.\n\nA common use case for this is residual connections.\nLet's build a toy ResNet model for CIFAR10 to demonstrate this:\n\"\"\"\n\ninputs = keras.Input(shape=(32, 32, 3), name=\"img\")\nx = layers.Conv2D(32, 3, activation=\"relu\")(inputs)\nx = layers.Conv2D(64, 3, activation=\"relu\")(x)\nblock_1_output = layers.MaxPooling2D(3)(x)\n\nx = layers.Conv2D(64, 3, activation=\"relu\", padding=\"same\")(block_1_output)\nx = layers.Conv2D(64, 3, activation=\"relu\", padding=\"same\")(x)\nblock_2_output = layers.add([x, block_1_output])\n\nx = layers.Conv2D(64, 3, activation=\"relu\", padding=\"same\")(block_2_output)\nx = layers.Conv2D(64, 3, activation=\"relu\", padding=\"same\")(x)\nblock_3_output = layers.add([x, block_2_output])\n\nx = layers.Conv2D(64, 3, activation=\"relu\")(block_3_output)\nx = layers.GlobalAveragePooling2D()(x)\nx = layers.Dense(256, activation=\"relu\")(x)\nx = layers.Dropout(0.5)(x)\noutputs = layers.Dense(10)(x)\n\nmodel = keras.Model(inputs, outputs, name=\"toy_resnet\")\nmodel.summary()\n\n\"\"\"\nPlot the model:\n\"\"\"\n\nkeras.utils.plot_model(model, \"mini_resnet.png\", show_shapes=True)\n\n\"\"\"\nNow train the model:\n\"\"\"\n\n(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()\n\nx_train = x_train.astype(\"float32\") / 255.0\nx_test = x_test.astype(\"float32\") / 255.0\ny_train = keras.utils.to_categorical(y_train, 10)\ny_test = keras.utils.to_categorical(y_test, 10)\n\nmodel.compile(\n    optimizer=keras.optimizers.RMSprop(1e-3),\n    loss=keras.losses.CategoricalCrossentropy(from_logits=True),\n    metrics=[\"acc\"],\n)\n# We restrict the data to the first 1000 samples so as to limit execution time\n# on Colab. Try to train on the entire dataset until convergence!\nmodel.fit(\n    x_train[:1000],\n    y_train[:1000],\n    batch_size=64,\n    epochs=1,\n    validation_split=0.2,\n)\n\n\"\"\"\n## Shared layers\n\nAnother good use for the functional API are models that use *shared layers*.\nShared layers are layer instances that are reused multiple times in the same model --\nthey learn features that correspond to multiple paths in the graph-of-layers.\n\nShared layers are often used to encode inputs from similar spaces\n(say, two different pieces of text that feature similar vocabulary).\nThey enable sharing of information across these different inputs,\nand they make it possible to train such a model on less data.\nIf a given word is seen in one of the inputs,\nthat will benefit the processing of all inputs that pass through the shared layer.\n\nTo share a layer in the functional API, call the same layer instance multiple times.\nFor instance, here's an `Embedding` layer shared across two different text inputs:\n\"\"\"\n\n# Embedding for 1000 unique words mapped to 128-dimensional vectors\nshared_embedding = layers.Embedding(1000, 128)\n\n# Variable-length sequence of integers\ntext_input_a = keras.Input(shape=(None,), dtype=\"int32\")\n\n# Variable-length sequence of integers\ntext_input_b = keras.Input(shape=(None,), dtype=\"int32\")\n\n# Reuse the same layer to encode both inputs\nencoded_input_a = shared_embedding(text_input_a)\nencoded_input_b = shared_embedding(text_input_b)\n\n\"\"\"\n## Extract and reuse nodes in the graph of layers\n\nBecause the graph of layers you are manipulating is a static data structure,\nit can be accessed and inspected. And this is how you are able to plot\nfunctional models as images.\n\nThis also means that you can access the activations of intermediate layers\n(\"nodes\" in the graph) and reuse them elsewhere --\nwhich is very useful for something like feature extraction.\n\nLet's look at an example. This is a VGG19 model with weights pretrained on ImageNet:\n\"\"\"\n\nvgg19 = keras.applications.VGG19()\n\n\"\"\"\nAnd these are the intermediate activations of the model,\nobtained by querying the graph data structure:\n\"\"\"\n\nfeatures_list = [layer.output for layer in vgg19.layers]\n\n\"\"\"\nUse these features to create a new feature-extraction model that returns\nthe values of the intermediate layer activations:\n\"\"\"\n\nfeat_extraction_model = keras.Model(inputs=vgg19.input, outputs=features_list)\n\nimg = np.random.random((1, 224, 224, 3)).astype(\"float32\")\nextracted_features = feat_extraction_model(img)\n\n\"\"\"\nThis comes in handy for tasks like\n[neural style transfer](https://keras.io/examples/generative/neural_style_transfer/),\namong other things.\n\"\"\"\n\n\"\"\"\n## Extend the API using custom layers\n\n`keras` includes a wide range of built-in layers, for example:\n\n- Convolutional layers: `Conv1D`, `Conv2D`, `Conv3D`, `Conv2DTranspose`\n- Pooling layers: `MaxPooling1D`, `MaxPooling2D`, `MaxPooling3D`, `AveragePooling1D`\n- RNN layers: `GRU`, `LSTM`, `ConvLSTM2D`\n- `BatchNormalization`, `Dropout`, `Embedding`, etc.\n\nBut if you don't find what you need, it's easy to extend the API by creating\nyour own layers. All layers subclass the `Layer` class and implement:\n\n- `call` method, that specifies the computation done by the layer.\n- `build` method, that creates the weights of the layer (this is just a style\nconvention since you can create weights in `__init__`, as well).\n\nTo learn more about creating layers from scratch, read\n[custom layers and models](/guides/making_new_layers_and_models_via_subclassing) guide.\n\nThe following is a basic implementation of `keras.layers.Dense`:\n\"\"\"\n\n\nclass CustomDense(layers.Layer):\n    def __init__(self, units=32):\n        super().__init__()\n        self.units = units\n\n    def build(self, input_shape):\n        self.w = self.add_weight(\n            shape=(input_shape[-1], self.units),\n            initializer=\"random_normal\",\n            trainable=True,\n        )\n        self.b = self.add_weight(\n            shape=(self.units,), initializer=\"random_normal\", trainable=True\n        )\n\n    def call(self, inputs):\n        return ops.matmul(inputs, self.w) + self.b\n\n\ninputs = keras.Input((4,))\noutputs = CustomDense(10)(inputs)\n\nmodel = keras.Model(inputs, outputs)\n\n\"\"\"\nFor serialization support in your custom layer, define a `get_config()`\nmethod that returns the constructor arguments of the layer instance:\n\"\"\"\n\n\nclass CustomDense(layers.Layer):\n    def __init__(self, units=32):\n        super().__init__()\n        self.units = units\n\n    def build(self, input_shape):\n        self.w = self.add_weight(\n            shape=(input_shape[-1], self.units),\n            initializer=\"random_normal\",\n            trainable=True,\n        )\n        self.b = self.add_weight(\n            shape=(self.units,), initializer=\"random_normal\", trainable=True\n        )\n\n    def call(self, inputs):\n        return ops.matmul(inputs, self.w) + self.b\n\n    def get_config(self):\n        return {\"units\": self.units}\n\n\ninputs = keras.Input((4,))\noutputs = CustomDense(10)(inputs)\n\nmodel = keras.Model(inputs, outputs)\nconfig = model.get_config()\n\nnew_model = keras.Model.from_config(\n    config, custom_objects={\"CustomDense\": CustomDense}\n)\n\n\"\"\"\nOptionally, implement the class method `from_config(cls, config)` which is used\nwhen recreating a layer instance given its config dictionary.\nThe default implementation of `from_config` is:\n\n```python\ndef from_config(cls, config):\n  return cls(**config)\n```\n\"\"\"\n\n\"\"\"\n## When to use the functional API\n\nShould you use the Keras functional API to create a new model,\nor just subclass the `Model` class directly? In general, the functional API\nis higher-level, easier and safer, and has a number of\nfeatures that subclassed models do not support.\n\nHowever, model subclassing provides greater flexibility when building models\nthat are not easily expressible as directed acyclic graphs of layers.\nFor example, you could not implement a Tree-RNN with the functional API\nand would have to subclass `Model` directly.\n\nFor an in-depth look at the differences between the functional API and\nmodel subclassing, read\n[What are Symbolic and Imperative APIs in TensorFlow 2.0?](https://blog.tensorflow.org/2019/01/what-are-symbolic-and-imperative-apis.html).\n\n### Functional API strengths:\n\nThe following properties are also true for Sequential models\n(which are also data structures), but are not true for subclassed models\n(which are Python bytecode, not data structures).\n\n#### Less verbose\n\nThere is no `super().__init__(...)`, no `def call(self, ...):`, etc.\n\nCompare:\n\n```python\ninputs = keras.Input(shape=(32,))\nx = layers.Dense(64, activation='relu')(inputs)\noutputs = layers.Dense(10)(x)\nmlp = keras.Model(inputs, outputs)\n```\n\nWith the subclassed version:\n\n```python\nclass MLP(keras.Model):\n\n  def __init__(self, **kwargs):\n    super().__init__(**kwargs)\n    self.dense_1 = layers.Dense(64, activation='relu')\n    self.dense_2 = layers.Dense(10)\n\n  def call(self, inputs):\n    x = self.dense_1(inputs)\n    return self.dense_2(x)\n\n# Instantiate the model.\nmlp = MLP()\n# Necessary to create the model's state.\n# The model doesn't have a state until it's called at least once.\n_ = mlp(ops.zeros((1, 32)))\n```\n\n#### Model validation while defining its connectivity graph\n\nIn the functional API, the input specification (shape and dtype) is created\nin advance (using `Input`). Every time you call a layer,\nthe layer checks that the specification passed to it matches its assumptions,\nand it will raise a helpful error message if not.\n\nThis guarantees that any model you can build with the functional API will run.\nAll debugging -- other than convergence-related debugging --\nhappens statically during the model construction and not at execution time.\nThis is similar to type checking in a compiler.\n\n#### A functional model is plottable and inspectable\n\nYou can plot the model as a graph, and you can easily access intermediate nodes\nin this graph. For example, to extract and reuse the activations of intermediate\nlayers (as seen in a previous example):\n\n```python\nfeatures_list = [layer.output for layer in vgg19.layers]\nfeat_extraction_model = keras.Model(inputs=vgg19.input, outputs=features_list)\n```\n\n#### A functional model can be serialized or cloned\n\nBecause a functional model is a data structure rather than a piece of code,\nit is safely serializable and can be saved as a single file\nthat allows you to recreate the exact same model\nwithout having access to any of the original code.\nSee the [serialization & saving guide](/guides/serialization_and_saving/).\n\nTo serialize a subclassed model, it is necessary for the implementer\nto specify a `get_config()`\nand `from_config()` method at the model level.\n\n\n### Functional API weakness:\n\n#### It does not support dynamic architectures\n\nThe functional API treats models as DAGs of layers.\nThis is true for most deep learning architectures, but not all -- for example,\nrecursive networks or Tree RNNs do not follow this assumption and cannot\nbe implemented in the functional API.\n\"\"\"\n\n\"\"\"\n## Mix-and-match API styles\n\nChoosing between the functional API or Model subclassing isn't a\nbinary decision that restricts you into one category of models.\nAll models in the `keras` API can interact with each other, whether they're\n`Sequential` models, functional models, or subclassed models that are written\nfrom scratch.\n\nYou can always use a functional model or `Sequential` model\nas part of a subclassed model or layer:\n\"\"\"\n\nunits = 32\ntimesteps = 10\ninput_dim = 5\n\n# Define a Functional model\ninputs = keras.Input((None, units))\nx = layers.GlobalAveragePooling1D()(inputs)\noutputs = layers.Dense(1)(x)\nmodel = keras.Model(inputs, outputs)\n\n\nclass CustomRNN(layers.Layer):\n    def __init__(self):\n        super().__init__()\n        self.units = units\n        self.projection_1 = layers.Dense(units=units, activation=\"tanh\")\n        self.projection_2 = layers.Dense(units=units, activation=\"tanh\")\n        # Our previously-defined Functional model\n        self.classifier = model\n\n    def call(self, inputs):\n        outputs = []\n        state = ops.zeros(shape=(inputs.shape[0], self.units))\n        for t in range(inputs.shape[1]):\n            x = inputs[:, t, :]\n            h = self.projection_1(x)\n            y = h + self.projection_2(state)\n            state = y\n            outputs.append(y)\n        features = ops.stack(outputs, axis=1)\n        print(features.shape)\n        return self.classifier(features)\n\n\nrnn_model = CustomRNN()\n_ = rnn_model(ops.zeros((1, timesteps, input_dim)))\n\n\"\"\"\nYou can use any subclassed layer or model in the functional API\nas long as it implements a `call` method that follows one of the following patterns:\n\n- `call(self, inputs, **kwargs)` --\nWhere `inputs` is a tensor or a nested structure of tensors (e.g. a list of tensors),\nand where `**kwargs` are non-tensor arguments (non-inputs).\n- `call(self, inputs, training=None, **kwargs)` --\nWhere `training` is a boolean indicating whether the layer should behave\nin training mode and inference mode.\n- `call(self, inputs, mask=None, **kwargs)` --\nWhere `mask` is a boolean mask tensor (useful for RNNs, for instance).\n- `call(self, inputs, training=None, mask=None, **kwargs)` --\nOf course, you can have both masking and training-specific behavior at the same time.\n\nAdditionally, if you implement the `get_config` method on your custom Layer or model,\nthe functional models you create will still be serializable and cloneable.\n\nHere's a quick example of a custom RNN, written from scratch,\nbeing used in a functional model:\n\"\"\"\n\nunits = 32\ntimesteps = 10\ninput_dim = 5\nbatch_size = 16\n\n\nclass CustomRNN(layers.Layer):\n    def __init__(self):\n        super().__init__()\n        self.units = units\n        self.projection_1 = layers.Dense(units=units, activation=\"tanh\")\n        self.projection_2 = layers.Dense(units=units, activation=\"tanh\")\n        self.classifier = layers.Dense(1)\n\n    def call(self, inputs):\n        outputs = []\n        state = ops.zeros(shape=(inputs.shape[0], self.units))\n        for t in range(inputs.shape[1]):\n            x = inputs[:, t, :]\n            h = self.projection_1(x)\n            y = h + self.projection_2(state)\n            state = y\n            outputs.append(y)\n        features = ops.stack(outputs, axis=1)\n        return self.classifier(features)\n\n\n# Note that you specify a static batch size for the inputs with the `batch_shape`\n# arg, because the inner computation of `CustomRNN` requires a static batch size\n# (when you create the `state` zeros tensor).\ninputs = keras.Input(batch_shape=(batch_size, timesteps, input_dim))\nx = layers.Conv1D(32, 3)(inputs)\noutputs = CustomRNN()(x)\n\nmodel = keras.Model(inputs, outputs)\n\nrnn_model = CustomRNN()\n_ = rnn_model(ops.zeros((1, 10, 5)))\n"
  },
  {
    "path": "guides/making_new_layers_and_models_via_subclassing.py",
    "content": "\"\"\"\nTitle: Making new layers and models via subclassing\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate created: 2019/03/01\nLast modified: 2023/06/25\nDescription: Complete guide to writing `Layer` and `Model` objects from scratch.\nAccelerator: None\n\"\"\"\n\n\"\"\"\n## Introduction\n\nThis guide will cover everything you need to know to build your own\nsubclassed layers and models. In particular, you'll learn about the following features:\n\n- The `Layer` class\n- The `add_weight()` method\n- Trainable and non-trainable weights\n- The `build()` method\n- Making sure your layers can be used with any backend\n- The `add_loss()` method\n- The `training` argument in `call()`\n- The `mask` argument in `call()`\n- Making sure your layers can be serialized\n\nLet's dive in.\n\"\"\"\n\"\"\"\n## Setup\n\"\"\"\n\nimport numpy as np\nimport keras\nfrom keras import ops\nfrom keras import layers\n\n\"\"\"\n## The `Layer` class: the combination of state (weights) and some computation\n\nOne of the central abstractions in Keras is the `Layer` class. A layer\nencapsulates both a state (the layer's \"weights\") and a transformation from\ninputs to outputs (a \"call\", the layer's forward pass).\n\nHere's a densely-connected layer. It has two state variables:\nthe variables `w` and `b`.\n\"\"\"\n\n\nclass Linear(keras.layers.Layer):\n    def __init__(self, units=32, input_dim=32):\n        super().__init__()\n        self.w = self.add_weight(\n            shape=(input_dim, units),\n            initializer=\"random_normal\",\n            trainable=True,\n        )\n        self.b = self.add_weight(\n            shape=(units,), initializer=\"zeros\", trainable=True\n        )\n\n    def call(self, inputs):\n        return ops.matmul(inputs, self.w) + self.b\n\n\n\"\"\"\nYou would use a layer by calling it on some tensor input(s), much like a Python\nfunction.\n\"\"\"\n\nx = ops.ones((2, 2))\nlinear_layer = Linear(4, 2)\ny = linear_layer(x)\nprint(y)\n\n\"\"\"\nNote that the weights `w` and `b` are automatically tracked by the layer upon\nbeing set as layer attributes:\n\"\"\"\n\nassert linear_layer.weights == [linear_layer.w, linear_layer.b]\n\n\"\"\"\n## Layers can have non-trainable weights\n\nBesides trainable weights, you can add non-trainable weights to a layer as\nwell. Such weights are meant not to be taken into account during\nbackpropagation, when you are training the layer.\n\nHere's how to add and use a non-trainable weight:\n\"\"\"\n\n\nclass ComputeSum(keras.layers.Layer):\n    def __init__(self, input_dim):\n        super().__init__()\n        self.total = self.add_weight(\n            initializer=\"zeros\", shape=(input_dim,), trainable=False\n        )\n\n    def call(self, inputs):\n        self.total.assign_add(ops.sum(inputs, axis=0))\n        return self.total\n\n\nx = ops.ones((2, 2))\nmy_sum = ComputeSum(2)\ny = my_sum(x)\nprint(y.numpy())\ny = my_sum(x)\nprint(y.numpy())\n\n\"\"\"\nIt's part of `layer.weights`, but it gets categorized as a non-trainable weight:\n\"\"\"\n\nprint(\"weights:\", len(my_sum.weights))\nprint(\"non-trainable weights:\", len(my_sum.non_trainable_weights))\n\n# It's not included in the trainable weights:\nprint(\"trainable_weights:\", my_sum.trainable_weights)\n\n\"\"\"\n## Best practice: deferring weight creation until the shape of the inputs is known\n\nOur `Linear` layer above took an `input_dim` argument that was used to compute\nthe shape of the weights `w` and `b` in `__init__()`:\n\"\"\"\n\n\nclass Linear(keras.layers.Layer):\n    def __init__(self, units=32, input_dim=32):\n        super().__init__()\n        self.w = self.add_weight(\n            shape=(input_dim, units),\n            initializer=\"random_normal\",\n            trainable=True,\n        )\n        self.b = self.add_weight(\n            shape=(units,), initializer=\"zeros\", trainable=True\n        )\n\n    def call(self, inputs):\n        return ops.matmul(inputs, self.w) + self.b\n\n\n\"\"\"\nIn many cases, you may not know in advance the size of your inputs, and you\nwould like to lazily create weights when that value becomes known, some time\nafter instantiating the layer.\n\nIn the Keras API, we recommend creating layer weights in the\n`build(self, inputs_shape)` method of your layer. Like this:\n\"\"\"\n\n\nclass Linear(keras.layers.Layer):\n    def __init__(self, units=32):\n        super().__init__()\n        self.units = units\n\n    def build(self, input_shape):\n        self.w = self.add_weight(\n            shape=(input_shape[-1], self.units),\n            initializer=\"random_normal\",\n            trainable=True,\n        )\n        self.b = self.add_weight(\n            shape=(self.units,), initializer=\"random_normal\", trainable=True\n        )\n\n    def call(self, inputs):\n        return ops.matmul(inputs, self.w) + self.b\n\n\n\"\"\"\nThe `__call__()` method of your layer will automatically run build the first time\nit is called. You now have a layer that's lazy and thus easier to use:\n\"\"\"\n\n# At instantiation, we don't know on what inputs this is going to get called\nlinear_layer = Linear(32)\n\n# The layer's weights are created dynamically the first time the layer is called\ny = linear_layer(x)\n\n\"\"\"\nImplementing `build()` separately as shown above nicely separates creating weights\nonly once from using weights in every call.\n\"\"\"\n\n\"\"\"\n## Layers are recursively composable\n\nIf you assign a Layer instance as an attribute of another Layer, the outer layer\nwill start tracking the weights created by the inner layer.\n\nWe recommend creating such sublayers in the `__init__()` method and leave it to\nthe first `__call__()` to trigger building their weights.\n\"\"\"\n\n\nclass MLPBlock(keras.layers.Layer):\n    def __init__(self):\n        super().__init__()\n        self.linear_1 = Linear(32)\n        self.linear_2 = Linear(32)\n        self.linear_3 = Linear(1)\n\n    def call(self, inputs):\n        x = self.linear_1(inputs)\n        x = keras.activations.relu(x)\n        x = self.linear_2(x)\n        x = keras.activations.relu(x)\n        return self.linear_3(x)\n\n\nmlp = MLPBlock()\ny = mlp(\n    ops.ones(shape=(3, 64))\n)  # The first call to the `mlp` will create the weights\nprint(\"weights:\", len(mlp.weights))\nprint(\"trainable weights:\", len(mlp.trainable_weights))\n\n\"\"\"\n## Backend-agnostic layers and backend-specific layers\n\nAs long as a layer only uses APIs from the `keras.ops` namespace\n(or other Keras namespaces such as `keras.activations`, `keras.random`, or `keras.layers`),\nthen it can be used with any backend -- TensorFlow, JAX, or PyTorch.\n\nAll layers you've seen so far in this guide work with all Keras backends.\n\nThe `keras.ops` namespace gives you access to:\n\n- The NumPy API, e.g. `ops.matmul`, `ops.sum`, `ops.reshape`, `ops.stack`, etc.\n- Neural networks-specific APIs such as `ops.softmax`, `ops.conv`, `ops.binary_crossentropy`, `ops.relu`, etc.\n\nYou can also use backend-native APIs in your layers (such as `tf.nn` functions),\nbut if you do this, then your layer will only be usable with the backend in question.\nFor instance, you could write the following JAX-specific layer using `jax.numpy`:\n\n```python\nimport jax\n\nclass Linear(keras.layers.Layer):\n    ...\n\n    def call(self, inputs):\n        return jax.numpy.matmul(inputs, self.w) + self.b\n```\n\nThis would be the equivalent TensorFlow-specific layer:\n\n```python\nimport tensorflow as tf\n\nclass Linear(keras.layers.Layer):\n    ...\n\n    def call(self, inputs):\n        return tf.matmul(inputs, self.w) + self.b\n```\n\nAnd this would be the equivalent PyTorch-specific layer:\n\n```python\nimport torch\n\nclass Linear(keras.layers.Layer):\n    ...\n\n    def call(self, inputs):\n        return torch.matmul(inputs, self.w) + self.b\n```\n\nBecause cross-backend compatibility is a tremendously useful property, we strongly\nrecommend that you seek to always make your layers backend-agnostic by leveraging\nonly Keras APIs.\n\"\"\"\n\n\"\"\"\n## The `add_loss()` method\n\nWhen writing the `call()` method of a layer, you can create loss tensors that\nyou will want to use later, when writing your training loop. This is doable by\ncalling `self.add_loss(value)`:\n\"\"\"\n\n\n# A layer that creates an activity regularization loss\nclass ActivityRegularizationLayer(keras.layers.Layer):\n    def __init__(self, rate=1e-2):\n        super().__init__()\n        self.rate = rate\n\n    def call(self, inputs):\n        self.add_loss(self.rate * ops.mean(inputs))\n        return inputs\n\n\n\"\"\"\nThese losses (including those created by any inner layer) can be retrieved via\n`layer.losses`. This property is reset at the start of every `__call__()` to\nthe top-level layer, so that `layer.losses` always contains the loss values\ncreated during the last forward pass.\n\"\"\"\n\n\nclass OuterLayer(keras.layers.Layer):\n    def __init__(self):\n        super().__init__()\n        self.activity_reg = ActivityRegularizationLayer(1e-2)\n\n    def call(self, inputs):\n        return self.activity_reg(inputs)\n\n\nlayer = OuterLayer()\nassert (\n    len(layer.losses) == 0\n)  # No losses yet since the layer has never been called\n\n_ = layer(ops.zeros((1, 1)))\nassert len(layer.losses) == 1  # We created one loss value\n\n# `layer.losses` gets reset at the start of each __call__\n_ = layer(ops.zeros((1, 1)))\nassert len(layer.losses) == 1  # This is the loss created during the call above\n\n\"\"\"\nIn addition, the `loss` property also contains regularization losses created\nfor the weights of any inner layer:\n\"\"\"\n\n\nclass OuterLayerWithKernelRegularizer(keras.layers.Layer):\n    def __init__(self):\n        super().__init__()\n        self.dense = keras.layers.Dense(\n            32, kernel_regularizer=keras.regularizers.l2(1e-3)\n        )\n\n    def call(self, inputs):\n        return self.dense(inputs)\n\n\nlayer = OuterLayerWithKernelRegularizer()\n_ = layer(ops.zeros((1, 1)))\n\n# This is `1e-3 * sum(layer.dense.kernel ** 2)`,\n# created by the `kernel_regularizer` above.\nprint(layer.losses)\n\n\"\"\"\nThese losses are meant to be taken into account when writing custom training loops.\n\nThey also work seamlessly with `fit()` (they get automatically summed and added to the main loss, if any):\n\"\"\"\n\ninputs = keras.Input(shape=(3,))\noutputs = ActivityRegularizationLayer()(inputs)\nmodel = keras.Model(inputs, outputs)\n\n# If there is a loss passed in `compile`, the regularization\n# losses get added to it\nmodel.compile(optimizer=\"adam\", loss=\"mse\")\nmodel.fit(np.random.random((2, 3)), np.random.random((2, 3)))\n\n# It's also possible not to pass any loss in `compile`,\n# since the model already has a loss to minimize, via the `add_loss`\n# call during the forward pass!\nmodel.compile(optimizer=\"adam\")\nmodel.fit(np.random.random((2, 3)), np.random.random((2, 3)))\n\n\"\"\"\n## You can optionally enable serialization on your layers\n\nIf you need your custom layers to be serializable as part of a\n[Functional model](/guides/functional_api/), you can optionally implement a `get_config()`\nmethod:\n\"\"\"\n\n\nclass Linear(keras.layers.Layer):\n    def __init__(self, units=32):\n        super().__init__()\n        self.units = units\n\n    def build(self, input_shape):\n        self.w = self.add_weight(\n            shape=(input_shape[-1], self.units),\n            initializer=\"random_normal\",\n            trainable=True,\n        )\n        self.b = self.add_weight(\n            shape=(self.units,), initializer=\"random_normal\", trainable=True\n        )\n\n    def call(self, inputs):\n        return ops.matmul(inputs, self.w) + self.b\n\n    def get_config(self):\n        return {\"units\": self.units}\n\n\n# Now you can recreate the layer from its config:\nlayer = Linear(64)\nconfig = layer.get_config()\nprint(config)\nnew_layer = Linear.from_config(config)\n\n\"\"\"\nNote that the `__init__()` method of the base `Layer` class takes some keyword\narguments, in particular a `name` and a `dtype`. It's good practice to pass\nthese arguments to the parent class in `__init__()` and to include them in the\nlayer config:\n\"\"\"\n\n\nclass Linear(keras.layers.Layer):\n    def __init__(self, units=32, **kwargs):\n        super().__init__(**kwargs)\n        self.units = units\n\n    def build(self, input_shape):\n        self.w = self.add_weight(\n            shape=(input_shape[-1], self.units),\n            initializer=\"random_normal\",\n            trainable=True,\n        )\n        self.b = self.add_weight(\n            shape=(self.units,), initializer=\"random_normal\", trainable=True\n        )\n\n    def call(self, inputs):\n        return ops.matmul(inputs, self.w) + self.b\n\n    def get_config(self):\n        config = super().get_config()\n        config.update({\"units\": self.units})\n        return config\n\n\nlayer = Linear(64)\nconfig = layer.get_config()\nprint(config)\nnew_layer = Linear.from_config(config)\n\n\"\"\"\nIf you need more flexibility when deserializing the layer from its config, you\ncan also override the `from_config()` class method. This is the base\nimplementation of `from_config()`:\n\n```python\ndef from_config(cls, config):\n    return cls(**config)\n```\n\nTo learn more about serialization and saving, see the complete\n[guide to saving and serializing models](/guides/serialization_and_saving/).\n\"\"\"\n\n\"\"\"\n## Privileged `training` argument in the `call()` method\n\nSome layers, in particular the `BatchNormalization` layer and the `Dropout`\nlayer, have different behaviors during training and inference. For such\nlayers, it is standard practice to expose a `training` (boolean) argument in\nthe `call()` method.\n\nBy exposing this argument in `call()`, you enable the built-in training and\nevaluation loops (e.g. `fit()`) to correctly use the layer in training and\ninference.\n\"\"\"\n\n\nclass CustomDropout(keras.layers.Layer):\n    def __init__(self, rate, **kwargs):\n        super().__init__(**kwargs)\n        self.rate = rate\n\n    def call(self, inputs, training=None):\n        if training:\n            return keras.random.dropout(inputs, rate=self.rate)\n        return inputs\n\n\n\"\"\"\n## Privileged `mask` argument in the `call()` method\n\nThe other privileged argument supported by `call()` is the `mask` argument.\n\nYou will find it in all Keras RNN layers. A mask is a boolean tensor (one\nboolean value per timestep in the input) used to skip certain input timesteps\nwhen processing timeseries data.\n\nKeras will automatically pass the correct `mask` argument to `__call__()` for\nlayers that support it, when a mask is generated by a prior layer.\nMask-generating layers are the `Embedding`\nlayer configured with `mask_zero=True`, and the `Masking` layer.\n\"\"\"\n\n\"\"\"\n## The `Model` class\n\nIn general, you will use the `Layer` class to define inner computation blocks,\nand will use the `Model` class to define the outer model -- the object you\nwill train.\n\nFor instance, in a ResNet50 model, you would have several ResNet blocks\nsubclassing `Layer`, and a single `Model` encompassing the entire ResNet50\nnetwork.\n\nThe `Model` class has the same API as `Layer`, with the following differences:\n\n- It exposes built-in training, evaluation, and prediction loops\n(`model.fit()`, `model.evaluate()`, `model.predict()`).\n- It exposes the list of its inner layers, via the `model.layers` property.\n- It exposes saving and serialization APIs (`save()`, `save_weights()`...)\n\nEffectively, the `Layer` class corresponds to what we refer to in the\nliterature as a \"layer\" (as in \"convolution layer\" or \"recurrent layer\") or as\na \"block\" (as in \"ResNet block\" or \"Inception block\").\n\nMeanwhile, the `Model` class corresponds to what is referred to in the\nliterature as a \"model\" (as in \"deep learning model\") or as a \"network\" (as in\n\"deep neural network\").\n\nSo if you're wondering, \"should I use the `Layer` class or the `Model` class?\",\nask yourself: will I need to call `fit()` on it? Will I need to call `save()`\non it? If so, go with `Model`. If not (either because your class is just a block\nin a bigger system, or because you are writing training & saving code yourself),\nuse `Layer`.\n\nFor instance, we could take our mini-resnet example above, and use it to build\na `Model` that we could train with `fit()`, and that we could save with\n`save_weights()`:\n\"\"\"\n\n\"\"\"\n```python\nclass ResNet(keras.Model):\n\n    def __init__(self, num_classes=1000):\n        super().__init__()\n        self.block_1 = ResNetBlock()\n        self.block_2 = ResNetBlock()\n        self.global_pool = layers.GlobalAveragePooling2D()\n        self.classifier = Dense(num_classes)\n\n    def call(self, inputs):\n        x = self.block_1(inputs)\n        x = self.block_2(x)\n        x = self.global_pool(x)\n        return self.classifier(x)\n\n\nresnet = ResNet()\ndataset = ...\nresnet.fit(dataset, epochs=10)\nresnet.save(filepath.keras)\n```\n\"\"\"\n\n\"\"\"\n## Putting it all together: an end-to-end example\n\nHere's what you've learned so far:\n\n- A `Layer` encapsulate a state (created in `__init__()` or `build()`) and some\ncomputation (defined in `call()`).\n- Layers can be recursively nested to create new, bigger computation blocks.\n- Layers are backend-agnostic as long as they only use Keras APIs. You can use\nbackend-native APIs (such as `jax.numpy`, `torch.nn` or `tf.nn`), but then\nyour layer will only be usable with that specific backend.\n- Layers can create and track losses (typically regularization losses)\nvia `add_loss()`.\n- The outer container, the thing you want to train, is a `Model`. A `Model` is\njust like a `Layer`, but with added training and serialization utilities.\n\nLet's put all of these things together into an end-to-end example: we're going\nto implement a Variational AutoEncoder (VAE) in a backend-agnostic fashion\n-- so that it runs the same with TensorFlow, JAX, and PyTorch.\nWe'll train it on MNIST digits.\n\nOur VAE will be a subclass of `Model`, built as a nested composition of layers\nthat subclass `Layer`. It will feature a regularization loss (KL divergence).\n\"\"\"\n\n\nclass Sampling(layers.Layer):\n    \"\"\"Uses (z_mean, z_log_var) to sample z, the vector encoding a digit.\"\"\"\n\n    def call(self, inputs):\n        z_mean, z_log_var = inputs\n        batch = ops.shape(z_mean)[0]\n        dim = ops.shape(z_mean)[1]\n        epsilon = keras.random.normal(shape=(batch, dim))\n        return z_mean + ops.exp(0.5 * z_log_var) * epsilon\n\n\nclass Encoder(layers.Layer):\n    \"\"\"Maps MNIST digits to a triplet (z_mean, z_log_var, z).\"\"\"\n\n    def __init__(\n        self, latent_dim=32, intermediate_dim=64, name=\"encoder\", **kwargs\n    ):\n        super().__init__(name=name, **kwargs)\n        self.dense_proj = layers.Dense(intermediate_dim, activation=\"relu\")\n        self.dense_mean = layers.Dense(latent_dim)\n        self.dense_log_var = layers.Dense(latent_dim)\n        self.sampling = Sampling()\n\n    def call(self, inputs):\n        x = self.dense_proj(inputs)\n        z_mean = self.dense_mean(x)\n        z_log_var = self.dense_log_var(x)\n        z = self.sampling((z_mean, z_log_var))\n        return z_mean, z_log_var, z\n\n\nclass Decoder(layers.Layer):\n    \"\"\"Converts z, the encoded digit vector, back into a readable digit.\"\"\"\n\n    def __init__(\n        self, original_dim, intermediate_dim=64, name=\"decoder\", **kwargs\n    ):\n        super().__init__(name=name, **kwargs)\n        self.dense_proj = layers.Dense(intermediate_dim, activation=\"relu\")\n        self.dense_output = layers.Dense(original_dim, activation=\"sigmoid\")\n\n    def call(self, inputs):\n        x = self.dense_proj(inputs)\n        return self.dense_output(x)\n\n\nclass VariationalAutoEncoder(keras.Model):\n    \"\"\"Combines the encoder and decoder into an end-to-end model for training.\"\"\"\n\n    def __init__(\n        self,\n        original_dim,\n        intermediate_dim=64,\n        latent_dim=32,\n        name=\"autoencoder\",\n        **kwargs,\n    ):\n        super().__init__(name=name, **kwargs)\n        self.original_dim = original_dim\n        self.encoder = Encoder(\n            latent_dim=latent_dim, intermediate_dim=intermediate_dim\n        )\n        self.decoder = Decoder(original_dim, intermediate_dim=intermediate_dim)\n\n    def call(self, inputs):\n        z_mean, z_log_var, z = self.encoder(inputs)\n        reconstructed = self.decoder(z)\n        # Add KL divergence regularization loss.\n        kl_loss = -0.5 * ops.mean(\n            z_log_var - ops.square(z_mean) - ops.exp(z_log_var) + 1\n        )\n        self.add_loss(kl_loss)\n        return reconstructed\n\n\n\"\"\"\nLet's train it on MNIST using the `fit()` API:\n\"\"\"\n\n(x_train, _), _ = keras.datasets.mnist.load_data()\nx_train = x_train.reshape(60000, 784).astype(\"float32\") / 255\n\noriginal_dim = 784\nvae = VariationalAutoEncoder(784, 64, 32)\n\noptimizer = keras.optimizers.Adam(learning_rate=1e-3)\nvae.compile(optimizer, loss=keras.losses.MeanSquaredError())\n\nvae.fit(x_train, x_train, epochs=2, batch_size=64)\n"
  },
  {
    "path": "guides/sequential_model.py",
    "content": "\"\"\"\nTitle: The Sequential model\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate created: 2020/04/12\nLast modified: 2023/06/25\nDescription: Complete guide to the Sequential model.\nAccelerator: GPU\n\"\"\"\n\n\"\"\"\n## Setup\n\n\"\"\"\n\nimport keras\nfrom keras import layers\nfrom keras import ops\n\n\"\"\"\n## When to use a Sequential model\n\nA `Sequential` model is appropriate for **a plain stack of layers**\nwhere each layer has **exactly one input tensor and one output tensor**.\n\nSchematically, the following `Sequential` model:\n\"\"\"\n\n# Define Sequential model with 3 layers\nmodel = keras.Sequential(\n    [\n        layers.Dense(2, activation=\"relu\", name=\"layer1\"),\n        layers.Dense(3, activation=\"relu\", name=\"layer2\"),\n        layers.Dense(4, name=\"layer3\"),\n    ]\n)\n# Call model on a test input\nx = ops.ones((3, 3))\ny = model(x)\n\n\"\"\"\nis equivalent to this function:\n\"\"\"\n\n# Create 3 layers\nlayer1 = layers.Dense(2, activation=\"relu\", name=\"layer1\")\nlayer2 = layers.Dense(3, activation=\"relu\", name=\"layer2\")\nlayer3 = layers.Dense(4, name=\"layer3\")\n\n# Call layers on a test input\nx = ops.ones((3, 3))\ny = layer3(layer2(layer1(x)))\n\n\"\"\"\nA Sequential model is **not appropriate** when:\n\n- Your model has multiple inputs or multiple outputs\n- Any of your layers has multiple inputs or multiple outputs\n- You need to do layer sharing\n- You want non-linear topology (e.g. a residual connection, a multi-branch\nmodel)\n\"\"\"\n\n\"\"\"\n## Creating a Sequential model\n\nYou can create a Sequential model by passing a list of layers to the Sequential\nconstructor:\n\"\"\"\n\nmodel = keras.Sequential(\n    [\n        layers.Dense(2, activation=\"relu\"),\n        layers.Dense(3, activation=\"relu\"),\n        layers.Dense(4),\n    ]\n)\n\n\"\"\"\nIts layers are accessible via the `layers` attribute:\n\"\"\"\n\nmodel.layers\n\n\"\"\"\nYou can also create a Sequential model incrementally via the `add()` method:\n\"\"\"\n\nmodel = keras.Sequential()\nmodel.add(layers.Dense(2, activation=\"relu\"))\nmodel.add(layers.Dense(3, activation=\"relu\"))\nmodel.add(layers.Dense(4))\n\n\"\"\"\nNote that there's also a corresponding `pop()` method to remove layers:\na Sequential model behaves very much like a list of layers.\n\"\"\"\n\nmodel.pop()\nprint(len(model.layers))  # 2\n\n\"\"\"\nAlso note that the Sequential constructor accepts a `name` argument, just like\nany layer or model in Keras. This is useful to annotate TensorBoard graphs\nwith semantically meaningful names.\n\"\"\"\n\nmodel = keras.Sequential(name=\"my_sequential\")\nmodel.add(layers.Dense(2, activation=\"relu\", name=\"layer1\"))\nmodel.add(layers.Dense(3, activation=\"relu\", name=\"layer2\"))\nmodel.add(layers.Dense(4, name=\"layer3\"))\n\n\"\"\"\n## Specifying the input shape in advance\n\nGenerally, all layers in Keras need to know the shape of their inputs\nin order to be able to create their weights. So when you create a layer like\nthis, initially, it has no weights:\n\"\"\"\n\nlayer = layers.Dense(3)\nlayer.weights  # Empty\n\n\"\"\"\nIt creates its weights the first time it is called on an input, since the shape\nof the weights depends on the shape of the inputs:\n\"\"\"\n\n# Call layer on a test input\nx = ops.ones((1, 4))\ny = layer(x)\nlayer.weights  # Now it has weights, of shape (4, 3) and (3,)\n\n\"\"\"\nNaturally, this also applies to Sequential models. When you instantiate a\nSequential model without an input shape, it isn't \"built\": it has no weights\n(and calling\n`model.weights` results in an error stating just this). The weights are created\nwhen the model first sees some input data:\n\"\"\"\n\nmodel = keras.Sequential(\n    [\n        layers.Dense(2, activation=\"relu\"),\n        layers.Dense(3, activation=\"relu\"),\n        layers.Dense(4),\n    ]\n)  # No weights at this stage!\n\n# At this point, you can't do this:\n# model.weights\n\n# You also can't do this:\n# model.summary()\n\n# Call the model on a test input\nx = ops.ones((1, 4))\ny = model(x)\nprint(\"Number of weights after calling the model:\", len(model.weights))  # 6\n\n\"\"\"\nOnce a model is \"built\", you can call its `summary()` method to display its\ncontents:\n\"\"\"\n\nmodel.summary()\n\n\"\"\"\nHowever, it can be very useful when building a Sequential model incrementally\nto be able to display the summary of the model so far, including the current\noutput shape. In this case, you should start your model by passing an `Input`\nobject to your model, so that it knows its input shape from the start:\n\"\"\"\n\nmodel = keras.Sequential()\nmodel.add(keras.Input(shape=(4,)))\nmodel.add(layers.Dense(2, activation=\"relu\"))\n\nmodel.summary()\n\n\"\"\"\nNote that the `Input` object is not displayed as part of `model.layers`, since\nit isn't a layer:\n\"\"\"\n\nmodel.layers\n\n\"\"\"\nModels built with a predefined input shape like this always have weights (even\nbefore seeing any data) and always have a defined output shape.\n\nIn general, it's a recommended best practice to always specify the input shape\nof a Sequential model in advance if you know what it is.\n\"\"\"\n\n\"\"\"\n## A common debugging workflow: `add()` + `summary()`\n\nWhen building a new Sequential architecture, it's useful to incrementally stack\nlayers with `add()` and frequently print model summaries. For instance, this\nenables you to monitor how a stack of `Conv2D` and `MaxPooling2D` layers is\ndownsampling image feature maps:\n\"\"\"\n\nmodel = keras.Sequential()\nmodel.add(keras.Input(shape=(250, 250, 3)))  # 250x250 RGB images\nmodel.add(layers.Conv2D(32, 5, strides=2, activation=\"relu\"))\nmodel.add(layers.Conv2D(32, 3, activation=\"relu\"))\nmodel.add(layers.MaxPooling2D(3))\n\n# Can you guess what the current output shape is at this point? Probably not.\n# Let's just print it:\nmodel.summary()\n\n# The answer was: (40, 40, 32), so we can keep downsampling...\n\nmodel.add(layers.Conv2D(32, 3, activation=\"relu\"))\nmodel.add(layers.Conv2D(32, 3, activation=\"relu\"))\nmodel.add(layers.MaxPooling2D(3))\nmodel.add(layers.Conv2D(32, 3, activation=\"relu\"))\nmodel.add(layers.Conv2D(32, 3, activation=\"relu\"))\nmodel.add(layers.MaxPooling2D(2))\n\n# And now?\nmodel.summary()\n\n# Now that we have 4x4 feature maps, time to apply global max pooling.\nmodel.add(layers.GlobalMaxPooling2D())\n\n# Finally, we add a classification layer.\nmodel.add(layers.Dense(10))\n\n\"\"\"\nVery practical, right?\n\n\n\"\"\"\n\n\"\"\"\n## What to do once you have a model\n\nOnce your model architecture is ready, you will want to:\n\n- Train your model, evaluate it, and run inference. See our\n[guide to training & evaluation with the built-in loops](\n    /guides/training_with_built_in_methods/)\n- Save your model to disk and restore it. See our\n[guide to serialization & saving](/guides/serialization_and_saving/).\n- Speed up model training by leveraging multiple GPUs. See our\n[guide to multi-GPU and distributed training](https://keras.io/guides/distributed_training/).\n\"\"\"\n\n\"\"\"\n## Feature extraction with a Sequential model\n\nOnce a Sequential model has been built, it behaves like a [Functional API\nmodel](/guides/functional_api/). This means that every layer has an `input`\nand `output` attribute. These attributes can be used to do neat things, like\nquickly\ncreating a model that extracts the outputs of all intermediate layers in a\nSequential model:\n\"\"\"\n\ninitial_model = keras.Sequential(\n    [\n        keras.Input(shape=(250, 250, 3)),\n        layers.Conv2D(32, 5, strides=2, activation=\"relu\"),\n        layers.Conv2D(32, 3, activation=\"relu\"),\n        layers.Conv2D(32, 3, activation=\"relu\"),\n    ]\n)\nfeature_extractor = keras.Model(\n    inputs=initial_model.inputs,\n    outputs=[layer.output for layer in initial_model.layers],\n)\n\n# Call feature extractor on test input.\nx = ops.ones((1, 250, 250, 3))\nfeatures = feature_extractor(x)\n\n\"\"\"\nHere's a similar example that only extract features from one layer:\n\"\"\"\n\ninitial_model = keras.Sequential(\n    [\n        keras.Input(shape=(250, 250, 3)),\n        layers.Conv2D(32, 5, strides=2, activation=\"relu\"),\n        layers.Conv2D(32, 3, activation=\"relu\", name=\"my_intermediate_layer\"),\n        layers.Conv2D(32, 3, activation=\"relu\"),\n    ]\n)\nfeature_extractor = keras.Model(\n    inputs=initial_model.inputs,\n    outputs=initial_model.get_layer(name=\"my_intermediate_layer\").output,\n)\n# Call feature extractor on test input.\nx = ops.ones((1, 250, 250, 3))\nfeatures = feature_extractor(x)\n\n\"\"\"\n## Transfer learning with a Sequential model\n\nTransfer learning consists of freezing the bottom layers in a model and only training\nthe top layers. If you aren't familiar with it, make sure to read our [guide\nto transfer learning](/guides/transfer_learning/).\n\nHere are two common transfer learning blueprint involving Sequential models.\n\nFirst, let's say that you have a Sequential model, and you want to freeze all\nlayers except the last one. In this case, you would simply iterate over\n`model.layers` and set `layer.trainable = False` on each layer, except the\nlast one. Like this:\n\n```python\nmodel = keras.Sequential([\n    keras.Input(shape=(784)),\n    layers.Dense(32, activation='relu'),\n    layers.Dense(32, activation='relu'),\n    layers.Dense(32, activation='relu'),\n    layers.Dense(10),\n])\n\n# Presumably you would want to first load pre-trained weights.\nmodel.load_weights(...)\n\n# Freeze all layers except the last one.\nfor layer in model.layers[:-1]:\n  layer.trainable = False\n\n# Recompile and train (this will only update the weights of the last layer).\nmodel.compile(...)\nmodel.fit(...)\n```\n\nAnother common blueprint is to use a Sequential model to stack a pre-trained\nmodel and some freshly initialized classification layers. Like this:\n\n```python\n# Load a convolutional base with pre-trained weights\nbase_model = keras.applications.Xception(\n    weights='imagenet',\n    include_top=False,\n    pooling='avg')\n\n# Freeze the base model\nbase_model.trainable = False\n\n# Use a Sequential model to add a trainable classifier on top\nmodel = keras.Sequential([\n    base_model,\n    layers.Dense(1000),\n])\n\n# Compile & train\nmodel.compile(...)\nmodel.fit(...)\n```\n\nIf you do transfer learning, you will probably find yourself frequently using\nthese two patterns.\n\"\"\"\n\n\"\"\"\nThat's about all you need to know about Sequential models!\n\nTo find out more about building models in Keras, see:\n\n- [Guide to the Functional API](/guides/functional_api/)\n- [Guide to making new Layers & Models via subclassing](\n    /guides/making_new_layers_and_models_via_subclassing/)\n\"\"\"\n"
  },
  {
    "path": "guides/training_with_built_in_methods.py",
    "content": "\"\"\"\nTitle: Training & evaluation with the built-in methods\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate created: 2019/03/01\nLast modified: 2023/03/25\nDescription: Complete guide to training & evaluation with `fit()` and `evaluate()`.\nAccelerator: GPU\n\"\"\"\n\n\"\"\"\n## Setup\n\"\"\"\n\n# We import torch & TF so as to use torch Dataloaders & tf.data.Datasets.\nimport torch\nimport tensorflow as tf\n\nimport os\nimport numpy as np\nimport keras\nfrom keras import layers\nfrom keras import ops\n\n\"\"\"\n## Introduction\n\nThis guide covers training, evaluation, and prediction (inference) models\nwhen using built-in APIs for training & validation (such as `Model.fit()`,\n`Model.evaluate()` and `Model.predict()`).\n\nIf you are interested in leveraging `fit()` while specifying your\nown training step function, see the\n[Customizing what happens in `fit()` guide](/guides/customizing_what_happens_in_fit/).\n\nIf you are interested in writing your own training & evaluation loops from\nscratch, see the guide\n[\"writing a training loop from scratch\"](/guides/writing_a_training_loop_from_scratch/).\n\nIn general, whether you are using built-in loops or writing your own, model training &\nevaluation works strictly in the same way across every kind of Keras model --\nSequential models, models built with the Functional API, and models written from\nscratch via model subclassing.\n\nThis guide doesn't cover distributed training, which is covered in our\n[guide to multi-GPU & distributed training](https://keras.io/guides/distributed_training/).\n\"\"\"\n\n\"\"\"\n## API overview: a first end-to-end example\n\nWhen passing data to the built-in training loops of a model, you should either use:\n\n- NumPy arrays (if your data is small and fits in memory)\n- Subclasses of `keras.utils.PyDataset`\n- `tf.data.Dataset` objects\n- PyTorch `DataLoader` instances\n\nIn the next few paragraphs, we'll use the MNIST dataset as NumPy arrays, in\norder to demonstrate how to use optimizers, losses, and metrics. Afterwards, we'll\ntake a close look at each of the other options.\n\nLet's consider the following model (here, we build in with the Functional API, but it\ncould be a Sequential model or a subclassed model as well):\n\"\"\"\n\ninputs = keras.Input(shape=(784,), name=\"digits\")\nx = layers.Dense(64, activation=\"relu\", name=\"dense_1\")(inputs)\nx = layers.Dense(64, activation=\"relu\", name=\"dense_2\")(x)\noutputs = layers.Dense(10, activation=\"softmax\", name=\"predictions\")(x)\n\nmodel = keras.Model(inputs=inputs, outputs=outputs)\n\n\"\"\"\nHere's what the typical end-to-end workflow looks like, consisting of:\n\n- Training\n- Validation on a holdout set generated from the original training data\n- Evaluation on the test data\n\nWe'll use MNIST data for this example.\n\"\"\"\n\n(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n\n# Preprocess the data (these are NumPy arrays)\nx_train = x_train.reshape(60000, 784).astype(\"float32\") / 255\nx_test = x_test.reshape(10000, 784).astype(\"float32\") / 255\n\ny_train = y_train.astype(\"float32\")\ny_test = y_test.astype(\"float32\")\n\n# Reserve 10,000 samples for validation\nx_val = x_train[-10000:]\ny_val = y_train[-10000:]\nx_train = x_train[:-10000]\ny_train = y_train[:-10000]\n\n\"\"\"\nWe specify the training configuration (optimizer, loss, metrics):\n\"\"\"\n\nmodel.compile(\n    optimizer=keras.optimizers.RMSprop(),  # Optimizer\n    # Loss function to minimize\n    loss=keras.losses.SparseCategoricalCrossentropy(),\n    # List of metrics to monitor\n    metrics=[keras.metrics.SparseCategoricalAccuracy()],\n)\n\n\"\"\"\nWe call `fit()`, which will train the model by slicing the data into \"batches\" of size\n`batch_size`, and repeatedly iterating over the entire dataset for a given number of\n`epochs`.\n\"\"\"\n\nprint(\"Fit model on training data\")\nhistory = model.fit(\n    x_train,\n    y_train,\n    batch_size=64,\n    epochs=2,\n    # We pass some validation for\n    # monitoring validation loss and metrics\n    # at the end of each epoch\n    validation_data=(x_val, y_val),\n)\n\n\"\"\"\nThe returned `history` object holds a record of the loss values and metric values\nduring training:\n\"\"\"\n\nhistory.history\n\n\"\"\"\nWe evaluate the model on the test data via `evaluate()`:\n\"\"\"\n\n# Evaluate the model on the test data using `evaluate`\nprint(\"Evaluate on test data\")\nresults = model.evaluate(x_test, y_test, batch_size=128)\nprint(\"test loss, test acc:\", results)\n\n# Generate predictions (probabilities -- the output of the last layer)\n# on new data using `predict`\nprint(\"Generate predictions for 3 samples\")\npredictions = model.predict(x_test[:3])\nprint(\"predictions shape:\", predictions.shape)\n\n\"\"\"\nNow, let's review each piece of this workflow in detail.\n\"\"\"\n\n\"\"\"\n## The `compile()` method: specifying a loss, metrics, and an optimizer\n\nTo train a model with `fit()`, you need to specify a loss function, an optimizer, and\noptionally, some metrics to monitor.\n\nYou pass these to the model as arguments to the `compile()` method:\n\"\"\"\n\nmodel.compile(\n    optimizer=keras.optimizers.RMSprop(learning_rate=1e-3),\n    loss=keras.losses.SparseCategoricalCrossentropy(),\n    metrics=[keras.metrics.SparseCategoricalAccuracy()],\n)\n\n\"\"\"\nThe `metrics` argument should be a list -- your model can have any number of metrics.\n\nIf your model has multiple outputs, you can specify different losses and metrics for\neach output, and you can modulate the contribution of each output to the total loss of\nthe model. You will find more details about this in the **Passing data to multi-input,\nmulti-output models** section.\n\nNote that if you're satisfied with the default settings, in many cases the optimizer,\nloss, and metrics can be specified via string identifiers as a shortcut:\n\"\"\"\n\nmodel.compile(\n    optimizer=\"rmsprop\",\n    loss=\"sparse_categorical_crossentropy\",\n    metrics=[\"sparse_categorical_accuracy\"],\n)\n\n\"\"\"\nFor later reuse, let's put our model definition and compile step in functions; we will\ncall them several times across different examples in this guide.\n\"\"\"\n\n\ndef get_uncompiled_model():\n    inputs = keras.Input(shape=(784,), name=\"digits\")\n    x = layers.Dense(64, activation=\"relu\", name=\"dense_1\")(inputs)\n    x = layers.Dense(64, activation=\"relu\", name=\"dense_2\")(x)\n    outputs = layers.Dense(10, activation=\"softmax\", name=\"predictions\")(x)\n    model = keras.Model(inputs=inputs, outputs=outputs)\n    return model\n\n\ndef get_compiled_model():\n    model = get_uncompiled_model()\n    model.compile(\n        optimizer=\"rmsprop\",\n        loss=\"sparse_categorical_crossentropy\",\n        metrics=[\"sparse_categorical_accuracy\"],\n    )\n    return model\n\n\n\"\"\"\n### Many built-in optimizers, losses, and metrics are available\n\nIn general, you won't have to create your own losses, metrics, or optimizers\nfrom scratch, because what you need is likely to be already part of the Keras API:\n\nOptimizers:\n\n- `SGD()` (with or without momentum)\n- `RMSprop()`\n- `Adam()`\n- etc.\n\nLosses:\n\n- `MeanSquaredError()`\n- `KLDivergence()`\n- `CosineSimilarity()`\n- etc.\n\nMetrics:\n\n- `AUC()`\n- `Precision()`\n- `Recall()`\n- etc.\n\"\"\"\n\n\"\"\"\n### Custom losses\n\nIf you need to create a custom loss, Keras provides three ways to do so.\n\nThe first method involves creating a function that accepts inputs `y_true` and\n`y_pred`. The following example shows a loss function that computes the mean squared\nerror between the real data and the predictions:\n\"\"\"\n\n\ndef custom_mean_squared_error(y_true, y_pred):\n    return ops.mean(ops.square(y_true - y_pred), axis=-1)\n\n\nmodel = get_uncompiled_model()\nmodel.compile(optimizer=keras.optimizers.Adam(), loss=custom_mean_squared_error)\n\n# We need to one-hot encode the labels to use MSE\ny_train_one_hot = ops.one_hot(y_train, num_classes=10)\nmodel.fit(x_train, y_train_one_hot, batch_size=64, epochs=1)\n\n\"\"\"\nIf you need a loss function that takes in parameters beside `y_true` and `y_pred`, you\ncan subclass the `keras.losses.Loss` class and implement the following two methods:\n\n- `__init__(self)`: accept parameters to pass during the call of your loss function\n- `call(self, y_true, y_pred)`: use the targets (y_true) and the model predictions\n(y_pred) to compute the model's loss\n\nLet's say you want to use mean squared error, but with an added term that\nwill de-incentivize  prediction values far from 0.5 (we assume that the categorical\ntargets are one-hot encoded and take values between 0 and 1). This\ncreates an incentive for the model not to be too confident, which may help\nreduce overfitting (we won't know if it works until we try!).\n\nHere's how you would do it:\n\"\"\"\n\n\nclass CustomMSE(keras.losses.Loss):\n    def __init__(self, regularization_factor=0.1, name=\"custom_mse\"):\n        super().__init__(name=name)\n        self.regularization_factor = regularization_factor\n\n    def call(self, y_true, y_pred):\n        mse = ops.mean(ops.square(y_true - y_pred), axis=-1)\n        reg = ops.mean(ops.square(0.5 - y_pred), axis=-1)\n        return mse + reg * self.regularization_factor\n\n\nmodel = get_uncompiled_model()\nmodel.compile(optimizer=keras.optimizers.Adam(), loss=CustomMSE())\n\ny_train_one_hot = ops.one_hot(y_train, num_classes=10)\nmodel.fit(x_train, y_train_one_hot, batch_size=64, epochs=1)\n\n\n\"\"\"\n### Custom metrics\n\nIf you need a metric that isn't part of the API, you can easily create custom metrics\nby subclassing the `keras.metrics.Metric` class. You will need to implement 4\nmethods:\n\n- `__init__(self)`, in which you will create state variables for your metric.\n- `update_state(self, y_true, y_pred, sample_weight=None)`, which uses the targets\ny_true and the model predictions y_pred to update the state variables.\n- `result(self)`, which uses the state variables to compute the final results.\n- `reset_state(self)`, which reinitializes the state of the metric.\n\nState update and results computation are kept separate (in `update_state()` and\n`result()`, respectively) because in some cases, the results computation might be very\nexpensive and would only be done periodically.\n\nHere's a simple example showing how to implement a `CategoricalTruePositives` metric\nthat counts how many samples were correctly classified as belonging to a given class:\n\"\"\"\n\n\nclass CategoricalTruePositives(keras.metrics.Metric):\n    def __init__(self, name=\"categorical_true_positives\", **kwargs):\n        super().__init__(name=name, **kwargs)\n        self.true_positives = self.add_variable(\n            shape=(), name=\"ctp\", initializer=\"zeros\"\n        )\n\n    def update_state(self, y_true, y_pred, sample_weight=None):\n        y_pred = ops.reshape(ops.argmax(y_pred, axis=1), (-1, 1))\n        values = ops.cast(y_true, \"int32\") == ops.cast(y_pred, \"int32\")\n        values = ops.cast(values, \"float32\")\n        if sample_weight is not None:\n            sample_weight = ops.cast(sample_weight, \"float32\")\n            values = ops.multiply(values, sample_weight)\n        self.true_positives.assign_add(ops.sum(values))\n\n    def result(self):\n        return self.true_positives\n\n    def reset_state(self):\n        # The state of the metric will be reset at the start of each epoch.\n        self.true_positives.assign(0)\n\n\nmodel = get_uncompiled_model()\nmodel.compile(\n    optimizer=keras.optimizers.RMSprop(learning_rate=1e-3),\n    loss=keras.losses.SparseCategoricalCrossentropy(),\n    metrics=[CategoricalTruePositives()],\n)\nmodel.fit(x_train, y_train, batch_size=64, epochs=3)\n\n\"\"\"\n### Handling losses and metrics that don't fit the standard signature\n\nThe overwhelming majority of losses and metrics can be computed from `y_true` and\n`y_pred`, where `y_pred` is an output of your model -- but not all of them. For\ninstance, a regularization loss may only require the activation of a layer (there are\nno targets in this case), and this activation may not be a model output.\n\nIn such cases, you can call `self.add_loss(loss_value)` from inside the call method of\na custom layer. Losses added in this way get added to the \"main\" loss during training\n(the one passed to `compile()`). Here's a simple example that adds activity\nregularization (note that activity regularization is built-in in all Keras layers --\nthis layer is just for the sake of providing a concrete example):\n\"\"\"\n\n\nclass ActivityRegularizationLayer(layers.Layer):\n    def call(self, inputs):\n        self.add_loss(ops.sum(inputs) * 0.1)\n        return inputs  # Pass-through layer.\n\n\ninputs = keras.Input(shape=(784,), name=\"digits\")\nx = layers.Dense(64, activation=\"relu\", name=\"dense_1\")(inputs)\n\n# Insert activity regularization as a layer\nx = ActivityRegularizationLayer()(x)\n\nx = layers.Dense(64, activation=\"relu\", name=\"dense_2\")(x)\noutputs = layers.Dense(10, name=\"predictions\")(x)\n\nmodel = keras.Model(inputs=inputs, outputs=outputs)\nmodel.compile(\n    optimizer=keras.optimizers.RMSprop(learning_rate=1e-3),\n    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n)\n\n# The displayed loss will be much higher than before\n# due to the regularization component.\nmodel.fit(x_train, y_train, batch_size=64, epochs=1)\n\n\"\"\"\nNote that when you pass losses via `add_loss()`, it becomes possible to call\n`compile()` without a loss function, since the model already has a loss to minimize.\n\nConsider the following `LogisticEndpoint` layer: it takes as inputs\ntargets & logits, and it tracks a crossentropy loss via `add_loss()`.\n\"\"\"\n\n\nclass LogisticEndpoint(keras.layers.Layer):\n    def __init__(self, name=None):\n        super().__init__(name=name)\n        self.loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)\n\n    def call(self, targets, logits, sample_weights=None):\n        # Compute the training-time loss value and add it\n        # to the layer using `self.add_loss()`.\n        loss = self.loss_fn(targets, logits, sample_weights)\n        self.add_loss(loss)\n\n        # Return the inference-time prediction tensor (for `.predict()`).\n        return ops.softmax(logits)\n\n\n\"\"\"\nYou can use it in a model with two inputs (input data & targets), compiled without a\n`loss` argument, like this:\n\"\"\"\n\ninputs = keras.Input(shape=(3,), name=\"inputs\")\ntargets = keras.Input(shape=(10,), name=\"targets\")\nlogits = keras.layers.Dense(10)(inputs)\npredictions = LogisticEndpoint(name=\"predictions\")(targets, logits)\n\nmodel = keras.Model(inputs=[inputs, targets], outputs=predictions)\nmodel.compile(optimizer=\"adam\")  # No loss argument!\n\ndata = {\n    \"inputs\": np.random.random((3, 3)),\n    \"targets\": np.random.random((3, 10)),\n}\nmodel.fit(data)\n\n\"\"\"\nFor more information about training multi-input models, see the section **Passing data\nto multi-input, multi-output models**.\n\"\"\"\n\n\"\"\"\n### Automatically setting apart a validation holdout set\n\nIn the first end-to-end example you saw, we used the `validation_data` argument to pass\na tuple of NumPy arrays `(x_val, y_val)` to the model for evaluating a validation loss\nand validation metrics at the end of each epoch.\n\nHere's another option: the argument `validation_split` allows you to automatically\nreserve part of your training data for validation. The argument value represents the\nfraction of the data to be reserved for validation, so it should be set to a number\nhigher than 0 and lower than 1. For instance, `validation_split=0.2` means \"use 20% of\nthe data for validation\", and `validation_split=0.6` means \"use 60% of the data for\nvalidation\".\n\nThe way the validation is computed is by taking the last x% samples of the arrays\nreceived by the `fit()` call, before any shuffling.\n\nNote that you can only use `validation_split` when training with NumPy data.\n\"\"\"\n\nmodel = get_compiled_model()\nmodel.fit(x_train, y_train, batch_size=64, validation_split=0.2, epochs=1)\n\n\"\"\"\n## Training & evaluation using `tf.data` Datasets\n\nIn the past few paragraphs, you've seen how to handle losses, metrics, and optimizers,\nand you've seen how to use the `validation_data` and `validation_split` arguments in\n`fit()`, when your data is passed as NumPy arrays.\n\nAnother option is to use an iterator-like, such as a `tf.data.Dataset`, a\nPyTorch `DataLoader`, or a Keras `PyDataset`. Let's take look at the former.\n\nThe `tf.data` API is a set of utilities in TensorFlow 2.0 for loading and preprocessing\ndata in a way that's fast and scalable. For a complete guide about creating `Datasets`,\nsee the [tf.data documentation](https://www.tensorflow.org/guide/data).\n\n**You can use `tf.data` to train your Keras\nmodels regardless of the backend you're using --\nwhether it's JAX, PyTorch, or TensorFlow.**\nYou can pass a `Dataset` instance directly to the methods `fit()`, `evaluate()`, and\n`predict()`:\n\"\"\"\n\nmodel = get_compiled_model()\n\n# First, let's create a training Dataset instance.\n# For the sake of our example, we'll use the same MNIST data as before.\ntrain_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n# Shuffle and slice the dataset.\ntrain_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)\n\n# Now we get a test dataset.\ntest_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))\ntest_dataset = test_dataset.batch(64)\n\n# Since the dataset already takes care of batching,\n# we don't pass a `batch_size` argument.\nmodel.fit(train_dataset, epochs=3)\n\n# You can also evaluate or predict on a dataset.\nprint(\"Evaluate\")\nresult = model.evaluate(test_dataset)\ndict(zip(model.metrics_names, result))\n\n\"\"\"\nNote that the Dataset is reset at the end of each epoch, so it can be reused of the\nnext epoch.\n\nIf you want to run training only on a specific number of batches from this Dataset, you\ncan pass the `steps_per_epoch` argument, which specifies how many training steps the\nmodel should run using this Dataset before moving on to the next epoch.\n\"\"\"\n\nmodel = get_compiled_model()\n\n# Prepare the training dataset\ntrain_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\ntrain_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)\n\n# Only use the 100 batches per epoch (that's 64 * 100 samples)\nmodel.fit(train_dataset, epochs=3, steps_per_epoch=100)\n\n\"\"\"\nYou can also pass a `Dataset` instance as the `validation_data` argument in `fit()`:\n\"\"\"\n\nmodel = get_compiled_model()\n\n# Prepare the training dataset\ntrain_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\ntrain_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)\n\n# Prepare the validation dataset\nval_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))\nval_dataset = val_dataset.batch(64)\n\nmodel.fit(train_dataset, epochs=1, validation_data=val_dataset)\n\n\"\"\"\nAt the end of each epoch, the model will iterate over the validation dataset and\ncompute the validation loss and validation metrics.\n\nIf you want to run validation only on a specific number of batches from this dataset,\nyou can pass the `validation_steps` argument, which specifies how many validation\nsteps the model should run with the validation dataset before interrupting validation\nand moving on to the next epoch:\n\"\"\"\n\nmodel = get_compiled_model()\n\n# Prepare the training dataset\ntrain_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\ntrain_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)\n\n# Prepare the validation dataset\nval_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))\nval_dataset = val_dataset.batch(64)\n\nmodel.fit(\n    train_dataset,\n    epochs=1,\n    # Only run validation using the first 10 batches of the dataset\n    # using the `validation_steps` argument\n    validation_data=val_dataset,\n    validation_steps=10,\n)\n\n\"\"\"\nNote that the validation dataset will be reset after each use (so that you will always\nbe evaluating on the same samples from epoch to epoch).\n\nThe argument `validation_split` (generating a holdout set from the training data) is\nnot supported when training from `Dataset` objects, since this feature requires the\nability to index the samples of the datasets, which is not possible in general with\nthe `Dataset` API.\n\"\"\"\n\n\"\"\"\n## Training & evaluation using `PyDataset` instances\n\n`keras.utils.PyDataset` is a utility that you can subclass to obtain\na Python generator with two important properties:\n\n- It works well with multiprocessing.\n- It can be shuffled (e.g. when passing `shuffle=True` in `fit()`).\n\nA `PyDataset` must implement two methods:\n\n- `__getitem__`\n- `__len__`\n\nThe method `__getitem__` should return a complete batch.\nIf you want to modify your dataset between epochs, you may implement `on_epoch_end`.\nYou may also implement `on_epoch_begin` to be called at the start of each epoch.\n\nHere's a quick example:\n\"\"\"\n\n\nclass ExamplePyDataset(keras.utils.PyDataset):\n    def __init__(self, x, y, batch_size, **kwargs):\n        super().__init__(**kwargs)\n        self.x = x\n        self.y = y\n        self.batch_size = batch_size\n\n    def __len__(self):\n        return int(np.ceil(len(self.x) / float(self.batch_size)))\n\n    def __getitem__(self, idx):\n        batch_x = self.x[idx * self.batch_size : (idx + 1) * self.batch_size]\n        batch_y = self.y[idx * self.batch_size : (idx + 1) * self.batch_size]\n        return batch_x, batch_y\n\n\ntrain_py_dataset = ExamplePyDataset(x_train, y_train, batch_size=32)\nval_py_dataset = ExamplePyDataset(x_val, y_val, batch_size=32)\n\n\"\"\"\nTo fit the model, pass the dataset instead as the `x` argument (no need for a `y`\nargument since the dataset includes the targets), and pass the validation dataset\nas the `validation_data` argument. And no need for the `validation_batch_size`\nargument, since the dataset is already batched!\n\"\"\"\n\nmodel = get_compiled_model()\nmodel.fit(\n    train_py_dataset, batch_size=64, validation_data=val_py_dataset, epochs=1\n)\n\n\"\"\"\nEvaluating the model is just as easy:\n\"\"\"\n\nmodel.evaluate(val_py_dataset)\n\n\"\"\"\nImportantly, `PyDataset` objects support three common constructor arguments\nthat handle the parallel processing configuration:\n\n- `workers`: Number of workers to use in multithreading or\n    multiprocessing. Typically, you'd set it to the number of\n    cores on your CPU.\n- `use_multiprocessing`: Whether to use Python multiprocessing for\n    parallelism. Setting this to `True` means that your\n    dataset will be replicated in multiple forked processes.\n    This is necessary to gain compute-level (rather than I/O level)\n    benefits from parallelism. However it can only be set to\n    `True` if your dataset can be safely pickled.\n- `max_queue_size`: Maximum number of batches to keep in the queue\n    when iterating over the dataset in a multithreaded or\n    multiprocessed setting.\n    You can reduce this value to reduce the CPU memory consumption of\n    your dataset. It defaults to 10.\n\nBy default, multiprocessing is disabled (`use_multiprocessing=False`) and only\none thread is used. You should make sure to only turn on `use_multiprocessing` if\nyour code is running inside a Python `if __name__ == \"__main__\":` block in order\nto avoid issues.\n\nHere's a 4-thread, non-multiprocessed example:\n\"\"\"\n\ntrain_py_dataset = ExamplePyDataset(x_train, y_train, batch_size=32, workers=4)\nval_py_dataset = ExamplePyDataset(x_val, y_val, batch_size=32, workers=4)\n\nmodel = get_compiled_model()\nmodel.fit(\n    train_py_dataset, batch_size=64, validation_data=val_py_dataset, epochs=1\n)\n\n\"\"\"\n## Training & evaluation using PyTorch `DataLoader` objects\n\nAll built-in training and evaluation APIs are also compatible with `torch.utils.data.Dataset` and\n`torch.utils.data.DataLoader` objects -- regardless of whether you're using the PyTorch backend,\nor the JAX or TensorFlow backends. Let's take a look at a simple example.\n\nUnlike `PyDataset` which are batch-centric, PyTorch `Dataset` objects are sample-centric:\nthe `__len__` method returns the number of samples,\nand the `__getitem__` method returns a specific sample.\n\"\"\"\n\n\nclass ExampleTorchDataset(torch.utils.data.Dataset):\n    def __init__(self, x, y):\n        self.x = x\n        self.y = y\n\n    def __len__(self):\n        return len(self.x)\n\n    def __getitem__(self, idx):\n        return self.x[idx], self.y[idx]\n\n\ntrain_torch_dataset = ExampleTorchDataset(x_train, y_train)\nval_torch_dataset = ExampleTorchDataset(x_val, y_val)\n\n\"\"\"\nTo use a PyTorch Dataset, you need to wrap it into a `Dataloader` which takes care\nof batching and shuffling:\n\"\"\"\n\ntrain_dataloader = torch.utils.data.DataLoader(\n    train_torch_dataset, batch_size=32, shuffle=True\n)\nval_dataloader = torch.utils.data.DataLoader(\n    val_torch_dataset, batch_size=32, shuffle=True\n)\n\n\"\"\"\nNow you can use them in the Keras API just like any other iterator:\n\"\"\"\n\nmodel = get_compiled_model()\nmodel.fit(\n    train_dataloader, batch_size=64, validation_data=val_dataloader, epochs=1\n)\nmodel.evaluate(val_dataloader)\n\n\"\"\"\n## Using sample weighting and class weighting\n\nWith the default settings the weight of a sample is decided by its frequency\nin the dataset. There are two methods to weight the data, independent of\nsample frequency:\n\n* Class weights\n* Sample weights\n\"\"\"\n\n\"\"\"\n### Class weights\n\nThis is set by passing a dictionary to the `class_weight` argument to\n`Model.fit()`. This dictionary maps class indices to the weight that should\nbe used for samples belonging to this class.\n\nThis can be used to balance classes without resampling, or to train a\nmodel that gives more importance to a particular class.\n\nFor instance, if class \"0\" is half as represented as class \"1\" in your data,\nyou could use `Model.fit(..., class_weight={0: 1., 1: 0.5})`.\n\"\"\"\n\n\"\"\"\nHere's a NumPy example where we use class weights or sample weights to\ngive more importance to the correct classification of class #5 (which\nis the digit \"5\" in the MNIST dataset).\n\"\"\"\n\nclass_weight = {\n    0: 1.0,\n    1: 1.0,\n    2: 1.0,\n    3: 1.0,\n    4: 1.0,\n    # Set weight \"2\" for class \"5\",\n    # making this class 2x more important\n    5: 2.0,\n    6: 1.0,\n    7: 1.0,\n    8: 1.0,\n    9: 1.0,\n}\n\nprint(\"Fit with class weight\")\nmodel = get_compiled_model()\nmodel.fit(x_train, y_train, class_weight=class_weight, batch_size=64, epochs=1)\n\n\"\"\"\n### Sample weights\n\nFor fine grained control, or if you are not building a classifier,\nyou can use \"sample weights\".\n\n- When training from NumPy data: Pass the `sample_weight`\n  argument to `Model.fit()`.\n- When training from `tf.data` or any other sort of iterator:\n  Yield `(input_batch, label_batch, sample_weight_batch)` tuples.\n\nA \"sample weights\" array is an array of numbers that specify how much weight\neach sample in a batch should have in computing the total loss. It is commonly\nused in imbalanced classification problems (the idea being to give more weight\nto rarely-seen classes).\n\nWhen the weights used are ones and zeros, the array can be used as a *mask* for\nthe loss function (entirely discarding the contribution of certain samples to\nthe total loss).\n\"\"\"\n\nsample_weight = np.ones(shape=(len(y_train),))\nsample_weight[y_train == 5] = 2.0\n\nprint(\"Fit with sample weight\")\nmodel = get_compiled_model()\nmodel.fit(\n    x_train, y_train, sample_weight=sample_weight, batch_size=64, epochs=1\n)\n\n\"\"\"\nHere's a matching `Dataset` example:\n\"\"\"\n\nsample_weight = np.ones(shape=(len(y_train),))\nsample_weight[y_train == 5] = 2.0\n\n# Create a Dataset that includes sample weights\n# (3rd element in the return tuple).\ntrain_dataset = tf.data.Dataset.from_tensor_slices(\n    (x_train, y_train, sample_weight)\n)\n\n# Shuffle and slice the dataset.\ntrain_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)\n\nmodel = get_compiled_model()\nmodel.fit(train_dataset, epochs=1)\n\n\"\"\"\n## Passing data to multi-input, multi-output models\n\nIn the previous examples, we were considering a model with a single input (a tensor of\nshape `(764,)`) and a single output (a prediction tensor of shape `(10,)`). But what\nabout models that have multiple inputs or outputs?\n\nConsider the following model, which has an image input of shape `(32, 32, 3)` (that's\n`(height, width, channels)`) and a time series input of shape `(None, 10)` (that's\n`(timesteps, features)`). Our model will have two outputs computed from the\ncombination of these inputs: a \"score\" (of shape `(1,)`) and a probability\ndistribution over five classes (of shape `(5,)`).\n\"\"\"\n\nimage_input = keras.Input(shape=(32, 32, 3), name=\"img_input\")\ntimeseries_input = keras.Input(shape=(None, 10), name=\"ts_input\")\n\nx1 = layers.Conv2D(3, 3)(image_input)\nx1 = layers.GlobalMaxPooling2D()(x1)\n\nx2 = layers.Conv1D(3, 3)(timeseries_input)\nx2 = layers.GlobalMaxPooling1D()(x2)\n\nx = layers.concatenate([x1, x2])\n\nscore_output = layers.Dense(1, name=\"score_output\")(x)\nclass_output = layers.Dense(5, name=\"class_output\")(x)\n\nmodel = keras.Model(\n    inputs=[image_input, timeseries_input], outputs=[score_output, class_output]\n)\n\n\"\"\"\nLet's plot this model, so you can clearly see what we're doing here (note that the\nshapes shown in the plot are batch shapes, rather than per-sample shapes).\n\"\"\"\n\nkeras.utils.plot_model(\n    model, \"multi_input_and_output_model.png\", show_shapes=True\n)\n\n\"\"\"\nAt compilation time, we can specify different losses to different outputs, by passing\nthe loss functions as a list:\n\"\"\"\n\nmodel.compile(\n    optimizer=keras.optimizers.RMSprop(1e-3),\n    loss=[\n        keras.losses.MeanSquaredError(),\n        keras.losses.CategoricalCrossentropy(),\n    ],\n)\n\n\"\"\"\nIf we only passed a single loss function to the model, the same loss function would be\napplied to every output (which is not appropriate here).\n\nLikewise for metrics:\n\"\"\"\n\nmodel.compile(\n    optimizer=keras.optimizers.RMSprop(1e-3),\n    loss=[\n        keras.losses.MeanSquaredError(),\n        keras.losses.CategoricalCrossentropy(),\n    ],\n    metrics=[\n        [\n            keras.metrics.MeanAbsolutePercentageError(),\n            keras.metrics.MeanAbsoluteError(),\n        ],\n        [keras.metrics.CategoricalAccuracy()],\n    ],\n)\n\n\"\"\"\nSince we gave names to our output layers, we could also specify per-output losses and\nmetrics via a dict:\n\"\"\"\n\nmodel.compile(\n    optimizer=keras.optimizers.RMSprop(1e-3),\n    loss={\n        \"score_output\": keras.losses.MeanSquaredError(),\n        \"class_output\": keras.losses.CategoricalCrossentropy(),\n    },\n    metrics={\n        \"score_output\": [\n            keras.metrics.MeanAbsolutePercentageError(),\n            keras.metrics.MeanAbsoluteError(),\n        ],\n        \"class_output\": [keras.metrics.CategoricalAccuracy()],\n    },\n)\n\n\"\"\"\nWe recommend the use of explicit names and dicts if you have more than 2 outputs.\n\nIt's possible to give different weights to different output-specific losses (for\ninstance, one might wish to privilege the \"score\" loss in our example, by giving to 2x\nthe importance of the class loss), using the `loss_weights` argument:\n\"\"\"\n\nmodel.compile(\n    optimizer=keras.optimizers.RMSprop(1e-3),\n    loss={\n        \"score_output\": keras.losses.MeanSquaredError(),\n        \"class_output\": keras.losses.CategoricalCrossentropy(),\n    },\n    metrics={\n        \"score_output\": [\n            keras.metrics.MeanAbsolutePercentageError(),\n            keras.metrics.MeanAbsoluteError(),\n        ],\n        \"class_output\": [keras.metrics.CategoricalAccuracy()],\n    },\n    loss_weights={\"score_output\": 2.0, \"class_output\": 1.0},\n)\n\n\"\"\"\nYou could also choose not to compute a loss for certain outputs, if these outputs are\nmeant for prediction but not for training:\n\"\"\"\n\n# List loss version\nmodel.compile(\n    optimizer=keras.optimizers.RMSprop(1e-3),\n    loss=[None, keras.losses.CategoricalCrossentropy()],\n)\n\n# Or dict loss version\nmodel.compile(\n    optimizer=keras.optimizers.RMSprop(1e-3),\n    loss={\"class_output\": keras.losses.CategoricalCrossentropy()},\n)\n\n\"\"\"\nPassing data to a multi-input or multi-output model in `fit()` works in a similar way as\nspecifying a loss function in compile: you can pass **lists of NumPy arrays** (with\n1:1 mapping to the outputs that received a loss function) or **dicts mapping output\nnames to NumPy arrays**.\n\"\"\"\n\nmodel.compile(\n    optimizer=keras.optimizers.RMSprop(1e-3),\n    loss=[\n        keras.losses.MeanSquaredError(),\n        keras.losses.CategoricalCrossentropy(),\n    ],\n)\n\n# Generate dummy NumPy data\nimg_data = np.random.random_sample(size=(100, 32, 32, 3))\nts_data = np.random.random_sample(size=(100, 20, 10))\nscore_targets = np.random.random_sample(size=(100, 1))\nclass_targets = np.random.random_sample(size=(100, 5))\n\n# Fit on lists\nmodel.fit(\n    [img_data, ts_data], [score_targets, class_targets], batch_size=32, epochs=1\n)\n\n# Alternatively, fit on dicts\nmodel.fit(\n    {\"img_input\": img_data, \"ts_input\": ts_data},\n    {\"score_output\": score_targets, \"class_output\": class_targets},\n    batch_size=32,\n    epochs=1,\n)\n\n\"\"\"\nHere's the `Dataset` use case: similarly as what we did for NumPy arrays, the `Dataset`\nshould return a tuple of dicts.\n\"\"\"\n\ntrain_dataset = tf.data.Dataset.from_tensor_slices(\n    (\n        {\"img_input\": img_data, \"ts_input\": ts_data},\n        {\"score_output\": score_targets, \"class_output\": class_targets},\n    )\n)\ntrain_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)\n\nmodel.fit(train_dataset, epochs=1)\n\n\"\"\"\n## Using callbacks\n\nCallbacks in Keras are objects that are called at different points during training (at\nthe start of an epoch, at the end of a batch, at the end of an epoch, etc.). They\ncan be used to implement certain behaviors, such as:\n\n- Doing validation at different points during training (beyond the built-in per-epoch\nvalidation)\n- Checkpointing the model at regular intervals or when it exceeds a certain accuracy\nthreshold\n- Changing the learning rate of the model when training seems to be plateauing\n- Doing fine-tuning of the top layers when training seems to be plateauing\n- Sending email or instant message notifications when training ends or where a certain\nperformance threshold is exceeded\n- Etc.\n\nCallbacks can be passed as a list to your call to `fit()`:\n\"\"\"\n\nmodel = get_compiled_model()\n\ncallbacks = [\n    keras.callbacks.EarlyStopping(\n        # Stop training when `val_loss` is no longer improving\n        monitor=\"val_loss\",\n        # \"no longer improving\" being defined as \"no better than 1e-2 less\"\n        min_delta=1e-2,\n        # \"no longer improving\" being further defined as \"for at least 2 epochs\"\n        patience=2,\n        verbose=1,\n    )\n]\nmodel.fit(\n    x_train,\n    y_train,\n    epochs=20,\n    batch_size=64,\n    callbacks=callbacks,\n    validation_split=0.2,\n)\n\n\"\"\"\n### Many built-in callbacks are available\n\nThere are many built-in callbacks already available in Keras, such as:\n\n- `ModelCheckpoint`: Periodically save the model.\n- `EarlyStopping`: Stop training when training is no longer improving the validation\nmetrics.\n- `TensorBoard`: periodically write model logs that can be visualized in\n[TensorBoard](https://www.tensorflow.org/tensorboard) (more details in the section\n\"Visualization\").\n- `CSVLogger`: streams loss and metrics data to a CSV file.\n- etc.\n\nSee the [callbacks documentation](/api/callbacks/) for the complete list.\n\n### Writing your own callback\n\nYou can create a custom callback by extending the base class\n`keras.callbacks.Callback`. A callback has access to its associated model through the\nclass property `self.model`.\n\nMake sure to read the\n[complete guide to writing custom callbacks](/guides/writing_your_own_callbacks/).\n\nHere's a simple example saving a list of per-batch loss values during training:\n\"\"\"\n\n\nclass LossHistory(keras.callbacks.Callback):\n    def on_train_begin(self, logs):\n        self.per_batch_losses = []\n\n    def on_batch_end(self, batch, logs):\n        self.per_batch_losses.append(logs.get(\"loss\"))\n\n\n\"\"\"\n## Checkpointing models\n\nWhen you're training model on relatively large datasets, it's crucial to save\ncheckpoints of your model at frequent intervals.\n\nThe easiest way to achieve this is with the `ModelCheckpoint` callback:\n\"\"\"\n\nmodel = get_compiled_model()\n\ncallbacks = [\n    keras.callbacks.ModelCheckpoint(\n        # Path where to save the model\n        # The two parameters below mean that we will overwrite\n        # the current checkpoint if and only if\n        # the `val_loss` score has improved.\n        # The saved model name will include the current epoch.\n        filepath=\"mymodel_{epoch}.keras\",\n        save_best_only=True,  # Only save a model if `val_loss` has improved.\n        monitor=\"val_loss\",\n        verbose=1,\n    )\n]\nmodel.fit(\n    x_train,\n    y_train,\n    epochs=2,\n    batch_size=64,\n    callbacks=callbacks,\n    validation_split=0.2,\n)\n\n\"\"\"\nThe `ModelCheckpoint` callback can be used to implement fault-tolerance:\nthe ability to restart training from the last saved state of the model in case training\ngets randomly interrupted. Here's a basic example:\n\"\"\"\n\n# Prepare a directory to store all the checkpoints.\ncheckpoint_dir = \"./ckpt\"\nif not os.path.exists(checkpoint_dir):\n    os.makedirs(checkpoint_dir)\n\n\ndef make_or_restore_model():\n    # Either restore the latest model, or create a fresh one\n    # if there is no checkpoint available.\n    checkpoints = [\n        os.path.join(checkpoint_dir, name)\n        for name in os.listdir(checkpoint_dir)\n    ]\n    if checkpoints:\n        latest_checkpoint = max(checkpoints, key=os.path.getctime)\n        print(\"Restoring from\", latest_checkpoint)\n        return keras.models.load_model(latest_checkpoint)\n    print(\"Creating a new model\")\n    return get_compiled_model()\n\n\nmodel = make_or_restore_model()\ncallbacks = [\n    # This callback saves the model every 100 batches.\n    # We include the training loss in the saved model name.\n    keras.callbacks.ModelCheckpoint(\n        filepath=os.path.join(checkpoint_dir, \"model-loss={loss:.2f}.keras\"),\n        save_freq=100,\n    )\n]\nmodel.fit(x_train, y_train, epochs=1, callbacks=callbacks)\n\n\"\"\"\nYou call also write your own callback for saving and restoring models.\n\nFor a complete guide on serialization and saving, see the\n[guide to saving and serializing Models](/guides/serialization_and_saving/).\n\"\"\"\n\n\"\"\"\n## Using learning rate schedules\n\nA common pattern when training deep learning models is to gradually reduce the learning\nas training progresses. This is generally known as \"learning rate decay\".\n\nThe learning decay schedule could be static (fixed in advance, as a function of the\ncurrent epoch or the current batch index), or dynamic (responding to the current\nbehavior of the model, in particular the validation loss).\n\n### Passing a schedule to an optimizer\n\nYou can easily use a static learning rate decay schedule by passing a schedule object\nas the `learning_rate` argument in your optimizer:\n\"\"\"\n\ninitial_learning_rate = 0.1\nlr_schedule = keras.optimizers.schedules.ExponentialDecay(\n    initial_learning_rate, decay_steps=100000, decay_rate=0.96, staircase=True\n)\n\noptimizer = keras.optimizers.RMSprop(learning_rate=lr_schedule)\n\n\"\"\"\nSeveral built-in schedules are available: `ExponentialDecay`, `PiecewiseConstantDecay`,\n`PolynomialDecay`, and `InverseTimeDecay`.\n\n### Using callbacks to implement a dynamic learning rate schedule\n\nA dynamic learning rate schedule (for instance, decreasing the learning rate when the\nvalidation loss is no longer improving) cannot be achieved with these schedule objects,\nsince the optimizer does not have access to validation metrics.\n\nHowever, callbacks do have access to all metrics, including validation metrics! You can\nthus achieve this pattern by using a callback that modifies the current learning rate\non the optimizer. In fact, this is even built-in as the `ReduceLROnPlateau` callback.\n\"\"\"\n\n\"\"\"\n## Visualizing loss and metrics during training with TensorBoard\n\nThe best way to keep an eye on your model during training is to use\n[TensorBoard](https://www.tensorflow.org/tensorboard) -- a browser-based application\nthat you can run locally that provides you with:\n\n- Live plots of the loss and metrics for training and evaluation\n- (optionally) Visualizations of the histograms of your layer activations\n- (optionally) 3D visualizations of the embedding spaces learned by your `Embedding`\nlayers\n\nIf you have installed TensorFlow with pip, you should be able to launch TensorBoard\nfrom the command line:\n\n```\ntensorboard --logdir=/full_path_to_your_logs\n```\n\"\"\"\n\n\"\"\"\n### Using the TensorBoard callback\n\nThe easiest way to use TensorBoard with a Keras model and the `fit()` method is the\n`TensorBoard` callback.\n\nIn the simplest case, just specify where you want the callback to write logs, and\nyou're good to go:\n\"\"\"\n\nkeras.callbacks.TensorBoard(\n    log_dir=\"/full_path_to_your_logs\",\n    histogram_freq=0,  # How often to log histogram visualizations\n    embeddings_freq=0,  # How often to log embedding visualizations\n    update_freq=\"epoch\",\n)  # How often to write logs (default: once per epoch)\n\n\"\"\"\nFor more information, see the\n[documentation for the `TensorBoard` callback](https://keras.io/api/callbacks/tensorboard/).\n\"\"\"\n"
  },
  {
    "path": "guides/transfer_learning.py",
    "content": "\"\"\"\nTitle: Transfer learning & fine-tuning\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate created: 2020/04/15\nLast modified: 2023/06/25\nDescription: Complete guide to transfer learning & fine-tuning in Keras.\nAccelerator: GPU\n\"\"\"\n\n\"\"\"\n## Setup\n\"\"\"\n\nimport numpy as np\nimport keras\nfrom keras import layers\nimport tensorflow_datasets as tfds\nimport matplotlib.pyplot as plt\n\n\"\"\"\n## Introduction\n\n**Transfer learning** consists of taking features learned on one problem, and\nleveraging them on a new, similar problem. For instance, features from a model that has\nlearned to identify raccoons may be useful to kick-start a model meant to identify\n tanukis.\n\nTransfer learning is usually done for tasks where your dataset has too little data to\n train a full-scale model from scratch.\n\nThe most common incarnation of transfer learning in the context of deep learning is the\n following workflow:\n\n1. Take layers from a previously trained model.\n2. Freeze them, so as to avoid destroying any of the information they contain during\n future training rounds.\n3. Add some new, trainable layers on top of the frozen layers. They will learn to turn\n the old features into predictions on a  new dataset.\n4. Train the new layers on your dataset.\n\nA last, optional step, is **fine-tuning**, which consists of unfreezing the entire\nmodel you obtained above (or part of it), and re-training it on the new data with a\nvery low learning rate. This can potentially achieve meaningful improvements, by\n incrementally adapting the pretrained features to the new data.\n\nFirst, we will go over the Keras `trainable` API in detail, which underlies most\n transfer learning & fine-tuning workflows.\n\nThen, we'll demonstrate the typical workflow by taking a model pretrained on the\nImageNet dataset, and retraining it on the Kaggle \"cats vs dogs\" classification\n dataset.\n\nThis is adapted from\n[Deep Learning with Python](https://www.manning.com/books/deep-learning-with-python)\nand the 2016 blog post\n[\"building powerful image classification models using very little data\"](https://blog.keras.io/building-powerful-image-classification-models-using-very-little-data.html).\n\"\"\"\n\n\"\"\"\n## Freezing layers: understanding the `trainable` attribute\n\nLayers & models have three weight attributes:\n\n- `weights` is the list of all weights variables of the layer.\n- `trainable_weights` is the list of those that are meant to be updated (via gradient\n descent) to minimize the loss during training.\n- `non_trainable_weights` is the list of those that aren't meant to be trained.\n Typically they are updated by the model during the forward pass.\n\n**Example: the `Dense` layer has 2 trainable weights (kernel & bias)**\n\"\"\"\n\nlayer = keras.layers.Dense(3)\nlayer.build((None, 4))  # Create the weights\n\nprint(\"weights:\", len(layer.weights))\nprint(\"trainable_weights:\", len(layer.trainable_weights))\nprint(\"non_trainable_weights:\", len(layer.non_trainable_weights))\n\n\"\"\"\nIn general, all weights are trainable weights. The only built-in layer that has\nnon-trainable weights is the `BatchNormalization` layer. It uses non-trainable weights\n to keep track of the mean and variance of its inputs during training.\nTo learn how to use non-trainable weights in your own custom layers, see the\n[guide to writing new layers from scratch](https://keras.io/guides/making_new_layers_and_models_via_subclassing/).\n\n**Example: the `BatchNormalization` layer has 2 trainable weights and 2 non-trainable\n weights**\n\"\"\"\n\nlayer = keras.layers.BatchNormalization()\nlayer.build((None, 4))  # Create the weights\n\nprint(\"weights:\", len(layer.weights))\nprint(\"trainable_weights:\", len(layer.trainable_weights))\nprint(\"non_trainable_weights:\", len(layer.non_trainable_weights))\n\n\"\"\"\nLayers & models also feature a boolean attribute `trainable`. Its value can be changed.\nSetting `layer.trainable` to `False` moves all the layer's weights from trainable to\nnon-trainable.  This is called \"freezing\" the layer: the state of a frozen layer won't\nbe updated during training (either when training with `fit()` or when training with\n any custom loop that relies on `trainable_weights` to apply gradient updates).\n\n**Example: setting `trainable` to `False`**\n\"\"\"\n\nlayer = keras.layers.Dense(3)\nlayer.build((None, 4))  # Create the weights\nlayer.trainable = False  # Freeze the layer\n\nprint(\"weights:\", len(layer.weights))\nprint(\"trainable_weights:\", len(layer.trainable_weights))\nprint(\"non_trainable_weights:\", len(layer.non_trainable_weights))\n\n\"\"\"\nWhen a trainable weight becomes non-trainable, its value is no longer updated during\n training.\n\"\"\"\n\n# Make a model with 2 layers\nlayer1 = keras.layers.Dense(3, activation=\"relu\")\nlayer2 = keras.layers.Dense(3, activation=\"sigmoid\")\nmodel = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])\n\n# Freeze the first layer\nlayer1.trainable = False\n\n# Keep a copy of the weights of layer1 for later reference\ninitial_layer1_weights_values = layer1.get_weights()\n\n# Train the model\nmodel.compile(optimizer=\"adam\", loss=\"mse\")\nmodel.fit(np.random.random((2, 3)), np.random.random((2, 3)))\n\n# Check that the weights of layer1 have not changed during training\nfinal_layer1_weights_values = layer1.get_weights()\nnp.testing.assert_allclose(\n    initial_layer1_weights_values[0], final_layer1_weights_values[0]\n)\nnp.testing.assert_allclose(\n    initial_layer1_weights_values[1], final_layer1_weights_values[1]\n)\n\n\"\"\"\nDo not confuse the `layer.trainable` attribute with the argument `training` in\n`layer.__call__()` (which controls whether the layer should run its forward pass in\n inference mode or training mode). For more information, see the\n[Keras FAQ](\n  https://keras.io/getting_started/faq/#whats-the-difference-between-the-training-argument-in-call-and-the-trainable-attribute).\n\"\"\"\n\n\"\"\"\n## Recursive setting of the `trainable` attribute\n\nIf you set `trainable = False` on a model or on any layer that has sublayers,\nall children layers become non-trainable as well.\n\n**Example:**\n\"\"\"\n\ninner_model = keras.Sequential(\n    [\n        keras.Input(shape=(3,)),\n        keras.layers.Dense(3, activation=\"relu\"),\n        keras.layers.Dense(3, activation=\"relu\"),\n    ]\n)\n\nmodel = keras.Sequential(\n    [\n        keras.Input(shape=(3,)),\n        inner_model,\n        keras.layers.Dense(3, activation=\"sigmoid\"),\n    ]\n)\n\nmodel.trainable = False  # Freeze the outer model\n\nassert inner_model.trainable == False  # All layers in `model` are now frozen\nassert (\n    inner_model.layers[0].trainable == False\n)  # `trainable` is propagated recursively\n\n\"\"\"\n## The typical transfer-learning workflow\n\nThis leads us to how a typical transfer learning workflow can be implemented in Keras:\n\n1. Instantiate a base model and load pre-trained weights into it.\n2. Freeze all layers in the base model by setting `trainable = False`.\n3. Create a new model on top of the output of one (or several) layers from the base\n model.\n4. Train your new model on your new dataset.\n\nNote that an alternative, more lightweight workflow could also be:\n\n1. Instantiate a base model and load pre-trained weights into it.\n2. Run your new dataset through it and record the output of one (or several) layers\n from the base model. This is called **feature extraction**.\n3. Use that output as input data for a new, smaller model.\n\nA key advantage of that second workflow is that you only run the base model once on\n your data, rather than once per epoch of training. So it's a lot faster & cheaper.\n\nAn issue with that second workflow, though, is that it doesn't allow you to dynamically\nmodify the input data of your new model during training, which is required when doing\ndata augmentation, for instance. Transfer learning is typically used for tasks when\nyour new dataset has too little data to train a full-scale model from scratch, and in\nsuch scenarios data augmentation is very important. So in what follows, we will focus\n on the first workflow.\n\nHere's what the first workflow looks like in Keras:\n\nFirst, instantiate a base model with pre-trained weights.\n\n```python\nbase_model = keras.applications.Xception(\n    weights='imagenet',  # Load weights pre-trained on ImageNet.\n    input_shape=(150, 150, 3),\n    include_top=False)  # Do not include the ImageNet classifier at the top.\n```\n\nThen, freeze the base model.\n\n```python\nbase_model.trainable = False\n```\n\nCreate a new model on top.\n\n```python\ninputs = keras.Input(shape=(150, 150, 3))\n# We make sure that the base_model is running in inference mode here,\n# by passing `training=False`. This is important for fine-tuning, as you will\n# learn in a few paragraphs.\nx = base_model(inputs, training=False)\n# Convert features of shape `base_model.output_shape[1:]` to vectors\nx = keras.layers.GlobalAveragePooling2D()(x)\n# A Dense classifier with a single unit (binary classification)\noutputs = keras.layers.Dense(1)(x)\nmodel = keras.Model(inputs, outputs)\n```\n\nTrain the model on new data.\n\n```python\nmodel.compile(optimizer=keras.optimizers.Adam(),\n              loss=keras.losses.BinaryCrossentropy(from_logits=True),\n              metrics=[keras.metrics.BinaryAccuracy()])\nmodel.fit(new_dataset, epochs=20, callbacks=..., validation_data=...)\n```\n\n\"\"\"\n\n\"\"\"\n## Fine-tuning\n\nOnce your model has converged on the new data, you can try to unfreeze all or part of\n the base model and retrain the whole model end-to-end with a very low learning rate.\n\nThis is an optional last step that can potentially give you incremental improvements.\n It could also potentially lead to quick overfitting -- keep that in mind.\n\nIt is critical to only do this step *after* the model with frozen layers has been\ntrained to convergence. If you mix randomly-initialized trainable layers with\ntrainable layers that hold pre-trained features, the randomly-initialized layers will\ncause very large gradient updates during training, which will destroy your pre-trained\n features.\n\nIt's also critical to use a very low learning rate at this stage, because\nyou are training a much larger model than in the first round of training, on a dataset\n that is typically very small.\nAs a result, you are at risk of overfitting very quickly if you apply large weight\n updates. Here, you only want to readapt the pretrained weights in an incremental way.\n\nThis is how to implement fine-tuning of the whole base model:\n\n```python\n# Unfreeze the base model\nbase_model.trainable = True\n\n# It's important to recompile your model after you make any changes\n# to the `trainable` attribute of any inner layer, so that your changes\n# are take into account\nmodel.compile(optimizer=keras.optimizers.Adam(1e-5),  # Very low learning rate\n              loss=keras.losses.BinaryCrossentropy(from_logits=True),\n              metrics=[keras.metrics.BinaryAccuracy()])\n\n# Train end-to-end. Be careful to stop before you overfit!\nmodel.fit(new_dataset, epochs=10, callbacks=..., validation_data=...)\n```\n\n**Important note about `compile()` and `trainable`**\n\nCalling `compile()` on a model is meant to \"freeze\" the behavior of that model. This\n implies that the `trainable`\nattribute values at the time the model is compiled should be preserved throughout the\n lifetime of that model,\nuntil `compile` is called again. Hence, if you change any `trainable` value, make sure\n to call `compile()` again on your\nmodel for your changes to be taken into account.\n\n**Important notes about `BatchNormalization` layer**\n\nMany image models contain `BatchNormalization` layers. That layer is a special case on\n every imaginable count. Here are a few things to keep in mind.\n\n- `BatchNormalization` contains 2 non-trainable weights that get updated during\ntraining. These are the variables tracking the mean and variance of the inputs.\n- When you set `bn_layer.trainable = False`, the `BatchNormalization` layer will\nrun in inference mode, and will not update its mean & variance statistics. This is not\nthe case for other layers in general, as\n[weight trainability & inference/training modes are two orthogonal concepts](\n  https://keras.io/getting_started/faq/#whats-the-difference-between-the-training-argument-in-call-and-the-trainable-attribute).\nBut the two are tied in the case of the `BatchNormalization` layer.\n- When you unfreeze a model for finetuning by setting `base_model.trainable=True` that \ncontains `BatchNormalization` layers, then all layers of the base model become\ntrainable along with `BatchNormalization` layers. It's a good idea to keep\n`BatchNormalization` either frozen during fine-tuning, or running in inference mode,\nso remember to set `layer.trainable = False`\non those layers specifically after unfreezing the outer model, or otherwise\ncall the model with `training=False` to keep it inference mode.\n\nYou'll see this pattern in action in the end-to-end example at the end of this guide.\n\"\"\"\n\n\"\"\"\n## An end-to-end example: fine-tuning an image classification model on a cats vs. dogs dataset\n\nTo solidify these concepts, let's walk you through a concrete end-to-end transfer\nlearning & fine-tuning example. We will load the Xception model, pre-trained on\n ImageNet, and use it on the Kaggle \"cats vs. dogs\" classification dataset.\n\"\"\"\n\n\"\"\"\n### Getting the data\n\nFirst, let's fetch the cats vs. dogs dataset using TFDS. If you have your own dataset,\nyou'll probably want to use the utility\n`keras.utils.image_dataset_from_directory` to generate similar labeled\n dataset objects from a set of images on disk filed into class-specific folders.\n\nTransfer learning is most useful when working with very small datasets. To keep our\ndataset small, we will use 40% of the original training data (25,000 images) for\n training, 10% for validation, and 10% for testing.\n\"\"\"\n\ntfds.disable_progress_bar()\n\ntrain_ds, validation_ds, test_ds = tfds.load(\n    \"cats_vs_dogs\",\n    # Reserve 10% for validation and 10% for test\n    split=[\"train[:40%]\", \"train[40%:50%]\", \"train[50%:60%]\"],\n    as_supervised=True,  # Include labels\n)\n\nprint(f\"Number of training samples: {train_ds.cardinality()}\")\nprint(f\"Number of validation samples: {validation_ds.cardinality()}\")\nprint(f\"Number of test samples: {test_ds.cardinality()}\")\n\n\"\"\"\nThese are the first 9 images in the training dataset -- as you can see, they're all\ndifferent sizes.\n\"\"\"\n\nplt.figure(figsize=(10, 10))\nfor i, (image, label) in enumerate(train_ds.take(9)):\n    ax = plt.subplot(3, 3, i + 1)\n    plt.imshow(image)\n    plt.title(int(label))\n    plt.axis(\"off\")\n\n\"\"\"\nWe can also see that label 1 is \"dog\" and label 0 is \"cat\".\n\"\"\"\n\n\"\"\"\n### Standardizing the data\n\nOur raw images have a variety of sizes. In addition, each pixel consists of 3 integer\nvalues between 0 and 255 (RGB level values). This isn't a great fit for feeding a\nneural network. We need to do 2 things:\n\n- Standardize to a fixed image size. We pick 150x150.\n- Normalize pixel values between -1 and 1. We'll do this using a `Normalization` layer as\npart of the model itself.\n\nIn general, it's a good practice to develop models that take raw data as input, as\nopposed to models that take already-preprocessed data. The reason being that, if your\nmodel expects preprocessed data, any time you export your model to use it elsewhere\n(in a web browser, in a mobile app), you'll need to reimplement the exact same\npreprocessing pipeline. This gets very tricky very quickly. So we should do the least\n possible amount of preprocessing before hitting the model.\n\nHere, we'll do image resizing in the data pipeline (because a deep neural network can\nonly process contiguous batches of data), and we'll do the input value scaling as part\nof the model, when we create it.\n\nLet's resize images to 150x150:\n\"\"\"\n\nresize_fn = keras.layers.Resizing(150, 150)\n\ntrain_ds = train_ds.map(lambda x, y: (resize_fn(x), y))\nvalidation_ds = validation_ds.map(lambda x, y: (resize_fn(x), y))\ntest_ds = test_ds.map(lambda x, y: (resize_fn(x), y))\n\n\"\"\"\n### Using random data augmentation\n\nWhen you don't have a large image dataset, it's a good practice to artificially\nintroduce sample diversity by applying random yet realistic transformations to\nthe training images, such as random horizontal flipping or small random rotations. This\nhelps expose the model to different aspects of the training data while slowing down\noverfitting.\n\"\"\"\n\naugmentation_layers = [\n    layers.RandomFlip(\"horizontal\"),\n    layers.RandomRotation(0.1),\n]\n\n\ndef data_augmentation(x):\n    for layer in augmentation_layers:\n        x = layer(x)\n    return x\n\n\ntrain_ds = train_ds.map(lambda x, y: (data_augmentation(x), y))\n\n\"\"\"\nLet's batch the data and use prefetching to optimize loading speed.\n\"\"\"\n\nfrom tensorflow import data as tf_data\n\nbatch_size = 64\n\ntrain_ds = train_ds.batch(batch_size).prefetch(tf_data.AUTOTUNE).cache()\nvalidation_ds = (\n    validation_ds.batch(batch_size).prefetch(tf_data.AUTOTUNE).cache()\n)\ntest_ds = test_ds.batch(batch_size).prefetch(tf_data.AUTOTUNE).cache()\n\n\"\"\"\nLet's visualize what the first image of the first batch looks like after various random\n transformations:\n\"\"\"\n\nfor images, labels in train_ds.take(1):\n    plt.figure(figsize=(10, 10))\n    first_image = images[0]\n    for i in range(9):\n        ax = plt.subplot(3, 3, i + 1)\n        augmented_image = data_augmentation(np.expand_dims(first_image, 0))\n        plt.imshow(np.array(augmented_image[0]).astype(\"int32\"))\n        plt.title(int(labels[0]))\n        plt.axis(\"off\")\n\n\"\"\"\n## Build a model\n\nNow let's built a model that follows the blueprint we've explained earlier.\n\nNote that:\n\n- We add a `Rescaling` layer to scale input values (initially in the `[0, 255]`\n range) to the `[-1, 1]` range.\n- We add a `Dropout` layer before the classification layer, for regularization.\n- We make sure to pass `training=False` when calling the base model, so that\nit runs in inference mode, so that batchnorm statistics don't get updated\neven after we unfreeze the base model for fine-tuning.\n\"\"\"\n\nbase_model = keras.applications.Xception(\n    weights=\"imagenet\",  # Load weights pre-trained on ImageNet.\n    input_shape=(150, 150, 3),\n    include_top=False,\n)  # Do not include the ImageNet classifier at the top.\n\n# Freeze the base_model\nbase_model.trainable = False\n\n# Create new model on top\ninputs = keras.Input(shape=(150, 150, 3))\n\n# Pre-trained Xception weights requires that input be scaled\n# from (0, 255) to a range of (-1., +1.), the rescaling layer\n# outputs: `(inputs * scale) + offset`\nscale_layer = keras.layers.Rescaling(scale=1 / 127.5, offset=-1)\nx = scale_layer(inputs)\n\n# The base model contains batchnorm layers. We want to keep them in inference mode\n# when we unfreeze the base model for fine-tuning, so we make sure that the\n# base_model is running in inference mode here.\nx = base_model(x, training=False)\nx = keras.layers.GlobalAveragePooling2D()(x)\nx = keras.layers.Dropout(0.2)(x)  # Regularize with dropout\noutputs = keras.layers.Dense(1)(x)\nmodel = keras.Model(inputs, outputs)\n\nmodel.summary(show_trainable=True)\n\n\"\"\"\n## Train the top layer\n\"\"\"\n\nmodel.compile(\n    optimizer=keras.optimizers.Adam(),\n    loss=keras.losses.BinaryCrossentropy(from_logits=True),\n    metrics=[keras.metrics.BinaryAccuracy()],\n)\n\nepochs = 2\nprint(\"Fitting the top layer of the model\")\nmodel.fit(train_ds, epochs=epochs, validation_data=validation_ds)\n\n\"\"\"\n## Do a round of fine-tuning of the entire model\n\nFinally, let's unfreeze the base model and train the entire model end-to-end with a low\n learning rate.\n\nImportantly, although the base model becomes trainable, it is still running in\ninference mode since we passed `training=False` when calling it when we built the\nmodel. This means that the batch normalization layers inside won't update their batch\nstatistics. If they did, they would wreck havoc on the representations learned by the\n model so far.\n\"\"\"\n\n# Unfreeze the base_model. Note that it keeps running in inference mode\n# since we passed `training=False` when calling it. This means that\n# the batchnorm layers will not update their batch statistics.\n# This prevents the batchnorm layers from undoing all the training\n# we've done so far.\nbase_model.trainable = True\nmodel.summary(show_trainable=True)\n\nmodel.compile(\n    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate\n    loss=keras.losses.BinaryCrossentropy(from_logits=True),\n    metrics=[keras.metrics.BinaryAccuracy()],\n)\n\nepochs = 1\nprint(\"Fitting the end-to-end model\")\nmodel.fit(train_ds, epochs=epochs, validation_data=validation_ds)\n\n\"\"\"\nAfter 10 epochs, fine-tuning gains us a nice improvement here.\nLet's evaluate the model on the test dataset:\n\"\"\"\n\nprint(\"Test dataset evaluation\")\nmodel.evaluate(test_ds)\n"
  },
  {
    "path": "guides/understanding_masking_and_padding.py",
    "content": "\"\"\"\nTitle: Understanding masking & padding\nAuthors: Scott Zhu, Francois Chollet\nDate created: 2019/07/16\nLast modified: 2023/06/25\nDescription: Complete guide to using mask-aware sequence layers in Keras.\nAccelerator: None\n\"\"\"\n\n\"\"\"\n## Setup\n\"\"\"\nimport numpy as np\nimport keras\nfrom keras import ops\nfrom keras import layers\n\n\"\"\"\n## Introduction\n\n**Masking** is a way to tell sequence-processing layers that certain timesteps\nin an input are missing, and thus should be skipped when processing the data.\n\n**Padding** is a special form of masking where the masked steps are at the start or\nthe end of a sequence. Padding comes from the need to encode sequence data into\ncontiguous batches: in order to make all sequences in a batch fit a given standard\nlength, it is necessary to pad or truncate some sequences.\n\nLet's take a close look.\n\"\"\"\n\n\"\"\"\n## Padding sequence data\n\nWhen processing sequence data, it is very common for individual samples to have\ndifferent lengths. Consider the following example (text tokenized as words):\n\n```\n[\n  [\"Hello\", \"world\", \"!\"],\n  [\"How\", \"are\", \"you\", \"doing\", \"today\"],\n  [\"The\", \"weather\", \"will\", \"be\", \"nice\", \"tomorrow\"],\n]\n```\n\nAfter vocabulary lookup, the data might be vectorized as integers, e.g.:\n\n```\n[\n  [71, 1331, 4231]\n  [73, 8, 3215, 55, 927],\n  [83, 91, 1, 645, 1253, 927],\n]\n```\n\nThe data is a nested list where individual samples have length 3, 5, and 6,\nrespectively. Since the input data for a deep learning model must be a single tensor\n(of shape e.g. `(batch_size, 6, vocab_size)` in this case), samples that are shorter\nthan the longest item need to be padded with some placeholder value (alternatively,\none might also truncate long samples before padding short samples).\n\nKeras provides a utility function to truncate and pad Python lists to a common length:\n`keras.utils.pad_sequences`.\n\"\"\"\n\nraw_inputs = [\n    [711, 632, 71],\n    [73, 8, 3215, 55, 927],\n    [83, 91, 1, 645, 1253, 927],\n]\n\n# By default, this will pad using 0s; it is configurable via the\n# \"value\" parameter.\n# Note that you could use \"pre\" padding (at the beginning) or\n# \"post\" padding (at the end).\n# We recommend using \"post\" padding when working with RNN layers\n# (in order to be able to use the\n# CuDNN implementation of the layers).\npadded_inputs = keras.utils.pad_sequences(raw_inputs, padding=\"post\")\nprint(padded_inputs)\n\n\n\"\"\"\n## Masking\n\nNow that all samples have a uniform length, the model must be informed that some part\nof the data is actually padding and should be ignored. That mechanism is **masking**.\n\nThere are three ways to introduce input masks in Keras models:\n\n- Add a `keras.layers.Masking` layer.\n- Configure a `keras.layers.Embedding` layer with `mask_zero=True`.\n- Pass a `mask` argument manually when calling layers that support this argument (e.g.\nRNN layers).\n\"\"\"\n\n\"\"\"\n## Mask-generating layers: `Embedding` and `Masking`\n\nUnder the hood, these layers will create a mask tensor (2D tensor with shape `(batch,\nsequence_length)`), and attach it to the tensor output returned by the `Masking` or\n`Embedding` layer.\n\"\"\"\n\nembedding = layers.Embedding(input_dim=5000, output_dim=16, mask_zero=True)\nmasked_output = embedding(padded_inputs)\n\nprint(masked_output._keras_mask)\n\nmasking_layer = layers.Masking()\n# Simulate the embedding lookup by expanding the 2D input to 3D,\n# with embedding dimension of 10.\nunmasked_embedding = ops.cast(\n    ops.tile(ops.expand_dims(padded_inputs, axis=-1), [1, 1, 10]),\n    dtype=\"float32\",\n)\n\nmasked_embedding = masking_layer(unmasked_embedding)\nprint(masked_embedding._keras_mask)\n\n\"\"\"\nAs you can see from the printed result, the mask is a 2D boolean tensor with shape\n`(batch_size, sequence_length)`, where each individual `False` entry indicates that\nthe corresponding timestep should be ignored during processing.\n\"\"\"\n\n\"\"\"\n## Mask propagation in the Functional API and Sequential API\n\nWhen using the Functional API or the Sequential API, a mask generated by an `Embedding`\nor `Masking` layer will be propagated through the network for any layer that is\ncapable of using them (for example, RNN layers). Keras will automatically fetch the\nmask corresponding to an input and pass it to any layer that knows how to use it.\n\nFor instance, in the following Sequential model, the `LSTM` layer will automatically\nreceive a mask, which means it will ignore padded values:\n\"\"\"\n\nmodel = keras.Sequential(\n    [\n        layers.Embedding(input_dim=5000, output_dim=16, mask_zero=True),\n        layers.LSTM(32),\n    ]\n)\n\n\"\"\"\nThis is also the case for the following Functional API model:\n\"\"\"\n\ninputs = keras.Input(shape=(None,), dtype=\"int32\")\nx = layers.Embedding(input_dim=5000, output_dim=16, mask_zero=True)(inputs)\noutputs = layers.LSTM(32)(x)\n\nmodel = keras.Model(inputs, outputs)\n\n\"\"\"\n## Passing mask tensors directly to layers\n\"\"\"\n\n\"\"\"\nLayers that can handle masks (such as the `LSTM` layer) have a `mask` argument in their\n`__call__` method.\n\nMeanwhile, layers that produce a mask (e.g. `Embedding`) expose a `compute_mask(input,\nprevious_mask)` method which you can call.\n\nThus, you can pass the output of the `compute_mask()` method of a mask-producing layer\nto the `__call__` method of a mask-consuming layer, like this:\n\n\"\"\"\n\n\nclass MyLayer(layers.Layer):\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n        self.embedding = layers.Embedding(\n            input_dim=5000, output_dim=16, mask_zero=True\n        )\n        self.lstm = layers.LSTM(32)\n\n    def call(self, inputs):\n        x = self.embedding(inputs)\n        # Note that you could also prepare a `mask` tensor manually.\n        # It only needs to be a boolean tensor\n        # with the right shape, i.e. (batch_size, timesteps).\n        mask = self.embedding.compute_mask(inputs)\n        output = self.lstm(\n            x, mask=mask\n        )  # The layer will ignore the masked values\n        return output\n\n\nlayer = MyLayer()\nx = np.random.random((32, 10)) * 100\nx = x.astype(\"int32\")\nlayer(x)\n\n\"\"\"\n## Supporting masking in your custom layers\n\"\"\"\n\n\"\"\"\nSometimes, you may need to write layers that generate a mask (like `Embedding`), or\nlayers that need to modify the current mask.\n\nFor instance, any layer that produces a tensor with a different time dimension than its\ninput, such as a `Concatenate` layer that concatenates on the time dimension, will\nneed to modify the current mask so that downstream layers will be able to properly\ntake masked timesteps into account.\n\nTo do this, your layer should implement the `layer.compute_mask()` method, which\nproduces a new mask given the input and the current mask.\n\nHere is an example of a `TemporalSplit` layer that needs to modify the current mask.\n\"\"\"\n\n\nclass TemporalSplit(keras.layers.Layer):\n    \"\"\"Split the input tensor into 2 tensors along the time dimension.\"\"\"\n\n    def call(self, inputs):\n        # Expect the input to be 3D and mask to be 2D, split the input tensor into 2\n        # subtensors along the time axis (axis 1).\n        return ops.split(inputs, 2, axis=1)\n\n    def compute_mask(self, inputs, mask=None):\n        # Also split the mask into 2 if it presents.\n        if mask is None:\n            return None\n        return ops.split(mask, 2, axis=1)\n\n\nfirst_half, second_half = TemporalSplit()(masked_embedding)\nprint(first_half._keras_mask)\nprint(second_half._keras_mask)\n\n\"\"\"\nHere is another example of a `CustomEmbedding` layer that is capable of generating a\nmask from input values:\n\"\"\"\n\n\nclass CustomEmbedding(keras.layers.Layer):\n    def __init__(self, input_dim, output_dim, mask_zero=False, **kwargs):\n        super().__init__(**kwargs)\n        self.input_dim = input_dim\n        self.output_dim = output_dim\n        self.mask_zero = mask_zero\n\n    def build(self, input_shape):\n        self.embeddings = self.add_weight(\n            shape=(self.input_dim, self.output_dim),\n            initializer=\"random_normal\",\n            dtype=\"float32\",\n        )\n\n    def call(self, inputs):\n        inputs = ops.cast(inputs, \"int32\")\n        return ops.take(self.embeddings, inputs)\n\n    def compute_mask(self, inputs, mask=None):\n        if not self.mask_zero:\n            return None\n        return ops.not_equal(inputs, 0)\n\n\nlayer = CustomEmbedding(10, 32, mask_zero=True)\nx = np.random.random((3, 10)) * 9\nx = x.astype(\"int32\")\n\ny = layer(x)\nmask = layer.compute_mask(x)\n\nprint(mask)\n\n\"\"\"\nNote: For more details about format limitations related to masking, see the \n[serialization guide](/guides/serialization_and_saving).\n\"\"\"\n\n\"\"\"\n## Opting-in to mask propagation on compatible layers\n\nMost layers don't modify the time dimension, so don't need to modify the current mask.\nHowever, they may still want to be able to **propagate** the current mask, unchanged,\nto the next layer. **This is an opt-in behavior.** By default, a custom layer will\ndestroy the current mask (since the framework has no way to tell whether propagating\nthe mask is safe to do).\n\nIf you have a custom layer that does not modify the time dimension, and if you want it\nto be able to propagate the current input mask, you should set `self.supports_masking\n= True` in the layer constructor. In this case, the default behavior of\n`compute_mask()` is to just pass the current mask through.\n\nHere's an example of a layer that is whitelisted for mask propagation:\n\n\"\"\"\n\n\nclass MyActivation(keras.layers.Layer):\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n        # Signal that the layer is safe for mask propagation\n        self.supports_masking = True\n\n    def call(self, inputs):\n        return ops.relu(inputs)\n\n\n\"\"\"\nYou can now use this custom layer in-between a mask-generating layer (like `Embedding`)\nand a mask-consuming layer (like `LSTM`), and it will pass the mask along so that it\nreaches the mask-consuming layer.\n\"\"\"\n\ninputs = keras.Input(shape=(None,), dtype=\"int32\")\nx = layers.Embedding(input_dim=5000, output_dim=16, mask_zero=True)(inputs)\nx = MyActivation()(x)  # Will pass the mask along\nprint(\"Mask found:\", x._keras_mask)\noutputs = layers.LSTM(32)(x)  # Will receive the mask\n\nmodel = keras.Model(inputs, outputs)\ny = model(np.random.randint(0, 5000, size=(32, 100)))\n\n\"\"\"\n## Writing layers that need mask information\n\nSome layers are mask *consumers*: they accept a `mask` argument in `call` and use it to\ndetermine whether to skip certain time steps.\n\nTo write such a layer, you can simply add a `mask=None` argument in your `call`\nsignature. The mask associated with the inputs will be passed to your layer whenever\nit is available.\n\nHere's a simple example below: a layer that computes a softmax over the time dimension\n(axis 1) of an input sequence, while discarding masked timesteps.\n\"\"\"\n\n\nclass TemporalSoftmax(keras.layers.Layer):\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n        self.supports_masking = True\n\n    def call(self, inputs, mask=None):\n        assert mask is not None\n        broadcast_float_mask = ops.expand_dims(ops.cast(mask, \"float32\"), -1)\n        inputs_exp = ops.exp(inputs) * broadcast_float_mask\n        inputs_sum = ops.sum(\n            inputs_exp * broadcast_float_mask, axis=-1, keepdims=True\n        )\n        return inputs_exp / inputs_sum\n\n\ninputs = keras.Input(shape=(None,), dtype=\"int32\")\nx = layers.Embedding(input_dim=10, output_dim=32, mask_zero=True)(inputs)\nx = layers.Dense(1)(x)\noutputs = TemporalSoftmax()(x)\n\nmodel = keras.Model(inputs, outputs)\ny = model(np.random.randint(0, 10, size=(32, 100)))\n\n\"\"\"\n## Summary\n\nThat is all you need to know about padding & masking in Keras. To recap:\n\n- \"Masking\" is how layers are able to know when to skip / ignore certain timesteps in\nsequence inputs.\n- Some layers are mask-generators: `Embedding` can generate a mask from input values\n(if `mask_zero=True`), and so can the `Masking` layer.\n- Some layers are mask-consumers: they expose a `mask` argument in their `__call__`\nmethod. This is the case for RNN layers.\n- In the Functional API and Sequential API, mask information is propagated\nautomatically.\n- When using layers in a standalone way, you can pass the `mask` arguments to layers\nmanually.\n- You can easily write layers that modify the current mask, that generate a new mask,\nor that consume the mask associated with the inputs.\n\"\"\"\n"
  },
  {
    "path": "guides/writing_a_custom_training_loop_in_jax.py",
    "content": "\"\"\"\nTitle: Writing a training loop from scratch in JAX\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate created: 2023/06/25\nLast modified: 2023/06/25\nDescription: Writing low-level training & evaluation loops in JAX.\nAccelerator: None\n\"\"\"\n\n\"\"\"\n## Setup\n\"\"\"\n\nimport os\n\n# This guide can only be run with the jax backend.\nos.environ[\"KERAS_BACKEND\"] = \"jax\"\n\nimport jax\n\n# We import TF so we can use tf.data.\nimport tensorflow as tf\nimport keras\nimport numpy as np\n\n\"\"\"\n## Introduction\n\nKeras provides default training and evaluation loops, `fit()` and `evaluate()`.\nTheir usage is covered in the guide\n[Training & evaluation with the built-in methods](https://keras.io/guides/training_with_built_in_methods/).\n\nIf you want to customize the learning algorithm of your model while still leveraging\nthe convenience of `fit()`\n(for instance, to train a GAN using `fit()`), you can subclass the `Model` class and\nimplement your own `train_step()` method, which\nis called repeatedly during `fit()`.\n\nNow, if you want very low-level control over training & evaluation, you should write\nyour own training & evaluation loops from scratch. This is what this guide is about.\n\"\"\"\n\n\"\"\"\n## A first end-to-end example\n\nTo write a custom training loop, we need the following ingredients:\n\n- A model to train, of course.\n- An optimizer. You could either use an optimizer from `keras.optimizers`, or\none from the `optax` package.\n- A loss function.\n- A dataset. The standard in the JAX ecosystem is to load data via `tf.data`,\nso that's what we'll use.\n\nLet's line them up.\n\nFirst, let's get the model and the MNIST dataset:\n\"\"\"\n\n\ndef get_model():\n    inputs = keras.Input(shape=(784,), name=\"digits\")\n    x1 = keras.layers.Dense(64, activation=\"relu\")(inputs)\n    x2 = keras.layers.Dense(64, activation=\"relu\")(x1)\n    outputs = keras.layers.Dense(10, name=\"predictions\")(x2)\n    model = keras.Model(inputs=inputs, outputs=outputs)\n    return model\n\n\nmodel = get_model()\n\n# Prepare the training dataset.\nbatch_size = 32\n(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\nx_train = np.reshape(x_train, (-1, 784)).astype(\"float32\")\nx_test = np.reshape(x_test, (-1, 784)).astype(\"float32\")\ny_train = keras.utils.to_categorical(y_train)\ny_test = keras.utils.to_categorical(y_test)\n\n# Reserve 10,000 samples for validation.\nx_val = x_train[-10000:]\ny_val = y_train[-10000:]\nx_train = x_train[:-10000]\ny_train = y_train[:-10000]\n\n# Prepare the training dataset.\ntrain_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\ntrain_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)\n\n# Prepare the validation dataset.\nval_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))\nval_dataset = val_dataset.batch(batch_size)\n\n\"\"\"\nNext, here's the loss function and the optimizer.\nWe'll use a Keras optimizer in this case.\n\"\"\"\n\n# Instantiate a loss function.\nloss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)\n\n# Instantiate an optimizer.\noptimizer = keras.optimizers.Adam(learning_rate=1e-3)\n\n\"\"\"\n### Getting gradients in JAX \n\nLet's train our model using mini-batch gradient with a custom training loop.\n\nIn JAX, gradients are computed via *metaprogramming*: you call the `jax.grad` (or\n`jax.value_and_grad` on a function in order to create a gradient-computing function\nfor that first function.\n\nSo the first thing we need is a function that returns the loss value.\nThat's the function we'll use to generate the gradient function. Something like this:\n\n```python\ndef compute_loss(x, y):\n    ...\n    return loss\n```\n\nOnce you have such a function, you can compute gradients via metaprogramming as such:\n\n```python\ngrad_fn = jax.grad(compute_loss)\ngrads = grad_fn(x, y)\n```\n\nTypically, you don't just want to get the gradient values, you also want to get\nthe loss value. You can do this by using `jax.value_and_grad` instead of `jax.grad`:\n\n```python\ngrad_fn = jax.value_and_grad(compute_loss)\nloss, grads = grad_fn(x, y)\n```\n\n### JAX computation is purely stateless\n\nIn JAX, everything must be a stateless function -- so our loss computation function\nmust be stateless as well. That means that all Keras variables (e.g. weight tensors)\nmust be passed as function inputs, and any variable that has been updated during the\nforward pass must be returned as function output. The function have no side effect.\n\nDuring the forward pass, the non-trainable variables of a Keras model might get\nupdated. These variables could be, for instance, RNG seed state variables or\nBatchNormalization statistics. We're going to need to return those. So we need\nsomething like this:\n\n```python\ndef compute_loss_and_updates(trainable_variables, non_trainable_variables, x, y):\n    ...\n    return loss, non_trainable_variables\n```\n\nOnce you have such a function, you can get the gradient function by\nspecifying `hax_aux` in `value_and_grad`: it tells JAX that the loss\ncomputation function returns more outputs than just the loss. Note that the loss\nshould always be the first output.\n\n```python\ngrad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)\n(loss, non_trainable_variables), grads = grad_fn(\n    trainable_variables, non_trainable_variables, x, y\n)\n```\n\nNow that we have established the basics,\nlet's implement this `compute_loss_and_updates` function.\nKeras models have a `stateless_call` method which will come in handy here.\nIt works just like `model.__call__`, but it requires you to explicitly\npass the value of all the variables in the model, and it returns not just\nthe `__call__` outputs but also the (potentially updated) non-trainable\nvariables.\n\"\"\"\n\n\ndef compute_loss_and_updates(\n    trainable_variables, non_trainable_variables, x, y\n):\n    y_pred, non_trainable_variables = model.stateless_call(\n        trainable_variables, non_trainable_variables, x\n    )\n    loss = loss_fn(y, y_pred)\n    return loss, non_trainable_variables\n\n\n\"\"\"\nLet's get the gradient function:\n\"\"\"\n\ngrad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)\n\n\"\"\"\n### The training step function\n\nNext, let's implement the end-to-end training step, the function\nthat will both run the forward pass, compute the loss, compute the gradients,\nbut also use the optimizer to update the trainable variables. This function\nalso needs to be stateless, so it will get as input a `state` tuple that\nincludes every state element we're going to use:\n\n- `trainable_variables` and `non_trainable_variables`: the model's variables.\n- `optimizer_variables`: the optimizer's state variables,\nsuch as momentum accumulators.\n\nTo update the trainable variables, we use the optimizer's stateless method\n`stateless_apply`. It's equivalent to `optimizer.apply()`, but it requires\nalways passing `trainable_variables` and `optimizer_variables`. It returns\nboth the updated trainable variables and the updated optimizer_variables.\n\"\"\"\n\n\ndef train_step(state, data):\n    trainable_variables, non_trainable_variables, optimizer_variables = state\n    x, y = data\n    (loss, non_trainable_variables), grads = grad_fn(\n        trainable_variables, non_trainable_variables, x, y\n    )\n    trainable_variables, optimizer_variables = optimizer.stateless_apply(\n        optimizer_variables, grads, trainable_variables\n    )\n    # Return updated state\n    return loss, (\n        trainable_variables,\n        non_trainable_variables,\n        optimizer_variables,\n    )\n\n\n\"\"\"\n### Make it fast with `jax.jit`\n\nBy default, JAX operations run eagerly,\njust like in TensorFlow eager mode and PyTorch eager mode.\nAnd just like TensorFlow eager mode and PyTorch eager mode, it's pretty slow\n-- eager mode is better used as a debugging environment, not as a way to do\nany actual work. So let's make our `train_step` fast by compiling it.\n\nWhen you have a stateless JAX function, you can compile it to XLA via the \n`@jax.jit` decorator. It will get traced during its first execution, and in\nsubsequent executions you will be executing the traced graph (this is just\nlike `@tf.function(jit_compile=True)`. Let's try it:\n\"\"\"\n\n\n@jax.jit\ndef train_step(state, data):\n    trainable_variables, non_trainable_variables, optimizer_variables = state\n    x, y = data\n    (loss, non_trainable_variables), grads = grad_fn(\n        trainable_variables, non_trainable_variables, x, y\n    )\n    trainable_variables, optimizer_variables = optimizer.stateless_apply(\n        optimizer_variables, grads, trainable_variables\n    )\n    # Return updated state\n    return loss, (\n        trainable_variables,\n        non_trainable_variables,\n        optimizer_variables,\n    )\n\n\n\"\"\"\nWe're now ready to train our model. The training loop itself\nis trivial: we just repeatedly call `loss, state = train_step(state, data)`.\n\nNote:\n\n- We convert the TF tensors yielded by the `tf.data.Dataset` to NumPy\nbefore passing them to our JAX function.\n- All variables must be built beforehand:\nthe model must be built and the optimizer must be built. Since we're using a\nFunctional API model, it's already built, but if it were a subclassed model\nyou'd need to call it on a batch of data to build it.\n\"\"\"\n\n# Build optimizer variables.\noptimizer.build(model.trainable_variables)\n\ntrainable_variables = model.trainable_variables\nnon_trainable_variables = model.non_trainable_variables\noptimizer_variables = optimizer.variables\nstate = trainable_variables, non_trainable_variables, optimizer_variables\n\n# Training loop\nfor step, data in enumerate(train_dataset):\n    data = (data[0].numpy(), data[1].numpy())\n    loss, state = train_step(state, data)\n    # Log every 100 batches.\n    if step % 100 == 0:\n        print(f\"Training loss (for 1 batch) at step {step}: {float(loss):.4f}\")\n        print(f\"Seen so far: {(step + 1) * batch_size} samples\")\n\n\"\"\"\nA key thing to notice here is that the loop is entirely stateless -- the variables\nattached to the model (`model.weights`) are never getting updated during the loop.\nTheir new values are only stored in the `state` tuple. That means that at some point,\nbefore saving the model, you should be attaching the new variable values back to the model.\n\nJust call `variable.assign(new_value)` on each model variable you want to update:\n\"\"\"\n\ntrainable_variables, non_trainable_variables, optimizer_variables = state\nfor variable, value in zip(model.trainable_variables, trainable_variables):\n    variable.assign(value)\nfor variable, value in zip(\n    model.non_trainable_variables, non_trainable_variables\n):\n    variable.assign(value)\n\n\"\"\"\n## Low-level handling of metrics\n\nLet's add metrics monitoring to this basic training loop.\n\nYou can readily reuse built-in Keras metrics (or custom ones you wrote) in such training\nloops written from scratch. Here's the flow:\n\n- Instantiate the metric at the start of the loop\n- Include `metric_variables` in the `train_step` arguments\nand `compute_loss_and_updates` arguments.\n- Call `metric.stateless_update_state()` in the `compute_loss_and_updates` function.\nIt's equivalent to `update_state()` -- only stateless.\n- When you need to display the current value of the metric, outside the `train_step`\n(in the eager scope), attach the new metric variable values to the metric object\nand vall `metric.result()`.\n- Call `metric.reset_state()` when you need to clear the state of the metric\n(typically at the end of an epoch)\n\nLet's use this knowledge to compute `CategoricalAccuracy` on training and\nvalidation data at the end of training:\n\"\"\"\n\n# Get a fresh model\nmodel = get_model()\n\n# Instantiate an optimizer to train the model.\noptimizer = keras.optimizers.Adam(learning_rate=1e-3)\n# Instantiate a loss function.\nloss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)\n\n# Prepare the metrics.\ntrain_acc_metric = keras.metrics.CategoricalAccuracy()\nval_acc_metric = keras.metrics.CategoricalAccuracy()\n\n\ndef compute_loss_and_updates(\n    trainable_variables, non_trainable_variables, metric_variables, x, y\n):\n    y_pred, non_trainable_variables = model.stateless_call(\n        trainable_variables, non_trainable_variables, x\n    )\n    loss = loss_fn(y, y_pred)\n    metric_variables = train_acc_metric.stateless_update_state(\n        metric_variables, y, y_pred\n    )\n    return loss, (non_trainable_variables, metric_variables)\n\n\ngrad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)\n\n\n@jax.jit\ndef train_step(state, data):\n    (\n        trainable_variables,\n        non_trainable_variables,\n        optimizer_variables,\n        metric_variables,\n    ) = state\n    x, y = data\n    (loss, (non_trainable_variables, metric_variables)), grads = grad_fn(\n        trainable_variables, non_trainable_variables, metric_variables, x, y\n    )\n    trainable_variables, optimizer_variables = optimizer.stateless_apply(\n        optimizer_variables, grads, trainable_variables\n    )\n    # Return updated state\n    return loss, (\n        trainable_variables,\n        non_trainable_variables,\n        optimizer_variables,\n        metric_variables,\n    )\n\n\n\"\"\"\nWe'll also prepare an evaluation step function:\n\"\"\"\n\n\n@jax.jit\ndef eval_step(state, data):\n    trainable_variables, non_trainable_variables, metric_variables = state\n    x, y = data\n    y_pred, non_trainable_variables = model.stateless_call(\n        trainable_variables, non_trainable_variables, x\n    )\n    loss = loss_fn(y, y_pred)\n    metric_variables = val_acc_metric.stateless_update_state(\n        metric_variables, y, y_pred\n    )\n    return loss, (\n        trainable_variables,\n        non_trainable_variables,\n        metric_variables,\n    )\n\n\n\"\"\"\nHere are our loops:\n\"\"\"\n\n# Build optimizer variables.\noptimizer.build(model.trainable_variables)\n\ntrainable_variables = model.trainable_variables\nnon_trainable_variables = model.non_trainable_variables\noptimizer_variables = optimizer.variables\nmetric_variables = train_acc_metric.variables\nstate = (\n    trainable_variables,\n    non_trainable_variables,\n    optimizer_variables,\n    metric_variables,\n)\n\n# Training loop\nfor step, data in enumerate(train_dataset):\n    data = (data[0].numpy(), data[1].numpy())\n    loss, state = train_step(state, data)\n    # Log every 100 batches.\n    if step % 100 == 0:\n        print(f\"Training loss (for 1 batch) at step {step}: {float(loss):.4f}\")\n        _, _, _, metric_variables = state\n        for variable, value in zip(\n            train_acc_metric.variables, metric_variables\n        ):\n            variable.assign(value)\n        print(f\"Training accuracy: {train_acc_metric.result()}\")\n        print(f\"Seen so far: {(step + 1) * batch_size} samples\")\n\nmetric_variables = val_acc_metric.variables\n(\n    trainable_variables,\n    non_trainable_variables,\n    optimizer_variables,\n    metric_variables,\n) = state\nstate = trainable_variables, non_trainable_variables, metric_variables\n\n# Eval loop\nfor step, data in enumerate(val_dataset):\n    data = (data[0].numpy(), data[1].numpy())\n    loss, state = eval_step(state, data)\n    # Log every 100 batches.\n    if step % 100 == 0:\n        print(\n            f\"Validation loss (for 1 batch) at step {step}: {float(loss):.4f}\"\n        )\n        _, _, metric_variables = state\n        for variable, value in zip(val_acc_metric.variables, metric_variables):\n            variable.assign(value)\n        print(f\"Validation accuracy: {val_acc_metric.result()}\")\n        print(f\"Seen so far: {(step + 1) * batch_size} samples\")\n\n\"\"\"\n## Low-level handling of losses tracked by the model\n\nLayers & models recursively track any losses created during the forward pass\nby layers that call `self.add_loss(value)`. The resulting list of scalar loss\nvalues are available via the property `model.losses`\nat the end of the forward pass.\n\nIf you want to be using these loss components, you should sum them\nand add them to the main loss in your training step.\n\nConsider this layer, that creates an activity regularization loss:\n\"\"\"\n\n\nclass ActivityRegularizationLayer(keras.layers.Layer):\n    def call(self, inputs):\n        self.add_loss(1e-2 * jax.numpy.sum(inputs))\n        return inputs\n\n\n\"\"\"\nLet's build a really simple model that uses it:\n\"\"\"\n\ninputs = keras.Input(shape=(784,), name=\"digits\")\nx = keras.layers.Dense(64, activation=\"relu\")(inputs)\n# Insert activity regularization as a layer\nx = ActivityRegularizationLayer()(x)\nx = keras.layers.Dense(64, activation=\"relu\")(x)\noutputs = keras.layers.Dense(10, name=\"predictions\")(x)\n\nmodel = keras.Model(inputs=inputs, outputs=outputs)\n\n\"\"\"\nHere's what our `compute_loss_and_updates` function should look like now:\n\n- Pass `return_losses=True` to `model.stateless_call()`.\n- Sum the resulting `losses` and add them to the main loss.\n\"\"\"\n\n\ndef compute_loss_and_updates(\n    trainable_variables, non_trainable_variables, metric_variables, x, y\n):\n    y_pred, non_trainable_variables, losses = model.stateless_call(\n        trainable_variables, non_trainable_variables, x, return_losses=True\n    )\n    loss = loss_fn(y, y_pred)\n    if losses:\n        loss += jax.numpy.sum(losses)\n    metric_variables = train_acc_metric.stateless_update_state(\n        metric_variables, y, y_pred\n    )\n    return loss, non_trainable_variables, metric_variables\n\n\n\"\"\"\nThat's it!\n\"\"\"\n"
  },
  {
    "path": "guides/writing_a_custom_training_loop_in_tensorflow.py",
    "content": "\"\"\"\nTitle: Writing a training loop from scratch in TensorFlow\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate created: 2019/03/01\nLast modified: 2023/06/25\nDescription: Writing low-level training & evaluation loops in TensorFlow.\nAccelerator: None\n\"\"\"\n\n\"\"\"\n## Setup\n\"\"\"\n\nimport time\nimport os\n\n# This guide can only be run with the TensorFlow backend.\nos.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n\nimport tensorflow as tf\nimport keras\nimport numpy as np\n\n\"\"\"\n## Introduction\n\nKeras provides default training and evaluation loops, `fit()` and `evaluate()`.\nTheir usage is covered in the guide\n[Training & evaluation with the built-in methods](https://keras.io/guides/training_with_built_in_methods/).\n\nIf you want to customize the learning algorithm of your model while still leveraging\nthe convenience of `fit()`\n(for instance, to train a GAN using `fit()`), you can subclass the `Model` class and\nimplement your own `train_step()` method, which\nis called repeatedly during `fit()`.\n\nNow, if you want very low-level control over training & evaluation, you should write\nyour own training & evaluation loops from scratch. This is what this guide is about.\n\"\"\"\n\n\"\"\"\n## A first end-to-end example\n\nLet's consider a simple MNIST model:\n\"\"\"\n\n\ndef get_model():\n    inputs = keras.Input(shape=(784,), name=\"digits\")\n    x1 = keras.layers.Dense(64, activation=\"relu\")(inputs)\n    x2 = keras.layers.Dense(64, activation=\"relu\")(x1)\n    outputs = keras.layers.Dense(10, name=\"predictions\")(x2)\n    model = keras.Model(inputs=inputs, outputs=outputs)\n    return model\n\n\nmodel = get_model()\n\n\"\"\"\nLet's train it using mini-batch gradient with a custom training loop.\n\nFirst, we're going to need an optimizer, a loss function, and a dataset:\n\"\"\"\n\n# Instantiate an optimizer.\noptimizer = keras.optimizers.Adam(learning_rate=1e-3)\n# Instantiate a loss function.\nloss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n\n# Prepare the training dataset.\nbatch_size = 32\n(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\nx_train = np.reshape(x_train, (-1, 784))\nx_test = np.reshape(x_test, (-1, 784))\n\n# Reserve 10,000 samples for validation.\nx_val = x_train[-10000:]\ny_val = y_train[-10000:]\nx_train = x_train[:-10000]\ny_train = y_train[:-10000]\n\n# Prepare the training dataset.\ntrain_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\ntrain_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)\n\n# Prepare the validation dataset.\nval_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))\nval_dataset = val_dataset.batch(batch_size)\n\n\"\"\"\nCalling a model inside a `GradientTape` scope enables you to retrieve the gradients of\nthe trainable weights of the layer with respect to a loss value. Using an optimizer\ninstance, you can use these gradients to update these variables (which you can\nretrieve using `model.trainable_weights`).\n\nHere's our training loop, step by step:\n\n- We open a `for` loop that iterates over epochs\n- For each epoch, we open a `for` loop that iterates over the dataset, in batches\n- For each batch, we open a `GradientTape()` scope\n- Inside this scope, we call the model (forward pass) and compute the loss\n- Outside the scope, we retrieve the gradients of the weights\nof the model with regard to the loss\n- Finally, we use the optimizer to update the weights of the model based on the\ngradients\n\"\"\"\n\nepochs = 3\nfor epoch in range(epochs):\n    print(f\"\\nStart of epoch {epoch}\")\n\n    # Iterate over the batches of the dataset.\n    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):\n        # Open a GradientTape to record the operations run\n        # during the forward pass, which enables auto-differentiation.\n        with tf.GradientTape() as tape:\n            # Run the forward pass of the layer.\n            # The operations that the layer applies\n            # to its inputs are going to be recorded\n            # on the GradientTape.\n            logits = model(\n                x_batch_train, training=True\n            )  # Logits for this minibatch\n\n            # Compute the loss value for this minibatch.\n            loss_value = loss_fn(y_batch_train, logits)\n\n        # Use the gradient tape to automatically retrieve\n        # the gradients of the trainable variables with respect to the loss.\n        grads = tape.gradient(loss_value, model.trainable_weights)\n\n        # Run one step of gradient descent by updating\n        # the value of the variables to minimize the loss.\n        optimizer.apply(grads, model.trainable_weights)\n\n        # Log every 100 batches.\n        if step % 100 == 0:\n            print(\n                f\"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}\"\n            )\n            print(f\"Seen so far: {(step + 1) * batch_size} samples\")\n\n\"\"\"\n## Low-level handling of metrics\n\nLet's add metrics monitoring to this basic loop.\n\nYou can readily reuse the built-in metrics (or custom ones you wrote) in such training\nloops written from scratch. Here's the flow:\n\n- Instantiate the metric at the start of the loop\n- Call `metric.update_state()` after each batch\n- Call `metric.result()` when you need to display the current value of the metric\n- Call `metric.reset_state()` when you need to clear the state of the metric\n(typically at the end of an epoch)\n\nLet's use this knowledge to compute `SparseCategoricalAccuracy` on training and\nvalidation data at the end of each epoch:\n\"\"\"\n\n# Get a fresh model\nmodel = get_model()\n\n# Instantiate an optimizer to train the model.\noptimizer = keras.optimizers.Adam(learning_rate=1e-3)\n# Instantiate a loss function.\nloss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n\n# Prepare the metrics.\ntrain_acc_metric = keras.metrics.SparseCategoricalAccuracy()\nval_acc_metric = keras.metrics.SparseCategoricalAccuracy()\n\n\"\"\"\nHere's our training & evaluation loop:\n\"\"\"\n\nepochs = 2\nfor epoch in range(epochs):\n    print(f\"\\nStart of epoch {epoch}\")\n    start_time = time.time()\n\n    # Iterate over the batches of the dataset.\n    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):\n        with tf.GradientTape() as tape:\n            logits = model(x_batch_train, training=True)\n            loss_value = loss_fn(y_batch_train, logits)\n        grads = tape.gradient(loss_value, model.trainable_weights)\n        optimizer.apply(grads, model.trainable_weights)\n\n        # Update training metric.\n        train_acc_metric.update_state(y_batch_train, logits)\n\n        # Log every 100 batches.\n        if step % 100 == 0:\n            print(\n                f\"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}\"\n            )\n            print(f\"Seen so far: {(step + 1) * batch_size} samples\")\n\n    # Display metrics at the end of each epoch.\n    train_acc = train_acc_metric.result()\n    print(f\"Training acc over epoch: {float(train_acc):.4f}\")\n\n    # Reset training metrics at the end of each epoch\n    train_acc_metric.reset_state()\n\n    # Run a validation loop at the end of each epoch.\n    for x_batch_val, y_batch_val in val_dataset:\n        val_logits = model(x_batch_val, training=False)\n        # Update val metrics\n        val_acc_metric.update_state(y_batch_val, val_logits)\n    val_acc = val_acc_metric.result()\n    val_acc_metric.reset_state()\n    print(f\"Validation acc: {float(val_acc):.4f}\")\n    print(f\"Time taken: {time.time() - start_time:.2f}s\")\n\n\"\"\"\n## Speeding-up your training step with `tf.function`\n\nThe default runtime in TensorFlow is eager execution.\nAs such, our training loop above executes eagerly.\n\nThis is great for debugging, but graph compilation has a definite performance\nadvantage. Describing your computation as a static graph enables the framework\nto apply global performance optimizations. This is impossible when\nthe framework is constrained to greedily execute one operation after another,\nwith no knowledge of what comes next.\n\nYou can compile into a static graph any function that takes tensors as input.\nJust add a `@tf.function` decorator on it, like this:\n\"\"\"\n\n\n@tf.function\ndef train_step(x, y):\n    with tf.GradientTape() as tape:\n        logits = model(x, training=True)\n        loss_value = loss_fn(y, logits)\n    grads = tape.gradient(loss_value, model.trainable_weights)\n    optimizer.apply(grads, model.trainable_weights)\n    train_acc_metric.update_state(y, logits)\n    return loss_value\n\n\n\"\"\"\nLet's do the same with the evaluation step:\n\"\"\"\n\n\n@tf.function\ndef test_step(x, y):\n    val_logits = model(x, training=False)\n    val_acc_metric.update_state(y, val_logits)\n\n\n\"\"\"\nNow, let's re-run our training loop with this compiled training step:\n\"\"\"\n\nepochs = 2\nfor epoch in range(epochs):\n    print(f\"\\nStart of epoch {epoch}\")\n    start_time = time.time()\n\n    # Iterate over the batches of the dataset.\n    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):\n        loss_value = train_step(x_batch_train, y_batch_train)\n\n        # Log every 100 batches.\n        if step % 100 == 0:\n            print(\n                f\"Training loss (for 1 batch) at step {step}: {float(loss_value):.4f}\"\n            )\n            print(f\"Seen so far: {(step + 1) * batch_size} samples\")\n\n    # Display metrics at the end of each epoch.\n    train_acc = train_acc_metric.result()\n    print(f\"Training acc over epoch: {float(train_acc):.4f}\")\n\n    # Reset training metrics at the end of each epoch\n    train_acc_metric.reset_state()\n\n    # Run a validation loop at the end of each epoch.\n    for x_batch_val, y_batch_val in val_dataset:\n        test_step(x_batch_val, y_batch_val)\n\n    val_acc = val_acc_metric.result()\n    val_acc_metric.reset_state()\n    print(f\"Validation acc: {float(val_acc):.4f}\")\n    print(f\"Time taken: {time.time() - start_time:.2f}s\")\n\n\"\"\"\nMuch faster, isn't it?\n\"\"\"\n\n\"\"\"\n## Low-level handling of losses tracked by the model\n\nLayers & models recursively track any losses created during the forward pass\nby layers that call `self.add_loss(value)`. The resulting list of scalar loss\nvalues are available via the property `model.losses`\nat the end of the forward pass.\n\nIf you want to be using these loss components, you should sum them\nand add them to the main loss in your training step.\n\nConsider this layer, that creates an activity regularization loss:\n\n\"\"\"\n\n\nclass ActivityRegularizationLayer(keras.layers.Layer):\n    def call(self, inputs):\n        self.add_loss(1e-2 * tf.reduce_sum(inputs))\n        return inputs\n\n\n\"\"\"\nLet's build a really simple model that uses it:\n\"\"\"\n\ninputs = keras.Input(shape=(784,), name=\"digits\")\nx = keras.layers.Dense(64, activation=\"relu\")(inputs)\n# Insert activity regularization as a layer\nx = ActivityRegularizationLayer()(x)\nx = keras.layers.Dense(64, activation=\"relu\")(x)\noutputs = keras.layers.Dense(10, name=\"predictions\")(x)\n\nmodel = keras.Model(inputs=inputs, outputs=outputs)\n\n\"\"\"\nHere's what our training step should look like now:\n\"\"\"\n\n\n@tf.function\ndef train_step(x, y):\n    with tf.GradientTape() as tape:\n        logits = model(x, training=True)\n        loss_value = loss_fn(y, logits)\n        # Add any extra losses created during the forward pass.\n        loss_value += sum(model.losses)\n    grads = tape.gradient(loss_value, model.trainable_weights)\n    optimizer.apply(grads, model.trainable_weights)\n    train_acc_metric.update_state(y, logits)\n    return loss_value\n\n\n\"\"\"\n## Summary\n\nNow you know everything there is to know about using built-in training loops and\nwriting your own from scratch.\n\nTo conclude, here's a simple end-to-end example that ties together everything\nyou've learned in this guide: a DCGAN trained on MNIST digits.\n\"\"\"\n\n\"\"\"\n## End-to-end example: a GAN training loop from scratch\n\nYou may be familiar with Generative Adversarial Networks (GANs). GANs can generate new\nimages that look almost real, by learning the latent distribution of a training\ndataset of images (the \"latent space\" of the images).\n\nA GAN is made of two parts: a \"generator\" model that maps points in the latent\nspace to points in image space, a \"discriminator\" model, a classifier\nthat can tell the difference between real images (from the training dataset)\nand fake images (the output of the generator network).\n\nA GAN training loop looks like this:\n\n1) Train the discriminator.\n- Sample a batch of random points in the latent space.\n- Turn the points into fake images via the \"generator\" model.\n- Get a batch of real images and combine them with the generated images.\n- Train the \"discriminator\" model to classify generated vs. real images.\n\n2) Train the generator.\n- Sample random points in the latent space.\n- Turn the points into fake images via the \"generator\" network.\n- Get a batch of real images and combine them with the generated images.\n- Train the \"generator\" model to \"fool\" the discriminator and classify the fake images\nas real.\n\nFor a much more detailed overview of how GANs works, see\n[Deep Learning with Python](https://www.manning.com/books/deep-learning-with-python).\n\nLet's implement this training loop. First, create the discriminator meant to classify\nfake vs real digits:\n\"\"\"\n\ndiscriminator = keras.Sequential(\n    [\n        keras.Input(shape=(28, 28, 1)),\n        keras.layers.Conv2D(64, (3, 3), strides=(2, 2), padding=\"same\"),\n        keras.layers.LeakyReLU(negative_slope=0.2),\n        keras.layers.Conv2D(128, (3, 3), strides=(2, 2), padding=\"same\"),\n        keras.layers.LeakyReLU(negative_slope=0.2),\n        keras.layers.GlobalMaxPooling2D(),\n        keras.layers.Dense(1),\n    ],\n    name=\"discriminator\",\n)\ndiscriminator.summary()\n\n\"\"\"\nThen let's create a generator network,\nthat turns latent vectors into outputs of shape `(28, 28, 1)` (representing\nMNIST digits):\n\"\"\"\n\nlatent_dim = 128\n\ngenerator = keras.Sequential(\n    [\n        keras.Input(shape=(latent_dim,)),\n        # We want to generate 128 coefficients to reshape into a 7x7x128 map\n        keras.layers.Dense(7 * 7 * 128),\n        keras.layers.LeakyReLU(negative_slope=0.2),\n        keras.layers.Reshape((7, 7, 128)),\n        keras.layers.Conv2DTranspose(\n            128, (4, 4), strides=(2, 2), padding=\"same\"\n        ),\n        keras.layers.LeakyReLU(negative_slope=0.2),\n        keras.layers.Conv2DTranspose(\n            128, (4, 4), strides=(2, 2), padding=\"same\"\n        ),\n        keras.layers.LeakyReLU(negative_slope=0.2),\n        keras.layers.Conv2D(1, (7, 7), padding=\"same\", activation=\"sigmoid\"),\n    ],\n    name=\"generator\",\n)\n\n\"\"\"\nHere's the key bit: the training loop. As you can see it is quite straightforward. The\ntraining step function only takes 17 lines.\n\"\"\"\n\n# Instantiate one optimizer for the discriminator and another for the generator.\nd_optimizer = keras.optimizers.Adam(learning_rate=0.0003)\ng_optimizer = keras.optimizers.Adam(learning_rate=0.0004)\n\n# Instantiate a loss function.\nloss_fn = keras.losses.BinaryCrossentropy(from_logits=True)\n\n\n@tf.function\ndef train_step(real_images):\n    # Sample random points in the latent space\n    random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))\n    # Decode them to fake images\n    generated_images = generator(random_latent_vectors)\n    # Combine them with real images\n    combined_images = tf.concat([generated_images, real_images], axis=0)\n\n    # Assemble labels discriminating real from fake images\n    labels = tf.concat(\n        [tf.ones((batch_size, 1)), tf.zeros((real_images.shape[0], 1))], axis=0\n    )\n    # Add random noise to the labels - important trick!\n    labels += 0.05 * tf.random.uniform(labels.shape)\n\n    # Train the discriminator\n    with tf.GradientTape() as tape:\n        predictions = discriminator(combined_images)\n        d_loss = loss_fn(labels, predictions)\n    grads = tape.gradient(d_loss, discriminator.trainable_weights)\n    d_optimizer.apply(grads, discriminator.trainable_weights)\n\n    # Sample random points in the latent space\n    random_latent_vectors = tf.random.normal(shape=(batch_size, latent_dim))\n    # Assemble labels that say \"all real images\"\n    misleading_labels = tf.zeros((batch_size, 1))\n\n    # Train the generator (note that we should *not* update the weights\n    # of the discriminator)!\n    with tf.GradientTape() as tape:\n        predictions = discriminator(generator(random_latent_vectors))\n        g_loss = loss_fn(misleading_labels, predictions)\n    grads = tape.gradient(g_loss, generator.trainable_weights)\n    g_optimizer.apply(grads, generator.trainable_weights)\n    return d_loss, g_loss, generated_images\n\n\n\"\"\"\nLet's train our GAN, by repeatedly calling `train_step` on batches of images.\n\nSince our discriminator and generator are convnets, you're going to want to\nrun this code on a GPU.\n\"\"\"\n\n# Prepare the dataset. We use both the training & test MNIST digits.\nbatch_size = 64\n(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()\nall_digits = np.concatenate([x_train, x_test])\nall_digits = all_digits.astype(\"float32\") / 255.0\nall_digits = np.reshape(all_digits, (-1, 28, 28, 1))\ndataset = tf.data.Dataset.from_tensor_slices(all_digits)\ndataset = dataset.shuffle(buffer_size=1024).batch(batch_size)\n\nepochs = 1  # In practice you need at least 20 epochs to generate nice digits.\nsave_dir = \"./\"\n\nfor epoch in range(epochs):\n    print(f\"\\nStart epoch {epoch}\")\n\n    for step, real_images in enumerate(dataset):\n        # Train the discriminator & generator on one batch of real images.\n        d_loss, g_loss, generated_images = train_step(real_images)\n\n        # Logging.\n        if step % 100 == 0:\n            # Print metrics\n            print(f\"discriminator loss at step {step}: {d_loss:.2f}\")\n            print(f\"adversarial loss at step {step}: {g_loss:.2f}\")\n\n            # Save one generated image\n            img = keras.utils.array_to_img(\n                generated_images[0] * 255.0, scale=False\n            )\n            img.save(os.path.join(save_dir, f\"generated_img_{step}.png\"))\n\n        # To limit execution time we stop after 10 steps.\n        # Remove the lines below to actually train the model!\n        if step > 10:\n            break\n\n\"\"\"\nThat's it! You'll get nice-looking fake MNIST digits after just ~30s of training on the\nColab GPU.\n\"\"\"\n"
  },
  {
    "path": "guides/writing_a_custom_training_loop_in_torch.py",
    "content": "\"\"\"\nTitle: Writing a training loop from scratch in PyTorch\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate created: 2023/06/25\nLast modified: 2023/06/25\nDescription: Writing low-level training & evaluation loops in PyTorch.\nAccelerator: None\n\"\"\"\n\n\"\"\"\n## Setup\n\"\"\"\n\nimport os\n\n# This guide can only be run with the torch backend.\nos.environ[\"KERAS_BACKEND\"] = \"torch\"\n\nimport torch\nimport keras\nimport numpy as np\n\n\"\"\"\n## Introduction\n\nKeras provides default training and evaluation loops, `fit()` and `evaluate()`.\nTheir usage is covered in the guide\n[Training & evaluation with the built-in methods](https://keras.io/guides/training_with_built_in_methods/).\n\nIf you want to customize the learning algorithm of your model while still leveraging\nthe convenience of `fit()`\n(for instance, to train a GAN using `fit()`), you can subclass the `Model` class and\nimplement your own `train_step()` method, which\nis called repeatedly during `fit()`.\n\nNow, if you want very low-level control over training & evaluation, you should write\nyour own training & evaluation loops from scratch. This is what this guide is about.\n\"\"\"\n\n\"\"\"\n## A first end-to-end example\n\nTo write a custom training loop, we need the following ingredients:\n\n- A model to train, of course.\n- An optimizer. You could either use a `keras.optimizers` optimizer,\nor a native PyTorch optimizer from `torch.optim`.\n- A loss function. You could either use a `keras.losses` loss,\nor a native PyTorch loss from `torch.nn`.\n- A dataset. You could use any format: a `tf.data.Dataset`,\na PyTorch `DataLoader`, a Python generator, etc.\n\nLet's line them up. We'll use torch-native objects in each case --\nexcept, of course, for the Keras model.\n\nFirst, let's get the model and the MNIST dataset:\n\"\"\"\n\n\n# Let's consider a simple MNIST model\ndef get_model():\n    inputs = keras.Input(shape=(784,), name=\"digits\")\n    x1 = keras.layers.Dense(64, activation=\"relu\")(inputs)\n    x2 = keras.layers.Dense(64, activation=\"relu\")(x1)\n    outputs = keras.layers.Dense(10, name=\"predictions\")(x2)\n    model = keras.Model(inputs=inputs, outputs=outputs)\n    return model\n\n\n# Create load up the MNIST dataset and put it in a torch DataLoader\n# Prepare the training dataset.\nbatch_size = 32\n(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\nx_train = np.reshape(x_train, (-1, 784)).astype(\"float32\")\nx_test = np.reshape(x_test, (-1, 784)).astype(\"float32\")\ny_train = keras.utils.to_categorical(y_train)\ny_test = keras.utils.to_categorical(y_test)\n\n# Reserve 10,000 samples for validation.\nx_val = x_train[-10000:]\ny_val = y_train[-10000:]\nx_train = x_train[:-10000]\ny_train = y_train[:-10000]\n\n# Create torch Datasets\ntrain_dataset = torch.utils.data.TensorDataset(\n    torch.from_numpy(x_train), torch.from_numpy(y_train)\n)\nval_dataset = torch.utils.data.TensorDataset(\n    torch.from_numpy(x_val), torch.from_numpy(y_val)\n)\n\n# Create DataLoaders for the Datasets\ntrain_dataloader = torch.utils.data.DataLoader(\n    train_dataset, batch_size=batch_size, shuffle=True\n)\nval_dataloader = torch.utils.data.DataLoader(\n    val_dataset, batch_size=batch_size, shuffle=False\n)\n\n\"\"\"\nNext, here's our PyTorch optimizer and our PyTorch loss function:\n\"\"\"\n\n# Instantiate a torch optimizer\nmodel = get_model()\noptimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n\n# Instantiate a torch loss function\nloss_fn = torch.nn.CrossEntropyLoss()\n\n\"\"\"\nLet's train our model using mini-batch gradient with a custom training loop.\n\nCalling `loss.backward()` on a loss tensor triggers backpropagation.\nOnce that's done, your optimizer is magically aware of the gradients for each variable\nand can update its variables, which is done via `optimizer.step()`.\nTensors, variables, optimizers are all interconnected to one another via hidden global state.\nAlso, don't forget to call `model.zero_grad()` before `loss.backward()`, or you won't\nget the right gradients for your variables.\n\nHere's our training loop, step by step:\n\n- We open a `for` loop that iterates over epochs\n- For each epoch, we open a `for` loop that iterates over the dataset, in batches\n- For each batch, we call the model on the input data to retrieve the predictions,\nthen we use them to compute a loss value\n- We call `loss.backward()` to \n- Outside the scope, we retrieve the gradients of the weights\nof the model with regard to the loss\n- Finally, we use the optimizer to update the weights of the model based on the\ngradients\n\"\"\"\n\nepochs = 3\nfor epoch in range(epochs):\n    for step, (inputs, targets) in enumerate(train_dataloader):\n        # Forward pass\n        logits = model(inputs)\n        loss = loss_fn(logits, targets)\n\n        # Backward pass\n        model.zero_grad()\n        loss.backward()\n\n        # Optimizer variable updates\n        optimizer.step()\n\n        # Log every 100 batches.\n        if step % 100 == 0:\n            print(\n                f\"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}\"\n            )\n            print(f\"Seen so far: {(step + 1) * batch_size} samples\")\n\n\"\"\"\nAs an alternative, let's look at what the loop looks like when using a Keras optimizer\nand a Keras loss function.\n\nImportant differences:\n\n- You retrieve the gradients for the variables via `v.value.grad`,\ncalled on each trainable variable.\n- You update your variables via `optimizer.apply()`, which must be\ncalled in a `torch.no_grad()` scope.\n\n**Also, a big gotcha:** while all NumPy/TensorFlow/JAX/Keras APIs\nas well as Python `unittest` APIs use the argument order convention\n`fn(y_true, y_pred)` (reference values first, predicted values second),\nPyTorch actually uses `fn(y_pred, y_true)` for its losses.\nSo make sure to invert the order of `logits` and `targets`.\n\"\"\"\n\nmodel = get_model()\noptimizer = keras.optimizers.Adam(learning_rate=1e-3)\nloss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)\n\nfor epoch in range(epochs):\n    print(f\"\\nStart of epoch {epoch}\")\n    for step, (inputs, targets) in enumerate(train_dataloader):\n        # Forward pass\n        logits = model(inputs)\n        loss = loss_fn(targets, logits)\n\n        # Backward pass\n        model.zero_grad()\n        trainable_weights = [v for v in model.trainable_weights]\n\n        # Call torch.Tensor.backward() on the loss to compute gradients\n        # for the weights.\n        loss.backward()\n        gradients = [v.value.grad for v in trainable_weights]\n\n        # Update weights\n        with torch.no_grad():\n            optimizer.apply(gradients, trainable_weights)\n\n        # Log every 100 batches.\n        if step % 100 == 0:\n            print(\n                f\"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}\"\n            )\n            print(f\"Seen so far: {(step + 1) * batch_size} samples\")\n\n\"\"\"\n## Low-level handling of metrics\n\nLet's add metrics monitoring to this basic training loop.\n\nYou can readily reuse built-in Keras metrics (or custom ones you wrote) in such training\nloops written from scratch. Here's the flow:\n\n- Instantiate the metric at the start of the loop\n- Call `metric.update_state()` after each batch\n- Call `metric.result()` when you need to display the current value of the metric\n- Call `metric.reset_state()` when you need to clear the state of the metric\n(typically at the end of an epoch)\n\nLet's use this knowledge to compute `CategoricalAccuracy` on training and\nvalidation data at the end of each epoch:\n\"\"\"\n\n# Get a fresh model\nmodel = get_model()\n\n# Instantiate an optimizer to train the model.\noptimizer = keras.optimizers.Adam(learning_rate=1e-3)\n# Instantiate a loss function.\nloss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)\n\n# Prepare the metrics.\ntrain_acc_metric = keras.metrics.CategoricalAccuracy()\nval_acc_metric = keras.metrics.CategoricalAccuracy()\n\n\"\"\"\nHere's our training & evaluation loop:\n\"\"\"\n\nfor epoch in range(epochs):\n    print(f\"\\nStart of epoch {epoch}\")\n    for step, (inputs, targets) in enumerate(train_dataloader):\n        # Forward pass\n        logits = model(inputs)\n        loss = loss_fn(targets, logits)\n\n        # Backward pass\n        model.zero_grad()\n        trainable_weights = [v for v in model.trainable_weights]\n\n        # Call torch.Tensor.backward() on the loss to compute gradients\n        # for the weights.\n        loss.backward()\n        gradients = [v.value.grad for v in trainable_weights]\n\n        # Update weights\n        with torch.no_grad():\n            optimizer.apply(gradients, trainable_weights)\n\n        # Update training metric.\n        train_acc_metric.update_state(targets, logits)\n\n        # Log every 100 batches.\n        if step % 100 == 0:\n            print(\n                f\"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}\"\n            )\n            print(f\"Seen so far: {(step + 1) * batch_size} samples\")\n\n    # Display metrics at the end of each epoch.\n    train_acc = train_acc_metric.result()\n    print(f\"Training acc over epoch: {float(train_acc):.4f}\")\n\n    # Reset training metrics at the end of each epoch\n    train_acc_metric.reset_state()\n\n    # Run a validation loop at the end of each epoch.\n    for x_batch_val, y_batch_val in val_dataloader:\n        val_logits = model(x_batch_val, training=False)\n        # Update val metrics\n        val_acc_metric.update_state(y_batch_val, val_logits)\n    val_acc = val_acc_metric.result()\n    val_acc_metric.reset_state()\n    print(f\"Validation acc: {float(val_acc):.4f}\")\n\n\n\"\"\"\n## Low-level handling of losses tracked by the model\n\nLayers & models recursively track any losses created during the forward pass\nby layers that call `self.add_loss(value)`. The resulting list of scalar loss\nvalues are available via the property `model.losses`\nat the end of the forward pass.\n\nIf you want to be using these loss components, you should sum them\nand add them to the main loss in your training step.\n\nConsider this layer, that creates an activity regularization loss:\n\"\"\"\n\n\nclass ActivityRegularizationLayer(keras.layers.Layer):\n    def call(self, inputs):\n        self.add_loss(1e-2 * torch.sum(inputs))\n        return inputs\n\n\n\"\"\"\nLet's build a really simple model that uses it:\n\"\"\"\n\ninputs = keras.Input(shape=(784,), name=\"digits\")\nx = keras.layers.Dense(64, activation=\"relu\")(inputs)\n# Insert activity regularization as a layer\nx = ActivityRegularizationLayer()(x)\nx = keras.layers.Dense(64, activation=\"relu\")(x)\noutputs = keras.layers.Dense(10, name=\"predictions\")(x)\n\nmodel = keras.Model(inputs=inputs, outputs=outputs)\n\n\"\"\"\nHere's what our training loop should look like now:\n\"\"\"\n\n# Get a fresh model\nmodel = get_model()\n\n# Instantiate an optimizer to train the model.\noptimizer = keras.optimizers.Adam(learning_rate=1e-3)\n# Instantiate a loss function.\nloss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)\n\n# Prepare the metrics.\ntrain_acc_metric = keras.metrics.CategoricalAccuracy()\nval_acc_metric = keras.metrics.CategoricalAccuracy()\n\nfor epoch in range(epochs):\n    print(f\"\\nStart of epoch {epoch}\")\n    for step, (inputs, targets) in enumerate(train_dataloader):\n        # Forward pass\n        logits = model(inputs)\n        loss = loss_fn(targets, logits)\n        if model.losses:\n            loss = loss + torch.sum(*model.losses)\n\n        # Backward pass\n        model.zero_grad()\n        trainable_weights = [v for v in model.trainable_weights]\n\n        # Call torch.Tensor.backward() on the loss to compute gradients\n        # for the weights.\n        loss.backward()\n        gradients = [v.value.grad for v in trainable_weights]\n\n        # Update weights\n        with torch.no_grad():\n            optimizer.apply(gradients, trainable_weights)\n\n        # Update training metric.\n        train_acc_metric.update_state(targets, logits)\n\n        # Log every 100 batches.\n        if step % 100 == 0:\n            print(\n                f\"Training loss (for 1 batch) at step {step}: {loss.detach().numpy():.4f}\"\n            )\n            print(f\"Seen so far: {(step + 1) * batch_size} samples\")\n\n    # Display metrics at the end of each epoch.\n    train_acc = train_acc_metric.result()\n    print(f\"Training acc over epoch: {float(train_acc):.4f}\")\n\n    # Reset training metrics at the end of each epoch\n    train_acc_metric.reset_state()\n\n    # Run a validation loop at the end of each epoch.\n    for x_batch_val, y_batch_val in val_dataloader:\n        val_logits = model(x_batch_val, training=False)\n        # Update val metrics\n        val_acc_metric.update_state(y_batch_val, val_logits)\n    val_acc = val_acc_metric.result()\n    val_acc_metric.reset_state()\n    print(f\"Validation acc: {float(val_acc):.4f}\")\n\n\"\"\"\nThat's it!\n\"\"\"\n"
  },
  {
    "path": "guides/writing_your_own_callbacks.py",
    "content": "\"\"\"\nTitle: Writing your own callbacks\nAuthors: Rick Chao, Francois Chollet\nDate created: 2019/03/20\nLast modified: 2023/06/25\nDescription: Complete guide to writing new Keras callbacks.\nAccelerator: GPU\n\"\"\"\n\n\"\"\"\n## Introduction\n\nA callback is a powerful tool to customize the behavior of a Keras model during\ntraining, evaluation, or inference. Examples include `keras.callbacks.TensorBoard`\nto visualize training progress and results with TensorBoard, or\n`keras.callbacks.ModelCheckpoint` to periodically save your model during training.\n\nIn this guide, you will learn what a Keras callback is, what it can do, and how you can\nbuild your own. We provide a few demos of simple callback applications to get you\nstarted.\n\"\"\"\n\n\"\"\"\n## Setup\n\"\"\"\n\nimport numpy as np\nimport keras\n\n\"\"\"\n## Keras callbacks overview\n\nAll callbacks subclass the `keras.callbacks.Callback` class, and\noverride a set of methods called at various stages of training, testing, and\npredicting. Callbacks are useful to get a view on internal states and statistics of\nthe model during training.\n\nYou can pass a list of callbacks (as the keyword argument `callbacks`) to the following\nmodel methods:\n\n- `keras.Model.fit()`\n- `keras.Model.evaluate()`\n- `keras.Model.predict()`\n\"\"\"\n\n\"\"\"\n## An overview of callback methods\n\n### Global methods\n\n#### `on_(train|test|predict)_begin(self, logs=None)`\n\nCalled at the beginning of `fit`/`evaluate`/`predict`.\n\n#### `on_(train|test|predict)_end(self, logs=None)`\n\nCalled at the end of `fit`/`evaluate`/`predict`.\n\n### Batch-level methods for training/testing/predicting\n\n#### `on_(train|test|predict)_batch_begin(self, batch, logs=None)`\n\nCalled right before processing a batch during training/testing/predicting.\n\n#### `on_(train|test|predict)_batch_end(self, batch, logs=None)`\n\nCalled at the end of training/testing/predicting a batch. Within this method, `logs` is\na dict containing the metrics results.\n\n### Epoch-level methods (training only)\n\n#### `on_epoch_begin(self, epoch, logs=None)`\n\nCalled at the beginning of an epoch during training.\n\n#### `on_epoch_end(self, epoch, logs=None)`\n\nCalled at the end of an epoch during training.\n\"\"\"\n\n\"\"\"\n## A basic example\n\nLet's take a look at a concrete example. To get started, let's import tensorflow and\ndefine a simple Sequential Keras model:\n\"\"\"\n\n\n# Define the Keras model to add callbacks to\ndef get_model():\n    model = keras.Sequential()\n    model.add(keras.layers.Dense(1))\n    model.compile(\n        optimizer=keras.optimizers.RMSprop(learning_rate=0.1),\n        loss=\"mean_squared_error\",\n        metrics=[\"mean_absolute_error\"],\n    )\n    return model\n\n\n\"\"\"\nThen, load the MNIST data for training and testing from Keras datasets API:\n\"\"\"\n\n# Load example MNIST data and pre-process it\n(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\nx_train = x_train.reshape(-1, 784).astype(\"float32\") / 255.0\nx_test = x_test.reshape(-1, 784).astype(\"float32\") / 255.0\n\n# Limit the data to 1000 samples\nx_train = x_train[:1000]\ny_train = y_train[:1000]\nx_test = x_test[:1000]\ny_test = y_test[:1000]\n\n\"\"\"\nNow, define a simple custom callback that logs:\n\n- When `fit`/`evaluate`/`predict` starts & ends\n- When each epoch starts & ends\n- When each training batch starts & ends\n- When each evaluation (test) batch starts & ends\n- When each inference (prediction) batch starts & ends\n\"\"\"\n\n\nclass CustomCallback(keras.callbacks.Callback):\n    def on_train_begin(self, logs=None):\n        keys = list(logs.keys())\n        print(\"Starting training; got log keys: {}\".format(keys))\n\n    def on_train_end(self, logs=None):\n        keys = list(logs.keys())\n        print(\"Stop training; got log keys: {}\".format(keys))\n\n    def on_epoch_begin(self, epoch, logs=None):\n        keys = list(logs.keys())\n        print(\n            \"Start epoch {} of training; got log keys: {}\".format(epoch, keys)\n        )\n\n    def on_epoch_end(self, epoch, logs=None):\n        keys = list(logs.keys())\n        print(\"End epoch {} of training; got log keys: {}\".format(epoch, keys))\n\n    def on_test_begin(self, logs=None):\n        keys = list(logs.keys())\n        print(\"Start testing; got log keys: {}\".format(keys))\n\n    def on_test_end(self, logs=None):\n        keys = list(logs.keys())\n        print(\"Stop testing; got log keys: {}\".format(keys))\n\n    def on_predict_begin(self, logs=None):\n        keys = list(logs.keys())\n        print(\"Start predicting; got log keys: {}\".format(keys))\n\n    def on_predict_end(self, logs=None):\n        keys = list(logs.keys())\n        print(\"Stop predicting; got log keys: {}\".format(keys))\n\n    def on_train_batch_begin(self, batch, logs=None):\n        keys = list(logs.keys())\n        print(\n            \"...Training: start of batch {}; got log keys: {}\".format(\n                batch, keys\n            )\n        )\n\n    def on_train_batch_end(self, batch, logs=None):\n        keys = list(logs.keys())\n        print(\n            \"...Training: end of batch {}; got log keys: {}\".format(batch, keys)\n        )\n\n    def on_test_batch_begin(self, batch, logs=None):\n        keys = list(logs.keys())\n        print(\n            \"...Evaluating: start of batch {}; got log keys: {}\".format(\n                batch, keys\n            )\n        )\n\n    def on_test_batch_end(self, batch, logs=None):\n        keys = list(logs.keys())\n        print(\n            \"...Evaluating: end of batch {}; got log keys: {}\".format(\n                batch, keys\n            )\n        )\n\n    def on_predict_batch_begin(self, batch, logs=None):\n        keys = list(logs.keys())\n        print(\n            \"...Predicting: start of batch {}; got log keys: {}\".format(\n                batch, keys\n            )\n        )\n\n    def on_predict_batch_end(self, batch, logs=None):\n        keys = list(logs.keys())\n        print(\n            \"...Predicting: end of batch {}; got log keys: {}\".format(\n                batch, keys\n            )\n        )\n\n\n\"\"\"\nLet's try it out:\n\"\"\"\n\nmodel = get_model()\nmodel.fit(\n    x_train,\n    y_train,\n    batch_size=128,\n    epochs=1,\n    verbose=0,\n    validation_split=0.5,\n    callbacks=[CustomCallback()],\n)\n\nres = model.evaluate(\n    x_test, y_test, batch_size=128, verbose=0, callbacks=[CustomCallback()]\n)\n\nres = model.predict(x_test, batch_size=128, callbacks=[CustomCallback()])\n\n\"\"\"\n### Usage of `logs` dict\n\nThe `logs` dict contains the loss value, and all the metrics at the end of a batch or\nepoch. Example includes the loss and mean absolute error.\n\"\"\"\n\n\nclass LossAndErrorPrintingCallback(keras.callbacks.Callback):\n    def on_train_batch_end(self, batch, logs=None):\n        print(\n            \"Up to batch {}, the average loss is {:7.2f}.\".format(\n                batch, logs[\"loss\"]\n            )\n        )\n\n    def on_test_batch_end(self, batch, logs=None):\n        print(\n            \"Up to batch {}, the average loss is {:7.2f}.\".format(\n                batch, logs[\"loss\"]\n            )\n        )\n\n    def on_epoch_end(self, epoch, logs=None):\n        print(\n            \"The average loss for epoch {} is {:7.2f} \"\n            \"and mean absolute error is {:7.2f}.\".format(\n                epoch, logs[\"loss\"], logs[\"mean_absolute_error\"]\n            )\n        )\n\n\nmodel = get_model()\nmodel.fit(\n    x_train,\n    y_train,\n    batch_size=128,\n    epochs=2,\n    verbose=0,\n    callbacks=[LossAndErrorPrintingCallback()],\n)\n\nres = model.evaluate(\n    x_test,\n    y_test,\n    batch_size=128,\n    verbose=0,\n    callbacks=[LossAndErrorPrintingCallback()],\n)\n\n\"\"\"\n## Usage of `self.model` attribute\n\nIn addition to receiving log information when one of their methods is called,\ncallbacks have access to the model associated with the current round of\ntraining/evaluation/inference: `self.model`.\n\nHere are a few of the things you can do with `self.model` in a callback:\n\n- Set `self.model.stop_training = True` to immediately interrupt training.\n- Mutate hyperparameters of the optimizer (available as `self.model.optimizer`),\nsuch as `self.model.optimizer.learning_rate`.\n- Save the model at period intervals.\n- Record the output of `model.predict()` on a few test samples at the end of each\nepoch, to use as a sanity check during training.\n- Extract visualizations of intermediate features at the end of each epoch, to monitor\nwhat the model is learning over time.\n- etc.\n\nLet's see this in action in a couple of examples.\n\"\"\"\n\n\"\"\"\n## Examples of Keras callback applications\n\n### Early stopping at minimum loss\n\nThis first example shows the creation of a `Callback` that stops training when the\nminimum of loss has been reached, by setting the attribute `self.model.stop_training`\n(boolean). Optionally, you can provide an argument `patience` to specify how many\nepochs we should wait before stopping after having reached a local minimum.\n\n`keras.callbacks.EarlyStopping` provides a more complete and general implementation.\n\"\"\"\n\n\nclass EarlyStoppingAtMinLoss(keras.callbacks.Callback):\n    \"\"\"Stop training when the loss is at its min, i.e. the loss stops decreasing.\n\n    Arguments:\n        patience: Number of epochs to wait after min has been hit. After this\n        number of no improvement, training stops.\n    \"\"\"\n\n    def __init__(self, patience=0):\n        super().__init__()\n        self.patience = patience\n        # best_weights to store the weights at which the minimum loss occurs.\n        self.best_weights = None\n\n    def on_train_begin(self, logs=None):\n        # The number of epoch it has waited when loss is no longer minimum.\n        self.wait = 0\n        # The epoch the training stops at.\n        self.stopped_epoch = 0\n        # Initialize the best as infinity.\n        self.best = np.inf\n\n    def on_epoch_end(self, epoch, logs=None):\n        current = logs.get(\"loss\")\n        if np.less(current, self.best):\n            self.best = current\n            self.wait = 0\n            # Record the best weights if current results is better (less).\n            self.best_weights = self.model.get_weights()\n        else:\n            self.wait += 1\n            if self.wait >= self.patience:\n                self.stopped_epoch = epoch\n                self.model.stop_training = True\n                print(\"Restoring model weights from the end of the best epoch.\")\n                self.model.set_weights(self.best_weights)\n\n    def on_train_end(self, logs=None):\n        if self.stopped_epoch > 0:\n            print(f\"Epoch {self.stopped_epoch + 1}: early stopping\")\n\n\nmodel = get_model()\nmodel.fit(\n    x_train,\n    y_train,\n    batch_size=64,\n    epochs=30,\n    verbose=0,\n    callbacks=[LossAndErrorPrintingCallback(), EarlyStoppingAtMinLoss()],\n)\n\n\"\"\"\n### Learning rate scheduling\n\nIn this example, we show how a custom Callback can be used to dynamically change the\nlearning rate of the optimizer during the course of training.\n\nSee `callbacks.LearningRateScheduler` for a more general implementations.\n\"\"\"\n\n\nclass CustomLearningRateScheduler(keras.callbacks.Callback):\n    \"\"\"Learning rate scheduler which sets the learning rate according to schedule.\n\n    Arguments:\n        schedule: a function that takes an epoch index\n            (integer, indexed from 0) and current learning rate\n            as inputs and returns a new learning rate as output (float).\n    \"\"\"\n\n    def __init__(self, schedule):\n        super().__init__()\n        self.schedule = schedule\n\n    def on_epoch_begin(self, epoch, logs=None):\n        if not hasattr(self.model.optimizer, \"learning_rate\"):\n            raise ValueError('Optimizer must have a \"learning_rate\" attribute.')\n        # Get the current learning rate from model's optimizer.\n        lr = self.model.optimizer.learning_rate\n        # Call schedule function to get the scheduled learning rate.\n        scheduled_lr = self.schedule(epoch, lr)\n        # Set the value back to the optimizer before this epoch starts\n        self.model.optimizer.learning_rate = scheduled_lr\n        print(\n            f\"\\nEpoch {epoch}: Learning rate is {float(np.array(scheduled_lr))}.\"\n        )\n\n\nLR_SCHEDULE = [\n    # (epoch to start, learning rate) tuples\n    (3, 0.05),\n    (6, 0.01),\n    (9, 0.005),\n    (12, 0.001),\n]\n\n\ndef lr_schedule(epoch, lr):\n    \"\"\"Helper function to retrieve the scheduled learning rate based on epoch.\"\"\"\n    if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:\n        return lr\n    for i in range(len(LR_SCHEDULE)):\n        if epoch == LR_SCHEDULE[i][0]:\n            return LR_SCHEDULE[i][1]\n    return lr\n\n\nmodel = get_model()\nmodel.fit(\n    x_train,\n    y_train,\n    batch_size=64,\n    epochs=15,\n    verbose=0,\n    callbacks=[\n        LossAndErrorPrintingCallback(),\n        CustomLearningRateScheduler(lr_schedule),\n    ],\n)\n\n\"\"\"\n### Built-in Keras callbacks\n\nBe sure to check out the existing Keras callbacks by\nreading the [API docs](https://keras.io/api/callbacks/).\nApplications include logging to CSV, saving\nthe model, visualizing metrics in TensorBoard, and a lot more!\n\"\"\"\n"
  },
  {
    "path": "integration_tests/basic_full_flow.py",
    "content": "import numpy as np\nimport pytest\n\nimport keras\nfrom keras.src import layers\nfrom keras.src import losses\nfrom keras.src import metrics\nfrom keras.src import optimizers\nfrom keras.src import testing\n\n\nclass MyModel(keras.Model):\n    def __init__(self, hidden_dim, output_dim, **kwargs):\n        super().__init__(**kwargs)\n        self.hidden_dim = hidden_dim\n        self.output_dim = output_dim\n        self.dense1 = layers.Dense(hidden_dim, activation=\"relu\")\n        self.dense2 = layers.Dense(hidden_dim, activation=\"relu\")\n        self.dense3 = layers.Dense(output_dim)\n\n    def call(self, x):\n        x = self.dense1(x)\n        x = self.dense2(x)\n        return self.dense3(x)\n\n\nclass BasicFlowTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_basic_fit(self):\n        model = MyModel(hidden_dim=2, output_dim=1)\n\n        x = np.random.random((128, 4))\n        y = np.random.random((128, 4))\n        batch_size = 32\n        epochs = 3\n\n        model.compile(\n            optimizer=optimizers.SGD(learning_rate=0.001),\n            loss=losses.MeanSquaredError(),\n            metrics=[metrics.MeanSquaredError()],\n        )\n        output_before_fit = model(x)\n        model.fit(\n            x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2\n        )\n        output_after_fit = model(x)\n\n        self.assertNotAllClose(output_before_fit, output_after_fit)\n\n    def test_basic_fit_no_training(self):\n        model = MyModel(hidden_dim=2, output_dim=1)\n        x = np.random.random((128, 4))\n        model.predict(x)\n        model(x)\n"
  },
  {
    "path": "integration_tests/dataset_tests/boston_housing_test.py",
    "content": "from keras.src import testing\nfrom keras.src.datasets import boston_housing\n\n\nclass BostonHousingTest(testing.TestCase):\n    def test_load_data(self):\n        (x_train, y_train), (x_test, y_test) = boston_housing.load_data()\n        self.assertEqual(x_train.shape[1], 13)\n        self.assertEqual(x_train.shape[0] + x_test.shape[0], 506)\n\n    def test_seed_reproducibility(self):\n        seed = 123\n        first_load = boston_housing.load_data(seed=seed)\n        second_load = boston_housing.load_data(seed=seed)\n        self.assertAllClose(first_load[0][0], second_load[0][0])\n        self.assertAllClose(first_load[1][0], second_load[1][0])\n\n    def test_invalid_test_split(self):\n        with self.assertRaises(AssertionError):\n            boston_housing.load_data(test_split=-0.1)\n        with self.assertRaises(AssertionError):\n            boston_housing.load_data(test_split=1.0)\n"
  },
  {
    "path": "integration_tests/dataset_tests/california_housing_test.py",
    "content": "from keras.src import testing\nfrom keras.src.datasets import california_housing\n\n\nclass CaliforniaHousingTest(testing.TestCase):\n    def test_load_data_large(self):\n        (x_train, y_train), (x_test, y_test) = california_housing.load_data(\n            version=\"large\"\n        )\n        self.assertEqual(x_train.shape[1], 8)\n        # Ensure the dataset contains 20,640 samples as documented\n        self.assertEqual(x_train.shape[0] + x_test.shape[0], 20640)\n\n    def test_load_data_small(self):\n        (x_train, y_train), (x_test, y_test) = california_housing.load_data(\n            version=\"small\"\n        )\n        self.assertEqual(x_train.shape[1], 8)\n        # Ensure the small dataset contains 600 samples as documented\n        self.assertEqual(x_train.shape[0] + x_test.shape[0], 600)\n\n    def test_invalid_version(self):\n        with self.assertRaises(ValueError):\n            california_housing.load_data(version=\"invalid_version\")\n\n    def test_seed_reproducibility(self):\n        # Ensure the data is reproducible with the same seed\n        seed = 123\n        first_load = california_housing.load_data(version=\"large\", seed=seed)\n        second_load = california_housing.load_data(version=\"large\", seed=seed)\n        self.assertAllClose(first_load[0][0], second_load[0][0])\n        self.assertAllClose(first_load[1][0], second_load[1][0])\n"
  },
  {
    "path": "integration_tests/dataset_tests/cifar100_test.py",
    "content": "import numpy as np\n\nfrom keras.src import testing\nfrom keras.src.datasets import cifar100\n\n\nclass Cifar100LoadDataTest(testing.TestCase):\n    def test_shapes_fine_label_mode(self):\n        (x_train, y_train), (x_test, y_test) = cifar100.load_data(\n            label_mode=\"fine\"\n        )\n        self.assertEqual(x_train.shape, (50000, 32, 32, 3))\n        self.assertEqual(y_train.shape, (50000, 1))\n        self.assertEqual(x_test.shape, (10000, 32, 32, 3))\n        self.assertEqual(y_test.shape, (10000, 1))\n\n    def test_shapes_coarse_label_mode(self):\n        (x_train, y_train), (x_test, y_test) = cifar100.load_data(\n            label_mode=\"coarse\"\n        )\n        self.assertEqual(x_train.shape, (50000, 32, 32, 3))\n        self.assertEqual(y_train.shape, (50000, 1))\n        self.assertEqual(x_test.shape, (10000, 32, 32, 3))\n        self.assertEqual(y_test.shape, (10000, 1))\n\n    def test_dtypes(self):\n        (x_train, y_train), (x_test, y_test) = cifar100.load_data()\n        self.assertEqual(x_train.dtype, np.uint8)\n        self.assertEqual(y_train.dtype, np.int64)\n        self.assertEqual(x_test.dtype, np.uint8)\n        self.assertEqual(y_test.dtype, np.int64)\n\n    def test_invalid_label_mode(self):\n        with self.assertRaises(ValueError):\n            cifar100.load_data(label_mode=\"invalid\")\n"
  },
  {
    "path": "integration_tests/dataset_tests/cifar10_test.py",
    "content": "import numpy as np\n\nfrom keras.src import testing\nfrom keras.src.datasets import cifar10\n\n\nclass Cifar10LoadDataTest(testing.TestCase):\n    def test_x_train_shape(self):\n        (x_train, _), _ = cifar10.load_data()\n        self.assertEqual(x_train.shape, (50000, 32, 32, 3))\n\n    def test_y_train_shape(self):\n        (_, y_train), _ = cifar10.load_data()\n        self.assertEqual(y_train.shape, (50000, 1))\n\n    def test_x_test_shape(self):\n        _, (x_test, _) = cifar10.load_data()\n        self.assertEqual(x_test.shape, (10000, 32, 32, 3))\n\n    def test_y_test_shape(self):\n        _, (_, y_test) = cifar10.load_data()\n        self.assertEqual(y_test.shape, (10000, 1))\n\n    def test_x_train_dtype(self):\n        (x_train, _), _ = cifar10.load_data()\n        self.assertEqual(x_train.dtype, np.uint8)\n\n    def test_y_train_dtype(self):\n        (_, y_train), _ = cifar10.load_data()\n        self.assertEqual(y_train.dtype, np.uint8)\n\n    def test_x_test_dtype(self):\n        _, (x_test, _) = cifar10.load_data()\n        self.assertEqual(x_test.dtype, np.uint8)\n\n    def test_y_test_dtype(self):\n        _, (_, y_test) = cifar10.load_data()\n        self.assertEqual(y_test.dtype, np.uint8)\n"
  },
  {
    "path": "integration_tests/dataset_tests/fashion_mnist_test.py",
    "content": "import numpy as np\n\nfrom keras.src import testing\nfrom keras.src.datasets import fashion_mnist\n\n\nclass FashionMnistLoadDataTest(testing.TestCase):\n    def test_x_train_shape(self):\n        (x_train, _), _ = fashion_mnist.load_data()\n        self.assertEqual(x_train.shape, (60000, 28, 28))\n\n    def test_y_train_shape(self):\n        (_, y_train), _ = fashion_mnist.load_data()\n        self.assertEqual(y_train.shape, (60000,))\n\n    def test_x_test_shape(self):\n        _, (x_test, _) = fashion_mnist.load_data()\n        self.assertEqual(x_test.shape, (10000, 28, 28))\n\n    def test_y_test_shape(self):\n        _, (_, y_test) = fashion_mnist.load_data()\n        self.assertEqual(y_test.shape, (10000,))\n\n    def test_x_train_dtype(self):\n        (x_train, _), _ = fashion_mnist.load_data()\n        self.assertEqual(x_train.dtype, np.uint8)\n\n    def test_y_train_dtype(self):\n        (_, y_train), _ = fashion_mnist.load_data()\n        self.assertEqual(y_train.dtype, np.uint8)\n\n    def test_x_test_dtype(self):\n        _, (x_test, _) = fashion_mnist.load_data()\n        self.assertEqual(x_test.dtype, np.uint8)\n\n    def test_y_test_dtype(self):\n        _, (_, y_test) = fashion_mnist.load_data()\n        self.assertEqual(y_test.dtype, np.uint8)\n"
  },
  {
    "path": "integration_tests/dataset_tests/imdb_test.py",
    "content": "import numpy as np\n\nfrom keras.src import testing\nfrom keras.src.datasets import imdb\n\n\nclass ImdbLoadDataTest(testing.TestCase):\n    def test_load_data_default(self):\n        (x_train, y_train), (x_test, y_test) = imdb.load_data()\n        self.assertIsInstance(x_train, np.ndarray)\n        self.assertIsInstance(y_train, np.ndarray)\n        self.assertIsInstance(x_test, np.ndarray)\n        self.assertIsInstance(y_test, np.ndarray)\n\n        # Check lengths\n        self.assertEqual(len(x_train), 25000)\n        self.assertEqual(len(y_train), 25000)\n        self.assertEqual(len(x_test), 25000)\n        self.assertEqual(len(y_test), 25000)\n\n        # Check types within lists for x\n        self.assertIsInstance(x_train[0], list)\n        self.assertIsInstance(x_test[0], list)\n\n    def test_num_words(self):\n        # Only consider the top 1000 words\n        (x_train, _), _ = imdb.load_data(num_words=1000)\n        # Ensure that no word index exceeds 999 (0-based indexing)\n        max_index = max(max(sequence) for sequence in x_train if sequence)\n        self.assertLessEqual(max_index, 999)\n\n    def test_skip_top(self):\n        # Skip the top 10 most frequent words\n        (x_train, _), _ = imdb.load_data(skip_top=10, num_words=1000)\n        # Check if top 10 words are skipped properly\n        self.assertNotIn(1, x_train[0])  # Assuming 1 is among top 10\n\n    def test_maxlen(self):\n        # Only consider sequences shorter than 100\n        (x_train, _), _ = imdb.load_data(maxlen=100)\n        self.assertTrue(all(len(seq) <= 100 for seq in x_train))\n\n    def test_get_word_index(self):\n        word_index = imdb.get_word_index()\n        self.assertIsInstance(word_index, dict)\n        # Check if word_index contains specific known words\n        self.assertIn(\"the\", word_index)\n        self.assertIn(\"and\", word_index)\n"
  },
  {
    "path": "integration_tests/dataset_tests/mnist_test.py",
    "content": "import numpy as np\n\nfrom keras.src import testing\nfrom keras.src.datasets import mnist\n\n\nclass MnistLoadDataTest(testing.TestCase):\n    def test_x_train_shape(self):\n        (x_train, _), _ = mnist.load_data()\n        self.assertEqual(x_train.shape, (60000, 28, 28))\n\n    def test_y_train_shape(self):\n        (_, y_train), _ = mnist.load_data()\n        self.assertEqual(y_train.shape, (60000,))\n\n    def test_x_test_shape(self):\n        _, (x_test, _) = mnist.load_data()\n        self.assertEqual(x_test.shape, (10000, 28, 28))\n\n    def test_y_test_shape(self):\n        _, (_, y_test) = mnist.load_data()\n        self.assertEqual(y_test.shape, (10000,))\n\n    def test_x_train_dtype(self):\n        (x_train, _), _ = mnist.load_data()\n        self.assertEqual(x_train.dtype, np.uint8)\n\n    def test_y_train_dtype(self):\n        (_, y_train), _ = mnist.load_data()\n        self.assertEqual(y_train.dtype, np.uint8)\n\n    def test_x_test_dtype(self):\n        _, (x_test, _) = mnist.load_data()\n        self.assertEqual(x_test.dtype, np.uint8)\n\n    def test_y_test_dtype(self):\n        _, (_, y_test) = mnist.load_data()\n        self.assertEqual(y_test.dtype, np.uint8)\n"
  },
  {
    "path": "integration_tests/dataset_tests/reuters_test.py",
    "content": "import numpy as np\n\nfrom keras.src import testing\nfrom keras.src.datasets import reuters\n\n\nclass ReutersLoadDataTest(testing.TestCase):\n    def test_load_data_default(self):\n        (x_train, y_train), (x_test, y_test) = reuters.load_data()\n        # Check types\n        self.assertIsInstance(x_train, np.ndarray)\n        self.assertIsInstance(y_train, np.ndarray)\n        self.assertIsInstance(x_test, np.ndarray)\n        self.assertIsInstance(y_test, np.ndarray)\n\n        # Check shapes\n        self.assertGreater(len(x_train), 0)\n        self.assertEqual(len(x_train), len(y_train))\n        self.assertGreater(len(x_test), 0)\n        self.assertEqual(len(x_test), len(y_test))\n\n    def test_num_words(self):\n        # Only consider the top 1000 words\n        (x_train, _), _ = reuters.load_data(num_words=1000)\n        # Ensure no word index exceeds 999 (0-based indexing)\n        max_index = max(max(sequence) for sequence in x_train if sequence)\n        self.assertLessEqual(max_index, 999)\n\n    def test_skip_top(self):\n        # Skip the top 10 most frequent words\n        (x_train, _), _ = reuters.load_data(skip_top=10, num_words=1000)\n        # Assuming 1 is among top 10, check if it's skipped\n        self.assertNotIn(1, x_train[0])\n\n    def test_maxlen(self):\n        # Only consider sequences shorter than 50\n        (x_train, _), _ = reuters.load_data(maxlen=50)\n        self.assertTrue(all(len(seq) <= 50 for seq in x_train))\n\n    def test_get_word_index(self):\n        word_index = reuters.get_word_index()\n        self.assertIsInstance(word_index, dict)\n        # Check if word_index contains specific known words\n        self.assertIn(\"the\", word_index)\n\n    def test_get_label_names(self):\n        label_names = reuters.get_label_names()\n        self.assertIsInstance(label_names, tuple)\n        # Check if the tuple contains specific known labels\n        self.assertIn(\"earn\", label_names)\n        self.assertIn(\"acq\", label_names)\n"
  },
  {
    "path": "integration_tests/import_test.py",
    "content": "import os\nimport re\nimport subprocess\n\nfrom keras.src import backend\nfrom keras.src.backend import config\n\n# For torch, use index url to avoid installing nvidia drivers for the test.\nBACKEND_REQ = {\n    \"tensorflow\": (\"tensorflow-cpu\", \"\"),\n    \"torch\": (\n        \"torch\",\n        \"--extra-index-url https://download.pytorch.org/whl/cpu \",\n    ),\n    \"jax\": (\"jax[cpu]\", \"\"),\n    \"openvino\": (\"openvino\", \"\"),\n}\n\n\ndef setup_package():\n    subprocess.run(\"rm -rf tmp_build_dir\", shell=True)\n    build_process = subprocess.run(\n        \"python3 pip_build.py\",\n        capture_output=True,\n        text=True,\n        shell=True,\n    )\n    print(build_process.stdout)\n    whl_path = re.findall(\n        r\"[^\\s]*\\.whl\",\n        build_process.stdout,\n    )\n    if not whl_path:\n        print(build_process.stdout)\n        print(build_process.stderr)\n        raise ValueError(\"Installing Keras package unsuccessful. \")\n    return whl_path[-1]\n\n\ndef create_virtualenv():\n    env_setup = [\n        # Create virtual environment\n        \"python3 -m venv test_env\",\n    ]\n    os.environ[\"PATH\"] = os.pathsep.join(\n        (\n            os.path.join(os.getcwd(), \"test_env\", \"bin\"),\n            os.environ.get(\"PATH\", \"\"),\n        )\n    )\n    if os.name == \"nt\":\n        os.environ[\"PATH\"] = os.pathsep.join(\n            (\n                os.path.join(os.getcwd(), \"test_env\", \"Scripts\"),\n                os.environ[\"PATH\"],\n            )\n        )\n    run_commands_local(env_setup)\n\n\ndef manage_venv_installs(whl_path):\n    other_backends = list(set(BACKEND_REQ.keys()) - {backend.backend()})\n    backend_pkg, backend_extra_url = BACKEND_REQ[backend.backend()]\n    install_setup = [\n        # Installs the backend's package and common requirements\n        f\"pip install {backend_extra_url}{backend_pkg}\",\n        \"pip install -r requirements-common.txt\",\n        \"pip install pytest\",\n        # Ensure other backends are uninstalled\n        \"pip uninstall -y {0} {1} {2}\".format(\n            BACKEND_REQ[other_backends[0]][0],\n            BACKEND_REQ[other_backends[1]][0],\n            BACKEND_REQ[other_backends[2]][0],\n        ),\n        # Install `.whl` package\n        f\"pip install {whl_path}\",\n    ]\n    # Install flax for JAX when NNX is enabled\n    if backend.backend() == \"jax\" and config.is_nnx_enabled():\n        install_setup.append(\"pip install flax>=0.10.1\")\n    run_commands_venv(install_setup)\n\n\ndef run_keras_flow():\n    test_script = [\n        # Runs the example script\n        \"python -m pytest integration_tests/basic_full_flow.py\",\n    ]\n    run_commands_venv(test_script)\n\n\ndef cleanup():\n    cleanup_script = [\n        # Exits virtual environment, deletes files, and any\n        # miscellaneous install logs\n        \"exit\",\n        \"rm -rf test_env\",\n        \"rm -rf tmp_build_dir\",\n        \"rm -f *+cpu\",\n    ]\n    run_commands_local(cleanup_script)\n\n\ndef run_commands_local(commands):\n    for command in commands:\n        print(f\"Running command: {command}\")\n        subprocess.run(command, shell=True)\n\n\ndef run_commands_venv(commands):\n    for command in commands:\n        print(f\"Running command: {command}\")\n        cmd_with_args = command.split(\" \")\n        cmd_with_args[0] = os.path.join(\n            \"test_env\",\n            \"Scripts\" if os.name == \"nt\" else \"bin\",\n            cmd_with_args[0],\n        )\n        p = subprocess.Popen(cmd_with_args)\n        assert p.wait() == 0\n\n\ndef test_keras_imports():\n    try:\n        # Ensures packages from all backends are installed.\n        # Builds Keras core package and returns package file path.\n        whl_path = setup_package()\n\n        # Creates and activates a virtual environment.\n        create_virtualenv()\n\n        # Ensures the backend's package is installed\n        # and the other backends are uninstalled.\n        manage_venv_installs(whl_path)\n\n        # Runs test of basic flow in Keras Core.\n        # Tests for backend-specific imports and `model.fit()`.\n        run_keras_flow()\n\n        # Removes virtual environment and associated files\n    finally:\n        cleanup()\n\n\nif __name__ == \"__main__\":\n    test_keras_imports()\n"
  },
  {
    "path": "integration_tests/jax_custom_fit_test.py",
    "content": "import jax\nimport numpy as np\n\nimport keras\n\n\ndef test_custom_fit():\n    class CustomModel(keras.Model):\n        def __init__(self, *args, **kwargs):\n            super().__init__(*args, **kwargs)\n            self.loss_tracker = keras.metrics.Mean(name=\"loss\")\n            self.mae_metric = keras.metrics.MeanAbsoluteError(name=\"mae\")\n            self.loss_fn = keras.losses.MeanSquaredError()\n\n        def compute_loss_and_updates(\n            self,\n            trainable_variables,\n            non_trainable_variables,\n            x,\n            y,\n            training=False,\n        ):\n            y_pred, non_trainable_variables = self.stateless_call(\n                trainable_variables,\n                non_trainable_variables,\n                x,\n                training=training,\n            )\n            loss = self.loss_fn(y, y_pred)\n            return loss, (y_pred, non_trainable_variables)\n\n        def train_step(self, state, data):\n            (\n                trainable_variables,\n                non_trainable_variables,\n                optimizer_variables,\n                metrics_variables,\n            ) = state\n            x, y = data\n            grad_fn = jax.value_and_grad(\n                self.compute_loss_and_updates, has_aux=True\n            )\n            (loss, (y_pred, non_trainable_variables)), grads = grad_fn(\n                trainable_variables,\n                non_trainable_variables,\n                x,\n                y,\n                training=True,\n            )\n            (\n                trainable_variables,\n                optimizer_variables,\n            ) = self.optimizer.stateless_apply(\n                optimizer_variables, grads, trainable_variables\n            )\n            loss_tracker_vars = metrics_variables[\n                : len(self.loss_tracker.variables)\n            ]\n            mae_metric_vars = metrics_variables[\n                len(self.loss_tracker.variables) :\n            ]\n            loss_tracker_vars = self.loss_tracker.stateless_update_state(\n                loss_tracker_vars, loss\n            )\n            mae_metric_vars = self.mae_metric.stateless_update_state(\n                mae_metric_vars, y, y_pred\n            )\n            logs = {}\n            logs[self.loss_tracker.name] = self.loss_tracker.stateless_result(\n                loss_tracker_vars\n            )\n            logs[self.mae_metric.name] = self.mae_metric.stateless_result(\n                mae_metric_vars\n            )\n            new_metrics_vars = loss_tracker_vars + mae_metric_vars\n            state = (\n                trainable_variables,\n                non_trainable_variables,\n                optimizer_variables,\n                new_metrics_vars,\n            )\n            return logs, state\n\n        @property\n        def metrics(self):\n            return [self.loss_tracker, self.mae_metric]\n\n    inputs = keras.Input(shape=(32,))\n    outputs = keras.layers.Dense(1)(inputs)\n    model = CustomModel(inputs, outputs)\n    model.compile(optimizer=\"adam\")\n    x = np.random.random((64, 32))\n    y = np.random.random((64, 1))\n    history = model.fit(x, y, epochs=1)\n\n    assert \"loss\" in history.history\n    assert \"mae\" in history.history\n\n    print(\"History:\")\n    print(history.history)\n\n\nif __name__ == \"__main__\":\n    test_custom_fit()\n"
  },
  {
    "path": "integration_tests/model_visualization_test.py",
    "content": "import re\n\nimport keras\nfrom keras.src import testing\nfrom keras.src.utils import model_to_dot\nfrom keras.src.utils import plot_model\n\n\nclass SubclassModel(keras.models.Model):\n    def __init__(self, name):\n        super().__init__(name=name)\n\n    def call(self, x):\n        return x\n\n\ndef parse_text_from_html(html):\n    pattern = r\"<font[^>]*>(.*?)</font>\"\n    matches = re.findall(pattern, html)\n\n    for match in matches:\n        clean_text = re.sub(r\"<[^>]*>\", \"\", match)\n        return clean_text\n    return \"\"\n\n\ndef get_node_text(node):\n    attributes = node.get_attributes()\n\n    if \"label\" in attributes:\n        html = node.get_attributes()[\"label\"]\n        return parse_text_from_html(html)\n    else:\n        return None\n\n\ndef get_edge_dict(dot):\n    def get_node_dict(graph, path=\"\"):\n        nodes = {\n            node.get_name(): path + get_node_text(node)\n            for node in graph.get_nodes()\n            if node.get_name() != \"node\"  # Dummy node inserted by pydot?\n        }\n\n        for subgraph in graph.get_subgraphs():\n            sub_nodes = get_node_dict(\n                subgraph, path=f\"{path}{subgraph.get_label()} > \"\n            )\n            nodes.update(sub_nodes)\n\n        return nodes\n\n    node_dict = get_node_dict(dot)\n\n    def get_edges(graph):\n        edges = list(graph.get_edges())\n        for subgraph in graph.get_subgraphs():\n            edges.extend(get_edges(subgraph))\n        return edges\n\n    edge_dict = dict()\n    dangling_edges = []\n    for edge in get_edges(dot):\n        source_node = node_dict.get(edge.get_source(), None)\n        destination_node = node_dict.get(edge.get_destination(), None)\n        if source_node is None or destination_node is None:\n            dangling_edges.append(\n                f\"from '{source_node}'/'{edge.get_source()}' \"\n                f\"to '{destination_node}'/'{edge.get_destination()}'\"\n            )\n        if source_node in edge_dict:\n            destination_nodes = edge_dict[source_node]\n            if not isinstance(destination_nodes, set):\n                destination_nodes = set([destination_nodes])\n                edge_dict[source_node] = destination_nodes\n            destination_nodes.add(destination_node)\n        else:\n            edge_dict[source_node] = destination_node\n\n    if dangling_edges:\n        raise ValueError(f\"Dangling edges found: {dangling_edges}\")\n    return edge_dict\n\n\nclass ModelVisualizationTest(testing.TestCase):\n    def multi_plot_model(self, model, name, expand_nested=False):\n        if expand_nested:\n            name = f\"{name}-expand_nested\"\n\n        TEST_CASES = [\n            {},\n            {\n                \"show_shapes\": True,\n            },\n            {\n                \"show_shapes\": True,\n                \"show_dtype\": True,\n            },\n            {\n                \"show_shapes\": True,\n                \"show_dtype\": True,\n                \"show_layer_names\": True,\n            },\n            {\n                \"show_shapes\": True,\n                \"show_dtype\": True,\n                \"show_layer_names\": True,\n                \"show_layer_activations\": True,\n            },\n            {\n                \"show_shapes\": True,\n                \"show_dtype\": True,\n                \"show_layer_names\": True,\n                \"show_layer_activations\": True,\n                \"show_trainable\": True,\n            },\n            {\n                \"show_shapes\": True,\n                \"show_dtype\": True,\n                \"show_layer_names\": True,\n                \"show_layer_activations\": True,\n                \"show_trainable\": True,\n                \"rankdir\": \"LR\",\n            },\n            {\n                \"show_layer_activations\": True,\n                \"show_trainable\": True,\n            },\n        ]\n\n        for test_case in TEST_CASES:\n            tags = [v if k == \"rankdir\" else k for k, v in test_case.items()]\n            file_name = f\"{'-'.join([name] + tags)}.png\"\n            plot_model(\n                model, file_name, expand_nested=expand_nested, **test_case\n            )\n            self.assertFileExists(file_name)\n\n    def test_plot_sequential_model(self):\n        model = keras.Sequential(\n            [\n                keras.Input((3,), name=\"input\"),\n                keras.layers.Dense(4, activation=\"relu\", name=\"dense\"),\n                keras.layers.Dense(1, activation=\"sigmoid\", name=\"dense_1\"),\n            ]\n        )\n\n        edge_dict = get_edge_dict(model_to_dot(model))\n        self.assertEqual(\n            edge_dict,\n            {\n                \"dense (Dense)\": \"dense_1 (Dense)\",\n            },\n        )\n        self.multi_plot_model(model, \"sequential\")\n\n    def test_plot_functional_model(self):\n        inputs = keras.Input((3,), name=\"input\")\n        x = keras.layers.Dense(\n            4, activation=\"relu\", trainable=False, name=\"dense\"\n        )(inputs)\n        residual = x\n        x = keras.layers.Dense(4, activation=\"relu\", name=\"dense_1\")(x)\n        x = keras.layers.Dense(4, activation=\"relu\", name=\"dense_2\")(x)\n        x = keras.layers.Dense(4, activation=\"relu\", name=\"dense_3\")(x)\n        x += residual\n        residual = x\n        x = keras.layers.Dense(4, activation=\"relu\", name=\"dense_4\")(x)\n        x = keras.layers.Dense(4, activation=\"relu\", name=\"dense_5\")(x)\n        x = keras.layers.Dense(4, activation=\"relu\", name=\"dense_6\")(x)\n        x += residual\n        x = keras.layers.Dropout(0.5, name=\"dropout\")(x)\n        outputs = keras.layers.Dense(1, activation=\"sigmoid\", name=\"dense_7\")(x)\n\n        model = keras.Model(inputs, outputs)\n\n        edge_dict = get_edge_dict(model_to_dot(model))\n        self.assertEqual(\n            edge_dict,\n            {\n                \"input (InputLayer)\": \"dense (Dense)\",\n                \"dense (Dense)\": {\"dense_1 (Dense)\", \"add (Add)\"},\n                \"dense_1 (Dense)\": \"dense_2 (Dense)\",\n                \"dense_2 (Dense)\": \"dense_3 (Dense)\",\n                \"dense_3 (Dense)\": \"add (Add)\",\n                \"add (Add)\": {\"dense_4 (Dense)\", \"add_1 (Add)\"},\n                \"dense_4 (Dense)\": \"dense_5 (Dense)\",\n                \"dense_5 (Dense)\": \"dense_6 (Dense)\",\n                \"dense_6 (Dense)\": \"add_1 (Add)\",\n                \"add_1 (Add)\": \"dropout (Dropout)\",\n                \"dropout (Dropout)\": \"dense_7 (Dense)\",\n            },\n        )\n        self.multi_plot_model(model, \"functional\")\n\n    def test_plot_subclassed_model(self):\n        model = SubclassModel(name=\"subclass\")\n        model.build((None, 3))\n\n        self.multi_plot_model(model, \"subclassed\")\n\n    def test_plot_nested_functional_model(self):\n        inputs = keras.Input((3,), name=\"input\")\n        x = keras.layers.Dense(4, activation=\"relu\", name=\"dense\")(inputs)\n        x = keras.layers.Dense(4, activation=\"relu\", name=\"dense_1\")(x)\n        outputs = keras.layers.Dense(3, activation=\"relu\", name=\"dense_2\")(x)\n        inner_model = keras.Model(inputs, outputs, name=\"inner_model\")\n\n        inputs = keras.Input((3,), name=\"input_1\")\n        x = keras.layers.Dense(\n            3, activation=\"relu\", trainable=False, name=\"dense_3\"\n        )(inputs)\n        residual = x\n        x = inner_model(x)\n        x = keras.layers.Add(name=\"add\")([x, residual])\n        residual = x\n        x = keras.layers.Dense(4, activation=\"relu\", name=\"dense_4\")(x)\n        x = keras.layers.Dense(4, activation=\"relu\", name=\"dense_5\")(x)\n        x = keras.layers.Dense(3, activation=\"relu\", name=\"dense_6\")(x)\n        x = keras.layers.Add(name=\"add_1\")([x, residual])\n        x = keras.layers.Dropout(0.5, name=\"dropout\")(x)\n        outputs = keras.layers.Dense(1, activation=\"sigmoid\", name=\"dense_7\")(x)\n        model = keras.Model(inputs, outputs)\n\n        edge_dict = get_edge_dict(model_to_dot(model))\n        self.assertEqual(\n            edge_dict,\n            {\n                \"input_1 (InputLayer)\": \"dense_3 (Dense)\",\n                \"dense_3 (Dense)\": {\"inner_model (Functional)\", \"add (Add)\"},\n                \"inner_model (Functional)\": \"add (Add)\",\n                \"add (Add)\": {\"dense_4 (Dense)\", \"add_1 (Add)\"},\n                \"dense_4 (Dense)\": \"dense_5 (Dense)\",\n                \"dense_5 (Dense)\": \"dense_6 (Dense)\",\n                \"dense_6 (Dense)\": \"add_1 (Add)\",\n                \"add_1 (Add)\": \"dropout (Dropout)\",\n                \"dropout (Dropout)\": \"dense_7 (Dense)\",\n            },\n        )\n        self.multi_plot_model(model, \"nested-functional\")\n\n        edge_dict = get_edge_dict(model_to_dot(model, expand_nested=True))\n        self.assertEqual(\n            edge_dict,\n            {\n                \"input_1 (InputLayer)\": \"dense_3 (Dense)\",\n                \"dense_3 (Dense)\": {\n                    \"inner_model > input (InputLayer)\",\n                    \"add (Add)\",\n                },\n                \"inner_model > input (InputLayer)\": \"inner_model > dense (Dense)\",  # noqa: E501\n                \"inner_model > dense (Dense)\": \"inner_model > dense_1 (Dense)\",  # noqa: E501\n                \"inner_model > dense_1 (Dense)\": \"inner_model > dense_2 (Dense)\",  # noqa: E501\n                \"inner_model > dense_2 (Dense)\": \"add (Add)\",\n                \"add (Add)\": {\"dense_4 (Dense)\", \"add_1 (Add)\"},\n                \"dense_4 (Dense)\": \"dense_5 (Dense)\",\n                \"dense_5 (Dense)\": \"dense_6 (Dense)\",\n                \"dense_6 (Dense)\": \"add_1 (Add)\",\n                \"add_1 (Add)\": \"dropout (Dropout)\",\n                \"dropout (Dropout)\": \"dense_7 (Dense)\",\n            },\n        )\n        self.multi_plot_model(model, \"nested-functional\", expand_nested=True)\n\n    def test_plot_functional_model_with_splits_and_merges(self):\n        class SplitLayer(keras.Layer):\n            def call(self, x):\n                return list(keras.ops.split(x, 2, axis=1))\n\n        class ConcatLayer(keras.Layer):\n            def call(self, xs):\n                return keras.ops.concatenate(xs, axis=1)\n\n        inputs = keras.Input((2,), name=\"input\")\n        a, b = SplitLayer()(inputs)\n\n        a = keras.layers.Dense(2, name=\"dense\")(a)\n        b = keras.layers.Dense(2, name=\"dense_1\")(b)\n\n        outputs = ConcatLayer(name=\"concat_layer\")([a, b])\n        model = keras.Model(inputs, outputs)\n\n        edge_dict = get_edge_dict(model_to_dot(model))\n        self.assertEqual(\n            edge_dict,\n            {\n                \"input (InputLayer)\": \"split_layer (SplitLayer)\",\n                \"split_layer (SplitLayer)\": {\n                    \"dense (Dense)\",\n                    \"dense_1 (Dense)\",\n                },\n                \"dense (Dense)\": \"concat_layer (ConcatLayer)\",\n                \"dense_1 (Dense)\": \"concat_layer (ConcatLayer)\",\n            },\n        )\n        self.multi_plot_model(model, \"split-functional\")\n\n    def test_plot_sequential_in_sequential(self):\n        inner_model = keras.models.Sequential(\n            [\n                keras.layers.Dense(10, name=\"dense2\"),\n                keras.layers.Dense(10, name=\"dense3\"),\n            ],\n            name=\"sub\",\n        )\n        model = keras.models.Sequential(\n            [\n                keras.layers.Dense(10, name=\"dense1\"),\n                inner_model,\n            ],\n        )\n        model.build((1, 10))\n\n        #\n        #  +-------------------------+\n        #  |     dense1 (Dense)      |\n        #  +-------------------------+\n        #               |\n        #               v\n        #  +-------------------------+\n        #  |    sub (Sequential)     |\n        #  +-------------------------+\n        #\n        edge_dict = get_edge_dict(model_to_dot(model))\n        self.assertEqual(\n            edge_dict,\n            {\n                \"dense1 (Dense)\": \"sub (Sequential)\",\n            },\n        )\n        self.multi_plot_model(model, \"sequential_in_sequential\")\n\n        #\n        #    +-------------------------+\n        #    |     dense1 (Dense)      |\n        #    +-------------------------+\n        #                 |\n        #  +--------------|--------------+\n        #  | sub          v              |\n        #  | +-------------------------+ |\n        #  | |     dense2 (Dense)      | |\n        #  | +-------------------------+ |\n        #  |              |              |\n        #  |              v              |\n        #  | +-------------------------+ |\n        #  | |     dense3 (Dense)      | |\n        #  | +-------------------------+ |\n        #  +-----------------------------+\n        #\n        edge_dict = get_edge_dict(model_to_dot(model, expand_nested=True))\n        self.assertEqual(\n            edge_dict,\n            {\n                \"dense1 (Dense)\": \"sub > dense2 (Dense)\",\n                \"sub > dense2 (Dense)\": \"sub > dense3 (Dense)\",\n            },\n        )\n        self.multi_plot_model(\n            model, \"sequential_in_sequential\", expand_nested=True\n        )\n\n    def test_plot_functional_in_functional(self):\n        inner_input = keras.layers.Input((10,), name=\"inner_input\")\n        x = keras.layers.Dense(10, name=\"dense1\")(inner_input)\n        x = keras.layers.Dense(10, name=\"dense2\")(x)\n        inner_model = keras.models.Model(inner_input, x, name=\"inner\")\n\n        outer_input = keras.layers.Input((10,), name=\"outer_input\")\n        model = keras.models.Model(outer_input, inner_model(outer_input))\n\n        #\n        #  +-------------------------+\n        #  |outer_input (InputLayer) |\n        #  +-------------------------+\n        #               |\n        #               v\n        #  +-------------------------+\n        #  |   inner (Functional)    |\n        #  +-------------------------+\n        #\n        edge_dict = get_edge_dict(model_to_dot(model))\n        self.assertEqual(\n            edge_dict,\n            {\n                \"outer_input (InputLayer)\": \"inner (Functional)\",\n            },\n        )\n        self.multi_plot_model(model, \"functional_in_functional\")\n\n        #\n        #    +-------------------------+\n        #    |outer_input (InputLayer) |\n        #    +-------------------------+\n        #                 |\n        #  +--------------|--------------+\n        #  | inner        v              |\n        #  | +-------------------------+ |\n        #  | |inner_input (InputLayer) | |\n        #  | +-------------------------+ |\n        #  |              |              |\n        #  |              v              |\n        #  | +-------------------------+ |\n        #  | |     dense1 (Dense)      | |\n        #  | +-------------------------+ |\n        #  |              |              |\n        #  |              v              |\n        #  | +-------------------------+ |\n        #  | |     dense2 (Dense)      | |\n        #  | +-------------------------+ |\n        #  +-----------------------------+\n        #\n        edge_dict = get_edge_dict(model_to_dot(model, expand_nested=True))\n        self.assertEqual(\n            edge_dict,\n            {\n                \"outer_input (InputLayer)\": \"inner > inner_input (InputLayer)\",\n                \"inner > inner_input (InputLayer)\": \"inner > dense1 (Dense)\",\n                \"inner > dense1 (Dense)\": \"inner > dense2 (Dense)\",\n            },\n        )\n        self.multi_plot_model(\n            model, \"functional_in_functional\", expand_nested=True\n        )\n\n    def test_plot_sequential_in_sequential_in_sequential(self):\n        inner_model = keras.models.Sequential(\n            [\n                keras.layers.Dense(10, name=\"dense2\"),\n                keras.layers.Dense(10, name=\"dense3\"),\n            ],\n            name=\"inner\",\n        )\n        mid_model = keras.models.Sequential(\n            [\n                inner_model,\n            ],\n            name=\"mid\",\n        )\n        model = keras.models.Sequential(\n            [\n                keras.layers.Dense(10, name=\"dense1\"),\n                mid_model,\n                keras.layers.Dense(10, name=\"dense4\"),\n            ],\n        )\n        model.build((1, 10))\n\n        #\n        #  +-------------------------+\n        #  |     dense1 (Dense)      |\n        #  +-------------------------+\n        #               |\n        #               v\n        #  +-------------------------+\n        #  |    mid (Sequential)     |\n        #  +-------------------------+\n        #               |\n        #               v\n        #  +-------------------------+\n        #  |     dense4 (Dense)      |\n        #  +-------------------------+\n        #\n        edge_dict = get_edge_dict(model_to_dot(model))\n        self.assertEqual(\n            edge_dict,\n            {\n                \"dense1 (Dense)\": \"mid (Sequential)\",\n                \"mid (Sequential)\": \"dense4 (Dense)\",\n            },\n        )\n        self.multi_plot_model(model, \"sequential_in_sequential_in_sequential\")\n\n        #\n        #      +-------------------------+\n        #      |     dense1 (Dense)      |\n        #      +-------------------------+\n        #                   |\n        #  +----------------|----------------+\n        #  | mid            |                |\n        #  | +--------------|--------------+ |\n        #  | | inner        v              | |\n        #  | | +-------------------------+ | |\n        #  | | |     dense2 (Dense)      | | |\n        #  | | +-------------------------+ | |\n        #  | |              |              | |\n        #  | |              v              | |\n        #  | | +-------------------------+ | |\n        #  | | |     dense3 (Dense)      | | |\n        #  | | +-------------------------+ | |\n        #  | +--------------|--------------+ |\n        #  +----------------|----------------+\n        #                   v\n        #      +-------------------------+\n        #      |     dense4 (Dense)      |\n        #      +-------------------------+\n        #\n        edge_dict = get_edge_dict(model_to_dot(model, expand_nested=True))\n        self.assertEqual(\n            edge_dict,\n            {\n                \"dense1 (Dense)\": \"mid > inner > dense2 (Dense)\",\n                \"mid > inner > dense2 (Dense)\": \"mid > inner > dense3 (Dense)\",\n                \"mid > inner > dense3 (Dense)\": \"dense4 (Dense)\",\n            },\n        )\n        self.multi_plot_model(\n            model, \"sequential_in_sequential_in_sequential\", expand_nested=True\n        )\n\n    def test_plot_functional_in_sequential_in_sequential(self):\n        input1 = keras.layers.Input((10,), name=\"input1\")\n        x = keras.layers.Dense(10, name=\"dense2\")(input1)\n        inner_model = keras.models.Model(input1, x, name=\"inner\")\n\n        mid_model = keras.models.Sequential(\n            [\n                inner_model,\n            ],\n            name=\"mid\",\n        )\n        model = keras.models.Sequential(\n            [\n                keras.layers.Dense(10, name=\"dense1\"),\n                mid_model,\n                keras.layers.Dense(10, name=\"dense3\"),\n            ],\n        )\n        model.build((1, 10))\n\n        #\n        #  +-------------------------+\n        #  |     dense1 (Dense)      |\n        #  +-------------------------+\n        #               |\n        #               v\n        #  +-------------------------+\n        #  |    mid (Sequential)     |\n        #  +-------------------------+\n        #               |\n        #               v\n        #  +-------------------------+\n        #  |     dense3 (Dense)      |\n        #  +-------------------------+\n        #\n        edge_dict = get_edge_dict(model_to_dot(model))\n        self.assertEqual(\n            edge_dict,\n            {\n                \"dense1 (Dense)\": \"mid (Sequential)\",\n                \"mid (Sequential)\": \"dense3 (Dense)\",\n            },\n        )\n        self.multi_plot_model(model, \"functional_in_sequential_in_sequential\")\n\n        #\n        #      +-------------------------+\n        #      |     dense1 (Dense)      |\n        #      +-------------------------+\n        #                   |\n        #  +----------------|----------------+\n        #  | mid            |                |\n        #  | +--------------|--------------+ |\n        #  | | inner        v              | |\n        #  | | +-------------------------+ | |\n        #  | | |   input1 (Inputlayer)   | | |\n        #  | | +-------------------------+ | |\n        #  | |              |              | |\n        #  | |              v              | |\n        #  | | +-------------------------+ | |\n        #  | | |     dense2 (Dense)      | | |\n        #  | | +-------------------------+ | |\n        #  | +--------------|--------------+ |\n        #  +----------------|----------------+\n        #                   v\n        #      +-------------------------+\n        #      |     dense3 (Dense)      |\n        #      +-------------------------+\n        #\n        edge_dict = get_edge_dict(model_to_dot(model, expand_nested=True))\n        self.assertEqual(\n            edge_dict,\n            {\n                \"dense1 (Dense)\": \"mid > inner > input1 (InputLayer)\",\n                \"mid > inner > input1 (InputLayer)\": \"mid > inner > dense2 (Dense)\",  # noqa: E501\n                \"mid > inner > dense2 (Dense)\": \"dense3 (Dense)\",\n            },\n        )\n        self.multi_plot_model(\n            model, \"functional_in_sequential_in_sequential\", expand_nested=True\n        )\n\n    def test_plot_functional_in_functional_in_functional(self):\n        # From https://github.com/keras-team/keras/issues/21119\n        inner_input = keras.layers.Input((10,), name=\"inner_input\")\n        x = keras.layers.Dense(10, name=\"dense1\")(inner_input)\n        inner_model = keras.models.Model(inner_input, x, name=\"inner\")\n\n        mid_input = keras.layers.Input((10,), name=\"mid_input\")\n        mid_output = inner_model(mid_input)\n        mid_model = keras.models.Model(mid_input, mid_output, name=\"mid\")\n\n        outer_input = keras.layers.Input((10,), name=\"outer_input\")\n        x = mid_model(outer_input)\n        x = keras.layers.Dense(10, name=\"dense2\")(x)\n        model = keras.models.Model(outer_input, x)\n\n        #\n        #  +-------------------------+\n        #  |outer_input (InputLayer) |\n        #  +-------------------------+\n        #               |\n        #               v\n        #  +-------------------------+\n        #  |    mid (Functional)     |\n        #  +-------------------------+\n        #               |\n        #               v\n        #  +-------------------------+\n        #  |     dense2 (Dense)      |\n        #  +-------------------------+\n        #\n        edge_dict = get_edge_dict(model_to_dot(model))\n        self.assertEqual(\n            edge_dict,\n            {\n                \"outer_input (InputLayer)\": \"mid (Functional)\",\n                \"mid (Functional)\": \"dense2 (Dense)\",\n            },\n        )\n        self.multi_plot_model(model, \"functional_in_functional_in_functional\")\n\n        #\n        #      +-------------------------+\n        #      |outer_input (InputLayer) |\n        #      +-------------------------+\n        #                   |\n        #  +----------------|----------------+\n        #  | mid            |                |\n        #  |   +-------------------------+   |\n        #  |   | mid_input (Inputlayer)  |   |\n        #  |   +-------------------------+   |\n        #  | +--------------|--------------+ |\n        #  | | inner        v              | |\n        #  | | +-------------------------+ | |\n        #  | | |inner_input (Inputlayer) | | |\n        #  | | +-------------------------+ | |\n        #  | |              |              | |\n        #  | |              v              | |\n        #  | | +-------------------------+ | |\n        #  | | |     dense1 (Dense)      | | |\n        #  | | +-------------------------+ | |\n        #  | +--------------|--------------+ |\n        #  +----------------|----------------+\n        #                   v\n        #      +-------------------------+\n        #      |     dense2 (Dense)      |\n        #      +-------------------------+\n        #\n        edge_dict = get_edge_dict(model_to_dot(model, expand_nested=True))\n        self.assertEqual(\n            edge_dict,\n            {\n                \"outer_input (InputLayer)\": \"mid > mid_input (InputLayer)\",\n                \"mid > mid_input (InputLayer)\": \"mid > inner > inner_input (InputLayer)\",  # noqa: E501\n                \"mid > inner > inner_input (InputLayer)\": \"mid > inner > dense1 (Dense)\",  # noqa: E501\n                \"mid > inner > dense1 (Dense)\": \"dense2 (Dense)\",\n            },\n        )\n        self.multi_plot_model(\n            model, \"functional_in_functional_in_functional\", expand_nested=True\n        )\n\n    def test_plot_complex(self):\n        # Note: this test exercises the case when `output_index` is not 0 and\n        # changes when going deeply in nested models to resolve the destination\n        # of an edge.\n        inner_inpt1 = keras.layers.Input(shape=(10,), name=\"inner_inpt1\")\n        inner_inpt2 = keras.layers.Input(shape=(10,), name=\"inner_inpt2\")\n        inner_model = keras.models.Model(\n            [inner_inpt1, inner_inpt2],\n            [\n                keras.layers.Dense(10, name=\"dense1\")(inner_inpt1),\n                keras.layers.Dense(10, name=\"dense2\")(inner_inpt2),\n            ],\n            name=\"inner\",\n        )\n\n        input0 = keras.layers.Input(shape=(10,), name=\"input0\")\n        input1 = keras.layers.Input(shape=(10,), name=\"input1\")\n        input2 = keras.layers.Input(shape=(10,), name=\"input2\")\n        input3 = keras.layers.Input(shape=(10,), name=\"input3\")\n\n        mid_sequential = keras.models.Sequential(\n            [\n                keras.layers.Dense(10, name=\"dense0\"),\n                SubclassModel(name=\"subclass0\"),\n            ],\n            name=\"seq\",\n        )\n        mid_subclass = SubclassModel(name=\"subclass3\")\n        mid_model = keras.models.Model(\n            [input0, input1, input2, input3],\n            [\n                mid_sequential(input0),\n                *inner_model([input1, input2]),\n                mid_subclass(input3),\n            ],\n            name=\"mid\",\n        )\n\n        outer_input = keras.layers.Input((10,), name=\"outer_input\")\n        mid_outputs = mid_model(\n            [outer_input, outer_input, outer_input, outer_input]\n        )\n        model = keras.models.Model(\n            outer_input,\n            [\n                keras.layers.Add(name=\"add1\")([mid_outputs[0], mid_outputs[1]]),\n                keras.layers.Add(name=\"add2\")([mid_outputs[2], mid_outputs[3]]),\n            ],\n        )\n\n        #\n        #                 +-------------------------+\n        #                 |outer_input (InputLayer) |\n        #                 +-------------------------+\n        #                              |\n        #                              v\n        #                 +-------------------------+\n        #                 |    mid (Functional)     |\n        #                 +-------------------------+\n        #                          |      |\n        #                          v      v\n        #  +-------------------------+  +-------------------------+\n        #  |       add1 (Add)        |  |       add2 (Add)        |\n        #  +-------------------------+  +-------------------------+\n        #\n        edge_dict = get_edge_dict(model_to_dot(model))\n        self.assertEqual(\n            edge_dict,\n            {\n                \"outer_input (InputLayer)\": \"mid (Functional)\",\n                \"mid (Functional)\": {\"add1 (Add)\", \"add2 (Add)\"},\n            },\n        )\n        self.multi_plot_model(model, \"complex\")\n\n        #\n        #                               +-----------+\n        #            +------------------|outer_input|-----------------+\n        #            |                  +-----------+                 |\n        #            |                   |         |                  |\n        #  +---------|-------------------|---------|------------------|-------+\n        #  | mid     v                   v         v                  v       |\n        #  |   +-----------+     +-----------+ +-----------+    +-----------+ |\n        #  |   |  input0   |     |  input1   | |  input2   |    |  input3   | |\n        #  |   +-----------+     +-----------+ +-----------+    +-----------+ |\n        #  | +-------|-------+ +-------|-------------|-------+        |       |\n        #  | | seq   v       | | inner v             v       |        |       |\n        #  | | +-----------+ | | +-----------+ +-----------+ |  +-----------+ |\n        #  | | |  dense0   | | | |inner_inp1t| |inner_inp2t| |  | subclass3 | |\n        #  | | +-----------+ | | +-----------+ +-----------+ |  +-----------+ |\n        #  | |       |       | |       |             |       |    |           |\n        #  | |       v       | |       v             v       |    |           |\n        #  | | +-----------+ | | +-----------+ +-----------+ |    |           |\n        #  | | | subclass0 | | | |  dense1   | |  dense2   | |    |           |\n        #  | | +-----------+ | | +-----------+ +-----------+ |    |           |\n        #  | +-----------|---+ +---|---------------------|---+    |           |\n        #  +-------------|---------|---------------------|--------|-----------+\n        #                v         v                     v        v\n        #               +-----------+                   +-----------+\n        #               |    add1   |                   |   add2    |\n        #               +-----------+                   +-----------+\n        #\n        edge_dict = get_edge_dict(model_to_dot(model, expand_nested=True))\n        self.assertEqual(\n            edge_dict,\n            {\n                # 1st row\n                \"outer_input (InputLayer)\": {\n                    \"mid > input0 (InputLayer)\",\n                    \"mid > input1 (InputLayer)\",\n                    \"mid > input2 (InputLayer)\",\n                    \"mid > input3 (InputLayer)\",\n                },\n                # 2nd row\n                \"mid > input0 (InputLayer)\": \"mid > seq > dense0 (Dense)\",\n                \"mid > input1 (InputLayer)\": \"mid > inner > inner_inpt1 (InputLayer)\",  # noqa: E501\n                \"mid > input2 (InputLayer)\": \"mid > inner > inner_inpt2 (InputLayer)\",  # noqa: E501\n                \"mid > input3 (InputLayer)\": \"mid > subclass3 (SubclassModel)\",\n                # 3rd row\n                \"mid > seq > dense0 (Dense)\": \"mid > seq > subclass0 (SubclassModel)\",  # noqa: E501\n                \"mid > inner > inner_inpt1 (InputLayer)\": \"mid > inner > dense1 (Dense)\",  # noqa: E501\n                \"mid > inner > inner_inpt2 (InputLayer)\": \"mid > inner > dense2 (Dense)\",  # noqa: E501\n                # 4th row\n                \"mid > seq > subclass0 (SubclassModel)\": \"add1 (Add)\",\n                \"mid > inner > dense1 (Dense)\": \"add1 (Add)\",\n                \"mid > inner > dense2 (Dense)\": \"add2 (Add)\",\n                \"mid > subclass3 (SubclassModel)\": \"add2 (Add)\",\n            },\n        )\n        self.multi_plot_model(model, \"complex\", expand_nested=True)\n"
  },
  {
    "path": "integration_tests/numerical_test.py",
    "content": "import keras  # isort: skip, keep it on top for torch test\n\nimport sys\n\nimport numpy as np\nimport tf_keras\n\nkeras.backend.set_image_data_format(\"channels_last\")\ntf_keras.backend.set_image_data_format(\"channels_last\")\n\nNUM_CLASSES = 10\nBATCH_SIZE = 32\nEPOCHS = 1\n\n\ndef build_mnist_data(num_classes):\n    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n\n    # Scale images to the [0, 1] range\n    x_train = x_train.astype(\"float32\") / 255\n    x_test = x_test.astype(\"float32\") / 255\n    # Make sure images have shape (28, 28, 1)\n    x_train = np.expand_dims(x_train, -1)\n    x_test = np.expand_dims(x_test, -1)\n\n    # convert class vectors to binary class matrices\n    y_train = keras.utils.to_categorical(y_train, num_classes)\n    y_test = keras.utils.to_categorical(y_test, num_classes)\n\n    return x_train[:100], y_train[:100]\n\n\ndef build_keras_model(keras_module, num_classes):\n    input_shape = (28, 28, 1)\n\n    model = keras_module.Sequential(\n        [\n            keras_module.Input(shape=input_shape),\n            keras_module.layers.Conv2D(\n                32, kernel_size=(3, 3), activation=\"relu\"\n            ),\n            keras_module.layers.BatchNormalization(),\n            keras_module.layers.MaxPooling2D(pool_size=(2, 2)),\n            keras_module.layers.Conv2D(\n                64, kernel_size=(3, 3), activation=\"relu\"\n            ),\n            keras_module.layers.BatchNormalization(scale=False, center=True),\n            keras_module.layers.MaxPooling2D(pool_size=(2, 2)),\n            keras_module.layers.Flatten(),\n            keras_module.layers.Dense(num_classes, activation=\"softmax\"),\n        ]\n    )\n    return model\n\n\ndef compile_model(model):\n    model.compile(\n        loss=\"categorical_crossentropy\",\n        optimizer=\"adam\",\n        metrics=[\"mae\", \"accuracy\"],\n        jit_compile=False,\n        run_eagerly=True,\n    )\n\n\ndef train_model(model, x, y):\n    return model.fit(\n        x,\n        y,\n        batch_size=BATCH_SIZE,\n        epochs=EPOCHS,\n        shuffle=False,\n        verbose=0,\n    )\n\n\ndef eval_model(model, x, y):\n    score = model.evaluate(x, y, verbose=0, batch_size=BATCH_SIZE)\n    print(score)\n    return score\n\n\ndef check_history(h1, h2):\n    for key in h1.history.keys():\n        print(f\"{key}:\")\n        print(h1.history[key])\n        print(h2.history[key])\n        np.testing.assert_allclose(\n            h1.history[key],\n            h2.history[key],\n            atol=1e-3,\n        )\n\n\ndef predict_model(model, x):\n    return model.predict(x, batch_size=BATCH_SIZE, verbose=0)\n\n\ndef numerical_test():\n    x_train, y_train = build_mnist_data(NUM_CLASSES)\n    keras_model = build_keras_model(keras, NUM_CLASSES)\n    tf_keras_model = build_keras_model(tf_keras, NUM_CLASSES)\n\n    # Make sure both model have same weights before training\n    weights = [weight.numpy() for weight in keras_model.weights]\n    tf_keras_model.set_weights(weights)\n\n    for kw, kcw in zip(keras_model.weights, tf_keras_model.weights):\n        np.testing.assert_allclose(kw.numpy(), kcw.numpy())\n\n    compile_model(keras_model)\n    compile_model(tf_keras_model)\n\n    print(\"Checking training histories:\")\n    keras_history = train_model(keras_model, x_train, y_train)\n    tf_keras_history = train_model(tf_keras_model, x_train, y_train)\n    check_history(keras_history, tf_keras_history)\n    print(\"Training histories match.\")\n    print()\n\n    print(\"Checking trained weights:\")\n    for kw, kcw in zip(keras_model.weights, tf_keras_model.weights):\n        np.testing.assert_allclose(kw.numpy(), kcw.numpy(), atol=1e-3)\n    print(\"Trained weights match.\")\n    print()\n\n    print(\"Checking predict:\")\n    outputs1 = predict_model(keras_model, x_train)\n    outputs2 = predict_model(tf_keras_model, x_train)\n    np.testing.assert_allclose(outputs1, outputs2, atol=1e-3)\n    print(\"Predict results match.\")\n    print()\n\n    print(\"Checking evaluate:\")\n    score1 = eval_model(keras_model, x_train, y_train)\n    score2 = eval_model(tf_keras_model, x_train, y_train)\n    np.testing.assert_allclose(score1, score2, atol=1e-3)\n    print(\"Evaluate results match.\")\n\n\nif __name__ == \"__main__\":\n    if keras.backend.backend() == \"openvino\":\n        # this test requires trainable backend\n        sys.exit(0)\n    keras.utils.set_random_seed(1337)\n    tf_keras.utils.set_random_seed(1337)\n    numerical_test()\n"
  },
  {
    "path": "integration_tests/pytorch_export_test.py",
    "content": "\"\"\"\nIntegration tests for PyTorch model export with dynamic shapes.\n\nTests the complete fix for GitHub issue #22102 where models with\nAveragePooling2D → Conv2D → Reshape failed to export with dynamic shapes.\n\nThe fixes enable:\n1. torch.export with dynamic shapes\n2. ONNX export with dynamic shapes\n3. TorchScript tracing with dynamic shapes\n\"\"\"\n\nimport os\n\nos.environ[\"KERAS_BACKEND\"] = \"torch\"\n\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import testing\n\n\n@pytest.mark.skipif(\n    backend.backend() != \"torch\",\n    reason=\"Export tests require PyTorch backend\",\n)\nclass TestPyTorchExportWithDynamicShapes(testing.TestCase):\n    \"\"\"Test PyTorch export methods with dynamic shapes (GitHub issue #22102).\"\"\"\n\n    @parameterized.named_parameters(\n        (\"shape_3x3\", (1, 3, 3, 1016), (1, 1, 512)),\n        (\"shape_5x5\", (1, 5, 5, 1016), (1, 4, 512)),\n        (\"shape_7x7_batch2\", (2, 7, 7, 1016), (2, 9, 512)),\n    )\n    def test_issue_22102_model_inference(self, input_shape, expected_shape):\n        \"\"\"Test the exact model from issue #22102 with varying shapes.\"\"\"\n        import torch\n\n        # Create the exact model from issue #22102\n        inputs = layers.Input(shape=(None, None, 1016))\n        x = layers.AveragePooling2D(pool_size=(3, 2), strides=2)(inputs)\n        x = layers.Conv2D(512, kernel_size=1, activation=\"relu\")(x)\n        x = layers.Reshape((-1, 512))(x)\n        model = models.Model(inputs=inputs, outputs=x)\n\n        # Test inference with varying shapes\n        x_test = torch.randn(*input_shape)\n        output = model(x_test)\n\n        self.assertEqual(tuple(output.shape), expected_shape)\n\n    @parameterized.named_parameters(\n        (\"torch_export\", \"torch_export\"),\n        (\"onnx_export\", \"onnx_export\"),\n        (\"torchscript_trace\", \"torchscript_trace\"),\n    )\n    def test_issue_22102_export_methods(self, export_method):\n        \"\"\"Test issue #22102 model with different export methods.\n\n        Validates that all export methods work with dynamic shapes\n        after the fix.\n        \"\"\"\n        import tempfile\n\n        import torch\n\n        # Create the exact model from issue #22102\n        inputs = layers.Input(shape=(None, None, 1016))\n        x = layers.AveragePooling2D(pool_size=(3, 2), strides=2)(inputs)\n        x = layers.Conv2D(512, kernel_size=1, activation=\"relu\")(x)\n        x = layers.Reshape((-1, 512))(x)\n        model = models.Model(inputs=inputs, outputs=x)\n\n        sample_input = torch.randn(1, 3, 3, 1016)\n\n        if export_method == \"torch_export\":\n            # Test torch.export with dynamic shapes\n            # Note: torch.export has stricter constraints than ONNX export\n            # Skip if constraints cannot be satisfied\n            try:\n                batch_dim = torch.export.Dim(\"batch\", min=1, max=1024)\n                h_dim = torch.export.Dim(\"height\", min=1, max=1024)\n                w_dim = torch.export.Dim(\"width\", min=1, max=1024)\n\n                exported = torch.export.export(\n                    model,\n                    (sample_input,),\n                    dynamic_shapes=(({0: batch_dim, 1: h_dim, 2: w_dim},),),\n                    strict=False,\n                )\n\n                # Test with different shapes\n                for shape in [(1, 3, 3, 1016), (1, 5, 5, 1016)]:\n                    x_test = torch.randn(*shape)\n                    output = exported.module()(x_test)\n                    self.assertIsNotNone(output)\n\n            except Exception as e:\n                # torch.export has known limitations with certain\n                # layer combinations. The important thing is that\n                # ONNX export works (tested separately)\n                if \"Constraints violated\" in str(e):\n                    pytest.skip(\n                        f\"torch.export constraints not satisfiable: {e}\"\n                    )\n                pytest.skip(f\"torch.export not available: {e}\")\n\n        elif export_method == \"onnx_export\":\n            # Test ONNX export with dynamic shapes\n            try:\n                import onnxruntime as ort\n\n                with tempfile.NamedTemporaryFile(\n                    suffix=\".onnx\", delete=False\n                ) as f:\n                    onnx_path = f.name\n\n                torch.onnx.export(\n                    model,\n                    (sample_input,),\n                    onnx_path,\n                    input_names=[\"input\"],\n                    output_names=[\"output\"],\n                    dynamic_shapes=(\n                        (\n                            (\n                                torch.export.Dim.DYNAMIC,\n                                torch.export.Dim.DYNAMIC,\n                                torch.export.Dim.DYNAMIC,\n                                torch.export.Dim.STATIC,\n                            ),\n                        ),\n                    ),\n                )\n\n                # Test with ONNX Runtime\n                ort_session = ort.InferenceSession(onnx_path)\n                input_name = ort_session.get_inputs()[0].name\n\n                for shape in [\n                    (1, 3, 3, 1016),\n                    (1, 5, 5, 1016),\n                    (2, 7, 7, 1016),\n                ]:\n                    x_test = np.random.randn(*shape).astype(np.float32)\n                    keras_output = (\n                        model(torch.from_numpy(x_test)).detach().numpy()\n                    )\n                    onnx_output = ort_session.run(None, {input_name: x_test})[0]\n\n                    self.assertEqual(keras_output.shape, onnx_output.shape)\n                    max_diff = np.abs(keras_output - onnx_output).max()\n                    self.assertLess(max_diff, 1e-4)\n\n                os.unlink(onnx_path)\n\n            except ImportError:\n                pytest.skip(\"onnxruntime not available\")\n            except Exception as e:\n                if \"Constraints violated\" in str(e):\n                    self.fail(f\"ONNX export failed: {e}\")\n                pytest.skip(f\"ONNX export not available: {e}\")\n\n        elif export_method == \"torchscript_trace\":\n            # Test TorchScript tracing\n            try:\n                traced = torch.jit.trace(model, sample_input)\n\n                # Test with different shapes\n                for shape in [(1, 3, 3, 1016), (1, 5, 5, 1016)]:\n                    x_test = torch.randn(*shape)\n                    output = traced(x_test)\n                    self.assertIsNotNone(output)\n\n            except Exception as e:\n                pytest.skip(f\"TorchScript trace not available: {e}\")\n\n    @parameterized.named_parameters(\n        (\"global_avg_pool\", \"global_avg_pool\"),\n        (\"reshape_flatten\", \"reshape_flatten\"),\n        (\"combined\", \"combined\"),\n    )\n    def test_fixed_layers_export(self, layer_type):\n        \"\"\"Test that fixed layers work with PyTorch export methods.\n\n        Tests the three main fixes:\n        1. GlobalAveragePooling2D (mean() dtype fix)\n        2. Reshape with -1 (dynamic reshape fix)\n        3. Combined scenario (variables.py SymInt fix)\n        \"\"\"\n        import tempfile\n\n        import torch\n\n        if layer_type == \"global_avg_pool\":\n            # Test GlobalAveragePooling2D (mean() fix)\n            inputs = layers.Input(shape=(None, None, 64))\n            x = layers.Conv2D(64, 3, padding=\"same\")(inputs)\n            x = layers.GlobalAveragePooling2D()(x)\n            x = layers.Dense(10)(x)\n            model = models.Model(inputs=inputs, outputs=x)\n            sample_input = torch.randn(1, 8, 8, 64)\n            test_shapes = [(1, 8, 8, 64), (2, 16, 16, 64)]\n\n        elif layer_type == \"reshape_flatten\":\n            # Test Reshape with -1 (reshape fix)\n            inputs = layers.Input(shape=(None, None, 64))\n            x = layers.Conv2D(32, 3, padding=\"same\")(inputs)\n            x = layers.Reshape((-1, 32))(x)\n            model = models.Model(inputs=inputs, outputs=x)\n            sample_input = torch.randn(1, 8, 8, 64)\n            test_shapes = [(1, 8, 8, 64), (1, 16, 16, 64)]\n\n        else:  # combined\n            # Test combined scenario (all fixes)\n            inputs = layers.Input(shape=(None, None, 64))\n            x = layers.AveragePooling2D(pool_size=2)(inputs)\n            x = layers.Conv2D(128, 3, padding=\"same\")(x)\n            x = layers.GlobalAveragePooling2D()(x)\n            x = layers.Dense(256)(x)\n            x = layers.Dropout(0.5)(x)\n            x = layers.Dense(10)(x)\n            model = models.Model(inputs=inputs, outputs=x)\n            sample_input = torch.randn(1, 8, 8, 64)\n            test_shapes = [(1, 8, 8, 64), (2, 16, 16, 64)]\n\n        # Test torch.export\n        # Note: torch.export has stricter constraints than ONNX export\n        # Skip if constraints cannot be satisfied\n        try:\n            batch_dim = torch.export.Dim(\"batch\", min=1, max=1024)\n            h_dim = torch.export.Dim(\"height\", min=1, max=1024)\n            w_dim = torch.export.Dim(\"width\", min=1, max=1024)\n\n            exported = torch.export.export(\n                model,\n                (sample_input,),\n                dynamic_shapes=(({0: batch_dim, 1: h_dim, 2: w_dim},),),\n                strict=False,\n            )\n            self.assertIsNotNone(exported)\n        except Exception as e:\n            # torch.export has known limitations with certain layers\n            # The important thing is that ONNX export works\n            if \"Constraints violated\" in str(e):\n                pytest.skip(f\"torch.export constraints not satisfiable: {e}\")\n            pytest.skip(f\"torch.export not available: {e}\")\n\n        # Test ONNX export\n        try:\n            import onnxruntime as ort\n\n            with tempfile.NamedTemporaryFile(suffix=\".onnx\", delete=False) as f:\n                onnx_path = f.name\n\n            torch.onnx.export(\n                model,\n                (sample_input,),\n                onnx_path,\n                input_names=[\"input\"],\n                output_names=[\"output\"],\n                dynamic_shapes=(\n                    (\n                        (\n                            torch.export.Dim.DYNAMIC,\n                            torch.export.Dim.DYNAMIC,\n                            torch.export.Dim.DYNAMIC,\n                            torch.export.Dim.STATIC,\n                        ),\n                    ),\n                ),\n            )\n\n            # Verify ONNX model works with varying shapes\n            ort_session = ort.InferenceSession(onnx_path)\n            input_name = ort_session.get_inputs()[0].name\n\n            for shape in test_shapes:\n                x_test = np.random.randn(*shape).astype(np.float32)\n                onnx_output = ort_session.run(None, {input_name: x_test})[0]\n                self.assertIsNotNone(onnx_output)\n\n            os.unlink(onnx_path)\n\n        except ImportError:\n            pytest.skip(\"onnxruntime not available\")\n        except TypeError as e:\n            if \"dtype\" in str(e):\n                self.fail(\n                    f\"ONNX export failed with dtype error for {layer_type}: {e}\"\n                )\n            pytest.skip(f\"ONNX export not available: {e}\")\n        except Exception as e:\n            if \"Constraints violated\" in str(e):\n                self.fail(f\"ONNX export failed for {layer_type}: {e}\")\n            pytest.skip(f\"ONNX export not available: {e}\")\n"
  },
  {
    "path": "integration_tests/tf_custom_fit_test.py",
    "content": "import numpy as np\nimport tensorflow as tf\n\nimport keras\n\n\ndef test_custom_fit():\n    class CustomModel(keras.Model):\n        def __init__(self, *args, **kwargs):\n            super().__init__(*args, **kwargs)\n            self.loss_tracker = keras.metrics.Mean(name=\"loss\")\n            self.mae_metric = keras.metrics.MeanAbsoluteError(name=\"mae\")\n            self.loss_fn = keras.losses.MeanSquaredError()\n\n        def train_step(self, data):\n            x, y = data\n            with tf.GradientTape() as tape:\n                y_pred = self(x, training=True)\n                loss = self.loss_fn(y, y_pred)\n            trainable_vars = self.trainable_variables\n            gradients = tape.gradient(loss, trainable_vars)\n            self.optimizer.apply(gradients, trainable_vars)\n            self.loss_tracker.update_state(loss)\n            self.mae_metric.update_state(y, y_pred)\n            return {\n                \"loss\": self.loss_tracker.result(),\n                \"mae\": self.mae_metric.result(),\n            }\n\n        @property\n        def metrics(self):\n            return [self.loss_tracker, self.mae_metric]\n\n    inputs = keras.Input(shape=(32,))\n    outputs = keras.layers.Dense(1)(inputs)\n    model = CustomModel(inputs, outputs)\n    model.compile(optimizer=\"adam\")\n    x = np.random.random((64, 32))\n    y = np.random.random((64, 1))\n    history = model.fit(x, y, epochs=1)\n\n    assert \"loss\" in history.history\n    assert \"mae\" in history.history\n\n    print(\"History:\")\n    print(history.history)\n\n\nif __name__ == \"__main__\":\n    test_custom_fit()\n"
  },
  {
    "path": "integration_tests/tf_distribute_training_test.py",
    "content": "import numpy as np\nimport tensorflow as tf\n\nimport keras\nfrom keras.src import layers\nfrom keras.src import losses\nfrom keras.src import metrics\nfrom keras.src import models\nfrom keras.src import optimizers\nfrom keras.src.callbacks import LearningRateScheduler\n\n\ndef test_model_fit():\n    cpus = tf.config.list_physical_devices(\"CPU\")\n    tf.config.set_logical_device_configuration(\n        cpus[0],\n        [\n            tf.config.LogicalDeviceConfiguration(),\n            tf.config.LogicalDeviceConfiguration(),\n        ],\n    )\n\n    keras.utils.set_random_seed(1337)\n\n    strategy = tf.distribute.MirroredStrategy([\"CPU:0\", \"CPU:1\"])\n    with strategy.scope():\n        inputs = layers.Input((100,), batch_size=32)\n        x = layers.Dense(256, activation=\"relu\")(inputs)\n        x = layers.Dense(256, activation=\"relu\")(x)\n        x = layers.Dense(256, activation=\"relu\")(x)\n        x = layers.BatchNormalization()(x)\n        outputs = layers.Dense(16)(x)\n        model = models.Model(inputs, outputs)\n\n    callbacks = [LearningRateScheduler(lambda _: 0.1)]\n\n    model.summary()\n\n    x = np.random.random((5000, 100))\n    y = np.random.random((5000, 16))\n    batch_size = 32\n    epochs = 2\n\n    # Fit from numpy arrays:\n    with strategy.scope():\n        model.compile(\n            optimizer=optimizers.LossScaleOptimizer(\n                optimizers.SGD(learning_rate=0.001, momentum=0.01)\n            ),\n            loss=losses.MeanSquaredError(),\n            metrics=[metrics.MeanSquaredError()],\n        )\n        history = model.fit(\n            x,\n            y,\n            batch_size=batch_size,\n            epochs=epochs,\n            validation_split=0.2,\n            callbacks=callbacks,\n        )\n\n    print(\"History:\")\n    print(history.history)\n\n    # Fit again from distributed dataset:\n    with strategy.scope():\n        dataset = tf.data.Dataset.from_tensor_slices((x, y)).batch(batch_size)\n        dataset = strategy.experimental_distribute_dataset(dataset)\n        history = model.fit(dataset, epochs=epochs, callbacks=callbacks)\n\n    print(\"History:\")\n    print(history.history)\n\n\nif __name__ == \"__main__\":\n    test_model_fit()\n"
  },
  {
    "path": "integration_tests/torch_custom_fit_test.py",
    "content": "import numpy as np\nimport torch\n\nimport keras\n\n\ndef test_custom_fit():\n    class CustomModel(keras.Model):\n        def __init__(self, *args, **kwargs):\n            super().__init__(*args, **kwargs)\n            self.loss_tracker = keras.metrics.Mean(name=\"loss\")\n            self.mae_metric = keras.metrics.MeanAbsoluteError(name=\"mae\")\n            self.loss_fn = keras.losses.MeanSquaredError()\n\n        def train_step(self, data):\n            x, y = data\n            self.zero_grad()\n            y_pred = self(x, training=True)\n            loss = self.loss_fn(y, y_pred)\n            loss.backward()\n            trainable_weights = [v for v in self.trainable_weights]\n            gradients = [v.value.grad for v in trainable_weights]\n            with torch.no_grad():\n                self.optimizer.apply(gradients, trainable_weights)\n            self.loss_tracker.update_state(loss)\n            self.mae_metric.update_state(y, y_pred)\n            return {\n                \"loss\": self.loss_tracker.result(),\n                \"mae\": self.mae_metric.result(),\n            }\n\n        @property\n        def metrics(self):\n            return [self.loss_tracker, self.mae_metric]\n\n    inputs = keras.Input(shape=(32,))\n    outputs = keras.layers.Dense(1)(inputs)\n    model = CustomModel(inputs, outputs)\n    model.compile(optimizer=\"adam\")\n    x = np.random.random((64, 32))\n    y = np.random.random((64, 1))\n    history = model.fit(x, y, epochs=1)\n\n    assert \"loss\" in history.history\n    assert \"mae\" in history.history\n\n    print(\"History:\")\n    print(history.history)\n\n\nif __name__ == \"__main__\":\n    test_custom_fit()\n"
  },
  {
    "path": "integration_tests/torch_workflow_test.py",
    "content": "import torch\n\nfrom keras.src import layers\nfrom keras.src import testing\nfrom keras.src.backend.common import KerasVariable\n\n\nclass Net(torch.nn.Module):\n    def __init__(self):\n        super().__init__()\n        self.fc1 = layers.Dense(1)\n\n    def forward(self, x):\n        x = self.fc1(x)\n        return x\n\n\nclass TorchWorkflowTest(testing.TestCase):\n    def test_keras_layer_in_nn_module(self):\n        net = Net()\n\n        # Test using Keras layer in a nn.Module.\n        # Test forward pass\n        self.assertAllEqual(list(net(torch.empty(100, 10)).shape), [100, 1])\n        # Test KerasVariables are added as nn.Parameter.\n        self.assertLen(list(net.parameters()), 2)\n\n        # Test using KerasVariable as a torch tensor for torch ops.\n        kernel = net.fc1.kernel\n        transposed_kernel = torch.transpose(kernel, 0, 1)\n        self.assertIsInstance(kernel, KerasVariable)\n        self.assertIsInstance(\n            torch.mul(kernel, transposed_kernel), torch.Tensor\n        )\n"
  },
  {
    "path": "keras/__init__.py",
    "content": "# This file should NEVER be packaged! This is a hack to make \"import keras\" from\n# the base of the repo just import the source files. We'll keep it for compat.\n\nimport os  # isort: skip\n\n# Add everything in /api/ to the module search path.\n__path__.append(os.path.join(os.path.dirname(__file__), \"api\"))  # noqa: F405\n\nfrom keras.api import *  # noqa: F403, E402\nfrom keras.api import __version__  # noqa: E402\n\n# Don't pollute namespace.\ndel os\n"
  },
  {
    "path": "keras/api/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras import _tf_keras as _tf_keras\nfrom keras import activations as activations\nfrom keras import applications as applications\nfrom keras import backend as backend\nfrom keras import callbacks as callbacks\nfrom keras import config as config\nfrom keras import constraints as constraints\nfrom keras import datasets as datasets\nfrom keras import distillation as distillation\nfrom keras import distribution as distribution\nfrom keras import dtype_policies as dtype_policies\nfrom keras import export as export\nfrom keras import initializers as initializers\nfrom keras import layers as layers\nfrom keras import legacy as legacy\nfrom keras import losses as losses\nfrom keras import metrics as metrics\nfrom keras import mixed_precision as mixed_precision\nfrom keras import models as models\nfrom keras import ops as ops\nfrom keras import optimizers as optimizers\nfrom keras import preprocessing as preprocessing\nfrom keras import quantizers as quantizers\nfrom keras import random as random\nfrom keras import regularizers as regularizers\nfrom keras import saving as saving\nfrom keras import tree as tree\nfrom keras import utils as utils\nfrom keras import visualization as visualization\nfrom keras import wrappers as wrappers\nfrom keras.src.backend import Variable as Variable\nfrom keras.src.backend import device as device\nfrom keras.src.backend import name_scope as name_scope\nfrom keras.src.backend.common.keras_tensor import KerasTensor as KerasTensor\nfrom keras.src.backend.common.remat import RematScope as RematScope\nfrom keras.src.backend.common.remat import remat as remat\nfrom keras.src.backend.common.stateless_scope import (\n    StatelessScope as StatelessScope,\n)\nfrom keras.src.backend.common.symbolic_scope import (\n    SymbolicScope as SymbolicScope,\n)\nfrom keras.src.dtype_policies.dtype_policy import DTypePolicy as DTypePolicy\nfrom keras.src.dtype_policies.dtype_policy import (\n    FloatDTypePolicy as FloatDTypePolicy,\n)\nfrom keras.src.initializers.initializer import Initializer as Initializer\nfrom keras.src.layers.core.input_layer import Input as Input\nfrom keras.src.layers.input_spec import InputSpec as InputSpec\nfrom keras.src.layers.layer import Layer as Layer\nfrom keras.src.losses.loss import Loss as Loss\nfrom keras.src.metrics.metric import Metric as Metric\nfrom keras.src.models.model import Model as Model\nfrom keras.src.models.sequential import Sequential as Sequential\nfrom keras.src.ops.function import Function as Function\nfrom keras.src.ops.operation import Operation as Operation\nfrom keras.src.optimizers.optimizer import Optimizer as Optimizer\nfrom keras.src.quantizers.quantizers import Quantizer as Quantizer\nfrom keras.src.regularizers.regularizers import Regularizer as Regularizer\nfrom keras.src.version import __version__ as __version__\nfrom keras.src.version import version as version\n"
  },
  {
    "path": "keras/api/_tf_keras/__init__.py",
    "content": "from keras._tf_keras import keras\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras import activations as activations\nfrom keras import applications as applications\nfrom keras import callbacks as callbacks\nfrom keras import config as config\nfrom keras import constraints as constraints\nfrom keras import datasets as datasets\nfrom keras import distillation as distillation\nfrom keras import distribution as distribution\nfrom keras import dtype_policies as dtype_policies\nfrom keras import export as export\nfrom keras import initializers as initializers\nfrom keras import legacy as legacy\nfrom keras import mixed_precision as mixed_precision\nfrom keras import models as models\nfrom keras import ops as ops\nfrom keras import optimizers as optimizers\nfrom keras import quantizers as quantizers\nfrom keras import random as random\nfrom keras import regularizers as regularizers\nfrom keras import tree as tree\nfrom keras import utils as utils\nfrom keras import visualization as visualization\nfrom keras import wrappers as wrappers\nfrom keras._tf_keras.keras import backend as backend\nfrom keras._tf_keras.keras import layers as layers\nfrom keras._tf_keras.keras import losses as losses\nfrom keras._tf_keras.keras import metrics as metrics\nfrom keras._tf_keras.keras import preprocessing as preprocessing\nfrom keras.src.backend import Variable as Variable\nfrom keras.src.backend import device as device\nfrom keras.src.backend import name_scope as name_scope\nfrom keras.src.backend.common.keras_tensor import KerasTensor as KerasTensor\nfrom keras.src.backend.common.remat import RematScope as RematScope\nfrom keras.src.backend.common.remat import remat as remat\nfrom keras.src.backend.common.stateless_scope import (\n    StatelessScope as StatelessScope,\n)\nfrom keras.src.backend.common.symbolic_scope import (\n    SymbolicScope as SymbolicScope,\n)\nfrom keras.src.dtype_policies.dtype_policy import DTypePolicy as DTypePolicy\nfrom keras.src.dtype_policies.dtype_policy import (\n    FloatDTypePolicy as FloatDTypePolicy,\n)\nfrom keras.src.initializers.initializer import Initializer as Initializer\nfrom keras.src.layers.core.input_layer import Input as Input\nfrom keras.src.layers.input_spec import InputSpec as InputSpec\nfrom keras.src.layers.layer import Layer as Layer\nfrom keras.src.losses.loss import Loss as Loss\nfrom keras.src.metrics.metric import Metric as Metric\nfrom keras.src.models.model import Model as Model\nfrom keras.src.models.sequential import Sequential as Sequential\nfrom keras.src.ops.function import Function as Function\nfrom keras.src.ops.operation import Operation as Operation\nfrom keras.src.optimizers.optimizer import Optimizer as Optimizer\nfrom keras.src.quantizers.quantizers import Quantizer as Quantizer\nfrom keras.src.regularizers.regularizers import Regularizer as Regularizer\nfrom keras.src.version import __version__ as __version__\nfrom keras.src.version import version as version\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/activations/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.activations import deserialize as deserialize\nfrom keras.src.activations import get as get\nfrom keras.src.activations import serialize as serialize\nfrom keras.src.activations.activations import celu as celu\nfrom keras.src.activations.activations import elu as elu\nfrom keras.src.activations.activations import exponential as exponential\nfrom keras.src.activations.activations import gelu as gelu\nfrom keras.src.activations.activations import glu as glu\nfrom keras.src.activations.activations import hard_shrink as hard_shrink\nfrom keras.src.activations.activations import hard_sigmoid as hard_sigmoid\nfrom keras.src.activations.activations import hard_silu as hard_silu\nfrom keras.src.activations.activations import hard_silu as hard_swish\nfrom keras.src.activations.activations import hard_tanh as hard_tanh\nfrom keras.src.activations.activations import leaky_relu as leaky_relu\nfrom keras.src.activations.activations import linear as linear\nfrom keras.src.activations.activations import log_sigmoid as log_sigmoid\nfrom keras.src.activations.activations import log_softmax as log_softmax\nfrom keras.src.activations.activations import mish as mish\nfrom keras.src.activations.activations import relu as relu\nfrom keras.src.activations.activations import relu6 as relu6\nfrom keras.src.activations.activations import selu as selu\nfrom keras.src.activations.activations import sigmoid as sigmoid\nfrom keras.src.activations.activations import silu as silu\nfrom keras.src.activations.activations import silu as swish\nfrom keras.src.activations.activations import soft_shrink as soft_shrink\nfrom keras.src.activations.activations import softmax as softmax\nfrom keras.src.activations.activations import softplus as softplus\nfrom keras.src.activations.activations import softsign as softsign\nfrom keras.src.activations.activations import sparse_plus as sparse_plus\nfrom keras.src.activations.activations import sparse_sigmoid as sparse_sigmoid\nfrom keras.src.activations.activations import sparsemax as sparsemax\nfrom keras.src.activations.activations import squareplus as squareplus\nfrom keras.src.activations.activations import tanh as tanh\nfrom keras.src.activations.activations import tanh_shrink as tanh_shrink\nfrom keras.src.activations.activations import threshold as threshold\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/applications/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.applications import convnext as convnext\nfrom keras.applications import densenet as densenet\nfrom keras.applications import efficientnet as efficientnet\nfrom keras.applications import efficientnet_v2 as efficientnet_v2\nfrom keras.applications import imagenet_utils as imagenet_utils\nfrom keras.applications import inception_resnet_v2 as inception_resnet_v2\nfrom keras.applications import inception_v3 as inception_v3\nfrom keras.applications import mobilenet as mobilenet\nfrom keras.applications import mobilenet_v2 as mobilenet_v2\nfrom keras.applications import mobilenet_v3 as mobilenet_v3\nfrom keras.applications import nasnet as nasnet\nfrom keras.applications import resnet as resnet\nfrom keras.applications import resnet50 as resnet50\nfrom keras.applications import resnet_v2 as resnet_v2\nfrom keras.applications import vgg16 as vgg16\nfrom keras.applications import vgg19 as vgg19\nfrom keras.applications import xception as xception\nfrom keras.src.applications.convnext import ConvNeXtBase as ConvNeXtBase\nfrom keras.src.applications.convnext import ConvNeXtLarge as ConvNeXtLarge\nfrom keras.src.applications.convnext import ConvNeXtSmall as ConvNeXtSmall\nfrom keras.src.applications.convnext import ConvNeXtTiny as ConvNeXtTiny\nfrom keras.src.applications.convnext import ConvNeXtXLarge as ConvNeXtXLarge\nfrom keras.src.applications.densenet import DenseNet121 as DenseNet121\nfrom keras.src.applications.densenet import DenseNet169 as DenseNet169\nfrom keras.src.applications.densenet import DenseNet201 as DenseNet201\nfrom keras.src.applications.efficientnet import EfficientNetB0 as EfficientNetB0\nfrom keras.src.applications.efficientnet import EfficientNetB1 as EfficientNetB1\nfrom keras.src.applications.efficientnet import EfficientNetB2 as EfficientNetB2\nfrom keras.src.applications.efficientnet import EfficientNetB3 as EfficientNetB3\nfrom keras.src.applications.efficientnet import EfficientNetB4 as EfficientNetB4\nfrom keras.src.applications.efficientnet import EfficientNetB5 as EfficientNetB5\nfrom keras.src.applications.efficientnet import EfficientNetB6 as EfficientNetB6\nfrom keras.src.applications.efficientnet import EfficientNetB7 as EfficientNetB7\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2B0 as EfficientNetV2B0,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2B1 as EfficientNetV2B1,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2B2 as EfficientNetV2B2,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2B3 as EfficientNetV2B3,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2L as EfficientNetV2L,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2M as EfficientNetV2M,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2S as EfficientNetV2S,\n)\nfrom keras.src.applications.inception_resnet_v2 import (\n    InceptionResNetV2 as InceptionResNetV2,\n)\nfrom keras.src.applications.inception_v3 import InceptionV3 as InceptionV3\nfrom keras.src.applications.mobilenet import MobileNet as MobileNet\nfrom keras.src.applications.mobilenet_v2 import MobileNetV2 as MobileNetV2\nfrom keras.src.applications.mobilenet_v3 import (\n    MobileNetV3Large as MobileNetV3Large,\n)\nfrom keras.src.applications.mobilenet_v3 import (\n    MobileNetV3Small as MobileNetV3Small,\n)\nfrom keras.src.applications.nasnet import NASNetLarge as NASNetLarge\nfrom keras.src.applications.nasnet import NASNetMobile as NASNetMobile\nfrom keras.src.applications.resnet import ResNet50 as ResNet50\nfrom keras.src.applications.resnet import ResNet101 as ResNet101\nfrom keras.src.applications.resnet import ResNet152 as ResNet152\nfrom keras.src.applications.resnet_v2 import ResNet50V2 as ResNet50V2\nfrom keras.src.applications.resnet_v2 import ResNet101V2 as ResNet101V2\nfrom keras.src.applications.resnet_v2 import ResNet152V2 as ResNet152V2\nfrom keras.src.applications.vgg16 import VGG16 as VGG16\nfrom keras.src.applications.vgg19 import VGG19 as VGG19\nfrom keras.src.applications.xception import Xception as Xception\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/applications/convnext/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.convnext import ConvNeXtBase as ConvNeXtBase\nfrom keras.src.applications.convnext import ConvNeXtLarge as ConvNeXtLarge\nfrom keras.src.applications.convnext import ConvNeXtSmall as ConvNeXtSmall\nfrom keras.src.applications.convnext import ConvNeXtTiny as ConvNeXtTiny\nfrom keras.src.applications.convnext import ConvNeXtXLarge as ConvNeXtXLarge\nfrom keras.src.applications.convnext import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.convnext import preprocess_input as preprocess_input\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/applications/densenet/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.densenet import DenseNet121 as DenseNet121\nfrom keras.src.applications.densenet import DenseNet169 as DenseNet169\nfrom keras.src.applications.densenet import DenseNet201 as DenseNet201\nfrom keras.src.applications.densenet import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.densenet import preprocess_input as preprocess_input\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/applications/efficientnet/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.efficientnet import EfficientNetB0 as EfficientNetB0\nfrom keras.src.applications.efficientnet import EfficientNetB1 as EfficientNetB1\nfrom keras.src.applications.efficientnet import EfficientNetB2 as EfficientNetB2\nfrom keras.src.applications.efficientnet import EfficientNetB3 as EfficientNetB3\nfrom keras.src.applications.efficientnet import EfficientNetB4 as EfficientNetB4\nfrom keras.src.applications.efficientnet import EfficientNetB5 as EfficientNetB5\nfrom keras.src.applications.efficientnet import EfficientNetB6 as EfficientNetB6\nfrom keras.src.applications.efficientnet import EfficientNetB7 as EfficientNetB7\nfrom keras.src.applications.efficientnet import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.efficientnet import (\n    preprocess_input as preprocess_input,\n)\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/applications/efficientnet_v2/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2B0 as EfficientNetV2B0,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2B1 as EfficientNetV2B1,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2B2 as EfficientNetV2B2,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2B3 as EfficientNetV2B3,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2L as EfficientNetV2L,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2M as EfficientNetV2M,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2S as EfficientNetV2S,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    preprocess_input as preprocess_input,\n)\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/applications/imagenet_utils/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.imagenet_utils import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.imagenet_utils import (\n    preprocess_input as preprocess_input,\n)\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/applications/inception_resnet_v2/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.inception_resnet_v2 import (\n    InceptionResNetV2 as InceptionResNetV2,\n)\nfrom keras.src.applications.inception_resnet_v2 import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.inception_resnet_v2 import (\n    preprocess_input as preprocess_input,\n)\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/applications/inception_v3/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.inception_v3 import InceptionV3 as InceptionV3\nfrom keras.src.applications.inception_v3 import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.inception_v3 import (\n    preprocess_input as preprocess_input,\n)\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/applications/mobilenet/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.mobilenet import MobileNet as MobileNet\nfrom keras.src.applications.mobilenet import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.mobilenet import (\n    preprocess_input as preprocess_input,\n)\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/applications/mobilenet_v2/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.mobilenet_v2 import MobileNetV2 as MobileNetV2\nfrom keras.src.applications.mobilenet_v2 import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.mobilenet_v2 import (\n    preprocess_input as preprocess_input,\n)\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/applications/mobilenet_v3/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.mobilenet_v3 import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.mobilenet_v3 import (\n    preprocess_input as preprocess_input,\n)\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/applications/nasnet/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.nasnet import NASNetLarge as NASNetLarge\nfrom keras.src.applications.nasnet import NASNetMobile as NASNetMobile\nfrom keras.src.applications.nasnet import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.nasnet import preprocess_input as preprocess_input\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/applications/resnet/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.resnet import ResNet50 as ResNet50\nfrom keras.src.applications.resnet import ResNet101 as ResNet101\nfrom keras.src.applications.resnet import ResNet152 as ResNet152\nfrom keras.src.applications.resnet import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.resnet import preprocess_input as preprocess_input\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/applications/resnet50/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.resnet import ResNet50 as ResNet50\nfrom keras.src.applications.resnet import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.resnet import preprocess_input as preprocess_input\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/applications/resnet_v2/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.resnet_v2 import ResNet50V2 as ResNet50V2\nfrom keras.src.applications.resnet_v2 import ResNet101V2 as ResNet101V2\nfrom keras.src.applications.resnet_v2 import ResNet152V2 as ResNet152V2\nfrom keras.src.applications.resnet_v2 import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.resnet_v2 import (\n    preprocess_input as preprocess_input,\n)\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/applications/vgg16/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.vgg16 import VGG16 as VGG16\nfrom keras.src.applications.vgg16 import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.vgg16 import preprocess_input as preprocess_input\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/applications/vgg19/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.vgg19 import VGG19 as VGG19\nfrom keras.src.applications.vgg19 import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.vgg19 import preprocess_input as preprocess_input\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/applications/xception/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.xception import Xception as Xception\nfrom keras.src.applications.xception import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.xception import preprocess_input as preprocess_input\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/backend/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.backend.common.dtypes import result_type as result_type\nfrom keras.src.backend.common.global_state import clear_session as clear_session\nfrom keras.src.backend.common.keras_tensor import (\n    is_keras_tensor as is_keras_tensor,\n)\nfrom keras.src.backend.common.variables import is_float_dtype as is_float_dtype\nfrom keras.src.backend.common.variables import is_int_dtype as is_int_dtype\nfrom keras.src.backend.common.variables import (\n    standardize_dtype as standardize_dtype,\n)\nfrom keras.src.backend.config import backend as backend\nfrom keras.src.backend.config import epsilon as epsilon\nfrom keras.src.backend.config import floatx as floatx\nfrom keras.src.backend.config import image_data_format as image_data_format\nfrom keras.src.backend.config import set_epsilon as set_epsilon\nfrom keras.src.backend.config import set_floatx as set_floatx\nfrom keras.src.backend.config import (\n    set_image_data_format as set_image_data_format,\n)\nfrom keras.src.legacy.backend import abs as abs\nfrom keras.src.legacy.backend import all as all\nfrom keras.src.legacy.backend import any as any\nfrom keras.src.legacy.backend import arange as arange\nfrom keras.src.legacy.backend import argmax as argmax\nfrom keras.src.legacy.backend import argmin as argmin\nfrom keras.src.legacy.backend import batch_dot as batch_dot\nfrom keras.src.legacy.backend import batch_flatten as batch_flatten\nfrom keras.src.legacy.backend import batch_get_value as batch_get_value\nfrom keras.src.legacy.backend import batch_normalization as batch_normalization\nfrom keras.src.legacy.backend import batch_set_value as batch_set_value\nfrom keras.src.legacy.backend import bias_add as bias_add\nfrom keras.src.legacy.backend import binary_crossentropy as binary_crossentropy\nfrom keras.src.legacy.backend import (\n    binary_focal_crossentropy as binary_focal_crossentropy,\n)\nfrom keras.src.legacy.backend import cast as cast\nfrom keras.src.legacy.backend import cast_to_floatx as cast_to_floatx\nfrom keras.src.legacy.backend import (\n    categorical_crossentropy as categorical_crossentropy,\n)\nfrom keras.src.legacy.backend import (\n    categorical_focal_crossentropy as categorical_focal_crossentropy,\n)\nfrom keras.src.legacy.backend import clip as clip\nfrom keras.src.legacy.backend import concatenate as concatenate\nfrom keras.src.legacy.backend import constant as constant\nfrom keras.src.legacy.backend import conv1d as conv1d\nfrom keras.src.legacy.backend import conv2d as conv2d\nfrom keras.src.legacy.backend import conv2d_transpose as conv2d_transpose\nfrom keras.src.legacy.backend import conv3d as conv3d\nfrom keras.src.legacy.backend import cos as cos\nfrom keras.src.legacy.backend import count_params as count_params\nfrom keras.src.legacy.backend import ctc_batch_cost as ctc_batch_cost\nfrom keras.src.legacy.backend import ctc_decode as ctc_decode\nfrom keras.src.legacy.backend import (\n    ctc_label_dense_to_sparse as ctc_label_dense_to_sparse,\n)\nfrom keras.src.legacy.backend import cumprod as cumprod\nfrom keras.src.legacy.backend import cumsum as cumsum\nfrom keras.src.legacy.backend import depthwise_conv2d as depthwise_conv2d\nfrom keras.src.legacy.backend import dot as dot\nfrom keras.src.legacy.backend import dropout as dropout\nfrom keras.src.legacy.backend import dtype as dtype\nfrom keras.src.legacy.backend import elu as elu\nfrom keras.src.legacy.backend import equal as equal\nfrom keras.src.legacy.backend import eval as eval\nfrom keras.src.legacy.backend import exp as exp\nfrom keras.src.legacy.backend import expand_dims as expand_dims\nfrom keras.src.legacy.backend import eye as eye\nfrom keras.src.legacy.backend import flatten as flatten\nfrom keras.src.legacy.backend import foldl as foldl\nfrom keras.src.legacy.backend import foldr as foldr\nfrom keras.src.legacy.backend import gather as gather\nfrom keras.src.legacy.backend import get_value as get_value\nfrom keras.src.legacy.backend import gradients as gradients\nfrom keras.src.legacy.backend import greater as greater\nfrom keras.src.legacy.backend import greater_equal as greater_equal\nfrom keras.src.legacy.backend import hard_sigmoid as hard_sigmoid\nfrom keras.src.legacy.backend import in_top_k as in_top_k\nfrom keras.src.legacy.backend import int_shape as int_shape\nfrom keras.src.legacy.backend import is_sparse as is_sparse\nfrom keras.src.legacy.backend import l2_normalize as l2_normalize\nfrom keras.src.legacy.backend import less as less\nfrom keras.src.legacy.backend import less_equal as less_equal\nfrom keras.src.legacy.backend import log as log\nfrom keras.src.legacy.backend import map_fn as map_fn\nfrom keras.src.legacy.backend import max as max\nfrom keras.src.legacy.backend import maximum as maximum\nfrom keras.src.legacy.backend import mean as mean\nfrom keras.src.legacy.backend import min as min\nfrom keras.src.legacy.backend import minimum as minimum\nfrom keras.src.legacy.backend import (\n    moving_average_update as moving_average_update,\n)\nfrom keras.src.legacy.backend import name_scope as name_scope\nfrom keras.src.legacy.backend import ndim as ndim\nfrom keras.src.legacy.backend import not_equal as not_equal\nfrom keras.src.legacy.backend import one_hot as one_hot\nfrom keras.src.legacy.backend import ones as ones\nfrom keras.src.legacy.backend import ones_like as ones_like\nfrom keras.src.legacy.backend import permute_dimensions as permute_dimensions\nfrom keras.src.legacy.backend import pool2d as pool2d\nfrom keras.src.legacy.backend import pool3d as pool3d\nfrom keras.src.legacy.backend import pow as pow\nfrom keras.src.legacy.backend import prod as prod\nfrom keras.src.legacy.backend import random_bernoulli as random_bernoulli\nfrom keras.src.legacy.backend import random_normal as random_normal\nfrom keras.src.legacy.backend import (\n    random_normal_variable as random_normal_variable,\n)\nfrom keras.src.legacy.backend import random_uniform as random_uniform\nfrom keras.src.legacy.backend import (\n    random_uniform_variable as random_uniform_variable,\n)\nfrom keras.src.legacy.backend import relu as relu\nfrom keras.src.legacy.backend import repeat as repeat\nfrom keras.src.legacy.backend import repeat_elements as repeat_elements\nfrom keras.src.legacy.backend import reshape as reshape\nfrom keras.src.legacy.backend import resize_images as resize_images\nfrom keras.src.legacy.backend import resize_volumes as resize_volumes\nfrom keras.src.legacy.backend import reverse as reverse\nfrom keras.src.legacy.backend import rnn as rnn\nfrom keras.src.legacy.backend import round as round\nfrom keras.src.legacy.backend import separable_conv2d as separable_conv2d\nfrom keras.src.legacy.backend import set_value as set_value\nfrom keras.src.legacy.backend import shape as shape\nfrom keras.src.legacy.backend import sigmoid as sigmoid\nfrom keras.src.legacy.backend import sign as sign\nfrom keras.src.legacy.backend import sin as sin\nfrom keras.src.legacy.backend import softmax as softmax\nfrom keras.src.legacy.backend import softplus as softplus\nfrom keras.src.legacy.backend import softsign as softsign\nfrom keras.src.legacy.backend import (\n    sparse_categorical_crossentropy as sparse_categorical_crossentropy,\n)\nfrom keras.src.legacy.backend import spatial_2d_padding as spatial_2d_padding\nfrom keras.src.legacy.backend import spatial_3d_padding as spatial_3d_padding\nfrom keras.src.legacy.backend import sqrt as sqrt\nfrom keras.src.legacy.backend import square as square\nfrom keras.src.legacy.backend import squeeze as squeeze\nfrom keras.src.legacy.backend import stack as stack\nfrom keras.src.legacy.backend import std as std\nfrom keras.src.legacy.backend import stop_gradient as stop_gradient\nfrom keras.src.legacy.backend import sum as sum\nfrom keras.src.legacy.backend import switch as switch\nfrom keras.src.legacy.backend import tanh as tanh\nfrom keras.src.legacy.backend import temporal_padding as temporal_padding\nfrom keras.src.legacy.backend import tile as tile\nfrom keras.src.legacy.backend import to_dense as to_dense\nfrom keras.src.legacy.backend import transpose as transpose\nfrom keras.src.legacy.backend import truncated_normal as truncated_normal\nfrom keras.src.legacy.backend import update as update\nfrom keras.src.legacy.backend import update_add as update_add\nfrom keras.src.legacy.backend import update_sub as update_sub\nfrom keras.src.legacy.backend import var as var\nfrom keras.src.legacy.backend import variable as variable\nfrom keras.src.legacy.backend import zeros as zeros\nfrom keras.src.legacy.backend import zeros_like as zeros_like\nfrom keras.src.utils.naming import get_uid as get_uid\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/callbacks/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.callbacks.backup_and_restore import (\n    BackupAndRestore as BackupAndRestore,\n)\nfrom keras.src.callbacks.callback import Callback as Callback\nfrom keras.src.callbacks.callback_list import CallbackList as CallbackList\nfrom keras.src.callbacks.csv_logger import CSVLogger as CSVLogger\nfrom keras.src.callbacks.early_stopping import EarlyStopping as EarlyStopping\nfrom keras.src.callbacks.history import History as History\nfrom keras.src.callbacks.lambda_callback import LambdaCallback as LambdaCallback\nfrom keras.src.callbacks.learning_rate_scheduler import (\n    LearningRateScheduler as LearningRateScheduler,\n)\nfrom keras.src.callbacks.model_checkpoint import (\n    ModelCheckpoint as ModelCheckpoint,\n)\nfrom keras.src.callbacks.orbax_checkpoint import (\n    OrbaxCheckpoint as OrbaxCheckpoint,\n)\nfrom keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger\nfrom keras.src.callbacks.reduce_lr_on_plateau import (\n    ReduceLROnPlateau as ReduceLROnPlateau,\n)\nfrom keras.src.callbacks.remote_monitor import RemoteMonitor as RemoteMonitor\nfrom keras.src.callbacks.swap_ema_weights import (\n    SwapEMAWeights as SwapEMAWeights,\n)\nfrom keras.src.callbacks.tensorboard import TensorBoard as TensorBoard\nfrom keras.src.callbacks.terminate_on_nan import (\n    TerminateOnNaN as TerminateOnNaN,\n)\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/config/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.backend.config import backend as backend\nfrom keras.src.backend.config import (\n    disable_flash_attention as disable_flash_attention,\n)\nfrom keras.src.backend.config import (\n    enable_flash_attention as enable_flash_attention,\n)\nfrom keras.src.backend.config import epsilon as epsilon\nfrom keras.src.backend.config import floatx as floatx\nfrom keras.src.backend.config import image_data_format as image_data_format\nfrom keras.src.backend.config import (\n    is_flash_attention_enabled as is_flash_attention_enabled,\n)\nfrom keras.src.backend.config import is_nnx_enabled as is_nnx_enabled\nfrom keras.src.backend.config import max_epochs as max_epochs\nfrom keras.src.backend.config import max_steps_per_epoch as max_steps_per_epoch\nfrom keras.src.backend.config import set_epsilon as set_epsilon\nfrom keras.src.backend.config import set_floatx as set_floatx\nfrom keras.src.backend.config import (\n    set_image_data_format as set_image_data_format,\n)\nfrom keras.src.backend.config import set_max_epochs as set_max_epochs\nfrom keras.src.backend.config import (\n    set_max_steps_per_epoch as set_max_steps_per_epoch,\n)\nfrom keras.src.dtype_policies.dtype_policy import dtype_policy as dtype_policy\nfrom keras.src.dtype_policies.dtype_policy import (\n    set_dtype_policy as set_dtype_policy,\n)\nfrom keras.src.saving.serialization_lib import (\n    enable_unsafe_deserialization as enable_unsafe_deserialization,\n)\nfrom keras.src.utils.backend_utils import set_backend as set_backend\nfrom keras.src.utils.io_utils import (\n    disable_interactive_logging as disable_interactive_logging,\n)\nfrom keras.src.utils.io_utils import (\n    enable_interactive_logging as enable_interactive_logging,\n)\nfrom keras.src.utils.io_utils import (\n    is_interactive_logging_enabled as is_interactive_logging_enabled,\n)\nfrom keras.src.utils.traceback_utils import (\n    disable_traceback_filtering as disable_traceback_filtering,\n)\nfrom keras.src.utils.traceback_utils import (\n    enable_traceback_filtering as enable_traceback_filtering,\n)\nfrom keras.src.utils.traceback_utils import (\n    is_traceback_filtering_enabled as is_traceback_filtering_enabled,\n)\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/constraints/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.constraints import deserialize as deserialize\nfrom keras.src.constraints import get as get\nfrom keras.src.constraints import serialize as serialize\nfrom keras.src.constraints.constraints import Constraint as Constraint\nfrom keras.src.constraints.constraints import MaxNorm as MaxNorm\nfrom keras.src.constraints.constraints import MaxNorm as max_norm\nfrom keras.src.constraints.constraints import MinMaxNorm as MinMaxNorm\nfrom keras.src.constraints.constraints import MinMaxNorm as min_max_norm\nfrom keras.src.constraints.constraints import NonNeg as NonNeg\nfrom keras.src.constraints.constraints import NonNeg as non_neg\nfrom keras.src.constraints.constraints import UnitNorm as UnitNorm\nfrom keras.src.constraints.constraints import UnitNorm as unit_norm\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/datasets/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.datasets import boston_housing as boston_housing\nfrom keras.datasets import california_housing as california_housing\nfrom keras.datasets import cifar10 as cifar10\nfrom keras.datasets import cifar100 as cifar100\nfrom keras.datasets import fashion_mnist as fashion_mnist\nfrom keras.datasets import imdb as imdb\nfrom keras.datasets import mnist as mnist\nfrom keras.datasets import reuters as reuters\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/datasets/boston_housing/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.datasets.boston_housing import load_data as load_data\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/datasets/california_housing/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.datasets.california_housing import load_data as load_data\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/datasets/cifar10/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.datasets.cifar10 import load_data as load_data\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/datasets/cifar100/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.datasets.cifar100 import load_data as load_data\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/datasets/fashion_mnist/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.datasets.fashion_mnist import load_data as load_data\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/datasets/imdb/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.datasets.imdb import get_word_index as get_word_index\nfrom keras.src.datasets.imdb import load_data as load_data\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/datasets/mnist/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.datasets.mnist import load_data as load_data\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/datasets/reuters/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.datasets.reuters import get_label_names as get_label_names\nfrom keras.src.datasets.reuters import get_word_index as get_word_index\nfrom keras.src.datasets.reuters import load_data as load_data\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/distillation/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.distillation.distillation_loss import (\n    DistillationLoss as DistillationLoss,\n)\nfrom keras.src.distillation.distillation_loss import (\n    FeatureDistillation as FeatureDistillation,\n)\nfrom keras.src.distillation.distillation_loss import (\n    LogitsDistillation as LogitsDistillation,\n)\nfrom keras.src.distillation.distiller import Distiller as Distiller\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/distribution/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.distribution.distribution_lib import DataParallel as DataParallel\nfrom keras.src.distribution.distribution_lib import DeviceMesh as DeviceMesh\nfrom keras.src.distribution.distribution_lib import LayoutMap as LayoutMap\nfrom keras.src.distribution.distribution_lib import (\n    ModelParallel as ModelParallel,\n)\nfrom keras.src.distribution.distribution_lib import TensorLayout as TensorLayout\nfrom keras.src.distribution.distribution_lib import (\n    distribute_tensor as distribute_tensor,\n)\nfrom keras.src.distribution.distribution_lib import distribution as distribution\nfrom keras.src.distribution.distribution_lib import (\n    get_device_count as get_device_count,\n)\nfrom keras.src.distribution.distribution_lib import initialize as initialize\nfrom keras.src.distribution.distribution_lib import list_devices as list_devices\nfrom keras.src.distribution.distribution_lib import (\n    set_distribution as set_distribution,\n)\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/dtype_policies/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.dtype_policies import deserialize as deserialize\nfrom keras.src.dtype_policies import get as get\nfrom keras.src.dtype_policies import serialize as serialize\nfrom keras.src.dtype_policies.dtype_policy import (\n    AWQDTypePolicy as AWQDTypePolicy,\n)\nfrom keras.src.dtype_policies.dtype_policy import DTypePolicy as DTypePolicy\nfrom keras.src.dtype_policies.dtype_policy import (\n    FloatDTypePolicy as FloatDTypePolicy,\n)\nfrom keras.src.dtype_policies.dtype_policy import (\n    GPTQDTypePolicy as GPTQDTypePolicy,\n)\nfrom keras.src.dtype_policies.dtype_policy import (\n    Int4DTypePolicy as Int4DTypePolicy,\n)\nfrom keras.src.dtype_policies.dtype_policy import (\n    QuantizedDTypePolicy as QuantizedDTypePolicy,\n)\nfrom keras.src.dtype_policies.dtype_policy import (\n    QuantizedFloat8DTypePolicy as QuantizedFloat8DTypePolicy,\n)\nfrom keras.src.dtype_policies.dtype_policy_map import (\n    DTypePolicyMap as DTypePolicyMap,\n)\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/export/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.export.saved_model import ExportArchive as ExportArchive\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/initializers/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.initializers import deserialize as deserialize\nfrom keras.src.initializers import get as get\nfrom keras.src.initializers import serialize as serialize\nfrom keras.src.initializers.constant_initializers import STFT as STFT\nfrom keras.src.initializers.constant_initializers import STFT as STFTInitializer\nfrom keras.src.initializers.constant_initializers import STFT as stft\nfrom keras.src.initializers.constant_initializers import Constant as Constant\nfrom keras.src.initializers.constant_initializers import Constant as constant\nfrom keras.src.initializers.constant_initializers import Identity as Identity\nfrom keras.src.initializers.constant_initializers import (\n    Identity as IdentityInitializer,\n)\nfrom keras.src.initializers.constant_initializers import Identity as identity\nfrom keras.src.initializers.constant_initializers import Ones as Ones\nfrom keras.src.initializers.constant_initializers import Ones as ones\nfrom keras.src.initializers.constant_initializers import Zeros as Zeros\nfrom keras.src.initializers.constant_initializers import Zeros as zeros\nfrom keras.src.initializers.initializer import Initializer as Initializer\nfrom keras.src.initializers.random_initializers import (\n    GlorotNormal as GlorotNormal,\n)\nfrom keras.src.initializers.random_initializers import (\n    GlorotNormal as glorot_normal,\n)\nfrom keras.src.initializers.random_initializers import (\n    GlorotUniform as GlorotUniform,\n)\nfrom keras.src.initializers.random_initializers import (\n    GlorotUniform as glorot_uniform,\n)\nfrom keras.src.initializers.random_initializers import HeNormal as HeNormal\nfrom keras.src.initializers.random_initializers import HeNormal as he_normal\nfrom keras.src.initializers.random_initializers import HeUniform as HeUniform\nfrom keras.src.initializers.random_initializers import HeUniform as he_uniform\nfrom keras.src.initializers.random_initializers import (\n    LecunNormal as LecunNormal,\n)\nfrom keras.src.initializers.random_initializers import (\n    LecunNormal as lecun_normal,\n)\nfrom keras.src.initializers.random_initializers import (\n    LecunUniform as LecunUniform,\n)\nfrom keras.src.initializers.random_initializers import (\n    LecunUniform as lecun_uniform,\n)\nfrom keras.src.initializers.random_initializers import Orthogonal as Orthogonal\nfrom keras.src.initializers.random_initializers import (\n    Orthogonal as OrthogonalInitializer,\n)\nfrom keras.src.initializers.random_initializers import Orthogonal as orthogonal\nfrom keras.src.initializers.random_initializers import (\n    RandomNormal as RandomNormal,\n)\nfrom keras.src.initializers.random_initializers import (\n    RandomNormal as random_normal,\n)\nfrom keras.src.initializers.random_initializers import (\n    RandomUniform as RandomUniform,\n)\nfrom keras.src.initializers.random_initializers import (\n    RandomUniform as random_uniform,\n)\nfrom keras.src.initializers.random_initializers import (\n    TruncatedNormal as TruncatedNormal,\n)\nfrom keras.src.initializers.random_initializers import (\n    TruncatedNormal as truncated_normal,\n)\nfrom keras.src.initializers.random_initializers import (\n    VarianceScaling as VarianceScaling,\n)\nfrom keras.src.initializers.random_initializers import (\n    VarianceScaling as variance_scaling,\n)\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/layers/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.export.tfsm_layer import TFSMLayer as TFSMLayer\nfrom keras.src.layers import deserialize as deserialize\nfrom keras.src.layers import serialize as serialize\nfrom keras.src.layers.activations.activation import Activation as Activation\nfrom keras.src.layers.activations.elu import ELU as ELU\nfrom keras.src.layers.activations.leaky_relu import LeakyReLU as LeakyReLU\nfrom keras.src.layers.activations.prelu import PReLU as PReLU\nfrom keras.src.layers.activations.relu import ReLU as ReLU\nfrom keras.src.layers.activations.softmax import Softmax as Softmax\nfrom keras.src.layers.attention.additive_attention import (\n    AdditiveAttention as AdditiveAttention,\n)\nfrom keras.src.layers.attention.attention import Attention as Attention\nfrom keras.src.layers.attention.grouped_query_attention import (\n    GroupedQueryAttention as GroupQueryAttention,\n)\nfrom keras.src.layers.attention.multi_head_attention import (\n    MultiHeadAttention as MultiHeadAttention,\n)\nfrom keras.src.layers.convolutional.conv1d import Conv1D as Conv1D\nfrom keras.src.layers.convolutional.conv1d import Conv1D as Convolution1D\nfrom keras.src.layers.convolutional.conv1d_transpose import (\n    Conv1DTranspose as Conv1DTranspose,\n)\nfrom keras.src.layers.convolutional.conv1d_transpose import (\n    Conv1DTranspose as Convolution1DTranspose,\n)\nfrom keras.src.layers.convolutional.conv2d import Conv2D as Conv2D\nfrom keras.src.layers.convolutional.conv2d import Conv2D as Convolution2D\nfrom keras.src.layers.convolutional.conv2d_transpose import (\n    Conv2DTranspose as Conv2DTranspose,\n)\nfrom keras.src.layers.convolutional.conv2d_transpose import (\n    Conv2DTranspose as Convolution2DTranspose,\n)\nfrom keras.src.layers.convolutional.conv3d import Conv3D as Conv3D\nfrom keras.src.layers.convolutional.conv3d import Conv3D as Convolution3D\nfrom keras.src.layers.convolutional.conv3d_transpose import (\n    Conv3DTranspose as Conv3DTranspose,\n)\nfrom keras.src.layers.convolutional.conv3d_transpose import (\n    Conv3DTranspose as Convolution3DTranspose,\n)\nfrom keras.src.layers.convolutional.depthwise_conv1d import (\n    DepthwiseConv1D as DepthwiseConv1D,\n)\nfrom keras.src.layers.convolutional.depthwise_conv2d import (\n    DepthwiseConv2D as DepthwiseConv2D,\n)\nfrom keras.src.layers.convolutional.separable_conv1d import (\n    SeparableConv1D as SeparableConv1D,\n)\nfrom keras.src.layers.convolutional.separable_conv1d import (\n    SeparableConv1D as SeparableConvolution1D,\n)\nfrom keras.src.layers.convolutional.separable_conv2d import (\n    SeparableConv2D as SeparableConv2D,\n)\nfrom keras.src.layers.convolutional.separable_conv2d import (\n    SeparableConv2D as SeparableConvolution2D,\n)\nfrom keras.src.layers.core.dense import Dense as Dense\nfrom keras.src.layers.core.einsum_dense import EinsumDense as EinsumDense\nfrom keras.src.layers.core.embedding import Embedding as Embedding\nfrom keras.src.layers.core.identity import Identity as Identity\nfrom keras.src.layers.core.input_layer import Input as Input\nfrom keras.src.layers.core.input_layer import InputLayer as InputLayer\nfrom keras.src.layers.core.lambda_layer import Lambda as Lambda\nfrom keras.src.layers.core.masking import Masking as Masking\nfrom keras.src.layers.core.reversible_embedding import (\n    ReversibleEmbedding as ReversibleEmbedding,\n)\nfrom keras.src.layers.core.wrapper import Wrapper as Wrapper\nfrom keras.src.layers.input_spec import InputSpec as InputSpec\nfrom keras.src.layers.layer import Layer as Layer\nfrom keras.src.layers.merging.add import Add as Add\nfrom keras.src.layers.merging.add import add as add\nfrom keras.src.layers.merging.average import Average as Average\nfrom keras.src.layers.merging.average import average as average\nfrom keras.src.layers.merging.concatenate import Concatenate as Concatenate\nfrom keras.src.layers.merging.concatenate import concatenate as concatenate\nfrom keras.src.layers.merging.dot import Dot as Dot\nfrom keras.src.layers.merging.dot import dot as dot\nfrom keras.src.layers.merging.maximum import Maximum as Maximum\nfrom keras.src.layers.merging.maximum import maximum as maximum\nfrom keras.src.layers.merging.minimum import Minimum as Minimum\nfrom keras.src.layers.merging.minimum import minimum as minimum\nfrom keras.src.layers.merging.multiply import Multiply as Multiply\nfrom keras.src.layers.merging.multiply import multiply as multiply\nfrom keras.src.layers.merging.subtract import Subtract as Subtract\nfrom keras.src.layers.merging.subtract import subtract as subtract\nfrom keras.src.layers.normalization.batch_normalization import (\n    BatchNormalization as BatchNormalization,\n)\nfrom keras.src.layers.normalization.group_normalization import (\n    GroupNormalization as GroupNormalization,\n)\nfrom keras.src.layers.normalization.layer_normalization import (\n    LayerNormalization as LayerNormalization,\n)\nfrom keras.src.layers.normalization.rms_normalization import (\n    RMSNormalization as RMSNormalization,\n)\nfrom keras.src.layers.normalization.spectral_normalization import (\n    SpectralNormalization as SpectralNormalization,\n)\nfrom keras.src.layers.normalization.unit_normalization import (\n    UnitNormalization as UnitNormalization,\n)\nfrom keras.src.layers.pooling.adaptive_average_pooling1d import (\n    AdaptiveAveragePooling1D as AdaptiveAveragePooling1D,\n)\nfrom keras.src.layers.pooling.adaptive_average_pooling2d import (\n    AdaptiveAveragePooling2D as AdaptiveAveragePooling2D,\n)\nfrom keras.src.layers.pooling.adaptive_average_pooling3d import (\n    AdaptiveAveragePooling3D as AdaptiveAveragePooling3D,\n)\nfrom keras.src.layers.pooling.adaptive_max_pooling1d import (\n    AdaptiveMaxPooling1D as AdaptiveMaxPooling1D,\n)\nfrom keras.src.layers.pooling.adaptive_max_pooling2d import (\n    AdaptiveMaxPooling2D as AdaptiveMaxPooling2D,\n)\nfrom keras.src.layers.pooling.adaptive_max_pooling3d import (\n    AdaptiveMaxPooling3D as AdaptiveMaxPooling3D,\n)\nfrom keras.src.layers.pooling.average_pooling1d import (\n    AveragePooling1D as AveragePooling1D,\n)\nfrom keras.src.layers.pooling.average_pooling1d import (\n    AveragePooling1D as AvgPool1D,\n)\nfrom keras.src.layers.pooling.average_pooling2d import (\n    AveragePooling2D as AveragePooling2D,\n)\nfrom keras.src.layers.pooling.average_pooling2d import (\n    AveragePooling2D as AvgPool2D,\n)\nfrom keras.src.layers.pooling.average_pooling3d import (\n    AveragePooling3D as AveragePooling3D,\n)\nfrom keras.src.layers.pooling.average_pooling3d import (\n    AveragePooling3D as AvgPool3D,\n)\nfrom keras.src.layers.pooling.global_average_pooling1d import (\n    GlobalAveragePooling1D as GlobalAveragePooling1D,\n)\nfrom keras.src.layers.pooling.global_average_pooling1d import (\n    GlobalAveragePooling1D as GlobalAvgPool1D,\n)\nfrom keras.src.layers.pooling.global_average_pooling2d import (\n    GlobalAveragePooling2D as GlobalAveragePooling2D,\n)\nfrom keras.src.layers.pooling.global_average_pooling2d import (\n    GlobalAveragePooling2D as GlobalAvgPool2D,\n)\nfrom keras.src.layers.pooling.global_average_pooling3d import (\n    GlobalAveragePooling3D as GlobalAveragePooling3D,\n)\nfrom keras.src.layers.pooling.global_average_pooling3d import (\n    GlobalAveragePooling3D as GlobalAvgPool3D,\n)\nfrom keras.src.layers.pooling.global_max_pooling1d import (\n    GlobalMaxPooling1D as GlobalMaxPool1D,\n)\nfrom keras.src.layers.pooling.global_max_pooling1d import (\n    GlobalMaxPooling1D as GlobalMaxPooling1D,\n)\nfrom keras.src.layers.pooling.global_max_pooling2d import (\n    GlobalMaxPooling2D as GlobalMaxPool2D,\n)\nfrom keras.src.layers.pooling.global_max_pooling2d import (\n    GlobalMaxPooling2D as GlobalMaxPooling2D,\n)\nfrom keras.src.layers.pooling.global_max_pooling3d import (\n    GlobalMaxPooling3D as GlobalMaxPool3D,\n)\nfrom keras.src.layers.pooling.global_max_pooling3d import (\n    GlobalMaxPooling3D as GlobalMaxPooling3D,\n)\nfrom keras.src.layers.pooling.max_pooling1d import MaxPooling1D as MaxPool1D\nfrom keras.src.layers.pooling.max_pooling1d import MaxPooling1D as MaxPooling1D\nfrom keras.src.layers.pooling.max_pooling2d import MaxPooling2D as MaxPool2D\nfrom keras.src.layers.pooling.max_pooling2d import MaxPooling2D as MaxPooling2D\nfrom keras.src.layers.pooling.max_pooling3d import MaxPooling3D as MaxPool3D\nfrom keras.src.layers.pooling.max_pooling3d import MaxPooling3D as MaxPooling3D\nfrom keras.src.layers.preprocessing.category_encoding import (\n    CategoryEncoding as CategoryEncoding,\n)\nfrom keras.src.layers.preprocessing.discretization import (\n    Discretization as Discretization,\n)\nfrom keras.src.layers.preprocessing.hashed_crossing import (\n    HashedCrossing as HashedCrossing,\n)\nfrom keras.src.layers.preprocessing.hashing import Hashing as Hashing\nfrom keras.src.layers.preprocessing.image_preprocessing.aug_mix import (\n    AugMix as AugMix,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.auto_contrast import (\n    AutoContrast as AutoContrast,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.center_crop import (\n    CenterCrop as CenterCrop,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.clahe import (\n    ContrastLimitedAdaptiveHistogramEqualization as ContrastLimitedAdaptiveHistogramEqualization,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.cut_mix import (\n    CutMix as CutMix,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.equalization import (\n    Equalization as Equalization,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.max_num_bounding_box import (\n    MaxNumBoundingBoxes as MaxNumBoundingBoxes,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.mix_up import (\n    MixUp as MixUp,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.rand_augment import (\n    RandAugment as RandAugment,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_brightness import (\n    RandomBrightness as RandomBrightness,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_color_degeneration import (\n    RandomColorDegeneration as RandomColorDegeneration,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_color_jitter import (\n    RandomColorJitter as RandomColorJitter,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_contrast import (\n    RandomContrast as RandomContrast,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_crop import (\n    RandomCrop as RandomCrop,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_elastic_transform import (\n    RandomElasticTransform as RandomElasticTransform,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_erasing import (\n    RandomErasing as RandomErasing,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_flip import (\n    RandomFlip as RandomFlip,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_gaussian_blur import (\n    RandomGaussianBlur as RandomGaussianBlur,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_grayscale import (\n    RandomGrayscale as RandomGrayscale,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_hue import (\n    RandomHue as RandomHue,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_invert import (\n    RandomInvert as RandomInvert,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_perspective import (\n    RandomPerspective as RandomPerspective,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_posterization import (\n    RandomPosterization as RandomPosterization,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_rotation import (\n    RandomRotation as RandomRotation,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_saturation import (\n    RandomSaturation as RandomSaturation,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_sharpness import (\n    RandomSharpness as RandomSharpness,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_shear import (\n    RandomShear as RandomShear,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_translation import (\n    RandomTranslation as RandomTranslation,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_zoom import (\n    RandomZoom as RandomZoom,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.resizing import (\n    Resizing as Resizing,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.solarization import (\n    Solarization as Solarization,\n)\nfrom keras.src.layers.preprocessing.integer_lookup import (\n    IntegerLookup as IntegerLookup,\n)\nfrom keras.src.layers.preprocessing.mel_spectrogram import (\n    MelSpectrogram as MelSpectrogram,\n)\nfrom keras.src.layers.preprocessing.normalization import (\n    Normalization as Normalization,\n)\nfrom keras.src.layers.preprocessing.pipeline import Pipeline as Pipeline\nfrom keras.src.layers.preprocessing.rescaling import Rescaling as Rescaling\nfrom keras.src.layers.preprocessing.stft_spectrogram import (\n    STFTSpectrogram as STFTSpectrogram,\n)\nfrom keras.src.layers.preprocessing.string_lookup import (\n    StringLookup as StringLookup,\n)\nfrom keras.src.layers.preprocessing.text_vectorization import (\n    TextVectorization as TextVectorization,\n)\nfrom keras.src.layers.regularization.activity_regularization import (\n    ActivityRegularization as ActivityRegularization,\n)\nfrom keras.src.layers.regularization.dropout import Dropout as Dropout\nfrom keras.src.layers.regularization.gaussian_dropout import (\n    GaussianDropout as GaussianDropout,\n)\nfrom keras.src.layers.regularization.gaussian_noise import (\n    GaussianNoise as GaussianNoise,\n)\nfrom keras.src.layers.regularization.spatial_dropout import (\n    SpatialDropout1D as SpatialDropout1D,\n)\nfrom keras.src.layers.regularization.spatial_dropout import (\n    SpatialDropout2D as SpatialDropout2D,\n)\nfrom keras.src.layers.regularization.spatial_dropout import (\n    SpatialDropout3D as SpatialDropout3D,\n)\nfrom keras.src.layers.reshaping.cropping1d import Cropping1D as Cropping1D\nfrom keras.src.layers.reshaping.cropping2d import Cropping2D as Cropping2D\nfrom keras.src.layers.reshaping.cropping3d import Cropping3D as Cropping3D\nfrom keras.src.layers.reshaping.flatten import Flatten as Flatten\nfrom keras.src.layers.reshaping.permute import Permute as Permute\nfrom keras.src.layers.reshaping.repeat_vector import (\n    RepeatVector as RepeatVector,\n)\nfrom keras.src.layers.reshaping.reshape import Reshape as Reshape\nfrom keras.src.layers.reshaping.up_sampling1d import (\n    UpSampling1D as UpSampling1D,\n)\nfrom keras.src.layers.reshaping.up_sampling2d import (\n    UpSampling2D as UpSampling2D,\n)\nfrom keras.src.layers.reshaping.up_sampling3d import (\n    UpSampling3D as UpSampling3D,\n)\nfrom keras.src.layers.reshaping.zero_padding1d import (\n    ZeroPadding1D as ZeroPadding1D,\n)\nfrom keras.src.layers.reshaping.zero_padding2d import (\n    ZeroPadding2D as ZeroPadding2D,\n)\nfrom keras.src.layers.reshaping.zero_padding3d import (\n    ZeroPadding3D as ZeroPadding3D,\n)\nfrom keras.src.layers.rnn.bidirectional import Bidirectional as Bidirectional\nfrom keras.src.layers.rnn.conv_lstm1d import ConvLSTM1D as ConvLSTM1D\nfrom keras.src.layers.rnn.conv_lstm2d import ConvLSTM2D as ConvLSTM2D\nfrom keras.src.layers.rnn.conv_lstm3d import ConvLSTM3D as ConvLSTM3D\nfrom keras.src.layers.rnn.gru import GRU as GRU\nfrom keras.src.layers.rnn.gru import GRUCell as GRUCell\nfrom keras.src.layers.rnn.lstm import LSTM as LSTM\nfrom keras.src.layers.rnn.lstm import LSTMCell as LSTMCell\nfrom keras.src.layers.rnn.rnn import RNN as RNN\nfrom keras.src.layers.rnn.simple_rnn import SimpleRNN as SimpleRNN\nfrom keras.src.layers.rnn.simple_rnn import SimpleRNNCell as SimpleRNNCell\nfrom keras.src.layers.rnn.stacked_rnn_cells import (\n    StackedRNNCells as StackedRNNCells,\n)\nfrom keras.src.layers.rnn.time_distributed import (\n    TimeDistributed as TimeDistributed,\n)\nfrom keras.src.legacy.layers import AlphaDropout as AlphaDropout\nfrom keras.src.legacy.layers import RandomHeight as RandomHeight\nfrom keras.src.legacy.layers import RandomWidth as RandomWidth\nfrom keras.src.legacy.layers import ThresholdedReLU as ThresholdedReLU\nfrom keras.src.utils.jax_layer import FlaxLayer as FlaxLayer\nfrom keras.src.utils.jax_layer import JaxLayer as JaxLayer\nfrom keras.src.utils.torch_utils import TorchModuleWrapper as TorchModuleWrapper\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/legacy/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.legacy import saving as saving\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/legacy/saving/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.legacy.saving.serialization import (\n    deserialize_keras_object as deserialize_keras_object,\n)\nfrom keras.src.legacy.saving.serialization import (\n    serialize_keras_object as serialize_keras_object,\n)\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/losses/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.legacy.losses import Reduction as Reduction\nfrom keras.src.losses import deserialize as deserialize\nfrom keras.src.losses import get as get\nfrom keras.src.losses import serialize as serialize\nfrom keras.src.losses.loss import Loss as Loss\nfrom keras.src.losses.losses import CTC as CTC\nfrom keras.src.losses.losses import BinaryCrossentropy as BinaryCrossentropy\nfrom keras.src.losses.losses import (\n    BinaryFocalCrossentropy as BinaryFocalCrossentropy,\n)\nfrom keras.src.losses.losses import (\n    CategoricalCrossentropy as CategoricalCrossentropy,\n)\nfrom keras.src.losses.losses import (\n    CategoricalFocalCrossentropy as CategoricalFocalCrossentropy,\n)\nfrom keras.src.losses.losses import (\n    CategoricalGeneralizedCrossEntropy as CategoricalGeneralizedCrossEntropy,\n)\nfrom keras.src.losses.losses import CategoricalHinge as CategoricalHinge\nfrom keras.src.losses.losses import Circle as Circle\nfrom keras.src.losses.losses import CosineSimilarity as CosineSimilarity\nfrom keras.src.losses.losses import Dice as Dice\nfrom keras.src.losses.losses import Hinge as Hinge\nfrom keras.src.losses.losses import Huber as Huber\nfrom keras.src.losses.losses import KLDivergence as KLDivergence\nfrom keras.src.losses.losses import LogCosh as LogCosh\nfrom keras.src.losses.losses import MeanAbsoluteError as MeanAbsoluteError\nfrom keras.src.losses.losses import (\n    MeanAbsolutePercentageError as MeanAbsolutePercentageError,\n)\nfrom keras.src.losses.losses import MeanSquaredError as MeanSquaredError\nfrom keras.src.losses.losses import (\n    MeanSquaredLogarithmicError as MeanSquaredLogarithmicError,\n)\nfrom keras.src.losses.losses import Poisson as Poisson\nfrom keras.src.losses.losses import (\n    SparseCategoricalCrossentropy as SparseCategoricalCrossentropy,\n)\nfrom keras.src.losses.losses import SquaredHinge as SquaredHinge\nfrom keras.src.losses.losses import Tversky as Tversky\nfrom keras.src.losses.losses import binary_crossentropy as binary_crossentropy\nfrom keras.src.losses.losses import (\n    binary_focal_crossentropy as binary_focal_crossentropy,\n)\nfrom keras.src.losses.losses import (\n    categorical_crossentropy as categorical_crossentropy,\n)\nfrom keras.src.losses.losses import (\n    categorical_focal_crossentropy as categorical_focal_crossentropy,\n)\nfrom keras.src.losses.losses import (\n    categorical_generalized_cross_entropy as categorical_generalized_cross_entropy,\n)\nfrom keras.src.losses.losses import categorical_hinge as categorical_hinge\nfrom keras.src.losses.losses import circle as circle\nfrom keras.src.losses.losses import cosine_similarity as cosine_similarity\nfrom keras.src.losses.losses import ctc as ctc\nfrom keras.src.losses.losses import dice as dice\nfrom keras.src.losses.losses import hinge as hinge\nfrom keras.src.losses.losses import huber as huber\nfrom keras.src.losses.losses import kl_divergence as KLD\nfrom keras.src.losses.losses import kl_divergence as kld\nfrom keras.src.losses.losses import kl_divergence as kullback_leibler_divergence\nfrom keras.src.losses.losses import log_cosh as logcosh\nfrom keras.src.losses.losses import mean_absolute_error as MAE\nfrom keras.src.losses.losses import mean_absolute_error as mae\nfrom keras.src.losses.losses import mean_absolute_percentage_error as MAPE\nfrom keras.src.losses.losses import mean_absolute_percentage_error as mape\nfrom keras.src.losses.losses import mean_squared_error as MSE\nfrom keras.src.losses.losses import mean_squared_error as mse\nfrom keras.src.losses.losses import mean_squared_logarithmic_error as MSLE\nfrom keras.src.losses.losses import mean_squared_logarithmic_error as msle\nfrom keras.src.losses.losses import poisson as poisson\nfrom keras.src.losses.losses import (\n    sparse_categorical_crossentropy as sparse_categorical_crossentropy,\n)\nfrom keras.src.losses.losses import squared_hinge as squared_hinge\nfrom keras.src.losses.losses import tversky as tversky\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/metrics/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.losses.losses import binary_crossentropy as binary_crossentropy\nfrom keras.src.losses.losses import (\n    binary_focal_crossentropy as binary_focal_crossentropy,\n)\nfrom keras.src.losses.losses import (\n    categorical_crossentropy as categorical_crossentropy,\n)\nfrom keras.src.losses.losses import (\n    categorical_focal_crossentropy as categorical_focal_crossentropy,\n)\nfrom keras.src.losses.losses import categorical_hinge as categorical_hinge\nfrom keras.src.losses.losses import hinge as hinge\nfrom keras.src.losses.losses import huber as huber\nfrom keras.src.losses.losses import kl_divergence as KLD\nfrom keras.src.losses.losses import kl_divergence as kld\nfrom keras.src.losses.losses import kl_divergence as kullback_leibler_divergence\nfrom keras.src.losses.losses import log_cosh as logcosh\nfrom keras.src.losses.losses import mean_absolute_error as MAE\nfrom keras.src.losses.losses import mean_absolute_error as mae\nfrom keras.src.losses.losses import mean_absolute_percentage_error as MAPE\nfrom keras.src.losses.losses import mean_absolute_percentage_error as mape\nfrom keras.src.losses.losses import mean_squared_error as MSE\nfrom keras.src.losses.losses import mean_squared_error as mse\nfrom keras.src.losses.losses import mean_squared_logarithmic_error as MSLE\nfrom keras.src.losses.losses import mean_squared_logarithmic_error as msle\nfrom keras.src.losses.losses import poisson as poisson\nfrom keras.src.losses.losses import (\n    sparse_categorical_crossentropy as sparse_categorical_crossentropy,\n)\nfrom keras.src.losses.losses import squared_hinge as squared_hinge\nfrom keras.src.metrics import deserialize as deserialize\nfrom keras.src.metrics import get as get\nfrom keras.src.metrics import serialize as serialize\nfrom keras.src.metrics.accuracy_metrics import Accuracy as Accuracy\nfrom keras.src.metrics.accuracy_metrics import BinaryAccuracy as BinaryAccuracy\nfrom keras.src.metrics.accuracy_metrics import (\n    CategoricalAccuracy as CategoricalAccuracy,\n)\nfrom keras.src.metrics.accuracy_metrics import (\n    SparseCategoricalAccuracy as SparseCategoricalAccuracy,\n)\nfrom keras.src.metrics.accuracy_metrics import (\n    SparseTopKCategoricalAccuracy as SparseTopKCategoricalAccuracy,\n)\nfrom keras.src.metrics.accuracy_metrics import (\n    TopKCategoricalAccuracy as TopKCategoricalAccuracy,\n)\nfrom keras.src.metrics.accuracy_metrics import (\n    binary_accuracy as binary_accuracy,\n)\nfrom keras.src.metrics.accuracy_metrics import (\n    categorical_accuracy as categorical_accuracy,\n)\nfrom keras.src.metrics.accuracy_metrics import (\n    sparse_categorical_accuracy as sparse_categorical_accuracy,\n)\nfrom keras.src.metrics.accuracy_metrics import (\n    sparse_top_k_categorical_accuracy as sparse_top_k_categorical_accuracy,\n)\nfrom keras.src.metrics.accuracy_metrics import (\n    top_k_categorical_accuracy as top_k_categorical_accuracy,\n)\nfrom keras.src.metrics.confusion_metrics import AUC as AUC\nfrom keras.src.metrics.confusion_metrics import FalseNegatives as FalseNegatives\nfrom keras.src.metrics.confusion_metrics import FalsePositives as FalsePositives\nfrom keras.src.metrics.confusion_metrics import Precision as Precision\nfrom keras.src.metrics.confusion_metrics import (\n    PrecisionAtRecall as PrecisionAtRecall,\n)\nfrom keras.src.metrics.confusion_metrics import Recall as Recall\nfrom keras.src.metrics.confusion_metrics import (\n    RecallAtPrecision as RecallAtPrecision,\n)\nfrom keras.src.metrics.confusion_metrics import (\n    SensitivityAtSpecificity as SensitivityAtSpecificity,\n)\nfrom keras.src.metrics.confusion_metrics import (\n    SpecificityAtSensitivity as SpecificityAtSensitivity,\n)\nfrom keras.src.metrics.confusion_metrics import TrueNegatives as TrueNegatives\nfrom keras.src.metrics.confusion_metrics import TruePositives as TruePositives\nfrom keras.src.metrics.correlation_metrics import (\n    ConcordanceCorrelation as ConcordanceCorrelation,\n)\nfrom keras.src.metrics.correlation_metrics import (\n    PearsonCorrelation as PearsonCorrelation,\n)\nfrom keras.src.metrics.correlation_metrics import (\n    concordance_correlation as concordance_correlation,\n)\nfrom keras.src.metrics.correlation_metrics import (\n    pearson_correlation as pearson_correlation,\n)\nfrom keras.src.metrics.f_score_metrics import F1Score as F1Score\nfrom keras.src.metrics.f_score_metrics import FBetaScore as FBetaScore\nfrom keras.src.metrics.hinge_metrics import CategoricalHinge as CategoricalHinge\nfrom keras.src.metrics.hinge_metrics import Hinge as Hinge\nfrom keras.src.metrics.hinge_metrics import SquaredHinge as SquaredHinge\nfrom keras.src.metrics.iou_metrics import BinaryIoU as BinaryIoU\nfrom keras.src.metrics.iou_metrics import IoU as IoU\nfrom keras.src.metrics.iou_metrics import MeanIoU as MeanIoU\nfrom keras.src.metrics.iou_metrics import OneHotIoU as OneHotIoU\nfrom keras.src.metrics.iou_metrics import OneHotMeanIoU as OneHotMeanIoU\nfrom keras.src.metrics.metric import Metric as Metric\nfrom keras.src.metrics.probabilistic_metrics import (\n    BinaryCrossentropy as BinaryCrossentropy,\n)\nfrom keras.src.metrics.probabilistic_metrics import (\n    CategoricalCrossentropy as CategoricalCrossentropy,\n)\nfrom keras.src.metrics.probabilistic_metrics import KLDivergence as KLDivergence\nfrom keras.src.metrics.probabilistic_metrics import Poisson as Poisson\nfrom keras.src.metrics.probabilistic_metrics import (\n    SparseCategoricalCrossentropy as SparseCategoricalCrossentropy,\n)\nfrom keras.src.metrics.reduction_metrics import Mean as Mean\nfrom keras.src.metrics.reduction_metrics import (\n    MeanMetricWrapper as MeanMetricWrapper,\n)\nfrom keras.src.metrics.reduction_metrics import Sum as Sum\nfrom keras.src.metrics.regression_metrics import (\n    CosineSimilarity as CosineSimilarity,\n)\nfrom keras.src.metrics.regression_metrics import LogCoshError as LogCoshError\nfrom keras.src.metrics.regression_metrics import (\n    MeanAbsoluteError as MeanAbsoluteError,\n)\nfrom keras.src.metrics.regression_metrics import (\n    MeanAbsolutePercentageError as MeanAbsolutePercentageError,\n)\nfrom keras.src.metrics.regression_metrics import (\n    MeanSquaredError as MeanSquaredError,\n)\nfrom keras.src.metrics.regression_metrics import (\n    MeanSquaredLogarithmicError as MeanSquaredLogarithmicError,\n)\nfrom keras.src.metrics.regression_metrics import R2Score as R2Score\nfrom keras.src.metrics.regression_metrics import (\n    RootMeanSquaredError as RootMeanSquaredError,\n)\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/mixed_precision/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.dtype_policies.dtype_policy import DTypePolicy as DTypePolicy\nfrom keras.src.dtype_policies.dtype_policy import DTypePolicy as Policy\nfrom keras.src.dtype_policies.dtype_policy import dtype_policy as dtype_policy\nfrom keras.src.dtype_policies.dtype_policy import dtype_policy as global_policy\nfrom keras.src.dtype_policies.dtype_policy import (\n    set_dtype_policy as set_dtype_policy,\n)\nfrom keras.src.dtype_policies.dtype_policy import (\n    set_dtype_policy as set_global_policy,\n)\nfrom keras.src.optimizers.loss_scale_optimizer import (\n    LossScaleOptimizer as LossScaleOptimizer,\n)\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/models/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.models.cloning import clone_model as clone_model\nfrom keras.src.models.model import Model as Model\nfrom keras.src.models.model import model_from_json as model_from_json\nfrom keras.src.models.sequential import Sequential as Sequential\nfrom keras.src.saving.saving_api import load_model as load_model\nfrom keras.src.saving.saving_api import save_model as save_model\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/ops/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.ops import image as image\nfrom keras.ops import linalg as linalg\nfrom keras.ops import nn as nn\nfrom keras.ops import numpy as numpy\nfrom keras.src.ops.core import associative_scan as associative_scan\nfrom keras.src.ops.core import cast as cast\nfrom keras.src.ops.core import cond as cond\nfrom keras.src.ops.core import convert_to_numpy as convert_to_numpy\nfrom keras.src.ops.core import convert_to_tensor as convert_to_tensor\nfrom keras.src.ops.core import custom_gradient as custom_gradient\nfrom keras.src.ops.core import dtype as dtype\nfrom keras.src.ops.core import fori_loop as fori_loop\nfrom keras.src.ops.core import is_tensor as is_tensor\nfrom keras.src.ops.core import map as map\nfrom keras.src.ops.core import saturate_cast as saturate_cast\nfrom keras.src.ops.core import scan as scan\nfrom keras.src.ops.core import scatter as scatter\nfrom keras.src.ops.core import scatter_update as scatter_update\nfrom keras.src.ops.core import shape as shape\nfrom keras.src.ops.core import slice as slice\nfrom keras.src.ops.core import slice_update as slice_update\nfrom keras.src.ops.core import stop_gradient as stop_gradient\nfrom keras.src.ops.core import switch as switch\nfrom keras.src.ops.core import unstack as unstack\nfrom keras.src.ops.core import vectorized_map as vectorized_map\nfrom keras.src.ops.core import while_loop as while_loop\nfrom keras.src.ops.einops import rearrange as rearrange\nfrom keras.src.ops.linalg import cholesky as cholesky\nfrom keras.src.ops.linalg import cholesky_inverse as cholesky_inverse\nfrom keras.src.ops.linalg import det as det\nfrom keras.src.ops.linalg import eig as eig\nfrom keras.src.ops.linalg import eigh as eigh\nfrom keras.src.ops.linalg import inv as inv\nfrom keras.src.ops.linalg import jvp as jvp\nfrom keras.src.ops.linalg import lstsq as lstsq\nfrom keras.src.ops.linalg import lu_factor as lu_factor\nfrom keras.src.ops.linalg import norm as norm\nfrom keras.src.ops.linalg import qr as qr\nfrom keras.src.ops.linalg import solve as solve\nfrom keras.src.ops.linalg import solve_triangular as solve_triangular\nfrom keras.src.ops.linalg import svd as svd\nfrom keras.src.ops.math import erf as erf\nfrom keras.src.ops.math import erfinv as erfinv\nfrom keras.src.ops.math import extract_sequences as extract_sequences\nfrom keras.src.ops.math import fft as fft\nfrom keras.src.ops.math import fft2 as fft2\nfrom keras.src.ops.math import ifft2 as ifft2\nfrom keras.src.ops.math import in_top_k as in_top_k\nfrom keras.src.ops.math import irfft as irfft\nfrom keras.src.ops.math import istft as istft\nfrom keras.src.ops.math import logdet as logdet\nfrom keras.src.ops.math import logsumexp as logsumexp\nfrom keras.src.ops.math import rfft as rfft\nfrom keras.src.ops.math import rsqrt as rsqrt\nfrom keras.src.ops.math import segment_max as segment_max\nfrom keras.src.ops.math import segment_sum as segment_sum\nfrom keras.src.ops.math import stft as stft\nfrom keras.src.ops.math import top_k as top_k\nfrom keras.src.ops.math import view_as_complex as view_as_complex\nfrom keras.src.ops.math import view_as_real as view_as_real\nfrom keras.src.ops.nn import adaptive_average_pool as adaptive_average_pool\nfrom keras.src.ops.nn import adaptive_max_pool as adaptive_max_pool\nfrom keras.src.ops.nn import average_pool as average_pool\nfrom keras.src.ops.nn import batch_normalization as batch_normalization\nfrom keras.src.ops.nn import binary_crossentropy as binary_crossentropy\nfrom keras.src.ops.nn import (\n    categorical_crossentropy as categorical_crossentropy,\n)\nfrom keras.src.ops.nn import celu as celu\nfrom keras.src.ops.nn import conv as conv\nfrom keras.src.ops.nn import conv_transpose as conv_transpose\nfrom keras.src.ops.nn import ctc_decode as ctc_decode\nfrom keras.src.ops.nn import ctc_loss as ctc_loss\nfrom keras.src.ops.nn import depth_to_space as depth_to_space\nfrom keras.src.ops.nn import depthwise_conv as depthwise_conv\nfrom keras.src.ops.nn import dot_product_attention as dot_product_attention\nfrom keras.src.ops.nn import elu as elu\nfrom keras.src.ops.nn import fold as fold\nfrom keras.src.ops.nn import gelu as gelu\nfrom keras.src.ops.nn import glu as glu\nfrom keras.src.ops.nn import hard_shrink as hard_shrink\nfrom keras.src.ops.nn import hard_sigmoid as hard_sigmoid\nfrom keras.src.ops.nn import hard_silu as hard_silu\nfrom keras.src.ops.nn import hard_silu as hard_swish\nfrom keras.src.ops.nn import hard_tanh as hard_tanh\nfrom keras.src.ops.nn import layer_normalization as layer_normalization\nfrom keras.src.ops.nn import leaky_relu as leaky_relu\nfrom keras.src.ops.nn import log_sigmoid as log_sigmoid\nfrom keras.src.ops.nn import log_softmax as log_softmax\nfrom keras.src.ops.nn import max_pool as max_pool\nfrom keras.src.ops.nn import moments as moments\nfrom keras.src.ops.nn import multi_hot as multi_hot\nfrom keras.src.ops.nn import normalize as normalize\nfrom keras.src.ops.nn import one_hot as one_hot\nfrom keras.src.ops.nn import polar as polar\nfrom keras.src.ops.nn import psnr as psnr\nfrom keras.src.ops.nn import relu as relu\nfrom keras.src.ops.nn import relu6 as relu6\nfrom keras.src.ops.nn import rms_normalization as rms_normalization\nfrom keras.src.ops.nn import selu as selu\nfrom keras.src.ops.nn import separable_conv as separable_conv\nfrom keras.src.ops.nn import sigmoid as sigmoid\nfrom keras.src.ops.nn import silu as silu\nfrom keras.src.ops.nn import silu as swish\nfrom keras.src.ops.nn import soft_shrink as soft_shrink\nfrom keras.src.ops.nn import softmax as softmax\nfrom keras.src.ops.nn import softplus as softplus\nfrom keras.src.ops.nn import softsign as softsign\nfrom keras.src.ops.nn import space_to_depth as space_to_depth\nfrom keras.src.ops.nn import (\n    sparse_categorical_crossentropy as sparse_categorical_crossentropy,\n)\nfrom keras.src.ops.nn import sparse_plus as sparse_plus\nfrom keras.src.ops.nn import sparse_sigmoid as sparse_sigmoid\nfrom keras.src.ops.nn import sparsemax as sparsemax\nfrom keras.src.ops.nn import squareplus as squareplus\nfrom keras.src.ops.nn import tanh_shrink as tanh_shrink\nfrom keras.src.ops.nn import threshold as threshold\nfrom keras.src.ops.nn import unfold as unfold\nfrom keras.src.ops.numpy import abs as abs\nfrom keras.src.ops.numpy import absolute as absolute\nfrom keras.src.ops.numpy import add as add\nfrom keras.src.ops.numpy import all as all\nfrom keras.src.ops.numpy import allclose as allclose\nfrom keras.src.ops.numpy import amax as amax\nfrom keras.src.ops.numpy import amin as amin\nfrom keras.src.ops.numpy import angle as angle\nfrom keras.src.ops.numpy import any as any\nfrom keras.src.ops.numpy import append as append\nfrom keras.src.ops.numpy import arange as arange\nfrom keras.src.ops.numpy import arccos as arccos\nfrom keras.src.ops.numpy import arccosh as arccosh\nfrom keras.src.ops.numpy import arcsin as arcsin\nfrom keras.src.ops.numpy import arcsinh as arcsinh\nfrom keras.src.ops.numpy import arctan as arctan\nfrom keras.src.ops.numpy import arctan2 as arctan2\nfrom keras.src.ops.numpy import arctanh as arctanh\nfrom keras.src.ops.numpy import argmax as argmax\nfrom keras.src.ops.numpy import argmin as argmin\nfrom keras.src.ops.numpy import argpartition as argpartition\nfrom keras.src.ops.numpy import argsort as argsort\nfrom keras.src.ops.numpy import array as array\nfrom keras.src.ops.numpy import array_split as array_split\nfrom keras.src.ops.numpy import average as average\nfrom keras.src.ops.numpy import bartlett as bartlett\nfrom keras.src.ops.numpy import bincount as bincount\nfrom keras.src.ops.numpy import bitwise_and as bitwise_and\nfrom keras.src.ops.numpy import bitwise_invert as bitwise_invert\nfrom keras.src.ops.numpy import bitwise_left_shift as bitwise_left_shift\nfrom keras.src.ops.numpy import bitwise_not as bitwise_not\nfrom keras.src.ops.numpy import bitwise_or as bitwise_or\nfrom keras.src.ops.numpy import bitwise_right_shift as bitwise_right_shift\nfrom keras.src.ops.numpy import bitwise_xor as bitwise_xor\nfrom keras.src.ops.numpy import blackman as blackman\nfrom keras.src.ops.numpy import broadcast_to as broadcast_to\nfrom keras.src.ops.numpy import cbrt as cbrt\nfrom keras.src.ops.numpy import ceil as ceil\nfrom keras.src.ops.numpy import clip as clip\nfrom keras.src.ops.numpy import concatenate as concatenate\nfrom keras.src.ops.numpy import conj as conj\nfrom keras.src.ops.numpy import conjugate as conjugate\nfrom keras.src.ops.numpy import copy as copy\nfrom keras.src.ops.numpy import corrcoef as corrcoef\nfrom keras.src.ops.numpy import correlate as correlate\nfrom keras.src.ops.numpy import cos as cos\nfrom keras.src.ops.numpy import cosh as cosh\nfrom keras.src.ops.numpy import count_nonzero as count_nonzero\nfrom keras.src.ops.numpy import cross as cross\nfrom keras.src.ops.numpy import cumprod as cumprod\nfrom keras.src.ops.numpy import cumsum as cumsum\nfrom keras.src.ops.numpy import deg2rad as deg2rad\nfrom keras.src.ops.numpy import diag as diag\nfrom keras.src.ops.numpy import diagflat as diagflat\nfrom keras.src.ops.numpy import diagonal as diagonal\nfrom keras.src.ops.numpy import diff as diff\nfrom keras.src.ops.numpy import digitize as digitize\nfrom keras.src.ops.numpy import divide as divide\nfrom keras.src.ops.numpy import divide_no_nan as divide_no_nan\nfrom keras.src.ops.numpy import dot as dot\nfrom keras.src.ops.numpy import dstack as dstack\nfrom keras.src.ops.numpy import einsum as einsum\nfrom keras.src.ops.numpy import empty as empty\nfrom keras.src.ops.numpy import empty_like as empty_like\nfrom keras.src.ops.numpy import equal as equal\nfrom keras.src.ops.numpy import exp as exp\nfrom keras.src.ops.numpy import exp2 as exp2\nfrom keras.src.ops.numpy import expand_dims as expand_dims\nfrom keras.src.ops.numpy import expm1 as expm1\nfrom keras.src.ops.numpy import eye as eye\nfrom keras.src.ops.numpy import flip as flip\nfrom keras.src.ops.numpy import floor as floor\nfrom keras.src.ops.numpy import floor_divide as floor_divide\nfrom keras.src.ops.numpy import fmod as fmod\nfrom keras.src.ops.numpy import full as full\nfrom keras.src.ops.numpy import full_like as full_like\nfrom keras.src.ops.numpy import gcd as gcd\nfrom keras.src.ops.numpy import geomspace as geomspace\nfrom keras.src.ops.numpy import get_item as get_item\nfrom keras.src.ops.numpy import greater as greater\nfrom keras.src.ops.numpy import greater_equal as greater_equal\nfrom keras.src.ops.numpy import hamming as hamming\nfrom keras.src.ops.numpy import hanning as hanning\nfrom keras.src.ops.numpy import heaviside as heaviside\nfrom keras.src.ops.numpy import histogram as histogram\nfrom keras.src.ops.numpy import hsplit as hsplit\nfrom keras.src.ops.numpy import hstack as hstack\nfrom keras.src.ops.numpy import hypot as hypot\nfrom keras.src.ops.numpy import identity as identity\nfrom keras.src.ops.numpy import imag as imag\nfrom keras.src.ops.numpy import inner as inner\nfrom keras.src.ops.numpy import isclose as isclose\nfrom keras.src.ops.numpy import isfinite as isfinite\nfrom keras.src.ops.numpy import isin as isin\nfrom keras.src.ops.numpy import isinf as isinf\nfrom keras.src.ops.numpy import isnan as isnan\nfrom keras.src.ops.numpy import isneginf as isneginf\nfrom keras.src.ops.numpy import isposinf as isposinf\nfrom keras.src.ops.numpy import isreal as isreal\nfrom keras.src.ops.numpy import kaiser as kaiser\nfrom keras.src.ops.numpy import kron as kron\nfrom keras.src.ops.numpy import lcm as lcm\nfrom keras.src.ops.numpy import ldexp as ldexp\nfrom keras.src.ops.numpy import left_shift as left_shift\nfrom keras.src.ops.numpy import less as less\nfrom keras.src.ops.numpy import less_equal as less_equal\nfrom keras.src.ops.numpy import linspace as linspace\nfrom keras.src.ops.numpy import log as log\nfrom keras.src.ops.numpy import log1p as log1p\nfrom keras.src.ops.numpy import log2 as log2\nfrom keras.src.ops.numpy import log10 as log10\nfrom keras.src.ops.numpy import logaddexp as logaddexp\nfrom keras.src.ops.numpy import logaddexp2 as logaddexp2\nfrom keras.src.ops.numpy import logical_and as logical_and\nfrom keras.src.ops.numpy import logical_not as logical_not\nfrom keras.src.ops.numpy import logical_or as logical_or\nfrom keras.src.ops.numpy import logical_xor as logical_xor\nfrom keras.src.ops.numpy import logspace as logspace\nfrom keras.src.ops.numpy import matmul as matmul\nfrom keras.src.ops.numpy import max as max\nfrom keras.src.ops.numpy import maximum as maximum\nfrom keras.src.ops.numpy import mean as mean\nfrom keras.src.ops.numpy import median as median\nfrom keras.src.ops.numpy import meshgrid as meshgrid\nfrom keras.src.ops.numpy import min as min\nfrom keras.src.ops.numpy import minimum as minimum\nfrom keras.src.ops.numpy import mod as mod\nfrom keras.src.ops.numpy import moveaxis as moveaxis\nfrom keras.src.ops.numpy import multiply as multiply\nfrom keras.src.ops.numpy import nan_to_num as nan_to_num\nfrom keras.src.ops.numpy import nanargmax as nanargmax\nfrom keras.src.ops.numpy import nanargmin as nanargmin\nfrom keras.src.ops.numpy import nancumprod as nancumprod\nfrom keras.src.ops.numpy import nancumsum as nancumsum\nfrom keras.src.ops.numpy import nanmax as nanmax\nfrom keras.src.ops.numpy import nanmean as nanmean\nfrom keras.src.ops.numpy import nanmin as nanmin\nfrom keras.src.ops.numpy import nanprod as nanprod\nfrom keras.src.ops.numpy import nanstd as nanstd\nfrom keras.src.ops.numpy import nansum as nansum\nfrom keras.src.ops.numpy import nanvar as nanvar\nfrom keras.src.ops.numpy import ndim as ndim\nfrom keras.src.ops.numpy import negative as negative\nfrom keras.src.ops.numpy import nextafter as nextafter\nfrom keras.src.ops.numpy import nonzero as nonzero\nfrom keras.src.ops.numpy import not_equal as not_equal\nfrom keras.src.ops.numpy import ones as ones\nfrom keras.src.ops.numpy import ones_like as ones_like\nfrom keras.src.ops.numpy import outer as outer\nfrom keras.src.ops.numpy import pad as pad\nfrom keras.src.ops.numpy import power as power\nfrom keras.src.ops.numpy import prod as prod\nfrom keras.src.ops.numpy import ptp as ptp\nfrom keras.src.ops.numpy import quantile as quantile\nfrom keras.src.ops.numpy import ravel as ravel\nfrom keras.src.ops.numpy import real as real\nfrom keras.src.ops.numpy import reciprocal as reciprocal\nfrom keras.src.ops.numpy import repeat as repeat\nfrom keras.src.ops.numpy import reshape as reshape\nfrom keras.src.ops.numpy import right_shift as right_shift\nfrom keras.src.ops.numpy import roll as roll\nfrom keras.src.ops.numpy import rot90 as rot90\nfrom keras.src.ops.numpy import round as round\nfrom keras.src.ops.numpy import searchsorted as searchsorted\nfrom keras.src.ops.numpy import select as select\nfrom keras.src.ops.numpy import sign as sign\nfrom keras.src.ops.numpy import signbit as signbit\nfrom keras.src.ops.numpy import sin as sin\nfrom keras.src.ops.numpy import sinc as sinc\nfrom keras.src.ops.numpy import sinh as sinh\nfrom keras.src.ops.numpy import size as size\nfrom keras.src.ops.numpy import slogdet as slogdet\nfrom keras.src.ops.numpy import sort as sort\nfrom keras.src.ops.numpy import split as split\nfrom keras.src.ops.numpy import sqrt as sqrt\nfrom keras.src.ops.numpy import square as square\nfrom keras.src.ops.numpy import squeeze as squeeze\nfrom keras.src.ops.numpy import stack as stack\nfrom keras.src.ops.numpy import std as std\nfrom keras.src.ops.numpy import subtract as subtract\nfrom keras.src.ops.numpy import sum as sum\nfrom keras.src.ops.numpy import swapaxes as swapaxes\nfrom keras.src.ops.numpy import take as take\nfrom keras.src.ops.numpy import take_along_axis as take_along_axis\nfrom keras.src.ops.numpy import tan as tan\nfrom keras.src.ops.numpy import tanh as tanh\nfrom keras.src.ops.numpy import tensordot as tensordot\nfrom keras.src.ops.numpy import tile as tile\nfrom keras.src.ops.numpy import trace as trace\nfrom keras.src.ops.numpy import transpose as transpose\nfrom keras.src.ops.numpy import trapezoid as trapezoid\nfrom keras.src.ops.numpy import tri as tri\nfrom keras.src.ops.numpy import tril as tril\nfrom keras.src.ops.numpy import triu as triu\nfrom keras.src.ops.numpy import true_divide as true_divide\nfrom keras.src.ops.numpy import trunc as trunc\nfrom keras.src.ops.numpy import unravel_index as unravel_index\nfrom keras.src.ops.numpy import vander as vander\nfrom keras.src.ops.numpy import var as var\nfrom keras.src.ops.numpy import vdot as vdot\nfrom keras.src.ops.numpy import vectorize as vectorize\nfrom keras.src.ops.numpy import view as view\nfrom keras.src.ops.numpy import vsplit as vsplit\nfrom keras.src.ops.numpy import vstack as vstack\nfrom keras.src.ops.numpy import where as where\nfrom keras.src.ops.numpy import zeros as zeros\nfrom keras.src.ops.numpy import zeros_like as zeros_like\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/ops/image/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.ops.image import affine_transform as affine_transform\nfrom keras.src.ops.image import crop_images as crop_images\nfrom keras.src.ops.image import elastic_transform as elastic_transform\nfrom keras.src.ops.image import extract_patches as extract_patches\nfrom keras.src.ops.image import extract_patches_3d as extract_patches_3d\nfrom keras.src.ops.image import gaussian_blur as gaussian_blur\nfrom keras.src.ops.image import hsv_to_rgb as hsv_to_rgb\nfrom keras.src.ops.image import map_coordinates as map_coordinates\nfrom keras.src.ops.image import pad_images as pad_images\nfrom keras.src.ops.image import perspective_transform as perspective_transform\nfrom keras.src.ops.image import resize as resize\nfrom keras.src.ops.image import rgb_to_grayscale as rgb_to_grayscale\nfrom keras.src.ops.image import rgb_to_hsv as rgb_to_hsv\nfrom keras.src.ops.image import scale_and_translate as scale_and_translate\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/ops/linalg/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.ops.linalg import cholesky as cholesky\nfrom keras.src.ops.linalg import cholesky_inverse as cholesky_inverse\nfrom keras.src.ops.linalg import det as det\nfrom keras.src.ops.linalg import eig as eig\nfrom keras.src.ops.linalg import eigh as eigh\nfrom keras.src.ops.linalg import inv as inv\nfrom keras.src.ops.linalg import jvp as jvp\nfrom keras.src.ops.linalg import lstsq as lstsq\nfrom keras.src.ops.linalg import lu_factor as lu_factor\nfrom keras.src.ops.linalg import norm as norm\nfrom keras.src.ops.linalg import qr as qr\nfrom keras.src.ops.linalg import solve as solve\nfrom keras.src.ops.linalg import solve_triangular as solve_triangular\nfrom keras.src.ops.linalg import svd as svd\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/ops/nn/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.ops.nn import adaptive_average_pool as adaptive_average_pool\nfrom keras.src.ops.nn import adaptive_max_pool as adaptive_max_pool\nfrom keras.src.ops.nn import average_pool as average_pool\nfrom keras.src.ops.nn import batch_normalization as batch_normalization\nfrom keras.src.ops.nn import binary_crossentropy as binary_crossentropy\nfrom keras.src.ops.nn import (\n    categorical_crossentropy as categorical_crossentropy,\n)\nfrom keras.src.ops.nn import celu as celu\nfrom keras.src.ops.nn import conv as conv\nfrom keras.src.ops.nn import conv_transpose as conv_transpose\nfrom keras.src.ops.nn import ctc_decode as ctc_decode\nfrom keras.src.ops.nn import ctc_loss as ctc_loss\nfrom keras.src.ops.nn import depth_to_space as depth_to_space\nfrom keras.src.ops.nn import depthwise_conv as depthwise_conv\nfrom keras.src.ops.nn import dot_product_attention as dot_product_attention\nfrom keras.src.ops.nn import elu as elu\nfrom keras.src.ops.nn import fold as fold\nfrom keras.src.ops.nn import gelu as gelu\nfrom keras.src.ops.nn import glu as glu\nfrom keras.src.ops.nn import hard_shrink as hard_shrink\nfrom keras.src.ops.nn import hard_sigmoid as hard_sigmoid\nfrom keras.src.ops.nn import hard_silu as hard_silu\nfrom keras.src.ops.nn import hard_silu as hard_swish\nfrom keras.src.ops.nn import hard_tanh as hard_tanh\nfrom keras.src.ops.nn import layer_normalization as layer_normalization\nfrom keras.src.ops.nn import leaky_relu as leaky_relu\nfrom keras.src.ops.nn import log_sigmoid as log_sigmoid\nfrom keras.src.ops.nn import log_softmax as log_softmax\nfrom keras.src.ops.nn import max_pool as max_pool\nfrom keras.src.ops.nn import moments as moments\nfrom keras.src.ops.nn import multi_hot as multi_hot\nfrom keras.src.ops.nn import normalize as normalize\nfrom keras.src.ops.nn import one_hot as one_hot\nfrom keras.src.ops.nn import polar as polar\nfrom keras.src.ops.nn import psnr as psnr\nfrom keras.src.ops.nn import relu as relu\nfrom keras.src.ops.nn import relu6 as relu6\nfrom keras.src.ops.nn import rms_normalization as rms_normalization\nfrom keras.src.ops.nn import selu as selu\nfrom keras.src.ops.nn import separable_conv as separable_conv\nfrom keras.src.ops.nn import sigmoid as sigmoid\nfrom keras.src.ops.nn import silu as silu\nfrom keras.src.ops.nn import silu as swish\nfrom keras.src.ops.nn import soft_shrink as soft_shrink\nfrom keras.src.ops.nn import softmax as softmax\nfrom keras.src.ops.nn import softplus as softplus\nfrom keras.src.ops.nn import softsign as softsign\nfrom keras.src.ops.nn import space_to_depth as space_to_depth\nfrom keras.src.ops.nn import (\n    sparse_categorical_crossentropy as sparse_categorical_crossentropy,\n)\nfrom keras.src.ops.nn import sparse_plus as sparse_plus\nfrom keras.src.ops.nn import sparse_sigmoid as sparse_sigmoid\nfrom keras.src.ops.nn import sparsemax as sparsemax\nfrom keras.src.ops.nn import squareplus as squareplus\nfrom keras.src.ops.nn import tanh_shrink as tanh_shrink\nfrom keras.src.ops.nn import threshold as threshold\nfrom keras.src.ops.nn import unfold as unfold\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/ops/numpy/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.ops.numpy import abs as abs\nfrom keras.src.ops.numpy import absolute as absolute\nfrom keras.src.ops.numpy import add as add\nfrom keras.src.ops.numpy import all as all\nfrom keras.src.ops.numpy import allclose as allclose\nfrom keras.src.ops.numpy import amax as amax\nfrom keras.src.ops.numpy import amin as amin\nfrom keras.src.ops.numpy import angle as angle\nfrom keras.src.ops.numpy import any as any\nfrom keras.src.ops.numpy import append as append\nfrom keras.src.ops.numpy import arange as arange\nfrom keras.src.ops.numpy import arccos as arccos\nfrom keras.src.ops.numpy import arccosh as arccosh\nfrom keras.src.ops.numpy import arcsin as arcsin\nfrom keras.src.ops.numpy import arcsinh as arcsinh\nfrom keras.src.ops.numpy import arctan as arctan\nfrom keras.src.ops.numpy import arctan2 as arctan2\nfrom keras.src.ops.numpy import arctanh as arctanh\nfrom keras.src.ops.numpy import argmax as argmax\nfrom keras.src.ops.numpy import argmin as argmin\nfrom keras.src.ops.numpy import argpartition as argpartition\nfrom keras.src.ops.numpy import argsort as argsort\nfrom keras.src.ops.numpy import array as array\nfrom keras.src.ops.numpy import array_split as array_split\nfrom keras.src.ops.numpy import average as average\nfrom keras.src.ops.numpy import bartlett as bartlett\nfrom keras.src.ops.numpy import bincount as bincount\nfrom keras.src.ops.numpy import bitwise_and as bitwise_and\nfrom keras.src.ops.numpy import bitwise_invert as bitwise_invert\nfrom keras.src.ops.numpy import bitwise_left_shift as bitwise_left_shift\nfrom keras.src.ops.numpy import bitwise_not as bitwise_not\nfrom keras.src.ops.numpy import bitwise_or as bitwise_or\nfrom keras.src.ops.numpy import bitwise_right_shift as bitwise_right_shift\nfrom keras.src.ops.numpy import bitwise_xor as bitwise_xor\nfrom keras.src.ops.numpy import blackman as blackman\nfrom keras.src.ops.numpy import broadcast_to as broadcast_to\nfrom keras.src.ops.numpy import cbrt as cbrt\nfrom keras.src.ops.numpy import ceil as ceil\nfrom keras.src.ops.numpy import clip as clip\nfrom keras.src.ops.numpy import concatenate as concatenate\nfrom keras.src.ops.numpy import conj as conj\nfrom keras.src.ops.numpy import conjugate as conjugate\nfrom keras.src.ops.numpy import copy as copy\nfrom keras.src.ops.numpy import corrcoef as corrcoef\nfrom keras.src.ops.numpy import correlate as correlate\nfrom keras.src.ops.numpy import cos as cos\nfrom keras.src.ops.numpy import cosh as cosh\nfrom keras.src.ops.numpy import count_nonzero as count_nonzero\nfrom keras.src.ops.numpy import cross as cross\nfrom keras.src.ops.numpy import cumprod as cumprod\nfrom keras.src.ops.numpy import cumsum as cumsum\nfrom keras.src.ops.numpy import deg2rad as deg2rad\nfrom keras.src.ops.numpy import diag as diag\nfrom keras.src.ops.numpy import diagflat as diagflat\nfrom keras.src.ops.numpy import diagonal as diagonal\nfrom keras.src.ops.numpy import diff as diff\nfrom keras.src.ops.numpy import digitize as digitize\nfrom keras.src.ops.numpy import divide as divide\nfrom keras.src.ops.numpy import divide_no_nan as divide_no_nan\nfrom keras.src.ops.numpy import dot as dot\nfrom keras.src.ops.numpy import dstack as dstack\nfrom keras.src.ops.numpy import einsum as einsum\nfrom keras.src.ops.numpy import empty as empty\nfrom keras.src.ops.numpy import empty_like as empty_like\nfrom keras.src.ops.numpy import equal as equal\nfrom keras.src.ops.numpy import exp as exp\nfrom keras.src.ops.numpy import exp2 as exp2\nfrom keras.src.ops.numpy import expand_dims as expand_dims\nfrom keras.src.ops.numpy import expm1 as expm1\nfrom keras.src.ops.numpy import eye as eye\nfrom keras.src.ops.numpy import flip as flip\nfrom keras.src.ops.numpy import floor as floor\nfrom keras.src.ops.numpy import floor_divide as floor_divide\nfrom keras.src.ops.numpy import fmod as fmod\nfrom keras.src.ops.numpy import full as full\nfrom keras.src.ops.numpy import full_like as full_like\nfrom keras.src.ops.numpy import gcd as gcd\nfrom keras.src.ops.numpy import geomspace as geomspace\nfrom keras.src.ops.numpy import get_item as get_item\nfrom keras.src.ops.numpy import greater as greater\nfrom keras.src.ops.numpy import greater_equal as greater_equal\nfrom keras.src.ops.numpy import hamming as hamming\nfrom keras.src.ops.numpy import hanning as hanning\nfrom keras.src.ops.numpy import heaviside as heaviside\nfrom keras.src.ops.numpy import histogram as histogram\nfrom keras.src.ops.numpy import hsplit as hsplit\nfrom keras.src.ops.numpy import hstack as hstack\nfrom keras.src.ops.numpy import hypot as hypot\nfrom keras.src.ops.numpy import identity as identity\nfrom keras.src.ops.numpy import imag as imag\nfrom keras.src.ops.numpy import inner as inner\nfrom keras.src.ops.numpy import isclose as isclose\nfrom keras.src.ops.numpy import isfinite as isfinite\nfrom keras.src.ops.numpy import isin as isin\nfrom keras.src.ops.numpy import isinf as isinf\nfrom keras.src.ops.numpy import isnan as isnan\nfrom keras.src.ops.numpy import isneginf as isneginf\nfrom keras.src.ops.numpy import isposinf as isposinf\nfrom keras.src.ops.numpy import isreal as isreal\nfrom keras.src.ops.numpy import kaiser as kaiser\nfrom keras.src.ops.numpy import kron as kron\nfrom keras.src.ops.numpy import lcm as lcm\nfrom keras.src.ops.numpy import ldexp as ldexp\nfrom keras.src.ops.numpy import left_shift as left_shift\nfrom keras.src.ops.numpy import less as less\nfrom keras.src.ops.numpy import less_equal as less_equal\nfrom keras.src.ops.numpy import linspace as linspace\nfrom keras.src.ops.numpy import log as log\nfrom keras.src.ops.numpy import log1p as log1p\nfrom keras.src.ops.numpy import log2 as log2\nfrom keras.src.ops.numpy import log10 as log10\nfrom keras.src.ops.numpy import logaddexp as logaddexp\nfrom keras.src.ops.numpy import logaddexp2 as logaddexp2\nfrom keras.src.ops.numpy import logical_and as logical_and\nfrom keras.src.ops.numpy import logical_not as logical_not\nfrom keras.src.ops.numpy import logical_or as logical_or\nfrom keras.src.ops.numpy import logical_xor as logical_xor\nfrom keras.src.ops.numpy import logspace as logspace\nfrom keras.src.ops.numpy import matmul as matmul\nfrom keras.src.ops.numpy import max as max\nfrom keras.src.ops.numpy import maximum as maximum\nfrom keras.src.ops.numpy import mean as mean\nfrom keras.src.ops.numpy import median as median\nfrom keras.src.ops.numpy import meshgrid as meshgrid\nfrom keras.src.ops.numpy import min as min\nfrom keras.src.ops.numpy import minimum as minimum\nfrom keras.src.ops.numpy import mod as mod\nfrom keras.src.ops.numpy import moveaxis as moveaxis\nfrom keras.src.ops.numpy import multiply as multiply\nfrom keras.src.ops.numpy import nan_to_num as nan_to_num\nfrom keras.src.ops.numpy import nanargmax as nanargmax\nfrom keras.src.ops.numpy import nanargmin as nanargmin\nfrom keras.src.ops.numpy import nancumprod as nancumprod\nfrom keras.src.ops.numpy import nancumsum as nancumsum\nfrom keras.src.ops.numpy import nanmax as nanmax\nfrom keras.src.ops.numpy import nanmean as nanmean\nfrom keras.src.ops.numpy import nanmin as nanmin\nfrom keras.src.ops.numpy import nanprod as nanprod\nfrom keras.src.ops.numpy import nanstd as nanstd\nfrom keras.src.ops.numpy import nansum as nansum\nfrom keras.src.ops.numpy import nanvar as nanvar\nfrom keras.src.ops.numpy import ndim as ndim\nfrom keras.src.ops.numpy import negative as negative\nfrom keras.src.ops.numpy import nextafter as nextafter\nfrom keras.src.ops.numpy import nonzero as nonzero\nfrom keras.src.ops.numpy import not_equal as not_equal\nfrom keras.src.ops.numpy import ones as ones\nfrom keras.src.ops.numpy import ones_like as ones_like\nfrom keras.src.ops.numpy import outer as outer\nfrom keras.src.ops.numpy import pad as pad\nfrom keras.src.ops.numpy import power as power\nfrom keras.src.ops.numpy import prod as prod\nfrom keras.src.ops.numpy import ptp as ptp\nfrom keras.src.ops.numpy import quantile as quantile\nfrom keras.src.ops.numpy import ravel as ravel\nfrom keras.src.ops.numpy import real as real\nfrom keras.src.ops.numpy import reciprocal as reciprocal\nfrom keras.src.ops.numpy import repeat as repeat\nfrom keras.src.ops.numpy import reshape as reshape\nfrom keras.src.ops.numpy import right_shift as right_shift\nfrom keras.src.ops.numpy import roll as roll\nfrom keras.src.ops.numpy import rot90 as rot90\nfrom keras.src.ops.numpy import round as round\nfrom keras.src.ops.numpy import searchsorted as searchsorted\nfrom keras.src.ops.numpy import select as select\nfrom keras.src.ops.numpy import sign as sign\nfrom keras.src.ops.numpy import signbit as signbit\nfrom keras.src.ops.numpy import sin as sin\nfrom keras.src.ops.numpy import sinc as sinc\nfrom keras.src.ops.numpy import sinh as sinh\nfrom keras.src.ops.numpy import size as size\nfrom keras.src.ops.numpy import slogdet as slogdet\nfrom keras.src.ops.numpy import sort as sort\nfrom keras.src.ops.numpy import split as split\nfrom keras.src.ops.numpy import sqrt as sqrt\nfrom keras.src.ops.numpy import square as square\nfrom keras.src.ops.numpy import squeeze as squeeze\nfrom keras.src.ops.numpy import stack as stack\nfrom keras.src.ops.numpy import std as std\nfrom keras.src.ops.numpy import subtract as subtract\nfrom keras.src.ops.numpy import sum as sum\nfrom keras.src.ops.numpy import swapaxes as swapaxes\nfrom keras.src.ops.numpy import take as take\nfrom keras.src.ops.numpy import take_along_axis as take_along_axis\nfrom keras.src.ops.numpy import tan as tan\nfrom keras.src.ops.numpy import tanh as tanh\nfrom keras.src.ops.numpy import tensordot as tensordot\nfrom keras.src.ops.numpy import tile as tile\nfrom keras.src.ops.numpy import trace as trace\nfrom keras.src.ops.numpy import transpose as transpose\nfrom keras.src.ops.numpy import trapezoid as trapezoid\nfrom keras.src.ops.numpy import tri as tri\nfrom keras.src.ops.numpy import tril as tril\nfrom keras.src.ops.numpy import triu as triu\nfrom keras.src.ops.numpy import true_divide as true_divide\nfrom keras.src.ops.numpy import trunc as trunc\nfrom keras.src.ops.numpy import unravel_index as unravel_index\nfrom keras.src.ops.numpy import vander as vander\nfrom keras.src.ops.numpy import var as var\nfrom keras.src.ops.numpy import vdot as vdot\nfrom keras.src.ops.numpy import vectorize as vectorize\nfrom keras.src.ops.numpy import view as view\nfrom keras.src.ops.numpy import vsplit as vsplit\nfrom keras.src.ops.numpy import vstack as vstack\nfrom keras.src.ops.numpy import where as where\nfrom keras.src.ops.numpy import zeros as zeros\nfrom keras.src.ops.numpy import zeros_like as zeros_like\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/optimizers/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.optimizers import legacy as legacy\nfrom keras.optimizers import schedules as schedules\nfrom keras.src.optimizers import deserialize as deserialize\nfrom keras.src.optimizers import get as get\nfrom keras.src.optimizers import serialize as serialize\nfrom keras.src.optimizers.adadelta import Adadelta as Adadelta\nfrom keras.src.optimizers.adafactor import Adafactor as Adafactor\nfrom keras.src.optimizers.adagrad import Adagrad as Adagrad\nfrom keras.src.optimizers.adam import Adam as Adam\nfrom keras.src.optimizers.adamax import Adamax as Adamax\nfrom keras.src.optimizers.adamw import AdamW as AdamW\nfrom keras.src.optimizers.ftrl import Ftrl as Ftrl\nfrom keras.src.optimizers.lamb import Lamb as Lamb\nfrom keras.src.optimizers.lion import Lion as Lion\nfrom keras.src.optimizers.loss_scale_optimizer import (\n    LossScaleOptimizer as LossScaleOptimizer,\n)\nfrom keras.src.optimizers.muon import Muon as Muon\nfrom keras.src.optimizers.nadam import Nadam as Nadam\nfrom keras.src.optimizers.optimizer import Optimizer as Optimizer\nfrom keras.src.optimizers.rmsprop import RMSprop as RMSprop\nfrom keras.src.optimizers.schedule_free_adamw import (\n    ScheduleFreeAdamW as ScheduleFreeAdamW,\n)\nfrom keras.src.optimizers.sgd import SGD as SGD\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/optimizers/legacy/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.optimizers import LegacyOptimizerWarning as Adagrad\nfrom keras.src.optimizers import LegacyOptimizerWarning as Adam\nfrom keras.src.optimizers import LegacyOptimizerWarning as Ftrl\nfrom keras.src.optimizers import LegacyOptimizerWarning as Optimizer\nfrom keras.src.optimizers import LegacyOptimizerWarning as RMSprop\nfrom keras.src.optimizers import LegacyOptimizerWarning as SGD\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/optimizers/schedules/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.optimizers.schedules.learning_rate_schedule import (\n    CosineDecay as CosineDecay,\n)\nfrom keras.src.optimizers.schedules.learning_rate_schedule import (\n    CosineDecayRestarts as CosineDecayRestarts,\n)\nfrom keras.src.optimizers.schedules.learning_rate_schedule import (\n    ExponentialDecay as ExponentialDecay,\n)\nfrom keras.src.optimizers.schedules.learning_rate_schedule import (\n    InverseTimeDecay as InverseTimeDecay,\n)\nfrom keras.src.optimizers.schedules.learning_rate_schedule import (\n    LearningRateSchedule as LearningRateSchedule,\n)\nfrom keras.src.optimizers.schedules.learning_rate_schedule import (\n    PiecewiseConstantDecay as PiecewiseConstantDecay,\n)\nfrom keras.src.optimizers.schedules.learning_rate_schedule import (\n    PolynomialDecay as PolynomialDecay,\n)\nfrom keras.src.optimizers.schedules.learning_rate_schedule import (\n    deserialize as deserialize,\n)\nfrom keras.src.optimizers.schedules.learning_rate_schedule import (\n    serialize as serialize,\n)\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/preprocessing/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras._tf_keras.keras.preprocessing import image as image\nfrom keras._tf_keras.keras.preprocessing import sequence as sequence\nfrom keras._tf_keras.keras.preprocessing import text as text\nfrom keras.src.utils.image_dataset_utils import (\n    image_dataset_from_directory as image_dataset_from_directory,\n)\nfrom keras.src.utils.text_dataset_utils import (\n    text_dataset_from_directory as text_dataset_from_directory,\n)\nfrom keras.src.utils.timeseries_dataset_utils import (\n    timeseries_dataset_from_array as timeseries_dataset_from_array,\n)\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/preprocessing/image/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.legacy.preprocessing.image import (\n    DirectoryIterator as DirectoryIterator,\n)\nfrom keras.src.legacy.preprocessing.image import (\n    ImageDataGenerator as ImageDataGenerator,\n)\nfrom keras.src.legacy.preprocessing.image import Iterator as Iterator\nfrom keras.src.legacy.preprocessing.image import (\n    NumpyArrayIterator as NumpyArrayIterator,\n)\nfrom keras.src.legacy.preprocessing.image import (\n    apply_affine_transform as apply_affine_transform,\n)\nfrom keras.src.legacy.preprocessing.image import (\n    apply_brightness_shift as apply_brightness_shift,\n)\nfrom keras.src.legacy.preprocessing.image import (\n    apply_channel_shift as apply_channel_shift,\n)\nfrom keras.src.legacy.preprocessing.image import (\n    random_brightness as random_brightness,\n)\nfrom keras.src.legacy.preprocessing.image import (\n    random_channel_shift as random_channel_shift,\n)\nfrom keras.src.legacy.preprocessing.image import (\n    random_rotation as random_rotation,\n)\nfrom keras.src.legacy.preprocessing.image import random_shear as random_shear\nfrom keras.src.legacy.preprocessing.image import random_shift as random_shift\nfrom keras.src.legacy.preprocessing.image import random_zoom as random_zoom\nfrom keras.src.utils.image_utils import array_to_img as array_to_img\nfrom keras.src.utils.image_utils import img_to_array as img_to_array\nfrom keras.src.utils.image_utils import load_img as load_img\nfrom keras.src.utils.image_utils import save_img as save_img\nfrom keras.src.utils.image_utils import smart_resize as smart_resize\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/preprocessing/sequence/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.legacy.preprocessing.sequence import (\n    TimeseriesGenerator as TimeseriesGenerator,\n)\nfrom keras.src.legacy.preprocessing.sequence import (\n    make_sampling_table as make_sampling_table,\n)\nfrom keras.src.legacy.preprocessing.sequence import skipgrams as skipgrams\nfrom keras.src.utils.sequence_utils import pad_sequences as pad_sequences\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/preprocessing/text/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.legacy.preprocessing.text import Tokenizer as Tokenizer\nfrom keras.src.legacy.preprocessing.text import hashing_trick as hashing_trick\nfrom keras.src.legacy.preprocessing.text import one_hot as one_hot\nfrom keras.src.legacy.preprocessing.text import (\n    text_to_word_sequence as text_to_word_sequence,\n)\nfrom keras.src.legacy.preprocessing.text import (\n    tokenizer_from_json as tokenizer_from_json,\n)\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/quantizers/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.quantizers import deserialize as deserialize\nfrom keras.src.quantizers import get as get\nfrom keras.src.quantizers import serialize as serialize\nfrom keras.src.quantizers.awq_config import AWQConfig as AWQConfig\nfrom keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig\nfrom keras.src.quantizers.quantization_config import (\n    Float8QuantizationConfig as Float8QuantizationConfig,\n)\nfrom keras.src.quantizers.quantization_config import (\n    Int4QuantizationConfig as Int4QuantizationConfig,\n)\nfrom keras.src.quantizers.quantization_config import (\n    Int8QuantizationConfig as Int8QuantizationConfig,\n)\nfrom keras.src.quantizers.quantization_config import (\n    QuantizationConfig as QuantizationConfig,\n)\nfrom keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer\nfrom keras.src.quantizers.quantizers import Quantizer as Quantizer\nfrom keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize\nfrom keras.src.quantizers.quantizers import (\n    abs_max_quantize_grouped_with_zero_point as abs_max_quantize_grouped_with_zero_point,\n)\nfrom keras.src.quantizers.quantizers import (\n    compute_float8_amax_history as compute_float8_amax_history,\n)\nfrom keras.src.quantizers.quantizers import (\n    compute_float8_scale as compute_float8_scale,\n)\nfrom keras.src.quantizers.quantizers import (\n    fake_quant_with_min_max_vars as fake_quant_with_min_max_vars,\n)\nfrom keras.src.quantizers.quantizers import pack_int4 as pack_int4\nfrom keras.src.quantizers.quantizers import (\n    quantize_and_dequantize as quantize_and_dequantize,\n)\nfrom keras.src.quantizers.quantizers import unpack_int4 as unpack_int4\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/random/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.random.random import beta as beta\nfrom keras.src.random.random import binomial as binomial\nfrom keras.src.random.random import categorical as categorical\nfrom keras.src.random.random import dropout as dropout\nfrom keras.src.random.random import gamma as gamma\nfrom keras.src.random.random import normal as normal\nfrom keras.src.random.random import randint as randint\nfrom keras.src.random.random import shuffle as shuffle\nfrom keras.src.random.random import truncated_normal as truncated_normal\nfrom keras.src.random.random import uniform as uniform\nfrom keras.src.random.seed_generator import SeedGenerator as SeedGenerator\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/regularizers/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.regularizers import deserialize as deserialize\nfrom keras.src.regularizers import get as get\nfrom keras.src.regularizers import serialize as serialize\nfrom keras.src.regularizers.regularizers import L1 as L1\nfrom keras.src.regularizers.regularizers import L1 as l1\nfrom keras.src.regularizers.regularizers import L1L2 as L1L2\nfrom keras.src.regularizers.regularizers import L1L2 as l1_l2\nfrom keras.src.regularizers.regularizers import L2 as L2\nfrom keras.src.regularizers.regularizers import L2 as l2\nfrom keras.src.regularizers.regularizers import (\n    OrthogonalRegularizer as OrthogonalRegularizer,\n)\nfrom keras.src.regularizers.regularizers import (\n    OrthogonalRegularizer as orthogonal_regularizer,\n)\nfrom keras.src.regularizers.regularizers import Regularizer as Regularizer\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/saving/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.saving.file_editor import KerasFileEditor as KerasFileEditor\nfrom keras.src.saving.object_registration import (\n    CustomObjectScope as CustomObjectScope,\n)\nfrom keras.src.saving.object_registration import (\n    CustomObjectScope as custom_object_scope,\n)\nfrom keras.src.saving.object_registration import (\n    get_custom_objects as get_custom_objects,\n)\nfrom keras.src.saving.object_registration import (\n    get_registered_name as get_registered_name,\n)\nfrom keras.src.saving.object_registration import (\n    get_registered_object as get_registered_object,\n)\nfrom keras.src.saving.object_registration import (\n    register_keras_serializable as register_keras_serializable,\n)\nfrom keras.src.saving.saving_api import load_model as load_model\nfrom keras.src.saving.saving_api import load_weights as load_weights\nfrom keras.src.saving.saving_api import save_model as save_model\nfrom keras.src.saving.saving_api import save_weights as save_weights\nfrom keras.src.saving.serialization_lib import (\n    deserialize_keras_object as deserialize_keras_object,\n)\nfrom keras.src.saving.serialization_lib import (\n    serialize_keras_object as serialize_keras_object,\n)\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/tree/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.tree.tree_api import MAP_TO_NONE as MAP_TO_NONE\nfrom keras.src.tree.tree_api import assert_same_paths as assert_same_paths\nfrom keras.src.tree.tree_api import (\n    assert_same_structure as assert_same_structure,\n)\nfrom keras.src.tree.tree_api import flatten as flatten\nfrom keras.src.tree.tree_api import flatten_with_path as flatten_with_path\nfrom keras.src.tree.tree_api import is_nested as is_nested\nfrom keras.src.tree.tree_api import lists_to_tuples as lists_to_tuples\nfrom keras.src.tree.tree_api import map_shape_structure as map_shape_structure\nfrom keras.src.tree.tree_api import map_structure as map_structure\nfrom keras.src.tree.tree_api import map_structure_up_to as map_structure_up_to\nfrom keras.src.tree.tree_api import pack_sequence_as as pack_sequence_as\nfrom keras.src.tree.tree_api import traverse as traverse\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/utils/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.backend.common.global_state import clear_session as clear_session\nfrom keras.src.backend.common.keras_tensor import (\n    is_keras_tensor as is_keras_tensor,\n)\nfrom keras.src.backend.common.variables import (\n    standardize_dtype as standardize_dtype,\n)\nfrom keras.src.layers.preprocessing.feature_space import (\n    FeatureSpace as FeatureSpace,\n)\nfrom keras.src.ops.operation_utils import get_source_inputs as get_source_inputs\nfrom keras.src.saving.object_registration import (\n    CustomObjectScope as CustomObjectScope,\n)\nfrom keras.src.saving.object_registration import (\n    CustomObjectScope as custom_object_scope,\n)\nfrom keras.src.saving.object_registration import (\n    get_custom_objects as get_custom_objects,\n)\nfrom keras.src.saving.object_registration import (\n    get_registered_name as get_registered_name,\n)\nfrom keras.src.saving.object_registration import (\n    get_registered_object as get_registered_object,\n)\nfrom keras.src.saving.object_registration import (\n    register_keras_serializable as register_keras_serializable,\n)\nfrom keras.src.saving.serialization_lib import (\n    deserialize_keras_object as deserialize_keras_object,\n)\nfrom keras.src.saving.serialization_lib import (\n    serialize_keras_object as serialize_keras_object,\n)\nfrom keras.src.trainers.data_adapters.data_adapter_utils import (\n    pack_x_y_sample_weight as pack_x_y_sample_weight,\n)\nfrom keras.src.trainers.data_adapters.data_adapter_utils import (\n    unpack_x_y_sample_weight as unpack_x_y_sample_weight,\n)\nfrom keras.src.trainers.data_adapters.py_dataset_adapter import (\n    PyDataset as PyDataset,\n)\nfrom keras.src.trainers.data_adapters.py_dataset_adapter import (\n    PyDataset as Sequence,\n)\nfrom keras.src.utils.audio_dataset_utils import (\n    audio_dataset_from_directory as audio_dataset_from_directory,\n)\nfrom keras.src.utils.config import Config as Config\nfrom keras.src.utils.dataset_utils import split_dataset as split_dataset\nfrom keras.src.utils.file_utils import get_file as get_file\nfrom keras.src.utils.image_dataset_utils import (\n    image_dataset_from_directory as image_dataset_from_directory,\n)\nfrom keras.src.utils.image_utils import array_to_img as array_to_img\nfrom keras.src.utils.image_utils import img_to_array as img_to_array\nfrom keras.src.utils.image_utils import load_img as load_img\nfrom keras.src.utils.image_utils import save_img as save_img\nfrom keras.src.utils.io_utils import (\n    disable_interactive_logging as disable_interactive_logging,\n)\nfrom keras.src.utils.io_utils import (\n    enable_interactive_logging as enable_interactive_logging,\n)\nfrom keras.src.utils.io_utils import (\n    is_interactive_logging_enabled as is_interactive_logging_enabled,\n)\nfrom keras.src.utils.model_visualization import model_to_dot as model_to_dot\nfrom keras.src.utils.model_visualization import plot_model as plot_model\nfrom keras.src.utils.numerical_utils import normalize as normalize\nfrom keras.src.utils.numerical_utils import to_categorical as to_categorical\nfrom keras.src.utils.progbar import Progbar as Progbar\nfrom keras.src.utils.rng_utils import set_random_seed as set_random_seed\nfrom keras.src.utils.sequence_utils import pad_sequences as pad_sequences\nfrom keras.src.utils.text_dataset_utils import (\n    text_dataset_from_directory as text_dataset_from_directory,\n)\nfrom keras.src.utils.timeseries_dataset_utils import (\n    timeseries_dataset_from_array as timeseries_dataset_from_array,\n)\nfrom keras.utils import bounding_boxes as bounding_boxes\nfrom keras.utils import legacy as legacy\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/utils/bounding_boxes/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (\n    affine_transform as affine_transform,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (\n    clip_to_image_size as clip_to_image_size,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (\n    convert_format as convert_format,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (\n    crop as crop,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (\n    decode_deltas_to_boxes as decode_deltas_to_boxes,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (\n    encode_box_to_deltas as encode_box_to_deltas,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (\n    pad as pad,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.iou import (\n    compute_ciou as compute_ciou,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.iou import (\n    compute_iou as compute_iou,\n)\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/utils/legacy/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.legacy.saving.serialization import (\n    deserialize_keras_object as deserialize_keras_object,\n)\nfrom keras.src.legacy.saving.serialization import (\n    serialize_keras_object as serialize_keras_object,\n)\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/visualization/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.visualization.draw_bounding_boxes import (\n    draw_bounding_boxes as draw_bounding_boxes,\n)\nfrom keras.src.visualization.draw_segmentation_masks import (\n    draw_segmentation_masks as draw_segmentation_masks,\n)\nfrom keras.src.visualization.plot_bounding_box_gallery import (\n    plot_bounding_box_gallery as plot_bounding_box_gallery,\n)\nfrom keras.src.visualization.plot_image_gallery import (\n    plot_image_gallery as plot_image_gallery,\n)\nfrom keras.src.visualization.plot_segmentation_mask_gallery import (\n    plot_segmentation_mask_gallery as plot_segmentation_mask_gallery,\n)\n"
  },
  {
    "path": "keras/api/_tf_keras/keras/wrappers/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.wrappers.sklearn_wrapper import (\n    SKLearnClassifier as SKLearnClassifier,\n)\nfrom keras.src.wrappers.sklearn_wrapper import (\n    SKLearnRegressor as SKLearnRegressor,\n)\nfrom keras.src.wrappers.sklearn_wrapper import (\n    SKLearnTransformer as SKLearnTransformer,\n)\n"
  },
  {
    "path": "keras/api/activations/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.activations import deserialize as deserialize\nfrom keras.src.activations import get as get\nfrom keras.src.activations import serialize as serialize\nfrom keras.src.activations.activations import celu as celu\nfrom keras.src.activations.activations import elu as elu\nfrom keras.src.activations.activations import exponential as exponential\nfrom keras.src.activations.activations import gelu as gelu\nfrom keras.src.activations.activations import glu as glu\nfrom keras.src.activations.activations import hard_shrink as hard_shrink\nfrom keras.src.activations.activations import hard_sigmoid as hard_sigmoid\nfrom keras.src.activations.activations import hard_silu as hard_silu\nfrom keras.src.activations.activations import hard_silu as hard_swish\nfrom keras.src.activations.activations import hard_tanh as hard_tanh\nfrom keras.src.activations.activations import leaky_relu as leaky_relu\nfrom keras.src.activations.activations import linear as linear\nfrom keras.src.activations.activations import log_sigmoid as log_sigmoid\nfrom keras.src.activations.activations import log_softmax as log_softmax\nfrom keras.src.activations.activations import mish as mish\nfrom keras.src.activations.activations import relu as relu\nfrom keras.src.activations.activations import relu6 as relu6\nfrom keras.src.activations.activations import selu as selu\nfrom keras.src.activations.activations import sigmoid as sigmoid\nfrom keras.src.activations.activations import silu as silu\nfrom keras.src.activations.activations import silu as swish\nfrom keras.src.activations.activations import soft_shrink as soft_shrink\nfrom keras.src.activations.activations import softmax as softmax\nfrom keras.src.activations.activations import softplus as softplus\nfrom keras.src.activations.activations import softsign as softsign\nfrom keras.src.activations.activations import sparse_plus as sparse_plus\nfrom keras.src.activations.activations import sparse_sigmoid as sparse_sigmoid\nfrom keras.src.activations.activations import sparsemax as sparsemax\nfrom keras.src.activations.activations import squareplus as squareplus\nfrom keras.src.activations.activations import tanh as tanh\nfrom keras.src.activations.activations import tanh_shrink as tanh_shrink\nfrom keras.src.activations.activations import threshold as threshold\n"
  },
  {
    "path": "keras/api/applications/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.applications import convnext as convnext\nfrom keras.applications import densenet as densenet\nfrom keras.applications import efficientnet as efficientnet\nfrom keras.applications import efficientnet_v2 as efficientnet_v2\nfrom keras.applications import imagenet_utils as imagenet_utils\nfrom keras.applications import inception_resnet_v2 as inception_resnet_v2\nfrom keras.applications import inception_v3 as inception_v3\nfrom keras.applications import mobilenet as mobilenet\nfrom keras.applications import mobilenet_v2 as mobilenet_v2\nfrom keras.applications import mobilenet_v3 as mobilenet_v3\nfrom keras.applications import nasnet as nasnet\nfrom keras.applications import resnet as resnet\nfrom keras.applications import resnet50 as resnet50\nfrom keras.applications import resnet_v2 as resnet_v2\nfrom keras.applications import vgg16 as vgg16\nfrom keras.applications import vgg19 as vgg19\nfrom keras.applications import xception as xception\nfrom keras.src.applications.convnext import ConvNeXtBase as ConvNeXtBase\nfrom keras.src.applications.convnext import ConvNeXtLarge as ConvNeXtLarge\nfrom keras.src.applications.convnext import ConvNeXtSmall as ConvNeXtSmall\nfrom keras.src.applications.convnext import ConvNeXtTiny as ConvNeXtTiny\nfrom keras.src.applications.convnext import ConvNeXtXLarge as ConvNeXtXLarge\nfrom keras.src.applications.densenet import DenseNet121 as DenseNet121\nfrom keras.src.applications.densenet import DenseNet169 as DenseNet169\nfrom keras.src.applications.densenet import DenseNet201 as DenseNet201\nfrom keras.src.applications.efficientnet import EfficientNetB0 as EfficientNetB0\nfrom keras.src.applications.efficientnet import EfficientNetB1 as EfficientNetB1\nfrom keras.src.applications.efficientnet import EfficientNetB2 as EfficientNetB2\nfrom keras.src.applications.efficientnet import EfficientNetB3 as EfficientNetB3\nfrom keras.src.applications.efficientnet import EfficientNetB4 as EfficientNetB4\nfrom keras.src.applications.efficientnet import EfficientNetB5 as EfficientNetB5\nfrom keras.src.applications.efficientnet import EfficientNetB6 as EfficientNetB6\nfrom keras.src.applications.efficientnet import EfficientNetB7 as EfficientNetB7\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2B0 as EfficientNetV2B0,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2B1 as EfficientNetV2B1,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2B2 as EfficientNetV2B2,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2B3 as EfficientNetV2B3,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2L as EfficientNetV2L,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2M as EfficientNetV2M,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2S as EfficientNetV2S,\n)\nfrom keras.src.applications.inception_resnet_v2 import (\n    InceptionResNetV2 as InceptionResNetV2,\n)\nfrom keras.src.applications.inception_v3 import InceptionV3 as InceptionV3\nfrom keras.src.applications.mobilenet import MobileNet as MobileNet\nfrom keras.src.applications.mobilenet_v2 import MobileNetV2 as MobileNetV2\nfrom keras.src.applications.mobilenet_v3 import (\n    MobileNetV3Large as MobileNetV3Large,\n)\nfrom keras.src.applications.mobilenet_v3 import (\n    MobileNetV3Small as MobileNetV3Small,\n)\nfrom keras.src.applications.nasnet import NASNetLarge as NASNetLarge\nfrom keras.src.applications.nasnet import NASNetMobile as NASNetMobile\nfrom keras.src.applications.resnet import ResNet50 as ResNet50\nfrom keras.src.applications.resnet import ResNet101 as ResNet101\nfrom keras.src.applications.resnet import ResNet152 as ResNet152\nfrom keras.src.applications.resnet_v2 import ResNet50V2 as ResNet50V2\nfrom keras.src.applications.resnet_v2 import ResNet101V2 as ResNet101V2\nfrom keras.src.applications.resnet_v2 import ResNet152V2 as ResNet152V2\nfrom keras.src.applications.vgg16 import VGG16 as VGG16\nfrom keras.src.applications.vgg19 import VGG19 as VGG19\nfrom keras.src.applications.xception import Xception as Xception\n"
  },
  {
    "path": "keras/api/applications/convnext/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.convnext import ConvNeXtBase as ConvNeXtBase\nfrom keras.src.applications.convnext import ConvNeXtLarge as ConvNeXtLarge\nfrom keras.src.applications.convnext import ConvNeXtSmall as ConvNeXtSmall\nfrom keras.src.applications.convnext import ConvNeXtTiny as ConvNeXtTiny\nfrom keras.src.applications.convnext import ConvNeXtXLarge as ConvNeXtXLarge\nfrom keras.src.applications.convnext import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.convnext import preprocess_input as preprocess_input\n"
  },
  {
    "path": "keras/api/applications/densenet/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.densenet import DenseNet121 as DenseNet121\nfrom keras.src.applications.densenet import DenseNet169 as DenseNet169\nfrom keras.src.applications.densenet import DenseNet201 as DenseNet201\nfrom keras.src.applications.densenet import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.densenet import preprocess_input as preprocess_input\n"
  },
  {
    "path": "keras/api/applications/efficientnet/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.efficientnet import EfficientNetB0 as EfficientNetB0\nfrom keras.src.applications.efficientnet import EfficientNetB1 as EfficientNetB1\nfrom keras.src.applications.efficientnet import EfficientNetB2 as EfficientNetB2\nfrom keras.src.applications.efficientnet import EfficientNetB3 as EfficientNetB3\nfrom keras.src.applications.efficientnet import EfficientNetB4 as EfficientNetB4\nfrom keras.src.applications.efficientnet import EfficientNetB5 as EfficientNetB5\nfrom keras.src.applications.efficientnet import EfficientNetB6 as EfficientNetB6\nfrom keras.src.applications.efficientnet import EfficientNetB7 as EfficientNetB7\nfrom keras.src.applications.efficientnet import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.efficientnet import (\n    preprocess_input as preprocess_input,\n)\n"
  },
  {
    "path": "keras/api/applications/efficientnet_v2/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2B0 as EfficientNetV2B0,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2B1 as EfficientNetV2B1,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2B2 as EfficientNetV2B2,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2B3 as EfficientNetV2B3,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2L as EfficientNetV2L,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2M as EfficientNetV2M,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    EfficientNetV2S as EfficientNetV2S,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.efficientnet_v2 import (\n    preprocess_input as preprocess_input,\n)\n"
  },
  {
    "path": "keras/api/applications/imagenet_utils/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.imagenet_utils import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.imagenet_utils import (\n    preprocess_input as preprocess_input,\n)\n"
  },
  {
    "path": "keras/api/applications/inception_resnet_v2/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.inception_resnet_v2 import (\n    InceptionResNetV2 as InceptionResNetV2,\n)\nfrom keras.src.applications.inception_resnet_v2 import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.inception_resnet_v2 import (\n    preprocess_input as preprocess_input,\n)\n"
  },
  {
    "path": "keras/api/applications/inception_v3/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.inception_v3 import InceptionV3 as InceptionV3\nfrom keras.src.applications.inception_v3 import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.inception_v3 import (\n    preprocess_input as preprocess_input,\n)\n"
  },
  {
    "path": "keras/api/applications/mobilenet/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.mobilenet import MobileNet as MobileNet\nfrom keras.src.applications.mobilenet import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.mobilenet import (\n    preprocess_input as preprocess_input,\n)\n"
  },
  {
    "path": "keras/api/applications/mobilenet_v2/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.mobilenet_v2 import MobileNetV2 as MobileNetV2\nfrom keras.src.applications.mobilenet_v2 import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.mobilenet_v2 import (\n    preprocess_input as preprocess_input,\n)\n"
  },
  {
    "path": "keras/api/applications/mobilenet_v3/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.mobilenet_v3 import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.mobilenet_v3 import (\n    preprocess_input as preprocess_input,\n)\n"
  },
  {
    "path": "keras/api/applications/nasnet/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.nasnet import NASNetLarge as NASNetLarge\nfrom keras.src.applications.nasnet import NASNetMobile as NASNetMobile\nfrom keras.src.applications.nasnet import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.nasnet import preprocess_input as preprocess_input\n"
  },
  {
    "path": "keras/api/applications/resnet/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.resnet import ResNet50 as ResNet50\nfrom keras.src.applications.resnet import ResNet101 as ResNet101\nfrom keras.src.applications.resnet import ResNet152 as ResNet152\nfrom keras.src.applications.resnet import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.resnet import preprocess_input as preprocess_input\n"
  },
  {
    "path": "keras/api/applications/resnet50/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.resnet import ResNet50 as ResNet50\nfrom keras.src.applications.resnet import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.resnet import preprocess_input as preprocess_input\n"
  },
  {
    "path": "keras/api/applications/resnet_v2/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.resnet_v2 import ResNet50V2 as ResNet50V2\nfrom keras.src.applications.resnet_v2 import ResNet101V2 as ResNet101V2\nfrom keras.src.applications.resnet_v2 import ResNet152V2 as ResNet152V2\nfrom keras.src.applications.resnet_v2 import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.resnet_v2 import (\n    preprocess_input as preprocess_input,\n)\n"
  },
  {
    "path": "keras/api/applications/vgg16/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.vgg16 import VGG16 as VGG16\nfrom keras.src.applications.vgg16 import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.vgg16 import preprocess_input as preprocess_input\n"
  },
  {
    "path": "keras/api/applications/vgg19/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.vgg19 import VGG19 as VGG19\nfrom keras.src.applications.vgg19 import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.vgg19 import preprocess_input as preprocess_input\n"
  },
  {
    "path": "keras/api/applications/xception/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.applications.xception import Xception as Xception\nfrom keras.src.applications.xception import (\n    decode_predictions as decode_predictions,\n)\nfrom keras.src.applications.xception import preprocess_input as preprocess_input\n"
  },
  {
    "path": "keras/api/backend/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.backend.common.dtypes import result_type as result_type\nfrom keras.src.backend.common.global_state import clear_session as clear_session\nfrom keras.src.backend.common.keras_tensor import (\n    is_keras_tensor as is_keras_tensor,\n)\nfrom keras.src.backend.common.variables import is_float_dtype as is_float_dtype\nfrom keras.src.backend.common.variables import is_int_dtype as is_int_dtype\nfrom keras.src.backend.common.variables import (\n    standardize_dtype as standardize_dtype,\n)\nfrom keras.src.backend.config import backend as backend\nfrom keras.src.backend.config import epsilon as epsilon\nfrom keras.src.backend.config import floatx as floatx\nfrom keras.src.backend.config import image_data_format as image_data_format\nfrom keras.src.backend.config import set_epsilon as set_epsilon\nfrom keras.src.backend.config import set_floatx as set_floatx\nfrom keras.src.backend.config import (\n    set_image_data_format as set_image_data_format,\n)\nfrom keras.src.utils.naming import get_uid as get_uid\n"
  },
  {
    "path": "keras/api/callbacks/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.callbacks.backup_and_restore import (\n    BackupAndRestore as BackupAndRestore,\n)\nfrom keras.src.callbacks.callback import Callback as Callback\nfrom keras.src.callbacks.callback_list import CallbackList as CallbackList\nfrom keras.src.callbacks.csv_logger import CSVLogger as CSVLogger\nfrom keras.src.callbacks.early_stopping import EarlyStopping as EarlyStopping\nfrom keras.src.callbacks.history import History as History\nfrom keras.src.callbacks.lambda_callback import LambdaCallback as LambdaCallback\nfrom keras.src.callbacks.learning_rate_scheduler import (\n    LearningRateScheduler as LearningRateScheduler,\n)\nfrom keras.src.callbacks.model_checkpoint import (\n    ModelCheckpoint as ModelCheckpoint,\n)\nfrom keras.src.callbacks.orbax_checkpoint import (\n    OrbaxCheckpoint as OrbaxCheckpoint,\n)\nfrom keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger\nfrom keras.src.callbacks.reduce_lr_on_plateau import (\n    ReduceLROnPlateau as ReduceLROnPlateau,\n)\nfrom keras.src.callbacks.remote_monitor import RemoteMonitor as RemoteMonitor\nfrom keras.src.callbacks.swap_ema_weights import (\n    SwapEMAWeights as SwapEMAWeights,\n)\nfrom keras.src.callbacks.tensorboard import TensorBoard as TensorBoard\nfrom keras.src.callbacks.terminate_on_nan import (\n    TerminateOnNaN as TerminateOnNaN,\n)\n"
  },
  {
    "path": "keras/api/config/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.backend.config import backend as backend\nfrom keras.src.backend.config import (\n    disable_flash_attention as disable_flash_attention,\n)\nfrom keras.src.backend.config import (\n    enable_flash_attention as enable_flash_attention,\n)\nfrom keras.src.backend.config import epsilon as epsilon\nfrom keras.src.backend.config import floatx as floatx\nfrom keras.src.backend.config import image_data_format as image_data_format\nfrom keras.src.backend.config import (\n    is_flash_attention_enabled as is_flash_attention_enabled,\n)\nfrom keras.src.backend.config import is_nnx_enabled as is_nnx_enabled\nfrom keras.src.backend.config import max_epochs as max_epochs\nfrom keras.src.backend.config import max_steps_per_epoch as max_steps_per_epoch\nfrom keras.src.backend.config import set_epsilon as set_epsilon\nfrom keras.src.backend.config import set_floatx as set_floatx\nfrom keras.src.backend.config import (\n    set_image_data_format as set_image_data_format,\n)\nfrom keras.src.backend.config import set_max_epochs as set_max_epochs\nfrom keras.src.backend.config import (\n    set_max_steps_per_epoch as set_max_steps_per_epoch,\n)\nfrom keras.src.dtype_policies.dtype_policy import dtype_policy as dtype_policy\nfrom keras.src.dtype_policies.dtype_policy import (\n    set_dtype_policy as set_dtype_policy,\n)\nfrom keras.src.saving.serialization_lib import (\n    enable_unsafe_deserialization as enable_unsafe_deserialization,\n)\nfrom keras.src.utils.backend_utils import set_backend as set_backend\nfrom keras.src.utils.io_utils import (\n    disable_interactive_logging as disable_interactive_logging,\n)\nfrom keras.src.utils.io_utils import (\n    enable_interactive_logging as enable_interactive_logging,\n)\nfrom keras.src.utils.io_utils import (\n    is_interactive_logging_enabled as is_interactive_logging_enabled,\n)\nfrom keras.src.utils.traceback_utils import (\n    disable_traceback_filtering as disable_traceback_filtering,\n)\nfrom keras.src.utils.traceback_utils import (\n    enable_traceback_filtering as enable_traceback_filtering,\n)\nfrom keras.src.utils.traceback_utils import (\n    is_traceback_filtering_enabled as is_traceback_filtering_enabled,\n)\n"
  },
  {
    "path": "keras/api/constraints/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.constraints import deserialize as deserialize\nfrom keras.src.constraints import get as get\nfrom keras.src.constraints import serialize as serialize\nfrom keras.src.constraints.constraints import Constraint as Constraint\nfrom keras.src.constraints.constraints import MaxNorm as MaxNorm\nfrom keras.src.constraints.constraints import MaxNorm as max_norm\nfrom keras.src.constraints.constraints import MinMaxNorm as MinMaxNorm\nfrom keras.src.constraints.constraints import MinMaxNorm as min_max_norm\nfrom keras.src.constraints.constraints import NonNeg as NonNeg\nfrom keras.src.constraints.constraints import NonNeg as non_neg\nfrom keras.src.constraints.constraints import UnitNorm as UnitNorm\nfrom keras.src.constraints.constraints import UnitNorm as unit_norm\n"
  },
  {
    "path": "keras/api/datasets/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.datasets import boston_housing as boston_housing\nfrom keras.datasets import california_housing as california_housing\nfrom keras.datasets import cifar10 as cifar10\nfrom keras.datasets import cifar100 as cifar100\nfrom keras.datasets import fashion_mnist as fashion_mnist\nfrom keras.datasets import imdb as imdb\nfrom keras.datasets import mnist as mnist\nfrom keras.datasets import reuters as reuters\n"
  },
  {
    "path": "keras/api/datasets/boston_housing/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.datasets.boston_housing import load_data as load_data\n"
  },
  {
    "path": "keras/api/datasets/california_housing/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.datasets.california_housing import load_data as load_data\n"
  },
  {
    "path": "keras/api/datasets/cifar10/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.datasets.cifar10 import load_data as load_data\n"
  },
  {
    "path": "keras/api/datasets/cifar100/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.datasets.cifar100 import load_data as load_data\n"
  },
  {
    "path": "keras/api/datasets/fashion_mnist/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.datasets.fashion_mnist import load_data as load_data\n"
  },
  {
    "path": "keras/api/datasets/imdb/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.datasets.imdb import get_word_index as get_word_index\nfrom keras.src.datasets.imdb import load_data as load_data\n"
  },
  {
    "path": "keras/api/datasets/mnist/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.datasets.mnist import load_data as load_data\n"
  },
  {
    "path": "keras/api/datasets/reuters/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.datasets.reuters import get_label_names as get_label_names\nfrom keras.src.datasets.reuters import get_word_index as get_word_index\nfrom keras.src.datasets.reuters import load_data as load_data\n"
  },
  {
    "path": "keras/api/distillation/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.distillation.distillation_loss import (\n    DistillationLoss as DistillationLoss,\n)\nfrom keras.src.distillation.distillation_loss import (\n    FeatureDistillation as FeatureDistillation,\n)\nfrom keras.src.distillation.distillation_loss import (\n    LogitsDistillation as LogitsDistillation,\n)\nfrom keras.src.distillation.distiller import Distiller as Distiller\n"
  },
  {
    "path": "keras/api/distribution/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.distribution.distribution_lib import DataParallel as DataParallel\nfrom keras.src.distribution.distribution_lib import DeviceMesh as DeviceMesh\nfrom keras.src.distribution.distribution_lib import LayoutMap as LayoutMap\nfrom keras.src.distribution.distribution_lib import (\n    ModelParallel as ModelParallel,\n)\nfrom keras.src.distribution.distribution_lib import TensorLayout as TensorLayout\nfrom keras.src.distribution.distribution_lib import (\n    distribute_tensor as distribute_tensor,\n)\nfrom keras.src.distribution.distribution_lib import distribution as distribution\nfrom keras.src.distribution.distribution_lib import (\n    get_device_count as get_device_count,\n)\nfrom keras.src.distribution.distribution_lib import initialize as initialize\nfrom keras.src.distribution.distribution_lib import list_devices as list_devices\nfrom keras.src.distribution.distribution_lib import (\n    set_distribution as set_distribution,\n)\n"
  },
  {
    "path": "keras/api/dtype_policies/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.dtype_policies import deserialize as deserialize\nfrom keras.src.dtype_policies import get as get\nfrom keras.src.dtype_policies import serialize as serialize\nfrom keras.src.dtype_policies.dtype_policy import (\n    AWQDTypePolicy as AWQDTypePolicy,\n)\nfrom keras.src.dtype_policies.dtype_policy import DTypePolicy as DTypePolicy\nfrom keras.src.dtype_policies.dtype_policy import (\n    FloatDTypePolicy as FloatDTypePolicy,\n)\nfrom keras.src.dtype_policies.dtype_policy import (\n    GPTQDTypePolicy as GPTQDTypePolicy,\n)\nfrom keras.src.dtype_policies.dtype_policy import (\n    Int4DTypePolicy as Int4DTypePolicy,\n)\nfrom keras.src.dtype_policies.dtype_policy import (\n    QuantizedDTypePolicy as QuantizedDTypePolicy,\n)\nfrom keras.src.dtype_policies.dtype_policy import (\n    QuantizedFloat8DTypePolicy as QuantizedFloat8DTypePolicy,\n)\nfrom keras.src.dtype_policies.dtype_policy_map import (\n    DTypePolicyMap as DTypePolicyMap,\n)\n"
  },
  {
    "path": "keras/api/export/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.export.saved_model import ExportArchive as ExportArchive\n"
  },
  {
    "path": "keras/api/initializers/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.initializers import deserialize as deserialize\nfrom keras.src.initializers import get as get\nfrom keras.src.initializers import serialize as serialize\nfrom keras.src.initializers.constant_initializers import STFT as STFT\nfrom keras.src.initializers.constant_initializers import STFT as STFTInitializer\nfrom keras.src.initializers.constant_initializers import STFT as stft\nfrom keras.src.initializers.constant_initializers import Constant as Constant\nfrom keras.src.initializers.constant_initializers import Constant as constant\nfrom keras.src.initializers.constant_initializers import Identity as Identity\nfrom keras.src.initializers.constant_initializers import (\n    Identity as IdentityInitializer,\n)\nfrom keras.src.initializers.constant_initializers import Identity as identity\nfrom keras.src.initializers.constant_initializers import Ones as Ones\nfrom keras.src.initializers.constant_initializers import Ones as ones\nfrom keras.src.initializers.constant_initializers import Zeros as Zeros\nfrom keras.src.initializers.constant_initializers import Zeros as zeros\nfrom keras.src.initializers.initializer import Initializer as Initializer\nfrom keras.src.initializers.random_initializers import (\n    GlorotNormal as GlorotNormal,\n)\nfrom keras.src.initializers.random_initializers import (\n    GlorotNormal as glorot_normal,\n)\nfrom keras.src.initializers.random_initializers import (\n    GlorotUniform as GlorotUniform,\n)\nfrom keras.src.initializers.random_initializers import (\n    GlorotUniform as glorot_uniform,\n)\nfrom keras.src.initializers.random_initializers import HeNormal as HeNormal\nfrom keras.src.initializers.random_initializers import HeNormal as he_normal\nfrom keras.src.initializers.random_initializers import HeUniform as HeUniform\nfrom keras.src.initializers.random_initializers import HeUniform as he_uniform\nfrom keras.src.initializers.random_initializers import (\n    LecunNormal as LecunNormal,\n)\nfrom keras.src.initializers.random_initializers import (\n    LecunNormal as lecun_normal,\n)\nfrom keras.src.initializers.random_initializers import (\n    LecunUniform as LecunUniform,\n)\nfrom keras.src.initializers.random_initializers import (\n    LecunUniform as lecun_uniform,\n)\nfrom keras.src.initializers.random_initializers import Orthogonal as Orthogonal\nfrom keras.src.initializers.random_initializers import (\n    Orthogonal as OrthogonalInitializer,\n)\nfrom keras.src.initializers.random_initializers import Orthogonal as orthogonal\nfrom keras.src.initializers.random_initializers import (\n    RandomNormal as RandomNormal,\n)\nfrom keras.src.initializers.random_initializers import (\n    RandomNormal as random_normal,\n)\nfrom keras.src.initializers.random_initializers import (\n    RandomUniform as RandomUniform,\n)\nfrom keras.src.initializers.random_initializers import (\n    RandomUniform as random_uniform,\n)\nfrom keras.src.initializers.random_initializers import (\n    TruncatedNormal as TruncatedNormal,\n)\nfrom keras.src.initializers.random_initializers import (\n    TruncatedNormal as truncated_normal,\n)\nfrom keras.src.initializers.random_initializers import (\n    VarianceScaling as VarianceScaling,\n)\nfrom keras.src.initializers.random_initializers import (\n    VarianceScaling as variance_scaling,\n)\n"
  },
  {
    "path": "keras/api/layers/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.export.tfsm_layer import TFSMLayer as TFSMLayer\nfrom keras.src.layers import deserialize as deserialize\nfrom keras.src.layers import serialize as serialize\nfrom keras.src.layers.activations.activation import Activation as Activation\nfrom keras.src.layers.activations.elu import ELU as ELU\nfrom keras.src.layers.activations.leaky_relu import LeakyReLU as LeakyReLU\nfrom keras.src.layers.activations.prelu import PReLU as PReLU\nfrom keras.src.layers.activations.relu import ReLU as ReLU\nfrom keras.src.layers.activations.softmax import Softmax as Softmax\nfrom keras.src.layers.attention.additive_attention import (\n    AdditiveAttention as AdditiveAttention,\n)\nfrom keras.src.layers.attention.attention import Attention as Attention\nfrom keras.src.layers.attention.grouped_query_attention import (\n    GroupedQueryAttention as GroupQueryAttention,\n)\nfrom keras.src.layers.attention.multi_head_attention import (\n    MultiHeadAttention as MultiHeadAttention,\n)\nfrom keras.src.layers.convolutional.conv1d import Conv1D as Conv1D\nfrom keras.src.layers.convolutional.conv1d import Conv1D as Convolution1D\nfrom keras.src.layers.convolutional.conv1d_transpose import (\n    Conv1DTranspose as Conv1DTranspose,\n)\nfrom keras.src.layers.convolutional.conv1d_transpose import (\n    Conv1DTranspose as Convolution1DTranspose,\n)\nfrom keras.src.layers.convolutional.conv2d import Conv2D as Conv2D\nfrom keras.src.layers.convolutional.conv2d import Conv2D as Convolution2D\nfrom keras.src.layers.convolutional.conv2d_transpose import (\n    Conv2DTranspose as Conv2DTranspose,\n)\nfrom keras.src.layers.convolutional.conv2d_transpose import (\n    Conv2DTranspose as Convolution2DTranspose,\n)\nfrom keras.src.layers.convolutional.conv3d import Conv3D as Conv3D\nfrom keras.src.layers.convolutional.conv3d import Conv3D as Convolution3D\nfrom keras.src.layers.convolutional.conv3d_transpose import (\n    Conv3DTranspose as Conv3DTranspose,\n)\nfrom keras.src.layers.convolutional.conv3d_transpose import (\n    Conv3DTranspose as Convolution3DTranspose,\n)\nfrom keras.src.layers.convolutional.depthwise_conv1d import (\n    DepthwiseConv1D as DepthwiseConv1D,\n)\nfrom keras.src.layers.convolutional.depthwise_conv2d import (\n    DepthwiseConv2D as DepthwiseConv2D,\n)\nfrom keras.src.layers.convolutional.separable_conv1d import (\n    SeparableConv1D as SeparableConv1D,\n)\nfrom keras.src.layers.convolutional.separable_conv1d import (\n    SeparableConv1D as SeparableConvolution1D,\n)\nfrom keras.src.layers.convolutional.separable_conv2d import (\n    SeparableConv2D as SeparableConv2D,\n)\nfrom keras.src.layers.convolutional.separable_conv2d import (\n    SeparableConv2D as SeparableConvolution2D,\n)\nfrom keras.src.layers.core.dense import Dense as Dense\nfrom keras.src.layers.core.einsum_dense import EinsumDense as EinsumDense\nfrom keras.src.layers.core.embedding import Embedding as Embedding\nfrom keras.src.layers.core.identity import Identity as Identity\nfrom keras.src.layers.core.input_layer import Input as Input\nfrom keras.src.layers.core.input_layer import InputLayer as InputLayer\nfrom keras.src.layers.core.lambda_layer import Lambda as Lambda\nfrom keras.src.layers.core.masking import Masking as Masking\nfrom keras.src.layers.core.reversible_embedding import (\n    ReversibleEmbedding as ReversibleEmbedding,\n)\nfrom keras.src.layers.core.wrapper import Wrapper as Wrapper\nfrom keras.src.layers.input_spec import InputSpec as InputSpec\nfrom keras.src.layers.layer import Layer as Layer\nfrom keras.src.layers.merging.add import Add as Add\nfrom keras.src.layers.merging.add import add as add\nfrom keras.src.layers.merging.average import Average as Average\nfrom keras.src.layers.merging.average import average as average\nfrom keras.src.layers.merging.concatenate import Concatenate as Concatenate\nfrom keras.src.layers.merging.concatenate import concatenate as concatenate\nfrom keras.src.layers.merging.dot import Dot as Dot\nfrom keras.src.layers.merging.dot import dot as dot\nfrom keras.src.layers.merging.maximum import Maximum as Maximum\nfrom keras.src.layers.merging.maximum import maximum as maximum\nfrom keras.src.layers.merging.minimum import Minimum as Minimum\nfrom keras.src.layers.merging.minimum import minimum as minimum\nfrom keras.src.layers.merging.multiply import Multiply as Multiply\nfrom keras.src.layers.merging.multiply import multiply as multiply\nfrom keras.src.layers.merging.subtract import Subtract as Subtract\nfrom keras.src.layers.merging.subtract import subtract as subtract\nfrom keras.src.layers.normalization.batch_normalization import (\n    BatchNormalization as BatchNormalization,\n)\nfrom keras.src.layers.normalization.group_normalization import (\n    GroupNormalization as GroupNormalization,\n)\nfrom keras.src.layers.normalization.layer_normalization import (\n    LayerNormalization as LayerNormalization,\n)\nfrom keras.src.layers.normalization.rms_normalization import (\n    RMSNormalization as RMSNormalization,\n)\nfrom keras.src.layers.normalization.spectral_normalization import (\n    SpectralNormalization as SpectralNormalization,\n)\nfrom keras.src.layers.normalization.unit_normalization import (\n    UnitNormalization as UnitNormalization,\n)\nfrom keras.src.layers.pooling.adaptive_average_pooling1d import (\n    AdaptiveAveragePooling1D as AdaptiveAveragePooling1D,\n)\nfrom keras.src.layers.pooling.adaptive_average_pooling2d import (\n    AdaptiveAveragePooling2D as AdaptiveAveragePooling2D,\n)\nfrom keras.src.layers.pooling.adaptive_average_pooling3d import (\n    AdaptiveAveragePooling3D as AdaptiveAveragePooling3D,\n)\nfrom keras.src.layers.pooling.adaptive_max_pooling1d import (\n    AdaptiveMaxPooling1D as AdaptiveMaxPooling1D,\n)\nfrom keras.src.layers.pooling.adaptive_max_pooling2d import (\n    AdaptiveMaxPooling2D as AdaptiveMaxPooling2D,\n)\nfrom keras.src.layers.pooling.adaptive_max_pooling3d import (\n    AdaptiveMaxPooling3D as AdaptiveMaxPooling3D,\n)\nfrom keras.src.layers.pooling.average_pooling1d import (\n    AveragePooling1D as AveragePooling1D,\n)\nfrom keras.src.layers.pooling.average_pooling1d import (\n    AveragePooling1D as AvgPool1D,\n)\nfrom keras.src.layers.pooling.average_pooling2d import (\n    AveragePooling2D as AveragePooling2D,\n)\nfrom keras.src.layers.pooling.average_pooling2d import (\n    AveragePooling2D as AvgPool2D,\n)\nfrom keras.src.layers.pooling.average_pooling3d import (\n    AveragePooling3D as AveragePooling3D,\n)\nfrom keras.src.layers.pooling.average_pooling3d import (\n    AveragePooling3D as AvgPool3D,\n)\nfrom keras.src.layers.pooling.global_average_pooling1d import (\n    GlobalAveragePooling1D as GlobalAveragePooling1D,\n)\nfrom keras.src.layers.pooling.global_average_pooling1d import (\n    GlobalAveragePooling1D as GlobalAvgPool1D,\n)\nfrom keras.src.layers.pooling.global_average_pooling2d import (\n    GlobalAveragePooling2D as GlobalAveragePooling2D,\n)\nfrom keras.src.layers.pooling.global_average_pooling2d import (\n    GlobalAveragePooling2D as GlobalAvgPool2D,\n)\nfrom keras.src.layers.pooling.global_average_pooling3d import (\n    GlobalAveragePooling3D as GlobalAveragePooling3D,\n)\nfrom keras.src.layers.pooling.global_average_pooling3d import (\n    GlobalAveragePooling3D as GlobalAvgPool3D,\n)\nfrom keras.src.layers.pooling.global_max_pooling1d import (\n    GlobalMaxPooling1D as GlobalMaxPool1D,\n)\nfrom keras.src.layers.pooling.global_max_pooling1d import (\n    GlobalMaxPooling1D as GlobalMaxPooling1D,\n)\nfrom keras.src.layers.pooling.global_max_pooling2d import (\n    GlobalMaxPooling2D as GlobalMaxPool2D,\n)\nfrom keras.src.layers.pooling.global_max_pooling2d import (\n    GlobalMaxPooling2D as GlobalMaxPooling2D,\n)\nfrom keras.src.layers.pooling.global_max_pooling3d import (\n    GlobalMaxPooling3D as GlobalMaxPool3D,\n)\nfrom keras.src.layers.pooling.global_max_pooling3d import (\n    GlobalMaxPooling3D as GlobalMaxPooling3D,\n)\nfrom keras.src.layers.pooling.max_pooling1d import MaxPooling1D as MaxPool1D\nfrom keras.src.layers.pooling.max_pooling1d import MaxPooling1D as MaxPooling1D\nfrom keras.src.layers.pooling.max_pooling2d import MaxPooling2D as MaxPool2D\nfrom keras.src.layers.pooling.max_pooling2d import MaxPooling2D as MaxPooling2D\nfrom keras.src.layers.pooling.max_pooling3d import MaxPooling3D as MaxPool3D\nfrom keras.src.layers.pooling.max_pooling3d import MaxPooling3D as MaxPooling3D\nfrom keras.src.layers.preprocessing.category_encoding import (\n    CategoryEncoding as CategoryEncoding,\n)\nfrom keras.src.layers.preprocessing.discretization import (\n    Discretization as Discretization,\n)\nfrom keras.src.layers.preprocessing.hashed_crossing import (\n    HashedCrossing as HashedCrossing,\n)\nfrom keras.src.layers.preprocessing.hashing import Hashing as Hashing\nfrom keras.src.layers.preprocessing.image_preprocessing.aug_mix import (\n    AugMix as AugMix,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.auto_contrast import (\n    AutoContrast as AutoContrast,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.center_crop import (\n    CenterCrop as CenterCrop,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.clahe import (\n    ContrastLimitedAdaptiveHistogramEqualization as ContrastLimitedAdaptiveHistogramEqualization,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.cut_mix import (\n    CutMix as CutMix,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.equalization import (\n    Equalization as Equalization,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.max_num_bounding_box import (\n    MaxNumBoundingBoxes as MaxNumBoundingBoxes,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.mix_up import (\n    MixUp as MixUp,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.rand_augment import (\n    RandAugment as RandAugment,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_brightness import (\n    RandomBrightness as RandomBrightness,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_color_degeneration import (\n    RandomColorDegeneration as RandomColorDegeneration,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_color_jitter import (\n    RandomColorJitter as RandomColorJitter,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_contrast import (\n    RandomContrast as RandomContrast,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_crop import (\n    RandomCrop as RandomCrop,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_elastic_transform import (\n    RandomElasticTransform as RandomElasticTransform,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_erasing import (\n    RandomErasing as RandomErasing,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_flip import (\n    RandomFlip as RandomFlip,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_gaussian_blur import (\n    RandomGaussianBlur as RandomGaussianBlur,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_grayscale import (\n    RandomGrayscale as RandomGrayscale,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_hue import (\n    RandomHue as RandomHue,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_invert import (\n    RandomInvert as RandomInvert,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_perspective import (\n    RandomPerspective as RandomPerspective,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_posterization import (\n    RandomPosterization as RandomPosterization,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_rotation import (\n    RandomRotation as RandomRotation,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_saturation import (\n    RandomSaturation as RandomSaturation,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_sharpness import (\n    RandomSharpness as RandomSharpness,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_shear import (\n    RandomShear as RandomShear,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_translation import (\n    RandomTranslation as RandomTranslation,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_zoom import (\n    RandomZoom as RandomZoom,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.resizing import (\n    Resizing as Resizing,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.solarization import (\n    Solarization as Solarization,\n)\nfrom keras.src.layers.preprocessing.integer_lookup import (\n    IntegerLookup as IntegerLookup,\n)\nfrom keras.src.layers.preprocessing.mel_spectrogram import (\n    MelSpectrogram as MelSpectrogram,\n)\nfrom keras.src.layers.preprocessing.normalization import (\n    Normalization as Normalization,\n)\nfrom keras.src.layers.preprocessing.pipeline import Pipeline as Pipeline\nfrom keras.src.layers.preprocessing.rescaling import Rescaling as Rescaling\nfrom keras.src.layers.preprocessing.stft_spectrogram import (\n    STFTSpectrogram as STFTSpectrogram,\n)\nfrom keras.src.layers.preprocessing.string_lookup import (\n    StringLookup as StringLookup,\n)\nfrom keras.src.layers.preprocessing.text_vectorization import (\n    TextVectorization as TextVectorization,\n)\nfrom keras.src.layers.regularization.activity_regularization import (\n    ActivityRegularization as ActivityRegularization,\n)\nfrom keras.src.layers.regularization.alpha_dropout import (\n    AlphaDropout as AlphaDropout,\n)\nfrom keras.src.layers.regularization.dropout import Dropout as Dropout\nfrom keras.src.layers.regularization.gaussian_dropout import (\n    GaussianDropout as GaussianDropout,\n)\nfrom keras.src.layers.regularization.gaussian_noise import (\n    GaussianNoise as GaussianNoise,\n)\nfrom keras.src.layers.regularization.spatial_dropout import (\n    SpatialDropout1D as SpatialDropout1D,\n)\nfrom keras.src.layers.regularization.spatial_dropout import (\n    SpatialDropout2D as SpatialDropout2D,\n)\nfrom keras.src.layers.regularization.spatial_dropout import (\n    SpatialDropout3D as SpatialDropout3D,\n)\nfrom keras.src.layers.reshaping.cropping1d import Cropping1D as Cropping1D\nfrom keras.src.layers.reshaping.cropping2d import Cropping2D as Cropping2D\nfrom keras.src.layers.reshaping.cropping3d import Cropping3D as Cropping3D\nfrom keras.src.layers.reshaping.flatten import Flatten as Flatten\nfrom keras.src.layers.reshaping.permute import Permute as Permute\nfrom keras.src.layers.reshaping.repeat_vector import (\n    RepeatVector as RepeatVector,\n)\nfrom keras.src.layers.reshaping.reshape import Reshape as Reshape\nfrom keras.src.layers.reshaping.up_sampling1d import (\n    UpSampling1D as UpSampling1D,\n)\nfrom keras.src.layers.reshaping.up_sampling2d import (\n    UpSampling2D as UpSampling2D,\n)\nfrom keras.src.layers.reshaping.up_sampling3d import (\n    UpSampling3D as UpSampling3D,\n)\nfrom keras.src.layers.reshaping.zero_padding1d import (\n    ZeroPadding1D as ZeroPadding1D,\n)\nfrom keras.src.layers.reshaping.zero_padding2d import (\n    ZeroPadding2D as ZeroPadding2D,\n)\nfrom keras.src.layers.reshaping.zero_padding3d import (\n    ZeroPadding3D as ZeroPadding3D,\n)\nfrom keras.src.layers.rnn.bidirectional import Bidirectional as Bidirectional\nfrom keras.src.layers.rnn.conv_lstm1d import ConvLSTM1D as ConvLSTM1D\nfrom keras.src.layers.rnn.conv_lstm2d import ConvLSTM2D as ConvLSTM2D\nfrom keras.src.layers.rnn.conv_lstm3d import ConvLSTM3D as ConvLSTM3D\nfrom keras.src.layers.rnn.gru import GRU as GRU\nfrom keras.src.layers.rnn.gru import GRUCell as GRUCell\nfrom keras.src.layers.rnn.lstm import LSTM as LSTM\nfrom keras.src.layers.rnn.lstm import LSTMCell as LSTMCell\nfrom keras.src.layers.rnn.rnn import RNN as RNN\nfrom keras.src.layers.rnn.simple_rnn import SimpleRNN as SimpleRNN\nfrom keras.src.layers.rnn.simple_rnn import SimpleRNNCell as SimpleRNNCell\nfrom keras.src.layers.rnn.stacked_rnn_cells import (\n    StackedRNNCells as StackedRNNCells,\n)\nfrom keras.src.layers.rnn.time_distributed import (\n    TimeDistributed as TimeDistributed,\n)\nfrom keras.src.utils.jax_layer import FlaxLayer as FlaxLayer\nfrom keras.src.utils.jax_layer import JaxLayer as JaxLayer\nfrom keras.src.utils.torch_utils import TorchModuleWrapper as TorchModuleWrapper\n"
  },
  {
    "path": "keras/api/legacy/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.legacy import saving as saving\n"
  },
  {
    "path": "keras/api/legacy/saving/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.legacy.saving.serialization import (\n    deserialize_keras_object as deserialize_keras_object,\n)\nfrom keras.src.legacy.saving.serialization import (\n    serialize_keras_object as serialize_keras_object,\n)\n"
  },
  {
    "path": "keras/api/losses/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.losses import deserialize as deserialize\nfrom keras.src.losses import get as get\nfrom keras.src.losses import serialize as serialize\nfrom keras.src.losses.loss import Loss as Loss\nfrom keras.src.losses.losses import CTC as CTC\nfrom keras.src.losses.losses import BinaryCrossentropy as BinaryCrossentropy\nfrom keras.src.losses.losses import (\n    BinaryFocalCrossentropy as BinaryFocalCrossentropy,\n)\nfrom keras.src.losses.losses import (\n    CategoricalCrossentropy as CategoricalCrossentropy,\n)\nfrom keras.src.losses.losses import (\n    CategoricalFocalCrossentropy as CategoricalFocalCrossentropy,\n)\nfrom keras.src.losses.losses import (\n    CategoricalGeneralizedCrossEntropy as CategoricalGeneralizedCrossEntropy,\n)\nfrom keras.src.losses.losses import CategoricalHinge as CategoricalHinge\nfrom keras.src.losses.losses import Circle as Circle\nfrom keras.src.losses.losses import CosineSimilarity as CosineSimilarity\nfrom keras.src.losses.losses import Dice as Dice\nfrom keras.src.losses.losses import Hinge as Hinge\nfrom keras.src.losses.losses import Huber as Huber\nfrom keras.src.losses.losses import KLDivergence as KLDivergence\nfrom keras.src.losses.losses import LogCosh as LogCosh\nfrom keras.src.losses.losses import MeanAbsoluteError as MeanAbsoluteError\nfrom keras.src.losses.losses import (\n    MeanAbsolutePercentageError as MeanAbsolutePercentageError,\n)\nfrom keras.src.losses.losses import MeanSquaredError as MeanSquaredError\nfrom keras.src.losses.losses import (\n    MeanSquaredLogarithmicError as MeanSquaredLogarithmicError,\n)\nfrom keras.src.losses.losses import Poisson as Poisson\nfrom keras.src.losses.losses import (\n    SparseCategoricalCrossentropy as SparseCategoricalCrossentropy,\n)\nfrom keras.src.losses.losses import SquaredHinge as SquaredHinge\nfrom keras.src.losses.losses import Tversky as Tversky\nfrom keras.src.losses.losses import binary_crossentropy as binary_crossentropy\nfrom keras.src.losses.losses import (\n    binary_focal_crossentropy as binary_focal_crossentropy,\n)\nfrom keras.src.losses.losses import (\n    categorical_crossentropy as categorical_crossentropy,\n)\nfrom keras.src.losses.losses import (\n    categorical_focal_crossentropy as categorical_focal_crossentropy,\n)\nfrom keras.src.losses.losses import (\n    categorical_generalized_cross_entropy as categorical_generalized_cross_entropy,\n)\nfrom keras.src.losses.losses import categorical_hinge as categorical_hinge\nfrom keras.src.losses.losses import circle as circle\nfrom keras.src.losses.losses import cosine_similarity as cosine_similarity\nfrom keras.src.losses.losses import ctc as ctc\nfrom keras.src.losses.losses import dice as dice\nfrom keras.src.losses.losses import hinge as hinge\nfrom keras.src.losses.losses import huber as huber\nfrom keras.src.losses.losses import kl_divergence as kl_divergence\nfrom keras.src.losses.losses import log_cosh as log_cosh\nfrom keras.src.losses.losses import mean_absolute_error as mean_absolute_error\nfrom keras.src.losses.losses import (\n    mean_absolute_percentage_error as mean_absolute_percentage_error,\n)\nfrom keras.src.losses.losses import mean_squared_error as mean_squared_error\nfrom keras.src.losses.losses import (\n    mean_squared_logarithmic_error as mean_squared_logarithmic_error,\n)\nfrom keras.src.losses.losses import poisson as poisson\nfrom keras.src.losses.losses import (\n    sparse_categorical_crossentropy as sparse_categorical_crossentropy,\n)\nfrom keras.src.losses.losses import squared_hinge as squared_hinge\nfrom keras.src.losses.losses import tversky as tversky\n"
  },
  {
    "path": "keras/api/metrics/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.losses.losses import binary_crossentropy as binary_crossentropy\nfrom keras.src.losses.losses import (\n    binary_focal_crossentropy as binary_focal_crossentropy,\n)\nfrom keras.src.losses.losses import (\n    categorical_crossentropy as categorical_crossentropy,\n)\nfrom keras.src.losses.losses import (\n    categorical_focal_crossentropy as categorical_focal_crossentropy,\n)\nfrom keras.src.losses.losses import categorical_hinge as categorical_hinge\nfrom keras.src.losses.losses import hinge as hinge\nfrom keras.src.losses.losses import huber as huber\nfrom keras.src.losses.losses import kl_divergence as kl_divergence\nfrom keras.src.losses.losses import log_cosh as log_cosh\nfrom keras.src.losses.losses import mean_absolute_error as mean_absolute_error\nfrom keras.src.losses.losses import (\n    mean_absolute_percentage_error as mean_absolute_percentage_error,\n)\nfrom keras.src.losses.losses import mean_squared_error as mean_squared_error\nfrom keras.src.losses.losses import (\n    mean_squared_logarithmic_error as mean_squared_logarithmic_error,\n)\nfrom keras.src.losses.losses import poisson as poisson\nfrom keras.src.losses.losses import (\n    sparse_categorical_crossentropy as sparse_categorical_crossentropy,\n)\nfrom keras.src.losses.losses import squared_hinge as squared_hinge\nfrom keras.src.metrics import deserialize as deserialize\nfrom keras.src.metrics import get as get\nfrom keras.src.metrics import serialize as serialize\nfrom keras.src.metrics.accuracy_metrics import Accuracy as Accuracy\nfrom keras.src.metrics.accuracy_metrics import BinaryAccuracy as BinaryAccuracy\nfrom keras.src.metrics.accuracy_metrics import (\n    CategoricalAccuracy as CategoricalAccuracy,\n)\nfrom keras.src.metrics.accuracy_metrics import (\n    SparseCategoricalAccuracy as SparseCategoricalAccuracy,\n)\nfrom keras.src.metrics.accuracy_metrics import (\n    SparseTopKCategoricalAccuracy as SparseTopKCategoricalAccuracy,\n)\nfrom keras.src.metrics.accuracy_metrics import (\n    TopKCategoricalAccuracy as TopKCategoricalAccuracy,\n)\nfrom keras.src.metrics.accuracy_metrics import (\n    binary_accuracy as binary_accuracy,\n)\nfrom keras.src.metrics.accuracy_metrics import (\n    categorical_accuracy as categorical_accuracy,\n)\nfrom keras.src.metrics.accuracy_metrics import (\n    sparse_categorical_accuracy as sparse_categorical_accuracy,\n)\nfrom keras.src.metrics.accuracy_metrics import (\n    sparse_top_k_categorical_accuracy as sparse_top_k_categorical_accuracy,\n)\nfrom keras.src.metrics.accuracy_metrics import (\n    top_k_categorical_accuracy as top_k_categorical_accuracy,\n)\nfrom keras.src.metrics.confusion_metrics import AUC as AUC\nfrom keras.src.metrics.confusion_metrics import FalseNegatives as FalseNegatives\nfrom keras.src.metrics.confusion_metrics import FalsePositives as FalsePositives\nfrom keras.src.metrics.confusion_metrics import Precision as Precision\nfrom keras.src.metrics.confusion_metrics import (\n    PrecisionAtRecall as PrecisionAtRecall,\n)\nfrom keras.src.metrics.confusion_metrics import Recall as Recall\nfrom keras.src.metrics.confusion_metrics import (\n    RecallAtPrecision as RecallAtPrecision,\n)\nfrom keras.src.metrics.confusion_metrics import (\n    SensitivityAtSpecificity as SensitivityAtSpecificity,\n)\nfrom keras.src.metrics.confusion_metrics import (\n    SpecificityAtSensitivity as SpecificityAtSensitivity,\n)\nfrom keras.src.metrics.confusion_metrics import TrueNegatives as TrueNegatives\nfrom keras.src.metrics.confusion_metrics import TruePositives as TruePositives\nfrom keras.src.metrics.correlation_metrics import (\n    ConcordanceCorrelation as ConcordanceCorrelation,\n)\nfrom keras.src.metrics.correlation_metrics import (\n    PearsonCorrelation as PearsonCorrelation,\n)\nfrom keras.src.metrics.correlation_metrics import (\n    concordance_correlation as concordance_correlation,\n)\nfrom keras.src.metrics.correlation_metrics import (\n    pearson_correlation as pearson_correlation,\n)\nfrom keras.src.metrics.f_score_metrics import F1Score as F1Score\nfrom keras.src.metrics.f_score_metrics import FBetaScore as FBetaScore\nfrom keras.src.metrics.hinge_metrics import CategoricalHinge as CategoricalHinge\nfrom keras.src.metrics.hinge_metrics import Hinge as Hinge\nfrom keras.src.metrics.hinge_metrics import SquaredHinge as SquaredHinge\nfrom keras.src.metrics.iou_metrics import BinaryIoU as BinaryIoU\nfrom keras.src.metrics.iou_metrics import IoU as IoU\nfrom keras.src.metrics.iou_metrics import MeanIoU as MeanIoU\nfrom keras.src.metrics.iou_metrics import OneHotIoU as OneHotIoU\nfrom keras.src.metrics.iou_metrics import OneHotMeanIoU as OneHotMeanIoU\nfrom keras.src.metrics.metric import Metric as Metric\nfrom keras.src.metrics.probabilistic_metrics import (\n    BinaryCrossentropy as BinaryCrossentropy,\n)\nfrom keras.src.metrics.probabilistic_metrics import (\n    CategoricalCrossentropy as CategoricalCrossentropy,\n)\nfrom keras.src.metrics.probabilistic_metrics import KLDivergence as KLDivergence\nfrom keras.src.metrics.probabilistic_metrics import Poisson as Poisson\nfrom keras.src.metrics.probabilistic_metrics import (\n    SparseCategoricalCrossentropy as SparseCategoricalCrossentropy,\n)\nfrom keras.src.metrics.reduction_metrics import Mean as Mean\nfrom keras.src.metrics.reduction_metrics import (\n    MeanMetricWrapper as MeanMetricWrapper,\n)\nfrom keras.src.metrics.reduction_metrics import Sum as Sum\nfrom keras.src.metrics.regression_metrics import (\n    CosineSimilarity as CosineSimilarity,\n)\nfrom keras.src.metrics.regression_metrics import LogCoshError as LogCoshError\nfrom keras.src.metrics.regression_metrics import (\n    MeanAbsoluteError as MeanAbsoluteError,\n)\nfrom keras.src.metrics.regression_metrics import (\n    MeanAbsolutePercentageError as MeanAbsolutePercentageError,\n)\nfrom keras.src.metrics.regression_metrics import (\n    MeanSquaredError as MeanSquaredError,\n)\nfrom keras.src.metrics.regression_metrics import (\n    MeanSquaredLogarithmicError as MeanSquaredLogarithmicError,\n)\nfrom keras.src.metrics.regression_metrics import R2Score as R2Score\nfrom keras.src.metrics.regression_metrics import (\n    RootMeanSquaredError as RootMeanSquaredError,\n)\n"
  },
  {
    "path": "keras/api/mixed_precision/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.dtype_policies.dtype_policy import DTypePolicy as DTypePolicy\nfrom keras.src.dtype_policies.dtype_policy import DTypePolicy as Policy\nfrom keras.src.dtype_policies.dtype_policy import dtype_policy as dtype_policy\nfrom keras.src.dtype_policies.dtype_policy import dtype_policy as global_policy\nfrom keras.src.dtype_policies.dtype_policy import (\n    set_dtype_policy as set_dtype_policy,\n)\nfrom keras.src.dtype_policies.dtype_policy import (\n    set_dtype_policy as set_global_policy,\n)\nfrom keras.src.optimizers.loss_scale_optimizer import (\n    LossScaleOptimizer as LossScaleOptimizer,\n)\n"
  },
  {
    "path": "keras/api/models/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.models.cloning import clone_model as clone_model\nfrom keras.src.models.model import Model as Model\nfrom keras.src.models.model import model_from_json as model_from_json\nfrom keras.src.models.sequential import Sequential as Sequential\nfrom keras.src.saving.saving_api import load_model as load_model\nfrom keras.src.saving.saving_api import save_model as save_model\n"
  },
  {
    "path": "keras/api/ops/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.ops import image as image\nfrom keras.ops import linalg as linalg\nfrom keras.ops import nn as nn\nfrom keras.ops import numpy as numpy\nfrom keras.src.ops.core import associative_scan as associative_scan\nfrom keras.src.ops.core import cast as cast\nfrom keras.src.ops.core import cond as cond\nfrom keras.src.ops.core import convert_to_numpy as convert_to_numpy\nfrom keras.src.ops.core import convert_to_tensor as convert_to_tensor\nfrom keras.src.ops.core import custom_gradient as custom_gradient\nfrom keras.src.ops.core import dtype as dtype\nfrom keras.src.ops.core import fori_loop as fori_loop\nfrom keras.src.ops.core import is_tensor as is_tensor\nfrom keras.src.ops.core import map as map\nfrom keras.src.ops.core import saturate_cast as saturate_cast\nfrom keras.src.ops.core import scan as scan\nfrom keras.src.ops.core import scatter as scatter\nfrom keras.src.ops.core import scatter_update as scatter_update\nfrom keras.src.ops.core import shape as shape\nfrom keras.src.ops.core import slice as slice\nfrom keras.src.ops.core import slice_update as slice_update\nfrom keras.src.ops.core import stop_gradient as stop_gradient\nfrom keras.src.ops.core import switch as switch\nfrom keras.src.ops.core import unstack as unstack\nfrom keras.src.ops.core import vectorized_map as vectorized_map\nfrom keras.src.ops.core import while_loop as while_loop\nfrom keras.src.ops.einops import rearrange as rearrange\nfrom keras.src.ops.linalg import cholesky as cholesky\nfrom keras.src.ops.linalg import cholesky_inverse as cholesky_inverse\nfrom keras.src.ops.linalg import det as det\nfrom keras.src.ops.linalg import eig as eig\nfrom keras.src.ops.linalg import eigh as eigh\nfrom keras.src.ops.linalg import inv as inv\nfrom keras.src.ops.linalg import jvp as jvp\nfrom keras.src.ops.linalg import lstsq as lstsq\nfrom keras.src.ops.linalg import lu_factor as lu_factor\nfrom keras.src.ops.linalg import norm as norm\nfrom keras.src.ops.linalg import qr as qr\nfrom keras.src.ops.linalg import solve as solve\nfrom keras.src.ops.linalg import solve_triangular as solve_triangular\nfrom keras.src.ops.linalg import svd as svd\nfrom keras.src.ops.math import erf as erf\nfrom keras.src.ops.math import erfinv as erfinv\nfrom keras.src.ops.math import extract_sequences as extract_sequences\nfrom keras.src.ops.math import fft as fft\nfrom keras.src.ops.math import fft2 as fft2\nfrom keras.src.ops.math import ifft2 as ifft2\nfrom keras.src.ops.math import in_top_k as in_top_k\nfrom keras.src.ops.math import irfft as irfft\nfrom keras.src.ops.math import istft as istft\nfrom keras.src.ops.math import logdet as logdet\nfrom keras.src.ops.math import logsumexp as logsumexp\nfrom keras.src.ops.math import rfft as rfft\nfrom keras.src.ops.math import rsqrt as rsqrt\nfrom keras.src.ops.math import segment_max as segment_max\nfrom keras.src.ops.math import segment_sum as segment_sum\nfrom keras.src.ops.math import stft as stft\nfrom keras.src.ops.math import top_k as top_k\nfrom keras.src.ops.math import view_as_complex as view_as_complex\nfrom keras.src.ops.math import view_as_real as view_as_real\nfrom keras.src.ops.nn import adaptive_average_pool as adaptive_average_pool\nfrom keras.src.ops.nn import adaptive_max_pool as adaptive_max_pool\nfrom keras.src.ops.nn import average_pool as average_pool\nfrom keras.src.ops.nn import batch_normalization as batch_normalization\nfrom keras.src.ops.nn import binary_crossentropy as binary_crossentropy\nfrom keras.src.ops.nn import (\n    categorical_crossentropy as categorical_crossentropy,\n)\nfrom keras.src.ops.nn import celu as celu\nfrom keras.src.ops.nn import conv as conv\nfrom keras.src.ops.nn import conv_transpose as conv_transpose\nfrom keras.src.ops.nn import ctc_decode as ctc_decode\nfrom keras.src.ops.nn import ctc_loss as ctc_loss\nfrom keras.src.ops.nn import depth_to_space as depth_to_space\nfrom keras.src.ops.nn import depthwise_conv as depthwise_conv\nfrom keras.src.ops.nn import dot_product_attention as dot_product_attention\nfrom keras.src.ops.nn import elu as elu\nfrom keras.src.ops.nn import fold as fold\nfrom keras.src.ops.nn import gelu as gelu\nfrom keras.src.ops.nn import glu as glu\nfrom keras.src.ops.nn import hard_shrink as hard_shrink\nfrom keras.src.ops.nn import hard_sigmoid as hard_sigmoid\nfrom keras.src.ops.nn import hard_silu as hard_silu\nfrom keras.src.ops.nn import hard_silu as hard_swish\nfrom keras.src.ops.nn import hard_tanh as hard_tanh\nfrom keras.src.ops.nn import layer_normalization as layer_normalization\nfrom keras.src.ops.nn import leaky_relu as leaky_relu\nfrom keras.src.ops.nn import log_sigmoid as log_sigmoid\nfrom keras.src.ops.nn import log_softmax as log_softmax\nfrom keras.src.ops.nn import max_pool as max_pool\nfrom keras.src.ops.nn import moments as moments\nfrom keras.src.ops.nn import multi_hot as multi_hot\nfrom keras.src.ops.nn import normalize as normalize\nfrom keras.src.ops.nn import one_hot as one_hot\nfrom keras.src.ops.nn import polar as polar\nfrom keras.src.ops.nn import psnr as psnr\nfrom keras.src.ops.nn import relu as relu\nfrom keras.src.ops.nn import relu6 as relu6\nfrom keras.src.ops.nn import rms_normalization as rms_normalization\nfrom keras.src.ops.nn import selu as selu\nfrom keras.src.ops.nn import separable_conv as separable_conv\nfrom keras.src.ops.nn import sigmoid as sigmoid\nfrom keras.src.ops.nn import silu as silu\nfrom keras.src.ops.nn import silu as swish\nfrom keras.src.ops.nn import soft_shrink as soft_shrink\nfrom keras.src.ops.nn import softmax as softmax\nfrom keras.src.ops.nn import softplus as softplus\nfrom keras.src.ops.nn import softsign as softsign\nfrom keras.src.ops.nn import space_to_depth as space_to_depth\nfrom keras.src.ops.nn import (\n    sparse_categorical_crossentropy as sparse_categorical_crossentropy,\n)\nfrom keras.src.ops.nn import sparse_plus as sparse_plus\nfrom keras.src.ops.nn import sparse_sigmoid as sparse_sigmoid\nfrom keras.src.ops.nn import sparsemax as sparsemax\nfrom keras.src.ops.nn import squareplus as squareplus\nfrom keras.src.ops.nn import tanh_shrink as tanh_shrink\nfrom keras.src.ops.nn import threshold as threshold\nfrom keras.src.ops.nn import unfold as unfold\nfrom keras.src.ops.numpy import abs as abs\nfrom keras.src.ops.numpy import absolute as absolute\nfrom keras.src.ops.numpy import add as add\nfrom keras.src.ops.numpy import all as all\nfrom keras.src.ops.numpy import allclose as allclose\nfrom keras.src.ops.numpy import amax as amax\nfrom keras.src.ops.numpy import amin as amin\nfrom keras.src.ops.numpy import angle as angle\nfrom keras.src.ops.numpy import any as any\nfrom keras.src.ops.numpy import append as append\nfrom keras.src.ops.numpy import arange as arange\nfrom keras.src.ops.numpy import arccos as arccos\nfrom keras.src.ops.numpy import arccosh as arccosh\nfrom keras.src.ops.numpy import arcsin as arcsin\nfrom keras.src.ops.numpy import arcsinh as arcsinh\nfrom keras.src.ops.numpy import arctan as arctan\nfrom keras.src.ops.numpy import arctan2 as arctan2\nfrom keras.src.ops.numpy import arctanh as arctanh\nfrom keras.src.ops.numpy import argmax as argmax\nfrom keras.src.ops.numpy import argmin as argmin\nfrom keras.src.ops.numpy import argpartition as argpartition\nfrom keras.src.ops.numpy import argsort as argsort\nfrom keras.src.ops.numpy import array as array\nfrom keras.src.ops.numpy import array_split as array_split\nfrom keras.src.ops.numpy import average as average\nfrom keras.src.ops.numpy import bartlett as bartlett\nfrom keras.src.ops.numpy import bincount as bincount\nfrom keras.src.ops.numpy import bitwise_and as bitwise_and\nfrom keras.src.ops.numpy import bitwise_invert as bitwise_invert\nfrom keras.src.ops.numpy import bitwise_left_shift as bitwise_left_shift\nfrom keras.src.ops.numpy import bitwise_not as bitwise_not\nfrom keras.src.ops.numpy import bitwise_or as bitwise_or\nfrom keras.src.ops.numpy import bitwise_right_shift as bitwise_right_shift\nfrom keras.src.ops.numpy import bitwise_xor as bitwise_xor\nfrom keras.src.ops.numpy import blackman as blackman\nfrom keras.src.ops.numpy import broadcast_to as broadcast_to\nfrom keras.src.ops.numpy import cbrt as cbrt\nfrom keras.src.ops.numpy import ceil as ceil\nfrom keras.src.ops.numpy import clip as clip\nfrom keras.src.ops.numpy import concatenate as concatenate\nfrom keras.src.ops.numpy import conj as conj\nfrom keras.src.ops.numpy import conjugate as conjugate\nfrom keras.src.ops.numpy import copy as copy\nfrom keras.src.ops.numpy import corrcoef as corrcoef\nfrom keras.src.ops.numpy import correlate as correlate\nfrom keras.src.ops.numpy import cos as cos\nfrom keras.src.ops.numpy import cosh as cosh\nfrom keras.src.ops.numpy import count_nonzero as count_nonzero\nfrom keras.src.ops.numpy import cross as cross\nfrom keras.src.ops.numpy import cumprod as cumprod\nfrom keras.src.ops.numpy import cumsum as cumsum\nfrom keras.src.ops.numpy import deg2rad as deg2rad\nfrom keras.src.ops.numpy import diag as diag\nfrom keras.src.ops.numpy import diagflat as diagflat\nfrom keras.src.ops.numpy import diagonal as diagonal\nfrom keras.src.ops.numpy import diff as diff\nfrom keras.src.ops.numpy import digitize as digitize\nfrom keras.src.ops.numpy import divide as divide\nfrom keras.src.ops.numpy import divide_no_nan as divide_no_nan\nfrom keras.src.ops.numpy import dot as dot\nfrom keras.src.ops.numpy import dstack as dstack\nfrom keras.src.ops.numpy import einsum as einsum\nfrom keras.src.ops.numpy import empty as empty\nfrom keras.src.ops.numpy import empty_like as empty_like\nfrom keras.src.ops.numpy import equal as equal\nfrom keras.src.ops.numpy import exp as exp\nfrom keras.src.ops.numpy import exp2 as exp2\nfrom keras.src.ops.numpy import expand_dims as expand_dims\nfrom keras.src.ops.numpy import expm1 as expm1\nfrom keras.src.ops.numpy import eye as eye\nfrom keras.src.ops.numpy import flip as flip\nfrom keras.src.ops.numpy import floor as floor\nfrom keras.src.ops.numpy import floor_divide as floor_divide\nfrom keras.src.ops.numpy import fmod as fmod\nfrom keras.src.ops.numpy import full as full\nfrom keras.src.ops.numpy import full_like as full_like\nfrom keras.src.ops.numpy import gcd as gcd\nfrom keras.src.ops.numpy import geomspace as geomspace\nfrom keras.src.ops.numpy import get_item as get_item\nfrom keras.src.ops.numpy import greater as greater\nfrom keras.src.ops.numpy import greater_equal as greater_equal\nfrom keras.src.ops.numpy import hamming as hamming\nfrom keras.src.ops.numpy import hanning as hanning\nfrom keras.src.ops.numpy import heaviside as heaviside\nfrom keras.src.ops.numpy import histogram as histogram\nfrom keras.src.ops.numpy import hsplit as hsplit\nfrom keras.src.ops.numpy import hstack as hstack\nfrom keras.src.ops.numpy import hypot as hypot\nfrom keras.src.ops.numpy import identity as identity\nfrom keras.src.ops.numpy import imag as imag\nfrom keras.src.ops.numpy import inner as inner\nfrom keras.src.ops.numpy import isclose as isclose\nfrom keras.src.ops.numpy import isfinite as isfinite\nfrom keras.src.ops.numpy import isin as isin\nfrom keras.src.ops.numpy import isinf as isinf\nfrom keras.src.ops.numpy import isnan as isnan\nfrom keras.src.ops.numpy import isneginf as isneginf\nfrom keras.src.ops.numpy import isposinf as isposinf\nfrom keras.src.ops.numpy import isreal as isreal\nfrom keras.src.ops.numpy import kaiser as kaiser\nfrom keras.src.ops.numpy import kron as kron\nfrom keras.src.ops.numpy import lcm as lcm\nfrom keras.src.ops.numpy import ldexp as ldexp\nfrom keras.src.ops.numpy import left_shift as left_shift\nfrom keras.src.ops.numpy import less as less\nfrom keras.src.ops.numpy import less_equal as less_equal\nfrom keras.src.ops.numpy import linspace as linspace\nfrom keras.src.ops.numpy import log as log\nfrom keras.src.ops.numpy import log1p as log1p\nfrom keras.src.ops.numpy import log2 as log2\nfrom keras.src.ops.numpy import log10 as log10\nfrom keras.src.ops.numpy import logaddexp as logaddexp\nfrom keras.src.ops.numpy import logaddexp2 as logaddexp2\nfrom keras.src.ops.numpy import logical_and as logical_and\nfrom keras.src.ops.numpy import logical_not as logical_not\nfrom keras.src.ops.numpy import logical_or as logical_or\nfrom keras.src.ops.numpy import logical_xor as logical_xor\nfrom keras.src.ops.numpy import logspace as logspace\nfrom keras.src.ops.numpy import matmul as matmul\nfrom keras.src.ops.numpy import max as max\nfrom keras.src.ops.numpy import maximum as maximum\nfrom keras.src.ops.numpy import mean as mean\nfrom keras.src.ops.numpy import median as median\nfrom keras.src.ops.numpy import meshgrid as meshgrid\nfrom keras.src.ops.numpy import min as min\nfrom keras.src.ops.numpy import minimum as minimum\nfrom keras.src.ops.numpy import mod as mod\nfrom keras.src.ops.numpy import moveaxis as moveaxis\nfrom keras.src.ops.numpy import multiply as multiply\nfrom keras.src.ops.numpy import nan_to_num as nan_to_num\nfrom keras.src.ops.numpy import nanargmax as nanargmax\nfrom keras.src.ops.numpy import nanargmin as nanargmin\nfrom keras.src.ops.numpy import nancumprod as nancumprod\nfrom keras.src.ops.numpy import nancumsum as nancumsum\nfrom keras.src.ops.numpy import nanmax as nanmax\nfrom keras.src.ops.numpy import nanmean as nanmean\nfrom keras.src.ops.numpy import nanmin as nanmin\nfrom keras.src.ops.numpy import nanprod as nanprod\nfrom keras.src.ops.numpy import nanstd as nanstd\nfrom keras.src.ops.numpy import nansum as nansum\nfrom keras.src.ops.numpy import nanvar as nanvar\nfrom keras.src.ops.numpy import ndim as ndim\nfrom keras.src.ops.numpy import negative as negative\nfrom keras.src.ops.numpy import nextafter as nextafter\nfrom keras.src.ops.numpy import nonzero as nonzero\nfrom keras.src.ops.numpy import not_equal as not_equal\nfrom keras.src.ops.numpy import ones as ones\nfrom keras.src.ops.numpy import ones_like as ones_like\nfrom keras.src.ops.numpy import outer as outer\nfrom keras.src.ops.numpy import pad as pad\nfrom keras.src.ops.numpy import power as power\nfrom keras.src.ops.numpy import prod as prod\nfrom keras.src.ops.numpy import ptp as ptp\nfrom keras.src.ops.numpy import quantile as quantile\nfrom keras.src.ops.numpy import ravel as ravel\nfrom keras.src.ops.numpy import real as real\nfrom keras.src.ops.numpy import reciprocal as reciprocal\nfrom keras.src.ops.numpy import repeat as repeat\nfrom keras.src.ops.numpy import reshape as reshape\nfrom keras.src.ops.numpy import right_shift as right_shift\nfrom keras.src.ops.numpy import roll as roll\nfrom keras.src.ops.numpy import rot90 as rot90\nfrom keras.src.ops.numpy import round as round\nfrom keras.src.ops.numpy import searchsorted as searchsorted\nfrom keras.src.ops.numpy import select as select\nfrom keras.src.ops.numpy import sign as sign\nfrom keras.src.ops.numpy import signbit as signbit\nfrom keras.src.ops.numpy import sin as sin\nfrom keras.src.ops.numpy import sinc as sinc\nfrom keras.src.ops.numpy import sinh as sinh\nfrom keras.src.ops.numpy import size as size\nfrom keras.src.ops.numpy import slogdet as slogdet\nfrom keras.src.ops.numpy import sort as sort\nfrom keras.src.ops.numpy import split as split\nfrom keras.src.ops.numpy import sqrt as sqrt\nfrom keras.src.ops.numpy import square as square\nfrom keras.src.ops.numpy import squeeze as squeeze\nfrom keras.src.ops.numpy import stack as stack\nfrom keras.src.ops.numpy import std as std\nfrom keras.src.ops.numpy import subtract as subtract\nfrom keras.src.ops.numpy import sum as sum\nfrom keras.src.ops.numpy import swapaxes as swapaxes\nfrom keras.src.ops.numpy import take as take\nfrom keras.src.ops.numpy import take_along_axis as take_along_axis\nfrom keras.src.ops.numpy import tan as tan\nfrom keras.src.ops.numpy import tanh as tanh\nfrom keras.src.ops.numpy import tensordot as tensordot\nfrom keras.src.ops.numpy import tile as tile\nfrom keras.src.ops.numpy import trace as trace\nfrom keras.src.ops.numpy import transpose as transpose\nfrom keras.src.ops.numpy import trapezoid as trapezoid\nfrom keras.src.ops.numpy import tri as tri\nfrom keras.src.ops.numpy import tril as tril\nfrom keras.src.ops.numpy import triu as triu\nfrom keras.src.ops.numpy import true_divide as true_divide\nfrom keras.src.ops.numpy import trunc as trunc\nfrom keras.src.ops.numpy import unravel_index as unravel_index\nfrom keras.src.ops.numpy import vander as vander\nfrom keras.src.ops.numpy import var as var\nfrom keras.src.ops.numpy import vdot as vdot\nfrom keras.src.ops.numpy import vectorize as vectorize\nfrom keras.src.ops.numpy import view as view\nfrom keras.src.ops.numpy import vsplit as vsplit\nfrom keras.src.ops.numpy import vstack as vstack\nfrom keras.src.ops.numpy import where as where\nfrom keras.src.ops.numpy import zeros as zeros\nfrom keras.src.ops.numpy import zeros_like as zeros_like\n"
  },
  {
    "path": "keras/api/ops/image/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.ops.image import affine_transform as affine_transform\nfrom keras.src.ops.image import crop_images as crop_images\nfrom keras.src.ops.image import elastic_transform as elastic_transform\nfrom keras.src.ops.image import extract_patches as extract_patches\nfrom keras.src.ops.image import extract_patches_3d as extract_patches_3d\nfrom keras.src.ops.image import gaussian_blur as gaussian_blur\nfrom keras.src.ops.image import hsv_to_rgb as hsv_to_rgb\nfrom keras.src.ops.image import map_coordinates as map_coordinates\nfrom keras.src.ops.image import pad_images as pad_images\nfrom keras.src.ops.image import perspective_transform as perspective_transform\nfrom keras.src.ops.image import resize as resize\nfrom keras.src.ops.image import rgb_to_grayscale as rgb_to_grayscale\nfrom keras.src.ops.image import rgb_to_hsv as rgb_to_hsv\nfrom keras.src.ops.image import scale_and_translate as scale_and_translate\n"
  },
  {
    "path": "keras/api/ops/linalg/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.ops.linalg import cholesky as cholesky\nfrom keras.src.ops.linalg import cholesky_inverse as cholesky_inverse\nfrom keras.src.ops.linalg import det as det\nfrom keras.src.ops.linalg import eig as eig\nfrom keras.src.ops.linalg import eigh as eigh\nfrom keras.src.ops.linalg import inv as inv\nfrom keras.src.ops.linalg import jvp as jvp\nfrom keras.src.ops.linalg import lstsq as lstsq\nfrom keras.src.ops.linalg import lu_factor as lu_factor\nfrom keras.src.ops.linalg import norm as norm\nfrom keras.src.ops.linalg import qr as qr\nfrom keras.src.ops.linalg import solve as solve\nfrom keras.src.ops.linalg import solve_triangular as solve_triangular\nfrom keras.src.ops.linalg import svd as svd\n"
  },
  {
    "path": "keras/api/ops/nn/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.ops.nn import adaptive_average_pool as adaptive_average_pool\nfrom keras.src.ops.nn import adaptive_max_pool as adaptive_max_pool\nfrom keras.src.ops.nn import average_pool as average_pool\nfrom keras.src.ops.nn import batch_normalization as batch_normalization\nfrom keras.src.ops.nn import binary_crossentropy as binary_crossentropy\nfrom keras.src.ops.nn import (\n    categorical_crossentropy as categorical_crossentropy,\n)\nfrom keras.src.ops.nn import celu as celu\nfrom keras.src.ops.nn import conv as conv\nfrom keras.src.ops.nn import conv_transpose as conv_transpose\nfrom keras.src.ops.nn import ctc_decode as ctc_decode\nfrom keras.src.ops.nn import ctc_loss as ctc_loss\nfrom keras.src.ops.nn import depth_to_space as depth_to_space\nfrom keras.src.ops.nn import depthwise_conv as depthwise_conv\nfrom keras.src.ops.nn import dot_product_attention as dot_product_attention\nfrom keras.src.ops.nn import elu as elu\nfrom keras.src.ops.nn import fold as fold\nfrom keras.src.ops.nn import gelu as gelu\nfrom keras.src.ops.nn import glu as glu\nfrom keras.src.ops.nn import hard_shrink as hard_shrink\nfrom keras.src.ops.nn import hard_sigmoid as hard_sigmoid\nfrom keras.src.ops.nn import hard_silu as hard_silu\nfrom keras.src.ops.nn import hard_silu as hard_swish\nfrom keras.src.ops.nn import hard_tanh as hard_tanh\nfrom keras.src.ops.nn import layer_normalization as layer_normalization\nfrom keras.src.ops.nn import leaky_relu as leaky_relu\nfrom keras.src.ops.nn import log_sigmoid as log_sigmoid\nfrom keras.src.ops.nn import log_softmax as log_softmax\nfrom keras.src.ops.nn import max_pool as max_pool\nfrom keras.src.ops.nn import moments as moments\nfrom keras.src.ops.nn import multi_hot as multi_hot\nfrom keras.src.ops.nn import normalize as normalize\nfrom keras.src.ops.nn import one_hot as one_hot\nfrom keras.src.ops.nn import polar as polar\nfrom keras.src.ops.nn import psnr as psnr\nfrom keras.src.ops.nn import relu as relu\nfrom keras.src.ops.nn import relu6 as relu6\nfrom keras.src.ops.nn import rms_normalization as rms_normalization\nfrom keras.src.ops.nn import selu as selu\nfrom keras.src.ops.nn import separable_conv as separable_conv\nfrom keras.src.ops.nn import sigmoid as sigmoid\nfrom keras.src.ops.nn import silu as silu\nfrom keras.src.ops.nn import silu as swish\nfrom keras.src.ops.nn import soft_shrink as soft_shrink\nfrom keras.src.ops.nn import softmax as softmax\nfrom keras.src.ops.nn import softplus as softplus\nfrom keras.src.ops.nn import softsign as softsign\nfrom keras.src.ops.nn import space_to_depth as space_to_depth\nfrom keras.src.ops.nn import (\n    sparse_categorical_crossentropy as sparse_categorical_crossentropy,\n)\nfrom keras.src.ops.nn import sparse_plus as sparse_plus\nfrom keras.src.ops.nn import sparse_sigmoid as sparse_sigmoid\nfrom keras.src.ops.nn import sparsemax as sparsemax\nfrom keras.src.ops.nn import squareplus as squareplus\nfrom keras.src.ops.nn import tanh_shrink as tanh_shrink\nfrom keras.src.ops.nn import threshold as threshold\nfrom keras.src.ops.nn import unfold as unfold\n"
  },
  {
    "path": "keras/api/ops/numpy/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.ops.numpy import abs as abs\nfrom keras.src.ops.numpy import absolute as absolute\nfrom keras.src.ops.numpy import add as add\nfrom keras.src.ops.numpy import all as all\nfrom keras.src.ops.numpy import allclose as allclose\nfrom keras.src.ops.numpy import amax as amax\nfrom keras.src.ops.numpy import amin as amin\nfrom keras.src.ops.numpy import angle as angle\nfrom keras.src.ops.numpy import any as any\nfrom keras.src.ops.numpy import append as append\nfrom keras.src.ops.numpy import arange as arange\nfrom keras.src.ops.numpy import arccos as arccos\nfrom keras.src.ops.numpy import arccosh as arccosh\nfrom keras.src.ops.numpy import arcsin as arcsin\nfrom keras.src.ops.numpy import arcsinh as arcsinh\nfrom keras.src.ops.numpy import arctan as arctan\nfrom keras.src.ops.numpy import arctan2 as arctan2\nfrom keras.src.ops.numpy import arctanh as arctanh\nfrom keras.src.ops.numpy import argmax as argmax\nfrom keras.src.ops.numpy import argmin as argmin\nfrom keras.src.ops.numpy import argpartition as argpartition\nfrom keras.src.ops.numpy import argsort as argsort\nfrom keras.src.ops.numpy import array as array\nfrom keras.src.ops.numpy import array_split as array_split\nfrom keras.src.ops.numpy import average as average\nfrom keras.src.ops.numpy import bartlett as bartlett\nfrom keras.src.ops.numpy import bincount as bincount\nfrom keras.src.ops.numpy import bitwise_and as bitwise_and\nfrom keras.src.ops.numpy import bitwise_invert as bitwise_invert\nfrom keras.src.ops.numpy import bitwise_left_shift as bitwise_left_shift\nfrom keras.src.ops.numpy import bitwise_not as bitwise_not\nfrom keras.src.ops.numpy import bitwise_or as bitwise_or\nfrom keras.src.ops.numpy import bitwise_right_shift as bitwise_right_shift\nfrom keras.src.ops.numpy import bitwise_xor as bitwise_xor\nfrom keras.src.ops.numpy import blackman as blackman\nfrom keras.src.ops.numpy import broadcast_to as broadcast_to\nfrom keras.src.ops.numpy import cbrt as cbrt\nfrom keras.src.ops.numpy import ceil as ceil\nfrom keras.src.ops.numpy import clip as clip\nfrom keras.src.ops.numpy import concatenate as concatenate\nfrom keras.src.ops.numpy import conj as conj\nfrom keras.src.ops.numpy import conjugate as conjugate\nfrom keras.src.ops.numpy import copy as copy\nfrom keras.src.ops.numpy import corrcoef as corrcoef\nfrom keras.src.ops.numpy import correlate as correlate\nfrom keras.src.ops.numpy import cos as cos\nfrom keras.src.ops.numpy import cosh as cosh\nfrom keras.src.ops.numpy import count_nonzero as count_nonzero\nfrom keras.src.ops.numpy import cross as cross\nfrom keras.src.ops.numpy import cumprod as cumprod\nfrom keras.src.ops.numpy import cumsum as cumsum\nfrom keras.src.ops.numpy import deg2rad as deg2rad\nfrom keras.src.ops.numpy import diag as diag\nfrom keras.src.ops.numpy import diagflat as diagflat\nfrom keras.src.ops.numpy import diagonal as diagonal\nfrom keras.src.ops.numpy import diff as diff\nfrom keras.src.ops.numpy import digitize as digitize\nfrom keras.src.ops.numpy import divide as divide\nfrom keras.src.ops.numpy import divide_no_nan as divide_no_nan\nfrom keras.src.ops.numpy import dot as dot\nfrom keras.src.ops.numpy import dstack as dstack\nfrom keras.src.ops.numpy import einsum as einsum\nfrom keras.src.ops.numpy import empty as empty\nfrom keras.src.ops.numpy import empty_like as empty_like\nfrom keras.src.ops.numpy import equal as equal\nfrom keras.src.ops.numpy import exp as exp\nfrom keras.src.ops.numpy import exp2 as exp2\nfrom keras.src.ops.numpy import expand_dims as expand_dims\nfrom keras.src.ops.numpy import expm1 as expm1\nfrom keras.src.ops.numpy import eye as eye\nfrom keras.src.ops.numpy import flip as flip\nfrom keras.src.ops.numpy import floor as floor\nfrom keras.src.ops.numpy import floor_divide as floor_divide\nfrom keras.src.ops.numpy import fmod as fmod\nfrom keras.src.ops.numpy import full as full\nfrom keras.src.ops.numpy import full_like as full_like\nfrom keras.src.ops.numpy import gcd as gcd\nfrom keras.src.ops.numpy import geomspace as geomspace\nfrom keras.src.ops.numpy import get_item as get_item\nfrom keras.src.ops.numpy import greater as greater\nfrom keras.src.ops.numpy import greater_equal as greater_equal\nfrom keras.src.ops.numpy import hamming as hamming\nfrom keras.src.ops.numpy import hanning as hanning\nfrom keras.src.ops.numpy import heaviside as heaviside\nfrom keras.src.ops.numpy import histogram as histogram\nfrom keras.src.ops.numpy import hsplit as hsplit\nfrom keras.src.ops.numpy import hstack as hstack\nfrom keras.src.ops.numpy import hypot as hypot\nfrom keras.src.ops.numpy import identity as identity\nfrom keras.src.ops.numpy import imag as imag\nfrom keras.src.ops.numpy import inner as inner\nfrom keras.src.ops.numpy import isclose as isclose\nfrom keras.src.ops.numpy import isfinite as isfinite\nfrom keras.src.ops.numpy import isin as isin\nfrom keras.src.ops.numpy import isinf as isinf\nfrom keras.src.ops.numpy import isnan as isnan\nfrom keras.src.ops.numpy import isneginf as isneginf\nfrom keras.src.ops.numpy import isposinf as isposinf\nfrom keras.src.ops.numpy import isreal as isreal\nfrom keras.src.ops.numpy import kaiser as kaiser\nfrom keras.src.ops.numpy import kron as kron\nfrom keras.src.ops.numpy import lcm as lcm\nfrom keras.src.ops.numpy import ldexp as ldexp\nfrom keras.src.ops.numpy import left_shift as left_shift\nfrom keras.src.ops.numpy import less as less\nfrom keras.src.ops.numpy import less_equal as less_equal\nfrom keras.src.ops.numpy import linspace as linspace\nfrom keras.src.ops.numpy import log as log\nfrom keras.src.ops.numpy import log1p as log1p\nfrom keras.src.ops.numpy import log2 as log2\nfrom keras.src.ops.numpy import log10 as log10\nfrom keras.src.ops.numpy import logaddexp as logaddexp\nfrom keras.src.ops.numpy import logaddexp2 as logaddexp2\nfrom keras.src.ops.numpy import logical_and as logical_and\nfrom keras.src.ops.numpy import logical_not as logical_not\nfrom keras.src.ops.numpy import logical_or as logical_or\nfrom keras.src.ops.numpy import logical_xor as logical_xor\nfrom keras.src.ops.numpy import logspace as logspace\nfrom keras.src.ops.numpy import matmul as matmul\nfrom keras.src.ops.numpy import max as max\nfrom keras.src.ops.numpy import maximum as maximum\nfrom keras.src.ops.numpy import mean as mean\nfrom keras.src.ops.numpy import median as median\nfrom keras.src.ops.numpy import meshgrid as meshgrid\nfrom keras.src.ops.numpy import min as min\nfrom keras.src.ops.numpy import minimum as minimum\nfrom keras.src.ops.numpy import mod as mod\nfrom keras.src.ops.numpy import moveaxis as moveaxis\nfrom keras.src.ops.numpy import multiply as multiply\nfrom keras.src.ops.numpy import nan_to_num as nan_to_num\nfrom keras.src.ops.numpy import nanargmax as nanargmax\nfrom keras.src.ops.numpy import nanargmin as nanargmin\nfrom keras.src.ops.numpy import nancumprod as nancumprod\nfrom keras.src.ops.numpy import nancumsum as nancumsum\nfrom keras.src.ops.numpy import nanmax as nanmax\nfrom keras.src.ops.numpy import nanmean as nanmean\nfrom keras.src.ops.numpy import nanmin as nanmin\nfrom keras.src.ops.numpy import nanprod as nanprod\nfrom keras.src.ops.numpy import nanstd as nanstd\nfrom keras.src.ops.numpy import nansum as nansum\nfrom keras.src.ops.numpy import nanvar as nanvar\nfrom keras.src.ops.numpy import ndim as ndim\nfrom keras.src.ops.numpy import negative as negative\nfrom keras.src.ops.numpy import nextafter as nextafter\nfrom keras.src.ops.numpy import nonzero as nonzero\nfrom keras.src.ops.numpy import not_equal as not_equal\nfrom keras.src.ops.numpy import ones as ones\nfrom keras.src.ops.numpy import ones_like as ones_like\nfrom keras.src.ops.numpy import outer as outer\nfrom keras.src.ops.numpy import pad as pad\nfrom keras.src.ops.numpy import power as power\nfrom keras.src.ops.numpy import prod as prod\nfrom keras.src.ops.numpy import ptp as ptp\nfrom keras.src.ops.numpy import quantile as quantile\nfrom keras.src.ops.numpy import ravel as ravel\nfrom keras.src.ops.numpy import real as real\nfrom keras.src.ops.numpy import reciprocal as reciprocal\nfrom keras.src.ops.numpy import repeat as repeat\nfrom keras.src.ops.numpy import reshape as reshape\nfrom keras.src.ops.numpy import right_shift as right_shift\nfrom keras.src.ops.numpy import roll as roll\nfrom keras.src.ops.numpy import rot90 as rot90\nfrom keras.src.ops.numpy import round as round\nfrom keras.src.ops.numpy import searchsorted as searchsorted\nfrom keras.src.ops.numpy import select as select\nfrom keras.src.ops.numpy import sign as sign\nfrom keras.src.ops.numpy import signbit as signbit\nfrom keras.src.ops.numpy import sin as sin\nfrom keras.src.ops.numpy import sinc as sinc\nfrom keras.src.ops.numpy import sinh as sinh\nfrom keras.src.ops.numpy import size as size\nfrom keras.src.ops.numpy import slogdet as slogdet\nfrom keras.src.ops.numpy import sort as sort\nfrom keras.src.ops.numpy import split as split\nfrom keras.src.ops.numpy import sqrt as sqrt\nfrom keras.src.ops.numpy import square as square\nfrom keras.src.ops.numpy import squeeze as squeeze\nfrom keras.src.ops.numpy import stack as stack\nfrom keras.src.ops.numpy import std as std\nfrom keras.src.ops.numpy import subtract as subtract\nfrom keras.src.ops.numpy import sum as sum\nfrom keras.src.ops.numpy import swapaxes as swapaxes\nfrom keras.src.ops.numpy import take as take\nfrom keras.src.ops.numpy import take_along_axis as take_along_axis\nfrom keras.src.ops.numpy import tan as tan\nfrom keras.src.ops.numpy import tanh as tanh\nfrom keras.src.ops.numpy import tensordot as tensordot\nfrom keras.src.ops.numpy import tile as tile\nfrom keras.src.ops.numpy import trace as trace\nfrom keras.src.ops.numpy import transpose as transpose\nfrom keras.src.ops.numpy import trapezoid as trapezoid\nfrom keras.src.ops.numpy import tri as tri\nfrom keras.src.ops.numpy import tril as tril\nfrom keras.src.ops.numpy import triu as triu\nfrom keras.src.ops.numpy import true_divide as true_divide\nfrom keras.src.ops.numpy import trunc as trunc\nfrom keras.src.ops.numpy import unravel_index as unravel_index\nfrom keras.src.ops.numpy import vander as vander\nfrom keras.src.ops.numpy import var as var\nfrom keras.src.ops.numpy import vdot as vdot\nfrom keras.src.ops.numpy import vectorize as vectorize\nfrom keras.src.ops.numpy import view as view\nfrom keras.src.ops.numpy import vsplit as vsplit\nfrom keras.src.ops.numpy import vstack as vstack\nfrom keras.src.ops.numpy import where as where\nfrom keras.src.ops.numpy import zeros as zeros\nfrom keras.src.ops.numpy import zeros_like as zeros_like\n"
  },
  {
    "path": "keras/api/optimizers/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.optimizers import legacy as legacy\nfrom keras.optimizers import schedules as schedules\nfrom keras.src.optimizers import deserialize as deserialize\nfrom keras.src.optimizers import get as get\nfrom keras.src.optimizers import serialize as serialize\nfrom keras.src.optimizers.adadelta import Adadelta as Adadelta\nfrom keras.src.optimizers.adafactor import Adafactor as Adafactor\nfrom keras.src.optimizers.adagrad import Adagrad as Adagrad\nfrom keras.src.optimizers.adam import Adam as Adam\nfrom keras.src.optimizers.adamax import Adamax as Adamax\nfrom keras.src.optimizers.adamw import AdamW as AdamW\nfrom keras.src.optimizers.ftrl import Ftrl as Ftrl\nfrom keras.src.optimizers.lamb import Lamb as Lamb\nfrom keras.src.optimizers.lion import Lion as Lion\nfrom keras.src.optimizers.loss_scale_optimizer import (\n    LossScaleOptimizer as LossScaleOptimizer,\n)\nfrom keras.src.optimizers.muon import Muon as Muon\nfrom keras.src.optimizers.nadam import Nadam as Nadam\nfrom keras.src.optimizers.optimizer import Optimizer as Optimizer\nfrom keras.src.optimizers.rmsprop import RMSprop as RMSprop\nfrom keras.src.optimizers.schedule_free_adamw import (\n    ScheduleFreeAdamW as ScheduleFreeAdamW,\n)\nfrom keras.src.optimizers.sgd import SGD as SGD\n"
  },
  {
    "path": "keras/api/optimizers/legacy/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.optimizers import LegacyOptimizerWarning as Adagrad\nfrom keras.src.optimizers import LegacyOptimizerWarning as Adam\nfrom keras.src.optimizers import LegacyOptimizerWarning as Ftrl\nfrom keras.src.optimizers import LegacyOptimizerWarning as Optimizer\nfrom keras.src.optimizers import LegacyOptimizerWarning as RMSprop\nfrom keras.src.optimizers import LegacyOptimizerWarning as SGD\n"
  },
  {
    "path": "keras/api/optimizers/schedules/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.optimizers.schedules.learning_rate_schedule import (\n    CosineDecay as CosineDecay,\n)\nfrom keras.src.optimizers.schedules.learning_rate_schedule import (\n    CosineDecayRestarts as CosineDecayRestarts,\n)\nfrom keras.src.optimizers.schedules.learning_rate_schedule import (\n    ExponentialDecay as ExponentialDecay,\n)\nfrom keras.src.optimizers.schedules.learning_rate_schedule import (\n    InverseTimeDecay as InverseTimeDecay,\n)\nfrom keras.src.optimizers.schedules.learning_rate_schedule import (\n    LearningRateSchedule as LearningRateSchedule,\n)\nfrom keras.src.optimizers.schedules.learning_rate_schedule import (\n    PiecewiseConstantDecay as PiecewiseConstantDecay,\n)\nfrom keras.src.optimizers.schedules.learning_rate_schedule import (\n    PolynomialDecay as PolynomialDecay,\n)\nfrom keras.src.optimizers.schedules.learning_rate_schedule import (\n    deserialize as deserialize,\n)\nfrom keras.src.optimizers.schedules.learning_rate_schedule import (\n    serialize as serialize,\n)\n"
  },
  {
    "path": "keras/api/preprocessing/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.preprocessing import image as image\nfrom keras.preprocessing import sequence as sequence\nfrom keras.src.utils.image_dataset_utils import (\n    image_dataset_from_directory as image_dataset_from_directory,\n)\nfrom keras.src.utils.text_dataset_utils import (\n    text_dataset_from_directory as text_dataset_from_directory,\n)\nfrom keras.src.utils.timeseries_dataset_utils import (\n    timeseries_dataset_from_array as timeseries_dataset_from_array,\n)\n"
  },
  {
    "path": "keras/api/preprocessing/image/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.utils.image_utils import array_to_img as array_to_img\nfrom keras.src.utils.image_utils import img_to_array as img_to_array\nfrom keras.src.utils.image_utils import load_img as load_img\nfrom keras.src.utils.image_utils import save_img as save_img\nfrom keras.src.utils.image_utils import smart_resize as smart_resize\n"
  },
  {
    "path": "keras/api/preprocessing/sequence/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.utils.sequence_utils import pad_sequences as pad_sequences\n"
  },
  {
    "path": "keras/api/quantizers/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.quantizers import deserialize as deserialize\nfrom keras.src.quantizers import get as get\nfrom keras.src.quantizers import serialize as serialize\nfrom keras.src.quantizers.awq_config import AWQConfig as AWQConfig\nfrom keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig\nfrom keras.src.quantizers.quantization_config import (\n    Float8QuantizationConfig as Float8QuantizationConfig,\n)\nfrom keras.src.quantizers.quantization_config import (\n    Int4QuantizationConfig as Int4QuantizationConfig,\n)\nfrom keras.src.quantizers.quantization_config import (\n    Int8QuantizationConfig as Int8QuantizationConfig,\n)\nfrom keras.src.quantizers.quantization_config import (\n    QuantizationConfig as QuantizationConfig,\n)\nfrom keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer\nfrom keras.src.quantizers.quantizers import Quantizer as Quantizer\nfrom keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize\nfrom keras.src.quantizers.quantizers import (\n    abs_max_quantize_grouped_with_zero_point as abs_max_quantize_grouped_with_zero_point,\n)\nfrom keras.src.quantizers.quantizers import (\n    compute_float8_amax_history as compute_float8_amax_history,\n)\nfrom keras.src.quantizers.quantizers import (\n    compute_float8_scale as compute_float8_scale,\n)\nfrom keras.src.quantizers.quantizers import (\n    fake_quant_with_min_max_vars as fake_quant_with_min_max_vars,\n)\nfrom keras.src.quantizers.quantizers import pack_int4 as pack_int4\nfrom keras.src.quantizers.quantizers import (\n    quantize_and_dequantize as quantize_and_dequantize,\n)\nfrom keras.src.quantizers.quantizers import unpack_int4 as unpack_int4\n"
  },
  {
    "path": "keras/api/random/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.random.random import beta as beta\nfrom keras.src.random.random import binomial as binomial\nfrom keras.src.random.random import categorical as categorical\nfrom keras.src.random.random import dropout as dropout\nfrom keras.src.random.random import gamma as gamma\nfrom keras.src.random.random import normal as normal\nfrom keras.src.random.random import randint as randint\nfrom keras.src.random.random import shuffle as shuffle\nfrom keras.src.random.random import truncated_normal as truncated_normal\nfrom keras.src.random.random import uniform as uniform\nfrom keras.src.random.seed_generator import SeedGenerator as SeedGenerator\n"
  },
  {
    "path": "keras/api/regularizers/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.regularizers import deserialize as deserialize\nfrom keras.src.regularizers import get as get\nfrom keras.src.regularizers import serialize as serialize\nfrom keras.src.regularizers.regularizers import L1 as L1\nfrom keras.src.regularizers.regularizers import L1 as l1\nfrom keras.src.regularizers.regularizers import L1L2 as L1L2\nfrom keras.src.regularizers.regularizers import L1L2 as l1_l2\nfrom keras.src.regularizers.regularizers import L2 as L2\nfrom keras.src.regularizers.regularizers import L2 as l2\nfrom keras.src.regularizers.regularizers import (\n    OrthogonalRegularizer as OrthogonalRegularizer,\n)\nfrom keras.src.regularizers.regularizers import (\n    OrthogonalRegularizer as orthogonal_regularizer,\n)\nfrom keras.src.regularizers.regularizers import Regularizer as Regularizer\n"
  },
  {
    "path": "keras/api/saving/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.saving.file_editor import KerasFileEditor as KerasFileEditor\nfrom keras.src.saving.object_registration import (\n    CustomObjectScope as CustomObjectScope,\n)\nfrom keras.src.saving.object_registration import (\n    CustomObjectScope as custom_object_scope,\n)\nfrom keras.src.saving.object_registration import (\n    get_custom_objects as get_custom_objects,\n)\nfrom keras.src.saving.object_registration import (\n    get_registered_name as get_registered_name,\n)\nfrom keras.src.saving.object_registration import (\n    get_registered_object as get_registered_object,\n)\nfrom keras.src.saving.object_registration import (\n    register_keras_serializable as register_keras_serializable,\n)\nfrom keras.src.saving.saving_api import load_model as load_model\nfrom keras.src.saving.saving_api import load_weights as load_weights\nfrom keras.src.saving.saving_api import save_model as save_model\nfrom keras.src.saving.saving_api import save_weights as save_weights\nfrom keras.src.saving.serialization_lib import (\n    deserialize_keras_object as deserialize_keras_object,\n)\nfrom keras.src.saving.serialization_lib import (\n    serialize_keras_object as serialize_keras_object,\n)\n"
  },
  {
    "path": "keras/api/tree/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.tree.tree_api import MAP_TO_NONE as MAP_TO_NONE\nfrom keras.src.tree.tree_api import assert_same_paths as assert_same_paths\nfrom keras.src.tree.tree_api import (\n    assert_same_structure as assert_same_structure,\n)\nfrom keras.src.tree.tree_api import flatten as flatten\nfrom keras.src.tree.tree_api import flatten_with_path as flatten_with_path\nfrom keras.src.tree.tree_api import is_nested as is_nested\nfrom keras.src.tree.tree_api import lists_to_tuples as lists_to_tuples\nfrom keras.src.tree.tree_api import map_shape_structure as map_shape_structure\nfrom keras.src.tree.tree_api import map_structure as map_structure\nfrom keras.src.tree.tree_api import map_structure_up_to as map_structure_up_to\nfrom keras.src.tree.tree_api import pack_sequence_as as pack_sequence_as\nfrom keras.src.tree.tree_api import traverse as traverse\n"
  },
  {
    "path": "keras/api/utils/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.backend.common.global_state import clear_session as clear_session\nfrom keras.src.backend.common.keras_tensor import (\n    is_keras_tensor as is_keras_tensor,\n)\nfrom keras.src.backend.common.variables import (\n    standardize_dtype as standardize_dtype,\n)\nfrom keras.src.layers.preprocessing.feature_space import (\n    FeatureSpace as FeatureSpace,\n)\nfrom keras.src.ops.operation_utils import get_source_inputs as get_source_inputs\nfrom keras.src.saving.object_registration import (\n    CustomObjectScope as CustomObjectScope,\n)\nfrom keras.src.saving.object_registration import (\n    CustomObjectScope as custom_object_scope,\n)\nfrom keras.src.saving.object_registration import (\n    get_custom_objects as get_custom_objects,\n)\nfrom keras.src.saving.object_registration import (\n    get_registered_name as get_registered_name,\n)\nfrom keras.src.saving.object_registration import (\n    get_registered_object as get_registered_object,\n)\nfrom keras.src.saving.object_registration import (\n    register_keras_serializable as register_keras_serializable,\n)\nfrom keras.src.saving.serialization_lib import (\n    deserialize_keras_object as deserialize_keras_object,\n)\nfrom keras.src.saving.serialization_lib import (\n    serialize_keras_object as serialize_keras_object,\n)\nfrom keras.src.trainers.data_adapters.data_adapter_utils import (\n    pack_x_y_sample_weight as pack_x_y_sample_weight,\n)\nfrom keras.src.trainers.data_adapters.data_adapter_utils import (\n    unpack_x_y_sample_weight as unpack_x_y_sample_weight,\n)\nfrom keras.src.trainers.data_adapters.py_dataset_adapter import (\n    PyDataset as PyDataset,\n)\nfrom keras.src.trainers.data_adapters.py_dataset_adapter import (\n    PyDataset as Sequence,\n)\nfrom keras.src.utils.audio_dataset_utils import (\n    audio_dataset_from_directory as audio_dataset_from_directory,\n)\nfrom keras.src.utils.config import Config as Config\nfrom keras.src.utils.dataset_utils import split_dataset as split_dataset\nfrom keras.src.utils.file_utils import get_file as get_file\nfrom keras.src.utils.image_dataset_utils import (\n    image_dataset_from_directory as image_dataset_from_directory,\n)\nfrom keras.src.utils.image_utils import array_to_img as array_to_img\nfrom keras.src.utils.image_utils import img_to_array as img_to_array\nfrom keras.src.utils.image_utils import load_img as load_img\nfrom keras.src.utils.image_utils import save_img as save_img\nfrom keras.src.utils.io_utils import (\n    disable_interactive_logging as disable_interactive_logging,\n)\nfrom keras.src.utils.io_utils import (\n    enable_interactive_logging as enable_interactive_logging,\n)\nfrom keras.src.utils.io_utils import (\n    is_interactive_logging_enabled as is_interactive_logging_enabled,\n)\nfrom keras.src.utils.model_visualization import model_to_dot as model_to_dot\nfrom keras.src.utils.model_visualization import plot_model as plot_model\nfrom keras.src.utils.numerical_utils import normalize as normalize\nfrom keras.src.utils.numerical_utils import to_categorical as to_categorical\nfrom keras.src.utils.progbar import Progbar as Progbar\nfrom keras.src.utils.rng_utils import set_random_seed as set_random_seed\nfrom keras.src.utils.sequence_utils import pad_sequences as pad_sequences\nfrom keras.src.utils.text_dataset_utils import (\n    text_dataset_from_directory as text_dataset_from_directory,\n)\nfrom keras.src.utils.timeseries_dataset_utils import (\n    timeseries_dataset_from_array as timeseries_dataset_from_array,\n)\nfrom keras.utils import bounding_boxes as bounding_boxes\nfrom keras.utils import legacy as legacy\n"
  },
  {
    "path": "keras/api/utils/bounding_boxes/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (\n    affine_transform as affine_transform,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (\n    clip_to_image_size as clip_to_image_size,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (\n    convert_format as convert_format,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (\n    crop as crop,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (\n    decode_deltas_to_boxes as decode_deltas_to_boxes,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (\n    encode_box_to_deltas as encode_box_to_deltas,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (\n    pad as pad,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.iou import (\n    compute_ciou as compute_ciou,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.iou import (\n    compute_iou as compute_iou,\n)\n"
  },
  {
    "path": "keras/api/utils/legacy/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.legacy.saving.serialization import (\n    deserialize_keras_object as deserialize_keras_object,\n)\nfrom keras.src.legacy.saving.serialization import (\n    serialize_keras_object as serialize_keras_object,\n)\n"
  },
  {
    "path": "keras/api/visualization/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.visualization.draw_bounding_boxes import (\n    draw_bounding_boxes as draw_bounding_boxes,\n)\nfrom keras.src.visualization.draw_segmentation_masks import (\n    draw_segmentation_masks as draw_segmentation_masks,\n)\nfrom keras.src.visualization.plot_bounding_box_gallery import (\n    plot_bounding_box_gallery as plot_bounding_box_gallery,\n)\nfrom keras.src.visualization.plot_image_gallery import (\n    plot_image_gallery as plot_image_gallery,\n)\nfrom keras.src.visualization.plot_segmentation_mask_gallery import (\n    plot_segmentation_mask_gallery as plot_segmentation_mask_gallery,\n)\n"
  },
  {
    "path": "keras/api/wrappers/__init__.py",
    "content": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\"\n\nfrom keras.src.wrappers.sklearn_wrapper import (\n    SKLearnClassifier as SKLearnClassifier,\n)\nfrom keras.src.wrappers.sklearn_wrapper import (\n    SKLearnRegressor as SKLearnRegressor,\n)\nfrom keras.src.wrappers.sklearn_wrapper import (\n    SKLearnTransformer as SKLearnTransformer,\n)\n"
  },
  {
    "path": "keras/src/__init__.py",
    "content": "from keras.src import activations\nfrom keras.src import applications\nfrom keras.src import backend\nfrom keras.src import constraints\nfrom keras.src import datasets\nfrom keras.src import initializers\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import optimizers\nfrom keras.src import regularizers\nfrom keras.src import utils\nfrom keras.src import visualization\nfrom keras.src.backend import KerasTensor\nfrom keras.src.layers import Input\nfrom keras.src.layers import Layer\nfrom keras.src.models import Functional\nfrom keras.src.models import Model\nfrom keras.src.models import Sequential\nfrom keras.src.version import __version__\n"
  },
  {
    "path": "keras/src/activations/__init__.py",
    "content": "import types\n\nfrom keras.src.activations.activations import celu\nfrom keras.src.activations.activations import elu\nfrom keras.src.activations.activations import exponential\nfrom keras.src.activations.activations import gelu\nfrom keras.src.activations.activations import glu\nfrom keras.src.activations.activations import hard_shrink\nfrom keras.src.activations.activations import hard_sigmoid\nfrom keras.src.activations.activations import hard_silu\nfrom keras.src.activations.activations import hard_tanh\nfrom keras.src.activations.activations import leaky_relu\nfrom keras.src.activations.activations import linear\nfrom keras.src.activations.activations import log_sigmoid\nfrom keras.src.activations.activations import log_softmax\nfrom keras.src.activations.activations import mish\nfrom keras.src.activations.activations import relu\nfrom keras.src.activations.activations import relu6\nfrom keras.src.activations.activations import selu\nfrom keras.src.activations.activations import sigmoid\nfrom keras.src.activations.activations import silu\nfrom keras.src.activations.activations import soft_shrink\nfrom keras.src.activations.activations import softmax\nfrom keras.src.activations.activations import softplus\nfrom keras.src.activations.activations import softsign\nfrom keras.src.activations.activations import sparse_plus\nfrom keras.src.activations.activations import sparse_sigmoid\nfrom keras.src.activations.activations import sparsemax\nfrom keras.src.activations.activations import squareplus\nfrom keras.src.activations.activations import tanh\nfrom keras.src.activations.activations import tanh_shrink\nfrom keras.src.activations.activations import threshold\nfrom keras.src.api_export import keras_export\nfrom keras.src.saving import object_registration\nfrom keras.src.saving import serialization_lib\n\nALL_OBJECTS = {\n    relu,\n    leaky_relu,\n    relu6,\n    softmax,\n    celu,\n    elu,\n    selu,\n    softplus,\n    softsign,\n    squareplus,\n    soft_shrink,\n    sparse_plus,\n    silu,\n    gelu,\n    glu,\n    tanh,\n    tanh_shrink,\n    threshold,\n    sigmoid,\n    sparse_sigmoid,\n    exponential,\n    hard_sigmoid,\n    hard_silu,\n    hard_tanh,\n    hard_shrink,\n    linear,\n    mish,\n    log_softmax,\n    log_sigmoid,\n    sparsemax,\n}\n\nALL_OBJECTS_DICT = {fn.__name__: fn for fn in ALL_OBJECTS}\n# Additional aliases\nALL_OBJECTS_DICT[\"swish\"] = silu\nALL_OBJECTS_DICT[\"hard_swish\"] = hard_silu\n\n\n@keras_export(\"keras.activations.serialize\")\ndef serialize(activation):\n    fn_config = serialization_lib.serialize_keras_object(activation)\n    if \"config\" not in fn_config:\n        raise ValueError(\n            f\"Unknown activation function '{activation}' cannot be \"\n            \"serialized due to invalid function name. Make sure to use \"\n            \"an activation name that matches the references defined in \"\n            \"activations.py or use \"\n            \"`@keras.saving.register_keras_serializable()`\"\n            \"to register any custom activations. \"\n            f\"config={fn_config}\"\n        )\n    if not isinstance(activation, types.FunctionType):\n        # Case for additional custom activations represented by objects\n        return fn_config\n    if (\n        isinstance(fn_config[\"config\"], str)\n        and fn_config[\"config\"] not in globals()\n    ):\n        # Case for custom activation functions from external activations modules\n        fn_config[\"config\"] = object_registration.get_registered_name(\n            activation\n        )\n        return fn_config\n    # Case for keras.activations builtins (simply return name)\n    return fn_config[\"config\"]\n\n\n@keras_export(\"keras.activations.deserialize\")\ndef deserialize(config, custom_objects=None):\n    \"\"\"Return a Keras activation function via its config.\"\"\"\n    return serialization_lib.deserialize_keras_object(\n        config,\n        module_objects=ALL_OBJECTS_DICT,\n        custom_objects=custom_objects,\n    )\n\n\n@keras_export(\"keras.activations.get\")\ndef get(identifier):\n    \"\"\"Retrieve a Keras activation function via an identifier.\"\"\"\n    if identifier is None:\n        return linear\n    if isinstance(identifier, dict):\n        obj = serialization_lib.deserialize_keras_object(identifier)\n    elif isinstance(identifier, str):\n        obj = ALL_OBJECTS_DICT.get(identifier, None)\n    else:\n        obj = identifier\n    if callable(obj):\n        return obj\n    raise ValueError(\n        f\"Could not interpret activation function identifier: {identifier}\"\n    )\n"
  },
  {
    "path": "keras/src/activations/activations.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\n\n\n@keras_export(\"keras.activations.relu\")\ndef relu(x, negative_slope=0.0, max_value=None, threshold=0.0):\n    \"\"\"Applies the rectified linear unit activation function.\n\n    With default values, this returns the standard ReLU activation:\n    `max(x, 0)`, the element-wise maximum of 0 and the input tensor.\n\n    Modifying default parameters allows you to use non-zero thresholds,\n    change the max value of the activation,\n    and to use a non-zero multiple of the input for values below the threshold.\n\n    Examples:\n\n    >>> x = [-10, -5, 0.0, 5, 10]\n    >>> keras.activations.relu(x)\n    [ 0.,  0.,  0.,  5., 10.]\n    >>> keras.activations.relu(x, negative_slope=0.5)\n    [-5. , -2.5,  0. ,  5. , 10. ]\n    >>> keras.activations.relu(x, max_value=5.)\n    [0., 0., 0., 5., 5.]\n    >>> keras.activations.relu(x, threshold=5.)\n    [-0., -0.,  0.,  0., 10.]\n\n    Args:\n        x: Input tensor.\n        negative_slope: A `float` that controls the slope\n            for values lower than the threshold.\n        max_value: A `float` that sets the saturation threshold (the largest\n            value the function will return).\n        threshold: A `float` giving the threshold value of the activation\n            function below which values will be damped or set to zero.\n\n    Returns:\n        A tensor with the same shape and dtype as input `x`.\n    \"\"\"\n    if backend.any_symbolic_tensors((x,)):\n        return ReLU(\n            negative_slope=negative_slope,\n            max_value=max_value,\n            threshold=threshold,\n        )(x)\n    return ReLU.static_call(\n        x,\n        negative_slope=negative_slope,\n        max_value=max_value,\n        threshold=threshold,\n    )\n\n\nclass ReLU(ops.Operation):\n    def __init__(\n        self, negative_slope=0.0, max_value=None, threshold=0.0, name=None\n    ):\n        super().__init__(name=name)\n        self.negative_slope = negative_slope\n        self.max_value = max_value\n        self.threshold = threshold\n\n    def call(self, x):\n        return self.static_call(\n            x,\n            negative_slope=self.negative_slope,\n            max_value=self.max_value,\n            threshold=self.threshold,\n        )\n\n    def compute_output_spec(self, x):\n        return backend.KerasTensor(x.shape, x.dtype)\n\n    @staticmethod\n    def static_call(x, negative_slope=0.0, max_value=None, threshold=0.0):\n        x = backend.convert_to_tensor(x)\n        if negative_slope != 0.0:\n            if max_value is None and threshold == 0:\n                return backend.nn.leaky_relu(x, negative_slope=negative_slope)\n\n            if threshold != 0:\n                negative_part = backend.nn.relu(-x + threshold)\n            else:\n                negative_part = backend.nn.relu(-x)\n        else:\n            negative_part = 1\n\n        clip_max = max_value is not None\n        if threshold != 0:\n            # computes x for x > threshold else 0\n            threshold = ops.cast(threshold, dtype=x.dtype)\n            x = x * backend.cast(\n                backend.numpy.greater(x, threshold), dtype=x.dtype\n            )\n        elif max_value == 6:\n            # if no threshold, then can use nn.relu6 native op for performance\n            x = backend.nn.relu6(x)\n            clip_max = False\n        else:\n            x = backend.nn.relu(x)\n\n        if clip_max:\n            min_value = ops.cast(0.0, dtype=x.dtype)\n            max_value = ops.cast(max_value, dtype=x.dtype)\n            x = backend.numpy.clip(x, min_value, max_value)\n\n        if negative_slope != 0.0:\n            x -= negative_slope * negative_part\n        return x\n\n\n@keras_export(\"keras.activations.leaky_relu\")\ndef leaky_relu(x, negative_slope=0.2):\n    \"\"\"Leaky relu activation function.\n\n    Args:\n        x: Input tensor.\n        negative_slope: A `float` that controls the slope\n            for values lower than the threshold.\n    \"\"\"\n    return ops.leaky_relu(x, negative_slope=negative_slope)\n\n\n@keras_export(\"keras.activations.relu6\")\ndef relu6(x):\n    \"\"\"Relu6 activation function.\n\n    It's the ReLU function, but truncated to a maximum value of 6.\n\n    Args:\n        x: Input tensor.\n    \"\"\"\n    return ops.relu6(x)\n\n\n@keras_export(\"keras.activations.softmax\")\ndef softmax(x, axis=-1):\n    \"\"\"Softmax converts a vector of values to a probability distribution.\n\n    The elements of the output vector are in range `[0, 1]` and sum to 1.\n\n    Each input vector is handled independently.\n    The `axis` argument sets which axis of the input the function\n    is applied along.\n\n    Softmax is often used as the activation for the last\n    layer of a classification network because the result could be interpreted as\n    a probability distribution.\n\n    The softmax of each vector x is computed as\n    `exp(x) / sum(exp(x))`.\n\n    The input values in are the log-odds of the resulting probability.\n\n    Args:\n        x: Input tensor.\n        axis: Integer, axis along which the softmax is applied.\n    \"\"\"\n    output = ops.softmax(x, axis=axis)\n    # Cache the logits to use for crossentropy loss.\n    try:\n        output._keras_logits = x\n    except AttributeError:\n        # We're dealing with a C-type.\n        pass\n    return output\n\n\n@keras_export(\"keras.activations.elu\")\ndef elu(x, alpha=1.0):\n    \"\"\"Exponential Linear Unit.\n\n    The exponential linear unit (ELU) with `alpha > 0` is defined as:\n\n    - `x` if `x > 0`\n    - alpha * `exp(x) - 1` if `x < 0`\n\n    ELUs have negative values which pushes the mean of the activations\n    closer to zero.\n\n    Mean activations that are closer to zero enable faster learning as they\n    bring the gradient closer to the natural gradient.\n    ELUs saturate to a negative value when the argument gets smaller.\n    Saturation means a small derivative which decreases the variation\n    and the information that is propagated to the next layer.\n\n    Args:\n        x: Input tensor.\n        alpha: A scalar, slope of positive section. Defaults to `1.0`.\n\n    Reference:\n\n    - [Clevert et al., 2016](https://arxiv.org/abs/1511.07289)\n    \"\"\"\n    return ops.elu(x, alpha=alpha)\n\n\n@keras_export(\"keras.activations.selu\")\ndef selu(x):\n    \"\"\"Scaled Exponential Linear Unit (SELU).\n\n    The Scaled Exponential Linear Unit (SELU) activation function is defined as:\n\n    - `scale * x` if `x > 0`\n    - `scale * alpha * (exp(x) - 1)` if `x < 0`\n\n    where `alpha` and `scale` are pre-defined constants\n    (`alpha=1.67326324` and `scale=1.05070098`).\n\n    Basically, the SELU activation function multiplies `scale` (> 1) with the\n    output of the `keras.activations.elu` function to ensure a slope larger\n    than one for positive inputs.\n\n    The values of `alpha` and `scale` are\n    chosen so that the mean and variance of the inputs are preserved\n    between two consecutive layers as long as the weights are initialized\n    correctly (see `keras.initializers.LecunNormal` initializer)\n    and the number of input units is \"large enough\"\n    (see reference paper for more information).\n\n    Args:\n        x: Input tensor.\n\n    Notes:\n\n    - To be used together with the\n        `keras.initializers.LecunNormal` initializer.\n    - To be used together with the dropout variant\n        `keras.layers.AlphaDropout` (rather than regular dropout).\n\n    Reference:\n\n    - [Klambauer et al., 2017](https://arxiv.org/abs/1706.02515)\n    \"\"\"\n    return ops.selu(x)\n\n\n@keras_export(\"keras.activations.softplus\")\ndef softplus(x):\n    \"\"\"Softplus activation function.\n\n    It is defined as: `softplus(x) = log(exp(x) + 1)`.\n\n    Args:\n        x: Input tensor.\n    \"\"\"\n    return ops.softplus(x)\n\n\n@keras_export(\"keras.activations.softsign\")\ndef softsign(x):\n    \"\"\"Softsign activation function.\n\n    Softsign is defined as: `softsign(x) = x / (abs(x) + 1)`.\n\n    Args:\n        x: Input tensor.\n    \"\"\"\n    return ops.softsign(x)\n\n\n@keras_export(\"keras.activations.soft_shrink\")\ndef soft_shrink(x, threshold=0.5):\n    \"\"\"Soft Shrink activation function.\n\n    It is defined as:\n\n    `soft_shrink(x) = x - threshold` if `x > threshold`,\n    `soft_shrink(x) = x + threshold` if `x < -threshold`,\n    `soft_shrink(x) = 0` otherwise.\n\n    Args:\n        x: Input tensor.\n        threshold: Threshold value. Defaults to 0.5.\n\n    \"\"\"\n    return ops.soft_shrink(x, threshold=threshold)\n\n\n@keras_export(\"keras.activations.sparse_plus\")\ndef sparse_plus(x):\n    \"\"\"SparsePlus activation function.\n\n    SparsePlus is defined as:\n\n    `sparse_plus(x) = 0` for `x <= -1`.\n    `sparse_plus(x) = (1/4) * (x + 1)^2` for `-1 < x < 1`.\n    `sparse_plus(x) = x` for `x >= 1`.\n\n    Args:\n        x: Input tensor.\n\n    \"\"\"\n    return ops.sparse_plus(x)\n\n\n@keras_export([\"keras.activations.silu\", \"keras.activations.swish\"])\ndef silu(x):\n    \"\"\"Swish (or Silu) activation function.\n\n    It is defined as: `swish(x) = x * sigmoid(x)`.\n\n    The Swish (or Silu) activation function is a smooth,\n    non-monotonic function that is unbounded above and\n    bounded below.\n\n    Args:\n        x: Input tensor.\n\n    Reference:\n\n    - [Ramachandran et al., 2017](https://arxiv.org/abs/1710.05941)\n    \"\"\"\n    return ops.silu(x)\n\n\n@keras_export(\"keras.activations.squareplus\")\ndef squareplus(x, b=4):\n    \"\"\"Squareplus activation function.\n\n    The Squareplus activation function is defined as:\n\n    `f(x) = (x + sqrt(x^2 + b)) / 2`\n\n    Where `b` is a smoothness parameter.\n\n    Args:\n        x: Input tensor.\n        b: Smoothness parameter. Defaults to 4.\n\n    Reference:\n\n    - [Ramachandran et al., 2021](https://arxiv.org/abs/2112.11687)\n    \"\"\"\n    return ops.squareplus(x, b=b)\n\n\n@keras_export(\"keras.activations.gelu\")\ndef gelu(x, approximate=False):\n    \"\"\"Gaussian error linear unit (GELU) activation function.\n\n    The Gaussian error linear unit (GELU) is defined as:\n\n    `gelu(x) = x * P(X <= x)` where `P(X) ~ N(0, 1)`,\n    i.e. `gelu(x) = 0.5 * x * (1 + erf(x / sqrt(2)))`.\n\n    GELU weights inputs by their value, rather than gating\n    inputs by their sign as in ReLU.\n\n    Args:\n        x: Input tensor.\n        approximate: A `bool`, whether to enable approximation.\n\n    Reference:\n\n    - [Hendrycks et al., 2016](https://arxiv.org/abs/1606.08415)\n    \"\"\"\n    return ops.gelu(x, approximate=approximate)\n\n\n@keras_export(\"keras.activations.celu\")\ndef celu(x, alpha=1.0):\n    \"\"\"Continuously Differentiable Exponential Linear Unit.\n\n    The CeLU activation function is defined as:\n\n    `celu(x) = alpha * (exp(x / alpha) - 1) for x < 0`,`celu(x) = x for x >= 0`.\n\n    where `alpha` is a scaling parameter that controls the activation's shape.\n\n    Args:\n        x: Input tensor.\n        alpha: The α value for the CeLU formulation. Defaults to `1.0`.\n\n    Reference:\n\n    - [Barron, J. T., 2017](https://arxiv.org/abs/1704.07483)\n    \"\"\"\n    return ops.celu(x, alpha=alpha)\n\n\n@keras_export(\"keras.activations.glu\")\ndef glu(x, axis=-1):\n    \"\"\"Gated Linear Unit (GLU) activation function.\n\n    The GLU activation function is defined as:\n\n    `glu(x) = a * sigmoid(b)`,\n\n    where `x` is split into two equal parts `a` and `b` along the given axis.\n\n    Args:\n        x: Input tensor.\n        axis: The axis along which to split the input tensor. Defaults to `-1`.\n\n    Reference:\n\n    - [Dauphin et al., 2017](https://arxiv.org/abs/1612.08083)\n    \"\"\"\n    return ops.glu(x, axis=axis)\n\n\n@keras_export(\"keras.activations.tanh\")\ndef tanh(x):\n    \"\"\"Hyperbolic tangent activation function.\n\n    It is defined as:\n    `tanh(x) = sinh(x) / cosh(x)`, i.e.\n    `tanh(x) = ((exp(x) - exp(-x)) / (exp(x) + exp(-x)))`.\n\n    Args:\n        x: Input tensor.\n    \"\"\"\n    return ops.tanh(x)\n\n\n@keras_export(\"keras.activations.tanh_shrink\")\ndef tanh_shrink(x):\n    \"\"\"Tanh shrink activation function.\n\n    It is defined as:\n\n    `f(x) = x - tanh(x)`.\n\n    Args:\n        x: Input tensor.\n    \"\"\"\n    return ops.tanh_shrink(x)\n\n\n@keras_export(\"keras.activations.hard_tanh\")\ndef hard_tanh(x):\n    \"\"\"HardTanh activation function.\n\n    It is defined as:\n    `hard_tanh(x) = -1 for x < -1`,\n    `hard_tanh(x) = x for -1 <= x <= 1`,\n    `hard_tanh(x) = 1 for x > 1`.\n\n    Args:\n        x: Input tensor.\n    \"\"\"\n    return ops.hard_tanh(x)\n\n\n@keras_export(\"keras.activations.hard_shrink\")\ndef hard_shrink(x, threshold=0.5):\n    \"\"\"Hard Shrink activation function.\n\n    It is defined as:\n\n    `hard_shrink(x) = x` if `|x| > threshold`,\n    `hard_shrink(x) = 0` otherwise.\n\n    Args:\n        x: Input tensor.\n        threshold: Threshold value. Defaults to 0.5.\n\n    \"\"\"\n    return ops.hard_shrink(x, threshold=threshold)\n\n\n@keras_export(\"keras.activations.threshold\")\ndef threshold(x, threshold, default_value):\n    \"\"\"Threshold activation function.\n\n    It is defined as:\n\n    `threshold(x) = x` if `x > threshold`,\n    `threshold(x) = default_value` otherwise.\n\n    Args:\n        x: Input tensor.\n        threshold: The value that decides when to retain or replace x.\n        default_value: Value to assign when `x <= threshold`.\n\n    \"\"\"\n    return ops.threshold(x, threshold, default_value)\n\n\n@keras_export(\"keras.activations.sigmoid\")\ndef sigmoid(x):\n    \"\"\"Sigmoid activation function.\n\n    It is defined as: `sigmoid(x) = 1 / (1 + exp(-x))`.\n\n    For small values (<-5),\n    `sigmoid` returns a value close to zero, and for large values (>5)\n    the result of the function gets close to 1.\n\n    Sigmoid is equivalent to a 2-element softmax, where the second element is\n    assumed to be zero. The sigmoid function always returns a value between\n    0 and 1.\n\n    Args:\n        x: Input tensor.\n    \"\"\"\n    output = ops.sigmoid(x)\n    # Cache the logits to use for crossentropy loss.\n    try:\n        output._keras_logits = x\n    except AttributeError:\n        # We're dealing with a C-type.\n        pass\n    return output\n\n\n@keras_export(\"keras.activations.exponential\")\ndef exponential(x):\n    \"\"\"Exponential activation function.\n\n    Args:\n        x: Input tensor.\n    \"\"\"\n    return ops.exp(x)\n\n\n@keras_export(\"keras.activations.hard_sigmoid\")\ndef hard_sigmoid(x):\n    \"\"\"Hard sigmoid activation function.\n\n    The hard sigmoid activation is defined as:\n\n    - `0` if `if x <= -3`\n    - `1` if `x >= 3`\n    - `(x/6) + 0.5` if `-3 < x < 3`\n\n    It's a faster, piecewise linear approximation\n    of the sigmoid activation.\n\n    Args:\n        x: Input tensor.\n\n    Reference:\n\n    - [Wikipedia \"Hard sigmoid\"](https://en.wikipedia.org/wiki/Hard_sigmoid)\n    \"\"\"\n    return ops.hard_sigmoid(x)\n\n\n@keras_export(\"keras.activations.log_sigmoid\")\ndef log_sigmoid(x):\n    \"\"\"Logarithm of the sigmoid activation function.\n\n    It is defined as `f(x) = log(1 / (1 + exp(-x)))`.\n\n    Args:\n        x: Input tensor.\n\n    \"\"\"\n    return ops.log_sigmoid(x)\n\n\n@keras_export(\"keras.activations.sparse_sigmoid\")\ndef sparse_sigmoid(x):\n    \"\"\"Sparse sigmoid activation function.\n\n    It is defined as\n\n    `f(x) = 0` for `x <= -1`,\n    `f(x) = 0.5 * (x + 1)` for `-1 < x < 1`,\n    `f(x) = 1` for `x >= 1`.\n\n    Args:\n        x: Input tensor.\n\n    Reference:\n\n    - [M. Blondel, A. F. T. Martins, V. Niculae, 2019](https://arxiv.org/pdf/1901.02324)\n\n    \"\"\"\n    return ops.sparse_sigmoid(x)\n\n\n@keras_export([\"keras.activations.hard_silu\", \"keras.activations.hard_swish\"])\ndef hard_silu(x):\n    \"\"\"Hard SiLU activation function, also known as Hard Swish.\n\n    It is defined as:\n\n    - `0` if `if x < -3`\n    - `x` if `x > 3`\n    - `x * (x + 3) / 6` if `-3 <= x <= 3`\n\n    It's a faster, piecewise linear approximation of the silu activation.\n\n    Args:\n        x: Input tensor.\n\n    Reference:\n\n    - [A Howard, 2019](https://arxiv.org/abs/1905.02244)\n    \"\"\"\n    x = backend.convert_to_tensor(x)\n    return ops.hard_silu(x)\n\n\n@keras_export(\"keras.activations.linear\")\ndef linear(x):\n    \"\"\"Linear activation function (pass-through).\n\n    A \"linear\" activation is an identity function:\n    it returns the input, unmodified.\n\n    Args:\n        x: Input tensor.\n    \"\"\"\n    return x\n\n\nclass Mish(ops.Operation):\n    def call(self, x):\n        return self.static_call(x)\n\n    def compute_output_spec(self, x):\n        return backend.KerasTensor(x.shape, x.dtype)\n\n    @staticmethod\n    def static_call(x):\n        return x * backend.nn.tanh(backend.nn.softplus(x))\n\n\n@keras_export(\"keras.activations.mish\")\ndef mish(x):\n    \"\"\"Mish activation function.\n\n    It is defined as:\n\n    `mish(x) = x * tanh(softplus(x))`\n\n    where `softplus` is defined as:\n\n    `softplus(x) = log(exp(x) + 1)`\n\n    Args:\n        x: Input tensor.\n\n    Reference:\n\n    - [Misra, 2019](https://arxiv.org/abs/1908.08681)\n    \"\"\"\n    x = backend.convert_to_tensor(x)\n    return Mish.static_call(x)\n\n\n@keras_export(\"keras.activations.log_softmax\")\ndef log_softmax(x, axis=-1):\n    \"\"\"Log-Softmax activation function.\n\n    Each input vector is handled independently.\n    The `axis` argument sets which axis of the input the function\n    is applied along.\n\n    Args:\n        x: Input tensor.\n        axis: Integer, axis along which the softmax is applied.\n    \"\"\"\n    return ops.log_softmax(x, axis=axis)\n\n\n@keras_export([\"keras.activations.sparsemax\"])\ndef sparsemax(x, axis=-1):\n    \"\"\"Sparsemax activation function.\n\n    For each batch `i`, and class `j`,\n    sparsemax activation function is defined as:\n\n    `sparsemax(x)[i, j] = max(x[i, j] - τ(x[i, :]), 0).`\n\n    Args:\n        x: Input tensor.\n        axis: `int`, axis along which the sparsemax operation is applied.\n\n    Returns:\n        A tensor, output of sparsemax transformation. Has the same type and\n        shape as `x`.\n\n    Reference:\n\n    - [Martins et.al., 2016](https://arxiv.org/abs/1602.02068)\n    \"\"\"\n    x = backend.convert_to_tensor(x)\n    return ops.sparsemax(x, axis)\n"
  },
  {
    "path": "keras/src/activations/activations_test.py",
    "content": "import numpy as np\n\nfrom keras.src import activations\nfrom keras.src import backend\nfrom keras.src import testing\n\n\ndef _ref_softmax(values):\n    m = np.max(values)\n    e = np.exp(values - m)\n    return e / np.sum(e)\n\n\ndef _ref_softplus(x):\n    return np.log(np.ones_like(x) + np.exp(x))\n\n\ndef _ref_log_softmax(values):\n    max_val = np.max(values)  # for numerical stability\n    stabilized_values = values - max_val\n    log_sum_exp = np.log(np.sum(np.exp(stabilized_values)))\n    return stabilized_values - log_sum_exp\n\n\ndef _ref_leaky_relu(x, alpha=0.2):\n    return x if x > 0 else alpha * x\n\n\ndef _ref_relu6(x):\n    return min(max(0, x), 6)\n\n\ndef _ref_silu(x):\n    return x / (1 + np.exp(-x))\n\n\ndef _ref_hard_sigmoid(x):\n    x = (x / 6.0) + 0.5\n    z = 0.0 if x <= 0 else (1.0 if x >= 1 else x)\n    return z\n\n\ndef _ref_sparse_sigmoid(x):\n    return np.where(x <= -1, 0, np.where(x >= 1, 1, 0.5 * (x + 1)))\n\n\ndef _ref_log_sigmoid(x):\n    return -1 * _ref_softplus(-x)\n\n\ndef _ref_hard_silu(x):\n    return x * np.minimum(np.maximum(0.0, x + 3.0), 6.0) * (1.0 / 6.0)\n\n\ndef _ref_sigmoid(x):\n    if x >= 0:\n        return 1 / (1 + np.exp(-x))\n    else:\n        z = np.exp(x)\n        return z / (1 + z)\n\n\ndef _ref_softsign(x):\n    return np.divide(x, np.ones_like(x) + np.absolute(x))\n\n\nclass ActivationsTest(testing.TestCase):\n    def test_softmax(self):\n        x = np.random.random((2, 5))\n\n        result = activations.softmax(x[np.newaxis, :])[0]\n        expected = _ref_softmax(x[0])\n        self.assertAllClose(result[0], expected, rtol=1e-05)\n\n    def test_softmax_2d_axis_0(self):\n        x = np.random.random((2, 5))\n        result = activations.softmax(x[np.newaxis, :], axis=1)[0]\n        expected = np.zeros((2, 5))\n        for i in range(5):\n            expected[:, i] = _ref_softmax(x[:, i])\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n    def test_softmax_3d_axis_tuple(self):\n        x = np.random.random((2, 3, 5))\n        result = activations.softmax(x, axis=(1, 2))\n        expected = np.zeros((2, 3, 5))\n        for i in range(2):\n            expected[i, :, :] = _ref_softmax(x[i, :, :])\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n    def test_softmax_1d(self):\n        x = np.random.random(5)\n        result = activations.softmax(x)\n        expected = _ref_softmax(x)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n    def test_softmax_higher_dim(self):\n        x = np.random.random((2, 3, 4, 5))\n        result = activations.softmax(x, axis=(2, 3))\n        expected = np.zeros((2, 3, 4, 5))\n        for i in range(2):\n            for j in range(3):\n                expected[i, j, :, :] = _ref_softmax(x[i, j, :, :])\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n    def test_softmax_higher_dim_multiple_axes(self):\n        x = np.random.random((2, 3, 4, 5, 6))\n        result = activations.softmax(x, axis=(2, 3, 4))\n        expected = np.zeros((2, 3, 4, 5, 6))\n        for i in range(2):\n            for j in range(3):\n                expected[i, j, :, :, :] = _ref_softmax(x[i, j, :, :, :])\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n    def test_softmax_negative_axis(self):\n        x = np.random.random((2, 5))\n        result = activations.softmax(x, axis=-1)\n        expected = np.zeros((2, 5))\n        for i in range(2):\n            expected[i, :] = _ref_softmax(x[i, :])\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n    def test_temporal_softmax(self):\n        x = np.random.random((2, 2, 3)) * 10\n        result = activations.softmax(x[np.newaxis, :])[0]\n        expected = _ref_softmax(x[0, 0])\n        self.assertAllClose(result[0, 0], expected, rtol=1e-05)\n\n    def test_log_softmax_2d_axis_0(self):\n        x = np.random.random((2, 5))\n        result = activations.log_softmax(x[np.newaxis, :], axis=1)[0]\n        expected = np.zeros((2, 5))\n        for i in range(5):\n            expected[:, i] = _ref_log_softmax(x[:, i])\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n    def test_log_softmax_3d_axis_tuple(self):\n        x = np.random.random((2, 3, 5))\n        result = activations.log_softmax(x, axis=(1, 2))\n        expected = np.zeros((2, 3, 5))\n        for i in range(2):\n            expected[i, :, :] = _ref_log_softmax(x[i, :, :])\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n    def test_log_softmax_1d(self):\n        x = np.random.random(5)\n        result = activations.log_softmax(x)\n        expected = _ref_log_softmax(x)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n    def test_log_softmax_higher_dim(self):\n        x = np.random.random((2, 3, 4, 5))\n        result = activations.log_softmax(x, axis=(2, 3))\n        expected = np.zeros((2, 3, 4, 5))\n        for i in range(2):\n            for j in range(3):\n                expected[i, j, :, :] = _ref_log_softmax(x[i, j, :, :])\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n    def test_log_softmax_higher_dim_multiple_axes(self):\n        x = np.random.random((2, 3, 4, 5, 6))\n        result = activations.log_softmax(x, axis=(2, 3, 4))\n        expected = np.zeros((2, 3, 4, 5, 6))\n        for i in range(2):\n            for j in range(3):\n                expected[i, j, :, :, :] = _ref_log_softmax(x[i, j, :, :, :])\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n    def test_log_softmax_negative_axis(self):\n        x = np.random.random((2, 5))\n        result = activations.log_softmax(x, axis=-1)\n        expected = np.zeros((2, 5))\n        for i in range(2):\n            expected[i, :] = _ref_log_softmax(x[i, :])\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n    def test_temporal_log_softmax(self):\n        x = np.random.random((2, 2, 3)) * 10\n        result = activations.log_softmax(x[np.newaxis, :])[0]\n        expected = _ref_log_softmax(x[0, 0])\n        self.assertAllClose(result[0, 0], expected, rtol=1e-05)\n\n    def test_selu(self):\n        alpha = 1.6732632423543772848170429916717\n        scale = 1.0507009873554804934193349852946\n\n        positive_values = np.array([[1, 2]], dtype=backend.floatx())\n        result = activations.selu(positive_values[np.newaxis, :])[0]\n        self.assertAllClose(result, positive_values * scale, rtol=1e-05)\n\n        negative_values = np.array([[-1, -2]], dtype=backend.floatx())\n        result = activations.selu(negative_values[np.newaxis, :])[0]\n        true_result = (np.exp(negative_values) - 1) * scale * alpha\n        self.assertAllClose(result, true_result)\n\n    def test_softplus(self):\n        # Basic test for random values between 0 and 1\n        x = np.random.uniform(0, 1, (2, 5))\n        result = activations.softplus(x[np.newaxis, :])[0]\n        expected = np.vectorize(_ref_softplus)(x)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n        # Test with 1D array\n        x_1d = np.random.uniform(-10, 10, 5)\n        result_1d = activations.softplus(x_1d)\n        expected_1d = np.vectorize(_ref_softplus)(x_1d)\n        self.assertAllClose(result_1d, expected_1d, rtol=1e-05)\n\n        # Test with 3D array\n        x_3d = np.random.uniform(-10, 10, (3, 3, 3))\n        result_3d = activations.softplus(x_3d)\n        expected_3d = np.vectorize(_ref_softplus)(x_3d)\n        self.assertAllClose(result_3d, expected_3d, rtol=1e-05)\n\n        # Test near zero values\n        x_zero = np.random.uniform(-1e-7, 1e-7, (2, 5))\n        result_zero = activations.softplus(x_zero)\n        expected_zero = np.vectorize(_ref_softplus)(x_zero)\n        self.assertAllClose(result_zero, expected_zero, rtol=1e-05)\n\n        # Test large positive values\n        x_large_positive = np.random.uniform(10, 100, (2, 5))\n        result_large_positive = activations.softplus(x_large_positive)\n        expected_large_positive = np.vectorize(_ref_softplus)(x_large_positive)\n        self.assertAllClose(\n            result_large_positive, expected_large_positive, rtol=1e-05\n        )\n\n        # Test large negative values\n        x_large_negative = np.random.uniform(-100, -10, (2, 5))\n        result_large_negative = activations.softplus(x_large_negative)\n        expected_large_negative = np.vectorize(_ref_softplus)(x_large_negative)\n        self.assertAllClose(\n            result_large_negative, expected_large_negative, rtol=1e-05\n        )\n\n    def test_softsign(self):\n        # Basic test for random values between 0 and 1\n        x = np.random.uniform(0, 1, (2, 5))\n        result = activations.softsign(x[np.newaxis, :])[0]\n        expected = np.vectorize(_ref_softsign)(x)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n        # Test with 1D array\n        x_1d = np.random.uniform(-10, 10, 5)\n        result_1d = activations.softsign(x_1d)\n        expected_1d = np.vectorize(_ref_softsign)(x_1d)\n        self.assertAllClose(result_1d, expected_1d, rtol=1e-05)\n\n        # Test with 3D array\n        x_3d = np.random.uniform(-10, 10, (3, 3, 3))\n        result_3d = activations.softsign(x_3d)\n        expected_3d = np.vectorize(_ref_softsign)(x_3d)\n        self.assertAllClose(result_3d, expected_3d, rtol=1e-05)\n\n        # Test near zero values\n        x_zero = np.random.uniform(-1e-7, 1e-7, (2, 5))\n        result_zero = activations.softsign(x_zero)\n        expected_zero = np.vectorize(_ref_softsign)(x_zero)\n        self.assertAllClose(result_zero, expected_zero, rtol=1e-05)\n\n        # Test large positive values\n        x_large_positive = np.random.uniform(10, 100, (2, 5))\n        result_large_positive = activations.softsign(x_large_positive)\n        expected_large_positive = np.vectorize(_ref_softsign)(x_large_positive)\n        self.assertAllClose(\n            result_large_positive, expected_large_positive, rtol=1e-05\n        )\n\n        # Test large negative values\n        x_large_negative = np.random.uniform(-100, -10, (2, 5))\n        result_large_negative = activations.softsign(x_large_negative)\n        expected_large_negative = np.vectorize(_ref_softsign)(x_large_negative)\n        self.assertAllClose(\n            result_large_negative, expected_large_negative, rtol=1e-05\n        )\n\n    def test_sigmoid(self):\n        # Basic test for random values between 0 and 1\n        x = np.random.uniform(0, 1, (2, 5))\n        result = activations.sigmoid(x[np.newaxis, :])[0]\n        expected = np.vectorize(_ref_sigmoid)(x)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n        # Test with 1D array\n        x_1d = np.random.uniform(-10, 10, 5)\n        result_1d = activations.sigmoid(x_1d)\n        expected_1d = np.vectorize(_ref_sigmoid)(x_1d)\n        self.assertAllClose(result_1d, expected_1d, rtol=1e-05)\n\n        # Test with 3D array\n        x_3d = np.random.uniform(-10, 10, (3, 3, 3))\n        result_3d = activations.sigmoid(x_3d)\n        expected_3d = np.vectorize(_ref_sigmoid)(x_3d)\n        self.assertAllClose(result_3d, expected_3d, rtol=1e-05)\n\n        # Test near zero values\n        x_zero = np.random.uniform(-1e-7, 1e-7, (2, 5))\n        result_zero = activations.sigmoid(x_zero)\n        expected_zero = np.vectorize(_ref_sigmoid)(x_zero)\n        self.assertAllClose(result_zero, expected_zero, rtol=1e-05)\n\n        # Test large positive values\n        x_large_positive = np.random.uniform(10, 100, (2, 5))\n        result_large_positive = activations.sigmoid(x_large_positive)\n        expected_large_positive = np.vectorize(_ref_sigmoid)(x_large_positive)\n        self.assertAllClose(\n            result_large_positive, expected_large_positive, rtol=1e-05\n        )\n\n        # Test large negative values\n        x_large_negative = np.random.uniform(-100, -10, (2, 5))\n        result_large_negative = activations.sigmoid(x_large_negative)\n        expected_large_negative = np.vectorize(_ref_sigmoid)(x_large_negative)\n        self.assertAllClose(\n            result_large_negative, expected_large_negative, rtol=1e-05\n        )\n\n    def test_hard_sigmoid(self):\n        # Basic test for random values between 0 and 1\n        x = np.random.uniform(0, 1, (2, 5))\n        result = activations.hard_sigmoid(x[np.newaxis, :])[0]\n        expected = np.vectorize(_ref_hard_sigmoid)(x)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n        # Test with 1D array\n        x_1d = np.random.uniform(-10, 10, 5)\n        result_1d = activations.hard_sigmoid(x_1d)\n        expected_1d = np.vectorize(_ref_hard_sigmoid)(x_1d)\n        self.assertAllClose(result_1d, expected_1d, rtol=1e-05)\n\n        # Test with 3D array\n        x_3d = np.random.uniform(-10, 10, (3, 3, 3))\n        result_3d = activations.hard_sigmoid(x_3d)\n        expected_3d = np.vectorize(_ref_hard_sigmoid)(x_3d)\n        self.assertAllClose(result_3d, expected_3d, rtol=1e-05)\n\n        # Test with strictly positive values much larger than 1\n        x_positive_above_1 = np.random.uniform(\n            5, 10, (2, 5)\n        )  # Adjusted this range\n        result_positive_above_1 = activations.hard_sigmoid(x_positive_above_1)\n        expected_positive_above_1 = np.ones((2, 5))\n        self.assertAllClose(\n            result_positive_above_1, expected_positive_above_1, rtol=1e-05\n        )\n\n    def test_sparse_sigmoid(self):\n        # Basic test for random values between 0 and 1\n        x = np.random.uniform(0, 1, (2, 5))\n        result = activations.sparse_sigmoid(x[np.newaxis, :])[0]\n        expected = np.vectorize(_ref_sparse_sigmoid)(x)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n        # Test with 1D array\n        x_1d = np.random.uniform(-10, 10, 5)\n        result_1d = activations.sparse_sigmoid(x_1d)\n        expected_1d = np.vectorize(_ref_sparse_sigmoid)(x_1d)\n        self.assertAllClose(result_1d, expected_1d, rtol=1e-05)\n\n        # Test with 3D array\n        x_3d = np.random.uniform(-10, 10, (3, 3, 3))\n        result_3d = activations.sparse_sigmoid(x_3d)\n        expected_3d = np.vectorize(_ref_sparse_sigmoid)(x_3d)\n        self.assertAllClose(result_3d, expected_3d, rtol=1e-05)\n\n        # Test large positive values\n        x_large_positive = np.random.uniform(10, 100, (2, 5))\n        result_large_positive = activations.sparse_sigmoid(x_large_positive)\n        expected_large_positive = np.vectorize(_ref_sparse_sigmoid)(\n            x_large_positive\n        )\n        self.assertAllClose(\n            result_large_positive, expected_large_positive, rtol=1e-05\n        )\n\n        # Test large negative values\n        x_large_negative = np.random.uniform(-100, -10, (2, 5))\n        result_large_negative = activations.sparse_sigmoid(x_large_negative)\n        expected_large_negative = np.vectorize(_ref_sparse_sigmoid)(\n            x_large_negative\n        )\n        self.assertAllClose(\n            result_large_negative, expected_large_negative, rtol=1e-05\n        )\n\n    def test_log_sigmoid(self):\n        # Basic test for random values between 0 and 1\n        x = np.random.uniform(0, 1, (2, 5))\n        result = activations.log_sigmoid(x[np.newaxis, :])[0]\n        expected = np.vectorize(_ref_log_sigmoid)(x)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n        # Test with 1D array\n        x_1d = np.random.uniform(-10, 10, 5)\n        result_1d = activations.log_sigmoid(x_1d)\n        expected_1d = np.vectorize(_ref_log_sigmoid)(x_1d)\n        self.assertAllClose(result_1d, expected_1d, rtol=1e-05)\n\n        # Test with 3D array\n        x_3d = np.random.uniform(-10, 10, (3, 3, 3))\n        result_3d = activations.log_sigmoid(x_3d)\n        expected_3d = np.vectorize(_ref_log_sigmoid)(x_3d)\n        self.assertAllClose(result_3d, expected_3d, rtol=1e-05)\n\n        # Test large positive values\n        x_large_positive = np.random.uniform(10, 100, (2, 5))\n        result_large_positive = activations.log_sigmoid(x_large_positive)\n        expected_large_positive = np.vectorize(_ref_log_sigmoid)(\n            x_large_positive\n        )\n        self.assertAllClose(\n            result_large_positive, expected_large_positive, rtol=1e-05\n        )\n\n        # Test large negative values\n        x_large_negative = np.random.uniform(-100, -10, (2, 5))\n        result_large_negative = activations.log_sigmoid(x_large_negative)\n        expected_large_negative = np.vectorize(_ref_log_sigmoid)(\n            x_large_negative\n        )\n        self.assertAllClose(\n            result_large_negative, expected_large_negative, rtol=1e-05\n        )\n\n    def test_hard_silu(self):\n        # Basic test for random values between -3 and 3\n        x = np.random.uniform(-3, 3, (2, 5)).astype(\"float32\")\n        result = activations.hard_silu(x[np.newaxis, :])[0]\n        expected = np.vectorize(_ref_hard_silu)(x)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n        # Test with 1D array\n        x_1d = np.random.uniform(-10, 10, 5).astype(\"float32\")\n        result_1d = activations.hard_silu(x_1d)\n        expected_1d = np.vectorize(_ref_hard_silu)(x_1d)\n        self.assertAllClose(result_1d, expected_1d, rtol=1e-05)\n\n        # Test with 3D array\n        x_3d = np.random.uniform(-10, 10, (3, 3, 3)).astype(\"float32\")\n        result_3d = activations.hard_silu(x_3d)\n        expected_3d = np.vectorize(_ref_hard_silu)(x_3d)\n        self.assertAllClose(result_3d, expected_3d, rtol=1e-05)\n\n        # Test with strictly positive values much larger than 3\n        x_positive_above_3 = np.random.uniform(5, 10, (2, 5)).astype(\"float32\")\n        result_positive_above_3 = activations.hard_silu(x_positive_above_3)\n        expected_positive_above_3 = x_positive_above_3\n        self.assertAllClose(\n            result_positive_above_3, expected_positive_above_3, rtol=1e-05\n        )\n\n        # Test with strictly negative values much smaller than -3\n        x_negatives = np.random.uniform(-10, -5, (2, 5)).astype(\"float32\")\n        result = activations.hard_silu(x_negatives)\n        expected_zeros = np.zeros_like(x_negatives)\n        self.assertAllClose(result, expected_zeros, rtol=1e-05)\n\n    def test_relu_negative_slope(self):\n        # Define the input tensor\n        x = np.array([-10, -5, 0.0, 5, 10])\n\n        # Test with only negative_slope\n        result_negative_slope = activations.relu(x, negative_slope=0.5)\n        expected_negative_slope = np.array([-5.0, -2.5, 0.0, 5.0, 10.0])\n        self.assertAllClose(\n            result_negative_slope, expected_negative_slope, rtol=1e-05\n        )\n\n    def test_relu_max_value(self):\n        # Define the input tensor\n        x = np.array([-10, -5, 0.0, 5, 10])\n\n        # Test with only max_value\n        result_max_value = activations.relu(x, max_value=5.0)\n        expected_max_value = np.array([0.0, 0.0, 0.0, 5.0, 5.0])\n        self.assertAllClose(result_max_value, expected_max_value, rtol=1e-05)\n\n    def test_relu_threshold(self):\n        # Define the input tensor\n        x = np.array([-10, -5, 0.0, 5, 10])\n\n        # Test with only threshold\n        result_threshold = activations.relu(x, threshold=5.0)\n        expected_threshold = np.array([-0.0, -0.0, 0.0, 0.0, 10.0])\n        self.assertAllClose(result_threshold, expected_threshold, rtol=1e-05)\n\n    def test_relu_combined_threshold_and_max_value(self):\n        # Define the input tensor\n        x = np.array([-10, -5, 0.0, 5, 10])\n\n        # Test with threshold and max_value\n        result_combined = activations.relu(x, threshold=5.0, max_value=5.0)\n        expected_combined = np.array([0.0, 0.0, 0.0, 0.0, 5.0])\n        self.assertAllClose(result_combined, expected_combined, rtol=1e-05)\n\n    def test_relu_combined_all_parameters(self):\n        # Define the input tensor\n        x = np.array([-10, -5, 0.0, 5, 10])\n\n        # Test with negative_slope, max_value, and threshold\n        result_combined = activations.relu(\n            x, negative_slope=0.5, max_value=5.0, threshold=5.0\n        )\n        expected_combined = np.array([-7.5, -5.0, -2.5, 0.0, 5.0])\n        self.assertAllClose(result_combined, expected_combined, rtol=1e-05)\n\n    def test_relu_to_trigger_relu6(self):\n        x = np.array([-10, -5, 0.0, 5, 10, 12])\n        result_relu6 = activations.relu(x, max_value=6.0)\n        expected_relu6 = np.array([0.0, 0.0, 0.0, 5.0, 6.0, 6.0])\n        self.assertAllClose(result_relu6, expected_relu6, rtol=1e-05)\n\n    def test_relu_to_trigger_leaky(self):\n        x = np.array([-10, -5, 0.0, 5, 10])\n        result_leaky = activations.relu(x, negative_slope=0.5)\n        expected_leaky = np.array([-5.0, -2.5, 0.0, 5.0, 10.0])\n        self.assertAllClose(result_leaky, expected_leaky, rtol=1e-05)\n\n    def test_relu(self):\n        # Basic test for positive values\n        positive_values = np.random.uniform(0.1, 10, (2, 5))\n        result = activations.relu(positive_values[np.newaxis, :])[0]\n        self.assertAllClose(result, positive_values, rtol=1e-05)\n\n        # Basic test for negative values\n        negative_values = np.random.uniform(-10, -0.1, (2, 5))\n        result = activations.relu(negative_values[np.newaxis, :])[0]\n        expected = np.zeros((2, 5))\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n        # Test with 1D array\n        x_1d = np.random.uniform(-10, 10, 5)\n        result_1d = activations.relu(x_1d)\n        expected_1d = np.maximum(0, x_1d)\n        self.assertAllClose(result_1d, expected_1d, rtol=1e-05)\n\n        # Test with 3D array\n        x_3d = np.random.uniform(-10, 10, (3, 3, 3))\n        result_3d = activations.relu(x_3d)\n        expected_3d = np.maximum(0, x_3d)\n        self.assertAllClose(result_3d, expected_3d, rtol=1e-05)\n\n        # Test near zero values\n        x_zero = np.random.uniform(-1e-7, 1e-7, (2, 5))\n        result_zero = activations.relu(x_zero)\n        expected_zero = np.maximum(0, x_zero)\n        self.assertAllClose(result_zero, expected_zero, rtol=1e-05)\n\n        # Test large positive values\n        x_large_positive = np.random.uniform(1e4, 1e5, (2, 5))\n        result_large_positive = activations.relu(x_large_positive)\n        self.assertAllClose(result_large_positive, x_large_positive, rtol=1e-05)\n\n        # Test large negative values\n        x_large_negative = np.random.uniform(-1e5, -1e4, (2, 5))\n        result_large_negative = activations.relu(x_large_negative)\n        expected_large_negative = np.zeros((2, 5))\n        self.assertAllClose(\n            result_large_negative, expected_large_negative, rtol=1e-05\n        )\n\n    def test_leaky_relu(self):\n        leaky_relu_vectorized = np.vectorize(_ref_leaky_relu)\n\n        # Test for negative_slope = 0.01\n        # Test positive values\n        positive_values = np.random.random((2, 5))\n        result = activations.leaky_relu(\n            positive_values[np.newaxis, :], negative_slope=0.01\n        )[0]\n        expected = leaky_relu_vectorized(positive_values, alpha=0.01)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n        # Test negative values\n        negative_values = np.random.uniform(-1, 0, (2, 5))\n        result = activations.leaky_relu(\n            negative_values[np.newaxis, :], negative_slope=0.01\n        )[0]\n        expected = leaky_relu_vectorized(negative_values, alpha=0.01)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n        # Test for negative_slope = 0.3\n        # Test positive values\n        positive_values = np.random.random((2, 5))\n        result = activations.leaky_relu(\n            positive_values[np.newaxis, :], negative_slope=0.3\n        )[0]\n        expected = leaky_relu_vectorized(positive_values, alpha=0.3)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n        # Test negative values\n        negative_values = np.random.uniform(-1, 0, (2, 5))\n        result = activations.leaky_relu(\n            negative_values[np.newaxis, :], negative_slope=0.3\n        )[0]\n        expected = leaky_relu_vectorized(negative_values, alpha=0.3)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n    def test_relu6(self):\n        relu6_vectorized = np.vectorize(_ref_relu6)\n\n        # Test positive values less than 6\n        positive_values = np.random.uniform(0, 5.9, (2, 5))\n        result = activations.relu6(positive_values[np.newaxis, :])[0]\n        expected = relu6_vectorized(positive_values)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n        # Test positive values greater than 6\n        positive_values_above_6 = np.random.uniform(6.1, 10, (2, 5))\n        result = activations.relu6(positive_values_above_6[np.newaxis, :])[0]\n        expected = relu6_vectorized(positive_values_above_6)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n        # Test negative values\n        negative_values = np.random.uniform(-1, 0, (2, 5))\n        result = activations.relu6(negative_values[np.newaxis, :])[0]\n        expected = relu6_vectorized(negative_values)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n    def test_silu(self):\n        silu_vectorized = np.vectorize(_ref_silu)\n\n        # Test positive values\n        positive_values = np.random.uniform(0, 5.9, (2, 5))\n        result = activations.silu(positive_values[np.newaxis, :])[0]\n        expected = silu_vectorized(positive_values)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n        # Test values around zero (to ensure sigmoid behaves correctly)\n        around_zero_values = np.random.uniform(-1, 1, (2, 5))\n        result = activations.silu(around_zero_values[np.newaxis, :])[0]\n        expected = silu_vectorized(around_zero_values)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n        # Test negative values\n        negative_values = np.random.uniform(-5.9, 0, (2, 5))\n        result = activations.silu(negative_values[np.newaxis, :])[0]\n        expected = silu_vectorized(negative_values)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n    def test_gelu(self):\n        def gelu(x, approximate=False):\n            if approximate:\n                return (\n                    0.5\n                    * x\n                    * (\n                        1.0\n                        + np.tanh(\n                            np.sqrt(2.0 / np.pi)\n                            * (x + 0.044715 * np.power(x, 3))\n                        )\n                    )\n                )\n            else:\n                from scipy.stats import norm\n\n                return x * norm.cdf(x)\n\n        x = np.random.random((2, 5))\n        result = activations.gelu(x[np.newaxis, :])[0]\n        expected = gelu(x)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n        x = np.random.random((2, 5))\n        result = activations.gelu(x[np.newaxis, :], approximate=True)[0]\n        expected = gelu(x, True)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n    def test_celu(self):\n        def celu(x, alpha=1.0):\n            return np.maximum(x, 0.0) + alpha * np.expm1(\n                np.minimum(x, 0.0) / alpha\n            )\n\n        x = np.random.random((2, 5))\n        result = activations.celu(x[np.newaxis, :])[0]\n        expected = celu(x)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n        x = np.random.random((2, 5))\n        result = activations.celu(x[np.newaxis, :], alpha=0.5)[0]\n        expected = celu(x, alpha=0.5)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n    def test_glu(self):\n        def glu(x, axis=-1):\n            x1, x2 = np.split(x, 2, axis)\n            return x1 * (1 / (1 + np.exp(-x2)))\n\n        x = np.random.random((2, 4))\n        result = activations.glu(x[np.newaxis, :])[0]\n        expected = glu(x)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n        x = np.random.random((2, 4))\n        result = activations.glu(x[np.newaxis, :], axis=-2)[0]\n        expected = glu(x, axis=-2)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n    def test_tanh_shrink(self):\n        def tanh_shrink(x):\n            return x - np.tanh(x)\n\n        x = np.random.random((2, 5))\n        result = activations.tanh_shrink(x[np.newaxis, :])[0]\n        expected = tanh_shrink(x)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n    def test_hard_tanh(self):\n        def hard_tanh(x):\n            return np.clip(x, -1.0, 1.0)\n\n        x = np.random.random((2, 5))\n        result = activations.hard_tanh(x[np.newaxis, :])[0]\n        expected = hard_tanh(x)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n    def test_hard_shrink(self):\n        def hard_shrink(x):\n            return np.where(np.abs(x) > 0.5, x, 0.0)\n\n        x = np.random.random((2, 5))\n        result = activations.hard_shrink(x[np.newaxis, :])[0]\n        expected = hard_shrink(x)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n    def test_threshold(self):\n        def threshold(x, threshold_value, value):\n            return np.where(\n                x > threshold_value, x, np.array(value, dtype=x.dtype)\n            )\n\n        x = np.random.random((2, 5))\n        result = activations.threshold(x[np.newaxis, :], 0, 0)[0]\n        expected = threshold(x, 0, 0)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n    def test_squareplus(self):\n        def squareplus(x, b=4):\n            y = x + np.sqrt(x**2 + b)\n            return y / 2\n\n        x = np.random.random((2, 5))\n        result = activations.squareplus(x[np.newaxis, :])[0]\n        expected = squareplus(x)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n    def test_soft_shrink(self):\n        def soft_shrink(x, threshold=0.5):\n            return np.where(\n                x > threshold,\n                x - threshold,\n                np.where(x < -threshold, x + threshold, 0.0),\n            )\n\n        x = np.random.random((2, 5))\n        result = activations.soft_shrink(x[np.newaxis, :])[0]\n        expected = soft_shrink(x)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n    def test_sparse_plus(self):\n        def sparse_plus(x):\n            return np.where(\n                x <= -1,\n                np.zeros_like(x),\n                np.where(x < 1, (1 / 4) * (x + 1) ** 2, x),\n            )\n\n        x = np.random.random((2, 5))\n        result = activations.sparse_plus(x[np.newaxis, :])[0]\n        expected = sparse_plus(x)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n    def test_elu(self):\n        x = np.random.random((2, 5))\n        result = activations.elu(x[np.newaxis, :])[0]\n        self.assertAllClose(result, x, rtol=1e-05)\n        negative_values = np.array([[-1, -2]], dtype=backend.floatx())\n        result = activations.elu(negative_values[np.newaxis, :])[0]\n        true_result = np.exp(negative_values) - 1\n        self.assertAllClose(result, true_result)\n\n    def test_tanh(self):\n        # Basic test for the tanh activation function\n        x = np.random.random((2, 5))\n        result = activations.tanh(x[np.newaxis, :])[0]\n        expected = np.tanh(x)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n        # Basic test for the tanh activation function\n        x = np.random.uniform(-10, 10, (2, 5))\n        result = activations.tanh(x[np.newaxis, :])[0]\n        expected = np.tanh(x)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n        # Test with 1D array\n        x_1d = np.random.uniform(-10, 10, 5)\n        result_1d = activations.tanh(x_1d)\n        expected_1d = np.tanh(x_1d)\n        self.assertAllClose(result_1d, expected_1d, rtol=1e-05)\n\n        # Test with 3D array\n        x_3d = np.random.uniform(-10, 10, (3, 3, 3))\n        result_3d = activations.tanh(x_3d)\n        expected_3d = np.tanh(x_3d)\n        self.assertAllClose(result_3d, expected_3d, rtol=1e-05)\n\n        # Test with strictly positive values\n        x_positive = np.random.uniform(0, 10, (2, 5))\n        result_positive = activations.tanh(x_positive)\n        expected_positive = np.tanh(x_positive)\n        self.assertAllClose(result_positive, expected_positive, rtol=1e-05)\n\n        # Test with strictly negative values\n        x_negative = np.random.uniform(-10, 0, (2, 5))\n        result_negative = activations.tanh(x_negative)\n        expected_negative = np.tanh(x_negative)\n        self.assertAllClose(result_negative, expected_negative, rtol=1e-05)\n\n        # Test near zero values\n        x_zero = np.random.uniform(-1e-7, 1e-7, (2, 5))\n        result_zero = activations.tanh(x_zero)\n        expected_zero = np.tanh(x_zero)\n        self.assertAllClose(result_zero, expected_zero, rtol=1e-05)\n\n        # Test large values to check stability\n        x_large = np.random.uniform(1e4, 1e5, (2, 5))\n        result_large = activations.tanh(x_large)\n        expected_large = np.tanh(x_large)\n        self.assertAllClose(result_large, expected_large, rtol=1e-05)\n\n    def test_exponential(self):\n        # Basic test for the exponential activation function\n        x = np.random.random((2, 5))\n        result = activations.exponential(x[np.newaxis, :])[0]\n        expected = np.exp(x)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n        x = np.random.uniform(-10, 10, (2, 5))\n        result = activations.exponential(x[np.newaxis, :])[0]\n        expected = np.exp(x)\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n        # Test with 1D array\n        x_1d = np.random.uniform(-10, 10, 5)\n        result_1d = activations.exponential(x_1d)\n        expected_1d = np.exp(x_1d)\n        self.assertAllClose(result_1d, expected_1d, rtol=1e-05)\n\n        # Test with 3D array\n        x_3d = np.random.uniform(-10, 10, (3, 3, 3))\n        result_3d = activations.exponential(x_3d)\n        expected_3d = np.exp(x_3d)\n        self.assertAllClose(result_3d, expected_3d, rtol=1e-05)\n\n        # Test with strictly positive values\n        x_positive = np.random.uniform(0, 10, (2, 5))\n        result_positive = activations.exponential(x_positive)\n        expected_positive = np.exp(x_positive)\n        self.assertAllClose(result_positive, expected_positive, rtol=1e-05)\n\n        # Test with strictly negative values\n        x_negative = np.random.uniform(-10, 0, (2, 5))\n        result_negative = activations.exponential(x_negative)\n        expected_negative = np.exp(x_negative)\n        self.assertAllClose(result_negative, expected_negative, rtol=1e-05)\n\n        # Test near zero values\n        x_zero = np.random.uniform(-1e-7, 1e-7, (2, 5))\n        result_zero = activations.exponential(x_zero)\n        expected_zero = np.exp(x_zero)\n        self.assertAllClose(result_zero, expected_zero, rtol=1e-05)\n\n        # Test large values to check stability\n        x_large = np.random.uniform(1e4, 1e5, (2, 5))\n        result_large = activations.exponential(x_large)\n        expected_large = np.exp(x_large)\n        self.assertAllClose(result_large, expected_large, rtol=1e-05)\n\n    def test_mish(self):\n        # Basic test for the mish activation function\n        x = np.random.random((2, 5))\n        result = activations.mish(x[np.newaxis, :])[0]\n        expected = x * np.tanh(_ref_softplus(x))\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n        x = np.random.uniform(-10, 10, (2, 5))\n        result = activations.mish(x[np.newaxis, :])[0]\n        expected = x * np.tanh(_ref_softplus(x))\n        self.assertAllClose(result, expected, rtol=1e-05)\n\n        # Test with 1D array\n        x_1d = np.random.uniform(-10, 10, 5)\n        result_1d = activations.mish(x_1d)\n        expected_1d = x_1d * np.tanh(_ref_softplus(x_1d))\n        self.assertAllClose(result_1d, expected_1d, rtol=1e-05)\n\n        # Test with 3D array\n        x_3d = np.random.uniform(-10, 10, (3, 3, 3))\n        result_3d = activations.mish(x_3d)\n        expected_3d = x_3d * np.tanh(_ref_softplus(x_3d))\n        self.assertAllClose(result_3d, expected_3d, rtol=1e-05)\n\n        # Test with strictly positive values\n        x_positive = np.random.uniform(0, 10, (2, 5))\n        result_positive = activations.mish(x_positive)\n        expected_positive = x_positive * np.tanh(_ref_softplus(x_positive))\n        self.assertAllClose(result_positive, expected_positive, rtol=1e-05)\n\n        # Test with strictly negative values\n        x_negative = np.random.uniform(-10, 0, (2, 5))\n        result_negative = activations.mish(x_negative)\n        expected_negative = x_negative * np.tanh(_ref_softplus(x_negative))\n        self.assertAllClose(result_negative, expected_negative, rtol=1e-05)\n\n        # Test near zero values\n        x_zero = np.random.uniform(-1e-7, 1e-7, (2, 5))\n        result_zero = activations.mish(x_zero)\n        expected_zero = x_zero * np.tanh(_ref_softplus(x_zero))\n        self.assertAllClose(result_zero, expected_zero, rtol=1e-05)\n\n        # Test large values to check stability\n        x_large = np.random.uniform(1e4, 1e5, (2, 5))\n        result_large = activations.mish(x_large)\n        expected_large = x_large * np.tanh(_ref_softplus(x_large))\n        self.assertAllClose(result_large, expected_large, rtol=1e-05)\n\n    def test_linear(self):\n        x = np.random.random((10, 5))\n        self.assertAllClose(x, activations.linear(x))\n\n        # Test with 1D array\n        x_1d = np.random.uniform(-10, 10, 5)\n        self.assertAllClose(x_1d, activations.linear(x_1d))\n\n        # Test with 2D array\n        x = np.random.uniform(-10, 10, (10, 5))\n        self.assertAllClose(x, activations.linear(x))\n\n        # Test with 3D array\n        x_3d = np.random.uniform(-10, 10, (5, 5, 5))\n        self.assertAllClose(x_3d, activations.linear(x_3d))\n\n        # Test with float32 data type\n        x_float32 = np.random.uniform(-10, 10, (10, 5)).astype(np.float32)\n        self.assertAllClose(x_float32, activations.linear(x_float32))\n        # Test with int32 data type\n        x_int32 = np.random.randint(-10, 10, (10, 5)).astype(np.int32)\n        self.assertAllClose(x_int32, activations.linear(x_int32))\n\n    def test_sparsemax(self):\n        # result check with 1d\n        x_1d = np.linspace(1, 12, num=12)\n        expected_result = np.zeros_like(x_1d)\n        expected_result[-1] = 1.0\n        self.assertAllClose(expected_result, activations.sparsemax(x_1d))\n\n        # result check with 2d\n        x_2d = np.linspace(1, 12, num=12).reshape(-1, 2)\n        expected_result = np.zeros_like(x_2d)\n        expected_result[:, -1] = 1.0\n        self.assertAllClose(expected_result, activations.sparsemax(x_2d))\n\n        # result check with 3d\n        x_3d = np.linspace(1, 12, num=12).reshape(-1, 1, 3)\n        expected_result = np.zeros_like(x_3d)\n        expected_result[:, :, -1] = 1.0\n        self.assertAllClose(expected_result, activations.sparsemax(x_3d))\n\n        # result check with axis=-2 with 2d input\n        x_2d = np.linspace(1, 12, num=12).reshape(-1, 2)\n        expected_result = np.zeros_like(x_2d)\n        expected_result[-1, :] = 1.0\n        self.assertAllClose(\n            expected_result, activations.sparsemax(x_2d, axis=-2)\n        )\n\n        # result check with axis=-2 with 3d input\n        x_3d = np.linspace(1, 12, num=12).reshape(-1, 1, 3)\n        expected_result = np.ones_like(x_3d)\n        self.assertAllClose(\n            expected_result, activations.sparsemax(x_3d, axis=-2)\n        )\n\n        # result check with axis=-3 with 3d input\n        x_3d = np.linspace(1, 12, num=12).reshape(-1, 1, 3)\n        expected_result = np.zeros_like(x_3d)\n        expected_result[-1, :, :] = 1.0\n        self.assertAllClose(\n            expected_result, activations.sparsemax(x_3d, axis=-3)\n        )\n\n        # result check with axis=-3 with 4d input\n        x_4d = np.linspace(1, 12, num=12).reshape(-1, 1, 1, 2)\n        expected_result = np.ones_like(x_4d)\n        self.assertAllClose(\n            expected_result, activations.sparsemax(x_4d, axis=-3)\n        )\n\n    def test_get_method(self):\n        obj = activations.get(\"relu\")\n        self.assertEqual(obj, activations.relu)\n\n        obj = activations.get(None)\n        self.assertEqual(obj, activations.linear)\n\n        with self.assertRaises(ValueError):\n            activations.get(\"typo\")\n"
  },
  {
    "path": "keras/src/api_export.py",
    "content": "try:\n    import namex\nexcept ImportError:\n    namex = None\n\n\n# These dicts reference \"canonical names\" only\n# (i.e. the first name an object was registered with).\nREGISTERED_NAMES_TO_OBJS = {}\nREGISTERED_OBJS_TO_NAMES = {}\n\n\ndef register_internal_serializable(path, symbol):\n    global REGISTERED_NAMES_TO_OBJS\n    if isinstance(path, (list, tuple)):\n        name = path[0]\n    else:\n        name = path\n    REGISTERED_NAMES_TO_OBJS[name] = symbol\n    REGISTERED_OBJS_TO_NAMES[symbol] = name\n\n\ndef get_symbol_from_name(name):\n    return REGISTERED_NAMES_TO_OBJS.get(name, None)\n\n\ndef get_name_from_symbol(symbol):\n    return REGISTERED_OBJS_TO_NAMES.get(symbol, None)\n\n\nif namex:\n\n    class keras_export(namex.export):\n        def __init__(self, path):\n            super().__init__(package=\"keras\", path=path)\n\n        def __call__(self, symbol):\n            register_internal_serializable(self.path, symbol)\n            return super().__call__(symbol)\n\nelse:\n\n    class keras_export:\n        def __init__(self, path):\n            self.path = path\n\n        def __call__(self, symbol):\n            register_internal_serializable(self.path, symbol)\n            return symbol\n"
  },
  {
    "path": "keras/src/applications/__init__.py",
    "content": ""
  },
  {
    "path": "keras/src/applications/applications_test.py",
    "content": "import os\n\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import testing\nfrom keras.src.applications import convnext\nfrom keras.src.applications import densenet\nfrom keras.src.applications import efficientnet\nfrom keras.src.applications import efficientnet_v2\nfrom keras.src.applications import inception_resnet_v2\nfrom keras.src.applications import inception_v3\nfrom keras.src.applications import mobilenet\nfrom keras.src.applications import mobilenet_v2\nfrom keras.src.applications import mobilenet_v3\nfrom keras.src.applications import nasnet\nfrom keras.src.applications import resnet\nfrom keras.src.applications import resnet_v2\nfrom keras.src.applications import vgg16\nfrom keras.src.applications import vgg19\nfrom keras.src.applications import xception\nfrom keras.src.layers import Conv2D\nfrom keras.src.layers import Input\nfrom keras.src.saving import serialization_lib\nfrom keras.src.utils import file_utils\nfrom keras.src.utils import image_utils\n\ntry:\n    import PIL\nexcept ImportError:\n    PIL = None\n\nMODEL_LIST = [\n    # vgg\n    (vgg16.VGG16, 512, vgg16),\n    (vgg19.VGG19, 512, vgg19),\n    # xception\n    (xception.Xception, 2048, xception),\n    # inception\n    (inception_v3.InceptionV3, 2048, inception_v3),\n    (inception_resnet_v2.InceptionResNetV2, 1536, inception_resnet_v2),\n    # mobilenet\n    (mobilenet.MobileNet, 1024, mobilenet),\n    (mobilenet_v2.MobileNetV2, 1280, mobilenet_v2),\n    (mobilenet_v3.MobileNetV3Small, 576, mobilenet_v3),\n    (mobilenet_v3.MobileNetV3Large, 960, mobilenet_v3),\n    # efficientnet\n    (efficientnet.EfficientNetB0, 1280, efficientnet),\n    (efficientnet.EfficientNetB1, 1280, efficientnet),\n    (efficientnet.EfficientNetB2, 1408, efficientnet),\n    (efficientnet.EfficientNetB3, 1536, efficientnet),\n    (efficientnet.EfficientNetB4, 1792, efficientnet),\n    (efficientnet.EfficientNetB5, 2048, efficientnet),\n    (efficientnet.EfficientNetB6, 2304, efficientnet),\n    (efficientnet.EfficientNetB7, 2560, efficientnet),\n    (efficientnet_v2.EfficientNetV2B0, 1280, efficientnet_v2),\n    (efficientnet_v2.EfficientNetV2B1, 1280, efficientnet_v2),\n    (efficientnet_v2.EfficientNetV2B2, 1408, efficientnet_v2),\n    (efficientnet_v2.EfficientNetV2B3, 1536, efficientnet_v2),\n    (efficientnet_v2.EfficientNetV2S, 1280, efficientnet_v2),\n    (efficientnet_v2.EfficientNetV2M, 1280, efficientnet_v2),\n    (efficientnet_v2.EfficientNetV2L, 1280, efficientnet_v2),\n    # densenet\n    (densenet.DenseNet121, 1024, densenet),\n    (densenet.DenseNet169, 1664, densenet),\n    (densenet.DenseNet201, 1920, densenet),\n    # convnext\n    (convnext.ConvNeXtTiny, 768, convnext),\n    (convnext.ConvNeXtSmall, 768, convnext),\n    (convnext.ConvNeXtBase, 1024, convnext),\n    (convnext.ConvNeXtLarge, 1536, convnext),\n    (convnext.ConvNeXtXLarge, 2048, convnext),\n    # nasnet\n    (nasnet.NASNetMobile, 1056, nasnet),\n    (nasnet.NASNetLarge, 4032, nasnet),\n    # resnet\n    (resnet.ResNet50, 2048, resnet),\n    (resnet.ResNet101, 2048, resnet),\n    (resnet.ResNet152, 2048, resnet),\n    (resnet_v2.ResNet50V2, 2048, resnet_v2),\n    (resnet_v2.ResNet101V2, 2048, resnet_v2),\n    (resnet_v2.ResNet152V2, 2048, resnet_v2),\n]\nMODELS_UNSUPPORTED_CHANNELS_FIRST = [\"ConvNeXt\", \"DenseNet\", \"NASNet\"]\n\n# Add names for `named_parameters`, and add each data format for each model\ntest_parameters = [\n    (\n        \"{}_{}\".format(model[0].__name__, image_data_format),\n        *model,\n        image_data_format,\n    )\n    for image_data_format in [\"channels_first\", \"channels_last\"]\n    for model in MODEL_LIST\n]\n\n\ndef _get_elephant(target_size):\n    # For models that don't include a Flatten step,\n    # the default is to accept variable-size inputs\n    # even when loading ImageNet weights (since it is possible).\n    # In this case, default to 299x299.\n    TEST_IMAGE_PATH = (\n        \"https://storage.googleapis.com/tensorflow/\"\n        \"keras-applications/tests/elephant.jpg\"\n    )\n\n    if target_size[0] is None:\n        target_size = (299, 299)\n    test_image = file_utils.get_file(\"elephant.jpg\", TEST_IMAGE_PATH)\n    img = image_utils.load_img(test_image, target_size=tuple(target_size))\n    x = image_utils.img_to_array(img)\n    return np.expand_dims(x, axis=0)\n\n\n@pytest.mark.skipif(\n    os.environ.get(\"SKIP_APPLICATIONS_TESTS\"),\n    reason=\"Env variable set to skip.\",\n)\n@pytest.mark.requires_trainable_backend\nclass ApplicationsTest(testing.TestCase):\n    @classmethod\n    def setUpClass(cls):\n        cls.original_image_data_format = backend.image_data_format()\n\n    @classmethod\n    def tearDownClass(cls):\n        backend.set_image_data_format(cls.original_image_data_format)\n\n    def skip_if_invalid_image_data_format_for_model(\n        self, app, image_data_format\n    ):\n        does_not_support_channels_first = any(\n            [\n                unsupported_name.lower() in app.__name__.lower()\n                for unsupported_name in MODELS_UNSUPPORTED_CHANNELS_FIRST\n            ]\n        )\n        if (\n            image_data_format == \"channels_first\"\n            and does_not_support_channels_first\n        ):\n            self.skipTest(\n                \"{} does not support channels first\".format(app.__name__)\n            )\n\n    @parameterized.named_parameters(test_parameters)\n    def test_application_notop_variable_input_channels(\n        self, app, last_dim, _, image_data_format\n    ):\n        if app == nasnet.NASNetMobile and backend.backend() == \"torch\":\n            self.skipTest(\n                \"NASNetMobile pretrained incorrect with torch backend.\"\n            )\n        self.skip_if_invalid_image_data_format_for_model(app, image_data_format)\n        backend.set_image_data_format(image_data_format)\n\n        # Test compatibility with 1 channel\n        if image_data_format == \"channels_first\":\n            input_shape = (1, None, None)\n            correct_output_shape = [None, last_dim, None, None]\n        else:\n            input_shape = (None, None, 1)\n            correct_output_shape = [None, None, None, last_dim]\n\n        model = app(weights=None, include_top=False, input_shape=input_shape)\n        output_shape = list(model.outputs[0].shape)\n        self.assertEqual(output_shape, correct_output_shape)\n\n        # Test compatibility with 4 channels\n        if image_data_format == \"channels_first\":\n            input_shape = (4, None, None)\n        else:\n            input_shape = (None, None, 4)\n        model = app(weights=None, include_top=False, input_shape=input_shape)\n        output_shape = list(model.outputs[0].shape)\n        self.assertEqual(output_shape, correct_output_shape)\n\n    @parameterized.named_parameters(test_parameters)\n    @pytest.mark.skipif(PIL is None, reason=\"Requires PIL.\")\n    def test_application_base(self, app, _, app_module, image_data_format):\n        import tensorflow as tf\n\n        if app == nasnet.NASNetMobile and backend.backend() == \"torch\":\n            self.skipTest(\n                \"NASNetMobile pretrained incorrect with torch backend.\"\n            )\n        if (\n            image_data_format == \"channels_first\"\n            and len(tf.config.list_physical_devices(\"GPU\")) == 0\n            and backend.backend() == \"tensorflow\"\n        ):\n            self.skipTest(\n                \"Conv2D doesn't support channels_first using CPU with \"\n                \"tensorflow backend\"\n            )\n        self.skip_if_invalid_image_data_format_for_model(app, image_data_format)\n        backend.set_image_data_format(image_data_format)\n\n        # Can be instantiated with default arguments\n        model = app(weights=\"imagenet\")\n\n        # Can run a correct inference on a test image\n        if image_data_format == \"channels_first\":\n            shape = model.input_shape[2:4]\n        else:\n            shape = model.input_shape[1:3]\n        x = _get_elephant(shape)\n\n        x = app_module.preprocess_input(x)\n        preds = model.predict(x)\n        names = [p[1] for p in app_module.decode_predictions(preds)[0]]\n        # Test correct label is in top 3 (weak correctness test).\n        self.assertIn(\"African_elephant\", names[:3])\n\n        # Can be serialized and deserialized\n        config = serialization_lib.serialize_keras_object(model)\n        reconstructed_model = serialization_lib.deserialize_keras_object(config)\n        self.assertEqual(len(model.weights), len(reconstructed_model.weights))\n\n    @parameterized.named_parameters(test_parameters)\n    def test_application_notop_custom_input_shape(\n        self, app, last_dim, _, image_data_format\n    ):\n        if app == nasnet.NASNetMobile and backend.backend() == \"torch\":\n            self.skipTest(\n                \"NASNetMobile pretrained incorrect with torch backend.\"\n            )\n        self.skip_if_invalid_image_data_format_for_model(app, image_data_format)\n        backend.set_image_data_format(image_data_format)\n\n        if image_data_format == \"channels_first\":\n            input_shape = (3, 123, 123)\n            last_dim_axis = 1\n        else:\n            input_shape = (123, 123, 3)\n            last_dim_axis = -1\n        model = app(weights=None, include_top=False, input_shape=input_shape)\n        output_shape = list(model.outputs[0].shape)\n        self.assertEqual(output_shape[last_dim_axis], last_dim)\n\n    @parameterized.named_parameters(test_parameters)\n    def test_application_notop_custom_input_tensor(\n        self, app, last_dim, _, image_data_format\n    ):\n        if app == nasnet.NASNetMobile and backend.backend() == \"torch\":\n            self.skipTest(\n                \"NASNetMobile pretrained incorrect with torch backend.\"\n            )\n        self.skip_if_invalid_image_data_format_for_model(app, image_data_format)\n        backend.set_image_data_format(image_data_format)\n\n        if image_data_format == \"channels_first\":\n            input_shape = (4, 200, 200)\n            last_dim_axis = 1\n        else:\n            input_shape = (200, 200, 4)\n            last_dim_axis = -1\n\n        inputs_custom = Input(shape=input_shape, name=\"custom_input\")\n        inputs_custom = Conv2D(3, (2, 2), padding=\"valid\", strides=(2, 2))(\n            inputs_custom\n        )\n        model = app(weights=None, include_top=False, input_tensor=inputs_custom)\n        output_shape = list(model.outputs[0].shape)\n        self.assertEqual(output_shape[last_dim_axis], last_dim)\n\n    @parameterized.named_parameters(test_parameters)\n    def test_application_pooling(self, app, last_dim, _, image_data_format):\n        if app == nasnet.NASNetMobile and backend.backend() == \"torch\":\n            self.skipTest(\n                \"NASNetMobile pretrained incorrect with torch backend.\"\n            )\n        self.skip_if_invalid_image_data_format_for_model(app, image_data_format)\n        backend.set_image_data_format(image_data_format)\n\n        model = app(weights=None, include_top=False, pooling=\"max\")\n        output_shape = list(model.outputs[0].shape)\n        self.assertEqual(output_shape, [None, last_dim])\n\n    @parameterized.named_parameters(test_parameters)\n    def test_application_classifier_activation(self, app, *_):\n        if app == nasnet.NASNetMobile and backend.backend() == \"torch\":\n            self.skipTest(\n                \"NASNetMobile pretrained incorrect with torch backend.\"\n            )\n\n        model = app(\n            weights=None, include_top=True, classifier_activation=\"softmax\"\n        )\n        last_layer_act = model.layers[-1].activation.__name__\n        self.assertEqual(last_layer_act, \"softmax\")\n"
  },
  {
    "path": "keras/src/applications/convnext.py",
    "content": "import numpy as np\n\nfrom keras.src import backend\nfrom keras.src import initializers\nfrom keras.src import layers\nfrom keras.src import ops\nfrom keras.src import random\nfrom keras.src.api_export import keras_export\nfrom keras.src.applications import imagenet_utils\nfrom keras.src.layers.layer import Layer\nfrom keras.src.models import Functional\nfrom keras.src.models import Sequential\nfrom keras.src.ops import operation_utils\nfrom keras.src.utils import file_utils\n\nBASE_WEIGHTS_PATH = (\n    \"https://storage.googleapis.com/tensorflow/keras-applications/convnext/\"\n)\n\nWEIGHTS_HASHES = {\n    \"convnext_tiny\": (\n        \"8ae6e78ce2933352b1ef4008e6dd2f17bc40771563877d156bc6426c7cf503ff\",\n        \"d547c096cabd03329d7be5562c5e14798aa39ed24b474157cef5e85ab9e49ef1\",\n    ),\n    \"convnext_small\": (\n        \"ce1277d8f1ee5a0ef0e171469089c18f5233860ceaf9b168049cb9263fd7483c\",\n        \"6fc8009faa2f00c1c1dfce59feea9b0745eb260a7dd11bee65c8e20843da6eab\",\n    ),\n    \"convnext_base\": (\n        \"52cbb006d3dadd03f6e095a8ca1aca47aecdd75acb4bc74bce1f5c695d0086e6\",\n        \"40a20c5548a5e9202f69735ecc06c990e6b7c9d2de39f0361e27baeb24cb7c45\",\n    ),\n    \"convnext_large\": (\n        \"070c5ed9ed289581e477741d3b34beffa920db8cf590899d6d2c67fba2a198a6\",\n        \"96f02b6f0753d4f543261bc9d09bed650f24dd6bc02ddde3066135b63d23a1cd\",\n    ),\n    \"convnext_xlarge\": (\n        \"c1f5ccab661354fc3a79a10fa99af82f0fbf10ec65cb894a3ae0815f17a889ee\",\n        \"de3f8a54174130e0cecdc71583354753d557fcf1f4487331558e2a16ba0cfe05\",\n    ),\n}\n\n\nMODEL_CONFIGS = {\n    \"tiny\": {\n        \"depths\": [3, 3, 9, 3],\n        \"projection_dims\": [96, 192, 384, 768],\n        \"default_size\": 224,\n    },\n    \"small\": {\n        \"depths\": [3, 3, 27, 3],\n        \"projection_dims\": [96, 192, 384, 768],\n        \"default_size\": 224,\n    },\n    \"base\": {\n        \"depths\": [3, 3, 27, 3],\n        \"projection_dims\": [128, 256, 512, 1024],\n        \"default_size\": 224,\n    },\n    \"large\": {\n        \"depths\": [3, 3, 27, 3],\n        \"projection_dims\": [192, 384, 768, 1536],\n        \"default_size\": 224,\n    },\n    \"xlarge\": {\n        \"depths\": [3, 3, 27, 3],\n        \"projection_dims\": [256, 512, 1024, 2048],\n        \"default_size\": 224,\n    },\n}\n\nBASE_DOCSTRING = \"\"\"Instantiates the {name} architecture.\n\nReferences:\n- [A ConvNet for the 2020s](https://arxiv.org/abs/2201.03545)\n(CVPR 2022)\n\nFor image classification use cases, see\n[this page for detailed examples](\nhttps://keras.io/api/applications/#usage-examples-for-image-classification-models).\nFor transfer learning use cases, make sure to read the\n[guide to transfer learning & fine-tuning](\nhttps://keras.io/guides/transfer_learning/).\n\nThe `base`, `large`, and `xlarge` models were first pre-trained on the\nImageNet-21k dataset and then fine-tuned on the ImageNet-1k dataset. The\npre-trained parameters of the models were assembled from the\n[official repository](https://github.com/facebookresearch/ConvNeXt). To get a\nsense of how these parameters were converted to Keras compatible parameters,\nplease refer to\n[this repository](https://github.com/sayakpaul/keras-convnext-conversion).\n\nNote: Each Keras Application expects a specific kind of input preprocessing.\nFor ConvNeXt, preprocessing is included in the model using a `Normalization`\nlayer.  ConvNeXt models expect their inputs to be float or uint8 tensors of\npixels with values in the [0-255] range.\n\nWhen calling the `summary()` method after instantiating a ConvNeXt model,\nprefer setting the `expand_nested` argument `summary()` to `True` to better\ninvestigate the instantiated model.\n\nArgs:\n    include_top: Whether to include the fully-connected\n        layer at the top of the network. Defaults to `True`.\n    weights: One of `None` (random initialization),\n        `\"imagenet\"` (pre-training on ImageNet-1k), or the path to the weights\n        file to be loaded. Defaults to `\"imagenet\"`.\n    input_tensor: Optional Keras tensor\n        (i.e. output of `layers.Input()`)\n        to use as image input for the model.\n    input_shape: Optional shape tuple, only to be specified\n        if `include_top` is `False`.\n        It should have exactly 3 inputs channels.\n    pooling: Optional pooling mode for feature extraction\n        when `include_top` is `False`. Defaults to None.\n        - `None` means that the output of the model will be\n        the 4D tensor output of the last convolutional layer.\n        - `avg` means that global average pooling\n        will be applied to the output of the\n        last convolutional layer, and thus\n        the output of the model will be a 2D tensor.\n        - `max` means that global max pooling will\n        be applied.\n    classes: Optional number of classes to classify images\n        into, only to be specified if `include_top` is `True`, and\n        if no `weights` argument is specified. Defaults to 1000 (number of\n        ImageNet classes).\n    classifier_activation: A `str` or callable. The activation function to use\n        on the \"top\" layer. Ignored unless `include_top=True`. Set\n        `classifier_activation=None` to return the logits of the \"top\" layer.\n        Defaults to `\"softmax\"`.\n        When loading pretrained weights, `classifier_activation` can only\n        be `None` or `\"softmax\"`.\n    name: The name of the model (string).\n\nReturns:\n    A model instance.\n\"\"\"\n\n\nclass StochasticDepth(Layer):\n    \"\"\"Stochastic Depth module.\n\n    It performs batch-wise dropping rather than sample-wise. In libraries like\n    `timm`, it's similar to `DropPath` layers that drops residual paths\n    sample-wise.\n\n    References:\n    - https://github.com/rwightman/pytorch-image-models\n\n    Args:\n      drop_path_rate (float): Probability of dropping paths. Should be within\n        [0, 1].\n\n    Returns:\n      Tensor either with the residual path dropped or kept.\n    \"\"\"\n\n    def __init__(self, drop_path_rate, **kwargs):\n        super().__init__(**kwargs)\n        self.drop_path_rate = drop_path_rate\n\n    def call(self, x, training=None):\n        if training:\n            keep_prob = 1 - self.drop_path_rate\n            shape = (ops.shape(x)[0],) + (1,) * (len(ops.shape(x)) - 1)\n            random_tensor = keep_prob + random.uniform(shape, 0, 1)\n            random_tensor = ops.floor(random_tensor)\n            return (x / keep_prob) * random_tensor\n        return x\n\n    def get_config(self):\n        config = super().get_config()\n        config.update({\"drop_path_rate\": self.drop_path_rate})\n        return config\n\n\nclass LayerScale(Layer):\n    \"\"\"Layer scale module.\n\n    References:\n\n    - https://arxiv.org/abs/2103.17239\n\n    Args:\n        init_values (float): Initial value for layer scale. Should be within\n            [0, 1].\n        projection_dim (int): Projection dimensionality.\n\n    Returns:\n        Tensor multiplied to the scale.\n    \"\"\"\n\n    def __init__(self, init_values, projection_dim, **kwargs):\n        super().__init__(**kwargs)\n        self.init_values = init_values\n        self.projection_dim = projection_dim\n\n    def build(self, _):\n        self.gamma = self.add_weight(\n            shape=(self.projection_dim,),\n            initializer=initializers.Constant(self.init_values),\n            trainable=True,\n        )\n\n    def call(self, x):\n        return x * self.gamma\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"init_values\": self.init_values,\n                \"projection_dim\": self.projection_dim,\n            }\n        )\n        return config\n\n\ndef ConvNeXtBlock(\n    projection_dim, drop_path_rate=0.0, layer_scale_init_value=1e-6, name=None\n):\n    \"\"\"ConvNeXt block.\n\n    References:\n    - https://arxiv.org/abs/2201.03545\n    - https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py\n\n    Notes:\n        In the original ConvNeXt implementation (linked above), the authors use\n        `Dense` layers for pointwise convolutions for increased efficiency.\n        Following that, this implementation also uses the same.\n\n    Args:\n        projection_dim (int): Number of filters for convolution layers. In the\n            ConvNeXt paper, this is referred to as projection dimension.\n        drop_path_rate (float): Probability of dropping paths. Should be within\n            [0, 1].\n        layer_scale_init_value (float): Layer scale value.\n            Should be a small float number.\n        name: name to path to the keras layer.\n\n    Returns:\n        A function representing a ConvNeXtBlock block.\n    \"\"\"\n    if name is None:\n        name = f\"prestem{str(backend.get_uid('prestem'))}\"\n\n    def apply(inputs):\n        x = inputs\n\n        x = layers.Conv2D(\n            filters=projection_dim,\n            kernel_size=7,\n            padding=\"same\",\n            groups=projection_dim,\n            name=f\"{name}_depthwise_conv\",\n        )(x)\n        x = layers.LayerNormalization(epsilon=1e-6, name=f\"{name}_layernorm\")(x)\n        x = layers.Dense(4 * projection_dim, name=f\"{name}_pointwise_conv_1\")(x)\n        x = layers.Activation(\"gelu\", name=f\"{name}_gelu\")(x)\n        x = layers.Dense(projection_dim, name=f\"{name}_pointwise_conv_2\")(x)\n\n        if layer_scale_init_value is not None:\n            x = LayerScale(\n                layer_scale_init_value,\n                projection_dim,\n                name=f\"{name}_layer_scale\",\n            )(x)\n        if drop_path_rate:\n            layer = StochasticDepth(\n                drop_path_rate, name=f\"{name}_stochastic_depth\"\n            )\n        else:\n            layer = layers.Activation(\"linear\", name=f\"{name}_identity\")\n\n        return inputs + layer(x)\n\n    return apply\n\n\ndef PreStem(name=None):\n    \"\"\"Normalizes inputs with ImageNet-1k mean and std.\"\"\"\n    if name is None:\n        name = \"prestem{0}\".format(str(backend.get_uid(\"prestem\")))\n\n    def apply(x):\n        x = layers.Normalization(\n            mean=[0.485 * 255, 0.456 * 255, 0.406 * 255],\n            variance=[\n                (0.229 * 255) ** 2,\n                (0.224 * 255) ** 2,\n                (0.225 * 255) ** 2,\n            ],\n            name=f\"{name}_prestem_normalization\",\n        )(x)\n        return x\n\n    return apply\n\n\ndef Head(num_classes=1000, classifier_activation=None, name=None):\n    \"\"\"Implementation of classification head of ConvNeXt.\n\n    Args:\n        num_classes: number of classes for Dense layer\n        classifier_activation: activation function for the Dense layer\n        name: name prefix\n\n    Returns:\n        Classification head function.\n    \"\"\"\n    if name is None:\n        name = str(backend.get_uid(\"head\"))\n\n    def apply(x):\n        x = layers.GlobalAveragePooling2D(name=f\"{name}_head_gap\")(x)\n        x = layers.LayerNormalization(\n            epsilon=1e-6, name=f\"{name}_head_layernorm\"\n        )(x)\n        x = layers.Dense(\n            num_classes,\n            activation=classifier_activation,\n            name=f\"{name}_head_dense\",\n        )(x)\n        return x\n\n    return apply\n\n\ndef ConvNeXt(\n    depths,\n    projection_dims,\n    drop_path_rate=0.0,\n    layer_scale_init_value=1e-6,\n    default_size=224,\n    name=\"convnext\",\n    include_preprocessing=True,\n    include_top=True,\n    weights=None,\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    weights_name=None,\n):\n    \"\"\"Instantiates ConvNeXt architecture given specific configuration.\n\n    Args:\n        depths: An iterable containing depths for each individual stages.\n        projection_dims: An iterable containing output number of channels of\n        each individual stages.\n        drop_path_rate: Stochastic depth probability. If 0.0, then stochastic\n            depth won't be used.\n        layer_scale_init_value: Layer scale coefficient. If 0.0, layer scaling\n            won't be used.\n        default_size: Default input image size.\n        name: An optional name for the model.\n        include_preprocessing: boolean denoting whether to\n            include preprocessing in the model.\n            When `weights=\"imagenet\"` this should always be `True`.\n            But for other models (e.g., randomly initialized) you should set it\n            to `False` and apply preprocessing to data accordingly.\n        include_top: Boolean denoting whether to include classification\n            head to the model.\n        weights: one of `None` (random initialization), `\"imagenet\"`\n            (pre-training on ImageNet-1k),\n            or the path to the weights file to be loaded.\n        input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) to\n            use as image input for the model.\n        input_shape: optional shape tuple, only to be specified if `include_top`\n            is `False`. It should have exactly 3 inputs channels.\n        pooling: optional pooling mode for feature extraction when `include_top`\n            is `False`.\n            - `None` means that the output of the model will be\n                the 4D tensor output of the last convolutional layer.\n            - `avg` means that global average pooling will be applied\n                to the output of the last convolutional layer,\n                and thus the output of the model will be a 2D tensor.\n            - `max` means that global max pooling will be applied.\n        classes: optional number of classes to classify images into,\n            only to be specified if `include_top` is `True`,\n            and if no `weights` argument is specified.\n        classifier_activation: A `str` or callable.\n            The activation function to use\n            on the \"top\" layer. Ignored unless `include_top=True`.\n            Set `classifier_activation=None` to return the logits\n            of the \"top\" layer.\n\n    Returns:\n        A model instance.\n    \"\"\"\n    if backend.image_data_format() == \"channels_first\":\n        raise ValueError(\n            \"ConvNeXt does not support the `channels_first` image data \"\n            \"format. Switch to `channels_last` by editing your local \"\n            \"config file at ~/.keras/keras.json\"\n        )\n    if not (weights in {\"imagenet\", None} or file_utils.exists(weights)):\n        raise ValueError(\n            \"The `weights` argument should be either \"\n            \"`None` (random initialization), `imagenet` \"\n            \"(pre-training on ImageNet), \"\n            \"or the path to the weights file to be loaded.\"\n        )\n\n    if weights == \"imagenet\" and include_top and classes != 1000:\n        raise ValueError(\n            'If using `weights=\"imagenet\"` with `include_top=True`, '\n            \"`classes` should be 1000. \"\n            f\"Received classes={classes}\"\n        )\n\n    # Determine proper input shape.\n    input_shape = imagenet_utils.obtain_input_shape(\n        input_shape,\n        default_size=default_size,\n        min_size=32,\n        data_format=backend.image_data_format(),\n        require_flatten=include_top,\n        weights=weights,\n    )\n\n    if input_tensor is None:\n        img_input = layers.Input(shape=input_shape)\n    else:\n        if not backend.is_keras_tensor(input_tensor):\n            img_input = layers.Input(tensor=input_tensor, shape=input_shape)\n        else:\n            img_input = input_tensor\n\n    if input_tensor is not None:\n        inputs = operation_utils.get_source_inputs(input_tensor)[0]\n        x = input_tensor\n    else:\n        inputs = img_input\n        x = inputs\n\n    if include_preprocessing:\n        channel_axis = (\n            3 if backend.image_data_format() == \"channels_last\" else 1\n        )\n        num_channels = input_shape[channel_axis - 1]\n        if num_channels == 3:\n            x = PreStem(name=name)(x)\n\n    # Stem block.\n    stem = Sequential(\n        [\n            layers.Conv2D(\n                projection_dims[0],\n                kernel_size=4,\n                strides=4,\n                name=f\"{name}_stem_conv\",\n            ),\n            layers.LayerNormalization(\n                epsilon=1e-6, name=f\"{name}_stem_layernorm\"\n            ),\n        ],\n        name=f\"{name}_stem\",\n    )\n\n    # Downsampling blocks.\n    downsample_layers = []\n    downsample_layers.append(stem)\n\n    num_downsample_layers = 3\n    for i in range(num_downsample_layers):\n        downsample_layer = Sequential(\n            [\n                layers.LayerNormalization(\n                    epsilon=1e-6,\n                    name=f\"{name}_downsampling_layernorm_{i}\",\n                ),\n                layers.Conv2D(\n                    projection_dims[i + 1],\n                    kernel_size=2,\n                    strides=2,\n                    name=f\"{name}_downsampling_conv_{i}\",\n                ),\n            ],\n            name=f\"{name}_downsampling_block_{i}\",\n        )\n        downsample_layers.append(downsample_layer)\n\n    # Stochastic depth schedule.\n    # This is referred from the original ConvNeXt codebase:\n    # https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py#L86\n    depth_drop_rates = [\n        float(x) for x in np.linspace(0.0, drop_path_rate, sum(depths))\n    ]\n\n    # First apply downsampling blocks and then apply ConvNeXt stages.\n    cur = 0\n\n    num_convnext_blocks = 4\n    for i in range(num_convnext_blocks):\n        x = downsample_layers[i](x)\n        for j in range(depths[i]):\n            x = ConvNeXtBlock(\n                projection_dim=projection_dims[i],\n                drop_path_rate=depth_drop_rates[cur + j],\n                layer_scale_init_value=layer_scale_init_value,\n                name=name + f\"_stage_{i}_block_{j}\",\n            )(x)\n        cur += depths[i]\n\n    if include_top:\n        imagenet_utils.validate_activation(classifier_activation, weights)\n        x = Head(\n            num_classes=classes,\n            classifier_activation=classifier_activation,\n            name=name,\n        )(x)\n\n    else:\n        if pooling == \"avg\":\n            x = layers.GlobalAveragePooling2D()(x)\n        elif pooling == \"max\":\n            x = layers.GlobalMaxPooling2D()(x)\n        x = layers.LayerNormalization(epsilon=1e-6)(x)\n\n    model = Functional(inputs=inputs, outputs=x, name=name)\n\n    # Validate weights before requesting them from the API\n    if weights == \"imagenet\":\n        expected_config = MODEL_CONFIGS[weights_name.split(\"convnext_\")[-1]]\n        if (\n            depths != expected_config[\"depths\"]\n            or projection_dims != expected_config[\"projection_dims\"]\n        ):\n            raise ValueError(\n                f\"Architecture configuration does not match {weights_name} \"\n                f\"variant. When using pre-trained weights, the model \"\n                f\"architecture must match the pre-trained configuration \"\n                f\"exactly. Expected depths: {expected_config['depths']}, \"\n                f\"got: {depths}. Expected projection_dims: \"\n                f\"{expected_config['projection_dims']}, got: {projection_dims}.\"\n            )\n\n        if weights_name not in name:\n            raise ValueError(\n                f'Model name \"{name}\" does not match weights variant '\n                f'\"{weights_name}\". When using imagenet weights, model name '\n                f'must contain the weights variant (e.g., \"convnext_'\n                f'{weights_name.split(\"convnext_\")[-1]}\").'\n            )\n\n    # Load weights.\n    if weights == \"imagenet\":\n        if include_top:\n            file_suffix = \".h5\"\n            file_hash = WEIGHTS_HASHES[weights_name][0]\n        else:\n            file_suffix = \"_notop.h5\"\n            file_hash = WEIGHTS_HASHES[weights_name][1]\n        file_name = name + file_suffix\n        weights_path = file_utils.get_file(\n            file_name,\n            BASE_WEIGHTS_PATH + file_name,\n            cache_subdir=\"models\",\n            file_hash=file_hash,\n        )\n        model.load_weights(weights_path)\n    elif weights is not None:\n        model.load_weights(weights)\n\n    return model\n\n\n## Instantiating variants ##\n\n\n@keras_export(\n    [\n        \"keras.applications.convnext.ConvNeXtTiny\",\n        \"keras.applications.ConvNeXtTiny\",\n    ]\n)\ndef ConvNeXtTiny(\n    include_top=True,\n    include_preprocessing=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"convnext_tiny\",\n):\n    return ConvNeXt(\n        weights_name=\"convnext_tiny\",\n        depths=MODEL_CONFIGS[\"tiny\"][\"depths\"],\n        projection_dims=MODEL_CONFIGS[\"tiny\"][\"projection_dims\"],\n        drop_path_rate=0.0,\n        layer_scale_init_value=1e-6,\n        default_size=MODEL_CONFIGS[\"tiny\"][\"default_size\"],\n        name=name,\n        include_top=include_top,\n        include_preprocessing=include_preprocessing,\n        weights=weights,\n        input_tensor=input_tensor,\n        input_shape=input_shape,\n        pooling=pooling,\n        classes=classes,\n        classifier_activation=classifier_activation,\n    )\n\n\n@keras_export(\n    [\n        \"keras.applications.convnext.ConvNeXtSmall\",\n        \"keras.applications.ConvNeXtSmall\",\n    ]\n)\ndef ConvNeXtSmall(\n    include_top=True,\n    include_preprocessing=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"convnext_small\",\n):\n    return ConvNeXt(\n        weights_name=\"convnext_small\",\n        depths=MODEL_CONFIGS[\"small\"][\"depths\"],\n        projection_dims=MODEL_CONFIGS[\"small\"][\"projection_dims\"],\n        drop_path_rate=0.0,\n        layer_scale_init_value=1e-6,\n        default_size=MODEL_CONFIGS[\"small\"][\"default_size\"],\n        name=name,\n        include_top=include_top,\n        include_preprocessing=include_preprocessing,\n        weights=weights,\n        input_tensor=input_tensor,\n        input_shape=input_shape,\n        pooling=pooling,\n        classes=classes,\n        classifier_activation=classifier_activation,\n    )\n\n\n@keras_export(\n    [\n        \"keras.applications.convnext.ConvNeXtBase\",\n        \"keras.applications.ConvNeXtBase\",\n    ]\n)\ndef ConvNeXtBase(\n    include_top=True,\n    include_preprocessing=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"convnext_base\",\n):\n    return ConvNeXt(\n        weights_name=\"convnext_base\",\n        depths=MODEL_CONFIGS[\"base\"][\"depths\"],\n        projection_dims=MODEL_CONFIGS[\"base\"][\"projection_dims\"],\n        drop_path_rate=0.0,\n        layer_scale_init_value=1e-6,\n        default_size=MODEL_CONFIGS[\"base\"][\"default_size\"],\n        name=name,\n        include_top=include_top,\n        include_preprocessing=include_preprocessing,\n        weights=weights,\n        input_tensor=input_tensor,\n        input_shape=input_shape,\n        pooling=pooling,\n        classes=classes,\n        classifier_activation=classifier_activation,\n    )\n\n\n@keras_export(\n    [\n        \"keras.applications.convnext.ConvNeXtLarge\",\n        \"keras.applications.ConvNeXtLarge\",\n    ]\n)\ndef ConvNeXtLarge(\n    include_top=True,\n    include_preprocessing=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"convnext_large\",\n):\n    return ConvNeXt(\n        weights_name=\"convnext_large\",\n        depths=MODEL_CONFIGS[\"large\"][\"depths\"],\n        projection_dims=MODEL_CONFIGS[\"large\"][\"projection_dims\"],\n        drop_path_rate=0.0,\n        layer_scale_init_value=1e-6,\n        default_size=MODEL_CONFIGS[\"large\"][\"default_size\"],\n        name=name,\n        include_top=include_top,\n        include_preprocessing=include_preprocessing,\n        weights=weights,\n        input_tensor=input_tensor,\n        input_shape=input_shape,\n        pooling=pooling,\n        classes=classes,\n        classifier_activation=classifier_activation,\n    )\n\n\n@keras_export(\n    [\n        \"keras.applications.convnext.ConvNeXtXLarge\",\n        \"keras.applications.ConvNeXtXLarge\",\n    ]\n)\ndef ConvNeXtXLarge(\n    include_top=True,\n    include_preprocessing=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"convnext_xlarge\",\n):\n    return ConvNeXt(\n        weights_name=\"convnext_xlarge\",\n        depths=MODEL_CONFIGS[\"xlarge\"][\"depths\"],\n        projection_dims=MODEL_CONFIGS[\"xlarge\"][\"projection_dims\"],\n        drop_path_rate=0.0,\n        layer_scale_init_value=1e-6,\n        default_size=MODEL_CONFIGS[\"xlarge\"][\"default_size\"],\n        name=name,\n        include_top=include_top,\n        include_preprocessing=include_preprocessing,\n        weights=weights,\n        input_tensor=input_tensor,\n        input_shape=input_shape,\n        pooling=pooling,\n        classes=classes,\n        classifier_activation=classifier_activation,\n    )\n\n\nConvNeXtTiny.__doc__ = BASE_DOCSTRING.format(name=\"ConvNeXtTiny\")\nConvNeXtSmall.__doc__ = BASE_DOCSTRING.format(name=\"ConvNeXtSmall\")\nConvNeXtBase.__doc__ = BASE_DOCSTRING.format(name=\"ConvNeXtBase\")\nConvNeXtLarge.__doc__ = BASE_DOCSTRING.format(name=\"ConvNeXtLarge\")\nConvNeXtXLarge.__doc__ = BASE_DOCSTRING.format(name=\"ConvNeXtXLarge\")\n\n\n@keras_export(\"keras.applications.convnext.preprocess_input\")\ndef preprocess_input(x, data_format=None):\n    \"\"\"A placeholder method for backward compatibility.\n\n    The preprocessing logic has been included in the convnext model\n    implementation. Users are no longer required to call this method to\n    normalize the input data. This method does nothing and only kept as a\n    placeholder to align the API surface between old and new version of model.\n\n    Args:\n        x: A floating point `numpy.array` or a tensor.\n        data_format: Optional data format of the image tensor/array. Defaults to\n            None, in which case the global setting\n            `keras.backend.image_data_format()` is used\n            (unless you changed it, it defaults to `\"channels_last\"`).{mode}\n\n    Returns:\n        Unchanged `numpy.array` or tensor.\n    \"\"\"\n    return x\n\n\n@keras_export(\"keras.applications.convnext.decode_predictions\")\ndef decode_predictions(preds, top=5):\n    return imagenet_utils.decode_predictions(preds, top=top)\n\n\ndecode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__\n"
  },
  {
    "path": "keras/src/applications/densenet.py",
    "content": "from keras.src import backend\nfrom keras.src import layers\nfrom keras.src.api_export import keras_export\nfrom keras.src.applications import imagenet_utils\nfrom keras.src.models import Functional\nfrom keras.src.ops import operation_utils\nfrom keras.src.utils import file_utils\n\nBASE_WEIGHTS_PATH = (\n    \"https://storage.googleapis.com/tensorflow/keras-applications/densenet/\"\n)\nDENSENET121_WEIGHT_PATH = (\n    f\"{BASE_WEIGHTS_PATH}densenet121_weights_tf_dim_ordering_tf_kernels.h5\"\n)\nDENSENET121_WEIGHT_PATH_NO_TOP = (\n    f\"{BASE_WEIGHTS_PATH}\"\n    \"densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5\"\n)\nDENSENET169_WEIGHT_PATH = (\n    f\"{BASE_WEIGHTS_PATH}densenet169_weights_tf_dim_ordering_tf_kernels.h5\"\n)\nDENSENET169_WEIGHT_PATH_NO_TOP = (\n    f\"{BASE_WEIGHTS_PATH}\"\n    \"densenet169_weights_tf_dim_ordering_tf_kernels_notop.h5\"\n)\nDENSENET201_WEIGHT_PATH = (\n    f\"{BASE_WEIGHTS_PATH}densenet201_weights_tf_dim_ordering_tf_kernels.h5\"\n)\nDENSENET201_WEIGHT_PATH_NO_TOP = (\n    f\"{BASE_WEIGHTS_PATH}\"\n    \"densenet201_weights_tf_dim_ordering_tf_kernels_notop.h5\"\n)\n\n\ndef dense_block(x, blocks, name):\n    \"\"\"A dense block.\n\n    Args:\n        x: input tensor.\n        blocks: integer, the number of building blocks.\n        name: string, block label.\n\n    Returns:\n        Output tensor for the block.\n    \"\"\"\n    for i in range(blocks):\n        x = conv_block(x, 32, name=f\"{name}_block{i + 1}\")\n    return x\n\n\ndef transition_block(x, reduction, name):\n    \"\"\"A transition block.\n\n    Args:\n        x: input tensor.\n        reduction: float, compression rate at transition layers.\n        name: string, block label.\n\n    Returns:\n        Output tensor for the block.\n    \"\"\"\n    bn_axis = 3 if backend.image_data_format() == \"channels_last\" else 1\n    x = layers.BatchNormalization(\n        axis=bn_axis, epsilon=1.001e-5, name=f\"{name}_bn\"\n    )(x)\n    x = layers.Activation(\"relu\", name=f\"{name}_relu\")(x)\n    x = layers.Conv2D(\n        int(x.shape[bn_axis] * reduction),\n        1,\n        use_bias=False,\n        name=f\"{name}_conv\",\n    )(x)\n    x = layers.AveragePooling2D(2, strides=2, name=f\"{name}_pool\")(x)\n    return x\n\n\ndef conv_block(x, growth_rate, name):\n    \"\"\"A building block for a dense block.\n\n    Args:\n        x: input tensor.\n        growth_rate: float, growth rate at dense layers.\n        name: string, block label.\n\n    Returns:\n        Output tensor for the block.\n    \"\"\"\n    bn_axis = 3 if backend.image_data_format() == \"channels_last\" else 1\n    x1 = layers.BatchNormalization(\n        axis=bn_axis, epsilon=1.001e-5, name=f\"{name}_0_bn\"\n    )(x)\n    x1 = layers.Activation(\"relu\", name=f\"{name}_0_relu\")(x1)\n    x1 = layers.Conv2D(\n        4 * growth_rate, 1, use_bias=False, name=f\"{name}_1_conv\"\n    )(x1)\n    x1 = layers.BatchNormalization(\n        axis=bn_axis, epsilon=1.001e-5, name=f\"{name}_1_bn\"\n    )(x1)\n    x1 = layers.Activation(\"relu\", name=f\"{name}_1_relu\")(x1)\n    x1 = layers.Conv2D(\n        growth_rate, 3, padding=\"same\", use_bias=False, name=f\"{name}_2_conv\"\n    )(x1)\n    x = layers.Concatenate(axis=bn_axis, name=f\"{name}_concat\")([x, x1])\n    return x\n\n\ndef DenseNet(\n    blocks,\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"densenet\",\n):\n    \"\"\"Instantiates the DenseNet architecture.\n\n    Reference:\n    - [Densely Connected Convolutional Networks](\n        https://arxiv.org/abs/1608.06993) (CVPR 2017)\n\n    This function returns a Keras image classification model,\n    optionally loaded with weights pre-trained on ImageNet.\n\n    For image classification use cases, see\n    [this page for detailed examples](\n      https://keras.io/api/applications/#usage-examples-for-image-classification-models).\n\n    For transfer learning use cases, make sure to read the\n    [guide to transfer learning & fine-tuning](\n      https://keras.io/guides/transfer_learning/).\n\n    Note: each Keras Application expects a specific kind of input preprocessing.\n    For DenseNet, call `keras.applications.densenet.preprocess_input`\n    on your inputs before passing them to the model.\n    `densenet.preprocess_input` will scale pixels between 0 and 1 and then\n    will normalize each channel with respect to the ImageNet\n    dataset statistics.\n\n    Args:\n        blocks: numbers of building blocks for the four dense layers.\n        include_top: whether to include the fully-connected\n            layer at the top of the network.\n        weights: one of `None` (random initialization),\n            `\"imagenet\"` (pre-training on ImageNet),\n            or the path to the weights file to be loaded.\n        input_tensor: optional Keras tensor\n            (i.e. output of `layers.Input()`)\n            to use as image input for the model.\n        input_shape: optional shape tuple, only to be specified\n            if `include_top` is False (otherwise the input shape\n            has to be `(224, 224, 3)`\n            (with `'channels_last'` data format)\n            or `(3, 224, 224)` (with `'channels_first'` data format).\n            It should have exactly 3 inputs channels,\n            and width and height should be no smaller than 32.\n            E.g. `(200, 200, 3)` would be one valid value.\n        pooling: optional pooling mode for feature extraction\n            when `include_top` is `False`.\n            - `None` means that the output of the model will be\n                the 4D tensor output of the\n                last convolutional block.\n            - `avg` means that global average pooling\n                will be applied to the output of the\n                last convolutional block, and thus\n                the output of the model will be a 2D tensor.\n            - `max` means that global max pooling will\n                be applied.\n        classes: optional number of classes to classify images\n            into, only to be specified if `include_top` is `True`, and\n            if no `weights` argument is specified. Defaults to `1000`.\n        classifier_activation: A `str` or callable.\n            The activation function to use\n            on the \"top\" layer. Ignored unless `include_top=True`. Set\n            `classifier_activation=None` to return the logits of the \"top\"\n            layer. When loading pretrained weights, `classifier_activation`\n            can only be `None` or `\"softmax\"`.\n        name: The name of the model (string).\n\n    Returns:\n        A model instance.\n    \"\"\"\n    if backend.image_data_format() == \"channels_first\":\n        raise ValueError(\n            \"DenseNet does not support the `channels_first` image data \"\n            \"format. Switch to `channels_last` by editing your local \"\n            \"config file at ~/.keras/keras.json\"\n        )\n    if not (weights in {\"imagenet\", None} or file_utils.exists(weights)):\n        raise ValueError(\n            \"The `weights` argument should be either \"\n            \"`None` (random initialization), `imagenet` \"\n            \"(pre-training on ImageNet), \"\n            \"or the path to the weights file to be loaded.\"\n        )\n\n    if weights == \"imagenet\" and include_top and classes != 1000:\n        raise ValueError(\n            'If using `weights` as `\"imagenet\"` with `include_top`'\n            \" as true, `classes` should be 1000\"\n        )\n\n    # Determine proper input shape\n    input_shape = imagenet_utils.obtain_input_shape(\n        input_shape,\n        default_size=224,\n        min_size=32,\n        data_format=backend.image_data_format(),\n        require_flatten=include_top,\n        weights=weights,\n    )\n\n    if input_tensor is None:\n        img_input = layers.Input(shape=input_shape)\n    else:\n        if not backend.is_keras_tensor(input_tensor):\n            img_input = layers.Input(tensor=input_tensor, shape=input_shape)\n        else:\n            img_input = input_tensor\n\n    bn_axis = 3 if backend.image_data_format() == \"channels_last\" else 1\n\n    x = layers.ZeroPadding2D(padding=((3, 3), (3, 3)))(img_input)\n    x = layers.Conv2D(64, 7, strides=2, use_bias=False, name=\"conv1_conv\")(x)\n    x = layers.BatchNormalization(\n        axis=bn_axis, epsilon=1.001e-5, name=\"conv1_bn\"\n    )(x)\n    x = layers.Activation(\"relu\", name=\"conv1_relu\")(x)\n    x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)))(x)\n    x = layers.MaxPooling2D(3, strides=2, name=\"pool1\")(x)\n\n    x = dense_block(x, blocks[0], name=\"conv2\")\n    x = transition_block(x, 0.5, name=\"pool2\")\n    x = dense_block(x, blocks[1], name=\"conv3\")\n    x = transition_block(x, 0.5, name=\"pool3\")\n    x = dense_block(x, blocks[2], name=\"conv4\")\n    x = transition_block(x, 0.5, name=\"pool4\")\n    x = dense_block(x, blocks[3], name=\"conv5\")\n\n    x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name=\"bn\")(x)\n    x = layers.Activation(\"relu\", name=\"relu\")(x)\n\n    if include_top:\n        x = layers.GlobalAveragePooling2D(name=\"avg_pool\")(x)\n\n        imagenet_utils.validate_activation(classifier_activation, weights)\n        x = layers.Dense(\n            classes, activation=classifier_activation, name=\"predictions\"\n        )(x)\n    else:\n        if pooling == \"avg\":\n            x = layers.GlobalAveragePooling2D(name=\"avg_pool\")(x)\n        elif pooling == \"max\":\n            x = layers.GlobalMaxPooling2D(name=\"max_pool\")(x)\n\n    # Ensure that the model takes into account\n    # any potential predecessors of `input_tensor`.\n    if input_tensor is not None:\n        inputs = operation_utils.get_source_inputs(input_tensor)\n    else:\n        inputs = img_input\n\n    # Create model.\n    model = Functional(inputs, x, name=name)\n\n    # Load weights.\n    if weights == \"imagenet\":\n        if include_top:\n            if blocks == [6, 12, 24, 16]:\n                weights_path = file_utils.get_file(\n                    \"densenet121_weights_tf_dim_ordering_tf_kernels.h5\",\n                    DENSENET121_WEIGHT_PATH,\n                    cache_subdir=\"models\",\n                    file_hash=\"9d60b8095a5708f2dcce2bca79d332c7\",\n                )\n            elif blocks == [6, 12, 32, 32]:\n                weights_path = file_utils.get_file(\n                    \"densenet169_weights_tf_dim_ordering_tf_kernels.h5\",\n                    DENSENET169_WEIGHT_PATH,\n                    cache_subdir=\"models\",\n                    file_hash=\"d699b8f76981ab1b30698df4c175e90b\",\n                )\n            elif blocks == [6, 12, 48, 32]:\n                weights_path = file_utils.get_file(\n                    \"densenet201_weights_tf_dim_ordering_tf_kernels.h5\",\n                    DENSENET201_WEIGHT_PATH,\n                    cache_subdir=\"models\",\n                    file_hash=\"1ceb130c1ea1b78c3bf6114dbdfd8807\",\n                )\n            else:\n                raise ValueError(\"weights_path undefined\")\n        else:\n            if blocks == [6, 12, 24, 16]:\n                weights_path = file_utils.get_file(\n                    \"densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5\",\n                    DENSENET121_WEIGHT_PATH_NO_TOP,\n                    cache_subdir=\"models\",\n                    file_hash=\"30ee3e1110167f948a6b9946edeeb738\",\n                )\n            elif blocks == [6, 12, 32, 32]:\n                weights_path = file_utils.get_file(\n                    \"densenet169_weights_tf_dim_ordering_tf_kernels_notop.h5\",\n                    DENSENET169_WEIGHT_PATH_NO_TOP,\n                    cache_subdir=\"models\",\n                    file_hash=\"b8c4d4c20dd625c148057b9ff1c1176b\",\n                )\n            elif blocks == [6, 12, 48, 32]:\n                weights_path = file_utils.get_file(\n                    \"densenet201_weights_tf_dim_ordering_tf_kernels_notop.h5\",\n                    DENSENET201_WEIGHT_PATH_NO_TOP,\n                    cache_subdir=\"models\",\n                    file_hash=\"c13680b51ded0fb44dff2d8f86ac8bb1\",\n                )\n            else:\n                raise ValueError(\"weights_path undefined\")\n        model.load_weights(weights_path)\n    elif weights is not None:\n        model.load_weights(weights)\n\n    return model\n\n\n@keras_export(\n    [\n        \"keras.applications.densenet.DenseNet121\",\n        \"keras.applications.DenseNet121\",\n    ]\n)\ndef DenseNet121(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"densenet121\",\n):\n    \"\"\"Instantiates the Densenet121 architecture.\"\"\"\n    return DenseNet(\n        [6, 12, 24, 16],\n        include_top,\n        weights,\n        input_tensor,\n        input_shape,\n        pooling,\n        classes,\n        classifier_activation,\n        name=name,\n    )\n\n\n@keras_export(\n    [\n        \"keras.applications.densenet.DenseNet169\",\n        \"keras.applications.DenseNet169\",\n    ]\n)\ndef DenseNet169(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"densenet169\",\n):\n    \"\"\"Instantiates the Densenet169 architecture.\"\"\"\n    return DenseNet(\n        [6, 12, 32, 32],\n        include_top,\n        weights,\n        input_tensor,\n        input_shape,\n        pooling,\n        classes,\n        classifier_activation,\n        name=name,\n    )\n\n\n@keras_export(\n    [\n        \"keras.applications.densenet.DenseNet201\",\n        \"keras.applications.DenseNet201\",\n    ]\n)\ndef DenseNet201(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"densenet201\",\n):\n    \"\"\"Instantiates the Densenet201 architecture.\"\"\"\n    return DenseNet(\n        [6, 12, 48, 32],\n        include_top,\n        weights,\n        input_tensor,\n        input_shape,\n        pooling,\n        classes,\n        classifier_activation,\n        name=name,\n    )\n\n\n@keras_export(\"keras.applications.densenet.preprocess_input\")\ndef preprocess_input(x, data_format=None):\n    return imagenet_utils.preprocess_input(\n        x, data_format=data_format, mode=\"torch\"\n    )\n\n\n@keras_export(\"keras.applications.densenet.decode_predictions\")\ndef decode_predictions(preds, top=5):\n    return imagenet_utils.decode_predictions(preds, top=top)\n\n\npreprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(\n    mode=\"\",\n    ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TORCH,\n    error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC,\n)\ndecode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__\n\nDOC = \"\"\"\n\nReference:\n- [Densely Connected Convolutional Networks](\n    https://arxiv.org/abs/1608.06993) (CVPR 2017)\n\nOptionally loads weights pre-trained on ImageNet.\nNote that the data format convention used by the model is\nthe one specified in your Keras config at `~/.keras/keras.json`.\n\nNote: each Keras Application expects a specific kind of input preprocessing.\nFor DenseNet, call `keras.applications.densenet.preprocess_input`\non your inputs before passing them to the model.\n\nArgs:\n    include_top: whether to include the fully-connected\n        layer at the top of the network.\n    weights: one of `None` (random initialization),\n        `\"imagenet\"` (pre-training on ImageNet),\n        or the path to the weights file to be loaded.\n        input_tensor: optional Keras tensor\n        (i.e. output of `layers.Input()`)\n        to use as image input for the model.\n    input_shape: optional shape tuple, only to be specified\n        if `include_top` is False (otherwise the input shape\n        has to be `(224, 224, 3)` (with `'channels_last'` data format)\n        or `(3, 224, 224)` (with `'channels_first'` data format).\n        It should have exactly 3 inputs channels,\n        and width and height should be no smaller than 32.\n        E.g. `(200, 200, 3)` would be one valid value.\n    pooling: Optional pooling mode for feature extraction\n        when `include_top` is `False`.\n        - `None` means that the output of the model will be\n            the 4D tensor output of the\n            last convolutional block.\n        - `avg` means that global average pooling\n            will be applied to the output of the\n            last convolutional block, and thus\n            the output of the model will be a 2D tensor.\n        - `max` means that global max pooling will\n            be applied.\n    classes: optional number of classes to classify images\n        into, only to be specified if `include_top` is `True`, and\n        if no `weights` argument is specified. Defaults to 1000.\n    classifier_activation: A `str` or callable.\n        The activation function to use\n        on the \"top\" layer. Ignored unless `include_top=True`. Set\n        `classifier_activation=None` to return the logits\n        of the \"top\" layer. When loading pretrained weights,\n        `classifier_activation` can only be `None` or `\"softmax\"`.\n    name: The name of the model (string).\n\nReturns:\n    A Keras model instance.\n\"\"\"\n\nsetattr(DenseNet121, \"__doc__\", DenseNet121.__doc__ + DOC)\nsetattr(DenseNet169, \"__doc__\", DenseNet169.__doc__ + DOC)\nsetattr(DenseNet201, \"__doc__\", DenseNet201.__doc__ + DOC)\n"
  },
  {
    "path": "keras/src/applications/efficientnet.py",
    "content": "import copy\nimport math\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src.api_export import keras_export\nfrom keras.src.applications import imagenet_utils\nfrom keras.src.models import Functional\nfrom keras.src.ops import operation_utils\nfrom keras.src.utils import file_utils\n\nBASE_WEIGHTS_PATH = \"https://storage.googleapis.com/keras-applications/\"\n\nWEIGHTS_HASHES = {\n    \"b0\": (\n        \"902e53a9f72be733fc0bcb005b3ebbac\",\n        \"50bc09e76180e00e4465e1a485ddc09d\",\n    ),\n    \"b1\": (\n        \"1d254153d4ab51201f1646940f018540\",\n        \"74c4e6b3e1f6a1eea24c589628592432\",\n    ),\n    \"b2\": (\n        \"b15cce36ff4dcbd00b6dd88e7857a6ad\",\n        \"111f8e2ac8aa800a7a99e3239f7bfb39\",\n    ),\n    \"b3\": (\n        \"ffd1fdc53d0ce67064dc6a9c7960ede0\",\n        \"af6d107764bb5b1abb91932881670226\",\n    ),\n    \"b4\": (\n        \"18c95ad55216b8f92d7e70b3a046e2fc\",\n        \"ebc24e6d6c33eaebbd558eafbeedf1ba\",\n    ),\n    \"b5\": (\n        \"ace28f2a6363774853a83a0b21b9421a\",\n        \"38879255a25d3c92d5e44e04ae6cec6f\",\n    ),\n    \"b6\": (\n        \"165f6e37dce68623721b423839de8be5\",\n        \"9ecce42647a20130c1f39a5d4cb75743\",\n    ),\n    \"b7\": (\n        \"8c03f828fec3ef71311cd463b6759d99\",\n        \"cbcfe4450ddf6f3ad90b1b398090fe4a\",\n    ),\n}\n\nDEFAULT_BLOCKS_ARGS = [\n    {\n        \"kernel_size\": 3,\n        \"repeats\": 1,\n        \"filters_in\": 32,\n        \"filters_out\": 16,\n        \"expand_ratio\": 1,\n        \"id_skip\": True,\n        \"strides\": 1,\n        \"se_ratio\": 0.25,\n    },\n    {\n        \"kernel_size\": 3,\n        \"repeats\": 2,\n        \"filters_in\": 16,\n        \"filters_out\": 24,\n        \"expand_ratio\": 6,\n        \"id_skip\": True,\n        \"strides\": 2,\n        \"se_ratio\": 0.25,\n    },\n    {\n        \"kernel_size\": 5,\n        \"repeats\": 2,\n        \"filters_in\": 24,\n        \"filters_out\": 40,\n        \"expand_ratio\": 6,\n        \"id_skip\": True,\n        \"strides\": 2,\n        \"se_ratio\": 0.25,\n    },\n    {\n        \"kernel_size\": 3,\n        \"repeats\": 3,\n        \"filters_in\": 40,\n        \"filters_out\": 80,\n        \"expand_ratio\": 6,\n        \"id_skip\": True,\n        \"strides\": 2,\n        \"se_ratio\": 0.25,\n    },\n    {\n        \"kernel_size\": 5,\n        \"repeats\": 3,\n        \"filters_in\": 80,\n        \"filters_out\": 112,\n        \"expand_ratio\": 6,\n        \"id_skip\": True,\n        \"strides\": 1,\n        \"se_ratio\": 0.25,\n    },\n    {\n        \"kernel_size\": 5,\n        \"repeats\": 4,\n        \"filters_in\": 112,\n        \"filters_out\": 192,\n        \"expand_ratio\": 6,\n        \"id_skip\": True,\n        \"strides\": 2,\n        \"se_ratio\": 0.25,\n    },\n    {\n        \"kernel_size\": 3,\n        \"repeats\": 1,\n        \"filters_in\": 192,\n        \"filters_out\": 320,\n        \"expand_ratio\": 6,\n        \"id_skip\": True,\n        \"strides\": 1,\n        \"se_ratio\": 0.25,\n    },\n]\n\nCONV_KERNEL_INITIALIZER = {\n    \"class_name\": \"VarianceScaling\",\n    \"config\": {\n        \"scale\": 2.0,\n        \"mode\": \"fan_out\",\n        \"distribution\": \"truncated_normal\",\n    },\n}\n\nDENSE_KERNEL_INITIALIZER = {\n    \"class_name\": \"VarianceScaling\",\n    \"config\": {\n        \"scale\": 1.0 / 3.0,\n        \"mode\": \"fan_out\",\n        \"distribution\": \"uniform\",\n    },\n}\n\nBASE_DOCSTRING = \"\"\"Instantiates the {name} architecture.\n\nReference:\n- [EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks](\n    https://arxiv.org/abs/1905.11946) (ICML 2019)\n\nThis function returns a Keras image classification model,\noptionally loaded with weights pre-trained on ImageNet.\n\nFor image classification use cases, see\n[this page for detailed examples](\nhttps://keras.io/api/applications/#usage-examples-for-image-classification-models).\n\nFor transfer learning use cases, make sure to read the\n[guide to transfer learning & fine-tuning](\nhttps://keras.io/guides/transfer_learning/).\n\nNote: each Keras Application expects a specific kind of input preprocessing.\nFor EfficientNet, input preprocessing is included as part of the model\n(as a `Rescaling` layer), and thus\n`keras.applications.efficientnet.preprocess_input` is actually a\npass-through function. EfficientNet models expect their inputs to be float\ntensors of pixels with values in the `[0-255]` range.\n\nArgs:\n    include_top: Whether to include the fully-connected\n        layer at the top of the network. Defaults to `True`.\n    weights: One of `None` (random initialization),\n        `\"imagenet\"` (pre-training on ImageNet),\n        or the path to the weights file to be loaded.\n        Defaults to `\"imagenet\"`.\n    input_tensor: Optional Keras tensor\n        (i.e. output of `layers.Input()`)\n        to use as image input for the model.\n    input_shape: Optional shape tuple, only to be specified\n        if `include_top` is False.\n        It should have exactly 3 inputs channels.\n    pooling: Optional pooling mode for feature extraction\n        when `include_top` is `False`. Defaults to `None`.\n        - `None` means that the output of the model will be\n            the 4D tensor output of the\n            last convolutional layer.\n        - `avg` means that global average pooling\n            will be applied to the output of the\n            last convolutional layer, and thus\n            the output of the model will be a 2D tensor.\n        - `max` means that global max pooling will\n            be applied.\n    classes: Optional number of classes to classify images\n        into, only to be specified if `include_top` is True, and\n        if no `weights` argument is specified. 1000 is how many\n        ImageNet classes there are. Defaults to `1000`.\n    classifier_activation: A `str` or callable. The activation function to use\n        on the \"top\" layer. Ignored unless `include_top=True`. Set\n        `classifier_activation=None` to return the logits of the \"top\" layer.\n        Defaults to `'softmax'`.\n        When loading pretrained weights, `classifier_activation` can only\n        be `None` or `\"softmax\"`.\n    name: The name of the model (string).\n\nReturns:\n    A model instance.\n\"\"\"\n\n\nIMAGENET_STDDEV_RGB = [0.229, 0.224, 0.225]\n\n\ndef EfficientNet(\n    width_coefficient,\n    depth_coefficient,\n    default_size,\n    dropout_rate=0.2,\n    drop_connect_rate=0.2,\n    depth_divisor=8,\n    activation=\"swish\",\n    blocks_args=\"default\",\n    name=\"efficientnet\",\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    weights_name=None,\n):\n    \"\"\"Instantiates the EfficientNet architecture.\n\n    Args:\n      width_coefficient: float, scaling coefficient for network width.\n      depth_coefficient: float, scaling coefficient for network depth.\n      default_size: integer, default input image size.\n      dropout_rate: float, dropout rate before final classifier layer.\n      drop_connect_rate: float, dropout rate at skip connections.\n      depth_divisor: integer, a unit of network width.\n      activation: activation function.\n      blocks_args: list of dicts, parameters to construct block modules.\n      name: string, model name.\n      include_top: whether to include the fully-connected\n          layer at the top of the network.\n      weights: one of `None` (random initialization),\n            'imagenet' (pre-training on ImageNet),\n            or the path to the weights file to be loaded.\n      input_tensor: optional Keras tensor\n          (i.e. output of `layers.Input()`)\n          to use as image input for the model.\n      input_shape: optional shape tuple, only to be specified\n          if `include_top` is False.\n          It should have exactly 3 inputs channels.\n      pooling: optional pooling mode for feature extraction\n          when `include_top` is `False`.\n          - `None` means that the output of the model will be\n              the 4D tensor output of the\n              last convolutional layer.\n          - `avg` means that global average pooling\n              will be applied to the output of the\n              last convolutional layer, and thus\n              the output of the model will be a 2D tensor.\n          - `max` means that global max pooling will\n              be applied.\n      classes: optional number of classes to classify images\n          into, only to be specified if `include_top` is True, and\n          if no `weights` argument is specified.\n      classifier_activation: A `str` or callable. The activation function to use\n          on the \"top\" layer. Ignored unless `include_top=True`. Set\n          `classifier_activation=None` to return the logits of the \"top\" layer.\n\n    Returns:\n        A model instance.\n    \"\"\"\n    if blocks_args == \"default\":\n        blocks_args = DEFAULT_BLOCKS_ARGS\n\n    if not (weights in {\"imagenet\", None} or file_utils.exists(weights)):\n        raise ValueError(\n            \"The `weights` argument should be either \"\n            \"`None` (random initialization), `imagenet` \"\n            \"(pre-training on ImageNet), \"\n            \"or the path to the weights file to be loaded.\"\n        )\n\n    if weights == \"imagenet\" and include_top and classes != 1000:\n        raise ValueError(\n            'If using `weights=\"imagenet\"` with `include_top`'\n            \" as true, `classes` should be 1000\"\n        )\n\n    # Determine proper input shape\n    input_shape = imagenet_utils.obtain_input_shape(\n        input_shape,\n        default_size=default_size,\n        min_size=32,\n        data_format=backend.image_data_format(),\n        require_flatten=include_top,\n        weights=weights,\n    )\n\n    if input_tensor is None:\n        img_input = layers.Input(shape=input_shape)\n    else:\n        if not backend.is_keras_tensor(input_tensor):\n            img_input = layers.Input(tensor=input_tensor, shape=input_shape)\n        else:\n            img_input = input_tensor\n\n    bn_axis = 3 if backend.image_data_format() == \"channels_last\" else 1\n\n    def round_filters(filters, divisor=depth_divisor):\n        \"\"\"Round number of filters based on depth multiplier.\"\"\"\n        filters *= width_coefficient\n        new_filters = max(\n            divisor, int(filters + divisor / 2) // divisor * divisor\n        )\n        # Make sure that round down does not go down by more than 10%.\n        if new_filters < 0.9 * filters:\n            new_filters += divisor\n        return int(new_filters)\n\n    def round_repeats(repeats):\n        \"\"\"Round number of repeats based on depth multiplier.\"\"\"\n        return int(math.ceil(depth_coefficient * repeats))\n\n    # Build stem\n    x = img_input\n    x = layers.Rescaling(1.0 / 255.0)(x)\n    x = layers.Normalization(axis=bn_axis)(x)\n\n    if weights == \"imagenet\":\n        # Note that the normalization layer uses square value of STDDEV as the\n        # variance for the layer: result = (input - mean) / sqrt(var)\n        # However, the original implementation uses (input - mean) / var to\n        # normalize the input, we need to divide another sqrt(var) to match the\n        # original implementation.\n        # See https://github.com/tensorflow/tensorflow/issues/49930 for more\n        # details\n        x = layers.Rescaling(\n            [1.0 / math.sqrt(stddev) for stddev in IMAGENET_STDDEV_RGB]\n        )(x)\n\n    x = layers.ZeroPadding2D(\n        padding=imagenet_utils.correct_pad(x, 3), name=\"stem_conv_pad\"\n    )(x)\n    x = layers.Conv2D(\n        round_filters(32),\n        3,\n        strides=2,\n        padding=\"valid\",\n        use_bias=False,\n        kernel_initializer=CONV_KERNEL_INITIALIZER,\n        name=\"stem_conv\",\n    )(x)\n    x = layers.BatchNormalization(axis=bn_axis, name=\"stem_bn\")(x)\n    x = layers.Activation(activation, name=\"stem_activation\")(x)\n\n    # Build blocks\n    blocks_args = copy.deepcopy(blocks_args)\n\n    b = 0\n    blocks = float(sum(round_repeats(args[\"repeats\"]) for args in blocks_args))\n    for i, args in enumerate(blocks_args):\n        if args[\"repeats\"] <= 0:\n            raise ValueError(\n                f\"The number of repeats in `EfficientNet` must be > 0. \"\n                f\"Received: repeats={args['repeats']}\"\n            )\n        # Update block input and output filters based on depth multiplier.\n        args[\"filters_in\"] = round_filters(args[\"filters_in\"])\n        args[\"filters_out\"] = round_filters(args[\"filters_out\"])\n\n        for j in range(round_repeats(args.pop(\"repeats\"))):\n            # The first block needs to take care of stride and filter size\n            # increase.\n            if j > 0:\n                args[\"strides\"] = 1\n                args[\"filters_in\"] = args[\"filters_out\"]\n            x = block(\n                x,\n                activation,\n                drop_connect_rate * b / blocks,\n                name=f\"block{i + 1}{chr(j + 97)}_\",\n                **args,\n            )\n            b += 1\n\n    # Build top\n    x = layers.Conv2D(\n        round_filters(1280),\n        1,\n        padding=\"same\",\n        use_bias=False,\n        kernel_initializer=CONV_KERNEL_INITIALIZER,\n        name=\"top_conv\",\n    )(x)\n    x = layers.BatchNormalization(axis=bn_axis, name=\"top_bn\")(x)\n    x = layers.Activation(activation, name=\"top_activation\")(x)\n    if include_top:\n        x = layers.GlobalAveragePooling2D(name=\"avg_pool\")(x)\n        if dropout_rate > 0:\n            x = layers.Dropout(dropout_rate, name=\"top_dropout\")(x)\n        imagenet_utils.validate_activation(classifier_activation, weights)\n        x = layers.Dense(\n            classes,\n            activation=classifier_activation,\n            kernel_initializer=DENSE_KERNEL_INITIALIZER,\n            name=\"predictions\",\n        )(x)\n    else:\n        if pooling == \"avg\":\n            x = layers.GlobalAveragePooling2D(name=\"avg_pool\")(x)\n        elif pooling == \"max\":\n            x = layers.GlobalMaxPooling2D(name=\"max_pool\")(x)\n\n    # Ensure that the model takes into account\n    # any potential predecessors of `input_tensor`.\n    if input_tensor is not None:\n        inputs = operation_utils.get_source_inputs(input_tensor)\n    else:\n        inputs = img_input\n\n    # Create model.\n    model = Functional(inputs, x, name=name)\n\n    # Load weights.\n    if weights == \"imagenet\":\n        if include_top:\n            file_suffix = \".h5\"\n            file_hash = WEIGHTS_HASHES[weights_name][0]\n        else:\n            file_suffix = \"_notop.h5\"\n            file_hash = WEIGHTS_HASHES[weights_name][1]\n        file_name = name + file_suffix\n        weights_path = file_utils.get_file(\n            file_name,\n            BASE_WEIGHTS_PATH + file_name,\n            cache_subdir=\"models\",\n            file_hash=file_hash,\n        )\n        model.load_weights(weights_path)\n    elif weights is not None:\n        model.load_weights(weights)\n    return model\n\n\ndef block(\n    inputs,\n    activation=\"swish\",\n    drop_rate=0.0,\n    name=\"\",\n    filters_in=32,\n    filters_out=16,\n    kernel_size=3,\n    strides=1,\n    expand_ratio=1,\n    se_ratio=0.0,\n    id_skip=True,\n):\n    \"\"\"An inverted residual block.\n\n    Args:\n        inputs: input tensor.\n        activation: activation function.\n        drop_rate: float between 0 and 1, fraction of the input units to drop.\n        name: string, block label.\n        filters_in: integer, the number of input filters.\n        filters_out: integer, the number of output filters.\n        kernel_size: integer, the dimension of the convolution window.\n        strides: integer, the stride of the convolution.\n        expand_ratio: integer, scaling coefficient for the input filters.\n        se_ratio: float between 0 and 1, fraction to squeeze the input filters.\n        id_skip: boolean.\n\n    Returns:\n        output tensor for the block.\n    \"\"\"\n    bn_axis = 3 if backend.image_data_format() == \"channels_last\" else 1\n\n    # Expansion phase\n    filters = filters_in * expand_ratio\n    if expand_ratio != 1:\n        x = layers.Conv2D(\n            filters,\n            1,\n            padding=\"same\",\n            use_bias=False,\n            kernel_initializer=CONV_KERNEL_INITIALIZER,\n            name=f\"{name}expand_conv\",\n        )(inputs)\n        x = layers.BatchNormalization(axis=bn_axis, name=f\"{name}expand_bn\")(x)\n        x = layers.Activation(activation, name=f\"{name}expand_activation\")(x)\n    else:\n        x = inputs\n\n    # Depthwise Convolution\n    if strides == 2:\n        x = layers.ZeroPadding2D(\n            padding=imagenet_utils.correct_pad(x, kernel_size),\n            name=f\"{name}dwconv_pad\",\n        )(x)\n        conv_pad = \"valid\"\n    else:\n        conv_pad = \"same\"\n    x = layers.DepthwiseConv2D(\n        kernel_size,\n        strides=strides,\n        padding=conv_pad,\n        use_bias=False,\n        depthwise_initializer=CONV_KERNEL_INITIALIZER,\n        name=f\"{name}dwconv\",\n    )(x)\n    x = layers.BatchNormalization(axis=bn_axis, name=f\"{name}bn\")(x)\n    x = layers.Activation(activation, name=f\"{name}activation\")(x)\n\n    # Squeeze and Excitation phase\n    if 0 < se_ratio <= 1:\n        filters_se = max(1, int(filters_in * se_ratio))\n        se = layers.GlobalAveragePooling2D(name=f\"{name}se_squeeze\")(x)\n        if bn_axis == 1:\n            se_shape = (filters, 1, 1)\n        else:\n            se_shape = (1, 1, filters)\n        se = layers.Reshape(se_shape, name=f\"{name}se_reshape\")(se)\n        se = layers.Conv2D(\n            filters_se,\n            1,\n            padding=\"same\",\n            activation=activation,\n            kernel_initializer=CONV_KERNEL_INITIALIZER,\n            name=f\"{name}se_reduce\",\n        )(se)\n        se = layers.Conv2D(\n            filters,\n            1,\n            padding=\"same\",\n            activation=\"sigmoid\",\n            kernel_initializer=CONV_KERNEL_INITIALIZER,\n            name=f\"{name}se_expand\",\n        )(se)\n        x = layers.multiply([x, se], name=f\"{name}se_excite\")\n\n    # Output phase\n    x = layers.Conv2D(\n        filters_out,\n        1,\n        padding=\"same\",\n        use_bias=False,\n        kernel_initializer=CONV_KERNEL_INITIALIZER,\n        name=f\"{name}project_conv\",\n    )(x)\n    x = layers.BatchNormalization(axis=bn_axis, name=f\"{name}project_bn\")(x)\n    if id_skip and strides == 1 and filters_in == filters_out:\n        if drop_rate > 0:\n            x = layers.Dropout(\n                drop_rate, noise_shape=(None, 1, 1, 1), name=f\"{name}drop\"\n            )(x)\n        x = layers.add([x, inputs], name=f\"{name}add\")\n    return x\n\n\n@keras_export(\n    [\n        \"keras.applications.efficientnet.EfficientNetB0\",\n        \"keras.applications.EfficientNetB0\",\n    ]\n)\ndef EfficientNetB0(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"efficientnetb0\",\n):\n    return EfficientNet(\n        1.0,\n        1.0,\n        224,\n        0.2,\n        name=name,\n        include_top=include_top,\n        weights=weights,\n        input_tensor=input_tensor,\n        input_shape=input_shape,\n        pooling=pooling,\n        classes=classes,\n        classifier_activation=classifier_activation,\n        weights_name=\"b0\",\n    )\n\n\n@keras_export(\n    [\n        \"keras.applications.efficientnet.EfficientNetB1\",\n        \"keras.applications.EfficientNetB1\",\n    ]\n)\ndef EfficientNetB1(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"efficientnetb1\",\n):\n    return EfficientNet(\n        1.0,\n        1.1,\n        240,\n        0.2,\n        name=name,\n        include_top=include_top,\n        weights=weights,\n        input_tensor=input_tensor,\n        input_shape=input_shape,\n        pooling=pooling,\n        classes=classes,\n        classifier_activation=classifier_activation,\n        weights_name=\"b1\",\n    )\n\n\n@keras_export(\n    [\n        \"keras.applications.efficientnet.EfficientNetB2\",\n        \"keras.applications.EfficientNetB2\",\n    ]\n)\ndef EfficientNetB2(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"efficientnetb2\",\n):\n    return EfficientNet(\n        1.1,\n        1.2,\n        260,\n        0.3,\n        name=name,\n        include_top=include_top,\n        weights=weights,\n        input_tensor=input_tensor,\n        input_shape=input_shape,\n        pooling=pooling,\n        classes=classes,\n        classifier_activation=classifier_activation,\n        weights_name=\"b2\",\n    )\n\n\n@keras_export(\n    [\n        \"keras.applications.efficientnet.EfficientNetB3\",\n        \"keras.applications.EfficientNetB3\",\n    ]\n)\ndef EfficientNetB3(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"efficientnetb3\",\n):\n    return EfficientNet(\n        1.2,\n        1.4,\n        300,\n        0.3,\n        name=name,\n        include_top=include_top,\n        weights=weights,\n        input_tensor=input_tensor,\n        input_shape=input_shape,\n        pooling=pooling,\n        classes=classes,\n        classifier_activation=classifier_activation,\n        weights_name=\"b3\",\n    )\n\n\n@keras_export(\n    [\n        \"keras.applications.efficientnet.EfficientNetB4\",\n        \"keras.applications.EfficientNetB4\",\n    ]\n)\ndef EfficientNetB4(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"efficientnetb4\",\n):\n    return EfficientNet(\n        1.4,\n        1.8,\n        380,\n        0.4,\n        name=name,\n        include_top=include_top,\n        weights=weights,\n        input_tensor=input_tensor,\n        input_shape=input_shape,\n        pooling=pooling,\n        classes=classes,\n        classifier_activation=classifier_activation,\n        weights_name=\"b4\",\n    )\n\n\n@keras_export(\n    [\n        \"keras.applications.efficientnet.EfficientNetB5\",\n        \"keras.applications.EfficientNetB5\",\n    ]\n)\ndef EfficientNetB5(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"efficientnetb5\",\n):\n    return EfficientNet(\n        1.6,\n        2.2,\n        456,\n        0.4,\n        name=name,\n        include_top=include_top,\n        weights=weights,\n        input_tensor=input_tensor,\n        input_shape=input_shape,\n        pooling=pooling,\n        classes=classes,\n        classifier_activation=classifier_activation,\n        weights_name=\"b5\",\n    )\n\n\n@keras_export(\n    [\n        \"keras.applications.efficientnet.EfficientNetB6\",\n        \"keras.applications.EfficientNetB6\",\n    ]\n)\ndef EfficientNetB6(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"efficientnetb6\",\n):\n    return EfficientNet(\n        1.8,\n        2.6,\n        528,\n        0.5,\n        name=name,\n        include_top=include_top,\n        weights=weights,\n        input_tensor=input_tensor,\n        input_shape=input_shape,\n        pooling=pooling,\n        classes=classes,\n        classifier_activation=classifier_activation,\n        weights_name=\"b6\",\n    )\n\n\n@keras_export(\n    [\n        \"keras.applications.efficientnet.EfficientNetB7\",\n        \"keras.applications.EfficientNetB7\",\n    ]\n)\ndef EfficientNetB7(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"efficientnetb7\",\n):\n    return EfficientNet(\n        2.0,\n        3.1,\n        600,\n        0.5,\n        name=name,\n        include_top=include_top,\n        weights=weights,\n        input_tensor=input_tensor,\n        input_shape=input_shape,\n        pooling=pooling,\n        classes=classes,\n        classifier_activation=classifier_activation,\n        weights_name=\"b7\",\n    )\n\n\nEfficientNetB0.__doc__ = BASE_DOCSTRING.format(name=\"EfficientNetB0\")\nEfficientNetB1.__doc__ = BASE_DOCSTRING.format(name=\"EfficientNetB1\")\nEfficientNetB2.__doc__ = BASE_DOCSTRING.format(name=\"EfficientNetB2\")\nEfficientNetB3.__doc__ = BASE_DOCSTRING.format(name=\"EfficientNetB3\")\nEfficientNetB4.__doc__ = BASE_DOCSTRING.format(name=\"EfficientNetB4\")\nEfficientNetB5.__doc__ = BASE_DOCSTRING.format(name=\"EfficientNetB5\")\nEfficientNetB6.__doc__ = BASE_DOCSTRING.format(name=\"EfficientNetB6\")\nEfficientNetB7.__doc__ = BASE_DOCSTRING.format(name=\"EfficientNetB7\")\n\n\n@keras_export(\"keras.applications.efficientnet.preprocess_input\")\ndef preprocess_input(x, data_format=None):\n    \"\"\"A placeholder method for backward compatibility.\n\n    The preprocessing logic has been included in the efficientnet model\n    implementation. Users are no longer required to call this method to\n    normalize the input data. This method does nothing and only kept as a\n    placeholder to align the API surface between old and new version of model.\n\n    Args:\n        x: A floating point `numpy.array` or a tensor.\n        data_format: Optional data format of the image tensor/array. `None`\n            means the global setting `keras.backend.image_data_format()`\n            is used (unless you changed it, it uses `\"channels_last\"`).\n            Defaults to `None`.\n\n    Returns:\n        Unchanged `numpy.array` or tensor.\n    \"\"\"\n    return x\n\n\n@keras_export(\"keras.applications.efficientnet.decode_predictions\")\ndef decode_predictions(preds, top=5):\n    return imagenet_utils.decode_predictions(preds, top=top)\n\n\ndecode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__\n"
  },
  {
    "path": "keras/src/applications/efficientnet_v2.py",
    "content": "import copy\nimport math\n\nfrom keras.src import backend\nfrom keras.src import initializers\nfrom keras.src import layers\nfrom keras.src.api_export import keras_export\nfrom keras.src.applications import imagenet_utils\nfrom keras.src.models import Functional\nfrom keras.src.ops import operation_utils\nfrom keras.src.utils import file_utils\n\nBASE_WEIGHTS_PATH = \"https://storage.googleapis.com/tensorflow/keras-applications/efficientnet_v2/\"  # noqa: E501\n\nWEIGHTS_HASHES = {\n    \"b0\": (\n        \"21ecbf6da12460d5c40bb2f29ceb2188\",\n        \"893217f2bb855e2983157299931e43ff\",\n    ),\n    \"b1\": (\n        \"069f0534ff22adf035c89e2d9547a9dc\",\n        \"0e80663031ca32d657f9caa404b6ec37\",\n    ),\n    \"b2\": (\n        \"424e49f28180edbde1e94797771950a7\",\n        \"1dfe2e7a5d45b6632553a8961ea609eb\",\n    ),\n    \"b3\": (\n        \"1f1fc43bd98a6e4fd8fdfd551e02c7a0\",\n        \"f6abf7b5849ac99a89b50dd3fd532856\",\n    ),\n    \"-s\": (\n        \"e1d88a8495beba45748fedd0cecbe016\",\n        \"af0682fb74e8c54910f2d4393339c070\",\n    ),\n    \"-m\": (\n        \"a3bf6aa3276309f4fc6a34aa114c95cd\",\n        \"1b8dc055df72dde80d614482840fe342\",\n    ),\n    \"-l\": (\n        \"27e6d408b53c7ebc868fefa357689935\",\n        \"b0b66b5c863aef5b46e8608fe1711615\",\n    ),\n}\n\nDEFAULT_BLOCKS_ARGS = {\n    \"efficientnetv2-s\": [\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 2,\n            \"input_filters\": 24,\n            \"output_filters\": 24,\n            \"expand_ratio\": 1,\n            \"se_ratio\": 0.0,\n            \"strides\": 1,\n            \"conv_type\": 1,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 4,\n            \"input_filters\": 24,\n            \"output_filters\": 48,\n            \"expand_ratio\": 4,\n            \"se_ratio\": 0.0,\n            \"strides\": 2,\n            \"conv_type\": 1,\n        },\n        {\n            \"conv_type\": 1,\n            \"expand_ratio\": 4,\n            \"input_filters\": 48,\n            \"kernel_size\": 3,\n            \"num_repeat\": 4,\n            \"output_filters\": 64,\n            \"se_ratio\": 0,\n            \"strides\": 2,\n        },\n        {\n            \"conv_type\": 0,\n            \"expand_ratio\": 4,\n            \"input_filters\": 64,\n            \"kernel_size\": 3,\n            \"num_repeat\": 6,\n            \"output_filters\": 128,\n            \"se_ratio\": 0.25,\n            \"strides\": 2,\n        },\n        {\n            \"conv_type\": 0,\n            \"expand_ratio\": 6,\n            \"input_filters\": 128,\n            \"kernel_size\": 3,\n            \"num_repeat\": 9,\n            \"output_filters\": 160,\n            \"se_ratio\": 0.25,\n            \"strides\": 1,\n        },\n        {\n            \"conv_type\": 0,\n            \"expand_ratio\": 6,\n            \"input_filters\": 160,\n            \"kernel_size\": 3,\n            \"num_repeat\": 15,\n            \"output_filters\": 256,\n            \"se_ratio\": 0.25,\n            \"strides\": 2,\n        },\n    ],\n    \"efficientnetv2-m\": [\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 3,\n            \"input_filters\": 24,\n            \"output_filters\": 24,\n            \"expand_ratio\": 1,\n            \"se_ratio\": 0,\n            \"strides\": 1,\n            \"conv_type\": 1,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 5,\n            \"input_filters\": 24,\n            \"output_filters\": 48,\n            \"expand_ratio\": 4,\n            \"se_ratio\": 0,\n            \"strides\": 2,\n            \"conv_type\": 1,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 5,\n            \"input_filters\": 48,\n            \"output_filters\": 80,\n            \"expand_ratio\": 4,\n            \"se_ratio\": 0,\n            \"strides\": 2,\n            \"conv_type\": 1,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 7,\n            \"input_filters\": 80,\n            \"output_filters\": 160,\n            \"expand_ratio\": 4,\n            \"se_ratio\": 0.25,\n            \"strides\": 2,\n            \"conv_type\": 0,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 14,\n            \"input_filters\": 160,\n            \"output_filters\": 176,\n            \"expand_ratio\": 6,\n            \"se_ratio\": 0.25,\n            \"strides\": 1,\n            \"conv_type\": 0,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 18,\n            \"input_filters\": 176,\n            \"output_filters\": 304,\n            \"expand_ratio\": 6,\n            \"se_ratio\": 0.25,\n            \"strides\": 2,\n            \"conv_type\": 0,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 5,\n            \"input_filters\": 304,\n            \"output_filters\": 512,\n            \"expand_ratio\": 6,\n            \"se_ratio\": 0.25,\n            \"strides\": 1,\n            \"conv_type\": 0,\n        },\n    ],\n    \"efficientnetv2-l\": [\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 4,\n            \"input_filters\": 32,\n            \"output_filters\": 32,\n            \"expand_ratio\": 1,\n            \"se_ratio\": 0,\n            \"strides\": 1,\n            \"conv_type\": 1,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 7,\n            \"input_filters\": 32,\n            \"output_filters\": 64,\n            \"expand_ratio\": 4,\n            \"se_ratio\": 0,\n            \"strides\": 2,\n            \"conv_type\": 1,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 7,\n            \"input_filters\": 64,\n            \"output_filters\": 96,\n            \"expand_ratio\": 4,\n            \"se_ratio\": 0,\n            \"strides\": 2,\n            \"conv_type\": 1,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 10,\n            \"input_filters\": 96,\n            \"output_filters\": 192,\n            \"expand_ratio\": 4,\n            \"se_ratio\": 0.25,\n            \"strides\": 2,\n            \"conv_type\": 0,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 19,\n            \"input_filters\": 192,\n            \"output_filters\": 224,\n            \"expand_ratio\": 6,\n            \"se_ratio\": 0.25,\n            \"strides\": 1,\n            \"conv_type\": 0,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 25,\n            \"input_filters\": 224,\n            \"output_filters\": 384,\n            \"expand_ratio\": 6,\n            \"se_ratio\": 0.25,\n            \"strides\": 2,\n            \"conv_type\": 0,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 7,\n            \"input_filters\": 384,\n            \"output_filters\": 640,\n            \"expand_ratio\": 6,\n            \"se_ratio\": 0.25,\n            \"strides\": 1,\n            \"conv_type\": 0,\n        },\n    ],\n    \"efficientnetv2-b0\": [\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 1,\n            \"input_filters\": 32,\n            \"output_filters\": 16,\n            \"expand_ratio\": 1,\n            \"se_ratio\": 0,\n            \"strides\": 1,\n            \"conv_type\": 1,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 2,\n            \"input_filters\": 16,\n            \"output_filters\": 32,\n            \"expand_ratio\": 4,\n            \"se_ratio\": 0,\n            \"strides\": 2,\n            \"conv_type\": 1,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 2,\n            \"input_filters\": 32,\n            \"output_filters\": 48,\n            \"expand_ratio\": 4,\n            \"se_ratio\": 0,\n            \"strides\": 2,\n            \"conv_type\": 1,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 3,\n            \"input_filters\": 48,\n            \"output_filters\": 96,\n            \"expand_ratio\": 4,\n            \"se_ratio\": 0.25,\n            \"strides\": 2,\n            \"conv_type\": 0,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 5,\n            \"input_filters\": 96,\n            \"output_filters\": 112,\n            \"expand_ratio\": 6,\n            \"se_ratio\": 0.25,\n            \"strides\": 1,\n            \"conv_type\": 0,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 8,\n            \"input_filters\": 112,\n            \"output_filters\": 192,\n            \"expand_ratio\": 6,\n            \"se_ratio\": 0.25,\n            \"strides\": 2,\n            \"conv_type\": 0,\n        },\n    ],\n    \"efficientnetv2-b1\": [\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 1,\n            \"input_filters\": 32,\n            \"output_filters\": 16,\n            \"expand_ratio\": 1,\n            \"se_ratio\": 0,\n            \"strides\": 1,\n            \"conv_type\": 1,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 2,\n            \"input_filters\": 16,\n            \"output_filters\": 32,\n            \"expand_ratio\": 4,\n            \"se_ratio\": 0,\n            \"strides\": 2,\n            \"conv_type\": 1,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 2,\n            \"input_filters\": 32,\n            \"output_filters\": 48,\n            \"expand_ratio\": 4,\n            \"se_ratio\": 0,\n            \"strides\": 2,\n            \"conv_type\": 1,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 3,\n            \"input_filters\": 48,\n            \"output_filters\": 96,\n            \"expand_ratio\": 4,\n            \"se_ratio\": 0.25,\n            \"strides\": 2,\n            \"conv_type\": 0,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 5,\n            \"input_filters\": 96,\n            \"output_filters\": 112,\n            \"expand_ratio\": 6,\n            \"se_ratio\": 0.25,\n            \"strides\": 1,\n            \"conv_type\": 0,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 8,\n            \"input_filters\": 112,\n            \"output_filters\": 192,\n            \"expand_ratio\": 6,\n            \"se_ratio\": 0.25,\n            \"strides\": 2,\n            \"conv_type\": 0,\n        },\n    ],\n    \"efficientnetv2-b2\": [\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 1,\n            \"input_filters\": 32,\n            \"output_filters\": 16,\n            \"expand_ratio\": 1,\n            \"se_ratio\": 0,\n            \"strides\": 1,\n            \"conv_type\": 1,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 2,\n            \"input_filters\": 16,\n            \"output_filters\": 32,\n            \"expand_ratio\": 4,\n            \"se_ratio\": 0,\n            \"strides\": 2,\n            \"conv_type\": 1,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 2,\n            \"input_filters\": 32,\n            \"output_filters\": 48,\n            \"expand_ratio\": 4,\n            \"se_ratio\": 0,\n            \"strides\": 2,\n            \"conv_type\": 1,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 3,\n            \"input_filters\": 48,\n            \"output_filters\": 96,\n            \"expand_ratio\": 4,\n            \"se_ratio\": 0.25,\n            \"strides\": 2,\n            \"conv_type\": 0,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 5,\n            \"input_filters\": 96,\n            \"output_filters\": 112,\n            \"expand_ratio\": 6,\n            \"se_ratio\": 0.25,\n            \"strides\": 1,\n            \"conv_type\": 0,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 8,\n            \"input_filters\": 112,\n            \"output_filters\": 192,\n            \"expand_ratio\": 6,\n            \"se_ratio\": 0.25,\n            \"strides\": 2,\n            \"conv_type\": 0,\n        },\n    ],\n    \"efficientnetv2-b3\": [\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 1,\n            \"input_filters\": 32,\n            \"output_filters\": 16,\n            \"expand_ratio\": 1,\n            \"se_ratio\": 0,\n            \"strides\": 1,\n            \"conv_type\": 1,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 2,\n            \"input_filters\": 16,\n            \"output_filters\": 32,\n            \"expand_ratio\": 4,\n            \"se_ratio\": 0,\n            \"strides\": 2,\n            \"conv_type\": 1,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 2,\n            \"input_filters\": 32,\n            \"output_filters\": 48,\n            \"expand_ratio\": 4,\n            \"se_ratio\": 0,\n            \"strides\": 2,\n            \"conv_type\": 1,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 3,\n            \"input_filters\": 48,\n            \"output_filters\": 96,\n            \"expand_ratio\": 4,\n            \"se_ratio\": 0.25,\n            \"strides\": 2,\n            \"conv_type\": 0,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 5,\n            \"input_filters\": 96,\n            \"output_filters\": 112,\n            \"expand_ratio\": 6,\n            \"se_ratio\": 0.25,\n            \"strides\": 1,\n            \"conv_type\": 0,\n        },\n        {\n            \"kernel_size\": 3,\n            \"num_repeat\": 8,\n            \"input_filters\": 112,\n            \"output_filters\": 192,\n            \"expand_ratio\": 6,\n            \"se_ratio\": 0.25,\n            \"strides\": 2,\n            \"conv_type\": 0,\n        },\n    ],\n}\n\nCONV_KERNEL_INITIALIZER = {\n    \"class_name\": \"VarianceScaling\",\n    \"config\": {\n        \"scale\": 2.0,\n        \"mode\": \"fan_out\",\n        \"distribution\": \"truncated_normal\",\n    },\n}\n\nDENSE_KERNEL_INITIALIZER = {\n    \"class_name\": \"VarianceScaling\",\n    \"config\": {\n        \"scale\": 1.0 / 3.0,\n        \"mode\": \"fan_out\",\n        \"distribution\": \"uniform\",\n    },\n}\n\nBASE_DOCSTRING = \"\"\"Instantiates the {name} architecture.\n\nReference:\n- [EfficientNetV2: Smaller Models and Faster Training](\n    https://arxiv.org/abs/2104.00298) (ICML 2021)\n\nThis function returns a Keras image classification model,\noptionally loaded with weights pre-trained on ImageNet.\n\nFor image classification use cases, see\n[this page for detailed examples](\nhttps://keras.io/api/applications/#usage-examples-for-image-classification-models).\n\nFor transfer learning use cases, make sure to read the\n[guide to transfer learning & fine-tuning](\nhttps://keras.io/guides/transfer_learning/).\n\nNote: each Keras Application expects a specific kind of input preprocessing.\nFor EfficientNetV2, by default input preprocessing is included as a part of\nthe model (as a `Rescaling` layer), and thus\n`keras.applications.efficientnet_v2.preprocess_input` is actually a\npass-through function. In this use case, EfficientNetV2 models expect their\ninputs to be float tensors of pixels with values in the `[0, 255]` range.\nAt the same time, preprocessing as a part of the model (i.e. `Rescaling`\nlayer) can be disabled by setting `include_preprocessing` argument to `False`.\nWith preprocessing disabled EfficientNetV2 models expect their inputs to be\nfloat tensors of pixels with values in the `[-1, 1]` range.\n\nArgs:\n    include_top: Boolean, whether to include the fully-connected\n        layer at the top of the network. Defaults to `True`.\n    weights: One of `None` (random initialization),\n        `\"imagenet\"` (pre-training on ImageNet),\n        or the path to the weights file to be loaded. Defaults to `\"imagenet\"`.\n    input_tensor: Optional Keras tensor\n        (i.e. output of `layers.Input()`)\n        to use as image input for the model.\n    input_shape: Optional shape tuple, only to be specified\n        if `include_top` is `False`.\n        It should have exactly 3 inputs channels.\n    pooling: Optional pooling mode for feature extraction\n        when `include_top` is `False`. Defaults to None.\n        - `None` means that the output of the model will be\n            the 4D tensor output of the\n            last convolutional layer.\n        - `\"avg\"` means that global average pooling\n            will be applied to the output of the\n            last convolutional layer, and thus\n            the output of the model will be a 2D tensor.\n        - `\"max\"` means that global max pooling will\n            be applied.\n    classes: Optional number of classes to classify images\n        into, only to be specified if `include_top` is `True`, and\n        if no `weights` argument is specified. Defaults to 1000 (number of\n        ImageNet classes).\n    classifier_activation: A string or callable. The activation function to use\n        on the \"top\" layer. Ignored unless `include_top=True`. Set\n        `classifier_activation=None` to return the logits of the \"top\" layer.\n        Defaults to `\"softmax\"`.\n        When loading pretrained weights, `classifier_activation` can only\n        be `None` or `\"softmax\"`.\n    name: The name of the model (string).\n\nReturns:\n    A model instance.\n\"\"\"\n\n\ndef round_filters(filters, width_coefficient, min_depth, depth_divisor):\n    \"\"\"Round number of filters based on depth multiplier.\"\"\"\n    filters *= width_coefficient\n    minimum_depth = min_depth or depth_divisor\n    new_filters = max(\n        minimum_depth,\n        int(filters + depth_divisor / 2) // depth_divisor * depth_divisor,\n    )\n    return int(new_filters)\n\n\ndef round_repeats(repeats, depth_coefficient):\n    \"\"\"Round number of repeats based on depth multiplier.\"\"\"\n    return int(math.ceil(depth_coefficient * repeats))\n\n\ndef MBConvBlock(\n    input_filters,\n    output_filters,\n    expand_ratio=1,\n    kernel_size=3,\n    strides=1,\n    se_ratio=0.0,\n    bn_momentum=0.9,\n    activation=\"swish\",\n    survival_probability=0.8,\n    name=None,\n):\n    \"\"\"MBConv block: Mobile Inverted Residual Bottleneck.\"\"\"\n    bn_axis = 3 if backend.image_data_format() == \"channels_last\" else 1\n\n    if name is None:\n        name = backend.get_uid(\"block0\")\n\n    def apply(inputs):\n        # Expansion phase\n        filters = input_filters * expand_ratio\n        if expand_ratio != 1:\n            x = layers.Conv2D(\n                filters=filters,\n                kernel_size=1,\n                strides=1,\n                kernel_initializer=CONV_KERNEL_INITIALIZER,\n                padding=\"same\",\n                data_format=backend.image_data_format(),\n                use_bias=False,\n                name=f\"{name}expand_conv\",\n            )(inputs)\n            x = layers.BatchNormalization(\n                axis=bn_axis,\n                momentum=bn_momentum,\n                name=f\"{name}expand_bn\",\n            )(x)\n            x = layers.Activation(activation, name=f\"{name}expand_activation\")(\n                x\n            )\n        else:\n            x = inputs\n\n        # Depthwise conv\n        x = layers.DepthwiseConv2D(\n            kernel_size=kernel_size,\n            strides=strides,\n            depthwise_initializer=CONV_KERNEL_INITIALIZER,\n            padding=\"same\",\n            data_format=backend.image_data_format(),\n            use_bias=False,\n            name=f\"{name}dwconv2\",\n        )(x)\n        x = layers.BatchNormalization(\n            axis=bn_axis, momentum=bn_momentum, name=f\"{name}bn\"\n        )(x)\n        x = layers.Activation(activation, name=f\"{name}activation\")(x)\n\n        # Squeeze and excite\n        if 0 < se_ratio <= 1:\n            filters_se = max(1, int(input_filters * se_ratio))\n            se = layers.GlobalAveragePooling2D(name=f\"{name}se_squeeze\")(x)\n            if bn_axis == 1:\n                se_shape = (filters, 1, 1)\n            else:\n                se_shape = (1, 1, filters)\n            se = layers.Reshape(se_shape, name=f\"{name}se_reshape\")(se)\n\n            se = layers.Conv2D(\n                filters_se,\n                1,\n                padding=\"same\",\n                activation=activation,\n                kernel_initializer=CONV_KERNEL_INITIALIZER,\n                name=f\"{name}se_reduce\",\n            )(se)\n            se = layers.Conv2D(\n                filters,\n                1,\n                padding=\"same\",\n                activation=\"sigmoid\",\n                kernel_initializer=CONV_KERNEL_INITIALIZER,\n                name=f\"{name}se_expand\",\n            )(se)\n\n            x = layers.multiply([x, se], name=f\"{name}se_excite\")\n\n        # Output phase\n        x = layers.Conv2D(\n            filters=output_filters,\n            kernel_size=1,\n            strides=1,\n            kernel_initializer=CONV_KERNEL_INITIALIZER,\n            padding=\"same\",\n            data_format=backend.image_data_format(),\n            use_bias=False,\n            name=f\"{name}project_conv\",\n        )(x)\n        x = layers.BatchNormalization(\n            axis=bn_axis, momentum=bn_momentum, name=f\"{name}project_bn\"\n        )(x)\n\n        if strides == 1 and input_filters == output_filters:\n            if survival_probability:\n                x = layers.Dropout(\n                    survival_probability,\n                    noise_shape=(None, 1, 1, 1),\n                    name=f\"{name}drop\",\n                )(x)\n            x = layers.add([x, inputs], name=f\"{name}add\")\n\n        return x\n\n    return apply\n\n\ndef FusedMBConvBlock(\n    input_filters,\n    output_filters,\n    expand_ratio=1,\n    kernel_size=3,\n    strides=1,\n    se_ratio=0.0,\n    bn_momentum=0.9,\n    activation=\"swish\",\n    survival_probability=0.8,\n    name=None,\n):\n    \"\"\"Fuses the proj conv1x1 and depthwise_conv into a conv2d.\"\"\"\n    bn_axis = 3 if backend.image_data_format() == \"channels_last\" else 1\n\n    if name is None:\n        name = backend.get_uid(\"block0\")\n\n    def apply(inputs):\n        filters = input_filters * expand_ratio\n        if expand_ratio != 1:\n            x = layers.Conv2D(\n                filters,\n                kernel_size=kernel_size,\n                strides=strides,\n                kernel_initializer=CONV_KERNEL_INITIALIZER,\n                data_format=backend.image_data_format(),\n                padding=\"same\",\n                use_bias=False,\n                name=f\"{name}expand_conv\",\n            )(inputs)\n            x = layers.BatchNormalization(\n                axis=bn_axis, momentum=bn_momentum, name=f\"{name}expand_bn\"\n            )(x)\n            x = layers.Activation(\n                activation=activation, name=f\"{name}expand_activation\"\n            )(x)\n        else:\n            x = inputs\n\n        # Squeeze and excite\n        if 0 < se_ratio <= 1:\n            filters_se = max(1, int(input_filters * se_ratio))\n            se = layers.GlobalAveragePooling2D(name=f\"{name}se_squeeze\")(x)\n            if bn_axis == 1:\n                se_shape = (filters, 1, 1)\n            else:\n                se_shape = (1, 1, filters)\n\n            se = layers.Reshape(se_shape, name=f\"{name}se_reshape\")(se)\n\n            se = layers.Conv2D(\n                filters_se,\n                1,\n                padding=\"same\",\n                activation=activation,\n                kernel_initializer=CONV_KERNEL_INITIALIZER,\n                name=f\"{name}se_reduce\",\n            )(se)\n            se = layers.Conv2D(\n                filters,\n                1,\n                padding=\"same\",\n                activation=\"sigmoid\",\n                kernel_initializer=CONV_KERNEL_INITIALIZER,\n                name=f\"{name}se_expand\",\n            )(se)\n\n            x = layers.multiply([x, se], name=f\"{name}se_excite\")\n\n        # Output phase:\n        x = layers.Conv2D(\n            output_filters,\n            kernel_size=1 if expand_ratio != 1 else kernel_size,\n            strides=1 if expand_ratio != 1 else strides,\n            kernel_initializer=CONV_KERNEL_INITIALIZER,\n            padding=\"same\",\n            use_bias=False,\n            name=f\"{name}project_conv\",\n        )(x)\n        x = layers.BatchNormalization(\n            axis=bn_axis, momentum=bn_momentum, name=f\"{name}project_bn\"\n        )(x)\n        if expand_ratio == 1:\n            x = layers.Activation(\n                activation=activation, name=f\"{name}project_activation\"\n            )(x)\n\n        # Residual:\n        if strides == 1 and input_filters == output_filters:\n            if survival_probability:\n                x = layers.Dropout(\n                    survival_probability,\n                    noise_shape=(None, 1, 1, 1),\n                    name=f\"{name}drop\",\n                )(x)\n            x = layers.add([x, inputs], name=f\"{name}add\")\n        return x\n\n    return apply\n\n\ndef EfficientNetV2(\n    width_coefficient,\n    depth_coefficient,\n    default_size,\n    dropout_rate=0.2,\n    drop_connect_rate=0.2,\n    depth_divisor=8,\n    min_depth=8,\n    bn_momentum=0.9,\n    activation=\"swish\",\n    blocks_args=\"default\",\n    name=\"efficientnetv2\",\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    include_preprocessing=True,\n    weights_name=None,\n):\n    \"\"\"Instantiates the EfficientNetV2 architecture using given scaling\n    coefficients.\n\n    Args:\n        width_coefficient: float, scaling coefficient for network width.\n        depth_coefficient: float, scaling coefficient for network depth.\n        default_size: integer, default input image size.\n        dropout_rate: float, dropout rate before final classifier layer.\n        drop_connect_rate: float, dropout rate at skip connections.\n        depth_divisor: integer, a unit of network width.\n        min_depth: integer, minimum number of filters.\n        bn_momentum: float. Momentum parameter for Batch Normalization layers.\n        activation: activation function.\n        blocks_args: list of dicts, parameters to construct block modules.\n        name: string, model name.\n        include_top: whether to include the fully-connected layer at the top of\n            the network.\n        weights: one of `None` (random initialization), `\"imagenet\"`\n            (pre-training on ImageNet),\n            or the path to the weights file to be loaded.\n        input_tensor: optional Keras tensor (i.e. output of `layers.Input()`) or\n            numpy array to use as image input for the model.\n        input_shape: optional shape tuple, only to be specified if `include_top`\n            is `False`. It should have exactly 3 inputs channels.\n        pooling: optional pooling mode for feature extraction when `include_top`\n            is `False`.\n            - `None` means that the output of the model will be the\n                4D tensor output of the last convolutional layer.\n            - \"avg\" means that global average pooling will be applied to\n                the output of the last convolutional layer,\n                and thus the output of the model will be a 2D tensor.\n            - `\"max\"` means that global max pooling will be applied.\n        classes: optional number of classes to classify images into,\n            only to be specified if `include_top` is `True`, and if no `weights`\n            argument is specified.\n        classifier_activation: A string or callable. The activation function to\n            use on the \"top\" layer. Ignored unless `include_top=True`. Set\n            `classifier_activation=None` to return the logits of the \"top\"\n            layer.\n        include_preprocessing: Boolean, whether to include the preprocessing\n            layer (`Rescaling`) at the bottom of the network.\n            Defaults to `True`.\n\n    Returns:\n        A model instance.\n    \"\"\"\n\n    if blocks_args == \"default\":\n        blocks_args = DEFAULT_BLOCKS_ARGS[name]\n\n    if not (weights in {\"imagenet\", None} or file_utils.exists(weights)):\n        raise ValueError(\n            \"The `weights` argument should be either \"\n            \"`None` (random initialization), `imagenet` \"\n            \"(pre-training on ImageNet), \"\n            \"or the path to the weights file to be loaded.\"\n            f\"Received: weights={weights}\"\n        )\n\n    if weights == \"imagenet\" and include_top and classes != 1000:\n        raise ValueError(\n            'If using `weights=\"imagenet\"` with `include_top`'\n            \" as true, `classes` should be 1000\"\n        )\n\n    # Determine proper input shape\n    input_shape = imagenet_utils.obtain_input_shape(\n        input_shape,\n        default_size=default_size,\n        min_size=32,\n        data_format=backend.image_data_format(),\n        require_flatten=include_top,\n        weights=weights,\n    )\n\n    if input_tensor is None:\n        img_input = layers.Input(shape=input_shape)\n    else:\n        if not backend.is_keras_tensor(input_tensor):\n            img_input = layers.Input(tensor=input_tensor, shape=input_shape)\n        else:\n            img_input = input_tensor\n\n    bn_axis = 3 if backend.image_data_format() == \"channels_last\" else 1\n\n    x = img_input\n\n    if include_preprocessing:\n        # Apply original V1 preprocessing for Bx variants\n        # if number of channels allows it\n        num_channels = input_shape[bn_axis - 1]\n        if name.split(\"-\")[-1].startswith(\"b\") and num_channels == 3:\n            x = layers.Rescaling(scale=1.0 / 255)(x)\n            mean = [0.485, 0.456, 0.406]\n            variance = [0.229**2, 0.224**2, 0.225**2]\n            x = layers.Normalization(\n                mean=mean,\n                variance=variance,\n                axis=bn_axis,\n            )(x)\n        else:\n            x = layers.Rescaling(scale=1.0 / 128.0, offset=-1)(x)\n\n    # Build stem\n    stem_filters = round_filters(\n        filters=blocks_args[0][\"input_filters\"],\n        width_coefficient=width_coefficient,\n        min_depth=min_depth,\n        depth_divisor=depth_divisor,\n    )\n    x = layers.Conv2D(\n        filters=stem_filters,\n        kernel_size=3,\n        strides=2,\n        kernel_initializer=CONV_KERNEL_INITIALIZER,\n        padding=\"same\",\n        use_bias=False,\n        name=\"stem_conv\",\n    )(x)\n    x = layers.BatchNormalization(\n        axis=bn_axis,\n        momentum=bn_momentum,\n        name=\"stem_bn\",\n    )(x)\n    x = layers.Activation(activation, name=\"stem_activation\")(x)\n\n    # Build blocks\n    blocks_args = copy.deepcopy(blocks_args)\n    b = 0\n    blocks = float(sum(args[\"num_repeat\"] for args in blocks_args))\n\n    for i, args in enumerate(blocks_args):\n        if args[\"num_repeat\"] <= 0:\n            raise ValueError(\n                f\"The number of repeats in `EfficientNetV2` must be > 0. \"\n                f\"Received: num_repeat={args['num_repeat']}\"\n            )\n\n        # Update block input and output filters based on depth multiplier.\n        args[\"input_filters\"] = round_filters(\n            filters=args[\"input_filters\"],\n            width_coefficient=width_coefficient,\n            min_depth=min_depth,\n            depth_divisor=depth_divisor,\n        )\n        args[\"output_filters\"] = round_filters(\n            filters=args[\"output_filters\"],\n            width_coefficient=width_coefficient,\n            min_depth=min_depth,\n            depth_divisor=depth_divisor,\n        )\n\n        # Determine which conv type to use:\n        block = {0: MBConvBlock, 1: FusedMBConvBlock}[args.pop(\"conv_type\")]\n        repeats = round_repeats(\n            repeats=args.pop(\"num_repeat\"), depth_coefficient=depth_coefficient\n        )\n        for j in range(repeats):\n            # The first block needs to take care of stride and filter size\n            # increase.\n            if j > 0:\n                args[\"strides\"] = 1\n                args[\"input_filters\"] = args[\"output_filters\"]\n\n            x = block(\n                activation=activation,\n                bn_momentum=bn_momentum,\n                survival_probability=drop_connect_rate * b / blocks,\n                name=f\"block{i + 1}{chr(j + 97)}_\",\n                **args,\n            )(x)\n            b += 1\n\n    # Build top\n    top_filters = round_filters(\n        filters=1280,\n        width_coefficient=width_coefficient,\n        min_depth=min_depth,\n        depth_divisor=depth_divisor,\n    )\n    x = layers.Conv2D(\n        filters=top_filters,\n        kernel_size=1,\n        strides=1,\n        kernel_initializer=CONV_KERNEL_INITIALIZER,\n        padding=\"same\",\n        data_format=backend.image_data_format(),\n        use_bias=False,\n        name=\"top_conv\",\n    )(x)\n    x = layers.BatchNormalization(\n        axis=bn_axis,\n        momentum=bn_momentum,\n        name=\"top_bn\",\n    )(x)\n    x = layers.Activation(activation=activation, name=\"top_activation\")(x)\n\n    if include_top:\n        x = layers.GlobalAveragePooling2D(name=\"avg_pool\")(x)\n        if dropout_rate > 0:\n            x = layers.Dropout(dropout_rate, name=\"top_dropout\")(x)\n        imagenet_utils.validate_activation(classifier_activation, weights)\n        x = layers.Dense(\n            classes,\n            activation=classifier_activation,\n            kernel_initializer=DENSE_KERNEL_INITIALIZER,\n            bias_initializer=initializers.Constant(0.0),\n            name=\"predictions\",\n        )(x)\n    else:\n        if pooling == \"avg\":\n            x = layers.GlobalAveragePooling2D(name=\"avg_pool\")(x)\n        elif pooling == \"max\":\n            x = layers.GlobalMaxPooling2D(name=\"max_pool\")(x)\n\n    # Ensure that the model takes into account\n    # any potential predecessors of `input_tensor`.\n    if input_tensor is not None:\n        inputs = operation_utils.get_source_inputs(input_tensor)\n    else:\n        inputs = img_input\n\n    # Create model.\n    model = Functional(inputs, x, name=name)\n\n    # Load weights.\n    if weights == \"imagenet\":\n        if include_top:\n            file_suffix = \".h5\"\n            file_hash = WEIGHTS_HASHES[weights_name][0]\n        else:\n            file_suffix = \"_notop.h5\"\n            file_hash = WEIGHTS_HASHES[weights_name][1]\n        file_name = name + file_suffix\n        weights_path = file_utils.get_file(\n            file_name,\n            BASE_WEIGHTS_PATH + file_name,\n            cache_subdir=\"models\",\n            file_hash=file_hash,\n        )\n        model.load_weights(weights_path)\n    elif weights is not None:\n        model.load_weights(weights)\n\n    return model\n\n\n@keras_export(\n    [\n        \"keras.applications.efficientnet_v2.EfficientNetV2B0\",\n        \"keras.applications.EfficientNetV2B0\",\n    ]\n)\ndef EfficientNetV2B0(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    include_preprocessing=True,\n    name=\"efficientnetv2-b0\",\n):\n    return EfficientNetV2(\n        width_coefficient=1.0,\n        depth_coefficient=1.0,\n        default_size=224,\n        name=name,\n        include_top=include_top,\n        weights=weights,\n        input_tensor=input_tensor,\n        input_shape=input_shape,\n        pooling=pooling,\n        classes=classes,\n        classifier_activation=classifier_activation,\n        include_preprocessing=include_preprocessing,\n        weights_name=\"b0\",\n    )\n\n\n@keras_export(\n    [\n        \"keras.applications.efficientnet_v2.EfficientNetV2B1\",\n        \"keras.applications.EfficientNetV2B1\",\n    ]\n)\ndef EfficientNetV2B1(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    include_preprocessing=True,\n    name=\"efficientnetv2-b1\",\n):\n    return EfficientNetV2(\n        width_coefficient=1.0,\n        depth_coefficient=1.1,\n        default_size=240,\n        name=name,\n        include_top=include_top,\n        weights=weights,\n        input_tensor=input_tensor,\n        input_shape=input_shape,\n        pooling=pooling,\n        classes=classes,\n        classifier_activation=classifier_activation,\n        include_preprocessing=include_preprocessing,\n        weights_name=\"b1\",\n    )\n\n\n@keras_export(\n    [\n        \"keras.applications.efficientnet_v2.EfficientNetV2B2\",\n        \"keras.applications.EfficientNetV2B2\",\n    ]\n)\ndef EfficientNetV2B2(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    include_preprocessing=True,\n    name=\"efficientnetv2-b2\",\n):\n    return EfficientNetV2(\n        width_coefficient=1.1,\n        depth_coefficient=1.2,\n        default_size=260,\n        name=name,\n        include_top=include_top,\n        weights=weights,\n        input_tensor=input_tensor,\n        input_shape=input_shape,\n        pooling=pooling,\n        classes=classes,\n        classifier_activation=classifier_activation,\n        include_preprocessing=include_preprocessing,\n        weights_name=\"b2\",\n    )\n\n\n@keras_export(\n    [\n        \"keras.applications.efficientnet_v2.EfficientNetV2B3\",\n        \"keras.applications.EfficientNetV2B3\",\n    ]\n)\ndef EfficientNetV2B3(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    include_preprocessing=True,\n    name=\"efficientnetv2-b3\",\n):\n    return EfficientNetV2(\n        width_coefficient=1.2,\n        depth_coefficient=1.4,\n        default_size=300,\n        name=name,\n        include_top=include_top,\n        weights=weights,\n        input_tensor=input_tensor,\n        input_shape=input_shape,\n        pooling=pooling,\n        classes=classes,\n        classifier_activation=classifier_activation,\n        include_preprocessing=include_preprocessing,\n        weights_name=\"b3\",\n    )\n\n\n@keras_export(\n    [\n        \"keras.applications.efficientnet_v2.EfficientNetV2S\",\n        \"keras.applications.EfficientNetV2S\",\n    ]\n)\ndef EfficientNetV2S(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    include_preprocessing=True,\n    name=\"efficientnetv2-s\",\n):\n    return EfficientNetV2(\n        width_coefficient=1.0,\n        depth_coefficient=1.0,\n        default_size=384,\n        name=name,\n        include_top=include_top,\n        weights=weights,\n        input_tensor=input_tensor,\n        input_shape=input_shape,\n        pooling=pooling,\n        classes=classes,\n        classifier_activation=classifier_activation,\n        include_preprocessing=include_preprocessing,\n        weights_name=\"-s\",\n    )\n\n\n@keras_export(\n    [\n        \"keras.applications.efficientnet_v2.EfficientNetV2M\",\n        \"keras.applications.EfficientNetV2M\",\n    ]\n)\ndef EfficientNetV2M(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    include_preprocessing=True,\n    name=\"efficientnetv2-m\",\n):\n    return EfficientNetV2(\n        width_coefficient=1.0,\n        depth_coefficient=1.0,\n        default_size=480,\n        name=name,\n        include_top=include_top,\n        weights=weights,\n        input_tensor=input_tensor,\n        input_shape=input_shape,\n        pooling=pooling,\n        classes=classes,\n        classifier_activation=classifier_activation,\n        include_preprocessing=include_preprocessing,\n        weights_name=\"-m\",\n    )\n\n\n@keras_export(\n    [\n        \"keras.applications.efficientnet_v2.EfficientNetV2L\",\n        \"keras.applications.EfficientNetV2L\",\n    ]\n)\ndef EfficientNetV2L(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    include_preprocessing=True,\n    name=\"efficientnetv2-l\",\n):\n    return EfficientNetV2(\n        width_coefficient=1.0,\n        depth_coefficient=1.0,\n        default_size=480,\n        name=name,\n        include_top=include_top,\n        weights=weights,\n        input_tensor=input_tensor,\n        input_shape=input_shape,\n        pooling=pooling,\n        classes=classes,\n        classifier_activation=classifier_activation,\n        include_preprocessing=include_preprocessing,\n        weights_name=\"-l\",\n    )\n\n\nEfficientNetV2B0.__doc__ = BASE_DOCSTRING.format(name=\"EfficientNetV2B0\")\nEfficientNetV2B1.__doc__ = BASE_DOCSTRING.format(name=\"EfficientNetV2B1\")\nEfficientNetV2B2.__doc__ = BASE_DOCSTRING.format(name=\"EfficientNetV2B2\")\nEfficientNetV2B3.__doc__ = BASE_DOCSTRING.format(name=\"EfficientNetV2B3\")\nEfficientNetV2S.__doc__ = BASE_DOCSTRING.format(name=\"EfficientNetV2S\")\nEfficientNetV2M.__doc__ = BASE_DOCSTRING.format(name=\"EfficientNetV2M\")\nEfficientNetV2L.__doc__ = BASE_DOCSTRING.format(name=\"EfficientNetV2L\")\n\n\n@keras_export(\"keras.applications.efficientnet_v2.preprocess_input\")\ndef preprocess_input(x, data_format=None):\n    \"\"\"A placeholder method for backward compatibility.\n\n    The preprocessing logic has been included in the EfficientNetV2 model\n    implementation. Users are no longer required to call this method to\n    normalize the input data. This method does nothing and only kept as a\n    placeholder to align the API surface between old and new version of model.\n\n    Args:\n        x: A floating point `numpy.array` or a tensor.\n        data_format: Optional data format of the image tensor/array. Defaults to\n            None, in which case the global setting\n            `keras.backend.image_data_format()` is used\n            (unless you changed it, it defaults to \"channels_last\").{mode}\n\n    Returns:\n        Unchanged `numpy.array` or tensor.\n    \"\"\"\n    return x\n\n\n@keras_export(\"keras.applications.efficientnet_v2.decode_predictions\")\ndef decode_predictions(preds, top=5):\n    return imagenet_utils.decode_predictions(preds, top=top)\n\n\ndecode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__\n"
  },
  {
    "path": "keras/src/applications/imagenet_utils.py",
    "content": "import json\nimport warnings\n\nimport numpy as np\n\nfrom keras.src import activations\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.utils import file_utils\n\nCLASS_INDEX = None\nCLASS_INDEX_PATH = (\n    \"https://storage.googleapis.com/download.tensorflow.org/\"\n    \"data/imagenet_class_index.json\"\n)\n\n\nPREPROCESS_INPUT_DOC = \"\"\"\n  Preprocesses a tensor or Numpy array encoding a batch of images.\n\n  Usage example with `applications.MobileNet`:\n\n  ```python\n  i = keras.layers.Input([None, None, 3], dtype=\"uint8\")\n  x = ops.cast(i, \"float32\")\n  x = keras.applications.mobilenet.preprocess_input(x)\n  core = keras.applications.MobileNet()\n  x = core(x)\n  model = keras.Model(inputs=[i], outputs=[x])\n  result = model(image)\n  ```\n\n  Args:\n        x: A floating point `numpy.array` or a backend-native tensor,\n            3D or 4D with 3 color\n            channels, with values in the range [0, 255].\n            The preprocessed data are written over the input data\n        if the data types are compatible. To avoid this\n        behaviour, `numpy.copy(x)` can be used.\n        data_format: Optional data format of the image tensor/array. None, means\n        the global setting `keras.backend.image_data_format()` is used\n        (unless you changed it, it uses \"channels_last\").{mode}\n        Defaults to `None`.\n\n  Returns:\n      Preprocessed array with type `float32`.\n      {ret}\n\n  Raises:\n      {error}\n  \"\"\"\n\nPREPROCESS_INPUT_MODE_DOC = \"\"\"\n    mode: One of \"caffe\", \"tf\" or \"torch\".\n      - caffe: will convert the images from RGB to BGR,\n          then will zero-center each color channel with\n          respect to the ImageNet dataset,\n          without scaling.\n      - tf: will scale pixels between -1 and 1,\n          sample-wise.\n      - torch: will scale pixels between 0 and 1 and then\n          will normalize each channel with respect to the\n          ImageNet dataset.\n      Defaults to `\"caffe\"`.\n  \"\"\"\n\nPREPROCESS_INPUT_DEFAULT_ERROR_DOC = \"\"\"\n    ValueError: In case of unknown `mode` or `data_format` argument.\"\"\"\n\nPREPROCESS_INPUT_ERROR_DOC = \"\"\"\n    ValueError: In case of unknown `data_format` argument.\"\"\"\n\nPREPROCESS_INPUT_RET_DOC_TF = \"\"\"\n      The inputs pixel values are scaled between -1 and 1, sample-wise.\"\"\"\n\nPREPROCESS_INPUT_RET_DOC_TORCH = \"\"\"\n      The input pixels values are scaled between 0 and 1 and each channel is\n      normalized with respect to the ImageNet dataset.\"\"\"\n\nPREPROCESS_INPUT_RET_DOC_CAFFE = \"\"\"\n      The images are converted from RGB to BGR, then each color channel is\n      zero-centered with respect to the ImageNet dataset, without scaling.\"\"\"\n\n\n@keras_export(\"keras.applications.imagenet_utils.preprocess_input\")\ndef preprocess_input(x, data_format=None, mode=\"caffe\"):\n    \"\"\"Preprocesses a tensor or Numpy array encoding a batch of images.\"\"\"\n    if mode not in {\"caffe\", \"tf\", \"torch\"}:\n        raise ValueError(\n            \"Expected mode to be one of `caffe`, `tf` or `torch`. \"\n            f\"Received: mode={mode}\"\n        )\n\n    if data_format is None:\n        data_format = backend.image_data_format()\n    elif data_format not in {\"channels_first\", \"channels_last\"}:\n        raise ValueError(\n            \"Expected data_format to be one of `channels_first` or \"\n            f\"`channels_last`. Received: data_format={data_format}\"\n        )\n\n    if isinstance(x, np.ndarray):\n        return _preprocess_numpy_input(x, data_format=data_format, mode=mode)\n    else:\n        return _preprocess_tensor_input(x, data_format=data_format, mode=mode)\n\n\npreprocess_input.__doc__ = PREPROCESS_INPUT_DOC.format(\n    mode=PREPROCESS_INPUT_MODE_DOC,\n    ret=\"\",\n    error=PREPROCESS_INPUT_DEFAULT_ERROR_DOC,\n)\n\n\n@keras_export(\"keras.applications.imagenet_utils.decode_predictions\")\ndef decode_predictions(preds, top=5):\n    \"\"\"Decodes the prediction of an ImageNet model.\n\n    Args:\n        preds: NumPy array encoding a batch of predictions.\n        top: Integer, how many top-guesses to return. Defaults to `5`.\n\n    Returns:\n        A list of lists of top class prediction tuples\n        `(class_name, class_description, score)`.\n        One list of tuples per sample in batch input.\n\n    Raises:\n        ValueError: In case of invalid shape of the `pred` array\n            (must be 2D).\n    \"\"\"\n    global CLASS_INDEX\n\n    if len(preds.shape) != 2 or preds.shape[1] != 1000:\n        raise ValueError(\n            \"`decode_predictions` expects \"\n            \"a batch of predictions \"\n            \"(i.e. a 2D array of shape (samples, 1000)). \"\n            f\"Received array with shape: {preds.shape}\"\n        )\n    if CLASS_INDEX is None:\n        fpath = file_utils.get_file(\n            \"imagenet_class_index.json\",\n            CLASS_INDEX_PATH,\n            cache_subdir=\"models\",\n            file_hash=\"c2c37ea517e94d9795004a39431a14cb\",\n        )\n        with open(fpath) as f:\n            CLASS_INDEX = json.load(f)\n    results = []\n    preds = ops.convert_to_numpy(preds)\n    for pred in preds:\n        top_indices = pred.argsort()[-top:][::-1]\n        result = [tuple(CLASS_INDEX[str(i)]) + (pred[i],) for i in top_indices]\n        result.sort(key=lambda x: x[2], reverse=True)\n        results.append(result)\n    return results\n\n\ndef _preprocess_numpy_input(x, data_format, mode):\n    \"\"\"Preprocesses a NumPy array encoding a batch of images.\n\n    Args:\n      x: Input array, 3D or 4D.\n      data_format: Data format of the image array.\n      mode: One of \"caffe\", \"tf\" or \"torch\".\n        - caffe: will convert the images from RGB to BGR,\n            then will zero-center each color channel with\n            respect to the ImageNet dataset,\n            without scaling.\n        - tf: will scale pixels between -1 and 1,\n            sample-wise.\n        - torch: will scale pixels between 0 and 1 and then\n            will normalize each channel with respect to the\n            ImageNet dataset.\n\n    Returns:\n        Preprocessed Numpy array.\n    \"\"\"\n    if not issubclass(x.dtype.type, np.floating):\n        x = x.astype(backend.floatx(), copy=False)\n\n    if mode == \"tf\":\n        x /= 127.5\n        x -= 1.0\n        return x\n    elif mode == \"torch\":\n        x /= 255.0\n        mean = [0.485, 0.456, 0.406]\n        std = [0.229, 0.224, 0.225]\n    else:\n        if data_format == \"channels_first\":\n            # 'RGB'->'BGR'\n            if len(x.shape) == 3:\n                x = x[::-1, ...]\n            else:\n                x = x[:, ::-1, ...]\n        else:\n            # 'RGB'->'BGR'\n            x = x[..., ::-1]\n        mean = [103.939, 116.779, 123.68]\n        std = None\n\n    # Zero-center by mean pixel\n    if data_format == \"channels_first\":\n        if len(x.shape) == 3:\n            x[0, :, :] -= mean[0]\n            x[1, :, :] -= mean[1]\n            x[2, :, :] -= mean[2]\n            if std is not None:\n                x[0, :, :] /= std[0]\n                x[1, :, :] /= std[1]\n                x[2, :, :] /= std[2]\n        else:\n            x[:, 0, :, :] -= mean[0]\n            x[:, 1, :, :] -= mean[1]\n            x[:, 2, :, :] -= mean[2]\n            if std is not None:\n                x[:, 0, :, :] /= std[0]\n                x[:, 1, :, :] /= std[1]\n                x[:, 2, :, :] /= std[2]\n    else:\n        x[..., 0] -= mean[0]\n        x[..., 1] -= mean[1]\n        x[..., 2] -= mean[2]\n        if std is not None:\n            x[..., 0] /= std[0]\n            x[..., 1] /= std[1]\n            x[..., 2] /= std[2]\n    return x\n\n\ndef _preprocess_tensor_input(x, data_format, mode):\n    \"\"\"Preprocesses a tensor encoding a batch of images.\n\n    Args:\n      x: Input tensor, 3D or 4D.\n      data_format: Data format of the image tensor.\n      mode: One of \"caffe\", \"tf\" or \"torch\".\n        - caffe: will convert the images from RGB to BGR,\n            then will zero-center each color channel with\n            respect to the ImageNet dataset,\n            without scaling.\n        - tf: will scale pixels between -1 and 1,\n            sample-wise.\n        - torch: will scale pixels between 0 and 1 and then\n            will normalize each channel with respect to the\n            ImageNet dataset.\n\n    Returns:\n        Preprocessed tensor.\n    \"\"\"\n    ndim = len(x.shape)\n\n    if mode == \"tf\":\n        x /= 127.5\n        x -= 1.0\n        return x\n    elif mode == \"torch\":\n        x /= 255.0\n        mean = [0.485, 0.456, 0.406]\n        std = [0.229, 0.224, 0.225]\n    else:\n        if data_format == \"channels_first\":\n            # 'RGB'->'BGR'\n            if len(x.shape) == 3:\n                x = ops.stack([x[i, ...] for i in (2, 1, 0)], axis=0)\n            else:\n                x = ops.stack([x[:, i, :] for i in (2, 1, 0)], axis=1)\n        else:\n            # 'RGB'->'BGR'\n            x = ops.stack([x[..., i] for i in (2, 1, 0)], axis=-1)\n        mean = [103.939, 116.779, 123.68]\n        std = None\n\n    mean_tensor = ops.convert_to_tensor(-np.array(mean), dtype=x.dtype)\n\n    # Zero-center by mean pixel\n    if data_format == \"channels_first\":\n        if len(x.shape) == 3:\n            mean_tensor = ops.reshape(mean_tensor, (3, 1, 1))\n        else:\n            mean_tensor = ops.reshape(mean_tensor, (1, 3) + (1,) * (ndim - 2))\n    else:\n        mean_tensor = ops.reshape(mean_tensor, (1,) * (ndim - 1) + (3,))\n    x += mean_tensor\n    if std is not None:\n        std_tensor = ops.convert_to_tensor(np.array(std), dtype=x.dtype)\n        if data_format == \"channels_first\":\n            std_tensor = ops.reshape(std_tensor, (-1, 1, 1))\n        x /= std_tensor\n    return x\n\n\ndef obtain_input_shape(\n    input_shape,\n    default_size,\n    min_size,\n    data_format,\n    require_flatten,\n    weights=None,\n):\n    \"\"\"Internal utility to compute/validate a model's input shape.\n\n    Args:\n      input_shape: Either None (will return the default network input shape),\n        or a user-provided shape to be validated.\n      default_size: Default input width/height for the model.\n      min_size: Minimum input width/height accepted by the model.\n      data_format: Image data format to use.\n      require_flatten: Whether the model is expected to\n        be linked to a classifier via a Flatten layer.\n      weights: One of `None` (random initialization)\n        or 'imagenet' (pre-training on ImageNet).\n        If weights='imagenet' input channels must be equal to 3.\n\n    Returns:\n      An integer shape tuple (may include None entries).\n\n    Raises:\n      ValueError: In case of invalid argument values.\n    \"\"\"\n    if weights != \"imagenet\" and input_shape and len(input_shape) == 3:\n        if data_format == \"channels_first\":\n            correct_channel_axis = 1 if len(input_shape) == 4 else 0\n            if input_shape[correct_channel_axis] not in {1, 3}:\n                warnings.warn(\n                    \"This model usually expects 1 or 3 input channels. \"\n                    \"However, it was passed an input_shape \"\n                    f\"with {input_shape[0]} input channels.\",\n                    stacklevel=2,\n                )\n            default_shape = (input_shape[0], default_size, default_size)\n        else:\n            if input_shape[-1] not in {1, 3}:\n                warnings.warn(\n                    \"This model usually expects 1 or 3 input channels. \"\n                    \"However, it was passed an input_shape \"\n                    f\"with {input_shape[-1]} input channels.\",\n                    stacklevel=2,\n                )\n            default_shape = (default_size, default_size, input_shape[-1])\n    else:\n        if data_format == \"channels_first\":\n            default_shape = (3, default_size, default_size)\n        else:\n            default_shape = (default_size, default_size, 3)\n    if weights == \"imagenet\" and require_flatten:\n        if input_shape is not None:\n            if input_shape != default_shape:\n                raise ValueError(\n                    \"When setting `include_top=True` \"\n                    \"and loading `imagenet` weights, \"\n                    f\"`input_shape` should be {default_shape}.  \"\n                    f\"Received: input_shape={input_shape}\"\n                )\n        return default_shape\n    if input_shape:\n        if data_format == \"channels_first\":\n            if input_shape is not None:\n                if len(input_shape) != 3:\n                    raise ValueError(\n                        \"`input_shape` must be a tuple of three integers.\"\n                    )\n                if input_shape[0] != 3 and weights == \"imagenet\":\n                    raise ValueError(\n                        \"The input must have 3 channels; Received \"\n                        f\"`input_shape={input_shape}`\"\n                    )\n                if (\n                    input_shape[1] is not None and input_shape[1] < min_size\n                ) or (input_shape[2] is not None and input_shape[2] < min_size):\n                    raise ValueError(\n                        f\"Input size must be at least {min_size}\"\n                        f\"x{min_size}; Received: \"\n                        f\"input_shape={input_shape}\"\n                    )\n        else:\n            if input_shape is not None:\n                if len(input_shape) != 3:\n                    raise ValueError(\n                        \"`input_shape` must be a tuple of three integers.\"\n                    )\n                if input_shape[-1] != 3 and weights == \"imagenet\":\n                    raise ValueError(\n                        \"The input must have 3 channels; Received \"\n                        f\"`input_shape={input_shape}`\"\n                    )\n                if (\n                    input_shape[0] is not None and input_shape[0] < min_size\n                ) or (input_shape[1] is not None and input_shape[1] < min_size):\n                    raise ValueError(\n                        \"Input size must be at least \"\n                        f\"{min_size}x{min_size}; Received: \"\n                        f\"input_shape={input_shape}\"\n                    )\n    else:\n        if require_flatten:\n            input_shape = default_shape\n        else:\n            if data_format == \"channels_first\":\n                input_shape = (3, None, None)\n            else:\n                input_shape = (None, None, 3)\n    if require_flatten:\n        if None in input_shape:\n            raise ValueError(\n                \"If `include_top` is True, \"\n                \"you should specify a static `input_shape`. \"\n                f\"Received: input_shape={input_shape}\"\n            )\n    return input_shape\n\n\ndef correct_pad(inputs, kernel_size):\n    \"\"\"Returns a tuple for zero-padding for 2D convolution with downsampling.\n\n    Args:\n      inputs: Input tensor.\n      kernel_size: An integer or tuple/list of 2 integers.\n\n    Returns:\n      A tuple.\n    \"\"\"\n    img_dim = 2 if backend.image_data_format() == \"channels_first\" else 1\n    input_size = inputs.shape[img_dim : (img_dim + 2)]\n    if isinstance(kernel_size, int):\n        kernel_size = (kernel_size, kernel_size)\n    if input_size[0] is None:\n        adjust = (1, 1)\n    else:\n        adjust = (1 - input_size[0] % 2, 1 - input_size[1] % 2)\n    correct = (kernel_size[0] // 2, kernel_size[1] // 2)\n    return (\n        (correct[0] - adjust[0], correct[0]),\n        (correct[1] - adjust[1], correct[1]),\n    )\n\n\ndef validate_activation(classifier_activation, weights):\n    \"\"\"validates that the classifer_activation is compatible with the weights.\n\n    Args:\n      classifier_activation: str or callable activation function\n      weights: The pretrained weights to load.\n\n    Raises:\n      ValueError: if an activation other than `None` or `softmax` are used with\n        pretrained weights.\n    \"\"\"\n    if weights is None:\n        return\n\n    classifier_activation = activations.get(classifier_activation)\n    if classifier_activation not in {\n        activations.get(\"softmax\"),\n        activations.get(None),\n    }:\n        raise ValueError(\n            \"Only `None` and `softmax` activations are allowed \"\n            \"for the `classifier_activation` argument when using \"\n            \"pretrained weights, with `include_top=True`; Received: \"\n            f\"classifier_activation={classifier_activation}\"\n        )\n"
  },
  {
    "path": "keras/src/applications/imagenet_utils_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nimport keras\nfrom keras.src import backend\nfrom keras.src import testing\nfrom keras.src.applications import imagenet_utils as utils\nfrom keras.src.dtype_policies.dtype_policy import set_dtype_policy\n\n\nclass TestImageNetUtils(testing.TestCase):\n    def test_preprocess_input(self):\n        # Test invalid mode check\n        x = np.random.uniform(0, 255, (10, 10, 3))\n        with self.assertRaises(ValueError):\n            utils.preprocess_input(x, mode=\"some_unknown_mode\")\n\n        # Test image batch with float and int image input\n        x = np.random.uniform(0, 255, (2, 10, 10, 3))\n        xint = x.astype(\"int32\")\n        self.assertEqual(utils.preprocess_input(x).shape, x.shape)\n        self.assertEqual(utils.preprocess_input(xint).shape, xint.shape)\n\n        out1 = utils.preprocess_input(x, \"channels_last\")\n        out1int = utils.preprocess_input(xint, \"channels_last\")\n        out2 = utils.preprocess_input(\n            np.transpose(x, (0, 3, 1, 2)), \"channels_first\"\n        )\n        out2int = utils.preprocess_input(\n            np.transpose(xint, (0, 3, 1, 2)), \"channels_first\"\n        )\n        self.assertAllClose(out1, out2.transpose(0, 2, 3, 1))\n        self.assertAllClose(out1int, out2int.transpose(0, 2, 3, 1))\n\n        # Test single image\n        x = np.random.uniform(0, 255, (10, 10, 3))\n        xint = x.astype(\"int32\")\n        self.assertEqual(utils.preprocess_input(x).shape, x.shape)\n        self.assertEqual(utils.preprocess_input(xint).shape, xint.shape)\n\n        out1 = utils.preprocess_input(x, \"channels_last\")\n        out1int = utils.preprocess_input(xint, \"channels_last\")\n        out2 = utils.preprocess_input(\n            np.transpose(x, (2, 0, 1)), \"channels_first\"\n        )\n        out2int = utils.preprocess_input(\n            np.transpose(xint, (2, 0, 1)), \"channels_first\"\n        )\n        self.assertAllClose(out1, out2.transpose(1, 2, 0))\n        self.assertAllClose(out1int, out2int.transpose(1, 2, 0))\n\n        # Test that writing over the input data works predictably\n        for mode in [\"torch\", \"tf\"]:\n            x = np.random.uniform(0, 255, (2, 10, 10, 3))\n            xint = x.astype(\"int\")\n            x2 = utils.preprocess_input(x, \"channels_last\", mode=mode)\n            xint2 = utils.preprocess_input(xint, \"channels_last\")\n            self.assertAllClose(x, x2)\n            self.assertNotEqual(xint.astype(\"float\").max(), xint2.max())\n\n        # Caffe mode works differently from the others\n        x = np.random.uniform(0, 255, (2, 10, 10, 3))\n        xint = x.astype(\"int\")\n        x2 = utils.preprocess_input(\n            x, data_format=\"channels_last\", mode=\"caffe\"\n        )\n        xint2 = utils.preprocess_input(xint, data_format=\"channels_last\")\n        self.assertAllClose(x, x2[..., ::-1])\n        self.assertNotEqual(xint.astype(\"float\").max(), xint2.max())\n\n    @parameterized.named_parameters(\n        [\n            {\"testcase_name\": \"mode_torch\", \"mode\": \"torch\"},\n            {\"testcase_name\": \"mode_tf\", \"mode\": \"tf\"},\n            {\"testcase_name\": \"mode_caffe\", \"mode\": \"caffe\"},\n        ]\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_preprocess_input_symbolic(self, mode):\n        backend_data_format = backend.image_data_format()\n        # Test image batch\n        if backend_data_format == \"channels_last\":\n            x = np.random.uniform(0, 255, (2, 10, 10, 3))\n        elif backend_data_format == \"channels_first\":\n            x = np.random.uniform(0, 255, (2, 3, 10, 10))\n        inputs = keras.layers.Input(shape=x.shape[1:])\n        outputs = keras.layers.Lambda(\n            lambda x: utils.preprocess_input(x, mode=mode),\n            output_shape=x.shape[1:],\n        )(inputs)\n        model = keras.Model(inputs, outputs)\n        self.assertEqual(model.predict(x).shape, x.shape)\n\n        x = np.random.uniform(0, 255, (2, 10, 10, 3))\n        inputs = keras.layers.Input(shape=x.shape[1:])\n        outputs1 = keras.layers.Lambda(\n            lambda x: utils.preprocess_input(x, \"channels_last\", mode=mode),\n            output_shape=x.shape[1:],\n        )(inputs)\n        model1 = keras.Model(inputs, outputs1)\n        out1 = model1.predict(x)\n        x2 = np.transpose(x, (0, 3, 1, 2))\n        inputs2 = keras.layers.Input(shape=x2.shape[1:])\n        outputs2 = keras.layers.Lambda(\n            lambda x: utils.preprocess_input(x, \"channels_first\", mode=mode),\n            output_shape=x2.shape[1:],\n        )(inputs2)\n        model2 = keras.Model(inputs2, outputs2)\n        out2 = model2.predict(x2)\n        self.assertAllClose(out1, out2.transpose(0, 2, 3, 1))\n\n        # Test single image\n        if backend_data_format == \"channels_last\":\n            x = np.random.uniform(0, 255, (10, 10, 3))\n        elif backend_data_format == \"channels_first\":\n            x = np.random.uniform(0, 255, (3, 10, 10))\n        inputs = keras.layers.Input(shape=x.shape)\n        outputs = keras.layers.Lambda(\n            lambda x: utils.preprocess_input(x, mode=mode), output_shape=x.shape\n        )(inputs)\n        model = keras.Model(inputs, outputs)\n        self.assertEqual(model.predict(x[np.newaxis])[0].shape, x.shape)\n\n        x = np.random.uniform(0, 255, (10, 10, 3))\n        inputs = keras.layers.Input(shape=x.shape)\n        outputs1 = keras.layers.Lambda(\n            lambda x: utils.preprocess_input(x, \"channels_last\", mode=mode),\n            output_shape=x.shape,\n        )(inputs)\n        model1 = keras.Model(inputs, outputs1)\n        out1 = model1.predict(x[np.newaxis])[0]\n        x2 = np.transpose(x, (2, 0, 1))\n        inputs2 = keras.layers.Input(shape=x2.shape)\n        outputs2 = keras.layers.Lambda(\n            lambda x: utils.preprocess_input(x, \"channels_first\", mode=mode),\n            output_shape=x2.shape,\n        )(inputs2)\n        model2 = keras.Model(inputs2, outputs2)\n        out2 = model2.predict(x2[np.newaxis])[0]\n        self.assertAllClose(out1, out2.transpose(1, 2, 0))\n\n    @parameterized.named_parameters(\n        [\n            {\"testcase_name\": \"mode_torch\", \"mode\": \"torch\"},\n            {\"testcase_name\": \"mode_tf\", \"mode\": \"tf\"},\n            {\"testcase_name\": \"mode_caffe\", \"mode\": \"caffe\"},\n        ]\n    )\n    def test_preprocess_input_symbolic_mixed_precision(self, mode):\n        set_dtype_policy(\"mixed_float16\")\n        shape = (20, 20, 3)\n        inputs = keras.layers.Input(shape=shape)\n        try:\n            keras.layers.Lambda(\n                lambda x: utils.preprocess_input(x, mode=mode),\n                output_shape=shape,\n            )(inputs)\n        finally:\n            set_dtype_policy(\"float32\")\n\n    @parameterized.named_parameters(\n        [\n            {\n                \"testcase_name\": \"channels_last_format\",\n                \"data_format\": \"channels_last\",\n            },\n            {\n                \"testcase_name\": \"channels_first_format\",\n                \"data_format\": \"channels_first\",\n            },\n        ]\n    )\n    def test_obtain_input_shape(self, data_format):\n        # input_shape and default_size are not identical.\n        with self.assertRaises(ValueError):\n            utils.obtain_input_shape(\n                input_shape=(224, 224, 3),\n                default_size=299,\n                min_size=139,\n                data_format=\"channels_last\",\n                require_flatten=True,\n                weights=\"imagenet\",\n            )\n\n        # Test invalid use cases\n\n        shape = (139, 139)\n        if data_format == \"channels_last\":\n            input_shape = shape + (99,)\n        else:\n            input_shape = (99,) + shape\n\n        # input_shape is smaller than min_size.\n        shape = (100, 100)\n        if data_format == \"channels_last\":\n            input_shape = shape + (3,)\n        else:\n            input_shape = (3,) + shape\n        with self.assertRaises(ValueError):\n            utils.obtain_input_shape(\n                input_shape=input_shape,\n                default_size=None,\n                min_size=139,\n                data_format=data_format,\n                require_flatten=False,\n            )\n\n        # shape is 1D.\n        shape = (100,)\n        if data_format == \"channels_last\":\n            input_shape = shape + (3,)\n        else:\n            input_shape = (3,) + shape\n        with self.assertRaises(ValueError):\n            utils.obtain_input_shape(\n                input_shape=input_shape,\n                default_size=None,\n                min_size=139,\n                data_format=data_format,\n                require_flatten=False,\n            )\n\n        # the number of channels is 5 not 3.\n        shape = (100, 100)\n        if data_format == \"channels_last\":\n            input_shape = shape + (5,)\n        else:\n            input_shape = (5,) + shape\n        with self.assertRaises(ValueError):\n            utils.obtain_input_shape(\n                input_shape=input_shape,\n                default_size=None,\n                min_size=139,\n                data_format=data_format,\n                require_flatten=False,\n            )\n\n        # require_flatten=True with dynamic input shape.\n        with self.assertRaises(ValueError):\n            utils.obtain_input_shape(\n                input_shape=None,\n                default_size=None,\n                min_size=139,\n                data_format=\"channels_first\",\n                require_flatten=True,\n            )\n\n        # test include top\n        self.assertEqual(\n            utils.obtain_input_shape(\n                input_shape=(3, 200, 200),\n                default_size=None,\n                min_size=139,\n                data_format=\"channels_first\",\n                require_flatten=True,\n            ),\n            (3, 200, 200),\n        )\n\n        self.assertEqual(\n            utils.obtain_input_shape(\n                input_shape=None,\n                default_size=None,\n                min_size=139,\n                data_format=\"channels_last\",\n                require_flatten=False,\n            ),\n            (None, None, 3),\n        )\n\n        self.assertEqual(\n            utils.obtain_input_shape(\n                input_shape=None,\n                default_size=None,\n                min_size=139,\n                data_format=\"channels_first\",\n                require_flatten=False,\n            ),\n            (3, None, None),\n        )\n\n        self.assertEqual(\n            utils.obtain_input_shape(\n                input_shape=None,\n                default_size=None,\n                min_size=139,\n                data_format=\"channels_last\",\n                require_flatten=False,\n            ),\n            (None, None, 3),\n        )\n\n        self.assertEqual(\n            utils.obtain_input_shape(\n                input_shape=(150, 150, 3),\n                default_size=None,\n                min_size=139,\n                data_format=\"channels_last\",\n                require_flatten=False,\n            ),\n            (150, 150, 3),\n        )\n\n        self.assertEqual(\n            utils.obtain_input_shape(\n                input_shape=(3, None, None),\n                default_size=None,\n                min_size=139,\n                data_format=\"channels_first\",\n                require_flatten=False,\n            ),\n            (3, None, None),\n        )\n"
  },
  {
    "path": "keras/src/applications/inception_resnet_v2.py",
    "content": "from keras.src import backend\nfrom keras.src import layers\nfrom keras.src.api_export import keras_export\nfrom keras.src.applications import imagenet_utils\nfrom keras.src.layers.layer import Layer\nfrom keras.src.models import Functional\nfrom keras.src.ops import operation_utils\nfrom keras.src.utils import file_utils\n\nBASE_WEIGHT_URL = (\n    \"https://storage.googleapis.com/tensorflow/\"\n    \"keras-applications/inception_resnet_v2/\"\n)\n\n\n@keras_export(\n    [\n        \"keras.applications.inception_resnet_v2.InceptionResNetV2\",\n        \"keras.applications.InceptionResNetV2\",\n    ]\n)\ndef InceptionResNetV2(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"inception_resnet_v2\",\n):\n    \"\"\"Instantiates the Inception-ResNet v2 architecture.\n\n    Reference:\n    - [Inception-v4, Inception-ResNet and the Impact of\n       Residual Connections on Learning](https://arxiv.org/abs/1602.07261)\n      (AAAI 2017)\n\n    This function returns a Keras image classification model,\n    optionally loaded with weights pre-trained on ImageNet.\n\n    For image classification use cases, see\n    [this page for detailed examples](\n      https://keras.io/api/applications/#usage-examples-for-image-classification-models).\n\n    For transfer learning use cases, make sure to read the\n    [guide to transfer learning & fine-tuning](\n      https://keras.io/guides/transfer_learning/).\n\n    Note: each Keras Application expects a specific kind of\n    input preprocessing. For InceptionResNetV2, call\n    `keras.applications.inception_resnet_v2.preprocess_input`\n    on your inputs before passing them to the model.\n    `inception_resnet_v2.preprocess_input`\n    will scale input pixels between -1 and 1.\n\n    Args:\n        include_top: whether to include the fully-connected\n            layer at the top of the network.\n        weights: one of `None` (random initialization),\n            `\"imagenet\"` (pre-training on ImageNet),\n            or the path to the weights file to be loaded.\n        input_tensor: optional Keras tensor\n            (i.e. output of `layers.Input()`)\n            to use as image input for the model.\n        input_shape: optional shape tuple, only to be specified\n            if `include_top` is `False` (otherwise the input shape\n            has to be `(299, 299, 3)`\n            (with `'channels_last'` data format)\n            or `(3, 299, 299)` (with `'channels_first'` data format).\n            It should have exactly 3 inputs channels,\n            and width and height should be no smaller than 75.\n            E.g. `(150, 150, 3)` would be one valid value.\n        pooling: Optional pooling mode for feature extraction\n            when `include_top` is `False`.\n            - `None` means that the output of the model will be\n                the 4D tensor output of the last convolutional block.\n            - `'avg'` means that global average pooling\n                will be applied to the output of the\n                last convolutional block, and thus\n                the output of the model will be a 2D tensor.\n            - `'max'` means that global max pooling will be applied.\n        classes: optional number of classes to classify images\n            into, only to be specified if `include_top` is `True`,\n            and if no `weights` argument is specified.\n        classifier_activation: A `str` or callable.\n            The activation function to use on the \"top\" layer.\n            Ignored unless `include_top=True`.\n            Set `classifier_activation=None` to return the logits\n            of the \"top\" layer. When loading pretrained weights,\n            `classifier_activation` can only be `None` or `\"softmax\"`.\n        name: The name of the model (string).\n\n    Returns:\n        A model instance.\n    \"\"\"\n    if not (weights in {\"imagenet\", None} or file_utils.exists(weights)):\n        raise ValueError(\n            \"The `weights` argument should be either \"\n            \"`None` (random initialization), `imagenet` \"\n            \"(pre-training on ImageNet), \"\n            \"or the path to the weights file to be loaded.\"\n        )\n\n    if weights == \"imagenet\" and include_top and classes != 1000:\n        raise ValueError(\n            'If using `weights=\"imagenet\"` with `include_top=True`, '\n            \"`classes` should be 1000. \"\n            f\"Received classes={classes}\"\n        )\n\n    # Determine proper input shape\n    input_shape = imagenet_utils.obtain_input_shape(\n        input_shape,\n        default_size=299,\n        min_size=75,\n        data_format=backend.image_data_format(),\n        require_flatten=include_top,\n        weights=weights,\n    )\n\n    if input_tensor is None:\n        img_input = layers.Input(shape=input_shape)\n    else:\n        if not backend.is_keras_tensor(input_tensor):\n            img_input = layers.Input(tensor=input_tensor, shape=input_shape)\n        else:\n            img_input = input_tensor\n\n    # Stem block: 35 x 35 x 192\n    x = conv2d_bn(img_input, 32, 3, strides=2, padding=\"valid\")\n    x = conv2d_bn(x, 32, 3, padding=\"valid\")\n    x = conv2d_bn(x, 64, 3)\n    x = layers.MaxPooling2D(3, strides=2)(x)\n    x = conv2d_bn(x, 80, 1, padding=\"valid\")\n    x = conv2d_bn(x, 192, 3, padding=\"valid\")\n    x = layers.MaxPooling2D(3, strides=2)(x)\n\n    # Mixed 5b (Inception-A block): 35 x 35 x 320\n    branch_0 = conv2d_bn(x, 96, 1)\n    branch_1 = conv2d_bn(x, 48, 1)\n    branch_1 = conv2d_bn(branch_1, 64, 5)\n    branch_2 = conv2d_bn(x, 64, 1)\n    branch_2 = conv2d_bn(branch_2, 96, 3)\n    branch_2 = conv2d_bn(branch_2, 96, 3)\n    branch_pool = layers.AveragePooling2D(3, strides=1, padding=\"same\")(x)\n    branch_pool = conv2d_bn(branch_pool, 64, 1)\n    branches = [branch_0, branch_1, branch_2, branch_pool]\n    channel_axis = 1 if backend.image_data_format() == \"channels_first\" else 3\n    x = layers.Concatenate(axis=channel_axis, name=\"mixed_5b\")(branches)\n\n    # 10x block35 (Inception-ResNet-A block): 35 x 35 x 320\n    for block_idx in range(1, 11):\n        x = inception_resnet_block(\n            x, scale=0.17, block_type=\"block35\", block_idx=block_idx\n        )\n\n    # Mixed 6a (Reduction-A block): 17 x 17 x 1088\n    branch_0 = conv2d_bn(x, 384, 3, strides=2, padding=\"valid\")\n    branch_1 = conv2d_bn(x, 256, 1)\n    branch_1 = conv2d_bn(branch_1, 256, 3)\n    branch_1 = conv2d_bn(branch_1, 384, 3, strides=2, padding=\"valid\")\n    branch_pool = layers.MaxPooling2D(3, strides=2, padding=\"valid\")(x)\n    branches = [branch_0, branch_1, branch_pool]\n    x = layers.Concatenate(axis=channel_axis, name=\"mixed_6a\")(branches)\n\n    # 20x block17 (Inception-ResNet-B block): 17 x 17 x 1088\n    for block_idx in range(1, 21):\n        x = inception_resnet_block(\n            x, scale=0.1, block_type=\"block17\", block_idx=block_idx\n        )\n\n    # Mixed 7a (Reduction-B block): 8 x 8 x 2080\n    branch_0 = conv2d_bn(x, 256, 1)\n    branch_0 = conv2d_bn(branch_0, 384, 3, strides=2, padding=\"valid\")\n    branch_1 = conv2d_bn(x, 256, 1)\n    branch_1 = conv2d_bn(branch_1, 288, 3, strides=2, padding=\"valid\")\n    branch_2 = conv2d_bn(x, 256, 1)\n    branch_2 = conv2d_bn(branch_2, 288, 3)\n    branch_2 = conv2d_bn(branch_2, 320, 3, strides=2, padding=\"valid\")\n    branch_pool = layers.MaxPooling2D(3, strides=2, padding=\"valid\")(x)\n    branches = [branch_0, branch_1, branch_2, branch_pool]\n    x = layers.Concatenate(axis=channel_axis, name=\"mixed_7a\")(branches)\n\n    # 10x block8 (Inception-ResNet-C block): 8 x 8 x 2080\n    for block_idx in range(1, 10):\n        x = inception_resnet_block(\n            x, scale=0.2, block_type=\"block8\", block_idx=block_idx\n        )\n    x = inception_resnet_block(\n        x, scale=1.0, activation=None, block_type=\"block8\", block_idx=10\n    )\n\n    # Final convolution block: 8 x 8 x 1536\n    x = conv2d_bn(x, 1536, 1, name=\"conv_7b\")\n\n    if include_top:\n        # Classification block\n        x = layers.GlobalAveragePooling2D(name=\"avg_pool\")(x)\n        imagenet_utils.validate_activation(classifier_activation, weights)\n        x = layers.Dense(\n            classes, activation=classifier_activation, name=\"predictions\"\n        )(x)\n    else:\n        if pooling == \"avg\":\n            x = layers.GlobalAveragePooling2D()(x)\n        elif pooling == \"max\":\n            x = layers.GlobalMaxPooling2D()(x)\n\n    # Ensure that the model takes into account\n    # any potential predecessors of `input_tensor`.\n    if input_tensor is not None:\n        inputs = operation_utils.get_source_inputs(input_tensor)\n    else:\n        inputs = img_input\n\n    # Create model.\n    model = Functional(inputs, x, name=name)\n\n    # Load weights.\n    if weights == \"imagenet\":\n        if include_top:\n            fname = \"inception_resnet_v2_weights_tf_dim_ordering_tf_kernels.h5\"\n            weights_path = file_utils.get_file(\n                fname,\n                BASE_WEIGHT_URL + fname,\n                cache_subdir=\"models\",\n                file_hash=\"e693bd0210a403b3192acc6073ad2e96\",\n            )\n        else:\n            fname = (\n                \"inception_resnet_v2_weights_\"\n                \"tf_dim_ordering_tf_kernels_notop.h5\"\n            )\n            weights_path = file_utils.get_file(\n                fname,\n                BASE_WEIGHT_URL + fname,\n                cache_subdir=\"models\",\n                file_hash=\"d19885ff4a710c122648d3b5c3b684e4\",\n            )\n        model.load_weights(weights_path)\n    elif weights is not None:\n        model.load_weights(weights)\n\n    return model\n\n\ndef conv2d_bn(\n    x,\n    filters,\n    kernel_size,\n    strides=1,\n    padding=\"same\",\n    activation=\"relu\",\n    use_bias=False,\n    name=None,\n):\n    \"\"\"Utility function to apply conv + BN.\n\n    Args:\n        x: input tensor.\n        filters: filters in `Conv2D`.\n        kernel_size: kernel size as in `Conv2D`.\n        strides: strides in `Conv2D`.\n        padding: padding mode in `Conv2D`.\n        activation: activation in `Conv2D`.\n        use_bias: whether to use a bias in `Conv2D`.\n        name: name of the ops; will become `name + '_ac'`\n            for the activation and `name + '_bn'` for the batch norm layer.\n\n    Returns:\n        Output tensor after applying `Conv2D` and `BatchNormalization`.\n    \"\"\"\n    x = layers.Conv2D(\n        filters,\n        kernel_size,\n        strides=strides,\n        padding=padding,\n        use_bias=use_bias,\n        name=name,\n    )(x)\n    if not use_bias:\n        bn_axis = 1 if backend.image_data_format() == \"channels_first\" else 3\n        bn_name = None if name is None else f\"{name}_bn\"\n        x = layers.BatchNormalization(axis=bn_axis, scale=False, name=bn_name)(\n            x\n        )\n    if activation is not None:\n        ac_name = None if name is None else f\"{name}_ac\"\n        x = layers.Activation(activation, name=ac_name)(x)\n    return x\n\n\nclass CustomScaleLayer(Layer):\n    def __init__(self, scale, **kwargs):\n        super().__init__(**kwargs)\n        self.scale = scale\n\n    def get_config(self):\n        config = super().get_config()\n        config.update({\"scale\": self.scale})\n        return config\n\n    def call(self, inputs):\n        return inputs[0] + inputs[1] * self.scale\n\n\ndef inception_resnet_block(x, scale, block_type, block_idx, activation=\"relu\"):\n    \"\"\"Adds an Inception-ResNet block.\n\n    Args:\n        x: input tensor.\n        scale: scaling factor to scale the residuals\n            (i.e., the output of passing `x` through an inception module)\n            before adding them to the shortcut\n            branch. Let `r` be the output from the residual branch,\n            the output of this block will be `x + scale * r`.\n        block_type: `'block35'`, `'block17'` or `'block8'`,\n            determines the network structure in the residual branch.\n        block_idx: an `int` used for generating layer names.\n            The Inception-ResNet blocks are repeated many times\n            in this network. We use `block_idx` to identify each\n            of the repetitions. For example, the first\n            Inception-ResNet-A block will have\n            `block_type='block35', block_idx=0`, and the layer names\n            will have a common prefix `'block35_0'`.\n        activation: activation function to use at the end of the block.\n\n    Returns:\n        Output tensor for the block.\n    \"\"\"\n    if block_type == \"block35\":\n        branch_0 = conv2d_bn(x, 32, 1)\n        branch_1 = conv2d_bn(x, 32, 1)\n        branch_1 = conv2d_bn(branch_1, 32, 3)\n        branch_2 = conv2d_bn(x, 32, 1)\n        branch_2 = conv2d_bn(branch_2, 48, 3)\n        branch_2 = conv2d_bn(branch_2, 64, 3)\n        branches = [branch_0, branch_1, branch_2]\n    elif block_type == \"block17\":\n        branch_0 = conv2d_bn(x, 192, 1)\n        branch_1 = conv2d_bn(x, 128, 1)\n        branch_1 = conv2d_bn(branch_1, 160, [1, 7])\n        branch_1 = conv2d_bn(branch_1, 192, [7, 1])\n        branches = [branch_0, branch_1]\n    elif block_type == \"block8\":\n        branch_0 = conv2d_bn(x, 192, 1)\n        branch_1 = conv2d_bn(x, 192, 1)\n        branch_1 = conv2d_bn(branch_1, 224, [1, 3])\n        branch_1 = conv2d_bn(branch_1, 256, [3, 1])\n        branches = [branch_0, branch_1]\n    else:\n        raise ValueError(\n            \"Unknown Inception-ResNet block type. \"\n            'Expects \"block35\", \"block17\" or \"block8\", '\n            f\"but got: {block_type}\"\n        )\n\n    block_name = f\"{block_type}_{block_idx}\"\n    channel_axis = 1 if backend.image_data_format() == \"channels_first\" else 3\n    mixed = layers.Concatenate(axis=channel_axis, name=f\"{block_name}_mixed\")(\n        branches\n    )\n    up = conv2d_bn(\n        mixed,\n        x.shape[channel_axis],\n        1,\n        activation=None,\n        use_bias=True,\n        name=f\"{block_name}_conv\",\n    )\n\n    x = CustomScaleLayer(scale)([x, up])\n    if activation is not None:\n        x = layers.Activation(activation, name=f\"{block_name}_ac\")(x)\n    return x\n\n\n@keras_export(\"keras.applications.inception_resnet_v2.preprocess_input\")\ndef preprocess_input(x, data_format=None):\n    return imagenet_utils.preprocess_input(\n        x, data_format=data_format, mode=\"tf\"\n    )\n\n\n@keras_export(\"keras.applications.inception_resnet_v2.decode_predictions\")\ndef decode_predictions(preds, top=5):\n    return imagenet_utils.decode_predictions(preds, top=top)\n\n\npreprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(\n    mode=\"\",\n    ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,\n    error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC,\n)\ndecode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__\n"
  },
  {
    "path": "keras/src/applications/inception_v3.py",
    "content": "from keras.src import backend\nfrom keras.src import layers\nfrom keras.src.api_export import keras_export\nfrom keras.src.applications import imagenet_utils\nfrom keras.src.models import Functional\nfrom keras.src.ops import operation_utils\nfrom keras.src.utils import file_utils\n\nWEIGHTS_PATH = (\n    \"https://storage.googleapis.com/tensorflow/keras-applications/\"\n    \"inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels.h5\"\n)\nWEIGHTS_PATH_NO_TOP = (\n    \"https://storage.googleapis.com/tensorflow/keras-applications/\"\n    \"inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5\"\n)\n\n\n@keras_export(\n    [\n        \"keras.applications.inception_v3.InceptionV3\",\n        \"keras.applications.InceptionV3\",\n    ]\n)\ndef InceptionV3(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"inception_v3\",\n):\n    \"\"\"Instantiates the Inception v3 architecture.\n\n    Reference:\n    - [Rethinking the Inception Architecture for Computer Vision](\n        http://arxiv.org/abs/1512.00567) (CVPR 2016)\n\n    This function returns a Keras image classification model,\n    optionally loaded with weights pre-trained on ImageNet.\n\n    For image classification use cases, see\n    [this page for detailed examples](\n      https://keras.io/api/applications/#usage-examples-for-image-classification-models).\n\n    For transfer learning use cases, make sure to read the\n    [guide to transfer learning & fine-tuning](\n      https://keras.io/guides/transfer_learning/).\n\n    Note: each Keras Application expects a specific kind of input preprocessing.\n    For `InceptionV3`, call\n    `keras.applications.inception_v3.preprocess_input` on your inputs\n    before passing them to the model.\n    `inception_v3.preprocess_input` will scale input pixels between -1 and 1.\n\n    Args:\n        include_top: Boolean, whether to include the fully-connected\n            layer at the top, as the last layer of the network.\n            Defaults to `True`.\n        weights: One of `None` (random initialization),\n            `imagenet` (pre-training on ImageNet),\n            or the path to the weights file to be loaded.\n            Defaults to `\"imagenet\"`.\n        input_tensor: Optional Keras tensor (i.e. output of `layers.Input()`)\n            to use as image input for the model. `input_tensor` is useful for\n            sharing inputs between multiple different networks.\n            Defaults to `None`.\n        input_shape: Optional shape tuple, only to be specified\n            if `include_top` is False (otherwise the input shape\n            has to be `(299, 299, 3)` (with `channels_last` data format)\n            or `(3, 299, 299)` (with `channels_first` data format).\n            It should have exactly 3 inputs channels,\n            and width and height should be no smaller than 75.\n            E.g. `(150, 150, 3)` would be one valid value.\n            `input_shape` will be ignored if the `input_tensor` is provided.\n        pooling: Optional pooling mode for feature extraction\n            when `include_top` is `False`.\n            - `None` (default) means that the output of the model will be\n                the 4D tensor output of the last convolutional block.\n            - `avg` means that global average pooling\n                will be applied to the output of the\n                last convolutional block, and thus\n                the output of the model will be a 2D tensor.\n            - `max` means that global max pooling will be applied.\n        classes: optional number of classes to classify images\n            into, only to be specified if `include_top` is `True`, and\n            if no `weights` argument is specified. Defaults to 1000.\n        classifier_activation: A `str` or callable. The activation function\n            to use on the \"top\" layer. Ignored unless `include_top=True`.\n            Set `classifier_activation=None` to return the logits of the \"top\"\n            layer. When loading pretrained weights, `classifier_activation`\n            can only be `None` or `\"softmax\"`.\n        name: The name of the model (string).\n\n    Returns:\n        A model instance.\n    \"\"\"\n    if not (weights in {\"imagenet\", None} or file_utils.exists(weights)):\n        raise ValueError(\n            \"The `weights` argument should be either \"\n            \"`None` (random initialization), `imagenet` \"\n            \"(pre-training on ImageNet), \"\n            \"or the path to the weights file to be loaded; \"\n            f\"Received: weights={weights}\"\n        )\n\n    if weights == \"imagenet\" and include_top and classes != 1000:\n        raise ValueError(\n            'If using `weights=\"imagenet\"` with `include_top=True`, '\n            \"`classes` should be 1000. \"\n            f\"Received classes={classes}\"\n        )\n\n    # Determine proper input shape\n    input_shape = imagenet_utils.obtain_input_shape(\n        input_shape,\n        default_size=299,\n        min_size=75,\n        data_format=backend.image_data_format(),\n        require_flatten=include_top,\n        weights=weights,\n    )\n\n    if input_tensor is None:\n        img_input = layers.Input(shape=input_shape)\n    else:\n        if not backend.is_keras_tensor(input_tensor):\n            img_input = layers.Input(tensor=input_tensor, shape=input_shape)\n        else:\n            img_input = input_tensor\n\n    if backend.image_data_format() == \"channels_first\":\n        channel_axis = 1\n    else:\n        channel_axis = 3\n\n    x = conv2d_bn(img_input, 32, 3, 3, strides=(2, 2), padding=\"valid\")\n    x = conv2d_bn(x, 32, 3, 3, padding=\"valid\")\n    x = conv2d_bn(x, 64, 3, 3)\n    x = layers.MaxPooling2D((3, 3), strides=(2, 2))(x)\n\n    x = conv2d_bn(x, 80, 1, 1, padding=\"valid\")\n    x = conv2d_bn(x, 192, 3, 3, padding=\"valid\")\n    x = layers.MaxPooling2D((3, 3), strides=(2, 2))(x)\n\n    # mixed 0: 35 x 35 x 256\n    branch1x1 = conv2d_bn(x, 64, 1, 1)\n\n    branch5x5 = conv2d_bn(x, 48, 1, 1)\n    branch5x5 = conv2d_bn(branch5x5, 64, 5, 5)\n\n    branch3x3dbl = conv2d_bn(x, 64, 1, 1)\n    branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)\n    branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)\n\n    branch_pool = layers.AveragePooling2D(\n        (3, 3), strides=(1, 1), padding=\"same\"\n    )(x)\n    branch_pool = conv2d_bn(branch_pool, 32, 1, 1)\n    x = layers.concatenate(\n        [branch1x1, branch5x5, branch3x3dbl, branch_pool],\n        axis=channel_axis,\n        name=\"mixed0\",\n    )\n\n    # mixed 1: 35 x 35 x 288\n    branch1x1 = conv2d_bn(x, 64, 1, 1)\n\n    branch5x5 = conv2d_bn(x, 48, 1, 1)\n    branch5x5 = conv2d_bn(branch5x5, 64, 5, 5)\n\n    branch3x3dbl = conv2d_bn(x, 64, 1, 1)\n    branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)\n    branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)\n\n    branch_pool = layers.AveragePooling2D(\n        (3, 3), strides=(1, 1), padding=\"same\"\n    )(x)\n    branch_pool = conv2d_bn(branch_pool, 64, 1, 1)\n    x = layers.concatenate(\n        [branch1x1, branch5x5, branch3x3dbl, branch_pool],\n        axis=channel_axis,\n        name=\"mixed1\",\n    )\n\n    # mixed 2: 35 x 35 x 288\n    branch1x1 = conv2d_bn(x, 64, 1, 1)\n\n    branch5x5 = conv2d_bn(x, 48, 1, 1)\n    branch5x5 = conv2d_bn(branch5x5, 64, 5, 5)\n\n    branch3x3dbl = conv2d_bn(x, 64, 1, 1)\n    branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)\n    branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)\n\n    branch_pool = layers.AveragePooling2D(\n        (3, 3), strides=(1, 1), padding=\"same\"\n    )(x)\n    branch_pool = conv2d_bn(branch_pool, 64, 1, 1)\n    x = layers.concatenate(\n        [branch1x1, branch5x5, branch3x3dbl, branch_pool],\n        axis=channel_axis,\n        name=\"mixed2\",\n    )\n\n    # mixed 3: 17 x 17 x 768\n    branch3x3 = conv2d_bn(x, 384, 3, 3, strides=(2, 2), padding=\"valid\")\n\n    branch3x3dbl = conv2d_bn(x, 64, 1, 1)\n    branch3x3dbl = conv2d_bn(branch3x3dbl, 96, 3, 3)\n    branch3x3dbl = conv2d_bn(\n        branch3x3dbl, 96, 3, 3, strides=(2, 2), padding=\"valid\"\n    )\n\n    branch_pool = layers.MaxPooling2D((3, 3), strides=(2, 2))(x)\n    x = layers.concatenate(\n        [branch3x3, branch3x3dbl, branch_pool], axis=channel_axis, name=\"mixed3\"\n    )\n\n    # mixed 4: 17 x 17 x 768\n    branch1x1 = conv2d_bn(x, 192, 1, 1)\n\n    branch7x7 = conv2d_bn(x, 128, 1, 1)\n    branch7x7 = conv2d_bn(branch7x7, 128, 1, 7)\n    branch7x7 = conv2d_bn(branch7x7, 192, 7, 1)\n\n    branch7x7dbl = conv2d_bn(x, 128, 1, 1)\n    branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 7, 1)\n    branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 1, 7)\n    branch7x7dbl = conv2d_bn(branch7x7dbl, 128, 7, 1)\n    branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7)\n\n    branch_pool = layers.AveragePooling2D(\n        (3, 3), strides=(1, 1), padding=\"same\"\n    )(x)\n    branch_pool = conv2d_bn(branch_pool, 192, 1, 1)\n    x = layers.concatenate(\n        [branch1x1, branch7x7, branch7x7dbl, branch_pool],\n        axis=channel_axis,\n        name=\"mixed4\",\n    )\n\n    # mixed 5, 6: 17 x 17 x 768\n    for i in range(2):\n        branch1x1 = conv2d_bn(x, 192, 1, 1)\n\n        branch7x7 = conv2d_bn(x, 160, 1, 1)\n        branch7x7 = conv2d_bn(branch7x7, 160, 1, 7)\n        branch7x7 = conv2d_bn(branch7x7, 192, 7, 1)\n\n        branch7x7dbl = conv2d_bn(x, 160, 1, 1)\n        branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 7, 1)\n        branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 1, 7)\n        branch7x7dbl = conv2d_bn(branch7x7dbl, 160, 7, 1)\n        branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7)\n\n        branch_pool = layers.AveragePooling2D(\n            (3, 3), strides=(1, 1), padding=\"same\"\n        )(x)\n        branch_pool = conv2d_bn(branch_pool, 192, 1, 1)\n        x = layers.concatenate(\n            [branch1x1, branch7x7, branch7x7dbl, branch_pool],\n            axis=channel_axis,\n            name=\"mixed{0}\".format(5 + i),\n        )\n\n    # mixed 7: 17 x 17 x 768\n    branch1x1 = conv2d_bn(x, 192, 1, 1)\n\n    branch7x7 = conv2d_bn(x, 192, 1, 1)\n    branch7x7 = conv2d_bn(branch7x7, 192, 1, 7)\n    branch7x7 = conv2d_bn(branch7x7, 192, 7, 1)\n\n    branch7x7dbl = conv2d_bn(x, 192, 1, 1)\n    branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 7, 1)\n    branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7)\n    branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 7, 1)\n    branch7x7dbl = conv2d_bn(branch7x7dbl, 192, 1, 7)\n\n    branch_pool = layers.AveragePooling2D(\n        (3, 3), strides=(1, 1), padding=\"same\"\n    )(x)\n    branch_pool = conv2d_bn(branch_pool, 192, 1, 1)\n    x = layers.concatenate(\n        [branch1x1, branch7x7, branch7x7dbl, branch_pool],\n        axis=channel_axis,\n        name=\"mixed7\",\n    )\n\n    # mixed 8: 8 x 8 x 1280\n    branch3x3 = conv2d_bn(x, 192, 1, 1)\n    branch3x3 = conv2d_bn(branch3x3, 320, 3, 3, strides=(2, 2), padding=\"valid\")\n\n    branch7x7x3 = conv2d_bn(x, 192, 1, 1)\n    branch7x7x3 = conv2d_bn(branch7x7x3, 192, 1, 7)\n    branch7x7x3 = conv2d_bn(branch7x7x3, 192, 7, 1)\n    branch7x7x3 = conv2d_bn(\n        branch7x7x3, 192, 3, 3, strides=(2, 2), padding=\"valid\"\n    )\n\n    branch_pool = layers.MaxPooling2D((3, 3), strides=(2, 2))(x)\n    x = layers.concatenate(\n        [branch3x3, branch7x7x3, branch_pool], axis=channel_axis, name=\"mixed8\"\n    )\n\n    # mixed 9: 8 x 8 x 2048\n    for i in range(2):\n        branch1x1 = conv2d_bn(x, 320, 1, 1)\n\n        branch3x3 = conv2d_bn(x, 384, 1, 1)\n        branch3x3_1 = conv2d_bn(branch3x3, 384, 1, 3)\n        branch3x3_2 = conv2d_bn(branch3x3, 384, 3, 1)\n        branch3x3 = layers.concatenate(\n            [branch3x3_1, branch3x3_2],\n            axis=channel_axis,\n            name=f\"mixed9_{i}\",\n        )\n\n        branch3x3dbl = conv2d_bn(x, 448, 1, 1)\n        branch3x3dbl = conv2d_bn(branch3x3dbl, 384, 3, 3)\n        branch3x3dbl_1 = conv2d_bn(branch3x3dbl, 384, 1, 3)\n        branch3x3dbl_2 = conv2d_bn(branch3x3dbl, 384, 3, 1)\n        branch3x3dbl = layers.concatenate(\n            [branch3x3dbl_1, branch3x3dbl_2], axis=channel_axis\n        )\n\n        branch_pool = layers.AveragePooling2D(\n            (3, 3), strides=(1, 1), padding=\"same\"\n        )(x)\n        branch_pool = conv2d_bn(branch_pool, 192, 1, 1)\n        x = layers.concatenate(\n            [branch1x1, branch3x3, branch3x3dbl, branch_pool],\n            axis=channel_axis,\n            name=f\"mixed{9 + i}\",\n        )\n    if include_top:\n        # Classification block\n        x = layers.GlobalAveragePooling2D(name=\"avg_pool\")(x)\n        imagenet_utils.validate_activation(classifier_activation, weights)\n        x = layers.Dense(\n            classes, activation=classifier_activation, name=\"predictions\"\n        )(x)\n    else:\n        if pooling == \"avg\":\n            x = layers.GlobalAveragePooling2D()(x)\n        elif pooling == \"max\":\n            x = layers.GlobalMaxPooling2D()(x)\n\n    # Ensure that the model takes into account\n    # any potential predecessors of `input_tensor`.\n    if input_tensor is not None:\n        inputs = operation_utils.get_source_inputs(input_tensor)\n    else:\n        inputs = img_input\n    # Create model.\n    model = Functional(inputs, x, name=name)\n\n    # Load weights.\n    if weights == \"imagenet\":\n        if include_top:\n            weights_path = file_utils.get_file(\n                \"inception_v3_weights_tf_dim_ordering_tf_kernels.h5\",\n                WEIGHTS_PATH,\n                cache_subdir=\"models\",\n                file_hash=\"9a0d58056eeedaa3f26cb7ebd46da564\",\n            )\n        else:\n            weights_path = file_utils.get_file(\n                \"inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5\",\n                WEIGHTS_PATH_NO_TOP,\n                cache_subdir=\"models\",\n                file_hash=\"bcbd6486424b2319ff4ef7d526e38f63\",\n            )\n        model.load_weights(weights_path)\n    elif weights is not None:\n        model.load_weights(weights)\n\n    return model\n\n\ndef conv2d_bn(\n    x, filters, num_row, num_col, padding=\"same\", strides=(1, 1), name=None\n):\n    \"\"\"Utility function to apply conv + BN.\n\n    Args:\n        x: input tensor.\n        filters: filters in `Conv2D`.\n        num_row: height of the convolution kernel.\n        num_col: width of the convolution kernel.\n        padding: padding mode in `Conv2D`.\n        strides: strides in `Conv2D`.\n        name: name of the ops; will become `name + '_conv'`\n            for the convolution and `name + '_bn'` for the\n            batch norm layer.\n\n    Returns:\n        Output tensor after applying `Conv2D` and `BatchNormalization`.\n    \"\"\"\n    if name is not None:\n        bn_name = f\"{name}_bn\"\n        conv_name = f\"{name}_conv\"\n    else:\n        bn_name = None\n        conv_name = None\n    if backend.image_data_format() == \"channels_first\":\n        bn_axis = 1\n    else:\n        bn_axis = 3\n    x = layers.Conv2D(\n        filters,\n        (num_row, num_col),\n        strides=strides,\n        padding=padding,\n        use_bias=False,\n        name=conv_name,\n    )(x)\n    x = layers.BatchNormalization(axis=bn_axis, scale=False, name=bn_name)(x)\n    x = layers.Activation(\"relu\", name=name)(x)\n    return x\n\n\n@keras_export(\"keras.applications.inception_v3.preprocess_input\")\ndef preprocess_input(x, data_format=None):\n    return imagenet_utils.preprocess_input(\n        x, data_format=data_format, mode=\"tf\"\n    )\n\n\n@keras_export(\"keras.applications.inception_v3.decode_predictions\")\ndef decode_predictions(preds, top=5):\n    return imagenet_utils.decode_predictions(preds, top=top)\n\n\npreprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(\n    mode=\"\",\n    ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,\n    error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC,\n)\ndecode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__\n"
  },
  {
    "path": "keras/src/applications/mobilenet.py",
    "content": "import warnings\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src.api_export import keras_export\nfrom keras.src.applications import imagenet_utils\nfrom keras.src.models import Functional\nfrom keras.src.ops import operation_utils\nfrom keras.src.utils import file_utils\n\nBASE_WEIGHT_PATH = (\n    \"https://storage.googleapis.com/tensorflow/keras-applications/mobilenet/\"\n)\n\n\n@keras_export(\n    [\n        \"keras.applications.mobilenet.MobileNet\",\n        \"keras.applications.MobileNet\",\n    ]\n)\ndef MobileNet(\n    input_shape=None,\n    alpha=1.0,\n    depth_multiplier=1,\n    dropout=1e-3,\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=None,\n):\n    \"\"\"Instantiates the MobileNet architecture.\n\n    Reference:\n    - [MobileNets: Efficient Convolutional Neural Networks\n       for Mobile Vision Applications](\n        https://arxiv.org/abs/1704.04861)\n\n    This function returns a Keras image classification model,\n    optionally loaded with weights pre-trained on ImageNet.\n\n    For image classification use cases, see\n    [this page for detailed examples](\n    https://keras.io/api/applications/#usage-examples-for-image-classification-models).\n\n    For transfer learning use cases, make sure to read the\n    [guide to transfer learning & fine-tuning](\n    https://keras.io/guides/transfer_learning/).\n\n    Note: each Keras Application expects a specific kind of input preprocessing.\n    For MobileNet, call `keras.applications.mobilenet.preprocess_input`\n    on your inputs before passing them to the model.\n    `mobilenet.preprocess_input` will scale input pixels between -1 and 1.\n\n    Args:\n        input_shape: Optional shape tuple, only to be specified if `include_top`\n            is `False` (otherwise the input shape has to be `(224, 224, 3)`\n            (with `\"channels_last\"` data format) or `(3, 224, 224)`\n            (with `\"channels_first\"` data format).\n            It should have exactly 3 inputs channels, and width and\n            height should be no smaller than 32. E.g. `(200, 200, 3)` would\n            be one valid value. Defaults to `None`.\n            `input_shape` will be ignored if the `input_tensor` is provided.\n        alpha: Controls the width of the network. This is known as the width\n            multiplier in the MobileNet paper.\n            - If `alpha < 1.0`, proportionally decreases the number\n                of filters in each layer.\n            - If `alpha > 1.0`, proportionally increases the number\n                of filters in each layer.\n            - If `alpha == 1`, default number of filters from the paper\n                are used at each layer. Defaults to `1.0`.\n        depth_multiplier: Depth multiplier for depthwise convolution.\n            This is called the resolution multiplier in the MobileNet paper.\n            Defaults to `1.0`.\n        dropout: Dropout rate. Defaults to `0.001`.\n        include_top: Boolean, whether to include the fully-connected layer\n            at the top of the network. Defaults to `True`.\n        weights: One of `None` (random initialization), `\"imagenet\"`\n            (pre-training on ImageNet), or the path to the weights file\n            to be loaded. Defaults to `\"imagenet\"`.\n        input_tensor: Optional Keras tensor (i.e. output of `layers.Input()`)\n            to use as image input for the model. `input_tensor` is useful\n            for sharing inputs between multiple different networks.\n            Defaults to `None`.\n        pooling: Optional pooling mode for feature extraction when `include_top`\n            is `False`.\n            - `None` (default) means that the output of the model will be\n                the 4D tensor output of the last convolutional block.\n            - `avg` means that global average pooling\n                will be applied to the output of the\n                last convolutional block, and thus\n                the output of the model will be a 2D tensor.\n            - `max` means that global max pooling will be applied.\n        classes: Optional number of classes to classify images into,\n            only to be specified if `include_top` is `True`, and if\n            no `weights` argument is specified. Defaults to `1000`.\n        classifier_activation: A `str` or callable. The activation function\n            to use on the \"top\" layer. Ignored unless `include_top=True`.\n            Set `classifier_activation=None` to return the logits of the \"top\"\n            layer. When loading pretrained weights, `classifier_activation`\n            can only be `None` or `\"softmax\"`.\n        name: String, the name of the model.\n\n    Returns:\n        A model instance.\n    \"\"\"\n    if not (weights in {\"imagenet\", None} or file_utils.exists(weights)):\n        raise ValueError(\n            \"The `weights` argument should be either \"\n            \"`None` (random initialization), 'imagenet' \"\n            \"(pre-training on ImageNet), \"\n            \"or the path to the weights file to be loaded. \"\n            f\"Received weights={weights}\"\n        )\n\n    if weights == \"imagenet\" and include_top and classes != 1000:\n        raise ValueError(\n            \"If using `weights='imagenet'` with `include_top=True`, \"\n            \"`classes` should be 1000.  \"\n            f\"Received classes={classes}\"\n        )\n\n    # Determine proper input shape and default size.\n    if input_shape is None:\n        default_size = 224\n    else:\n        if backend.image_data_format() == \"channels_first\":\n            rows = input_shape[1]\n            cols = input_shape[2]\n        else:\n            rows = input_shape[0]\n            cols = input_shape[1]\n\n        if rows == cols and rows in [128, 160, 192, 224]:\n            default_size = rows\n        else:\n            default_size = 224\n\n    input_shape = imagenet_utils.obtain_input_shape(\n        input_shape,\n        default_size=default_size,\n        min_size=32,\n        data_format=backend.image_data_format(),\n        require_flatten=include_top,\n        weights=weights,\n    )\n\n    if backend.image_data_format() == \"channels_last\":\n        row_axis, col_axis = (0, 1)\n    else:\n        row_axis, col_axis = (1, 2)\n    rows = input_shape[row_axis]\n    cols = input_shape[col_axis]\n\n    if weights == \"imagenet\":\n        if depth_multiplier != 1:\n            raise ValueError(\n                \"If imagenet weights are being loaded, \"\n                \"depth multiplier must be 1.  \"\n                f\"Received depth_multiplier={depth_multiplier}\"\n            )\n\n        if alpha not in [0.25, 0.50, 0.75, 1.0]:\n            raise ValueError(\n                \"If imagenet weights are being loaded, \"\n                \"alpha can be one of\"\n                \"`0.25`, `0.50`, `0.75` or `1.0` only.  \"\n                f\"Received alpha={alpha}\"\n            )\n\n        if rows != cols or rows not in [128, 160, 192, 224]:\n            rows = 224\n            warnings.warn(\n                \"`input_shape` is undefined or non-square, \"\n                \"or `rows` is not in [128, 160, 192, 224]. \"\n                \"Weights for input shape (224, 224) will be \"\n                \"loaded as the default.\",\n                stacklevel=2,\n            )\n\n    if input_tensor is None:\n        img_input = layers.Input(shape=input_shape)\n    else:\n        if not backend.is_keras_tensor(input_tensor):\n            img_input = layers.Input(tensor=input_tensor, shape=input_shape)\n        else:\n            img_input = input_tensor\n\n    x = _conv_block(img_input, 32, alpha, strides=(2, 2))\n    x = _depthwise_conv_block(x, 64, alpha, depth_multiplier, block_id=1)\n\n    x = _depthwise_conv_block(\n        x, 128, alpha, depth_multiplier, strides=(2, 2), block_id=2\n    )\n    x = _depthwise_conv_block(x, 128, alpha, depth_multiplier, block_id=3)\n\n    x = _depthwise_conv_block(\n        x, 256, alpha, depth_multiplier, strides=(2, 2), block_id=4\n    )\n    x = _depthwise_conv_block(x, 256, alpha, depth_multiplier, block_id=5)\n\n    x = _depthwise_conv_block(\n        x, 512, alpha, depth_multiplier, strides=(2, 2), block_id=6\n    )\n    x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=7)\n    x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=8)\n    x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=9)\n    x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=10)\n    x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=11)\n\n    x = _depthwise_conv_block(\n        x, 1024, alpha, depth_multiplier, strides=(2, 2), block_id=12\n    )\n    x = _depthwise_conv_block(x, 1024, alpha, depth_multiplier, block_id=13)\n\n    if include_top:\n        x = layers.GlobalAveragePooling2D(keepdims=True)(x)\n        x = layers.Dropout(dropout, name=\"dropout\")(x)\n        x = layers.Conv2D(classes, (1, 1), padding=\"same\", name=\"conv_preds\")(x)\n        x = layers.Reshape((classes,), name=\"reshape_2\")(x)\n        imagenet_utils.validate_activation(classifier_activation, weights)\n        x = layers.Activation(\n            activation=classifier_activation, name=\"predictions\"\n        )(x)\n    else:\n        if pooling == \"avg\":\n            x = layers.GlobalAveragePooling2D()(x)\n        elif pooling == \"max\":\n            x = layers.GlobalMaxPooling2D()(x)\n\n    # Ensure that the model takes into account\n    # any potential predecessors of `input_tensor`.\n    if input_tensor is not None:\n        inputs = operation_utils.get_source_inputs(input_tensor)\n    else:\n        inputs = img_input\n\n    # Create model.\n    if name is None:\n        name = f\"mobilenet_{alpha:0.2f}_{rows}\"\n    model = Functional(inputs, x, name=name)\n\n    # Load weights.\n    if weights == \"imagenet\":\n        if alpha == 1.0:\n            alpha_text = \"1_0\"\n        elif alpha == 0.75:\n            alpha_text = \"7_5\"\n        elif alpha == 0.50:\n            alpha_text = \"5_0\"\n        else:\n            alpha_text = \"2_5\"\n\n        if include_top:\n            model_name = \"mobilenet_%s_%d_tf.h5\" % (alpha_text, rows)\n            weight_path = BASE_WEIGHT_PATH + model_name\n            weights_path = file_utils.get_file(\n                model_name, weight_path, cache_subdir=\"models\"\n            )\n        else:\n            model_name = \"mobilenet_%s_%d_tf_no_top.h5\" % (alpha_text, rows)\n            weight_path = BASE_WEIGHT_PATH + model_name\n            weights_path = file_utils.get_file(\n                model_name, weight_path, cache_subdir=\"models\"\n            )\n        model.load_weights(weights_path)\n    elif weights is not None:\n        model.load_weights(weights)\n\n    return model\n\n\ndef _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)):\n    \"\"\"Adds an initial convolution layer (with batch normalization and relu6).\n\n    Args:\n        inputs: Input tensor of shape `(rows, cols, 3)` (with `channels_last`\n            data format) or (3, rows, cols) (with `channels_first` data format).\n            It should have exactly 3 inputs channels, and width and height\n            should be no smaller than 32. E.g. `(224, 224, 3)` would be\n            one valid value.\n        filters: Integer, the dimensionality of the output space (i.e. the\n            number of output filters in the convolution).\n        alpha: controls the width of the network. - If `alpha` < 1.0,\n            proportionally decreases the number of filters in each layer.\n            - If `alpha` > 1.0, proportionally increases the number of filters\n                in each layer.\n            - If `alpha` = 1, default number of filters from the paper are\n                used at each layer.\n        kernel: An integer or tuple/list of 2 integers, specifying the width\n            and height of the 2D convolution window.\n            Can be a single integer to specify the same value for\n            all spatial dimensions.\n        strides: An integer or tuple/list of 2 integers, specifying the strides\n            of the convolution along the width and height.\n            Can be a single integer to specify the same value for all\n            spatial dimensions. Specifying any stride value != 1 is\n            incompatible with specifying any `dilation_rate`\n            value != 1.\n\n    Input shape:\n        4D tensor with shape: `(samples, channels, rows, cols)` if\n            data_format='channels_first'\n        or 4D tensor with shape: `(samples, rows, cols, channels)` if\n            data_format='channels_last'. # Output shape\n        4D tensor with shape: `(samples, filters, new_rows, new_cols)`\n            if data_format='channels_first'\n        or 4D tensor with shape: `(samples, new_rows, new_cols, filters)`\n            if data_format='channels_last'. `rows` and `cols` values\n            might have changed due to stride.\n\n    Returns:\n        Output tensor of block.\n    \"\"\"\n    channel_axis = 1 if backend.image_data_format() == \"channels_first\" else -1\n    filters = int(filters * alpha)\n    x = layers.Conv2D(\n        filters,\n        kernel,\n        padding=\"same\",\n        use_bias=False,\n        strides=strides,\n        name=\"conv1\",\n    )(inputs)\n    x = layers.BatchNormalization(axis=channel_axis, name=\"conv1_bn\")(x)\n    return layers.ReLU(6.0, name=\"conv1_relu\")(x)\n\n\ndef _depthwise_conv_block(\n    inputs,\n    pointwise_conv_filters,\n    alpha,\n    depth_multiplier=1,\n    strides=(1, 1),\n    block_id=1,\n):\n    \"\"\"Adds a depthwise convolution block.\n\n    A depthwise convolution block consists of a depthwise conv,\n    batch normalization, relu6, pointwise convolution,\n    batch normalization and relu6 activation.\n\n    Args:\n        inputs: Input tensor of shape `(rows, cols, channels)` (with\n            `channels_last` data format) or (channels, rows, cols) (with\n            `channels_first` data format).\n        pointwise_conv_filters: Integer, the dimensionality of the output space\n            (i.e. the number of output filters in the pointwise convolution).\n        alpha: controls the width of the network. - If `alpha` < 1.0,\n            proportionally decreases the number of filters in each layer.\n            - If `alpha` > 1.0, proportionally increases the number of filters\n                in each layer.\n            - If `alpha` = 1, default number of filters from the paper are\n                used at each layer.\n        depth_multiplier: The number of depthwise convolution output channels\n            for each input channel. The total number of depthwise convolution\n            output channels will be equal to `filters_in * depth_multiplier`.\n        strides: An integer or tuple/list of 2 integers, specifying the strides\n            of the convolution along the width and height.\n            Can be a single integer to specify the same value for\n            all spatial dimensions. Specifying any stride value != 1 is\n            incompatible with specifying any `dilation_rate` value != 1.\n        block_id: Integer, a unique identification designating the block number.\n\n    Input shape:\n        4D tensor with shape: `(batch, channels, rows, cols)` if\n            data_format='channels_first'\n        or 4D tensor with shape: `(batch, rows, cols, channels)` if\n            data_format='channels_last'. # Output shape\n        4D tensor with shape: `(batch, filters, new_rows, new_cols)` if\n            data_format='channels_first'\n        or 4D tensor with shape: `(batch, new_rows, new_cols, filters)` if\n            data_format='channels_last'. `rows` and `cols` values might have\n            changed due to stride.\n\n    Returns:\n        Output tensor of block.\n    \"\"\"\n    channel_axis = 1 if backend.image_data_format() == \"channels_first\" else -1\n    pointwise_conv_filters = int(pointwise_conv_filters * alpha)\n\n    if strides == (1, 1):\n        x = inputs\n    else:\n        x = layers.ZeroPadding2D(\n            ((0, 1), (0, 1)), name=\"conv_pad_%d\" % block_id\n        )(inputs)\n    x = layers.DepthwiseConv2D(\n        (3, 3),\n        padding=\"same\" if strides == (1, 1) else \"valid\",\n        depth_multiplier=depth_multiplier,\n        strides=strides,\n        use_bias=False,\n        name=\"conv_dw_%d\" % block_id,\n    )(x)\n    x = layers.BatchNormalization(\n        axis=channel_axis, name=\"conv_dw_%d_bn\" % block_id\n    )(x)\n    x = layers.ReLU(6.0, name=\"conv_dw_%d_relu\" % block_id)(x)\n\n    x = layers.Conv2D(\n        pointwise_conv_filters,\n        (1, 1),\n        padding=\"same\",\n        use_bias=False,\n        strides=(1, 1),\n        name=\"conv_pw_%d\" % block_id,\n    )(x)\n    x = layers.BatchNormalization(\n        axis=channel_axis, name=\"conv_pw_%d_bn\" % block_id\n    )(x)\n    return layers.ReLU(6.0, name=\"conv_pw_%d_relu\" % block_id)(x)\n\n\n@keras_export(\"keras.applications.mobilenet.preprocess_input\")\ndef preprocess_input(x, data_format=None):\n    return imagenet_utils.preprocess_input(\n        x, data_format=data_format, mode=\"tf\"\n    )\n\n\n@keras_export(\"keras.applications.mobilenet.decode_predictions\")\ndef decode_predictions(preds, top=5):\n    return imagenet_utils.decode_predictions(preds, top=top)\n\n\npreprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(\n    mode=\"\",\n    ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,\n    error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC,\n)\ndecode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__\n"
  },
  {
    "path": "keras/src/applications/mobilenet_v2.py",
    "content": "import warnings\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src.api_export import keras_export\nfrom keras.src.applications import imagenet_utils\nfrom keras.src.models import Functional\nfrom keras.src.ops import operation_utils\nfrom keras.src.utils import file_utils\n\nBASE_WEIGHT_PATH = (\n    \"https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/\"\n)\n\n\n@keras_export(\n    [\n        \"keras.applications.mobilenet_v2.MobileNetV2\",\n        \"keras.applications.MobileNetV2\",\n    ]\n)\ndef MobileNetV2(\n    input_shape=None,\n    alpha=1.0,\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=None,\n):\n    \"\"\"Instantiates the MobileNetV2 architecture.\n\n    MobileNetV2 is very similar to the original MobileNet,\n    except that it uses inverted residual blocks with\n    bottlenecking features. It has a drastically lower\n    parameter count than the original MobileNet.\n    MobileNets support any input size greater\n    than 32 x 32, with larger image sizes\n    offering better performance.\n\n    Reference:\n    - [MobileNetV2: Inverted Residuals and Linear Bottlenecks](\n        https://arxiv.org/abs/1801.04381) (CVPR 2018)\n\n    This function returns a Keras image classification model,\n    optionally loaded with weights pre-trained on ImageNet.\n\n    For image classification use cases, see\n    [this page for detailed examples](\n      https://keras.io/api/applications/#usage-examples-for-image-classification-models).\n\n    For transfer learning use cases, make sure to read the\n    [guide to transfer learning & fine-tuning](\n      https://keras.io/guides/transfer_learning/).\n\n    Note: each Keras Application expects a specific kind of input preprocessing.\n    For MobileNetV2, call\n    `keras.applications.mobilenet_v2.preprocess_input`\n    on your inputs before passing them to the model.\n    `mobilenet_v2.preprocess_input` will scale input pixels between -1 and 1.\n\n    Args:\n        input_shape: Optional shape tuple, only to be specified if `include_top`\n            is `False` (otherwise the input shape has to be `(224, 224, 3)`\n            (with `\"channels_last\"` data format) or `(3, 224, 224)`\n            (with `\"channels_first\"` data format).\n            It should have exactly 3 inputs channels, and width and\n            height should be no smaller than 32. E.g. `(200, 200, 3)` would\n            be one valid value. Defaults to `None`.\n            `input_shape` will be ignored if the `input_tensor` is provided.\n        alpha: Controls the width of the network. This is known as the width\n            multiplier in the MobileNet paper.\n            - If `alpha < 1.0`, proportionally decreases the number\n                of filters in each layer.\n            - If `alpha > 1.0`, proportionally increases the number\n                of filters in each layer.\n            - If `alpha == 1`, default number of filters from the paper\n                are used at each layer. Defaults to `1.0`.\n        include_top: Boolean, whether to include the fully-connected layer\n            at the top of the network. Defaults to `True`.\n        weights: One of `None` (random initialization), `\"imagenet\"`\n            (pre-training on ImageNet), or the path to the weights file\n            to be loaded. Defaults to `\"imagenet\"`.\n        input_tensor: Optional Keras tensor (i.e. output of `layers.Input()`)\n            to use as image input for the model. `input_tensor` is useful\n            for sharing inputs between multiple different networks.\n            Defaults to `None`.\n        pooling: Optional pooling mode for feature extraction when `include_top`\n            is `False`.\n            - `None` (default) means that the output of the model will be\n                the 4D tensor output of the last convolutional block.\n            - `avg` means that global average pooling\n                will be applied to the output of the\n                last convolutional block, and thus\n                the output of the model will be a 2D tensor.\n            - `max` means that global max pooling will be applied.\n        classes: Optional number of classes to classify images into,\n            only to be specified if `include_top` is `True`, and if\n            no `weights` argument is specified. Defaults to `1000`.\n        classifier_activation: A `str` or callable. The activation function\n            to use on the \"top\" layer. Ignored unless `include_top=True`.\n            Set `classifier_activation=None` to return the logits of the \"top\"\n            layer. When loading pretrained weights, `classifier_activation`\n            can only be `None` or `\"softmax\"`.\n        name: String, the name of the model.\n\n    Returns:\n        A model instance.\n    \"\"\"\n    if not (weights in {\"imagenet\", None} or file_utils.exists(weights)):\n        raise ValueError(\n            \"The `weights` argument should be either \"\n            \"`None` (random initialization), `imagenet` \"\n            \"(pre-training on ImageNet), \"\n            \"or the path to the weights file to be loaded.  \"\n            f\"Received `weights={weights}`\"\n        )\n\n    if weights == \"imagenet\" and include_top and classes != 1000:\n        raise ValueError(\n            'If using `weights=\"imagenet\"` with `include_top` '\n            f\"as true, `classes` should be 1000. Received `classes={classes}`\"\n        )\n\n    # Determine proper input shape and default size.\n    # If both input_shape and input_tensor are used, they should match\n    if input_shape is not None and input_tensor is not None:\n        try:\n            is_input_t_tensor = backend.is_keras_tensor(input_tensor)\n        except ValueError:\n            try:\n                is_input_t_tensor = backend.is_keras_tensor(\n                    operation_utils.get_source_inputs(input_tensor)\n                )\n            except ValueError:\n                raise ValueError(\n                    f\"input_tensor: {input_tensor}\"\n                    \"is not type input_tensor. \"\n                    f\"Received `type(input_tensor)={type(input_tensor)}`\"\n                )\n        if is_input_t_tensor:\n            if backend.image_data_format() == \"channels_first\":\n                if input_tensor.shape[1] != input_shape[1]:\n                    raise ValueError(\n                        \"input_shape[1] must equal shape(input_tensor)[1] \"\n                        \"when `image_data_format` is `channels_first`; \"\n                        \"Received `input_tensor.shape=\"\n                        f\"{input_tensor.shape}`\"\n                        f\", `input_shape={input_shape}`\"\n                    )\n            else:\n                if input_tensor.shape[2] != input_shape[1]:\n                    raise ValueError(\n                        \"input_tensor.shape[2] must equal input_shape[1]; \"\n                        \"Received `input_tensor.shape=\"\n                        f\"{input_tensor.shape}`, \"\n                        f\"`input_shape={input_shape}`\"\n                    )\n        else:\n            raise ValueError(\n                \"input_tensor is not a Keras tensor; \"\n                f\"Received `input_tensor={input_tensor}`\"\n            )\n\n    # If input_shape is None, infer shape from input_tensor.\n    if input_shape is None and input_tensor is not None:\n        try:\n            backend.is_keras_tensor(input_tensor)\n        except ValueError:\n            raise ValueError(\n                \"input_tensor must be a valid Keras tensor type; \"\n                f\"Received {input_tensor} of type {type(input_tensor)}\"\n            )\n\n        if input_shape is None and not backend.is_keras_tensor(input_tensor):\n            default_size = 224\n        elif input_shape is None and backend.is_keras_tensor(input_tensor):\n            if backend.image_data_format() == \"channels_first\":\n                rows = input_tensor.shape[2]\n                cols = input_tensor.shape[3]\n            else:\n                rows = input_tensor.shape[1]\n                cols = input_tensor.shape[2]\n\n            if rows == cols and rows in [96, 128, 160, 192, 224]:\n                default_size = rows\n            else:\n                default_size = 224\n\n    # If input_shape is None and no input_tensor\n    elif input_shape is None:\n        default_size = 224\n\n    # If input_shape is not None, assume default size.\n    else:\n        if backend.image_data_format() == \"channels_first\":\n            rows = input_shape[1]\n            cols = input_shape[2]\n        else:\n            rows = input_shape[0]\n            cols = input_shape[1]\n\n        if rows == cols and rows in [96, 128, 160, 192, 224]:\n            default_size = rows\n        else:\n            default_size = 224\n\n    input_shape = imagenet_utils.obtain_input_shape(\n        input_shape,\n        default_size=default_size,\n        min_size=32,\n        data_format=backend.image_data_format(),\n        require_flatten=include_top,\n        weights=weights,\n    )\n\n    if backend.image_data_format() == \"channels_last\":\n        row_axis, col_axis = (0, 1)\n    else:\n        row_axis, col_axis = (1, 2)\n    rows = input_shape[row_axis]\n    cols = input_shape[col_axis]\n\n    if weights == \"imagenet\":\n        if alpha not in [0.35, 0.50, 0.75, 1.0, 1.3, 1.4]:\n            raise ValueError(\n                \"If imagenet weights are being loaded, \"\n                \"alpha must be one of `0.35`, `0.50`, `0.75`, \"\n                \"`1.0`, `1.3` or `1.4` only;\"\n                f\" Received `alpha={alpha}`\"\n            )\n\n        if rows != cols or rows not in [96, 128, 160, 192, 224]:\n            rows = 224\n            warnings.warn(\n                \"`input_shape` is undefined or non-square, \"\n                \"or `rows` is not in [96, 128, 160, 192, 224]. \"\n                \"Weights for input shape (224, 224) will be \"\n                \"loaded as the default.\",\n                stacklevel=2,\n            )\n\n    if input_tensor is None:\n        img_input = layers.Input(shape=input_shape)\n    else:\n        if not backend.is_keras_tensor(input_tensor):\n            img_input = layers.Input(tensor=input_tensor, shape=input_shape)\n        else:\n            img_input = input_tensor\n\n    channel_axis = 1 if backend.image_data_format() == \"channels_first\" else -1\n\n    first_block_filters = _make_divisible(32 * alpha, 8)\n    x = layers.Conv2D(\n        first_block_filters,\n        kernel_size=3,\n        strides=(2, 2),\n        padding=\"same\",\n        use_bias=False,\n        name=\"Conv1\",\n    )(img_input)\n    x = layers.BatchNormalization(\n        axis=channel_axis, epsilon=1e-3, momentum=0.999, name=\"bn_Conv1\"\n    )(x)\n    x = layers.ReLU(6.0, name=\"Conv1_relu\")(x)\n\n    x = _inverted_res_block(\n        x, filters=16, alpha=alpha, stride=1, expansion=1, block_id=0\n    )\n\n    x = _inverted_res_block(\n        x, filters=24, alpha=alpha, stride=2, expansion=6, block_id=1\n    )\n    x = _inverted_res_block(\n        x, filters=24, alpha=alpha, stride=1, expansion=6, block_id=2\n    )\n\n    x = _inverted_res_block(\n        x, filters=32, alpha=alpha, stride=2, expansion=6, block_id=3\n    )\n    x = _inverted_res_block(\n        x, filters=32, alpha=alpha, stride=1, expansion=6, block_id=4\n    )\n    x = _inverted_res_block(\n        x, filters=32, alpha=alpha, stride=1, expansion=6, block_id=5\n    )\n\n    x = _inverted_res_block(\n        x, filters=64, alpha=alpha, stride=2, expansion=6, block_id=6\n    )\n    x = _inverted_res_block(\n        x, filters=64, alpha=alpha, stride=1, expansion=6, block_id=7\n    )\n    x = _inverted_res_block(\n        x, filters=64, alpha=alpha, stride=1, expansion=6, block_id=8\n    )\n    x = _inverted_res_block(\n        x, filters=64, alpha=alpha, stride=1, expansion=6, block_id=9\n    )\n\n    x = _inverted_res_block(\n        x, filters=96, alpha=alpha, stride=1, expansion=6, block_id=10\n    )\n    x = _inverted_res_block(\n        x, filters=96, alpha=alpha, stride=1, expansion=6, block_id=11\n    )\n    x = _inverted_res_block(\n        x, filters=96, alpha=alpha, stride=1, expansion=6, block_id=12\n    )\n\n    x = _inverted_res_block(\n        x, filters=160, alpha=alpha, stride=2, expansion=6, block_id=13\n    )\n    x = _inverted_res_block(\n        x, filters=160, alpha=alpha, stride=1, expansion=6, block_id=14\n    )\n    x = _inverted_res_block(\n        x, filters=160, alpha=alpha, stride=1, expansion=6, block_id=15\n    )\n\n    x = _inverted_res_block(\n        x, filters=320, alpha=alpha, stride=1, expansion=6, block_id=16\n    )\n\n    # no alpha applied to last conv as stated in the paper:\n    # if the width multiplier is greater than 1 we increase the number of output\n    # channels.\n    if alpha > 1.0:\n        last_block_filters = _make_divisible(1280 * alpha, 8)\n    else:\n        last_block_filters = 1280\n\n    x = layers.Conv2D(\n        last_block_filters, kernel_size=1, use_bias=False, name=\"Conv_1\"\n    )(x)\n    x = layers.BatchNormalization(\n        axis=channel_axis, epsilon=1e-3, momentum=0.999, name=\"Conv_1_bn\"\n    )(x)\n    x = layers.ReLU(6.0, name=\"out_relu\")(x)\n\n    if include_top:\n        x = layers.GlobalAveragePooling2D()(x)\n        imagenet_utils.validate_activation(classifier_activation, weights)\n        x = layers.Dense(\n            classes, activation=classifier_activation, name=\"predictions\"\n        )(x)\n\n    else:\n        if pooling == \"avg\":\n            x = layers.GlobalAveragePooling2D()(x)\n        elif pooling == \"max\":\n            x = layers.GlobalMaxPooling2D()(x)\n\n    # Ensure that the model takes into account any potential predecessors of\n    # `input_tensor`.\n    if input_tensor is not None:\n        inputs = operation_utils.get_source_inputs(input_tensor)\n    else:\n        inputs = img_input\n\n    # Create model.\n    if name is None:\n        name = f\"mobilenetv2_{alpha:0.2f}_{rows}\"\n    model = Functional(inputs, x, name=name)\n\n    # Load weights.\n    if weights == \"imagenet\":\n        if include_top:\n            model_name = (\n                \"mobilenet_v2_weights_tf_dim_ordering_tf_kernels\"\n                f\"_{float(alpha)}_{rows}.h5\"\n            )\n            weight_path = BASE_WEIGHT_PATH + model_name\n            weights_path = file_utils.get_file(\n                model_name, weight_path, cache_subdir=\"models\"\n            )\n        else:\n            model_name = (\n                \"mobilenet_v2_weights_tf_dim_ordering_tf_kernels_\"\n                f\"{float(alpha)}_{rows}_no_top.h5\"\n            )\n            weight_path = BASE_WEIGHT_PATH + model_name\n            weights_path = file_utils.get_file(\n                model_name, weight_path, cache_subdir=\"models\"\n            )\n        model.load_weights(weights_path)\n    elif weights is not None:\n        model.load_weights(weights)\n\n    return model\n\n\ndef _inverted_res_block(inputs, expansion, stride, alpha, filters, block_id):\n    \"\"\"Inverted ResNet block.\"\"\"\n    channel_axis = 1 if backend.image_data_format() == \"channels_first\" else -1\n\n    in_channels = inputs.shape[channel_axis]\n    pointwise_conv_filters = int(filters * alpha)\n    # Ensure the number of filters on the last 1x1 convolution is divisible by\n    # 8.\n    pointwise_filters = _make_divisible(pointwise_conv_filters, 8)\n    x = inputs\n    prefix = f\"block_{block_id}_\"\n\n    if block_id:\n        # Expand with a pointwise 1x1 convolution.\n        x = layers.Conv2D(\n            expansion * in_channels,\n            kernel_size=1,\n            padding=\"same\",\n            use_bias=False,\n            activation=None,\n            name=f\"{prefix}expand\",\n        )(x)\n        x = layers.BatchNormalization(\n            axis=channel_axis,\n            epsilon=1e-3,\n            momentum=0.999,\n            name=f\"{prefix}expand_BN\",\n        )(x)\n        x = layers.ReLU(6.0, name=f\"{prefix}expand_relu\")(x)\n    else:\n        prefix = \"expanded_conv_\"\n\n    # Depthwise 3x3 convolution.\n    if stride == 2:\n        x = layers.ZeroPadding2D(\n            padding=imagenet_utils.correct_pad(x, 3), name=f\"{prefix}pad\"\n        )(x)\n    x = layers.DepthwiseConv2D(\n        kernel_size=3,\n        strides=stride,\n        activation=None,\n        use_bias=False,\n        padding=\"same\" if stride == 1 else \"valid\",\n        name=f\"{prefix}depthwise\",\n    )(x)\n    x = layers.BatchNormalization(\n        axis=channel_axis,\n        epsilon=1e-3,\n        momentum=0.999,\n        name=f\"{prefix}depthwise_BN\",\n    )(x)\n\n    x = layers.ReLU(6.0, name=f\"{prefix}depthwise_relu\")(x)\n\n    # Project with a pointwise 1x1 convolution.\n    x = layers.Conv2D(\n        pointwise_filters,\n        kernel_size=1,\n        padding=\"same\",\n        use_bias=False,\n        activation=None,\n        name=f\"{prefix}project\",\n    )(x)\n    x = layers.BatchNormalization(\n        axis=channel_axis,\n        epsilon=1e-3,\n        momentum=0.999,\n        name=f\"{prefix}project_BN\",\n    )(x)\n\n    if in_channels == pointwise_filters and stride == 1:\n        return layers.Add(name=f\"{prefix}add\")([inputs, x])\n    return x\n\n\ndef _make_divisible(v, divisor, min_value=None):\n    if min_value is None:\n        min_value = divisor\n    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)\n    # Make sure that round down does not go down by more than 10%.\n    if new_v < 0.9 * v:\n        new_v += divisor\n    return new_v\n\n\n@keras_export(\"keras.applications.mobilenet_v2.preprocess_input\")\ndef preprocess_input(x, data_format=None):\n    return imagenet_utils.preprocess_input(\n        x, data_format=data_format, mode=\"tf\"\n    )\n\n\n@keras_export(\"keras.applications.mobilenet_v2.decode_predictions\")\ndef decode_predictions(preds, top=5):\n    return imagenet_utils.decode_predictions(preds, top=top)\n\n\npreprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(\n    mode=\"\",\n    ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,\n    error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC,\n)\ndecode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__\n"
  },
  {
    "path": "keras/src/applications/mobilenet_v3.py",
    "content": "import warnings\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src.api_export import keras_export\nfrom keras.src.applications import imagenet_utils\nfrom keras.src.models import Functional\nfrom keras.src.ops import operation_utils\nfrom keras.src.utils import file_utils\n\nBASE_WEIGHT_PATH = (\n    \"https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v3/\"\n)\nWEIGHTS_HASHES = {\n    \"large_224_0.75_float\": (\n        \"765b44a33ad4005b3ac83185abf1d0eb\",\n        \"40af19a13ebea4e2ee0c676887f69a2e\",\n    ),\n    \"large_224_1.0_float\": (\n        \"59e551e166be033d707958cf9e29a6a7\",\n        \"07fb09a5933dd0c8eaafa16978110389\",\n    ),\n    \"large_minimalistic_224_1.0_float\": (\n        \"675e7b876c45c57e9e63e6d90a36599c\",\n        \"ec5221f64a2f6d1ef965a614bdae7973\",\n    ),\n    \"small_224_0.75_float\": (\n        \"cb65d4e5be93758266aa0a7f2c6708b7\",\n        \"ebdb5cc8e0b497cd13a7c275d475c819\",\n    ),\n    \"small_224_1.0_float\": (\n        \"8768d4c2e7dee89b9d02b2d03d65d862\",\n        \"d3e8ec802a04aa4fc771ee12a9a9b836\",\n    ),\n    \"small_minimalistic_224_1.0_float\": (\n        \"99cd97fb2fcdad2bf028eb838de69e37\",\n        \"cde8136e733e811080d9fcd8a252f7e4\",\n    ),\n}\n\n\nBASE_DOCSTRING = \"\"\"Instantiates the {name} architecture.\n\nReference:\n- [Searching for MobileNetV3](\n    https://arxiv.org/pdf/1905.02244.pdf) (ICCV 2019)\n\nThe following table describes the performance of MobileNets v3:\n------------------------------------------------------------------------\nMACs stands for Multiply Adds\n\n|Classification Checkpoint|MACs(M)|Parameters(M)|Top1 Accuracy|Pixel1 CPU(ms)|\n|---|---|---|---|---|\n| mobilenet_v3_large_1.0_224              | 217 | 5.4 |   75.6   |   51.2  |\n| mobilenet_v3_large_0.75_224             | 155 | 4.0 |   73.3   |   39.8  |\n| mobilenet_v3_large_minimalistic_1.0_224 | 209 | 3.9 |   72.3   |   44.1  |\n| mobilenet_v3_small_1.0_224              | 66  | 2.9 |   68.1   |   15.8  |\n| mobilenet_v3_small_0.75_224             | 44  | 2.4 |   65.4   |   12.8  |\n| mobilenet_v3_small_minimalistic_1.0_224 | 65  | 2.0 |   61.9   |   12.2  |\n\nFor image classification use cases, see\n[this page for detailed examples](\nhttps://keras.io/api/applications/#usage-examples-for-image-classification-models).\n\nFor transfer learning use cases, make sure to read the\n[guide to transfer learning & fine-tuning](\nhttps://keras.io/guides/transfer_learning/).\n\nNote: each Keras Application expects a specific kind of input preprocessing.\nFor MobileNetV3, by default input preprocessing is included as a part of the\nmodel (as a `Rescaling` layer), and thus\n`keras.applications.mobilenet_v3.preprocess_input` is actually a\npass-through function. In this use case, MobileNetV3 models expect their\ninputs to be float tensors of pixels with values in the `[0-255]` range.\nAt the same time, preprocessing as a part of the model (i.e. `Rescaling`\nlayer) can be disabled by setting `include_preprocessing` argument to `False`.\nWith preprocessing disabled MobileNetV3 models expect their inputs to be float\ntensors of pixels with values in the `[-1, 1]` range.\n\nArgs:\n    input_shape: Optional shape tuple, to be specified if you would\n        like to use a model with an input image resolution that is not\n        `(224, 224, 3)`.\n        It should have exactly 3 inputs channels.\n        You can also omit this option if you would like\n        to infer input_shape from an input_tensor.\n        If you choose to include both input_tensor and input_shape then\n        input_shape will be used if they match, if the shapes\n        do not match then we will throw an error.\n        E.g. `(160, 160, 3)` would be one valid value.\n    alpha: controls the width of the network. This is known as the\n        depth multiplier in the MobileNetV3 paper, but the name is kept for\n        consistency with MobileNetV1 in Keras.\n        When `weights` is `imagenet`, `alpha` can be one of `0.75` or `1.0`\n        for non-minimalistic models, and `1.0` for minimalistic models.\n        - If `alpha < 1.0`, proportionally decreases the number\n            of filters in each layer.\n        - If `alpha > 1.0`, proportionally increases the number\n            of filters in each layer.\n        - If `alpha == 1`, default number of filters from the paper\n            are used at each layer.\n    minimalistic: In addition to large and small models this module also\n        contains so-called minimalistic models, these models have the same\n        per-layer dimensions characteristic as MobilenetV3 however, they don't\n        utilize any of the advanced blocks (squeeze-and-excite units,\n        hard-swish, and 5x5 convolutions).\n        While these models are less efficient on CPU, they\n        are much more performant on GPU/DSP.\n    include_top: Boolean, whether to include the fully-connected\n        layer at the top of the network. Defaults to `True`.\n    weights: String, one of `None` (random initialization),\n        `\"imagenet\"` (pre-training on ImageNet),\n        or the path to the weights file to be loaded.\n    input_tensor: Optional Keras tensor (i.e. output of\n        `layers.Input()`)\n        to use as image input for the model.\n    pooling: String, optional pooling mode for feature extraction\n        when `include_top` is `False`.\n        - `None` means that the output of the model\n            will be the 4D tensor output of the\n            last convolutional block.\n        - `avg` means that global average pooling\n            will be applied to the output of the\n            last convolutional block, and thus\n            the output of the model will be a\n            2D tensor.\n        - `max` means that global max pooling will\n            be applied.\n    classes: Integer, optional number of classes to classify images\n        into, only to be specified if `include_top` is `True`, and\n        if no `weights` argument is specified.\n    dropout_rate: fraction of the input units to drop on the last layer.\n    classifier_activation: A `str` or callable. The activation function to use\n        on the \"top\" layer. Ignored unless `include_top=True`. Set\n        `classifier_activation=None` to return the logits of the \"top\" layer.\n        When loading pretrained weights, `classifier_activation` can only\n        be `None` or `\"softmax\"`.\n    include_preprocessing: Boolean, whether to include the preprocessing\n        layer (`Rescaling`) at the bottom of the network. Defaults to `True`.\n    name: String, the name of the model.\n\nCall arguments:\n    inputs: A floating point `numpy.array` or backend-native tensor,\n        4D with 3 color channels, with values in the range `[0, 255]`\n        if `include_preprocessing` is `True` and in the range `[-1, 1]`\n        otherwise.\n\nReturns:\n    A model instance.\n\"\"\"\n\n\ndef MobileNetV3(\n    stack_fn,\n    last_point_ch,\n    input_shape=None,\n    alpha=1.0,\n    model_type=\"large\",\n    minimalistic=False,\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    classes=1000,\n    pooling=None,\n    dropout_rate=0.2,\n    classifier_activation=\"softmax\",\n    include_preprocessing=True,\n    name=None,\n):\n    if not (weights in {\"imagenet\", None} or file_utils.exists(weights)):\n        raise ValueError(\n            \"The `weights` argument should be either \"\n            \"`None` (random initialization), `imagenet` \"\n            \"(pre-training on ImageNet), \"\n            \"or the path to the weights file to be loaded.  \"\n            f\"Received weights={weights}\"\n        )\n\n    if weights == \"imagenet\" and include_top and classes != 1000:\n        raise ValueError(\n            'If using `weights=\"imagenet\"` with `include_top` '\n            \"as true, `classes` should be 1000.  \"\n            f\"Received classes={classes}\"\n        )\n\n    # Determine proper input shape and default size.\n    # If both input_shape and input_tensor are used, they should match\n    if input_shape is not None and input_tensor is not None:\n        try:\n            is_input_t_tensor = backend.is_keras_tensor(input_tensor)\n        except ValueError:\n            try:\n                is_input_t_tensor = backend.is_keras_tensor(\n                    operation_utils.get_source_inputs(input_tensor)\n                )\n            except ValueError:\n                raise ValueError(\n                    \"input_tensor: \",\n                    input_tensor,\n                    \"is not type input_tensor.  \"\n                    f\"Received type(input_tensor)={type(input_tensor)}\",\n                )\n        if is_input_t_tensor:\n            if backend.image_data_format() == \"channels_first\":\n                if input_tensor.shape[1] != input_shape[1]:\n                    raise ValueError(\n                        \"When backend.image_data_format()=channels_first, \"\n                        \"input_shape[1] must equal \"\n                        \"input_tensor.shape[1].  Received \"\n                        f\"input_shape={input_shape}, \"\n                        \"input_tensor.shape=\"\n                        f\"{input_tensor.shape}\"\n                    )\n            else:\n                if input_tensor.shape[2] != input_shape[1]:\n                    raise ValueError(\n                        \"input_shape[1] must equal \"\n                        \"input_tensor.shape[2].  Received \"\n                        f\"input_shape={input_shape}, \"\n                        \"input_tensor.shape=\"\n                        f\"{input_tensor.shape}\"\n                    )\n        else:\n            raise ValueError(\n                \"input_tensor specified: \",\n                input_tensor,\n                \"is not a keras tensor\",\n            )\n\n    # If input_shape is None, infer shape from input_tensor\n    if input_shape is None and input_tensor is not None:\n        try:\n            backend.is_keras_tensor(input_tensor)\n        except ValueError:\n            raise ValueError(\n                \"input_tensor: \",\n                input_tensor,\n                \"is type: \",\n                type(input_tensor),\n                \"which is not a valid type\",\n            )\n\n        if backend.is_keras_tensor(input_tensor):\n            if backend.image_data_format() == \"channels_first\":\n                rows = input_tensor.shape[2]\n                cols = input_tensor.shape[3]\n                input_shape = (3, cols, rows)\n            else:\n                rows = input_tensor.shape[1]\n                cols = input_tensor.shape[2]\n                input_shape = (cols, rows, 3)\n    # If input_shape is None and input_tensor is None using standard shape\n    if input_shape is None and input_tensor is None:\n        if backend.image_data_format() == \"channels_last\":\n            input_shape = (None, None, 3)\n        else:\n            input_shape = (3, None, None)\n\n    if backend.image_data_format() == \"channels_last\":\n        row_axis, col_axis = (0, 1)\n    else:\n        row_axis, col_axis = (1, 2)\n    rows = input_shape[row_axis]\n    cols = input_shape[col_axis]\n    if rows and cols and (rows < 32 or cols < 32):\n        raise ValueError(\n            \"Input size must be at least 32x32; Received `input_shape=\"\n            f\"{input_shape}`\"\n        )\n    if weights == \"imagenet\":\n        if (\n            not minimalistic\n            and alpha not in [0.75, 1.0]\n            or minimalistic\n            and alpha != 1.0\n        ):\n            raise ValueError(\n                \"If imagenet weights are being loaded, \"\n                \"alpha can be one of `0.75`, `1.0` for non minimalistic \"\n                \"or `1.0` for minimalistic only.\"\n            )\n\n        if rows != cols or rows != 224:\n            warnings.warn(\n                \"`input_shape` is undefined or non-square, \"\n                \"or `rows` is not 224. \"\n                \"Weights for input shape (224, 224) will be \"\n                \"loaded as the default.\",\n                stacklevel=2,\n            )\n\n    if input_tensor is None:\n        img_input = layers.Input(shape=input_shape)\n    else:\n        if not backend.is_keras_tensor(input_tensor):\n            img_input = layers.Input(tensor=input_tensor, shape=input_shape)\n        else:\n            img_input = input_tensor\n\n    channel_axis = 1 if backend.image_data_format() == \"channels_first\" else -1\n\n    if minimalistic:\n        kernel = 3\n        activation = relu\n        se_ratio = None\n    else:\n        kernel = 5\n        activation = hard_swish\n        se_ratio = 0.25\n\n    x = img_input\n    if include_preprocessing:\n        x = layers.Rescaling(scale=1.0 / 127.5, offset=-1.0)(x)\n    x = layers.Conv2D(\n        16,\n        kernel_size=3,\n        strides=(2, 2),\n        padding=\"same\",\n        use_bias=False,\n        name=\"conv\",\n    )(x)\n    x = layers.BatchNormalization(\n        axis=channel_axis, epsilon=1e-3, momentum=0.999, name=\"conv_bn\"\n    )(x)\n    x = activation(x)\n\n    x = stack_fn(x, kernel, activation, se_ratio)\n\n    last_conv_ch = _depth(x.shape[channel_axis] * 6)\n\n    # if the width multiplier is greater than 1 we\n    # increase the number of output channels\n    if alpha > 1.0:\n        last_point_ch = _depth(last_point_ch * alpha)\n    x = layers.Conv2D(\n        last_conv_ch,\n        kernel_size=1,\n        padding=\"same\",\n        use_bias=False,\n        name=\"conv_1\",\n    )(x)\n    x = layers.BatchNormalization(\n        axis=channel_axis, epsilon=1e-3, momentum=0.999, name=\"conv_1_bn\"\n    )(x)\n    x = activation(x)\n    if include_top:\n        x = layers.GlobalAveragePooling2D(keepdims=True)(x)\n        x = layers.Conv2D(\n            last_point_ch,\n            kernel_size=1,\n            padding=\"same\",\n            use_bias=True,\n            name=\"conv_2\",\n        )(x)\n        x = activation(x)\n\n        if dropout_rate > 0:\n            x = layers.Dropout(dropout_rate)(x)\n        x = layers.Conv2D(\n            classes, kernel_size=1, padding=\"same\", name=\"logits\"\n        )(x)\n        x = layers.Flatten()(x)\n        imagenet_utils.validate_activation(classifier_activation, weights)\n        x = layers.Activation(\n            activation=classifier_activation, name=\"predictions\"\n        )(x)\n    else:\n        if pooling == \"avg\":\n            x = layers.GlobalAveragePooling2D(name=\"avg_pool\")(x)\n        elif pooling == \"max\":\n            x = layers.GlobalMaxPooling2D(name=\"max_pool\")(x)\n    # Ensure that the model takes into account\n    # any potential predecessors of `input_tensor`.\n    if input_tensor is not None:\n        inputs = operation_utils.get_source_inputs(input_tensor)\n    else:\n        inputs = img_input\n\n    # Create model.\n    model = Functional(inputs, x, name=name)\n\n    # Load weights.\n    if weights == \"imagenet\":\n        model_name = \"{}{}_224_{}_float\".format(\n            model_type, \"_minimalistic\" if minimalistic else \"\", str(alpha)\n        )\n        if include_top:\n            file_name = f\"weights_mobilenet_v3_{model_name}.h5\"\n            file_hash = WEIGHTS_HASHES[model_name][0]\n        else:\n            file_name = f\"weights_mobilenet_v3_{model_name}_no_top_v2.h5\"\n            file_hash = WEIGHTS_HASHES[model_name][1]\n        weights_path = file_utils.get_file(\n            file_name,\n            BASE_WEIGHT_PATH + file_name,\n            cache_subdir=\"models\",\n            file_hash=file_hash,\n        )\n        model.load_weights(weights_path)\n    elif weights is not None:\n        model.load_weights(weights)\n\n    return model\n\n\n@keras_export(\"keras.applications.MobileNetV3Small\")\ndef MobileNetV3Small(\n    input_shape=None,\n    alpha=1.0,\n    minimalistic=False,\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    classes=1000,\n    pooling=None,\n    dropout_rate=0.2,\n    classifier_activation=\"softmax\",\n    include_preprocessing=True,\n    name=\"MobileNetV3Small\",\n):\n    def stack_fn(x, kernel, activation, se_ratio):\n        def depth(d):\n            return _depth(d * alpha)\n\n        x = _inverted_res_block(x, 1, depth(16), 3, 2, se_ratio, relu, 0)\n        x = _inverted_res_block(x, 72.0 / 16, depth(24), 3, 2, None, relu, 1)\n        x = _inverted_res_block(x, 88.0 / 24, depth(24), 3, 1, None, relu, 2)\n        x = _inverted_res_block(\n            x, 4, depth(40), kernel, 2, se_ratio, activation, 3\n        )\n        x = _inverted_res_block(\n            x, 6, depth(40), kernel, 1, se_ratio, activation, 4\n        )\n        x = _inverted_res_block(\n            x, 6, depth(40), kernel, 1, se_ratio, activation, 5\n        )\n        x = _inverted_res_block(\n            x, 3, depth(48), kernel, 1, se_ratio, activation, 6\n        )\n        x = _inverted_res_block(\n            x, 3, depth(48), kernel, 1, se_ratio, activation, 7\n        )\n        x = _inverted_res_block(\n            x, 6, depth(96), kernel, 2, se_ratio, activation, 8\n        )\n        x = _inverted_res_block(\n            x, 6, depth(96), kernel, 1, se_ratio, activation, 9\n        )\n        x = _inverted_res_block(\n            x, 6, depth(96), kernel, 1, se_ratio, activation, 10\n        )\n        return x\n\n    return MobileNetV3(\n        stack_fn,\n        1024,\n        input_shape,\n        alpha,\n        \"small\",\n        minimalistic,\n        include_top,\n        weights,\n        input_tensor,\n        classes,\n        pooling,\n        dropout_rate,\n        classifier_activation,\n        include_preprocessing,\n        name=name,\n    )\n\n\n@keras_export(\"keras.applications.MobileNetV3Large\")\ndef MobileNetV3Large(\n    input_shape=None,\n    alpha=1.0,\n    minimalistic=False,\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    classes=1000,\n    pooling=None,\n    dropout_rate=0.2,\n    classifier_activation=\"softmax\",\n    include_preprocessing=True,\n    name=\"MobileNetV3Large\",\n):\n    def stack_fn(x, kernel, activation, se_ratio):\n        def depth(d):\n            return _depth(d * alpha)\n\n        x = _inverted_res_block(x, 1, depth(16), 3, 1, None, relu, 0)\n        x = _inverted_res_block(x, 4, depth(24), 3, 2, None, relu, 1)\n        x = _inverted_res_block(x, 3, depth(24), 3, 1, None, relu, 2)\n        x = _inverted_res_block(x, 3, depth(40), kernel, 2, se_ratio, relu, 3)\n        x = _inverted_res_block(x, 3, depth(40), kernel, 1, se_ratio, relu, 4)\n        x = _inverted_res_block(x, 3, depth(40), kernel, 1, se_ratio, relu, 5)\n        x = _inverted_res_block(x, 6, depth(80), 3, 2, None, activation, 6)\n        x = _inverted_res_block(x, 2.5, depth(80), 3, 1, None, activation, 7)\n        x = _inverted_res_block(x, 2.3, depth(80), 3, 1, None, activation, 8)\n        x = _inverted_res_block(x, 2.3, depth(80), 3, 1, None, activation, 9)\n        x = _inverted_res_block(\n            x, 6, depth(112), 3, 1, se_ratio, activation, 10\n        )\n        x = _inverted_res_block(\n            x, 6, depth(112), 3, 1, se_ratio, activation, 11\n        )\n        x = _inverted_res_block(\n            x, 6, depth(160), kernel, 2, se_ratio, activation, 12\n        )\n        x = _inverted_res_block(\n            x, 6, depth(160), kernel, 1, se_ratio, activation, 13\n        )\n        x = _inverted_res_block(\n            x, 6, depth(160), kernel, 1, se_ratio, activation, 14\n        )\n        return x\n\n    return MobileNetV3(\n        stack_fn,\n        1280,\n        input_shape,\n        alpha,\n        \"large\",\n        minimalistic,\n        include_top,\n        weights,\n        input_tensor,\n        classes,\n        pooling,\n        dropout_rate,\n        classifier_activation,\n        include_preprocessing,\n        name=name,\n    )\n\n\nMobileNetV3Small.__doc__ = BASE_DOCSTRING.format(name=\"MobileNetV3Small\")\nMobileNetV3Large.__doc__ = BASE_DOCSTRING.format(name=\"MobileNetV3Large\")\n\n\ndef relu(x):\n    return layers.ReLU()(x)\n\n\ndef hard_sigmoid(x):\n    return layers.ReLU(6.0)(x + 3.0) * (1.0 / 6.0)\n\n\ndef hard_swish(x):\n    return layers.Activation(\"hard_swish\")(x)\n\n\n# This function is taken from the original tf repo.\n# It ensures that all layers have a channel number that is divisible by 8\n# It can be seen here:\n# https://github.com/tensorflow/models/blob/master/research/\n# slim/nets/mobilenet/mobilenet.py\n\n\ndef _depth(v, divisor=8, min_value=None):\n    if min_value is None:\n        min_value = divisor\n    new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)\n    # Make sure that round down does not go down by more than 10%.\n    if new_v < 0.9 * v:\n        new_v += divisor\n    return new_v\n\n\ndef _se_block(inputs, filters, se_ratio, prefix):\n    x = layers.GlobalAveragePooling2D(\n        keepdims=True, name=f\"{prefix}squeeze_excite_avg_pool\"\n    )(inputs)\n    x = layers.Conv2D(\n        _depth(filters * se_ratio),\n        kernel_size=1,\n        padding=\"same\",\n        name=f\"{prefix}squeeze_excite_conv\",\n    )(x)\n    x = layers.ReLU(name=f\"{prefix}squeeze_excite_relu\")(x)\n    x = layers.Conv2D(\n        filters,\n        kernel_size=1,\n        padding=\"same\",\n        name=f\"{prefix}squeeze_excite_conv_1\",\n    )(x)\n    x = hard_sigmoid(x)\n    x = layers.Multiply(name=f\"{prefix}squeeze_excite_mul\")([inputs, x])\n    return x\n\n\ndef _inverted_res_block(\n    x, expansion, filters, kernel_size, stride, se_ratio, activation, block_id\n):\n    channel_axis = 1 if backend.image_data_format() == \"channels_first\" else -1\n    shortcut = x\n    prefix = \"expanded_conv_\"\n    infilters = x.shape[channel_axis]\n    if block_id:\n        # Expand\n        prefix = f\"expanded_conv_{block_id}_\"\n        x = layers.Conv2D(\n            _depth(infilters * expansion),\n            kernel_size=1,\n            padding=\"same\",\n            use_bias=False,\n            name=f\"{prefix}expand\",\n        )(x)\n        x = layers.BatchNormalization(\n            axis=channel_axis,\n            epsilon=1e-3,\n            momentum=0.999,\n            name=f\"{prefix}expand_bn\",\n        )(x)\n        x = activation(x)\n\n    if stride == 2:\n        x = layers.ZeroPadding2D(\n            padding=imagenet_utils.correct_pad(x, kernel_size),\n            name=f\"{prefix}depthwise_pad\",\n        )(x)\n    x = layers.DepthwiseConv2D(\n        kernel_size,\n        strides=stride,\n        padding=\"same\" if stride == 1 else \"valid\",\n        use_bias=False,\n        name=f\"{prefix}depthwise\",\n    )(x)\n    x = layers.BatchNormalization(\n        axis=channel_axis,\n        epsilon=1e-3,\n        momentum=0.999,\n        name=f\"{prefix}depthwise_bn\",\n    )(x)\n    x = activation(x)\n\n    if se_ratio:\n        x = _se_block(x, _depth(infilters * expansion), se_ratio, prefix)\n\n    x = layers.Conv2D(\n        filters,\n        kernel_size=1,\n        padding=\"same\",\n        use_bias=False,\n        name=f\"{prefix}project\",\n    )(x)\n    x = layers.BatchNormalization(\n        axis=channel_axis,\n        epsilon=1e-3,\n        momentum=0.999,\n        name=f\"{prefix}project_bn\",\n    )(x)\n\n    if stride == 1 and infilters == filters:\n        x = layers.Add(name=f\"{prefix}add\")([shortcut, x])\n    return x\n\n\n@keras_export(\"keras.applications.mobilenet_v3.preprocess_input\")\ndef preprocess_input(x, data_format=None):\n    \"\"\"A placeholder method for backward compatibility.\n\n    The preprocessing logic has been included in the mobilenet_v3 model\n    implementation. Users are no longer required to call this method to\n    normalize the input data. This method does nothing and only kept as a\n    placeholder to align the API surface between old and new version of model.\n\n    Args:\n        x: A floating point `numpy.array` or a tensor.\n        data_format: Optional data format of the image tensor/array.\n            `None` means the global setting\n            `keras.config.image_data_format()` is used\n            (unless you changed it, it uses `\"channels_last\"`).\n            Defaults to `None`.\n\n    Returns:\n        Unchanged `numpy.array` or tensor.\n    \"\"\"\n    return x\n\n\n@keras_export(\"keras.applications.mobilenet_v3.decode_predictions\")\ndef decode_predictions(preds, top=5):\n    return imagenet_utils.decode_predictions(preds, top=top)\n\n\ndecode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__\n"
  },
  {
    "path": "keras/src/applications/nasnet.py",
    "content": "import warnings\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src.api_export import keras_export\nfrom keras.src.applications import imagenet_utils\nfrom keras.src.models import Functional\nfrom keras.src.ops import operation_utils\nfrom keras.src.utils import file_utils\n\nBASE_WEIGHTS_PATH = (\n    \"https://storage.googleapis.com/tensorflow/keras-applications/nasnet/\"\n)\nNASNET_MOBILE_WEIGHT_PATH = f\"{BASE_WEIGHTS_PATH}NASNet-mobile.h5\"\nNASNET_MOBILE_WEIGHT_PATH_NO_TOP = f\"{BASE_WEIGHTS_PATH}NASNet-mobile-no-top.h5\"\nNASNET_LARGE_WEIGHT_PATH = f\"{BASE_WEIGHTS_PATH}NASNet-large.h5\"\nNASNET_LARGE_WEIGHT_PATH_NO_TOP = f\"{BASE_WEIGHTS_PATH}NASNet-large-no-top.h5\"\n\n\ndef NASNet(\n    input_shape=None,\n    penultimate_filters=4032,\n    num_blocks=6,\n    stem_block_filters=96,\n    skip_reduction=True,\n    filter_multiplier=2,\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    pooling=None,\n    classes=1000,\n    default_size=None,\n    classifier_activation=\"softmax\",\n    name=\"NASNet\",\n):\n    \"\"\"Instantiates a NASNet model.\n\n    Reference:\n    - [Learning Transferable Architectures for Scalable Image Recognition](\n        https://arxiv.org/abs/1707.07012) (CVPR 2018)\n\n    For image classification use cases, see\n    [this page for detailed examples](\n      https://keras.io/api/applications/#usage-examples-for-image-classification-models).\n\n    For transfer learning use cases, make sure to read the\n    [guide to transfer learning & fine-tuning](\n      https://keras.io/guides/transfer_learning/).\n\n    Note: each Keras Application expects a specific kind of input preprocessing.\n    For NasNet, call `keras.applications.nasnet.preprocess_input`\n    on your inputs before passing them to the model.\n    `nasnet.preprocess_input` will scale input pixels between -1 and 1.\n\n    Args:\n        input_shape: Optional shape tuple, the input shape\n            is by default `(331, 331, 3)` for NASNetLarge and\n            `(224, 224, 3)` for NASNetMobile.\n            It should have exactly 3 input channels,\n            and width and height should be no smaller than 32.\n            E.g. `(224, 224, 3)` would be one valid value.\n        penultimate_filters: Number of filters in the penultimate layer.\n            NASNet models use the notation `NASNet (N @ P)`, where:\n                -   N is the number of blocks\n                -   P is the number of penultimate filters\n        num_blocks: Number of repeated blocks of the NASNet model.\n            NASNet models use the notation `NASNet (N @ P)`, where:\n                -   N is the number of blocks\n                -   P is the number of penultimate filters\n        stem_block_filters: Number of filters in the initial stem block\n        skip_reduction: Whether to skip the reduction step at the tail\n            end of the network.\n        filter_multiplier: Controls the width of the network.\n            - If `filter_multiplier` < 1.0, proportionally decreases the number\n                of filters in each layer.\n            - If `filter_multiplier` > 1.0, proportionally increases the number\n                of filters in each layer.\n            - If `filter_multiplier` = 1, default number of filters from the\n                paper are used at each layer.\n        include_top: Whether to include the fully-connected\n            layer at the top of the network.\n        weights: `None` (random initialization) or\n            `imagenet` (ImageNet weights)\n        input_tensor: Optional Keras tensor (i.e. output of\n            `layers.Input()`)\n            to use as image input for the model.\n        pooling: Optional pooling mode for feature extraction\n            when `include_top` is `False`.\n            - `None` means that the output of the model\n                will be the 4D tensor output of the\n                last convolutional block.\n            - `avg` means that global average pooling\n                will be applied to the output of the\n                last convolutional block, and thus\n                the output of the model will be a\n                2D tensor.\n            - `max` means that global max pooling will\n                be applied.\n        classes: Optional number of classes to classify images\n            into, only to be specified if `include_top` is `True`, and\n            if no `weights` argument is specified.\n        default_size: Specifies the default image size of the model\n        classifier_activation: A `str` or callable.\n            The activation function to use on the \"top\" layer.\n            Ignored unless `include_top=True`.\n            Set `classifier_activation=None` to return the logits\n            of the \"top\" layer. When loading pretrained weights,\n            `classifier_activation` can only be `None` or `\"softmax\"`.\n        name: The name of the model (string).\n\n    Returns:\n        A model instance.\n    \"\"\"\n    if backend.image_data_format() == \"channels_first\":\n        raise ValueError(\n            \"NASNet does not support the `channels_first` image data \"\n            \"format. Switch to `channels_last` by editing your local \"\n            \"config file at ~/.keras/keras.json\"\n        )\n    if not (weights in {\"imagenet\", None} or file_utils.exists(weights)):\n        raise ValueError(\n            \"The `weights` argument should be either \"\n            \"`None` (random initialization), `imagenet` \"\n            \"(pre-training on ImageNet), \"\n            \"or the path to the weights file to be loaded.\"\n        )\n\n    if weights == \"imagenet\" and include_top and classes != 1000:\n        raise ValueError(\n            'If using `weights` as `\"imagenet\"` with `include_top` '\n            \"as true, `classes` should be 1000\"\n        )\n\n    if (\n        isinstance(input_shape, tuple)\n        and None in input_shape\n        and weights == \"imagenet\"\n    ):\n        raise ValueError(\n            \"When specifying the input shape of a NASNet and loading \"\n            \"`ImageNet` weights, the input_shape argument must be static\"\n            f\" (no None entries). Got: `input_shape={input_shape}`.\"\n        )\n\n    if default_size is None:\n        default_size = 331\n\n    # Determine proper input shape and default size.\n    input_shape = imagenet_utils.obtain_input_shape(\n        input_shape,\n        default_size=default_size,\n        min_size=32,\n        data_format=backend.image_data_format(),\n        require_flatten=include_top,\n        weights=weights,\n    )\n\n    if backend.image_data_format() != \"channels_last\":\n        warnings.warn(\n            \"The NASNet family of models is only available \"\n            'for the input data format \"channels_last\" '\n            \"(width, height, channels). \"\n            \"However your settings specify the default \"\n            'data format \"channels_first\" (channels, width, height).'\n            ' You should set `image_data_format=\"channels_last\"` '\n            \"in your Keras config located at ~/.keras/keras.json. \"\n            \"The model being returned right now will expect inputs \"\n            'to follow the \"channels_last\" data format.',\n            stacklevel=2,\n        )\n\n    if input_tensor is None:\n        img_input = layers.Input(shape=input_shape)\n    else:\n        if not backend.is_keras_tensor(input_tensor):\n            img_input = layers.Input(tensor=input_tensor, shape=input_shape)\n        else:\n            img_input = input_tensor\n\n    if penultimate_filters % (24 * (filter_multiplier**2)) != 0:\n        raise ValueError(\n            \"For NASNet-A models, the `penultimate_filters` must be a multiple \"\n            \"of 24 * (`filter_multiplier` ** 2). \"\n            f\"Current value: {penultimate_filters}\"\n        )\n\n    channel_dim = 1 if backend.image_data_format() == \"channels_first\" else -1\n    filters = penultimate_filters // 24\n\n    x = layers.Conv2D(\n        stem_block_filters,\n        (3, 3),\n        strides=(2, 2),\n        padding=\"valid\",\n        use_bias=False,\n        name=\"stem_conv1\",\n        kernel_initializer=\"he_normal\",\n    )(img_input)\n\n    x = layers.BatchNormalization(\n        axis=channel_dim, momentum=0.9997, epsilon=1e-3, name=\"stem_bn1\"\n    )(x)\n\n    p = None\n    x, p = _reduction_a_cell(\n        x, p, filters // (filter_multiplier**2), block_id=\"stem_1\"\n    )\n    x, p = _reduction_a_cell(\n        x, p, filters // filter_multiplier, block_id=\"stem_2\"\n    )\n\n    for i in range(num_blocks):\n        x, p = _normal_a_cell(x, p, filters, block_id=f\"{i}\")\n\n    x, p0 = _reduction_a_cell(\n        x, p, filters * filter_multiplier, block_id=f\"reduce_{num_blocks}\"\n    )\n\n    p = p0 if not skip_reduction else p\n\n    for i in range(num_blocks):\n        x, p = _normal_a_cell(\n            x,\n            p,\n            filters * filter_multiplier,\n            block_id=f\"{num_blocks + i + 1}\",\n        )\n\n    x, p0 = _reduction_a_cell(\n        x,\n        p,\n        filters * filter_multiplier**2,\n        block_id=f\"reduce_{2 * num_blocks}\",\n    )\n\n    p = p0 if not skip_reduction else p\n\n    for i in range(num_blocks):\n        x, p = _normal_a_cell(\n            x,\n            p,\n            filters * filter_multiplier**2,\n            block_id=f\"{2 * num_blocks + i + 1}\",\n        )\n\n    x = layers.Activation(\"relu\")(x)\n\n    if include_top:\n        x = layers.GlobalAveragePooling2D()(x)\n        imagenet_utils.validate_activation(classifier_activation, weights)\n        x = layers.Dense(\n            classes, activation=classifier_activation, name=\"predictions\"\n        )(x)\n    else:\n        if pooling == \"avg\":\n            x = layers.GlobalAveragePooling2D()(x)\n        elif pooling == \"max\":\n            x = layers.GlobalMaxPooling2D()(x)\n\n    # Ensure that the model takes into account\n    # any potential predecessors of `input_tensor`.\n    if input_tensor is not None:\n        inputs = operation_utils.get_source_inputs(input_tensor)\n    else:\n        inputs = img_input\n\n    model = Functional(inputs, x, name=name)\n\n    # Load weights.\n    if weights == \"imagenet\":\n        if default_size == 224:  # mobile version\n            if include_top:\n                weights_path = file_utils.get_file(\n                    \"nasnet_mobile.h5\",\n                    NASNET_MOBILE_WEIGHT_PATH,\n                    cache_subdir=\"models\",\n                    file_hash=\"020fb642bf7360b370c678b08e0adf61\",\n                )\n            else:\n                weights_path = file_utils.get_file(\n                    \"nasnet_mobile_no_top.h5\",\n                    NASNET_MOBILE_WEIGHT_PATH_NO_TOP,\n                    cache_subdir=\"models\",\n                    file_hash=\"1ed92395b5b598bdda52abe5c0dbfd63\",\n                )\n            model.load_weights(weights_path)\n        elif default_size == 331:  # large version\n            if include_top:\n                weights_path = file_utils.get_file(\n                    \"nasnet_large.h5\",\n                    NASNET_LARGE_WEIGHT_PATH,\n                    cache_subdir=\"models\",\n                    file_hash=\"11577c9a518f0070763c2b964a382f17\",\n                )\n            else:\n                weights_path = file_utils.get_file(\n                    \"nasnet_large_no_top.h5\",\n                    NASNET_LARGE_WEIGHT_PATH_NO_TOP,\n                    cache_subdir=\"models\",\n                    file_hash=\"d81d89dc07e6e56530c4e77faddd61b5\",\n                )\n            model.load_weights(weights_path)\n        else:\n            raise ValueError(\n                \"ImageNet weights can only be loaded with NASNetLarge\"\n                \" or NASNetMobile\"\n            )\n    elif weights is not None:\n        model.load_weights(weights)\n\n    return model\n\n\n@keras_export(\n    [\n        \"keras.applications.nasnet.NASNetMobile\",\n        \"keras.applications.NASNetMobile\",\n    ]\n)\ndef NASNetMobile(\n    input_shape=None,\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"nasnet_mobile\",\n):\n    \"\"\"Instantiates a Mobile NASNet model in ImageNet mode.\n\n    Reference:\n    - [Learning Transferable Architectures for Scalable Image Recognition](\n        https://arxiv.org/abs/1707.07012) (CVPR 2018)\n\n    Optionally loads weights pre-trained on ImageNet.\n    Note that the data format convention used by the model is\n    the one specified in your Keras config at `~/.keras/keras.json`.\n\n    Note: each Keras Application expects a specific kind of input preprocessing.\n    For NASNet, call `keras.applications.nasnet.preprocess_input` on your\n    inputs before passing them to the model.\n\n    Args:\n        input_shape: Optional shape tuple, only to be specified\n            if `include_top` is False (otherwise the input shape\n            has to be `(224, 224, 3)` for NASNetMobile\n            It should have exactly 3 inputs channels,\n            and width and height should be no smaller than 32.\n            E.g. `(224, 224, 3)` would be one valid value.\n        include_top: Whether to include the fully-connected\n            layer at the top of the network.\n        weights: `None` (random initialization) or\n            `imagenet` (ImageNet weights). For loading `imagenet` weights,\n            `input_shape` should be (224, 224, 3)\n        input_tensor: Optional Keras tensor (i.e. output of\n            `layers.Input()`)\n            to use as image input for the model.\n        pooling: Optional pooling mode for feature extraction\n            when `include_top` is `False`.\n            - `None` means that the output of the model\n                will be the 4D tensor output of the\n                last convolutional layer.\n            - `avg` means that global average pooling\n                will be applied to the output of the\n                last convolutional layer, and thus\n                the output of the model will be a\n                2D tensor.\n            - `max` means that global max pooling will\n                be applied.\n        classes: Optional number of classes to classify images\n            into, only to be specified if `include_top` is `True`, and\n            if no `weights` argument is specified.\n        classifier_activation: A `str` or callable. The activation function to\n            use on the \"top\" layer. Ignored unless `include_top=True`. Set\n            `classifier_activation=None` to return the logits of the \"top\"\n            layer.  When loading pretrained weights, `classifier_activation` can\n            only be `None` or `\"softmax\"`.\n        name: The name of the model (string).\n\n    Returns:\n        A Keras model instance.\n    \"\"\"\n    if backend.backend() == \"torch\":\n        raise ValueError(\n            \"NASNetMobile is not available with the torch backend \"\n            \"at this time due to an outstanding bug. \"\n            \"If interested, please open a PR.\"\n        )\n    if not include_top and input_shape is None:\n        input_shape = (224, 224, 3)\n    return NASNet(\n        input_shape,\n        penultimate_filters=1056,\n        num_blocks=4,\n        stem_block_filters=32,\n        skip_reduction=False,\n        filter_multiplier=2,\n        include_top=include_top,\n        weights=weights,\n        input_tensor=input_tensor,\n        pooling=pooling,\n        classes=classes,\n        default_size=224,\n        classifier_activation=classifier_activation,\n        name=name,\n    )\n\n\n@keras_export(\n    [\n        \"keras.applications.nasnet.NASNetLarge\",\n        \"keras.applications.NASNetLarge\",\n    ]\n)\ndef NASNetLarge(\n    input_shape=None,\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"nasnet_large\",\n):\n    \"\"\"Instantiates a NASNet model in ImageNet mode.\n\n    Reference:\n    - [Learning Transferable Architectures for Scalable Image Recognition](\n        https://arxiv.org/abs/1707.07012) (CVPR 2018)\n\n    Optionally loads weights pre-trained on ImageNet.\n    Note that the data format convention used by the model is\n    the one specified in your Keras config at `~/.keras/keras.json`.\n\n    Note: each Keras Application expects a specific kind of input preprocessing.\n    For NASNet, call `keras.applications.nasnet.preprocess_input` on your\n    inputs before passing them to the model.\n\n    Args:\n        input_shape: Optional shape tuple, only to be specified\n            if `include_top` is False (otherwise the input shape\n            has to be `(331, 331, 3)` for NASNetLarge.\n            It should have exactly 3 inputs channels,\n            and width and height should be no smaller than 32.\n            E.g. `(224, 224, 3)` would be one valid value.\n        include_top: Whether to include the fully-connected\n            layer at the top of the network.\n        weights: `None` (random initialization) or\n            `imagenet` (ImageNet weights).  For loading `imagenet` weights,\n            `input_shape` should be (331, 331, 3)\n        input_tensor: Optional Keras tensor (i.e. output of\n            `layers.Input()`)\n            to use as image input for the model.\n        pooling: Optional pooling mode for feature extraction\n            when `include_top` is `False`.\n            - `None` means that the output of the model\n                will be the 4D tensor output of the\n                last convolutional layer.\n            - `avg` means that global average pooling\n                will be applied to the output of the\n                last convolutional layer, and thus\n                the output of the model will be a\n                2D tensor.\n            - `max` means that global max pooling will\n                be applied.\n        classes: Optional number of classes to classify images\n            into, only to be specified if `include_top` is `True`, and\n            if no `weights` argument is specified.\n        classifier_activation: A `str` or callable. The activation function to\n            use on the \"top\" layer. Ignored unless `include_top=True`. Set\n            `classifier_activation=None` to return the logits of the \"top\"\n            layer.  When loading pretrained weights, `classifier_activation`\n            can only be `None` or `\"softmax\"`.\n        name: The name of the model (string).\n\n    Returns:\n        A Keras model instance.\n    \"\"\"\n    return NASNet(\n        input_shape,\n        penultimate_filters=4032,\n        num_blocks=6,\n        stem_block_filters=96,\n        skip_reduction=True,\n        filter_multiplier=2,\n        include_top=include_top,\n        weights=weights,\n        input_tensor=input_tensor,\n        pooling=pooling,\n        classes=classes,\n        default_size=331,\n        classifier_activation=classifier_activation,\n        name=name,\n    )\n\n\ndef _separable_conv_block(\n    ip, filters, kernel_size=(3, 3), strides=(1, 1), block_id=None\n):\n    \"\"\"Adds 2 blocks of [relu-separable conv-batchnorm].\n\n    Args:\n        ip: Input tensor\n        filters: Number of output filters per layer\n        kernel_size: Kernel size of separable convolutions\n        strides: Strided convolution for downsampling\n        block_id: String block_id\n\n    Returns:\n        A Keras tensor\n    \"\"\"\n    channel_dim = 1 if backend.image_data_format() == \"channels_first\" else -1\n\n    with backend.name_scope(f\"separable_conv_block_{block_id}\"):\n        x = layers.Activation(\"relu\")(ip)\n        if strides == (2, 2):\n            x = layers.ZeroPadding2D(\n                padding=imagenet_utils.correct_pad(x, kernel_size),\n                name=f\"separable_conv_1_pad_{block_id}\",\n            )(x)\n            conv_pad = \"valid\"\n        else:\n            conv_pad = \"same\"\n        x = layers.SeparableConv2D(\n            filters,\n            kernel_size,\n            strides=strides,\n            name=f\"separable_conv_1_{block_id}\",\n            padding=conv_pad,\n            use_bias=False,\n        )(x)\n        x = layers.BatchNormalization(\n            axis=channel_dim,\n            momentum=0.9997,\n            epsilon=1e-3,\n            name=f\"separable_conv_1_bn_{block_id}\",\n        )(x)\n        x = layers.Activation(\"relu\")(x)\n        x = layers.SeparableConv2D(\n            filters,\n            kernel_size,\n            name=f\"separable_conv_2_{block_id}\",\n            padding=\"same\",\n            use_bias=False,\n        )(x)\n        x = layers.BatchNormalization(\n            axis=channel_dim,\n            momentum=0.9997,\n            epsilon=1e-3,\n            name=f\"separable_conv_2_bn_{block_id}\",\n        )(x)\n    return x\n\n\ndef _adjust_block(p, ip, filters, block_id=None):\n    \"\"\"Adjusts the input `previous path` to match the shape of the `input`.\n\n    Used in situations where the output number of filters needs to be changed.\n\n    Args:\n        p: Input tensor which needs to be modified\n        ip: Input tensor whose shape needs to be matched\n        filters: Number of output filters to be matched\n        block_id: String block_id\n\n    Returns:\n        Adjusted Keras tensor\n    \"\"\"\n    channel_dim = 1 if backend.image_data_format() == \"channels_first\" else -1\n    img_dim = 2 if backend.image_data_format() == \"channels_first\" else -2\n\n    with backend.name_scope(\"adjust_block\"):\n        if p is None:\n            p = ip\n\n        elif p.shape[img_dim] != ip.shape[img_dim]:\n            with backend.name_scope(f\"adjust_reduction_block_{block_id}\"):\n                p = layers.Activation(\"relu\", name=f\"adjust_relu_1_{block_id}\")(\n                    p\n                )\n                p1 = layers.AveragePooling2D(\n                    (1, 1),\n                    strides=(2, 2),\n                    padding=\"valid\",\n                    name=f\"adjust_avg_pool_1_{block_id}\",\n                )(p)\n                p1 = layers.Conv2D(\n                    filters // 2,\n                    (1, 1),\n                    padding=\"same\",\n                    use_bias=False,\n                    name=f\"adjust_conv_1_{block_id}\",\n                    kernel_initializer=\"he_normal\",\n                )(p1)\n\n                p2 = layers.ZeroPadding2D(padding=((0, 1), (0, 1)))(p)\n                p2 = layers.Cropping2D(cropping=((1, 0), (1, 0)))(p2)\n                p2 = layers.AveragePooling2D(\n                    (1, 1),\n                    strides=(2, 2),\n                    padding=\"valid\",\n                    name=f\"adjust_avg_pool_2_{block_id}\",\n                )(p2)\n                p2 = layers.Conv2D(\n                    filters // 2,\n                    (1, 1),\n                    padding=\"same\",\n                    use_bias=False,\n                    name=f\"adjust_conv_2_{block_id}\",\n                    kernel_initializer=\"he_normal\",\n                )(p2)\n\n                p = layers.concatenate([p1, p2], axis=channel_dim)\n                p = layers.BatchNormalization(\n                    axis=channel_dim,\n                    momentum=0.9997,\n                    epsilon=1e-3,\n                    name=f\"adjust_bn_{block_id}\",\n                )(p)\n\n        elif p.shape[channel_dim] != filters:\n            with backend.name_scope(f\"adjust_projection_block_{block_id}\"):\n                p = layers.Activation(\"relu\")(p)\n                p = layers.Conv2D(\n                    filters,\n                    (1, 1),\n                    strides=(1, 1),\n                    padding=\"same\",\n                    name=f\"adjust_conv_projection_{block_id}\",\n                    use_bias=False,\n                    kernel_initializer=\"he_normal\",\n                )(p)\n                p = layers.BatchNormalization(\n                    axis=channel_dim,\n                    momentum=0.9997,\n                    epsilon=1e-3,\n                    name=f\"adjust_bn_{block_id}\",\n                )(p)\n    return p\n\n\ndef _normal_a_cell(ip, p, filters, block_id=None):\n    \"\"\"Adds a Normal cell for NASNet-A (Fig. 4 in the paper).\n\n    Args:\n        ip: Input tensor `x`\n        p: Input tensor `p`\n        filters: Number of output filters\n        block_id: String block_id\n\n    Returns:\n        A Keras tensor\n    \"\"\"\n    channel_dim = 1 if backend.image_data_format() == \"channels_first\" else -1\n\n    with backend.name_scope(f\"normal_A_block_{block_id}\"):\n        p = _adjust_block(p, ip, filters, block_id)\n\n        h = layers.Activation(\"relu\")(ip)\n        h = layers.Conv2D(\n            filters,\n            (1, 1),\n            strides=(1, 1),\n            padding=\"same\",\n            name=f\"normal_conv_1_{block_id}\",\n            use_bias=False,\n            kernel_initializer=\"he_normal\",\n        )(h)\n        h = layers.BatchNormalization(\n            axis=channel_dim,\n            momentum=0.9997,\n            epsilon=1e-3,\n            name=f\"normal_bn_1_{block_id}\",\n        )(h)\n\n        with backend.name_scope(\"block_1\"):\n            x1_1 = _separable_conv_block(\n                h,\n                filters,\n                kernel_size=(5, 5),\n                block_id=f\"normal_left1_{block_id}\",\n            )\n            x1_2 = _separable_conv_block(\n                p, filters, block_id=f\"normal_right1_{block_id}\"\n            )\n            x1 = layers.add([x1_1, x1_2], name=f\"normal_add_1_{block_id}\")\n\n        with backend.name_scope(\"block_2\"):\n            x2_1 = _separable_conv_block(\n                p, filters, (5, 5), block_id=f\"normal_left2_{block_id}\"\n            )\n            x2_2 = _separable_conv_block(\n                p, filters, (3, 3), block_id=f\"normal_right2_{block_id}\"\n            )\n            x2 = layers.add([x2_1, x2_2], name=f\"normal_add_2_{block_id}\")\n\n        with backend.name_scope(\"block_3\"):\n            x3 = layers.AveragePooling2D(\n                (3, 3),\n                strides=(1, 1),\n                padding=\"same\",\n                name=f\"normal_left3_{block_id}\",\n            )(h)\n            x3 = layers.add([x3, p], name=f\"normal_add_3_{block_id}\")\n\n        with backend.name_scope(\"block_4\"):\n            x4_1 = layers.AveragePooling2D(\n                (3, 3),\n                strides=(1, 1),\n                padding=\"same\",\n                name=f\"normal_left4_{block_id}\",\n            )(p)\n            x4_2 = layers.AveragePooling2D(\n                (3, 3),\n                strides=(1, 1),\n                padding=\"same\",\n                name=f\"normal_right4_{block_id}\",\n            )(p)\n            x4 = layers.add([x4_1, x4_2], name=f\"normal_add_4_{block_id}\")\n\n        with backend.name_scope(\"block_5\"):\n            x5 = _separable_conv_block(\n                h, filters, block_id=f\"normal_left5_{block_id}\"\n            )\n            x5 = layers.add([x5, h], name=f\"normal_add_5_{block_id}\")\n\n        x = layers.concatenate(\n            [p, x1, x2, x3, x4, x5],\n            axis=channel_dim,\n            name=f\"normal_concat_{block_id}\",\n        )\n    return x, ip\n\n\ndef _reduction_a_cell(ip, p, filters, block_id=None):\n    \"\"\"Adds a Reduction cell for NASNet-A (Fig. 4 in the paper).\n\n    Args:\n      ip: Input tensor `x`\n      p: Input tensor `p`\n      filters: Number of output filters\n      block_id: String block_id\n\n    Returns:\n      A Keras tensor\n    \"\"\"\n    channel_dim = 1 if backend.image_data_format() == \"channels_first\" else -1\n\n    with backend.name_scope(f\"reduction_A_block_{block_id}\"):\n        p = _adjust_block(p, ip, filters, block_id)\n\n        h = layers.Activation(\"relu\")(ip)\n        h = layers.Conv2D(\n            filters,\n            (1, 1),\n            strides=(1, 1),\n            padding=\"same\",\n            name=f\"reduction_conv_1_{block_id}\",\n            use_bias=False,\n            kernel_initializer=\"he_normal\",\n        )(h)\n        h = layers.BatchNormalization(\n            axis=channel_dim,\n            momentum=0.9997,\n            epsilon=1e-3,\n            name=f\"reduction_bn_1_{block_id}\",\n        )(h)\n        h3 = layers.ZeroPadding2D(\n            padding=imagenet_utils.correct_pad(h, 3),\n            name=f\"reduction_pad_1_{block_id}\",\n        )(h)\n\n        with backend.name_scope(\"block_1\"):\n            x1_1 = _separable_conv_block(\n                h,\n                filters,\n                (5, 5),\n                strides=(2, 2),\n                block_id=f\"reduction_left1_{block_id}\",\n            )\n            x1_2 = _separable_conv_block(\n                p,\n                filters,\n                (7, 7),\n                strides=(2, 2),\n                block_id=f\"reduction_right1_{block_id}\",\n            )\n            x1 = layers.add([x1_1, x1_2], name=f\"reduction_add_1_{block_id}\")\n\n        with backend.name_scope(\"block_2\"):\n            x2_1 = layers.MaxPooling2D(\n                (3, 3),\n                strides=(2, 2),\n                padding=\"valid\",\n                name=f\"reduction_left2_{block_id}\",\n            )(h3)\n            x2_2 = _separable_conv_block(\n                p,\n                filters,\n                (7, 7),\n                strides=(2, 2),\n                block_id=f\"reduction_right2_{block_id}\",\n            )\n            x2 = layers.add([x2_1, x2_2], name=f\"reduction_add_2_{block_id}\")\n\n        with backend.name_scope(\"block_3\"):\n            x3_1 = layers.AveragePooling2D(\n                (3, 3),\n                strides=(2, 2),\n                padding=\"valid\",\n                name=f\"reduction_left3_{block_id}\",\n            )(h3)\n            x3_2 = _separable_conv_block(\n                p,\n                filters,\n                (5, 5),\n                strides=(2, 2),\n                block_id=f\"reduction_right3_{block_id}\",\n            )\n            x3 = layers.add([x3_1, x3_2], name=f\"reduction_add3_{block_id}\")\n\n        with backend.name_scope(\"block_4\"):\n            x4 = layers.AveragePooling2D(\n                (3, 3),\n                strides=(1, 1),\n                padding=\"same\",\n                name=f\"reduction_left4_{block_id}\",\n            )(x1)\n            x4 = layers.add([x2, x4])\n\n        with backend.name_scope(\"block_5\"):\n            x5_1 = _separable_conv_block(\n                x1, filters, (3, 3), block_id=f\"reduction_left4_{block_id}\"\n            )\n            x5_2 = layers.MaxPooling2D(\n                (3, 3),\n                strides=(2, 2),\n                padding=\"valid\",\n                name=f\"reduction_right5_{block_id}\",\n            )(h3)\n            x5 = layers.add([x5_1, x5_2], name=f\"reduction_add4_{block_id}\")\n\n        x = layers.concatenate(\n            [x2, x3, x4, x5],\n            axis=channel_dim,\n            name=f\"reduction_concat_{block_id}\",\n        )\n        return x, ip\n\n\n@keras_export(\"keras.applications.nasnet.preprocess_input\")\ndef preprocess_input(x, data_format=None):\n    return imagenet_utils.preprocess_input(\n        x, data_format=data_format, mode=\"tf\"\n    )\n\n\n@keras_export(\"keras.applications.nasnet.decode_predictions\")\ndef decode_predictions(preds, top=5):\n    return imagenet_utils.decode_predictions(preds, top=top)\n\n\npreprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(\n    mode=\"\",\n    ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,\n    error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC,\n)\ndecode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__\n"
  },
  {
    "path": "keras/src/applications/resnet.py",
    "content": "from keras.src import backend\nfrom keras.src import layers\nfrom keras.src.api_export import keras_export\nfrom keras.src.applications import imagenet_utils\nfrom keras.src.models import Functional\nfrom keras.src.ops import operation_utils\nfrom keras.src.utils import file_utils\n\nBASE_WEIGHTS_PATH = (\n    \"https://storage.googleapis.com/tensorflow/keras-applications/resnet/\"\n)\nWEIGHTS_HASHES = {\n    \"resnet50\": (\n        \"2cb95161c43110f7111970584f804107\",\n        \"4d473c1dd8becc155b73f8504c6f6626\",\n    ),\n    \"resnet101\": (\n        \"f1aeb4b969a6efcfb50fad2f0c20cfc5\",\n        \"88cf7a10940856eca736dc7b7e228a21\",\n    ),\n    \"resnet152\": (\n        \"100835be76be38e30d865e96f2aaae62\",\n        \"ee4c566cf9a93f14d82f913c2dc6dd0c\",\n    ),\n    \"resnet50v2\": (\n        \"3ef43a0b657b3be2300d5770ece849e0\",\n        \"fac2f116257151a9d068a22e544a4917\",\n    ),\n    \"resnet101v2\": (\n        \"6343647c601c52e1368623803854d971\",\n        \"c0ed64b8031c3730f411d2eb4eea35b5\",\n    ),\n    \"resnet152v2\": (\n        \"a49b44d1979771252814e80f8ec446f9\",\n        \"ed17cf2e0169df9d443503ef94b23b33\",\n    ),\n    \"resnext50\": (\n        \"67a5b30d522ed92f75a1f16eef299d1a\",\n        \"62527c363bdd9ec598bed41947b379fc\",\n    ),\n    \"resnext101\": (\n        \"34fb605428fcc7aa4d62f44404c11509\",\n        \"0f678c91647380debd923963594981b3\",\n    ),\n}\n\n\ndef ResNet(\n    stack_fn,\n    preact,\n    use_bias,\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"resnet\",\n    weights_name=None,\n):\n    \"\"\"Instantiates the ResNet, ResNetV2, and ResNeXt architecture.\n\n    Args:\n        stack_fn: A function that returns output tensor for the\n            stacked residual blocks.\n        preact: Whether to use pre-activation or not. `True` for ResNetV2,\n            `False` for ResNet and ResNeXt.\n        use_bias: Whether to use biases for convolutional layers or not.\n            `True` for ResNet and ResNetV2, `False` for ResNeXt.\n        name: Name of the model.\n        include_top: Whether to include the fully-connected\n            layer at the top of the network.\n        weights: One of `None` (random initialization),\n            `\"imagenet\"` (pre-training on ImageNet),\n            or the path to the weights file to be loaded.\n        input_tensor: Optional Keras tensor (i.e. output of `layers.Input()`)\n            to use as image input for the model.\n        input_shape: Optional shape tuple, only to be specified\n            if `include_top` is `False` (otherwise the input shape\n            has to be `(224, 224, 3)` (with `channels_last` data format)\n            or `(3, 224, 224)` (with `\"channels_first\"` data format). It\n            should have exactly 3 inputs channels.\n        pooling: Optional pooling mode for feature extraction\n            when `include_top` is `False`.\n            - `None` means that the output of the model will be\n                the 4D tensor output of the\n                last convolutional layer.\n            - `avg` means that global average pooling\n                will be applied to the output of the\n                last convolutional layer, and thus\n                the output of the model will be a 2D tensor.\n            - `max` means that global max pooling will\n                be applied.\n        classes: optional number of classes to classify images\n            into, only to be specified if `include_top` is `True`,\n            and if no `weights` argument is specified.\n        classifier_activation: A `str` or callable. The activation\n            function to use on the \"top\" layer. Ignored unless\n            `include_top=True`. Set `classifier_activation=None` to\n            return the logits of the \"top\" layer. When loading\n            pretrained weights, `classifier_activation` can only be\n            `None` or `\"softmax\"`.\n        name: The name of the model (string).\n\n    Returns:\n        A Model instance.\n    \"\"\"\n\n    if not (weights in {\"imagenet\", None} or file_utils.exists(weights)):\n        raise ValueError(\n            \"The `weights` argument should be either \"\n            \"`None` (random initialization), 'imagenet' \"\n            \"(pre-training on ImageNet), \"\n            \"or the path to the weights file to be loaded.  Received: \"\n            f\"weights={weights}\"\n        )\n\n    if weights == \"imagenet\" and include_top and classes != 1000:\n        raise ValueError(\n            \"If using `weights='imagenet'` with `include_top=True`, \"\n            \"`classes` should be 1000.  \"\n            f\"Received classes={classes}\"\n        )\n\n    # Determine proper input shape\n    input_shape = imagenet_utils.obtain_input_shape(\n        input_shape,\n        default_size=224,\n        min_size=32,\n        data_format=backend.image_data_format(),\n        require_flatten=include_top,\n        weights=weights,\n    )\n\n    if input_tensor is None:\n        img_input = layers.Input(shape=input_shape)\n    else:\n        if not backend.is_keras_tensor(input_tensor):\n            img_input = layers.Input(tensor=input_tensor, shape=input_shape)\n        else:\n            img_input = input_tensor\n\n    if backend.image_data_format() == \"channels_last\":\n        bn_axis = 3\n    else:\n        bn_axis = 1\n\n    x = layers.ZeroPadding2D(padding=((3, 3), (3, 3)), name=\"conv1_pad\")(\n        img_input\n    )\n    x = layers.Conv2D(64, 7, strides=2, use_bias=use_bias, name=\"conv1_conv\")(x)\n\n    if not preact:\n        x = layers.BatchNormalization(\n            axis=bn_axis, epsilon=1.001e-5, name=\"conv1_bn\"\n        )(x)\n        x = layers.Activation(\"relu\", name=\"conv1_relu\")(x)\n\n    x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)), name=\"pool1_pad\")(x)\n    x = layers.MaxPooling2D(3, strides=2, name=\"pool1_pool\")(x)\n\n    x = stack_fn(x)\n\n    if preact:\n        x = layers.BatchNormalization(\n            axis=bn_axis, epsilon=1.001e-5, name=\"post_bn\"\n        )(x)\n        x = layers.Activation(\"relu\", name=\"post_relu\")(x)\n\n    if include_top:\n        x = layers.GlobalAveragePooling2D(name=\"avg_pool\")(x)\n\n        # Validate activation for the classifier layer\n        imagenet_utils.validate_activation(classifier_activation, weights)\n\n        x = layers.Dense(\n            classes, activation=classifier_activation, name=\"predictions\"\n        )(x)\n    else:\n        if pooling == \"avg\":\n            x = layers.GlobalAveragePooling2D(name=\"avg_pool\")(x)\n        elif pooling == \"max\":\n            x = layers.GlobalMaxPooling2D(name=\"max_pool\")(x)\n\n    # Ensure that the model takes into account\n    # any potential predecessors of `input_tensor`.\n    if input_tensor is not None:\n        inputs = operation_utils.get_source_inputs(input_tensor)\n    else:\n        inputs = img_input\n\n    # Create model.\n    model = Functional(inputs, x, name=name)\n\n    # Load weights.\n    if (weights == \"imagenet\") and (weights_name in WEIGHTS_HASHES):\n        if include_top:\n            file_name = f\"{weights_name}_weights_tf_dim_ordering_tf_kernels.h5\"\n            file_hash = WEIGHTS_HASHES[weights_name][0]\n        else:\n            file_name = (\n                f\"{weights_name}_weights_tf_dim_ordering_tf_kernels_notop.h5\"\n            )\n            file_hash = WEIGHTS_HASHES[weights_name][1]\n        weights_path = file_utils.get_file(\n            file_name,\n            f\"{BASE_WEIGHTS_PATH}{file_name}\",\n            cache_subdir=\"models\",\n            file_hash=file_hash,\n        )\n        model.load_weights(weights_path)\n    elif weights is not None:\n        model.load_weights(weights)\n\n    return model\n\n\ndef residual_block_v1(\n    x, filters, kernel_size=3, stride=1, conv_shortcut=True, name=None\n):\n    \"\"\"A residual block for ResNet*_v1.\n\n    Args:\n        x: Input tensor.\n        filters: No of filters in the bottleneck layer.\n        kernel_size: Kernel size of the bottleneck layer. Defaults to `3`.\n        stride: Stride of the first layer. Defaults to `1`.\n        conv_shortcut: Use convolution shortcut if `True`, otherwise\n            use identity shortcut. Defaults to `True`\n        name(optional): Name of the block\n\n    Returns:\n        Output tensor for the residual block.\n    \"\"\"\n\n    if backend.image_data_format() == \"channels_last\":\n        bn_axis = 3\n    else:\n        bn_axis = 1\n\n    if conv_shortcut:\n        shortcut = layers.Conv2D(\n            4 * filters, 1, strides=stride, name=f\"{name}_0_conv\"\n        )(x)\n        shortcut = layers.BatchNormalization(\n            axis=bn_axis, epsilon=1.001e-5, name=f\"{name}_0_bn\"\n        )(shortcut)\n    else:\n        shortcut = x\n\n    x = layers.Conv2D(filters, 1, strides=stride, name=f\"{name}_1_conv\")(x)\n    x = layers.BatchNormalization(\n        axis=bn_axis, epsilon=1.001e-5, name=f\"{name}_1_bn\"\n    )(x)\n    x = layers.Activation(\"relu\", name=f\"{name}_1_relu\")(x)\n\n    x = layers.Conv2D(\n        filters, kernel_size, padding=\"SAME\", name=f\"{name}_2_conv\"\n    )(x)\n    x = layers.BatchNormalization(\n        axis=bn_axis, epsilon=1.001e-5, name=f\"{name}_2_bn\"\n    )(x)\n    x = layers.Activation(\"relu\", name=f\"{name}_2_relu\")(x)\n\n    x = layers.Conv2D(4 * filters, 1, name=f\"{name}_3_conv\")(x)\n    x = layers.BatchNormalization(\n        axis=bn_axis, epsilon=1.001e-5, name=f\"{name}_3_bn\"\n    )(x)\n\n    x = layers.Add(name=f\"{name}_add\")([shortcut, x])\n    x = layers.Activation(\"relu\", name=f\"{name}_out\")(x)\n    return x\n\n\ndef stack_residual_blocks_v1(x, filters, blocks, stride1=2, name=None):\n    \"\"\"A set of stacked residual blocks.\n\n    Args:\n        x: Input tensor.\n        filters: Number of filters in the bottleneck layer in a block.\n        blocks: Number of blocks in the stacked blocks.\n        stride1: Stride of the first layer in the first block. Defaults to `2`.\n        name: Stack label.\n\n    Returns:\n        Output tensor for the stacked blocks.\n    \"\"\"\n\n    x = residual_block_v1(x, filters, stride=stride1, name=f\"{name}_block1\")\n    for i in range(2, blocks + 1):\n        x = residual_block_v1(\n            x, filters, conv_shortcut=False, name=f\"{name}_block{i}\"\n        )\n    return x\n\n\ndef residual_block_v2(\n    x, filters, kernel_size=3, stride=1, conv_shortcut=False, name=None\n):\n    \"\"\"A residual block for ResNet*_v2.\n\n    Args:\n        x: Input tensor.\n        filters: No of filters in the bottleneck layer.\n        kernel_size: Kernel size of the bottleneck layer. Defaults to `3`.\n        stride: Stride of the first layer. Defaults to `1`.\n        conv_shortcut: Use convolution shortcut if `True`, otherwise\n            use identity shortcut. Defaults to `True`\n        name(optional): Name of the block\n\n    Returns:\n        Output tensor for the residual block.\n    \"\"\"\n\n    if backend.image_data_format() == \"channels_last\":\n        bn_axis = 3\n    else:\n        bn_axis = 1\n\n    preact = layers.BatchNormalization(\n        axis=bn_axis, epsilon=1.001e-5, name=f\"{name}_preact_bn\"\n    )(x)\n    preact = layers.Activation(\"relu\", name=f\"{name}_preact_relu\")(preact)\n\n    if conv_shortcut:\n        shortcut = layers.Conv2D(\n            4 * filters, 1, strides=stride, name=f\"{name}_0_conv\"\n        )(preact)\n    else:\n        shortcut = (\n            layers.MaxPooling2D(1, strides=stride)(x) if stride > 1 else x\n        )\n\n    x = layers.Conv2D(\n        filters, 1, strides=1, use_bias=False, name=f\"{name}_1_conv\"\n    )(preact)\n    x = layers.BatchNormalization(\n        axis=bn_axis, epsilon=1.001e-5, name=f\"{name}_1_bn\"\n    )(x)\n    x = layers.Activation(\"relu\", name=f\"{name}_1_relu\")(x)\n\n    x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)), name=f\"{name}_2_pad\")(x)\n    x = layers.Conv2D(\n        filters,\n        kernel_size,\n        strides=stride,\n        use_bias=False,\n        name=f\"{name}_2_conv\",\n    )(x)\n    x = layers.BatchNormalization(\n        axis=bn_axis, epsilon=1.001e-5, name=f\"{name}_2_bn\"\n    )(x)\n    x = layers.Activation(\"relu\", name=f\"{name}_2_relu\")(x)\n\n    x = layers.Conv2D(4 * filters, 1, name=f\"{name}_3_conv\")(x)\n    x = layers.Add(name=f\"{name}_out\")([shortcut, x])\n    return x\n\n\ndef stack_residual_blocks_v2(x, filters, blocks, stride1=2, name=None):\n    \"\"\"A set of stacked residual blocks.\n\n    Args:\n        x: Input tensor.\n        filters: Number of filters in the bottleneck layer in a block.\n        blocks: Number of blocks in the stacked blocks.\n        stride1: Stride of the first layer in the first block. Defaults to `2`.\n        name: Stack label.\n\n    Returns:\n        Output tensor for the stacked blocks.\n    \"\"\"\n\n    x = residual_block_v2(x, filters, conv_shortcut=True, name=f\"{name}_block1\")\n    for i in range(2, blocks):\n        x = residual_block_v2(x, filters, name=f\"{name}_block{i}\")\n    x = residual_block_v2(\n        x, filters, stride=stride1, name=f\"{name}_block{str(blocks)}\"\n    )\n    return x\n\n\n@keras_export(\n    [\n        \"keras.applications.resnet50.ResNet50\",\n        \"keras.applications.resnet.ResNet50\",\n        \"keras.applications.ResNet50\",\n    ]\n)\ndef ResNet50(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"resnet50\",\n):\n    \"\"\"Instantiates the ResNet50 architecture.\"\"\"\n\n    def stack_fn(x):\n        x = stack_residual_blocks_v1(x, 64, 3, stride1=1, name=\"conv2\")\n        x = stack_residual_blocks_v1(x, 128, 4, name=\"conv3\")\n        x = stack_residual_blocks_v1(x, 256, 6, name=\"conv4\")\n        return stack_residual_blocks_v1(x, 512, 3, name=\"conv5\")\n\n    return ResNet(\n        stack_fn,\n        preact=False,\n        use_bias=True,\n        weights_name=\"resnet50\",\n        name=name,\n        include_top=include_top,\n        weights=weights,\n        input_tensor=input_tensor,\n        input_shape=input_shape,\n        pooling=pooling,\n        classes=classes,\n        classifier_activation=classifier_activation,\n    )\n\n\n@keras_export(\n    [\n        \"keras.applications.resnet.ResNet101\",\n        \"keras.applications.ResNet101\",\n    ]\n)\ndef ResNet101(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"resnet101\",\n):\n    \"\"\"Instantiates the ResNet101 architecture.\"\"\"\n\n    def stack_fn(x):\n        x = stack_residual_blocks_v1(x, 64, 3, stride1=1, name=\"conv2\")\n        x = stack_residual_blocks_v1(x, 128, 4, name=\"conv3\")\n        x = stack_residual_blocks_v1(x, 256, 23, name=\"conv4\")\n        return stack_residual_blocks_v1(x, 512, 3, name=\"conv5\")\n\n    return ResNet(\n        stack_fn,\n        preact=False,\n        use_bias=True,\n        name=name,\n        weights_name=\"resnet101\",\n        include_top=include_top,\n        weights=weights,\n        input_tensor=input_tensor,\n        input_shape=input_shape,\n        pooling=pooling,\n        classes=classes,\n        classifier_activation=classifier_activation,\n    )\n\n\n@keras_export(\n    [\n        \"keras.applications.resnet.ResNet152\",\n        \"keras.applications.ResNet152\",\n    ]\n)\ndef ResNet152(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"resnet152\",\n):\n    \"\"\"Instantiates the ResNet152 architecture.\"\"\"\n\n    def stack_fn(x):\n        x = stack_residual_blocks_v1(x, 64, 3, stride1=1, name=\"conv2\")\n        x = stack_residual_blocks_v1(x, 128, 8, name=\"conv3\")\n        x = stack_residual_blocks_v1(x, 256, 36, name=\"conv4\")\n        return stack_residual_blocks_v1(x, 512, 3, name=\"conv5\")\n\n    return ResNet(\n        stack_fn,\n        preact=False,\n        use_bias=True,\n        name=name,\n        weights_name=\"resnet152\",\n        include_top=include_top,\n        weights=weights,\n        input_tensor=input_tensor,\n        input_shape=input_shape,\n        pooling=pooling,\n        classes=classes,\n        classifier_activation=classifier_activation,\n    )\n\n\n@keras_export(\n    [\n        \"keras.applications.resnet50.preprocess_input\",\n        \"keras.applications.resnet.preprocess_input\",\n    ]\n)\ndef preprocess_input(x, data_format=None):\n    return imagenet_utils.preprocess_input(\n        x, data_format=data_format, mode=\"caffe\"\n    )\n\n\n@keras_export(\n    [\n        \"keras.applications.resnet50.decode_predictions\",\n        \"keras.applications.resnet.decode_predictions\",\n    ]\n)\ndef decode_predictions(preds, top=5):\n    return imagenet_utils.decode_predictions(preds, top=top)\n\n\npreprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(\n    mode=\"\",\n    ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_CAFFE,\n    error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC,\n)\ndecode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__\n\nDOC = \"\"\"\n\nReference:\n- [Deep Residual Learning for Image Recognition](\n    https://arxiv.org/abs/1512.03385) (CVPR 2015)\n\nFor image classification use cases, see [this page for detailed examples](\n    https://keras.io/api/applications/#usage-examples-for-image-classification-models).\n\nFor transfer learning use cases, make sure to read the\n[guide to transfer learning & fine-tuning](\n    https://keras.io/guides/transfer_learning/).\n\nNote: each Keras Application expects a specific kind of input preprocessing.\nFor ResNet, call `keras.applications.resnet.preprocess_input` on your\ninputs before passing them to the model. `resnet.preprocess_input` will convert\nthe input images from RGB to BGR, then will zero-center each color channel with\nrespect to the ImageNet dataset, without scaling.\n\nArgs:\n    include_top: whether to include the fully-connected\n        layer at the top of the network.\n    weights: one of `None` (random initialization),\n        `\"imagenet\"` (pre-training on ImageNet), or the path to the weights\n        file to be loaded.\n    input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)\n        to use as image input for the model.\n    input_shape: optional shape tuple, only to be specified if `include_top`\n        is `False` (otherwise the input shape has to be `(224, 224, 3)`\n        (with `\"channels_last\"` data format) or `(3, 224, 224)`\n        (with `\"channels_first\"` data format). It should have exactly 3\n        inputs channels, and width and height should be no smaller than 32.\n        E.g. `(200, 200, 3)` would be one valid value.\n    pooling: Optional pooling mode for feature extraction when `include_top`\n        is `False`.\n        - `None` means that the output of the model will be the 4D tensor\n                output of the last convolutional block.\n        - `avg` means that global average pooling will be applied to the output\n                of the last convolutional block, and thus the output of the\n                model will be a 2D tensor.\n        - `max` means that global max pooling will be applied.\n    classes: optional number of classes to classify images into, only to be\n        specified if `include_top` is `True`, and if no `weights` argument is\n        specified. Defaults to `1000`.\n    classifier_activation: A `str` or callable. The activation function to\n        use on the \"top\" layer. Ignored unless `include_top=True`. Set\n        `classifier_activation=None` to return the logits of the \"top\" layer.\n        When loading pretrained weights, `classifier_activation` can only\n        be `None` or `\"softmax\"`.\n    name: The name of the model (string).\n\nReturns:\n    A Model instance.\n\"\"\"\n\nsetattr(ResNet50, \"__doc__\", ResNet50.__doc__ + DOC)\nsetattr(ResNet101, \"__doc__\", ResNet101.__doc__ + DOC)\nsetattr(ResNet152, \"__doc__\", ResNet152.__doc__ + DOC)\n"
  },
  {
    "path": "keras/src/applications/resnet_v2.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.applications import imagenet_utils\nfrom keras.src.applications import resnet\n\n\n@keras_export(\n    [\n        \"keras.applications.ResNet50V2\",\n        \"keras.applications.resnet_v2.ResNet50V2\",\n    ]\n)\ndef ResNet50V2(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"resnet50v2\",\n):\n    \"\"\"Instantiates the ResNet50V2 architecture.\"\"\"\n\n    def stack_fn(x):\n        x = resnet.stack_residual_blocks_v2(x, 64, 3, name=\"conv2\")\n        x = resnet.stack_residual_blocks_v2(x, 128, 4, name=\"conv3\")\n        x = resnet.stack_residual_blocks_v2(x, 256, 6, name=\"conv4\")\n        return resnet.stack_residual_blocks_v2(\n            x, 512, 3, stride1=1, name=\"conv5\"\n        )\n\n    return resnet.ResNet(\n        stack_fn,\n        True,\n        True,\n        name=name,\n        weights_name=\"resnet50v2\",\n        include_top=include_top,\n        weights=weights,\n        input_tensor=input_tensor,\n        input_shape=input_shape,\n        pooling=pooling,\n        classes=classes,\n        classifier_activation=classifier_activation,\n    )\n\n\n@keras_export(\n    [\n        \"keras.applications.ResNet101V2\",\n        \"keras.applications.resnet_v2.ResNet101V2\",\n    ]\n)\ndef ResNet101V2(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"resnet101v2\",\n):\n    \"\"\"Instantiates the ResNet101V2 architecture.\"\"\"\n\n    def stack_fn(x):\n        x = resnet.stack_residual_blocks_v2(x, 64, 3, name=\"conv2\")\n        x = resnet.stack_residual_blocks_v2(x, 128, 4, name=\"conv3\")\n        x = resnet.stack_residual_blocks_v2(x, 256, 23, name=\"conv4\")\n        return resnet.stack_residual_blocks_v2(\n            x, 512, 3, stride1=1, name=\"conv5\"\n        )\n\n    return resnet.ResNet(\n        stack_fn,\n        True,\n        True,\n        name=name,\n        weights_name=\"resnet101v2\",\n        include_top=include_top,\n        weights=weights,\n        input_tensor=input_tensor,\n        input_shape=input_shape,\n        pooling=pooling,\n        classes=classes,\n        classifier_activation=classifier_activation,\n    )\n\n\n@keras_export(\n    [\n        \"keras.applications.ResNet152V2\",\n        \"keras.applications.resnet_v2.ResNet152V2\",\n    ]\n)\ndef ResNet152V2(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"resnet152v2\",\n):\n    \"\"\"Instantiates the ResNet152V2 architecture.\"\"\"\n\n    def stack_fn(x):\n        x = resnet.stack_residual_blocks_v2(x, 64, 3, name=\"conv2\")\n        x = resnet.stack_residual_blocks_v2(x, 128, 8, name=\"conv3\")\n        x = resnet.stack_residual_blocks_v2(x, 256, 36, name=\"conv4\")\n        return resnet.stack_residual_blocks_v2(\n            x, 512, 3, stride1=1, name=\"conv5\"\n        )\n\n    return resnet.ResNet(\n        stack_fn,\n        True,\n        True,\n        name=name,\n        weights_name=\"resnet152v2\",\n        include_top=include_top,\n        weights=weights,\n        input_tensor=input_tensor,\n        input_shape=input_shape,\n        pooling=pooling,\n        classes=classes,\n        classifier_activation=classifier_activation,\n    )\n\n\n@keras_export(\"keras.applications.resnet_v2.preprocess_input\")\ndef preprocess_input(x, data_format=None):\n    return imagenet_utils.preprocess_input(\n        x, data_format=data_format, mode=\"tf\"\n    )\n\n\n@keras_export(\"keras.applications.resnet_v2.decode_predictions\")\ndef decode_predictions(preds, top=5):\n    return imagenet_utils.decode_predictions(preds, top=top)\n\n\npreprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(\n    mode=\"\",\n    ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,\n    error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC,\n)\ndecode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__\n\n\nDOC = \"\"\"\n\nReference:\n- [Identity Mappings in Deep Residual Networks](\n    https://arxiv.org/abs/1603.05027) (CVPR 2016)\n\nFor image classification use cases, see [this page for detailed examples](\n    https://keras.io/api/applications/#usage-examples-for-image-classification-models).\n\nFor transfer learning use cases, make sure to read the\n[guide to transfer learning & fine-tuning](\n    https://keras.io/guides/transfer_learning/).\n\nNote: each Keras Application expects a specific kind of input preprocessing.\nFor ResNet, call `keras.applications.resnet_v2.preprocess_input` on your\ninputs before passing them to the model. `resnet_v2.preprocess_input` will\nscale input pixels between -1 and 1.\n\nArgs:\n    include_top: whether to include the fully-connected\n        layer at the top of the network.\n    weights: one of `None` (random initialization),\n        `\"imagenet\"` (pre-training on ImageNet), or the path to the weights\n        file to be loaded.\n    input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)\n        to use as image input for the model.\n    input_shape: optional shape tuple, only to be specified if `include_top`\n        is `False` (otherwise the input shape has to be `(224, 224, 3)`\n        (with `\"channels_last\"` data format) or `(3, 224, 224)`\n        (with `\"channels_first\"` data format). It should have exactly 3\n        inputs channels, and width and height should be no smaller than 32.\n        E.g. `(200, 200, 3)` would be one valid value.\n    pooling: Optional pooling mode for feature extraction when `include_top`\n        is `False`.\n        - `None` means that the output of the model will be the 4D tensor\n                output of the last convolutional block.\n        - `avg` means that global average pooling will be applied to the output\n                of the last convolutional block, and thus the output of the\n                model will be a 2D tensor.\n        - `max` means that global max pooling will be applied.\n    classes: optional number of classes to classify images into, only to be\n        specified if `include_top` is `True`, and if no `weights` argument is\n        specified.\n    classifier_activation: A `str` or callable. The activation function to\n        use on the \"top\" layer. Ignored unless `include_top=True`. Set\n        `classifier_activation=None` to return the logits of the \"top\" layer.\n        When loading pretrained weights, `classifier_activation` can only\n        be `None` or `\"softmax\"`.\n    name: The name of the model (string).\n\nReturns:\n    A Model instance.\n\"\"\"\n\nsetattr(ResNet50V2, \"__doc__\", ResNet50V2.__doc__ + DOC)\nsetattr(ResNet101V2, \"__doc__\", ResNet101V2.__doc__ + DOC)\nsetattr(ResNet152V2, \"__doc__\", ResNet152V2.__doc__ + DOC)\n"
  },
  {
    "path": "keras/src/applications/vgg16.py",
    "content": "from keras.src import backend\nfrom keras.src import layers\nfrom keras.src.api_export import keras_export\nfrom keras.src.applications import imagenet_utils\nfrom keras.src.models import Functional\nfrom keras.src.ops import operation_utils\nfrom keras.src.utils import file_utils\n\nWEIGHTS_PATH = (\n    \"https://storage.googleapis.com/tensorflow/keras-applications/\"\n    \"vgg16/vgg16_weights_tf_dim_ordering_tf_kernels.h5\"\n)\nWEIGHTS_PATH_NO_TOP = (\n    \"https://storage.googleapis.com/tensorflow/\"\n    \"keras-applications/vgg16/\"\n    \"vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5\"\n)\n\n\n@keras_export([\"keras.applications.vgg16.VGG16\", \"keras.applications.VGG16\"])\ndef VGG16(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"vgg16\",\n):\n    \"\"\"Instantiates the VGG16 model.\n\n    Reference:\n    - [Very Deep Convolutional Networks for Large-Scale Image Recognition](\n    https://arxiv.org/abs/1409.1556) (ICLR 2015)\n\n    For image classification use cases, see\n    [this page for detailed examples](\n      https://keras.io/api/applications/#usage-examples-for-image-classification-models).\n\n    For transfer learning use cases, make sure to read the\n    [guide to transfer learning & fine-tuning](\n      https://keras.io/guides/transfer_learning/).\n\n    The default input size for this model is 224x224.\n\n    Note: each Keras Application expects a specific kind of input preprocessing.\n    For VGG16, call `keras.applications.vgg16.preprocess_input` on your\n    inputs before passing them to the model.\n    `vgg16.preprocess_input` will convert the input images from RGB to BGR,\n    then will zero-center each color channel with respect to the ImageNet\n    dataset, without scaling.\n\n    Args:\n        include_top: whether to include the 3 fully-connected\n            layers at the top of the network.\n        weights: one of `None` (random initialization),\n            `\"imagenet\"` (pre-training on ImageNet),\n            or the path to the weights file to be loaded.\n        input_tensor: optional Keras tensor\n            (i.e. output of `layers.Input()`)\n            to use as image input for the model.\n        input_shape: optional shape tuple, only to be specified\n            if `include_top` is `False` (otherwise the input shape\n            has to be `(224, 224, 3)`\n            (with `channels_last` data format) or\n            `(3, 224, 224)` (with `\"channels_first\"` data format).\n            It should have exactly 3 input channels,\n            and width and height should be no smaller than 32.\n            E.g. `(200, 200, 3)` would be one valid value.\n        pooling: Optional pooling mode for feature extraction\n            when `include_top` is `False`.\n            - `None` means that the output of the model will be\n                the 4D tensor output of the\n                last convolutional block.\n            - `avg` means that global average pooling\n                will be applied to the output of the\n                last convolutional block, and thus\n                the output of the model will be a 2D tensor.\n            - `max` means that global max pooling will\n                be applied.\n        classes: optional number of classes to classify images\n            into, only to be specified if `include_top` is `True`, and\n            if no `weights` argument is specified.\n        classifier_activation: A `str` or callable. The activation function to\n            use on the \"top\" layer. Ignored unless `include_top=True`. Set\n            `classifier_activation=None` to return the logits of the \"top\"\n            layer.  When loading pretrained weights, `classifier_activation`\n            can only be `None` or `\"softmax\"`.\n        name: The name of the model (string).\n\n    Returns:\n        A `Model` instance.\n    \"\"\"\n    if not (weights in {\"imagenet\", None} or file_utils.exists(weights)):\n        raise ValueError(\n            \"The `weights` argument should be either \"\n            \"`None` (random initialization), 'imagenet' \"\n            \"(pre-training on ImageNet), \"\n            \"or the path to the weights file to be loaded.  Received: \"\n            f\"weights={weights}\"\n        )\n\n    if weights == \"imagenet\" and include_top and classes != 1000:\n        raise ValueError(\n            \"If using `weights='imagenet'` with `include_top=True`, \"\n            \"`classes` should be 1000.  \"\n            f\"Received classes={classes}\"\n        )\n\n    # Determine proper input shape\n    input_shape = imagenet_utils.obtain_input_shape(\n        input_shape,\n        default_size=224,\n        min_size=32,\n        data_format=backend.image_data_format(),\n        require_flatten=include_top,\n        weights=weights,\n    )\n\n    if input_tensor is None:\n        img_input = layers.Input(shape=input_shape)\n    else:\n        if not backend.is_keras_tensor(input_tensor):\n            img_input = layers.Input(tensor=input_tensor, shape=input_shape)\n        else:\n            img_input = input_tensor\n    # Block 1\n    x = layers.Conv2D(\n        64, (3, 3), activation=\"relu\", padding=\"same\", name=\"block1_conv1\"\n    )(img_input)\n    x = layers.Conv2D(\n        64, (3, 3), activation=\"relu\", padding=\"same\", name=\"block1_conv2\"\n    )(x)\n    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name=\"block1_pool\")(x)\n\n    # Block 2\n    x = layers.Conv2D(\n        128, (3, 3), activation=\"relu\", padding=\"same\", name=\"block2_conv1\"\n    )(x)\n    x = layers.Conv2D(\n        128, (3, 3), activation=\"relu\", padding=\"same\", name=\"block2_conv2\"\n    )(x)\n    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name=\"block2_pool\")(x)\n\n    # Block 3\n    x = layers.Conv2D(\n        256, (3, 3), activation=\"relu\", padding=\"same\", name=\"block3_conv1\"\n    )(x)\n    x = layers.Conv2D(\n        256, (3, 3), activation=\"relu\", padding=\"same\", name=\"block3_conv2\"\n    )(x)\n    x = layers.Conv2D(\n        256, (3, 3), activation=\"relu\", padding=\"same\", name=\"block3_conv3\"\n    )(x)\n    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name=\"block3_pool\")(x)\n\n    # Block 4\n    x = layers.Conv2D(\n        512, (3, 3), activation=\"relu\", padding=\"same\", name=\"block4_conv1\"\n    )(x)\n    x = layers.Conv2D(\n        512, (3, 3), activation=\"relu\", padding=\"same\", name=\"block4_conv2\"\n    )(x)\n    x = layers.Conv2D(\n        512, (3, 3), activation=\"relu\", padding=\"same\", name=\"block4_conv3\"\n    )(x)\n    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name=\"block4_pool\")(x)\n\n    # Block 5\n    x = layers.Conv2D(\n        512, (3, 3), activation=\"relu\", padding=\"same\", name=\"block5_conv1\"\n    )(x)\n    x = layers.Conv2D(\n        512, (3, 3), activation=\"relu\", padding=\"same\", name=\"block5_conv2\"\n    )(x)\n    x = layers.Conv2D(\n        512, (3, 3), activation=\"relu\", padding=\"same\", name=\"block5_conv3\"\n    )(x)\n    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name=\"block5_pool\")(x)\n\n    if include_top:\n        # Classification block\n        x = layers.Flatten(name=\"flatten\")(x)\n        x = layers.Dense(4096, activation=\"relu\", name=\"fc1\")(x)\n        x = layers.Dense(4096, activation=\"relu\", name=\"fc2\")(x)\n\n        imagenet_utils.validate_activation(classifier_activation, weights)\n        x = layers.Dense(\n            classes, activation=classifier_activation, name=\"predictions\"\n        )(x)\n    else:\n        if pooling == \"avg\":\n            x = layers.GlobalAveragePooling2D()(x)\n        elif pooling == \"max\":\n            x = layers.GlobalMaxPooling2D()(x)\n\n    # Ensure that the model takes into account\n    # any potential predecessors of `input_tensor`.\n    if input_tensor is not None:\n        inputs = operation_utils.get_source_inputs(input_tensor)\n    else:\n        inputs = img_input\n\n    # Create model.\n    model = Functional(inputs, x, name=name)\n\n    # Load weights.\n    if weights == \"imagenet\":\n        if include_top:\n            weights_path = file_utils.get_file(\n                \"vgg16_weights_tf_dim_ordering_tf_kernels.h5\",\n                WEIGHTS_PATH,\n                cache_subdir=\"models\",\n                file_hash=\"64373286793e3c8b2b4e3219cbf3544b\",\n            )\n        else:\n            weights_path = file_utils.get_file(\n                \"vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5\",\n                WEIGHTS_PATH_NO_TOP,\n                cache_subdir=\"models\",\n                file_hash=\"6d6bbae143d832006294945121d1f1fc\",\n            )\n        model.load_weights(weights_path)\n    elif weights is not None:\n        model.load_weights(weights)\n\n    return model\n\n\n@keras_export(\"keras.applications.vgg16.preprocess_input\")\ndef preprocess_input(x, data_format=None):\n    return imagenet_utils.preprocess_input(\n        x, data_format=data_format, mode=\"caffe\"\n    )\n\n\n@keras_export(\"keras.applications.vgg16.decode_predictions\")\ndef decode_predictions(preds, top=5):\n    return imagenet_utils.decode_predictions(preds, top=top)\n\n\npreprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(\n    mode=\"\",\n    ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_CAFFE,\n    error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC,\n)\ndecode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__\n"
  },
  {
    "path": "keras/src/applications/vgg19.py",
    "content": "from keras.src import backend\nfrom keras.src import layers\nfrom keras.src.api_export import keras_export\nfrom keras.src.applications import imagenet_utils\nfrom keras.src.models import Functional\nfrom keras.src.ops import operation_utils\nfrom keras.src.utils import file_utils\n\nWEIGHTS_PATH = (\n    \"https://storage.googleapis.com/tensorflow/keras-applications/\"\n    \"vgg19/vgg19_weights_tf_dim_ordering_tf_kernels.h5\"\n)\nWEIGHTS_PATH_NO_TOP = (\n    \"https://storage.googleapis.com/tensorflow/\"\n    \"keras-applications/vgg19/\"\n    \"vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5\"\n)\n\n\n@keras_export([\"keras.applications.vgg19.VGG19\", \"keras.applications.VGG19\"])\ndef VGG19(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"vgg19\",\n):\n    \"\"\"Instantiates the VGG19 model.\n\n    Reference:\n    - [Very Deep Convolutional Networks for Large-Scale Image Recognition](\n    https://arxiv.org/abs/1409.1556) (ICLR 2015)\n\n    For image classification use cases, see\n    [this page for detailed examples](\n      https://keras.io/api/applications/#usage-examples-for-image-classification-models).\n\n    For transfer learning use cases, make sure to read the\n    [guide to transfer learning & fine-tuning](\n      https://keras.io/guides/transfer_learning/).\n\n    The default input size for this model is 224x224.\n\n    Note: each Keras Application expects a specific kind of input preprocessing.\n    For VGG19, call `keras.applications.vgg19.preprocess_input` on your\n    inputs before passing them to the model.\n    `vgg19.preprocess_input` will convert the input images from RGB to BGR,\n    then will zero-center each color channel with respect to the ImageNet\n    dataset, without scaling.\n\n    Args:\n        include_top: whether to include the 3 fully-connected\n            layers at the top of the network.\n        weights: one of `None` (random initialization),\n            `\"imagenet\"` (pre-training on ImageNet),\n            or the path to the weights file to be loaded.\n        input_tensor: optional Keras tensor\n            (i.e. output of `layers.Input()`)\n            to use as image input for the model.\n        input_shape: optional shape tuple, only to be specified\n            if `include_top` is `False` (otherwise the input shape\n            has to be `(224, 224, 3)`\n            (with `channels_last` data format) or\n            `(3, 224, 224)` (with `\"channels_first\"` data format).\n            It should have exactly 3 input channels,\n            and width and height should be no smaller than 32.\n            E.g. `(200, 200, 3)` would be one valid value.\n        pooling: Optional pooling mode for feature extraction\n            when `include_top` is `False`.\n            - `None` means that the output of the model will be\n                the 4D tensor output of the\n                last convolutional block.\n            - `avg` means that global average pooling\n                will be applied to the output of the\n                last convolutional block, and thus\n                the output of the model will be a 2D tensor.\n            - `max` means that global max pooling will\n                be applied.\n        classes: optional number of classes to classify images\n            into, only to be specified if `include_top` is `True`, and\n            if no `weights` argument is specified.\n        classifier_activation: A `str` or callable. The activation function to\n            use on the \"top\" layer. Ignored unless `include_top=True`. Set\n            `classifier_activation=None` to return the logits of the \"top\"\n            layer.  When loading pretrained weights, `classifier_activation` can\n            only be `None` or `\"softmax\"`.\n        name: The name of the model (string).\n\n    Returns:\n        A model instance.\n    \"\"\"\n    if not (weights in {\"imagenet\", None} or file_utils.exists(weights)):\n        raise ValueError(\n            \"The `weights` argument should be either \"\n            \"`None` (random initialization), 'imagenet' \"\n            \"(pre-training on ImageNet), \"\n            \"or the path to the weights file to be loaded.  Received: \"\n            f\"weights={weights}\"\n        )\n\n    if weights == \"imagenet\" and include_top and classes != 1000:\n        raise ValueError(\n            \"If using `weights='imagenet'` with `include_top=True`, \"\n            \"`classes` should be 1000.  \"\n            f\"Received classes={classes}\"\n        )\n\n    # Determine proper input shape\n    input_shape = imagenet_utils.obtain_input_shape(\n        input_shape,\n        default_size=224,\n        min_size=32,\n        data_format=backend.image_data_format(),\n        require_flatten=include_top,\n        weights=weights,\n    )\n\n    if input_tensor is None:\n        img_input = layers.Input(shape=input_shape)\n    else:\n        if not backend.is_keras_tensor(input_tensor):\n            img_input = layers.Input(tensor=input_tensor, shape=input_shape)\n        else:\n            img_input = input_tensor\n    # Block 1\n    x = layers.Conv2D(\n        64, (3, 3), activation=\"relu\", padding=\"same\", name=\"block1_conv1\"\n    )(img_input)\n    x = layers.Conv2D(\n        64, (3, 3), activation=\"relu\", padding=\"same\", name=\"block1_conv2\"\n    )(x)\n    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name=\"block1_pool\")(x)\n\n    # Block 2\n    x = layers.Conv2D(\n        128, (3, 3), activation=\"relu\", padding=\"same\", name=\"block2_conv1\"\n    )(x)\n    x = layers.Conv2D(\n        128, (3, 3), activation=\"relu\", padding=\"same\", name=\"block2_conv2\"\n    )(x)\n    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name=\"block2_pool\")(x)\n\n    # Block 3\n    x = layers.Conv2D(\n        256, (3, 3), activation=\"relu\", padding=\"same\", name=\"block3_conv1\"\n    )(x)\n    x = layers.Conv2D(\n        256, (3, 3), activation=\"relu\", padding=\"same\", name=\"block3_conv2\"\n    )(x)\n    x = layers.Conv2D(\n        256, (3, 3), activation=\"relu\", padding=\"same\", name=\"block3_conv3\"\n    )(x)\n    x = layers.Conv2D(\n        256, (3, 3), activation=\"relu\", padding=\"same\", name=\"block3_conv4\"\n    )(x)\n    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name=\"block3_pool\")(x)\n\n    # Block 4\n    x = layers.Conv2D(\n        512, (3, 3), activation=\"relu\", padding=\"same\", name=\"block4_conv1\"\n    )(x)\n    x = layers.Conv2D(\n        512, (3, 3), activation=\"relu\", padding=\"same\", name=\"block4_conv2\"\n    )(x)\n    x = layers.Conv2D(\n        512, (3, 3), activation=\"relu\", padding=\"same\", name=\"block4_conv3\"\n    )(x)\n    x = layers.Conv2D(\n        512, (3, 3), activation=\"relu\", padding=\"same\", name=\"block4_conv4\"\n    )(x)\n    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name=\"block4_pool\")(x)\n\n    # Block 5\n    x = layers.Conv2D(\n        512, (3, 3), activation=\"relu\", padding=\"same\", name=\"block5_conv1\"\n    )(x)\n    x = layers.Conv2D(\n        512, (3, 3), activation=\"relu\", padding=\"same\", name=\"block5_conv2\"\n    )(x)\n    x = layers.Conv2D(\n        512, (3, 3), activation=\"relu\", padding=\"same\", name=\"block5_conv3\"\n    )(x)\n    x = layers.Conv2D(\n        512, (3, 3), activation=\"relu\", padding=\"same\", name=\"block5_conv4\"\n    )(x)\n    x = layers.MaxPooling2D((2, 2), strides=(2, 2), name=\"block5_pool\")(x)\n\n    if include_top:\n        # Classification block\n        x = layers.Flatten(name=\"flatten\")(x)\n        x = layers.Dense(4096, activation=\"relu\", name=\"fc1\")(x)\n        x = layers.Dense(4096, activation=\"relu\", name=\"fc2\")(x)\n        imagenet_utils.validate_activation(classifier_activation, weights)\n        x = layers.Dense(\n            classes, activation=classifier_activation, name=\"predictions\"\n        )(x)\n    else:\n        if pooling == \"avg\":\n            x = layers.GlobalAveragePooling2D()(x)\n        elif pooling == \"max\":\n            x = layers.GlobalMaxPooling2D()(x)\n\n    # Ensure that the model takes into account\n    # any potential predecessors of `input_tensor`.\n    if input_tensor is not None:\n        inputs = operation_utils.get_source_inputs(input_tensor)\n    else:\n        inputs = img_input\n\n    # Create model.\n    model = Functional(inputs, x, name=name)\n\n    # Load weights.\n    if weights == \"imagenet\":\n        if include_top:\n            weights_path = file_utils.get_file(\n                \"vgg19_weights_tf_dim_ordering_tf_kernels.h5\",\n                WEIGHTS_PATH,\n                cache_subdir=\"models\",\n                file_hash=\"cbe5617147190e668d6c5d5026f83318\",\n            )\n        else:\n            weights_path = file_utils.get_file(\n                \"vgg19_weights_tf_dim_ordering_tf_kernels_notop.h5\",\n                WEIGHTS_PATH_NO_TOP,\n                cache_subdir=\"models\",\n                file_hash=\"253f8cb515780f3b799900260a226db6\",\n            )\n        model.load_weights(weights_path)\n    elif weights is not None:\n        model.load_weights(weights)\n\n    return model\n\n\n@keras_export(\"keras.applications.vgg19.preprocess_input\")\ndef preprocess_input(x, data_format=None):\n    return imagenet_utils.preprocess_input(\n        x, data_format=data_format, mode=\"caffe\"\n    )\n\n\n@keras_export(\"keras.applications.vgg19.decode_predictions\")\ndef decode_predictions(preds, top=5):\n    return imagenet_utils.decode_predictions(preds, top=top)\n\n\npreprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(\n    mode=\"\",\n    ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_CAFFE,\n    error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC,\n)\ndecode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__\n"
  },
  {
    "path": "keras/src/applications/xception.py",
    "content": "from keras.src import backend\nfrom keras.src import layers\nfrom keras.src.api_export import keras_export\nfrom keras.src.applications import imagenet_utils\nfrom keras.src.models import Functional\nfrom keras.src.ops import operation_utils\nfrom keras.src.utils import file_utils\n\nWEIGHTS_PATH = (\n    \"https://storage.googleapis.com/tensorflow/keras-applications/\"\n    \"xception/xception_weights_tf_dim_ordering_tf_kernels.h5\"\n)\nWEIGHTS_PATH_NO_TOP = (\n    \"https://storage.googleapis.com/tensorflow/keras-applications/\"\n    \"xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5\"\n)\n\n\n@keras_export(\n    [\n        \"keras.applications.xception.Xception\",\n        \"keras.applications.Xception\",\n    ]\n)\ndef Xception(\n    include_top=True,\n    weights=\"imagenet\",\n    input_tensor=None,\n    input_shape=None,\n    pooling=None,\n    classes=1000,\n    classifier_activation=\"softmax\",\n    name=\"xception\",\n):\n    \"\"\"Instantiates the Xception architecture.\n\n    Reference:\n    - [Xception: Deep Learning with Depthwise Separable Convolutions](\n        https://arxiv.org/abs/1610.02357) (CVPR 2017)\n\n    For image classification use cases, see\n    [this page for detailed examples](\n      https://keras.io/api/applications/#usage-examples-for-image-classification-models).\n\n    For transfer learning use cases, make sure to read the\n    [guide to transfer learning & fine-tuning](\n      https://keras.io/guides/transfer_learning/).\n\n    The default input image size for this model is 299x299.\n\n    Note: each Keras Application expects a specific kind of input preprocessing.\n    For Xception, call `keras.applications.xception.preprocess_input`\n    on your inputs before passing them to the model.\n    `xception.preprocess_input` will scale input pixels between -1 and 1.\n\n    Args:\n        include_top: whether to include the 3 fully-connected\n            layers at the top of the network.\n        weights: one of `None` (random initialization),\n            `\"imagenet\"` (pre-training on ImageNet),\n            or the path to the weights file to be loaded.\n        input_tensor: optional Keras tensor\n            (i.e. output of `layers.Input()`)\n            to use as image input for the model.\n        input_shape: optional shape tuple, only to be specified\n            if `include_top` is `False` (otherwise the input shape\n            has to be `(299, 299, 3)`.\n            It should have exactly 3 inputs channels,\n            and width and height should be no smaller than 71.\n            E.g. `(150, 150, 3)` would be one valid value.\n        pooling: Optional pooling mode for feature extraction\n            when `include_top` is `False`.\n            - `None` means that the output of the model will be\n                the 4D tensor output of the\n                last convolutional block.\n            - `avg` means that global average pooling\n                will be applied to the output of the\n                last convolutional block, and thus\n                the output of the model will be a 2D tensor.\n            - `max` means that global max pooling will\n                be applied.\n        classes: optional number of classes to classify images\n            into, only to be specified if `include_top` is `True`, and\n            if no `weights` argument is specified. Defaults to `1000`.\n        classifier_activation: A `str` or callable. The activation function to\n            use on the \"top\" layer. Ignored unless `include_top=True`. Set\n            `classifier_activation=None` to return the logits of the \"top\"\n            layer.  When loading pretrained weights, `classifier_activation` can\n            only be `None` or `\"softmax\"`.\n        name: The name of the model (string).\n\n    Returns:\n        A model instance.\n    \"\"\"\n    if not (weights in {\"imagenet\", None} or file_utils.exists(weights)):\n        raise ValueError(\n            \"The `weights` argument should be either \"\n            \"`None` (random initialization), 'imagenet' \"\n            \"(pre-training on ImageNet), \"\n            \"or the path to the weights file to be loaded.\"\n        )\n\n    if weights == \"imagenet\" and include_top and classes != 1000:\n        raise ValueError(\n            \"If using `weights='imagenet'` with `include_top=True`, \"\n            \"`classes` should be 1000.  \"\n            f\"Received classes={classes}\"\n        )\n\n    # Determine proper input shape\n    input_shape = imagenet_utils.obtain_input_shape(\n        input_shape,\n        default_size=299,\n        min_size=71,\n        data_format=backend.image_data_format(),\n        require_flatten=include_top,\n        weights=weights,\n    )\n\n    if input_tensor is None:\n        img_input = layers.Input(shape=input_shape)\n    else:\n        if not backend.is_keras_tensor(input_tensor):\n            img_input = layers.Input(tensor=input_tensor, shape=input_shape)\n        else:\n            img_input = input_tensor\n\n    channel_axis = 1 if backend.image_data_format() == \"channels_first\" else -1\n\n    x = layers.Conv2D(\n        32, (3, 3), strides=(2, 2), use_bias=False, name=\"block1_conv1\"\n    )(img_input)\n    x = layers.BatchNormalization(axis=channel_axis, name=\"block1_conv1_bn\")(x)\n    x = layers.Activation(\"relu\", name=\"block1_conv1_act\")(x)\n    x = layers.Conv2D(64, (3, 3), use_bias=False, name=\"block1_conv2\")(x)\n    x = layers.BatchNormalization(axis=channel_axis, name=\"block1_conv2_bn\")(x)\n    x = layers.Activation(\"relu\", name=\"block1_conv2_act\")(x)\n\n    residual = layers.Conv2D(\n        128, (1, 1), strides=(2, 2), padding=\"same\", use_bias=False\n    )(x)\n    residual = layers.BatchNormalization(axis=channel_axis)(residual)\n\n    x = layers.SeparableConv2D(\n        128, (3, 3), padding=\"same\", use_bias=False, name=\"block2_sepconv1\"\n    )(x)\n    x = layers.BatchNormalization(axis=channel_axis, name=\"block2_sepconv1_bn\")(\n        x\n    )\n    x = layers.Activation(\"relu\", name=\"block2_sepconv2_act\")(x)\n    x = layers.SeparableConv2D(\n        128, (3, 3), padding=\"same\", use_bias=False, name=\"block2_sepconv2\"\n    )(x)\n    x = layers.BatchNormalization(axis=channel_axis, name=\"block2_sepconv2_bn\")(\n        x\n    )\n\n    x = layers.MaxPooling2D(\n        (3, 3), strides=(2, 2), padding=\"same\", name=\"block2_pool\"\n    )(x)\n    x = layers.add([x, residual])\n\n    residual = layers.Conv2D(\n        256, (1, 1), strides=(2, 2), padding=\"same\", use_bias=False\n    )(x)\n    residual = layers.BatchNormalization(axis=channel_axis)(residual)\n\n    x = layers.Activation(\"relu\", name=\"block3_sepconv1_act\")(x)\n    x = layers.SeparableConv2D(\n        256, (3, 3), padding=\"same\", use_bias=False, name=\"block3_sepconv1\"\n    )(x)\n    x = layers.BatchNormalization(axis=channel_axis, name=\"block3_sepconv1_bn\")(\n        x\n    )\n    x = layers.Activation(\"relu\", name=\"block3_sepconv2_act\")(x)\n    x = layers.SeparableConv2D(\n        256, (3, 3), padding=\"same\", use_bias=False, name=\"block3_sepconv2\"\n    )(x)\n    x = layers.BatchNormalization(axis=channel_axis, name=\"block3_sepconv2_bn\")(\n        x\n    )\n\n    x = layers.MaxPooling2D(\n        (3, 3), strides=(2, 2), padding=\"same\", name=\"block3_pool\"\n    )(x)\n    x = layers.add([x, residual])\n\n    residual = layers.Conv2D(\n        728, (1, 1), strides=(2, 2), padding=\"same\", use_bias=False\n    )(x)\n    residual = layers.BatchNormalization(axis=channel_axis)(residual)\n\n    x = layers.Activation(\"relu\", name=\"block4_sepconv1_act\")(x)\n    x = layers.SeparableConv2D(\n        728, (3, 3), padding=\"same\", use_bias=False, name=\"block4_sepconv1\"\n    )(x)\n    x = layers.BatchNormalization(axis=channel_axis, name=\"block4_sepconv1_bn\")(\n        x\n    )\n    x = layers.Activation(\"relu\", name=\"block4_sepconv2_act\")(x)\n    x = layers.SeparableConv2D(\n        728, (3, 3), padding=\"same\", use_bias=False, name=\"block4_sepconv2\"\n    )(x)\n    x = layers.BatchNormalization(axis=channel_axis, name=\"block4_sepconv2_bn\")(\n        x\n    )\n\n    x = layers.MaxPooling2D(\n        (3, 3), strides=(2, 2), padding=\"same\", name=\"block4_pool\"\n    )(x)\n    x = layers.add([x, residual])\n\n    for i in range(8):\n        residual = x\n        prefix = f\"block{i + 5}\"\n\n        x = layers.Activation(\"relu\", name=f\"{prefix}_sepconv1_act\")(x)\n        x = layers.SeparableConv2D(\n            728,\n            (3, 3),\n            padding=\"same\",\n            use_bias=False,\n            name=f\"{prefix}_sepconv1\",\n        )(x)\n        x = layers.BatchNormalization(\n            axis=channel_axis, name=f\"{prefix}_sepconv1_bn\"\n        )(x)\n        x = layers.Activation(\"relu\", name=f\"{prefix}_sepconv2_act\")(x)\n        x = layers.SeparableConv2D(\n            728,\n            (3, 3),\n            padding=\"same\",\n            use_bias=False,\n            name=f\"{prefix}_sepconv2\",\n        )(x)\n        x = layers.BatchNormalization(\n            axis=channel_axis, name=f\"{prefix}_sepconv2_bn\"\n        )(x)\n        x = layers.Activation(\"relu\", name=f\"{prefix}_sepconv3_act\")(x)\n        x = layers.SeparableConv2D(\n            728,\n            (3, 3),\n            padding=\"same\",\n            use_bias=False,\n            name=f\"{prefix}_sepconv3\",\n        )(x)\n        x = layers.BatchNormalization(\n            axis=channel_axis, name=f\"{prefix}_sepconv3_bn\"\n        )(x)\n\n        x = layers.add([x, residual])\n\n    residual = layers.Conv2D(\n        1024, (1, 1), strides=(2, 2), padding=\"same\", use_bias=False\n    )(x)\n    residual = layers.BatchNormalization(axis=channel_axis)(residual)\n\n    x = layers.Activation(\"relu\", name=\"block13_sepconv1_act\")(x)\n    x = layers.SeparableConv2D(\n        728, (3, 3), padding=\"same\", use_bias=False, name=\"block13_sepconv1\"\n    )(x)\n    x = layers.BatchNormalization(\n        axis=channel_axis, name=\"block13_sepconv1_bn\"\n    )(x)\n    x = layers.Activation(\"relu\", name=\"block13_sepconv2_act\")(x)\n    x = layers.SeparableConv2D(\n        1024, (3, 3), padding=\"same\", use_bias=False, name=\"block13_sepconv2\"\n    )(x)\n    x = layers.BatchNormalization(\n        axis=channel_axis, name=\"block13_sepconv2_bn\"\n    )(x)\n\n    x = layers.MaxPooling2D(\n        (3, 3), strides=(2, 2), padding=\"same\", name=\"block13_pool\"\n    )(x)\n    x = layers.add([x, residual])\n\n    x = layers.SeparableConv2D(\n        1536, (3, 3), padding=\"same\", use_bias=False, name=\"block14_sepconv1\"\n    )(x)\n    x = layers.BatchNormalization(\n        axis=channel_axis, name=\"block14_sepconv1_bn\"\n    )(x)\n    x = layers.Activation(\"relu\", name=\"block14_sepconv1_act\")(x)\n\n    x = layers.SeparableConv2D(\n        2048, (3, 3), padding=\"same\", use_bias=False, name=\"block14_sepconv2\"\n    )(x)\n    x = layers.BatchNormalization(\n        axis=channel_axis, name=\"block14_sepconv2_bn\"\n    )(x)\n    x = layers.Activation(\"relu\", name=\"block14_sepconv2_act\")(x)\n\n    if include_top:\n        x = layers.GlobalAveragePooling2D(name=\"avg_pool\")(x)\n        imagenet_utils.validate_activation(classifier_activation, weights)\n        x = layers.Dense(\n            classes, activation=classifier_activation, name=\"predictions\"\n        )(x)\n    else:\n        if pooling == \"avg\":\n            x = layers.GlobalAveragePooling2D()(x)\n        elif pooling == \"max\":\n            x = layers.GlobalMaxPooling2D()(x)\n\n    # Ensure that the model takes into account\n    # any potential predecessors of `input_tensor`.\n    if input_tensor is not None:\n        inputs = operation_utils.get_source_inputs(input_tensor)\n    else:\n        inputs = img_input\n    # Create model.\n    model = Functional(inputs, x, name=name)\n\n    # Load weights.\n    if weights == \"imagenet\":\n        if include_top:\n            weights_path = file_utils.get_file(\n                \"xception_weights_tf_dim_ordering_tf_kernels.h5\",\n                WEIGHTS_PATH,\n                cache_subdir=\"models\",\n                file_hash=\"0a58e3b7378bc2990ea3b43d5981f1f6\",\n            )\n        else:\n            weights_path = file_utils.get_file(\n                \"xception_weights_tf_dim_ordering_tf_kernels_notop.h5\",\n                WEIGHTS_PATH_NO_TOP,\n                cache_subdir=\"models\",\n                file_hash=\"b0042744bf5b25fce3cb969f33bebb97\",\n            )\n        model.load_weights(weights_path)\n    elif weights is not None:\n        model.load_weights(weights)\n\n    return model\n\n\n@keras_export(\"keras.applications.xception.preprocess_input\")\ndef preprocess_input(x, data_format=None):\n    return imagenet_utils.preprocess_input(\n        x, data_format=data_format, mode=\"tf\"\n    )\n\n\n@keras_export(\"keras.applications.xception.decode_predictions\")\ndef decode_predictions(preds, top=5):\n    return imagenet_utils.decode_predictions(preds, top=top)\n\n\npreprocess_input.__doc__ = imagenet_utils.PREPROCESS_INPUT_DOC.format(\n    mode=\"\",\n    ret=imagenet_utils.PREPROCESS_INPUT_RET_DOC_TF,\n    error=imagenet_utils.PREPROCESS_INPUT_ERROR_DOC,\n)\ndecode_predictions.__doc__ = imagenet_utils.decode_predictions.__doc__\n"
  },
  {
    "path": "keras/src/backend/__init__.py",
    "content": "from keras.src.backend.config import backend\n\nif backend() == \"torch\":\n    # When using the torch backend,\n    # torch needs to be imported first, otherwise it will segfault\n    # upon import.\n    import torch\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend.common.dtypes import result_type\nfrom keras.src.backend.common.keras_tensor import KerasTensor\nfrom keras.src.backend.common.keras_tensor import any_symbolic_tensors\nfrom keras.src.backend.common.keras_tensor import is_keras_tensor\nfrom keras.src.backend.common.masking import get_keras_mask\nfrom keras.src.backend.common.masking import set_keras_mask\nfrom keras.src.backend.common.stateless_scope import StatelessScope\nfrom keras.src.backend.common.stateless_scope import get_stateless_scope\nfrom keras.src.backend.common.stateless_scope import in_stateless_scope\nfrom keras.src.backend.common.symbolic_scope import SymbolicScope\nfrom keras.src.backend.common.symbolic_scope import in_symbolic_scope\nfrom keras.src.backend.common.variables import AutocastScope\nfrom keras.src.backend.common.variables import Variable\nfrom keras.src.backend.common.variables import get_autocast_scope\nfrom keras.src.backend.common.variables import is_float_dtype\nfrom keras.src.backend.common.variables import is_int_dtype\nfrom keras.src.backend.common.variables import standardize_dtype\nfrom keras.src.backend.common.variables import standardize_shape\nfrom keras.src.backend.config import epsilon\nfrom keras.src.backend.config import floatx\nfrom keras.src.backend.config import image_data_format\nfrom keras.src.backend.config import set_epsilon\nfrom keras.src.backend.config import set_floatx\nfrom keras.src.backend.config import set_image_data_format\nfrom keras.src.backend.config import standardize_data_format\n\n# Import backend functions.\nif backend() == \"tensorflow\":\n    from keras.src.backend.tensorflow import *  # noqa: F403\n    from keras.src.backend.tensorflow.core import Variable as BackendVariable\nelif backend() == \"jax\":\n    from keras.src.backend.jax import *  # noqa: F403\n    from keras.src.backend.jax.core import Variable as BackendVariable\nelif backend() == \"torch\":\n    from keras.src.backend.torch import *  # noqa: F403\n    from keras.src.backend.torch.core import Variable as BackendVariable\n\n    distribution_lib = None\nelif backend() == \"numpy\":\n    from keras.src.backend.numpy import *  # noqa: F403\n    from keras.src.backend.numpy.core import Variable as BackendVariable\n\n    distribution_lib = None\nelif backend() == \"openvino\":\n    from keras.src.backend.openvino import *  # noqa: F403\n    from keras.src.backend.openvino.core import Variable as BackendVariable\n\n    distribution_lib = None\nelse:\n    raise ValueError(f\"Unable to import backend : {backend()}\")\n\n\n@keras_export(\"keras.Variable\")\nclass Variable(BackendVariable):  # noqa: F811\n    pass\n\n\nbackend_name_scope = name_scope  # noqa: F405\n\n\n@keras_export(\"keras.name_scope\")\nclass name_scope(backend_name_scope):\n    pass\n\n\n@keras_export(\"keras.device\")\ndef device(device_name):\n    return device_scope(device_name)  # noqa: F405\n"
  },
  {
    "path": "keras/src/backend/common/__init__.py",
    "content": "from keras.src.backend.common import backend_utils\nfrom keras.src.backend.common.dtypes import result_type\nfrom keras.src.backend.common.variables import AutocastScope\nfrom keras.src.backend.common.variables import Variable as KerasVariable\nfrom keras.src.backend.common.variables import get_autocast_scope\nfrom keras.src.backend.common.variables import is_float_dtype\nfrom keras.src.backend.common.variables import is_int_dtype\nfrom keras.src.backend.common.variables import standardize_dtype\nfrom keras.src.backend.common.variables import standardize_shape\nfrom keras.src.random import random\n"
  },
  {
    "path": "keras/src/backend/common/backend_utils.py",
    "content": "import functools\nimport math\nimport operator\nimport re\nimport warnings\n\n\ndef _convert_conv_transpose_padding_args_from_keras_to_jax(\n    kernel_size, stride, dilation_rate, padding, output_padding\n):\n    \"\"\"Convert the padding arguments from Keras to the ones used by JAX.\n    JAX starts with an shape of size `(input-1) * stride - kernel_size + 2`,\n    then adds `left_pad` on the left, and `right_pad` on the right.\n    In Keras, the `padding` argument determines a base shape, to which\n    `output_padding` is added on the right. If `output_padding` is None, it will\n    be given a default value.\n    \"\"\"\n\n    if padding.lower() not in {\"valid\", \"same\"}:\n        raise ValueError(\n            f\"The `padding` argument must be one of 'valid', 'same'. \"\n            f\"Received: padding={padding}\"\n        )\n    kernel_size = (kernel_size - 1) * dilation_rate + 1\n\n    if padding.lower() == \"valid\":\n        # If output_padding is None, we fill it so that the shape of the output\n        # is `(input-1)*s + max(kernel_size, stride)`\n        output_padding = (\n            max(kernel_size, stride) - kernel_size\n            if output_padding is None\n            else output_padding\n        )\n        left_pad = kernel_size - 1\n        right_pad = kernel_size - 1 + output_padding\n\n    else:\n        if output_padding is None:\n            # When output_padding is None, we want the shape of the output to\n            # be `input * s`, therefore a total padding of\n            # `stride + kernel_size - 2`\n            pad_len = stride + kernel_size - 2\n        else:\n            # When output_padding is filled, we want the shape of the output to\n            # be `(input-1)*stride + kernel_size%2 + output_padding`\n            pad_len = kernel_size + kernel_size % 2 - 2 + output_padding\n        left_pad = min(pad_len // 2 + pad_len % 2, kernel_size - 1)\n        right_pad = pad_len - left_pad\n\n    return left_pad, right_pad\n\n\ndef _convert_conv_transpose_padding_args_from_keras_to_torch(\n    kernel_size, stride, dilation_rate, padding, output_padding\n):\n    \"\"\"Convert the padding arguments from Keras to the ones used by Torch.\n    Torch starts with an output shape of `(input-1) * stride + kernel_size`,\n    then removes `torch_padding` from both sides, and adds\n    `torch_output_padding` on the right.\n    Because in Torch the output_padding can only be added to the right,\n    consistency with Tensorflow is not always possible. In particular this is\n    the case when both the Torch padding and output_padding values are\n    strictly positive.\n    \"\"\"\n    if padding.lower() not in {\"valid\", \"same\"}:\n        raise ValueError(\n            f\"The `padding` argument must be one of 'valid', 'same'. \"\n            f\"Received: padding={padding}\"\n        )\n    original_kernel_size = kernel_size\n    kernel_size = (kernel_size - 1) * dilation_rate + 1\n\n    if padding.lower() == \"valid\":\n        # If output_padding is None, we fill it so that the shape of the output\n        # is `(i-1)*s + max(k, s)`\n        output_padding = (\n            max(kernel_size, stride) - kernel_size\n            if output_padding is None\n            else output_padding\n        )\n        torch_padding = 0\n        torch_output_padding = output_padding\n\n    else:\n        # When output_padding is None, we want the shape of the output to be\n        # `input * s`, otherwise we use the value provided.\n        output_padding = (\n            stride - kernel_size % 2\n            if output_padding is None\n            else output_padding\n        )\n        torch_padding = max(\n            -((kernel_size % 2 - kernel_size + output_padding) // 2), 0\n        )\n        torch_output_padding = (\n            2 * torch_padding + kernel_size % 2 - kernel_size + output_padding\n        )\n\n    if torch_padding > 0 and torch_output_padding > 0:\n        warnings.warn(\n            f\"You might experience inconsistencies across backends when \"\n            f\"calling conv transpose with kernel_size={original_kernel_size}, \"\n            f\"stride={stride}, dilation_rate={dilation_rate}, \"\n            f\"padding={padding}, output_padding={output_padding}.\"\n        )\n\n    if torch_output_padding >= stride:\n        warnings.warn(\n            f\"Torch backend requires output_padding < stride. \"\n            f\"Clamping output_padding {torch_output_padding} -> {stride - 1} \"\n            f\"for stride {stride}.\",\n            UserWarning,\n        )\n        torch_output_padding = stride - 1\n\n    return torch_padding, torch_output_padding\n\n\ndef compute_conv_transpose_padding_args_for_jax(\n    input_shape,\n    kernel_shape,\n    strides,\n    padding,\n    output_padding,\n    dilation_rate,\n):\n    num_spatial_dims = len(input_shape) - 2\n    kernel_spatial_shape = kernel_shape[:-2]\n\n    jax_padding = []\n    for i in range(num_spatial_dims):\n        output_padding_i = (\n            output_padding\n            if output_padding is None or isinstance(output_padding, int)\n            else output_padding[i]\n        )\n        strides_i = strides if isinstance(strides, int) else strides[i]\n        dilation_rate_i = (\n            dilation_rate\n            if isinstance(dilation_rate, int)\n            else dilation_rate[i]\n        )\n        (\n            pad_left,\n            pad_right,\n        ) = _convert_conv_transpose_padding_args_from_keras_to_jax(\n            kernel_size=kernel_spatial_shape[i],\n            stride=strides_i,\n            dilation_rate=dilation_rate_i,\n            padding=padding,\n            output_padding=output_padding_i,\n        )\n        jax_padding.append((pad_left, pad_right))\n\n    return jax_padding\n\n\ndef compute_conv_transpose_padding_args_for_torch(\n    input_shape,\n    kernel_shape,\n    strides,\n    padding,\n    output_padding,\n    dilation_rate,\n):\n    num_spatial_dims = len(input_shape) - 2\n    kernel_spatial_shape = kernel_shape[:-2]\n\n    torch_paddings = []\n    torch_output_paddings = []\n    for i in range(num_spatial_dims):\n        output_padding_i = (\n            output_padding\n            if output_padding is None or isinstance(output_padding, int)\n            else output_padding[i]\n        )\n        strides_i = strides if isinstance(strides, int) else strides[i]\n        dilation_rate_i = (\n            dilation_rate\n            if isinstance(dilation_rate, int)\n            else dilation_rate[i]\n        )\n        (\n            torch_padding,\n            torch_output_padding,\n        ) = _convert_conv_transpose_padding_args_from_keras_to_torch(\n            kernel_size=kernel_spatial_shape[i],\n            stride=strides_i,\n            dilation_rate=dilation_rate_i,\n            padding=padding,\n            output_padding=output_padding_i,\n        )\n        torch_paddings.append(torch_padding)\n        torch_output_paddings.append(torch_output_padding)\n\n    # --- FIX FOR TORCH CONSTRAINT: output_padding < stride ---\n    corrected_output_paddings = []\n    for s, op in zip(\n        strides\n        if isinstance(strides, (list, tuple))\n        else [strides] * num_spatial_dims,\n        torch_output_paddings,\n    ):\n        max_allowed = max(0, s - 1)\n        if op > max_allowed:\n            corrected_output_paddings.append(max_allowed)\n        else:\n            corrected_output_paddings.append(op)\n\n    torch_output_paddings = corrected_output_paddings\n\n    return torch_paddings, torch_output_paddings\n\n\ndef _get_output_shape_given_tf_padding(\n    input_size, kernel_size, strides, padding, output_padding, dilation_rate\n):\n    if input_size is None:\n        return None\n\n    if padding.lower() not in {\"valid\", \"same\"}:\n        raise ValueError(\n            f\"The `padding` argument must be one of 'valid', 'same'. \"\n            f\"Received: padding={padding}\"\n        )\n\n    kernel_size = (kernel_size - 1) * dilation_rate + 1\n\n    if padding.lower() == \"valid\":\n        output_padding = (\n            max(kernel_size, strides) - kernel_size\n            if output_padding is None\n            else output_padding\n        )\n        return (input_size - 1) * strides + kernel_size + output_padding\n\n    else:\n        if output_padding is None:\n            return input_size * strides\n        else:\n            return (input_size - 1) * strides + kernel_size % 2 + output_padding\n\n\ndef compute_conv_transpose_output_shape(\n    input_shape,\n    kernel_size,\n    filters,\n    strides,\n    padding,\n    output_padding=None,\n    data_format=\"channels_last\",\n    dilation_rate=1,\n):\n    num_spatial_dims = len(input_shape) - 2\n    kernel_spatial_shape = kernel_size\n\n    if isinstance(output_padding, int):\n        output_padding = (output_padding,) * len(kernel_spatial_shape)\n    if isinstance(strides, int):\n        strides = (strides,) * num_spatial_dims\n    if isinstance(dilation_rate, int):\n        dilation_rate = (dilation_rate,) * num_spatial_dims\n\n    if data_format == \"channels_last\":\n        input_spatial_shape = input_shape[1:-1]\n    else:\n        input_spatial_shape = input_shape[2:]\n\n    output_shape = []\n    for i in range(num_spatial_dims):\n        current_output_padding = (\n            None if output_padding is None else output_padding[i]\n        )\n\n        shape_i = _get_output_shape_given_tf_padding(\n            input_size=input_spatial_shape[i],\n            kernel_size=kernel_spatial_shape[i],\n            strides=strides[i],\n            padding=padding,\n            output_padding=current_output_padding,\n            dilation_rate=dilation_rate[i],\n        )\n        output_shape.append(shape_i)\n\n    if data_format == \"channels_last\":\n        output_shape = [input_shape[0]] + output_shape + [filters]\n    else:\n        output_shape = [input_shape[0], filters] + output_shape\n    return output_shape\n\n\ndef canonicalize_axis(axis, num_dims):\n    \"\"\"Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims).\"\"\"\n    axis = operator.index(axis)\n    if not -num_dims <= axis < num_dims:\n        raise ValueError(\n            f\"axis {axis} is out of bounds for an array with dimension \"\n            f\"{num_dims}.\"\n        )\n    if axis < 0:\n        axis = axis + num_dims\n    return axis\n\n\ndef standardize_axis_for_numpy(axis):\n    \"\"\"Standardize an axis to a tuple if it is a list in the numpy backend.\"\"\"\n    return tuple(axis) if isinstance(axis, list) else axis\n\n\ndef to_tuple_or_list(value):\n    \"\"\"Convert the non-`None` value to either a tuple or a list.\"\"\"\n    if value is None:\n        return value\n    if not isinstance(value, (int, tuple, list)):\n        raise ValueError(\n            \"`value` must be an integer, tuple or list. \"\n            f\"Received: value={value}\"\n        )\n    if isinstance(value, int):\n        return (value,)\n    return value\n\n\n### Code for ops.vectorize() used for TF and torch backends.\n\n# See http://docs.scipy.org/doc/numpy/reference/c-api.generalized-ufuncs.html\n_DIMENSION_NAME = r\"\\w+\"\n_CORE_DIMENSION_LIST = \"(?:{0:}(?:,{0:})*)?\".format(_DIMENSION_NAME)\n_ARGUMENT = rf\"\\({_CORE_DIMENSION_LIST}\\)\"\n_ARGUMENT_LIST = \"{0:}(?:,{0:})*\".format(_ARGUMENT)\n_SIGNATURE = \"^{0:}->{0:}$\".format(_ARGUMENT_LIST)\n\n\ndef _vectorize_parse_gufunc_signature(\n    signature,\n):\n    if not re.match(_SIGNATURE, signature):\n        raise ValueError(f\"not a valid gufunc signature: {signature}\")\n    args, retvals = (\n        [\n            tuple(re.findall(_DIMENSION_NAME, arg))\n            for arg in re.findall(_ARGUMENT, arg_list)\n        ]\n        for arg_list in signature.split(\"->\")\n    )\n    return args, retvals\n\n\ndef _vectorize_update_dim_sizes(dim_sizes, shape, core_dims, is_input=True):\n    num_core_dims = len(core_dims)\n    if is_input:\n        if len(shape) < num_core_dims:\n            raise ValueError(\n                f\"input with shape {shape} does not \"\n                \"have enough dimensions for all core \"\n                f\"dimensions {core_dims}\"\n            )\n    else:\n        if len(shape) != num_core_dims:\n            raise ValueError(\n                f\"output shape {shape} does not \"\n                f\"match core dimensions {core_dims}\"\n            )\n\n    core_shape = shape[-num_core_dims:] if core_dims else ()\n    for dim, size in zip(core_dims, core_shape):\n        if dim not in dim_sizes:\n            dim_sizes[dim] = size\n        elif size != dim_sizes[dim]:\n            raise ValueError(\n                f\"inconsistent size for core dimension {dim}: \"\n                f\"{size} vs {dim_sizes[dim]}\"\n            )\n\n\ndef _vectorize_parse_input_dimensions(\n    args,\n    input_core_dims,\n):\n    from keras.src import ops\n\n    if len(args) != len(input_core_dims):\n        raise TypeError(\n            \"wrong number of positional arguments: \"\n            f\"expected {len(input_core_dims)}, got {len(args)}\"\n        )\n    shapes = []\n    dim_sizes = {}\n    for arg, core_dims in zip(args, input_core_dims):\n        _vectorize_update_dim_sizes(\n            dim_sizes, arg.shape, core_dims, is_input=True\n        )\n        ndim = arg.ndim - len(core_dims)\n        shapes.append(arg.shape[:ndim])\n    broadcast_shape = shapes[0]\n    for s in shapes:\n        broadcast_shape = ops.broadcast_shapes(broadcast_shape, s)\n    return broadcast_shape, dim_sizes\n\n\ndef _vectorize_check_output_dims(\n    func,\n    dim_sizes,\n    expected_output_core_dims,\n):\n    from keras.src import ops\n\n    def wrapped(*args):\n        out = func(*args)\n        if isinstance(out, (list, tuple)):\n            out_shapes = [ops.shape(x) for x in out]\n        else:\n            out_shapes = [out.shape]\n\n        if expected_output_core_dims is None:\n            output_core_dims = [()] * len(out_shapes)\n        else:\n            output_core_dims = expected_output_core_dims\n            if len(output_core_dims) > 1 and not isinstance(out, tuple):\n                raise TypeError(\n                    \"output must be a tuple when multiple outputs \"\n                    f\"are expected, got: {out}\"\n                )\n            if len(out_shapes) != len(output_core_dims):\n                raise TypeError(\n                    \"wrong number of output arguments: \"\n                    f\"expected {len(output_core_dims)}, got {len(out_shapes)}\"\n                )\n\n        sizes = dict(dim_sizes)\n        for shape, core_dims in zip(out_shapes, output_core_dims):\n            _vectorize_update_dim_sizes(sizes, shape, core_dims, is_input=False)\n\n        return out\n\n    return wrapped\n\n\ndef _vectorize_apply_excluded(func, excluded, args, kwargs):\n    if not excluded:\n        return func, args, kwargs\n\n    dynamic_args = [arg for i, arg in enumerate(args) if i not in excluded]\n    dynamic_kwargs = {\n        key: val for key, val in kwargs.items() if key not in excluded\n    }\n    static_args = [\n        (i, args[i])\n        for i in sorted(e for e in excluded if isinstance(e, int))\n        if i < len(args)\n    ]\n    static_kwargs = {key: val for key, val in kwargs.items() if key in excluded}\n\n    def new_func(*args, **kwargs):\n        args = list(args)\n        for i, arg in static_args:\n            args.insert(i, arg)\n        return func(*args, **kwargs, **static_kwargs)\n\n    return new_func, dynamic_args, dynamic_kwargs\n\n\ndef vectorize_impl(pyfunc, vmap_fn, *, excluded=None, signature=None):\n    \"\"\"Implementation adapted from JAX and NumPy.\"\"\"\n\n    from keras.src import ops\n\n    excluded = None or set()\n\n    @functools.wraps(pyfunc)\n    def wrapped(*args, **kwargs):\n        excluded_func, args, kwargs = _vectorize_apply_excluded(\n            pyfunc, excluded, args, kwargs\n        )\n\n        if signature is not None:\n            input_core_dims, output_core_dims = (\n                _vectorize_parse_gufunc_signature(signature)\n            )\n        else:\n            input_core_dims = [()] * len(args)\n            output_core_dims = None\n\n        none_args = {i for i, arg in enumerate(args) if arg is None}\n        if any(none_args):\n            if any(input_core_dims[i] != () for i in none_args):\n                raise ValueError(\n                    f\"Cannot pass None at locations {none_args} \"\n                    f\"with signature={signature}\"\n                )\n            excluded_func, args, _ = _vectorize_apply_excluded(\n                excluded_func, none_args, args, {}\n            )\n            input_core_dims = [\n                dim\n                for i, dim in enumerate(input_core_dims)\n                if i not in none_args\n            ]\n\n        args = tuple(map(ops.convert_to_tensor, args))\n\n        broadcast_shape, dim_sizes = _vectorize_parse_input_dimensions(\n            args, input_core_dims\n        )\n        checked_func = _vectorize_check_output_dims(\n            excluded_func, dim_sizes, output_core_dims\n        )\n        squeezed_args = []\n        rev_filled_shapes = []\n        for arg, core_dims in zip(args, input_core_dims):\n            noncore_shape = arg.shape[: arg.ndim - len(core_dims)]\n\n            pad_ndim = len(broadcast_shape) - len(noncore_shape)\n            filled_shape = pad_ndim * (1,) + noncore_shape\n            rev_filled_shapes.append(filled_shape[::-1])\n\n            squeeze_indices = tuple(\n                i for i, size in enumerate(noncore_shape) if size == 1\n            )\n            squeezed_arg = ops.squeeze(arg, axis=squeeze_indices)\n            squeezed_args.append(squeezed_arg)\n\n        vectorized_func = checked_func\n        dims_to_expand = []\n        for negdim, axis_sizes in enumerate(zip(*rev_filled_shapes)):\n            in_axes = tuple(None if size == 1 else 0 for size in axis_sizes)\n            if all(axis is None for axis in in_axes):\n                dims_to_expand.append(len(broadcast_shape) - 1 - negdim)\n            else:\n                vectorized_func = vmap_fn(vectorized_func, in_axes)\n        result = vectorized_func(*squeezed_args)\n\n        if not dims_to_expand:\n            return result\n        elif isinstance(result, tuple):\n            return tuple(\n                ops.expand_dims(r, axis=dims_to_expand) for r in result\n            )\n        else:\n            return ops.expand_dims(result, axis=dims_to_expand)\n\n    return wrapped\n\n\ndef slice_along_axis(x, start=0, stop=None, step=1, axis=0):\n    \"\"\"Slice a Tensor along the given axis.\"\"\"\n    # Ref: same util function defined in tfp.math.scan_associative\n    if axis >= 0:\n        slices = [slice(None)] * axis + [slice(start, stop, step)]\n    else:\n        slices = [Ellipsis, slice(start, stop, step)] + [slice(None)] * (\n            -1 - axis\n        )\n    return x[tuple(slices)]\n\n\ndef compute_adaptive_pooling_window_sizes(input_dim, output_dim):\n    \"\"\"Compute small and big window sizes for adaptive pooling.\"\"\"\n    small = math.ceil(input_dim / output_dim)\n    big = small + 1\n    return small, big\n"
  },
  {
    "path": "keras/src/backend/common/backend_utils_test.py",
    "content": "from keras.src.backend.common.backend_utils import (\n    _convert_conv_transpose_padding_args_from_keras_to_jax,\n)\nfrom keras.src.backend.common.backend_utils import (\n    _convert_conv_transpose_padding_args_from_keras_to_torch,\n)\nfrom keras.src.backend.common.backend_utils import (\n    _get_output_shape_given_tf_padding,\n)\nfrom keras.src.backend.common.backend_utils import (\n    compute_conv_transpose_padding_args_for_jax,\n)\nfrom keras.src.backend.common.backend_utils import (\n    compute_conv_transpose_padding_args_for_torch,\n)\nfrom keras.src.testing import test_case\n\n\nclass ConvertConvTransposePaddingArgsJAXTest(test_case.TestCase):\n    def test_valid_padding_without_output_padding(self):\n        \"\"\"Test conversion with 'valid' padding and no output padding\"\"\"\n        (\n            left_pad,\n            right_pad,\n        ) = _convert_conv_transpose_padding_args_from_keras_to_jax(\n            kernel_size=3,\n            stride=2,\n            dilation_rate=1,\n            padding=\"valid\",\n            output_padding=None,\n        )\n        self.assertEqual(left_pad, 2)\n        self.assertEqual(right_pad, 2)\n\n    def test_same_padding_without_output_padding(self):\n        \"\"\"Test conversion with 'same' padding and no output padding.\"\"\"\n        (\n            left_pad,\n            right_pad,\n        ) = _convert_conv_transpose_padding_args_from_keras_to_jax(\n            kernel_size=3,\n            stride=2,\n            dilation_rate=1,\n            padding=\"same\",\n            output_padding=None,\n        )\n        self.assertEqual(left_pad, 2)\n        self.assertEqual(right_pad, 1)\n\n\nclass ConvertConvTransposePaddingArgsTorchTest(test_case.TestCase):\n    def test_valid_padding_without_output_padding(self):\n        \"\"\"Test conversion with 'valid' padding and no output padding\"\"\"\n        (\n            torch_padding,\n            torch_output_padding,\n        ) = _convert_conv_transpose_padding_args_from_keras_to_torch(\n            kernel_size=3,\n            stride=2,\n            dilation_rate=1,\n            padding=\"valid\",\n            output_padding=None,\n        )\n        self.assertEqual(torch_padding, 0)\n        self.assertEqual(torch_output_padding, 0)\n\n    def test_same_padding_without_output_padding(self):\n        \"\"\"Test conversion with 'same' padding and no output padding\"\"\"\n        (\n            torch_padding,\n            torch_output_padding,\n        ) = _convert_conv_transpose_padding_args_from_keras_to_torch(\n            kernel_size=3,\n            stride=2,\n            dilation_rate=1,\n            padding=\"same\",\n            output_padding=None,\n        )\n        self.assertEqual(torch_padding, 1)\n        self.assertEqual(torch_output_padding, 1)\n\n\nclass ComputeConvTransposePaddingArgsForJAXTest(test_case.TestCase):\n    def test_valid_padding_without_output_padding(self):\n        \"\"\"Test computation with 'valid' padding and no output padding\"\"\"\n        jax_padding = compute_conv_transpose_padding_args_for_jax(\n            input_shape=(1, 5, 5, 3),\n            kernel_shape=(3, 3, 3, 3),\n            strides=2,\n            padding=\"valid\",\n            output_padding=None,\n            dilation_rate=1,\n        )\n        self.assertEqual(jax_padding, [(2, 2), (2, 2)])\n\n    def test_same_padding_without_output_padding(self):\n        \"\"\"Test computation with 'same' padding and no output padding\"\"\"\n        jax_padding = compute_conv_transpose_padding_args_for_jax(\n            input_shape=(1, 5, 5, 3),\n            kernel_shape=(3, 3, 3, 3),\n            strides=2,\n            padding=\"same\",\n            output_padding=None,\n            dilation_rate=1,\n        )\n\n        self.assertEqual(jax_padding, [(2, 1), (2, 1)])\n\n\nclass ComputeConvTransposePaddingArgsForTorchTest(test_case.TestCase):\n    def test_valid_padding_without_output_padding(self):\n        \"\"\"Test computation with 'valid' padding and no output padding\"\"\"\n        (\n            torch_paddings,\n            torch_output_paddings,\n        ) = compute_conv_transpose_padding_args_for_torch(\n            input_shape=(1, 5, 5, 3),\n            kernel_shape=(3, 3, 3, 3),\n            strides=2,\n            padding=\"valid\",\n            output_padding=None,\n            dilation_rate=1,\n        )\n        self.assertEqual(torch_paddings, [0, 0])\n        self.assertEqual(torch_output_paddings, [0, 0])\n\n    def test_same_padding_without_output_padding(self):\n        \"\"\"Test computation with 'same' padding and no output padding\"\"\"\n        (\n            torch_paddings,\n            torch_output_paddings,\n        ) = compute_conv_transpose_padding_args_for_torch(\n            input_shape=(1, 5, 5, 3),\n            kernel_shape=(3, 3, 3, 3),\n            strides=2,\n            padding=\"same\",\n            output_padding=None,\n            dilation_rate=1,\n        )\n        self.assertEqual(torch_paddings, [1, 1])\n        self.assertEqual(torch_output_paddings, [1, 1])\n\n    def test_valid_padding_with_none_output_padding(self):\n        \"\"\"Test conversion with 'valid' padding and no output padding\"\"\"\n        (\n            torch_padding,\n            torch_output_padding,\n        ) = _convert_conv_transpose_padding_args_from_keras_to_torch(\n            kernel_size=3,\n            stride=2,\n            dilation_rate=1,\n            padding=\"valid\",\n            output_padding=None,\n        )\n        self.assertEqual(torch_padding, 0)\n        self.assertEqual(torch_output_padding, 0)\n\n    def test_valid_padding_with_output_padding(self):\n        \"\"\"Test conversion with 'valid' padding and output padding for Torch.\"\"\"\n        (\n            torch_padding,\n            torch_output_padding,\n        ) = _convert_conv_transpose_padding_args_from_keras_to_torch(\n            kernel_size=3,\n            stride=2,\n            dilation_rate=1,\n            padding=\"valid\",\n            output_padding=1,\n        )\n        self.assertEqual(torch_padding, 0)\n        self.assertEqual(torch_output_padding, 1)\n\n    def test_output_padding_clamped_for_torch_constraint(self):\n        \"\"\"Test that output_padding is clamped\n        when >= stride (Torch constraint).\n        \"\"\"\n        (\n            torch_paddings,\n            torch_output_paddings,\n        ) = compute_conv_transpose_padding_args_for_torch(\n            input_shape=(1, 8, 8, 8, 16),  # any shape\n            kernel_shape=(2, 2, 2, 16, 32),  # Keras kernel shape\n            strides=1,\n            padding=\"same\",\n            output_padding=1,  # Keras wants this\n            dilation_rate=1,\n        )\n        # Torch expects output_padding < stride (1)\n        # so output_padding should be clamped to 0\n        self.assertEqual(torch_output_paddings, [0, 0, 0])\n\n\nclass GetOutputShapeGivenTFPaddingTest(test_case.TestCase):\n    def test_valid_padding_without_output_padding(self):\n        \"\"\"Test computation with 'valid' padding and no output padding.\"\"\"\n        output_shape = _get_output_shape_given_tf_padding(\n            input_size=5,\n            kernel_size=3,\n            strides=2,\n            padding=\"valid\",\n            output_padding=None,\n            dilation_rate=1,\n        )\n        self.assertEqual(output_shape, 11)\n\n    def test_same_padding_without_output_padding(self):\n        \"\"\"Test computation with 'same' padding and no output padding.\"\"\"\n        output_shape = _get_output_shape_given_tf_padding(\n            input_size=5,\n            kernel_size=3,\n            strides=2,\n            padding=\"same\",\n            output_padding=None,\n            dilation_rate=1,\n        )\n        self.assertEqual(output_shape, 10)\n\n    def test_valid_padding_with_output_padding(self):\n        \"\"\"Test computation with 'valid' padding and output padding.\"\"\"\n        output_shape = _get_output_shape_given_tf_padding(\n            input_size=5,\n            kernel_size=3,\n            strides=2,\n            padding=\"valid\",\n            output_padding=1,\n            dilation_rate=1,\n        )\n        self.assertEqual(output_shape, 12)\n\n    def test_warning_for_inconsistencies(self):\n        \"\"\"Test that a warning is raised for potential inconsistencies\"\"\"\n        with self.assertWarns(Warning):\n            _convert_conv_transpose_padding_args_from_keras_to_torch(\n                kernel_size=3,\n                stride=2,\n                dilation_rate=1,\n                padding=\"same\",\n                output_padding=1,\n            )\n\n    def test_same_padding_without_output_padding_for_torch_(self):\n        \"\"\"Test conversion with 'same' padding and no output padding.\"\"\"\n        (\n            torch_padding,\n            torch_output_padding,\n        ) = _convert_conv_transpose_padding_args_from_keras_to_torch(\n            kernel_size=3,\n            stride=2,\n            dilation_rate=1,\n            padding=\"same\",\n            output_padding=None,\n        )\n        self.assertEqual(torch_padding, max(-((3 % 2 - 3) // 2), 0))\n        self.assertEqual(torch_output_padding, 1)\n"
  },
  {
    "path": "keras/src/backend/common/compute_output_spec_test.py",
    "content": "import pytest\n\nfrom keras.src import backend\nfrom keras.src import testing\n\n\ndef example_fn(x):\n    x = (x + 2) * backend.numpy.ones_like(x)\n    x = backend.numpy.stack([x, x], axis=-1)\n    return x\n\n\nclass ComputeOutputSpecTest(testing.TestCase):\n    def test_basics(self):\n        out = backend.compute_output_spec(\n            example_fn, backend.KerasTensor((2, 3))\n        )\n        self.assertIsInstance(out, backend.KerasTensor)\n        self.assertEqual(out.shape, (2, 3, 2))\n\n        out = backend.compute_output_spec(\n            example_fn, backend.KerasTensor((None, 3))\n        )\n        self.assertIsInstance(out, backend.KerasTensor)\n        self.assertEqual(out.shape, (None, 3, 2))\n\n        out = backend.compute_output_spec(\n            example_fn, backend.KerasTensor((2, None))\n        )\n        self.assertIsInstance(out, backend.KerasTensor)\n        self.assertEqual(out.shape, (2, None, 2))\n\n    @pytest.mark.skipif(\n        backend.backend() != \"torch\", reason=\"Only applicable for torch\"\n    )\n    def test_torch_meta_device_incompatible_ops(self):\n        class Container:\n            def __init__(self):\n                self.canary = False\n\n            def example_meta_fn(self, x):\n                y = backend.numpy.ones(x.shape)\n                if str(y.device) == \"meta\":\n                    self.canary = True\n                    raise ValueError(\"Erroring out on meta device\")\n                x = (x + 2) * y\n                x = backend.numpy.stack([x, x], axis=-1)\n                return x\n\n        instance = Container()\n        out = backend.compute_output_spec(\n            instance.example_meta_fn, backend.KerasTensor((2, 3))\n        )\n        self.assertIsInstance(out, backend.KerasTensor)\n        self.assertTrue(instance.canary)\n        self.assertEqual(out.shape, (2, 3, 2))\n\n        instance = Container()\n        out = backend.compute_output_spec(\n            instance.example_meta_fn, backend.KerasTensor((2, None))\n        )\n        self.assertIsInstance(out, backend.KerasTensor)\n        self.assertTrue(instance.canary)\n        self.assertEqual(out.shape, (2, None, 2))\n"
  },
  {
    "path": "keras/src/backend/common/dtypes.py",
    "content": "import functools\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend import config\nfrom keras.src.backend.common.variables import standardize_dtype\n\nBOOL_TYPES = (\"bool\",)\nINT_TYPES = (\n    \"uint8\",\n    \"uint16\",\n    \"uint32\",\n    \"uint64\",\n    \"int8\",\n    \"int16\",\n    \"int32\",\n    \"int64\",\n)\nFLOAT_TYPES = (\"bfloat16\", \"float16\", \"float32\", \"float64\")\nWEAK_TYPES = (\"int\", \"float\")\nCOMPLEX_TYPES = (\"complex64\", \"complex128\")\n# We need to separate float8 from float because there are no implicit\n# conversions from float8 dtypes to other dtypes.\n# Ref: https://github.com/google/jax/issues/16705\nFLOAT8_TYPES = (\"float8_e4m3fn\", \"float8_e5m2\")\n\n# All supported dtypes in Keras\nALLOWED_DTYPES = (\n    \"float16\",\n    \"float32\",\n    \"float64\",\n    \"uint8\",\n    \"uint16\",\n    \"uint32\",\n    \"uint64\",\n    \"int8\",\n    \"int16\",\n    \"int32\",\n    \"int64\",\n    \"bfloat16\",\n    \"bool\",\n    \"string\",\n    \"float8_e4m3fn\",\n    \"float8_e5m2\",\n    \"complex64\",\n    \"complex128\",\n)\nPYTHON_DTYPES_MAP = {\n    bool: \"bool\",\n    int: \"int64\" if config.backend() == \"tensorflow\" else \"int32\",\n    float: \"float32\",\n    str: \"string\",\n    # special case for string value\n    \"int\": \"int64\" if config.backend() == \"tensorflow\" else \"int32\",\n    complex: \"complex128\" if config.backend() == \"tensorflow\" else \"complex64\",\n}\n\n# We adapted the type promotion lattice from JAX. Ref:\n# https://github.com/google/jax/blob/main/jax/_src/dtypes.py\n\n\ndef _type_promotion_lattice():\n    \"\"\"\n    Return the type promotion lattice in the form of a DAG.\n    This DAG maps each type to its immediately higher type on the lattice.\n    \"\"\"\n    (b1,) = BOOL_TYPES\n    (u1, u2, u4, u8, i1, i2, i4, i8) = INT_TYPES\n    bf, f2, f4, f8 = FLOAT_TYPES\n    i_, f_ = WEAK_TYPES\n    c64, c128 = COMPLEX_TYPES\n    out = {\n        b1: [i_],\n        u1: [i2, u2],\n        u2: [i4, u4],\n        u4: [i8, u8],\n        u8: [f_],\n        i_: [u1, i1, c64],\n        i1: [i2],\n        i2: [i4],\n        i4: [i8],\n        i8: [f_],\n        f_: [bf, f2],\n        bf: [f4],\n        f2: [f4],\n        f4: [f8, c64],\n        f8: [c128],\n        c64: [c128],\n        c128: [],\n    }\n    return out\n\n\ndef _make_lattice_upper_bounds():\n    lattice = _type_promotion_lattice()\n    upper_bounds = {node: {node} for node in lattice}\n    for n in lattice:\n        while True:\n            new_upper_bounds = set().union(\n                *(lattice[b] for b in upper_bounds[n])\n            )\n            if n in new_upper_bounds:\n                raise ValueError(\n                    f\"cycle detected in type promotion lattice for node {n}\"\n                )\n            if new_upper_bounds.issubset(upper_bounds[n]):\n                break\n            upper_bounds[n] |= new_upper_bounds\n    return upper_bounds\n\n\nLATTICE_UPPER_BOUNDS = _make_lattice_upper_bounds()\n\n\n@functools.lru_cache(512)\ndef _least_upper_bound(*nodes):\n    \"\"\"Compute the least upper bound of a set of nodes.\n\n    Args:\n        nodes: sequence of entries from dtypes + weak_types\n\n    Returns:\n        The type representing the least upper bound of the input nodes on the\n        promotion lattice.\n    \"\"\"\n    # This function computes the least upper bound of a set of nodes N within a\n    # partially ordered set defined by the lattice generated above.\n    # Given a partially ordered set S, let the set of upper bounds of n ∈ S be\n    #   UB(n) ≡ {m ∈ S | n ≤ m}\n    # Further, for a set of nodes N ⊆ S, let the set of common upper bounds be\n    # given by\n    #   CUB(N) ≡ {a ∈ S | ∀ b ∈ N: a ∈ UB(b)}\n    # Then the least upper bound of N is defined as\n    #   LUB(N) ≡ {c ∈ CUB(N) | ∀ d ∈ CUB(N), c ≤ d}\n    # The definition of an upper bound implies that\n    #   c ≤ d if and only if d ∈ UB(c),\n    # so the LUB can be expressed:\n    #   LUB(N) = {c ∈ CUB(N) | ∀ d ∈ CUB(N): d ∈ UB(c)}\n    # or, equivalently:\n    #   LUB(N) = {c ∈ CUB(N) | CUB(N) ⊆ UB(c)}\n    # By definition, LUB(N) has a cardinality of 1 for a partially ordered set.\n    # Note a potential algorithmic shortcut: from the definition of CUB(N),\n    # we have\n    #   ∀ c ∈ N: CUB(N) ⊆ UB(c)\n    # So if N ∩ CUB(N) is nonempty, if follows that LUB(N) = N ∩ CUB(N).\n    N = set(nodes)\n    UB = LATTICE_UPPER_BOUNDS\n    try:\n        bounds = [UB[n] for n in N]\n    except KeyError:\n        dtype = next(n for n in N if n not in UB)\n        raise ValueError(\n            f\"{dtype=} is not a valid dtype for Keras type promotion.\"\n        )\n    CUB = set.intersection(*bounds)\n    LUB = (CUB & N) or {c for c in CUB if CUB.issubset(UB[c])}\n    if len(LUB) == 1:\n        return LUB.pop()\n    elif len(LUB) == 0:\n        msg = (\n            f\"Input dtypes {tuple(str(n) for n in nodes)} have no available \"\n            \"implicit dtype promotion path. Try explicitly casting inputs to \"\n            \"the desired output type.\"\n        )\n        raise ValueError(msg)\n    else:\n        # If we get here, it means the lattice is ill-formed.\n        raise ValueError(\n            f\"Internal Type Promotion error: {nodes} do not have a unique \"\n            f\"least upper bound on the specified lattice; options are {LUB}. \"\n            \"This is an unexpected error in Keras's internal logic; \"\n            \"please report it to the maintainers.\"\n        )\n\n\ndef _dtype_and_weaktype(value):\n    \"\"\"Return a (dtype, weak_type) tuple for the given input.\"\"\"\n    is_weak_type = False\n    if value is int or value is float:\n        # Note that we can't use `value in [int, float]` because the dtype\n        # might be equal to python scalar types.\n        # e.g, tf.float32 == float is True\n        is_weak_type = True\n    return standardize_dtype(value), is_weak_type\n\n\n@functools.lru_cache(maxsize=None)\ndef _respect_weak_type(dtype, weak_type):\n    \"\"\"Return the weak dtype of `dtype` if `weak_type==True`.\"\"\"\n    if weak_type:\n        if dtype == \"bool\":\n            return dtype\n        elif \"float\" in dtype:\n            return \"float\"\n        elif \"int\" in dtype:\n            return \"int\"\n        elif \"complex\" in dtype:\n            return \"complex\"\n        else:\n            raise ValueError(\n                \"Invalid value for argument `dtype`. Expected one of \"\n                f\"{ALLOWED_DTYPES}. Received: dtype={dtype}\"\n            )\n    return dtype\n\n\n@functools.lru_cache(maxsize=None)\ndef _resolve_weak_type(dtype, precision=\"32\"):\n    \"\"\"Resolve weak type by the precision of `backend.floatx()`.\"\"\"\n    extended_allowed_dtypes = set(ALLOWED_DTYPES).union(WEAK_TYPES)\n    if dtype not in extended_allowed_dtypes:\n        raise ValueError(\n            \"Invalid value for argument `dtype`. Expected one of \"\n            f\"{extended_allowed_dtypes}. Received: dtype={dtype}\"\n        )\n    if precision not in [\"16\", \"32\", \"64\"]:\n        raise ValueError(\n            f\"Invalid value for argument `precision`. Expected one of \"\n            f\"('16', '32', '64'). Received: precision={precision}\"\n        )\n    if dtype == \"bfloat16\":  # special case for bfloat16\n        dtype_indicator = \"f\"\n    else:\n        dtype_indicator = dtype[:1]\n\n    if dtype_indicator == \"b\":\n        return \"bool\"\n    elif dtype_indicator == \"i\":\n        return f\"int{precision}\"\n    elif dtype_indicator == \"u\":\n        return f\"uint{precision}\"\n    else:\n        return f\"float{precision}\"\n\n\nBIT64_TO_BIT32_DTYPE = {\n    # Since TF variables require int64 to be placed on the GPU, we exclusively\n    # enable the int64 dtype for TF.\n    \"int64\": \"int64\" if config.backend() == \"tensorflow\" else \"int32\",\n    \"uint64\": \"uint32\",\n    \"float64\": \"float64\" if config.backend() == \"tensorflow\" else \"float32\",\n    \"complex128\": \"complex64\",\n}\n\n\ndef _lattice_result_type(*args):\n    dtypes, weak_types = zip(*(_dtype_and_weaktype(arg) for arg in args))\n    if len(dtypes) == 1:\n        out_dtype = dtypes[0]\n        out_weak_type = weak_types[0]\n    elif len(set(dtypes)) == 1 and not all(weak_types):\n        # Trivial promotion case. This allows extended dtypes through.\n        out_dtype = dtypes[0]\n        out_weak_type = False\n    elif all(weak_types):\n        # If all inputs are weakly typed, we compute the bound of the\n        # strongly-typed counterparts and apply the weak type at the end. This\n        # avoids returning the incorrect result with non-canonical weak types\n        # (e.g. weak int16).\n        out_dtype = _least_upper_bound(\n            *{_respect_weak_type(d, False) for d in dtypes}\n        )\n        out_weak_type = True\n    else:\n        out_dtype = _least_upper_bound(\n            *{_respect_weak_type(d, w) for d, w in zip(dtypes, weak_types)}\n        )\n        out_weak_type = any(out_dtype is t for t in WEAK_TYPES)\n\n    out_weak_type = (out_dtype != \"bool\") and out_weak_type\n    precision = config.floatx()[-2:]\n    if out_weak_type:\n        out_dtype = _resolve_weak_type(out_dtype, precision=precision)\n\n    # Force to be 32-bit dtype when encountering 64-bit dtype. This is to\n    # be aligned with JAX's default behavior.\n    out_dtype = BIT64_TO_BIT32_DTYPE.get(out_dtype, out_dtype)\n    return out_dtype\n\n\n@keras_export(\"keras.backend.result_type\")\ndef result_type(*dtypes):\n    \"\"\"Returns the type from applying the Keras type promotion rules.\n\n    In general, each argument is first parsed by `backend.standardize_dtype`,\n    and the resulting dtype is determined by the least upper bound of the type\n    promotion lattice.\n\n    Note: This function attempts to match the result of `jnp.result_type`.\n\n    Args:\n        dtypes: Input dtypes.\n\n    Returns:\n        The result dtype.\n\n    Examples:\n\n    >>> x = keras.ops.ones((1,), dtype=\"bfloat16\")\n    >>> keras.backend.result_type(x.dtype, int)\n    \"bfloat16\"\n\n    >>> x = keras.ops.ones((1,), dtype=\"int32\")\n    >>> y = keras.ops.ones((1,), dtype=\"float32\")\n    >>> keras.backend.result_type(x.dtype, y.dtype)\n    \"float32\"\n\n    >>> z= keras.ops.ones((1,), dtype='complex64')\n    >>> keras.backend.result_type(z.dtype, int)\n    \"float64\"\n\n    \"\"\"\n    if len(dtypes) == 0:\n        # If no dtypes provided, default to floatx, this matches\n        # `ops.convert_to_tensor([])`\n        return config.floatx()\n    for dtype in dtypes:\n        if dtype in FLOAT8_TYPES:\n            raise ValueError(\n                \"There is no implicit conversions from float8 dtypes to others.\"\n                f\" You must cast it internally. Received: {dtypes}\"\n            )\n    return _lattice_result_type(\n        *(config.floatx() if arg is None else arg for arg in dtypes),\n    )\n"
  },
  {
    "path": "keras/src/backend/common/dtypes_test.py",
    "content": "from unittest.mock import patch\n\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src.backend.common import dtypes\nfrom keras.src.testing import test_case\nfrom keras.src.testing.test_utils import named_product\n\n\nclass DtypesTest(test_case.TestCase):\n    \"\"\"Test the dtype to verify that the behavior matches JAX.\"\"\"\n\n    ALL_DTYPES = [\n        x\n        for x in dtypes.ALLOWED_DTYPES\n        if x\n        not in (\n            \"string\",\n            \"complex128\",\n            \"float64\",\n            \"uint64\",\n            \"int64\",\n        )\n        + dtypes.FLOAT8_TYPES  # Remove float8 dtypes for the following tests\n    ] + [None]\n    if backend.backend() == \"torch\":\n        ALL_DTYPES = [x for x in ALL_DTYPES if x not in (\"uint16\", \"uint32\")]\n    elif backend.backend() == \"tensorflow\":\n        # TODO(hongyu): Re-enable uint32 tests once we determine how to handle\n        # dtypes.result_type(uint32, int*) -> int64 promotion.\n        # Since TF variables require int64 to be placed on the GPU, we\n        # exclusively enable the int64 dtype for TF. However, JAX does not\n        # natively support int64, which prevents us from comparing the dtypes.\n        ALL_DTYPES = [x for x in ALL_DTYPES if x not in (\"uint32\",)]\n    elif backend.backend() == \"openvino\":\n        ALL_DTYPES = [x for x in ALL_DTYPES if x not in (\"complex64\",)]\n\n    @parameterized.named_parameters(\n        named_product(dtype1=ALL_DTYPES, dtype2=[bool, int, float])\n    )\n    def test_result_type_with_python_scalar_types(self, dtype1, dtype2):\n        import jax.numpy as jnp\n\n        out = backend.result_type(dtype1, dtype2)\n        expected = jnp.result_type(dtype1, dtype2).name\n        self.assertEqual(out, expected)\n\n    @parameterized.named_parameters(\n        named_product(dtype1=ALL_DTYPES, dtype2=ALL_DTYPES)\n    )\n    def test_result_type_with_tensor(self, dtype1, dtype2):\n        import jax.numpy as jnp\n\n        x1 = ops.ones((1,), dtype=dtype1)\n        x2 = ops.ones((1,), dtype=dtype2)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n\n        out = backend.result_type(x1.dtype, x2.dtype)\n        expected = jnp.result_type(x1_jax, x2_jax).name\n        self.assertEqual(out, expected)\n\n    @parameterized.named_parameters(\n        named_product(\n            dtype=[\n                \"int8\",\n                \"int16\",\n                \"int32\",\n                \"int64\",\n                \"uint8\",\n                \"uint16\",\n                \"uint32\",\n            ]\n        )\n    )\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\", reason=\"TensorFlow only\"\n    )\n    def test_result_type_with_int64(self, dtype):\n        # https://github.com/keras-team/keras/issues/21677\n        x1 = ops.ones((1,), dtype=\"int64\")\n        x2 = ops.ones((1,), dtype=dtype)\n        out = backend.result_type(x1.dtype, x2.dtype)\n        self.assertEqual(out, \"int64\")\n\n    @parameterized.named_parameters(\n        named_product(\n            dtype=[\n                \"float16\",\n                \"bfloat16\",\n                \"float32\",\n                \"float64\",\n                \"int8\",\n                \"int16\",\n                \"int32\",\n                \"int64\",\n                \"uint8\",\n                \"uint16\",\n            ]\n        )\n    )\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\", reason=\"TensorFlow only\"\n    )\n    def test_result_type_with_float64(self, dtype):\n        # Float types have a similar issue as int64 in TF.:\n        # https://github.com/keras-team/keras/issues/21677\n        x1 = ops.ones((1,), dtype=\"float64\")\n        x2 = ops.ones((1,), dtype=dtype)\n        out = backend.result_type(x1.dtype, x2.dtype)\n        self.assertEqual(out, \"float64\")\n\n    def test_result_type_with_none(self):\n        import jax.numpy as jnp\n\n        self.assertEqual(backend.result_type(None), jnp.result_type(None).name)\n\n    def test_result_type_empty_list(self):\n        self.assertEqual(backend.result_type(), \"float32\")\n\n    def test_respect_weak_type_for_bool(self):\n        self.assertEqual(dtypes._respect_weak_type(\"bool\", True), \"bool\")\n\n    def test_respect_weak_type_for_int(self):\n        self.assertEqual(dtypes._respect_weak_type(\"int32\", True), \"int\")\n\n    def test_respect_weak_type_for_float(self):\n        self.assertEqual(dtypes._respect_weak_type(\"float32\", True), \"float\")\n\n    def test_resolve_weak_type_for_bfloat16(self):\n        self.assertEqual(dtypes._resolve_weak_type(\"bfloat16\"), \"float32\")\n\n    def test_resolve_weak_type_for_bfloat16_with_precision(self):\n        self.assertEqual(\n            dtypes._resolve_weak_type(\"bfloat16\", precision=\"64\"), \"float64\"\n        )\n\n    def test_respect_weak_type_for_complex64(self):\n        self.assertAllEqual(\n            dtypes._respect_weak_type(\"complex64\", True), \"complex\"\n        )\n\n    def test_respect_weak_type_for_complex128(self):\n        self.assertAllEqual(\n            dtypes._respect_weak_type(\"complex128\", True), \"complex\"\n        )\n\n    def test_invalid_dtype_for_keras_promotion(self):\n        with self.assertRaisesRegex(\n            ValueError, \"is not a valid dtype for Keras type promotion.\"\n        ):\n            dtypes._least_upper_bound(\"invalid_dtype\")\n\n    def test_resolve_weak_type_for_invalid_dtype(self):\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid value for argument `dtype`. Expected one of\"\n        ):\n            dtypes._resolve_weak_type(\"invalid_dtype\")\n\n    def test_resolve_weak_type_for_invalid_precision(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Invalid value for argument `precision`. Expected one of\",\n        ):\n            dtypes._resolve_weak_type(\"int32\", precision=\"invalid_precision\")\n\n    def test_cycle_detection_in_make_lattice_upper_bounds(self):\n        original_lattice_function = dtypes._type_promotion_lattice\n\n        def mock_lattice():\n            lattice = original_lattice_function()\n            lattice[\"int32\"].append(\"float32\")\n            lattice[\"float32\"].append(\"int32\")\n            return lattice\n\n        dtypes._type_promotion_lattice = mock_lattice\n\n        with self.assertRaisesRegex(\n            ValueError, \"cycle detected in type promotion lattice for node\"\n        ):\n            dtypes._make_lattice_upper_bounds()\n\n        dtypes._type_promotion_lattice = original_lattice_function\n\n    def test_respect_weak_type_for_invalid_dtype(self):\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid value for argument `dtype`. Expected one of\"\n        ):\n            dtypes._respect_weak_type(\"invalid_dtype\", True)\n\n    def test_invalid_dtype_in_least_upper_bound(self):\n        invalid_dtype = \"non_existent_dtype\"\n        with self.assertRaisesRegex(\n            ValueError, \"is not a valid dtype for Keras type promotion\"\n        ):\n            dtypes._least_upper_bound(invalid_dtype)\n\n    def test_empty_lub_in_least_upper_bound(self):\n        dtype1 = \"float32\"\n        dtype2 = \"int32\"\n        with patch.dict(\n            dtypes.LATTICE_UPPER_BOUNDS,\n            {\"float32\": set(), \"int32\": set()},\n            clear=True,\n        ):\n            with self.assertRaisesRegex(\n                ValueError, \"no available implicit dtype promotion path\"\n            ):\n                dtypes._least_upper_bound(dtype1, dtype2)\n\n    def test_valid_dtype_leading_to_single_lub_element(self):\n        self.assertEqual(\n            dtypes._least_upper_bound(\"float32\", \"int32\"), \"float32\"\n        )\n\n    def test_valid_dtype_leading_to_keyerror_and_valueerror(self):\n        invalid_dtype = \"non_existent_dtype\"\n        with self.assertRaisesRegex(\n            ValueError, \"is not a valid dtype for Keras type promotion\"\n        ):\n            dtypes._least_upper_bound(invalid_dtype)\n\n    def test_resolve_weak_type_bool(self):\n        self.assertEqual(dtypes._resolve_weak_type(\"bool\"), \"bool\")\n\n    def test_resolve_weak_type_int(self):\n        self.assertEqual(\n            dtypes._resolve_weak_type(\"int32\", precision=\"32\"), \"int32\"\n        )\n        self.assertEqual(\n            dtypes._resolve_weak_type(\"int64\", precision=\"64\"), \"int64\"\n        )\n\n    def test_resolve_weak_type_uint(self):\n        self.assertEqual(\n            dtypes._resolve_weak_type(\"uint32\", precision=\"32\"), \"uint32\"\n        )\n        self.assertEqual(\n            dtypes._resolve_weak_type(\"uint64\", precision=\"64\"), \"uint64\"\n        )\n\n    def test_resolve_weak_type_float(self):\n        self.assertEqual(\n            dtypes._resolve_weak_type(\"float32\", precision=\"32\"), \"float32\"\n        )\n        self.assertEqual(\n            dtypes._resolve_weak_type(\"float64\", precision=\"64\"), \"float64\"\n        )\n\n    def test_least_upper_bound_ensure_order_independence(self):\n        # Test to ensure _least_upper_bound is order-independent.\n        result1 = dtypes._least_upper_bound(\"float32\", \"int32\")\n        result2 = dtypes._least_upper_bound(\"int32\", \"float32\")\n        self.assertEqual(result1, result2)\n\n    def test_least_upper_bound_single_element(self):\n        dtypes.LATTICE_UPPER_BOUNDS[\"test_dtype\"] = {\"test_dtype\"}\n        self.assertEqual(dtypes._least_upper_bound(\"test_dtype\"), \"test_dtype\")\n\n    def test_least_upper_bound_no_element(self):\n        dtypes.LATTICE_UPPER_BOUNDS[\"test_dtype\"] = set()\n        with self.assertRaisesRegex(\n            ValueError, \"no available implicit dtype promotion path\"\n        ):\n            dtypes._least_upper_bound(\"test_dtype\")\n\n    def test_least_upper_bound_with_no_common_upper_bound(self):\n        with patch.dict(\n            dtypes.LATTICE_UPPER_BOUNDS,\n            {\"test_dtype1\": set(), \"test_dtype2\": set()},\n            clear=True,\n        ):\n            with self.assertRaisesRegex(\n                ValueError, \"no available implicit dtype promotion path\"\n            ):\n                dtypes._least_upper_bound(\"test_dtype1\", \"test_dtype2\")\n\n    def test_invalid_float8_dtype(self):\n        with self.assertRaisesRegex(\n            ValueError, \"There is no implicit conversions from float8 dtypes\"\n        ):\n            dtypes.result_type(\"float8_e4m3fn\", \"bfloat16\")\n        with self.assertRaisesRegex(\n            ValueError, \"There is no implicit conversions from float8 dtypes\"\n        ):\n            dtypes.result_type(\"float8_e5m2\", \"bfloat16\")\n"
  },
  {
    "path": "keras/src/backend/common/global_state.py",
    "content": "import gc\nimport threading\n\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\n\nGLOBAL_STATE_TRACKER = threading.local()\nGLOBAL_SETTINGS_TRACKER = threading.local()\n\n\ndef set_global_attribute(name, value):\n    setattr(GLOBAL_STATE_TRACKER, name, value)\n\n\ndef get_global_attribute(name, default=None, set_to_default=False):\n    attr = getattr(GLOBAL_STATE_TRACKER, name, None)\n    if attr is None and default is not None:\n        attr = default\n        if set_to_default:\n            set_global_attribute(name, attr)\n    return attr\n\n\n@keras_export([\"keras.utils.clear_session\", \"keras.backend.clear_session\"])\ndef clear_session(free_memory=True):\n    \"\"\"Resets all state generated by Keras.\n\n    Keras manages a global state, which it uses to implement the Functional\n    model-building API and to uniquify autogenerated layer names.\n\n    If you are creating many models in a loop, this global state will consume\n    an increasing amount of memory over time, and you may want to clear it.\n    Calling `clear_session()` releases the global state: this helps avoid\n    clutter from old models and layers, especially when memory is limited.\n\n    Args:\n        free_memory: Whether to call Python garbage collection.\n            It's usually a good practice to call it to make sure\n            memory used by deleted objects is immediately freed.\n            However, it may take a few seconds to execute, so\n            when using `clear_session()` in a short loop,\n            you may want to skip it.\n\n    Example 1: calling `clear_session()` when creating models in a loop\n\n    ```python\n    for _ in range(100):\n      # Without `clear_session()`, each iteration of this loop will\n      # slightly increase the size of the global state managed by Keras\n      model = keras.Sequential([\n          keras.layers.Dense(10) for _ in range(10)])\n\n    for _ in range(100):\n      # With `clear_session()` called at the beginning,\n      # Keras starts with a blank state at each iteration\n      # and memory consumption is constant over time.\n      keras.backend.clear_session()\n      model = keras.Sequential([\n          keras.layers.Dense(10) for _ in range(10)])\n    ```\n\n    Example 2: resetting the layer name generation counter\n\n    >>> layers = [keras.layers.Dense(10) for _ in range(10)]\n    >>> new_layer = keras.layers.Dense(10)\n    >>> print(new_layer.name)\n    dense_10\n    >>> keras.backend.clear_session()\n    >>> new_layer = keras.layers.Dense(10)\n    >>> print(new_layer.name)\n    dense\n    \"\"\"\n    global GLOBAL_STATE_TRACKER\n    global GLOBAL_SETTINGS_TRACKER\n\n    GLOBAL_STATE_TRACKER = threading.local()\n    GLOBAL_SETTINGS_TRACKER = threading.local()\n\n    if backend.backend() == \"tensorflow\":\n        from keras.src.utils.module_utils import tensorflow as tf\n\n        tf.compat.v1.reset_default_graph()\n        if tf.executing_eagerly():\n            # Clear pending nodes in eager executors, kernel caches and\n            # step_containers.\n            from tensorflow.python.eager import context\n\n            context.context().clear_kernel_cache()\n    elif backend.backend() == \"torch\":\n        import torch._dynamo as dynamo\n\n        # reset's torchdynamo's cache so that  cached guards, compiled fn, etc\n        # do not persist between clear_session() calls\n        dynamo.reset()\n\n    if free_memory:\n        # Manually trigger garbage collection.\n        gc.collect()\n"
  },
  {
    "path": "keras/src/backend/common/global_state_test.py",
    "content": "from keras.src.backend.common import global_state\nfrom keras.src.testing import test_case\nfrom keras.src.utils.naming import auto_name\n\n\nclass GlobalStateTest(test_case.TestCase):\n    def test_clear_session(self):\n        name0 = auto_name(\"somename\")\n        self.assertEqual(name0, \"somename\")\n        name1 = auto_name(\"somename\")\n        self.assertEqual(name1, \"somename_1\")\n        global_state.clear_session()\n        name0 = auto_name(\"somename\")\n        self.assertEqual(name0, \"somename\")\n"
  },
  {
    "path": "keras/src/backend/common/keras_tensor.py",
    "content": "from keras.src import tree\nfrom keras.src.api_export import keras_export\nfrom keras.src.utils.naming import auto_name\n\n\n@keras_export(\"keras.KerasTensor\")\nclass KerasTensor:\n    \"\"\"Symbolic tensor -- encapsulates a shape and a dtype.\n\n    You can use `KerasTensor` instances to build computation\n    graphs of Keras operations, such as `keras.Function`\n    objects or Functional `keras.models.Model` objects.\n\n    Example:\n\n    >>> x = keras.KerasTensor(shape=(3, 4), dtype=\"float32\")\n    >>> x.shape\n    (3, 4)\n    >>> x.dtype\n    float32\n\n    Calling a Keras operation (including a layer or a model)\n    on a `KerasTensor` instance will return another `KerasTensor`\n    instance with the appropriate shape and dtype. This is\n    called a \"symbolic call\" (since there is no actual data\n    involved). The computation of the correct output shape and\n    dtype is called \"static shape inference\".\n    \"\"\"\n\n    def __init__(\n        self,\n        shape,\n        dtype=\"float32\",\n        sparse=False,\n        ragged=False,\n        record_history=True,\n        name=None,\n        **kwargs,\n    ):\n        from keras.src import backend\n\n        ragged_rank = kwargs.pop(\"ragged_rank\", None)\n        row_splits_dtype = kwargs.pop(\"row_splits_dtype\", None)\n        if kwargs:\n            raise TypeError(\n                f\"Unexpected keyword arguments: {', '.join(kwargs.keys())}\"\n            )\n\n        self._shape = backend.standardize_shape(shape)\n        self._dtype = backend.standardize_dtype(dtype)\n        self._sparse = bool(sparse)\n        self._ragged = bool(ragged)\n        if self._sparse and self._ragged:\n            raise ValueError(\n                \"KerasTensor cannot have `sparse=True` and `ragged=True` at \"\n                \"the same time.\"\n            )\n        self._ragged_rank = (\n            int(ragged_rank) if ragged_rank is not None else None\n        )\n        self._row_splits_dtype = (\n            backend.standardize_dtype(row_splits_dtype)\n            if row_splits_dtype is not None\n            else None\n        )\n        self.name = name or auto_name(self.__class__.__name__)\n        self.record_history = record_history\n\n    @property\n    def shape(self):\n        return self._shape\n\n    @shape.setter\n    def shape(self, value):\n        raise AttributeError(\n            \"The `shape` attribute of KerasTensor is immutable. One should \"\n            \"create a new instance of KerasTensor for this.\"\n        )\n\n    @property\n    def dtype(self):\n        return self._dtype\n\n    @dtype.setter\n    def dtype(self, value):\n        raise AttributeError(\n            \"The `dtype` attribute of KerasTensor is immutable. One should \"\n            \"create a new instance of KerasTensor for this.\"\n        )\n\n    @property\n    def sparse(self):\n        return self._sparse\n\n    @sparse.setter\n    def sparse(self, value):\n        raise AttributeError(\n            \"The `sparse` attribute of KerasTensor is immutable. One should \"\n            \"create a new instance of KerasTensor for this.\"\n        )\n\n    @property\n    def ragged_rank(self):\n        return self._ragged_rank\n\n    @ragged_rank.setter\n    def ragged_rank(self, value):\n        raise AttributeError(\n            \"The `ragged_rank` attribute of KerasTensor is immutable. One \"\n            \"should create a new instance of KerasTensor for this.\"\n        )\n\n    @property\n    def row_splits_dtype(self):\n        return self._row_splits_dtype\n\n    @row_splits_dtype.setter\n    def row_splits_dtype(self, value):\n        raise AttributeError(\n            \"The `row_splits_dtype` attribute of KerasTensor is immutable. One \"\n            \"should create a new instance of KerasTensor for this.\"\n        )\n\n    @property\n    def ragged(self):\n        return self._ragged\n\n    @ragged.setter\n    def ragged(self, value):\n        raise AttributeError(\n            \"The `ragged` attribute of KerasTensor is immutable. One should \"\n            \"create a new instance of KerasTensor for this.\"\n        )\n\n    @property\n    def ndim(self):\n        return len(self.shape)\n\n    def reshape(self, newshape):\n        from keras.src import ops\n\n        return ops.Reshape(newshape)(self)\n\n    def squeeze(self, axis=None):\n        from keras.src import ops\n\n        return ops.Squeeze(axis)(self)\n\n    def __int__(self):\n        raise ValueError(\n            \"A KerasTensor is symbolic: it's a placeholder for a shape \"\n            \"an a dtype. It doesn't have any actual numerical value. \"\n            \"You cannot convert it to an int.\"\n        )\n\n    def __float__(self):\n        raise ValueError(\n            \"A KerasTensor is symbolic: it's a placeholder for a shape \"\n            \"an a dtype. It doesn't have any actual numerical value. \"\n            \"You cannot convert it to a float.\"\n        )\n\n    def __array__(self):\n        raise ValueError(\n            \"A KerasTensor is symbolic: it's a placeholder for a shape \"\n            \"an a dtype. It doesn't have any actual numerical value. \"\n            \"You cannot convert it to a NumPy array.\"\n        )\n\n    def __jax_array__(self):\n        raise ValueError(\n            \"A KerasTensor cannot be used as input to a JAX function. \"\n            \"A KerasTensor is a symbolic placeholder for a shape and dtype, \"\n            \"used when constructing Keras Functional models \"\n            \"or Keras Functions. You can only use it as input to a Keras layer \"\n            \"or a Keras operation (from the namespaces `keras.layers` \"\n            \"and `keras.ops`). \"\n            \"You are likely doing something like:\\n\\n\"\n            \"```\\n\"\n            \"x = Input(...)\\n\"\n            \"...\\n\"\n            \"jax_fn(x)  # Invalid.\\n\"\n            \"```\\n\\n\"\n            \"What you should do instead is wrap `jax_fn` in a layer:\\n\\n\"\n            \"```\\n\"\n            \"class MyLayer(Layer):\\n\"\n            \"    def call(self, x):\\n\"\n            \"        return jax_fn(x)\\n\\n\"\n            \"x = MyLayer()(x)\\n\"\n            \"```\\n\"\n        )\n\n    def __tf_tensor__(self, dtype=None, name=None):\n        raise ValueError(\n            \"A KerasTensor cannot be used as input to a TensorFlow function. \"\n            \"A KerasTensor is a symbolic placeholder for a shape and dtype, \"\n            \"used when constructing Keras Functional models \"\n            \"or Keras Functions. You can only use it as input to a Keras layer \"\n            \"or a Keras operation (from the namespaces `keras.layers` \"\n            \"and `keras.ops`). \"\n            \"You are likely doing something like:\\n\\n\"\n            \"```\\n\"\n            \"x = Input(...)\\n\"\n            \"...\\n\"\n            \"tf_fn(x)  # Invalid.\\n\"\n            \"```\\n\\n\"\n            \"What you should do instead is wrap `tf_fn` in a layer:\\n\\n\"\n            \"```\\n\"\n            \"class MyLayer(Layer):\\n\"\n            \"    def call(self, x):\\n\"\n            \"        return tf_fn(x)\\n\\n\"\n            \"x = MyLayer()(x)\\n\"\n            \"```\\n\"\n        )\n\n    def __repr__(self):\n        return (\n            f\"<KerasTensor shape={self.shape}, dtype={self.dtype}, \"\n            f\"sparse={self.sparse}, ragged={self.ragged}, name={self.name}>\"\n        )\n\n    def __iter__(self):\n        raise NotImplementedError(\n            \"Iterating over a symbolic KerasTensor is not supported.\"\n        )\n\n    def __bool__(self):\n        raise TypeError(\"A symbolic KerasTensor cannot be used as a boolean.\")\n\n    def __add__(self, other):\n        from keras.src import ops\n\n        return ops.Add().symbolic_call(self, other)\n\n    def __radd__(self, other):\n        from keras.src import ops\n\n        return ops.Add().symbolic_call(other, self)\n\n    def __sub__(self, other):\n        from keras.src import ops\n\n        return ops.Subtract().symbolic_call(self, other)\n\n    def __rsub__(self, other):\n        from keras.src import ops\n\n        return ops.Subtract().symbolic_call(other, self)\n\n    def __mul__(self, other):\n        from keras.src import ops\n\n        return ops.Multiply().symbolic_call(self, other)\n\n    def __rmul__(self, other):\n        from keras.src import ops\n\n        return ops.Multiply().symbolic_call(other, self)\n\n    def __matmul__(self, other):\n        from keras.src import ops\n\n        return ops.Matmul().symbolic_call(self, other)\n\n    def __rmatmul__(self, other):\n        from keras.src import ops\n\n        return ops.Matmul().symbolic_call(other, self)\n\n    def __div__(self, other):\n        from keras.src import ops\n\n        return ops.Divide().symbolic_call(self, other)\n\n    def __rdiv__(self, other):\n        from keras.src import ops\n\n        return ops.Divide().symbolic_call(other, self)\n\n    def __truediv__(self, other):\n        from keras.src import ops\n\n        return ops.TrueDivide().symbolic_call(self, other)\n\n    def __rtruediv__(self, other):\n        from keras.src import ops\n\n        return ops.TrueDivide().symbolic_call(other, self)\n\n    def __neg__(self):\n        from keras.src import ops\n\n        return ops.Negative().symbolic_call(self)\n\n    def __abs__(self):\n        from keras.src import ops\n\n        return ops.Absolute().symbolic_call(self)\n\n    def __pow__(self, other):\n        from keras.src import ops\n\n        return ops.Power().symbolic_call(self, other)\n\n    def __rpow__(self, other):\n        from keras.src import ops\n\n        return ops.Power().symbolic_call(other, self)\n\n    def __floordiv__(self, other):\n        from keras.src import ops\n\n        return ops.FloorDivide().symbolic_call(self, other)\n\n    def __rfloordiv__(self, other):\n        from keras.src import ops\n\n        return ops.FloorDivide().symbolic_call(other, self)\n\n    def __mod__(self, other):\n        from keras.src import ops\n\n        return ops.Mod().symbolic_call(self, other)\n\n    def __rmod__(self, other):\n        from keras.src import ops\n\n        return ops.Mod().symbolic_call(other, self)\n\n    def __lt__(self, other):\n        from keras.src import ops\n\n        return ops.Less().symbolic_call(self, other)\n\n    def __le__(self, other):\n        from keras.src import ops\n\n        return ops.LessEqual().symbolic_call(self, other)\n\n    def __gt__(self, other):\n        from keras.src import ops\n\n        return ops.Greater().symbolic_call(self, other)\n\n    def __ge__(self, other):\n        from keras.src import ops\n\n        return ops.GreaterEqual().symbolic_call(self, other)\n\n    def __ne__(self, other):\n        from keras.src import ops\n\n        return ops.NotEqual().symbolic_call(self, other)\n\n    def __and__(self, other):\n        from keras.src import ops\n\n        return ops.LogicalAnd().symbolic_call(self, other)\n\n    def __rand__(self, other):\n        from keras.src import ops\n\n        return ops.LogicalAnd().symbolic_call(other, self)\n\n    def __or__(self, other):\n        from keras.src import ops\n\n        return ops.LogicalOr().symbolic_call(self, other)\n\n    def __ror__(self, other):\n        from keras.src import ops\n\n        return ops.LogicalOr().symbolic_call(other, self)\n\n    def __invert__(self):\n        from keras.src import ops\n\n        return ops.LogicalNot().symbolic_call(self)\n\n    def __xor__(self, other):\n        from keras.src import ops\n\n        return ops.LogicalXor().symbolic_call(self, other)\n\n    def __rxor__(self, other):\n        from keras.src import ops\n\n        return ops.LogicalXor().symbolic_call(other, self)\n\n    def __getitem__(self, key):\n        from keras.src import ops\n\n        return ops.GetItem().symbolic_call(self, key)\n\n    def __round__(self, ndigits=None):\n        from keras.src import ops\n\n        decimals = ndigits or 0\n        return ops.Round(decimals=decimals).symbolic_call(self)\n\n\ndef any_symbolic_tensors(args=None, kwargs=None):\n    args = args or ()\n    kwargs = kwargs or {}\n    for x in tree.flatten((args, kwargs)):\n        if isinstance(x, KerasTensor):\n            return True\n    return False\n\n\n@keras_export([\"keras.utils.is_keras_tensor\", \"keras.backend.is_keras_tensor\"])\ndef is_keras_tensor(x):\n    \"\"\"Returns whether `x` is a Keras tensor.\n\n    A \"Keras tensor\" is a *symbolic tensor*, such as a tensor\n    that was created via `Input()`. A \"symbolic tensor\"\n    can be understood as a placeholder -- it does not\n    contain any actual numerical data, only a shape and dtype.\n    It can be used for building Functional models, but it\n    cannot be used in actual computations.\n    \"\"\"\n    return isinstance(x, KerasTensor)\n"
  },
  {
    "path": "keras/src/backend/common/keras_tensor_test.py",
    "content": "from unittest.mock import Mock\nfrom unittest.mock import patch\n\nimport numpy as np\nimport tensorflow as tf\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.backend.common import keras_tensor\n\n\nclass KerasTensorTest(testing.TestCase):\n    def test_attributes(self):\n        x = keras_tensor.KerasTensor(shape=(3,), dtype=\"float32\", sparse=True)\n        self.assertEqual(x.dtype, \"float32\")\n        self.assertEqual(x.shape, (3,))\n        self.assertEqual(x.sparse, True)\n\n        # Raise error if trying to set attributes\n        with self.assertRaisesRegex(\n            AttributeError, \"The `shape` attribute of KerasTensor is immutable.\"\n        ):\n            x.shape = [3, 2]\n        with self.assertRaisesRegex(\n            AttributeError, \"The `dtype` attribute of KerasTensor is immutable.\"\n        ):\n            x.dtype = \"int32\"\n\n    def test_attributes_sparse(self):\n        x = keras_tensor.KerasTensor(shape=(3,), dtype=\"float32\", sparse=True)\n        self.assertEqual(x.sparse, True)\n\n        # Raise error if trying to set attributes\n        with self.assertRaisesRegex(\n            AttributeError,\n            \"The `sparse` attribute of KerasTensor is immutable.\",\n        ):\n            x.sparse = False\n\n    def test_attributes_ragged(self):\n        x = keras_tensor.KerasTensor(shape=(3,), dtype=\"float32\", ragged=True)\n        self.assertEqual(x.ragged, True)\n\n        # Raise error if trying to set attributes\n        with self.assertRaisesRegex(\n            AttributeError,\n            \"The `ragged` attribute of KerasTensor is immutable.\",\n        ):\n            x.ragged = False\n\n    def test_init_sparse_ragged_raises(self):\n        with self.assertRaisesRegex(\n            ValueError, \"cannot have `sparse=True` and `ragged=True`\"\n        ):\n            keras_tensor.KerasTensor(shape=(3,), sparse=True, ragged=True)\n\n    def test_numpy_methods(self):\n        x = keras_tensor.KerasTensor(shape=(3, 2), dtype=\"float32\")\n\n        # reshape\n        x = x.reshape((6,))\n        self.assertEqual(x.shape, (6,))\n\n        # expand_dims, squeeze\n        x = ops.expand_dims(x, -1)\n        self.assertEqual(x.shape, (6, 1))\n        x = x.squeeze()\n        self.assertEqual(x.shape, (6,))\n        x = ops.expand_dims(x, axis=0)\n        self.assertEqual(x.shape, (1, 6))\n        x = x.squeeze(axis=0)\n        self.assertEqual(x.shape, (6,))\n\n    def test_invalid_usage(self):\n        x = keras_tensor.KerasTensor(shape=(3,), dtype=\"float32\")\n        with self.assertRaisesRegex(\n            ValueError, \"doesn't have any actual numerical value\"\n        ):\n            np.array(x)\n\n        if backend.backend() == \"jax\":\n            from jax import numpy as jnp\n\n            with self.assertRaisesRegex(\n                ValueError, \"cannot be used as input to a JAX function\"\n            ):\n                jnp.array(x)\n\n        with self.assertRaisesRegex(\n            ValueError, \"cannot be used as input to a TensorFlow function\"\n        ):\n            tf.convert_to_tensor(x)\n\n    def test_bool(self):\n        tensor = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"float32\")\n        with self.assertRaisesRegex(TypeError, \"cannot be used as a boolean.\"):\n            bool(tensor)\n\n    def test_representation(self):\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"float32\")\n        self.assertIn(\"<KerasTensor shape=(3, 4)\", repr(x))\n\n    def test_iterating(self):\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"float32\")\n        with self.assertRaises(NotImplementedError):\n            iter(x)\n\n    def test_any_symbolic_tensors(self):\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"float32\")\n        y = np.array([1, 2, 3])\n        self.assertTrue(keras_tensor.any_symbolic_tensors(args=[x, y]))\n        self.assertFalse(keras_tensor.any_symbolic_tensors(args=[y]))\n\n    def test_is_keras_tensor(self):\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"float32\")\n        self.assertTrue(keras_tensor.is_keras_tensor(x))\n        y = np.array([1, 2, 3])\n        self.assertFalse(keras_tensor.is_keras_tensor(y))\n\n    @patch(\"keras.src.ops.Absolute.symbolic_call\")\n    def test_abs_method(self, mock_symbolic_call):\n        mock_tensor = Mock()\n        mock_symbolic_call.return_value = mock_tensor\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"float32\")\n        abs_x = abs(x)  # this will internally call x.__abs__()\n        mock_symbolic_call.assert_called_once_with(x)\n        self.assertEqual(abs_x, mock_tensor)\n\n    @patch(\"keras.src.ops.Negative.symbolic_call\")\n    def test_neg_method(self, mock_method):\n        self._test_unary_op_method(mock_method, lambda x: -x)\n\n    @patch(\"keras.src.ops.Subtract.symbolic_call\")\n    def test_sub_method(self, mock_method):\n        y = Mock()\n        self._test_binary_op_method(mock_method, y, lambda x, y: x - y)\n\n    @patch(\"keras.src.ops.Multiply.symbolic_call\")\n    def test_mul_method(self, mock_method):\n        y = Mock()\n        self._test_binary_op_method(mock_method, y, lambda x, y: x * y)\n\n    @patch(\"keras.src.ops.Matmul.symbolic_call\")\n    def test_matmul_method(self, mock_method):\n        y = Mock()\n        self._test_binary_op_method(mock_method, y, lambda x, y: x @ y)\n\n    @patch(\"keras.src.ops.Power.symbolic_call\")\n    def test_pow_method(self, mock_method):\n        y = Mock()\n        self._test_binary_op_method(mock_method, y, lambda x, y: x**y)\n\n    @patch(\"keras.src.ops.Mod.symbolic_call\")\n    def test_mod_method(self, mock_method):\n        y = Mock()\n        self._test_binary_op_method(mock_method, y, lambda x, y: x % y)\n\n    @patch(\"keras.src.ops.Less.symbolic_call\")\n    def test_lt_method(self, mock_method):\n        y = Mock()\n        self._test_binary_op_method(mock_method, y, lambda x, y: x < y)\n\n    @patch(\"keras.src.ops.LogicalAnd.symbolic_call\")\n    def test_and_method(self, mock_method):\n        y = Mock()\n        self._test_binary_op_method(mock_method, y, lambda x, y: x & y)\n\n    @patch(\"keras.src.ops.LogicalOr.symbolic_call\")\n    def test_or_method(self, mock_method):\n        y = Mock()\n        self._test_binary_op_method(mock_method, y, lambda x, y: x | y)\n\n    @patch(\"keras.src.ops.GetItem.symbolic_call\")\n    def test_getitem_method(self, mock_method):\n        y = Mock()\n        self._test_binary_op_method(mock_method, y, lambda x, y: x[y])\n\n    def _test_unary_op_method(self, mock_method, operator):\n        mock_tensor = Mock()\n        mock_method.return_value = mock_tensor\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"float32\")\n        result = operator(x)\n        mock_method.assert_called_once_with(x)\n        self.assertEqual(result, mock_tensor)\n\n    def _test_binary_op_method(self, mock_method, other, operator):\n        mock_tensor = Mock()\n        mock_method.return_value = mock_tensor\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"float32\")\n        result = operator(x, other)\n        mock_method.assert_called_once_with(x, other)\n        self.assertEqual(result, mock_tensor)\n\n    @patch(\"keras.src.ops.Add.symbolic_call\")\n    def test_radd_method(self, mock_symbolic_call):\n        \"\"\"Test __radd__ method\"\"\"\n        mock_tensor = Mock()\n        mock_symbolic_call.return_value = mock_tensor\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"float32\")\n        y = Mock()\n        result = y + x\n        mock_symbolic_call.assert_called_once_with(y, x)\n        self.assertEqual(result, mock_tensor)\n\n    @patch(\"keras.src.ops.Subtract.symbolic_call\")\n    def test_rsub_method(self, mock_symbolic_call):\n        \"\"\"Test __rsub__ method\"\"\"\n        mock_tensor = Mock()\n        mock_symbolic_call.return_value = mock_tensor\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"float32\")\n        y = Mock()\n        result = y - x\n        mock_symbolic_call.assert_called_once_with(y, x)\n        self.assertEqual(result, mock_tensor)\n\n    @patch(\"keras.src.ops.Multiply.symbolic_call\")\n    def test_rmul_method(self, mock_symbolic_call):\n        \"\"\"Test __rmul__ method\"\"\"\n        mock_tensor = Mock()\n        mock_symbolic_call.return_value = mock_tensor\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"float32\")\n        y = Mock()\n        result = y * x\n        mock_symbolic_call.assert_called_once_with(y, x)\n        self.assertEqual(result, mock_tensor)\n\n    @patch(\"keras.src.ops.Matmul.symbolic_call\")\n    def test_rmatmul_method(self, mock_symbolic_call):\n        \"\"\"Test __rmatmul__ method\"\"\"\n        mock_tensor = Mock()\n        mock_symbolic_call.return_value = mock_tensor\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"float32\")\n        y = Mock()\n        result = y @ x\n        mock_symbolic_call.assert_called_once_with(y, x)\n        self.assertEqual(result, mock_tensor)\n\n    @patch(\"keras.src.ops.Power.symbolic_call\")\n    def test_rpow_method(self, mock_symbolic_call):\n        \"\"\"Test __rpow__ method\"\"\"\n        mock_tensor = Mock()\n        mock_symbolic_call.return_value = mock_tensor\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"float32\")\n        y = Mock()\n        result = y**x\n        mock_symbolic_call.assert_called_once_with(y, x)\n        self.assertEqual(result, mock_tensor)\n\n    @patch(\"keras.src.ops.FloorDivide.symbolic_call\")\n    def test_floordiv_method(self, mock_symbolic_call):\n        \"\"\"Test __floordiv__ method\"\"\"\n        mock_tensor = Mock()\n        mock_symbolic_call.return_value = mock_tensor\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"float32\")\n        y = Mock()\n        result = x // y\n        mock_symbolic_call.assert_called_once_with(x, y)\n        self.assertEqual(result, mock_tensor)\n\n    @patch(\"keras.src.ops.FloorDivide.symbolic_call\")\n    def test_rfloordiv_method(self, mock_symbolic_call):\n        \"\"\"Test __rfloordiv__ method\"\"\"\n        mock_tensor = Mock()\n        mock_symbolic_call.return_value = mock_tensor\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"float32\")\n        y = Mock()\n        result = y // x\n        mock_symbolic_call.assert_called_once_with(y, x)\n        self.assertEqual(result, mock_tensor)\n\n    @patch(\"keras.src.ops.Mod.symbolic_call\")\n    def test_rmod_method(self, mock_symbolic_call):\n        \"\"\"Test __rmod__ method\"\"\"\n        mock_tensor = Mock()\n        mock_symbolic_call.return_value = mock_tensor\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"float32\")\n        y = Mock()\n        result = y % x\n        mock_symbolic_call.assert_called_once_with(y, x)\n        self.assertEqual(result, mock_tensor)\n\n    @patch(\"keras.src.ops.LessEqual.symbolic_call\")\n    def test_le_method(self, mock_symbolic_call):\n        \"\"\"Test __le__ method\"\"\"\n        mock_tensor = Mock()\n        mock_symbolic_call.return_value = mock_tensor\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"float32\")\n        y = Mock()\n        result = x <= y\n        mock_symbolic_call.assert_called_once_with(x, y)\n        self.assertEqual(result, mock_tensor)\n\n    @patch(\"keras.src.ops.Greater.symbolic_call\")\n    def test_gt_method(self, mock_symbolic_call):\n        \"\"\"Test __gt__ method\"\"\"\n        mock_tensor = Mock()\n        mock_symbolic_call.return_value = mock_tensor\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"float32\")\n        y = Mock()\n        result = x > y\n        mock_symbolic_call.assert_called_once_with(x, y)\n        self.assertEqual(result, mock_tensor)\n\n    @patch(\"keras.src.ops.GreaterEqual.symbolic_call\")\n    def test_ge_method(self, mock_symbolic_call):\n        \"\"\"Test __ge__ method\"\"\"\n        mock_tensor = Mock()\n        mock_symbolic_call.return_value = mock_tensor\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"float32\")\n        y = Mock()\n        result = x >= y\n        mock_symbolic_call.assert_called_once_with(x, y)\n        self.assertEqual(result, mock_tensor)\n\n    @patch(\"keras.src.ops.NotEqual.symbolic_call\")\n    def test_ne_method(self, mock_symbolic_call):\n        \"\"\"Test __ne__ method\"\"\"\n        mock_tensor = Mock()\n        mock_symbolic_call.return_value = mock_tensor\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"float32\")\n        y = Mock()\n        result = x != y\n        mock_symbolic_call.assert_called_once_with(x, y)\n        self.assertEqual(result, mock_tensor)\n\n    @patch(\"keras.src.ops.LogicalAnd.symbolic_call\")\n    def test_rand_method(self, mock_symbolic_call):\n        \"\"\"Test __rand__ method\"\"\"\n        mock_tensor = Mock()\n        mock_symbolic_call.return_value = mock_tensor\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"bool\")\n        y = Mock()\n        result = y & x\n        mock_symbolic_call.assert_called_once_with(y, x)\n        self.assertEqual(result, mock_tensor)\n\n    @patch(\"keras.src.ops.LogicalOr.symbolic_call\")\n    def test_ror_method(self, mock_symbolic_call):\n        \"\"\"Test __ror__ method\"\"\"\n        mock_tensor = Mock()\n        mock_symbolic_call.return_value = mock_tensor\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"bool\")\n        y = Mock()\n        result = y | x\n        mock_symbolic_call.assert_called_once_with(y, x)\n        self.assertEqual(result, mock_tensor)\n\n    @patch(\"keras.src.ops.LogicalNot.symbolic_call\")\n    def test_invert_method(self, mock_symbolic_call):\n        \"\"\"Test __invert__ method\"\"\"\n        mock_tensor = Mock()\n        mock_symbolic_call.return_value = mock_tensor\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"bool\")\n        result = ~x\n        mock_symbolic_call.assert_called_once_with(x)\n        self.assertEqual(result, mock_tensor)\n\n    @patch(\"keras.src.ops.LogicalXor.symbolic_call\")\n    def test_xor_method(self, mock_symbolic_call):\n        \"\"\"Test __xor__ method\"\"\"\n        mock_tensor = Mock()\n        mock_symbolic_call.return_value = mock_tensor\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"bool\")\n        y = Mock()\n        result = x ^ y\n        mock_symbolic_call.assert_called_once_with(x, y)\n        self.assertEqual(result, mock_tensor)\n\n    @patch(\"keras.src.ops.LogicalXor.symbolic_call\")\n    def test_rxor_method(self, mock_symbolic_call):\n        \"\"\"Test __rxor__ method\"\"\"\n        mock_tensor = Mock()\n        mock_symbolic_call.return_value = mock_tensor\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"bool\")\n        y = Mock()\n        result = y ^ x\n        mock_symbolic_call.assert_called_once_with(y, x)\n        self.assertEqual(result, mock_tensor)\n\n    @patch(\"keras.src.ops.TrueDivide.symbolic_call\")\n    def test_truediv_method(self, mock_symbolic_call):\n        \"\"\"Test __truediv__ method\"\"\"\n        mock_tensor = Mock()\n        mock_symbolic_call.return_value = mock_tensor\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"float32\")\n        y = Mock()\n        result = x / y\n        mock_symbolic_call.assert_called_once_with(x, y)\n        self.assertEqual(result, mock_tensor)\n\n    @patch(\"keras.src.ops.TrueDivide.symbolic_call\")\n    def test_rtruediv_method(self, mock_symbolic_call):\n        \"\"\"Test __rtruediv__ method\"\"\"\n        mock_tensor = Mock()\n        mock_symbolic_call.return_value = mock_tensor\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"float32\")\n        y = Mock()\n        result = y / x\n        mock_symbolic_call.assert_called_once_with(y, x)\n        self.assertEqual(result, mock_tensor)\n\n    @patch(\"keras.src.ops.Divide.symbolic_call\")\n    def test_div_method(self, mock_symbolic_call):\n        \"\"\"Test __div__ method\"\"\"\n        mock_tensor = Mock()\n        mock_symbolic_call.return_value = mock_tensor\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"float32\")\n        y = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"float32\")\n        # to ensure compatibility across Python versions\n        result = x.__div__(y)\n        mock_symbolic_call.assert_called_once_with(x, y)\n        self.assertEqual(result, mock_tensor)\n\n    @patch(\"keras.src.ops.Divide.symbolic_call\")\n    def test_rdiv_method(self, mock_symbolic_call):\n        \"\"\"Test __rdiv__ method\"\"\"\n        mock_tensor = Mock()\n        mock_symbolic_call.return_value = mock_tensor\n        x = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"float32\")\n        y = keras_tensor.KerasTensor(shape=(3, 4), dtype=\"float32\")\n        # to ensure compatibility across Python versions\n        result = x.__rdiv__(y)\n        mock_symbolic_call.assert_called_once_with(y, x)\n        self.assertEqual(result, mock_tensor)\n"
  },
  {
    "path": "keras/src/backend/common/masking.py",
    "content": "from keras.src.backend.common.tensor_attributes import get_tensor_attr\nfrom keras.src.backend.common.tensor_attributes import set_tensor_attr\n\n\ndef set_keras_mask(x, mask):\n    \"\"\"Sets the Keras mask attribute for the given tensor in-place.\n\n    Args:\n        x: Input tensor.\n        mask: The mask tensor to be set. If `None`, the `_keras_mask` attribute\n            will be cleared.\n    \"\"\"\n    set_tensor_attr(x, \"_keras_mask\", mask)\n\n\ndef get_keras_mask(x):\n    \"\"\"Gets the Keras mask attribute from the given tensor.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        The mask tensor associated with the input tensor, or `None` if no mask\n        has been set.\n    \"\"\"\n    return get_tensor_attr(x, \"_keras_mask\")\n"
  },
  {
    "path": "keras/src/backend/common/masking_test.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.backend.common.masking import get_keras_mask\nfrom keras.src.backend.common.masking import set_keras_mask\n\n\nclass MaskingTest(testing.TestCase):\n    def test_mask_on_eager_tensor(self):\n        x = ops.zeros((2, 3))\n        self.assertIsNone(get_keras_mask(x))\n\n        set_keras_mask(x, None)\n        self.assertIsNone(get_keras_mask(x))\n\n        mask = ops.ones((2, 3))\n        set_keras_mask(x, mask)\n        self.assertIs(get_keras_mask(x), mask)\n\n        set_keras_mask(x, None)\n        self.assertIsNone(get_keras_mask(x))\n\n        set_keras_mask(x, None)\n        self.assertIsNone(get_keras_mask(x))\n\n    def test_mask_on_tracer_tensor(self):\n        def fn(x):\n            self.assertIsNone(get_keras_mask(x))\n\n            set_keras_mask(x, None)\n            self.assertIsNone(get_keras_mask(x))\n\n            mask = ops.ones((2, 3))\n            set_keras_mask(x, mask)\n            self.assertIs(get_keras_mask(x), mask)\n\n            set_keras_mask(x, None)\n            self.assertIsNone(get_keras_mask(x))\n\n            set_keras_mask(x, None)  # key is now deleted, should be a no-op\n            self.assertIsNone(get_keras_mask(x))\n\n        backend.compute_output_spec(fn, backend.KerasTensor((2, 3)))\n"
  },
  {
    "path": "keras/src/backend/common/name_scope.py",
    "content": "from keras.src.backend.common import global_state\n\n\nclass name_scope:\n    \"\"\"Creates a sub-namespace for variable paths.\n\n    Args:\n        name: Name of the current scope (string).\n        caller: Optional ID of a caller object (e.g. class instance).\n        deduplicate: If `True`, if `caller` was passed,\n            and the previous caller matches the current caller,\n            and the previous name matches the current name,\n            do not reenter a new namespace.\n        override_parent: Can be used to provide an absolute path\n            which would override any previously opened name scopes.\n    \"\"\"\n\n    def __init__(\n        self, name, caller=None, deduplicate=True, override_parent=None\n    ):\n        if not isinstance(name, str) or \"/\" in name:\n            raise ValueError(\n                \"Argument `name` must be a string and \"\n                \"cannot contain character `/`. \"\n                f\"Received: name={name}\"\n            )\n        self.name = name\n        self.caller = caller\n        self.deduplicate = deduplicate\n        self.override_parent = override_parent\n        if (\n            override_parent is None\n            and deduplicate\n            and getattr(caller, \"_parent_path\", None) is not None\n        ):\n            self.override_parent = caller._parent_path\n        self._pop_on_exit = False\n\n    def __enter__(self):\n        name_scope_stack = global_state.get_global_attribute(\n            \"name_scope_stack\", default=[], set_to_default=True\n        )\n        if self.deduplicate and name_scope_stack:\n            parent_caller = name_scope_stack[-1].caller\n            parent_name = name_scope_stack[-1].name\n            if (\n                self.caller is not None\n                and self.caller is parent_caller\n                and self.name == parent_name\n            ):\n                return self\n        name_scope_stack.append(self)\n        self._pop_on_exit = True\n        return self\n\n    def __exit__(self, *args, **kwargs):\n        if self._pop_on_exit:\n            name_scope_stack = global_state.get_global_attribute(\n                \"name_scope_stack\"\n            )\n            if name_scope_stack:\n                name_scope_stack.pop()\n\n\ndef current_path():\n    name_scope_stack = global_state.get_global_attribute(\"name_scope_stack\")\n    if name_scope_stack is None:\n        return \"\"\n    parts = []\n    for entry in name_scope_stack:\n        if entry.override_parent is not None:\n            parts = [p for p in entry.override_parent.split(\"/\") if p]\n        parts.append(entry.name)\n    return \"/\".join(parts)\n"
  },
  {
    "path": "keras/src/backend/common/name_scope_test.py",
    "content": "import threading\n\nfrom keras.src import testing\nfrom keras.src.backend.common import global_state\nfrom keras.src.backend.common.name_scope import current_path\nfrom keras.src.backend.common.name_scope import name_scope\n\n\nclass NameScopeTest(testing.TestCase):\n    def test_stacking(self):\n        self.assertEqual(current_path(), \"\")\n        with name_scope(\"outer\") as outer:\n            self.assertEqual(outer.name, \"outer\")\n            self.assertEqual(current_path(), \"outer\")\n            with name_scope(\"middle\") as middle:\n                self.assertEqual(middle.name, \"middle\")\n                self.assertEqual(current_path(), \"outer/middle\")\n                with name_scope(\"inner\") as inner:\n                    self.assertEqual(inner.name, \"inner\")\n                    self.assertEqual(current_path(), \"outer/middle/inner\")\n                self.assertEqual(current_path(), \"outer/middle\")\n            self.assertEqual(current_path(), \"outer\")\n        self.assertEqual(current_path(), \"\")\n\n    def test_deduplication(self):\n        self.assertEqual(current_path(), \"\")\n        with name_scope(\"name\", caller=1):\n            with name_scope(\"name\", caller=1):\n                self.assertEqual(current_path(), \"name\")\n        self.assertEqual(current_path(), \"\")\n        with name_scope(\"name\"):\n            with name_scope(\"name\"):\n                self.assertEqual(current_path(), \"name/name\")\n\n    def test_errors(self):\n        with self.assertRaisesRegex(ValueError, \"must be a string\"):\n            name_scope(\"foo/bar\")\n        with self.assertRaisesRegex(ValueError, \"must be a string\"):\n            name_scope(4)\n\n    def test_override_parent(self):\n        self.assertEqual(current_path(), \"\")\n        with name_scope(\"outer\"):\n            self.assertEqual(current_path(), \"outer\")\n            with name_scope(\"middle\", override_parent=\"/absolute/path\"):\n                self.assertEqual(current_path(), \"absolute/path/middle\")\n                with name_scope(\"inner\"):\n                    self.assertEqual(\n                        current_path(), \"absolute/path/middle/inner\"\n                    )\n            self.assertEqual(current_path(), \"outer\")\n\n    def test_exit_with_none_stack(self):\n        \"\"\"Test that __exit__ handles None name_scope_stack gracefully.\"\"\"\n        # Create a name_scope instance\n        scope = name_scope(\"test\")\n        # Enter the scope normally\n        scope.__enter__()\n\n        # Simulate the scenario where global state is cleared\n        # (e.g., in a different thread)\n        global_state.set_global_attribute(\"name_scope_stack\", None)\n\n        # Exit should not raise an AttributeError\n        scope.__exit__()\n\n        # Clean up: reset the stack\n        global_state.set_global_attribute(\"name_scope_stack\", [])\n\n    def test_exit_with_empty_stack(self):\n        \"\"\"Test that __exit__ handles empty name_scope_stack gracefully.\"\"\"\n        # Create a name_scope instance\n        scope = name_scope(\"test\")\n        # Enter the scope normally\n        scope.__enter__()\n\n        # Simulate the scenario where the stack is cleared\n        name_scope_stack = global_state.get_global_attribute(\"name_scope_stack\")\n        name_scope_stack.clear()\n\n        # Exit should not raise an IndexError\n        scope.__exit__()\n\n        # Verify stack is still empty\n        name_scope_stack = global_state.get_global_attribute(\n            \"name_scope_stack\", default=[]\n        )\n        self.assertEqual(len(name_scope_stack), 0)\n\n    def test_multithreaded_name_scope(self):\n        \"\"\"Test name_scope in multithreaded environment.\"\"\"\n        results = []\n\n        def thread_function(thread_id):\n            # Each thread should have its own name_scope_stack\n            with name_scope(f\"thread_{thread_id}\"):\n                path = current_path()\n                results.append(path)\n                # Verify we get the expected path\n                self.assertEqual(path, f\"thread_{thread_id}\")\n\n        # Create and start multiple threads\n        threads = []\n        for i in range(5):\n            thread = threading.Thread(target=thread_function, args=(i,))\n            threads.append(thread)\n            thread.start()\n\n        # Wait for all threads to complete\n        for thread in threads:\n            thread.join()\n\n        # Verify all threads executed successfully\n        self.assertEqual(len(results), 5)\n\n    def test_exit_without_pop_on_exit(self):\n        \"\"\"Test that __exit__ respects _pop_on_exit flag.\"\"\"\n        # Create a name_scope but don't enter it\n        scope = name_scope(\"test\")\n        # _pop_on_exit should be False\n        self.assertFalse(scope._pop_on_exit)\n\n        # Set up a stack manually\n        global_state.set_global_attribute(\"name_scope_stack\", [scope])\n\n        scope.__exit__()\n\n        # Verify the stack still contains the scope\n        name_scope_stack = global_state.get_global_attribute(\"name_scope_stack\")\n        self.assertEqual(len(name_scope_stack), 1)\n\n        # Clean up\n        global_state.set_global_attribute(\"name_scope_stack\", [])\n"
  },
  {
    "path": "keras/src/backend/common/remat.py",
    "content": "from collections import namedtuple\n\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend.common import global_state\n\n\n@keras_export(\"keras.RematScope\")\nclass RematScope:\n    \"\"\"A context manager for enabling rematerialization in Keras.\n\n    Rematerialization (gradient checkpointing) trades memory for computation by\n    recomputing intermediate activations during the backward pass. This is\n    particularly useful for training large models or large batch sizes within\n    limited memory constraints.\n\n    This should be used when initializing the layer (e.g., `layer(input)`).\n    Rematerialization applies at execution time, not at creation time.\n\n    Args:\n        mode: Rematerialization mode to apply.\n            Options:\n            - `\"full\"`: Apply rematerialization globally to all supported\n              operations.\n            - `\"activations\"`: Apply rematerialization to activations on any\n              layers that contain `keras.activations` (e.g., `Dense(...,\n              activation=relu)`).\n            - `\"larger_than\"`: Apply rematerialization to layers with output\n              sizes larger than `output_size_threshold`.\n            - `\"list_of_layers\"`: Apply rematerialization to a specific list of\n              layer names.\n            - `None`: Disable rematerialization.\n        output_size_threshold: Output size threshold for the\n            `\"larger_than\"` mode. Layers producing outputs larger than this\n            threshold will be rematerialized. Default is `1024`.\n        layer_names: List of layer names for the\n            `\"list_of_layers\"` mode. Default is an empty list.\n\n    Examples:\n    Using \"list_of_layers\" mode:\n\n    ```python\n    from keras import RematScope\n    input_tensor = tf.random.normal((1, 32, 32, 3))\n    with RematScope(mode=\"list_of_layers\", layer_names=[\"dense_1\",\n    \"conv2d_1\"]):\n        layer1 = keras.layers.Dense(128, name=\"dense_1\")\n        layer2 = keras.layers.Conv2D(64, (3, 3), name=\"conv2d_1\")\n        layer3 = keras.layers.Dense(64, name=\"dense_2\")\n        # Only layer1 and layer2 will apply rematerialization\n        output1 = layer1(input_tensor)\n        output2 = layer2(output1)\n        output3 = layer3(output2)\n    ```\n\n    Using \"larger_than\" mode with a specific output size threshold:\n\n    ```python\n    with RematScope(mode=\"larger_than\", output_size_threshold=2048):\n        layer = keras.layers.Conv2D(64, (3, 3))\n        output = layer(input_tensor)  # Conv2D outputs larger than 2048\n    ```\n\n    Nested scopes for fine-grained control:\n\n    ```python\n    with RematScope(mode=\"full\"):\n        # Create layers\n        layer1 = keras.layers.Dense(128, activation='relu')\n        output1 = layer1(input_tensor)  # layer1 is fully rematerialized\n        with RematScope(mode=\"larger_than\", output_size_threshold=512):\n            layer2 = keras.layers.Conv2D(32, (3, 3))\n            output2 = layer2(output1) # layer2 is conditionally rematerialized\n            # if output > 512\n    ```\n    \"\"\"\n\n    def __init__(\n        self, mode=\"full\", output_size_threshold=1024, layer_names=None\n    ):\n        if mode not in {\n            \"full\",\n            \"activations\",\n            \"larger_than\",\n            \"list_of_layers\",\n            None,\n        }:\n            raise ValueError(\n                f\"Invalid mode '{mode}'. Supported modes are: \"\n                \"'full', 'activations', 'larger_than', 'list_of_layers', or \"\n                \" None.\"\n            )\n        self.mode = mode\n        self.output_size_threshold = output_size_threshold\n        self.layer_names = layer_names or []\n        self._pop_on_exit = False\n\n    def __enter__(self):\n        remat_scope_stack = global_state.get_global_attribute(\n            \"remat_scope_stack\", default=[], set_to_default=True\n        )\n        remat_scope_stack.append(self)\n        self._pop_on_exit = True\n        return self\n\n    def __exit__(self, *args, **kwargs):\n        if self._pop_on_exit:\n            remat_scope_stack = global_state.get_global_attribute(\n                \"remat_scope_stack\"\n            )\n            remat_scope_stack.pop()\n\n\nRematMode = namedtuple(\n    \"RematMode\", [\"mode\", \"output_size_threshold\", \"layer_names\"]\n)\n\n\ndef get_current_remat_mode():\n    \"\"\"Get the current rematerialization mode and associated settings.\n\n    Returns:\n        RematMode or None: The current rematerialization mode, or None if not\n        set.\n    \"\"\"\n    remat_scope_stack = global_state.get_global_attribute(\"remat_scope_stack\")\n    if not remat_scope_stack:\n        return None\n    active_scope = remat_scope_stack[-1]\n    return RematMode(\n        active_scope.mode,\n        active_scope.output_size_threshold,\n        active_scope.layer_names,\n    )\n\n\n@keras_export(\"keras.remat\")\ndef remat(f):\n    \"\"\"Applies rematerialization to a function or layer for memory optimization.\n\n    Rematerialization is a memory optimization technique that trades off\n    computation for memory. Instead of storing intermediate results\n    (e.g. activations) for backpropagation, they are recomputed during the\n    backward pass. This reduces peak memory usage at the cost of increased\n    computation time, allowing the training of larger models or using larger\n    batch sizes within the same memory constraints.\n\n    Args:\n        f: A callable function, to which rematerialization is\n           applied. This is typically a computationally expensive operation\n           where intermediate states can be recomputed instead of stored.\n\n    Returns:\n        A wrapped function that applies rematerialization. The returned\n        function defines a custom gradient, ensuring that during the backward\n        pass, the forward computation is recomputed as needed.\n\n    Example:\n\n    ```python\n    from keras import Model\n    class CustomRematLayer(layers.Layer):\n        def __init__(self, **kwargs):\n            super().__init__(**kwargs)\n            self.remat_function = remat(self.intermediate_function)\n\n        def intermediate_function(self, x):\n            for _ in range(2):\n                x = x + x * 0.1  # Simple scaled transformation\n            return x\n\n        def call(self, inputs):\n            return self.remat_function(inputs)\n\n    # Define a simple model using the custom layer\n    inputs = layers.Input(shape=(4,))\n    x = layers.Dense(4, activation=\"relu\")(inputs)\n    x = CustomRematLayer()(x)  # Custom layer with rematerialization\n    outputs = layers.Dense(1)(x)\n\n    # Create and compile the model\n    model = Model(inputs=inputs, outputs=outputs)\n    model.compile(optimizer=\"sgd\", loss=\"mse\")\n    ```\n    \"\"\"\n    return backend.core.remat(f)\n"
  },
  {
    "path": "keras/src/backend/common/remat_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import testing\nfrom keras.src.backend.common import global_state\nfrom keras.src.backend.common.remat import RematScope\nfrom keras.src.backend.common.remat import get_current_remat_mode\nfrom keras.src.layers import activations\n\n\nclass TestRematScope(testing.TestCase):\n    def test_remat_scope_activation(self):\n        self.assertIsNone(\n            get_current_remat_mode()\n        )  # Initially, no mode is active\n\n        with RematScope(mode=\"full\"):\n            self.assertEqual(\n                get_current_remat_mode().mode, \"full\"\n            )  # Mode is set to \"full\"\n\n        self.assertIsNone(\n            get_current_remat_mode()\n        )  # Mode is restored to None after scope ends\n\n    def test_remat_scope_nested(self):\n        \"\"\"Test nested scopes with different rematerialization modes.\"\"\"\n        with RematScope(mode=\"full\"):\n            self.assertEqual(\n                get_current_remat_mode().mode, \"full\"\n            )  # Outer scope is \"full\"\n\n            with RematScope(mode=\"activations\"):\n                self.assertEqual(\n                    get_current_remat_mode().mode, \"activations\"\n                )  # Inner scope is \"activations\"\n\n            self.assertEqual(\n                get_current_remat_mode().mode, \"full\"\n            )  # Back to outer scope\n\n        self.assertIsNone(\n            get_current_remat_mode()\n        )  # Mode is restored to None after all scopes\n\n    def test_remat_scope_stack_management(self):\n        \"\"\"Test that the remat_scope_stack is managed correctly.\"\"\"\n        self.assertIsNone(\n            global_state.get_global_attribute(\"remat_scope_stack\")\n        )  # No stack initially\n\n        with RematScope(mode=\"full\"):\n            remat_stack = global_state.get_global_attribute(\"remat_scope_stack\")\n            self.assertIsNotNone(remat_stack)  # Stack is initialized\n            self.assertEqual(len(remat_stack), 1)  # Stack contains one entry\n\n            with RematScope(mode=\"activations\"):\n                remat_stack = global_state.get_global_attribute(\n                    \"remat_scope_stack\"\n                )\n                self.assertEqual(\n                    len(remat_stack), 2\n                )  # Stack contains two entries\n\n            remat_stack = global_state.get_global_attribute(\"remat_scope_stack\")\n            self.assertEqual(len(remat_stack), 1)  # Back to one entry\n\n        self.assertEqual(\n            global_state.get_global_attribute(\"remat_scope_stack\"), []\n        )  # Stack is cleared\n\n    def test_invalid_mode(self):\n        \"\"\"Test that invalid rematerialization modes raise an error.\"\"\"\n        with self.assertRaises(ValueError):\n            RematScope(mode=\"invalid\")  # Invalid mode should raise ValueError\n\n\n@pytest.mark.skipif(\n    backend.backend() in (\"openvino\", \"numpy\"),\n    reason=\"remat not supported on OpenVino and Numpy\",\n)\nclass RematTest(testing.TestCase):\n    def test_remat_basic_call(self):\n        # Generate dummy data\n        data_size = 10**5\n        x_train = np.random.normal(size=(data_size, 4))\n        y_train = np.random.normal(size=(data_size, 1))\n\n        epochs = 5\n        batch_size = 512\n        # test applying remat\n        output_with_remat = backend.core.remat(activations.ReLU())(x_train)\n        output_without_remat = activations.ReLU()(x_train)\n        self.assertAllClose(output_with_remat, output_without_remat)\n        # test remat in a model\n        intermediate_function = backend.core.remat(activations.ReLU())\n        inputs = layers.Input(shape=(4,))\n        x = layers.Dense(4)(inputs)\n        x = layers.Lambda(intermediate_function)(x)\n        outputs = layers.Dense(1)(x)\n        model = models.Model(inputs=inputs, outputs=outputs)\n        model.predict(x_train)\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n\n        # Train model\n        model.fit(\n            x_train,\n            y_train,\n            epochs=epochs,\n            batch_size=batch_size,\n            verbose=0,\n        )\n\n    def test_remat_with_kwargs(self):\n        # Define a function that uses keyword arguments\n        def fn_with_kwargs(x, scale=1.0, offset=0.0):\n            return x * scale + offset\n\n        x = np.array([1.0, 2.0, 3.0], dtype=np.float32)\n\n        # Test with keyword arguments\n        remat_fn = backend.core.remat(fn_with_kwargs)\n        result_with_kwargs = remat_fn(x, scale=2.0, offset=1.0)\n        expected = fn_with_kwargs(x, scale=2.0, offset=1.0)\n        self.assertAllClose(result_with_kwargs, expected)\n\n        # Test with default keyword arguments\n        result_with_defaults = remat_fn(x)\n        expected_defaults = fn_with_kwargs(x)\n        self.assertAllClose(result_with_defaults, expected_defaults)\n\n        # Test with partial keyword arguments\n        result_partial = remat_fn(x, scale=3.0)\n        expected_partial = fn_with_kwargs(x, scale=3.0)\n        self.assertAllClose(result_partial, expected_partial)\n"
  },
  {
    "path": "keras/src/backend/common/stateless_scope.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.backend.common import global_state\n\n\n@keras_export(\"keras.StatelessScope\")\nclass StatelessScope:\n    \"\"\"Scope to prevent any update to Keras Variables.\n\n    The values of variables to be used inside the scope\n    should be passed via the `state_mapping` argument, a\n    list of tuples `(k, v)` where `k` is a `Variable`\n    and `v` is the intended value for this variable\n    (a backend tensor).\n\n    Updated values can be collected on scope exit via\n    `value = scope.get_current_value(variable)`. No updates\n    will be applied in-place to any variables for the duration\n    of the scope.\n\n    Example:\n\n    ```python\n    state_mapping = [(k, ops.ones(k.shape, k.dtype)) for k in model.weights]\n    with keras.StatelessScope(state_mapping) as scope:\n        outputs = model.some_function(inputs)\n\n    # All model variables remain unchanged. Their new values can be\n    # collected via:\n    for k in model.weights:\n        new_value = scope.get_current_value(k)\n        print(f\"New value for {k}: {new_value})\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        state_mapping=None,\n        collect_losses=False,\n        initialize_variables=True,\n    ):\n        from keras.src import backend\n        from keras.src.backend.common.variables import Variable\n\n        self.collect_losses = collect_losses\n        self.initialize_variables = initialize_variables\n        self.losses = []\n        self.state_mapping = {}\n        state_mapping = state_mapping or {}\n        for k, v in state_mapping:\n            if not isinstance(k, Variable):\n                raise ValueError(\n                    \"Invalid reference variable in StatelessScope: \"\n                    \"all keys in argument `mapping` must be Variable \"\n                    f\"instances. Received instead: {k}\"\n                )\n            if isinstance(v, Variable):\n                v = backend.cast(v.value, dtype=k.dtype)\n            else:\n                v = backend.convert_to_tensor(v, dtype=k.dtype)\n            if k.shape != v.shape:\n                raise ValueError(\n                    \"Invalid variable value in StatelessScope: \"\n                    \"all values in argument `mapping` must be tensors with \"\n                    \"a shape that matches the corresponding variable shape. \"\n                    f\"For variable {k}, received invalid value {v} with shape \"\n                    f\"{v.shape}.\"\n                )\n            self.state_mapping[id(k)] = v\n\n    def __enter__(self):\n        self.original_scope = get_stateless_scope()\n        global_state.set_global_attribute(\"stateless_scope\", self)\n        return self\n\n    def add_loss(self, loss):\n        self.losses.append(loss)\n\n    def add_update(self, update):\n        variable, value = update\n        self.state_mapping[id(variable)] = value\n\n    def get_current_value(self, variable):\n        return self.state_mapping.get(id(variable), None)\n\n    def __exit__(self, *args, **kwargs):\n        global_state.set_global_attribute(\n            \"stateless_scope\", self.original_scope\n        )\n        if self.original_scope is None and self.initialize_variables:\n            # We're back in eager scope;\n            # if any variables were created within the stateless\n            # scope, we initialize them here.\n            from keras.src.backend.common.variables import (\n                initialize_all_variables,\n            )\n\n            initialize_all_variables()\n\n\ndef in_stateless_scope():\n    return global_state.get_global_attribute(\"stateless_scope\") is not None\n\n\ndef get_stateless_scope():\n    return global_state.get_global_attribute(\"stateless_scope\")\n"
  },
  {
    "path": "keras/src/backend/common/stateless_scope_test.py",
    "content": "import numpy as np\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.backend.common.stateless_scope import StatelessScope\n\n\nclass TestStatelessScope(testing.TestCase):\n    def test_basic_flow(self):\n        var1 = backend.Variable(np.zeros((2,)))\n        var2 = backend.Variable(np.zeros((2,)))\n        var_out = backend.Variable(np.zeros((2,)))\n\n        value1 = ops.ones(shape=(2,))\n        value2 = ops.ones(shape=(2,))\n        with StatelessScope(\n            state_mapping=[(var1, value1), (var2, value2)]\n        ) as scope:\n            out = var1 + var2\n            var_out.assign(out)\n            var_out_value = var_out + 0.0\n            # Inside scope: new value is used.\n            self.assertAllClose(var_out_value, 2 * np.ones((2,)))\n\n        # Out of scope: old value is used.\n        var_out_value = var_out + 0.0\n        self.assertAllClose(var_out_value, np.zeros((2,)))\n\n        # Updates are tracked.\n        var_out_value = scope.get_current_value(var_out)\n        self.assertAllClose(var_out_value, 2 * np.ones((2,)))\n\n        # Updates can be reapplied.\n        var_out.assign(scope.get_current_value(var_out))\n        self.assertAllClose(var_out_value, 2 * np.ones((2,)))\n\n    def test_invalid_key_in_state_mapping(self):\n        # var1 = backend.Variable(np.zeros((2,)))\n        invalid_key = \"not_a_keras_variable\"\n        value1 = ops.ones(shape=(2,))\n\n        with self.assertRaisesRegex(\n            ValueError, \"all keys in argument `mapping` must be Variable\"\n        ):\n            StatelessScope(state_mapping=[(invalid_key, value1)])\n\n    def test_invalid_value_shape_in_state_mapping(self):\n        var1 = backend.Variable(np.zeros((2,)))\n        invalid_value = ops.ones(shape=(3,))  # Incorrect shape\n\n        with self.assertRaisesRegex(\n            ValueError, \"all values in argument `mapping` must be tensors with\"\n        ):\n            StatelessScope(state_mapping=[(var1, invalid_value)])\n"
  },
  {
    "path": "keras/src/backend/common/symbolic_scope.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.backend.common import global_state\n\n\n@keras_export(\"keras.SymbolicScope\")\nclass SymbolicScope:\n    \"\"\"Scope to indicate the symbolic stage.\"\"\"\n\n    def __enter__(self):\n        self.original_scope = get_symbolic_scope()\n        global_state.set_global_attribute(\"symbolic_scope\", self)\n        return self\n\n    def __exit__(self, *args, **kwargs):\n        global_state.set_global_attribute(\"symbolic_scope\", self.original_scope)\n\n\ndef in_symbolic_scope():\n    return global_state.get_global_attribute(\"symbolic_scope\") is not None\n\n\ndef get_symbolic_scope():\n    return global_state.get_global_attribute(\"symbolic_scope\")\n"
  },
  {
    "path": "keras/src/backend/common/symbolic_scope_test.py",
    "content": "import numpy as np\n\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.backend.common.symbolic_scope import SymbolicScope\nfrom keras.src.backend.common.symbolic_scope import in_symbolic_scope\n\n\nclass TestSymbolicScope(testing.TestCase):\n    def test_basic_flow(self):\n        # Define a function that behaves differently according to\n        # `in_symbolic_scope`.\n        def compute_loss(y, y_pred):\n            if in_symbolic_scope():\n                return ops.zeros_like(y)\n            return ops.add(y, y_pred)\n\n        y = ops.ones(shape=(2,))\n        y_pred = ops.ones(shape=(2,))\n        with SymbolicScope():\n            loss = compute_loss(y, y_pred)\n        self.assertAllClose(loss, np.zeros((2,)))\n\n        loss = compute_loss(y, y_pred)\n        self.assertAllClose(loss, 2 * np.ones((2,)))\n"
  },
  {
    "path": "keras/src/backend/common/tensor_attributes.py",
    "content": "import weakref\n\nfrom keras.src.backend.common import global_state\n\n\ndef _clear_tensor_attr(tensor_id, attr):\n    attr_dict = global_state.get_global_attribute(f\"{attr}_dict\")\n    if attr_dict is not None and tensor_id in attr_dict:\n        del attr_dict[tensor_id]\n\n\ndef set_tensor_attr(tensor, attr, value):\n    try:\n        setattr(tensor, attr, value)\n    except AttributeError:\n        attr_dict = global_state.get_global_attribute(f\"{attr}_dict\")\n        if attr_dict is None:\n            if value is None:\n                return\n            attr_dict = {}\n            global_state.set_global_attribute(f\"{attr}_dict\", attr_dict)\n        if value is not None:\n            attr_dict[id(tensor)] = value\n            weakref.finalize(tensor, _clear_tensor_attr, id(tensor), attr)\n        elif id(tensor) in attr_dict:\n            del attr_dict[id(tensor)]\n\n\ndef get_tensor_attr(tensor, attr):\n    if not hasattr(tensor, attr):\n        attr_dict = global_state.get_global_attribute(f\"{attr}_dict\")\n        if attr_dict is not None:\n            return attr_dict.get(id(tensor), None)\n        else:\n            return None\n    return getattr(tensor, attr, None)\n"
  },
  {
    "path": "keras/src/backend/common/thread_safe_test.py",
    "content": "import concurrent\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import testing\n\n\nclass TestThreadSafe(testing.TestCase):\n    def test_is_thread_safe(self):\n        if backend.IS_THREAD_SAFE:\n            executor = concurrent.futures.ThreadPoolExecutor()\n\n            def sum(x, axis):\n                return ops.sum(x, axis=axis)\n\n            futures = []\n\n            for i in range(10000):\n                futures.clear()\n                x = ops.convert_to_tensor(np.random.rand(100, 100))\n                futures.append(executor.submit(sum, x, 1))\n                x = ops.convert_to_tensor(np.random.rand(100))\n                futures.append(executor.submit(sum, x, 0))\n                concurrent.futures.wait(\n                    futures, return_when=concurrent.futures.ALL_COMPLETED\n                )\n                [future.result() for future in futures]\n"
  },
  {
    "path": "keras/src/backend/common/variables.py",
    "content": "import numpy as np\n\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend import config\nfrom keras.src.backend.common import dtypes\nfrom keras.src.backend.common import global_state\nfrom keras.src.backend.common.name_scope import current_path\nfrom keras.src.backend.common.stateless_scope import get_stateless_scope\nfrom keras.src.backend.common.stateless_scope import in_stateless_scope\nfrom keras.src.utils.module_utils import tensorflow as tf\nfrom keras.src.utils.naming import auto_name\n\n\nclass Variable:\n    \"\"\"Represents a backend-agnostic variable in Keras.\n\n    A `Variable` acts as a container for state. It holds a tensor value and can\n    be updated. With the JAX backend, variables are used to implement\n    \"functionalization\", the pattern of lifting stateful operations out of\n    a piece of computation to turn it into a stateless function.\n\n    Args:\n        initializer: Initial value or callable for initialization.\n            If a callable is used, it should take the arguments\n            `shape` and `dtype`.\n        shape: Optional. Tuple for the variable's shape.\n            Required if `initializer` is a callable.\n        dtype: Optional. Data type of the variable. Defaults to the global float\n            dtype type (`\"float32\"` if never configured).\n        trainable: Optional. Boolean indicating if variable is trainable.\n            Defaults to `True`.\n        autocast: Optional. Boolean indicating whether the variable supports\n            autocasting. If `True`, the layer may first convert the variable\n            to the compute data type when accessed. Defaults to `True`.\n        aggregation: Optional string, one of `None`, `\"none\"`, `\"mean\"`,\n            `\"sum\"` or `\"only_first_replica\"` specifying how a distributed\n            variable will be aggregated. This serves as a semantic annotation,\n            to be taken into account by downstream backends or users. Defaults\n            to `\"none\"`.\n        name: Optional. A unique name for the variable. Automatically generated\n            if not set.\n\n    Attributes:\n        shape: The shape of the variable (tuple of integers).\n        ndim: The number of dimensions of the variable (integer).\n        dtype: The data type of the variable (string).\n        trainable: Whether the variable is trainable (boolean).\n        autocast: Whether the variable supports autocasting (boolean).\n        aggregation: How a distributed variable will be aggregated (string).\n        value: The current value of the variable (NumPy array or tensor).\n        name: The name of the variable (string).\n        path: The path of the variable within the Keras model or layer (string).\n        kwargs: Additional backend-specific keyword arguments.\n\n    Examples:\n\n    **Initializing a `Variable` with a NumPy array:**\n\n    ```python\n    import numpy as np\n    import keras\n    initial_array = np.ones((3, 3))\n    variable_from_array = keras.Variable(initializer=initial_array)\n    ```\n\n    **Using a Keras initializer to create a `Variable`:**\n\n    ```python\n    from keras.src.initializers import Ones\n    variable_from_initializer = keras.Variable(\n        initializer=Ones(), shape=(3, 3), dtype=\"float32\"\n    )\n    ```\n\n    **Updating the value of a `Variable`:**\n\n    ```python\n    new_value = np.zeros((3, 3), dtype=\"float32\")\n    variable_from_array.assign(new_value)\n    ```\n\n    **Marking a `Variable` as non-trainable:**\n\n    ```python\n    non_trainable_variable = keras.Variable(\n        initializer=np.ones((3, 3), dtype=\"float32\"), trainable=False\n    )\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        initializer,\n        shape=None,\n        dtype=None,\n        trainable=True,\n        autocast=True,\n        aggregation=\"none\",\n        synchronization=\"auto\",\n        name=None,\n        **kwargs,\n    ):\n        del kwargs\n        name = name or auto_name(self.__class__.__name__)\n        if not isinstance(name, str) or \"/\" in name:\n            raise ValueError(\n                \"Argument `name` must be a string and \"\n                \"cannot contain character `/`. \"\n                f\"Received: name={name}\"\n            )\n        if aggregation not in (\n            None,\n            \"none\",\n            \"mean\",\n            \"sum\",\n            \"only_first_replica\",\n        ):\n            raise ValueError(\n                \"Invalid value for argument `aggregation`. Expected \"\n                \"one of `None`, `'none'`, `'mean'`, `'sum'`, \"\n                \"`'only_first_replica'`. \"\n                f\"Received: aggregation={aggregation}\"\n            )\n        if aggregation is None:\n            aggregation = \"none\"\n        if synchronization not in (\n            None,\n            \"none\",\n            \"on_read\",\n            \"on_write\",\n            \"auto\",\n        ):\n            raise ValueError(\n                \"Invalid value for argument `synchronization`. Expected \"\n                \"one of `None`, `'none'`, `'on_read'`, `'on_write'`, \"\n                \"`'auto'`. \"\n                f\"Received: synchronization={synchronization}\"\n            )\n        if synchronization is None:\n            synchronization = \"none\"\n        self._name = name\n        parent_path = current_path()\n        if parent_path:\n            self._path = f\"{parent_path}/{name}\"\n        else:\n            self._path = name\n        self._shape = None\n        self._initializer = None\n        self._regularizer = None\n        self._constraint = None\n        self._trainable = bool(trainable)\n        self._autocast = bool(autocast)\n        self._aggregation = aggregation\n        self._synchronization = synchronization\n        # `self._overwrite_with_gradient` is an internal property to determine\n        # whether this variable should be overwritten by the computed gradient.\n        # Ref: https://github.com/google/flax/blob/main/flax/linen/fp8_ops.py\n        self._overwrite_with_gradient = False\n        if isinstance(initializer, str):\n            from keras.src import initializers\n\n            initializer = initializers.get(initializer)\n        if callable(initializer):\n            if shape is None:\n                raise ValueError(\n                    \"When creating a Variable from an initializer, \"\n                    \"the `shape` argument should be specified. \"\n                    f\"Received: initializer={initializer} \"\n                    f\"and shape={shape}\"\n                )\n        else:\n            initializer = self._convert_to_tensor(initializer, dtype=dtype)\n            # If dtype is None and `initializer` is an array, use its dtype.\n            if dtype is None:\n                dtype = initializer.dtype\n        self._dtype = standardize_dtype(dtype)\n\n        if in_stateless_scope():\n            if callable(initializer):\n                self._value = None\n                self._initializer = initializer\n                self._shape = self._validate_shape(shape)\n                register_uninitialized_variable(self)\n            else:\n                raise ValueError(\n                    \"You are attempting to create a variable \"\n                    \"while in a stateless scope. This is disallowed. \"\n                    \"Make sure that all variables are created \"\n                    \"before you start using your layer/model objects.\\n\\n\"\n                    \"In some cases, you might be seeing this error \"\n                    \"because you need to \"\n                    \"implement a `def build(self, input_shape)` method \"\n                    \"on your layer/model, which will \"\n                    \"create its variables.\\n\\n\"\n                    \"In some other cases, you might be seeing this error \"\n                    \"because you are instantiating a `Variable` and \"\n                    \"assigning it to a layer without going through \"\n                    \"self.add_variable()/self.add_weight(). Always prefer \"\n                    \"using these methods \"\n                    \"(with a `shape` and `initializer` argument).\"\n                )\n        else:\n            if callable(initializer):\n                self._shape = self._validate_shape(shape)\n                self._initialize_with_initializer(initializer)\n            else:\n                self._initialize(initializer)\n                self._shape = self._validate_shape(self._value.shape)\n        self._ndim = len(self._shape)\n\n    def _deferred_initialize(self):\n        if self._value is not None:\n            # If NNX is enabled, it's possible the variable was already\n            # initialized by a concrete call. In this case, _deferred_initialize\n            # returns early and does not raise an error.\n            if config.is_nnx_enabled():\n                return\n            raise ValueError(f\"Variable {self.path} is already initialized.\")\n\n        if in_stateless_scope():\n            raise ValueError(\n                \"You are attempting to initialize a variable \"\n                \"while in a stateless scope. This is disallowed. \"\n                \"Make sure that all variables are initialized \"\n                \"before you start using your layer/model objects.\"\n            )\n        self._initialize_with_initializer(self._initializer)\n        self._initializer = None\n\n    def _validate_shape(self, shape):\n        shape = standardize_shape(shape)\n        if None in shape:\n            raise ValueError(\n                \"Shapes used to initialize variables must be \"\n                \"fully-defined (no `None` dimensions). Received: \"\n                f\"shape={shape} for variable path='{self.path}'\"\n            )\n        return shape\n\n    def _maybe_autocast(self, value):\n        autocast_scope = get_autocast_scope()\n        if self._autocast and autocast_scope is not None:\n            return autocast_scope.maybe_cast(value)\n        return value\n\n    def numpy(self):\n        return np.array(self)\n\n    @property\n    def aggregation(self):\n        \"\"\"The strategy for aggregating this variable.\"\"\"\n        return self._aggregation\n\n    @property\n    def synchronization(self):\n        \"\"\"The strategy for synchronizing this variable.\"\"\"\n        return self._synchronization\n\n    @property\n    def value(self):\n        \"\"\"The current value of the variable (numpy array or backend tensor).\"\"\"\n        if in_stateless_scope():\n            scope = get_stateless_scope()\n            value = scope.get_current_value(self)\n            if value is not None:\n                return self._maybe_autocast(value)\n        if self._value is None:\n            # Uninitialized variable. Return a placeholder.\n            # This is fine because it's only ever used\n            # in during shape inference / graph tracing\n            # (anything else would be a bug, to be fixed.)\n            return self._maybe_autocast(\n                self._initializer(self._shape, dtype=self._dtype)\n            )\n        return self._maybe_autocast(self._value)\n\n    def assign(self, value):\n        value = self._convert_to_tensor(value, dtype=self._dtype)\n        if not shape_equal(value.shape, self.shape):\n            raise ValueError(\n                \"The shape of the target variable and \"\n                \"the shape of the target value in \"\n                \"`variable.assign(value)` must match. \"\n                f\"variable.shape={self.shape}, \"\n                f\"Received: value.shape={value.shape}. \"\n                f\"Target variable: {self}\"\n            )\n        if in_stateless_scope():\n            scope = get_stateless_scope()\n            scope.add_update((self, value))\n        else:\n            self._direct_assign(value)\n        return value\n\n    def assign_add(self, value):\n        return self.assign(self + value)\n\n    def assign_sub(self, value):\n        return self.assign(self - value)\n\n    @property\n    def dtype(self):\n        \"\"\"The data type of the variable.\"\"\"\n        autocast_scope = get_autocast_scope()\n        if (\n            self._autocast\n            and autocast_scope is not None\n            and is_float_dtype(self._dtype)\n        ):\n            dtype = autocast_scope.dtype\n        else:\n            dtype = self._dtype\n        return backend.standardize_dtype(dtype)\n\n    @property\n    def shape(self):\n        \"\"\"The shape of the variable.\"\"\"\n        return self._shape\n\n    @property\n    def ndim(self):\n        \"\"\"The number of dimensions of the variable.\"\"\"\n        return self._ndim\n\n    @property\n    def trainable(self):\n        \"\"\"Whether the variable is trainable.\"\"\"\n        return self._trainable\n\n    @trainable.setter\n    def trainable(self, value):\n        self._trainable = bool(value)\n\n    @property\n    def name(self):\n        \"\"\"The name of the variable.\"\"\"\n        return self._name\n\n    @property\n    def path(self):\n        \"\"\"The path of the variable within the Keras model or layer.\"\"\"\n        return self._path\n\n    @property\n    def overwrite_with_gradient(self):\n        \"\"\"Whether this variable should be overwritten by the gradient.\n\n        This property is designed for a special case where we want to overwrite\n        the variable directly with its computed gradient. For example, in float8\n        training, new `scale` and `amax_history` are computed as gradients, and\n        we want to overwrite them directly instead of following the typical\n        procedure such as gradient descent with a learning rate, gradient\n        clipping and weight decaying.\n        \"\"\"\n        return self._overwrite_with_gradient\n\n    @overwrite_with_gradient.setter\n    def overwrite_with_gradient(self, value):\n        if not isinstance(value, bool):\n            raise TypeError(\n                \"`overwrite_with_gradient` must be a boolean. \"\n                f\"Received: {value}\"\n            )\n        self._overwrite_with_gradient = value\n\n    @property\n    def regularizer(self):\n        return self._regularizer\n\n    @regularizer.setter\n    def regularizer(self, value):\n        from keras.src.regularizers import Regularizer\n\n        if value is not None and not isinstance(value, Regularizer):\n            raise ValueError(\n                \"Invalid value for attribute `regularizer`. Expected an \"\n                \"instance of `keras.regularizers.Regularizer`, or `None`. \"\n                f\"Received: regularizer={value}\"\n            )\n        self._regularizer = value\n\n    @property\n    def constraint(self):\n        return self._constraint\n\n    @constraint.setter\n    def constraint(self, value):\n        from keras.src.constraints import Constraint\n\n        if value is not None and not isinstance(value, Constraint):\n            raise ValueError(\n                \"Invalid value for attribute `constraint`. Expected an \"\n                \"instance of `keras.constraints.Constraint`, or `None`. \"\n                f\"Received: constraint={value}\"\n            )\n        self._constraint = value\n\n    def __repr__(self):\n        value = None\n        if hasattr(self, \"_value\") and self._value is not None:\n            try:\n                value = backend.core.convert_to_numpy(self._value)\n            except:\n                # In some cases the conversion to numpy can fail.\n                pass\n        value_str = f\", value={value}\" if value is not None else \"\"\n        return (\n            f\"<Variable path={self.path}, shape={self.shape}, \"\n            f\"dtype={self.dtype}{value_str}>\"\n        )\n\n    def _initialize(self, value):\n        raise NotImplementedError\n\n    def _initialize_with_initializer(self, initializer):\n        value = self._convert_to_tensor(\n            initializer(self._shape, dtype=self._dtype)\n        )\n        self._initialize(value)\n\n    def _convert_to_tensor(self, value, dtype=None):\n        raise NotImplementedError\n\n    def __getitem__(self, idx):\n        return self.value.__getitem__(idx)\n\n    def __int__(self):\n        if self.ndim > 0:\n            raise TypeError(\n                \"Only scalar arrays can be converted to Python scalars. \"\n                f\"Got: shape={self.shape}\"\n            )\n        return int(self.value)\n\n    def __float__(self):\n        if self.ndim > 0:\n            raise TypeError(\n                \"Only scalar arrays can be converted to Python scalars. \"\n                f\"Got: shape={self.shape}\"\n            )\n        return float(self.value)\n\n    def __array__(self, dtype=None):\n        # We can't directly use self.value.__array__ here because of scalar.\n        # Numpy require this method to return as array like object. In the case\n        # of scalar, it will fail the type checking from numpy. We need to\n        # return a 0d array via numpy.\n        return np.asarray(self.value.__array__(dtype))\n\n    def __bool__(self):\n        raise TypeError(\"A Keras Variable cannot be used as a boolean.\")\n\n    def __neg__(self):\n        return self.value.__neg__()\n\n    def __pos__(self):\n        return self.value\n\n    def __abs__(self):\n        return self.value.__abs__()\n\n    def __invert__(self):\n        return self.value.__invert__()\n\n    def __eq__(self, other):\n        return backend.numpy.equal(self.value, other)\n\n    def __ne__(self, other):\n        return backend.numpy.not_equal(self.value, other)\n\n    def __lt__(self, other):\n        return backend.numpy.less(self.value, other)\n\n    def __le__(self, other):\n        return backend.numpy.less_equal(self.value, other)\n\n    def __gt__(self, other):\n        return backend.numpy.greater(self.value, other)\n\n    def __ge__(self, other):\n        return backend.numpy.greater_equal(self.value, other)\n\n    def __add__(self, other):\n        return backend.numpy.add(self.value, other)\n\n    def __radd__(self, other):\n        return backend.numpy.add(other, self.value)\n\n    def __sub__(self, other):\n        return backend.numpy.subtract(self.value, other)\n\n    def __rsub__(self, other):\n        return backend.numpy.subtract(other, self.value)\n\n    def __mul__(self, other):\n        return backend.numpy.multiply(self.value, other)\n\n    def __rmul__(self, other):\n        return backend.numpy.multiply(other, self.value)\n\n    def __truediv__(self, other):\n        return backend.numpy.true_divide(self.value, other)\n\n    def __rtruediv__(self, other):\n        return backend.numpy.true_divide(other, self.value)\n\n    def __floordiv__(self, other):\n        return backend.numpy.floor_divide(self.value, other)\n\n    def __rfloordiv__(self, other):\n        return backend.numpy.floor_divide(other, self.value)\n\n    def __mod__(self, other):\n        return backend.numpy.mod(self.value, other)\n\n    def __rmod__(self, other):\n        return backend.numpy.mod(other, self.value)\n\n    def __pow__(self, other):\n        return backend.numpy.power(self.value, other)\n\n    def __rpow__(self, other):\n        return backend.numpy.power(other, self.value)\n\n    def __matmul__(self, other):\n        return backend.numpy.matmul(self.value, other)\n\n    def __rmatmul__(self, other):\n        return backend.numpy.matmul(other, self.value)\n\n    def __and__(self, other):\n        return backend.numpy.logical_and(self.value, other)\n\n    def __rand__(self, other):\n        return backend.numpy.logical_and(other, self.value)\n\n    def __or__(self, other):\n        return backend.numpy.logical_or(self.value, other)\n\n    def __ror__(self, other):\n        return backend.numpy.logical_or(other, self.value)\n\n    def __xor__(self, other):\n        return backend.numpy.logical_xor(self.value, other)\n\n    def __rxor__(self, other):\n        return backend.numpy.logical_xor(other, self.value)\n\n    def __round__(self, ndigits=None):\n        decimals = ndigits or 0\n        return backend.numpy.round(self.value, decimals=decimals)\n\n\ndef register_uninitialized_variable(variable):\n    uninitialized_variables = global_state.get_global_attribute(\n        \"uninitialized_variables\", [], set_to_default=True\n    )\n    uninitialized_variables.append(variable)\n\n\ndef initialize_all_variables():\n    collection = global_state.get_global_attribute(\"uninitialized_variables\")\n    if collection:\n        for v in collection:\n            v._deferred_initialize()\n    global_state.set_global_attribute(\"uninitialized_variables\", [])\n\n\n@keras_export(\n    [\"keras.utils.standardize_dtype\", \"keras.backend.standardize_dtype\"]\n)\ndef standardize_dtype(dtype):\n    if dtype is None:\n        return config.floatx()\n    dtype = dtypes.PYTHON_DTYPES_MAP.get(dtype, dtype)\n    if hasattr(dtype, \"name\"):\n        dtype = dtype.name\n    elif hasattr(dtype, \"__name__\"):\n        dtype = dtype.__name__\n    elif hasattr(dtype, \"__str__\") and (\n        \"torch\" in str(dtype) or \"jax.numpy\" in str(dtype)\n    ):\n        dtype = str(dtype).split(\".\")[-1]\n\n    if dtype not in dtypes.ALLOWED_DTYPES:\n        raise ValueError(f\"Invalid dtype: {dtype}\")\n    return dtype\n\n\ndef standardize_shape(shape):\n    if not isinstance(shape, tuple):\n        if shape is None:\n            raise ValueError(\"Undefined shapes are not supported.\")\n        if not hasattr(shape, \"__iter__\"):\n            raise ValueError(f\"Cannot convert '{shape}' to a shape.\")\n        if config.backend() == \"tensorflow\":\n            if isinstance(shape, tf.TensorShape):\n                # `tf.TensorShape` may contain `Dimension` objects.\n                # We need to convert the items in it to either int or `None`\n                shape = shape.as_list()\n\n    if config.backend() == \"jax\":\n        # Replace `_DimExpr` (dimension expression) with None\n        from jax import export as jax_export\n\n        shape = tuple(\n            None if jax_export.is_symbolic_dim(d) else d for d in shape\n        )\n\n    if config.backend() == \"torch\":\n        # Replace symbolic dimensions with None to preserve dynamic shapes\n        # during torch.export tracing\n        import torch\n\n        shape = tuple(None if isinstance(d, torch.SymInt) else d for d in shape)\n\n    # Handle dimensions that are not ints and not None, verify they're >= 0.\n    standardized_shape = []\n    for d in shape:\n        if d is None:\n            standardized_shape.append(d)\n            continue\n\n        # Reject these even if they can be cast to int successfully.\n        if isinstance(d, (str, float)):\n            raise ValueError(\n                f\"Cannot convert '{shape}' to a shape. \"\n                f\"Found invalid dimension '{d}' of type '{type(d)}'. \"\n            )\n\n        try:\n            # Cast numpy scalars, tf constant tensors, etc.\n            d = int(d)\n        except Exception as e:\n            raise ValueError(\n                f\"Cannot convert '{shape}' to a shape. \"\n                f\"Found invalid dimension '{d}' of type '{type(d)}'. \"\n            ) from e\n        if d < 0:\n            raise ValueError(\n                f\"Cannot convert '{shape}' to a shape. \"\n                \"Negative dimensions are not allowed.\"\n            )\n        standardized_shape.append(d)\n\n    # This also turns subclasses of `tuple` (e.g. `torch.Size`) to plain tuple.\n    return tuple(standardized_shape)\n\n\ndef shape_equal(a_shape, b_shape):\n    \"\"\"Return whether a_shape == b_shape (allows None entries).\"\"\"\n    if len(a_shape) != len(b_shape):\n        return False\n    for e1, e2 in zip(a_shape, b_shape):\n        if e1 is not None and e2 is not None and e1 != e2:\n            return False\n    return True\n\n\n@keras_export(\"keras.backend.is_float_dtype\")\ndef is_float_dtype(dtype):\n    dtype = standardize_dtype(dtype)\n    return dtype.startswith(\"float\") or dtype.startswith(\"bfloat\")\n\n\n@keras_export(\"keras.backend.is_int_dtype\")\ndef is_int_dtype(dtype):\n    dtype = standardize_dtype(dtype)\n    return dtype.startswith(\"int\") or dtype.startswith(\"uint\")\n\n\ndef get_autocast_scope():\n    return global_state.get_global_attribute(\"autocast_scope\")\n\n\nclass AutocastScope:\n    \"\"\"Context manager that enables the autocasting of float variables.\n\n    Under this context manager, float `Variables`s will be cast to `dtype`\n    (note that `dtype` must also be float).\n    \"\"\"\n\n    def __init__(self, dtype):\n        if dtype is not None:\n            dtype = standardize_dtype(dtype)\n            if not is_float_dtype(dtype):\n                raise ValueError(\n                    \"`AutocastScope` can only be used with \"\n                    \"a floating-point target dtype, such as 'float16'. \"\n                    f\"Received: dtype={dtype}\"\n                )\n        self.dtype = dtype\n        self.original_scope = None\n\n    def maybe_cast(self, value):\n        from keras.src import backend\n\n        if self.dtype is not None and is_float_dtype(value.dtype):\n            return backend.cast(value, dtype=self.dtype)\n        return value\n\n    def __enter__(self):\n        self.original_scope = get_autocast_scope()\n        global_state.set_global_attribute(\"autocast_scope\", self)\n\n    def __exit__(self, *args, **kwargs):\n        global_state.set_global_attribute(\"autocast_scope\", self.original_scope)\n"
  },
  {
    "path": "keras/src/backend/common/variables_test.py",
    "content": "import itertools\nfrom unittest.mock import create_autospec\n\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom conftest import skip_if_backend\nfrom keras.src import backend\nfrom keras.src import initializers\nfrom keras.src import ops\nfrom keras.src.backend.common import dtypes\nfrom keras.src.backend.common.variables import AutocastScope\nfrom keras.src.backend.common.variables import shape_equal\nfrom keras.src.backend.common.variables import standardize_dtype\nfrom keras.src.backend.common.variables import standardize_shape\nfrom keras.src.testing import test_case\nfrom keras.src.testing.test_utils import named_product\n\n\nclass VariableInitializationTest(test_case.TestCase):\n    \"\"\"Tests for Variable.__init__()\"\"\"\n\n    def test_deferred_initialization(self):\n        \"\"\"Tests deferred initialization of variables.\"\"\"\n        with backend.StatelessScope():\n            v = backend.Variable(\n                initializer=initializers.RandomNormal(), shape=(2, 2)\n            )\n            self.assertEqual(v._value, None)\n            # Variables can nevertheless be accessed\n            _ = v + 1\n        self.assertEqual(v._value.shape, (2, 2))\n\n        with self.assertRaisesRegex(ValueError, \"while in a stateless scope\"):\n            with backend.StatelessScope():\n                v = backend.Variable(initializer=0)\n\n    def test_variable_initialization_with_numpy_array(self):\n        \"\"\"Test variable init with numpy array initializer.\"\"\"\n        v = backend.Variable(\n            initializer=np.ones((2, 2), dtype=np.int32), trainable=False\n        )\n        self.assertAllClose(v.value, np.ones((2, 2)))\n        self.assertEqual(v.dtype, \"int32\")\n\n    def test_variable_initialization_with_native_array(self):\n        \"\"\"Test variable init with native array initializer.\"\"\"\n        v = backend.Variable(\n            initializer=ops.ones((2, 2), dtype=\"int32\"), trainable=False\n        )\n        self.assertAllClose(v.value, np.ones((2, 2)))\n        self.assertEqual(v.dtype, \"int32\")\n\n    def test_variable_initialization_with_python_array(self):\n        \"\"\"Test variable init with python array initializer.\"\"\"\n        v = backend.Variable(initializer=[[1, 1], [1, 1]], trainable=False)\n        self.assertAllClose(v.value, np.ones((2, 2)))\n        self.assertEqual(v.dtype, \"int32\")\n        v = backend.Variable(\n            initializer=[[1.0, 1.0], [1.0, 1.0]], trainable=False\n        )\n        self.assertAllClose(v.value, np.ones((2, 2)))\n        self.assertEqual(v.dtype, \"float32\")\n\n    def test_variable_initialization_with_lambda_expression(self):\n        # Test Python number\n        v = backend.Variable(\n            initializer=lambda *a, **kw: 1.0,\n            shape=(),\n            dtype=\"float32\",\n        )\n        self.assertAllClose(v.value, 1.0)\n        self.assertEqual(v.dtype, \"float32\")\n\n        # Test Python array\n        v = backend.Variable(\n            initializer=lambda *a, **kw: [1.0],\n            shape=(1,),\n            dtype=\"float32\",\n        )\n        self.assertAllClose(v.value, np.ones((1,)))\n        self.assertEqual(v.dtype, \"float32\")\n\n        # Test numpy array\n        v = backend.Variable(\n            initializer=lambda *a, **kw: np.ones((1,)),\n            shape=(1,),\n            dtype=\"float32\",\n        )\n        self.assertAllClose(v.value, np.ones((1,)))\n        self.assertEqual(v.dtype, \"float32\")\n\n        # Test backend array\n        v = backend.Variable(\n            initializer=lambda *a, **kw: ops.ones((1,)),\n            shape=(1,),\n            dtype=\"float32\",\n        )\n        self.assertAllClose(v.value, np.ones((1,)))\n        self.assertEqual(v.dtype, \"float32\")\n\n    def test_variable_initialization_with_strings(self):\n        \"\"\"Test variable init with non-callable initializer.\"\"\"\n        v = backend.Variable(initializer=\"ones\", shape=(2, 2))\n        self.assertAllClose(v.value, np.ones((2, 2)))\n\n    def test_variable_initialization_with_non_trainable(self):\n        \"\"\"Test variable initialization with non-trainable flag.\"\"\"\n        v = backend.Variable(initializer=np.ones((2, 2)), trainable=False)\n        self.assertFalse(v.trainable)\n\n    def test_variable_initialization_without_shape(self):\n        \"\"\"Test variable init without a shape.\"\"\"\n        with self.assertRaisesRegex(\n            ValueError,\n            \"When creating a Variable from an initializer, the `shape` \",\n        ):\n            backend.Variable(initializer=initializers.RandomNormal())\n\n    def test_deferred_initialize_already_initialized(self):\n        \"\"\"Test deferred init on an already initialized variable.\"\"\"\n        v = backend.Variable(initializer=np.ones((2, 2)))\n        with self.assertRaisesRegex(\n            ValueError, f\"Variable {v.path} is already initialized.\"\n        ):\n            v._deferred_initialize()\n\n    def test_variable_initialize(self):\n        \"\"\"Test initializing a variable.\"\"\"\n        v = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        init_value = np.array([4.0, 5.0, 6.0])\n        v._initialize(value=init_value)\n        self.assertAllClose(v.value, init_value)\n\n    def test_variable_without_shape_from_callable_initializer(self):\n        \"\"\"Test that Variable raises error\n        if shape is not provided for callable initializer.\"\"\"\n        with self.assertRaisesRegex(\n            ValueError, \"When creating a Variable from an initializer\"\n        ):\n            backend.Variable(initializer=lambda: np.ones((2, 2)))\n\n\nclass VariablePropertiesTest(test_case.TestCase):\n    \"\"\"Tests for Variable._deferred_initialize Variable._maybe_autocast\"\"\"\n\n    @skip_if_backend(\n        \"openvino\", \"Can not constant fold eltwise node by CPU plugin\"\n    )\n    def test_deferred_assignment(self):\n        \"\"\"Tests deferred assignment to variables.\"\"\"\n        with backend.StatelessScope() as scope:\n            v = backend.Variable(\n                initializer=initializers.RandomNormal(), shape=(2, 2)\n            )\n            self.assertEqual(v._value, None)\n            v.assign(np.zeros((2, 2)))\n            v.assign_add(2 * np.ones((2, 2)))\n            v.assign_sub(np.ones((2, 2)))\n        out = scope.get_current_value(v)\n        self.assertAllClose(out, np.ones((2, 2)))\n\n    def test_trainable_setter(self):\n        \"\"\"Tests the trainable setter.\"\"\"\n        v = backend.Variable(\n            initializer=initializers.RandomNormal(),\n            shape=(2, 2),\n        )\n        self.assertTrue(v.trainable)\n        v.trainable = False\n        self.assertFalse(v.trainable)\n\n        if backend.backend() == \"torch\":\n            v.trainable = True\n            self.assertTrue(v._value.requires_grad)\n            v.trainable = False\n            self.assertFalse(v._value.requires_grad)\n\n    def test_autocasting_float(self):\n        # Tests autocasting of float variables\n        v = backend.Variable(\n            initializer=initializers.RandomNormal(),\n            shape=(2, 2),\n            dtype=\"float32\",\n        )\n        self.assertEqual(v.dtype, \"float32\")\n        self.assertEqual(backend.standardize_dtype(v.value.dtype), \"float32\")\n        with AutocastScope(\"float16\"):\n            self.assertEqual(\n                backend.standardize_dtype(v.value.dtype), \"float16\"\n            )\n        self.assertEqual(backend.standardize_dtype(v.value.dtype), \"float32\")\n\n    def test_autocasting_float_assign(self):\n        # Tests assigning value to variable within an autocast scope\n        v = backend.Variable(\n            initializer=initializers.RandomNormal(),\n            shape=(2, 2),\n            dtype=\"float32\",\n        )\n        self.assertEqual(v.dtype, \"float32\")\n        self.assertEqual(backend.standardize_dtype(v.value.dtype), \"float32\")\n\n        # Assign float16 value within float16 scope\n        with AutocastScope(\"float16\"):\n            self.assertEqual(\n                backend.standardize_dtype(v.value.dtype), \"float16\"\n            )\n            v.assign(ops.ones((2, 2), \"float16\"))\n        self.assertEqual(backend.standardize_dtype(v.value.dtype), \"float32\")\n\n        # Assign float32 value within float16 scope\n        with AutocastScope(\"float16\"):\n            self.assertEqual(\n                backend.standardize_dtype(v.value.dtype), \"float16\"\n            )\n            v.assign(ops.zeros((2, 2), \"float32\"))\n        self.assertEqual(backend.standardize_dtype(v.value.dtype), \"float32\")\n\n    def test_autocasting_int(self):\n        # Test non-float variables are not affected\n        v = backend.Variable(\n            initializer=initializers.Ones(),\n            shape=(2, 2),\n            dtype=\"int32\",\n            trainable=False,\n        )\n        self.assertEqual(v.dtype, \"int32\")\n        self.assertEqual(backend.standardize_dtype(v.value.dtype), \"int32\")\n\n        with AutocastScope(\"float16\"):\n            self.assertEqual(backend.standardize_dtype(v.value.dtype), \"int32\")\n\n    def test_autocasting_float_with_autocast_off(self):\n        # Test autocast argument\n        v = backend.Variable(\n            initializer=initializers.RandomNormal(),\n            shape=(2, 2),\n            dtype=\"float32\",\n            autocast=False,\n        )\n        self.assertEqual(v.dtype, \"float32\")\n        self.assertEqual(backend.standardize_dtype(v.value.dtype), \"float32\")\n        with AutocastScope(\"float16\"):\n            self.assertEqual(\n                backend.standardize_dtype(v.value.dtype),\n                \"float32\",  # ignore AutocastScope\n            )\n        self.assertEqual(backend.standardize_dtype(v.value.dtype), \"float32\")\n\n    @parameterized.named_parameters(\n        named_product(\n            dtype=(\n                dtype for dtype in dtypes.ALLOWED_DTYPES if dtype != \"string\"\n            )\n        )\n    )\n    def test_standardize_dtype(self, dtype):\n        \"\"\"Tests standardize_dtype for all ALLOWED_DTYPES except string.\"\"\"\n        if backend.backend() == \"torch\" and dtype in (\n            \"uint16\",\n            \"uint32\",\n            \"uint64\",\n            \"complex64\",\n            \"complex128\",\n        ):\n            self.skipTest(f\"torch backend does not support dtype {dtype}\")\n\n        if backend.backend() == \"jax\":\n            if dtype in (\"complex128\",):\n                self.skipTest(f\"jax backend does not support dtype {dtype}\")\n            import jax\n\n            if not jax.config.x64_enabled and \"64\" in dtype:\n                self.skipTest(\n                    f\"jax backend does not support {dtype} without x64 enabled\"\n                )\n\n        if backend.backend() == \"openvino\" and dtype in (\n            \"complex64\",\n            \"complex128\",\n        ):\n            self.skipTest(f\"openvino backend does not support dtype {dtype}\")\n\n        x = backend.convert_to_tensor(np.zeros(()), dtype)\n        actual = standardize_dtype(x.dtype)\n        self.assertEqual(actual, dtype)\n\n    def test_standardize_dtype_with_torch_dtype(self):\n        \"\"\"Tests dtype standardization with PyTorch dtypes.\"\"\"\n        import torch\n\n        x = torch.randn(4, 4)\n        backend.standardize_dtype(x.dtype)\n\n    def test_name_validation(self):\n        \"\"\"Tests validation of variable names.\"\"\"\n        with self.assertRaisesRegex(\n            ValueError, \"Argument `name` must be a string\"\n        ):\n            backend.Variable(\n                initializer=initializers.RandomNormal(), name=12345\n            )\n\n        with self.assertRaisesRegex(ValueError, \"cannot contain character `/`\"):\n            backend.Variable(\n                initializer=initializers.RandomNormal(), name=\"invalid/name\"\n            )\n\n    def test_standardize_shape_with_none(self):\n        with self.assertRaisesRegex(\n            ValueError, \"Undefined shapes are not supported.\"\n        ):\n            standardize_shape(None)\n\n    def test_standardize_shape_with_non_iterable(self):\n        with self.assertRaisesRegex(\n            ValueError, \"Cannot convert '42' to a shape.\"\n        ):\n            standardize_shape(42)\n\n    def test_standardize_shape_with_valid_input(self):\n        shape = (3, 4, 5)\n        standardized_shape = standardize_shape(shape)\n        self.assertEqual(standardized_shape, (3, 4, 5))\n\n    def test_standardize_shape_with_valid_input_with_none(self):\n        shape = (3, None, 5)\n        standardized_shape = standardize_shape(shape)\n        self.assertEqual(standardized_shape, (3, None, 5))\n\n    def test_standardize_shape_with_valid_not_tuple_input(self):\n        shape = [3, 4, 5]\n        standardized_shape = standardize_shape(shape)\n        self.assertEqual(standardized_shape, (3, 4, 5))\n\n    def test_standardize_shape_with_numpy(self):\n        shape = [3, np.int32(4), np.int64(5)]\n        standardized_shape = standardize_shape(shape)\n        self.assertEqual(standardized_shape, (3, 4, 5))\n        for d in standardized_shape:\n            self.assertIsInstance(d, int)\n\n    def test_standardize_shape_with_string(self):\n        shape_with_string = (3, 4, \"5\")\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Cannot convert .* to a shape. Found invalid dimension '5'.\",\n        ):\n            standardize_shape(shape_with_string)\n\n    def test_standardize_shape_with_float(self):\n        shape_with_float = (3, 4, 5.0)\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Cannot convert .* to a shape. Found invalid dimension '5.0'.\",\n        ):\n            standardize_shape(shape_with_float)\n\n    def test_standardize_shape_with_object(self):\n        shape_with_object = (3, 4, object())\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Cannot convert .* to a shape. Found invalid dimension .*object\",\n        ):\n            standardize_shape(shape_with_object)\n\n    def test_standardize_shape_with_negative_dimension(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Cannot convert .* to a shape. Negative dimensions\",\n        ):\n            standardize_shape((3, 4, -5))\n\n    @parameterized.named_parameters(\n        (\"all_dynamic\", (None, None, None, 64), (None, None, None, 64)),\n        (\"mixed\", (None, 224, 224, 3), (None, 224, 224, 3)),\n        (\"all_static\", (1, 224, 224, 3), (1, 224, 224, 3)),\n    )\n    def test_standardize_shape_preserves_none(self, input_shape, expected):\n        \"\"\"Test that None dimensions are preserved correctly.\"\"\"\n        result = standardize_shape(input_shape)\n        self.assertEqual(result, expected)\n\n    def test_shape_equal_length_mismatch(self):\n        \"\"\"Test mismatch in lengths of shapes.\"\"\"\n        self.assertFalse(shape_equal((3, 2), (3, 2, 4)))\n        self.assertFalse(shape_equal((), (3,)))\n        self.assertFalse(shape_equal((3, 2, 4, 5), (3, 2, 4)))\n\n    def test_autocast_scope_with_non_float_dtype(self):\n        \"\"\"Tests autocast scope with non-float dtype.\"\"\"\n        with self.assertRaisesRegex(\n            ValueError,\n            \"`AutocastScope` can only be used with a floating-point\",\n        ):\n            _ = AutocastScope(\"int32\")\n\n    def test_variable_path_creation(self):\n        \"\"\"Test path creation for a variable.\"\"\"\n        v = backend.Variable(initializer=np.ones((2, 2)), name=\"test_var\")\n        self.assertEqual(v.path, \"test_var\")\n\n        with backend.name_scope(\"test_scope\"):\n            v = backend.Variable(initializer=np.ones((2, 2)), name=\"test_var\")\n            self.assertEqual(v.path, \"test_scope/test_var\")\n\n    def test_overwrite_with_gradient_setter(self):\n        v = backend.Variable(\n            initializer=initializers.RandomNormal(),\n            shape=(2, 2),\n        )\n        self.assertFalse(v.overwrite_with_gradient)\n        v.overwrite_with_gradient = True\n        self.assertTrue(v.overwrite_with_gradient)\n\n        with self.assertRaisesRegex(TypeError, \"must be a boolean.\"):\n            v.overwrite_with_gradient = \"true\"\n\n\nclass VariableNumpyValueAndAssignmentTest(test_case.TestCase):\n    \"\"\"tests for Variable.numpy(), Variable.value() and Variable.assign()\"\"\"\n\n    def test_variable_numpy(self):\n        \"\"\"Test retrieving the value of a variable as a numpy array.\"\"\"\n        v = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        self.assertIsInstance(v.numpy(), np.ndarray)\n        self.assertAllClose(v.numpy(), np.array([1.0, 2.0, 3.0]))\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"Tests for MirroredVariable under tf backend\",\n    )\n    def test_variable_numpy_scalar(self):\n        from keras.src.utils.module_utils import tensorflow as tf\n\n        strategy = tf.distribute.MirroredStrategy([\"cpu:0\", \"cpu:1\"])\n        with strategy.scope():\n            v = backend.Variable(initializer=0.0)\n\n        np_value = backend.convert_to_numpy(v)\n        self.assertIsInstance(np_value, np.ndarray)\n        self.assertAllClose(np_value, 0.0)\n\n    def test_variable_value(self):\n        \"\"\"Test retrieving the value of a variable.\"\"\"\n        v = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        self.assertAllClose(v.value, np.array([1.0, 2.0, 3.0]))\n\n    def test_variable_assign(self):\n        \"\"\"Test assigning a new value to a variable.\"\"\"\n        v = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        v.assign(np.array([4.0, 5.0, 6.0]))\n        self.assertAllClose(v.value, np.array([4.0, 5.0, 6.0]))\n\n    def test_variable_assign_return(self):\n        \"\"\"Test assigning a new value and returning.\"\"\"\n        v = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        r = v.assign(np.array([4.0, 5.0, 6.0]))\n        self.assertAllClose(r, np.array([4.0, 5.0, 6.0]))\n\n    def test_variable_assign_add(self):\n        \"\"\"Test the assign_add method on a variable.\"\"\"\n        v = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        v.assign_add(np.array([1.0, 1.0, 1.0]))\n        self.assertAllClose(v.value, np.array([2.0, 3.0, 4.0]))\n\n    def test_variable_assign_add_return(self):\n        \"\"\"Test assign_add a new value and returning.\"\"\"\n        v = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        r = v.assign_add(np.array([1.0, 1.0, 1.0]))\n        self.assertAllClose(r, np.array([2.0, 3.0, 4.0]))\n\n    def test_variable_assign_sub(self):\n        \"\"\"Test the assign_sub method on a variable.\"\"\"\n        v = backend.Variable(initializer=np.array([2.0, 3.0, 4.0]))\n        v.assign_sub(np.array([1.0, 1.0, 1.0]))\n        self.assertAllClose(v.value, np.array([1.0, 2.0, 3.0]))\n\n    def test_variable_assign_sub_return(self):\n        \"\"\"Test assign_sub a new value and returning.\"\"\"\n        v = backend.Variable(initializer=np.array([2.0, 3.0, 4.0]))\n        r = v.assign_sub(np.array([1.0, 1.0, 1.0]))\n        self.assertAllClose(r, np.array([1.0, 2.0, 3.0]))\n\n    def test_deferred_initialize_within_stateless_scope(self):\n        \"\"\"Test deferred init within a stateless scope.\"\"\"\n        with backend.StatelessScope():\n            v = backend.Variable(\n                initializer=initializers.RandomNormal(), shape=(2, 2)\n            )\n            with self.assertRaisesRegex(\n                ValueError,\n                \"You are attempting to initialize a variable \"\n                \"while in a stateless scope. This is disallowed.\",\n            ):\n                v._deferred_initialize()\n\n\nclass VariableDtypeShapeNdimRepr(test_case.TestCase):\n    \"\"\"tests for dtype, shape, ndim, __repr__\"\"\"\n\n    def test_variable_dtype(self):\n        \"\"\"Test retrieving the dtype of a variable.\"\"\"\n        v = backend.Variable(\n            initializer=np.array([1.0, 2.0, 3.0], dtype=np.float32)\n        )\n        self.assertEqual(v.dtype, \"float32\")\n\n    def test_variable_shape(self):\n        \"\"\"Test retrieving the shape of a variable.\"\"\"\n        v = backend.Variable(initializer=np.array([[1.0, 2.0], [3.0, 4.0]]))\n        self.assertEqual(v.shape, (2, 2))\n\n    def test_variable_ndim(self):\n        \"\"\"Test retrieving the number of dimensions of a variable.\"\"\"\n        v = backend.Variable(initializer=np.array([[1.0, 2.0], [3.0, 4.0]]))\n        self.assertEqual(v.ndim, 2)\n\n    def test_variable_repr(self):\n        \"\"\"Test the string representation of a variable.\"\"\"\n        v = backend.Variable(\n            initializer=np.array([1.0, 2.0, 3.0], dtype=np.float32),\n            name=\"test_var\",\n        )\n        expected_repr = (\n            \"<Variable path=test_var, shape=(3,), dtype=float32, \"\n            \"value=[1. 2. 3.]>\"\n        )\n        self.assertEqual(repr(v), expected_repr)\n\n        # Test with `backend.StatelessScope()`\n        with backend.StatelessScope():\n            v = backend.Variable(\n                initializer=\"zeros\", shape=(3,), name=\"test_var\"\n            )\n            expected_repr = (\n                \"<Variable path=test_var, shape=(3,), dtype=float32>\"\n            )\n            self.assertEqual(repr(v), expected_repr)\n\n    def test_variable_getitem(self):\n        \"\"\"Test getting an item from a variable.\"\"\"\n        v = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        self.assertEqual(v[0], 1)\n\n    def test_variable_initialize(self):\n        \"\"\"Test initializing a variable.\"\"\"\n        v = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        init_value = np.array([4.0, 5.0, 6.0])\n        v._initialize(value=init_value)\n        self.assertAllClose(v.value, init_value)\n\n    def test_variable_convert_to_tensor(self):\n        \"\"\"Test converting a variable to a tensor.\"\"\"\n        v = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        self.assertAllClose(\n            v._convert_to_tensor(v.value), np.array([1.0, 2.0, 3.0])\n        )\n\n    def test_variable_convert_to_tensor_with_dtype(self):\n        \"\"\"Test converting a variable to a tensor with a dtype.\"\"\"\n        v = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        self.assertAllClose(\n            v._convert_to_tensor(v.value, dtype=\"float32\"),\n            np.array([1.0, 2.0, 3.0]),\n        )\n\n    def test_variable_array(self):\n        \"\"\"Test converting a variable to an array.\"\"\"\n        v = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        self.assertAllClose(v.__array__(), np.array([1.0, 2.0, 3.0]))\n\n\nclass VariableOpsCorrectnessTest(test_case.TestCase):\n    \"\"\"Tests for operations on Variable.\"\"\"\n\n    def test_int(self):\n        v = backend.Variable(initializer=np.array(-1.1))\n        self.assertAllClose(int(v), np.array(-1))\n\n    def test_float(self):\n        v = backend.Variable(initializer=np.array(-1.1))\n        self.assertAllClose(float(v), np.array(-1.1))\n\n    def test__neg__(self):\n        \"\"\"Test negating a variable.\"\"\"\n        v = backend.Variable(initializer=np.array([-1.0, 2.0]), trainable=False)\n        self.assertAllClose(v.__neg__(), np.array([1.0, -2.0]))\n\n    def test__abs__(self):\n        \"\"\"Test absolute value on a variable.\"\"\"\n        v = backend.Variable(initializer=np.array([-1.0, 2.0]), trainable=False)\n        self.assertAllClose(v.__abs__(), np.array([1.0, 2.0]))\n\n    def test__invert__(self):\n        \"\"\"Test bitwise not on a variable.\"\"\"\n        v = backend.Variable(\n            initializer=np.array([True, False]), trainable=False, dtype=\"bool\"\n        )\n        self.assertAllClose(v.__invert__(), np.array([False, True]))\n\n    def test__eq__(self):\n        \"\"\"Test equality comparison on a variable.\"\"\"\n        v = backend.Variable(initializer=np.array([1.0, 2.0]), trainable=False)\n        self.assertAllClose(\n            v.__eq__(np.array([1.0, 2.0])), np.array([True, True])\n        )\n\n    def test__ne__(self):\n        \"\"\"Test inequality comparison on a variable.\"\"\"\n        v = backend.Variable(initializer=np.array([1.0, 2.0]), trainable=False)\n        self.assertAllClose(\n            v.__ne__(np.array([1.0, 2.0])), np.array([False, False])\n        )\n\n    def test__lt__(self):\n        \"\"\"Test less than comparison on a variable.\"\"\"\n        v = backend.Variable(initializer=np.array([1.0, 2.0]), trainable=False)\n        self.assertAllClose(\n            v.__lt__(np.array([1.0, 2.0])), np.array([False, False])\n        )\n\n    def test__le__(self):\n        \"\"\"Test less than or equal to comparison on a variable.\"\"\"\n        v = backend.Variable(initializer=np.array([1.0, 2.0]), trainable=False)\n        self.assertAllClose(\n            v.__le__(np.array([1.0, 2.0])), np.array([True, True])\n        )\n\n    def test__gt__(self):\n        \"\"\"Test greater than comparison on a variable.\"\"\"\n        v = backend.Variable(initializer=np.array([1.0, 2.0]), trainable=False)\n        self.assertAllClose(\n            v.__gt__(np.array([1.0, 2.0])), np.array([False, False])\n        )\n\n    def test__ge__(self):\n        \"\"\"Test greater than or equal to comparison on a variable.\"\"\"\n        v = backend.Variable(initializer=np.array([1.0, 2.0]), trainable=False)\n        self.assertAllClose(\n            v.__ge__(np.array([1.0, 2.0])), np.array([True, True])\n        )\n\n    def test__add__(self):\n        \"\"\"Test addition operation on a variable.\"\"\"\n        v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        v2 = backend.Variable(initializer=np.array([4.0, 5.0, 6.0]))\n        self.assertAllClose(v1.__add__(v2), np.array([5.0, 7.0, 9.0]))\n\n    def test__radd__(self):\n        \"\"\"Test reverse addition operation on a variable.\"\"\"\n        v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        v2 = backend.Variable(initializer=np.array([4.0, 5.0, 6.0]))\n        self.assertAllClose(v1.__radd__(v2), np.array([5.0, 7.0, 9.0]))\n\n    def test__sub__(self):\n        \"\"\"Test subtraction operation on a variable.\"\"\"\n        v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        v2 = backend.Variable(initializer=np.array([4.0, 5.0, 6.0]))\n        self.assertAllClose(v1.__sub__(v2), np.array([-3.0, -3.0, -3.0]))\n\n    def test__rsub__(self):\n        \"\"\"Test reverse subtraction operation on a variable.\"\"\"\n        v1 = backend.Variable(initializer=np.array([4.0, 5.0, 6.0]))\n        v2 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        self.assertAllClose(v1.__rsub__(v2), np.array([-3.0, -3.0, -3.0]))\n\n    def test__mul__(self):\n        \"\"\"Test multiplication operation on a variable.\"\"\"\n        v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        v2 = backend.Variable(initializer=np.array([4.0, 5.0, 6.0]))\n        self.assertAllClose(v1.__mul__(v2), np.array([4.0, 10.0, 18.0]))\n\n    def test__rmul__(self):\n        \"\"\"Test reverse multiplication operation on a variable.\"\"\"\n        v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        v2 = backend.Variable(initializer=np.array([4.0, 5.0, 6.0]))\n        self.assertAllClose(v1.__rmul__(v2), np.array([4.0, 10.0, 18.0]))\n\n    def test__truediv__(self):\n        \"\"\"Test true division operation on a variable.\"\"\"\n        v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        v2 = backend.Variable(initializer=np.array([4.0, 5.0, 6.0]))\n        self.assertAllClose(v1.__truediv__(v2), np.array([0.25, 0.4, 0.5]))\n\n    def test__rtruediv__(self):\n        \"\"\"Test reverse true division operation on a variable.\"\"\"\n        v1 = backend.Variable(initializer=np.array([4.0, 5.0, 6.0]))\n        v2 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        self.assertAllClose(v1.__rtruediv__(v2), np.array([0.25, 0.4, 0.5]))\n\n    @skip_if_backend(\n        \"openvino\", \"`floor_divide` is not supported with openvino backend\"\n    )\n    def test__floordiv__(self):\n        \"\"\"Test floordiv operation on a variable.\"\"\"\n        v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        v2 = backend.Variable(initializer=np.array([-4.0, 5.0, 6.0]))\n        self.assertAllClose(v1.__floordiv__(v2), np.array([-1.0, 0.0, 0.0]))\n\n    @skip_if_backend(\n        \"openvino\", \"`floor_divide` is not supported with openvino backend\"\n    )\n    def test__rfloordiv__(self):\n        \"\"\"Test reverse floordiv operation on a variable.\"\"\"\n        v1 = backend.Variable(initializer=np.array([-4.0, 5.0, 6.0]))\n        v2 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        self.assertAllClose(v1.__rfloordiv__(v2), np.array([-1.0, 0.0, 0.0]))\n\n    def test__mod__(self):\n        \"\"\"Test mod operation on a variable.\"\"\"\n        v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        v2 = backend.Variable(initializer=np.array([-4.0, 5.0, 6.0]))\n        self.assertAllClose(v1.__mod__(v2), np.array([-3.0, 2.0, 3.0]))\n\n    def test__rmod__(self):\n        \"\"\"Test reverse mod operation on a variable.\"\"\"\n        v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        v2 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        self.assertAllClose(v1.__rmod__(v2), np.array([0.0, 0.0, 0.0]))\n\n    def test__pow__(self):\n        \"\"\"Test pow operation on a variable.\"\"\"\n        v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        v2 = backend.Variable(initializer=np.array([-4.0, 5.0, 6.0]))\n        self.assertAllClose(v1.__pow__(v2), np.array([1.0, 32.0, 729.0]))\n\n    def test__rpow__(self):\n        \"\"\"Test reverse power operation on a variable.\"\"\"\n        v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        v2 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        self.assertAllClose(v1.__rpow__(v2), np.array([1.0, 4.0, 27.0]))\n\n    def test__matmul__(self):\n        \"\"\"Test matmul operation on a variable.\"\"\"\n        v1 = backend.Variable(initializer=np.array([[1.0, 2.0], [3.0, 4.0]]))\n        v2 = backend.Variable(initializer=np.array([[5.0, 6.0], [7.0, 8.0]]))\n        self.assertAllClose(\n            v1.__matmul__(v2), np.array([[19.0, 22.0], [43.0, 50.0]])\n        )\n\n    def test__rmatmul__(self):\n        \"\"\"Test reverse matmul operation on a variable.\"\"\"\n        v1 = backend.Variable(initializer=np.array([[1.0, 2.0], [3.0, 4.0]]))\n        v2 = backend.Variable(initializer=np.array([[5.0, 6.0], [7.0, 8.0]]))\n        self.assertAllClose(\n            v1.__rmatmul__(v2), np.array([[23.0, 34.0], [31.0, 46.0]])\n        )\n\n    def test__and__(self):\n        \"\"\"Test bitwise and operation on a variable.\"\"\"\n        v1 = backend.Variable(\n            initializer=np.array([True, False]), dtype=\"bool\", trainable=False\n        )\n        v2 = backend.Variable(\n            initializer=np.array([True, True]), dtype=\"bool\", trainable=False\n        )\n        self.assertAllClose(v1.__and__(v2), np.array([True, False]))\n\n    def test__rand__(self):\n        \"\"\"Test reverse bitwise and operation on a variable.\"\"\"\n        v1 = backend.Variable(\n            initializer=np.array([True, False]), dtype=\"bool\", trainable=False\n        )\n        v2 = backend.Variable(\n            initializer=np.array([True, True]), dtype=\"bool\", trainable=False\n        )\n        self.assertAllClose(v1.__rand__(v2), np.array([True, False]))\n\n    def test__or__(self):\n        \"\"\"Test bitwise or operation on a variable.\"\"\"\n        v1 = backend.Variable(\n            initializer=np.array([True, False]), dtype=\"bool\", trainable=False\n        )\n        v2 = backend.Variable(\n            initializer=np.array([True, True]), dtype=\"bool\", trainable=False\n        )\n        self.assertAllClose(v1.__or__(v2), np.array([True, True]))\n\n    def test__ror__(self):\n        \"\"\"Test reverse bitwise or operation on a variable.\"\"\"\n        v1 = backend.Variable(\n            initializer=np.array([True, False]), dtype=\"bool\", trainable=False\n        )\n        v2 = backend.Variable(\n            initializer=np.array([True, True]), dtype=\"bool\", trainable=False\n        )\n        self.assertAllClose(v1.__ror__(v2), np.array([True, True]))\n\n    def test__xor__(self):\n        \"\"\"Test bitwise xor operation on a variable.\"\"\"\n        v1 = backend.Variable(\n            initializer=np.array([True, False]), dtype=\"bool\", trainable=False\n        )\n        v2 = backend.Variable(\n            initializer=np.array([True, True]), dtype=\"bool\", trainable=False\n        )\n        self.assertAllClose(v1.__xor__(v2), np.array([False, True]))\n\n    def test__rxor__(self):\n        \"\"\"Test reverse bitwise xor operation on a variable.\"\"\"\n        v1 = backend.Variable(\n            initializer=np.array([True, False]), dtype=\"bool\", trainable=False\n        )\n        v2 = backend.Variable(\n            initializer=np.array([True, True]), dtype=\"bool\", trainable=False\n        )\n        self.assertAllClose(v1.__rxor__(v2), np.array([False, True]))\n\n    def test__pos__(self):\n        \"\"\"Test unary plus on a variable.\"\"\"\n        v = backend.Variable(initializer=np.array([-1.0, 2.0]), trainable=False)\n        self.assertAllClose(v.__pos__(), np.array([-1.0, 2.0]))\n\n    def test_variable_pow(self):\n        \"\"\"Test pow operation on a variable.\"\"\"\n        v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        v2 = backend.Variable(initializer=np.array([4.0, 5.0, 6.0]))\n        result = v1**v2\n        self.assertAllClose(result, np.array([1.0, 32.0, 729.0]))\n\n    def test_variable_rpow(self):\n        \"\"\"Test reverse power operation on a variable.\"\"\"\n        v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))\n        v2 = backend.Variable(initializer=np.array([4.0, 5.0, 6.0]))\n        result = v2**v1\n        self.assertAllClose(result, np.array([4.0, 25.0, 216.0]))\n\n    def test_round(self):\n        v = backend.Variable(initializer=np.array([1.1, 2.2, 3.3]))\n        self.assertAllClose(round(v), np.array([1.0, 2.0, 3.0]))\n\n\nclass VariableOpsBehaviorTest(test_case.TestCase):\n    def test_invalid_bool(self):\n        \"\"\"Test converting a variable to boolean.\"\"\"\n        v = backend.Variable(initializer=np.ones((2, 2)))\n        with self.assertRaisesRegex(\n            TypeError, \"A Keras Variable cannot be used as a boolean.\"\n        ):\n            bool(v)\n\n    def test_invalid_int(self):\n        v = backend.Variable(initializer=np.ones((2, 2)))\n        with self.assertRaisesRegex(\n            TypeError, \"Only scalar arrays can be converted to Python scalars.\"\n        ):\n            int(v)\n\n    def test_invalid_float(self):\n        v = backend.Variable(initializer=np.ones((2, 2)))\n        with self.assertRaisesRegex(\n            TypeError, \"Only scalar arrays can be converted to Python scalars.\"\n        ):\n            float(v)\n\n\nclass VariableOpsDTypeTest(test_case.TestCase):\n    \"\"\"Test the dtype to verify that the behavior matches JAX.\"\"\"\n\n    ALL_DTYPES = [\n        x\n        for x in dtypes.ALLOWED_DTYPES\n        if x\n        not in (\n            \"string\",\n            \"complex128\",\n            # Remove 64-bit dtypes.\n            \"float64\",\n            \"uint64\",\n            \"int64\",\n        )\n        + dtypes.FLOAT8_TYPES  # Remove float8 dtypes for the following tests\n    ] + [None]\n    INT_DTYPES = [x for x in dtypes.INT_TYPES if x not in (\"uint64\", \"int64\")]\n    FLOAT_DTYPES = [x for x in dtypes.FLOAT_TYPES if x not in (\"float64\",)]\n    COMPLEX_DTYPES = [\"complex32\", \"complex64\"]\n    if backend.backend() == \"torch\":\n        ALL_DTYPES = [\n            x for x in ALL_DTYPES if x not in (\"uint16\", \"uint32\", \"complex64\")\n        ]\n        INT_DTYPES = [x for x in INT_DTYPES if x not in (\"uint16\", \"uint32\")]\n    elif backend.backend() == \"tensorflow\":\n        # TODO(hongyu): Re-enable uint32 tests once we determine how to handle\n        # dtypes.result_type(uint32, int*) -> int64 promotion.\n        # Since TF variables require int64 to be placed on the GPU, we\n        # exclusively enable the int64 dtype for TF. However, JAX does not\n        # natively support int64, which prevents us from comparing the dtypes.\n        ALL_DTYPES = [x for x in ALL_DTYPES if x not in (\"uint32\",)]\n        INT_DTYPES = [x for x in INT_DTYPES if x not in (\"uint32\",)]\n    elif backend.backend() == \"openvino\":\n        ALL_DTYPES = [x for x in ALL_DTYPES if x not in (\"complex64\",)]\n    NON_COMPLEX_DTYPES = [\n        x for x in ALL_DTYPES if x and x not in [\"complex32\", \"complex64\"]\n    ]\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_eq(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = backend.Variable(\"ones\", shape=(1,), dtype=dtype1, trainable=False)\n        x2 = backend.Variable(\"ones\", shape=(1,), dtype=dtype2, trainable=False)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.equal(x1_jax, x2_jax).dtype)\n\n        self.assertDType(x1 == x2, expected_dtype)\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_ne(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = backend.Variable(\"ones\", shape=(1,), dtype=dtype1, trainable=False)\n        x2 = backend.Variable(\"ones\", shape=(1,), dtype=dtype2, trainable=False)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.not_equal(x1_jax, x2_jax).dtype)\n\n        self.assertDType(x1 != x2, expected_dtype)\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(NON_COMPLEX_DTYPES, 2))\n    )\n    def test_lt(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = backend.Variable(\"ones\", shape=(1,), dtype=dtype1, trainable=False)\n        x2 = backend.Variable(\"ones\", shape=(1,), dtype=dtype2, trainable=False)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.less(x1_jax, x2_jax).dtype)\n\n        self.assertDType(x1 < x2, expected_dtype)\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(NON_COMPLEX_DTYPES, 2))\n    )\n    def test_le(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = backend.Variable(\"ones\", shape=(1,), dtype=dtype1, trainable=False)\n        x2 = backend.Variable(\"ones\", shape=(1,), dtype=dtype2, trainable=False)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.less_equal(x1_jax, x2_jax).dtype)\n\n        self.assertDType(x1 <= x2, expected_dtype)\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(NON_COMPLEX_DTYPES, 2))\n    )\n    def test_gt(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = backend.Variable(\"ones\", shape=(1,), dtype=dtype1, trainable=False)\n        x2 = backend.Variable(\"ones\", shape=(1,), dtype=dtype2, trainable=False)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.greater(x1_jax, x2_jax).dtype)\n\n        self.assertDType(x1 > x2, expected_dtype)\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(NON_COMPLEX_DTYPES, 2))\n    )\n    def test_ge(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = backend.Variable(\"ones\", shape=(1,), dtype=dtype1, trainable=False)\n        x2 = backend.Variable(\"ones\", shape=(1,), dtype=dtype2, trainable=False)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(\n            jnp.greater_equal(x1_jax, x2_jax).dtype\n        )\n\n        self.assertDType(x1 >= x2, expected_dtype)\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_add(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = backend.Variable(\"ones\", shape=(1,), dtype=dtype1, trainable=False)\n        x2 = backend.Variable(\"ones\", shape=(1,), dtype=dtype2, trainable=False)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.add(x1_jax, x2_jax).dtype)\n\n        self.assertDType(x1 + x2, expected_dtype)\n        self.assertDType(x1.__radd__(x2), expected_dtype)\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_sub(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = backend.Variable(\"ones\", shape=(1,), dtype=dtype1, trainable=False)\n        x2 = backend.Variable(\"ones\", shape=(1,), dtype=dtype2, trainable=False)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.add(x1_jax, x2_jax).dtype)\n\n        self.assertDType(x1 - x2, expected_dtype)\n        self.assertDType(x1.__rsub__(x2), expected_dtype)\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_mul(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = backend.Variable(\"ones\", shape=(1,), dtype=dtype1, trainable=False)\n        x2 = backend.Variable(\"ones\", shape=(1,), dtype=dtype2, trainable=False)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.add(x1_jax, x2_jax).dtype)\n\n        self.assertDType(x1 * x2, expected_dtype)\n        self.assertDType(x1.__rmul__(x2), expected_dtype)\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(NON_COMPLEX_DTYPES, 2))\n    )\n    def test_truediv(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = backend.Variable(\"ones\", shape=(1,), dtype=dtype1, trainable=False)\n        x2 = backend.Variable(\"ones\", shape=(1,), dtype=dtype2, trainable=False)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(\n            jnp.true_divide(x1_jax, x2_jax).dtype\n        )\n\n        self.assertDType(x1 / x2, expected_dtype)\n        self.assertDType(x1.__rtruediv__(x2), expected_dtype)\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(NON_COMPLEX_DTYPES, 2))\n    )\n    @skip_if_backend(\n        \"openvino\", \"`floor_divide` is not supported with openvino backend\"\n    )\n    def test_floordiv(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = backend.Variable(\"ones\", shape=(1,), dtype=dtype1, trainable=False)\n        x2 = backend.Variable(\"ones\", shape=(1,), dtype=dtype2, trainable=False)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(\n            jnp.floor_divide(x1_jax, x2_jax).dtype\n        )\n\n        self.assertDType(x1 // x2, expected_dtype)\n        self.assertDType(x1.__rfloordiv__(x2), expected_dtype)\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(NON_COMPLEX_DTYPES, 2))\n    )\n    def test_mod(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = backend.Variable(\"ones\", shape=(1,), dtype=dtype1, trainable=False)\n        x2 = backend.Variable(\"ones\", shape=(1,), dtype=dtype2, trainable=False)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.mod(x1_jax, x2_jax).dtype)\n\n        self.assertDType(x1 % x2, expected_dtype)\n        self.assertDType(x1.__rmod__(x2), expected_dtype)\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_pow(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = backend.Variable(\"ones\", shape=(1,), dtype=dtype1, trainable=False)\n        x2 = backend.Variable(\"ones\", shape=(1,), dtype=dtype2, trainable=False)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.power(x1_jax, x2_jax).dtype)\n\n        self.assertDType(x1**x2, expected_dtype)\n        self.assertDType(x1.__rpow__(x2), expected_dtype)\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_matmul(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = backend.Variable(\"ones\", shape=(1,), dtype=dtype1, trainable=False)\n        x2 = backend.Variable(\"ones\", shape=(1,), dtype=dtype2, trainable=False)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.matmul(x1_jax, x2_jax).dtype)\n\n        self.assertDType(x1 @ x2, expected_dtype)\n        self.assertDType(x1.__rmatmul__(x2), expected_dtype)\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_and(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = backend.Variable(\"ones\", shape=(1,), dtype=dtype1, trainable=False)\n        x2 = backend.Variable(\"ones\", shape=(1,), dtype=dtype2, trainable=False)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(\n            jnp.logical_and(x1_jax, x2_jax).dtype\n        )\n\n        self.assertDType(x1 & x2, expected_dtype)\n        self.assertDType(x1.__rand__(x2), expected_dtype)\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_or(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = backend.Variable(\"ones\", shape=(1,), dtype=dtype1, trainable=False)\n        x2 = backend.Variable(\"ones\", shape=(1,), dtype=dtype2, trainable=False)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.logical_or(x1_jax, x2_jax).dtype)\n\n        self.assertDType(x1 | x2, expected_dtype)\n        self.assertDType(x1.__ror__(x2), expected_dtype)\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_xor(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = backend.Variable(\"ones\", shape=(1,), dtype=dtype1, trainable=False)\n        x2 = backend.Variable(\"ones\", shape=(1,), dtype=dtype2, trainable=False)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(\n            jnp.logical_xor(x1_jax, x2_jax).dtype\n        )\n\n        self.assertDType(x1 ^ x2, expected_dtype)\n        self.assertDType(x1.__rxor__(x2), expected_dtype)\n\n\n@pytest.mark.skipif(\n    backend.backend() != \"torch\",\n    reason=\"Tests for standardize_shape with Torch backend\",\n)\nclass TestStandardizeShapeWithTorch(test_case.TestCase):\n    def test_standardize_shape_with_torch_size(self):\n        import torch\n\n        tensor = torch.randn(3, 4, 5)\n        shape = tensor.size()\n        standardized_shape = standardize_shape(shape)\n        self.assertEqual(standardized_shape, (3, 4, 5))\n        self.assertIs(type(standardized_shape), tuple)\n        for d in standardized_shape:\n            self.assertIsInstance(d, int)\n\n    def test_standardize_shape_with_torch_symint(self):\n        \"\"\"Test that torch.SymInt dimensions are converted to None.\n\n        This validates the fix for GitHub issue #22102 where torch.SymInt\n        objects from torch.export were causing \"Constraints violated\" errors.\n        \"\"\"\n        import torch\n\n        # Create a mock SymInt object\n        sym_int = create_autospec(torch.SymInt, instance=True)\n        shape_with_sym_int = (sym_int, 224, 224, 64)\n\n        # SymInt should be converted to None\n        result = standardize_shape(shape_with_sym_int)\n        self.assertEqual(result, (None, 224, 224, 64))\n\n        # Test with multiple SymInts\n        sym_int2 = create_autospec(torch.SymInt, instance=True)\n        shape_with_multiple_sym_ints = (sym_int, sym_int2, 64)\n        result = standardize_shape(shape_with_multiple_sym_ints)\n        self.assertEqual(result, (None, None, 64))\n\n\n@pytest.mark.skipif(\n    backend.backend() != \"tensorflow\",\n    reason=\"Tests for standardize_shape with TensorFlow backend\",\n)\nclass TestStandardizeShapeWithTensorflow(test_case.TestCase):\n    def test_standardize_shape_with_tensor_size(self):\n        import tensorflow as tf\n\n        shape = (3, tf.constant(4, dtype=tf.int64), 5)\n        standardized_shape = standardize_shape(shape)\n        self.assertEqual(standardized_shape, (3, 4, 5))\n        self.assertIs(type(standardized_shape), tuple)\n        for d in standardized_shape:\n            self.assertIsInstance(d, int)\n"
  },
  {
    "path": "keras/src/backend/config.py",
    "content": "import json\nimport os\n\nfrom keras.src.api_export import keras_export\n\n# The type of float to use throughout a session.\n_FLOATX = \"float32\"\n\n# Epsilon fuzz factor used throughout the codebase.\n_EPSILON = 1e-7\n\n# Default image data format, one of \"channels_last\", \"channels_first\".\n_IMAGE_DATA_FORMAT = \"channels_last\"\n\n# Default backend: TensorFlow.\n_BACKEND = \"tensorflow\"\n\n# Whether NNX is enabled.\n_NNX_ENABLED = False\n\n# Cap run duration for debugging.\n_MAX_EPOCHS = None\n_MAX_STEPS_PER_EPOCH = None\n\n\n@keras_export([\"keras.config.floatx\", \"keras.backend.floatx\"])\ndef floatx():\n    \"\"\"Return the default float type, as a string.\n\n    E.g. `'bfloat16'`, `'float16'`, `'float32'`, `'float64'`.\n\n    Returns:\n        String, the current default float type.\n\n    Example:\n\n    >>> keras.config.floatx()\n    'float32'\n\n    \"\"\"\n    return _FLOATX\n\n\n@keras_export([\"keras.config.set_floatx\", \"keras.backend.set_floatx\"])\ndef set_floatx(value):\n    \"\"\"Set the default float dtype.\n\n    Note: It is not recommended to set this to `\"float16\"` for training,\n    as this will likely cause numeric stability issues.\n    Instead, mixed precision, which leverages\n    a mix of `float16` and `float32`. It can be configured by calling\n    `keras.mixed_precision.set_dtype_policy('mixed_float16')`.\n\n    Args:\n        value: String; `'bfloat16'`, `'float16'`, `'float32'`, or `'float64'`.\n\n    Examples:\n    >>> keras.config.floatx()\n    'float32'\n\n    >>> keras.config.set_floatx('float64')\n    >>> keras.config.floatx()\n    'float64'\n\n    >>> # Set it back to float32\n    >>> keras.config.set_floatx('float32')\n\n    Raises:\n        ValueError: In case of invalid value.\n    \"\"\"\n    global _FLOATX\n    accepted_dtypes = {\"bfloat16\", \"float16\", \"float32\", \"float64\"}\n    if value not in accepted_dtypes:\n        raise ValueError(\n            f\"Unknown `floatx` value: {value}. \"\n            f\"Expected one of {accepted_dtypes}\"\n        )\n    _FLOATX = str(value)\n\n\n@keras_export([\"keras.config.epsilon\", \"keras.backend.epsilon\"])\ndef epsilon():\n    \"\"\"Return the value of the fuzz factor used in numeric expressions.\n\n    Returns:\n        A float.\n\n    Example:\n\n    >>> keras.config.epsilon()\n    1e-07\n\n    \"\"\"\n    return _EPSILON\n\n\n@keras_export([\"keras.config.set_epsilon\", \"keras.backend.set_epsilon\"])\ndef set_epsilon(value):\n    \"\"\"Set the value of the fuzz factor used in numeric expressions.\n\n    Args:\n        value: float. New value of epsilon.\n\n    Examples:\n    >>> keras.config.epsilon()\n    1e-07\n\n    >>> keras.config.set_epsilon(1e-5)\n    >>> keras.config.epsilon()\n    1e-05\n\n    >>> # Set it back to the default value.\n    >>> keras.config.set_epsilon(1e-7)\n\n    \"\"\"\n    global _EPSILON\n    _EPSILON = value\n\n\n@keras_export(\n    [\n        \"keras.config.image_data_format\",\n        \"keras.backend.image_data_format\",\n    ]\n)\ndef image_data_format():\n    \"\"\"Return the default image data format convention.\n\n    Returns:\n        A string, either `'channels_first'` or `'channels_last'`.\n\n    Example:\n\n    >>> keras.config.image_data_format()\n    'channels_last'\n\n    \"\"\"\n    return _IMAGE_DATA_FORMAT\n\n\n@keras_export(\n    [\n        \"keras.config.set_image_data_format\",\n        \"keras.backend.set_image_data_format\",\n    ]\n)\ndef set_image_data_format(data_format):\n    \"\"\"Set the value of the image data format convention.\n\n    Args:\n        data_format: string. `'channels_first'` or `'channels_last'`.\n\n    Examples:\n\n    >>> keras.config.image_data_format()\n    'channels_last'\n\n    >>> keras.config.set_image_data_format('channels_first')\n    >>> keras.config.image_data_format()\n    'channels_first'\n\n    >>> # Set it back to `'channels_last'`\n    >>> keras.config.set_image_data_format('channels_last')\n\n    \"\"\"\n    global _IMAGE_DATA_FORMAT\n    data_format = str(data_format).lower()\n    if data_format not in {\"channels_first\", \"channels_last\"}:\n        raise ValueError(\n            \"The `data_format` argument must be one of \"\n            \"{'channels_first', 'channels_last'}. \"\n            f\"Received: data_format={data_format}\"\n        )\n    _IMAGE_DATA_FORMAT = data_format\n\n\n@keras_export(\"keras.config.enable_flash_attention\")\ndef enable_flash_attention():\n    \"\"\"Enable flash attention.\n\n    Flash attention offers performance optimization for attention layers,\n    making it especially useful for large language models (LLMs) that\n    benefit from faster and more memory-efficient attention computations.\n\n    Once enabled, supported layers like `MultiHeadAttention` will **attempt** to\n    use flash attention for faster computations. By default, this feature is\n    enabled.\n\n    Note that enabling flash attention does not guarantee it will always be\n    used. Typically, the inputs must be in `float16` or `bfloat16` dtype, and\n    input layout requirements may vary depending on the backend.\n    \"\"\"\n    from keras.src.backend.common import global_state\n\n    global_state.set_global_attribute(\"flash_attention\", None)\n\n\n@keras_export(\"keras.config.disable_flash_attention\")\ndef disable_flash_attention():\n    \"\"\"Disable flash attention.\n\n    Flash attention offers performance optimization for attention layers,\n    making it especially useful for large language models (LLMs) that\n    benefit from faster and more memory-efficient attention computations.\n\n    Once disabled, supported layers like `MultiHeadAttention` will not\n    use flash attention for faster computations.\n    \"\"\"\n    from keras.src.backend.common import global_state\n\n    global_state.set_global_attribute(\"flash_attention\", False)\n\n\n@keras_export(\"keras.config.is_flash_attention_enabled\")\ndef is_flash_attention_enabled():\n    \"\"\"Checks whether flash attention is globally enabled in Keras.\n\n    Flash attention is a performance-optimized method for computing attention\n    in large models, such as transformers, allowing for faster and more\n    memory-efficient operations. This function checks the global Keras\n    configuration to determine if flash attention is enabled for compatible\n    layers (e.g., `MultiHeadAttention`).\n\n    Note that enabling flash attention does not guarantee it will always be\n    used. Typically, the inputs must be in `float16` or `bfloat16` dtype, and\n    input layout requirements may vary depending on the backend.\n\n    Returns:\n        `False` if disabled; otherwise, it indicates that it is enabled.\n    \"\"\"\n    from keras.src.backend.common import global_state\n\n    return global_state.get_global_attribute(\"flash_attention\", default=None)\n\n\n@keras_export(\"keras.config.is_nnx_enabled\")\ndef is_nnx_enabled():\n    \"\"\"Checks whether NNX specific features are enabled for the JAX backend.\n\n    Returns:\n        bool: `True` if NNX backend features are enabled, `False` otherwise.\n        Defaults to `False`.\n    \"\"\"\n    return _NNX_ENABLED\n\n\ndef set_nnx_enabled(value):\n    global _NNX_ENABLED\n    from keras.src.backend.common import global_state\n\n    _NNX_ENABLED = bool(value)\n    if _NNX_ENABLED:\n        try:\n            from flax import nnx  # noqa F401\n        except ImportError:\n            raise ImportError(\n                \"To use NNX with the JAX backend, you must install `flax`.\"\n            )\n    global_state.set_global_attribute(\"nnx_enabled\", bool(value))\n\n\ndef standardize_data_format(data_format):\n    if data_format is None:\n        return image_data_format()\n    data_format = str(data_format).lower()\n    if data_format not in {\"channels_first\", \"channels_last\"}:\n        raise ValueError(\n            \"The `data_format` argument must be one of \"\n            \"{'channels_first', 'channels_last'}. \"\n            f\"Received: data_format={data_format}\"\n        )\n    return data_format\n\n\n# Set Keras base dir path given KERAS_HOME env variable, if applicable.\n# Otherwise either ~/.keras or /tmp.\nif \"KERAS_HOME\" in os.environ:\n    _KERAS_DIR = os.environ.get(\"KERAS_HOME\")\nelse:\n    _keras_base_dir = os.path.expanduser(\"~\")\n    if not os.access(_keras_base_dir, os.W_OK):\n        _keras_base_dir = \"/tmp\"\n    _KERAS_DIR = os.path.join(_keras_base_dir, \".keras\")\n\n\ndef keras_home():\n    # Private accessor for the keras home location.\n    return _KERAS_DIR\n\n\n# Attempt to read Keras config file.\n_config_path = os.path.expanduser(os.path.join(_KERAS_DIR, \"keras.json\"))\nif os.path.exists(_config_path):\n    try:\n        with open(_config_path) as f:\n            _config = json.load(f)\n    except ValueError:\n        _config = {}\n    _floatx = _config.get(\"floatx\", floatx())\n    if _floatx not in {\"float16\", \"float32\", \"float64\"}:\n        raise ValueError(\n            \"Invalid `floatx` configuration. \"\n            \"Expected one of {'float16', 'float32', 'float64'}. \"\n            f\"Received: floatx={_floatx}\"\n        )\n    _epsilon = _config.get(\"epsilon\", epsilon())\n    if not isinstance(_epsilon, float):\n        raise ValueError(\n            \"Invalid `epsilon` configuration. \"\n            \"Expected a float. \"\n            f\"Received: epsilon={_epsilon}\"\n        )\n    _backend = _config.get(\"backend\", _BACKEND)\n    _image_data_format = _config.get(\"image_data_format\", image_data_format())\n    if _image_data_format not in {\"channels_last\", \"channels_first\"}:\n        raise ValueError(\n            \"Invalid `image_data_format` configuration. \"\n            \"Expected one of {'channels_last', 'channels_first'}. \"\n            f\"Received: image_data_format={_image_data_format}\"\n        )\n    _nnx_enabled_config = _config.get(\"nnx_enabled\", _NNX_ENABLED)\n\n    # Apply basic configs that don't cause circular import\n    set_floatx(_floatx)\n    _NNX_ENABLED = _nnx_enabled_config\n    set_epsilon(_epsilon)\n    set_image_data_format(_image_data_format)\n    _BACKEND = _backend\n\n# Save config file, if possible.\nif not os.path.exists(_KERAS_DIR):\n    try:\n        os.makedirs(_KERAS_DIR)\n    except OSError:\n        # Except permission denied and potential race conditions\n        # in multi-threaded environments.\n        pass\n\nif not os.path.exists(_config_path):\n    _config = {\n        \"floatx\": floatx(),\n        \"epsilon\": epsilon(),\n        \"backend\": _BACKEND,\n        \"image_data_format\": image_data_format(),\n    }\n    try:\n        with open(_config_path, \"w\") as f:\n            f.write(json.dumps(_config, indent=4))\n    except IOError:\n        # Except permission denied.\n        pass\n\n# Set backend based on KERAS_BACKEND flag, if applicable.\nif \"KERAS_BACKEND\" in os.environ:\n    _backend = os.environ[\"KERAS_BACKEND\"]\n    if _backend:\n        _BACKEND = _backend\nif \"KERAS_MAX_EPOCHS\" in os.environ:\n    _MAX_EPOCHS = int(os.environ[\"KERAS_MAX_EPOCHS\"])\nif \"KERAS_MAX_STEPS_PER_EPOCH\" in os.environ:\n    _MAX_STEPS_PER_EPOCH = int(os.environ[\"KERAS_MAX_STEPS_PER_EPOCH\"])\n\n\nif _BACKEND != \"tensorflow\":\n    # If we are not running on the tensorflow backend, we should stop tensorflow\n    # from using all available GPU memory. See\n    # https://www.tensorflow.org/guide/gpu#limiting_gpu_memory_growth\n    os.environ[\"TF_FORCE_GPU_ALLOW_GROWTH\"] = \"true\"\n\n\n@keras_export(\n    [\n        \"keras.config.backend\",\n        \"keras.backend.backend\",\n    ]\n)\ndef backend():\n    \"\"\"Publicly accessible method for determining the current backend.\n\n    Returns:\n        String, the name of the backend Keras is currently using. One of\n            `\"tensorflow\"`, `\"torch\"`, or `\"jax\"`.\n\n    Example:\n\n    >>> keras.config.backend()\n    'tensorflow'\n\n    \"\"\"\n    return _BACKEND\n\n\n@keras_export([\"keras.config.set_max_epochs\"])\ndef set_max_epochs(max_epochs):\n    \"\"\"Limit the maximum number of epochs for any call to fit.\n\n    This will cap the number of epochs for any training run using `model.fit()`.\n    This is purely for debugging, and can also be set via the `KERAS_MAX_EPOCHS`\n    environment variable to quickly run a script without modifying its source.\n\n    Args:\n        max_epochs: The integer limit on the number of epochs or `None`. If\n            `None`, no limit is applied.\n    \"\"\"\n    global _MAX_EPOCHS\n    _MAX_EPOCHS = max_epochs\n\n\n@keras_export([\"keras.config.set_max_steps_per_epoch\"])\ndef set_max_steps_per_epoch(max_steps_per_epoch):\n    \"\"\"Limit the maximum number of steps for any call to fit/evaluate/predict.\n\n    This will cap the number of steps for single epoch of a call to `fit()`,\n    `evaluate()`, or `predict()`. This is purely for debugging, and can also be\n    set via the `KERAS_MAX_STEPS_PER_EPOCH` environment variable to quickly run\n    a scrip without modifying its source.\n\n    Args:\n        max_epochs: The integer limit on the number of epochs or `None`. If\n            `None`, no limit is applied.\n    \"\"\"\n    global _MAX_STEPS_PER_EPOCH\n    _MAX_STEPS_PER_EPOCH = max_steps_per_epoch\n\n\n@keras_export([\"keras.config.max_epochs\"])\ndef max_epochs():\n    \"\"\"Get the maximum number of epochs for any call to fit.\n\n    Retrieves the limit on the number of epochs set by\n    `keras.config.set_max_epochs` or the `KERAS_MAX_EPOCHS` environment\n    variable.\n\n    Returns:\n        The integer limit on the number of epochs or `None`, if no limit has\n        been set.\n    \"\"\"\n    return _MAX_EPOCHS\n\n\n@keras_export([\"keras.config.max_steps_per_epoch\"])\ndef max_steps_per_epoch():\n    \"\"\"Get the maximum number of steps for any call to fit/evaluate/predict.\n\n    Retrieves the limit on the number of epochs set by\n    `keras.config.set_max_steps_per_epoch` or the `KERAS_MAX_STEPS_PER_EPOCH`\n    environment variable.\n\n    Args:\n        max_epochs: The integer limit on the number of epochs or `None`. If\n            `None`, no limit is applied.\n    \"\"\"\n    return _MAX_STEPS_PER_EPOCH\n\n\nif \"KERAS_NNX_ENABLED\" in os.environ:\n    env_val = os.environ[\"KERAS_NNX_ENABLED\"].lower()\n    if env_val == \"true\" or env_val == \"1\":\n        _NNX_ENABLED = True\n    else:\n        _NNX_ENABLED = False\n\nset_nnx_enabled(_NNX_ENABLED)\n"
  },
  {
    "path": "keras/src/backend/jax/__init__.py",
    "content": "from keras.src.backend.config import is_nnx_enabled\nfrom keras.src.backend.jax import core\nfrom keras.src.backend.jax import distribution_lib\nfrom keras.src.backend.jax import image\nfrom keras.src.backend.jax import linalg\nfrom keras.src.backend.jax import math\nfrom keras.src.backend.jax import nn\nfrom keras.src.backend.jax import numpy\nfrom keras.src.backend.jax import random\nfrom keras.src.backend.jax import tensorboard\nfrom keras.src.backend.jax.core import IS_THREAD_SAFE\nfrom keras.src.backend.jax.core import SUPPORTS_RAGGED_TENSORS\nfrom keras.src.backend.jax.core import SUPPORTS_SPARSE_TENSORS\nfrom keras.src.backend.jax.core import Variable\nfrom keras.src.backend.jax.core import cast\nfrom keras.src.backend.jax.core import compute_output_spec\nfrom keras.src.backend.jax.core import cond\nfrom keras.src.backend.jax.core import convert_to_numpy\nfrom keras.src.backend.jax.core import convert_to_tensor\nfrom keras.src.backend.jax.core import device_scope\nfrom keras.src.backend.jax.core import is_tensor\nfrom keras.src.backend.jax.core import name_scope\nfrom keras.src.backend.jax.core import random_seed_dtype\nfrom keras.src.backend.jax.core import scatter\nfrom keras.src.backend.jax.core import shape\nfrom keras.src.backend.jax.core import stop_gradient\nfrom keras.src.backend.jax.core import vectorized_map\nfrom keras.src.backend.jax.rnn import cudnn_ok\nfrom keras.src.backend.jax.rnn import gru\nfrom keras.src.backend.jax.rnn import lstm\nfrom keras.src.backend.jax.rnn import rnn\n"
  },
  {
    "path": "keras/src/backend/jax/core.py",
    "content": "import math\n\nimport jax\nimport jax.experimental.sparse as jax_sparse\nimport jax.numpy as jnp\nimport ml_dtypes\nimport numpy as np\nfrom jax import export as jax_export\n\nfrom keras.src import tree\nfrom keras.src.backend import config\nfrom keras.src.backend.common import KerasVariable\nfrom keras.src.backend.common import global_state\nfrom keras.src.backend.common import standardize_dtype\nfrom keras.src.backend.common.keras_tensor import KerasTensor\nfrom keras.src.backend.common.name_scope import name_scope as base_name_scope\nfrom keras.src.backend.common.stateless_scope import StatelessScope\nfrom keras.src.backend.common.stateless_scope import get_stateless_scope\nfrom keras.src.backend.common.stateless_scope import in_stateless_scope\nfrom keras.src.backend.common.symbolic_scope import SymbolicScope\nfrom keras.src.backend.jax import distribution_lib\n\nSUPPORTS_SPARSE_TENSORS = True\nSUPPORTS_RAGGED_TENSORS = False\nIS_THREAD_SAFE = True\n\n\nclass JaxVariable(KerasVariable):\n    def __init__(self, *args, layout=None, **kwargs):\n        # Intercept layout parameter so that it is available\n        # during initialization.\n        self._layout = layout\n        super().__init__(*args, **kwargs)\n\n    def _initialize_layout(self):\n        # We can't import the keras/distribution/distribution_lib\n        # due to circular dependency.\n        distribution = global_state.get_global_attribute(\"distribution\")\n        if self._layout is None and distribution is not None:\n            tensor_layout = distribution.get_variable_layout(self)\n            from keras.src.distribution import TensorLayout\n\n            if isinstance(tensor_layout, TensorLayout):\n                self._layout = tensor_layout.backend_layout\n            else:\n                self._layout = tensor_layout\n\n    def _initialize(self, value):\n        # Note that variable.shape is needed by distribution_lib\n        self._shape = self._validate_shape(value.shape)\n        self._initialize_layout()\n        self._direct_assign(value)\n\n    def _initialize_with_initializer(self, initializer):\n        self._initialize_layout()\n        layout = self._layout\n        shape = self._shape\n        if should_shard_at_init(layout, shape):\n            jitted_initializer = jax.jit(\n                initializer.__call__,\n                out_shardings=layout,\n                static_argnames=[\"shape\", \"dtype\"],\n            )\n            value = jitted_initializer(shape=self._shape, dtype=self._dtype)\n            self._value = value\n        else:\n            super()._initialize_with_initializer(initializer)\n\n    def _direct_assign(self, value):\n        if self._layout is not None:\n            value = distribution_lib.distribute_tensor(value, self._layout)\n        self._value = value\n\n    def _convert_to_tensor(self, value, dtype=None):\n        return convert_to_tensor(value, dtype=dtype, sparse=False)\n\n    # Overload native accessor.\n    def __jax_array__(self):\n        return self.value\n\n\nVariable = JaxVariable\nif config.is_nnx_enabled():\n    from flax import nnx\n\n    class NnxVariable(JaxVariable, nnx.Variable):\n        def __init__(\n            self,\n            initializer,\n            shape=None,\n            dtype=None,\n            trainable=True,\n            autocast=True,\n            aggregation=\"none\",\n            synchronization=\"auto\",\n            name=None,\n            layout=None,\n            mutable=None,\n            **nnx_metadata,\n        ):\n            # Ensure 'mutable' is in nnx_metadata, but explicit 'mutable'\n            # param takes precedence.\n            nnx_metadata[\"mutable\"] = True if mutable is None else mutable\n\n            # First, initialize a basic nnx.Variable with a dummy value\n            # This sets up the NNX variable structure\n            if shape is None:\n                dummy_value = jnp.array(0.0)\n            else:\n                dummy_value = jnp.zeros(shape, dtype=standardize_dtype(dtype))\n\n            # Initialize nnx.Variable first\n            nnx.Variable.__init__(self, value=dummy_value, **nnx_metadata)\n\n            # Now we can safely set layout\n            self._layout = layout\n\n            # Initialize JaxVariable (which will call KerasVariable.__init__\n            # and set up the real value).\n            JaxVariable.__init__(\n                self,\n                initializer=initializer,\n                shape=shape,\n                dtype=dtype,\n                trainable=trainable,\n                autocast=autocast,\n                aggregation=aggregation,\n                synchronization=synchronization,\n                name=name,\n            )\n\n            # The real value is now set in self._value, sync it to raw_value\n            object.__setattr__(self, \"raw_value\", self._value)\n\n        def _initialize_with_initializer(self, initializer):\n            value = self._convert_to_tensor(\n                initializer(self._shape, dtype=self._dtype)\n            )\n            self._initialize(value)\n\n        @property\n        def _value(self):\n            if hasattr(self, \"raw_value\"):\n                return self.raw_value\n            return None\n\n        @_value.setter\n        def _value(self, new_keras_value):\n            self._direct_assign(new_keras_value)\n\n        def __getstate__(self):\n            # Get the state from KerasVariable (attributes in __dict__)\n            # KerasVariable does not have a custom __getstate__, so we mimic\n            # default behavior.\n            try:\n                keras_state = KerasVariable.__getstate__(self)\n            except AttributeError:\n                keras_state = object.__getstate__(self)\n\n            # Get the state from nnx.Variable\n            nnx_specific_state = nnx.Variable.__getstate__(self)\n\n            # Merge them. Keras state is primary. NNX specific state adds\n            # to it.\n            if \"raw_value\" in nnx_specific_state:\n                keras_state[\"_value\"] = nnx_specific_state[\"raw_value\"]\n\n            # Add NNX attributes that are not in Keras's __dict__\n            if \"_trace_state\" in nnx_specific_state:\n                keras_state[\"_trace_state\"] = nnx_specific_state[\"_trace_state\"]\n            if \"_var_metadata\" in nnx_specific_state:\n                keras_state[\"_var_metadata\"] = nnx_specific_state[\n                    \"_var_metadata\"\n                ]\n\n            # Remove elements that might be problematic or redundant if\n            # nnx.Variable's __getstate__\n            keras_state.pop(\"raw_value\", None)\n\n            return keras_state\n\n        def __setstate__(self, state):\n            # Separate nnx specific keys that we added if they are not part\n            # of Keras __dict__ this __getstate__ puts them into the main\n            # state dictionary.\n            nnx_raw_value = state[\"_value\"]  # This was raw_value\n            nnx_trace_state = state.pop(\"_trace_state\", None)\n            nnx_var_metadata = state.pop(\"_var_metadata\", None)\n\n            # Populate the instance's __dict__ with the Keras attributes.\n            self.__dict__.update(state)\n\n            # restore the nnx.Variable specific slotted attributes.\n            object.__setattr__(self, \"raw_value\", nnx_raw_value)\n\n            if nnx_trace_state is not None:\n                object.__setattr__(self, \"_trace_state\", nnx_trace_state)\n            else:\n                pass\n\n            if nnx_var_metadata is not None:\n                object.__setattr__(self, \"_var_metadata\", nnx_var_metadata)\n            else:\n                pass\n\n            # Ensure Keras's self._value is also consistent with the\n            # restored raw_value\n            self._value = nnx_raw_value\n\n            if hasattr(self, \"_shape\") and self._shape is not None:\n                self._ndim = len(self._shape)\n            else:\n                # Fallback if shape isn't immediately available.\n                self._ndim = len(self.raw_value.shape)\n\n        def _direct_assign(self, value):\n            # Apply JAX-specific distribution if layout is present\n            if self._layout is not None:\n                value = distribution_lib.distribute_tensor(value, self._layout)\n\n            # Apply on_set_value hook if it exists\n            if (\n                hasattr(self, \"_var_metadata\")\n                and \"on_set_value\" in self._var_metadata\n            ):\n                value = self._var_metadata[\"on_set_value\"](self, value)\n\n            # Set the value for both Keras and NNX parts\n            # This ensures both systems see the same value\n            object.__setattr__(self, \"raw_value\", value)\n\n        @property\n        def value(self):\n            if in_stateless_scope():\n                scope = get_stateless_scope()\n                stateless_value = scope.get_current_value(self)\n                if stateless_value is not None:\n                    return self._maybe_autocast(stateless_value)\n            if not hasattr(self, \"raw_value\"):\n                if self._initializer is not None:\n                    self._initialize(\n                        self._initializer(self.shape, dtype=self.dtype)\n                    )\n                else:\n                    raise AttributeError(\n                        \"Variable is not properly initialized (raw_value \"\n                        \"missing) and has no initializer.\"\n                    )\n            current_value = self.raw_value\n            if (\n                hasattr(self, \"_var_metadata\")\n                and \"on_get_value\" in self._var_metadata\n            ):\n                current_value = self._var_metadata[\"on_get_value\"](\n                    self, current_value\n                )\n            return self._maybe_autocast(current_value)\n\n    Variable = NnxVariable\n\n    def _flatten_nnx_variable(variable):\n        children = (variable.raw_value,)\n        # We copy __dict__ to avoid side effects\n        keras_state = variable.__dict__.copy()\n        # Remove elements that might be problematic or redundant if\n        # nnx.Variable's __getstate__\n        keras_state.pop(\"raw_value\", None)\n        aux_data = (\n            variable._var_metadata,\n            getattr(variable, \"_trace_state\", None),\n            keras_state,\n        )\n        return children, aux_data\n\n    def _unflatten_nnx_variable(aux_data, children):\n        var_metadata, trace_state, keras_state = aux_data\n        raw_value = children[0]\n\n        # Create uninitialized instance\n        variable = NnxVariable.__new__(NnxVariable)\n\n        # Restore state\n        variable._var_metadata = var_metadata\n        if trace_state is not None:\n            variable._trace_state = trace_state\n        variable.__dict__.update(keras_state)\n        variable.raw_value = raw_value\n\n        return variable\n\n    try:\n        jax.tree_util.register_pytree_node(\n            NnxVariable,\n            _flatten_nnx_variable,\n            _unflatten_nnx_variable,\n        )\n    except ValueError:\n        pass\n\n    def __setattr__(self, name, value):\n        # Mirror Keras attributes to _var_metadata to ensure persistence\n        # if the Pytree registration is not respected by NNX.\n        if (\n            name != \"_var_metadata\"\n            and name not in (\"_raw_value\", \"_trace_state\")\n            and hasattr(self, \"_var_metadata\")\n        ):\n            self._var_metadata[name] = value\n\n        object.__setattr__(self, name, value)\n\n    NnxVariable.__setattr__ = __setattr__\n\n\ndef should_shard_at_init(init_layout, shape):\n    if not isinstance(init_layout, jax.sharding.NamedSharding):\n        return False\n\n    size_threshold = 250 * 1024 * 1024\n    # We multiply by the mesh size here to take into account the worst case\n    # scenario of the array being first duplicated in the memory of one device\n    # before being transferred to the other devices.\n    size = math.prod(shape) * 4 * init_layout.mesh.devices.size\n    return size >= size_threshold\n\n\ndef convert_to_tensor(x, dtype=None, sparse=None, ragged=None):\n    if ragged:\n        raise ValueError(\"`ragged=True` is not supported with jax backend\")\n    if dtype is not None:\n        dtype = standardize_dtype(dtype)\n    if isinstance(x, (jnp.ndarray, jax.Array)) and (\n        dtype is None or x.dtype == dtype\n    ):\n        # Skip the conversion early if the instance is already a JAX array.\n        # This is important in the multi-process context since jax.array(x) for\n        # an existing distributed jax array will raise error.\n        return x\n\n    if isinstance(x, Variable):\n        if dtype is not None and x.dtype != dtype:\n            return x.value.astype(dtype)\n        return x.value\n\n    if isinstance(x, jax_sparse.JAXSparse):\n        if sparse is not None and not sparse:\n            x = x.todense()\n        elif dtype is not None and x.dtype != dtype:\n            return x.astype(dtype)\n        else:\n            return x\n\n    if not is_tensor(x) and standardize_dtype(dtype) == \"bfloat16\":\n        # Can't create bfloat16 arrays on the fly (e.g. from a h5 Dataset).\n        # Instead we convert \"as is\" (to stored dtype) and cast.\n        return jnp.asarray(x).astype(dtype)\n    return jnp.asarray(x, dtype=dtype)\n\n\ndef convert_to_numpy(x):\n    if isinstance(x, jax_sparse.JAXSparse):\n        x = x.todense()\n    if is_tensor(x) and x.dtype == \"bfloat16\":\n        return np.array(x, dtype=ml_dtypes.bfloat16)\n    return np.array(x)\n\n\ndef is_tensor(x):\n    if isinstance(x, (jnp.ndarray, jax_sparse.JAXSparse)):\n        return True\n    return False\n\n\ndef shape(x):\n    return x.shape\n\n\ndef cast(x, dtype):\n    return convert_to_tensor(x, dtype=dtype)\n\n\n# Shape / dtype / sparseness inference util\ndef compute_output_spec(fn, *args, **kwargs):\n    with StatelessScope(), SymbolicScope():\n        built_in_types = (type(None), int, float, str, bool, complex, bytes)\n\n        # First, separate symbolic args from other args\n        static_args_idx = []\n        static_args = []\n        maybe_symbolic_args = []\n        static_kwargs = {}\n        maybe_symbolic_kwargs = {}\n        for idx, arg in enumerate(args):\n            if isinstance(arg, built_in_types):\n                static_args_idx.append(idx)\n                static_args.append(arg)\n            else:\n                maybe_symbolic_args.append(arg)\n        maybe_symbolic_args = tuple(maybe_symbolic_args)\n        for k, v in kwargs.items():\n            if isinstance(v, built_in_types):\n                static_kwargs[k] = v\n            else:\n                maybe_symbolic_kwargs[k] = v\n\n        # Create a _DimExpr instance for one dimension by creating a symbolic\n        # shape with one dimension and extracting it.\n        #\n        # We create a single dynamic dimension and reuse it instead of creating\n        # N dynamic dimensions. This is for backwards compatibility. Previously\n        # we would fill all dynamic dimensions with the same concrete value.\n        # This can handle the case where there is an implicit assumption that\n        # two dimensions are the same (e.g. square images).\n        #\n        # We add the constraint \"dynamic_dimension>=2\" to prevent JAX from\n        # assuming that the dimension can be broadcastable or squeezable. It\n        # removes this ambiguity.\n        dynamic_dimension = jax_export.symbolic_shape(\n            \"(dynamic_dimension)\",\n            constraints=[\"dynamic_dimension>=2\"],\n        )[0]\n\n        def convert_keras_tensor_to_jax(x):\n            if isinstance(x, KerasTensor):\n                shape = tuple(\n                    [d if d is not None else dynamic_dimension for d in x.shape]\n                )\n                return jax.ShapeDtypeStruct(shape, dtype=x.dtype)\n            return x\n\n        def wrapped_fn(*args, **kwargs):\n            # Turn inputs that are sparse to BCOO tensors\n            def to_bcoo_if_sparse(x, maybe_symbolic_x):\n                if (\n                    isinstance(maybe_symbolic_x, KerasTensor)\n                    and maybe_symbolic_x.sparse\n                ):\n                    return jax_sparse.BCOO.fromdense(x, nse=1)\n                return x\n\n            args, kwargs = tree.map_structure(\n                to_bcoo_if_sparse,\n                (args, kwargs),\n                (maybe_symbolic_args, maybe_symbolic_kwargs),\n            )\n\n            rec_args = []\n            idx_static = 0\n            idx_sym = 0\n            i = 0\n            while idx_static < len(static_args) or idx_sym < len(args):\n                if i in static_args_idx:\n                    rec_args.append(static_args[idx_static])\n                    idx_static += 1\n                else:\n                    rec_args.append(args[idx_sym])\n                    idx_sym += 1\n\n                i += 1\n            with StatelessScope():\n                return fn(*rec_args, **kwargs, **static_kwargs)\n\n        maybe_symbolic_args_jax, maybe_symbolic_kwargs_jax = tree.map_structure(\n            convert_keras_tensor_to_jax,\n            (maybe_symbolic_args, maybe_symbolic_kwargs),\n        )\n        jax_out = jax.eval_shape(\n            wrapped_fn, *maybe_symbolic_args_jax, **maybe_symbolic_kwargs_jax\n        )\n\n        def convert_jax_spec_to_keras_tensor(x):\n            if isinstance(x, jax.ShapeDtypeStruct):\n                shape = tuple(\n                    d if isinstance(d, int) else None for d in x.shape\n                )\n                return KerasTensor(shape, x.dtype)\n            elif isinstance(x, jax_sparse.BCOO):\n                shape = tuple(\n                    d if isinstance(d, int) else None for d in x.shape\n                )\n                return KerasTensor(shape, x.dtype, sparse=True)\n            return x\n\n        return tree.map_structure(convert_jax_spec_to_keras_tensor, jax_out)\n\n\ndef cond(pred, true_fn, false_fn):\n    return jax.lax.cond(pred, true_fun=true_fn, false_fun=false_fn)\n\n\ndef vectorized_map(function, elements):\n    return jax.vmap(function)(elements)\n\n\ndef map(f, xs):\n    return jax.lax.map(f, xs)\n\n\ndef scan(f, init, xs=None, length=None, reverse=False, unroll=1):\n    if not isinstance(unroll, bool):\n        if not isinstance(unroll, int) or unroll < 1:\n            raise ValueError(\n                \"`unroll` must be an positive integer or boolean. \"\n                f\"Received: unroll={unroll}\"\n            )\n    return jax.lax.scan(\n        f, init=init, xs=xs, length=length, reverse=reverse, unroll=unroll\n    )\n\n\ndef associative_scan(f, elems, reverse=False, axis=0):\n    return jax.lax.associative_scan(f, elems, reverse, axis)\n\n\ndef scatter(indices, values, shape):\n    zeros = jnp.zeros(shape, values.dtype)\n    key = tuple(jnp.moveaxis(indices, -1, 0))\n    return zeros.at[key].add(values)\n\n\ndef scatter_update(inputs, indices, updates, reduction=None):\n    inputs = convert_to_tensor(inputs)\n    indices = jnp.array(indices)\n    indices = jnp.transpose(indices)\n    idx = tuple(indices)\n    if reduction is None:\n        return inputs.at[idx].set(updates)\n    elif reduction == \"add\":\n        return inputs.at[idx].add(updates)\n    elif reduction == \"max\":\n        return inputs.at[idx].max(updates)\n    elif reduction == \"min\":\n        return inputs.at[idx].min(updates)\n    elif reduction == \"mul\":\n        return inputs.at[idx].multiply(updates)\n    else:\n        raise ValueError(f\"Unsupported reduction: {reduction}\")\n\n\ndef slice(inputs, start_indices, shape):\n    # If shape[i] is -1, all remaining elements in dimension i are included in\n    # the slice.\n    final_shape = tuple(\n        inputs.shape[i] - start_indices[i] if s == -1 else s\n        for i, s in enumerate(shape)\n    )\n    return jax.lax.dynamic_slice(inputs, start_indices, final_shape)\n\n\ndef slice_update(inputs, start_indices, updates):\n    return jax.lax.dynamic_update_slice(inputs, updates, start_indices)\n\n\ndef switch(index, branches, *operands):\n    return jax.lax.switch(index, branches, *operands)\n\n\ndef while_loop(\n    cond,\n    body,\n    loop_vars,\n    maximum_iterations=None,\n):\n    is_tuple = isinstance(loop_vars, (tuple, list))\n    loop_vars = tuple(loop_vars) if is_tuple else (loop_vars,)\n    if maximum_iterations is not None:\n        current_iter = 0\n        loop_vars = loop_vars + (current_iter,)\n\n        # Unpack list/tuple args. The last argument is `current_iter`.\n        def _cond(args):\n            return cond(*args[:-1]) & (args[-1] < maximum_iterations)\n\n        def _body(args):\n            outputs = body(*args[:-1])\n            outputs = tuple(outputs) if is_tuple else (outputs,)\n            return outputs + (args[-1] + 1,)\n\n    else:\n\n        def _cond(args):\n            return cond(*args)\n\n        def _body(args):\n            outputs = body(*args)\n            return tuple(outputs) if is_tuple else (outputs,)\n\n    outputs = jax.lax.while_loop(_cond, _body, loop_vars)\n    if maximum_iterations is not None:\n        outputs = outputs[:-1]\n    return outputs if is_tuple else outputs[0]\n\n\ndef fori_loop(lower, upper, body_fun, init_val):\n    return jax.lax.fori_loop(lower, upper, body_fun, init_val)\n\n\ndef stop_gradient(variable):\n    if isinstance(variable, Variable):\n        variable = variable.value\n    return jax.lax.stop_gradient(variable)\n\n\ndef unstack(x, num=None, axis=0):\n    return [\n        jax.lax.index_in_dim(x, i, axis, keepdims=False)\n        for i in range(x.shape[axis])\n    ]\n\n\ndef random_seed_dtype():\n    # jax random seed uses uint32.\n    return \"uint32\"\n\n\ndef custom_gradient(fun):\n    fun_with_custom_gradient = jax.custom_gradient(fun=fun)\n\n    # Add a wrapper to unwrap variables, otherwise custom_gradient will fail\n    def fun_with_custom_gradient_wrapper(*args, **kwargs):\n        args, kwargs = tree.map_shape_structure(\n            lambda x: x.value if isinstance(x, KerasVariable) else x,\n            (args, kwargs),\n        )\n        return fun_with_custom_gradient(*args, **kwargs)\n\n    return fun_with_custom_gradient_wrapper\n\n\ndef remat(f):\n    \"\"\"Implementation of rematerialization.\n\n    Args:\n        f: The function or operation to rematerialize.\n    Returns:\n        A function wrapping f that defines a custom gradient, which\n        recomputes f on the backwards pass of a gradient call.\n    \"\"\"\n    return jax.checkpoint(f)\n\n\nclass name_scope(base_name_scope):\n    def __init__(self, name, **kwargs):\n        super().__init__(name, **kwargs)\n        self._jax_name_scope = jax.named_scope(name)\n\n    def __enter__(self):\n        name_scope_stack = global_state.get_global_attribute(\n            \"name_scope_stack\", default=[], set_to_default=True\n        )\n        if self.deduplicate and name_scope_stack:\n            parent_caller = name_scope_stack[-1].caller\n            parent_name = name_scope_stack[-1].name\n            if (\n                self.caller is not None\n                and self.caller is parent_caller\n                and self.name == parent_name\n            ):\n                return self\n        name_scope_stack.append(self)\n        self._pop_on_exit = True\n        self._jax_name_scope.__enter__()\n        return self\n\n    def __exit__(self, *args, **kwargs):\n        super().__exit__(*args, **kwargs)\n        if self._pop_on_exit:\n            self._jax_name_scope.__exit__(*args, **kwargs)\n\n\ndef device_scope(device_name):\n    if isinstance(device_name, str):\n        # We support string value like \"cpu:0\", \"gpu:1\", etc.\n        device_name = device_name.lower()\n        jax_device = distribution_lib._to_backend_device(device_name)\n    elif not isinstance(device_name, jax.Device):\n        raise ValueError(\n            \"Invalid value for argument `device_name`. \"\n            \"Expected a string like 'gpu:0' or a `jax.Device` instance. \"\n            f\"Received: device_name='{device_name}'\"\n        )\n    else:\n        jax_device = device_name\n    return jax.default_device(jax_device)\n"
  },
  {
    "path": "keras/src/backend/jax/core_test.py",
    "content": "import os\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport pytest\n\nimport keras\nfrom keras.src import backend\nfrom keras.src import testing\nfrom keras.src.backend.config import is_nnx_enabled\n\nif is_nnx_enabled():\n    from flax import nnx\n\n    from keras.src.backend.jax.core import NnxVariable\n\n\n@pytest.mark.skipif(\n    backend.backend() != \"jax\",\n    reason=\"JAX backend specific test for core Variable integration with NNX.\",\n)\n@pytest.mark.skipif(\n    not is_nnx_enabled(),\n    reason=\"Test requires NNX backend to be enabled by default for setup.\",\n)\nclass NnxVariableTest(testing.TestCase):\n    def setup(self):\n        super().setup()\n\n        class NNXModel(nnx.Module):\n            def __init__(self, rngs):\n                self.linear = nnx.Linear(2, 3, rngs=rngs)\n                # Use NnxVariable directly as KerasJaxVariable\n                # might be JaxVariable if NNX is disabled globally.\n                self.custom_variable = NnxVariable(jnp.ones((1, 3)))\n\n            def __call__(self, x):\n                return self.linear(x) + self.custom_variable\n\n        self.nnx_model = NNXModel(rngs=nnx.Rngs(0))\n        self.keras_nnx_model = keras.Sequential(\n            [keras.layers.Dense(units=1, input_shape=(10,))]\n        )\n        self.single_dummy_input = np.random.rand(1, 10)\n\n    def test_variable_in_nnx_module(self):\n        self.assertTrue(hasattr(self.nnx_model.custom_variable, \"_trace_state\"))\n        self.assertIsNotNone(self.nnx_model.custom_variable._trace_state)\n        self.assertAllEqual(self.nnx_model.custom_variable.value, [[1, 1, 1]])\n        self.assertTrue(\n            isinstance(self.nnx_model.custom_variable, nnx.Variable)\n        )\n\n    def test_model_saving(self):\n        path = os.path.join(self.get_temp_dir(), \"model.keras\")\n        original_outputs = self.keras_nnx_model(self.single_dummy_input)\n        self.keras_nnx_model.save(path, save_format=\"keras_v3\")\n        restored_model = keras.models.load_model(path)\n        restored_outputs = restored_model(self.single_dummy_input)\n        self.assertAllEqual(original_outputs, restored_outputs)\n\n    def test_keras_variable_nnx_split_merge_sync(self):\n        variable1 = keras.Variable(jnp.array(1.0))\n        graphdef, state = nnx.split(variable1)\n        state = jax.tree.map(lambda x: x + 1, state)\n        variable2 = nnx.merge(graphdef, state)\n        self.assertEqual(variable2._value, variable2.value)\n"
  },
  {
    "path": "keras/src/backend/jax/distribution_lib.py",
    "content": "\"\"\"Utilities for distribution strategy with JAX backend.\"\"\"\n\nimport jax\nimport numpy as np\n\nfrom keras.src.random import seed_generator\nfrom keras.src.utils import jax_utils\nfrom keras.src.utils import rng_utils\n\n\ndef list_devices(device_type=None):\n    \"\"\"Return all the available devices based on the device type.\n\n    Note that this should return the global devices in a distributed setting.\n\n    Args:\n        device_type: string of `\"cpu\"`, `\"gpu\"` or `\"tpu\"`. Defaults to `\"gpu\"`\n            or `\"tpu\"` if available when device_type is not provided. Otherwise\n            will return the `\"cpu\"` devices.\n\n    Return:\n        List of devices that are available for distribute computation.\n    \"\"\"\n    device_type = device_type.lower() if device_type else None\n    jax_devices = jax.devices(backend=device_type)\n    return [f\"{device.platform}:{device.id}\" for device in jax_devices]\n\n\ndef get_device_count(device_type=None):\n    \"\"\"Returns the number of available JAX devices.\n    Args:\n        device_type: Optional device type to count (e.g., \"cpu\", \"gpu\", \"tpu\").\n            If `None`, it defaults to counting \"gpu\" or \"tpu\" devices if\n            available, otherwise it counts \"cpu\" devices. It does not\n            return the sum of all device types.\n    Returns:\n        int: The total number of JAX devices for the specified type.\n    \"\"\"\n    device_type = device_type.lower() if device_type else None\n    return jax.device_count(device_type)\n\n\ndef distribute_tensor(tensor, layout):\n    \"\"\"Distribute the tensor based on the layout.\n\n    Note that this function can be used both in eager context, or within a\n    jitted function.\n\n    Args:\n        tensor: `jax.Array` that need to be distributed.\n        layout: `TensorLayout` for the created variable, or a\n            JAX-supported layout instance (e.g. `jax.sharding.Sharding`).\n\n    Returns:\n        Distributed value.\n    \"\"\"\n    # Avoid circular imports.\n    from keras.src.distribution import TensorLayout\n\n    if isinstance(layout, TensorLayout):\n        layout = layout.backend_layout\n\n    if jax_utils.is_in_jax_tracing_scope(tensor):\n        return jax.lax.with_sharding_constraint(tensor, layout)\n\n    # Skip relayout if unnecessary.\n    if isinstance(tensor, jax.Array):\n        if isinstance(\n            layout, jax.sharding.Sharding\n        ) and tensor.sharding.is_equivalent_to(layout, ndim=len(tensor.shape)):\n            return tensor\n        # JAX explicit \"layout\" support.\n        elif hasattr(layout, \"layout\"):\n            current_layout = getattr(tensor, \"layout\", None)\n            if current_layout == layout:\n                return tensor\n        # JAX explicit \"format\" support.\n        elif hasattr(layout, \"format\"):\n            current_layout = getattr(tensor, \"format\", None)\n            if current_layout == layout:\n                return tensor\n\n    return jax.device_put(tensor, layout)\n\n\ndef distribute_data_input(per_process_batch, layout, batch_dim_name):\n    \"\"\"Distribute the input data with the corresponding layout.\n\n    Note that the inputs here is a local worker batch. Within the local worker,\n    the data need to be further partitioned to map to each of the devices.\n\n    Args:\n        inputs: `jax.Array` that is already sharded to a local process size.\n        layout: `TensorLayout` for the distribution information, or a\n            `jax.sharding.Sharding` instance.\n\n    Returns:\n        A global batch distributed according to `layout`.\n    \"\"\"\n    # Avoid circular imports.\n    from keras.src.distribution import TensorLayout\n\n    if isinstance(layout, TensorLayout):\n        layout = layout.backend_layout\n\n    return jax.make_array_from_process_local_data(layout, per_process_batch)\n\n\ndef initialize_rng():\n    \"\"\"Initializes the global random number generator across processes.\n\n    This is required for consistent initialization in multi-host settings.\n    \"\"\"\n    global_seed = rng_utils.get_random_seed()\n    # Only set a random seed if not already set\n    # via keras.config.set_random_seed()\n    if global_seed is None:\n        # Generate a random seed on each CPU host and psum them to get a single\n        # consistent seed across all processes.\n        cpu_devices = jax.devices(\"cpu\")\n        num_local_cpu_devices = jax.local_device_count(\"cpu\")\n        # Seed must be in range [0, 2^32 - 1], so to ensure proper range and\n        # avoid signed integer overflow, we use uint32.\n        local_seed = jax.numpy.asarray(\n            [seed_generator.make_default_seed()] * num_local_cpu_devices,\n            dtype=jax.numpy.uint32,\n        )\n        # Sum across processes and pull out the first item.\n        global_seed = jax.pmap(\n            lambda x: jax.lax.psum(x, \"all\"),\n            axis_name=\"all\",\n            devices=cpu_devices,\n        )(local_seed).item(0)\n        # Set the global seed.\n        rng_utils.set_random_seed(global_seed)\n\n\ndef initialize(job_addresses, num_processes, process_id):\n    if job_addresses and \",\" in job_addresses:\n        # When user provide all the job addresses, we will split and get the\n        # first one, which is the coordinator.\n        job_addresses = job_addresses.split(\",\")\n        # Do a sanity check to make sure the number of addresses also match\n        # the num_processes.\n        if num_processes is not None and num_processes != len(job_addresses):\n            raise ValueError(\n                f\"The provided job_addresses {job_addresses} has \"\n                f\"{len(job_addresses)} jobs, but num_processes is \"\n                f\"{num_processes}\"\n            )\n        coordinator_address = job_addresses[0]\n    else:\n        coordinator_address = job_addresses\n\n    jax.distributed.initialize(\n        coordinator_address=coordinator_address,\n        num_processes=num_processes,\n        process_id=process_id,\n    )\n\n    # Ensure the random number generator is initialized across processes.\n    initialize_rng()\n\n\ndef num_processes():\n    \"\"\"Return the number of processes for the current distribution setting.\"\"\"\n    return jax.process_count()\n\n\ndef process_id():\n    \"\"\"Return the current process ID for the distribution setting.\"\"\"\n    return jax.process_index()\n\n\ndef _to_backend_device(device_name):\n    if isinstance(device_name, jax.Device):\n        return device_name\n    device_name = str(device_name)\n    if \":\" not in device_name:\n        device_type, device_id = device_name, 0\n    else:\n        device_type, device_id = device_name.split(\":\")\n\n    devices = jax.devices(backend=device_type)\n    for device in devices:\n        if device.platform == device_type and device.id == int(device_id):\n            return device\n    raise ValueError(f\"Device not found: {device_name}\")\n\n\ndef _to_backend_mesh(device_mesh):\n    \"\"\"Convert the DeviceMesh to JAX backend specific Mesh.\n\n    Args:\n        device_mesh: DeviceMesh instance to convert.\n\n    Returns:\n        A `jax.sharding.Mesh` instance.\n    \"\"\"\n    shape = device_mesh.devices.shape\n    devices = [_to_backend_device(d) for d in device_mesh.devices.flatten()]\n    devices = np.array(devices).reshape(shape)\n    return jax.sharding.Mesh(devices, device_mesh.axis_names)\n\n\ndef _to_backend_layout(tensor_layout):\n    \"\"\"Convert the TensorLayout to JAX backend specific Sharding.\n\n    Args:\n        tensor_layout: TensorLayout instance to convert.\n\n    Returns:\n        A `jax.sharding.NamedSharding` instance.\n    \"\"\"\n    if tensor_layout.device_mesh is None:\n        raise ValueError(\n            \"Cannot create sharding when device mesh is not set \"\n            \"for TensorLayout.\"\n        )\n    partition_spec = jax.sharding.PartitionSpec(*tensor_layout.axes)\n    jax_mesh = tensor_layout.device_mesh.backend_mesh\n    return jax.sharding.NamedSharding(jax_mesh, partition_spec)\n"
  },
  {
    "path": "keras/src/backend/jax/distribution_lib_test.py",
    "content": "\"\"\"Test for distribution_lib.py.\"\"\"\n\nimport functools\nimport os\nfrom unittest import mock\n\nimport jax\nimport numpy as np\nimport pytest\nfrom jax.experimental import layout as jax_layout\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import testing\nfrom keras.src.backend import distribution_lib as backend_dlib\nfrom keras.src.distribution import distribution_lib\n\nif backend.backend() == \"jax\":\n    # Due to https://github.com/google/jax/issues/17188, we can't\n    # override the XLA flag after the JAX back init. We have to\n    # run this at top level to let JAX pick the flag value.\n    xla_flags = os.getenv(\"XLA_FLAGS\") or \"\"\n    # Don't override user-specified device count, or other XLA flags.\n    if \"xla_force_host_platform_device_count\" not in xla_flags:\n        os.environ[\"XLA_FLAGS\"] = (\n            f\"{xla_flags} --xla_force_host_platform_device_count=8\"\n        )\n\n\n@pytest.mark.skipif(\n    backend.backend() != \"jax\" or len(jax.devices()) != 8,\n    reason=\"Backend specific test and requires 8 devices\",\n)\nclass JaxDistributionLibTest(testing.TestCase):\n    def _create_jax_layout(self, sharding):\n        # Use jax_layout.Format or jax_layout.Layout if available.\n        if hasattr(jax_layout, \"Format\"):\n            return jax_layout.Format(sharding=sharding)\n        elif hasattr(jax_layout, \"Layout\"):\n            return jax_layout.Layout(sharding=sharding)\n\n        return sharding\n\n    def test_get_device_count(self):\n        self.assertEqual(backend_dlib.get_device_count(), 8)\n        self.assertEqual(backend_dlib.get_device_count(\"cpu\"), 8)\n\n    def test_list_devices(self):\n        self.assertEqual(len(distribution_lib.list_devices()), 8)\n        self.assertEqual(len(distribution_lib.list_devices(\"cpu\")), 8)\n        self.assertEqual(len(distribution_lib.list_devices(\"cpu\")), 8)\n\n    def test_device_conversion(self):\n        devices = distribution_lib.list_devices(\"cpu\")\n        jax_devices = jax.devices(\"cpu\")\n\n        for d, jax_d in zip(devices, jax_devices):\n            converted_jax_device = backend_dlib._to_backend_device(d)\n            self.assertIsInstance(converted_jax_device, jax.Device)\n            self.assertEqual(jax_d, converted_jax_device)\n\n    @mock.patch.object(jax.distributed, \"initialize\", return_value=None)\n    def test_initialize_with_all_job_addresses(self, mock_jax_initialize):\n        backend_dlib.initialize(\"10.0.0.1:1234,10.0.0.2:2345\", 2, 0)\n        mock_jax_initialize.assert_called_once_with(\n            coordinator_address=\"10.0.0.1:1234\", num_processes=2, process_id=0\n        )\n\n    def test_initialize_validate_job_and_process(self):\n        with self.assertRaisesRegex(\n            ValueError, \"has 2 jobs, but num_processes is 3\"\n        ):\n            backend_dlib.initialize(\"10.0.0.1:1234,10.0.0.2:2345\", 3, 0)\n\n    @mock.patch.object(jax.distributed, \"initialize\", return_value=None)\n    def test_initialize_with_coordinator_address(self, mock_jax_initialize):\n        backend_dlib.initialize(\"10.0.0.1:1234\", 2, 0)\n        mock_jax_initialize.assert_called_once_with(\n            coordinator_address=\"10.0.0.1:1234\", num_processes=2, process_id=0\n        )\n\n    def test_distribute_tensor(self):\n        jax_mesh = jax.sharding.Mesh(\n            np.array(jax.devices()).reshape(2, 4), (\"batch\", \"model\")\n        )\n\n        inputs = jax.numpy.array(np.random.normal(size=(16, 8)))\n        target_layout = jax.sharding.NamedSharding(\n            jax_mesh, jax.sharding.PartitionSpec(\"batch\", None)\n        )\n\n        @functools.partial(jax.jit, static_argnames=\"target_layout\")\n        def test_function(inputs, target_layout):\n            return distribution_lib.distribute_tensor(inputs, target_layout)\n\n        result = test_function(inputs, target_layout)\n        # Note that the returned tensor has a different sharding implementation\n        # which is GSPMDSharding, but it should be equivalent as the target\n        # layout specified.\n        self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))\n\n        # Test without jit\n        result = distribution_lib.distribute_tensor(inputs, target_layout)\n        self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))\n\n    def test_distribute_tensor_with_jax_layout(self):\n        jax_mesh = jax.sharding.Mesh(\n            np.array(jax.devices()).reshape(2, 4), (\"batch\", \"model\")\n        )\n\n        inputs = jax.numpy.array(np.random.normal(size=(16, 8)))\n        target_layout = self._create_jax_layout(\n            sharding=jax.sharding.NamedSharding(\n                jax_mesh, jax.sharding.PartitionSpec(\"batch\", None)\n            )\n        )\n\n        @functools.partial(jax.jit, static_argnames=\"target_layout\")\n        def test_function(inputs, target_layout):\n            return distribution_lib.distribute_tensor(inputs, target_layout)\n\n        result = test_function(inputs, target_layout)\n        # Note that the returned tensor has a different sharding implementation\n        # which is GSPMDSharding, but it should be equivalent as the target\n        # layout specified.\n        self.assertTrue(\n            result.sharding.is_equivalent_to(target_layout.sharding, ndim=2)\n        )\n\n        # Test without jit.\n        result = distribution_lib.distribute_tensor(inputs, target_layout)\n        self.assertTrue(\n            result.sharding.is_equivalent_to(target_layout.sharding, ndim=2)\n        )\n\n    def test_processes(self):\n        self.assertEqual(backend_dlib.process_id(), 0)\n        self.assertEqual(backend_dlib.num_processes(), 1)\n\n    def test_to_backend_mesh(self):\n        devices = [f\"cpu:{i}\" for i in range(8)]\n        shape = (4, 2)\n        axis_names = [\"batch\", \"model\"]\n\n        mesh = distribution_lib.DeviceMesh(shape, axis_names, devices)\n        jax_mesh = backend_dlib._to_backend_mesh(mesh)\n\n        self.assertIsInstance(jax_mesh, jax.sharding.Mesh)\n        self.assertEqual(jax_mesh.devices.shape, shape)\n        self.assertEqual(jax_mesh.axis_names, (\"batch\", \"model\"))\n\n    def test_to_backend_layout(self):\n        axes = [\"data\", None]\n        mesh = distribution_lib.DeviceMesh(\n            (4, 2), [\"data\", \"model\"], [f\"cpu:{i}\" for i in range(8)]\n        )\n        layout = distribution_lib.TensorLayout(axes, mesh)\n        jax_sharding = backend_dlib._to_backend_layout(layout)\n        jax_mesh = backend_dlib._to_backend_mesh(mesh)\n        self.assertEqual(\n            jax_sharding,\n            jax.sharding.NamedSharding(\n                jax_mesh, jax.sharding.PartitionSpec(\"data\", None)\n            ),\n        )\n\n    def test_validation_for_device_mesh(self):\n        axes = [\"data\", None]\n        layout = distribution_lib.TensorLayout(axes, device_mesh=None)\n\n        with self.assertRaisesRegex(\n            ValueError, \"Cannot create sharding when device mesh is not set\"\n        ):\n            backend_dlib._to_backend_layout(layout)\n\n    def test_variable_assignment_reuse_layout(self):\n        shape = (4, 2)\n        axis_names = [\"batch\", \"model\"]\n        device_mesh = distribution_lib.DeviceMesh(\n            shape, axis_names, backend_dlib.list_devices()\n        )\n        layout_map = distribution_lib.LayoutMap(device_mesh)\n        layout_map[\".*dense.*kernel\"] = distribution_lib.TensorLayout(\n            [None, \"model\"]\n        )\n        layout_map[\".*dense.*bias\"] = distribution_lib.TensorLayout([\"model\"])\n\n        distribution = distribution_lib.ModelParallel(\n            layout_map=layout_map, batch_dim_name=\"batch\"\n        )\n\n        with distribution.scope():\n            dense_layer = layers.Dense(8)\n            dense_layer.build((16, 16))\n\n        self.assertEqual(\n            dense_layer.kernel._value.sharding.spec, (None, \"model\")\n        )\n        self.assertEqual(dense_layer.bias._value.sharding.spec, (\"model\",))\n\n        # Assign a numpy value to dense layer to mimic the model weight loading\n        new_kernel = np.random.normal(size=(16, 8))\n        new_bias = np.random.normal(size=(8))\n        dense_layer.kernel.assign(new_kernel)\n        dense_layer.bias.assign(new_bias)\n\n        # Make sure the loaded value still use the layout when it is\n        # initialized, even outside of the distribution scope.\n        self.assertEqual(\n            dense_layer.kernel._value.sharding.spec, (None, \"model\")\n        )\n        self.assertEqual(dense_layer.bias._value.sharding.spec, (\"model\",))\n\n    def test_e2e_data_parallel_model(self):\n        distribution = distribution_lib.DataParallel(\n            devices=backend_dlib.list_devices()\n        )\n\n        with distribution.scope():\n            inputs = layers.Input(shape=[28, 28, 1])\n            y = layers.Flatten()(inputs)\n            y = layers.Dense(units=200, use_bias=False, activation=\"relu\")(y)\n            y = layers.Dropout(0.4)(y)\n            y = layers.Dense(units=10, activation=\"softmax\")(y)\n            model = models.Model(inputs=inputs, outputs=y)\n\n        # Make sure all the weights are properly sharded.\n        for weight in model.weights:\n            self.assertTrue(weight._value.sharding.is_fully_replicated)\n\n        inputs = np.random.normal(size=(32, 28, 28, 1))\n        labels = np.random.normal(size=(32, 10))\n\n        with distribution.scope():\n            model.compile(loss=\"mse\")\n            model.fit(inputs, labels)\n\n    def test_e2e_model_parallel_model(self):\n        shape = (4, 2)\n        axis_names = [\"batch\", \"model\"]\n        device_mesh = distribution_lib.DeviceMesh(\n            shape, axis_names, backend_dlib.list_devices()\n        )\n\n        layout_map = distribution_lib.LayoutMap(device_mesh)\n        layout_map[\".*dense.*kernel\"] = distribution_lib.TensorLayout(\n            [None, \"model\"]\n        )\n        layout_map[\".*dense.*bias\"] = distribution_lib.TensorLayout([\"model\"])\n\n        distribution = distribution_lib.ModelParallel(\n            layout_map=layout_map, batch_dim_name=\"batch\"\n        )\n        with distribution.scope():\n            inputs = layers.Input(shape=[28, 28, 1])\n            y = layers.Flatten()(inputs)\n            y = layers.Dense(units=200, use_bias=False, activation=\"relu\")(y)\n            y = layers.Dropout(0.4)(y)\n            y = layers.Dense(units=10, activation=\"softmax\")(y)\n            model = models.Model(inputs=inputs, outputs=y)\n\n        for weight in model.weights:\n            if \"kernel\" in weight.name:\n                self.assertEqual(weight._value.sharding.spec, (None, \"model\"))\n            elif \"bias\" in weight.name:\n                self.assertEqual(weight._value.sharding.spec, (\"model\",))\n            else:\n                self.assertTrue(weight._value.sharding.is_fully_replicated)\n\n        inputs = np.random.normal(size=(32, 28, 28, 1))\n        labels = np.random.normal(size=(32, 10))\n\n        with distribution.scope():\n            model.compile(loss=\"mse\")\n            model.fit(inputs, labels)\n\n    def test_e2e_model_parallel_with_output_sharding(self):\n        shape = (4, 2)\n        axis_names = [\"batch\", \"model\"]\n        device_mesh = distribution_lib.DeviceMesh(\n            shape, axis_names, backend_dlib.list_devices()\n        )\n\n        layout_map = distribution_lib.LayoutMap(device_mesh)\n        layout_map[\".*dense.*kernel\"] = distribution_lib.TensorLayout(\n            [None, \"model\"]\n        )\n        layout_map[\".*dense.*bias\"] = distribution_lib.TensorLayout([\"model\"])\n        # Force the dense layer output to be batch parallel only, and not\n        # sharded on model dimension.\n        layout_map[\".*dense.*output\"] = (\"batch\", None)\n\n        distribution = distribution_lib.ModelParallel(\n            layout_map=layout_map, batch_dim_name=\"batch\"\n        )\n        sharding_capture = ShardingCaptureLayer()\n        with distribution.scope():\n            inputs = layers.Input(shape=[28, 28, 1])\n            y = layers.Flatten()(inputs)\n            y = layers.Dense(units=200, use_bias=False, activation=\"relu\")(y)\n            y = sharding_capture(y)\n            y = layers.Dropout(0.4)(y)\n            y = layers.Dense(units=10, activation=\"softmax\")(y)\n            model = models.Model(inputs=inputs, outputs=y)\n\n        for weight in model.weights:\n            if \"kernel\" in weight.name:\n                self.assertEqual(weight._value.sharding.spec, (None, \"model\"))\n            elif \"bias\" in weight.name:\n                self.assertEqual(weight._value.sharding.spec, (\"model\",))\n            else:\n                self.assertTrue(weight._value.sharding.is_fully_replicated)\n\n        inputs = np.random.normal(size=(32, 28, 28, 1))\n        labels = np.random.normal(size=(32, 10))\n\n        with distribution.scope():\n            model.compile(loss=\"mse\")\n            model.fit(inputs, labels)\n\n        # Note that the intermediate_tensor_layout is only captured during the\n        # actual training, and not at the model building time.\n        intermediate_tensor_layout = jax.sharding.NamedSharding(\n            backend_dlib._to_backend_mesh(distribution.device_mesh),\n            jax.sharding.PartitionSpec(\"batch\", None),\n        )\n        self.assertTrue(\n            sharding_capture.captured_input_sharding.is_equivalent_to(\n                intermediate_tensor_layout, ndim=2\n            )\n        )\n\n    def test_distribute_data_input(self):\n        per_process_batch = jax.numpy.arange(24).reshape(\n            6, 4\n        )  # Example input array\n        devices = jax.devices()[:4]  # Simulate 4 devices\n        batch_dim_size, model_dim_size = 2, 2\n        mesh = jax.sharding.Mesh(\n            np.array(devices).reshape(batch_dim_size, model_dim_size),\n            axis_names=[\"batch\", \"model\"],\n        )\n        layout = jax.sharding.NamedSharding(\n            mesh, jax.sharding.PartitionSpec(\"batch\", None)\n        )\n\n        result = backend_dlib.distribute_data_input(\n            per_process_batch, layout, \"batch\"\n        )\n\n        # Check the shape of the global batch array\n        self.assertEqual(\n            result.shape, (6, 4)\n        )  # (per_replica_batch_size * num_model_replicas_total, 4)\n\n        # Check the sharding of the global batch array\n        self.assertEqual(len(result.addressable_shards), len(devices))\n        # Since batch_dim_size=2, there are 2 model replicas so there is one\n        # replication of data for model replica #1 and another replication of\n        # data for model replica #2. Within each model replica, the data is\n        # sharded to two shards. Therefore, each shard has 1/2 of\n        # per_process_batch.\n        for shard in result.addressable_shards:\n            self.assertEqual(shard.data.shape, (3, 4))\n\n\nclass ShardingCaptureLayer(layers.Layer):\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n        self.captured_input_sharding = None\n        self.supports_masking = True\n\n    def call(self, inputs):\n        jax.debug.inspect_array_sharding(\n            inputs, callback=lambda x: self.capture_input_sharding(x)\n        )\n        return inputs\n\n    def capture_input_sharding(self, sharding):\n        self.captured_input_sharding = sharding\n"
  },
  {
    "path": "keras/src/backend/jax/excluded_tpu_tests.txt",
    "content": "ConvTransposeBasicTest\nExportArchiveTest::test_jax_endpoint_registration_tf_function\nExportArchiveTest::test_jax_multi_unknown_endpoint_registration\nExportArchiveTest::test_layer_export\nExportArchiveTest::test_low_level_model_export_functional\nExportArchiveTest::test_low_level_model_export_sequential\nExportArchiveTest::test_low_level_model_export_subclass\nExportArchiveTest::test_low_level_model_export_with_alias\nExportArchiveTest::test_low_level_model_export_with_dynamic_dims_functional\nExportArchiveTest::test_low_level_model_export_with_dynamic_dims_sequential\nExportArchiveTest::test_low_level_model_export_with_dynamic_dims_subclass\nExportArchiveTest::test_low_level_model_export_with_jax2tf_kwargs\nExportArchiveTest::test_low_level_model_export_with_jax2tf_polymorphic_shapes\nExportArchiveTest::test_model_combined_with_tf_preprocessing\nExportArchiveTest::test_model_export_method_functional\nExportArchiveTest::test_model_export_method_sequential\nExportArchiveTest::test_model_export_method_subclass\nExportArchiveTest::test_multi_input_output_functional_model\nExportArchiveTest::test_non_standard_layer_signature\nExportArchiveTest::test_non_standard_layer_signature_with_kwargs\nExportArchiveTest::test_track_multiple_layers\nExportONNXTest::test_export_with_input_names\nExportONNXTest::test_export_with_opset_version_18\nExportONNXTest::test_export_with_opset_version_none\nExportONNXTest::test_model_with_input_structure_array\nExportONNXTest::test_model_with_input_structure_dict\nExportONNXTest::test_model_with_input_structure_tuple\nExportONNXTest::test_model_with_multiple_inputs\nExportONNXTest::test_standard_model_export_functional\nExportONNXTest::test_standard_model_export_lstm\nExportONNXTest::test_standard_model_export_sequential\nExportONNXTest::test_standard_model_export_subclass\nExportOpenVINOTest::test_model_with_input_structure_array\nExportOpenVINOTest::test_model_with_input_structure_dict\nExportOpenVINOTest::test_model_with_input_structure_tuple\nExportOpenVINOTest::test_model_with_multiple_inputs\nExportOpenVINOTest::test_standard_model_export_functional\nExportOpenVINOTest::test_standard_model_export_sequential\nExportOpenVINOTest::test_standard_model_export_subclass\nExportSavedModelTest::test_input_signature_functional_<kerastensor shape=(none, 10), dtype=float32, sparse=false, ragged=false, name=inputs>\nExportSavedModelTest::test_input_signature_functional_backend_tensor\nExportSavedModelTest::test_input_signature_functional_inputspec(dtype=float32, shape=(none, 10), ndim=2)\nExportSavedModelTest::test_input_signature_functional_tensorspec(shape=(none, 10), dtype=tf.float32, name='inputs')\nExportSavedModelTest::test_input_signature_sequential_<kerastensor shape=(none, 10), dtype=float32, sparse=false, ragged=false, name=inputs>\nExportSavedModelTest::test_input_signature_sequential_backend_tensor\nExportSavedModelTest::test_input_signature_sequential_inputspec(dtype=float32, shape=(none, 10), ndim=2)\nExportSavedModelTest::test_input_signature_sequential_tensorspec(shape=(none, 10), dtype=tf.float32, name='inputs')\nExportSavedModelTest::test_input_signature_subclass_<kerastensor shape=(none, 10), dtype=float32, sparse=false, ragged=false, name=inputs>\nExportSavedModelTest::test_input_signature_subclass_backend_tensor\nExportSavedModelTest::test_input_signature_subclass_inputspec(dtype=float32, shape=(none, 10), ndim=2)\nExportSavedModelTest::test_input_signature_subclass_tensorspec(shape=(none, 10), dtype=tf.float32, name='inputs')\nExportSavedModelTest::test_jax_specific_kwargs_functional_false_{'enable_xla': true, 'native_serialization': true}\nExportSavedModelTest::test_jax_specific_kwargs_functional_false_none\nExportSavedModelTest::test_jax_specific_kwargs_functional_true_{'enable_xla': true, 'native_serialization': true}\nExportSavedModelTest::test_jax_specific_kwargs_functional_true_none\nExportSavedModelTest::test_jax_specific_kwargs_sequential_false_{'enable_xla': true, 'native_serialization': true}\nExportSavedModelTest::test_jax_specific_kwargs_sequential_false_none\nExportSavedModelTest::test_jax_specific_kwargs_sequential_true_{'enable_xla': true, 'native_serialization': true}\nExportSavedModelTest::test_jax_specific_kwargs_sequential_true_none\nExportSavedModelTest::test_jax_specific_kwargs_subclass_false_{'enable_xla': true, 'native_serialization': true}\nExportSavedModelTest::test_jax_specific_kwargs_subclass_false_none\nExportSavedModelTest::test_jax_specific_kwargs_subclass_true_{'enable_xla': true, 'native_serialization': true}\nExportSavedModelTest::test_jax_specific_kwargs_subclass_true_none\nExportSavedModelTest::test_model_with_input_structure_array\nExportSavedModelTest::test_model_with_input_structure_dict\nExportSavedModelTest::test_model_with_input_structure_tuple\nExportSavedModelTest::test_model_with_multiple_inputs\nExportSavedModelTest::test_model_with_non_trainable_state_export_functional\nExportSavedModelTest::test_model_with_non_trainable_state_export_sequential\nExportSavedModelTest::test_model_with_non_trainable_state_export_subclass\nExportSavedModelTest::test_model_with_rng_export_functional\nExportSavedModelTest::test_model_with_rng_export_sequential\nExportSavedModelTest::test_model_with_rng_export_subclass\nExportSavedModelTest::test_model_with_tf_data_layer_functional\nExportSavedModelTest::test_model_with_tf_data_layer_sequential\nExportSavedModelTest::test_model_with_tf_data_layer_subclass\nExportSavedModelTest::test_standard_model_export_functional\nExportSavedModelTest::test_standard_model_export_sequential\nExportSavedModelTest::test_standard_model_export_subclass\nTestJaxLayer::test_flax_layer_training_independent_bound_method\nTestJaxLayer::test_flax_layer_training_rng_state_no_method\nTestJaxLayer::test_flax_layer_training_rng_unbound_method\nTestJaxLayer::test_flax_layer_training_rng_unbound_method_dtype_policy\nTestJaxLayer::test_jax_layer_stateless\nTestJaxLayer::test_jax_layer_training_independent\nTestJaxLayer::test_jax_layer_training_state\nTestJaxLayer::test_jax_layer_training_state_dtype_policy"
  },
  {
    "path": "keras/src/backend/jax/export.py",
    "content": "import copy\nimport inspect\nimport itertools\nimport string\nimport warnings\n\nfrom keras.src import tree\nfrom keras.src.backend.common.stateless_scope import StatelessScope\nfrom keras.src.export.saved_model_export_archive import SavedModelExportArchive\nfrom keras.src.utils.module_utils import tensorflow as tf\n\n\nclass JaxExportArchive(SavedModelExportArchive):\n    \"\"\"JAX backend implementation of SavedModel export archive.\"\"\"\n\n    def _backend_init(self):\n        \"\"\"JAX-specific initialization.\"\"\"\n        self._backend_variables = []\n        self._backend_trainable_variables = []\n        self._backend_non_trainable_variables = []\n\n    def _backend_track_layer(self, layer):\n        # Variables in the lists below are actually part of the trackables\n        # that get saved, because the lists are created in __init__.\n        trainable_variables = layer.trainable_variables\n        non_trainable_variables = layer.non_trainable_variables\n\n        self._tf_trackable.trainable_variables += tree.map_structure(\n            self._convert_to_tf_variable, trainable_variables\n        )\n        self._tf_trackable.non_trainable_variables += tree.map_structure(\n            self._convert_to_tf_variable, non_trainable_variables\n        )\n        self._tf_trackable.variables = (\n            self._tf_trackable.trainable_variables\n            + self._tf_trackable.non_trainable_variables\n        )\n\n        self._backend_trainable_variables += trainable_variables\n        self._backend_non_trainable_variables += non_trainable_variables\n        self._backend_variables = (\n            self._backend_trainable_variables\n            + self._backend_non_trainable_variables\n        )\n\n    def _backend_add_endpoint(self, name, fn, input_signature, **kwargs):\n        jax2tf_kwargs = kwargs.pop(\"jax2tf_kwargs\", None)\n        # Use `copy.copy()` to avoid modification issues.\n        jax2tf_kwargs = copy.copy(jax2tf_kwargs) or {}\n        is_static = bool(kwargs.pop(\"is_static\", False))\n\n        # Configure `jax2tf_kwargs`\n        if \"native_serialization\" not in jax2tf_kwargs:\n            jax2tf_kwargs[\"native_serialization\"] = (\n                self._check_device_compatible()\n            )\n        if \"polymorphic_shapes\" not in jax2tf_kwargs:\n            jax2tf_kwargs[\"polymorphic_shapes\"] = self._to_polymorphic_shape(\n                input_signature\n            )\n\n        # Note: we truncate the number of parameters to what is specified by\n        # `input_signature`.\n        fn_signature = inspect.signature(fn)\n        fn_parameters = list(fn_signature.parameters.values())\n\n        if is_static:\n            from jax.experimental import jax2tf\n\n            jax_fn = jax2tf.convert(fn, **jax2tf_kwargs)\n            jax_fn.__signature__ = inspect.Signature(\n                parameters=fn_parameters[0 : len(input_signature)],\n                return_annotation=fn_signature.return_annotation,\n            )\n\n            decorated_fn = tf.function(\n                jax_fn,\n                input_signature=input_signature,\n                autograph=False,\n            )\n        else:\n            # 1. Create a stateless wrapper for `fn`\n            # 2. jax2tf the stateless wrapper\n            # 3. Create a stateful function that binds the variables with\n            #    the jax2tf converted stateless wrapper\n            # 4. Make the signature of the stateful function the same as the\n            #    original function\n            # 5. Wrap in a `tf.function`\n            def stateless_fn(variables, *args, **kwargs):\n                state_mapping = zip(self._backend_variables, variables)\n                with StatelessScope(state_mapping=state_mapping) as scope:\n                    output = fn(*args, **kwargs)\n\n                # Gather updated non-trainable variables\n                non_trainable_variables = []\n                for var in self._backend_non_trainable_variables:\n                    new_value = scope.get_current_value(var)\n                    non_trainable_variables.append(new_value)\n                return output, non_trainable_variables\n\n            jax2tf_stateless_fn = self._convert_jax2tf_function(\n                stateless_fn, input_signature, jax2tf_kwargs=jax2tf_kwargs\n            )\n\n            def stateful_fn(*args, **kwargs):\n                output, non_trainable_variables = jax2tf_stateless_fn(\n                    # Change the trackable `ListWrapper` to a plain `list`\n                    list(self._tf_trackable.variables),\n                    *args,\n                    **kwargs,\n                )\n                for var, new_value in zip(\n                    self._tf_trackable.non_trainable_variables,\n                    non_trainable_variables,\n                ):\n                    var.assign(tf.cast(new_value, var.dtype))\n                return output\n\n            stateful_fn.__signature__ = inspect.Signature(\n                parameters=fn_parameters[0 : len(input_signature)],\n                return_annotation=fn_signature.return_annotation,\n            )\n\n            decorated_fn = tf.function(\n                stateful_fn,\n                input_signature=input_signature,\n                autograph=False,\n            )\n        return decorated_fn\n\n    def _convert_jax2tf_function(self, fn, input_signature, jax2tf_kwargs=None):\n        from jax.experimental import jax2tf\n\n        variables_shapes = self._to_polymorphic_shape(\n            self._backend_variables, allow_none=False\n        )\n        input_shapes = list(jax2tf_kwargs[\"polymorphic_shapes\"])\n        jax2tf_kwargs[\"polymorphic_shapes\"] = [variables_shapes] + input_shapes\n        return jax2tf.convert(fn, **jax2tf_kwargs)\n\n    def _to_polymorphic_shape(self, struct, allow_none=True):\n        if allow_none:\n            # Generates unique names: a, b, ... z, aa, ab, ... az, ba, ... zz\n            # for unknown non-batch dims. Defined here to be scope per endpoint.\n            dim_names = itertools.chain(\n                string.ascii_lowercase,\n                itertools.starmap(\n                    lambda a, b: a + b,\n                    itertools.product(string.ascii_lowercase, repeat=2),\n                ),\n            )\n\n        def convert_shape(x):\n            poly_shape = []\n            for index, dim in enumerate(list(x.shape)):\n                if dim is not None:\n                    poly_shape.append(str(dim))\n                elif not allow_none:\n                    raise ValueError(\n                        f\"Illegal None dimension in {x} with shape {x.shape}\"\n                    )\n                elif index == 0:\n                    poly_shape.append(\"batch\")\n                else:\n                    poly_shape.append(next(dim_names))\n            return f\"({', '.join(poly_shape)})\"\n\n        return tree.map_structure(convert_shape, struct)\n\n    def _check_device_compatible(self):\n        from jax import default_backend as jax_device\n\n        if (\n            jax_device() == \"gpu\"\n            and len(tf.config.list_physical_devices(\"GPU\")) == 0\n        ):\n            warnings.warn(\n                \"JAX backend is using GPU for export, but installed \"\n                \"TF package cannot access GPU, so reloading the model with \"\n                \"the TF runtime in the same environment will not work. \"\n                \"To use JAX-native serialization for high-performance export \"\n                \"and serving, please install `tensorflow-gpu` and ensure \"\n                \"CUDA version compatibility between your JAX and TF \"\n                \"installations.\"\n            )\n            return False\n        else:\n            return True\n"
  },
  {
    "path": "keras/src/backend/jax/image.py",
    "content": "import functools\n\nimport jax\nimport jax.numpy as jnp\n\nfrom keras.src import backend\nfrom keras.src.backend.jax.core import convert_to_tensor\nfrom keras.src.random.seed_generator import draw_seed\n\nRESIZE_INTERPOLATIONS = (\n    \"bilinear\",\n    \"nearest\",\n    \"lanczos3\",\n    \"lanczos5\",\n    \"bicubic\",\n)\nAFFINE_TRANSFORM_INTERPOLATIONS = {  # map to order\n    \"nearest\": 0,\n    \"bilinear\": 1,\n}\nAFFINE_TRANSFORM_FILL_MODES = {\n    \"constant\",\n    \"nearest\",\n    \"wrap\",\n    \"mirror\",\n    \"reflect\",\n}\nMAP_COORDINATES_FILL_MODES = {\n    \"constant\",\n    \"nearest\",\n    \"wrap\",\n    \"mirror\",\n    \"reflect\",\n}\nSCALE_AND_TRANSLATE_METHODS = {\n    \"linear\",\n    \"bilinear\",\n    \"trilinear\",\n    \"cubic\",\n    \"bicubic\",\n    \"tricubic\",\n    \"lanczos3\",\n    \"lanczos5\",\n}\n\n\ndef rgb_to_grayscale(images, data_format=None):\n    images = convert_to_tensor(images)\n    data_format = backend.standardize_data_format(data_format)\n    channels_axis = -1 if data_format == \"channels_last\" else -3\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n    # Convert to floats\n    original_dtype = images.dtype\n    compute_dtype = backend.result_type(images.dtype, float)\n    images = images.astype(compute_dtype)\n\n    # Ref: tf.image.rgb_to_grayscale\n    rgb_weights = convert_to_tensor(\n        [0.2989, 0.5870, 0.1140], dtype=images.dtype\n    )\n    images = jnp.tensordot(images, rgb_weights, axes=(channels_axis, -1))\n    images = jnp.expand_dims(images, axis=channels_axis)\n    return images.astype(original_dtype)\n\n\ndef rgb_to_hsv(images, data_format=None):\n    # Ref: dm_pix\n    images = convert_to_tensor(images)\n    dtype = images.dtype\n    data_format = backend.standardize_data_format(data_format)\n    channels_axis = -1 if data_format == \"channels_last\" else -3\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n    if not backend.is_float_dtype(dtype):\n        raise ValueError(\n            \"Invalid images dtype: expected float dtype. \"\n            f\"Received: images.dtype={backend.standardize_dtype(dtype)}\"\n        )\n    eps = jnp.finfo(dtype).eps\n    images = jnp.where(jnp.abs(images) < eps, 0.0, images)\n    red, green, blue = jnp.split(images, 3, channels_axis)\n    red = jnp.squeeze(red, channels_axis)\n    green = jnp.squeeze(green, channels_axis)\n    blue = jnp.squeeze(blue, channels_axis)\n\n    def rgb_planes_to_hsv_planes(r, g, b):\n        value = jnp.maximum(jnp.maximum(r, g), b)\n        minimum = jnp.minimum(jnp.minimum(r, g), b)\n        range_ = value - minimum\n\n        safe_value = jnp.where(value > 0, value, 1.0)\n        safe_range = jnp.where(range_ > 0, range_, 1.0)\n\n        saturation = jnp.where(value > 0, range_ / safe_value, 0.0)\n        norm = 1.0 / (6.0 * safe_range)\n\n        hue = jnp.where(\n            value == g,\n            norm * (b - r) + 2.0 / 6.0,\n            norm * (r - g) + 4.0 / 6.0,\n        )\n        hue = jnp.where(value == r, norm * (g - b), hue)\n        hue = jnp.where(range_ > 0, hue, 0.0) + (hue < 0.0).astype(hue.dtype)\n        return hue, saturation, value\n\n    images = jnp.stack(\n        rgb_planes_to_hsv_planes(red, green, blue), axis=channels_axis\n    )\n    return images\n\n\ndef hsv_to_rgb(images, data_format=None):\n    # Ref: dm_pix\n    images = convert_to_tensor(images)\n    dtype = images.dtype\n    data_format = backend.standardize_data_format(data_format)\n    channels_axis = -1 if data_format == \"channels_last\" else -3\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n    if not backend.is_float_dtype(dtype):\n        raise ValueError(\n            \"Invalid images dtype: expected float dtype. \"\n            f\"Received: images.dtype={backend.standardize_dtype(dtype)}\"\n        )\n    hue, saturation, value = jnp.split(images, 3, channels_axis)\n    hue = jnp.squeeze(hue, channels_axis)\n    saturation = jnp.squeeze(saturation, channels_axis)\n    value = jnp.squeeze(value, channels_axis)\n\n    def hsv_planes_to_rgb_planes(hue, saturation, value):\n        dh = jnp.mod(hue, 1.0) * 6.0\n        dr = jnp.clip(jnp.abs(dh - 3.0) - 1.0, 0.0, 1.0)\n        dg = jnp.clip(2.0 - jnp.abs(dh - 2.0), 0.0, 1.0)\n        db = jnp.clip(2.0 - jnp.abs(dh - 4.0), 0.0, 1.0)\n        one_minus_s = 1.0 - saturation\n\n        red = value * (one_minus_s + saturation * dr)\n        green = value * (one_minus_s + saturation * dg)\n        blue = value * (one_minus_s + saturation * db)\n        return red, green, blue\n\n    images = jnp.stack(\n        hsv_planes_to_rgb_planes(hue, saturation, value), axis=channels_axis\n    )\n    return images\n\n\ndef resize(\n    images,\n    size,\n    interpolation=\"bilinear\",\n    antialias=False,\n    crop_to_aspect_ratio=False,\n    pad_to_aspect_ratio=False,\n    fill_mode=\"constant\",\n    fill_value=0.0,\n    data_format=None,\n):\n    data_format = backend.standardize_data_format(data_format)\n    if interpolation not in RESIZE_INTERPOLATIONS:\n        raise ValueError(\n            \"Invalid value for argument `interpolation`. Expected of one \"\n            f\"{RESIZE_INTERPOLATIONS}. Received: interpolation={interpolation}\"\n        )\n    if fill_mode != \"constant\":\n        raise ValueError(\n            \"Invalid value for argument `fill_mode`. Only `'constant'` \"\n            f\"is supported. Received: fill_mode={fill_mode}\"\n        )\n    if pad_to_aspect_ratio and crop_to_aspect_ratio:\n        raise ValueError(\n            \"Only one of `pad_to_aspect_ratio` & `crop_to_aspect_ratio` \"\n            \"can be `True`.\"\n        )\n    if not len(size) == 2:\n        raise ValueError(\n            \"Argument `size` must be a tuple of two elements \"\n            f\"(height, width). Received: size={size}\"\n        )\n    size = tuple(size)\n    target_height, target_width = size\n    if len(images.shape) == 4:\n        if data_format == \"channels_last\":\n            size = (images.shape[0],) + size + (images.shape[-1],)\n        else:\n            size = (images.shape[0], images.shape[1]) + size\n        batch_size = images.shape[0]\n    elif len(images.shape) == 3:\n        if data_format == \"channels_last\":\n            size = size + (images.shape[-1],)\n        else:\n            size = (images.shape[0],) + size\n    else:\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n\n    if crop_to_aspect_ratio:\n        shape = images.shape\n        if data_format == \"channels_last\":\n            height, width = shape[-3], shape[-2]\n        else:\n            height, width = shape[-2], shape[-1]\n        crop_height = int(float(width * target_height) / target_width)\n        crop_height = max(min(height, crop_height), 1)\n        crop_width = int(float(height * target_width) / target_height)\n        crop_width = max(min(width, crop_width), 1)\n        crop_box_hstart = int(float(height - crop_height) / 2)\n        crop_box_wstart = int(float(width - crop_width) / 2)\n        if data_format == \"channels_last\":\n            if len(images.shape) == 4:\n                images = images[\n                    :,\n                    crop_box_hstart : crop_box_hstart + crop_height,\n                    crop_box_wstart : crop_box_wstart + crop_width,\n                    :,\n                ]\n            else:\n                images = images[\n                    crop_box_hstart : crop_box_hstart + crop_height,\n                    crop_box_wstart : crop_box_wstart + crop_width,\n                    :,\n                ]\n        else:\n            if len(images.shape) == 4:\n                images = images[\n                    :,\n                    :,\n                    crop_box_hstart : crop_box_hstart + crop_height,\n                    crop_box_wstart : crop_box_wstart + crop_width,\n                ]\n            else:\n                images = images[\n                    :,\n                    crop_box_hstart : crop_box_hstart + crop_height,\n                    crop_box_wstart : crop_box_wstart + crop_width,\n                ]\n    elif pad_to_aspect_ratio:\n        shape = images.shape\n        if data_format == \"channels_last\":\n            height, width, channels = shape[-3], shape[-2], shape[-1]\n        else:\n            height, width, channels = shape[-2], shape[-1], shape[-3]\n\n        pad_height = int(float(width * target_height) / target_width)\n        pad_height = max(height, pad_height)\n        pad_width = int(float(height * target_width) / target_height)\n        pad_width = max(width, pad_width)\n        img_box_hstart = int(float(pad_height - height) / 2)\n        img_box_wstart = int(float(pad_width - width) / 2)\n        if data_format == \"channels_last\":\n            if img_box_hstart > 0:\n                if len(images.shape) == 4:\n                    padded_img = jnp.concatenate(\n                        [\n                            jnp.ones(\n                                (batch_size, img_box_hstart, width, channels),\n                                dtype=images.dtype,\n                            )\n                            * fill_value,\n                            images,\n                            jnp.ones(\n                                (batch_size, img_box_hstart, width, channels),\n                                dtype=images.dtype,\n                            )\n                            * fill_value,\n                        ],\n                        axis=1,\n                    )\n                else:\n                    padded_img = jnp.concatenate(\n                        [\n                            jnp.ones(\n                                (img_box_hstart, width, channels),\n                                dtype=images.dtype,\n                            )\n                            * fill_value,\n                            images,\n                            jnp.ones(\n                                (img_box_hstart, width, channels),\n                                dtype=images.dtype,\n                            )\n                            * fill_value,\n                        ],\n                        axis=0,\n                    )\n            elif img_box_wstart > 0:\n                if len(images.shape) == 4:\n                    padded_img = jnp.concatenate(\n                        [\n                            jnp.ones(\n                                (batch_size, height, img_box_wstart, channels),\n                                dtype=images.dtype,\n                            )\n                            * fill_value,\n                            images,\n                            jnp.ones(\n                                (batch_size, height, img_box_wstart, channels),\n                                dtype=images.dtype,\n                            )\n                            * fill_value,\n                        ],\n                        axis=2,\n                    )\n                else:\n                    padded_img = jnp.concatenate(\n                        [\n                            jnp.ones(\n                                (height, img_box_wstart, channels),\n                                dtype=images.dtype,\n                            )\n                            * fill_value,\n                            images,\n                            jnp.ones(\n                                (height, img_box_wstart, channels),\n                                dtype=images.dtype,\n                            )\n                            * fill_value,\n                        ],\n                        axis=1,\n                    )\n            else:\n                padded_img = images\n        else:\n            if img_box_hstart > 0:\n                if len(images.shape) == 4:\n                    padded_img = jnp.concatenate(\n                        [\n                            jnp.ones(\n                                (batch_size, channels, img_box_hstart, width)\n                            )\n                            * fill_value,\n                            images,\n                            jnp.ones(\n                                (batch_size, channels, img_box_hstart, width)\n                            )\n                            * fill_value,\n                        ],\n                        axis=2,\n                    )\n                else:\n                    padded_img = jnp.concatenate(\n                        [\n                            jnp.ones((channels, img_box_hstart, width))\n                            * fill_value,\n                            images,\n                            jnp.ones((channels, img_box_hstart, width))\n                            * fill_value,\n                        ],\n                        axis=1,\n                    )\n            elif img_box_wstart > 0:\n                if len(images.shape) == 4:\n                    padded_img = jnp.concatenate(\n                        [\n                            jnp.ones(\n                                (batch_size, channels, height, img_box_wstart)\n                            )\n                            * fill_value,\n                            images,\n                            jnp.ones(\n                                (batch_size, channels, height, img_box_wstart)\n                            )\n                            * fill_value,\n                        ],\n                        axis=3,\n                    )\n                else:\n                    padded_img = jnp.concatenate(\n                        [\n                            jnp.ones((channels, height, img_box_wstart))\n                            * fill_value,\n                            images,\n                            jnp.ones((channels, height, img_box_wstart))\n                            * fill_value,\n                        ],\n                        axis=2,\n                    )\n            else:\n                padded_img = images\n        images = padded_img\n\n    return jax.image.resize(\n        images, size, method=interpolation, antialias=antialias\n    )\n\n\ndef affine_transform(\n    images,\n    transform,\n    interpolation=\"bilinear\",\n    fill_mode=\"constant\",\n    fill_value=0,\n    data_format=None,\n):\n    data_format = backend.standardize_data_format(data_format)\n    if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys():\n        raise ValueError(\n            \"Invalid value for argument `interpolation`. Expected of one \"\n            f\"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: \"\n            f\"interpolation={interpolation}\"\n        )\n    if fill_mode not in AFFINE_TRANSFORM_FILL_MODES:\n        raise ValueError(\n            \"Invalid value for argument `fill_mode`. Expected of one \"\n            f\"{AFFINE_TRANSFORM_FILL_MODES}. Received: fill_mode={fill_mode}\"\n        )\n\n    transform = convert_to_tensor(transform)\n\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n    if len(transform.shape) not in (1, 2):\n        raise ValueError(\n            \"Invalid transform rank: expected rank 1 (single transform) \"\n            \"or rank 2 (batch of transforms). Received input with shape: \"\n            f\"transform.shape={transform.shape}\"\n        )\n\n    # unbatched case\n    need_squeeze = False\n    if len(images.shape) == 3:\n        images = jnp.expand_dims(images, axis=0)\n        need_squeeze = True\n    if len(transform.shape) == 1:\n        transform = jnp.expand_dims(transform, axis=0)\n\n    if data_format == \"channels_first\":\n        images = jnp.transpose(images, (0, 2, 3, 1))\n\n    batch_size = images.shape[0]\n\n    # get indices\n    meshgrid = jnp.meshgrid(\n        *[jnp.arange(size) for size in images.shape[1:]], indexing=\"ij\"\n    )\n    indices = jnp.concatenate(\n        [jnp.expand_dims(x, axis=-1) for x in meshgrid], axis=-1\n    )\n    indices = jnp.tile(indices, (batch_size, 1, 1, 1, 1))\n\n    # swap the values\n    a0 = transform[:, 0]\n    a2 = transform[:, 2]\n    b1 = transform[:, 4]\n    b2 = transform[:, 5]\n    transform = transform.at[:, 0].set(b1)\n    transform = transform.at[:, 2].set(b2)\n    transform = transform.at[:, 4].set(a0)\n    transform = transform.at[:, 5].set(a2)\n\n    # deal with transform\n    transform = jnp.pad(\n        transform, pad_width=[[0, 0], [0, 1]], constant_values=1\n    )\n    transform = jnp.reshape(transform, (batch_size, 3, 3))\n    offset = transform[:, 0:2, 2]\n    offset = jnp.pad(offset, pad_width=[[0, 0], [0, 1]])\n    transform = transform.at[:, 0:2, 2].set(0)\n\n    # transform the indices\n    coordinates = jnp.einsum(\"Bhwij, Bjk -> Bhwik\", indices, transform)\n    coordinates = jnp.moveaxis(coordinates, source=-1, destination=1)\n    coordinates += jnp.reshape(offset, shape=(*offset.shape, 1, 1, 1))\n\n    # apply affine transformation\n    _map_coordinates = functools.partial(\n        jax.scipy.ndimage.map_coordinates,\n        order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],\n        mode=fill_mode,\n        cval=fill_value,\n    )\n    affined = jax.vmap(_map_coordinates)(images, coordinates)\n\n    if data_format == \"channels_first\":\n        affined = jnp.transpose(affined, (0, 3, 1, 2))\n    if need_squeeze:\n        affined = jnp.squeeze(affined, axis=0)\n    return affined\n\n\ndef perspective_transform(\n    images,\n    start_points,\n    end_points,\n    interpolation=\"bilinear\",\n    fill_value=0,\n    data_format=None,\n):\n    data_format = backend.standardize_data_format(data_format)\n    if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys():\n        raise ValueError(\n            \"Invalid value for argument `interpolation`. Expected of one \"\n            f\"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: \"\n            f\"interpolation={interpolation}\"\n        )\n\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n\n    if start_points.shape[-2:] != (4, 2) or start_points.ndim not in (2, 3):\n        raise ValueError(\n            \"Invalid start_points shape: expected (4,2) for a single image\"\n            f\" or (N,4,2) for a batch. Received shape: {start_points.shape}\"\n        )\n    if end_points.shape[-2:] != (4, 2) or end_points.ndim not in (2, 3):\n        raise ValueError(\n            \"Invalid end_points shape: expected (4,2) for a single image\"\n            f\" or (N,4,2) for a batch. Received shape: {end_points.shape}\"\n        )\n    if start_points.shape != end_points.shape:\n        raise ValueError(\n            \"start_points and end_points must have the same shape.\"\n            f\" Received start_points.shape={start_points.shape}, \"\n            f\"end_points.shape={end_points.shape}\"\n        )\n\n    need_squeeze = False\n    if len(images.shape) == 3:\n        images = jnp.expand_dims(images, axis=0)\n        need_squeeze = True\n\n    if len(start_points.shape) == 2:\n        start_points = jnp.expand_dims(start_points, axis=0)\n    if len(end_points.shape) == 2:\n        end_points = jnp.expand_dims(end_points, axis=0)\n\n    if data_format == \"channels_first\":\n        images = jnp.transpose(images, (0, 2, 3, 1))\n\n    _, height, width, _ = images.shape\n    transforms = compute_homography_matrix(\n        jnp.asarray(start_points, dtype=\"float32\"),\n        jnp.asarray(end_points, dtype=\"float32\"),\n    )\n\n    x, y = jnp.meshgrid(jnp.arange(width), jnp.arange(height), indexing=\"xy\")\n    grid = jnp.stack([x.ravel(), y.ravel(), jnp.ones_like(x).ravel()], axis=0)\n\n    def transform_coordinates(transform):\n        denom = transform[6] * grid[0] + transform[7] * grid[1] + 1.0\n        x_in = (\n            transform[0] * grid[0] + transform[1] * grid[1] + transform[2]\n        ) / denom\n        y_in = (\n            transform[3] * grid[0] + transform[4] * grid[1] + transform[5]\n        ) / denom\n        return jnp.stack([y_in, x_in], axis=0)\n\n    transformed_coords = jax.vmap(transform_coordinates)(transforms)\n\n    def interpolate_image(image, coords):\n        def interpolate_channel(channel_img):\n            return jax.scipy.ndimage.map_coordinates(\n                channel_img,\n                coords,\n                order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],\n                mode=\"constant\",\n                cval=fill_value,\n            ).reshape(height, width)\n\n        return jax.vmap(interpolate_channel, in_axes=0)(\n            jnp.moveaxis(image, -1, 0)\n        )\n\n    output = jax.vmap(interpolate_image, in_axes=(0, 0))(\n        images, transformed_coords\n    )\n    output = jnp.moveaxis(output, 1, -1)\n\n    if data_format == \"channels_first\":\n        output = jnp.transpose(output, (0, 3, 1, 2))\n    if need_squeeze:\n        output = jnp.squeeze(output, axis=0)\n\n    return output\n\n\ndef compute_homography_matrix(start_points, end_points):\n    start_x, start_y = start_points[..., 0], start_points[..., 1]\n    end_x, end_y = end_points[..., 0], end_points[..., 1]\n\n    zeros = jnp.zeros_like(end_x)\n    ones = jnp.ones_like(end_x)\n\n    x_rows = jnp.stack(\n        [\n            end_x,\n            end_y,\n            ones,\n            zeros,\n            zeros,\n            zeros,\n            -start_x * end_x,\n            -start_x * end_y,\n        ],\n        axis=-1,\n    )\n    y_rows = jnp.stack(\n        [\n            zeros,\n            zeros,\n            zeros,\n            end_x,\n            end_y,\n            ones,\n            -start_y * end_x,\n            -start_y * end_y,\n        ],\n        axis=-1,\n    )\n\n    coefficient_matrix = jnp.concatenate([x_rows, y_rows], axis=1)\n\n    target_vector = jnp.expand_dims(\n        jnp.concatenate([start_x, start_y], axis=-1), axis=-1\n    )\n\n    homography_matrix = jnp.linalg.solve(coefficient_matrix, target_vector)\n\n    return homography_matrix.squeeze(-1)\n\n\ndef map_coordinates(\n    inputs, coordinates, order, fill_mode=\"constant\", fill_value=0.0\n):\n    inputs = convert_to_tensor(inputs)\n    coordinates = convert_to_tensor(coordinates)\n    if coordinates.shape[0] != len(inputs.shape):\n        raise ValueError(\n            \"First dim of `coordinates` must be the same as the rank of \"\n            \"`inputs`. \"\n            f\"Received inputs with shape: {inputs.shape} and coordinate \"\n            f\"leading dim of {coordinates.shape[0]}\"\n        )\n    if len(coordinates.shape) < 2:\n        raise ValueError(\n            \"Invalid coordinates rank: expected at least rank 2.\"\n            f\" Received input with shape: {coordinates.shape}\"\n        )\n    if fill_mode not in MAP_COORDINATES_FILL_MODES:\n        raise ValueError(\n            \"Invalid value for argument `fill_mode`. Expected one of \"\n            f\"{set(MAP_COORDINATES_FILL_MODES)}. Received: \"\n            f\"fill_mode={fill_mode}\"\n        )\n    if order not in range(2):\n        raise ValueError(\n            \"Invalid value for argument `order`. Expected one of \"\n            f\"{[0, 1]}. Received: order={order}\"\n        )\n    return jax.scipy.ndimage.map_coordinates(\n        inputs, coordinates, order, fill_mode, fill_value\n    )\n\n\ndef gaussian_blur(\n    images, kernel_size=(3, 3), sigma=(1.0, 1.0), data_format=None\n):\n    def _create_gaussian_kernel(kernel_size, sigma, dtype):\n        def _get_gaussian_kernel1d(size, sigma):\n            x = jnp.arange(size, dtype=dtype) - jnp.array(\n                (size - 1) / 2, dtype=dtype\n            )\n            kernel1d = jnp.exp(-0.5 * (x / sigma) ** 2)\n            return kernel1d / jnp.sum(kernel1d)\n\n        def _get_gaussian_kernel2d(size, sigma):\n            kernel1d_x = _get_gaussian_kernel1d(size[0], sigma[0])\n            kernel1d_y = _get_gaussian_kernel1d(size[1], sigma[1])\n            return jnp.outer(kernel1d_y, kernel1d_x)\n\n        kernel = _get_gaussian_kernel2d(kernel_size, sigma)[\n            jnp.newaxis, jnp.newaxis, :, :\n        ]\n        return kernel\n\n    images = convert_to_tensor(images)\n    dtype = backend.standardize_dtype(images.dtype)\n    sigma = convert_to_tensor(sigma, dtype=dtype)\n\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n\n    need_squeeze = False\n    if images.ndim == 3:\n        images = images[jnp.newaxis, ...]\n        need_squeeze = True\n\n    if data_format == \"channels_last\":\n        images = jnp.transpose(images, (0, 3, 1, 2))\n\n    num_channels = images.shape[1]\n    kernel = _create_gaussian_kernel(kernel_size, sigma, dtype)\n\n    kernel = jnp.tile(kernel, (num_channels, 1, 1, 1))\n\n    blurred_images = jax.lax.conv_general_dilated(\n        images,\n        kernel,\n        window_strides=(1, 1),\n        padding=\"SAME\",\n        dimension_numbers=(\"NCHW\", \"OIHW\", \"NCHW\"),\n        feature_group_count=num_channels,\n    )\n\n    if data_format == \"channels_last\":\n        blurred_images = jnp.transpose(blurred_images, (0, 2, 3, 1))\n\n    if need_squeeze:\n        blurred_images = blurred_images.squeeze(axis=0)\n\n    return blurred_images\n\n\ndef elastic_transform(\n    images,\n    alpha=20.0,\n    sigma=5.0,\n    interpolation=\"bilinear\",\n    fill_mode=\"reflect\",\n    fill_value=0.0,\n    seed=None,\n    data_format=None,\n):\n    data_format = backend.standardize_data_format(data_format)\n    if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys():\n        raise ValueError(\n            \"Invalid value for argument `interpolation`. Expected of one \"\n            f\"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: \"\n            f\"interpolation={interpolation}\"\n        )\n    if fill_mode not in AFFINE_TRANSFORM_FILL_MODES:\n        raise ValueError(\n            \"Invalid value for argument `fill_mode`. Expected of one \"\n            f\"{AFFINE_TRANSFORM_FILL_MODES}. Received: fill_mode={fill_mode}\"\n        )\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n\n    images = convert_to_tensor(images)\n    alpha = convert_to_tensor(alpha)\n    sigma = convert_to_tensor(sigma)\n    input_dtype = images.dtype\n    kernel_size = (int(6 * sigma) | 1, int(6 * sigma) | 1)\n\n    need_squeeze = False\n    if len(images.shape) == 3:\n        images = jnp.expand_dims(images, axis=0)\n        need_squeeze = True\n\n    if data_format == \"channels_last\":\n        batch_size, height, width, channels = images.shape\n        channel_axis = -1\n    else:\n        batch_size, channels, height, width = images.shape\n        channel_axis = 1\n\n    seed = draw_seed(seed)\n    dx = (\n        jax.random.normal(\n            seed, shape=(batch_size, height, width), dtype=input_dtype\n        )\n        * sigma\n    )\n    dy = (\n        jax.random.normal(\n            seed, shape=(batch_size, height, width), dtype=input_dtype\n        )\n        * sigma\n    )\n\n    dx = gaussian_blur(\n        jnp.expand_dims(dx, axis=channel_axis),\n        kernel_size=kernel_size,\n        sigma=(sigma, sigma),\n        data_format=data_format,\n    )\n    dy = gaussian_blur(\n        jnp.expand_dims(dy, axis=channel_axis),\n        kernel_size=kernel_size,\n        sigma=(sigma, sigma),\n        data_format=data_format,\n    )\n\n    dx = jnp.squeeze(dx)\n    dy = jnp.squeeze(dy)\n\n    x, y = jnp.meshgrid(jnp.arange(width), jnp.arange(height))\n    x, y = x[None, :, :], y[None, :, :]\n\n    distorted_x = x + alpha * dx\n    distorted_y = y + alpha * dy\n\n    transformed_images = jnp.zeros_like(images)\n\n    if data_format == \"channels_last\":\n        for i in range(channels):\n            transformed_images = transformed_images.at[..., i].set(\n                jnp.stack(\n                    [\n                        map_coordinates(\n                            images[b, ..., i],\n                            [distorted_y[b], distorted_x[b]],\n                            order=AFFINE_TRANSFORM_INTERPOLATIONS[\n                                interpolation\n                            ],\n                            fill_mode=fill_mode,\n                            fill_value=fill_value,\n                        )\n                        for b in range(batch_size)\n                    ]\n                )\n            )\n    else:\n        for i in range(channels):\n            transformed_images = transformed_images.at[:, i, :, :].set(\n                jnp.stack(\n                    [\n                        map_coordinates(\n                            images[b, i, ...],\n                            [distorted_y[b], distorted_x[b]],\n                            order=AFFINE_TRANSFORM_INTERPOLATIONS[\n                                interpolation\n                            ],\n                            fill_mode=fill_mode,\n                            fill_value=fill_value,\n                        )\n                        for b in range(batch_size)\n                    ]\n                )\n            )\n\n    if need_squeeze:\n        transformed_images = jnp.squeeze(transformed_images, axis=0)\n    transformed_images = transformed_images.astype(input_dtype)\n\n    return transformed_images\n\n\ndef scale_and_translate(\n    images,\n    output_shape,\n    scale,\n    translation,\n    spatial_dims,\n    method,\n    antialias=True,\n):\n    if method not in SCALE_AND_TRANSLATE_METHODS:\n        raise ValueError(\n            \"Invalid value for argument `method`. Expected of one \"\n            f\"{SCALE_AND_TRANSLATE_METHODS}. Received: method={method}\"\n        )\n    images = convert_to_tensor(images)\n    scale = convert_to_tensor(scale)\n    translation = convert_to_tensor(translation)\n    return jax.image.scale_and_translate(\n        images,\n        output_shape,\n        spatial_dims,\n        scale,\n        translation,\n        method,\n        antialias,\n    )\n"
  },
  {
    "path": "keras/src/backend/jax/layer.py",
    "content": "from keras.src.backend.config import is_nnx_enabled\n\nif is_nnx_enabled():\n    from flax import nnx\n\n    class BaseLayer(nnx.Module):\n        def __init_subclass__(cls, **kwargs):\n            super().__init_subclass__(pytree=False, **kwargs)\nelse:\n    BaseLayer = object\n\n\nclass JaxLayer(BaseLayer):\n    pass\n"
  },
  {
    "path": "keras/src/backend/jax/linalg.py",
    "content": "import jax\nimport jax.numpy as jnp\nimport jax.scipy as jsp\n\nfrom keras.src.backend import config\nfrom keras.src.backend import standardize_dtype\nfrom keras.src.backend.common import dtypes\nfrom keras.src.backend.jax.core import cast\nfrom keras.src.backend.jax.core import convert_to_tensor\n\n\ndef cholesky(a, upper=False):\n    out = jnp.linalg.cholesky(a, upper=upper)\n    try:\n        # In eager mode, raise for nan to\n        # achieve behavior consistency with numpy\n        if jnp.any(jnp.isnan(out)):\n            raise ValueError(\n                \"Cholesky decomposition failed. \"\n                \"The input might not be a valid \"\n                \"positive definite matrix.\"\n            )\n    except jax.errors.TracerBoolConversionError:\n        # Cannot raise for nan in tracing mode\n        pass\n    return out\n\n\ndef cholesky_inverse(a, upper=False):\n    identity = jnp.eye(a.shape[-1], dtype=a.dtype)\n    inv_chol = solve_triangular(a, identity, lower=not upper)\n    if upper:\n        a_inv = jnp.matmul(inv_chol, jnp.transpose(inv_chol))\n    else:\n        a_inv = jnp.matmul(jnp.transpose(inv_chol), inv_chol)\n    return a_inv\n\n\ndef det(a):\n    return jnp.linalg.det(a)\n\n\ndef eig(x):\n    return jnp.linalg.eig(x)\n\n\ndef eigh(x):\n    return jnp.linalg.eigh(x)\n\n\ndef inv(a):\n    return jnp.linalg.inv(a)\n\n\ndef lu_factor(x):\n    lu_factor_fn = jsp.linalg.lu_factor\n    if x.ndim > 2:\n        for i in range(x.ndim - 2):\n            lu_factor_fn = jax.vmap(lu_factor_fn)\n\n    return lu_factor_fn(x)\n\n\ndef norm(x, ord=None, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = cast(x, dtype)\n    return jnp.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)\n\n\ndef qr(x, mode=\"reduced\"):\n    if mode not in {\"reduced\", \"complete\"}:\n        raise ValueError(\n            \"`mode` argument value not supported. \"\n            \"Expected one of {'reduced', 'complete'}. \"\n            f\"Received: mode={mode}\"\n        )\n    return jnp.linalg.qr(x, mode=mode)\n\n\ndef solve(a, b):\n    return jnp.linalg.solve(a, b)\n\n\ndef solve_triangular(a, b, lower=False):\n    return jsp.linalg.solve_triangular(a, b, lower=lower)\n\n\ndef svd(x, full_matrices=True, compute_uv=True):\n    return jnp.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv)\n\n\ndef lstsq(a, b, rcond=None):\n    a = convert_to_tensor(a)\n    b = convert_to_tensor(b)\n    return jnp.linalg.lstsq(a, b, rcond=rcond)[0]\n\n\ndef jvp(fun, primals, tangents, has_aux=False):\n    return jax.jvp(fun, primals, tangents, has_aux=has_aux)\n"
  },
  {
    "path": "keras/src/backend/jax/math.py",
    "content": "import math\n\nimport jax\nimport jax.numpy as jnp\n\nfrom keras.src.backend import standardize_dtype\nfrom keras.src.backend.jax.core import convert_to_tensor\nfrom keras.src.utils.module_utils import scipy\n\n\ndef segment_sum(data, segment_ids, num_segments=None, sorted=False):\n    if num_segments is None:\n        raise ValueError(\n            \"Argument `num_segments` must be set when using the JAX backend. \"\n            \"Received: num_segments=None\"\n        )\n    return jax.ops.segment_sum(\n        data, segment_ids, num_segments, indices_are_sorted=sorted\n    )\n\n\ndef segment_max(data, segment_ids, num_segments=None, sorted=False):\n    if num_segments is None:\n        raise ValueError(\n            \"Argument `num_segments` must be set when using the JAX backend. \"\n            \"Received: num_segments=None\"\n        )\n    return jax.ops.segment_max(\n        data, segment_ids, num_segments, indices_are_sorted=sorted\n    )\n\n\ndef top_k(x, k, sorted=True):\n    # Jax does not supported `sorted`, but in the case where `sorted=False`,\n    # order is not guaranteed, so OK to return sorted output.\n    return jax.lax.top_k(x, k)\n\n\ndef in_top_k(targets, predictions, k):\n    preds_at_label = jnp.take_along_axis(\n        predictions, jnp.expand_dims(targets, axis=-1), axis=-1\n    )\n    # `nan` shouldn't be considered as large probability.\n    preds_at_label = jnp.where(\n        jnp.isnan(preds_at_label), -jnp.inf, preds_at_label\n    )\n    rank = 1 + jnp.sum(jnp.greater(predictions, preds_at_label), axis=-1)\n    return jnp.less_equal(rank, k)\n\n\ndef logsumexp(x, axis=None, keepdims=False):\n    return jax.scipy.special.logsumexp(x, axis=axis, keepdims=keepdims)\n\n\ndef qr(x, mode=\"reduced\"):\n    if mode not in {\"reduced\", \"complete\"}:\n        raise ValueError(\n            \"`mode` argument value not supported. \"\n            \"Expected one of {'reduced', 'complete'}. \"\n            f\"Received: mode={mode}\"\n        )\n    return jnp.linalg.qr(x, mode=mode)\n\n\ndef extract_sequences(x, sequence_length, sequence_stride):\n    *batch_shape, signal_length = x.shape\n    batch_shape = list(batch_shape)\n    x = jnp.reshape(x, (math.prod(batch_shape), signal_length, 1))\n    x = jax.lax.conv_general_dilated_patches(\n        x,\n        (sequence_length,),\n        (sequence_stride,),\n        \"VALID\",\n        dimension_numbers=(\"NTC\", \"OIT\", \"NTC\"),\n    )\n    return jnp.reshape(x, (*batch_shape, *x.shape[-2:]))\n\n\ndef _get_complex_tensor_from_tuple(x):\n    if not isinstance(x, (tuple, list)) or len(x) != 2:\n        raise ValueError(\n            \"Input `x` should be a tuple of two tensors - real and imaginary.\"\n            f\"Received: x={x}\"\n        )\n    # `convert_to_tensor` does not support passing complex tensors. We separate\n    # the input out into real and imaginary and convert them separately.\n    real, imag = x\n    # Check shapes.\n    if real.shape != imag.shape:\n        raise ValueError(\n            \"Input `x` should be a tuple of two tensors - real and imaginary.\"\n            \"Both the real and imaginary parts should have the same shape. \"\n            f\"Received: x[0].shape = {real.shape}, x[1].shape = {imag.shape}\"\n        )\n    # Ensure dtype is float.\n    if not jnp.issubdtype(real.dtype, jnp.floating) or not jnp.issubdtype(\n        imag.dtype, jnp.floating\n    ):\n        raise ValueError(\n            \"At least one tensor in input `x` is not of type float.\"\n            f\"Received: x={x}.\"\n        )\n    complex_input = jax.lax.complex(real, imag)\n    return complex_input\n\n\ndef fft(x):\n    complex_input = _get_complex_tensor_from_tuple(x)\n    complex_output = jnp.fft.fft(complex_input)\n    return jnp.real(complex_output), jnp.imag(complex_output)\n\n\ndef fft2(x):\n    complex_input = _get_complex_tensor_from_tuple(x)\n    complex_output = jnp.fft.fft2(complex_input)\n    return jnp.real(complex_output), jnp.imag(complex_output)\n\n\ndef ifft2(x):\n    complex_input = _get_complex_tensor_from_tuple(x)\n    complex_output = jnp.fft.ifft2(complex_input)\n    return jnp.real(complex_output), jnp.imag(complex_output)\n\n\ndef rfft(x, fft_length=None):\n    complex_output = jnp.fft.rfft(x, n=fft_length, axis=-1, norm=\"backward\")\n    return jnp.real(complex_output), jnp.imag(complex_output)\n\n\ndef irfft(x, fft_length=None):\n    complex_input = _get_complex_tensor_from_tuple(x)\n    return jnp.fft.irfft(complex_input, n=fft_length, axis=-1, norm=\"backward\")\n\n\ndef stft(\n    x, sequence_length, sequence_stride, fft_length, window=\"hann\", center=True\n):\n    if standardize_dtype(x.dtype) not in {\"float32\", \"float64\"}:\n        raise TypeError(\n            \"Invalid input type. Expected `float32` or `float64`. \"\n            f\"Received: input type={x.dtype}\"\n        )\n    if fft_length < sequence_length:\n        raise ValueError(\n            \"`fft_length` must equal or larger than `sequence_length`. \"\n            f\"Received: sequence_length={sequence_length}, \"\n            f\"fft_length={fft_length}\"\n        )\n    if isinstance(window, str):\n        if window not in {\"hann\", \"hamming\"}:\n            raise ValueError(\n                \"If a string is passed to `window`, it must be one of \"\n                f'`\"hann\"`, `\"hamming\"`. Received: window={window}'\n            )\n    x = convert_to_tensor(x)\n\n    if center:\n        pad_width = [(0, 0) for _ in range(len(x.shape))]\n        pad_width[-1] = (fft_length // 2, fft_length // 2)\n        x = jnp.pad(x, pad_width, mode=\"reflect\")\n\n    l_pad = (fft_length - sequence_length) // 2\n    r_pad = fft_length - sequence_length - l_pad\n\n    if window is not None:\n        if isinstance(window, str):\n            win = convert_to_tensor(\n                scipy.signal.get_window(window, sequence_length), dtype=x.dtype\n            )\n        else:\n            win = convert_to_tensor(window, dtype=x.dtype)\n        if len(win.shape) != 1 or win.shape[-1] != sequence_length:\n            raise ValueError(\n                \"The shape of `window` must be equal to [sequence_length].\"\n                f\"Received: window shape={win.shape}\"\n            )\n        win = jnp.pad(win, [[l_pad, r_pad]])\n    else:\n        win = jnp.ones((sequence_length + l_pad + r_pad), dtype=x.dtype)\n\n    result = jax.scipy.signal.stft(\n        x,\n        fs=1.0,\n        window=win,\n        nperseg=(sequence_length + l_pad + r_pad),\n        noverlap=(sequence_length + l_pad + r_pad - sequence_stride),\n        nfft=fft_length,\n        boundary=None,\n        padded=False,\n    )[-1]\n    # scale and swap to (..., num_sequences, fft_bins)\n    scale = jnp.sqrt(1.0 / win.sum() ** 2)\n    result = result / scale\n    result = jnp.swapaxes(result, -2, -1)\n    return jnp.real(result), jnp.imag(result)\n\n\ndef istft(\n    x,\n    sequence_length,\n    sequence_stride,\n    fft_length,\n    length=None,\n    window=\"hann\",\n    center=True,\n):\n    x = _get_complex_tensor_from_tuple(x)\n    dtype = jnp.real(x).dtype\n\n    if len(x.shape) < 2:\n        raise ValueError(\n            f\"Input `x` must have at least 2 dimensions. \"\n            f\"Received shape: {x.shape}\"\n        )\n\n    expected_output_len = fft_length + sequence_stride * (x.shape[-2] - 1)\n    l_pad = (fft_length - sequence_length) // 2\n    r_pad = fft_length - sequence_length - l_pad\n\n    if window is not None:\n        if isinstance(window, str):\n            win = convert_to_tensor(\n                scipy.signal.get_window(window, sequence_length), dtype=dtype\n            )\n        else:\n            win = convert_to_tensor(window, dtype=dtype)\n        if len(win.shape) != 1 or win.shape[-1] != sequence_length:\n            raise ValueError(\n                \"The shape of `window` must be equal to [sequence_length].\"\n                f\"Received: window shape={win.shape}\"\n            )\n        win = jnp.pad(win, [[l_pad, r_pad]])\n    else:\n        win = jnp.ones((sequence_length + l_pad + r_pad), dtype=dtype)\n\n    x = jax.scipy.signal.istft(\n        x,\n        fs=1.0,\n        window=win,\n        nperseg=(sequence_length + l_pad + r_pad),\n        noverlap=(sequence_length + l_pad + r_pad - sequence_stride),\n        nfft=fft_length,\n        boundary=False,\n        time_axis=-2,\n        freq_axis=-1,\n    )[-1]\n\n    # scale\n    x = x / win.sum() if window is not None else x / sequence_stride\n\n    start = 0 if center is False else fft_length // 2\n    if length is not None:\n        end = start + length\n    elif center is True:\n        end = -(fft_length // 2)\n    else:\n        end = expected_output_len\n    return x[..., start:end]\n\n\ndef rsqrt(x):\n    return jax.lax.rsqrt(x)\n\n\ndef erf(x):\n    return jax.lax.erf(x)\n\n\ndef erfinv(x):\n    return jax.lax.erf_inv(x)\n\n\ndef logdet(x):\n    from keras.src.backend.jax.numpy import slogdet\n\n    # In JAX (like in NumPy) slogdet is more stable than\n    # `np.log(np.linalg.det(x))`. See\n    # https://numpy.org/doc/stable/reference/generated/numpy.linalg.slogdet.html\n    return slogdet(x)[1]\n"
  },
  {
    "path": "keras/src/backend/jax/nn.py",
    "content": "import builtins\nimport inspect\nimport math\n\nimport jax\nimport jax.experimental.sparse as jax_sparse\nimport jax.numpy as jnp\nfrom absl import logging\nfrom jax import lax\nfrom jax import nn as jnn\nfrom jax.experimental.pallas.ops.tpu.splash_attention import (\n    splash_attention_kernel,\n)\nfrom jax.experimental.pallas.ops.tpu.splash_attention import (\n    splash_attention_mask,\n)\n\nfrom keras.src import backend\nfrom keras.src.backend.common.backend_utils import (\n    compute_adaptive_pooling_window_sizes,\n)\nfrom keras.src.backend.common.backend_utils import (\n    compute_conv_transpose_padding_args_for_jax,\n)\nfrom keras.src.backend.jax.core import cast\nfrom keras.src.backend.jax.core import convert_to_tensor\n\n\ndef relu(x):\n    x = convert_to_tensor(x)\n    return jnn.relu(x)\n\n\ndef relu6(x):\n    x = convert_to_tensor(x)\n    return jnn.relu6(x)\n\n\ndef sigmoid(x):\n    x = convert_to_tensor(x)\n    return jnn.sigmoid(x)\n\n\ndef sparse_sigmoid(x):\n    x = convert_to_tensor(x)\n    return jnn.sparse_sigmoid(x)\n\n\ndef tanh(x):\n    x = convert_to_tensor(x)\n    return jnn.tanh(x)\n\n\ndef tanh_shrink(x):\n    x = convert_to_tensor(x)\n    return x - jnp.tanh(x)\n\n\ndef softplus(x):\n    x = convert_to_tensor(x)\n    return jnn.softplus(x)\n\n\ndef softsign(x):\n    x = convert_to_tensor(x)\n    return jnn.soft_sign(x)\n\n\ndef soft_shrink(x, threshold=0.5):\n    x = convert_to_tensor(x)\n    return jnp.where(\n        x > threshold,\n        x - threshold,\n        jnp.where(x < -threshold, x + threshold, 0.0),\n    )\n\n\ndef sparse_plus(x):\n    x = convert_to_tensor(x)\n    return jnn.sparse_plus(x)\n\n\ndef silu(x):\n    x = convert_to_tensor(x)\n    return jnn.silu(x)\n\n\ndef squareplus(x, b=4):\n    x = convert_to_tensor(x)\n    return jnn.squareplus(x, b=b)\n\n\ndef log_sigmoid(x):\n    x = convert_to_tensor(x)\n    return jnn.log_sigmoid(x)\n\n\ndef leaky_relu(x, negative_slope=0.2):\n    x = convert_to_tensor(x)\n    return jnn.leaky_relu(x, negative_slope=negative_slope)\n\n\ndef hard_sigmoid(x):\n    x = convert_to_tensor(x)\n    return jnn.hard_sigmoid(x)\n\n\ndef hard_silu(x):\n    x = convert_to_tensor(x)\n    return jnn.hard_silu(x)\n\n\ndef elu(x, alpha=1.0):\n    x = convert_to_tensor(x)\n    return jnn.elu(x, alpha=alpha)\n\n\ndef selu(x):\n    x = convert_to_tensor(x)\n    return jnn.selu(x)\n\n\ndef gelu(x, approximate=True):\n    x = convert_to_tensor(x)\n    return jnn.gelu(x, approximate)\n\n\ndef celu(x, alpha=1.0):\n    x = convert_to_tensor(x)\n    return jnn.celu(x, alpha=alpha)\n\n\ndef glu(x, axis=-1):\n    x = convert_to_tensor(x)\n    return jnn.glu(x, axis=axis)\n\n\ndef hard_tanh(x):\n    x = convert_to_tensor(x)\n    return jnn.hard_tanh(x)\n\n\ndef hard_shrink(x, threshold=0.5):\n    x = convert_to_tensor(x)\n    return jnp.where(jnp.abs(x) > threshold, x, 0.0)\n\n\ndef threshold(x, threshold, default_value):\n    x = convert_to_tensor(x)\n    return jnp.where(x > threshold, x, default_value)\n\n\ndef softmax(x, axis=-1):\n    x = convert_to_tensor(x)\n    return jnn.softmax(x, axis=axis)\n\n\ndef log_softmax(x, axis=-1):\n    x = convert_to_tensor(x)\n    return jnn.log_softmax(x, axis=axis)\n\n\ndef sparsemax(x, axis=-1):\n    # Sort logits along the specified axis in descending order\n    logits = convert_to_tensor(x)\n    logits_sorted = -1.0 * jnp.sort(logits * -1.0, axis=axis)\n    logits_cumsum = jnp.cumsum(logits_sorted, axis=axis)  # find cumulative sum\n    r = jnp.arange(1, logits.shape[axis] + 1)  # Determine the sparsity\n    r_shape = [1] * logits.ndim\n    r_shape[axis] = -1  # Broadcast to match the target axis\n    r = r.reshape(r_shape)\n    support = logits_sorted - (logits_cumsum - 1) / r > 0\n    # Find the threshold\n    k = jnp.sum(support, axis=axis, keepdims=True)\n    logits_cumsum_safe = jnp.where(support, logits_cumsum, 0.0)\n    tau = (jnp.sum(logits_cumsum_safe, axis=axis, keepdims=True) - 1) / k\n    output = jnp.maximum(logits - tau, 0.0)\n    return output\n\n\ndef _convert_to_spatial_operand(\n    x,\n    num_spatial_dims,\n    data_format=\"channels_last\",\n    include_batch_and_channels=True,\n):\n    # Helper function that converts an operand to a spatial operand.\n    x = (x,) * num_spatial_dims if isinstance(x, int) else x\n    if not include_batch_and_channels:\n        return x\n    if data_format == \"channels_last\":\n        x = (1,) + x + (1,)\n    else:\n        x = (1,) + (1,) + x\n    return x\n\n\ndef _pool(\n    inputs,\n    initial_value,\n    reduce_fn,\n    pool_size,\n    strides=None,\n    padding=\"valid\",\n):\n    \"\"\"Helper function to define pooling functions.\n\n    Args:\n        inputs: input data of shape `N+2`.\n        initial_value: the initial value for the reduction.\n        reduce_fn: a reduce function of the form `(T, T) -> T`.\n        pool_size: a sequence of `N` integers, representing the window size to\n            reduce over.\n        strides: a sequence of `N` integers, representing the inter-window\n            strides (default: `(1, ..., 1)`).\n        padding: either the string `same` or `valid`.\n\n    Returns:\n        The output of the reduction for each window slice.\n    \"\"\"\n    if padding not in (\"same\", \"valid\"):\n        raise ValueError(\n            f\"Invalid padding '{padding}', must be 'same' or 'valid'.\"\n        )\n    padding = padding.upper()\n    return lax.reduce_window(\n        inputs,\n        initial_value,\n        reduce_fn,\n        pool_size,\n        strides,\n        padding,\n    )\n\n\ndef max_pool(\n    inputs,\n    pool_size,\n    strides=None,\n    padding=\"valid\",\n    data_format=None,\n):\n    data_format = backend.standardize_data_format(data_format)\n    num_spatial_dims = inputs.ndim - 2\n    pool_size = _convert_to_spatial_operand(\n        pool_size, num_spatial_dims, data_format\n    )\n    strides = pool_size if strides is None else strides\n    strides = _convert_to_spatial_operand(\n        strides, num_spatial_dims, data_format\n    )\n    return _pool(inputs, -jnp.inf, lax.max, pool_size, strides, padding)\n\n\ndef average_pool(\n    inputs,\n    pool_size,\n    strides=None,\n    padding=\"valid\",\n    data_format=None,\n):\n    data_format = backend.standardize_data_format(data_format)\n    num_spatial_dims = inputs.ndim - 2\n    pool_size = _convert_to_spatial_operand(\n        pool_size, num_spatial_dims, data_format\n    )\n    strides = pool_size if strides is None else strides\n    strides = _convert_to_spatial_operand(\n        strides, num_spatial_dims, data_format\n    )\n\n    pooled = _pool(inputs, 0.0, lax.add, pool_size, strides, padding)\n    if padding == \"valid\":\n        # Avoid the extra reduce_window.\n        return pooled / math.prod(pool_size)\n    else:\n        # Count the number of valid entries at each input point, then use that\n        # for computing average. Assumes that any two arrays of same shape will\n        # be padded the same. Avoid broadcasting on axis where pooling is\n        # skipped.\n        shape = [\n            (a if b != 1 else 1) for (a, b) in zip(inputs.shape, pool_size)\n        ]\n        window_counts = _pool(\n            jnp.ones(shape, inputs.dtype),\n            0.0,\n            lax.add,\n            pool_size,\n            strides,\n            padding,\n        )\n        return pooled / window_counts\n\n\ndef _compute_adaptive_pooling_gather_indices(\n    input_dim, output_size, big_window\n):\n    \"\"\"Compute gather indices for Two-Pool Gather method.\"\"\"\n    window_starts = jnp.floor(\n        (jnp.arange(output_size) * input_dim) / output_size\n    ).astype(jnp.int32)\n\n    window_ends = jnp.ceil(\n        (jnp.arange(1, output_size + 1) * input_dim) / output_size\n    ).astype(jnp.int32)\n\n    window_sizes = window_ends - window_starts\n    is_big = window_sizes == big_window\n\n    small_window = big_window - 1\n    small_len = input_dim - small_window + 1\n\n    small_indices = window_starts\n    big_indices = window_starts + small_len\n\n    gather = jnp.where(is_big, big_indices, small_indices)\n    return gather.astype(jnp.int32)\n\n\ndef _adaptive_average_pool1d(inputs, output_size, data_format=\"channels_first\"):\n    if isinstance(output_size, int):\n        output_size = (output_size,)\n\n    if data_format == \"channels_first\":\n        inputs = jnp.transpose(inputs, (0, 2, 1))  # NCL → NLC\n\n    n, l, c = inputs.shape\n    out_l = output_size[0]\n\n    small, big = compute_adaptive_pooling_window_sizes(l, out_l)\n    gather = _compute_adaptive_pooling_gather_indices(l, out_l, big)\n\n    small_pool = (\n        lax.reduce_window(\n            inputs, 0.0, lax.add, (1, small, 1), (1, 1, 1), \"valid\"\n        )\n        / small\n    )\n\n    big_pool = (\n        lax.reduce_window(inputs, 0.0, lax.add, (1, big, 1), (1, 1, 1), \"valid\")\n        / big\n    )\n\n    combined = jnp.concatenate([small_pool, big_pool], axis=1)\n    out = jnp.take(combined, gather, axis=1)\n\n    if data_format == \"channels_first\":\n        out = jnp.transpose(out, (0, 2, 1))\n\n    return out\n\n\ndef _adaptive_max_pool1d(inputs, output_size, data_format=\"channels_first\"):\n    if isinstance(output_size, int):\n        output_size = (output_size,)\n\n    if data_format == \"channels_first\":\n        inputs = jnp.transpose(inputs, (0, 2, 1))\n\n    n, l, c = inputs.shape\n    out_l = output_size[0]\n\n    small, big = compute_adaptive_pooling_window_sizes(l, out_l)\n    gather = _compute_adaptive_pooling_gather_indices(l, out_l, big)\n\n    small_pool = lax.reduce_window(\n        inputs, -jnp.inf, lax.max, (1, small, 1), (1, 1, 1), \"valid\"\n    )\n\n    big_pool = lax.reduce_window(\n        inputs, -jnp.inf, lax.max, (1, big, 1), (1, 1, 1), \"valid\"\n    )\n\n    combined = jnp.concatenate([small_pool, big_pool], axis=1)\n    out = jnp.take(combined, gather, axis=1)\n\n    if data_format == \"channels_first\":\n        out = jnp.transpose(out, (0, 2, 1))\n\n    return out\n\n\ndef _adaptive_average_pool2d(inputs, output_size, data_format=\"channels_first\"):\n    if isinstance(output_size, int):\n        output_size = (output_size, output_size)\n\n    if data_format == \"channels_first\":\n        inputs = jnp.transpose(inputs, (0, 2, 3, 1))\n\n    n, h, w, c = inputs.shape\n    out_h, out_w = output_size\n\n    small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)\n    gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)\n\n    small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)\n    gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)\n\n    small_h_pool = (\n        lax.reduce_window(\n            inputs, 0.0, lax.add, (1, small_h, 1, 1), (1, 1, 1, 1), \"valid\"\n        )\n        / small_h\n    )\n\n    big_h_pool = (\n        lax.reduce_window(\n            inputs, 0.0, lax.add, (1, big_h, 1, 1), (1, 1, 1, 1), \"valid\"\n        )\n        / big_h\n    )\n\n    combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=1)\n    pooled_h = jnp.take(combined_h, gather_h, axis=1)\n\n    small_w_pool = (\n        lax.reduce_window(\n            pooled_h, 0.0, lax.add, (1, 1, small_w, 1), (1, 1, 1, 1), \"valid\"\n        )\n        / small_w\n    )\n\n    big_w_pool = (\n        lax.reduce_window(\n            pooled_h, 0.0, lax.add, (1, 1, big_w, 1), (1, 1, 1, 1), \"valid\"\n        )\n        / big_w\n    )\n\n    combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=2)\n    out = jnp.take(combined_w, gather_w, axis=2)\n\n    if data_format == \"channels_first\":\n        out = jnp.transpose(out, (0, 3, 1, 2))\n\n    return out\n\n\ndef _adaptive_max_pool2d(inputs, output_size, data_format=\"channels_first\"):\n    if isinstance(output_size, int):\n        output_size = (output_size, output_size)\n\n    if data_format == \"channels_first\":\n        inputs = jnp.transpose(inputs, (0, 2, 3, 1))\n\n    n, h, w, c = inputs.shape\n    out_h, out_w = output_size\n\n    small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)\n    gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)\n\n    small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)\n    gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)\n\n    small_h_pool = lax.reduce_window(\n        inputs, -jnp.inf, lax.max, (1, small_h, 1, 1), (1, 1, 1, 1), \"valid\"\n    )\n\n    big_h_pool = lax.reduce_window(\n        inputs, -jnp.inf, lax.max, (1, big_h, 1, 1), (1, 1, 1, 1), \"valid\"\n    )\n\n    combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=1)\n    pooled_h = jnp.take(combined_h, gather_h, axis=1)\n\n    small_w_pool = lax.reduce_window(\n        pooled_h, -jnp.inf, lax.max, (1, 1, small_w, 1), (1, 1, 1, 1), \"valid\"\n    )\n\n    big_w_pool = lax.reduce_window(\n        pooled_h, -jnp.inf, lax.max, (1, 1, big_w, 1), (1, 1, 1, 1), \"valid\"\n    )\n\n    combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=2)\n    out = jnp.take(combined_w, gather_w, axis=2)\n\n    if data_format == \"channels_first\":\n        out = jnp.transpose(out, (0, 3, 1, 2))\n\n    return out\n\n\ndef _adaptive_average_pool3d(inputs, output_size, data_format=\"channels_first\"):\n    if isinstance(output_size, int):\n        output_size = (output_size, output_size, output_size)\n\n    if data_format == \"channels_first\":\n        inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1))\n\n    n, d, h, w, c = inputs.shape\n    out_d, out_h, out_w = output_size\n\n    small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d)\n    gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d)\n\n    small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)\n    gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)\n\n    small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)\n    gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)\n\n    small_d_pool = (\n        lax.reduce_window(\n            inputs,\n            0.0,\n            lax.add,\n            (1, small_d, 1, 1, 1),\n            (1, 1, 1, 1, 1),\n            \"valid\",\n        )\n        / small_d\n    )\n\n    big_d_pool = (\n        lax.reduce_window(\n            inputs, 0.0, lax.add, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), \"valid\"\n        )\n        / big_d\n    )\n\n    combined_d = jnp.concatenate([small_d_pool, big_d_pool], axis=1)\n    pooled_d = jnp.take(combined_d, gather_d, axis=1)\n\n    small_h_pool = (\n        lax.reduce_window(\n            pooled_d,\n            0.0,\n            lax.add,\n            (1, 1, small_h, 1, 1),\n            (1, 1, 1, 1, 1),\n            \"valid\",\n        )\n        / small_h\n    )\n\n    big_h_pool = (\n        lax.reduce_window(\n            pooled_d,\n            0.0,\n            lax.add,\n            (1, 1, big_h, 1, 1),\n            (1, 1, 1, 1, 1),\n            \"valid\",\n        )\n        / big_h\n    )\n\n    combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=2)\n    pooled_h = jnp.take(combined_h, gather_h, axis=2)\n\n    small_w_pool = (\n        lax.reduce_window(\n            pooled_h,\n            0.0,\n            lax.add,\n            (1, 1, 1, small_w, 1),\n            (1, 1, 1, 1, 1),\n            \"valid\",\n        )\n        / small_w\n    )\n\n    big_w_pool = (\n        lax.reduce_window(\n            pooled_h,\n            0.0,\n            lax.add,\n            (1, 1, 1, big_w, 1),\n            (1, 1, 1, 1, 1),\n            \"valid\",\n        )\n        / big_w\n    )\n\n    combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=3)\n    out = jnp.take(combined_w, gather_w, axis=3)\n\n    if data_format == \"channels_first\":\n        out = jnp.transpose(out, (0, 4, 1, 2, 3))\n\n    return out\n\n\ndef _adaptive_max_pool3d(inputs, output_size, data_format=\"channels_first\"):\n    if isinstance(output_size, int):\n        output_size = (output_size, output_size, output_size)\n\n    if data_format == \"channels_first\":\n        inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1))\n\n    n, d, h, w, c = inputs.shape\n    out_d, out_h, out_w = output_size\n\n    small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d)\n    gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d)\n\n    small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)\n    gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)\n\n    small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)\n    gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)\n\n    small_d_pool = lax.reduce_window(\n        inputs,\n        -jnp.inf,\n        lax.max,\n        (1, small_d, 1, 1, 1),\n        (1, 1, 1, 1, 1),\n        \"valid\",\n    )\n\n    big_d_pool = lax.reduce_window(\n        inputs, -jnp.inf, lax.max, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), \"valid\"\n    )\n\n    combined_d = jnp.concatenate([small_d_pool, big_d_pool], axis=1)\n    pooled_d = jnp.take(combined_d, gather_d, axis=1)\n\n    small_h_pool = lax.reduce_window(\n        pooled_d,\n        -jnp.inf,\n        lax.max,\n        (1, 1, small_h, 1, 1),\n        (1, 1, 1, 1, 1),\n        \"valid\",\n    )\n\n    big_h_pool = lax.reduce_window(\n        pooled_d,\n        -jnp.inf,\n        lax.max,\n        (1, 1, big_h, 1, 1),\n        (1, 1, 1, 1, 1),\n        \"valid\",\n    )\n\n    combined_h = jnp.concatenate([small_h_pool, big_h_pool], axis=2)\n    pooled_h = jnp.take(combined_h, gather_h, axis=2)\n\n    small_w_pool = lax.reduce_window(\n        pooled_h,\n        -jnp.inf,\n        lax.max,\n        (1, 1, 1, small_w, 1),\n        (1, 1, 1, 1, 1),\n        \"valid\",\n    )\n\n    big_w_pool = lax.reduce_window(\n        pooled_h,\n        -jnp.inf,\n        lax.max,\n        (1, 1, 1, big_w, 1),\n        (1, 1, 1, 1, 1),\n        \"valid\",\n    )\n\n    combined_w = jnp.concatenate([small_w_pool, big_w_pool], axis=3)\n    out = jnp.take(combined_w, gather_w, axis=3)\n\n    if data_format == \"channels_first\":\n        out = jnp.transpose(out, (0, 4, 1, 2, 3))\n\n    return out\n\n\ndef adaptive_average_pool(inputs, output_size, data_format=None):\n    data_format = backend.standardize_data_format(data_format)\n    dims = inputs.ndim - 2\n    if dims == 1:\n        return _adaptive_average_pool1d(inputs, output_size, data_format)\n    if dims == 2:\n        return _adaptive_average_pool2d(inputs, output_size, data_format)\n    if dims == 3:\n        return _adaptive_average_pool3d(inputs, output_size, data_format)\n    raise ValueError(\"adaptive_average_pool supports only 1D/2D/3D inputs\")\n\n\ndef adaptive_max_pool(inputs, output_size, data_format=None):\n    data_format = backend.standardize_data_format(data_format)\n    dims = inputs.ndim - 2\n    if dims == 1:\n        return _adaptive_max_pool1d(inputs, output_size, data_format)\n    if dims == 2:\n        return _adaptive_max_pool2d(inputs, output_size, data_format)\n    if dims == 3:\n        return _adaptive_max_pool3d(inputs, output_size, data_format)\n    raise ValueError(\"adaptive_max_pool supports only 1D/2D/3D inputs\")\n\n\ndef _convert_to_lax_conv_dimension_numbers(\n    num_spatial_dims,\n    data_format=\"channels_last\",\n    transpose=False,\n):\n    \"\"\"Create a `lax.ConvDimensionNumbers` for the given inputs.\"\"\"\n    num_dims = num_spatial_dims + 2\n\n    if data_format == \"channels_last\":\n        spatial_dims = tuple(range(1, num_dims - 1))\n        inputs_dn = (0, num_dims - 1) + spatial_dims\n    else:\n        spatial_dims = tuple(range(2, num_dims))\n        inputs_dn = (0, 1) + spatial_dims\n\n    if transpose:\n        kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2))\n    else:\n        kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2))\n\n    return lax.ConvDimensionNumbers(\n        lhs_spec=inputs_dn, rhs_spec=kernel_dn, out_spec=inputs_dn\n    )\n\n\ndef conv(\n    inputs,\n    kernel,\n    strides=1,\n    padding=\"valid\",\n    data_format=None,\n    dilation_rate=1,\n):\n    data_format = backend.standardize_data_format(data_format)\n    num_spatial_dims = inputs.ndim - 2\n    dimension_numbers = _convert_to_lax_conv_dimension_numbers(\n        num_spatial_dims,\n        data_format,\n        transpose=False,\n    )\n    strides = _convert_to_spatial_operand(\n        strides,\n        num_spatial_dims,\n        data_format,\n        include_batch_and_channels=False,\n    )\n    dilation_rate = _convert_to_spatial_operand(\n        dilation_rate,\n        num_spatial_dims,\n        data_format,\n        include_batch_and_channels=False,\n    )\n    if data_format == \"channels_last\":\n        channels = inputs.shape[-1]\n    else:\n        channels = inputs.shape[1]\n    kernel_in_channels = kernel.shape[-2]\n    if channels % kernel_in_channels > 0:\n        raise ValueError(\n            \"The number of input channels must be evenly divisible by \"\n            f\"kernel's in_channels. Received input channels {channels} and \"\n            f\"kernel in_channels {kernel_in_channels}. \"\n        )\n    feature_group_count = channels // kernel_in_channels\n    kernel = convert_to_tensor(kernel)\n    inputs = convert_to_tensor(inputs, dtype=kernel.dtype)\n    result = jax.lax.conv_general_dilated(\n        inputs,\n        kernel,\n        strides,\n        padding,\n        rhs_dilation=dilation_rate,\n        dimension_numbers=dimension_numbers,\n        feature_group_count=feature_group_count,\n    )\n    if result.size == 0:\n        raise ValueError(\n            \"The convolution operation resulted in an empty output. \"\n            \"This can happen if the input is too small for the given \"\n            \"kernel size, strides, dilation rate, and padding mode. \"\n            \"Please check the input shape and convolution parameters.\"\n        )\n    return result\n\n\ndef depthwise_conv(\n    inputs,\n    kernel,\n    strides=1,\n    padding=\"valid\",\n    data_format=None,\n    dilation_rate=1,\n):\n    data_format = backend.standardize_data_format(data_format)\n    num_spatial_dims = inputs.ndim - 2\n    dimension_numbers = _convert_to_lax_conv_dimension_numbers(\n        num_spatial_dims,\n        data_format,\n        transpose=False,\n    )\n    strides = _convert_to_spatial_operand(\n        strides,\n        num_spatial_dims,\n        data_format,\n        include_batch_and_channels=False,\n    )\n    dilation_rate = _convert_to_spatial_operand(\n        dilation_rate,\n        num_spatial_dims,\n        data_format,\n        include_batch_and_channels=False,\n    )\n    feature_group_count = (\n        inputs.shape[-1] if data_format == \"channels_last\" else inputs.shape[1]\n    )\n    kernel = convert_to_tensor(kernel)\n    inputs = convert_to_tensor(inputs)\n    kernel = jnp.reshape(\n        kernel,\n        kernel.shape[:-2] + (1, feature_group_count * kernel.shape[-1]),\n    )\n    return jax.lax.conv_general_dilated(\n        inputs,\n        kernel,\n        strides,\n        padding,\n        rhs_dilation=dilation_rate,\n        dimension_numbers=dimension_numbers,\n        feature_group_count=feature_group_count,\n    )\n\n\ndef separable_conv(\n    inputs,\n    depthwise_kernel,\n    pointwise_kernel,\n    strides=1,\n    padding=\"valid\",\n    data_format=None,\n    dilation_rate=1,\n):\n    data_format = backend.standardize_data_format(data_format)\n    depthwise_conv_output = depthwise_conv(\n        inputs,\n        depthwise_kernel,\n        strides,\n        padding,\n        data_format,\n        dilation_rate,\n    )\n    return conv(\n        depthwise_conv_output,\n        pointwise_kernel,\n        strides=1,\n        padding=\"valid\",\n        data_format=data_format,\n        dilation_rate=dilation_rate,\n    )\n\n\ndef conv_transpose(\n    inputs,\n    kernel,\n    strides=1,\n    padding=\"valid\",\n    output_padding=None,\n    data_format=None,\n    dilation_rate=1,\n):\n    data_format = backend.standardize_data_format(data_format)\n    num_spatial_dims = inputs.ndim - 2\n    padding_values = compute_conv_transpose_padding_args_for_jax(\n        input_shape=inputs.shape,\n        kernel_shape=kernel.shape,\n        strides=strides,\n        padding=padding,\n        output_padding=output_padding,\n        dilation_rate=dilation_rate,\n    )\n    dimension_numbers = _convert_to_lax_conv_dimension_numbers(\n        num_spatial_dims,\n        data_format,\n        transpose=False,\n    )\n    strides = _convert_to_spatial_operand(\n        strides,\n        num_spatial_dims,\n        data_format,\n        include_batch_and_channels=False,\n    )\n    dilation_rate = _convert_to_spatial_operand(\n        dilation_rate,\n        num_spatial_dims,\n        data_format,\n        include_batch_and_channels=False,\n    )\n\n    return jax.lax.conv_transpose(\n        inputs,\n        kernel,\n        strides,\n        padding=padding_values,\n        rhs_dilation=dilation_rate,\n        dimension_numbers=dimension_numbers,\n        transpose_kernel=True,\n    )\n\n\ndef one_hot(x, num_classes, axis=-1, dtype=None, sparse=False):\n    x = convert_to_tensor(x)\n    if sparse:\n        if axis < 0:\n            axis = axis + len(x.shape) + 1\n        if dtype is None:\n            dtype = \"float32\"\n        # We deal with negative inputs by having zeros in the output although\n        # it's useless. It makes shapes static.\n        values = jnp.greater_equal(jnp.ravel(x), 0).astype(dtype)\n        values_count = values.shape[0]\n        indices = [jnp.arange(dim) for dim in x.shape]\n        indices = list(jnp.meshgrid(*indices, indexing=\"ij\"))\n        indices.insert(axis, jnp.maximum(x, 0))  # Deal with negative indices\n        indices = [a.reshape(values_count, 1).astype(\"int32\") for a in indices]\n        indices = jnp.concatenate(indices, axis=1)\n        shape = list(x.shape)\n        shape.insert(axis, num_classes)\n        shape = tuple(shape)\n        return jax_sparse.BCOO(\n            (values, indices),\n            shape=shape,\n            indices_sorted=True,\n            unique_indices=True,\n        )\n    return jnn.one_hot(x, num_classes, axis=axis, dtype=dtype)\n\n\ndef multi_hot(x, num_classes, axis=-1, dtype=None, sparse=False):\n    x = convert_to_tensor(x)\n    reduction_axis = 1 if len(x.shape) > 1 else 0\n    if sparse:\n        result = one_hot(\n            x, num_classes, axis=axis, dtype=\"int32\", sparse=sparse\n        )\n        # JAX's BCOO does not support max reduction, use sum and compare with 0.\n        result = jax_sparse.bcoo_reduce_sum(result, axes=(reduction_axis,))\n        result = jax_sparse.bcoo_sum_duplicates(result)\n        values = jnp.greater_equal(result.data, 0).astype(dtype)\n        return jax_sparse.BCOO(\n            (values, result.indices),\n            shape=result.shape,\n            indices_sorted=True,\n            unique_indices=True,\n        )\n    return jnp.max(\n        one_hot(cast(x, \"int32\"), num_classes, axis=axis, dtype=dtype),\n        axis=reduction_axis,\n    )\n\n\ndef categorical_crossentropy(target, output, from_logits=False, axis=-1):\n    target = jnp.array(target)\n    output = jnp.array(output)\n\n    if target.shape != output.shape:\n        raise ValueError(\n            \"Arguments `target` and `output` must have the same shape. \"\n            \"Received: \"\n            f\"target.shape={target.shape}, output.shape={output.shape}\"\n        )\n    if len(target.shape) < 1:\n        raise ValueError(\n            \"Arguments `target` and `output` must be at least rank 1. \"\n            \"Received: \"\n            f\"target.shape={target.shape}, output.shape={output.shape}\"\n        )\n\n    if from_logits:\n        log_prob = jax.nn.log_softmax(output, axis=axis)\n    else:\n        output = output / jnp.sum(output, axis, keepdims=True)\n        output = jnp.clip(output, backend.epsilon(), 1.0 - backend.epsilon())\n        log_prob = jnp.log(output)\n    return -jnp.sum(target * log_prob, axis=axis)\n\n\ndef sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):\n    target = jnp.array(target, dtype=\"int32\")\n    output = jnp.array(output)\n    if len(target.shape) == len(output.shape) and target.shape[-1] == 1:\n        target = jnp.squeeze(target, axis=-1)\n\n    if len(output.shape) < 1:\n        raise ValueError(\n            \"Argument `output` must be at least rank 1. \"\n            \"Received: \"\n            f\"output.shape={output.shape}\"\n        )\n    if target.shape != output.shape[:-1]:\n        raise ValueError(\n            \"Arguments `target` and `output` must have the same shape \"\n            \"up until the last dimension: \"\n            f\"target.shape={target.shape}, output.shape={output.shape}\"\n        )\n    if from_logits:\n        log_prob = jax.nn.log_softmax(output, axis=axis)\n    else:\n        output = output / jnp.sum(output, axis, keepdims=True)\n        output = jnp.clip(output, backend.epsilon(), 1.0 - backend.epsilon())\n        log_prob = jnp.log(output)\n    target = jnn.one_hot(target, output.shape[axis], axis=axis)\n    return -jnp.sum(target * log_prob, axis=axis)\n\n\ndef binary_crossentropy(target, output, from_logits=False):\n    target = jnp.array(target)\n    output = jnp.array(output)\n\n    if target.shape != output.shape:\n        raise ValueError(\n            \"Arguments `target` and `output` must have the same shape. \"\n            \"Received: \"\n            f\"target.shape={target.shape}, output.shape={output.shape}\"\n        )\n\n    if from_logits:\n        log_logits = jax.nn.log_sigmoid(output)\n        log_neg_logits = jax.nn.log_sigmoid(-output)\n        return -1.0 * target * log_logits - (1.0 - target) * log_neg_logits\n\n    output = jnp.clip(output, backend.epsilon(), 1.0 - backend.epsilon())\n    bce = target * jnp.log(output)\n    bce += (1.0 - target) * jnp.log(1.0 - output)\n    return -bce\n\n\ndef moments(x, axes, keepdims=False, synchronized=False):\n    if synchronized:\n        raise NotImplementedError(\n            \"Argument synchronized=True is not supported with JAX.\"\n        )\n    # The dynamic range of float16 is too limited for statistics. As a\n    # workaround, we simply perform the operations on float32 and convert back\n    # to float16\n    need_cast = False\n    ori_dtype = backend.standardize_dtype(x.dtype)\n    if ori_dtype in (\"float16\", \"bfloat16\"):\n        need_cast = True\n        x = cast(x, \"float32\")\n\n    mean = jnp.mean(x, axes, keepdims=True)\n    variance = jnp.var(x, axis=axes, keepdims=True)\n\n    if not keepdims:\n        mean = jnp.squeeze(mean, axes)\n        variance = jnp.squeeze(variance, axes)\n    if need_cast:\n        # avoid overflow and underflow when casting from float16 to float32\n        mean = jnp.clip(\n            mean, jnp.finfo(jnp.float16).min, jnp.finfo(jnp.float16).max\n        )\n        variance = jnp.clip(\n            variance, jnp.finfo(jnp.float16).min, jnp.finfo(jnp.float16).max\n        )\n        mean = cast(mean, ori_dtype)\n        variance = cast(variance, ori_dtype)\n    return mean, variance\n\n\ndef batch_normalization(\n    x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3\n):\n    shape = [1] * len(x.shape)\n    shape[axis] = mean.shape[0]\n    mean = jnp.reshape(mean, shape)\n    variance = jnp.reshape(variance, shape)\n\n    inv = jax.lax.rsqrt(variance + epsilon)\n    if scale is not None:\n        scale = jnp.reshape(scale, shape)\n        inv = inv * scale\n\n    res = -mean * inv\n    if offset is not None:\n        offset = jnp.reshape(offset, shape)\n        res = res + offset\n\n    return jnp.add(x * inv, res)\n\n\ndef ctc_loss(target, output, target_length, output_length, mask_index=0):\n    # Ref: https://github.com/google-deepmind/optax\n    # optax.ctc_loss_with_forward_probs\n    target = convert_to_tensor(target, dtype=\"int32\")\n    output = convert_to_tensor(output)\n    target_length = convert_to_tensor(target_length, \"int32\")\n    output_length = convert_to_tensor(output_length, \"int32\")\n    batch_size, max_input_length, num_classes = output.shape\n    batch_size, max_label_length = target.shape\n    log_epsilon = -1e5\n\n    # Ensure that the dtype promotion behavior matches that of `tf.nn.ctc_loss`\n    dtype = backend.result_type(output.dtype, \"float32\")\n    output = cast(output, dtype)\n\n    def _lengths_to_paddings(lengths, max_length):\n        indices = jnp.arange(max_length).reshape(\n            (1,) * lengths.ndim + (max_length,)\n        )\n        lengths = jnp.expand_dims(lengths, axis=-1)\n        elem_valid = indices < lengths\n        return jnp.logical_not(elem_valid)\n\n    target_paddings = _lengths_to_paddings(target_length, max_label_length)\n    output_paddings = _lengths_to_paddings(output_length, max_input_length)\n    target_paddings = target_paddings.astype(output.dtype)\n    output_paddings = output_paddings.astype(output.dtype)\n\n    logprobs = jnn.log_softmax(output)\n    label_lengths = max_label_length - jnp.sum(target_paddings, axis=1).astype(\n        jnp.int32\n    )\n\n    # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1].\n    repeat = (target[:, :-1] == target[:, 1:]).astype(jnp.float32)\n    repeat = jnp.pad(repeat, ((0, 0), (0, 1)))\n\n    logprobs_phi = logprobs[:, :, mask_index : mask_index + 1]  # [B, T, 1]\n    logprobs_phi = jnp.transpose(logprobs_phi, (1, 0, 2))  # [T, B, 1]\n\n    _one_hot = jax.nn.one_hot(\n        target, num_classes=num_classes, dtype=logprobs.dtype\n    )  # [B, N, K]\n    logprobs_emit = jnp.einsum(\"btk,bnk->btn\", logprobs, _one_hot)\n    logprobs_emit = jnp.transpose(logprobs_emit, (1, 0, 2))  # [T, B, N]\n\n    # [B, N]\n    logalpha_phi_init = (\n        jnp.ones((batch_size, max_label_length + 1), dtype=output.dtype)\n        * log_epsilon\n    )\n    logalpha_phi_init = logalpha_phi_init.at[:, 0].set(0.0)\n    logalpha_emit_init = (\n        jnp.ones((batch_size, max_label_length), dtype=output.dtype)\n        * log_epsilon\n    )\n\n    def update_phi_score(phi, added_score):\n        # Update `phi[:, 1:]`` with adding `added_score` in log space.\n        return jnp.concatenate(\n            [phi[:, :1], jnp.logaddexp(phi[:, 1:], added_score)], axis=-1\n        )\n\n    def loop_body(prev, x):\n        prev_phi, prev_emit = prev\n        # emit-to-phi epsilon transition, except if the next label is repetition\n        prev_phi_orig = prev_phi\n        prev_phi = update_phi_score(prev_phi, prev_emit + log_epsilon * repeat)\n\n        logprob_emit, logprob_phi, pad = x\n\n        # phi-to-emit transition\n        next_emit = jnp.logaddexp(\n            prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit\n        )\n        # self-loop transition\n        next_phi = prev_phi + logprob_phi\n        # emit-to-phi blank transition only when the next label is repetition\n        next_phi = update_phi_score(\n            next_phi, prev_emit + logprob_phi + log_epsilon * (1.0 - repeat)\n        )\n\n        pad = pad.reshape((batch_size, 1))\n        next_emit = pad * prev_emit + (1.0 - pad) * next_emit\n        next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi\n\n        return (next_phi, next_emit), (next_phi, next_emit)\n\n    xs = (logprobs_emit, logprobs_phi, output_paddings.transpose((1, 0)))\n    _, (logalpha_phi, logalpha_emit) = jax.lax.scan(\n        loop_body, (logalpha_phi_init, logalpha_emit_init), xs\n    )\n\n    # last row needs to be updated with the last epsilon transition\n    logalpha_phi_last = update_phi_score(logalpha_phi[-1], logalpha_emit[-1])\n    logalpha_phi = logalpha_phi.at[-1].set(logalpha_phi_last)\n\n    # extract per_seq_loss\n    # [B, N+1]\n    _one_hot = jax.nn.one_hot(\n        label_lengths,\n        num_classes=max_label_length + 1,\n        dtype=logalpha_phi_last.dtype,\n    )\n    per_seq_loss = -jnp.einsum(\"bn,bn->b\", logalpha_phi_last, _one_hot)\n    return per_seq_loss\n\n\ndef _ctc_greedy_decode(\n    inputs,\n    sequence_lengths,\n    merge_repeated=True,\n    mask_index=None,\n):\n    inputs = convert_to_tensor(inputs)\n    sequence_lengths = convert_to_tensor(sequence_lengths, dtype=\"int32\")\n    batch_size, max_length, num_classes = inputs.shape\n\n    if mask_index is None:\n        mask_index = num_classes - 1\n\n    indices = jnp.argmax(inputs, axis=-1)\n    scores = jnp.max(inputs, axis=-1)\n\n    seqlen_mask = jnp.arange(max_length)[None, :]\n    seqlen_mask = seqlen_mask >= sequence_lengths[:, None]\n\n    indices = jnp.where(seqlen_mask, mask_index, indices)\n    scores = jnp.where(seqlen_mask, 0.0, scores)\n\n    if merge_repeated:\n        repeat_mask = indices[:, 1:] == indices[:, :-1]\n        repeat_mask = jnp.pad(repeat_mask, ((0, 0), (1, 0)))\n        indices = jnp.where(repeat_mask, mask_index, indices)\n\n    # We set to -1 for blank labels\n    invalid_mask = indices == mask_index\n    indices = jnp.where(invalid_mask, -1, indices)\n\n    # We rearrange the indices by moving `mask_index` to the end of the array\n    order = jnp.expand_dims(jnp.arange(max_length), axis=0)  # [1, N]\n    order = jnp.tile(order, (batch_size, 1))  # [B, N]\n    order = jnp.where(invalid_mask, max_length, order)\n    order = jnp.argsort(order, axis=-1)\n    indices = jnp.take_along_axis(indices, order, axis=-1)\n\n    scores = -jnp.sum(scores, axis=1)[:, None]\n    indices = jnp.expand_dims(indices, axis=0)\n    return indices, scores\n\n\ndef _ctc_beam_search_decode(\n    inputs,\n    sequence_lengths,\n    beam_width=100,\n    top_paths=1,\n    mask_index=None,\n):\n    inputs = convert_to_tensor(inputs)\n    sequence_lengths = convert_to_tensor(sequence_lengths)\n\n    batch_size, max_seq_len, num_classes = inputs.shape\n    inputs = jnn.log_softmax(inputs)\n    seqlen_mask = jnp.arange(max_seq_len)[None, :] >= sequence_lengths[:, None]\n\n    if mask_index is None:\n        mask_index = num_classes - 1\n\n    # This is a workaround for the fact that jnp.argsort does not support\n    # the order parameter which is used to break ties when scores are equal.\n    # For compatibility with the tensorflow implementation, we flip the inputs\n    # and the mask_index, and then flip the classes back to the correct indices\n    inputs = jnp.flip(inputs, axis=2)\n    mask_index = num_classes - mask_index - 1\n\n    _pad = -1\n\n    init_paths = jnp.full(\n        (batch_size, 2 * beam_width, max_seq_len), _pad, dtype=jnp.int32\n    )\n\n    num_init_paths = builtins.min(num_classes, beam_width)\n    max_classes = jnp.argsort(inputs[:, 0], axis=1)[:, -num_init_paths:]\n    init_classes = jnp.where(max_classes == mask_index, _pad, max_classes)\n    init_paths = init_paths.at[:, :num_init_paths, 0].set(init_classes)\n\n    init_scores = (\n        jnp.full((batch_size, 2 * beam_width), -jnp.inf, dtype=inputs.dtype)\n        .at[:, :num_init_paths]\n        .set(jnp.take_along_axis(inputs[:, 0], max_classes, axis=1))\n    )\n    init_masked = init_paths[:, :, 0] == _pad\n\n    def _extend_paths(paths, scores, masked, x):\n        paths = jnp.repeat(paths, num_classes, axis=0)\n        scores = jnp.repeat(scores, num_classes)\n        masked = jnp.repeat(masked, num_classes)\n\n        path_tail_index = jnp.argmax(paths == _pad, axis=1)\n        paths_arange = jnp.arange(2 * beam_width * num_classes)\n        path_tails = paths[paths_arange, path_tail_index - 1]\n        path_tails = jnp.where(path_tail_index == 0, _pad, path_tails)\n\n        classes = jnp.arange(num_classes).at[mask_index].set(_pad)\n        classes = jnp.tile(classes, 2 * beam_width)\n\n        prev_masked = masked\n        masked = classes == _pad\n\n        masked_repeat = ~prev_masked & (path_tails == classes)\n        classes = jnp.where(masked_repeat, _pad, classes)\n        paths = paths.at[paths_arange, path_tail_index].set(classes)\n\n        x = jnp.tile(x, 2 * beam_width)\n        scores = scores + x\n\n        return paths, scores, masked\n\n    def _merge_scores(unique_inverse, scores):\n        scores_max = jnp.max(scores)\n        scores_exp = jnp.exp(scores - scores_max)\n        scores = jnp.zeros_like(scores).at[unique_inverse].add(scores_exp)\n        scores = jnp.log(scores) + scores_max\n        return scores\n\n    def _prune_paths(paths, scores, masked):\n        paths, unique_inverse = jnp.unique(\n            paths,\n            return_inverse=True,\n            size=2 * num_classes * beam_width,\n            axis=0,\n            fill_value=_pad,\n        )\n        if len(unique_inverse.shape) >= 2:\n            unique_inverse = jnp.squeeze(unique_inverse, axis=1)\n\n        emit_scores = jnp.where(masked, -jnp.inf, scores)\n        mask_scores = jnp.where(masked, scores, -jnp.inf)\n\n        emit_scores = _merge_scores(unique_inverse, emit_scores)\n        mask_scores = _merge_scores(unique_inverse, mask_scores)\n\n        total_scores = jnp.logaddexp(emit_scores, mask_scores)\n        top_indices = jnp.argsort(total_scores)[-beam_width:]\n\n        paths = paths[top_indices]\n        emit_scores = emit_scores[top_indices]\n        mask_scores = mask_scores[top_indices]\n\n        paths = jnp.tile(paths, (2, 1))\n        scores = jnp.concatenate([emit_scores, mask_scores])\n        masked = jnp.concatenate(\n            [jnp.zeros(beam_width, bool), jnp.ones(beam_width, bool)]\n        )\n\n        return paths, scores, masked\n\n    def _decode_step(paths, scores, masked, x):\n        paths, scores, masked = _extend_paths(paths, scores, masked, x)\n        paths, scores, masked = _prune_paths(paths, scores, masked)\n        return paths, scores, masked\n\n    def _step(prev, x):\n        paths, scores, masked = prev\n        x, seqlen_mask = x\n\n        paths, scores, masked = lax.cond(\n            seqlen_mask,\n            lambda paths, scores, masked, x: (paths, scores, masked),\n            _decode_step,\n            paths,\n            scores,\n            masked,\n            x,\n        )\n\n        return (paths, scores, masked), None\n\n    def _decode_batch(\n        init_paths, init_scores, init_masked, inputs, seqlen_mask\n    ):\n        (paths, scores, masked), _ = lax.scan(\n            _step,\n            (init_paths, init_scores, init_masked),\n            (inputs[1:], seqlen_mask[1:]),\n        )\n\n        paths, unique_inverse = jnp.unique(\n            paths,\n            return_inverse=True,\n            size=2 * num_classes * beam_width,\n            axis=0,\n            fill_value=_pad,\n        )\n        if len(unique_inverse.shape) >= 2:\n            unique_inverse = jnp.squeeze(unique_inverse, axis=1)\n        scores = _merge_scores(unique_inverse, scores)\n\n        top_indices = jnp.argsort(scores)[-top_paths:][::-1]\n        paths = paths[top_indices]\n        scores = scores[top_indices]\n\n        return paths, scores\n\n    paths, scores = jax.vmap(_decode_batch)(\n        init_paths, init_scores, init_masked, inputs, seqlen_mask\n    )\n\n    # convert classes back to the correct indices\n    paths = jnp.where(paths == _pad, _pad, num_classes - paths - 1)\n    paths = jnp.transpose(paths, [1, 0, 2])\n    return paths, scores\n\n\ndef ctc_decode(\n    inputs,\n    sequence_lengths,\n    strategy=\"greedy\",\n    beam_width=100,\n    top_paths=1,\n    merge_repeated=True,\n    mask_index=0,\n):\n    inputs = convert_to_tensor(inputs)\n    dtype = backend.result_type(inputs.dtype, \"float32\")\n    inputs = cast(inputs, dtype)\n\n    if strategy == \"greedy\":\n        return _ctc_greedy_decode(\n            inputs,\n            sequence_lengths,\n            merge_repeated=merge_repeated,\n            mask_index=mask_index,\n        )\n    elif strategy == \"beam_search\":\n        return _ctc_beam_search_decode(\n            inputs,\n            sequence_lengths,\n            beam_width=beam_width,\n            top_paths=top_paths,\n            mask_index=mask_index,\n        )\n    else:\n        raise ValueError(\n            f\"Invalid strategy {strategy}. Supported values are \"\n            \"'greedy' and 'beam_search'.\"\n        )\n\n\ndef psnr(x1, x2, max_val):\n    if x1.shape != x2.shape:\n        raise ValueError(\n            f\"Input shapes {x1.shape} and {x2.shape} must \"\n            \"match for PSNR calculation. \"\n        )\n\n    max_val = convert_to_tensor(max_val, dtype=x2.dtype)\n    mse = jnp.mean(jnp.square(x1 - x2))\n    psnr = 20 * jnp.log10(max_val) - 10 * jnp.log10(mse)\n    return psnr\n\n\ndef _can_use_flash_attention(query, key, value, bias, raise_error=False):\n    \"\"\"Verify the availability of flash attention.\"\"\"\n    try:\n        from jax._src.cudnn.fused_attention_stablehlo import _normalize_layout\n        from jax._src.cudnn.fused_attention_stablehlo import (\n            check_compute_capability,\n        )\n        from jax._src.cudnn.fused_attention_stablehlo import check_cudnn_version\n        from jax._src.cudnn.fused_attention_stablehlo import (\n            check_is_flash_attention,\n        )\n        from jax._src.cudnn.fused_attention_stablehlo import check_layout\n        from jax.nn import dot_product_attention as dot_product_attention\n    except ImportError:\n        if raise_error:\n            raise ImportError(\n                \"Flash attention is not supported in your current JAX version. \"\n                \"Please update it by following the official guide: \"\n                \"https://jax.readthedocs.io/en/latest/installation.html\"\n            )\n        return False\n\n    if jax.devices()[0].platform == \"tpu\":\n        return True\n    try:\n        # Check if cuDNN is installed and raise RuntimeError if cuDNN is not\n        # detected\n        cudnn_version = check_cudnn_version()\n        # Only support at least Ampere\n        if not check_compute_capability(\"8.0\"):\n            raise RuntimeError(\"Require at least Ampere arch to run\")\n\n        # Inspect inputs of `check_layout`\n        check_layout_params = list(\n            inspect.signature(check_layout).parameters.keys()\n        )\n        for known_param in (\"query\", \"key\", \"value\", \"bias\", \"layout\"):\n            check_layout_params.remove(known_param)\n        # Defaults to `None` when not specified.\n        check_layout_kwargs = {key: None for key in check_layout_params}\n        check_layout(\n            query,\n            key,\n            value,\n            bias,\n            layout=_normalize_layout(\"BTNH\"),\n            **check_layout_kwargs,\n        )\n\n        # Inspect inputs of `check_is_flash_attention`\n        check_is_flash_attention_params = inspect.signature(\n            check_is_flash_attention\n        ).parameters\n        check_is_flash_attention_kwargs = {\n            \"query\": query,\n            \"key\": key,\n            \"value\": value,\n            \"layout\": _normalize_layout(\"BTNH\"),\n            \"cudnn_version\": cudnn_version,\n            \"has_bias\": bias is not None,\n            \"is_training\": False,\n        }\n        # Remove unsupported arguments\n        for param in list(check_is_flash_attention_kwargs.keys()):\n            if param not in check_is_flash_attention_params:\n                check_is_flash_attention_kwargs.pop(param)\n        check_is_flash_attention(**check_is_flash_attention_kwargs)\n        return True\n    except:\n        if raise_error:\n            raise\n        return False\n\n\ndef _apply_masks(logits, mask, is_causal):\n    if mask is None and not is_causal:\n        return logits\n\n    combined_mask = jnp.ones_like(logits, dtype=\"bool\")\n    if mask is not None:\n        combined_mask = jnp.logical_and(combined_mask, mask)\n\n    if is_causal:\n        T, S = logits.shape[2], logits.shape[3]\n        mask = jnp.tril(jnp.ones((T, S), dtype=\"bool\"))\n        mask = mask[None, None, :, :]\n        combined_mask = jnp.logical_and(combined_mask, mask)\n\n    large_negative_number = jnp.asarray(\n        -0.7 * jnp.finfo(logits.dtype).max, dtype=logits.dtype\n    )\n    padded_logits = jnp.where(combined_mask, logits, large_negative_number)\n    return padded_logits\n\n\ndef _dot_product_attention_core(\n    query, key, value, bias, mask, is_causal, scale\n):\n    logits_dtype = jnp.promote_types(query.dtype, jnp.float32)\n    logits = jnp.einsum(\n        \"BTNH,BSNH->BNTS\", query, key, preferred_element_type=logits_dtype\n    )\n    logits *= jnp.array(scale, dtype=logits.dtype)\n\n    if bias is not None:\n        logits = (logits + bias).astype(logits.dtype)\n\n    padded_logits = _apply_masks(logits, mask, is_causal)\n\n    # Softmax and it is always carried out in fp32.\n    padded_logits = padded_logits.astype(jnp.float32)\n    probs = jax.nn.softmax(padded_logits, axis=-1).astype(key.dtype)\n    return jnp.einsum(\"BNTS,BSNH->BTNH\", probs, value)\n\n\ndef wrap_flash_attention(\n    query,\n    key,\n    value,\n    decoder_segment_ids,\n    custom_mask=None,\n    attn_logits_soft_cap=None,\n    head_shards=1,\n    q_seq_shards=1,\n):\n    \"\"\"Applies a wrapped flash attention mechanism using the Splash kernel.\n    This function prepares the appropriate attention mask (causal or custom),\n    constructs a multi-head mask, and applies the Splash multi-head attention\n    kernel to the provided query, key, and value tensors. It supports optional\n    sharding and soft capping of attention logits.\n    Args:\n        query: jax.Array. The query tensor of shape\n            (batch, num_heads, seq_len, head_dim).\n        key: jax.Array. The key tensor of shape\n            (batch, num_heads, seq_len, head_dim).\n        value: jax.Array. The value tensor of shape\n            (batch, num_heads, seq_len, head_dim).\n        decoder_segment_ids: Optional. Segment IDs for the decoder, used for\n            sharding or masking.\n        custom_mask: Optional[jax.Array]. A custom attention mask to apply. If\n            None, a causal mask is used.\n        attn_logits_soft_cap: Optional[float]. If provided, applies a soft cap\n            to the attention logits.\n        head_shards: int, default=1. Number of shards for the attention heads.\n        q_seq_shards: int, default=1. Number of shards for the query sequence\n            dimension.\n    Returns:\n        jax.Array: The result of applying the Splash multi-head attention\n            kernel to the inputs.\n    Raises:\n        AssertionError: If sharding along the sequence dimension is attempted\n            with decoder_segment_ids.\n    \"\"\"\n    if decoder_segment_ids is not None:\n        if query.shape[2] != decoder_segment_ids.q.shape[1]:\n            raise ValueError(\n                \"Sharding along sequence dimension not allowed\"\n                \" in TPU kernel attention\"\n            )\n\n    if custom_mask is not None:\n        mask = splash_attention_mask.NumpyMask(array=custom_mask)\n    else:\n        mask = splash_attention_mask.CausalMask(\n            shape=(query.shape[2], query.shape[2])\n        )\n\n    # Create multi-head mask\n    multi_head_mask = splash_attention_mask.MultiHeadMask(\n        masks=(mask,) * query.shape[1]\n    )\n    splash_kernel = splash_attention_kernel.make_splash_mha(\n        mask=multi_head_mask,\n        head_shards=head_shards,\n        q_seq_shards=q_seq_shards,\n        attn_logits_soft_cap=attn_logits_soft_cap,\n    )\n\n    return jax.vmap(splash_kernel)(\n        query, key, value, segment_ids=decoder_segment_ids\n    )\n\n\ndef dot_product_attention(\n    query,\n    key,\n    value,\n    bias=None,\n    mask=None,\n    scale=None,\n    is_causal=False,\n    flash_attention=None,\n    attn_logits_soft_cap=None,\n):\n    \"\"\"Computes dot-product attention given query, key, and value.\n\n    This is the core computation of attention that is used in transformers.\n    For TPU platforms, flash attention optimizations are automatically applied\n    when possible, and sharding parameters are inferred from the layout map\n    in the current distribution context.\n\n    Args:\n        query: Queries with shape `[batch, time, heads,\n            depth_k]`.\n        key: Keys with shape `[batch, time, heads,\n            depth_k]`.\n        value: Values with shape `[batch, time, heads,\n            depth_v]`.\n        bias: Optional bias with shape broadcastable to\n            `[batch, heads, dest_time, source_time]`.\n        mask: Optional mask with shape broadcastable to\n            `[batch, heads, dest_time, source_time]`.\n        scale: Float. Optional scale that is applied to the attention\n            computation.\n        is_causal: Boolean. Specifying whether causal masking is applied.\n        flash_attention: Boolean. Whether to use flash attention optimization\n            for increased performance. Default to None, which means it will\n            be auto-determined based on the platform, input shapes and\n            compatibility.\n        attn_logits_soft_cap: Float. Optional float to softly cap attention\n            logits to avoid numerical stability issues. Applied as:\n            `logits = logits / (1.0 + abs(logits) / attn_logits_soft_cap)`.\n\n    Returns:\n        JAX Array of shape `[batch, time, heads, depth_v]`.\n    \"\"\"\n    query = convert_to_tensor(query)\n    key = convert_to_tensor(key)\n    value = convert_to_tensor(value)\n    if len(query.shape) != 4 or len(key.shape) != 4 or len(value.shape) != 4:\n        raise ValueError(\n            \"`dot_product_attention` only supports 4D inputs. \"\n            f\"Received: query.shape={query.shape}, key.shape={key.shape}, \"\n            f\"value.shape={value.shape}.\"\n        )\n    compute_dtype = backend.result_type(query.dtype, key.dtype, value.dtype)\n    query = cast(query, compute_dtype)\n    key = cast(key, compute_dtype)\n    value = cast(value, compute_dtype)\n    if bias is not None:\n        bias = convert_to_tensor(bias, dtype=compute_dtype)\n\n    # Check platform\n    platform = jax.devices()[0].platform\n    is_tpu = platform == \"tpu\"\n\n    # Determine flash attention compatibility\n    if flash_attention is None:\n        flash_attention = _can_use_flash_attention(query, key, value, bias)\n    elif flash_attention is True:\n        # Use `raise_error=True` to provide more details if the inputs failed to\n        # use flash attention\n        _can_use_flash_attention(query, key, value, bias, raise_error=True)\n\n    # TPU-specific flash attention path\n    if is_tpu and flash_attention:\n        # Get sharding parameters from distribution context\n        head_shards = 1\n        # Typically keep q_seq_shards=1 for best performance\n        q_seq_shards = 1\n        try:\n            from keras.src.distribution.distribution_lib import ModelParallel\n            from keras.src.distribution.distribution_lib import (\n                distribution as get_dist,\n            )\n\n            # Get current distribution if available\n            dist = get_dist()\n            if dist and isinstance(dist, ModelParallel):\n                mesh = dist.device_mesh\n                if \"model\" in mesh.axis_names:\n                    model_dim_index = mesh.axis_names.index(\"model\")\n                    # Set head_shards based on the model dimension of the mesh\n                    head_shards = mesh.shape[model_dim_index]\n        except (ImportError, ValueError, AttributeError):\n            # Use default values if detection fails\n            logging.exception(\n                \"Failed to determine distribution context for sharding. \"\n                \"Using default head_shards=1 and q_seq_shards=1.\"\n            )\n        # Transpose to ('batch', 'heads', 'length', 'head_dim')\n        query_tpu_layout = jnp.transpose(query, axes=(0, 2, 1, 3))\n        key_tpu_layout = jnp.transpose(key, axes=(0, 2, 1, 3))\n        value_tpu_layout = jnp.transpose(value, axes=(0, 2, 1, 3))\n\n        bs, num_heads, q_len, head_dim = query_tpu_layout.shape\n\n        # Apply scale to query if provided\n        if scale is not None:\n            # TPU kernel applies 1/sqrt(head_dim) internally, to achieve\n            # overall QK^T * scale, scale query by (scale * sqrt(head_dim))\n            query_tpu_layout = query_tpu_layout * (scale * math.sqrt(head_dim))\n\n        # Create segment IDs for Splash Attention (for packing/batching)\n        segment_ids = jnp.zeros([bs, q_len], dtype=jnp.int32)\n        decoder_segment_ids = splash_attention_kernel.SegmentIds(\n            q=segment_ids, kv=segment_ids\n        )\n\n        # Process mask for Splash Attention\n        custom_mask = None\n        if mask is not None:\n            mask_bool = mask.astype(\"bool\") if mask.dtype != jnp.bool_ else mask\n\n            if mask_bool.ndim == 3 and mask_bool.shape[0] == bs:\n                custom_mask = mask_bool[0]\n            elif mask_bool.ndim == 4 and mask_bool.shape[0] == bs:\n                custom_mask = mask_bool[0, 0]\n\n            if is_causal and custom_mask is not None:\n                causal_mask = jnp.tril(\n                    jnp.ones((q_len, q_len), dtype=jnp.bool_)\n                )\n                custom_mask = jnp.logical_and(custom_mask, causal_mask)\n\n        if custom_mask is None and is_causal:\n            custom_mask = jnp.tril(jnp.ones((q_len, q_len), dtype=jnp.bool_))\n\n        # Splash attention kernel requires concrete mask values for hashing.\n        # If the mask is a tracer (e.g. inside a scan/loop), we must fall back.\n        if isinstance(mask, jax.core.Tracer) or isinstance(\n            custom_mask, jax.core.Tracer\n        ):\n            flash_attention = False\n        else:\n            try:\n                output = wrap_flash_attention(\n                    query_tpu_layout,\n                    key_tpu_layout,\n                    value_tpu_layout,\n                    decoder_segment_ids=decoder_segment_ids,\n                    custom_mask=custom_mask,\n                    attn_logits_soft_cap=attn_logits_soft_cap,\n                    head_shards=head_shards,\n                    q_seq_shards=q_seq_shards,\n                )\n                # Transpose output back to Keras layout\n                return jnp.transpose(output, axes=(0, 2, 1, 3))\n            except Exception:\n                logging.exception(\n                    \"Failed to apply Splash kernel for flash attention. \"\n                    \"Falling back to JAX native dot_product_attention.\"\n                )\n                flash_attention = False\n\n    # JAX native dot_product_attention for GPU or fallback for TPU\n    if hasattr(jax.nn, \"dot_product_attention\"):\n        impls = [\"cudnn\", \"xla\"] if flash_attention else [\"xla\"]\n        for impl in impls:\n            try:\n                return jax.nn.dot_product_attention(\n                    query,\n                    key,\n                    value,\n                    bias=bias,\n                    mask=mask,\n                    scale=scale,\n                    is_causal=is_causal,\n                    implementation=impl,\n                )\n            except Exception:\n                logging.exception(\n                    f\"Failed to apply {impl} implementation of \"\n                    \"jax.nn.dot_product_attention.\"\n                )\n\n    if flash_attention:\n        raise RuntimeError(\n            \"Flash attention is not supported in your current JAX version. \"\n            \"Please update it by following the official guide: \"\n            \"https://jax.readthedocs.io/en/latest/installation.html\"\n        )\n    # Ref: jax.nn.dot_product_attention\n    # https://github.com/jax-ml/jax/blob/jax-v0.4.33/jax/_src/nn/functions.py#L886\n    # Not support `query_seq_lengths` and `key_value_seq_lengths` args\n\n    # Fallback to custom XLA implementation\n    # This is the reference implementation from jax.nn.dot_product_attention\n    output_shape = query.shape\n    _, _, K, H = key.shape\n    scale = (1.0 / jnp.sqrt(H)) if scale is None else scale\n\n    # _dot_product_attention_xla\n    B, T, N, H = query.shape\n    G = N // K\n    query = jnp.reshape(query, (B, T, K, G, H))\n\n    def _reshape_to_grouped(t, t_name):\n        if t is not None:\n            while t.ndim < 4:\n                if t.ndim == 3 and t.shape[1] == N:\n                    t = jnp.expand_dims(t, axis=2)\n                else:\n                    t = jnp.expand_dims(t, axis=1)\n            tB, tN, tT, tS = t.shape\n            if tN == 1:\n                t = jnp.broadcast_to(t[:, :, None, :, :], (tB, tN, G, tT, tS))\n            else:\n                if tN != N:\n                    raise ValueError(\n                        f\"Expected `{t_name}` to have shape (B, 1, T, S) or \"\n                        f\"(B, N, T, S) with N={N} but got {t.shape}.\"\n                    )\n                t = jnp.reshape(t, (tB, K, G, tT, tS))\n        return t\n\n    bias = _reshape_to_grouped(bias, \"bias\")\n    mask = _reshape_to_grouped(mask, \"mask\")\n    vmapped_fn = jax.vmap(\n        _dot_product_attention_core,\n        in_axes=(3, None, None, 2, 2, None, None),\n        out_axes=3,\n    )\n    encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale)\n    return jnp.reshape(encoded, output_shape)\n\n\ndef unfold(input, kernel_size, dilation=1, padding=0, stride=1):\n    \"\"\"JAX implementation of Unfold.\n    Extract sliding local blocks from a **NCHW** batched image tensor.\n\n    Args:\n        input: 4-D tensor, shape (N, C, H, W)  **required**.\n        kernel_size: int or (kH, kW)\n        dilation: int or (dH, dW), default 1\n        padding: int or (pH, pW), default 0\n        stride: int or (sH, sW), default 1\n\n    Returns:\n        3-D tensor, shape (N, C*kH*kW, L)\n    \"\"\"\n\n    def _pair(x):\n        return (x, x) if isinstance(x, int) else x\n\n    k = _pair(kernel_size)\n    d = _pair(dilation)\n    p = _pair(padding)\n    s = _pair(stride)\n\n    N, C, H, W = input.shape\n\n    # ---- padding ----\n    if any(_ > 0 for _ in p):\n        input = jnp.pad(input, ((0, 0), (0, 0), (p[0], p[0]), (p[1], p[1])))\n\n    patches = lax.conv_general_dilated_patches(\n        input,\n        filter_shape=k,\n        window_strides=s,\n        padding=\"VALID\",  # has padde\n        rhs_dilation=d,\n        dimension_numbers=(\"NCHW\", \"OIHW\", \"NCHW\"),  # only support 'NCHW'\n    )  # shape: (N, C*kH*kW, oH, oW)\n\n    # ---- reshape -> (N, C*kH*kW, L) ----\n    _, CKK, oH, oW = patches.shape\n    return patches.reshape(N, CKK, oH * oW)\n\n\ndef fold(x, output_size, kernel_size, dilation=1, padding=0, stride=1):\n    \"\"\"JAX implementation of Fold (col2im).\n    Combine an array of sliding local blocks into a large tensor.\n\n    Uses ``lax.conv_transpose`` with an identity kernel so that the\n    entire operation is JIT-compilable and runs on XLA.\n\n    Args:\n        x: 3-D tensor, shape (N, C*kH*kW, L)  **required**.\n        output_size: int or (oH, oW)\n        kernel_size: int or (kH, kW)\n        dilation: int or (dH, dW), default 1\n        padding: int or (pH, pW), default 0\n        stride: int or (sH, sW), default 1\n\n    Returns:\n        4-D tensor, shape (N, C, oH, oW)\n    \"\"\"\n\n    def _pair(val):\n        return (val, val) if isinstance(val, int) else val\n\n    oH, oW = _pair(output_size)\n    kH, kW = _pair(kernel_size)\n    dH, dW = _pair(dilation)\n    pH, pW = _pair(padding)\n    sH, sW = _pair(stride)\n\n    N, CKK, L = x.shape\n    C = CKK // (kH * kW)\n\n    # Number of output patches along each dimension\n    nH = (oH + 2 * pH - dH * (kH - 1) - 1) // sH + 1\n    nW = (oW + 2 * pW - dW * (kW - 1) - 1) // sW + 1\n\n    # Reshape: (N, C*kH*kW, L) -> (N, C*kH*kW, nH, nW)\n    x = jnp.reshape(x, (N, CKK, nH, nW))\n\n    # Identity kernel: maps each (c, i, j) input channel to output\n    # channel c at kernel position (i, j).\n    # eye(CKK) -> (CKK, CKK) -> reshape (CKK, C, kH, kW) ->\n    # transpose to HWIO: (kH, kW, CKK, C)\n    kernel = jnp.eye(CKK, dtype=x.dtype)\n    kernel = kernel.reshape(CKK, C, kH, kW)\n    kernel = kernel.transpose(2, 3, 0, 1)  # -> (kH, kW, CKK, C)\n    # conv_transpose flips the kernel spatially, so pre-flip to cancel\n    kernel = jnp.flip(kernel, axis=(0, 1))\n\n    # Padded output size\n    oH_pad = oH + 2 * pH\n    oW_pad = oW + 2 * pW\n\n    # conv_transpose with padding=\"VALID\" produces output of size:\n    #   (nH - 1) * sH + (kH - 1) * dH + 1  (= oH_pad)\n    output = lax.conv_transpose(\n        x,\n        kernel,\n        strides=(sH, sW),\n        padding=\"VALID\",\n        rhs_dilation=(dH, dW),\n        dimension_numbers=(\"NCHW\", \"HWIO\", \"NCHW\"),\n    )\n\n    # Remove padding\n    if pH > 0 or pW > 0:\n        output = output[:, :, pH : oH_pad - pH, pW : oW_pad - pW]\n\n    return output\n\n\ndef depth_to_space(x, block_size, data_format=\"channels_last\"):\n    \"\"\"JAX implementation of depth_to_space (pixel shuffle).\n\n    Rearranges data from depth into blocks of spatial data.\n\n    Args:\n        x: 4-D tensor with shape (N, H, W, C) for channels_last or\n            (N, C, H, W) for channels_first.\n        block_size: An integer specifying the block size.\n        data_format: \"channels_last\" or \"channels_first\".\n\n    Returns:\n        A tensor with shape (N, H*block_size, W*block_size, C/block_size**2)\n        for channels_last or (N, C/block_size**2, H*block_size, W*block_size)\n        for channels_first.\n    \"\"\"\n    if data_format == \"channels_last\":\n        # NHWC format\n        n, h, w, c = x.shape\n        new_c = c // (block_size**2)\n        # Reshape: (N, H, W, C) -> (N, H, W, block_size, block_size, new_C)\n        x = jnp.reshape(x, (n, h, w, block_size, block_size, new_c))\n        # Transpose to (N, H, bH, W, bW, new_C) to interleave spatial blocks.\n        x = jnp.transpose(x, (0, 1, 3, 2, 4, 5))\n        # Reshape to the final spatial dimensions.\n        x = jnp.reshape(x, (n, h * block_size, w * block_size, new_c))\n    else:\n        # NCHW format\n        n, c, h, w = x.shape\n        new_c = c // (block_size**2)\n        # Reshape: (N, C, H, W) -> (N, new_C, block_size, block_size, H, W)\n        x = jnp.reshape(x, (n, new_c, block_size, block_size, h, w))\n        # Transpose: (N, C, bH, bW, H, W) -> (N, C, H, bH, W, bW)\n        x = jnp.transpose(x, (0, 1, 4, 2, 5, 3))\n        # Reshape: (N, C, H, bH, W, bW) -> (N, C, H*bH, W*bW)\n        x = jnp.reshape(x, (n, new_c, h * block_size, w * block_size))\n    return x\n\n\ndef space_to_depth(x, block_size, data_format=\"channels_last\"):\n    \"\"\"JAX implementation of space_to_depth (pixel unshuffle).\n\n    Rearranges blocks of spatial data into depth.\n\n    Args:\n        x: 4-D tensor with shape (N, H, W, C) for channels_last or\n            (N, C, H, W) for channels_first.\n        block_size: An integer specifying the block size.\n        data_format: \"channels_last\" or \"channels_first\".\n\n    Returns:\n        A tensor with shape (N, H/block_size, W/block_size, C*block_size**2)\n        for channels_last or (N, C*block_size**2, H/block_size, W/block_size)\n        for channels_first.\n    \"\"\"\n    if data_format == \"channels_last\":\n        # NHWC format\n        n, h, w, c = x.shape\n        new_h = h // block_size\n        new_w = w // block_size\n        # Reshape: (N, H, W, C) -> (N, new_H, bH, new_W, bW, C)\n        x = jnp.reshape(x, (n, new_h, block_size, new_w, block_size, c))\n        # Transpose: -> (N, new_H, new_W, bH, bW, C)\n        x = jnp.transpose(x, (0, 1, 3, 2, 4, 5))\n        # Reshape: -> (N, new_H, new_W, C*bH*bW)\n        x = jnp.reshape(x, (n, new_h, new_w, c * block_size**2))\n    else:\n        # NCHW format\n        n, c, h, w = x.shape\n        new_h = h // block_size\n        new_w = w // block_size\n        # Reshape: (N, C, H, W) -> (N, C, new_H, bH, new_W, bW)\n        x = jnp.reshape(x, (n, c, new_h, block_size, new_w, block_size))\n        # Transpose: -> (N, C, bH, bW, new_H, new_W)\n        x = jnp.transpose(x, (0, 1, 3, 5, 2, 4))\n        # Reshape: -> (N, C*bH*bW, new_H, new_W)\n        x = jnp.reshape(x, (n, c * block_size**2, new_h, new_w))\n    return x\n"
  },
  {
    "path": "keras/src/backend/jax/numpy.py",
    "content": "import builtins\nimport math\n\nimport jax\nimport jax.experimental.sparse as jax_sparse\nimport jax.numpy as jnp\nfrom jax import export as jax_export\n\nfrom keras.src.backend import config\nfrom keras.src.backend.common import dtypes\nfrom keras.src.backend.common.backend_utils import canonicalize_axis\nfrom keras.src.backend.common.backend_utils import to_tuple_or_list\nfrom keras.src.backend.common.variables import standardize_dtype\nfrom keras.src.backend.jax import nn\nfrom keras.src.backend.jax import sparse\nfrom keras.src.backend.jax.core import cast\nfrom keras.src.backend.jax.core import convert_to_tensor\n\n\ndef _uses_cpu(x):\n    if hasattr(x, \"device\"):\n        device = x.device\n        if not isinstance(device, jax.Device):\n            # Array is sharded.\n            return False\n        return device.platform == \"cpu\"\n    else:\n        # This is a Tracer, not a concrete Array.\n        return jax.default_backend() == \"cpu\"\n\n\ndef rot90(array, k=1, axes=(0, 1)):\n    \"\"\"Rotate an array by 90 degrees in the specified plane.\"\"\"\n    if array.ndim < 2:\n        raise ValueError(\n            f\"Input array must have at least 2 dimensions. \"\n            f\"Received: array.ndim={array.ndim}\"\n        )\n    if len(axes) != 2 or axes[0] == axes[1]:\n        raise ValueError(\n            f\"Invalid axes: {axes}. Axes must be a tuple of \"\n            \"two different dimensions.\"\n        )\n    return jnp.rot90(array, k=k, axes=axes)\n\n\n@sparse.elementwise_binary_union(linear=True, use_sparsify=True)\ndef add(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.add(x1, x2)\n\n\ndef bartlett(x):\n    x = convert_to_tensor(x)\n    return cast(jnp.bartlett(x), config.floatx())\n\n\ndef hamming(x):\n    x = convert_to_tensor(x)\n    return cast(jnp.hamming(x), config.floatx())\n\n\ndef hanning(x):\n    x = convert_to_tensor(x)\n    return cast(jnp.hanning(x), config.floatx())\n\n\ndef heaviside(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.heaviside(x1, x2)\n\n\ndef hypot(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.hypot(x1, x2)\n\n\ndef kaiser(x, beta):\n    x = convert_to_tensor(x)\n    return cast(jnp.kaiser(x, beta), config.floatx())\n\n\ndef bincount(x, weights=None, minlength=0, sparse=False):\n    # Note: bincount is never traceable / jittable because the output shape\n    # depends on the values in x.\n    if sparse or isinstance(x, jax_sparse.BCOO):\n        if isinstance(x, jax_sparse.BCOO):\n            if weights is not None:\n                if not isinstance(weights, jax_sparse.BCOO):\n                    raise ValueError(\"`x` and `weights` must both be BCOOs\")\n                if x.indices is not weights.indices:\n                    # This test works in eager mode only\n                    if not jnp.all(jnp.equal(x.indices, weights.indices)):\n                        raise ValueError(\n                            \"`x` and `weights` BCOOs must have the same indices\"\n                        )\n                weights = weights.data\n            x = x.data\n        reduction_axis = 1 if len(x.shape) > 1 else 0\n        maxlength = jnp.maximum(jnp.max(x) + 1, minlength)\n        one_hot_encoding = nn.one_hot(x, maxlength, sparse=True)\n        if weights is not None:\n            expanded_weights = jnp.expand_dims(weights, reduction_axis + 1)\n            one_hot_encoding = one_hot_encoding * expanded_weights\n\n        outputs = jax_sparse.bcoo_reduce_sum(\n            one_hot_encoding,\n            axes=(reduction_axis,),\n        )\n        return outputs\n    if len(x.shape) == 2:\n        if weights is None:\n\n            def bincount_fn(arr):\n                return jnp.bincount(arr, minlength=minlength)\n\n            bincounts = list(map(bincount_fn, x))\n        else:\n\n            def bincount_fn(arr_w):\n                return jnp.bincount(\n                    arr_w[0], weights=arr_w[1], minlength=minlength\n                )\n\n            bincounts = list(map(bincount_fn, zip(x, weights)))\n\n        return jnp.stack(bincounts)\n    return jnp.bincount(x, weights=weights, minlength=minlength)\n\n\ndef einsum(subscripts, *operands, **kwargs):\n    operands = [convert_to_tensor(x) for x in operands]\n    # When all operands are of int8, specifying `preferred_element_type` as\n    # int32 to enable hardware-accelerated einsum\n    dtypes = list(set(standardize_dtype(x.dtype) for x in operands))\n    if len(dtypes) == 1 and dtypes[0] == \"int8\":\n        preferred_element_type = \"int32\"\n    else:\n        preferred_element_type = None\n    kwargs[\"preferred_element_type\"] = preferred_element_type\n    return jnp.einsum(subscripts, *operands, **kwargs)\n\n\n@sparse.elementwise_binary_union(linear=True, use_sparsify=True)\ndef subtract(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.subtract(x1, x2)\n\n\ndef matmul(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    # When both x1 and x2 are of int8, specifying `preferred_element_type` as\n    # int32 to enable hardware-accelerated matmul\n    x1_dtype = standardize_dtype(x1.dtype)\n    x2_dtype = standardize_dtype(x2.dtype)\n    if x1_dtype == \"int8\" and x2_dtype == \"int8\":\n        preferred_element_type = \"int32\"\n    else:\n        preferred_element_type = None\n    if isinstance(x1, jax_sparse.JAXSparse) or isinstance(\n        x2, jax_sparse.JAXSparse\n    ):\n        if not hasattr(matmul, \"sparse_matmul\"):\n            matmul.sparse_matmul = jax_sparse.sparsify(jnp.matmul)\n        if isinstance(x1, jax_sparse.BCOO):\n            x1 = jax_sparse.bcoo_update_layout(\n                x1, n_batch=len(x1.shape) - 2, on_inefficient=\"warn\"\n            )\n        if isinstance(x2, jax_sparse.BCOO):\n            x2 = jax_sparse.bcoo_update_layout(\n                x2, n_batch=len(x2.shape) - 2, on_inefficient=\"warn\"\n            )\n        return matmul.sparse_matmul(\n            x1, x2, preferred_element_type=preferred_element_type\n        )\n\n    return jnp.matmul(x1, x2, preferred_element_type=preferred_element_type)\n\n\ndef multiply(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    if isinstance(x1, jax_sparse.BCOO):\n        if isinstance(x2, jax_sparse.BCOO):\n            # x1 is sparse, x2 is sparse.\n            if x1.indices is x2.indices:\n                # `bcoo_multiply_sparse` will not detect that the indices are\n                # the same, optimize this case here.\n                if not x1.unique_indices:\n                    x1 = jax_sparse.bcoo_sum_duplicates(x1)\n                    x2 = jax_sparse.bcoo_sum_duplicates(x2)\n                return jax_sparse.BCOO(\n                    (jnp.multiply(x1.data, x2.data), x1.indices),\n                    shape=x1.shape,\n                    indices_sorted=True,\n                    unique_indices=True,\n                )\n            else:\n                return jax_sparse.bcoo_multiply_sparse(x1, x2)\n        else:\n            # x1 is sparse, x2 is dense.\n            out_data = jax_sparse.bcoo_multiply_dense(x1, x2)\n            return jax_sparse.BCOO(\n                (out_data, x1.indices),\n                shape=x1.shape,\n                indices_sorted=x1.indices_sorted,\n                unique_indices=x1.unique_indices,\n            )\n    elif isinstance(x2, jax_sparse.BCOO):\n        # x1 is dense, x2 is sparse.\n        out_data = jax_sparse.bcoo_multiply_dense(x2, x1)\n        return jax_sparse.BCOO(\n            (out_data, x2.indices),\n            shape=x2.shape,\n            indices_sorted=x2.indices_sorted,\n            unique_indices=x2.unique_indices,\n        )\n    return jnp.multiply(x1, x2)\n\n\ndef mean(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    ori_dtype = standardize_dtype(x.dtype)\n    # `jnp.mean` does not handle low precision (e.g., float16) overflow\n    # correctly, so we compute with float32 and cast back to the original type.\n    compute_dtype = dtypes.result_type(x.dtype, \"float32\")\n    if \"int\" in ori_dtype or ori_dtype == \"bool\":\n        result_dtype = compute_dtype\n    else:\n        result_dtype = ori_dtype\n    if isinstance(x, jax_sparse.BCOO):\n        if axis is None:\n            axis = tuple(range(len(x.shape)))\n        (\n            canonical_axis,\n            keep_dims_shape,\n            broadcast_dimensions,\n        ) = sparse.axis_shape_dims_for_broadcast_in_dim(\n            axis, x.shape, insert_dims=False\n        )\n        divisor = math.prod(x.shape[i] for i in canonical_axis)\n        output = jax_sparse.bcoo_reduce_sum(x, axes=canonical_axis)\n        output = jax_sparse.BCOO(\n            (output.data.astype(result_dtype) / divisor, output.indices),\n            shape=output.shape,\n        )\n        if keepdims:\n            # `bcoo_reduce_sum` does not support keepdims, neither does\n            # sparsify(jnp.sum), so we recreate the empty dimensions.\n            output = jax_sparse.bcoo_broadcast_in_dim(\n                output,\n                shape=keep_dims_shape,\n                broadcast_dimensions=broadcast_dimensions,\n            )\n        return output\n    else:\n        output = jnp.mean(x, axis=axis, keepdims=keepdims, dtype=compute_dtype)\n        return cast(output, result_dtype)\n\n\ndef max(x, axis=None, keepdims=False, initial=None):\n    x = convert_to_tensor(x)\n    return jnp.max(x, axis=axis, keepdims=keepdims, initial=initial)\n\n\ndef ones(shape, dtype=None):\n    dtype = dtype or config.floatx()\n    return jnp.ones(shape, dtype=dtype)\n\n\ndef zeros(shape, dtype=None):\n    dtype = dtype or config.floatx()\n    return jnp.zeros(shape, dtype=dtype)\n\n\n@sparse.elementwise_unary(linear=False)\ndef absolute(x):\n    x = convert_to_tensor(x)\n    return jnp.absolute(x)\n\n\ndef abs(x):\n    return absolute(x)\n\n\ndef all(x, axis=None, keepdims=False):\n    return jnp.all(x, axis=axis, keepdims=keepdims)\n\n\ndef angle(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = cast(x, dtype)\n    return jnp.angle(x)\n\n\ndef any(x, axis=None, keepdims=False):\n    return jnp.any(x, axis=axis, keepdims=keepdims)\n\n\ndef amax(x, axis=None, keepdims=False):\n    return jnp.amax(x, axis=axis, keepdims=keepdims)\n\n\ndef amin(x, axis=None, keepdims=False):\n    return jnp.amin(x, axis=axis, keepdims=keepdims)\n\n\ndef append(x1, x2, axis=None):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.append(x1, x2, axis=axis)\n\n\ndef arange(start, stop=None, step=None, dtype=None):\n    def get_dtype(x):\n        if hasattr(x, \"dtype\"):\n            return x.dtype\n        if jax_export.is_symbolic_dim(x):\n            return int\n        return type(x)\n\n    if dtype is None:\n        dtypes_to_resolve = [get_dtype(start)]\n        if stop is not None:\n            dtypes_to_resolve.append(get_dtype(stop))\n        if step is not None:\n            dtypes_to_resolve.append(get_dtype(step))\n        dtype = dtypes.result_type(*dtypes_to_resolve)\n    dtype = standardize_dtype(dtype)\n    return jnp.arange(start, stop, step=step, dtype=dtype)\n\n\n@sparse.densifying_unary\ndef arccos(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = cast(x, dtype)\n    return jnp.arccos(x)\n\n\n@sparse.densifying_unary\ndef arccosh(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = cast(x, dtype)\n    return jnp.arccosh(x)\n\n\n@sparse.elementwise_unary(linear=False)\ndef arcsin(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = cast(x, dtype)\n    return jnp.arcsin(x)\n\n\n@sparse.elementwise_unary(linear=False)\ndef arcsinh(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = cast(x, dtype)\n    return jnp.arcsinh(x)\n\n\n@sparse.elementwise_unary(linear=False)\ndef arctan(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = cast(x, dtype)\n    return jnp.arctan(x)\n\n\ndef arctan2(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype, float)\n    x1 = cast(x1, dtype)\n    x2 = cast(x2, dtype)\n    return jnp.arctan2(x1, x2)\n\n\n@sparse.elementwise_unary(linear=False)\ndef arctanh(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = cast(x, dtype)\n    return jnp.arctanh(x)\n\n\ndef argmax(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    dtype = standardize_dtype(x.dtype)\n    if \"float\" not in dtype or x.ndim == 0 or not _uses_cpu(x):\n        return jnp.argmax(x, axis=axis, keepdims=keepdims)\n\n    # Fix the flush-to-zero (FTZ) issue based on this issue:\n    # https://github.com/jax-ml/jax/issues/24280\n    dtype = dtypes.result_type(dtype, \"float32\")\n    x = cast(x, dtype)\n    is_negative_zero = (x == 0.0) & jnp.signbit(x)\n    x = jnp.where(is_negative_zero, -jnp.finfo(x.dtype).tiny, x)\n    return jnp.argmax(x, axis=axis, keepdims=keepdims)\n\n\ndef argmin(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    dtype = standardize_dtype(x.dtype)\n    if \"float\" not in dtype or x.ndim == 0 or not _uses_cpu(x):\n        return jnp.argmin(x, axis=axis, keepdims=keepdims)\n\n    # Fix the flush-to-zero (FTZ) issue based on this issue:\n    # https://github.com/jax-ml/jax/issues/24280\n    dtype = dtypes.result_type(dtype, \"float32\")\n    x = cast(x, dtype)\n    is_negative_zero = (x == 0.0) & jnp.signbit(x)\n    x = jnp.where(is_negative_zero, -jnp.finfo(x.dtype).tiny, x)\n    return jnp.argmin(x, axis=axis, keepdims=keepdims)\n\n\ndef argsort(x, axis=-1):\n    x = convert_to_tensor(x)\n    if x.ndim == 0:\n        return jnp.argsort(x, axis=None)\n    return jnp.argsort(x, axis=axis)\n\n\ndef array(x, dtype=None):\n    return jnp.array(x, dtype=dtype)\n\n\ndef view(x, dtype=None):\n    x = convert_to_tensor(x)\n    return x.view(dtype=dtype)\n\n\ndef average(x, axis=None, weights=None):\n    x = convert_to_tensor(x)\n    dtypes_to_resolve = [x.dtype, float]\n    if weights is not None:\n        weights = convert_to_tensor(weights)\n        dtypes_to_resolve.append(weights.dtype)\n    dtype = dtypes.result_type(*dtypes_to_resolve)\n    x = cast(x, dtype)\n    if weights is not None:\n        weights = cast(weights, dtype)\n    return jnp.average(x, weights=weights, axis=axis)\n\n\ndef bitwise_and(x, y):\n    x = convert_to_tensor(x)\n    y = convert_to_tensor(y)\n    return jnp.bitwise_and(x, y)\n\n\ndef bitwise_invert(x):\n    x = convert_to_tensor(x)\n    return jnp.invert(x)\n\n\ndef bitwise_not(x):\n    return bitwise_invert(x)\n\n\ndef bitwise_or(x, y):\n    x = convert_to_tensor(x)\n    y = convert_to_tensor(y)\n    return jnp.bitwise_or(x, y)\n\n\ndef bitwise_xor(x, y):\n    x = convert_to_tensor(x)\n    y = convert_to_tensor(y)\n    return jnp.bitwise_xor(x, y)\n\n\ndef bitwise_left_shift(x, y):\n    x = convert_to_tensor(x)\n    if not isinstance(y, int):\n        y = convert_to_tensor(y)\n    return jnp.left_shift(x, y)\n\n\ndef left_shift(x, y):\n    return bitwise_left_shift(x, y)\n\n\ndef bitwise_right_shift(x, y):\n    x = convert_to_tensor(x)\n    if not isinstance(y, int):\n        y = convert_to_tensor(y)\n    return jnp.right_shift(x, y)\n\n\ndef right_shift(x, y):\n    return bitwise_right_shift(x, y)\n\n\ndef blackman(x):\n    x = convert_to_tensor(x)\n    return cast(jnp.blackman(x), config.floatx())\n\n\ndef broadcast_to(x, shape):\n    x = convert_to_tensor(x)\n    return jnp.broadcast_to(x, shape)\n\n\ndef cbrt(x):\n    x = convert_to_tensor(x)\n    return jnp.cbrt(x)\n\n\n@sparse.elementwise_unary(linear=False)\ndef ceil(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = cast(x, dtype)\n    return jnp.ceil(x)\n\n\ndef clip(x, x_min, x_max):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"bool\":\n        x = cast(x, \"int32\")\n    return jnp.clip(x, x_min, x_max)\n\n\ndef concatenate(xs, axis=0):\n    bcoo_count = builtins.sum(isinstance(x, jax_sparse.BCOO) for x in xs)\n    if bcoo_count == len(xs):\n        axis = canonicalize_axis(axis, len(xs[0].shape))\n        return jax_sparse.bcoo_concatenate(xs, dimension=axis)\n    elif bcoo_count:\n        xs = [\n            x.todense()\n            if isinstance(x, jax_sparse.JAXSparse)\n            else convert_to_tensor(x)\n            for x in xs\n        ]\n    else:\n        xs = [convert_to_tensor(x) for x in xs]\n    return jnp.concatenate(xs, axis=axis)\n\n\n@sparse.elementwise_unary(linear=True)\ndef conjugate(x):\n    x = convert_to_tensor(x)\n    return jnp.conjugate(x)\n\n\n@sparse.elementwise_unary(linear=True)\ndef conj(x):\n    x = convert_to_tensor(x)\n    return jnp.conjugate(x)\n\n\n@sparse.elementwise_unary(linear=True)\ndef copy(x):\n    x = convert_to_tensor(x)\n    return jnp.copy(x)\n\n\n@sparse.densifying_unary\ndef cos(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = cast(x, dtype)\n    return jnp.cos(x)\n\n\n@sparse.densifying_unary\ndef cosh(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = cast(x, dtype)\n    return jnp.cosh(x)\n\n\ndef count_nonzero(x, axis=None):\n    return cast(jnp.count_nonzero(x, axis=axis), \"int32\")\n\n\ndef cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.cross(\n        x1,\n        x2,\n        axisa=axisa,\n        axisb=axisb,\n        axisc=axisc,\n        axis=axis,\n    )\n\n\ndef cumprod(x, axis=None, dtype=None):\n    x = convert_to_tensor(x)\n    return jnp.cumprod(x, axis=axis, dtype=dtype)\n\n\ndef cumsum(x, axis=None, dtype=None):\n    x = convert_to_tensor(x)\n    return jnp.cumsum(x, axis=axis, dtype=dtype)\n\n\ndef deg2rad(x):\n    x = convert_to_tensor(x)\n    return jnp.deg2rad(x)\n\n\ndef diag(x, k=0):\n    x = convert_to_tensor(x)\n    return jnp.diag(x, k=k)\n\n\ndef diagflat(x, k=0):\n    x = convert_to_tensor(x)\n    return jnp.diagflat(x, k=k)\n\n\ndef diagonal(x, offset=0, axis1=0, axis2=1):\n    x = convert_to_tensor(x)\n    return jnp.diagonal(\n        x,\n        offset=offset,\n        axis1=axis1,\n        axis2=axis2,\n    )\n\n\ndef diff(a, n=1, axis=-1):\n    a = convert_to_tensor(a)\n    return jnp.diff(a, n=n, axis=axis)\n\n\n@sparse.elementwise_unary(linear=False)\ndef digitize(x, bins):\n    x = convert_to_tensor(x)\n    bins = convert_to_tensor(bins)\n    return jnp.digitize(x, bins)\n\n\ndef dot(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.dot(x1, x2)\n\n\ndef dstack(xs):\n    return jnp.dstack(xs)\n\n\ndef empty(shape, dtype=None):\n    dtype = dtype or config.floatx()\n    return jnp.empty(shape, dtype=dtype)\n\n\ndef empty_like(x, dtype=None):\n    return jnp.empty_like(x, dtype=dtype)\n\n\ndef equal(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.equal(x1, x2)\n\n\n@sparse.densifying_unary\ndef exp(x):\n    x = convert_to_tensor(x)\n    ori_dtype = standardize_dtype(x.dtype)\n    if \"int\" in ori_dtype or ori_dtype == \"bool\":\n        x = cast(x, config.floatx())\n    return jnp.exp(x)\n\n\n@sparse.densifying_unary\ndef exp2(x):\n    x = convert_to_tensor(x)\n    ori_dtype = standardize_dtype(x.dtype)\n    if \"int\" in ori_dtype or ori_dtype == \"bool\":\n        x = cast(x, config.floatx())\n    return jnp.exp2(x)\n\n\ndef expand_dims(x, axis):\n    x = convert_to_tensor(x)\n    if isinstance(x, jax_sparse.BCOO):\n        (\n            _,\n            result_shape,\n            broadcast_dimensions,\n        ) = sparse.axis_shape_dims_for_broadcast_in_dim(\n            axis, x.shape, insert_dims=True\n        )\n        return jax_sparse.bcoo_broadcast_in_dim(\n            x, shape=result_shape, broadcast_dimensions=broadcast_dimensions\n        )\n    return jnp.expand_dims(x, axis)\n\n\n@sparse.elementwise_unary(linear=False)\ndef expm1(x):\n    x = convert_to_tensor(x)\n    ori_dtype = standardize_dtype(x.dtype)\n    if \"int\" in ori_dtype or ori_dtype == \"bool\":\n        x = cast(x, config.floatx())\n    return jnp.expm1(x)\n\n\ndef flip(x, axis=None):\n    return jnp.flip(x, axis=axis)\n\n\n@sparse.elementwise_unary(linear=False)\ndef floor(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = cast(x, dtype)\n    return jnp.floor(x)\n\n\ndef full(shape, fill_value, dtype=None):\n    dtype = dtype or config.floatx()\n    return jnp.full(shape, fill_value, dtype=dtype)\n\n\ndef full_like(x, fill_value, dtype=None):\n    return jnp.full_like(x, fill_value, dtype=dtype)\n\n\ndef gcd(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.gcd(x1, x2)\n\n\ndef geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):\n    return jnp.geomspace(\n        start, stop, num=num, endpoint=endpoint, dtype=dtype, axis=axis\n    )\n\n\ndef greater(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.greater(x1, x2)\n\n\ndef greater_equal(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.greater_equal(x1, x2)\n\n\ndef hstack(xs):\n    return jnp.hstack(xs)\n\n\ndef hsplit(x, indices_or_sections):\n    x = convert_to_tensor(x)\n    return jnp.hsplit(x, indices_or_sections)\n\n\ndef identity(n, dtype=None):\n    dtype = dtype or config.floatx()\n    return jnp.identity(n, dtype=dtype)\n\n\n@sparse.elementwise_unary(linear=True)\ndef imag(x):\n    x = convert_to_tensor(x)\n    return jnp.imag(x)\n\n\ndef isclose(x1, x2, rtol=1e-5, atol=1e-8, equal_nan=False):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.isclose(x1, x2, rtol, atol, equal_nan)\n\n\ndef allclose(x1, x2, rtol=1e-5, atol=1e-8, equal_nan=False):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.allclose(x1, x2, rtol, atol, equal_nan)\n\n\n@sparse.densifying_unary\ndef isfinite(x):\n    x = convert_to_tensor(x)\n    return jnp.isfinite(x)\n\n\ndef isin(x1, x2, assume_unique=False, invert=False):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.isin(x1, x2, assume_unique=assume_unique, invert=invert)\n\n\n@sparse.elementwise_unary(linear=False)\ndef isinf(x):\n    x = convert_to_tensor(x)\n    return jnp.isinf(x)\n\n\n@sparse.elementwise_unary(linear=False)\ndef isnan(x):\n    x = convert_to_tensor(x)\n    return jnp.isnan(x)\n\n\ndef isneginf(x):\n    x = convert_to_tensor(x)\n    return jnp.isneginf(x)\n\n\ndef isposinf(x):\n    x = convert_to_tensor(x)\n    return jnp.isposinf(x)\n\n\ndef isreal(x):\n    x = convert_to_tensor(x)\n    return jnp.isreal(x)\n\n\ndef kron(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.kron(x1, x2)\n\n\ndef lcm(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.lcm(x1, x2)\n\n\ndef ldexp(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n\n    if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:\n        raise TypeError(\n            f\"ldexp exponent must be an integer type. \"\n            f\"Received: x2 dtype={x2.dtype}\"\n        )\n\n    return jnp.ldexp(x1, x2)\n\n\ndef less(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.less(x1, x2)\n\n\ndef less_equal(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.less_equal(x1, x2)\n\n\ndef linspace(\n    start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0\n):\n    return jnp.linspace(\n        start,\n        stop,\n        num=num,\n        endpoint=endpoint,\n        retstep=retstep,\n        dtype=dtype,\n        axis=axis,\n    )\n\n\n@sparse.densifying_unary\ndef log(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        x = cast(x, config.floatx())\n    return jnp.log(x)\n\n\n@sparse.densifying_unary\ndef log10(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        x = cast(x, config.floatx())\n    return jnp.log10(x)\n\n\n@sparse.elementwise_unary(linear=False)\ndef log1p(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        x = cast(x, config.floatx())\n    return jnp.log1p(x)\n\n\n@sparse.densifying_unary\ndef log2(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        x = cast(x, config.floatx())\n    return jnp.log2(x)\n\n\ndef logaddexp(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype, float)\n    x1 = cast(x1, dtype)\n    x2 = cast(x2, dtype)\n    return jnp.logaddexp(x1, x2)\n\n\ndef logaddexp2(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype, float)\n    x1 = cast(x1, dtype)\n    x2 = cast(x2, dtype)\n    return jnp.logaddexp2(x1, x2)\n\n\ndef logical_and(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.logical_and(x1, x2)\n\n\ndef logical_not(x):\n    x = convert_to_tensor(x)\n    return jnp.logical_not(x)\n\n\ndef logical_or(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.logical_or(x1, x2)\n\n\ndef logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0):\n    return jnp.logspace(\n        start,\n        stop,\n        num=num,\n        endpoint=endpoint,\n        base=base,\n        dtype=dtype,\n        axis=axis,\n    )\n\n\n@sparse.elementwise_binary_union(linear=False, use_sparsify=False)\ndef maximum(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.maximum(x1, x2)\n\n\ndef median(x, axis=None, keepdims=False):\n    # axis of jnp.median must be hashable\n    if isinstance(axis, list):\n        axis = tuple(axis)\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        x = cast(x, config.floatx())\n\n    result = jnp.median(x, axis=axis, keepdims=keepdims)\n\n    # TODO: with jax < 0.4.26 jnp.median failed to keepdims when axis is None\n    if keepdims is True and axis is None:\n        while result.ndim < x.ndim:\n            result = jnp.expand_dims(result, axis=-1)\n    return result\n\n\ndef meshgrid(*x, indexing=\"xy\"):\n    return jnp.meshgrid(*x, indexing=indexing)\n\n\ndef min(x, axis=None, keepdims=False, initial=None):\n    x = convert_to_tensor(x)\n    return jnp.min(x, axis=axis, keepdims=keepdims, initial=initial)\n\n\n@sparse.elementwise_binary_union(linear=False, use_sparsify=False)\ndef minimum(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.minimum(x1, x2)\n\n\ndef mod(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.mod(x1, x2)\n\n\ndef fmod(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.fmod(x1, x2)\n\n\ndef moveaxis(x, source, destination):\n    return jnp.moveaxis(x, source=source, destination=destination)\n\n\ndef nanargmax(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    return jnp.nanargmax(x, axis=axis, keepdims=keepdims)\n\n\ndef nanargmin(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    return jnp.nanargmin(x, axis=axis, keepdims=keepdims)\n\n\ndef nancumsum(x, axis=None, dtype=None):\n    x = convert_to_tensor(x)\n    return jnp.nancumsum(x, axis=axis, dtype=dtype)\n\n\ndef nancumprod(x, axis=None, dtype=None):\n    x = convert_to_tensor(x)\n    return jnp.nancumprod(x, axis=axis, dtype=dtype)\n\n\ndef nanmax(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    return jnp.nanmax(x, axis=axis, keepdims=keepdims)\n\n\ndef nanmean(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    return jnp.nanmean(x, axis=axis, keepdims=keepdims)\n\n\ndef nanmin(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    return jnp.nanmin(x, axis=axis, keepdims=keepdims)\n\n\ndef nanprod(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    return jnp.nanprod(x, axis=axis, keepdims=keepdims)\n\n\ndef nanstd(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    return jnp.nanstd(x, axis=axis, keepdims=keepdims)\n\n\ndef nansum(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    return jnp.nansum(x, axis=axis, keepdims=keepdims)\n\n\ndef nanvar(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    return jnp.nanvar(x, axis=axis, keepdims=keepdims)\n\n\ndef nan_to_num(x, nan=0.0, posinf=None, neginf=None):\n    x = convert_to_tensor(x)\n    return jnp.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf)\n\n\ndef ndim(x):\n    return jnp.ndim(x)\n\n\ndef nonzero(x):\n    return jnp.nonzero(x)\n\n\ndef not_equal(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.not_equal(x1, x2)\n\n\ndef ones_like(x, dtype=None):\n    return jnp.ones_like(x, dtype=dtype)\n\n\ndef zeros_like(x, dtype=None):\n    return jnp.zeros_like(x, dtype=dtype)\n\n\ndef outer(x1, x2):\n    return jnp.outer(x1, x2)\n\n\ndef pad(x, pad_width, mode=\"constant\", constant_values=None):\n    x = convert_to_tensor(x)\n    kwargs = {}\n    if constant_values is not None:\n        if mode != \"constant\":\n            raise ValueError(\n                \"Argument `constant_values` can only be \"\n                \"provided when `mode == 'constant'`. \"\n                f\"Received: mode={mode}\"\n            )\n        kwargs[\"constant_values\"] = constant_values\n    return jnp.pad(x, pad_width, mode=mode, **kwargs)\n\n\ndef prod(x, axis=None, keepdims=False, dtype=None):\n    x = convert_to_tensor(x)\n    return jnp.prod(x, axis=axis, keepdims=keepdims, dtype=dtype)\n\n\ndef ptp(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    return jnp.ptp(x, axis=axis, keepdims=keepdims)\n\n\ndef quantile(x, q, axis=None, method=\"linear\", keepdims=False):\n    x = convert_to_tensor(x)\n    q = convert_to_tensor(q)\n    if standardize_dtype(x.dtype) == \"int64\":\n        x = cast(x, config.floatx())\n\n    result = jnp.quantile(x, q, axis=axis, method=method, keepdims=keepdims)\n\n    # TODO: with jax < 0.4.26 jnp.quantile failed to keepdims when axis is None\n    if keepdims is True and axis is None:\n        result_ndim = x.ndim + (1 if len(q.shape) > 0 else 0)\n        while result.ndim < result_ndim:\n            result = jnp.expand_dims(result, axis=-1)\n    return result\n\n\ndef ravel(x):\n    x = convert_to_tensor(x)\n    return jnp.ravel(x)\n\n\ndef unravel_index(indices, shape):\n    indices = convert_to_tensor(indices)\n    return jnp.unravel_index(indices, shape)\n\n\n@sparse.elementwise_unary(linear=True)\ndef real(x):\n    x = convert_to_tensor(x)\n    return jnp.real(x)\n\n\n@sparse.densifying_unary\ndef reciprocal(x):\n    x = convert_to_tensor(x)\n    return jnp.reciprocal(x)\n\n\ndef repeat(x, repeats, axis=None):\n    x = convert_to_tensor(x)\n    return jnp.repeat(x, repeats, axis=axis)\n\n\ndef reshape(x, newshape):\n    if isinstance(x, jax_sparse.BCOO):\n        from keras.src.ops import operation_utils\n\n        # Resolve the -1 in `new_shape` if applicable and possible\n        output_shape = operation_utils.compute_reshape_output_shape(\n            x.shape, newshape, \"new_shape\"\n        )\n        if None not in output_shape:\n            newshape = output_shape\n        return jax_sparse.bcoo_reshape(x, new_sizes=newshape)\n    x = convert_to_tensor(x)\n    return jnp.reshape(x, newshape)\n\n\ndef roll(x, shift, axis=None):\n    return jnp.roll(x, shift, axis=axis)\n\n\ndef searchsorted(sorted_sequence, values, side=\"left\"):\n    if ndim(sorted_sequence) != 1:\n        raise ValueError(\n            \"`searchsorted` only supports 1-D sorted sequences. \"\n            \"You can use `keras.ops.vectorized_map` \"\n            \"to extend it to N-D sequences. Received: \"\n            f\"sorted_sequence.shape={sorted_sequence.shape}\"\n        )\n    return jnp.searchsorted(sorted_sequence, values, side=side)\n\n\n@sparse.elementwise_unary(linear=False)\ndef sign(x):\n    x = convert_to_tensor(x)\n    return jnp.sign(x)\n\n\n@sparse.elementwise_unary(linear=False)\ndef signbit(x):\n    x = convert_to_tensor(x)\n    return jnp.signbit(x)\n\n\n@sparse.elementwise_unary(linear=False)\ndef sin(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = cast(x, dtype)\n    return jnp.sin(x)\n\n\ndef sinc(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = cast(x, dtype)\n    return jnp.sinc(x)\n\n\n@sparse.elementwise_unary(linear=False)\ndef sinh(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = cast(x, dtype)\n    return jnp.sinh(x)\n\n\ndef size(x):\n    return jnp.size(x)\n\n\ndef sort(x, axis=-1):\n    x = convert_to_tensor(x)\n    return jnp.sort(x, axis=axis)\n\n\ndef split(x, indices_or_sections, axis=0):\n    x = convert_to_tensor(x)\n    return jnp.split(x, indices_or_sections, axis=axis)\n\n\ndef array_split(x, indices_or_sections, axis=0):\n    x = convert_to_tensor(x)\n    return jnp.array_split(x, indices_or_sections, axis=axis)\n\n\ndef stack(x, axis=0):\n    x = [convert_to_tensor(t) for t in x]\n    return jnp.stack(x, axis=axis)\n\n\ndef std(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        x = cast(x, config.floatx())\n    return jnp.std(x, axis=axis, keepdims=keepdims)\n\n\ndef swapaxes(x, axis1, axis2):\n    x = convert_to_tensor(x)\n    return jnp.swapaxes(x, axis1=axis1, axis2=axis2)\n\n\ndef take(x, indices, axis=None):\n    x = convert_to_tensor(x)\n    indices = convert_to_tensor(indices, sparse=False)\n    return jnp.take(x, indices, axis=axis)\n\n\ndef take_along_axis(x, indices, axis=None):\n    x = convert_to_tensor(x)\n    indices = convert_to_tensor(indices, sparse=False)\n    return jnp.take_along_axis(x, indices, axis=axis)\n\n\n@sparse.elementwise_unary(linear=False)\ndef tan(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = cast(x, dtype)\n    return jnp.tan(x)\n\n\n@sparse.elementwise_unary(linear=False)\ndef tanh(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = cast(x, dtype)\n    return jnp.tanh(x)\n\n\ndef tensordot(x1, x2, axes=2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.tensordot(x1, x2, axes=axes)\n\n\n@sparse.elementwise_unary(linear=False)\ndef round(x, decimals=0):\n    x = convert_to_tensor(x)\n\n    # jnp.round doesn't support decimals < 0 for integers\n    x_dtype = standardize_dtype(x.dtype)\n    if \"int\" in x_dtype and decimals < 0:\n        factor = cast(math.pow(10, decimals), config.floatx())\n        x = cast(x, config.floatx())\n        x = jnp.multiply(x, factor)\n        x = jnp.round(x)\n        x = jnp.divide(x, factor)\n        return cast(x, x_dtype)\n    else:\n        return jnp.round(x, decimals=decimals)\n\n\ndef tile(x, repeats):\n    return jnp.tile(x, repeats)\n\n\ndef trace(x, offset=0, axis1=0, axis2=1):\n    x = convert_to_tensor(x)\n    return jnp.trace(x, offset=offset, axis1=axis1, axis2=axis2)\n\n\ndef tri(N, M=None, k=0, dtype=None):\n    dtype = dtype or config.floatx()\n    return jnp.tri(N, M=M, k=k, dtype=dtype)\n\n\ndef tril(x, k=0):\n    x = convert_to_tensor(x)\n    return jnp.tril(x, k=k)\n\n\ndef triu(x, k=0):\n    x = convert_to_tensor(x)\n    return jnp.triu(x, k=k)\n\n\ndef trunc(x):\n    x = convert_to_tensor(x)\n    dtype = standardize_dtype(x.dtype)\n    if \"int\" in dtype or \"bool\" == dtype:\n        return x\n    return jnp.trunc(x)\n\n\ndef vdot(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.vdot(x1, x2)\n\n\ndef inner(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.inner(x1, x2)\n\n\ndef vstack(xs):\n    return jnp.vstack(xs)\n\n\ndef vsplit(x, indices_or_sections):\n    x = convert_to_tensor(x)\n    return jnp.vsplit(x, indices_or_sections)\n\n\ndef vectorize(pyfunc, *, excluded=None, signature=None):\n    if excluded is None:\n        excluded = set()\n    return jnp.vectorize(pyfunc, excluded=excluded, signature=signature)\n\n\ndef where(condition, x1=None, x2=None):\n    return jnp.where(condition, x1, x2)\n\n\n@sparse.elementwise_division\ndef divide(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.divide(x1, x2)\n\n\ndef divide_no_nan(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    safe_x2 = jnp.where(x2 == 0, 1, x2)\n    return jnp.where(x2 == 0, 0, jnp.divide(x1, safe_x2))\n\n\ndef true_divide(x1, x2):\n    return divide(x1, x2)\n\n\ndef power(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.power(x1, x2)\n\n\n@sparse.elementwise_unary(linear=True)\ndef negative(x):\n    x = convert_to_tensor(x)\n    return jnp.negative(x)\n\n\ndef nextafter(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.nextafter(x1, x2)\n\n\n@sparse.elementwise_unary(linear=False)\ndef square(x):\n    x = convert_to_tensor(x)\n    return jnp.square(x)\n\n\n@sparse.elementwise_unary(linear=False)\ndef sqrt(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        x = cast(x, config.floatx())\n    return jnp.sqrt(x)\n\n\ndef squeeze(x, axis=None):\n    if isinstance(x, jax_sparse.BCOO):\n        if axis is None:\n            axis = tuple(i for i, d in enumerate(x.shape) if d == 1)\n        axis = to_tuple_or_list(axis)\n        return jax_sparse.bcoo_squeeze(x, dimensions=axis)\n    x = convert_to_tensor(x)\n    return jnp.squeeze(x, axis=axis)\n\n\ndef transpose(x, axes=None):\n    x = convert_to_tensor(x)\n    if isinstance(x, jax_sparse.BCOO):\n        num_dims = len(x.shape)\n        if axes is None:\n            permutation = tuple(range(num_dims)[::-1])\n        else:\n            permutation = []\n            for a in axes:\n                a = canonicalize_axis(a, num_dims)\n                permutation.append(a)\n        return jax_sparse.bcoo_transpose(x, permutation=permutation)\n    return jnp.transpose(x, axes=axes)\n\n\ndef trapezoid(y, x=None, dx=1.0, axis=-1):\n    y = convert_to_tensor(y)\n    if x is not None:\n        x = convert_to_tensor(x)\n    dx = convert_to_tensor(dx)\n    return jnp.trapezoid(y, x, dx=dx, axis=axis)\n\n\ndef vander(x, N=None, increasing=False):\n    x = convert_to_tensor(x)\n    return jnp.vander(x, N=N, increasing=increasing)\n\n\ndef var(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    # `jnp.var` does not handle low precision (e.g., float16) overflow\n    # correctly, so we compute with float32 and cast back to the original type.\n    compute_dtype = dtypes.result_type(x.dtype, \"float32\")\n    result_dtype = dtypes.result_type(x.dtype, float)\n    return cast(\n        jnp.var(x, axis=axis, keepdims=keepdims, dtype=compute_dtype),\n        result_dtype,\n    )\n\n\ndef sum(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    if isinstance(x, jax_sparse.BCOO):\n        if axis is None:\n            axis = tuple(range(len(x.shape)))\n        (\n            canonical_axis,\n            keep_dims_shape,\n            broadcast_dimensions,\n        ) = sparse.axis_shape_dims_for_broadcast_in_dim(\n            axis, x.shape, insert_dims=False\n        )\n        output = jax_sparse.bcoo_reduce_sum(x, axes=canonical_axis)\n        if keepdims:\n            # `bcoo_reduce_sum` does not support keepdims, neither does\n            # sparsify(jnp.sum), so we recreate the empty dimensions.\n            output = jax_sparse.bcoo_broadcast_in_dim(\n                output,\n                shape=keep_dims_shape,\n                broadcast_dimensions=broadcast_dimensions,\n            )\n        return output\n    return jnp.sum(x, axis=axis, keepdims=keepdims)\n\n\ndef eye(N, M=None, k=0, dtype=None):\n    dtype = dtype or config.floatx()\n    return jnp.eye(N, M=M, k=k, dtype=dtype)\n\n\ndef floor_divide(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.floor_divide(x1, x2)\n\n\ndef logical_xor(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.logical_xor(x1, x2)\n\n\ndef corrcoef(x):\n    x = convert_to_tensor(x)\n    return jnp.corrcoef(x)\n\n\ndef correlate(x1, x2, mode=\"valid\"):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return jnp.correlate(x1, x2, mode)\n\n\ndef select(condlist, choicelist, default=0):\n    return jnp.select(condlist, choicelist, default=default)\n\n\ndef slogdet(x):\n    x = convert_to_tensor(x)\n    return tuple(jnp.linalg.slogdet(x))\n\n\ndef argpartition(x, kth, axis=-1):\n    return jnp.argpartition(x, kth, axis)\n\n\ndef histogram(x, bins=10, range=None):\n    return jnp.histogram(x, bins=bins, range=range)\n"
  },
  {
    "path": "keras/src/backend/jax/optimizer.py",
    "content": "\"\"\"A class for JAX specific optimizer logic.\n\nIts purpose is to route around statelessness\nrequirements in cond ops used for EMA handling\nand gradient accumulation handling. We do this\nby skipping conditionals entirely.\n\"\"\"\n\nimport jax\nfrom jax import numpy as jnp\n\nfrom keras.src.optimizers import base_optimizer\n\n\nclass JaxOptimizer(base_optimizer.BaseOptimizer):\n    def _backend_apply_gradients(self, grads, trainable_variables):\n        if self.gradient_accumulation_steps:\n            is_update_step = (\n                self._iterations + 1\n            ) % self.gradient_accumulation_steps == 0\n            steps = self.gradient_accumulation_steps\n\n            current_trainable_vars_value = [\n                v.value for v in trainable_variables\n            ]\n            current_optimizer_vars_value = [v.value for v in self.variables]\n\n            # `trainable_variables` might have been filtered in previous\n            # processing steps, so we need to ensure the correct mapping between\n            # `self._accumulated_gradients` and `trainable_variables`\n            acc_grads = [\n                self._accumulated_gradients[self._get_variable_index(v)]\n                for v in trainable_variables\n            ]\n\n            new_g_accs = jax.lax.cond(\n                is_update_step,\n                lambda: [jnp.zeros(g.shape, dtype=g.dtype) for g in acc_grads],\n                lambda: [g + acc_g.value for g, acc_g in zip(grads, acc_grads)],\n            )\n\n            grads = jax.lax.cond(\n                is_update_step,\n                lambda: [\n                    (g + acc_g.value) / steps\n                    for g, acc_g in zip(grads, acc_grads)\n                ],\n                lambda: list(grads),\n            )\n\n            # Apply clipping and weight decay.\n            grads = self._clip_gradients(grads)\n            self._apply_weight_decay(trainable_variables)\n\n            self._backend_update_step(\n                grads, trainable_variables, self.learning_rate\n            )\n            new_trainable_vars = jax.lax.cond(\n                is_update_step,\n                lambda: [v.value for v in trainable_variables],\n                lambda: current_trainable_vars_value,\n            )\n            new_opt_vars = jax.lax.cond(\n                is_update_step,\n                lambda: [v.value for v in self.variables],\n                lambda: current_optimizer_vars_value,\n            )\n\n            for value, v in zip(new_trainable_vars, trainable_variables):\n                v.assign(value)\n\n            for value, v in zip(new_opt_vars, self.variables):\n                v.assign(value)\n\n            for n_g_acc, g_acc in zip(new_g_accs, acc_grads):\n                g_acc.assign(n_g_acc)\n\n        else:\n            # Apply clipping and weight decay.\n            grads = self._clip_gradients(grads)\n            self._apply_weight_decay(trainable_variables)\n\n            self._backend_update_step(\n                grads, trainable_variables, self.learning_rate\n            )\n\n        if self.use_ema:\n            self._update_model_variables_moving_average(\n                self._trainable_variables\n            )\n            if self.ema_overwrite_frequency is not None:\n                should_overwrite_model_vars = (\n                    self.iterations + 1\n                ) % self.ema_overwrite_frequency == 0\n                should_overwrite_model_vars_int = (\n                    should_overwrite_model_vars.astype(\"int32\")\n                )\n                should_not_overwrite_model_vars_int = jnp.logical_not(\n                    should_overwrite_model_vars\n                ).astype(\"int32\")\n                current_trainable_vars_value = [\n                    v.value for v in self._trainable_variables\n                ]\n                for var, average_var in zip(\n                    self._trainable_variables,\n                    self._model_variables_moving_average,\n                ):\n                    var.assign(\n                        average_var * should_overwrite_model_vars_int\n                        + var.value * should_not_overwrite_model_vars_int\n                    )\n"
  },
  {
    "path": "keras/src/backend/jax/random.py",
    "content": "import jax\n\nfrom keras.src.backend.config import floatx\nfrom keras.src.random.seed_generator import SeedGenerator\nfrom keras.src.random.seed_generator import draw_seed\nfrom keras.src.random.seed_generator import make_default_seed\n\n\ndef jax_draw_seed(seed):\n    if isinstance(seed, jax.Array):\n        return seed\n    else:\n        return draw_seed(seed)\n\n\ndef normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):\n    dtype = dtype or floatx()\n    seed = jax_draw_seed(seed)\n    sample = jax.random.normal(seed, shape=shape, dtype=dtype)\n    return sample * stddev + mean\n\n\ndef uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):\n    dtype = dtype or floatx()\n    seed = jax_draw_seed(seed)\n    return jax.random.uniform(\n        seed, shape=shape, dtype=dtype, minval=minval, maxval=maxval\n    )\n\n\ndef categorical(logits, num_samples, dtype=\"int32\", seed=None):\n    seed = jax_draw_seed(seed)\n    output_shape = list(logits.shape)\n    output_shape[1] = num_samples\n    output_shape = tuple(output_shape)\n    output = jax.random.categorical(\n        seed, logits[..., None], shape=output_shape, axis=1\n    )\n    return output.astype(dtype)\n\n\ndef randint(shape, minval, maxval, dtype=\"int32\", seed=None):\n    seed = jax_draw_seed(seed)\n    return jax.random.randint(\n        seed, shape=shape, dtype=dtype, minval=minval, maxval=maxval\n    )\n\n\ndef truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):\n    dtype = dtype or floatx()\n    seed = jax_draw_seed(seed)\n    sample = jax.random.truncated_normal(\n        seed, shape=shape, lower=-2.0, upper=2.0, dtype=dtype\n    )\n    return sample * stddev + mean\n\n\ndef _get_concrete_noise_shape(inputs, noise_shape):\n    if noise_shape is None:\n        return inputs.shape\n\n    concrete_inputs_shape = inputs.shape\n    concrete_noise_shape = []\n    for i, value in enumerate(noise_shape):\n        concrete_noise_shape.append(\n            concrete_inputs_shape[i] if value is None else value\n        )\n    return concrete_noise_shape\n\n\ndef dropout(inputs, rate, noise_shape=None, seed=None):\n    if rate == 1.0:\n        return jax.numpy.zeros_like(inputs)\n    if rate == 0.0:\n        return inputs\n    seed = jax_draw_seed(seed)\n    keep_prob = 1.0 - rate\n    # The `noise_shape` may contain `None` so we need to convert it\n    # into a concrete shape before passing it on to jax.\n    noise_shape = _get_concrete_noise_shape(inputs, noise_shape)\n    mask = jax.random.bernoulli(seed, p=keep_prob, shape=noise_shape)\n    mask = jax.numpy.broadcast_to(mask, inputs.shape)\n    return jax.lax.select(\n        mask, inputs / keep_prob, jax.numpy.zeros_like(inputs)\n    )\n\n\ndef shuffle(x, axis=0, seed=None):\n    seed = jax_draw_seed(seed)\n    return jax.random.permutation(seed, x, axis, independent=True)\n\n\ndef gamma(shape, alpha, dtype=None, seed=None):\n    seed = jax_draw_seed(seed)\n    dtype = dtype or floatx()\n    return jax.random.gamma(seed, alpha, shape=shape, dtype=dtype)\n\n\ndef binomial(shape, counts, probabilities, dtype=None, seed=None):\n    dtype = dtype or floatx()\n    seed = jax_draw_seed(seed)\n    # jax doesn't accept python lists as arguments\n    counts = jax.numpy.array(counts)\n    probabilities = jax.numpy.array(probabilities)\n    sample = jax.random.binomial(\n        key=seed, n=counts, p=probabilities, shape=shape, dtype=dtype\n    )\n    return sample\n\n\ndef beta(shape, alpha, beta, dtype=None, seed=None):\n    dtype = dtype or floatx()\n    seed = jax_draw_seed(seed)\n    # jax doesn't accept python lists as arguments\n    alpha = jax.numpy.array(alpha)\n    beta = jax.numpy.array(beta)\n    sample = jax.random.beta(\n        key=seed, a=alpha, b=beta, shape=shape, dtype=dtype\n    )\n    return sample\n"
  },
  {
    "path": "keras/src/backend/jax/rnn.py",
    "content": "import contextlib\n\nfrom jax import lax\nfrom jax import numpy as jnp\n\nfrom keras.src import tree\nfrom keras.src.backend.common import stateless_scope\n\n\ndef rnn(\n    step_function,\n    inputs,\n    initial_states,\n    go_backwards=False,\n    mask=None,\n    constants=None,\n    unroll=False,\n    input_length=None,\n    time_major=False,\n    zero_output_for_mask=False,\n    return_all_outputs=True,\n):\n    def swap_batch_timestep(input_t):\n        # Swap the batch and timestep dim for the incoming tensor.\n        axes = list(range(len(input_t.shape)))\n        axes[0], axes[1] = 1, 0\n        return jnp.transpose(input_t, axes)\n\n    if not time_major:\n        inputs = tree.map_structure(swap_batch_timestep, inputs)\n\n    flattened_inputs = tree.flatten(inputs)\n    time_steps = flattened_inputs[0].shape[0]\n\n    if mask is not None:\n        if mask.dtype != \"bool\":\n            mask = mask.astype(\"bool\")\n        if len(mask.shape) == 2:\n            mask = jnp.expand_dims(mask, axis=-1)\n        if not time_major:\n            mask = swap_batch_timestep(mask)\n\n    if constants is None:\n        constants = []\n\n    def _expand_mask(mask_t, input_t, fixed_dim=1):\n        if tree.is_nested(mask_t):\n            raise ValueError(\n                f\"mask_t is expected to be tensor, but got {mask_t}\"\n            )\n        if tree.is_nested(input_t):\n            raise ValueError(\n                f\"input_t is expected to be tensor, but got {input_t}\"\n            )\n        rank_diff = len(input_t.shape) - len(mask_t.shape)\n        for _ in range(rank_diff):\n            mask_t = jnp.expand_dims(mask_t, -1)\n        multiples = [1] * fixed_dim + list(input_t.shape[fixed_dim:])\n        return jnp.tile(mask_t, multiples)\n\n    if unroll:\n        if not time_steps:\n            raise ValueError(\"Unrolling requires a fixed number of timesteps.\")\n        states = tuple(initial_states)\n        successive_states = []\n        successive_outputs = []\n\n        # Process the input tensors. The input tensor need to be split on the\n        # time_step dim, and reverse if go_backwards is True. In the case of\n        # nested input, the input is flattened and then transformed\n        # individually.  The result of this will be a tuple of lists, each of\n        # the item in tuple is list of the tensor with shape (batch, feature)\n        def _process_single_input_t(input_t):\n            input_t = unstack(input_t)  # unstack for time_step dim\n            if go_backwards:\n                input_t.reverse()\n            return input_t\n\n        if tree.is_nested(inputs):\n            processed_input = tree.map_structure(\n                _process_single_input_t, inputs\n            )\n        else:\n            processed_input = (_process_single_input_t(inputs),)\n\n        def _get_input_tensor(time):\n            inp = [t_[time] for t_ in processed_input]\n            return tree.pack_sequence_as(inputs, inp)\n\n        if mask is not None:\n            mask_list = unstack(mask)\n            if go_backwards:\n                mask_list.reverse()\n\n            for i in range(time_steps):\n                inp = _get_input_tensor(i)\n                mask_t = mask_list[i]\n                output, new_states = step_function(\n                    inp, tuple(states) + tuple(constants)\n                )\n                tiled_mask_t = _expand_mask(mask_t, output)\n\n                if not successive_outputs:\n                    prev_output = jnp.zeros_like(output)\n                else:\n                    prev_output = successive_outputs[-1]\n\n                output = jnp.where(tiled_mask_t, output, prev_output)\n\n                flat_states = tree.flatten(states)\n                flat_new_states = tree.flatten(new_states)\n                tiled_mask_t = tuple(\n                    _expand_mask(mask_t, s) for s in flat_states\n                )\n                flat_final_states = tuple(\n                    jnp.where(m, s, ps)\n                    for m, s, ps in zip(\n                        tiled_mask_t, flat_new_states, flat_states\n                    )\n                )\n                states = tree.pack_sequence_as(states, flat_final_states)\n\n                if return_all_outputs:\n                    successive_outputs.append(output)\n                    successive_states.append(states)\n                else:\n                    successive_outputs = [output]\n                    successive_states = [states]\n            last_output = successive_outputs[-1]\n            new_states = successive_states[-1]\n            outputs = jnp.stack(successive_outputs)\n\n        else:  # mask is None\n            for i in range(time_steps):\n                inp = _get_input_tensor(i)\n                output, states = step_function(\n                    inp, tuple(states) + tuple(constants)\n                )\n                if return_all_outputs:\n                    successive_outputs.append(output)\n                    successive_states.append(states)\n                else:\n                    successive_outputs = [output]\n                    successive_states = [states]\n            last_output = successive_outputs[-1]\n            new_states = successive_states[-1]\n            outputs = jnp.stack(successive_outputs)\n\n    else:  # Unroll == False\n        if mask is not None:\n\n            def _step(states, current_input):\n                current_input, current_mask = current_input\n                is_masked = jnp.all(\n                    jnp.logical_not(current_mask), axis=-1, keepdims=True\n                )\n\n                output_t, new_states = step_function(current_input, states)\n\n                if zero_output_for_mask:\n                    masked_outs = jnp.where(\n                        is_masked, jnp.zeros_like(output_t), output_t\n                    )\n                else:\n                    # Assume the first state is the previous output.\n                    output_tm1 = states[0]\n                    if tree.is_nested(output_tm1):\n                        # Stacked RNN case: assume first state of last cell.\n                        output_tm1 = states[-1][0]\n                    masked_outs = jnp.where(is_masked, output_tm1, output_t)\n\n                new_states = tree.map_structure(\n                    lambda s, ns: jnp.where(is_masked, s, ns),\n                    states,\n                    new_states,\n                )\n                return (new_states, masked_outs)\n\n            scan_xs = (inputs, mask)\n\n        else:\n\n            def _step(states, current_input):\n                output_t, new_states = step_function(current_input, states)\n                return new_states, output_t\n\n            scan_xs = inputs\n\n        if stateless_scope.in_stateless_scope():\n            # Reuse the existing parent stateless scope.\n            scope = contextlib.nullcontext()\n        else:\n            scope = stateless_scope.StatelessScope()\n        with scope:\n            # We must use a stateless scope because `scan` will involve\n            # JAX tracing -- any variable update at this stage would\n            # be a leak.\n            new_states, outputs = lax.scan(\n                f=_step,\n                init=initial_states,\n                xs=scan_xs,\n                reverse=go_backwards,\n            )\n        if go_backwards:\n            outputs = jnp.flip(outputs, axis=0)\n        last_output = outputs[-1]\n\n    if not time_major:\n        outputs = tree.map_structure(swap_batch_timestep, outputs)\n\n    return last_output, outputs, new_states\n\n\ndef cudnn_ok(*args, **kwargs):\n    return False\n\n\ndef lstm(*args, **kwargs):\n    raise NotImplementedError\n\n\ndef gru(*args, **kwargs):\n    raise NotImplementedError\n\n\ndef unstack(x, axis=0):\n    return [\n        lax.index_in_dim(x, i, axis, keepdims=False)\n        for i in range(x.shape[axis])\n    ]\n"
  },
  {
    "path": "keras/src/backend/jax/sparse.py",
    "content": "import functools\n\nimport jax.experimental.sparse as jax_sparse\nimport jax.numpy as jnp\n\nfrom keras.src.utils import jax_utils\n\n\ndef axis_shape_dims_for_broadcast_in_dim(axis, input_shape, insert_dims):\n    \"\"\"Turn the `axis` argument to the arguments needed by `broadcast_in_dim`.\n\n    Args:\n        axis: single int or a tuple of ints for the axis argument. The list of\n          dimensions to reduce or insert.\n        input_shape: the shape of the input as a tuple ints.\n        insert_dims: `False` turns dimensions in `axis` to 1s (use case:\n          reduction along `axis` with `keep_dims=True`). `True`, inserts 1s\n          according to `axis` (use case: `expand_dims`).\n    Returns:\n        A tuple of three lists\n        - The canonical value for `axis`: always a list, negative values have\n          been resolved and values are sorted in ascending order.\n        - The output shape: `input_shape` with 1s at the indices in `axis`, for\n          use as the `shape` argument of `broadcast_in_dim`.\n        - The broadcast dimensions: list of dimensions not in `axis`, for use as\n          the `broadcast_dimensions` argument of `broadcast_in_dim`.\n    \"\"\"\n    if axis is None:\n        raise ValueError(\"Received `None` value for `axis`\")\n    if isinstance(axis, int):\n        axis = (axis,)\n    # Check uniqueness.\n    if len(set(axis)) != len(axis):\n        raise ValueError(f\"Repeated axis in `axis`: {axis}\")\n    result_dims = len(input_shape)\n    if insert_dims:\n        result_dims += len(axis)\n\n    # Resolve negative values.\n    canonical_axis = []\n    for a in axis:\n        if not -result_dims <= a < result_dims:\n            raise ValueError(\n                f\"In `axis`, axis {a} is out of bounds for array \"\n                f\"of dimension {result_dims}\"\n            )\n        if a < 0:\n            a = a + result_dims\n        canonical_axis.append(a)\n\n    # Check uniqueness again after resolving negative values.\n    if len(set(canonical_axis)) != len(canonical_axis):\n        raise ValueError(f\"Repeated axis in `axis`: {canonical_axis}\")\n    canonical_axis = sorted(canonical_axis)\n\n    # Compute output shape.\n    output_shape = list(input_shape)\n    for i in canonical_axis:\n        if insert_dims:\n            output_shape.insert(i, 1)\n        else:\n            output_shape[i] = 1\n    broadcast_dims = [i for i in range(result_dims) if i not in canonical_axis]\n    return canonical_axis, output_shape, broadcast_dims\n\n\ndef bcoo_add_indices(x1, x2, sum_duplicates):\n    \"\"\"Add the indices of `x2` to `x1` with zero values.\n\n    Args:\n        x1: `BCOO` tensor to add indices to.\n        x2: `BCOO` tensor to take the indices to add to x1.\n        sum_duplicates: if `True` calls `bcoo_sum_duplicates` on the output.\n    Returns:\n        a `BCOO` tensor equal to `x1` but with extra zeros at indices in `x2`\n        that were missing in `x1`.\n    \"\"\"\n    x2_zeros = jnp.zeros(x2.data.shape, x1.data.dtype)\n    concat_axis = len(x1.indices.shape) - 2\n    output_indices = jnp.concatenate([x1.indices, x2.indices], axis=concat_axis)\n    output_data = jnp.concatenate([x1.data, x2_zeros], axis=concat_axis)\n    output = jax_sparse.BCOO((output_data, output_indices), shape=x1.shape)\n    if sum_duplicates:\n        output = jax_sparse.bcoo_sum_duplicates(output)\n    return output\n\n\ndef densifying_unary(func):\n    \"\"\"Decorator to add support for `JAXSparse` tensors (including `BCOO`) to a\n    non-zero-preserving element-wise unary operator.\n\n    There are requirements on the operator for this decorator to work correctly:\n\n    - The operator must be element-wise\n    - The operator must be unary (one input tensor and one output tensor)\n    - The operator must return a tensor of the same shape.\n\n    Additional arguments to the function (besides the input tensor) are\n    supported. The returned result is a dense tensor.\n\n    Args:\n        func: The unary operator to wrap.\n    Returns:\n        Wrapped function that supports `JAXSparse` tensors.\n    \"\"\"\n\n    @functools.wraps(func)\n    def sparse_wrapper(x, *args, **kwargs):\n        if isinstance(x, jax_sparse.JAXSparse):\n            x = x.todense()\n        return func(x, *args, **kwargs)\n\n    return sparse_wrapper\n\n\ndef elementwise_unary(linear):\n    \"\"\"Decorator to add support for `BCOO` sparse tensors to a zero-preserving\n    element-wise unary operator.\n\n    There are requirements on the operator for this decorator to work correctly:\n\n    - The operator must be element-wise\n    - The operator must be unary (one input tensor and one output tensor)\n    - The operator must return a tensor of the same shape, and if it is a\n      `BCOO` tensor, the indices of the result must be the same. Therefore:\n        - Reduction operations are not supported (e.g. `mean`).\n        - Operations for which the result may be dense (e.g. `reciprocal`), or\n          the sparse indices depend on the inputs are not supported (e.g.\n          `clip`). This implies that `func(0)` must be 0.\n\n    Additional arguments to the function (besides the input tensor) are\n    supported as long as they cannot change the indices of the result. For\n    instance,`round` is supported, but `clip` is not supported as\n    `clip(x, 1.0, 2.0)` would always return a dense tensor.\n\n    Note that if an input sparse tensor contains zero values, the indices and\n    the zero values are preserved.\n\n    Args:\n        linear: if `True`, means that the operation is such that\n            `op(a + b) == op(a) + op(b)`.\n    Returns:\n        Wrapped function that supports `BCOO` sparse tensors.\n    \"\"\"\n\n    def wrap_elementwise_unary(func):\n        @functools.wraps(func)\n        def sparse_wrapper(x, *args, **kwargs):\n            if isinstance(x, jax_sparse.BCOO):\n                if not linear and not x.unique_indices:\n                    x = jax_sparse.bcoo_sum_duplicates(x)\n                return jax_sparse.BCOO(\n                    (func(x.data, *args, **kwargs), x.indices), shape=x.shape\n                )\n            else:\n                return func(x, *args, **kwargs)\n\n        return sparse_wrapper\n\n    return wrap_elementwise_unary\n\n\ndef elementwise_binary_union(linear, use_sparsify):\n    \"\"\"Decorator to add support for `JAXSparse` tensors (including `BCOO`) to an\n    element-wise binary operator such that the indices present in the result are\n    are the union of the indices in the two operand.\n\n    The primary use case for this is the `add` and `subtract` operators.\n\n    There are requirements on the operator for this decorator to work correctly:\n\n    - The operator must be element-wise.\n    - The operator must be binary (two input tensors and one output tensor).\n    - Both inputs must be of the same shape or one input must be a scalar.\n    - The output must be of the same shape as the (non scalar) inputs.\n    - The indices of the output must be the union of the indices of the inputs.\n      This implies that func(0, 0) must be 0. As a result, if one operand is\n      dense or a scalar, then the result will be dense.\n\n    Additional arguments to the function (besides the input tensors) are not\n    supported.\n\n    Note that if the result of the operation is zero at some indices, including\n    because the operands were zero at these indices, the zeros and indices are\n    preserved.\n\n    The `BCOO` format is the only supported one in all cases. Other formats are\n    not supported when `use_sparsify` is `False`.\n\n    Args:\n        use_sparsify: indicates that the JAX `sparsify` transform supports this\n            operation.\n        linear: if `True`, mean that the operation is such that\n            `op(a + b, c) == op(a, c) + op(b, c)` and\n            `op(a, c + d) == op(a, c) + op(a, d)`.\n    Returns:\n        Wrapped function that supports `JAXSparse`.\n    \"\"\"\n\n    def wrap_elementwise_binary_union(func):\n        sparse_func = jax_sparse.sparsify(func) if use_sparsify else None\n\n        @functools.wraps(func)\n        def sparse_wrapper(x1, x2):\n            if isinstance(x1, jax_sparse.JAXSparse):\n                if isinstance(x2, jax_sparse.JAXSparse):\n                    # x1 and x2 are sparse.\n                    # The way we use `sparsify` it cannot know that the indices\n                    # are the same, so we optimize this case here.\n                    if (\n                        x1.indices is x2.indices\n                        and isinstance(x1, jax_sparse.BCOO)\n                        and isinstance(x2, jax_sparse.BCOO)\n                    ):\n                        if not linear and not x1.unique_indices:\n                            x1 = jax_sparse.bcoo_sum_duplicates(x1)\n                            x2 = jax_sparse.bcoo_sum_duplicates(x2)\n                        return jax_sparse.BCOO(\n                            (func(x1.data, x2.data), x1.indices),\n                            shape=x1.shape,\n                            indices_sorted=x1.indices_sorted,\n                            unique_indices=x1.unique_indices,\n                        )\n                    elif use_sparsify:\n                        return sparse_func(x1, x2)\n                    elif isinstance(x1, jax_sparse.BCOO) and isinstance(\n                        x2, jax_sparse.BCOO\n                    ):\n                        x1 = bcoo_add_indices(x1, x2, sum_duplicates=not linear)\n                        x2 = bcoo_add_indices(x2, x1, sum_duplicates=not linear)\n                        return jax_sparse.BCOO(\n                            (func(x1.data, x2.data), x1.indices),\n                            shape=x1.shape,\n                            indices_sorted=True,\n                            unique_indices=True,\n                        )\n                    else:\n                        ValueError(\n                            \"Unsupported sparse format: \"\n                            f\"{x1.__class__} and {x2.__class__}\"\n                        )\n                else:\n                    # x1 is sparse, x2 is dense, densify x2.\n                    x1 = x1.todense()\n            elif isinstance(x2, jax_sparse.JAXSparse):\n                # x1 is dense, x2 is sparse, densify x2.\n                x2 = x2.todense()\n            return func(x1, x2)\n\n        return sparse_wrapper\n\n    return wrap_elementwise_binary_union\n\n\ndef elementwise_division(func):\n    \"\"\"Decorator to add support for `BCOO` sparse tensors to element-wise binary\n    division and related operators.\n\n    This decorator is designed for operations related to the division of two\n    two operands (e.g. `divide`). It accepts `BCOO` tensors for both the\n    dividend and the divisor, but handles them differently based on whether they\n    are the dividend or the divisor.\n\n    - If the divisor is sparse, it is densified and the result is dense because\n      the result contains Inf or Nan outside of the indices of the dividend.\n    - If the dividend is sparse and the divisor is dense, it finds occurrences\n      of zeros and NaNs in the divisor. The result may therefore have more\n      indices than there were in the dividend to return correct values where the\n      divisor was zero or NaN.\n    - If the dividend is sparse and the divisor is a scalar, it does the\n      division element-wise. Note that the result is incorrectly sparse if the\n      scalar divisor is zero.\n\n    Args:\n        func: The function to wrap.\n    Returns:\n        Wrapped function that supports `BCOO` sparse tensors.\n    \"\"\"\n    sparse_func = jax_sparse.sparsify(func)\n\n    @functools.wraps(func)\n    def sparse_wrapper(x1, x2):\n        if isinstance(x1, jax_sparse.JAXSparse):\n            if isinstance(x2, jax_sparse.JAXSparse):\n                # x1 is sparse and x2 is sparse.\n                # Divisor is sparse, meaning we're doing divisions by zero\n                # outside of x2.indices, so the result is dense. Densify both.\n                x1 = x1.todense()\n                x2 = x2.todense()\n            elif isinstance(x1, jax_sparse.BCOO):\n                if not hasattr(x2, \"shape\") or len(x2.shape) == 0:\n                    # x1 is sparse BCOO, x2 is scalar, apply func element-wise.\n                    return jax_sparse.BCOO(\n                        (func(x1.data, x2), x1.indices),\n                        shape=x1.shape,\n                        indices_sorted=x1.indices_sorted,\n                        unique_indices=x1.unique_indices,\n                    )\n                else:\n                    # x1 is sparse BCOO, x2 is dense.\n                    if not jax_utils.is_in_jax_tracing_scope(x2):\n                        # Find zeros and nans in x2 and add indices to x1.\n                        # 1. Create a dense mask for zeros and nans.\n                        x2_zeros_and_nans = jnp.equal(x2, 0)\n                        if not jnp.issubdtype(x2.dtype, jnp.integer):\n                            x2_zeros_and_nans = jnp.logical_or(\n                                x2_zeros_and_nans, jnp.isnan(x2)\n                            )\n                        # 2. Make it a BCOO of True values.\n                        x2_zeros_and_nans = jax_sparse.bcoo_fromdense(\n                            x2_zeros_and_nans,\n                            n_batch=x1.n_batch,\n                            n_dense=x1.n_dense,\n                            index_dtype=x1.indices.dtype,\n                        )\n                        # 3. Add the indices to x1.\n                        x1 = bcoo_add_indices(\n                            x1, x2_zeros_and_nans, sum_duplicates=True\n                        )\n                    return sparse_func(x1, x2)\n            else:\n                raise ValueError(f\"Unsupported sparse format: {x1.__class__}\")\n        elif isinstance(x2, jax_sparse.JAXSparse):\n            # x1 is dense, x2 is sparse, densify x2\n            x2 = x2.todense()\n        return func(x1, x2)\n\n    return sparse_wrapper\n"
  },
  {
    "path": "keras/src/backend/jax/tensorboard.py",
    "content": "from keras.src.utils.module_utils import jax\n\n\ndef start_trace(logdir):\n    if logdir:\n        jax.profiler.start_trace(logdir)\n\n\ndef stop_trace(save):\n    if save:\n        jax.profiler.stop_trace()\n\n\ndef start_batch_trace(batch):\n    batch_trace_context = jax.profiler.TraceAnnotation(\n        f\"Profiled batch {batch}\"\n    )\n    batch_trace_context.__enter__()\n    return batch_trace_context\n\n\ndef stop_batch_trace(batch_trace_context):\n    batch_trace_context.__exit__(None, None, None)\n"
  },
  {
    "path": "keras/src/backend/jax/trainer.py",
    "content": "import itertools\nimport warnings\nfrom functools import partial\n\nimport jax\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src import callbacks as callbacks_module\nfrom keras.src import optimizers as optimizers_module\nfrom keras.src import tree\nfrom keras.src.backend import config\nfrom keras.src.backend import distribution_lib as jax_distribution_lib\nfrom keras.src.backend.config import is_nnx_enabled\nfrom keras.src.distribution import distribution_lib\nfrom keras.src.trainers import trainer as base_trainer\nfrom keras.src.trainers.data_adapters import array_slicing\nfrom keras.src.trainers.data_adapters import data_adapter_utils\nfrom keras.src.trainers.epoch_iterator import EpochIterator\nfrom keras.src.utils import traceback_utils\nfrom keras.src.utils.python_utils import pythonify_logs\n\nif is_nnx_enabled():\n    from flax import nnx\n\n    jit = nnx.jit\nelse:\n    jit = jax.jit\n\n\nclass JAXTrainer(base_trainer.Trainer):\n    def __init__(self):\n        super().__init__()\n        self.train_function = None\n        self.test_function = None\n        self.predict_function = None\n        self._jax_state_synced = True\n\n    def compute_loss_and_updates(\n        self,\n        trainable_variables,\n        non_trainable_variables,\n        metrics_variables,\n        x,\n        y,\n        sample_weight,\n        training=False,\n        optimizer_variables=None,\n    ):\n        \"\"\"This method is stateless and is intended for use with jax.grad.\"\"\"\n        kwargs = {}\n        if self._call_has_training_arg:\n            kwargs[\"training\"] = training\n\n        # Run stateless forward pass\n        y_pred, non_trainable_variables, losses = self.stateless_call(\n            trainable_variables,\n            non_trainable_variables,\n            x,\n            return_losses=True,\n            **kwargs,\n        )\n        if losses:\n            # Make forward pass losses available to compute_loss.\n            self._losses_override.clear()\n            self._losses_override = losses\n\n        loss, variables = self.stateless_compute_loss(\n            trainable_variables,\n            non_trainable_variables,\n            metrics_variables,\n            x=x,\n            y=y,\n            y_pred=y_pred,\n            sample_weight=sample_weight,\n            training=training,\n        )\n        if losses:\n            self._losses_override.clear()\n        (trainable_variables, non_trainable_variables, metrics_variables) = (\n            variables\n        )\n\n        # Handle loss scaling\n        unscaled_loss = loss\n        if training and self.optimizer is not None:\n            # Scale loss with a StatelessScope, to use an update scale variable.\n            mapping = list(zip(self.optimizer.variables, optimizer_variables))\n            with backend.StatelessScope(state_mapping=mapping):\n                loss = self.optimizer.scale_loss(loss)\n        return loss, (\n            unscaled_loss,\n            y_pred,\n            non_trainable_variables,\n            metrics_variables,\n        )\n\n    def _update_metrics_variables(\n        self, metrics_variables, unscaled_loss, x, y, y_pred, sample_weight\n    ):\n        with backend.StatelessScope(\n            state_mapping=[\n                (ref_v, v)\n                for ref_v, v in zip(self.metrics_variables, metrics_variables)\n            ]\n        ) as scope:\n            self._loss_tracker.update_state(\n                unscaled_loss,\n                sample_weight=next(\n                    i for i in tree.flatten(x) if i is not None\n                ).shape[0],\n            )\n            logs = self.compute_metrics(x, y, y_pred, sample_weight)\n\n        new_metrics_variables = []\n        for ref_v in self.metrics_variables:\n            new_v = scope.get_current_value(ref_v)\n            if new_v is None:\n                new_v = ref_v.value\n            new_metrics_variables.append(new_v)\n        return logs, new_metrics_variables\n\n    def train_step(self, state, data):\n        (\n            trainable_variables,\n            non_trainable_variables,\n            optimizer_variables,\n            metrics_variables,\n        ) = state\n        x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)\n        grad_fn = jax.value_and_grad(\n            self.compute_loss_and_updates, has_aux=True\n        )\n        (loss, aux), grads = grad_fn(\n            trainable_variables,\n            non_trainable_variables,\n            metrics_variables,\n            x,\n            y,\n            sample_weight,\n            training=True,\n            optimizer_variables=optimizer_variables,\n        )\n        (unscaled_loss, y_pred, non_trainable_variables, metrics_variables) = (\n            aux\n        )\n\n        (\n            trainable_variables,\n            optimizer_variables,\n        ) = self.optimizer.stateless_apply(\n            optimizer_variables, grads, trainable_variables\n        )\n\n        logs, metrics_variables = self._update_metrics_variables(\n            metrics_variables, unscaled_loss, x, y, y_pred, sample_weight\n        )\n\n        state = (\n            trainable_variables,\n            non_trainable_variables,\n            optimizer_variables,\n            metrics_variables,\n        )\n        return logs, state\n\n    def test_step(self, state, data):\n        (\n            trainable_variables,\n            non_trainable_variables,\n            metrics_variables,\n        ) = state\n        x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)\n        loss, aux = self.compute_loss_and_updates(\n            trainable_variables,\n            non_trainable_variables,\n            metrics_variables,\n            x,\n            y,\n            sample_weight,\n            training=False,\n        )\n        (unscaled_loss, y_pred, non_trainable_variables, metrics_variables) = (\n            aux\n        )\n\n        logs, metrics_variables = self._update_metrics_variables(\n            metrics_variables, unscaled_loss, x, y, y_pred, sample_weight\n        )\n\n        state = (\n            trainable_variables,\n            non_trainable_variables,\n            metrics_variables,\n        )\n        return logs, state\n\n    def predict_step(self, state, data):\n        trainable_variables, non_trainable_variables = state\n        kwargs = {}\n        if self._call_has_training_arg:\n            kwargs[\"training\"] = False\n\n        x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data)\n        outputs, non_trainable_variables = self.stateless_call(\n            trainable_variables, non_trainable_variables, x, **kwargs\n        )\n        return outputs, non_trainable_variables\n\n    def _make_function(self, step_function, concatenate_outputs=False):\n        if self.steps_per_execution > 1:\n            if concatenate_outputs:\n\n                def concatenate(outputs):\n                    output = outputs[0]\n                    for next_output in outputs[1:]:\n                        output = tree.map_structure(\n                            lambda t1, t2: jax.numpy.concatenate([t1, t2]),\n                            output,\n                            next_output,\n                        )\n                    return output\n\n                if not self.run_eagerly and self.jit_compile:\n                    concatenate = jit(concatenate)\n\n                def iterator_step(state, iterator):\n                    data = next(iterator)\n                    outputs, state = step_function(state, data)\n                    outputs = [outputs]\n                    try:\n                        for _ in range(self.steps_per_execution - 1):\n                            data = next(iterator)\n                            _outputs, state = step_function(state, data)\n                            outputs.append(_outputs)\n                    except StopIteration:\n                        pass\n                    outputs = concatenate(outputs)\n                    return outputs, state\n\n            else:\n\n                def iterator_step(state, iterator):\n                    data = next(iterator)\n                    outputs, state = step_function(state, data)\n                    try:\n                        for _ in range(self.steps_per_execution - 1):\n                            data = next(iterator)\n                            outputs, state = step_function(state, data)\n                    except StopIteration:\n                        pass\n                    return outputs, state\n\n        else:\n\n            def iterator_step(state, iterator):\n                return step_function(state, next(iterator))\n\n        return iterator_step\n\n    def make_train_function(self, force=False):\n        if self.train_function is not None and not force:\n            return\n        if not self.run_eagerly and self.jit_compile:\n            out_shardings = None\n            if distribution_lib.distribution() is not None:\n                state_shardings = self._get_state_sharding_spec()\n                out_shardings = (None, state_shardings)\n            if is_nnx_enabled():\n                step_fn = lambda state, data: type(self).train_step(\n                    self, state, data\n                )\n            else:\n                step_fn = self.train_step\n            train_step = jit(\n                step_fn,\n                donate_argnums=0,\n                out_shardings=out_shardings,\n            )\n        else:\n            train_step = self.train_step\n\n        step_function = self._make_function(train_step)\n\n        self.train_function = step_function\n\n    def make_test_function(self, force=False):\n        if self.test_function is not None and not force:\n            return\n        if not self.run_eagerly and self.jit_compile:\n            out_shardings = None\n            if distribution_lib.distribution() is not None:\n                (\n                    trainable_shardings,\n                    non_trainable_shardings,\n                    _,  # optimizer_shardings\n                    metrics_shardings,\n                ) = self._get_state_sharding_spec()\n                state_shardings = (\n                    trainable_shardings,\n                    non_trainable_shardings,\n                    metrics_shardings,\n                )\n                out_shardings = (None, state_shardings)\n            if is_nnx_enabled():\n                step_fn = lambda state, data: type(self).test_step(\n                    self, state, data\n                )\n            else:\n                step_fn = self.test_step\n            test_step = jit(\n                step_fn,\n                donate_argnums=0,\n                out_shardings=out_shardings,\n            )\n        else:\n            test_step = self.test_step\n\n        step_function = self._make_function(test_step)\n\n        self.test_function = step_function\n\n    def make_predict_function(self, force=False):\n        if self.predict_function is not None and not force:\n            return self.predict_function\n\n        def predict_step(state, data):\n            outputs, non_trainable_variables = self.predict_step(state, data)\n            return outputs, (state[0], non_trainable_variables)\n\n        if not self.run_eagerly and self.jit_compile:\n            out_shardings = None\n            if distribution_lib.distribution() is not None:\n                (\n                    trainable_shardings,\n                    non_trainable_shardings,\n                    _,  # optimizer_shardings\n                    _,  # metrics_shardings\n                ) = self._get_state_sharding_spec()\n                state_shardings = (\n                    trainable_shardings,\n                    non_trainable_shardings,\n                )\n                out_shardings = (None, state_shardings)\n            predict_step = jit(\n                predict_step,\n                donate_argnums=0,\n                out_shardings=out_shardings,\n            )\n\n        _step_function = self._make_function(\n            predict_step, concatenate_outputs=True\n        )\n\n        def step_function(state, iterator):\n            outputs, state = _step_function(state, iterator)\n            return outputs, state\n\n        self.predict_function = step_function\n\n    @traceback_utils.filter_traceback\n    def fit(\n        self,\n        x=None,\n        y=None,\n        batch_size=None,\n        epochs=1,\n        verbose=\"auto\",\n        callbacks=None,\n        validation_split=0.0,\n        validation_data=None,\n        shuffle=True,\n        class_weight=None,\n        sample_weight=None,\n        initial_epoch=0,\n        steps_per_epoch=None,\n        validation_steps=None,\n        validation_batch_size=None,\n        validation_freq=1,\n    ):\n        self._assert_compile_called(\"fit\")\n        # Possibly cap epochs for debugging runs.\n        max_epochs = config.max_epochs()\n        if max_epochs and max_epochs < epochs:\n            warnings.warn(\"Limiting epochs to %d\" % max_epochs)\n            epochs = max_epochs\n        # TODO: respect compiled trainable state\n        self._eval_epoch_iterator = None\n        if validation_split and validation_data is None:\n            # Create the validation data using the training data. Only supported\n            # for TF/numpy/jax arrays.\n            (\n                (x, y, sample_weight),\n                validation_data,\n            ) = array_slicing.train_validation_split(\n                (x, y, sample_weight), validation_split=validation_split\n            )\n\n        if validation_data is not None:\n            (\n                val_x,\n                val_y,\n                val_sample_weight,\n            ) = data_adapter_utils.unpack_x_y_sample_weight(validation_data)\n\n        # Create an iterator that yields batches for one epoch.\n        epoch_iterator = JAXEpochIterator(\n            x=x,\n            y=y,\n            sample_weight=sample_weight,\n            batch_size=batch_size,\n            steps_per_epoch=steps_per_epoch,\n            shuffle=shuffle,\n            class_weight=class_weight,\n            steps_per_execution=self.steps_per_execution,\n        )\n\n        self._symbolic_build(iterator=epoch_iterator)\n        epoch_iterator.reset()\n\n        # Container that configures and calls callbacks.\n        if not isinstance(callbacks, callbacks_module.CallbackList):\n            callbacks = callbacks_module.CallbackList(\n                callbacks,\n                add_history=True,\n                add_progbar=verbose != 0,\n                verbose=verbose,\n                epochs=epochs,\n                steps=epoch_iterator.num_batches,\n                model=self,\n            )\n\n        self.make_train_function()\n        self.stop_training = False\n        training_logs = {}\n        training_finished = False\n        callbacks.on_train_begin()\n        initial_epoch = self._initial_epoch or initial_epoch\n        try:\n            for epoch in range(initial_epoch, epochs):\n                self.reset_metrics()\n                callbacks.on_epoch_begin(epoch)\n\n                self._jax_state_synced = True\n                with epoch_iterator.catch_stop_iteration():\n                    for begin_step, end_step, iterator in epoch_iterator:\n                        # Callbacks\n                        callbacks.on_train_batch_begin(begin_step)\n\n                        # Train step\n                        if self._jax_state_synced:\n                            # The state may have been synced by a callback.\n                            state = self._get_jax_state(\n                                trainable_variables=True,\n                                non_trainable_variables=True,\n                                optimizer_variables=True,\n                                metrics_variables=True,\n                                purge_model_variables=True,\n                            )\n                            self._jax_state_synced = False\n\n                        logs, state = self.train_function(state, iterator)\n                        (\n                            trainable_variables,\n                            non_trainable_variables,\n                            optimizer_variables,\n                            metrics_variables,\n                        ) = state\n\n                        # Setting _jax_state enables callbacks to force a state\n                        # sync if they need to.\n                        self._jax_state = {\n                            \"trainable_variables\": trainable_variables,\n                            \"non_trainable_variables\": non_trainable_variables,\n                            \"optimizer_variables\": optimizer_variables,\n                            \"metrics_variables\": metrics_variables,\n                        }\n                        # Dispatch callbacks. This takes care of async dispatch.\n                        callbacks.on_train_batch_end(end_step, logs)\n\n                        if self.stop_training:\n                            # Stop training if a callback has set\n                            # this flag in on_(train_)batch_end.\n                            break\n\n                # Reattach state to the model\n                # (if not already done by a callback).\n                # NOTE: doing this after each step would be a big performance\n                # bottleneck.\n                self.jax_state_sync()\n\n                # Override with model metrics instead of last step logs if\n                # needed.\n                epoch_logs = dict(self._get_metrics_result_or_logs(logs))\n\n                # Run validation.\n                if validation_data is not None and self._should_eval(\n                    epoch, validation_freq\n                ):\n                    # Create JAXEpochIterator for evaluation and cache it.\n                    if getattr(self, \"_eval_epoch_iterator\", None) is None:\n                        self._eval_epoch_iterator = JAXEpochIterator(\n                            x=val_x,\n                            y=val_y,\n                            sample_weight=val_sample_weight,\n                            batch_size=validation_batch_size or batch_size,\n                            steps_per_execution=self.steps_per_execution,\n                            steps_per_epoch=validation_steps,\n                            shuffle=False,\n                        )\n                    val_logs = self.evaluate(\n                        x=val_x,\n                        y=val_y,\n                        sample_weight=val_sample_weight,\n                        batch_size=validation_batch_size or batch_size,\n                        steps=validation_steps,\n                        callbacks=callbacks,\n                        return_dict=True,\n                        _use_cached_eval_dataset=True,\n                    )\n                    val_logs = {\n                        f\"val_{name}\": val for name, val in val_logs.items()\n                    }\n                    epoch_logs.update(val_logs)\n\n                callbacks.on_epoch_end(epoch, epoch_logs)\n                training_logs = epoch_logs\n                if self.stop_training:\n                    break\n            training_finished = True\n\n        finally:\n            self.jax_state_sync()\n            if (\n                isinstance(self.optimizer, optimizers_module.Optimizer)\n                and epochs > 0\n            ):\n                self.optimizer.finalize_variable_values(self.trainable_weights)\n\n            # If _eval_epoch_iterator exists, delete it after all epochs\n            # are done.\n            if getattr(self, \"_eval_epoch_iterator\", None) is not None:\n                del self._eval_epoch_iterator\n            if training_finished:\n                callbacks.on_train_end(logs=training_logs)\n            self._jax_state = None\n        return self.history\n\n    @traceback_utils.filter_traceback\n    def evaluate(\n        self,\n        x=None,\n        y=None,\n        batch_size=None,\n        verbose=\"auto\",\n        sample_weight=None,\n        steps=None,\n        callbacks=None,\n        return_dict=False,\n        **kwargs,\n    ):\n        self._assert_compile_called(\"evaluate\")\n        # TODO: respect compiled trainable state\n        use_cached_eval_dataset = kwargs.pop(\"_use_cached_eval_dataset\", False)\n        if kwargs:\n            raise ValueError(f\"Arguments not recognized: {kwargs}\")\n\n        if use_cached_eval_dataset:\n            epoch_iterator = self._eval_epoch_iterator\n        else:\n            # Create an iterator that yields batches of\n            # input/target data.\n            epoch_iterator = JAXEpochIterator(\n                x=x,\n                y=y,\n                sample_weight=sample_weight,\n                batch_size=batch_size,\n                steps_per_epoch=steps,\n                shuffle=False,\n                steps_per_execution=self.steps_per_execution,\n            )\n\n        self._symbolic_build(iterator=epoch_iterator)\n        epoch_iterator.reset()\n\n        # Container that configures and calls callbacks.\n        if not isinstance(callbacks, callbacks_module.CallbackList):\n            callbacks = callbacks_module.CallbackList(\n                callbacks,\n                add_progbar=verbose != 0,\n                verbose=verbose,\n                epochs=1,\n                steps=epoch_iterator.num_batches,\n                model=self,\n            )\n\n        self.make_test_function()\n        self.stop_evaluating = False\n        callbacks.on_test_begin()\n        logs = {}\n        self.reset_metrics()\n\n        self._jax_state_synced = True\n        with epoch_iterator.catch_stop_iteration():\n            for begin_step, end_step, iterator in epoch_iterator:\n                callbacks.on_test_batch_begin(begin_step)\n\n                if self._jax_state_synced:\n                    # The state may have been synced by a callback.\n                    state = self._get_jax_state(\n                        trainable_variables=True,\n                        non_trainable_variables=True,\n                        metrics_variables=True,\n                        purge_model_variables=True,\n                    )\n                    self._jax_state_synced = False\n\n                logs, state = self.test_function(state, iterator)\n                (\n                    trainable_variables,\n                    non_trainable_variables,\n                    metrics_variables,\n                ) = state\n\n                # Setting _jax_state enables callbacks to force a state sync\n                # if they need to.\n                self._jax_state = {\n                    # I wouldn't recommend modifying non-trainable model state\n                    # during evaluate(), but it's allowed.\n                    \"trainable_variables\": trainable_variables,\n                    \"non_trainable_variables\": non_trainable_variables,\n                    \"metrics_variables\": metrics_variables,\n                }\n\n                # Dispatch callbacks. This takes care of async dispatch.\n                callbacks.on_test_batch_end(end_step, logs)\n\n                if self.stop_evaluating:\n                    break\n\n        # Reattach state back to model (if not already done by a callback).\n        self.jax_state_sync()\n\n        logs = pythonify_logs(self._get_metrics_result_or_logs(logs))\n        callbacks.on_test_end(logs)\n        self._jax_state = None\n\n        if return_dict:\n            return logs\n        return self._flatten_metrics_in_order(logs)\n\n    @traceback_utils.filter_traceback\n    def predict(\n        self, x, batch_size=None, verbose=\"auto\", steps=None, callbacks=None\n    ):\n        # Create an iterator that yields batches of input data.\n        epoch_iterator = JAXEpochIterator(\n            x=x,\n            batch_size=batch_size,\n            steps_per_epoch=steps,\n            shuffle=False,\n            steps_per_execution=self.steps_per_execution,\n        )\n\n        if not all(layer.built for layer in self._flatten_layers()):\n            # Build the model on one batch of data.\n            for _, _, iterator in epoch_iterator:\n                # Build model\n                x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(\n                    next(iterator)\n                )\n                if is_nnx_enabled():\n                    self(x)\n                else:\n                    with backend.StatelessScope():\n                        self(x)\n                break\n            epoch_iterator.reset()\n        # Container that configures and calls callbacks.\n        if not isinstance(callbacks, callbacks_module.CallbackList):\n            callbacks = callbacks_module.CallbackList(\n                callbacks,\n                add_progbar=verbose != 0,\n                verbose=verbose,\n                epochs=1,\n                steps=epoch_iterator.num_batches,\n                model=self,\n            )\n\n        self.make_predict_function()\n        self.stop_predicting = False\n        callbacks.on_predict_begin()\n\n        def append_to_outputs(batch_outputs, outputs):\n            if outputs is None:\n                outputs = tree.map_structure(\n                    lambda batch_output: [batch_output],\n                    batch_outputs,\n                )\n            else:\n                tree.map_structure_up_to(\n                    batch_outputs,\n                    lambda output, batch_output: output.append(batch_output),\n                    outputs,\n                    batch_outputs,\n                )\n            return outputs\n\n        self._jax_state_synced = True\n        outputs = None\n        non_trainable_variables = None\n        with epoch_iterator.catch_stop_iteration():\n            for begin_step, end_step, iterator in epoch_iterator:\n                callbacks.on_predict_batch_begin(begin_step)\n                if self._jax_state_synced:\n                    # The state may have been synced by a callback.\n                    state = self._get_jax_state(\n                        trainable_variables=True,\n                        non_trainable_variables=True,\n                        purge_model_variables=True,\n                    )\n                    self._jax_state_synced = False\n                batch_outputs, state = self.predict_function(state, iterator)\n                (\n                    trainable_variables,\n                    non_trainable_variables,\n                ) = state\n                self._jax_state = {\n                    \"trainable_variables\": trainable_variables,\n                    # I wouldn't recommend modifying non-trainable model state\n                    # during predict(), but it's allowed.\n                    \"non_trainable_variables\": non_trainable_variables,\n                }\n                outputs = append_to_outputs(batch_outputs, outputs)\n\n                # Dispatch callbacks. This takes care of async dispatch.\n                callbacks.on_predict_batch_end(\n                    end_step, {\"outputs\": batch_outputs}\n                )\n\n                if self.stop_predicting:\n                    break\n\n        self.jax_state_sync()\n        callbacks.on_predict_end()\n        self._jax_state = None\n        return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs)\n\n    def train_on_batch(\n        self,\n        x,\n        y=None,\n        sample_weight=None,\n        class_weight=None,\n        return_dict=False,\n    ):\n        self._assert_compile_called(\"train_on_batch\")\n        if class_weight is not None:\n            if sample_weight is not None:\n                raise ValueError(\n                    \"Arguments `sample_weight` and `class_weight` \"\n                    \"cannot be specified at the same time. \"\n                    f\"Received: sample_weight={sample_weight}, \"\n                    f\"class_weight={class_weight}\"\n                )\n            sample_weight = data_adapter_utils.class_weight_to_sample_weights(\n                y, class_weight\n            )\n\n        def data():\n            yield _distribute_data((x, y, sample_weight))\n\n        # Maybe build model\n        self._symbolic_build(data_batch=next(data()))\n        self.make_train_function()\n\n        # Train step\n        state = self._get_jax_state(\n            trainable_variables=True,\n            non_trainable_variables=True,\n            optimizer_variables=True,\n            metrics_variables=True,\n            purge_model_variables=False,\n        )\n        self._jax_state_synced = False\n        logs, state = self.train_function(state, data())\n\n        # State sync\n        (\n            trainable_variables,\n            non_trainable_variables,\n            optimizer_variables,\n            metrics_variables,\n        ) = state\n        self._jax_state = {\n            \"trainable_variables\": trainable_variables,\n            \"non_trainable_variables\": non_trainable_variables,\n            \"optimizer_variables\": optimizer_variables,\n            \"metrics_variables\": metrics_variables,\n        }\n        self.jax_state_sync()\n\n        # Format return values\n        logs = pythonify_logs(logs)\n        if return_dict:\n            return logs\n        return self._flatten_metrics_in_order(logs)\n\n    def test_on_batch(\n        self,\n        x,\n        y=None,\n        sample_weight=None,\n        return_dict=False,\n    ):\n        self._assert_compile_called(\"test_on_batch\")\n\n        def data():\n            yield _distribute_data((x, y, sample_weight))\n\n        # Maybe build model\n        self._symbolic_build(data_batch=next(data()))\n        self.make_test_function()\n\n        # Test step\n        state = self._get_jax_state(\n            trainable_variables=True,\n            non_trainable_variables=True,\n            metrics_variables=True,\n            purge_model_variables=False,\n        )\n        self._jax_state_synced = False\n        logs, state = self.test_function(state, data())\n\n        # State sync\n        trainable_variables, non_trainable_variables, metrics_variables = state\n        self._jax_state = {\n            \"trainable_variables\": trainable_variables,\n            \"non_trainable_variables\": non_trainable_variables,\n            \"metrics_variables\": metrics_variables,\n        }\n        self.jax_state_sync()\n\n        # Format return values.\n        logs = pythonify_logs(logs)\n        if return_dict:\n            return logs\n        return self._flatten_metrics_in_order(logs)\n\n    def predict_on_batch(self, x):\n        if not all(layer.built for layer in self._flatten_layers()):\n            # Build model\n            with backend.StatelessScope():\n                self(x)\n        self.make_predict_function()\n\n        state = self._get_jax_state(\n            trainable_variables=True,\n            non_trainable_variables=True,\n            metrics_variables=False,\n            purge_model_variables=False,\n        )\n        self._jax_state_synced = False\n\n        def data():\n            yield (x,)\n\n        batch_outputs, state = self.predict_function(state, data())\n        trainable_variables, non_trainable_variables = state\n        self._jax_state = {\n            \"trainable_variables\": trainable_variables,\n            \"non_trainable_variables\": non_trainable_variables,\n        }\n        self.jax_state_sync()\n        batch_outputs = tree.map_structure(lambda x: np.array(x), batch_outputs)\n        return batch_outputs\n\n    def jax_state_sync(self):\n        if not getattr(self, \"_jax_state\", None) or self._jax_state_synced:\n            return\n\n        trainable_variables = self._jax_state.get(\"trainable_variables\", None)\n        non_trainable_variables = self._jax_state.get(\n            \"non_trainable_variables\", None\n        )\n        optimizer_variables = self._jax_state.get(\"optimizer_variables\", None)\n        metrics_variables = self._jax_state.get(\"metrics_variables\", None)\n        if trainable_variables:\n            for ref_v, v in zip(self.trainable_variables, trainable_variables):\n                ref_v.assign(v)\n        if non_trainable_variables:\n            for ref_v, v in zip(\n                self.non_trainable_variables, non_trainable_variables\n            ):\n                ref_v.assign(v)\n        if optimizer_variables:\n            for ref_v, v in zip(self.optimizer.variables, optimizer_variables):\n                ref_v.assign(v)\n        if metrics_variables:\n            for ref_v, v in zip(self.metrics_variables, metrics_variables):\n                ref_v.assign(v)\n        self._jax_state_synced = True\n\n    def _get_state_sharding_spec(self):\n        trainable_shardings = [\n            v.value.sharding for v in self.trainable_variables\n        ]\n        non_trainable_shardings = [\n            v.value.sharding for v in self.non_trainable_variables\n        ]\n        if hasattr(self, \"optimizer\") and self.optimizer is not None:\n            optimizer_shardings = [\n                v.value.sharding for v in self.optimizer.variables\n            ]\n        else:\n            optimizer_shardings = []\n        metrics_shardings = [v.value.sharding for v in self.metrics_variables]\n\n        self._check_sharding_consistency(\n            trainable_shardings,\n            non_trainable_shardings,\n            optimizer_shardings,\n            metrics_shardings,\n        )\n\n        return (\n            trainable_shardings,\n            non_trainable_shardings,\n            optimizer_shardings,\n            metrics_shardings,\n        )\n\n    def _check_sharding_consistency(\n        self,\n        trainable_shardings,\n        non_trainable_shardings,\n        optimizer_shardings,\n        metrics_shardings,\n    ):\n        \"\"\"Warn if there is a mix of local and distributed variable shardings.\n\n        When some variables have SingleDeviceSharding (created outside the\n        distribution scope) and others have mesh-aware shardings (created\n        inside), passing them together as `out_shardings` to `jax.jit`\n        raises ``ValueError: Received incompatible devices for jitted\n        computation``. This helper detects the mismatch early and emits\n        an actionable warning.\n        \"\"\"\n        if distribution_lib.distribution() is None:\n            return\n\n        var_shard_pairs = itertools.chain(\n            zip(self.trainable_variables, trainable_shardings),\n            zip(self.non_trainable_variables, non_trainable_shardings),\n            zip(\n                (\n                    self.optimizer.variables\n                    if hasattr(self, \"optimizer\") and self.optimizer\n                    else []\n                ),\n                optimizer_shardings,\n            ),\n            zip(self.metrics_variables, metrics_shardings),\n        )\n\n        first_local_var_path = None\n        has_mesh = False\n        for v, s in var_shard_pairs:\n            if isinstance(s, jax.sharding.SingleDeviceSharding):\n                if first_local_var_path is None:\n                    first_local_var_path = v.path\n            else:\n                has_mesh = True\n            # Early exit: we know there is a mix as soon as we have\n            # seen at least one of each kind.\n            if first_local_var_path and has_mesh:\n                break\n\n        if not (first_local_var_path and has_mesh):\n            return\n\n        warnings.warn(\n            \"Detected a mix of local (SingleDeviceSharding) and \"\n            \"distributed (mesh-aware) variables. This will cause \"\n            \"a 'ValueError: Received incompatible devices for \"\n            \"jitted computation' when JAX tries to compile the \"\n            \"training step.\\n\\n\"\n            f\"First local variable found: {first_local_var_path!r}\\n\\n\"\n            \"This typically happens when the model is built or \"\n            \"weights are loaded before the distribution is set. \"\n            \"To fix this, call set_distribution() before creating \"\n            \"any Keras objects:\\n\\n\"\n            \"    import keras\\n\"\n            \"    keras.distribution.set_distribution(distribution)\\n\"\n            \"    model = create_model()\\n\"\n            \"    model.compile(...)\\n\"\n            \"    model.fit(...)\\n\\n\"\n            \"Alternatively, use the distribution scope context \"\n            \"manager:\\n\\n\"\n            \"    with distribution.scope():\\n\"\n            \"        model = create_model()\\n\"\n            \"        model.compile(...)\\n\"\n            \"        model.fit(...)\\n\",\n            stacklevel=3,\n        )\n\n    def _purge_model_variables(\n        self,\n        trainable_variables=False,\n        non_trainable_variables=False,\n        optimizer_variables=False,\n        metrics_variables=False,\n    ):\n        \"\"\"Remove all the model variable for memory saving.\n\n        During JAX training, since the training function is stateless, we have\n        to pass in and get the model weights over and over, during which the\n        copy of the weights that attached to the Variable are still and\n        occupying extra memory. We remove those variable to save memory (for\n        better memory utilization) at the beginning of the epoch, and reattach\n        the value back to variables at the end of the epoch, via\n        `jax_state_sync()`.\n        \"\"\"\n        if trainable_variables:\n            for v in self.trainable_variables:\n                v._value = None\n        if non_trainable_variables:\n            for v in self.non_trainable_variables:\n                v._value = None\n        if optimizer_variables:\n            for v in self.optimizer.variables:\n                v._value = None\n        if metrics_variables:\n            for v in self.metrics_variables:\n                v._value = None\n\n    def _get_jax_state(\n        self,\n        trainable_variables=False,\n        non_trainable_variables=False,\n        optimizer_variables=False,\n        metrics_variables=False,\n        purge_model_variables=False,\n    ):\n        state = []\n        if trainable_variables:\n            state.append([v.value for v in self.trainable_variables])\n        if non_trainable_variables:\n            state.append([v.value for v in self.non_trainable_variables])\n        if optimizer_variables:\n            state.append([v.value for v in self.optimizer.variables])\n        if metrics_variables:\n            state.append([v.value for v in self.metrics_variables])\n        if purge_model_variables:\n            self._purge_model_variables(\n                trainable_variables=trainable_variables,\n                non_trainable_variables=non_trainable_variables,\n                optimizer_variables=optimizer_variables,\n                metrics_variables=metrics_variables,\n            )\n        return tuple(state)\n\n\ndef _distribute_data(data, layouts=None):\n    distribution = distribution_lib.distribution()\n\n    if distribution is not None:\n        if layouts is None:\n            layouts = tree.map_structure(\n                lambda d: distribution.get_data_layout(d.shape),\n                data,\n            )\n        jax_dist_data_input = partial(\n            jax_distribution_lib.distribute_data_input,\n            batch_dim_name=distribution.batch_dim_name,\n        )\n        return tree.map_structure(jax_dist_data_input, data, layouts)\n\n    return tree.map_structure(jax.device_put, data)\n\n\nclass JAXEpochIterator(EpochIterator):\n    def __next__(self):\n        return next(self._epoch_iterator)\n\n    def _get_iterator(self):\n        distribution = distribution_lib.distribution()\n        if distribution is not None:\n            return self._get_distributed_iterator(distribution)\n        else:\n            return self._one_batch_ahead_iterator(\n                self.data_adapter.get_jax_iterator()\n            )\n\n    def _get_distributed_iterator(self, distribution):\n        \"\"\"Lazily compute layouts to reduce host to device transfer latency.\"\"\"\n        layouts = None\n        for data in self.data_adapter.get_jax_iterator():\n            if layouts is None:\n                layouts = tree.map_structure(\n                    lambda d: (\n                        distribution.get_data_layout(d.shape).backend_layout\n                    ),\n                    data,\n                )\n            yield _distribute_data(data, layouts)\n\n    def _one_batch_ahead_iterator(self, numpy_iterator):\n        \"\"\"Initiate transfers to the device one batch ahead.\n\n        This utility takes an iterator and returns a new iterator which\n        initiates the transfer to device one step ahead. This can improve the\n        performance of training loops significantly by overlapping compute and\n        data transfer.\n        \"\"\"\n        next_batch = None\n        for batch in numpy_iterator:\n            batch = _distribute_data(batch)\n            if next_batch is None:\n                next_batch = batch\n            else:\n                current_batch = next_batch\n                next_batch = batch\n                yield current_batch\n\n        if next_batch is not None:\n            yield next_batch\n"
  },
  {
    "path": "keras/src/backend/jax/trainer_test.py",
    "content": "import warnings\n\nimport numpy as np\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import testing\nfrom keras.src.backend import distribution_lib as backend_dlib\nfrom keras.src.distribution import distribution_lib\n\n\nclass JAXTrainerTest(testing.TestCase, parameterized.TestCase):\n    def _skip_if_not_distributed(self):\n        if backend.backend() != \"jax\":\n            self.skipTest(\"Requires JAX backend\")\n        if len(backend_dlib.list_devices()) < 2:\n            self.skipTest(\"Requires at least 2 devices\")\n\n    def _make_distribution(self, dist_type):\n        if dist_type == \"data_parallel\":\n            return distribution_lib.DataParallel()\n        devices = backend_dlib.list_devices()\n        n = len(devices)\n        mesh = distribution_lib.DeviceMesh((n,), [\"model\"], devices)\n        layout_map = distribution_lib.LayoutMap(mesh)\n        layout_map[\".*dense.*kernel\"] = distribution_lib.TensorLayout(\n            [None, \"model\"]\n        )\n        layout_map[\".*dense.*bias\"] = distribution_lib.TensorLayout([\"model\"])\n        return distribution_lib.ModelParallel(layout_map=layout_map)\n\n    # ----------------------------------------------------------------\n    # Mixed-sharding warning tests\n    # ----------------------------------------------------------------\n    @parameterized.named_parameters(\n        {\"testcase_name\": \"DataParallel\", \"dist_type\": \"data_parallel\"},\n        {\"testcase_name\": \"ModelParallel\", \"dist_type\": \"model_parallel\"},\n    )\n    def test_warns_when_model_built_outside_scope(self, dist_type):\n        \"\"\"Model built outside distribution -> mixed warning on compile.\"\"\"\n        self._skip_if_not_distributed()\n        import jax\n\n        n = len(backend_dlib.list_devices())\n        units = n * max(1, 4 // n)\n        dist = self._make_distribution(dist_type)\n\n        # Model created outside any distribution scope — weights are local.\n        model = models.Sequential([layers.Dense(units, input_shape=(16,))])\n\n        for w in model.weights:\n            self.assertIsInstance(\n                w.value.sharding, jax.sharding.SingleDeviceSharding\n            )\n\n        inputs = np.random.normal(size=(8, 16)).astype(\"float32\")\n        labels = np.random.normal(size=(8, units)).astype(\"float32\")\n\n        with dist.scope():\n            model.compile(loss=\"mse\", optimizer=\"adam\")\n            with warnings.catch_warnings(record=True) as caught:\n                warnings.simplefilter(\"always\")\n                model._symbolic_build(data_batch=(inputs[:2], labels[:2]))\n                model._get_state_sharding_spec()\n\n            mixed = [w for w in caught if \"mix of local\" in str(w.message)]\n            self.assertGreater(\n                len(mixed),\n                0,\n                \"Expected a mixed-sharding warning but none was raised\",\n            )\n            msg = str(mixed[0].message)\n            self.assertIn(\"SingleDeviceSharding\", msg)\n            self.assertIn(\"set_distribution\", msg)\n\n    @parameterized.named_parameters(\n        {\"testcase_name\": \"DataParallel\", \"dist_type\": \"data_parallel\"},\n        {\"testcase_name\": \"ModelParallel\", \"dist_type\": \"model_parallel\"},\n    )\n    def test_no_warning_when_model_built_inside_scope(self, dist_type):\n        \"\"\"Model built inside distribution scope -> no warning.\"\"\"\n        self._skip_if_not_distributed()\n\n        n = len(backend_dlib.list_devices())\n        units = n * max(1, 4 // n)\n        dist = self._make_distribution(dist_type)\n\n        # Model created inside scope — weights get proper sharding.\n        with dist.scope():\n            model = models.Sequential([layers.Dense(units, input_shape=(16,))])\n\n        inputs = np.random.normal(size=(8, 16)).astype(\"float32\")\n        labels = np.random.normal(size=(8, units)).astype(\"float32\")\n\n        with dist.scope():\n            model.compile(loss=\"mse\", optimizer=\"adam\")\n            with warnings.catch_warnings(record=True) as caught:\n                warnings.simplefilter(\"always\")\n                model._symbolic_build(data_batch=(inputs[:2], labels[:2]))\n                model._get_state_sharding_spec()\n\n            mixed = [w for w in caught if \"mix of local\" in str(w.message)]\n            self.assertEqual(\n                len(mixed),\n                0,\n                \"Unexpected mixed-sharding warning when model is \"\n                \"built inside scope\",\n            )\n"
  },
  {
    "path": "keras/src/backend/numpy/__init__.py",
    "content": "from keras.src.backend.common.name_scope import name_scope\nfrom keras.src.backend.numpy import core\nfrom keras.src.backend.numpy import image\nfrom keras.src.backend.numpy import linalg\nfrom keras.src.backend.numpy import math\nfrom keras.src.backend.numpy import nn\nfrom keras.src.backend.numpy import numpy\nfrom keras.src.backend.numpy import random\nfrom keras.src.backend.numpy.core import IS_THREAD_SAFE\nfrom keras.src.backend.numpy.core import SUPPORTS_RAGGED_TENSORS\nfrom keras.src.backend.numpy.core import SUPPORTS_SPARSE_TENSORS\nfrom keras.src.backend.numpy.core import Variable\nfrom keras.src.backend.numpy.core import cast\nfrom keras.src.backend.numpy.core import compute_output_spec\nfrom keras.src.backend.numpy.core import cond\nfrom keras.src.backend.numpy.core import convert_to_numpy\nfrom keras.src.backend.numpy.core import convert_to_tensor\nfrom keras.src.backend.numpy.core import device_scope\nfrom keras.src.backend.numpy.core import is_tensor\nfrom keras.src.backend.numpy.core import random_seed_dtype\nfrom keras.src.backend.numpy.core import shape\nfrom keras.src.backend.numpy.core import vectorized_map\nfrom keras.src.backend.numpy.rnn import cudnn_ok\nfrom keras.src.backend.numpy.rnn import gru\nfrom keras.src.backend.numpy.rnn import lstm\nfrom keras.src.backend.numpy.rnn import rnn\n"
  },
  {
    "path": "keras/src/backend/numpy/core.py",
    "content": "import builtins\nimport contextlib\nimport functools\nimport warnings\n\nimport numpy as np\n\nfrom keras.src import tree\nfrom keras.src.backend.common import KerasVariable\nfrom keras.src.backend.common import standardize_dtype\nfrom keras.src.backend.common.backend_utils import slice_along_axis\nfrom keras.src.backend.common.dtypes import result_type\nfrom keras.src.backend.common.keras_tensor import KerasTensor\nfrom keras.src.backend.common.stateless_scope import StatelessScope\nfrom keras.src.backend.common.symbolic_scope import SymbolicScope\n\nSUPPORTS_SPARSE_TENSORS = False\nSUPPORTS_RAGGED_TENSORS = False\nIS_THREAD_SAFE = True\n\n\nclass Variable(KerasVariable):\n    def _initialize(self, value):\n        self._value = value\n\n    def _direct_assign(self, value):\n        self._value = np.array(value, dtype=self._dtype)\n\n    def _convert_to_tensor(self, value, dtype=None):\n        return convert_to_tensor(value, dtype=dtype)\n\n    # Overload native accessor.\n    def __array__(self):\n        return self.value\n\n\ndef convert_to_tensor(x, dtype=None, sparse=None, ragged=None):\n    if sparse:\n        raise ValueError(\"`sparse=True` is not supported with numpy backend\")\n    if ragged:\n        raise ValueError(\"`ragged=True` is not supported with numpy backend\")\n    if dtype is not None:\n        dtype = standardize_dtype(dtype)\n    if isinstance(x, Variable):\n        if dtype and dtype != x.dtype:\n            return x.value.astype(dtype)\n        return x.value\n    if not is_tensor(x) and standardize_dtype(dtype) == \"bfloat16\":\n        # Can't create bfloat16 arrays on the fly (e.g. from a h5 Dataset).\n        # Instead we convert \"as is\" (to stored dtype) and cast.\n        return np.asarray(x).astype(dtype)\n    if dtype is None:\n        dtype = result_type(\n            *[getattr(item, \"dtype\", type(item)) for item in tree.flatten(x)]\n        )\n    return np.array(x, dtype=dtype)\n\n\ndef convert_to_numpy(x):\n    return np.array(x)\n\n\ndef is_tensor(x):\n    if isinstance(x, (np.generic, np.ndarray)):\n        return True\n    return False\n\n\ndef shape(x):\n    return x.shape\n\n\ndef cast(x, dtype):\n    return convert_to_tensor(x, dtype=dtype)\n\n\ndef cond(pred, true_fn, false_fn):\n    if pred:\n        return true_fn()\n    return false_fn()\n\n\ndef vectorized_map(function, elements):\n    if not isinstance(elements, (list, tuple)):\n        return np.stack([function(x) for x in elements])\n    else:\n        batch_size = elements[0].shape[0]\n        output_store = []\n        for index in range(batch_size):\n            output_store.append(function([x[index] for x in elements]))\n        return np.stack(output_store)\n\n\n# Shape / dtype inference util\ndef compute_output_spec(fn, *args, **kwargs):\n    with StatelessScope(), SymbolicScope():\n\n        def has_none_shape(x):\n            if isinstance(x, KerasTensor):\n                return None in x.shape\n            return False\n\n        none_in_shape = any(\n            builtins.map(has_none_shape, tree.flatten((args, kwargs)))\n        )\n\n        def convert_keras_tensor_to_numpy(x, fill_value=None):\n            if isinstance(x, KerasTensor):\n                shape = list(x.shape)\n                if fill_value:\n                    for i, e in enumerate(shape):\n                        if e is None:\n                            shape[i] = fill_value\n                return np.empty(\n                    shape=shape,\n                    dtype=x.dtype,\n                )\n            return x\n\n        args_1, kwargs_1 = tree.map_structure(\n            lambda x: convert_keras_tensor_to_numpy(x, fill_value=83),\n            (args, kwargs),\n        )\n        outputs_1 = fn(*args_1, **kwargs_1)\n\n        outputs = outputs_1\n\n        if none_in_shape:\n            args_2, kwargs_2 = tree.map_structure(\n                lambda x: convert_keras_tensor_to_numpy(x, fill_value=89),\n                (args, kwargs),\n            )\n            outputs_2 = fn(*args_2, **kwargs_2)\n\n            flat_out_1 = tree.flatten(outputs_1)\n            flat_out_2 = tree.flatten(outputs_2)\n\n            flat_out = []\n            for x1, x2 in zip(flat_out_1, flat_out_2):\n                shape = list(x1.shape)\n                for i, e in enumerate(x2.shape):\n                    if e != shape[i]:\n                        shape[i] = None\n                flat_out.append(KerasTensor(shape, standardize_dtype(x1.dtype)))\n            outputs = tree.pack_sequence_as(outputs_1, flat_out)\n\n        def convert_numpy_to_keras_tensor(x):\n            if is_tensor(x):\n                return KerasTensor(x.shape, standardize_dtype(x.dtype))\n            return x\n\n        output_spec = tree.map_structure(convert_numpy_to_keras_tensor, outputs)\n    return output_spec\n\n\ndef map(f, xs):\n    def g(_, x):\n        return (), f(x)\n\n    _, ys = scan(g, (), xs)\n    return ys\n\n\ndef scan(f, init, xs=None, length=None, reverse=False, unroll=1):\n    # Ref: jax.lax.scan\n    if not callable(f):\n        raise TypeError(f\"`f` should be a callable. Received: f={f}\")\n    if not isinstance(unroll, bool):\n        if not isinstance(unroll, int) or unroll < 1:\n            raise ValueError(\n                \"`unroll` must be an positive integer or boolean. \"\n                f\"Received: unroll={unroll}\"\n            )\n    if xs is None and length is None:\n        raise ValueError(\"Got no `xs` to scan over and `length` not provided.\")\n\n    input_is_sequence = tree.is_nested(xs)\n    output_is_sequence = tree.is_nested(init)\n\n    def pack_input(x):\n        return tree.pack_sequence_as(xs, x) if input_is_sequence else x[0]\n\n    def pack_output(x):\n        return tree.pack_sequence_as(init, x) if output_is_sequence else x[0]\n\n    if xs is None:\n        xs_flat = []\n        n = int(length)\n    else:\n        xs_flat = tree.flatten(xs)\n        xs_flat = [convert_to_tensor(elem) for elem in xs_flat]\n        n = int(length) if length is not None else shape(xs_flat[0])[0]\n\n    init_flat = tree.flatten(init)\n    init_flat = [convert_to_tensor(init) for init in init_flat]\n    init = pack_output(init_flat)\n    dummy_y = [np.zeros_like(init) for init in init_flat]\n\n    carry = init\n    ys = []\n    maybe_reversed = reversed if reverse else lambda x: x\n    for i in maybe_reversed(range(n)):\n        xs_slice = [x[i] for x in xs_flat]\n        packed_xs = pack_input(xs_slice) if len(xs_slice) > 0 else None\n        carry, y = f(carry, packed_xs)\n        ys.append(y if y is not None else dummy_y)\n    stacked_y = tree.map_structure(\n        lambda *ys: np.stack(ys), *maybe_reversed(ys)\n    )\n    return carry, stacked_y\n\n\ndef associative_scan(f, elems, reverse=False, axis=0):\n    # Ref: jax.lax.associative_scan\n    if not callable(f):\n        raise TypeError(f\"`f` should be a callable. Received: f={f}\")\n    elems_flat = tree.flatten(elems)\n    elems_flat = [convert_to_tensor(elem) for elem in elems_flat]\n    if reverse:\n        elems_flat = [np.flip(elem, (axis,)) for elem in elems_flat]\n\n    def _combine(a_flat, b_flat):\n        a = tree.pack_sequence_as(elems, a_flat)\n        b = tree.pack_sequence_as(elems, b_flat)\n        c = f(a, b)\n        c_flat = tree.flatten(c)\n        return c_flat\n\n    num_elems = int(elems_flat[0].shape[axis])\n    if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]):\n        raise ValueError(\n            \"Array inputs to associative_scan must have the same \"\n            \"first dimension. (saw: {})\".format(\n                [elem.shape for elem in elems_flat]\n            )\n        )\n\n    def _interleave(a, b, axis):\n        \"\"\"Given two Tensors of static shape, interleave them along axis.\"\"\"\n        if not (\n            a.shape[axis] == b.shape[axis] or a.shape[axis] == b.shape[axis] + 1\n        ):\n            raise ValueError(\n                \"Shapes are incompatible for associative_scan interleaving. \"\n                f\"a.shape[{axis}]={a.shape[axis]}, \"\n                f\"b.shape[{axis}]={b.shape[axis]}\"\n            )\n\n        # we want to get a: [a1, a2], b: [b1, b2]\n        # to a: [a1, 0, a2, 0], b: [0, b1, 0, b2]\n        a_shape = list(a.shape)\n        a_shape[axis] = a.shape[axis] * 2 - 1\n\n        b_shape = list(b.shape)\n        b_shape[axis] = b.shape[axis] * 2 - 1\n\n        a_dil = np.zeros(a_shape)\n        np.copyto(slice_along_axis(a_dil, 0, None, 2, axis), a)\n        b_dil = np.zeros(b_shape)\n        np.copyto(slice_along_axis(b_dil, 0, None, 2, axis), b)\n\n        a_pad = [[0, 0] for _ in range(a.ndim)]\n        a_pad[axis][-1] = 1 if a.shape[axis] == b.shape[axis] else 0\n\n        b_pad = [[0, 0] for _ in range(b.ndim)]\n        b_pad[axis] = [1, 0] if a.shape[axis] == b.shape[axis] else [1, 1]\n\n        op = np.bitwise_or if a.dtype == np.bool_ else np.add\n        return op(\n            np.pad(a_dil, a_pad),\n            np.pad(b_dil, b_pad),\n        )\n\n    def _scan(elems):\n        num_elems = elems[0].shape[axis]\n        if num_elems < 2:\n            return elems\n\n        reduced_elems = _combine(\n            [\n                slice_along_axis(elem, 0, -1, step=2, axis=axis)\n                for elem in elems\n            ],\n            [\n                slice_along_axis(elem, 1, None, step=2, axis=axis)\n                for elem in elems\n            ],\n        )\n\n        odd_elems = _scan(reduced_elems)\n        if num_elems % 2 == 0:\n            even_elems = _combine(\n                [slice_along_axis(e, 0, -1, axis=axis) for e in odd_elems],\n                [\n                    slice_along_axis(e, 2, None, step=2, axis=axis)\n                    for e in elems\n                ],\n            )\n        else:\n            even_elems = _combine(\n                odd_elems,\n                [\n                    slice_along_axis(e, 2, None, step=2, axis=axis)\n                    for e in elems\n                ],\n            )\n\n        even_elems = [\n            np.concatenate(\n                [slice_along_axis(elem, 0, 1, axis=axis), result],\n                axis=axis,\n            )\n            for (elem, result) in zip(elems, even_elems)\n        ]\n        return list(\n            builtins.map(\n                functools.partial(_interleave, axis=axis), even_elems, odd_elems\n            )\n        )\n\n    scans = _scan(elems_flat)\n    if reverse:\n        scans = [np.flip(scanned, (axis,)) for scanned in scans]\n\n    return tree.pack_sequence_as(elems, scans)\n\n\ndef scatter(indices, values, shape):\n    indices = convert_to_tensor(indices)\n    values = convert_to_tensor(values)\n    zeros = np.zeros(shape, dtype=values.dtype)\n\n    index_length = indices.shape[-1]\n    value_shape = shape[index_length:]\n    indices = np.reshape(indices, [-1, index_length])\n    values = np.reshape(values, [-1] + list(value_shape))\n\n    idx = tuple(indices.T)\n    np.add.at(zeros, idx, values)\n    return zeros\n\n\ndef scatter_update(inputs, indices, updates, reduction=None):\n    indices = np.array(indices)\n    indices = np.transpose(indices)\n    idx = tuple(indices)\n    if reduction is None:\n        inputs[idx] = updates\n    elif reduction == \"add\":\n        np.add.at(inputs, idx, updates)\n    elif reduction == \"max\":\n        np.maximum.at(inputs, idx, updates)\n    elif reduction == \"min\":\n        np.minimum.at(inputs, idx, updates)\n    elif reduction == \"mul\":\n        np.multiply.at(inputs, idx, updates)\n    else:\n        raise ValueError(f\"Unsupported reduction: {reduction}\")\n    return inputs\n\n\ndef slice(inputs, start_indices, shape):\n    # Validate inputs\n    if len(start_indices) != len(shape):\n        raise ValueError(\n            \"Length of `start_indices` must match length of `shape`. \"\n            f\"Received: start_indices={start_indices}, shape={shape}\"\n        )\n\n    # Generate list of indices arrays for each dimension\n    indices = [\n        np.arange(start, start + length)\n        for start, length in zip(start_indices, shape)\n    ]\n\n    # Use np.ix_ to create a multidimensional index array\n    mesh = np.ix_(*indices)\n\n    return inputs[mesh]\n\n\ndef slice_update(inputs, start_indices, updates):\n    # Generate list of indices arrays for each dimension\n    indices = [\n        np.arange(start, start + length)\n        for start, length in zip(start_indices, updates.shape)\n    ]\n\n    # Use np.ix_ to create a multidimensional index array\n    mesh = np.ix_(*indices)\n    inputs[mesh] = updates\n    return inputs\n\n\ndef switch(index, branches, *operands):\n    index = convert_to_tensor(index, \"int32\")\n    index = np.clip(index, 0, len(branches) - 1)\n    return branches[index](*operands)\n\n\ndef while_loop(\n    cond,\n    body,\n    loop_vars,\n    maximum_iterations=None,\n):\n    current_iter = 0\n    iteration_check = lambda iter: (\n        maximum_iterations is None or iter < maximum_iterations\n    )\n    is_tuple = isinstance(loop_vars, (tuple, list))\n    loop_vars = tuple(loop_vars) if is_tuple else (loop_vars,)\n    loop_vars = tree.map_structure(convert_to_tensor, loop_vars)\n    while cond(*loop_vars) and iteration_check(current_iter):\n        loop_vars = body(*loop_vars)\n        if not isinstance(loop_vars, (list, tuple)):\n            loop_vars = (loop_vars,)\n        loop_vars = tuple(loop_vars)\n        current_iter += 1\n    return loop_vars if is_tuple else loop_vars[0]\n\n\ndef fori_loop(lower, upper, body_fun, init_val):\n    val = init_val\n    for i in range(lower, upper):\n        val = body_fun(i, val)\n    return val\n\n\ndef stop_gradient(variable):\n    return variable\n\n\ndef unstack(x, num=None, axis=0):\n    x = np.moveaxis(x, axis, 0)\n    return [x[i] for i in range(x.shape[0])]\n\n\ndef random_seed_dtype():\n    return \"uint32\"\n\n\nclass custom_gradient:\n    \"\"\"Decorator for custom gradients.\n\n    Args:\n        fun: Forward pass function.\n    \"\"\"\n\n    def __init__(self, fun):\n        warnings.warn(\n            \"`custom_gradient` for the numpy backend acts as a pass-through to \"\n            \"support the forward pass. No gradient computation or modification \"\n            \"takes place.\"\n        )\n        self.fun = fun\n\n    def __call__(self, *args, **kwargs):\n        outputs, _ = self.fun(*args, **kwargs)\n        return outputs\n\n\n@contextlib.contextmanager\ndef device_scope(device_name):\n    yield\n\n\ndef remat(f):\n    warnings.warn(\n        \"Rematerialization memory optimization is not supported by the \"\n        \"Numpy backend. Please switch to JAX, TensorFlow, or PyTorch to \"\n        \"utilize this feature.\"\n    )\n    return f\n"
  },
  {
    "path": "keras/src/backend/numpy/export.py",
    "content": "class NumpyExportArchive:\n    def track(self, resource):\n        raise NotImplementedError(\n            \"`track` is not implemented in the numpy backend.\"\n        )\n\n    def add_endpoint(self, name, fn, input_signature=None, **kwargs):\n        raise NotImplementedError(\n            \"`add_endpoint` is not implemented in the numpy backend.\"\n        )\n"
  },
  {
    "path": "keras/src/backend/numpy/image.py",
    "content": "import ml_dtypes\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src.backend.numpy.core import convert_to_tensor\nfrom keras.src.random.seed_generator import draw_seed\nfrom keras.src.utils.module_utils import scipy\n\nRESIZE_INTERPOLATIONS = (\n    \"bilinear\",\n    \"nearest\",\n    \"lanczos3\",\n    \"lanczos5\",\n    \"bicubic\",\n)\nAFFINE_TRANSFORM_INTERPOLATIONS = {  # map to order\n    \"nearest\": 0,\n    \"bilinear\": 1,\n}\nAFFINE_TRANSFORM_FILL_MODES = {\n    \"constant\",\n    \"nearest\",\n    \"wrap\",\n    \"mirror\",\n    \"reflect\",\n}\nMAP_COORDINATES_FILL_MODES = {\n    \"constant\",\n    \"nearest\",\n    \"wrap\",\n    \"mirror\",\n    \"reflect\",\n}\nSCALE_AND_TRANSLATE_METHODS = {\n    \"linear\",\n    \"bilinear\",\n    \"trilinear\",\n    \"cubic\",\n    \"bicubic\",\n    \"tricubic\",\n    \"lanczos3\",\n    \"lanczos5\",\n}\n\n\ndef rgb_to_grayscale(images, data_format=None):\n    images = convert_to_tensor(images)\n    data_format = backend.standardize_data_format(data_format)\n    channels_axis = -1 if data_format == \"channels_last\" else -3\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n    # Convert to floats\n    original_dtype = images.dtype\n    compute_dtype = backend.result_type(images.dtype, float)\n    images = images.astype(compute_dtype)\n\n    # Ref: tf.image.rgb_to_grayscale\n    rgb_weights = np.array([0.2989, 0.5870, 0.1140], dtype=images.dtype)\n    grayscales = np.tensordot(images, rgb_weights, axes=(channels_axis, -1))\n    grayscales = np.expand_dims(grayscales, axis=channels_axis)\n    return grayscales.astype(original_dtype)\n\n\ndef rgb_to_hsv(images, data_format=None):\n    # Ref: dm_pix\n    images = convert_to_tensor(images)\n    dtype = backend.standardize_dtype(images.dtype)\n    data_format = backend.standardize_data_format(data_format)\n    channels_axis = -1 if data_format == \"channels_last\" else -3\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n    if not backend.is_float_dtype(dtype):\n        raise ValueError(\n            \"Invalid images dtype: expected float dtype. \"\n            f\"Received: images.dtype={dtype}\"\n        )\n    eps = ml_dtypes.finfo(dtype).eps\n    images = np.where(np.abs(images) < eps, 0.0, images)\n    red, green, blue = np.split(images, 3, channels_axis)\n    red = np.squeeze(red, channels_axis)\n    green = np.squeeze(green, channels_axis)\n    blue = np.squeeze(blue, channels_axis)\n\n    def rgb_planes_to_hsv_planes(r, g, b):\n        value = np.maximum(np.maximum(r, g), b)\n        minimum = np.minimum(np.minimum(r, g), b)\n        range_ = value - minimum\n\n        safe_value = np.where(value > 0, value, 1.0)\n        safe_range = np.where(range_ > 0, range_, 1.0)\n\n        saturation = np.where(value > 0, range_ / safe_value, 0.0)\n        norm = 1.0 / (6.0 * safe_range)\n\n        hue = np.where(\n            value == g,\n            norm * (b - r) + 2.0 / 6.0,\n            norm * (r - g) + 4.0 / 6.0,\n        )\n        hue = np.where(value == r, norm * (g - b), hue)\n        hue = np.where(range_ > 0, hue, 0.0) + (hue < 0.0).astype(hue.dtype)\n        return hue, saturation, value\n\n    images = np.stack(\n        rgb_planes_to_hsv_planes(red, green, blue), axis=channels_axis\n    )\n    return images.astype(dtype)\n\n\ndef hsv_to_rgb(images, data_format=None):\n    # Ref: dm_pix\n    images = convert_to_tensor(images)\n    dtype = images.dtype\n    data_format = backend.standardize_data_format(data_format)\n    channels_axis = -1 if data_format == \"channels_last\" else -3\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n    if not backend.is_float_dtype(dtype):\n        raise ValueError(\n            \"Invalid images dtype: expected float dtype. \"\n            f\"Received: images.dtype={backend.standardize_dtype(dtype)}\"\n        )\n    hue, saturation, value = np.split(images, 3, channels_axis)\n    hue = np.squeeze(hue, channels_axis)\n    saturation = np.squeeze(saturation, channels_axis)\n    value = np.squeeze(value, channels_axis)\n\n    def hsv_planes_to_rgb_planes(hue, saturation, value):\n        dh = np.mod(hue, 1.0) * 6.0\n        dr = np.clip(np.abs(dh - 3.0) - 1.0, 0.0, 1.0)\n        dg = np.clip(2.0 - np.abs(dh - 2.0), 0.0, 1.0)\n        db = np.clip(2.0 - np.abs(dh - 4.0), 0.0, 1.0)\n        one_minus_s = 1.0 - saturation\n\n        red = value * (one_minus_s + saturation * dr)\n        green = value * (one_minus_s + saturation * dg)\n        blue = value * (one_minus_s + saturation * db)\n        return red, green, blue\n\n    images = np.stack(\n        hsv_planes_to_rgb_planes(hue, saturation, value), axis=channels_axis\n    )\n    return images.astype(dtype)\n\n\ndef resize(\n    images,\n    size,\n    interpolation=\"bilinear\",\n    antialias=False,\n    crop_to_aspect_ratio=False,\n    pad_to_aspect_ratio=False,\n    fill_mode=\"constant\",\n    fill_value=0.0,\n    data_format=None,\n):\n    data_format = backend.standardize_data_format(data_format)\n    if interpolation not in RESIZE_INTERPOLATIONS:\n        raise ValueError(\n            \"Invalid value for argument `interpolation`. Expected of one \"\n            f\"{RESIZE_INTERPOLATIONS}. Received: interpolation={interpolation}\"\n        )\n    if fill_mode != \"constant\":\n        raise ValueError(\n            \"Invalid value for argument `fill_mode`. Only `'constant'` \"\n            f\"is supported. Received: fill_mode={fill_mode}\"\n        )\n    if pad_to_aspect_ratio and crop_to_aspect_ratio:\n        raise ValueError(\n            \"Only one of `pad_to_aspect_ratio` & `crop_to_aspect_ratio` \"\n            \"can be `True`.\"\n        )\n    if not len(size) == 2:\n        raise ValueError(\n            \"Argument `size` must be a tuple of two elements \"\n            f\"(height, width). Received: size={size}\"\n        )\n    size = tuple(size)\n    target_height, target_width = size\n    if len(images.shape) == 4:\n        if data_format == \"channels_last\":\n            size = (images.shape[0],) + size + (images.shape[-1],)\n        else:\n            size = (images.shape[0], images.shape[1]) + size\n    elif len(images.shape) == 3:\n        if data_format == \"channels_last\":\n            size = size + (images.shape[-1],)\n        else:\n            size = (images.shape[0],) + size\n    else:\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n\n    if crop_to_aspect_ratio:\n        shape = images.shape\n        if data_format == \"channels_last\":\n            height, width = shape[-3], shape[-2]\n        else:\n            height, width = shape[-2], shape[-1]\n        crop_height = int(float(width * target_height) / target_width)\n        crop_height = max(min(height, crop_height), 1)\n        crop_width = int(float(height * target_width) / target_height)\n        crop_width = max(min(width, crop_width), 1)\n        crop_box_hstart = int(float(height - crop_height) / 2)\n        crop_box_wstart = int(float(width - crop_width) / 2)\n        if data_format == \"channels_last\":\n            if len(images.shape) == 4:\n                images = images[\n                    :,\n                    crop_box_hstart : crop_box_hstart + crop_height,\n                    crop_box_wstart : crop_box_wstart + crop_width,\n                    :,\n                ]\n            else:\n                images = images[\n                    crop_box_hstart : crop_box_hstart + crop_height,\n                    crop_box_wstart : crop_box_wstart + crop_width,\n                    :,\n                ]\n        else:\n            if len(images.shape) == 4:\n                images = images[\n                    :,\n                    :,\n                    crop_box_hstart : crop_box_hstart + crop_height,\n                    crop_box_wstart : crop_box_wstart + crop_width,\n                ]\n            else:\n                images = images[\n                    :,\n                    crop_box_hstart : crop_box_hstart + crop_height,\n                    crop_box_wstart : crop_box_wstart + crop_width,\n                ]\n    elif pad_to_aspect_ratio:\n        shape = images.shape\n        batch_size = images.shape[0]\n        if data_format == \"channels_last\":\n            height, width, channels = shape[-3], shape[-2], shape[-1]\n        else:\n            channels, height, width = shape[-3], shape[-2], shape[-1]\n        pad_height = int(float(width * target_height) / target_width)\n        pad_height = max(height, pad_height)\n        pad_width = int(float(height * target_width) / target_height)\n        pad_width = max(width, pad_width)\n        img_box_hstart = int(float(pad_height - height) / 2)\n        img_box_wstart = int(float(pad_width - width) / 2)\n\n        if data_format == \"channels_last\":\n            if img_box_hstart > 0:\n                if len(images.shape) == 4:\n                    padded_img = np.concatenate(\n                        [\n                            np.ones(\n                                (batch_size, img_box_hstart, width, channels),\n                                dtype=images.dtype,\n                            )\n                            * fill_value,\n                            images,\n                            np.ones(\n                                (batch_size, img_box_hstart, width, channels),\n                                dtype=images.dtype,\n                            )\n                            * fill_value,\n                        ],\n                        axis=1,\n                    )\n                else:\n                    padded_img = np.concatenate(\n                        [\n                            np.ones(\n                                (img_box_hstart, width, channels),\n                                dtype=images.dtype,\n                            )\n                            * fill_value,\n                            images,\n                            np.ones(\n                                (img_box_hstart, width, channels),\n                                dtype=images.dtype,\n                            )\n                            * fill_value,\n                        ],\n                        axis=0,\n                    )\n            elif img_box_wstart > 0:\n                if len(images.shape) == 4:\n                    padded_img = np.concatenate(\n                        [\n                            np.ones(\n                                (batch_size, height, img_box_wstart, channels),\n                                dtype=images.dtype,\n                            )\n                            * fill_value,\n                            images,\n                            np.ones(\n                                (batch_size, height, img_box_wstart, channels),\n                                dtype=images.dtype,\n                            )\n                            * fill_value,\n                        ],\n                        axis=2,\n                    )\n                else:\n                    padded_img = np.concatenate(\n                        [\n                            np.ones(\n                                (height, img_box_wstart, channels),\n                                dtype=images.dtype,\n                            )\n                            * fill_value,\n                            images,\n                            np.ones(\n                                (height, img_box_wstart, channels),\n                                dtype=images.dtype,\n                            )\n                            * fill_value,\n                        ],\n                        axis=1,\n                    )\n            else:\n                padded_img = images\n        else:\n            if img_box_hstart > 0:\n                if len(images.shape) == 4:\n                    padded_img = np.concatenate(\n                        [\n                            np.ones(\n                                (batch_size, channels, img_box_hstart, width)\n                            )\n                            * fill_value,\n                            images,\n                            np.ones(\n                                (batch_size, channels, img_box_hstart, width)\n                            )\n                            * fill_value,\n                        ],\n                        axis=2,\n                    )\n                else:\n                    padded_img = np.concatenate(\n                        [\n                            np.ones((channels, img_box_hstart, width))\n                            * fill_value,\n                            images,\n                            np.ones((channels, img_box_hstart, width))\n                            * fill_value,\n                        ],\n                        axis=1,\n                    )\n            elif img_box_wstart > 0:\n                if len(images.shape) == 4:\n                    padded_img = np.concatenate(\n                        [\n                            np.ones(\n                                (batch_size, channels, height, img_box_wstart)\n                            )\n                            * fill_value,\n                            images,\n                            np.ones(\n                                (batch_size, channels, height, img_box_wstart)\n                            )\n                            * fill_value,\n                        ],\n                        axis=3,\n                    )\n                else:\n                    padded_img = np.concatenate(\n                        [\n                            np.ones((channels, height, img_box_wstart))\n                            * fill_value,\n                            images,\n                            np.ones((channels, height, img_box_wstart))\n                            * fill_value,\n                        ],\n                        axis=2,\n                    )\n            else:\n                padded_img = images\n        images = padded_img\n\n    return _resize(images, size, method=interpolation, antialias=antialias)\n\n\ndef _compute_weight_mat(\n    input_size, output_size, scale, translation, kernel, antialias\n):\n    dtype = np.result_type(scale, translation)\n    inv_scale = 1.0 / scale\n    kernel_scale = np.maximum(inv_scale, 1.0) if antialias else 1.0\n\n    sample_f = (\n        (np.arange(output_size, dtype=dtype) + 0.5) * inv_scale\n        - translation * inv_scale\n        - 0.5\n    )\n\n    x = (\n        np.abs(\n            sample_f[np.newaxis, :]\n            - np.arange(input_size, dtype=dtype)[:, np.newaxis]\n        )\n        / kernel_scale\n    )\n\n    weights = kernel(x)\n\n    total_weight_sum = np.sum(weights, axis=0, keepdims=True)\n    weights = np.where(\n        np.abs(total_weight_sum) > 1000.0 * np.finfo(np.float32).eps,\n        np.divide(\n            weights, np.where(total_weight_sum != 0, total_weight_sum, 1)\n        ),\n        0,\n    )\n\n    input_size_minus_0_5 = input_size - 0.5\n    return np.where(\n        np.logical_and(sample_f >= -0.5, sample_f <= input_size_minus_0_5)[\n            np.newaxis, :\n        ],\n        weights,\n        0,\n    )\n\n\ndef _resize(image, shape, method, antialias):\n    if method == \"nearest\":\n        return _resize_nearest(image, shape)\n    else:\n        kernel = _kernels.get(method, None)\n    if kernel is None:\n        raise ValueError(\"Unknown resize method\")\n\n    spatial_dims = tuple(\n        i for i in range(len(shape)) if image.shape[i] != shape[i]\n    )\n    scale = [\n        shape[d] / image.shape[d] if image.shape[d] != 0 else 1.0\n        for d in spatial_dims\n    ]\n\n    return _scale_and_translate(\n        image,\n        shape,\n        spatial_dims,\n        scale,\n        [0.0] * len(spatial_dims),\n        kernel,\n        antialias,\n    )\n\n\ndef _resize_nearest(x, output_shape):\n    input_shape = x.shape\n    spatial_dims = tuple(\n        i for i in range(len(input_shape)) if input_shape[i] != output_shape[i]\n    )\n\n    for d in spatial_dims:\n        m, n = input_shape[d], output_shape[d]\n        offsets = (np.arange(n, dtype=np.float32) + 0.5) * m / n\n        offsets = np.floor(offsets).astype(np.int32)\n        indices = [slice(None)] * len(input_shape)\n        indices[d] = offsets\n        x = x[tuple(indices)]\n    return x\n\n\ndef _fill_triangle_kernel(x):\n    return np.maximum(0, 1 - np.abs(x))\n\n\ndef _fill_keys_cubic_kernel(x):\n    out = ((1.5 * x - 2.5) * x) * x + 1.0\n    out = np.where(x >= 1.0, ((-0.5 * x + 2.5) * x - 4.0) * x + 2.0, out)\n    return np.where(x >= 2.0, 0.0, out)\n\n\ndef _fill_lanczos_kernel(radius, x):\n    y = radius * np.sin(np.pi * x) * np.sin(np.pi * x / radius)\n    out = np.where(\n        x > 1e-3, np.divide(y, np.where(x != 0, np.pi**2 * x**2, 1)), 1\n    )\n    return np.where(x > radius, 0.0, out)\n\n\n_kernels = {\n    \"linear\": _fill_triangle_kernel,\n    \"bilinear\": _fill_triangle_kernel,  # For `resize`.\n    \"cubic\": _fill_keys_cubic_kernel,\n    \"bicubic\": _fill_keys_cubic_kernel,  # For `resize`.\n    \"lanczos3\": lambda x: _fill_lanczos_kernel(3.0, x),\n    \"lanczos5\": lambda x: _fill_lanczos_kernel(5.0, x),\n}\n\n\ndef _scale_and_translate(\n    x, output_shape, spatial_dims, scale, translation, kernel, antialias\n):\n    input_shape = x.shape\n\n    if len(spatial_dims) == 0:\n        return x\n\n    if np.issubdtype(x.dtype, np.integer):\n        output = x.astype(np.float32)\n        use_rounding = True\n    else:\n        output = x.copy()\n        use_rounding = False\n\n    for i, d in enumerate(spatial_dims):\n        d = d % x.ndim\n        m, n = input_shape[d], output_shape[d]\n\n        w = _compute_weight_mat(\n            m, n, scale[i], translation[i], kernel, antialias\n        ).astype(output.dtype)\n        output = np.tensordot(output, w, axes=(d, 0))\n        output = np.moveaxis(output, -1, d)\n\n    if use_rounding:\n        output = np.clip(np.round(output), x.min(), x.max())\n        output = output.astype(x.dtype)\n    return output\n\n\ndef affine_transform(\n    images,\n    transform,\n    interpolation=\"bilinear\",\n    fill_mode=\"constant\",\n    fill_value=0,\n    data_format=None,\n):\n    data_format = backend.standardize_data_format(data_format)\n    if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys():\n        raise ValueError(\n            \"Invalid value for argument `interpolation`. Expected of one \"\n            f\"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: \"\n            f\"interpolation={interpolation}\"\n        )\n    if fill_mode not in AFFINE_TRANSFORM_FILL_MODES:\n        raise ValueError(\n            \"Invalid value for argument `fill_mode`. Expected of one \"\n            f\"{AFFINE_TRANSFORM_FILL_MODES}. Received: fill_mode={fill_mode}\"\n        )\n\n    images = convert_to_tensor(images)\n    transform = convert_to_tensor(transform)\n\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n    if len(transform.shape) not in (1, 2):\n        raise ValueError(\n            \"Invalid transform rank: expected rank 1 (single transform) \"\n            \"or rank 2 (batch of transforms). Received input with shape: \"\n            f\"transform.shape={transform.shape}\"\n        )\n\n    # `scipy.ndimage.map_coordinates` lacks support for float16 and bfloat16.\n    input_dtype = backend.standardize_dtype(images.dtype)\n    compute_dtype = backend.result_type(input_dtype, \"float32\")\n    images = images.astype(compute_dtype)\n    transform = transform.astype(compute_dtype)\n\n    # unbatched case\n    need_squeeze = False\n    if len(images.shape) == 3:\n        images = np.expand_dims(images, axis=0)\n        need_squeeze = True\n    if len(transform.shape) == 1:\n        transform = np.expand_dims(transform, axis=0)\n\n    if data_format == \"channels_first\":\n        images = np.transpose(images, (0, 2, 3, 1))\n\n    batch_size = images.shape[0]\n\n    # get indices\n    meshgrid = np.meshgrid(\n        *[np.arange(size) for size in images.shape[1:]], indexing=\"ij\"\n    )\n    indices = np.concatenate(\n        [np.expand_dims(x, axis=-1) for x in meshgrid], axis=-1\n    )\n    indices = np.tile(indices, (batch_size, 1, 1, 1, 1))\n\n    # swap the values\n    a0 = transform[:, 0].copy()\n    a2 = transform[:, 2].copy()\n    b1 = transform[:, 4].copy()\n    b2 = transform[:, 5].copy()\n    transform[:, 0] = b1\n    transform[:, 2] = b2\n    transform[:, 4] = a0\n    transform[:, 5] = a2\n\n    # deal with transform\n    transform = np.pad(transform, pad_width=[[0, 0], [0, 1]], constant_values=1)\n    transform = np.reshape(transform, (batch_size, 3, 3))\n    offset = transform[:, 0:2, 2].copy()\n    offset = np.pad(offset, pad_width=[[0, 0], [0, 1]])\n    transform[:, 0:2, 2] = 0\n\n    # transform the indices\n    coordinates = np.einsum(\"Bhwij, Bjk -> Bhwik\", indices, transform)\n    coordinates = np.moveaxis(coordinates, source=-1, destination=1)\n    coordinates += np.reshape(offset, (*offset.shape, 1, 1, 1))\n\n    # apply affine transformation\n    affined = np.stack(\n        [\n            map_coordinates(\n                images[i],\n                coordinates[i],\n                order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],\n                fill_mode=fill_mode,\n                fill_value=fill_value,\n            )\n            for i in range(batch_size)\n        ],\n        axis=0,\n    )\n\n    if data_format == \"channels_first\":\n        affined = np.transpose(affined, (0, 3, 1, 2))\n    if need_squeeze:\n        affined = np.squeeze(affined, axis=0)\n    return affined.astype(input_dtype)\n\n\ndef perspective_transform(\n    images,\n    start_points,\n    end_points,\n    interpolation=\"bilinear\",\n    fill_value=0,\n    data_format=None,\n):\n    data_format = backend.standardize_data_format(data_format)\n    start_points = convert_to_tensor(start_points)\n    end_points = convert_to_tensor(end_points)\n\n    if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS:\n        raise ValueError(\n            \"Invalid value for argument `interpolation`. Expected of one \"\n            f\"{AFFINE_TRANSFORM_INTERPOLATIONS}. Received: \"\n            f\"interpolation={interpolation}\"\n        )\n\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n\n    if start_points.ndim not in (2, 3) or start_points.shape[-2:] != (4, 2):\n        raise ValueError(\n            \"Invalid start_points shape: expected (4,2) for a single image\"\n            f\" or (N,4,2) for a batch. Received shape: {start_points.shape}\"\n        )\n    if end_points.ndim not in (2, 3) or end_points.shape[-2:] != (4, 2):\n        raise ValueError(\n            \"Invalid end_points shape: expected (4,2) for a single image\"\n            f\" or (N,4,2) for a batch. Received shape: {end_points.shape}\"\n        )\n    if start_points.shape != end_points.shape:\n        raise ValueError(\n            \"start_points and end_points must have the same shape.\"\n            f\" Received start_points.shape={start_points.shape}, \"\n            f\"end_points.shape={end_points.shape}\"\n        )\n\n    input_dtype = images.dtype\n    if input_dtype == \"float16\":\n        images = images.astype(\"float32\")\n\n    need_squeeze = False\n    if len(images.shape) == 3:\n        images = np.expand_dims(images, axis=0)\n        need_squeeze = True\n\n    if len(start_points.shape) == 2:\n        start_points = np.expand_dims(start_points, axis=0)\n    if len(end_points.shape) == 2:\n        end_points = np.expand_dims(end_points, axis=0)\n\n    if data_format == \"channels_first\":\n        images = np.transpose(images, (0, 2, 3, 1))\n\n    batch_size, height, width, channels = images.shape\n\n    transforms = compute_homography_matrix(start_points, end_points)\n\n    if len(transforms.shape) == 1:\n        transforms = np.expand_dims(transforms, axis=0)\n    if transforms.shape[0] == 1 and batch_size > 1:\n        transforms = np.tile(transforms, (batch_size, 1))\n\n    x, y = np.meshgrid(\n        np.arange(width, dtype=np.float32),\n        np.arange(height, dtype=np.float32),\n        indexing=\"xy\",\n    )\n\n    output = np.empty((batch_size, height, width, channels))\n\n    for i in range(batch_size):\n        a0, a1, a2, a3, a4, a5, a6, a7 = transforms[i]\n        denom = a6 * x + a7 * y + 1.0\n        x_in = (a0 * x + a1 * y + a2) / denom\n        y_in = (a3 * x + a4 * y + a5) / denom\n\n        coords = np.stack([y_in.ravel(), x_in.ravel()], axis=0)\n\n        mapped_channels = []\n        for channel in range(channels):\n            channel_img = images[i, :, :, channel]\n\n            mapped_channel = map_coordinates(\n                channel_img,\n                coords,\n                order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],\n                fill_mode=\"constant\",\n                fill_value=fill_value,\n            )\n            mapped_channels.append(mapped_channel.reshape(height, width))\n\n        output[i] = np.stack(mapped_channels, axis=-1)\n\n    if data_format == \"channels_first\":\n        output = np.transpose(output, (0, 3, 1, 2))\n    if need_squeeze:\n        output = np.squeeze(output, axis=0)\n    output = output.astype(input_dtype)\n\n    return output\n\n\ndef compute_homography_matrix(start_points, end_points):\n    start_points = convert_to_tensor(start_points)\n    end_points = convert_to_tensor(end_points)\n    dtype = backend.result_type(start_points.dtype, end_points.dtype, float)\n    # `np.linalg.solve` lacks support for float16 and bfloat16.\n    compute_dtype = backend.result_type(dtype, \"float32\")\n    start_points = start_points.astype(dtype)\n    end_points = end_points.astype(dtype)\n\n    start_x1, start_y1 = start_points[:, 0, 0], start_points[:, 0, 1]\n    start_x2, start_y2 = start_points[:, 1, 0], start_points[:, 1, 1]\n    start_x3, start_y3 = start_points[:, 2, 0], start_points[:, 2, 1]\n    start_x4, start_y4 = start_points[:, 3, 0], start_points[:, 3, 1]\n\n    end_x1, end_y1 = end_points[:, 0, 0], end_points[:, 0, 1]\n    end_x2, end_y2 = end_points[:, 1, 0], end_points[:, 1, 1]\n    end_x3, end_y3 = end_points[:, 2, 0], end_points[:, 2, 1]\n    end_x4, end_y4 = end_points[:, 3, 0], end_points[:, 3, 1]\n\n    coefficient_matrix = np.stack(\n        [\n            np.stack(\n                [\n                    end_x1,\n                    end_y1,\n                    np.ones_like(end_x1),\n                    np.zeros_like(end_x1),\n                    np.zeros_like(end_x1),\n                    np.zeros_like(end_x1),\n                    -start_x1 * end_x1,\n                    -start_x1 * end_y1,\n                ],\n                axis=-1,\n            ),\n            np.stack(\n                [\n                    np.zeros_like(end_x1),\n                    np.zeros_like(end_x1),\n                    np.zeros_like(end_x1),\n                    end_x1,\n                    end_y1,\n                    np.ones_like(end_x1),\n                    -start_y1 * end_x1,\n                    -start_y1 * end_y1,\n                ],\n                axis=-1,\n            ),\n            np.stack(\n                [\n                    end_x2,\n                    end_y2,\n                    np.ones_like(end_x2),\n                    np.zeros_like(end_x2),\n                    np.zeros_like(end_x2),\n                    np.zeros_like(end_x2),\n                    -start_x2 * end_x2,\n                    -start_x2 * end_y2,\n                ],\n                axis=-1,\n            ),\n            np.stack(\n                [\n                    np.zeros_like(end_x2),\n                    np.zeros_like(end_x2),\n                    np.zeros_like(end_x2),\n                    end_x2,\n                    end_y2,\n                    np.ones_like(end_x2),\n                    -start_y2 * end_x2,\n                    -start_y2 * end_y2,\n                ],\n                axis=-1,\n            ),\n            np.stack(\n                [\n                    end_x3,\n                    end_y3,\n                    np.ones_like(end_x3),\n                    np.zeros_like(end_x3),\n                    np.zeros_like(end_x3),\n                    np.zeros_like(end_x3),\n                    -start_x3 * end_x3,\n                    -start_x3 * end_y3,\n                ],\n                axis=-1,\n            ),\n            np.stack(\n                [\n                    np.zeros_like(end_x3),\n                    np.zeros_like(end_x3),\n                    np.zeros_like(end_x3),\n                    end_x3,\n                    end_y3,\n                    np.ones_like(end_x3),\n                    -start_y3 * end_x3,\n                    -start_y3 * end_y3,\n                ],\n                axis=-1,\n            ),\n            np.stack(\n                [\n                    end_x4,\n                    end_y4,\n                    np.ones_like(end_x4),\n                    np.zeros_like(end_x4),\n                    np.zeros_like(end_x4),\n                    np.zeros_like(end_x4),\n                    -start_x4 * end_x4,\n                    -start_x4 * end_y4,\n                ],\n                axis=-1,\n            ),\n            np.stack(\n                [\n                    np.zeros_like(end_x4),\n                    np.zeros_like(end_x4),\n                    np.zeros_like(end_x4),\n                    end_x4,\n                    end_y4,\n                    np.ones_like(end_x4),\n                    -start_y4 * end_x4,\n                    -start_y4 * end_y4,\n                ],\n                axis=-1,\n            ),\n        ],\n        axis=1,\n    )\n\n    target_vector = np.stack(\n        [\n            start_x1,\n            start_y1,\n            start_x2,\n            start_y2,\n            start_x3,\n            start_y3,\n            start_x4,\n            start_y4,\n        ],\n        axis=-1,\n    )\n    target_vector = np.expand_dims(target_vector, axis=-1)\n    coefficient_matrix = coefficient_matrix.astype(compute_dtype)\n    target_vector = target_vector.astype(compute_dtype)\n    homography_matrix = np.linalg.solve(coefficient_matrix, target_vector)\n    homography_matrix = np.reshape(homography_matrix, [-1, 8])\n    return homography_matrix.astype(dtype)\n\n\ndef map_coordinates(\n    inputs, coordinates, order, fill_mode=\"constant\", fill_value=0.0\n):\n    inputs = convert_to_tensor(inputs)\n    coordinates = convert_to_tensor(coordinates)\n    if coordinates.shape[0] != len(inputs.shape):\n        raise ValueError(\n            \"First dim of `coordinates` must be the same as the rank of \"\n            \"`inputs`. \"\n            f\"Received inputs with shape: {inputs.shape} and coordinate \"\n            f\"leading dim of {coordinates.shape[0]}\"\n        )\n    if len(coordinates.shape) < 2:\n        raise ValueError(\n            \"Invalid coordinates rank: expected at least rank 2.\"\n            f\" Received input with shape: {coordinates.shape}\"\n        )\n    if fill_mode not in MAP_COORDINATES_FILL_MODES:\n        raise ValueError(\n            \"Invalid value for argument `fill_mode`. Expected one of \"\n            f\"{set(MAP_COORDINATES_FILL_MODES.keys())}. Received: \"\n            f\"fill_mode={fill_mode}\"\n        )\n    if order not in range(2):\n        raise ValueError(\n            \"Invalid value for argument `order`. Expected one of \"\n            f\"{[0, 1]}. Received: order={order}\"\n        )\n    # SciPy's implementation of map_coordinates handles boundaries incorrectly,\n    # unless mode='reflect'. For order=1, this only affects interpolation\n    # outside the bounds of the original array.\n    # https://github.com/scipy/scipy/issues/2640\n    padding = [\n        (\n            max(-np.floor(c.min()).astype(int) + 1, 0),\n            max(np.ceil(c.max()).astype(int) + 1 - size, 0),\n        )\n        for c, size in zip(coordinates, inputs.shape)\n    ]\n    shifted_coords = [c + p[0] for p, c in zip(padding, coordinates)]\n    pad_mode = {\n        \"nearest\": \"edge\",\n        \"mirror\": \"reflect\",\n        \"reflect\": \"symmetric\",\n    }.get(fill_mode, fill_mode)\n    if fill_mode == \"constant\":\n        padded = np.pad(\n            inputs, padding, mode=pad_mode, constant_values=fill_value\n        )\n    else:\n        padded = np.pad(inputs, padding, mode=pad_mode)\n\n    # `scipy.ndimage.map_coordinates` lacks support for float16 and bfloat16.\n    if backend.is_float_dtype(padded.dtype):\n        padded = padded.astype(\"float32\")\n    result = scipy.ndimage.map_coordinates(\n        padded, shifted_coords, order=order, mode=fill_mode, cval=fill_value\n    )\n    return result.astype(inputs.dtype)\n\n\ndef gaussian_blur(\n    images, kernel_size=(3, 3), sigma=(1.0, 1.0), data_format=None\n):\n    def _create_gaussian_kernel(kernel_size, sigma, num_channels, dtype):\n        def _get_gaussian_kernel1d(size, sigma):\n            x = np.arange(size, dtype=dtype) - (size - 1) / 2\n            kernel1d = np.exp(-0.5 * (x / sigma) ** 2)\n            return kernel1d / np.sum(kernel1d)\n\n        def _get_gaussian_kernel2d(size, sigma):\n            size = np.asarray(size, dtype)\n            kernel1d_x = _get_gaussian_kernel1d(size[0], sigma[0])\n            kernel1d_y = _get_gaussian_kernel1d(size[1], sigma[1])\n            return np.outer(kernel1d_y, kernel1d_x)\n\n        kernel = _get_gaussian_kernel2d(kernel_size, sigma)\n        kernel = kernel[:, :, np.newaxis]\n        kernel = np.tile(kernel, (1, 1, num_channels))\n        return kernel.astype(dtype)\n\n    images = convert_to_tensor(images)\n    kernel_size = convert_to_tensor(kernel_size)\n    sigma = convert_to_tensor(sigma)\n    input_dtype = backend.standardize_dtype(images.dtype)\n    # `scipy.signal.convolve2d` lacks support for float16 and bfloat16.\n    compute_dtype = backend.result_type(input_dtype, \"float32\")\n    images = images.astype(compute_dtype)\n    sigma = sigma.astype(compute_dtype)\n\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n\n    need_squeeze = False\n    if len(images.shape) == 3:\n        images = np.expand_dims(images, axis=0)\n        need_squeeze = True\n\n    if data_format == \"channels_first\":\n        images = np.transpose(images, (0, 2, 3, 1))\n\n    batch_size, height, width, num_channels = images.shape\n\n    kernel = _create_gaussian_kernel(\n        kernel_size, sigma, num_channels, input_dtype\n    )\n\n    kernel_h, kernel_w = kernel.shape[0], kernel.shape[1]\n    pad_h = (kernel_h - 1) // 2\n    pad_h_after = kernel_h - 1 - pad_h\n    pad_w = (kernel_w - 1) // 2\n    pad_w_after = kernel_w - 1 - pad_w\n\n    blurred_images = np.empty_like(images)\n\n    for b in range(batch_size):\n        for ch in range(num_channels):\n            padded = np.pad(\n                images[b, :, :, ch],\n                ((pad_h, pad_h_after), (pad_w, pad_w_after)),\n                mode=\"constant\",\n            )\n            blurred_images[b, :, :, ch] = scipy.signal.convolve2d(\n                padded, kernel[:, :, ch], mode=\"valid\"\n            )\n\n    if data_format == \"channels_first\":\n        blurred_images = np.transpose(blurred_images, (0, 3, 1, 2))\n    if need_squeeze:\n        blurred_images = np.squeeze(blurred_images, axis=0)\n    return blurred_images.astype(input_dtype)\n\n\ndef elastic_transform(\n    images,\n    alpha=20.0,\n    sigma=5.0,\n    interpolation=\"bilinear\",\n    fill_mode=\"reflect\",\n    fill_value=0.0,\n    seed=None,\n    data_format=None,\n):\n    data_format = backend.standardize_data_format(data_format)\n    if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys():\n        raise ValueError(\n            \"Invalid value for argument `interpolation`. Expected of one \"\n            f\"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: \"\n            f\"interpolation={interpolation}\"\n        )\n    if fill_mode not in AFFINE_TRANSFORM_FILL_MODES:\n        raise ValueError(\n            \"Invalid value for argument `fill_mode`. Expected of one \"\n            f\"{AFFINE_TRANSFORM_FILL_MODES}. Received: fill_mode={fill_mode}\"\n        )\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n\n    images = convert_to_tensor(images)\n    input_dtype = images.dtype\n\n    alpha = convert_to_tensor(alpha, dtype=input_dtype)\n    sigma = convert_to_tensor(sigma, dtype=input_dtype)\n\n    kernel_size = (int(6 * sigma) | 1, int(6 * sigma) | 1)\n\n    need_squeeze = False\n    if len(images.shape) == 3:\n        images = np.expand_dims(images, axis=0)\n        need_squeeze = True\n\n    if data_format == \"channels_last\":\n        batch_size, height, width, channels = images.shape\n        channel_axis = -1\n    else:\n        batch_size, channels, height, width = images.shape\n        channel_axis = 1\n\n    seed = draw_seed(seed)\n    rng = np.random.default_rng(seed)\n    dx = (\n        rng.normal(size=(batch_size, height, width), loc=0.0, scale=1.0).astype(\n            input_dtype\n        )\n        * sigma\n    )\n    dy = (\n        rng.normal(size=(batch_size, height, width), loc=0.0, scale=1.0).astype(\n            input_dtype\n        )\n        * sigma\n    )\n\n    dx = gaussian_blur(\n        np.expand_dims(dx, axis=channel_axis),\n        kernel_size=kernel_size,\n        sigma=(sigma, sigma),\n        data_format=data_format,\n    )\n    dy = gaussian_blur(\n        np.expand_dims(dy, axis=channel_axis),\n        kernel_size=kernel_size,\n        sigma=(sigma, sigma),\n        data_format=data_format,\n    )\n\n    dx = np.squeeze(dx)\n    dy = np.squeeze(dy)\n\n    x, y = np.meshgrid(np.arange(width), np.arange(height))\n    x, y = x[None, :, :], y[None, :, :]\n\n    distorted_x = x + alpha * dx\n    distorted_y = y + alpha * dy\n\n    transformed_images = np.zeros_like(images)\n\n    if data_format == \"channels_last\":\n        for i in range(channels):\n            transformed_images[..., i] = np.stack(\n                [\n                    map_coordinates(\n                        images[b, ..., i],\n                        [distorted_y[b], distorted_x[b]],\n                        order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],\n                        fill_mode=fill_mode,\n                        fill_value=fill_value,\n                    )\n                    for b in range(batch_size)\n                ]\n            )\n    else:\n        for i in range(channels):\n            transformed_images[:, i, :, :] = np.stack(\n                [\n                    map_coordinates(\n                        images[b, i, ...],\n                        [distorted_y[b], distorted_x[b]],\n                        order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],\n                        fill_mode=fill_mode,\n                        fill_value=fill_value,\n                    )\n                    for b in range(batch_size)\n                ]\n            )\n\n    if need_squeeze:\n        transformed_images = np.squeeze(transformed_images, axis=0)\n    transformed_images = transformed_images.astype(input_dtype)\n\n    return transformed_images\n\n\ndef scale_and_translate(\n    images,\n    output_shape,\n    scale,\n    translation,\n    spatial_dims,\n    method,\n    antialias=True,\n):\n    if method not in SCALE_AND_TRANSLATE_METHODS:\n        raise ValueError(\n            \"Invalid value for argument `method`. Expected of one \"\n            f\"{SCALE_AND_TRANSLATE_METHODS}. Received: method={method}\"\n        )\n    if method in (\"linear\", \"bilinear\", \"trilinear\", \"triangle\"):\n        method = \"linear\"\n    elif method in (\"cubic\", \"bicubic\", \"tricubic\"):\n        method = \"cubic\"\n\n    images = convert_to_tensor(images)\n    scale = convert_to_tensor(scale)\n    translation = convert_to_tensor(translation)\n    kernel = _kernels[method]\n    dtype = backend.result_type(scale.dtype, translation.dtype)\n    scale = scale.astype(dtype)\n    translation = translation.astype(dtype)\n    return _scale_and_translate(\n        images,\n        output_shape,\n        spatial_dims,\n        scale,\n        translation,\n        kernel,\n        antialias,\n    )\n"
  },
  {
    "path": "keras/src/backend/numpy/layer.py",
    "content": "class NumpyLayer:\n    pass\n"
  },
  {
    "path": "keras/src/backend/numpy/linalg.py",
    "content": "import numpy as np\nimport scipy.linalg as sl\n\nfrom keras.src.backend import standardize_dtype\nfrom keras.src.backend.common import dtypes\nfrom keras.src.backend.numpy.core import convert_to_tensor\n\n\ndef cholesky(a, upper=False):\n    return np.linalg.cholesky(a, upper=upper)\n\n\ndef cholesky_inverse(a, upper=False):\n    identity = np.eye(a.shape[-1], dtype=a.dtype)\n    inv_chol = solve_triangular(a, identity, lower=not upper)\n    if upper:\n        a_inv = np.matmul(inv_chol, inv_chol.T)\n    else:\n        a_inv = np.matmul(inv_chol.T, inv_chol)\n    return a_inv\n\n\ndef det(a):\n    return np.linalg.det(a)\n\n\ndef eig(a):\n    return np.linalg.eig(a)\n\n\ndef eigh(a):\n    return np.linalg.eigh(a)\n\n\ndef inv(a):\n    return np.linalg.inv(a)\n\n\ndef lu_factor(a):\n    if a.ndim == 2:\n        return sl.lu_factor(a)\n\n    m, n = a.shape[-2:]\n    signature = \"(m,n) -> (m,n), \"\n    signature += \"(m)\" if m <= n else \"(n)\"\n    _lu_factor_gufunc = np.vectorize(\n        sl.lu_factor,\n        signature=signature,\n    )\n    return _lu_factor_gufunc(a)\n\n\ndef norm(x, ord=None, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    dtype = standardize_dtype(x.dtype)\n    if \"int\" in dtype or dtype == \"bool\":\n        dtype = dtypes.result_type(x.dtype, \"float32\")\n    return np.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims).astype(\n        dtype\n    )\n\n\ndef qr(x, mode=\"reduced\"):\n    if mode not in {\"reduced\", \"complete\"}:\n        raise ValueError(\n            \"`mode` argument value not supported. \"\n            \"Expected one of {'reduced', 'complete'}. \"\n            f\"Received: mode={mode}\"\n        )\n    return np.linalg.qr(x, mode=mode)\n\n\ndef solve(a, b):\n    return np.linalg.solve(a, b)\n\n\ndef solve_triangular(a, b, lower=False):\n    if a.ndim == 2:\n        return sl.solve_triangular(a, b, lower=lower)\n\n    _vectorized_solve_triangular = np.vectorize(\n        lambda a, b: sl.solve_triangular(a, b, lower=lower),\n        signature=\"(n,n),(n,m)->(n,m)\",\n    )\n    if b.ndim == a.ndim - 1:\n        b = np.expand_dims(b, axis=-1)\n        return _vectorized_solve_triangular(a, b).squeeze(axis=-1)\n    return _vectorized_solve_triangular(a, b)\n\n\ndef svd(x, full_matrices=True, compute_uv=True):\n    return np.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv)\n\n\ndef lstsq(a, b, rcond=None):\n    a = convert_to_tensor(a)\n    b = convert_to_tensor(b)\n    return np.linalg.lstsq(a, b, rcond=rcond)[0]\n\n\ndef jvp(fun, primals, tangents, has_aux=False):\n    raise NotImplementedError(\"JVP is not supported by the Numpy backend.\")\n"
  },
  {
    "path": "keras/src/backend/numpy/math.py",
    "content": "import numpy as np\n\nfrom keras.src.backend import standardize_dtype\nfrom keras.src.backend.jax.math import fft as jax_fft\nfrom keras.src.backend.jax.math import fft2 as jax_fft2\nfrom keras.src.backend.numpy.core import convert_to_tensor\nfrom keras.src.utils.module_utils import scipy\n\n\ndef _segment_reduction_fn(\n    data, segment_ids, reduction_method, num_segments, sorted\n):\n    if num_segments is None:\n        num_segments = np.amax(segment_ids) + 1\n\n    valid_indices = segment_ids >= 0  # Ignore segment_ids that are -1\n    valid_data = data[valid_indices]\n    valid_segment_ids = segment_ids[valid_indices]\n\n    data_shape = list(valid_data.shape)\n    data_shape[0] = (\n        num_segments  # Replace first dimension (which corresponds to segments)\n    )\n\n    if reduction_method == np.maximum:\n        result = np.ones(data_shape, dtype=valid_data.dtype) * -np.inf\n    else:\n        result = np.zeros(data_shape, dtype=valid_data.dtype)\n\n    if sorted:\n        reduction_method.at(result, valid_segment_ids, valid_data)\n    else:\n        sort_indices = np.argsort(valid_segment_ids)\n        sorted_segment_ids = valid_segment_ids[sort_indices]\n        sorted_data = valid_data[sort_indices]\n\n        reduction_method.at(result, sorted_segment_ids, sorted_data)\n\n    return result\n\n\ndef segment_sum(data, segment_ids, num_segments=None, sorted=False):\n    return _segment_reduction_fn(\n        data, segment_ids, np.add, num_segments, sorted\n    )\n\n\ndef segment_max(data, segment_ids, num_segments=None, sorted=False):\n    return _segment_reduction_fn(\n        data, segment_ids, np.maximum, num_segments, sorted\n    )\n\n\ndef top_k(x, k, sorted=True):\n    if sorted:\n        # Take the k largest values.\n        sorted_indices = np.argsort(x, axis=-1)[..., ::-1]\n        sorted_values = np.take_along_axis(x, sorted_indices, axis=-1)\n        top_k_values = sorted_values[..., :k]\n        top_k_indices = sorted_indices[..., :k]\n    else:\n        # Partition the array such that all values larger than the k-th\n        # largest value are to the right of it.\n        top_k_indices = np.argpartition(x, -k, axis=-1)[..., -k:]\n        top_k_values = np.take_along_axis(x, top_k_indices, axis=-1)\n    return top_k_values, top_k_indices\n\n\ndef in_top_k(targets, predictions, k):\n    targets = targets[:, None]\n    topk_values = top_k(predictions, k)[0]\n    targets_values = np.take_along_axis(predictions, targets, axis=-1)\n    mask = targets_values >= topk_values\n    return np.any(mask, axis=-1)\n\n\ndef logsumexp(x, axis=None, keepdims=False):\n    return scipy.special.logsumexp(x, axis=axis, keepdims=keepdims)\n\n\ndef qr(x, mode=\"reduced\"):\n    if mode not in {\"reduced\", \"complete\"}:\n        raise ValueError(\n            \"`mode` argument value not supported. \"\n            \"Expected one of {'reduced', 'complete'}. \"\n            f\"Received: mode={mode}\"\n        )\n    return np.linalg.qr(x, mode=mode)\n\n\ndef extract_sequences(x, sequence_length, sequence_stride):\n    *batch_shape, _ = x.shape\n    batch_shape = list(batch_shape)\n    shape = x.shape[:-1] + (\n        (x.shape[-1] - (sequence_length - sequence_stride)) // sequence_stride,\n        sequence_length,\n    )\n    strides = x.strides[:-1] + (\n        sequence_stride * x.strides[-1],\n        x.strides[-1],\n    )\n    x = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides)\n    return np.reshape(x, (*batch_shape, *x.shape[-2:]))\n\n\ndef _get_complex_tensor_from_tuple(x):\n    if not isinstance(x, (tuple, list)) or len(x) != 2:\n        raise ValueError(\n            \"Input `x` should be a tuple of two tensors - real and imaginary.\"\n            f\"Received: x={x}\"\n        )\n    # `convert_to_tensor` does not support passing complex tensors. We separate\n    # the input out into real and imaginary and convert them separately.\n    real, imag = x\n    # Check shapes.\n    if real.shape != imag.shape:\n        raise ValueError(\n            \"Input `x` should be a tuple of two tensors - real and imaginary.\"\n            \"Both the real and imaginary parts should have the same shape. \"\n            f\"Received: x[0].shape = {real.shape}, x[1].shape = {imag.shape}\"\n        )\n    # Ensure dtype is float.\n    if not np.issubdtype(real.dtype, np.floating) or not np.issubdtype(\n        imag.dtype, np.floating\n    ):\n        raise ValueError(\n            \"At least one tensor in input `x` is not of type float.\"\n            f\"Received: x={x}.\"\n        )\n    complex_input = real + 1j * imag\n    return complex_input\n\n\ndef fft(x):\n    real, imag = jax_fft(x)\n    return np.array(real), np.array(imag)\n\n\ndef fft2(x):\n    real, imag = jax_fft2(x)\n    return np.array(real), np.array(imag)\n\n\ndef ifft2(x):\n    complex_input = _get_complex_tensor_from_tuple(x)\n    complex_output = np.fft.ifft2(complex_input)\n    return np.real(complex_output), np.imag(complex_output)\n\n\ndef rfft(x, fft_length=None):\n    complex_output = np.fft.rfft(x, n=fft_length, axis=-1, norm=\"backward\")\n    # numpy always outputs complex128, so we need to recast the dtype\n    return (\n        np.real(complex_output).astype(x.dtype),\n        np.imag(complex_output).astype(x.dtype),\n    )\n\n\ndef irfft(x, fft_length=None):\n    complex_input = _get_complex_tensor_from_tuple(x)\n    # numpy always outputs float64, so we need to recast the dtype\n    return np.fft.irfft(\n        complex_input, n=fft_length, axis=-1, norm=\"backward\"\n    ).astype(x[0].dtype)\n\n\ndef stft(\n    x, sequence_length, sequence_stride, fft_length, window=\"hann\", center=True\n):\n    if standardize_dtype(x.dtype) not in {\"float32\", \"float64\"}:\n        raise TypeError(\n            \"Invalid input type. Expected `float32` or `float64`. \"\n            f\"Received: input type={x.dtype}\"\n        )\n    if fft_length < sequence_length:\n        raise ValueError(\n            \"`fft_length` must equal or larger than `sequence_length`. \"\n            f\"Received: sequence_length={sequence_length}, \"\n            f\"fft_length={fft_length}\"\n        )\n    if isinstance(window, str):\n        if window not in {\"hann\", \"hamming\"}:\n            raise ValueError(\n                \"If a string is passed to `window`, it must be one of \"\n                f'`\"hann\"`, `\"hamming\"`. Received: window={window}'\n            )\n    x = convert_to_tensor(x)\n    ori_dtype = x.dtype\n\n    if center:\n        pad_width = [(0, 0) for _ in range(len(x.shape))]\n        pad_width[-1] = (fft_length // 2, fft_length // 2)\n        x = np.pad(x, pad_width, mode=\"reflect\")\n\n    l_pad = (fft_length - sequence_length) // 2\n    r_pad = fft_length - sequence_length - l_pad\n\n    if window is not None:\n        if isinstance(window, str):\n            win = convert_to_tensor(\n                scipy.signal.get_window(window, sequence_length), dtype=x.dtype\n            )\n        else:\n            win = convert_to_tensor(window, dtype=x.dtype)\n        if len(win.shape) != 1 or win.shape[-1] != sequence_length:\n            raise ValueError(\n                \"The shape of `window` must be equal to [sequence_length].\"\n                f\"Received: window shape={win.shape}\"\n            )\n        win = np.pad(win, [[l_pad, r_pad]])\n    else:\n        win = np.ones((sequence_length + l_pad + r_pad), dtype=x.dtype)\n\n    x = scipy.signal.stft(\n        x,\n        fs=1.0,\n        window=win,\n        nperseg=(sequence_length + l_pad + r_pad),\n        noverlap=(sequence_length + l_pad + r_pad - sequence_stride),\n        nfft=fft_length,\n        boundary=None,\n        padded=False,\n    )[-1]\n\n    # scale and swap to (..., num_sequences, fft_bins)\n    x = x / np.sqrt(1.0 / win.sum() ** 2)\n    x = np.swapaxes(x, -2, -1)\n    return np.real(x).astype(ori_dtype), np.imag(x).astype(ori_dtype)\n\n\ndef istft(\n    x,\n    sequence_length,\n    sequence_stride,\n    fft_length,\n    length=None,\n    window=\"hann\",\n    center=True,\n):\n    x = _get_complex_tensor_from_tuple(x)\n    dtype = np.real(x).dtype\n\n    expected_output_len = fft_length + sequence_stride * (x.shape[-2] - 1)\n    l_pad = (fft_length - sequence_length) // 2\n    r_pad = fft_length - sequence_length - l_pad\n\n    if window is not None:\n        if isinstance(window, str):\n            win = convert_to_tensor(\n                scipy.signal.get_window(window, sequence_length), dtype=dtype\n            )\n        else:\n            win = convert_to_tensor(window, dtype=dtype)\n        if len(win.shape) != 1 or win.shape[-1] != sequence_length:\n            raise ValueError(\n                \"The shape of `window` must be equal to [sequence_length].\"\n                f\"Received: window shape={win.shape}\"\n            )\n        win = np.pad(win, [[l_pad, r_pad]])\n    else:\n        win = np.ones((sequence_length + l_pad + r_pad), dtype=dtype)\n\n    x = scipy.signal.istft(\n        x,\n        fs=1.0,\n        window=win,\n        nperseg=(sequence_length + l_pad + r_pad),\n        noverlap=(sequence_length + l_pad + r_pad - sequence_stride),\n        nfft=fft_length,\n        boundary=False,\n        time_axis=-2,\n        freq_axis=-1,\n    )[-1]\n\n    # scale\n    x = x / win.sum() if window is not None else x / sequence_stride\n\n    start = 0 if center is False else fft_length // 2\n    if length is not None:\n        end = start + length\n    elif center is True:\n        end = -(fft_length // 2)\n    else:\n        end = expected_output_len\n    return x[..., start:end]\n\n\ndef rsqrt(x):\n    return 1.0 / np.sqrt(x)\n\n\ndef erf(x):\n    return np.array(scipy.special.erf(x))\n\n\ndef erfinv(x):\n    return np.array(scipy.special.erfinv(x))\n\n\ndef logdet(x):\n    from keras.src.backend.numpy.numpy import slogdet\n\n    # In NumPy slogdet is more stable than `np.log(np.linalg.det(x))`. See\n    # https://numpy.org/doc/stable/reference/generated/numpy.linalg.slogdet.html\n    return slogdet(x)[1]\n"
  },
  {
    "path": "keras/src/backend/numpy/nn.py",
    "content": "import jax\nimport numpy as np\nfrom jax import lax\n\nfrom keras.src import backend\nfrom keras.src.backend.common.backend_utils import (\n    compute_adaptive_pooling_window_sizes,\n)\nfrom keras.src.backend.common.backend_utils import (\n    compute_conv_transpose_padding_args_for_jax,\n)\nfrom keras.src.backend.numpy.core import cast\nfrom keras.src.backend.numpy.core import convert_to_tensor\nfrom keras.src.backend.numpy.core import is_tensor\nfrom keras.src.utils.module_utils import scipy\n\n\ndef relu(x):\n    x = convert_to_tensor(x)\n    return np.maximum(x, np.array(0.0, x.dtype))\n\n\ndef relu6(x):\n    x = convert_to_tensor(x)\n    # np.clip incorrectly promote bfloat16 to float32, so we replace it with\n    # np.minimum and np.maximum here\n    return np.minimum(\n        np.maximum(x, np.array(0.0, x.dtype)), np.array(6.0, x.dtype)\n    )\n\n\ndef sigmoid(x):\n    x = convert_to_tensor(x)\n    return np.array(1.0, x.dtype) / (np.array(1.0, x.dtype) + np.exp(-x))\n\n\ndef sparse_sigmoid(x):\n    x = convert_to_tensor(x)\n    return np.where(\n        x <= -1,\n        np.array(0.0, x.dtype),\n        np.where(\n            x >= 1, np.array(1.0, x.dtype), np.array(0.5 * (x + 1), x.dtype)\n        ),\n    )\n\n\ndef tanh(x):\n    return np.tanh(x)\n\n\ndef tanh_shrink(x):\n    x = convert_to_tensor(x)\n    return x - np.tanh(x)\n\n\ndef softplus(x):\n    x = convert_to_tensor(x)\n    return np.logaddexp(x, np.array(0.0, x.dtype))\n\n\ndef softsign(x):\n    x = convert_to_tensor(x)\n    return x / (np.array(1.0, x.dtype) + np.abs(x))\n\n\ndef soft_shrink(x, threshold=0.5):\n    return np.where(\n        x > threshold,\n        np.array(x - threshold, dtype=x.dtype),\n        np.where(\n            x < -threshold,\n            np.array(x + threshold, dtype=x.dtype),\n            np.array(0.0, dtype=x.dtype),\n        ),\n    )\n\n\ndef sparse_plus(x):\n    return np.where(\n        x <= -1,\n        np.zeros_like(x, dtype=x.dtype),\n        np.where(x < 1, np.array((1 / 4) * (x + 1) ** 2, dtype=x.dtype), x),\n    )\n\n\ndef silu(x):\n    x = convert_to_tensor(x)\n    return x * sigmoid(x)\n\n\ndef squareplus(x, b=4):\n    x = convert_to_tensor(x)\n    b = convert_to_tensor(b, dtype=x.dtype)\n    y = x + np.sqrt(x**2 + b)\n    return y / 2\n\n\ndef log_sigmoid(x):\n    x = convert_to_tensor(x)\n    return -softplus(-x)\n\n\ndef leaky_relu(x, negative_slope=0.2):\n    x = convert_to_tensor(x)\n    return np.maximum(x, np.array(negative_slope, x.dtype) * x)\n\n\ndef hard_sigmoid(x):\n    # python numbers will be promoted to float64 by np, so it's necessary to\n    # first convert the python numbers to np scalars\n    x = x / np.array(6.0, x.dtype) + np.array(0.5, x.dtype)\n    return np.where(\n        x <= 0.0,\n        np.array(0.0, x.dtype),\n        np.where(x >= 1.0, np.array(1.0, x.dtype), x),\n    )\n\n\ndef hard_silu(x):\n    return x * hard_sigmoid(x)\n\n\ndef elu(x, alpha=1.0):\n    x = convert_to_tensor(x)\n    return np.where(\n        x >= np.array(0.0, x.dtype), x, np.array(alpha, x.dtype) * np.expm1(x)\n    )\n\n\ndef selu(x):\n    alpha = 1.6732632423543772848170429916717\n    scale = 1.0507009873554804934193349852946\n    x = convert_to_tensor(x)\n    return np.array(scale, x.dtype) * elu(x, alpha)\n\n\ndef gelu(x, approximate=True):\n    x = convert_to_tensor(x)\n    # followed by JAX's implementation\n    if approximate:\n        sqrt_2_over_pi = np.sqrt(2 / np.pi).astype(x.dtype)\n        cdf = np.array(0.5, x.dtype) * (\n            np.array(1.0, x.dtype)\n            + np.tanh(\n                sqrt_2_over_pi\n                * (x + np.array(0.044715, x.dtype) * (x**3).astype(x.dtype))\n            )\n        )\n        return x * cdf\n    else:\n        sqrt_2 = np.sqrt(2).astype(x.dtype)\n        return (\n            x\n            * (scipy.special.erf(x / sqrt_2) + 1).astype(x.dtype)\n            / np.array(2, x.dtype)\n        )\n\n\ndef celu(x, alpha=1.0):\n    x = convert_to_tensor(x)\n    alpha = np.array(alpha, x.dtype)\n    return np.maximum(x, np.array(0.0, dtype=x.dtype)) + alpha * np.expm1(\n        np.minimum(x, np.array(0.0, dtype=x.dtype)) / alpha\n    )\n\n\ndef glu(x, axis=-1):\n    x = convert_to_tensor(x)\n    dtype = x.dtype\n    if x.shape[axis] % 2 != 0:\n        raise ValueError(\n            \"axis size must be divisible by 2. \"\n            f\"Received: x.shape={x.shape} with axis={axis}\"\n        )\n    x1, x2 = np.split(x, 2, axis)\n    return (x1 * sigmoid(x2)).astype(dtype)\n\n\ndef hard_tanh(x):\n    x = convert_to_tensor(x)\n    min_val = np.asarray(-1.0, x.dtype)\n    max_val = np.asarray(1.0, x.dtype)\n    return np.array(np.clip(x, min_val, max_val), dtype=x.dtype)\n\n\ndef hard_shrink(x, threshold=0.5):\n    x = convert_to_tensor(x)\n    threshold = np.asarray(threshold, x.dtype)\n    return np.array(\n        np.where(np.abs(x) > threshold, x, np.array(0.0, dtype=x.dtype)),\n        dtype=x.dtype,\n    )\n\n\ndef threshold(x, threshold, default_value):\n    x = convert_to_tensor(x)\n    return np.where(x > threshold, x, np.array(default_value, dtype=x.dtype))\n\n\ndef softmax(x, axis=-1):\n    exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))\n    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)\n\n\ndef log_softmax(x, axis=-1):\n    max_x = np.max(x, axis=axis, keepdims=True)\n    logsumexp = np.log(np.exp(x - max_x).sum(axis=axis, keepdims=True))\n    return x - max_x - logsumexp\n\n\ndef sparsemax(x, axis=-1):\n    # Sort logits along the specified axis in descending order\n    logits = convert_to_tensor(x)\n    logits_sorted = -1.0 * np.sort(-1.0 * logits, axis=axis)\n    logits_cumsum = np.cumsum(logits_sorted, axis=axis)\n    r = np.arange(1, logits.shape[axis] + 1)\n    r_shape = [1] * logits.ndim\n    r_shape[axis] = -1  # Broadcast to match the target axis\n    r = r.reshape(r_shape)\n    support = logits_sorted - (logits_cumsum - 1) / r > 0\n    # Find the threshold\n    k = np.sum(support, axis=axis, keepdims=True)\n    logits_cumsum_safe = np.where(support, logits_cumsum, 0.0)\n    tau = (np.sum(logits_cumsum_safe, axis=axis, keepdims=True) - 1) / k\n    output = np.maximum(logits - tau, 0.0)\n    return output\n\n\ndef _convert_to_spatial_operand(\n    x,\n    num_spatial_dims,\n    data_format=\"channels_last\",\n    include_batch_and_channels=True,\n):\n    # Helper function that converts an operand to a spatial operand.\n    x = (x,) * num_spatial_dims if isinstance(x, int) else x\n    if not include_batch_and_channels:\n        return x\n    if data_format == \"channels_last\":\n        x = (1,) + x + (1,)\n    else:\n        x = (1,) + (1,) + x\n    return x\n\n\ndef _pool(\n    inputs,\n    initial_value,\n    reduce_fn,\n    pool_size,\n    strides=None,\n    padding=\"valid\",\n):\n    \"\"\"Helper function to define pooling functions.\n\n    Args:\n        inputs: input data of shape `N+2`.\n        initial_value: the initial value for the reduction.\n        reduce_fn: a reduce function of the form `(T, T) -> T`.\n        pool_size: a sequence of `N` integers, representing the window size to\n            reduce over.\n        strides: a sequence of `N` integers, representing the inter-window\n            strides (default: `(1, ..., 1)`).\n        padding: either the string `same` or `valid`.\n\n    Returns:\n        The output of the reduction for each window slice.\n    \"\"\"\n    if padding not in (\"same\", \"valid\"):\n        raise ValueError(\n            f\"Invalid padding '{padding}', must be 'same' or 'valid'.\"\n        )\n    padding = padding.upper()\n    return np.array(\n        lax.reduce_window(\n            inputs,\n            initial_value,\n            reduce_fn,\n            pool_size,\n            strides,\n            padding,\n        )\n    )\n\n\ndef max_pool(\n    inputs,\n    pool_size,\n    strides=None,\n    padding=\"valid\",\n    data_format=None,\n):\n    data_format = backend.standardize_data_format(data_format)\n    num_spatial_dims = inputs.ndim - 2\n    pool_size = _convert_to_spatial_operand(\n        pool_size, num_spatial_dims, data_format\n    )\n    strides = pool_size if strides is None else strides\n    strides = _convert_to_spatial_operand(\n        strides, num_spatial_dims, data_format\n    )\n    return _pool(inputs, -np.inf, lax.max, pool_size, strides, padding)\n\n\ndef average_pool(\n    inputs,\n    pool_size,\n    strides=None,\n    padding=\"valid\",\n    data_format=None,\n):\n    data_format = backend.standardize_data_format(data_format)\n    num_spatial_dims = inputs.ndim - 2\n    pool_size = _convert_to_spatial_operand(\n        pool_size, num_spatial_dims, data_format\n    )\n    strides = pool_size if strides is None else strides\n    strides = _convert_to_spatial_operand(\n        strides, num_spatial_dims, data_format\n    )\n\n    pooled = _pool(inputs, 0.0, lax.add, pool_size, strides, padding)\n    if padding == \"valid\":\n        # Avoid the extra reduce_window.\n        return pooled / np.prod(pool_size)\n    else:\n        # Count the number of valid entries at each input point, then use that\n        # for computing average. Assumes that any two arrays of same shape will\n        # be padded the same. Avoid broadcasting on axis where pooling is\n        # skipped.\n        shape = [\n            (a if b != 1 else 1) for (a, b) in zip(inputs.shape, pool_size)\n        ]\n        window_counts = _pool(\n            np.ones(shape, inputs.dtype),\n            0.0,\n            lax.add,\n            pool_size,\n            strides,\n            padding,\n        )\n        return pooled / window_counts\n\n\ndef _compute_adaptive_pooling_gather_indices(\n    input_dim, output_size, big_window\n):\n    window_starts = np.floor(\n        (np.arange(output_size) * input_dim) / output_size\n    ).astype(np.int32)\n\n    window_ends = np.ceil(\n        (np.arange(1, output_size + 1) * input_dim) / output_size\n    ).astype(np.int32)\n\n    window_sizes = window_ends - window_starts\n    is_big = window_sizes == big_window\n\n    small_window = big_window - 1\n    small_pool_len = input_dim - small_window + 1\n\n    small_indices = window_starts\n    big_indices = window_starts + small_pool_len\n\n    gather = np.where(is_big, big_indices, small_indices)\n    return gather.astype(np.int32)\n\n\ndef _strided_view_1d(x, window_size):\n    n, l, c = x.shape\n    out = l - window_size + 1\n\n    strides = x.strides\n    shape = (n, out, window_size, c)\n    new_strides = (strides[0], strides[1], strides[1], strides[2])\n\n    return np.lib.stride_tricks.as_strided(x, shape=shape, strides=new_strides)\n\n\ndef _adaptive_pool1d_impl(inputs, output_size, mode, data_format):\n    if isinstance(output_size, int):\n        output_size = (output_size,)\n\n    if data_format == \"channels_first\":\n        inputs = np.transpose(inputs, (0, 2, 1))\n\n    n, l, c = inputs.shape\n    out_l = output_size[0]\n\n    small, big = compute_adaptive_pooling_window_sizes(l, out_l)\n    gather = _compute_adaptive_pooling_gather_indices(l, out_l, big)\n\n    sv_small = _strided_view_1d(inputs, small)\n    small_pool = (\n        np.mean(sv_small, axis=2)\n        if mode == \"average\"\n        else np.max(sv_small, axis=2)\n    )\n\n    sv_big = _strided_view_1d(inputs, big)\n    big_pool = (\n        np.mean(sv_big, axis=2) if mode == \"average\" else np.max(sv_big, axis=2)\n    )\n\n    combined = np.concatenate([small_pool, big_pool], axis=1)\n    out = combined[:, gather, :]\n\n    if data_format == \"channels_first\":\n        out = np.transpose(out, (0, 2, 1))\n\n    return out\n\n\ndef _adaptive_pool2d_impl(inputs, output_size, mode, data_format):\n    if isinstance(output_size, int):\n        output_size = (output_size, output_size)\n\n    if data_format == \"channels_first\":\n        inputs = np.transpose(inputs, (0, 2, 3, 1))\n\n    n, h, w, c = inputs.shape\n    out_h, out_w = output_size\n\n    small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)\n    gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)\n\n    x_h = np.transpose(inputs, (0, 2, 1, 3)).reshape(n * w, h, c)\n\n    sv_small_h = _strided_view_1d(x_h, small_h)\n    small_pool_h = (\n        np.mean(sv_small_h, axis=2)\n        if mode == \"average\"\n        else np.max(sv_small_h, axis=2)\n    )\n\n    sv_big_h = _strided_view_1d(x_h, big_h)\n    big_pool_h = (\n        np.mean(sv_big_h, axis=2)\n        if mode == \"average\"\n        else np.max(sv_big_h, axis=2)\n    )\n\n    combined_h = np.concatenate([small_pool_h, big_pool_h], axis=1)\n    pooled_h = combined_h[:, gather_h, :]\n\n    pooled_h = pooled_h.reshape(n, w, out_h, c)\n    pooled_h = np.transpose(pooled_h, (0, 2, 1, 3))\n\n    small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)\n    gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)\n\n    x_w = pooled_h.reshape(n * out_h, w, c)\n\n    sv_small_w = _strided_view_1d(x_w, small_w)\n    small_pool_w = (\n        np.mean(sv_small_w, axis=2)\n        if mode == \"average\"\n        else np.max(sv_small_w, axis=2)\n    )\n\n    sv_big_w = _strided_view_1d(x_w, big_w)\n    big_pool_w = (\n        np.mean(sv_big_w, axis=2)\n        if mode == \"average\"\n        else np.max(sv_big_w, axis=2)\n    )\n\n    combined_w = np.concatenate([small_pool_w, big_pool_w], axis=1)\n    out = combined_w[:, gather_w, :].reshape(n, out_h, out_w, c)\n\n    if data_format == \"channels_first\":\n        out = np.transpose(out, (0, 3, 1, 2))\n\n    return out\n\n\ndef _adaptive_pool3d_impl(inputs, output_size, mode, data_format):\n    if isinstance(output_size, int):\n        output_size = (output_size, output_size, output_size)\n\n    if data_format == \"channels_first\":\n        inputs = np.transpose(inputs, (0, 2, 3, 4, 1))\n\n    n, d, h, w, c = inputs.shape\n    out_d, out_h, out_w = output_size\n\n    small_d, big_d = compute_adaptive_pooling_window_sizes(d, out_d)\n    gather_d = _compute_adaptive_pooling_gather_indices(d, out_d, big_d)\n\n    x_d = np.transpose(inputs, (0, 2, 3, 1, 4)).reshape(n * h * w, d, c)\n\n    sv_small_d = _strided_view_1d(x_d, small_d)\n    small_pool_d = (\n        np.mean(sv_small_d, axis=2)\n        if mode == \"average\"\n        else np.max(sv_small_d, axis=2)\n    )\n\n    sv_big_d = _strided_view_1d(x_d, big_d)\n    big_pool_d = (\n        np.mean(sv_big_d, axis=2)\n        if mode == \"average\"\n        else np.max(sv_big_d, axis=2)\n    )\n\n    combined_d = np.concatenate([small_pool_d, big_pool_d], axis=1)\n    pooled_d = combined_d[:, gather_d, :].reshape(n, h, w, out_d, c)\n    pooled_d = np.transpose(pooled_d, (0, 3, 1, 2, 4))\n\n    small_h, big_h = compute_adaptive_pooling_window_sizes(h, out_h)\n    gather_h = _compute_adaptive_pooling_gather_indices(h, out_h, big_h)\n\n    x_h = np.transpose(pooled_d, (0, 1, 3, 2, 4)).reshape(n * out_d * w, h, c)\n\n    sv_small_h = _strided_view_1d(x_h, small_h)\n    small_pool_h = (\n        np.mean(sv_small_h, axis=2)\n        if mode == \"average\"\n        else np.max(sv_small_h, axis=2)\n    )\n\n    sv_big_h = _strided_view_1d(x_h, big_h)\n    big_pool_h = (\n        np.mean(sv_big_h, axis=2)\n        if mode == \"average\"\n        else np.max(sv_big_h, axis=2)\n    )\n\n    combined_h = np.concatenate([small_pool_h, big_pool_h], axis=1)\n    pooled_h = combined_h[:, gather_h, :].reshape(n, out_d, w, out_h, c)\n    pooled_h = np.transpose(pooled_h, (0, 1, 3, 2, 4))\n\n    small_w, big_w = compute_adaptive_pooling_window_sizes(w, out_w)\n    gather_w = _compute_adaptive_pooling_gather_indices(w, out_w, big_w)\n\n    x_w = pooled_h.reshape(n * out_d * out_h, w, c)\n\n    sv_small_w = _strided_view_1d(x_w, small_w)\n    small_pool_w = (\n        np.mean(sv_small_w, axis=2)\n        if mode == \"average\"\n        else np.max(sv_small_w, axis=2)\n    )\n\n    sv_big_w = _strided_view_1d(x_w, big_w)\n    big_pool_w = (\n        np.mean(sv_big_w, axis=2)\n        if mode == \"average\"\n        else np.max(sv_big_w, axis=2)\n    )\n\n    combined_w = np.concatenate([small_pool_w, big_pool_w], axis=1)\n    out = combined_w[:, gather_w, :].reshape(n, out_d, out_h, out_w, c)\n\n    if data_format == \"channels_first\":\n        out = np.transpose(out, (0, 4, 1, 2, 3))\n\n    return out\n\n\ndef adaptive_average_pool(inputs, output_size, data_format=None):\n    data_format = backend.standardize_data_format(data_format)\n    dims = inputs.ndim - 2\n    if dims == 1:\n        return _adaptive_pool1d_impl(\n            inputs, output_size, \"average\", data_format\n        )\n    if dims == 2:\n        return _adaptive_pool2d_impl(\n            inputs, output_size, \"average\", data_format\n        )\n    if dims == 3:\n        return _adaptive_pool3d_impl(\n            inputs, output_size, \"average\", data_format\n        )\n    raise ValueError(\"adaptive_average_pool supports only 1D/2D/3D\")\n\n\ndef adaptive_max_pool(inputs, output_size, data_format=None):\n    data_format = backend.standardize_data_format(data_format)\n    dims = inputs.ndim - 2\n    if dims == 1:\n        return _adaptive_pool1d_impl(inputs, output_size, \"max\", data_format)\n    if dims == 2:\n        return _adaptive_pool2d_impl(inputs, output_size, \"max\", data_format)\n    if dims == 3:\n        return _adaptive_pool3d_impl(inputs, output_size, \"max\", data_format)\n    raise ValueError(\"adaptive_max_pool supports only 1D/2D/3D\")\n\n\ndef _convert_to_lax_conv_dimension_numbers(\n    num_spatial_dims,\n    data_format=\"channels_last\",\n    transpose=False,\n):\n    \"\"\"Create a `lax.ConvDimensionNumbers` for the given inputs.\"\"\"\n    num_dims = num_spatial_dims + 2\n\n    if data_format == \"channels_last\":\n        spatial_dims = tuple(range(1, num_dims - 1))\n        inputs_dn = (0, num_dims - 1) + spatial_dims\n    else:\n        spatial_dims = tuple(range(2, num_dims))\n        inputs_dn = (0, 1) + spatial_dims\n\n    if transpose:\n        kernel_dn = (num_dims - 2, num_dims - 1) + tuple(range(num_dims - 2))\n    else:\n        kernel_dn = (num_dims - 1, num_dims - 2) + tuple(range(num_dims - 2))\n\n    return lax.ConvDimensionNumbers(\n        lhs_spec=inputs_dn, rhs_spec=kernel_dn, out_spec=inputs_dn\n    )\n\n\ndef conv(\n    inputs,\n    kernel,\n    strides=1,\n    padding=\"valid\",\n    data_format=None,\n    dilation_rate=1,\n):\n    data_format = backend.standardize_data_format(data_format)\n    num_spatial_dims = inputs.ndim - 2\n    dimension_numbers = _convert_to_lax_conv_dimension_numbers(\n        num_spatial_dims,\n        data_format,\n        transpose=False,\n    )\n    strides = _convert_to_spatial_operand(\n        strides,\n        num_spatial_dims,\n        data_format,\n        include_batch_and_channels=False,\n    )\n    dilation_rate = _convert_to_spatial_operand(\n        dilation_rate,\n        num_spatial_dims,\n        data_format,\n        include_batch_and_channels=False,\n    )\n    if data_format == \"channels_last\":\n        channels = inputs.shape[-1]\n    else:\n        channels = inputs.shape[1]\n    kernel_in_channels = kernel.shape[-2]\n    if channels % kernel_in_channels > 0:\n        raise ValueError(\n            \"The number of input channels must be evenly divisible by \"\n            f\"kernel's in_channels. Received input channels {channels} and \"\n            f\"kernel in_channels {kernel_in_channels}. \"\n        )\n    feature_group_count = channels // kernel_in_channels\n    result = np.array(\n        jax.lax.conv_general_dilated(\n            inputs,\n            kernel if is_tensor(kernel) else kernel.numpy(),\n            strides,\n            padding,\n            rhs_dilation=dilation_rate,\n            dimension_numbers=dimension_numbers,\n            feature_group_count=feature_group_count,\n        )\n    )\n    if result.size == 0:\n        raise ValueError(\n            \"The convolution operation resulted in an empty output. \"\n            \"This can happen if the input is too small for the given \"\n            \"kernel size, strides, dilation rate, and padding mode. \"\n            \"Please check the input shape and convolution parameters.\"\n        )\n    return result\n\n\ndef depthwise_conv(\n    inputs,\n    kernel,\n    strides=1,\n    padding=\"valid\",\n    data_format=None,\n    dilation_rate=1,\n):\n    data_format = backend.standardize_data_format(data_format)\n    num_spatial_dims = inputs.ndim - 2\n    dimension_numbers = _convert_to_lax_conv_dimension_numbers(\n        num_spatial_dims,\n        data_format,\n        transpose=False,\n    )\n    strides = _convert_to_spatial_operand(\n        strides,\n        num_spatial_dims,\n        data_format,\n        include_batch_and_channels=False,\n    )\n    dilation_rate = _convert_to_spatial_operand(\n        dilation_rate,\n        num_spatial_dims,\n        data_format,\n        include_batch_and_channels=False,\n    )\n    feature_group_count = (\n        inputs.shape[-1] if data_format == \"channels_last\" else inputs.shape[1]\n    )\n    kernel = np.reshape(\n        kernel if is_tensor(kernel) else kernel.numpy(),\n        kernel.shape[:-2] + (1, feature_group_count * kernel.shape[-1]),\n    )\n    return np.array(\n        jax.lax.conv_general_dilated(\n            inputs,\n            kernel,\n            strides,\n            padding,\n            rhs_dilation=dilation_rate,\n            dimension_numbers=dimension_numbers,\n            feature_group_count=feature_group_count,\n        )\n    )\n\n\ndef separable_conv(\n    inputs,\n    depthwise_kernel,\n    pointwise_kernel,\n    strides=1,\n    padding=\"valid\",\n    data_format=None,\n    dilation_rate=1,\n):\n    data_format = backend.standardize_data_format(data_format)\n    depthwise_conv_output = depthwise_conv(\n        inputs,\n        depthwise_kernel,\n        strides,\n        padding,\n        data_format,\n        dilation_rate,\n    )\n    return conv(\n        depthwise_conv_output,\n        pointwise_kernel,\n        strides=1,\n        padding=\"valid\",\n        data_format=data_format,\n        dilation_rate=dilation_rate,\n    )\n\n\ndef conv_transpose(\n    inputs,\n    kernel,\n    strides=1,\n    padding=\"valid\",\n    output_padding=None,\n    data_format=None,\n    dilation_rate=1,\n):\n    data_format = backend.standardize_data_format(data_format)\n    num_spatial_dims = inputs.ndim - 2\n    padding_values = compute_conv_transpose_padding_args_for_jax(\n        input_shape=inputs.shape,\n        kernel_shape=kernel.shape,\n        strides=strides,\n        padding=padding,\n        output_padding=output_padding,\n        dilation_rate=dilation_rate,\n    )\n    dimension_numbers = _convert_to_lax_conv_dimension_numbers(\n        num_spatial_dims,\n        data_format,\n        transpose=False,\n    )\n    strides = _convert_to_spatial_operand(\n        strides,\n        num_spatial_dims,\n        data_format,\n        include_batch_and_channels=False,\n    )\n    dilation_rate = _convert_to_spatial_operand(\n        dilation_rate,\n        num_spatial_dims,\n        data_format,\n        include_batch_and_channels=False,\n    )\n\n    return np.array(\n        jax.lax.conv_transpose(\n            inputs,\n            kernel if is_tensor(kernel) else kernel.numpy(),\n            strides,\n            padding=padding_values,\n            rhs_dilation=dilation_rate,\n            dimension_numbers=dimension_numbers,\n            transpose_kernel=True,\n        )\n    )\n\n\ndef one_hot(x, num_classes, axis=-1, dtype=None, sparse=False):\n    if sparse:\n        raise ValueError(\"Unsupported value `sparse=True` with numpy backend\")\n    if dtype is None:\n        dtype = \"float32\"\n    x = convert_to_tensor(x)\n    input_shape = x.shape\n\n    x = x.reshape(-1)\n    if not num_classes:\n        num_classes = np.max(x) + 1\n\n    batch_size = x.shape[0]\n    categorical = np.zeros((batch_size, num_classes), dtype=dtype)\n    valid_indices = x >= 0\n    categorical[np.arange(batch_size)[valid_indices], x[valid_indices]] = 1\n\n    # First, reshape the array with the extra dimension at the end\n    output_shape = input_shape + (num_classes,)\n    categorical = np.reshape(categorical, output_shape)\n\n    # Then, move this new dimension to the right place (according to axis)\n    if axis != -1:\n        categorical = np.moveaxis(categorical, -1, axis)\n\n    return categorical\n\n\ndef multi_hot(x, num_classes, axis=-1, dtype=None, sparse=False):\n    if sparse:\n        raise ValueError(\"Unsupported value `sparse=True` with numpy backend\")\n    x = convert_to_tensor(x)\n    reduction_axis = 1 if len(x.shape) > 1 else 0\n    outputs = np.max(\n        one_hot(cast(x, \"int32\"), num_classes, axis=axis, dtype=dtype),\n        axis=reduction_axis,\n    )\n    return outputs\n\n\ndef categorical_crossentropy(target, output, from_logits=False, axis=-1):\n    target = np.array(target)\n    output = np.array(output)\n\n    if target.shape != output.shape:\n        raise ValueError(\n            \"Arguments `target` and `output` must have the same shape. \"\n            \"Received: \"\n            f\"target.shape={target.shape}, output.shape={output.shape}\"\n        )\n    if len(target.shape) < 1:\n        raise ValueError(\n            \"Arguments `target` and `output` must be at least rank 1. \"\n            \"Received: \"\n            f\"target.shape={target.shape}, output.shape={output.shape}\"\n        )\n\n    if from_logits:\n        log_prob = log_softmax(output, axis=axis)\n    else:\n        output = output / np.sum(output, axis, keepdims=True)\n        output = np.clip(output, backend.epsilon(), 1.0 - backend.epsilon())\n        log_prob = np.log(output)\n    return -np.sum(target * log_prob, axis=axis)\n\n\ndef sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):\n    target = np.array(target, dtype=\"int32\")\n    output = np.array(output)\n    if len(target.shape) == len(output.shape) and target.shape[-1] == 1:\n        target = np.squeeze(target, axis=-1)\n\n    if len(output.shape) < 1:\n        raise ValueError(\n            \"Argument `output` must be at least rank 1. \"\n            \"Received: \"\n            f\"output.shape={output.shape}\"\n        )\n    if target.shape != output.shape[:-1]:\n        raise ValueError(\n            \"Arguments `target` and `output` must have the same shape \"\n            \"up until the last dimension: \"\n            f\"target.shape={target.shape}, output.shape={output.shape}\"\n        )\n    if from_logits:\n        log_prob = log_softmax(output, axis=axis)\n    else:\n        output = output / np.sum(output, axis, keepdims=True)\n        output = np.clip(output, backend.epsilon(), 1.0 - backend.epsilon())\n        log_prob = np.log(output)\n    target = one_hot(target, output.shape[axis], axis=axis)\n    return -np.sum(target * log_prob, axis=axis)\n\n\ndef binary_crossentropy(target, output, from_logits=False):\n    target = np.array(target)\n    output = np.array(output)\n\n    if target.shape != output.shape:\n        raise ValueError(\n            \"Arguments `target` and `output` must have the same shape. \"\n            \"Received: \"\n            f\"target.shape={target.shape}, output.shape={output.shape}\"\n        )\n\n    if from_logits:\n        output = sigmoid(output)\n\n    output = np.clip(output, backend.epsilon(), 1.0 - backend.epsilon())\n    bce = target * np.log(output)\n    bce += (1.0 - target) * np.log(1.0 - output)\n    return -bce\n\n\ndef moments(x, axes, keepdims=False, synchronized=False):\n    if synchronized:\n        raise NotImplementedError(\n            \"Argument synchronized=True is not supported with NumPy.\"\n        )\n    axes = tuple(axes) if isinstance(axes, list) else axes\n    # The dynamic range of float16 is too limited for statistics. As a\n    # workaround, we simply perform the operations on float32 and convert back\n    # to float16\n    need_cast = False\n    ori_dtype = backend.standardize_dtype(x.dtype)\n    if ori_dtype == \"float16\":\n        need_cast = True\n        x = cast(x, \"float32\")\n\n    mean = np.mean(x, axes, keepdims=True)\n\n    # The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster\n    # but less numerically stable.\n    variance = np.mean(np.square(x), axis=axes, keepdims=True) - np.square(mean)\n\n    if not keepdims:\n        mean = np.squeeze(mean, axes)\n        variance = np.squeeze(variance, axes)\n    if need_cast:\n        # avoid overflow and underflow when casting from float16 to float32\n        mean = np.clip(mean, np.finfo(np.float16).min, np.finfo(np.float16).max)\n        variance = np.clip(\n            variance, np.finfo(np.float16).min, np.finfo(np.float16).max\n        )\n        mean = cast(mean, ori_dtype)\n        variance = cast(variance, ori_dtype)\n    return mean, variance\n\n\ndef batch_normalization(\n    x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3\n):\n    shape = [1] * len(x.shape)\n    shape[axis] = mean.shape[0]\n    mean = np.reshape(mean, shape)\n    variance = np.reshape(variance, shape)\n\n    inv = 1.0 / np.sqrt(variance + epsilon)\n    if scale is not None:\n        scale = np.reshape(scale, shape)\n        inv = inv * scale\n\n    res = -mean * inv\n    if offset is not None:\n        offset = np.reshape(offset, shape)\n        res = res + offset\n\n    return x * inv + res\n\n\ndef ctc_loss(target, output, target_length, output_length, mask_index=0):\n    # Ref: https://github.com/google-deepmind/optax\n    # optax.ctc_loss_with_forward_probs\n    target = convert_to_tensor(target, dtype=\"int32\")\n    output = convert_to_tensor(output)\n    target_length = convert_to_tensor(target_length, \"int32\")\n    output_length = convert_to_tensor(output_length, \"int32\")\n    batch_size, max_input_length, num_classes = output.shape\n    batch_size, max_label_length = target.shape\n    log_epsilon = -1e5\n\n    # Ensure that the dtype promotion behavior matches that of `tf.nn.ctc_loss`\n    dtype = backend.result_type(output.dtype, \"float32\")\n    output = output.astype(dtype)\n\n    def _lengths_to_paddings(lengths, max_length):\n        indices = np.arange(max_length).reshape(\n            (1,) * lengths.ndim + (max_length,)\n        )\n        lengths = np.expand_dims(lengths, axis=-1)\n        elem_valid = indices < lengths\n        return np.logical_not(elem_valid)\n\n    target_paddings = _lengths_to_paddings(target_length, max_label_length)\n    output_paddings = _lengths_to_paddings(output_length, max_input_length)\n    target_paddings = target_paddings.astype(output.dtype)\n    output_paddings = output_paddings.astype(output.dtype)\n\n    logprobs = log_softmax(output, axis=-1)\n    label_lengths = max_label_length - np.sum(target_paddings, axis=1).astype(\n        np.int32\n    )\n\n    # repeat[b, n] == 1.0 when label[b, n] == label[b, n+1].\n    repeat = (target[:, :-1] == target[:, 1:]).astype(np.float32)\n    repeat = np.pad(repeat, ((0, 0), (0, 1)))\n\n    logprobs_phi = logprobs[:, :, mask_index : mask_index + 1]  # [B, T, 1]\n    logprobs_phi = np.transpose(logprobs_phi, (1, 0, 2))  # [T, B, 1]\n\n    _one_hot = one_hot(target, num_classes=num_classes)  # [B, N, K]\n    logprobs_emit = np.einsum(\"btk,bnk->btn\", logprobs, _one_hot)\n    logprobs_emit = np.transpose(logprobs_emit, (1, 0, 2))  # [T, B, N]\n\n    # [B, N]\n    logalpha_phi_init = (\n        np.ones((batch_size, max_label_length + 1), dtype=output.dtype)\n        * log_epsilon\n    )\n    logalpha_phi_init[:, 0] = 0.0\n    logalpha_emit_init = (\n        np.ones((batch_size, max_label_length), dtype=output.dtype)\n        * log_epsilon\n    )\n\n    def update_phi_score(phi, added_score):\n        # Update `phi[:, 1:]`` with adding `added_score` in log space.\n        return np.concatenate(\n            [phi[:, :1], np.logaddexp(phi[:, 1:], added_score)], axis=-1\n        )\n\n    def loop_body(prev, x):\n        prev_phi, prev_emit = prev\n        # emit-to-phi epsilon transition, except if the next label is repetition\n        prev_phi_orig = prev_phi\n        prev_phi = update_phi_score(prev_phi, prev_emit + log_epsilon * repeat)\n\n        logprob_emit, logprob_phi, pad = x\n\n        # phi-to-emit transition\n        next_emit = np.logaddexp(\n            prev_phi[:, :-1] + logprob_emit, prev_emit + logprob_emit\n        )\n        # self-loop transition\n        next_phi = prev_phi + logprob_phi\n        # emit-to-phi blank transition only when the next label is repetition\n        next_phi = update_phi_score(\n            next_phi, prev_emit + logprob_phi + log_epsilon * (1.0 - repeat)\n        )\n\n        pad = pad.reshape((batch_size, 1))\n        next_emit = pad * prev_emit + (1.0 - pad) * next_emit\n        next_phi = pad * prev_phi_orig + (1.0 - pad) * next_phi\n\n        return (next_phi, next_emit), (next_phi, next_emit)\n\n    def np_scan(f, init, xs):\n        carry = init\n        ys = []\n        for x in zip(*xs):\n            carry, y = f(carry, x)\n            ys.append(y)\n        result = []\n        for i in range(len(ys[0])):\n            result.append(np.stack([y[i] for y in ys]))\n        return carry, result\n\n    xs = (logprobs_emit, logprobs_phi, output_paddings.transpose((1, 0)))\n    _, (logalpha_phi, logalpha_emit) = np_scan(\n        loop_body, (logalpha_phi_init, logalpha_emit_init), xs\n    )\n\n    # last row needs to be updated with the last epsilon transition\n    logalpha_phi_last = update_phi_score(logalpha_phi[-1], logalpha_emit[-1])\n    logalpha_phi[-1] = logalpha_phi_last\n\n    # extract per_seq_loss\n    # [B, N+1]\n    _one_hot = one_hot(label_lengths, num_classes=max_label_length + 1)\n    per_seq_loss = -np.einsum(\"bn,bn->b\", logalpha_phi_last, _one_hot)\n    return per_seq_loss\n\n\ndef _ctc_greedy_decode(\n    inputs,\n    sequence_lengths,\n    merge_repeated=True,\n    mask_index=None,\n):\n    inputs = convert_to_tensor(inputs)\n    sequence_lengths = convert_to_tensor(sequence_lengths, dtype=\"int32\")\n    batch_size, max_length, num_classes = inputs.shape\n\n    if mask_index is None:\n        mask_index = num_classes - 1\n\n    indices = np.argmax(inputs, axis=-1).astype(\"int32\")\n    scores = np.max(inputs, axis=-1)\n\n    seqlen_mask = np.arange(max_length)[None, :]\n    seqlen_mask = seqlen_mask >= sequence_lengths[:, None]\n\n    indices = np.where(seqlen_mask, mask_index, indices)\n    scores = np.where(seqlen_mask, 0.0, scores)\n\n    if merge_repeated:\n        repeat_mask = indices[:, 1:] == indices[:, :-1]\n        repeat_mask = np.pad(repeat_mask, ((0, 0), (1, 0)))\n        indices = np.where(repeat_mask, mask_index, indices)\n\n    # We set to -1 for blank labels\n    invalid_mask = indices == mask_index\n    indices = np.where(invalid_mask, -1, indices)\n\n    # We rearrange the indices by moving `mask_index` to the end of the array\n    order = np.expand_dims(np.arange(max_length), axis=0)  # [1, N]\n    order = np.tile(order, (batch_size, 1))  # [B, N]\n    order = np.where(invalid_mask, max_length, order)\n    order = np.argsort(order, axis=-1)\n    indices = np.take_along_axis(indices, order, axis=-1)\n\n    scores = -np.sum(scores, axis=1)[:, None]\n    indices = np.expand_dims(indices, axis=0)\n    return indices, scores\n\n\ndef _ctc_beam_search_decode(\n    inputs,\n    sequence_lengths,\n    beam_width=100,\n    top_paths=1,\n    mask_index=None,\n):\n    inputs = convert_to_tensor(inputs)\n    sequence_lengths = convert_to_tensor(sequence_lengths)\n\n    batch_size, max_seq_len, num_classes = inputs.shape\n    inputs = log_softmax(inputs, axis=-1)\n    seqlen_mask = np.arange(max_seq_len)[None, :] >= sequence_lengths[:, None]\n\n    if mask_index is None:\n        mask_index = num_classes - 1\n\n    # This is a workaround for the fact that np.argsort does not support\n    # the order parameter which is used to break ties when scores are equal.\n    # For compatibility with the tensorflow implementation, we flip the inputs\n    # and the mask_index, and then flip the classes back to the correct indices\n    inputs = np.flip(inputs, axis=2)\n    mask_index = num_classes - mask_index - 1\n\n    _pad = -1\n\n    init_paths = np.full(\n        (batch_size, 2 * beam_width, max_seq_len), _pad, dtype=np.int32\n    )\n\n    num_init_paths = np.min(np.array([num_classes, beam_width]))\n    max_classes = np.argsort(inputs[:, 0], axis=1)[:, -num_init_paths:]\n    init_classes = np.where(max_classes == mask_index, _pad, max_classes)\n    init_paths[:, :num_init_paths, 0] = init_classes\n\n    init_scores = np.full(\n        (batch_size, 2 * beam_width), -np.inf, dtype=inputs.dtype\n    )\n    init_scores[:, :num_init_paths] = np.take_along_axis(\n        inputs[:, 0], max_classes, axis=1\n    )\n    init_masked = init_paths[:, :, 0] == _pad\n\n    def _extend_paths(paths, scores, masked, x):\n        paths = np.repeat(paths, num_classes, axis=0)\n        scores = np.repeat(scores, num_classes)\n        masked = np.repeat(masked, num_classes)\n\n        path_tail_index = np.argmax(paths == _pad, axis=1)\n        paths_arange = np.arange(2 * beam_width * num_classes)\n        path_tails = paths[paths_arange, path_tail_index - 1]\n        path_tails = np.where(path_tail_index == 0, _pad, path_tails)\n\n        classes = np.arange(num_classes)\n        classes[mask_index] = _pad\n        classes = np.tile(classes, 2 * beam_width)\n\n        prev_masked = masked\n        masked = classes == _pad\n\n        masked_repeat = ~prev_masked & (path_tails == classes)\n        classes = np.where(masked_repeat, _pad, classes)\n        paths[paths_arange, path_tail_index] = classes\n\n        x = np.tile(x, 2 * beam_width)\n        scores = scores + x\n\n        return paths, scores, masked\n\n    def _merge_scores(unique_inverse, scores):\n        scores_max = np.max(scores)\n        scores_exp = np.exp(scores - scores_max)\n        scores = np.zeros_like(scores)\n        for i, u in enumerate(unique_inverse):\n            scores[u] += scores_exp[i]\n        scores = np.log(scores) + scores_max\n        return scores\n\n    def _prune_paths(paths, scores, masked):\n        paths, unique_inverse = np.unique(paths, return_inverse=True, axis=0)\n        pad_size = (2 * num_classes * beam_width) - len(paths)\n        if pad_size > 0:\n            paths = np.pad(paths, [[0, pad_size], [0, 0]], constant_values=_pad)\n        paths = paths[: 2 * num_classes * beam_width]\n        if len(unique_inverse.shape) >= 2:\n            unique_inverse = np.squeeze(unique_inverse, axis=1)\n\n        emit_scores = np.where(masked, -np.inf, scores)\n        mask_scores = np.where(masked, scores, -np.inf)\n\n        emit_scores = _merge_scores(unique_inverse, emit_scores)\n        mask_scores = _merge_scores(unique_inverse, mask_scores)\n\n        total_scores = np.logaddexp(emit_scores, mask_scores)\n        top_indices = np.argsort(total_scores, kind=\"stable\")[-beam_width:]\n\n        paths = paths[top_indices]\n        emit_scores = emit_scores[top_indices]\n        mask_scores = mask_scores[top_indices]\n\n        paths = np.tile(paths, (2, 1))\n        scores = np.concatenate([emit_scores, mask_scores])\n        masked = np.concatenate(\n            [np.zeros(beam_width, bool), np.ones(beam_width, bool)]\n        )\n\n        return paths, scores, masked\n\n    def _decode_step(paths, scores, masked, x):\n        paths, scores, masked = _extend_paths(paths, scores, masked, x)\n        paths, scores, masked = _prune_paths(paths, scores, masked)\n        return paths, scores, masked\n\n    def _step(prev, x):\n        paths, scores, masked = prev\n        x, seqlen_mask = x\n        if not seqlen_mask:\n            paths, scores, masked = _decode_step(paths, scores, masked, x)\n        return (paths, scores, masked), None\n\n    def _decode_batch(\n        init_paths, init_scores, init_masked, inputs, seqlen_mask\n    ):\n        def np_scan_only_carry(f, init, xs):\n            carry = init\n            for x in zip(*xs):\n                carry, y = f(carry, x)\n            return carry, None\n\n        (paths, scores, masked), _ = np_scan_only_carry(\n            _step,\n            (init_paths, init_scores, init_masked),\n            (inputs[1:], seqlen_mask[1:]),\n        )\n\n        paths, unique_inverse = np.unique(paths, return_inverse=True, axis=0)\n        pad_size = (2 * num_classes * beam_width) - len(paths)\n        if pad_size > 0:\n            paths = np.pad(paths, [[0, pad_size], [0, 0]], constant_values=_pad)\n        paths = paths[: 2 * num_classes * beam_width]\n        if len(unique_inverse.shape) >= 2:\n            unique_inverse = np.squeeze(unique_inverse, axis=1)\n        scores = _merge_scores(unique_inverse, scores)\n\n        top_indices = np.argsort(scores)[-top_paths:][::-1]\n        paths = paths[top_indices]\n        scores = scores[top_indices]\n\n        return paths, scores\n\n    results = [\n        _decode_batch(p, s, m, i, sm)\n        for p, s, m, i, sm in zip(\n            init_paths, init_scores, init_masked, inputs, seqlen_mask\n        )\n    ]\n    paths = np.stack([r[0] for r in results])\n    scores = np.stack([r[1] for r in results])\n\n    # convert classes back to the correct indices\n    paths = np.where(paths == _pad, _pad, num_classes - paths - 1)\n    paths = np.transpose(paths, [1, 0, 2])\n    return paths, scores\n\n\ndef ctc_decode(\n    inputs,\n    sequence_lengths,\n    strategy=\"greedy\",\n    beam_width=100,\n    top_paths=1,\n    merge_repeated=True,\n    mask_index=0,\n):\n    inputs = convert_to_tensor(inputs)\n    dtype = backend.result_type(inputs.dtype, \"float32\")\n    inputs = cast(inputs, dtype)\n\n    if strategy == \"greedy\":\n        return _ctc_greedy_decode(\n            inputs,\n            sequence_lengths,\n            merge_repeated=merge_repeated,\n            mask_index=mask_index,\n        )\n    elif strategy == \"beam_search\":\n        return _ctc_beam_search_decode(\n            inputs,\n            sequence_lengths,\n            beam_width=beam_width,\n            top_paths=top_paths,\n            mask_index=mask_index,\n        )\n    else:\n        raise ValueError(\n            f\"Invalid strategy {strategy}. Supported values are \"\n            \"'greedy' and 'beam_search'.\"\n        )\n\n\ndef psnr(x1, x2, max_val):\n    if x1.shape != x2.shape:\n        raise ValueError(\n            f\"Input shapes {x1.shape} and {x2.shape} must \"\n            \"match for PSNR calculation. \"\n        )\n\n    max_val = convert_to_tensor(max_val, dtype=x2.dtype)\n    mse = np.mean(np.square(x1 - x2))\n    psnr = 20 * np.log10(max_val) - 10 * np.log10(mse)\n    return psnr\n\n\ndef _get_large_negative(dtype):\n    dtype = backend.standardize_dtype(dtype)\n    val = 65500.0 if dtype == \"float16\" else 3.38953e38\n    return np.asarray(val * -0.7, dtype=dtype)\n\n\ndef _apply_masks(logits, mask, is_causal):\n    if mask is None and not is_causal:\n        return logits\n\n    combined_mask = np.ones_like(logits, dtype=np.bool_)\n    if mask is not None:\n        combined_mask = np.logical_and(combined_mask, mask)\n\n    if is_causal:\n        T, S = logits.shape[2], logits.shape[3]\n        mask = np.tril(np.ones((T, S), dtype=np.bool_))\n        mask = mask[None, None, :, :]\n        combined_mask = np.logical_and(combined_mask, mask)\n\n    padded_logits = np.where(\n        combined_mask, logits, _get_large_negative(logits.dtype)\n    )\n    return padded_logits\n\n\ndef _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale):\n    original_dtype = key.dtype\n    logits_dtype = np.promote_types(query.dtype, np.float32)\n    if backend.standardize_dtype(key.dtype) == \"bfloat16\":\n        # `np.einsum` doesn't support bfloat16\n        key = key.astype(\"float32\")\n        value = value.astype(\"float32\")\n    logits = np.einsum(\"BTNH,BSNH->BNTS\", query, key)\n    logits = logits.astype(logits_dtype)\n    logits *= np.array(scale, dtype=logits.dtype)\n\n    if bias is not None:\n        logits = (logits + bias).astype(logits.dtype)\n\n    padded_logits = _apply_masks(logits, mask, is_causal)\n\n    # Softmax and it is always carried out in fp32.\n    padded_logits = padded_logits.astype(np.float32)\n    probs = softmax(padded_logits, axis=-1).astype(original_dtype)\n    encoded_dtype = probs.dtype\n    if backend.standardize_dtype(probs.dtype) == \"bfloat16\":\n        # `np.einsum` doesn't support bfloat16\n        probs = probs.astype(\"float32\")\n        value = value.astype(\"float32\")\n    encoded = np.einsum(\"BNTS,BSNH->BTNH\", probs, value)\n    encoded = encoded.astype(encoded_dtype)\n    return encoded\n\n\ndef dot_product_attention(\n    query,\n    key,\n    value,\n    bias=None,\n    mask=None,\n    scale=None,\n    is_causal=False,\n    flash_attention=None,\n    attn_logits_soft_cap=None,\n):\n    if flash_attention is None:\n        flash_attention = False\n    if flash_attention:\n        raise ValueError(\"Flash attention is not supported in numpy backend.\")\n\n    # Ref: jax.nn.dot_product_attention\n    # https://github.com/jax-ml/jax/blob/jax-v0.4.32/jax/_src/nn/functions.py#L828\n    # Not support `query_seq_lengths` and `key_value_seq_lengths` args\n    query = convert_to_tensor(query)\n    key = convert_to_tensor(key)\n    value = convert_to_tensor(value)\n    if len(query.shape) != 4:\n        raise ValueError(\n            \"`dot_product_attention` only supports 4D inputs. \"\n            f\"Received: query.shape={query.shape}, key.shape={key.shape}, \"\n            f\"value.shape={value.shape}.\"\n        )\n    compute_dtype = backend.result_type(query.dtype, key.dtype, value.dtype)\n    query = cast(query, compute_dtype)\n    key = cast(key, compute_dtype)\n    value = cast(value, compute_dtype)\n    if bias is not None:\n        bias = convert_to_tensor(bias, dtype=compute_dtype)\n\n    _, _, _, H = key.shape\n    scale = (1.0 / np.sqrt(H)) if scale is None else scale\n    return _dot_product_attention_xla(\n        query, key, value, bias, mask, is_causal, scale\n    )\n\n\ndef unfold(input, kernel_size, dilation=1, padding=0, stride=1):\n    \"\"\"NumPy implementation of Unfold.\n    Extract sliding local blocks from a **NCHW** batched image tensor.\n\n    Args:\n        input: 4-D tensor, shape (N, C, H, W)  **required**.\n        kernel_size: int or (kH, kW)\n        dilation: int or (dH, dW), default 1\n        padding: int or (pH, pW), default 0\n        stride: int or (sH, sW), default 1\n\n    Returns:\n        3-D tensor, shape (N, C*kH*kW, L)\n    \"\"\"\n\n    def _pair(x):\n        return (x, x) if isinstance(x, int) else x\n\n    k = _pair(kernel_size)\n    d = _pair(dilation)\n    p = _pair(padding)\n    s = _pair(stride)\n\n    N, C, H, W = input.shape\n\n    # ---- padding ----\n    if any(_ > 0 for _ in p):\n        input = np.pad(\n            input, ((0, 0), (0, 0), (p[0], p[0]), (p[1], p[1])), mode=\"constant\"\n        )\n\n    # ----  spatial size ----\n    oH = (input.shape[2] - (k[0] - 1) * d[0] - 1) // s[0] + 1\n    oW = (input.shape[3] - (k[1] - 1) * d[1] - 1) // s[1] + 1\n\n    i0 = np.arange(0, oH) * s[0]\n    j0 = np.arange(0, oW) * s[1]\n    i, j = np.meshgrid(i0, j0, indexing=\"ij\")  # shape (oH, oW)\n    i = i.reshape(-1)\n    j = j.reshape(-1)\n\n    # ---- flatten patches ----\n    patches = np.empty((N, C, k[0], k[1], oH * oW), dtype=input.dtype)\n    for idx in range(k[0]):\n        for jdx in range(k[1]):\n            patches[:, :, idx, jdx, :] = input[\n                :, :, i + idx * d[0], j + jdx * d[1]\n            ]\n\n    # ---- reshape -> (N, C*kH*kW, L) ----\n    return patches.reshape(N, C * k[0] * k[1], -1)\n\n\ndef fold(x, output_size, kernel_size, dilation=1, padding=0, stride=1):\n    \"\"\"NumPy implementation of Fold (col2im).\n    Combine an array of sliding local blocks into a large tensor.\n\n    Args:\n        x: 3-D tensor, shape (N, C*kH*kW, L)  **required**.\n        output_size: int or (oH, oW)\n        kernel_size: int or (kH, kW)\n        dilation: int or (dH, dW), default 1\n        padding: int or (pH, pW), default 0\n        stride: int or (sH, sW), default 1\n\n    Returns:\n        4-D tensor, shape (N, C, oH, oW)\n    \"\"\"\n\n    def _pair(val):\n        return (val, val) if isinstance(val, int) else val\n\n    oH, oW = _pair(output_size)\n    kH, kW = _pair(kernel_size)\n    dH, dW = _pair(dilation)\n    pH, pW = _pair(padding)\n    sH, sW = _pair(stride)\n\n    N, CKK, L = x.shape\n    C = CKK // (kH * kW)\n\n    # Number of output patches along each dimension\n    nH = (oH + 2 * pH - dH * (kH - 1) - 1) // sH + 1\n    nW = (oW + 2 * pW - dW * (kW - 1) - 1) // sW + 1\n\n    # Reshape: (N, C*kH*kW, L) -> (N, C, kH, kW, nH, nW)\n    x = np.reshape(x, (N, C, kH, kW, nH, nW))\n\n    # Padded output size\n    oH_pad = oH + 2 * pH\n    oW_pad = oW + 2 * pW\n\n    output = np.zeros((N, C, oH_pad, oW_pad), dtype=x.dtype)\n\n    for i in range(kH):\n        for j in range(kW):\n            h_start = i * dH\n            w_start = j * dW\n            h_indices = h_start + np.arange(nH) * sH\n            w_indices = w_start + np.arange(nW) * sW\n            h_ix, w_ix = np.ix_(h_indices, w_indices)\n            output[:, :, h_ix, w_ix] += x[:, :, i, j, :, :]\n\n    # Remove padding\n    if pH > 0 or pW > 0:\n        output = output[:, :, pH : oH_pad - pH, pW : oW_pad - pW]\n\n    return output\n\n\ndef depth_to_space(x, block_size, data_format=\"channels_last\"):\n    \"\"\"NumPy implementation of depth_to_space (pixel shuffle).\n\n    Rearranges data from depth into blocks of spatial data.\n\n    Args:\n        x: 4-D tensor with shape (N, H, W, C) for channels_last or\n            (N, C, H, W) for channels_first.\n        block_size: An integer specifying the block size.\n        data_format: \"channels_last\" or \"channels_first\".\n\n    Returns:\n        A tensor with shape (N, H*block_size, W*block_size, C/block_size**2)\n        for channels_last or (N, C/block_size**2, H*block_size, W*block_size)\n        for channels_first.\n    \"\"\"\n    if data_format == \"channels_last\":\n        # NHWC format\n        n, h, w, c = x.shape\n        new_c = c // (block_size**2)\n        # Reshape: (N, H, W, C) -> (N, H, W, block_size, block_size, new_C)\n        x = np.reshape(x, (n, h, w, block_size, block_size, new_c))\n        # Transpose to (N, H, bH, W, bW, new_C) to interleave spatial blocks.\n        x = np.transpose(x, (0, 1, 3, 2, 4, 5))\n        # Reshape to the final spatial dimensions.\n        x = np.reshape(x, (n, h * block_size, w * block_size, new_c))\n    else:\n        # NCHW format\n        n, c, h, w = x.shape\n        new_c = c // (block_size**2)\n        # Reshape: (N, C, H, W) -> (N, new_C, block_size, block_size, H, W)\n        x = np.reshape(x, (n, new_c, block_size, block_size, h, w))\n        # Transpose: (N, C, bH, bW, H, W) -> (N, C, H, bH, W, bW)\n        x = np.transpose(x, (0, 1, 4, 2, 5, 3))\n        # Reshape: (N, C, H, bH, W, bW) -> (N, C, H*bH, W*bW)\n        x = np.reshape(x, (n, new_c, h * block_size, w * block_size))\n    return x\n\n\ndef space_to_depth(x, block_size, data_format=\"channels_last\"):\n    \"\"\"NumPy implementation of space_to_depth (pixel unshuffle).\n\n    Rearranges blocks of spatial data into depth.\n\n    Args:\n        x: 4-D tensor with shape (N, H, W, C) for channels_last or\n            (N, C, H, W) for channels_first.\n        block_size: An integer specifying the block size.\n        data_format: \"channels_last\" or \"channels_first\".\n\n    Returns:\n        A tensor with shape (N, H/block_size, W/block_size, C*block_size**2)\n        for channels_last or (N, C*block_size**2, H/block_size, W/block_size)\n        for channels_first.\n    \"\"\"\n    if data_format == \"channels_last\":\n        # NHWC format\n        n, h, w, c = x.shape\n        new_h = h // block_size\n        new_w = w // block_size\n        # Reshape: (N, H, W, C) -> (N, new_H, bH, new_W, bW, C)\n        x = np.reshape(x, (n, new_h, block_size, new_w, block_size, c))\n        # Transpose: -> (N, new_H, new_W, bH, bW, C)\n        x = np.transpose(x, (0, 1, 3, 2, 4, 5))\n        # Reshape: -> (N, new_H, new_W, C*bH*bW)\n        x = np.reshape(x, (n, new_h, new_w, c * block_size**2))\n    else:\n        # NCHW format\n        n, c, h, w = x.shape\n        new_h = h // block_size\n        new_w = w // block_size\n        # Reshape: (N, C, H, W) -> (N, C, new_H, bH, new_W, bW)\n        x = np.reshape(x, (n, c, new_h, block_size, new_w, block_size))\n        # Transpose: -> (N, C, bH, bW, new_H, new_W)\n        x = np.transpose(x, (0, 1, 3, 5, 2, 4))\n        # Reshape: -> (N, C*bH*bW, new_H, new_W)\n        x = np.reshape(x, (n, c * block_size**2, new_h, new_w))\n    return x\n"
  },
  {
    "path": "keras/src/backend/numpy/numpy.py",
    "content": "import numpy as np\n\nfrom keras.src import tree\nfrom keras.src.backend import config\nfrom keras.src.backend import standardize_dtype\nfrom keras.src.backend.common import dtypes\nfrom keras.src.backend.common.backend_utils import standardize_axis_for_numpy\nfrom keras.src.backend.numpy.core import convert_to_tensor\n\n\ndef rot90(array, k=1, axes=(0, 1)):\n    \"\"\"Rotate an array by 90 degrees in the specified plane.\"\"\"\n    if array.ndim < 2:\n        raise ValueError(\n            \"Input array must have at least 2 dimensions. \"\n            f\"Received: array.ndim={array.ndim}\"\n        )\n    if len(axes) != 2 or axes[0] == axes[1]:\n        raise ValueError(\n            f\"Invalid axes: {axes}. Axes must be a tuple \"\n            \"of two different dimensions.\"\n        )\n    return np.rot90(array, k=k, axes=axes)\n\n\ndef add(x1, x2):\n    if not isinstance(x1, (int, float)):\n        x1 = convert_to_tensor(x1)\n    if not isinstance(x2, (int, float)):\n        x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(\n        getattr(x1, \"dtype\", type(x1)),\n        getattr(x2, \"dtype\", type(x2)),\n    )\n    x1 = convert_to_tensor(x1, dtype)\n    x2 = convert_to_tensor(x2, dtype)\n    return np.add(x1, x2)\n\n\ndef einsum(subscripts, *operands, **kwargs):\n    operands = tree.map_structure(convert_to_tensor, operands)\n    dtypes_to_resolve = list(set(standardize_dtype(x.dtype) for x in operands))\n    # When operands are of int8, we cast the result to int32 to align with\n    # the behavior of jax.\n    if len(dtypes_to_resolve) == 1 and dtypes_to_resolve[0] == \"int8\":\n        compute_dtype = \"int32\"  # prevent overflow\n        result_dtype = \"int32\"\n    else:\n        result_dtype = dtypes.result_type(*dtypes_to_resolve)\n        compute_dtype = result_dtype\n        # TODO: np.einsum doesn't support bfloat16\n        if compute_dtype == \"bfloat16\":\n            compute_dtype = \"float32\"\n    operands = tree.map_structure(lambda x: x.astype(compute_dtype), operands)\n    return np.einsum(subscripts, *operands, **kwargs).astype(result_dtype)\n\n\ndef subtract(x1, x2):\n    if not isinstance(x1, (int, float)):\n        x1 = convert_to_tensor(x1)\n    if not isinstance(x2, (int, float)):\n        x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(\n        getattr(x1, \"dtype\", type(x1)),\n        getattr(x2, \"dtype\", type(x2)),\n    )\n    x1 = convert_to_tensor(x1, dtype)\n    x2 = convert_to_tensor(x2, dtype)\n    return np.subtract(x1, x2)\n\n\ndef matmul(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    # When both x1 and x2 are of int8, we cast the outputs to int32 to align\n    # with jax\n    x1_dtype = standardize_dtype(x1.dtype)\n    x2_dtype = standardize_dtype(x2.dtype)\n    if x1_dtype == \"int8\" and x2_dtype == \"int8\":\n        dtype = \"int32\"\n    else:\n        dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    x1 = x1.astype(dtype)\n    x2 = x2.astype(dtype)\n    return np.matmul(x1, x2).astype(dtype)\n\n\ndef multiply(x1, x2):\n    if not isinstance(x1, (int, float)):\n        x1 = convert_to_tensor(x1)\n    if not isinstance(x2, (int, float)):\n        x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(\n        getattr(x1, \"dtype\", type(x1)),\n        getattr(x2, \"dtype\", type(x2)),\n    )\n    x1 = convert_to_tensor(x1, dtype)\n    x2 = convert_to_tensor(x2, dtype)\n    return np.multiply(x1, x2)\n\n\ndef mean(x, axis=None, keepdims=False):\n    axis = standardize_axis_for_numpy(axis)\n    x = convert_to_tensor(x)\n    ori_dtype = standardize_dtype(x.dtype)\n    if \"int\" in ori_dtype or ori_dtype == \"bool\":\n        result_dtype = dtypes.result_type(x.dtype, \"float32\")\n    else:\n        result_dtype = ori_dtype\n    return np.mean(x, axis=axis, keepdims=keepdims).astype(result_dtype)\n\n\ndef max(x, axis=None, keepdims=False, initial=None):\n    axis = standardize_axis_for_numpy(axis)\n    return np.max(x, axis=axis, keepdims=keepdims, initial=initial)\n\n\ndef ones(shape, dtype=None):\n    dtype = dtype or config.floatx()\n    return np.ones(shape, dtype=dtype)\n\n\ndef zeros(shape, dtype=None):\n    dtype = dtype or config.floatx()\n    return np.zeros(shape, dtype=dtype)\n\n\ndef absolute(x):\n    return np.absolute(x)\n\n\ndef abs(x):\n    return absolute(x)\n\n\ndef all(x, axis=None, keepdims=False):\n    axis = standardize_axis_for_numpy(axis)\n    return np.all(x, axis=axis, keepdims=keepdims)\n\n\ndef allclose(x1, x2, rtol=1e-05, atol=1e-08, equal_nan=False):\n    return np.allclose(x1, x2, rtol=rtol, atol=atol, equal_nan=equal_nan)\n\n\ndef angle(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = x.astype(dtype)\n    return np.angle(x)\n\n\ndef any(x, axis=None, keepdims=False):\n    axis = standardize_axis_for_numpy(axis)\n    return np.any(x, axis=axis, keepdims=keepdims)\n\n\ndef amax(x, axis=None, keepdims=False):\n    axis = standardize_axis_for_numpy(axis)\n    return np.amax(x, axis=axis, keepdims=keepdims)\n\n\ndef amin(x, axis=None, keepdims=False):\n    axis = standardize_axis_for_numpy(axis)\n    return np.amin(x, axis=axis, keepdims=keepdims)\n\n\ndef append(x1, x2, axis=None):\n    axis = standardize_axis_for_numpy(axis)\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    x1 = x1.astype(dtype)\n    x2 = x2.astype(dtype)\n    return np.append(x1, x2, axis=axis)\n\n\ndef arange(start, stop=None, step=None, dtype=None):\n    if dtype is None:\n        dtypes_to_resolve = [getattr(start, \"dtype\", type(start))]\n        if stop is not None:\n            dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n        if step is not None:\n            dtypes_to_resolve.append(getattr(step, \"dtype\", type(step)))\n        dtype = dtypes.result_type(*dtypes_to_resolve)\n    if stop is None:\n        start, stop = 0, start\n    if step is None:\n        step = 1\n    return np.arange(start, stop, step=step, dtype=dtype)\n\n\ndef arccos(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = x.astype(dtype)\n    return np.arccos(x)\n\n\ndef arccosh(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = x.astype(dtype)\n    return np.arccosh(x)\n\n\ndef arcsin(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = x.astype(dtype)\n    return np.arcsin(x)\n\n\ndef arcsinh(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = x.astype(dtype)\n    return np.arcsinh(x)\n\n\ndef arctan(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = x.astype(dtype)\n    return np.arctan(x)\n\n\ndef arctan2(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype, float)\n    x1 = x1.astype(dtype)\n    x2 = x2.astype(dtype)\n    return np.arctan2(x1, x2)\n\n\ndef arctanh(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = x.astype(dtype)\n    return np.arctanh(x)\n\n\ndef argmax(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    axis = standardize_axis_for_numpy(axis)\n    dtype = standardize_dtype(x.dtype)\n    if \"float\" not in dtype or x.ndim == 0:\n        return np.argmax(x, axis=axis, keepdims=keepdims).astype(\"int32\")\n\n    dtype = dtypes.result_type(dtype, \"float32\")\n    x = x.astype(dtype)\n    is_negative_zero = (x == 0.0) & np.signbit(x)\n    x = np.where(is_negative_zero, -np.finfo(x.dtype).tiny, x)\n    return np.argmax(x, axis=axis, keepdims=keepdims).astype(\"int32\")\n\n\ndef argmin(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    axis = standardize_axis_for_numpy(axis)\n    dtype = standardize_dtype(x.dtype)\n    if \"float\" not in dtype or x.ndim == 0:\n        return np.argmin(x, axis=axis, keepdims=keepdims).astype(\"int32\")\n\n    dtype = dtypes.result_type(dtype, \"float32\")\n    x = x.astype(dtype)\n    is_negative_zero = (x == 0.0) & np.signbit(x)\n    x = np.where(is_negative_zero, -np.finfo(x.dtype).tiny, x)\n    return np.argmin(x, axis=axis, keepdims=keepdims).astype(\"int32\")\n\n\ndef argsort(x, axis=-1):\n    axis = standardize_axis_for_numpy(axis)\n    return np.argsort(x, axis=axis).astype(\"int32\")\n\n\ndef array(x, dtype=None):\n    return convert_to_tensor(x, dtype=dtype)\n\n\ndef view(x, dtype=None):\n    x = convert_to_tensor(x)\n    return x.view(dtype=dtype)\n\n\ndef average(x, axis=None, weights=None):\n    axis = standardize_axis_for_numpy(axis)\n    x = convert_to_tensor(x)\n    dtypes_to_resolve = [x.dtype, float]\n    if weights is not None:\n        weights = convert_to_tensor(weights)\n        dtypes_to_resolve.append(weights.dtype)\n    dtype = dtypes.result_type(*dtypes_to_resolve)\n    x = x.astype(dtype)\n    if weights is not None:\n        weights = weights.astype(dtype)\n    return np.average(x, weights=weights, axis=axis)\n\n\ndef bartlett(x):\n    x = convert_to_tensor(x)\n    return np.bartlett(x).astype(config.floatx())\n\n\ndef hamming(x):\n    x = convert_to_tensor(x)\n    return np.hamming(x).astype(config.floatx())\n\n\ndef hanning(x):\n    x = convert_to_tensor(x)\n    return np.hanning(x).astype(config.floatx())\n\n\ndef heaviside(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    if dtype in [\"int8\", \"int16\", \"int32\", \"uint8\", \"uint16\", \"uint32\"]:\n        dtype = config.floatx()\n    elif dtype in [\"int64\"]:\n        dtype = \"float64\"\n\n    return np.heaviside(x1, x2).astype(dtype)\n\n\ndef kaiser(x, beta):\n    x = convert_to_tensor(x)\n    return np.kaiser(x, beta).astype(config.floatx())\n\n\ndef bincount(x, weights=None, minlength=0, sparse=False):\n    if sparse:\n        raise ValueError(\"Unsupported value `sparse=True` with numpy backend\")\n    x = convert_to_tensor(x)\n    dtypes_to_resolve = [x.dtype]\n    if weights is not None:\n        weights = convert_to_tensor(weights)\n        dtypes_to_resolve.append(weights.dtype)\n        dtype = dtypes.result_type(*dtypes_to_resolve)\n    else:\n        dtype = \"int32\"\n    if len(x.shape) == 2:\n        if weights is None:\n\n            def bincount_fn(arr):\n                return np.bincount(arr, minlength=minlength)\n\n            bincounts = list(map(bincount_fn, x))\n        else:\n\n            def bincount_fn(arr_w):\n                return np.bincount(\n                    arr_w[0], weights=arr_w[1], minlength=minlength\n                )\n\n            bincounts = list(map(bincount_fn, zip(x, weights)))\n\n        return np.stack(bincounts).astype(dtype)\n    return np.bincount(x, weights, minlength).astype(dtype)\n\n\ndef bitwise_and(x, y):\n    x = convert_to_tensor(x)\n    y = convert_to_tensor(y)\n    dtype = dtypes.result_type(x.dtype, y.dtype)\n    x = x.astype(dtype)\n    y = y.astype(dtype)\n    return np.bitwise_and(x, y)\n\n\ndef bitwise_invert(x):\n    x = convert_to_tensor(x)\n    return np.bitwise_not(x)\n\n\ndef bitwise_not(x):\n    return bitwise_invert(x)\n\n\ndef bitwise_or(x, y):\n    x = convert_to_tensor(x)\n    y = convert_to_tensor(y)\n    dtype = dtypes.result_type(x.dtype, y.dtype)\n    x = x.astype(dtype)\n    y = y.astype(dtype)\n    return np.bitwise_or(x, y)\n\n\ndef bitwise_xor(x, y):\n    x = convert_to_tensor(x)\n    y = convert_to_tensor(y)\n    dtype = dtypes.result_type(x.dtype, y.dtype)\n    x = x.astype(dtype)\n    y = y.astype(dtype)\n    return np.bitwise_xor(x, y)\n\n\ndef bitwise_left_shift(x, y):\n    x = convert_to_tensor(x)\n    if not isinstance(y, int):\n        y = convert_to_tensor(y)\n        dtype = dtypes.result_type(x.dtype, y.dtype)\n        x = x.astype(dtype)\n        y = y.astype(dtype)\n    return np.left_shift(x, y)\n\n\ndef left_shift(x, y):\n    return bitwise_left_shift(x, y)\n\n\ndef bitwise_right_shift(x, y):\n    x = convert_to_tensor(x)\n    if not isinstance(y, int):\n        y = convert_to_tensor(y)\n        dtype = dtypes.result_type(x.dtype, y.dtype)\n        x = x.astype(dtype)\n        y = y.astype(dtype)\n    return np.right_shift(x, y)\n\n\ndef right_shift(x, y):\n    return bitwise_right_shift(x, y)\n\n\ndef blackman(x):\n    x = convert_to_tensor(x)\n    return np.blackman(x).astype(config.floatx())\n\n\ndef broadcast_to(x, shape):\n    return np.broadcast_to(x, shape)\n\n\ndef cbrt(x):\n    x = convert_to_tensor(x)\n\n    dtype = standardize_dtype(x.dtype)\n    if dtype in [\"bool\", \"int8\", \"int16\", \"int32\", \"uint8\", \"uint16\", \"uint32\"]:\n        dtype = config.floatx()\n    elif dtype == \"int64\":\n        dtype = \"float64\"\n\n    return np.cbrt(x).astype(dtype)\n\n\ndef ceil(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = x.astype(dtype)\n    return np.ceil(x)\n\n\ndef clip(x, x_min, x_max):\n    x = convert_to_tensor(x)\n    dtype = standardize_dtype(x.dtype)\n    if dtype == \"bool\":\n        dtype = \"int32\"\n    return np.clip(x, x_min, x_max).astype(dtype)\n\n\ndef concatenate(xs, axis=0):\n    axis = standardize_axis_for_numpy(axis)\n    dtype_set = set([getattr(x, \"dtype\", type(x)) for x in xs])\n    if len(dtype_set) > 1:\n        dtype = dtypes.result_type(*dtype_set)\n        xs = tree.map_structure(\n            lambda x: convert_to_tensor(x).astype(dtype), xs\n        )\n    return np.concatenate(xs, axis=axis)\n\n\ndef conjugate(x):\n    return np.conjugate(x)\n\n\ndef conj(x):\n    return conjugate(x)\n\n\ndef copy(x):\n    return np.copy(x)\n\n\ndef cos(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = x.astype(dtype)\n    return np.cos(x)\n\n\ndef cosh(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = x.astype(dtype)\n    return np.cosh(x)\n\n\ndef count_nonzero(x, axis=None):\n    axis = standardize_axis_for_numpy(axis)\n    # np.count_nonzero will return python int when axis=None, so we need\n    # to convert_to_tensor\n    return convert_to_tensor(np.count_nonzero(x, axis=axis)).astype(\"int32\")\n\n\ndef cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None):\n    axis = standardize_axis_for_numpy(axis)\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    x1 = x1.astype(dtype)\n    x2 = x2.astype(dtype)\n    return np.cross(\n        x1,\n        x2,\n        axisa=axisa,\n        axisb=axisb,\n        axisc=axisc,\n        axis=axis,\n    )\n\n\ndef cumprod(x, axis=None, dtype=None):\n    axis = standardize_axis_for_numpy(axis)\n    dtype = dtypes.result_type(dtype or x.dtype)\n    if dtype == \"bool\":\n        dtype = \"int32\"\n    return np.cumprod(x, axis=axis, dtype=dtype)\n\n\ndef cumsum(x, axis=None, dtype=None):\n    axis = standardize_axis_for_numpy(axis)\n    dtype = dtypes.result_type(dtype or x.dtype)\n    if dtype == \"bool\":\n        dtype = \"int32\"\n    return np.cumsum(x, axis=axis, dtype=dtype)\n\n\ndef deg2rad(x):\n    x = convert_to_tensor(x)\n\n    if x.dtype in [\"int64\", \"float64\"]:\n        dtype = \"float64\"\n    elif x.dtype in [\"bfloat16\", \"float16\"]:\n        dtype = x.dtype\n    else:\n        dtype = config.floatx()\n\n    return np.deg2rad(x).astype(dtype)\n\n\ndef diag(x, k=0):\n    return np.diag(x, k=k)\n\n\ndef diagflat(x, k=0):\n    return np.diagflat(x, k=k)\n\n\ndef diagonal(x, offset=0, axis1=0, axis2=1):\n    axis1 = standardize_axis_for_numpy(axis1)\n    axis2 = standardize_axis_for_numpy(axis2)\n    return np.diagonal(x, offset=offset, axis1=axis1, axis2=axis2)\n\n\ndef diff(a, n=1, axis=-1):\n    return np.diff(a, n=n, axis=axis)\n\n\ndef digitize(x, bins):\n    return np.digitize(x, bins).astype(np.int32)\n\n\ndef dot(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    x1 = x1.astype(dtype)\n    x2 = x2.astype(dtype)\n    return np.dot(x1, x2)\n\n\ndef dstack(xs):\n    dtype_set = set([getattr(x, \"dtype\", type(x)) for x in xs])\n    if len(dtype_set) > 1:\n        dtype = dtypes.result_type(*dtype_set)\n        xs = tree.map_structure(\n            lambda x: convert_to_tensor(x).astype(dtype), xs\n        )\n    return np.dstack(xs)\n\n\ndef empty(shape, dtype=None):\n    dtype = dtype or config.floatx()\n    return np.empty(shape, dtype=dtype)\n\n\ndef empty_like(x, dtype=None):\n    return np.empty_like(x, dtype=dtype)\n\n\ndef equal(x1, x2):\n    return np.equal(x1, x2)\n\n\ndef exp(x):\n    x = convert_to_tensor(x)\n    ori_dtype = standardize_dtype(x.dtype)\n    if \"int\" in ori_dtype or ori_dtype == \"bool\":\n        x = x.astype(config.floatx())\n    return np.exp(x)\n\n\ndef exp2(x):\n    x = convert_to_tensor(x)\n    ori_dtype = standardize_dtype(x.dtype)\n    if \"int\" in ori_dtype or ori_dtype == \"bool\":\n        x = x.astype(config.floatx())\n    return np.exp2(x)\n\n\ndef expand_dims(x, axis):\n    axis = standardize_axis_for_numpy(axis)\n    return np.expand_dims(x, axis)\n\n\ndef expm1(x):\n    x = convert_to_tensor(x)\n    ori_dtype = standardize_dtype(x.dtype)\n    if \"int\" in ori_dtype or ori_dtype == \"bool\":\n        x = x.astype(config.floatx())\n    return np.expm1(x)\n\n\ndef flip(x, axis=None):\n    axis = standardize_axis_for_numpy(axis)\n    return np.flip(x, axis=axis)\n\n\ndef floor(x):\n    x = convert_to_tensor(x)\n    dtype = (\n        config.floatx()\n        if standardize_dtype(x.dtype) == \"int64\"\n        else dtypes.result_type(x.dtype, float)\n    )\n    x = x.astype(dtype)\n    return np.floor(x)\n\n\ndef full(shape, fill_value, dtype=None):\n    dtype = dtype or config.floatx()\n    return np.full(shape, fill_value, dtype=dtype)\n\n\ndef full_like(x, fill_value, dtype=None):\n    return np.full_like(x, fill_value, dtype=dtype)\n\n\ndef gcd(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    return np.gcd(x1, x2).astype(dtype)\n\n\ndef geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):\n    dtype = dtype or config.floatx()\n    return np.geomspace(\n        start, stop, num=num, endpoint=endpoint, dtype=dtype, axis=axis\n    )\n\n\ndef greater(x1, x2):\n    return np.greater(x1, x2)\n\n\ndef greater_equal(x1, x2):\n    return np.greater_equal(x1, x2)\n\n\ndef hstack(xs):\n    dtype_set = set([getattr(x, \"dtype\", type(x)) for x in xs])\n    if len(dtype_set) > 1:\n        dtype = dtypes.result_type(*dtype_set)\n        xs = tree.map_structure(\n            lambda x: convert_to_tensor(x).astype(dtype), xs\n        )\n    return np.hstack(xs)\n\n\ndef hsplit(x, indices_or_sections):\n    x = convert_to_tensor(x)\n    return np.hsplit(x, indices_or_sections)\n\n\ndef hypot(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    if dtype in [\"int8\", \"int16\", \"int32\", \"uint8\", \"uint16\", \"uint32\"]:\n        dtype = config.floatx()\n    elif dtype in [\"int64\"]:\n        dtype = \"float64\"\n\n    return np.hypot(x1, x2).astype(dtype)\n\n\ndef identity(n, dtype=None):\n    dtype = dtype or config.floatx()\n    return np.identity(n, dtype=dtype)\n\n\ndef imag(x):\n    return np.imag(x)\n\n\ndef isclose(x1, x2, rtol=1e-5, atol=1e-8, equal_nan=False):\n    return np.isclose(x1, x2, rtol, atol, equal_nan)\n\n\ndef isfinite(x):\n    return np.isfinite(x)\n\n\ndef isin(x1, x2, assume_unique=False, invert=False):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return np.isin(x1, x2, assume_unique=assume_unique, invert=invert)\n\n\ndef isinf(x):\n    return np.isinf(x)\n\n\ndef isnan(x):\n    return np.isnan(x)\n\n\ndef isneginf(x):\n    x = convert_to_tensor(x)\n    return np.isneginf(x)\n\n\ndef isposinf(x):\n    x = convert_to_tensor(x)\n    return np.isposinf(x)\n\n\ndef isreal(x):\n    x = convert_to_tensor(x)\n    return np.isreal(x)\n\n\ndef kron(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    return np.kron(x1, x2).astype(dtype)\n\n\ndef lcm(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    return np.lcm(x1, x2).astype(dtype)\n\n\ndef ldexp(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype, float)\n\n    if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:\n        raise TypeError(\n            f\"ldexp exponent must be an integer type. \"\n            f\"Received: x2 dtype={x2.dtype}\"\n        )\n    return np.ldexp(x1, x2).astype(dtype)\n\n\ndef less(x1, x2):\n    return np.less(x1, x2)\n\n\ndef less_equal(x1, x2):\n    return np.less_equal(x1, x2)\n\n\ndef linspace(\n    start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0\n):\n    axis = standardize_axis_for_numpy(axis)\n    if dtype is None:\n        dtypes_to_resolve = [\n            getattr(start, \"dtype\", type(start)),\n            getattr(stop, \"dtype\", type(stop)),\n            float,\n        ]\n        dtype = dtypes.result_type(*dtypes_to_resolve)\n    return np.linspace(\n        start,\n        stop,\n        num=num,\n        endpoint=endpoint,\n        retstep=retstep,\n        dtype=dtype,\n        axis=axis,\n    )\n\n\ndef log(x):\n    x = convert_to_tensor(x)\n    dtype = (\n        config.floatx()\n        if standardize_dtype(x.dtype) == \"int64\"\n        else dtypes.result_type(x.dtype, float)\n    )\n    return np.log(x, dtype=dtype)\n\n\ndef log10(x):\n    x = convert_to_tensor(x)\n    dtype = (\n        config.floatx()\n        if standardize_dtype(x.dtype) == \"int64\"\n        else dtypes.result_type(x.dtype, float)\n    )\n    return np.log10(x, dtype=dtype)\n\n\ndef log1p(x):\n    x = convert_to_tensor(x)\n    dtype = (\n        config.floatx()\n        if standardize_dtype(x.dtype) == \"int64\"\n        else dtypes.result_type(x.dtype, float)\n    )\n    return np.log1p(x, dtype=dtype)\n\n\ndef log2(x):\n    x = convert_to_tensor(x)\n    dtype = (\n        config.floatx()\n        if standardize_dtype(x.dtype) == \"int64\"\n        else dtypes.result_type(x.dtype, float)\n    )\n    return np.log2(x, dtype=dtype)\n\n\ndef logaddexp(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype, float)\n    x1 = x1.astype(dtype)\n    x2 = x2.astype(dtype)\n    return np.logaddexp(x1, x2)\n\n\ndef logaddexp2(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype, float)\n    return np.logaddexp2(x1, x2).astype(dtype)\n\n\ndef logical_and(x1, x2):\n    return np.logical_and(x1, x2)\n\n\ndef logical_not(x):\n    return np.logical_not(x)\n\n\ndef logical_or(x1, x2):\n    return np.logical_or(x1, x2)\n\n\ndef logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0):\n    if dtype is None:\n        dtypes_to_resolve = [\n            getattr(start, \"dtype\", type(start)),\n            getattr(stop, \"dtype\", type(stop)),\n            float,\n        ]\n        dtype = dtypes.result_type(*dtypes_to_resolve)\n    return np.logspace(\n        start,\n        stop,\n        num=num,\n        endpoint=endpoint,\n        base=base,\n        dtype=dtype,\n        axis=axis,\n    )\n\n\ndef maximum(x1, x2):\n    if not isinstance(x1, (int, float)):\n        x1 = convert_to_tensor(x1)\n    if not isinstance(x2, (int, float)):\n        x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(\n        getattr(x1, \"dtype\", type(x1)),\n        getattr(x2, \"dtype\", type(x2)),\n    )\n    x1 = convert_to_tensor(x1, dtype)\n    x2 = convert_to_tensor(x2, dtype)\n    return np.maximum(x1, x2)\n\n\ndef median(x, axis=None, keepdims=False):\n    dtype = dtypes.result_type(x.dtype, float)\n    return np.median(x, axis=axis, keepdims=keepdims).astype(dtype)\n\n\ndef meshgrid(*x, indexing=\"xy\"):\n    return np.meshgrid(*x, indexing=indexing)\n\n\ndef min(x, axis=None, keepdims=False, initial=None):\n    axis = standardize_axis_for_numpy(axis)\n    return np.min(x, axis=axis, keepdims=keepdims, initial=initial)\n\n\ndef minimum(x1, x2):\n    if not isinstance(x1, (int, float)):\n        x1 = convert_to_tensor(x1)\n    if not isinstance(x2, (int, float)):\n        x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(\n        getattr(x1, \"dtype\", type(x1)),\n        getattr(x2, \"dtype\", type(x2)),\n    )\n    x1 = convert_to_tensor(x1, dtype)\n    x2 = convert_to_tensor(x2, dtype)\n    return np.minimum(x1, x2)\n\n\ndef mod(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    if dtype == \"bool\":\n        dtype = \"int32\"\n    x1 = x1.astype(dtype)\n    x2 = x2.astype(dtype)\n    return np.mod(x1, x2)\n\n\ndef fmod(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    if dtype == \"bool\":\n        dtype = \"int32\"\n    x1 = x1.astype(dtype)\n    x2 = x2.astype(dtype)\n    return np.fmod(x1, x2)\n\n\ndef moveaxis(x, source, destination):\n    return np.moveaxis(x, source=source, destination=destination)\n\n\ndef nanargmax(x, axis=None, keepdims=False):\n    if not np.issubdtype(x.dtype, np.floating):\n        return argmax(x, axis=axis, keepdims=keepdims)\n\n    nan_mask = np.isnan(x)\n\n    return np.where(\n        np.all(nan_mask, axis=axis, keepdims=keepdims),\n        -1,\n        np.nanargmax(\n            np.where(nan_mask, -np.inf, x), axis=axis, keepdims=keepdims\n        ).astype(\"int32\"),\n    )\n\n\ndef nanargmin(x, axis=None, keepdims=False):\n    if not np.issubdtype(x.dtype, np.floating):\n        return argmin(x, axis=axis, keepdims=keepdims)\n\n    nan_mask = np.isnan(x)\n\n    return np.where(\n        np.all(nan_mask, axis=axis, keepdims=keepdims),\n        -1,\n        np.nanargmin(\n            np.where(nan_mask, np.inf, x), axis=axis, keepdims=keepdims\n        ).astype(\"int32\"),\n    )\n\n\ndef nancumsum(x, axis=None, dtype=None):\n    axis = standardize_axis_for_numpy(axis)\n    dtype = dtypes.result_type(dtype or x.dtype)\n    if dtype == \"bool\":\n        dtype = \"int32\"\n    return np.nancumsum(x, axis=axis, dtype=dtype)\n\n\ndef nancumprod(x, axis=None, dtype=None):\n    axis = standardize_axis_for_numpy(axis)\n    dtype = dtypes.result_type(dtype or x.dtype)\n    if dtype == \"bool\":\n        dtype = \"int32\"\n    return np.nancumprod(x, axis=axis, dtype=dtype)\n\n\ndef nanmax(x, axis=None, keepdims=False):\n    return np.nanmax(x, axis=axis, keepdims=keepdims)\n\n\ndef nanmean(x, axis=None, keepdims=False):\n    dtype = dtypes.result_type(standardize_dtype(x.dtype), float)\n    return np.nanmean(x, axis=axis, keepdims=keepdims).astype(dtype)\n\n\ndef nanmin(x, axis=None, keepdims=False):\n    return np.nanmin(x, axis=axis, keepdims=keepdims)\n\n\ndef nanprod(x, axis=None, keepdims=False):\n    axis = standardize_axis_for_numpy(axis)\n\n    x = convert_to_tensor(x)\n\n    dtype = dtypes.result_type(x.dtype)\n    if dtype in (\"bool\", \"int8\", \"int16\"):\n        dtype = \"int32\"\n    elif dtype in (\"uint8\", \"uint16\"):\n        dtype = \"uint32\"\n    return np.nanprod(x, axis=axis, keepdims=keepdims, dtype=dtype)\n\n\ndef nanstd(x, axis=None, keepdims=False):\n    axis = standardize_axis_for_numpy(axis)\n    x = convert_to_tensor(x)\n    compute_dtype = dtypes.result_type(x.dtype, \"float32\")\n    result_dtype = dtypes.result_type(x.dtype, float)\n    return np.nanstd(\n        x, axis=axis, keepdims=keepdims, dtype=compute_dtype\n    ).astype(result_dtype)\n\n\ndef nansum(x, axis=None, keepdims=False):\n    axis = standardize_axis_for_numpy(axis)\n    dtype = standardize_dtype(x.dtype)\n\n    if dtype in (\"bool\", \"int8\", \"int16\"):\n        dtype = \"int32\"\n    elif dtype in (\"uint8\", \"uint16\"):\n        dtype = \"uint32\"\n    return np.nansum(x, axis=axis, keepdims=keepdims).astype(dtype)\n\n\ndef nanvar(x, axis=None, keepdims=False):\n    axis = standardize_axis_for_numpy(axis)\n    x = convert_to_tensor(x)\n    compute_dtype = dtypes.result_type(x.dtype, \"float32\")\n    result_dtype = dtypes.result_type(x.dtype, float)\n    return np.nanvar(\n        x, axis=axis, keepdims=keepdims, dtype=compute_dtype\n    ).astype(result_dtype)\n\n\ndef nan_to_num(x, nan=0.0, posinf=None, neginf=None):\n    return np.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf)\n\n\ndef ndim(x):\n    return np.ndim(x)\n\n\ndef nonzero(x):\n    return tuple(indices.astype(\"int32\") for indices in np.nonzero(x))\n\n\ndef not_equal(x1, x2):\n    return np.not_equal(x1, x2)\n\n\ndef zeros_like(x, dtype=None):\n    return np.zeros_like(x, dtype=dtype)\n\n\ndef ones_like(x, dtype=None):\n    return np.ones_like(x, dtype=dtype)\n\n\ndef outer(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    x1 = x1.astype(dtype)\n    x2 = x2.astype(dtype)\n    return np.outer(x1, x2)\n\n\ndef pad(x, pad_width, mode=\"constant\", constant_values=None):\n    kwargs = {}\n    if constant_values is not None:\n        if mode != \"constant\":\n            raise ValueError(\n                \"Argument `constant_values` can only be \"\n                \"provided when `mode == 'constant'`. \"\n                f\"Received: mode={mode}\"\n            )\n        kwargs[\"constant_values\"] = constant_values\n    return np.pad(x, pad_width, mode=mode, **kwargs)\n\n\ndef prod(x, axis=None, keepdims=False, dtype=None):\n    axis = standardize_axis_for_numpy(axis)\n    x = convert_to_tensor(x)\n    if dtype is None:\n        dtype = dtypes.result_type(x.dtype)\n        if dtype in (\"bool\", \"int8\", \"int16\"):\n            dtype = \"int32\"\n        elif dtype in (\"uint8\", \"uint16\"):\n            dtype = \"uint32\"\n    return np.prod(x, axis=axis, keepdims=keepdims, dtype=dtype)\n\n\ndef ptp(x, axis=None, keepdims=False):\n    return np.ptp(x, axis=axis, keepdims=keepdims)\n\n\ndef quantile(x, q, axis=None, method=\"linear\", keepdims=False):\n    axis = standardize_axis_for_numpy(axis)\n    x = convert_to_tensor(x)\n\n    ori_dtype = standardize_dtype(x.dtype)\n    # np.quantile doesn't support bool\n    if ori_dtype == \"bool\":\n        x = x.astype(config.floatx())\n    if ori_dtype == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    return np.quantile(\n        x, q, axis=axis, method=method, keepdims=keepdims\n    ).astype(dtype)\n\n\ndef ravel(x):\n    return np.ravel(x)\n\n\ndef unravel_index(indices, shape):\n    dtype = dtypes.result_type(indices.dtype)\n    return tuple(\n        indices.astype(dtype) for indices in np.unravel_index(indices, shape)\n    )\n\n\ndef real(x):\n    return np.real(x)\n\n\ndef reciprocal(x):\n    return np.reciprocal(x)\n\n\ndef repeat(x, repeats, axis=None):\n    return np.repeat(x, repeats, axis=axis)\n\n\ndef reshape(x, newshape):\n    return np.reshape(x, newshape)\n\n\ndef roll(x, shift, axis=None):\n    return np.roll(x, shift, axis=axis)\n\n\ndef searchsorted(sorted_sequence, values, side=\"left\"):\n    if ndim(sorted_sequence) != 1:\n        raise ValueError(\n            \"`searchsorted` only supports 1-D sorted sequences. \"\n            \"You can use `keras.ops.vectorized_map` \"\n            \"to extend it to N-D sequences. Received: \"\n            f\"sorted_sequence.shape={sorted_sequence.shape}\"\n        )\n    out_type = (\n        \"int32\"\n        if sorted_sequence.shape[0] <= np.iinfo(np.int32).max\n        else \"int64\"\n    )\n    return np.searchsorted(sorted_sequence, values, side=side).astype(out_type)\n\n\ndef sign(x):\n    return np.sign(x)\n\n\ndef signbit(x):\n    return np.signbit(x)\n\n\ndef sin(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = x.astype(dtype)\n    return np.sin(x)\n\n\ndef sinc(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = x.astype(dtype)\n    return np.sinc(x).astype(dtype)\n\n\ndef sinh(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = x.astype(dtype)\n    return np.sinh(x)\n\n\ndef size(x):\n    return np.size(x)\n\n\ndef sort(x, axis=-1):\n    axis = standardize_axis_for_numpy(axis)\n    return np.sort(x, axis=axis)\n\n\ndef split(x, indices_or_sections, axis=0):\n    axis = standardize_axis_for_numpy(axis)\n    return np.split(x, indices_or_sections, axis=axis)\n\n\ndef array_split(x, indices_or_sections, axis=0):\n    axis = standardize_axis_for_numpy(axis)\n    return np.array_split(x, indices_or_sections, axis=axis)\n\n\ndef stack(x, axis=0):\n    axis = standardize_axis_for_numpy(axis)\n    dtype_set = set([getattr(a, \"dtype\", type(a)) for a in x])\n    if len(dtype_set) > 1:\n        dtype = dtypes.result_type(*dtype_set)\n        x = tree.map_structure(lambda a: convert_to_tensor(a).astype(dtype), x)\n    return np.stack(x, axis=axis)\n\n\ndef std(x, axis=None, keepdims=False):\n    axis = standardize_axis_for_numpy(axis)\n    x = convert_to_tensor(x)\n    ori_dtype = standardize_dtype(x.dtype)\n    if \"int\" in ori_dtype or ori_dtype == \"bool\":\n        x = x.astype(config.floatx())\n    return np.std(x, axis=axis, keepdims=keepdims)\n\n\ndef swapaxes(x, axis1, axis2):\n    return np.swapaxes(x, axis1=axis1, axis2=axis2)\n\n\ndef take(x, indices, axis=None):\n    axis = standardize_axis_for_numpy(axis)\n    return np.take(x, indices, axis=axis)\n\n\ndef take_along_axis(x, indices, axis=None):\n    axis = standardize_axis_for_numpy(axis)\n    return np.take_along_axis(x, indices, axis=axis)\n\n\ndef tan(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = x.astype(dtype)\n    return np.tan(x)\n\n\ndef tanh(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = x.astype(dtype)\n    return np.tanh(x)\n\n\ndef tensordot(x1, x2, axes=2):\n    axes = tuple(axes) if isinstance(axes, list) else axes\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    x1 = x1.astype(dtype)\n    x2 = x2.astype(dtype)\n    return np.tensordot(x1, x2, axes=axes)\n\n\ndef round(x, decimals=0):\n    return np.round(x, decimals=decimals)\n\n\ndef tile(x, repeats):\n    return np.tile(x, repeats)\n\n\ndef trace(x, offset=0, axis1=0, axis2=1):\n    axis1 = standardize_axis_for_numpy(axis1)\n    axis2 = standardize_axis_for_numpy(axis2)\n    x = convert_to_tensor(x)\n    dtype = standardize_dtype(x.dtype)\n    if dtype in (\"bool\", \"int8\", \"int16\"):\n        dtype = \"int32\"\n    elif dtype in (\"uint8\", \"uint16\"):\n        dtype = \"uint32\"\n    return np.trace(x, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)\n\n\ndef tri(N, M=None, k=0, dtype=None):\n    dtype = dtype or config.floatx()\n    return np.tri(N, M=M, k=k, dtype=dtype)\n\n\ndef tril(x, k=0):\n    return np.tril(x, k=k)\n\n\ndef triu(x, k=0):\n    return np.triu(x, k=k)\n\n\ndef trunc(x):\n    x = convert_to_tensor(x)\n    dtype = standardize_dtype(x.dtype)\n    if \"int\" in dtype or \"bool\" == dtype:\n        return x\n    return np.trunc(x)\n\n\ndef vdot(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    x1 = x1.astype(dtype)\n    x2 = x2.astype(dtype)\n    return np.vdot(x1, x2)\n\n\ndef inner(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    x1 = x1.astype(dtype)\n    x2 = x2.astype(dtype)\n    return np.inner(x1, x2)\n\n\ndef vstack(xs):\n    dtype_set = set([getattr(x, \"dtype\", type(x)) for x in xs])\n    if len(dtype_set) > 1:\n        dtype = dtypes.result_type(*dtype_set)\n        xs = tree.map_structure(\n            lambda x: convert_to_tensor(x).astype(dtype), xs\n        )\n    return np.vstack(xs)\n\n\ndef vsplit(x, indices_or_sections):\n    x = convert_to_tensor(x)\n    return np.vsplit(x, indices_or_sections)\n\n\ndef vectorize(pyfunc, *, excluded=None, signature=None):\n    return np.vectorize(pyfunc, excluded=excluded, signature=signature)\n\n\ndef where(condition, x1=None, x2=None):\n    if x1 is not None and x2 is not None:\n        if not isinstance(x1, (int, float)):\n            x1 = convert_to_tensor(x1)\n        if not isinstance(x2, (int, float)):\n            x2 = convert_to_tensor(x2)\n        dtype = dtypes.result_type(\n            getattr(x1, \"dtype\", type(x1)),\n            getattr(x2, \"dtype\", type(x2)),\n        )\n        x1 = convert_to_tensor(x1, dtype)\n        x2 = convert_to_tensor(x2, dtype)\n        return np.where(condition, x1, x2)\n    else:\n        return np.where(condition)\n\n\ndef divide(x1, x2):\n    if not isinstance(x1, (int, float)):\n        x1 = convert_to_tensor(x1)\n    if not isinstance(x2, (int, float)):\n        x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(\n        getattr(x1, \"dtype\", type(x1)),\n        getattr(x2, \"dtype\", type(x2)),\n        float,\n    )\n    x1 = convert_to_tensor(x1, dtype)\n    x2 = convert_to_tensor(x2, dtype)\n    return np.divide(x1, x2)\n\n\ndef divide_no_nan(x1, x2):\n    if not isinstance(x1, (int, float)):\n        x1 = convert_to_tensor(x1)\n    if not isinstance(x2, (int, float)):\n        x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(\n        getattr(x1, \"dtype\", type(x1)),\n        getattr(x2, \"dtype\", type(x2)),\n        float,\n    )\n    x1 = convert_to_tensor(x1, dtype)\n    x2 = convert_to_tensor(x2, dtype)\n    # No need for the double-where trick since we don't calculate gradients in\n    # numpy backend.\n    return np.where(x2 == 0, np.array(0, dtype=dtype), np.divide(x1, x2))\n\n\ndef true_divide(x1, x2):\n    return divide(x1, x2)\n\n\ndef power(x1, x2):\n    if not isinstance(x1, (int, float)):\n        x1 = convert_to_tensor(x1)\n    if not isinstance(x2, (int, float)):\n        x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(\n        getattr(x1, \"dtype\", type(x1)),\n        getattr(x2, \"dtype\", type(x2)),\n    )\n    x1 = convert_to_tensor(x1, dtype)\n    x2 = convert_to_tensor(x2, dtype)\n    return np.power(x1, x2)\n\n\ndef negative(x):\n    return np.negative(x)\n\n\ndef nextafter(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype, float)\n\n    return np.nextafter(x1, x2).astype(dtype)\n\n\ndef square(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"bool\":\n        x = x.astype(\"int32\")\n    return np.square(x)\n\n\ndef sqrt(x):\n    x = convert_to_tensor(x)\n    # upcast to float64 for int64 which matches JAX's behavior\n    dtype = (\n        config.floatx()\n        if standardize_dtype(x.dtype) == \"int64\"\n        else dtypes.result_type(x.dtype, float)\n    )\n    return np.sqrt(x, dtype=dtype)\n\n\ndef squeeze(x, axis=None):\n    axis = standardize_axis_for_numpy(axis)\n    return np.squeeze(x, axis=axis)\n\n\ndef transpose(x, axes=None):\n    axes = tuple(axes) if isinstance(axes, list) else axes\n    return np.transpose(x, axes=axes)\n\n\ndef trapezoid(y, x=None, dx=1.0, axis=-1):\n    y = convert_to_tensor(y)\n    result_dtype = dtypes.result_type(y.dtype, float)\n    if x is not None:\n        x = convert_to_tensor(x)\n    dx = convert_to_tensor(dx)\n    return np.trapezoid(y, x, dx=dx, axis=axis).astype(result_dtype)\n\n\ndef vander(x, N=None, increasing=False):\n    x = convert_to_tensor(x)\n    result_dtype = dtypes.result_type(x.dtype)\n    compute_dtype = dtypes.result_type(x.dtype, config.floatx())\n    x = x.astype(compute_dtype)\n    return np.vander(x, N=N, increasing=increasing).astype(result_dtype)\n\n\ndef var(x, axis=None, keepdims=False):\n    axis = standardize_axis_for_numpy(axis)\n    x = convert_to_tensor(x)\n    compute_dtype = dtypes.result_type(x.dtype, \"float32\")\n    result_dtype = dtypes.result_type(x.dtype, float)\n    return np.var(x, axis=axis, keepdims=keepdims, dtype=compute_dtype).astype(\n        result_dtype\n    )\n\n\ndef sum(x, axis=None, keepdims=False):\n    axis = standardize_axis_for_numpy(axis)\n    dtype = standardize_dtype(x.dtype)\n    # follow jax's rule\n    if dtype in (\"bool\", \"int8\", \"int16\"):\n        dtype = \"int32\"\n    elif dtype in (\"uint8\", \"uint16\"):\n        dtype = \"uint32\"\n    return np.sum(x, axis=axis, keepdims=keepdims).astype(dtype)\n\n\ndef eye(N, M=None, k=0, dtype=None):\n    dtype = dtype or config.floatx()\n    return np.eye(N, M=M, k=k, dtype=dtype)\n\n\ndef floor_divide(x1, x2):\n    if not isinstance(x1, (int, float)):\n        x1 = convert_to_tensor(x1)\n    if not isinstance(x2, (int, float)):\n        x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(\n        getattr(x1, \"dtype\", type(x1)), getattr(x2, \"dtype\", type(x2))\n    )\n    x1 = convert_to_tensor(x1, dtype)\n    x2 = convert_to_tensor(x2, dtype)\n    return np.floor_divide(x1, x2)\n\n\ndef logical_xor(x1, x2):\n    return np.logical_xor(x1, x2)\n\n\ndef corrcoef(x):\n    if x.dtype in [\"int64\", \"float64\"]:\n        dtype = \"float64\"\n    elif x.dtype in [\"bfloat16\", \"float16\"]:\n        dtype = x.dtype\n    else:\n        dtype = config.floatx()\n\n    x = convert_to_tensor(x)\n\n    return np.corrcoef(x).astype(dtype)\n\n\ndef correlate(x1, x2, mode=\"valid\"):\n    dtype = dtypes.result_type(\n        getattr(x1, \"dtype\", type(x1)),\n        getattr(x2, \"dtype\", type(x2)),\n    )\n    if dtype == \"int64\":\n        dtype = \"float64\"\n    elif dtype not in [\"bfloat16\", \"float16\", \"float64\"]:\n        dtype = \"float32\"\n\n    x1 = convert_to_tensor(x1, dtype)\n    x2 = convert_to_tensor(x2, dtype)\n    return np.correlate(x1, x2, mode)\n\n\ndef select(condlist, choicelist, default=0):\n    return np.select(condlist, choicelist, default=default)\n\n\ndef slogdet(x):\n    return tuple(np.linalg.slogdet(x))\n\n\ndef argpartition(x, kth, axis=-1):\n    return np.argpartition(x, kth, axis).astype(\"int32\")\n\n\ndef histogram(x, bins=10, range=None):\n    return np.histogram(x, bins=bins, range=range)\n"
  },
  {
    "path": "keras/src/backend/numpy/random.py",
    "content": "import numpy as np\n\nfrom keras.src.backend.config import floatx\nfrom keras.src.backend.numpy.nn import softmax\nfrom keras.src.random.seed_generator import SeedGenerator\nfrom keras.src.random.seed_generator import draw_seed\nfrom keras.src.random.seed_generator import make_default_seed\n\n\ndef normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):\n    dtype = dtype or floatx()\n    seed = draw_seed(seed)\n    rng = np.random.default_rng(seed)\n    return rng.normal(size=shape, loc=mean, scale=stddev).astype(dtype)\n\n\ndef uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):\n    dtype = dtype or floatx()\n    seed = draw_seed(seed)\n    rng = np.random.default_rng(seed)\n    return rng.uniform(size=shape, low=minval, high=maxval).astype(dtype)\n\n\ndef categorical(logits, num_samples, dtype=\"int64\", seed=None):\n    seed = draw_seed(seed)\n    rng = np.random.default_rng(seed)\n    output = []\n    for logits_instance in logits:\n        probabilities = softmax(logits_instance)\n        classes = np.arange(logits_instance.shape[-1])\n        samples = rng.choice(classes, size=num_samples, p=probabilities)\n        output.append(samples)\n    return np.array(output).astype(dtype)\n\n\ndef randint(shape, minval, maxval, dtype=\"int32\", seed=None):\n    seed = draw_seed(seed)\n    rng = np.random.default_rng(seed)\n    output = rng.integers(low=minval, high=maxval, size=shape, dtype=dtype)\n    return output\n\n\ndef truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):\n    dtype = dtype or floatx()\n    seed = draw_seed(seed)\n    rng = np.random.default_rng(seed)\n\n    lower_bound = mean - 2 * stddev\n    upper_bound = mean + 2 * stddev\n\n    flat_shape = np.prod(shape)\n    random_numbers = np.empty(0)\n\n    # loop until we have enough valid numbers to fill our desired shape\n    while random_numbers.shape[0] < flat_shape:\n        # Generate a batch of random numbers from a normal distribution\n        batch = rng.normal(loc=mean, scale=stddev, size=flat_shape)\n\n        # Filter the numbers to keep only those within the specified bounds\n        valid = batch[(batch >= lower_bound) & (batch <= upper_bound)]\n\n        # Append the valid numbers to the result array\n        random_numbers = np.append(random_numbers, valid)\n\n    # Truncate the result array to the desired size and reshape it\n    return random_numbers[:flat_shape].astype(dtype).reshape(shape)\n\n\ndef dropout(inputs, rate, noise_shape=None, seed=None):\n    if rate == 1.0:\n        return np.zeros_like(inputs)\n    if rate == 0.0:\n        return inputs\n    dtype = inputs.dtype\n    seed = draw_seed(seed)\n\n    keep_prob = 1.0 - rate\n    # If noise_shape is not provided, use the shape of inputs\n    if noise_shape is None:\n        noise_shape = inputs.shape\n    else:\n        # If noise_shape is provided, replace None with corresponding\n        # input shape\n        noise_shape = [\n            n if n is not None else inputs.shape[i]\n            for i, n in enumerate(noise_shape)\n        ]\n\n    rng = np.random.default_rng(seed)\n    mask = rng.uniform(size=noise_shape) < keep_prob\n    mask = np.broadcast_to(mask, inputs.shape)\n    return np.where(\n        mask, (inputs / keep_prob).astype(dtype), np.zeros_like(inputs)\n    )\n\n\ndef shuffle(x, axis=0, seed=None):\n    seed = draw_seed(seed)\n    rng = np.random.default_rng(seed)\n    return rng.permuted(x, axis=axis)\n\n\ndef gamma(shape, alpha, dtype=None, seed=None):\n    dtype = dtype or floatx()\n    seed = draw_seed(seed)\n    rng = np.random.default_rng(seed)\n    return rng.gamma(alpha, scale=1.0, size=shape).astype(dtype)\n\n\ndef binomial(shape, counts, probabilities, dtype=None, seed=None):\n    dtype = dtype or floatx()\n    seed = draw_seed(seed)\n    rng = np.random.default_rng(seed)\n    sample = rng.binomial(n=counts, p=probabilities, size=shape).astype(dtype)\n    return sample\n\n\ndef beta(shape, alpha, beta, dtype=None, seed=None):\n    dtype = dtype or floatx()\n    seed = draw_seed(seed)\n    rng = np.random.default_rng(seed)\n    sample = rng.beta(a=alpha, b=beta, size=shape).astype(dtype)\n    return sample\n"
  },
  {
    "path": "keras/src/backend/numpy/rnn.py",
    "content": "import numpy as np\n\nfrom keras.src import tree\n\n\ndef rnn(\n    step_function,\n    inputs,\n    initial_states,\n    go_backwards=False,\n    mask=None,\n    constants=None,\n    unroll=False,\n    input_length=None,\n    time_major=False,\n    zero_output_for_mask=False,\n    return_all_outputs=True,\n):\n    def swap_batch_timestep(input_t):\n        # Swap the batch and timestep dim for the incoming tensor.\n        axes = list(range(len(input_t.shape)))\n        axes[0], axes[1] = 1, 0\n        return np.transpose(input_t, axes)\n\n    if not time_major:\n        inputs = tree.map_structure(swap_batch_timestep, inputs)\n\n    flattened_inputs = tree.flatten(inputs)\n    time_steps = flattened_inputs[0].shape[0]\n\n    if mask is not None:\n        if mask.dtype != \"bool\":\n            mask = mask.astype(\"bool\")\n        if len(mask.shape) == 2:\n            mask = np.expand_dims(mask, axis=-1)\n        if not time_major:\n            mask = swap_batch_timestep(mask)\n\n    if constants is None:\n        constants = []\n\n    def _expand_mask(mask_t, input_t, fixed_dim=1):\n        if tree.is_nested(mask_t):\n            raise ValueError(\n                f\"mask_t is expected to be tensor, but got {mask_t}\"\n            )\n        if tree.is_nested(input_t):\n            raise ValueError(\n                f\"input_t is expected to be tensor, but got {input_t}\"\n            )\n        rank_diff = len(input_t.shape) - len(mask_t.shape)\n        for _ in range(rank_diff):\n            mask_t = np.expand_dims(mask_t, -1)\n        multiples = [1] * fixed_dim + list(input_t.shape[fixed_dim:])\n        return np.tile(mask_t, multiples)\n\n    if unroll:\n        if not time_steps:\n            raise ValueError(\"Unrolling requires a fixed number of timesteps.\")\n        states = tuple(initial_states)\n        successive_states = []\n        successive_outputs = []\n\n        # Process the input tensors. The input tensor need to be split on the\n        # time_step dim, and reverse if go_backwards is True. In the case of\n        # nested input, the input is flattened and then transformed\n        # individually.  The result of this will be a tuple of lists, each of\n        # the item in tuple is list of the tensor with shape (batch, feature)\n        def _process_single_input_t(input_t):\n            input_t = unstack(input_t)  # unstack for time_step dim\n            if go_backwards:\n                input_t.reverse()\n            return input_t\n\n        if tree.is_nested(inputs):\n            processed_input = tree.map_structure(\n                _process_single_input_t, inputs\n            )\n        else:\n            processed_input = (_process_single_input_t(inputs),)\n\n        def _get_input_tensor(time):\n            inp = [t_[time] for t_ in processed_input]\n            return tree.pack_sequence_as(inputs, inp)\n\n        if mask is not None:\n            mask_list = unstack(mask)\n            if go_backwards:\n                mask_list.reverse()\n\n            for i in range(time_steps):\n                inp = _get_input_tensor(i)\n                mask_t = mask_list[i]\n                output, new_states = step_function(\n                    inp, tuple(states) + tuple(constants)\n                )\n                tiled_mask_t = _expand_mask(mask_t, output)\n\n                if not successive_outputs:\n                    prev_output = np.zeros_like(output)\n                else:\n                    prev_output = successive_outputs[-1]\n\n                output = np.where(tiled_mask_t, output, prev_output)\n\n                flat_states = tree.flatten(states)\n                flat_new_states = tree.flatten(new_states)\n                tiled_mask_t = tuple(\n                    _expand_mask(mask_t, s) for s in flat_states\n                )\n                flat_final_states = tuple(\n                    np.where(m, s, ps)\n                    for m, s, ps in zip(\n                        tiled_mask_t, flat_new_states, flat_states\n                    )\n                )\n                states = tree.pack_sequence_as(states, flat_final_states)\n\n                if return_all_outputs:\n                    successive_outputs.append(output)\n                    successive_states.append(states)\n                else:\n                    successive_outputs = [output]\n                    successive_states = [states]\n            last_output = successive_outputs[-1]\n            new_states = successive_states[-1]\n            outputs = np.stack(successive_outputs)\n\n        else:  # mask is None\n            for i in range(time_steps):\n                inp = _get_input_tensor(i)\n                output, states = step_function(\n                    inp, tuple(states) + tuple(constants)\n                )\n                if return_all_outputs:\n                    successive_outputs.append(output)\n                    successive_states.append(states)\n                else:\n                    successive_outputs = [output]\n                    successive_states = [states]\n            last_output = successive_outputs[-1]\n            new_states = successive_states[-1]\n            outputs = np.stack(successive_outputs)\n\n    else:  # Unroll == False\n        if mask is not None:\n\n            def _step(states, current_input):\n                current_input, current_mask = current_input\n                is_masked = np.all(\n                    np.logical_not(current_mask), axis=-1, keepdims=True\n                )\n\n                output_t, new_states = step_function(current_input, states)\n\n                if zero_output_for_mask:\n                    masked_outs = np.where(\n                        is_masked, np.zeros_like(output_t), output_t\n                    )\n                else:\n                    # Assume the first state is the previous output.\n                    output_tm1 = states[0]\n                    if tree.is_nested(output_tm1):\n                        # Stacked RNN case: assume first state of last cell.\n                        output_tm1 = states[-1][0]\n                    masked_outs = np.where(is_masked, output_tm1, output_t)\n\n                new_states = tree.map_structure(\n                    lambda s, ns: np.where(is_masked, s, ns),\n                    states,\n                    new_states,\n                )\n                return (new_states, masked_outs)\n\n            scan_xs = (inputs, mask)\n\n        else:\n\n            def _step(states, current_input):\n                output_t, new_states = step_function(current_input, states)\n                return new_states, output_t\n\n            scan_xs = inputs\n\n        new_states, outputs = numpy_scan(\n            f=_step,\n            init=initial_states,\n            xs=scan_xs,\n            reverse=go_backwards,\n            mask=mask,\n        )\n\n        if go_backwards:\n            outputs = np.flip(outputs, axis=0)\n        last_output = outputs[-1]\n\n    if not time_major:\n        outputs = tree.map_structure(swap_batch_timestep, outputs)\n\n    return last_output, outputs, new_states\n\n\ndef lstm(*args, **kwargs):\n    raise NotImplementedError\n\n\ndef gru(*args, **kwargs):\n    raise NotImplementedError\n\n\ndef unstack(x, axis=0):\n    return [x.take(i, axis) for i in range(x.shape[axis])]\n\n\ndef numpy_scan(f, init, xs, reverse=False, mask=None):\n    states = init\n    outputs = []\n\n    if mask is not None:\n        x, mask = xs\n        x = np.flip(x, axis=0) if reverse else x\n        mask = np.flip(mask, axis=0) if reverse else mask\n\n        for each_x, each_mask in zip(x, mask):\n            states, output = f(states, (each_x, each_mask))\n            outputs.append(output)\n    else:\n        xs = np.flip(xs, axis=0) if reverse else xs\n\n        for x in xs:\n            states, output = f(states, x)\n            outputs.append(output)\n\n    outputs = np.array(outputs)\n\n    if reverse:\n        outputs = np.flip(outputs, axis=0)\n\n    return states, outputs\n\n\ndef cudnn_ok(*args, **kwargs):\n    return False\n"
  },
  {
    "path": "keras/src/backend/numpy/trainer.py",
    "content": "import numpy as np\n\nfrom keras.src import backend\nfrom keras.src import callbacks as callbacks_module\nfrom keras.src import tree\nfrom keras.src.backend.common import standardize_dtype\nfrom keras.src.backend.common.keras_tensor import KerasTensor\nfrom keras.src.backend.numpy.core import is_tensor\nfrom keras.src.trainers import trainer as base_trainer\nfrom keras.src.trainers.data_adapters import data_adapter_utils\nfrom keras.src.trainers.epoch_iterator import EpochIterator\nfrom keras.src.utils import traceback_utils\nfrom keras.src.utils.python_utils import pythonify_logs\n\n\nclass NumpyTrainer(base_trainer.Trainer):\n    def __init__(self):\n        super().__init__()\n        self.test_function = None\n        self.predict_function = None\n\n    def test_step(self, data):\n        (\n            x,\n            y,\n            sample_weight,\n        ) = data_adapter_utils.unpack_x_y_sample_weight(data)\n        if self._call_has_training_arg:\n            y_pred = self(x, training=False)\n        else:\n            y_pred = self(x)\n        loss = self._compute_loss(\n            x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=False\n        )\n        self._loss_tracker.update_state(\n            loss, sample_weight=tree.flatten(x)[0].shape[0]\n        )\n        return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)\n\n    def predict_step(self, data):\n        x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data)\n        if self._call_has_training_arg:\n            y_pred = self(x, training=False)\n        else:\n            y_pred = self(x)\n        return y_pred\n\n    def make_test_function(self, force=False):\n        if self.test_function is not None and not force:\n            return self.test_function\n\n        def one_test_step(data):\n            data = data[0]\n            return self.test_step(data)\n\n        def multi_test_steps(data):\n            for single_step_data in data:\n                logs = one_test_step([single_step_data])\n            return logs\n\n        if self.steps_per_execution > 1:\n            test_step = multi_test_steps\n        else:\n            test_step = one_test_step\n\n        self.test_function = test_step\n\n    def make_predict_function(self, force=False):\n        if self.predict_function is not None and not force:\n            return self.predict_function\n\n        def one_predict_step(data):\n            data = data[0]\n            return self.predict_step(data)\n\n        def multi_predict_steps(data):\n            outputs = one_predict_step(data[:1])\n\n            for single_step_data in data[1:]:\n                step_outputs = one_predict_step([single_step_data])\n                outputs = tree.map_structure(\n                    lambda t1, t2: np.concatenate([t1, t2]),\n                    outputs,\n                    step_outputs,\n                )\n            return outputs\n\n        if self.steps_per_execution > 1:\n            predict_step = multi_predict_steps\n        else:\n            predict_step = one_predict_step\n\n        self.predict_function = predict_step\n\n    def _symbolic_build(self, data_batch):\n        model_unbuilt = not all(layer.built for layer in self._flatten_layers())\n        compile_metrics_unbuilt = (\n            self._compile_metrics is not None\n            and not self._compile_metrics.built\n        )\n        compile_loss_unbuilt = (\n            self._compile_loss is not None and not self._compile_loss.built\n        )\n        if model_unbuilt or compile_metrics_unbuilt or compile_loss_unbuilt:\n            # Create symbolic tensors matching an input batch.\n\n            def to_symbolic_input(v):\n                if is_tensor(v):\n                    return KerasTensor(v.shape, standardize_dtype(v.dtype))\n                return v\n\n            data_batch = tree.map_structure(to_symbolic_input, data_batch)\n            (\n                x,\n                y,\n                sample_weight,\n            ) = data_adapter_utils.unpack_x_y_sample_weight(data_batch)\n            # Build all model state with `backend.compute_output_spec`.\n            try:\n                y_pred = backend.compute_output_spec(self, x)\n            except:\n                raise RuntimeError(\n                    \"Unable to automatically build the model. \"\n                    \"Please build it yourself before calling \"\n                    \"fit/evaluate/predict. \"\n                    \"A model is 'built' when its variables have \"\n                    \"been created and its `self.built` attribute \"\n                    \"is True. Usually, calling the model on a batch \"\n                    \"of data is the right way to build it.\"\n                )\n            if compile_metrics_unbuilt:\n                # Build all metric state with `backend.compute_output_spec`.\n                backend.compute_output_spec(\n                    self.compute_metrics,\n                    x,\n                    y,\n                    y_pred,\n                    sample_weight=sample_weight,\n                )\n            if compile_loss_unbuilt:\n                # Build `CompileLoss` state with `backend.compute_output_spec`.\n                backend.compute_output_spec(\n                    self._compute_loss,\n                    x,\n                    y,\n                    y_pred,\n                    sample_weight=sample_weight,\n                )\n        self._post_build()\n\n    def fit(\n        self,\n        x=None,\n        y=None,\n        batch_size=None,\n        epochs=1,\n        verbose=\"auto\",\n        callbacks=None,\n        validation_split=0.0,\n        validation_data=None,\n        shuffle=True,\n        class_weight=None,\n        sample_weight=None,\n        initial_epoch=0,\n        steps_per_epoch=None,\n        validation_steps=None,\n        validation_batch_size=None,\n        validation_freq=1,\n    ):\n        raise NotImplementedError(\"fit not implemented for NumPy backend.\")\n\n    @traceback_utils.filter_traceback\n    def predict(\n        self, x, batch_size=None, verbose=\"auto\", steps=None, callbacks=None\n    ):\n        # Create an iterator that yields batches of input data.\n        epoch_iterator = EpochIterator(\n            x=x,\n            batch_size=batch_size,\n            steps_per_epoch=steps,\n            shuffle=False,\n            steps_per_execution=self.steps_per_execution,\n        )\n\n        # Container that configures and calls callbacks.\n        if not isinstance(callbacks, callbacks_module.CallbackList):\n            callbacks = callbacks_module.CallbackList(\n                callbacks,\n                add_progbar=verbose != 0,\n                verbose=verbose,\n                epochs=1,\n                steps=epoch_iterator.num_batches,\n                model=self,\n            )\n\n        def append_to_outputs(batch_outputs, outputs):\n            if outputs is None:\n                outputs = tree.map_structure(\n                    lambda batch_output: [batch_output],\n                    batch_outputs,\n                )\n            else:\n                tree.map_structure_up_to(\n                    batch_outputs,\n                    lambda output, batch_output: output.append(batch_output),\n                    outputs,\n                    batch_outputs,\n                )\n            return outputs\n\n        self.make_predict_function()\n        self.stop_predicting = False\n        callbacks.on_predict_begin()\n        outputs = None\n        for begin_step, end_step, data in epoch_iterator:\n            callbacks.on_predict_batch_begin(begin_step)\n            batch_outputs = self.predict_function(data)\n            outputs = append_to_outputs(batch_outputs, outputs)\n            callbacks.on_predict_batch_end(end_step, {\"outputs\": batch_outputs})\n            if self.stop_predicting:\n                break\n        callbacks.on_predict_end()\n        return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs)\n\n    @traceback_utils.filter_traceback\n    def evaluate(\n        self,\n        x=None,\n        y=None,\n        batch_size=None,\n        verbose=\"auto\",\n        sample_weight=None,\n        steps=None,\n        callbacks=None,\n        return_dict=False,\n        **kwargs,\n    ):\n        # TODO: respect compiled trainable state\n        use_cached_eval_dataset = kwargs.pop(\"_use_cached_eval_dataset\", False)\n        if kwargs:\n            raise ValueError(f\"Arguments not recognized: {kwargs}\")\n\n        if use_cached_eval_dataset:\n            epoch_iterator = self._eval_epoch_iterator\n        else:\n            # Create an iterator that yields batches of input/target data.\n            epoch_iterator = EpochIterator(\n                x=x,\n                y=y,\n                sample_weight=sample_weight,\n                batch_size=batch_size,\n                steps_per_epoch=steps,\n                shuffle=False,\n                steps_per_execution=self.steps_per_execution,\n            )\n\n        if not all(layer.built for layer in self._flatten_layers()):\n            # Build the model on one batch of data.\n            for _, _, data in epoch_iterator:\n                data_batch = data[0]\n                self._symbolic_build(data_batch)\n                break\n\n        # Container that configures and calls callbacks.\n        if not isinstance(callbacks, callbacks_module.CallbackList):\n            callbacks = callbacks_module.CallbackList(\n                callbacks,\n                add_progbar=verbose != 0,\n                verbose=verbose,\n                epochs=1,\n                steps=epoch_iterator.num_batches,\n                model=self,\n            )\n\n        self.make_test_function()\n        self.stop_evaluating = False\n        callbacks.on_test_begin()\n        logs = {}\n        self.reset_metrics()\n        for begin_step, end_step, data in epoch_iterator:\n            callbacks.on_test_batch_begin(begin_step)\n            logs = self.test_function(data)\n            callbacks.on_test_batch_end(end_step, logs)\n            if self.stop_evaluating:\n                break\n        logs = pythonify_logs(self._get_metrics_result_or_logs(logs))\n        callbacks.on_test_end(logs)\n\n        if return_dict:\n            return logs\n        return self._flatten_metrics_in_order(logs)\n\n    def train_on_batch(\n        self,\n        x,\n        y=None,\n        sample_weight=None,\n        class_weight=None,\n        return_dict=False,\n    ):\n        raise NotImplementedError(\n            \"train_on_batch not implemented for NumPy backend.\"\n        )\n\n    def test_on_batch(\n        self,\n        x,\n        y=None,\n        sample_weight=None,\n        return_dict=False,\n    ):\n        self._assert_compile_called(\"test_on_batch\")\n\n        data = (x, y, sample_weight)\n\n        # Maybe build model\n        self._symbolic_build(data)\n        self.make_test_function()\n\n        logs = self.test_function([data])\n        logs = pythonify_logs(logs)\n        if return_dict:\n            return logs\n        return self._flatten_metrics_in_order(logs)\n\n    def predict_on_batch(self, x):\n        self.make_predict_function()\n        batch_outputs = self.predict_function([(x,)])\n        batch_outputs = tree.map_structure(\n            backend.convert_to_numpy, batch_outputs\n        )\n        return batch_outputs\n"
  },
  {
    "path": "keras/src/backend/openvino/__init__.py",
    "content": "from keras.src.backend.common.name_scope import name_scope\nfrom keras.src.backend.openvino import core\nfrom keras.src.backend.openvino import image\nfrom keras.src.backend.openvino import linalg\nfrom keras.src.backend.openvino import math\nfrom keras.src.backend.openvino import nn\nfrom keras.src.backend.openvino import numpy\nfrom keras.src.backend.openvino import random\nfrom keras.src.backend.openvino.core import IS_THREAD_SAFE\nfrom keras.src.backend.openvino.core import SUPPORTS_RAGGED_TENSORS\nfrom keras.src.backend.openvino.core import SUPPORTS_SPARSE_TENSORS\nfrom keras.src.backend.openvino.core import Variable\nfrom keras.src.backend.openvino.core import cast\nfrom keras.src.backend.openvino.core import compute_output_spec\nfrom keras.src.backend.openvino.core import cond\nfrom keras.src.backend.openvino.core import convert_to_numpy\nfrom keras.src.backend.openvino.core import convert_to_tensor\nfrom keras.src.backend.openvino.core import device_scope\nfrom keras.src.backend.openvino.core import is_tensor\nfrom keras.src.backend.openvino.core import random_seed_dtype\nfrom keras.src.backend.openvino.core import shape\nfrom keras.src.backend.openvino.core import vectorized_map\nfrom keras.src.backend.openvino.rnn import cudnn_ok\nfrom keras.src.backend.openvino.rnn import gru\nfrom keras.src.backend.openvino.rnn import lstm\nfrom keras.src.backend.openvino.rnn import rnn\n"
  },
  {
    "path": "keras/src/backend/openvino/core.py",
    "content": "import builtins\nimport contextlib\nimport warnings\n\nimport numpy as np\nimport openvino as ov\nimport openvino.opset15 as ov_opset\nfrom openvino import Model\nfrom openvino import Tensor\nfrom openvino import Type\nfrom openvino import compile_model\n\nfrom keras.src import tree\nfrom keras.src.backend.common import KerasVariable\nfrom keras.src.backend.common import dtypes\nfrom keras.src.backend.common import standardize_dtype\nfrom keras.src.backend.common.backend_utils import slice_along_axis\nfrom keras.src.backend.common.dtypes import result_type\nfrom keras.src.backend.common.keras_tensor import KerasTensor\nfrom keras.src.backend.common.stateless_scope import StatelessScope\n\nSUPPORTS_SPARSE_TENSORS = False\nSUPPORTS_RAGGED_TENSORS = False\nIS_THREAD_SAFE = True\n\nOPENVINO_DTYPES = {\n    \"float16\": ov.Type.f16,\n    \"float32\": ov.Type.f32,\n    \"float64\": ov.Type.f64,\n    \"uint8\": ov.Type.u8,\n    \"uint16\": ov.Type.u16,\n    \"uint32\": ov.Type.u32,\n    \"uint64\": ov.Type.u64,\n    \"int8\": ov.Type.i8,\n    \"int16\": ov.Type.i16,\n    \"int32\": ov.Type.i32,\n    \"int64\": ov.Type.i64,\n    \"bfloat16\": ov.Type.bf16,\n    \"bool\": ov.Type.boolean,\n    \"float8_e4m3fn\": ov.Type.f8e4m3,\n    \"float8_e5m2\": ov.Type.f8e5m2,\n    \"string\": ov.Type.string,\n}\n\nDTYPES_MAX = {\n    ov.Type.bf16: 3.38953139e38,\n    ov.Type.f16: np.finfo(np.float16).max,\n    ov.Type.f32: np.finfo(np.float32).max,\n    ov.Type.f64: np.finfo(np.float64).max,\n    ov.Type.u8: np.iinfo(np.uint8).max,\n    ov.Type.u16: np.iinfo(np.uint16).max,\n    ov.Type.u32: np.iinfo(np.uint32).max,\n    ov.Type.u64: np.iinfo(np.uint64).max,\n    ov.Type.i8: np.iinfo(np.int8).max,\n    ov.Type.i16: np.iinfo(np.int16).max,\n    ov.Type.i32: np.iinfo(np.int32).max,\n    ov.Type.i64: np.iinfo(np.int64).max,\n    ov.Type.boolean: 1,\n}\n\nDTYPES_MIN = {\n    ov.Type.bf16: -3.38953139e38,\n    ov.Type.f16: np.finfo(np.float16).min,\n    ov.Type.f32: np.finfo(np.float32).min,\n    ov.Type.f64: np.finfo(np.float64).min,\n    ov.Type.u8: np.iinfo(np.uint8).min,\n    ov.Type.u16: np.iinfo(np.uint16).min,\n    ov.Type.u32: np.iinfo(np.uint32).min,\n    ov.Type.u64: np.iinfo(np.uint64).min,\n    ov.Type.i8: np.iinfo(np.int8).min,\n    ov.Type.i16: np.iinfo(np.int16).min,\n    ov.Type.i32: np.iinfo(np.int32).min,\n    ov.Type.i64: np.iinfo(np.int64).min,\n    ov.Type.boolean: 0,\n}\n\n\ndef align_operand_types(x1, x2, op_name):\n    x1_type = x1.element_type\n    x2_type = x2.element_type\n    if x1_type.is_dynamic() or x2_type.is_dynamic():\n        raise ValueError(\n            f\"'{op_name}' operation is not supported for dynamic operand type \"\n            \"with openvino backend\"\n        )\n    x1_type = ov_to_keras_type(x1_type)\n    x2_type = ov_to_keras_type(x2_type)\n    result_type = dtypes.result_type(x1_type, x2_type)\n    result_type = OPENVINO_DTYPES[result_type]\n    if x1_type != result_type:\n        x1 = ov_opset.convert(x1, result_type).output(0)\n    if x2_type != result_type:\n        x2 = ov_opset.convert(x2, result_type).output(0)\n    return x1, x2\n\n\n# create ov.Output (symbolic OpenVINO tensor)\n# for different input `x`\ndef get_ov_output(x, ov_type=None):\n    if isinstance(x, float):\n        if ov_type is None:\n            ov_type = Type.f32\n        x = ov_opset.constant(x, ov_type).output(0)\n    elif isinstance(x, int):\n        if ov_type is None:\n            ov_type = Type.i32\n        x = ov_opset.constant(x, ov_type).output(0)\n    elif isinstance(x, np.ndarray):\n        if x.dtype == np.dtype(\"bfloat16\"):\n            x = ov_opset.constant(x, OPENVINO_DTYPES[\"bfloat16\"]).output(0)\n        else:\n            x = ov_opset.constant(x).output(0)\n    elif isinstance(x, (list, tuple)):\n        if isinstance(x, tuple):\n            x = list(x)\n        if ov_type is None:\n            x = ov_opset.constant(x).output(0)\n        else:\n            x = ov_opset.constant(x, ov_type).output(0)\n    elif np.isscalar(x):\n        x = ov_opset.constant(x).output(0)\n    elif isinstance(x, KerasVariable):\n        if isinstance(x.value, OpenVINOKerasTensor):\n            return x.value.output\n        x = ov_opset.constant(x.value.data).output(0)\n    elif isinstance(x, OpenVINOKerasTensor):\n        x = x.output\n    elif isinstance(x, ov.Output):\n        return x\n    elif isinstance(x, Tensor):\n        x = ov_opset.constant(x.data).output(0)\n    else:\n        raise ValueError(\n            \"unsupported type of `x` to create ov.Output: {}\".format(type(x))\n        )\n    return x\n\n\n# wrapper for OpenVINO symbolic tensor ov.Output\n# that provides interface similar to KerasTensor\n# with dtype and shape members\nclass OpenVINOKerasTensor:\n    def __init__(self, x, data=None):\n        x_shape = x.get_partial_shape()\n        if x_shape.rank.is_dynamic:\n            x_keras_shape = None\n        else:\n            x_keras_shape = [\n                None if dim.is_dynamic else dim.get_length()\n                for dim in list(x_shape)\n            ]\n        x_type = x.get_element_type()\n        x_keras_type = ov_to_keras_type(x_type)\n        self.output = x\n        if x_keras_shape is not None:\n            self.shape = tuple(x_keras_shape)\n        else:\n            self.shape = None\n        self.dtype = x_keras_type\n        self.ndim = None\n        self.data = data\n        if x.get_partial_shape().rank.is_static:\n            self.ndim = x.get_partial_shape().rank.get_length()\n\n    def __add__(self, other):\n        first = self.output\n        other = get_ov_output(other)\n        first, other = align_operand_types(\n            first, other, \"OpenVINOKerasTensor::__add__\"\n        )\n        return OpenVINOKerasTensor(ov_opset.add(first, other).output(0))\n\n    def __radd__(self, other):\n        first = self.output\n        other = get_ov_output(other)\n        first, other = align_operand_types(\n            first, other, \"OpenVINOKerasTensor::__radd__\"\n        )\n        return OpenVINOKerasTensor(ov_opset.add(first, other).output(0))\n\n    def __sub__(self, other):\n        first = self.output\n        other = get_ov_output(other)\n        first, other = align_operand_types(\n            first, other, \"OpenVINOKerasTensor::__sub__\"\n        )\n        if first.get_element_type() == Type.boolean:\n            return OpenVINOKerasTensor(\n                ov_opset.logical_xor(first, other).output(0)\n            )\n        return OpenVINOKerasTensor(ov_opset.subtract(first, other).output(0))\n\n    def __rsub__(self, other):\n        first = self.output\n        other = get_ov_output(other)\n        first, other = align_operand_types(\n            first, other, \"OpenVINOKerasTensor::__rsub__\"\n        )\n        return OpenVINOKerasTensor(ov_opset.subtract(other, first).output(0))\n\n    def __mul__(self, other):\n        first = self.output\n        other = get_ov_output(other)\n        first, other = align_operand_types(\n            first, other, \"OpenVINOKerasTensor::__mul__\"\n        )\n        if first.get_element_type() == Type.boolean:\n            return OpenVINOKerasTensor(\n                ov_opset.logical_and(first, other).output(0)\n            )\n        return OpenVINOKerasTensor(ov_opset.multiply(first, other).output(0))\n\n    def __rmul__(self, other):\n        first = self.output\n        other = get_ov_output(other)\n        first, other = align_operand_types(\n            first, other, \"OpenVINOKerasTensor::__rmul__\"\n        )\n        if first.get_element_type() == Type.boolean:\n            return OpenVINOKerasTensor(\n                ov_opset.logical_and(first, other).output(0)\n            )\n        return OpenVINOKerasTensor(ov_opset.multiply(first, other).output(0))\n\n    def __truediv__(self, other):\n        first = self.output\n        other = get_ov_output(other)\n        first, other = align_operand_types(\n            first, other, \"OpenVINOKerasTensor::__truediv__\"\n        )\n        return OpenVINOKerasTensor(ov_opset.divide(first, other).output(0))\n\n    def __rtruediv__(self, other):\n        first = self.output\n        other = get_ov_output(other)\n        first, other = align_operand_types(\n            first, other, \"OpenVINOKerasTensor::__rtruediv__\"\n        )\n        return OpenVINOKerasTensor(ov_opset.divide(other, first).output(0))\n\n    def __floordiv__(self, other):\n        first = self.output\n        other = get_ov_output(other)\n        first, other = align_operand_types(\n            first, other, \"OpenVINOKerasTensor::__floordiv__\"\n        )\n        return OpenVINOKerasTensor(ov_opset.divide(first, other).output(0))\n\n    def __rfloordiv__(self, other):\n        first = self.output\n        other = get_ov_output(other)\n        first, other = align_operand_types(\n            first, other, \"OpenVINOKerasTensor::__rfloordiv__\"\n        )\n        return OpenVINOKerasTensor(ov_opset.divide(other, first).output(0))\n\n    def __neg__(self):\n        first = self.output\n        return OpenVINOKerasTensor(ov_opset.negative(first).output(0))\n\n    def __abs__(self):\n        first = self.output\n        return OpenVINOKerasTensor(ov_opset.absolute(first).output(0))\n\n    def __invert__(self):\n        first = self.output\n        return OpenVINOKerasTensor(ov_opset.logical_not(first).output(0))\n\n    def __pow__(self, other):\n        first = self.output\n        other = get_ov_output(other)\n        first, other = align_operand_types(\n            first, other, \"OpenVINOKerasTensor::__pow__\"\n        )\n        return OpenVINOKerasTensor(ov_opset.power(first, other).output(0))\n\n    def __rpow__(self, other):\n        other = get_ov_output(other)\n        first = self.output\n        first, other = align_operand_types(\n            first, other, \"OpenVINOKerasTensor::__rpow__\"\n        )\n        return OpenVINOKerasTensor(ov_opset.power(other, first).output(0))\n\n    def __lt__(self, other):\n        first = self.output\n        other = get_ov_output(other)\n        first, other = align_operand_types(\n            first, other, \"OpenVINOKerasTensor::__lt__\"\n        )\n        return OpenVINOKerasTensor(ov_opset.less(first, other).output(0))\n\n    def __gt__(self, other):\n        first = self.output\n        other = get_ov_output(other)\n        first, other = align_operand_types(\n            first, other, \"OpenVINOKerasTensor::__gt__\"\n        )\n        return OpenVINOKerasTensor(ov_opset.greater(first, other).output(0))\n\n    def __le__(self, other):\n        first = self.output\n        other = get_ov_output(other)\n        first, other = align_operand_types(\n            first, other, \"OpenVINOKerasTensor::__le__\"\n        )\n        return OpenVINOKerasTensor(ov_opset.less_equal(first, other).output(0))\n\n    def __ge__(self, other):\n        first = self.output\n        other = get_ov_output(other)\n        first, other = align_operand_types(\n            first, other, \"OpenVINOKerasTensor::__ge__\"\n        )\n        return OpenVINOKerasTensor(\n            ov_opset.greater_equal(first, other).output(0)\n        )\n\n    def __eq__(self, other):\n        first = self.output\n        other = get_ov_output(other)\n        first, other = align_operand_types(\n            first, other, \"OpenVINOKerasTensor::__eq__\"\n        )\n        return OpenVINOKerasTensor(ov_opset.equal(first, other).output(0))\n\n    def __ne__(self, other):\n        first = self.output\n        other = get_ov_output(other)\n        first, other = align_operand_types(\n            first, other, \"OpenVINOKerasTensor::__ne__\"\n        )\n        return OpenVINOKerasTensor(ov_opset.not_equal(first, other).output(0))\n\n    def __getitem__(self, indices):\n        data = self.output\n        rank = len(data.get_partial_shape())\n        axes, gather_indices_nodes = [], []\n        slice_axes, slice_starts, slice_ends, slice_steps = [], [], [], []\n        unsqueeze_axes = []\n\n        if not isinstance(indices, tuple):\n            indices = (indices,)\n\n        if any(i is Ellipsis for i in indices):\n            ellipsis_pos = indices.index(Ellipsis)\n            num_specified = sum(\n                i is not Ellipsis and i is not None for i in indices\n            )\n            num_missing = rank - num_specified\n            indices = (\n                indices[:ellipsis_pos]\n                + (builtins.slice(None),) * num_missing\n                + indices[ellipsis_pos + 1 :]\n            )\n\n        def count_unsqueeze_before(dim):\n            return sum(1 for i in range(dim) if indices[i] is None)\n\n        partial_shape = ov_opset.shape_of(data, Type.i32)\n        zero_const = ov_opset.constant(0, Type.i32)\n\n        for dim, index in enumerate(indices):\n            if isinstance(index, bool):\n                raise ValueError(\n                    \"OpenVINO backend does not support boolean indexing\"\n                )\n            elif isinstance(index, (int, np.integer, np.ndarray)):\n                if isinstance(index, (np.ndarray, np.integer)):\n                    if isinstance(index, np.ndarray) and len(index.shape) != 0:\n                        raise ValueError(\n                            \"OpenVINO backend does not support\"\n                            \"multi-dimensional indexing\"\n                        )\n                    index = int(index)\n                actual_dim = dim - count_unsqueeze_before(dim)\n                if not (0 <= actual_dim < rank):\n                    raise IndexError(\n                        f\"Index {index} is out of bounds for \"\n                        f\"axis {dim} with rank {rank}\"\n                    )\n                length = ov_opset.gather(\n                    partial_shape,\n                    ov_opset.constant([actual_dim], Type.i32),\n                    zero_const,\n                )\n                if index >= 0:\n                    idx_value = ov_opset.constant([index], Type.i32)\n                else:\n                    idx_value = ov_opset.add(\n                        ov_opset.constant([index], Type.i32), length\n                    )\n                axes.append(dim)\n                gather_indices_nodes.append(idx_value.output(0))\n            elif isinstance(index, builtins.slice):\n                if index == builtins.slice(None):\n                    continue\n                if index.step is not None and index.step < 0:\n                    raise ValueError(\"OpenVINO doesn't support negative steps\")\n                slice_axes.append(dim)\n                slice_starts.append(0 if index.start is None else index.start)\n                slice_ends.append(\n                    2**31 - 1 if index.stop is None else index.stop\n                )\n                slice_steps.append(1 if index.step is None else index.step)\n            elif index is None:\n                unsqueeze_axes.append(dim)\n            elif isinstance(index, OpenVINOKerasTensor):\n                index = get_ov_output(index)\n                index_type = index.get_element_type()\n                index_shape = index.get_partial_shape()\n                if index_type == Type.boolean or not index_type.is_integral():\n                    raise ValueError(\n                        \"OpenVINO backend does not \"\n                        f\"support {index_type} indexing\"\n                    )\n                axes.append(dim)\n                if len(index_shape) > 1:\n                    raise ValueError(\n                        \"OpenVINO backend does not \"\n                        \"support multi-dimensional indexing\"\n                    )\n                if len(index_shape) == 0:\n                    index = ov_opset.unsqueeze(index, zero_const).output(0)\n                if index_type != Type.i32:\n                    index = ov_opset.convert(index, Type.i32).output(0)\n                shape_tensor = ov_opset.shape_of(data, Type.i32)\n                axis_i32 = ov_opset.constant([dim], dtype=Type.i32)\n                dim_size = ov_opset.gather(shape_tensor, axis_i32, zero_const)\n                is_negative = ov_opset.less(index, zero_const)\n                adjusted_index = ov_opset.add(index, dim_size)\n                index = ov_opset.select(\n                    is_negative, adjusted_index, index\n                ).output(0)\n                gather_indices_nodes.append(index)\n            else:\n                raise ValueError(\n                    f\"Unsupported index type {type(index)} \"\n                    \"in OpenVINOKerasTensor.__getitem__\"\n                )\n\n        if slice_axes:\n            step = ov_opset.constant(slice_steps, Type.i32).output(0)\n            start = ov_opset.constant(slice_starts, Type.i32).output(0)\n            stop = ov_opset.constant(slice_ends, Type.i32).output(0)\n            adjusted_slice_axes = [\n                ax - sum(1 for unsq in unsqueeze_axes if unsq <= ax)\n                for ax in slice_axes\n            ]\n            axes_const = ov_opset.constant(\n                adjusted_slice_axes, Type.i32\n            ).output(0)\n            data = ov_opset.slice(data, start, stop, step, axes_const).output(0)\n\n        if axes:\n            gather_indices_const = (\n                gather_indices_nodes[0]\n                if len(gather_indices_nodes) == 1\n                else ov_opset.concat(gather_indices_nodes, axis=0).output(0)\n            )\n            adjusted_axes = [\n                ax - sum(1 for unsq in unsqueeze_axes if unsq <= ax)\n                for ax in axes\n            ]\n            if len(axes) == 1:\n                data = ov_opset.gather(\n                    data, gather_indices_const, adjusted_axes[0]\n                ).output(0)\n                data = ov_opset.squeeze(data, adjusted_axes[0]).output(0)\n            else:\n                rank = len(data.get_partial_shape())\n                remaining_axes = [\n                    i for i in range(rank) if i not in adjusted_axes\n                ]\n                perm = ov_opset.constant(\n                    adjusted_axes + remaining_axes, Type.i32\n                )\n                data = ov_opset.transpose(data, perm).output(0)\n                data = ov_opset.gather_nd(data, gather_indices_const).output(0)\n\n        if unsqueeze_axes:\n            adjusted_unsqueeze = []\n            for ax in unsqueeze_axes:\n                ax -= sum(1 for s in axes if s < ax)\n                ax -= sum(1 for s in slice_axes if s < ax)\n                adjusted_unsqueeze.append(ax)\n            unsqueeze_const = ov_opset.constant(\n                adjusted_unsqueeze, Type.i32\n            ).output(0)\n            data = ov_opset.unsqueeze(data, unsqueeze_const).output(0)\n\n        return OpenVINOKerasTensor(data)\n\n    def __len__(self):\n        ov_output = self.output\n        ov_shape = ov_output.get_partial_shape()\n        if not (ov_shape.rank.is_static and ov_shape.rank.get_length() > 0):\n            raise ValueError(\n                \"Rank must be static and greater than zero to compute `len()`. \"\n                f\"rank={ov_shape.rank}\"\n            )\n        if not ov_shape[0].is_static:\n            raise ValueError(\n                \"The first dimension must be static to compute `len()`. \"\n                f\"shape={ov_shape}\"\n            )\n        return ov_shape[0].get_length()\n\n    def __iter__(self):\n        if self.shape is None or len(self.shape) == 0:\n            raise TypeError(\"iteration over a 0-d tensor\")\n        for i in range(self.shape[0]):\n            yield self[i]\n\n    def __bool__(self):\n        return bool(self.numpy())\n\n    def __mod__(self, other):\n        first = self.output\n        other = get_ov_output(other)\n        first, other = align_operand_types(\n            first, other, \"OpenVINOKerasTensor::__mod__\"\n        )\n        return OpenVINOKerasTensor(ov_opset.mod(first, other).output(0))\n\n    def __array__(self, dtype=None):\n        try:\n            tensor = cast(self, dtype=dtype) if dtype is not None else self\n            return convert_to_numpy(tensor)\n        except Exception as e:\n            raise RuntimeError(\n                \"An OpenVINOKerasTensor is symbolic: it's a placeholder \"\n                \"for a shape and a dtype.\\n\"\n                \"It doesn't have any actual numerical value.\\n\"\n                \"You cannot convert it to a NumPy array.\"\n            ) from e\n\n    def numpy(self):\n        return self.__array__()\n\n    def __rmod__(self, other):\n        other = get_ov_output(other)\n        first = self.output\n        other, first = align_operand_types(\n            other, first, \"OpenVINOKerasTensor::__rmod__\"\n        )\n        return OpenVINOKerasTensor(ov_opset.mod(other, first).output(0))\n\n    def __matmul__(self, other):\n        first = self.output\n        other = get_ov_output(other)\n        first, other = align_operand_types(\n            first, other, \"OpenVINOKerasTensor::__matmul__\"\n        )\n        return OpenVINOKerasTensor(\n            ov_opset.matmul(first, other, False, False).output(0)\n        )\n\n    def __rmatmul__(self, other):\n        other = get_ov_output(other)\n        first = self.output\n        other, first = align_operand_types(\n            other, first, \"OpenVINOKerasTensor::__rmatmul__\"\n        )\n        return OpenVINOKerasTensor(\n            ov_opset.matmul(other, first, False, False).output(0)\n        )\n\n    def __div__(self, other):\n        return self.__truediv__(other)\n\n    def __rdiv__(self, other):\n        return self.__rtruediv__(other)\n\n    def __and__(self, other):\n        first = self.output\n        other = get_ov_output(other)\n        first, other = align_operand_types(\n            first, other, \"OpenVINOKerasTensor::__and__\"\n        )\n        return OpenVINOKerasTensor(ov_opset.logical_and(first, other).output(0))\n\n    def __rand__(self, other):\n        other = get_ov_output(other)\n        first = self.output\n        other, first = align_operand_types(\n            other, first, \"OpenVINOKerasTensor::__rand__\"\n        )\n        return OpenVINOKerasTensor(ov_opset.logical_and(other, first).output(0))\n\n    def __or__(self, other):\n        first = self.output\n        other = get_ov_output(other)\n        first, other = align_operand_types(\n            first, other, \"OpenVINOKerasTensor::__or__\"\n        )\n        return OpenVINOKerasTensor(ov_opset.logical_or(first, other).output(0))\n\n    def __ror__(self, other):\n        other = get_ov_output(other)\n        first = self.output\n        other, first = align_operand_types(\n            other, first, \"OpenVINOKerasTensor::__ror__\"\n        )\n        return OpenVINOKerasTensor(ov_opset.logical_or(other, first).output(0))\n\n    def __xor__(self, other):\n        first = self.output\n        other = get_ov_output(other)\n        first, other = align_operand_types(\n            first, other, \"OpenVINOKerasTensor::__xor__\"\n        )\n        return OpenVINOKerasTensor(ov_opset.logical_xor(first, other).output(0))\n\n    def __rxor__(self, other):\n        other = get_ov_output(other)\n        first = self.output\n        other, first = align_operand_types(\n            other, first, \"OpenVINOKerasTensor::__rxor__\"\n        )\n        return OpenVINOKerasTensor(ov_opset.logical_xor(other, first).output(0))\n\n    def __int__(self):\n        arr = convert_to_numpy(self)\n        if arr.ndim > 0:\n            raise TypeError(\n                \"Only scalar arrays can be converted to Python scalars. \"\n                f\"Got: shape={arr.shape}\"\n            )\n        return int(arr)\n\n    def __float__(self):\n        arr = convert_to_numpy(self)\n        if arr.ndim > 0:\n            raise TypeError(\n                \"Only scalar arrays can be converted to Python scalars. \"\n                f\"Got: shape={arr.shape}\"\n            )\n        return float(arr)\n\n    def __repr__(self):\n        return f\"<OpenVINOKerasTensor shape={self.shape}, dtype={self.dtype}>\"\n\n    def __round__(self, ndigits=None):\n        first = self.output\n        decimals = ndigits or 0\n        if decimals == 0:\n            result = ov_opset.round(first, \"half_to_even\")\n        else:\n            factor = ov_opset.constant(10.0**decimals, first.get_element_type())\n            scaled = ov_opset.multiply(first, factor)\n            rounded = ov_opset.round(scaled, \"half_to_even\")\n            result = ov_opset.divide(rounded, factor)\n        return OpenVINOKerasTensor(result.output(0))\n\n    def reshape(self, new_shape):\n        first = self.output\n        shape_const = get_ov_output(new_shape)\n        return OpenVINOKerasTensor(\n            ov_opset.reshape(first, shape_const, False).output(0)\n        )\n\n    def squeeze(self, axis=None):\n        first = self.output\n        if axis is not None:\n            axes = get_ov_output([axis] if isinstance(axis, int) else axis)\n        else:\n            axes = get_ov_output(\n                [i for i, d in enumerate(self.shape) if d == 1]\n            )\n        return OpenVINOKerasTensor(ov_opset.squeeze(first, axes).output(0))\n\n\ndef ov_to_keras_type(ov_type):\n    for _keras_type, _ov_type in OPENVINO_DTYPES.items():\n        if ov_type == _ov_type:\n            return _keras_type\n    raise ValueError(\n        f\"Requested OpenVINO type has no keras analogue '{ov_type.to_string()}'\"\n    )\n\n\n@contextlib.contextmanager\ndef device_scope(device_name):\n    yield\n\n\ndef get_device():\n    return \"CPU\"\n\n\nclass Variable(KerasVariable):\n    def _initialize(self, value):\n        if isinstance(value, OpenVINOKerasTensor):\n            self._value = value\n        elif isinstance(value, Tensor):\n            value_const = ov_opset.constant(\n                value.data, dtype=OPENVINO_DTYPES[self._dtype]\n            )\n            self._value = OpenVINOKerasTensor(value_const.output(0))\n        else:\n            value_const = ov_opset.constant(\n                value, dtype=OPENVINO_DTYPES[self._dtype]\n            )\n            self._value = OpenVINOKerasTensor(value_const.output(0))\n\n    def _direct_assign(self, value):\n        self._value = value\n\n    def _convert_to_tensor(self, value, dtype=None):\n        return convert_to_tensor(value, dtype=dtype)\n\n    def __array__(self):\n        return convert_to_numpy(self)\n\n    def __getitem__(self, idx):\n        arr = convert_to_numpy(self)\n        return arr.__getitem__(idx)\n\n    def __int__(self):\n        arr = convert_to_numpy(self)\n        if arr.ndim > 0:\n            raise TypeError(\n                \"Only scalar arrays can be converted to Python scalars. \"\n                f\"Got: shape={arr.shape}\"\n            )\n        return int(arr)\n\n    def __float__(self):\n        arr = convert_to_numpy(self)\n        if arr.ndim > 0:\n            raise TypeError(\n                \"Only scalar arrays can be converted to Python scalars. \"\n                f\"Got: shape={arr.shape}\"\n            )\n        return float(arr)\n\n\ndef _is_scalar(elem):\n    return not isinstance(elem, (list, tuple, set, dict))\n\n\ndef convert_to_tensor(x, dtype=None, sparse=None, ragged=None):\n    if sparse:\n        raise ValueError(\"`sparse=True` is not supported with openvino backend\")\n    if ragged:\n        raise ValueError(\"`ragged=True` is not supported with openvino backend\")\n    if dtype is not None:\n        dtype = standardize_dtype(dtype)\n    if isinstance(x, OpenVINOKerasTensor):\n        if dtype and dtype != standardize_dtype(x.dtype):\n            x = cast(x, dtype)\n        return x\n    elif isinstance(x, np.ndarray):\n        if dtype is not None:\n            ov_type = OPENVINO_DTYPES[dtype]\n        else:\n            ov_type = OPENVINO_DTYPES[standardize_dtype(x.dtype)]\n        return OpenVINOKerasTensor(ov_opset.constant(x, ov_type).output(0))\n    elif isinstance(x, (list, tuple)):\n        if dtype is None:\n            dtype = result_type(\n                *[\n                    getattr(item, \"dtype\", type(item))\n                    for item in tree.flatten(x)\n                ]\n            )\n        x = np.array(x, dtype=dtype)\n        ov_type = OPENVINO_DTYPES[dtype]\n        return OpenVINOKerasTensor(ov_opset.constant(x, ov_type).output(0), x)\n    elif isinstance(x, (float, int, bool)):\n        if dtype is None:\n            dtype = standardize_dtype(type(x))\n        ov_type = OPENVINO_DTYPES[dtype]\n        return OpenVINOKerasTensor(ov_opset.constant(x, ov_type).output(0), x)\n    elif isinstance(x, ov.Output):\n        return OpenVINOKerasTensor(x)\n    if isinstance(x, Variable):\n        x = x.value\n        if dtype and dtype != x.dtype:\n            x = cast(x, dtype)\n        return x\n    original_type = type(x)\n    try:\n        if dtype is None:\n            dtype = getattr(x, \"dtype\", original_type)\n            ov_type = OPENVINO_DTYPES[standardize_dtype(dtype)]\n        else:\n            ov_type = OPENVINO_DTYPES[dtype]\n        x = np.array(x)\n        return OpenVINOKerasTensor(ov_opset.constant(x, ov_type).output(0))\n    except Exception as e:\n        raise TypeError(\n            f\"Cannot convert object of type {original_type} \"\n            f\"to OpenVINOKerasTensor: {e}\"\n        )\n\n\ndef convert_to_numpy(x):\n    if isinstance(x, np.ndarray):\n        return x\n    elif isinstance(x, (int, float)):\n        return np.array(x)\n    elif isinstance(x, (list, tuple)):\n        x_new = []\n        for elem in x:\n            x_new.append(convert_to_numpy(elem))\n        return np.array(x_new)\n    elif np.isscalar(x):\n        return x\n    elif isinstance(x, ov.Tensor):\n        return x.data\n    elif x is None:\n        return x\n    elif isinstance(x, KerasVariable):\n        if isinstance(x.value, OpenVINOKerasTensor):\n            x = x.value\n        else:\n            return x.value.data\n    if not isinstance(x, OpenVINOKerasTensor):\n        raise ValueError(f\"unsupported type {type(x)} for `convert_to_numpy`.\")\n    # if the tensor is backed by a Constant OV node, extract\n    # its data array directly without compiling a model.\n    try:\n        node = x.output.get_node()\n        if node.get_type_name() == \"Constant\":\n            return np.array(node.data)\n    except Exception:\n        # fall back to the slow path.\n        pass\n    try:\n        ov_result = x.output\n        ov_model = Model(results=[ov_result], parameters=[])\n        ov_compiled_model = compile_model(ov_model, get_device())\n        result = ov_compiled_model({})[0]\n    except Exception as inner_exception:\n        raise RuntimeError(\n            \"`convert_to_numpy` failed to convert the tensor.\"\n        ) from inner_exception\n    return np.array(result)\n\n\ndef is_tensor(x):\n    if isinstance(x, OpenVINOKerasTensor):\n        return True\n    if isinstance(x, ov.Tensor):\n        return True\n    return False\n\n\ndef shape(x):\n    return tuple(x.shape)\n\n\ndef cast(x, dtype):\n    dtype = standardize_dtype(dtype)\n    ov_type = OPENVINO_DTYPES[dtype]\n    x = get_ov_output(x)\n    return OpenVINOKerasTensor(ov_opset.convert(x, ov_type).output(0))\n\n\ndef cond(pred, true_fn, false_fn):\n    true_val = true_fn()\n    false_val = false_fn()\n\n    if true_val is None:\n        return None\n\n    if isinstance(pred, bool):\n        pred_ov = ov_opset.constant(pred, Type.boolean).output(0)\n    else:\n        pred_ov = get_ov_output(pred)\n        if pred_ov.get_element_type() != Type.boolean:\n            pred_ov = ov_opset.convert(pred_ov, Type.boolean).output(0)\n\n    def _select(t, f):\n        t_ov, f_ov = align_operand_types(\n            get_ov_output(t), get_ov_output(f), \"cond\"\n        )\n        return OpenVINOKerasTensor(\n            ov_opset.select(pred_ov, t_ov, f_ov).output(0)\n        )\n\n    if isinstance(true_val, (list, tuple)):\n        return type(true_val)(\n            _select(t, f) for t, f in zip(true_val, false_val)\n        )\n    return _select(true_val, false_val)\n\n\ndef vectorized_map(function, elements):\n    return map(function, elements)\n\n\n# Shape / dtype inference util\ndef compute_output_spec(fn, *args, **kwargs):\n    with StatelessScope():\n\n        def convert_keras_tensor_to_openvino(x):\n            if isinstance(x, KerasTensor):\n                x_shape = list(x.shape)\n                x_shape = [-1 if dim is None else dim for dim in x_shape]\n                x_type = OPENVINO_DTYPES[x.dtype]\n                param = ov_opset.parameter(shape=x_shape, dtype=x_type)\n                return OpenVINOKerasTensor(param.output(0))\n            return x\n\n        args_1, kwargs_1 = tree.map_structure(\n            lambda x: convert_keras_tensor_to_openvino(x),\n            (args, kwargs),\n        )\n        outputs_1 = fn(*args_1, **kwargs_1)\n\n        outputs = outputs_1\n\n        def convert_openvino_to_keras_tensor(x):\n            if is_tensor(x):\n                x_type = x.dtype\n                x_shape = x.shape\n                return KerasTensor(x_shape, x_type)\n            elif isinstance(x, OpenVINOKerasTensor):\n                x_type = x.dtype\n                x_shape = x.shape\n                return KerasTensor(x_shape, x_type)\n            return x\n\n        output_spec = tree.map_structure(\n            convert_openvino_to_keras_tensor, outputs\n        )\n    return output_spec\n\n\ndef map(f, xs):\n    def g(_, x):\n        return (), f(x)\n\n    _, ys = scan(g, (), xs)\n    return ys\n\n\ndef scan(f, init, xs=None, length=None, reverse=False, unroll=1):\n    # Ref: jax.lax.scan\n    if not callable(f):\n        raise TypeError(f\"`f` should be a callable. Received: f={f}\")\n    if not isinstance(unroll, bool):\n        if not isinstance(unroll, int) or unroll < 1:\n            raise ValueError(\n                \"`unroll` must be an positive integer or boolean. \"\n                f\"Received: unroll={unroll}\"\n            )\n    if xs is None and length is None:\n        raise ValueError(\"Got no `xs` to scan over and `length` not provided.\")\n\n    input_is_sequence = tree.is_nested(xs)\n    output_is_sequence = tree.is_nested(init)\n\n    def pack_input(x):\n        return tree.pack_sequence_as(xs, x) if input_is_sequence else x[0]\n\n    def pack_output(x):\n        return tree.pack_sequence_as(init, x) if output_is_sequence else x[0]\n\n    if xs is None:\n        xs_flat = []\n        n = int(length)\n    else:\n        xs_flat = tree.flatten(xs)\n        xs_flat = [convert_to_tensor(elem) for elem in xs_flat]\n        n = (\n            int(length)\n            if length is not None\n            else (shape(xs_flat[0])[0] if xs_flat else 0)\n        )\n\n    init_flat = tree.flatten(init)\n    init_flat = [convert_to_tensor(i) for i in init_flat]\n    init = pack_output(init_flat)\n\n    dummy_y = []\n    for i in init_flat:\n        i_ov = get_ov_output(i)\n        zero = ov_opset.constant(0, i_ov.get_element_type()).output(0)\n        shape_node = ov_opset.shape_of(i_ov, Type.i32).output(0)\n        dummy_y.append(\n            OpenVINOKerasTensor(ov_opset.broadcast(zero, shape_node).output(0))\n        )\n\n    carry = init\n    ys = []\n    maybe_reversed = reversed if reverse else lambda x: x\n    for i in maybe_reversed(range(n)):\n        xs_slice = [x[i] for x in xs_flat]\n        packed_xs = pack_input(xs_slice) if len(xs_slice) > 0 else None\n        carry, y = f(carry, packed_xs)\n        ys.append(y if y is not None else dummy_y)\n\n    def _stack(tensors):\n        elems = [get_ov_output(t) for t in tensors]\n        const_axis = ov_opset.constant(0, Type.i32).output(0)\n        elems = [ov_opset.unsqueeze(e, const_axis).output(0) for e in elems]\n        return OpenVINOKerasTensor(ov_opset.concat(elems, 0).output(0))\n\n    stacked_y = tree.map_structure(\n        lambda *y: _stack(list(y)), *maybe_reversed(ys)\n    )\n    return carry, stacked_y\n\n\ndef associative_scan(f, elems, reverse=False, axis=0):\n    # Ref: jax.lax.associative_scan\n    if not callable(f):\n        raise TypeError(f\"`f` should be a callable. Received: f={f}\")\n    elems_flat = tree.flatten(elems)\n    elems_flat = [convert_to_tensor(elem) for elem in elems_flat]\n\n    def _flip(x, axis):\n        x_ov = get_ov_output(x)\n        ndim = len(x_ov.get_partial_shape())\n        begin = [0] * ndim\n        end = [0] * ndim\n        strides = [1] * ndim\n        strides[axis] = -1\n        mask = [1] * ndim\n        result = ov_opset.strided_slice(\n            data=x_ov,\n            begin=begin,\n            end=end,\n            strides=strides,\n            begin_mask=mask,\n            end_mask=mask,\n        ).output(0)\n        return OpenVINOKerasTensor(result)\n\n    def _concat(tensors, axis):\n        elems = [get_ov_output(t) for t in tensors]\n        keras_types = [ov_to_keras_type(e.get_element_type()) for e in elems]\n        if keras_types:\n            target = OPENVINO_DTYPES[result_type(*keras_types)]\n            elems = [\n                ov_opset.convert(e, target).output(0)\n                if e.get_element_type() != target\n                else e\n                for e in elems\n            ]\n        return OpenVINOKerasTensor(ov_opset.concat(elems, axis).output(0))\n\n    def _unsqueeze(x, axis):\n        x_ov = get_ov_output(x)\n        const_axis = ov_opset.constant(axis, Type.i32).output(0)\n        return OpenVINOKerasTensor(\n            ov_opset.unsqueeze(x_ov, const_axis).output(0)\n        )\n\n    if reverse:\n        elems_flat = [_flip(elem, axis) for elem in elems_flat]\n\n    def _combine(a_flat, b_flat):\n        a = tree.pack_sequence_as(elems, a_flat)\n        b = tree.pack_sequence_as(elems, b_flat)\n        c = f(a, b)\n        return tree.flatten(c)\n\n    num_elems = int(elems_flat[0].shape[axis])\n    if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]):\n        raise ValueError(\n            \"Array inputs to associative_scan must have the same \"\n            \"first dimension. (saw: {})\".format(\n                [elem.shape for elem in elems_flat]\n            )\n        )\n\n    def _interleave(a, b, axis):\n        n_a = a.shape[axis]\n        n_b = b.shape[axis]\n\n        a_common = slice_along_axis(a, 0, n_b, axis=axis)\n        a_exp = _unsqueeze(a_common, axis + 1)\n        b_exp = _unsqueeze(b, axis + 1)\n        interleaved = _concat([a_exp, b_exp], axis + 1)\n\n        interleaved_ov = get_ov_output(interleaved)\n        orig_shape = ov_opset.shape_of(interleaved_ov, Type.i32).output(0)\n        ndim = len(interleaved_ov.get_partial_shape())\n        pre = ov_opset.slice(\n            orig_shape,\n            ov_opset.constant([0], Type.i32),\n            ov_opset.constant([axis], Type.i32),\n            ov_opset.constant([1], Type.i32),\n        ).output(0)\n        merged_dim = ov_opset.constant([n_b * 2], Type.i32).output(0)\n        post = ov_opset.slice(\n            orig_shape,\n            ov_opset.constant([axis + 2], Type.i32),\n            ov_opset.constant([ndim], Type.i32),\n            ov_opset.constant([1], Type.i32),\n        ).output(0)\n        target_shape = ov_opset.concat([pre, merged_dim, post], 0).output(0)\n        interleaved = OpenVINOKerasTensor(\n            ov_opset.reshape(interleaved_ov, target_shape, False).output(0)\n        )\n\n        if n_a > n_b:\n            last = slice_along_axis(a, n_b, n_b + 1, axis=axis)\n            interleaved = _concat([interleaved, last], axis)\n\n        return interleaved\n\n    def _scan(elems):\n        num_elems = elems[0].shape[axis]\n        if num_elems < 2:\n            return elems\n\n        reduced_elems = _combine(\n            [slice_along_axis(e, 0, -1, step=2, axis=axis) for e in elems],\n            [slice_along_axis(e, 1, None, step=2, axis=axis) for e in elems],\n        )\n        odd_elems = _scan(reduced_elems)\n\n        if num_elems % 2 == 0:\n            even_elems = _combine(\n                [slice_along_axis(e, 0, -1, axis=axis) for e in odd_elems],\n                [\n                    slice_along_axis(e, 2, None, step=2, axis=axis)\n                    for e in elems\n                ],\n            )\n        else:\n            even_elems = _combine(\n                odd_elems,\n                [\n                    slice_along_axis(e, 2, None, step=2, axis=axis)\n                    for e in elems\n                ],\n            )\n        even_elems = [\n            _concat(\n                [slice_along_axis(elem, 0, 1, axis=axis), result],\n                axis,\n            )\n            for elem, result in zip(elems, even_elems)\n        ]\n        return [_interleave(e, o, axis) for e, o in zip(even_elems, odd_elems)]\n\n    scanned_elems = _scan(elems_flat)\n    if reverse:\n        scanned_elems = [_flip(elem, axis) for elem in scanned_elems]\n    return tree.pack_sequence_as(elems, scanned_elems)\n\n\ndef scatter(indices, values, shape):\n    indices = get_ov_output(indices)\n    values = get_ov_output(values)\n\n    # Create a zeros tensor of the target shape.\n    shape = get_ov_output(shape)\n    zero_const = ov_opset.constant(0, values.get_element_type())\n    zeros = ov_opset.broadcast(zero_const, shape).output(0)\n\n    return scatter_update(zeros, indices, values, \"add\")\n\n\ndef scatter_update(inputs, indices, updates, reduction=None):\n    inputs = get_ov_output(inputs)\n    indices = get_ov_output(indices)\n    updates = get_ov_output(updates)\n\n    inputs, updates = align_operand_types(inputs, updates, \"scatter_update\")\n\n    # Map Keras reduction to OpenVINO ScatterNDUpdate reduction.\n    # OpenVINO Opset 15 supports: \"none\", \"sum\", \"sub\", \"prod\", \"min\", \"max\".\n    if reduction is None:\n        ov_reduction = \"none\"\n    elif reduction == \"add\":\n        ov_reduction = \"sum\"\n    elif reduction == \"mul\":\n        ov_reduction = \"prod\"\n    elif reduction in (\"max\", \"min\"):\n        ov_reduction = reduction\n    else:\n        raise ValueError(f\"Unsupported reduction: {reduction}\")\n\n    result = ov_opset.scatter_nd_update(\n        inputs, indices, updates, reduction=ov_reduction\n    ).output(0)\n    return OpenVINOKerasTensor(result)\n\n\ndef slice(inputs, start_indices, shape):\n    inputs = get_ov_output(inputs)\n    if isinstance(start_indices, (list, np.ndarray)):\n        start_indices = tuple(start_indices)\n    if isinstance(shape, (list, np.ndarray)):\n        shape = tuple(shape)\n    if not isinstance(start_indices, tuple):\n        raise ValueError(\n            \"`slice` operation requires tuple for `start_indices with the \"\n            f\"openvino backend. Received: start_indices={start_indices}\"\n        )\n    if not isinstance(shape, tuple):\n        raise ValueError(\n            \"`slice` operation requires tuple for `shape` with the \"\n            f\"openvino backend. Received: shape={shape}\"\n        )\n\n    axes = []\n    start = []\n    stop = []\n\n    def prepare_slice_index(val):\n        val_type = val.get_element_type()\n        if not val_type.is_integral():\n            raise ValueError(\n                \"`slice` is not supported by OpenVINO backend \"\n                \"for `start_indices` or `shape` with non-integer types\"\n            )\n        if val_type != Type.i32:\n            val = ov_opset.convert(val, Type.i32).output(0)\n        if len(val.get_partial_shape()) == 0:\n            val = ov_opset.unsqueeze(\n                val, ov_opset.constant(0, Type.i32)\n            ).output(0)\n        return val\n\n    for idx, length in enumerate(shape):\n        if length is not None and length >= 0:\n            axes.append(idx)\n            start_val = prepare_slice_index(get_ov_output(start_indices[idx]))\n            stop_val = prepare_slice_index(\n                get_ov_output(start_indices[idx] + length)\n            )\n            start.append(start_val)\n            stop.append(stop_val)\n\n    if len(axes) == 0:\n        return inputs\n\n    step = [1] * len(start)\n    step = ov_opset.constant(step, Type.i32).output(0)\n    start = ov_opset.concat(start, axis=0).output(0)\n    stop = ov_opset.concat(stop, axis=0).output(0)\n    axes = ov_opset.constant(axes, Type.i32).output(0)\n    result = ov_opset.slice(inputs, start, stop, step, axes).output(0)\n\n    # Apply reshape to ensure output matches expected shape\n    # Convert None (dynamic) dimensions to -1 for OpenVINO compatibility\n    if all(dim is None or (isinstance(dim, int) and dim >= 0) for dim in shape):\n        reshape_pattern = [(-1 if dim is None else dim) for dim in shape]\n        target_shape = ov_opset.constant(reshape_pattern, Type.i32).output(0)\n        result = ov_opset.reshape(result, target_shape, False).output(0)\n\n    return OpenVINOKerasTensor(result)\n\n\ndef slice_update(inputs, start_indices, updates):\n    inputs = get_ov_output(inputs)\n    updates_tensor = get_ov_output(updates)\n\n    if isinstance(start_indices, (list, np.ndarray)):\n        start_indices = tuple(start_indices)\n    if not isinstance(start_indices, tuple):\n        raise ValueError(\n            \"`slice_update` is not supported by openvino backend\"\n            \" for `start_indices` of type {}\".format(type(start_indices))\n        )\n\n    zero_scalar = ov_opset.constant(0, Type.i32)\n    one_scalar = ov_opset.constant(1, Type.i32)\n    zero_tensor = ov_opset.constant([0], Type.i32)\n    one_tensor = ov_opset.constant([1], Type.i32)\n\n    processed_start_indices = []\n    for idx in start_indices:\n        val = get_ov_output(idx)\n        if not val.get_element_type().is_integral():\n            raise ValueError(\"`slice_update` requires integral start_indices\")\n        if val.get_element_type() != Type.i32:\n            val = ov_opset.convert(val, Type.i32).output(0)\n        if val.get_partial_shape().rank.get_length() == 0:\n            val = ov_opset.unsqueeze(val, zero_scalar).output(0)\n        processed_start_indices.append(val)\n\n    updates_shape = ov_opset.shape_of(updates_tensor, Type.i32).output(0)\n    rank = updates_tensor.get_partial_shape().rank.get_length()\n    if rank == 0:\n        # Handle scalar update\n        start_tensor = ov_opset.concat(processed_start_indices, axis=0).output(\n            0\n        )\n        # For scatter_nd_update,\n        # indices should be of shape [num_updates, rank_of_inputs]\n        # and updates should be of shape [num_updates]. Here num_updates is 1.\n        absolute_indices = ov_opset.unsqueeze(start_tensor, zero_scalar).output(\n            0\n        )\n        updates_flat = ov_opset.unsqueeze(updates_tensor, zero_scalar).output(0)\n        result = ov_opset.scatter_nd_update(\n            inputs, absolute_indices, updates_flat\n        ).output(0)\n        return OpenVINOKerasTensor(result)\n\n    # Compute the total number of elements in the updates tensor.\n    # Example:\n    # if updates.shape = [2, 3], total_elements = 6.\n    total_elements = ov_opset.reduce_prod(\n        updates_shape, zero_tensor, keep_dims=False\n    ).output(0)\n\n    # Generate a flat range [0, 1, ..., total_elements-1].\n    # This will be used to enumerate all positions in the updates tensor.\n    flat_indices = ov_opset.range(\n        zero_scalar, total_elements, one_scalar, output_type=Type.i32\n    ).output(0)\n\n    dim_sizes = []\n    strides = []\n\n    # For each dimension, compute its size and the stride.\n    # (number of elements to skip to move to the next index in this dimension).\n    # Example:\n    # for shape [2, 3], strides = [3, 1].\n    for dim in range(rank):\n        dim_size = ov_opset.gather(\n            updates_shape, ov_opset.constant([dim], Type.i32), zero_scalar\n        ).output(0)\n        dim_size_scalar = ov_opset.squeeze(dim_size, zero_tensor).output(0)\n        dim_sizes.append(dim_size_scalar)\n\n        # Strides to convert a flat index into a multi-dimensional index.\n        # This allows us to map each element in the flattened updates tensor\n        # to its correct N-dimensional position, so we can compute the absolute\n        # index in the input tensor for the scatter update.\n        # Stride for a dimension is the product of all dimensions after it.\n        # For the last dimension, stride is 1.\n        # Example:\n        # For a 3D tensor with shape [2, 3, 4]:\n        #   - stride for dim=0 (first axis) is 3*4=12\n        #     (to move to the next \"block\" along axis 0)\n        #   - stride for dim=1 is 4 (to move to the next row along axis 1)\n        #   - stride for dim=2 is 1 (to move to the next element along axis 2)\n        # This is equivalent to how numpy flattens multi-dimensional arrays.\n        if dim < rank - 1:\n            remaining_dims = ov_opset.slice(\n                updates_shape,\n                ov_opset.constant([dim + 1], Type.i32),\n                ov_opset.constant([rank], Type.i32),\n                one_tensor,\n                zero_tensor,\n            ).output(0)\n            stride = ov_opset.reduce_prod(\n                remaining_dims, zero_tensor, keep_dims=False\n            ).output(0)\n        else:\n            stride = one_scalar\n        strides.append(stride)\n\n    coord_tensors = []\n    # For each dimension, compute the coordinate for every flat index.\n    # Example:\n    # for shape [2, 3], flat index 4 -> coordinates [1, 1] (row 1, col 1).\n    for dim in range(rank):\n        coords = ov_opset.mod(\n            ov_opset.divide(flat_indices, strides[dim]).output(0),\n            dim_sizes[dim],\n        ).output(0)\n        coord_tensors.append(coords)\n\n    coord_tensors_unsqueezed = []\n    for coord in coord_tensors:\n        # Unsqueeze to make each coordinate a column vector for concatenation.\n        coord_unsqueezed = ov_opset.unsqueeze(coord, one_tensor).output(0)\n        coord_tensors_unsqueezed.append(coord_unsqueezed)\n\n    # Concatenate all coordinate columns to form [total_elements, rank] matrix.\n    # Each row is a multi-dimensional index into the updates tensor.\n    # Example:\n    # for shape [2, 3], row 4 = [1, 1].\n    indices_matrix = ov_opset.concat(coord_tensors_unsqueezed, axis=1).output(0)\n\n    # Broadcast start indices to match the number of updates.\n    # Example:\n    # start_indices = (2, 3), indices_matrix = [[0,0],[0,1],...],\n    # start_broadcast = [[2,3],[2,3],...]\n    start_tensor = ov_opset.concat(processed_start_indices, axis=0).output(0)\n    start_reshaped = ov_opset.reshape(\n        start_tensor, ov_opset.constant([1, rank], Type.i32), special_zero=False\n    ).output(0)\n\n    broadcast_shape = ov_opset.concat(\n        [\n            ov_opset.unsqueeze(total_elements, zero_tensor).output(0),\n            one_tensor,\n        ],\n        axis=0,\n    ).output(0)\n\n    start_broadcast = ov_opset.tile(start_reshaped, broadcast_shape).output(0)\n\n    # Add the broadcasted start indices to the relative indices\n    # to get absolute indices in the input tensor.\n    # Example:\n    # if start=(2,3), update index [1,1] -> absolute index [3,4].\n    absolute_indices = ov_opset.add(indices_matrix, start_broadcast).output(0)\n\n    # Flatten the updates tensor to match the flat indices.\n    updates_flat = ov_opset.reshape(\n        updates_tensor,\n        ov_opset.unsqueeze(total_elements, zero_tensor).output(0),\n        special_zero=False,\n    ).output(0)\n\n    # Perform the scatter update: for each absolute index,\n    # set the corresponding value from updates_flat.\n    result = ov_opset.scatter_nd_update(\n        inputs, absolute_indices, updates_flat\n    ).output(0)\n    return OpenVINOKerasTensor(result)\n\n\ndef switch(index, branches, *operands):\n    if len(branches) == 1:\n        return branches[0](*operands)\n\n    n = len(branches)\n    index_ov = get_ov_output(convert_to_tensor(index, \"int32\"))\n    index_ov = ov_opset.clamp(index_ov, 0, n - 1).output(0)\n    operands_ov = [get_ov_output(op_val) for op_val in operands]\n\n    def _trace_branch(branch_fn):\n        params, wrapped = [], []\n        for ov_out in operands_ov:\n            p = ov_opset.parameter(\n                ov_out.get_partial_shape(), ov_out.get_element_type()\n            )\n            params.append(p)\n            wrapped.append(OpenVINOKerasTensor(p.output(0)))\n        raw = branch_fn(*wrapped)\n        if raw is None:\n            flat = []\n        elif isinstance(raw, (list, tuple)):\n            flat = [get_ov_output(o) for o in raw]\n        else:\n            flat = [get_ov_output(raw)]\n        return params, Model(flat, params), raw\n\n    def _build(branch_idx):\n        inner_outputs = None\n        then_params, then_body, then_raw = _trace_branch(branches[branch_idx])\n        if branch_idx == n - 2:\n            else_params, else_body, _ = _trace_branch(branches[branch_idx + 1])\n        else:\n            inner_outputs, _ = _build(branch_idx + 1)\n            else_params, pt_results = [], []\n            for inner_out in inner_outputs:\n                ep = ov_opset.parameter(\n                    inner_out.get_partial_shape(),\n                    inner_out.get_element_type(),\n                )\n                else_params.append(ep)\n                pt_results.append(ep.output(0))\n            else_body = Model(pt_results, else_params)\n\n        cond = ov_opset.equal(\n            index_ov,\n            ov_opset.constant(branch_idx, Type.i32).output(0),\n        ).output(0)\n        if_node = ov_opset.if_op(cond)\n        if_node.set_then_body(then_body)\n        if_node.set_else_body(else_body)\n\n        if inner_outputs is None:\n            for ov_inp, tp, ep in zip(operands_ov, then_params, else_params):\n                if_node.set_input(ov_inp, tp, ep)\n        else:\n            for ov_inp, tp in zip(operands_ov, then_params):\n                if_node.set_input(ov_inp, tp, None)\n            for inner_out, ep in zip(inner_outputs, else_params):\n                if_node.set_input(inner_out, None, ep)\n\n        outputs = [\n            if_node.set_output(then_body.results[i], else_body.results[i])\n            for i in range(len(then_body.results))\n        ]\n        return outputs, then_raw\n\n    final_outputs, template_raw = _build(0)\n    wrapped = [OpenVINOKerasTensor(o) for o in final_outputs]\n\n    if template_raw is None:\n        return None\n    elif isinstance(template_raw, tuple):\n        return tuple(wrapped)\n    elif isinstance(template_raw, list):\n        return list(wrapped)\n    else:\n        return wrapped[0]\n\n\ndef while_loop(\n    cond,\n    body,\n    loop_vars,\n    maximum_iterations=None,\n):\n    def flatten_structure(data):\n        if isinstance(data, dict):\n            return [v for k in sorted(data) for v in flatten_structure(data[k])]\n        elif isinstance(data, (tuple, list)):\n            return [v for item in data for v in flatten_structure(item)]\n        else:\n            return [data]\n\n    def pack_structure(template, flat):\n        if isinstance(template, dict):\n            keys = sorted(template)\n            packed = {}\n            for k in keys:\n                value, flat = pack_structure(template[k], flat)\n                packed[k] = value\n            return packed, flat\n        elif isinstance(template, (tuple, list)):\n            packed = []\n            for item in template:\n                value, flat = pack_structure(item, flat)\n                packed.append(value)\n            return (\n                tuple(packed) if isinstance(template, tuple) else packed\n            ), flat\n        else:\n            return flat[0], flat[1:]\n\n    is_scalar_input = _is_scalar(loop_vars)\n\n    if is_scalar_input:\n        loop_vars = (loop_vars,)\n    elif isinstance(loop_vars, (list, np.ndarray)):\n        loop_vars = tuple(loop_vars)\n    else:\n        if not isinstance(loop_vars, (tuple, dict)):\n            raise ValueError(\n                \"Expected tuple or dict for `loop_vars`, \"\n                f\"Received: {type(loop_vars)}\"\n            )\n\n    flat_loop_vars = flatten_structure(loop_vars)\n    loop_vars_ov = [get_ov_output(var) for var in flat_loop_vars]\n\n    maximum_iterations = (\n        ov_opset.constant(-1, Type.i32).output(0)\n        if maximum_iterations is None\n        else get_ov_output(maximum_iterations)\n    )\n\n    trip_count = maximum_iterations\n    execution_condition = ov_opset.constant(True, Type.boolean).output(0)\n    loop = ov_opset.loop(trip_count, execution_condition)\n\n    shapes = [var.get_partial_shape() for var in loop_vars_ov]\n    types = [var.get_element_type() for var in loop_vars_ov]\n    params = [\n        ov_opset.parameter(shape, dtype) for shape, dtype in zip(shapes, types)\n    ]\n    param_tensors = [OpenVINOKerasTensor(p.output(0)) for p in params]\n\n    packed_args, _ = pack_structure(loop_vars, param_tensors)\n    if isinstance(packed_args, dict):\n        body_out = body(packed_args)\n    else:\n        body_out = body(*packed_args)\n\n    if not isinstance(body_out, (list, tuple, dict)):\n        body_out = (body_out,)\n\n    flat_body_out = flatten_structure(body_out)\n    if isinstance(packed_args, dict):\n        cond_output = get_ov_output(cond(body_out))\n    else:\n        cond_output = get_ov_output(cond(*body_out))\n\n    if len(cond_output.get_partial_shape()) != 0:\n        raise ValueError(\n            \"`cond` function must return a scalar boolean value, \"\n            \"but got shape {}\".format(cond_output.get_partial_shape())\n        )\n\n    for p, out in zip(params, flat_body_out):\n        out_shape = get_ov_output(out).get_partial_shape()\n        p.set_partial_shape(out_shape)\n\n    results = [cond_output] + [get_ov_output(x) for x in flat_body_out]\n    body_func = Model(results=results, parameters=params)\n    loop.set_function(body_func)\n    loop.set_special_body_ports([-1, 0])\n\n    for param, init_val, next_val in zip(params, loop_vars_ov, flat_body_out):\n        loop.set_merged_input(param, init_val, get_ov_output(next_val))\n\n    outputs_flat = [\n        OpenVINOKerasTensor(loop.get_iter_value(get_ov_output(val)))\n        for val in flat_body_out\n    ]\n    final_output, _ = pack_structure(loop_vars, outputs_flat)\n\n    if is_scalar_input:\n        if isinstance(final_output, tuple):\n            return final_output[0]\n        else:\n            return final_output\n    else:\n        return final_output\n\n\ndef fori_loop(lower, upper, body_fun, init_val):\n    return while_loop(\n        lambda i, val: i < upper,\n        lambda i, val: (i + 1, body_fun(i, val)),\n        (lower, init_val),\n    )[1]\n\n\ndef stop_gradient(variable):\n    return variable\n\n\ndef unstack(x, num=None, axis=0):\n    x_ov = get_ov_output(x)\n    axis_ov = get_ov_output(axis)\n\n    if num is None:\n        shape = x_ov.get_partial_shape()\n        num = shape[axis].get_length()\n\n    split_ov = ov_opset.split(x_ov, axis_ov, num)\n\n    return [\n        OpenVINOKerasTensor(ov_opset.squeeze(out, axis_ov).output(0))\n        for out in split_ov.outputs()\n    ]\n\n\ndef random_seed_dtype():\n    # OpenVINO arithmetic promotes uint32 * int32 → int32 (Python ints are\n    # i32 in get_ov_output), so the seed tensor from SeedGenerator.next()\n    # ends up as int32. Returning int32 keeps the declared dtype consistent\n    # with what the backend actually produces.\n    return \"int32\"\n\n\ndef custom_gradient(fun):\n    \"\"\"Decorator for custom gradients.\n\n    Args:\n        fun: Forward pass function.\n    \"\"\"\n\n    def __init__(self, fun):\n        warnings.warn(\n            \"`custom_gradient` for the openvino backend\"\n            \" acts as a pass-through to \"\n            \"support the forward pass.\"\n            \" No gradient computation or modification \"\n            \"takes place.\"\n        )\n        self.fun = fun\n\n    def __call__(self, *args, **kwargs):\n        outputs, _ = self.fun(*args, **kwargs)\n        return outputs\n\n\ndef remat(f):\n    warnings.warn(\n        \"Rematerialization memory optimization is not supported by the \"\n        \"OpenVino backend. Please switch to JAX, TensorFlow, or PyTorch to \"\n        \"utilize this feature.\"\n    )\n    return f\n"
  },
  {
    "path": "keras/src/backend/openvino/excluded_concrete_tests.txt",
    "content": "AdadeltaTest::test_correctness_with_golden\nAdadeltaTest::test_weight_decay\nAdafactorTest::test_weight_decay\nAdagradTest::test_correctness_with_golden\nAdagradTest::test_weight_decay\nAdamaxTest::test_correctness_with_golden\nAdamaxTest::test_weight_decay\nAdamTest::test_correctness_with_golden\nAdamTest::test_weight_decay\nAdamWTest::test_correctness_with_golden\nAdamWTest::test_weight_decay\nAdditiveAttentionTest::test_attention_basics\nAttentionTest::test_attention_basics\nAttentionTest::test_attention_calculate_scores_with_scale\nAUCTest::test_config\nAUCTest::test_weighted_pr_interpolation\nAUCTest::test_weighted_pr_interpolation_negative_weights\nAudioDatasetFromDirectoryTest::test_audio_dataset_from_directory_manual_labels\nAveragePoolingBasicTest::test_average_pooling1d0\nAveragePoolingBasicTest::test_average_pooling1d1\nAveragePoolingBasicTest::test_average_pooling1d2\nAveragePoolingBasicTest::test_average_pooling2d0\nAveragePoolingBasicTest::test_average_pooling2d1\nAveragePoolingBasicTest::test_average_pooling2d2\nAveragePoolingBasicTest::test_average_pooling2d3\nAveragePoolingBasicTest::test_average_pooling2d4\nAveragePoolingBasicTest::test_average_pooling2d5\nAveragePoolingBasicTest::test_average_pooling2d6\nAveragePoolingBasicTest::test_average_pooling3d0\nAveragePoolingBasicTest::test_average_pooling3d1\nAveragePoolingBasicTest::test_average_pooling3d2\nCircleTest::test_correctness\nCircleTest::test_correctness_weighted\nCircleTest::test_dtype_arg\nCircleTest::test_mean_with_sample_weight_reduction\nCircleTest::test_no_reduction\nCircleTest::test_sum_reduction\nComputeScaleZeroTest::test_dequantize_with_sz_map_logic\nComputeScaleZeroTest::test_dtype_and_finiteness_sym_true\nComputeScaleZeroTest::test_per_tensor_shapes_and_basic_invariants_bits2_sym\nComputeScaleZeroTest::test_per_tensor_shapes_and_basic_invariants_bits4_sym\nComputeScaleZeroTest::test_per_tensor_shapes_and_basic_invariants_bits8_sym\nComputeScaleZeroTest::test_per_tensor_symmetric_on_constant_input_uses_safe_range\nComputeScaleZeroTest::test_quantize_with_sz_map_logic\nConvLSTM1DTest::test_correctness\nConvLSTM2DTest::test_correctness\nConvLSTMTest::test_correctness\nCoreOpsCallsTests::test_unstack_basic_functionality\nCTCTest::test_correctness\nCTCTest::test_dtype_arg\nDenseTest::test_dense_quantize_config_int4\nDenseTest::test_dense_quantize_config_int8\nDenseTest::test_dense_quantize_config_int8_weight_only\nDenseTest::test_int4_block_size_serialization_grouped_block_128\nDenseTest::test_int4_block_size_serialization_grouped_block_64\nDenseTest::test_int4_block_size_serialization_per_channel_none\nDenseTest::test_int4_block_size_with_lora_grouped_block_64\nDenseTest::test_int4_block_size_with_lora_per_channel\nDenseTest::test_int4_quantization_block_size_grouped_block_128\nDenseTest::test_int4_quantization_block_size_grouped_block_64\nDenseTest::test_int4_quantization_block_size_per_channel_neg1\nDenseTest::test_int4_quantization_block_size_per_channel_none\nDenseTest::test_int4_subchannel_g_idx_serialization\nDenseTest::test_quantize_float8_inference\nDenseTest::test_quantize_int_int4\nDenseTest::test_quantize_int_int8\nDTypePolicyMapTest::test_basic_usage\nDtypesTest::test_empty_lub_in_least_upper_bound\nDtypesTest::test_result_type_with_float64_bfloat16\nDtypesTest::test_result_type_with_float64_float16\nDtypesTest::test_result_type_with_float64_float32\nDtypesTest::test_result_type_with_float64_float64\nDtypesTest::test_result_type_with_float64_int16\nDtypesTest::test_result_type_with_float64_int32\nDtypesTest::test_result_type_with_float64_int64\nDtypesTest::test_result_type_with_float64_int8\nDtypesTest::test_result_type_with_float64_uint16\nDtypesTest::test_result_type_with_float64_uint8\nDtypesTest::test_result_type_with_int64_int16\nDtypesTest::test_result_type_with_int64_int32\nDtypesTest::test_result_type_with_int64_int64\nDtypesTest::test_result_type_with_int64_int8\nDtypesTest::test_result_type_with_int64_uint16\nDtypesTest::test_result_type_with_int64_uint32\nDtypesTest::test_result_type_with_int64_uint8\nEarlyStoppingTest::test_early_stopping\nEarlyStoppingTest::test_early_stopping_reuse\nEinsumDenseTest::test_einsum_dense_quantize_int4\nEinsumDenseTest::test_einsum_dense_quantize_int4_weight_only\nEinsumDenseTest::test_einsum_dense_quantize_int8\nEinsumDenseTest::test_einsum_dense_quantize_int8_weight_only\nEinsumDenseTest::test_int4_block_size_serialization_grouped_block_128\nEinsumDenseTest::test_int4_block_size_serialization_grouped_block_64\nEinsumDenseTest::test_int4_block_size_serialization_per_channel_none\nEinsumDenseTest::test_int4_block_size_various_equations_ab_bcd_acd_grouped\nEinsumDenseTest::test_int4_block_size_various_equations_ab_bcd_acd_pc\nEinsumDenseTest::test_int4_block_size_various_equations_btd_df_btf_grouped\nEinsumDenseTest::test_int4_block_size_various_equations_btd_df_btf_pc\nEinsumDenseTest::test_int4_block_size_with_lora_grouped_block_64\nEinsumDenseTest::test_int4_block_size_with_lora_per_channel\nEinsumDenseTest::test_int4_grouped_multi_reduced_axes_attn_output_grouped\nEinsumDenseTest::test_int4_grouped_multi_reduced_axes_attn_output_pc\nEinsumDenseTest::test_int4_grouped_multi_reduced_axes_mha_value_grouped\nEinsumDenseTest::test_int4_grouped_multi_reduced_axes_mha_value_pc\nEinsumDenseTest::test_int4_multi_reduced_axes_serialization\nEinsumDenseTest::test_int4_multi_reduced_vs_single_reduced\nEinsumDenseTest::test_int4_quantization_block_size_grouped_block_128\nEinsumDenseTest::test_int4_quantization_block_size_grouped_block_64\nEinsumDenseTest::test_int4_quantization_block_size_per_channel_neg1\nEinsumDenseTest::test_int4_quantization_block_size_per_channel_none\nEinsumDenseTest::test_int4_subchannel_g_idx_serialization\nEinsumDenseTest::test_quantize_float8_inference\nEinsumDenseTest::test_quantize_int_int4\nEinsumDenseTest::test_quantize_int_int8\nEinsumDenseTest::test_quantize_with_specific_equations_int4_btd,df->btf\nEinsumDenseTest::test_quantize_with_specific_equations_int4_btd,ndh->btnh\nEinsumDenseTest::test_quantize_with_specific_equations_int4_btnh,nhd->btd\nEinsumDenseTest::test_quantize_with_specific_equations_int8_btd,df->btf\nEinsumDenseTest::test_quantize_with_specific_equations_int8_btd,ndh->btnh\nEinsumDenseTest::test_quantize_with_specific_equations_int8_btnh,nhd->btd\nFalseNegativesTest::test_unweighted\nFalseNegativesTest::test_weighted\nFalsePositivesTest::test_unweighted\nFalsePositivesTest::test_weighted\nFilterSafePathsTest::test_invalid_path_warning\nFilterSafePathsTest::test_member_within_base_dir\nFilterSafePathsTest::test_symbolic_link_in_base_dir\nFilterSafePathsTest::test_symlink_within_base_dir\nGetFileTest::test_join_simple\nGPTQQuantizerTest::test_quantize_clipping_behavior_extremes\nGPTQQuantizerTest::test_zero_scale_guard_no_nans_for_finite_inputs\nGroupedQuantizationParametersTest::test_grouped_weight_shapes_divisible\nGroupedQuantizationParametersTest::test_grouped_weight_shapes_non_divisible\nGroupedQueryAttentionTest::test_basics\nHistogramTest::test_histogram_predict_jit_compile_false\nHistogramTest::test_histogram_predict_jit_compile_true\nImageDatasetFromDirectoryTest::test_image_dataset_from_directory_binary_grain\nImageDatasetFromDirectoryTest::test_image_dataset_from_directory_color_modes_grain\nImageDatasetFromDirectoryTest::test_image_dataset_from_directory_crop_to_aspect_ratio_grain\nImageDatasetFromDirectoryTest::test_image_dataset_from_directory_follow_links_grain\nImageDatasetFromDirectoryTest::test_image_dataset_from_directory_manual_labels_grain\nImageDatasetFromDirectoryTest::test_image_dataset_from_directory_manual_labels_tf\nImageDatasetFromDirectoryTest::test_image_dataset_from_directory_multiclass_grain\nImageDatasetFromDirectoryTest::test_image_dataset_from_directory_no_labels_grain\nImageDatasetFromDirectoryTest::test_image_dataset_from_directory_not_batched_grain\nImageDatasetFromDirectoryTest::test_image_dataset_from_directory_pad_to_aspect_ratio_grain\nImageDatasetFromDirectoryTest::test_image_dataset_from_directory_shuffle_grain\nImageDatasetFromDirectoryTest::test_image_dataset_from_directory_shuffle_tf\nImageDatasetFromDirectoryTest::test_image_dataset_from_directory_validation_split_grain\nImageDatasetFromDirectoryTest::test_sample_count_grain\nImageOpsBehaviorTests::test_affine_transform\nImageOpsBehaviorTests::test_elastic_transform\nImageOpsBehaviorTests::test_gaussian_blur\nImageOpsBehaviorTests::test_map_coordinates\nImageOpsBehaviorTests::test_perspective_transform\nImageOpsBehaviorTests::test_resize\nImageOpsCorrectnessTest::test_affine_transform\nImageOpsCorrectnessTest::test_crop_images\nImageOpsCorrectnessTest::test_elastic_transform\nImageOpsCorrectnessTest::test_extract_patches\nImageOpsCorrectnessTest::test_gaussian_blur\nImageOpsCorrectnessTest::test_map_coordinates\nImageOpsCorrectnessTest::test_pad_images\nImageOpsCorrectnessTest::test_perspective_transform\nImageOpsCorrectnessTest::test_resize\nImageOpsCorrectnessTest::test_scale_and_translate\nImageOpsDtypeTest::test_affine_transform\nImageOpsDtypeTest::test_elastic_transform\nImageOpsDtypeTest::test_gaussian_blur\nImageOpsDtypeTest::test_map_coordinates\nImageOpsDtypeTest::test_perspective_transform\nImageOpsDtypeTest::test_resize\nImageOpsDtypeTest::test_scale_and_translate\nIsLinkInDirTest::test_is_link_in_dir_with_absolute_paths\nIsLinkInDirTest::test_is_link_in_dir_with_relative_paths\nIsPathInDirTest::test_is_path_in_dir_with_absolute_paths\nLambTest::test_correctness_with_golden\nLambTest::test_weight_decay\nLayerTest::test_call_context_args_with_custom_layers\nLayerTest::test_call_context_args_with_func_seq_models_as_layers\nLayerTest::test_complex_dtype_support\nLayerTest::test_context_arg_propagation_without_declaration\nLayerTest::test_context_args_with_triple_nesting_and_priority\nLayerTest::test_end_to_end_masking\nLayerTest::test_quantized_layer_with_remat\nLayerTest::test_stateless_call\nLinalgOpsCorrectnessTest::test_cholesky\nLinalgOpsCorrectnessTest::test_eig\nLinalgOpsCorrectnessTest::test_lstsq\nLinalgOpsCorrectnessTest::test_lu_factor\nLinalgOpsCorrectnessTest::test_norm_2_-2\nLinalgOpsCorrectnessTest::test_norm_2_2\nLinalgOpsCorrectnessTest::test_norm_2_nuc\nLinalgOpsCorrectnessTest::test_qr\nLinalgOpsCorrectnessTest::test_solve_triangular\nLinalgOpsCorrectnessTest::test_svd\nLinalgOpsDynamicShapeTest::test_qr\nLionTest::test_correctness_with_golden\nLionTest::test_weight_decay\nLoadModelTests::test_basic_load_bfloat16\nLoadWeightsTests::test_load_weights_h5_bfloat16_bfloat16\nLoadWeightsTests::test_load_weights_h5_bfloat16_float16\nLoadWeightsTests::test_load_weights_h5_bfloat16_float32\nLoadWeightsTests::test_load_weights_h5_bfloat16_float64\nLoadWeightsTests::test_load_weights_keras_bfloat16_bfloat16\nLoadWeightsTests::test_load_weights_keras_bfloat16_float16\nLoadWeightsTests::test_load_weights_keras_bfloat16_float32\nLoadWeightsTests::test_load_weights_keras_bfloat16_float64\nLoadWeightsTests::test_load_weights_keras_float16_bfloat16\nLoadWeightsTests::test_load_weights_keras_float32_bfloat16\nLoadWeightsTests::test_load_weights_keras_float64_bfloat16\nLoadWeightsTests::test_load_weights_weights.h5_bfloat16_bfloat16\nLoadWeightsTests::test_load_weights_weights.h5_bfloat16_float16\nLoadWeightsTests::test_load_weights_weights.h5_bfloat16_float32\nLoadWeightsTests::test_load_weights_weights.h5_bfloat16_float64\nLoadWeightsTests::test_load_weights_weights.h5_float16_bfloat16\nLoadWeightsTests::test_load_weights_weights.h5_float32_bfloat16\nLoadWeightsTests::test_load_weights_weights.h5_float64_bfloat16\nLossScaleOptimizerTest::test_apply_with_no_vars\nLossScaleOptimizerTest::test_downscaling_stateful\nLossScaleOptimizerTest::test_downscaling_stateless\nLossScaleOptimizerTest::test_finite_step_stateful\nLossScaleOptimizerTest::test_finite_step_stateless\nLossScaleOptimizerTest::test_finite_step_with_inner_loss_scale_stateful\nLossScaleOptimizerTest::test_finite_step_with_inner_loss_scale_stateless\nLossScaleOptimizerTest::test_finite_step_with_overwrite_stateful\nLossScaleOptimizerTest::test_finite_step_with_overwrite_stateless\nLossScaleOptimizerTest::test_infinite_step_stateful\nLossScaleOptimizerTest::test_infinite_step_stateless\nLossScaleOptimizerTest::test_iterations_update_stateful\nLossScaleOptimizerTest::test_iterations_update_stateless\nLossScaleOptimizerTest::test_upscaling_stateful\nLossScaleOptimizerTest::test_upscaling_stateless\nMathOpsCorrectnessTest::test_erfinv_operation_basic\nMathOpsCorrectnessTest::test_erfinv_operation_dtype\nMathOpsCorrectnessTest::test_logdet\nMaxPoolingBasicTest::test_max_pooling1d0\nMaxPoolingBasicTest::test_max_pooling1d1\nMaxPoolingBasicTest::test_max_pooling1d2\nMaxPoolingBasicTest::test_max_pooling2d0\nMaxPoolingBasicTest::test_max_pooling2d1\nMaxPoolingBasicTest::test_max_pooling2d2\nMaxPoolingBasicTest::test_max_pooling3d0\nMaxPoolingBasicTest::test_max_pooling3d1\nMaxPoolingBasicTest::test_max_pooling3d2\nMeanIoUTest::test_big_chunk\nMeanTest::test_weighted_dynamic_shapes\nMetricTest::test_stateless_update_state\nMetricWrapperTest::test_weighted_dynamic_shape\nMultiAUCTest::test_label_weights\nMultiAUCTest::test_manual_thresholds\nMultiAUCTest::test_pr_interpolation\nMultiAUCTest::test_pr_interpolation_unweighted\nMultiAUCTest::test_reset_state\nMultiAUCTest::test_unweighted\nMultiAUCTest::test_unweighted_all_correct\nMultiAUCTest::test_unweighted_from_logits\nMultiAUCTest::test_weighted_roc_interpolation\nMultiHeadAttentionTest::test_attention_axes_negative_indexing\nMultiHeadAttentionTest::test_basics\nMultiHeadAttentionTest::test_high_dim_attention_4d_inputs_1freebatch_mask2\nMultiHeadAttentionTest::test_high_dim_attention_4d_inputs_1freebatch_mask4\nMultiHeadAttentionTest::test_high_dim_attention_5d_inputs_2d_attention\nMultiHeadAttentionTest::test_multi_head_attention_output_shape_as_int\nMultiHeadAttentionTest::test_multi_head_attention_output_shape_as_tuple\nMultiHeadAttentionTest::test_quantize_int8\nNadamTest::test_correctness_with_golden\nNadamTest::test_weight_decay\nNNOpsBehaviorTest::test_invalid_strategy_ctc_decode\nNNOpsCorrectnessTest::test_ctc_decode\nNNOpsCorrectnessTest::test_glu\nNNOpsCorrectnessTest::test_polar_corectness\nNNOpsCorrectnessTest::test_sparsemax\nNNOpsDtypeTest::test_ctc_decode\nNNOpsDtypeTest::test_glu_\nNNOpsDtypeTest::test_polar_\nNNOpsDynamicShapeTest::test_glu\nNumpyDtypeTest::test_view\nNumpyOneInputOpsCorrectnessTest::test_conj\nNumpyOneInputOpsCorrectnessTest::test_imag\nNumpyOneInputOpsCorrectnessTest::test_isreal\nNumpyOneInputOpsCorrectnessTest::test_real\nNumpyOneInputOpsCorrectnessTest::test_view\nNumpyOneInputOpsDynamicShapeTest::test_view\nNumpyOneInputOpsStaticShapeTest::test_view\nOptimizerTest::test_constraints_are_applied\nOptimizerTest::test_ema\nOptimizerTest::test_gradient_accumulation\nOptimizerTest::test_gradient_accumulation_with_weigth_decay0\nOptimizerTest::test_gradient_accumulation_with_weigth_decay1\nOptimizerTest::test_gradient_accumulation_with_weigth_decay10\nOptimizerTest::test_gradient_accumulation_with_weigth_decay2\nOptimizerTest::test_gradient_accumulation_with_weigth_decay3\nOptimizerTest::test_gradient_accumulation_with_weigth_decay4\nOptimizerTest::test_gradient_accumulation_with_weigth_decay5\nOptimizerTest::test_gradient_accumulation_with_weigth_decay6\nOptimizerTest::test_gradient_accumulation_with_weigth_decay7\nOptimizerTest::test_gradient_accumulation_with_weigth_decay8\nOptimizerTest::test_gradient_accumulation_with_weigth_decay9\nOptimizerTest::test_overwrite_with_gradient_with_gradient_accumulation\nQrOpTest\nQuantizersTest::test_abs_max_quantizer\nQuantizersTest::test_compute_float8_amax_history\nQuantizersTest::test_compute_float8_scale\nQuantizersTest::test_grouped_quantize_with_padding\nQuantizersTest::test_grouped_vs_perchannel_accuracy\nRandomBehaviorTest::test_beta_tf_data_compatibility\nRandomCorrectnessTest::test_truncated_normal0\nRandomCorrectnessTest::test_truncated_normal1\nRandomCorrectnessTest::test_truncated_normal2\nRandomCorrectnessTest::test_truncated_normal3\nRandomCorrectnessTest::test_truncated_normal4\nRandomCorrectnessTest::test_truncated_normal5\nRandomCorrectnessTest::test_uniform0\nRandomCorrectnessTest::test_uniform1\nRandomCorrectnessTest::test_uniform2\nRandomCorrectnessTest::test_uniform3\nRandomCorrectnessTest::test_uniform4\nRandomDTypeTest::test_normal_bfloat16\nRandomDTypeTest::test_truncated_normal_bfloat16\nRandomDTypeTest::test_uniform_bfloat16\nReshapeTest::test_reshape_with_dynamic_batch_size_and_minus_one\nRMSpropTest::test_correctness_with_golden\nRMSpropTest::test_weight_decay\nSaveModelTests::test_objects_to_skip\nSavingAPITest::test_normalization_kpl\nSavingBattleTest::test_bidirectional_lstm_saving\nSavingTest::test_basics\nSavingTest::test_load_weights_only_with_unbuilt_model\nSavingTest::test_partial_load\nScheduleFreeAdamWTest::test_multiple_steps\nScheduleFreeAdamWTest::test_warmup\nScheduleFreeAdamWTest::test_weight_decay\nSerializationLibTest::test_custom_fn\nSerializationLibTest::test_custom_layer\nSGDTest::test_correctness_with_golden\nSGDTest::test_weight_decay\nSparseCategoricalCrossentropyTest::test_all_correct_unweighted\nSparseCategoricalCrossentropyTest::test_binary_segmentation\nSparseCategoricalCrossentropyTest::test_dtype_arg\nSparseCategoricalCrossentropyTest::test_ignore_class\nSparseCategoricalCrossentropyTest::test_multi_class_segmentation\nSparseCategoricalCrossentropyTest::test_no_reduction\nSparseCategoricalCrossentropyTest::test_sample_weighted\nSparseCategoricalCrossentropyTest::test_scalar_weighted\nSparseCategoricalCrossentropyTest::test_unweighted\nSpectralNormalizationTest::test_apply_layer\nStackedRNNTest::test_return_state_stacked_lstm_cell\nTestNumericalUtils::test_build_pos_neg_masks\nTestTrainer::test_predict_dropout\nTestTrainer::test_predict_flow_eager\nTestTrainer::test_predict_flow_graph_fn\nTestTrainer::test_predict_flow_jit\nTestTrainer::test_predict_flow_struct_eager\nTestTrainer::test_predict_flow_struct_graph_fn\nTestTrainer::test_predict_flow_struct_jit\nTestTrainer::test_predict_preserve_order_1_eager\nTestTrainer::test_predict_preserve_order_1_jit\nTestTrainer::test_predict_preserve_order_1_non_jit\nTestTrainer::test_predict_preserve_order_50_eager\nTestTrainer::test_predict_preserve_order_50_jit\nTestTrainer::test_predict_preserve_order_50_non_jit\nTestTrainer::test_symbolic_build\nTextDatasetFromDirectoryTest::test_sample_count_grain\nTextDatasetFromDirectoryTest::test_text_dataset_from_directory_binary_grain\nTextDatasetFromDirectoryTest::test_text_dataset_from_directory_follow_links_grain\nTextDatasetFromDirectoryTest::test_text_dataset_from_directory_manual_labels_grain\nTextDatasetFromDirectoryTest::test_text_dataset_from_directory_manual_labels_tf\nTextDatasetFromDirectoryTest::test_text_dataset_from_directory_multiclass_grain\nTextDatasetFromDirectoryTest::test_text_dataset_from_directory_not_batched_grain\nTextDatasetFromDirectoryTest::test_text_dataset_from_directory_standalone_grain\nTextDatasetFromDirectoryTest::test_text_dataset_from_directory_validation_split_grain\nTimeseriesDatasetTest::test_basics_grain\nTimeseriesDatasetTest::test_basics_tf\nTimeseriesDatasetTest::test_no_targets_grain\nTimeseriesDatasetTest::test_no_targets_tf\nTimeseriesDatasetTest::test_not_batched_grain\nTimeseriesDatasetTest::test_sampling_rate_grain\nTimeseriesDatasetTest::test_sampling_rate_tf\nTimeseriesDatasetTest::test_sequence_stride_grain\nTimeseriesDatasetTest::test_sequence_stride_tf\nTimeseriesDatasetTest::test_shuffle_grain\nTimeseriesDatasetTest::test_shuffle_tf\nTimeseriesDatasetTest::test_start_and_end_index_grain\nTimeseriesDatasetTest::test_timeseries_regression_grain\nTimeseriesDatasetTest::test_timeseries_regression_tf\nTrueNegativesTest::test_unweighted\nTrueNegativesTest::test_weighted\nTruePositiveTest::test_unweighted\nTruePositiveTest::test_weighted\nUnitNormalizationTest::test_un_basics\nUpSampling2dTest::test_upsampling_2d_lanczos_interpolation_methods\nUpSampling2dTest::test_upsampling_2d_various_interpolation_methods\n"
  },
  {
    "path": "keras/src/backend/openvino/excluded_tests.txt",
    "content": "keras/src/activations\nkeras/src/layers/preprocessing\nkeras/src/layers/regularization\nkeras/src/legacy"
  },
  {
    "path": "keras/src/backend/openvino/export.py",
    "content": "class OpenvinoExportArchive:\n    def track(self, resource):\n        raise NotImplementedError(\n            \"`track` is not implemented in the openvino backend.\"\n        )\n\n    def add_endpoint(self, name, fn, input_signature=None, **kwargs):\n        raise NotImplementedError(\n            \"`add_endpoint` is not implemented in the openvino backend.\"\n        )\n"
  },
  {
    "path": "keras/src/backend/openvino/image.py",
    "content": "import openvino.opset15 as ov_opset\n\nfrom keras.src import backend\nfrom keras.src.backend.openvino.core import OpenVINOKerasTensor\nfrom keras.src.backend.openvino.core import get_ov_output\n\n\ndef rgb_to_grayscale(images, data_format=None):\n    images = get_ov_output(images)\n    data_format = backend.standardize_data_format(data_format)\n    if images.get_partial_shape().rank not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n    channel_axis = -3 if data_format == \"channels_first\" else -1\n    if images.shape[channel_axis] not in (1, 3):\n        raise ValueError(\n            \"Invalid channel size: expected 3 (RGB) or 1 (Grayscale). \"\n            f\"Received input with shape: images.shape={images.shape}\"\n        )\n\n    if images.shape[channel_axis] == 3:\n        original_type = images.get_element_type()\n        rgb_weights = ov_opset.constant(\n            [0.2989, 0.5870, 0.1140], dtype=original_type\n        ).output(0)\n        if data_format == \"channels_first\":\n            rgb_weights = ov_opset.unsqueeze(rgb_weights, axes=[-2, -1]).output(\n                0\n            )\n        grayscales = ov_opset.multiply(images, rgb_weights).output(0)\n        grayscales = ov_opset.reduce_sum(\n            grayscales, reduction_axes=[channel_axis]\n        ).output(0)\n        grayscales = ov_opset.unsqueeze(grayscales, axes=[channel_axis]).output(\n            0\n        )\n        if grayscales.get_element_type() != original_type:\n            # Type of grayscales may be changed after unsqueeze, so we need to\n            # convert it back to the original type.\n            grayscales = ov_opset.convert(grayscales, original_type).output(0)\n\n    return OpenVINOKerasTensor(grayscales)\n\n\ndef rgb_to_hsv(images, data_format=None):\n    dtype = images.dtype\n    images = get_ov_output(images)\n    ov_type = images.get_element_type()\n    data_format = backend.standardize_data_format(data_format)\n    channels_axis = -1 if data_format == \"channels_last\" else -3\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n    if not backend.is_float_dtype(dtype):\n        raise ValueError(\n            \"Invalid images dtype: expected float dtype. \"\n            f\"Received: images.dtype={dtype}\"\n        )\n    eps = ov_opset.constant(backend.epsilon(), dtype=ov_type).output(0)\n    images = ov_opset.select(\n        ov_opset.less(ov_opset.abs(images), eps),\n        ov_opset.constant(0.0, dtype=ov_type),\n        images,\n    ).output(0)\n    rgb_channels = ov_opset.split(images, axis=channels_axis, num_splits=3)\n    r, g, b = (\n        rgb_channels.output(0),\n        rgb_channels.output(1),\n        rgb_channels.output(2),\n    )\n\n    def rgb_planes_to_hsv_planes(r, g, b):\n        value = ov_opset.maximum(ov_opset.maximum(r, g), b).output(0)\n        minimum = ov_opset.minimum(ov_opset.minimum(r, g), b).output(0)\n        range_ = ov_opset.subtract(value, minimum).output(0)\n\n        safe_value = ov_opset.select(\n            ov_opset.greater(value, ov_opset.constant(0.0, dtype=ov_type)),\n            value,\n            ov_opset.constant(1.0, dtype=ov_type),\n        ).output(0)\n        safe_range = ov_opset.select(\n            ov_opset.greater(range_, ov_opset.constant(0.0, dtype=ov_type)),\n            range_,\n            ov_opset.constant(1.0, dtype=ov_type),\n        ).output(0)\n\n        saturation = ov_opset.select(\n            ov_opset.greater(value, ov_opset.constant(0.0, dtype=ov_type)),\n            ov_opset.divide(range_, safe_value),\n            ov_opset.constant(0.0, dtype=ov_type),\n        ).output(0)\n        norm = ov_opset.divide(\n            ov_opset.constant(1.0, dtype=ov_type),\n            ov_opset.multiply(\n                ov_opset.constant(6.0, dtype=ov_type), safe_range\n            ),\n        ).output(0)\n\n        hue = ov_opset.select(\n            ov_opset.equal(value, g),\n            ov_opset.add(\n                ov_opset.multiply(norm, ov_opset.subtract(b, r)),\n                ov_opset.constant(2.0 / 6.0, dtype=ov_type),\n            ),\n            ov_opset.add(\n                ov_opset.multiply(norm, ov_opset.subtract(r, g)),\n                ov_opset.constant(4.0 / 6.0, dtype=ov_type),\n            ),\n        ).output(0)\n        hue = ov_opset.select(\n            ov_opset.equal(value, r),\n            ov_opset.multiply(norm, ov_opset.subtract(g, b)),\n            hue,\n        ).output(0)\n        hue = ov_opset.select(\n            ov_opset.greater(range_, ov_opset.constant(0.0, dtype=ov_type)),\n            hue,\n            ov_opset.constant(0.0, dtype=ov_type),\n        ).output(0)\n        hue = ov_opset.add(\n            hue,\n            ov_opset.convert(\n                ov_opset.less(hue, ov_opset.constant(0.0, dtype=ov_type)),\n                ov_type,\n            ),\n        ).output(0)\n        return hue, saturation, value\n\n    images = ov_opset.concat(\n        rgb_planes_to_hsv_planes(r, g, b), axis=channels_axis\n    ).output(0)\n    return OpenVINOKerasTensor(images)\n\n\ndef hsv_to_rgb(images, data_format=None):\n    dtype = images.dtype\n    images = get_ov_output(images)\n    ov_type = images.get_element_type()\n    data_format = backend.standardize_data_format(data_format)\n    channels_axis = -1 if data_format == \"channels_last\" else -3\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n    if not backend.is_float_dtype(dtype):\n        raise ValueError(\n            \"Invalid images dtype: expected float dtype. \"\n            f\"Received: images.dtype={dtype}\"\n        )\n    hsv_channels = ov_opset.split(images, axis=channels_axis, num_splits=3)\n    hue, saturation, value = (\n        hsv_channels.output(0),\n        hsv_channels.output(1),\n        hsv_channels.output(2),\n    )\n\n    def hsv_planes_to_rgb_planes(hue, saturation, value):\n        def channel_value(channel_delta, one_minus_saturation):\n            return ov_opset.multiply(\n                value,\n                ov_opset.add(\n                    one_minus_saturation,\n                    ov_opset.multiply(saturation, channel_delta),\n                ),\n            )\n\n        dh = ov_opset.multiply(\n            ov_opset.mod(hue, ov_opset.constant(1.0, dtype=ov_type)),\n            ov_opset.constant(6.0, dtype=ov_type),\n        ).output(0)\n        one_const = ov_opset.constant(1.0, dtype=ov_type).output(0)\n        two_const = ov_opset.constant(2.0, dtype=ov_type).output(0)\n        three_const = ov_opset.constant(3.0, dtype=ov_type).output(0)\n        four_const = ov_opset.constant(4.0, dtype=ov_type).output(0)\n        dr = ov_opset.subtract(\n            ov_opset.abs(ov_opset.subtract(dh, three_const)), one_const\n        ).output(0)\n        dr = ov_opset.clamp(dr, 0.0, 1.0).output(0)\n        dg = ov_opset.subtract(\n            two_const, ov_opset.abs(ov_opset.subtract(dh, two_const))\n        ).output(0)\n        dg = ov_opset.clamp(dg, 0.0, 1.0).output(0)\n        db = ov_opset.subtract(\n            two_const, ov_opset.abs(ov_opset.subtract(dh, four_const))\n        ).output(0)\n        db = ov_opset.clamp(db, 0.0, 1.0).output(0)\n        one_minus_saturation = ov_opset.subtract(one_const, saturation).output(\n            0\n        )\n\n        red = channel_value(dr, one_minus_saturation)\n        green = channel_value(dg, one_minus_saturation)\n        blue = channel_value(db, one_minus_saturation)\n        return red, green, blue\n\n    images = ov_opset.concat(\n        hsv_planes_to_rgb_planes(hue, saturation, value), axis=channels_axis\n    ).output(0)\n    return OpenVINOKerasTensor(images)\n\n\ndef resize(\n    image,\n    size,\n    interpolation=\"bilinear\",\n    antialias=False,\n    crop_to_aspect_ratio=False,\n    pad_to_aspect_ratio=False,\n    fill_mode=\"constant\",\n    fill_value=0.0,\n    data_format=\"channels_last\",\n):\n    raise NotImplementedError(\"`resize` is not supported with openvino backend\")\n\n\ndef affine_transform(\n    images,\n    transform,\n    interpolation=\"bilinear\",\n    fill_mode=\"constant\",\n    fill_value=0,\n    data_format=None,\n):\n    raise NotImplementedError(\n        \"`affine_transform` is not supported with openvino backend\"\n    )\n\n\ndef perspective_transform(\n    images,\n    start_points,\n    end_points,\n    interpolation=\"bilinear\",\n    fill_value=0,\n    data_format=None,\n):\n    raise NotImplementedError(\n        \"`perspective_transform` is not supported with openvino backend\"\n    )\n\n\ndef map_coordinates(\n    inputs, coordinates, order, fill_mode=\"constant\", fill_value=0\n):\n    raise NotImplementedError(\n        \"`map_coordinates` is not supported with openvino backend\"\n    )\n\n\ndef gaussian_blur(\n    images, kernel_size=(3, 3), sigma=(1.0, 1.0), data_format=None\n):\n    raise NotImplementedError(\n        \"`gaussian_blur` is not supported with openvino backend\"\n    )\n\n\ndef elastic_transform(\n    images,\n    alpha=20.0,\n    sigma=5.0,\n    interpolation=\"bilinear\",\n    fill_mode=\"reflect\",\n    fill_value=0.0,\n    seed=None,\n    data_format=None,\n):\n    raise NotImplementedError(\n        \"`elastic_transform` is not supported with openvino backend\"\n    )\n\n\ndef scale_and_translate(\n    images,\n    output_shape,\n    scale,\n    translation,\n    spatial_dims,\n    method,\n    antialias=True,\n):\n    raise NotImplementedError(\n        \"`scale_and_translate` is not supported with openvino backend\"\n    )\n"
  },
  {
    "path": "keras/src/backend/openvino/layer.py",
    "content": "class OpenvinoLayer:\n    pass\n"
  },
  {
    "path": "keras/src/backend/openvino/linalg.py",
    "content": "import openvino as ov\nimport openvino.opset15 as ov_opset\nfrom openvino import Model\nfrom openvino import Type\n\nfrom keras.src.backend import config\nfrom keras.src.backend import standardize_dtype\nfrom keras.src.backend.common import dtypes\nfrom keras.src.backend.openvino.core import OpenVINOKerasTensor\nfrom keras.src.backend.openvino.core import cast\nfrom keras.src.backend.openvino.core import convert_to_tensor\nfrom keras.src.backend.openvino.core import get_ov_output\n\n\ndef cholesky(a, upper=False):\n    raise NotImplementedError(\n        \"`cholesky` is not supported with openvino backend.\"\n    )\n\n\ndef cholesky_inverse(a, upper=False):\n    a = convert_to_tensor(a)\n    a_ov = get_ov_output(a)\n    if upper:\n        # Reconstruct A = U^T @ U, then invert\n        reconstructed_matrix = ov_opset.matmul(a_ov, a_ov, True, False).output(\n            0\n        )\n    else:\n        # Reconstruct A = L @ L^T, then invert\n        reconstructed_matrix = ov_opset.matmul(a_ov, a_ov, False, True).output(\n            0\n        )\n    result = ov_opset.inverse(reconstructed_matrix, adjoint=False).output(0)\n    return OpenVINOKerasTensor(result)\n\n\ndef det(a):\n    a = convert_to_tensor(a)\n    a_ov = get_ov_output(a)\n    original_type = a_ov.get_element_type()\n\n    # Avoid constant folding bug for f64 in OpenVINO CPU Loop evaluate\n    if original_type == Type.f64:\n        a_ov = ov_opset.convert(a_ov, Type.f32).output(0)\n\n    a_shape = ov_opset.shape_of(a_ov, output_type=\"i32\").output(0)\n\n    rank = a_ov.get_partial_shape().rank.get_length()\n\n    minus_1 = ov_opset.constant([-1], Type.i32).output(0)\n    minus_2 = ov_opset.constant([-2], Type.i32).output(0)\n\n    N_node_1d = ov_opset.gather(\n        a_shape, minus_1, ov_opset.constant(0, Type.i32).output(0)\n    ).output(0)\n    N_node_scalar = ov_opset.squeeze(\n        N_node_1d, ov_opset.constant([0], Type.i32).output(0)\n    ).output(0)\n\n    num_batch_dims = rank - 2\n    if num_batch_dims > 0:\n        batch_dims_shape = ov_opset.broadcast(\n            ov_opset.constant([1], Type.i32).output(0),\n            ov_opset.constant([num_batch_dims], Type.i32).output(0),\n        ).output(0)\n        eye_shape = ov_opset.concat(\n            [batch_dims_shape, N_node_1d, N_node_1d], 0\n        ).output(0)\n    else:\n        eye_shape = ov_opset.concat([N_node_1d, N_node_1d], 0).output(0)\n\n    eye = ov_opset.eye(\n        N_node_scalar,\n        N_node_scalar,\n        ov_opset.constant(0, Type.i32).output(0),\n        a_ov.get_element_type(),\n    ).output(0)\n    eye_reshaped = ov_opset.reshape(eye, eye_shape, False).output(0)\n\n    trip_count = N_node_scalar\n    loop = ov_opset.loop(\n        trip_count, ov_opset.constant(True, Type.boolean).output(0)\n    )\n\n    M_param = ov_opset.parameter([-1] * rank, a_ov.get_element_type(), \"M\")\n    k_param = ov_opset.parameter([], Type.i32, \"k\")\n    A_body_param = ov_opset.parameter(\n        [-1] * rank, a_ov.get_element_type(), \"A_body\"\n    )\n    eye_body_param = ov_opset.parameter(\n        [-1] * rank, a_ov.get_element_type(), \"eye_body\"\n    )\n\n    k_next = ov_opset.add(\n        k_param.output(0), ov_opset.constant(1, Type.i32).output(0)\n    ).output(0)\n    k_f32 = ov_opset.convert(k_next, a_ov.get_element_type()).output(0)\n\n    M_diag = ov_opset.multiply(\n        M_param.output(0), eye_body_param.output(0)\n    ).output(0)\n    trace_axes = ov_opset.concat([minus_2, minus_1], 0).output(0)\n    trace = ov_opset.reduce_sum(M_diag, trace_axes, keep_dims=True).output(0)\n\n    minus_one = ov_opset.constant(-1.0, a_ov.get_element_type()).output(0)\n    c_k_factor = ov_opset.divide(minus_one, k_f32).output(0)\n    c_k = ov_opset.multiply(c_k_factor, trace).output(0)\n\n    c_k_I = ov_opset.multiply(c_k, eye_body_param.output(0)).output(0)\n    M_plus_c_k_I = ov_opset.add(M_param.output(0), c_k_I).output(0)\n\n    M_next = ov_opset.matmul(\n        A_body_param.output(0), M_plus_c_k_I, False, False\n    ).output(0)\n\n    cond_next = ov_opset.constant(True, Type.boolean).output(0)\n\n    body = ov.Model(\n        [M_next, k_next, c_k, cond_next],\n        [M_param, k_param, A_body_param, eye_body_param],\n    )\n    loop.set_function(body)\n    loop.set_special_body_ports([-1, 3])\n\n    loop.set_merged_input(M_param, a_ov, M_next)\n    loop.set_merged_input(\n        k_param, ov_opset.constant(0, Type.i32).output(0), k_next\n    )\n    loop.set_invariant_input(A_body_param, a_ov)\n    loop.set_invariant_input(eye_body_param, eye_reshaped)\n\n    out_c_k = loop.get_iter_value(c_k, -1)\n\n    det_c_k = ov_opset.squeeze(out_c_k, trace_axes).output(0)\n\n    N_mod_2 = ov_opset.mod(\n        N_node_scalar, ov_opset.constant(2, Type.i32).output(0)\n    ).output(0)\n    N_mod_2_f32 = ov_opset.convert(N_mod_2, a_ov.get_element_type()).output(0)\n    one = ov_opset.constant(1.0, a_ov.get_element_type()).output(0)\n    two = ov_opset.constant(2.0, a_ov.get_element_type()).output(0)\n    sign = ov_opset.subtract(\n        one, ov_opset.multiply(two, N_mod_2_f32).output(0)\n    ).output(0)\n\n    det = ov_opset.multiply(det_c_k, sign).output(0)\n\n    if original_type == Type.f64:\n        det = ov_opset.convert(det, Type.f64).output(0)\n\n    return OpenVINOKerasTensor(det)\n\n\ndef eig(a):\n    raise NotImplementedError(\"`eig` is not supported with openvino backend\")\n\n\ndef eigh(a):\n    a = convert_to_tensor(a)\n    a_ov = get_ov_output(a)\n    a_ov_type = a_ov.get_element_type()\n    if not a_ov_type.is_real():\n        a_ov = ov_opset.convert(a_ov, Type.f32).output(0)\n        out_ov_type = Type.f32\n    else:\n        out_ov_type = a_ov_type\n    zero_const = ov_opset.constant(0, Type.i32).output(0)\n    one_const = ov_opset.constant(1, Type.i32).output(0)\n    minus_one_const = ov_opset.constant(-1, Type.i32).output(0)\n    a_shape = ov_opset.shape_of(a_ov, Type.i32).output(0)\n    rank = a_ov.get_partial_shape().rank.get_length()\n    if rank == 2:\n        n = ov_opset.gather(\n            a_shape, ov_opset.constant(0, Type.i32), zero_const\n        ).output(0)\n        n_int = n\n        batch_size_prod = ov_opset.constant(1, Type.i32).output(0)\n    else:\n        n = ov_opset.gather(a_shape, minus_one_const, zero_const).output(0)\n        n_int = n\n        batch_shape = ov_opset.slice(\n            a_shape,\n            ov_opset.constant([0], Type.i32),\n            ov_opset.constant([-2], Type.i32),\n            ov_opset.constant([1], Type.i32),\n            ov_opset.constant([0], Type.i32),\n        ).output(0)\n        batch_size_prod = ov_opset.reduce_prod(\n            batch_shape, zero_const, False\n        ).output(0)\n    a_flat_shape = ov_opset.concat(\n        [\n            ov_opset.unsqueeze(batch_size_prod, zero_const).output(0),\n            ov_opset.unsqueeze(n, zero_const).output(0),\n            ov_opset.unsqueeze(n, zero_const).output(0),\n        ],\n        axis=0,\n    ).output(0)\n    A_flat = ov_opset.reshape(a_ov, a_flat_shape, False).output(0)\n    range_n = ov_opset.range(\n        zero_const, n, one_const, output_type=Type.i32\n    ).output(0)\n    eye_n = ov_opset.one_hot(\n        range_n,\n        n,\n        ov_opset.constant(1.0, out_ov_type),\n        ov_opset.constant(0.0, out_ov_type),\n        axis=-1,\n    ).output(0)\n    V_flat = ov_opset.broadcast(eye_n, a_flat_shape).output(0)\n    n_minus_one = ov_opset.subtract(n_int, one_const).output(0)\n    n_squared_minus_n = ov_opset.multiply(n_int, n_minus_one).output(0)\n    sweep_iters = ov_opset.divide(\n        n_squared_minus_n, ov_opset.constant(2, Type.i32)\n    ).output(0)\n    max_iter = ov_opset.multiply(\n        ov_opset.constant(15, Type.i32), sweep_iters\n    ).output(0)\n    trip_count = max_iter\n    execution_cond = ov_opset.constant(True, Type.boolean).output(0)\n    loop = ov_opset.loop(trip_count, execution_cond)\n    A_param = ov_opset.parameter(\n        A_flat.get_partial_shape(), A_flat.get_element_type()\n    )\n    V_param = ov_opset.parameter(\n        V_flat.get_partial_shape(), V_flat.get_element_type()\n    )\n    A_curr = A_param.output(0)\n    V_curr = V_param.output(0)\n    A_curr_shape = ov_opset.shape_of(A_curr, Type.i32).output(0)\n    l_batch_size_prod = ov_opset.gather(\n        A_curr_shape, zero_const, zero_const\n    ).output(0)\n    l_n = ov_opset.gather(A_curr_shape, minus_one_const, zero_const).output(0)\n    l_flat_shape = A_curr_shape\n    l_range_n = ov_opset.range(\n        zero_const, l_n, one_const, output_type=Type.i32\n    ).output(0)\n    l_eye_n = ov_opset.one_hot(\n        l_range_n,\n        l_n,\n        ov_opset.constant(1.0, out_ov_type),\n        ov_opset.constant(0.0, out_ov_type),\n        axis=-1,\n    ).output(0)\n    mask = ov_opset.subtract(\n        ov_opset.constant(1.0, out_ov_type), l_eye_n\n    ).output(0)\n    mask_b = ov_opset.broadcast(mask, l_flat_shape).output(0)\n    A_off = ov_opset.multiply(A_curr, mask_b).output(0)\n    A_off_abs = ov_opset.abs(A_off).output(0)\n    flat_n2 = ov_opset.concat(\n        [\n            ov_opset.unsqueeze(l_batch_size_prod, zero_const).output(0),\n            ov_opset.unsqueeze(ov_opset.multiply(l_n, l_n), zero_const).output(\n                0\n            ),\n        ],\n        axis=0,\n    ).output(0)\n    A_off_abs_flat = ov_opset.reshape(A_off_abs, flat_n2, False).output(0)\n    max_val = ov_opset.reduce_max(A_off_abs_flat, one_const, False).output(0)\n    epsilon = ov_opset.constant(1e-6, out_ov_type).output(0)\n    continue_cond = ov_opset.reduce_logical_or(\n        ov_opset.greater(max_val, epsilon), zero_const, False\n    ).output(0)\n    topk = ov_opset.topk(\n        A_off_abs_flat, ov_opset.constant(1, Type.i32), 1, \"max\", \"value\"\n    )\n    argmax_flat = topk.output(1)  # shape [B, 1]\n    argmax_flat_sq = ov_opset.squeeze(argmax_flat, one_const).output(\n        0\n    )  # shape [B]\n    p = ov_opset.divide(argmax_flat_sq, l_n).output(0)  # shape [B]\n    q = ov_opset.mod(argmax_flat_sq, l_n).output(0)  # shape [B]\n    p_unsqueezed = ov_opset.unsqueeze(p, one_const).output(0)\n    q_unsqueezed = ov_opset.unsqueeze(q, one_const).output(0)\n    b_indices = ov_opset.range(\n        zero_const, l_batch_size_prod, one_const, output_type=Type.i32\n    ).output(0)\n    b_unsqueezed = ov_opset.unsqueeze(b_indices, one_const).output(0)\n    pp_indices = ov_opset.concat(\n        [b_unsqueezed, p_unsqueezed, p_unsqueezed], axis=1\n    ).output(0)\n    qq_indices = ov_opset.concat(\n        [b_unsqueezed, q_unsqueezed, q_unsqueezed], axis=1\n    ).output(0)\n    pq_indices = ov_opset.concat(\n        [b_unsqueezed, p_unsqueezed, q_unsqueezed], axis=1\n    ).output(0)\n    App = ov_opset.gather_nd(A_curr, pp_indices).output(0)  # shape [B]\n    Aqq = ov_opset.gather_nd(A_curr, qq_indices).output(0)\n    Apq = ov_opset.gather_nd(A_curr, pq_indices).output(0)\n    zero_out = ov_opset.constant(0.0, out_ov_type).output(0)\n    is_p_eq_q = ov_opset.equal(p, q).output(0)\n    is_apq_zero = ov_opset.logical_or(\n        ov_opset.equal(Apq, zero_out).output(0), is_p_eq_q\n    ).output(0)\n    safe_Apq = ov_opset.select(\n        is_apq_zero, ov_opset.constant(1.0, out_ov_type), Apq\n    ).output(0)\n    theta = ov_opset.divide(\n        ov_opset.subtract(Aqq, App),\n        ov_opset.multiply(ov_opset.constant(2.0, out_ov_type), safe_Apq),\n    ).output(0)\n    theta_abs = ov_opset.abs(theta).output(0)\n    theta_sign = ov_opset.sign(theta).output(0)\n    theta_sign = ov_opset.select(\n        ov_opset.equal(theta, zero_out),\n        ov_opset.constant(1.0, out_ov_type),\n        theta_sign,\n    ).output(0)\n    sqrt_term = ov_opset.sqrt(\n        ov_opset.add(\n            ov_opset.multiply(theta, theta), ov_opset.constant(1.0, out_ov_type)\n        )\n    ).output(0)\n    t = ov_opset.divide(theta_sign, ov_opset.add(theta_abs, sqrt_term)).output(\n        0\n    )\n    t = ov_opset.select(is_apq_zero, zero_out, t).output(0)\n    c = ov_opset.divide(\n        ov_opset.constant(1.0, out_ov_type),\n        ov_opset.sqrt(\n            ov_opset.add(\n                ov_opset.multiply(t, t), ov_opset.constant(1.0, out_ov_type)\n            )\n        ),\n    ).output(0)\n    s = ov_opset.multiply(c, t).output(0)\n    R = ov_opset.broadcast(l_eye_n, l_flat_shape).output(0)\n    c_safe = ov_opset.select(\n        is_p_eq_q, ov_opset.constant(1.0, out_ov_type), c\n    ).output(0)\n    s_safe = ov_opset.select(is_p_eq_q, zero_out, s).output(0)\n    c_updates = c_safe\n    s_updates = s_safe\n    neg_s_updates = ov_opset.negative(s_safe).output(0)\n    p_safe = ov_opset.select(\n        is_p_eq_q, ov_opset.constant(0, Type.i32), p\n    ).output(0)\n    q_safe = ov_opset.select(\n        is_p_eq_q, ov_opset.constant(1, Type.i32), q\n    ).output(0)\n    p_safe_unsqueezed = ov_opset.unsqueeze(p_safe, one_const).output(0)\n    q_safe_unsqueezed = ov_opset.unsqueeze(q_safe, one_const).output(0)\n    pp_safe_indices = ov_opset.concat(\n        [b_unsqueezed, p_safe_unsqueezed, p_safe_unsqueezed], axis=1\n    ).output(0)\n    qq_safe_indices = ov_opset.concat(\n        [b_unsqueezed, q_safe_unsqueezed, q_safe_unsqueezed], axis=1\n    ).output(0)\n    pq_safe_indices = ov_opset.concat(\n        [b_unsqueezed, p_safe_unsqueezed, q_safe_unsqueezed], axis=1\n    ).output(0)\n    qp_safe_indices = ov_opset.concat(\n        [b_unsqueezed, q_safe_unsqueezed, p_safe_unsqueezed], axis=1\n    ).output(0)\n\n    R = ov_opset.scatter_nd_update(R, pp_safe_indices, c_updates).output(0)\n    R = ov_opset.scatter_nd_update(R, qq_safe_indices, c_updates).output(0)\n    R = ov_opset.scatter_nd_update(R, pq_safe_indices, s_updates).output(0)\n    R = ov_opset.scatter_nd_update(R, qp_safe_indices, neg_s_updates).output(0)\n\n    # Transpose R to R^T: swap last two dims\n    RT = ov_opset.transpose(R, ov_opset.constant([0, 2, 1], Type.i32)).output(0)\n\n    A_next = ov_opset.matmul(\n        RT, ov_opset.matmul(A_curr, R, False, False), False, False\n    ).output(0)\n    V_next = ov_opset.matmul(V_curr, R, False, False).output(0)\n\n    # If not continue_cond, just return A_curr and V_curr\n    A_next = ov_opset.select(\n        ov_opset.unsqueeze(\n            ov_opset.unsqueeze(continue_cond, zero_const), zero_const\n        ).output(0),\n        A_next,\n        A_curr,\n    ).output(0)\n\n    V_next = ov_opset.select(\n        ov_opset.unsqueeze(\n            ov_opset.unsqueeze(continue_cond, zero_const), zero_const\n        ).output(0),\n        V_next,\n        V_curr,\n    ).output(0)\n    body = Model(\n        [continue_cond, A_next, V_next], [A_param, V_param], \"jacobi_loop\"\n    )\n    loop.set_function(body)\n    loop.set_special_body_ports([-1, 0])\n    loop.set_merged_input(A_param, A_flat, A_next)\n    loop.set_merged_input(V_param, V_flat, V_next)\n    A_out = loop.get_iter_value(A_next)\n    V_out = loop.get_iter_value(V_next)\n    eigenvalues_flat = ov_opset.reduce_sum(\n        ov_opset.multiply(A_out, eye_n), minus_one_const, False\n    ).output(0)\n    neg_eigenvalues = ov_opset.negative(eigenvalues_flat).output(0)\n    topk_sort = ov_opset.topk(neg_eigenvalues, n, -1, \"max\", \"value\")\n    w_flat = ov_opset.negative(topk_sort.output(0)).output(0)\n    sort_indices = topk_sort.output(1)\n    v_indices = ov_opset.broadcast(\n        ov_opset.unsqueeze(sort_indices, one_const).output(0), a_flat_shape\n    ).output(0)\n    v_flat = ov_opset.gather_elements(V_out, v_indices, -1).output(0)\n    if rank == 2:\n        w_final = ov_opset.squeeze(w_flat, zero_const).output(0)\n        v_final = ov_opset.squeeze(v_flat, zero_const).output(0)\n    else:\n        w_shape_final = ov_opset.concat(\n            [batch_shape, ov_opset.unsqueeze(n, zero_const).output(0)], axis=0\n        ).output(0)\n        w_final = ov_opset.reshape(w_flat, w_shape_final, False).output(0)\n        v_final = ov_opset.reshape(v_flat, a_shape, False).output(0)\n    if out_ov_type == Type.f64:\n        w_final = ov_opset.convert(w_final, Type.f64).output(0)\n        v_final = ov_opset.convert(v_final, Type.f64).output(0)\n\n    return (\n        OpenVINOKerasTensor(w_final),\n        OpenVINOKerasTensor(v_final),\n    )\n\n\ndef inv(a):\n    a = convert_to_tensor(a)\n    a_ov = get_ov_output(a)\n    result = ov_opset.inverse(a_ov, adjoint=False).output(0)\n    return OpenVINOKerasTensor(result)\n\n\ndef lu_factor(a):\n    raise NotImplementedError(\n        \"`lu_factor` is not supported with openvino backend\"\n    )\n\n\ndef norm(x, ord=None, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    x_shape = tuple(x.shape)\n    ndim = len(x_shape)\n\n    if axis is None:\n        axis = tuple(range(ndim))\n    elif isinstance(axis, int):\n        axis = (axis,)\n    if any(a < -ndim or a >= ndim for a in axis):\n        raise ValueError(\n            \"All `axis` values must be in the range [-ndim, ndim). \"\n            f\"Received inputs with ndim={ndim}, while axis={axis}\"\n        )\n    axis = axis[0] if len(axis) == 1 else axis\n    num_axes = 1 if isinstance(axis, int) else len(axis)\n\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = cast(x, dtype)\n\n    x_ov = get_ov_output(x)\n\n    # Ref: jax.numpy.linalg.norm\n    if num_axes == 1:\n        if ord is None or ord == 2:\n            # L2 norm: sqrt(sum(x * conj(x)))\n            x_conj = x_ov\n            x_sq = ov_opset.multiply(x_conj, x_conj).output(0)\n            axis_for_const = list(axis) if isinstance(axis, tuple) else axis\n            axis_const = ov_opset.constant(axis_for_const, Type.i32).output(0)\n            norm_result = ov_opset.reduce_sum(\n                x_sq, axis_const, keepdims\n            ).output(0)\n            norm_result = ov_opset.sqrt(norm_result).output(0)\n        elif ord == float(\"inf\"):\n            axis_for_const = list(axis) if isinstance(axis, tuple) else axis\n            axis_const = ov_opset.constant(axis_for_const, Type.i32).output(0)\n            x_abs = ov_opset.abs(x_ov).output(0)\n            norm_result = ov_opset.reduce_max(\n                x_abs, axis_const, keepdims\n            ).output(0)\n        elif ord == float(\"-inf\"):\n            axis_for_const = list(axis) if isinstance(axis, tuple) else axis\n            axis_const = ov_opset.constant(axis_for_const, Type.i32).output(0)\n            x_abs = ov_opset.abs(x_ov).output(0)\n            norm_result = ov_opset.reduce_min(\n                x_abs, axis_const, keepdims\n            ).output(0)\n        elif ord == 0:\n            # Count non-zero elements\n            axis_for_const = list(axis) if isinstance(axis, tuple) else axis\n            axis_const = ov_opset.constant(axis_for_const, Type.i32).output(0)\n            zero = ov_opset.constant(0.0, Type.f32).output(0)\n            not_equal = ov_opset.not_equal(x_ov, zero).output(0)\n            not_equal_float = ov_opset.convert(not_equal, Type.f32).output(0)\n            norm_result = ov_opset.reduce_sum(\n                not_equal_float, axis_const, keepdims\n            ).output(0)\n        elif ord == 1:\n            # L1 norm: sum(|x|)\n            axis_for_const = list(axis) if isinstance(axis, tuple) else axis\n            axis_const = ov_opset.constant(axis_for_const, Type.i32).output(0)\n            x_abs = ov_opset.abs(x_ov).output(0)\n            norm_result = ov_opset.reduce_sum(\n                x_abs, axis_const, keepdims\n            ).output(0)\n        elif isinstance(ord, str):\n            raise ValueError(\n                f\"Invalid `ord` argument for vector norm. Received: ord={ord}\"\n            )\n        else:\n            # p-norm: (sum(|x|^p))^(1/p)\n            ord_tensor = convert_to_tensor(ord, dtype=dtype)\n            ord_ov = get_ov_output(ord_tensor)\n            axis_for_const = list(axis) if isinstance(axis, tuple) else axis\n            axis_const = ov_opset.constant(axis_for_const, Type.i32).output(0)\n            x_abs = ov_opset.abs(x_ov).output(0)\n            x_pow = ov_opset.power(x_abs, ord_ov).output(0)\n            sum_pow = ov_opset.reduce_sum(x_pow, axis_const, keepdims).output(0)\n            one = convert_to_tensor(1.0, dtype=dtype)\n            one_ov = get_ov_output(one)\n            inv_ord = ov_opset.divide(one_ov, ord_ov).output(0)\n            norm_result = ov_opset.power(sum_pow, inv_ord).output(0)\n\n    elif num_axes == 2:\n        row_axis, col_axis = axis[0], axis[1]\n        row_axis = row_axis + ndim if row_axis < 0 else row_axis\n        col_axis = col_axis + ndim if col_axis < 0 else col_axis\n\n        if ord is None or ord == \"fro\":\n            # Frobenius norm: sqrt(sum(x * conj(x)))\n            x_sq = ov_opset.multiply(x_ov, x_ov).output(0)\n            axis_for_const = list(axis) if isinstance(axis, tuple) else axis\n            axis_const = ov_opset.constant(axis_for_const, Type.i32).output(0)\n            sum_sq = ov_opset.reduce_sum(x_sq, axis_const, keepdims).output(0)\n            norm_result = ov_opset.sqrt(sum_sq).output(0)\n        elif ord == 1:\n            # Maximum absolute column sum\n            if not keepdims and col_axis > row_axis:\n                col_axis -= 1\n            row_axis_const = ov_opset.constant(row_axis, Type.i32).output(0)\n            col_axis_const = ov_opset.constant(col_axis, Type.i32).output(0)\n\n            x_abs = ov_opset.abs(x_ov).output(0)\n            col_sum = ov_opset.reduce_sum(\n                x_abs, row_axis_const, keep_dims=keepdims\n            ).output(0)\n            norm_result = ov_opset.reduce_max(\n                col_sum, col_axis_const, keep_dims=keepdims\n            ).output(0)\n        elif ord == -1:\n            # Minimum absolute column sum\n            if not keepdims and col_axis > row_axis:\n                col_axis -= 1\n            row_axis_const = ov_opset.constant(row_axis, Type.i32).output(0)\n            col_axis_const = ov_opset.constant(col_axis, Type.i32).output(0)\n\n            x_abs = ov_opset.abs(x_ov).output(0)\n            col_sum = ov_opset.reduce_sum(\n                x_abs, row_axis_const, keep_dims=keepdims\n            ).output(0)\n            norm_result = ov_opset.reduce_min(\n                col_sum, col_axis_const, keep_dims=keepdims\n            ).output(0)\n        elif ord == float(\"inf\"):\n            # Maximum absolute row sum\n            if not keepdims and row_axis > col_axis:\n                row_axis -= 1\n            col_axis_const = ov_opset.constant(col_axis, Type.i32).output(0)\n            row_axis_const = ov_opset.constant(row_axis, Type.i32).output(0)\n\n            x_abs = ov_opset.abs(x_ov).output(0)\n            row_sum = ov_opset.reduce_sum(\n                x_abs, col_axis_const, keep_dims=keepdims\n            ).output(0)\n            norm_result = ov_opset.reduce_max(\n                row_sum, row_axis_const, keep_dims=keepdims\n            ).output(0)\n        elif ord == float(\"-inf\"):\n            # Minimum absolute row sum\n            if not keepdims and row_axis > col_axis:\n                row_axis -= 1\n            col_axis_const = ov_opset.constant(col_axis, Type.i32).output(0)\n            row_axis_const = ov_opset.constant(row_axis, Type.i32).output(0)\n\n            x_abs = ov_opset.abs(x_ov).output(0)\n            row_sum = ov_opset.reduce_sum(\n                x_abs, col_axis_const, keep_dims=keepdims\n            ).output(0)\n            norm_result = ov_opset.reduce_min(\n                row_sum, row_axis_const, keep_dims=keepdims\n            ).output(0)\n        elif ord in (\"nuc\", 2, -2):\n            # Nuclear norm, spectral norm, and minimum singular value\n            # These require SVD which is not supported in OpenVINO backend\n            raise NotImplementedError(\n                f\"`norm` with ord={ord} for matrix norms requires SVD \"\n                \"which is not supported with openvino backend\"\n            )\n        else:\n            raise ValueError(\n                f\"Invalid `ord` argument for matrix norm. Received: ord={ord}\"\n            )\n    else:\n        raise ValueError(f\"Invalid axis values. Received: axis={axis}\")\n\n    return OpenVINOKerasTensor(norm_result)\n\n\ndef qr(x, mode=\"reduced\"):\n    raise NotImplementedError(\"`qr` is not supported with openvino backend\")\n\n\ndef solve(a, b):\n    a = convert_to_tensor(a)\n    b = convert_to_tensor(b)\n    a_ov = get_ov_output(a)\n    b_ov = get_ov_output(b)\n    squeeze = b.ndim == a.ndim - 1\n    if squeeze:\n        minus_one = ov_opset.constant([-1], Type.i32).output(0)\n        b_ov = ov_opset.unsqueeze(b_ov, minus_one).output(0)\n    a_inv = ov_opset.inverse(a_ov, adjoint=False).output(0)\n    result = ov_opset.matmul(a_inv, b_ov, False, False).output(0)\n    if squeeze:\n        minus_one = ov_opset.constant([-1], Type.i32).output(0)\n        result = ov_opset.squeeze(result, minus_one).output(0)\n    return OpenVINOKerasTensor(result)\n\n\ndef solve_triangular(a, b, lower=False):\n    raise NotImplementedError(\n        \"`solve_triangular` is not supported with openvino backend\"\n    )\n\n\ndef svd(x, full_matrices=True, compute_uv=True):\n    raise NotImplementedError(\"`svd` is not supported with openvino backend\")\n\n\ndef lstsq(a, b, rcond=None):\n    raise NotImplementedError(\"`lstsq` is not supported with openvino backend\")\n\n\ndef jvp(fun, primals, tangents, has_aux=False):\n    raise NotImplementedError(\"`jvp` is not supported with openvino backend\")\n"
  },
  {
    "path": "keras/src/backend/openvino/math.py",
    "content": "import numpy as np\nimport openvino.opset15 as ov_opset\nimport scipy.signal\nfrom openvino import Type\n\nfrom keras.src.backend.openvino.core import OpenVINOKerasTensor\nfrom keras.src.backend.openvino.core import cast\nfrom keras.src.backend.openvino.core import get_ov_output\nfrom keras.src.backend.openvino.core import standardize_dtype\nfrom keras.src.backend.openvino.numpy import stack\n\nINT32_MAX = 2**31 - 1\n\n\ndef _segment_reduction_fn(\n    data, segment_ids, reduction_method, num_segments, sorted\n):\n    data = get_ov_output(data)\n    segment_ids = get_ov_output(segment_ids)\n\n    if num_segments is None:\n        max_id = ov_opset.reduce_max(\n            segment_ids, ov_opset.constant([0], Type.i32), keep_dims=False\n        ).output(0)\n        num_segments = ov_opset.add(\n            max_id, ov_opset.constant(1, max_id.get_element_type())\n        ).output(0)\n    else:\n        num_segments = ov_opset.constant(\n            num_segments, segment_ids.get_element_type()\n        ).output(0)\n\n    is_negative = ov_opset.less(\n        segment_ids, ov_opset.constant(0, segment_ids.get_element_type())\n    ).output(0)\n    safe_segment_ids = ov_opset.select(\n        is_negative, num_segments, segment_ids\n    ).output(0)\n    indices = ov_opset.unsqueeze(\n        safe_segment_ids, ov_opset.constant(-1, Type.i32)\n    ).output(0)\n\n    num_segments_plus_1 = ov_opset.add(\n        num_segments, ov_opset.constant(1, num_segments.get_element_type())\n    ).output(0)\n\n    data_shape = data.get_partial_shape()\n    rank = data_shape.rank.get_length() if data_shape.rank.is_static else -1\n\n    if rank > 1:\n        data_shape_node = ov_opset.shape_of(data, output_type=Type.i32).output(\n            0\n        )\n        rest_shape = ov_opset.slice(\n            data_shape_node,\n            start=ov_opset.constant([1], Type.i32),\n            stop=ov_opset.constant([2147483647], Type.i32),\n            step=ov_opset.constant([1], Type.i32),\n            axes=ov_opset.constant([0], Type.i32),\n        ).output(0)\n        num_seg_node = ov_opset.unsqueeze(\n            num_segments_plus_1, ov_opset.constant(0, Type.i32)\n        ).output(0)\n        buffer_shape = ov_opset.concat(\n            [num_seg_node, rest_shape], axis=0\n        ).output(0)\n    else:\n        buffer_shape = ov_opset.unsqueeze(\n            num_segments_plus_1, ov_opset.constant(0, Type.i32)\n        ).output(0)\n\n    if reduction_method == \"max\":\n        from keras.src.backend.openvino.core import DTYPES_MIN\n\n        data_type = data.get_element_type()\n        if data_type.is_real():\n            init_val = np.array(-np.inf, dtype=np.float32)\n        else:\n            init_val = DTYPES_MIN[data_type]\n    else:\n        init_val = 0\n\n    init_val_node = ov_opset.constant(init_val, data.get_element_type()).output(\n        0\n    )\n    buffer = ov_opset.broadcast(init_val_node, buffer_shape).output(0)\n\n    scattered = ov_opset.scatter_nd_update(\n        buffer, indices, data, reduction=reduction_method\n    ).output(0)\n\n    start = ov_opset.constant([0], Type.i32).output(0)\n    end = ov_opset.unsqueeze(\n        num_segments, ov_opset.constant(0, Type.i32)\n    ).output(0)\n    axes = ov_opset.constant([0], Type.i32).output(0)\n    step = ov_opset.constant([1], Type.i32).output(0)\n    result = ov_opset.slice(scattered, start, end, step, axes).output(0)\n\n    return OpenVINOKerasTensor(result)\n\n\ndef segment_sum(data, segment_ids, num_segments=None, sorted=False):\n    return _segment_reduction_fn(data, segment_ids, \"sum\", num_segments, sorted)\n\n\ndef segment_max(data, segment_ids, num_segments=None, sorted=False):\n    return _segment_reduction_fn(data, segment_ids, \"max\", num_segments, sorted)\n\n\ndef top_k(x, k, sorted=True):\n    x = get_ov_output(x)\n    k_tensor = ov_opset.constant(k, dtype=Type.i32)\n    axis = -1\n    sort_type = \"value\" if sorted else \"none\"\n    topk_node = ov_opset.topk(x, k_tensor, axis, \"max\", sort_type)\n    values = topk_node.output(0)\n    indices = topk_node.output(1)\n    return OpenVINOKerasTensor(values), OpenVINOKerasTensor(indices)\n\n\ndef in_top_k(targets, predictions, k):\n    from keras.src.backend.openvino.numpy import take_along_axis\n\n    # Expand targets: (batch,) → (batch, 1) for use with take_along_axis\n    targets = ov_opset.unsqueeze(\n        get_ov_output(targets), ov_opset.constant(1, Type.i32)\n    ).output(0)\n    predictions = get_ov_output(predictions)\n\n    # top_k returns (batch, k) sorted descending; last col is the k-th largest\n    topk_values = top_k(predictions, k)[0]\n    # Grab only the last column (index k-1): threshold value, shape (batch,)\n    k_minus_1_idx = ov_opset.constant([k - 1], dtype=Type.i32).output(0)\n    topk_values_axis = ov_opset.constant(1, dtype=Type.i32).output(0)\n    topk_min = ov_opset.gather(\n        topk_values, k_minus_1_idx, topk_values_axis\n    ).output(0)\n\n    # Gather the prediction score at each true class index → shape (batch, 1)\n    targets_values = take_along_axis(predictions, targets, axis=-1)\n    # target score >= k-th largest score means it belongs in the top-k\n    mask = ov_opset.greater_equal(targets_values, topk_min).output(0)\n    return OpenVINOKerasTensor(mask)\n\n\ndef logsumexp(x, axis=None, keepdims=False):\n    x = get_ov_output(x)\n    if axis is None:\n        flatten_shape = ov_opset.constant([-1], Type.i32).output(0)\n        x = ov_opset.reshape(x, flatten_shape, False).output(0)\n        axis = 0\n    if isinstance(axis, tuple):\n        axis = list(axis)\n    axis = ov_opset.constant(axis, Type.i32).output(0)\n    const_zero = ov_opset.constant(0, x.get_element_type()).output(0)\n    # Use keepdims=True for reduce_max to ensure proper broadcasting\n    reduce_max = ov_opset.reduce_max(x, axis, True).output(0)\n    is_finite = ov_opset.is_finite(reduce_max).output(0)\n    norm_max = ov_opset.select(is_finite, reduce_max, const_zero).output(0)\n    norm_max_sub = ov_opset.subtract(x, norm_max).output(0)\n    exp_norm_max = ov_opset.exp(norm_max_sub).output(0)\n    sum_exp = ov_opset.reduce_sum(exp_norm_max, axis, keepdims).output(0)\n    log_sum_exp = ov_opset.log(sum_exp).output(0)\n    # Squeeze norm_max if needed to match dimensions\n    if not keepdims:\n        norm_max = ov_opset.squeeze(norm_max, axis).output(0)\n    log_sum_exp = ov_opset.add(norm_max, log_sum_exp).output(0)\n    return OpenVINOKerasTensor(log_sum_exp)\n\n\ndef qr(x, mode=\"reduced\"):\n    raise NotImplementedError(\"`qr` is not supported with openvino backend\")\n\n\ndef extract_sequences(x, sequence_length, sequence_stride):\n    x = get_ov_output(x)\n    x_shape = x.partial_shape\n    ndim = len(x_shape)\n\n    # Define common constants for reuse\n    zero_const_1d = ov_opset.constant([0], Type.i32)\n    shape_tensor = ov_opset.shape_of(x, output_type=Type.i32).output(0)\n\n    last_idx = ov_opset.constant([ndim - 1], Type.i32)\n    axis0 = ov_opset.constant(0, Type.i32)\n    signal_len_1d = ov_opset.gather(shape_tensor, last_idx, axis0).output(0)\n    signal_len_scalar = ov_opset.squeeze(signal_len_1d, zero_const_1d).output(0)\n\n    minus_one = ov_opset.constant([-1], Type.i32).output(0)\n    shape_2d = ov_opset.concat([minus_one, signal_len_1d], axis=0).output(0)\n    x_2d = ov_opset.reshape(x, shape_2d, False).output(0)\n\n    seq_len_c = ov_opset.constant(sequence_length, Type.i32).output(0)\n    stride_c = ov_opset.constant(sequence_stride, Type.i32).output(0)\n    diff = ov_opset.subtract(signal_len_scalar, seq_len_c).output(0)\n    num_seq_scalar = ov_opset.add(\n        ov_opset.divide(diff, stride_c).output(0),\n        ov_opset.constant(1, Type.i32).output(0),\n    ).output(0)\n\n    row_stop = ov_opset.multiply(num_seq_scalar, stride_c).output(0)\n    row_idx = ov_opset.range(\n        ov_opset.constant(0, Type.i32).output(0),\n        row_stop,\n        stride_c,\n        output_type=Type.i32,\n    ).output(0)\n    row_idx_2d = ov_opset.unsqueeze(\n        row_idx, ov_opset.constant([1], Type.i32)\n    ).output(0)\n\n    col_idx = ov_opset.constant(\n        np.arange(sequence_length, dtype=np.int32)\n    ).output(0)\n    col_idx_2d = ov_opset.unsqueeze(col_idx, zero_const_1d).output(0)\n\n    indices = ov_opset.add(row_idx_2d, col_idx_2d).output(0)\n\n    gathered = ov_opset.gather(\n        x_2d, indices, ov_opset.constant(1, Type.i32)\n    ).output(0)\n\n    batch_shape = ov_opset.slice(\n        shape_tensor,\n        start=zero_const_1d,\n        stop=ov_opset.constant([ndim - 1], Type.i32),\n        step=ov_opset.constant([1], Type.i32),\n        axes=zero_const_1d,\n    ).output(0)\n    num_seq_1d = ov_opset.unsqueeze(num_seq_scalar, zero_const_1d).output(0)\n    seq_len_1d = ov_opset.constant([sequence_length], Type.i32).output(0)\n    out_shape = ov_opset.concat(\n        [batch_shape, num_seq_1d, seq_len_1d], axis=0\n    ).output(0)\n    result = ov_opset.reshape(gathered, out_shape, False).output(0)\n\n    return OpenVINOKerasTensor(result)\n\n\ndef _dft(x, axes_offsets, inverse=False):\n    \"\"\"Shared helper for fft, fft2, and ifft2.\n\n    Args:\n        x: Tuple of (real, imag) KerasTensors.\n        axes_offsets: List of negative axis offsets relative to the\n            complex-data rank (e.g. [-2] for fft, [-3, -2] for fft2/ifft2).\n        inverse: If True, use ov_opset.idft; otherwise use ov_opset.dft.\n    \"\"\"\n    ori_dtype = x[0].dtype\n    x0 = cast(x[0], \"float32\") if ori_dtype == \"float64\" else x[0]\n    x1 = cast(x[1], \"float32\") if ori_dtype == \"float64\" else x[1]\n\n    real = ov_opset.unsqueeze(\n        get_ov_output(x0), ov_opset.constant([-1], Type.i32)\n    ).output(0)\n    imag = ov_opset.unsqueeze(\n        get_ov_output(x1), ov_opset.constant([-1], Type.i32)\n    ).output(0)\n    complex_data = ov_opset.concat([real, imag], -1).output(0)\n\n    rank = len(x[0].shape) + 1\n    axes = ov_opset.constant(\n        [rank + off for off in axes_offsets], Type.i32\n    ).output(0)\n\n    op = ov_opset.idft if inverse else ov_opset.dft\n    result = op(complex_data, axes).output(0)\n\n    out_real = ov_opset.gather(\n        result, ov_opset.constant(0, Type.i32), ov_opset.constant(-1, Type.i32)\n    ).output(0)\n    out_imag = ov_opset.gather(\n        result, ov_opset.constant(1, Type.i32), ov_opset.constant(-1, Type.i32)\n    ).output(0)\n\n    if ori_dtype == \"float64\":\n        out_real = ov_opset.convert(out_real, Type.f64).output(0)\n        out_imag = ov_opset.convert(out_imag, Type.f64).output(0)\n\n    return OpenVINOKerasTensor(out_real), OpenVINOKerasTensor(out_imag)\n\n\ndef fft(x):\n    # axes_offsets=[-2]: last axis of complex data (rank = input_rank + 1)\n    return _dft(x, axes_offsets=[-2], inverse=False)\n\n\ndef fft2(x):\n    # axes_offsets=[-3, -2]: two trailing axes of complex data\n    return _dft(x, axes_offsets=[-3, -2], inverse=False)\n\n\ndef ifft2(x):\n    # Same axes as fft2 but with the inverse DFT\n    return _dft(x, axes_offsets=[-3, -2], inverse=True)\n\n\ndef rfft(x, fft_length=None):\n    ori_dtype = x.dtype\n    x = cast(x, \"float32\") if x.dtype == \"float64\" else x\n\n    x_node = get_ov_output(x)\n    rank = len(x_node.shape)\n    axes = ov_opset.constant([rank - 1], Type.i32).output(0)\n\n    if fft_length is not None:\n        signal_size = ov_opset.constant([fft_length], Type.i32).output(0)\n        # Pad input if signal_size > input_size (OpenVINO limitation)\n        last_dim = x_node.shape[-1]\n        if isinstance(last_dim, int) and last_dim < fft_length:\n            pad_begin = [0] * rank\n            pad_end = [0] * rank\n            pad_end[-1] = fft_length - last_dim\n            pad_begin_node = ov_opset.constant(pad_begin, Type.i32).output(0)\n            pad_end_node = ov_opset.constant(pad_end, Type.i32).output(0)\n            x_node = ov_opset.pad(\n                x_node, pad_begin_node, pad_end_node, \"constant\"\n            ).output(0)\n\n        rdft = ov_opset.rdft(x_node, axes, signal_size).output(0)\n    else:\n        rdft = ov_opset.rdft(x_node, axes).output(0)\n\n    out_real = ov_opset.gather(\n        rdft, ov_opset.constant(0, Type.i32), ov_opset.constant(-1, Type.i32)\n    ).output(0)\n    out_imag = ov_opset.gather(\n        rdft, ov_opset.constant(1, Type.i32), ov_opset.constant(-1, Type.i32)\n    ).output(0)\n\n    if ori_dtype == \"float64\":\n        out_real = ov_opset.convert(out_real, Type.f64).output(0)\n        out_imag = ov_opset.convert(out_imag, Type.f64).output(0)\n\n    return OpenVINOKerasTensor(out_real), OpenVINOKerasTensor(out_imag)\n\n\ndef irfft(x, fft_length=None):\n    ori_dtype = x[0].dtype\n    if ori_dtype == \"float64\":\n        x = (cast(x[0], \"float32\"), cast(x[1], \"float32\"))\n\n    complex_data = get_ov_output(stack(x, axis=-1))\n    rank = len(complex_data.shape)\n    axes = ov_opset.constant([rank - 2], Type.i32).output(0)\n\n    if fft_length is not None:\n        signal_size = ov_opset.constant([fft_length], Type.i32).output(0)\n        irdft = ov_opset.irdft(complex_data, axes, signal_size).output(0)\n    else:\n        irdft = ov_opset.irdft(complex_data, axes).output(0)\n\n    if ori_dtype == \"float64\":\n        irdft = ov_opset.convert(irdft, Type.f64).output(0)\n\n    return OpenVINOKerasTensor(irdft)\n\n\ndef stft(\n    x, sequence_length, sequence_stride, fft_length, window=\"hann\", center=True\n):\n    if standardize_dtype(x.dtype) not in {\"float32\", \"float64\"}:\n        raise TypeError(\n            \"Invalid input type. Expected `float32` or `float64`. \"\n            f\"Received: input type={x.dtype}\"\n        )\n    if fft_length < sequence_length:\n        raise ValueError(\n            \"`fft_length` must equal or larger than `sequence_length`. \"\n            f\"Received: sequence_length={sequence_length}, \"\n            f\"fft_length={fft_length}\"\n        )\n    if isinstance(window, str):\n        if window not in {\"hann\", \"hamming\"}:\n            raise ValueError(\n                \"If a string is passed to `window`, it must be one of \"\n                f'`\"hann\"`, `\"hamming\"`. Received: window={window}'\n            )\n\n    ori_dtype = x.dtype\n    x = get_ov_output(x)\n\n    ori_shape = x.shape\n    num_dims = len(ori_shape)\n\n    if num_dims > 2:\n        flatten_shape = ov_opset.constant([-1, ori_shape[-1]], Type.i32).output(\n            0\n        )\n        x = ov_opset.reshape(x, flatten_shape, False).output(0)\n\n    if center:\n        # pad x with reflect mode\n        pad_begin = [0] * len(x.shape)\n        pad_end = [0] * len(x.shape)\n        pad_begin[-1] = fft_length // 2\n        pad_end[-1] = fft_length // 2\n        pad_begin_node = ov_opset.constant(pad_begin, Type.i32).output(0)\n        pad_end_node = ov_opset.constant(pad_end, Type.i32).output(0)\n        x = ov_opset.pad(x, pad_begin_node, pad_end_node, \"reflect\").output(0)\n\n    l_pad = (fft_length - sequence_length) // 2\n    r_pad = fft_length - sequence_length - l_pad\n\n    element_type = x.get_element_type()\n    if element_type == Type.f64:\n        x = ov_opset.convert(x, Type.f32).output(0)\n        element_type = Type.f32\n\n    if window is not None:\n        if isinstance(window, str):\n            win = scipy.signal.get_window(window, sequence_length)\n        else:\n            win = window\n        if len(win.shape) != 1 or win.shape[-1] != sequence_length:\n            raise ValueError(\n                \"The shape of `window` must be equal to [sequence_length].\"\n                f\"Received: window shape={win.shape}\"\n            )\n        win = np.pad(win, [[l_pad, r_pad]])\n        win_node = ov_opset.constant(win, element_type).output(0)\n    else:\n        win = np.ones((sequence_length + l_pad + r_pad))\n        win_node = ov_opset.constant(win, element_type).output(0)\n\n    frame_size_node = ov_opset.constant(fft_length, Type.i32).output(0)\n    frame_step_node = ov_opset.constant(sequence_stride, Type.i32).output(0)\n\n    stft_node = ov_opset.stft(\n        x, win_node, frame_size_node, frame_step_node, transpose_frames=False\n    ).output(0)\n\n    out_real = ov_opset.gather(\n        stft_node,\n        ov_opset.constant(0, Type.i32),\n        ov_opset.constant(-1, Type.i32),\n    ).output(0)\n    out_imag = ov_opset.gather(\n        stft_node,\n        ov_opset.constant(1, Type.i32),\n        ov_opset.constant(-1, Type.i32),\n    ).output(0)\n\n    if num_dims > 2:\n        target_shape = list(ori_shape[:-1]) + [-1, fft_length // 2 + 1]\n        target_shape_node = ov_opset.constant(target_shape, Type.i32).output(0)\n        out_real = ov_opset.reshape(out_real, target_shape_node, False).output(\n            0\n        )\n        out_imag = ov_opset.reshape(out_imag, target_shape_node, False).output(\n            0\n        )\n\n    if ori_dtype == \"float64\":\n        out_real = ov_opset.convert(out_real, Type.f64).output(0)\n        out_imag = ov_opset.convert(out_imag, Type.f64).output(0)\n\n    return OpenVINOKerasTensor(out_real), OpenVINOKerasTensor(out_imag)\n\n\ndef _overlap_sequences_ov(x, sequence_stride, fft_length):\n    \"\"\"Perform overlap-and-add using OpenVINO ops.\n\n    Takes a tensor x of shape [batch, num_sequences, fft_length] and\n    reconstructs a time-domain signal by striding each frame by\n    `sequence_stride` and summing the overlapping contributions.\n\n    Returns the reconstructed signal and its length as an OV scalar node.\n    \"\"\"\n    nstep = 1 + (fft_length - 1) // sequence_stride\n    padded_len = nstep * sequence_stride\n    pad_amount = padded_len - fft_length\n\n    zero = ov_opset.constant(0, Type.i32).output(0)\n\n    # Extract dynamic batch size and number of sequences from input shape.\n    x_shape = ov_opset.shape_of(x, output_type=Type.i32).output(0)\n    flat_batch = ov_opset.gather(\n        x_shape, ov_opset.constant(0, Type.i32), zero\n    ).output(0)\n    num_sequences = ov_opset.gather(\n        x_shape, ov_opset.constant(1, Type.i32), zero\n    ).output(0)\n\n    # Compute expected output length: (num_sequences - 1) * stride + fft_length\n    output_size = ov_opset.add(\n        ov_opset.multiply(\n            ov_opset.constant(sequence_stride, Type.i32),\n            ov_opset.subtract(num_sequences, ov_opset.constant(1, Type.i32)),\n        ),\n        ov_opset.constant(fft_length, Type.i32),\n    ).output(0)\n\n    # Pad each frame along the last axis so its length is a multiple of stride.\n    if pad_amount > 0:\n        pad_begin = ov_opset.constant([0, 0, 0], Type.i32).output(0)\n        pad_end = ov_opset.constant([0, 0, pad_amount], Type.i32).output(0)\n        x = ov_opset.pad(x, pad_begin, pad_end, \"constant\").output(0)\n\n    # Reshape to [batch, num_sequences, nstep, sequence_stride] to expose\n    # the overlap structure within each frame.\n    overlap_shape = ov_opset.concat(\n        [\n            ov_opset.unsqueeze(\n                flat_batch, ov_opset.constant(0, Type.i32)\n            ).output(0),\n            ov_opset.unsqueeze(\n                num_sequences, ov_opset.constant(0, Type.i32)\n            ).output(0),\n            ov_opset.constant([nstep, sequence_stride], Type.i32).output(0),\n        ],\n        0,\n    ).output(0)\n    x = ov_opset.reshape(x, overlap_shape, False).output(0)\n\n    # Transpose to [batch, nstep, num_sequences, sequence_stride] so that\n    # overlapping frames become adjacent along axis 2.\n    x = ov_opset.transpose(x, ov_opset.constant([0, 2, 1, 3], Type.i32)).output(\n        0\n    )\n\n    # Pad num_sequences axis by num_sequences to create interleaved zeros,\n    # enabling a flat reshape that places each frame at its strided offset.\n    pad_begin_n = ov_opset.constant([0, 0, 0, 0], Type.i32).output(0)\n    pad_end_n = ov_opset.concat(\n        [\n            ov_opset.constant([0, 0], Type.i32).output(0),\n            ov_opset.unsqueeze(\n                num_sequences, ov_opset.constant(0, Type.i32)\n            ).output(0),\n            ov_opset.constant([0], Type.i32).output(0),\n        ],\n        0,\n    ).output(0)\n    x = ov_opset.pad(x, pad_begin_n, pad_end_n, \"constant\").output(0)\n\n    # overlapping_dim_size = 2 * num_sequences - 1: the number of distinct\n    # stride-aligned positions covered by all frames after interleaving.\n    overlapping_dim_size = ov_opset.subtract(\n        ov_opset.multiply(ov_opset.constant(2, Type.i32), num_sequences),\n        ov_opset.constant(1, Type.i32),\n    ).output(0)\n\n    # Flatten to [batch, nstep * 2 * sequence_stride * num_sequences] so the\n    # interleaved frame data forms a contiguous sequence.\n    total_inner = ov_opset.multiply(\n        ov_opset.constant(nstep * 2 * sequence_stride, Type.i32),\n        num_sequences,\n    ).output(0)\n    flatten_shape = ov_opset.concat(\n        [\n            ov_opset.unsqueeze(\n                flat_batch, ov_opset.constant(0, Type.i32)\n            ).output(0),\n            ov_opset.unsqueeze(\n                total_inner, ov_opset.constant(0, Type.i32)\n            ).output(0),\n        ],\n        0,\n    ).output(0)\n    x = ov_opset.reshape(x, flatten_shape, False).output(0)\n\n    # Slice away the trailing padding introduced by the interleaving step.\n    slice_len = ov_opset.multiply(\n        overlapping_dim_size,\n        ov_opset.constant(nstep * sequence_stride, Type.i32),\n    ).output(0)\n    x = ov_opset.slice(\n        x,\n        ov_opset.constant([0], Type.i32).output(0),\n        ov_opset.unsqueeze(slice_len, ov_opset.constant(0, Type.i32)).output(0),\n        ov_opset.constant([1], Type.i32).output(0),\n        ov_opset.constant([1], Type.i32).output(0),\n    ).output(0)\n\n    # Reshape to [batch, nstep, overlapping_dim_size * sequence_stride] and\n    # reduce-sum over the nstep axis to accumulate overlapping frame values.\n    inner_size = ov_opset.multiply(\n        overlapping_dim_size, ov_opset.constant(sequence_stride, Type.i32)\n    ).output(0)\n    sum_shape = ov_opset.concat(\n        [\n            ov_opset.unsqueeze(\n                flat_batch, ov_opset.constant(0, Type.i32)\n            ).output(0),\n            ov_opset.constant([nstep], Type.i32).output(0),\n            ov_opset.unsqueeze(\n                inner_size, ov_opset.constant(0, Type.i32)\n            ).output(0),\n        ],\n        0,\n    ).output(0)\n    x = ov_opset.reshape(x, sum_shape, False).output(0)\n\n    x = ov_opset.reduce_sum(\n        x,\n        ov_opset.constant([1], Type.i32).output(0),\n        keep_dims=False,\n    ).output(0)\n\n    # Trim the result to the expected signal length.\n    x = ov_opset.slice(\n        x,\n        ov_opset.constant([0], Type.i32).output(0),\n        ov_opset.unsqueeze(output_size, ov_opset.constant(0, Type.i32)).output(\n            0\n        ),\n        ov_opset.constant([1], Type.i32).output(0),\n        ov_opset.constant([1], Type.i32).output(0),\n    ).output(0)\n\n    return x, output_size\n\n\ndef istft(\n    x,\n    sequence_length,\n    sequence_stride,\n    fft_length,\n    length=None,\n    window=\"hann\",\n    center=True,\n):\n    if isinstance(window, str):\n        if window not in {\"hann\", \"hamming\"}:\n            raise ValueError(\n                \"If a string is passed to `window`, it must be one of \"\n                f'`\"hann\"`, `\"hamming\"`. Received: window={window}'\n            )\n\n    ori_dtype = x[0].dtype\n\n    x0 = get_ov_output(x[0])\n    ori_partial_shape = x0.get_partial_shape()\n    num_dims = ori_partial_shape.rank.get_length()\n    ori_shape_list = [\n        None if dim.is_dynamic else dim.get_length()\n        for dim in ori_partial_shape\n    ]\n\n    l_pad = (fft_length - sequence_length) // 2\n    r_pad = fft_length - sequence_length - l_pad\n\n    if window is not None:\n        if isinstance(window, str):\n            win = scipy.signal.get_window(window, sequence_length)\n        else:\n            win = np.asarray(window, dtype=np.float64)\n        if len(win.shape) != 1 or win.shape[-1] != sequence_length:\n            raise ValueError(\n                \"The shape of `window` must be equal to [sequence_length].\"\n                f\"Received: window shape={win.shape}\"\n            )\n        win = np.pad(win, [[l_pad, r_pad]])\n\n        denom = np.square(win)\n        overlaps = -(-fft_length // sequence_stride)\n        denom = np.pad(denom, [(0, overlaps * sequence_stride - fft_length)])\n        denom = denom.reshape([overlaps, sequence_stride])\n        denom = denom.sum(axis=0, keepdims=True)\n        denom = np.tile(denom, [overlaps, 1])\n        denom = denom.reshape([overlaps * sequence_stride])\n        win = win / denom[:fft_length]\n    else:\n        win = None\n\n    frames = irfft(x, fft_length)\n    frames = get_ov_output(frames)\n\n    element_type = frames.get_element_type()\n    if element_type == Type.f64:\n        frames = ov_opset.convert(frames, Type.f32).output(0)\n        element_type = Type.f32\n\n    if win is not None:\n        win_node = ov_opset.constant(win.astype(np.float32), Type.f32).output(0)\n        if element_type != Type.f32:\n            win_node = ov_opset.convert(win_node, element_type).output(0)\n        frames = ov_opset.multiply(frames, win_node).output(0)\n\n    if num_dims == 2:\n        frames = ov_opset.unsqueeze(\n            frames, ov_opset.constant(0, Type.i32)\n        ).output(0)\n    elif num_dims > 2:\n        frames_shp = ov_opset.shape_of(frames, output_type=Type.i32).output(0)\n        num_seq_node = ov_opset.gather(\n            frames_shp,\n            ov_opset.constant(num_dims - 2, Type.i32),\n            ov_opset.constant(0, Type.i32),\n        ).output(0)\n        flatten_shp = ov_opset.concat(\n            [\n                ov_opset.constant([-1], Type.i32).output(0),\n                ov_opset.unsqueeze(\n                    num_seq_node, ov_opset.constant(0, Type.i32)\n                ).output(0),\n                ov_opset.constant([fft_length], Type.i32).output(0),\n            ],\n            0,\n        ).output(0)\n        frames = ov_opset.reshape(frames, flatten_shp, False).output(0)\n\n    frames, output_size = _overlap_sequences_ov(\n        frames, sequence_stride, fft_length\n    )\n\n    start_val = fft_length // 2 if center else 0\n\n    if length is not None:\n        frames = ov_opset.slice(\n            frames,\n            ov_opset.constant([start_val], Type.i32).output(0),\n            ov_opset.constant([start_val + length], Type.i32).output(0),\n            ov_opset.constant([1], Type.i32).output(0),\n            ov_opset.constant([1], Type.i32).output(0),\n        ).output(0)\n    else:\n        if start_val > 0:\n            frames = ov_opset.slice(\n                frames,\n                ov_opset.constant([start_val], Type.i32).output(0),\n                ov_opset.constant([INT32_MAX], Type.i32).output(0),\n                ov_opset.constant([1], Type.i32).output(0),\n                ov_opset.constant([1], Type.i32).output(0),\n            ).output(0)\n        if center:\n            cur_len = ov_opset.gather(\n                ov_opset.shape_of(frames, output_type=Type.i32).output(0),\n                ov_opset.constant(1, Type.i32),\n                ov_opset.constant(0, Type.i32),\n            ).output(0)\n            end_node = ov_opset.subtract(\n                cur_len,\n                ov_opset.constant(fft_length // 2, Type.i32),\n            ).output(0)\n            frames = ov_opset.slice(\n                frames,\n                ov_opset.constant([0], Type.i32).output(0),\n                ov_opset.unsqueeze(\n                    end_node, ov_opset.constant(0, Type.i32)\n                ).output(0),\n                ov_opset.constant([1], Type.i32).output(0),\n                ov_opset.constant([1], Type.i32).output(0),\n            ).output(0)\n\n    if num_dims == 2:\n        frames = ov_opset.squeeze(\n            frames, ov_opset.constant([0], Type.i32)\n        ).output(0)\n    elif num_dims > 2:\n        batch_dims = ori_shape_list[:-2]\n        target_shape = [d if d is not None else -1 for d in batch_dims] + [-1]\n        target_shape_node = ov_opset.constant(target_shape, Type.i32).output(0)\n        frames = ov_opset.reshape(frames, target_shape_node, False).output(0)\n\n    if ori_dtype == \"float64\":\n        frames = ov_opset.convert(frames, Type.f64).output(0)\n\n    return OpenVINOKerasTensor(frames)\n\n\ndef rsqrt(x):\n    x = get_ov_output(x)\n    const_one = ov_opset.constant(1, x.get_element_type()).output(0)\n    sqrt = ov_opset.sqrt(x).output(0)\n    return OpenVINOKerasTensor(ov_opset.divide(const_one, sqrt).output(0))\n\n\ndef erf(x):\n    x = get_ov_output(x)\n    erf = ov_opset.erf(x).output(0)\n    return OpenVINOKerasTensor(erf)\n\n\ndef erfinv(x):\n    # TODO: Float64 infinity values are clamped on CPU backend,\n    # breaking erfinv(±1) = ±inf\n    # See https://github.com/openvinotoolkit/openvino/issues/34138\n    # Tests excluded: test_erfinv_operation_basic, test_erfinv_operation_dtype\n    x = get_ov_output(x)\n    dtype = x.get_element_type()\n\n    a = 0.147\n    two_over_pi_a = 2.0 / (np.pi * a)\n    two_over_sqrt_pi = 2.0 / np.sqrt(np.pi)\n\n    one = ov_opset.constant(1.0, dtype).output(0)\n    half = ov_opset.constant(0.5, dtype).output(0)\n\n    x_sq = ov_opset.multiply(x, x).output(0)\n    log_term = ov_opset.log(ov_opset.subtract(one, x_sq).output(0)).output(0)\n\n    k = ov_opset.add(\n        ov_opset.constant(two_over_pi_a, dtype).output(0),\n        ov_opset.multiply(half, log_term).output(0),\n    ).output(0)\n\n    inner = ov_opset.subtract(\n        ov_opset.multiply(k, k).output(0),\n        ov_opset.multiply(\n            ov_opset.constant(1.0 / a, dtype).output(0), log_term\n        ).output(0),\n    ).output(0)\n\n    y0 = ov_opset.multiply(\n        ov_opset.sign(x).output(0),\n        ov_opset.sqrt(\n            ov_opset.subtract(ov_opset.sqrt(inner).output(0), k).output(0)\n        ).output(0),\n    ).output(0)\n\n    erf_err = ov_opset.subtract(ov_opset.erf(y0).output(0), x).output(0)\n\n    y0_sq = ov_opset.multiply(y0, y0).output(0)\n    exp_term = ov_opset.exp(ov_opset.negative(y0_sq).output(0)).output(0)\n    deriv = ov_opset.multiply(\n        ov_opset.constant(two_over_sqrt_pi, dtype).output(0),\n        exp_term,\n    ).output(0)\n    y1 = ov_opset.subtract(\n        y0, ov_opset.divide(erf_err, deriv).output(0)\n    ).output(0)\n\n    return OpenVINOKerasTensor(y1)\n"
  },
  {
    "path": "keras/src/backend/openvino/nn.py",
    "content": "import numpy as np\nimport openvino.opset15 as ov_opset\nfrom openvino import Type\n\nimport keras.src.backend.openvino.numpy as onp\nfrom keras.src import backend\nfrom keras.src.backend.common.backend_utils import (\n    _get_output_shape_given_tf_padding,\n)\nfrom keras.src.backend.openvino.core import OPENVINO_DTYPES\nfrom keras.src.backend.openvino.core import OpenVINOKerasTensor\nfrom keras.src.backend.openvino.core import get_ov_output\n\n\ndef relu(x):\n    x = get_ov_output(x)\n    return OpenVINOKerasTensor(ov_opset.relu(x).output(0))\n\n\ndef relu6(x):\n    x = get_ov_output(x)\n    return OpenVINOKerasTensor(ov_opset.clamp(x, 0.0, 6.0).output(0))\n\n\ndef celu(x, alpha=1.0):\n    x = get_ov_output(x)\n    const_zero = get_ov_output(0.0, x.get_element_type())\n    const_alpha = get_ov_output(alpha, x.get_element_type())\n    const_one = get_ov_output(1.0, x.get_element_type())\n    exp_x_div_alpha = ov_opset.exp(ov_opset.divide(x, const_alpha)).output(0)\n    negative_branch = ov_opset.multiply(\n        const_alpha, ov_opset.subtract(exp_x_div_alpha, const_one)\n    )\n\n    celu_x = ov_opset.add(\n        ov_opset.maximum(x, const_zero).output(0),\n        ov_opset.minimum(negative_branch, const_zero).output(0),\n    )\n    return OpenVINOKerasTensor(celu_x.output(0))\n\n\ndef sigmoid(x):\n    x = get_ov_output(x)\n    return OpenVINOKerasTensor(ov_opset.sigmoid(x).output(0))\n\n\ndef tanh(x):\n    x = get_ov_output(x)\n    return OpenVINOKerasTensor(ov_opset.tanh(x).output(0))\n\n\ndef tanh_shrink(x):\n    x = get_ov_output(x)\n    return OpenVINOKerasTensor(ov_opset.subtract(x, ov_opset.tanh(x)).output(0))\n\n\ndef hard_tanh(x):\n    x = get_ov_output(x)\n    return OpenVINOKerasTensor(ov_opset.clamp(x, -1.0, 1.0).output(0))\n\n\ndef soft_shrink(x, threshold=0.5):\n    x = get_ov_output(x)\n    et = x.get_element_type()\n    thr = get_ov_output(threshold, et)\n    zero = get_ov_output(0.0, et)\n    abs_x = ov_opset.abs(x)\n    sub = ov_opset.subtract(abs_x, thr)\n    shrunk = ov_opset.maximum(sub, zero)\n    sign = ov_opset.sign(x)\n    out = ov_opset.multiply(sign, shrunk)\n    return OpenVINOKerasTensor(out.output(0))\n\n\ndef hard_shrink(x, threshold=0.5):\n    x = get_ov_output(x)\n    et = x.get_element_type()\n    thr = get_ov_output(threshold, et)\n    zero = get_ov_output(0.0, et)\n    cond = ov_opset.greater(ov_opset.abs(x), thr)\n    out = ov_opset.select(cond, x, zero)\n    return OpenVINOKerasTensor(out.output(0))\n\n\ndef softplus(x):\n    x = get_ov_output(x)\n    return OpenVINOKerasTensor(ov_opset.softplus(x).output(0))\n\n\ndef softsign(x):\n    x = get_ov_output(x)\n    return OpenVINOKerasTensor(ov_opset.softsign(x).output(0))\n\n\ndef silu(x):\n    x = get_ov_output(x)\n    beta = get_ov_output(1.0, x.get_element_type())\n    return OpenVINOKerasTensor(ov_opset.swish(x, beta=beta).output(0))\n\n\ndef log_sigmoid(x):\n    x = get_ov_output(x)\n    neg_x = ov_opset.negative(x)\n    return OpenVINOKerasTensor(\n        ov_opset.negative(ov_opset.softplus(neg_x)).output(0)\n    )\n\n\ndef leaky_relu(x, negative_slope=0.2):\n    x = get_ov_output(x)\n    slope_const = ov_opset.constant(\n        negative_slope, x.get_element_type()\n    ).output(0)\n    leaky_relu = ov_opset.prelu(x, slope_const).output(0)\n    return OpenVINOKerasTensor(leaky_relu)\n\n\ndef sparse_sigmoid(x):\n    x = get_ov_output(x)\n    et = x.get_element_type()\n    one = get_ov_output(1.0, et)\n    neg_one = get_ov_output(-1.0, et)\n    half = get_ov_output(0.5, et)\n    y = ov_opset.minimum(ov_opset.maximum(x, neg_one), one)\n    out = ov_opset.multiply(half, ov_opset.add(y, one))\n    return OpenVINOKerasTensor(out.output(0))\n\n\ndef hard_sigmoid(x):\n    x = get_ov_output(x)\n    alpha = get_ov_output(1.0 / 6.0, x.get_element_type())\n    beta = get_ov_output(0.5, x.get_element_type())\n    return OpenVINOKerasTensor(ov_opset.hard_sigmoid(x, alpha, beta).output(0))\n\n\ndef hard_silu(x):\n    hard_sigmoid_output = get_ov_output(hard_sigmoid(x))\n    x = get_ov_output(x)\n    return OpenVINOKerasTensor(\n        ov_opset.multiply(x, hard_sigmoid_output).output(0)\n    )\n\n\ndef elu(x, alpha=1.0):\n    x = get_ov_output(x)\n    return OpenVINOKerasTensor(ov_opset.elu(x, alpha).output(0))\n\n\ndef selu(x):\n    alpha = 1.6732632423543772848170429916717\n    scale = 1.0507009873554804934193349852946\n    x = get_ov_output(x)\n    alpha = get_ov_output(alpha, x.get_element_type())\n    scale = get_ov_output(scale, x.get_element_type())\n    return OpenVINOKerasTensor(ov_opset.selu(x, alpha, scale).output(0))\n\n\ndef gelu(x, approximate=True):\n    x = get_ov_output(x)\n    approximate_mode = \"erf\"\n    if approximate:\n        approximate_mode = \"tanh\"\n    return OpenVINOKerasTensor(ov_opset.gelu(x, approximate_mode).output(0))\n\n\ndef softmax(x, axis=-1):\n    x = get_ov_output(x)\n    if axis is None:\n        x_shape = ov_opset.shape_of(x)\n        flatten_shape = ov_opset.constant([-1], Type.i32).output(0)\n        flatten_x = ov_opset.reshape(x, flatten_shape, False).output(0)\n        softmax_x = ov_opset.softmax(flatten_x, 0).output(0)\n        return OpenVINOKerasTensor(\n            ov_opset.reshape(softmax_x, x_shape, False).output(0)\n        )\n    if isinstance(axis, (tuple, list)):\n        if not axis:\n            return OpenVINOKerasTensor(x)\n        axes_const = ov_opset.constant(sorted(axis), Type.i32).output(0)\n        x_max = ov_opset.reduce_max(x, axes_const, True).output(0)\n        exp_x = ov_opset.exp(ov_opset.subtract(x, x_max).output(0)).output(0)\n        sum_exp = ov_opset.reduce_sum(exp_x, axes_const, True).output(0)\n        return OpenVINOKerasTensor(ov_opset.divide(exp_x, sum_exp).output(0))\n    return OpenVINOKerasTensor(ov_opset.softmax(x, axis).output(0))\n\n\ndef log_softmax(x, axis=-1):\n    x = get_ov_output(x)\n    if isinstance(axis, (tuple, list)) and not axis:\n        return OpenVINOKerasTensor(x)\n    restore_shape = None\n    if axis is None:\n        restore_shape = ov_opset.shape_of(x)\n        flatten_shape = ov_opset.constant([-1], Type.i32).output(0)\n        x = ov_opset.reshape(x, flatten_shape, False).output(0)\n        axes = [0]\n    elif isinstance(axis, (tuple, list)):\n        axes = sorted(axis)\n    else:\n        axes = [axis]\n    axes_const = ov_opset.constant(axes, Type.i32).output(0)\n    x_max = ov_opset.reduce_max(x, axes_const, True).output(0)\n    x_shifted = ov_opset.subtract(x, x_max).output(0)\n    sum_exp = ov_opset.reduce_sum(\n        ov_opset.exp(x_shifted).output(0), axes_const, True\n    ).output(0)\n    log_sum_exp = ov_opset.log(sum_exp).output(0)\n    result = ov_opset.subtract(x_shifted, log_sum_exp).output(0)\n    if restore_shape is not None:\n        result = ov_opset.reshape(result, restore_shape, False).output(0)\n    return OpenVINOKerasTensor(result)\n\n\ndef squareplus(x, b=4):\n    x = get_ov_output(x)\n    et = x.get_element_type()\n    b = get_ov_output(b, et)\n    two = get_ov_output(2.0, et)\n    x_squared = ov_opset.multiply(x, x)\n    inside = ov_opset.add(x_squared, b)\n    root = ov_opset.sqrt(inside)\n    summed = ov_opset.add(x, root)\n    out = ov_opset.divide(summed, two)\n    return OpenVINOKerasTensor(out.output(0))\n\n\ndef sparse_plus(x):\n    x = get_ov_output(x)\n    et = x.get_element_type()\n    one = get_ov_output(1.0, et)\n    neg_one = get_ov_output(-1.0, et)\n    zero = get_ov_output(0.0, et)\n    quarter = get_ov_output(0.25, et)\n    x_plus_1 = ov_opset.add(x, one)\n    quad = ov_opset.multiply(quarter, ov_opset.multiply(x_plus_1, x_plus_1))\n    leq_than_neg_one = ov_opset.less_equal(x, neg_one)\n    less_than_one = ov_opset.less(x, one)\n    out = ov_opset.select(\n        leq_than_neg_one,\n        zero,\n        ov_opset.select(less_than_one, quad, x),\n    )\n    return OpenVINOKerasTensor(out.output(0))\n\n\ndef threshold(x, threshold, default_value):\n    x = get_ov_output(x)\n    et = x.get_element_type()\n    thr = get_ov_output(threshold, et)\n    dv = get_ov_output(default_value, et)\n    cond = ov_opset.greater(x, thr)\n    out = ov_opset.select(cond, x, dv)\n    return OpenVINOKerasTensor(out.output(0))\n\n\ndef max_pool(\n    inputs,\n    pool_size,\n    strides=None,\n    padding=\"valid\",\n    data_format=None,\n):\n    num_spatial_dims = (\n        get_ov_output(inputs).get_partial_shape().rank.get_length() - 2\n    )\n    kwargs = {\"dilations\": [1] * num_spatial_dims}  # required for ov max_pool\n    return _pool(\n        inputs,\n        pool_size,\n        ov_opset.max_pool,\n        strides,\n        padding,\n        data_format,\n        **kwargs,\n    )\n\n\ndef average_pool(\n    inputs,\n    pool_size,\n    strides=None,\n    padding=\"valid\",\n    data_format=None,\n):\n    return _pool(\n        inputs,\n        pool_size,\n        ov_opset.avg_pool,\n        strides,\n        padding,\n        data_format,\n        exclude_pad=True,\n    )\n\n\ndef _compute_adaptive_gather_indices(\n    input_dim, output_size, small_window, big_window\n):\n    \"\"\"Compute gather indices for the two-pool gather method.\"\"\"\n    window_starts = np.floor(\n        np.arange(output_size) * input_dim / output_size\n    ).astype(np.int32)\n    window_ends = np.minimum(\n        np.ceil(np.arange(1, output_size + 1) * input_dim / output_size).astype(\n            np.int32\n        ),\n        input_dim,\n    )\n    window_starts = np.minimum(window_starts, input_dim - 1)\n    window_sizes = window_ends - window_starts\n    small_pool_len = max(1, input_dim - small_window + 1)\n    return np.where(\n        window_sizes == big_window,\n        window_starts + small_pool_len,\n        window_starts,\n    ).tolist()\n\n\ndef _adaptive_pool_ov(\n    inputs, output_size, pool_type, data_format, num_spatial_dims\n):\n    \"\"\"Shared OpenVINO implementation for adaptive average/max pooling.\n\n    Uses the two-pool gather method: for each spatial axis independently,\n    apply pooling with the small and big kernel sizes (stride=1, VALID),\n    concatenate the results, then gather the correct output positions.\n    \"\"\"\n    if isinstance(output_size, int):\n        output_size = (output_size,) * num_spatial_dims\n\n    data_format = backend.standardize_data_format(data_format)\n    inputs = get_ov_output(inputs)\n\n    current = _adjust_input(inputs, num_spatial_dims, data_format)\n\n    for spatial_idx in range(num_spatial_dims):\n        gather_axis = 2 + spatial_idx\n        current_ps = current.get_partial_shape()\n        input_dim = current_ps[gather_axis].get_length()\n        output_dim = output_size[spatial_idx]\n\n        if input_dim == output_dim:\n            continue\n\n        small_w = int(np.ceil(input_dim / output_dim))\n        big_w = small_w + 1\n        gather_indices = _compute_adaptive_gather_indices(\n            input_dim, output_dim, small_w, big_w\n        )\n\n        strides = [1] * num_spatial_dims\n\n        small_kernel = [1] * num_spatial_dims\n        small_kernel[spatial_idx] = small_w\n\n        if pool_type == \"avg\":\n            small_pool = ov_opset.avg_pool(\n                current,\n                strides=strides,\n                pads_begin=[],\n                pads_end=[],\n                kernel_shape=small_kernel,\n                exclude_pad=True,\n                auto_pad=\"VALID\",\n            ).output(0)\n        else:\n            small_pool = ov_opset.max_pool(\n                current,\n                strides=strides,\n                dilations=[1] * num_spatial_dims,\n                pads_begin=[],\n                pads_end=[],\n                kernel_shape=small_kernel,\n                auto_pad=\"VALID\",\n            ).output(0)\n\n        if big_w <= input_dim:\n            big_kernel = [1] * num_spatial_dims\n            big_kernel[spatial_idx] = big_w\n\n            if pool_type == \"avg\":\n                big_pool = ov_opset.avg_pool(\n                    current,\n                    strides=strides,\n                    pads_begin=[],\n                    pads_end=[],\n                    kernel_shape=big_kernel,\n                    exclude_pad=True,\n                    auto_pad=\"VALID\",\n                ).output(0)\n            else:\n                big_pool = ov_opset.max_pool(\n                    current,\n                    strides=strides,\n                    dilations=[1] * num_spatial_dims,\n                    pads_begin=[],\n                    pads_end=[],\n                    kernel_shape=big_kernel,\n                    auto_pad=\"VALID\",\n                ).output(0)\n\n            combined = ov_opset.concat(\n                [small_pool, big_pool], gather_axis\n            ).output(0)\n        else:\n            # big_w > input_dim: the big pool produces no outputs and all\n            # gather indices come from the small pool, so skip the big pool.\n            combined = small_pool\n\n        indices_node = ov_opset.constant(gather_indices, Type.i32).output(0)\n        axis_node = ov_opset.constant(gather_axis, Type.i32).output(0)\n        current = ov_opset.gather(combined, indices_node, axis_node).output(0)\n\n    result = _adjust_outputs(current, num_spatial_dims, data_format)\n    return OpenVINOKerasTensor(result)\n\n\ndef adaptive_average_pool(inputs, output_size, data_format=None):\n    inputs_ov = get_ov_output(inputs)\n    num_spatial_dims = inputs_ov.get_partial_shape().rank.get_length() - 2\n    if num_spatial_dims not in (1, 2, 3):\n        raise ValueError(\n            \"adaptive_average_pool supports 1D, 2D, or 3D inputs only.\"\n        )\n    return _adaptive_pool_ov(\n        inputs, output_size, \"avg\", data_format, num_spatial_dims\n    )\n\n\ndef adaptive_max_pool(inputs, output_size, data_format=None):\n    inputs_ov = get_ov_output(inputs)\n    num_spatial_dims = inputs_ov.get_partial_shape().rank.get_length() - 2\n    if num_spatial_dims not in (1, 2, 3):\n        raise ValueError(\n            \"adaptive_max_pool supports 1D, 2D, or 3D inputs only.\"\n        )\n    return _adaptive_pool_ov(\n        inputs, output_size, \"max\", data_format, num_spatial_dims\n    )\n\n\ndef _pool(\n    inputs,\n    pool_size,\n    pooling_func,\n    strides=None,\n    padding=\"valid\",\n    data_format=None,\n    **kwargs,\n):\n    data_format = backend.standardize_data_format(data_format)\n    inputs = get_ov_output(inputs)\n\n    num_spatial_dims = inputs.get_partial_shape().rank.get_length() - 2\n    if isinstance(pool_size, int):\n        pool_size = [pool_size] * num_spatial_dims\n\n    if strides is None:\n        strides = pool_size\n\n    strides = _adjust_strides_dilation(strides, num_spatial_dims)\n    pad_mode, pads_begin, pads_end = _adjust_padding(padding)\n    inputs = _adjust_input(inputs, num_spatial_dims, data_format)\n    pool_kwargs = {\n        \"kernel_shape\": pool_size,\n        \"strides\": strides,\n        \"auto_pad\": pad_mode,\n        \"pads_begin\": pads_begin,\n        \"pads_end\": pads_end,\n        **kwargs,\n    }\n    pooled = pooling_func(inputs, **pool_kwargs).output(0)\n    adjusted_pooled = _adjust_outputs(pooled, num_spatial_dims, data_format)\n    return OpenVINOKerasTensor(adjusted_pooled)\n\n\ndef _adjust_strides_dilation(\n    x,\n    num_spatial_dims,\n):\n    # Helper function that converts an operand to a spatial operand.\n    x = (x,) * num_spatial_dims if isinstance(x, int) else x\n    # OpenVINO expects input in NCHW layout\n    # x = [1, 1] + list(x)\n    x = list(x)\n    return x\n\n\ndef _adjust_padding(\n    padding,\n):\n    padding = padding.lower() if isinstance(padding, str) else padding\n    if padding == \"same\":\n        return \"SAME_UPPER\", [], []\n    elif padding == \"same_lower\":\n        return \"SAME_LOWER\", [], []\n    elif padding == \"valid\":\n        return \"VALID\", [], []\n    pads_begin = []\n    pads_end = []\n    for padding_pair in padding:\n        pads_begin.append(padding_pair[0])\n        pads_end.append(padding_pair[1])\n    return \"EXPLICIT\", pads_begin, pads_end\n\n\ndef _adjust_input(inputs, num_spatial_dims, data_format):\n    if data_format == \"channels_first\":\n        return inputs\n    if num_spatial_dims == 1:\n        permutation = [0, 2, 1]\n    elif num_spatial_dims == 2:\n        permutation = [0, 3, 1, 2]\n    else:\n        permutation = [0, 4, 1, 2, 3]\n    permutation = ov_opset.constant(permutation, Type.i32)\n    return ov_opset.transpose(inputs, permutation).output(0)\n\n\ndef _adjust_kernel(kernel, num_spatial_dims):\n    if num_spatial_dims == 1:\n        permutation = [2, 1, 0]\n    elif num_spatial_dims == 2:\n        permutation = [3, 2, 0, 1]\n    else:\n        permutation = [4, 3, 0, 1, 2]\n    permutation = ov_opset.constant(permutation, Type.i32)\n    return ov_opset.transpose(kernel, permutation).output(0)\n\n\ndef _adjust_depthwise_kernel(kernel, num_spatial_dims):\n    # kernel layout: filter_H, filter_W, C_IN, Ch_mul\n    if num_spatial_dims == 1:\n        # kernel layout: filter_H, C_IN, Ch_mul\n        permutation = [1, 2, 0]\n    elif num_spatial_dims == 2:\n        # kernel layout: filter_H, filter_W, C_IN, Ch_mul\n        permutation = [2, 3, 0, 1]\n    else:\n        # kernel layout: filter_H, filter_W, filter_Z, C_IN, Ch_mul\n        permutation = [3, 4, 0, 1, 2]\n    permutation = ov_opset.constant(permutation, Type.i32)\n    return ov_opset.transpose(kernel, permutation).output(0)\n\n\ndef _adjust_outputs(outputs, num_spatial_dims, data_format):\n    if data_format == \"channels_first\":\n        return outputs\n    # convert a tensor from NCHW to NHWC layout\n    if num_spatial_dims == 1:\n        permutation = [0, 2, 1]\n    elif num_spatial_dims == 2:\n        permutation = [0, 2, 3, 1]\n    else:\n        permutation = [0, 2, 3, 4, 1]\n    permutation = ov_opset.constant(permutation, Type.i32)\n    return ov_opset.transpose(outputs, permutation).output(0)\n\n\ndef conv(\n    inputs,\n    kernel,\n    strides=1,\n    padding=\"valid\",\n    data_format=None,\n    dilation_rate=1,\n):\n    inputs = get_ov_output(inputs)\n    kernel = get_ov_output(kernel)\n\n    data_format = backend.standardize_data_format(data_format)\n    num_spatial_dims = inputs.get_partial_shape().rank.get_length() - 2\n\n    if data_format == \"channels_last\":\n        inputs_in_channels = inputs.get_partial_shape()[\n            2 + num_spatial_dims - 1\n        ]\n    else:\n        inputs_in_channels = inputs.get_partial_shape()[1]\n    kernel_in_channels = kernel.get_partial_shape()[-2]\n\n    strides = _adjust_strides_dilation(strides, num_spatial_dims)\n    dilation_rate = _adjust_strides_dilation(dilation_rate, num_spatial_dims)\n    pad_mode, pads_begin, pads_end = _adjust_padding(padding)\n    inputs = _adjust_input(inputs, num_spatial_dims, data_format)\n    kernel = _adjust_kernel(kernel, num_spatial_dims)\n\n    num_groups = (\n        inputs_in_channels.get_length() // kernel_in_channels.get_length()\n    )\n    if num_groups == 1:\n        conv = ov_opset.convolution(\n            inputs,\n            kernel,\n            strides,\n            pads_begin,\n            pads_end,\n            dilation_rate,\n            pad_mode,\n        )\n    else:\n        input_shape = ov_opset.shape_of(inputs).output(0)\n        filter_shape = ov_opset.shape_of(kernel).output(0)\n        zero_const = ov_opset.constant([0], Type.i32).output(0)\n        one_const = ov_opset.constant([1], Type.i32).output(0)\n        two_const = ov_opset.constant([2], Type.i32).output(0)\n        input_cin = ov_opset.slice(\n            input_shape, one_const, two_const, one_const\n        ).output(0)\n        filter_cin = ov_opset.slice(\n            filter_shape, one_const, two_const, one_const\n        ).output(0)\n        num_groups = ov_opset.divide(input_cin, filter_cin).output(0)\n\n        # reshape the filter based on the number of groups information\n        int_max_const = ov_opset.constant([2**31 - 1], Type.i32).output(0)\n        filter_cout = ov_opset.slice(\n            filter_shape, zero_const, one_const, one_const\n        ).output(0)\n        filter_new_cout = ov_opset.divide(filter_cout, num_groups).output(0)\n        shape_cin_xy = ov_opset.slice(\n            filter_shape, one_const, int_max_const, one_const\n        ).output(0)\n        filter_new_shape = ov_opset.concat(\n            [num_groups, filter_new_cout, shape_cin_xy], 0\n        ).output(0)\n        new_filter = ov_opset.reshape(kernel, filter_new_shape, False).output(0)\n        conv = ov_opset.group_convolution(\n            inputs,\n            new_filter,\n            strides,\n            pads_begin,\n            pads_end,\n            dilation_rate,\n            pad_mode,\n        )\n    conv = _adjust_outputs(conv.output(0), num_spatial_dims, data_format)\n    return OpenVINOKerasTensor(conv)\n\n\ndef depthwise_conv(\n    inputs,\n    kernel,\n    strides=1,\n    padding=\"valid\",\n    data_format=None,\n    dilation_rate=1,\n):\n    inputs = get_ov_output(inputs)\n    kernel = get_ov_output(kernel)\n\n    data_format = backend.standardize_data_format(data_format)\n    num_spatial_dims = inputs.get_partial_shape().rank.get_length() - 2\n\n    if data_format != \"channels_last\":\n        raise ValueError(\n            \"OpenVINO depthwise_conv only supports 'channels_last' \"\n            f\"data format. Received: data_format={data_format}\"\n        )\n\n    strides = _adjust_strides_dilation(strides, num_spatial_dims)\n    dilation_rate = _adjust_strides_dilation(dilation_rate, num_spatial_dims)\n    pad_mode, pads_begin, pads_end = _adjust_padding(padding)\n\n    inputs = _adjust_input(inputs, num_spatial_dims, data_format)\n    kernel = _adjust_depthwise_kernel(kernel, num_spatial_dims)\n    unsqueeze_dim = ov_opset.constant([2], Type.i32)\n    kernel = ov_opset.unsqueeze(kernel, unsqueeze_dim)\n\n    group_conv = ov_opset.group_convolution(\n        inputs, kernel, strides, pads_begin, pads_end, dilation_rate, pad_mode\n    )\n    group_conv = _adjust_outputs(\n        group_conv.output(0), num_spatial_dims, data_format\n    )\n    return OpenVINOKerasTensor(group_conv)\n\n\ndef separable_conv(\n    inputs,\n    depthwise_kernel,\n    pointwise_kernel,\n    strides=1,\n    padding=\"valid\",\n    data_format=None,\n    dilation_rate=1,\n):\n    data_format = backend.standardize_data_format(data_format)\n    depthwise_conv_output = depthwise_conv(\n        inputs,\n        depthwise_kernel,\n        strides,\n        padding,\n        data_format,\n        dilation_rate,\n    )\n    return conv(\n        depthwise_conv_output,\n        pointwise_kernel,\n        strides=1,\n        padding=\"valid\",\n        data_format=data_format,\n        dilation_rate=dilation_rate,\n    )\n\n\ndef conv_transpose(\n    inputs,\n    kernel,\n    strides=1,\n    padding=\"valid\",\n    output_padding=None,\n    data_format=None,\n    dilation_rate=1,\n):\n    inputs = get_ov_output(inputs)\n    kernel = get_ov_output(kernel)\n\n    data_format = backend.standardize_data_format(data_format)\n    num_spatial_dims = inputs.get_partial_shape().rank.get_length() - 2\n\n    strides = _adjust_strides_dilation(strides, num_spatial_dims)\n    dilation_rate = _adjust_strides_dilation(dilation_rate, num_spatial_dims)\n\n    # Convert to channels-first (NCHW) layout\n    inputs = _adjust_input(inputs, num_spatial_dims, data_format)\n    # Rearrange kernel from Keras (*kernel, C_out, C_in)\n    # to OpenVINO format (C_in, C_out, *kernel)\n    kernel = _adjust_kernel(kernel, num_spatial_dims)\n\n    # inputs: (N, C_in, *spatial), kernel: (C_in, C_out, *kernel_size)\n    input_pshape = inputs.get_partial_shape()\n    kernel_pshape = kernel.get_partial_shape()\n\n    spatial_output_shape = []\n    all_static = True\n    for i in range(num_spatial_dims):\n        in_dim = input_pshape[2 + i]\n        k_dim = kernel_pshape[2 + i]\n        s = strides[i]\n        d = dilation_rate[i]\n        op_i = (\n            output_padding\n            if output_padding is None or isinstance(output_padding, int)\n            else output_padding[i]\n        )\n        if in_dim.is_static and k_dim.is_static:\n            out_dim = _get_output_shape_given_tf_padding(\n                input_size=in_dim.get_length(),\n                kernel_size=k_dim.get_length(),\n                strides=s,\n                padding=padding,\n                output_padding=op_i,\n                dilation_rate=d,\n            )\n            spatial_output_shape.append(out_dim)\n        else:\n            all_static = False\n            break\n\n    pad_mode = \"SAME_LOWER\" if padding.lower() == \"same\" else \"VALID\"\n\n    if all_static:\n        output_shape_node = ov_opset.constant(\n            spatial_output_shape, Type.i64\n        ).output(0)\n    else:\n        output_shape_node = None\n\n    conv_t = ov_opset.convolution_backprop_data(\n        inputs,\n        kernel,\n        strides=strides,\n        output_shape=output_shape_node,\n        dilations=dilation_rate,\n        auto_pad=pad_mode,\n    )\n    result = _adjust_outputs(conv_t.output(0), num_spatial_dims, data_format)\n    return OpenVINOKerasTensor(result)\n\n\ndef one_hot(x, num_classes, axis=-1, dtype=None, sparse=False):\n    if sparse:\n        raise ValueError(\"`sparse=True` is not supported with openvino backend\")\n    x = get_ov_output(x)\n    if dtype is None:\n        dtype = backend.floatx()\n    ov_dtype = OPENVINO_DTYPES[dtype]\n    on_value = get_ov_output(1, ov_dtype)\n    off_value = get_ov_output(0, ov_dtype)\n    one_hot_encoded = ov_opset.one_hot(\n        x,\n        depth=num_classes,\n        axis=axis,\n        on_value=on_value,\n        off_value=off_value,\n    ).output(0)\n    return OpenVINOKerasTensor(one_hot_encoded)\n\n\ndef multi_hot(x, num_classes, axis=-1, dtype=None, sparse=False):\n    reduction_axis = 1 if len(x.shape) > 1 else 0\n    if backend.standardize_dtype(dtype) == \"bool\":\n        outputs = one_hot(x, num_classes, axis=axis, dtype=dtype, sparse=sparse)\n        result = ov_opset.reduce_logical_or(outputs, reduction_axis)\n    else:\n        outputs = one_hot(x, num_classes, axis=axis, dtype=dtype)\n        result = ov_opset.reduce_max(outputs, reduction_axis)\n    return OpenVINOKerasTensor(result.output(0))\n\n\ndef categorical_crossentropy(target, output, from_logits=False, axis=-1):\n    target = get_ov_output(target)\n    output = get_ov_output(output)\n\n    if target.shape != output.shape:\n        raise ValueError(\n            \"Arguments `target` and `output` must have the same shape. \"\n            \"Received: \"\n            f\"target.shape={target.shape}, output.shape={output.shape}\"\n        )\n    if len(target.shape) < 1:\n        raise ValueError(\n            \"Arguments `target` and `output` must be at least rank 1. \"\n            \"Received: \"\n            f\"target.shape={target.shape}, output.shape={output.shape}\"\n        )\n\n    if from_logits:\n        log_prob = ov_opset.log_softmax(output, axis).output(0)\n    else:\n        sum_result = ov_opset.reduce_sum(output, axis, keep_dims=True).output(0)\n        output = ov_opset.divide(output, sum_result).output(0)\n        output = ov_opset.clamp(\n            output, min_value=backend.epsilon(), max_value=1 - backend.epsilon()\n        ).output(0)\n        log_prob = ov_opset.log(output).output(0)\n    result = ov_opset.multiply(target, log_prob).output(0)\n    loss = ov_opset.reduce_sum(result, axis).output(0)\n    loss = ov_opset.negative(loss).output(0)\n    return OpenVINOKerasTensor(loss)\n\n\ndef sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):\n    target = get_ov_output(target)\n    output = get_ov_output(output)\n\n    if len(target.shape) == len(output.shape) and target.shape[-1] == 1:\n        target = ov_opset.squeeze(target, -1).output(0)\n\n    if len(output.shape) < 1:\n        raise ValueError(\n            \"Argument `output` must be at least rank 1. \"\n            \"Received: \"\n            f\"output.shape={output.shape}\"\n        )\n\n    output_shape_without_class_dim = list(output.shape)\n    del output_shape_without_class_dim[axis]\n\n    if list(target.shape) != output_shape_without_class_dim:\n        raise ValueError(\n            \"Arguments `target` and `output` must have the same shape \"\n            \"up until the last dimension: \"\n            f\"target.shape={target.shape}, output.shape={output.shape}\"\n        )\n\n    if from_logits:\n        log_prob = ov_opset.log_softmax(output, axis).output(0)\n    else:\n        sum = ov_opset.reduce_sum(output, axis, keep_dims=True).output(0)\n        output = ov_opset.divide(output, sum).output(0)\n        output = ov_opset.clamp(\n            output, min_value=backend.epsilon(), max_value=1 - backend.epsilon()\n        ).output(0)\n        log_prob = ov_opset.log(output).output(0)\n\n    output_type = output.get_element_type()\n    on_val = ov_opset.constant(1, output_type).output(0)\n    off_val = ov_opset.constant(0, output_type).output(0)\n    one_hot_target = ov_opset.one_hot(\n        target,\n        depth=output.shape[axis],\n        on_value=on_val,\n        off_value=off_val,\n        axis=axis,\n    ).output(0)\n    result = ov_opset.multiply(one_hot_target, log_prob).output(0)\n    loss = ov_opset.reduce_sum(result, axis).output(0)\n    loss = ov_opset.negative(loss).output(0)\n    return OpenVINOKerasTensor(loss)\n\n\ndef binary_crossentropy(target, output, from_logits=False):\n    target = get_ov_output(target)\n    output = get_ov_output(output)\n    if target.get_element_type() != output.get_element_type():\n        output = ov_opset.convert(output, target.get_element_type()).output(0)\n\n    if target.shape != output.shape:\n        raise ValueError(\n            \"Arguments `target` and `output` must have the same shape. \"\n            \"Received: \"\n            f\"target.shape={target.shape}, output.shape={output.shape}\"\n        )\n\n    if from_logits:\n        one = ov_opset.constant(1, target.get_element_type()).output(0)\n        neg_output = ov_opset.negative(output).output(0)\n        one_minus_target = ov_opset.subtract(one, target).output(0)\n        bce = ov_opset.add(\n            ov_opset.multiply(\n                target, ov_opset.softplus(neg_output).output(0)\n            ).output(0),\n            ov_opset.multiply(\n                one_minus_target, ov_opset.softplus(output).output(0)\n            ).output(0),\n        ).output(0)\n        return OpenVINOKerasTensor(bce)\n\n    output = ov_opset.clamp(\n        output, min_value=backend.epsilon(), max_value=1 - backend.epsilon()\n    ).output(0)\n    one = ov_opset.constant(1, target.get_element_type()).output(0)\n\n    minus_output = ov_opset.subtract(one, output).output(0)\n    minus_target = ov_opset.subtract(one, target).output(0)\n\n    log_prob = ov_opset.log(output).output(0)\n    minus_log_prob = ov_opset.log(minus_output).output(0)\n    result = ov_opset.multiply(target, log_prob).output(0)\n    minus_result = ov_opset.multiply(minus_target, minus_log_prob).output(0)\n    bce = ov_opset.add(result, minus_result).output(0)\n    bce = ov_opset.negative(bce).output(0)\n    return OpenVINOKerasTensor(bce)\n\n\ndef moments(x, axes, keepdims=False, synchronized=False):\n    x = get_ov_output(x)\n    ori_type = x.get_element_type()\n    if ori_type == Type.f16:\n        x = ov_opset.convert(x, Type.f32).output(0)\n    axes_c = ov_opset.constant(axes, Type.i32).output(0)\n    const_two = ov_opset.constant(2, x.get_element_type()).output(0)\n    mean = ov_opset.reduce_mean(x, axes_c, keepdims).output(0)\n    squared_x_mean = ov_opset.reduce_mean(\n        ov_opset.power(x, const_two).output(0), axes_c, keepdims\n    ).output(0)\n    variance = ov_opset.subtract(\n        squared_x_mean, ov_opset.power(mean, const_two).output(0)\n    ).output(0)\n    if ori_type == Type.f16:\n        fp16_max = float(np.finfo(np.float16).max)\n        fp16_min = float(np.finfo(np.float16).min)\n        mean = ov_opset.convert(\n            ov_opset.clamp(mean, fp16_min, fp16_max).output(0), Type.f16\n        ).output(0)\n        variance = ov_opset.convert(\n            ov_opset.clamp(variance, 0.0, fp16_max).output(0), Type.f16\n        ).output(0)\n    return OpenVINOKerasTensor(mean), OpenVINOKerasTensor(variance)\n\n\ndef batch_normalization(\n    x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3\n):\n    x = get_ov_output(x)\n    mean = get_ov_output(mean)\n    variance = get_ov_output(variance)\n    if offset is not None:\n        offset = get_ov_output(offset)\n    else:\n        mean_shape = ov_opset.shape_of(mean)\n        mean_type = mean.get_element_type()\n        zero_const = ov_opset.constant([0], mean_type)\n        offset = ov_opset.broadcast(zero_const, mean_shape)\n    if scale is not None:\n        scale = get_ov_output(scale)\n    else:\n        mean_shape = ov_opset.shape_of(mean)\n        mean_type = mean.get_element_type()\n        one_const = ov_opset.constant([1], mean_type)\n        scale = ov_opset.broadcast(one_const, mean_shape)\n\n    # adjust x input to have the second dimension representing the channel axis\n    x_rank = x.get_partial_shape().rank.get_length()\n    if axis < 0:\n        axis += x_rank\n    if axis != 1:\n        perm_vector = list(range(0, x_rank))\n        perm_vector[1] = axis\n        perm_vector[axis] = 1\n        perm_vector = ov_opset.constant(perm_vector, Type.i32).output(0)\n        x = ov_opset.transpose(x, perm_vector).output(0)\n    batch_norm = ov_opset.batch_norm_inference(\n        x, scale, offset, mean, variance, epsilon\n    ).output(0)\n    if axis != 1:\n        perm_vector = list(range(0, x_rank))\n        perm_vector[1] = axis\n        perm_vector[axis] = 1\n        perm_vector = ov_opset.constant(perm_vector, Type.i32).output(0)\n        batch_norm = ov_opset.transpose(batch_norm, perm_vector).output(0)\n    return OpenVINOKerasTensor(batch_norm)\n\n\ndef ctc_loss(target, output, target_length, output_length, mask_index=0):\n    target = get_ov_output(target)\n    output = get_ov_output(output)\n    target_length = get_ov_output(target_length)\n    output_length = get_ov_output(output_length)\n    ctc_loss_ = ov_opset.ctc_loss(\n        output, output_length, target, target_length, blank_index=mask_index\n    )\n    ctc_loss_ = ov_opset.convert(ctc_loss_, OPENVINO_DTYPES[backend.floatx()])\n    return OpenVINOKerasTensor(ctc_loss_.output(0))\n\n\ndef ctc_decode(\n    inputs,\n    sequence_lengths,\n    strategy=\"greedy\",\n    beam_width=100,\n    top_paths=1,\n    merge_repeated=True,\n    mask_index=0,\n):\n    raise NotImplementedError(\n        \"`ctc_decode` is not supported with openvino backend\"\n    )\n\n\ndef psnr(x1, x2, max_val):\n    from keras.src.backend.openvino.numpy import log10\n\n    x1 = get_ov_output(x1)\n    x2 = get_ov_output(x2)\n    max_val = get_ov_output(max_val, x1.get_element_type())\n    diff = ov_opset.subtract(x1, x2)\n    squared_diff = ov_opset.multiply(diff, diff)\n    reduction_axes = list(range(0, x1.get_partial_shape().rank.get_length()))\n    mse = ov_opset.reduce_mean(squared_diff, reduction_axes).output(0)\n    log_max_val = get_ov_output(log10(OpenVINOKerasTensor(max_val)))\n    log_mse = get_ov_output(log10(OpenVINOKerasTensor(mse)))\n\n    psnr = ov_opset.subtract(\n        ov_opset.multiply(\n            ov_opset.constant(20, log_max_val.get_element_type()), log_max_val\n        ),\n        ov_opset.multiply(\n            ov_opset.constant(10, log_mse.get_element_type()), log_mse\n        ),\n    ).output(0)\n    return OpenVINOKerasTensor(psnr)\n\n\ndef dot_product_attention(\n    query,\n    key,\n    value,\n    bias=None,\n    mask=None,\n    scale=None,\n    is_causal=False,\n    flash_attention=None,\n    attn_logits_soft_cap=None,\n):\n    if bias is not None:\n        raise NotImplementedError(\n            \"`dot_product_attention` with `bias` is not supported \"\n            \"with openvino backend\"\n        )\n    if flash_attention:\n        raise NotImplementedError(\n            \"`dot_product_attention` with `flash_attention` is not supported \"\n            \"with openvino backend\"\n        )\n    if attn_logits_soft_cap is not None:\n        raise NotImplementedError(\n            \"`dot_product_attention` with `attn_logits_soft_cap` is not \"\n            \"supported with openvino backend\"\n        )\n    query = get_ov_output(query)\n    key = get_ov_output(key)\n    value = get_ov_output(value)\n    if query.get_element_type() != key.get_element_type():\n        ov_type = OPENVINO_DTYPES[backend.floatx()]\n        query = ov_opset.convert(query, ov_type).output(0)\n        key = ov_opset.convert(key, ov_type).output(0)\n    if value.get_element_type() != query.get_element_type():\n        value = ov_opset.convert(value, query.get_element_type()).output(0)\n    axes_const = ov_opset.constant([0, 2, 1, 3], Type.i32).output(0)\n\n    query = ov_opset.transpose(query, axes_const)\n    key = ov_opset.transpose(key, axes_const)\n    value = ov_opset.transpose(value, axes_const)\n    mask = get_ov_output(mask) if mask is not None else None\n    scale = (\n        get_ov_output(scale, query.get_element_type())\n        if scale is not None\n        else None\n    )\n    dpa = ov_opset.scaled_dot_product_attention(\n        query, key, value, attention_mask=mask, scale=scale, causal=is_causal\n    )\n    dpa = ov_opset.transpose(dpa, axes_const)\n    return OpenVINOKerasTensor(dpa.output(0))\n\n\ndef unfold(input, kernel_size, dilation=1, padding=0, stride=1):\n    def _pair(x):\n        return (x, x) if isinstance(x, int) else x\n\n    k = _pair(kernel_size)\n    d = _pair(dilation)\n    p = _pair(padding)\n    s = _pair(stride)\n\n    N, C, H, W = input.shape\n\n    # ---- padding ----\n    if any(_ > 0 for _ in p):\n        input = onp.pad(input, ((0, 0), (0, 0), (p[0], p[0]), (p[1], p[1])))\n\n    # ---- extract patches ----\n    patches = ov_opset.extract_image_patches(\n        image=input,\n        sizes=[k[0], k[1]],\n        strides=[s[0], s[1]],\n        rates=[d[0], d[1]],\n        auto_pad=\"VALID\",\n    )  # (N, kH*kW*C, nH, nW)\n    N, D, nH, nW = patches.shape\n    patches = ov_opset.reshape(patches, [N, k[0], k[1], C, nH, nW], False)\n    patches = ov_opset.transpose(\n        patches, [0, 3, 1, 2, 4, 5]\n    )  # (N, C, kH, kW, nH, nW)\n    patches = ov_opset.reshape(patches, [N, C * k[0] * k[1], nH * nW], False)\n    return OpenVINOKerasTensor(patches.output(0))\n\n\ndef fold(x, output_size, kernel_size, dilation=1, padding=0, stride=1):\n    def _pair(val):\n        return (val, val) if isinstance(val, int) else val\n\n    oH, oW = _pair(output_size)\n    kH, kW = _pair(kernel_size)\n    dH, dW = _pair(dilation)\n    pH, pW = _pair(padding)\n    sH, sW = _pair(stride)\n\n    x = get_ov_output(x)\n\n    # Derive CKK and C dynamically from the input shape to support dynamic dims.\n    shape_x = ov_opset.shape_of(x).output(0)\n    one_i64 = ov_opset.constant([1], Type.i64).output(0)\n    two_i64 = ov_opset.constant([2], Type.i64).output(0)\n    step_i64 = ov_opset.constant([1], Type.i64).output(0)\n    CKK_1d = ov_opset.slice(shape_x, one_i64, two_i64, step_i64).output(0)\n    KK_node = ov_opset.constant([kH * kW], Type.i64).output(0)\n    C_1d = ov_opset.divide(CKK_1d, KK_node).output(0)\n    CKK_scalar = ov_opset.squeeze(CKK_1d, [0]).output(0)\n\n    nH = (oH + 2 * pH - dH * (kH - 1) - 1) // sH + 1\n    nW = (oW + 2 * pW - dW * (kW - 1) - 1) // sW + 1\n\n    # Reshape: (N, CKK, L) -> (N, CKK, nH, nW); 0 copies the dim from input.\n    new_shape = ov_opset.constant([0, 0, nH, nW], Type.i64).output(0)\n    x = ov_opset.reshape(x, new_shape, True).output(0)\n\n    # Build identity kernel dynamically: shape (CKK, CKK) via one_hot,\n    # then reshape to (CKK, C, kH, kW).\n    indices = ov_opset.range(\n        ov_opset.constant(0, Type.i64),\n        CKK_scalar,\n        ov_opset.constant(1, Type.i64),\n        Type.i64,\n    ).output(0)\n    on_val = ov_opset.constant(1, Type.f32).output(0)\n    off_val = ov_opset.constant(0, Type.f32).output(0)\n    eye = ov_opset.one_hot(\n        indices, depth=CKK_scalar, on_value=on_val, off_value=off_val, axis=-1\n    ).output(0)\n    kernel_shape = ov_opset.concat(\n        [CKK_1d, C_1d, ov_opset.constant([kH, kW], Type.i64).output(0)], axis=0\n    ).output(0)\n    kernel = ov_opset.reshape(eye, kernel_shape, False).output(0)\n    kernel = ov_opset.convert(kernel, x.get_element_type()).output(0)\n\n    oH_pad = oH + 2 * pH\n    oW_pad = oW + 2 * pW\n\n    output_shape_node = ov_opset.constant([oH_pad, oW_pad], Type.i64).output(0)\n    result = ov_opset.convolution_backprop_data(\n        x,\n        kernel,\n        strides=[sH, sW],\n        output_shape=output_shape_node,\n        dilations=[dH, dW],\n        auto_pad=\"VALID\",\n    ).output(0)\n\n    if pH > 0 or pW > 0:\n        axes = ov_opset.constant([2, 3], Type.i32).output(0)\n        start = ov_opset.constant([pH, pW], Type.i32).output(0)\n        stop = ov_opset.constant([oH_pad - pH, oW_pad - pW], Type.i32).output(0)\n        step = ov_opset.constant([1, 1], Type.i32).output(0)\n        result = ov_opset.slice(result, start, stop, step, axes).output(0)\n\n    return OpenVINOKerasTensor(result)\n\n\ndef depth_to_space(x, block_size, data_format=\"channels_last\"):\n    \"\"\"OpenVINO implementation of depth_to_space (pixel shuffle).\n\n    Rearranges data from depth into blocks of spatial data.\n\n    Args:\n        x: 4-D tensor with shape (N, H, W, C) for channels_last or\n            (N, C, H, W) for channels_first.\n        block_size: An integer specifying the block size.\n        data_format: \"channels_last\" or \"channels_first\".\n\n    Returns:\n        A tensor with shape (N, H*block_size, W*block_size, C/block_size**2)\n        for channels_last or (N, C/block_size**2, H*block_size, W*block_size)\n        for channels_first.\n    \"\"\"\n    x = get_ov_output(x)\n    # OpenVINO depth_to_space uses \"blocks_first\" mode by default\n    # and expects NCHW format\n    if data_format == \"channels_last\":\n        # Convert NHWC to NCHW\n        axes = ov_opset.constant([0, 3, 1, 2], Type.i32).output(0)\n        x = ov_opset.transpose(x, axes).output(0)\n        result = ov_opset.depth_to_space(x, \"blocks_first\", block_size).output(\n            0\n        )\n        # Convert back to NHWC\n        axes_back = ov_opset.constant([0, 2, 3, 1], Type.i32).output(0)\n        result = ov_opset.transpose(result, axes_back).output(0)\n    else:\n        result = ov_opset.depth_to_space(x, \"blocks_first\", block_size).output(\n            0\n        )\n    return OpenVINOKerasTensor(result)\n\n\ndef space_to_depth(x, block_size, data_format=\"channels_last\"):\n    \"\"\"OpenVINO implementation of space_to_depth (pixel unshuffle).\n\n    Rearranges blocks of spatial data into depth.\n\n    Args:\n        x: 4-D tensor with shape (N, H, W, C) for channels_last or\n            (N, C, H, W) for channels_first.\n        block_size: An integer specifying the block size.\n        data_format: \"channels_last\" or \"channels_first\".\n\n    Returns:\n        A tensor with shape (N, H/block_size, W/block_size, C*block_size**2)\n        for channels_last or (N, C*block_size**2, H/block_size, W/block_size)\n        for channels_first.\n    \"\"\"\n    x = get_ov_output(x)\n    # OpenVINO space_to_depth uses \"blocks_first\" mode by default\n    # and expects NCHW format\n    if data_format == \"channels_last\":\n        # Convert NHWC to NCHW\n        axes = ov_opset.constant([0, 3, 1, 2], Type.i32).output(0)\n        x = ov_opset.transpose(x, axes).output(0)\n        result = ov_opset.space_to_depth(x, \"blocks_first\", block_size).output(\n            0\n        )\n        # Convert back to NHWC\n        axes_back = ov_opset.constant([0, 2, 3, 1], Type.i32).output(0)\n        result = ov_opset.transpose(result, axes_back).output(0)\n    else:\n        result = ov_opset.space_to_depth(x, \"blocks_first\", block_size).output(\n            0\n        )\n    return OpenVINOKerasTensor(result)\n"
  },
  {
    "path": "keras/src/backend/openvino/numpy.py",
    "content": "import numpy as np\nimport openvino as ov\nimport openvino.opset15 as ov_opset\nfrom openvino import Type\n\nfrom keras.src.backend import config\nfrom keras.src.backend.common import dtypes\nfrom keras.src.backend.common.backend_utils import canonicalize_axis\nfrom keras.src.backend.common.variables import standardize_dtype\nfrom keras.src.backend.openvino.core import DTYPES_MAX\nfrom keras.src.backend.openvino.core import DTYPES_MIN\nfrom keras.src.backend.openvino.core import OPENVINO_DTYPES\nfrom keras.src.backend.openvino.core import OpenVINOKerasTensor\nfrom keras.src.backend.openvino.core import (\n    align_operand_types as _align_operand_types,\n)\nfrom keras.src.backend.openvino.core import convert_to_tensor\nfrom keras.src.backend.openvino.core import get_ov_output\nfrom keras.src.backend.openvino.core import ov_to_keras_type\nfrom keras.src.backend.openvino.core import while_loop\n\n\ndef _promote_binary_op_types(x1, x2):\n    t1 = (\n        ov_to_keras_type(x1.get_element_type())\n        if isinstance(x1, ov.Output)\n        else getattr(x1, \"dtype\", type(x1))\n    )\n    t2 = (\n        ov_to_keras_type(x2.get_element_type())\n        if isinstance(x2, ov.Output)\n        else getattr(x2, \"dtype\", type(x2))\n    )\n    target_type = OPENVINO_DTYPES[dtypes.result_type(t1, t2)]\n    x1 = get_ov_output(x1, target_type)\n    x2 = get_ov_output(x2, target_type)\n    return x1, x2\n\n\ndef add(x1, x2):\n    x1, x2 = _promote_binary_op_types(x1, x2)\n    x1, x2 = _align_operand_types(x1, x2, \"add()\")\n    if x1.get_element_type() == Type.boolean:\n        return OpenVINOKerasTensor(ov_opset.logical_or(x1, x2).output(0))\n    return OpenVINOKerasTensor(ov_opset.add(x1, x2).output(0))\n\n\ndef einsum(subscripts, *operands, **kwargs):\n    inputs = [get_ov_output(operand) for operand in operands]\n    keras_types = [ov_to_keras_type(inp.get_element_type()) for inp in inputs]\n    result_dtype = (\n        dtypes.result_type(*keras_types) if keras_types else config.floatx()\n    )\n    if set(keras_types) == {\"int8\"}:\n        result_dtype = \"int32\"\n    ov_result_type = OPENVINO_DTYPES[result_dtype]\n    # OV Einsum supports float*/int32/int64; promote unsupported types\n    _ov_einsum_ok = {\n        OPENVINO_DTYPES[t]\n        for t in (\"float16\", \"bfloat16\", \"float32\", \"float64\", \"int32\", \"int64\")\n    }\n    if ov_result_type not in _ov_einsum_ok:\n        ov_compute_type = OPENVINO_DTYPES[\n            \"int64\" if result_dtype in (\"uint32\", \"uint64\") else \"int32\"\n        ]\n    else:\n        ov_compute_type = ov_result_type\n    inputs = [\n        ov_opset.convert(inp, ov_compute_type).output(0)\n        if inp.get_element_type() != ov_compute_type\n        else inp\n        for inp in inputs\n    ]\n    result = ov_opset.einsum(inputs, subscripts).output(0)\n    if result.get_element_type() != ov_result_type:\n        result = ov_opset.convert(result, ov_result_type).output(0)\n    return OpenVINOKerasTensor(result)\n\n\ndef subtract(x1, x2):\n    x1, x2 = _promote_binary_op_types(x1, x2)\n    x1, x2 = _align_operand_types(x1, x2, \"subtract()\")\n    if x1.get_element_type() == Type.boolean:\n        return OpenVINOKerasTensor(ov_opset.logical_xor(x1, x2).output(0))\n    return OpenVINOKerasTensor(ov_opset.subtract(x1, x2).output(0))\n\n\ndef matmul(x1, x2):\n    element_type = None\n    if isinstance(x1, OpenVINOKerasTensor):\n        element_type = x1.output.get_element_type()\n    if isinstance(x2, OpenVINOKerasTensor):\n        element_type = x2.output.get_element_type()\n    x1 = get_ov_output(x1, element_type)\n    x2 = get_ov_output(x2, element_type)\n\n    # When both inputs are int8, promote to int32 to align with other backends.\n    if (\n        ov_to_keras_type(x1.get_element_type()) == \"int8\"\n        and ov_to_keras_type(x2.get_element_type()) == \"int8\"\n    ):\n        int32_type = OPENVINO_DTYPES[\"int32\"]\n        x1 = ov_opset.convert(x1, int32_type).output(0)\n        x2 = ov_opset.convert(x2, int32_type).output(0)\n\n    x1, x2 = _align_operand_types(x1, x2, \"matmul()\")\n    return OpenVINOKerasTensor(ov_opset.matmul(x1, x2, False, False).output(0))\n\n\ndef multiply(x1, x2):\n    x1, x2 = _promote_binary_op_types(x1, x2)\n    x1, x2 = _align_operand_types(x1, x2, \"multiply()\")\n    if x1.get_element_type() == Type.boolean:\n        return OpenVINOKerasTensor(ov_opset.logical_and(x1, x2).output(0))\n    return OpenVINOKerasTensor(ov_opset.multiply(x1, x2).output(0))\n\n\ndef mean(x, axis=None, keepdims=False):\n    x_ov = get_ov_output(x)\n    x_type = x_ov.get_element_type()\n\n    was_axis_none = axis is None\n    x_resolved, axis_resolved = _resolve_axis(x_ov, axis)\n\n    if axis_resolved is None:\n        return OpenVINOKerasTensor(x_ov)\n\n    if x_type.is_integral():\n        ov_type = OPENVINO_DTYPES[config.floatx()]\n        x_resolved = ov_opset.convert(x_resolved, ov_type).output(0)\n\n    result = ov_opset.reduce_mean(x_resolved, axis_resolved, keepdims).output(0)\n\n    if keepdims and was_axis_none:\n        rank = x.get_partial_shape().rank.get_length()\n        result_shape = [1] * rank\n        result = ov_opset.reshape(\n            result,\n            ov_opset.constant(result_shape, Type.i32).output(0),\n            False,\n        ).output(0)\n\n    return OpenVINOKerasTensor(result)\n\n\ndef max(x, axis=None, keepdims=False, initial=None):\n    return _compute_extrema(x, \"max\", axis, keepdims, initial)\n\n\ndef _compute_extrema(x, operation, axis=None, keepdims=False, initial=None):\n    if operation == \"min\":\n        reduction_op = ov_opset.reduce_min\n        elementwise_op = ov_opset.minimum\n    elif operation == \"max\":\n        reduction_op = ov_opset.reduce_max\n        elementwise_op = ov_opset.maximum\n    else:\n        raise ValueError(\n            f\"Operation must be 'min' or 'max', received {operation}\"\n        )\n\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    x_for_rank = x\n\n    is_bool = x_type == Type.boolean\n    if is_bool:\n        x = ov_opset.convert(x, Type.i32).output(0)\n        x_type = Type.i32\n\n    if isinstance(axis, tuple) and len(axis) == 0:\n        return OpenVINOKerasTensor(x)\n\n    was_axis_none = axis is None\n    x, axis = _resolve_axis(x, axis)\n\n    result = reduction_op(x, axis, keepdims).output(0)\n\n    if initial is not None:\n        initial_tensor = ov_opset.constant(initial, x_type).output(0)\n        result = elementwise_op(result, initial_tensor).output(0)\n\n    if keepdims and was_axis_none:\n        orig_shape = ov_opset.shape_of(x_for_rank, Type.i32).output(0)\n        orig_rank_shape = ov_opset.shape_of(orig_shape, Type.i32).output(0)\n        one = ov_opset.constant(1, Type.i32).output(0)\n        result_shape = ov_opset.broadcast(one, orig_rank_shape).output(0)\n        result = ov_opset.reshape(result, result_shape, False).output(0)\n\n    if is_bool:\n        result = ov_opset.convert(result, Type.boolean).output(0)\n\n    return OpenVINOKerasTensor(result)\n\n\ndef ones(shape, dtype=None):\n    dtype = standardize_dtype(dtype) or config.floatx()\n    ov_type = OPENVINO_DTYPES[dtype]\n    const_one = ov_opset.constant(1, ov_type).output(0)\n    if isinstance(shape, tuple):\n        shape = list(shape)\n    elif isinstance(shape, int):\n        shape = [shape]\n    output_shape = ov_opset.constant(shape, dtype=Type.i32).output(0)\n    ones = ov_opset.broadcast(const_one, output_shape)\n    return OpenVINOKerasTensor(ones.output(0))\n\n\ndef zeros(shape, dtype=None):\n    dtype = standardize_dtype(dtype) or config.floatx()\n    ov_type = OPENVINO_DTYPES[dtype]\n    const_zero = ov_opset.constant(0, dtype=ov_type).output(0)\n    if isinstance(shape, tuple):\n        shape = list(shape)\n    elif isinstance(shape, int):\n        shape = [shape]\n    output_shape = ov_opset.constant(shape, dtype=Type.i32).output(0)\n    zeros = ov_opset.broadcast(const_zero, output_shape)\n    return OpenVINOKerasTensor(zeros.output(0))\n\n\ndef absolute(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type == Type.boolean:\n        return OpenVINOKerasTensor(x)\n    return OpenVINOKerasTensor(ov_opset.absolute(x).output(0))\n\n\ndef abs(x):\n    x = get_ov_output(x)\n    return OpenVINOKerasTensor(ov_opset.absolute(x).output(0))\n\n\ndef all(x, axis=None, keepdims=False):\n    x = get_ov_output(x)\n    x, axis = _resolve_axis(x, axis)\n    if axis is None:\n        return OpenVINOKerasTensor(x)\n    x = ov_opset.convert(x, Type.boolean).output(0)\n    return OpenVINOKerasTensor(\n        ov_opset.reduce_logical_and(x, axis, keepdims).output(0)\n    )\n\n\ndef allclose(x1, x2, rtol=1e-05, atol=1e-08, equal_nan=False):\n    if (\n        not isinstance(x1, OpenVINOKerasTensor)\n        and not isinstance(x2, OpenVINOKerasTensor)\n        and not isinstance(x1, ov.Output)\n        and not isinstance(x2, ov.Output)\n    ):\n        try:\n            return OpenVINOKerasTensor(\n                ov_opset.constant(\n                    np.allclose(\n                        x1, x2, rtol=rtol, atol=atol, equal_nan=equal_nan\n                    ),\n                    Type.boolean,\n                ).output(0)\n            )\n        except Exception:\n            pass\n\n    return all(isclose(x1, x2, rtol=rtol, atol=atol, equal_nan=equal_nan))\n\n\ndef angle(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n\n    # angle() is defined over real values in OpenVINO backend.\n    # Cast integral and boolean to floatx\n    if x_type.is_integral() or x_type == Type.boolean:\n        x = ov_opset.convert(x, OPENVINO_DTYPES[config.floatx()]).output(0)\n        x_type = x.get_element_type()\n\n    zero = ov_opset.constant(0, x_type).output(0)\n    pi = ov_opset.constant(float(np.pi), x_type).output(0)\n    is_negative = ov_opset.less(x, zero).output(0)\n    result = ov_opset.select(is_negative, pi, zero).output(0)\n\n    return OpenVINOKerasTensor(result)\n\n\ndef any(x, axis=None, keepdims=False):\n    x = get_ov_output(x)\n    x, axis = _resolve_axis(x, axis)\n    if axis is None:\n        return OpenVINOKerasTensor(x)\n    x = ov_opset.convert(x, Type.boolean).output(0)\n    return OpenVINOKerasTensor(\n        ov_opset.reduce_logical_or(x, axis, keepdims).output(0)\n    )\n\n\ndef amax(x, axis=None, keepdims=False):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    x, axis = _resolve_axis(x, axis)\n    if axis is None:\n        return OpenVINOKerasTensor(x)\n    if x_type == Type.boolean:\n        return OpenVINOKerasTensor(\n            ov_opset.reduce_logical_or(x, axis, keepdims).output(0)\n        )\n    return OpenVINOKerasTensor(ov_opset.reduce_max(x, axis, keepdims).output(0))\n\n\ndef amin(x, axis=None, keepdims=False):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    x, axis = _resolve_axis(x, axis)\n    if axis is None:\n        return OpenVINOKerasTensor(x)\n    if x_type == Type.boolean:\n        return OpenVINOKerasTensor(\n            ov_opset.reduce_logical_and(x, axis, keepdims).output(0)\n        )\n    return OpenVINOKerasTensor(ov_opset.reduce_min(x, axis, keepdims).output(0))\n\n\ndef _resolve_axis(x, axis):\n    if axis == () or axis == []:\n        return x, None\n    if axis is None:\n        flatten_shape = ov_opset.constant([-1], Type.i32).output(0)\n        x = ov_opset.reshape(x, flatten_shape, False).output(0)\n        axis = 0\n    if isinstance(axis, tuple):\n        axis = list(axis)\n    axis = ov_opset.constant(axis, Type.i32).output(0)\n    return x, axis\n\n\ndef _upcast_type_if_needed(x):\n    x_type = x.get_element_type()\n    if x_type == Type.boolean:\n        x = ov_opset.convert(x, Type.i32).output(0)\n    elif x_type in (Type.i8, Type.i16):\n        x = ov_opset.convert(x, Type.i32).output(0)\n    elif x_type in (Type.u8, Type.u16):\n        x = ov_opset.convert(x, Type.u32).output(0)\n    return x\n\n\ndef append(x1, x2, axis=None):\n    x1, x2 = get_ov_output(x1), get_ov_output(x2)\n    x1, x2 = _align_operand_types(x1, x2, \"append()\")\n    if axis is None:\n        flatten_shape = ov_opset.constant([-1], Type.i32).output(0)\n        x1 = ov_opset.reshape(x1, flatten_shape, False).output(0)\n        x2 = ov_opset.reshape(x2, flatten_shape, False).output(0)\n        axis = 0\n    return OpenVINOKerasTensor(ov_opset.concat([x1, x2], axis).output(0))\n\n\ndef arange(start, stop=None, step=None, dtype=None):\n    if stop is None:\n        start, stop = get_ov_output(0), get_ov_output(start)\n    else:\n        start, stop = get_ov_output(start), get_ov_output(stop)\n\n    step = get_ov_output(1) if step is None else get_ov_output(step)\n\n    ov_type = None\n    if dtype is not None:\n        ov_type = OPENVINO_DTYPES[standardize_dtype(dtype)]\n    else:\n        ov_type = OPENVINO_DTYPES[\n            dtypes.result_type(\n                ov_to_keras_type(start.get_element_type()),\n                ov_to_keras_type(stop.get_element_type()),\n                ov_to_keras_type(step.get_element_type()),\n                \"int32\",\n            )\n        ]\n\n    start_node = ov_opset.convert(start, ov_type)\n    stop_node = ov_opset.convert(stop, ov_type)\n    step_node = ov_opset.convert(step, ov_type)\n\n    return OpenVINOKerasTensor(\n        ov_opset.range(start_node, stop_node, step_node, ov_type).output(0)\n    )\n\n\ndef arccos(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type.is_integral():\n        ov_type = OPENVINO_DTYPES[config.floatx()]\n        x = ov_opset.convert(x, ov_type)\n    return OpenVINOKerasTensor(ov_opset.acos(x).output(0))\n\n\ndef arccosh(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type.is_integral():\n        ov_type = OPENVINO_DTYPES[config.floatx()]\n        x = ov_opset.convert(x, ov_type)\n    return OpenVINOKerasTensor(ov_opset.acosh(x).output(0))\n\n\ndef arcsin(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type.is_integral():\n        ov_type = OPENVINO_DTYPES[config.floatx()]\n        x = ov_opset.convert(x, ov_type)\n    return OpenVINOKerasTensor(ov_opset.asin(x).output(0))\n\n\ndef arcsinh(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type.is_integral():\n        ov_type = OPENVINO_DTYPES[config.floatx()]\n        x = ov_opset.convert(x, ov_type)\n    return OpenVINOKerasTensor(ov_opset.asinh(x).output(0))\n\n\ndef arctan(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type.is_integral():\n        ov_type = OPENVINO_DTYPES[config.floatx()]\n        x = ov_opset.convert(x, ov_type)\n    return OpenVINOKerasTensor(ov_opset.atan(x).output(0))\n\n\ndef arctan2(x1, x2):\n    x1 = get_ov_output(x1)\n    x2 = get_ov_output(x2)\n\n    x1_type = ov_to_keras_type(x1.get_element_type())\n    x2_type = ov_to_keras_type(x2.get_element_type())\n    result_type = dtypes.result_type(x1_type, x2_type, float)\n    result_type = OPENVINO_DTYPES[result_type]\n    x1 = ov_opset.convert(x1, result_type)\n    x2 = ov_opset.convert(x2, result_type)\n\n    nan_x1 = ov_opset.is_nan(x1)\n    nan_x2 = ov_opset.is_nan(x2)\n    nan_mask = ov_opset.logical_or(nan_x1, nan_x2)\n\n    x = ov_opset.divide(x1, x2)\n    y = ov_opset.atan(x)\n\n    ov_type = x1.get_element_type()\n    pi = ov_opset.constant(float(np.pi), ov_type)\n    half_pi = ov_opset.constant(float(np.pi / 2), ov_type)\n    neg_half_pi = ov_opset.constant(-float(np.pi / 2), ov_type)\n    zero_const = ov_opset.constant(0.0, ov_type)\n\n    cond_x2_gt0 = ov_opset.greater(x2, zero_const)\n    cond_x2_lt0 = ov_opset.less(x2, zero_const)\n\n    cond_x1_ge0 = ov_opset.greater_equal(x1, zero_const)\n    cond_x1_gt0 = ov_opset.greater(x1, zero_const)\n    cond_x1_eq0 = ov_opset.equal(x1, zero_const)\n\n    out_x2_lt0 = ov_opset.select(\n        cond_x1_ge0,\n        ov_opset.add(y, pi),\n        ov_opset.subtract(y, pi),\n    )\n\n    out_x1_zero = ov_opset.select(cond_x1_eq0, zero_const, neg_half_pi)\n    out_x2_zero = ov_opset.select(cond_x1_gt0, half_pi, out_x1_zero)\n\n    out_not_pos = ov_opset.select(cond_x2_lt0, out_x2_lt0, out_x2_zero)\n\n    value_out = ov_opset.select(cond_x2_gt0, y, out_not_pos)\n\n    # Generate NaN safely for all floating dtypes (including bf16)\n    nan_value = ov_opset.divide(zero_const, zero_const)\n    final_out = ov_opset.select(nan_mask, nan_value, value_out)\n    return OpenVINOKerasTensor(final_out.output(0))\n\n\ndef arctanh(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type.is_integral():\n        ov_type = OPENVINO_DTYPES[config.floatx()]\n        x = ov_opset.convert(x, ov_type)\n    return OpenVINOKerasTensor(ov_opset.atanh(x).output(0))\n\n\ndef argmax(x, axis=None, keepdims=False):\n    x = get_ov_output(x)\n    x_shape = x.get_partial_shape()\n    rank = x_shape.rank.get_length()\n    if rank == 0:\n        return OpenVINOKerasTensor(ov_opset.constant([0], Type.i32).output(0))\n    if axis is None:\n        flatten_shape = ov_opset.constant(\n            [-1] + [1] * (rank - 1), Type.i32\n        ).output(0)\n        x = ov_opset.reshape(x, flatten_shape, False).output(0)\n        axis = 0\n        k = ov_opset.constant(1, Type.i32).output(0)\n    else:\n        if axis < 0:\n            axis = rank + axis\n        k = ov_opset.constant(1, Type.i32).output(0)\n    topk_outputs = ov_opset.topk(\n        x,\n        k=k,\n        axis=axis,\n        mode=\"max\",\n        sort=\"value\",\n        stable=True,\n        index_element_type=Type.i32,\n    )\n    topk_indices = topk_outputs.output(1)\n    if not keepdims:\n        topk_indices = ov_opset.squeeze(topk_indices, [axis]).output(0)\n    return OpenVINOKerasTensor(topk_indices)\n\n\ndef argmin(x, axis=None, keepdims=False):\n    x = get_ov_output(x)\n    x_shape = x.get_partial_shape()\n    rank = x_shape.rank.get_length()\n    if rank == 0:\n        return OpenVINOKerasTensor(ov_opset.constant([0], Type.i32).output(0))\n    if axis is None:\n        flatten_shape = ov_opset.constant(\n            [-1] + [1] * (rank - 1), Type.i32\n        ).output(0)\n        x = ov_opset.reshape(x, flatten_shape, False).output(0)\n        axis = 0\n        k = ov_opset.constant(1, Type.i32).output(0)\n    else:\n        if axis < 0:\n            axis = rank + axis\n        k = ov_opset.constant(1, Type.i32).output(0)\n    topk_outputs = ov_opset.topk(\n        x,\n        k=k,\n        axis=axis,\n        mode=\"min\",\n        sort=\"value\",\n        stable=True,\n        index_element_type=Type.i32,\n    )\n    topk_indices = topk_outputs.output(1)\n    if not keepdims:\n        topk_indices = ov_opset.squeeze(topk_indices, [axis]).output(0)\n    return OpenVINOKerasTensor(topk_indices)\n\n\ndef argsort(x, axis=-1):\n    x = get_ov_output(x)\n    x_shape = x.get_partial_shape()\n    rank = x_shape.rank.get_length()\n    if rank == 0:\n        return OpenVINOKerasTensor(ov_opset.constant([0], Type.i32).output(0))\n    if axis is None:\n        flatten_shape = ov_opset.constant([-1], Type.i32).output(0)\n        x = ov_opset.reshape(x, flatten_shape, False).output(0)\n        x_shape_tensor = ov_opset.shape_of(x, Type.i32).output(0)\n        k = ov_opset.reduce_prod(\n            x_shape_tensor, ov_opset.constant([0], Type.i32), keep_dims=False\n        )\n        axis = 0\n    else:\n        if axis < 0:\n            axis = rank + axis\n        x_shape_tensor = ov_opset.shape_of(x, Type.i32).output(0)\n        k = ov_opset.gather(\n            x_shape_tensor,\n            ov_opset.constant(axis, Type.i32).output(0),\n            ov_opset.constant(0, Type.i32).output(0),\n        ).output(0)\n    sorted_indices = ov_opset.topk(\n        x,\n        k=k,\n        axis=axis,\n        mode=\"min\",\n        sort=\"value\",\n    ).output(1)\n    return OpenVINOKerasTensor(sorted_indices)\n\n\ndef array(x, dtype=None):\n    return convert_to_tensor(x, dtype=dtype)\n\n\ndef view(x, dtype=None):\n    raise NotImplementedError(\"`view` is not supported with openvino backend\")\n\n\ndef average(x, axis=None, weights=None):\n    x = get_ov_output(x)\n    if weights is not None:\n        weights = get_ov_output(weights)\n    if axis is None:\n        flatten_shape = ov_opset.constant([-1], Type.i32).output(0)\n        x = ov_opset.reshape(x, flatten_shape, False).output(0)\n        if weights is not None:\n            weights = ov_opset.reshape(weights, flatten_shape, False).output(0)\n        axis = 0\n\n    if weights is not None:\n        x_type = x.get_element_type()\n        weights_type = weights.get_element_type()\n        if (weights_type.is_integral() or weights_type == Type.boolean) and (\n            x_type.is_integral() or x_type == Type.boolean\n        ):\n            x = ov_opset.convert(x, Type.f32).output(0)\n            weights = ov_opset.convert(weights, Type.f32).output(0)\n        x, weights = _align_operand_types(x, weights, \"multiply()\")\n        x = ov_opset.multiply(x, weights)\n\n    if isinstance(axis, tuple):\n        axis = list(axis)\n    if axis == []:\n        return OpenVINOKerasTensor(x)\n\n    axis_const = ov_opset.constant(axis, dtype=Type.i32).output(0)\n    mean_ops = ov_opset.reduce_mean(x, axis_const, False)\n    return OpenVINOKerasTensor(mean_ops.output(0))\n\n\ndef bartlett(x):\n    x = get_ov_output(x)\n    zero_const = ov_opset.constant(0, Type.i64)\n    one_const = ov_opset.constant(1, Type.i64)\n    two_const = ov_opset.constant(2, Type.i64)\n    two_const_f64 = ov_opset.constant(2.0, Type.f64)\n    if x.get_element_type() != Type.i64:\n        x = ov_opset.convert(x, Type.i64)\n    half = ov_opset.convert(\n        ov_opset.divide(ov_opset.subtract(x, one_const), two_const), Type.f64\n    )\n    n = ov_opset.range(zero_const, x, one_const, Type.f64)\n    condition = ov_opset.less_equal(n, half)\n    first_half = ov_opset.divide(\n        ov_opset.multiply(two_const_f64, n),\n        ov_opset.convert(ov_opset.subtract(x, one_const), Type.f64),\n    )\n    second_half = ov_opset.subtract(two_const_f64, first_half)\n    window = ov_opset.select(condition, first_half, second_half)\n    window = ov_opset.convert(window, OPENVINO_DTYPES[config.floatx()]).output(\n        0\n    )\n    return OpenVINOKerasTensor(window)\n\n\ndef hamming(x):\n    m = get_ov_output(x)\n\n    m_i64 = (\n        m if m.get_element_type() == Type.i64 else ov_opset.convert(m, Type.i64)\n    )\n\n    start = ov_opset.constant(0, Type.i64)\n    step = ov_opset.constant(1, Type.i64)\n    n = ov_opset.range(start, m_i64, step, Type.f64)\n\n    one_i64 = ov_opset.constant(1, Type.i64)\n    denom_i64 = ov_opset.subtract(m_i64, one_i64)\n    denom = ov_opset.convert(denom_i64, Type.f64)\n\n    two_pi = ov_opset.constant(2.0 * np.pi, Type.f64)\n    two_pi_over_m_minus_1 = ov_opset.divide(two_pi, denom)\n\n    x = ov_opset.multiply(two_pi_over_m_minus_1, n)\n    c = ov_opset.cos(x)\n\n    # 0.54 - 0.46 * cos(...)\n    a = ov_opset.constant(0.54, Type.f64)\n    b = ov_opset.constant(0.46, Type.f64)\n    hamming_window = ov_opset.subtract(a, ov_opset.multiply(b, c))\n    hamming_window = ov_opset.convert(\n        hamming_window, OPENVINO_DTYPES[config.floatx()]\n    )\n\n    return OpenVINOKerasTensor(hamming_window.output(0))\n\n\ndef hanning(x):\n    m = get_ov_output(x)\n\n    m_i64 = (\n        m if m.get_element_type() == Type.i64 else ov_opset.convert(m, Type.i64)\n    )\n\n    start = ov_opset.constant(0, Type.i64)\n    step = ov_opset.constant(1, Type.i64)\n    n = ov_opset.range(start, m_i64, step, Type.f64)\n\n    one_i64 = ov_opset.constant(1, Type.i64)\n    denom_i64 = ov_opset.subtract(m_i64, one_i64)\n    denom = ov_opset.convert(denom_i64, Type.f64)\n\n    # Handle M=1 case to avoid division by zero\n    one_f64 = ov_opset.constant(1.0, Type.f64)\n    is_zero = ov_opset.equal(denom_i64, ov_opset.constant(0, Type.i64))\n    safe_denom = ov_opset.select(is_zero, one_f64, denom)\n\n    two_pi = ov_opset.constant(2.0 * np.pi, Type.f64)\n    two_pi_over_m_minus_1 = ov_opset.divide(two_pi, safe_denom)\n\n    x = ov_opset.multiply(two_pi_over_m_minus_1, n)\n    c = ov_opset.cos(x)\n\n    # 0.5 - 0.5 * cos(...)\n    a = ov_opset.constant(0.5, Type.f64)\n    b = ov_opset.constant(0.5, Type.f64)\n    hanning_window = ov_opset.subtract(a, ov_opset.multiply(b, c))\n\n    # Fix for M=1: NumPy returns [1.], but formula gives [0.]\n    # Broadcast 1.0 to the shape of hanning_window\n    ones = ov_opset.broadcast(one_f64, ov_opset.shape_of(hanning_window))\n    hanning_window = ov_opset.select(is_zero, ones, hanning_window)\n\n    hanning_window = ov_opset.convert(\n        hanning_window, OPENVINO_DTYPES[config.floatx()]\n    )\n\n    return OpenVINOKerasTensor(hanning_window.output(0))\n\n\ndef heaviside(x1, x2):\n    x1 = get_ov_output(x1)\n    x2 = get_ov_output(x2)\n    x1, x2 = _align_operand_types(x1, x2, \"heaviside()\")\n\n    x_type = ov_to_keras_type(x1.get_element_type())\n    if x_type in [\n        \"int8\",\n        \"int16\",\n        \"int32\",\n        \"uint8\",\n        \"uint16\",\n        \"uint32\",\n        \"bool\",\n    ]:\n        ov_type = OPENVINO_DTYPES[config.floatx()]\n        x1 = ov_opset.convert(x1, ov_type).output(0)\n        x2 = ov_opset.convert(x2, ov_type).output(0)\n    elif x_type in [\"int64\", \"uint64\"]:\n        ov_type = OPENVINO_DTYPES[\"float64\"]\n        x1 = ov_opset.convert(x1, ov_type).output(0)\n        x2 = ov_opset.convert(x2, ov_type).output(0)\n\n    x1_type = x1.get_element_type()\n\n    zero_scalar = ov_opset.constant(0, x1_type).output(0)\n    one_scalar = ov_opset.constant(1, x1_type).output(0)\n\n    neg = ov_opset.less(x1, zero_scalar).output(0)\n    pos = ov_opset.greater(x1, zero_scalar).output(0)\n    eq = ov_opset.equal(x1, zero_scalar).output(0)\n\n    x = ov_opset.select(neg, zero_scalar, x1).output(0)\n    x = ov_opset.select(pos, one_scalar, x).output(0)\n    x = ov_opset.select(eq, x2, x).output(0)\n    return OpenVINOKerasTensor(x)\n\n\ndef _i0_node(x):\n    x = ov_opset.abs(x).output(0)\n    x_type = x.get_element_type()\n    three_point_seven_five = ov_opset.constant(3.75, x_type).output(0)\n    p1_coeffs = [\n        1.0,\n        3.5156229,\n        3.0899424,\n        1.2067492,\n        0.2659732,\n        0.0360768,\n        0.0045813,\n    ]\n    p2_coeffs = [\n        0.39894228,\n        0.01328592,\n        0.00225319,\n        -0.00157565,\n        0.00916281,\n        -0.02057706,\n        0.02635537,\n        -0.01647633,\n        0.00392377,\n    ]\n    t_A = ov_opset.divide(x, three_point_seven_five).output(0)\n    t_A = ov_opset.multiply(t_A, t_A).output(0)\n    res_A = ov_opset.constant(p1_coeffs[6], x_type).output(0)\n    for i in range(5, -1, -1):\n        c = ov_opset.constant(p1_coeffs[i], x_type).output(0)\n        res_A = ov_opset.add(ov_opset.multiply(res_A, t_A), c).output(0)\n    safe_x = ov_opset.maximum(x, three_point_seven_five).output(0)\n    t_B = ov_opset.divide(three_point_seven_five, safe_x).output(0)\n    res_B = ov_opset.constant(p2_coeffs[8], x_type).output(0)\n    for i in range(7, -1, -1):\n        c = ov_opset.constant(p2_coeffs[i], x_type).output(0)\n        res_B = ov_opset.add(ov_opset.multiply(res_B, t_B), c).output(0)\n    exp_x = ov_opset.exp(x).output(0)\n    sqrt_safe_x = ov_opset.sqrt(safe_x).output(0)\n    factor = ov_opset.divide(exp_x, sqrt_safe_x).output(0)\n    res_B = ov_opset.multiply(factor, res_B).output(0)\n    condition = ov_opset.less_equal(x, three_point_seven_five).output(0)\n    result = ov_opset.select(condition, res_A, res_B).output(0)\n\n    return result\n\n\ndef kaiser(x, beta):\n    m = get_ov_output(x)\n    beta = get_ov_output(beta)\n    if m.get_element_type() != Type.i64:\n        m_i64 = ov_opset.convert(m, Type.i64).output(0)\n    else:\n        m_i64 = m\n    calc_type = Type.f64\n    if m.get_element_type() != calc_type:\n        m_float = ov_opset.convert(m, calc_type).output(0)\n    else:\n        m_float = m\n    if beta.get_element_type() != calc_type:\n        beta = ov_opset.convert(beta, calc_type).output(0)\n    start = ov_opset.constant(0, Type.i64).output(0)\n    step = ov_opset.constant(1, Type.i64).output(0)\n    n = ov_opset.range(start, m_i64, step, calc_type).output(0)\n    one_float = ov_opset.constant(1.0, calc_type).output(0)\n    two_float = ov_opset.constant(2.0, calc_type).output(0)\n    alpha = ov_opset.divide(\n        ov_opset.subtract(m_float, one_float), two_float\n    ).output(0)\n    zero_float = ov_opset.constant(0.0, calc_type).output(0)\n    is_alpha_zero = ov_opset.equal(alpha, zero_float).output(0)\n    safe_alpha = ov_opset.select(is_alpha_zero, one_float, alpha).output(0)\n    val = ov_opset.divide(ov_opset.subtract(n, alpha), safe_alpha).output(0)\n    val_sq = ov_opset.multiply(val, val).output(0)\n    term = ov_opset.subtract(one_float, val_sq).output(0)\n    term = ov_opset.maximum(term, zero_float).output(0)\n    sqrt_term = ov_opset.sqrt(term).output(0)\n    arg = ov_opset.multiply(beta, sqrt_term).output(0)\n    num = _i0_node(arg)\n    den = _i0_node(beta)\n    result = ov_opset.divide(num, den).output(0)\n    result = ov_opset.convert(result, OPENVINO_DTYPES[config.floatx()]).output(\n        0\n    )\n    return OpenVINOKerasTensor(result)\n\n\ndef bitwise_left_shift(x, y):\n    element_type = None\n    if isinstance(x, OpenVINOKerasTensor):\n        element_type = x.output.get_element_type()\n    if isinstance(y, OpenVINOKerasTensor):\n        element_type = y.output.get_element_type()\n    x = get_ov_output(x, element_type)\n    y = get_ov_output(y, element_type)\n    x, y = _align_operand_types(x, y, \"bitwise_left_shift()\")\n    return OpenVINOKerasTensor(ov_opset.bitwise_left_shift(x, y).output(0))\n\n\ndef left_shift(x, y):\n    return bitwise_left_shift(x, y)\n\n\ndef bitwise_right_shift(x, y):\n    element_type = None\n    if isinstance(x, OpenVINOKerasTensor):\n        element_type = x.output.get_element_type()\n    if isinstance(y, OpenVINOKerasTensor):\n        element_type = y.output.get_element_type()\n    x = get_ov_output(x, element_type)\n    y = get_ov_output(y, element_type)\n    x, y = _align_operand_types(x, y, \"bitwise_right_shift()\")\n    return OpenVINOKerasTensor(ov_opset.bitwise_right_shift(x, y).output(0))\n\n\ndef right_shift(x, y):\n    return bitwise_right_shift(x, y)\n\n\ndef bincount(x, weights=None, minlength=0, sparse=False):\n    if x is None:\n        raise ValueError(\"input x is None\")\n    if sparse:\n        raise ValueError(\"Unsupported value `sparse=True`\")\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    shape_x = ov_opset.shape_of(x, \"i64\").output(0)\n    rank_x = ov_opset.shape_of(shape_x, \"i64\").output(0)\n    rank_x = ov_opset.convert(rank_x, x_type).output(0)\n    scalar_shape = ov_opset.constant([], x_type).output(0)\n    rank_x = ov_opset.reshape(rank_x, scalar_shape, False).output(0)\n    const_minus_one = ov_opset.constant(-1, x_type).output(0)\n    rank_minus_one = ov_opset.add(rank_x, const_minus_one).output(0)\n    minlength = get_ov_output(minlength)\n    minlength = ov_opset.convert(minlength, x_type).output(0)\n    const_one = ov_opset.constant(1, x_type).output(0)\n    const_zero = ov_opset.constant(0, x_type).output(0)\n    max_element = ov_opset.reduce_max(x, const_zero, keep_dims=False).output(0)\n    depth = ov_opset.add(max_element, const_one).output(0)\n    depth = ov_opset.maximum(depth, minlength).output(0)\n    depth_scalar = ov_opset.reduce_max(\n        depth, const_zero, keep_dims=False\n    ).output(0)\n    one_hot = ov_opset.one_hot(\n        x, depth_scalar, const_one, const_zero, axis=-1\n    ).output(0)\n    if weights is not None:\n        weights = get_ov_output(weights)\n        weights_type = weights.get_element_type()\n        weights_new = ov_opset.reshape(weights, [-1, 1], False).output(0)\n        one_hot = ov_opset.convert(one_hot, weights_type).output(0)\n        final_one_hot = ov_opset.multiply(one_hot, weights_new).output(0)\n        final_output = ov_opset.reduce_sum(\n            final_one_hot, rank_minus_one, keep_dims=False\n        ).output(0)\n        return OpenVINOKerasTensor(final_output)\n    else:\n        final_output = ov_opset.reduce_sum(\n            one_hot, rank_minus_one, keep_dims=False\n        ).output(0)\n        final_output = ov_opset.convert(final_output, Type.i32).output(0)\n        return OpenVINOKerasTensor(final_output)\n\n\ndef _bitwise_op_i8u8(ov_op, x, y):\n    \"\"\"Apply an OV bitwise op, working around a SIMD bug in 8-bit kernels.\n\n    OpenVINO's int8/uint8 bitwise kernels have a vectorization bug: when the\n    last dimension is >= 32, elements at non-stride-4 positions (0..31 range)\n    receive wrong values.  Casting to int32/uint32 avoids the buggy kernel.\n    \"\"\"\n    elem_type = x.get_element_type()\n    if elem_type in (Type.i8, Type.u8):\n        cast_type = Type.i32 if elem_type == Type.i8 else Type.u32\n        x = ov_opset.convert(x, cast_type).output(0)\n        y = ov_opset.convert(y, cast_type).output(0)\n        result = ov_op(x, y).output(0)\n        return OpenVINOKerasTensor(\n            ov_opset.convert(result, elem_type).output(0)\n        )\n    return OpenVINOKerasTensor(ov_op(x, y).output(0))\n\n\ndef bitwise_and(x, y):\n    x = get_ov_output(x)\n    y = get_ov_output(y)\n    x, y = _align_operand_types(x, y, \"bitwise_and()\")\n    return _bitwise_op_i8u8(ov_opset.bitwise_and, x, y)\n\n\ndef bitwise_xor(x, y):\n    x = get_ov_output(x)\n    y = get_ov_output(y)\n    x, y = _align_operand_types(x, y, \"bitwise_xor()\")\n    return _bitwise_op_i8u8(ov_opset.bitwise_xor, x, y)\n\n\ndef bitwise_invert(x):\n    x = get_ov_output(x)\n    return OpenVINOKerasTensor(ov_opset.bitwise_not(x).output(0))\n\n\ndef bitwise_not(x):\n    return bitwise_invert(x)\n\n\ndef bitwise_or(x, y):\n    x = get_ov_output(x)\n    y = get_ov_output(y)\n    x, y = _align_operand_types(x, y, \"bitwise_or()\")\n    return _bitwise_op_i8u8(ov_opset.bitwise_or, x, y)\n\n\ndef blackman(x):\n    x = get_ov_output(x)\n    zero_const = ov_opset.constant(0, Type.i64)\n    one_const = ov_opset.constant(1, Type.i64)\n    two_pi = ov_opset.constant(2.0 * np.pi, Type.f64)\n    term_1 = ov_opset.constant(0.42, Type.f64)\n    term_2 = ov_opset.constant(0.5, Type.f64)\n    term_3 = ov_opset.constant(0.08, Type.f64)\n    if x.get_element_type() != Type.i64:\n        x = ov_opset.convert(x, Type.i64)\n    n = ov_opset.range(zero_const, x, one_const, Type.f64)\n    n_minus_1 = ov_opset.subtract(\n        ov_opset.convert(x, Type.f64), ov_opset.constant(1.0, Type.f64)\n    ).output(0)\n    angle_2pi = ov_opset.divide(ov_opset.multiply(two_pi, n), n_minus_1)\n    angle_4pi = ov_opset.multiply(angle_2pi, ov_opset.constant(2.0, Type.f64))\n    cos_2pi = ov_opset.cos(angle_2pi)\n    cos_4pi = ov_opset.cos(angle_4pi)\n    term_2_final = ov_opset.multiply(term_2, cos_2pi)\n    term_3_final = ov_opset.multiply(term_3, cos_4pi)\n    window = ov_opset.add(ov_opset.subtract(term_1, term_2_final), term_3_final)\n    window = ov_opset.convert(window, OPENVINO_DTYPES[config.floatx()]).output(\n        0\n    )\n    return OpenVINOKerasTensor(window)\n\n\ndef broadcast_to(x, shape):\n    if not isinstance(shape, (tuple, list)):\n        raise ValueError(\n            f\"`broadcast_to` is supported only for tuple and list `shape`. \"\n            f\"Received: shape={shape} (type {type(shape)})\"\n        )\n    target_shape = ov_opset.constant(list(shape), Type.i32).output(0)\n    x = get_ov_output(x)\n    return OpenVINOKerasTensor(ov_opset.broadcast(x, target_shape).output(0))\n\n\ndef cbrt(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type.is_integral() or x_type == Type.boolean:\n        x = ov_opset.convert(x, OPENVINO_DTYPES[config.floatx()]).output(0)\n    sign_x = ov_opset.sign(x)\n    abs_x = ov_opset.absolute(x)\n    one_third = ov_opset.constant(1.0 / 3.0, x.get_element_type())\n    root_abs = ov_opset.power(abs_x, one_third)\n    res = ov_opset.multiply(sign_x, root_abs)\n    return OpenVINOKerasTensor(res.output(0))\n\n\ndef ceil(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type.is_integral():\n        x = ov_opset.convert(x, OPENVINO_DTYPES[config.floatx()]).output(0)\n    ceiling = ov_opset.ceil(x).output(0)\n    return OpenVINOKerasTensor(ceiling)\n\n\ndef clip(x, x_min, x_max):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type == Type.boolean:\n        x = ov_opset.convert(x, Type.i32).output(0)\n    x_min = get_ov_output(x_min, x.get_element_type())\n    x_max = get_ov_output(x_max, x.get_element_type())\n    clip_by_min = ov_opset.maximum(x, x_min).output(0)\n    clip_by_max = ov_opset.minimum(clip_by_min, x_max).output(0)\n    return OpenVINOKerasTensor(clip_by_max)\n\n\ndef concatenate(xs, axis=0):\n    elems = [get_ov_output(x) for x in xs]\n    if axis is None:\n        flatten_shape = ov_opset.constant([-1], Type.i32).output(0)\n        elems = [\n            ov_opset.reshape(x, flatten_shape, False).output(0) for x in elems\n        ]\n        axis = 0\n    keras_types = [ov_to_keras_type(x.get_element_type()) for x in elems]\n    if keras_types:\n        target_type = dtypes.result_type(*keras_types)\n        ov_target_type = OPENVINO_DTYPES[target_type]\n        elems = [\n            ov_opset.convert(x, ov_target_type).output(0)\n            if x.get_element_type() != ov_target_type\n            else x\n            for x in elems\n        ]\n    res = ov_opset.concat(elems, axis).output(0)\n    return OpenVINOKerasTensor(res)\n\n\ndef conjugate(x):\n    # TODO: Implement complex support when OpenVINO adds complex dtypes.\n    # Currently, all supported dtypes are real-valued.\n    return convert_to_tensor(x)\n\n\ndef conj(x):\n    return conjugate(x)\n\n\ndef copy(x):\n    return x\n\n\ndef cos(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type.is_integral():\n        ov_type = OPENVINO_DTYPES[config.floatx()]\n        x = ov_opset.convert(x, ov_type)\n    return OpenVINOKerasTensor(ov_opset.cos(x).output(0))\n\n\ndef cosh(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type.is_integral():\n        ov_type = OPENVINO_DTYPES[config.floatx()]\n        x = ov_opset.convert(x, ov_type)\n    return OpenVINOKerasTensor(ov_opset.cosh(x).output(0))\n\n\ndef count_nonzero(x, axis=None):\n    x = get_ov_output(x)\n    zero_constant = ov_opset.constant(0, dtype=Type.i32).output(0)\n    zero_constant = ov_opset.convert_like(zero_constant, x)\n    x = ov_opset.not_equal(x, zero_constant).output(0)\n    x = ov_opset.convert(x, Type.i32).output(0)\n    x, axis = _resolve_axis(x, axis)\n    if not axis:\n        return OpenVINOKerasTensor(x)\n    return OpenVINOKerasTensor(ov_opset.reduce_sum(x, axis, False).output(0))\n\n\ndef cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None):\n    if axis is not None:\n        axisa = axisb = axisc = axis\n\n    x1 = get_ov_output(x1)\n    x2 = get_ov_output(x2)\n\n    x1, x2 = _align_operand_types(x1, x2, \"cross()\")\n\n    shape1 = x1.get_partial_shape()\n    shape2 = x2.get_partial_shape()\n\n    # Rank Normalization\n    rank1 = shape1.rank.get_length()\n    rank2 = shape2.rank.get_length()\n\n    axisa = canonicalize_axis(axisa, rank1)\n    axisb = canonicalize_axis(axisb, rank2)\n    axisc = canonicalize_axis(axisc, rank1 if rank1 > rank2 else rank2)\n\n    d1 = shape1[axisa].get_length()\n    d2 = shape2[axisb].get_length()\n\n    if d1 not in (2, 3) or d2 not in (2, 3):\n        raise ValueError(\n            \"Dimension of vectors for cross product must be 2 or 3. \"\n            f\"Got dimensions {d1} and {d2} for inputs x1 and x2.\"\n        )\n\n    # Pad to 3D by adding a zero component.\n    def pad_to_3d(x, dim, ax):\n        if dim == 3:\n            return x\n\n        # Create a slice of zeros with the same type as x\n        slice0 = ov_opset.gather(\n            x,\n            ov_opset.constant([0], Type.i32),\n            ov_opset.constant(ax, Type.i32),\n        )\n        zeros = ov_opset.multiply(\n            slice0,\n            ov_opset.constant(0, x.get_element_type()),\n        )\n\n        return ov_opset.concat([x, zeros], ax)\n\n    x1_3d = pad_to_3d(x1, d1, axisa)\n    x2_3d = pad_to_3d(x2, d2, axisb)\n\n    # Split Vectors\n    u = ov_opset.split(x1_3d, ov_opset.constant(axisa, Type.i32), 3).outputs()\n    v = ov_opset.split(x2_3d, ov_opset.constant(axisb, Type.i32), 3).outputs()\n\n    # u x v = (u2*v3 - u3*v2, u3*v1 - u1*v3, u1*v2 - u2*v1)\n    res_x = ov_opset.subtract(\n        ov_opset.multiply(u[1], v[2]), ov_opset.multiply(u[2], v[1])\n    )\n    res_y = ov_opset.subtract(\n        ov_opset.multiply(u[2], v[0]), ov_opset.multiply(u[0], v[2])\n    )\n    res_z = ov_opset.subtract(\n        ov_opset.multiply(u[0], v[1]), ov_opset.multiply(u[1], v[0])\n    )\n\n    # If dim was 2D, we remove the padded zero component.\n    if d1 == 2 and d2 == 2:\n        result = res_z\n        result = ov_opset.squeeze(result, ov_opset.constant([axisc], Type.i32))\n    else:\n        result = ov_opset.concat([res_x, res_y, res_z], axisc)\n\n    return OpenVINOKerasTensor(result.output(0))\n\n\ndef cumprod(x, axis=None, dtype=None):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n\n    # Determine output dtype following numpy backend logic\n    if dtype is not None:\n        ov_type = OPENVINO_DTYPES[standardize_dtype(dtype)]\n        if ov_type == Type.boolean:\n            ov_type = Type.i32\n    else:\n        ov_type = x_type\n        if ov_type == Type.boolean:\n            ov_type = Type.i32\n\n    # Convert boolean to int32 for computation\n    if x_type == Type.boolean:\n        x = ov_opset.convert(x, Type.i32).output(0)\n        x_type = Type.i32\n\n    compute_as_float = False\n    if x_type.is_integral():\n        compute_dtype = Type.f32\n        x = ov_opset.convert(x, compute_dtype).output(0)\n        compute_as_float = True\n    else:\n        compute_dtype = x_type\n\n    x, axis = _resolve_axis(x, axis)\n\n    signs = ov_opset.sign(x).output(0)\n\n    is_zero_sign = ov_opset.equal(\n        signs, ov_opset.constant(0, compute_dtype)\n    ).output(0)\n    signs_no_zeros = ov_opset.select(\n        is_zero_sign, ov_opset.constant(1, compute_dtype), signs\n    ).output(0)\n\n    is_negative = ov_opset.less(\n        signs_no_zeros, ov_opset.constant(0, compute_dtype)\n    ).output(0)\n    num_negatives = ov_opset.cumsum(\n        ov_opset.convert(is_negative, Type.i32), axis\n    ).output(0)\n    is_odd = ov_opset.mod(num_negatives, ov_opset.constant(2, Type.i32)).output(\n        0\n    )\n\n    cum_sign = ov_opset.subtract(\n        ov_opset.constant(1, Type.i32),\n        ov_opset.multiply(ov_opset.constant(2, Type.i32), is_odd),\n    ).output(0)\n    cum_sign = ov_opset.convert(cum_sign, compute_dtype).output(0)\n\n    abs_x = ov_opset.absolute(x).output(0)\n    is_zero_abs = ov_opset.equal(\n        abs_x, ov_opset.constant(0, compute_dtype)\n    ).output(0)\n    abs_x_safe = ov_opset.select(\n        is_zero_abs, ov_opset.constant(1, compute_dtype), abs_x\n    ).output(0)\n\n    log_abs_x = ov_opset.log(abs_x_safe).output(0)\n    cumsum_log_abs = ov_opset.cumsum(log_abs_x, axis).output(0)\n    cumprod_abs = ov_opset.exp(cumsum_log_abs).output(0)\n\n    result = ov_opset.multiply(cumprod_abs, cum_sign).output(0)\n\n    is_zero = ov_opset.equal(x, ov_opset.constant(0, compute_dtype)).output(0)\n    has_zero_before = ov_opset.cumsum(\n        ov_opset.convert(is_zero, Type.i32), axis\n    ).output(0)\n    zero_mask = ov_opset.equal(\n        has_zero_before, ov_opset.constant(0, Type.i32)\n    ).output(0)\n    result = ov_opset.multiply(\n        result, ov_opset.convert(zero_mask, compute_dtype)\n    ).output(0)\n\n    if compute_as_float and ov_type.is_integral():\n        result = ov_opset.round(result).output(0)\n\n    if result.get_element_type() != ov_type:\n        result = ov_opset.convert(result, ov_type).output(0)\n\n    return OpenVINOKerasTensor(result)\n\n\ndef cumsum(x, axis=None, dtype=None):\n    x = get_ov_output(x)\n    if dtype is not None:\n        ov_type = OPENVINO_DTYPES[standardize_dtype(dtype)]\n        x = ov_opset.convert(x, ov_type).output(0)\n    x, axis = _resolve_axis(x, axis)\n    if x.get_element_type() == Type.boolean:\n        x = ov_opset.convert(x, Type.i32).output(0)\n    return OpenVINOKerasTensor(ov_opset.cumsum(x, axis).output(0))\n\n\ndef deg2rad(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    pi_over_180 = np.pi / 180.0\n\n    if x_type == Type.i64:\n        output_type = Type.f64\n    elif x_type.is_integral():\n        output_type = OPENVINO_DTYPES[config.floatx()]\n    else:\n        output_type = x_type\n\n    if x_type != output_type:\n        x = ov_opset.convert(x, output_type)\n\n    const_pi_over_180 = ov_opset.constant(pi_over_180, output_type).output(0)\n    result = ov_opset.multiply(x, const_pi_over_180).output(0)\n\n    return OpenVINOKerasTensor(result)\n\n\ndef diag(x, k=0):\n    x = get_ov_output(x)\n    x_shape = x.get_partial_shape()\n    rank = x_shape.rank.get_length()\n\n    if rank == 1:\n        N_dim = x_shape[0]\n        if not N_dim.is_static:\n            raise ValueError(\n                \"diag requires input with static shape for 1D input.\"\n            )\n        N = N_dim.get_length()\n        output_size = N + np.abs(k)\n        out_shape = ov_opset.constant(\n            [output_size, output_size], dtype=Type.i32\n        ).output(0)\n        zeros_const = ov_opset.constant(0, x.get_element_type()).output(0)\n        diag_matrix = ov_opset.broadcast(zeros_const, out_shape)\n\n        indices = []\n        if k >= 0:\n            for i in range(N):\n                indices.append([i, i + k])\n        else:\n            for i in range(N):\n                indices.append([i - k, i])\n\n        indices = np.array(indices, dtype=np.int32)\n        indices_const = ov_opset.constant(indices, dtype=Type.i32).output(0)\n        updated = ov_opset.scatter_nd_update(diag_matrix, indices_const, x)\n        return OpenVINOKerasTensor(updated.output(0))\n\n    elif rank == 2:\n        M_dim = x_shape[0]\n        N_dim = x_shape[1]\n        if not M_dim.is_static or not N_dim.is_static:\n            raise ValueError(\n                \"diag requires input with static shape for 2D input.\"\n            )\n        M = M_dim.get_length()\n        N = N_dim.get_length()\n\n        if k >= 0:\n            L = np.minimum(M, N - k) if (N - k) > 0 else 0\n            indices = [[i, i + k] for i in range(L)]\n        else:\n            L = np.minimum(M + k, N) if (M + k) > 0 else 0\n            indices = [[i - k, i] for i in range(L)]\n\n        if L <= 0:\n            keras_dtype = ov_to_keras_type(x.get_element_type())\n            np_dtype = np.dtype(keras_dtype)\n            empty_np = np.empty((0,), dtype=np_dtype)\n            empty_const = ov_opset.constant(\n                empty_np, x.get_element_type()\n            ).output(0)\n            return OpenVINOKerasTensor(empty_const)\n\n        indices = np.array(indices, dtype=np.int32)\n        indices_const = ov_opset.constant(indices, dtype=Type.i32).output(0)\n        diag_vec = ov_opset.gather_nd(x, indices_const)\n        return OpenVINOKerasTensor(diag_vec.output(0))\n\n    else:\n        raise ValueError(\"diag supports only 1D or 2D tensors\")\n\n\ndef diagflat(x, k=0):\n    x = get_ov_output(x)\n\n    flatten_shape = ov_opset.constant([-1], dtype=Type.i32).output(0)\n    v_flat = ov_opset.reshape(x, flatten_shape, False).output(0)\n\n    v_flat_shape = ov_opset.shape_of(v_flat, Type.i32).output(0)\n    zero_node = ov_opset.constant(0, dtype=Type.i32).output(0)\n    n = ov_opset.gather(v_flat_shape, zero_node, zero_node).output(0)\n\n    k_val = int(k)\n    if k_val < 0:\n        abs_k = -k_val\n    else:\n        abs_k = k_val\n\n    n_plus_k = ov_opset.add(\n        n, ov_opset.constant(abs_k, dtype=Type.i32).output(0)\n    ).output(0)\n\n    target_shape_vec = ov_opset.concat(\n        [\n            ov_opset.reshape(\n                n_plus_k,\n                ov_opset.constant([1], dtype=Type.i32).output(0),\n                False,\n            ).output(0),\n            ov_opset.reshape(\n                n_plus_k,\n                ov_opset.constant([1], dtype=Type.i32).output(0),\n                False,\n            ).output(0),\n        ],\n        0,\n    ).output(0)\n\n    v_type = x.get_element_type()\n    zero_const = ov_opset.constant(0, dtype=v_type).output(0)\n\n    zeros_mat = ov_opset.broadcast(zero_const, target_shape_vec).output(0)\n\n    one_node = ov_opset.constant(1, dtype=Type.i32).output(0)\n    rng = ov_opset.range(zero_node, n, one_node, Type.i32).output(0)\n\n    k_const = ov_opset.constant(k_val, dtype=Type.i32).output(0)\n\n    if k_val >= 0:\n        rows = rng\n        cols = ov_opset.add(rng, k_const).output(0)\n    else:\n        neg_k_const = ov_opset.constant(-k_val, dtype=Type.i32).output(0)\n        rows = ov_opset.add(rng, neg_k_const).output(0)\n        cols = rng\n\n    rows_expanded = ov_opset.reshape(rows, [-1, 1], False).output(0)\n    cols_expanded = ov_opset.reshape(cols, [-1, 1], False).output(0)\n    indices = ov_opset.concat([rows_expanded, cols_expanded], 1).output(0)\n\n    result = ov_opset.scatter_nd_update(zeros_mat, indices, v_flat).output(0)\n    return OpenVINOKerasTensor(result)\n\n\ndef diagonal(x, offset=0, axis1=0, axis2=1):\n    x = get_ov_output(x)\n    shape = x.get_partial_shape()\n    rank = x.get_partial_shape().rank.get_length()\n    if rank is None:\n        raise ValueError(\"`diagonal` requires input tensor with static rank.\")\n    if rank < 2:\n        raise ValueError(\n            f\"diagonal requires input tensor with rank >= 2.Given rank: {rank}\"\n        )\n    axis1 = canonicalize_axis(axis1, rank)\n    axis2 = canonicalize_axis(axis2, rank)\n    if axis1 == axis2:\n        raise ValueError(\"`axis1` and `axis2` cannot be the same.\")\n\n    perm_order = [axis1, axis2] + [\n        i for i in range(rank) if i != axis1 and i != axis2\n    ]\n    perm_const = ov_opset.constant(perm_order, dtype=Type.i32).output(0)\n    x_transposed = ov_opset.transpose(x, perm_const)\n\n    N_dim = shape[axis1]\n    M_dim = shape[axis2]\n    if not N_dim.is_static or not M_dim.is_static:\n        raise ValueError(\n            \"`diagonal` requires input tensor with static shape for axes \"\n            f\"`axis1` ({axis1}) and `axis2` ({axis2}).\"\n        )\n    N = N_dim.get_length()\n    M = M_dim.get_length()\n    if offset >= 0:\n        L = np.minimum(N, M - offset) if (M - offset) > 0 else 0\n        indices = [[i, i + offset] for i in range(L)]\n    else:\n        L = np.minimum(N + offset, M) if (N + offset) > 0 else 0\n        indices = [[i - offset, i] for i in range(L)]\n\n    indices = np.array(indices, dtype=np.int32).reshape(L, 2)\n    indices_const = ov_opset.constant(indices, dtype=Type.i32).output(0)\n\n    diag_gathered = ov_opset.gather_nd(x_transposed, indices_const)\n\n    out_rank = rank - 1\n    out_perm_order = list(range(1, out_rank)) + [0]\n    out_perm_const = ov_opset.constant(out_perm_order, dtype=Type.i32).output(0)\n\n    final_output = ov_opset.transpose(diag_gathered, out_perm_const)\n    return OpenVINOKerasTensor(final_output.output(0))\n\n\ndef diff(a, n=1, axis=-1):\n    if n == 0:\n        return OpenVINOKerasTensor(get_ov_output(a))\n    if n < 0:\n        raise ValueError(f\"order must be non-negative but got {repr(n)}\")\n    a = get_ov_output(a)\n    a_type = a.get_element_type()\n    if isinstance(a, np.ndarray):\n        rank = a.ndim\n    else:\n        rank = a.get_partial_shape().rank.get_length()\n    if axis < 0:\n        axis = axis + rank\n    result = a\n    for _ in range(n):\n        rank = result.get_partial_shape().rank.get_length()\n        strides = ov_opset.constant(\n            np.array([1] * rank, dtype=np.int64), Type.i64\n        ).output(0)\n\n        begin_upper_list = [0] * rank\n        begin_upper_list[axis] = 1\n        begin_upper = ov_opset.constant(\n            np.array(begin_upper_list, dtype=np.int64), Type.i64\n        ).output(0)\n        end_upper = ov_opset.constant(\n            np.array([0] * rank, dtype=np.int64), Type.i64\n        ).output(0)\n        begin_mask_upper = [1] * rank\n        begin_mask_upper[axis] = 0\n        end_mask_upper = [1] * rank\n        upper = ov_opset.strided_slice(\n            data=result,\n            begin=begin_upper,\n            end=end_upper,\n            strides=strides,\n            begin_mask=begin_mask_upper,\n            end_mask=end_mask_upper,\n            new_axis_mask=[],\n            shrink_axis_mask=[],\n            ellipsis_mask=[],\n        ).output(0)\n\n        begin_lower = ov_opset.constant(\n            np.array([0] * rank, dtype=np.int64), Type.i64\n        ).output(0)\n        end_lower_list = [0] * rank\n        end_lower_list[axis] = -1\n        end_lower = ov_opset.constant(\n            np.array(end_lower_list, dtype=np.int64), Type.i64\n        ).output(0)\n        begin_mask_lower = [1] * rank\n        end_mask_lower = [1] * rank\n        end_mask_lower[axis] = 0\n        lower = ov_opset.strided_slice(\n            data=result,\n            begin=begin_lower,\n            end=end_lower,\n            strides=strides,\n            begin_mask=begin_mask_lower,\n            end_mask=end_mask_lower,\n            new_axis_mask=[],\n            shrink_axis_mask=[],\n            ellipsis_mask=[],\n        ).output(0)\n\n        if a_type == Type.boolean:\n            result = ov_opset.not_equal(upper, lower).output(0)\n        else:\n            result = ov_opset.subtract(upper, lower).output(0)\n    return OpenVINOKerasTensor(result)\n\n\ndef digitize(x, bins):\n    x_node = get_ov_output(x)\n\n    if isinstance(bins, OpenVINOKerasTensor):\n        bins_node = get_ov_output(bins)\n    else:\n        bins_np = np.asarray(bins)\n        if bins_np.ndim != 1:\n            raise ValueError(\"`bins` must be 1-D array-like\")\n        bins_node = ov_opset.constant(bins_np).output(0)\n\n    x_node, bins_node = _align_operand_types(x_node, bins_node, \"digitize()\")\n\n    if x_node.get_element_type() == Type.boolean:\n        x_node = ov_opset.convert(x_node, Type.f32).output(0)\n        bins_node = ov_opset.convert(bins_node, Type.f32).output(0)\n\n    result = ov_opset.bucketize(\n        x_node,\n        bins_node,\n        output_type=Type.i32,\n        with_right_bound=False,\n    ).output(0)\n\n    return OpenVINOKerasTensor(result)\n\n\ndef dot(x1, x2):\n    element_type = None\n    if isinstance(x1, OpenVINOKerasTensor):\n        element_type = x1.output.get_element_type()\n    if isinstance(x2, OpenVINOKerasTensor):\n        element_type = x2.output.get_element_type()\n    x1 = get_ov_output(x1, element_type)\n    x2 = get_ov_output(x2, element_type)\n    x1, x2 = _align_operand_types(x1, x2, \"dot()\")\n    if x1.get_partial_shape().rank == 0 or x2.get_partial_shape().rank == 0:\n        return OpenVINOKerasTensor(ov_opset.multiply(x1, x2).output(0))\n    return OpenVINOKerasTensor(ov_opset.matmul(x1, x2, False, False).output(0))\n\n\ndef dstack(xs):\n    if not isinstance(xs, (list, tuple)):\n        xs = (xs,)\n    elems = [convert_to_tensor(elem) for elem in xs]\n    element_type = elems[0].output.get_element_type()\n    elems = [get_ov_output(elem, element_type) for elem in elems]\n\n    processed_elems = []\n    for elem in elems:\n        shape = elem.get_partial_shape()\n        rank = shape.rank\n        shape_len = rank.get_length()\n        if shape_len == 0:\n            elem = ov_opset.unsqueeze(\n                elem, ov_opset.constant(0, Type.i32)\n            ).output(0)\n            elem = ov_opset.unsqueeze(\n                elem, ov_opset.constant(1, Type.i32)\n            ).output(0)\n            elem = ov_opset.unsqueeze(\n                elem, ov_opset.constant(2, Type.i32)\n            ).output(0)\n        elif shape_len == 1:\n            elem = ov_opset.unsqueeze(\n                elem, ov_opset.constant(0, Type.i32)\n            ).output(0)\n            elem = ov_opset.unsqueeze(\n                elem, ov_opset.constant(2, Type.i32)\n            ).output(0)\n        elif shape_len == 2:\n            elem = ov_opset.unsqueeze(\n                elem, ov_opset.constant(2, Type.i32)\n            ).output(0)\n        processed_elems.append(elem)\n\n    for i in range(1, len(processed_elems)):\n        processed_elems[0], processed_elems[i] = _align_operand_types(\n            processed_elems[0], processed_elems[i], \"dstack()\"\n        )\n    return OpenVINOKerasTensor(ov_opset.concat(processed_elems, 2).output(0))\n\n\ndef empty(shape, dtype=None):\n    dtype = standardize_dtype(dtype) or config.floatx()\n    ov_type = OPENVINO_DTYPES[dtype]\n    if isinstance(shape, tuple):\n        shape = list(shape)\n    elif isinstance(shape, int):\n        shape = [shape]\n    shape_node = ov_opset.constant(shape, Type.i32).output(0)\n    const_zero = ov_opset.constant(0, dtype=ov_type).output(0)\n    empty_tensor = ov_opset.broadcast(const_zero, shape_node).output(0)\n    return OpenVINOKerasTensor(empty_tensor)\n\n\ndef empty_like(x, dtype=None):\n    return zeros_like(x, dtype=dtype)\n\n\ndef equal(x1, x2):\n    element_type = None\n    if isinstance(x1, OpenVINOKerasTensor):\n        element_type = x1.output.get_element_type()\n    if isinstance(x2, OpenVINOKerasTensor):\n        element_type = x2.output.get_element_type()\n    x1 = get_ov_output(x1, element_type)\n    x2 = get_ov_output(x2, element_type)\n    x1, x2 = _align_operand_types(x1, x2, \"equal()\")\n    return OpenVINOKerasTensor(ov_opset.equal(x1, x2).output(0))\n\n\ndef exp(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type.is_integral():\n        ov_type = OPENVINO_DTYPES[config.floatx()]\n        x = ov_opset.convert(x, ov_type)\n    return OpenVINOKerasTensor(ov_opset.exp(x).output(0))\n\n\ndef exp2(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type.is_integral() or x_type == Type.boolean:\n        ov_type = OPENVINO_DTYPES[config.floatx()]\n        x = ov_opset.convert(x, ov_type).output(0)\n    two = ov_opset.constant(2.0, x.get_element_type()).output(0)\n    result = ov_opset.power(two, x).output(0)\n    return OpenVINOKerasTensor(result)\n\n\ndef expand_dims(x, axis):\n    x = get_ov_output(x)\n    if isinstance(axis, tuple):\n        axis = list(axis)\n    axis = ov_opset.constant(axis, Type.i32).output(0)\n    return OpenVINOKerasTensor(ov_opset.unsqueeze(x, axis).output(0))\n\n\ndef expm1(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type.is_integral():\n        ov_type = OPENVINO_DTYPES[config.floatx()]\n        x = ov_opset.convert(x, ov_type)\n    exp_x = ov_opset.exp(x).output(0)\n    const_one = ov_opset.constant(1, exp_x.get_element_type())\n    result = ov_opset.subtract(exp_x, const_one).output(0)\n    return OpenVINOKerasTensor(result)\n\n\ndef flip(x, axis=None):\n    x_node = get_ov_output(x)\n\n    # Using OpenVINO tensor shape\n    ndim = len(x_node.get_partial_shape())\n    if ndim is None:\n        raise ValueError(\n            \"The `flip` operation does not support tensors with dynamic rank \"\n            \"for the OpenVINO backend.\"\n        )\n\n    if axis is None:\n        axis = list(range(ndim))\n    elif isinstance(axis, int):\n        axis = [axis]\n\n    axis = [a + ndim if a < 0 else a for a in axis]\n\n    begin = [0] * ndim\n    end = [0] * ndim\n    strides = [1] * ndim\n    for a in axis:\n        strides[a] = -1\n\n    all_ones_mask = [1] * ndim\n    result = ov_opset.strided_slice(\n        data=x_node,\n        begin=begin,\n        end=end,\n        strides=strides,\n        begin_mask=all_ones_mask,\n        end_mask=all_ones_mask,\n    )\n    return OpenVINOKerasTensor(result.output(0))\n\n\ndef rot90(array, k=1, axes=(0, 1)):\n    \"\"\"Rotate an array by 90 degrees in the plane specified by axes.\"\"\"\n    array = get_ov_output(array)\n\n    if not isinstance(axes, (tuple, list)) or len(axes) != 2:\n        raise ValueError(\"axes must be a tuple of length 2\")\n\n    shape = array.get_partial_shape()\n    ndim = shape.rank.get_length()\n    if ndim is None:\n        raise ValueError(\n            \"`rot90` does not support tensors with dynamic rank \"\n            \"for the OpenVINO backend.\"\n        )\n\n    axis1 = canonicalize_axis(axes[0], ndim)\n    axis2 = canonicalize_axis(axes[1], ndim)\n\n    if axis1 == axis2:\n        raise ValueError(\"axes must be different\")\n\n    k = k % 4\n    if k == 0:\n        return OpenVINOKerasTensor(array)\n\n    result = array\n\n    for _ in range(k):\n        # 1️ Transpose axis1 <-> axis2\n        perm = list(range(ndim))\n        perm[axis1], perm[axis2] = perm[axis2], perm[axis1]\n        perm_const = ov_opset.constant(perm, Type.i32).output(0)\n        result = ov_opset.transpose(result, perm_const).output(0)\n\n        # 2️ Reverse along axis1 using StridedSlice\n        begin = [0] * ndim\n        end = [0] * ndim\n        strides = [1] * ndim\n        strides[axis1] = -1\n\n        begin_mask = [1] * ndim\n        end_mask = [1] * ndim\n\n        result = ov_opset.strided_slice(\n            data=result,\n            begin=begin,\n            end=end,\n            strides=strides,\n            begin_mask=begin_mask,\n            end_mask=end_mask,\n        ).output(0)\n\n    return OpenVINOKerasTensor(result)\n\n\ndef floor(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type.is_integral():\n        x = ov_opset.convert(x, OPENVINO_DTYPES[config.floatx()])\n    return OpenVINOKerasTensor(ov_opset.floor(x).output(0))\n\n\ndef full(shape, fill_value, dtype=None):\n    dtype = standardize_dtype(dtype) or config.floatx()\n    ov_type = OPENVINO_DTYPES[dtype]\n    fill_value = get_ov_output(fill_value, ov_type)\n    if isinstance(shape, tuple):\n        shape = list(shape)\n    target_shape = ov_opset.constant(shape, Type.i32)\n    return OpenVINOKerasTensor(\n        ov_opset.broadcast(fill_value, target_shape).output(0)\n    )\n\n\ndef full_like(x, fill_value, dtype=None):\n    x = get_ov_output(x)\n    shape_x = ov_opset.shape_of(x)\n    if dtype is not None:\n        ov_type = OPENVINO_DTYPES[standardize_dtype(dtype)]\n    else:\n        ov_type = x.get_element_type()\n    const_value = ov_opset.constant(fill_value, ov_type).output(0)\n    res = ov_opset.broadcast(const_value, shape_x).output(0)\n    return OpenVINOKerasTensor(res)\n\n\ndef gcd(x1, x2):\n    x1 = get_ov_output(x1)\n    x2 = get_ov_output(x2)\n    x1, x2 = _align_operand_types(x1, x2, \"gcd()\")\n\n    x1 = ov_opset.abs(x1).output(0)\n    x2 = ov_opset.abs(x2).output(0)\n\n    # Broadcast to common shape\n    temp_sum = ov_opset.add(x1, x2).output(0)\n    target_shape = ov_opset.shape_of(temp_sum, Type.i32).output(0)\n    x1 = ov_opset.broadcast(x1, target_shape).output(0)\n    x2 = ov_opset.broadcast(x2, target_shape).output(0)\n\n    def cond(a, b):\n        b = get_ov_output(b)\n        zero = ov_opset.constant(0, b.get_element_type()).output(0)\n        not_zero = ov_opset.not_equal(b, zero).output(0)\n\n        shape_b = ov_opset.shape_of(b, Type.i64).output(0)\n        rank_b = ov_opset.shape_of(shape_b, Type.i64).output(0)\n        rank_b_scalar = ov_opset.squeeze(\n            rank_b, ov_opset.constant(0, Type.i32)\n        ).output(0)\n        axes = ov_opset.range(\n            ov_opset.constant(0, Type.i64).output(0),\n            rank_b_scalar,\n            ov_opset.constant(1, Type.i64).output(0),\n            Type.i64,\n        ).output(0)\n\n        return ov_opset.reduce_logical_or(not_zero, axes, False).output(0)\n\n    def body(a, b):\n        a = get_ov_output(a)\n        b = get_ov_output(b)\n\n        zero = ov_opset.constant(0, b.get_element_type()).output(0)\n        mask = ov_opset.not_equal(b, zero).output(0)\n\n        one = ov_opset.constant(1, b.get_element_type()).output(0)\n        safe_b = ov_opset.select(mask, b, one).output(0)\n\n        mod_val = ov_opset.floor_mod(a, safe_b).output(0)\n\n        next_a = ov_opset.select(mask, b, a).output(0)\n        next_b = ov_opset.select(mask, mod_val, b).output(0)\n\n        return OpenVINOKerasTensor(next_a), OpenVINOKerasTensor(next_b)\n\n    x1_kt = OpenVINOKerasTensor(x1)\n    x2_kt = OpenVINOKerasTensor(x2)\n\n    results = while_loop(cond, body, (x1_kt, x2_kt))\n\n    return results[0]\n\n\ndef geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):\n    start = get_ov_output(start)\n    stop = get_ov_output(stop)\n\n    if dtype is None:\n        output_type = OPENVINO_DTYPES[config.floatx()]\n    else:\n        output_type = OPENVINO_DTYPES[dtype]\n\n    start = ov_opset.convert(start, output_type).output(0)\n    stop = ov_opset.convert(stop, output_type).output(0)\n\n    abs_start = ov_opset.abs(start).output(0)\n    abs_stop = ov_opset.abs(stop).output(0)\n\n    log_start = ov_opset.log(abs_start).output(0)\n    log_stop = ov_opset.log(abs_stop).output(0)\n\n    const_10 = ov_opset.constant(10.0, output_type).output(0)\n    log_10 = ov_opset.log(const_10).output(0)\n\n    log10_start = ov_opset.divide(log_start, log_10).output(0)\n    log10_stop = ov_opset.divide(log_stop, log_10).output(0)\n\n    result = logspace(\n        OpenVINOKerasTensor(log10_start),\n        OpenVINOKerasTensor(log10_stop),\n        num=num,\n        endpoint=endpoint,\n        base=10,\n        dtype=dtype,\n        axis=axis,\n    )\n\n    if num == 0:\n        return result\n\n    start_sign = ov_opset.sign(start).output(0)\n    result_output = get_ov_output(result)\n    return OpenVINOKerasTensor(\n        ov_opset.multiply(result_output, start_sign).output(0)\n    )\n\n\ndef greater(x1, x2):\n    element_type = None\n    if isinstance(x1, OpenVINOKerasTensor):\n        element_type = x1.output.get_element_type()\n    if isinstance(x2, OpenVINOKerasTensor):\n        element_type = x2.output.get_element_type()\n    x1 = get_ov_output(x1, element_type)\n    x2 = get_ov_output(x2, element_type)\n    x1, x2 = _align_operand_types(x1, x2, \"greater()\")\n    return OpenVINOKerasTensor(ov_opset.greater(x1, x2).output(0))\n\n\ndef greater_equal(x1, x2):\n    element_type = None\n    if isinstance(x1, OpenVINOKerasTensor):\n        element_type = x1.output.get_element_type()\n    if isinstance(x2, OpenVINOKerasTensor):\n        element_type = x2.output.get_element_type()\n    x1 = get_ov_output(x1, element_type)\n    x2 = get_ov_output(x2, element_type)\n    x1, x2 = _align_operand_types(x1, x2, \"greater_equal()\")\n    return OpenVINOKerasTensor(ov_opset.greater_equal(x1, x2).output(0))\n\n\ndef hstack(xs):\n    if not isinstance(xs, (list, tuple)):\n        xs = (xs,)\n    elems = [convert_to_tensor(elem) for elem in xs]\n    element_type = elems[0].output.get_element_type()\n    elems = [get_ov_output(elem, element_type) for elem in elems]\n    is_1d = elems and len(elems[0].get_partial_shape().to_shape()) == 1\n    axis = 0 if is_1d else 1\n    for i in range(1, len(elems)):\n        elems[0], elems[i] = _align_operand_types(\n            elems[0], elems[i], \"hstack()\"\n        )\n    return OpenVINOKerasTensor(ov_opset.concat(elems, axis).output(0))\n\n\ndef hsplit(x, indices_or_sections):\n    x_ov = get_ov_output(x)\n    if len(x_ov.get_partial_shape()) == 1:\n        return split(x, indices_or_sections, axis=0)\n    return split(x, indices_or_sections, axis=1)\n\n\ndef hypot(x1, x2):\n    element_type = None\n    if isinstance(x1, OpenVINOKerasTensor):\n        element_type = x1.output.get_element_type()\n    if isinstance(x2, OpenVINOKerasTensor):\n        element_type = x2.output.get_element_type()\n    x1 = get_ov_output(x1, element_type)\n    x2 = get_ov_output(x2, element_type)\n    x1, x2 = _align_operand_types(x1, x2, \"hypot()\")\n    x_type = x1.get_element_type()\n    if x_type.is_integral() or x_type == Type.boolean:\n        ov_type = OPENVINO_DTYPES[config.floatx()]\n        x1 = ov_opset.convert(x1, ov_type)\n        x2 = ov_opset.convert(x2, ov_type)\n    x1_abs = ov_opset.absolute(x1)\n    x2_abs = ov_opset.absolute(x2)\n    max_val = ov_opset.maximum(x1_abs, x2_abs)\n    min_val = ov_opset.minimum(x1_abs, x2_abs)\n    one = ov_opset.constant(1, max_val.get_element_type())\n    is_zero_mask = ov_opset.equal(\n        max_val, ov_opset.constant(0, max_val.get_element_type())\n    )\n    safe_divisor = ov_opset.select(is_zero_mask, one, max_val)\n    ratio = ov_opset.divide(min_val, safe_divisor)\n    result = ov_opset.multiply(\n        max_val,\n        ov_opset.sqrt(ov_opset.add(one, ov_opset.multiply(ratio, ratio))),\n    )\n    return OpenVINOKerasTensor(result.output(0))\n\n\ndef identity(n, dtype=None):\n    n = get_ov_output(n)\n    dtype = Type.f32 if dtype is None else dtype\n    if isinstance(dtype, str):\n        ov_dtype = OPENVINO_DTYPES[dtype]\n    else:\n        ov_dtype = dtype\n    n32 = ov_opset.convert(n, Type.i32).output(0)\n    identity_matrix = ov_opset.eye(\n        num_rows=n32, num_columns=n32, diagonal_index=0, output_type=ov_dtype\n    )\n    return OpenVINOKerasTensor(identity_matrix.output(0))\n\n\ndef imag(x):\n    # Implement properly when OpenVINO supports complex inputs\n    x = convert_to_tensor(x)\n    return zeros(x.shape, dtype=x.dtype)\n\n\ndef inner(x1, x2):\n    element_type = None\n    if isinstance(x1, OpenVINOKerasTensor):\n        element_type = x1.output.get_element_type()\n    if isinstance(x2, OpenVINOKerasTensor):\n        element_type = x2.output.get_element_type()\n    x1_out = get_ov_output(x1, element_type)\n    x2_out = get_ov_output(x2, element_type)\n\n    x1_rank = x1_out.get_partial_shape().rank\n    x2_rank = x2_out.get_partial_shape().rank\n\n    is_x1_scalar = x1_rank.is_static and x1_rank.get_length() == 0\n    is_x2_scalar = x2_rank.is_static and x2_rank.get_length() == 0\n\n    if is_x1_scalar or is_x2_scalar:\n        x1_out, x2_out = _align_operand_types(x1_out, x2_out, \"inner()\")\n        return OpenVINOKerasTensor(ov_opset.multiply(x1_out, x2_out).output(0))\n\n    return tensordot(x1, x2, axes=((-1,), (-1,)))\n\n\ndef isclose(x1, x2, rtol=1e-5, atol=1e-8, equal_nan=False):\n    dtype = OPENVINO_DTYPES[config.floatx()]\n\n    x1 = ov_opset.convert(get_ov_output(x1), dtype)\n    x2 = ov_opset.convert(get_ov_output(x2), dtype)\n    rtol = ov_opset.convert(get_ov_output(rtol), dtype)\n    atol = ov_opset.convert(get_ov_output(atol), dtype)\n\n    abs_diff = ov_opset.abs(ov_opset.subtract(x1, x2))\n    abs_x2 = ov_opset.abs(x2)\n    total_tolerance = ov_opset.add(atol, ov_opset.multiply(rtol, abs_x2))\n    is_close = ov_opset.less_equal(abs_diff, total_tolerance)\n    if equal_nan:\n        both_nan = ov_opset.logical_and(ov_opset.isnan(x1), ov_opset.isnan(x2))\n        is_close = ov_opset.logical_or(is_close, both_nan)\n\n    return OpenVINOKerasTensor(is_close.output(0))\n\n\ndef isfinite(x):\n    # NOTE: openvino has an is_finite operation but it does not properly\n    # catch np.inf and -np.inf as not finite values. Hence we bootstrap here. If\n    # that ever changes, we could simplify this to just call that operation.\n    inf_values = get_ov_output(isinf(x))\n    nan_values = get_ov_output(isnan(x))\n    all_non_finite_values = ov_opset.logical_or(inf_values, nan_values).output(\n        0\n    )\n    is_finite = ov_opset.logical_not(all_non_finite_values).output(0)\n    return OpenVINOKerasTensor(is_finite)\n\n\ndef isin(x1, x2, assume_unique=False, invert=False):\n    x1 = get_ov_output(x1)\n    x2 = get_ov_output(x2)\n    output_shape = ov_opset.shape_of(x1).output(0)\n    x1, x2 = _align_operand_types(x1, x2, \"isin()\")\n\n    minus_one = ov_opset.constant([-1], dtype=Type.i64)\n    x1 = ov_opset.reshape(x1, minus_one, special_zero=False).output(0)\n    x2 = ov_opset.reshape(x2, minus_one, special_zero=False).output(0)\n    if not assume_unique:\n        x2 = ov_opset.unique(x2).output(0)\n    x1 = ov_opset.unsqueeze(x1, 1).output(0)\n    x2 = ov_opset.unsqueeze(x2, 0).output(0)\n    cmp = ov_opset.equal(x1, x2).output(0)\n    result_flat = ov_opset.reduce_logical_or(cmp, 1).output(0)\n\n    if invert:\n        result_flat = ov_opset.logical_not(result_flat).output(0)\n    result = ov_opset.reshape(result_flat, output_shape, False).output(0)\n    return OpenVINOKerasTensor(result)\n\n\ndef isinf(x):\n    pos_inf = get_ov_output(isposinf(x))\n    neg_inf = get_ov_output(isneginf(x))\n    inf = ov_opset.logical_or(pos_inf, neg_inf).output(0)\n    return OpenVINOKerasTensor(inf)\n\n\ndef isnan(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type.is_integral():\n        x = ov_opset.convert(x, OPENVINO_DTYPES[config.floatx()])\n    return OpenVINOKerasTensor(ov_opset.is_nan(x).output(0))\n\n\ndef isneginf(x):\n    return _is_inf(x, pos=False)\n\n\ndef isposinf(x):\n    return _is_inf(x)\n\n\ndef _is_inf(x, pos=True):\n    # NOTE: there is an ov_opset.is_inf but it does not catch\n    # numpy infinite values like np.inf and -np.inf, hence why we have this\n    # if this ever changes in OpenVINO, we can do this instead:\n    # ov_opset.is_inf(x, {\"detect_positive\": pos, \"detect_negative\": not pos})\n    # for each infinite sign\n    inf_value = np.inf if pos else -np.inf\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n\n    if x_type.is_integral() or x_type == Type.boolean:\n        shape = ov_opset.shape_of(x, \"i32\").output(0)\n        false_const = ov_opset.constant(False, Type.boolean).output(0)\n        return OpenVINOKerasTensor(\n            ov_opset.broadcast(false_const, shape).output(0)\n        )\n\n    if x_type == Type.bf16:\n        x_f32 = ov_opset.convert(x, Type.f32).output(0)\n        inf = ov_opset.constant(inf_value, Type.f32).output(0)\n        is_inf = ov_opset.equal(x_f32, inf).output(0)\n    else:\n        if x_type == Type.f16:\n            inf = ov_opset.constant(inf_value, Type.f16).output(0)\n        elif x_type == Type.f32:\n            inf = ov_opset.constant(inf_value, Type.f32).output(0)\n        elif x_type == Type.f64:\n            inf = ov_opset.constant(inf_value, Type.f64).output(0)\n        else:\n            inf = ov_opset.constant(inf_value, Type.f32).output(0)\n        is_inf = ov_opset.equal(x, inf).output(0)\n    return OpenVINOKerasTensor(is_inf)\n\n\ndef isreal(x):\n    # Implement complex support when OpenVINO adds complex dtypes.\n    x = convert_to_tensor(x)\n    return ones(x.shape, dtype=\"bool\")\n\n\ndef kron(x1, x2):\n    x1 = get_ov_output(x1)\n    x2 = get_ov_output(x2)\n    x1, x2 = _align_operand_types(x1, x2, \"kron()\")\n    x1_shape = x1.get_partial_shape()\n    x2_shape = x2.get_partial_shape()\n    if x1_shape.rank.is_dynamic or x2_shape.rank.is_dynamic:\n        raise ValueError(\n            \"`kron` does not support tensors with dynamic rank for \"\n            \"the OpenVINO backend.\"\n        )\n    ndim1 = x1_shape.rank.get_length()\n    ndim2 = x2_shape.rank.get_length()\n    if ndim1 < ndim2:\n        axes = ov_opset.range(\n            ov_opset.constant(0, Type.i32),\n            ov_opset.constant(ndim2 - ndim1, Type.i32),\n            ov_opset.constant(1, Type.i32),\n        )\n        x1 = ov_opset.unsqueeze(x1, axes)\n        ndim1 = ndim2\n    elif ndim2 < ndim1:\n        axes = ov_opset.range(\n            ov_opset.constant(0, Type.i32),\n            ov_opset.constant(ndim1 - ndim2, Type.i32),\n            ov_opset.constant(1, Type.i32),\n        )\n        x2 = ov_opset.unsqueeze(x2, axes)\n        ndim2 = ndim1\n    shape1 = ov_opset.shape_of(x1, Type.i32)\n    shape2 = ov_opset.shape_of(x2, Type.i32)\n    ones = ov_opset.broadcast(\n        ov_opset.constant(1, Type.i32), ov_opset.constant([ndim1], Type.i32)\n    )\n    axis = ov_opset.constant(1, Type.i32)\n    flatten = ov_opset.constant([-1], Type.i32)\n    unsqueezed_ones = ov_opset.unsqueeze(ones, axis)\n    x1_new_shape = ov_opset.reshape(\n        ov_opset.concat(\n            [ov_opset.unsqueeze(shape1, axis), unsqueezed_ones],\n            axis=1,\n        ),\n        flatten,\n        False,\n    )\n    x2_new_shape = ov_opset.reshape(\n        ov_opset.concat(\n            [unsqueezed_ones, ov_opset.unsqueeze(shape2, axis)],\n            axis=1,\n        ),\n        flatten,\n        False,\n    )\n    result = ov_opset.multiply(\n        ov_opset.reshape(x1, x1_new_shape, False),\n        ov_opset.reshape(x2, x2_new_shape, False),\n    )\n    result = ov_opset.reshape(\n        result, ov_opset.multiply(shape1, shape2), False\n    ).output(0)\n    return OpenVINOKerasTensor(result)\n\n\ndef lcm(x1, x2):\n    x1 = get_ov_output(x1)\n    x2 = get_ov_output(x2)\n    x1, x2 = _align_operand_types(x1, x2, \"lcm()\")\n    if not x1.get_element_type().is_integral():\n        raise ValueError(\"`lcm` is only supported for integer types.\")\n    x1_abs = ov_opset.abs(x1).output(0)\n    x2_abs = ov_opset.abs(x2).output(0)\n\n    gcd_val = gcd(x1, x2)\n    gcd_val = get_ov_output(gcd_val)\n\n    zero = ov_opset.constant(0, gcd_val.get_element_type()).output(0)\n    one = ov_opset.constant(1, gcd_val.get_element_type()).output(0)\n\n    is_zero = ov_opset.equal(gcd_val, zero).output(0)\n    safe_gcd = ov_opset.select(is_zero, one, gcd_val).output(0)\n\n    term1 = ov_opset.divide(x1_abs, safe_gcd).output(0)\n    result = ov_opset.multiply(term1, x2_abs).output(0)\n\n    return OpenVINOKerasTensor(result)\n\n\ndef ldexp(x1, x2):\n    element_type = None\n    if isinstance(x1, OpenVINOKerasTensor):\n        element_type = x1.output.get_element_type()\n    if isinstance(x2, OpenVINOKerasTensor):\n        element_type = x2.output.get_element_type()\n    x1 = get_ov_output(x1, element_type)\n    x2 = get_ov_output(x2, element_type)\n    x1, x2 = _align_operand_types(x1, x2, \"ldexp()\")\n\n    float_dtype = OPENVINO_DTYPES[config.floatx()]\n    if x1.get_element_type().is_integral():\n        x1 = ov_opset.convert(x1, float_dtype)\n    if x2.get_element_type().is_integral():\n        x2 = ov_opset.convert(x2, float_dtype)\n\n    const_two = ov_opset.constant(2, x2.get_element_type())\n    result = ov_opset.multiply(x1, ov_opset.power(const_two, x2))\n\n    return OpenVINOKerasTensor(result.output(0))\n\n\ndef less(x1, x2):\n    element_type = None\n    if isinstance(x1, OpenVINOKerasTensor):\n        element_type = x1.output.get_element_type()\n    if isinstance(x2, OpenVINOKerasTensor):\n        element_type = x2.output.get_element_type()\n    x1 = get_ov_output(x1, element_type)\n    x2 = get_ov_output(x2, element_type)\n    x1, x2 = _align_operand_types(x1, x2, \"less()\")\n    return OpenVINOKerasTensor(ov_opset.less(x1, x2).output(0))\n\n\ndef less_equal(x1, x2):\n    element_type = None\n    if isinstance(x1, OpenVINOKerasTensor):\n        element_type = x1.output.get_element_type()\n    if isinstance(x2, OpenVINOKerasTensor):\n        element_type = x2.output.get_element_type()\n    x1 = get_ov_output(x1, element_type)\n    x2 = get_ov_output(x2, element_type)\n    x1, x2 = _align_operand_types(x1, x2, \"less_equal()\")\n    return OpenVINOKerasTensor(ov_opset.less_equal(x1, x2).output(0))\n\n\ndef linspace(\n    start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0\n):\n    \"\"\"Return evenly spaced numbers over a specified interval.\n\n    Supports axis=0 (prepend) and axis=-1 (append). Intermediate axis values are\n    treated as axis=-1.\n\n    If `retstep` is True, also returns the step size between values.\n\n    \"\"\"\n\n    start = get_ov_output(start)\n    stop = get_ov_output(stop)\n\n    if hasattr(num, \"output\") or isinstance(num, OpenVINOKerasTensor):\n        num_tensor = get_ov_output(num)\n        try:\n            if num_tensor.get_node().get_type_name() == \"Constant\":\n                num_value = num_tensor.get_node().get_vector()[0]\n                num = int(num_value)\n            else:\n                raise NotImplementedError(\n                    \"Dynamic num values not fully supported\"\n                )\n        except Exception as e:\n            raise NotImplementedError(\n                \"Could not extract num value from tensor\"\n            ) from e\n    else:\n        num = int(num)\n\n    if dtype is None:\n        output_type = OPENVINO_DTYPES[config.floatx()]\n    else:\n        output_type = OPENVINO_DTYPES[dtype]\n\n    start = ov_opset.convert(start, output_type).output(0)\n    stop = ov_opset.convert(stop, output_type).output(0)\n\n    if num < 0:\n        raise ValueError(\"Number of samples, `num`, must be non-negative.\")\n\n    if num == 0:\n        empty_shape = ov_opset.constant([0], Type.i32).output(0)\n        result = ov_opset.broadcast(\n            ov_opset.constant(0.0, output_type).output(0), empty_shape\n        ).output(0)\n        if retstep:\n            nan_step = ov_opset.constant(np.nan, output_type).output(0)\n            return OpenVINOKerasTensor(result), OpenVINOKerasTensor(nan_step)\n        return OpenVINOKerasTensor(result)\n\n    if num == 1:\n        result_val = start\n        axis_const = ov_opset.constant([axis], Type.i32).output(0)\n        result = ov_opset.unsqueeze(result_val, axis_const).output(0)\n        if retstep:\n            if endpoint:\n                step = ov_opset.constant(np.nan, output_type).output(0)\n            else:\n                step = ov_opset.subtract(stop, start).output(0)\n            return OpenVINOKerasTensor(result), OpenVINOKerasTensor(step)\n    zero_i32 = ov_opset.constant(0, Type.i32).output(0)\n    one_i32 = ov_opset.constant(1, Type.i32).output(0)\n    one_i32_array = ov_opset.constant([1], Type.i32).output(0)\n\n    num_const = ov_opset.constant(num, output_type).output(0)\n\n    if endpoint:\n        divisor = ov_opset.subtract(\n            num_const, ov_opset.constant(1, output_type).output(0)\n        ).output(0)\n    else:\n        divisor = num_const\n\n    step = ov_opset.divide(\n        ov_opset.subtract(stop, start).output(0), divisor\n    ).output(0)\n\n    indices = ov_opset.range(\n        zero_i32,\n        ov_opset.constant(num, Type.i32).output(0),\n        one_i32,\n        output_type,\n    ).output(0)\n\n    start_shape = ov_opset.convert(\n        ov_opset.shape_of(start).output(0), Type.i32\n    ).output(0)\n    indices_shape = ov_opset.convert(\n        ov_opset.shape_of(indices).output(0), Type.i32\n    ).output(0)\n\n    start_rank = ov_opset.shape_of(start_shape).output(0)\n    ones_for_start = ov_opset.broadcast(one_i32, start_rank).output(0)\n\n    if axis == 0:\n        indices_target_shape = ov_opset.concat(\n            [indices_shape, ones_for_start], 0\n        ).output(0)\n        start_target_shape = ov_opset.concat(\n            [one_i32_array, start_shape], 0\n        ).output(0)\n    else:\n        indices_target_shape = ov_opset.concat(\n            [ones_for_start, indices_shape], 0\n        ).output(0)\n        start_target_shape = ov_opset.concat(\n            [start_shape, one_i32_array], 0\n        ).output(0)\n\n    indices_reshaped = ov_opset.reshape(\n        indices, indices_target_shape, False\n    ).output(0)\n    start_reshaped = ov_opset.reshape(start, start_target_shape, False).output(\n        0\n    )\n    step_reshaped = ov_opset.reshape(step, start_target_shape, False).output(0)\n\n    scaled_indices = ov_opset.multiply(indices_reshaped, step_reshaped).output(\n        0\n    )\n    result = ov_opset.add(start_reshaped, scaled_indices).output(0)\n\n    if retstep:\n        return OpenVINOKerasTensor(result), OpenVINOKerasTensor(step)\n    return OpenVINOKerasTensor(result)\n\n\ndef log(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type.is_integral():\n        x_type = OPENVINO_DTYPES[config.floatx()]\n        x = ov_opset.convert(x, x_type)\n    return OpenVINOKerasTensor(ov_opset.log(x).output(0))\n\n\ndef log10(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type.is_integral():\n        x_type = OPENVINO_DTYPES[config.floatx()]\n        x = ov_opset.convert(x, x_type)\n    log_x = ov_opset.log(x).output(0)\n    const_10 = ov_opset.constant(10, x_type).output(0)\n    log_10 = ov_opset.log(const_10).output(0)\n    result = ov_opset.divide(log_x, log_10).output(0)\n    return OpenVINOKerasTensor(result)\n\n\ndef log1p(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n\n    if x_type.is_integral():\n        x_type = OPENVINO_DTYPES[config.floatx()]\n        x = ov_opset.convert(x, x_type)\n\n    one_const = ov_opset.constant(1, x_type).output(0)\n    added = ov_opset.add(x, one_const).output(0)\n    result = ov_opset.log(added).output(0)\n    return OpenVINOKerasTensor(result)\n\n\ndef log2(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type.is_integral():\n        x_type = OPENVINO_DTYPES[config.floatx()]\n        x = ov_opset.convert(x, x_type)\n    log_x = ov_opset.log(x).output(0)\n    const_2 = ov_opset.constant(2, x_type).output(0)\n    log_2 = ov_opset.log(const_2).output(0)\n    result = ov_opset.divide(log_x, log_2).output(0)\n    return OpenVINOKerasTensor(result)\n\n\ndef logaddexp(x1, x2):\n    element_type = None\n    if isinstance(x1, OpenVINOKerasTensor):\n        element_type = x1.output.get_element_type()\n    if isinstance(x2, OpenVINOKerasTensor):\n        element_type = x2.output.get_element_type()\n    x1 = get_ov_output(x1, element_type)\n    x2 = get_ov_output(x2, element_type)\n    x1, x2 = _align_operand_types(x1, x2, \"logaddexp()\")\n\n    if x1.element_type.is_integral() or x2.element_type.is_integral():\n        float_dtype = OPENVINO_DTYPES[config.floatx()]\n        if x1.element_type.is_integral():\n            x1 = ov_opset.convert(x1, float_dtype)\n        if x2.element_type.is_integral():\n            x2 = ov_opset.convert(x2, float_dtype)\n\n    # Get the output nodes properly\n    max_val_node = ov_opset.maximum(x1, x2)\n    max_val = max_val_node.output(0)\n\n    # Compute absolute difference\n    sub_node = ov_opset.subtract(x1, x2)\n    abs_diff_node = ov_opset.abs(sub_node.output(0))\n    abs_diff = abs_diff_node.output(0)\n\n    # Compute negative absolute difference and its exponential\n    neg_abs_diff_node = ov_opset.negative(abs_diff)\n    neg_abs_diff = neg_abs_diff_node.output(0)\n    exp_neg_abs_node = ov_opset.exp(neg_abs_diff)\n    exp_neg_abs = exp_neg_abs_node.output(0)\n\n    # Get the element type from the node, not the output\n    element_type = exp_neg_abs_node.get_element_type()\n    one_node = ov_opset.constant(1, element_type)\n    one = one_node.output(0)\n\n    # Compute log term\n    one_plus_exp_node = ov_opset.add(one, exp_neg_abs)\n    one_plus_exp = one_plus_exp_node.output(0)\n    log_term_node = ov_opset.log(one_plus_exp)\n    log_term = log_term_node.output(0)\n\n    # Final result\n    result_node = ov_opset.add(max_val, log_term)\n    result = result_node.output(0)\n\n    return OpenVINOKerasTensor(result)\n\n\ndef logaddexp2(x1, x2):\n    element_type = None\n    if isinstance(x1, OpenVINOKerasTensor):\n        element_type = x1.output.get_element_type()\n    if isinstance(x2, OpenVINOKerasTensor):\n        element_type = x2.output.get_element_type()\n    x1 = get_ov_output(x1, element_type)\n    x2 = get_ov_output(x2, element_type)\n    x1, x2 = _align_operand_types(x1, x2, \"logaddexp2()\")\n\n    if x1.element_type.is_integral() or x2.element_type.is_integral():\n        float_dtype = OPENVINO_DTYPES[config.floatx()]\n        if x1.get_element_type().is_integral():\n            x1 = ov_opset.convert(x1, float_dtype)\n        if x2.get_element_type().is_integral():\n            x2 = ov_opset.convert(x2, float_dtype)\n\n    max_val = ov_opset.maximum(x1, x2)\n\n    sub = ov_opset.subtract(x1, x2)\n    abs_diff = ov_opset.abs(sub)\n\n    neg_abs_diff = ov_opset.negative(abs_diff)\n\n    element_type = neg_abs_diff.get_element_type()\n\n    two = ov_opset.constant(2, dtype=element_type)\n\n    power_of_2 = ov_opset.power(two, neg_abs_diff)\n\n    one_plus_power = ov_opset.add(\n        ov_opset.constant(1, dtype=element_type), power_of_2\n    )\n    log2_term = ov_opset.divide(ov_opset.log(one_plus_power), ov_opset.log(two))\n    result = ov_opset.add(max_val, log2_term).output(0)\n\n    return OpenVINOKerasTensor(result)\n\n\ndef logical_and(x1, x2):\n    x1 = get_ov_output(x1)\n    x2 = get_ov_output(x2)\n    x1 = ov_opset.convert(x1, Type.boolean).output(0)\n    x2 = ov_opset.convert(x2, Type.boolean).output(0)\n    return OpenVINOKerasTensor(ov_opset.logical_and(x1, x2).output(0))\n\n\ndef logical_not(x):\n    x = get_ov_output(x)\n    x = ov_opset.convert(x, Type.boolean).output(0)\n    return OpenVINOKerasTensor(ov_opset.logical_not(x).output(0))\n\n\ndef logical_or(x1, x2):\n    x1 = get_ov_output(x1)\n    x2 = get_ov_output(x2)\n    x1 = ov_opset.convert(x1, Type.boolean).output(0)\n    x2 = ov_opset.convert(x2, Type.boolean).output(0)\n    return OpenVINOKerasTensor(ov_opset.logical_or(x1, x2).output(0))\n\n\ndef logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0):\n    linear_samples = linspace(\n        start=start,\n        stop=stop,\n        num=num,\n        endpoint=endpoint,\n        retstep=False,\n        dtype=dtype,\n        axis=axis,\n    )\n\n    if dtype is None:\n        output_type = OPENVINO_DTYPES[config.floatx()]\n    else:\n        output_type = OPENVINO_DTYPES[dtype]\n\n    linear_output = get_ov_output(linear_samples)\n    base_tensor = get_ov_output(base)\n\n    base_tensor = ov_opset.convert(base_tensor, output_type).output(0)\n\n    result = ov_opset.power(base_tensor, linear_output).output(0)\n\n    return OpenVINOKerasTensor(result)\n\n\ndef maximum(x1, x2):\n    x1, x2 = _promote_binary_op_types(x1, x2)\n    x1, x2 = _align_operand_types(x1, x2, \"maximum()\")\n    return OpenVINOKerasTensor(ov_opset.maximum(x1, x2).output(0))\n\n\ndef median(x, axis=None, keepdims=False):\n    x = get_ov_output(x)\n    x_shape = x.get_partial_shape()\n    rank = x_shape.rank.get_length()\n\n    if rank == 0:\n        return OpenVINOKerasTensor(x)\n\n    # Handle axis=None by flattening the input\n    flattened_all = False\n    if axis is None:\n        x = ov_opset.reshape(x, [-1], False).output(0)\n        axis = 0\n        original_rank = rank\n        rank = 1\n        flattened_all = True\n    else:\n        # Handle tuple axis - for median, we only support single axis\n        if isinstance(axis, (tuple, list)):\n            if len(axis) != 1:\n                raise ValueError(\"median only supports single axis reduction\")\n            axis = axis[0]\n\n        # Handle negative axis\n        if axis < 0:\n            axis = rank + axis\n        original_rank = rank\n\n    # Get the size of the dimension to sort\n    shape_tensor = ov_opset.shape_of(x, output_type=Type.i32).output(0)\n    k = ov_opset.gather(\n        shape_tensor,\n        ov_opset.constant([axis], Type.i32).output(0),\n        ov_opset.constant(0, Type.i32).output(0),\n    ).output(0)\n\n    # Convert k to a scalar value\n    k_scalar = ov_opset.squeeze(k, [0]).output(0)\n\n    # Use topk with k=size_of_axis to get all elements sorted\n    topk_outputs = ov_opset.topk(\n        x, k=k_scalar, axis=axis, mode=\"min\", sort=\"value\", stable=True\n    )\n\n    # Get the sorted values\n    sorted_values = topk_outputs.output(0)\n\n    # Convert to float for median calculation\n    x1_type = ov_to_keras_type(sorted_values.get_element_type())\n    result_type = dtypes.result_type(x1_type, float)\n    result_type = OPENVINO_DTYPES[result_type]\n    sorted_values = ov_opset.convert(sorted_values, result_type).output(0)\n\n    # Calculate median indices\n    # For odd length: median_idx = (k-1) // 2\n    # For even length: we need indices (k//2 - 1) and k//2, then average\n\n    k_minus_1 = ov_opset.subtract(\n        k_scalar, ov_opset.constant(1, Type.i32).output(0)\n    ).output(0)\n    k_div_2 = ov_opset.divide(\n        k_scalar, ov_opset.constant(2, Type.i32).output(0)\n    ).output(0)\n    k_minus_1_div_2 = ov_opset.divide(\n        k_minus_1, ov_opset.constant(2, Type.i32).output(0)\n    ).output(0)\n\n    # Check if k is odd\n    k_mod_2 = ov_opset.mod(\n        k_scalar, ov_opset.constant(2, Type.i32).output(0)\n    ).output(0)\n    is_odd = ov_opset.equal(\n        k_mod_2, ov_opset.constant(1, Type.i32).output(0)\n    ).output(0)\n\n    # For odd case: take the middle element\n    odd_idx = k_minus_1_div_2\n\n    # For even case: take average of two middle elements\n    even_idx1 = ov_opset.subtract(\n        k_div_2, ov_opset.constant(1, Type.i32).output(0)\n    ).output(0)\n    even_idx2 = k_div_2\n\n    # Gather elements for both cases\n    # Create gather indices tensor for the axis\n    gather_indices_odd = ov_opset.unsqueeze(odd_idx, [0]).output(0)\n    gather_indices_even1 = ov_opset.unsqueeze(even_idx1, [0]).output(0)\n    gather_indices_even2 = ov_opset.unsqueeze(even_idx2, [0]).output(0)\n\n    # Gather the median elements\n    odd_result = ov_opset.gather(\n        sorted_values,\n        gather_indices_odd,\n        ov_opset.constant(axis, Type.i32).output(0),\n    ).output(0)\n    even_result1 = ov_opset.gather(\n        sorted_values,\n        gather_indices_even1,\n        ov_opset.constant(axis, Type.i32).output(0),\n    ).output(0)\n    even_result2 = ov_opset.gather(\n        sorted_values,\n        gather_indices_even2,\n        ov_opset.constant(axis, Type.i32).output(0),\n    ).output(0)\n\n    # Average the two middle elements for even case\n    even_sum = ov_opset.add(even_result1, even_result2).output(0)\n    even_result = ov_opset.divide(\n        even_sum, ov_opset.constant(2.0, result_type).output(0)\n    ).output(0)\n\n    # Select between odd and even results\n    median_result = ov_opset.select(is_odd, odd_result, even_result).output(0)\n\n    # Remove the gathered dimension (squeeze)\n    median_result = ov_opset.squeeze(median_result, [axis]).output(0)\n\n    # Handle keepdims\n    if keepdims:\n        if flattened_all:\n            # When axis=None, keepdims should restore all dimensions as 1\n            ones_shape = ov_opset.constant(\n                [1] * original_rank, Type.i32\n            ).output(0)\n            median_result = ov_opset.reshape(\n                median_result, ones_shape, False\n            ).output(0)\n        else:\n            median_result = ov_opset.unsqueeze(median_result, [axis]).output(0)\n\n    return OpenVINOKerasTensor(median_result)\n\n\ndef meshgrid(*x, indexing=\"xy\"):\n    if len(x) < 2:\n        raise ValueError(\n            \"meshgrid requires at least 2 input arrays. \"\n            f\"Received: {len(x)} input array(s).\"\n        )\n    if indexing not in (\"xy\", \"ij\"):\n        raise ValueError(\"indexing must be either 'xy' or 'ij'\")\n\n    tensors = [get_ov_output(xi) for xi in x]\n    n = len(tensors)\n\n    shapes = [\n        ov_opset.shape_of(t, Type.i64).output(0) for t in tensors\n    ]  # each is [Ni]\n    one = ov_opset.constant([1], Type.i64).output(0)\n\n    if indexing == \"xy\":\n        shape_list = [shapes[1], shapes[0]] + shapes[2:]\n        out_shape = ov_opset.concat(shape_list, axis=0).output(0)\n    else:\n        out_shape = ov_opset.concat(shapes, axis=0).output(0)\n\n    outputs = []\n    for i, t in enumerate(tensors):\n        reshape_parts = [one] * n\n        if indexing == \"xy\":\n            if i == 0:\n                reshape_parts[1] = shapes[0]\n            elif i == 1:\n                reshape_parts[0] = shapes[1]\n            else:\n                reshape_parts[i] = shapes[i]\n        else:\n            reshape_parts[i] = shapes[i]\n\n        reshape_shape = ov_opset.concat(reshape_parts, axis=0).output(0)\n        reshaped = ov_opset.reshape(t, reshape_shape, False).output(0)\n        broadcasted = ov_opset.broadcast(reshaped, out_shape).output(0)\n        outputs.append(OpenVINOKerasTensor(broadcasted))\n\n    return outputs\n\n\ndef min(x, axis=None, keepdims=False, initial=None):\n    return _compute_extrema(x, \"min\", axis, keepdims, initial)\n\n\ndef minimum(x1, x2):\n    x1, x2 = _promote_binary_op_types(x1, x2)\n    x1, x2 = _align_operand_types(x1, x2, \"minimum()\")\n    return OpenVINOKerasTensor(ov_opset.minimum(x1, x2).output(0))\n\n\ndef mod(x1, x2):\n    x1 = get_ov_output(x1)\n    x2 = get_ov_output(x2)\n    x1, x2 = _align_operand_types(x1, x2, \"mod()\")\n    return OpenVINOKerasTensor(ov_opset.floor_mod(x1, x2).output(0))\n\n\ndef fmod(x1, x2):\n    x1 = get_ov_output(x1)\n    x2 = get_ov_output(x2)\n    x1, x2 = _align_operand_types(x1, x2, \"fmod()\")\n    return OpenVINOKerasTensor(ov_opset.mod(x1, x2).output(0))\n\n\ndef moveaxis(x, source, destination):\n    x = get_ov_output(x)\n    if isinstance(source, int):\n        source = [source]\n    if isinstance(destination, int):\n        destination = [destination]\n\n    ndim = x.get_partial_shape().rank.get_length()\n    source = [axis if axis >= 0 else axis + ndim for axis in source]\n    destination = [axis if axis >= 0 else axis + ndim for axis in destination]\n\n    axes = list(range(ndim))\n    for src, dst in zip(source, destination):\n        axes.remove(src)\n        axes.insert(dst, src)\n\n    axes_const = ov_opset.constant(axes, Type.i32).output(0)\n    return OpenVINOKerasTensor(ov_opset.transpose(x, axes_const).output(0))\n\n\ndef nanargmax(x, axis=None, keepdims=False):\n    if isinstance(x, np.ndarray) and x.dtype == np.float64:\n        # conversion to f32 due to https://github.com/openvinotoolkit/openvino/issues/34138\n        x = x.astype(np.float32)\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type == Type.f64:\n        # conversion to f32 due to https://github.com/openvinotoolkit/openvino/issues/34138\n        x = ov_opset.convert(x, Type.f32).output(0)\n        x_type = Type.f32\n\n    original_axis = axis\n\n    if x_type.is_integral() or x_type == Type.boolean:\n        return argmax(\n            OpenVINOKerasTensor(x), axis=original_axis, keepdims=keepdims\n        )\n\n    x, resolved_axis = _resolve_axis(x, original_axis)\n    if resolved_axis is None:\n        return OpenVINOKerasTensor(x)\n\n    nan_mask = ov_opset.is_nan(x)\n    neg_inf = ov_opset.constant(np.array(-np.inf, dtype=np.float32))\n    if x_type != Type.f32:\n        neg_inf = ov_opset.convert(neg_inf, x_type)\n    x_replaced = ov_opset.select(nan_mask, neg_inf, x).output(0)\n\n    result = argmax(\n        OpenVINOKerasTensor(x_replaced), axis=original_axis, keepdims=keepdims\n    )\n    result_ov = get_ov_output(result)\n\n    all_nan = ov_opset.reduce_logical_and(\n        nan_mask, resolved_axis, keepdims\n    ).output(0)\n    nan_value = ov_opset.constant(-1, Type.i32).output(0)\n    if result_ov.get_element_type() != Type.i32:\n        nan_value = ov_opset.convert(nan_value, result_ov.get_element_type())\n    result_ov = ov_opset.select(all_nan, nan_value, result_ov).output(0)\n\n    return OpenVINOKerasTensor(result_ov)\n\n\ndef nanargmin(x, axis=None, keepdims=False):\n    if isinstance(x, np.ndarray) and x.dtype == np.float64:\n        # conversion to f32 due to https://github.com/openvinotoolkit/openvino/issues/34138\n        x = x.astype(np.float32)\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type == Type.f64:\n        # conversion to f32 due to https://github.com/openvinotoolkit/openvino/issues/34138\n        x = ov_opset.convert(x, Type.f32).output(0)\n        x_type = Type.f32\n\n    original_axis = axis\n\n    if x_type.is_integral() or x_type == Type.boolean:\n        return argmin(\n            OpenVINOKerasTensor(x), axis=original_axis, keepdims=keepdims\n        )\n\n    x, resolved_axis = _resolve_axis(x, original_axis)\n    if resolved_axis is None:\n        return OpenVINOKerasTensor(x)\n\n    nan_mask = ov_opset.is_nan(x)\n    pos_inf = ov_opset.constant(np.array(np.inf, dtype=np.float32))\n    if x_type != Type.f32:\n        pos_inf = ov_opset.convert(pos_inf, x_type)\n    x_replaced = ov_opset.select(nan_mask, pos_inf, x).output(0)\n\n    result = argmin(\n        OpenVINOKerasTensor(x_replaced), axis=original_axis, keepdims=keepdims\n    )\n    result_ov = get_ov_output(result)\n\n    all_nan = ov_opset.reduce_logical_and(\n        nan_mask, resolved_axis, keepdims\n    ).output(0)\n    nan_value = ov_opset.constant(-1, Type.i32).output(0)\n    if result_ov.get_element_type() != Type.i32:\n        nan_value = ov_opset.convert(nan_value, result_ov.get_element_type())\n    result_ov = ov_opset.select(all_nan, nan_value, result_ov).output(0)\n\n    return OpenVINOKerasTensor(result_ov)\n\n\ndef nancumsum(x, axis=None, dtype=None):\n    return cumsum(nan_to_num(x, nan=0.0), axis=axis, dtype=dtype)\n\n\ndef nancumprod(x, axis=None, dtype=None):\n    return cumprod(nan_to_num(x, nan=1.0), axis=axis, dtype=dtype)\n\n\ndef nanmax(x, axis=None, keepdims=False):\n    if isinstance(x, np.ndarray) and x.dtype == np.float64:\n        # conversion to f32 due to https://github.com/openvinotoolkit/openvino/issues/34138\n        x = x.astype(np.float32)\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type == Type.f64:\n        # conversion to f32 due to https://github.com/openvinotoolkit/openvino/issues/34138\n        x = ov_opset.convert(x, Type.f32).output(0)\n        x_type = Type.f32\n\n    if x_type.is_integral() or x_type == Type.boolean:\n        return amax(OpenVINOKerasTensor(x), axis=axis, keepdims=keepdims)\n\n    x, axis = _resolve_axis(x, axis)\n    if axis is None:\n        return OpenVINOKerasTensor(x)\n\n    nan_mask = ov_opset.is_nan(x)\n    neg_inf = ov_opset.constant(np.array(-np.inf, dtype=np.float32))\n    if x_type != Type.f32:\n        neg_inf = ov_opset.convert(neg_inf, x_type)\n    x_replaced = ov_opset.select(nan_mask, neg_inf, x).output(0)\n\n    result = ov_opset.reduce_max(x_replaced, axis, keepdims).output(0)\n\n    all_nan = ov_opset.reduce_logical_and(nan_mask, axis, keepdims).output(0)\n    nan_value = ov_opset.constant(np.array(np.nan, dtype=np.float32))\n    if x_type != Type.f32:\n        nan_value = ov_opset.convert(nan_value, x_type)\n    result = ov_opset.select(all_nan, nan_value, result).output(0)\n\n    return OpenVINOKerasTensor(result)\n\n\ndef nanmean(x, axis=None, keepdims=False):\n    if isinstance(x, np.ndarray) and x.dtype == np.float64:\n        # conversion to f32 due to https://github.com/openvinotoolkit/openvino/issues/34138\n        x = x.astype(np.float32)\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type == Type.f64:\n        # conversion to f32 due to https://github.com/openvinotoolkit/openvino/issues/34138\n        x = ov_opset.convert(x, Type.f32).output(0)\n        x_type = Type.f32\n\n    if x_type.is_integral() or x_type == Type.boolean:\n        return mean(OpenVINOKerasTensor(x), axis=axis, keepdims=keepdims)\n\n    x, axis = _resolve_axis(x, axis)\n    if axis is None:\n        return OpenVINOKerasTensor(x)\n\n    nan_mask = ov_opset.is_nan(x)\n    zero = ov_opset.constant(0, x_type)\n    x_no_nan = ov_opset.select(nan_mask, zero, x).output(0)\n\n    not_nan = ov_opset.logical_not(nan_mask).output(0)\n    not_nan_float = ov_opset.convert(not_nan, x_type).output(0)\n\n    nan_sum = ov_opset.reduce_sum(x_no_nan, axis, keepdims).output(0)\n    count = ov_opset.reduce_sum(not_nan_float, axis, keepdims).output(0)\n    result = ov_opset.divide(nan_sum, count).output(0)\n    return OpenVINOKerasTensor(result)\n\n\ndef nanmin(x, axis=None, keepdims=False):\n    if isinstance(x, np.ndarray) and x.dtype == np.float64:\n        # conversion to f32 due to https://github.com/openvinotoolkit/openvino/issues/34138\n        x = x.astype(np.float32)\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type == Type.f64:\n        # conversion to f32 due to https://github.com/openvinotoolkit/openvino/issues/34138\n        x = ov_opset.convert(x, Type.f32).output(0)\n        x_type = Type.f32\n\n    if x_type.is_integral() or x_type == Type.boolean:\n        return amin(OpenVINOKerasTensor(x), axis=axis, keepdims=keepdims)\n\n    x, axis = _resolve_axis(x, axis)\n    if axis is None:\n        return OpenVINOKerasTensor(x)\n\n    nan_mask = ov_opset.is_nan(x)\n    pos_inf = ov_opset.constant(np.array(np.inf, dtype=np.float32))\n    if x_type != Type.f32:\n        pos_inf = ov_opset.convert(pos_inf, x_type)\n    x_replaced = ov_opset.select(nan_mask, pos_inf, x).output(0)\n\n    result = ov_opset.reduce_min(x_replaced, axis, keepdims).output(0)\n\n    all_nan = ov_opset.reduce_logical_and(nan_mask, axis, keepdims).output(0)\n    nan_value = ov_opset.constant(np.array(np.nan, dtype=np.float32))\n    if x_type != Type.f32:\n        nan_value = ov_opset.convert(nan_value, x_type)\n    result = ov_opset.select(all_nan, nan_value, result).output(0)\n\n    return OpenVINOKerasTensor(result)\n\n\ndef nanprod(x, axis=None, keepdims=False):\n    if isinstance(x, np.ndarray) and x.dtype == np.float64:\n        # conversion to f32 due to https://github.com/openvinotoolkit/openvino/issues/34138\n        x = x.astype(np.float32)\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type == Type.f64:\n        # conversion to f32 due to https://github.com/openvinotoolkit/openvino/issues/34138\n        x = ov_opset.convert(x, Type.f32).output(0)\n        x_type = Type.f32\n\n    if not x_type.is_integral() and x_type != Type.boolean:\n        nan_mask = ov_opset.is_nan(x)\n        one = ov_opset.constant(1, x_type)\n        x = ov_opset.select(nan_mask, one, x).output(0)\n\n    x = _upcast_type_if_needed(x)\n    x, axis = _resolve_axis(x, axis)\n    if axis is None:\n        return OpenVINOKerasTensor(x)\n\n    result = ov_opset.reduce_prod(x, axis, keepdims).output(0)\n    return OpenVINOKerasTensor(result)\n\n\ndef nanstd(x, axis=None, keepdims=False):\n    return sqrt(nanvar(x, axis=axis, keepdims=keepdims))\n\n\ndef nansum(x, axis=None, keepdims=False):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n\n    if not x_type.is_integral() and x_type != Type.boolean:\n        nan_mask = ov_opset.is_nan(x)\n        zero = ov_opset.constant(0, x_type)\n        x = ov_opset.select(nan_mask, zero, x).output(0)\n\n    x, axis = _resolve_axis(x, axis)\n    if axis is None:\n        return OpenVINOKerasTensor(x)\n\n    x = _upcast_type_if_needed(x)\n    result = ov_opset.reduce_sum(x, axis, keepdims).output(0)\n\n    return OpenVINOKerasTensor(result)\n\n\ndef nanvar(x, axis=None, keepdims=False):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    x_keras = ov_to_keras_type(x_type)\n    result_dtype = dtypes.result_type(x_keras, float)\n    ov_result_type = OPENVINO_DTYPES[result_dtype]\n\n    # Compute in float32 due to OpenVINO f64 limitation\n    if x_type == Type.f64:\n        x = ov_opset.convert(x, Type.f32).output(0)\n        x_type = Type.f32\n\n    if x_type.is_integral() or x_type == Type.boolean:\n        result = var(OpenVINOKerasTensor(x), axis=axis, keepdims=keepdims)\n        result = get_ov_output(result)\n        if result.get_element_type() != ov_result_type:\n            result = ov_opset.convert(result, ov_result_type).output(0)\n        return OpenVINOKerasTensor(result)\n\n    if axis == () or axis == []:\n        nan_mask = ov_opset.is_nan(x).output(0)\n        zero = ov_opset.constant(0, x_type).output(0)\n        result = ov_opset.select(nan_mask, x, zero).output(0)\n        if x_type != ov_result_type:\n            result = ov_opset.convert(result, ov_result_type).output(0)\n        return OpenVINOKerasTensor(result)\n\n    # Compute mean ignoring NaN, keeping dims for broadcasting\n    mean_val = get_ov_output(\n        nanmean(OpenVINOKerasTensor(x), axis=axis, keepdims=True)\n    )\n\n    nan_mask = ov_opset.is_nan(x)\n    zero = ov_opset.constant(0, x_type)\n    not_nan = ov_opset.logical_not(nan_mask).output(0)\n\n    # Squared deviations, zeroed where NaN\n    centered = ov_opset.subtract(x, mean_val).output(0)\n    centered = ov_opset.select(nan_mask, zero, centered).output(0)\n    squared = ov_opset.multiply(centered, centered).output(0)\n\n    if axis is None:\n        flatten_shape = ov_opset.constant([-1], Type.i32).output(0)\n        squared = ov_opset.reshape(squared, flatten_shape, False).output(0)\n        not_nan = ov_opset.reshape(not_nan, flatten_shape, False).output(0)\n        axis_const = ov_opset.constant(0, Type.i32).output(0)\n    else:\n        if isinstance(axis, (tuple, list)):\n            axis_const = ov_opset.constant(list(axis), Type.i32).output(0)\n        else:\n            axis_const = ov_opset.constant(axis, Type.i32).output(0)\n\n    not_nan_float = ov_opset.convert(not_nan, x_type).output(0)\n    sq_sum = ov_opset.reduce_sum(squared, axis_const, keepdims).output(0)\n    count = ov_opset.reduce_sum(not_nan_float, axis_const, keepdims).output(0)\n    result = ov_opset.divide(sq_sum, count).output(0)\n    if result.get_element_type() != ov_result_type:\n        result = ov_opset.convert(result, ov_result_type).output(0)\n    return OpenVINOKerasTensor(result)\n\n\ndef nan_to_num(x, nan=0.0, posinf=None, neginf=None):\n    x = get_ov_output(x)\n    dtype = x.get_element_type()\n    if dtype.is_integral():\n        return OpenVINOKerasTensor(x)\n    isfloat64 = True if dtype == Type.f64 else False\n    if isfloat64:  # conversion to f32 due to https://github.com/openvinotoolkit/openvino/issues/34138\n        x = ov_opset.convert(x, Type.f32).output(0)\n        dtype = Type.f32\n    nan_val = ov_opset.constant(nan, dtype).output(0)\n    posinf_val = ov_opset.constant(\n        posinf if posinf is not None else DTYPES_MAX[dtype], dtype\n    ).output(0)\n    neginf_val = ov_opset.constant(\n        neginf if neginf is not None else DTYPES_MIN[dtype], dtype\n    ).output(0)\n    posinf_mask = ov_opset.is_inf(\n        x,\n        {\"detect_positive\": True, \"detect_negative\": False},\n    ).output(0)\n    neginf_mask = ov_opset.is_inf(\n        x,\n        {\"detect_positive\": False, \"detect_negative\": True},\n    ).output(0)\n    nan_mask = ov_opset.is_nan(x).output(0)\n    x = ov_opset.select(nan_mask, nan_val, x).output(0)\n    x = ov_opset.select(posinf_mask, posinf_val, x).output(0)\n    x = ov_opset.select(neginf_mask, neginf_val, x).output(0)\n    if isfloat64:\n        x = ov_opset.convert(x, Type.f64).output(0)\n    return OpenVINOKerasTensor(x)\n\n\ndef ndim(x):\n    x = get_ov_output(x)\n    shape_tensor = ov_opset.shape_of(x, Type.i64).output(0)\n    rank_tensor = ov_opset.shape_of(shape_tensor, Type.i64).output(0)\n    return OpenVINOKerasTensor(rank_tensor)\n\n\ndef nonzero(x):\n    x = get_ov_output(x)\n    res = ov_opset.non_zero(data=x, output_type=\"i32\").output(0)\n    return OpenVINOKerasTensor(res)\n\n\ndef not_equal(x1, x2):\n    element_type = None\n    if isinstance(x1, OpenVINOKerasTensor):\n        element_type = x1.output.get_element_type()\n    if isinstance(x2, OpenVINOKerasTensor):\n        element_type = x2.output.get_element_type()\n    x1 = get_ov_output(x1, element_type)\n    x2 = get_ov_output(x2, element_type)\n    x1, x2 = _align_operand_types(x1, x2, \"not_equal()\")\n    return OpenVINOKerasTensor(ov_opset.not_equal(x1, x2).output(0))\n\n\ndef zeros_like(x, dtype=None):\n    x = get_ov_output(x)\n    shape_x = ov_opset.shape_of(x)\n    if dtype is not None:\n        ov_type = OPENVINO_DTYPES[standardize_dtype(dtype)]\n        const_zero = ov_opset.constant(0, ov_type).output(0)\n    else:\n        const_zero = ov_opset.constant(0, x.get_element_type()).output(0)\n    res = ov_opset.broadcast(const_zero, shape_x).output(0)\n    return OpenVINOKerasTensor(res)\n\n\ndef ones_like(x, dtype=None):\n    x = get_ov_output(x)\n    shape_x = ov_opset.shape_of(x)\n    if dtype is not None:\n        ov_type = OPENVINO_DTYPES[standardize_dtype(dtype)]\n        const_one = ov_opset.constant(1, ov_type).output(0)\n    else:\n        const_one = ov_opset.constant(1, x.get_element_type()).output(0)\n    res = ov_opset.broadcast(const_one, shape_x).output(0)\n    return OpenVINOKerasTensor(res)\n\n\ndef outer(x1, x2):\n    x1 = get_ov_output(x1)\n    x2 = get_ov_output(x2)\n\n    x1, x2 = _align_operand_types(x1, x2, \"outer()\")\n\n    new_shape_x1 = ov_opset.constant([-1, 1], Type.i32).output(0)\n    new_shape_x2 = ov_opset.constant([1, -1], Type.i32).output(0)\n\n    # Reshape directly from original tensors\n    x1_reshaped = ov_opset.reshape(x1, new_shape_x1, False).output(0)\n    x2_reshaped = ov_opset.reshape(x2, new_shape_x2, False).output(0)\n\n    result = ov_opset.multiply(x1_reshaped, x2_reshaped).output(0)\n\n    return OpenVINOKerasTensor(result)\n\n\ndef pad(x, pad_width, mode=\"constant\", constant_values=None):\n    x = get_ov_output(x)\n    pad_value = None\n    if constant_values is not None:\n        if mode != \"constant\":\n            raise ValueError(\n                \"Argument `constant_values` can only be \"\n                \"provided when `mode == 'constant'`. \"\n                f\"Received: mode={mode}\"\n            )\n        if not isinstance(constant_values, int):\n            raise ValueError(\n                \"`pad` operation supports only scalar pad value \"\n                \"in constant mode with the openvino backend. \"\n                f\"Received: constant_values={constant_values}\"\n            )\n        pad_value = ov_opset.constant(\n            constant_values, x.get_element_type()\n        ).output(0)\n\n    # split pad_width into two tensors pads_begin and pads_end\n    pads_begin = []\n    pads_end = []\n    for pads_pair in pad_width:\n        pads_begin.append(pads_pair[0])\n        pads_end.append(pads_pair[1])\n    pads_begin = ov_opset.constant(pads_begin, Type.i32).output(0)\n    pads_end = ov_opset.constant(pads_end, Type.i32).output(0)\n    return OpenVINOKerasTensor(\n        ov_opset.pad(x, pads_begin, pads_end, mode, pad_value).output(0)\n    )\n\n\ndef prod(x, axis=None, keepdims=False, dtype=None):\n    x = get_ov_output(x)\n\n    # If a specific dtype is requested, cast the input to that dtype.\n    if dtype is not None:\n        ov_dtype = OPENVINO_DTYPES[standardize_dtype(dtype)]\n        x = ov_opset.convert(x, ov_dtype).output(0)\n    # Otherwise, apply dtype promotion rules before reduction.\n    else:\n        x = _upcast_type_if_needed(x)\n    x, axis = _resolve_axis(x, axis)\n    if axis is None:\n        return OpenVINOKerasTensor(x)\n    # Compute the product\n    result = ov_opset.reduce_prod(x, axis, keepdims).output(0)\n\n    return OpenVINOKerasTensor(result)\n\n\ndef ptp(x, axis=None, keepdims=False):\n    if axis == ():\n        return zeros_like(x)\n    x = get_ov_output(x)\n\n    x_resolved, resolved_axis = _resolve_axis(x, axis)\n\n    max_val = ov_opset.reduce_max(x_resolved, resolved_axis, keepdims)\n    min_val = ov_opset.reduce_min(x_resolved, resolved_axis, keepdims)\n\n    return OpenVINOKerasTensor(ov_opset.subtract(max_val, min_val).output(0))\n\n\ndef quantile(x, q, axis=None, method=\"linear\", keepdims=False):\n    x = get_ov_output(x)\n    q_ov = get_ov_output(q)\n\n    x_keras_type = ov_to_keras_type(x.get_element_type())\n    compute_dtype = (\n        config.floatx()\n        if x_keras_type in (\"int64\", \"bool\")\n        else dtypes.result_type(x_keras_type, float)\n    )\n    compute_ov_type = OPENVINO_DTYPES[compute_dtype]\n    x = ov_opset.convert(x, compute_ov_type).output(0)\n    q_f64 = ov_opset.convert(q_ov, Type.f64).output(0)\n    q_rank = q_ov.get_partial_shape().rank.get_length()\n    x_ndim = x.get_partial_shape().rank.get_length()\n\n    # Flatten axis dims to the last position, then sort along it\n    if axis is None:\n        y = ov_opset.reshape(\n            x, ov_opset.constant([-1], Type.i64).output(0), False\n        ).output(0)\n        norm_axis = None\n    else:\n        if isinstance(axis, int):\n            axis = [axis]\n        axis = [a % x_ndim for a in axis]\n        other_dims = sorted(set(range(x_ndim)).difference(axis))\n        x_t = ov_opset.transpose(\n            x, ov_opset.constant(other_dims + list(axis), Type.i32).output(0)\n        ).output(0)\n        x_shape = ov_opset.shape_of(x, Type.i64).output(0)\n        if other_dims:\n            other_shape = ov_opset.gather(\n                x_shape,\n                ov_opset.constant(other_dims, Type.i32).output(0),\n                ov_opset.constant(0, Type.i32).output(0),\n            ).output(0)\n            flat_shape = ov_opset.concat(\n                [other_shape, ov_opset.constant([-1], Type.i64).output(0)],\n                axis=0,\n            ).output(0)\n        else:\n            flat_shape = ov_opset.constant([-1], Type.i64).output(0)\n        y = ov_opset.reshape(x_t, flat_shape, False).output(0)\n        norm_axis = axis\n\n    sorted_y = sort(OpenVINOKerasTensor(y)).output\n\n    # Size of the last (sorted) dimension, needed for index computation\n    y_ndim = y.get_partial_shape().rank.get_length()\n    n_i32 = ov_opset.squeeze(\n        ov_opset.gather(\n            ov_opset.shape_of(y, Type.i32).output(0),\n            ov_opset.constant([y_ndim - 1], Type.i32).output(0),\n            ov_opset.constant(0, Type.i32).output(0),\n        ).output(0),\n        ov_opset.constant([0], Type.i32).output(0),\n    ).output(0)\n\n    # exact_idx = (n - 1) * q  in float64 for precision\n    n_f64 = ov_opset.convert(n_i32, Type.f64).output(0)\n    exact_idx = ov_opset.multiply(\n        ov_opset.subtract(\n            n_f64, ov_opset.constant(np.float64(1.0)).output(0)\n        ).output(0),\n        q_f64,\n    ).output(0)\n\n    zero_i32 = ov_opset.constant(np.int32(0)).output(0)\n    n_minus1_i32 = ov_opset.subtract(\n        n_i32, ov_opset.constant(np.int32(1)).output(0)\n    ).output(0)\n    last_ax = ov_opset.constant(y_ndim - 1, Type.i32).output(0)\n\n    def _clamp_idx(f64_idx):\n        i = ov_opset.convert(f64_idx, Type.i32).output(0)\n        return ov_opset.minimum(\n            ov_opset.maximum(i, zero_i32).output(0), n_minus1_i32\n        ).output(0)\n\n    def _gather(idx):\n        return ov_opset.gather(sorted_y, idx, last_ax).output(0)\n\n    lo_idx = _clamp_idx(ov_opset.floor(exact_idx).output(0))\n    hi_idx = _clamp_idx(ov_opset.ceiling(exact_idx).output(0))\n\n    if method == \"lower\":\n        gathered = _gather(lo_idx)\n    elif method == \"higher\":\n        gathered = _gather(hi_idx)\n    elif method == \"nearest\":\n        gathered = _gather(\n            _clamp_idx(ov_opset.round(exact_idx, \"half_to_even\").output(0))\n        )\n    elif method == \"midpoint\":\n        two = ov_opset.convert(\n            ov_opset.constant(np.float32(2.0)).output(0), compute_ov_type\n        ).output(0)\n        gathered = ov_opset.divide(\n            ov_opset.add(_gather(lo_idx), _gather(hi_idx)).output(0), two\n        ).output(0)\n    else:  # linear\n        # preserve_gradients: ensure interp_lo_idx < interp_hi_idx\n        one_i32 = ov_opset.constant(np.int32(1)).output(0)\n        interp_lo_idx = ov_opset.maximum(\n            ov_opset.subtract(hi_idx, one_i32).output(0), zero_i32\n        ).output(0)\n        interp_hi_idx = ov_opset.minimum(\n            ov_opset.add(interp_lo_idx, one_i32).output(0), n_minus1_i32\n        ).output(0)\n        frac = ov_opset.convert(\n            ov_opset.subtract(\n                ov_opset.convert(interp_hi_idx, Type.f64).output(0), exact_idx\n            ).output(0),\n            compute_ov_type,\n        ).output(0)\n        one_val = ov_opset.convert(\n            ov_opset.constant(np.float32(1.0)).output(0), compute_ov_type\n        ).output(0)\n        gathered = ov_opset.add(\n            ov_opset.multiply(\n                _gather(interp_hi_idx),\n                ov_opset.subtract(one_val, frac).output(0),\n            ).output(0),\n            ov_opset.multiply(_gather(interp_lo_idx), frac).output(0),\n        ).output(0)\n\n    # keepdims: insert size-1 dims before rotating q to front\n    if keepdims:\n        axes_to_add = (\n            list(range(x_ndim)) if norm_axis is None else sorted(norm_axis)\n        )\n        for i in axes_to_add:\n            gathered = ov_opset.unsqueeze(\n                gathered, ov_opset.constant([i], Type.i32).output(0)\n            ).output(0)\n\n    # For 1-D q, rotate the q dim from last to first\n    if q_rank > 0:\n        g_ndim = gathered.get_partial_shape().rank.get_length()\n        if g_ndim >= 2:\n            gathered = ov_opset.transpose(\n                gathered,\n                ov_opset.constant(\n                    [g_ndim - 1] + list(range(g_ndim - 1)), Type.i32\n                ).output(0),\n            ).output(0)\n\n    return OpenVINOKerasTensor(gathered)\n\n\ndef ravel(x):\n    x = get_ov_output(x)\n    target_shape = ov_opset.constant([-1], dtype=Type.i32).output(0)\n    return OpenVINOKerasTensor(\n        ov_opset.reshape(x, target_shape, special_zero=False).output(0)\n    )\n\n\ndef real(x):\n    # TODO: Implement complex support when OpenVINO adds complex dtypes.\n    # Currently, all supported dtypes are real-valued.\n    return convert_to_tensor(x)\n\n\ndef reciprocal(x):\n    x = get_ov_output(x)\n    one_constant = ov_opset.constant(1, dtype=x.get_element_type()).output(0)\n    x = ov_opset.divide(one_constant, x).output(0)\n    return OpenVINOKerasTensor(x)\n\n\ndef repeat(x, repeats, axis=None):\n    x = get_ov_output(x)\n    const_0 = ov_opset.constant(0, Type.i32)\n    const_1 = ov_opset.constant(1, Type.i32)\n    const_neg_1 = ov_opset.constant([-1], Type.i32)\n\n    if axis is not None and axis < 0:\n        axis += len(x.get_partial_shape())\n\n    if axis is None:\n        x = ov_opset.reshape(x, const_neg_1, special_zero=False)\n        axis = 0\n\n    if isinstance(repeats, np.integer):\n        repeats = int(repeats)\n    elif (\n        isinstance(repeats, np.ndarray)\n        and repeats.size == 1\n        and repeats.ndim <= 1\n    ):\n        repeats = int(repeats.item())\n\n    if isinstance(repeats, int):\n        dim_len = ov_opset.gather(\n            ov_opset.shape_of(x, Type.i32),\n            ov_opset.constant([axis], Type.i32),\n            const_0,\n        )\n        dim_len = ov_opset.squeeze(dim_len, ov_opset.constant([0], Type.i32))\n        idx_range = ov_opset.range(\n            const_0, dim_len, const_1, output_type=Type.i32\n        )\n        idx_range = ov_opset.unsqueeze(idx_range, const_1)\n        tiled = ov_opset.tile(\n            idx_range, ov_opset.constant([1, repeats], Type.i32)\n        )\n        idx = ov_opset.reshape(tiled, const_neg_1, special_zero=False)\n        result = ov_opset.gather(x, idx, ov_opset.constant(axis, Type.i32))\n        return OpenVINOKerasTensor(result.output(0))\n    repeats_tensor = get_ov_output(repeats)\n    cumsum = ov_opset.cumsum(repeats_tensor, const_0)\n    total = ov_opset.reduce_sum(\n        repeats_tensor, ov_opset.constant([0], Type.i32), keep_dims=False\n    )\n    total = ov_opset.convert(total, Type.i32)\n    out_indices = ov_opset.range(const_0, total, const_1, output_type=Type.i32)\n    cumsum_unsq = ov_opset.unsqueeze(cumsum, const_0)\n    out_indices_unsq = ov_opset.unsqueeze(out_indices, const_1)\n    cumsum_unsq = ov_opset.convert(cumsum_unsq, Type.i32)\n    mask = ov_opset.greater_equal(out_indices_unsq, cumsum_unsq)\n    gather_indices = ov_opset.reduce_sum(\n        ov_opset.convert(mask, Type.i32), ov_opset.constant([1], Type.i32)\n    )\n    result = ov_opset.gather(\n        x, gather_indices, ov_opset.constant(axis, Type.i32)\n    )\n    return OpenVINOKerasTensor(result.output(0))\n\n\ndef reshape(x, newshape):\n    x = get_ov_output(x)\n    if isinstance(newshape, int):\n        newshape = [newshape]\n    elif isinstance(newshape, tuple):\n        newshape = list(newshape)\n    if isinstance(newshape, list):\n        newshape = [-1 if d is None else d for d in newshape]\n    if isinstance(newshape, OpenVINOKerasTensor):\n        newshape = get_ov_output(newshape)\n    else:\n        newshape = ov_opset.constant(newshape, Type.i32).output(0)\n    return OpenVINOKerasTensor(ov_opset.reshape(x, newshape, False).output(0))\n\n\ndef roll(x, shift, axis=None):\n    x = get_ov_output(x)\n    if axis is not None:\n        result = ov_opset.roll(x, shift, axis).output(0)\n    else:\n        output_shape = ov_opset.shape_of(x).output(0)\n        flattened = ov_opset.reshape(\n            x, ov_opset.constant([-1], Type.i32), False\n        ).output(0)\n        result = ov_opset.roll(flattened, shift, 0).output(0)\n        result = ov_opset.reshape(result, output_shape, False).output(0)\n    return OpenVINOKerasTensor(result)\n\n\ndef searchsorted(sorted_sequence, values, side=\"left\"):\n    if side not in (\"left\", \"right\"):\n        raise ValueError(\n            f\"`side` must be either 'left' or 'right'. Received: side={side}\"\n        )\n    sorted_sequence = get_ov_output(sorted_sequence)\n    values = get_ov_output(values)\n\n    if sorted_sequence.get_partial_shape().rank.get_length() != 1:\n        raise ValueError(\n            \"`searchsorted` only supports 1-D sorted sequences. \"\n            \"You can use `keras.ops.vectorized_map` \"\n            \"to extend it to N-D sequences. Received: \"\n            f\"sorted_sequence.shape={sorted_sequence.get_partial_shape()}\"\n        )\n\n    sorted_sequence, values = _align_operand_types(\n        sorted_sequence, values, \"searchsorted()\"\n    )\n\n    # Note: OpenVINO's bucketize with_right_bound has opposite semantics\n    # with_right_bound=True means search from right (side='left' in numpy)\n    # with_right_bound=False means search from left (side='right' in numpy)\n    with_right_bound = side == \"left\"\n    result = ov_opset.bucketize(\n        values,\n        sorted_sequence,\n        output_type=Type.i32,\n        with_right_bound=with_right_bound,\n    ).output(0)\n\n    return OpenVINOKerasTensor(result)\n\n\ndef sign(x):\n    x = get_ov_output(x)\n    return OpenVINOKerasTensor(ov_opset.sign(x).output(0))\n\n\ndef signbit(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    zero = ov_opset.constant(0, dtype=x_type).output(0)\n    is_negative = ov_opset.less(x, zero).output(0)\n    if x_type.is_real():\n        one = ov_opset.constant(1.0, dtype=x_type).output(0)\n        recip = ov_opset.divide(one, x).output(0)\n        recip_neg = ov_opset.less(recip, zero).output(0)\n        is_zero = ov_opset.equal(x, zero).output(0)\n        neg_zero = ov_opset.logical_and(is_zero, recip_neg).output(0)\n        is_negative = ov_opset.logical_or(is_negative, neg_zero).output(0)\n    return OpenVINOKerasTensor(is_negative)\n\n\ndef sin(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type.is_integral():\n        ov_type = OPENVINO_DTYPES[config.floatx()]\n        x = ov_opset.convert(x, ov_type)\n    return OpenVINOKerasTensor(ov_opset.sin(x).output(0))\n\n\ndef sinc(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type.is_integral():\n        ov_type = OPENVINO_DTYPES[config.floatx()]\n        x = ov_opset.convert(x, ov_type).output(0)\n    elem_type = x.get_element_type()\n    pi = ov_opset.constant(np.pi, elem_type)\n    one = ov_opset.constant(1.0, elem_type)\n    zero = ov_opset.constant(0.0, elem_type)\n    pi_x = ov_opset.multiply(pi, x)\n    sin_pi_x = ov_opset.sin(pi_x)\n    sinc_val = ov_opset.divide(sin_pi_x, pi_x)\n    is_zero = ov_opset.equal(x, zero)\n    result = ov_opset.select(is_zero, one, sinc_val)\n    return OpenVINOKerasTensor(result.output(0))\n\n\ndef sinh(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type.is_integral():\n        ov_type = OPENVINO_DTYPES[config.floatx()]\n        x = ov_opset.convert(x, ov_type)\n    return OpenVINOKerasTensor(ov_opset.sinh(x).output(0))\n\n\ndef size(x):\n    x = get_ov_output(x)\n    shape_tensor = ov_opset.shape_of(x, output_type=Type.i64)\n    final_size = ov_opset.reduce_prod(\n        shape_tensor,\n        ov_opset.constant([0], Type.i64),\n        keep_dims=False,\n    )\n    return OpenVINOKerasTensor(final_size.output(0))\n\n\ndef sort(x, axis=-1):\n    x = get_ov_output(x)\n    x_shape = x.get_partial_shape()\n    rank = x_shape.rank.get_length()\n\n    if rank == 0:\n        return OpenVINOKerasTensor(x)\n\n    # Handle axis=None by flattening the input\n    if axis is None:\n        x = ov_opset.reshape(\n            x, ov_opset.constant([-1], Type.i32), False\n        ).output(0)\n        axis = 0\n    # Handle negative axis\n    elif axis < 0:\n        axis = rank + axis\n\n    # Get the size of the dimension to sort\n    shape_tensor = ov_opset.shape_of(x, output_type=Type.i32).output(0)\n    k = ov_opset.gather(\n        shape_tensor,\n        ov_opset.constant([axis], Type.i32).output(0),\n        ov_opset.constant(0, Type.i32).output(0),\n    ).output(0)\n\n    # Convert k to a scalar value\n    k_scalar = ov_opset.squeeze(k, ov_opset.constant([0], Type.i32)).output(0)\n\n    # Use topk with k=size_of_axis to get all elements sorted\n    topk_outputs = ov_opset.topk(\n        x, k=k_scalar, axis=axis, mode=\"min\", sort=\"value\", stable=True\n    )\n\n    # Get the sorted values\n    sorted_values = topk_outputs.output(0)\n\n    return OpenVINOKerasTensor(sorted_values)\n\n\ndef split(x, indices_or_sections, axis=0):\n    x = get_ov_output(x)\n    axis_tensor = ov_opset.constant(axis, dtype=Type.i32).output(0)\n\n    shape_tensor = ov_opset.shape_of(x)\n    axis_i32 = ov_opset.constant([axis], dtype=Type.i32)\n    dim_at_axis_tensor = ov_opset.gather(\n        shape_tensor, axis_i32, ov_opset.constant(0, dtype=Type.i32)\n    )\n\n    if isinstance(indices_or_sections, OpenVINOKerasTensor):\n        indices_or_sections = (\n            indices_or_sections.output.get_node().get_data().tolist()\n        )\n\n    if isinstance(indices_or_sections, int):\n        num_splits = indices_or_sections\n        splits = ov_opset.split(x, axis_tensor, num_splits=num_splits)\n        result = []\n        for i in range(num_splits):\n            result.append(OpenVINOKerasTensor(splits.output(i)))\n        return result\n\n    if isinstance(indices_or_sections, (list, tuple, np.ndarray)):\n        indices = list(indices_or_sections)\n        split_lengths = []\n        split_lengths.append(indices[0])\n        for i in range(1, len(indices)):\n            split_lengths.append(indices[i] - indices[i - 1])\n\n        last_index_tensor = ov_opset.constant(indices[-1], dtype=Type.i64)\n        remaining_length_tensor = ov_opset.subtract(\n            dim_at_axis_tensor, last_index_tensor\n        )\n\n        length_parts = []\n        length_parts.append(ov_opset.constant(split_lengths, dtype=Type.i64))\n        length_parts.append(remaining_length_tensor)\n        length_tensor = ov_opset.concat(length_parts, axis=0)\n\n        splits = ov_opset.variadic_split(x, axis_tensor, length_tensor)\n        result = []\n        for i in range(len(split_lengths) + 1):\n            result.append(OpenVINOKerasTensor(splits.output(i)))\n        return result\n\n    raise TypeError(\n        f\"unsupported type of indices_or_sections: {type(indices_or_sections)}\"\n    )\n\n\ndef array_split(x, indices_or_sections, axis=0):\n    original_shape = x.shape\n    x = get_ov_output(x)\n\n    num_splits_val = indices_or_sections\n    total_size = original_shape[axis]\n    if total_size is None:\n        raise ValueError(\n            f\"Cannot use array_split with static Python logic on dynamic axis. \"\n            f\"Axis {axis} has unknown dimension for shape {original_shape}.\"\n        )\n\n    base_size = total_size // num_splits_val\n    remainder = total_size % num_splits_val\n\n    split_lengths = [base_size + 1] * remainder + [base_size] * (\n        num_splits_val - remainder\n    )\n    split_lengths_tensor = ov_opset.constant(\n        split_lengths, dtype=Type.i64\n    ).output(0)\n\n    axis_tensor = ov_opset.constant(axis, dtype=Type.i32).output(0)\n    splits = ov_opset.variadic_split(x, axis_tensor, split_lengths_tensor)\n\n    result = []\n    for i in range(num_splits_val):\n        result.append(OpenVINOKerasTensor(splits.output(i)))\n    return result\n\n\ndef stack(x, axis=0):\n    if isinstance(x, tuple):\n        x = list(x)\n    if not isinstance(x, list):\n        raise ValueError(\n            f\"`stack` supports only `x` as list or tuple. Received: {type(x)}\"\n        )\n    elems = [get_ov_output(e) for e in x]\n    ref = elems[0]\n    for i in range(1, len(elems)):\n        ref, elems[i] = _align_operand_types(ref, elems[i], \"stack()\")\n    elems[0] = ref\n    const_axis = ov_opset.constant(axis, Type.i32).output(0)\n    elems = [ov_opset.unsqueeze(e, const_axis).output(0) for e in elems]\n    res = ov_opset.concat(elems, axis).output(0)\n    return OpenVINOKerasTensor(res)\n\n\ndef std(x, axis=None, keepdims=False):\n    var_x = var(x, axis, keepdims)\n    std_dev = ov_opset.sqrt(var_x.output).output(0)\n    return OpenVINOKerasTensor(std_dev)\n\n\ndef swapaxes(x, axis1, axis2):\n    x = get_ov_output(x)\n    x_shape = x.get_partial_shape()\n    if x_shape.rank.is_dynamic:\n        raise ValueError(\n            \"`swapaxes` does not support tensors with dynamic rank for the \"\n            \"OpenVINO backend.\"\n        )\n    rank = x_shape.rank.get_length()\n    axis1 = canonicalize_axis(axis1, rank)\n    axis2 = canonicalize_axis(axis2, rank)\n    axes = list(range(rank))\n    axes[axis1], axes[axis2] = axes[axis2], axes[axis1]\n    result = ov_opset.transpose(x, ov_opset.constant(axes, Type.i32))\n    return OpenVINOKerasTensor(result.output(0))\n\n\ndef take(x, indices, axis=None):\n    x = get_ov_output(x)\n    indices = get_ov_output(indices)\n    if axis is None:\n        target_shape = ov_opset.constant([-1], dtype=Type.i32).output(0)\n        x = ov_opset.reshape(x, target_shape, False).output(0)\n        axis = ov_opset.constant(0, dtype=Type.i32).output(0)\n    else:\n        axis = ov_opset.constant(axis, dtype=Type.i32).output(0)\n    return OpenVINOKerasTensor(ov_opset.gather(x, indices, axis).output(0))\n\n\ndef take_along_axis(x, indices, axis=None):\n    x = get_ov_output(x)\n    indices = get_ov_output(indices)\n\n    if axis is None:\n        target_shape = ov_opset.constant([-1], dtype=Type.i32).output(0)\n        x_flat = ov_opset.reshape(x, target_shape, False).output(0)\n        indices_flat = ov_opset.reshape(indices, target_shape, False).output(0)\n        result = ov_opset.gather_elements(x_flat, indices_flat, 0).output(0)\n        return OpenVINOKerasTensor(result)\n\n    x_rank = len(x.get_partial_shape())\n    if axis < 0:\n        axis += x_rank\n\n    x_shape = ov_opset.shape_of(x, Type.i32).output(0)\n    indices_shape = ov_opset.shape_of(indices, Type.i32).output(0)\n\n    zero_const = ov_opset.constant(0, dtype=Type.i32).output(0)\n    axis_index = ov_opset.constant([axis], dtype=Type.i32).output(0)\n\n    # Fix negative indices\n    dim_size = ov_opset.squeeze(\n        ov_opset.gather(x_shape, axis_index, zero_const).output(0), zero_const\n    ).output(0)\n    zero_scalar = ov_opset.constant(0, indices.get_element_type()).output(0)\n    is_neg = ov_opset.less(indices, zero_scalar).output(0)\n    dim_size_cast = ov_opset.convert(\n        dim_size, indices.get_element_type()\n    ).output(0)\n    indices = ov_opset.select(\n        is_neg, ov_opset.add(indices, dim_size_cast).output(0), indices\n    ).output(0)\n    indices = ov_opset.convert(indices, Type.i32).output(0)\n\n    x_target_parts, indices_target_parts = [], []\n\n    for i in range(x_rank):\n        dim_idx = ov_opset.constant([i], dtype=Type.i32).output(0)\n        x_dim = ov_opset.gather(x_shape, dim_idx, zero_const).output(0)\n        indices_dim = ov_opset.gather(\n            indices_shape, dim_idx, zero_const\n        ).output(0)\n\n        if i == axis:\n            # For axis dimension: keep original dimensions\n            x_target_parts.append(x_dim)\n            indices_target_parts.append(indices_dim)\n        else:\n            # For other dimensions: use maximum for broadcasting\n            max_dim = ov_opset.maximum(x_dim, indices_dim).output(0)\n            x_target_parts.append(max_dim)\n            indices_target_parts.append(max_dim)\n\n    x_target_shape = ov_opset.concat(x_target_parts, axis=0).output(0)\n    indices_target_shape = ov_opset.concat(indices_target_parts, axis=0).output(\n        0\n    )\n\n    # Broadcast to target shapes and gather elements\n    x_broadcasted = ov_opset.broadcast(x, x_target_shape).output(0)\n    indices_broadcasted = ov_opset.broadcast(\n        indices, indices_target_shape\n    ).output(0)\n    result = ov_opset.gather_elements(\n        x_broadcasted, indices_broadcasted, axis\n    ).output(0)\n\n    return OpenVINOKerasTensor(result)\n\n\ndef tan(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type.is_integral():\n        ov_type = OPENVINO_DTYPES[config.floatx()]\n        x = ov_opset.convert(x, ov_type)\n    return OpenVINOKerasTensor(ov_opset.tan(x).output(0))\n\n\ndef tanh(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type.is_integral():\n        ov_type = OPENVINO_DTYPES[config.floatx()]\n        x = ov_opset.convert(x, ov_type)\n    return OpenVINOKerasTensor(ov_opset.tanh(x).output(0))\n\n\ndef tensordot(x1, x2, axes=2):\n    a = get_ov_output(x1)\n    b = get_ov_output(x2)\n    a, b = _align_operand_types(a, b, \"tensordot()\")\n\n    rank_a = a.get_partial_shape().rank.get_length()\n    rank_b = b.get_partial_shape().rank.get_length()\n\n    if isinstance(axes, int):\n        axes_a = list(range(rank_a - axes, rank_a))\n        axes_b = list(range(axes))\n    else:\n        axes_a, axes_b = [\n            list(ax) if isinstance(ax, (list, tuple)) else [ax] for ax in axes\n        ]\n        axes_a = [canonicalize_axis(i, rank_a) for i in axes_a]\n        axes_b = [canonicalize_axis(i, rank_b) for i in axes_b]\n\n    notin_a = [i for i in range(rank_a) if i not in axes_a]\n    notin_b = [i for i in range(rank_b) if i not in axes_b]\n\n    # Transpose so contraction axes are at the end of A and beginning of B\n    a_transpose = ov_opset.transpose(\n        a, ov_opset.constant(notin_a + axes_a, Type.i32)\n    )\n    b_transpose = ov_opset.transpose(\n        b, ov_opset.constant(axes_b + notin_b, Type.i32)\n    )\n\n    # Calculate the product of the contraction dimensions\n    shape_a = ov_opset.shape_of(a, Type.i32)\n    contract_dims = ov_opset.gather(\n        shape_a, ov_opset.constant(axes_a, Type.i32), 0\n    )\n    contract_size = ov_opset.reduce_prod(contract_dims, 0, keep_dims=True)\n\n    # Reshape A to [-1, contract_size] and B to [contract_size, -1]\n    a_2d = ov_opset.reshape(\n        a_transpose,\n        ov_opset.concat([ov_opset.constant([-1], Type.i32), contract_size], 0),\n        False,\n    )\n    b_2d = ov_opset.reshape(\n        b_transpose,\n        ov_opset.concat([contract_size, ov_opset.constant([-1], Type.i32)], 0),\n        False,\n    )\n\n    result = ov_opset.matmul(a_2d, b_2d, False, False)\n\n    # Reconstruct final shape from free dimensions\n    if not notin_a and not notin_b:\n        # Scalar output case\n        result = ov_opset.reshape(\n            result, ov_opset.constant([], Type.i32), False\n        )\n    else:\n        shape_b = ov_opset.shape_of(b, Type.i32)\n        final_parts = []\n        if notin_a:\n            final_parts.append(\n                ov_opset.gather(\n                    shape_a, ov_opset.constant(notin_a, Type.i32), 0\n                )\n            )\n        if notin_b:\n            final_parts.append(\n                ov_opset.gather(\n                    shape_b, ov_opset.constant(notin_b, Type.i32), 0\n                )\n            )\n\n        result = ov_opset.reshape(\n            result, ov_opset.concat(final_parts, 0), False\n        )\n\n    return OpenVINOKerasTensor(result.output(0))\n\n\ndef round(x, decimals=0):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type.is_integral() or x_type == Type.boolean:\n        x = ov_opset.convert(x, OPENVINO_DTYPES[config.floatx()])\n\n    if decimals == 0:\n        result = ov_opset.round(x, \"half_to_even\")\n    else:\n        factor = ov_opset.constant(10.0**decimals, x.get_element_type())\n        scaled = ov_opset.multiply(x, factor)\n        rounded = ov_opset.round(scaled, \"half_to_even\")\n        result = ov_opset.divide(rounded, factor)\n\n    if x_type.is_integral():\n        result = ov_opset.convert(result, x_type)\n\n    return OpenVINOKerasTensor(result.output(0))\n\n\ndef trunc(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type.is_integral():\n        return OpenVINOKerasTensor(x)\n    sign_x = ov_opset.sign(x)\n    abs_x = ov_opset.abs(x)\n    floor_abs_x = ov_opset.floor(abs_x)\n    result = ov_opset.multiply(sign_x, floor_abs_x)\n    return OpenVINOKerasTensor(result.output(0))\n\n\ndef tile(x, repeats):\n    x = get_ov_output(x)\n\n    if isinstance(repeats, int):\n        repeats = [repeats]\n    repeats = get_ov_output(repeats)\n\n    if repeats.get_element_type() != Type.i64:\n        repeats = ov_opset.convert(repeats, Type.i64)\n\n    if len(repeats.get_partial_shape()) != 1:\n        repeats = ov_opset.reshape(repeats, [-1], False)\n\n    shape_x = ov_opset.shape_of(x, Type.i64)\n    rank_x = ov_opset.shape_of(shape_x, Type.i64)\n    rank_r = ov_opset.shape_of(repeats, Type.i64)\n\n    one = ov_opset.constant(1, Type.i64)\n    zero = ov_opset.constant(0, Type.i64)\n\n    pad_x = ov_opset.maximum(ov_opset.subtract(rank_r, rank_x), zero)\n    new_x_shape = ov_opset.concat(\n        [ov_opset.broadcast(one, pad_x).output(0), shape_x], 0\n    )\n    x = ov_opset.reshape(x, new_x_shape, False)\n\n    pad_r = ov_opset.maximum(ov_opset.subtract(rank_x, rank_r), zero)\n    repeats = ov_opset.concat(\n        [ov_opset.broadcast(one, pad_r).output(0), repeats], 0\n    )\n\n    return OpenVINOKerasTensor(ov_opset.tile(x, repeats).output(0))\n\n\ndef trace(x, offset=0, axis1=0, axis2=1):\n    x = diagonal(x, offset=offset, axis1=axis1, axis2=axis2)\n    return sum(x, axis=-1)\n\n\ndef tri(N, M=None, k=0, dtype=None):\n    if M is None:\n        M = N\n    if dtype is None:\n        dtype = \"float32\"\n\n    ov_dtype = OPENVINO_DTYPES[dtype]\n\n    def ensure_constant(value, default_type=Type.i32):\n        if isinstance(value, (int, float)):\n            return ov_opset.constant(value, default_type)\n        elif hasattr(value, \"get_element_type\"):\n            if value.get_element_type() != Type.i32:\n                value = ov_opset.convert(value, Type.i32)\n            return ov_opset.squeeze(value, ov_opset.constant([0], Type.i32))\n        else:\n            return ov_opset.constant(value, default_type)\n\n    N_const = ensure_constant(N)\n    M_const = ensure_constant(M)\n    k_const = ensure_constant(k)\n\n    # Create row and column indices\n    row_range = ov_opset.range(\n        ov_opset.constant(0, Type.i32),\n        N_const,\n        ov_opset.constant(1, Type.i32),\n        output_type=Type.i32,\n    )\n    col_range = ov_opset.range(\n        ov_opset.constant(0, Type.i32),\n        M_const,\n        ov_opset.constant(1, Type.i32),\n        output_type=Type.i32,\n    )\n\n    # Reshape indices for broadcasting\n    row_idx = ov_opset.unsqueeze(row_range, ov_opset.constant([1], Type.i32))\n    col_idx = ov_opset.unsqueeze(col_range, ov_opset.constant([0], Type.i32))\n\n    mask = ov_opset.less_equal(col_idx, ov_opset.add(row_idx, k_const))\n\n    if ov_dtype == Type.boolean:\n        result = mask\n    else:\n        result = ov_opset.convert(mask, ov_dtype)\n\n    return OpenVINOKerasTensor(result.output(0))\n\n\ndef tril(x, k=0):\n    x = get_ov_output(x)\n    ov_type = x.get_element_type()\n    shape = ov_opset.shape_of(x, Type.i32)\n    zero_const = ov_opset.constant(0, Type.i32)\n    minus2 = ov_opset.constant([-2], Type.i32)\n    minus1 = ov_opset.constant([-1], Type.i32)\n    M = ov_opset.squeeze(ov_opset.gather(shape, minus2, zero_const), zero_const)\n    N = ov_opset.squeeze(ov_opset.gather(shape, minus1, zero_const), zero_const)\n    tri_mask = tri(M, N, k=k, dtype=\"bool\").output\n    mask = ov_opset.convert(tri_mask, ov_type)\n    if ov_type == Type.boolean:\n        out = ov_opset.logical_and(x, mask)\n    else:\n        out = ov_opset.multiply(x, mask)\n    return OpenVINOKerasTensor(out.output(0))\n\n\ndef triu(x, k=0):\n    x = get_ov_output(x)\n    ov_type = x.get_element_type()\n    shape = ov_opset.shape_of(x, Type.i32)\n    zero_const = ov_opset.constant(0, Type.i32)\n    minus2 = ov_opset.constant([-2], Type.i32)\n    minus1 = ov_opset.constant([-1], Type.i32)\n    M = ov_opset.squeeze(ov_opset.gather(shape, minus2, zero_const), zero_const)\n    N = ov_opset.squeeze(ov_opset.gather(shape, minus1, zero_const), zero_const)\n    tri_mask = tri(M, N, k=k - 1, dtype=\"bool\").output\n    if ov_type == Type.boolean:\n        mask = ov_opset.logical_not(tri_mask)\n    else:\n        const_one = ov_opset.constant(1, ov_type)\n        converted_mask = ov_opset.convert(tri_mask, ov_type)\n        mask = ov_opset.subtract(const_one, converted_mask)\n    if ov_type == Type.boolean:\n        out = ov_opset.logical_and(x, mask)\n    else:\n        out = ov_opset.multiply(x, mask)\n    return OpenVINOKerasTensor(out.output(0))\n\n\ndef vdot(x1, x2):\n    element_type = None\n    if isinstance(x1, OpenVINOKerasTensor):\n        element_type = x1.output.get_element_type()\n    if isinstance(x2, OpenVINOKerasTensor):\n        element_type = x2.output.get_element_type()\n    x1 = get_ov_output(x1, element_type)\n    x2 = get_ov_output(x2, element_type)\n    x1, x2 = _align_operand_types(x1, x2, \"vdot()\")\n    if x1.get_partial_shape().rank == 0 or x2.get_partial_shape().rank == 0:\n        return OpenVINOKerasTensor(ov_opset.multiply(x1, x2).output(0))\n    flatten_shape = ov_opset.constant([-1], Type.i32).output(0)\n    x1 = ov_opset.reshape(x1, flatten_shape, False).output(0)\n    x2 = ov_opset.reshape(x2, flatten_shape, False).output(0)\n    return OpenVINOKerasTensor(ov_opset.matmul(x1, x2, False, False).output(0))\n\n\ndef vstack(xs):\n    if not isinstance(xs, (list, tuple)):\n        xs = (xs,)\n    elems = [convert_to_tensor(elem) for elem in xs]\n    element_type = elems[0].output.get_element_type()\n    elems = [get_ov_output(elem, element_type) for elem in elems]\n    axis = 0\n    for i in range(1, len(elems)):\n        elems[0], elems[i] = _align_operand_types(\n            elems[0], elems[i], \"vstack()\"\n        )\n    return OpenVINOKerasTensor(ov_opset.concat(elems, axis).output(0))\n\n\ndef vsplit(x, indices_or_sections):\n    return split(x, indices_or_sections, axis=0)\n\n\ndef vectorize(pyfunc, *, excluded=None, signature=None):\n    def wrapper(*args, **kwargs):\n        converted_args = tuple(convert_to_tensor(arg) for arg in args)\n        return pyfunc(*converted_args, **kwargs)\n\n    return wrapper\n\n\ndef where(condition, x1=None, x2=None):\n    condition = get_ov_output(condition)\n    if x1 is None and x2 is None:\n        nonzero_indices = ov_opset.non_zero(condition)\n        return OpenVINOKerasTensor(nonzero_indices.output(0))\n    if x1 is None:\n        return OpenVINOKerasTensor(condition)\n    if x2 is None:\n        raise ValueError(\"x2 must be provided if x1 is specified.\")\n\n    def cast_literal_like_tensor(literal, x):\n        ov_type = get_ov_output(x).get_element_type()\n        is_bool = ov_type == Type.boolean\n        is_float_to_int = isinstance(literal, float) and ov_type.is_integral()\n        if is_bool or is_float_to_int:\n            return get_ov_output(literal), get_ov_output(x)\n        return get_ov_output(literal, ov_type), get_ov_output(x)\n\n    if isinstance(x1, (int, float)):\n        x1, x2 = cast_literal_like_tensor(x1, x2)\n    elif isinstance(x2, (int, float)):\n        x2, x1 = cast_literal_like_tensor(x2, x1)\n    else:\n        x1 = get_ov_output(x1)\n        x2 = get_ov_output(x2)\n    x1, x2 = _align_operand_types(x1, x2, \"select()\")\n    return OpenVINOKerasTensor(ov_opset.select(condition, x1, x2).output(0))\n\n\ndef divide(x1, x2):\n    element_type = None\n    if isinstance(x1, OpenVINOKerasTensor):\n        element_type = x1.output.get_element_type()\n    if isinstance(x2, OpenVINOKerasTensor):\n        element_type = x2.output.get_element_type()\n    x1 = get_ov_output(x1, element_type)\n    x2 = get_ov_output(x2, element_type)\n    x1_type = ov_to_keras_type(x1.get_element_type())\n    x2_type = ov_to_keras_type(x2.get_element_type())\n    result_type = dtypes.result_type(x1_type, x2_type, float)\n    result_type = OPENVINO_DTYPES[result_type]\n    x1 = ov_opset.convert(x1, result_type).output(0)\n    x2 = ov_opset.convert(x2, result_type).output(0)\n    return OpenVINOKerasTensor(ov_opset.divide(x1, x2).output(0))\n\n\ndef divide_no_nan(x1, x2):\n    element_type = None\n    if isinstance(x1, OpenVINOKerasTensor):\n        element_type = x1.output.get_element_type()\n    if isinstance(x2, OpenVINOKerasTensor):\n        element_type = x2.output.get_element_type()\n    x1 = get_ov_output(x1, element_type)\n    x2 = get_ov_output(x2, element_type)\n    x1, x2 = _align_operand_types(x1, x2, \"divide_no_nan()\")\n\n    zero = ov_opset.constant(0, x2.get_element_type())\n    div = ov_opset.divide(x1, x2)\n    is_zero = ov_opset.equal(x2, zero)\n    result = ov_opset.select(is_zero, zero, div)\n    return OpenVINOKerasTensor(result.output(0))\n\n\ndef true_divide(x1, x2):\n    return divide(x1, x2)\n\n\ndef power(x1, x2):\n    t1 = (\n        ov_to_keras_type(x1.get_element_type())\n        if isinstance(x1, ov.Output)\n        else getattr(x1, \"dtype\", type(x1))\n    )\n    t2 = (\n        ov_to_keras_type(x2.get_element_type())\n        if isinstance(x2, ov.Output)\n        else getattr(x2, \"dtype\", type(x2))\n    )\n    element_type = OPENVINO_DTYPES[dtypes.result_type(t1, t2)]\n    x1 = get_ov_output(x1, element_type)\n    x2 = get_ov_output(x2, element_type)\n    x1, x2 = _align_operand_types(x1, x2, \"power()\")\n    return OpenVINOKerasTensor(ov_opset.power(x1, x2).output(0))\n\n\ndef negative(x):\n    x = get_ov_output(x)\n    return OpenVINOKerasTensor(ov_opset.negative(x).output(0))\n\n\ndef nextafter(x1, x2):\n    x1 = get_ov_output(x1)\n    x2 = get_ov_output(x2)\n    x1, x2 = _align_operand_types(x1, x2, \"nextafter()\")\n\n    x1_keras = ov_to_keras_type(x1.get_element_type())\n    x2_keras = ov_to_keras_type(x2.get_element_type())\n    dtype = dtypes.result_type(x1_keras, x2_keras, float)\n    ov_dtype = OPENVINO_DTYPES[dtype]\n\n    # Work in float64 for precision (matches TF/PyTorch approach)\n    x1 = ov_opset.convert(x1, Type.f64).output(0)\n    x2 = ov_opset.convert(x2, Type.f64).output(0)\n\n    zero = ov_opset.constant(0.0, Type.f64).output(0)\n    two = ov_opset.constant(2.0, Type.f64).output(0)\n    half = ov_opset.constant(0.5, Type.f64).output(0)\n\n    eq_mask = ov_opset.equal(x1, x2).output(0)\n    direction = ov_opset.sign(ov_opset.subtract(x2, x1)).output(0)\n    abs_x1 = ov_opset.abs(x1).output(0)\n\n    # Compute ULP = 2^(floor(log2(|x1|)) - 52) for normal float64 numbers\n    ln2 = ov_opset.constant(np.log(2.0), Type.f64).output(0)\n    log2_abs = ov_opset.floor(\n        ov_opset.divide(ov_opset.log(abs_x1), ln2)\n    ).output(0)\n    min_exp = ov_opset.constant(-1022.0, Type.f64).output(0)\n    clamped_exp = ov_opset.maximum(log2_abs, min_exp).output(0)\n    mantissa_bits = ov_opset.constant(52.0, Type.f64).output(0)\n    ulp_exp = ov_opset.subtract(clamped_exp, mantissa_bits).output(0)\n    ulp = ov_opset.power(two, ulp_exp).output(0)\n\n    # At power-of-2 boundaries going towards zero, the ULP is halved\n    # because we step into the adjacent binade with finer spacing\n    pow2_floor = ov_opset.power(two, log2_abs).output(0)\n    is_pow2 = ov_opset.equal(abs_x1, pow2_floor).output(0)\n    going_towards_zero = ov_opset.less(\n        ov_opset.multiply(x1, direction), zero\n    ).output(0)\n    halve_mask = ov_opset.logical_and(is_pow2, going_towards_zero).output(0)\n    ulp = ov_opset.select(halve_mask, ov_opset.multiply(ulp, half), ulp).output(\n        0\n    )\n\n    result = ov_opset.add(x1, ov_opset.multiply(direction, ulp)).output(0)\n\n    # Handle x1 == 0: result is the smallest subnormal towards x2\n    min_subnormal = ov_opset.constant(5e-324, Type.f64).output(0)\n    zero_result = ov_opset.multiply(ov_opset.sign(x2), min_subnormal).output(0)\n    is_zero = ov_opset.equal(x1, zero).output(0)\n    result = ov_opset.select(is_zero, zero_result, result).output(0)\n\n    # Handle x1 == x2: return x2\n    result = ov_opset.select(eq_mask, x2, result).output(0)\n\n    return OpenVINOKerasTensor(ov_opset.convert(result, ov_dtype).output(0))\n\n\ndef square(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type == Type.boolean:\n        x = ov_opset.convert(x, Type.i32).output(0)\n    const_two = ov_opset.constant(2, x.get_element_type()).output(0)\n    return OpenVINOKerasTensor(ov_opset.power(x, const_two).output(0))\n\n\ndef sqrt(x):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    if x_type.is_integral():\n        ov_type = OPENVINO_DTYPES[config.floatx()]\n        x = ov_opset.convert(x, ov_type).output(0)\n    return OpenVINOKerasTensor(ov_opset.sqrt(x).output(0))\n\n\ndef squeeze(x, axis=None):\n    x = get_ov_output(x)\n    if axis is None:\n        axis = []\n        for idx, dim in enumerate(x.get_partial_shape()):\n            if dim == 1:\n                axis.append(idx)\n    if isinstance(axis, tuple):\n        axis = list(axis)\n    axis = ov_opset.constant(axis, Type.i32).output(0)\n    return OpenVINOKerasTensor(ov_opset.squeeze(x, axis).output(0))\n\n\ndef transpose(x, axes=None):\n    x = get_ov_output(x)\n    if axes is None:\n        # generate reverse permutation vector\n        shape_x = ov_opset.shape_of(x, \"i64\").output(0)\n        rank_x = ov_opset.shape_of(shape_x, \"i64\").output(0)\n        scalar_shape = ov_opset.constant([], Type.i32).output(0)\n        rank_x = ov_opset.reshape(rank_x, scalar_shape, False).output(0)\n        const_minus_one = ov_opset.constant(-1, Type.i64).output(0)\n        rank_minus_one = ov_opset.add(rank_x, const_minus_one).output(0)\n        axes = ov_opset.range(\n            rank_minus_one, const_minus_one, const_minus_one, \"i64\"\n        ).output(0)\n    else:\n        if isinstance(axes, tuple):\n            axes = list(axes)\n        axes = ov_opset.constant(axes, Type.i32).output(0)\n    return OpenVINOKerasTensor(ov_opset.transpose(x, axes).output(0))\n\n\ndef _helper_trapezoid(y, axis):\n    rank = y.get_partial_shape().rank.get_length()\n    strides = ov_opset.constant([1] * rank, dtype=Type.i64).output(0)\n\n    # y[:-1]\n    begin1 = ov_opset.constant([0] * rank, dtype=Type.i64).output(0)\n    end1_list = [0] * rank\n    end1_list[axis] = -1\n    end1 = ov_opset.constant(end1_list, dtype=Type.i64).output(0)\n    begin_mask1 = [1] * rank\n    begin_mask1[axis] = 0\n    end_mask1 = [1] * rank\n    end_mask1[axis] = 0\n    y1 = ov_opset.strided_slice(\n        y, begin1, end1, strides, begin_mask1, end_mask1\n    ).output(0)\n\n    # y[1:]\n    begin2_list = [0] * rank\n    begin2_list[axis] = 1\n    begin2 = ov_opset.constant(begin2_list, dtype=Type.i64).output(0)\n    end2 = ov_opset.constant([0] * rank, dtype=Type.i64).output(0)\n    begin_mask2 = [1] * rank\n    begin_mask2[axis] = 0\n    end_mask2 = [1] * rank\n    y2 = ov_opset.strided_slice(\n        y, begin2, end2, strides, begin_mask2, end_mask2\n    ).output(0)\n\n    return y1, y2\n\n\ndef trapezoid(y, x=None, dx=1.0, axis=-1):\n    y = get_ov_output(y)\n    y_type = y.get_element_type()\n\n    if y_type.is_integral():\n        y_type = OPENVINO_DTYPES[config.floatx()]\n        y = ov_opset.convert(y, y_type).output(0)\n\n    y1, y2 = _helper_trapezoid(y, axis)\n    y_final = ov_opset.add(y1, y2).output(0)\n    const_two = ov_opset.constant(2, dtype=y_type).output(0)\n    y_final = ov_opset.divide(y_final, const_two).output(0)\n\n    if x is not None:\n        x = get_ov_output(x)\n        x_type = x.get_element_type()\n        if x_type.is_integral():\n            x_type = OPENVINO_DTYPES[config.floatx()]\n            x = ov_opset.convert(x, x_type).output(0)\n\n        x1, x2 = _helper_trapezoid(x, axis)\n        x_final = ov_opset.subtract(x2, x1).output(0)\n\n    else:\n        x_final = ov_opset.constant(dx, dtype=y_type).output(0)\n\n    result = ov_opset.multiply(y_final, x_final).output(0)\n    const_axis = ov_opset.constant([axis], Type.i64).output(0)\n    result = ov_opset.reduce_sum(result, const_axis, False).output(0)\n\n    return OpenVINOKerasTensor(result)\n\n\ndef unravel_index(indices, shape):\n    indices = get_ov_output(indices)\n    if not indices.get_element_type().is_integral():\n        indices = ov_opset.convert(indices, Type.i64).output(0)\n    indices_dtype = indices.get_element_type()\n\n    if None in shape:\n        raise ValueError(\n            f\"`shape` argument cannot contain `None`. Received: shape={shape}\"\n        )\n\n    if isinstance(shape, tuple):\n        shape = list(shape)\n\n    # Handle negative indices\n    total_size = np.prod(shape)\n    total_size_const = ov_opset.constant(total_size, indices_dtype).output(0)\n\n    zero = ov_opset.constant(0, indices_dtype).output(0)\n    is_negative = ov_opset.less(indices, zero).output(0)\n    indices = ov_opset.select(\n        is_negative, ov_opset.add(indices, total_size_const), indices\n    ).output(0)\n\n    coords = []\n    for dim_size in reversed(shape):\n        dim_const = ov_opset.constant(dim_size, indices_dtype).output(0)\n        coord = ov_opset.floor_mod(indices, dim_const).output(0)\n        coords.append(coord)\n        indices = ov_opset.divide(indices, dim_const).output(0)\n\n    coords = list(reversed(coords))\n    return tuple(OpenVINOKerasTensor(coord) for coord in coords)\n\n\ndef vander(x, N=None, increasing=False):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n\n    shape_x = ov_opset.shape_of(x, Type.i64).output(0)\n\n    const_zero_1D = ov_opset.constant([0], dtype=Type.i64).output(0)\n    const_zero = ov_opset.constant(0, dtype=Type.i64).output(0)\n    const_one = ov_opset.constant(1, dtype=Type.i64).output(0)\n    const_mone = ov_opset.constant(-1, dtype=Type.i64).output(0)\n\n    if N is None:\n        const_N = ov_opset.squeeze(shape_x, const_zero_1D).output(0)\n        const_N_1D = shape_x\n    else:\n        const_N = ov_opset.constant(N, Type.i64).output(0)\n        const_N_1D = ov_opset.constant([N], Type.i64).output(0)\n\n    const_N_minus_one = ov_opset.subtract(const_N, const_one).output(0)\n    if increasing:\n        powers = ov_opset.range(const_zero, const_N, const_one, x_type).output(\n            0\n        )\n    else:\n        powers = ov_opset.range(\n            const_N_minus_one, const_mone, const_mone, x_type\n        ).output(0)\n\n    target_shape = ov_opset.concat([shape_x, const_N_1D], 0).output(0)\n\n    const_one_1D = ov_opset.constant([1], dtype=Type.i64).output(0)\n\n    powers = ov_opset.unsqueeze(powers, const_zero_1D).output(0)\n    x = ov_opset.unsqueeze(x, const_one_1D).output(0)\n\n    result = ov_opset.broadcast(x, target_shape).output(0)\n\n    result = ov_opset.power(result, powers).output(0)\n\n    return OpenVINOKerasTensor(result)\n\n\ndef var(x, axis=None, keepdims=False):\n    x = get_ov_output(x)\n    x_type = x.get_element_type()\n    x, axis = _resolve_axis(x, axis)\n\n    if x_type.is_integral() or x_type == Type.boolean:\n        work_dtype = OPENVINO_DTYPES[config.floatx()]\n        x = ov_opset.convert(x, work_dtype).output(0)\n    else:\n        work_dtype = x_type\n\n    if axis is None:\n        const_zero = ov_opset.constant(0, dtype=work_dtype).output(0)\n        return OpenVINOKerasTensor(\n            ov_opset.broadcast(const_zero, ov_opset.shape_of(x)).output(0)\n        )\n    # The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster\n    # but less numerically stable.\n    mean = ov_opset.reduce_mean(x, axis, keepdims).output(0)\n    const_two = ov_opset.constant(2, work_dtype).output(0)\n\n    squared_x = ov_opset.power(x, const_two).output(0)\n    squared_mean = ov_opset.power(mean, const_two).output(0)\n\n    squared_x_mean = ov_opset.reduce_mean(squared_x, axis, keepdims).output(0)\n    variance = OpenVINOKerasTensor(\n        ov_opset.subtract(squared_x_mean, squared_mean).output(0)\n    )\n    return variance\n\n\ndef sum(x, axis=None, keepdims=False):\n    x = get_ov_output(x)\n    x, axis = _resolve_axis(x, axis)\n    if axis is None:\n        return OpenVINOKerasTensor(x)\n    x = _upcast_type_if_needed(x)\n    summed_value = ov_opset.reduce_sum(x, axis, keepdims).output(0)\n    return OpenVINOKerasTensor(summed_value)\n\n\ndef eye(N, M=None, k=0, dtype=None):\n    dtype = standardize_dtype(dtype) or config.floatx()\n    ov_type = OPENVINO_DTYPES[dtype]\n    if M is None:\n        M = N\n    return OpenVINOKerasTensor(\n        ov_opset.eye(\n            ov_opset.constant(N, Type.i32),\n            ov_opset.constant(M, Type.i32),\n            ov_opset.constant(k, Type.i32),\n            output_type=ov_type,\n        ).output(0)\n    )\n\n\ndef floor_divide(x1, x2):\n    x1_output = get_ov_output(x1)\n    x2_output = get_ov_output(x2)\n    if x1_output.get_element_type() == Type.boolean:\n        x1_output = ov_opset.convert(x1_output, Type.i32).output(0)\n    if isinstance(x2, (int, float)):\n        if x1_output.get_element_type().is_integral() and isinstance(x2, float):\n            ov_type = OPENVINO_DTYPES[config.floatx()]\n        else:\n            ov_type = x1_output.get_element_type()\n        x1 = ov_opset.convert(x1_output, ov_type).output(0)\n        x2 = ov_opset.convert(x2_output, ov_type).output(0)\n    else:\n        x1, x2 = _align_operand_types(x1_output, x2_output, \"floor_divide()\")\n    div = ov_opset.divide(x1, x2).output(0)\n    floored_div = ov_opset.floor(div).output(0)\n    return OpenVINOKerasTensor(floored_div)\n\n\ndef logical_xor(x1, x2):\n    x1 = get_ov_output(x1)\n    x2 = get_ov_output(x2)\n    x1 = ov_opset.convert(x1, Type.boolean).output(0)\n    x2 = ov_opset.convert(x2, Type.boolean).output(0)\n    return OpenVINOKerasTensor(ov_opset.logical_xor(x1, x2).output(0))\n\n\ndef corrcoef(x):\n    x_ov = get_ov_output(x)\n    x_type = x_ov.get_element_type()\n    ov_type = x_type\n\n    if x_type.is_integral():\n        ov_type = OPENVINO_DTYPES[config.floatx()]\n        x_ov = ov_opset.convert(x_ov, ov_type).output(0)\n\n    const_one = ov_opset.constant(1, dtype=Type.i64).output(0)\n    const_two = ov_opset.constant(2, dtype=ov_type).output(0)\n\n    mean = ov_opset.reduce_mean(x_ov, const_one, True).output(0)\n    x_ov = ov_opset.subtract(x_ov, mean).output(0)\n\n    cov = ov_opset.matmul(x_ov, x_ov, False, True).output(0)\n    xsqr = ov_opset.power(x_ov, const_two).output(0)\n    xvar = ov_opset.reduce_sum(xsqr, const_one, True).output(0)\n    xstd = ov_opset.sqrt(xvar).output(0)\n\n    den = ov_opset.matmul(xstd, xstd, False, True).output(0)\n\n    result = ov_opset.divide(cov, den).output(0)\n\n    return OpenVINOKerasTensor(result)\n\n\ndef correlate(x1, x2, mode=\"valid\"):\n    x1 = get_ov_output(x1)\n    x2 = get_ov_output(x2)\n    x1_type = x1.get_element_type()\n    x2_type = x2.get_element_type()\n    x1_type = ov_to_keras_type(x1_type)\n    x2_type = ov_to_keras_type(x2_type)\n    result_type = dtypes.result_type(x1_type, x2_type, float)\n\n    result_type = OPENVINO_DTYPES[result_type]\n    x1 = ov_opset.convert(x1, result_type).output(0)\n    x2 = ov_opset.convert(x2, result_type).output(0)\n\n    shape_filter = ov_opset.shape_of(x2, Type.i64).output(0)\n    const_two = ov_opset.constant(2, Type.f64).output(0)\n    const_one = ov_opset.constant(1, Type.i64).output(0)\n    const_zero = ov_opset.constant(0, result_type).output(0)\n    shape_filter_minus_one = ov_opset.subtract(shape_filter, const_one).output(\n        0\n    )\n\n    # padding x1\n    if mode == \"valid\":\n        pass\n\n    elif mode == \"same\":\n        shape_minus_one_float = ov_opset.convert(\n            shape_filter_minus_one, Type.f64\n        ).output(0)\n\n        right = ov_opset.divide(shape_minus_one_float, const_two).output(0)\n        left = ov_opset.ceil(right).output(0)\n        right = ov_opset.floor(right).output(0)\n        left = ov_opset.convert(left, Type.i64).output(0)\n        right = ov_opset.convert(right, Type.i64).output(0)\n        x1 = ov_opset.pad(x1, left, right, \"constant\", const_zero).output(0)\n\n    elif mode == \"full\":\n        pad = shape_filter_minus_one\n        x1 = ov_opset.pad(x1, pad, pad, \"constant\", const_zero).output(0)\n\n    else:\n        raise ValueError(\n            f\"mode: {mode} not available chose from valid, same, full.\"\n        )\n\n    axes = ov_opset.constant([0, 1], dtype=Type.i64).output(0)\n    x2 = ov_opset.unsqueeze(x2, axes).output(0)\n    x1 = ov_opset.unsqueeze(x1, axes).output(0)\n\n    result = ov_opset.convolution(x1, x2, [1], [0], [0], [1]).output(0)\n\n    result = ov_opset.squeeze(result, axes).output(0)\n\n    return OpenVINOKerasTensor(result)\n\n\ndef select(condlist, choicelist, default=0):\n    if len(condlist) != len(choicelist):\n        raise ValueError(\n            \"select(): condlist and choicelist must have the same length\"\n        )\n    conds = [get_ov_output(c) for c in condlist]\n    choices = [get_ov_output(v) for v in choicelist]\n\n    result = get_ov_output(default)\n    for cond_idx in reversed(range(len(conds))):\n        cond = conds[cond_idx]\n        choice = choices[cond_idx]\n        choice, result = _align_operand_types(choice, result, \"select()\")\n        result = ov_opset.select(cond, choice, result).output(0)\n    return OpenVINOKerasTensor(result)\n\n\ndef slogdet(x):\n    x = convert_to_tensor(x)\n    x_ov = get_ov_output(x)\n    x_ov_type = x_ov.get_element_type()\n\n    # Cast integer/boolean inputs to float\n    if x_ov_type.is_integral() or x_ov_type == Type.boolean:\n        float_type = OPENVINO_DTYPES[config.floatx()]\n        x_ov = ov_opset.convert(x_ov, float_type).output(0)\n        x_ov_type = x_ov.get_element_type()\n\n    # Promote to result type (e.g. float32 -> float64 if needed)\n    keras_type = ov_to_keras_type(x_ov_type)\n    result_ov_type = OPENVINO_DTYPES[dtypes.result_type(keras_type, float)]\n    if x_ov_type != result_ov_type:\n        x_ov = ov_opset.convert(x_ov, result_ov_type).output(0)\n        x_ov_type = result_ov_type\n\n    x_shape = x_ov.get_partial_shape()\n    x_rank = len(x_shape)\n    n = x_shape[-1].get_length()\n\n    # Flatten batch dims: (..., n, n) -> (batch, n, n)\n    flat_shape = ov_opset.constant([-1, n, n], Type.i32).output(0)\n    x_batched = ov_opset.reshape(x_ov, flat_shape, False).output(0)\n\n    batch_shape = ov_opset.shape_of(x_batched, Type.i32).output(0)\n    batch_size = ov_opset.gather(\n        batch_shape,\n        ov_opset.constant([0], Type.i32).output(0),\n        ov_opset.constant(0, Type.i32).output(0),\n    ).output(0)\n\n    zero = ov_opset.constant(0.0, x_ov_type).output(0)\n    one = ov_opset.constant(1.0, x_ov_type).output(0)\n    two = ov_opset.constant(2.0, x_ov_type).output(0)\n\n    # Accumulators — one value per batch element\n    log_abs_det = ov_opset.broadcast(zero, batch_size).output(0)\n    sign_det = ov_opset.broadcast(one, batch_size).output(0)\n\n    row_axis = ov_opset.constant(1, Type.i32).output(0)\n    col_axis = ov_opset.constant(2, Type.i32).output(0)\n\n    # LU decomposition with partial pivoting\n    for k in range(n):\n        # Find pivot row: max |value| in column k, from row k downward\n        col_k = ov_opset.gather(\n            x_batched, ov_opset.constant(k, Type.i32).output(0), col_axis\n        ).output(0)\n        abs_col_k = ov_opset.absolute(col_k).output(0)\n\n        # Slice rows [k:n] of the column\n        abs_col_k_sub = ov_opset.slice(\n            abs_col_k,\n            ov_opset.constant([0, k], Type.i32).output(0),\n            ov_opset.constant([2**30, n], Type.i32).output(0),\n            ov_opset.constant([1, 1], Type.i32).output(0),\n            ov_opset.constant([0, 1], Type.i32).output(0),\n        ).output(0)\n\n        topk_result = ov_opset.topk(\n            abs_col_k_sub,\n            ov_opset.constant(1, Type.i32).output(0),\n            axis=1,\n            mode=\"max\",\n            sort=\"none\",\n        )\n        local_max_idx = ov_opset.squeeze(\n            ov_opset.convert(topk_result.output(1), Type.i32).output(0),\n            ov_opset.constant([1], Type.i32).output(0),\n        ).output(0)\n\n        # Absolute pivot row index (local index is relative to row k)\n        pivot_row = ov_opset.add(\n            local_max_idx, ov_opset.constant(k, Type.i32).output(0)\n        ).output(0)\n\n        # Track sign change caused by row swap\n        swap_needed = ov_opset.not_equal(\n            pivot_row, ov_opset.constant(k, Type.i32).output(0)\n        ).output(0)\n        swap_needed_f = ov_opset.convert(swap_needed, x_ov_type).output(0)\n        # sign_flip = 1 - 2*swap_needed_f  →  no swap: +1, swap: -1\n        sign_flip = ov_opset.subtract(\n            ov_opset.broadcast(one, batch_size).output(0),\n            ov_opset.multiply(two, swap_needed_f).output(0),\n        ).output(0)\n        sign_det = ov_opset.multiply(sign_det, sign_flip).output(0)\n\n        # Swap row k with pivot_row\n        row_k = ov_opset.gather(\n            x_batched, ov_opset.constant([k], Type.i32).output(0), row_axis\n        ).output(0)\n        pivot_row_2d = ov_opset.unsqueeze(\n            pivot_row, ov_opset.constant([1], Type.i32).output(0)\n        ).output(0)\n        pivot_row_data = ov_opset.gather(\n            x_batched, pivot_row_2d, row_axis, batch_dims=1\n        ).output(0)\n\n        # Write pivot row data into position k\n        x_batched = ov_opset.scatter_update(\n            x_batched,\n            ov_opset.constant([k], Type.i32).output(0),\n            pivot_row_data,\n            row_axis,\n        ).output(0)\n\n        # Write old row k into position pivot_row (mask-based scatter)\n        all_row_indices = ov_opset.unsqueeze(\n            ov_opset.range(\n                ov_opset.constant(0, Type.i32).output(0),\n                ov_opset.constant(n, Type.i32).output(0),\n                ov_opset.constant(1, Type.i32).output(0),\n                output_type=Type.i32,\n            ).output(0),\n            ov_opset.constant([0, 2], Type.i32).output(0),\n        ).output(0)\n\n        pivot_row_3d = ov_opset.unsqueeze(\n            pivot_row_2d, ov_opset.constant([2], Type.i32).output(0)\n        ).output(0)\n        swap_mask = ov_opset.equal(all_row_indices, pivot_row_3d).output(0)\n        row_k_tiled = ov_opset.broadcast(\n            row_k, ov_opset.shape_of(x_batched, Type.i32).output(0)\n        ).output(0)\n        x_batched = ov_opset.select(swap_mask, row_k_tiled, x_batched).output(0)\n\n        # Extract pivot element and accumulate log|det| and sign\n        k_idx = ov_opset.constant([k], Type.i32).output(0)\n        pivot_row_cur = ov_opset.gather(x_batched, k_idx, row_axis).output(0)\n        pivot_elem = ov_opset.gather(pivot_row_cur, k_idx, col_axis).output(0)\n        pivot_scalar = ov_opset.squeeze(\n            pivot_elem, ov_opset.constant([1, 2], Type.i32).output(0)\n        ).output(0)\n\n        abs_pivot = ov_opset.absolute(pivot_scalar).output(0)\n        safe_abs = ov_opset.maximum(\n            abs_pivot, ov_opset.constant(1e-38, x_ov_type).output(0)\n        ).output(0)\n        log_abs_det = ov_opset.add(\n            log_abs_det, ov_opset.log(safe_abs).output(0)\n        ).output(0)\n        sign_det = ov_opset.multiply(\n            sign_det, ov_opset.sign(pivot_scalar).output(0)\n        ).output(0)\n\n        # Protect against division by zero during elimination\n        safe_pivot = ov_opset.select(\n            ov_opset.equal(\n                pivot_elem, ov_opset.constant(0.0, x_ov_type).output(0)\n            ).output(0),\n            ov_opset.constant(1.0, x_ov_type).output(0),\n            pivot_elem,\n        ).output(0)\n\n        # Gaussian elimination: zero out entries below pivot\n        for i in range(k + 1, n):\n            i_idx = ov_opset.constant([i], Type.i32).output(0)\n            row_i = ov_opset.gather(x_batched, i_idx, row_axis).output(0)\n            elem_ik = ov_opset.gather(row_i, k_idx, col_axis).output(0)\n            multiplier = ov_opset.divide(elem_ik, safe_pivot).output(0)\n            row_i_new = ov_opset.subtract(\n                row_i, ov_opset.multiply(multiplier, pivot_row_cur).output(0)\n            ).output(0)\n            x_batched = ov_opset.scatter_update(\n                x_batched, i_idx, row_i_new, row_axis\n            ).output(0)\n\n    # For singular matrices: sign=0, logabsdet=-inf\n    is_singular = ov_opset.equal(\n        sign_det, ov_opset.broadcast(zero, batch_size).output(0)\n    ).output(0)\n    neg_inf = ov_opset.constant(float(\"-inf\"), x_ov_type).output(0)\n    log_abs_det = ov_opset.select(\n        is_singular,\n        ov_opset.broadcast(neg_inf, batch_size).output(0),\n        log_abs_det,\n    ).output(0)\n\n    # Reshape outputs back to batch shape (drop last two dims)\n    if x_rank > 2:\n        batch_dims = [x_shape[i].get_length() for i in range(x_rank - 2)]\n        out_shape = ov_opset.constant(batch_dims, Type.i32).output(0)\n    else:\n        out_shape = ov_opset.constant([], Type.i32).output(0)\n\n    sign_result = ov_opset.reshape(sign_det, out_shape, False).output(0)\n    logabsdet_result = ov_opset.reshape(log_abs_det, out_shape, False).output(0)\n\n    return OpenVINOKerasTensor(sign_result), OpenVINOKerasTensor(\n        logabsdet_result\n    )\n\n\ndef argpartition(x, kth, axis=-1):\n    x = get_ov_output(x)\n    x_shape = x.get_partial_shape()\n    rank = x_shape.rank.get_length()\n    axis = canonicalize_axis(axis, rank)\n    axes = list(range(rank))\n    axes[axis], axes[-1] = axes[-1], axes[axis]\n    x = ov_opset.transpose(x, ov_opset.constant(axes))\n    x_shape_tensor = ov_opset.shape_of(x)\n    n = ov_opset.gather(\n        x_shape_tensor,\n        ov_opset.constant(-1),\n        ov_opset.constant(0),\n    )\n    if isinstance(kth, int) and kth < 0:\n        kth_tensor = ov_opset.add(\n            n,\n            ov_opset.constant(kth, n.get_element_type()),\n        )\n    else:\n        kth_tensor = ov_opset.constant(kth, n.get_element_type())\n    one = ov_opset.constant(1, kth_tensor.get_element_type())\n    k_val = ov_opset.add(kth_tensor, one)\n    bottom_ind = ov_opset.topk(\n        ov_opset.negative(x),\n        k=k_val,\n        axis=-1,\n        mode=\"max\",\n        sort=\"value\",\n    ).output(1)\n    one_hot_mask = ov_opset.one_hot(\n        bottom_ind,\n        n,\n        ov_opset.constant(1),\n        ov_opset.constant(0),\n        axis=-1,\n    )\n    mask = ov_opset.reduce_sum(\n        one_hot_mask,\n        ov_opset.constant([-2]),\n        keep_dims=False,\n    )\n    ones = ov_opset.broadcast(\n        ov_opset.constant(1),\n        x_shape_tensor,\n    )\n    proxy = ov_opset.subtract(ones, mask)\n    remaining_k = ov_opset.subtract(n, k_val)\n    top_ind = ov_opset.topk(\n        proxy,\n        k=remaining_k,\n        axis=-1,\n        mode=\"max\",\n        sort=\"value\",\n    ).output(1)\n    result = ov_opset.concat([bottom_ind, top_ind], axis=-1)\n    inv_axes = [0] * rank\n    for i, a in enumerate(axes):\n        inv_axes[a] = i\n    result = ov_opset.transpose(\n        result,\n        ov_opset.constant(inv_axes),\n    ).output(0)\n    return OpenVINOKerasTensor(result)\n\n\ndef histogram(x, bins=10, range=None):\n    x = get_ov_output(x)\n    x = ov_opset.reshape(x, [-1], False).output(0)\n\n    float_type = OPENVINO_DTYPES[config.floatx()]\n    x_float = ov_opset.convert(x, float_type).output(0)\n\n    if range is None:\n        min_val = ov_opset.reduce_min(x_float, 0).output(0)\n        max_val = ov_opset.reduce_max(x_float, 0).output(0)\n\n        is_equal = ov_opset.equal(min_val, max_val)\n        half = ov_opset.constant(0.5, float_type).output(0)\n        min_val = ov_opset.select(\n            is_equal, ov_opset.subtract(min_val, half), min_val\n        )\n        max_val = ov_opset.select(\n            is_equal, ov_opset.add(max_val, half), max_val\n        )\n\n        min_val = min_val.output(0)\n        max_val = max_val.output(0)\n    else:\n        min_val = ov_opset.constant(range[0], float_type).output(0)\n        max_val = ov_opset.constant(range[1], float_type).output(0)\n\n    bins_const = ov_opset.constant(bins, float_type).output(0)\n    step = ov_opset.divide(\n        ov_opset.subtract(max_val, min_val), bins_const\n    ).output(0)\n\n    idx_float = ov_opset.range(\n        ov_opset.constant(0, float_type),\n        ov_opset.constant(bins + 1, float_type),\n        ov_opset.constant(1, float_type),\n        output_type=float_type,\n    ).output(0)\n\n    bin_edges = ov_opset.add(\n        min_val, ov_opset.multiply(idx_float, step)\n    ).output(0)\n\n    inds = ov_opset.bucketize(\n        x_float, bin_edges, output_type=Type.i32, with_right_bound=False\n    ).output(0)\n\n    inds_shifted = ov_opset.subtract(\n        inds, ov_opset.constant(1, Type.i32).output(0)\n    )\n\n    trash_idx = ov_opset.constant(bins, Type.i32).output(0)\n\n    is_under = ov_opset.less(\n        inds_shifted, ov_opset.constant(0, Type.i32).output(0)\n    )\n    is_over = ov_opset.greater_equal(inds_shifted, trash_idx)\n\n    is_max = ov_opset.equal(x_float, max_val)\n\n    final_inds = inds_shifted\n    final_inds = ov_opset.select(is_under, trash_idx, final_inds)\n\n    bins_minus_1 = ov_opset.constant(bins - 1, Type.i32).output(0)\n    replacement = ov_opset.select(is_max, bins_minus_1, trash_idx)\n    final_inds = ov_opset.select(is_over, replacement, final_inds)\n\n    depth = ov_opset.constant(bins + 1, Type.i32).output(0)\n    on_val = ov_opset.constant(1, Type.i32).output(0)\n    off_val = ov_opset.constant(0, Type.i32).output(0)\n\n    one_hot = ov_opset.one_hot(final_inds, depth, on_val, off_val, axis=-1)\n    counts = ov_opset.reduce_sum(\n        one_hot, ov_opset.constant(0, Type.i32).output(0), keep_dims=False\n    )\n\n    hist = ov_opset.slice(\n        counts,\n        ov_opset.constant([0], Type.i32).output(0),\n        ov_opset.constant([bins], Type.i32).output(0),\n        ov_opset.constant([1], Type.i32).output(0),\n        ov_opset.constant([0], Type.i32).output(0),\n    )\n\n    return OpenVINOKerasTensor(hist.output(0)), OpenVINOKerasTensor(bin_edges)\n"
  },
  {
    "path": "keras/src/backend/openvino/random.py",
    "content": "import numpy as np\nimport openvino.opset15 as ov_opset\nfrom openvino import Type\n\nfrom keras.src.backend.config import floatx\nfrom keras.src.backend.openvino import numpy as ov_numpy\nfrom keras.src.backend.openvino.core import OPENVINO_DTYPES\nfrom keras.src.backend.openvino.core import OpenVINOKerasTensor\nfrom keras.src.backend.openvino.core import convert_to_numpy\nfrom keras.src.backend.openvino.core import get_ov_output\nfrom keras.src.random.seed_generator import SeedGenerator\nfrom keras.src.random.seed_generator import draw_seed\nfrom keras.src.random.seed_generator import make_default_seed\n\n\ndef _rng_from_seed_data(seed_data):\n    \"\"\"Create a NumPy RNG from seed tensor data.\n\n    Seed tensors are stored as int32 and may be negative due to C-style\n    wrapping of large user-supplied values.  Reinterpret the bit pattern\n    as uint32 so that np.random.default_rng receives non-negative entropy.\n    \"\"\"\n    if seed_data is None:\n        return np.random.default_rng()\n    arr = np.asarray(seed_data, dtype=np.int32).view(np.uint32)\n    return np.random.default_rng(arr)\n\n\ndef _random_uniform(shape, minval, maxval, dtype, seed1, seed2):\n    \"\"\"Wrapper for `ov_opset.random_uniform` that sanitizes seed values.\n\n    Pybind11 will sign-flip values >= 2**31, and int32 wrap-around can\n    produce negative values. Masks seeds to 31 bits so they are in the range\n    [0, 2**31-1], then clamps to at least 1, as the OpenVINO C++ layer\n    requires each seed to be strictly positive (> 0).\n\n    Args:\n        shape: The shape of the random tensor.\n        minval: The lower bound of the random distribution.\n        maxval: The upper bound of the random distribution.\n        dtype: The data type of the output tensor.\n        seed1: The first part of the seed.\n        seed2: The second part of the seed.\n\n    Returns:\n        An OpenVINO output tensor with random values.\n    \"\"\"\n    safe_seed1 = max(1, int(seed1) & 0x7FFFFFFF)\n    safe_seed2 = max(1, int(seed2) & 0x7FFFFFFF)\n    return ov_opset.random_uniform(\n        shape, minval, maxval, dtype, safe_seed1, safe_seed2\n    ).output(0)\n\n\ndef normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):\n    dtype = dtype or floatx()\n    seed = draw_seed(seed)\n    rng = _rng_from_seed_data(seed.data)\n    normal_const = rng.normal(size=shape, loc=mean, scale=stddev).astype(dtype)\n    return OpenVINOKerasTensor(ov_opset.constant(normal_const).output(0))\n\n\ndef uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):\n    dtype = dtype or floatx()\n    seed_val = draw_seed(seed)\n    if isinstance(seed_val, OpenVINOKerasTensor):\n        seed_data = convert_to_numpy(seed_val)\n    else:\n        seed_data = seed_val.data\n    rng = _rng_from_seed_data(seed_data)\n    random_values = rng.uniform(minval, maxval, size=shape).astype(dtype)\n    return OpenVINOKerasTensor(ov_opset.constant(random_values).output(0))\n\n\ndef categorical(logits, num_samples, dtype=\"int64\", seed=None):\n    dtype = dtype or \"int64\"\n    ov_dtype = OPENVINO_DTYPES[dtype]\n    logits = get_ov_output(logits)\n\n    zero_const = ov_opset.constant(0, Type.i32).output(0)\n    one_const = ov_opset.constant(1, Type.i32).output(0)\n    neg_one_const = ov_opset.constant(-1, Type.i32).output(0)\n\n    # Compute probabilities and cumulative sum\n    probs = ov_opset.softmax(logits, axis=-1).output(0)\n    cumsum_probs = ov_opset.cumsum(probs, neg_one_const).output(0)\n\n    # Get shape and compute batch dimensions\n    logits_shape = ov_opset.shape_of(logits, Type.i32).output(0)\n    rank = ov_opset.shape_of(logits_shape, Type.i32).output(0)\n    rank_scalar = ov_opset.squeeze(rank, zero_const).output(0)\n    rank_minus_1 = ov_opset.subtract(rank_scalar, one_const).output(0)\n\n    # Extract batch shape (all dimensions except last)\n    batch_indices = ov_opset.range(\n        zero_const, rank_minus_1, one_const, output_type=Type.i32\n    ).output(0)\n    batch_shape = ov_opset.gather(logits_shape, batch_indices, axis=0).output(0)\n\n    # Create final shape [batch_dims..., num_samples]\n    num_samples_const = ov_opset.constant([num_samples], Type.i32).output(0)\n    final_shape = ov_opset.concat(\n        [batch_shape, num_samples_const], axis=0\n    ).output(0)\n\n    seed_tensor = draw_seed(seed)\n    if isinstance(seed_tensor, OpenVINOKerasTensor):\n        seed1, seed2 = convert_to_numpy(seed_tensor)\n    else:\n        seed1, seed2 = seed_tensor.data\n\n    probs_dtype = probs.get_element_type()\n    zero_float = ov_opset.constant(0.0, probs_dtype).output(0)\n    one_float = ov_opset.constant(1.0, probs_dtype).output(0)\n\n    rand = _random_uniform(\n        final_shape, zero_float, one_float, probs_dtype, seed1, seed2\n    )\n\n    rand_unsqueezed = ov_opset.unsqueeze(rand, neg_one_const).output(0)\n    cumsum_unsqueezed = ov_opset.unsqueeze(cumsum_probs, one_const).output(0)\n\n    # Count how many cumulative probabilities each random number exceeds\n    greater = ov_opset.greater(rand_unsqueezed, cumsum_unsqueezed).output(0)\n    samples = ov_opset.reduce_sum(\n        ov_opset.convert(greater, Type.i32).output(0), neg_one_const\n    ).output(0)\n\n    result = ov_opset.convert(samples, ov_dtype).output(0)\n    return OpenVINOKerasTensor(result)\n\n\ndef randint(shape, minval, maxval, dtype=\"int32\", seed=None):\n    dtype = dtype or \"int32\"\n    ov_dtype = OPENVINO_DTYPES[dtype]\n    seed_val = draw_seed(seed)\n    if isinstance(seed_val, OpenVINOKerasTensor):\n        seed1, seed2 = convert_to_numpy(seed_val)\n    else:\n        seed1, seed2 = seed_val.data\n    if ov_dtype in (Type.i64, Type.u64, Type.u32):\n        gen_dtype = Type.i64\n    else:\n        gen_dtype = Type.i32\n    if isinstance(shape, (list, tuple)):\n        shape = ov_opset.constant(list(shape), Type.i32).output(0)\n    elif isinstance(shape, OpenVINOKerasTensor):\n        shape = shape.output\n    elif isinstance(shape, int):\n        shape = ov_opset.constant([shape], Type.i32).output(0)\n    else:\n        shape = get_ov_output(shape, Type.i32)\n    minval = get_ov_output(minval, gen_dtype)\n    maxval = get_ov_output(maxval, gen_dtype)\n    if minval.get_element_type() != gen_dtype:\n        minval = ov_opset.convert(minval, gen_dtype).output(0)\n    if maxval.get_element_type() != gen_dtype:\n        maxval = ov_opset.convert(maxval, gen_dtype).output(0)\n    rand = _random_uniform(shape, minval, maxval, gen_dtype, seed1, seed2)\n    if ov_dtype != gen_dtype:\n        result = ov_opset.convert(rand, ov_dtype).output(0)\n    else:\n        result = rand\n    return OpenVINOKerasTensor(result)\n\n\ndef truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):\n    dtype = dtype or floatx()\n    seed = draw_seed(seed)\n    rng = _rng_from_seed_data(seed.data)\n\n    lower_bound = mean - 2 * stddev\n    upper_bound = mean + 2 * stddev\n\n    flat_shape = np.prod(shape)\n    random_numbers = np.empty(0)\n\n    # loop until we have enough valid numbers to fill our desired shape\n    while random_numbers.shape[0] < flat_shape:\n        # Generate a batch of random numbers from a normal distribution\n        batch = rng.normal(loc=mean, scale=stddev, size=flat_shape)\n\n        # Filter the numbers to keep only those within the specified bounds\n        valid = batch[(batch >= lower_bound) & (batch <= upper_bound)]\n\n        # Append the valid numbers to the result array\n        random_numbers = np.append(random_numbers, valid)\n\n    # Truncate the result array to the desired size and reshape it\n    np_array_res = random_numbers[:flat_shape].astype(dtype).reshape(shape)\n    return OpenVINOKerasTensor(ov_opset.constant(np_array_res).output(0))\n\n\ndef dropout(inputs, rate, noise_shape=None, seed=None):\n    inputs_ov = get_ov_output(inputs)\n    dtype = inputs_ov.get_element_type()\n\n    seed_val = draw_seed(seed)\n    if isinstance(seed_val, OpenVINOKerasTensor):\n        seed1, seed2 = convert_to_numpy(seed_val)\n    else:\n        seed1, seed2 = seed_val.data\n\n    if not isinstance(rate, (int, float)):\n        rate = get_ov_output(rate, dtype)\n    else:\n        rate = ov_opset.constant(rate, dtype).output(0)\n\n    one = ov_opset.constant(1.0, dtype).output(0)\n    keep_prob = ov_opset.subtract(one, rate).output(0)\n\n    if noise_shape is None:\n        noise_shape_node = ov_opset.shape_of(inputs_ov, Type.i32).output(0)\n    else:\n        shape_elements = []\n        input_shape_node = ov_opset.shape_of(inputs_ov, Type.i32).output(0)\n        zero_index = ov_opset.constant(0, Type.i32).output(0)\n\n        for i, dim in enumerate(noise_shape):\n            if dim is None:\n                indices = ov_opset.constant([i], Type.i32).output(0)\n                dim_node = ov_opset.gather(\n                    input_shape_node, indices, zero_index\n                ).output(0)\n                shape_elements.append(dim_node)\n            else:\n                shape_elements.append(\n                    ov_opset.constant([dim], Type.i32).output(0)\n                )\n\n        noise_shape_node = ov_opset.concat(shape_elements, 0).output(0)\n\n    gen_dtype = dtype\n    if dtype in (Type.bf16, Type.f16):\n        gen_dtype = Type.f32\n\n    min_val = ov_opset.constant(0.0, gen_dtype).output(0)\n    max_val = ov_opset.constant(1.0, gen_dtype).output(0)\n\n    rand = _random_uniform(\n        noise_shape_node, min_val, max_val, gen_dtype, seed1, seed2\n    )\n\n    if gen_dtype != dtype:\n        keep_prob_gen = ov_opset.convert(keep_prob, gen_dtype).output(0)\n        mask = ov_opset.less(rand, keep_prob_gen).output(0)\n    else:\n        mask = ov_opset.less(rand, keep_prob).output(0)\n\n    zero = ov_opset.constant(0.0, dtype).output(0)\n    one_dtype = ov_opset.constant(1.0, dtype).output(0)\n\n    is_zero_prob = ov_opset.equal(keep_prob, zero).output(0)\n    safe_prob = ov_opset.select(is_zero_prob, one_dtype, keep_prob).output(0)\n    inv_prob = ov_opset.divide(one_dtype, safe_prob).output(0)\n    scale = ov_opset.select(is_zero_prob, zero, inv_prob).output(0)\n\n    mask_casted = ov_opset.convert(mask, dtype).output(0)\n\n    masked_inputs = ov_opset.multiply(inputs_ov, mask_casted).output(0)\n    result = ov_opset.multiply(masked_inputs, scale).output(0)\n\n    return OpenVINOKerasTensor(result)\n\n\ndef shuffle(x, axis=0, seed=None):\n    seed_tensor = draw_seed(seed)\n    if isinstance(seed_tensor, OpenVINOKerasTensor):\n        seed1, seed2 = convert_to_numpy(seed_tensor)\n    else:\n        seed1, seed2 = seed_tensor.data\n    x_ov = get_ov_output(x)\n    x_shape = x_ov.get_partial_shape()\n    rank = x_shape.rank.get_length()\n    if axis < 0:\n        axis += rank\n    shape_tensor = ov_opset.shape_of(x_ov, Type.i32).output(0)\n    dim_size = ov_opset.gather(\n        shape_tensor,\n        ov_opset.constant([axis], Type.i32).output(0),\n        ov_opset.constant(0, Type.i32).output(0),\n    ).output(0)\n    min_val = ov_opset.constant(0.0, Type.f32).output(0)\n    max_val = ov_opset.constant(1.0, Type.f32).output(0)\n    rand_shape = ov_opset.reshape(\n        dim_size, ov_opset.constant([1], Type.i32).output(0), False\n    ).output(0)\n    rand_values = _random_uniform(\n        rand_shape, min_val, max_val, Type.f32, seed1, seed2\n    )\n    indices = ov_numpy.argsort(OpenVINOKerasTensor(rand_values), axis=0)\n    return ov_numpy.take(x, indices, axis=axis)\n\n\ndef _const(val, dtype):\n    if dtype == Type.bf16:\n        return ov_opset.convert(\n            ov_opset.constant(val, Type.f32), Type.bf16\n        ).output(0)\n    return ov_opset.constant(val, dtype).output(0)\n\n\ndef _random_normal(shape, dtype, seed1, seed2):\n    zero = _const(0.0, dtype)\n    one = _const(1.0, dtype)\n    two_pi = _const(2 * np.pi, dtype)\n    minus_two = _const(-2.0, dtype)\n    epsilon = _const(1e-7, dtype)\n    u1 = _random_uniform(shape, zero, one, dtype, seed1, seed2)\n    u2 = _random_uniform(shape, zero, one, dtype, seed1 + 123, seed2)\n    u1 = ov_opset.add(u1, epsilon).output(0)\n    mag = ov_opset.sqrt(ov_opset.multiply(minus_two, ov_opset.log(u1))).output(\n        0\n    )\n    angle = ov_opset.multiply(two_pi, u2).output(0)\n    z0 = ov_opset.multiply(mag, ov_opset.cos(angle)).output(0)\n    return z0\n\n\ndef gamma(shape, alpha, dtype=None, seed=None):\n    dtype = dtype or floatx()\n    ov_dtype = OPENVINO_DTYPES[dtype]\n    seed_val = draw_seed(seed)\n    if isinstance(seed_val, OpenVINOKerasTensor):\n        seed1, seed2 = convert_to_numpy(seed_val)\n    else:\n        seed1, seed2 = seed_val.data\n    seed1 = int(seed1)\n    seed2 = int(seed2)\n    if isinstance(shape, (list, tuple)):\n        shape = ov_opset.constant(list(shape), Type.i32).output(0)\n    elif isinstance(shape, OpenVINOKerasTensor):\n        shape = shape.output\n    else:\n        shape = get_ov_output(shape, Type.i32)\n    alpha = get_ov_output(alpha, ov_dtype)\n    one = _const(1.0, ov_dtype)\n    one_third = _const(1.0 / 3.0, ov_dtype)\n    zero = _const(0.0, ov_dtype)\n    is_small_alpha = ov_opset.less(alpha, one).output(0)\n    alpha_boosted = ov_opset.select(\n        is_small_alpha, ov_opset.add(alpha, one), alpha\n    ).output(0)\n    d = ov_opset.subtract(alpha_boosted, one_third).output(0)\n    c = ov_opset.divide(\n        one,\n        ov_opset.sqrt(ov_opset.multiply(_const(9.0, ov_dtype), d)),\n    ).output(0)\n    samples = ov_opset.broadcast(zero, shape).output(0)\n    mask = ov_opset.broadcast(\n        ov_opset.constant(False, Type.boolean), shape\n    ).output(0)\n    num_iters = 10\n    for i in range(num_iters):\n        iter_seed = seed1 + i * 1000\n        x = _random_normal(shape, ov_dtype, iter_seed, seed2)\n        cx = ov_opset.multiply(c, x).output(0)\n        v_base = ov_opset.add(one, cx).output(0)\n        v = ov_opset.power(v_base, _const(3.0, ov_dtype)).output(0)\n        v_pos = ov_opset.greater(v, zero).output(0)\n        u = _random_uniform(shape, zero, one, ov_dtype, iter_seed + 500, seed2)\n        x2 = ov_opset.multiply(x, x).output(0)\n        x4 = ov_opset.multiply(x2, x2).output(0)\n        c1_val = ov_opset.subtract(\n            one, ov_opset.multiply(_const(0.0331, ov_dtype), x4)\n        ).output(0)\n        accept1 = ov_opset.less(u, c1_val).output(0)\n        v_safe = ov_opset.select(v_pos, v, one).output(0)\n        log_u = ov_opset.log(u).output(0)\n        log_v = ov_opset.log(v_safe).output(0)\n        term2 = ov_opset.multiply(\n            d, ov_opset.add(ov_opset.subtract(one, v), log_v)\n        ).output(0)\n        rhs = ov_opset.add(\n            ov_opset.multiply(_const(0.5, ov_dtype), x2), term2\n        ).output(0)\n        accept2 = ov_opset.less(log_u, rhs).output(0)\n        accepted = ov_opset.logical_or(accept1, accept2).output(0)\n        accepted = ov_opset.logical_and(accepted, v_pos).output(0)\n        dv = ov_opset.multiply(d, v).output(0)\n        update_mask = ov_opset.logical_and(\n            ov_opset.logical_not(mask), accepted\n        ).output(0)\n        samples = ov_opset.select(update_mask, dv, samples).output(0)\n        mask = ov_opset.logical_or(mask, accepted).output(0)\n    u_final = _random_uniform(shape, zero, one, ov_dtype, seed1 + 9999, seed2)\n    pow_exp = ov_opset.divide(one, alpha).output(0)\n    u_pow = ov_opset.power(u_final, pow_exp).output(0)\n    adjusted_samples = ov_opset.multiply(samples, u_pow).output(0)\n    final_samples = ov_opset.select(\n        is_small_alpha, adjusted_samples, samples\n    ).output(0)\n    return OpenVINOKerasTensor(final_samples)\n\n\ndef binomial(shape, counts, probabilities, dtype=None, seed=None):\n    dtype = dtype or floatx()\n    ov_dtype = OPENVINO_DTYPES[dtype]\n    seed_val = draw_seed(seed)\n    if isinstance(seed_val, OpenVINOKerasTensor):\n        seed1, seed2 = convert_to_numpy(seed_val)\n    else:\n        seed1, seed2 = seed_val.data\n    counts = get_ov_output(counts)\n    probabilities = get_ov_output(probabilities)\n    calc_dtype = Type.f32\n    counts_f = ov_opset.convert(counts, calc_dtype).output(0)\n    probs_f = ov_opset.convert(probabilities, calc_dtype).output(0)\n    if isinstance(shape, (list, tuple)):\n        shape_tensor = ov_opset.constant(list(shape), Type.i32).output(0)\n    elif isinstance(shape, OpenVINOKerasTensor):\n        shape_tensor = shape.output\n    else:\n        shape_tensor = get_ov_output(shape, Type.i32)\n    zero = ov_opset.constant(0.0, calc_dtype).output(0)\n    one = ov_opset.constant(1.0, calc_dtype).output(0)\n    u1 = _random_uniform(shape_tensor, zero, one, calc_dtype, seed1, seed2)\n    u2 = _random_uniform(shape_tensor, zero, one, calc_dtype, seed1, seed2 + 1)\n    epsilon = 1e-7\n    epsilon_const = ov_opset.constant(epsilon, calc_dtype).output(0)\n    u1_safe = ov_opset.maximum(u1, epsilon_const).output(0)\n    log_u1 = ov_opset.log(u1_safe).output(0)\n    neg_two = ov_opset.constant(-2.0, calc_dtype).output(0)\n    two_pi = ov_opset.constant(2 * np.pi, calc_dtype).output(0)\n    r = ov_opset.sqrt(ov_opset.multiply(neg_two, log_u1)).output(0)\n    theta = ov_opset.multiply(two_pi, u2).output(0)\n    z = ov_opset.multiply(r, ov_opset.cos(theta)).output(0)\n    mean = ov_opset.multiply(counts_f, probs_f).output(0)\n    one_minus_p = ov_opset.subtract(one, probs_f).output(0)\n    var = ov_opset.multiply(mean, one_minus_p).output(0)\n    std = ov_opset.sqrt(var).output(0)\n    res_normal = ov_opset.add(mean, ov_opset.multiply(std, z)).output(0)\n    res_normal = ov_opset.round(res_normal, mode=\"half_to_even\").output(0)\n    res_normal = ov_opset.maximum(res_normal, zero).output(0)\n    res_normal = ov_opset.minimum(res_normal, counts_f).output(0)\n    is_one = ov_opset.equal(counts_f, one).output(0)\n    bernoulli = ov_opset.less(u1, probs_f).output(0)\n    bernoulli_f = ov_opset.convert(bernoulli, calc_dtype).output(0)\n    res = ov_opset.select(is_one, bernoulli_f, res_normal).output(0)\n    if ov_dtype != calc_dtype:\n        res = ov_opset.convert(res, ov_dtype).output(0)\n    return OpenVINOKerasTensor(res)\n\n\ndef beta(shape, alpha, beta, dtype=None, seed=None):\n    seed1 = seed\n    seed2 = seed\n    if isinstance(seed, int):\n        seed2 += 123\n\n    gamma_a = get_ov_output(gamma(shape, alpha, dtype=dtype, seed=seed1))\n    gamma_b = get_ov_output(gamma(shape, beta, dtype=dtype, seed=seed2))\n\n    sum_ab = ov_opset.add(gamma_a, gamma_b).output(0)\n    z = ov_opset.divide(gamma_a, sum_ab).output(0)\n\n    return OpenVINOKerasTensor(z)\n"
  },
  {
    "path": "keras/src/backend/openvino/rnn.py",
    "content": "import openvino.opset15 as ov_opset\nfrom openvino import Model\nfrom openvino import Type\n\nfrom keras.src import tree\nfrom keras.src.backend.openvino.core import OpenVINOKerasTensor\nfrom keras.src.backend.openvino.core import get_ov_output\n\n\ndef rnn(\n    step_function,\n    inputs,\n    initial_states,\n    go_backwards=False,\n    mask=None,\n    constants=None,\n    unroll=False,\n    input_length=None,\n    time_major=False,\n    zero_output_for_mask=False,\n    return_all_outputs=True,\n):\n    def swap_batch_timestep(input_t):\n        axes = list(range(len(input_t.shape)))\n        axes[0], axes[1] = 1, 0\n        perm_const = ov_opset.constant(axes, dtype=Type.i32).output(0)\n        input_ov = get_ov_output(input_t)\n        return OpenVINOKerasTensor(\n            ov_opset.transpose(input_ov, perm_const).output(0)\n        )\n\n    if not time_major:\n        inputs = tree.map_structure(swap_batch_timestep, inputs)\n        if mask is not None:\n            mask = swap_batch_timestep(mask)\n    flattened_inputs = tree.flatten(inputs)\n    flattened_states = tree.flatten(initial_states)\n    flattened_constants = tree.flatten(constants) if constants else []\n    input_0 = flattened_inputs[0]\n    input_0_ov = get_ov_output(input_0)\n    input_shape = ov_opset.shape_of(input_0_ov, Type.i32).output(0)\n    time_steps = ov_opset.gather(\n        input_shape,\n        ov_opset.constant([0], dtype=Type.i32).output(0),\n        ov_opset.constant(0, dtype=Type.i32).output(0),\n    ).output(0)\n    time_steps = ov_opset.squeeze(\n        time_steps, ov_opset.constant([0], dtype=Type.i32).output(0)\n    ).output(0)\n    if mask is None and input_length is not None:\n        input_len_ov = get_ov_output(input_length)\n        if input_len_ov.get_partial_shape().rank.get_length() == 1:\n            indices = ov_opset.range(\n                ov_opset.constant(0, dtype=Type.i32).output(0),\n                time_steps,\n                ov_opset.constant(1, dtype=Type.i32).output(0),\n                output_type=Type.i32,\n            ).output(0)\n            indices = ov_opset.unsqueeze(\n                indices, ov_opset.constant([1], dtype=Type.i32).output(0)\n            ).output(0)\n            input_len_casted = ov_opset.convert(input_len_ov, Type.i32).output(\n                0\n            )\n            input_len_expanded = ov_opset.unsqueeze(\n                input_len_casted,\n                ov_opset.constant([0], dtype=Type.i32).output(0),\n            ).output(0)\n            mask_bool = ov_opset.less(indices, input_len_expanded).output(0)\n            mask = OpenVINOKerasTensor(mask_bool)\n    if mask is not None:\n        mask_ov = get_ov_output(mask)\n        if mask_ov.get_element_type() != Type.boolean:\n            mask_ov = ov_opset.convert(mask_ov, Type.boolean).output(0)\n        pshape = mask_ov.get_partial_shape()\n        rank = pshape.rank.get_length()\n        if rank == 2:\n            mask_ov = ov_opset.unsqueeze(\n                mask_ov, ov_opset.constant([-1], dtype=Type.i32).output(0)\n            ).output(0)\n        mask = OpenVINOKerasTensor(mask_ov)\n    if go_backwards:\n\n        def reverse_time(x):\n            x_ov = get_ov_output(x)\n            start = ov_opset.constant([0], dtype=Type.i32).output(0)\n            idx = ov_opset.range(\n                ov_opset.subtract(\n                    time_steps, ov_opset.constant(1, dtype=Type.i32).output(0)\n                ).output(0),\n                ov_opset.constant(-1, dtype=Type.i32).output(0),\n                ov_opset.constant(-1, dtype=Type.i32).output(0),\n                output_type=Type.i32,\n            ).output(0)\n            return OpenVINOKerasTensor(\n                ov_opset.gather(x_ov, idx, start).output(0)\n            )\n\n        inputs = tree.map_structure(reverse_time, inputs)\n        if mask is not None:\n            mask = reverse_time(mask)\n        flattened_inputs = tree.flatten(inputs)\n\n    def _slice_at_0(x):\n        x_ov = get_ov_output(x)\n        slice_0 = ov_opset.gather(\n            x_ov,\n            ov_opset.constant(0, dtype=Type.i32).output(0),\n            ov_opset.constant(0, dtype=Type.i32).output(0),\n        ).output(0)\n        return OpenVINOKerasTensor(slice_0)\n\n    inputs_0 = tree.map_structure(_slice_at_0, inputs)\n    output_0, _ = step_function(\n        inputs_0, tuple(initial_states) + tuple(constants or [])\n    )\n    flattened_output_0 = tree.flatten(output_0)\n    last_output_states = []\n    for out in flattened_output_0:\n        out_ov = get_ov_output(out)\n        shape = ov_opset.shape_of(out_ov, Type.i32).output(0)\n        dtype = out_ov.get_element_type()\n        zeros = ov_opset.broadcast(\n            ov_opset.constant(0, dtype).output(0), shape\n        ).output(0)\n        last_output_states.append(OpenVINOKerasTensor(zeros))\n    params = []\n    sliced_inputs_params = []\n    for inp in flattened_inputs:\n        inp_ov = get_ov_output(inp)\n        pshape = inp_ov.get_partial_shape()\n        if pshape.rank.is_static:\n            new_shape = [1] + list(pshape)[1:]\n        else:\n            new_shape = None\n        param = ov_opset.parameter(new_shape, inp_ov.get_element_type())\n        sliced_inputs_params.append(param)\n        params.append(param)\n    sliced_mask_params = []\n    if mask is not None:\n        mask_ov = get_ov_output(mask)\n        pshape = mask_ov.get_partial_shape()\n        new_shape = [1] + list(pshape)[1:] if pshape.rank.is_static else None\n        param = ov_opset.parameter(new_shape, mask_ov.get_element_type())\n        sliced_mask_params.append(param)\n        params.append(param)\n    merged_states_params = []\n    for st in flattened_states:\n        st_ov = get_ov_output(st)\n        param = ov_opset.parameter(\n            st_ov.get_partial_shape(), st_ov.get_element_type()\n        )\n        merged_states_params.append(param)\n        params.append(param)\n    last_output_params = []\n    for lo in last_output_states:\n        lo_ov = get_ov_output(lo)\n        param = ov_opset.parameter(\n            lo_ov.get_partial_shape(), lo_ov.get_element_type()\n        )\n        last_output_params.append(param)\n        params.append(param)\n    constants_params = []\n    for c in flattened_constants:\n        c_ov = get_ov_output(c)\n        param = ov_opset.parameter(\n            c_ov.get_partial_shape(), c_ov.get_element_type()\n        )\n        constants_params.append(param)\n        params.append(param)\n    sliced_inputs_t = [\n        OpenVINOKerasTensor(\n            ov_opset.squeeze(\n                p.output(0),\n                ov_opset.constant([0], dtype=Type.i32).output(0),\n            ).output(0)\n        )\n        for p in sliced_inputs_params\n    ]\n    merged_states_t = [\n        OpenVINOKerasTensor(p.output(0)) for p in merged_states_params\n    ]\n    constants_t = [OpenVINOKerasTensor(p.output(0)) for p in constants_params]\n\n    packed_inputs = tree.pack_sequence_as(inputs, sliced_inputs_t)\n    packed_states = tree.pack_sequence_as(initial_states, merged_states_t)\n    step_output, step_new_states = step_function(\n        packed_inputs, tuple(packed_states) + tuple(constants_t)\n    )\n    flat_step_output = tree.flatten(step_output)\n    flat_step_new_states = tree.flatten(step_new_states)\n    final_output_list = []\n    final_states_list = []\n    final_last_output_list = []\n    if mask is not None:\n        mask_t = ov_opset.squeeze(\n            sliced_mask_params[0].output(0),\n            ov_opset.constant([0], dtype=Type.i32).output(0),\n        ).output(0)\n        for i, (new_st, old_st) in enumerate(\n            zip(flat_step_new_states, merged_states_t)\n        ):\n            new_st_ov = get_ov_output(new_st)\n            old_st_ov = get_ov_output(old_st)\n            res = ov_opset.select(mask_t, new_st_ov, old_st_ov).output(0)\n            final_states_list.append(res)\n        for i, (new_out, old_last_out) in enumerate(\n            zip(flat_step_output, last_output_params)\n        ):\n            new_out_ov = get_ov_output(new_out)\n            old_last_out_ov = old_last_out.output(0)\n            last_out_res = ov_opset.select(\n                mask_t, new_out_ov, old_last_out_ov\n            ).output(0)\n            final_last_output_list.append(last_out_res)\n            if zero_output_for_mask:\n                zero = ov_opset.broadcast(\n                    ov_opset.constant(0, new_out_ov.get_element_type()).output(\n                        0\n                    ),\n                    ov_opset.shape_of(new_out_ov, Type.i32).output(0),\n                ).output(0)\n                seq_out_res = ov_opset.select(mask_t, new_out_ov, zero).output(\n                    0\n                )\n            else:\n                seq_out_res = last_out_res\n            final_output_list.append(seq_out_res)\n    else:\n        final_states_list = [get_ov_output(x) for x in flat_step_new_states]\n        final_output_list = [get_ov_output(x) for x in flat_step_output]\n        final_last_output_list = [get_ov_output(x) for x in flat_step_output]\n    unsq_ax = ov_opset.constant([0], dtype=Type.i32).output(0)\n    final_output_list = [\n        ov_opset.unsqueeze(x, unsq_ax).output(0) for x in final_output_list\n    ]\n    cond_const = ov_opset.constant(True, Type.boolean).output(0)\n    results = (\n        [cond_const]\n        + final_states_list\n        + final_last_output_list\n        + final_output_list\n    )\n    body_model = Model(results, params)\n    exec_cond_in = ov_opset.constant(True, Type.boolean).output(0)\n    loop = ov_opset.loop(time_steps, exec_cond_in)\n    loop.set_function(body_model)\n    loop.set_special_body_ports([-1, 0])\n    for param, inp in zip(sliced_inputs_params, flattened_inputs):\n        loop.set_sliced_input(param, get_ov_output(inp), 0, 1, 1, -1, 0)\n    if mask is not None:\n        mask_ov = get_ov_output(mask)\n        loop.set_sliced_input(sliced_mask_params[0], mask_ov, 0, 1, 1, -1, 0)\n    current_res_idx = 1\n    for param, init in zip(merged_states_params, flattened_states):\n        loop.set_merged_input(\n            param, get_ov_output(init), results[current_res_idx]\n        )\n        current_res_idx += 1\n    final_last_outputs_res = []\n    for param, init in zip(last_output_params, last_output_states):\n        loop.set_merged_input(\n            param, get_ov_output(init), results[current_res_idx]\n        )\n        final_last_outputs_res.append(results[current_res_idx])\n        current_res_idx += 1\n    for param, val in zip(constants_params, flattened_constants):\n        loop.set_invariant_input(param, get_ov_output(val))\n    loop_outputs = []\n    for _ in final_output_list:\n        out = loop.get_concatenated_slices(\n            results[current_res_idx], 0, 1, 1, -1, 0\n        )\n        loop_outputs.append(OpenVINOKerasTensor(out))\n        current_res_idx += 1\n    loop_final_states = []\n    st_res_idx = 1\n    for _ in flattened_states:\n        out = loop.get_iter_value(results[st_res_idx], -1)\n        loop_final_states.append(OpenVINOKerasTensor(out))\n        st_res_idx += 1\n    lo_res_idx = st_res_idx\n    loop_final_last_outputs = []\n    for _ in last_output_states:\n        out = loop.get_iter_value(results[lo_res_idx], -1)\n        loop_final_last_outputs.append(OpenVINOKerasTensor(out))\n        lo_res_idx += 1\n    outputs = tree.pack_sequence_as(output_0, loop_outputs)\n    new_states = tree.pack_sequence_as(initial_states, loop_final_states)\n    last_output = tree.pack_sequence_as(output_0, loop_final_last_outputs)\n    if not time_major:\n        outputs = tree.map_structure(swap_batch_timestep, outputs)\n    return last_output, outputs, new_states\n\n\ndef _reorder_gates(x_ov, from_order, to_order, axis):\n    \"\"\"Reorder gate slices of a tensor along `axis`.\n\n    `from_order` and `to_order` are lists of single-char gate names, e.g.\n    from_order=['i','f','c','o'], to_order=['f','i','c','o'].\n    The tensor dimension along `axis` must be divisible by len(from_order).\n    \"\"\"\n    n_gates = len(from_order)\n    axis_const = ov_opset.constant(axis, dtype=Type.i32).output(0)\n    chunks = ov_opset.split(x_ov, axis_const, n_gates).outputs()\n    gate_map = {g: chunks[i] for i, g in enumerate(from_order)}\n    reordered = [gate_map[g] for g in to_order]\n    return ov_opset.concat(reordered, axis=axis).output(0)\n\n\ndef _seq_lengths(inputs_ov):\n    \"\"\"Return int32 sequence-length tensor [batch] equal to full time steps.\"\"\"\n    input_shape = ov_opset.shape_of(inputs_ov, Type.i32).output(0)\n    batch_size = ov_opset.gather(\n        input_shape,\n        ov_opset.constant([0], dtype=Type.i32).output(0),\n        ov_opset.constant(0, dtype=Type.i32).output(0),\n    ).output(0)\n    time_steps = ov_opset.gather(\n        input_shape,\n        ov_opset.constant([1], dtype=Type.i32).output(0),\n        ov_opset.constant(0, dtype=Type.i32).output(0),\n    ).output(0)\n    return ov_opset.broadcast(time_steps, batch_size).output(0)\n\n\ndef lstm(\n    inputs,\n    initial_h,\n    initial_c,\n    mask,\n    kernel,\n    recurrent_kernel,\n    bias,\n    activation,\n    recurrent_activation,\n    return_sequences=False,\n    go_backwards=False,\n    unroll=False,\n):\n    act_name = getattr(activation, \"__name__\", None)\n    rec_act_name = getattr(recurrent_activation, \"__name__\", None)\n    if not (\n        act_name == \"tanh\"\n        and rec_act_name == \"sigmoid\"\n        and not unroll\n        and bias is not None\n        and mask is None\n    ):\n        raise NotImplementedError\n\n    inputs_ov = get_ov_output(inputs)\n    initial_h_ov = get_ov_output(initial_h)\n    initial_c_ov = get_ov_output(initial_c)\n    kernel_ov = get_ov_output(kernel)\n    recurrent_kernel_ov = get_ov_output(recurrent_kernel)\n    bias_ov = get_ov_output(bias)\n\n    hidden_size = recurrent_kernel_ov.get_partial_shape()[0].get_length()\n\n    kt = ov_opset.transpose(\n        kernel_ov,\n        ov_opset.constant([1, 0], dtype=Type.i32).output(0),\n    ).output(0)\n    w = _reorder_gates(kt, [\"i\", \"f\", \"c\", \"o\"], [\"f\", \"i\", \"c\", \"o\"], axis=0)\n    w = ov_opset.unsqueeze(\n        w, ov_opset.constant([0], dtype=Type.i32).output(0)\n    ).output(0)\n\n    rt = ov_opset.transpose(\n        recurrent_kernel_ov,\n        ov_opset.constant([1, 0], dtype=Type.i32).output(0),\n    ).output(0)\n    r = _reorder_gates(rt, [\"i\", \"f\", \"c\", \"o\"], [\"f\", \"i\", \"c\", \"o\"], axis=0)\n    r = ov_opset.unsqueeze(\n        r, ov_opset.constant([0], dtype=Type.i32).output(0)\n    ).output(0)\n\n    b = _reorder_gates(\n        bias_ov, [\"i\", \"f\", \"c\", \"o\"], [\"f\", \"i\", \"c\", \"o\"], axis=0\n    )\n    b = ov_opset.unsqueeze(\n        b, ov_opset.constant([0], dtype=Type.i32).output(0)\n    ).output(0)\n\n    h0 = ov_opset.unsqueeze(\n        initial_h_ov, ov_opset.constant([1], dtype=Type.i32).output(0)\n    ).output(0)\n    c0 = ov_opset.unsqueeze(\n        initial_c_ov, ov_opset.constant([1], dtype=Type.i32).output(0)\n    ).output(0)\n\n    seq_lens = _seq_lengths(inputs_ov)\n    direction = \"reverse\" if go_backwards else \"forward\"\n\n    lstm_out = ov_opset.lstm_sequence(\n        inputs_ov, h0, c0, seq_lens, w, r, b, hidden_size, direction\n    )\n    dir_axis = ov_opset.constant([1], dtype=Type.i32).output(0)\n    all_outputs = ov_opset.squeeze(lstm_out.output(0), dir_axis).output(0)\n    h_n = ov_opset.squeeze(lstm_out.output(1), dir_axis).output(0)\n    c_n = ov_opset.squeeze(lstm_out.output(2), dir_axis).output(0)\n\n    if go_backwards:\n        shape = ov_opset.shape_of(all_outputs, Type.i32).output(0)\n        time_len = ov_opset.gather(\n            shape,\n            ov_opset.constant(1, dtype=Type.i32).output(0),\n            ov_opset.constant(0, dtype=Type.i32).output(0),\n        ).output(0)\n        idx = ov_opset.range(\n            ov_opset.subtract(\n                time_len, ov_opset.constant(1, dtype=Type.i32).output(0)\n            ).output(0),\n            ov_opset.constant(-1, dtype=Type.i32).output(0),\n            ov_opset.constant(-1, dtype=Type.i32).output(0),\n            output_type=Type.i32,\n        ).output(0)\n        all_outputs = ov_opset.gather(\n            all_outputs,\n            idx,\n            ov_opset.constant(1, dtype=Type.i32).output(0),\n        ).output(0)\n\n    return (\n        OpenVINOKerasTensor(h_n),\n        OpenVINOKerasTensor(all_outputs),\n        [OpenVINOKerasTensor(h_n), OpenVINOKerasTensor(c_n)],\n    )\n\n\ndef gru(\n    inputs,\n    initial_state,\n    mask,\n    kernel,\n    recurrent_kernel,\n    bias,\n    activation,\n    recurrent_activation,\n    return_sequences=False,\n    go_backwards=False,\n    unroll=False,\n    reset_after=True,\n):\n    act_name = getattr(activation, \"__name__\", None)\n    rec_act_name = getattr(recurrent_activation, \"__name__\", None)\n    if not (\n        act_name == \"tanh\"\n        and rec_act_name == \"sigmoid\"\n        and not unroll\n        and bias is not None\n        and reset_after\n        and mask is None\n    ):\n        raise NotImplementedError\n\n    inputs_ov = get_ov_output(inputs)\n    initial_state_ov = get_ov_output(initial_state)\n    kernel_ov = get_ov_output(kernel)\n    recurrent_kernel_ov = get_ov_output(recurrent_kernel)\n    bias_ov = get_ov_output(bias)\n\n    hidden_size = recurrent_kernel_ov.get_partial_shape()[0].get_length()\n\n    w = ov_opset.transpose(\n        kernel_ov,\n        ov_opset.constant([1, 0], dtype=Type.i32).output(0),\n    ).output(0)\n    w = ov_opset.unsqueeze(\n        w, ov_opset.constant([0], dtype=Type.i32).output(0)\n    ).output(0)\n\n    r = ov_opset.transpose(\n        recurrent_kernel_ov,\n        ov_opset.constant([1, 0], dtype=Type.i32).output(0),\n    ).output(0)\n    r = ov_opset.unsqueeze(\n        r, ov_opset.constant([0], dtype=Type.i32).output(0)\n    ).output(0)\n\n    # Keras bias [2, 3*units]: row 0 = input biases [b_z, b_r, b_h],\n    # row 1 = recurrent biases [rb_z, rb_r, rb_h].\n    # OV gru_sequence (linear_before_reset=True) wants B [1, 4*units]:\n    # [b_z+rb_z, b_r+rb_r, b_h, rb_h]\n    ax = ov_opset.constant(0, dtype=Type.i32).output(0)\n    b_input = ov_opset.gather(\n        bias_ov, ov_opset.constant(0, dtype=Type.i32).output(0), ax\n    ).output(0)\n    b_recur = ov_opset.gather(\n        bias_ov, ov_opset.constant(1, dtype=Type.i32).output(0), ax\n    ).output(0)\n    split_ax = ov_opset.constant(0, dtype=Type.i32).output(0)\n    b_in_parts = ov_opset.split(b_input, split_ax, 3).outputs()\n    b_rc_parts = ov_opset.split(b_recur, split_ax, 3).outputs()\n    b_z = ov_opset.add(b_in_parts[0], b_rc_parts[0]).output(0)\n    b_r = ov_opset.add(b_in_parts[1], b_rc_parts[1]).output(0)\n    b_h = b_in_parts[2]\n    rb_h = b_rc_parts[2]\n    b = ov_opset.concat([b_z, b_r, b_h, rb_h], axis=0).output(0)\n    b = ov_opset.unsqueeze(\n        b, ov_opset.constant([0], dtype=Type.i32).output(0)\n    ).output(0)\n\n    h0 = ov_opset.unsqueeze(\n        initial_state_ov, ov_opset.constant([1], dtype=Type.i32).output(0)\n    ).output(0)\n\n    seq_lens = _seq_lengths(inputs_ov)\n    direction = \"reverse\" if go_backwards else \"forward\"\n\n    gru_out = ov_opset.gru_sequence(\n        inputs_ov,\n        h0,\n        seq_lens,\n        w,\n        r,\n        b,\n        hidden_size,\n        direction,\n        linear_before_reset=True,\n    )\n    dir_axis = ov_opset.constant([1], dtype=Type.i32).output(0)\n    all_outputs = ov_opset.squeeze(gru_out.output(0), dir_axis).output(0)\n    h_n = ov_opset.squeeze(gru_out.output(1), dir_axis).output(0)\n\n    if go_backwards:\n        # OV direction=\"reverse\" outputs Y in original time order\n        # (Y[0]=fully-accumulated state). Keras go_backwards expects\n        # Y[0]=state after first reversed step. Flip time axis to match.\n        shape = ov_opset.shape_of(all_outputs, Type.i32).output(0)\n        time_len = ov_opset.gather(\n            shape,\n            ov_opset.constant(1, dtype=Type.i32).output(0),\n            ov_opset.constant(0, dtype=Type.i32).output(0),\n        ).output(0)\n        idx = ov_opset.range(\n            ov_opset.subtract(\n                time_len, ov_opset.constant(1, dtype=Type.i32).output(0)\n            ).output(0),\n            ov_opset.constant(-1, dtype=Type.i32).output(0),\n            ov_opset.constant(-1, dtype=Type.i32).output(0),\n            output_type=Type.i32,\n        ).output(0)\n        all_outputs = ov_opset.gather(\n            all_outputs,\n            idx,\n            ov_opset.constant(1, dtype=Type.i32).output(0),\n        ).output(0)\n\n    return (\n        OpenVINOKerasTensor(h_n),\n        OpenVINOKerasTensor(all_outputs),\n        [OpenVINOKerasTensor(h_n)],\n    )\n\n\ndef cudnn_ok(*args, **kwargs):\n    return False\n"
  },
  {
    "path": "keras/src/backend/openvino/trainer.py",
    "content": "import numpy as np\nimport openvino as ov\nimport openvino.opset15 as ov_opset\n\nfrom keras.src import backend\nfrom keras.src import callbacks as callbacks_module\nfrom keras.src import tree\nfrom keras.src.backend.openvino.core import OPENVINO_DTYPES\nfrom keras.src.backend.openvino.core import OpenVINOKerasTensor\nfrom keras.src.backend.openvino.core import get_device\nfrom keras.src.trainers import trainer as base_trainer\nfrom keras.src.trainers.data_adapters import data_adapter_utils\nfrom keras.src.trainers.epoch_iterator import EpochIterator\nfrom keras.src.utils import traceback_utils\nfrom keras.src.utils.python_utils import pythonify_logs\n\n\nclass OpenVINOTrainer(base_trainer.Trainer):\n    def __init__(self):\n        super().__init__()\n        self.test_function = None\n        self.predict_function = None\n        self.ov_compiled_model = None\n        self.ov_device = None\n        self.struct_params = None\n        self.struct_outputs = None\n\n    def _unpack_singleton(self, x):\n        if isinstance(x, (list, tuple)) and len(x) == 1:\n            return x[0]\n        return x\n\n    def test_step(self, data):\n        x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)\n        ov_compiled_model = self._get_compiled_model(x)\n        flatten_x = tree.flatten(x)\n        y_pred = ov_compiled_model(flatten_x)\n        y_pred = self._unpack_singleton(\n            tree.pack_sequence_as(self.struct_outputs, y_pred.to_tuple())\n        )\n        loss = self._compute_loss(\n            x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=False\n        )\n        loss = backend.convert_to_numpy(loss)\n        self._loss_tracker.update_state(\n            loss, sample_weight=tree.flatten(x)[0].shape[0]\n        )\n        return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)\n\n    def predict_step(self, data):\n        x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data)\n        ov_compiled_model = self._get_compiled_model(x)\n        flatten_x = tree.flatten(x)\n        y_pred = ov_compiled_model(flatten_x)\n        # recover structure of the model output\n        y_pred = self._unpack_singleton(\n            tree.pack_sequence_as(self.struct_outputs, y_pred.to_tuple())\n        )\n        return y_pred\n\n    def make_test_function(self, force=False):\n        if self.test_function is not None and not force:\n            return self.test_function\n\n        def one_test_step(data):\n            data = data[0]\n            return self.test_step(data)\n\n        def multi_test_steps(data):\n            for single_step_data in data:\n                logs = one_test_step([single_step_data])\n            return logs\n\n        if self.steps_per_execution > 1:\n            test_step = multi_test_steps\n        else:\n            test_step = one_test_step\n\n        self.test_function = test_step\n\n    def _parameterize_data(self, data):\n        if isinstance(data, (list, tuple)):\n            parametrize_data = []\n            for elem in data:\n                param_elem = self._parameterize_data(elem)\n                parametrize_data.append(param_elem)\n        elif isinstance(data, dict):\n            parametrize_data = dict()\n            for elem_name, elem in data.items():\n                param_elem = self._parameterize_data(elem)\n                parametrize_data[elem_name] = param_elem\n        elif isinstance(data, np.ndarray) or np.isscalar(data):\n            ov_type = OPENVINO_DTYPES[str(data.dtype)]\n            ov_shape = list(data.shape)\n            param = ov_opset.parameter(shape=ov_shape, dtype=ov_type)\n            parametrize_data = OpenVINOKerasTensor(param.output(0))\n        elif isinstance(data, int):\n            param = ov_opset.parameter(shape=[], dtype=ov.Type.i32)\n            parametrize_data = OpenVINOKerasTensor(param.output(0))\n        elif isinstance(data, float):\n            param = ov_opset.parameter(shape=[], dtype=ov.Type.f32)\n            parametrize_data = OpenVINOKerasTensor(param.output(0))\n        else:\n            raise \"Unknown type of input data {}\".format(type(data))\n        return parametrize_data\n\n    def _get_data_shapes(self, data):\n        shapes = []\n        for x in tree.flatten(data):\n            if isinstance(x, np.ndarray):\n                shapes.append(x.shape)\n            else:\n                shapes.append(None)\n        return shapes\n\n    def _get_compiled_model(self, data):\n        current_shapes = self._get_data_shapes(data)\n        if (\n            self.ov_compiled_model is not None\n            and get_device() == self.ov_device\n            and getattr(self, \"ov_input_shapes\", None) == current_shapes\n        ):\n            return self.ov_compiled_model\n\n        # remove the previous cached compiled model if exists\n        del self.ov_compiled_model\n\n        # prepare parameterized input\n        self.struct_params = self._parameterize_data(data)\n        # construct OpenVINO graph during calling Keras Model\n        self.struct_outputs = self(self.struct_params)\n\n        parameters = []\n        for p in tree.flatten(self.struct_params):\n            parameters.append(p.output.get_node())\n        results = []\n        for r in tree.flatten(self.struct_outputs):\n            results.append(ov_opset.result(r.output))\n\n        # prepare compiled model from scratch\n        ov_model = ov.Model(results=results, parameters=parameters)\n        self.ov_compiled_model = ov.compile_model(ov_model, get_device())\n        self.ov_device = get_device()\n        self.ov_input_shapes = current_shapes\n        return self.ov_compiled_model\n\n    def make_predict_function(self, force=False):\n        if self.predict_function is not None and not force:\n            return self.predict_function\n\n        def one_predict_step(data):\n            data = data[0]\n            return self.predict_step(data)\n\n        def multi_predict_steps(data):\n            outputs = one_predict_step(data[:1])\n\n            for single_step_data in data[1:]:\n                step_outputs = one_predict_step([single_step_data])\n                outputs = tree.map_structure(\n                    lambda t1, t2: np.concatenate([t1, t2]),\n                    outputs,\n                    step_outputs,\n                )\n            return outputs\n\n        if self.steps_per_execution > 1:\n            predict_step = multi_predict_steps\n        else:\n            predict_step = one_predict_step\n\n        self.predict_function = predict_step\n\n    def fit(\n        self,\n        x=None,\n        y=None,\n        batch_size=None,\n        epochs=1,\n        verbose=\"auto\",\n        callbacks=None,\n        validation_split=0.0,\n        validation_data=None,\n        shuffle=True,\n        class_weight=None,\n        sample_weight=None,\n        initial_epoch=0,\n        steps_per_epoch=None,\n        validation_steps=None,\n        validation_batch_size=None,\n        validation_freq=1,\n    ):\n        raise NotImplementedError(\n            \"`fit` is not supported with openvino backend\"\n        )\n\n    @traceback_utils.filter_traceback\n    def predict(\n        self, x, batch_size=None, verbose=\"auto\", steps=None, callbacks=None\n    ):\n        # Create an iterator that yields batches of input data.\n        epoch_iterator = EpochIterator(\n            x=x,\n            batch_size=batch_size,\n            steps_per_epoch=steps,\n            shuffle=False,\n            steps_per_execution=self.steps_per_execution,\n        )\n\n        # Container that configures and calls callbacks.\n        if not isinstance(callbacks, callbacks_module.CallbackList):\n            callbacks = callbacks_module.CallbackList(\n                callbacks,\n                add_history=True,\n                add_progbar=verbose != 0,\n                verbose=verbose,\n                epochs=1,\n                steps=epoch_iterator.num_batches,\n                model=self,\n            )\n\n        def append_to_outputs(batch_outputs, outputs):\n            if outputs is None:\n                outputs = tree.map_structure(\n                    lambda batch_output: [batch_output],\n                    batch_outputs,\n                )\n            else:\n                tree.map_structure_up_to(\n                    batch_outputs,\n                    lambda output, batch_output: output.append(batch_output),\n                    outputs,\n                    batch_outputs,\n                )\n            return outputs\n\n        self.make_predict_function()\n        self.stop_predicting = False\n        callbacks.on_predict_begin()\n        outputs = None\n        for begin_step, end_step, data in epoch_iterator.enumerate_epoch():\n            callbacks.on_predict_batch_begin(begin_step)\n            batch_outputs = self.predict_function(data)\n            outputs = append_to_outputs(batch_outputs, outputs)\n            callbacks.on_predict_batch_end(end_step, {\"outputs\": batch_outputs})\n            if self.stop_predicting:\n                break\n        callbacks.on_predict_end()\n        return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs)\n\n    @traceback_utils.filter_traceback\n    def evaluate(\n        self,\n        x=None,\n        y=None,\n        batch_size=None,\n        verbose=\"auto\",\n        sample_weight=None,\n        steps=None,\n        callbacks=None,\n        return_dict=False,\n        **kwargs,\n    ):\n        use_cached_eval_dataset = kwargs.pop(\"_use_cached_eval_dataset\", False)\n        if kwargs:\n            raise ValueError(f\"Arguments not recognized: {kwargs}\")\n\n        if use_cached_eval_dataset:\n            epoch_iterator = self._eval_epoch_iterator\n        else:\n            epoch_iterator = EpochIterator(\n                x=x,\n                y=y,\n                sample_weight=sample_weight,\n                batch_size=batch_size,\n                steps_per_epoch=steps,\n                shuffle=False,\n                steps_per_execution=self.steps_per_execution,\n            )\n\n        if not isinstance(callbacks, callbacks_module.CallbackList):\n            callbacks = callbacks_module.CallbackList(\n                callbacks,\n                add_progbar=verbose != 0,\n                verbose=verbose,\n                epochs=1,\n                steps=epoch_iterator.num_batches,\n                model=self,\n            )\n\n        self.make_test_function()\n        self.stop_evaluating = False\n        callbacks.on_test_begin()\n        logs = {}\n        self.reset_metrics()\n        for begin_step, end_step, data in epoch_iterator:\n            callbacks.on_test_batch_begin(begin_step)\n            logs = self.test_function(data)\n            callbacks.on_test_batch_end(end_step, logs)\n            if self.stop_evaluating:\n                break\n        logs = pythonify_logs(self._get_metrics_result_or_logs(logs))\n        callbacks.on_test_end(logs)\n\n        if return_dict:\n            return logs\n        return self._flatten_metrics_in_order(logs)\n\n    def train_on_batch(\n        self,\n        x,\n        y=None,\n        sample_weight=None,\n        class_weight=None,\n        return_dict=False,\n    ):\n        raise NotImplementedError(\n            \"`train_on_batch` is not supported with openvino backend\"\n        )\n\n    def test_on_batch(\n        self,\n        x,\n        y=None,\n        sample_weight=None,\n        return_dict=False,\n    ):\n        self._assert_compile_called(\"test_on_batch\")\n        self.make_test_function()\n        self.reset_metrics()\n        logs = self.test_function([(x, y, sample_weight)])\n        logs = pythonify_logs(self._get_metrics_result_or_logs(logs))\n        if return_dict:\n            return logs\n        return self._flatten_metrics_in_order(logs)\n\n    def predict_on_batch(self, x):\n        self.make_predict_function()\n        batch_outputs = self.predict_function([(x,)])\n        batch_outputs = tree.map_structure(\n            backend.convert_to_numpy, batch_outputs\n        )\n        return batch_outputs\n"
  },
  {
    "path": "keras/src/backend/tensorflow/__init__.py",
    "content": "from keras.src.backend.tensorflow import core\nfrom keras.src.backend.tensorflow import distribution_lib\nfrom keras.src.backend.tensorflow import image\nfrom keras.src.backend.tensorflow import linalg\nfrom keras.src.backend.tensorflow import math\nfrom keras.src.backend.tensorflow import nn\nfrom keras.src.backend.tensorflow import numpy\nfrom keras.src.backend.tensorflow import random\nfrom keras.src.backend.tensorflow import tensorboard\nfrom keras.src.backend.tensorflow.core import IS_THREAD_SAFE\nfrom keras.src.backend.tensorflow.core import SUPPORTS_RAGGED_TENSORS\nfrom keras.src.backend.tensorflow.core import SUPPORTS_SPARSE_TENSORS\nfrom keras.src.backend.tensorflow.core import Variable\nfrom keras.src.backend.tensorflow.core import cast\nfrom keras.src.backend.tensorflow.core import compute_output_spec\nfrom keras.src.backend.tensorflow.core import cond\nfrom keras.src.backend.tensorflow.core import convert_to_numpy\nfrom keras.src.backend.tensorflow.core import convert_to_tensor\nfrom keras.src.backend.tensorflow.core import device_scope\nfrom keras.src.backend.tensorflow.core import is_tensor\nfrom keras.src.backend.tensorflow.core import name_scope\nfrom keras.src.backend.tensorflow.core import random_seed_dtype\nfrom keras.src.backend.tensorflow.core import scatter\nfrom keras.src.backend.tensorflow.core import shape\nfrom keras.src.backend.tensorflow.core import stop_gradient\nfrom keras.src.backend.tensorflow.core import vectorized_map\nfrom keras.src.backend.tensorflow.rnn import cudnn_ok\nfrom keras.src.backend.tensorflow.rnn import gru\nfrom keras.src.backend.tensorflow.rnn import lstm\nfrom keras.src.backend.tensorflow.rnn import rnn\n"
  },
  {
    "path": "keras/src/backend/tensorflow/core.py",
    "content": "import builtins\n\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow.compiler.tf2xla.python.xla import dynamic_update_slice\n\nfrom keras.src import tree\nfrom keras.src.backend.common import KerasVariable\nfrom keras.src.backend.common import global_state\nfrom keras.src.backend.common import is_int_dtype\nfrom keras.src.backend.common import standardize_dtype\nfrom keras.src.backend.common.backend_utils import slice_along_axis\nfrom keras.src.backend.common.keras_tensor import KerasTensor\nfrom keras.src.backend.common.name_scope import name_scope as base_name_scope\nfrom keras.src.backend.common.stateless_scope import StatelessScope\nfrom keras.src.backend.common.stateless_scope import in_stateless_scope\nfrom keras.src.backend.common.symbolic_scope import SymbolicScope\nfrom keras.src.backend.tensorflow.sparse import sparse_to_dense\nfrom keras.src.utils.naming import auto_name\n\nSUPPORTS_SPARSE_TENSORS = True\nSUPPORTS_RAGGED_TENSORS = True\n# https://github.com/tensorflow/tensorflow/issues/78338\nIS_THREAD_SAFE = False\n\n\nclass Variable(\n    KerasVariable,\n    tf.__internal__.types.Tensor,\n    tf.__internal__.tracking.Trackable,\n):\n    _should_act_as_resource_variable = True\n\n    @property\n    def handle(self):\n        return self.value.handle\n\n    def _initialize(self, value):\n        if isinstance(value, tf.Variable):\n            self._value = value\n        else:\n            self._value = tf.Variable(\n                value,\n                dtype=self._dtype,\n                trainable=self.trainable,\n                name=self.name,\n                aggregation=self._map_aggregation(self.aggregation),\n                synchronization=self._map_synchronization(self.synchronization),\n            )\n\n    def _initialize_with_initializer(self, initializer):\n        self._initialize(lambda: initializer(self._shape, dtype=self._dtype))\n\n    def _deferred_initialize(self):\n        if self._value is not None:\n            raise ValueError(f\"Variable {self.path} is already initialized.\")\n\n        if in_stateless_scope():\n            raise ValueError(\n                \"You are attempting to initialize a variable \"\n                \"while in a stateless scope. This is disallowed. \"\n                \"Make sure that all variables are initialized \"\n                \"before you start using your layer/model objects.\"\n            )\n        with tf.init_scope():\n            self._initialize_with_initializer(self._initializer)\n            self._initializer = None\n\n    def _direct_assign(self, value):\n        self._value.assign(tf.cast(value, self._value.dtype))\n\n    def _convert_to_tensor(self, value, dtype=None):\n        return convert_to_tensor(value, dtype=dtype)\n\n    def numpy(self):  # noqa: F811\n        return self.value.numpy()\n\n    @property\n    def shape(self):\n        return tf.TensorShape(super().shape)\n\n    # Overload native accessor.\n    def __tf_tensor__(self, dtype=None, name=None):\n        return tf.convert_to_tensor(self.value, dtype=dtype, name=name)\n\n    # Methods below are for SavedModel support\n    @property\n    def _shared_name(self):\n        return self.value._shared_name\n\n    def _serialize_to_tensors(self):\n        try:\n            return self.value._serialize_to_tensors()\n        except NotImplementedError:\n            return {\"VARIABLE_VALUE\": self.value}\n\n    def _restore_from_tensors(self, restored_tensors):\n        try:\n            return self.value._restore_from_tensors(restored_tensors)\n        except NotImplementedError:\n            self.assign(restored_tensors[\"VARIABLE_VALUE\"])\n            return self.value\n\n    def _copy_trackable_to_cpu(self, object_map):\n        self.value._copy_trackable_to_cpu(object_map)\n        object_map[self] = tf.Variable(object_map[self.value])\n\n    def _export_to_saved_model_graph(\n        self, object_map, tensor_map, options, **kwargs\n    ):\n        resource_list = self.value._export_to_saved_model_graph(\n            object_map, tensor_map, options, **kwargs\n        )\n        object_map[self] = tf.Variable(object_map[self.value])\n        return resource_list\n\n    def _write_object_proto(self, proto, options):\n        return self.value._write_object_proto(proto, options)\n\n    def _map_aggregation(self, aggregation):\n        mapping = {\n            \"none\": tf.VariableAggregation.NONE,\n            \"sum\": tf.VariableAggregation.SUM,\n            \"mean\": tf.VariableAggregation.MEAN,\n            \"only_first_replica\": tf.VariableAggregation.ONLY_FIRST_REPLICA,\n        }\n        return mapping[aggregation]\n\n    def _map_synchronization(self, synchronization):\n        mapping = {\n            \"none\": tf.VariableSynchronization.NONE,\n            \"on_read\": tf.VariableSynchronization.ON_READ,\n            \"on_write\": tf.VariableSynchronization.ON_WRITE,\n            \"auto\": tf.VariableSynchronization.AUTO,\n        }\n        return mapping[synchronization]\n\n\ndef convert_to_tensor(x, dtype=None, sparse=None, ragged=None):\n    if isinstance(x, tf.SparseTensor) and sparse is not None and not sparse:\n        x = sparse_to_dense(x)\n    if isinstance(x, tf.RaggedTensor) and ragged is not None and not ragged:\n        x = x.to_tensor()\n    if dtype is not None:\n        dtype = standardize_dtype(dtype)\n    if not tf.is_tensor(x):\n        if dtype == \"bool\" or is_int_dtype(dtype):\n            # TensorFlow conversion is stricter than other backends, it does not\n            # allow ints for bools or floats for ints. We convert without dtype\n            # and cast instead.\n            x = tf.convert_to_tensor(x)\n            return tf.cast(x, dtype)\n        return tf.convert_to_tensor(x, dtype=dtype)\n    elif dtype is not None and not standardize_dtype(x.dtype) == dtype:\n        if isinstance(x, tf.SparseTensor):\n            x_shape = x.shape\n            x = tf.cast(x, dtype)\n            x.set_shape(x_shape)\n            return x\n        return tf.cast(x, dtype=dtype)\n    return x\n\n\ndef convert_to_numpy(x):\n    if isinstance(x, tf.SparseTensor):\n        x = sparse_to_dense(x)\n    elif isinstance(x, tf.IndexedSlices):\n        x = tf.convert_to_tensor(x)\n    elif isinstance(x, tf.RaggedTensor):\n        x = x.to_tensor()\n    return np.array(x)\n\n\ndef is_tensor(x):\n    return tf.is_tensor(x)\n\n\ndef shape(x):\n    \"\"\"Always return a tuple shape.\n\n    `tf.shape` will return a `tf.Tensor`, which differs from the tuple return\n    type on the torch and jax backends. We write our own method instead which\n    always returns a tuple, with integer values when the shape is known, and\n    tensor values when the shape is unknown (this is tf specific, as dynamic\n    shapes do not apply in other backends).\n    \"\"\"\n    if isinstance(x, KerasTensor):\n        return x.shape\n    if not tf.is_tensor(x):\n        x = tf.convert_to_tensor(x)\n    if x.shape == tf.TensorShape(None):\n        raise ValueError(\n            \"All tensors passed to `ops.shape` must have a statically known \"\n            f\"rank. Received: x={x} with unknown rank.\"\n        )\n    shape = x.shape.as_list()\n    dynamic = tf.shape(x)\n    for i in range(len(shape)):\n        if shape[i] is None:\n            try:\n                shape[i] = dynamic[i]\n            except:\n                # With RaggedTensors, accessing a ragged dimension will fail,\n                # we leave it as None.\n                pass\n    return tuple(shape)\n\n\ndef cast(x, dtype):\n    dtype = standardize_dtype(dtype)\n    if isinstance(x, tf.SparseTensor):\n        x_shape = x.shape\n        x = tf.cast(x, dtype)\n        x.set_shape(x_shape)\n        return x\n    else:\n        return tf.cast(x, dtype=dtype)\n\n\ndef compute_output_spec(fn, *args, **kwargs):\n    with StatelessScope(), SymbolicScope():\n        graph_name = auto_name(\"scratch_graph\")\n        with tf.__internal__.FuncGraph(graph_name).as_default():\n\n            def convert_keras_tensor_to_tf(x):\n                if isinstance(x, KerasTensor):\n                    if x.sparse:\n                        return tf.compat.v1.sparse_placeholder(\n                            shape=x.shape, dtype=x.dtype\n                        )\n                    else:\n                        return tf.compat.v1.placeholder(\n                            shape=x.shape, dtype=x.dtype\n                        )\n                return x\n\n            args, kwargs = tree.map_structure(\n                convert_keras_tensor_to_tf, (args, kwargs)\n            )\n            tf_out = fn(*args, **kwargs)\n\n            def convert_tf_to_keras_tensor(x):\n                if tf.is_tensor(x):\n                    return KerasTensor(\n                        x.shape, x.dtype, sparse=isinstance(x, tf.SparseTensor)\n                    )\n                return x\n\n            output_spec = tree.map_structure(convert_tf_to_keras_tensor, tf_out)\n    return output_spec\n\n\ndef cond(pred, true_fn, false_fn):\n    if isinstance(pred, tf.Variable):\n        return tf.cond(pred, true_fn=true_fn, false_fn=false_fn)\n    return tf.__internal__.smart_cond.smart_cond(\n        pred, true_fn=true_fn, false_fn=false_fn\n    )\n\n\ndef vectorized_map(function, elements):\n    return tf.vectorized_map(function, elements)\n\n\ndef map(f, xs):\n    xs = tree.map_structure(convert_to_tensor, xs)\n\n    def get_fn_output_signature(x):\n        out = f(x)\n        return tree.map_structure(tf.TensorSpec.from_tensor, out)\n\n    if tree.is_nested(xs):\n        input = tree.pack_sequence_as(xs, [x[0] for x in tree.flatten(xs)])\n        fn_output_signature = get_fn_output_signature(input)\n        return tf.map_fn(f, xs, fn_output_signature=fn_output_signature)\n    else:\n        fn_output_signature = get_fn_output_signature(xs[0])\n        return tf.map_fn(f, xs, fn_output_signature=fn_output_signature)\n\n\ndef scan(f, init, xs=None, length=None, reverse=False, unroll=1):\n    # We have reimplemented `scan` to match the behavior of `jax.lax.scan`\n    # Ref: tf.scan, jax.lax.scan\n    if not callable(f):\n        raise TypeError(f\"`f` should be a callable. Received: f={f}\")\n    if not isinstance(unroll, bool):\n        if not isinstance(unroll, int) or unroll < 1:\n            raise ValueError(\n                \"`unroll` must be an positive integer or boolean. \"\n                f\"Received: unroll={unroll}\"\n            )\n    if xs is None and length is None:\n        raise ValueError(\"Got no `xs` to scan over and `length` not provided.\")\n\n    input_is_sequence = tree.is_nested(xs)\n    output_is_sequence = tree.is_nested(init)\n\n    def pack_input(x):\n        return tree.pack_sequence_as(xs, x) if input_is_sequence else x[0]\n\n    def pack_output(x):\n        return tree.pack_sequence_as(init, x) if output_is_sequence else x[0]\n\n    if xs is None:\n        xs_flat = []\n        n = int(length)\n    else:\n        # xs_flat = flatten_input(xs)\n        xs_flat = tree.flatten(xs)\n        xs_flat = [tf.convert_to_tensor(elem) for elem in xs_flat]\n        n = int(length) if length is not None else tf.shape(xs_flat[0])[0]\n\n    # TensorArrays are always flat\n    xs_array = [\n        tf.TensorArray(\n            dtype=x.dtype,\n            size=n,\n            dynamic_size=False,\n            element_shape=x.shape[1:],\n            infer_shape=True,\n        )\n        for x in xs_flat\n    ]\n    xs_array = [x_a.unstack(x) for x_a, x in zip(xs_array, xs_flat)]\n\n    init_flat = tree.flatten(init)\n    carry_flat = [tf.convert_to_tensor(init) for init in init_flat]\n\n    # Store the intermediate values\n    # Note: there is a constraint that the output of `f` must have the same\n    # shape and dtype as carry (`init`).\n    ys_array = [\n        tf.TensorArray(\n            dtype=carry.dtype,\n            size=n,\n            dynamic_size=False,\n            element_shape=carry.shape,\n            infer_shape=True,\n        )\n        for carry in carry_flat\n    ]\n    carry_array = [\n        tf.TensorArray(\n            dtype=carry.dtype,\n            size=1,\n            dynamic_size=False,\n            clear_after_read=False,\n            element_shape=carry.shape,\n            infer_shape=True,\n        )\n        for carry in carry_flat\n    ]\n    carry_array = [\n        carry.write(0, c) for (carry, c) in zip(carry_array, carry_flat)\n    ]\n\n    def loop_body(i, carry_array, ys_array):\n        packed_xs = (\n            pack_input([xs.read(i) for xs in xs_array])\n            if len(xs_array) > 0\n            else None\n        )\n        packed_carry = pack_output([carry.read(0) for carry in carry_array])\n\n        carry, ys = f(packed_carry, packed_xs)\n\n        if ys is not None:\n            flat_ys = tree.flatten(ys)\n            ys_array = [ys.write(i, v) for (ys, v) in zip(ys_array, flat_ys)]\n        if carry is not None:\n            flat_carry = tree.flatten(carry)\n            carry_array = [\n                carry.write(0, v) for (carry, v) in zip(carry_array, flat_carry)\n            ]\n        next_i = i + 1 if not reverse else i - 1\n        return (next_i, carry_array, ys_array)\n\n    if isinstance(unroll, bool):\n        unroll = max(n, 1) if unroll else 1\n\n    _, carry_array, ys_array = tf.while_loop(\n        lambda i, _1, _2: i >= 0 if reverse else i < n,\n        loop_body,\n        (n - 1 if reverse else 0, carry_array, ys_array),\n        parallel_iterations=unroll,\n    )\n\n    ys_flat = [ys.stack() for ys in ys_array]\n    carry_flat = [carry.read(0) for carry in carry_array]\n    if xs is not None:\n        n_static = xs_flat[0].get_shape().with_rank_at_least(1)[0]\n        if not isinstance(n_static, int):\n            for x in xs_flat[1:]:\n                n_static.assert_is_compatible_with(\n                    x.get_shape().with_rank_at_least(1)[0]\n                )\n        for r in ys_flat:\n            r.set_shape(tf.TensorShape(n_static).concatenate(r.get_shape()[1:]))\n    return pack_output(carry_flat), pack_output(ys_flat)\n\n\ndef associative_scan(f, elems, reverse=False, axis=0):\n    # Implementation is the same as tfp.math.scan_associative\n    # with additional checks to ensure similar behavior with jax\n    if not callable(f):\n        raise TypeError(f\"`f` should be a callable. Received: f={f}\")\n    elems_flat = tree.flatten(elems)\n    elems_flat = [tf.convert_to_tensor(elem) for elem in elems_flat]\n    if reverse:\n        elems_flat = [tf.reverse(elem, [axis]) for elem in elems_flat]\n\n    def _combine(a_flat, b_flat):\n        a = tree.pack_sequence_as(elems, a_flat)\n        b = tree.pack_sequence_as(elems, b_flat)\n        c = f(a, b)\n        c_flat = tree.flatten(c)\n        return c_flat\n\n    def _get_dim(x):\n        return shape(x)[axis]\n\n    # TODO add constant dim check\n    num_elems = _get_dim(elems_flat[0])\n    if not all(_get_dim(elem) == num_elems for elem in elems_flat[1:]):\n        raise ValueError(\n            \"Array inputs to associative_scan must have the same \"\n            \"first dimension. (saw: {})\".format(\n                [tf.shape(elem) for elem in elems_flat]\n            )\n        )\n\n    def _interleave(a, b, axis):\n        # [a b c ...] [d e f ...] -> [a d b e c f ...]\n        num_elems_a = _get_dim(a)\n        num_elems_b = _get_dim(b)\n\n        # Note that interleaving implies rank(a)==rank(b).\n        axis = tf.where(axis >= 0, axis, tf.rank(a) + axis)\n        axis = (\n            int(axis)  # Avoid ndarray values.\n            if tf.get_static_value(axis) is not None\n            else axis\n        )\n\n        def _interleave_with_b(a):\n            return tf.reshape(\n                # Work around lack of support for Tensor axes in\n                # `tf.stack` by using `concat` and `expand_dims` instead.\n                tf.concat(\n                    [\n                        tf.expand_dims(a, axis=axis + 1),\n                        tf.expand_dims(b, axis=axis + 1),\n                    ],\n                    axis=axis + 1,\n                ),\n                tf.concat(\n                    [\n                        a.get_shape()[:axis],\n                        [2 * num_elems_b],\n                        a.get_shape()[axis + 1 :],\n                    ],\n                    axis=0,\n                ),\n            )\n\n        return tf.cond(\n            tf.equal(num_elems_a, num_elems_b + 1),\n            lambda: tf.concat(\n                [\n                    _interleave_with_b(\n                        slice_along_axis(a, None, -1, axis=axis)\n                    ),\n                    slice_along_axis(a, -1, None, axis=axis),\n                ],\n                axis=axis,\n            ),\n            lambda: _interleave_with_b(a),\n        )\n\n    def _scan(elems):\n        elem_length = _get_dim(elems[0])\n        a = [slice_along_axis(elem, 0, -1, step=2, axis=axis) for elem in elems]\n        b = [\n            slice_along_axis(elem, 1, None, step=2, axis=axis) for elem in elems\n        ]\n        reduced_elems = _combine(a, b)\n\n        def _handle_base_case_elem_length_two():\n            return [\n                tf.concat(\n                    [slice_along_axis(elem, 0, 1, axis=axis), reduced_elem],\n                    axis=axis,\n                )\n                for (reduced_elem, elem) in zip(reduced_elems, elems)\n            ]\n\n        def _handle_base_case_elem_length_three():\n            reduced_reduced_elems = _combine(\n                reduced_elems,\n                [slice_along_axis(elem, 2, 3, axis=axis) for elem in elems],\n            )\n            return [\n                tf.concat(\n                    [\n                        slice_along_axis(elem, 0, 1, axis=axis),\n                        reduced_elem,\n                        reduced_reduced_elem,\n                    ],\n                    axis=axis,\n                )\n                for (reduced_reduced_elem, reduced_elem, elem) in zip(\n                    reduced_reduced_elems, reduced_elems, elems\n                )\n            ]\n\n        at_base_case = tf.logical_or(\n            tf.equal(elem_length, 2), tf.equal(elem_length, 3)\n        )\n\n        def _base_case():\n            return tf.cond(\n                tf.equal(elem_length, 2),\n                _handle_base_case_elem_length_two,\n                _handle_base_case_elem_length_three,\n            )\n\n        def _recursive_case():\n            odd_elems = _scan(reduced_elems)\n\n            def _even_length_case():\n                return _combine(\n                    [\n                        slice_along_axis(odd_elem, 0, -1, axis=axis)\n                        for odd_elem in odd_elems\n                    ],\n                    [\n                        slice_along_axis(elem, 2, None, 2, axis=axis)\n                        for elem in elems\n                    ],\n                )\n\n            def _odd_length_case():\n                return _combine(\n                    [odd_elem for odd_elem in odd_elems],\n                    [\n                        slice_along_axis(elem, 2, None, 2, axis=axis)\n                        for elem in elems\n                    ],\n                )\n\n            results = tf.cond(\n                tf.equal(elem_length % 2, 0),\n                _even_length_case,\n                _odd_length_case,\n            )\n\n            even_elems = [\n                tf.concat(\n                    [slice_along_axis(elem, 0, 1, axis=axis), result], axis=axis\n                )\n                for (elem, result) in zip(elems, results)\n            ]\n            return list(\n                builtins.map(\n                    lambda a, b: _interleave(a, b, axis=axis),\n                    even_elems,\n                    odd_elems,\n                )\n            )\n\n        return tf.cond(at_base_case, _base_case, _recursive_case)\n\n    scans = _scan(elems_flat)\n    if reverse:\n        scans = [tf.reverse(scanned, [axis]) for scanned in scans]\n\n    return tree.pack_sequence_as(elems, scans)\n\n\ndef scatter(indices, values, shape):\n    return tf.scatter_nd(indices, values, shape)\n\n\ndef scatter_update(inputs, indices, updates, reduction=None):\n    if reduction is None:\n        return tf.tensor_scatter_nd_update(inputs, indices, updates)\n    elif reduction == \"add\":\n        return tf.tensor_scatter_nd_add(inputs, indices, updates)\n    elif reduction == \"max\":\n        return tf.tensor_scatter_nd_max(inputs, indices, updates)\n    elif reduction == \"min\":\n        return tf.tensor_scatter_nd_min(inputs, indices, updates)\n    elif reduction == \"mul\":\n        # TensorFlow doesn't have tensor_scatter_nd_mul, implement manually\n        # Use while_loop to handle both scalar and slice updates correctly\n        num_updates = tf.shape(indices)[0]\n\n        def body(i, result):\n            idx = indices[i : i + 1]  # Shape (1, index_depth)\n            current = tf.gather_nd(result, idx)  # Shape (1, *slice_shape)\n            new_value = (\n                current * updates[i]\n            )  # Maintains shape (1, *slice_shape)\n            return i + 1, tf.tensor_scatter_nd_update(result, idx, new_value)\n\n        _, result = tf.while_loop(\n            lambda i, _: i < num_updates,\n            body,\n            [0, inputs],\n        )\n        return result\n    else:\n        raise ValueError(f\"Unsupported reduction: {reduction}\")\n\n\ndef slice(inputs, start_indices, shape):\n    return tf.slice(inputs, start_indices, shape)\n\n\ndef slice_update(inputs, start_indices, updates):\n    return dynamic_update_slice(inputs, updates, start_indices)\n\n\ndef switch(index, branches, *operands):\n    index = convert_to_tensor(index, \"int32\")\n    index = tf.clip_by_value(index, 0, len(branches) - 1)\n\n    # Workaround to deal with python closures. More details:\n    # https://github.com/tensorflow/tensorflow/issues/8776#issuecomment-311383887\n    def gen_fn(i):\n        return lambda: branches[i](*operands)\n\n    branch_fns = [gen_fn(i) for i in range(len(branches))]\n    return tf.switch_case(index, branch_fns)\n\n\ndef while_loop(\n    cond,\n    body,\n    loop_vars,\n    maximum_iterations=None,\n):\n    is_tuple = isinstance(loop_vars, (tuple, list))\n    loop_vars = tuple(loop_vars) if is_tuple else (loop_vars,)\n\n    def _body(*args):\n        outputs = body(*args)\n        return tuple(outputs) if is_tuple else (outputs,)\n\n    outputs = tf.while_loop(\n        cond,\n        _body,\n        loop_vars,\n        maximum_iterations=maximum_iterations,\n    )\n    return outputs if is_tuple else outputs[0]\n\n\ndef fori_loop(lower, upper, body_fun, init_val):\n    return tf.while_loop(\n        lambda i, val: i < upper,\n        lambda i, val: (i + 1, body_fun(i, val)),\n        (lower, init_val),\n    )[1]\n\n\ndef stop_gradient(variable):\n    return tf.stop_gradient(variable)\n\n\ndef unstack(x, num=None, axis=0):\n    return tf.unstack(x, num=num, axis=axis)\n\n\ndef random_seed_dtype():\n    # tensorflow random operation only works on int32/int64, not uint32.\n    return \"int64\"\n\n\ndef custom_gradient(fun):\n    return tf.custom_gradient(f=fun)\n\n\ndef remat(f):\n    \"\"\"Implementation of rematerialization.\n\n    Args:\n        f: The function or operation to rematerialize.\n    Returns:\n        A function wrapping f that defines a custom gradient, which\n        recomputes f on the backwards pass of a gradient call.\n    \"\"\"\n    return tf.recompute_grad(f)\n\n\nclass name_scope(base_name_scope):\n    def __init__(self, name, **kwargs):\n        super().__init__(name, **kwargs)\n        self._tf_name_scope = tf.name_scope(name)\n\n    def __enter__(self):\n        name_scope_stack = global_state.get_global_attribute(\n            \"name_scope_stack\", default=[], set_to_default=True\n        )\n        if self.deduplicate and name_scope_stack:\n            parent_caller = name_scope_stack[-1].caller\n            parent_name = name_scope_stack[-1].name\n            if (\n                self.caller is not None\n                and self.caller is parent_caller\n                and self.name == parent_name\n            ):\n                return self\n        name_scope_stack.append(self)\n        self._pop_on_exit = True\n        self._tf_name_scope.__enter__()\n        return self\n\n    def __exit__(self, *args, **kwargs):\n        super().__exit__(*args, **kwargs)\n        if self._pop_on_exit:\n            self._tf_name_scope.__exit__(*args, **kwargs)\n\n\ndef device_scope(device_name):\n    return tf.device(device_name)\n"
  },
  {
    "path": "keras/src/backend/tensorflow/distribute_test.py",
    "content": "\"\"\"Tests for tf.distribute related functionality under tf implementation.\"\"\"\n\nimport numpy as np\nimport pytest\nimport tensorflow as tf\nfrom tensorflow.python.eager import context\n\nimport keras\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import testing\nfrom keras.src.backend.tensorflow import trainer as tf_trainer\n\n\n@pytest.mark.skipif(\n    backend.backend() != \"tensorflow\",\n    reason=\"The distribute test can only run with TF backend.\",\n)\nclass DistributeTest(testing.TestCase):\n    def setUp(self):\n        super().setUp()\n        # Need at least 2 devices for distribution related tests.\n        cpus = tf.config.list_physical_devices(\"CPU\")\n        context._reset_context()\n        tf.config.set_logical_device_configuration(\n            cpus[0],\n            [\n                tf.config.LogicalDeviceConfiguration(),\n                tf.config.LogicalDeviceConfiguration(),\n            ],\n        )\n\n    def test_variable_creation(self):\n        strategy = tf.distribute.MirroredStrategy([\"CPU:0\", \"CPU:1\"])\n        with strategy.scope():\n            dense = layers.Dense(2)\n            dense.build([4, 2])\n\n        self.assertIsInstance(dense.kernel, backend.Variable)\n        self.assertIsInstance(\n            dense.kernel.value, tf.distribute.DistributedValues\n        )\n        self.assertIn(\"MirroredVariable\", dense.kernel.value.__class__.__name__)\n\n        self.assertIsInstance(dense.kernel, backend.Variable)\n        self.assertIsInstance(dense.bias.value, tf.distribute.DistributedValues)\n        self.assertIn(\"MirroredVariable\", dense.bias.value.__class__.__name__)\n\n    def test_strategy_run(self):\n        strategy = tf.distribute.MirroredStrategy([\"CPU:0\", \"CPU:1\"])\n\n        with strategy.scope():\n            inputs = layers.Input(shape=[4])\n            dense = layers.Dense(2)\n            output = dense(inputs)\n            model = models.Functional(inputs, output)\n\n        self.assertIsInstance(dense.kernel, backend.Variable)\n        self.assertIsInstance(\n            dense.kernel.value, tf.distribute.DistributedValues\n        )\n\n        def input_fn(ctx):\n            if ctx.replica_id_in_sync_group == 1:\n                return tf.ones([8, 4])\n            else:\n                return tf.zeros([8, 4])\n\n        distributed_inputs = (\n            strategy.experimental_distribute_values_from_function(input_fn)\n        )\n\n        @tf.function\n        def run_fn(data):\n            return model(data)\n\n        result = strategy.run(run_fn, args=(distributed_inputs,))\n\n        self.assertIsInstance(\n            result, tf.types.experimental.distributed.PerReplica\n        )\n        self.assertLen(result.values, 2)\n        self.assertEqual(result.values[0].shape, [8, 2])\n        self.assertEqual(result.values[1].shape, [8, 2])\n        self.assertNotAllClose(result.values[0], result.values[1])\n        self.assertAllClose(result.values[0], tf.zeros([8, 2]))\n\n    def test_epoch_iterator(self):\n        x = np.random.random((100, 16))\n        y = np.random.random((100, 4))\n        sample_weight = np.random.random((100,))\n        batch_size = 16\n        shuffle = True\n\n        strategy = tf.distribute.MirroredStrategy([\"CPU:0\", \"CPU:1\"])\n\n        epoch_iterator = tf_trainer.TFEpochIterator(\n            x=x,\n            y=y,\n            sample_weight=sample_weight,\n            batch_size=batch_size,\n            shuffle=shuffle,\n            distribute_strategy=strategy,\n        )\n        steps_seen = []\n        for step, _, data_iterator in epoch_iterator:\n            steps_seen.append(step)\n            batch = next(data_iterator)\n            self.assertEqual(len(batch), 3)\n            x, y, sample_weight = batch\n            self.assertTrue(\n                isinstance(x, tf.types.experimental.distributed.PerReplica)\n            )\n            # Make sure the local batch size is 8\n            if step < 6:\n                self.assertEqual(x.values[0].shape, [8, 16])\n                self.assertEqual(y.values[0].shape, [8, 4])\n                self.assertEqual(sample_weight.values[0].shape, [8])\n            else:\n                # Last partial batch\n                self.assertEqual(x.values[0].shape, [2, 16])\n                self.assertEqual(y.values[0].shape, [2, 4])\n                self.assertEqual(sample_weight.values[0].shape, [2])\n        self.assertEqual(steps_seen, [0, 1, 2, 3, 4, 5, 6])\n\n    def test_variable_aggregation(self):\n        strategy = tf.distribute.MirroredStrategy([\"CPU:0\", \"CPU:1\"])\n\n        with strategy.scope():\n            x = np.random.random((4, 4))\n            v1 = backend.Variable(x, dtype=\"float32\")\n            self.assertEqual(v1.aggregation, \"none\")\n            self.assertEqual(v1.value.aggregation, tf.VariableAggregation.NONE)\n\n            v2 = backend.Variable(x, dtype=\"float32\", aggregation=\"sum\")\n            self.assertEqual(v2.aggregation, \"sum\")\n            self.assertEqual(v2.value.aggregation, tf.VariableAggregation.SUM)\n\n    def test_variable_synchronization(self):\n        strategy = tf.distribute.MirroredStrategy([\"CPU:0\", \"CPU:1\"])\n\n        with strategy.scope():\n            x = np.random.random((4, 4))\n            v1 = backend.Variable(x, dtype=\"float32\")\n            self.assertEqual(v1.synchronization, \"auto\")\n            # AUTO with MirroredStrategy defaults to ON_WRITE\n            self.assertEqual(\n                v1.value.synchronization, tf.VariableSynchronization.ON_WRITE\n            )\n\n            v2 = backend.Variable(x, dtype=\"float32\", synchronization=\"on_read\")\n            self.assertEqual(v2.synchronization, \"on_read\")\n            self.assertEqual(\n                v2.value.synchronization, tf.VariableSynchronization.ON_READ\n            )\n\n    def test_seed_generator(self):\n        strategy = tf.distribute.MirroredStrategy([\"CPU:0\", \"CPU:1\"])\n        with strategy.scope():\n            seed_generator = keras.random.SeedGenerator(42)\n            states = strategy.run(lambda: seed_generator.state.value).values\n            for s in states:\n                self.assertAllClose(keras.ops.convert_to_numpy(s), (42, 0))\n\n    def test_correctness_with_fit_and_regularizer(self):\n        strategy = tf.distribute.MirroredStrategy([\"CPU:0\", \"CPU:1\"])\n\n        batch_size = 12\n        x = keras.ops.ones((batch_size, 1))\n        y = keras.ops.zeros((batch_size, 1))\n\n        # Runs without a strategy to get expected weights.\n        inputs = layers.Input(shape=(1,))\n        layer = layers.Dense(\n            1,\n            use_bias=False,\n            kernel_initializer=keras.initializers.Constant(1),\n            kernel_regularizer=keras.regularizers.L1L2(l1=0.01, l2=0.01),\n        )\n        model = models.Model(inputs, layer(inputs))\n        model.compile(loss=\"mse\", optimizer=\"sgd\")\n        history = model.fit(x, y, batch_size=batch_size, epochs=1)\n        expected_loss = history.history[\"loss\"]\n        expected_weights = keras.ops.convert_to_numpy(layer.kernel)\n\n        # Runs with a mirrored strategy.\n        with strategy.scope():\n            inputs = layers.Input(shape=(1,))\n            layer = layers.Dense(\n                1,\n                use_bias=False,\n                kernel_initializer=keras.initializers.Constant(1),\n                kernel_regularizer=keras.regularizers.L1L2(l1=0.01, l2=0.01),\n            )\n            model = models.Model(inputs, layer(inputs))\n            model.compile(loss=\"mse\", optimizer=\"sgd\")\n            history = model.fit(x, y, batch_size=batch_size, epochs=1)\n            weights = strategy.run(lambda: layer.kernel.value).values\n\n            self.assertAllClose(history.history[\"loss\"], expected_loss)\n            for w in weights:\n                self.assertAllClose(\n                    keras.ops.convert_to_numpy(w), expected_weights\n                )\n"
  },
  {
    "path": "keras/src/backend/tensorflow/distribution_lib.py",
    "content": "\"\"\"!!!DO NOT USE!!!\n\nDistribution related class for Tensorflow backend.\n\nThis is just a prototype and we might want to unify it\nwith other backends in the future.\n\"\"\"\n\nimport tensorflow as tf\nfrom tensorflow.experimental import dtensor\n\n\ndef list_devices(device_type=None):\n    \"\"\"Return all the available devices based on the device type.\n\n    Note that this should return the global devices in a distributed setting.\n\n    Args:\n        device_type: string of `\"cpu\"`, `\"gpu\"` or `\"tpu\"`. Default to `gpu` or\n        `tpu` if available when device_type is not provided. Otherwise will\n        return the `cpu` devices.\n\n    Return:\n        List of devices that are available for distribute computation.\n    \"\"\"\n    device_type = device_type.upper() if device_type else None\n\n    # DTensor doesn't support getting global devices, even when knowing the\n    # Mesh. Use TF API instead to get global devices. Coordinator service is\n    # enabled by default with DTensor, so that list_logical_devices() returns\n    # a list of global devices. More context can be found in b/254911601.\n    tf_devices = tf.config.list_logical_devices(device_type=device_type)\n    cpu_devices = []\n    other_devices = []\n    for device in tf_devices:\n        if device.device_type.lower() == \"cpu\":\n            cpu_devices.append(device)\n        else:\n            other_devices.append(device)\n    if device_type is None:\n        tf_devices = other_devices if len(other_devices) > 0 else cpu_devices\n    return [\n        f\"{device.device_type.lower()}:{device.name.split(':')[-1]}\"\n        for device in tf_devices\n    ]\n\n\ndef distribute_value(value, tensor_layout):\n    # TODO\n    pass\n\n\ndef _to_backend_mesh(device_mesh):\n    \"\"\"Convert the DeviceMesh to Tensorflow backend specific Mesh.\n\n    Args:\n        device_mesh: DeviceMesh instance to convert.\n\n    Returns:\n        A `tf.dtensor.Mesh` instance.\n    \"\"\"\n    mesh_dims = list(zip(device_mesh.axis_names, device_mesh.shape))\n    return dtensor.create_distributed_mesh(\n        mesh_dims=mesh_dims, local_devices=device_mesh.devices.flatten()\n    )\n\n\ndef _to_backend_layout(tensor_layout):\n    \"\"\"Convert the TensorLayout to Tensorflow backend specific Sharding.\n\n    Args:\n        tensor_layout: TensorLayout instance to convert.\n\n    Returns:\n        A `tf.dtensor.Layout` instance.\n    \"\"\"\n    if tensor_layout.device_mesh is None:\n        raise ValueError(\n            \"Cannot create sharding when device mesh is not set for \"\n            \"TensorLayout.\"\n        )\n\n    sharding_specs = [\n        axis if axis else dtensor.UNSHARDED for axis in tensor_layout.axes\n    ]\n    dtensor_mesh = tensor_layout.device_mesh.backend_mesh\n    return dtensor.Layout(sharding_specs=sharding_specs, mesh=dtensor_mesh)\n"
  },
  {
    "path": "keras/src/backend/tensorflow/export.py",
    "content": "import tensorflow as tf\n\nfrom keras.src.export.saved_model_export_archive import SavedModelExportArchive\n\n\nclass TFExportArchive(SavedModelExportArchive):\n    \"\"\"TensorFlow backend implementation of SavedModel export archive.\"\"\"\n\n    def _backend_track_layer(self, layer):\n        # Variables in the lists below are actually part of the trackables\n        # that get saved, because the lists are created in __init__.\n        variables = layer.variables\n        trainable_variables = layer.trainable_variables\n        non_trainable_variables = layer.non_trainable_variables\n        self._tf_trackable.variables += variables\n        self._tf_trackable.trainable_variables += trainable_variables\n        self._tf_trackable.non_trainable_variables += non_trainable_variables\n\n    def _backend_add_endpoint(self, name, fn, input_signature, **kwargs):\n        decorated_fn = tf.function(\n            fn, input_signature=input_signature, autograph=False\n        )\n        return decorated_fn\n"
  },
  {
    "path": "keras/src/backend/tensorflow/image.py",
    "content": "import functools\nimport itertools\nimport operator\n\nimport numpy as np\nimport tensorflow as tf\n\nfrom keras.src import backend\nfrom keras.src.backend.tensorflow.core import convert_to_tensor\nfrom keras.src.backend.tensorflow.numpy import moveaxis\nfrom keras.src.random.seed_generator import draw_seed\n\nRESIZE_INTERPOLATIONS = (\n    \"bilinear\",\n    \"nearest\",\n    \"lanczos3\",\n    \"lanczos5\",\n    \"bicubic\",\n    \"area\",\n)\nAFFINE_TRANSFORM_INTERPOLATIONS = (\n    \"nearest\",\n    \"bilinear\",\n)\nAFFINE_TRANSFORM_FILL_MODES = (\n    \"constant\",\n    \"nearest\",\n    \"wrap\",\n    # \"mirror\", not supported by TF\n    \"reflect\",\n)\nMAP_COORDINATES_FILL_MODES = {\n    \"constant\",\n    \"nearest\",\n    \"wrap\",\n    \"mirror\",\n    \"reflect\",\n}\nSCALE_AND_TRANSLATE_METHODS = {\n    \"linear\",\n    \"bilinear\",\n    \"trilinear\",\n    \"cubic\",\n    \"bicubic\",\n    \"tricubic\",\n    \"lanczos3\",\n    \"lanczos5\",\n}\n\n\ndef rgb_to_grayscale(images, data_format=None):\n    images = convert_to_tensor(images)\n    data_format = backend.standardize_data_format(data_format)\n    channels_axis = -1 if data_format == \"channels_last\" else -3\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n    # Convert to floats\n    original_dtype = images.dtype\n    compute_dtype = backend.result_type(images.dtype, float)\n    images = tf.cast(images, compute_dtype)\n\n    # Ref: tf.image.rgb_to_grayscale\n    rgb_weights = convert_to_tensor(\n        [0.2989, 0.5870, 0.1140], dtype=images.dtype\n    )\n    images = tf.tensordot(images, rgb_weights, axes=(channels_axis, -1))\n    images = tf.expand_dims(images, axis=channels_axis)\n    return tf.cast(images, original_dtype)\n\n\ndef rgb_to_hsv(images, data_format=None):\n    images = convert_to_tensor(images)\n    dtype = images.dtype\n    data_format = backend.standardize_data_format(data_format)\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n    if not backend.is_float_dtype(dtype):\n        raise ValueError(\n            \"Invalid images dtype: expected float dtype. \"\n            f\"Received: images.dtype={backend.standardize_dtype(dtype)}\"\n        )\n    if data_format == \"channels_first\":\n        if len(images.shape) == 4:\n            images = tf.transpose(images, (0, 2, 3, 1))\n        else:\n            images = tf.transpose(images, (1, 2, 0))\n    images = tf.image.rgb_to_hsv(images)\n    if data_format == \"channels_first\":\n        if len(images.shape) == 4:\n            images = tf.transpose(images, (0, 3, 1, 2))\n        elif len(images.shape) == 3:\n            images = tf.transpose(images, (2, 0, 1))\n    return images\n\n\ndef hsv_to_rgb(images, data_format=None):\n    images = convert_to_tensor(images)\n    dtype = images.dtype\n    data_format = backend.standardize_data_format(data_format)\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n    if not backend.is_float_dtype(dtype):\n        raise ValueError(\n            \"Invalid images dtype: expected float dtype. \"\n            f\"Received: images.dtype={backend.standardize_dtype(dtype)}\"\n        )\n    if data_format == \"channels_first\":\n        if len(images.shape) == 4:\n            images = tf.transpose(images, (0, 2, 3, 1))\n        else:\n            images = tf.transpose(images, (1, 2, 0))\n    images = tf.image.hsv_to_rgb(images)\n    if data_format == \"channels_first\":\n        if len(images.shape) == 4:\n            images = tf.transpose(images, (0, 3, 1, 2))\n        elif len(images.shape) == 3:\n            images = tf.transpose(images, (2, 0, 1))\n    return images\n\n\ndef resize(\n    images,\n    size,\n    interpolation=\"bilinear\",\n    antialias=False,\n    crop_to_aspect_ratio=False,\n    pad_to_aspect_ratio=False,\n    fill_mode=\"constant\",\n    fill_value=0.0,\n    data_format=None,\n):\n    data_format = backend.standardize_data_format(data_format)\n    if interpolation not in RESIZE_INTERPOLATIONS:\n        raise ValueError(\n            \"Invalid value for argument `interpolation`. Expected of one \"\n            f\"{RESIZE_INTERPOLATIONS}. Received: interpolation={interpolation}\"\n        )\n    if fill_mode != \"constant\":\n        raise ValueError(\n            \"Invalid value for argument `fill_mode`. Only `'constant'` \"\n            f\"is supported. Received: fill_mode={fill_mode}\"\n        )\n    if pad_to_aspect_ratio and crop_to_aspect_ratio:\n        raise ValueError(\n            \"Only one of `pad_to_aspect_ratio` & `crop_to_aspect_ratio` \"\n            \"can be `True`.\"\n        )\n    if not len(size) == 2:\n        raise ValueError(\n            \"Argument `size` must be a tuple of two elements \"\n            f\"(height, width). Received: size={size}\"\n        )\n    size = tuple(size)\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n    if data_format == \"channels_first\":\n        if len(images.shape) == 4:\n            images = tf.transpose(images, (0, 2, 3, 1))\n        else:\n            images = tf.transpose(images, (1, 2, 0))\n\n    if crop_to_aspect_ratio:\n        shape = tf.shape(images)\n        height, width = shape[-3], shape[-2]\n        target_height, target_width = size\n        crop_height = tf.cast(\n            tf.cast(width * target_height, \"float32\") / target_width,\n            \"int32\",\n        )\n        crop_height = tf.maximum(tf.minimum(height, crop_height), 1)\n        crop_height = tf.cast(crop_height, \"int32\")\n        crop_width = tf.cast(\n            tf.cast(height * target_width, \"float32\") / target_height,\n            \"int32\",\n        )\n        crop_width = tf.maximum(tf.minimum(width, crop_width), 1)\n        crop_width = tf.cast(crop_width, \"int32\")\n\n        crop_box_hstart = tf.cast(\n            tf.cast(height - crop_height, \"float32\") / 2, \"int32\"\n        )\n        crop_box_wstart = tf.cast(\n            tf.cast(width - crop_width, \"float32\") / 2, \"int32\"\n        )\n        if len(images.shape) == 4:\n            images = images[\n                :,\n                crop_box_hstart : crop_box_hstart + crop_height,\n                crop_box_wstart : crop_box_wstart + crop_width,\n                :,\n            ]\n        else:\n            images = images[\n                crop_box_hstart : crop_box_hstart + crop_height,\n                crop_box_wstart : crop_box_wstart + crop_width,\n                :,\n            ]\n    elif pad_to_aspect_ratio:\n        shape = tf.shape(images)\n        height, width = shape[-3], shape[-2]\n        target_height, target_width = size\n        pad_height = tf.cast(\n            tf.cast(width * target_height, \"float32\") / target_width,\n            \"int32\",\n        )\n        pad_height = tf.maximum(height, pad_height)\n        pad_height = tf.cast(pad_height, \"int32\")\n        pad_width = tf.cast(\n            tf.cast(height * target_width, \"float32\") / target_height,\n            \"int32\",\n        )\n        pad_width = tf.maximum(width, pad_width)\n        pad_width = tf.cast(pad_width, \"int32\")\n\n        img_box_hstart = tf.cast(\n            tf.cast(pad_height - height, \"float32\") / 2, \"int32\"\n        )\n        img_box_wstart = tf.cast(\n            tf.cast(pad_width - width, \"float32\") / 2, \"int32\"\n        )\n        if len(images.shape) == 4:\n            batch_size = tf.shape(images)[0]\n            channels = tf.shape(images)[3]\n            padded_img = tf.cond(\n                img_box_hstart > 0,\n                lambda: tf.concat(\n                    [\n                        tf.ones(\n                            (batch_size, img_box_hstart, width, channels),\n                            dtype=images.dtype,\n                        )\n                        * fill_value,\n                        images,\n                        tf.ones(\n                            (batch_size, img_box_hstart, width, channels),\n                            dtype=images.dtype,\n                        )\n                        * fill_value,\n                    ],\n                    axis=1,\n                ),\n                lambda: images,\n            )\n            padded_img = tf.cond(\n                img_box_wstart > 0,\n                lambda: tf.concat(\n                    [\n                        tf.ones(\n                            (batch_size, height, img_box_wstart, channels),\n                            dtype=images.dtype,\n                        )\n                        * fill_value,\n                        padded_img,\n                        tf.ones(\n                            (batch_size, height, img_box_wstart, channels),\n                            dtype=images.dtype,\n                        )\n                        * fill_value,\n                    ],\n                    axis=2,\n                ),\n                lambda: padded_img,\n            )\n        else:\n            channels = tf.shape(images)[2]\n            padded_img = tf.cond(\n                img_box_hstart > 0,\n                lambda: tf.concat(\n                    [\n                        tf.ones(\n                            (img_box_hstart, width, channels),\n                            dtype=images.dtype,\n                        )\n                        * fill_value,\n                        images,\n                        tf.ones(\n                            (img_box_hstart, width, channels),\n                            dtype=images.dtype,\n                        )\n                        * fill_value,\n                    ],\n                    axis=0,\n                ),\n                lambda: images,\n            )\n            padded_img = tf.cond(\n                img_box_wstart > 0,\n                lambda: tf.concat(\n                    [\n                        tf.ones(\n                            (height, img_box_wstart, channels),\n                            dtype=images.dtype,\n                        )\n                        * fill_value,\n                        padded_img,\n                        tf.ones(\n                            (height, img_box_wstart, channels),\n                            dtype=images.dtype,\n                        )\n                        * fill_value,\n                    ],\n                    axis=1,\n                ),\n                lambda: padded_img,\n            )\n        images = padded_img\n\n    resized = tf.image.resize(\n        images, size, method=interpolation, antialias=antialias\n    )\n    if data_format == \"channels_first\":\n        if len(images.shape) == 4:\n            resized = tf.transpose(resized, (0, 3, 1, 2))\n        elif len(images.shape) == 3:\n            resized = tf.transpose(resized, (2, 0, 1))\n    return resized\n\n\ndef affine_transform(\n    images,\n    transform,\n    interpolation=\"bilinear\",\n    fill_mode=\"constant\",\n    fill_value=0,\n    data_format=None,\n):\n    data_format = backend.standardize_data_format(data_format)\n    if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS:\n        raise ValueError(\n            \"Invalid value for argument `interpolation`. Expected of one \"\n            f\"{AFFINE_TRANSFORM_INTERPOLATIONS}. Received: \"\n            f\"interpolation={interpolation}\"\n        )\n    if fill_mode not in AFFINE_TRANSFORM_FILL_MODES:\n        raise ValueError(\n            \"Invalid value for argument `fill_mode`. Expected of one \"\n            f\"{AFFINE_TRANSFORM_FILL_MODES}. Received: fill_mode={fill_mode}\"\n        )\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n    if len(transform.shape) not in (1, 2):\n        raise ValueError(\n            \"Invalid transform rank: expected rank 1 (single transform) \"\n            \"or rank 2 (batch of transforms). Received input with shape: \"\n            f\"transform.shape={transform.shape}\"\n        )\n    # unbatched case\n    need_squeeze = False\n    if len(images.shape) == 3:\n        images = tf.expand_dims(images, axis=0)\n        need_squeeze = True\n    if len(transform.shape) == 1:\n        transform = tf.expand_dims(transform, axis=0)\n\n    if data_format == \"channels_first\":\n        images = tf.transpose(images, (0, 2, 3, 1))\n\n    affined = tf.raw_ops.ImageProjectiveTransformV3(\n        images=images,\n        transforms=tf.cast(transform, dtype=tf.float32),\n        output_shape=tf.shape(images)[1:-1],\n        fill_value=fill_value,\n        interpolation=interpolation.upper(),\n        fill_mode=fill_mode.upper(),\n    )\n    affined = tf.ensure_shape(affined, images.shape)\n\n    if data_format == \"channels_first\":\n        affined = tf.transpose(affined, (0, 3, 1, 2))\n    if need_squeeze:\n        affined = tf.squeeze(affined, axis=0)\n    return affined\n\n\ndef perspective_transform(\n    images,\n    start_points,\n    end_points,\n    interpolation=\"bilinear\",\n    fill_value=0,\n    data_format=None,\n):\n    data_format = backend.standardize_data_format(data_format)\n    start_points = convert_to_tensor(start_points, dtype=tf.float32)\n    end_points = convert_to_tensor(end_points, dtype=tf.float32)\n\n    if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS:\n        raise ValueError(\n            \"Invalid value for argument `interpolation`. Expected of one \"\n            f\"{AFFINE_TRANSFORM_INTERPOLATIONS}. Received: \"\n            f\"interpolation={interpolation}\"\n        )\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n\n    if start_points.shape.rank not in (2, 3) or start_points.shape[-2:] != (\n        4,\n        2,\n    ):\n        raise ValueError(\n            \"Invalid start_points shape: expected (4,2) for a single image\"\n            f\" or (N,4,2) for a batch. Received shape: {start_points.shape}\"\n        )\n    if end_points.shape.rank not in (2, 3) or end_points.shape[-2:] != (4, 2):\n        raise ValueError(\n            \"Invalid end_points shape: expected (4,2) for a single image\"\n            f\" or (N,4,2) for a batch. Received shape: {end_points.shape}\"\n        )\n    if start_points.shape != end_points.shape:\n        raise ValueError(\n            \"start_points and end_points must have the same shape.\"\n            f\" Received start_points.shape={start_points.shape}, \"\n            f\"end_points.shape={end_points.shape}\"\n        )\n\n    need_squeeze = False\n    if len(images.shape) == 3:\n        images = tf.expand_dims(images, axis=0)\n        need_squeeze = True\n\n    if len(start_points.shape) == 2:\n        start_points = tf.expand_dims(start_points, axis=0)\n    if len(end_points.shape) == 2:\n        end_points = tf.expand_dims(end_points, axis=0)\n\n    if data_format == \"channels_first\":\n        images = tf.transpose(images, (0, 2, 3, 1))\n\n    transform = compute_homography_matrix(start_points, end_points)\n    if len(transform.shape) == 1:\n        transform = tf.expand_dims(transform, axis=0)\n\n    output = tf.raw_ops.ImageProjectiveTransformV3(\n        images=images,\n        transforms=tf.cast(transform, dtype=tf.float32),\n        output_shape=tf.shape(images)[1:-1],\n        fill_value=fill_value,\n        interpolation=interpolation.upper(),\n    )\n    output = tf.ensure_shape(output, images.shape)\n\n    if data_format == \"channels_first\":\n        output = tf.transpose(output, (0, 3, 1, 2))\n    if need_squeeze:\n        output = tf.squeeze(output, axis=0)\n    return output\n\n\ndef compute_homography_matrix(start_points, end_points):\n    start_x1, start_y1 = start_points[:, 0, 0], start_points[:, 0, 1]\n    start_x2, start_y2 = start_points[:, 1, 0], start_points[:, 1, 1]\n    start_x3, start_y3 = start_points[:, 2, 0], start_points[:, 2, 1]\n    start_x4, start_y4 = start_points[:, 3, 0], start_points[:, 3, 1]\n\n    end_x1, end_y1 = end_points[:, 0, 0], end_points[:, 0, 1]\n    end_x2, end_y2 = end_points[:, 1, 0], end_points[:, 1, 1]\n    end_x3, end_y3 = end_points[:, 2, 0], end_points[:, 2, 1]\n    end_x4, end_y4 = end_points[:, 3, 0], end_points[:, 3, 1]\n\n    coefficient_matrix = tf.stack(\n        [\n            tf.stack(\n                [\n                    end_x1,\n                    end_y1,\n                    tf.ones_like(end_x1),\n                    tf.zeros_like(end_x1),\n                    tf.zeros_like(end_x1),\n                    tf.zeros_like(end_x1),\n                    -start_x1 * end_x1,\n                    -start_x1 * end_y1,\n                ],\n                axis=-1,\n            ),\n            tf.stack(\n                [\n                    tf.zeros_like(end_x1),\n                    tf.zeros_like(end_x1),\n                    tf.zeros_like(end_x1),\n                    end_x1,\n                    end_y1,\n                    tf.ones_like(end_x1),\n                    -start_y1 * end_x1,\n                    -start_y1 * end_y1,\n                ],\n                axis=-1,\n            ),\n            tf.stack(\n                [\n                    end_x2,\n                    end_y2,\n                    tf.ones_like(end_x2),\n                    tf.zeros_like(end_x2),\n                    tf.zeros_like(end_x2),\n                    tf.zeros_like(end_x2),\n                    -start_x2 * end_x2,\n                    -start_x2 * end_y2,\n                ],\n                axis=-1,\n            ),\n            tf.stack(\n                [\n                    tf.zeros_like(end_x2),\n                    tf.zeros_like(end_x2),\n                    tf.zeros_like(end_x2),\n                    end_x2,\n                    end_y2,\n                    tf.ones_like(end_x2),\n                    -start_y2 * end_x2,\n                    -start_y2 * end_y2,\n                ],\n                axis=-1,\n            ),\n            tf.stack(\n                [\n                    end_x3,\n                    end_y3,\n                    tf.ones_like(end_x3),\n                    tf.zeros_like(end_x3),\n                    tf.zeros_like(end_x3),\n                    tf.zeros_like(end_x3),\n                    -start_x3 * end_x3,\n                    -start_x3 * end_y3,\n                ],\n                axis=-1,\n            ),\n            tf.stack(\n                [\n                    tf.zeros_like(end_x3),\n                    tf.zeros_like(end_x3),\n                    tf.zeros_like(end_x3),\n                    end_x3,\n                    end_y3,\n                    tf.ones_like(end_x3),\n                    -start_y3 * end_x3,\n                    -start_y3 * end_y3,\n                ],\n                axis=-1,\n            ),\n            tf.stack(\n                [\n                    end_x4,\n                    end_y4,\n                    tf.ones_like(end_x4),\n                    tf.zeros_like(end_x4),\n                    tf.zeros_like(end_x4),\n                    tf.zeros_like(end_x4),\n                    -start_x4 * end_x4,\n                    -start_x4 * end_y4,\n                ],\n                axis=-1,\n            ),\n            tf.stack(\n                [\n                    tf.zeros_like(end_x4),\n                    tf.zeros_like(end_x4),\n                    tf.zeros_like(end_x4),\n                    end_x4,\n                    end_y4,\n                    tf.ones_like(end_x4),\n                    -start_y4 * end_x4,\n                    -start_y4 * end_y4,\n                ],\n                axis=-1,\n            ),\n        ],\n        axis=1,\n    )\n\n    target_vector = tf.stack(\n        [\n            start_x1,\n            start_y1,\n            start_x2,\n            start_y2,\n            start_x3,\n            start_y3,\n            start_x4,\n            start_y4,\n        ],\n        axis=-1,\n    )\n    target_vector = tf.expand_dims(target_vector, axis=-1)\n\n    homography_matrix = tf.linalg.solve(coefficient_matrix, target_vector)\n    homography_matrix = tf.reshape(homography_matrix, [-1, 8])\n\n    return homography_matrix\n\n\ndef _mirror_index_fixer(index, size):\n    s = size - 1  # Half-wavelength of triangular wave\n    # Scaled, integer-valued version of the triangular wave |x - round(x)|\n    return tf.abs((index + s) % (2 * s) - s)\n\n\ndef _reflect_index_fixer(index, size):\n    return tf.math.floordiv(\n        _mirror_index_fixer(2 * index + 1, 2 * size + 1) - 1, 2\n    )\n\n\ndef _nearest_indices_and_weights(coordinate):\n    coordinate = (\n        coordinate if coordinate.dtype.is_integer else tf.round(coordinate)\n    )\n    index = tf.cast(coordinate, tf.int32)\n    weight = tf.constant(1, coordinate.dtype)\n    return [(index, weight)]\n\n\ndef _linear_indices_and_weights(coordinate):\n    lower = tf.floor(coordinate)\n    upper_weight = coordinate - lower\n    lower_weight = 1 - upper_weight\n    index = tf.cast(lower, tf.int32)\n    return [(index, lower_weight), (index + 1, upper_weight)]\n\n\ndef map_coordinates(\n    inputs, coordinates, order, fill_mode=\"constant\", fill_value=0.0\n):\n    input_arr = convert_to_tensor(inputs)\n    coordinate_arrs = convert_to_tensor(coordinates)\n\n    if coordinate_arrs.shape[0] != len(input_arr.shape):\n        raise ValueError(\n            \"First dim of `coordinates` must be the same as the rank of \"\n            \"`inputs`. \"\n            f\"Received inputs with shape: {input_arr.shape} and coordinate \"\n            f\"leading dim of {coordinate_arrs.shape[0]}\"\n        )\n    if len(coordinate_arrs.shape) < 2:\n        raise ValueError(\n            \"Invalid coordinates rank: expected at least rank 2.\"\n            f\" Received input with shape: {coordinate_arrs.shape}\"\n        )\n    if fill_mode not in MAP_COORDINATES_FILL_MODES:\n        raise ValueError(\n            \"Invalid value for argument `fill_mode`. Expected one of \"\n            f\"{set(MAP_COORDINATES_FILL_MODES.keys())}. Received: \"\n            f\"fill_mode={fill_mode}\"\n        )\n\n    fill_value = convert_to_tensor(fill_value, dtype=input_arr.dtype)\n\n    coordinate_arrs = tf.unstack(coordinate_arrs, axis=0)\n\n    if order == 0:\n        interp_fun = _nearest_indices_and_weights\n    elif order == 1:\n        interp_fun = _linear_indices_and_weights\n    else:\n        raise NotImplementedError(\"map_coordinates currently requires order<=1\")\n\n    def process_coordinates(coords, size):\n        if fill_mode == \"constant\":\n            valid = (coords >= 0) & (coords < size)\n            safe_coords = tf.clip_by_value(coords, 0, size - 1)\n            return safe_coords, valid\n        elif fill_mode == \"nearest\":\n            return tf.clip_by_value(coords, 0, size - 1), tf.ones_like(\n                coords, dtype=tf.bool\n            )\n        elif fill_mode in [\"mirror\", \"reflect\"]:\n            coords = tf.abs(coords)\n            size_2 = size * 2\n            mod = tf.math.mod(coords, size_2)\n            under = mod < size\n            over = ~under\n            # reflect mode is same as mirror for under\n            coords = tf.where(under, mod, size_2 - mod)\n            # for reflect mode, adjust the over case\n            if fill_mode == \"reflect\":\n                coords = tf.where(over, coords - 1, coords)\n            return coords, tf.ones_like(coords, dtype=tf.bool)\n        elif fill_mode == \"wrap\":\n            coords = tf.math.mod(coords, size)\n            return coords, tf.ones_like(coords, dtype=tf.bool)\n        else:\n            raise ValueError(f\"Unknown fill_mode: {fill_mode}\")\n\n    valid_1d_interpolations = []\n    for coordinate, size in zip(coordinate_arrs, input_arr.shape):\n        interp_nodes = interp_fun(coordinate)\n        valid_interp = []\n        for index, weight in interp_nodes:\n            safe_index, valid = process_coordinates(index, size)\n            valid_interp.append((safe_index, valid, weight))\n        valid_1d_interpolations.append(valid_interp)\n\n    outputs = []\n    for items in itertools.product(*valid_1d_interpolations):\n        indices, validities, weights = zip(*items)\n        indices = tf.transpose(tf.stack(indices))\n\n        gathered = tf.transpose(tf.gather_nd(input_arr, indices))\n\n        # Cast to computation dtype early to avoid type issues\n        dtype = weights[0].dtype\n        gathered = tf.cast(gathered, dtype)\n        gathered = tf.cast(gathered, weights[0].dtype)\n\n        if fill_mode == \"constant\":\n            all_valid = tf.reduce_all(validities, axis=0)\n            fill_value_typed = tf.cast(fill_value, dtype)\n            gathered = tf.where(all_valid, gathered, fill_value_typed)\n\n        outputs.append(functools.reduce(operator.mul, weights) * gathered)\n\n    result = functools.reduce(operator.add, outputs)\n\n    if input_arr.dtype.is_integer:\n        result = tf.round(result)\n    return tf.cast(result, input_arr.dtype)\n\n\ndef gaussian_blur(\n    images, kernel_size=(3, 3), sigma=(1.0, 1.0), data_format=None\n):\n    def _create_gaussian_kernel(kernel_size, sigma, num_channels, dtype):\n        def _get_gaussian_kernel1d(size, sigma):\n            x = tf.range(size, dtype=dtype) - (size - 1) / 2\n            kernel1d = tf.exp(-0.5 * (x / sigma) ** 2)\n            return kernel1d / tf.reduce_sum(kernel1d)\n\n        def _get_gaussian_kernel2d(size, sigma):\n            size = tf.cast(size, dtype)\n            kernel1d_x = _get_gaussian_kernel1d(size[0], sigma[0])\n            kernel1d_y = _get_gaussian_kernel1d(size[1], sigma[1])\n            return tf.tensordot(kernel1d_y, kernel1d_x, axes=0)\n\n        kernel = _get_gaussian_kernel2d(kernel_size, sigma)\n        kernel = tf.reshape(kernel, (kernel_size[0], kernel_size[1], 1, 1))\n        kernel = tf.tile(kernel, [1, 1, num_channels, 1])\n        kernel = tf.cast(kernel, dtype)\n        return kernel\n\n    images = convert_to_tensor(images)\n    dtype = backend.standardize_dtype(images.dtype)\n    kernel_size = convert_to_tensor(kernel_size, dtype=dtype)\n    sigma = convert_to_tensor(sigma, dtype=dtype)\n\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n\n    need_squeeze = False\n    if len(images.shape) == 3:\n        images = tf.expand_dims(images, axis=0)\n        need_squeeze = True\n\n    if data_format == \"channels_first\":\n        images = tf.transpose(images, (0, 2, 3, 1))\n\n    num_channels = tf.shape(images)[-1]\n    kernel = _create_gaussian_kernel(kernel_size, sigma, num_channels, dtype)\n\n    blurred_images = tf.nn.depthwise_conv2d(\n        images, kernel, strides=[1, 1, 1, 1], padding=\"SAME\"\n    )\n\n    if data_format == \"channels_first\":\n        blurred_images = tf.transpose(blurred_images, (0, 3, 1, 2))\n    if need_squeeze:\n        blurred_images = tf.squeeze(blurred_images, axis=0)\n\n    return blurred_images\n\n\ndef elastic_transform(\n    images,\n    alpha=20.0,\n    sigma=5.0,\n    interpolation=\"bilinear\",\n    fill_mode=\"reflect\",\n    fill_value=0.0,\n    seed=None,\n    data_format=None,\n):\n    data_format = backend.standardize_data_format(data_format)\n    if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS:\n        raise ValueError(\n            \"Invalid value for argument `interpolation`. Expected of one \"\n            f\"{AFFINE_TRANSFORM_INTERPOLATIONS}. Received: \"\n            f\"interpolation={interpolation}\"\n        )\n    if fill_mode not in AFFINE_TRANSFORM_FILL_MODES:\n        raise ValueError(\n            \"Invalid value for argument `fill_mode`. Expected of one \"\n            f\"{AFFINE_TRANSFORM_FILL_MODES}. Received: fill_mode={fill_mode}\"\n        )\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n\n    images = convert_to_tensor(images)\n    input_dtype = images.dtype\n\n    alpha = convert_to_tensor(alpha, dtype=input_dtype)\n    sigma = convert_to_tensor(sigma, dtype=input_dtype)\n    kernel_factor = convert_to_tensor(sigma, dtype=\"int32\")\n    kernel_size = (6 * kernel_factor | 1, 6 * kernel_factor | 1)\n\n    need_squeeze = False\n    if len(images.shape) == 3:\n        images = tf.expand_dims(images, axis=0)\n        need_squeeze = True\n\n    if data_format == \"channels_last\":\n        batch_size, height, width, channels = images.shape\n        channel_axis = -1\n    else:\n        batch_size, channels, height, width = images.shape\n        channel_axis = 1\n\n    seed = draw_seed(seed)\n\n    if batch_size is None:\n        batch_size = 1\n\n    dx = (\n        tf.random.stateless_normal(\n            shape=(batch_size, height, width),\n            mean=0.0,\n            stddev=1.0,\n            dtype=input_dtype,\n            seed=seed,\n        )\n        * sigma\n    )\n    dy = (\n        tf.random.stateless_normal(\n            shape=(batch_size, height, width),\n            mean=0.0,\n            stddev=1.0,\n            dtype=input_dtype,\n            seed=seed,\n        )\n        * sigma\n    )\n\n    dx = gaussian_blur(\n        tf.expand_dims(dx, axis=channel_axis),\n        kernel_size=kernel_size,\n        sigma=(sigma, sigma),\n        data_format=data_format,\n    )\n    dy = gaussian_blur(\n        tf.expand_dims(dy, axis=channel_axis),\n        kernel_size=kernel_size,\n        sigma=(sigma, sigma),\n        data_format=data_format,\n    )\n\n    dx = tf.squeeze(dx, axis=channel_axis)\n    dy = tf.squeeze(dy, axis=channel_axis)\n\n    x, y = tf.meshgrid(\n        tf.range(width, dtype=input_dtype),\n        tf.range(height, dtype=input_dtype),\n        indexing=\"xy\",\n    )\n    x = tf.expand_dims(x, axis=0)\n    y = tf.expand_dims(y, axis=0)\n\n    distorted_x = x + alpha * dx\n    distorted_y = y + alpha * dy\n\n    channel_outputs = []\n    if data_format == \"channels_last\":\n        for i in range(channels):\n            channel_transformed = tf.stack(\n                [\n                    map_coordinates(\n                        images[b, ..., i],\n                        [distorted_y[b], distorted_x[b]],\n                        order=AFFINE_TRANSFORM_INTERPOLATIONS.index(\n                            interpolation\n                        ),\n                        fill_mode=fill_mode,\n                        fill_value=fill_value,\n                    )\n                    for b in range(batch_size)\n                ],\n                axis=0,\n            )\n            channel_outputs.append(channel_transformed)\n        transformed_images = tf.stack(channel_outputs, axis=-1)\n    else:\n        for i in range(channels):\n            channel_transformed = tf.stack(\n                [\n                    map_coordinates(\n                        images[b, i, ...],\n                        [distorted_y[b], distorted_x[b]],\n                        order=AFFINE_TRANSFORM_INTERPOLATIONS.index(\n                            interpolation\n                        ),\n                        fill_mode=fill_mode,\n                        fill_value=fill_value,\n                    )\n                    for b in range(batch_size)\n                ],\n                axis=0,\n            )\n            channel_outputs.append(channel_transformed)\n        transformed_images = tf.stack(channel_outputs, axis=1)\n\n    if need_squeeze:\n        transformed_images = tf.squeeze(transformed_images, axis=0)\n    transformed_images = tf.cast(transformed_images, input_dtype)\n\n    return transformed_images\n\n\ndef _fill_triangle_kernel(x):\n    return tf.maximum(tf.constant(0, dtype=x.dtype), 1 - tf.abs(x))\n\n\ndef _fill_keys_cubic_kernel(x):\n    out = ((1.5 * x - 2.5) * x) * x + 1.0\n    out = tf.where(x >= 1.0, ((-0.5 * x + 2.5) * x - 4.0) * x + 2.0, out)\n    return tf.where(x >= 2.0, 0.0, out)\n\n\ndef _fill_lanczos_kernel(radius, x):\n    y = radius * tf.sin(np.pi * x) * tf.sin(np.pi * x / radius)\n    out = tf.where(\n        x > 1e-3, tf.divide(y, tf.where(x != 0, np.pi**2 * x**2, 1)), 1\n    )\n    return tf.where(x > radius, 0.0, out)\n\n\n_kernels = {\n    \"linear\": _fill_triangle_kernel,\n    \"cubic\": _fill_keys_cubic_kernel,\n    \"lanczos3\": lambda x: _fill_lanczos_kernel(3.0, x),\n    \"lanczos5\": lambda x: _fill_lanczos_kernel(5.0, x),\n}\n\n\ndef _compute_weight_mat(\n    input_size, output_size, scale, translation, kernel, antialias\n):\n    dtype = backend.result_type(scale.dtype, translation.dtype)\n    inv_scale = 1.0 / scale\n    kernel_scale = tf.maximum(inv_scale, 1.0) if antialias else 1.0\n    sample_f = (\n        (tf.range(output_size, dtype=dtype) + 0.5) * inv_scale\n        - translation * inv_scale\n        - 0.5\n    )\n    x = (\n        tf.abs(\n            sample_f[tf.newaxis, :]\n            - tf.range(input_size, dtype=dtype)[:, tf.newaxis]\n        )\n        / kernel_scale\n    )\n    weights = kernel(x)\n    total_weight_sum = tf.reduce_sum(weights, axis=0, keepdims=True)\n    weights = tf.where(\n        tf.abs(total_weight_sum) > 1000.0 * float(np.finfo(np.float32).eps),\n        tf.divide(\n            weights, tf.where(total_weight_sum != 0, total_weight_sum, 1)\n        ),\n        0,\n    )\n    input_size_minus_0_5 = tf.cast(input_size, dtype=dtype) - 0.5\n    return tf.where(\n        tf.logical_and(sample_f >= -0.5, sample_f <= input_size_minus_0_5)[\n            tf.newaxis, :\n        ],\n        weights,\n        0,\n    )\n\n\ndef _scale_and_translate(\n    x, output_shape, spatial_dims, scale, translation, kernel, antialias\n):\n    x = convert_to_tensor(x)\n    input_shape = tf.shape(x)\n    if len(spatial_dims) == 0:\n        return x\n    if backend.is_int_dtype(x.dtype):\n        output = tf.cast(x, tf.float32)\n        use_rounding = True\n    else:\n        output = tf.identity(x)\n        use_rounding = False\n    for i, d in enumerate(spatial_dims):\n        d = d % x.ndim\n        m, n = input_shape[d], output_shape[d]\n        w = tf.cast(\n            _compute_weight_mat(\n                m, n, scale[i], translation[i], kernel, antialias\n            ),\n            output.dtype,\n        )\n        output = tf.tensordot(output, w, axes=(d, 0))\n        output = moveaxis(output, -1, d)\n    if use_rounding:\n        output = tf.clip_by_value(\n            tf.round(output), tf.reduce_min(x), tf.reduce_max(x)\n        )\n        output = tf.cast(output, x.dtype)\n    return output\n\n\ndef scale_and_translate(\n    images,\n    output_shape,\n    scale,\n    translation,\n    spatial_dims,\n    method,\n    antialias=True,\n):\n    if method not in SCALE_AND_TRANSLATE_METHODS:\n        raise ValueError(\n            \"Invalid value for argument `method`. Expected of one \"\n            f\"{SCALE_AND_TRANSLATE_METHODS}. Received: method={method}\"\n        )\n    if method in (\"linear\", \"bilinear\", \"trilinear\", \"triangle\"):\n        method = \"linear\"\n    elif method in (\"cubic\", \"bicubic\", \"tricubic\"):\n        method = \"cubic\"\n\n    images = convert_to_tensor(images)\n    scale = convert_to_tensor(scale)\n    translation = convert_to_tensor(translation)\n    kernel = _kernels[method]\n    dtype = backend.result_type(scale.dtype, translation.dtype)\n    scale = tf.cast(scale, dtype)\n    translation = tf.cast(translation, dtype)\n    return _scale_and_translate(\n        images,\n        output_shape,\n        spatial_dims,\n        scale,\n        translation,\n        kernel,\n        antialias,\n    )\n"
  },
  {
    "path": "keras/src/backend/tensorflow/layer.py",
    "content": "import collections\n\nimport tensorflow as tf\n\nfrom keras.src import tree\nfrom keras.src.backend.tensorflow.trackable import KerasAutoTrackable\nfrom keras.src.utils import tf_utils\nfrom keras.src.utils import tracking\n\n\nclass TFLayer(KerasAutoTrackable):\n    def __init__(self, *args, **kwargs):\n        # Export-related attributes\n        self._saved_model_inputs_spec = None\n        self._saved_model_arg_spec = None\n        self._tracked = []\n\n    def _set_save_spec(self, inputs, args=None, kwargs=None):\n        \"\"\"Defines the save spec so that serialization can trace layer calls.\n\n        The TensorSpecs of the call function `inputs`, `args`, and `kwargs` are\n        saved into a tuple of `([inputs] + args, kwargs)`.\n\n        Args:\n          inputs: possibly nested inputs passed into the call function.\n          args: a list of positional arguments passed into call.\n          kwargs: a dictionary of keyword arguments passed into call.\n        \"\"\"\n        if self._saved_model_inputs_spec is not None:\n            return  # Already set.\n\n        inputs_spec = tree.map_structure(tf_utils.get_tensor_spec, inputs)\n        args_spec = tree.map_structure(tf_utils.get_tensor_spec, args or [])\n        kwargs_spec = {}\n        # Filter out non-tensor arguments from kwargs.\n        for key, kwarg in kwargs.items():\n            flat_kwarg = tree.flatten(kwarg)\n            flat_specs = [tf_utils.get_tensor_spec(x) for x in flat_kwarg]\n            if any(s is None for s in flat_specs):\n                continue\n            kwargs_spec[key] = tree.pack_sequence_as(kwarg, flat_specs)\n\n        self._saved_model_inputs_spec = inputs_spec\n        self._saved_model_arg_spec = (\n            [inputs_spec] + list(args_spec),\n            kwargs_spec,\n        )\n\n    @tf.__internal__.tracking.no_automatic_dependency_tracking\n    def _trackable_children(self, save_type=\"checkpoint\", **kwargs):\n        if save_type == \"savedmodel\":\n            # SavedModel needs to ignore the execution functions.\n            train_function = getattr(self, \"train_function\", None)\n            test_function = getattr(self, \"test_function\", None)\n            predict_function = getattr(self, \"predict_function\", None)\n            self.train_function = None\n            self.test_function = None\n            self.predict_function = None\n\n        children = super()._trackable_children(save_type, **kwargs)\n\n        if save_type == \"savedmodel\":\n            self.train_function = train_function\n            self.test_function = test_function\n            self.predict_function = predict_function\n\n            # Convert Keras tracked collections to plain Python structures\n            # without creating TensorFlow trackable dependencies\n            self._convert_tracked_collections(children)\n\n        return children\n\n    def _convert_tracked_collections(self, children):\n        \"\"\"Convert TrackedList/Dict/Set to plain Python structures.\"\"\"\n        for tracked_attr in self._tracked:\n            tracked_item = getattr(self, tracked_attr)\n            if isinstance(tracked_item, tracking.TrackedList):\n                children[tracked_attr] = list(tracked_item)\n            elif isinstance(tracked_item, tracking.TrackedOrderedDict):\n                children[tracked_attr] = collections.OrderedDict(tracked_item)\n            elif isinstance(tracked_item, tracking.TrackedDict):\n                children[tracked_attr] = dict(tracked_item)\n            elif isinstance(tracked_item, tracking.TrackedSet):\n                children[tracked_attr] = list(tracked_item)\n\n    def _get_save_spec(self, dynamic_batch=True):\n        \"\"\"Compatibility shim for TensorFlow saving utilities.\n\n        TensorFlow's SavedModel / TFLite export paths (e.g.,\n        tf.lite.TFLiteConverter.from_keras_model) expect a `_get_save_spec`\n        method on models. This method generates TensorSpec objects\n        describing the model's input signature.\n\n        Args:\n            dynamic_batch: whether to set the batch dimension to `None`.\n\n        Returns:\n            A TensorSpec, list or dict mirroring the model inputs, or\n            `None` when specs cannot be inferred.\n        \"\"\"\n        # Lazy import to avoid circular dependency\n        from keras.src.export.export_utils import make_tf_tensor_spec\n\n        # Fall back to building specs from `self.inputs`\n        inputs = getattr(self, \"inputs\", None)\n        if inputs is None:\n            return None\n\n        return tree.map_structure(\n            lambda x: make_tf_tensor_spec(x, dynamic_batch=dynamic_batch),\n            inputs,\n        )\n\n    @property\n    def _default_save_signature(self):\n        \"\"\"For SavedModel support: returns the default serving signature.\"\"\"\n\n        from keras.src.models.functional import Functional\n        from keras.src.models.model import Model\n        from keras.src.models.sequential import Sequential\n\n        if not isinstance(self, Model):\n            return None\n\n        inputs = None\n        if (\n            isinstance(self, Sequential)\n            and getattr(self, \"_functional\", None) is not None\n        ):\n            inputs = self._functional.input\n        elif isinstance(self, Functional):\n            inputs = self.input\n\n        if inputs is not None:\n            input_signature = (\n                tree.map_structure(\n                    lambda x: tf.TensorSpec(x.shape, x.dtype), inputs\n                ),\n            )\n        else:\n            input_signature = tuple(\n                tree.map_shape_structure(\n                    lambda s: tf.TensorSpec(s, self.input_dtype), value\n                )\n                for value in self._build_shapes_dict.values()\n            )\n\n        @tf.function(input_signature=input_signature)\n        def serving_default(inputs):\n            return self(inputs)\n\n        return serving_default\n"
  },
  {
    "path": "keras/src/backend/tensorflow/linalg.py",
    "content": "import tensorflow as tf\n\nfrom keras.src.backend import config\nfrom keras.src.backend import standardize_dtype\nfrom keras.src.backend.common import dtypes\nfrom keras.src.backend.tensorflow.core import cast\nfrom keras.src.backend.tensorflow.core import convert_to_tensor\n\n\ndef cholesky(a, upper=False):\n    out = tf.linalg.cholesky(a)\n    # tf.linalg.cholesky simply returns NaNs for non-positive definite matrices\n    out = tf.debugging.check_numerics(out, \"Cholesky\")\n    if upper:\n        return tf.linalg.adjoint(out)\n    return out\n\n\ndef cholesky_inverse(a, upper=False):\n    identity = tf.eye(num_rows=tf.shape(a)[-1], dtype=a.dtype)\n    inv_chol = tf.linalg.triangular_solve(a, identity, lower=not upper)\n    if upper:\n        a_inv = tf.matmul(inv_chol, inv_chol, transpose_b=True)\n    else:\n        a_inv = tf.matmul(inv_chol, inv_chol, transpose_a=True)\n    return a_inv\n\n\ndef det(a):\n    return tf.linalg.det(a)\n\n\ndef eig(a):\n    return tf.linalg.eig(a)\n\n\ndef eigh(a):\n    return tf.linalg.eigh(a)\n\n\ndef inv(a):\n    return tf.linalg.inv(a)\n\n\ndef lu_factor(a):\n    lu, p = tf.linalg.lu(a)\n    return lu, tf.math.invert_permutation(p)\n\n\ndef norm(x, ord=None, axis=None, keepdims=False):\n    from keras.src.backend.tensorflow.numpy import moveaxis\n\n    x = convert_to_tensor(x)\n    x_shape = x.shape\n    ndim = x_shape.rank\n\n    if axis is None:\n        axis = tuple(range(ndim))\n    elif isinstance(axis, int):\n        axis = (axis,)\n    if any(a < -ndim or a >= ndim for a in axis):\n        raise ValueError(\n            \"All `axis` values must be in the range [-ndim, ndim). \"\n            f\"Received inputs with ndim={ndim}, while axis={axis}\"\n        )\n    axis = axis[0] if len(axis) == 1 else axis\n    num_axes = 1 if isinstance(axis, int) else len(axis)\n\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = cast(x, dtype)\n\n    # Ref: jax.numpy.linalg.norm\n    if num_axes == 1:\n        if ord is None or ord == 2:\n            return tf.sqrt(\n                tf.reduce_sum(x * tf.math.conj(x), axis=axis, keepdims=keepdims)\n            )\n        elif ord == float(\"inf\"):\n            return tf.math.reduce_max(\n                tf.math.abs(x), axis=axis, keepdims=keepdims\n            )\n        elif ord == float(\"-inf\"):\n            return tf.math.reduce_min(\n                tf.math.abs(x), axis=axis, keepdims=keepdims\n            )\n        elif ord == 0:\n            return tf.math.reduce_sum(\n                tf.cast(tf.not_equal(x, 0), dtype=x.dtype),\n                axis=axis,\n                keepdims=keepdims,\n            )\n        elif isinstance(ord, str):\n            raise ValueError(\n                f\"Invalid `ord` argument for vector norm. Received: ord={ord}\"\n            )\n        else:\n            ord = convert_to_tensor(ord, dtype=x.dtype)\n            out = tf.math.reduce_sum(\n                tf.pow(tf.math.abs(x), ord), axis=axis, keepdims=keepdims\n            )\n            return tf.pow(out, 1.0 / ord)\n    elif num_axes == 2:\n        row_axis, col_axis = axis[0], axis[1]\n        row_axis = row_axis + ndim if row_axis < 0 else row_axis\n        col_axis = col_axis + ndim if col_axis < 0 else col_axis\n        if ord is None or ord == \"fro\":\n            return tf.sqrt(\n                tf.reduce_sum(x * tf.math.conj(x), axis=axis, keepdims=keepdims)\n            )\n        elif ord == 1:\n            if not keepdims and col_axis > row_axis:\n                col_axis -= 1\n            x = tf.math.reduce_max(\n                tf.reduce_sum(tf.math.abs(x), axis=row_axis, keepdims=keepdims),\n                axis=col_axis,\n                keepdims=keepdims,\n            )\n        elif ord == -1:\n            if not keepdims and col_axis > row_axis:\n                col_axis -= 1\n            x = tf.math.reduce_min(\n                tf.reduce_sum(tf.math.abs(x), axis=row_axis, keepdims=keepdims),\n                axis=col_axis,\n                keepdims=keepdims,\n            )\n        elif ord == float(\"inf\"):\n            if not keepdims and row_axis > col_axis:\n                row_axis -= 1\n            x = tf.math.reduce_max(\n                tf.reduce_sum(tf.math.abs(x), axis=col_axis, keepdims=keepdims),\n                axis=row_axis,\n                keepdims=keepdims,\n            )\n        elif ord == float(\"-inf\"):\n            if not keepdims and row_axis > col_axis:\n                row_axis -= 1\n            x = tf.math.reduce_min(\n                tf.reduce_sum(tf.math.abs(x), axis=col_axis, keepdims=keepdims),\n                axis=row_axis,\n                keepdims=keepdims,\n            )\n        elif ord in (\"nuc\", 2, -2):\n            x = moveaxis(x, axis, (-2, -1))\n            if ord == -2:\n                x = tf.math.reduce_min(\n                    tf.linalg.svd(x, compute_uv=False), axis=-1\n                )\n            elif ord == 2:\n                x = tf.math.reduce_max(\n                    tf.linalg.svd(x, compute_uv=False), axis=-1\n                )\n            else:\n                x = tf.math.reduce_sum(\n                    tf.linalg.svd(x, compute_uv=False), axis=-1\n                )\n            if keepdims:\n                x = tf.expand_dims(x, axis[0])\n                x = tf.expand_dims(x, axis[1])\n        else:\n            raise ValueError(\n                f\"Invalid `ord` argument for matrix norm. Received: ord={ord}\"\n            )\n        return x\n    else:\n        raise ValueError(f\"Invalid axis values. Received: axis={axis}\")\n\n\ndef qr(x, mode=\"reduced\"):\n    if mode not in {\"reduced\", \"complete\"}:\n        raise ValueError(\n            \"`mode` argument value not supported. \"\n            \"Expected one of {'reduced', 'complete'}. \"\n            f\"Received: mode={mode}\"\n        )\n    if mode == \"reduced\":\n        return tf.linalg.qr(x)\n    return tf.linalg.qr(x, full_matrices=True)\n\n\ndef solve(a, b):\n    # tensorflow.linalg.solve only supports same rank inputs\n    if b.shape.ndims == a.shape.ndims - 1:\n        b = tf.expand_dims(b, axis=-1)\n        return tf.squeeze(tf.linalg.solve(a, b), axis=-1)\n    return tf.linalg.solve(a, b)\n\n\ndef solve_triangular(a, b, lower=False):\n    if b.shape.ndims == a.shape.ndims - 1:\n        b = tf.expand_dims(b, axis=-1)\n        return tf.squeeze(\n            tf.linalg.triangular_solve(a, b, lower=lower), axis=-1\n        )\n    return tf.linalg.triangular_solve(a, b, lower=lower)\n\n\ndef svd(x, full_matrices=True, compute_uv=True):\n    if compute_uv is False:\n        return tf.linalg.svd(x, full_matrices=full_matrices, compute_uv=False)\n    s, u, v = tf.linalg.svd(\n        x, full_matrices=full_matrices, compute_uv=compute_uv\n    )\n    return u, s, tf.linalg.adjoint(v)\n\n\ndef lstsq(a, b, rcond=None):\n    a = convert_to_tensor(a)\n    b = convert_to_tensor(b)\n    if a.shape[0] != b.shape[0]:\n        raise ValueError(\"Leading dimensions of input arrays must match\")\n    b_orig_ndim = b.ndim\n    if b_orig_ndim == 1:\n        b = b[:, None]\n    if a.ndim != 2:\n        raise TypeError(\n            f\"{a.ndim}-dimensional array given. Array must be two-dimensional\"\n        )\n    if b.ndim != 2:\n        raise TypeError(\n            f\"{b.ndim}-dimensional array given. \"\n            \"Array must be one or two-dimensional\"\n        )\n    m, n = a.shape\n    dtype = a.dtype\n    eps = tf.experimental.numpy.finfo(dtype).eps\n    if a.shape == ():\n        s = tf.zeros(0, dtype=a.dtype)\n        x = tf.zeros((n, *b.shape[1:]), dtype=a.dtype)\n    else:\n        if rcond is None:\n            rcond = eps * max(n, m)\n        else:\n            rcond = tf.where(rcond < 0, eps, rcond)\n        u, s, vt = svd(a, full_matrices=False)\n        mask = s >= tf.convert_to_tensor(rcond, dtype=s.dtype) * s[0]\n        safe_s = tf.cast(tf.where(mask, s, 1), dtype=a.dtype)\n        s_inv = tf.where(mask, 1 / safe_s, 0)[:, tf.newaxis]\n        u_t_b = tf.matmul(tf.transpose(tf.math.conj(u)), b)\n        x = tf.matmul(tf.transpose(tf.math.conj(vt)), s_inv * u_t_b)\n\n    if b_orig_ndim == 1:\n        x = tf.reshape(x, [-1])\n    return x\n\n\ndef jvp(fun, primals, tangents, has_aux=False):\n    primal_flat = tf.nest.flatten(primals)\n    tangent_flat = tf.nest.flatten(tangents)\n\n    tangent_flat = [\n        tf.cast(t, p.dtype) for t, p in zip(tangent_flat, primal_flat)\n    ]\n\n    with tf.autodiff.ForwardAccumulator(primal_flat, tangent_flat) as acc:\n        if has_aux:\n            primals_out, aux = fun(*primals)\n        else:\n            primals_out = fun(*primals)\n\n        primals_out_flat = tf.nest.flatten(primals_out)\n        tangents_out_flat = [acc.jvp(po) for po in primals_out_flat]\n\n    tangents_out = tf.nest.pack_sequence_as(primals_out, tangents_out_flat)\n\n    if has_aux:\n        return primals_out, tangents_out, aux\n    return primals_out, tangents_out\n"
  },
  {
    "path": "keras/src/backend/tensorflow/math.py",
    "content": "import tensorflow as tf\n\nfrom keras.src.backend import standardize_dtype\nfrom keras.src.backend.tensorflow.core import cast\nfrom keras.src.backend.tensorflow.core import convert_to_tensor\n\n\ndef segment_sum(data, segment_ids, num_segments=None, sorted=False):\n    if sorted:\n        if num_segments is not None:\n            raise ValueError(\n                \"Argument `num_segments` cannot be set when sorted is True \"\n                \"when using the tensorflow backend.\"\n                f\"Received: num_segments={num_segments}, sorted={sorted}.\"\n            )\n        return tf.math.segment_sum(data, segment_ids)\n    else:\n        if num_segments is None:\n            unique_segment_ids, _ = tf.unique(segment_ids)\n            num_segments = tf.shape(unique_segment_ids)[0]\n        return tf.math.unsorted_segment_sum(data, segment_ids, num_segments)\n\n\ndef segment_max(data, segment_ids, num_segments=None, sorted=False):\n    if sorted:\n        if num_segments is not None:\n            raise ValueError(\n                \"Argument `num_segments` cannot be set when sorted is True \"\n                \"when using the tensorflow backend.\"\n                f\"Received: num_segments={num_segments}, sorted={sorted}.\"\n            )\n        return tf.math.segment_max(data, segment_ids)\n    else:\n        if num_segments is None:\n            unique_segment_ids, _ = tf.unique(segment_ids)\n            num_segments = tf.shape(unique_segment_ids)[0]\n        return tf.math.unsorted_segment_max(data, segment_ids, num_segments)\n\n\ndef top_k(x, k, sorted=True):\n    return tf.math.top_k(x, k, sorted=sorted)\n\n\ndef in_top_k(targets, predictions, k):\n    return tf.math.in_top_k(targets, predictions, k)\n\n\ndef logsumexp(x, axis=None, keepdims=False):\n    return tf.math.reduce_logsumexp(x, axis=axis, keepdims=keepdims)\n\n\ndef qr(x, mode=\"reduced\"):\n    if mode not in {\"reduced\", \"complete\"}:\n        raise ValueError(\n            \"`mode` argument value not supported. \"\n            \"Expected one of {'reduced', 'complete'}. \"\n            f\"Received: mode={mode}\"\n        )\n    if mode == \"reduced\":\n        return tf.linalg.qr(x)\n    return tf.linalg.qr(x, full_matrices=True)\n\n\ndef extract_sequences(x, sequence_length, sequence_stride):\n    return tf.signal.frame(\n        x,\n        frame_length=sequence_length,\n        frame_step=sequence_stride,\n        axis=-1,\n        pad_end=False,\n    )\n\n\ndef _get_complex_tensor_from_tuple(x):\n    if not isinstance(x, (tuple, list)) or len(x) != 2:\n        raise ValueError(\n            \"Input `x` should be a tuple of two tensors - real and imaginary.\"\n            f\"Received: x={x}\"\n        )\n    # `convert_to_tensor` does not support passing complex tensors. We separate\n    # the input out into real and imaginary and convert them separately.\n    real, imag = x\n    real = convert_to_tensor(real)\n    imag = convert_to_tensor(imag)\n    # Check shapes.\n    if real.shape != imag.shape:\n        raise ValueError(\n            \"Input `x` should be a tuple of two tensors - real and imaginary.\"\n            \"Both the real and imaginary parts should have the same shape. \"\n            f\"Received: x[0].shape = {real.shape}, x[1].shape = {imag.shape}\"\n        )\n    # Ensure dtype is float.\n    if not real.dtype.is_floating or not imag.dtype.is_floating:\n        raise ValueError(\n            \"At least one tensor in input `x` is not of type float.\"\n            f\"Received: x={x}.\"\n        )\n    complex_input = tf.dtypes.complex(real, imag)\n    return complex_input\n\n\ndef fft(x):\n    complex_input = _get_complex_tensor_from_tuple(x)\n    complex_output = tf.signal.fft(complex_input)\n    return tf.math.real(complex_output), tf.math.imag(complex_output)\n\n\ndef fft2(x):\n    complex_input = _get_complex_tensor_from_tuple(x)\n    complex_output = tf.signal.fft2d(complex_input)\n    return tf.math.real(complex_output), tf.math.imag(complex_output)\n\n\ndef ifft2(x):\n    real, imag = x\n    h = cast(tf.shape(real)[-2], real.dtype)\n    w = cast(tf.shape(real)[-1], real.dtype)\n    real_conj, imag_conj = real, -imag\n    fft_real, fft_imag = fft2((real_conj, imag_conj))\n    return fft_real / (h * w), -fft_imag / (h * w)\n\n\ndef rfft(x, fft_length=None):\n    if fft_length is not None:\n        fft_length = [fft_length]\n    complex_output = tf.signal.rfft(x, fft_length=fft_length)\n    return tf.math.real(complex_output), tf.math.imag(complex_output)\n\n\ndef irfft(x, fft_length=None):\n    complex_input = _get_complex_tensor_from_tuple(x)\n    if fft_length is not None:\n        fft_length = [fft_length]\n    return tf.signal.irfft(complex_input, fft_length)\n\n\ndef stft(\n    x, sequence_length, sequence_stride, fft_length, window=\"hann\", center=True\n):\n    if standardize_dtype(x.dtype) not in {\"float32\", \"float64\"}:\n        raise TypeError(\n            \"Invalid input type. Expected `float32` or `float64`. \"\n            f\"Received: input type={x.dtype}\"\n        )\n    if fft_length < sequence_length:\n        raise ValueError(\n            \"`fft_length` must equal or larger than `sequence_length`. \"\n            f\"Received: sequence_length={sequence_length}, \"\n            f\"fft_length={fft_length}\"\n        )\n    if isinstance(window, str):\n        if window not in {\"hann\", \"hamming\"}:\n            raise ValueError(\n                \"If a string is passed to `window`, it must be one of \"\n                f'`\"hann\"`, `\"hamming\"`. Received: window={window}'\n            )\n    x = convert_to_tensor(x)\n\n    if center:\n        pad_width = [(0, 0) for _ in range(len(x.shape))]\n        pad_width[-1] = (fft_length // 2, fft_length // 2)\n        x = tf.pad(x, pad_width, mode=\"reflect\")\n\n    l_pad = (fft_length - sequence_length) // 2\n    r_pad = fft_length - sequence_length - l_pad\n\n    if window is not None:\n        if isinstance(window, str):\n            if window == \"hann\":\n                win_array = tf.signal.hann_window(\n                    sequence_length, periodic=True, dtype=x.dtype\n                )\n            else:\n                win_array = tf.signal.hamming_window(\n                    sequence_length, periodic=True, dtype=x.dtype\n                )\n        else:\n            win_array = convert_to_tensor(window, dtype=x.dtype)\n        if len(win_array.shape) != 1 or win_array.shape[-1] != sequence_length:\n            raise ValueError(\n                \"The shape of `window` must be equal to [sequence_length].\"\n                f\"Received: window shape={win_array.shape}\"\n            )\n        win_array = tf.pad(win_array, [[l_pad, r_pad]])\n\n        def win(frame_step, dtype):\n            return win_array\n\n    else:\n        win = None\n\n    result = tf.signal.stft(\n        x,\n        frame_length=(sequence_length + l_pad + r_pad),\n        frame_step=sequence_stride,\n        fft_length=fft_length,\n        window_fn=win,\n    )\n    return tf.math.real(result), tf.math.imag(result)\n\n\ndef istft(\n    x,\n    sequence_length,\n    sequence_stride,\n    fft_length,\n    length=None,\n    window=\"hann\",\n    center=True,\n):\n    complex_input = _get_complex_tensor_from_tuple(x)\n    dtype = tf.math.real(complex_input).dtype\n\n    expected_output_len = fft_length + sequence_stride * (\n        tf.shape(complex_input)[-2] - 1\n    )\n    l_pad = (fft_length - sequence_length) // 2\n    r_pad = fft_length - sequence_length - l_pad\n\n    if window is not None:\n        if isinstance(window, str):\n            if window == \"hann\":\n                win_array = tf.signal.hann_window(\n                    sequence_length, periodic=True, dtype=dtype\n                )\n            else:\n                win_array = tf.signal.hamming_window(\n                    sequence_length, periodic=True, dtype=dtype\n                )\n        else:\n            win_array = convert_to_tensor(window, dtype=dtype)\n        if len(win_array.shape) != 1 or win_array.shape[-1] != sequence_length:\n            raise ValueError(\n                \"The shape of `window` must be equal to [sequence_length].\"\n                f\"Received: window shape={win_array.shape}\"\n            )\n        win_array = tf.pad(win_array, [[l_pad, r_pad]])\n        win = tf.signal.inverse_stft_window_fn(\n            sequence_stride, lambda frame_step, dtype: win_array\n        )\n    else:\n        win = None\n\n    x = tf.signal.inverse_stft(\n        complex_input,\n        frame_length=(sequence_length + l_pad + r_pad),\n        frame_step=sequence_stride,\n        fft_length=fft_length,\n        window_fn=win,\n    )\n\n    start = 0 if center is False else fft_length // 2\n    if length is not None:\n        end = start + length\n    elif center is True:\n        end = -(fft_length // 2)\n    else:\n        end = expected_output_len\n    return x[..., start:end]\n\n\ndef rsqrt(x):\n    return tf.math.rsqrt(x)\n\n\ndef erf(x):\n    return tf.math.erf(x)\n\n\ndef erfinv(x):\n    return tf.math.erfinv(x)\n\n\ndef logdet(x):\n    x = convert_to_tensor(x)\n    return tf.linalg.logdet(x)\n"
  },
  {
    "path": "keras/src/backend/tensorflow/name_scope_test.py",
    "content": "import tensorflow as tf\n\nfrom keras.src.backend.tensorflow.core import name_scope\nfrom keras.src.testing import TestCase\n\n\nclass TFNameScopeTest(TestCase):\n    def test_stacking(self):\n        self.assertEqual(tf.Variable(0, name=\"x\").name, \"x:0\")\n        with name_scope(\"outer\") as outer:\n            self.assertEqual(outer.name, \"outer\")\n            self.assertEqual(tf.Variable(0, name=\"x\").name, \"outer/x:0\")\n            with name_scope(\"middle\") as middle:\n                self.assertEqual(middle.name, \"middle\")\n                self.assertEqual(\n                    tf.Variable(0, name=\"x\").name, \"outer/middle/x:0\"\n                )\n                with name_scope(\"inner\") as inner:\n                    self.assertEqual(inner.name, \"inner\")\n                    self.assertEqual(\n                        tf.Variable(0, name=\"x\").name, \"outer/middle/inner/x:0\"\n                    )\n                self.assertEqual(\n                    tf.Variable(0, name=\"x\").name, \"outer/middle/x:0\"\n                )\n            self.assertEqual(tf.Variable(0, name=\"x\").name, \"outer/x:0\")\n        self.assertEqual(tf.Variable(0, name=\"x\").name, \"x:0\")\n\n    def test_deduplicate(self):\n        self.assertEqual(tf.Variable(0, name=\"x\").name, \"x:0\")\n        with name_scope(\"name\", caller=1):\n            with name_scope(\"name\", caller=1):\n                self.assertEqual(tf.Variable(0, name=\"x\").name, \"name/x:0\")\n        self.assertEqual(tf.Variable(0, name=\"x\").name, \"x:0\")\n        with name_scope(\"name\"):\n            with name_scope(\"name\"):\n                self.assertEqual(tf.Variable(0, name=\"x\").name, \"name/name/x:0\")\n"
  },
  {
    "path": "keras/src/backend/tensorflow/nn.py",
    "content": "import math\nimport warnings\n\nimport tensorflow as tf\n\nfrom keras.src import backend\nfrom keras.src.backend.common.backend_utils import (\n    compute_adaptive_pooling_window_sizes,\n)\nfrom keras.src.backend.common.backend_utils import (\n    compute_conv_transpose_output_shape,\n)\nfrom keras.src.backend.tensorflow.core import cast\nfrom keras.src.backend.tensorflow.core import convert_to_tensor\n\n\ndef relu(x):\n    return tf.nn.relu(x)\n\n\ndef relu6(x):\n    return tf.nn.relu6(x)\n\n\ndef sigmoid(x):\n    logits = x\n    output = tf.nn.sigmoid(x)\n    output._keras_logits = logits\n    return output\n\n\ndef sparse_sigmoid(x):\n    x = convert_to_tensor(x)\n    return tf.where(\n        x <= -1,\n        tf.constant(0.0, dtype=x.dtype),\n        tf.where(x >= 1, tf.constant(1.0, dtype=x.dtype), 0.5 * (x + 1)),\n    )\n\n\ndef tanh(x):\n    return tf.nn.tanh(x)\n\n\ndef tanh_shrink(x):\n    return x - tf.math.tanh(x)\n\n\ndef softplus(x):\n    return tf.math.softplus(x)\n\n\ndef softsign(x):\n    return tf.nn.softsign(x)\n\n\ndef soft_shrink(x, threshold=0.5):\n    return tf.where(\n        x > threshold,\n        x - threshold,\n        tf.where(x < -threshold, x + threshold, tf.zeros_like(x)),\n    )\n\n\ndef sparse_plus(x):\n    return tf.where(\n        x <= -1,\n        tf.zeros_like(x),\n        tf.where(x < 1, (1 / 4) * tf.pow(x + 1, 2), x),\n    )\n\n\ndef silu(x):\n    return tf.nn.silu(x)\n\n\ndef squareplus(x, b=4):\n    x = convert_to_tensor(x)\n    b = convert_to_tensor(b, dtype=x.dtype)\n    y = x + tf.sqrt(tf.square(x) + b)\n    return y / 2\n\n\ndef log_sigmoid(x):\n    return tf.math.log_sigmoid(x)\n\n\ndef leaky_relu(x, negative_slope=0.2):\n    return tf.nn.leaky_relu(x, alpha=negative_slope)\n\n\ndef hard_sigmoid(x):\n    x = convert_to_tensor(x)\n    return relu6(x + tf.constant(3.0, x.dtype)) / tf.constant(6.0, x.dtype)\n\n\ndef hard_silu(x):\n    return x * hard_sigmoid(x)\n\n\ndef elu(x, alpha=1.0):\n    res = tf.nn.elu(x)\n    if alpha == 1:\n        return res\n    else:\n        return tf.where(x > 0, res, alpha * res)\n\n\ndef selu(x):\n    return tf.nn.selu(x)\n\n\ndef gelu(x, approximate=True):\n    x = convert_to_tensor(x)\n    return tf.nn.gelu(x, approximate=approximate)\n\n\ndef celu(x, alpha=1.0):\n    return tf.maximum(x, 0.0) + alpha * tf.math.expm1(\n        tf.minimum(x, 0.0) / alpha\n    )\n\n\ndef glu(x, axis=-1):\n    if x.shape[axis] % 2 != 0:\n        raise ValueError(\n            \"axis size must be divisible by 2. \"\n            f\"Received: x.shape={x.shape} with axis={axis}\"\n        )\n    x1, x2 = tf.split(x, num_or_size_splits=2, axis=axis)\n    return x1 * tf.sigmoid(x2)\n\n\ndef hard_tanh(x):\n    return tf.clip_by_value(x, clip_value_min=-1.0, clip_value_max=1.0)\n\n\ndef hard_shrink(x, threshold=0.5):\n    return tf.where(tf.abs(x) > threshold, x, tf.zeros_like(x))\n\n\ndef threshold(x, threshold, default_value):\n    return tf.where(x > threshold, x, default_value)\n\n\ndef softmax(x, axis=-1):\n    logits = x\n    if axis is None:\n        # Unlike numpy, tf will handle axis=None as axis=-1.\n        # We need this workaround for the reduction on every dim.\n        output = tf.reshape(x, [-1])\n        output = tf.nn.softmax(output, axis=-1)\n        output = tf.reshape(output, tf.shape(x))\n    else:\n        output = tf.nn.softmax(x, axis=axis)\n    output._keras_logits = logits\n    return output\n\n\ndef log_softmax(x, axis=-1):\n    if axis is None:\n        # Unlike numpy, tf will handle axis=None as axis=-1.\n        # We need this workaround for the reduction on every dim.\n        output = tf.reshape(x, [-1])\n        output = tf.nn.log_softmax(output, axis=-1)\n        return tf.reshape(output, tf.shape(x))\n    return tf.nn.log_softmax(x, axis=axis)\n\n\ndef sparsemax(x, axis=-1):\n    # Sort logits along the specified axis in descending order\n    logits = convert_to_tensor(x)\n    logits_sorted = tf.sort(logits, direction=\"DESCENDING\", axis=axis)\n    logits_cumsum = tf.cumsum(logits_sorted, axis=axis)\n    r = tf.range(1, tf.shape(logits)[axis] + 1, dtype=logits.dtype)\n    r_shape = [1] * len(logits.shape)\n    r_shape[axis] = -1  # Broadcast to match the target axis\n    r = tf.reshape(r, r_shape)  # Reshape for broadcasting\n    support = logits_sorted - (logits_cumsum - 1) / r > 0\n    # Find the threshold\n    logits_cumsum_safe = tf.where(support, logits_cumsum, 0.0)\n    k = tf.reduce_sum(tf.cast(support, logits.dtype), axis=axis, keepdims=True)\n    tau = (tf.reduce_sum(logits_cumsum_safe, axis=axis, keepdims=True) - 1) / k\n    output = tf.maximum(logits - tau, 0.0)\n    return output\n\n\ndef _transpose_spatial_inputs(inputs):\n    num_spatial_dims = len(inputs.shape) - 2\n    # Tensorflow pooling does not support `channels_first` format, so\n    # we need to transpose to `channels_last` format.\n    if num_spatial_dims == 1:\n        inputs = tf.transpose(inputs, (0, 2, 1))\n    elif num_spatial_dims == 2:\n        inputs = tf.transpose(inputs, (0, 2, 3, 1))\n    elif num_spatial_dims == 3:\n        inputs = tf.transpose(inputs, (0, 2, 3, 4, 1))\n    else:\n        raise ValueError(\n            \"Pooling inputs's shape must be 3, 4 or 5, corresponding to 1D, 2D \"\n            f\"and 3D inputs. But received shape: {inputs.shape}.\"\n        )\n    return inputs\n\n\ndef _transpose_spatial_outputs(outputs):\n    # Undo the transpose in `_transpose_spatial_inputs`.\n    num_spatial_dims = len(outputs.shape) - 2\n    if num_spatial_dims == 1:\n        outputs = tf.transpose(outputs, (0, 2, 1))\n    elif num_spatial_dims == 2:\n        outputs = tf.transpose(outputs, (0, 3, 1, 2))\n    elif num_spatial_dims == 3:\n        outputs = tf.transpose(outputs, (0, 4, 1, 2, 3))\n    return outputs\n\n\ndef max_pool(\n    inputs,\n    pool_size,\n    strides=None,\n    padding=\"valid\",\n    data_format=None,\n):\n    data_format = backend.standardize_data_format(data_format)\n    strides = pool_size if strides is None else strides\n    padding = padding.upper()\n    tf_data_format = _convert_data_format(\"channels_last\", len(inputs.shape))\n    if data_format == \"channels_first\":\n        # Tensorflow pooling does not support `channels_first` format, so\n        # we need to transpose to `channels_last` format.\n        inputs = _transpose_spatial_inputs(inputs)\n\n    outputs = tf.nn.max_pool(\n        inputs,\n        pool_size,\n        strides,\n        padding,\n        tf_data_format,\n    )\n    if data_format == \"channels_first\":\n        outputs = _transpose_spatial_outputs(outputs)\n    return outputs\n\n\ndef average_pool(\n    inputs,\n    pool_size,\n    strides=None,\n    padding=\"valid\",\n    data_format=None,\n):\n    data_format = backend.standardize_data_format(data_format)\n    strides = pool_size if strides is None else strides\n    padding = padding.upper()\n    tf_data_format = _convert_data_format(\"channels_last\", len(inputs.shape))\n    if data_format == \"channels_first\":\n        # Tensorflow pooling does not support `channels_first` format, so\n        # we need to transpose to `channels_last` format.\n        inputs = _transpose_spatial_inputs(inputs)\n\n    outputs = tf.nn.avg_pool(\n        inputs,\n        pool_size,\n        strides,\n        padding,\n        tf_data_format,\n    )\n    if data_format == \"channels_first\":\n        outputs = _transpose_spatial_outputs(outputs)\n    return outputs\n\n\ndef _compute_static_gather_indices(\n    input_dim, output_size, small_window, big_window\n):\n    \"\"\"Compute gather indices for Two-Pool Gather method (corrected).\"\"\"\n    window_starts = tf.cast(\n        tf.floor(\n            tf.cast(tf.range(output_size), tf.float32)\n            * tf.cast(input_dim, tf.float32)\n            / tf.cast(output_size, tf.float32)\n        ),\n        tf.int32,\n    )\n    window_ends = tf.cast(\n        tf.math.ceil(\n            tf.cast(tf.range(1, output_size + 1), tf.float32)\n            * tf.cast(input_dim, tf.float32)\n            / tf.cast(output_size, tf.float32)\n        ),\n        tf.int32,\n    )\n\n    window_ends = tf.minimum(window_ends, input_dim)\n    window_starts = tf.minimum(window_starts, input_dim - 1)\n\n    window_sizes = window_ends - window_starts\n    is_big_window = tf.equal(window_sizes, big_window)\n\n    small_pool_len = max(1, input_dim - small_window + 1)\n\n    small_indices = window_starts\n    big_indices = window_starts + small_pool_len\n\n    gather_indices = tf.where(is_big_window, big_indices, small_indices)\n    return tf.cast(gather_indices, tf.int32)\n\n\ndef _adaptive_average_pool1d(inputs, output_size, data_format=\"channels_first\"):\n    if isinstance(output_size, int):\n        output_size = (output_size,)\n    if data_format == \"channels_first\":\n        inputs = tf.transpose(inputs, (0, 2, 1))\n\n    static_shape = inputs.shape.as_list()\n    l_static = static_shape[1]\n    out_l = output_size[0]\n\n    if l_static is None:\n        raise ValueError(\n            \"Input length must be statically known for adaptive pooling\"\n        )\n\n    small_l, big_l = compute_adaptive_pooling_window_sizes(l_static, out_l)\n    gather_l = _compute_static_gather_indices(l_static, out_l, small_l, big_l)\n\n    small_pool_l = tf.nn.pool(\n        inputs,\n        window_shape=(small_l,),\n        pooling_type=\"AVG\",\n        strides=(1,),\n        padding=\"VALID\",\n        data_format=\"NWC\",\n    )\n    big_pool_l = tf.nn.pool(\n        inputs,\n        window_shape=(big_l,),\n        pooling_type=\"AVG\",\n        strides=(1,),\n        padding=\"VALID\",\n        data_format=\"NWC\",\n    )\n\n    combined_l = tf.concat([small_pool_l, big_pool_l], axis=1)\n    pooled_l = tf.gather(combined_l, gather_l, axis=1)\n\n    if data_format == \"channels_first\":\n        pooled_l = tf.transpose(pooled_l, (0, 2, 1))\n    return pooled_l\n\n\ndef _adaptive_max_pool1d(inputs, output_size, data_format=\"channels_first\"):\n    if isinstance(output_size, int):\n        output_size = (output_size,)\n    if data_format == \"channels_first\":\n        inputs = tf.transpose(inputs, (0, 2, 1))\n\n    static_shape = inputs.shape.as_list()\n    l_static = static_shape[1]\n    out_l = output_size[0]\n\n    if l_static is None:\n        raise ValueError(\n            \"Input length must be statically known for adaptive pooling\"\n        )\n\n    small_l, big_l = compute_adaptive_pooling_window_sizes(l_static, out_l)\n    gather_l = _compute_static_gather_indices(l_static, out_l, small_l, big_l)\n\n    small_pool_l = tf.nn.pool(\n        inputs,\n        window_shape=(small_l,),\n        pooling_type=\"MAX\",\n        strides=(1,),\n        padding=\"VALID\",\n        data_format=\"NWC\",\n    )\n    big_pool_l = tf.nn.pool(\n        inputs,\n        window_shape=(big_l,),\n        pooling_type=\"MAX\",\n        strides=(1,),\n        padding=\"VALID\",\n        data_format=\"NWC\",\n    )\n\n    combined_l = tf.concat([small_pool_l, big_pool_l], axis=1)\n    pooled_l = tf.gather(combined_l, gather_l, axis=1)\n\n    if data_format == \"channels_first\":\n        pooled_l = tf.transpose(pooled_l, (0, 2, 1))\n    return pooled_l\n\n\ndef _adaptive_average_pool2d(inputs, output_size, data_format=\"channels_first\"):\n    if isinstance(output_size, int):\n        output_size = (output_size, output_size)\n\n    if data_format == \"channels_first\":\n        inputs = tf.transpose(inputs, (0, 2, 3, 1))\n\n    static_shape = inputs.shape.as_list()\n    h_static = static_shape[1]\n    w_static = static_shape[2]\n    out_h, out_w = output_size\n\n    if h_static is None or w_static is None:\n        raise ValueError(\n            \"Input spatial dimensions must be \"\n            \"statically known for adaptive pooling\"\n        )\n\n    small_h, big_h = compute_adaptive_pooling_window_sizes(h_static, out_h)\n    small_w, big_w = compute_adaptive_pooling_window_sizes(w_static, out_w)\n\n    gather_h = _compute_static_gather_indices(h_static, out_h, small_h, big_h)\n    gather_w = _compute_static_gather_indices(w_static, out_w, small_w, big_w)\n\n    small_pool_h = tf.nn.pool(\n        inputs,\n        window_shape=(small_h, 1),\n        pooling_type=\"AVG\",\n        strides=(1, 1),\n        padding=\"VALID\",\n        data_format=\"NHWC\",\n    )\n    big_pool_h = tf.nn.pool(\n        inputs,\n        window_shape=(big_h, 1),\n        pooling_type=\"AVG\",\n        strides=(1, 1),\n        padding=\"VALID\",\n        data_format=\"NHWC\",\n    )\n\n    combined_h = tf.concat([small_pool_h, big_pool_h], axis=1)\n    pooled_h = tf.gather(combined_h, gather_h, axis=1)\n\n    small_pool_w = tf.nn.pool(\n        pooled_h,\n        window_shape=(1, small_w),\n        pooling_type=\"AVG\",\n        strides=(1, 1),\n        padding=\"VALID\",\n        data_format=\"NHWC\",\n    )\n    big_pool_w = tf.nn.pool(\n        pooled_h,\n        window_shape=(1, big_w),\n        pooling_type=\"AVG\",\n        strides=(1, 1),\n        padding=\"VALID\",\n        data_format=\"NHWC\",\n    )\n\n    combined_w = tf.concat([small_pool_w, big_pool_w], axis=2)\n    pooled_w = tf.gather(combined_w, gather_w, axis=2)\n\n    if data_format == \"channels_first\":\n        pooled_w = tf.transpose(pooled_w, (0, 3, 1, 2))\n\n    return pooled_w\n\n\ndef _adaptive_max_pool2d(inputs, output_size, data_format=\"channels_first\"):\n    \"\"\"Adaptive Max Pooling 2D using Two-Pool Gather method.\"\"\"\n    if isinstance(output_size, int):\n        output_size = (output_size, output_size)\n\n    if data_format == \"channels_first\":\n        inputs = tf.transpose(inputs, (0, 2, 3, 1))\n\n    static_shape = inputs.shape.as_list()\n    h_static = static_shape[1]\n    w_static = static_shape[2]\n    out_h, out_w = output_size\n\n    if h_static is None or w_static is None:\n        raise ValueError(\n            \"Input spatial dimensions must be \"\n            \"statically known for adaptive pooling\"\n        )\n\n    small_h, big_h = compute_adaptive_pooling_window_sizes(h_static, out_h)\n    small_w, big_w = compute_adaptive_pooling_window_sizes(w_static, out_w)\n\n    gather_h = _compute_static_gather_indices(h_static, out_h, small_h, big_h)\n    gather_w = _compute_static_gather_indices(w_static, out_w, small_w, big_w)\n\n    small_pool_h = tf.nn.pool(\n        inputs,\n        window_shape=(small_h, 1),\n        pooling_type=\"MAX\",\n        strides=(1, 1),\n        padding=\"VALID\",\n        data_format=\"NHWC\",\n    )\n    big_pool_h = tf.nn.pool(\n        inputs,\n        window_shape=(big_h, 1),\n        pooling_type=\"MAX\",\n        strides=(1, 1),\n        padding=\"VALID\",\n        data_format=\"NHWC\",\n    )\n\n    combined_h = tf.concat([small_pool_h, big_pool_h], axis=1)\n    pooled_h = tf.gather(combined_h, gather_h, axis=1)\n\n    small_pool_w = tf.nn.pool(\n        pooled_h,\n        window_shape=(1, small_w),\n        pooling_type=\"MAX\",\n        strides=(1, 1),\n        padding=\"VALID\",\n        data_format=\"NHWC\",\n    )\n    big_pool_w = tf.nn.pool(\n        pooled_h,\n        window_shape=(1, big_w),\n        pooling_type=\"MAX\",\n        strides=(1, 1),\n        padding=\"VALID\",\n        data_format=\"NHWC\",\n    )\n\n    combined_w = tf.concat([small_pool_w, big_pool_w], axis=2)\n    pooled_w = tf.gather(combined_w, gather_w, axis=2)\n\n    if data_format == \"channels_first\":\n        pooled_w = tf.transpose(pooled_w, (0, 3, 1, 2))\n\n    return pooled_w\n\n\ndef _adaptive_average_pool3d(inputs, output_size, data_format=\"channels_first\"):\n    if isinstance(output_size, int):\n        output_size = (output_size, output_size, output_size)\n\n    if data_format == \"channels_first\":\n        inputs = tf.transpose(inputs, (0, 2, 3, 4, 1))\n\n    static_shape = inputs.shape.as_list()\n    d_static = static_shape[1]\n    h_static = static_shape[2]\n    w_static = static_shape[3]\n    out_d, out_h, out_w = output_size\n\n    if d_static is None or h_static is None or w_static is None:\n        raise ValueError(\n            \"Input spatial dimensions must be \"\n            \"statically known for adaptive pooling\"\n        )\n\n    small_d, big_d = compute_adaptive_pooling_window_sizes(d_static, out_d)\n    small_h, big_h = compute_adaptive_pooling_window_sizes(h_static, out_h)\n    small_w, big_w = compute_adaptive_pooling_window_sizes(w_static, out_w)\n\n    gather_d = _compute_static_gather_indices(d_static, out_d, small_d, big_d)\n    gather_h = _compute_static_gather_indices(h_static, out_h, small_h, big_h)\n    gather_w = _compute_static_gather_indices(w_static, out_w, small_w, big_w)\n\n    small_pool_d = tf.nn.pool(\n        inputs,\n        window_shape=(small_d, 1, 1),\n        pooling_type=\"AVG\",\n        strides=(1, 1, 1),\n        padding=\"VALID\",\n        data_format=\"NDHWC\",\n    )\n    big_pool_d = tf.nn.pool(\n        inputs,\n        window_shape=(big_d, 1, 1),\n        pooling_type=\"AVG\",\n        strides=(1, 1, 1),\n        padding=\"VALID\",\n        data_format=\"NDHWC\",\n    )\n\n    combined_d = tf.concat([small_pool_d, big_pool_d], axis=1)\n    pooled_d = tf.gather(combined_d, gather_d, axis=1)\n\n    small_pool_h = tf.nn.pool(\n        pooled_d,\n        window_shape=(1, small_h, 1),\n        pooling_type=\"AVG\",\n        strides=(1, 1, 1),\n        padding=\"VALID\",\n        data_format=\"NDHWC\",\n    )\n    big_pool_h = tf.nn.pool(\n        pooled_d,\n        window_shape=(1, big_h, 1),\n        pooling_type=\"AVG\",\n        strides=(1, 1, 1),\n        padding=\"VALID\",\n        data_format=\"NDHWC\",\n    )\n\n    combined_h = tf.concat([small_pool_h, big_pool_h], axis=2)\n    pooled_h = tf.gather(combined_h, gather_h, axis=2)\n\n    small_pool_w = tf.nn.pool(\n        pooled_h,\n        window_shape=(1, 1, small_w),\n        pooling_type=\"AVG\",\n        strides=(1, 1, 1),\n        padding=\"VALID\",\n        data_format=\"NDHWC\",\n    )\n    big_pool_w = tf.nn.pool(\n        pooled_h,\n        window_shape=(1, 1, big_w),\n        pooling_type=\"AVG\",\n        strides=(1, 1, 1),\n        padding=\"VALID\",\n        data_format=\"NDHWC\",\n    )\n\n    combined_w = tf.concat([small_pool_w, big_pool_w], axis=3)\n    pooled_w = tf.gather(combined_w, gather_w, axis=3)\n\n    if data_format == \"channels_first\":\n        pooled_w = tf.transpose(pooled_w, (0, 4, 1, 2, 3))\n\n    return pooled_w\n\n\ndef _adaptive_max_pool3d(inputs, output_size, data_format=\"channels_first\"):\n    \"\"\"Adaptive Max Pooling 3D using Two-Pool Gather method.\"\"\"\n    if isinstance(output_size, int):\n        output_size = (output_size, output_size, output_size)\n\n    if data_format == \"channels_first\":\n        inputs = tf.transpose(inputs, (0, 2, 3, 4, 1))\n\n    static_shape = inputs.shape.as_list()\n    d_static = static_shape[1]\n    h_static = static_shape[2]\n    w_static = static_shape[3]\n    out_d, out_h, out_w = output_size\n\n    if d_static is None or h_static is None or w_static is None:\n        raise ValueError(\n            \"Input spatial dimensions must be \"\n            \"statically known for adaptive pooling\"\n        )\n\n    small_d, big_d = compute_adaptive_pooling_window_sizes(d_static, out_d)\n    small_h, big_h = compute_adaptive_pooling_window_sizes(h_static, out_h)\n    small_w, big_w = compute_adaptive_pooling_window_sizes(w_static, out_w)\n\n    gather_d = _compute_static_gather_indices(d_static, out_d, small_d, big_d)\n    gather_h = _compute_static_gather_indices(h_static, out_h, small_h, big_h)\n    gather_w = _compute_static_gather_indices(w_static, out_w, small_w, big_w)\n\n    small_pool_d = tf.nn.pool(\n        inputs,\n        window_shape=(small_d, 1, 1),\n        pooling_type=\"MAX\",\n        strides=(1, 1, 1),\n        padding=\"VALID\",\n        data_format=\"NDHWC\",\n    )\n    big_pool_d = tf.nn.pool(\n        inputs,\n        window_shape=(big_d, 1, 1),\n        pooling_type=\"MAX\",\n        strides=(1, 1, 1),\n        padding=\"VALID\",\n        data_format=\"NDHWC\",\n    )\n\n    combined_d = tf.concat([small_pool_d, big_pool_d], axis=1)\n    pooled_d = tf.gather(combined_d, gather_d, axis=1)\n\n    small_pool_h = tf.nn.pool(\n        pooled_d,\n        window_shape=(1, small_h, 1),\n        pooling_type=\"MAX\",\n        strides=(1, 1, 1),\n        padding=\"VALID\",\n        data_format=\"NDHWC\",\n    )\n    big_pool_h = tf.nn.pool(\n        pooled_d,\n        window_shape=(1, big_h, 1),\n        pooling_type=\"MAX\",\n        strides=(1, 1, 1),\n        padding=\"VALID\",\n        data_format=\"NDHWC\",\n    )\n\n    combined_h = tf.concat([small_pool_h, big_pool_h], axis=2)\n    pooled_h = tf.gather(combined_h, gather_h, axis=2)\n\n    small_pool_w = tf.nn.pool(\n        pooled_h,\n        window_shape=(1, 1, small_w),\n        pooling_type=\"MAX\",\n        strides=(1, 1, 1),\n        padding=\"VALID\",\n        data_format=\"NDHWC\",\n    )\n    big_pool_w = tf.nn.pool(\n        pooled_h,\n        window_shape=(1, 1, big_w),\n        pooling_type=\"MAX\",\n        strides=(1, 1, 1),\n        padding=\"VALID\",\n        data_format=\"NDHWC\",\n    )\n\n    combined_w = tf.concat([small_pool_w, big_pool_w], axis=3)\n    pooled_w = tf.gather(combined_w, gather_w, axis=3)\n\n    if data_format == \"channels_first\":\n        pooled_w = tf.transpose(pooled_w, (0, 4, 1, 2, 3))\n\n    return pooled_w\n\n\ndef adaptive_average_pool(inputs, output_size, data_format=None):\n    data_format = backend.standardize_data_format(data_format)\n    ndims = len(inputs.shape) - 2\n    if ndims == 1:\n        return _adaptive_average_pool1d(inputs, output_size, data_format)\n    elif ndims == 2:\n        return _adaptive_average_pool2d(inputs, output_size, data_format)\n    elif ndims == 3:\n        return _adaptive_average_pool3d(inputs, output_size, data_format)\n    else:\n        raise ValueError(\n            \"adaptive_average_pool supports 1D, 2D, or 3D inputs only.\"\n        )\n\n\ndef adaptive_max_pool(inputs, output_size, data_format=None):\n    data_format = backend.standardize_data_format(data_format)\n    ndims = len(inputs.shape) - 2\n    if ndims == 1:\n        return _adaptive_max_pool1d(inputs, output_size, data_format)\n    elif ndims == 2:\n        return _adaptive_max_pool2d(inputs, output_size, data_format)\n    elif ndims == 3:\n        return _adaptive_max_pool3d(inputs, output_size, data_format)\n    else:\n        raise ValueError(\n            \"adaptive_max_pool supports 1D, 2D, or 3D inputs only.\"\n        )\n\n\ndef _convert_data_format(data_format, ndim):\n    if data_format == \"channels_last\":\n        if ndim == 3:\n            return \"NWC\"\n        elif ndim == 4:\n            return \"NHWC\"\n        elif ndim == 5:\n            return \"NDHWC\"\n        else:\n            raise ValueError(\n                f\"Input rank not supported: {ndim}. \"\n                \"Expected values are [3, 4, 5]\"\n            )\n    elif data_format == \"channels_first\":\n        if ndim == 3:\n            return \"NCW\"\n        elif ndim == 4:\n            return \"NCHW\"\n        elif ndim == 5:\n            return \"NCDHW\"\n        else:\n            raise ValueError(\n                f\"Input rank not supported: {ndim}. \"\n                \"Expected values are [3, 4, 5]\"\n            )\n    else:\n        raise ValueError(\n            f\"Invalid data_format: {data_format}. \"\n            'Expected values are [\"channels_first\", \"channels_last\"]'\n        )\n\n\ndef conv(\n    inputs,\n    kernel,\n    strides=1,\n    padding=\"valid\",\n    data_format=None,\n    dilation_rate=1,\n):\n    def _conv():\n        tf_data_format = _convert_data_format(data_format, len(inputs.shape))\n        result = tf.nn.convolution(\n            inputs,\n            kernel,\n            strides,\n            padding.upper(),\n            data_format=tf_data_format,\n            dilations=dilation_rate,\n        )\n        result_shape = result.shape\n        if (\n            result_shape.is_fully_defined()\n            and math.prod(result_shape.as_list()) == 0\n        ):\n            raise ValueError(\n                \"The convolution operation resulted in an empty output. \"\n                \"Output shape:\"\n                f\" {result_shape}. This can happen if the input is too small \"\n                \"for the given kernel size, strides, dilation rate, and \"\n                \"padding mode. Please check the input shape and convolution \"\n                \"parameters.\"\n            )\n        return result\n\n    # Certain ops are are broken in Tensorflow on CPU only.\n    # We can work around by compiling the op with XLA.\n    @tf.function(jit_compile=True)\n    def _conv_xla():\n        return _conv()\n\n    # Channels first \"NCDHW\" (3d convolutions) are broken on CPU without XLA.\n    needs_xla = data_format == \"channels_first\" and len(inputs.shape) == 5\n    # grouped convolutions are broken on CPU without XLA.\n    data_format = backend.standardize_data_format(data_format)\n    if data_format == \"channels_last\":\n        channels = inputs.shape[-1]\n    else:\n        channels = inputs.shape[1]\n    needs_xla = needs_xla or channels != kernel.shape[-2]\n    if needs_xla:\n        return _conv_xla()\n    else:\n        return _conv()\n\n\ndef depthwise_conv(\n    inputs,\n    kernel,\n    strides=1,\n    padding=\"valid\",\n    data_format=None,\n    dilation_rate=1,\n):\n    data_format = backend.standardize_data_format(data_format)\n    num_spatial_dims = len(inputs.shape) - 2\n    if num_spatial_dims > 2:\n        raise ValueError(\n            \"`inputs` rank must be 3 (1D conv) or 4 (2D conv). Received: \"\n            f\"{inputs.ndim}.\"\n        )\n    # Because we use `tf.nn.depthwise_conv2d` for both 1D and 2D convs, we set\n    # `tf_data_format` using 2D conv format.\n    tf_data_format = _convert_data_format(data_format, 4)\n    padding = padding.upper()\n    if isinstance(strides, int):\n        strides = (strides,) * num_spatial_dims\n    if isinstance(dilation_rate, int):\n        dilation_rate = (dilation_rate,) * num_spatial_dims\n    if num_spatial_dims == 1:\n        # 1D depthwise conv.\n        if data_format == \"channels_last\":\n            strides = (1,) + strides * 2 + (1,)\n            spatial_start_dim = 1\n        else:\n            strides = (1, 1) + strides * 2\n            spatial_start_dim = 2\n        inputs = tf.expand_dims(inputs, spatial_start_dim)\n        kernel = tf.expand_dims(kernel, axis=0)\n\n        dilation_rate = None if dilation_rate is None else (1,) + dilation_rate\n\n        outputs = tf.nn.depthwise_conv2d(\n            inputs,\n            kernel,\n            strides,\n            padding,\n            data_format=tf_data_format,\n            dilations=dilation_rate,\n        )\n        return tf.squeeze(outputs, [spatial_start_dim])\n\n    if data_format == \"channels_last\":\n        strides = (1,) + strides + (1,)\n        spatial_start_dim = 1\n    else:\n        strides = (1, 1) + strides\n        spatial_start_dim = 2\n    return tf.nn.depthwise_conv2d(\n        inputs,\n        kernel,\n        strides,\n        padding,\n        data_format=tf_data_format,\n        dilations=dilation_rate,\n    )\n\n\ndef separable_conv(\n    inputs,\n    depthwise_kernel,\n    pointwise_kernel,\n    strides=1,\n    padding=\"valid\",\n    data_format=None,\n    dilation_rate=1,\n):\n    data_format = backend.standardize_data_format(data_format)\n    num_spatial_dims = len(inputs.shape) - 2\n    if num_spatial_dims > 2:\n        raise ValueError(\n            \"`num_spatial_dims` must be 1 or 2. Received: \"\n            f\"num_spatial_dims={num_spatial_dims}.\"\n        )\n    # Because we use `tf.nn.separable_conv2d` for both 1D and 2D convs, we set\n    # `tf_data_format` using 2D conv format.\n    tf_data_format = _convert_data_format(data_format, 4)\n    padding = padding.upper()\n    if isinstance(strides, int):\n        strides = (strides,) * num_spatial_dims\n    if isinstance(dilation_rate, int):\n        dilation_rate = (dilation_rate,) * num_spatial_dims\n    if num_spatial_dims == 1:\n        # 1D depthwise conv.\n        if data_format == \"channels_last\":\n            strides = (1,) + strides * 2 + (1,)\n            spatial_start_dim = 1\n        else:\n            strides = (1, 1) + strides * 2\n            spatial_start_dim = 2\n        inputs = tf.expand_dims(inputs, spatial_start_dim)\n        depthwise_kernel = tf.expand_dims(depthwise_kernel, axis=0)\n        pointwise_kernel = tf.expand_dims(pointwise_kernel, axis=0)\n        dilation_rate = None if dilation_rate is None else (1,) + dilation_rate\n\n        outputs = tf.nn.separable_conv2d(\n            inputs,\n            depthwise_kernel,\n            pointwise_kernel,\n            strides,\n            padding,\n            data_format=tf_data_format,\n            dilations=dilation_rate,\n        )\n        return tf.squeeze(outputs, [spatial_start_dim])\n\n    if data_format == \"channels_last\":\n        strides = (1,) + strides + (1,)\n    else:\n        strides = (1, 1) + strides\n    return tf.nn.separable_conv2d(\n        inputs,\n        depthwise_kernel,\n        pointwise_kernel,\n        strides,\n        padding,\n        data_format=tf_data_format,\n        dilations=dilation_rate,\n    )\n\n\ndef conv_transpose(\n    inputs,\n    kernel,\n    strides=1,\n    padding=\"valid\",\n    output_padding=None,\n    data_format=None,\n    dilation_rate=1,\n):\n    data_format = backend.standardize_data_format(data_format)\n    tf_data_format = _convert_data_format(data_format, len(inputs.shape))\n    kernel_size = kernel.shape[:-2]\n    filters = kernel.shape[-2]\n    input_shape = list(inputs.shape)\n    symbolic_shape = tf.shape(inputs)\n    for i, e in enumerate(input_shape):\n        if e is None:\n            input_shape[i] = symbolic_shape[i]\n    output_shape = compute_conv_transpose_output_shape(\n        input_shape,\n        kernel_size,\n        filters,\n        strides,\n        padding,\n        output_padding,\n        data_format,\n        dilation_rate,\n    )\n\n    return tf.nn.conv_transpose(\n        inputs,\n        kernel,\n        output_shape,\n        strides,\n        padding=padding.upper(),\n        data_format=tf_data_format,\n        dilations=dilation_rate,\n    )\n\n\ndef one_hot(x, num_classes, axis=-1, dtype=None, sparse=False):\n    x = convert_to_tensor(x, dtype=\"int64\")\n    if dtype is None:\n        dtype = \"float32\"\n    else:\n        dtype = backend.standardize_dtype(dtype)\n    if sparse:\n        # We don't use `tf.sparse.bincount`, it doesn't handle negative indices\n        # and only support rank 1 and 2 tensors (`one_hot` adds a dimension).\n        if axis < 0:\n            axis = axis + len(x.shape) + 1\n        values_count = math.prod(x.shape)\n        values = tf.reshape(x, (values_count,))\n        # We deal with negative inputs by having zeros in the output although\n        # it's useless. It makes shapes static.\n        values = tf.cast(tf.greater_equal(values, 0), dtype=dtype)\n        indices = [tf.range(dim) for dim in x.shape]\n        indices = tf.meshgrid(*indices, indexing=\"ij\")\n        indices.insert(axis, tf.maximum(x, 0))  # Deal with negative indices\n        indices = [tf.reshape(a, (values_count, 1)) for a in indices]\n        indices = [tf.cast(a, tf.int64) for a in indices]\n        indices = tf.concat(indices, axis=1)\n        shape = list(x.shape)\n        shape.insert(axis, num_classes)\n        return tf.SparseTensor(indices, values, shape)\n    on_value, off_value = (True, False) if dtype == \"bool\" else (None, None)\n    return tf.one_hot(\n        x,\n        num_classes,\n        on_value=on_value,\n        off_value=off_value,\n        axis=axis,\n        dtype=dtype,\n    )\n\n\ndef multi_hot(x, num_classes, axis=-1, dtype=None, sparse=False):\n    reduction_axis = 1 if len(x.shape) > 1 else 0\n    if backend.standardize_dtype(dtype) == \"bool\":\n        if sparse:\n            # `tf.sparse.reduce_max` doesn't work on bool and there is no\n            # `tf.sparse.reduce_any`.\n            outputs = one_hot(\n                x, num_classes, axis=axis, dtype=\"int8\", sparse=True\n            )\n            outputs = tf.sparse.reduce_max(\n                outputs, axis=reduction_axis, output_is_sparse=True\n            )\n            outputs_shape = outputs.shape\n            outputs = tf.cast(outputs, dtype)\n            outputs.set_shape(outputs_shape)\n            return outputs\n        else:\n            outputs = one_hot(x, num_classes, axis=axis, dtype=dtype)\n            return tf.reduce_any(outputs, axis=reduction_axis)\n    else:\n        if sparse:\n            # We don't use `tf.sparse.bincount`, it doesn't handle negative\n            # indices and has a rank limitation.\n            outputs = one_hot(\n                x, num_classes, axis=axis, dtype=dtype, sparse=True\n            )\n            return tf.sparse.reduce_max(\n                outputs, axis=reduction_axis, output_is_sparse=True\n            )\n        else:\n            outputs = one_hot(x, num_classes, axis=axis, dtype=dtype)\n            return tf.reduce_max(outputs, axis=reduction_axis)\n\n\ndef _get_logits(output, from_logits, op_type, fn_name):\n    \"\"\"Retrieves logits tensor from maybe-softmax or maybe-sigmoid tensor.\"\"\"\n    output_ = output\n    from_logits_ = from_logits\n\n    has_keras_logits = hasattr(output, \"_keras_logits\")\n    if has_keras_logits:\n        output_ = output._keras_logits\n        from_logits_ = True\n\n    from_expected_op_type = (\n        hasattr(output, \"op\")\n        and not isinstance(output, (tf.__internal__.EagerTensor, tf.Variable))\n        and output.op.type == op_type\n    ) and not has_keras_logits\n\n    if from_expected_op_type:\n        # When softmax activation function is used for output operation, we\n        # use logits from the softmax function directly to compute loss in order\n        # to prevent collapsing zero when training.\n        if len(output.op.inputs) != 1:\n            raise ValueError(f\"Expected 1 input for {op_type}.\")\n        output_ = output.op.inputs[0]\n        from_logits_ = True\n\n    if from_logits and (has_keras_logits or from_expected_op_type):\n        warnings.warn(\n            f'\"`{fn_name}` received `from_logits=True`, but '\n            f\"the `output` argument was produced by a {op_type} \"\n            \"activation and thus does not represent logits. \"\n            \"Was this intended?\",\n            stacklevel=2,\n        )\n    return output_, from_logits_\n\n\ndef categorical_crossentropy(target, output, from_logits=False, axis=-1):\n    \"\"\"Categorical crossentropy between an output tensor and a target tensor.\n\n    Args:\n        target: A tensor of the same shape as `output`.\n        output: A tensor resulting from a softmax\n            (unless `from_logits` is `True`, in which\n            case `output` is expected to be the logits).\n        from_logits: Boolean, whether `output` is the\n            result of a softmax, or is a tensor of logits.\n        axis: Int specifying the channels axis. `axis=-1` corresponds to data\n            format `channels_last`, and `axis=1` corresponds to data format\n            `channels_first`.\n\n    Returns:\n        Output tensor.\n\n    Example:\n\n    >>> a = tf.constant([1., 0., 0., 0., 1., 0., 0., 0., 1.], shape=[3,3])\n    >>> print(a)\n    tf.Tensor(\n      [[1. 0. 0.]\n       [0. 1. 0.]\n       [0. 0. 1.]], shape=(3, 3), dtype=float32)\n    >>> b = tf.constant([.9, .05, .05, .05, .89, .06, .05, .01, .94],\n    ...                 shape=[3, 3])\n    >>> print(b)\n    tf.Tensor(\n      [[0.9  0.05 0.05]\n       [0.05 0.89 0.06]\n       [0.05 0.01 0.94]], shape=(3, 3), dtype=float32)\n    >>> loss = categorical_crossentropy(a, b)\n    >>> print(np.around(loss, 5))\n    [0.10536 0.11653 0.06188]\n    >>> loss = categorical_crossentropy(a, a)\n    >>> print(np.around(loss, 5))\n    [0. 0. 0.]\n    \"\"\"\n    target = tf.convert_to_tensor(target)\n    output = tf.convert_to_tensor(output)\n\n    if len(target.shape) < 1:\n        raise ValueError(\n            \"Arguments `target` and `output` must be at least rank 1. \"\n            \"Received: \"\n            f\"target.shape={target.shape}, output.shape={output.shape}\"\n        )\n    if len(target.shape) != len(output.shape):\n        raise ValueError(\n            \"Arguments `target` and `output` must have the same rank \"\n            \"(ndim). Received: \"\n            f\"target.shape={target.shape}, output.shape={output.shape}\"\n        )\n    for e1, e2 in zip(target.shape, output.shape):\n        if e1 is not None and e2 is not None and e1 != e2:\n            raise ValueError(\n                \"Arguments `target` and `output` must have the same shape. \"\n                \"Received: \"\n                f\"target.shape={target.shape}, output.shape={output.shape}\"\n            )\n\n    output, from_logits = _get_logits(\n        output, from_logits, \"Softmax\", \"categorical_crossentropy\"\n    )\n    if from_logits:\n        return tf.nn.softmax_cross_entropy_with_logits(\n            labels=target, logits=output, axis=axis\n        )\n\n    # Adjust the predictions so that the probability of\n    # each class for every sample adds up to 1\n    # This is needed to ensure that the cross entropy is\n    # computed correctly.\n    output = output / tf.reduce_sum(output, axis, keepdims=True)\n\n    # Compute cross entropy from probabilities.\n    output = tf.clip_by_value(\n        output, backend.epsilon(), 1.0 - backend.epsilon()\n    )\n    return -tf.reduce_sum(target * tf.math.log(output), axis)\n\n\ndef sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):\n    \"\"\"Categorical crossentropy with integer targets.\n\n    Args:\n        target: An integer tensor.\n        output: A tensor resulting from a softmax\n            (unless `from_logits` is True, in which\n            case `output` is expected to be the logits).\n        from_logits: Boolean, whether `output` is the\n            result of a softmax, or is a tensor of logits.\n        axis: Int specifying the channels axis. `axis=-1` corresponds to data\n            format `channels_last`, and `axis=1` corresponds to data format\n            `channels_first`.\n\n    Returns:\n        Output tensor.\n    \"\"\"\n    if axis != -1 and axis != len(output.shape) - 1:\n        raise ValueError(\n            f\"Only axis=-1 is currently supported. Received: axis={axis}\"\n        )\n    output, from_logits = _get_logits(\n        output, from_logits, \"Softmax\", \"sparse_categorical_crossentropy\"\n    )\n\n    target = tf.convert_to_tensor(target)\n    target = tf.cast(target, dtype=\"int64\")\n    output = tf.convert_to_tensor(output)\n    if len(target.shape) == len(output.shape) and target.shape[-1] == 1:\n        target = tf.squeeze(target, axis=-1)\n\n    if len(output.shape) < 1:\n        raise ValueError(\n            \"Argument `output` must be at least rank 1. \"\n            \"Received: \"\n            f\"output.shape={output.shape}\"\n        )\n    if len(target.shape) != len(output.shape[:-1]):\n        raise ValueError(\n            \"Argument `output` must have rank (ndim) `target.ndim - 1`. \"\n            \"Received: \"\n            f\"target.shape={target.shape}, output.shape={output.shape}\"\n        )\n    for e1, e2 in zip(target.shape, output.shape[:-1]):\n        if e1 is not None and e2 is not None and e1 != e2:\n            raise ValueError(\n                \"Arguments `target` and `output` must have the same shape \"\n                \"up until the last dimension: \"\n                f\"target.shape={target.shape}, output.shape={output.shape}\"\n            )\n\n    if not from_logits:\n        output = tf.clip_by_value(\n            output, backend.epsilon(), 1 - backend.epsilon()\n        )\n        output = tf.math.log(output)\n\n    result = tf.nn.sparse_softmax_cross_entropy_with_logits(\n        labels=target, logits=output\n    )\n    return result\n\n\ndef binary_crossentropy(target, output, from_logits=False):\n    \"\"\"Binary crossentropy between an output tensor and a target tensor.\n\n    Args:\n        target: A tensor with the same shape as `output`.\n        output: A tensor.\n        from_logits: Whether `output` is expected to be a logits tensor.\n            By default, we consider that `output`\n            encodes a probability distribution.\n\n    Returns:\n        A tensor.\n    \"\"\"\n    target = tf.convert_to_tensor(target)\n    output = tf.convert_to_tensor(output)\n\n    if len(target.shape) != len(output.shape):\n        raise ValueError(\n            \"Arguments `target` and `output` must have the same rank \"\n            \"(ndim). Received: \"\n            f\"target.shape={target.shape}, output.shape={output.shape}\"\n        )\n    for e1, e2 in zip(target.shape, output.shape):\n        if e1 is not None and e2 is not None and e1 != e2:\n            raise ValueError(\n                \"Arguments `target` and `output` must have the same shape. \"\n                \"Received: \"\n                f\"target.shape={target.shape}, output.shape={output.shape}\"\n            )\n\n    output, from_logits = _get_logits(\n        output, from_logits, \"Sigmoid\", \"binary_crossentropy\"\n    )\n\n    if from_logits:\n        return tf.nn.sigmoid_cross_entropy_with_logits(\n            labels=target, logits=output\n        )\n\n    # Compute cross entropy from probabilities.\n    output = tf.clip_by_value(\n        output, backend.epsilon(), 1.0 - backend.epsilon()\n    )\n    bce = target * tf.math.log(output)\n    bce += (1 - target) * tf.math.log(1 - output)\n    return -bce\n\n\ndef moments(x, axes, keepdims=False, synchronized=False):\n    # The dynamic range of float16 is too limited for statistics. As a\n    # workaround, we simply perform the operations on float32 and convert back\n    # to float16\n    need_cast = False\n    ori_dtype = backend.standardize_dtype(x.dtype)\n    if ori_dtype in (\"float16\", \"bfloat16\"):\n        need_cast = True\n        x = cast(x, \"float32\")\n\n    if synchronized:\n        mean, variance = _compute_moments_sync(x, axes, keepdims)\n    else:\n        mean, variance = _compute_moments(x, axes, keepdims)\n    if need_cast:\n        # avoid overflow and underflow when casting from float16 to float32\n        mean = tf.clip_by_value(mean, tf.float16.min, tf.float16.max)\n        variance = tf.clip_by_value(variance, tf.float16.min, tf.float16.max)\n        mean = cast(mean, ori_dtype)\n        variance = cast(variance, ori_dtype)\n    return mean, variance\n\n\ndef _compute_moments_sync(x, axes, keepdims):\n    replica_ctx = tf.distribute.get_replica_context()\n    if not replica_ctx:\n        return _compute_moments(x, axes, keepdims)\n\n    local_count = tf.ones_like(x, name=\"count\")\n\n    local_sum = tf.reduce_sum(x, axis=axes, keepdims=True)\n    local_squared_sum = tf.reduce_sum(tf.square(x), axis=axes, keepdims=True)\n    local_count = tf.reduce_sum(local_count, axis=axes, keepdims=True)\n\n    # TODO(b/163099951): batch the all-reduces once we sort out the\n    # ordering issue for NCCL. We don't have a mechanism to launch\n    # NCCL in the same order in each replica nowadays, so we limit\n    # NCCL to batch all-reduces.\n    y_sum = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM, local_sum)\n    y_squared_sum = replica_ctx.all_reduce(\n        tf.distribute.ReduceOp.SUM, local_squared_sum\n    )\n    count_sum = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM, local_count)\n\n    mean = tf.math.divide_no_nan(y_sum, count_sum)\n    y_squared_mean = tf.math.divide_no_nan(y_squared_sum, count_sum)\n    # var = E(x^2) - E(x)^2\n    variance = tf.maximum(y_squared_mean - tf.square(mean), 0.0)\n    if not keepdims:\n        mean = tf.squeeze(mean, axes)\n        variance = tf.squeeze(variance, axes)\n\n    return mean, variance\n\n\ndef _compute_moments(x, axes, keepdims):\n    return tf.nn.moments(x, axes, keepdims=keepdims)\n\n\ndef batch_normalization(\n    x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3\n):\n    if axis != -1:\n        shape = [1] * len(x.shape)\n        shape[axis] = mean.shape[0]\n        mean = tf.reshape(mean, shape)\n        variance = tf.reshape(variance, shape)\n        if offset is not None:\n            offset = tf.reshape(offset, shape)\n        if scale is not None:\n            scale = tf.reshape(scale, shape)\n\n    return tf.nn.batch_normalization(\n        x=x,\n        mean=mean,\n        variance=variance,\n        offset=offset,\n        scale=scale,\n        variance_epsilon=epsilon,\n    )\n\n\ndef ctc_loss(target, output, target_length, output_length, mask_index=0):\n    target = convert_to_tensor(target)\n    output = convert_to_tensor(output)\n    target = tf.cast(target, dtype=\"int32\")\n\n    # `tf.nn.ctc_loss` will internally cast to float32 when the input is float16\n    # or bfloat16. Additionally, it will raise an error when the input is\n    # float64. As a result, we perform the casting externally and add support\n    # for float64.\n    result_dtype = backend.result_type(output.dtype, \"float32\")\n    compute_dtype = \"float32\" if result_dtype == \"float64\" else result_dtype\n    output = tf.cast(output, compute_dtype)\n    loss = tf.nn.ctc_loss(\n        labels=target,\n        logits=output,\n        label_length=target_length,\n        logit_length=output_length,\n        blank_index=mask_index,\n        logits_time_major=False,\n    )\n    return tf.cast(loss, result_dtype)\n\n\ndef ctc_decode(\n    inputs,\n    sequence_lengths,\n    strategy=\"greedy\",\n    beam_width=100,\n    top_paths=1,\n    merge_repeated=True,\n    mask_index=0,\n):\n    inputs = convert_to_tensor(inputs)\n    input_shape = tf.shape(inputs)\n    num_samples, num_steps = input_shape[0], input_shape[1]\n    inputs = tf.transpose(inputs, (1, 0, 2))\n\n    dtype = backend.result_type(inputs.dtype, \"float32\")\n    inputs = tf.cast(inputs, dtype)\n\n    sequence_lengths = convert_to_tensor(sequence_lengths, dtype=\"int32\")\n    if strategy == \"greedy\":\n        (decoded, scores) = tf.nn.ctc_greedy_decoder(\n            inputs=inputs,\n            sequence_length=sequence_lengths,\n            merge_repeated=merge_repeated,\n            blank_index=mask_index,\n        )\n    elif strategy == \"beam_search\":\n        # Move `mask_index` column to the last position since this is the\n        # default for `tf.nn.ctc_beam_search_decoder`\n        if mask_index is not None:\n            inputs_before = inputs[..., :mask_index]\n            inputs_mask = inputs[..., mask_index : mask_index + 1]\n            inputs_after = inputs[..., mask_index + 1 :]\n            inputs = tf.concat(\n                [inputs_before, inputs_after, inputs_mask], axis=-1\n            )\n        (decoded, scores) = tf.nn.ctc_beam_search_decoder(\n            inputs=inputs,\n            sequence_length=sequence_lengths,\n            beam_width=beam_width,\n            top_paths=top_paths,\n        )\n    else:\n        raise ValueError(\n            f\"Invalid strategy {strategy}. Supported values are \"\n            \"'greedy' and 'beam_search'.\"\n        )\n\n    # Postprocess sparse tensor\n    decoded_dense = []\n    for st in decoded:\n        st = tf.SparseTensor(st.indices, st.values, (num_samples, num_steps))\n        decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1))\n    decoded_dense = tf.stack(decoded_dense, axis=0)\n    decoded_dense = tf.cast(decoded_dense, \"int32\")\n\n    # We need to recover the labels because we swapped the indices earlier\n    if strategy == \"beam_search\" and mask_index is not None:\n        if mask_index < 0:\n            mask_index = mask_index + input_shape[-1]\n        decoded_dense = tf.where(\n            decoded_dense >= mask_index, decoded_dense + 1, decoded_dense\n        )\n    return decoded_dense, scores\n\n\ndef psnr(x1, x2, max_val):\n    from keras.src.backend.tensorflow.numpy import log10\n\n    if x1.shape != x2.shape:\n        raise ValueError(\n            f\"Input shapes {x1.shape} and {x2.shape} must \"\n            \"match for PSNR calculation. \"\n        )\n\n    max_val = convert_to_tensor(max_val, dtype=x2.dtype)\n    mse = tf.reduce_mean(tf.square(x1 - x2))\n    psnr = 20 * log10(max_val) - 10 * log10(mse)\n    return psnr\n\n\ndef _get_large_negative(dtype):\n    dtype = backend.standardize_dtype(dtype)\n    val = 65500.0 if dtype == \"float16\" else 3.38953e38\n    return tf.constant(val * -0.7, dtype=dtype)\n\n\ndef _apply_masks(logits, mask, is_causal):\n    if mask is None and not is_causal:\n        return logits\n\n    combined_mask = tf.ones_like(logits, dtype=\"bool\")\n    if mask is not None:\n        combined_mask = tf.logical_and(combined_mask, mask)\n\n    if is_causal:\n        logits_shape = tf.shape(logits)\n        T, S = logits_shape[2], logits_shape[3]\n        mask = tf.linalg.band_part(tf.ones((T, S), \"bool\"), -1, 0)\n        mask = mask[None, None, :, :]\n        combined_mask = tf.logical_and(combined_mask, mask)\n\n    padded_logits = tf.where(\n        combined_mask, logits, _get_large_negative(logits.dtype)\n    )\n    return padded_logits\n\n\ndef _dot_product_attention_xla(query, key, value, bias, mask, is_causal, scale):\n    logits_dtype = backend.result_type(query.dtype, \"float32\")\n    logits = tf.einsum(\"BTNH,BSNH->BNTS\", query, key, optimize=\"optimal\")\n    logits = tf.cast(logits, logits_dtype)\n    logits = tf.multiply(logits, tf.cast(scale, logits.dtype))\n\n    if bias is not None:\n        logits = tf.add(logits, tf.cast(bias, logits.dtype))\n\n    padded_logits = _apply_masks(logits, mask, is_causal)\n\n    # Softmax is always carried out in high precision.\n    probs_dtype = backend.result_type(padded_logits.dtype, \"float32\")\n    probs = tf.cast(\n        tf.nn.softmax(tf.cast(padded_logits, probs_dtype), axis=-1), key.dtype\n    )\n    return tf.einsum(\"BNTS,BSNH->BTNH\", probs, value, optimize=\"optimal\")\n\n\ndef dot_product_attention(\n    query,\n    key,\n    value,\n    bias=None,\n    mask=None,\n    scale=None,\n    is_causal=False,\n    flash_attention=None,\n    attn_logits_soft_cap=None,\n):\n    if flash_attention is None:\n        flash_attention = False\n    if flash_attention:\n        raise ValueError(\n            \"Flash attention is not supported in tensorflow backend.\"\n        )\n\n    # Ref: jax.nn.dot_product_attention\n    # https://github.com/jax-ml/jax/blob/jax-v0.4.32/jax/_src/nn/functions.py#L828\n    # Not support `query_seq_lengths` and `key_value_seq_lengths` args\n    query = convert_to_tensor(query)\n    key = convert_to_tensor(key)\n    value = convert_to_tensor(value)\n    if len(query.shape) != 4:\n        raise ValueError(\n            \"`dot_product_attention` only supports 4D inputs. \"\n            f\"Received: query.shape={query.shape}, key.shape={key.shape}, \"\n            f\"value.shape={value.shape}.\"\n        )\n    compute_dtype = backend.result_type(query.dtype, key.dtype, value.dtype)\n    query = cast(query, compute_dtype)\n    key = cast(key, compute_dtype)\n    value = cast(value, compute_dtype)\n    if bias is not None:\n        bias = convert_to_tensor(bias, dtype=compute_dtype)\n\n    H = tf.shape(key)[-1]\n    scale = (1.0 / tf.sqrt(tf.cast(H, \"float32\"))) if scale is None else scale\n    return _dot_product_attention_xla(\n        query, key, value, bias, mask, is_causal, scale\n    )\n\n\ndef unfold(input, kernel_size, dilation=1, padding=0, stride=1):\n    \"\"\"Tensorflow implementation of Unfold.\n    Extract sliding local blocks from a **NCHW** batched image tensor.\n\n    Args:\n        input: 4-D tensor, shape (N, C, H, W)  **required**.\n        kernel_size: int or (kH, kW)\n        dilation: int or (dH, dW), default 1\n        padding: int or (pH, pW), default 0\n        stride: int or (sH, sW), default 1\n\n    Returns:\n        3-D tensor, shape (N, C*kH*kW, L)\n    \"\"\"\n    k = (\n        (kernel_size, kernel_size)\n        if isinstance(kernel_size, int)\n        else kernel_size\n    )\n    d = (dilation, dilation) if isinstance(dilation, int) else dilation\n    p = (padding, padding) if isinstance(padding, int) else padding\n    s = (stride, stride) if isinstance(stride, int) else stride\n    N, C, H, W = input.shape\n\n    # ---- padding ----\n    if any(_ > 0 for _ in p):\n        input = tf.pad(input, [[0, 0], [0, 0], [p[0], p[0]], [p[1], p[1]]])\n    x = tf.transpose(input, [0, 2, 3, 1])  # (N, H, W, C)\n    patches = tf.image.extract_patches(\n        images=x,\n        sizes=[1, k[0], k[1], 1],\n        strides=[1, s[0], s[1], 1],\n        rates=[1, d[0], d[1], 1],\n        padding=\"VALID\",\n    )  # (N, nH, nW, kH*kW*C)\n\n    N, nH, nW, D = patches.shape\n    patches = tf.reshape(\n        patches, [N, nH, nW, k[0], k[1], C]\n    )  # (N, nH, nW, kH, kW, C)\n    patches = tf.transpose(\n        patches, [0, 5, 3, 4, 1, 2]\n    )  # (N, C, kH, kW, nH, nW)\n    patches = tf.reshape(patches, [N, C * k[0] * k[1], nH * nW])\n    return patches\n\n\ndef fold(x, output_size, kernel_size, dilation=1, padding=0, stride=1):\n    \"\"\"TensorFlow implementation of Fold (col2im).\n    Combine an array of sliding local blocks into a large tensor.\n\n    Args:\n        x: 3-D tensor, shape (N, C*kH*kW, L)  **required**.\n        output_size: int or (oH, oW)\n        kernel_size: int or (kH, kW)\n        dilation: int or (dH, dW), default 1\n        padding: int or (pH, pW), default 0\n        stride: int or (sH, sW), default 1\n\n    Returns:\n        4-D tensor, shape (N, C, oH, oW)\n    \"\"\"\n    k = (\n        (kernel_size, kernel_size)\n        if isinstance(kernel_size, int)\n        else kernel_size\n    )\n    o = (\n        (output_size, output_size)\n        if isinstance(output_size, int)\n        else output_size\n    )\n    d = (dilation, dilation) if isinstance(dilation, int) else dilation\n    p = (padding, padding) if isinstance(padding, int) else padding\n    s = (stride, stride) if isinstance(stride, int) else stride\n\n    N = tf.shape(x)[0]\n    CKK = x.shape[1]\n    kH, kW = k\n    oH, oW = o\n    C = CKK // (kH * kW)\n\n    # Number of output patches along each dimension\n    nH = (oH + 2 * p[0] - d[0] * (kH - 1) - 1) // s[0] + 1\n    nW = (oW + 2 * p[1] - d[1] * (kW - 1) - 1) // s[1] + 1\n\n    # Reshape: (N, C*kH*kW, L) -> (N, C, kH, kW, nH, nW)\n    x = tf.reshape(x, [N, C, kH, kW, nH, nW])\n\n    # Padded output size\n    oH_pad = oH + 2 * p[0]\n    oW_pad = oW + 2 * p[1]\n\n    # Build scatter indices for all kernel positions\n    # Process one sample at a time using vectorized_map\n    def _fold_single(x_single):\n        # x_single: (C, kH, kW, nH, nW)\n        output = tf.zeros([C, oH_pad, oW_pad], dtype=x.dtype)\n        for i in range(kH):\n            for j in range(kW):\n                h_start = i * d[0]\n                w_start = j * d[1]\n                h_indices = tf.range(nH) * s[0] + h_start\n                w_indices = tf.range(nW) * s[1] + w_start\n                # x_single[:, i, j, :, :] has shape (C, nH, nW)\n                patch = x_single[:, i, j, :, :]\n                # Build indices for scatter\n                c_idx = tf.repeat(tf.range(C), nH * nW)  # (C*nH*nW,)\n                h_idx = tf.tile(tf.repeat(h_indices, nW), [C])  # (C*nH*nW,)\n                w_idx = tf.tile(tf.tile(w_indices, [nH]), [C])  # (C*nH*nW,)\n                indices = tf.stack(\n                    [c_idx, h_idx, w_idx], axis=1\n                )  # (C*nH*nW, 3)\n                values = tf.reshape(patch, [-1])  # (C*nH*nW,)\n                output = tf.tensor_scatter_nd_add(output, indices, values)\n        return output\n\n    output = tf.vectorized_map(_fold_single, x)  # (N, C, oH_pad, oW_pad)\n\n    # Remove padding\n    if p[0] > 0 or p[1] > 0:\n        output = output[:, :, p[0] : oH_pad - p[0], p[1] : oW_pad - p[1]]\n\n    return output\n\n\ndef depth_to_space(x, block_size, data_format=\"channels_last\"):\n    \"\"\"TensorFlow implementation of depth_to_space.\n\n    Rearranges data from depth into blocks of spatial data.\n\n    Args:\n        x: 4-D tensor with shape (N, H, W, C) for channels_last or\n            (N, C, H, W) for channels_first.\n        block_size: An integer specifying the block size.\n        data_format: \"channels_last\" or \"channels_first\".\n\n    Returns:\n        A tensor with shape (N, H*block_size, W*block_size, C/block_size**2)\n        for channels_last or (N, C/block_size**2, H*block_size, W*block_size)\n        for channels_first.\n    \"\"\"\n    if data_format == \"channels_last\":\n        return tf.nn.depth_to_space(x, block_size, data_format=\"NHWC\")\n    else:\n        # NCHW format is not supported on CPU, so we transpose manually\n        # NCHW -> NHWC\n        x = tf.transpose(x, [0, 2, 3, 1])\n        x = tf.nn.depth_to_space(x, block_size, data_format=\"NHWC\")\n        # NHWC -> NCHW\n        x = tf.transpose(x, [0, 3, 1, 2])\n        return x\n\n\ndef space_to_depth(x, block_size, data_format=\"channels_last\"):\n    \"\"\"TensorFlow implementation of space_to_depth.\n\n    Rearranges blocks of spatial data into depth.\n\n    Args:\n        x: 4-D tensor with shape (N, H, W, C) for channels_last or\n            (N, C, H, W) for channels_first.\n        block_size: An integer specifying the block size.\n        data_format: \"channels_last\" or \"channels_first\".\n\n    Returns:\n        A tensor with shape (N, H/block_size, W/block_size, C*block_size**2)\n        for channels_last or (N, C*block_size**2, H/block_size, W/block_size)\n        for channels_first.\n    \"\"\"\n    if data_format == \"channels_last\":\n        return tf.nn.space_to_depth(x, block_size, data_format=\"NHWC\")\n    else:\n        # NCHW format is not supported on CPU, so we transpose manually\n        # NCHW -> NHWC\n        x = tf.transpose(x, [0, 2, 3, 1])\n        x = tf.nn.space_to_depth(x, block_size, data_format=\"NHWC\")\n        # NHWC -> NCHW\n        x = tf.transpose(x, [0, 3, 1, 2])\n        return x\n"
  },
  {
    "path": "keras/src/backend/tensorflow/numpy.py",
    "content": "import builtins\nimport collections\nimport functools\nimport math\nimport string\nimport warnings\n\nimport numpy as np\nimport tensorflow as tf\nfrom tensorflow.python.ops.linalg.sparse import sparse_csr_matrix_ops\nfrom tensorflow.python.ops.math_ops import is_nan\n\nfrom keras.src import tree\nfrom keras.src.backend import config\nfrom keras.src.backend import standardize_dtype\nfrom keras.src.backend.common import dtypes\nfrom keras.src.backend.common.backend_utils import canonicalize_axis\nfrom keras.src.backend.common.backend_utils import to_tuple_or_list\nfrom keras.src.backend.common.backend_utils import vectorize_impl\nfrom keras.src.backend.tensorflow import sparse\nfrom keras.src.backend.tensorflow.core import cast\nfrom keras.src.backend.tensorflow.core import convert_to_tensor\nfrom keras.src.backend.tensorflow.core import shape as shape_op\n\n\ndef rot90(array, k=1, axes=(0, 1)):\n    \"\"\"Rotate an array by 90 degrees in the specified plane.\n\n    Args:\n        array: Input tensor\n        k: Number of 90-degree rotations (default=1)\n        axes: Tuple of two axes that define the plane of rotation.\n        Defaults to (0, 1).\n\n    Returns:\n        Rotated tensor with correct shape transformation\n    \"\"\"\n    array = convert_to_tensor(array)\n\n    if array.shape.rank < 2:\n        raise ValueError(\n            f\"Input array must have at least 2 dimensions. \"\n            f\"Received: array.ndim={array.shape.rank}\"\n        )\n\n    if len(axes) != 2 or axes[0] == axes[1]:\n        raise ValueError(\n            f\"Invalid axes: {axes}. Axes must be a tuple of \"\n            \"two different dimensions.\"\n        )\n\n    k = k % 4\n    if k == 0:\n        return array\n\n    axes = tuple(\n        axis if axis >= 0 else array.shape.rank + axis for axis in axes\n    )\n\n    perm = [i for i in range(array.shape.rank) if i not in axes]\n    perm.extend(axes)\n    array = tf.transpose(array, perm)\n\n    shape = tf.shape(array)\n    non_rot_shape = shape[:-2]\n    h, w = shape[-2], shape[-1]\n\n    array = tf.reshape(array, tf.concat([[-1], [h, w]], axis=0))\n\n    array = tf.reverse(array, axis=[2])\n    array = tf.transpose(array, [0, 2, 1])\n\n    if k % 2 == 1:\n        final_h, final_w = w, h\n    else:\n        final_h, final_w = h, w\n\n    if k > 1:\n        array = tf.reshape(array, tf.concat([[-1], [final_h, final_w]], axis=0))\n        for _ in range(k - 1):\n            array = tf.reverse(array, axis=[2])\n            array = tf.transpose(array, [0, 2, 1])\n\n    final_shape = tf.concat([non_rot_shape, [final_h, final_w]], axis=0)\n    array = tf.reshape(array, final_shape)\n\n    inv_perm = [0] * len(perm)\n    for i, p in enumerate(perm):\n        inv_perm[p] = i\n    array = tf.transpose(array, inv_perm)\n\n    return array\n\n\n@sparse.elementwise_binary_union(tf.sparse.add)\ndef add(x1, x2):\n    if not isinstance(x1, (int, float)):\n        x1 = convert_to_tensor(x1)\n    if not isinstance(x2, (int, float)):\n        x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(\n        getattr(x1, \"dtype\", type(x1)),\n        getattr(x2, \"dtype\", type(x2)),\n    )\n    x1 = convert_to_tensor(x1, dtype)\n    x2 = convert_to_tensor(x2, dtype)\n\n    # Special case of `tf.add`: `tf.nn.bias_add`\n    # `BiasAdd` can be fused with `MatMul` and `Conv*` kernels\n    # Expecting `x1` to be `inputs` and `x2` to be `bias` (no swapping)\n    x2_squeeze_shape = [d for d in x2.shape.as_list() if d is None or d > 1]\n    if (\n        # `x2` looks like bias (can be squeezed to vector)\n        1 == len(x2_squeeze_shape)\n        # `x1` looks like input tensor (rank >= 2)\n        and len(x1.shape) > 1\n        # `x2` non-squeezable dimension defined\n        and x2_squeeze_shape[0] is not None\n        # `x2` non-squeezable dimension match `x1` channel dimension\n        and x2_squeeze_shape[0]\n        in {x1.shape.as_list()[1], x1.shape.as_list()[-1]}\n    ):\n        if x1.shape[-1] == x2_squeeze_shape[0]:\n            data_format = \"NHWC\"\n        else:\n            data_format = \"NCHW\"\n        if len(x2.shape) > 1:\n            x2 = tf.squeeze(x2)\n        return tf.nn.bias_add(x1, x2, data_format=data_format)\n\n    return tf.add(x1, x2)\n\n\ndef bartlett(x):\n    x = convert_to_tensor(x, dtype=config.floatx())\n    if x == 0:\n        return tf.constant([])\n    if x == 1:\n        return tf.ones([1])\n\n    n = tf.range(x)\n    half = (x - 1) / 2\n\n    window = tf.where(n <= half, 2.0 * n / (x - 1), 2.0 - 2.0 * n / (x - 1))\n\n    return window\n\n\ndef hamming(x):\n    x = convert_to_tensor(x, dtype=tf.int32)\n    return tf.signal.hamming_window(x, periodic=False)\n\n\ndef hanning(x):\n    x = convert_to_tensor(x, dtype=tf.int32)\n    return tf.signal.hann_window(x, periodic=False)\n\n\ndef heaviside(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    if dtype in [\"int8\", \"int16\", \"int32\", \"uint8\", \"uint16\", \"uint32\"]:\n        dtype = config.floatx()\n    elif dtype in [\"int64\"]:\n        dtype = \"float64\"\n\n    x1 = tf.cast(x1, dtype)\n    x2 = tf.cast(x2, dtype)\n\n    return tf.where(\n        x1 < 0,\n        tf.zeros_like(x1),\n        tf.where(x1 > 0, tf.ones_like(x1), x2),\n    )\n\n\ndef kaiser(x, beta):\n    x = convert_to_tensor(x, dtype=tf.int32)\n    return tf.signal.kaiser_window(x, beta=beta)\n\n\ndef bincount(x, weights=None, minlength=0, sparse=False):\n    x = convert_to_tensor(x)\n    dtypes_to_resolve = [x.dtype]\n    if standardize_dtype(x.dtype) not in [\"int32\", \"int64\"]:\n        x = tf.cast(x, tf.int32)\n    if weights is not None:\n        weights = convert_to_tensor(weights)\n        dtypes_to_resolve.append(weights.dtype)\n        dtype = dtypes.result_type(*dtypes_to_resolve)\n        if standardize_dtype(weights.dtype) not in [\n            \"int32\",\n            \"int64\",\n            \"float32\",\n            \"float64\",\n        ]:\n            if \"int\" in standardize_dtype(weights.dtype):\n                weights = tf.cast(weights, tf.int32)\n            else:\n                weights = tf.cast(weights, tf.float32)\n    else:\n        dtype = \"int32\"\n    if sparse or isinstance(x, tf.SparseTensor):\n        output = tf.sparse.bincount(\n            x,\n            weights=weights,\n            minlength=minlength,\n            axis=-1,\n        )\n        actual_length = output.shape[-1]\n        if actual_length is None:\n            actual_length = tf.shape(output)[-1]\n        output = cast(output, dtype)\n        if x.shape.rank == 1:\n            output_shape = (actual_length,)\n        else:\n            batch_size = output.shape[0]\n            if batch_size is None:\n                batch_size = tf.shape(output)[0]\n            output_shape = (batch_size, actual_length)\n        return tf.SparseTensor(\n            indices=output.indices,\n            values=output.values,\n            dense_shape=output_shape,\n        )\n    return tf.cast(\n        tf.math.bincount(x, weights=weights, minlength=minlength, axis=-1),\n        dtype,\n    )\n\n\n@functools.lru_cache(512)\ndef _normalize_einsum_subscripts(subscripts):\n    # string.ascii_letters\n    mapping = {}\n    normalized_subscripts = \"\"\n    for c in subscripts:\n        if c in string.ascii_letters:\n            if c not in mapping:\n                mapping[c] = string.ascii_letters[len(mapping)]\n            normalized_subscripts += mapping[c]\n        else:\n            normalized_subscripts += c\n    return normalized_subscripts\n\n\ndef einsum(subscripts, *operands, **kwargs):\n    operands = tree.map_structure(convert_to_tensor, operands)\n    subscripts = _normalize_einsum_subscripts(subscripts)\n\n    def is_valid_for_custom_ops(subscripts, *operands):\n        # Check that `subscripts` is supported and the shape of operands is not\n        # `None`.\n        if subscripts in [\n            \"a,b->ab\",\n            \"ab,b->a\",\n            \"ab,bc->ac\",\n            \"ab,cb->ac\",\n            \"abc,cd->abd\",\n            \"abc,dc->abd\",\n            \"abcd,abde->abce\",\n            \"abcd,abed->abce\",\n            \"abcd,acbe->adbe\",\n            \"abcd,adbe->acbe\",\n            \"abcd,aecd->acbe\",\n            \"abcd,aecd->aceb\",\n        ]:\n            # These subscripts don't require the shape information\n            return True\n        elif subscripts == \"abc,cde->abde\":\n            _, b1, c1 = operands[0].shape\n            c2, d2, e2 = operands[1].shape\n            b, c, d, e = b1, c1 or c2, d2, e2\n            if None in (b, c, d, e):\n                return False\n            return True\n        elif subscripts == \"abc,dce->abde\":\n            _, b1, c1 = operands[0].shape\n            d2, c2, e2 = operands[1].shape\n            b, c, d, e = b1, c1 or c2, d2, e2\n            if None in (b, c, d, e):\n                return False\n            return True\n        elif subscripts == \"abc,dec->abde\":\n            _, b1, c1 = operands[0].shape\n            d2, e2, c2 = operands[1].shape\n            b, c, d, e = b1, c1 or c2, d2, e2\n            if None in (b, c, d, e):\n                return False\n            return True\n        elif subscripts == \"abcd,cde->abe\":\n            _, b1, c1, d1 = operands[0].shape\n            c2, d2, e2 = operands[1].shape\n            b, c, d, e = b1, c1 or c2, d1 or d2, e2\n            if None in (b, c, d, e):\n                return False\n            return True\n        elif subscripts == \"abcd,ced->abe\":\n            _, b1, c1, d1 = operands[0].shape\n            c2, e2, d2 = operands[1].shape\n            b, c, d, e = b1, c1 or c2, d1 or d2, e2\n            if None in (b, c, d, e):\n                return False\n            return True\n        elif subscripts == \"abcd,ecd->abe\":\n            _, b1, c1, d1 = operands[0].shape\n            e2, c2, d2 = operands[1].shape\n            b, c, d, e = b1, c1 or c2, d1 or d2, e2\n            if None in (b, c, d, e):\n                return False\n            return True\n        elif subscripts == \"abcde,aebf->adbcf\":\n            _, b1, c1, d1, e1 = operands[0].shape\n            _, e2, b2, f2 = operands[1].shape\n            b, c, d, e, f = b1 or b2, c1, d1, e1 or e2, f2\n            if None in (b, c, d, e, f):\n                return False\n            return True\n        elif subscripts == \"abcde,afce->acdbf\":\n            _, b1, c1, d1, e1 = operands[0].shape\n            _, f2, c2, e2 = operands[1].shape\n            b, c, d, e, f = b1, c1 or c2, d1, e1 or e2, f2\n            if None in (b, c, d, e, f):\n                return False\n            return True\n        else:\n            # No match in subscripts\n            return False\n\n    def use_custom_ops(subscripts, *operands, output_type):\n        # Replace tf.einsum with custom ops to utilize hardware-accelerated\n        # matmul\n        x, y = operands[0], operands[1]\n        if subscripts == \"a,b->ab\":\n            x = tf.expand_dims(x, axis=-1)\n            y = tf.expand_dims(y, axis=0)\n            return tf.matmul(x, y, output_type=output_type)\n        elif subscripts == \"ab,b->a\":\n            y = tf.expand_dims(y, axis=-1)\n            result = tf.matmul(x, y, output_type=output_type)\n            return tf.squeeze(result, axis=-1)\n        elif subscripts == \"ab,bc->ac\":\n            return tf.matmul(x, y, output_type=output_type)\n        elif subscripts == \"ab,cb->ac\":\n            y = tf.transpose(y, [1, 0])\n            return tf.matmul(x, y, output_type=output_type)\n        elif subscripts == \"abc,cd->abd\":\n            return tf.matmul(x, y, output_type=output_type)\n        elif subscripts == \"abc,cde->abde\":\n            _, b1, c1 = x.shape\n            c2, d2, e2 = y.shape\n            b, c, d, e = b1, c1 or c2, d2, e2\n            y = tf.reshape(y, [c, -1])\n            result = tf.matmul(x, y, output_type=output_type)\n            return tf.reshape(result, [-1, b, d, e])\n        elif subscripts == \"abc,dc->abd\":\n            y = tf.transpose(y, [1, 0])\n            return tf.matmul(x, y, output_type=output_type)\n        elif subscripts == \"abc,dce->abde\":\n            _, b1, c1 = x.shape\n            d2, c2, e2 = y.shape\n            b, c, d, e = b1, c1 or c2, d2, e2\n            y = tf.transpose(y, [1, 0, 2])  # cde\n            y = tf.reshape(y, [c, -1])\n            result = tf.matmul(x, y, output_type=output_type)\n            return tf.reshape(result, [-1, b, d, e])\n        elif subscripts == \"abc,dec->abde\":\n            _, b1, c1 = x.shape\n            d2, e2, c2 = y.shape\n            b, c, d, e = b1, c1 or c2, d2, e2\n            y = tf.transpose(y, [2, 0, 1])  # cde\n            y = tf.reshape(y, [c, -1])\n            result = tf.matmul(x, y, output_type=output_type)\n            return tf.reshape(result, [-1, b, d, e])\n        elif subscripts == \"abcd,abde->abce\":\n            return tf.matmul(x, y, output_type=output_type)\n        elif subscripts == \"abcd,abed->abce\":\n            y = tf.transpose(y, [0, 1, 3, 2])\n            return tf.matmul(x, y, output_type=output_type)\n        elif subscripts == \"abcd,acbe->adbe\":\n            x = tf.transpose(x, [0, 1, 3, 2])\n            y = tf.transpose(y, [0, 2, 1, 3])\n            result = tf.matmul(x, y, output_type=output_type)\n            return tf.transpose(result, [0, 2, 1, 3])\n        elif subscripts == \"abcd,adbe->acbe\":\n            y = tf.transpose(y, [0, 2, 1, 3])  # abde\n            result = tf.matmul(x, y, output_type=output_type)  # abce\n            return tf.transpose(result, [0, 2, 1, 3])\n        elif subscripts == \"abcd,aecd->acbe\":\n            x = tf.transpose(x, [0, 2, 1, 3])  # acbd\n            y = tf.transpose(y, [0, 2, 3, 1])  # acde\n            return tf.matmul(x, y, output_type=output_type)\n        elif subscripts == \"abcd,aecd->aceb\":\n            x = tf.transpose(x, [0, 2, 1, 3])\n            y = tf.transpose(y, [0, 2, 3, 1])\n            result = tf.matmul(x, y, output_type=output_type)  # acbe\n            return tf.transpose(result, [0, 1, 3, 2])\n        elif subscripts == \"abcd,cde->abe\":\n            _, b1, c1, d1 = x.shape\n            c2, d2, e2 = y.shape\n            b, c, d, e = b1, c1 or c2, d1 or d2, e2\n            x = tf.reshape(x, [-1, b, c * d])\n            y = tf.reshape(y, [-1, e])\n            return tf.matmul(x, y, output_type=output_type)\n        elif subscripts == \"abcd,ced->abe\":\n            _, b1, c1, d1 = x.shape\n            c2, e2, d2 = y.shape\n            b, c, d, e = b1, c1 or c2, d1 or d2, e2\n            x = tf.reshape(x, [-1, b, c * d])\n            y = tf.transpose(y, [0, 2, 1])\n            y = tf.reshape(y, [-1, e])\n            return tf.matmul(x, y, output_type=output_type)\n        elif subscripts == \"abcd,ecd->abe\":\n            _, b1, c1, d1 = x.shape\n            e2, c2, d2 = y.shape\n            b, c, d, e = b1, c1 or c2, d1 or d2, e2\n            x = tf.reshape(x, [-1, b, c * d])\n            y = tf.transpose(y, [1, 2, 0])\n            y = tf.reshape(y, [-1, e])\n            return tf.matmul(x, y, output_type=output_type)\n        elif subscripts == \"abcde,aebf->adbcf\":\n            _, b1, c1, d1, e1 = x.shape\n            _, e2, b2, f2 = y.shape\n            b, c, d, e, f = b1 or b2, c1, d1, e1 or e2, f2\n            x = tf.reshape(x, [-1, b, c * d, e])  # ab(cd)e\n            y = tf.transpose(y, [0, 2, 1, 3])  # abef\n            result = tf.matmul(x, y, output_type=output_type)  # ab(cd)f\n            result = tf.reshape(result, [-1, b, c, d, f])  # abcdf\n            return tf.transpose(result, [0, 3, 1, 2, 4])\n        elif subscripts == \"abcde,afce->acdbf\":\n            _, b1, c1, d1, e1 = x.shape\n            _, f2, c2, e2 = y.shape\n            b, c, d, e, f = b1, c1 or c2, d1, e1 or e2, f2\n            x = tf.transpose(x, [0, 2, 3, 1, 4])  # acdbe\n            x = tf.reshape(x, [-1, c, d * b, e])  # ac(db)e\n            y = tf.transpose(y, [0, 2, 3, 1])  # acef\n            result = tf.matmul(x, y, output_type=output_type)  # ac(db)f\n            return tf.reshape(result, [-1, c, d, b, f])\n        else:\n            raise NotImplementedError\n\n    dtypes_to_resolve = list(set(standardize_dtype(x.dtype) for x in operands))\n    # When operands are of int8, we cast the result to int32 to align with\n    # the behavior of jax.\n    if len(dtypes_to_resolve) == 1 and dtypes_to_resolve[0] == \"int8\":\n        compute_dtype = \"int8\"\n        result_dtype = \"int32\"\n        output_type = \"int32\"\n    else:\n        result_dtype = dtypes.result_type(*dtypes_to_resolve)\n        compute_dtype = result_dtype\n        output_type = None\n\n    # TODO: Remove the condition once `tf.einsum` supports int8xint8->int32\n    if is_valid_for_custom_ops(subscripts, *operands) and not kwargs:\n        # TODO: tf.matmul doesn't support integer dtype if not specifying\n        # output_type=\"int32\"\n        if \"int\" in compute_dtype and output_type is None:\n            compute_dtype = config.floatx()\n        operands = tree.map_structure(\n            lambda x: tf.cast(x, compute_dtype), operands\n        )\n        result = use_custom_ops(subscripts, *operands, output_type=output_type)\n    else:\n        # TODO: tf.einsum doesn't support integer dtype with gpu\n        if \"int\" in compute_dtype:\n            compute_dtype = config.floatx()\n        operands = tree.map_structure(\n            lambda x: tf.cast(x, compute_dtype), operands\n        )\n        result = tf.einsum(subscripts, *operands, **kwargs)\n    return tf.cast(result, result_dtype)\n\n\n@sparse.elementwise_binary_union(sparse.sparse_subtract)\ndef subtract(x1, x2):\n    if not isinstance(x1, (int, float)):\n        x1 = convert_to_tensor(x1)\n    if not isinstance(x2, (int, float)):\n        x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(\n        getattr(x1, \"dtype\", type(x1)),\n        getattr(x2, \"dtype\", type(x2)),\n    )\n    x1 = convert_to_tensor(x1, dtype)\n    x2 = convert_to_tensor(x2, dtype)\n    return tf.subtract(x1, x2)\n\n\ndef matmul(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    x1_shape = x1.shape\n    x2_shape = x2.shape\n    x1_sparse = isinstance(x1, tf.SparseTensor)\n    x2_sparse = isinstance(x2, tf.SparseTensor)\n    # When both x1 and x2 are of int8 and dense tensor, specifying `output_type`\n    # as int32 to enable hardware-accelerated matmul\n    x1_dtype = standardize_dtype(x1.dtype)\n    x2_dtype = standardize_dtype(x2.dtype)\n    if (\n        x1_dtype == \"int8\"\n        and x2_dtype == \"int8\"\n        and not x1_sparse\n        and not x2_sparse\n        and x1_shape.rank != 1  # TODO: support tf.tensordot\n        and x2_shape.rank != 1  # TODO: support tf.tensordot\n    ):\n        compute_dtype = \"int8\"\n        result_dtype = \"int32\"\n        output_type = result_dtype\n    else:\n        # TODO: Typically, GPU and XLA only support float types\n        compute_dtype = dtypes.result_type(x1.dtype, x2.dtype, float)\n        result_dtype = dtypes.result_type(x1.dtype, x2.dtype)\n        output_type = None\n    x1 = tf.cast(x1, compute_dtype)\n    x2 = tf.cast(x2, compute_dtype)\n\n    def with_combined_batch_dimensions(a, b, output_shape, fn_3d):\n        a_sparse = isinstance(a, tf.SparseTensor)\n        b_sparse = isinstance(b, tf.SparseTensor)\n        batch_shape = b.shape[:-2] if b_sparse else a.shape[:-2]\n        batch_size = math.prod(batch_shape)\n        a3d_shape = [batch_size] + a.shape[-2:]\n        a_3d = (\n            tf.sparse.reshape(a, a3d_shape)\n            if a_sparse\n            else tf.reshape(a, a3d_shape)\n        )\n        b3d_shape = [batch_size] + b.shape[-2:]\n        b_3d = (\n            tf.sparse.reshape(b, b3d_shape)\n            if b_sparse\n            else tf.reshape(b, b3d_shape)\n        )\n        result_3d = fn_3d(a_3d, b_3d)\n        return (\n            tf.sparse.reshape(result_3d, output_shape)\n            if isinstance(result_3d, tf.SparseTensor)\n            else tf.reshape(result_3d, output_shape)\n        )\n\n    def sparse_sparse_matmul(a, b):\n        dtype = a.values.dtype\n        # Convert SparseTensors to CSR SparseMatrix.\n        a_csr = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(\n            a.indices, a.values, a.dense_shape\n        )\n        b_csr = sparse_csr_matrix_ops.sparse_tensor_to_csr_sparse_matrix(\n            b.indices, b.values, b.dense_shape\n        )\n        # Compute the CSR SparseMatrix matrix multiplication.\n        result_csr = sparse_csr_matrix_ops.sparse_matrix_sparse_mat_mul(\n            a_csr, b_csr, dtype\n        )\n        # Convert the CSR SparseMatrix to a SparseTensor.\n        res = sparse_csr_matrix_ops.csr_sparse_matrix_to_sparse_tensor(\n            result_csr, dtype\n        )\n        return tf.SparseTensor(res.indices, res.values, res.dense_shape)\n\n    def embedding_lookup_sparse_dense_matmul(a, b):\n        # We need at least one id per rows for embedding_lookup_sparse,\n        # otherwise there will be missing rows in the output.\n        a, _ = tf.sparse.fill_empty_rows(a, 0)\n        # We need to split x1 into separate ids and weights tensors. The ids\n        # should be the column indices of x1 and the values of the weights\n        # can continue to be the actual x1. The column arrangement of ids\n        # and weights does not matter as we sum over columns. See details in\n        # the documentation for sparse_ops.sparse_tensor_dense_matmul.\n        ids = tf.SparseTensor(\n            indices=a.indices,\n            values=a.indices[:, 1],\n            dense_shape=a.dense_shape,\n        )\n        return tf.nn.embedding_lookup_sparse(b, ids, a, combiner=\"sum\")\n\n    # Either a or b is sparse\n    def sparse_dense_matmul_3d(a, b):\n        return tf.map_fn(\n            lambda x: tf.sparse.sparse_dense_matmul(x[0], x[1]),\n            elems=(a, b),\n            fn_output_signature=a.dtype,\n        )\n\n    if x1_sparse or x2_sparse:\n        from keras.src.ops.operation_utils import compute_matmul_output_shape\n\n        output_shape = compute_matmul_output_shape(x1_shape, x2_shape)\n        if x1_sparse and x2_sparse:\n            if x1_shape.rank <= 3:\n                output = sparse_sparse_matmul(x1, x2)\n            else:\n                output = with_combined_batch_dimensions(\n                    x1, x2, output_shape, sparse_sparse_matmul\n                )\n        else:\n            # Sparse * dense or dense * sparse\n            sparse_rank = x1_shape.rank if x1_sparse else x2_shape.rank\n\n            # Special case: embedding_lookup_sparse for sparse * dense, rank 2\n            if x1_sparse and sparse_rank == 2:\n                output = embedding_lookup_sparse_dense_matmul(x1, x2)\n            elif sparse_rank == 2:\n                output = tf.sparse.sparse_dense_matmul(x1, x2)\n            elif sparse_rank == 3:\n                output = sparse_dense_matmul_3d(x1, x2)\n            else:\n                output = with_combined_batch_dimensions(\n                    x1, x2, output_shape, sparse_dense_matmul_3d\n                )\n        output = tf.cast(output, result_dtype)\n        output.set_shape(output_shape)\n        return output\n    else:\n        if x1_shape.rank == 2 and x2_shape.rank == 2:\n            output = tf.matmul(x1, x2, output_type=output_type)\n        elif x2_shape.rank == 1:\n            output = tf.tensordot(x1, x2, axes=1)\n        elif x1_shape.rank == 1:\n            output = tf.tensordot(x1, x2, axes=[[0], [-2]])\n        else:\n            output = tf.matmul(x1, x2, output_type=output_type)\n        return tf.cast(output, result_dtype)\n\n\n@sparse.elementwise_binary_intersection\ndef multiply(x1, x2):\n    if not isinstance(x1, (int, float)):\n        x1 = convert_to_tensor(x1)\n    if not isinstance(x2, (int, float)):\n        x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(\n        getattr(x1, \"dtype\", type(x1)),\n        getattr(x2, \"dtype\", type(x2)),\n    )\n    x1 = convert_to_tensor(x1, dtype)\n    x2 = convert_to_tensor(x2, dtype)\n    return tf.multiply(x1, x2)\n\n\ndef mean(x, axis=None, keepdims=False):\n    if isinstance(x, tf.IndexedSlices):\n        if axis is None:\n            # Reduce against all axes, result is a single value and dense.\n            # The denominator has to account for `dense_shape`.\n            sum = tf.reduce_sum(x.values, keepdims=keepdims)\n            return sum / tf.cast(tf.reduce_prod(x.dense_shape), dtype=sum.dtype)\n\n        axis = to_tuple_or_list(axis)\n        if not axis:\n            # Empty axis tuple, this is a no-op\n            return x\n\n        dense_shape = tf.convert_to_tensor(x.dense_shape)\n        rank = tf.shape(dense_shape)[0]\n        # Normalize axis: convert negative values and sort\n        axis = [canonicalize_axis(a, rank) for a in axis]\n        axis.sort()\n\n        if axis == [0]:\n            # Reduce against `axis=0` only, result is dense.\n            # The denominator has to account for `dense_shape[0]`.\n            sum = tf.reduce_sum(x.values, axis=0, keepdims=keepdims)\n            return sum / tf.cast(dense_shape[0], dtype=sum.dtype)\n        elif axis[0] == 0:\n            # Reduce against axis 0 and other axes, result is dense.\n            # We do `axis=0` separately first. The denominator has to account\n            # for `dense_shape[0]`.\n            # We use `keepdims=True` in `reduce_sum`` so that we can leave the\n            # 0 in axis and do `reduce_mean` with `keepdims` to apply it for all\n            # axes.\n            sum = tf.reduce_sum(x.values, axis=0, keepdims=True)\n            axis_0_mean = sum / tf.cast(dense_shape[0], dtype=sum.dtype)\n            return tf.reduce_mean(axis_0_mean, axis=axis, keepdims=keepdims)\n        elif keepdims:\n            # With `keepdims=True`, result is an `IndexedSlices` with the same\n            # indices since axis 0 is not touched. The only thing to do is to\n            # correct `dense_shape` to account for dimensions that became 1.\n            new_values = tf.reduce_mean(x.values, axis=axis, keepdims=True)\n            new_dense_shape = tf.concat(\n                [dense_shape[0:1], new_values.shape[1:]], axis=0\n            )\n            return tf.IndexedSlices(new_values, x.indices, new_dense_shape)\n        elif rank == len(axis) + 1:\n            # `keepdims=False` and reducing against all axes except 0, result is\n            # a 1D tensor, which cannot be `IndexedSlices`. We have to scatter\n            # the computed means to construct the correct dense tensor.\n            return tf.scatter_nd(\n                tf.expand_dims(x.indices, axis=1),\n                tf.reduce_mean(x.values, axis=axis),\n                [dense_shape[0]],\n            )\n        else:\n            # `keepdims=False`, not reducing against axis 0 and there is at\n            # least one other axis we are not reducing against. We simply need\n            # to fix `dense_shape` to remove dimensions that were reduced.\n            gather_indices = [i for i in range(rank) if i not in axis]\n            return tf.IndexedSlices(\n                tf.reduce_mean(x.values, axis=axis),\n                x.indices,\n                tf.gather(x.dense_shape, gather_indices, axis=0),\n            )\n    x = convert_to_tensor(x)\n    ori_dtype = standardize_dtype(x.dtype)\n    compute_dtype = dtypes.result_type(x.dtype, \"float32\")\n    # `tf.reduce_mean` does not handle low precision (e.g., float16) overflow\n    # correctly, so we compute with float32 and cast back to the original type.\n    if \"int\" in ori_dtype or ori_dtype == \"bool\":\n        result_dtype = compute_dtype\n    else:\n        result_dtype = ori_dtype\n    output = tf.reduce_mean(\n        tf.cast(x, compute_dtype), axis=axis, keepdims=keepdims\n    )\n    return tf.cast(output, result_dtype)\n\n\ndef max(x, axis=None, keepdims=False, initial=None):\n    x = convert_to_tensor(x)\n\n    # The TensorFlow numpy API implementation doesn't support `initial` so we\n    # handle it manually here.\n    if initial is not None:\n        if standardize_dtype(x.dtype) == \"bool\":\n            x = tf.reduce_any(x, axis=axis, keepdims=keepdims)\n            x = tf.math.maximum(tf.cast(x, \"int32\"), tf.cast(initial, \"int32\"))\n            return tf.cast(x, \"bool\")\n        else:\n            x = tf.reduce_max(x, axis=axis, keepdims=keepdims)\n            return tf.math.maximum(x, initial)\n\n    # TensorFlow returns -inf by default for an empty list, but for consistency\n    # with other backends and the numpy API we want to throw in this case.\n    if tf.executing_eagerly():\n        size_x = size(x)\n        tf.assert_greater(\n            size_x,\n            tf.constant(0, dtype=size_x.dtype),\n            message=\"Cannot compute the max of an empty tensor.\",\n        )\n\n    if standardize_dtype(x.dtype) == \"bool\":\n        return tf.reduce_any(x, axis=axis, keepdims=keepdims)\n    else:\n        return tf.reduce_max(x, axis=axis, keepdims=keepdims)\n\n\ndef ones(shape, dtype=None):\n    dtype = dtype or config.floatx()\n    return tf.ones(shape, dtype=dtype)\n\n\ndef zeros(shape, dtype=None):\n    dtype = dtype or config.floatx()\n    return tf.zeros(shape, dtype=dtype)\n\n\n@sparse.elementwise_unary\ndef absolute(x):\n    x = convert_to_tensor(x)\n    # uintx and bool are always non-negative\n    dtype = standardize_dtype(x.dtype)\n    if \"uint\" in dtype or dtype == \"bool\":\n        return x\n    return tf.abs(x)\n\n\ndef abs(x):\n    return absolute(x)\n\n\ndef all(x, axis=None, keepdims=False):\n    x = tf.cast(x, \"bool\")\n    return tf.reduce_all(x, axis=axis, keepdims=keepdims)\n\n\ndef angle(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = tf.cast(x, dtype)\n    return tf.math.angle(x)\n\n\ndef any(x, axis=None, keepdims=False):\n    x = tf.cast(x, \"bool\")\n    return tf.reduce_any(x, axis=axis, keepdims=keepdims)\n\n\ndef amax(x, axis=None, keepdims=False):\n    return max(x, axis=axis, keepdims=keepdims)\n\n\ndef amin(x, axis=None, keepdims=False):\n    return min(x, axis=axis, keepdims=keepdims)\n\n\ndef append(x1, x2, axis=None):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    x1 = tf.cast(x1, dtype)\n    x2 = tf.cast(x2, dtype)\n    if axis is None:\n        return tf.concat([tf.reshape(x1, [-1]), tf.reshape(x2, [-1])], axis=0)\n    else:\n        return tf.concat([x1, x2], axis=axis)\n\n\ndef arange(start, stop=None, step=None, dtype=None):\n    if dtype is None:\n        dtypes_to_resolve = [getattr(start, \"dtype\", type(start))]\n        if stop is not None:\n            dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n        if step is not None:\n            dtypes_to_resolve.append(getattr(step, \"dtype\", type(step)))\n        dtype = dtypes.result_type(*dtypes_to_resolve)\n    dtype = standardize_dtype(dtype)\n    if step is None:\n        step = 1\n    try:\n        out = tf.range(start, stop, delta=step, dtype=dtype)\n    except tf.errors.NotFoundError:\n        # Some dtypes may not work in eager mode on CPU or GPU.\n        out = tf.range(start, stop, delta=step, dtype=\"float32\")\n        out = tf.cast(out, dtype)\n    return out\n\n\n@sparse.densifying_unary(0.5 * np.pi)\ndef arccos(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = tf.cast(x, dtype)\n    return tf.math.acos(x)\n\n\n@sparse.densifying_unary(np.nan)\ndef arccosh(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = tf.cast(x, dtype)\n    return tf.math.acosh(x)\n\n\n@sparse.elementwise_unary\ndef arcsin(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = tf.cast(x, dtype)\n    return tf.math.asin(x)\n\n\n@sparse.elementwise_unary\ndef arcsinh(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = tf.cast(x, dtype)\n    return tf.math.asinh(x)\n\n\n@sparse.elementwise_unary\ndef arctan(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = tf.cast(x, dtype)\n    return tf.math.atan(x)\n\n\ndef arctan2(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype, float)\n    x1 = tf.cast(x1, dtype)\n    x2 = tf.cast(x2, dtype)\n    return tf.math.atan2(x1, x2)\n\n\n@sparse.elementwise_unary\ndef arctanh(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = tf.cast(x, dtype)\n    return tf.math.atanh(x)\n\n\ndef _keepdims(x, y, axis):\n    if axis is None:\n        shape = [1 for _ in range(len(x.shape))]\n    else:\n        shape = list(shape_op(x))\n        for axis in tree.flatten(axis):\n            shape[axis] = 1\n    y = tf.reshape(y, shape)\n    return y\n\n\ndef argmax(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    dtype = standardize_dtype(x.dtype)\n    if \"float\" not in dtype or x.ndim == 0:\n        _x = x\n        if axis is None:\n            x = tf.reshape(x, [-1])\n        y = tf.argmax(x, axis=axis, output_type=\"int32\")\n        if keepdims:\n            y = _keepdims(_x, y, axis)\n        return y\n\n    # Fix the flush-to-zero (FTZ) issue based on this issue:\n    # https://github.com/jax-ml/jax/issues/24280\n    dtype = dtypes.result_type(dtype, \"float32\")\n    x = cast(x, dtype)\n    is_negative_zero = tf.logical_and(tf.equal(x, 0.0), signbit(x))\n    x = tf.where(\n        is_negative_zero, -np.finfo(standardize_dtype(x.dtype)).tiny, x\n    )\n    _x = x\n    if axis is None:\n        x = tf.reshape(x, [-1])\n    y = tf.argmax(x, axis=axis, output_type=\"int32\")\n    if keepdims:\n        y = _keepdims(_x, y, axis)\n    return y\n\n\ndef argmin(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    dtype = standardize_dtype(x.dtype)\n    if \"float\" not in dtype or x.ndim == 0:\n        _x = x\n        if axis is None:\n            x = tf.reshape(x, [-1])\n        y = tf.argmin(x, axis=axis, output_type=\"int32\")\n        if keepdims:\n            y = _keepdims(_x, y, axis)\n        return y\n\n    # Fix the flush-to-zero (FTZ) issue based on this issue:\n    # https://github.com/jax-ml/jax/issues/24280\n    dtype = dtypes.result_type(dtype, \"float32\")\n    x = cast(x, dtype)\n    is_negative_zero = tf.logical_and(tf.equal(x, 0.0), signbit(x))\n    x = tf.where(\n        is_negative_zero, -np.finfo(standardize_dtype(x.dtype)).tiny, x\n    )\n    _x = x\n    if axis is None:\n        x = tf.reshape(x, [-1])\n    y = tf.argmin(x, axis=axis, output_type=\"int32\")\n    if keepdims:\n        y = _keepdims(_x, y, axis)\n    return y\n\n\ndef argsort(x, axis=-1):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"bool\":\n        x = tf.cast(x, \"uint8\")\n\n    x_shape = x.shape\n    if x_shape.rank == 0:\n        return tf.cast([0], \"int32\")\n\n    if axis is None:\n        x = tf.reshape(x, [-1])\n        axis = 0\n    return tf.argsort(x, axis=axis)\n\n\ndef array(x, dtype=None):\n    return convert_to_tensor(x, dtype=dtype)\n\n\ndef view(x, dtype=None):\n    from keras.src import backend\n\n    x = convert_to_tensor(x)\n    old_dtype = tf.as_dtype(backend.standardize_dtype(x.dtype))\n    new_dtype = tf.as_dtype(\n        backend.standardize_dtype(dtype if dtype else x.dtype)\n    )\n\n    old_itemsize = old_dtype.size\n    new_itemsize = new_dtype.size\n\n    old_shape = list(shape_op(x))\n    last_dim_size = old_shape[-1] if len(old_shape) > 0 else -1\n    if (last_dim_size == -1 and old_itemsize != new_itemsize) or (\n        last_dim_size * old_itemsize % new_itemsize != 0\n    ):\n        raise ValueError(\n            f\"Cannot view array of shape {x.shape} and dtype {old_dtype} \"\n            f\"as dtype {new_dtype} because the total number of bytes \"\n            f\"is not divisible by the new itemsize.\"\n        )\n\n    if old_itemsize == new_itemsize:\n        return tf.bitcast(x, type=new_dtype)\n    elif old_itemsize > new_itemsize:\n        ratio = old_itemsize // new_itemsize\n        new_shape = list(shape_op(x))\n        new_shape[-1] *= ratio\n        flat_tensor = tf.reshape(x, [-1])\n        cast_tensor = tf.bitcast(flat_tensor, type=new_dtype)\n        return tf.reshape(cast_tensor, new_shape)\n    else:\n        ratio = new_itemsize // old_itemsize\n        if isinstance(last_dim_size, int) and last_dim_size % ratio != 0:\n            raise ValueError(\n                f\"Cannot view dtype. Last dimension size ({last_dim_size}) \"\n                f\"must be divisible by the ratio of new/old item sizes \"\n                f\"({ratio}).\"\n            )\n        intermediate_shape = old_shape[:-1] + [last_dim_size // ratio, ratio]\n        reshaped_tensor = tf.reshape(x, intermediate_shape)\n        return tf.bitcast(reshaped_tensor, new_dtype)\n\n\ndef average(x, axis=None, weights=None):\n    x = convert_to_tensor(x)\n\n    if weights is None:  # Treat all weights as 1\n        dtype = dtypes.result_type(x.dtype, float)\n        x = tf.cast(x, dtype)\n        avg = tf.reduce_mean(x, axis=axis)\n    else:\n        weights = convert_to_tensor(weights)\n        dtype = dtypes.result_type(x.dtype, weights.dtype, float)\n        x = tf.cast(x, dtype)\n        weights = tf.cast(weights, dtype)\n\n        def _rank_equal_case():\n            weights_sum = tf.reduce_sum(weights, axis=axis)\n            return tf.reduce_sum(x * weights, axis=axis) / weights_sum\n\n        def _rank_not_equal_case():\n            weights_sum = tf.reduce_sum(weights)\n            axes = tf.convert_to_tensor([[axis], [0]])\n            return tf.tensordot(x, weights, axes) / weights_sum\n\n        if axis is None:\n            avg = _rank_equal_case()\n        else:\n            if len(x.shape) == len(weights.shape):\n                avg = _rank_equal_case()\n            else:\n                avg = _rank_not_equal_case()\n    return avg\n\n\ndef bitwise_and(x, y):\n    x = convert_to_tensor(x)\n    y = convert_to_tensor(y)\n    dtype = dtypes.result_type(x.dtype, y.dtype)\n    x = tf.cast(x, dtype)\n    y = tf.cast(y, dtype)\n    return tf.bitwise.bitwise_and(x, y)\n\n\ndef bitwise_invert(x):\n    x = convert_to_tensor(x)\n    return tf.bitwise.invert(x)\n\n\ndef bitwise_not(x):\n    return bitwise_invert(x)\n\n\ndef bitwise_or(x, y):\n    x = convert_to_tensor(x)\n    y = convert_to_tensor(y)\n    dtype = dtypes.result_type(x.dtype, y.dtype)\n    x = tf.cast(x, dtype)\n    y = tf.cast(y, dtype)\n    return tf.bitwise.bitwise_or(x, y)\n\n\ndef bitwise_xor(x, y):\n    x = convert_to_tensor(x)\n    y = convert_to_tensor(y)\n    dtype = dtypes.result_type(x.dtype, y.dtype)\n    x = tf.cast(x, dtype)\n    y = tf.cast(y, dtype)\n    return tf.bitwise.bitwise_xor(x, y)\n\n\ndef bitwise_left_shift(x, y):\n    x = convert_to_tensor(x)\n    if not isinstance(y, int):\n        y = convert_to_tensor(y)\n        dtype = dtypes.result_type(x.dtype, y.dtype)\n        x = tf.cast(x, dtype)\n        y = tf.cast(y, dtype)\n    return tf.bitwise.left_shift(x, y)\n\n\ndef left_shift(x, y):\n    return bitwise_left_shift(x, y)\n\n\ndef bitwise_right_shift(x, y):\n    x = convert_to_tensor(x)\n    if not isinstance(y, int):\n        y = convert_to_tensor(y)\n        dtype = dtypes.result_type(x.dtype, y.dtype)\n        x = tf.cast(x, dtype)\n        y = tf.cast(y, dtype)\n    return tf.bitwise.right_shift(x, y)\n\n\ndef right_shift(x, y):\n    return bitwise_right_shift(x, y)\n\n\ndef blackman(x):\n    dtype = config.floatx()\n    x = tf.cast(x, dtype)\n    n = tf.range(x, dtype=dtype)\n    n_minus_1 = tf.cast(x - 1, dtype)\n    term1 = 0.42\n    term2 = -0.5 * tf.cos(2 * np.pi * n / n_minus_1)\n    term3 = 0.08 * tf.cos(4 * np.pi * n / n_minus_1)\n    window = term1 + term2 + term3\n    return window\n\n\ndef broadcast_to(x, shape):\n    return tf.broadcast_to(x, shape)\n\n\ndef cbrt(x):\n    x = convert_to_tensor(x)\n\n    dtype = standardize_dtype(x.dtype)\n    if dtype == \"int64\":\n        x = tf.cast(x, \"float64\")\n    elif dtype not in [\"bfloat16\", \"float16\", \"float64\"]:\n        x = tf.cast(x, config.floatx())\n\n    return tf.sign(x) * tf.pow(tf.abs(x), 1.0 / 3.0)\n\n\n@sparse.elementwise_unary\ndef ceil(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = tf.cast(x, dtype)\n    return tf.math.ceil(x)\n\n\ndef clip(x, x_min, x_max):\n    dtype = standardize_dtype(x.dtype)\n    if dtype == \"bool\":\n        x = tf.cast(x, \"int32\")\n    return tf.clip_by_value(x, x_min, x_max)\n\n\ndef concatenate(xs, axis=0):\n    sparse_count = builtins.sum(isinstance(x, tf.SparseTensor) for x in xs)\n    if sparse_count:\n        if sparse_count == len(xs):\n            return tf.sparse.concat(axis=axis, sp_inputs=xs)\n        else:\n            xs = [\n                (\n                    convert_to_tensor(x, sparse=False)\n                    if isinstance(x, tf.SparseTensor)\n                    else x\n                )\n                for x in xs\n            ]\n    xs = tree.map_structure(convert_to_tensor, xs)\n    dtype_set = set([x.dtype for x in xs])\n    if len(dtype_set) > 1:\n        dtype = dtypes.result_type(*dtype_set)\n        xs = tree.map_structure(lambda x: tf.cast(x, dtype), xs)\n    return tf.concat(xs, axis=axis)\n\n\n@sparse.elementwise_unary\ndef conjugate(x):\n    return tf.math.conj(x)\n\n\n@sparse.elementwise_unary\ndef conj(x):\n    return tf.math.conj(x)\n\n\n@sparse.elementwise_unary\ndef copy(x):\n    x = convert_to_tensor(x)\n    return tf.identity(x)\n\n\n@sparse.densifying_unary(1)\ndef cos(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = tf.cast(x, dtype)\n    return tf.math.cos(x)\n\n\n@sparse.densifying_unary(1)\ndef cosh(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = tf.cast(x, dtype)\n    return tf.math.cosh(x)\n\n\ndef count_nonzero(x, axis=None):\n    return tf.math.count_nonzero(x, axis=axis, dtype=\"int32\")\n\n\ndef cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    x1 = tf.cast(x1, dtype)\n    x2 = tf.cast(x2, dtype)\n\n    if axis is not None:\n        axisa = axis\n        axisb = axis\n        axisc = axis\n    x1 = moveaxis(x1, axisa, -1)\n    x2 = moveaxis(x2, axisb, -1)\n\n    def maybe_pad_zeros(x, size_of_last_dim):\n        def pad_zeros(x):\n            return tf.pad(\n                x,\n                tf.concat(\n                    [\n                        tf.zeros([tf.rank(x) - 1, 2], \"int32\"),\n                        tf.constant([[0, 1]], \"int32\"),\n                    ],\n                    axis=0,\n                ),\n            )\n\n        if isinstance(size_of_last_dim, int):\n            if size_of_last_dim == 2:\n                return pad_zeros(x)\n            return x\n\n        return tf.cond(\n            tf.equal(size_of_last_dim, 2), lambda: pad_zeros(x), lambda: x\n        )\n\n    x1_dim = shape_op(x1)[-1]\n    x2_dim = shape_op(x2)[-1]\n\n    x1 = maybe_pad_zeros(x1, x1_dim)\n    x2 = maybe_pad_zeros(x2, x2_dim)\n\n    # Broadcast each other\n    shape = shape_op(x1)\n\n    shape = tf.broadcast_dynamic_shape(shape, shape_op(x2))\n    x1 = tf.broadcast_to(x1, shape)\n    x2 = tf.broadcast_to(x2, shape)\n\n    c = tf.linalg.cross(x1, x2)\n\n    if isinstance(x1_dim, int) and isinstance(x2_dim, int):\n        if (x1_dim == 2) & (x2_dim == 2):\n            return c[..., 2]\n        return moveaxis(c, -1, axisc)\n\n    return tf.cond(\n        (x1_dim == 2) & (x2_dim == 2),\n        lambda: c[..., 2],\n        lambda: moveaxis(c, -1, axisc),\n    )\n\n\ndef cumprod(x, axis=None, dtype=None):\n    x = convert_to_tensor(x, dtype=dtype)\n    # tf.math.cumprod doesn't support bool\n    if standardize_dtype(x.dtype) == \"bool\":\n        x = tf.cast(x, \"int32\")\n    if axis is None:\n        x = tf.reshape(x, [-1])\n        axis = 0\n    return tf.math.cumprod(x, axis=axis)\n\n\ndef cumsum(x, axis=None, dtype=None):\n    x = convert_to_tensor(x, dtype=dtype)\n    # tf.math.cumprod doesn't support bool\n    if standardize_dtype(x.dtype) == \"bool\":\n        x = tf.cast(x, \"int32\")\n    if axis is None:\n        x = tf.reshape(x, [-1])\n        axis = 0\n    return tf.math.cumsum(x, axis=axis)\n\n\ndef deg2rad(x):\n    x = convert_to_tensor(x)\n\n    dtype = x.dtype\n    if standardize_dtype(dtype) in [\n        \"bool\",\n        \"int8\",\n        \"int16\",\n        \"int32\",\n        \"uint8\",\n        \"uint16\",\n        \"uint32\",\n    ]:\n        dtype = config.floatx()\n    elif standardize_dtype(dtype) in [\"int64\"]:\n        dtype = \"float64\"\n    x = tf.cast(x, dtype)\n\n    pi = tf.constant(math.pi, dtype=dtype)\n    return x * (pi / tf.constant(180.0, dtype=dtype))\n\n\ndef diag(x, k=0):\n    x = convert_to_tensor(x)\n    if len(x.shape) == 1:\n        return tf.linalg.diag(x, k=k)\n    elif len(x.shape) == 2:\n        return diagonal(x, offset=k)\n    else:\n        raise ValueError(f\"`x` must be 1d or 2d. Received: x.shape={x.shape}\")\n\n\ndef diagflat(x, k=0):\n    x = convert_to_tensor(x)\n    return diag(tf.reshape(x, [-1]), k)\n\n\ndef diagonal(x, offset=0, axis1=0, axis2=1):\n    x = convert_to_tensor(x)\n    x_rank = x.ndim\n    if (\n        offset == 0\n        and (axis1 == x_rank - 2 or axis1 == -2)\n        and (axis2 == x_rank - 1 or axis2 == -1)\n    ):\n        return tf.linalg.diag_part(x)\n\n    x = moveaxis(x, (axis1, axis2), (-2, -1))\n    x_shape = shape_op(x)\n\n    def _zeros():\n        return tf.zeros(tf.concat([x_shape[:-1], [0]], 0), dtype=x.dtype)\n\n    if isinstance(x_shape[-1], int) and isinstance(x_shape[-2], int):\n        if offset <= -1 * x_shape[-2] or offset >= x_shape[-1]:\n            x = _zeros()\n    else:\n        x = tf.cond(\n            tf.logical_or(\n                tf.less_equal(offset, -1 * x_shape[-2]),\n                tf.greater_equal(offset, x_shape[-1]),\n            ),\n            lambda: _zeros(),\n            lambda: x,\n        )\n    return tf.linalg.diag_part(x, k=offset)\n\n\ndef diff(a, n=1, axis=-1):\n    a = convert_to_tensor(a)\n    if n == 0:\n        return a\n    elif n < 0:\n        raise ValueError(f\"Order `n` must be non-negative. Received n={n}\")\n    elif a.ndim == 0:\n        raise ValueError(\n            \"`diff` requires input that is at least one dimensional. \"\n            f\"Received: a={a}\"\n        )\n    axis = canonicalize_axis(axis, a.ndim)\n    slice1 = [slice(None)] * a.ndim\n    slice2 = [slice(None)] * a.ndim\n    slice1[axis] = slice(1, None)\n    slice2[axis] = slice(None, -1)\n    slice1_tuple = tuple(slice1)\n    slice2_tuple = tuple(slice2)\n    for _ in range(n):\n        if standardize_dtype(a.dtype) == \"bool\":\n            a = tf.not_equal(a[slice1_tuple], a[slice2_tuple])\n        else:\n            a = tf.subtract(a[slice1_tuple], a[slice2_tuple])\n    return a\n\n\ndef digitize(x, bins):\n    x = convert_to_tensor(x)\n    bins = list(bins)\n\n    # bins must be float type\n    bins = tree.map_structure(lambda x: float(x), bins)\n\n    # TODO: tf.raw_ops.Bucketize doesn't support bool, bfloat16, float16, int8\n    # int16, uint8, uint16, uint32\n    ori_dtype = standardize_dtype(x.dtype)\n    if ori_dtype in (\"bool\", \"int8\", \"int16\", \"uint8\", \"uint16\"):\n        x = cast(x, \"int32\")\n    elif ori_dtype == \"uint32\":\n        x = cast(x, \"int64\")\n    elif ori_dtype in (\"bfloat16\", \"float16\"):\n        x = cast(x, \"float32\")\n\n    if isinstance(x, tf.RaggedTensor):\n        return tf.ragged.map_flat_values(\n            lambda y: tf.raw_ops.Bucketize(input=y, boundaries=bins), x\n        )\n    elif isinstance(x, tf.SparseTensor):\n        output = tf.SparseTensor(\n            indices=tf.identity(x.indices),\n            values=tf.raw_ops.Bucketize(input=x.values, boundaries=bins),\n            dense_shape=tf.identity(x.dense_shape),\n        )\n        output.set_shape(x.shape)\n        return output\n    return tf.raw_ops.Bucketize(input=x, boundaries=bins)\n\n\ndef dot(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    result_dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    # GPU only supports float types\n    compute_dtype = dtypes.result_type(result_dtype, float)\n    x1 = tf.cast(x1, compute_dtype)\n    x2 = tf.cast(x2, compute_dtype)\n\n    x_shape = x1.shape\n    y_shape = x2.shape\n    if x_shape.rank == 0 or y_shape.rank == 0:\n        output = x1 * x2\n    elif y_shape.rank == 1:\n        output = tf.tensordot(x1, x2, axes=[[-1], [-1]])\n    else:\n        output = tf.tensordot(x1, x2, axes=[[-1], [-2]])\n    return tf.cast(output, result_dtype)\n\n\ndef dstack(xs):\n    xs = [convert_to_tensor(x) for x in xs]\n    if len(xs) > 1:\n        unique_dtypes = {x.dtype for x in xs}\n        if len(unique_dtypes) > 1:\n            dtype = dtypes.result_type(*[x.dtype for x in xs])\n            xs = [cast(x, dtype) for x in xs]\n    xs_reshaped = []\n    for x in xs:\n        shape = x.shape\n        if len(shape) == 0:\n            x = tf.reshape(x, (1, 1, 1))\n        elif len(shape) == 1:\n            x = tf.expand_dims(x, axis=0)\n            x = tf.expand_dims(x, axis=2)\n        elif len(shape) == 2:\n            x = tf.expand_dims(x, axis=2)\n        xs_reshaped.append(x)\n    return tf.concat(xs_reshaped, axis=2)\n\n\ndef empty(shape, dtype=None):\n    dtype = dtype or config.floatx()\n    return tf.zeros(shape, dtype=dtype)\n\n\ndef empty_like(x, dtype=None):\n    return tf.zeros_like(x, dtype=dtype)\n\n\ndef equal(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    x1 = tf.cast(x1, dtype)\n    x2 = tf.cast(x2, dtype)\n    return tf.equal(x1, x2)\n\n\n@sparse.densifying_unary(1)\ndef exp(x):\n    x = convert_to_tensor(x)\n    ori_dtype = standardize_dtype(x.dtype)\n    if \"int\" in ori_dtype or ori_dtype == \"bool\":\n        x = tf.cast(x, config.floatx())\n    return tf.exp(x)\n\n\n@sparse.densifying_unary(1)\ndef exp2(x):\n    x = convert_to_tensor(x)\n    ori_dtype = standardize_dtype(x.dtype)\n    if \"int\" in ori_dtype or ori_dtype == \"bool\":\n        x = tf.cast(x, config.floatx())\n    return tf.math.pow(2.0, x)\n\n\ndef expand_dims(x, axis):\n    x = convert_to_tensor(x)\n    axis = to_tuple_or_list(axis)\n    out_ndim = len(x.shape) + len(axis)\n    axis = sorted([canonicalize_axis(a, out_ndim) for a in axis])\n    if isinstance(x, tf.SparseTensor):\n        from keras.src.ops.operation_utils import (\n            compute_expand_dims_output_shape,\n        )\n\n        output_shape = compute_expand_dims_output_shape(x.shape, axis)\n        for a in axis:\n            x = tf.sparse.expand_dims(x, a)\n        x.set_shape(output_shape)\n        return x\n    for a in axis:\n        x = tf.expand_dims(x, a)\n    return x\n\n\n@sparse.elementwise_unary\ndef expm1(x):\n    x = convert_to_tensor(x)\n    ori_dtype = standardize_dtype(x.dtype)\n    if \"int\" in ori_dtype or ori_dtype == \"bool\":\n        x = tf.cast(x, config.floatx())\n    return tf.math.expm1(x)\n\n\ndef flip(x, axis=None):\n    x = convert_to_tensor(x)\n    if axis is None:\n        return tf.reverse(x, tf.range(tf.rank(x)))\n    return tf.reverse(x, [axis])\n\n\n@sparse.elementwise_unary\ndef floor(x):\n    x = convert_to_tensor(x)\n    dtype = (\n        config.floatx()\n        if standardize_dtype(x.dtype) == \"int64\"\n        else dtypes.result_type(x.dtype, float)\n    )\n    x = tf.cast(x, dtype)\n    return tf.floor(x)\n\n\ndef full(shape, fill_value, dtype=None):\n    dtype = dtype or config.floatx()\n    fill_value = convert_to_tensor(fill_value, dtype)\n    return tf.broadcast_to(fill_value, shape)\n\n\ndef full_like(x, fill_value, dtype=None):\n    x = convert_to_tensor(x)\n    dtype = dtypes.result_type(dtype or x.dtype)\n    fill_value = convert_to_tensor(fill_value, dtype)\n    return tf.broadcast_to(fill_value, tf.shape(x))\n\n\ndef gcd(x1, x2):\n    x1 = tf.convert_to_tensor(x1)\n    x2 = tf.convert_to_tensor(x2)\n\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    x1 = tf.cast(x1, dtype)\n    x2 = tf.cast(x2, dtype)\n\n    if not x1.dtype.is_integer:\n        raise TypeError(\"Arguments to gcd must be integers.\")\n\n    target_shape = tf.broadcast_static_shape(x1.shape, x2.shape)\n    x1 = tf.broadcast_to(x1, target_shape)\n    x2 = tf.broadcast_to(x2, target_shape)\n\n    def cond(a, b):\n        return tf.reduce_any(b != 0)\n\n    def body(a, b):\n        b_safe = tf.where(tf.equal(b, 0), tf.ones_like(b), b)\n        return (\n            tf.where(tf.not_equal(b, 0), b, a),\n            tf.where(\n                tf.not_equal(b, 0),\n                tf.math.floormod(a, b_safe),\n                tf.zeros_like(b),\n            ),\n        )\n\n    if dtype not in [tf.uint8, tf.uint16, tf.uint32, tf.uint64]:\n        x1 = tf.abs(x1)\n        x2 = tf.abs(x2)\n\n    gcd_val, _ = tf.while_loop(cond, body, [x1, x2])\n    return gcd_val\n\n\ndef geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):\n    start = convert_to_tensor(start)\n    stop = convert_to_tensor(stop)\n    log_start = tf.math.log(tf.cast(tf.abs(start), dtype or config.floatx()))\n    log_stop = tf.math.log(tf.cast(tf.abs(stop), dtype or config.floatx()))\n    log_base = tf.math.log(tf.constant(10.0, dtype=log_start.dtype))\n    result = logspace(\n        log_start / log_base,\n        log_stop / log_base,\n        num=num,\n        endpoint=endpoint,\n        base=10,\n        dtype=dtype,\n        axis=axis,\n    )\n    # Handle sign: start and stop must have the same sign (or be zero)\n    start_sign = tf.cast(tf.sign(tf.cast(start, log_start.dtype)), result.dtype)\n    return result * start_sign\n\n\ndef greater(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    x1 = tf.cast(x1, dtype)\n    x2 = tf.cast(x2, dtype)\n    return tf.greater(x1, x2)\n\n\ndef greater_equal(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    x1 = tf.cast(x1, dtype)\n    x2 = tf.cast(x2, dtype)\n    return tf.greater_equal(x1, x2)\n\n\ndef hstack(xs):\n    dtype_set = set([getattr(x, \"dtype\", type(x)) for x in xs])\n    if len(dtype_set) > 1:\n        dtype = dtypes.result_type(*dtype_set)\n        xs = tree.map_structure(lambda x: convert_to_tensor(x, dtype), xs)\n    if len(xs[0].shape) == 1:\n        return tf.concat(xs, axis=0)\n    return tf.concat(xs, axis=1)\n\n\ndef hsplit(x, indices_or_sections):\n    x = convert_to_tensor(x)\n    if x.ndim == 1:\n        return split(x, indices_or_sections, axis=0)\n    return split(x, indices_or_sections, axis=1)\n\n\ndef hypot(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    if dtype in [\"int8\", \"int16\", \"int32\", \"uint8\", \"uint16\", \"uint32\"]:\n        dtype = config.floatx()\n    elif dtype in [\"int64\"]:\n        dtype = \"float64\"\n\n    x1 = tf.cast(x1, dtype)\n    x2 = tf.cast(x2, dtype)\n\n    x1_abs = tf.abs(x1)\n    x2_abs = tf.abs(x2)\n    max_val = tf.maximum(x1_abs, x2_abs)\n    min_val = tf.minimum(x1_abs, x2_abs)\n\n    ratio = tf.math.divide_no_nan(min_val, max_val)\n    return max_val * tf.sqrt(1.0 + tf.square(ratio))\n\n\ndef identity(n, dtype=None):\n    return eye(N=n, M=n, dtype=dtype)\n\n\n@sparse.elementwise_unary\ndef imag(x):\n    return tf.math.imag(x)\n\n\ndef isclose(x1, x2, rtol=1e-5, atol=1e-8, equal_nan=False):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    x1 = tf.cast(x1, dtype)\n    x2 = tf.cast(x2, dtype)\n    if \"float\" in dtype:\n        result = tf.abs(x1 - x2) <= (atol + rtol * tf.abs(x2))\n        if equal_nan:\n            result = result | (is_nan(x1) & is_nan(x2))\n        return result\n    else:\n        return tf.equal(x1, x2)\n\n\ndef allclose(x1, x2, rtol=1e-5, atol=1e-8, equal_nan=False):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    x1 = tf.cast(x1, dtype)\n    x2 = tf.cast(x2, dtype)\n\n    if \"float\" in standardize_dtype(dtype):\n        finite = tf.math.is_finite(x1) & tf.math.is_finite(x2)\n        close = tf.abs(x1 - x2) <= (atol + rtol * tf.abs(x2))\n        result = (finite & close) | tf.equal(x1, x2)\n        if equal_nan:\n            result = result | (is_nan(x1) & is_nan(x2))\n        return tf.reduce_all(result)\n    else:\n        return tf.reduce_all(tf.equal(x1, x2))\n\n\n@sparse.densifying_unary(True)\ndef isfinite(x):\n    x = convert_to_tensor(x)\n    dtype_as_dtype = tf.as_dtype(x.dtype)\n    if dtype_as_dtype.is_integer or not dtype_as_dtype.is_numeric:\n        return tf.ones(x.shape, tf.bool)\n    return tf.math.is_finite(x)\n\n\ndef isin(x1, x2, assume_unique=False, invert=False):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    x1 = tf.cast(x1, dtype)\n    x2 = tf.cast(x2, dtype)\n\n    output_shape = tf.shape(x1)\n\n    x1 = tf.reshape(x1, [-1])\n    x2 = tf.reshape(x2, [-1])\n\n    if not assume_unique:\n        x2 = tf.unique(x2)[0]\n\n    if tf.size(x1) == 0 or tf.size(x2) == 0:\n        return tf.zeros(output_shape, dtype=tf.bool)\n\n    cmp = tf.equal(tf.expand_dims(x1, 1), tf.expand_dims(x2, 0))\n    result_flat = tf.reduce_any(cmp, axis=1)\n\n    if invert:\n        result_flat = tf.logical_not(result_flat)\n\n    return tf.reshape(result_flat, output_shape)\n\n\ndef isinf(x):\n    x = convert_to_tensor(x)\n    dtype_as_dtype = tf.as_dtype(x.dtype)\n    if dtype_as_dtype.is_integer or not dtype_as_dtype.is_numeric:\n        return tf.zeros(x.shape, tf.bool)\n    return tf.math.is_inf(x)\n\n\ndef isnan(x):\n    x = convert_to_tensor(x)\n    dtype_as_dtype = tf.as_dtype(x.dtype)\n    if dtype_as_dtype.is_integer or not dtype_as_dtype.is_numeric:\n        return tf.zeros(x.shape, tf.bool)\n    return tf.math.is_nan(x)\n\n\ndef isneginf(x):\n    x = convert_to_tensor(x)\n    dtype_as_dtype = tf.as_dtype(x.dtype)\n    if dtype_as_dtype.is_integer or not dtype_as_dtype.is_numeric:\n        return tf.zeros_like(x, dtype=tf.bool)\n    return tf.math.equal(x, -tf.constant(float(\"inf\"), dtype=x.dtype))\n\n\ndef isposinf(x):\n    x = convert_to_tensor(x)\n    dtype_as_dtype = tf.as_dtype(x.dtype)\n    if dtype_as_dtype.is_integer or not dtype_as_dtype.is_numeric:\n        return tf.zeros_like(x, dtype=tf.bool)\n    return tf.math.equal(x, tf.constant(float(\"inf\"), dtype=x.dtype))\n\n\ndef isreal(x):\n    x = convert_to_tensor(x)\n    if x.dtype.is_complex:\n        return tf.equal(tf.math.imag(x), 0)\n    else:\n        return tf.ones_like(x, dtype=tf.bool)\n\n\ndef kron(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    x1 = tf.cast(x1, dtype)\n    x2 = tf.cast(x2, dtype)\n\n    ndim_x1 = tf.rank(x1)\n    ndim_x2 = tf.rank(x2)\n\n    def expand_front(x, num):\n        for _ in range(num):\n            x = tf.expand_dims(x, axis=0)\n        return x\n\n    x1 = tf.cond(\n        ndim_x1 < ndim_x2,\n        lambda: expand_front(x1, ndim_x2 - ndim_x1),\n        lambda: x1,\n    )\n    x2 = tf.cond(\n        ndim_x2 < ndim_x1,\n        lambda: expand_front(x2, ndim_x1 - ndim_x2),\n        lambda: x2,\n    )\n\n    x1_reshaped = tf.reshape(\n        x1,\n        tf.reshape(\n            tf.stack([tf.shape(x1), tf.ones_like(tf.shape(x1))], axis=1), [-1]\n        ),\n    )\n    x2_reshaped = tf.reshape(\n        x2,\n        tf.reshape(\n            tf.stack([tf.ones_like(tf.shape(x2)), tf.shape(x2)], axis=1), [-1]\n        ),\n    )\n\n    out = tf.multiply(x1_reshaped, x2_reshaped)\n    out_shape = tf.multiply(tf.shape(x1), tf.shape(x2))\n    out = tf.reshape(out, out_shape)\n    return out\n\n\ndef lcm(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n\n    if not (x1.dtype.is_integer and x2.dtype.is_integer):\n        raise TypeError(\n            f\"Arguments to lcm must be integers. \"\n            f\"Received: x1.dtype={x1.dtype.name}, x2.dtype={x2.dtype.name}\"\n        )\n\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    x1 = tf.cast(x1, dtype)\n    x2 = tf.cast(x2, dtype)\n\n    if dtype not in [tf.uint8, tf.uint16, tf.uint32, tf.uint64]:\n        x1 = tf.math.abs(x1)\n        x2 = tf.math.abs(x2)\n\n    divisor = gcd(x1, x2)\n    divisor_safe = tf.where(\n        divisor == 0, tf.constant(1, dtype=divisor.dtype), divisor\n    )\n\n    result = x1 * (x2 // divisor_safe)\n    result = tf.where(divisor == 0, tf.zeros_like(result), result)\n\n    return result\n\n\ndef ldexp(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype, float)\n\n    if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:\n        raise TypeError(\n            f\"ldexp exponent must be an integer type. \"\n            f\"Received: x2 dtype={x2.dtype}\"\n        )\n\n    x1 = tf.cast(x1, dtype)\n    x2 = tf.cast(x2, x1.dtype)\n    result = x1 * tf.pow(tf.constant(2.0, dtype=x1.dtype), x2)\n    return tf.cast(tf.where(tf.math.is_inf(x1) | (x1 == 0), x1, result), dtype)\n\n\ndef less(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    x1 = tf.cast(x1, dtype)\n    x2 = tf.cast(x2, dtype)\n    return tf.less(x1, x2)\n\n\ndef less_equal(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    x1 = tf.cast(x1, dtype)\n    x2 = tf.cast(x2, dtype)\n    return tf.less_equal(x1, x2)\n\n\ndef linspace(\n    start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0\n):\n    if num < 0:\n        raise ValueError(\n            f\"`num` must be a non-negative integer. Received: num={num}\"\n        )\n    if dtype is None:\n        dtypes_to_resolve = [\n            getattr(start, \"dtype\", type(start)),\n            getattr(stop, \"dtype\", type(stop)),\n            float,\n        ]\n        dtype = dtypes.result_type(*dtypes_to_resolve)\n    else:\n        dtype = standardize_dtype(dtype)\n    start = convert_to_tensor(start, dtype=dtype)\n    stop = convert_to_tensor(stop, dtype=dtype)\n    step = convert_to_tensor(np.nan)\n    if endpoint:\n        result = tf.linspace(start, stop, num, axis=axis)\n        if num > 1:\n            step = (stop - start) / (tf.cast(num, dtype) - 1)\n    else:\n        # tf.linspace doesn't support endpoint=False, so we manually handle it\n        if num > 0:\n            step = (stop - start) / tf.cast(num, dtype)\n        if num > 1:\n            new_stop = tf.cast(stop, step.dtype) - step\n            start = tf.cast(start, new_stop.dtype)\n            result = tf.linspace(start, new_stop, num, axis=axis)\n        else:\n            result = tf.linspace(start, stop, num, axis=axis)\n    if dtype is not None:\n        if \"int\" in dtype:\n            result = tf.floor(result)\n        result = tf.cast(result, dtype)\n    if retstep:\n        return (result, step)\n    else:\n        return result\n\n\n@sparse.densifying_unary(-np.inf)\ndef log(x):\n    x = convert_to_tensor(x)\n    dtype = (\n        config.floatx()\n        if standardize_dtype(x.dtype) == \"int64\"\n        else dtypes.result_type(x.dtype, float)\n    )\n    x = tf.cast(x, dtype)\n    return tf.math.log(x)\n\n\n@sparse.densifying_unary(-np.inf)\ndef log10(x):\n    x = convert_to_tensor(x)\n    dtype = (\n        config.floatx()\n        if standardize_dtype(x.dtype) == \"int64\"\n        else dtypes.result_type(x.dtype, float)\n    )\n    x = tf.cast(x, dtype)\n    return tf.math.log(x) / tf.math.log(tf.constant(10, x.dtype))\n\n\n@sparse.elementwise_unary\ndef log1p(x):\n    x = convert_to_tensor(x)\n    dtype = (\n        config.floatx()\n        if standardize_dtype(x.dtype) == \"int64\"\n        else dtypes.result_type(x.dtype, float)\n    )\n    x = tf.cast(x, dtype)\n    return tf.math.log1p(x)\n\n\n@sparse.densifying_unary(-np.inf)\ndef log2(x):\n    x = convert_to_tensor(x)\n    dtype = (\n        config.floatx()\n        if standardize_dtype(x.dtype) == \"int64\"\n        else dtypes.result_type(x.dtype, float)\n    )\n    x = tf.cast(x, dtype)\n    return tf.math.log(x) / tf.math.log(tf.constant(2, x.dtype))\n\n\ndef logaddexp(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype, float)\n    x1 = tf.cast(x1, dtype)\n    x2 = tf.cast(x2, dtype)\n    delta = x1 - x2\n    return tf.where(\n        tf.math.is_nan(delta),\n        x1 + x2,\n        tf.maximum(x1, x2) + tf.math.log1p(tf.math.exp(-tf.abs(delta))),\n    )\n\n\ndef logaddexp2(x1, x2):\n    x1 = tf.convert_to_tensor(x1)\n    x2 = tf.convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype, float)\n    x1 = tf.cast(x1, dtype)\n    x2 = tf.cast(x2, dtype)\n    delta = x1 - x2\n    log2 = tf.cast(tf.math.log(2.0), dtype)\n    return tf.where(\n        tf.math.is_nan(delta),\n        x1 + x2,\n        tf.maximum(x1, x2)\n        + tf.math.log1p(tf.math.exp(-tf.abs(delta) * log2)) / log2,\n    )\n\n\ndef logical_and(x1, x2):\n    x1 = tf.cast(x1, \"bool\")\n    x2 = tf.cast(x2, \"bool\")\n    return tf.logical_and(x1, x2)\n\n\ndef logical_not(x):\n    x = tf.cast(x, \"bool\")\n    return tf.logical_not(x)\n\n\ndef logical_or(x1, x2):\n    x1 = tf.cast(x1, \"bool\")\n    x2 = tf.cast(x2, \"bool\")\n    return tf.logical_or(x1, x2)\n\n\ndef logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0):\n    result = linspace(\n        start=start,\n        stop=stop,\n        num=num,\n        endpoint=endpoint,\n        dtype=dtype,\n        axis=axis,\n    )\n    return tf.pow(tf.cast(base, result.dtype), result)\n\n\n@sparse.elementwise_binary_union(tf.sparse.maximum, densify_mixed=True)\ndef maximum(x1, x2):\n    if not isinstance(x1, (int, float)):\n        x1 = convert_to_tensor(x1)\n    if not isinstance(x2, (int, float)):\n        x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(\n        getattr(x1, \"dtype\", type(x1)),\n        getattr(x2, \"dtype\", type(x2)),\n    )\n    x1 = convert_to_tensor(x1, dtype)\n    x2 = convert_to_tensor(x2, dtype)\n    return tf.maximum(x1, x2)\n\n\ndef median(x, axis=None, keepdims=False):\n    return quantile(x, 0.5, axis=axis, keepdims=keepdims)\n\n\ndef meshgrid(*x, indexing=\"xy\"):\n    return tf.meshgrid(*x, indexing=indexing)\n\n\ndef min(x, axis=None, keepdims=False, initial=None):\n    x = convert_to_tensor(x)\n\n    # The TensorFlow numpy API implementation doesn't support `initial` so we\n    # handle it manually here.\n    if initial is not None:\n        if standardize_dtype(x.dtype) == \"bool\":\n            x = tf.reduce_all(x, axis=axis, keepdims=keepdims)\n            x = tf.math.minimum(tf.cast(x, \"int32\"), tf.cast(initial, \"int32\"))\n            return tf.cast(x, \"bool\")\n        else:\n            x = tf.reduce_min(x, axis=axis, keepdims=keepdims)\n        return tf.math.minimum(x, initial)\n\n    # TensorFlow returns inf by default for an empty list, but for consistency\n    # with other backends and the numpy API we want to throw in this case.\n    if tf.executing_eagerly():\n        size_x = size(x)\n        tf.assert_greater(\n            size_x,\n            tf.constant(0, dtype=size_x.dtype),\n            message=\"Cannot compute the min of an empty tensor.\",\n        )\n\n    if standardize_dtype(x.dtype) == \"bool\":\n        return tf.reduce_all(x, axis=axis, keepdims=keepdims)\n    else:\n        return tf.reduce_min(x, axis=axis, keepdims=keepdims)\n\n\n@sparse.elementwise_binary_union(tf.sparse.minimum, densify_mixed=True)\ndef minimum(x1, x2):\n    if not isinstance(x1, (int, float)):\n        x1 = convert_to_tensor(x1)\n    if not isinstance(x2, (int, float)):\n        x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(\n        getattr(x1, \"dtype\", type(x1)),\n        getattr(x2, \"dtype\", type(x2)),\n    )\n    x1 = convert_to_tensor(x1, dtype)\n    x2 = convert_to_tensor(x2, dtype)\n    return tf.minimum(x1, x2)\n\n\ndef mod(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    if dtype == \"bool\":\n        dtype = \"int32\"\n    x1 = tf.cast(x1, dtype)\n    x2 = tf.cast(x2, dtype)\n    return tf.math.mod(x1, x2)\n\n\ndef fmod(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    if dtype == \"bool\":\n        dtype = \"int32\"\n    # tf.math.floormod does not support uint8/uint16; compute in int32\n    compute_dtype = dtype\n    if dtype in (\"uint8\", \"uint16\", \"uint32\"):\n        compute_dtype = \"int32\"\n    x1 = tf.cast(x1, compute_dtype)\n    x2 = tf.cast(x2, compute_dtype)\n    result = tf.sign(x1) * tf.math.floormod(tf.abs(x1), tf.abs(x2))\n    return tf.cast(result, dtype)\n\n\ndef moveaxis(x, source, destination):\n    x = convert_to_tensor(x)\n\n    _source = to_tuple_or_list(source)\n    _destination = to_tuple_or_list(destination)\n    _source = tuple(canonicalize_axis(i, x.ndim) for i in _source)\n    _destination = tuple(canonicalize_axis(i, x.ndim) for i in _destination)\n    if len(_source) != len(_destination):\n        raise ValueError(\n            \"Inconsistent number of `source` and `destination`. \"\n            f\"Received: source={source}, destination={destination}\"\n        )\n    # Directly return x if no movement is required\n    if _source == _destination:\n        return x\n    perm = [i for i in range(x.ndim) if i not in _source]\n    for dest, src in sorted(zip(_destination, _source)):\n        perm.insert(dest, src)\n    return tf.transpose(x, perm)\n\n\ndef nanargmax(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n\n    if not x.dtype.is_floating:\n        return argmax(x, axis=axis, keepdims=keepdims)\n\n    nan_mask = tf.math.is_nan(x)\n\n    return tf.where(\n        tf.reduce_all(nan_mask, axis=axis, keepdims=keepdims),\n        tf.constant(-1, dtype=tf.int32),\n        argmax(\n            tf.where(nan_mask, tf.constant(float(\"-inf\"), dtype=x.dtype), x),\n            axis=axis,\n            keepdims=keepdims,\n        ),\n    )\n\n\ndef nanargmin(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n\n    if not x.dtype.is_floating:\n        return argmin(x, axis=axis, keepdims=keepdims)\n\n    nan_mask = tf.math.is_nan(x)\n\n    return tf.where(\n        tf.reduce_all(nan_mask, axis=axis, keepdims=keepdims),\n        tf.constant(-1, dtype=tf.int32),\n        argmin(\n            tf.where(nan_mask, tf.constant(float(\"inf\"), dtype=x.dtype), x),\n            axis=axis,\n            keepdims=keepdims,\n        ),\n    )\n\n\ndef nancumsum(x, axis=None, dtype=None):\n    x = nan_to_num(x)\n    return cumsum(x, axis=axis, dtype=dtype)\n\n\ndef nancumprod(x, axis=None, dtype=None):\n    x = nan_to_num(x, nan=1.0)\n    return cumprod(x, axis=axis, dtype=dtype)\n\n\ndef nanmax(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n\n    if not x.dtype.is_floating:\n        dtype = standardize_dtype(x.dtype)\n        if dtype == \"bool\":\n            return tf.reduce_any(x, axis=axis, keepdims=keepdims)\n        return tf.reduce_max(x, axis=axis, keepdims=keepdims)\n\n    x_clean = tf.where(\n        tf.math.is_nan(x), tf.constant(float(\"-inf\"), dtype=x.dtype), x\n    )\n\n    return tf.where(\n        tf.reduce_all(tf.math.is_nan(x), axis=axis, keepdims=keepdims),\n        tf.constant(float(\"nan\"), dtype=x.dtype),\n        tf.reduce_max(x_clean, axis=axis, keepdims=keepdims),\n    )\n\n\ndef nanmean(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n\n    if axis == () or axis == []:\n        return x\n\n    if not x.dtype.is_floating:\n        return tf.reduce_mean(\n            tf.cast(x, \"float32\"), axis=axis, keepdims=keepdims\n        )\n\n    dtype = dtypes.result_type(standardize_dtype(x.dtype), float)\n    total_sum = cast(nansum(x, axis=axis, keepdims=keepdims), dtype)\n    normalizer = tf.reduce_sum(\n        cast(~tf.math.is_nan(x), dtype),\n        axis=axis,\n        keepdims=keepdims,\n    )\n    return tf.divide(total_sum, normalizer)\n\n\ndef nanmin(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n\n    if not x.dtype.is_floating:\n        dtype = standardize_dtype(x.dtype)\n        if dtype == \"bool\":\n            return tf.reduce_all(x, axis=axis, keepdims=keepdims)\n        return tf.reduce_min(x, axis=axis, keepdims=keepdims)\n\n    x_clean = tf.where(\n        tf.math.is_nan(x), tf.constant(float(\"inf\"), dtype=x.dtype), x\n    )\n\n    return tf.where(\n        tf.reduce_all(tf.math.is_nan(x), axis=axis, keepdims=keepdims),\n        tf.constant(float(\"nan\"), dtype=x.dtype),\n        tf.reduce_min(x_clean, axis=axis, keepdims=keepdims),\n    )\n\n\ndef nanprod(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n\n    if not x.dtype.is_floating:\n        return prod(x, axis=axis, keepdims=keepdims)\n\n    x_safe = tf.where(tf.math.is_nan(x), tf.ones((), dtype=x.dtype), x)\n    return prod(x_safe, axis=axis, keepdims=keepdims)\n\n\ndef nanstd(x, axis=None, keepdims=False):\n    var_val = nanvar(x, axis=axis, keepdims=keepdims)\n    return tf.sqrt(var_val)\n\n\ndef nansum(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    dtype = standardize_dtype(x.dtype)\n    x_clean = tf.where(\n        tf.math.is_nan(cast(x, config.floatx())), tf.zeros((), dtype=dtype), x\n    )\n\n    if dtype in (\"bool\", \"int8\", \"int16\"):\n        dtype = \"int32\"\n    elif dtype in (\"uint8\", \"uint16\"):\n        dtype = \"uint32\"\n    x_clean = cast(x_clean, dtype)\n\n    return tf.reduce_sum(x_clean, axis=axis, keepdims=keepdims)\n\n\ndef nanvar(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    result_dtype = dtypes.result_type(x.dtype, float)\n    x = tf.cast(x, result_dtype)\n\n    mean = nanmean(x, axis=axis, keepdims=True)\n\n    valid = ~tf.math.is_nan(x)\n\n    centered = tf.where(valid, x - mean, tf.zeros_like(x))\n\n    if centered.dtype.is_complex:\n        centered = tf.math.real(centered * tf.math.conj(centered))\n    else:\n        centered = tf.square(centered)\n\n    count = tf.reduce_sum(\n        tf.cast(valid, centered.dtype), axis=axis, keepdims=keepdims\n    )\n\n    var = tf.reduce_sum(centered, axis=axis, keepdims=keepdims) / count\n\n    return var\n\n\ndef nan_to_num(x, nan=0.0, posinf=None, neginf=None):\n    x = convert_to_tensor(x)\n\n    dtype = x.dtype\n    dtype_as_dtype = tf.as_dtype(dtype)\n    if dtype_as_dtype.is_integer or not dtype_as_dtype.is_numeric:\n        return x\n\n    # Replace NaN with `nan`\n    x = tf.where(tf.math.is_nan(x), tf.constant(nan, dtype), x)\n\n    # Replace positive infinity with `posinf` or `dtype.max`\n    if posinf is None:\n        posinf = dtype.max\n    x = tf.where(tf.math.is_inf(x) & (x > 0), tf.constant(posinf, dtype), x)\n\n    # Replace negative infinity with `neginf` or `dtype.min`\n    if neginf is None:\n        neginf = dtype.min\n    x = tf.where(tf.math.is_inf(x) & (x < 0), tf.constant(neginf, dtype), x)\n\n    return x\n\n\ndef ndim(x):\n    x = convert_to_tensor(x)\n    return x.shape.rank\n\n\ndef nonzero(x):\n    x = convert_to_tensor(x)\n    result = tf.unstack(tf.where(tf.cast(x, \"bool\")), x.shape.rank, axis=1)\n    return tree.map_structure(lambda indices: tf.cast(indices, \"int32\"), result)\n\n\ndef not_equal(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    x1 = tf.cast(x1, dtype)\n    x2 = tf.cast(x2, dtype)\n    return tf.not_equal(x1, x2)\n\n\ndef ones_like(x, dtype=None):\n    return tf.ones_like(x, dtype=dtype)\n\n\ndef zeros_like(x, dtype=None):\n    return tf.zeros_like(x, dtype=dtype)\n\n\ndef outer(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    x1 = tf.cast(x1, dtype)\n    x2 = tf.cast(x2, dtype)\n    return tf.reshape(x1, [-1, 1]) * tf.reshape(x2, [-1])\n\n\ndef pad(x, pad_width, mode=\"constant\", constant_values=None):\n    x = convert_to_tensor(x)\n    kwargs = {}\n    if constant_values is not None:\n        if mode != \"constant\":\n            raise ValueError(\n                \"Argument `constant_values` can only be \"\n                \"provided when `mode == 'constant'`. \"\n                f\"Received: mode={mode}\"\n            )\n        kwargs[\"constant_values\"] = constant_values\n    pad_width = convert_to_tensor(pad_width, \"int32\")\n    return tf.pad(x, pad_width, mode.upper(), **kwargs)\n\n\ndef prod(x, axis=None, keepdims=False, dtype=None):\n    x = convert_to_tensor(x)\n    if dtype is None:\n        dtype = dtypes.result_type(x.dtype)\n        if dtype == \"bool\":\n            dtype = \"int32\"\n        elif dtype in (\"int8\", \"int16\"):\n            dtype = \"int32\"\n        elif dtype in (\"uint8\", \"uint16\"):\n            dtype = \"uint32\"\n        x = tf.cast(x, dtype)\n    return tf.reduce_prod(x, axis=axis, keepdims=keepdims)\n\n\ndef ptp(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    return tf.reduce_max(x, axis=axis, keepdims=keepdims) - tf.reduce_min(\n        x, axis=axis, keepdims=keepdims\n    )\n\n\ndef _quantile(x, q, axis=None, method=\"linear\", keepdims=False):\n    # ref: tfp.stats.percentile\n    # float64 is needed here and below, else we get the wrong index if the array\n    # is huge along axis.\n    q = tf.cast(q, \"float64\")\n\n    # Move `axis` dims of `x` to the rightmost, call it `y`.\n    if axis is None:\n        y = tf.reshape(x, [-1])\n    else:\n        x_ndims = len(x.shape)\n        # _make_static_axis_non_negative_list\n        axis = [canonicalize_axis(a, x_ndims) for a in axis]\n\n        # _move_dims_to_flat_end\n        other_dims = sorted(set(range(x_ndims)).difference(axis))\n        perm = other_dims + list(axis)\n        x_permed = tf.transpose(a=x, perm=perm)\n        if None not in x.shape:\n            x_shape = list(x.shape)\n            other_shape = [x_shape[i] for i in other_dims]\n            end_shape = [math.prod([x_shape[i] for i in axis])]\n            full_shape = other_shape + end_shape\n        else:\n            other_shape = tf.gather(tf.shape(x), tf.cast(other_dims, tf.int64))\n            full_shape = tf.concat([other_shape, [-1]], axis=0)\n        y = tf.reshape(x_permed, shape=full_shape)\n\n    # Sort (in ascending order) everything which allows multiple calls to sort\n    # only once (under the hood) and use CSE.\n    sorted_y = tf.sort(y, axis=-1, direction=\"ASCENDING\")\n\n    d = tf.cast(tf.shape(y)[-1], \"float64\")\n\n    def _get_indices(method):\n        \"\"\"Get values of y at the indices implied by method.\"\"\"\n        if method == \"lower\":\n            indices = tf.math.floor((d - 1) * q)\n        elif method == \"higher\":\n            indices = tf.math.ceil((d - 1) * q)\n        elif method == \"nearest\":\n            indices = tf.round((d - 1) * q)\n        # d - 1 will be distinct from d in int32, but not necessarily double.\n        # So clip to avoid out of bounds errors.\n        return tf.clip_by_value(\n            tf.cast(indices, \"int32\"), 0, tf.shape(y)[-1] - 1\n        )\n\n    if method in [\"nearest\", \"lower\", \"higher\"]:\n        gathered_y = tf.gather(sorted_y, _get_indices(method), axis=-1)\n    elif method == \"midpoint\":\n        gathered_y = 0.5 * (\n            tf.gather(sorted_y, _get_indices(\"lower\"), axis=-1)\n            + tf.gather(sorted_y, _get_indices(\"higher\"), axis=-1)\n        )\n    elif method == \"linear\":\n        larger_y_idx = _get_indices(\"higher\")\n        exact_idx = (d - 1) * q\n        # preserve_gradients\n        smaller_y_idx = tf.maximum(larger_y_idx - 1, 0)\n        larger_y_idx = tf.minimum(smaller_y_idx + 1, tf.shape(y)[-1] - 1)\n        fraction = tf.cast(larger_y_idx, tf.float64) - exact_idx\n        fraction = tf.cast(fraction, y.dtype)\n        gathered_y = (\n            tf.gather(sorted_y, larger_y_idx, axis=-1) * (1 - fraction)\n            + tf.gather(sorted_y, smaller_y_idx, axis=-1) * fraction\n        )\n\n    # Propagate NaNs\n    if x.dtype in (tf.bfloat16, tf.float16, tf.float32, tf.float64):\n        # Apparently tf.is_nan doesn't like other dtypes\n        nan_batch_members = tf.reduce_any(tf.math.is_nan(x), axis=axis)\n        right_rank_matched_shape = tf.pad(\n            tf.shape(nan_batch_members),\n            paddings=[[0, tf.rank(q)]],\n            constant_values=1,\n        )\n        nan_batch_members = tf.reshape(\n            nan_batch_members, shape=right_rank_matched_shape\n        )\n        nan_value = tf.constant(float(\"NaN\"), dtype=x.dtype)\n        gathered_y = tf.where(nan_batch_members, nan_value, gathered_y)\n\n    # Expand dimensions if requested\n    if keepdims:\n        if axis is None:\n            ones_vec = tf.ones(shape=[tf.rank(x) + tf.rank(q)], dtype=\"int32\")\n            gathered_y *= tf.ones(ones_vec, dtype=gathered_y.dtype)\n        else:\n            for i in sorted(axis):\n                gathered_y = tf.expand_dims(gathered_y, axis=i)\n\n    # rotate_transpose\n    shift_value_static = tf.get_static_value(tf.rank(q))\n    ndims = tf.TensorShape(gathered_y.shape).rank\n    if ndims < 2:\n        return gathered_y\n    shift_value_static = int(\n        math.copysign(1, shift_value_static)\n        * (builtins.abs(shift_value_static) % ndims)\n    )\n    if shift_value_static == 0:\n        return gathered_y\n    perm = collections.deque(range(ndims))\n    perm.rotate(shift_value_static)\n    return tf.transpose(a=gathered_y, perm=list(perm))\n\n\ndef quantile(x, q, axis=None, method=\"linear\", keepdims=False):\n    x = convert_to_tensor(x)\n    q = convert_to_tensor(q)\n    axis = to_tuple_or_list(axis)\n    compute_dtype = dtypes.result_type(x.dtype, float)\n    x = tf.cast(x, compute_dtype)\n    return _quantile(x, q, axis=axis, method=method, keepdims=keepdims)\n\n\ndef ravel(x):\n    x = convert_to_tensor(x)\n    return tf.reshape(x, [-1])\n\n\ndef unravel_index(indices, shape):\n    indices = tf.convert_to_tensor(indices)\n    input_dtype = indices.dtype\n\n    if None in shape:\n        raise ValueError(\n            f\"`shape` argument cannot contain `None`. Received: shape={shape}\"\n        )\n\n    if indices.ndim == 1:\n        coords = []\n        for dim in reversed(shape):\n            coords.append(tf.cast(indices % dim, input_dtype))\n            indices = indices // dim\n        return tuple(reversed(coords))\n\n    indices_shape = indices.shape\n    coords = []\n    for dim in shape:\n        coords.append(\n            tf.reshape(tf.cast(indices % dim, input_dtype), indices_shape)\n        )\n        indices = indices // dim\n\n    return tuple(reversed(coords))\n\n\n@sparse.elementwise_unary\ndef real(x):\n    x = convert_to_tensor(x)\n    return tf.math.real(x)\n\n\n@sparse.densifying_unary(np.inf)\ndef reciprocal(x):\n    x = convert_to_tensor(x)\n    return tf.math.reciprocal(x)\n\n\ndef repeat(x, repeats, axis=None):\n    x = convert_to_tensor(x)\n    # TODO: tf.repeat doesn't support uint16\n    if standardize_dtype(x.dtype) == \"uint16\":\n        x = tf.cast(x, \"uint32\")\n        return tf.cast(tf.repeat(x, repeats, axis=axis), \"uint16\")\n    return tf.repeat(x, repeats, axis=axis)\n\n\ndef reshape(x, newshape):\n    x = convert_to_tensor(x)\n    if isinstance(x, tf.SparseTensor):\n        from keras.src.ops.operation_utils import compute_reshape_output_shape\n\n        output_shape = compute_reshape_output_shape(\n            x.shape, newshape, \"newshape\"\n        )\n        output = tf.sparse.reshape(x, newshape)\n        output.set_shape(output_shape)\n        return output\n    return tf.reshape(x, newshape)\n\n\ndef roll(x, shift, axis=None):\n    x = convert_to_tensor(x)\n    if axis is not None:\n        return tf.roll(x, shift=shift, axis=axis)\n\n    # If axis is None, the roll happens as a 1-d tensor.\n    original_shape = tf.shape(x)\n    x = tf.roll(tf.reshape(x, [-1]), shift, 0)\n    return tf.reshape(x, original_shape)\n\n\ndef searchsorted(sorted_sequence, values, side=\"left\"):\n    if ndim(sorted_sequence) != 1:\n        raise ValueError(\n            \"`searchsorted` only supports 1-D sorted sequences. \"\n            \"You can use `keras.ops.vectorized_map` \"\n            \"to extend it to N-D sequences. Received: \"\n            f\"sorted_sequence.shape={sorted_sequence.shape}\"\n        )\n    sequence_len = sorted_sequence.shape[0]\n    out_type = (\n        \"int32\"\n        if sequence_len is not None and sequence_len <= np.iinfo(np.int32).max\n        else \"int64\"\n    )\n    return tf.searchsorted(\n        sorted_sequence, values, side=side, out_type=out_type\n    )\n\n\n@sparse.elementwise_unary\ndef sign(x):\n    x = convert_to_tensor(x)\n    ori_dtype = standardize_dtype(x.dtype)\n    # TODO: tf.sign doesn't support uint8, uint16, uint32\n    if ori_dtype in (\"uint8\", \"uint16\", \"uint32\"):\n        x = tf.cast(x, \"int32\")\n        return tf.cast(tf.sign(x), ori_dtype)\n    return tf.sign(x)\n\n\n@sparse.elementwise_unary\ndef signbit(x):\n    x = convert_to_tensor(x)\n    ori_dtype = standardize_dtype(x.dtype)\n    if ori_dtype == \"bool\":\n        return tf.fill(tf.shape(x), False)\n    elif \"int\" in ori_dtype:\n        return x < 0\n    else:\n        x = cast(x, \"float32\")\n        return tf.less(\n            tf.bitwise.bitwise_and(\n                tf.bitcast(x, tf.int32),\n                # tf.float32 sign bit\n                tf.constant(tf.int32.min, dtype=tf.int32),\n            ),\n            0,\n        )\n\n\n@sparse.elementwise_unary\ndef sin(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = tf.cast(x, dtype)\n    return tf.math.sin(x)\n\n\ndef sinc(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = tf.cast(x, dtype)\n    pi_x = x * tf.constant(np.pi, dtype=x.dtype)\n    return tf.where(\n        tf.equal(x, 0),\n        tf.ones_like(x),\n        tf.math.sin(pi_x) / pi_x,\n    )\n\n\n@sparse.elementwise_unary\ndef sinh(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = tf.cast(x, dtype)\n    return tf.math.sinh(x)\n\n\ndef size(x):\n    x = convert_to_tensor(x)\n    return tf.size(x)\n\n\ndef sort(x, axis=-1):\n    x = convert_to_tensor(x)\n    ori_dtype = standardize_dtype(x.dtype)\n    # TODO: tf.sort doesn't support bool\n    if ori_dtype == \"bool\":\n        x = tf.cast(x, \"int8\")\n        return tf.cast(tf.sort(x, axis=axis), ori_dtype)\n    return tf.sort(x, axis=axis)\n\n\ndef split(x, indices_or_sections, axis=0):\n    if not isinstance(indices_or_sections, int):\n        # `tf.split` requires `num_or_size_splits`, so we need to convert\n        # `indices_or_sections` to the appropriate format.\n        total_size = x.shape[axis]\n        indices_or_sections = convert_to_tensor(indices_or_sections)\n        start_size = indices_or_sections[0:1]\n        end_size = total_size - indices_or_sections[-1:]\n        num_or_size_splits = tf.concat(\n            [start_size, diff(indices_or_sections), end_size], axis=0\n        )\n    else:\n        num_or_size_splits = indices_or_sections\n    return tf.split(x, num_or_size_splits, axis=axis)\n\n\ndef array_split(x, indices_or_sections, axis=0):\n    x = tf.convert_to_tensor(x)\n    num_splits = indices_or_sections\n    total_size = shape_op(x)[axis]\n    avg_size = total_size // num_splits\n    remainder = total_size % num_splits\n    sizes = [avg_size + 1] * remainder + [avg_size] * (num_splits - remainder)\n\n    return tf.split(x, sizes, axis=axis)\n\n\ndef stack(x, axis=0):\n    dtype_set = set([getattr(a, \"dtype\", type(a)) for a in x])\n    if len(dtype_set) > 1:\n        dtype = dtypes.result_type(*dtype_set)\n        x = tree.map_structure(lambda a: convert_to_tensor(a, dtype), x)\n    return tf.stack(x, axis=axis)\n\n\ndef std(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    ori_dtype = standardize_dtype(x.dtype)\n    if \"int\" in ori_dtype or ori_dtype == \"bool\":\n        x = tf.cast(x, config.floatx())\n    return tf.math.reduce_std(x, axis=axis, keepdims=keepdims)\n\n\ndef swapaxes(x, axis1, axis2):\n    x = convert_to_tensor(x)\n\n    if (\n        x.shape.rank is not None\n        and isinstance(axis1, int)\n        and isinstance(axis2, int)\n    ):\n        # This branch makes sure `perm` is statically known, to avoid a\n        # not-compile-time-constant XLA error.\n        axis1 = canonicalize_axis(axis1, x.ndim)\n        axis2 = canonicalize_axis(axis2, x.ndim)\n\n        # Directly return x if no movement is required\n        if axis1 == axis2:\n            return x\n\n        perm = list(range(x.ndim))\n        perm[axis1] = axis2\n        perm[axis2] = axis1\n    else:\n        x_rank = tf.rank(x)\n        axis1 = tf.where(axis1 < 0, tf.add(axis1, x_rank), axis1)\n        axis2 = tf.where(axis2 < 0, tf.add(axis2, x_rank), axis2)\n        perm = tf.range(x_rank)\n        perm = tf.tensor_scatter_nd_update(\n            perm, [[axis1], [axis2]], [axis2, axis1]\n        )\n    return tf.transpose(x, perm)\n\n\ndef take(x, indices, axis=None):\n    x = convert_to_tensor(x)\n    if axis is None:\n        x = tf.reshape(x, (-1,))\n        axis = 0\n\n    def fix_negative_indices(i):\n        # Correct the indices using \"fill\" mode which is the same as in jax\n        return tf.where(i < 0, i + tf.cast(tf.shape(x)[axis], i.dtype), i)\n\n    if isinstance(indices, tf.SparseTensor):\n        if x.dtype not in (tf.float16, tf.float32, tf.float64, tf.bfloat16):\n            warnings.warn(\n                \"`take` with the TensorFlow backend does not support \"\n                f\"`x.dtype={x.dtype}` when `indices` is a sparse tensor; \"\n                \"densifying `indices`.\"\n            )\n            indices = convert_to_tensor(indices, sparse=False)\n        elif axis != 0:\n            warnings.warn(\n                \"`take` with the TensorFlow backend does not support \"\n                f\"`axis={axis}` when `indices` is a sparse tensor; \"\n                \"densifying `indices`.\"\n            )\n            indices = convert_to_tensor(indices, sparse=False)\n        else:\n            indices = sparse.sparse_with_values(\n                indices, fix_negative_indices(indices.values)\n            )\n            # `expand_dims` on `indices` prevents combiner from being applied.\n            output = tf.nn.safe_embedding_lookup_sparse(\n                embedding_weights=tf.convert_to_tensor(x),\n                sparse_ids=tf.sparse.expand_dims(indices, axis=-1),\n                default_id=0,\n            )\n            output.set_shape(indices.shape + output.shape[len(indices.shape) :])\n            return output\n    elif isinstance(indices, tf.RaggedTensor):\n        indices = indices.with_values(fix_negative_indices(indices.values))\n        if axis == 0:\n            return tf.nn.embedding_lookup(x, indices)\n        else:\n            return tf.gather(x, indices, axis=axis)\n\n    indices = fix_negative_indices(convert_to_tensor(indices))\n    return tf.gather(x, indices, axis=axis)\n\n\ndef take_along_axis(x, indices, axis=None):\n    from keras.src.ops import operation_utils\n\n    x = convert_to_tensor(x)\n    indices = convert_to_tensor(indices, \"int64\")\n    if axis is None:\n        if indices.ndim != 1:\n            raise ValueError(\n                \"`indices` must be 1D if axis=None. \"\n                f\"Received: indices.shape={indices.shape}\"\n            )\n        return take_along_axis(tf.reshape(x, [-1]), indices, 0)\n\n    # Compute the static output shape as later on, all shapes manipulations\n    # use dynamic shapes.\n    static_output_shape = operation_utils.compute_take_along_axis_output_shape(\n        x.shape, indices.shape, axis\n    )\n    rank = x.ndim\n    static_axis = axis\n    axis = axis + rank if axis < 0 else axis\n\n    if axis >= rank:\n        raise ValueError(f\"Invalid axis: {static_axis} for input rank: {rank}\")\n\n    x_original_shape = shape_op(x)\n    indices_original_shape = shape_op(indices)\n\n    # Broadcast the static shapes first, but not for the `axis` dimension.\n    x_static_shape = list(x.shape)\n    indices_static_shape = list(indices.shape)\n    x_static_shape[axis] = 1\n    indices_static_shape[axis] = 1\n    broadcast_shape = operation_utils.broadcast_shapes(\n        x_static_shape, indices_static_shape\n    )\n\n    if None in broadcast_shape:\n        # Dynamic broadcast case. Note that `tf.broadcast_dynamic_shape` is\n        # not always XLA compilable with dynamic dimensions.\n        # We replace `None`s with the dynamic dimensions.\n        # `maximum` is the correct formula only when shapes are broadcastable,\n        # we rely on the broacast itself to fail in the incorrect case rather\n        # than make some expensive dynamic checks here.\n        broadcast_shape = [\n            tf.maximum(x_original_shape[i], indices_original_shape[i])\n            if dim is None\n            else dim\n            for i, dim in enumerate(broadcast_shape)\n        ]\n\n    x_shape = list(broadcast_shape)\n    x_shape[axis] = x_original_shape[axis]\n    indices_shape = list(broadcast_shape)\n    indices_shape[axis] = indices_original_shape[axis]\n    x = tf.broadcast_to(x, x_shape)\n    indices = tf.broadcast_to(indices, indices_shape)\n\n    # Correct the indices using \"fill\" mode which is the same as in jax\n    indices = tf.where(\n        indices < 0,\n        indices + tf.cast(x_shape[static_axis], dtype=indices.dtype),\n        indices,\n    )\n\n    x = swapaxes(x, static_axis, -1)\n    indices = swapaxes(indices, static_axis, -1)\n\n    x_shape = tf.shape(x)\n    x = tf.reshape(x, [-1, x_shape[-1]])\n    indices_shape = tf.shape(indices)\n    indices = tf.reshape(indices, [-1, indices_shape[-1]])\n\n    result = tf.gather(x, indices, batch_dims=1)\n    result = tf.reshape(result, indices_shape)\n    result = swapaxes(result, static_axis, -1)\n    result.set_shape(static_output_shape)\n    return result\n\n\n@sparse.elementwise_unary\ndef tan(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = tf.cast(x, dtype)\n    return tf.math.tan(x)\n\n\n@sparse.elementwise_unary\ndef tanh(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = tf.cast(x, dtype)\n    return tf.math.tanh(x)\n\n\ndef tensordot(x1, x2, axes=2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    result_dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    # TODO: tf.tensordot only supports float types\n    compute_dtype = dtypes.result_type(result_dtype, float)\n    x1 = tf.cast(x1, compute_dtype)\n    x2 = tf.cast(x2, compute_dtype)\n    return tf.cast(tf.tensordot(x1, x2, axes=axes), dtype=result_dtype)\n\n\n@sparse.elementwise_unary\ndef round(x, decimals=0):\n    if decimals == 0:\n        return tf.round(x)\n    x_dtype = x.dtype\n    if tf.as_dtype(x_dtype).is_integer:\n        # int\n        if decimals > 0:\n            return x\n        # temporarily convert to floats\n        factor = tf.cast(math.pow(10, decimals), config.floatx())\n        x = tf.cast(x, config.floatx())\n    else:\n        # float\n        factor = tf.cast(math.pow(10, decimals), x.dtype)\n    x = tf.multiply(x, factor)\n    x = tf.round(x)\n    x = tf.divide(x, factor)\n    return tf.cast(x, x_dtype)\n\n\ndef tile(x, repeats):\n    x = convert_to_tensor(x)\n\n    # Convert repeats to a list (works for both sequences and 1D tensors)\n    if isinstance(repeats, int):\n        repeats = [repeats]\n    else:\n        repeats = [v for v in repeats]\n\n    # Process list elements: convert concrete scalar tensors to Python ints\n    processed_repeats = []\n    for r in repeats:\n        if hasattr(r, \"numpy\") and r.shape == ():\n            processed_repeats.append(int(r.numpy()))\n        else:\n            processed_repeats.append(r)\n    repeats = processed_repeats\n\n    # Get x rank\n    x_rank = x.shape.rank\n\n    # Pad repeats if needed\n    if len(repeats) < x_rank:\n        repeats = [1] * (x_rank - len(repeats)) + repeats\n\n    # Add dimensions to x if needed using tf.expand_dims\n    while len(repeats) > x.shape.rank:\n        x = tf.expand_dims(x, 0)\n\n    return tf.tile(x, repeats)\n\n\ndef trace(x, offset=0, axis1=0, axis2=1):\n    x = convert_to_tensor(x)\n    dtype = standardize_dtype(x.dtype)\n    if dtype in (\"bool\", \"int8\", \"int16\"):\n        dtype = \"int32\"\n    elif dtype in (\"uint8\", \"uint16\"):\n        dtype = \"uint32\"\n    x = tf.cast(x, dtype)\n    x_shape = tf.shape(x)\n    x = moveaxis(x, (axis1, axis2), (-2, -1))\n    # Mask out the diagonal and reduce.\n    x = tf.where(\n        eye(x_shape[axis1], x_shape[axis2], k=offset, dtype=\"bool\"),\n        x,\n        tf.zeros_like(x),\n    )\n    return tf.reduce_sum(x, axis=(-2, -1))\n\n\ndef tri(N, M=None, k=0, dtype=None):\n    M = M if M is not None else N\n    dtype = standardize_dtype(dtype or config.floatx())\n    if k < 0:\n        lower = -k - 1\n        if lower > N:\n            r = tf.zeros([N, M], dtype=dtype)\n        else:\n            o = tf.ones([N, M], dtype=\"bool\")\n            r = tf.cast(\n                tf.logical_not(tf.linalg.band_part(o, lower, -1)), dtype=dtype\n            )\n    else:\n        o = tf.ones([N, M], dtype=dtype)\n        if k > M:\n            r = o\n        else:\n            r = tf.linalg.band_part(o, -1, k)\n    return r\n\n\ndef tril(x, k=0):\n    x = convert_to_tensor(x)\n\n    def _negative_k_branch():\n        shape = tf.shape(x)\n        rows, cols = shape[-2], shape[-1]\n        i, j = tf.meshgrid(tf.range(rows), tf.range(cols), indexing=\"ij\")\n        mask = i >= j - k\n        return tf.where(tf.broadcast_to(mask, shape), x, tf.zeros_like(x))\n\n    if isinstance(k, int):\n        if k >= 0:\n            return tf.linalg.band_part(x, -1, k)\n        return _negative_k_branch()\n\n    # when `k` is a tensor\n    return tf.cond(\n        tf.greater_equal(k, 0),\n        lambda: tf.linalg.band_part(x, -1, k),\n        _negative_k_branch,\n    )\n\n\ndef triu(x, k=0):\n    x = convert_to_tensor(x)\n\n    def _positive_k_branch():\n        shape = tf.shape(x)\n        rows, cols = shape[-2], shape[-1]\n        i, j = tf.meshgrid(tf.range(rows), tf.range(cols), indexing=\"ij\")\n        mask = i <= j - k\n        return tf.where(tf.broadcast_to(mask, shape), x, tf.zeros_like(x))\n\n    if isinstance(k, int):\n        if k <= 0:\n            return tf.linalg.band_part(x, -k, -1)\n        return _positive_k_branch()\n\n    # when `k` is a tensor\n    return tf.cond(\n        tf.less_equal(k, 0),\n        lambda: tf.linalg.band_part(x, -k, -1),\n        _positive_k_branch,\n    )\n\n\ndef trunc(x):\n    x = convert_to_tensor(x)\n    dtype = standardize_dtype(x.dtype)\n    if dtype == \"bool\" or \"int\" in dtype:\n        return x\n    return tf.where(x < 0, tf.math.ceil(x), tf.math.floor(x))\n\n\ndef vdot(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    result_dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    compute_dtype = dtypes.result_type(result_dtype, float)\n    x1 = tf.cast(x1, compute_dtype)\n    x2 = tf.cast(x2, compute_dtype)\n    x1 = tf.reshape(x1, [-1])\n    x2 = tf.reshape(x2, [-1])\n    return tf.cast(dot(x1, x2), result_dtype)\n\n\ndef inner(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    result_dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    compute_dtype = dtypes.result_type(result_dtype, float)\n    x1 = tf.cast(x1, compute_dtype)\n    x2 = tf.cast(x2, compute_dtype)\n    x = tf.cond(\n        tf.math.logical_or(\n            tf.math.equal(tf.rank(x1), 0),\n            tf.math.equal(tf.rank(x2), 0),\n        ),\n        lambda: x1 * x2,\n        lambda: tf.tensordot(x1, x2, axes=[[-1], [-1]]),\n    )\n    return tf.cast(x, result_dtype)\n\n\ndef vstack(xs):\n    dtype_set = set([getattr(x, \"dtype\", type(x)) for x in xs])\n    if len(dtype_set) > 1:\n        dtype = dtypes.result_type(*dtype_set)\n        xs = tree.map_structure(lambda x: convert_to_tensor(x, dtype), xs)\n    return tf.concat(xs, axis=0)\n\n\ndef vsplit(x, indices_or_sections):\n    return split(x, indices_or_sections, axis=0)\n\n\ndef _vmap_fn(fn, in_axes=0):\n    if in_axes != 0:\n        raise ValueError(\n            \"Not supported with `vectorize()` with the TensorFlow backend.\"\n        )\n\n    @functools.wraps(fn)\n    def wrapped(x):\n        return tf.vectorized_map(fn, x)\n\n    return wrapped\n\n\ndef vectorize(pyfunc, *, excluded=None, signature=None):\n    return vectorize_impl(\n        pyfunc, _vmap_fn, excluded=excluded, signature=signature\n    )\n\n\ndef where(condition, x1=None, x2=None):\n    condition = tf.cast(condition, \"bool\")\n    if x1 is not None and x2 is not None:\n        if not isinstance(x1, (int, float)):\n            x1 = convert_to_tensor(x1)\n        if not isinstance(x2, (int, float)):\n            x2 = convert_to_tensor(x2)\n        dtype = dtypes.result_type(\n            getattr(x1, \"dtype\", type(x1)),\n            getattr(x2, \"dtype\", type(x2)),\n        )\n        x1 = convert_to_tensor(x1, dtype)\n        x2 = convert_to_tensor(x2, dtype)\n        return tf.where(condition, x1, x2)\n    if x1 is None and x2 is None:\n        return nonzero(condition)\n    raise ValueError(\n        \"`x1` and `x2` either both should be `None`\"\n        \" or both should have non-None value.\"\n    )\n\n\n@sparse.elementwise_division\ndef divide(x1, x2):\n    if not isinstance(x1, (int, float)):\n        x1 = convert_to_tensor(x1)\n    if not isinstance(x2, (int, float)):\n        x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(\n        getattr(x1, \"dtype\", type(x1)),\n        getattr(x2, \"dtype\", type(x2)),\n        float,\n    )\n    x1 = convert_to_tensor(x1, dtype)\n    x2 = convert_to_tensor(x2, dtype)\n    return tf.divide(x1, x2)\n\n\ndef divide_no_nan(x1, x2):\n    if not isinstance(x1, (int, float)):\n        x1 = convert_to_tensor(x1)\n    if not isinstance(x2, (int, float)):\n        x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(\n        getattr(x1, \"dtype\", type(x1)),\n        getattr(x2, \"dtype\", type(x2)),\n        float,\n    )\n    x1 = convert_to_tensor(x1, dtype)\n    x2 = convert_to_tensor(x2, dtype)\n    return tf.math.divide_no_nan(x1, x2)\n\n\ndef true_divide(x1, x2):\n    return divide(x1, x2)\n\n\ndef power(x1, x2):\n    if not isinstance(x1, (int, float)):\n        x1 = convert_to_tensor(x1)\n    if not isinstance(x2, (int, float)):\n        x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(\n        getattr(x1, \"dtype\", type(x1)),\n        getattr(x2, \"dtype\", type(x2)),\n    )\n    # TODO: tf.pow doesn't support uint* types\n    if \"uint\" in dtype:\n        x1 = convert_to_tensor(x1, \"int32\")\n        x2 = convert_to_tensor(x2, \"int32\")\n        return tf.cast(tf.pow(x1, x2), dtype)\n    x1 = convert_to_tensor(x1, dtype)\n    x2 = convert_to_tensor(x2, dtype)\n    return tf.pow(x1, x2)\n\n\n@sparse.elementwise_unary\ndef negative(x):\n    return tf.negative(x)\n\n\ndef nextafter(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n\n    dtype = dtypes.result_type(x1.dtype, x2.dtype, float)\n    x1 = tf.cast(x1, tf.float64)\n    x2 = tf.cast(x2, tf.float64)\n    return tf.cast(tf.math.nextafter(x1, x2), dtype)\n\n\n@sparse.elementwise_unary\ndef square(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"bool\":\n        x = tf.cast(x, \"int32\")\n    return tf.square(x)\n\n\n@sparse.elementwise_unary\ndef sqrt(x):\n    x = convert_to_tensor(x)\n    dtype = (\n        config.floatx()\n        if standardize_dtype(x.dtype) == \"int64\"\n        else dtypes.result_type(x.dtype, float)\n    )\n    x = tf.cast(x, dtype)\n    return tf.math.sqrt(x)\n\n\ndef squeeze(x, axis=None):\n    x = convert_to_tensor(x)\n    axis = to_tuple_or_list(axis)\n    static_shape = x.shape.as_list()\n    if axis is not None:\n        for a in axis:\n            if static_shape[a] != 1:\n                raise ValueError(\n                    f\"Cannot squeeze axis={a}, because the dimension is not 1.\"\n                )\n        axis = sorted([canonicalize_axis(a, len(static_shape)) for a in axis])\n    if isinstance(x, tf.SparseTensor):\n        dynamic_shape = tf.shape(x)\n        new_shape = []\n        gather_indices = []\n        for i, dim in enumerate(static_shape):\n            if not (dim == 1 if axis is None else i in axis):\n                new_shape.append(dim if dim is not None else dynamic_shape[i])\n                gather_indices.append(i)\n        new_indices = tf.gather(x.indices, gather_indices, axis=1)\n        return tf.SparseTensor(new_indices, x.values, tuple(new_shape))\n    return tf.squeeze(x, axis=axis)\n\n\ndef transpose(x, axes=None):\n    if isinstance(x, tf.SparseTensor):\n        from keras.src.ops.operation_utils import compute_transpose_output_shape\n\n        output = tf.sparse.transpose(x, perm=axes)\n        output.set_shape(compute_transpose_output_shape(x.shape, axes))\n        return output\n    return tf.transpose(x, perm=axes)\n\n\ndef trapezoid(y, x=None, dx=1.0, axis=-1):\n    def _move_axis_to_last(tensor, axis):\n        if axis == -1:\n            return tensor\n        rank = tf.rank(tensor)\n        if axis < 0:\n            axis = rank + axis\n        perm = tf.concat(\n            [\n                tf.range(axis, dtype=tf.int32),\n                tf.range(axis + 1, rank, dtype=tf.int32),\n                tf.constant([axis], dtype=tf.int32),\n            ],\n            axis=0,\n        )\n        return tf.transpose(tensor, perm=perm)\n\n    y = convert_to_tensor(y)\n    dtype = dtypes.result_type(y.dtype, float)\n    y = tf.cast(y, dtype)\n\n    if x is None:\n        dx_array = tf.cast(dx, dtype)\n    else:\n        x = convert_to_tensor(x, dtype=dtype)\n        dx_array = diff(x, axis=axis)\n        dx_array = _move_axis_to_last(dx_array, axis)\n\n    y = _move_axis_to_last(y, axis)\n\n    avg_heights = 0.5 * (y[..., 1:] + y[..., :-1])\n    result = tf.reduce_sum(avg_heights * dx_array, axis=-1)\n\n    return result\n\n\ndef vander(x, N=None, increasing=False):\n    x = convert_to_tensor(x)\n    result_dtype = dtypes.result_type(x.dtype)\n\n    if N is None:\n        N = shape_op(x)[0]\n\n    if increasing:\n        powers = tf.range(N)\n    else:\n        powers = tf.range(N - 1, -1, -1)\n\n    x_exp = tf.expand_dims(x, axis=-1)\n\n    compute_dtype = dtypes.result_type(x.dtype, \"float32\")\n    vander = tf.math.pow(\n        tf.cast(x_exp, compute_dtype), tf.cast(powers, compute_dtype)\n    )\n    return tf.cast(vander, result_dtype)\n\n\ndef var(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    compute_dtype = dtypes.result_type(x.dtype, \"float32\")\n    result_dtype = dtypes.result_type(x.dtype, float)\n    x = tf.cast(x, compute_dtype)\n    return tf.cast(\n        tf.math.reduce_variance(x, axis=axis, keepdims=keepdims),\n        result_dtype,\n    )\n\n\ndef sum(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    dtype = standardize_dtype(x.dtype)\n    # follow jax's rule\n    if dtype in (\"bool\", \"int8\", \"int16\"):\n        dtype = \"int32\"\n    elif dtype in (\"uint8\", \"uint16\"):\n        dtype = \"uint32\"\n    x = cast(x, dtype)\n    if isinstance(x, tf.SparseTensor):\n        return tf.sparse.reduce_sum(\n            x, axis=axis, keepdims=keepdims, output_is_sparse=True\n        )\n    return tf.reduce_sum(x, axis=axis, keepdims=keepdims)\n\n\ndef eye(N, M=None, k=0, dtype=None):\n    dtype = dtype or config.floatx()\n    M = N if M is None else M\n    if isinstance(k, int) and k == 0:\n        return tf.eye(N, M, dtype=dtype)\n    # Create a smaller square eye and pad appropriately.\n    return tf.pad(\n        tf.eye(tf.minimum(M - k, N + k), dtype=dtype),\n        paddings=(\n            (tf.maximum(-k, 0), tf.maximum(N - M + k, 0)),\n            (tf.maximum(k, 0), tf.maximum(M - N - k, 0)),\n        ),\n    )\n\n\ndef floor_divide(x1, x2):\n    if not isinstance(x1, (int, float)):\n        x1 = convert_to_tensor(x1)\n    if not isinstance(x2, (int, float)):\n        x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(\n        getattr(x1, \"dtype\", type(x1)),\n        getattr(x2, \"dtype\", type(x2)),\n    )\n    x1 = convert_to_tensor(x1, dtype)\n    x2 = convert_to_tensor(x2, dtype)\n    return tf.math.floordiv(x1, x2)\n\n\ndef logical_xor(x1, x2):\n    x1 = tf.cast(x1, \"bool\")\n    x2 = tf.cast(x2, \"bool\")\n    return tf.math.logical_xor(x1, x2)\n\n\ndef corrcoef(x):\n    dtype = x.dtype\n    if dtype in [\"bool\", \"int8\", \"int16\", \"int32\", \"uint8\", \"uint16\", \"uint32\"]:\n        dtype = config.floatx()\n    x = convert_to_tensor(x, dtype)\n\n    if tf.rank(x) == 0:\n        return tf.constant(float(\"nan\"), dtype=config.floatx())\n\n    mean = tf.reduce_mean(x, axis=-1, keepdims=True)\n    x_centered = x - mean\n\n    num_samples = tf.cast(tf.shape(x)[-1], x.dtype)\n    cov_matrix = tf.matmul(x_centered, x_centered, adjoint_b=True) / (\n        num_samples - 1\n    )\n\n    diag = tf.linalg.diag_part(cov_matrix)\n    stddev = tf.sqrt(tf.math.real(diag))\n\n    outer_std = tf.tensordot(stddev, stddev, axes=0)\n    outer_std = tf.cast(outer_std, cov_matrix.dtype)\n    correlation = cov_matrix / outer_std\n\n    correlation_clipped = tf.clip_by_value(tf.math.real(correlation), -1.0, 1.0)\n    if correlation.dtype.is_complex:\n        imag_clipped = tf.clip_by_value(tf.math.imag(correlation), -1.0, 1.0)\n        return tf.complex(correlation_clipped, imag_clipped)\n    else:\n        return correlation_clipped\n\n\ndef correlate(x1, x2, mode=\"valid\"):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n\n    dtype = dtypes.result_type(\n        getattr(x1, \"dtype\", type(x1)),\n        getattr(x2, \"dtype\", type(x2)),\n    )\n    if dtype == tf.int64:\n        dtype = tf.float64\n    elif dtype not in [tf.bfloat16, tf.float16, tf.float64]:\n        dtype = tf.float32\n\n    x1 = tf.cast(x1, dtype)\n    x2 = tf.cast(x2, dtype)\n\n    def _pack(a, b):\n        # a: input [N] -> [1,N,1];\n        # b: filter [M] -> [M,1,1]\n        return (\n            tf.reshape(a, (1, shape_op(a)[0], 1)),\n            tf.reshape(b, (shape_op(b)[0], 1, 1)),\n        )\n\n    def _full_corr(x1, x2):\n        \"\"\"Compute 'full' correlation result (length = n + m - 1).\"\"\"\n        m = shape_op(x2)[0]\n        pad = (\n            builtins.max(m - 1, 0)\n            if isinstance(m, int)\n            else tf.maximum(m - 1, 0)\n        )\n        x1 = tf.pad(x1, [[pad, pad]])  # pad input with zeros\n        x1, x2 = _pack(x1, x2)\n        out = tf.nn.conv1d(x1, x2, stride=1, padding=\"VALID\")\n        return tf.squeeze(out, axis=[0, 2])\n\n    n = shape_op(x1)[0]\n    m = shape_op(x2)[0]\n\n    if mode == \"full\":\n        return _full_corr(x1, x2)\n    elif mode == \"same\":\n        # unfortunately we can't leverage 'SAME' padding directly like\n        # we can with \"valid\"\n        # it works fine for odd-length filters, but for even-length filters\n        # the output is off by 1 compared to numpy, due to how\n        # tf handles centering\n        full_corr = _full_corr(x1, x2)\n        full_len = n + m - 1\n        out_len = (\n            max([n, m])\n            if isinstance(n, int) and isinstance(m, int)\n            else tf.maximum(n, m)\n        )\n        start = (full_len - out_len) // 2\n        return tf.slice(full_corr, [start], [out_len])\n    elif mode == \"valid\":\n        x1, x2 = _pack(x1, x2)\n        return tf.squeeze(\n            tf.nn.conv1d(x1, x2, stride=1, padding=\"VALID\"), axis=[0, 2]\n        )\n    else:\n        raise ValueError(\n            f\"Invalid mode: '{mode}'. Mode must be one of:\"\n            f\" 'full', 'same', 'valid'.\"\n        )\n\n\ndef select(condlist, choicelist, default=0):\n    return tf.experimental.numpy.select(condlist, choicelist, default=default)\n\n\ndef slogdet(x):\n    x = convert_to_tensor(x)\n    return tuple(tf.linalg.slogdet(x))\n\n\ndef argpartition(x, kth, axis=-1):\n    x = convert_to_tensor(x, tf.int32)\n\n    x = swapaxes(x, axis, -1)\n    bottom_ind = tf.math.top_k(-x, kth + 1).indices\n\n    n = tf.shape(x)[-1]\n\n    mask = tf.reduce_sum(tf.one_hot(bottom_ind, n, dtype=tf.int32), axis=0)\n\n    indices = tf.where(mask)\n    updates = tf.squeeze(tf.zeros(tf.shape(indices)[0], dtype=tf.int32))\n\n    final_mask = tf.tensor_scatter_nd_update(x, indices, updates)\n\n    top_ind = tf.math.top_k(final_mask, tf.shape(x)[-1] - kth - 1).indices\n\n    out = tf.concat([bottom_ind, top_ind], axis=x.ndim - 1)\n    return swapaxes(out, -1, axis)\n\n\ndef histogram(x, bins=10, range=None):\n    \"\"\"Computes a histogram of the data tensor `x`.\n\n    Note: the `tf.histogram_fixed_width()` and\n    `tf.histogram_fixed_width_bins()` functions\n    yield slight numerical differences for some edge cases.\n    \"\"\"\n\n    x = tf.convert_to_tensor(x, dtype=x.dtype)\n\n    # Handle the range argument\n    if range is None:\n        min_val = tf.reduce_min(x)\n        max_val = tf.reduce_max(x)\n    else:\n        min_val, max_val = range\n\n    x = tf.boolean_mask(x, (x >= min_val) & (x <= max_val))\n    bin_edges = tf.linspace(min_val, max_val, bins + 1)\n    bin_edges = tf.cast(bin_edges, x.dtype)\n    bin_indices = tf.searchsorted(bin_edges[1:-1], x, side=\"right\")\n\n    # tf.math.bincount does not work with XLA in this case. So, we use\n    # `scatter_nd`.\n    bin_counts = tf.scatter_nd(\n        indices=tf.expand_dims(bin_indices, axis=-1),\n        updates=tf.ones_like(bin_indices, dtype=x.dtype),\n        shape=(bins,),\n    )\n    return bin_counts, bin_edges\n"
  },
  {
    "path": "keras/src/backend/tensorflow/optimizer.py",
    "content": "\"\"\"A class for Tensorflow specific optimizer logic.\n\nThe major behavior change for this class is for tf.distribute.\n\nIt will override methods from base Keras core Optimizer,\nwhich provide distribute specific functionality, e.g. variable\ncreation, loss reduction, etc.\n\"\"\"\n\nimport warnings\n\nimport tensorflow as tf\n\nfrom keras.src import backend\nfrom keras.src.backend.tensorflow.trackable import KerasAutoTrackable\nfrom keras.src.optimizers import base_optimizer\n\n\nclass TFOptimizer(KerasAutoTrackable, base_optimizer.BaseOptimizer):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self._distribution_strategy = tf.distribute.get_strategy()\n\n    def add_variable_from_reference(\n        self, reference_variable, name=None, initializer=\"zeros\"\n    ):\n        if isinstance(reference_variable, backend.Variable):\n            colocate_var = reference_variable.value\n        else:\n            colocate_var = reference_variable\n\n        with self._distribution_strategy.extended.colocate_vars_with(\n            colocate_var\n        ):\n            return super().add_variable_from_reference(\n                reference_variable, name=name, initializer=initializer\n            )\n\n    def stateless_apply(self, optimizer_variables, grads, trainable_variables):\n        # This is mainly due to the interaction with tf.distribute.Strategy,\n        # which requires tf.Variable as the inputs for most of its APIs.\n        raise ValueError(\n            \"stateless_apply is not supported with the TensorFlow backend \"\n            \"(as it is incompatible with tf.distribute).\"\n        )\n\n    def assign(self, variable, value):\n        if isinstance(variable, backend.Variable):\n            variable = variable.value\n        value = tf.cast(value, variable.dtype)\n        if isinstance(value, tf.IndexedSlices):\n            variable.scatter_update(value)\n        else:\n            variable.assign(value)\n\n    def assign_add(self, variable, value):\n        if isinstance(variable, backend.Variable):\n            variable = variable.value\n        value = tf.cast(value, variable.dtype)\n        if isinstance(value, tf.IndexedSlices):\n            variable.scatter_add(value)\n        else:\n            variable.assign_add(value)\n\n    def assign_sub(self, variable, value):\n        if isinstance(variable, backend.Variable):\n            variable = variable.value\n        value = tf.cast(value, variable.dtype)\n        if isinstance(value, tf.IndexedSlices):\n            variable.scatter_sub(value)\n        else:\n            variable.assign_sub(value)\n\n    def _var_key(self, variable):\n        if isinstance(variable, backend.Variable):\n            variable = variable.value  # Convert to tf.Variable\n        if hasattr(variable, \"_distributed_container\"):\n            variable = variable._distributed_container()\n        elif (\n            isinstance(variable, tf.__internal__.CompositeTensor)\n            and hasattr(variable, \"handle\")\n            and hasattr(variable.handle, \"_distributed_container\")\n        ):\n            # For ResourceVariables, the _distributed_container attribute\n            # is added to their handle tensors.\n            variable = variable.handle._distributed_container()\n        return variable._unique_id\n\n    def _apply_weight_decay(self, variables):\n        if self.weight_decay is None:\n            return\n\n        def distributed_apply_weight_decay(distribution, variables, **kwargs):\n            def weight_decay_fn(variable):\n                if self._use_weight_decay(variable):\n                    lr = tf.cast(self.learning_rate, variable.dtype)\n                    wd = tf.cast(self.weight_decay, variable.dtype)\n                    variable.assign_sub(variable * wd * lr)\n\n            for variable in variables:\n                if isinstance(variable, backend.Variable):\n                    variable = variable.value  # Convert to tf.Variable\n                distribution.extended.update(\n                    variable, weight_decay_fn, group=False\n                )\n\n        tf.__internal__.distribute.interim.maybe_merge_call(\n            distributed_apply_weight_decay,\n            self._distribution_strategy,\n            variables,\n        )\n\n    def _backend_update_step(self, grads, trainable_variables, learning_rate):\n        trainable_variables = [\n            v.value if isinstance(v, backend.Variable) else v\n            for v in trainable_variables\n        ]\n        grads_and_vars = list(zip(grads, trainable_variables))\n        grads_and_vars = self._all_reduce_sum_gradients(grads_and_vars)\n        tf.__internal__.distribute.interim.maybe_merge_call(\n            self._distributed_tf_update_step,\n            self._distribution_strategy,\n            grads_and_vars,\n            learning_rate,\n        )\n\n    def _distributed_tf_update_step(\n        self, distribution, grads_and_vars, learning_rate\n    ):\n        def apply_grad_to_update_var(var, grad, learning_rate):\n            return self.update_step(grad, var, learning_rate)\n\n        for grad, var in grads_and_vars:\n            distribution.extended.update(\n                var,\n                apply_grad_to_update_var,\n                args=(grad, learning_rate),\n                group=False,\n            )\n\n    def _all_reduce_sum_gradients(self, grads_and_vars):\n        \"\"\"Returns all-reduced gradients aggregated via summation.\n\n        Args:\n            grads_and_vars: List of (gradient, variable) pairs.\n\n        Returns:\n            List of (gradient, variable) pairs\n            where gradients have been all-reduced.\n        \"\"\"\n        replica_context = tf.distribute.get_replica_context()\n        if not replica_context:\n            return grads_and_vars\n\n        grads_and_vars = list(grads_and_vars)\n        filtered_grads_and_vars = filter_empty_gradients(grads_and_vars)\n        if filtered_grads_and_vars:\n            grads = [pair[0] for pair in filtered_grads_and_vars]\n            reduced = tf.distribute.get_replica_context().all_reduce(\n                tf.distribute.ReduceOp.SUM, grads\n            )\n        else:\n            reduced = []\n        # Copy 'reduced' but add None gradients back in\n        reduced_with_nones = []\n        reduced_pos = 0\n        for g, v in grads_and_vars:\n            if g is None:\n                reduced_with_nones.append((None, v))\n            else:\n                reduced_with_nones.append((reduced[reduced_pos], v))\n                reduced_pos += 1\n        if reduced_pos != len(reduced):\n            raise ValueError(\n                \"Internal error: Failed to add all gradients. Expected to \"\n                f\"process {len(reduced)} gradients, but processed \"\n                f\"{reduced_pos}.\"\n            )\n        return reduced_with_nones\n\n    def _overwrite_model_variables_with_average_value(\n        self, trainable_variables\n    ):\n        \"\"\"Overwrite model variables with their moving average values.\n\n        This function overwrites variables on each device.\n\n        Args:\n          var_list: list of model variables.\n        \"\"\"\n        trainable_variables = [\n            v.value if isinstance(v, backend.Variable) else v\n            for v in trainable_variables\n        ]\n        # Override model variable by the stored average value on all devices.\n        for var, average_var in zip(\n            trainable_variables, self._model_variables_moving_average\n        ):\n            self._distribution_strategy.extended.update(\n                var, lambda a, b: a.assign(b), args=(average_var,)\n            )\n\n    def _backend_increment_gradient_accumulators(self, grads, acc_grads):\n        def update_accumulator(var, grad):\n            var.assign(var + grad)\n\n        accumulators = [v.value for v in acc_grads]\n\n        def _distributed_tf_increment_grad_acc(\n            distribution, grads, accumulators\n        ):\n            for grad, var in zip(grads, accumulators):\n                distribution.extended.update(\n                    var, update_accumulator, args=(grad,), group=False\n                )\n\n        tf.__internal__.distribute.interim.maybe_merge_call(\n            _distributed_tf_increment_grad_acc,\n            self._distribution_strategy,\n            grads,\n            accumulators,\n        )\n\n    def _clip_by_norm(self, values, axes=None):\n        # We need to use TF-specific OP to support the case,\n        # when `values` are `tf.IndexedSlices`.\n        return tf.clip_by_norm(values, self.clipnorm, axes)\n\n\ndef filter_empty_gradients(grads_and_vars):\n    \"\"\"Filter out `(grad, var)` pairs that have a gradient equal to `None`.\"\"\"\n    grads_and_vars = tuple(grads_and_vars)\n    if not grads_and_vars:\n        return grads_and_vars\n\n    filtered = []\n    vars_with_empty_grads = []\n    for grad, var in grads_and_vars:\n        if grad is None:\n            vars_with_empty_grads.append(var)\n        else:\n            filtered.append((grad, var))\n    filtered = tuple(filtered)\n\n    if not filtered:\n        variable = ([v.name for _, v in grads_and_vars],)\n        raise ValueError(\n            f\"No gradients provided for any variable: {variable}. \"\n            f\"Provided `grads_and_vars` is {grads_and_vars}.\"\n        )\n    if vars_with_empty_grads:\n        warnings.warn(\n            \"Gradients do not exist for variables %s when minimizing the \"\n            \"loss. If you're using `model.compile()`, did you forget to \"\n            \"provide a `loss` argument?\",\n            ([v.name for v in vars_with_empty_grads]),\n        )\n    return filtered\n"
  },
  {
    "path": "keras/src/backend/tensorflow/optimizer_distribute_test.py",
    "content": "# flake8: noqa\n\nimport numpy as np\nimport pytest\nimport tensorflow as tf\nfrom absl.testing import parameterized\nfrom tensorflow.python.eager import context\n\nfrom keras.src import backend\nfrom keras.src import testing\nfrom keras.src.optimizers.sgd import SGD\n\n\n@pytest.mark.skipif(\n    backend.backend() != \"tensorflow\",\n    reason=\"The distribute test can only run with TF backend.\",\n)\nclass OptimizerDistributeTest(testing.TestCase):\n    def setUp(self):\n        super().setUp()\n        # Need at least 2 devices for distribution related tests.\n        cpus = tf.config.list_physical_devices(\"CPU\")\n        context._reset_context()\n        tf.config.set_logical_device_configuration(\n            cpus[0],\n            [\n                tf.config.LogicalDeviceConfiguration(),\n                tf.config.LogicalDeviceConfiguration(),\n            ],\n        )\n        self.strategy = tf.distribute.MirroredStrategy([\"CPU:0\", \"CPU:1\"])\n\n    def test_config(self):\n        with self.strategy.scope():\n            optimizer = SGD(\n                learning_rate=0.5,\n                momentum=0.06,\n                nesterov=True,\n                weight_decay=0.004,\n            )\n        self.run_class_serialization_test(optimizer)\n\n    @parameterized.parameters([(\"keras_sgd\",), (\"tf_keras_sgd\",)])\n    def test_single_step(self, optimizer_type):\n        if optimizer_type == \"tf_keras_sgd\":\n            try:\n                import tf_keras\n\n                optimizer_fn = tf_keras.optimizers.SGD\n            except (ImportError, AttributeError):\n                self.skipTest(\"tf_keras not installed\")\n        else:\n            optimizer_fn = SGD\n        with self.strategy.scope():\n            optimizer = optimizer_fn(\n                learning_rate=0.5,\n                momentum=0.06,\n            )\n            # use tf variable to work both in k2 & k3.\n            vars = tf.Variable([1.0, 2.0, 3.0, 4.0])\n\n            def update():\n                grads = tf.constant([1.0, 6.0, 7.0, 2.0])\n                optimizer.apply_gradients(zip([grads], [vars]))\n\n            self.strategy.run(update)\n            self.assertAllClose(\n                vars, [0.0, -4.0, -4.0, 2.0], rtol=1e-4, atol=1e-4\n            )\n\n    def test_weight_decay(self):\n        with self.strategy.scope():\n            grads, var1, var2, var3 = (\n                tf.zeros(()),\n                backend.Variable(2.0),\n                backend.Variable(3.0, name=\"exclude\"),\n                backend.Variable(4.0),\n            )\n            optimizer_1 = SGD(learning_rate=1.0, weight_decay=0.004)\n            self.strategy.run(\n                lambda: optimizer_1.apply_gradients(zip([grads], [var1]))\n            )\n\n            optimizer_2 = SGD(learning_rate=1.0, weight_decay=0.004)\n\n            def opt2_run():\n                optimizer_2.exclude_from_weight_decay(var_names=[\"exclude\"])\n                optimizer_2.apply_gradients(zip([grads, grads], [var1, var2]))\n\n            self.strategy.run(opt2_run)\n\n            optimizer_3 = SGD(learning_rate=1.0, weight_decay=0.004)\n\n            def opt3_run():\n                optimizer_3.exclude_from_weight_decay(var_list=[var3])\n                optimizer_3.apply_gradients(zip([grads, grads], [var1, var3]))\n\n            self.strategy.run(opt3_run)\n\n        self.assertAlmostEqual(var1.numpy(), 1.9760959)\n        self.assertAlmostEqual(var2.numpy(), 3.0)\n        self.assertAlmostEqual(var3.numpy(), 4.0)\n\n    def test_correctness_with_golden(self):\n        with self.strategy.scope():\n            optimizer = SGD(nesterov=True)\n            x = backend.Variable(np.ones([10]))\n\n            def update_grads():\n                grads = backend.convert_to_tensor(np.arange(0.1, 1.1, 0.1))\n                optimizer.apply_gradients(zip([grads], [x]))\n\n            def update_first_grads():\n                first_grads = backend.convert_to_tensor(np.full((10,), 0.01))\n                optimizer.apply_gradients(zip([first_grads], [x]))\n\n        # fmt: off\n        golden = np.array(\n            [\n                [0.9980, 0.9960, 0.9940, 0.9920, 0.9900, 0.9880, 0.9860, 0.9840, 0.9820, 0.9800],\n                [0.9978, 0.9958, 0.9938, 0.9918, 0.9898, 0.9878, 0.9858, 0.9838, 0.9818, 0.9798],\n                [0.9976, 0.9956, 0.9936, 0.9916, 0.9896, 0.9876, 0.9856, 0.9836, 0.9816, 0.9796],\n                [0.9974, 0.9954, 0.9934, 0.9914, 0.9894, 0.9874, 0.9854, 0.9834, 0.9814, 0.9794],\n                [0.9972, 0.9952, 0.9932, 0.9912, 0.9892, 0.9872, 0.9852, 0.9832, 0.9812, 0.9792],\n            ]\n        )\n        # fmt: on\n\n        self.strategy.run(update_grads)\n        for i in range(5):\n            self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4)\n            self.strategy.run(update_first_grads)\n\n    def test_clip_norm(self):\n        with self.strategy.scope():\n            optimizer = SGD(clipnorm=1)\n            grad = [np.array([100.0, 100.0])]\n            clipped_grad = optimizer._clip_gradients(grad)\n            self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])\n\n    def test_clip_value(self):\n        with self.strategy.scope():\n            optimizer = SGD(clipvalue=1)\n            grad = [np.array([100.0, 100.0])]\n            clipped_grad = optimizer._clip_gradients(grad)\n            self.assertAllClose(clipped_grad[0], [1.0, 1.0])\n\n    def test_stateless_not_supported(self):\n        optimizer = SGD(learning_rate=0.5)\n        grads = [np.array([1.0, 6.0, 7.0, 2.0])]\n        vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])]\n        optimizer.build(vars)\n        with self.assertRaisesRegex(ValueError, \"not supported\"):\n            optimizer.stateless_apply(optimizer.variables, grads, vars)\n\n    def test_ema(self):\n        with self.strategy.scope():\n            v = backend.Variable([[3.0, 4.0], [5.0, 6.0]])\n            grads = backend.convert_to_tensor([[1.0, 1.0], [1.0, 1.0]])\n            optimizer = SGD(\n                learning_rate=1.0,\n                use_ema=True,\n                ema_momentum=0.9,\n                ema_overwrite_frequency=3,\n            )\n            self.strategy.run(lambda: optimizer.apply_gradients([(grads, v)]))\n            self.assertAllClose(v, [[2.0, 3.0], [4.0, 5.0]])\n            self.assertAllClose(\n                optimizer._model_variables_moving_average[0],\n                [[2.0, 3.0], [4.0, 5.0]],  # initialized after first step\n            )\n            self.strategy.run(lambda: optimizer.apply_gradients([(grads, v)]))\n            self.assertAllClose(v, [[1.0, 2.0], [3.0, 4.0]])\n            self.assertAllClose(\n                optimizer._model_variables_moving_average[0],\n                [[1.9, 2.9], [3.9, 4.9]],\n            )\n            self.strategy.run(lambda: optimizer.apply_gradients([(grads, v)]))\n            # Variables were overwritten with EMA\n            self.assertAllClose(v, [[1.71, 2.71], [3.71, 4.71]])\n            self.assertAllClose(\n                optimizer._model_variables_moving_average[0],\n                [[1.71, 2.71], [3.71, 4.71]],\n            )\n\n    def test_gradient_accumulation(self):\n        with self.strategy.scope():\n            v = backend.Variable([[1.0, 2.0], [3.0, 4.0]])\n            grads = backend.convert_to_tensor([[1.0, 1.0], [2.0, 2.0]])\n            optimizer = SGD(learning_rate=1.0, gradient_accumulation_steps=3)\n            self.assertEqual(optimizer.gradient_accumulation_steps, 3)\n            self.strategy.run(lambda: optimizer.apply_gradients([(grads, v)]))\n            self.assertAllClose(v, [[1.0, 2.0], [3.0, 4.0]])\n            self.assertAllClose(\n                optimizer._accumulated_gradients[0], [[1.0, 1.0], [2.0, 2.0]]\n            )\n            self.assertAllClose(optimizer._iterations, 1)\n            self.assertAllClose(optimizer.iterations, 0)\n            self.strategy.run(lambda: optimizer.apply_gradients([(grads, v)]))\n            self.assertAllClose(v, [[1.0, 2.0], [3.0, 4.0]])\n            self.assertAllClose(\n                optimizer._accumulated_gradients[0], [[2.0, 2.0], [4.0, 4.0]]\n            )\n            self.assertAllClose(optimizer._iterations, 2)\n            self.assertAllClose(optimizer.iterations, 0)\n            self.strategy.run(lambda: optimizer.apply_gradients([(grads, v)]))\n            self.assertAllClose(v, [[-1.0, 0.0], [-1.0, 0.0]])\n            self.assertAllClose(\n                optimizer._accumulated_gradients[0], [[0.0, 0.0], [0.0, 0.0]]\n            )\n            self.assertAllClose(optimizer._iterations, 3)\n            self.assertAllClose(optimizer.iterations, 1)\n"
  },
  {
    "path": "keras/src/backend/tensorflow/random.py",
    "content": "import tensorflow as tf\n\nfrom keras.src.backend.common import standardize_dtype\nfrom keras.src.backend.config import floatx\nfrom keras.src.random.seed_generator import SeedGenerator\nfrom keras.src.random.seed_generator import draw_seed\nfrom keras.src.random.seed_generator import make_default_seed\n\n\ndef _cast_seed(seed):\n    # TensorFlow has a device placement issue that `Variable` must be int64\n    # in `SeedGenerator`. However, all `tf.random.stateless_*` expect the seed\n    # to be int32 to run with XLA.\n    # This function addresses the inconsistency using `floormod`.\n    # Ref: https://www.tensorflow.org/api_docs/python/tf/random\n    if standardize_dtype(seed.dtype) == \"int32\":\n        return seed\n    else:\n        seed = tf.cast(tf.math.floormod(seed, tf.int32.max - 1), dtype=\"int32\")\n        return seed\n\n\ndef normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):\n    dtype = dtype or floatx()\n    seed = _cast_seed(draw_seed(seed))\n    return tf.random.stateless_normal(\n        shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed\n    )\n\n\ndef uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):\n    dtype = dtype or floatx()\n    seed = _cast_seed(draw_seed(seed))\n    return tf.random.stateless_uniform(\n        shape=shape,\n        minval=tf.cast(minval, dtype),\n        maxval=tf.cast(maxval, dtype),\n        dtype=dtype,\n        seed=seed,\n    )\n\n\ndef categorical(logits, num_samples, dtype=\"int64\", seed=None):\n    seed = _cast_seed(draw_seed(seed))\n    output = tf.random.stateless_categorical(logits, num_samples, seed=seed)\n    return tf.cast(output, dtype)\n\n\ndef randint(shape, minval, maxval, dtype=\"int32\", seed=None):\n    intermediate_dtype = dtype\n    if standardize_dtype(dtype) not in [\"int32\", \"int64\"]:\n        intermediate_dtype = \"int64\"\n    seed = _cast_seed(draw_seed(seed))\n    output = tf.random.stateless_uniform(\n        shape=shape,\n        minval=minval,\n        maxval=maxval,\n        dtype=intermediate_dtype,\n        seed=seed,\n    )\n    return tf.cast(output, dtype)\n\n\ndef truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):\n    dtype = dtype or floatx()\n    seed = _cast_seed(draw_seed(seed))\n    return tf.random.stateless_truncated_normal(\n        shape=shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed\n    )\n\n\ndef _get_concrete_noise_shape(inputs, noise_shape):\n    if noise_shape is None:\n        return tf.shape(inputs)\n\n    concrete_inputs_shape = tf.shape(inputs)\n    concrete_noise_shape = []\n    for i, value in enumerate(noise_shape):\n        concrete_noise_shape.append(\n            concrete_inputs_shape[i] if value is None else value\n        )\n    return concrete_noise_shape\n\n\ndef dropout(inputs, rate, noise_shape=None, seed=None):\n    if rate == 1.0:\n        return tf.zeros_like(inputs)\n    if rate == 0.0:\n        return inputs\n    seed = _cast_seed(draw_seed(seed))\n    noise_shape = _get_concrete_noise_shape(inputs, noise_shape)\n    return tf.nn.experimental.stateless_dropout(\n        inputs,\n        rate=rate,\n        noise_shape=noise_shape,\n        seed=seed,\n    )\n\n\ndef shuffle(x, axis=0, seed=None):\n    seed = _cast_seed(draw_seed(seed))\n    indices = tf.argsort(\n        tf.random.stateless_uniform(shape=[tf.shape(x)[axis]], seed=seed)\n    )\n    return tf.gather(x, indices, axis=axis)\n\n\ndef gamma(shape, alpha, dtype=None, seed=None):\n    dtype = dtype or floatx()\n    seed = _cast_seed(draw_seed(seed))\n    # TODO: `tf.random.stateless_gamma` doesn't support bfloat16\n    intermediate_dtype = dtype\n    if standardize_dtype(dtype) == \"bfloat16\":\n        intermediate_dtype = \"float32\"\n    return tf.cast(\n        tf.random.stateless_gamma(\n            shape,\n            alpha=alpha,\n            dtype=intermediate_dtype,\n            seed=seed,\n        ),\n        dtype,\n    )\n\n\ndef binomial(shape, counts, probabilities, dtype=None, seed=None):\n    dtype = dtype or floatx()\n    seed = _cast_seed(draw_seed(seed))\n    # TODO: `tf.random.stateless_binomial` doesn't support bfloat16\n    intermediate_dtype = dtype\n    if standardize_dtype(dtype) == \"bfloat16\":\n        intermediate_dtype = \"float32\"\n    return tf.cast(\n        tf.random.stateless_binomial(\n            shape=shape,\n            seed=seed,\n            counts=counts,\n            probs=probabilities,\n            output_dtype=intermediate_dtype,\n        ),\n        dtype,\n    )\n\n\ndef beta(shape, alpha, beta, dtype=None, seed=None):\n    dtype = dtype or floatx()\n    # since tensorflow doesn't offer a beta distribution function\n    # so we'll use the formula U(a,b) = (X(a) / (X(a) + Y(b)),\n    # where U(a,b) is a beta-distributed random variable with\n    # parameters a and b, and X(a) and Y(b) are gamma-distributed\n    # random variables with parameters a and b respectively.\n\n    # Additionally, we'll use two different seeds for our two\n    # gamma random variables to prevent any unintended\n    # dependencies and correlations between the generated values\n    # due to the usage of same seed.\n    seed_1 = _cast_seed(draw_seed(seed))\n    # The choice of 12 is totally arbitrary, as we're\n    # incrementing the first drawn seed by a CONSTANT to\n    # ensure deterministic results.\n    seed_2 = seed_1 + 12\n\n    # TODO: `tf.random.stateless_gamma` doesn't support bfloat16\n    intermediate_dtype = dtype\n    if standardize_dtype(dtype) == \"bfloat16\":\n        intermediate_dtype = \"float32\"\n    alpha = tf.convert_to_tensor(alpha, dtype=intermediate_dtype)\n    beta = tf.convert_to_tensor(beta, dtype=intermediate_dtype)\n\n    # tensorflow's tf.random.stateless_gamma has a bit of unconventional\n    # implementation of the stateless_gamma function where it checks the\n    # broadcastability of alpha's shape with ONLY the RIGHTMOST dimension of\n    # the specified output shape instead of considering the whole.\n    # Consequently, it then results in errors for perfectly broadcastable shapes\n    # such as for output shape of (2, 3) and alpha shape of (1, 3)\n    # So to resolve this, we explicitly broadcast alpha and beta to shape before\n    # passing them to the stateless_gamma function.\n    alpha = tf.broadcast_to(alpha, shape)\n    beta = tf.broadcast_to(beta, shape)\n\n    gamma_a = tf.cast(\n        tf.random.stateless_gamma(\n            shape=shape, seed=seed_1, alpha=alpha, dtype=intermediate_dtype\n        ),\n        dtype,\n    )\n    gamma_b = tf.cast(\n        tf.random.stateless_gamma(\n            shape=shape, seed=seed_2, alpha=beta, dtype=intermediate_dtype\n        ),\n        dtype,\n    )\n    sample = gamma_a / (gamma_a + gamma_b)\n    return sample\n"
  },
  {
    "path": "keras/src/backend/tensorflow/rnn.py",
    "content": "import tensorflow as tf\n\nfrom keras.src import tree\n\n\ndef rnn(\n    step_function,\n    inputs,\n    initial_states,\n    go_backwards=False,\n    mask=None,\n    constants=None,\n    unroll=False,\n    input_length=None,\n    time_major=False,\n    zero_output_for_mask=False,\n    return_all_outputs=True,\n):\n    \"\"\"Iterates over the time dimension of a tensor.\n\n    Args:\n        step_function: RNN step function.\n            Args;\n                `input`; Tensor with shape `(samples, ...)` (no time dimension),\n                    representing input for the batch of samples at a certain\n                    time step.\n                `states`; List of tensors.\n            Returns;\n                `output`; Tensor with shape `(samples, output_dim)`\n                    (no time dimension).\n                `new_states`; List of tensors, same length and shapes\n                    as 'states'. The first state in the list must be the\n                    output tensor at the previous timestep.\n        inputs: Tensor of temporal data of shape `(samples, time, ...)`\n            (at least 3D), or nested tensors, and each of which has shape\n            `(samples, time, ...)`.\n        initial_states: Tensor with shape `(samples, state_size)`\n            (no time dimension), containing the initial values for the states\n            used in the step function. In the case that state_size is in a\n            nested shape, the shape of initial_states will also follow the\n            nested structure.\n        go_backwards: Boolean. If `True`, do the iteration over the time\n            dimension in reverse order and return the reversed sequence.\n        mask: Binary tensor with shape `(samples, time, 1)`,\n            with a zero for every element that is masked.\n        constants: List of constant values passed at each step.\n        unroll: Whether to unroll the RNN or to use a symbolic `while_loop`.\n        input_length: An integer or a 1-D Tensor, depending on whether\n            the time dimension is fixed-length or not. In case of variable\n            length input, it is used for masking in case there's no mask\n            specified.\n        time_major: Boolean. If `True`, the inputs and outputs will be in shape\n            `(timesteps, batch, ...)`, whereas in the False case, it will be\n            `(batch, timesteps, ...)`. Using `time_major = True` is a bit more\n            efficient because it avoids transposes at the beginning and end of\n            the RNN calculation. However, most TensorFlow data is batch-major,\n            so by default this function accepts input and emits output in\n            batch-major form.\n        zero_output_for_mask: Boolean. If `True`, the output for masked timestep\n            will be zeros, whereas in the `False` case, output from previous\n            timestep is returned.\n        return_all_outputs: Boolean. If `True`, return the recurrent outputs for\n            all timesteps in the sequence. If `False`, only return the output\n            for the last timestep (which consumes less memory).\n\n    Returns:\n        A tuple, `(last_output, outputs, new_states)`.\n            - `last_output`: the latest output of the rnn,\n                with shape `(samples, ...)`.\n            - `outputs`:\n                - If `return_all_outputs=True`: a tensor with shape\n                  `(samples, time, ...)` where each entry `outputs[s, t]` is the\n                  output of the step function at time `t` for sample `s`\n                - Else, a tensor equal to `last_output` with shape\n                  `(samples, 1, ...)`\n            - `new_states`: list of tensors, latest states returned by\n                the step function, of shape `(samples, ...)`.\n    \"\"\"\n    input_length = input_length or inputs.shape[1]\n\n    def swap_batch_timestep(input_t):\n        # Swap the batch and timestep dim for the incoming tensor.\n        axes = list(range(len(input_t.shape)))\n        axes[0], axes[1] = 1, 0\n        return tf.transpose(input_t, axes)\n\n    if not time_major:\n        inputs = tree.map_structure(swap_batch_timestep, inputs)\n\n    flattened_inputs = tree.flatten(inputs)\n    time_steps = flattened_inputs[0].shape[0]\n    time_steps_t = (\n        tf.shape(flattened_inputs[0])[0] if time_steps is None else time_steps\n    )\n\n    for input_ in flattened_inputs:\n        input_.shape.with_rank_at_least(3)\n\n    if mask is not None:\n        if mask.dtype != tf.bool:\n            mask = tf.cast(mask, tf.bool)\n        if len(mask.shape) == 2:\n            mask = tf.expand_dims(mask, axis=-1)\n        if not time_major:\n            mask = swap_batch_timestep(mask)\n\n    if constants is None:\n        constants = []\n\n    # tf.where needs its condition tensor to be the same shape as its two\n    # result tensors, but in our case the condition (mask) tensor is\n    # (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.\n    # So we need to broadcast the mask to match the shape of inputs.\n    # That's what the tile call does, it just repeats the mask along its\n    # second dimension n times.\n    def _expand_mask(mask_t, input_t, fixed_dim=1):\n        if tree.is_nested(mask_t):\n            raise ValueError(\n                f\"mask_t is expected to be tensor, but got {mask_t}\"\n            )\n        if tree.is_nested(input_t):\n            raise ValueError(\n                f\"input_t is expected to be tensor, but got {input_t}\"\n            )\n        rank_diff = len(input_t.shape) - len(mask_t.shape)\n        for _ in range(rank_diff):\n            mask_t = tf.expand_dims(mask_t, -1)\n        multiples = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:]\n        return tf.tile(mask_t, multiples)\n\n    if unroll:\n        if not time_steps:\n            raise ValueError(\"Unrolling requires a fixed number of timesteps.\")\n        states = tuple(initial_states)\n        successive_states = []\n        successive_outputs = []\n\n        # Process the input tensors. The input tensor need to be split on the\n        # time_step dim, and reverse if go_backwards is True. In the case of\n        # nested input, the input is flattened and then transformed\n        # individually.  The result of this will be a tuple of lists, each of\n        # the item in tuple is list of the tensor with shape (batch, feature)\n        def _process_single_input_t(input_t):\n            input_t = tf.unstack(input_t)  # unstack for time_step dim\n            if go_backwards:\n                input_t.reverse()\n            return input_t\n\n        if tree.is_nested(inputs):\n            processed_input = tree.map_structure(\n                _process_single_input_t, inputs\n            )\n        else:\n            processed_input = (_process_single_input_t(inputs),)\n\n        def _get_input_tensor(time):\n            inp = [t_[time] for t_ in processed_input]\n            return tree.pack_sequence_as(inputs, inp)\n\n        if mask is not None:\n            mask_list = tf.unstack(mask)\n            if go_backwards:\n                mask_list.reverse()\n\n            for i in range(time_steps):\n                inp = _get_input_tensor(i)\n                mask_t = mask_list[i]\n                output, new_states = step_function(\n                    inp, tuple(states) + tuple(constants)\n                )\n                tiled_mask_t = _expand_mask(mask_t, output)\n\n                if not successive_outputs:\n                    prev_output = tf.zeros_like(output)\n                else:\n                    prev_output = successive_outputs[-1]\n\n                output = tf.where(tiled_mask_t, output, prev_output)\n\n                flat_states = tree.flatten(states)\n                flat_new_states = tree.flatten(new_states)\n                tiled_mask_t = tuple(\n                    _expand_mask(mask_t, s) for s in flat_states\n                )\n                flat_final_states = tuple(\n                    tf.where(m, s, ps)\n                    for m, s, ps in zip(\n                        tiled_mask_t, flat_new_states, flat_states\n                    )\n                )\n                states = tree.pack_sequence_as(states, flat_final_states)\n\n                if return_all_outputs:\n                    successive_outputs.append(output)\n                    successive_states.append(states)\n                else:\n                    successive_outputs = [output]\n                    successive_states = [states]\n            last_output = successive_outputs[-1]\n            new_states = successive_states[-1]\n            outputs = tf.stack(successive_outputs)\n\n            if zero_output_for_mask:\n                last_output = tf.where(\n                    _expand_mask(mask_list[-1], last_output),\n                    last_output,\n                    tf.zeros_like(last_output),\n                )\n                outputs = tf.where(\n                    _expand_mask(mask, outputs, fixed_dim=2),\n                    outputs,\n                    tf.zeros_like(outputs),\n                )\n\n        else:  # mask is None\n            for i in range(time_steps):\n                inp = _get_input_tensor(i)\n                output, states = step_function(\n                    inp, tuple(states) + tuple(constants)\n                )\n                if return_all_outputs:\n                    successive_outputs.append(output)\n                    successive_states.append(states)\n                else:\n                    successive_outputs = [output]\n                    successive_states = [states]\n            last_output = successive_outputs[-1]\n            new_states = successive_states[-1]\n            outputs = tf.stack(successive_outputs)\n\n    else:  # Unroll == False\n        states = tuple(initial_states)\n\n        # Create input tensor array, if the inputs is nested tensors, then it\n        # will be flattened first, and tensor array will be created one per\n        # flattened tensor.\n        input_ta = tuple(\n            tf.TensorArray(\n                dtype=inp.dtype,\n                size=time_steps_t,\n                tensor_array_name=f\"input_ta_{i}\",\n            )\n            for i, inp in enumerate(flattened_inputs)\n        )\n        input_ta = tuple(\n            (\n                ta.unstack(input_)\n                if not go_backwards\n                else ta.unstack(tf.reverse(input_, [0]))\n            )\n            for ta, input_ in zip(input_ta, flattened_inputs)\n        )\n\n        # Get the time(0) input and compute the output for that, the output will\n        # be used to determine the dtype of output tensor array. Don't read from\n        # input_ta due to TensorArray clear_after_read default to True.\n        input_time_zero = tree.pack_sequence_as(\n            inputs, [inp[0] for inp in flattened_inputs]\n        )\n        # output_time_zero is used to determine the cell output shape and its\n        # dtype.  the value is discarded.\n        output_time_zero, _ = step_function(\n            input_time_zero, tuple(initial_states) + tuple(constants)\n        )\n\n        output_ta_size = time_steps_t if return_all_outputs else 1\n        output_ta = tuple(\n            tf.TensorArray(\n                dtype=out.dtype,\n                size=output_ta_size,\n                element_shape=out.shape,\n                tensor_array_name=f\"output_ta_{i}\",\n            )\n            for i, out in enumerate(tree.flatten(output_time_zero))\n        )\n\n        time = tf.constant(0, dtype=\"int32\", name=\"time\")\n\n        if input_length is None:\n            max_iterations = time_steps_t\n        else:\n            max_iterations = tf.reduce_max(input_length)\n\n        while_loop_kwargs = {\n            \"cond\": lambda time, *_: time < time_steps_t,\n            \"maximum_iterations\": max_iterations,\n            \"parallel_iterations\": 32,\n            \"swap_memory\": True,\n        }\n        if mask is not None:\n            if go_backwards:\n                mask = tf.reverse(mask, [0])\n\n            mask_ta = tf.TensorArray(\n                dtype=tf.bool, size=time_steps_t, tensor_array_name=\"mask_ta\"\n            )\n            mask_ta = mask_ta.unstack(mask)\n\n            def masking_fn(time):\n                return mask_ta.read(time)\n\n            def compute_masked_output(mask_t, flat_out, flat_mask):\n                tiled_mask_t = tuple(\n                    _expand_mask(mask_t, o, fixed_dim=len(mask_t.shape))\n                    for o in flat_out\n                )\n                return tuple(\n                    tf.where(m, o, fm)\n                    for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask)\n                )\n\n        elif isinstance(input_length, tf.Tensor):\n            if go_backwards:\n                max_len = tf.reduce_max(input_length, axis=0)\n                rev_input_length = tf.subtract(max_len - 1, input_length)\n\n                def masking_fn(time):\n                    return tf.less(rev_input_length, time)\n\n            else:\n\n                def masking_fn(time):\n                    return tf.greater(input_length, time)\n\n            def compute_masked_output(mask_t, flat_out, flat_mask):\n                return tuple(\n                    tf.where(mask_t, o, zo)\n                    for (o, zo) in zip(flat_out, flat_mask)\n                )\n\n        else:\n            masking_fn = None\n\n        if masking_fn is not None:\n            # Mask for the T output will be base on the output of T - 1. In the\n            # case T = 0, a zero filled tensor will be used.\n            flat_zero_output = tuple(\n                tf.zeros_like(o) for o in tree.flatten(output_time_zero)\n            )\n\n            def _step(time, output_ta_t, prev_output, *states):\n                \"\"\"RNN step function.\n\n                Args:\n                    time: Current timestep value.\n                    output_ta_t: TensorArray.\n                    prev_output: tuple of outputs from time - 1.\n                    *states: List of states.\n\n                Returns:\n                    Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)`\n                \"\"\"\n                current_input = tuple(ta.read(time) for ta in input_ta)\n                # maybe set shape.\n                current_input = tree.pack_sequence_as(inputs, current_input)\n                mask_t = masking_fn(time)\n                output, new_states = step_function(\n                    current_input, tuple(states) + tuple(constants)\n                )\n                # mask output\n                flat_output = tree.flatten(output)\n                flat_mask_output = (\n                    flat_zero_output\n                    if zero_output_for_mask\n                    else tree.flatten(prev_output)\n                )\n                flat_new_output = compute_masked_output(\n                    mask_t, flat_output, flat_mask_output\n                )\n\n                # mask states\n                flat_state = tree.flatten(states)\n                flat_new_state = tree.flatten(new_states)\n                flat_final_state = compute_masked_output(\n                    mask_t, flat_new_state, flat_state\n                )\n                new_states = tree.pack_sequence_as(new_states, flat_final_state)\n\n                ta_index_to_write = time if return_all_outputs else 0\n                output_ta_t = tuple(\n                    ta.write(ta_index_to_write, out)\n                    for ta, out in zip(output_ta_t, flat_new_output)\n                )\n\n                return (time + 1, output_ta_t, tuple(flat_new_output)) + tuple(\n                    new_states\n                )\n\n            final_outputs = tf.while_loop(\n                body=_step,\n                loop_vars=(time, output_ta, flat_zero_output) + states,\n                **while_loop_kwargs,\n            )\n            # Skip final_outputs[2] which is the output for final timestep.\n            new_states = final_outputs[3:]\n        else:\n\n            def _step(time, output_ta_t, *states):\n                \"\"\"RNN step function.\n\n                Args:\n                    time: Current timestep value.\n                    output_ta_t: TensorArray.\n                    *states: List of states.\n\n                Returns:\n                    Tuple: `(time + 1,output_ta_t) + tuple(new_states)`\n                \"\"\"\n                current_input = tuple(ta.read(time) for ta in input_ta)\n                current_input = tree.pack_sequence_as(inputs, current_input)\n                output, new_states = step_function(\n                    current_input, tuple(states) + tuple(constants)\n                )\n                flat_new_state = tree.flatten(new_states)\n\n                flat_output = tree.flatten(output)\n                ta_index_to_write = time if return_all_outputs else 0\n                output_ta_t = tuple(\n                    ta.write(ta_index_to_write, out)\n                    for ta, out in zip(output_ta_t, flat_output)\n                )\n\n                new_states = tree.pack_sequence_as(\n                    initial_states, flat_new_state\n                )\n                return (time + 1, output_ta_t) + tuple(new_states)\n\n            final_outputs = tf.while_loop(\n                body=_step,\n                loop_vars=(time, output_ta) + states,\n                **while_loop_kwargs,\n            )\n            new_states = final_outputs[2:]\n\n        output_ta = final_outputs[1]\n\n        outputs = tuple(o.stack() for o in output_ta)\n        last_output = tuple(o[-1] for o in outputs)\n\n        outputs = tree.pack_sequence_as(output_time_zero, outputs)\n        last_output = tree.pack_sequence_as(output_time_zero, last_output)\n\n    if not time_major:\n        outputs = tree.map_structure(swap_batch_timestep, outputs)\n\n    return last_output, outputs, new_states\n\n\ndef gru(\n    inputs,\n    initial_state,\n    mask,\n    kernel,\n    recurrent_kernel,\n    bias,\n    activation,\n    recurrent_activation,\n    return_sequences=False,\n    go_backwards=False,\n    unroll=False,\n    time_major=False,\n    reset_after=True,\n):\n    cudnn_supported = cudnn_ok(\n        activation,\n        recurrent_activation,\n        unroll,\n        use_bias=bias is not None,\n        reset_after=reset_after,\n    )\n    if not cudnn_supported:\n        raise NotImplementedError\n\n    from keras.src.backend.tensorflow import Variable\n\n    if isinstance(kernel, Variable):\n        kernel = kernel.value\n    if isinstance(recurrent_kernel, Variable):\n        recurrent_kernel = recurrent_kernel.value\n    if isinstance(bias, Variable):\n        bias = bias.value\n\n    try:\n        return _cudnn_gru(\n            inputs,\n            initial_state,\n            kernel,\n            recurrent_kernel,\n            bias,\n            mask,\n            time_major,\n            go_backwards,\n            return_sequences,\n        )\n    except tf.errors.InvalidArgumentError:\n        # cuDNN op not found.\n        raise NotImplementedError\n    except tf.errors.NotFoundError:\n        # alternative error: device not found for op\n        raise NotImplementedError\n\n\ndef _do_gru_arguments_support_cudnn(\n    activation,\n    recurrent_activation,\n    unroll,\n    use_bias,\n    reset_after,\n):\n    from keras.src import activations\n    from keras.src import ops\n\n    return (\n        activation in (activations.tanh, tf.tanh, ops.tanh)\n        and recurrent_activation\n        in (activations.sigmoid, tf.sigmoid, ops.sigmoid)\n        and not unroll\n        and use_bias\n        and reset_after\n    )\n\n\ndef _do_lstm_arguments_support_cudnn(\n    activation,\n    recurrent_activation,\n    unroll,\n    use_bias,\n):\n    from keras.src import activations\n    from keras.src import ops\n\n    return (\n        activation in (activations.tanh, tf.tanh, ops.tanh)\n        and recurrent_activation\n        in (activations.sigmoid, tf.sigmoid, ops.sigmoid)\n        and not unroll\n        and use_bias\n    )\n\n\ndef _has_fully_masked_sequence(mask):\n    \"\"\"Check if input sequence contains any fully masked data.\n\n    cuDNN kernel will error out if the input sequence contains any fully masked\n    data. We work around this issue by rerouting the computation to the\n    standard kernel until the issue on the cuDNN side has been fixed. For a\n    fully masked sequence, it will contain all `False` values. To make it easy\n    to check, we invert the boolean and check if any of the sequences has all\n    `True` values.\n\n    Args:\n        mask: The mask tensor.\n\n    Returns:\n        A boolean tensor, `True` if the mask contains a fully masked sequence.\n    \"\"\"\n    return tf.reduce_any(\n        tf.reduce_all(tf.logical_not(tf.cast(mask, dtype=\"bool\")), axis=1)\n    )\n\n\ndef _assert_valid_mask(mask):\n    valid = tf.logical_and(\n        tf.logical_not(_has_fully_masked_sequence(mask)),\n        _is_sequence_right_padded(mask),\n    )\n    tf.Assert(\n        valid,\n        [\n            (\n                \"You are passing a RNN mask that does not correspond to \"\n                \"right-padded sequences, while using cuDNN, which is not \"\n                \"supported. With cuDNN, RNN masks can only be used for \"\n                \"right-padding, e.g. `[[True, True, False, False]]` would \"\n                \"be a valid mask, but any mask that isn't just contiguous \"\n                \"`True`'s on the left and contiguous `False`'s on the right \"\n                \"would be invalid. You can pass `use_cudnn=False` to your \"\n                \"RNN layer to stop using cuDNN (this may be slower).\"\n            )\n        ],\n    )\n\n\ndef _standardize_cudnn_weights(weights, biases, shape, transpose_weights=False):\n    \"\"\"Utility function convert variable to cuDNN compatible parameter.\n\n    Note that Keras weights for kernels are different from the cuDNN format.\n    Eg.:\n\n    ```\n      Keras                 cuDNN\n      [[0, 1, 2],  <--->  [[0, 2, 4],\n       [3, 4, 5]]          [1, 3, 5]]\n    ```\n\n    If the input weights need to be in a unified format, then set\n    `transpose_weights=True` to convert the weights.\n\n    Args:\n        weights: list of weights for the kernels and recurrent kernels.\n        biases: list of biases for individual gate.\n        shape: the shape for the converted variables that will be feed to cuDNN.\n        transpose_weights: boolean, whether to transpose the weights.\n\n    Returns:\n        The converted weights that can be feed to cuDNN ops as param.\n    \"\"\"\n\n    def convert(w):\n        return tf.transpose(w) if transpose_weights else w\n\n    weights = [tf.reshape(convert(x), shape) for x in weights]\n    biases = [tf.reshape(x, shape) for x in biases]\n    return tf.concat(weights + biases, axis=0)\n\n\ndef _is_sequence_right_padded(mask):\n    \"\"\"Check the mask tensor and see if it right padded.\n\n    cuDNN uses the sequence length param to skip the tailing\n    timestep. If the data is left padded, or not a strict right padding (has\n    masked value in the middle of the sequence), then cuDNN won't work\n    properly in those cases.\n\n    Left padded data: [[False, False, True, True, True]].\n    Right padded data: [[True, True, True, False, False]].\n    Mixture of mask/unmasked data: [[True, False, True, False, False]].\n\n    Note that for the mixed data example above, the actually data RNN should see\n    are those 2 Trues (index 0 and 2), the index 1 False should be ignored and\n    not pollute the internal states.\n\n    Args:\n        mask: the Boolean tensor with shape [batch, timestep]\n\n    Returns:\n        boolean scalar tensor, whether the mask is strictly right padded.\n    \"\"\"\n    max_seq_length = tf.shape(mask)[1]\n    count_of_true = tf.reduce_sum(tf.cast(mask, tf.int32), axis=1)\n    right_padded_mask = tf.sequence_mask(count_of_true, maxlen=max_seq_length)\n    return tf.reduce_all(\n        tf.equal(\n            tf.cast(mask, dtype=\"bool\"),\n            tf.cast(right_padded_mask, dtype=\"bool\"),\n        )\n    )\n\n\ndef _compute_sequence_length_from_mask(mask, time_major):\n    \"\"\"Calculate the sequence length tensor (1-D) based on the masking tensor.\n\n    The masking tensor is a 2D boolean tensor with shape [batch, timestep]. For\n    any timestep that should be masked, the corresponding field will be False.\n    Consider the following example:\n      a = [[True, True, False, False],\n           [True, True, True, False]]\n    It is a (2, 4) tensor, and the corresponding sequence length result should\n    be 1D tensor with value [2, 3]. Note that the masking tensor must be right\n    padded that could be checked by, e.g., `is_sequence_right_padded()`.\n\n    Args:\n        mask: Boolean tensor with shape [batch, timestep] or [timestep, batch]\n            if time_major=True.\n        time_major: Boolean, which indicates whether the mask is time major or\n            batch major.\n\n    Returns:\n        sequence_length: 1D int32 tensor.\n    \"\"\"\n    timestep_index = 0 if time_major else 1\n    return tf.reduce_sum(tf.cast(mask, tf.int32), axis=timestep_index)\n\n\ndef _is_gpu_available():\n    return bool(tf.config.list_logical_devices(\"GPU\"))\n\n\ndef _cudnn_gru(\n    inputs,\n    initial_state,\n    kernel,\n    recurrent_kernel,\n    bias,\n    mask,\n    time_major,\n    go_backwards,\n    return_sequences,\n):\n    \"\"\"GRU with cuDNN implementation which is only available for GPU.\"\"\"\n    if mask is not None:\n        _assert_valid_mask(mask)\n        sequence_lengths = _compute_sequence_length_from_mask(mask, time_major)\n    else:\n        if time_major:\n            batch_dim = tf.shape(inputs)[1]\n            max_sequence_length = tf.shape(inputs)[0]\n        else:\n            batch_dim = tf.shape(inputs)[0]\n            max_sequence_length = tf.shape(inputs)[1]\n        sequence_lengths = tf.fill([batch_dim], max_sequence_length)\n\n    if not time_major and sequence_lengths is None:\n        inputs = tf.transpose(inputs, perm=(1, 0, 2))\n        seq_axis, batch_axis = (0, 1)\n    else:\n        seq_axis, batch_axis = (0, 1) if time_major else (1, 0)\n\n    # For init_h, cuDNN expects one more dim of num_layers before or after batch\n    # dim for time major or batch major inputs respectively\n    init_h = tf.expand_dims(initial_state, axis=seq_axis)\n\n    weights = tf.split(kernel, 3, axis=1)\n    weights += tf.split(recurrent_kernel, 3, axis=1)\n    # Note that the bias was initialized as shape (2, 3 * units), flatten it to\n    # (6 * units)\n    bias = tf.split(tf.reshape(bias, [-1]), 6)\n\n    if tf.sysconfig.get_build_info()[\"is_cuda_build\"]:\n        # Note that the gate order for cuDNN is different from the canonical\n        # format.  canonical format is [z, r, h], whereas cuDNN is [r, z, h].\n        # The swap need to be done for kernel, recurrent_kernel, input_bias,\n        # recurrent_bias.\n        # z is update gate weights.\n        # r is reset gate weights.\n        # h is output gate weights.\n        weights[0], weights[1] = weights[1], weights[0]\n        weights[3], weights[4] = weights[4], weights[3]\n        bias[0], bias[1] = bias[1], bias[0]\n        bias[3], bias[4] = bias[4], bias[3]\n\n    params = _standardize_cudnn_weights(\n        weights=weights,\n        biases=bias,\n        shape=tf.constant([-1]),\n        transpose_weights=True,\n    )\n\n    if go_backwards:\n        # Three reversals are required. E.g.,\n        # normal input = [1, 2, 3, 0, 0]  # where 0 need to be masked\n        # reversed_input_to_cudnn = [3, 2, 1, 0, 0]\n        # output_from_cudnn = [6, 5, 4, 0, 0]\n        # expected_output = [0, 0, 6, 5 ,4]\n        inputs = tf.reverse_sequence(\n            inputs,\n            sequence_lengths,\n            seq_axis=seq_axis,\n            batch_axis=batch_axis,\n        )\n    outputs, h, _, _, _ = tf.raw_ops.CudnnRNNV3(\n        input=inputs,\n        input_h=init_h,\n        input_c=0,\n        params=params,\n        is_training=True,\n        rnn_mode=\"gru\",\n        sequence_lengths=sequence_lengths,\n        time_major=time_major,\n    )\n    if go_backwards:\n        outputs = tf.reverse_sequence(\n            outputs,\n            sequence_lengths,\n            seq_axis=seq_axis,\n            batch_axis=batch_axis,\n        )\n        outputs = tf.reverse(outputs, axis=[seq_axis])\n\n    last_output = outputs[-1]\n    if not time_major and sequence_lengths is None and return_sequences:\n        outputs = tf.transpose(outputs, perm=[1, 0, 2])\n    state = tf.squeeze(h, axis=seq_axis)\n\n    # In the case of variable length input, the cudnn kernel will fill zeros for\n    # the output, whereas the default keras behavior is to bring over the\n    # previous output for t-1, so that in the return_sequence=False case, user\n    # can quickly get the final effect output instead just 0s at the last\n    # timestep.  In order to mimic the default keras behavior, we copy the final\n    # h state as the last_output, since it is numerically same as the output.\n    if sequence_lengths is not None:\n        last_output = state\n\n    # Match CPU return format\n    if not return_sequences:\n        outputs = tf.expand_dims(last_output, axis=0 if time_major else 1)\n\n    return (\n        last_output,\n        outputs,\n        [state],\n    )\n\n\ndef cudnn_ok(\n    activation,\n    recurrent_activation,\n    unroll,\n    use_bias,\n    reset_after=None,\n):\n    if reset_after is None:\n        args_supported = _do_lstm_arguments_support_cudnn(\n            activation=activation,\n            recurrent_activation=recurrent_activation,\n            unroll=unroll,\n            use_bias=use_bias,\n        )\n    else:\n        args_supported = _do_gru_arguments_support_cudnn(\n            activation=activation,\n            recurrent_activation=recurrent_activation,\n            unroll=unroll,\n            use_bias=use_bias,\n            reset_after=reset_after,\n        )\n    return args_supported and _is_gpu_available()\n\n\ndef lstm(\n    inputs,\n    initial_state_h,\n    initial_state_c,\n    mask,\n    kernel,\n    recurrent_kernel,\n    bias,\n    activation,\n    recurrent_activation,\n    return_sequences=False,\n    go_backwards=False,\n    unroll=False,\n    time_major=False,\n):\n    cudnn_supported = cudnn_ok(\n        activation, recurrent_activation, unroll, use_bias=bias is not None\n    )\n    if not cudnn_supported:\n        raise NotImplementedError\n\n    from keras.src.backend.tensorflow import Variable\n\n    if isinstance(kernel, Variable):\n        kernel = kernel.value\n    if isinstance(recurrent_kernel, Variable):\n        recurrent_kernel = recurrent_kernel.value\n    if isinstance(bias, Variable):\n        bias = bias.value\n\n    try:\n        return _cudnn_lstm(\n            inputs,\n            initial_state_h,\n            initial_state_c,\n            kernel,\n            recurrent_kernel,\n            bias,\n            mask,\n            time_major,\n            go_backwards,\n            return_sequences,\n        )\n    except tf.errors.InvalidArgumentError:\n        # cuDNN op not found.\n        raise NotImplementedError\n    except tf.errors.NotFoundError:\n        # alternative error: device not found for op\n        raise NotImplementedError\n\n\ndef _cudnn_lstm(\n    inputs,\n    initial_state_h,\n    initial_state_c,\n    kernel,\n    recurrent_kernel,\n    bias,\n    mask,\n    time_major,\n    go_backwards,\n    return_sequences,\n):\n    if mask is not None:\n        _assert_valid_mask(mask)\n        sequence_lengths = _compute_sequence_length_from_mask(mask, time_major)\n    else:\n        if time_major:\n            batch_dim = tf.shape(inputs)[1]\n            max_sequence_length = tf.shape(inputs)[0]\n        else:\n            batch_dim = tf.shape(inputs)[0]\n            max_sequence_length = tf.shape(inputs)[1]\n        sequence_lengths = tf.fill([batch_dim], max_sequence_length)\n\n    if not time_major and sequence_lengths is None:\n        inputs = tf.transpose(inputs, perm=(1, 0, 2))\n        seq_axis, batch_axis = (0, 1)\n    else:\n        seq_axis, batch_axis = (0, 1) if time_major else (1, 0)\n    # For init_h and init_c, cuDNN expects one more dim of num_layers before or\n    # after batch dim for time major or batch major inputs respectively\n    init_h = tf.expand_dims(initial_state_h, axis=seq_axis)\n    init_c = tf.expand_dims(initial_state_c, axis=seq_axis)\n\n    weights = tf.split(kernel, 4, axis=1)\n    weights += tf.split(recurrent_kernel, 4, axis=1)\n    # cuDNN has an extra set of bias for inputs, we disable them (setting to 0),\n    # so that mathematically it is same as the canonical LSTM implementation.\n    full_bias = tf.concat((tf.zeros_like(bias), bias), 0)\n\n    if tf.sysconfig.get_build_info()[\"is_rocm_build\"]:\n        # ROCm MIOpen's weight sequence for LSTM is different from both\n        # canonical and cuDNN format\n        # MIOpen: [i, f, o, c] cuDNN/Canonical: [i, f, c, o]\n        # i is input gate weights.\n        # f is forget gate weights.\n        # o is output gate weights.\n        # c is cell gate weights.\n        weights = [weights[x] for x in (0, 1, 3, 2, 4, 5, 7, 6)]\n        # full_bias is a tensor of shape (8*n,)\n        full_bias = tf.split(full_bias, 8, axis=0)\n        full_bias = [full_bias[x] for x in (0, 1, 3, 2, 4, 5, 7, 6)]\n\n    params = _standardize_cudnn_weights(\n        weights=weights,\n        biases=tf.split(full_bias, 8),\n        shape=tf.constant([-1]),\n        transpose_weights=True,\n    )\n\n    if go_backwards:\n        # Three reversals are required. E.g.,\n        # normal input = [1, 2, 3, 0, 0]  # where 0 need to be masked\n        # reversed_input_to_cudnn = [3, 2, 1, 0, 0]\n        # output_from_cudnn = [6, 5, 4, 0, 0]\n        # expected_output = [0, 0, 6, 5 ,4]\n        inputs = tf.reverse_sequence(\n            inputs,\n            sequence_lengths,\n            seq_axis=seq_axis,\n            batch_axis=batch_axis,\n        )\n    outputs, h, c, _, _ = tf.raw_ops.CudnnRNNV3(\n        input=inputs,\n        input_h=init_h,\n        input_c=init_c,\n        params=params,\n        is_training=True,\n        rnn_mode=\"lstm\",\n        sequence_lengths=sequence_lengths,\n        time_major=time_major,\n    )\n    if go_backwards:\n        outputs = tf.reverse_sequence(\n            outputs,\n            sequence_lengths,\n            seq_axis=seq_axis,\n            batch_axis=batch_axis,\n        )\n        outputs = tf.reverse(outputs, axis=[seq_axis])\n\n    last_output = outputs[-1]\n    if not time_major and sequence_lengths is None and return_sequences:\n        outputs = tf.transpose(outputs, perm=[1, 0, 2])\n    h = tf.squeeze(h, axis=seq_axis)\n    c = tf.squeeze(c, axis=seq_axis)\n\n    # In the case of variable length input, the cudnn kernel will fill zeros for\n    # the output, whereas the default keras behavior is to bring over the\n    # previous output for t-1, so that in the return_sequence=False case, user\n    # can quickly get the final effect output instead just 0s at the last\n    # timestep.  In order to mimic the default keras behavior, we copy the final\n    # h state as the last_output, since it is numerically same as the output.\n    if sequence_lengths is not None:\n        last_output = h\n\n    # Match CPU return format\n    if not return_sequences:\n        outputs = tf.expand_dims(last_output, axis=0 if time_major else 1)\n\n    return (last_output, outputs, [h, c])\n"
  },
  {
    "path": "keras/src/backend/tensorflow/saved_model_test.py",
    "content": "\"\"\"Tests for SavedModel functionality under tf implementation.\"\"\"\n\nimport os\n\nimport numpy as np\nimport pytest\nimport tensorflow as tf\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import metrics\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import optimizers\nfrom keras.src import testing\nfrom keras.src.saving import object_registration\nfrom keras.src.testing.test_utils import named_product\n\n\n@object_registration.register_keras_serializable(package=\"my_package\")\nclass CustomModelX(models.Model):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.dense1 = layers.Dense(1)\n        self.dense2 = layers.Dense(1)\n\n    def call(self, inputs):\n        out = self.dense1(inputs)\n        return self.dense2(out)\n\n    def one(self):\n        return 1\n\n\n@object_registration.register_keras_serializable(package=\"my_package\")\nclass CustomSignatureModel(models.Model):\n    def __init__(self):\n        super(CustomSignatureModel, self).__init__()\n        self.v = tf.Variable(1.0)\n\n    @tf.function\n    def __call__(self, x):\n        return x * self.v\n\n    @tf.function(input_signature=[tf.TensorSpec([], tf.float32)])\n    def mutate(self, new_v):\n        self.v.assign(new_v)\n\n\n@pytest.mark.skipif(\n    backend.backend() != \"tensorflow\",\n    reason=\"The SavedModel test can only run with TF backend.\",\n)\nclass SavedModelTest(testing.TestCase):\n    def test_sequential(self):\n        model = models.Sequential([layers.Dense(1)])\n        model.compile(loss=\"mse\", optimizer=\"adam\")\n        X_train = np.random.rand(100, 3)\n        y_train = np.random.rand(100, 1)\n        model.fit(X_train, y_train)\n        path = os.path.join(self.get_temp_dir(), \"my_keras_model\")\n        tf.saved_model.save(model, path)\n        restored_model = tf.saved_model.load(path)\n        self.assertAllClose(\n            model(X_train),\n            restored_model.signatures[\"serving_default\"](\n                tf.convert_to_tensor(X_train, dtype=tf.float32)\n            )[\"output_0\"],\n            rtol=1e-4,\n            atol=1e-4,\n        )\n\n    def test_functional(self):\n        inputs = layers.Input(shape=(3,))\n        x = layers.Dense(1, name=\"first_dense\")(inputs)\n        outputs = layers.Dense(1, name=\"second_dense\")(x)\n        model = models.Model(inputs, outputs)\n        model.compile(\n            optimizer=\"adam\",\n            loss=\"mse\",\n        )\n        X_train = np.random.rand(100, 3)\n        y_train = np.random.rand(100, 1)\n        model.fit(X_train, y_train)\n        path = os.path.join(self.get_temp_dir(), \"my_keras_model\")\n        tf.saved_model.save(model, path)\n        restored_model = tf.saved_model.load(path)\n        self.assertAllClose(\n            model(X_train),\n            restored_model.signatures[\"serving_default\"](\n                tf.convert_to_tensor(X_train, dtype=tf.float32)\n            )[\"output_0\"],\n            rtol=1e-4,\n            atol=1e-4,\n        )\n\n    def test_subclassed(self):\n        model = CustomModelX()\n        model.compile(\n            optimizer=\"adam\",\n            loss=\"mse\",\n            metrics=[metrics.Hinge(), \"mse\"],\n        )\n        X_train = np.random.rand(100, 3)\n        y_train = np.random.rand(100, 1)\n        model.fit(X_train, y_train)\n        path = os.path.join(self.get_temp_dir(), \"my_keras_model\")\n        tf.saved_model.save(model, path)\n        restored_model = tf.saved_model.load(path)\n        self.assertAllClose(\n            model(X_train),\n            restored_model.signatures[\"serving_default\"](\n                tf.convert_to_tensor(X_train, dtype=tf.float32)\n            )[\"output_0\"],\n            rtol=1e-4,\n            atol=1e-4,\n        )\n\n    def test_custom_model_and_layer(self):\n        @object_registration.register_keras_serializable(package=\"my_package\")\n        class CustomLayer(layers.Layer):\n            def __call__(self, inputs):\n                return inputs\n\n        @object_registration.register_keras_serializable(package=\"my_package\")\n        class Model(models.Model):\n            def __init__(self):\n                super().__init__()\n                self.layer = CustomLayer()\n\n            @tf.function(input_signature=[tf.TensorSpec([None, 1])])\n            def call(self, inputs):\n                return self.layer(inputs)\n\n        model = Model()\n        inp = np.array([[1.0]])\n        result = model(inp)\n        path = os.path.join(self.get_temp_dir(), \"my_keras_model\")\n        tf.saved_model.save(model, path)\n        restored_model = tf.saved_model.load(path)\n        self.assertAllClose(\n            result,\n            restored_model.call(inp),\n            rtol=1e-4,\n            atol=1e-4,\n        )\n\n    @parameterized.named_parameters(\n        named_product(struct_type=[\"tuple\", \"array\", \"dict\"])\n    )\n    def test_model_with_input_structure(self, struct_type):\n        class TupleModel(models.Model):\n            def call(self, inputs):\n                x, y = inputs\n                return x + ops.mean(y, axis=1)\n\n        class ArrayModel(models.Model):\n            def call(self, inputs):\n                x = inputs[0]\n                y = inputs[1]\n                return x + ops.mean(y, axis=1)\n\n        class DictModel(models.Model):\n            def call(self, inputs):\n                x = inputs[\"x\"]\n                y = inputs[\"y\"]\n                return x + ops.mean(y, axis=1)\n\n        input_x = tf.constant([1.0])\n        input_y = tf.constant([[1.0, 0.0, 2.0]])\n        if struct_type == \"tuple\":\n            model = TupleModel()\n            inputs = (input_x, input_y)\n        elif struct_type == \"array\":\n            model = ArrayModel()\n            inputs = [input_x, input_y]\n        elif struct_type == \"dict\":\n            model = DictModel()\n            inputs = {\"x\": input_x, \"y\": input_y}\n\n        result = model(inputs)\n        path = os.path.join(self.get_temp_dir(), \"my_keras_model\")\n        tf.saved_model.save(model, path)\n        restored_model = tf.saved_model.load(path)\n        outputs = restored_model.signatures[\"serving_default\"](\n            inputs=input_x, inputs_1=input_y\n        )\n        self.assertAllClose(result, outputs[\"output_0\"], rtol=1e-4, atol=1e-4)\n\n    def test_multi_input_model(self):\n        input_1 = layers.Input(shape=(3,))\n        input_2 = layers.Input(shape=(5,))\n\n        y1 = layers.Dense(1)(input_1)\n        y2 = layers.Dense(1)(input_2)\n        layer_2 = layers.Dense(1, activation=\"relu\")\n        output_1 = layer_2(y1)\n        output_2 = layer_2(y2)\n        model = models.Model([input_1, input_2], [output_1, output_2])\n\n        input_arr_1 = np.random.random((1, 3)).astype(\"float32\")\n        input_arr_2 = np.random.random((1, 5)).astype(\"float32\")\n\n        model = models.Model([input_1, input_2], [output_1, output_2])\n        path = os.path.join(self.get_temp_dir(), \"my_keras_model\")\n        outputs_1 = model(\n            inputs=[\n                tf.convert_to_tensor(input_arr_1, dtype=tf.float32),\n                tf.convert_to_tensor(input_arr_2, dtype=tf.float32),\n            ],\n        )\n        tf.saved_model.save(model, path)\n        restored_model = tf.saved_model.load(path)\n\n        outputs_2 = restored_model.signatures[\"serving_default\"](\n            inputs=tf.convert_to_tensor(input_arr_1, dtype=tf.float32),\n            inputs_1=tf.convert_to_tensor(input_arr_2, dtype=tf.float32),\n        )\n        self.assertAllClose(\n            outputs_1[0], outputs_2[\"output_0\"], rtol=1e-4, atol=1e-4\n        )\n        self.assertAllClose(\n            outputs_1[1], outputs_2[\"output_1\"], rtol=1e-4, atol=1e-4\n        )\n\n    def test_multi_input_custom_model_and_layer(self):\n        @object_registration.register_keras_serializable(package=\"my_package\")\n        class CustomLayer(layers.Layer):\n            def build(self, *input_shape):\n                pass\n\n            def call(self, *input_list):\n                self.add_loss(input_list[-2] * 2)\n                return sum(input_list)\n\n        @object_registration.register_keras_serializable(package=\"my_package\")\n        class CustomModel(models.Model):\n            def build(self, *input_shape):\n                self.layer = CustomLayer()\n                self.layer.build(*input_shape)\n\n            @tf.function\n            def call(self, *inputs):\n                inputs = list(inputs)\n                return self.layer(*inputs)\n\n        model = CustomModel()\n        inp = [\n            tf.constant(i, shape=[1, 1], dtype=tf.float32) for i in range(1, 4)\n        ]\n        expected = model(*inp)\n        path = os.path.join(self.get_temp_dir(), \"my_keras_model\")\n        tf.saved_model.save(model, path)\n        restored_model = tf.saved_model.load(path)\n        output = restored_model.call(*inp)\n        self.assertAllClose(expected, output, rtol=1e-4, atol=1e-4)\n\n    def test_list_trackable_children_tracking(self):\n        @object_registration.register_keras_serializable(package=\"my_package\")\n        class CustomLayerList(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.sublayers = [\n                    layers.Dense(2),\n                    layers.Dense(2),\n                ]\n\n            def call(self, inputs):\n                x = inputs\n                for sublayer in self.sublayers:\n                    x = sublayer(x)\n                return x\n\n        inputs = layers.Input(shape=(1,))\n        outputs = CustomLayerList()(inputs)\n        model = models.Model(inputs, outputs)\n\n        inp = np.array([[1.0]])\n        expected = model(inp)\n        path = os.path.join(self.get_temp_dir(), \"my_keras_model\")\n        tf.saved_model.save(model, path)\n        restored_model = tf.saved_model.load(path)\n        self.assertAllClose(\n            expected,\n            restored_model.signatures[\"serving_default\"](\n                tf.convert_to_tensor(inp, dtype=tf.float32)\n            )[\"output_0\"],\n            rtol=1e-4,\n            atol=1e-4,\n        )\n\n    def test_dict_trackable_children_tracking(self):\n        @object_registration.register_keras_serializable(package=\"my_package\")\n        class CustomLayerDict(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.sublayers = {\n                    \"first_layer\": layers.Dense(2),\n                    \"second_layer\": layers.Dense(2),\n                }\n\n            def call(self, inputs):\n                x = inputs\n                for key, sublayer in self.sublayers.items():\n                    x = sublayer(x)\n                return x\n\n        inputs = layers.Input(shape=(1,))\n        outputs = CustomLayerDict()(inputs)\n        model = models.Model(inputs, outputs)\n\n        inp = np.array([[1.0]])\n        expected = model(inp)\n        path = os.path.join(self.get_temp_dir(), \"my_keras_model\")\n        tf.saved_model.save(model, path)\n        restored_model = tf.saved_model.load(path)\n        self.assertAllClose(\n            expected,\n            restored_model.signatures[\"serving_default\"](\n                tf.convert_to_tensor(inp, dtype=tf.float32)\n            )[\"output_0\"],\n            rtol=1e-4,\n            atol=1e-4,\n        )\n\n    def test_fixed_signature_string_dtype(self):\n        @object_registration.register_keras_serializable(package=\"my_package\")\n        class Adder(models.Model):\n            @tf.function(\n                input_signature=[tf.TensorSpec(shape=[], dtype=tf.string)]\n            )\n            def concat(self, x):\n                return x + x\n\n        model = Adder()\n        path = os.path.join(self.get_temp_dir(), \"my_keras_model\")\n        tf.saved_model.save(model, path)\n        restored_model = tf.saved_model.load(path)\n        self.assertEqual(model.concat(\"hello\"), restored_model.concat(\"hello\"))\n\n    def test_non_fixed_signature_string_dtype(self):\n        @object_registration.register_keras_serializable(package=\"my_package\")\n        class Adder(models.Model):\n            @tf.function\n            def concat(self, x):\n                return x + x\n\n        model = Adder()\n\n        no_fn_path = os.path.join(self.get_temp_dir(), \"my_keras_model_no_fn\")\n        tf.saved_model.save(model, no_fn_path)\n        restored_model = tf.saved_model.load(no_fn_path)\n        with self.assertRaisesRegex(ValueError, \"zero restored functions\"):\n            _ = restored_model.concat(\"hello\")\n\n        path = os.path.join(self.get_temp_dir(), \"my_keras_model\")\n        tf.saved_model.save(\n            model,\n            path,\n            signatures=model.concat.get_concrete_function(\n                tf.TensorSpec(shape=[], dtype=tf.string, name=\"string_input\")\n            ),\n        )\n        restored_model = tf.saved_model.load(path)\n        self.assertEqual(model.concat(\"hello\"), restored_model.concat(\"hello\"))\n\n    def test_fine_tuning(self):\n        model = CustomSignatureModel()\n        model_no_signatures_path = os.path.join(\n            self.get_temp_dir(), \"model_no_signatures\"\n        )\n        _ = model(tf.constant(0.0))\n\n        tf.saved_model.save(model, model_no_signatures_path)\n        restored_model = tf.saved_model.load(model_no_signatures_path)\n\n        self.assertLen(list(restored_model.signatures.keys()), 0)\n        self.assertEqual(restored_model(tf.constant(3.0)).numpy(), 3)\n        restored_model.mutate(tf.constant(2.0))\n        self.assertEqual(restored_model(tf.constant(3.0)).numpy(), 6)\n        optimizer = optimizers.SGD(0.05)\n\n        def train_step():\n            with tf.GradientTape() as tape:\n                loss = (10.0 - restored_model(tf.constant(2.0))) ** 2\n            variables = tape.watched_variables()\n            grads = tape.gradient(loss, variables)\n            optimizer.apply_gradients(zip(grads, variables))\n            return loss\n\n        for _ in range(10):\n            # \"v\" approaches 5, \"loss\" approaches 0\n            loss = train_step()\n\n        self.assertAllClose(loss, 0.0, rtol=1e-2, atol=1e-2)\n        self.assertAllClose(restored_model.v.numpy(), 5.0, rtol=1e-2, atol=1e-2)\n\n    def test_signatures_path(self):\n        model = CustomSignatureModel()\n        model_with_signature_path = os.path.join(\n            self.get_temp_dir(), \"model_with_signature\"\n        )\n        call = model.__call__.get_concrete_function(\n            tf.TensorSpec(None, tf.float32)\n        )\n\n        tf.saved_model.save(model, model_with_signature_path, signatures=call)\n        restored_model = tf.saved_model.load(model_with_signature_path)\n        self.assertEqual(\n            list(restored_model.signatures.keys()), [\"serving_default\"]\n        )\n\n    def test_multiple_signatures_dict_path(self):\n        model = CustomSignatureModel()\n        model_multiple_signatures_path = os.path.join(\n            self.get_temp_dir(), \"model_with_multiple_signatures\"\n        )\n        call = model.__call__.get_concrete_function(\n            tf.TensorSpec(None, tf.float32)\n        )\n        signatures = {\n            \"serving_default\": call,\n            \"array_input\": model.__call__.get_concrete_function(\n                tf.TensorSpec([None], tf.float32)\n            ),\n        }\n\n        tf.saved_model.save(\n            model, model_multiple_signatures_path, signatures=signatures\n        )\n        restored_model = tf.saved_model.load(model_multiple_signatures_path)\n        self.assertEqual(\n            list(restored_model.signatures.keys()),\n            [\"serving_default\", \"array_input\"],\n        )\n"
  },
  {
    "path": "keras/src/backend/tensorflow/sparse.py",
    "content": "import functools\n\nimport tensorflow as tf\n\nones_bool = functools.partial(tf.ones, dtype=tf.bool)\nones_int8 = functools.partial(tf.ones, dtype=tf.int8)\nzeros_int8 = functools.partial(tf.zeros, dtype=tf.int8)\nones_like_int8 = functools.partial(tf.ones_like, dtype=tf.int8)\nzeros_like_int8 = functools.partial(tf.zeros_like, dtype=tf.int8)\n\n\ndef sparse_to_dense(x, default_value=None):\n    x_shape = x.shape\n    if x_shape.rank == 0:\n        # Workaround for bug on GPU when sparse tensor represents a scalar.\n        if x.values.shape[0] == 0:\n            return tf.constant(default_value, dtype=x.dtype)\n        else:\n            return tf.reshape(x.values, ())\n    x = tf.sparse.to_dense(x, default_value=default_value)\n    x.set_shape(x_shape)\n    return x\n\n\ndef sparse_with_values(x, values):\n    x_shape = x.shape\n    x = tf.SparseTensor(x.indices, values, x.dense_shape)\n    x.set_shape(x_shape)\n    return x\n\n\ndef broadcast_scalar_to_sparse_shape(scalar, sparse):\n    output = tf.broadcast_to(scalar, sparse.dense_shape)\n    output.set_shape(sparse.shape)\n    return output\n\n\ndef sparse_subtract(x1, x2):\n    \"\"\"Subtraction for `tf.SparseTensor`s.\n\n    Either `x1` or `x2` or both can be `tf.SparseTensor`s.\n\n    Args:\n        x1: fist tensor to add.\n        x2: second tensor to add.\n    Returns:\n        The sum of `x1` and `x2`, which is a `tf.SparseTensor` if and only if\n        both `x1` or `x2` are `tf.SparseTensor`s.\n    \"\"\"\n    if isinstance(x2, tf.SparseTensor):\n        return tf.sparse.add(x1, tf.sparse.map_values(tf.negative, x2))\n    else:\n        return tf.sparse.add(x1, tf.negative(x2))\n\n\ndef sparse_union_indices_and_values(x1, x2_indices, x2_values=None):\n    \"\"\"Compute the indices for the union of the indices of the provided\n    `tf.SparseTensor`s and another set of indices and return the modified values\n    for these indices.\n\n    Args:\n        x: a `tf.SparseTensor`.\n        indices: another set of indices in the `tf.SparseTensor` format.\n    Returns: A tuple containing:\n        - the indices for the union\n        - `x1` values for the union indices (some zeros were added)\n        - `x2` values for the union indices (some zeros were added) or `None` if\n          `x2_values` was `None`.\n    \"\"\"\n    # Add zeros at the x2 indices to x1 to create the union.\n    zeros2 = tf.SparseTensor(\n        x2_indices,\n        tf.zeros((tf.shape(x2_indices)[0],), x1.values.dtype),\n        x1.dense_shape,\n    )\n    x1_for_union = tf.sparse.add(x1, zeros2)\n    if x2_values is not None:\n        # Add zeros at the x1 indices to x2 to create the union.\n        x2 = tf.SparseTensor(x2_indices, x2_values, x1.dense_shape)\n        zeros1 = tf.sparse.map_values(tf.zeros_like, x1)\n        x2_for_union = tf.sparse.add(x2, zeros1)\n        return x1_for_union.indices, x1_for_union.values, x2_for_union.values\n    else:\n        return x1_for_union.indices, x1_for_union.values, None\n\n\ndef indexed_slices_union_indices_and_values(x1, x2_indices, x2_values=None):\n    \"\"\"Compute the indices for the union of two `tf.IndexedSlices` and modify\n    the values for these indices.\n\n    Args:\n        x1: the first `tf.IndexedSlices`.\n        x2_indices: the indices for the second `tf.IndexedSlices`.\n        x2_value: (optional) the values for the second `tf.IndexedSlices`.\n    Returns: A tuple containing:\n        - the indices for the union\n        - `x1` values for the union indices (some zeros were added)\n        - `x2` values for the union indices (some zeros were added) or `None` if\n          `x2_values` was `None`.\n    \"\"\"\n    # Compute the union of the indices by doing a logical or between the one-hot\n    # encoded indices for x1 and x2.\n    dim_0 = x1.dense_shape[0]\n    x1_indices_expanded = tf.expand_dims(x1.indices, axis=1)\n    x2_indices_expanded = tf.expand_dims(x2_indices, axis=1)\n    x1_indices_count = tf.shape(x1_indices_expanded)[0]\n    x2_indices_count = tf.shape(x2_indices_expanded)[0]\n    x1_indices_one_hot = tf.scatter_nd(\n        x1_indices_expanded,\n        ones_bool((x1_indices_count,)),\n        (dim_0,),\n    )\n    x2_indices_one_hot = tf.scatter_nd(\n        x2_indices_expanded,\n        ones_bool((x2_indices_count,)),\n        (dim_0,),\n    )\n    union_indices = tf.squeeze(\n        tf.where(tf.math.logical_or(x1_indices_one_hot, x2_indices_one_hot)),\n        axis=-1,\n    )\n    union_indices_count = tf.shape(union_indices)[0]\n\n    # Re-gather the values with extra zeros added at indices that are part of\n    # the union but were not in x1 or x2.\n    def values_for_union(indices_expanded, indices_count, values):\n        indices_indices = tf.scatter_nd(\n            indices_expanded,\n            tf.range(1, indices_count + 1),\n            (dim_0,),\n        )\n        to_union_indices = tf.gather(indices_indices, union_indices)\n        values_with_leading_zeros = tf.concat(\n            [tf.zeros_like(values[0:1]), values], axis=0\n        )\n        return tf.gather(values_with_leading_zeros, to_union_indices)\n\n    # Only recompute values if some indices were added.\n    x1_values_for_union_indices = tf.cond(\n        tf.equal(x1_indices_count, union_indices_count),\n        lambda: x1.values,\n        lambda: values_for_union(\n            x1_indices_expanded, x1_indices_count, x1.values\n        ),\n    )\n    if x2_values is not None:\n        x2_values_for_union_indices = tf.cond(\n            tf.equal(x2_indices_count, union_indices_count),\n            lambda: x2_values,\n            lambda: values_for_union(\n                x2_indices_expanded, x2_indices_count, x2_values\n            ),\n        )\n    else:\n        x2_values_for_union_indices = None\n\n    return (\n        union_indices,\n        x1_values_for_union_indices,\n        x2_values_for_union_indices,\n    )\n\n\ndef sparse_intersection_indices_and_values(x1, x2):\n    \"\"\"Compute the indices for the intersection of two `tf.SparseTensor`s and\n    modify the values for these indices.\n\n    Args:\n        x1: the first `tf.SparseTensor`.\n        x2: the second `tf.SparseTensor`.\n    Returns: A tuple containing:\n        - the indices for the intersection\n        - `x1` values for the intersection indices (some values were removed)\n        - `x2` values for the intersection indices (some values were removed)\n    \"\"\"\n    # Compute the intersection of indices in the form of a sparse\n    # tensor containing ones as values.\n    ones1 = tf.sparse.map_values(ones_like_int8, x1)\n    ones2 = tf.sparse.map_values(ones_like_int8, x2)\n    # tf.sets.intersection ignores the last dimension when, so we\n    # need to add a dummy extra dimension and then remove it.\n    intersection_extra_dim = tf.sets.intersection(\n        tf.sparse.expand_dims(ones1, axis=-1),\n        tf.sparse.expand_dims(ones2, axis=-1),\n    )\n\n    def empty_intersection():\n        return (\n            tf.zeros((0, x1.shape.rank), dtype=tf.int64),\n            tf.zeros((0,), dtype=x1.values.dtype),\n            tf.zeros((0,), dtype=x2.values.dtype),\n        )\n\n    def non_empty_intersection():\n        intersection = tf.sparse.reshape(intersection_extra_dim, x1.dense_shape)\n\n        # Compute the masks to remove indices in x1 and x2 that are not\n        # in the intersection, then trim x1 and x2.\n        zeros1 = tf.sparse.map_values(zeros_like_int8, x1)\n        zeros2 = tf.sparse.map_values(zeros_like_int8, x2)\n        mask1 = tf.sparse.add(zeros1, intersection)\n        mask2 = tf.sparse.add(zeros2, intersection)\n        return (\n            intersection.indices,\n            tf.sparse.retain(x1, tf.cast(mask1.values, tf.bool)).values,\n            tf.sparse.retain(x2, tf.cast(mask2.values, tf.bool)).values,\n        )\n\n    return tf.cond(\n        tf.equal(tf.size(intersection_extra_dim), 0),\n        empty_intersection,\n        non_empty_intersection,\n    )\n\n\ndef indexed_slices_intersection_indices_and_values(x1, x2):\n    \"\"\"Compute the indices for the intersection of two `tf.IndexedSlices` and\n    modify the values for these indices.\n\n    Args:\n        x1: the first `tf.IndexedSlices`.\n        x2: the second `tf.IndexedSlices`.\n    Returns: A tuple containing:\n        - the indices for the intersection\n        - `x1` values for the intersection indices (some values were removed)\n        - `x2` values for the intersection indices (some values were removed)\n    \"\"\"\n    # Compute the intersection of the indices by doing a logical\n    # and between the one hot encoded indices for x1 and x2.\n    dim_0 = x1.dense_shape[0]\n    x1_indices_expanded = tf.expand_dims(x1.indices, axis=1)\n    x2_indices_expanded = tf.expand_dims(x2.indices, axis=1)\n    x1_indices_count = x1_indices_expanded.shape[0]\n    x2_indices_count = x2_indices_expanded.shape[0]\n    x1_indices_one_hot = tf.scatter_nd(\n        x1_indices_expanded,\n        ones_bool((x1_indices_count,)),\n        (dim_0,),\n    )\n    x2_indices_one_hot = tf.scatter_nd(\n        x2_indices_expanded,\n        ones_bool((x2_indices_count,)),\n        (dim_0,),\n    )\n    intersection_indices = tf.squeeze(\n        tf.where(tf.math.logical_and(x1_indices_one_hot, x2_indices_one_hot)),\n        axis=-1,\n    )\n    intersection_indices_count = tf.shape(intersection_indices)[0]\n\n    def empty_intersection():\n        return (\n            intersection_indices,\n            tf.zeros((0,) + x1.values.shape[1:], x1.dtype),\n            tf.zeros((0,) + x2.values.shape[1:], x2.dtype),\n        )\n\n    def non_empty_intersection():\n        # Re-gather sub parts of the values that are part of the intersection.\n        def values_for_intersection(indices_expanded, indices_count, values):\n            indices_indices = tf.scatter_nd(\n                indices_expanded,\n                tf.range(indices_count),\n                (dim_0,),\n            )\n            to_intersection_indices = tf.gather(\n                indices_indices, intersection_indices\n            )\n            return tf.gather(values, to_intersection_indices)\n\n        # Only recompute values if some indices were removed.\n        x1_values_for_intersection = tf.cond(\n            tf.equal(x1_indices_count, intersection_indices_count),\n            lambda: x1.values,\n            lambda: values_for_intersection(\n                x1_indices_expanded, x1_indices_count, x1.values\n            ),\n        )\n        x2_values_for_intersection = tf.cond(\n            tf.equal(x2_indices_count, intersection_indices_count),\n            lambda: x2.values,\n            lambda: values_for_intersection(\n                x2_indices_expanded, x2_indices_count, x2.values\n            ),\n        )\n\n        return (\n            intersection_indices,\n            x1_values_for_intersection,\n            x2_values_for_intersection,\n        )\n\n    return tf.cond(\n        tf.equal(intersection_indices_count, 0),\n        empty_intersection,\n        non_empty_intersection,\n    )\n\n\ndef densifying_unary(default_value):\n    \"\"\"Decorator to add support for `tf.SparseTensor` and `tf.IndexedSlices` to\n    a non-zero-preserving element-wise unary operator.\n\n    There are requirements on the operator for this decorator to work correctly:\n\n    - The operator must be element-wise\n    - The operator must be unary (one input tensor and one output tensor)\n    - The operator must return a tensor of the same shape.\n\n    Additional arguments to the function (besides the input tensor) are\n    supported. The returned result is a dense tensor and contains\n    `default_value` outside of the indices of the input tensor.\n\n    Args:\n        default_value: The value to use outside of indices. It must be the value\n        that the operator returns for zero values.\n    Returns:\n        Wrapped function that supports `tf.SparseTensor` and `tf.IndexedSlices`.\n    \"\"\"\n\n    def wrap_densifying_unary(func):\n        @functools.wraps(func)\n        def sparse_wrapper(x, *args, **kwargs):\n            if isinstance(x, tf.SparseTensor):\n                sparse_output = sparse_with_values(\n                    x, func(x.values, *args, **kwargs)\n                )\n                return sparse_to_dense(\n                    sparse_output,\n                    tf.cast(default_value, sparse_output.values.dtype),\n                )\n            elif isinstance(x, tf.IndexedSlices):\n                sparse_output_values = func(x.values, *args, **kwargs)\n                output = tf.fill(\n                    x.dense_shape,\n                    tf.cast(default_value, sparse_output_values.dtype),\n                )\n                return tf.tensor_scatter_nd_update(\n                    output, tf.expand_dims(x.indices, 1), sparse_output_values\n                )\n            return func(x, *args, **kwargs)\n\n        return sparse_wrapper\n\n    return wrap_densifying_unary\n\n\ndef elementwise_unary(func):\n    \"\"\"Decorator to add support for `tf.SparseTensor` and `tf.IndexedSlices` to\n    a zero-preserving element-wise unary operator.\n\n    There are requirements on the operator for this decorator to work correctly:\n\n    - The operator must be element-wise\n    - The operator must be unary (one input tensor and one output tensor)\n    - The operator must return a tensor of the same shape, and if it is a\n      `tf.SparseTensor` or `tf.IndexedSlices`, the indices of the result must be\n      the same. Therefore:\n        - Reduction operations are not supported (e.g. `mean`).\n        - Operations for which the result may be dense (e.g. `reciprocal`), or\n          the sparse indices depend on the inputs are not supported (e.g.\n          `clip`). This implies that `func(0)` must be 0.\n\n    Additional arguments to the function (besides the input tensor) are\n    supported as long as they cannot change the indices of the result. For\n    instance,`round` is supported, but `clip` is not supported as\n    `clip(x, 1.0, 2.0)` would always return a dense tensor.\n\n    Note that if an input sparse tensor contains zero values, the indices and\n    the zero values are preserved.\n\n    Args:\n        func: The function to wrap.\n    Returns:\n        Wrapped function that supports `tf.SparseTensor` and `tf.IndexedSlices`.\n    \"\"\"\n\n    @functools.wraps(func)\n    def sparse_wrapper(x, *args, **kwargs):\n        if isinstance(x, tf.SparseTensor):\n            return sparse_with_values(x, func(x.values, *args, **kwargs))\n        elif isinstance(x, tf.IndexedSlices):\n            return tf.IndexedSlices(\n                func(x.values, *args, **kwargs), x.indices, x.dense_shape\n            )\n        else:\n            return func(x, *args, **kwargs)\n\n    return sparse_wrapper\n\n\ndef elementwise_binary_union(sparse_op, densify_mixed=False):\n    \"\"\"Decorator to add support for `tf.SparseTensor` and `tf.IndexedSlices` to\n    an element-wise binary operator such that the indices present in the result\n    are the union of the indices in the two operand.\n\n    The primary use case for this is the `add` and `subtract` operators.\n\n    There are requirements on the operator for this decorator to work correctly:\n\n    - The operator must be element-wise.\n    - The operator must be binary (two input tensors and one output tensor).\n    - Both inputs must be of the same shape or one input must be a scalar.\n    - The output must be of the same shape as the (non scalar) inputs.\n    - The indices of the output must be the union of the indices of the inputs.\n      This implies that func(0, 0) must be 0. As a result, if one operand is\n      dense or a scalar, then the result will be dense.\n\n    Additional arguments to the function (besides the input tensors) are not\n    supported.\n\n    Note that if the result of the operation is zero at some indices, including\n    because the operands were zero at these indices, the zeros and indices are\n    preserved.\n\n    Args:\n        sparse_op: implementation of the operation for `tf.SparseTensor`. Must\n            work if both of the operands are `tf.SparseTensor`s and can\n            optionally work if one of the operand is a `tf.SparseTensor` and\n            the other one is dense tensor, see `densify_mixed`.\n        densify_mixed: if `True`, `sparse_op` does not support a mix of\n            `tf.SparseTensor` and dense tensor or dense tensor with\n            `tf.SparseTensor` and the `tf.SparseTensor` tensor is densified.\n    Returns:\n        Wrapped function that supports `tf.SparseTensor` and `tf.IndexedSlices`.\n    \"\"\"\n\n    def wrap_elementwise_binary_union(func):\n        @functools.wraps(func)\n        def sparse_wrapper(x1, x2):\n            if isinstance(x1, tf.SparseTensor):\n                if isinstance(x2, tf.SparseTensor):\n                    # x1 is a SparseTensor and x2 is a SparseTensor.\n                    if x1.indices is x2.indices:\n                        return sparse_with_values(\n                            x1, func(x1.values, x2.values)\n                        )\n                    else:\n                        output = sparse_op(x1, x2)\n                        output.set_shape(x1.shape)\n                        return output\n                else:\n                    # x1 is a SparseTensor.\n                    if densify_mixed:\n                        x1 = sparse_to_dense(x1)\n                    else:\n                        if not hasattr(x2, \"shape\") or len(x2.shape) == 0:\n                            # x2 is a scalar, broadcast.\n                            x2 = broadcast_scalar_to_sparse_shape(x2, x1)\n                        return sparse_op(x1, x2)\n            elif isinstance(x2, tf.SparseTensor):\n                # x2 is a SparseTensor.\n                if densify_mixed:\n                    x2 = sparse_to_dense(x2)\n                else:\n                    if not hasattr(x1, \"shape\") or len(x1.shape) == 0:\n                        # x1 is a scalar, broadcast.\n                        x1 = broadcast_scalar_to_sparse_shape(x1, x2)\n                    return sparse_op(x1, x2)\n            elif isinstance(x1, tf.IndexedSlices):\n                if isinstance(x2, tf.IndexedSlices):\n                    # x1 is an IndexedSlices and x2 is an IndexedSlices.\n                    if x1.indices is x2.indices:\n                        return tf.IndexedSlices(\n                            func(x1.values, x2.values),\n                            x1.indices,\n                            x1.dense_shape,\n                        )\n                    else:\n                        # Compute the union of indices.\n                        (\n                            union_indices,\n                            x1_values_for_union,\n                            x2_values_for_union,\n                        ) = indexed_slices_union_indices_and_values(\n                            x1, x2.indices, x2.values\n                        )\n                        # Now, it is an element-wise operation on the union.\n                        return tf.IndexedSlices(\n                            func(\n                                x1_values_for_union,\n                                x2_values_for_union,\n                            ),\n                            union_indices,\n                            x1.dense_shape,\n                        )\n                else:\n                    # x1 is an IndexedSlices, densify.\n                    x1 = tf.convert_to_tensor(x1)\n            elif isinstance(x2, tf.IndexedSlices):\n                # x2 is an IndexedSlices, densify.\n                x2 = tf.convert_to_tensor(x2)\n            return func(x1, x2)\n\n        return sparse_wrapper\n\n    return wrap_elementwise_binary_union\n\n\ndef elementwise_binary_intersection(func):\n    \"\"\"Decorator to add support for `tf.SparseTensor` and `tf.IndexedSlices` to\n    an element-wise binary operator such that the indices present in the result\n    are the intersection of the indices in the two operand.\n\n    The primary use case for this is the `multiply` operator.\n\n    There are requirements on the operator for this decorator to work correctly:\n\n    - The operator must be element-wise.\n    - The operator must be binary (two input tensors and one output tensor).\n    - Both inputs must be of the same shape or one input must be a scalar.\n    - The output must be of the same shape as the (non scalar) inputs.\n    - The indices of the output must be the intersection of the indices of the\n      inputs. This implies that func(0, x) and func(x, 0) must be 0 for any x.\n      As a result, if one operand is dense or a scalar, then the indices are the\n      ones from the other operand.\n\n    Additional arguments to the function (besides the input tensors) are not\n    supported.\n\n    Note that if the operands contains zero values at some common indices, the\n    indices and the zero values are preserved.\n\n    Args:\n        func: The function to wrap.\n    Returns:\n        Wrapped function that supports `tf.SparseTensor` and `tf.IndexedSlices`.\n    \"\"\"\n\n    @functools.wraps(func)\n    def sparse_wrapper(x1, x2):\n        if isinstance(x1, tf.SparseTensor):\n            if isinstance(x2, tf.SparseTensor):\n                # x1 is a SparseTensor and x2 is a SparseTensor.\n                if x1.indices is x2.indices:\n                    return sparse_with_values(x1, func(x1.values, x2.values))\n                else:\n                    # Compute the intersection of indices.\n                    (\n                        intersection_indices,\n                        x1_values_for_intersection,\n                        x2_values_for_intersection,\n                    ) = sparse_intersection_indices_and_values(x1, x2)\n                    # Now, it is an element-wise operation on the intersection.\n                    output = tf.SparseTensor(\n                        intersection_indices,\n                        func(\n                            x1_values_for_intersection,\n                            x2_values_for_intersection,\n                        ),\n                        x1.dense_shape,\n                    )\n                    output.set_shape(x1.shape)\n                    return output\n            else:\n                # x1 is a SparseTensor.\n                if not hasattr(x2, \"shape\") or len(x2.shape) == 0:\n                    # x2 is a scalar, apply func element-wise.\n                    return sparse_with_values(x1, func(x1.values, x2))\n                else:\n                    # x2 is dense, gather values from x1 indices.\n                    return sparse_with_values(\n                        x1, func(x1.values, tf.gather_nd(x2, x1.indices))\n                    )\n        elif isinstance(x2, tf.SparseTensor):\n            # x2 is a SparseTensor.\n            if not hasattr(x1, \"shape\") or len(x1.shape) == 0:\n                # x1 is a scalar, apply func element-wise.\n                return sparse_with_values(x2, func(x1, x2.values))\n            else:\n                # x1 is dense, gather values from x2 indices.\n                return sparse_with_values(\n                    x2, func(tf.gather_nd(x1, x2.indices), x2.values)\n                )\n        elif isinstance(x1, tf.IndexedSlices):\n            if isinstance(x2, tf.IndexedSlices):\n                # x1 is an IndexedSlices and x2 is an IndexedSlices.\n                if x1.indices is x2.indices:\n                    return tf.IndexedSlices(\n                        func(x1.values, x2.values), x1.indices, x1.dense_shape\n                    )\n                else:\n                    # Compute the intersection of indices.\n                    (\n                        intersection_indices,\n                        x1_values_for_intersection,\n                        x2_values_for_intersection,\n                    ) = indexed_slices_intersection_indices_and_values(x1, x2)\n                    # Now, it is an element-wise operation on the intersection.\n                    return tf.IndexedSlices(\n                        func(\n                            x1_values_for_intersection,\n                            x2_values_for_intersection,\n                        ),\n                        intersection_indices,\n                        x1.dense_shape,\n                    )\n            else:\n                # x1 is an IndexedSlices.\n                if not hasattr(x2, \"shape\") or len(x2.shape) == 0:\n                    # x2 is a scalar, apply func element-wise.\n                    return tf.IndexedSlices(\n                        func(x1.values, x2), x1.indices, x1.dense_shape\n                    )\n                else:\n                    # x2 is dense, gather values from x1 indices.\n                    return tf.IndexedSlices(\n                        func(x1.values, tf.gather(x2, x1.indices)),\n                        x1.indices,\n                        x1.dense_shape,\n                    )\n        elif isinstance(x2, tf.IndexedSlices):\n            # x2 is an IndexedSlices.\n            if not hasattr(x1, \"shape\") or len(x1.shape) == 0:\n                # x1 is a scalar, apply func element-wise.\n                return tf.IndexedSlices(\n                    func(x1, x2.values), x2.indices, x2.dense_shape\n                )\n            else:\n                # x1 is dense, gather values from x2 indices.\n                return tf.IndexedSlices(\n                    func(tf.gather(x1, x2.indices), x2.values),\n                    x2.indices,\n                    x2.dense_shape,\n                )\n        # Default case, no SparseTensor and no IndexedSlices.\n        return func(x1, x2)\n\n    return sparse_wrapper\n\n\ndef elementwise_division(func):\n    \"\"\"Decorator to add support for `tf.SparseTensor` and `tf.IndexedSlices` to\n    element-wise binary division and related operators.\n\n    This decorator is designed for operations related to the division of two\n    operands (e.g. `divide`). It accepts `tf.SparseTensor` and\n    `tf.IndexedSlices` for both the dividend and the divisor, but handles them\n    differently based on whether they are the dividend or the divisor.\n\n    - If the divisor is a `tf.SparseTensor` or `tf.IndexedSlices`, it is\n      densified and the result is dense because the result contains Inf or Nan\n      outside of the indices of the dividend.\n    - If the dividend is a `tf.SparseTensor` or `tf.IndexedSlices` and the\n      divisor is dense, it finds occurrences of zeros and NaNs in the divisor.\n      The result may therefore have more indices than there were in the dividend\n      to return correct values where the divisor was zero or NaN.\n    - If the dividend is a `tf.SparseTensor` or `tf.IndexedSlices` and the\n      divisor is a scalar, it does the division element-wise. Note that the\n      result is incorrectly sparse if the scalar divisor is zero.\n\n    Args:\n        func: The function to wrap.\n    Returns:\n        Wrapped function that supports `tf.SparseTensor` and `tf.IndexedSlices`.\n    \"\"\"\n\n    @functools.wraps(func)\n    def sparse_wrapper(x1, x2):\n        if isinstance(x1, tf.SparseTensor):\n            if isinstance(x2, tf.SparseTensor):\n                # x1 is a SparseTensor and x2 is a SparseTensor.\n                # Divisor is sparse, meaning we're doing divisions by zero\n                # outside of x2.indices, so the result is dense. Densify both.\n                x1 = sparse_to_dense(x1)\n                x2 = sparse_to_dense(x2)\n            else:\n                # x1 is a SparseTensor.\n                if not hasattr(x2, \"shape\") or len(x2.shape) == 0:\n                    # x2 is a scalar, apply func element-wise.\n                    return sparse_with_values(x1, func(x1.values, x2))\n                else:\n                    # x2 is dense.\n                    x2_zeros_and_nans = tf.equal(x2, 0)\n                    if not tf.as_dtype(x2.dtype).is_integer:\n                        x2_zeros_and_nans = tf.math.logical_or(\n                            x2_zeros_and_nans, tf.math.is_nan(x2)\n                        )\n\n                    def func_for_x1_indices():\n                        # Gather values from x1 indices.\n                        return sparse_with_values(\n                            x1, func(x1.values, tf.gather_nd(x2, x1.indices))\n                        )\n\n                    def func_for_union_indices():\n                        # Compute the union of indices to keep zeros and NaNs.\n                        x2_zeros_and_nan_indices = tf.where(x2_zeros_and_nans)\n                        (\n                            union_indices,\n                            x1_values_for_union,\n                            _,\n                        ) = sparse_union_indices_and_values(\n                            x1, x2_zeros_and_nan_indices\n                        )\n                        output = tf.SparseTensor(\n                            union_indices,\n                            func(\n                                x1_values_for_union,\n                                tf.gather_nd(x2, union_indices),\n                            ),\n                            x1.dense_shape,\n                        )\n                        output.set_shape(x1.shape)\n                        return output\n\n                    return tf.cond(\n                        tf.reduce_any(x2_zeros_and_nans),\n                        func_for_union_indices,\n                        func_for_x1_indices,\n                    )\n        elif isinstance(x2, tf.SparseTensor):\n            # x2 is a SparseTensor.\n            # Divisor is sparse, densify to do the divisions by zero correctly.\n            x2 = sparse_to_dense(x2)\n        elif isinstance(x1, tf.IndexedSlices):\n            if isinstance(x2, tf.IndexedSlices):\n                # x1 is an IndexedSlices and x2 is an IndexedSlices.\n                # Divisor is slices, meaning we're doing divisions by zero\n                # outside of x2.indices, so the result is dense. Densify both.\n                x1 = tf.convert_to_tensor(x1)\n                x2 = tf.convert_to_tensor(x2)\n            else:\n                # x1 is a IndexedSlices.\n                if not hasattr(x2, \"shape\") or len(x2.shape) == 0:\n                    # x2 is a scalar, apply func element-wise.\n                    return tf.IndexedSlices(\n                        func(x1.values, x2), x1.indices, x1.dense_shape\n                    )\n                else:\n                    # x2 is dense.\n                    x2_zeros_and_nans = tf.equal(x2, 0)\n                    if not tf.as_dtype(x2.dtype).is_integer:\n                        x2_zeros_and_nans = tf.math.logical_or(\n                            x2_zeros_and_nans, tf.math.is_nan(x2)\n                        )\n                    x2_zeros_and_nans = tf.reduce_any(\n                        x2_zeros_and_nans, axis=tuple(range(1, x2.shape.rank))\n                    )\n\n                    def func_for_x1_indices():\n                        # Gather values from x1 indices.\n                        return tf.IndexedSlices(\n                            func(x1.values, tf.gather(x2, x1.indices)),\n                            x1.indices,\n                            x1.dense_shape,\n                        )\n\n                    def func_for_union_indices():\n                        x2_zeros_and_nan_indices = tf.squeeze(\n                            tf.where(x2_zeros_and_nans), axis=-1\n                        )\n                        # Compute the union of indices to keep zeros and NaNs.\n                        (\n                            union_indices,\n                            x1_values_for_union,\n                            _,\n                        ) = indexed_slices_union_indices_and_values(\n                            x1, x2_zeros_and_nan_indices\n                        )\n                        return tf.IndexedSlices(\n                            func(\n                                x1_values_for_union,\n                                tf.gather(x2, union_indices),\n                            ),\n                            union_indices,\n                            x1.dense_shape,\n                        )\n\n                    return tf.cond(\n                        tf.reduce_any(x2_zeros_and_nans),\n                        func_for_union_indices,\n                        func_for_x1_indices,\n                    )\n        elif isinstance(x2, tf.IndexedSlices):\n            # x2 is a IndexedSlices.\n            # Divisor is slices, densify to do the divisions by zero correctly.\n            x2 = tf.convert_to_tensor(x2)\n        # Default case, no SparseTensor and no IndexedSlices.\n        return func(x1, x2)\n\n    return sparse_wrapper\n"
  },
  {
    "path": "keras/src/backend/tensorflow/tensorboard.py",
    "content": "from keras.src.utils.module_utils import tensorflow as tf\n\n\ndef start_trace(logdir):\n    tf.profiler.experimental.start(logdir=logdir)\n\n\ndef stop_trace(save):\n    tf.profiler.experimental.stop(save=save)\n\n\ndef start_batch_trace(batch):\n    batch_trace_context = tf.profiler.experimental.Trace(\n        \"Profiled batch\", step_num=batch\n    )\n    batch_trace_context.__enter__()\n    return batch_trace_context\n\n\ndef stop_batch_trace(batch_trace_context):\n    batch_trace_context.__exit__(None, None, None)\n"
  },
  {
    "path": "keras/src/backend/tensorflow/trackable.py",
    "content": "import tensorflow as tf\n\nfrom keras.src.utils import tracking\n\n\nclass KerasAutoTrackable(tf.__internal__.tracking.AutoTrackable):\n    \"\"\"Manages dependencies on other objects with Keras tracking.\n\n    Similar to TF AutoTrackable, but disabling tracking is based\n    on tracking within Keras.\n\n    This serves as an interface between Keras tracking and TF tracking.\n    \"\"\"\n\n    def __setattr__(self, name, value):\n        \"\"\"Support self.foo = trackable syntax.\"\"\"\n        try:\n            if getattr(self, name) is value:\n                # Short circuit for `self.$x = self.$x`.\n                return\n        except AttributeError:\n            pass\n\n        if getattr(self, \"_self_setattr_tracking\", True):\n            value = sticky_attribute_assignment(\n                trackable=self, value=value, name=name\n            )\n        super().__setattr__(name, value)\n\n\ndef sticky_attribute_assignment(trackable, name, value):\n    \"\"\"Adds dependencies, called from __setattr__.\n\n    Args:\n        trackable: The object to add dependencies to (generally the one having\n        an attribute assigned).\n        name: The attribute name being assigned.\n        value: The value being assigned. Not necessarily a trackable object.\n\n    Returns:\n        The value which should be stored in the attribute.\n    \"\"\"\n    if isinstance(\n        value,\n        (\n            tracking.TrackedList,\n            tracking.TrackedDict,\n            tracking.TrackedOrderedDict,\n            tracking.TrackedSet,\n        ),\n    ) and hasattr(trackable, \"_tracked\"):\n        trackable._tracked.append(name)\n    if not tracking.is_tracking_enabled():\n        return value\n    if isinstance(value, tf.__internal__.tracking.Trackable):\n        trackable._track_trackable(  # pylint: disable=protected-access\n            value,\n            name=name,\n            # Allow the user to switch the Trackable which is tracked by this\n            # name, since assigning a new variable to an attribute has\n            # historically been fine (e.g. Adam did this).\n            overwrite=True,\n        )\n    return value\n"
  },
  {
    "path": "keras/src/backend/tensorflow/trainer.py",
    "content": "import contextlib\nimport functools\nimport warnings\n\nimport tensorflow as tf\nfrom tensorflow.python.eager import context as tf_context\n\nfrom keras.src import callbacks as callbacks_module\nfrom keras.src import metrics as metrics_module\nfrom keras.src import optimizers as optimizers_module\nfrom keras.src import tree\nfrom keras.src.backend import config\nfrom keras.src.losses import loss as loss_module\nfrom keras.src.trainers import trainer as base_trainer\nfrom keras.src.trainers.data_adapters import array_slicing\nfrom keras.src.trainers.data_adapters import data_adapter_utils\nfrom keras.src.trainers.epoch_iterator import EpochIterator\nfrom keras.src.utils import traceback_utils\nfrom keras.src.utils.python_utils import pythonify_logs\n\n\nclass TensorFlowTrainer(base_trainer.Trainer):\n    def __init__(self):\n        super().__init__()\n        self.train_function = None\n        self.test_function = None\n        self.predict_function = None\n\n        # Specifies how many steps of the step_per_execution loop to unroll.\n        # Increasing this value can reduce kernel launch overhead,\n        # but will increase memory usage and compilation time.\n        self.unrolled_steps_per_execution = 1\n\n        # Model must be created under scope of DistStrat it will be trained\n        # with.\n        if tf.distribute.has_strategy():\n            self._distribute_strategy = tf.distribute.get_strategy()\n        else:\n            self._distribute_strategy = None\n\n    @property\n    def distribute_strategy(self):\n        return self._distribute_strategy or tf.distribute.get_strategy()\n\n    @property\n    def distribute_reduction_method(self):\n        return self._distribute_reduction_method or \"auto\"\n\n    @distribute_reduction_method.setter\n    def distribute_reduction_method(self, value):\n        self._distribute_reduction_method = value\n\n    def train_step(self, data):\n        x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)\n\n        # Forward pass\n        with tf.GradientTape() as tape:\n            if self._call_has_training_arg:\n                y_pred = self(x, training=True)\n            else:\n                y_pred = self(x)\n            loss = self._compute_loss(\n                x=x,\n                y=y,\n                y_pred=y_pred,\n                sample_weight=sample_weight,\n                training=True,\n            )\n            self._loss_tracker.update_state(\n                loss_module.unscale_loss_for_distribution(loss),\n                sample_weight=tf.shape(\n                    next(i for i in tree.flatten(x) if i is not None)\n                )[0],\n            )\n            if self.optimizer is not None:\n                loss = self.optimizer.scale_loss(loss)\n\n        # Compute gradients\n        if self.trainable_weights:\n            trainable_weights = self.trainable_weights\n            gradients = tape.gradient(loss, trainable_weights)\n\n            # Update weights\n            self.optimizer.apply_gradients(zip(gradients, trainable_weights))\n        else:\n            warnings.warn(\"The model does not have any trainable weights.\")\n\n        return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)\n\n    def test_step(self, data):\n        x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)\n        if self._call_has_training_arg:\n            y_pred = self(x, training=False)\n        else:\n            y_pred = self(x)\n        loss = self._compute_loss(\n            x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=False\n        )\n        self._loss_tracker.update_state(\n            loss_module.unscale_loss_for_distribution(loss),\n            sample_weight=tf.shape(\n                next(i for i in tree.flatten(x) if i is not None)\n            )[0],\n        )\n        return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)\n\n    def predict_step(self, data):\n        x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data)\n        if self._call_has_training_arg:\n            y_pred = self(x, training=False)\n        else:\n            y_pred = self(x)\n        return y_pred\n\n    def _autoconvert_optionals(self, step_func):\n        # Wrapper converting (nested) TF Optional in input data to None\n        @functools.wraps(step_func)\n        def wrapper(data):\n            converted_data = tree.map_structure(\n                lambda i: (\n                    None if isinstance(i, tf.experimental.Optional) else i\n                ),\n                data,\n            )\n            result = step_func(converted_data)\n            return result\n\n        return wrapper\n\n    def _make_function(self, step_function):\n        @tf.autograph.experimental.do_not_convert\n        def one_step_on_data(data):\n            \"\"\"Runs a single training step on a batch of data.\"\"\"\n            outputs = self.distribute_strategy.run(step_function, args=(data,))\n            outputs = reduce_per_replica(\n                outputs,\n                self.distribute_strategy,\n                reduction=\"auto\",\n            )\n            return outputs\n\n        if not self.run_eagerly:\n            one_step_on_data = tf.function(\n                one_step_on_data,\n                reduce_retracing=True,\n                jit_compile=self.jit_compile,\n            )\n        one_step_on_data = self._autoconvert_optionals(one_step_on_data)\n\n        @tf.autograph.experimental.do_not_convert\n        def multi_step_on_iterator(iterator):\n            if self.steps_per_execution == 1:\n                return tf.experimental.Optional.from_value(\n                    one_step_on_data(iterator.get_next())\n                )\n\n            # the spec is set lazily during the tracing of `tf.while_loop`\n            empty_outputs = tf.experimental.Optional.empty(None)\n\n            def cond(execution_step, optional_outputs, next_optional_inputs):\n                return tf.logical_and(\n                    tf.less(execution_step, self.steps_per_execution),\n                    next_optional_inputs.has_value(),\n                )\n\n            def inner_body(\n                execution_step, optional_outputs, next_optional_inputs\n            ):\n                def has_next():\n                    next_optional_outputs = tf.experimental.Optional.from_value(\n                        one_step_on_data(next_optional_inputs.get_value())\n                    )\n                    empty_outputs._element_spec = (\n                        next_optional_outputs.element_spec\n                    )\n                    return next_optional_outputs\n\n                def no_has_next():\n                    optional_outputs._element_spec = empty_outputs._element_spec\n                    return optional_outputs\n\n                next_optional_outputs = tf.cond(\n                    tf.logical_and(\n                        tf.less(execution_step, self.steps_per_execution),\n                        next_optional_inputs.has_value(),\n                    ),\n                    has_next,\n                    no_has_next,\n                )\n\n                return (\n                    execution_step + 1,\n                    next_optional_outputs,\n                    # We don't want to iterate if we have reached\n                    # `steps_per_execution` steps\n                    tf.cond(\n                        tf.less(execution_step + 1, self.steps_per_execution),\n                        lambda: iterator.get_next_as_optional(),\n                        lambda: next_optional_inputs,\n                    ),\n                )\n\n            def body(execution_step, optional_outputs, next_optional_inputs):\n                for _ in range(\n                    min(\n                        self.unrolled_steps_per_execution,\n                        self.steps_per_execution,\n                    )\n                ):\n                    execution_step, optional_outputs, next_optional_inputs = (\n                        inner_body(\n                            execution_step,\n                            optional_outputs,\n                            next_optional_inputs,\n                        )\n                    )\n\n                return (execution_step, optional_outputs, next_optional_inputs)\n\n            execution_step = tf.constant(0)\n            next_optional_inputs = iterator.get_next_as_optional()\n\n            # Run the while loop\n            _, final_optional_outputs, _ = tf.while_loop(\n                cond,\n                body,\n                loop_vars=[execution_step, empty_outputs, next_optional_inputs],\n            )\n            final_optional_outputs._element_spec = empty_outputs.element_spec\n            return final_optional_outputs\n\n        if not self.run_eagerly:\n            multi_step_on_iterator = tf.function(\n                multi_step_on_iterator, reduce_retracing=True\n            )\n\n        def function(iterator):\n            if isinstance(\n                iterator, (tf.data.Iterator, tf.distribute.DistributedIterator)\n            ):\n                opt_outputs = multi_step_on_iterator(iterator)\n                if not opt_outputs.has_value():\n                    raise StopIteration\n                return opt_outputs.get_value()\n            else:\n                for step, data in zip(\n                    range(self.steps_per_execution), iterator\n                ):\n                    outputs = one_step_on_data(data)\n                return outputs\n\n        return function\n\n    def make_train_function(self, force=False):\n        if self.train_function is not None and not force:\n            return self.train_function\n        self.train_function = self._make_function(self.train_step)\n\n    def make_test_function(self, force=False):\n        if self.test_function is not None and not force:\n            return self.test_function\n        self.test_function = self._make_function(self.test_step)\n\n    def make_predict_function(self, force=False):\n        if self.predict_function is not None and not force:\n            return self.predict_function\n\n        @tf.autograph.experimental.do_not_convert\n        def one_step_on_data(data):\n            \"\"\"Runs a predict test step on a batch of data.\"\"\"\n            return self.predict_step(data)\n\n        if not self.run_eagerly and self.jit_compile:\n            one_step_on_data = tf.function(\n                one_step_on_data, reduce_retracing=True, jit_compile=True\n            )\n        one_step_on_data = self._autoconvert_optionals(one_step_on_data)\n\n        @tf.autograph.experimental.do_not_convert\n        def one_step_on_data_distributed(data):\n            data = data[0]\n            outputs = self.distribute_strategy.run(\n                one_step_on_data, args=(data,)\n            )\n            outputs = reduce_per_replica(\n                outputs,\n                self.distribute_strategy,\n                reduction=\"concat\",\n            )\n            return outputs\n\n        @tf.autograph.experimental.do_not_convert\n        def multi_step_on_data(data):\n            outputs = one_step_on_data_distributed(data[:1])\n            for single_step_data in data[1:]:\n                step_outputs = one_step_on_data_distributed([single_step_data])\n                outputs = tree.map_structure(\n                    lambda t1, t2: concat([t1, t2]), outputs, step_outputs\n                )\n            return outputs\n\n        if self.steps_per_execution > 1:\n            predict_function = multi_step_on_data\n        else:\n            predict_function = one_step_on_data_distributed\n\n        if not self.run_eagerly:\n            predict_function = tf.function(\n                predict_function, reduce_retracing=True\n            )\n\n        self.predict_function = predict_function\n\n    @traceback_utils.filter_traceback\n    def fit(\n        self,\n        x=None,\n        y=None,\n        batch_size=None,\n        epochs=1,\n        verbose=\"auto\",\n        callbacks=None,\n        validation_split=0.0,\n        validation_data=None,\n        shuffle=True,\n        class_weight=None,\n        sample_weight=None,\n        initial_epoch=0,\n        steps_per_epoch=None,\n        validation_steps=None,\n        validation_batch_size=None,\n        validation_freq=1,\n    ):\n        self._assert_compile_called(\"fit\")\n        # Possibly cap epochs for debugging runs.\n        max_epochs = config.max_epochs()\n        if max_epochs and max_epochs < epochs:\n            warnings.warn(\"Limiting epochs to %d\" % max_epochs)\n            epochs = max_epochs\n        # TODO: respect compiled trainable state\n        self._eval_epoch_iterator = None\n        if validation_split and validation_data is None:\n            # Create the validation data using the training data. Only supported\n            # for TF/numpy/jax arrays.\n            (\n                (x, y, sample_weight),\n                validation_data,\n            ) = array_slicing.train_validation_split(\n                (x, y, sample_weight), validation_split=validation_split\n            )\n\n        if validation_data is not None:\n            (\n                val_x,\n                val_y,\n                val_sample_weight,\n            ) = data_adapter_utils.unpack_x_y_sample_weight(validation_data)\n\n        # Create an iterator that yields batches for one epoch.\n        epoch_iterator = TFEpochIterator(\n            x=x,\n            y=y,\n            sample_weight=sample_weight,\n            batch_size=batch_size,\n            steps_per_epoch=steps_per_epoch,\n            shuffle=shuffle,\n            class_weight=class_weight,\n            distribute_strategy=self.distribute_strategy,\n            steps_per_execution=self.steps_per_execution,\n        )\n\n        self._maybe_symbolic_build(iterator=epoch_iterator)\n        epoch_iterator.reset()\n\n        # Container that configures and calls callbacks.\n        if not isinstance(callbacks, callbacks_module.CallbackList):\n            callbacks = callbacks_module.CallbackList(\n                callbacks,\n                add_history=True,\n                add_progbar=verbose != 0,\n                verbose=verbose,\n                epochs=epochs,\n                steps=epoch_iterator.num_batches,\n                model=self,\n            )\n\n        self.stop_training = False\n        self.make_train_function()\n        callbacks.on_train_begin()\n        training_logs = None\n        logs = {}\n        initial_epoch = self._initial_epoch or initial_epoch\n        for epoch in range(initial_epoch, epochs):\n            self.reset_metrics()\n            callbacks.on_epoch_begin(epoch)\n            with epoch_iterator.catch_stop_iteration():\n                for begin_step, end_step, iterator in epoch_iterator:\n                    callbacks.on_train_batch_begin(begin_step)\n                    logs = self.train_function(iterator)\n                    callbacks.on_train_batch_end(end_step, logs)\n                    if self.stop_training:\n                        break\n\n            # Override with model metrics instead of last step logs if needed.\n            epoch_logs = dict(self._get_metrics_result_or_logs(logs))\n\n            # Run validation.\n            if validation_data is not None and self._should_eval(\n                epoch, validation_freq\n            ):\n                # Create EpochIterator for evaluation and cache it.\n                if getattr(self, \"_eval_epoch_iterator\", None) is None:\n                    self._eval_epoch_iterator = TFEpochIterator(\n                        x=val_x,\n                        y=val_y,\n                        sample_weight=val_sample_weight,\n                        batch_size=validation_batch_size or batch_size,\n                        distribute_strategy=self.distribute_strategy,\n                        steps_per_execution=self.steps_per_execution,\n                        steps_per_epoch=validation_steps,\n                        shuffle=False,\n                    )\n                val_logs = self.evaluate(\n                    x=val_x,\n                    y=val_y,\n                    sample_weight=val_sample_weight,\n                    batch_size=validation_batch_size or batch_size,\n                    steps=validation_steps,\n                    callbacks=callbacks,\n                    return_dict=True,\n                    _use_cached_eval_dataset=True,\n                )\n                val_logs = {\n                    f\"val_{name}\": val for name, val in val_logs.items()\n                }\n                epoch_logs.update(val_logs)\n\n            callbacks.on_epoch_end(epoch, epoch_logs)\n            training_logs = epoch_logs\n            if self.stop_training:\n                break\n\n        if (\n            isinstance(self.optimizer, optimizers_module.Optimizer)\n            and epochs > 0\n        ):\n            self.optimizer.finalize_variable_values(self.trainable_weights)\n\n        # If _eval_epoch_iterator exists, delete it after all epochs are done.\n        if getattr(self, \"_eval_epoch_iterator\", None) is not None:\n            del self._eval_epoch_iterator\n        callbacks.on_train_end(logs=training_logs)\n        return self.history\n\n    @traceback_utils.filter_traceback\n    def evaluate(\n        self,\n        x=None,\n        y=None,\n        batch_size=None,\n        verbose=\"auto\",\n        sample_weight=None,\n        steps=None,\n        callbacks=None,\n        return_dict=False,\n        **kwargs,\n    ):\n        self._assert_compile_called(\"evaluate\")\n        # TODO: respect compiled trainable state\n        use_cached_eval_dataset = kwargs.pop(\"_use_cached_eval_dataset\", False)\n        if kwargs:\n            raise ValueError(f\"Arguments not recognized: {kwargs}\")\n\n        if use_cached_eval_dataset:\n            epoch_iterator = self._eval_epoch_iterator\n        else:\n            # Create an iterator that yields batches of input/target data.\n            epoch_iterator = TFEpochIterator(\n                x=x,\n                y=y,\n                sample_weight=sample_weight,\n                batch_size=batch_size,\n                steps_per_epoch=steps,\n                shuffle=False,\n                distribute_strategy=self.distribute_strategy,\n                steps_per_execution=self.steps_per_execution,\n            )\n\n        self._maybe_symbolic_build(iterator=epoch_iterator)\n        epoch_iterator.reset()\n\n        # Container that configures and calls callbacks.\n        if not isinstance(callbacks, callbacks_module.CallbackList):\n            callbacks = callbacks_module.CallbackList(\n                callbacks,\n                add_progbar=verbose != 0,\n                verbose=verbose,\n                epochs=1,\n                steps=epoch_iterator.num_batches,\n                model=self,\n            )\n\n        self.make_test_function()\n        self.stop_evaluating = False\n        callbacks.on_test_begin()\n        logs = {}\n        self.reset_metrics()\n        with epoch_iterator.catch_stop_iteration():\n            for begin_step, end_step, iterator in epoch_iterator:\n                callbacks.on_test_batch_begin(begin_step)\n                logs = self.test_function(iterator)\n                callbacks.on_test_batch_end(end_step, logs)\n                if self.stop_evaluating:\n                    break\n        logs = pythonify_logs(self._get_metrics_result_or_logs(logs))\n        callbacks.on_test_end(logs)\n\n        if return_dict:\n            return logs\n        return self._flatten_metrics_in_order(logs)\n\n    @traceback_utils.filter_traceback\n    def predict(\n        self, x, batch_size=None, verbose=\"auto\", steps=None, callbacks=None\n    ):\n        # Create an iterator that yields batches of input data.\n        epoch_iterator = TFEpochIterator(\n            x=x,\n            batch_size=batch_size,\n            steps_per_epoch=steps,\n            shuffle=False,\n            distribute_strategy=self.distribute_strategy,\n            steps_per_execution=self.steps_per_execution,\n        )\n\n        # Container that configures and calls callbacks.\n        if not isinstance(callbacks, callbacks_module.CallbackList):\n            callbacks = callbacks_module.CallbackList(\n                callbacks,\n                add_progbar=verbose != 0,\n                verbose=verbose,\n                epochs=1,\n                steps=epoch_iterator.num_batches,\n                model=self,\n            )\n\n        def append_to_outputs(batch_outputs, outputs):\n            if outputs is None:\n                outputs = tree.map_structure(\n                    lambda batch_output: [batch_output],\n                    batch_outputs,\n                )\n            else:\n                tree.map_structure_up_to(\n                    batch_outputs,\n                    lambda output, batch_output: output.append(batch_output),\n                    outputs,\n                    batch_outputs,\n                )\n            return outputs\n\n        def get_data(iterator):\n            \"\"\"Returns data for the next execution.\"\"\"\n            data = []\n            for _ in range(self.steps_per_execution):\n                try:\n                    single_step_data = next(iterator)\n                except (StopIteration, tf.errors.OutOfRangeError) as e:\n                    if hasattr(data, \"__len__\") and len(data) > 0:\n                        # Suppress the error when still have remaining data.\n                        return data\n                    else:\n                        # Re-raise the error for\n                        # EpochIterator.catch_stop_iteration() to catch when\n                        # no data left.\n                        raise e\n                data.append(single_step_data)\n            return data\n\n        self.make_predict_function()\n        self.stop_predicting = False\n        callbacks.on_predict_begin()\n        outputs = None\n        with epoch_iterator.catch_stop_iteration():\n            for begin_step, end_step, iterator in epoch_iterator:\n                callbacks.on_predict_batch_begin(begin_step)\n                data = get_data(iterator)\n                batch_outputs = self.predict_function(data)\n                outputs = append_to_outputs(batch_outputs, outputs)\n                callbacks.on_predict_batch_end(\n                    end_step, {\"outputs\": batch_outputs}\n                )\n                if self.stop_predicting:\n                    break\n        callbacks.on_predict_end()\n        outputs = tree.map_structure_up_to(\n            batch_outputs, potentially_ragged_concat, outputs\n        )\n        return tree.map_structure(convert_to_np_if_not_ragged, outputs)\n\n    def train_on_batch(\n        self,\n        x,\n        y=None,\n        sample_weight=None,\n        class_weight=None,\n        return_dict=False,\n    ):\n        self._assert_compile_called(\"train_on_batch\")\n        if class_weight is not None:\n            if sample_weight is not None:\n                raise ValueError(\n                    \"Arguments `sample_weight` and `class_weight` \"\n                    \"cannot be specified at the same time. \"\n                    f\"Received: sample_weight={sample_weight}, \"\n                    f\"class_weight={class_weight}\"\n                )\n            sample_weight = data_adapter_utils.class_weight_to_sample_weights(\n                y, class_weight\n            )\n\n        # Maybe build model\n        self._maybe_symbolic_build(data_batch=(x, y, sample_weight))\n        self.make_train_function()\n\n        def data():\n            yield (x, y, sample_weight)\n\n        logs = self.train_function(data())\n        logs = pythonify_logs(logs)\n        if return_dict:\n            return logs\n        return self._flatten_metrics_in_order(logs)\n\n    def test_on_batch(\n        self,\n        x,\n        y=None,\n        sample_weight=None,\n        return_dict=False,\n    ):\n        self._assert_compile_called(\"test_on_batch\")\n\n        def data():\n            yield (x, y, sample_weight)\n\n        # Maybe build model\n        self._maybe_symbolic_build(data_batch=(x, y, sample_weight))\n        self.make_test_function()\n\n        logs = self.test_function(data())\n        logs = pythonify_logs(logs)\n        if return_dict:\n            return logs\n        return self._flatten_metrics_in_order(logs)\n\n    def predict_on_batch(self, x):\n        self.make_predict_function()\n        batch_outputs = self.predict_function([(x,)])\n        batch_outputs = tree.map_structure(\n            convert_to_np_if_not_ragged, batch_outputs\n        )\n        return batch_outputs\n\n    # Backwards compatibility shims.\n    @property\n    def compiled_metrics(self):\n        class DeprecatedCompiledMetric:\n            def update_state(_, y, y_pred, sample_weight=None):\n                return self._compiled_metrics_update_state(\n                    y, y_pred, sample_weight=sample_weight\n                )\n\n        return DeprecatedCompiledMetric()\n\n    def _compiled_metrics_update_state(self, y, y_pred, sample_weight=None):\n        warnings.warn(\n            \"`model.compiled_metrics()` is deprecated. \"\n            \"Instead, use e.g.:\\n\"\n            \"```\\n\"\n            \"for metric in self.metrics:\\n\"\n            \"    metric.update_state(y, y_pred)\\n\"\n            \"```\\n\",\n            stacklevel=2,\n        )\n        for metric in self.metrics:\n            if isinstance(metric, metrics_module.Mean):\n                metric.update_state(y_pred, sample_weight=sample_weight)\n            else:\n                metric.update_state(y, y_pred, sample_weight=sample_weight)\n\n    def compiled_loss(\n        self, y, y_pred, sample_weight=None, regularization_losses=None\n    ):\n        warnings.warn(\n            \"`model.compiled_loss()` is deprecated. Instead, use \"\n            \"`model.compute_loss(x, y, y_pred, sample_weight, training)`.\",\n        )\n        return self.compute_loss(\n            x=None, y=y, y_pred=y_pred, sample_weight=sample_weight\n        )\n\n    def loss(self, y, y_pred, sample_weight=None):\n        warnings.warn(\n            \"`model.loss()` is deprecated. Instead, use \"\n            \"`model.compute_loss(x, y, y_pred, sample_weight, training)`.\",\n        )\n        return self.compute_loss(\n            x=None, y=y, y_pred=y_pred, sample_weight=sample_weight\n        )\n\n    def _maybe_symbolic_build(self, iterator=None, data_batch=None):\n        # Only symbolic build when distribute strategy is created in tf trainer\n        if self._distribute_strategy is None:\n            # When no distribution strategy is set, defer building\n            # to when the train/test/predict function gets traced.\n            # This maximizes backwards compatibility.\n            return\n\n        # Unlike jax/torch iterator, tf iterator returns an iterator instead\n        # of data batch in `iterator`.\n        if iterator is not None:\n            for _, _, it in iterator:\n                maybe_distributed_data_batch = next(it)\n                has_distributed_values = tree.map_structure(\n                    lambda x: isinstance(x, tf.distribute.DistributedValues),\n                    maybe_distributed_data_batch,\n                )\n                if all(tree.flatten(has_distributed_values)):\n                    data_batch = self.distribute_strategy.reduce(\n                        \"MEAN\",\n                        maybe_distributed_data_batch,\n                        axis=None,\n                    )\n                else:\n                    data_batch = maybe_distributed_data_batch\n                break\n        with self.distribute_strategy.scope():\n            self._symbolic_build(data_batch=data_batch)\n\n    def _aggregate_additional_loss(self, loss):\n        loss = super()._aggregate_additional_loss(loss)\n        return loss_module.scale_loss_for_distribution(loss)\n\n\nclass TFEpochIterator(EpochIterator):\n    def __init__(self, distribute_strategy=None, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self._distribute_strategy = distribute_strategy\n        dataset = self.data_adapter.get_tf_dataset()\n        if not isinstance(dataset, tf.distribute.DistributedDataset):\n            dataset = self._distribute_strategy.experimental_distribute_dataset(\n                dataset\n            )\n        self._distributed_dataset = dataset\n\n    def _get_iterator(self):\n        return self._distributed_dataset\n\n    def tf_sync(self):\n        tf_context.async_wait()\n\n    def __next__(self):\n        return next(self._epoch_iterator)\n\n    @contextlib.contextmanager\n    def catch_stop_iteration(self):\n        \"\"\"Catches errors when an iterator runs out of data.\"\"\"\n        with super().catch_stop_iteration():\n            try:\n                yield\n                self.tf_sync()\n            except tf.errors.OutOfRangeError:\n                raise StopIteration\n\n\ndef reduce_per_replica(values, strategy, reduction):\n    \"\"\"Attempt to reduce the structure `values` to single values.\n\n    Given `values` (a `tf.Tensor` or a `PerReplica` structure),\n    which represents the values across all the replicas, `reduce_per_replica`\n    attempts to \"reduce\" those values and returns the corresponding structure\n    that represents only single values.\n\n    Currently, `reduce_per_replica` is only used for reducing the metric results\n    from `tf.distribute.Strategy.run()`. Depending on the underlying\n    `Strategy` implementation, `values` may be a `PerReplica` object,\n    which can be thought of as a collection of values across the replicas,\n    or a `tf.Tensor`, if the strategy has already conducted the reduction\n    for the downstream library.\n\n    There are five possible outcomes of reduction:\n\n    1) if the `values` is a structure of simple `tf.Tensor`s, meaning that\n       reduction is not actually needed, `reduce_per_replica` returns the\n       structure as-is.\n    2) else, if `reduction=\"auto\"`, then the best reduction strategy is\n       chosen based on the current environment. This should only be used\n       for training cases (`fit()`).\n    3) else, if `reduction=\"first\"`, then `reduce_per_replica`\n       returns the values of the first replica. This is used in the case of\n       training and evaluation, where `values` is expected to hold the same\n       value across the replicas as a result of `Strategy`'s synchronization\n       across the replicas.\n       `reduce_per_replica` does not synchronize the values.\n    4) else, if `reduction=\"sum\"`, then `reduce_per_replica` returns the sum\n       of values for all replicas. This may be used in the custom training loop\n       case, where each replica contain different values which are not\n       synchronized.\n    5) else, if `reduction=\"concat\"`, then `reduce_per_replica`\n       returns the concatenation of the values across the replicas, along the\n       axis of dimension 0. This is used in the inference case (`predict()`).\n\n    Args:\n        values: Structure of `PerReplica` objects or `tf.Tensor`s.\n            `tf.Tensor`s are returned as-is.\n        strategy: `tf.distribute.Strategy` object.\n        reduction: One of `\"auto\"`, `\"first\"`, `\"concat\"`, `\"mean\"`, or `\"sum\"`.\n            `\"auto\"` will select `\"first\"` when used under a TPUStrategy, or\n            `\"mean\"` otherwise.\n\n    Returns:\n        Structure of `Tensor`s, representing the result of reduction.\n    \"\"\"\n\n    if reduction == \"auto\":\n        if isinstance(strategy, tf.distribute.TPUStrategy):\n            reduction = \"first\"\n        else:\n            reduction = \"mean\"\n\n    def _reduce(v):\n        \"\"\"Reduce a single `PerReplica` object.\"\"\"\n        if _collective_all_reduce_multi_worker(strategy):\n            if reduction == \"concat\":\n                return _multi_worker_concat(v, strategy)\n            elif reduction == \"sum\":\n                return strategy.reduce(\"SUM\", v)\n            elif reduction == \"mean\":\n                return strategy.reduce(\"MEAN\", v, axis=0)\n\n        if not _is_per_replica_instance(v):\n            return v\n        elif reduction == \"first\":\n            return strategy.experimental_local_results(v)[0]\n        elif reduction == \"concat\":\n            if _is_tpu_multi_host(strategy):\n                return _tpu_multi_host_concat(v, strategy)\n            else:\n                return concat(strategy.experimental_local_results(v))\n        elif reduction == \"sum\":\n            return tf.reduce_sum(strategy.experimental_local_results(v))\n        elif reduction == \"mean\":\n            return tf.reduce_mean(\n                strategy.experimental_local_results(v), axis=0\n            )\n        else:\n            raise ValueError(\n                \"`reduction` must be one of \"\n                '\"first\", \"concat\", \"mean\", \"sum\", or \"auto\". '\n                f\"Received: reduction={reduction}.\"\n            )\n\n    return tree.map_structure(_reduce, values)\n\n\ndef _multi_worker_concat(v, strategy):\n    \"\"\"Order PerReplica objects for CollectiveAllReduceStrategy and concat.\"\"\"\n    replicas = strategy.gather(v, axis=0)\n    # v might not have the same shape on different replicas\n    if _is_per_replica_instance(v):\n        shapes = tf.concat(\n            [\n                tf.expand_dims(tf.shape(single_value)[0], axis=0)\n                for single_value in v.values\n            ],\n            axis=0,\n        )\n        all_shapes = strategy.gather(shapes, axis=0)\n    else:\n        # v is a tensor. This may happen when, say, we have 2x1 multi-worker.\n        all_shapes = strategy.gather(\n            tf.expand_dims(tf.shape(v)[0], axis=0), axis=0\n        )\n\n    replicas = tf.split(\n        replicas,\n        num_or_size_splits=all_shapes,\n        num=strategy.num_replicas_in_sync,\n    )\n    ordered_replicas = []\n    num_replicas_per_worker = len(strategy.extended.worker_devices)\n    for replica_id in range(num_replicas_per_worker):\n        ordered_replicas += replicas[replica_id::num_replicas_per_worker]\n    return concat(ordered_replicas)\n\n\ndef concat(tensors, axis=0):\n    \"\"\"Concats `tensor`s along `axis`.\"\"\"\n    if isinstance(tensors[0], tf.SparseTensor):\n        return tf.sparse.concat(axis=axis, sp_inputs=tensors)\n    elif _is_scalar(tensors[0]):\n        return tf.stack(tensors, axis=axis)\n    else:\n        return tf.concat(tensors, axis=axis)\n\n\ndef _tpu_multi_host_concat(v, strategy):\n    \"\"\"Correctly order TPU PerReplica objects.\"\"\"\n    replicas = strategy.experimental_local_results(v)\n    # When distributed datasets are created from Tensors / NumPy,\n    # TPUStrategy.experimental_distribute_dataset shards data in\n    # (Replica, Host) order, and TPUStrategy.experimental_local_results returns\n    # it in (Host, Replica) order.\n    num_replicas_per_host = strategy.extended.num_replicas_per_host\n    ordered_replicas = []\n    for replica_id in range(num_replicas_per_host):\n        ordered_replicas += replicas[replica_id::num_replicas_per_host]\n    return concat(ordered_replicas)\n\n\ndef _collective_all_reduce_multi_worker(strategy):\n    return (\n        isinstance(strategy, tf.distribute.MultiWorkerMirroredStrategy)\n    ) and strategy.extended._in_multi_worker_mode()\n\n\ndef _is_per_replica_instance(obj):\n    return isinstance(obj, tf.distribute.DistributedValues) and isinstance(\n        obj, tf.__internal__.CompositeTensor\n    )\n\n\ndef _is_scalar(x):\n    return isinstance(x, (tf.Tensor, tf.Variable)) and x.shape.rank == 0\n\n\ndef _is_tpu_multi_host(strategy):\n    return _is_tpu_strategy(strategy) and strategy.extended.num_hosts > 1\n\n\ndef _is_tpu_strategy(strategy):\n    return _is_tpu_strategy_class(strategy.__class__)\n\n\ndef _is_tpu_strategy_class(clz):\n    def is_tpu_strat(k):\n        return k.__name__.startswith(\"TPUStrategy\")\n\n    if is_tpu_strat(clz):\n        return True\n    return any(map(_is_tpu_strategy_class, clz.__bases__))\n\n\ndef convert_to_np_if_not_ragged(x):\n    if isinstance(x, tf.RaggedTensor):\n        return x\n    elif isinstance(x, tf.SparseTensor):\n        return x\n    return x.numpy()\n\n\ndef potentially_ragged_concat(tensors):\n    \"\"\"Concats `Tensor`s along their first dimension.\n\n    Args:\n        tensors: List of `Tensor`s.\n\n    Returns:\n        Concatenation of the inputs along the first dimension -- of type\n        `np.ndarray` if all input shapes are compatible, or `tf.RaggedTensor`\n        if not.\n    \"\"\"\n    if len(tensors) == 1:\n        return tensors[0]\n    elif isinstance(tensors[0], tf.SparseTensor):\n        return tf.sparse.concat(axis=0, sp_inputs=tensors)\n    elif isinstance(tensors[0], tf.RaggedTensor):\n        return tf.concat(tensors, axis=0)\n\n    non_batch_shapes = tf.stack([tf.shape(tensor)[1:] for tensor in tensors])\n    constant_dims = tf.math.reduce_all(\n        non_batch_shapes == non_batch_shapes[:1], axis=0\n    )\n    if tf.math.reduce_all(constant_dims).numpy().item():\n        # All non-batch dims are constant\n        if _is_scalar(tensors[0]):\n            return tf.stack(tensors, axis=0)\n        else:\n            return tf.concat(tensors, axis=0)\n\n    # First, identify constant inner dimensions by finding the\n    # rightmost dimension that is not constant\n    constant_inner_dimensions = (\n        constant_dims.numpy().tolist()[::-1].index(False)\n    )\n    # If there are constant inner dimensions, define a constant inner shape\n    if constant_inner_dimensions == 0:\n        constant_inner_shape = None\n    else:\n        constant_inner_shape = tensors[0].shape[-constant_inner_dimensions:]\n    return tf.ragged.constant(\n        [tensor.numpy() for tensor in tensors], inner_shape=constant_inner_shape\n    ).merge_dims(0, 1)\n"
  },
  {
    "path": "keras/src/backend/tests/compute_output_spec_test.py",
    "content": "import unittest\n\nimport pytest\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src.backend.common.keras_tensor import KerasTensor\n\n\ndef single_arg_test_fn(x):\n    return ops.concatenate([(x + 1) ** 2, x], axis=-1)\n\n\ndef three_args_2_kwarg_test_fn(x1, x2, x3=None):\n    x1 = ops.max(x1, axis=1)\n    x2 = ops.max(x2, axis=1)\n    if x3 is not None:\n        x1 += ops.max(x3, axis=1)\n    return x1 + x2\n\n\nclass ComputeOutputSpecTest(unittest.TestCase):\n    def test_dynamic_batch_size(self):\n        x = KerasTensor(shape=(None, 3, 5))\n        y = backend.compute_output_spec(single_arg_test_fn, x)\n        self.assertEqual(y.shape, (None, 3, 10))\n\n        x1 = KerasTensor(shape=(None, 3, 5))\n        x2 = KerasTensor(shape=(None, 3, 5))\n        x3 = KerasTensor(shape=(None, 3, 5))\n        y = backend.compute_output_spec(\n            three_args_2_kwarg_test_fn, x1, x2, x3=x3\n        )\n        self.assertEqual(y.shape, (None, 5))\n\n    def test_dynamic_everything(self):\n        x = KerasTensor(shape=(2, None, 3))\n        y = backend.compute_output_spec(single_arg_test_fn, x)\n        self.assertEqual(y.shape, (2, None, 6))\n\n        x1 = KerasTensor(shape=(None, None, 5))\n        x2 = KerasTensor(shape=(None, None, 5))\n        x3 = KerasTensor(shape=(None, None, 5))\n        y = backend.compute_output_spec(\n            three_args_2_kwarg_test_fn, x1, x2, x3=x3\n        )\n        self.assertEqual(y.shape, (None, 5))\n\n    @pytest.mark.skipif(\n        not backend.SUPPORTS_SPARSE_TENSORS,\n        reason=\"Backend does not support sparse tensors.\",\n    )\n    def test_sparse_to_sparse(self):\n        def single_arg_sparse_fn(x):\n            y0 = ops.transpose(x, axes=(0, 2, 1))\n            y1 = ops.squeeze(ops.expand_dims(x, axis=3), axis=3)\n            return (y0, y1)\n\n        x = KerasTensor(shape=(None, 3, 3), sparse=True)\n        ys = backend.compute_output_spec(single_arg_sparse_fn, x)\n        for y in ys:\n            self.assertEqual(y.shape, (None, 3, 3))\n            self.assertTrue(y.sparse)\n\n        def three_args_sparse_fn(x1, x2, x3=None):\n            y0 = ops.add(x1, x2)  # sparse, sparse\n            y1 = ops.divide(x1, x3)  # sparse, dense\n            y2 = ops.matmul(x1, x2)  # sparse, sparse\n            y3 = ops.multiply(x1, x3)  # sparse, dense\n            return (y0, y1, y2, y3)\n\n        x1 = KerasTensor(shape=(None, 3, 3), sparse=True)\n        x2 = KerasTensor(shape=(None, 3, 3), sparse=True)\n        x3 = KerasTensor(shape=(None, 3, 3), sparse=False)\n        ys = backend.compute_output_spec(three_args_sparse_fn, x1, x2, x3=x3)\n        for y in ys:\n            self.assertEqual(y.shape, (None, 3, 3))\n            self.assertTrue(y.sparse)\n\n    @pytest.mark.skipif(\n        not backend.SUPPORTS_SPARSE_TENSORS,\n        reason=\"Backend does not support sparse tensors.\",\n    )\n    def test_sparse_to_dense(self):\n        def single_arg_dense_fn(x):\n            y0 = ops.exp(x)\n            return (y0,)\n\n        x = KerasTensor(shape=(None, 3, 3), sparse=True)\n        ys = backend.compute_output_spec(single_arg_dense_fn, x)\n        for y in ys:\n            self.assertEqual(y.shape, (None, 3, 3))\n            self.assertFalse(y.sparse)\n\n        def three_args_dense_fn(x1, x2, x3=None):\n            y0 = ops.add(x1, x2)  # sparse, dense\n            y1 = ops.add(x2, x1)  # dense, sparse\n            y2 = ops.concatenate([x1, x2], axis=0)  # sparse, dense\n            y3 = ops.matmul(x1, x2)  # sparse, dense\n            y4 = ops.matmul(x2, x1)  # dense, sparse\n            y5 = ops.take(x2, indices=x3, axis=1)  # dense, sparse\n            y6 = ops.divide(x1, x1)  # sparse, sparse\n            return (y0, y1, y2, y3, y4, y5, y6)\n\n        x1 = KerasTensor(shape=(None, 3, 3), sparse=True)\n        x2 = KerasTensor(shape=(None, 3, 3), sparse=False)\n        x3 = KerasTensor(shape=(3,), dtype=\"int64\", sparse=True)\n        ys = backend.compute_output_spec(three_args_dense_fn, x1, x2, x3=x3)\n        for y in ys:\n            self.assertEqual(y.shape, (None, 3, 3))\n            self.assertFalse(y.sparse)\n"
  },
  {
    "path": "keras/src/backend/tests/device_scope_test.py",
    "content": "import pytest\n\nfrom keras.src import backend\nfrom keras.src import testing\n\n\nclass DeviceTest(testing.TestCase):\n    @pytest.mark.skipif(\n        not testing.tensorflow_uses_gpu(), reason=\"tf on GPU only\"\n    )\n    def test_tf_device_scope(self):\n        with backend.device(\"cpu:0\"):\n            t = backend.numpy.ones((2, 1))\n            self.assertIn(\"CPU:0\", t.device)\n        with backend.device(\"CPU:0\"):\n            t = backend.numpy.ones((2, 1))\n            self.assertIn(\"CPU:0\", t.device)\n\n        # When leaving the scope, the device should be back with gpu:0\n        t = backend.numpy.ones((2, 1))\n        self.assertIn(\"GPU:0\", t.device)\n\n        # Also verify the explicit gpu device\n        with backend.device(\"gpu:0\"):\n            t = backend.numpy.ones((2, 1))\n            self.assertIn(\"GPU:0\", t.device)\n\n    @pytest.mark.skipif(not testing.jax_uses_gpu(), reason=\"jax on GPU only\")\n    def test_jax_device_scope(self):\n        import jax\n\n        with backend.device(\"cpu:0\"):\n            t = backend.numpy.ones((2, 1))\n            self.assertEqual(t.device, jax.devices(\"cpu\")[0])\n        with backend.device(\"CPU:0\"):\n            t = backend.numpy.ones((2, 1))\n            self.assertEqual(t.device, jax.devices(\"cpu\")[0])\n\n        # When leaving the scope, the device should be back with gpu:0\n        t = backend.numpy.ones((2, 1))\n        self.assertEqual(t.device, jax.devices(\"gpu\")[0])\n\n        # Also verify the explicit gpu device\n        with backend.device(\"gpu:0\"):\n            t = backend.numpy.ones((2, 1))\n            self.assertEqual(t.device, jax.devices(\"gpu\")[0])\n\n    @pytest.mark.skipif(backend.backend() != \"jax\", reason=\"jax only\")\n    def test_invalid_jax_device(self):\n        with self.assertRaisesRegex(ValueError, \"Received: device_name='123'\"):\n            backend.device(123).__enter__()\n\n    @pytest.mark.skipif(\n        not testing.torch_uses_gpu(), reason=\"torch on GPU only\"\n    )\n    def test_torch_device_scope(self):\n        import torch\n\n        with backend.device(\"cpu:0\"):\n            t = backend.numpy.ones((2, 1))\n            self.assertEqual(t.device, torch.device(\"cpu\"))\n        with backend.device(\"CPU:0\"):\n            t = backend.numpy.ones((2, 1))\n            self.assertEqual(t.device, torch.device(\"cpu\"))\n\n        # When leaving the scope, the device should be back with gpu:0\n        t = backend.numpy.ones((2, 1))\n        self.assertEqual(t.device, torch.device(\"cuda\", 0))\n\n        # Also verify the explicit gpu -> cuda conversion\n        with backend.device(\"gpu:0\"):\n            t = backend.numpy.ones((2, 1))\n            self.assertEqual(t.device, torch.device(\"cuda\", 0))\n\n    @pytest.mark.skipif(backend.backend() != \"torch\", reason=\"torch only\")\n    def test_invalid_torch_device(self):\n        with self.assertRaisesRegex(ValueError, \"Received: device_name='123'\"):\n            backend.device(123).__enter__()\n\n    @pytest.mark.skipif(backend.backend() != \"torch\", reason=\"torch only\")\n    def test_torch_meta_device(self):\n        import torch\n\n        with torch.device(\"meta\"):\n            x = torch.ones(5)\n\n        t = backend.convert_to_tensor(x)\n\n        if not torch.cuda.is_available():\n            self.assertEqual(t.device, torch.device(\"cpu\"))\n        else:\n            self.assertEqual(t.device, torch.device(\"cuda\", 0))\n"
  },
  {
    "path": "keras/src/backend/torch/__init__.py",
    "content": "\"\"\"Torch backend APIs.\n\n# Note on device placement\n\nTorch has a different device placement style compared to TF and JAX.\nIn short, variables/tensors are not created on GPU by default,\nand the GPU cannot directly communicate with the CPU.\nTo bring Torch behavior in line with TF and JAX automated device placement,\nwe are doing the following to automate device placement if a GPU is available:\n\n- Variables are created on GPU.\n- Input data will be placed on GPU at the first `keras.layers.Layer` call.\n- Tensor creation happens on GPU, e.g., `zeros()` will create a tensor on GPU.\n- `convert_to_numpy` will bring the tensor to CPU before converting it to NumPy.\n\"\"\"\n\nfrom keras.src.backend.common.name_scope import name_scope\nfrom keras.src.backend.torch import core\nfrom keras.src.backend.torch import image\nfrom keras.src.backend.torch import linalg\nfrom keras.src.backend.torch import math\nfrom keras.src.backend.torch import nn\nfrom keras.src.backend.torch import numpy\nfrom keras.src.backend.torch import random\nfrom keras.src.backend.torch.core import IS_THREAD_SAFE\nfrom keras.src.backend.torch.core import SUPPORTS_RAGGED_TENSORS\nfrom keras.src.backend.torch.core import SUPPORTS_SPARSE_TENSORS\nfrom keras.src.backend.torch.core import Variable\nfrom keras.src.backend.torch.core import cast\nfrom keras.src.backend.torch.core import compute_output_spec\nfrom keras.src.backend.torch.core import cond\nfrom keras.src.backend.torch.core import convert_to_numpy\nfrom keras.src.backend.torch.core import convert_to_tensor\nfrom keras.src.backend.torch.core import device_scope\nfrom keras.src.backend.torch.core import is_tensor\nfrom keras.src.backend.torch.core import random_seed_dtype\nfrom keras.src.backend.torch.core import scatter\nfrom keras.src.backend.torch.core import shape\nfrom keras.src.backend.torch.core import stop_gradient\nfrom keras.src.backend.torch.core import to_torch_dtype\nfrom keras.src.backend.torch.core import vectorized_map\nfrom keras.src.backend.torch.rnn import cudnn_ok\nfrom keras.src.backend.torch.rnn import gru\nfrom keras.src.backend.torch.rnn import lstm\nfrom keras.src.backend.torch.rnn import rnn\n"
  },
  {
    "path": "keras/src/backend/torch/core.py",
    "content": "import builtins\nimport contextlib\nimport functools\nimport os\n\nimport ml_dtypes\nimport numpy as np\nimport torch\n\nfrom keras.src import tree\nfrom keras.src.backend.common import KerasVariable\nfrom keras.src.backend.common import global_state\nfrom keras.src.backend.common import standardize_dtype\nfrom keras.src.backend.common.backend_utils import slice_along_axis\nfrom keras.src.backend.common.dtypes import result_type\nfrom keras.src.backend.common.keras_tensor import KerasTensor\nfrom keras.src.backend.common.stateless_scope import StatelessScope\nfrom keras.src.backend.common.stateless_scope import get_stateless_scope\nfrom keras.src.backend.common.stateless_scope import in_stateless_scope\nfrom keras.src.backend.common.symbolic_scope import SymbolicScope\nfrom keras.src.backend.config import floatx\n\nSUPPORTS_SPARSE_TENSORS = False\nSUPPORTS_RAGGED_TENSORS = False\nIS_THREAD_SAFE = True\n\n# Some operators such as 'aten::_foreach_mul_.Scalar'\n# are not currently implemented for the MPS device.\n# check https://github.com/pytorch/pytorch/issues/77764.\nif \"KERAS_TORCH_DEVICE\" in os.environ:\n    DEFAULT_DEVICE = os.environ[\"KERAS_TORCH_DEVICE\"]\nelif torch.backends.mps.is_available():\n    DEFAULT_DEVICE = \"mps\"\nelif torch.cuda.is_available():\n    DEFAULT_DEVICE = \"cuda\"\nelif hasattr(torch, \"xpu\") and torch.xpu.is_available():\n    DEFAULT_DEVICE = \"xpu\"\nelse:\n    DEFAULT_DEVICE = \"cpu\"\n\nTORCH_DTYPES = {\n    \"float16\": torch.float16,\n    \"float32\": torch.float32,\n    \"float64\": torch.float64,\n    \"uint8\": torch.uint8,\n    \"uint16\": torch.int32,  # TODO: Torch doesn't have `uint16` dtype.\n    \"uint32\": torch.int64,  # TODO: Torch doesn't have `uint32` dtype.\n    \"int8\": torch.int8,\n    \"int16\": torch.int16,\n    \"int32\": torch.int32,\n    \"int64\": torch.int64,\n    \"bfloat16\": torch.bfloat16,\n    \"bool\": torch.bool,\n    \"float8_e4m3fn\": torch.float8_e4m3fn,\n    \"float8_e5m2\": torch.float8_e5m2,\n    \"complex32\": torch.complex32,\n    \"complex64\": torch.complex64,\n    \"complex128\": torch.complex128,\n}\n\n\n@contextlib.contextmanager\ndef device_scope(device_name):\n    previous_device = global_state.get_global_attribute(\"torch_device\", None)\n    current_device = _parse_device_input(device_name)\n    global_state.set_global_attribute(\"torch_device\", current_device)\n    try:\n        yield torch.device(current_device)\n    finally:\n        global_state.set_global_attribute(\"torch_device\", previous_device)\n\n\ndef get_device():\n    device = global_state.get_global_attribute(\"torch_device\", None)\n    if device is None:\n        return DEFAULT_DEVICE\n    return device\n\n\ndef _parse_device_input(device_name):\n    if isinstance(device_name, str):\n        # We support string value like \"cpu:0\", \"gpu:1\", and need to convert\n        # \"gpu\" to \"cuda\"\n        device_name = device_name.lower()\n        if \"gpu\" in device_name:\n            device_name = device_name.replace(\"gpu\", \"cuda\")\n    else:\n        raise ValueError(\n            \"Invalid value for argument `device_name`. \"\n            \"Expected a string like 'gpu:0' or 'cpu'. \"\n            f\"Received: device_name='{device_name}'\"\n        )\n    # The torch.Device instance can be used directly.\n    return device_name\n\n\ndef to_torch_dtype(dtype):\n    standardized_dtype = TORCH_DTYPES.get(standardize_dtype(dtype), None)\n    if standardized_dtype is None:\n        raise ValueError(f\"Unsupported dtype for PyTorch: {dtype}\")\n    return standardized_dtype\n\n\nclass Variable(KerasVariable):\n    def _initialize(self, value):\n        if isinstance(value, torch.nn.Parameter):\n            # Reuse same parameter\n            self._value = value\n        else:\n            self._value = torch.nn.Parameter(\n                convert_to_tensor(value, dtype=self._dtype),\n                requires_grad=self.trainable,\n            ).to(get_device())\n\n    def _direct_assign(self, value):\n        with torch.no_grad():\n            self.value.copy_(value)\n\n    def _convert_to_tensor(self, value, dtype=None):\n        return convert_to_tensor(value, dtype=dtype)\n\n    # Overload native accessor.\n    @classmethod\n    def __torch_function__(cls, func, types, args=(), kwargs=None):\n        args = [arg.value if isinstance(arg, Variable) else arg for arg in args]\n        if kwargs is None:\n            kwargs = {}\n        kwargs = {\n            key: value.value if isinstance(value, Variable) else value\n            for key, value in kwargs.items()\n        }\n        return func(*args, **kwargs)\n\n    def __array__(self, dtype=None):\n        value = convert_to_numpy(self.value)\n        if dtype:\n            return value.astype(dtype)\n        return value\n\n    @property\n    def value(self):\n        # We cannot chain super() here because it will fail TorchDynamo. The\n        # reason why is unclear.\n        def maybe_use_symbolic_tensor(value):\n            # Create and use a symbolic tensor stub in symbolic calls.\n            if str(get_device()) == \"meta\" and str(value.device) != \"meta\":\n                return torch.nn.Parameter(\n                    torch.empty(\n                        size=self._shape,\n                        dtype=to_torch_dtype(self._dtype),\n                        device=\"meta\",\n                    ),\n                    requires_grad=self.trainable,\n                )\n            return value\n\n        if in_stateless_scope():\n            scope = get_stateless_scope()\n            value = scope.get_current_value(self)\n            if value is not None:\n                value = self._maybe_autocast(value)\n                return maybe_use_symbolic_tensor(value)\n        if self._value is None:\n            # Uninitialized variable. Return a placeholder.\n            # This is fine because it's only ever used\n            # in during shape inference / graph tracing\n            # (anything else would be a bug, to be fixed.)\n            value = self._maybe_autocast(\n                self._initializer(self._shape, dtype=self._dtype)\n            )\n        else:\n            value = self._maybe_autocast(self._value)\n        return maybe_use_symbolic_tensor(value)\n\n    @property\n    def trainable(self):\n        return self._trainable\n\n    @trainable.setter\n    def trainable(self, value):\n        self._trainable = value\n        if self._value is not None:\n            self._value.requires_grad = value\n\n    def __eq__(self, other):\n        try:\n            return super().__eq__(other)\n        except Exception:\n            return False\n\n\ndef convert_to_tensor(x, dtype=None, sparse=None, ragged=None):\n    if sparse:\n        raise ValueError(\"`sparse=True` is not supported with torch backend\")\n    if ragged:\n        raise ValueError(\"`ragged=True` is not supported with torch backend\")\n    if isinstance(x, Variable) or is_tensor(x):\n        if isinstance(x, Variable):\n            x = x.value\n        device = get_device()\n        if x.device != device:\n            if x.is_meta:\n                x = torch.empty_like(x, device=device)\n            else:\n                x = x.to(device)\n        if dtype is not None:\n            x = x.to(to_torch_dtype(dtype))\n        return x\n    if dtype is None:\n        if isinstance(x, bool):\n            return torch.as_tensor(x, dtype=torch.bool, device=get_device())\n        elif isinstance(x, int):\n            if x < -(2**31) or x >= 2**31:\n                return torch.as_tensor(\n                    x, dtype=torch.int64, device=get_device()\n                )\n            return torch.as_tensor(x, dtype=torch.int32, device=get_device())\n        elif isinstance(x, float):\n            return torch.as_tensor(\n                x, dtype=to_torch_dtype(floatx()), device=get_device()\n            )\n\n    # Convert to np in case of any array-like that is not list or tuple.\n    if not isinstance(x, (list, tuple)):\n        x = np.array(x)\n    elif len(x) > 0 and any(isinstance(x1, torch.Tensor) for x1 in x):\n        # Handle list or tuple of torch tensors\n        return torch.stack([convert_to_tensor(x1) for x1 in x])\n    if isinstance(x, np.ndarray):\n        if x.dtype == np.uint32:\n            # Torch backend does not support uint32.\n            x = x.astype(np.int64)\n        if standardize_dtype(x.dtype) == \"bfloat16\":\n            # Torch backend does not support converting bfloat16 ndarray.\n            x = x.astype(np.float32)\n            dtype = \"bfloat16\"\n        dtype = dtype or x.dtype\n    if dtype is None:\n        dtype = result_type(\n            *[getattr(item, \"dtype\", type(item)) for item in tree.flatten(x)]\n        )\n    dtype = to_torch_dtype(dtype)\n    return torch.as_tensor(x, dtype=dtype, device=get_device())\n\n\ndef convert_to_numpy(x):\n    def transform(x):\n        if is_tensor(x):\n            if x.requires_grad:\n                x = x.detach()\n            # Tensor has to be moved to CPU before converting to numpy.\n            if x.device != torch.device(\"cpu\"):\n                x = x.cpu()\n            if x.dtype == torch.bfloat16:\n                # Attempting to call .numpy() on a bfloat16 torch tensor leads\n                # to an immediate error. Instead we upcast to float32 and then\n                # convert to the numpy friendly bfloat16 type.\n                # https://github.com/pytorch/pytorch/issues/90574\n                return np.array(x.to(torch.float32)).astype(ml_dtypes.bfloat16)\n        return np.array(x)\n\n    if isinstance(x, (list, tuple)):\n        return np.array([transform(e) for e in x])\n    return transform(x)\n\n\ndef is_tensor(x):\n    # Using the built-in `isinstance` is recommended by pytorch\n    # over using torch.is_tensor\n    # see: https://pytorch.org/docs/stable/generated/torch.is_tensor.html\n    #\n    # Also, `torch.is_tensor()` causes issues with dynamo caching when\n    # a torch.Tensor and numpy.ndarray of the same size, shape, and dtype\n    # is passed, if called on a Tensor first the second call with ndarray\n    # will return `True` and vice-versa.\n    return isinstance(x, torch.Tensor)\n\n\ndef shape(x):\n    # Convert from `torch.Size` to plain tuple.\n    return tuple(x.shape)\n\n\ndef cast(x, dtype):\n    dtype = to_torch_dtype(dtype)\n    if isinstance(x, Variable):\n        x = x.value\n    if is_tensor(x):\n        if x.dtype == dtype:\n            return x\n        else:\n            return x.to(dtype)\n    return convert_to_tensor(x, dtype)\n\n\n# Shape / dtype inference util\ndef compute_output_spec(fn, *args, **kwargs):\n    def has_none_shape(x):\n        \"\"\"Check for if a `KerasTensor` has dynamic shape.\"\"\"\n        if isinstance(x, KerasTensor):\n            return None in x.shape\n        return False\n\n    def convert_keras_tensor_to_torch(x, fill_value=None):\n        \"\"\"Convert `KerasTensor`s to `torch.Tensor`s.\"\"\"\n        if isinstance(x, KerasTensor):\n            shape = list(x.shape)\n            if fill_value:\n                for i, e in enumerate(shape):\n                    if e is None:\n                        shape[i] = fill_value\n            return torch.ones(\n                size=shape,\n                dtype=TORCH_DTYPES[x.dtype],\n                device=get_device(),\n            )\n        return x\n\n    def convert_torch_to_keras_tensor(x):\n        \"\"\"Convert `torch.Tensor`s to `KerasTensor`s.\"\"\"\n        if is_tensor(x):\n            return KerasTensor(x.shape, standardize_dtype(x.dtype))\n        return x\n\n    def symbolic_call(fn, args, kwargs, fill_value):\n        \"\"\"Call `fn` to infer output shape and dtype.\"\"\"\n        try:\n            # First try instantiating all tensors on the `\"meta\"` device,\n            # which  should give a \"zero flop\" way to trace shape, but does\n            # not have universal support with torch operations.\n            with device_scope(\"meta\"):\n                meta_args, meta_kwargs = tree.map_structure(\n                    lambda x: convert_keras_tensor_to_torch(x, fill_value),\n                    (args, kwargs),\n                )\n                return fn(*meta_args, **meta_kwargs)\n        except:\n            with device_scope(DEFAULT_DEVICE):\n                # If the `\"meta\"` device placement fails, fall back to tracing\n                # eagerly with tensors on the default device. This will be\n                # more robust, but more expensive.\n                eager_args, eager_kwargs = tree.map_structure(\n                    lambda x: convert_keras_tensor_to_torch(x, fill_value),\n                    (args, kwargs),\n                )\n                return fn(*eager_args, **eager_kwargs)\n\n    with StatelessScope(), SymbolicScope(), torch.no_grad():\n        outputs = symbolic_call(fn, args, kwargs, fill_value=83)\n\n        none_in_shape = any(\n            builtins.map(has_none_shape, tree.flatten((args, kwargs)))\n        )\n        if none_in_shape:\n            outputs_1 = outputs\n            outputs_2 = symbolic_call(fn, args, kwargs, fill_value=89)\n\n            flat_out_1 = tree.flatten(outputs_1)\n            flat_out_2 = tree.flatten(outputs_2)\n\n            flat_out = []\n            for x1, x2 in zip(flat_out_1, flat_out_2):\n                shape = list(x1.shape)\n                for i, e in enumerate(x2.shape):\n                    if e != shape[i]:\n                        shape[i] = None\n                flat_out.append(KerasTensor(shape, standardize_dtype(x1.dtype)))\n            outputs = tree.pack_sequence_as(outputs_1, flat_out)\n\n        output_spec = tree.map_structure(convert_torch_to_keras_tensor, outputs)\n    return output_spec\n\n\ndef cond(pred, true_fn, false_fn):\n    # When symbolic execution, take pred as true.\n    if get_device() == \"meta\":\n        return true_fn()\n\n    if pred:\n        return true_fn()\n    return false_fn()\n\n\ndef vectorized_map(function, elements):\n    return torch.vmap(function)(elements)\n\n\ndef map(f, xs):\n    def g(_, x):\n        return (), f(x)\n\n    _, ys = scan(g, (), xs)\n    return ys\n\n\ndef scan(f, init, xs=None, length=None, reverse=False, unroll=1):\n    # Ref: jax.lax.scan\n    if not callable(f):\n        raise TypeError(f\"`f` should be a callable. Received: f={f}\")\n    if not isinstance(unroll, bool):\n        if not isinstance(unroll, int) or unroll < 1:\n            raise ValueError(\n                \"`unroll` must be an positive integer or boolean. \"\n                f\"Received: unroll={unroll}\"\n            )\n    if xs is None and length is None:\n        raise ValueError(\"Got no `xs` to scan over and `length` not provided.\")\n\n    input_is_sequence = tree.is_nested(xs)\n    output_is_sequence = tree.is_nested(init)\n\n    def pack_input(x):\n        return tree.pack_sequence_as(xs, x) if input_is_sequence else x[0]\n\n    def pack_output(x):\n        return tree.pack_sequence_as(init, x) if output_is_sequence else x[0]\n\n    if xs is None:\n        xs_flat = []\n        n = int(length)\n    else:\n        xs_flat = tree.flatten(xs)\n        xs_flat = [convert_to_tensor(elem) for elem in xs_flat]\n        n = int(length) if length is not None else shape(xs_flat[0])[0]\n\n    init_flat = tree.flatten(init)\n    init_flat = [convert_to_tensor(init) for init in init_flat]\n    init = pack_output(init_flat)\n    dummy_y = [torch.zeros_like(init) for init in init_flat]\n\n    carry = init\n    ys = []\n    maybe_reversed = reversed if reverse else lambda x: x\n    for i in maybe_reversed(range(n)):\n        xs_slice = [x[i] for x in xs_flat]\n        packed_xs = pack_input(xs_slice) if len(xs_slice) > 0 else None\n        carry, y = f(carry, packed_xs)\n        ys.append(y if y is not None else dummy_y)\n    stacked_y = tree.map_structure(\n        lambda *ys: torch.stack(ys), *maybe_reversed(ys)\n    )\n    return carry, stacked_y\n\n\ndef associative_scan(f, elems, reverse=False, axis=0):\n    # Ref: jax.lax.associative_scan\n    if not callable(f):\n        raise TypeError(f\"`f` should be a callable. Received: f={f}\")\n    elems_flat = tree.flatten(elems)\n    elems_flat = [convert_to_tensor(elem) for elem in elems_flat]\n    if reverse:\n        elems_flat = [torch.flip(elem, (axis,)) for elem in elems_flat]\n\n    def _combine(a_flat, b_flat):\n        a_flat = [convert_to_tensor(a) for a in a_flat]\n        b_flat = [convert_to_tensor(b) for b in b_flat]\n\n        a = tree.pack_sequence_as(elems, a_flat)\n        b = tree.pack_sequence_as(elems, b_flat)\n        c = f(a, b)\n        c_flat = tree.flatten(c)\n        return c_flat\n\n    num_elems = int(elems_flat[0].shape[axis])\n    if not all(int(elem.shape[axis]) == num_elems for elem in elems_flat[1:]):\n        raise ValueError(\n            \"Array inputs to associative_scan must have the same \"\n            \"first dimension. (saw: {})\".format(\n                [elem.shape for elem in elems_flat]\n            )\n        )\n\n    def _interleave(a, b, axis):\n        \"\"\"Given two Tensors of static shape, interleave them along axis.\"\"\"\n        if not (\n            a.shape[axis] == b.shape[axis] or a.shape[axis] == b.shape[axis] + 1\n        ):\n            raise ValueError(\n                \"Shapes are incompatible for associative_scan interleaving. \"\n                f\"a.shape[{axis}]={a.shape[axis]}, \"\n                f\"b.shape[{axis}]={b.shape[axis]}\"\n            )\n\n        # we want to get a: [a1, a2], b: [b1, b2]\n        # to a: [a1, 0, a2, 0], b: [0, b1, 0, b2]\n        a_shape = list(a.shape)\n        a_shape[axis] = a.shape[axis] * 2 - 1\n\n        b_shape = list(b.shape)\n        b_shape[axis] = b.shape[axis] * 2 - 1\n\n        a_dil = torch.zeros(a_shape)\n        slice_along_axis(a_dil, 0, None, 2, axis).copy_(a)\n\n        b_dil = torch.zeros(b_shape)\n        slice_along_axis(b_dil, 0, None, 2, axis).copy_(b)\n\n        a_pad = [[0, 0] for _ in range(a.dim())]\n        a_pad[axis][-1] = 1 if a.shape[axis] == b.shape[axis] else 0\n        a_pad = a_pad[::-1]\n        a_pad = tree.flatten(a_pad)\n\n        b_pad = [[0, 0] for _ in range(b.dim())]\n        b_pad[axis] = [1, 0] if a.shape[axis] == b.shape[axis] else [1, 1]\n        b_pad = b_pad[::-1]\n        b_pad = tree.flatten(b_pad)\n\n        op = torch.bitwise_or if a.dtype == torch.bool else torch.add\n        return op(\n            torch.nn.functional.pad(a_dil, a_pad),\n            torch.nn.functional.pad(b_dil, b_pad),\n        )\n\n    def _scan(elems):\n        num_elems = elems[0].shape[axis]\n        if num_elems < 2:\n            return elems\n\n        reduced_elems = _combine(\n            [\n                slice_along_axis(elem, 0, -1, step=2, axis=axis)\n                for elem in elems\n            ],\n            [\n                slice_along_axis(elem, 1, None, step=2, axis=axis)\n                for elem in elems\n            ],\n        )\n\n        odd_elems = _scan(reduced_elems)\n        if num_elems % 2 == 0:\n            even_elems = _combine(\n                [slice_along_axis(e, 0, -1, axis=axis) for e in odd_elems],\n                [\n                    slice_along_axis(e, 2, None, step=2, axis=axis)\n                    for e in elems\n                ],\n            )\n        else:\n            even_elems = _combine(\n                odd_elems,\n                [\n                    slice_along_axis(e, 2, None, step=2, axis=axis)\n                    for e in elems\n                ],\n            )\n\n        even_elems = [\n            torch.cat(\n                [slice_along_axis(elem, 0, 1, axis=axis), result],\n                dim=axis,\n            )\n            for (elem, result) in zip(elems, even_elems)\n        ]\n        return list(\n            builtins.map(\n                functools.partial(_interleave, axis=axis), even_elems, odd_elems\n            )\n        )\n\n    scans = _scan(elems_flat)\n    if reverse:\n        scans = [torch.flip(scanned, (axis,)) for scanned in scans]\n\n    return tree.pack_sequence_as(elems, scans)\n\n\ndef scatter(indices, values, shape):\n    indices = convert_to_tensor(indices)\n    values = convert_to_tensor(values)\n    zeros = torch.zeros(shape, dtype=values.dtype, device=get_device())\n\n    index_length = indices.shape[-1]\n    value_shape = shape[index_length:]\n    indices = torch.reshape(indices, [-1, index_length])\n    values = torch.reshape(values, [-1] + list(value_shape))\n\n    for i in range(indices.shape[0]):\n        index = indices[i]\n        zeros[tuple(index)] += values[i]\n    return zeros\n\n\ndef scatter_update(inputs, indices, updates, reduction=None):\n    inputs = convert_to_tensor(inputs)\n    indices = convert_to_tensor(indices, dtype=\"int64\")\n    updates = convert_to_tensor(updates, dtype=inputs.dtype)\n    indices = torch.transpose(indices, 0, 1)\n    idx = tuple(indices)\n\n    outputs = torch.clone(inputs)\n    if reduction is None:\n        outputs[idx] = updates\n    elif reduction == \"add\":\n        # Use index_put_ with accumulate=True for proper accumulation\n        outputs.index_put_(idx, updates, accumulate=True)\n    elif reduction == \"max\":\n        # Loop-based approach handles both scalar and slice updates.\n        # Associative, so sequential application handles duplicates.\n        indices_t = indices.T\n        for i in range(indices_t.shape[0]):\n            idx = tuple(indices_t[i])\n            outputs[idx] = torch.maximum(outputs[idx], updates[i])\n    elif reduction == \"min\":\n        indices_t = indices.T\n        for i in range(indices_t.shape[0]):\n            idx = tuple(indices_t[i])\n            outputs[idx] = torch.minimum(outputs[idx], updates[i])\n    elif reduction == \"mul\":\n        indices_t = indices.T\n        for i in range(indices_t.shape[0]):\n            idx = tuple(indices_t[i])\n            outputs[idx] = outputs[idx] * updates[i]\n    else:\n        raise ValueError(f\"Unsupported reduction: {reduction}\")\n    return outputs\n\n\ndef slice(inputs, start_indices, shape):\n    shape_dtype = to_torch_dtype(\"int64\")\n    inputs = convert_to_tensor(inputs)\n    start_indices = convert_to_tensor(start_indices).to(shape_dtype)\n    shape = convert_to_tensor(shape).to(shape_dtype)\n\n    python_slice = __builtins__[\"slice\"]\n    slices = [\n        python_slice(start_index, start_index + length)\n        for start_index, length in zip(start_indices, shape)\n    ]\n    return inputs[slices]\n\n\ndef slice_update(inputs, start_indices, updates):\n    shape_dtype = to_torch_dtype(\"int64\")\n    inputs = convert_to_tensor(inputs)\n    start_indices = convert_to_tensor(start_indices).to(shape_dtype)\n    updates = convert_to_tensor(updates)\n\n    python_slice = __builtins__[\"slice\"]\n    slices = [\n        python_slice(start_index, start_index + update_length)\n        for start_index, update_length in zip(start_indices, updates.shape)\n    ]\n    outputs = torch.clone(inputs)\n    outputs[slices] = updates\n    return outputs\n\n\ndef switch(index, branches, *operands):\n    index = convert_to_tensor(index, \"int32\")\n    index = torch.clamp(index, 0, len(branches) - 1)\n    return branches[index](*operands)\n\n\ndef while_loop(\n    cond,\n    body,\n    loop_vars,\n    maximum_iterations=None,\n):\n    current_iter = 0\n    iteration_check = lambda iter: (\n        maximum_iterations is None or iter < maximum_iterations\n    )\n    is_tuple = isinstance(loop_vars, (tuple, list))\n    loop_vars = tuple(loop_vars) if is_tuple else (loop_vars,)\n    loop_vars = tree.map_structure(convert_to_tensor, loop_vars)\n    while cond(*loop_vars) and iteration_check(current_iter):\n        loop_vars = body(*loop_vars)\n        if not isinstance(loop_vars, (list, tuple)):\n            loop_vars = (loop_vars,)\n        loop_vars = tuple(loop_vars)\n        current_iter += 1\n    return loop_vars if is_tuple else loop_vars[0]\n\n\ndef fori_loop(lower, upper, body_fun, init_val):\n    val = init_val\n    for i in range(lower, upper):\n        val = body_fun(i, val)\n    return val\n\n\ndef stop_gradient(variable):\n    if isinstance(variable, Variable):\n        variable = variable.value\n    # We can't use `.requires_grad_(False)` here since it only\n    # works when the tensor is a leaf node in the graph.\n    return variable.detach()\n\n\ndef unstack(x, num=None, axis=0):\n    return x.unbind(axis)\n\n\ndef random_seed_dtype():\n    # uint32 doesn't exist in torch. Seeds are conceptually uint32 values;\n    # int32 is used and the bit pattern is reinterpreted as uint32 at each\n    # call site (torch_seed_generator / torch.manual_seed) via & 0xFFFFFFFF.\n    return \"int32\"\n\n\ndef remat(f):\n    \"\"\"Implementation of rematerialization.\n\n    Args:\n        f: The function or operation to rematerialize.\n    Returns:\n        A function wrapping f that defines a custom gradient, which\n        recomputes f on the backwards pass of a gradient call.\n    \"\"\"\n\n    def wrapped(*args, **kwargs):\n        return torch.utils.checkpoint.checkpoint(\n            f, *args, use_reentrant=False, **kwargs\n        )\n\n    return wrapped\n\n\nclass custom_gradient:\n    \"\"\"Decorator for custom gradients.\n\n    Args:\n        forward_fn: Forward pass function.\n    \"\"\"\n\n    def __init__(self, forward_fn):\n        self.forward_fn = forward_fn\n\n    def __call__(self, *args, **kwargs):\n        return CustomGradientFunction.apply(self.forward_fn, *args, **kwargs)\n\n\nclass CustomGradientFunction(torch.autograd.Function):\n    \"\"\"Enables custom forward & backward passes for gradient computation.\"\"\"\n\n    @staticmethod\n    def forward(ctx, forward_fn, *args, **kwargs):\n        \"\"\"Forward pass computation specification.\n\n        Args:\n            ctx: Context object.\n            forward_fn: Function to compute forward pass.\n            *args: Arguments for the forward pass.\n            **kwargs: Keyword arguments for the forward pass.\n        \"\"\"\n        ctx.forward_fn = forward_fn\n        ctx.save_for_backward(*args)\n        try:\n            output, ctx.grad_fn = forward_fn(*args, **kwargs)\n        except:\n            output = forward_fn(*args, **kwargs)\n            ctx.grad_fn = lambda *args, **kwargs: torch.full((), float(\"nan\"))\n        return output\n\n    @staticmethod\n    def backward(ctx, grad_output):\n        \"\"\"Backward pass computation specification.\n\n        Args:\n            ctx: Context object.\n            grad_output: Gradient with respect to the output.\n        \"\"\"\n        args = ctx.saved_tensors\n        grad_fn = ctx.grad_fn\n        if grad_fn is None:\n            raise ValueError(\"grad_fn must be provided for custom gradient\")\n        grads = grad_fn(*args, upstream=grad_output)\n        if not isinstance(grads, tuple):\n            grads = (grads,)\n        return (None,) + grads\n"
  },
  {
    "path": "keras/src/backend/torch/export.py",
    "content": "import copy\nimport warnings\n\nimport torch\n\nfrom keras.src import layers\nfrom keras.src import tree\nfrom keras.src.export.export_utils import convert_spec_to_tensor\nfrom keras.src.export.export_utils import make_tf_tensor_spec\nfrom keras.src.export.saved_model_export_archive import SavedModelExportArchive\nfrom keras.src.utils.module_utils import tensorflow as tf\nfrom keras.src.utils.module_utils import torch_xla\n\n\nclass TorchExportArchive(SavedModelExportArchive):\n    \"\"\"Torch backend implementation of SavedModel export archive.\"\"\"\n\n    def _backend_track_layer(self, layer):\n        raise NotImplementedError(\n            \"`track` is not supported for `Layer`s and `Model`s in the torch \"\n            \"backend. Use `track_and_add_endpoint` instead.\"\n        )\n\n    def _backend_add_endpoint(self, name, fn, input_signature, **kwargs):\n        raise NotImplementedError(\n            \"`add_endpoint` is not supported for `Layer`s and `Model`s in the \"\n            \"torch backend. Use `track_and_add_endpoint` instead.\"\n        )\n\n    def track_and_add_endpoint(self, name, resource, input_signature, **kwargs):\n        if name in self._endpoint_names:\n            raise ValueError(f\"Endpoint name '{name}' is already taken.\")\n        if not isinstance(resource, layers.Layer):\n            raise ValueError(\n                \"Invalid resource type. Expected an instance of a Keras \"\n                \"`Layer` or `Model`. \"\n                f\"Received: resource={resource} (of type {type(resource)})\"\n            )\n        if not resource.built:\n            raise ValueError(\n                \"The layer provided has not yet been built. \"\n                \"It must be built before export.\"\n            )\n\n        input_signature = tree.map_structure(\n            make_tf_tensor_spec, input_signature\n        )\n        # Disable false alarms related to lifting parameters.\n        warnings.filterwarnings(\"ignore\", message=\".*created when tracing.*\")\n        warnings.filterwarnings(\n            \"ignore\", message=\".*Unable to find the path of the module.*\"\n        )\n\n        if not isinstance(resource, torch.nn.Module):\n            raise TypeError(\n                \"`resource` must be an instance of `torch.nn.Module`. \"\n                f\"Received: resource={resource} (of type {type(resource)})\"\n            )\n\n        sample_inputs = tree.map_structure(\n            lambda x: convert_spec_to_tensor(x, replace_none_number=2),\n            input_signature,\n        )\n        sample_inputs = tuple(sample_inputs)\n\n        # Build dynamic_shapes from input_signature where shape has None\n        # Use a shared \"batch\" dim for dimension 0 across all inputs\n        batch_dim = torch.export.Dim(\"batch\", min=1)\n        dynamic_shapes = []\n        for spec in input_signature:\n            dim_spec = {}\n            for dim_idx, dim_val in enumerate(spec.shape):\n                if dim_val is None:\n                    if dim_idx == 0:\n                        dim_spec[dim_idx] = batch_dim\n                    else:\n                        dim_spec[dim_idx] = torch.export.Dim(\n                            f\"dim_{len(dynamic_shapes)}_{dim_idx}\", min=1\n                        )\n            dynamic_shapes.append(dim_spec if dim_spec else None)\n        dynamic_shapes = tuple(dynamic_shapes) if any(dynamic_shapes) else None\n\n        # Ref: torch_xla.tf_saved_model_integration\n        exported = torch.export.export(\n            resource, sample_inputs, dynamic_shapes=dynamic_shapes, strict=False\n        )\n        options = torch_xla.stablehlo.StableHLOExportOptions(\n            override_tracing_arguments=sample_inputs\n        )\n        stablehlo_model = torch_xla.stablehlo.exported_program_to_stablehlo(\n            exported, options\n        )\n        state_dict_keys = list(stablehlo_model._bundle.state_dict.keys())\n\n        # Remove unused variables.\n        for k in state_dict_keys:\n            if \"lifted\" not in k:\n                stablehlo_model._bundle.state_dict.pop(k)\n\n        bundle = copy.deepcopy(stablehlo_model._bundle)\n        bundle.state_dict = {\n            k: tf.Variable(v, trainable=False, name=k)\n            for k, v in bundle.state_dict.items()\n        }\n        bundle.additional_constants = [\n            tf.Variable(v, trainable=False) for v in bundle.additional_constants\n        ]\n\n        # Track variables in `bundle` for `write_out`.\n        self._tf_trackable.variables += (\n            list(bundle.state_dict.values()) + bundle.additional_constants\n        )\n\n        # Ref: torch_xla.tf_saved_model_integration.save_stablehlo_graph_as_tf\n        def make_tf_function(func, bundle):\n            from tensorflow.compiler.tf2xla.python import xla as tfxla\n\n            def _get_shape_with_dynamic(signature):\n                shape = copy.copy(signature.shape)\n                for i in signature.dynamic_dims:\n                    shape[i] = None\n                return shape\n\n            def _extract_call_parameters(args, meta, bundle):\n                call_args = []\n                if meta.input_pytree_spec is not None:\n                    args = tree.flatten(args)\n                for loc in meta.input_locations:\n                    if loc.type_ == torch_xla.stablehlo.VariableType.PARAMETER:\n                        call_args.append(bundle.state_dict[loc.name])\n                    elif loc.type_ == torch_xla.stablehlo.VariableType.CONSTANT:\n                        call_args.append(\n                            bundle.additional_constants[loc.position]\n                        )\n                    else:\n                        call_args.append(args[loc.position])\n                return call_args\n\n            def inner(*args):\n                Touts = [sig.dtype for sig in func.meta.output_signature]\n                Souts = [\n                    _get_shape_with_dynamic(sig)\n                    for sig in func.meta.output_signature\n                ]\n                call_args = _extract_call_parameters(args, func.meta, bundle)\n                results = tfxla.call_module(\n                    tuple(call_args),\n                    version=5,\n                    Tout=Touts,  # dtype information\n                    Sout=Souts,  # Shape information\n                    function_list=[],\n                    module=func.bytecode,\n                )\n                if len(Souts) == 1:\n                    results = results[0]\n                return results\n\n            return inner\n\n        decorated_fn = tf.function(\n            make_tf_function(\n                stablehlo_model._bundle.stablehlo_funcs[0], bundle\n            ),\n            input_signature=input_signature,\n        )\n        self._endpoint_signatures[name] = input_signature\n        setattr(self._tf_trackable, name, decorated_fn)\n        self._endpoint_names.append(name)\n        return decorated_fn\n"
  },
  {
    "path": "keras/src/backend/torch/image.py",
    "content": "import functools\nimport itertools\nimport operator\n\nimport numpy as np\nimport torch\nimport torch._dynamo as dynamo\nimport torch.nn.functional as F\n\nfrom keras.src import backend\nfrom keras.src.backend.torch.core import cast\nfrom keras.src.backend.torch.core import convert_to_tensor\nfrom keras.src.backend.torch.core import get_device\nfrom keras.src.backend.torch.core import to_torch_dtype\nfrom keras.src.random.seed_generator import draw_seed\n\nRESIZE_INTERPOLATIONS = {\n    \"bilinear\": \"bilinear\",\n    \"nearest\": \"nearest-exact\",\n    \"bicubic\": \"bicubic\",\n}\nUNSUPPORTED_INTERPOLATIONS = (\n    \"lanczos3\",\n    \"lanczos5\",\n)\nAFFINE_TRANSFORM_INTERPOLATIONS = {\n    \"nearest\": 0,\n    \"bilinear\": 1,\n}\nAFFINE_TRANSFORM_FILL_MODES = {\n    \"constant\",\n    \"nearest\",\n    \"wrap\",\n    \"mirror\",\n    \"reflect\",\n}\nSCALE_AND_TRANSLATE_METHODS = {\n    \"linear\",\n    \"bilinear\",\n    \"trilinear\",\n    \"cubic\",\n    \"bicubic\",\n    \"tricubic\",\n    \"lanczos3\",\n    \"lanczos5\",\n}\n\n\ndef rgb_to_grayscale(images, data_format=None):\n    images = convert_to_tensor(images)\n    data_format = backend.standardize_data_format(data_format)\n    if images.ndim not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n    channel_axis = -3 if data_format == \"channels_first\" else -1\n    if images.shape[channel_axis] not in (1, 3):\n        raise ValueError(\n            \"Invalid channel size: expected 3 (RGB) or 1 (Grayscale). \"\n            f\"Received input with shape: images.shape={images.shape}\"\n        )\n\n    # This implementation is based on\n    # https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py\n    if images.shape[channel_axis] == 3:\n        r, g, b = images.unbind(dim=channel_axis)\n        images = (0.2989 * r + 0.587 * g + 0.114 * b).to(images.dtype)\n        images = images.unsqueeze(dim=channel_axis)\n    else:\n        images = images.clone()\n    return images\n\n\ndef rgb_to_hsv(images, data_format=None):\n    # Ref: dm_pix\n    images = convert_to_tensor(images)\n    dtype = images.dtype\n    data_format = backend.standardize_data_format(data_format)\n    channels_axis = -1 if data_format == \"channels_last\" else -3\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n    if not backend.is_float_dtype(dtype):\n        raise ValueError(\n            \"Invalid images dtype: expected float dtype. \"\n            f\"Received: images.dtype={backend.standardize_dtype(dtype)}\"\n        )\n    eps = torch.finfo(dtype).eps\n    images = torch.where(torch.abs(images) < eps, 0.0, images)\n    red, green, blue = torch.split(images, [1, 1, 1], channels_axis)\n    red = torch.squeeze(red, channels_axis)\n    green = torch.squeeze(green, channels_axis)\n    blue = torch.squeeze(blue, channels_axis)\n\n    def rgb_planes_to_hsv_planes(r, g, b):\n        value = torch.maximum(torch.maximum(r, g), b)\n        minimum = torch.minimum(torch.minimum(r, g), b)\n        range_ = value - minimum\n\n        safe_value = torch.where(value > 0, value, 1.0)\n        safe_range = torch.where(range_ > 0, range_, 1.0)\n\n        saturation = torch.where(value > 0, range_ / safe_value, 0.0)\n        norm = 1.0 / (6.0 * safe_range)\n\n        hue = torch.where(\n            value == g,\n            norm * (b - r) + 2.0 / 6.0,\n            norm * (r - g) + 4.0 / 6.0,\n        )\n        hue = torch.where(value == r, norm * (g - b), hue)\n        hue = torch.where(range_ > 0, hue, 0.0) + (hue < 0.0).to(hue.dtype)\n        return hue, saturation, value\n\n    images = torch.stack(\n        rgb_planes_to_hsv_planes(red, green, blue), axis=channels_axis\n    )\n    return images\n\n\ndef hsv_to_rgb(images, data_format=None):\n    # Ref: dm_pix\n    images = convert_to_tensor(images)\n    dtype = images.dtype\n    data_format = backend.standardize_data_format(data_format)\n    channels_axis = -1 if data_format == \"channels_last\" else -3\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n    if not backend.is_float_dtype(dtype):\n        raise ValueError(\n            \"Invalid images dtype: expected float dtype. \"\n            f\"Received: images.dtype={backend.standardize_dtype(dtype)}\"\n        )\n    hue, saturation, value = torch.split(images, [1, 1, 1], channels_axis)\n    hue = torch.squeeze(hue, channels_axis)\n    saturation = torch.squeeze(saturation, channels_axis)\n    value = torch.squeeze(value, channels_axis)\n\n    def hsv_planes_to_rgb_planes(hue, saturation, value):\n        dh = torch.remainder(hue, 1.0) * 6.0\n        dr = torch.clip(torch.abs(dh - 3.0) - 1.0, 0.0, 1.0)\n        dg = torch.clip(2.0 - torch.abs(dh - 2.0), 0.0, 1.0)\n        db = torch.clip(2.0 - torch.abs(dh - 4.0), 0.0, 1.0)\n        one_minus_s = 1.0 - saturation\n\n        red = value * (one_minus_s + saturation * dr)\n        green = value * (one_minus_s + saturation * dg)\n        blue = value * (one_minus_s + saturation * db)\n        return red, green, blue\n\n    images = torch.stack(\n        hsv_planes_to_rgb_planes(hue, saturation, value), axis=channels_axis\n    )\n    return images\n\n\ndef _cast_squeeze_in(image, req_dtypes):\n    need_squeeze = False\n    # make image NCHW\n    if image.ndim < 4:\n        image = image.unsqueeze(dim=0)\n        need_squeeze = True\n\n    out_dtype = image.dtype\n    need_cast = False\n    if out_dtype not in req_dtypes:\n        need_cast = True\n        req_dtype = req_dtypes[0]\n        image = image.to(req_dtype)\n    return image, need_cast, need_squeeze, out_dtype\n\n\ndef _cast_squeeze_out(image, need_cast, need_squeeze, out_dtype):\n    if need_squeeze:\n        image = image.squeeze(dim=0)\n\n    if need_cast:\n        if out_dtype in (\n            torch.uint8,\n            torch.int8,\n            torch.int16,\n            torch.int32,\n            torch.int64,\n        ):\n            # it is better to round before cast\n            image = torch.round(image)\n        image = image.to(out_dtype)\n    return image\n\n\ndef resize(\n    images,\n    size,\n    interpolation=\"bilinear\",\n    antialias=False,\n    crop_to_aspect_ratio=False,\n    pad_to_aspect_ratio=False,\n    fill_mode=\"constant\",\n    fill_value=0.0,\n    data_format=None,\n):\n    data_format = backend.standardize_data_format(data_format)\n    if interpolation in UNSUPPORTED_INTERPOLATIONS:\n        raise ValueError(\n            \"Resizing with Lanczos interpolation is \"\n            \"not supported by the PyTorch backend. \"\n            f\"Received: interpolation={interpolation}.\"\n        )\n    if interpolation not in RESIZE_INTERPOLATIONS:\n        raise ValueError(\n            \"Invalid value for argument `interpolation`. Expected of one \"\n            f\"{RESIZE_INTERPOLATIONS}. Received: interpolation={interpolation}\"\n        )\n    if fill_mode != \"constant\":\n        raise ValueError(\n            \"Invalid value for argument `fill_mode`. Only `'constant'` \"\n            f\"is supported. Received: fill_mode={fill_mode}\"\n        )\n    if pad_to_aspect_ratio and crop_to_aspect_ratio:\n        raise ValueError(\n            \"Only one of `pad_to_aspect_ratio` & `crop_to_aspect_ratio` \"\n            \"can be `True`.\"\n        )\n    if not len(size) == 2:\n        raise ValueError(\n            \"Argument `size` must be a tuple of two elements \"\n            f\"(height, width). Received: size={size}\"\n        )\n    size = tuple(size)\n    images = convert_to_tensor(images)\n    if images.ndim not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n    images, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(\n        images, [torch.float32, torch.float64]\n    )\n    if data_format == \"channels_last\":\n        images = images.permute((0, 3, 1, 2))\n\n    if crop_to_aspect_ratio:\n        shape = images.shape\n        height, width = shape[-2], shape[-1]\n        target_height, target_width = size\n        crop_height = int(float(width * target_height) / target_width)\n        crop_height = max(min(height, crop_height), 1)\n        crop_width = int(float(height * target_width) / target_height)\n        crop_width = max(min(width, crop_width), 1)\n        crop_box_hstart = int(float(height - crop_height) / 2)\n        crop_box_wstart = int(float(width - crop_width) / 2)\n        images = images[\n            :,\n            :,\n            crop_box_hstart : crop_box_hstart + crop_height,\n            crop_box_wstart : crop_box_wstart + crop_width,\n        ]\n    elif pad_to_aspect_ratio:\n        shape = images.shape\n        height, width = shape[-2], shape[-1]\n        target_height, target_width = size\n        pad_height = int(float(width * target_height) / target_width)\n        pad_height = max(height, pad_height)\n        pad_width = int(float(height * target_width) / target_height)\n        pad_width = max(width, pad_width)\n        img_box_hstart = int(float(pad_height - height) / 2)\n        img_box_wstart = int(float(pad_width - width) / 2)\n\n        batch_size = images.shape[0]\n        channels = images.shape[1]\n        if img_box_hstart > 0:\n            padded_img = torch.cat(\n                [\n                    torch.ones(\n                        (batch_size, channels, img_box_hstart, width),\n                        dtype=images.dtype,\n                        device=images.device,\n                    )\n                    * fill_value,\n                    images,\n                    torch.ones(\n                        (batch_size, channels, img_box_hstart, width),\n                        dtype=images.dtype,\n                        device=images.device,\n                    )\n                    * fill_value,\n                ],\n                axis=2,\n            )\n        else:\n            padded_img = images\n        if img_box_wstart > 0:\n            padded_img = torch.cat(\n                [\n                    torch.ones(\n                        (batch_size, channels, height, img_box_wstart),\n                        dtype=images.dtype,\n                        device=images.device,\n                    ),\n                    padded_img,\n                    torch.ones(\n                        (batch_size, channels, height, img_box_wstart),\n                        dtype=images.dtype,\n                        device=images.device,\n                    )\n                    * fill_value,\n                ],\n                axis=3,\n            )\n        images = padded_img\n\n    # This implementation is based on\n    # https://github.com/pytorch/vision/blob/main/torchvision/transforms/_functional_tensor.py\n    if antialias and interpolation not in (\"bilinear\", \"bicubic\"):\n        # We manually set it to False to avoid an error downstream in\n        # interpolate(). This behaviour is documented: the parameter is\n        # irrelevant for modes that are not bilinear or bicubic. We used to\n        # raise an error here, but now we don't use True as the default.\n        antialias = False\n    # Define align_corners to avoid warnings\n    align_corners = False if interpolation in (\"bilinear\", \"bicubic\") else None\n    resized = F.interpolate(\n        images,\n        size=size,\n        mode=RESIZE_INTERPOLATIONS[interpolation],\n        align_corners=align_corners,\n        antialias=antialias,\n    )\n    if interpolation == \"bicubic\" and out_dtype == torch.uint8:\n        resized = resized.clamp(min=0, max=255)\n    if data_format == \"channels_last\":\n        resized = resized.permute((0, 2, 3, 1))\n    resized = _cast_squeeze_out(\n        resized,\n        need_cast=need_cast,\n        need_squeeze=need_squeeze,\n        out_dtype=out_dtype,\n    )\n    return resized\n\n\ndef affine_transform(\n    images,\n    transform,\n    interpolation=\"bilinear\",\n    fill_mode=\"constant\",\n    fill_value=0,\n    data_format=None,\n):\n    data_format = backend.standardize_data_format(data_format)\n    if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys():\n        raise ValueError(\n            \"Invalid value for argument `interpolation`. Expected of one \"\n            f\"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: \"\n            f\"interpolation={interpolation}\"\n        )\n    if fill_mode not in AFFINE_TRANSFORM_FILL_MODES:\n        raise ValueError(\n            \"Invalid value for argument `fill_mode`. Expected of one \"\n            f\"{AFFINE_TRANSFORM_FILL_MODES}. Received: fill_mode={fill_mode}\"\n        )\n\n    images = convert_to_tensor(images)\n    transform = convert_to_tensor(transform)\n\n    if images.ndim not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n    if transform.ndim not in (1, 2):\n        raise ValueError(\n            \"Invalid transform rank: expected rank 1 (single transform) \"\n            \"or rank 2 (batch of transforms). Received input with shape: \"\n            f\"transform.shape={transform.shape}\"\n        )\n\n    # unbatched case\n    need_squeeze = False\n    if images.ndim == 3:\n        images = images.unsqueeze(dim=0)\n        need_squeeze = True\n    if transform.ndim == 1:\n        transform = transform.unsqueeze(dim=0)\n\n    if data_format == \"channels_first\":\n        images = images.permute((0, 2, 3, 1))\n\n    batch_size = images.shape[0]\n\n    # get indices\n    meshgrid = torch.meshgrid(\n        *[\n            torch.arange(size, dtype=transform.dtype, device=transform.device)\n            for size in images.shape[1:]\n        ],\n        indexing=\"ij\",\n    )\n    indices = torch.concatenate(\n        [torch.unsqueeze(x, dim=-1) for x in meshgrid], dim=-1\n    )\n    indices = torch.tile(indices, (batch_size, 1, 1, 1, 1))\n\n    # swap the values\n    a0 = transform[:, 0].clone()\n    a2 = transform[:, 2].clone()\n    b1 = transform[:, 4].clone()\n    b2 = transform[:, 5].clone()\n    transform[:, 0] = b1\n    transform[:, 2] = b2\n    transform[:, 4] = a0\n    transform[:, 5] = a2\n\n    # deal with transform\n    transform = torch.nn.functional.pad(\n        transform, pad=[0, 1, 0, 0], mode=\"constant\", value=1\n    )\n    transform = torch.reshape(transform, (batch_size, 3, 3))\n    offset = transform[:, 0:2, 2].clone()\n    offset = torch.nn.functional.pad(offset, pad=[0, 1, 0, 0])\n    transform[:, 0:2, 2] = 0\n\n    # transform the indices\n    coordinates = torch.einsum(\"Bhwij, Bjk -> Bhwik\", indices, transform)\n    coordinates = torch.moveaxis(coordinates, source=-1, destination=1)\n    coordinates += torch.reshape(offset, shape=(*offset.shape, 1, 1, 1))\n\n    # Note: torch.stack is faster than torch.vmap when the batch size is small.\n    affined = torch.stack(\n        [\n            map_coordinates(\n                images[i],\n                coordinates[i],\n                order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],\n                fill_mode=fill_mode,\n                fill_value=fill_value,\n            )\n            for i in range(len(images))\n        ],\n    )\n\n    if data_format == \"channels_first\":\n        affined = affined.permute((0, 3, 1, 2))\n    if need_squeeze:\n        affined = affined.squeeze(dim=0)\n    return affined\n\n\ndef perspective_transform(\n    images,\n    start_points,\n    end_points,\n    interpolation=\"bilinear\",\n    fill_value=0,\n    data_format=None,\n):\n    data_format = backend.standardize_data_format(data_format)\n\n    images = convert_to_tensor(images)\n    dtype = backend.standardize_dtype(images.dtype)\n    start_points = convert_to_tensor(start_points, dtype=dtype)\n    end_points = convert_to_tensor(end_points, dtype=dtype)\n\n    if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys():\n        raise ValueError(\n            \"Invalid value for argument `interpolation`. Expected of one \"\n            f\"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: \"\n            f\"interpolation={interpolation}\"\n        )\n\n    if images.ndim not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n\n    if start_points.shape[-2:] != (4, 2) or start_points.dim() not in (2, 3):\n        raise ValueError(\n            \"Invalid start_points shape: expected (4,2) for a single image\"\n            f\" or (N,4,2) for a batch. Received shape: {start_points.shape}\"\n        )\n    if end_points.shape[-2:] != (4, 2) or end_points.dim() not in (2, 3):\n        raise ValueError(\n            \"Invalid end_points shape: expected (4,2) for a single image\"\n            f\" or (N,4,2) for a batch. Received shape: {end_points.shape}\"\n        )\n    if start_points.shape != end_points.shape:\n        raise ValueError(\n            \"start_points and end_points must have the same shape.\"\n            f\" Received start_points.shape={start_points.shape}, \"\n            f\"end_points.shape={end_points.shape}\"\n        )\n\n    need_squeeze = False\n    if images.ndim == 3:\n        images = images.unsqueeze(dim=0)\n        need_squeeze = True\n\n    if start_points.ndim == 2:\n        start_points = start_points.unsqueeze(dim=0)\n    if end_points.ndim == 2:\n        end_points = end_points.unsqueeze(dim=0)\n\n    if data_format == \"channels_first\":\n        images = images.permute((0, 2, 3, 1))\n\n    batch_size, height, width, channels = images.shape\n\n    transforms = compute_homography_matrix(start_points, end_points)\n\n    if transforms.dim() == 1:\n        transforms = transforms.unsqueeze(0)\n    if transforms.shape[0] == 1 and batch_size > 1:\n        transforms = transforms.repeat(batch_size, 1)\n\n    grid_x, grid_y = torch.meshgrid(\n        torch.arange(width, dtype=to_torch_dtype(dtype), device=images.device),\n        torch.arange(height, dtype=to_torch_dtype(dtype), device=images.device),\n        indexing=\"xy\",\n    )\n\n    output = torch.empty(\n        [batch_size, height, width, channels],\n        dtype=to_torch_dtype(dtype),\n        device=images.device,\n    )\n\n    for i in range(batch_size):\n        a0, a1, a2, a3, a4, a5, a6, a7 = transforms[i]\n        denom = a6 * grid_x + a7 * grid_y + 1.0\n        x_in = (a0 * grid_x + a1 * grid_y + a2) / denom\n        y_in = (a3 * grid_x + a4 * grid_y + a5) / denom\n\n        coords = torch.stack([y_in.flatten(), x_in.flatten()], dim=0)\n        mapped_channels = []\n        for channel in range(channels):\n            channel_img = images[i, :, :, channel]\n            mapped_channel = map_coordinates(\n                channel_img,\n                coords,\n                order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],\n                fill_mode=\"constant\",\n                fill_value=fill_value,\n            )\n            mapped_channels.append(mapped_channel.reshape(height, width))\n        output[i] = torch.stack(mapped_channels, dim=-1)\n\n    if data_format == \"channels_first\":\n        output = output.permute((0, 3, 1, 2))\n    if need_squeeze:\n        output = output.squeeze(dim=0)\n\n    return output\n\n\ndef compute_homography_matrix(start_points, end_points):\n    start_points = convert_to_tensor(start_points)\n    end_points = convert_to_tensor(end_points)\n    dtype = backend.result_type(start_points.dtype, end_points.dtype, float)\n    # `torch.linalg.solve` requires float32.\n    compute_dtype = backend.result_type(dtype, \"float32\")\n    start_points = cast(start_points, dtype)\n    end_points = cast(end_points, dtype)\n\n    start_x1, start_y1 = start_points[:, 0, 0], start_points[:, 0, 1]\n    start_x2, start_y2 = start_points[:, 1, 0], start_points[:, 1, 1]\n    start_x3, start_y3 = start_points[:, 2, 0], start_points[:, 2, 1]\n    start_x4, start_y4 = start_points[:, 3, 0], start_points[:, 3, 1]\n\n    end_x1, end_y1 = end_points[:, 0, 0], end_points[:, 0, 1]\n    end_x2, end_y2 = end_points[:, 1, 0], end_points[:, 1, 1]\n    end_x3, end_y3 = end_points[:, 2, 0], end_points[:, 2, 1]\n    end_x4, end_y4 = end_points[:, 3, 0], end_points[:, 3, 1]\n\n    coefficient_matrix = torch.stack(\n        [\n            torch.stack(\n                [\n                    end_x1,\n                    end_y1,\n                    torch.ones_like(end_x1),\n                    torch.zeros_like(end_x1),\n                    torch.zeros_like(end_x1),\n                    torch.zeros_like(end_x1),\n                    -start_x1 * end_x1,\n                    -start_x1 * end_y1,\n                ],\n                dim=-1,\n            ),\n            torch.stack(\n                [\n                    torch.zeros_like(end_x1),\n                    torch.zeros_like(end_x1),\n                    torch.zeros_like(end_x1),\n                    end_x1,\n                    end_y1,\n                    torch.ones_like(end_x1),\n                    -start_y1 * end_x1,\n                    -start_y1 * end_y1,\n                ],\n                dim=-1,\n            ),\n            torch.stack(\n                [\n                    end_x2,\n                    end_y2,\n                    torch.ones_like(end_x2),\n                    torch.zeros_like(end_x2),\n                    torch.zeros_like(end_x2),\n                    torch.zeros_like(end_x2),\n                    -start_x2 * end_x2,\n                    -start_x2 * end_y2,\n                ],\n                dim=-1,\n            ),\n            torch.stack(\n                [\n                    torch.zeros_like(end_x2),\n                    torch.zeros_like(end_x2),\n                    torch.zeros_like(end_x2),\n                    end_x2,\n                    end_y2,\n                    torch.ones_like(end_x2),\n                    -start_y2 * end_x2,\n                    -start_y2 * end_y2,\n                ],\n                dim=-1,\n            ),\n            torch.stack(\n                [\n                    end_x3,\n                    end_y3,\n                    torch.ones_like(end_x3),\n                    torch.zeros_like(end_x3),\n                    torch.zeros_like(end_x3),\n                    torch.zeros_like(end_x3),\n                    -start_x3 * end_x3,\n                    -start_x3 * end_y3,\n                ],\n                dim=-1,\n            ),\n            torch.stack(\n                [\n                    torch.zeros_like(end_x3),\n                    torch.zeros_like(end_x3),\n                    torch.zeros_like(end_x3),\n                    end_x3,\n                    end_y3,\n                    torch.ones_like(end_x3),\n                    -start_y3 * end_x3,\n                    -start_y3 * end_y3,\n                ],\n                dim=-1,\n            ),\n            torch.stack(\n                [\n                    end_x4,\n                    end_y4,\n                    torch.ones_like(end_x4),\n                    torch.zeros_like(end_x4),\n                    torch.zeros_like(end_x4),\n                    torch.zeros_like(end_x4),\n                    -start_x4 * end_x4,\n                    -start_x4 * end_y4,\n                ],\n                dim=-1,\n            ),\n            torch.stack(\n                [\n                    torch.zeros_like(end_x4),\n                    torch.zeros_like(end_x4),\n                    torch.zeros_like(end_x4),\n                    end_x4,\n                    end_y4,\n                    torch.ones_like(end_x4),\n                    -start_y4 * end_x4,\n                    -start_y4 * end_y4,\n                ],\n                dim=-1,\n            ),\n        ],\n        dim=1,\n    )\n\n    target_vector = torch.stack(\n        [\n            start_x1,\n            start_y1,\n            start_x2,\n            start_y2,\n            start_x3,\n            start_y3,\n            start_x4,\n            start_y4,\n        ],\n        dim=-1,\n    ).unsqueeze(-1)\n\n    coefficient_matrix = cast(coefficient_matrix, compute_dtype)\n    target_vector = cast(target_vector, compute_dtype)\n    homography_matrix = torch.linalg.solve(coefficient_matrix, target_vector)\n    homography_matrix = homography_matrix.reshape(-1, 8)\n    homography_matrix = cast(homography_matrix, dtype)\n    return homography_matrix\n\n\ndef _mirror_index_fixer(index, size):\n    s = size - 1  # Half-wavelength of triangular wave\n    # Scaled, integer-valued version of the triangular wave |x - round(x)|\n    return torch.abs((index + s) % (2 * s) - s)\n\n\ndef _reflect_index_fixer(index, size):\n    return torch.floor_divide(\n        _mirror_index_fixer(2 * index + 1, 2 * size + 1) - 1, 2\n    )\n\n\n_INDEX_FIXERS = {\n    # we need to take care of out-of-bound indices in torch\n    \"constant\": lambda index, size: torch.clip(index, 0, size - 1),\n    \"nearest\": lambda index, size: torch.clip(index, 0, size - 1),\n    \"wrap\": lambda index, size: index % size,\n    \"mirror\": _mirror_index_fixer,\n    \"reflect\": _reflect_index_fixer,\n}\n\n\ndef _is_integer(a):\n    if not torch.is_floating_point(a) and not torch.is_complex(a):\n        return True\n    return False\n\n\ndef _nearest_indices_and_weights(coordinate):\n    coordinate = (\n        coordinate if _is_integer(coordinate) else torch.round(coordinate)\n    )\n    index = coordinate.to(torch.int32)\n    return [(index, 1)]\n\n\ndef _linear_indices_and_weights(coordinate):\n    lower = torch.floor(coordinate)\n    upper_weight = coordinate - lower\n    lower_weight = 1 - upper_weight\n    index = lower.to(torch.int32)\n    return [(index, lower_weight), (index + 1, upper_weight)]\n\n\ndef map_coordinates(\n    inputs, coordinates, order, fill_mode=\"constant\", fill_value=0.0\n):\n    input_arr = convert_to_tensor(inputs)\n    coordinate_arrs = [convert_to_tensor(c) for c in coordinates]\n\n    if len(coordinate_arrs) != len(input_arr.shape):\n        raise ValueError(\n            \"First dim of `coordinates` must be the same as the rank of \"\n            \"`inputs`. \"\n            f\"Received inputs with shape: {input_arr.shape} and coordinate \"\n            f\"leading dim of {len(coordinate_arrs)}\"\n        )\n    if len(coordinate_arrs[0].shape) < 1:\n        dim = len(coordinate_arrs)\n        shape = (dim,) + coordinate_arrs[0].shape\n        raise ValueError(\n            \"Invalid coordinates rank: expected at least rank 2.\"\n            f\" Received input with shape: {shape}\"\n        )\n\n    # skip tensor creation as possible\n    if isinstance(fill_value, (int, float)) and _is_integer(input_arr):\n        fill_value = int(fill_value)\n\n    if len(coordinates) != len(input_arr.shape):\n        raise ValueError(\n            \"coordinates must be a sequence of length inputs.shape, but \"\n            f\"{len(coordinates)} != {len(input_arr.shape)}\"\n        )\n\n    index_fixer = _INDEX_FIXERS.get(fill_mode)\n    if index_fixer is None:\n        raise ValueError(\n            \"Invalid value for argument `fill_mode`. Expected one of \"\n            f\"{set(_INDEX_FIXERS.keys())}. Received: fill_mode={fill_mode}\"\n        )\n\n    if order == 0:\n        interp_fun = _nearest_indices_and_weights\n    elif order == 1:\n        interp_fun = _linear_indices_and_weights\n    else:\n        raise NotImplementedError(\"map_coordinates currently requires order<=1\")\n\n    if fill_mode == \"constant\":\n\n        def is_valid(index, size):\n            return (0 <= index) & (index < size)\n\n    else:\n\n        def is_valid(index, size):\n            return True\n\n    valid_1d_interpolations = []\n    for coordinate, size in zip(coordinate_arrs, input_arr.shape):\n        interp_nodes = interp_fun(coordinate)\n        valid_interp = []\n        for index, weight in interp_nodes:\n            fixed_index = index_fixer(index, size)\n            valid = is_valid(index, size)\n            valid_interp.append((fixed_index, valid, weight))\n        valid_1d_interpolations.append(valid_interp)\n\n    outputs = []\n    for items in itertools.product(*valid_1d_interpolations):\n        indices, validities, weights = zip(*items)\n        if all(valid is True for valid in validities):\n            # fast path\n            contribution = input_arr[indices]\n        else:\n            all_valid = functools.reduce(operator.and_, validities)\n            contribution = torch.where(\n                all_valid, input_arr[indices], fill_value\n            )\n        outputs.append(functools.reduce(operator.mul, weights) * contribution)\n    result = functools.reduce(operator.add, outputs)\n    if _is_integer(input_arr):\n        result = result if _is_integer(result) else torch.round(result)\n    return result.to(input_arr.dtype)\n\n\ndef gaussian_blur(\n    images, kernel_size=(3, 3), sigma=(1.0, 1.0), data_format=None\n):\n    def _create_gaussian_kernel(kernel_size, sigma, dtype):\n        def _get_gaussian_kernel1d(size, sigma):\n            x = (\n                torch.arange(size, dtype=dtype, device=sigma.device)\n                - (size - 1) / 2\n            )\n            kernel1d = torch.exp(-0.5 * (x / sigma) ** 2)\n            return kernel1d / torch.sum(kernel1d)\n\n        def _get_gaussian_kernel2d(size, sigma):\n            kernel1d_x = _get_gaussian_kernel1d(size[0], sigma[0])\n            kernel1d_y = _get_gaussian_kernel1d(size[1], sigma[1])\n            return torch.outer(kernel1d_y, kernel1d_x)\n\n        kernel = _get_gaussian_kernel2d(kernel_size, sigma)\n\n        kernel = kernel.view(1, 1, kernel_size[0], kernel_size[1])\n        return kernel\n\n    images = convert_to_tensor(images)\n    kernel_size = convert_to_tensor(kernel_size)\n    sigma = convert_to_tensor(sigma)\n    dtype = images.dtype\n\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n\n    need_squeeze = False\n    if images.ndim == 3:\n        images = images.unsqueeze(dim=0)\n        need_squeeze = True\n\n    if data_format == \"channels_last\":\n        images = images.permute(0, 3, 1, 2)\n\n    num_channels = images.shape[1]\n    kernel = _create_gaussian_kernel(kernel_size, sigma, dtype)\n\n    kernel = kernel.expand(num_channels, 1, kernel_size[0], kernel_size[1])\n\n    blurred_images = torch.nn.functional.conv2d(\n        images,\n        kernel,\n        stride=1,\n        padding=int(kernel_size[0] // 2),\n        groups=num_channels,\n    )\n\n    if data_format == \"channels_last\":\n        blurred_images = blurred_images.permute(0, 2, 3, 1)\n\n    if need_squeeze:\n        blurred_images = blurred_images.squeeze(dim=0)\n\n    return blurred_images\n\n\n@dynamo.disable()\ndef _torch_seed_generator(seed):\n    first_seed, second_seed = draw_seed(seed)\n    device = get_device()\n    if device == \"meta\":\n        return None\n    generator = torch.Generator(device=get_device())\n    generator.manual_seed(int(first_seed + second_seed))\n    return generator\n\n\ndef elastic_transform(\n    images,\n    alpha=20.0,\n    sigma=5.0,\n    interpolation=\"bilinear\",\n    fill_mode=\"reflect\",\n    fill_value=0.0,\n    seed=None,\n    data_format=None,\n):\n    data_format = backend.standardize_data_format(data_format)\n    if interpolation not in AFFINE_TRANSFORM_INTERPOLATIONS.keys():\n        raise ValueError(\n            \"Invalid value for argument `interpolation`. Expected of one \"\n            f\"{set(AFFINE_TRANSFORM_INTERPOLATIONS.keys())}. Received: \"\n            f\"interpolation={interpolation}\"\n        )\n    if fill_mode not in AFFINE_TRANSFORM_FILL_MODES:\n        raise ValueError(\n            \"Invalid value for argument `fill_mode`. Expected of one \"\n            f\"{AFFINE_TRANSFORM_FILL_MODES}. Received: fill_mode={fill_mode}\"\n        )\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n\n    images = convert_to_tensor(images)\n    alpha = convert_to_tensor(alpha)\n    sigma = convert_to_tensor(sigma)\n    input_dtype = images.dtype\n    kernel_size = (int(6 * sigma) | 1, int(6 * sigma) | 1)\n\n    need_squeeze = False\n    if images.ndim == 3:\n        images = images.unsqueeze(dim=0)\n        need_squeeze = True\n\n    if data_format == \"channels_last\":\n        batch_size, height, width, channels = images.shape\n        channel_axis = -1\n    else:\n        batch_size, channels, height, width = images.shape\n        channel_axis = 1\n\n    generator = _torch_seed_generator(seed) if get_device() == \"meta\" else None\n    dx = (\n        torch.normal(\n            0.0,\n            1.0,\n            size=(batch_size, height, width),\n            generator=generator,\n            dtype=input_dtype,\n            device=images.device,\n        )\n        * sigma\n    )\n\n    dy = (\n        torch.normal(\n            0.0,\n            1.0,\n            size=(batch_size, height, width),\n            generator=generator,\n            dtype=input_dtype,\n            device=images.device,\n        )\n        * sigma\n    )\n\n    dx = gaussian_blur(\n        dx.unsqueeze(dim=channel_axis),\n        kernel_size=kernel_size,\n        sigma=(sigma, sigma),\n        data_format=data_format,\n    )\n    dy = gaussian_blur(\n        dy.unsqueeze(dim=channel_axis),\n        kernel_size=kernel_size,\n        sigma=(sigma, sigma),\n        data_format=data_format,\n    )\n\n    dx = dx.squeeze()\n    dy = dy.squeeze()\n\n    x, y = torch.meshgrid(\n        torch.arange(width), torch.arange(height), indexing=\"xy\"\n    )\n    x, y = x.unsqueeze(0).to(images.device), y.unsqueeze(0).to(images.device)\n\n    distorted_x = x + alpha * dx\n    distorted_y = y + alpha * dy\n\n    transformed_images = torch.zeros_like(images)\n\n    if data_format == \"channels_last\":\n        for i in range(channels):\n            transformed_images[..., i] = torch.stack(\n                [\n                    map_coordinates(\n                        images[b, ..., i],\n                        [distorted_y[b], distorted_x[b]],\n                        order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],\n                        fill_mode=fill_mode,\n                        fill_value=fill_value,\n                    )\n                    for b in range(batch_size)\n                ]\n            )\n    else:\n        for i in range(channels):\n            transformed_images[:, i, :, :] = torch.stack(\n                [\n                    map_coordinates(\n                        images[b, i, ...],\n                        [distorted_y[b], distorted_x[b]],\n                        order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],\n                        fill_mode=fill_mode,\n                        fill_value=fill_value,\n                    )\n                    for b in range(batch_size)\n                ]\n            )\n\n    if need_squeeze:\n        transformed_images = transformed_images.squeeze(0)\n    transformed_images = transformed_images.to(input_dtype)\n\n    return transformed_images\n\n\ndef _fill_triangle_kernel(x):\n    return torch.maximum(torch.tensor(0, dtype=x.dtype), 1 - torch.abs(x))\n\n\ndef _fill_keys_cubic_kernel(x):\n    out = ((1.5 * x - 2.5) * x) * x + 1.0\n    out = torch.where(x >= 1.0, ((-0.5 * x + 2.5) * x - 4.0) * x + 2.0, out)\n    return torch.where(x >= 2.0, 0.0, out)\n\n\ndef _fill_lanczos_kernel(radius, x):\n    y = radius * torch.sin(np.pi * x) * torch.sin(np.pi * x / radius)\n    out = torch.where(\n        x > 1e-3, torch.divide(y, torch.where(x != 0, np.pi**2 * x**2, 1)), 1\n    )\n    return torch.where(x > radius, 0.0, out)\n\n\n_kernels = {\n    \"linear\": _fill_triangle_kernel,\n    \"cubic\": _fill_keys_cubic_kernel,\n    \"lanczos3\": lambda x: _fill_lanczos_kernel(3.0, x),\n    \"lanczos5\": lambda x: _fill_lanczos_kernel(5.0, x),\n}\n\n\ndef _compute_weight_mat(\n    input_size, output_size, scale, translation, kernel, antialias\n):\n    dtype = to_torch_dtype(backend.result_type(scale.dtype, translation.dtype))\n    inv_scale = 1.0 / scale\n    kernel_scale = (\n        torch.maximum(\n            inv_scale,\n            torch.tensor(1.0, dtype=inv_scale.dtype, device=inv_scale.device),\n        )\n        if antialias\n        else 1.0\n    )\n    sample_f = (\n        (torch.arange(output_size, dtype=dtype, device=inv_scale.device) + 0.5)\n        * inv_scale\n        - translation * inv_scale\n        - 0.5\n    )\n    x = (\n        torch.abs(\n            sample_f[torch.newaxis, :]\n            - torch.arange(input_size, dtype=dtype, device=sample_f.device)[\n                :, torch.newaxis\n            ]\n        )\n        / kernel_scale\n    )\n    weights = kernel(x)\n    total_weight_sum = torch.sum(weights, dim=0, keepdims=True)\n    weights = torch.where(\n        torch.abs(total_weight_sum) > 1000.0 * float(np.finfo(np.float32).eps),\n        torch.divide(\n            weights, torch.where(total_weight_sum != 0, total_weight_sum, 1)\n        ),\n        0,\n    )\n    input_size_minus_0_5 = input_size - 0.5\n    return torch.where(\n        torch.logical_and(sample_f >= -0.5, sample_f <= input_size_minus_0_5)[\n            torch.newaxis, :\n        ],\n        weights,\n        0,\n    )\n\n\ndef _scale_and_translate(\n    x, output_shape, spatial_dims, scale, translation, kernel, antialias\n):\n    x = convert_to_tensor(x)\n    input_shape = x.shape\n    if len(spatial_dims) == 0:\n        return x\n    if backend.is_int_dtype(x.dtype):\n        output = cast(x, \"float32\")\n        use_rounding = True\n    else:\n        output = torch.clone(x)\n        use_rounding = False\n    for i, d in enumerate(spatial_dims):\n        d = d % x.ndim\n        m, n = input_shape[d], output_shape[d]\n        w = cast(\n            _compute_weight_mat(\n                m, n, scale[i], translation[i], kernel, antialias\n            ),\n            output.dtype,\n        )\n        output = torch.tensordot(output, w, dims=((d,), (0,)))\n        output = torch.moveaxis(output, -1, d)\n    if use_rounding:\n        output = torch.clip(torch.round(output), torch.min(x), torch.max(x))\n        output = cast(output, x.dtype)\n    return output\n\n\ndef scale_and_translate(\n    images,\n    output_shape,\n    scale,\n    translation,\n    spatial_dims,\n    method,\n    antialias=True,\n):\n    if method not in SCALE_AND_TRANSLATE_METHODS:\n        raise ValueError(\n            \"Invalid value for argument `method`. Expected of one \"\n            f\"{SCALE_AND_TRANSLATE_METHODS}. Received: method={method}\"\n        )\n    if method in (\"linear\", \"bilinear\", \"trilinear\", \"triangle\"):\n        method = \"linear\"\n    elif method in (\"cubic\", \"bicubic\", \"tricubic\"):\n        method = \"cubic\"\n\n    images = convert_to_tensor(images)\n    scale = convert_to_tensor(scale)\n    translation = convert_to_tensor(translation)\n    kernel = _kernels[method]\n    dtype = backend.result_type(scale.dtype, translation.dtype)\n    scale = cast(scale, dtype)\n    translation = cast(translation, dtype)\n    return _scale_and_translate(\n        images,\n        output_shape,\n        spatial_dims,\n        scale,\n        translation,\n        kernel,\n        antialias,\n    )\n"
  },
  {
    "path": "keras/src/backend/torch/layer.py",
    "content": "import torch\n\nfrom keras.src.backend.common.stateless_scope import in_stateless_scope\nfrom keras.src.ops.operation import Operation\n\n\nclass TorchLayer(torch.nn.Module):\n    @property\n    def torch_params(self):\n        if not hasattr(self, \"_torch_params\"):\n            self._track_variables()\n        return self._torch_params\n\n    def _post_build(self):\n        # Do not track variables when in a stateless scope.\n        # The variables are not initialized.\n        if in_stateless_scope():\n            return\n        self._track_variables()\n\n    def _track_variables(self):\n        # set torch_params attribute will have module automatically track\n        # parameters.\n        self._torch_params = torch.nn.ParameterDict(\n            {variable.path: variable.value for variable in self.variables}\n        )\n\n    def named_parameters(\n        self,\n        prefix=\"\",\n        recurse=True,\n        remove_duplicate=True,\n    ):\n        if not hasattr(self, \"_torch_params\"):\n            self._track_variables()\n        return torch.nn.Module.named_parameters(\n            self, prefix, recurse, remove_duplicate\n        )\n\n    def forward(self, *args, **kwargs):\n        return Operation.__call__(self, *args, **kwargs)\n\n    def _setattr_hook(self, name, value):\n        from keras.src.layers import Layer\n\n        if (\n            isinstance(value, torch.nn.Module)\n            and not isinstance(value, Layer)\n            and not name == \"_torch_params\"\n        ):\n            from keras.src.utils.torch_utils import TorchModuleWrapper\n\n            if not isinstance(self, TorchModuleWrapper):\n                value = TorchModuleWrapper(value)\n        return name, value\n\n    def _post_track_variable(self, variable):\n        if hasattr(self, \"_torch_params\"):\n            if variable.path not in self.torch_params:\n                self.torch_params[variable.path] = variable.value\n\n    def _post_untrack_variable(self, variable):\n        if hasattr(self, \"_torch_params\"):\n            if variable.path in self.torch_params:\n                self.torch_params.pop(variable.path)\n"
  },
  {
    "path": "keras/src/backend/torch/linalg.py",
    "content": "import torch\n\nfrom keras.src.backend import config\nfrom keras.src.backend import standardize_dtype\nfrom keras.src.backend.common import dtypes\nfrom keras.src.backend.torch.core import cast\nfrom keras.src.backend.torch.core import convert_to_tensor\n\n\ndef cholesky(x, upper=False):\n    return torch.linalg.cholesky(x, upper=upper)\n\n\ndef cholesky_inverse(x, upper=False):\n    return torch.cholesky_inverse(x, upper=upper)\n\n\ndef det(x):\n    return torch.det(x)\n\n\ndef eig(x):\n    return torch.linalg.eig(x)\n\n\ndef eigh(x):\n    return torch.linalg.eigh(x)\n\n\ndef inv(x):\n    return torch.linalg.inv(x)\n\n\ndef lu_factor(x):\n    LU, pivots = torch.linalg.lu_factor(x)\n    # torch returns pivots with 1-based indexing\n    return LU, pivots - 1\n\n\ndef norm(x, ord=None, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(x.dtype, float)\n    x = cast(x, dtype)\n    return torch.linalg.norm(x, ord=ord, dim=axis, keepdim=keepdims)\n\n\ndef qr(x, mode=\"reduced\"):\n    if mode not in {\"reduced\", \"complete\"}:\n        raise ValueError(\n            \"`mode` argument value not supported. \"\n            \"Expected one of {'reduced', 'complete'}. \"\n            f\"Received: mode={mode}\"\n        )\n    return torch.linalg.qr(x, mode=mode)\n\n\ndef solve(a, b):\n    return torch.linalg.solve(a, b)\n\n\ndef solve_triangular(a, b, lower=False):\n    if b.ndim == a.ndim - 1:\n        b = torch.unsqueeze(b, axis=-1)\n        return torch.linalg.solve_triangular(a, b, upper=not lower).squeeze(\n            axis=-1\n        )\n    return torch.linalg.solve_triangular(a, b, upper=not lower)\n\n\ndef svd(x, full_matrices=True, compute_uv=True):\n    if not compute_uv:\n        return torch.linalg.svdvals(x)\n    return torch.linalg.svd(x, full_matrices=full_matrices)\n\n\ndef lstsq(a, b, rcond=None):\n    a = convert_to_tensor(a)\n    b = convert_to_tensor(b)\n    return torch.linalg.lstsq(a, b, rcond=rcond)[0]\n\n\ndef jvp(fun, primals, tangents, has_aux=False):\n    return torch.func.jvp(fun, primals, tangents, has_aux=has_aux)\n"
  },
  {
    "path": "keras/src/backend/torch/math.py",
    "content": "import math\n\nimport torch\n\nfrom keras.src.backend import standardize_dtype\nfrom keras.src.backend.torch.core import convert_to_tensor\nfrom keras.src.backend.torch.core import get_device\nfrom keras.src.backend.torch.numpy import pad\n\n\ndef _segment_reduction_fn(data, segment_ids, reduction_method, num_segments):\n    num_repeats = torch.prod(\n        torch.tensor(data.shape[1:], device=get_device())\n    ).long()\n    # To use `scatter_add` in torch, we need to replicate `segment_ids` into the\n    # shape of `data`.\n    segment_ids = (\n        segment_ids.repeat_interleave(num_repeats)\n        .view(*data.shape)\n        .type(torch.int64)\n    )\n    num_segments = num_segments or len(torch.unique(segment_ids))\n\n    # .scatter_add does not support -1 in the indices.\n    # Add all out-of-bound indices value to an extra dimension after\n    # num_segments, which is removed before returning the result.\n\n    # Replacing the out-of-bound indices.\n    segment_ids = torch.where(segment_ids >= 0, segment_ids, num_segments)\n    segment_ids = torch.where(\n        segment_ids < num_segments, segment_ids, num_segments\n    )\n\n    # Add one more dimension to the result shape with the \"+1\".\n    shape = (num_segments + 1,) + tuple(data.shape[1:])\n\n    if reduction_method == \"amax\":\n        result = torch.ones(*shape, device=get_device()) * -float(\"Inf\")\n    else:\n        result = torch.zeros(*shape, device=get_device())\n\n    result = result.scatter_reduce(\n        0, segment_ids, data.float(), reduction_method\n    )\n\n    # Removing the extra dimension.\n    result = result[:-1, ...]\n\n    return result.type(data.dtype)\n\n\ndef segment_sum(data, segment_ids, num_segments=None, sorted=False):\n    data = convert_to_tensor(data)\n    segment_ids = convert_to_tensor(segment_ids)\n    return _segment_reduction_fn(data, segment_ids, \"sum\", num_segments)\n\n\ndef segment_max(data, segment_ids, num_segments=None, sorted=False):\n    data = convert_to_tensor(data)\n    segment_ids = convert_to_tensor(segment_ids)\n    return _segment_reduction_fn(data, segment_ids, \"amax\", num_segments)\n\n\ndef top_k(x, k, sorted=True):\n    x = convert_to_tensor(x)\n    return torch.topk(x, k, sorted=sorted)\n\n\ndef in_top_k(targets, predictions, k):\n    targets = convert_to_tensor(targets).type(torch.int64)\n    targets = targets[:, None]\n    predictions = convert_to_tensor(predictions)\n    topk_values = top_k(predictions, k).values\n    targets_values = torch.take_along_dim(predictions, targets, dim=-1)\n    mask = targets_values >= topk_values\n    return torch.any(mask, axis=-1)\n\n\ndef logsumexp(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    axis = tuple(range(x.dim())) if axis is None else axis\n    return torch.logsumexp(x, dim=axis, keepdim=keepdims)\n\n\ndef qr(x, mode=\"reduced\"):\n    x = convert_to_tensor(x)\n    if mode not in {\"reduced\", \"complete\"}:\n        raise ValueError(\n            \"`mode` argument value not supported. \"\n            \"Expected one of {'reduced', 'complete'}. \"\n            f\"Received: mode={mode}\"\n        )\n    x = convert_to_tensor(x)\n    return torch.linalg.qr(x, mode=mode)\n\n\ndef extract_sequences(x, sequence_length, sequence_stride):\n    x = convert_to_tensor(x)\n    return torch.unfold_copy(\n        x, dimension=-1, size=sequence_length, step=sequence_stride\n    )\n\n\ndef _overlap_sequences(x, sequence_stride):\n    # Ref: https://github.com/google/jax/blob/main/jax/_src/scipy/signal.py\n    x = convert_to_tensor(x)\n    *batch_shape, num_sequences, sequence_length = x.shape\n    if sequence_stride > sequence_length:\n        raise ValueError(\n            \"`sequence_stride` must equal or less than x.shape[-1]. \"\n            f\"Received: sequence_stride={sequence_stride}, \"\n            f\"x.shape[-1]={sequence_length}\"\n        )\n    if sequence_stride < (sequence_length / num_sequences):\n        raise ValueError(\n            \"`sequence_stride` must equal or greater than \"\n            \"x.shape[-1] / x.shape[-2]. \"\n            f\"Received: sequence_stride={sequence_stride}, \"\n            f\"x.shape[-1]={sequence_length}, x.shape[-2]={num_sequences}\"\n        )\n    flat_batchsize = math.prod(batch_shape)\n    x = torch.reshape(x, (flat_batchsize, num_sequences, sequence_length))\n    output_size = sequence_stride * (num_sequences - 1) + sequence_length\n    nstep_per_segment = 1 + (sequence_length - 1) // sequence_stride\n    # Here, we use shorter notation for axes.\n    # B: batch_size, N: num_sequences, S: nstep_per_segment,\n    # T: sequence_length divided by S\n    padded_segment_len = nstep_per_segment * sequence_stride\n    x = torch.nn.functional.pad(\n        x, (0, padded_segment_len - sequence_length, 0, 0, 0, 0)\n    )\n    x = torch.reshape(\n        x, (flat_batchsize, num_sequences, nstep_per_segment, sequence_stride)\n    )\n    # For obtaining shifted signals, this routine reinterprets flattened array\n    # with a shrinked axis.  With appropriate truncation/ padding, this\n    # operation pushes the last padded elements of the previous row to the head\n    # of the current row.\n    # See implementation of `overlap_and_add` in Tensorflow for details.\n    x = torch.permute(x, (0, 2, 1, 3))  # x: (B, S, N, T)\n    x = torch.nn.functional.pad(x, (0, 0, 0, num_sequences, 0, 0, 0, 0))\n    # x: (B, S, N*2, T)\n    shrinked = x.shape[2] - 1\n    x = torch.reshape(x, (flat_batchsize, -1))\n    x = x[:, : (nstep_per_segment * shrinked * sequence_stride)]\n    x = torch.reshape(\n        x, (flat_batchsize, nstep_per_segment, shrinked * sequence_stride)\n    )\n    # Finally, sum shifted segments, and truncate results to the output_size.\n    x = torch.sum(x, dim=1)[:, :output_size]\n    return torch.reshape(x, tuple(batch_shape) + (-1,))\n\n\ndef _get_complex_tensor_from_tuple(x):\n    if not isinstance(x, (tuple, list)) or len(x) != 2:\n        raise ValueError(\n            \"Input `x` should be a tuple of two tensors - real and imaginary.\"\n            f\"Received: x={x}\"\n        )\n    # `convert_to_tensor` does not support passing complex tensors. We separate\n    # the input out into real and imaginary and convert them separately.\n    real, imag = x\n    real = convert_to_tensor(real)\n    imag = convert_to_tensor(imag)\n    # Check shape.\n    if real.shape != imag.shape:\n        raise ValueError(\n            \"Input `x` should be a tuple of two tensors - real and imaginary.\"\n            \"Both the real and imaginary parts should have the same shape. \"\n            f\"Received: x[0].shape = {real.shape}, x[1].shape = {imag.shape}\"\n        )\n    # Ensure dtype is float.\n    if not torch.is_floating_point(real) or not torch.is_floating_point(imag):\n        raise ValueError(\n            \"At least one tensor in input `x` is not of type float.\"\n            f\"Received: x={x}.\"\n        )\n\n    complex_input = torch.complex(real, imag)\n    return complex_input\n\n\ndef fft(x):\n    complex_input = _get_complex_tensor_from_tuple(x)\n    complex_output = torch.fft.fft(complex_input)\n    return complex_output.real, complex_output.imag\n\n\ndef fft2(x):\n    complex_input = _get_complex_tensor_from_tuple(x)\n    complex_output = torch.fft.fft2(complex_input)\n    return complex_output.real, complex_output.imag\n\n\ndef ifft2(x):\n    complex_input = _get_complex_tensor_from_tuple(x)\n    complex_output = torch.fft.ifft2(complex_input)\n    return complex_output.real, complex_output.imag\n\n\ndef rfft(x, fft_length=None):\n    x = convert_to_tensor(x)\n    complex_output = torch.fft.rfft(x, n=fft_length, dim=-1, norm=\"backward\")\n    return complex_output.real, complex_output.imag\n\n\ndef irfft(x, fft_length=None):\n    complex_input = _get_complex_tensor_from_tuple(x)\n    return torch.fft.irfft(complex_input, n=fft_length, dim=-1, norm=\"backward\")\n\n\ndef stft(\n    x, sequence_length, sequence_stride, fft_length, window=\"hann\", center=True\n):\n    if standardize_dtype(x.dtype) not in {\"float32\", \"float64\"}:\n        raise TypeError(\n            \"Invalid input type. Expected `float32` or `float64`. \"\n            f\"Received: input type={x.dtype}\"\n        )\n    if fft_length < sequence_length:\n        raise ValueError(\n            \"`fft_length` must equal or larger than `sequence_length`. \"\n            f\"Received: sequence_length={sequence_length}, \"\n            f\"fft_length={fft_length}\"\n        )\n    if isinstance(window, str):\n        if window not in {\"hann\", \"hamming\"}:\n            raise ValueError(\n                \"If a string is passed to `window`, it must be one of \"\n                f'`\"hann\"`, `\"hamming\"`. Received: window={window}'\n            )\n    x = convert_to_tensor(x)\n\n    if window is not None:\n        if isinstance(window, str):\n            if window == \"hann\":\n                win = torch.hann_window(\n                    sequence_length,\n                    periodic=True,\n                    dtype=x.dtype,\n                    device=get_device(),\n                )\n            else:\n                win = torch.hamming_window(\n                    sequence_length,\n                    periodic=True,\n                    dtype=x.dtype,\n                    device=get_device(),\n                )\n        else:\n            win = convert_to_tensor(window, dtype=x.dtype)\n        if len(win.shape) != 1 or win.shape[-1] != sequence_length:\n            raise ValueError(\n                \"The shape of `window` must be equal to [sequence_length].\"\n                f\"Received: window shape={win.shape}\"\n            )\n    else:\n        win = torch.ones((sequence_length,), dtype=x.dtype, device=get_device())\n\n    need_unpack = False\n    *batch_shape, samples = x.shape\n    if len(x.shape) > 2:\n        need_unpack = True\n        flat_batchsize = math.prod(batch_shape)\n        x = torch.reshape(x, (flat_batchsize, samples))\n\n    x = torch.stft(\n        x,\n        n_fft=fft_length,\n        hop_length=sequence_stride,\n        win_length=sequence_length,\n        window=win,\n        center=center,\n        return_complex=True,\n    )\n    if need_unpack:\n        fft_unique_bins, num_sequences = x.shape[-2:]\n        x = torch.reshape(x, (*batch_shape, fft_unique_bins, num_sequences))\n\n    x = torch.swapaxes(x, -2, -1)\n    return x.real, x.imag\n\n\ndef istft(\n    x,\n    sequence_length,\n    sequence_stride,\n    fft_length,\n    length=None,\n    window=\"hann\",\n    center=True,\n):\n    complex_input = _get_complex_tensor_from_tuple(x)\n    dtype = complex_input.real.dtype\n    win = None\n    if window is not None:\n        if isinstance(window, str):\n            if window == \"hann\":\n                win = torch.hann_window(\n                    sequence_length,\n                    periodic=True,\n                    dtype=dtype,\n                    device=get_device(),\n                )\n            else:\n                win = torch.hamming_window(\n                    sequence_length,\n                    periodic=True,\n                    dtype=dtype,\n                    device=get_device(),\n                )\n        else:\n            win = convert_to_tensor(window, dtype=dtype)\n        if len(win.shape) != 1 or win.shape[-1] != sequence_length:\n            raise ValueError(\n                \"The shape of `window` must be equal to [sequence_length].\"\n                f\"Received: window shape={win.shape}\"\n            )\n\n    if sequence_length == fft_length and center is True and win is not None:\n        # can be fallen back to torch.istft\n        need_unpack = False\n        *batch_shape, num_sequences, fft_unique_bins = complex_input.shape\n        if len(complex_input.shape) > 3:\n            need_unpack = True\n            flat_batchsize = math.prod(batch_shape)\n            complex_input = torch.reshape(\n                complex_input, (flat_batchsize, num_sequences, fft_unique_bins)\n            )\n        complex_input = torch.swapaxes(complex_input, -2, -1)\n        x = torch.istft(\n            complex_input,\n            n_fft=fft_length,\n            hop_length=sequence_stride,\n            win_length=sequence_length,\n            window=win,\n            center=center,\n            length=length,\n            return_complex=False,\n        )\n        if need_unpack:\n            samples = x.shape[-1]\n            x = torch.reshape(x, (*batch_shape, samples))\n        return x\n\n    # custom implementation with irfft and _overlap_sequences\n    # references:\n    # torch: aten/src/ATen/native/SpectralOps.cpp\n    # tf: tf.signal.inverse_stft_window_fn\n    x = irfft(x, fft_length)\n\n    expected_output_len = fft_length + sequence_stride * (x.shape[-2] - 1)\n\n    if win is not None:\n        l_pad = (fft_length - sequence_length) // 2\n        r_pad = fft_length - sequence_length - l_pad\n        win = pad(win, [[l_pad, r_pad]], \"constant\")\n\n        # square and sum\n        _sequence_length = sequence_length + l_pad + r_pad\n        denom = torch.square(win)\n        overlaps = -(-_sequence_length // sequence_stride)\n        denom = pad(denom, [(0, overlaps * sequence_stride - _sequence_length)])\n        denom = torch.reshape(denom, [overlaps, sequence_stride])\n        denom = torch.sum(denom, 0, keepdims=True)\n        denom = torch.tile(denom, [overlaps, 1])\n        denom = torch.reshape(denom, [overlaps * sequence_stride])\n        win = torch.divide(win, denom[:_sequence_length])\n        x = torch.multiply(x, win)\n\n    x = _overlap_sequences(x, sequence_stride)\n\n    start = 0 if center is False else fft_length // 2\n    if length is not None:\n        end = start + length\n    elif center is True:\n        end = -(fft_length // 2)\n    else:\n        end = expected_output_len\n    return x[..., start:end]\n\n\ndef rsqrt(x):\n    x = convert_to_tensor(x)\n    return torch.rsqrt(x)\n\n\ndef erf(x):\n    x = convert_to_tensor(x)\n    return torch.erf(x)\n\n\ndef erfinv(x):\n    x = convert_to_tensor(x)\n    return torch.erfinv(x)\n\n\ndef logdet(x):\n    x = convert_to_tensor(x)\n    return torch.logdet(x)\n"
  },
  {
    "path": "keras/src/backend/torch/nn.py",
    "content": "import torch\nimport torch.nn.functional as tnn\n\nfrom keras.src import backend\nfrom keras.src.backend.common.backend_utils import (\n    compute_conv_transpose_padding_args_for_torch,\n)\nfrom keras.src.backend.torch.core import cast\nfrom keras.src.backend.torch.core import convert_to_tensor\nfrom keras.src.backend.torch.core import get_device\nfrom keras.src.backend.torch.numpy import expand_dims\nfrom keras.src.backend.torch.numpy import where\nfrom keras.src.utils.argument_validation import standardize_tuple\n\n\ndef relu(x):\n    x = convert_to_tensor(x)\n    return tnn.relu(x)\n\n\ndef relu6(x):\n    x = convert_to_tensor(x)\n    return tnn.relu6(x)\n\n\ndef sigmoid(x):\n    x = convert_to_tensor(x)\n    return tnn.sigmoid(x)\n\n\ndef sparse_sigmoid(x):\n    x = convert_to_tensor(x)\n    return torch.where(\n        x <= -1,\n        torch.tensor(0.0, device=x.device, dtype=x.dtype),\n        torch.where(\n            x >= 1,\n            torch.tensor(1.0, device=x.device, dtype=x.dtype),\n            0.5 * (x + 1),\n        ),\n    )\n\n\ndef tanh(x):\n    x = convert_to_tensor(x)\n    return tnn.tanh(x)\n\n\ndef tanh_shrink(x):\n    x = convert_to_tensor(x)\n    return tnn.tanhshrink(x)\n\n\ndef softplus(x):\n    x = convert_to_tensor(x)\n    return tnn.softplus(x)\n\n\ndef softsign(x):\n    x = convert_to_tensor(x)\n    return tnn.softsign(x)\n\n\ndef soft_shrink(x, threshold=0.5):\n    x = convert_to_tensor(x)\n    return tnn.softshrink(x, lambd=threshold)\n\n\ndef sparse_plus(x):\n    x = convert_to_tensor(x)\n    return torch.where(\n        x <= -1,\n        torch.zeros_like(x),\n        torch.where(x < 1, (1 / 4) * (x + 1) ** 2, x),\n    )\n\n\ndef silu(x):\n    x = convert_to_tensor(x)\n    return tnn.silu(x)\n\n\ndef squareplus(x, b=4):\n    x = convert_to_tensor(x)\n    b = convert_to_tensor(b)\n    y = x + torch.sqrt(x**2 + b)\n    return y / 2\n\n\ndef log_sigmoid(x):\n    x = convert_to_tensor(x)\n    return tnn.logsigmoid(x)\n\n\ndef leaky_relu(x, negative_slope=0.2):\n    x = convert_to_tensor(x)\n    return tnn.leaky_relu(x, negative_slope=negative_slope)\n\n\ndef hard_sigmoid(x):\n    x = convert_to_tensor(x)\n    return tnn.hardsigmoid(x)\n\n\ndef hard_silu(x):\n    x = convert_to_tensor(x)\n    return tnn.hardswish(x)\n\n\ndef elu(x, alpha=1.0):\n    x = convert_to_tensor(x)\n    return tnn.elu(x, alpha)\n\n\ndef selu(x):\n    x = convert_to_tensor(x)\n    return tnn.selu(x)\n\n\ndef gelu(x, approximate=True):\n    # TODO: torch.nn.gelu expects string approximate of `\"none\"` or `\"tanh\"`\n    x = convert_to_tensor(x)\n    if approximate:\n        return tnn.gelu(x, approximate=\"tanh\")\n    return tnn.gelu(x)\n\n\ndef celu(x, alpha=1.0):\n    x = convert_to_tensor(x)\n    return tnn.celu(x, alpha=alpha)\n\n\ndef glu(x, axis=-1):\n    x = convert_to_tensor(x)\n    return tnn.glu(x, dim=axis)\n\n\ndef hard_tanh(x):\n    x = convert_to_tensor(x)\n    return tnn.hardtanh(x, min_val=-1.0, max_val=1.0)\n\n\ndef hard_shrink(x, threshold=0.5):\n    x = convert_to_tensor(x)\n    return tnn.hardshrink(x, lambd=threshold)\n\n\ndef threshold(x, threshold, default_value):\n    x = convert_to_tensor(x)\n    return tnn.threshold(x, threshold=threshold, value=default_value)\n\n\ndef softmax(x, axis=-1):\n    x = convert_to_tensor(x)\n    dtype = backend.standardize_dtype(x.dtype)\n    # TODO: tnn.softmax doesn't support float16 using cpu\n    if (\n        get_device() == \"cpu\"\n        and backend.standardize_dtype(x.dtype) == \"float16\"\n    ):\n        x = cast(x, \"float32\")\n    if axis is None:\n        # Unlike numpy, PyTorch will handle axis=None as axis=-1.\n        # We need this workaround for the reduction on every dim.\n        output = torch.reshape(x, [-1])\n        output = tnn.softmax(output, dim=-1)\n        output = torch.reshape(output, x.shape)\n    else:\n        output = tnn.softmax(x, dim=axis)\n    return cast(output, dtype)\n\n\ndef log_softmax(x, axis=-1):\n    x = convert_to_tensor(x)\n    dtype = backend.standardize_dtype(x.dtype)\n    # TODO: tnn.log_softmax doesn't support float16 using cpu\n    if (\n        get_device() == \"cpu\"\n        and backend.standardize_dtype(x.dtype) == \"float16\"\n    ):\n        x = cast(x, \"float32\")\n    if axis is None:\n        # Unlike numpy, PyTorch will handle axis=None as axis=-1.\n        # We need this workaround for the reduction on every dim.\n        output = torch.reshape(x, [-1])\n        output = tnn.log_softmax(output, dim=-1)\n        output = torch.reshape(output, x.shape)\n    else:\n        output = tnn.log_softmax(x, dim=axis)\n    return cast(output, dtype)\n\n\ndef sparsemax(x, axis=-1):\n    # Sort logits along the specified axis in descending order\n    logits = convert_to_tensor(x)\n    logits_sorted, _ = torch.sort(logits, dim=axis, descending=True)\n    logits_cumsum = torch.cumsum(logits_sorted, dim=axis)\n    r = torch.arange(\n        1, logits.size(axis) + 1, device=logits.device, dtype=logits.dtype\n    )\n    r_shape = [1] * logits.ndim\n    r_shape[axis] = -1  # Broadcast to match the target axis\n    r = r.view(r_shape)\n    support = logits_sorted - (logits_cumsum - 1) / r > 0\n    # Find the threshold\n    k = torch.sum(support, dim=axis, keepdim=True)\n    logits_cumsum_safe = torch.where(\n        support, logits_cumsum, torch.tensor(0.0, device=logits.device)\n    )\n    tau = (torch.sum(logits_cumsum_safe, dim=axis, keepdim=True) - 1) / k\n    output = torch.clamp(logits - tau, min=0.0)\n    return output\n\n\ndef _compute_padding_length(\n    input_length, kernel_length, stride, dilation_rate=1\n):\n    \"\"\"Compute padding length along one dimension with support\n    for asymmetric padding.\"\"\"\n    effective_k_size = (kernel_length - 1) * dilation_rate + 1\n    if stride == 1:\n        # total padding is kernel_size - 1\n        total_padding = effective_k_size - 1\n    else:\n        # calc. needed padding for case with stride involved\n        output_size = (input_length + stride - 1) // stride\n        total_padding = max(\n            0, (output_size - 1) * stride + effective_k_size - input_length\n        )\n\n    # divide padding evenly, with extra pixel going at the end if needed\n    left_padding = total_padding // 2\n    right_padding = total_padding - left_padding\n    return (left_padding, right_padding)\n\n\ndef _apply_same_padding(\n    inputs, kernel_size, strides, data_format, operation_type, dilation_rate=1\n):\n    \"\"\"Apply same padding to the input tensor.\n\n    This function will evaluate if the padding value is compatible with torch\n    functions. To avoid calling `pad()` as much as possible, which may cause\n    performance or memory issues, when compatible, it does not apply the padding\n    to the tensor, but returns the input tensor and the padding value to pass to\n    the torch functions. If not compatible, it returns the padded tensor and 0\n    as the padding value.\n\n    Returns:\n        tensor: A padded tensor or the inputs.\n        padding: The padding value, ready to pass to the torch functions.\n    \"\"\"\n    spatial_shape = inputs.shape[2:]\n    num_spatial_dims = len(spatial_shape)\n    padding = []\n\n    if operation_type != \"pooling\":\n        dilation_rate = standardize_tuple(\n            dilation_rate, num_spatial_dims, \"dilation_rate\"\n        )\n\n    for i in range(num_spatial_dims):\n        dil = 1 if operation_type == \"pooling\" else dilation_rate[i]\n        pad = _compute_padding_length(\n            spatial_shape[i], kernel_size[i], strides[i], dil\n        )\n        padding.append(pad)\n\n    # convert padding to torch format\n    if all(left == right for left, right in padding):\n        return inputs, [left for left, _ in padding]\n\n    # else, need to pad manually\n    flattened_padding = []\n    for pad in reversed(padding):\n        flattened_padding.extend(pad)\n\n    mode = \"replicate\" if operation_type == \"pooling\" else \"constant\"\n    return tnn.pad(inputs, pad=tuple(flattened_padding), mode=mode), 0\n\n\ndef _transpose_spatial_inputs(inputs):\n    \"\"\"Transpose inputs from channels_last to channels_first format.\"\"\"\n    # Torch pooling does not support `channels_last` format, so\n    # we need to transpose to `channels_first` format.\n    ndim = inputs.ndim - 2\n    if ndim == 1:  # 1D case\n        return torch.permute(inputs, (0, 2, 1))\n    elif ndim == 2:  # 2D case\n        return torch.permute(inputs, (0, 3, 1, 2))\n    elif ndim == 3:  # 3D case\n        return torch.permute(inputs, (0, 4, 1, 2, 3))\n    raise ValueError(\n        \"Inputs must have ndim=3, 4 or 5, \"\n        \"corresponding to 1D, 2D and 3D inputs. \"\n        f\"Received input shape: {inputs.shape}.\"\n    )\n\n\ndef _transpose_spatial_outputs(outputs):\n    # Undo the transpose in `_transpose_spatial_inputs`.\n    num_spatial_dims = len(outputs.shape) - 2\n    if num_spatial_dims == 1:\n        outputs = torch.permute(outputs, (0, 2, 1))\n    elif num_spatial_dims == 2:\n        outputs = torch.permute(outputs, (0, 2, 3, 1))\n    elif num_spatial_dims == 3:\n        outputs = torch.permute(outputs, (0, 2, 3, 4, 1))\n    return outputs\n\n\ndef _transpose_conv_kernel(kernel):\n    # Torch requires conv kernel of format\n    # `(out_channels, in_channels, spatial_dims)`, we need to transpose.\n    num_spatial_dims = len(kernel.shape) - 2\n    if num_spatial_dims == 1:\n        kernel = torch.permute(kernel, (2, 1, 0))\n    elif num_spatial_dims == 2:\n        kernel = torch.permute(kernel, (3, 2, 0, 1))\n    elif num_spatial_dims == 3:\n        kernel = torch.permute(kernel, (4, 3, 0, 1, 2))\n    return kernel\n\n\ndef _get_channels_last_memory_format(ndim):\n    if ndim == 4:\n        return torch.channels_last\n    elif ndim == 5:\n        return torch.channels_last_3d\n    return None\n\n\ndef _maybe_convert_to_channels_last(tensor):\n    mem_fmt = _get_channels_last_memory_format(tensor.ndim)\n    if mem_fmt is not None and not tensor.is_contiguous(memory_format=mem_fmt):\n        return tensor.contiguous(memory_format=mem_fmt)\n    return tensor\n\n\ndef max_pool(\n    inputs,\n    pool_size,\n    strides=None,\n    padding=\"valid\",\n    data_format=None,\n):\n    \"\"\"Fixed max pooling implementation.\"\"\"\n    inputs = convert_to_tensor(inputs)\n    num_spatial_dims = inputs.ndim - 2\n    pool_size = standardize_tuple(pool_size, num_spatial_dims, \"pool_size\")\n    if strides is None:\n        strides = pool_size\n    else:\n        strides = standardize_tuple(strides, num_spatial_dims, \"strides\")\n\n    data_format = backend.standardize_data_format(data_format)\n    if data_format == \"channels_last\":\n        inputs = _transpose_spatial_inputs(inputs)\n\n    if padding == \"same\":\n        # Torch does not natively support `\"same\"` padding, we need to manually\n        # apply the right amount of padding to `inputs`.\n        inputs, padding = _apply_same_padding(\n            inputs, pool_size, strides, data_format, \"pooling\"\n        )\n    else:\n        padding = 0\n\n    device = get_device()\n    # Torch max pooling ops do not support symbolic tensors.\n    # Create a real tensor to execute the ops.\n    if device == \"meta\":\n        inputs = torch.empty(\n            size=inputs.shape, dtype=inputs.dtype, device=\"cpu\"\n        )\n\n    if num_spatial_dims == 1:\n        outputs = tnn.max_pool1d(\n            inputs, kernel_size=pool_size, stride=strides, padding=padding\n        )\n    elif num_spatial_dims == 2:\n        outputs = tnn.max_pool2d(\n            inputs, kernel_size=pool_size, stride=strides, padding=padding\n        )\n    elif num_spatial_dims == 3:\n        outputs = tnn.max_pool3d(\n            inputs, kernel_size=pool_size, stride=strides, padding=padding\n        )\n    else:\n        raise ValueError(\n            \"Inputs to pooling op must have ndim=3, 4 or 5, \"\n            \"corresponding to 1D, 2D and 3D inputs. \"\n            f\"Received input shape: {inputs.shape}.\"\n        )\n\n    outputs = outputs.to(device)\n    if data_format == \"channels_last\":\n        outputs = _transpose_spatial_outputs(outputs)\n    return outputs\n\n\ndef average_pool(\n    inputs,\n    pool_size,\n    strides=None,\n    padding=\"valid\",\n    data_format=None,\n):\n    \"\"\"Fixed average pooling with correct padding calculation.\"\"\"\n    inputs = convert_to_tensor(inputs)\n    num_spatial_dims = inputs.ndim - 2\n    pool_size = standardize_tuple(pool_size, num_spatial_dims, \"pool_size\")\n    strides = (\n        pool_size\n        if strides is None\n        else standardize_tuple(strides, num_spatial_dims, \"strides\")\n    )\n\n    data_format = backend.standardize_data_format(data_format)\n    orig_format = data_format\n\n    if data_format == \"channels_last\":\n        inputs = _transpose_spatial_inputs(inputs)\n\n    if padding == \"same\":\n        # Torch does not natively support `\"same\"` padding, we need to manually\n        # apply the right amount of padding to `inputs`.\n        inputs, padding = _apply_same_padding(\n            inputs,\n            pool_size,\n            strides,\n            \"channels_first\",  # we're in channels_first here\n            \"pooling\",\n        )\n    else:\n        padding = 0\n\n    # apply pooling\n    if num_spatial_dims == 1:\n        outputs = tnn.avg_pool1d(\n            inputs,\n            kernel_size=pool_size,\n            stride=strides,\n            padding=padding,\n            count_include_pad=False,\n        )\n    elif num_spatial_dims == 2:\n        outputs = tnn.avg_pool2d(\n            inputs,\n            kernel_size=pool_size,\n            stride=strides,\n            padding=padding,\n            count_include_pad=False,\n        )\n    elif num_spatial_dims == 3:\n        outputs = tnn.avg_pool3d(\n            inputs,\n            kernel_size=pool_size,\n            stride=strides,\n            padding=padding,\n            count_include_pad=False,\n        )\n    else:\n        raise ValueError(\n            \"Inputs to pooling op must have ndim=3, 4 or 5, \"\n            \"corresponding to 1D, 2D and 3D inputs. \"\n            f\"Received input shape: {inputs.shape}.\"\n        )\n\n    if orig_format == \"channels_last\":\n        outputs = _transpose_spatial_outputs(outputs)\n\n    return outputs\n\n\ndef adaptive_average_pool(inputs, output_size, data_format=None):\n    \"\"\"Adaptive average pooling(1D/2D/3D) with channels_last support.\"\"\"\n    inputs = convert_to_tensor(inputs)\n    num_spatial_dims = inputs.ndim - 2\n\n    data_format = backend.standardize_data_format(data_format)\n    orig_format = data_format\n    if data_format == \"channels_last\":\n        inputs = _transpose_spatial_inputs(inputs)\n\n    if isinstance(output_size, int):\n        torch_output_size = (\n            output_size\n            if num_spatial_dims == 1\n            else (output_size,) * num_spatial_dims\n        )\n    else:\n        torch_output_size = standardize_tuple(\n            output_size, num_spatial_dims, \"output_size\"\n        )\n\n    if get_device() == \"meta\":\n        inputs = torch.empty(\n            size=inputs.shape, dtype=inputs.dtype, device=\"cpu\"\n        )\n\n    if num_spatial_dims == 1:\n        outputs = tnn.adaptive_avg_pool1d(inputs, output_size=torch_output_size)\n    elif num_spatial_dims == 2:\n        outputs = tnn.adaptive_avg_pool2d(inputs, output_size=torch_output_size)\n    elif num_spatial_dims == 3:\n        outputs = tnn.adaptive_avg_pool3d(inputs, output_size=torch_output_size)\n    else:\n        raise ValueError(\n            \"Inputs to adaptive average pooling must have ndim=3, 4 or 5, \"\n            f\"Received input shape: {inputs.shape}.\"\n        )\n\n    if orig_format == \"channels_last\":\n        outputs = _transpose_spatial_outputs(outputs)\n    return outputs\n\n\ndef adaptive_max_pool(inputs, output_size, data_format=None):\n    \"\"\"Adaptive max pooling(1D/2D/3D) with channels_last support.\"\"\"\n    inputs = convert_to_tensor(inputs)\n    num_spatial_dims = inputs.ndim - 2\n\n    data_format = backend.standardize_data_format(data_format)\n    orig_format = data_format\n    if data_format == \"channels_last\":\n        inputs = _transpose_spatial_inputs(inputs)\n\n    if isinstance(output_size, int):\n        torch_output_size = (\n            output_size\n            if num_spatial_dims == 1\n            else (output_size,) * num_spatial_dims\n        )\n    else:\n        torch_output_size = standardize_tuple(\n            output_size, num_spatial_dims, \"output_size\"\n        )\n\n    if get_device() == \"meta\":\n        inputs = torch.empty(\n            size=inputs.shape, dtype=inputs.dtype, device=\"cpu\"\n        )\n\n    if num_spatial_dims == 1:\n        res = tnn.adaptive_max_pool1d(inputs, output_size=torch_output_size)\n    elif num_spatial_dims == 2:\n        res = tnn.adaptive_max_pool2d(inputs, output_size=torch_output_size)\n    elif num_spatial_dims == 3:\n        res = tnn.adaptive_max_pool3d(inputs, output_size=torch_output_size)\n    else:\n        raise ValueError(\n            \"Inputs to adaptive max pooling must have ndim=3, 4 or 5, \"\n            f\"Received input shape: {inputs.shape}.\"\n        )\n\n    outputs = res[0] if isinstance(res, tuple) else res\n\n    if orig_format == \"channels_last\":\n        outputs = _transpose_spatial_outputs(outputs)\n    return outputs\n\n\ndef conv(\n    inputs,\n    kernel,\n    strides=1,\n    padding=\"valid\",\n    data_format=None,\n    dilation_rate=1,\n):\n    \"\"\"Convolution with fixed group handling.\"\"\"\n    inputs = convert_to_tensor(inputs)\n    kernel = convert_to_tensor(kernel)\n    num_spatial_dims = inputs.ndim - 2\n    strides = standardize_tuple(strides, num_spatial_dims, \"strides\")\n\n    data_format = backend.standardize_data_format(data_format)\n    if data_format == \"channels_last\":\n        inputs = _transpose_spatial_inputs(inputs)\n\n    kernel = _transpose_conv_kernel(kernel)\n\n    if data_format == \"channels_last\":\n        inputs = _maybe_convert_to_channels_last(inputs)\n        kernel = _maybe_convert_to_channels_last(kernel)\n\n    # calc. groups snippet\n    in_channels = inputs.shape[1]\n    kernel_in_channels = kernel.shape[1]\n    if in_channels % kernel_in_channels != 0:\n        raise ValueError(\n            f\"Input channels ({in_channels}) must be divisible by \"\n            f\"kernel input channels ({kernel_in_channels})\"\n        )\n    groups = in_channels // kernel_in_channels\n\n    # handle padding\n    if padding == \"same\":\n        inputs, padding = _apply_same_padding(\n            inputs,\n            kernel.shape[2:],\n            strides,\n            data_format,\n            \"conv\",\n            dilation_rate,\n        )\n    else:\n        padding = 0\n\n    # apply convolution\n    if num_spatial_dims == 1:\n        outputs = tnn.conv1d(\n            inputs,\n            kernel,\n            stride=strides,\n            padding=padding,\n            dilation=dilation_rate,\n            groups=groups,\n        )\n    elif num_spatial_dims == 2:\n        outputs = tnn.conv2d(\n            inputs,\n            kernel,\n            stride=strides,\n            padding=padding,\n            dilation=dilation_rate,\n            groups=groups,\n        )\n    elif num_spatial_dims == 3:\n        outputs = tnn.conv3d(\n            inputs,\n            kernel,\n            stride=strides,\n            padding=padding,\n            dilation=dilation_rate,\n            groups=groups,\n        )\n    else:\n        raise ValueError(\n            \"Inputs to conv operation should have ndim=3, 4, or 5,\"\n            \"corresponding to 1D, 2D and 3D inputs. Received input \"\n            f\"shape: {inputs.shape}.\"\n        )\n\n    if data_format == \"channels_last\":\n        outputs = _transpose_spatial_outputs(outputs)\n    return outputs\n\n\ndef depthwise_conv(\n    inputs,\n    kernel,\n    strides=1,\n    padding=\"valid\",\n    data_format=None,\n    dilation_rate=1,\n):\n    kernel = convert_to_tensor(kernel)\n    kernel = torch.reshape(\n        kernel, kernel.shape[:-2] + (1, kernel.shape[-2] * kernel.shape[-1])\n    )\n    return conv(inputs, kernel, strides, padding, data_format, dilation_rate)\n\n\ndef separable_conv(\n    inputs,\n    depthwise_kernel,\n    pointwise_kernel,\n    strides=1,\n    padding=\"valid\",\n    data_format=None,\n    dilation_rate=1,\n):\n    depthwise_conv_output = depthwise_conv(\n        inputs,\n        depthwise_kernel,\n        strides,\n        padding,\n        data_format,\n        dilation_rate,\n    )\n    return conv(\n        depthwise_conv_output,\n        pointwise_kernel,\n        strides=1,\n        padding=\"valid\",\n        data_format=data_format,\n        dilation_rate=dilation_rate,\n    )\n\n\ndef conv_transpose(\n    inputs,\n    kernel,\n    strides=1,\n    padding=\"valid\",\n    output_padding=None,\n    data_format=None,\n    dilation_rate=1,\n):\n    inputs = convert_to_tensor(inputs)\n    kernel = convert_to_tensor(kernel)\n    num_spatial_dims = inputs.ndim - 2\n    strides = standardize_tuple(strides, num_spatial_dims, \"strides\")\n\n    data_format = backend.standardize_data_format(data_format)\n    (\n        torch_padding,\n        torch_output_padding,\n    ) = compute_conv_transpose_padding_args_for_torch(\n        input_shape=inputs.shape,\n        kernel_shape=kernel.shape,\n        strides=strides,\n        padding=padding,\n        output_padding=output_padding,\n        dilation_rate=dilation_rate,\n    )\n    if data_format == \"channels_last\":\n        inputs = _transpose_spatial_inputs(inputs)\n    # Transpose kernel from keras format to torch format.\n    kernel = _transpose_conv_kernel(kernel)\n\n    if data_format == \"channels_last\":\n        inputs = _maybe_convert_to_channels_last(inputs)\n        kernel = _maybe_convert_to_channels_last(kernel)\n\n    kernel_spatial_shape = kernel.shape[2:]\n    if isinstance(dilation_rate, int):\n        dilation_rate = [dilation_rate] * len(kernel_spatial_shape)\n\n    if num_spatial_dims == 1:\n        outputs = tnn.conv_transpose1d(\n            inputs,\n            kernel,\n            stride=strides,\n            padding=torch_padding,\n            output_padding=torch_output_padding,\n            dilation=dilation_rate,\n        )\n    elif num_spatial_dims == 2:\n        outputs = tnn.conv_transpose2d(\n            inputs,\n            kernel,\n            stride=strides,\n            padding=torch_padding,\n            output_padding=torch_output_padding,\n            dilation=dilation_rate,\n        )\n    elif num_spatial_dims == 3:\n        outputs = tnn.conv_transpose3d(\n            inputs,\n            kernel,\n            stride=strides,\n            padding=torch_padding,\n            output_padding=torch_output_padding,\n            dilation=dilation_rate,\n        )\n    else:\n        raise ValueError(\n            \"Inputs to conv transpose operation should have ndim=3, 4, or 5,\"\n            \"corresponding to 1D, 2D and 3D inputs. Received input \"\n            f\"shape: {inputs.shape}.\"\n        )\n    if data_format == \"channels_last\":\n        outputs = _transpose_spatial_outputs(outputs)\n    return outputs\n\n\ndef one_hot(x, num_classes, axis=-1, dtype=None, sparse=False):\n    if sparse:\n        raise ValueError(\"Unsupported value `sparse=True` with torch backend\")\n    # Axis is the output axis. By default, PyTorch, outputs to last axis.\n    # If axis is not last, change output to axis and shift remaining elements.\n    x = convert_to_tensor(x, dtype=torch.long)\n    zero = convert_to_tensor(0, dtype=torch.long)\n\n    # Torch one_hot does not natively handle negative values, so we add some\n    # manual handling for negatives in the input to one_hot by using max(x, 0).\n    # The output will have some invalid results, so we set them back to 0 using\n    # `where` afterwards.\n    output = tnn.one_hot(torch.clamp(x, min=0), num_classes)\n    output = where(expand_dims(x, axis=-1) >= 0, output, zero)\n    output = convert_to_tensor(output, dtype=dtype)\n    dims = output.dim()\n    if axis != -1 and axis != dims:\n        new_axes_order = list(range(dims))\n        new_axes_order[axis] = -1  # Shifts output to axis position\n        # Shift remaining axes with offset by 1 since output moved to `axis`.\n        for ax in range(axis + 1, dims):\n            new_axes_order[ax] -= 1\n        output = output.permute(new_axes_order)\n    return output\n\n\ndef multi_hot(x, num_classes, axis=-1, dtype=None, sparse=False):\n    if sparse:\n        raise ValueError(\"Unsupported value `sparse=True` with torch backend\")\n    x = convert_to_tensor(x)\n    reduction_axis = 1 if len(x.shape) > 1 else 0\n    outputs = torch.amax(\n        one_hot(cast(x, \"int32\"), num_classes, axis=axis, dtype=dtype),\n        dim=reduction_axis,\n    )\n    return outputs\n\n\ndef categorical_crossentropy(target, output, from_logits=False, axis=-1):\n    target = convert_to_tensor(target)\n    output = convert_to_tensor(output)\n\n    if target.shape != output.shape:\n        raise ValueError(\n            \"Arguments `target` and `output` must have the same shape. \"\n            \"Received: \"\n            f\"target.shape={target.shape}, output.shape={output.shape}\"\n        )\n    if len(target.shape) < 1:\n        raise ValueError(\n            \"Arguments `target` and `output` must be at least rank 1. \"\n            \"Received: \"\n            f\"target.shape={target.shape}, output.shape={output.shape}\"\n        )\n\n    if from_logits:\n        log_prob = tnn.log_softmax(output, dim=axis)\n    else:\n        output = output / torch.sum(output, dim=axis, keepdim=True)\n        output = torch.clip(output, backend.epsilon(), 1.0 - backend.epsilon())\n        log_prob = torch.log(output)\n    return -torch.sum(target * log_prob, dim=axis)\n\n\ndef sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):\n    target = convert_to_tensor(target, dtype=torch.long)\n    output = convert_to_tensor(output)\n\n    if len(target.shape) == len(output.shape) and target.shape[-1] == 1:\n        target = torch.squeeze(target, dim=-1)\n\n    if len(output.shape) < 1:\n        raise ValueError(\n            \"Argument `output` must be at least rank 1. \"\n            \"Received: \"\n            f\"output.shape={output.shape}\"\n        )\n    output_shape_without_class_dim = list(output.shape)\n    del output_shape_without_class_dim[axis]\n\n    if list(target.shape) != output_shape_without_class_dim:\n        raise ValueError(\n            \"Arguments `target` and `output` must have the same shape \"\n            \"up until the last dimension: \"\n            f\"target.shape={target.shape}, output.shape={output.shape}\"\n        )\n    # Use PyTorch native cross-entropy ops to avoid allocating a full\n    # one-hot matrix of shape (batch, ..., num_classes).  For large\n    # vocabularies this saves gigabytes of GPU memory per step.\n    # F.cross_entropy / F.nll_loss expect the class dim at position 1,\n    if output.dim() == 1:\n        output = output.unsqueeze(0)\n        target = target.unsqueeze(0)\n        squeeze = True\n    else:\n        squeeze = False\n        class_axis = axis % output.dim()\n        if class_axis != 1:\n            output = output.movedim(class_axis, 1)\n\n    if from_logits:\n        result = tnn.cross_entropy(output, target, reduction=\"none\")\n    else:\n        output = output / torch.sum(output, dim=1, keepdim=True)\n        output = torch.clip(output, backend.epsilon(), 1.0 - backend.epsilon())\n        log_prob = torch.log(output)\n        result = tnn.nll_loss(log_prob, target, reduction=\"none\")\n\n    if squeeze:\n        result = result.squeeze(0)\n    return result\n\n\ndef binary_crossentropy(target, output, from_logits=False):\n    target = convert_to_tensor(target)\n    output = convert_to_tensor(output)\n\n    # We only apply the squeeze fix if we are on an MPS device,\n    # as this change breaks tests on other platforms that\n    # expect the original tensor shape to be preserved.\n    if (\n        torch.backends.mps.is_available()\n        and target.ndim > 1\n        and output.ndim == target.ndim\n        and target.shape[-1] == 1\n        and output.shape[-1] == 1\n    ):\n        target = torch.squeeze(target, -1).contiguous()\n        output = torch.squeeze(output, -1).contiguous()\n\n    if target.shape != output.shape:\n        raise ValueError(\n            \"Arguments `target` and `output` must have the same shape. \"\n            \"Received: \"\n            f\"target.shape={target.shape}, output.shape={output.shape}\"\n        )\n\n    # By default, PyTorch, does reduction of `sum` over all rows,\n    # change reduction to `none` to keep dim\n    if from_logits:\n        return tnn.binary_cross_entropy_with_logits(\n            output, target, reduction=\"none\"\n        )\n    else:\n        output = torch.clip(output, backend.epsilon(), 1.0 - backend.epsilon())\n        return tnn.binary_cross_entropy(output, target, reduction=\"none\")\n\n\ndef moments(x, axes, keepdims=False, synchronized=False):\n    if synchronized:\n        raise NotImplementedError(\n            \"Argument synchronized=True is not supported with PyTorch.\"\n        )\n    x = convert_to_tensor(x)\n    # The dynamic range of float16 is too limited for statistics. As a\n    # workaround, we simply perform the operations on float32 and convert back\n    # to float16\n    need_cast = False\n    ori_dtype = backend.standardize_dtype(x.dtype)\n    if ori_dtype == \"float16\":\n        need_cast = True\n        x = cast(x, \"float32\")\n\n    mean = torch.mean(x, dim=axes, keepdim=True)\n\n    # The variance is computed using $Var = E[|x|^2] - |E[x]|^2$, It is faster\n    # but less numerically stable.\n    # Note: stop_gradient does not change the gradient to the mean, because that\n    # gradient is zero.\n    variance = torch.mean(\n        torch.square(x), dim=axes, keepdim=True\n    ) - torch.square(mean)\n\n    if not keepdims:\n        mean = torch.squeeze(mean, axes)\n        variance = torch.squeeze(variance, axes)\n    if need_cast:\n        # avoid overflow and underflow when casting from float16 to float32\n        mean = torch.clip(\n            mean,\n            torch.finfo(torch.float16).min,\n            torch.finfo(torch.float16).max,\n        )\n        variance = torch.clip(\n            variance,\n            torch.finfo(torch.float16).min,\n            torch.finfo(torch.float16).max,\n        )\n        mean = cast(mean, ori_dtype)\n        variance = cast(variance, ori_dtype)\n    return mean, variance\n\n\ndef batch_normalization(\n    x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3\n):\n    x = convert_to_tensor(x)\n    mean = convert_to_tensor(mean)\n    variance = convert_to_tensor(variance)\n\n    shape = [1] * len(x.shape)\n    shape[axis] = mean.shape[0]\n    mean = torch.reshape(mean, shape)\n    variance = torch.reshape(variance, shape)\n\n    if offset is not None:\n        offset = convert_to_tensor(offset)\n        offset = torch.reshape(offset, shape)\n    else:\n        offset = torch.zeros_like(mean)\n    if scale is not None:\n        scale = convert_to_tensor(scale)\n        scale = torch.reshape(scale, shape)\n    else:\n        scale = torch.ones_like(variance)\n\n    return (\n        x.subtract(mean)\n        .mul_(variance.add(epsilon).rsqrt_().mul(scale))\n        .add_(offset)\n    )\n\n\ndef ctc_loss(target, output, target_length, output_length, mask_index=0):\n    target = convert_to_tensor(target)\n    output = convert_to_tensor(output)\n    target_length = convert_to_tensor(target_length)\n    output_length = convert_to_tensor(output_length)\n\n    # Ensure that the dtype promotion behavior matches that of `tf.nn.ctc_loss`\n    dtype = backend.result_type(output.dtype, \"float32\")\n    output = cast(output, dtype)\n\n    output = torch.transpose(output, 1, 0)\n    logits = tnn.log_softmax(output, dim=-1)\n    loss = tnn.ctc_loss(\n        logits,\n        target,\n        output_length,\n        target_length,\n        blank=mask_index,\n        reduction=\"none\",\n    )\n    return loss\n\n\ndef _ctc_greedy_decode(\n    inputs,\n    sequence_lengths,\n    merge_repeated=True,\n    mask_index=None,\n):\n    inputs = convert_to_tensor(inputs)\n    sequence_lengths = convert_to_tensor(sequence_lengths, dtype=\"int32\")\n    batch_size, max_length, num_classes = inputs.shape\n\n    if mask_index is None:\n        mask_index = num_classes - 1\n\n    indices = torch.argmax(inputs, axis=-1)\n    indices = cast(indices, \"int32\")\n    scores = torch.max(inputs, axis=-1)[0]\n\n    seqlen_mask = torch.arange(max_length, device=indices.device)[None, :]\n    seqlen_mask = seqlen_mask >= sequence_lengths[:, None]\n\n    indices = torch.where(seqlen_mask, mask_index, indices)\n    scores = torch.where(seqlen_mask, 0.0, scores)\n\n    if merge_repeated:\n        repeat = indices[:, 1:] == indices[:, :-1]\n        repeat = tnn.pad(repeat, (1, 0, 0, 0))\n        indices = torch.where(repeat, mask_index, indices)\n\n    # We set to -1 for blank labels\n    invalid_mask = indices == mask_index\n    indices = torch.where(invalid_mask, -1, indices)\n\n    # We rearrange the indices by moving `mask_index` to the end of the array\n    order = torch.unsqueeze(\n        torch.arange(max_length, device=indices.device), dim=0\n    )  # [1, N]\n    order = torch.tile(order, (batch_size, 1))  # [B, N]\n    order = torch.where(invalid_mask, max_length, order)\n    order = torch.argsort(order, dim=-1)\n    indices = torch.take_along_dim(indices, order, dim=-1)\n\n    scores = -torch.sum(scores, axis=1)[:, None]\n    indices = torch.unsqueeze(indices, dim=0)\n    return indices, scores\n\n\ndef ctc_decode(\n    inputs,\n    sequence_lengths,\n    strategy=\"greedy\",\n    beam_width=100,\n    top_paths=1,\n    merge_repeated=True,\n    mask_index=0,\n):\n    inputs = convert_to_tensor(inputs)\n    dtype = backend.result_type(inputs.dtype, \"float32\")\n    inputs = cast(inputs, dtype)\n\n    if strategy == \"greedy\":\n        return _ctc_greedy_decode(\n            inputs,\n            sequence_lengths,\n            merge_repeated=merge_repeated,\n            mask_index=mask_index,\n        )\n    elif strategy == \"beam_search\":\n        raise NotImplementedError(\n            \"Torch backend doesn't yet support the beam search strategy for CTC\"\n            \"decoding.\"\n        )\n    else:\n        raise ValueError(\n            f\"Invalid strategy {strategy}. Supported values are \"\n            \"'greedy' and 'beam_search'.\"\n        )\n\n\ndef psnr(x1, x2, max_val):\n    if x1.shape != x2.shape:\n        raise ValueError(\n            f\"Input shapes {x1.shape} and {x2.shape} must \"\n            \"match for PSNR calculation. \"\n        )\n\n    x1, x2 = (\n        convert_to_tensor(x1),\n        convert_to_tensor(x2),\n    )\n    max_val = convert_to_tensor(max_val, dtype=x1.dtype)\n    mse = torch.mean((x1 - x2) ** 2)\n    psnr = 20 * torch.log10(max_val) - 10 * torch.log10(mse)\n    return psnr\n\n\ndef _get_large_negative(dtype):\n    dtype = backend.standardize_dtype(dtype)\n    if dtype == \"float16\":\n        val = 65500.0\n    else:\n        val = 3.38953e38\n    return convert_to_tensor(val * -0.7, dtype=dtype)\n\n\ndef _can_use_flash_attention(\n    query, key, value, mask=None, is_causal=False, raise_error=False\n):\n    \"\"\"Verify the availability of flash attention.\"\"\"\n    try:\n        from torch.backends.cuda import SDPAParams\n        from torch.backends.cuda import can_use_flash_attention\n    except ImportError:\n        if raise_error:\n            raise ImportError(\n                \"Flash attention is not supported in your current PyTorch \"\n                \"version. Please update it by following the official guide: \"\n                \"https://pytorch.org/get-started/locally/\"\n            )\n        return False\n\n    try:\n        spda_params = SDPAParams(\n            query,\n            key,\n            value,\n            mask,\n            0.0,  # dropout_p\n            is_causal,\n            False,  # enable_gqa\n        )\n    except TypeError:\n        # The old function signature for the older version of PyTorch\n        spda_params = SDPAParams(\n            query,\n            key,\n            value,\n            mask,\n            0.0,  # dropout_p\n            is_causal,\n        )\n    if raise_error and can_use_flash_attention(spda_params, True) is False:\n        raise RuntimeError(\n            \"Flash attention is not supported with the provided inputs. \"\n            \"Please check the warnings for more details.\"\n        )\n    return can_use_flash_attention(spda_params, False)\n\n\ndef dot_product_attention(\n    query,\n    key,\n    value,\n    bias=None,\n    mask=None,\n    scale=None,\n    is_causal=False,\n    flash_attention=None,\n    attn_logits_soft_cap=None,\n):\n    query = convert_to_tensor(query)\n    key = convert_to_tensor(key)\n    value = convert_to_tensor(value)\n    if len(query.shape) != 4 or len(key.shape) != 4 or len(value.shape) != 4:\n        raise ValueError(\n            \"`dot_product_attention` only supports 4D inputs. \"\n            f\"Received: query.shape={query.shape}, key.shape={key.shape}, \"\n            f\"value.shape={value.shape}.\"\n        )\n    if bias is not None and mask is not None:\n        raise ValueError(\n            \"Only one of `bias` and `mask` can be provided. Received both.\"\n        )\n    compute_dtype = backend.result_type(query.dtype, key.dtype, value.dtype)\n    query = cast(query, compute_dtype)\n    key = cast(key, compute_dtype)\n    value = cast(value, compute_dtype)\n\n    mask = mask if mask is None else convert_to_tensor(mask, dtype=\"bool\")\n    if mask is not None:\n        # Explicit set `is_causal` to `False` when `mask` is not `None`.\n        is_causal = False\n        mask = torch.where(mask, 0.0, _get_large_negative(query.dtype))\n    if bias is not None:\n        bias = convert_to_tensor(bias, dtype=compute_dtype)\n        mask = bias  # Use `bias` as `mask` for scaled_dot_product_attention.\n\n    axis0, axis1 = 1, 2\n    query = torch.transpose(query, axis0, axis1)\n    key = torch.transpose(key, axis0, axis1)\n    value = torch.transpose(value, axis0, axis1)\n\n    if flash_attention is None:\n        flash_attention = _can_use_flash_attention(\n            query, key, value, mask, is_causal\n        )\n    elif flash_attention is True:\n        # Use `raise_error=True` to provide more details if the inputs failed to\n        # use flash attention\n        _can_use_flash_attention(\n            query, key, value, mask, is_causal, raise_error=True\n        )\n    if flash_attention:\n        with torch.nn.attention.sdpa_kernel(\n            backends=[torch.nn.attention.SDPBackend.FLASH_ATTENTION],\n        ):\n            attention_output = torch.nn.functional.scaled_dot_product_attention(\n                query,\n                key,\n                value,\n                attn_mask=mask,\n                is_causal=is_causal,\n                scale=scale,\n            )\n    else:\n        if mask is not None:\n            mask = mask.contiguous()\n        attention_output = torch.nn.functional.scaled_dot_product_attention(\n            query.contiguous(),\n            key.contiguous(),\n            value.contiguous(),\n            attn_mask=mask,\n            is_causal=is_causal,\n            scale=scale,\n        )\n    return torch.transpose(attention_output, axis1, axis0)\n\n\ndef unfold(input, kernel_size, dilation=1, padding=0, stride=1):\n    \"\"\"Native PyTorch implementation of Unfold.\n    Extract sliding local blocks from a **NCHW** batched image tensor.\n\n    Args:\n        input: 4-D tensor, shape (N, C, H, W)  **required**.\n        kernel_size: int or (kH, kW)\n        dilation: int or (dH, dW), default 1\n        padding: int or (pH, pW), default 0\n        stride: int or (sH, sW), default 1\n\n    Returns:\n        3-D tensor, shape (N, C*kH*kW, L)\n    \"\"\"\n    return tnn.unfold(\n        input,\n        kernel_size=kernel_size,\n        dilation=dilation,\n        padding=padding,\n        stride=stride,\n    )\n\n\ndef fold(x, output_size, kernel_size, dilation=1, padding=0, stride=1):\n    \"\"\"Native PyTorch implementation of Fold.\n    Combine an array of sliding local blocks into a large tensor (col2im).\n\n    Args:\n        x: 3-D tensor, shape (N, C*kH*kW, L)  **required**.\n        output_size: int or (oH, oW)\n        kernel_size: int or (kH, kW)\n        dilation: int or (dH, dW), default 1\n        padding: int or (pH, pW), default 0\n        stride: int or (sH, sW), default 1\n\n    Returns:\n        4-D tensor, shape (N, C, oH, oW)\n    \"\"\"\n    return tnn.fold(\n        x,\n        output_size=output_size,\n        kernel_size=kernel_size,\n        dilation=dilation,\n        padding=padding,\n        stride=stride,\n    )\n\n\ndef depth_to_space(x, block_size, data_format=\"channels_last\"):\n    \"\"\"PyTorch implementation of depth_to_space.\n\n    Rearranges data from depth into blocks of spatial data.\n    Matches TensorFlow's depth_to_space behavior.\n\n    Args:\n        x: 4-D tensor with shape (N, H, W, C) for channels_last or\n            (N, C, H, W) for channels_first.\n        block_size: An integer specifying the block size.\n        data_format: \"channels_last\" or \"channels_first\".\n\n    Returns:\n        A tensor with shape (N, H*block_size, W*block_size, C/block_size**2)\n        for channels_last or (N, C/block_size**2, H*block_size, W*block_size)\n        for channels_first.\n    \"\"\"\n    x = convert_to_tensor(x)\n    if data_format == \"channels_last\":\n        # NHWC format\n        n, h, w, c = x.shape\n        new_c = c // (block_size**2)\n        # Reshape: (N, H, W, C) -> (N, H, W, block_size, block_size, new_C)\n        x = x.reshape(n, h, w, block_size, block_size, new_c)\n        # Permute to (N, H, bH, W, bW, new_C) to interleave spatial blocks.\n        x = x.permute(0, 1, 3, 2, 4, 5)\n        # Reshape to the final spatial dimensions.\n        x = x.reshape(n, h * block_size, w * block_size, new_c)\n    else:\n        # NCHW format\n        n, c, h, w = x.shape\n        new_c = c // (block_size**2)\n        # Reshape: (N, C, H, W) -> (N, new_C, block_size, block_size, H, W)\n        x = x.reshape(n, new_c, block_size, block_size, h, w)\n        # Permute: (N, C, bH, bW, H, W) -> (N, C, H, bH, W, bW)\n        x = x.permute(0, 1, 4, 2, 5, 3)\n        # Reshape: (N, C, H, bH, W, bW) -> (N, C, H*bH, W*bW)\n        x = x.reshape(n, new_c, h * block_size, w * block_size)\n    return x\n\n\ndef space_to_depth(x, block_size, data_format=\"channels_last\"):\n    \"\"\"PyTorch implementation of space_to_depth.\n\n    Rearranges blocks of spatial data into depth.\n    Matches TensorFlow's space_to_depth behavior.\n\n    Args:\n        x: 4-D tensor with shape (N, H, W, C) for channels_last or\n            (N, C, H, W) for channels_first.\n        block_size: An integer specifying the block size.\n        data_format: \"channels_last\" or \"channels_first\".\n\n    Returns:\n        A tensor with shape (N, H/block_size, W/block_size, C*block_size**2)\n        for channels_last or (N, C*block_size**2, H/block_size, W/block_size)\n        for channels_first.\n    \"\"\"\n    x = convert_to_tensor(x)\n    if data_format == \"channels_last\":\n        # NHWC format\n        n, h, w, c = x.shape\n        new_h = h // block_size\n        new_w = w // block_size\n        # Reshape: (N, H, W, C) -> (N, new_H, bH, new_W, bW, C)\n        x = x.reshape(n, new_h, block_size, new_w, block_size, c)\n        # Permute: (N, new_H, bH, new_W, bW, C) -> (N, new_H, new_W, bH, bW, C)\n        x = x.permute(0, 1, 3, 2, 4, 5)\n        # Reshape: (N, new_H, new_W, bH, bW, C) -> (N, new_H, new_W, C*bH*bW)\n        x = x.reshape(n, new_h, new_w, c * block_size**2)\n    else:\n        # NCHW format\n        n, c, h, w = x.shape\n        new_h = h // block_size\n        new_w = w // block_size\n        # Reshape: (N, C, H, W) -> (N, C, new_H, bH, new_W, bW)\n        x = x.reshape(n, c, new_h, block_size, new_w, block_size)\n        # Permute: (N, C, new_H, bH, new_W, bW) -> (N, C, bH, bW, new_H, new_W)\n        x = x.permute(0, 1, 3, 5, 2, 4)\n        # Reshape: (N, C, bH, bW, new_H, new_W) -> (N, C*bH*bW, new_H, new_W)\n        x = x.reshape(n, c * block_size**2, new_h, new_w)\n    return x\n"
  },
  {
    "path": "keras/src/backend/torch/numpy.py",
    "content": "import builtins\nimport math\n\nimport numpy as np\nimport torch\n\nfrom keras.src.backend import KerasTensor\nfrom keras.src.backend import config\nfrom keras.src.backend.common import dtypes\nfrom keras.src.backend.common.backend_utils import canonicalize_axis\nfrom keras.src.backend.common.backend_utils import to_tuple_or_list\nfrom keras.src.backend.common.backend_utils import vectorize_impl\nfrom keras.src.backend.common.variables import standardize_dtype\nfrom keras.src.backend.torch.core import cast\nfrom keras.src.backend.torch.core import convert_to_tensor\nfrom keras.src.backend.torch.core import get_device\nfrom keras.src.backend.torch.core import is_tensor\nfrom keras.src.backend.torch.core import to_torch_dtype\n\nTORCH_INT_TYPES = (\n    torch.int8,\n    torch.int16,\n    torch.int32,\n    torch.int64,\n)\n\n\ndef rot90(array, k=1, axes=(0, 1)):\n    \"\"\"Rotate an array by 90 degrees in the specified plane using PyTorch.\n\n    Args:\n        array: Input tensor\n        k: Number of 90-degree rotations (default=1)\n        axes: Tuple of two axes that define the\n            plane of rotation (defaults to `(0, 1)`).\n\n    Returns:\n        Rotated tensor\n    \"\"\"\n    array = convert_to_tensor(array)\n\n    if array.ndim < 2:\n        raise ValueError(\n            \"Input array must have at least 2 dimensions. \"\n            f\"Received: array.ndim={array.ndim}\"\n        )\n    if len(axes) != 2 or axes[0] == axes[1]:\n        raise ValueError(\n            f\"Invalid axes: {axes}. Axes must be a tuple \"\n            \"of two different dimensions.\"\n        )\n\n    axes = tuple(axis if axis >= 0 else array.ndim + axis for axis in axes)\n\n    if not builtins.all(0 <= axis < array.ndim for axis in axes):\n        raise ValueError(\n            f\"Invalid axes {axes} for tensor with {array.ndim} dimensions\"\n        )\n\n    rotated = torch.rot90(array, k=k, dims=axes)\n    if isinstance(array, np.ndarray):\n        rotated = rotated.cpu().numpy()\n\n    return rotated\n\n\ndef add(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return torch.add(x1, x2)\n\n\ndef einsum(subscripts, *operands, **kwargs):\n    operands = [convert_to_tensor(operand) for operand in operands]\n    # When all operands are of int8, we cast the result to int32 to align with\n    # the behavior of jax.\n    dtypes_to_resolve = list(set(standardize_dtype(x.dtype) for x in operands))\n    if len(dtypes_to_resolve) == 1 and dtypes_to_resolve[0] == \"int8\":\n        compute_dtype = \"int32\"\n        if get_device() == \"cuda\":\n            # TODO: torch.einsum doesn't support int32 when using cuda\n            compute_dtype = config.floatx()\n        # prevent overflow\n        operands = [cast(operand, compute_dtype) for operand in operands]\n        return cast(torch.einsum(subscripts, *operands), \"int32\")\n    return torch.einsum(subscripts, *operands)\n\n\ndef subtract(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    # TODO: torch.subtract doesn't support bool\n    if standardize_dtype(x1.dtype) == \"bool\":\n        x1 = cast(x1, x2.dtype)\n    if standardize_dtype(x2.dtype) == \"bool\":\n        x2 = cast(x2, x1.dtype)\n    return torch.subtract(x1, x2)\n\n\ndef matmul(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n\n    def can_use_int_matmul(x1, x2):\n        # torch._int_mm only accepts the following conditions:\n        # 1. cuda\n        # 2. both inputs must have int8 dtype\n        # 3. both inputs must be 2d\n        # 4. x1.shape must be [>16, >= 16 and a multiplier of 8]\n        # 5. x2.shape must be [>= 16 and a multiplier of 8, multiplier of 8]\n        if get_device() != \"cuda\":\n            return False\n        x1_dtype = standardize_dtype(x1.dtype)\n        x2_dtype = standardize_dtype(x2.dtype)\n        if x1_dtype != \"int8\" or x2_dtype != \"int8\":\n            return False\n        x1_shape = x1.shape\n        x2_shape = x2.shape\n        if x1.ndim != 2 or x2.ndim != 2:\n            return False\n        if x1_shape[0] <= 16 or x1_shape[1] < 16 or x1_shape[1] % 8 != 0:\n            return False\n        if x2_shape[0] < 16 or x2_shape[0] % 8 != 0 or x2_shape[1] % 8 != 0:\n            return False\n        return True\n\n    # Shortcut for torch._int_mm\n    # TODO: Loosen the restriction of the usage of torch._int_mm\n    # TODO: We should replace torch._int_mm with the public api if possible\n    if can_use_int_matmul(x1, x2):\n        return torch._int_mm(x1, x2)\n\n    x1_dtype = standardize_dtype(x1.dtype)\n    x2_dtype = standardize_dtype(x2.dtype)\n    if x1_dtype == \"int8\" and x2_dtype == \"int8\":\n        result_dtype = \"int32\"\n    else:\n        result_dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    compute_dtype = result_dtype\n\n    # TODO: torch.matmul doesn't support bool\n    if compute_dtype == \"bool\":\n        compute_dtype = config.floatx()\n    # TODO: torch.matmul doesn't support float16 with cpu\n    if get_device() == \"cpu\" and compute_dtype == \"float16\":\n        compute_dtype = \"float32\"\n    # TODO: torch.matmul doesn't support integer types with cuda\n    if get_device() == \"cuda\" and \"int\" in compute_dtype:\n        compute_dtype = config.floatx()\n\n    x1 = cast(x1, compute_dtype)\n    x2 = cast(x2, compute_dtype)\n    return cast(torch.matmul(x1, x2), result_dtype)\n\n\ndef multiply(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return torch.multiply(x1, x2)\n\n\ndef mean(x, axis=None, keepdims=False):\n    if isinstance(x, (list, tuple)):\n        x = stack(x)\n    x = convert_to_tensor(x)\n    if axis == () or axis == []:\n        # Torch handles the empty axis case differently from numpy.\n        return x\n    axis = to_tuple_or_list(axis)  # see [NB] below\n\n    ori_dtype = standardize_dtype(x.dtype)\n    # torch.mean only supports floating point inputs\n    compute_dtype = dtypes.result_type(x.dtype, \"float32\")\n    if \"int\" in ori_dtype or ori_dtype == \"bool\":\n        result_dtype = compute_dtype\n    else:\n        result_dtype = ori_dtype\n\n    # [NB] the python torch op torch.mean() is generated into\n    # `torch._C._VariableFunctions.pyi`, and the method\n    # signature is overloaded.\n    # Dynamo won't actually find the correct signature of\n    # `torch.mean()` if arguments are passed via kwargs\n    # So we have to pass the arguments via positional args\n    # EXCEPT for those that are forced as kwargs via the `*`\n    # delimiter in the overloaded method signatures.\n    # Additionally, we have to create a singleton-tuple\n    # when `axis` is an int to match the existing fn signature\n\n    # Cast input to compute dtype before mean to avoid dtype kwarg\n    # which causes issues with ONNX export (dtype kwarg not supported)\n    x = cast(x, compute_dtype)\n    result = torch.mean(x, axis, keepdims)\n    return cast(result, result_dtype)\n\n\ndef max(x, axis=None, keepdims=False, initial=None):\n    x = convert_to_tensor(x)\n    if 0 in x.shape:\n        if initial is None:\n            raise ValueError(\"Cannot compute the max of an empty tensor.\")\n        elif keepdims:\n            return torch.full((1,) * len(x.shape), initial)\n        else:\n            return torch.tensor(initial)\n\n    if axis is None:\n        result = torch.max(x)\n    else:\n        result = amax(x, axis=axis, keepdims=keepdims)\n    if isinstance(getattr(result, \"values\", None), torch.Tensor):\n        result = result.values\n\n    if initial is not None:\n        dtype = to_torch_dtype(result.dtype)\n        initial = convert_to_tensor(initial, dtype=dtype)\n        return torch.maximum(\n            result, torch.full(result.shape, initial, dtype=dtype)\n        )\n    return result\n\n\ndef ones(shape, dtype=None):\n    dtype = to_torch_dtype(dtype or config.floatx())\n    if isinstance(shape, int):\n        shape = (shape,)\n    return torch.ones(size=shape, dtype=dtype, device=get_device())\n\n\ndef zeros(shape, dtype=None):\n    dtype = to_torch_dtype(dtype or config.floatx())\n    if isinstance(shape, int):\n        shape = (shape,)\n    return torch.zeros(size=shape, dtype=dtype, device=get_device())\n\n\ndef zeros_like(x, dtype=None):\n    x = convert_to_tensor(x)\n    dtype = to_torch_dtype(dtype or x.dtype)\n    return torch.zeros_like(x, dtype=dtype)\n\n\ndef absolute(x):\n    x = convert_to_tensor(x)\n    # bool are always non-negative\n    if standardize_dtype(x.dtype) == \"bool\":\n        return x\n    return torch.abs(x)\n\n\ndef abs(x):\n    return absolute(x)\n\n\ndef all(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    if axis is None:\n        return cast(torch.all(x), \"bool\")\n    axis = to_tuple_or_list(axis)\n    for a in axis:\n        # `torch.all` does not handle multiple axes.\n        x = torch.all(x, dim=a, keepdim=keepdims)\n    return cast(x, \"bool\")\n\n\ndef angle(x):\n    x = convert_to_tensor(x)\n    ori_dtype = standardize_dtype(x.dtype)\n\n    # torch.angle doesn't support float16 with cuda\n    if get_device() != \"cpu\" and ori_dtype == \"float16\":\n        x = cast(x, \"float32\")\n        return cast(torch.angle(x), \"float16\")\n    return torch.angle(x)\n\n\ndef any(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    if axis is None:\n        return cast(torch.any(x), \"bool\")\n    axis = to_tuple_or_list(axis)\n    for a in axis:\n        # `torch.any` does not handle multiple axes.\n        x = torch.any(x, dim=a, keepdim=keepdims)\n    return cast(x, \"bool\")\n\n\ndef amax(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    if axis is None:\n        return torch.amax(x)\n    if axis == () or axis == []:\n        # Torch handles the empty axis case differently from numpy.\n        return x\n    return torch.amax(x, dim=axis, keepdim=keepdims)\n\n\ndef amin(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    if axis is None:\n        return torch.amin(x)\n    if axis == () or axis == []:\n        # Torch handles the empty axis case differently from numpy.\n        return x\n    return torch.amin(x, dim=axis, keepdim=keepdims)\n\n\ndef append(x1, x2, axis=None):\n    x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)\n    if axis is None:\n        return torch.cat((x1.flatten(), x2.flatten()))\n    return torch.cat((x1, x2), dim=axis)\n\n\ndef arange(start, stop=None, step=None, dtype=None):\n    if dtype is None:\n        dtypes_to_resolve = [getattr(start, \"dtype\", type(start))]\n        if stop is not None:\n            dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n        if step is not None:\n            dtypes_to_resolve.append(getattr(step, \"dtype\", type(step)))\n        dtype = dtypes.result_type(*dtypes_to_resolve)\n    dtype = to_torch_dtype(dtype)\n    if stop is None:\n        start, stop = 0, start\n    if step is None:\n        step = 1\n    return torch.arange(\n        start, stop, step=step, dtype=dtype, device=get_device()\n    )\n\n\ndef arccos(x):\n    x = convert_to_tensor(x)\n    return torch.arccos(x)\n\n\ndef arccosh(x):\n    x = convert_to_tensor(x)\n    return torch.arccosh(x)\n\n\ndef arcsin(x):\n    x = convert_to_tensor(x)\n    return torch.arcsin(x)\n\n\ndef arcsinh(x):\n    x = convert_to_tensor(x)\n    return torch.arcsinh(x)\n\n\ndef arctan(x):\n    x = convert_to_tensor(x)\n    return torch.arctan(x)\n\n\ndef arctan2(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    result_dtype = dtypes.result_type(x1.dtype, x2.dtype, float)\n    compute_dtype = result_dtype\n    # TODO: torch.arctan2 doesn't support float16 with cpu\n    if get_device() == \"cpu\" and compute_dtype == \"float16\":\n        compute_dtype = \"float32\"\n    x1 = cast(x1, compute_dtype)\n    x2 = cast(x2, compute_dtype)\n    return cast(torch.arctan2(x1, x2), result_dtype)\n\n\ndef arctanh(x):\n    x = convert_to_tensor(x)\n    return torch.arctanh(x)\n\n\ndef argmax(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n\n    # TODO: torch.argmax doesn't support bool\n    if standardize_dtype(x.dtype) == \"bool\":\n        x = cast(x, \"uint8\")\n\n    return cast(torch.argmax(x, dim=axis, keepdim=keepdims), dtype=\"int32\")\n\n\ndef argmin(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n\n    # TODO: torch.argmin doesn't support bool\n    if standardize_dtype(x.dtype) == \"bool\":\n        x = cast(x, \"uint8\")\n\n    return cast(torch.argmin(x, dim=axis, keepdim=keepdims), dtype=\"int32\")\n\n\ndef argsort(x, axis=-1):\n    x = convert_to_tensor(x)\n\n    # TODO: torch.argsort doesn't support bool\n    if standardize_dtype(x.dtype) == \"bool\":\n        x = cast(x, \"uint8\")\n\n    if axis is None:\n        axis = -1\n        x = x.reshape(-1)\n    return cast(torch.argsort(x, dim=axis, stable=True), dtype=\"int32\")\n\n\ndef array(x, dtype=None):\n    return convert_to_tensor(x, dtype=dtype)\n\n\ndef view(x, dtype=None):\n    dtype = to_torch_dtype(dtype)\n    x = convert_to_tensor(x)\n    return x.view(dtype=dtype)\n\n\ndef average(x, axis=None, weights=None):\n    x = convert_to_tensor(x)\n    dtypes_to_resolve = [x.dtype, float]\n    if weights is not None:\n        weights = convert_to_tensor(weights)\n        dtypes_to_resolve.append(weights.dtype)\n    dtype = dtypes.result_type(*dtypes_to_resolve)\n    x = cast(x, dtype)\n    if weights is not None:\n        weights = cast(weights, dtype)\n    if axis == () or axis == []:\n        # Torch handles the empty axis case differently from numpy.\n        return x\n    if weights is not None:\n        return torch.sum(torch.mul(x, weights), dim=axis) / torch.sum(\n            weights, dim=-1\n        )\n    return torch.mean(x, axis)\n\n\ndef bartlett(x):\n    x = convert_to_tensor(x)\n    return torch.signal.windows.bartlett(x)\n\n\ndef hamming(x):\n    x = convert_to_tensor(x)\n    return torch.signal.windows.hamming(x)\n\n\ndef hanning(x):\n    x = convert_to_tensor(x)\n    return torch.signal.windows.hann(x)\n\n\ndef heaviside(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    if dtype in [\"int8\", \"int16\", \"int32\", \"uint8\", \"uint16\", \"uint32\"]:\n        dtype = config.floatx()\n    elif dtype == \"int64\":\n        dtype = \"float64\"\n\n    x1 = cast(x1, dtype)\n    x2 = cast(x2, dtype)\n\n    return torch.heaviside(x1, x2)\n\n\ndef kaiser(x, beta):\n    x = convert_to_tensor(x)\n    return torch.signal.windows.kaiser(x, beta=beta)\n\n\ndef bincount(x, weights=None, minlength=0, sparse=False):\n    if sparse:\n        raise ValueError(\"Unsupported value `sparse=True` with torch backend\")\n    x = convert_to_tensor(x)\n    dtypes_to_resolve = [x.dtype]\n    if weights is not None:\n        weights = convert_to_tensor(weights)\n        dtypes_to_resolve.append(weights.dtype)\n        dtype = dtypes.result_type(*dtypes_to_resolve)\n    else:\n        dtype = \"int32\"\n    if len(x.shape) == 2:\n        if weights is None:\n\n            def bincount_fn(arr):\n                return torch.bincount(arr, minlength=minlength)\n\n            bincounts = list(map(bincount_fn, x))\n        else:\n\n            def bincount_fn(arr_w):\n                return torch.bincount(\n                    arr_w[0], weights=arr_w[1], minlength=minlength\n                )\n\n            bincounts = list(map(bincount_fn, zip(x, weights)))\n\n        return cast(torch.stack(bincounts), dtype)\n    return cast(torch.bincount(x, weights, minlength), dtype)\n\n\ndef bitwise_and(x, y):\n    x = convert_to_tensor(x)\n    y = convert_to_tensor(y)\n    return torch.bitwise_and(x, y)\n\n\ndef bitwise_invert(x):\n    x = convert_to_tensor(x)\n    return torch.bitwise_not(x)\n\n\ndef bitwise_not(x):\n    return bitwise_invert(x)\n\n\ndef bitwise_or(x, y):\n    x = convert_to_tensor(x)\n    y = convert_to_tensor(y)\n    return torch.bitwise_or(x, y)\n\n\ndef bitwise_xor(x, y):\n    x = convert_to_tensor(x)\n    y = convert_to_tensor(y)\n    return torch.bitwise_xor(x, y)\n\n\ndef bitwise_left_shift(x, y):\n    x = convert_to_tensor(x)\n    if not isinstance(y, int):\n        y = convert_to_tensor(y)\n    return torch.bitwise_left_shift(x, y)\n\n\ndef left_shift(x, y):\n    return bitwise_left_shift(x, y)\n\n\ndef bitwise_right_shift(x, y):\n    x = convert_to_tensor(x)\n    if not isinstance(y, int):\n        y = convert_to_tensor(y)\n    return torch.bitwise_right_shift(x, y)\n\n\ndef right_shift(x, y):\n    return bitwise_right_shift(x, y)\n\n\ndef blackman(x):\n    x = convert_to_tensor(x)\n    return torch.signal.windows.blackman(x)\n\n\ndef broadcast_to(x, shape):\n    x = convert_to_tensor(x)\n    return torch.broadcast_to(x, shape)\n\n\ndef cbrt(x):\n    x = convert_to_tensor(x)\n\n    dtype = standardize_dtype(x.dtype)\n    if dtype == \"bool\":\n        x = cast(x, \"int32\")\n    elif dtype == \"int64\":\n        x = cast(x, \"float64\")\n\n    return torch.sign(x) * torch.abs(x) ** (1.0 / 3.0)\n\n\ndef ceil(x):\n    x = convert_to_tensor(x)\n    ori_dtype = standardize_dtype(x.dtype)\n\n    # TODO: torch.ceil doesn't support bool\n    if ori_dtype == \"bool\":\n        x = cast(x, \"uint8\")\n    # TODO: torch.ceil doesn't support float16 with cpu\n    elif get_device() == \"cpu\" and ori_dtype == \"float16\":\n        x = cast(x, config.floatx())\n\n    if ori_dtype == \"int64\":\n        dtype = config.floatx()\n    else:\n        dtype = dtypes.result_type(ori_dtype, float)\n    return cast(torch.ceil(x), dtype=dtype)\n\n\ndef clip(x, x_min, x_max):\n    x = convert_to_tensor(x)\n    x_min = convert_to_tensor(x_min)\n    x_max = convert_to_tensor(x_max)\n    ori_dtype = standardize_dtype(x.dtype)\n\n    # TODO: torch.clip doesn't support float16 with cpu\n    if get_device() == \"cpu\" and ori_dtype == \"float16\":\n        x = cast(x, \"float32\")\n        return cast(torch.clip(x, min=x_min, max=x_max), \"float16\")\n\n    if ori_dtype == \"bool\":\n        x = cast(x, \"int32\")\n    return torch.clip(x, min=x_min, max=x_max)\n\n\ndef concatenate(xs, axis=0):\n    xs = [convert_to_tensor(x) for x in xs]\n    return torch.cat(xs, dim=axis)\n\n\ndef conjugate(x):\n    if not isinstance(x, torch.Tensor):\n        x = torch.from_numpy(x)  # needed for complex type conversion\n    return torch.conj(x).resolve_conj()\n\n\ndef conj(x):\n    if not isinstance(x, torch.Tensor):\n        x = torch.from_numpy(x)  # needed for complex type conversion\n    return torch.conj(x).resolve_conj()\n\n\ndef copy(x):\n    x = convert_to_tensor(x)\n    return torch.clone(x)\n\n\ndef cos(x):\n    x = convert_to_tensor(x)\n    return torch.cos(x)\n\n\ndef cosh(x):\n    x = convert_to_tensor(x)\n    return torch.cosh(x)\n\n\ndef count_nonzero(x, axis=None):\n    x = convert_to_tensor(x)\n    if axis == () or axis == []:\n        # Torch handles the empty axis case differently from numpy.\n        return cast(torch.ne(x, 0), \"int32\")\n    return cast(torch.count_nonzero(x, dim=axis).T, \"int32\")\n\n\ndef cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None):\n    if axisa != -1 or axisb != -1 or axisc != -1:\n        raise ValueError(\n            \"Torch backend does not support `axisa`, `axisb`, or `axisc`. \"\n            f\"Received: axisa={axisa}, axisb={axisb}, axisc={axisc}. Please \"\n            \"use `axis` arg in torch backend.\"\n        )\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    compute_dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    result_dtype = compute_dtype\n    # TODO: torch.cross doesn't support bfloat16 with gpu\n    if get_device() == \"cuda\" and compute_dtype == \"bfloat16\":\n        compute_dtype = \"float32\"\n    # TODO: torch.cross doesn't support float16 with cpu\n    elif get_device() == \"cpu\" and compute_dtype == \"float16\":\n        compute_dtype = \"float32\"\n    x1 = cast(x1, compute_dtype)\n    x2 = cast(x2, compute_dtype)\n    return cast(torch.cross(x1, x2, dim=axis), result_dtype)\n\n\ndef cumprod(x, axis=None, dtype=None):\n    x = convert_to_tensor(x)\n    if axis is None:\n        x = x.flatten()\n        axis = 0\n    dtype = dtypes.result_type(dtype or x.dtype)\n    if dtype == \"bool\":\n        dtype = \"int32\"\n    # TODO: torch.cumprod doesn't support float16 with cpu\n    elif get_device() == \"cpu\" and dtype == \"float16\":\n        return cast(\n            torch.cumprod(x, dim=axis, dtype=to_torch_dtype(\"float32\")),\n            \"float16\",\n        )\n    return torch.cumprod(x, dim=axis, dtype=to_torch_dtype(dtype))\n\n\ndef cumsum(x, axis=None, dtype=None):\n    x = convert_to_tensor(x)\n    if axis is None:\n        x = x.flatten()\n        axis = 0\n    dtype = dtypes.result_type(dtype or x.dtype)\n    if dtype == \"bool\":\n        dtype = \"int32\"\n    # TODO: torch.cumsum doesn't support float16 with cpu\n    elif get_device() == \"cpu\" and dtype == \"float16\":\n        return cast(\n            torch.cumsum(x, dim=axis, dtype=to_torch_dtype(\"float32\")),\n            \"float16\",\n        )\n    return torch.cumsum(x, dim=axis, dtype=to_torch_dtype(dtype))\n\n\ndef deg2rad(x):\n    x = convert_to_tensor(x)\n\n    if standardize_dtype(x.dtype) == \"int64\":\n        return cast(torch.deg2rad(x), \"float64\")\n\n    return torch.deg2rad(x)\n\n\ndef diag(x, k=0):\n    x = convert_to_tensor(x)\n    return torch.diag(x, diagonal=k)\n\n\ndef diagflat(x, k=0):\n    x = convert_to_tensor(x)\n    return torch.diagflat(x, offset=k)\n\n\ndef diagonal(x, offset=0, axis1=0, axis2=1):\n    x = convert_to_tensor(x)\n    return torch.diagonal(\n        x,\n        offset=offset,\n        dim1=axis1,\n        dim2=axis2,\n    )\n\n\ndef diff(a, n=1, axis=-1):\n    a = convert_to_tensor(a)\n    return torch.diff(a, n=n, dim=axis)\n\n\ndef digitize(x, bins):\n    x = convert_to_tensor(x)\n    bins = convert_to_tensor(bins)\n    if standardize_dtype(x.dtype) == \"bool\":\n        x = cast(x, \"uint8\")\n    return cast(torch.bucketize(x, bins, right=True), \"int32\")\n\n\ndef dot(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    result_dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    # GPU only supports float types\n    compute_dtype = dtypes.result_type(result_dtype, float)\n\n    # TODO: torch.matmul doesn't support float16 with cpu\n    if get_device() == \"cpu\" and compute_dtype == \"float16\":\n        compute_dtype = \"float32\"\n\n    x1 = cast(x1, compute_dtype)\n    x2 = cast(x2, compute_dtype)\n    if x1.ndim == 0 or x2.ndim == 0:\n        return cast(torch.multiply(x1, x2), result_dtype)\n    return cast(torch.matmul(x1, x2), result_dtype)\n\n\ndef dstack(xs):\n    xs = [convert_to_tensor(x) for x in xs]\n    return torch.dstack(xs)\n\n\ndef empty(shape, dtype=None):\n    dtype = to_torch_dtype(dtype or config.floatx())\n    return torch.empty(size=shape, dtype=dtype, device=get_device())\n\n\ndef empty_like(x, dtype=None):\n    x = convert_to_tensor(x)\n    dtype = to_torch_dtype(dtype or x.dtype)\n    return torch.empty_like(x, dtype=dtype, device=get_device())\n\n\ndef equal(x1, x2):\n    x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)\n    return torch.eq(x1, x2)\n\n\ndef exp(x):\n    x = convert_to_tensor(x)\n    ori_dtype = standardize_dtype(x.dtype)\n    if \"int\" in ori_dtype or ori_dtype == \"bool\":\n        x = cast(x, config.floatx())\n    return torch.exp(x)\n\n\ndef exp2(x):\n    x = convert_to_tensor(x)\n    ori_dtype = standardize_dtype(x.dtype)\n    if \"int\" in ori_dtype or ori_dtype == \"bool\":\n        x = cast(x, config.floatx())\n    return torch.exp2(x)\n\n\ndef expand_dims(x, axis):\n    x = convert_to_tensor(x)\n    axis = to_tuple_or_list(axis)\n    out_ndim = len(x.shape) + len(axis)\n    axis = sorted([canonicalize_axis(a, out_ndim) for a in axis])\n    for a in axis:\n        x = torch.unsqueeze(x, dim=a)\n    return x\n\n\ndef expm1(x):\n    x = convert_to_tensor(x)\n    ori_dtype = standardize_dtype(x.dtype)\n    if \"int\" in ori_dtype or ori_dtype == \"bool\":\n        x = cast(x, config.floatx())\n    return torch.expm1(x)\n\n\ndef flip(x, axis=None):\n    x = convert_to_tensor(x)\n    if axis is None:\n        axis = tuple(range(x.ndim))\n    axis = to_tuple_or_list(axis)\n    return torch.flip(x, dims=axis)\n\n\ndef floor(x):\n    x = convert_to_tensor(x)\n    dtype = (\n        config.floatx()\n        if standardize_dtype(x.dtype) == \"int64\"\n        else dtypes.result_type(x.dtype, float)\n    )\n    x = cast(x, dtype)\n    return torch.floor(x)\n\n\ndef full(shape, fill_value, dtype=None):\n    dtype = to_torch_dtype(dtype)\n    fill_value = convert_to_tensor(fill_value, dtype=dtype)\n    if len(fill_value.shape) > 0:\n        # `torch.full` only supports scala `fill_value`.\n        expand_size = len(shape) - len(fill_value.shape)\n        tile_shape = tuple(shape[:expand_size]) + (1,) * len(fill_value.shape)\n        return torch.tile(fill_value, tile_shape)\n    return torch.full(\n        size=shape, fill_value=fill_value, dtype=dtype, device=get_device()\n    )\n\n\ndef full_like(x, fill_value, dtype=None):\n    dtype = dtype or x.dtype\n    return full(shape=x.shape, fill_value=fill_value, dtype=dtype)\n\n\ndef gcd(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return torch.gcd(x1, x2)\n\n\ndef geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):\n    if axis != 0:\n        raise ValueError(\n            \"torch does not support an `axis` argument for geomspace. \"\n            f\"Received axis={axis}\"\n        )\n    if dtype is None:\n        dtypes_to_resolve = [\n            getattr(start, \"dtype\", type(start)),\n            getattr(stop, \"dtype\", type(stop)),\n            float,\n        ]\n        dtype = dtypes.result_type(*dtypes_to_resolve)\n    dtype = to_torch_dtype(dtype)\n\n    start = convert_to_tensor(start, dtype=dtype)\n    stop = convert_to_tensor(stop, dtype=dtype)\n\n    log_start = torch.log10(torch.abs(start))\n    log_stop = torch.log10(torch.abs(stop))\n\n    result = logspace(\n        log_start, log_stop, num=num, endpoint=endpoint, base=10, dtype=dtype\n    )\n    return result * torch.sign(start)\n\n\ndef greater(x1, x2):\n    x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)\n    return torch.greater(x1, x2)\n\n\ndef greater_equal(x1, x2):\n    x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)\n    return torch.greater_equal(x1, x2)\n\n\ndef hstack(xs):\n    xs = [convert_to_tensor(x) for x in xs]\n    return torch.hstack(xs)\n\n\ndef hsplit(x, indices_or_sections):\n    x = convert_to_tensor(x)\n    if not isinstance(indices_or_sections, int):\n        indices_or_sections = convert_to_tensor(indices_or_sections).tolist()\n    return list(torch.hsplit(x, indices_or_sections))\n\n\ndef hypot(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    if dtype in [\"int8\", \"int16\", \"int32\", \"uint8\", \"uint16\", \"uint32\"]:\n        dtype = config.floatx()\n    elif dtype == \"int64\":\n        dtype = \"float64\"\n\n    x1 = cast(x1, dtype)\n    x2 = cast(x2, dtype)\n\n    return torch.hypot(x1, x2)\n\n\ndef identity(n, dtype=None):\n    dtype = to_torch_dtype(dtype or config.floatx())\n\n    # TODO: torch.eye doesn't support bfloat16 with cpu\n    if get_device() == \"cpu\" and dtype == torch.bfloat16:\n        return cast(\n            torch.eye(n, dtype=to_torch_dtype(\"float32\"), device=get_device()),\n            dtype,\n        )\n    return torch.eye(n, dtype=dtype, device=get_device())\n\n\ndef imag(x):\n    if not isinstance(x, torch.Tensor):\n        x = torch.from_numpy(x)  # needed for complex type conversion\n    return torch.imag(x)\n\n\ndef isclose(x1, x2, rtol=1e-5, atol=1e-8, equal_nan=False):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    result_dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    x1 = cast(x1, result_dtype)\n    x2 = cast(x2, result_dtype)\n    if \"float\" in standardize_dtype(result_dtype):\n        return torch.isclose(x1, x2, rtol=rtol, atol=atol, equal_nan=equal_nan)\n    return torch.eq(x1, x2)\n\n\ndef allclose(x1, x2, rtol=1e-5, atol=1e-8, equal_nan=False):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    result_dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    x1 = cast(x1, result_dtype)\n    x2 = cast(x2, result_dtype)\n    if \"float\" in standardize_dtype(result_dtype):\n        return torch.all(\n            torch.isclose(x1, x2, rtol=rtol, atol=atol, equal_nan=equal_nan)\n        )\n    return torch.all(torch.eq(x1, x2))\n\n\ndef isfinite(x):\n    x = convert_to_tensor(x)\n    return torch.isfinite(x)\n\n\ndef isin(x1, x2, assume_unique=False, invert=False):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    if dtype == \"bool\":\n        x1 = cast(x1, \"int32\")\n        x2 = cast(x2, \"int32\")\n\n    if standardize_dtype(x1.dtype) == \"bool\":\n        x1 = cast(x1, x2.dtype)\n    if standardize_dtype(x2.dtype) == \"bool\":\n        x2 = cast(x2, x1.dtype)\n\n    return torch.isin(x1, x2, assume_unique=assume_unique, invert=invert)\n\n\ndef isinf(x):\n    x = convert_to_tensor(x)\n    return torch.isinf(x)\n\n\ndef isnan(x):\n    x = convert_to_tensor(x)\n    return torch.isnan(x)\n\n\ndef isneginf(x):\n    x = convert_to_tensor(x)\n    return torch.isneginf(x)\n\n\ndef isposinf(x):\n    x = convert_to_tensor(x)\n    return torch.isposinf(x)\n\n\ndef isreal(x):\n    x = convert_to_tensor(x)\n    return torch.isreal(x)\n\n\ndef kron(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return torch.kron(x1, x2)\n\n\ndef lcm(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    return torch.lcm(x1, x2)\n\n\ndef ldexp(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype, float)\n\n    if standardize_dtype(x2.dtype) not in dtypes.INT_TYPES:\n        raise TypeError(\n            f\"ldexp exponent must be an integer type. \"\n            f\"Received: x2 dtype={x2.dtype}\"\n        )\n\n    return cast(torch.ldexp(x1, x2), dtype)\n\n\ndef less(x1, x2):\n    x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)\n    return torch.less(x1, x2)\n\n\ndef less_equal(x1, x2):\n    x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)\n    return torch.less_equal(x1, x2)\n\n\ndef linspace(\n    start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0\n):\n    if axis != 0:\n        raise ValueError(\n            \"torch.linspace does not support an `axis` argument. \"\n            f\"Received axis={axis}\"\n        )\n    if dtype is None:\n        dtypes_to_resolve = [\n            getattr(start, \"dtype\", type(start)),\n            getattr(stop, \"dtype\", type(stop)),\n            float,\n        ]\n        dtype = dtypes.result_type(*dtypes_to_resolve)\n    dtype = to_torch_dtype(dtype)\n\n    step = convert_to_tensor(torch.nan)\n    if endpoint:\n        if num > 1:\n            step = (stop - start) / (num - 1)\n    else:\n        if num > 0:\n            step = (stop - start) / num\n        if num > 1:\n            stop = stop - ((stop - start) / num)\n    if hasattr(start, \"__len__\") and hasattr(stop, \"__len__\"):\n        start = convert_to_tensor(start, dtype=dtype)\n        stop = convert_to_tensor(stop, dtype=dtype)\n        steps = torch.arange(num, dtype=dtype, device=get_device()) / (num - 1)\n\n        # reshape `steps` to allow for broadcasting\n        for i in range(start.ndim):\n            steps = steps.unsqueeze(-1)\n\n        # increments from `start` to `stop` in each dimension\n        linspace = start[None] + steps * (stop - start)[None]\n    else:\n        linspace = torch.linspace(\n            start=start,\n            end=stop,\n            steps=num,\n            dtype=dtype,\n            device=get_device(),\n        )\n    if retstep is True:\n        return (linspace, step)\n    return linspace\n\n\ndef log(x):\n    x = convert_to_tensor(x)\n    return torch.log(x)\n\n\ndef log10(x):\n    x = convert_to_tensor(x)\n    return torch.log10(x)\n\n\ndef log1p(x):\n    x = convert_to_tensor(x)\n    return torch.log1p(x)\n\n\ndef log2(x):\n    x = convert_to_tensor(x)\n    return torch.log2(x)\n\n\ndef logaddexp(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype, float)\n\n    # TODO: torch.logaddexp doesn't support float16 with cpu\n    if get_device() == \"cpu\" and dtype == \"float16\":\n        x1 = cast(x1, \"float32\")\n        x2 = cast(x2, \"float32\")\n        return cast(torch.logaddexp(x1, x2), dtype)\n    else:\n        x1 = cast(x1, dtype)\n        x2 = cast(x2, dtype)\n        return torch.logaddexp(x1, x2)\n\n\ndef logaddexp2(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype, float)\n    x1 = cast(x1, dtype)\n    x2 = cast(x2, dtype)\n    return torch.logaddexp2(x1, x2)\n\n\ndef logical_and(x1, x2):\n    x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)\n    return torch.logical_and(x1, x2)\n\n\ndef logical_not(x):\n    x = convert_to_tensor(x)\n    return torch.logical_not(x)\n\n\ndef logical_or(x1, x2):\n    x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)\n    return torch.logical_or(x1, x2)\n\n\ndef logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0):\n    if axis != 0:\n        raise ValueError(\n            \"torch.logspace does not support an `axis` argument. \"\n            f\"Received axis={axis}\"\n        )\n    if dtype is None:\n        dtypes_to_resolve = [\n            getattr(start, \"dtype\", type(start)),\n            getattr(stop, \"dtype\", type(stop)),\n            float,\n        ]\n        dtype = dtypes.result_type(*dtypes_to_resolve)\n    dtype = to_torch_dtype(dtype)\n\n    if endpoint is False:\n        stop = stop - ((stop - start) / num)\n    if hasattr(start, \"__len__\") and hasattr(stop, \"__len__\"):\n        start = convert_to_tensor(start, dtype=dtype)\n        stop = convert_to_tensor(stop, dtype=dtype)\n        steps = torch.arange(num, dtype=dtype, device=get_device()) / (num - 1)\n\n        # reshape `steps` to allow for broadcasting\n        for i in range(start.ndim):\n            steps = steps.unsqueeze(-1)\n\n        # increments from `start` to `stop` in each dimension\n        linspace = start[None] + steps * (stop - start)[None]\n        logspace = base**linspace\n    else:\n        compute_dtype = dtype\n        # TODO: torch.logspace doesn't support float16 with cpu\n        if get_device() == \"cpu\" and dtype == torch.float16:\n            compute_dtype = torch.float32\n        logspace = cast(\n            torch.logspace(\n                start=start,\n                end=stop,\n                steps=num,\n                base=base,\n                dtype=compute_dtype,\n                device=get_device(),\n            ),\n            dtype,\n        )\n    return logspace\n\n\ndef maximum(x1, x2):\n    if not isinstance(x1, (int, float)):\n        x1 = convert_to_tensor(x1)\n    if not isinstance(x2, (int, float)):\n        x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(\n        getattr(x1, \"dtype\", type(x1)),\n        getattr(x2, \"dtype\", type(x2)),\n    )\n    x1 = convert_to_tensor(x1, dtype)\n    x2 = convert_to_tensor(x2, dtype)\n    return torch.maximum(x1, x2)\n\n\ndef median(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    compute_dtype = dtypes.result_type(x.dtype, \"float32\")\n    result_dtype = dtypes.result_type(x.dtype, float)\n    x = cast(x, compute_dtype)\n\n    if axis is None and keepdims is False:\n        return cast(torch.median(x), result_dtype)\n    elif isinstance(axis, int):\n        return cast(\n            torch.median(x, dim=axis, keepdim=keepdims)[0], result_dtype\n        )\n\n    # support multiple axes\n    if axis is None:\n        y = reshape(x, [-1])\n    else:\n        # transpose\n        axis = [canonicalize_axis(a, x.ndim) for a in axis]\n        other_dims = sorted(set(range(x.ndim)).difference(axis))\n        perm = other_dims + list(axis)\n        x_permed = torch.permute(x, dims=perm)\n        # reshape\n        x_shape = list(x.shape)\n        other_shape = [x_shape[i] for i in other_dims]\n        end_shape = [math.prod([x_shape[i] for i in axis])]\n        full_shape = other_shape + end_shape\n        y = reshape(x_permed, full_shape)\n\n    y = torch.median(y, dim=-1)[0]\n\n    if keepdims:\n        if axis is None:\n            for _ in range(x.ndim):\n                y = expand_dims(y, axis=-1)\n        else:\n            for i in sorted(axis):\n                y = expand_dims(y, axis=i)\n\n    return cast(y, result_dtype)\n\n\ndef meshgrid(*x, indexing=\"xy\"):\n    x = [convert_to_tensor(sc_tensor) for sc_tensor in x]\n    return torch.meshgrid(x, indexing=indexing)\n\n\ndef min(x, axis=None, keepdims=False, initial=None):\n    x = convert_to_tensor(x)\n    if 0 in x.shape:\n        if initial is None:\n            raise ValueError(\"Cannot compute the min of an empty tensor.\")\n        elif keepdims:\n            return torch.full((1,) * len(x.shape), initial)\n        else:\n            return torch.tensor(initial)\n\n    if axis is None:\n        result = torch.min(x)\n    else:\n        result = amin(x, axis=axis, keepdims=keepdims)\n\n    if isinstance(getattr(result, \"values\", None), torch.Tensor):\n        result = result.values\n\n    if initial is not None:\n        dtype = to_torch_dtype(result.dtype)\n        initial = convert_to_tensor(initial, dtype=dtype)\n        return torch.minimum(result, initial)\n    return result\n\n\ndef minimum(x1, x2):\n    if not isinstance(x1, (int, float)):\n        x1 = convert_to_tensor(x1)\n    if not isinstance(x2, (int, float)):\n        x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(\n        getattr(x1, \"dtype\", type(x1)),\n        getattr(x2, \"dtype\", type(x2)),\n    )\n    x1 = convert_to_tensor(x1, dtype)\n    x2 = convert_to_tensor(x2, dtype)\n    return torch.minimum(x1, x2)\n\n\ndef mod(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    if dtype == \"bool\":\n        x1 = cast(x1, \"int32\")\n        x2 = cast(x2, \"int32\")\n    return torch.remainder(x1, x2)\n\n\ndef fmod(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    if dtype == \"bool\":\n        x1 = cast(x1, \"int32\")\n        x2 = cast(x2, \"int32\")\n    return torch.fmod(x1, x2)\n\n\ndef moveaxis(x, source, destination):\n    x = convert_to_tensor(x)\n    return torch.moveaxis(x, source=source, destination=destination)\n\n\ndef nanargmax(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n\n    if not torch.is_floating_point(x):\n        return argmax(x, axis=axis, keepdims=keepdims)\n\n    x_clean = torch.where(torch.isnan(x), float(\"-inf\"), x)\n\n    return torch.where(\n        torch.isnan(x).all(dim=axis, keepdim=keepdims),\n        torch.tensor(-1, dtype=torch.int32, device=get_device()),\n        argmax(x_clean, axis=axis, keepdims=keepdims),\n    )\n\n\ndef nanargmin(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n\n    if not torch.is_floating_point(x):\n        return argmin(x, axis=axis, keepdims=keepdims)\n\n    x_clean = torch.where(torch.isnan(x), float(\"inf\"), x)\n\n    return torch.where(\n        torch.isnan(x).all(dim=axis, keepdim=keepdims),\n        torch.tensor(-1, dtype=torch.int32, device=get_device()),\n        argmin(x_clean, axis=axis, keepdims=keepdims),\n    )\n\n\ndef nancumsum(x, axis=None, dtype=None):\n    x = nan_to_num(x)\n    return cumsum(x, axis=axis, dtype=dtype)\n\n\ndef nancumprod(x, axis=None, dtype=None):\n    x = nan_to_num(x, nan=1.0)\n    return cumprod(x, axis=axis, dtype=dtype)\n\n\ndef nanmax(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    if not torch.is_floating_point(x):\n        return torch.amax(x, dim=axis, keepdim=keepdims)\n\n    if axis == () or axis == []:\n        return x\n\n    x_clean = torch.where(torch.isnan(x), float(\"-inf\"), x)\n    out = torch.amax(x_clean, dim=axis, keepdim=keepdims)\n\n    return torch.where(\n        torch.isnan(x).all(dim=axis, keepdim=keepdims),\n        torch.tensor(float(\"nan\"), dtype=x.dtype, device=get_device()),\n        out,\n    )\n\n\ndef nanmean(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n\n    if axis == () or axis == []:\n        return x\n\n    dtype = dtypes.result_type(standardize_dtype(x.dtype), float)\n    return torch.nanmean(cast(x, dtype), dim=axis, keepdim=keepdims)\n\n\ndef nanmin(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    if not torch.is_floating_point(x):\n        return torch.amin(x, dim=axis, keepdim=keepdims)\n\n    if axis == () or axis == []:\n        return x\n\n    x_clean = torch.where(torch.isnan(x), float(\"inf\"), x)\n    out = torch.amin(x_clean, dim=axis, keepdim=keepdims)\n\n    return torch.where(\n        torch.isnan(x).all(dim=axis, keepdim=keepdims),\n        torch.tensor(float(\"nan\"), dtype=x.dtype, device=get_device()),\n        out,\n    )\n\n\ndef nanprod(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n\n    if axis == () or axis == []:\n        return torch.nan_to_num(x, nan=1)\n\n    if isinstance(axis, (list, tuple)):\n        axis = sorted(axis, reverse=True)\n\n    if not torch.is_floating_point(x):\n        return prod(x, axis=axis, keepdims=keepdims)\n\n    return prod(\n        torch.where(torch.isnan(x), torch.ones((), dtype=x.dtype), x),\n        axis=axis,\n        keepdims=keepdims,\n    )\n\n\ndef nanstd(x, axis=None, keepdims=False):\n    var_val = nanvar(x, axis=axis, keepdims=keepdims)\n    return torch.sqrt(var_val)\n\n\ndef nansum(x, axis=None, keepdims=False):\n    if isinstance(x, (list, tuple)):\n        x = stack(x)\n    x = convert_to_tensor(x)\n    dtype = standardize_dtype(x.dtype)\n\n    if dtype in (\"bool\", \"uint8\", \"int8\", \"int16\"):\n        dtype = \"int32\"\n\n    if axis == () or axis == []:\n        return cast(torch.nan_to_num(x, nan=0), dtype)\n    return cast(torch.nansum(x, dim=axis, keepdim=keepdims), dtype)\n\n\ndef nanvar(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n\n    result_dtype = dtypes.result_type(x.dtype, float)\n    x = cast(x, result_dtype)\n\n    if axis == () or axis == []:\n        return torch.where(torch.isnan(x), x, torch.zeros(()))\n\n    mean = nanmean(x, axis=axis, keepdims=True)\n\n    valid = ~torch.isnan(x)\n    centered = torch.where(valid, x - mean, torch.zeros_like(x))\n\n    if torch.is_complex(centered):\n        centered = centered.real * centered.real + centered.imag * centered.imag\n    else:\n        centered = centered.square()\n\n    count = valid.sum(dim=axis, keepdim=keepdims)\n    var = centered.sum(dim=axis, keepdim=keepdims) / count\n    return var\n\n\ndef nan_to_num(x, nan=0.0, posinf=None, neginf=None):\n    x = convert_to_tensor(x)\n    return torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf)\n\n\ndef ndim(x):\n    x = convert_to_tensor(x)\n    return x.ndim\n\n\ndef nonzero(x):\n    x = convert_to_tensor(x)\n    return cast(torch.nonzero(x).T, \"int32\")\n\n\ndef not_equal(x1, x2):\n    x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)\n    return torch.not_equal(x1, x2)\n\n\ndef ones_like(x, dtype=None):\n    x = convert_to_tensor(x)\n    dtype = to_torch_dtype(dtype or x.dtype)\n    return torch.ones_like(x, dtype=dtype)\n\n\ndef outer(x1, x2):\n    x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)\n    return torch.outer(x1.flatten(), x2.flatten())\n\n\ndef pad(x, pad_width, mode=\"constant\", constant_values=None):\n    kwargs = {}\n    if constant_values is not None:\n        if mode != \"constant\":\n            raise ValueError(\n                \"Argument `constant_values` can only be \"\n                \"provided when `mode == 'constant'`. \"\n                f\"Received: mode={mode}\"\n            )\n        kwargs[\"value\"] = constant_values\n    x = convert_to_tensor(x)\n    pad_sum = []\n    pad_width = list(pad_width)[::-1]  # torch uses reverse order\n    pad_width_sum = 0\n    for pad in pad_width:\n        pad_width_sum += pad[0] + pad[1]\n    for pad in pad_width:\n        pad_sum += pad\n        pad_width_sum -= pad[0] + pad[1]\n        if pad_width_sum == 0:  # early break when no padding in higher order\n            break\n    if mode == \"symmetric\":\n        mode = \"replicate\"\n    if mode == \"constant\":\n        return torch.nn.functional.pad(x, pad=pad_sum, mode=mode, **kwargs)\n    # TODO: reflect and symmetric padding are implemented for padding the\n    # last 3 dimensions of a 4D or 5D input tensor, the last 2 dimensions of a\n    # 3D or 4D input tensor, or the last dimension of a 2D or 3D input tensor.\n    # https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html\n    ori_dtype = x.dtype\n    ori_ndim = x.ndim\n    need_squeeze = False\n    if x.ndim < 3:\n        need_squeeze = True\n        new_dims = [1] * (3 - x.ndim)\n        x = x.view(*new_dims, *x.shape)\n    need_cast = False\n    if x.dtype not in (torch.float32, torch.float64):\n        # TODO: reflect and symmetric padding are only supported with float32/64\n        # https://github.com/pytorch/pytorch/issues/40763\n        need_cast = True\n        x = cast(x, torch.float32)\n    x = torch.nn.functional.pad(x, pad=pad_sum, mode=mode)\n    if need_cast:\n        x = cast(x, ori_dtype)\n    if need_squeeze:\n        x = torch.squeeze(x, dim=tuple(range(3 - ori_ndim)))\n    return x\n\n\ndef prod(x, axis=None, keepdims=False, dtype=None):\n    x = convert_to_tensor(x)\n    if dtype is None:\n        dtype = dtypes.result_type(x.dtype)\n        if dtype == \"bool\":\n            dtype = \"int32\"\n        elif dtype in (\"int8\", \"int16\"):\n            dtype = \"int32\"\n        # TODO: torch.prod doesn't support uint32\n        elif dtype == \"uint8\":\n            dtype = \"int32\"\n    compute_dtype = dtype\n    # TODO: torch.prod doesn't support float16 with cpu\n    if get_device() == \"cpu\" and compute_dtype == \"float16\":\n        compute_dtype = \"float32\"\n    if axis is None:\n        return cast(torch.prod(x, dtype=to_torch_dtype(compute_dtype)), dtype)\n    axis = to_tuple_or_list(axis)\n    for a in axis:\n        # `torch.prod` does not handle multiple axes.\n        x = cast(\n            torch.prod(\n                x, dim=a, keepdim=keepdims, dtype=to_torch_dtype(compute_dtype)\n            ),\n            dtype,\n        )\n    return x\n\n\ndef ptp(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    if axis is None:\n        return x.max() - x.min()\n    elif axis == ():\n        return torch.zeros_like(x)\n    else:\n        return torch.amax(x, dim=axis, keepdim=keepdims) - torch.amin(\n            x, dim=axis, keepdim=keepdims\n        )\n\n\ndef quantile(x, q, axis=None, method=\"linear\", keepdims=False):\n    x = convert_to_tensor(x)\n    q = convert_to_tensor(q)\n    axis = to_tuple_or_list(axis)\n\n    compute_dtype = dtypes.result_type(x.dtype, \"float32\")\n    result_dtype = dtypes.result_type(x.dtype, float)\n\n    x = cast(x, compute_dtype)\n    # q must be same dtype as x\n    if x.dtype != q.dtype:\n        q = cast(q, x.dtype)\n\n    # support multiple axes\n    if axis is None:\n        y = reshape(x, [-1])\n    else:\n        # transpose\n        axis = [canonicalize_axis(a, x.ndim) for a in axis]\n        other_dims = sorted(set(range(x.ndim)).difference(axis))\n        perm = other_dims + list(axis)\n        x_permed = torch.permute(x, dims=perm)\n        # reshape\n        x_shape = list(x.shape)\n        other_shape = [x_shape[i] for i in other_dims]\n        end_shape = [math.prod([x_shape[i] for i in axis])]\n        full_shape = other_shape + end_shape\n        y = reshape(x_permed, full_shape)\n\n    y = torch.quantile(y, q, dim=-1, interpolation=method)\n\n    if keepdims:\n        if axis is None:\n            for _ in range(x.ndim):\n                y = expand_dims(y, axis=-1)\n        else:\n            for i in sorted(axis):\n                i = i + 1 if q.ndim > 0 else i\n                y = expand_dims(y, axis=i)\n\n    return cast(y, result_dtype)\n\n\ndef ravel(x):\n    x = convert_to_tensor(x)\n    return torch.ravel(x)\n\n\ndef unravel_index(indices, shape):\n    indices = convert_to_tensor(indices)\n    dtype = dtypes.result_type(indices.dtype)\n    return tuple(\n        cast(idx, dtype) for idx in torch.unravel_index(indices, shape)\n    )\n\n\ndef real(x):\n    if not isinstance(x, torch.Tensor):\n        x = torch.from_numpy(x)  # needed for complex type conversion\n    return torch.real(x)\n\n\ndef reciprocal(x):\n    x = convert_to_tensor(x)\n    return torch.reciprocal(x)\n\n\ndef repeat(x, repeats, axis=None):\n    x = convert_to_tensor(x)\n\n    if get_device() == \"meta\":\n        x = KerasTensor(x.shape, standardize_dtype(x.dtype))\n        outputs = repeat(x, repeats, axis=axis)\n\n        return torch.empty(\n            size=outputs.shape,\n            dtype=to_torch_dtype(outputs.dtype),\n            device=get_device(),\n        )\n\n    repeats = convert_to_tensor(repeats, dtype=int)\n\n    return torch.repeat_interleave(x, repeats, dim=axis)\n\n\ndef reshape(x, newshape):\n    if not isinstance(newshape, (list, tuple)):\n        newshape = (newshape,)\n    x = convert_to_tensor(x)\n    return torch.reshape(x, newshape)\n\n\ndef roll(x, shift, axis=None):\n    x = convert_to_tensor(x)\n    return torch.roll(x, shift, dims=axis)\n\n\ndef searchsorted(sorted_sequence, values, side=\"left\"):\n    if ndim(sorted_sequence) != 1:\n        raise ValueError(\n            \"`searchsorted` only supports 1-D sorted sequences. \"\n            \"You can use `keras.ops.vectorized_map` \"\n            \"to extend it to N-D sequences. Received: \"\n            f\"sorted_sequence.shape={sorted_sequence.shape}\"\n        )\n    out_int32 = sorted_sequence.shape[0] <= np.iinfo(np.int32).max\n    return torch.searchsorted(\n        sorted_sequence, values, side=side, out_int32=out_int32\n    )\n\n\ndef sign(x):\n    x = convert_to_tensor(x)\n    return torch.sign(x)\n\n\ndef signbit(x):\n    x = convert_to_tensor(x)\n    return torch.signbit(x)\n\n\ndef sin(x):\n    x = convert_to_tensor(x)\n    return torch.sin(x)\n\n\ndef sinc(x):\n    x = convert_to_tensor(x)\n    return torch.sinc(x)\n\n\ndef sinh(x):\n    x = convert_to_tensor(x)\n    return torch.sinh(x)\n\n\ndef size(x):\n    x_shape = convert_to_tensor(tuple(x.shape))\n    return torch.prod(x_shape)\n\n\ndef sort(x, axis=-1):\n    x = convert_to_tensor(x)\n    # TODO: torch.sort doesn't support bool with cuda\n    if get_device() == \"cuda\" and standardize_dtype(x.dtype) == \"bool\":\n        x = cast(x, \"uint8\")\n        return cast(torch.sort(x, dim=axis).values, \"bool\")\n    return torch.sort(x, dim=axis).values\n\n\ndef split(x, indices_or_sections, axis=0):\n    x = convert_to_tensor(x)\n    dim = x.shape[axis]\n    if not isinstance(indices_or_sections, int):\n        indices_or_sections = convert_to_tensor(indices_or_sections)\n        start_size = indices_or_sections[0:1]\n        end_size = dim - indices_or_sections[-1:]\n        chunk_sizes = torch.concat(\n            [start_size, torch.diff(indices_or_sections), end_size], dim=0\n        )\n        # torch.split doesn't support tensor input for `split_size_or_sections`\n        chunk_sizes = chunk_sizes.tolist()\n    else:\n        if dim % indices_or_sections != 0:\n            raise ValueError(\n                f\"Received indices_or_sections={indices_or_sections} \"\n                f\"(interpreted as a number of sections) and axis={axis}, \"\n                f\"but input dimension x.shape[{axis}]={x.shape[axis]} \"\n                f\"is not divisible by {indices_or_sections}. \"\n                f\"Full input shape: x.shape={x.shape}\"\n            )\n        chunk_sizes = dim // indices_or_sections\n    out = torch.split(\n        tensor=x,\n        split_size_or_sections=chunk_sizes,\n        dim=axis,\n    )\n    if dim == 0 and isinstance(indices_or_sections, int):\n        out = [out[0].clone() for _ in range(indices_or_sections)]\n    return list(out)\n\n\ndef array_split(x, indices_or_sections, axis=0):\n    x = convert_to_tensor(x)\n    out = torch.tensor_split(x, indices_or_sections, dim=axis)\n    return list(out)\n\n\ndef stack(x, axis=0):\n    x = [convert_to_tensor(elem) for elem in x]\n    return torch.stack(x, dim=axis)\n\n\ndef std(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    ori_dtype = standardize_dtype(x.dtype)\n    if \"int\" in ori_dtype or ori_dtype == \"bool\":\n        x = cast(x, \"float32\")\n    # Remove Bessel correction to align with numpy\n    return torch.std(x, dim=axis, keepdim=keepdims, unbiased=False)\n\n\ndef swapaxes(x, axis1, axis2):\n    x = convert_to_tensor(x)\n    return torch.swapaxes(x, axis0=axis1, axis1=axis2)\n\n\ndef take(x, indices, axis=None):\n    x = convert_to_tensor(x)\n    indices = convert_to_tensor(indices).long()\n    # Correct the indices using \"fill\" mode which is the same as in jax\n    x_dim = x.shape[axis] if axis is not None else x.shape[0]\n    indices = torch.where(\n        indices < 0,\n        indices + x_dim,\n        indices,\n    )\n    if x.ndim == 2 and axis == 0:\n        # This case is equivalent to embedding lookup.\n        return torch.nn.functional.embedding(indices, x)\n    if axis is None:\n        x = torch.reshape(x, (-1,))\n        axis = 0\n    if axis is not None:\n        axis = canonicalize_axis(axis, x.ndim)\n        shape = x.shape[:axis] + indices.shape + x.shape[axis + 1 :]\n        # ravel the `indices` since `index_select` expects `indices`\n        # to be a vector (1-D tensor).\n        indices = indices.ravel()\n        out = torch.index_select(x, dim=axis, index=indices).squeeze(axis)\n        return out.reshape(shape)\n    return torch.take(x, index=indices)\n\n\ndef take_along_axis(x, indices, axis=None):\n    x = convert_to_tensor(x)\n    indices = convert_to_tensor(indices).long()\n    # Correct the indices using \"fill\" mode which is the same as in jax\n    x_dim = x.shape[axis] if axis is not None else x.shape[0]\n    indices = torch.where(\n        indices < 0,\n        indices + x_dim,\n        indices,\n    )\n    return torch.take_along_dim(x, indices, dim=axis)\n\n\ndef tan(x):\n    x = convert_to_tensor(x)\n    return torch.tan(x)\n\n\ndef tanh(x):\n    x = convert_to_tensor(x)\n    return torch.tanh(x)\n\n\ndef tensordot(x1, x2, axes=2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    result_dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    # TODO: torch.tensordot only supports float types\n    compute_dtype = dtypes.result_type(result_dtype, float)\n    # TODO: torch.tensordot doesn't support float16 with cpu\n    if get_device() == \"cpu\" and compute_dtype == \"float16\":\n        compute_dtype = \"float32\"\n    x1 = cast(x1, compute_dtype)\n    x2 = cast(x2, compute_dtype)\n    # torch only handles dims=((0,), (1,)), numpy accepts axes=(0, 1).\n    if isinstance(axes, (list, tuple)):\n        first, second = axes\n        if not isinstance(first, (list, tuple)):\n            first = (first,)\n        if not isinstance(second, (list, tuple)):\n            second = (second,)\n        axes = (first, second)\n    return cast(torch.tensordot(x1, x2, dims=axes), result_dtype)\n\n\ndef round(x, decimals=0):\n    x = convert_to_tensor(x)\n    ori_dtype = standardize_dtype(x.dtype)\n    # TODO: torch.round doesn't support int8, int16, int32, int64, uint8\n    if \"int\" in ori_dtype:\n        x = cast(x, config.floatx())\n        return cast(torch.round(x, decimals=decimals), ori_dtype)\n    return torch.round(x, decimals=decimals)\n\n\ndef tile(x, repeats):\n    if is_tensor(repeats):\n        repeats = tuple(repeats.int().numpy())\n    if isinstance(repeats, int):\n        repeats = (repeats,)\n    x = convert_to_tensor(x)\n    return torch.tile(x, dims=repeats)\n\n\ndef trace(x, offset=0, axis1=0, axis2=1):\n    x = convert_to_tensor(x)\n    dtype = standardize_dtype(x.dtype)\n    if dtype in (\"bool\", \"int8\", \"int16\", \"uint8\"):\n        # Torch backend doesn't support uint32 dtype.\n        dtype = \"int32\"\n    return torch.sum(\n        torch.diagonal(x, offset, axis1, axis2),\n        dim=-1,\n        dtype=to_torch_dtype(dtype),\n    )\n\n\ndef tri(N, M=None, k=0, dtype=None):\n    dtype = to_torch_dtype(dtype or config.floatx())\n    M = M or N\n    x = torch.ones((N, M), dtype=dtype, device=get_device())\n    return torch.tril(x, diagonal=k)\n\n\ndef tril(x, k=0):\n    x = convert_to_tensor(x)\n    return torch.tril(x, diagonal=k)\n\n\ndef triu(x, k=0):\n    x = convert_to_tensor(x)\n    return torch.triu(x, diagonal=k)\n\n\ndef trunc(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"bool\":\n        return x\n    return torch.trunc(x)\n\n\ndef vdot(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    result_dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    # TODO: torch.vdot only supports float types\n    compute_dtype = dtypes.result_type(result_dtype, float)\n\n    # TODO: torch.vdot doesn't support float16 with cpu\n    if get_device() == \"cpu\" and compute_dtype == \"float16\":\n        compute_dtype = \"float32\"\n\n    x1 = cast(x1, compute_dtype)\n    x2 = cast(x2, compute_dtype)\n    return cast(torch.vdot(x1, x2), result_dtype)\n\n\ndef inner(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n    result_dtype = dtypes.result_type(x1.dtype, x2.dtype)\n    compute_dtype = dtypes.result_type(result_dtype, float)\n\n    if get_device() == \"cpu\" and compute_dtype == \"float16\":\n        compute_dtype = \"float32\"\n\n    x1 = cast(x1, compute_dtype)\n    x2 = cast(x2, compute_dtype)\n    return cast(torch.inner(x1, x2), result_dtype)\n\n\ndef vstack(xs):\n    xs = [convert_to_tensor(x) for x in xs]\n    return torch.vstack(xs)\n\n\ndef vsplit(x, indices_or_sections):\n    x = convert_to_tensor(x)\n    if not isinstance(indices_or_sections, int):\n        indices_or_sections = convert_to_tensor(indices_or_sections).tolist()\n    return list(torch.vsplit(x, indices_or_sections))\n\n\ndef vectorize(pyfunc, *, excluded=None, signature=None):\n    return vectorize_impl(\n        pyfunc, torch.vmap, excluded=excluded, signature=signature\n    )\n\n\ndef where(condition, x1=None, x2=None):\n    condition = convert_to_tensor(condition, dtype=bool)\n    if x1 is not None and x2 is not None:\n        x1 = convert_to_tensor(x1)\n        x2 = convert_to_tensor(x2)\n        return torch.where(condition, x1, x2)\n    else:\n        return torch.where(condition)\n\n\ndef divide(x1, x2):\n    if not isinstance(x1, (int, float)):\n        x1 = convert_to_tensor(x1)\n    if not isinstance(x2, (int, float)):\n        x2 = convert_to_tensor(x2)\n    return torch.divide(x1, x2)\n\n\ndef divide_no_nan(x1, x2):\n    if not isinstance(x1, (int, float)):\n        x1 = convert_to_tensor(x1)\n    if not isinstance(x2, (int, float)):\n        x2 = convert_to_tensor(x2)\n    return torch.where(x2 == 0, 0, torch.divide(x1, x2))\n\n\ndef true_divide(x1, x2):\n    return divide(x1, x2)\n\n\ndef power(x1, x2):\n    x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)\n    return torch.pow(x1, x2)\n\n\ndef negative(x):\n    x = convert_to_tensor(x)\n    return torch.negative(x)\n\n\ndef nextafter(x1, x2):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n\n    dtype = dtypes.result_type(x1.dtype, x2.dtype, float)\n    x1 = cast(x1, torch.float64)\n    x2 = cast(x2, torch.float64)\n    return cast(torch.nextafter(x1, x2), dtype)\n\n\ndef square(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"bool\":\n        x = cast(x, \"int32\")\n    return torch.square(x)\n\n\ndef sqrt(x):\n    x = convert_to_tensor(x)\n    if standardize_dtype(x.dtype) == \"int64\":\n        x = cast(x, config.floatx())\n    return torch.sqrt(x)\n\n\ndef squeeze(x, axis=None):\n    x = convert_to_tensor(x)\n    if axis is not None:\n        return torch.squeeze(x, dim=axis)\n    return torch.squeeze(x)\n\n\ndef transpose(x, axes=None):\n    x = convert_to_tensor(x)\n    if axes is not None:\n        return torch.permute(x, dims=axes)\n    return x.T\n\n\ndef trapezoid(y, x=None, dx=1.0, axis=-1):\n    y = convert_to_tensor(y)\n    if standardize_dtype(y.dtype) == \"bool\":\n        y = cast(y, config.floatx())\n    if x is not None:\n        x = convert_to_tensor(x)\n        return torch.trapz(y, x=x, dim=axis)\n    else:\n        dx = convert_to_tensor(dx)\n        return torch.trapz(y, dx=dx, dim=axis)\n\n\ndef vander(x, N=None, increasing=False):\n    x = convert_to_tensor(x)\n    result_dtype = dtypes.result_type(x.dtype)\n    return cast(torch.vander(x, N=N, increasing=increasing), result_dtype)\n\n\ndef var(x, axis=None, keepdims=False):\n    x = convert_to_tensor(x)\n    compute_dtype = dtypes.result_type(x.dtype, \"float32\")\n    result_dtype = dtypes.result_type(x.dtype, float)\n    if axis == [] or axis == ():\n        # Torch handles the empty axis case differently from numpy.\n        return zeros_like(x, result_dtype)\n    # Bessel correction removed for numpy compatibility\n    x = cast(x, compute_dtype)\n    return cast(\n        torch.var(x, dim=axis, keepdim=keepdims, correction=0), result_dtype\n    )\n\n\ndef sum(x, axis=None, keepdims=False):\n    if isinstance(x, (list, tuple)):\n        x = stack(x)\n    x = convert_to_tensor(x)\n    if axis == () or axis == []:\n        # Torch handles the empty axis case differently from numpy.\n        return x\n    dtype = standardize_dtype(x.dtype)\n    # follow jax's rule\n    # TODO: torch doesn't support uint32\n    if dtype in (\"bool\", \"uint8\", \"int8\", \"int16\"):\n        dtype = \"int32\"\n    if axis is not None:\n        return cast(torch.sum(x, axis=axis, keepdim=keepdims), dtype)\n    return cast(torch.sum(x), dtype)\n\n\ndef eye(N, M=None, k=0, dtype=None):\n    dtype = to_torch_dtype(dtype or config.floatx())\n    M = N if M is None else M\n    k = 0 if k is None else k\n    if k == 0:\n        # TODO: torch.eye doesn't support bfloat16 with cpu\n        if get_device() == \"cpu\" and dtype == torch.bfloat16:\n            return cast(\n                torch.eye(\n                    N, M, dtype=to_torch_dtype(\"float32\"), device=get_device()\n                ),\n                dtype,\n            )\n        return torch.eye(N, M, dtype=dtype, device=get_device())\n    diag_length = builtins.max(N, M)\n    diag = torch.ones(diag_length, dtype=dtype, device=get_device())\n    return torch.diag(diag, diagonal=k)[:N, :M]\n\n\ndef floor_divide(x1, x2):\n    if not isinstance(x1, (int, float)):\n        x1 = convert_to_tensor(x1)\n    if not isinstance(x2, (int, float)):\n        x2 = convert_to_tensor(x2)\n    dtype = dtypes.result_type(\n        getattr(x1, \"dtype\", type(x1)),\n        getattr(x2, \"dtype\", type(x2)),\n    )\n    return cast(torch.floor_divide(x1, x2), dtype)\n\n\ndef logical_xor(x1, x2):\n    x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)\n    return torch.logical_xor(x1, x2)\n\n\ndef corrcoef(x):\n    x = convert_to_tensor(x)\n\n    if standardize_dtype(x.dtype) == \"bool\":\n        x = cast(x, config.floatx())\n    elif standardize_dtype(x.dtype) == \"int64\":\n        x = cast(x, \"float64\")\n\n    return torch.corrcoef(x)\n\n\ndef correlate(x1, x2, mode=\"valid\"):\n    x1 = convert_to_tensor(x1)\n    x2 = convert_to_tensor(x2)\n\n    dtype = dtypes.result_type(\n        getattr(x1, \"dtype\", type(x1)),\n        getattr(x2, \"dtype\", type(x2)),\n    )\n    if dtype == \"int64\":\n        dtype = \"float64\"\n    elif dtype not in [\"bfloat16\", \"float16\", \"float64\"]:\n        dtype = \"float32\"\n\n    x1 = cast(x1, dtype)\n    x2 = cast(x2, dtype)\n\n    x1_len, x2_len = x1.size(0), x2.size(0)\n\n    if x1.shape[:-1] != x2.shape[:-1]:\n        new_shape = [max(i, j) for i, j in zip(x1.shape[:-1], x2.shape[:-1])]\n        x1 = torch.broadcast_to(x1, new_shape + [x1.shape[-1]])\n        x2 = torch.broadcast_to(x2, new_shape + [x2.shape[-1]])\n\n    num_signals = torch.tensor(x1.shape[:-1]).prod()\n    x1 = torch.reshape(x1, (int(num_signals), x1.size(-1)))\n    x2 = torch.reshape(x2, (int(num_signals), x2.size(-1)))\n\n    output = torch.nn.functional.conv1d(\n        x1, x2.unsqueeze(1), groups=x1.size(0), padding=x2.size(-1) - 1\n    )\n    output_shape = x1.shape[:-1] + (-1,)\n    result = output.reshape(output_shape)\n\n    if mode == \"valid\":\n        target_length = (\n            builtins.max(x1_len, x2_len) - builtins.min(x1_len, x2_len) + 1\n        )\n        start_idx = (result.size(-1) - target_length) // 2\n        result = result[..., start_idx : start_idx + target_length]\n\n    if mode == \"same\":\n        start_idx = (result.size(-1) - x1_len) // 2\n        result = result[..., start_idx : start_idx + x1_len]\n\n    return torch.squeeze(result)\n\n\ndef select(condlist, choicelist, default=0):\n    condlist = [convert_to_tensor(c) for c in condlist]\n    choicelist = [convert_to_tensor(c) for c in choicelist]\n    out = convert_to_tensor(default)\n    for c, v in reversed(list(zip(condlist, choicelist))):\n        out = torch.where(c, v, out)\n    return out\n\n\ndef slogdet(x):\n    x = convert_to_tensor(x)\n    return tuple(torch.linalg.slogdet(x))\n\n\ndef argpartition(x, kth, axis=-1):\n    x = convert_to_tensor(x, \"int32\")\n    x = torch.transpose(x, axis, -1)\n    bottom_ind = torch.topk(-x, kth + 1)[1]\n\n    def set_to_zero(a, i):\n        a[i] = torch.zeros(1, dtype=a.dtype, device=a.device)\n        return a\n\n    for _ in range(x.dim() - 1):\n        set_to_zero = torch.vmap(set_to_zero)\n    proxy = set_to_zero(torch.ones_like(x, dtype=torch.int32), bottom_ind)\n    top_ind = torch.topk(proxy, x.shape[-1] - kth - 1)[1]\n    out = torch.cat([bottom_ind, top_ind], dim=x.dim() - 1)\n    return cast(torch.transpose(out, -1, axis), \"int32\")\n\n\ndef histogram(x, bins=10, range=None):\n    hist_result = torch.histogram(x, bins=bins, range=range)\n    return hist_result.hist, hist_result.bin_edges\n"
  },
  {
    "path": "keras/src/backend/torch/optimizers/__init__.py",
    "content": "from keras.src.backend.torch.optimizers.torch_optimizer import TorchOptimizer\n"
  },
  {
    "path": "keras/src/backend/torch/optimizers/torch_adadelta.py",
    "content": "import torch\n\nfrom keras.src import ops\nfrom keras.src import optimizers\nfrom keras.src.backend.torch.optimizers import torch_parallel_optimizer\n\n\nclass Adadelta(\n    torch_parallel_optimizer.TorchParallelOptimizer, optimizers.Adadelta\n):\n    def _parallel_update_step(\n        self,\n        grads,\n        variables,\n        learning_rate,\n    ):\n        keras_variables = variables\n        variables = [v.value for v in variables]\n\n        dtype = variables[0].dtype\n        lr = ops.cast(learning_rate, dtype)\n        rho = self.rho\n\n        accumulated_grads = [\n            self._accumulated_grads[self._get_variable_index(variable)].value\n            for variable in keras_variables\n        ]\n        accumulated_delta_vars = [\n            self._accumulated_delta_vars[\n                self._get_variable_index(variable)\n            ].value\n            for variable in keras_variables\n        ]\n        torch._foreach_mul_(accumulated_grads, rho)\n        torch._foreach_add_(\n            accumulated_grads, torch._foreach_mul(grads, grads), alpha=1 - rho\n        )\n\n        def rms(x):\n            return torch._foreach_sqrt(torch._foreach_add(x, self.epsilon))\n\n        delta_vars = torch._foreach_mul(\n            torch._foreach_div(\n                torch._foreach_mul(rms(accumulated_delta_vars), grads),\n                rms(accumulated_grads),\n            ),\n            -1,\n        )\n        torch._foreach_mul_(accumulated_delta_vars, rho)\n        torch._foreach_add_(\n            accumulated_delta_vars,\n            torch._foreach_mul(delta_vars, delta_vars),\n            alpha=1 - rho,\n        )\n\n        torch._foreach_add_(variables, delta_vars, alpha=lr)\n"
  },
  {
    "path": "keras/src/backend/torch/optimizers/torch_adagrad.py",
    "content": "import torch\n\nfrom keras.src import ops\nfrom keras.src import optimizers\nfrom keras.src.backend.torch.optimizers import torch_parallel_optimizer\n\n\nclass Adagrad(\n    torch_parallel_optimizer.TorchParallelOptimizer, optimizers.Adagrad\n):\n    def _parallel_update_step(\n        self,\n        grads,\n        variables,\n        learning_rate,\n    ):\n        keras_variables = variables\n        variables = [v.value for v in variables]\n\n        dtype = variables[0].dtype\n        lr = ops.cast(learning_rate, dtype)\n\n        accumulators = [\n            self._accumulators[self._get_variable_index(variable)].value\n            for variable in keras_variables\n        ]\n        torch._foreach_add_(accumulators, torch._foreach_mul(grads, grads))\n        torch._foreach_add_(\n            variables,\n            torch._foreach_div(\n                torch._foreach_mul(grads, lr),\n                torch._foreach_sqrt(\n                    torch._foreach_add(accumulators, self.epsilon)\n                ),\n            ),\n            alpha=-1,\n        )\n"
  },
  {
    "path": "keras/src/backend/torch/optimizers/torch_adam.py",
    "content": "import torch\n\nfrom keras.src import ops\nfrom keras.src import optimizers\nfrom keras.src.backend.torch.optimizers import torch_parallel_optimizer\n\n\nclass Adam(torch_parallel_optimizer.TorchParallelOptimizer, optimizers.Adam):\n    def _parallel_update_step(\n        self,\n        grads,\n        variables,\n        learning_rate,\n    ):\n        keras_variables = variables\n        variables = [v.value for v in variables]\n\n        dtype = variables[0].dtype\n        lr = ops.cast(learning_rate, dtype)\n        local_step = ops.cast(self.iterations + 1, dtype)\n\n        beta_1_power = ops.power(ops.cast(self.beta_1, dtype), local_step)\n        beta_2_power = ops.power(ops.cast(self.beta_2, dtype), local_step)\n        alpha = lr * ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)\n\n        m_list = [\n            self._momentums[self._get_variable_index(variable)].value\n            for variable in keras_variables\n        ]\n        v_list = [\n            self._velocities[self._get_variable_index(variable)].value\n            for variable in keras_variables\n        ]\n\n        torch._foreach_mul_(m_list, self.beta_1)\n        torch._foreach_add_(m_list, grads, alpha=1 - self.beta_1)\n\n        torch._foreach_mul_(v_list, self.beta_2)\n        torch._foreach_add_(\n            v_list, torch._foreach_mul(grads, grads), alpha=1 - self.beta_2\n        )\n\n        if self.amsgrad:\n            v_hat_list = [\n                self._velocity_hats[self._get_variable_index(variable)].value\n                for variable in keras_variables\n            ]\n            torch._foreach_maximum_(v_hat_list, v_list)\n            v_list = v_hat_list\n\n        torch._foreach_add_(\n            variables,\n            torch._foreach_div(\n                torch._foreach_mul(m_list, alpha),\n                torch._foreach_add(torch._foreach_sqrt(v_list), self.epsilon),\n            ),\n            alpha=-1,\n        )\n"
  },
  {
    "path": "keras/src/backend/torch/optimizers/torch_adamax.py",
    "content": "import torch\n\nfrom keras.src import ops\nfrom keras.src import optimizers\nfrom keras.src.backend.torch.optimizers import torch_parallel_optimizer\n\n\nclass Adamax(\n    torch_parallel_optimizer.TorchParallelOptimizer, optimizers.Adamax\n):\n    def _parallel_update_step(\n        self,\n        grads,\n        variables,\n        learning_rate,\n    ):\n        keras_variables = variables\n        variables = [v.value for v in variables]\n\n        dtype = variables[0].dtype\n        lr = ops.cast(learning_rate, dtype)\n\n        local_step = ops.cast(self.iterations + 1, dtype)\n\n        beta_1_power = ops.power(ops.cast(self.beta_1, dtype), local_step)\n\n        m_list = [\n            self._m[self._get_variable_index(variable)].value\n            for variable in keras_variables\n        ]\n        u_list = [\n            self._u[self._get_variable_index(variable)].value\n            for variable in keras_variables\n        ]\n\n        torch._foreach_mul_(m_list, self.beta_1)\n        torch._foreach_add_(m_list, grads, alpha=1 - self.beta_1)\n\n        torch._foreach_mul_(u_list, self.beta_2)\n        torch._foreach_maximum_(u_list, torch._foreach_abs(grads))\n\n        torch._foreach_add_(\n            variables,\n            torch._foreach_div(\n                torch._foreach_mul(m_list, lr),\n                torch._foreach_mul(\n                    torch._foreach_add(u_list, self.epsilon),\n                    1 - beta_1_power,\n                ),\n            ),\n            alpha=-1,\n        )\n"
  },
  {
    "path": "keras/src/backend/torch/optimizers/torch_adamw.py",
    "content": "from keras.src import optimizers\nfrom keras.src.backend.torch.optimizers import torch_adam\n\n\nclass AdamW(torch_adam.Adam, optimizers.AdamW):\n    pass\n"
  },
  {
    "path": "keras/src/backend/torch/optimizers/torch_lion.py",
    "content": "import torch\n\nfrom keras.src import ops\nfrom keras.src import optimizers\nfrom keras.src.backend.torch.optimizers import torch_parallel_optimizer\n\n\nclass Lion(torch_parallel_optimizer.TorchParallelOptimizer, optimizers.Lion):\n    def _parallel_update_step(\n        self,\n        grads,\n        variables,\n        learning_rate,\n    ):\n        keras_variables = variables\n        variables = [v.value for v in variables]\n\n        dtype = variables[0].dtype\n        lr = ops.cast(learning_rate, dtype)\n\n        m_list = [\n            self._momentums[self._get_variable_index(variable)].value\n            for variable in keras_variables\n        ]\n\n        c_t = torch._foreach_mul(m_list, self.beta_1)\n        torch._foreach_add_(c_t, grads, alpha=1 - self.beta_1)\n        c_t = [c.sign() for c in c_t]\n\n        torch._foreach_add_(\n            variables,\n            torch._foreach_mul(c_t, lr),\n            alpha=-1,\n        )\n\n        torch._foreach_mul_(m_list, self.beta_2)\n        torch._foreach_add_(m_list, grads, alpha=1 - self.beta_2)\n"
  },
  {
    "path": "keras/src/backend/torch/optimizers/torch_nadam.py",
    "content": "import torch\n\nfrom keras.src import ops\nfrom keras.src import optimizers\nfrom keras.src.backend.torch import core\nfrom keras.src.backend.torch.optimizers import torch_parallel_optimizer\n\n\nclass Nadam(torch_parallel_optimizer.TorchParallelOptimizer, optimizers.Nadam):\n    def _parallel_update_step(\n        self,\n        grads,\n        variables,\n        learning_rate,\n    ):\n        keras_variables = variables\n        variables = [v.value for v in variables]\n\n        dtype = variables[0].dtype\n        lr = ops.cast(learning_rate, dtype)\n\n        local_step = ops.cast(self.iterations + 1, dtype)\n        next_step = ops.cast(self.iterations + 2, dtype)\n        decay = ops.cast(0.96, dtype)\n        beta_1 = ops.cast(self.beta_1, dtype)\n        beta_2 = ops.cast(self.beta_2, dtype)\n        u_t = beta_1 * (1.0 - 0.5 * (ops.power(decay, local_step)))\n        u_t_1 = beta_1 * (1.0 - 0.5 * (ops.power(decay, next_step)))\n        u_product_t = self._u_product.value * u_t\n        u_product_t_1 = u_product_t * u_t_1\n        beta_2_power = ops.power(beta_2, local_step)\n\n        self._u_product.assign(u_product_t)\n\n        m_list = [\n            self._momentums[self._get_variable_index(variable)].value\n            for variable in keras_variables\n        ]\n        v_list = [\n            self._velocities[self._get_variable_index(variable)].value\n            for variable in keras_variables\n        ]\n\n        torch._foreach_mul_(m_list, self.beta_1)\n        torch._foreach_add_(m_list, grads, alpha=1 - self.beta_1)\n\n        torch._foreach_mul_(v_list, self.beta_2)\n        torch._foreach_add_(\n            v_list, torch._foreach_mul(grads, grads), alpha=1 - self.beta_2\n        )\n\n        m_hat_list = torch._foreach_add(\n            torch._foreach_div(\n                torch._foreach_mul(m_list, u_t_1),\n                1 - core.convert_to_numpy(u_product_t_1),\n            ),\n            torch._foreach_div(\n                torch._foreach_mul(grads, 1 - u_t),\n                1 - core.convert_to_numpy(u_product_t),\n            ),\n        )\n\n        v_hat_list = torch._foreach_div(v_list, 1 - beta_2_power)\n\n        torch._foreach_add_(\n            variables,\n            torch._foreach_div(\n                torch._foreach_mul(m_hat_list, lr),\n                torch._foreach_add(\n                    torch._foreach_sqrt(v_hat_list), self.epsilon\n                ),\n            ),\n            alpha=-1,\n        )\n"
  },
  {
    "path": "keras/src/backend/torch/optimizers/torch_optimizer.py",
    "content": "import torch\n\nfrom keras.src import optimizers\nfrom keras.src.optimizers.base_optimizer import BaseOptimizer\nfrom keras.src.utils import torch_utils\n\n\nclass TorchOptimizer(BaseOptimizer):\n    def __new__(cls, *args, **kwargs):\n        # Import locally to avoid circular imports.\n        from keras.src.backend.torch.optimizers import torch_adadelta\n        from keras.src.backend.torch.optimizers import torch_adagrad\n        from keras.src.backend.torch.optimizers import torch_adam\n        from keras.src.backend.torch.optimizers import torch_adamax\n        from keras.src.backend.torch.optimizers import torch_adamw\n        from keras.src.backend.torch.optimizers import torch_lion\n        from keras.src.backend.torch.optimizers import torch_nadam\n        from keras.src.backend.torch.optimizers import torch_rmsprop\n        from keras.src.backend.torch.optimizers import torch_sgd\n\n        OPTIMIZERS = {\n            optimizers.Adadelta: torch_adadelta.Adadelta,\n            optimizers.Adagrad: torch_adagrad.Adagrad,\n            optimizers.Adam: torch_adam.Adam,\n            optimizers.Adamax: torch_adamax.Adamax,\n            optimizers.AdamW: torch_adamw.AdamW,\n            optimizers.Lion: torch_lion.Lion,\n            optimizers.Nadam: torch_nadam.Nadam,\n            optimizers.RMSprop: torch_rmsprop.RMSprop,\n            optimizers.SGD: torch_sgd.SGD,\n        }\n\n        if cls in OPTIMIZERS:\n            return OPTIMIZERS[cls](*args, **kwargs)\n        return super().__new__(cls)\n\n    @torch_utils.no_grad\n    def _apply_weight_decay(self, variables):\n        if self.weight_decay is None:\n            return\n\n        torch._foreach_mul_(\n            [v.value for v in variables if self._use_weight_decay(v)],\n            1 - self.weight_decay * self._get_current_learning_rate(),\n        )\n"
  },
  {
    "path": "keras/src/backend/torch/optimizers/torch_parallel_optimizer.py",
    "content": "import torch\n\nfrom keras.src.optimizers.base_optimizer import BaseOptimizer\nfrom keras.src.utils import torch_utils\n\n\nclass TorchParallelOptimizer(BaseOptimizer):\n    @torch_utils.no_grad\n    def _backend_update_step(self, grads, trainable_variables, learning_rate):\n        self._parallel_update_step(\n            grads,\n            trainable_variables,\n            learning_rate,\n        )\n\n    @torch_utils.no_grad\n    def _backend_reset_gradient_accumulators(self):\n        acc_list = [\n            v.value for v in self._accumulated_gradients if v is not None\n        ]\n        torch._foreach_mul_(acc_list, 0.0)\n\n    @torch_utils.no_grad\n    def _backend_increment_gradient_accumulators(self, grads, acc_grads):\n        acc_list = [v.value for v in acc_grads]\n        torch._foreach_add_(acc_list, grads, alpha=1.0)\n"
  },
  {
    "path": "keras/src/backend/torch/optimizers/torch_rmsprop.py",
    "content": "import torch\n\nfrom keras.src import ops\nfrom keras.src import optimizers\nfrom keras.src.backend.torch.optimizers import torch_parallel_optimizer\n\n\nclass RMSprop(\n    torch_parallel_optimizer.TorchParallelOptimizer, optimizers.RMSprop\n):\n    def _parallel_update_step(\n        self,\n        grads,\n        variables,\n        learning_rate,\n    ):\n        keras_variables = variables\n        variables = [v.value for v in variables]\n\n        dtype = variables[0].dtype\n        lr = ops.cast(learning_rate, dtype)\n\n        velocities = [\n            self._velocities[self._get_variable_index(variable)].value\n            for variable in keras_variables\n        ]\n\n        rho = self.rho\n\n        torch._foreach_mul_(velocities, rho)\n        torch._foreach_add_(\n            velocities, torch._foreach_mul(grads, grads), alpha=1 - rho\n        )\n\n        denominators = torch._foreach_add(velocities, self.epsilon)\n        if self.centered:\n            average_grads = [\n                self._average_gradients[\n                    self._get_variable_index(variable)\n                ].value\n                for variable in keras_variables\n            ]\n            torch._foreach_mul_(average_grads, rho)\n            torch._foreach_add_(average_grads, grads, alpha=1 - rho)\n            torch._foreach_add_(\n                denominators,\n                torch._foreach_mul(average_grads, average_grads),\n                alpha=-1,\n            )\n        torch._foreach_sqrt_(denominators)\n        increments = torch._foreach_div(\n            torch._foreach_mul(grads, lr), denominators\n        )\n\n        if self.momentum > 0:\n            momentum_list = [\n                self._momentums[self._get_variable_index(variable)].value\n                for variable in keras_variables\n            ]\n            torch._foreach_mul_(momentum_list, self.momentum)\n            torch._foreach_add_(momentum_list, increments)\n            torch._foreach_add_(variables, momentum_list, alpha=-1)\n        else:\n            torch._foreach_add_(variables, increments, alpha=-1)\n"
  },
  {
    "path": "keras/src/backend/torch/optimizers/torch_sgd.py",
    "content": "import torch\n\nfrom keras.src import optimizers\nfrom keras.src.backend.torch.optimizers import torch_parallel_optimizer\n\n\nclass SGD(torch_parallel_optimizer.TorchParallelOptimizer, optimizers.SGD):\n    def _parallel_update_step(\n        self,\n        grads,\n        variables,\n        learning_rate,\n    ):\n        keras_variables = variables\n        variables = [v.value for v in variables]\n        if self.momentum != 0:\n            bufs = [\n                self.momentums[self._get_variable_index(variable)].value\n                for variable in keras_variables\n            ]\n\n            for i in range(len(bufs)):\n                if bufs[i] is None:\n                    bufs[i] = torch.clone(grads[i]).detach()\n\n            torch._foreach_mul_(bufs, self.momentum)\n            torch._foreach_add_(bufs, grads, alpha=-learning_rate)\n\n            if self.nesterov:\n                torch._foreach_add_(variables, grads, alpha=-learning_rate)\n                torch._foreach_add_(variables, bufs, alpha=self.momentum)\n            else:\n                torch._foreach_add_(variables, bufs)\n\n        else:\n            torch._foreach_add_(variables, grads, alpha=-learning_rate)\n"
  },
  {
    "path": "keras/src/backend/torch/random.py",
    "content": "import torch\nimport torch._dynamo as dynamo\nimport torch.nn.functional as tnn\n\nfrom keras.src.backend.config import floatx\nfrom keras.src.backend.torch.core import convert_to_tensor\nfrom keras.src.backend.torch.core import get_device\nfrom keras.src.backend.torch.core import to_torch_dtype\nfrom keras.src.random.seed_generator import SeedGenerator\nfrom keras.src.random.seed_generator import draw_seed\nfrom keras.src.random.seed_generator import make_default_seed\n\n\n# torch.Generator not supported with dynamo\n# see: https://github.com/pytorch/pytorch/issues/88576\n@dynamo.disable()\ndef torch_seed_generator(seed):\n    device = get_device()\n    if device == \"meta\":\n        # Generator is not supported by the meta device.\n        return None\n    generator = torch.Generator(device=get_device())\n    first_seed, second_seed = draw_seed(seed)\n    # Re-interpret as uint32 and combine; the counter in second_seed ensures\n    # each SeedGenerator call produces a distinct seed.\n    generator.manual_seed(int(first_seed + second_seed) & 0xFFFFFFFF)\n    return generator\n\n\ndef normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):\n    dtype = dtype or floatx()\n    dtype = to_torch_dtype(dtype)\n    # Do not use generator during symbolic execution.\n    if get_device() == \"meta\":\n        return torch.normal(\n            mean, stddev, size=shape, dtype=dtype, device=get_device()\n        )\n    generator = torch_seed_generator(seed)\n    return torch.normal(\n        mean,\n        stddev,\n        size=shape,\n        generator=generator,\n        dtype=dtype,\n        device=get_device(),\n    )\n\n\ndef categorical(logits, num_samples, dtype=\"int32\", seed=None):\n    logits = convert_to_tensor(logits)\n    dtype = to_torch_dtype(dtype)\n    probs = torch.softmax(logits, dim=-1)\n    # Do not use generator during symbolic execution.\n    if get_device() == \"meta\":\n        return torch.multinomial(\n            probs,\n            num_samples,\n            replacement=True,\n        ).type(dtype)\n    generator = torch_seed_generator(seed)\n    return torch.multinomial(\n        probs,\n        num_samples,\n        replacement=True,\n        generator=generator,\n    ).type(dtype)\n\n\ndef uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):\n    dtype = dtype or floatx()\n    dtype = to_torch_dtype(dtype)\n    requested_shape = shape\n    if len(requested_shape) == 0:\n        shape = (1,)\n    # Do not use generator during symbolic execution.\n    if get_device() == \"meta\":\n        rand_tensor = torch.rand(size=shape, dtype=dtype, device=get_device())\n    else:\n        generator = torch_seed_generator(seed)\n        rand_tensor = torch.rand(\n            size=shape, generator=generator, dtype=dtype, device=get_device()\n        )\n\n    output = (maxval - minval) * rand_tensor + minval\n\n    if len(requested_shape) == 0:\n        return output[0]\n    return output\n\n\ndef randint(shape, minval, maxval, dtype=\"int32\", seed=None):\n    dtype = to_torch_dtype(dtype)\n    # Do not use generator during symbolic execution.\n    if get_device() == \"meta\":\n        return torch.randint(\n            low=minval,\n            high=maxval,\n            size=shape,\n            dtype=dtype,\n            device=get_device(),\n        )\n    generator = torch_seed_generator(seed)\n    return torch.randint(\n        low=minval,\n        high=maxval,\n        size=shape,\n        generator=generator,\n        dtype=dtype,\n        device=get_device(),\n    )\n\n\ndef truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):\n    dtype = to_torch_dtype(dtype)\n    # Take a larger standard normal dist, discard values outside 2 * stddev\n    # Offset by mean and stddev\n    x = normal(tuple(shape) + (4,), mean=0, stddev=1, dtype=dtype, seed=seed)\n    valid = (x > -2) & (x < 2)\n    indexes = valid.max(-1, keepdim=True)[1]\n    trunc_x = torch.empty(shape, dtype=dtype, device=get_device())\n    trunc_x.data.copy_(x.gather(-1, indexes).squeeze(-1))\n    trunc_x.data.mul_(stddev).add_(mean)\n    return trunc_x\n\n\ndef _get_concrete_noise_shape(inputs, noise_shape):\n    if noise_shape is None:\n        return inputs.shape\n\n    concrete_inputs_shape = inputs.shape\n    concrete_noise_shape = []\n    for i, value in enumerate(noise_shape):\n        concrete_noise_shape.append(\n            concrete_inputs_shape[i] if value is None else value\n        )\n    return concrete_noise_shape\n\n\ndef dropout(inputs, rate, noise_shape=None, seed=None):\n    if rate == 1.0:\n        return torch.zeros_like(inputs, device=get_device())\n    if rate == 0.0:\n        return inputs\n    if (\n        seed is not None\n        and not (isinstance(seed, SeedGenerator) and seed._initial_seed is None)\n        or noise_shape is not None\n    ):\n        keep_prob = 1.0 - rate\n        noise_shape = _get_concrete_noise_shape(inputs, noise_shape)\n        keep_prob_matrix = torch.full(\n            noise_shape, keep_prob, device=get_device()\n        )\n        generator = torch_seed_generator(seed)\n\n        # Do not use generator during symbolic execution.\n        if get_device() == \"meta\":\n            mask = torch.bernoulli(keep_prob_matrix)\n        else:\n            mask = torch.bernoulli(keep_prob_matrix, generator=generator)\n\n        mask = mask.bool()\n        mask = torch.broadcast_to(mask, inputs.shape)\n        return torch.where(\n            mask,\n            inputs / keep_prob,\n            torch.zeros_like(inputs, dtype=inputs.dtype),\n        )\n    # Fast path, unseeded (since torch doesn't support seeding dropout!!!!)\n    # Using the above implementation is possible, but much slower.\n    return torch.nn.functional.dropout(\n        inputs, p=rate, training=True, inplace=False\n    )\n\n\ndef shuffle(x, axis=0, seed=None):\n    # Ref: https://github.com/pytorch/pytorch/issues/71409\n    x = convert_to_tensor(x)\n\n    # Get permutation indices\n    # Do not use generator during symbolic execution.\n    if get_device() == \"meta\":\n        row_perm = torch.rand(x.shape[: axis + 1], device=get_device()).argsort(\n            axis\n        )\n    else:\n        generator = torch_seed_generator(seed)\n        row_perm = torch.rand(\n            x.shape[: axis + 1], generator=generator, device=get_device()\n        ).argsort(axis)\n    for _ in range(x.ndim - axis - 1):\n        row_perm.unsqueeze_(-1)\n\n    # Reformat this for the gather operation\n    row_perm = row_perm.repeat(\n        *[1 for _ in range(axis + 1)], *(x.shape[axis + 1 :])\n    )\n    return x.gather(axis, row_perm)\n\n\ndef gamma(shape, alpha, dtype=None, seed=None):\n    dtype = dtype or floatx()\n    dtype = to_torch_dtype(dtype)\n    alpha = torch.broadcast_to(convert_to_tensor(alpha), shape)\n    beta = torch.ones(shape, device=get_device())\n    prev_rng_state = torch.random.get_rng_state()\n    # Do not draw seed during symbolic execution\n    if not get_device() == \"meta\":\n        first_seed, second_seed = draw_seed(seed)\n        torch.manual_seed(int(first_seed + second_seed) & 0xFFFFFFFF)\n    gamma_distribution = torch.distributions.gamma.Gamma(alpha, beta)\n    sample = gamma_distribution.sample().type(dtype)\n    torch.random.set_rng_state(prev_rng_state)\n    return sample\n\n\ndef binomial(shape, counts, probabilities, dtype=None, seed=None):\n    dtype = dtype or floatx()\n    dtype = to_torch_dtype(dtype)\n    counts = torch.broadcast_to(convert_to_tensor(counts), shape)\n    probabilities = torch.broadcast_to(convert_to_tensor(probabilities), shape)\n    prev_rng_state = torch.random.get_rng_state()\n    # Do not draw seed during symbolic execution\n    if not get_device() == \"meta\":\n        first_seed, second_seed = draw_seed(seed)\n        torch.manual_seed(int(first_seed + second_seed) & 0xFFFFFFFF)\n    binomial_distribution = torch.distributions.binomial.Binomial(\n        total_count=counts, probs=probabilities\n    )\n    sample = binomial_distribution.sample().type(dtype)\n    torch.random.set_rng_state(prev_rng_state)\n    return sample\n\n\ndef beta(shape, alpha, beta, dtype=None, seed=None):\n    dtype = dtype or floatx()\n    dtype = to_torch_dtype(dtype)\n    alpha = torch.broadcast_to(convert_to_tensor(alpha), shape)\n    beta = torch.broadcast_to(convert_to_tensor(beta), shape)\n    prev_rng_state = torch.random.get_rng_state()\n    # Do not draw seed during symbolic execution\n    if not get_device() == \"meta\":\n        first_seed, second_seed = draw_seed(seed)\n        torch.manual_seed(int(first_seed + second_seed) & 0xFFFFFFFF)\n    beta_distribution = torch.distributions.beta.Beta(\n        concentration1=alpha, concentration0=beta\n    )\n    sample = beta_distribution.sample().type(dtype)\n    torch.random.set_rng_state(prev_rng_state)\n    return sample\n"
  },
  {
    "path": "keras/src/backend/torch/rnn.py",
    "content": "import torch\n\nfrom keras.src import tree\nfrom keras.src.backend.torch.core import convert_to_tensor\nfrom keras.src.backend.torch.core import get_device\n\n\ndef rnn(\n    step_function,\n    inputs,\n    initial_states,\n    go_backwards=False,\n    mask=None,\n    constants=None,\n    unroll=False,\n    input_length=None,\n    time_major=False,\n    zero_output_for_mask=False,\n    return_all_outputs=True,\n):\n    input_length = input_length or inputs.shape[1]\n\n    def swap_batch_timestep(input_t):\n        # Swap the batch and timestep dim for the incoming tensor.\n        axes = list(range(len(input_t.shape)))\n        axes[0], axes[1] = 1, 0\n        return torch.permute(input_t, axes)\n\n    if not time_major:\n        inputs = tree.map_structure(swap_batch_timestep, inputs)\n\n    flattened_inputs = tree.flatten(inputs)\n    time_steps = flattened_inputs[0].shape[0]\n    time_steps_t = time_steps\n\n    if mask is not None:\n        if mask.dtype != torch.bool:\n            mask = mask.type(torch.bool)\n        if len(mask.shape) == 2:\n            mask = torch.unsqueeze(mask, -1)\n        if not time_major:\n            mask = swap_batch_timestep(mask)\n\n    if constants is None:\n        constants = []\n\n    def _expand_mask(mask_t, input_t, fixed_dim=1):\n        if tree.is_nested(mask_t):\n            raise ValueError(\n                f\"mask_t is expected to be tensor,\\\n                  but got {mask_t}\"\n            )\n        if tree.is_nested(input_t):\n            raise ValueError(\n                f\"input_t is expected to be tensor,\\\n                  but got {input_t}\"\n            )\n        rank_diff = len(input_t.shape) - len(mask_t.shape)\n        for _ in range(rank_diff):\n            mask_t = torch.unsqueeze(mask_t, -1)\n        multiples = [1] * fixed_dim + list(input_t.shape[fixed_dim:])\n        return torch.tile(mask_t, multiples)\n\n    if unroll:\n        if not time_steps:\n            raise ValueError(\"Unrolling requires a fixed number of timesteps.\")\n        states = tuple(initial_states)\n        successive_states = []\n        successive_outputs = []\n\n        # Process the input tensors. The input tensor need to be split on the\n        # time_step dim, and reverse if go_backwards is True. In the case of\n        # nested input, the input is flattened and then transformed\n        # individually.  The result of this will be a tuple of lists, each of\n        # the item in tuple is list of the tensor with shape (batch, feature)\n        def _process_single_input_t(input_t):\n            input_t = torch.unbind(input_t)  # unstack for time_step dim\n            if go_backwards:\n                input_t = input_t[::-1]\n            return input_t\n\n        if tree.is_nested(inputs):\n            processed_input = tree.map_structure(\n                _process_single_input_t, inputs\n            )  # noqa: E501\n        else:\n            processed_input = (_process_single_input_t(inputs),)\n\n        def _get_input_tensor(time):\n            inp = [t_[time] for t_ in processed_input]\n            return tree.pack_sequence_as(inputs, inp)\n\n        if mask is not None:\n            mask_list = torch.unbind(mask)\n            if go_backwards:\n                mask_list = torch.flip(mask_list, dims=mask_list.shape)\n\n            for i in range(time_steps):\n                inp = _get_input_tensor(i)\n                mask_t = mask_list[i]\n                output, new_states = step_function(\n                    inp, tuple(states) + tuple(constants)\n                )\n                tiled_mask_t = _expand_mask(mask_t, output)\n\n                if not successive_outputs:\n                    prev_output = torch.zeros_like(output)\n                else:\n                    prev_output = successive_outputs[-1]\n\n                output = torch.where(tiled_mask_t, output, prev_output)\n\n                flat_states = tree.flatten(states)\n                flat_new_states = tree.flatten(new_states)\n                tiled_mask_t = tuple(\n                    _expand_mask(mask_t, s) for s in flat_states\n                )  # noqa: E501\n                flat_final_states = tuple(\n                    torch.where(m, s, ps)\n                    for m, s, ps in zip(\n                        tiled_mask_t, flat_new_states, flat_states\n                    )  # noqa: E501\n                )\n                states = tree.pack_sequence_as(states, flat_final_states)\n\n                if return_all_outputs:\n                    successive_outputs.append(output)\n                    successive_states.append(states)\n                else:\n                    successive_outputs = [output]\n                    successive_states = [states]\n            last_output = successive_outputs[-1]\n            new_states = successive_states[-1]\n            outputs = torch.stack(successive_outputs)\n\n            if zero_output_for_mask:\n                last_output = torch.where(\n                    _expand_mask(mask_list[-1], last_output),\n                    last_output,\n                    torch.zeros_like(last_output),\n                )\n                outputs = torch.where(\n                    _expand_mask(mask, outputs, fixed_dim=2),\n                    outputs,\n                    torch.zeros_like(outputs),\n                )\n\n        else:  # mask is None\n            for i in range(time_steps):\n                inp = _get_input_tensor(i)\n                output, states = step_function(\n                    inp, tuple(states) + tuple(constants)\n                )  # noqa: E501\n                if return_all_outputs:\n                    successive_outputs.append(output)\n                    successive_states.append(states)\n                else:\n                    successive_outputs = [output]\n                    successive_states = [states]\n            last_output = successive_outputs[-1]\n            new_states = successive_states[-1]\n            outputs = torch.stack(successive_outputs)\n\n    else:  # Unroll == False\n        states = tuple(initial_states)\n\n        # Create input tensor array, if the inputs is nested tensors, then it\n        # will be flattened first, and tensor array will be created one per\n        # flattened tensor.\n\n        input_ta = tuple(\n            (\n                list(torch.unbind(input_))\n                if not go_backwards\n                else list(torch.unbind(torch.flip(input_, [0])))\n            )\n            for input_ in flattened_inputs\n        )\n\n        # Get the time(0) input and compute the output for that.\n        input_time_zero = tree.pack_sequence_as(\n            inputs, [inp[0] for inp in flattened_inputs]\n        )\n        # output_time_zero is used to determine the cell output shape.\n        output_time_zero, _ = step_function(\n            input_time_zero, tuple(initial_states) + tuple(constants)\n        )\n\n        output_ta_size = time_steps_t if return_all_outputs else 1\n        output_ta = []\n        for out in tree.flatten(output_time_zero):\n            out_list = list(out)\n            if len(out) < output_ta_size:\n                out_list.extend([[]] * (output_ta_size - len(out)))\n            output_ta.append(out_list)\n\n        time = torch.tensor(0, dtype=torch.int32)\n\n        if input_length is None:\n            max_iterations = time_steps_t\n        else:\n            if hasattr(input_length, \"__len__\"):\n                input_length = convert_to_tensor(input_length)\n                max_iterations = torch.max(input_length)\n            else:\n                max_iterations = input_length\n\n        if mask is not None:\n            if go_backwards:\n                mask = torch.flip(mask, [0])\n\n            mask_ta = list(torch.unbind(mask))\n\n            def masking_fn(time):\n                return mask_ta[time]\n\n            def compute_masked_output(mask_t, flat_out, flat_mask):\n                tiled_mask_t = tuple(\n                    _expand_mask(mask_t, o, fixed_dim=len(mask_t.shape))\n                    for o in flat_out\n                )\n                return tuple(\n                    torch.where(m, o, fm)\n                    for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask)\n                )\n\n        elif isinstance(input_length, torch.Tensor):\n            if go_backwards:\n                max_len = torch.max(input_length, dim=0)\n                if isinstance(max_len, torch.return_types.max):\n                    max_len = max_len[0]\n                rev_input_length = torch.subtract(max_len - 1, input_length)\n\n                def masking_fn(time):\n                    return torch.less(rev_input_length, time)\n\n            else:\n\n                def masking_fn(time):\n                    return torch.greater(input_length, time)\n\n            def compute_masked_output(mask_t, flat_out, flat_mask):\n                return tuple(\n                    torch.where(mask_t, o, zo)\n                    for (o, zo) in zip(flat_out, flat_mask)  # noqa: E501\n                )\n\n        else:\n            masking_fn = None\n\n        if masking_fn is not None:\n            # Mask for the T output will be base on the output of T - 1. In the\n            # case T = 0, a zero filled tensor will be used.\n            flat_zero_output = tuple(\n                torch.zeros_like(o) for o in tree.flatten(output_time_zero)\n            )\n\n            def _step(time, output_ta_t, prev_output, *states):\n                \"\"\"RNN step function.\n\n                Args:\n                    time: Current timestep value.\n                    output_ta_t: TensorArray.\n                    prev_output: tuple of outputs from time - 1.\n                    *states: List of states.\n\n                Returns:\n                    Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)`\n                \"\"\"\n                current_input = tuple(ta[time] for ta in input_ta)\n                # maybe set shape.\n                current_input = tree.pack_sequence_as(inputs, current_input)\n                mask_t = masking_fn(time)\n                output, new_states = step_function(\n                    current_input, tuple(states) + tuple(constants)\n                )\n                # mask output\n                flat_output = tree.flatten(output)\n                flat_mask_output = (\n                    flat_zero_output\n                    if zero_output_for_mask\n                    else tree.flatten(prev_output)\n                )\n                flat_new_output = compute_masked_output(\n                    mask_t, flat_output, flat_mask_output\n                )\n\n                # mask states\n                flat_state = tree.flatten(states)\n                flat_new_state = tree.flatten(new_states)\n                flat_final_state = compute_masked_output(\n                    mask_t, flat_new_state, flat_state\n                )\n                new_states = tree.pack_sequence_as(new_states, flat_final_state)  # noqa: E501\n\n                ta_index_to_write = time if return_all_outputs else 0\n                for ta, out in zip(output_ta_t, flat_new_output):\n                    ta[ta_index_to_write] = out\n\n                return (time + 1, output_ta_t, tuple(flat_new_output)) + tuple(\n                    new_states\n                )\n\n            it = 0\n            output_ta_t, new_states, prev_output = (\n                output_ta,\n                states,\n                flat_zero_output,\n            )\n            while time < time_steps_t and it < max_iterations:\n                final_outputs = _step(\n                    time, output_ta_t, prev_output, *new_states\n                )  # noqa: E501\n                time, output_ta_t, prev_output = final_outputs[:3]\n                new_states = final_outputs[3:]\n                it += 1\n\n        else:\n\n            def _step(time, output_ta_t, *states):\n                \"\"\"RNN step function.\n\n                Args:\n                    time: Current timestep value.\n                    output_ta_t: TensorArray.\n                    *states: List of states.\n\n                Returns:\n                    Tuple: `(time + 1,output_ta_t) + tuple(new_states)`\n                \"\"\"\n                current_input = tuple(ta[time] for ta in input_ta)\n                current_input = tree.pack_sequence_as(inputs, current_input)\n                output, new_states = step_function(\n                    current_input, tuple(states) + tuple(constants)\n                )\n                flat_new_state = tree.flatten(new_states)\n\n                flat_output = tree.flatten(output)\n                ta_index_to_write = time if return_all_outputs else 0\n                for ta, out in zip(output_ta_t, flat_output):\n                    ta[ta_index_to_write] = out\n\n                new_states = tree.pack_sequence_as(\n                    initial_states, flat_new_state\n                )  # noqa: E501\n                return (time + 1, output_ta_t) + tuple(new_states)\n\n            it = 0\n            output_ta_t = output_ta\n            new_states = states\n            while time < time_steps_t and it < max_iterations:\n                final_outputs = _step(time, output_ta_t, *new_states)\n                time, output_ta_t = final_outputs[:2]\n                new_states = final_outputs[2:]\n                it += 1\n\n        def _stack(tensor_list):\n            max_ndims = max([t.ndim for t in tensor_list])\n            max_list = []\n            for i, t in enumerate(tensor_list):\n                if t.ndim == max_ndims:\n                    max_list.append(t)\n            return torch.stack(max_list)\n\n        output_ta = final_outputs[1]\n\n        outputs = tuple(_stack(o) for o in output_ta)\n        last_output = tuple(o[-1] for o in outputs)\n\n        outputs = tree.pack_sequence_as(output_time_zero, outputs)\n        last_output = tree.pack_sequence_as(output_time_zero, last_output)\n\n    if not time_major:\n        outputs = tree.map_structure(swap_batch_timestep, outputs)\n\n    return last_output, outputs, new_states\n\n\ndef _is_sequence_right_padded(mask):\n    \"\"\"Check the mask tensor and see if it right padded.\n\n    cuDNN uses the sequence length param to skip the tailing\n    timestep. If the data is left padded, or not a strict right padding (has\n    masked value in the middle of the sequence), then cuDNN won't work\n    properly in those cases.\n\n    Left padded data: [[False, False, True, True, True]].\n    Right padded data: [[True, True, True, False, False]].\n    Mixture of mask/unmasked data: [[True, False, True, False, False]].\n\n    Note that for the mixed data example above, the actually data RNN should see\n    are those 2 Trues (index 0 and 2), the index 1 False should be ignored and\n    not pollute the internal states.\n\n    Args:\n        mask: the Boolean tensor with shape [batch, timestep]\n\n    Returns:\n        boolean scalar tensor, whether the mask is strictly right padded.\n    \"\"\"\n    # Get max sequence length\n    max_seq_length = mask.shape[1]\n    # Count True values in each sequence\n    count_of_true = torch.sum(mask, dim=1)\n    # Create right padded mask\n    batch_size = mask.shape[0]\n    indices = torch.arange(max_seq_length, device=mask.device).repeat(\n        batch_size, 1\n    )  # noqa: E501\n    right_padded_mask = indices < count_of_true.unsqueeze(1)\n    return torch.all(mask == right_padded_mask)\n\n\ndef _has_fully_masked_sequence(mask):\n    \"\"\"Check if input sequence contains any fully masked data.\n\n    cuDNN kernel will error out if the input sequence contains any fully masked\n    data. We work around this issue by rerouting the computation to the\n    standard kernel until the issue on the cuDNN side has been fixed. For a\n    fully masked sequence, it will contain all `False` values. To make it easy\n    to check, we invert the boolean and check if any of the sequences has all\n    `True` values.\n\n    Args:\n        mask: The mask tensor.\n\n    Returns:\n        A boolean tensor, `True` if the mask contains a fully masked sequence.\n    \"\"\"\n    return torch.any(torch.all(~mask, dim=1))\n\n\ndef _assert_valid_mask(mask):\n    # Check if mask is valid for cuDNN\n    no_fully_masked = ~_has_fully_masked_sequence(mask)\n    is_right_padded = _is_sequence_right_padded(mask)\n    valid = no_fully_masked & is_right_padded\n\n    if not valid.item():\n        error_message = (\n            \"You are passing a RNN mask that does not correspond to \"\n            \"right-padded sequences, while using cuDNN, which is not \"\n            \"supported. With cuDNN, RNN masks can only be used for \"\n            \"right-padding, e.g. `[[True, True, False, False]]` would \"\n            \"be a valid mask, but any mask that isn't just contiguous \"\n            \"`True`'s on the left and contiguous `False`'s on the right \"\n            \"would be invalid. You can pass `use_cudnn=False` to your \"\n            \"RNN layer to stop using cuDNN (this may be slower).\"\n        )\n        raise ValueError(error_message)\n\n\ndef _compute_sequence_length_from_mask(mask, batch_first):\n    \"\"\"Calculate the sequence length tensor (1-D) based on the masking tensor.\n\n    The masking tensor is a 2D boolean tensor with shape [batch, timestep]. For\n    any timestep that should be masked, the corresponding field will be False.\n    Consider the following example:\n        a = [[True, True, False, False]\n             [True, True, True, False]]\n    It is a (2, 4) tensor, and the corresponding sequence length result should\n    be 1D tensor with value [2, 3]. Note that the masking tensor must be right\n    padded that could be checked by, e.g., `is_sequence_right_padded()`.\n\n    Args:\n        mask: Boolean tensor with shape [batch, timestep] or [timestep, batch]\n            if time_major=True.\n        time_major: Boolean, which indicates whether the mask is time major or\n            batch major.\n\n    Returns:\n        sequence_length: 1D int32 tensor.\n    \"\"\"\n    timestep_index = 0 if not batch_first else 1\n    return torch.sum(mask.int(), dim=timestep_index)\n\n\ndef prepare_lstm_weights(lstm, kernel, recurrent_kernel, bias, device):\n    \"\"\"Copies kernel and recurrent kernel weights into the PyTorch format.\n\n    We split the kernel and recurrent kernel weights, create associated\n    torch tensors adapted to be in line with the cuDNN optimization.\n    After we have copied the weights, we ensure the parameters are on\n    the same device and memory layout is optimized for cuDNN.\n\n    Args:\n        lstm: The PyTorch LSTM layer to prepare weights for.\n        kernel: The kernel weights tensor.\n        recurrent_kernel: The recurrent kernel weights tensor.\n        bias: The bias tensor.\n        device: The device to place the tensors on.\n    \"\"\"\n\n    lstm = lstm.to(device)\n    hidden_size = lstm.hidden_size\n\n    # Convert gates from Keras [i,f,c,o] to PyTorch [i,f,g,o]\n    i_k, f_k, c_k, o_k = torch.chunk(kernel, 4, dim=1)\n    weight_ih_data = torch.cat([i_k, f_k, c_k, o_k], dim=1).T\n\n    i_r, f_r, c_r, o_r = torch.chunk(recurrent_kernel, 4, dim=1)\n    weight_hh_data = torch.cat([i_r, f_r, c_r, o_r], dim=1).T\n\n    if bias is not None:\n        # Split Keras combined bias into input and hidden biases\n        bias_ih_data = convert_to_tensor(bias, dtype=\"float32\")\n        bias_hh_data = torch.zeros_like(bias_ih_data)\n\n    else:\n        bias_ih_data = torch.zeros(4 * hidden_size, device=device)\n        bias_hh_data = torch.zeros(4 * hidden_size, device=device)\n\n    # Create PyTorch tensors for weights\n    weight_ih = convert_to_tensor(weight_ih_data, dtype=\"float32\").contiguous()\n    weight_hh = convert_to_tensor(weight_hh_data, dtype=\"float32\").contiguous()\n    bias_ih = convert_to_tensor(bias_ih_data, dtype=\"float32\").contiguous()\n    bias_hh = convert_to_tensor(bias_hh_data, dtype=\"float32\").contiguous()\n\n    # Ensure the weights are all on the same device\n    weight_ih = weight_ih.to(device)\n    weight_hh = weight_hh.to(device)\n    bias_ih = bias_ih.to(device)\n    bias_hh = bias_hh.to(device)\n\n    # Copy Keras weights into Torch's flat weights\n    with torch.no_grad():\n        lstm.weight_ih_l0.copy_(weight_ih)\n        lstm.weight_hh_l0.copy_(weight_hh)\n        lstm.bias_ih_l0.copy_(bias_ih)\n        lstm.bias_hh_l0.copy_(bias_hh)\n\n    # Optimize the layout\n    lstm.flatten_parameters()\n\n    # After prepare_lstm_weights:\n    # Force all LSTM parameters to be on the correct device\n    for param in lstm.parameters():\n        if param.device != device:\n            param.data = param.data.to(device)\n\n\ndef _is_cuda_cudnn_available():\n    # We check if the cuda device and drivers are available\n    return torch.cuda.is_available() and torch.backends.cudnn.is_available()\n\n\ndef cudnn_ok(\n    activation,\n    recurrent_activation,\n    unroll,\n    use_bias=True,\n):\n    from keras.src import activations\n    from keras.src import ops\n\n    return (\n        activation in (activations.tanh, torch.tanh, ops.tanh)\n        and recurrent_activation\n        in (activations.sigmoid, torch.sigmoid, ops.sigmoid)  # noqa: E501\n        and not unroll\n        and use_bias\n        and _is_cuda_cudnn_available()\n    )\n\n\ndef lstm(\n    inputs,\n    initial_state_h,\n    initial_state_c,\n    mask,\n    kernel,\n    recurrent_kernel,\n    bias,\n    activation,\n    recurrent_activation,\n    return_sequences=False,\n    go_backwards=False,\n    unroll=False,\n    batch_first=True,\n):\n    cudnn_supported = cudnn_ok(\n        activation,\n        recurrent_activation,\n        unroll,\n        use_bias=bias is not None,\n    )\n\n    if not cudnn_supported:\n        raise NotImplementedError\n\n    # Get device from inputs\n    device = get_device()\n\n    from keras.src.backend.torch import Variable\n\n    if isinstance(kernel, Variable):\n        kernel = kernel.value\n    if isinstance(recurrent_kernel, Variable):\n        recurrent_kernel = recurrent_kernel.value\n    if isinstance(bias, Variable):\n        bias = bias.value\n\n    # Convert to torch tensors\n    inputs = convert_to_tensor(inputs, dtype=\"float32\")\n    initial_state_h = convert_to_tensor(initial_state_h, dtype=\"float32\")\n    initial_state_c = convert_to_tensor(initial_state_c, dtype=\"float32\")\n    if mask is not None:\n        mask = convert_to_tensor(mask, dtype=\"bool\")\n\n    # Preprocess for go_backwards by flipping the sequence\n    if go_backwards:\n        seq_dim = 1 if batch_first else 0\n        inputs = torch.flip(inputs, dims=[seq_dim])\n        if mask is not None:\n            mask = torch.flip(mask, dims=[seq_dim])\n\n    # Move all tensors to the same device\n    inputs = inputs.to(device)\n    initial_state_h = initial_state_h.to(device)\n    initial_state_c = initial_state_c.to(device)\n    if mask is not None:\n        mask = mask.to(device)\n\n    try:\n        return _cudnn_lstm(\n            inputs,\n            initial_state_h,\n            initial_state_c,\n            kernel,\n            recurrent_kernel,\n            bias,\n            mask,\n            batch_first,\n            go_backwards,\n            return_sequences,\n            device,\n        )\n    except RuntimeError:\n        raise NotImplementedError\n\n\ndef _cudnn_lstm(\n    inputs,\n    initial_state_h,\n    initial_state_c,\n    kernel,\n    recurrent_kernel,\n    bias,\n    mask,\n    batch_first,\n    go_backwards,\n    return_sequences,\n    device,\n):\n    if mask is not None:\n        _assert_valid_mask(mask)\n        sequence_lengths = _compute_sequence_length_from_mask(mask, batch_first)\n\n    # Ensure inputs are in batch_first format for consistency\n    if not batch_first:\n        inputs = inputs.permute(1, 0, 2)\n\n    seq_axis, batch_axis = (0, 1) if not batch_first else (1, 0)\n\n    # If shape is [batch, hidden]; Make [1, batch, hidden]\n    if initial_state_h.dim() == 2:\n        initial_state_h = initial_state_h.unsqueeze(0)\n        initial_state_c = initial_state_c.unsqueeze(0)\n    # If shape is [batch, 1, hidden]\n    elif initial_state_h.dim() == 3 and initial_state_h.shape[1] == 1:\n        initial_state_h = initial_state_h.permute(1, 0, 2)\n        initial_state_c = initial_state_c.permute(1, 0, 2)\n\n    input_size = kernel.shape[0]\n    hidden_size = recurrent_kernel.shape[0]\n\n    # Configure LSTM with the provided parameters\n    lstm = torch.nn.LSTM(\n        input_size=input_size,\n        hidden_size=hidden_size,\n        num_layers=1,\n        batch_first=batch_first,\n        bidirectional=False,\n    )\n\n    prepare_lstm_weights(lstm, kernel, recurrent_kernel, bias, device)\n\n    if mask is not None:\n        # Sort and pack\n        sorted_lengths, sorted_indices = torch.sort(\n            sequence_lengths, descending=True\n        )  # noqa: E501\n        sorted_inputs = inputs[sorted_indices]\n        sorted_initial_h = initial_state_h[:, sorted_indices]\n        sorted_initial_c = initial_state_c[:, sorted_indices]\n\n        # Create the packed sequence\n        packed_inputs = torch.nn.utils.rnn.pack_padded_sequence(\n            sorted_inputs, sorted_lengths.cpu(), batch_first\n        )\n\n        # Process with LSTM (which handles the packed sequence correctly)\n        packed_outputs, (h_n, c_n) = lstm(\n            packed_inputs, (sorted_initial_h, sorted_initial_c)\n        )\n\n        # Unpack back to padded tensor\n        outputs, _ = torch.nn.utils.rnn.pad_packed_sequence(\n            packed_outputs, batch_first\n        )  # noqa: E501\n\n    else:\n        # Run LSTM without packing for fixed-length sequences\n        outputs, (h_n, c_n) = lstm(inputs, (initial_state_h, initial_state_c))\n\n    # Reshape hidden states for return\n    h_n = h_n.squeeze(batch_axis)\n    c_n = c_n.squeeze(batch_axis)\n\n    # Return appropriate outputs based on return_sequences flag\n\n    if mask is not None:\n        last_output = h_n\n    else:\n        last_output = outputs[:, -1] if batch_first else outputs[-1]\n\n    if not return_sequences:\n        outputs = (\n            last_output.unsqueeze(1)\n            if batch_first\n            else last_output.unsqueeze(0)\n        )  # noqa: E501\n\n    if go_backwards and return_sequences:\n        outputs = torch.flip(outputs, dims=[seq_axis])\n\n    return last_output, outputs, [h_n, c_n]\n\n\ndef gru(\n    inputs,\n    initial_state,\n    mask,\n    kernel,\n    recurrent_kernel,\n    bias,\n    activation,\n    recurrent_activation,\n    return_sequences=False,\n    go_backwards=False,\n    unroll=False,\n    reset_after=True,\n):\n    cudnn_supported = cudnn_ok(\n        activation,\n        recurrent_activation,\n        unroll,\n        use_bias=bias is not None,\n    )\n\n    if not cudnn_supported or not reset_after or mask is not None:\n        raise NotImplementedError\n\n    # Get device from inputs\n    device = get_device()\n\n    # Convert to torch tensors (convert_to_tensor unwraps Variables)\n    kernel = convert_to_tensor(kernel)\n    recurrent_kernel = convert_to_tensor(recurrent_kernel)\n    if bias is not None:\n        bias = convert_to_tensor(bias)\n\n    inputs = convert_to_tensor(inputs, dtype=\"float32\")\n    initial_state = convert_to_tensor(initial_state, dtype=\"float32\")\n\n    # Preprocess for go_backwards by flipping the sequence\n    if go_backwards:\n        inputs = torch.flip(inputs, dims=[1])\n\n    # Move all tensors to the same device\n    inputs = inputs.to(device)\n    initial_state = initial_state.to(device)\n\n    try:\n        return _cudnn_gru(\n            inputs,\n            initial_state,\n            kernel,\n            recurrent_kernel,\n            bias,\n            return_sequences=return_sequences,\n            device=device,\n        )\n    except Exception:\n        raise NotImplementedError\n\n\ndef prepare_gru_params(kernel, recurrent_kernel, bias, device):\n    \"\"\"Prepares Keras GRU weights for PyTorch's functional GRU.\n\n    Reorders gates from Keras [z, r, h] to PyTorch [r, z, h] format\n    and returns weight tensors that maintain gradient connections.\n\n    Args:\n        kernel: The kernel weights tensor with shape (input_dim, 3*units).\n        recurrent_kernel: The recurrent kernel weights tensor\n            with shape (units, 3*units).\n        bias: The bias tensor with shape (2, 3*units) for reset_after=True.\n        device: The device to place the tensors on.\n\n    Returns:\n        A list of weight tensors [weight_ih, weight_hh, bias_ih, bias_hh]\n        suitable for torch._VF.gru.\n    \"\"\"\n    # Split Keras weights by gate: [z, r, h]\n    z_k, r_k, h_k = torch.chunk(kernel, 3, dim=1)\n    z_r, r_r, h_r = torch.chunk(recurrent_kernel, 3, dim=1)\n\n    # Reorder to PyTorch format [r, z, h] and transpose\n    weight_ih = torch.cat([r_k, z_k, h_k], dim=1).T.contiguous().to(device)\n    weight_hh = torch.cat([r_r, z_r, h_r], dim=1).T.contiguous().to(device)\n\n    if bias is not None:\n        # bias shape is (2, 3*units) for reset_after=True\n        # Row 0 is input bias, Row 1 is recurrent bias\n        z_bi, r_bi, h_bi = torch.chunk(bias[0], 3)\n        z_bh, r_bh, h_bh = torch.chunk(bias[1], 3)\n\n        # Reorder to [r, z, h]\n        bias_ih = torch.cat([r_bi, z_bi, h_bi]).contiguous().to(device)\n        bias_hh = torch.cat([r_bh, z_bh, h_bh]).contiguous().to(device)\n    else:\n        hidden_size = recurrent_kernel.shape[0]\n        bias_ih = torch.zeros(\n            3 * hidden_size, dtype=kernel.dtype, device=device\n        )\n        bias_hh = torch.zeros(\n            3 * hidden_size, dtype=kernel.dtype, device=device\n        )\n\n    return [weight_ih, weight_hh, bias_ih, bias_hh]\n\n\ndef _cudnn_gru(\n    inputs,\n    initial_state,\n    kernel,\n    recurrent_kernel,\n    bias,\n    return_sequences,\n    device,\n):\n    # If shape is [batch, hidden]; Make [1, batch, hidden]\n    if initial_state.dim() == 2:\n        initial_state = initial_state.unsqueeze(0)\n    # If shape is [batch, 1, hidden]\n    elif initial_state.dim() == 3 and initial_state.shape[1] == 1:\n        initial_state = initial_state.permute(1, 0, 2)\n\n    params = prepare_gru_params(kernel, recurrent_kernel, bias, device)\n\n    # Use functional GRU to maintain gradient flow through weight tensors\n    outputs, h_n = torch._VF.gru(\n        inputs,\n        initial_state,\n        params,\n        bias is not None,  # has_biases\n        1,  # num_layers\n        0.0,  # dropout\n        torch.is_grad_enabled(),  # training: must be True for backward pass\n        False,  # bidirectional\n        True,  # batch_first\n    )\n\n    # Reshape hidden state for return\n    h_n = h_n.squeeze(0)\n    last_output = outputs[:, -1]\n\n    if not return_sequences:\n        outputs = last_output.unsqueeze(1)\n\n    return last_output, outputs, [h_n]\n"
  },
  {
    "path": "keras/src/backend/torch/trainer.py",
    "content": "import warnings\n\nimport numpy as np\nimport torch\nfrom packaging.version import parse\n\nfrom keras.src import backend\nfrom keras.src import callbacks as callbacks_module\nfrom keras.src import optimizers as optimizers_module\nfrom keras.src import tree\nfrom keras.src.backend import config\nfrom keras.src.trainers import trainer as base_trainer\nfrom keras.src.trainers.data_adapters import array_slicing\nfrom keras.src.trainers.data_adapters import data_adapter_utils\nfrom keras.src.trainers.epoch_iterator import EpochIterator\nfrom keras.src.utils import traceback_utils\nfrom keras.src.utils.python_utils import pythonify_logs\n\n\nclass TorchTrainer(base_trainer.Trainer):\n    def __init__(self):\n        super().__init__()\n        self.train_function = None\n        self.test_function = None\n        self.predict_function = None\n\n    def _should_torch_compile(self):\n        # require torch>=2.1.0 to enable dynamo since it\n        # includes many improvements/fixes to torch.compile()\n        # TODO eventually we want to get rid of this when\n        # torch is upgraded to >=2.1 (from 2.0.1) in g3\n        if self.jit_compile and parse(torch.__version__) < parse(\"2.1.0\"):\n            warnings.warn(\n                \"Please upgrade to torch>=2.1.0 for `jit_compile=True` \"\n                \"to take effect. Using `jit_compile=False`\"\n            )\n            self.jit_compile = False\n\n        return self.jit_compile\n\n    def train_step(self, data):\n        x, y, sample_weight = data_adapter_utils.unpack_x_y_sample_weight(data)\n\n        # Compute predictions\n        if self._call_has_training_arg:\n            y_pred = self(x, training=True)\n        else:\n            y_pred = self(x)\n\n        # Call torch.nn.Module.zero_grad() to clear the leftover gradients\n        # for the weights from the previous train step.\n        self.zero_grad()\n\n        loss = self._compute_loss(\n            x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=True\n        )\n        self._loss_tracker.update_state(\n            loss,\n            sample_weight=next(\n                i for i in tree.flatten(x) if i is not None\n            ).shape[0],\n        )\n        if self.optimizer is not None:\n            loss = self.optimizer.scale_loss(loss)\n\n        # Compute gradients\n        if self.trainable_weights:\n            # Call torch.Tensor.backward() on the loss to compute gradients\n            # for the weights.\n            loss.backward()\n\n            trainable_weights = self.trainable_weights[:]\n            gradients = [v.value.grad for v in trainable_weights]\n\n            # Update weights\n            with torch.no_grad():\n                self.optimizer.apply(gradients, trainable_weights)\n        else:\n            warnings.warn(\"The model does not have any trainable weights.\")\n\n        return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)\n\n    def test_step(self, data):\n        (\n            x,\n            y,\n            sample_weight,\n        ) = data_adapter_utils.unpack_x_y_sample_weight(data)\n        if self._call_has_training_arg:\n            y_pred = self(x, training=False)\n        else:\n            y_pred = self(x)\n        loss = self._compute_loss(\n            x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=False\n        )\n        self._loss_tracker.update_state(\n            loss,\n            sample_weight=next(\n                i for i in tree.flatten(x) if i is not None\n            ).shape[0],\n        )\n        return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)\n\n    def predict_step(self, data):\n        x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data)\n        if self._call_has_training_arg:\n            y_pred = self(x, training=False)\n        else:\n            y_pred = self(x)\n        return y_pred\n\n    def make_train_function(self, force=False):\n        if self.train_function is not None and not force:\n            return self.train_function\n\n        if self.steps_per_execution > 1:\n            raise ValueError(\n                \"`steps_per_execution` must be 1 with the PyTorch backend. \"\n                f\"Received: steps_per_execution={self.steps_per_execution}\"\n            )\n\n        def one_step_on_data(data):\n            \"\"\"Runs a single training step on a batch of data.\"\"\"\n            data = data[0]\n            return self.train_step(data)\n\n        if self._should_torch_compile():\n            self.train_function = torch.compile(one_step_on_data)\n        else:\n            self.train_function = one_step_on_data\n\n    def make_test_function(self, force=False):\n        if self.test_function is not None and not force:\n            return self.test_function\n\n        if self.steps_per_execution > 1:\n            raise ValueError(\n                \"`steps_per_execution` must be 1 with the PyTorch backend. \"\n                f\"Received: steps_per_execution={self.steps_per_execution}\"\n            )\n\n        def one_step_on_data(data):\n            \"\"\"Runs a single test step on a batch of data.\"\"\"\n            data = data[0]\n            with torch.no_grad():\n                return self.test_step(data)\n\n        if self._should_torch_compile():\n            self.test_function = torch.compile(one_step_on_data)\n        else:\n            self.test_function = one_step_on_data\n\n    def make_predict_function(self, force=False):\n        if self.predict_function is not None and not force:\n            return self.predict_function\n\n        if self.steps_per_execution > 1:\n            raise ValueError(\n                \"`steps_per_execution` must be 1 with the PyTorch backend. \"\n                f\"Received: steps_per_execution={self.steps_per_execution}\"\n            )\n\n        def one_step_on_data(data):\n            \"\"\"Runs a predict test step on a batch of data.\"\"\"\n            data = data[0]\n            with torch.no_grad():\n                return self.predict_step(data)\n\n        if self._should_torch_compile():\n            self.predict_function = torch.compile(one_step_on_data)\n        else:\n            self.predict_function = one_step_on_data\n\n    @traceback_utils.filter_traceback\n    def fit(\n        self,\n        x=None,\n        y=None,\n        batch_size=None,\n        epochs=1,\n        verbose=\"auto\",\n        callbacks=None,\n        validation_split=0.0,\n        validation_data=None,\n        shuffle=True,\n        class_weight=None,\n        sample_weight=None,\n        initial_epoch=0,\n        steps_per_epoch=None,\n        validation_steps=None,\n        validation_batch_size=None,\n        validation_freq=1,\n    ):\n        if not self.compiled:\n            raise ValueError(\n                \"You must call `compile()` before calling `fit()`.\"\n            )\n        # Possibly cap epochs for debugging runs.\n        max_epochs = config.max_epochs()\n        if max_epochs and max_epochs < epochs:\n            warnings.warn(\"Limiting epochs to %d\" % max_epochs)\n            epochs = max_epochs\n\n        # TODO: respect compiled trainable state\n        self._eval_epoch_iterator = None\n        if validation_split and validation_data is None:\n            # Create the validation data using the training data. Only supported\n            # for TF/numpy/jax arrays.\n            # TODO: Support torch tensors for validation data.\n            (\n                (x, y, sample_weight),\n                validation_data,\n            ) = array_slicing.train_validation_split(\n                (x, y, sample_weight), validation_split=validation_split\n            )\n\n        if validation_data is not None:\n            (\n                val_x,\n                val_y,\n                val_sample_weight,\n            ) = data_adapter_utils.unpack_x_y_sample_weight(validation_data)\n\n        # Create an iterator that yields batches for one epoch.\n        epoch_iterator = TorchEpochIterator(\n            x=x,\n            y=y,\n            sample_weight=sample_weight,\n            batch_size=batch_size,\n            steps_per_epoch=steps_per_epoch,\n            shuffle=shuffle,\n            class_weight=class_weight,\n            steps_per_execution=self.steps_per_execution,\n        )\n\n        self._symbolic_build(iterator=epoch_iterator)\n        epoch_iterator.reset()\n\n        # Container that configures and calls callbacks.\n        if not isinstance(callbacks, callbacks_module.CallbackList):\n            callbacks = callbacks_module.CallbackList(\n                callbacks,\n                add_history=True,\n                add_progbar=verbose != 0,\n                verbose=verbose,\n                epochs=epochs,\n                steps=epoch_iterator.num_batches,\n                model=self,\n            )\n\n        self.stop_training = False\n        training_logs = {}\n        self.make_train_function()\n        callbacks.on_train_begin()\n        initial_epoch = self._initial_epoch or initial_epoch\n        for epoch in range(initial_epoch, epochs):\n            self.reset_metrics()\n            callbacks.on_epoch_begin(epoch)\n\n            # Switch the torch Module to training mode. Inform torch layers to\n            # do training behavior in case the user did not use `self.training`\n            # when implementing a custom layer with torch layers.\n            self.train()\n\n            logs = {}\n            for begin_step, end_step, data in epoch_iterator:\n                # Callbacks\n                callbacks.on_train_batch_begin(begin_step)\n\n                logs = self.train_function(data)\n\n                # Callbacks\n                callbacks.on_train_batch_end(end_step, logs)\n                if self.stop_training:\n                    break\n\n            # Override with model metrics instead of last step logs if needed.\n            epoch_logs = dict(self._get_metrics_result_or_logs(logs))\n\n            # Switch the torch Module back to testing mode.\n            self.eval()\n\n            # Run validation.\n            if validation_data is not None and self._should_eval(\n                epoch, validation_freq\n            ):\n                # Create TorchEpochIterator for evaluation and cache it.\n                if getattr(self, \"_eval_epoch_iterator\", None) is None:\n                    self._eval_epoch_iterator = TorchEpochIterator(\n                        x=val_x,\n                        y=val_y,\n                        sample_weight=val_sample_weight,\n                        batch_size=validation_batch_size or batch_size,\n                        steps_per_execution=self.steps_per_execution,\n                        steps_per_epoch=validation_steps,\n                        shuffle=False,\n                    )\n                val_logs = self.evaluate(\n                    x=val_x,\n                    y=val_y,\n                    sample_weight=val_sample_weight,\n                    batch_size=validation_batch_size or batch_size,\n                    steps=validation_steps,\n                    callbacks=callbacks,\n                    return_dict=True,\n                    _use_cached_eval_dataset=True,\n                )\n                val_logs = {\n                    f\"val_{name}\": val for name, val in val_logs.items()\n                }\n                epoch_logs.update(val_logs)\n\n            callbacks.on_epoch_end(epoch, epoch_logs)\n            training_logs = epoch_logs\n            if self.stop_training:\n                break\n\n        if (\n            isinstance(self.optimizer, optimizers_module.Optimizer)\n            and epochs > 0\n        ):\n            self.optimizer.finalize_variable_values(self.trainable_weights)\n\n        # If _eval_epoch_iterator exists, delete it after all epochs are done.\n        if getattr(self, \"_eval_epoch_iterator\", None) is not None:\n            del self._eval_epoch_iterator\n        callbacks.on_train_end(logs=training_logs)\n        return self.history\n\n    @traceback_utils.filter_traceback\n    def evaluate(\n        self,\n        x=None,\n        y=None,\n        batch_size=None,\n        verbose=\"auto\",\n        sample_weight=None,\n        steps=None,\n        callbacks=None,\n        return_dict=False,\n        **kwargs,\n    ):\n        # TODO: respect compiled trainable state\n        use_cached_eval_dataset = kwargs.pop(\"_use_cached_eval_dataset\", False)\n        if kwargs:\n            raise ValueError(f\"Arguments not recognized: {kwargs}\")\n\n        if use_cached_eval_dataset:\n            epoch_iterator = self._eval_epoch_iterator\n        else:\n            # Create an iterator that yields batches of input/target data.\n            epoch_iterator = TorchEpochIterator(\n                x=x,\n                y=y,\n                sample_weight=sample_weight,\n                batch_size=batch_size,\n                steps_per_epoch=steps,\n                shuffle=False,\n                steps_per_execution=self.steps_per_execution,\n            )\n\n        self._symbolic_build(iterator=epoch_iterator)\n        epoch_iterator.reset()\n\n        # Container that configures and calls callbacks.\n        if not isinstance(callbacks, callbacks_module.CallbackList):\n            callbacks = callbacks_module.CallbackList(\n                callbacks,\n                add_progbar=verbose != 0,\n                verbose=verbose,\n                epochs=1,\n                steps=epoch_iterator.num_batches,\n                model=self,\n            )\n\n        # Switch the torch Module back to testing mode.\n        self.eval()\n\n        self.make_test_function()\n        self.stop_evaluating = False\n        callbacks.on_test_begin()\n        logs = {}\n        self.reset_metrics()\n        for begin_step, end_step, data in epoch_iterator:\n            callbacks.on_test_batch_begin(begin_step)\n            logs = self.test_function(data)\n            callbacks.on_test_batch_end(end_step, logs)\n            if self.stop_evaluating:\n                break\n        logs = pythonify_logs(self._get_metrics_result_or_logs(logs))\n        callbacks.on_test_end(logs)\n\n        if return_dict:\n            return logs\n        return self._flatten_metrics_in_order(logs)\n\n    @traceback_utils.filter_traceback\n    def predict(\n        self, x, batch_size=None, verbose=\"auto\", steps=None, callbacks=None\n    ):\n        # Create an iterator that yields batches of input data.\n        epoch_iterator = TorchEpochIterator(\n            x=x,\n            batch_size=batch_size,\n            steps_per_epoch=steps,\n            shuffle=False,\n            steps_per_execution=self.steps_per_execution,\n        )\n\n        # Container that configures and calls callbacks.\n        if not isinstance(callbacks, callbacks_module.CallbackList):\n            callbacks = callbacks_module.CallbackList(\n                callbacks,\n                add_progbar=verbose != 0,\n                verbose=verbose,\n                epochs=1,\n                steps=epoch_iterator.num_batches,\n                model=self,\n            )\n\n        def append_to_outputs(batch_outputs, outputs):\n            if outputs is None:\n                outputs = tree.map_structure(\n                    lambda batch_output: [batch_output],\n                    batch_outputs,\n                )\n            else:\n                tree.map_structure_up_to(\n                    batch_outputs,\n                    lambda output, batch_output: output.append(batch_output),\n                    outputs,\n                    batch_outputs,\n                )\n            return outputs\n\n        # Switch the torch Module back to testing mode.\n        self.eval()\n\n        self.make_predict_function()\n        self.stop_predicting = False\n        callbacks.on_predict_begin()\n        outputs = None\n        for begin_step, end_step, data in epoch_iterator:\n            callbacks.on_predict_batch_begin(begin_step)\n            batch_outputs = self.predict_function(data)\n            outputs = append_to_outputs(batch_outputs, outputs)\n            callbacks.on_predict_batch_end(end_step, {\"outputs\": batch_outputs})\n            if self.stop_predicting:\n                break\n        callbacks.on_predict_end()\n        outputs = tree.map_structure(backend.convert_to_numpy, outputs)\n        return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs)\n\n    def train_on_batch(\n        self,\n        x,\n        y=None,\n        sample_weight=None,\n        class_weight=None,\n        return_dict=False,\n    ):\n        self._assert_compile_called(\"train_on_batch\")\n        if class_weight is not None:\n            if sample_weight is not None:\n                raise ValueError(\n                    \"Arguments `sample_weight` and `class_weight` \"\n                    \"cannot be specified at the same time. \"\n                    f\"Received: sample_weight={sample_weight}, \"\n                    f\"class_weight={class_weight}\"\n                )\n            sample_weight = data_adapter_utils.class_weight_to_sample_weights(\n                y, class_weight\n            )\n\n        data = (x, y, sample_weight)\n\n        # Maybe build model\n        self._symbolic_build(data_batch=data)\n        self.make_train_function()\n\n        logs = self.train_function([data])\n        logs = pythonify_logs(logs)\n        if return_dict:\n            return logs\n        return self._flatten_metrics_in_order(logs)\n\n    def test_on_batch(\n        self,\n        x,\n        y=None,\n        sample_weight=None,\n        return_dict=False,\n    ):\n        self._assert_compile_called(\"test_on_batch\")\n\n        data = (x, y, sample_weight)\n\n        # Maybe build model\n        self._symbolic_build(data_batch=data)\n        self.make_test_function()\n\n        logs = self.test_function([data])\n        logs = pythonify_logs(logs)\n        if return_dict:\n            return logs\n        return self._flatten_metrics_in_order(logs)\n\n    def predict_on_batch(self, x):\n        self.make_predict_function()\n        batch_outputs = self.predict_function([(x,)])\n        batch_outputs = tree.map_structure(\n            backend.convert_to_numpy, batch_outputs\n        )\n        return batch_outputs\n\n\nclass TorchEpochIterator(EpochIterator):\n    def _get_iterator(self):\n        return self.data_adapter.get_torch_dataloader()\n"
  },
  {
    "path": "keras/src/callbacks/__init__.py",
    "content": "from keras.src.callbacks.backup_and_restore import BackupAndRestore\nfrom keras.src.callbacks.callback import Callback\nfrom keras.src.callbacks.callback_list import CallbackList\nfrom keras.src.callbacks.csv_logger import CSVLogger\nfrom keras.src.callbacks.early_stopping import EarlyStopping\nfrom keras.src.callbacks.history import History\nfrom keras.src.callbacks.lambda_callback import LambdaCallback\nfrom keras.src.callbacks.learning_rate_scheduler import LearningRateScheduler\nfrom keras.src.callbacks.model_checkpoint import ModelCheckpoint\nfrom keras.src.callbacks.monitor_callback import MonitorCallback\nfrom keras.src.callbacks.orbax_checkpoint import OrbaxCheckpoint\nfrom keras.src.callbacks.progbar_logger import ProgbarLogger\nfrom keras.src.callbacks.reduce_lr_on_plateau import ReduceLROnPlateau\nfrom keras.src.callbacks.remote_monitor import RemoteMonitor\nfrom keras.src.callbacks.swap_ema_weights import SwapEMAWeights\nfrom keras.src.callbacks.tensorboard import TensorBoard\nfrom keras.src.callbacks.terminate_on_nan import TerminateOnNaN\n"
  },
  {
    "path": "keras/src/callbacks/backup_and_restore.py",
    "content": "import json\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.callbacks.callback import Callback\nfrom keras.src.utils import file_utils\n\n\n@keras_export(\"keras.callbacks.BackupAndRestore\")\nclass BackupAndRestore(Callback):\n    \"\"\"Callback to back up and restore the training state.\n\n    `BackupAndRestore` callback is intended to recover training from an\n    interruption that has happened in the middle of a `Model.fit` execution, by\n    backing up the training states in a temporary checkpoint file, at the end of\n    each epoch. Each backup overwrites the previously written checkpoint file,\n    so at any given time there is at most one such checkpoint file for\n    backup/restoring purpose.\n\n    If training restarts before completion, the training state (which includes\n    the `Model` weights and epoch number) is restored to the most recently saved\n    state at the beginning of a new `Model.fit` run. At the completion of a\n    `Model.fit` run, the temporary checkpoint file is deleted.\n\n    Note that the user is responsible to bring jobs back after the interruption.\n    This callback is important for the backup and restore mechanism for fault\n    tolerance purpose, and the model to be restored from a previous checkpoint\n    is expected to be the same as the one used to back up. If user changes\n    arguments passed to compile or fit, the checkpoint saved for fault tolerance\n    can become invalid.\n\n    Example:\n\n    >>> class InterruptingCallback(keras.callbacks.Callback):\n    ...   def on_epoch_begin(self, epoch, logs=None):\n    ...     if epoch == 4:\n    ...       raise RuntimeError('Interrupting!')\n    >>> callback = keras.callbacks.BackupAndRestore(backup_dir=\"/tmp/backup\")\n    >>> model = keras.models.Sequential([keras.layers.Dense(10)])\n    >>> model.compile(keras.optimizers.SGD(), loss='mse')\n    >>> model.build(input_shape=(None, 20))\n    >>> try:\n    ...   model.fit(np.arange(100).reshape(5, 20), np.zeros(5), epochs=10,\n    ...             batch_size=1, callbacks=[callback, InterruptingCallback()],\n    ...             verbose=0)\n    ... except:\n    ...   pass\n    >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),\n    ...                     epochs=10, batch_size=1, callbacks=[callback],\n    ...                     verbose=0)\n    >>> # Only 6 more epochs are run, since first training got interrupted at\n    >>> # zero-indexed epoch 4, second training will continue from 4 to 9.\n    >>> len(history.history['loss'])\n    >>> 6\n\n    Args:\n        backup_dir: String, path of directory where to store the data\n            needed to restore the model. The directory\n            cannot be reused elsewhere to store other files, e.g. by the\n            `BackupAndRestore` callback of another training run,\n            or by another callback (e.g. `ModelCheckpoint`)\n            of the same training run.\n        save_freq: `\"epoch\"`, integer, or `False`. When set to `\"epoch\"`\n          the callback saves the checkpoint at the end of each epoch.\n          When set to an integer, the callback saves the checkpoint every\n          `save_freq` batches. Set `save_freq=False` only if using\n          preemption checkpointing (i.e. with `save_before_preemption=True`).\n        double_checkpoint: Boolean. If enabled, `BackupAndRestore` callback\n          will save 2 last training states (current and previous). After\n          interruption if current state can't be loaded due to IO error\n          (e.g. file corrupted) it will try to restore previous one. Such\n          behaviour will consume twice more space on disk, but increase fault\n          tolerance. Defaults to `False`.\n        delete_checkpoint: Boolean. This `BackupAndRestore`\n          callback works by saving a checkpoint to back up the training state.\n          If `delete_checkpoint=True`, the checkpoint will be deleted after\n          training is finished. Use `False` if you'd like to keep the checkpoint\n          for future usage. Defaults to `True`.\n    \"\"\"\n\n    def __init__(\n        self,\n        backup_dir,\n        save_freq=\"epoch\",\n        double_checkpoint=False,\n        delete_checkpoint=True,\n    ):\n        super().__init__()\n        self.save_freq = save_freq\n        self.double_checkpoint = double_checkpoint\n        self.delete_checkpoint = delete_checkpoint\n        self._batches_seen_since_last_saving = 0\n        self._last_batch_seen = 0\n        self._current_epoch = 0\n\n        if not backup_dir:\n            raise ValueError(\"Empty `backup_dir` argument passed\")\n        self.backup_dir = backup_dir\n        self._weights_path = file_utils.join(backup_dir, \"latest.weights.h5\")\n        self._training_metadata_path = file_utils.join(\n            backup_dir, \"training_metadata.json\"\n        )\n        self._prev_weights_path = f\"{self._weights_path}.bkp\"\n        self._prev_training_metadata_path = (\n            f\"{self._training_metadata_path}.bkp\"\n        )\n        if save_freq != \"epoch\" and not isinstance(save_freq, int):\n            raise ValueError(\n                \"Invalid value for argument `save_freq`. \"\n                f\"Received: save_freq={save_freq}. \"\n                \"Expected either 'epoch' or an integer value.\"\n            )\n\n    def on_train_begin(self, logs=None):\n        try:\n            self._load_model()\n        except OSError as e:\n            # Weights may be corrupted. Trying to load previous one.\n            if not file_utils.exists(self._prev_weights_path):\n                raise e\n            file_utils.copy(self._prev_weights_path, self._weights_path)\n            if file_utils.exists(self._prev_training_metadata_path):\n                file_utils.copy(\n                    self._prev_training_metadata_path,\n                    self._training_metadata_path,\n                )\n            elif file_utils.exists(self._training_metadata_path):\n                file_utils.remove(self._training_metadata_path)\n            self._load_model()\n\n    def _load_model(self):\n        \"\"\"Get training state from temporary file and restore it.\"\"\"\n        if not self.model.built:\n            raise ValueError(\n                \"To use the BackupAndRestore callback, \"\n                \"you model must be built before you call `fit()`. \"\n                f\"Model {self.model} is unbuilt. You can build it \"\n                \"beforehand by calling it on a batch of data.\"\n            )\n        if file_utils.exists(self._weights_path):\n            if (\n                self.model.optimizer is not None\n                and not self.model.optimizer.built\n            ):\n                # Make sure optimizer weights exist before loading.\n                self.model.optimizer.build(self.model.trainable_variables)\n            self.model.load_weights(self._weights_path)\n\n        if file_utils.exists(self._training_metadata_path):\n            with file_utils.File(self._training_metadata_path, \"r\") as f:\n                training_metadata = json.loads(f.read())\n            epoch = training_metadata[\"epoch\"]\n            self.model._initial_epoch = epoch\n\n    def on_epoch_end(self, epoch, logs=None):\n        self._current_epoch = epoch + 1\n        self._last_batch_seen = 0\n        if self.save_freq == \"epoch\":\n            self._save_model()\n\n    def on_train_batch_end(self, batch, logs=None):\n        if self._should_save_on_batch(batch):\n            self._save_model()\n\n    def _save_model(self):\n        \"\"\"Saves the model.\n\n        Args:\n            epoch: the epoch this iteration is in.\n            batch: the batch this iteration is in. `None` if the `save_freq`\n                is set to `\"epoch\"`.\n            logs: the `logs` dict passed in to `on_batch_end` or `on_epoch_end`.\n        \"\"\"\n        # Create host directory if it doesn't exist.\n        if not file_utils.exists(self.backup_dir):\n            file_utils.makedirs(self.backup_dir)\n        if self.double_checkpoint and file_utils.exists(self._weights_path):\n            file_utils.copy(self._weights_path, self._prev_weights_path)\n        if self.double_checkpoint and file_utils.exists(\n            self._training_metadata_path\n        ):\n            file_utils.copy(\n                self._training_metadata_path, self._prev_training_metadata_path\n            )\n        self.model.save_weights(filepath=self._weights_path, overwrite=True)\n        with file_utils.File(self._training_metadata_path, \"w\") as f:\n            training_metadata = {\n                \"epoch\": self._current_epoch,\n                \"batch\": self._last_batch_seen,\n            }\n            f.write(json.dumps(training_metadata))\n\n    def _should_save_on_batch(self, batch):\n        \"\"\"Handles batch-level saving logic, supports steps_per_execution.\"\"\"\n        if self.save_freq == \"epoch\":\n            return False\n        if batch <= self._last_batch_seen:  # New epoch.\n            add_batches = batch + 1  # batches are zero-indexed.\n        else:\n            add_batches = batch - self._last_batch_seen\n        self._batches_seen_since_last_saving += add_batches\n        self._last_batch_seen = batch\n\n        if self._batches_seen_since_last_saving >= self.save_freq:\n            self._batches_seen_since_last_saving = 0\n            return True\n        return False\n\n    def on_train_end(self, logs=None):\n        if self.delete_checkpoint and file_utils.exists(self.backup_dir):\n            file_utils.rmtree(self.backup_dir)\n"
  },
  {
    "path": "keras/src/callbacks/backup_and_restore_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import callbacks\nfrom keras.src import layers\nfrom keras.src import testing\nfrom keras.src.models import Sequential\nfrom keras.src.utils import file_utils\n\n\nclass InterruptingCallback(callbacks.Callback):\n    \"\"\"A callback to intentionally interrupt training.\"\"\"\n\n    def __init__(self, steps_int, epoch_int):\n        self.batch_count = 0\n        self.epoch_count = 0\n        self.steps_int = steps_int\n        self.epoch_int = epoch_int\n\n    def on_epoch_end(self, epoch, log=None):\n        self.epoch_count += 1\n        if self.epoch_int is not None and self.epoch_count == self.epoch_int:\n            raise RuntimeError(\"EpochInterruption\")\n\n    def on_batch_end(self, batch, logs=None):\n        self.batch_count += 1\n        if self.steps_int is not None and self.batch_count == self.steps_int:\n            raise RuntimeError(\"StepsInterruption\")\n\n\nclass CanaryLayer(layers.Layer):\n    def __init__(self):\n        super().__init__()\n        self.counter = self.add_weight(\n            shape=(), initializer=\"zeros\", dtype=\"float32\", trainable=False\n        )\n\n    def call(self, x):\n        self.counter.assign_add(1)\n        return x\n\n\nclass BackupAndRestoreCallbackTest(testing.TestCase):\n    def make_model(self):\n        model = Sequential(\n            [\n                layers.Input((3,)),\n                CanaryLayer(),\n                layers.Dense(1),\n            ]\n        )\n        model.compile(\n            loss=\"mse\",\n            optimizer=\"sgd\",\n            metrics=[\"mse\"],\n        )\n        return model\n\n    # Check invalid save_freq, both string and non integer\n    def test_save_freq_unknown_error(self):\n        with self.assertRaisesRegex(ValueError, expected_regex=\"Invalid value\"):\n            callbacks.BackupAndRestore(\n                backup_dir=\"backup_dir\", save_freq=\"batch\"\n            )\n\n        with self.assertRaisesRegex(ValueError, expected_regex=\"Invalid value\"):\n            callbacks.BackupAndRestore(backup_dir=\"backup_dir\", save_freq=0.15)\n\n    # Checking if after interruption, correct model params and\n    # weights are loaded in step-wise backup\n    @pytest.mark.requires_trainable_backend\n    def test_best_case_step(self):\n        temp_dir = self.get_temp_dir()\n        backup_dir = file_utils.join(temp_dir, \"subdir\")\n        self.assertFalse(file_utils.exists(backup_dir))\n\n        model = self.make_model()\n        cbk = callbacks.BackupAndRestore(backup_dir, save_freq=1)\n\n        x_train = np.random.random((10, 3))\n        y_train = np.random.random((10, 1))\n\n        try:\n            model.fit(\n                x_train,\n                y_train,\n                batch_size=4,\n                callbacks=[\n                    cbk,\n                    InterruptingCallback(steps_int=2, epoch_int=None),\n                ],\n                epochs=2,\n                verbose=0,\n            )\n        except RuntimeError:\n            self.assertTrue(file_utils.exists(backup_dir))\n            self.assertEqual(cbk._current_epoch, 0)\n            self.assertEqual(cbk._last_batch_seen, 1)\n            self.assertEqual(int(model.layers[0].counter.value), 2)\n\n            hist = model.fit(\n                x_train, y_train, batch_size=4, callbacks=[cbk], epochs=5\n            )\n\n            self.assertEqual(cbk._current_epoch, 5)\n            self.assertEqual(hist.epoch[-1], 4)\n            self.assertEqual(int(model.layers[0].counter.value), 17)\n\n    # Checking if after interruption, correct model params and\n    # weights are loaded in epoch-wise backup\n    @pytest.mark.requires_trainable_backend\n    def test_best_case_epoch(self):\n        temp_dir = self.get_temp_dir()\n        backup_dir = file_utils.join(temp_dir, \"subdir\")\n        self.assertFalse(file_utils.exists(backup_dir))\n\n        model = self.make_model()\n        self.assertEqual(int(model.layers[0].counter.value), 0)\n        cbk = callbacks.BackupAndRestore(\n            backup_dir=backup_dir, save_freq=\"epoch\"\n        )\n\n        x_train = np.random.random((10, 3))\n        y_train = np.random.random((10, 1))\n\n        try:\n            model.fit(\n                x_train,\n                y_train,\n                batch_size=4,\n                callbacks=[\n                    cbk,\n                    InterruptingCallback(steps_int=None, epoch_int=2),\n                ],\n                epochs=6,\n                verbose=0,\n            )\n        except RuntimeError:\n            self.assertEqual(cbk._current_epoch, 2)\n            self.assertTrue(file_utils.exists(backup_dir))\n            self.assertEqual(int(model.layers[0].counter.value), 6)\n\n            hist = model.fit(\n                x_train, y_train, batch_size=4, callbacks=[cbk], epochs=5\n            )\n            self.assertEqual(cbk._current_epoch, 5)\n            self.assertEqual(hist.epoch[-1], 4)\n            self.assertEqual(int(model.layers[0].counter.value), 5 * 3)\n\n    # Checking if after interruption and weights corruption, previous model\n    # params and weights are loaded\n    @pytest.mark.requires_trainable_backend\n    def test_backup_corrupted(self):\n        temp_dir = self.get_temp_dir()\n        backup_dir = file_utils.join(temp_dir, \"subdir\")\n        self.assertFalse(file_utils.exists(backup_dir))\n\n        model = self.make_model()\n        self.assertEqual(int(model.layers[0].counter.value), 0)\n        cbk = callbacks.BackupAndRestore(\n            backup_dir=backup_dir, save_freq=\"epoch\", double_checkpoint=True\n        )\n\n        x_train = np.random.random((10, 3))\n        y_train = np.random.random((10, 1))\n\n        try:\n            model.fit(\n                x_train,\n                y_train,\n                batch_size=4,\n                callbacks=[\n                    cbk,\n                    InterruptingCallback(steps_int=None, epoch_int=2),\n                ],\n                epochs=6,\n                verbose=0,\n            )\n        except RuntimeError:\n            self.assertEqual(cbk._current_epoch, 2)\n            self.assertTrue(file_utils.exists(backup_dir))\n            self.assertTrue(file_utils.exists(cbk._weights_path))\n            self.assertTrue(file_utils.exists(cbk._training_metadata_path))\n            self.assertTrue(file_utils.exists(cbk._prev_weights_path))\n            self.assertTrue(file_utils.exists(cbk._prev_training_metadata_path))\n            self.assertEqual(int(model.layers[0].counter.value), 6)\n\n            # Corruption weights\n            with file_utils.File(cbk._weights_path, \"w\") as f:\n                f.write(\"0\")\n\n            hist = model.fit(\n                x_train, y_train, batch_size=4, callbacks=[cbk], epochs=5\n            )\n            self.assertEqual(cbk._current_epoch, 5)\n            self.assertEqual(hist.epoch[-1], 4)\n            self.assertEqual(int(model.layers[0].counter.value), 5 * 3)\n\n    # Checking if after interruption, when model is deleted\n    @pytest.mark.requires_trainable_backend\n    def test_model_deleted_case_epoch(self):\n        temp_dir = self.get_temp_dir()\n        backup_dir = file_utils.join(temp_dir, \"subdir\")\n        self.assertFalse(file_utils.exists(backup_dir))\n\n        model = self.make_model()\n        cbk = callbacks.BackupAndRestore(backup_dir, save_freq=\"epoch\")\n\n        x_train = np.random.random((10, 3))\n        y_train = np.random.random((10, 1))\n        model.fit(\n            x_train,\n            y_train,\n            batch_size=4,\n            callbacks=[cbk],\n            epochs=2,\n            verbose=0,\n        )\n        self.assertFalse(file_utils.exists(backup_dir))\n\n    def test_backup_dir_empty_error(self):\n        with self.assertRaisesRegex(\n            ValueError, expected_regex=\"Empty `backup_dir` argument passed\"\n        ):\n            callbacks.BackupAndRestore(backup_dir=\"\", save_freq=\"epoch\")\n\n    def test_backup_dir_none_error(self):\n        with self.assertRaisesRegex(\n            ValueError, expected_regex=\"Empty `backup_dir` argument passed\"\n        ):\n            callbacks.BackupAndRestore(backup_dir=None, save_freq=\"epoch\")\n"
  },
  {
    "path": "keras/src/callbacks/callback.py",
    "content": "from keras.src import backend\nfrom keras.src import utils\nfrom keras.src.api_export import keras_export\n\n\n@keras_export(\"keras.callbacks.Callback\")\nclass Callback:\n    \"\"\"Base class used to build new callbacks.\n\n    Callbacks can be passed to keras methods such as `fit()`, `evaluate()`, and\n    `predict()` in order to hook into the various stages of the model training,\n    evaluation, and inference lifecycle.\n\n    To create a custom callback, subclass `keras.callbacks.Callback` and\n    override the method associated with the stage of interest.\n\n    Example:\n\n    >>> training_finished = False\n    >>> class MyCallback(Callback):\n    ...   def on_train_end(self, logs=None):\n    ...     global training_finished\n    ...     training_finished = True\n    >>> model = Sequential([\n    ...     layers.Dense(1, input_shape=(1,))])\n    >>> model.compile(loss='mean_squared_error')\n    >>> model.fit(np.array([[1.0]]), np.array([[1.0]]),\n    ...           callbacks=[MyCallback()])\n    >>> assert training_finished == True\n\n    If you want to use `Callback` objects in a custom training loop:\n\n    1. You should pack all your callbacks into a single `callbacks.CallbackList`\n       so they can all be called together.\n    2. You will need to manually call all the `on_*` methods at the appropriate\n       locations in your loop. Like this:\n\n    Example:\n\n    ```python\n    callbacks =  keras.callbacks.CallbackList([...])\n    callbacks.append(...)\n    callbacks.on_train_begin(...)\n    for epoch in range(EPOCHS):\n        callbacks.on_epoch_begin(epoch)\n        for i, data in dataset.enumerate():\n        callbacks.on_train_batch_begin(i)\n        batch_logs = model.train_step(data)\n        callbacks.on_train_batch_end(i, batch_logs)\n        epoch_logs = ...\n        callbacks.on_epoch_end(epoch, epoch_logs)\n    final_logs=...\n    callbacks.on_train_end(final_logs)\n    ```\n\n    Attributes:\n        params: Dict. Training parameters\n            (eg. verbosity, batch size, number of epochs...).\n        model: Instance of `Model`.\n            Reference of the model being trained.\n\n    The `logs` dictionary that callback methods\n    take as argument will contain keys for quantities relevant to\n    the current batch or epoch (see method-specific docstrings).\n    \"\"\"\n\n    def __init__(self):\n        self.params = None\n        self._model = None\n\n    def set_params(self, params):\n        self.params = params\n\n    def set_model(self, model):\n        self._model = model\n\n    @property\n    def model(self):\n        if backend.backend() == \"torch\":\n            from torch.nn.parallel import DistributedDataParallel\n\n            if isinstance(self._model, DistributedDataParallel):\n                # Keras Callbacks expect to work with Keras models. e.g\n                # ModelCheckpoint and EarlyStopping both attempt to call\n                # keras-specific APIs on the value returned from this\n                # property. If this callback was created against a DDP\n                # wrapper instead of the underlying keras.Model, it is\n                # likely to fail. Return self._model.module for DDP\n                # instances instead.\n                return self._model.module\n\n        if backend.backend() == \"jax\" and hasattr(\n            self._model, \"jax_state_sync\"\n        ):\n            # With JAX, by default the model state is not\n            # attached to the model in the middle of an\n            # epoch. We have to force a sync before\n            # accessing model state for e.g. checkpointing.\n            self._model.jax_state_sync()\n        return self._model\n\n    @utils.default\n    def on_batch_begin(self, batch, logs=None):\n        \"\"\"A backwards compatibility alias for `on_train_batch_begin`.\"\"\"\n\n    @utils.default\n    def on_batch_end(self, batch, logs=None):\n        \"\"\"A backwards compatibility alias for `on_train_batch_end`.\"\"\"\n\n    @utils.default\n    def on_epoch_begin(self, epoch, logs=None):\n        \"\"\"Called at the start of an epoch.\n\n        Subclasses should override for any actions to run. This function should\n        only be called during TRAIN mode.\n\n        Args:\n            epoch: Integer, index of epoch.\n            logs: Dict. Currently no data is passed to this argument for this\n              method but that may change in the future.\n        \"\"\"\n\n    @utils.default\n    def on_epoch_end(self, epoch, logs=None):\n        \"\"\"Called at the end of an epoch.\n\n        Subclasses should override for any actions to run. This function should\n        only be called during TRAIN mode.\n\n        Args:\n            epoch: Integer, index of epoch.\n            logs: Dict, metric results for this training epoch, and for the\n              validation epoch if validation is performed. Validation result\n              keys are prefixed with `val_`. For training epoch, the values of\n              the `Model`'s metrics are returned. Example:\n              `{'loss': 0.2, 'accuracy': 0.7}`.\n        \"\"\"\n\n    @utils.default\n    def on_train_batch_begin(self, batch, logs=None):\n        \"\"\"Called at the beginning of a training batch in `fit` methods.\n\n        Subclasses should override for any actions to run.\n\n        Note that if the `steps_per_execution` argument to `compile` in\n        `Model` is set to `N`, this method will only be called every\n        `N` batches.\n\n        Args:\n            batch: Integer, index of batch within the current epoch.\n            logs: Dict. Currently no data is passed to this argument for this\n              method but that may change in the future.\n        \"\"\"\n        # For backwards compatibility.\n        self.on_batch_begin(batch, logs=logs)\n\n    @utils.default\n    def on_train_batch_end(self, batch, logs=None):\n        \"\"\"Called at the end of a training batch in `fit` methods.\n\n        Subclasses should override for any actions to run.\n\n        Note that if the `steps_per_execution` argument to `compile` in\n        `Model` is set to `N`, this method will only be called every\n        `N` batches.\n\n        Args:\n            batch: Integer, index of batch within the current epoch.\n            logs: Dict. Aggregated metric results up until this batch.\n        \"\"\"\n        # For backwards compatibility.\n        self.on_batch_end(batch, logs=logs)\n\n    @utils.default\n    def on_test_batch_begin(self, batch, logs=None):\n        \"\"\"Called at the beginning of a batch in `evaluate` methods.\n\n        Also called at the beginning of a validation batch in the `fit`\n        methods, if validation data is provided.\n\n        Subclasses should override for any actions to run.\n\n        Note that if the `steps_per_execution` argument to `compile` in\n        `Model` is set to `N`, this method will only be called every\n        `N` batches.\n\n        Args:\n            batch: Integer, index of batch within the current epoch.\n            logs: Dict. Currently no data is passed to this argument for this\n              method but that may change in the future.\n        \"\"\"\n\n    @utils.default\n    def on_test_batch_end(self, batch, logs=None):\n        \"\"\"Called at the end of a batch in `evaluate` methods.\n\n        Also called at the end of a validation batch in the `fit`\n        methods, if validation data is provided.\n\n        Subclasses should override for any actions to run.\n\n        Note that if the `steps_per_execution` argument to `compile` in\n        `Model` is set to `N`, this method will only be called every\n        `N` batches.\n\n        Args:\n            batch: Integer, index of batch within the current epoch.\n            logs: Dict. Aggregated metric results up until this batch.\n        \"\"\"\n\n    @utils.default\n    def on_predict_batch_begin(self, batch, logs=None):\n        \"\"\"Called at the beginning of a batch in `predict` methods.\n\n        Subclasses should override for any actions to run.\n\n        Note that if the `steps_per_execution` argument to `compile` in\n        `Model` is set to `N`, this method will only be called every\n        `N` batches.\n\n        Args:\n            batch: Integer, index of batch within the current epoch.\n            logs: Dict. Currently no data is passed to this argument for this\n              method but that may change in the future.\n        \"\"\"\n\n    @utils.default\n    def on_predict_batch_end(self, batch, logs=None):\n        \"\"\"Called at the end of a batch in `predict` methods.\n\n        Subclasses should override for any actions to run.\n\n        Note that if the `steps_per_execution` argument to `compile` in\n        `Model` is set to `N`, this method will only be called every\n        `N` batches.\n\n        Args:\n            batch: Integer, index of batch within the current epoch.\n            logs: Dict. Aggregated metric results up until this batch.\n        \"\"\"\n\n    @utils.default\n    def on_train_begin(self, logs=None):\n        \"\"\"Called at the beginning of training.\n\n        Subclasses should override for any actions to run.\n\n        Args:\n            logs: Dict. Currently no data is passed to this argument for this\n              method but that may change in the future.\n        \"\"\"\n\n    @utils.default\n    def on_train_end(self, logs=None):\n        \"\"\"Called at the end of training.\n\n        Subclasses should override for any actions to run.\n\n        Args:\n            logs: Dict. Currently the output of the last call to\n              `on_epoch_end()` is passed to this argument for this method but\n              that may change in the future.\n        \"\"\"\n\n    @utils.default\n    def on_test_begin(self, logs=None):\n        \"\"\"Called at the beginning of evaluation or validation.\n\n        Subclasses should override for any actions to run.\n\n        Args:\n            logs: Dict. Currently no data is passed to this argument for this\n              method but that may change in the future.\n        \"\"\"\n\n    @utils.default\n    def on_test_end(self, logs=None):\n        \"\"\"Called at the end of evaluation or validation.\n\n        Subclasses should override for any actions to run.\n\n        Args:\n            logs: Dict. Currently the output of the last call to\n              `on_test_batch_end()` is passed to this argument for this method\n              but that may change in the future.\n        \"\"\"\n\n    @utils.default\n    def on_predict_begin(self, logs=None):\n        \"\"\"Called at the beginning of prediction.\n\n        Subclasses should override for any actions to run.\n\n        Args:\n            logs: Dict. Currently no data is passed to this argument for this\n              method but that may change in the future.\n        \"\"\"\n\n    @utils.default\n    def on_predict_end(self, logs=None):\n        \"\"\"Called at the end of prediction.\n\n        Subclasses should override for any actions to run.\n\n        Args:\n            logs: Dict. Currently no data is passed to this argument for this\n              method but that may change in the future.\n        \"\"\"\n"
  },
  {
    "path": "keras/src/callbacks/callback_list.py",
    "content": "import concurrent.futures\n\nfrom keras.src import backend\nfrom keras.src import tree\nfrom keras.src import utils\nfrom keras.src.api_export import keras_export\nfrom keras.src.callbacks.callback import Callback\nfrom keras.src.callbacks.history import History\nfrom keras.src.callbacks.progbar_logger import ProgbarLogger\nfrom keras.src.utils import python_utils\n\n\n@keras_export(\"keras.callbacks.CallbackList\")\nclass CallbackList(Callback):\n    \"\"\"Container abstracting a list of callbacks.\"\"\"\n\n    def __init__(\n        self,\n        callbacks=None,\n        add_history=False,\n        add_progbar=False,\n        model=None,\n        **params,\n    ):\n        \"\"\"Container for `Callback` instances.\n\n        This object wraps a list of `Callback` instances, making it possible\n        to call them all at once via a single endpoint\n        (e.g. `callback_list.on_epoch_end(...)`).\n\n        Args:\n            callbacks: List of `Callback` instances.\n            add_history: Whether a `History` callback should be added, if one\n                does not already exist in the `callbacks` list.\n            add_progbar: Whether a `ProgbarLogger` callback should be added, if\n                one does not already exist in the `callbacks` list.\n            model: The `Model` these callbacks are used with.\n            **params: If provided, parameters will be passed to each `Callback`\n                via `Callback.set_params`.\n        \"\"\"\n        self.callbacks = tree.flatten(callbacks) if callbacks else []\n        self._in_begin_end_block_count = 0\n        self._executor = None\n        self._async_train = False\n        self._async_test = False\n        self._async_predict = False\n        self._futures = []\n        self._configure_async_dispatch(callbacks)\n        self._add_default_callbacks(add_history, add_progbar)\n        self.set_model(model)\n        self.set_params(params)\n\n    def set_params(self, params):\n        self.params = params\n        if params:\n            for callback in self.callbacks:\n                callback.set_params(params)\n\n    def _configure_async_dispatch(self, callbacks):\n        # Determine whether callbacks can be dispatched asynchronously.\n        if not backend.IS_THREAD_SAFE:\n            return\n        async_train = True\n        async_test = True\n        async_predict = True\n        if callbacks:\n            if isinstance(callbacks, (list, tuple)):\n                for cbk in callbacks:\n                    if getattr(cbk, \"async_safe\", False):\n                        # Callbacks that expose self.async_safe == True\n                        # will be assumed safe for async dispatch.\n                        continue\n                    if not utils.is_default(cbk.on_batch_end):\n                        async_train = False\n                    if not utils.is_default(cbk.on_train_batch_end):\n                        async_train = False\n                    if not utils.is_default(cbk.on_test_batch_end):\n                        async_test = False\n                    if not utils.is_default(cbk.on_predict_batch_end):\n                        async_predict = False\n\n        self._async_train = async_train\n        self._async_test = async_test\n        self._async_predict = async_predict\n\n    def _add_default_callbacks(self, add_history, add_progbar):\n        \"\"\"Adds `Callback`s that are always present.\"\"\"\n        self._progbar = None\n        self._history = None\n\n        for cb in self.callbacks:\n            if isinstance(cb, ProgbarLogger):\n                self._progbar = cb\n            elif isinstance(cb, History):\n                self._history = cb\n\n        if self._history is None and add_history:\n            self._history = History()\n            self.callbacks.append(self._history)\n\n        if self._progbar is None and add_progbar:\n            self._progbar = ProgbarLogger()\n            self.callbacks.append(self._progbar)\n\n    def set_model(self, model):\n        if not model:\n            return\n        super().set_model(model)\n        if self._history:\n            model.history = self._history\n        for callback in self.callbacks:\n            callback.set_model(model)\n\n    def _on_begin(self):\n        \"\"\"Called by `on_train/test/predict_begin`.\n\n        Start the executor for async calls if needed.\n        \"\"\"\n        self._in_begin_end_block_count += 1\n        if (\n            self._in_begin_end_block_count == 1\n            and (self._async_train or self._async_test or self._async_predict)\n            and self._executor is None\n        ):\n            self._executor = concurrent.futures.ThreadPoolExecutor()\n\n    def _on_end(self):\n        \"\"\"Called by `on_train/test/predict_end`.\n\n        Shutdown the executor for async calls if all begin/end blocks completed.\n        \"\"\"\n        self._in_begin_end_block_count -= 1\n        if self._in_begin_end_block_count < 0:\n            raise ValueError(\n                \"`on_xxx_end` called without corresponding `on_xxx_begin`\"\n            )\n        if self._in_begin_end_block_count == 0 and self._executor is not None:\n            self._executor.shutdown()\n            self._executor = None\n\n    def _async_dispatch(self, fn, *args):\n        for future in self._futures:\n            if future.done():\n                future.result()\n                self._futures.remove(future)\n        future = self._executor.submit(fn, *args)\n        self._futures.append(future)\n\n    def _flush_futures(self):\n        \"\"\"Waits for all futures to complete and clears the list.\"\"\"\n        for future in self._futures:\n            future.result()\n        self._futures = []\n\n    def on_batch_begin(self, batch, logs=None):\n        logs = python_utils.pythonify_logs(logs)\n        for callback in self.callbacks:\n            callback.on_batch_begin(batch, logs=logs)\n\n    def on_epoch_begin(self, epoch, logs=None):\n        logs = python_utils.pythonify_logs(logs)\n        for callback in self.callbacks:\n            callback.on_epoch_begin(epoch, logs)\n\n    def on_epoch_end(self, epoch, logs=None):\n        if self._async_train:\n            self._flush_futures()\n\n        logs = python_utils.pythonify_logs(logs)\n        for callback in self.callbacks:\n            callback.on_epoch_end(epoch, logs)\n\n    def on_train_batch_begin(self, batch, logs=None):\n        logs = python_utils.pythonify_logs(logs)\n        for callback in self.callbacks:\n            callback.on_train_batch_begin(batch, logs=logs)\n\n    def on_test_batch_begin(self, batch, logs=None):\n        logs = python_utils.pythonify_logs(logs)\n        for callback in self.callbacks:\n            callback.on_test_batch_begin(batch, logs=logs)\n\n    def on_predict_batch_begin(self, batch, logs=None):\n        logs = python_utils.pythonify_logs(logs)\n        for callback in self.callbacks:\n            callback.on_predict_batch_begin(batch, logs=logs)\n\n    def on_batch_end(self, batch, logs=None):\n        if self._async_train:\n            self._async_dispatch(self._on_batch_end, batch, logs)\n        else:\n            self._on_batch_end(batch, logs)\n\n    def on_train_batch_end(self, batch, logs=None):\n        if self._async_train:\n            self._async_dispatch(self._on_train_batch_end, batch, logs)\n        else:\n            self._on_train_batch_end(batch, logs)\n\n    def on_test_batch_end(self, batch, logs=None):\n        if self._async_test:\n            self._async_dispatch(self._on_test_batch_end, batch, logs)\n        else:\n            self._on_test_batch_end(batch, logs)\n\n    def on_predict_batch_end(self, batch, logs=None):\n        if self._async_predict:\n            self._async_dispatch(self._on_predict_batch_end, batch, logs)\n        else:\n            self._on_predict_batch_end(batch, logs)\n\n    def _on_batch_end(self, batch, logs=None):\n        logs = python_utils.pythonify_logs(logs)\n        for callback in self.callbacks:\n            callback.on_batch_end(batch, logs=logs)\n\n    def _on_train_batch_end(self, batch, logs=None):\n        logs = python_utils.pythonify_logs(logs)\n        for callback in self.callbacks:\n            callback.on_train_batch_end(batch, logs=logs)\n\n    def _on_test_batch_end(self, batch, logs=None):\n        logs = python_utils.pythonify_logs(logs)\n        for callback in self.callbacks:\n            callback.on_test_batch_end(batch, logs=logs)\n\n    def _on_predict_batch_end(self, batch, logs=None):\n        logs = python_utils.pythonify_logs(logs)\n        for callback in self.callbacks:\n            callback.on_predict_batch_end(batch, logs=logs)\n\n    def on_train_begin(self, logs=None):\n        self._on_begin()\n\n        logs = python_utils.pythonify_logs(logs)\n        for callback in self.callbacks:\n            callback.on_train_begin(logs)\n\n    def on_train_end(self, logs=None):\n        if self._async_train:\n            self._flush_futures()\n\n        logs = python_utils.pythonify_logs(logs)\n        for callback in self.callbacks:\n            callback.on_train_end(logs)\n\n        self._on_end()\n\n    def on_test_begin(self, logs=None):\n        self._on_begin()\n\n        logs = python_utils.pythonify_logs(logs)\n        for callback in self.callbacks:\n            callback.on_test_begin(logs)\n\n    def on_test_end(self, logs=None):\n        if self._async_test:\n            self._flush_futures()\n\n        logs = python_utils.pythonify_logs(logs)\n        for callback in self.callbacks:\n            callback.on_test_end(logs)\n\n        self._on_end()\n\n    def on_predict_begin(self, logs=None):\n        self._on_begin()\n\n        logs = python_utils.pythonify_logs(logs)\n        for callback in self.callbacks:\n            callback.on_predict_begin(logs)\n\n    def on_predict_end(self, logs=None):\n        if self._async_predict:\n            self._flush_futures()\n\n        logs = python_utils.pythonify_logs(logs)\n        for callback in self.callbacks:\n            callback.on_predict_end(logs)\n\n        self._on_end()\n"
  },
  {
    "path": "keras/src/callbacks/callback_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import models\nfrom keras.src import testing\nfrom keras.src.callbacks.callback import Callback\n\n\nclass CallbackTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_model_state_is_current_on_epoch_end(self):\n        test_obj = self\n\n        class TestModel(models.Model):\n            def __init__(self):\n                super().__init__()\n                self.iterations = self.add_variable(\n                    shape=(), initializer=\"zeros\", trainable=False\n                )\n\n            def call(self, inputs):\n                self.iterations.assign(self.iterations + 1)\n                return inputs\n\n        class CBK(Callback):\n            def on_batch_end(self, batch, logs):\n                test_obj.assertEqual(int(self.model.iterations), batch + 1)\n\n        model = TestModel()\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n        x = np.random.random((8, 1))\n        y = np.random.random((8, 1))\n        model.fit(x, y, callbacks=[CBK()], batch_size=2)\n"
  },
  {
    "path": "keras/src/callbacks/csv_logger.py",
    "content": "import collections\nimport csv\n\nimport numpy as np\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.callbacks.callback import Callback\nfrom keras.src.utils import file_utils\n\n\n@keras_export(\"keras.callbacks.CSVLogger\")\nclass CSVLogger(Callback):\n    \"\"\"Callback that streams epoch results to a CSV file.\n\n    Supports all values that can be represented as a string,\n    including 1D iterables such as `np.ndarray`.\n\n    Args:\n        filename: Filename of the CSV file, e.g. `'run/log.csv'`.\n        separator: String used to separate elements in the CSV file.\n        append: Boolean. True: append if file exists (useful for continuing\n            training). False: overwrite existing file.\n\n    Example:\n\n    ```python\n    csv_logger = CSVLogger('training.log')\n    model.fit(X_train, Y_train, callbacks=[csv_logger])\n    ```\n    \"\"\"\n\n    def __init__(self, filename, separator=\",\", append=False):\n        super().__init__()\n        self.sep = separator\n        self.filename = file_utils.path_to_string(filename)\n        self.append = append\n        self.writer = None\n        self.keys = None\n        self.append_header = True\n        self.csv_file = None\n\n    def on_train_begin(self, logs=None):\n        if self.append:\n            if file_utils.exists(self.filename):\n                with file_utils.File(self.filename, \"r\") as f:\n                    self.append_header = not bool(len(f.readline()))\n            mode = \"a\"\n        else:\n            mode = \"w\"\n        # ensure csv_file is None or closed before reassigning\n        if self.csv_file and not self.csv_file.closed:\n            self.csv_file.close()\n        self.csv_file = file_utils.File(self.filename, mode)\n        # Reset writer and keys\n        self.writer = None\n        self.keys = None\n\n    def on_epoch_end(self, epoch, logs=None):\n        logs = logs or {}\n\n        def handle_value(k):\n            is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0\n            if isinstance(k, str):\n                return k\n            elif (\n                isinstance(k, collections.abc.Iterable)\n                and not is_zero_dim_ndarray\n            ):\n                return f'\"[{\", \".join(map(str, k))}]\"'\n            else:\n                return k\n\n        if self.keys is None:\n            self.keys = sorted(logs.keys())\n\n            val_keys_found = False\n            for key in self.keys:\n                if key.startswith(\"val_\"):\n                    val_keys_found = True\n                    break\n            if not val_keys_found and self.keys:\n                self.keys.extend([f\"val_{k}\" for k in self.keys])\n\n        if not self.writer:\n\n            class CustomDialect(csv.excel):\n                delimiter = self.sep\n\n            fieldnames = [\"epoch\"] + (self.keys or [])\n\n            self.writer = csv.DictWriter(\n                self.csv_file, fieldnames=fieldnames, dialect=CustomDialect\n            )\n            if self.append_header:\n                self.writer.writeheader()\n\n        row_dict = collections.OrderedDict({\"epoch\": epoch})\n        row_dict.update(\n            (key, handle_value(logs.get(key, \"NA\"))) for key in self.keys\n        )\n        self.writer.writerow(row_dict)\n        self.csv_file.flush()\n\n    def on_train_end(self, logs=None):\n        if self.csv_file and not self.csv_file.closed:\n            self.csv_file.close()\n        self.writer = None\n"
  },
  {
    "path": "keras/src/callbacks/csv_logger_test.py",
    "content": "import csv\nimport os\nimport re\nimport tempfile\n\nimport numpy as np\nimport pytest\n\nfrom keras.src import callbacks\nfrom keras.src import initializers\nfrom keras.src import layers\nfrom keras.src import testing\nfrom keras.src.models import Sequential\nfrom keras.src.utils import numerical_utils\n\nTRAIN_SAMPLES = 10\nTEST_SAMPLES = 10\nINPUT_DIM = 3\nBATCH_SIZE = 4\n\n\nclass CSVLoggerTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_CSVLogger(self):\n        OUTPUT_DIM = 1\n        np.random.seed(1337)\n        temp_dir = tempfile.TemporaryDirectory()\n        filepath = os.path.join(temp_dir.name, \"log.tsv\")\n\n        sep = \"\\t\"\n        x_train = np.random.random((TRAIN_SAMPLES, INPUT_DIM))\n        y_train = np.random.random((TRAIN_SAMPLES, OUTPUT_DIM))\n        x_test = np.random.random((TEST_SAMPLES, INPUT_DIM))\n        y_test = np.random.random((TEST_SAMPLES, OUTPUT_DIM))\n\n        def make_model():\n            np.random.seed(1337)\n            model = Sequential(\n                [\n                    layers.Dense(2, activation=\"relu\"),\n                    layers.Dense(OUTPUT_DIM),\n                ]\n            )\n            model.compile(\n                loss=\"mse\",\n                optimizer=\"sgd\",\n                metrics=[\"mse\"],\n            )\n            return model\n\n        # case 1, create new file with defined separator\n        model = make_model()\n        cbks = [callbacks.CSVLogger(filepath, separator=sep)]\n        model.fit(\n            x_train,\n            y_train,\n            batch_size=BATCH_SIZE,\n            validation_data=(x_test, y_test),\n            callbacks=cbks,\n            epochs=1,\n            verbose=0,\n        )\n\n        self.assertTrue(os.path.exists(filepath))\n        with open(filepath) as csvfile:\n            dialect = csv.Sniffer().sniff(csvfile.read())\n        self.assertEqual(dialect.delimiter, sep)\n        del model\n        del cbks\n\n        # case 2, append data to existing file, skip header\n        model = make_model()\n        cbks = [callbacks.CSVLogger(filepath, separator=sep, append=True)]\n        model.fit(\n            x_train,\n            y_train,\n            batch_size=BATCH_SIZE,\n            validation_data=(x_test, y_test),\n            callbacks=cbks,\n            epochs=1,\n            verbose=0,\n        )\n\n        # case 3, reuse of CSVLogger object\n        model.fit(\n            x_train,\n            y_train,\n            batch_size=BATCH_SIZE,\n            validation_data=(x_test, y_test),\n            callbacks=cbks,\n            epochs=2,\n            verbose=0,\n        )\n\n        with open(filepath) as csvfile:\n            list_lines = csvfile.readlines()\n            for line in list_lines:\n                self.assertEqual(line.count(sep), 4)\n            self.assertLen(list_lines, 5)\n            output = \" \".join(list_lines)\n            self.assertLen(re.findall(\"epoch\", output), 1)\n\n        os.remove(filepath)\n\n        # case 3, Verify Val. loss also registered when Validation Freq > 1\n        model = make_model()\n        cbks = [callbacks.CSVLogger(filepath, separator=sep)]\n        hist = model.fit(\n            x_train,\n            y_train,\n            batch_size=BATCH_SIZE,\n            validation_data=(x_test, y_test),\n            validation_freq=3,\n            callbacks=cbks,\n            epochs=5,\n            verbose=0,\n        )\n        self.assertTrue(os.path.exists(filepath))\n        # Verify that validation loss is registered at val. freq\n        with open(filepath) as csvfile:\n            rows = csv.DictReader(csvfile, delimiter=sep)\n            for idx, row in enumerate(rows, 1):\n                self.assertIn(\"val_loss\", row)\n                if idx == 3:\n                    self.assertEqual(\n                        row[\"val_loss\"], str(hist.history[\"val_loss\"][0])\n                    )\n                else:\n                    self.assertEqual(row[\"val_loss\"], \"NA\")\n\n    @pytest.mark.requires_trainable_backend\n    def test_stop_training_csv(self):\n        # Test that using the CSVLogger callback with the TerminateOnNaN\n        # callback does not result in invalid CSVs.\n        tmpdir = tempfile.TemporaryDirectory()\n        csv_logfile = os.path.join(tmpdir.name, \"csv_logger.csv\")\n        NUM_CLASSES = 2\n        np.random.seed(1337)\n        x_train = np.random.random((TRAIN_SAMPLES, INPUT_DIM))\n        y_train = np.random.choice(np.arange(NUM_CLASSES), size=TRAIN_SAMPLES)\n        x_test = np.random.random((TEST_SAMPLES, INPUT_DIM))\n        y_test = np.random.choice(np.arange(NUM_CLASSES), size=TEST_SAMPLES)\n\n        y_test = numerical_utils.to_categorical(y_test)\n        y_train = numerical_utils.to_categorical(y_train)\n        model = Sequential()\n        initializer = initializers.Constant(value=1e5)\n        for _ in range(5):\n            model.add(\n                layers.Dense(\n                    2,\n                    activation=\"relu\",\n                    kernel_initializer=initializer,\n                )\n            )\n        model.add(layers.Dense(NUM_CLASSES))\n        model.compile(loss=\"mean_squared_error\", optimizer=\"sgd\")\n\n        history = model.fit(\n            x_train,\n            y_train,\n            batch_size=BATCH_SIZE,\n            validation_data=(x_test, y_test),\n            callbacks=[\n                callbacks.TerminateOnNaN(),\n                callbacks.CSVLogger(csv_logfile),\n            ],\n            epochs=20,\n        )\n        loss = history.history[\"loss\"]\n        self.assertEqual(len(loss), 1)\n        self.assertTrue(np.isnan(loss[0]) or np.isinf(loss[0]))\n\n        values = []\n        with open(csv_logfile) as f:\n            # On Windows, due to \\r\\n line ends, we may end up reading empty\n            # lines after each line. Skip empty lines.\n            values = [x for x in csv.reader(f) if x]\n        self.assertIn(\"nan\", values[-1], \"NaN not logged in CSV Logger.\")\n"
  },
  {
    "path": "keras/src/callbacks/early_stopping.py",
    "content": "import warnings\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.callbacks.monitor_callback import MonitorCallback\nfrom keras.src.utils import io_utils\n\n\n@keras_export(\"keras.callbacks.EarlyStopping\")\nclass EarlyStopping(MonitorCallback):\n    \"\"\"Stop training when a monitored metric has stopped improving.\n\n    Assuming the goal of a training is to minimize the loss. With this, the\n    metric to be monitored would be `'loss'`, and mode would be `'min'`. A\n    `model.fit()` training loop will check at end of every epoch whether\n    the loss is no longer decreasing, considering the `min_delta` and\n    `patience` if applicable. Once it's found no longer decreasing,\n    `model.stop_training` is marked True and the training terminates.\n\n    The quantity to be monitored needs to be available in `logs` dict.\n    To make it so, pass the loss or metrics at `model.compile()`.\n\n    Args:\n        monitor: Quantity to be monitored. Defaults to `\"val_loss\"`.\n        min_delta: Minimum change in the monitored quantity to qualify as an\n            improvement, i.e. an absolute change of less than min_delta, will\n            count as no improvement. Defaults to `0`.\n        patience: Number of epochs with no improvement after which training will\n            be stopped. Defaults to `0`.\n        verbose: Verbosity mode, 0 or 1. Mode 0 is silent, and mode 1 displays\n            messages when the callback takes an action. Defaults to `0`.\n        mode: One of `{\"auto\", \"min\", \"max\"}`. In `min` mode, training will stop\n            when the quantity monitored has stopped decreasing; in `\"max\"` mode\n            it will stop when the quantity monitored has stopped increasing; in\n            `\"auto\"` mode, the direction is automatically inferred from the name\n            of the monitored quantity. Defaults to `\"auto\"`.\n        baseline: Baseline value for the monitored quantity. If not `None`,\n            training will stop if the model doesn't show improvement over the\n            baseline. Defaults to `None`.\n        restore_best_weights: Whether to restore model weights from the epoch\n            with the best value of the monitored quantity. If `False`, the model\n            weights obtained at the last step of training are used. An epoch\n            will be restored regardless of the performance relative to the\n            `baseline`. If no epoch improves on `baseline`, training will run\n            for `patience` epochs and restore weights from the best epoch in\n            that set. Defaults to `False`.\n        start_from_epoch: Number of epochs to wait before starting to monitor\n            improvement. This allows for a warm-up period in which no\n            improvement is expected and thus training will not be stopped.\n            Defaults to `0`.\n\n    Example:\n\n    >>> callback = keras.callbacks.EarlyStopping(monitor='loss',\n    ...                                               patience=3)\n    >>> # This callback will stop the training when there is no improvement in\n    >>> # the loss for three consecutive epochs.\n    >>> model = keras.models.Sequential([keras.layers.Dense(10)])\n    >>> model.compile(keras.optimizers.SGD(), loss='mse')\n    >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),\n    ...                     epochs=10, batch_size=1, callbacks=[callback],\n    ...                     verbose=0)\n    >>> len(history.history['loss'])  # Only 4 epochs are run.\n    4\n    \"\"\"\n\n    def __init__(\n        self,\n        monitor=\"val_loss\",\n        min_delta=0,\n        patience=0,\n        verbose=0,\n        mode=\"auto\",\n        baseline=None,\n        restore_best_weights=False,\n        start_from_epoch=0,\n    ):\n        super().__init__(monitor, mode, min_delta=min_delta)\n        self.patience = patience\n        self.verbose = verbose\n        self.baseline = baseline\n        self.wait = 0\n        self.stopped_epoch = 0\n        self.restore_best_weights = restore_best_weights\n        self.best_weights = None\n        self.start_from_epoch = start_from_epoch\n\n    def on_train_begin(self, logs=None):\n        # Allow instances to be re-used\n        self.wait = 0\n        self.stopped_epoch = 0\n        self.best_weights = None\n        self.best_epoch = 0\n\n    def on_epoch_end(self, epoch, logs=None):\n        if self.monitor_op is None:\n            # Delay setup until the model's metrics are all built\n            self._set_monitor_op()\n\n        current = self.get_monitor_value(logs)\n        if current is None or epoch < self.start_from_epoch:\n            # If no monitor value exists or still in initial warm-up stage.\n            return\n        if self.restore_best_weights and self.best_weights is None:\n            # If best weights were never set,\n            # then the current weights are the best.\n            self.best_weights = self.model.get_weights()\n            self.best_epoch = epoch\n\n        self.wait += 1\n        if self._is_improvement(current, self.best):\n            self.best = current\n            self.best_epoch = epoch\n            if self.restore_best_weights:\n                self.best_weights = self.model.get_weights()\n            # Only restart wait if we beat both the baseline and our previous\n            # best.\n            if self.baseline is None or self._is_improvement(\n                current, self.baseline\n            ):\n                self.wait = 0\n            return\n\n        if self.wait >= self.patience and epoch > 0:\n            # Patience has been exceeded: stop training\n            self.stopped_epoch = epoch\n            self.model.stop_training = True\n\n    def on_train_end(self, logs=None):\n        if self.stopped_epoch > 0 and self.verbose > 0:\n            io_utils.print_msg(\n                f\"Epoch {self.stopped_epoch + 1}: early stopping\"\n            )\n        if self.restore_best_weights and self.best_weights is not None:\n            if self.verbose > 0:\n                io_utils.print_msg(\n                    \"Restoring model weights from \"\n                    \"the end of the best epoch: \"\n                    f\"{self.best_epoch + 1}.\"\n                )\n            self.model.set_weights(self.best_weights)\n\n    def get_monitor_value(self, logs):\n        logs = logs or {}\n        monitor_value = logs.get(self.monitor)\n        if monitor_value is None:\n            warnings.warn(\n                (\n                    f\"Early stopping conditioned on metric `{self.monitor}` \"\n                    \"which is not available. \"\n                    f\"Available metrics are: {','.join(list(logs.keys()))}\"\n                ),\n                stacklevel=2,\n            )\n        return monitor_value\n"
  },
  {
    "path": "keras/src/callbacks/early_stopping_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import callbacks\nfrom keras.src import layers\nfrom keras.src import metrics\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import testing\n\n\nclass EarlyStoppingTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_early_stopping(self):\n        x_train = np.random.random((10, 5))\n        y_train = np.random.random((10, 1))\n        x_test = np.random.random((10, 5))\n        y_test = np.random.random((10, 1))\n        model = models.Sequential(\n            (\n                layers.Dense(1, activation=\"relu\"),\n                layers.Dense(1, activation=\"relu\"),\n            )\n        )\n        model.compile(\n            loss=\"mae\",\n            optimizer=\"adam\",\n            metrics=[\n                \"mse\",\n                \"acc\",\n                \"accuracy\",\n                \"hinge\",\n                metrics.F1Score(name=\"f1_score\"),\n            ],\n        )\n\n        cases = [\n            (\"max\", \"val_mse\", \"max\"),\n            (\"min\", \"val_loss\", \"min\"),\n            (\"auto\", \"val_mse\", \"min\"),\n            (\"auto\", \"loss\", \"min\"),\n            (\"auto\", \"acc\", \"max\"),\n            (\"auto\", \"val_accuracy\", \"max\"),\n            (\"auto\", \"hinge\", \"min\"),\n            (\"auto\", \"f1_score\", \"max\"),\n        ]\n        for mode, monitor, expected_mode in cases:\n            patience = 0\n            cbks = [\n                callbacks.EarlyStopping(\n                    patience=patience, monitor=monitor, mode=mode\n                )\n            ]\n            model.fit(\n                x_train,\n                y_train,\n                batch_size=5,\n                validation_data=(x_test, y_test),\n                callbacks=cbks,\n                epochs=2,\n                verbose=0,\n            )\n            if expected_mode == \"max\":\n                monitor_op = ops.greater\n            else:\n                monitor_op = ops.less\n            self.assertEqual(cbks[0].monitor_op, monitor_op)\n\n        with self.assertRaises(ValueError):\n            cbks = [\n                callbacks.EarlyStopping(patience=patience, monitor=\"unknown\")\n            ]\n            model.fit(\n                x_train,\n                y_train,\n                batch_size=5,\n                validation_data=(x_test, y_test),\n                callbacks=cbks,\n                epochs=2,\n                verbose=0,\n            )\n\n    @pytest.mark.requires_trainable_backend\n    def test_early_stopping_patience(self):\n        cases = [0, 1, 2, 3]\n        losses = [10.0, 9.0, 8.0, 9.0, 8.9, 8.8, 8.7, 8.6, 8.5]\n\n        for patience in cases:\n            stopper = callbacks.EarlyStopping(monitor=\"loss\", patience=patience)\n            stopper.set_model(models.Sequential())\n            stopper.model.compile(loss=\"mse\", optimizer=\"sgd\")\n            stopper.on_train_begin()\n\n            for epoch, loss in enumerate(losses):\n                stopper.on_epoch_end(epoch=epoch, logs={\"loss\": loss})\n                if stopper.model.stop_training:\n                    break\n\n            self.assertEqual(stopper.stopped_epoch, max(patience, 1) + 2)\n\n    @pytest.mark.requires_trainable_backend\n    def test_early_stopping_reuse(self):\n        patience = 3\n        data = np.random.random((100, 1))\n        labels = np.where(data > 0.5, 1, 0)\n        model = models.Sequential(\n            (\n                layers.Dense(1, activation=\"relu\"),\n                layers.Dense(1, activation=\"relu\"),\n            )\n        )\n        model.compile(\n            optimizer=\"sgd\",\n            loss=\"mae\",\n            metrics=[\"mse\"],\n        )\n        stopper = callbacks.EarlyStopping(monitor=\"mse\", patience=patience)\n\n        history1 = model.fit(\n            data, labels, callbacks=[stopper], verbose=0, epochs=20\n        )\n        self.assertGreaterEqual(len(history1.epoch), patience)\n\n        history2 = model.fit(\n            data, labels, callbacks=[stopper], verbose=0, epochs=20\n        )\n        self.assertGreaterEqual(len(history2.epoch), patience)\n\n    @pytest.mark.requires_trainable_backend\n    def test_early_stopping_with_baseline(self):\n        baseline = 0.6\n        x_train = np.random.random((10, 5))\n        y_train = np.random.random((10, 1))\n        model = models.Sequential(\n            (\n                layers.Dense(1, activation=\"relu\"),\n                layers.Dense(1, activation=\"relu\"),\n            )\n        )\n        model.compile(optimizer=\"sgd\", loss=\"mae\", metrics=[\"mse\"])\n\n        patience = 3\n        stopper = callbacks.EarlyStopping(\n            monitor=\"mse\", patience=patience, baseline=baseline\n        )\n        hist = model.fit(\n            x_train, y_train, callbacks=[stopper], verbose=0, epochs=20\n        )\n        self.assertGreaterEqual(len(hist.epoch), patience)\n\n    def test_early_stopping_final_weights_when_restoring_model_weights(self):\n        class DummyModel:\n            def __init__(self):\n                self.stop_training = False\n                self.weights = -1\n\n            def get_weights(self):\n                return self.weights\n\n            def set_weights(self, weights):\n                self.weights = weights\n\n            def set_weight_to_epoch(self, epoch):\n                self.weights = epoch\n\n        early_stop = callbacks.EarlyStopping(\n            monitor=\"val_loss\", patience=2, restore_best_weights=True\n        )\n        early_stop.set_model(DummyModel())\n        losses = [0.2, 0.15, 0.1, 0.11, 0.12]\n        # The best configuration is in the epoch 2 (loss = 0.1000).\n        epochs_trained = 0\n        early_stop.on_train_begin()\n        for epoch in range(len(losses)):\n            epochs_trained += 1\n            early_stop.model.set_weight_to_epoch(epoch=epoch)\n            early_stop.on_epoch_end(epoch, logs={\"val_loss\": losses[epoch]})\n            if early_stop.model.stop_training:\n                break\n        early_stop.on_train_end()\n        # The best configuration is in epoch 2 (loss = 0.1000),\n        # and while patience = 2, we're restoring the best weights,\n        # so we end up at the epoch with the best weights, i.e. epoch 2\n        self.assertEqual(early_stop.model.get_weights(), 2)\n\n        # Check early stopping when no model beats the baseline.\n        early_stop = callbacks.EarlyStopping(\n            monitor=\"val_loss\",\n            patience=5,\n            baseline=0.5,\n            restore_best_weights=True,\n        )\n        early_stop.set_model(DummyModel())\n        losses = [0.9, 0.8, 0.7, 0.71, 0.72, 0.73]\n        # The best configuration is in the epoch 2 (loss = 0.7000).\n        epochs_trained = 0\n        early_stop.on_train_begin()\n        for epoch in range(len(losses)):\n            epochs_trained += 1\n            early_stop.model.set_weight_to_epoch(epoch=epoch)\n            early_stop.on_epoch_end(epoch, logs={\"val_loss\": losses[epoch]})\n            if early_stop.model.stop_training:\n                break\n        early_stop.on_train_end()\n        # No epoch improves on the baseline, so we should train for only 5\n        # epochs, and restore the second model.\n        self.assertEqual(epochs_trained, 5)\n        self.assertEqual(early_stop.model.get_weights(), 2)\n\n        # Check weight restoration when another callback requests a stop.\n        early_stop = callbacks.EarlyStopping(\n            monitor=\"val_loss\",\n            patience=5,\n            baseline=0.5,\n            restore_best_weights=True,\n        )\n        early_stop.set_model(DummyModel())\n        losses = [0.9, 0.8, 0.7, 0.71, 0.72, 0.73]\n        # The best configuration is in the epoch 2 (loss = 0.7000).\n        epochs_trained = 0\n        early_stop.on_train_begin()\n        for epoch in range(len(losses)):\n            epochs_trained += 1\n            early_stop.model.set_weight_to_epoch(epoch=epoch)\n            early_stop.on_epoch_end(epoch, logs={\"val_loss\": losses[epoch]})\n            if epoch == 3:\n                early_stop.model.stop_training = True\n            if early_stop.model.stop_training:\n                break\n        early_stop.on_train_end()\n        # We should restore the second model.\n        self.assertEqual(epochs_trained, 4)\n        self.assertEqual(early_stop.model.get_weights(), 2)\n\n    @pytest.mark.requires_trainable_backend\n    def test_early_stopping_with_start_from_epoch(self):\n        x_train = np.random.random((10, 5))\n        y_train = np.random.random((10, 1))\n        model = models.Sequential(\n            (\n                layers.Dense(1, activation=\"relu\"),\n                layers.Dense(1, activation=\"relu\"),\n            )\n        )\n        model.compile(optimizer=\"sgd\", loss=\"mae\", metrics=[\"mse\"])\n        start_from_epoch = 2\n        patience = 3\n        stopper = callbacks.EarlyStopping(\n            monitor=\"mse\",\n            patience=patience,\n            start_from_epoch=start_from_epoch,\n        )\n        history = model.fit(\n            x_train, y_train, callbacks=[stopper], verbose=0, epochs=20\n        )\n        # Test 'patience' argument functions correctly when used\n        # in conjunction with 'start_from_epoch'.\n        self.assertGreaterEqual(len(history.epoch), patience + start_from_epoch)\n\n        start_from_epoch = 2\n        patience = 0\n        stopper = callbacks.EarlyStopping(\n            monitor=\"mse\",\n            patience=patience,\n            start_from_epoch=start_from_epoch,\n        )\n        history = model.fit(\n            x_train, y_train, callbacks=[stopper], verbose=0, epochs=20\n        )\n        # Test for boundary condition when 'patience' = 0.\n        self.assertGreaterEqual(len(history.epoch), start_from_epoch)\n"
  },
  {
    "path": "keras/src/callbacks/history.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.callbacks.callback import Callback\n\n\n@keras_export(\"keras.callbacks.History\")\nclass History(Callback):\n    \"\"\"Callback that records events into a `History` object.\n\n    This callback is automatically applied to\n    every Keras model. The `History` object\n    gets returned by the `fit()` method of models.\n\n    Example:\n\n    >>> model = Sequential([layers.Dense(10)])\n    >>> model.compile(SGD(), loss='mse')\n    >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),\n    ...                     epochs=10, verbose=1)\n    >>> print(history.params)\n    {'verbose': 1, 'epochs': 10, 'steps': 1}\n    >>> # check the keys of history object\n    >>> print(history.history.keys())\n    dict_keys(['loss'])\n\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.history = {}\n\n    def on_train_begin(self, logs=None):\n        self.epoch = []\n\n    def on_epoch_end(self, epoch, logs=None):\n        logs = logs or {}\n        self.epoch.append(epoch)\n        for k, v in logs.items():\n            self.history.setdefault(k, []).append(v)\n\n        # Set the history attribute on the model after the epoch ends. This will\n        # make sure that the state which is set is the latest one.\n        self.model.history = self\n"
  },
  {
    "path": "keras/src/callbacks/lambda_callback.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.callbacks.callback import Callback\n\n\n@keras_export(\"keras.callbacks.LambdaCallback\")\nclass LambdaCallback(Callback):\n    \"\"\"Callback for creating simple, custom callbacks on-the-fly.\n\n    This callback is constructed with anonymous functions that will be called\n    at the appropriate time (during `Model.{fit | evaluate | predict}`).\n    Note that the callbacks expects positional arguments, as:\n\n    - `on_epoch_begin` and `on_epoch_end` expect two positional arguments:\n      `epoch`, `logs`\n    - `on_train_begin` and `on_train_end` expect one positional argument:\n      `logs`\n    - `on_train_batch_begin` and `on_train_batch_end` expect a positional\n      argument `batch` and a keyword argument `logs`\n    - See `Callback` class definition for the full list of functions and their\n      expected arguments.\n\n    Args:\n        on_epoch_begin: called at the beginning of every epoch.\n        on_epoch_end: called at the end of every epoch.\n        on_train_begin: called at the beginning of model training.\n        on_train_end: called at the end of model training.\n        on_train_batch_begin: called at the beginning of every train batch.\n        on_train_batch_end: called at the end of every train batch.\n        kwargs: Any function in `Callback` that you want to override by\n            passing `function_name=function`. For example,\n            `LambdaCallback(.., on_train_end=train_end_fn)`. The custom function\n            needs to have same arguments as the ones defined in `Callback`.\n\n    Example:\n\n    ```python\n    # Print the batch number at the beginning of every batch.\n    batch_print_callback = LambdaCallback(\n        on_train_batch_begin=lambda batch,logs: print(batch))\n\n    # Stream the epoch loss to a file in JSON format. The file content\n    # is not well-formed JSON but rather has a JSON object per line.\n    import json\n    json_log = open('loss_log.json', mode='wt', buffering=1)\n    json_logging_callback = LambdaCallback(\n        on_epoch_end=lambda epoch, logs: json_log.write(\n            json.dumps({'epoch': epoch, 'loss': logs['loss']}) + '\\n'),\n        on_train_end=lambda logs: json_log.close()\n    )\n\n    # Terminate some processes after having finished model training.\n    processes = ...\n    cleanup_callback = LambdaCallback(\n        on_train_end=lambda logs: [\n            p.terminate() for p in processes if p.is_alive()])\n\n    model.fit(...,\n              callbacks=[batch_print_callback,\n                         json_logging_callback,\n                         cleanup_callback])\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        on_epoch_begin=None,\n        on_epoch_end=None,\n        on_train_begin=None,\n        on_train_end=None,\n        on_train_batch_begin=None,\n        on_train_batch_end=None,\n        **kwargs,\n    ):\n        super().__init__()\n        self.__dict__.update(kwargs)\n        if on_epoch_begin is not None:\n            self.on_epoch_begin = on_epoch_begin\n        if on_epoch_end is not None:\n            self.on_epoch_end = on_epoch_end\n        if on_train_begin is not None:\n            self.on_train_begin = on_train_begin\n        if on_train_end is not None:\n            self.on_train_end = on_train_end\n        if on_train_batch_begin is not None:\n            self.on_train_batch_begin = on_train_batch_begin\n        if on_train_batch_end is not None:\n            self.on_train_batch_end = on_train_batch_end\n"
  },
  {
    "path": "keras/src/callbacks/lambda_callback_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl import logging\n\nfrom keras.src import callbacks\nfrom keras.src import layers\nfrom keras.src import losses\nfrom keras.src import optimizers\nfrom keras.src import testing\nfrom keras.src.models.sequential import Sequential\n\n\nclass LambdaCallbackTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_lambda_callback(self):\n        \"\"\"Test standard LambdaCallback functionalities with training.\"\"\"\n        batch_size = 4\n        model = Sequential(\n            [layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)]\n        )\n        model.compile(\n            optimizer=optimizers.SGD(), loss=losses.MeanSquaredError()\n        )\n        x = np.random.randn(16, 2)\n        y = np.random.randn(16, 1)\n        lambda_log_callback = callbacks.LambdaCallback(\n            on_train_begin=lambda logs: logging.warning(\"on_train_begin\"),\n            on_epoch_begin=lambda epoch, logs: logging.warning(\n                \"on_epoch_begin\"\n            ),\n            on_epoch_end=lambda epoch, logs: logging.warning(\"on_epoch_end\"),\n            on_train_end=lambda logs: logging.warning(\"on_train_end\"),\n        )\n        with self.assertLogs(level=\"WARNING\") as logs:\n            model.fit(\n                x,\n                y,\n                batch_size=batch_size,\n                validation_split=0.2,\n                callbacks=[lambda_log_callback],\n                epochs=5,\n                verbose=0,\n            )\n            self.assertTrue(any(\"on_train_begin\" in log for log in logs.output))\n            self.assertTrue(any(\"on_epoch_begin\" in log for log in logs.output))\n            self.assertTrue(any(\"on_epoch_end\" in log for log in logs.output))\n            self.assertTrue(any(\"on_train_end\" in log for log in logs.output))\n\n    @pytest.mark.requires_trainable_backend\n    def test_lambda_callback_with_batches(self):\n        \"\"\"Test LambdaCallback's behavior with batch-level callbacks.\"\"\"\n        batch_size = 4\n        model = Sequential(\n            [layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)]\n        )\n        model.compile(\n            optimizer=optimizers.SGD(), loss=losses.MeanSquaredError()\n        )\n        x = np.random.randn(16, 2)\n        y = np.random.randn(16, 1)\n        lambda_log_callback = callbacks.LambdaCallback(\n            on_train_batch_begin=lambda batch, logs: logging.warning(\n                \"on_train_batch_begin\"\n            ),\n            on_train_batch_end=lambda batch, logs: logging.warning(\n                \"on_train_batch_end\"\n            ),\n        )\n        with self.assertLogs(level=\"WARNING\") as logs:\n            model.fit(\n                x,\n                y,\n                batch_size=batch_size,\n                validation_split=0.2,\n                callbacks=[lambda_log_callback],\n                epochs=5,\n                verbose=0,\n            )\n            self.assertTrue(\n                any(\"on_train_batch_begin\" in log for log in logs.output)\n            )\n            self.assertTrue(\n                any(\"on_train_batch_end\" in log for log in logs.output)\n            )\n\n    @pytest.mark.requires_trainable_backend\n    def test_lambda_callback_with_kwargs(self):\n        \"\"\"Test LambdaCallback's behavior with custom defined callback.\"\"\"\n        batch_size = 4\n        model = Sequential(\n            [layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)]\n        )\n        model.compile(\n            optimizer=optimizers.SGD(), loss=losses.MeanSquaredError()\n        )\n        x = np.random.randn(16, 2)\n        y = np.random.randn(16, 1)\n        model.fit(\n            x, y, batch_size=batch_size, epochs=1, verbose=0\n        )  # Train briefly for evaluation to work.\n\n        def custom_on_test_begin(logs):\n            logging.warning(\"custom_on_test_begin_executed\")\n\n        lambda_log_callback = callbacks.LambdaCallback(\n            on_test_begin=custom_on_test_begin\n        )\n        with self.assertLogs(level=\"WARNING\") as logs:\n            model.evaluate(\n                x,\n                y,\n                batch_size=batch_size,\n                callbacks=[lambda_log_callback],\n                verbose=0,\n            )\n            self.assertTrue(\n                any(\n                    \"custom_on_test_begin_executed\" in log\n                    for log in logs.output\n                )\n            )\n\n    @pytest.mark.requires_trainable_backend\n    def test_lambda_callback_no_args(self):\n        \"\"\"Test initializing LambdaCallback without any arguments.\"\"\"\n        lambda_callback = callbacks.LambdaCallback()\n        self.assertIsInstance(lambda_callback, callbacks.LambdaCallback)\n\n    @pytest.mark.requires_trainable_backend\n    def test_lambda_callback_with_additional_kwargs(self):\n        \"\"\"Test initializing LambdaCallback with non-predefined kwargs.\"\"\"\n\n        def custom_callback(logs):\n            pass\n\n        lambda_callback = callbacks.LambdaCallback(\n            custom_method=custom_callback\n        )\n        self.assertTrue(hasattr(lambda_callback, \"custom_method\"))\n\n    @pytest.mark.requires_trainable_backend\n    def test_lambda_callback_during_prediction(self):\n        \"\"\"Test LambdaCallback's functionality during model prediction.\"\"\"\n        batch_size = 4\n        model = Sequential(\n            [layers.Input(shape=(2,), batch_size=batch_size), layers.Dense(1)]\n        )\n        model.compile(\n            optimizer=optimizers.SGD(), loss=losses.MeanSquaredError()\n        )\n        x = np.random.randn(16, 2)\n\n        def custom_on_predict_begin(logs):\n            logging.warning(\"on_predict_begin_executed\")\n\n        lambda_callback = callbacks.LambdaCallback(\n            on_predict_begin=custom_on_predict_begin\n        )\n        with self.assertLogs(level=\"WARNING\") as logs:\n            model.predict(\n                x, batch_size=batch_size, callbacks=[lambda_callback], verbose=0\n            )\n            self.assertTrue(\n                any(\"on_predict_begin_executed\" in log for log in logs.output)\n            )\n"
  },
  {
    "path": "keras/src/callbacks/learning_rate_scheduler.py",
    "content": "import numpy as np\n\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.callbacks.callback import Callback\nfrom keras.src.utils import io_utils\n\n\n@keras_export(\"keras.callbacks.LearningRateScheduler\")\nclass LearningRateScheduler(Callback):\n    \"\"\"Learning rate scheduler.\n\n    At the beginning of every epoch, this callback gets the updated learning\n    rate value from `schedule` function provided at `__init__`, with the current\n    epoch and current learning rate, and applies the updated learning rate on\n    the optimizer.\n\n    Args:\n        schedule: A function that takes an epoch index (integer, indexed from 0)\n            and current learning rate (float) as inputs and returns a new\n            learning rate as output (float).\n        verbose: Integer. 0: quiet, 1: log update messages.\n\n    Example:\n\n    >>> # This function keeps the initial learning rate for the first ten epochs\n    >>> # and decreases it exponentially after that.\n    >>> def scheduler(epoch, lr):\n    ...     if epoch < 10:\n    ...         return lr\n    ...     else:\n    ...         return lr * ops.exp(-0.1)\n    >>>\n    >>> model = keras.models.Sequential([keras.layers.Dense(10)])\n    >>> model.compile(keras.optimizers.SGD(), loss='mse')\n    >>> round(model.optimizer.learning_rate, 5)\n    0.01\n\n    >>> callback = keras.callbacks.LearningRateScheduler(scheduler)\n    >>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),\n    ...                     epochs=15, callbacks=[callback], verbose=0)\n    >>> round(model.optimizer.learning_rate, 5)\n    0.00607\n\n    \"\"\"\n\n    def __init__(self, schedule, verbose=0):\n        super().__init__()\n        self.schedule = schedule\n        self.verbose = verbose\n\n    def on_epoch_begin(self, epoch, logs=None):\n        if not hasattr(self.model.optimizer, \"learning_rate\"):\n            raise ValueError('Optimizer must have a \"learning_rate\" attribute.')\n\n        try:  # new API\n            learning_rate = float(\n                backend.convert_to_numpy(self.model.optimizer.learning_rate)\n            )\n            learning_rate = self.schedule(epoch, learning_rate)\n        except TypeError:  # Support for old API for backward compatibility\n            learning_rate = self.schedule(epoch)\n\n        if not isinstance(learning_rate, (float, np.float32, np.float64)):\n            raise ValueError(\n                \"The output of the `schedule` function should be a float. \"\n                f\"Got: {learning_rate}\"\n            )\n\n        self.model.optimizer.learning_rate = learning_rate\n        if self.verbose > 0:\n            io_utils.print_msg(\n                f\"\\nEpoch {epoch + 1}: LearningRateScheduler setting learning \"\n                f\"rate to {learning_rate}.\"\n            )\n\n    def on_epoch_end(self, epoch, logs=None):\n        logs = logs or {}\n        logs[\"learning_rate\"] = float(\n            backend.convert_to_numpy(self.model.optimizer.learning_rate)\n        )\n"
  },
  {
    "path": "keras/src/callbacks/learning_rate_scheduler_test.py",
    "content": "import pytest\n\nfrom keras.src import callbacks\nfrom keras.src import layers\nfrom keras.src import optimizers\nfrom keras.src import testing\nfrom keras.src.models import Sequential\nfrom keras.src.testing import test_utils\nfrom keras.src.utils import io_utils\nfrom keras.src.utils import numerical_utils\n\n\nclass LearningRateSchedulerTest(testing.TestCase):\n    def setUp(self):\n        (x_train, y_train), _ = test_utils.get_test_data(\n            train_samples=10,\n            test_samples=10,\n            input_shape=(3,),\n            num_classes=2,\n        )\n        y_train = numerical_utils.to_categorical(y_train)\n\n        model = Sequential([layers.Dense(5), layers.Dense(2)])\n\n        model.compile(\n            loss=\"mse\",\n            optimizer=\"sgd\",\n        )\n\n        self.model = model\n        self.x_train = x_train\n        self.y_train = y_train\n\n    @pytest.mark.requires_trainable_backend\n    def test_updates_learning_rate(self):\n        lr_scheduler = callbacks.LearningRateScheduler(\n            lambda step: 1.0 / (2.0 + step), verbose=1\n        )\n\n        self.model.fit(\n            self.x_train,\n            self.y_train,\n            callbacks=[lr_scheduler],\n            epochs=1,\n        )\n\n        self.assertEqual(self.model.optimizer.learning_rate.value, 0.5)\n\n    @pytest.mark.requires_trainable_backend\n    def test_verbose_logging(self):\n        lr_scheduler = callbacks.LearningRateScheduler(\n            lambda step: 1.0 / (1.0 + step), verbose=1\n        )\n        io_utils.disable_interactive_logging()\n        io_utils.set_logging_verbosity(\"INFO\")\n\n        with self.assertLogs() as logs:\n            self.model.fit(\n                self.x_train,\n                self.y_train,\n                callbacks=[lr_scheduler],\n                epochs=1,\n            )\n            expected_log = \"LearningRateScheduler setting learning rate to 1.0\"\n            self.assertTrue(any(expected_log in log for log in logs.output))\n\n    @pytest.mark.requires_trainable_backend\n    def test_schedule_dependent_on_previous_learning_rate(self):\n        lr_scheduler = callbacks.LearningRateScheduler(lambda step, lr: lr / 2)\n\n        initial_lr = 0.03\n        self.model.compile(\n            loss=\"mse\",\n            optimizer=optimizers.Adam(initial_lr),\n        )\n\n        self.model.fit(\n            self.x_train,\n            self.y_train,\n            callbacks=[lr_scheduler],\n            epochs=2,\n        )\n        self.assertEqual(\n            self.model.optimizer.learning_rate.value, initial_lr / 4.0\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_throws_when_optimizer_has_schedule(self):\n        lr_scheduler = callbacks.LearningRateScheduler(lambda step, lr: lr / 2)\n\n        self.model.compile(\n            loss=\"mse\",\n            optimizer=optimizers.Adam(\n                optimizers.schedules.PolynomialDecay(\n                    initial_learning_rate=0.1, decay_steps=10\n                )\n            ),\n        )\n\n        with self.assertRaisesRegex(\n            TypeError,\n            \"This optimizer was created with a `LearningRateSchedule`\",\n        ):\n            self.model.fit(\n                self.x_train,\n                self.y_train,\n                callbacks=[lr_scheduler],\n                epochs=2,\n            )\n\n    @pytest.mark.requires_trainable_backend\n    def test_learning_rate_in_history(self):\n        lr_scheduler = callbacks.LearningRateScheduler(lambda step, lr: 0.5)\n\n        history = self.model.fit(\n            self.x_train,\n            self.y_train,\n            callbacks=[lr_scheduler],\n            epochs=1,\n        )\n\n        self.assertTrue(\"learning_rate\" in history.history)\n        self.assertEqual(type(history.history[\"learning_rate\"][0]), float)\n        self.assertEqual(history.history[\"learning_rate\"][0], 0.5)\n"
  },
  {
    "path": "keras/src/callbacks/model_checkpoint.py",
    "content": "import os\nimport re\nimport warnings\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.callbacks.monitor_callback import MonitorCallback\nfrom keras.src.utils import file_utils\nfrom keras.src.utils import io_utils\n\n\n@keras_export(\"keras.callbacks.ModelCheckpoint\")\nclass ModelCheckpoint(MonitorCallback):\n    \"\"\"Callback to save the Keras model or model weights at some frequency.\n\n    `ModelCheckpoint` callback is used in conjunction with training using\n    `model.fit()` to save a model or weights (in a checkpoint file) at some\n    interval, so the model or weights can be loaded later to continue the\n    training from the state saved.\n\n    A few options this callback provides include:\n\n    - Whether to only keep the model that has achieved the \"best performance\" so\n      far, or whether to save the model at the end of every epoch regardless of\n      performance.\n    - Definition of \"best\"; which quantity to monitor and whether it should be\n      maximized or minimized.\n    - The frequency it should save at. Currently, the callback supports saving\n      at the end of every epoch, or after a fixed number of training batches.\n    - Whether only weights are saved, or the whole model is saved.\n\n    Example:\n\n    ```python\n    # Define a learning rate schedule\n    initial_learning_rate = 0.1\n    lr_schedule = keras.optimizers.schedules.ExponentialDecay(\n        initial_learning_rate,\n        decay_steps=100000,\n        decay_rate=0.96,\n        staircase=True,\n    )\n\n    model.compile(\n        optimizer=keras.optimizers.RMSprop(learning_rate=lr_schedule),\n        loss=\"sparse_categorical_crossentropy\",\n        metrics=[\"accuracy\"],\n    )\n\n    checkpoint_filepath = '/tmp/ckpt/checkpoint.model.keras'\n    model_checkpoint_callback = keras.callbacks.ModelCheckpoint(\n        filepath=checkpoint_filepath,\n        monitor='val_accuracy',\n        mode='max',\n        save_best_only=True)\n\n    # Model is saved at the end of every epoch, if it's the best seen so far.\n    model.fit(\n        x_train, y_train, epochs=10, callbacks=[model_checkpoint_callback]\n    )\n\n    # The model (that are considered the best) can be loaded as -\n    keras.models.load_model(checkpoint_filepath)\n\n    # Alternatively, one could checkpoint just the model weights as -\n    checkpoint_filepath = '/tmp/ckpt/checkpoint.weights.h5'\n    model_checkpoint_callback = keras.callbacks.ModelCheckpoint(\n        filepath=checkpoint_filepath,\n        save_weights_only=True,\n        monitor='val_accuracy',\n        mode='max',\n        save_best_only=True)\n\n    # Model weights are saved at the end of every epoch, if it's the best seen\n    # so far.\n    model.fit(\n        x_train, y_train, epochs=10, callbacks=[model_checkpoint_callback]\n    )\n\n    # The model weights (that are considered the best) can be loaded as -\n    model.load_weights(checkpoint_filepath)\n    ```\n\n    Resuming training from weight-only checkpoints:\n\n    When using `save_weights_only=True`, the weights file includes the state\n    of the optimizer (including iteration count and learning rate state)\n    if the model is compiled at the time of saving.\n\n    To correctly resume training and restore the optimizer state (e.g., to\n    continue a learning rate schedule without resetting it), you must\n    **compile the model before loading the weights**.\n\n    ```python\n    # Define a learning rate schedule\n    initial_learning_rate = 0.1\n    lr_schedule = keras.optimizers.schedules.ExponentialDecay(\n        initial_learning_rate,\n        decay_steps=100000,\n        decay_rate=0.96,\n        staircase=True,\n    )\n\n    # 1. Create a fresh model instance\n    model = get_model()\n\n    # 2. Compile the model *before* loading weights\n    model.compile(\n        optimizer=keras.optimizers.RMSprop(learning_rate=lr_schedule),\n        loss=\"sparse_categorical_crossentropy\",\n        metrics=[\"accuracy\"],\n    )\n\n    # 3. Load weights (optimizer state is restored automatically)\n    model.load_weights(checkpoint_filepath)\n\n    # 4. Continue training\n    model.fit(x_train, y_train, epochs=10)\n    ```\n\n    Args:\n        filepath: string or `PathLike`, path to save the model file.\n            `filepath` can contain named formatting options,\n            which will be filled the value of `epoch` and keys in `logs`\n            (passed in `on_epoch_end`).\n            The `filepath` name needs to end with `\".weights.h5\"` when\n            `save_weights_only=True` or should end with `\".keras\"` or `\".h5\"`\n            when checkpoint saving the whole model (default).\n            For example:\n            if `filepath` is `\"{epoch:02d}-{val_loss:.2f}.keras\"` or\n            \"{epoch:02d}-{val_loss:.2f}.weights.h5\"`, then the model\n            checkpoints will be saved with the epoch number and the validation\n            loss in the filename. The directory of the filepath\n            should not be reused by any other callbacks to avoid conflicts.\n        monitor: The metric name to monitor. Typically the metrics are set by\n            the `Model.compile` method. Note:\n            * Prefix the name with `\"val_\"` to monitor validation metrics.\n            * Use `\"loss\"` or `\"val_loss\"` to monitor the model's total loss.\n            * If you specify metrics as strings, like `\"accuracy\"`, pass the\n                same string (with or without the `\"val_\"` prefix).\n            * If you pass `metrics.Metric` objects, `monitor` should be set to\n                `metric.name`\n            * If you're not sure about the metric names you can check the\n                contents of the `history.history` dictionary returned by\n                `history = model.fit()`\n            * Multi-output models set additional prefixes on the metric names.\n        verbose: Verbosity mode, 0 or 1. Mode 0 is silent, and mode 1\n            displays messages when the callback takes an action.\n        save_best_only: if `save_best_only=True`, it only saves when the model\n            is considered the \"best\" and the latest best model according to the\n            quantity monitored will not be overwritten. If `filepath` doesn't\n            contain formatting options like `{epoch}` then `filepath` will be\n            overwritten by each new better model.\n        mode: one of {`\"auto\"`, `\"min\"`, `\"max\"`}. If `save_best_only=True`, the\n            decision to overwrite the current save file is made based on either\n            the maximization or the minimization of the monitored quantity.\n            For `val_acc`, this should be `\"max\"`, for `val_loss` this should be\n            `\"min\"`, etc. In `\"auto\"` mode, the direction is automatically\n            inferred from the name of the monitored quantity.\n        save_weights_only: if `True`, then only the model's weights will be\n            saved (`model.save_weights(filepath)`), else the full model is\n            saved (`model.save(filepath)`).\n        save_freq: `\"epoch\"` or integer. When using `\"epoch\"`, the callback\n            saves the model after each epoch. When using integer, the callback\n            saves the model at end of this many batches. If the `Model` is\n            compiled with `steps_per_execution=N`, then the saving criteria will\n            be checked every Nth batch. Note that if the saving isn't aligned to\n            epochs, the monitored metric may potentially be less reliable (it\n            could reflect as little as 1 batch, since the metrics get reset\n            every epoch). Defaults to `\"epoch\"`.\n        initial_value_threshold: Floating point initial \"best\" value of the\n            metric to be monitored. Only applies if `save_best_value=True`. Only\n            overwrites the model weights already saved if the performance of\n            current model is better than this value.\n    \"\"\"\n\n    def __init__(\n        self,\n        filepath,\n        monitor=\"val_loss\",\n        verbose=0,\n        save_best_only=False,\n        save_weights_only=False,\n        mode=\"auto\",\n        save_freq=\"epoch\",\n        initial_value_threshold=None,\n    ):\n        super().__init__(monitor, mode, initial_value_threshold)\n        self.verbose = verbose\n        self.filepath = file_utils.path_to_string(filepath)\n        self.save_best_only = save_best_only\n        self.save_weights_only = save_weights_only\n        self.save_freq = save_freq\n        self._batches_seen_since_last_saving = 0\n        self._last_batch_seen = None\n\n        if self.save_freq != \"epoch\" and not isinstance(self.save_freq, int):\n            raise ValueError(\n                f\"Unrecognized save_freq: {self.save_freq}. \"\n                \"Expected save_freq are 'epoch' or integer values\"\n            )\n\n        if save_weights_only:\n            if not self.filepath.endswith(\".weights.h5\"):\n                raise ValueError(\n                    \"When using `save_weights_only=True` in `ModelCheckpoint`\"\n                    \", the filepath provided must end in `.weights.h5` \"\n                    \"(Keras weights format). Received: \"\n                    f\"filepath={self.filepath}\"\n                )\n        else:\n            if not any(\n                self.filepath.endswith(ext) for ext in (\".keras\", \".h5\")\n            ):\n                raise ValueError(\n                    \"The filepath provided must end in `.keras` \"\n                    \"(Keras model format). Received: \"\n                    f\"filepath={self.filepath}\"\n                )\n\n    def on_train_batch_end(self, batch, logs=None):\n        if self._should_save_on_batch(batch):\n            self._save_model(epoch=self._current_epoch, batch=batch, logs=logs)\n\n    def on_epoch_begin(self, epoch, logs=None):\n        self._current_epoch = epoch\n\n    def on_epoch_end(self, epoch, logs=None):\n        if self.monitor_op is None:\n            # Delay setup until the model's metrics are all built\n            self._set_monitor_op()\n\n        if self.save_freq == \"epoch\":\n            self._save_model(epoch=epoch, batch=None, logs=logs)\n\n    def _should_save_on_batch(self, batch):\n        \"\"\"Handles batch-level saving logic, supports steps_per_execution.\"\"\"\n        if self.save_freq == \"epoch\":\n            return False\n        if self._last_batch_seen is None or batch <= self._last_batch_seen:\n            # New epoch.\n            add_batches = batch + 1  # batches are zero-indexed.\n        else:\n            add_batches = batch - self._last_batch_seen\n        self._batches_seen_since_last_saving += add_batches\n        self._last_batch_seen = batch\n\n        if self._batches_seen_since_last_saving >= self.save_freq:\n            self._batches_seen_since_last_saving = 0\n            return True\n        return False\n\n    def _should_save_model(self, epoch, batch, logs, filepath):\n        \"\"\"Determines whether the model should be saved.\n\n        The model should be saved in the following cases:\n\n        - self.save_best_only is False\n        - self.save_best_only is True and `monitor` is a numpy array or\n          backend tensor (falls back to `save_best_only=False`)\n        - self.save_best_only is True and `self.monitor_op(current, self.best)`\n          evaluates to True.\n\n        Args:\n            epoch: the epoch this iteration is in.\n            batch: the batch this iteration is in. `None` if the `save_freq`\n                is set to `\"epoch\"`.\n            logs: the `logs` dict passed in to `on_batch_end` or\n                `on_epoch_end`.\n            filepath: the path where the model would be saved\n        \"\"\"\n        logs = logs or {}\n        if self.save_best_only:\n            current = logs.get(self.monitor)\n            if current is None:\n                warnings.warn(\n                    f\"Can save best model only with {self.monitor} available.\",\n                    stacklevel=2,\n                )\n                return True\n            elif (\n                isinstance(current, np.ndarray) or backend.is_tensor(current)\n            ) and len(current.shape) > 0:\n                warnings.warn(\n                    \"Can save best model only when `monitor` is \"\n                    f\"a scalar value. Received: {current}. \"\n                    \"Falling back to `save_best_only=False`.\"\n                )\n                return True\n            else:\n                best_str = \"None\" if self.best is None else f\"{self.best:.5f}\"\n                if self._is_improvement(current, self.best):\n                    if self.verbose > 0:\n                        io_utils.print_msg(\n                            f\"\\nEpoch {epoch + 1}: {self.monitor} \"\n                            f\"improved from {best_str} to {current:.5f}, \"\n                            f\"saving model to {filepath}\"\n                        )\n                    self.best = current\n                    return True\n                else:\n                    if self.verbose > 0:\n                        io_utils.print_msg(\n                            f\"\\nEpoch {epoch + 1}: \"\n                            f\"{self.monitor} did not improve from {best_str}\"\n                        )\n                    return False\n        else:\n            if self.verbose > 0:\n                io_utils.print_msg(\n                    f\"\\nEpoch {epoch + 1}: saving model to {filepath}\"\n                )\n            return True\n\n    def _save_model(self, epoch, batch, logs):\n        \"\"\"Saves the model.\n\n        Args:\n            epoch: the epoch this iteration is in.\n            batch: the batch this iteration is in. `None` if the `save_freq`\n                is set to `\"epoch\"`.\n            logs: the `logs` dict passed in to `on_batch_end` or `on_epoch_end`.\n        \"\"\"\n        filepath = self._get_file_path(epoch, batch, logs)\n\n        try:\n            if self._should_save_model(epoch, batch, logs, filepath):\n                # Create host directory if it doesn't exist.\n                dirname = os.path.dirname(filepath)\n                if dirname and not file_utils.exists(dirname):\n                    file_utils.makedirs(dirname)\n\n                if self.save_weights_only:\n                    self.model.save_weights(filepath, overwrite=True)\n                else:\n                    self.model.save(filepath, overwrite=True)\n                if self.verbose > 0:\n                    io_utils.print_msg(\n                        f\"\\nEpoch {epoch + 1}: \"\n                        f\"finished saving model to {filepath}\"\n                    )\n        except IsADirectoryError:  # h5py 3.x\n            raise IOError(\n                \"Please specify a non-directory filepath for \"\n                \"ModelCheckpoint. Filepath used is an existing \"\n                f\"directory: {filepath}\"\n            )\n        except IOError as e:  # h5py 2.x\n            # `e.errno` appears to be `None` so checking the content of\n            # `e.args[0]`.\n            if \"is a directory\" in str(e.args[0]).lower():\n                raise IOError(\n                    \"Please specify a non-directory filepath for \"\n                    \"ModelCheckpoint. Filepath used is an existing \"\n                    f\"directory: f{filepath}\"\n                )\n            # Re-throw the error for any other causes.\n            raise e\n\n    def _get_file_path(self, epoch, batch, logs):\n        \"\"\"Returns the file path for checkpoint.\"\"\"\n\n        try:\n            # `filepath` may contain placeholders such as\n            # `{epoch:02d}`,`{batch:02d}` and `{mape:.2f}`. A mismatch between\n            # logged metrics and the path's placeholders can cause formatting to\n            # fail.\n            if batch is None or \"batch\" in logs:\n                file_path = self.filepath.format(epoch=epoch + 1, **logs)\n            else:\n                file_path = self.filepath.format(\n                    epoch=epoch + 1, batch=batch + 1, **logs\n                )\n        except KeyError as e:\n            raise KeyError(\n                f'Failed to format this callback filepath: \"{self.filepath}\". '\n                f\"Reason: {e}\"\n            )\n        return file_path\n\n    def _checkpoint_exists(self, filepath):\n        \"\"\"Returns whether the checkpoint `filepath` refers to exists.\"\"\"\n        return file_utils.exists(filepath)\n\n    def _get_most_recently_modified_file_matching_pattern(self, pattern):\n        \"\"\"Returns the most recently modified filepath matching pattern.\n\n        In the rare case where there are more than one pattern-matching file\n        having the same modified time that is most recent among all, return the\n        filepath that is largest (by `>` operator, lexicographically using the\n        numeric equivalents). This provides a tie-breaker when multiple files\n        are most recent. Note that a larger `filepath` can sometimes indicate a\n        later time of modification (for instance, when epoch/batch is used as\n        formatting option), but not necessarily (when accuracy or loss is used).\n        The tie-breaker is put in the logic as best effort to return the most\n        recent, and to avoid nondeterministic result.\n\n        Modified time of a file is obtained with `os.path.getmtime()`.\n\n        This utility function is best demonstrated via an example:\n\n        ```python\n        file_pattern = 'batch{batch:02d}epoch{epoch:02d}.keras'\n        test_dir = self.get_temp_dir()\n        path_pattern = os.path.join(test_dir, file_pattern)\n        file_paths = [\n            os.path.join(test_dir, file_name) for file_name in\n            ['batch03epoch02.keras',\n             'batch02epoch02.keras', 'batch01epoch01.keras']\n        ]\n        for file_path in file_paths:\n            # Write something to each of the files\n            ...\n        self.assertEqual(\n            _get_most_recently_modified_file_matching_pattern(path_pattern),\n            file_paths[-1])\n        ```\n\n        Args:\n            pattern: The file pattern that may optionally contain python\n                placeholder such as `{epoch:02d}`.\n\n        Returns:\n            The most recently modified file's full filepath matching `pattern`.\n            If `pattern` does not contain any placeholder, this returns the\n            filepath that exactly matches `pattern`. Returns `None` if no match\n            is found.\n        \"\"\"\n        dir_name = os.path.dirname(pattern)\n        base_name = os.path.basename(pattern)\n        base_name_regex = f\"^{re.sub(r'{.*}', r'.*', base_name)}$\"\n\n        latest_mod_time = 0\n        file_path_with_latest_mod_time = None\n        n_file_with_latest_mod_time = 0\n        file_path_with_largest_file_name = None\n\n        if file_utils.exists(dir_name):\n            for file_name in os.listdir(dir_name):\n                # Only consider if `file_name` matches the pattern.\n                if re.match(base_name_regex, file_name):\n                    file_path = os.path.join(dir_name, file_name)\n                    mod_time = os.path.getmtime(file_path)\n                    if (\n                        file_path_with_largest_file_name is None\n                        or file_path > file_path_with_largest_file_name\n                    ):\n                        file_path_with_largest_file_name = file_path\n                    if mod_time > latest_mod_time:\n                        latest_mod_time = mod_time\n                        file_path_with_latest_mod_time = file_path\n                        # In the case a file with later modified time is found,\n                        # reset the counter for the number of files with latest\n                        # modified time.\n                        n_file_with_latest_mod_time = 1\n                    elif mod_time == latest_mod_time:\n                        # In the case a file has modified time tied with the\n                        # most recent, increment the counter for the number of\n                        # files with latest modified time by 1.\n                        n_file_with_latest_mod_time += 1\n\n        if n_file_with_latest_mod_time == 1:\n            # Return the sole file that has most recent modified time.\n            return file_path_with_latest_mod_time\n        else:\n            # If there are more than one file having latest modified time,\n            # return the file path with the largest file name.\n            return file_path_with_largest_file_name\n"
  },
  {
    "path": "keras/src/callbacks/model_checkpoint_test.py",
    "content": "import os\nimport warnings\n\nimport pytest\n\nfrom keras.src import callbacks\nfrom keras.src import layers\nfrom keras.src import metrics\nfrom keras.src import models\nfrom keras.src import saving\nfrom keras.src import testing\nfrom keras.src.models import Sequential\nfrom keras.src.testing import test_utils\nfrom keras.src.utils import numerical_utils\n\ntry:\n    import h5py\nexcept ImportError:\n    h5py = None\n\nTRAIN_SAMPLES = 30\nTEST_SAMPLES = 30\nNUM_CLASSES = 3\nINPUT_DIM = 3\nNUM_HIDDEN = 5\nBATCH_SIZE = 5\n\n\nclass ModelCheckpointTest(testing.TestCase):\n    @pytest.mark.skipif(\n        h5py is None,\n        reason=\"`h5py` is a required dependency for `ModelCheckpoint` tests.\",\n    )\n    @pytest.mark.skipif(\n        testing.jax_uses_gpu(),\n        reason=\"Mysterious core dump on CI after upgrading JAX\",\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_model_checkpoint_options(self):\n        def get_model():\n            model = Sequential(\n                [\n                    layers.Dense(NUM_HIDDEN, activation=\"relu\"),\n                    layers.Dense(NUM_CLASSES, activation=\"softmax\"),\n                ]\n            )\n            model.compile(\n                loss=\"categorical_crossentropy\",\n                optimizer=\"sgd\",\n                metrics=[metrics.Accuracy(\"acc\")],\n            )\n            return model\n\n        model = get_model()\n        temp_dir = self.get_temp_dir()\n\n        # Save model to a subdir inside the temp_dir so we can test\n        # automatic directory creation.\n        filepath = os.path.join(temp_dir, \"subdir\", \"checkpoint.keras\")\n        (x_train, y_train), (x_test, y_test) = test_utils.get_test_data(\n            random_seed=42,\n            train_samples=TRAIN_SAMPLES,\n            test_samples=TEST_SAMPLES,\n            input_shape=(INPUT_DIM,),\n            num_classes=NUM_CLASSES,\n        )\n        y_test = numerical_utils.to_categorical(y_test, num_classes=NUM_CLASSES)\n        y_train = numerical_utils.to_categorical(\n            y_train, num_classes=NUM_CLASSES\n        )\n\n        # Case 1\n        monitor = \"val_loss\"\n        save_best_only = False\n        mode = \"auto\"\n\n        cbks = [\n            callbacks.ModelCheckpoint(\n                filepath,\n                monitor=monitor,\n                save_best_only=save_best_only,\n                mode=mode,\n            )\n        ]\n        model.fit(\n            x_train,\n            y_train,\n            batch_size=BATCH_SIZE,\n            validation_data=(x_test, y_test),\n            callbacks=cbks,\n            epochs=1,\n            verbose=0,\n        )\n        self.assertTrue(os.path.exists(filepath))\n        os.remove(filepath)\n\n        # Case 2\n        mode = \"min\"\n        cbks = [\n            callbacks.ModelCheckpoint(\n                filepath,\n                monitor=monitor,\n                save_best_only=save_best_only,\n                mode=mode,\n            )\n        ]\n        model.fit(\n            x_train,\n            y_train,\n            batch_size=BATCH_SIZE,\n            validation_data=(x_test, y_test),\n            callbacks=cbks,\n            epochs=1,\n            verbose=0,\n        )\n        self.assertTrue(os.path.exists(filepath))\n        os.remove(filepath)\n\n        # Case 3\n        mode = \"max\"\n        monitor = \"val_acc\"\n        cbks = [\n            callbacks.ModelCheckpoint(\n                filepath,\n                monitor=monitor,\n                save_best_only=save_best_only,\n                mode=mode,\n            )\n        ]\n        model.fit(\n            x_train,\n            y_train,\n            batch_size=BATCH_SIZE,\n            validation_data=(x_test, y_test),\n            callbacks=cbks,\n            epochs=1,\n            verbose=0,\n        )\n        self.assertTrue(os.path.exists(filepath))\n        os.remove(filepath)\n\n        # Case 4\n        save_best_only = True\n        cbks = [\n            callbacks.ModelCheckpoint(\n                filepath,\n                monitor=monitor,\n                save_best_only=save_best_only,\n                mode=mode,\n            )\n        ]\n        model.fit(\n            x_train,\n            y_train,\n            batch_size=BATCH_SIZE,\n            validation_data=(x_test, y_test),\n            callbacks=cbks,\n            epochs=1,\n            verbose=0,\n        )\n        self.assertTrue(os.path.exists(filepath))\n        os.remove(filepath)\n\n        # Case 5: metric not available.\n        cbks = [\n            callbacks.ModelCheckpoint(\n                filepath, monitor=\"unknown\", save_best_only=True, mode=\"min\"\n            )\n        ]\n        with pytest.warns(UserWarning):\n            model.fit(\n                x_train,\n                y_train,\n                batch_size=BATCH_SIZE,\n                validation_data=(x_test, y_test),\n                callbacks=cbks,\n                epochs=1,\n                verbose=0,\n            )\n        self.assertTrue(os.path.exists(filepath))\n\n        # Case 6\n        with warnings.catch_warnings(record=True) as warning_logs:\n            warnings.simplefilter(\"always\")\n            callbacks.ModelCheckpoint(\n                filepath,\n                monitor=monitor,\n                save_best_only=save_best_only,\n                mode=\"unknown\",\n            )\n            self.assertIn(\n                \"ModelCheckpoint mode 'unknown' is unknown\",\n                str(warning_logs[-1].message),\n            )\n\n        # Case 8a: `ModelCheckpoint` with an integer `save_freq`\n        temp_dir = self.get_temp_dir()\n        filepath = os.path.join(temp_dir, \"checkpoint.epoch{epoch:02d}.keras\")\n        save_best_only = False\n        cbks = [\n            callbacks.ModelCheckpoint(\n                filepath,\n                monitor=monitor,\n                save_best_only=save_best_only,\n                mode=mode,\n                save_freq=15,\n            )\n        ]\n        self.assertFalse(os.path.exists(filepath.format(epoch=3)))\n        model.fit(\n            x_train,\n            y_train,\n            batch_size=6,  # 5 batches / epoch, so should backup every 3 epochs\n            validation_data=(x_test, y_test),\n            callbacks=cbks,\n            epochs=10,\n            verbose=0,\n        )\n        self.assertFalse(os.path.exists(filepath.format(epoch=1)))\n        self.assertFalse(os.path.exists(filepath.format(epoch=2)))\n        self.assertTrue(os.path.exists(filepath.format(epoch=3)))\n        self.assertFalse(os.path.exists(filepath.format(epoch=4)))\n        self.assertFalse(os.path.exists(filepath.format(epoch=5)))\n        self.assertTrue(os.path.exists(filepath.format(epoch=6)))\n        self.assertFalse(os.path.exists(filepath.format(epoch=7)))\n        self.assertFalse(os.path.exists(filepath.format(epoch=8)))\n        self.assertTrue(os.path.exists(filepath.format(epoch=9)))\n        os.remove(filepath.format(epoch=3))\n        os.remove(filepath.format(epoch=6))\n        os.remove(filepath.format(epoch=9))\n\n        # Case 8b: `ModelCheckpoint` with int `save_freq` & `save_weights_only`\n        temp_dir = self.get_temp_dir()\n        filepath = os.path.join(\n            temp_dir, \"checkpoint.epoch{epoch:02d}.weights.h5\"\n        )\n        cbks = [\n            callbacks.ModelCheckpoint(\n                filepath, monitor=monitor, save_freq=15, save_weights_only=True\n            )\n        ]\n        self.assertFalse(os.path.exists(filepath.format(epoch=3)))\n        model.fit(\n            x_train,\n            y_train,\n            batch_size=6,\n            validation_data=(x_test, y_test),\n            callbacks=cbks,\n            epochs=10,\n            verbose=0,\n        )\n        self.assertFalse(os.path.exists(filepath.format(epoch=1)))\n        self.assertFalse(os.path.exists(filepath.format(epoch=2)))\n        self.assertTrue(os.path.exists(filepath.format(epoch=3)))\n        self.assertFalse(os.path.exists(filepath.format(epoch=4)))\n        self.assertFalse(os.path.exists(filepath.format(epoch=5)))\n        self.assertTrue(os.path.exists(filepath.format(epoch=6)))\n        self.assertFalse(os.path.exists(filepath.format(epoch=7)))\n        self.assertFalse(os.path.exists(filepath.format(epoch=8)))\n        self.assertTrue(os.path.exists(filepath.format(epoch=9)))\n\n        # Case 9: `ModelCheckpoint` with valid and invalid save_freq argument.\n        with self.assertRaisesRegex(ValueError, \"Unrecognized save_freq\"):\n            callbacks.ModelCheckpoint(\n                filepath,\n                monitor=monitor,\n                save_best_only=save_best_only,\n                save_weights_only=True,\n                mode=mode,\n                save_freq=\"invalid_save_freq\",\n            )\n        # The following should not raise ValueError.\n        callbacks.ModelCheckpoint(\n            filepath,\n            monitor=monitor,\n            save_best_only=save_best_only,\n            save_weights_only=True,\n            mode=mode,\n            save_freq=\"epoch\",\n        )\n        callbacks.ModelCheckpoint(\n            filepath,\n            monitor=monitor,\n            save_best_only=save_best_only,\n            save_weights_only=True,\n            mode=mode,\n            save_freq=3,\n        )\n\n        # Case 10a: `ModelCheckpoint` save with batch in filename.\n        temp_dir = self.get_temp_dir()\n        filepath = os.path.join(\n            temp_dir, \"checkpoint.epoch{epoch:02d}batch{batch:02d}.keras\"\n        )\n        cbks = [\n            callbacks.ModelCheckpoint(filepath, monitor=monitor, save_freq=1)\n        ]\n        model.fit(\n            x_train,\n            y_train,\n            batch_size=15,\n            validation_data=(x_test, y_test),\n            callbacks=cbks,\n            epochs=5,\n            verbose=1,\n        )\n        self.assertTrue(os.path.exists(filepath.format(epoch=1, batch=1)))\n        self.assertTrue(os.path.exists(filepath.format(epoch=1, batch=2)))\n        self.assertTrue(os.path.exists(filepath.format(epoch=2, batch=1)))\n        self.assertTrue(os.path.exists(filepath.format(epoch=2, batch=2)))\n        self.assertTrue(os.path.exists(filepath.format(epoch=3, batch=1)))\n        self.assertTrue(os.path.exists(filepath.format(epoch=3, batch=2)))\n        self.assertTrue(os.path.exists(filepath.format(epoch=4, batch=1)))\n        self.assertTrue(os.path.exists(filepath.format(epoch=4, batch=2)))\n        self.assertTrue(os.path.exists(filepath.format(epoch=5, batch=1)))\n        self.assertTrue(os.path.exists(filepath.format(epoch=5, batch=2)))\n\n        # Case 10b: `ModelCheckpoint` save weights with batch in filename.\n        temp_dir = self.get_temp_dir()\n        filepath = os.path.join(\n            temp_dir, \"checkpoint.epoch{epoch:02d}batch{batch:02d}.weights.h5\"\n        )\n        cbks = [\n            callbacks.ModelCheckpoint(\n                filepath, monitor=monitor, save_freq=1, save_weights_only=True\n            )\n        ]\n        model.fit(\n            x_train,\n            y_train,\n            batch_size=15,\n            validation_data=(x_test, y_test),\n            callbacks=cbks,\n            epochs=5,\n            verbose=1,\n        )\n\n        self.assertTrue(os.path.exists(filepath.format(epoch=1, batch=1)))\n        self.assertTrue(os.path.exists(filepath.format(epoch=1, batch=2)))\n        self.assertTrue(os.path.exists(filepath.format(epoch=2, batch=1)))\n        self.assertTrue(os.path.exists(filepath.format(epoch=2, batch=2)))\n        self.assertTrue(os.path.exists(filepath.format(epoch=3, batch=1)))\n        self.assertTrue(os.path.exists(filepath.format(epoch=3, batch=2)))\n        self.assertTrue(os.path.exists(filepath.format(epoch=4, batch=1)))\n        self.assertTrue(os.path.exists(filepath.format(epoch=4, batch=2)))\n        self.assertTrue(os.path.exists(filepath.format(epoch=5, batch=1)))\n        self.assertTrue(os.path.exists(filepath.format(epoch=5, batch=2)))\n\n        # Case 11: ModelCheckpoint saves model with initial_value_threshold\n        # param\n        mode = \"max\"\n        monitor = \"val_acc\"\n        initial_value_threshold = -0.01\n        save_best_only = True\n        filepath = os.path.join(temp_dir, \"checkpoint.keras\")\n        cbks = [\n            callbacks.ModelCheckpoint(\n                filepath,\n                monitor=monitor,\n                save_best_only=save_best_only,\n                initial_value_threshold=initial_value_threshold,\n                mode=mode,\n            )\n        ]\n        model.fit(\n            x_train,\n            y_train,\n            batch_size=BATCH_SIZE,\n            validation_data=(x_test, y_test),\n            callbacks=cbks,\n            epochs=1,\n            verbose=0,\n        )\n        self.assertTrue(os.path.exists(filepath))\n        os.remove(filepath)\n\n        # Case 12: ModelCheckpoint saves model with initial_value_threshold\n        # param\n        mode = \"auto\"\n        monitor = \"val_loss\"\n        initial_value_threshold = None\n        save_best_only = True\n        cbks = [\n            callbacks.ModelCheckpoint(\n                filepath,\n                monitor=monitor,\n                save_best_only=save_best_only,\n                initial_value_threshold=initial_value_threshold,\n                mode=mode,\n            )\n        ]\n        model.fit(\n            x_train,\n            y_train,\n            batch_size=BATCH_SIZE,\n            validation_data=(x_test, y_test),\n            callbacks=cbks,\n            epochs=1,\n            verbose=0,\n        )\n        self.assertTrue(os.path.exists(filepath))\n        os.remove(filepath)\n\n        # Case 13: ModelCheckpoint doesn't save model if loss was minimum\n        # earlier\n        mode = \"min\"\n        monitor = \"val_loss\"\n        initial_value_threshold = 0\n        save_best_only = True\n        cbks = [\n            callbacks.ModelCheckpoint(\n                filepath,\n                monitor=monitor,\n                save_best_only=save_best_only,\n                initial_value_threshold=initial_value_threshold,\n                mode=mode,\n            )\n        ]\n        model.fit(\n            x_train,\n            y_train,\n            batch_size=BATCH_SIZE,\n            validation_data=(x_test, y_test),\n            callbacks=cbks,\n            epochs=1,\n            verbose=0,\n        )\n        self.assertFalse(os.path.exists(filepath))\n\n        # Case 14: ModelCheckpoint doesn't save model if loss was min earlier in\n        # auto mode\n        mode = \"auto\"\n        monitor = \"val_loss\"\n        initial_value_threshold = 0\n        save_best_only = True\n        cbks = [\n            callbacks.ModelCheckpoint(\n                filepath,\n                monitor=monitor,\n                save_best_only=save_best_only,\n                initial_value_threshold=initial_value_threshold,\n                mode=mode,\n            )\n        ]\n        model.fit(\n            x_train,\n            y_train,\n            batch_size=BATCH_SIZE,\n            validation_data=(x_test, y_test),\n            callbacks=cbks,\n            epochs=1,\n            verbose=0,\n        )\n        self.assertFalse(os.path.exists(filepath))\n\n        # Case 15: ModelCheckpoint doesn't save model if auc was max earlier in\n        # auto mode\n        mode = \"auto\"\n        monitor = \"val_auc\"\n        initial_value_threshold = 1\n        save_best_only = True\n        cbks = [\n            callbacks.ModelCheckpoint(\n                filepath,\n                monitor=monitor,\n                save_best_only=save_best_only,\n                initial_value_threshold=initial_value_threshold,\n                mode=mode,\n            )\n        ]\n        model.compile(\n            loss=\"categorical_crossentropy\",\n            optimizer=\"sgd\",\n            metrics=[metrics.AUC()],\n        )\n        model.fit(\n            x_train,\n            y_train,\n            batch_size=BATCH_SIZE,\n            validation_data=(x_test, y_test),\n            callbacks=cbks,\n            epochs=1,\n            verbose=0,\n        )\n        self.assertFalse(os.path.exists(filepath))\n\n    @pytest.mark.skipif(\n        h5py is None,\n        reason=\"`h5py` is a required dependency for `ModelCheckpoint` tests.\",\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_model_checkpoint_loading(self):\n        def get_model():\n            inputs = layers.Input(shape=(INPUT_DIM,), batch_size=5)\n            x = layers.Dense(NUM_HIDDEN, activation=\"relu\")(inputs)\n            outputs = layers.Dense(NUM_CLASSES, activation=\"softmax\")(x)\n            functional_model = models.Model(inputs, outputs)\n            functional_model.compile(\n                loss=\"categorical_crossentropy\",\n                optimizer=\"sgd\",\n                metrics=[metrics.Accuracy(\"acc\")],\n            )\n            return functional_model\n\n        (x_train, y_train), (x_test, y_test) = test_utils.get_test_data(\n            random_seed=42,\n            train_samples=TRAIN_SAMPLES,\n            test_samples=TEST_SAMPLES,\n            input_shape=(INPUT_DIM,),\n            num_classes=NUM_CLASSES,\n        )\n        y_test = numerical_utils.to_categorical(y_test, num_classes=NUM_CLASSES)\n        y_train = numerical_utils.to_categorical(\n            y_train, num_classes=NUM_CLASSES\n        )\n\n        # Model Checkpoint load model (default)\n        model = get_model()\n        temp_dir = self.get_temp_dir()\n        filepath = os.path.join(temp_dir, \"checkpoint.model.keras\")\n        mode = \"auto\"\n        monitor = \"val_loss\"\n        save_best_only = True\n\n        cbks = [\n            callbacks.ModelCheckpoint(\n                filepath,\n                monitor=monitor,\n                save_best_only=save_best_only,\n                mode=mode,\n            )\n        ]\n        model.fit(\n            x_train,\n            y_train,\n            batch_size=BATCH_SIZE,\n            validation_data=(x_test, y_test),\n            callbacks=cbks,\n            epochs=1,\n            verbose=0,\n        )\n        ref_weights = model.get_weights()\n        self.assertTrue(os.path.exists(filepath))\n        new_model = saving.load_model(filepath)\n        new_weights = new_model.get_weights()\n        self.assertEqual(len(ref_weights), len(new_weights))\n        for ref_w, w in zip(ref_weights, new_weights):\n            self.assertAllClose(ref_w, w)\n\n        # Model Checkpoint load model weights\n        model = get_model()\n        temp_dir = self.get_temp_dir()\n        filepath = os.path.join(temp_dir, \"checkpoint.weights.h5\")\n        mode = \"auto\"\n        monitor = \"val_loss\"\n        save_best_only = True\n\n        cbks = [\n            callbacks.ModelCheckpoint(\n                filepath,\n                monitor=monitor,\n                save_best_only=save_best_only,\n                save_weights_only=True,\n                mode=mode,\n            )\n        ]\n        model.fit(\n            x_train,\n            y_train,\n            batch_size=BATCH_SIZE,\n            validation_data=(x_test, y_test),\n            callbacks=cbks,\n            epochs=1,\n            verbose=0,\n        )\n        ref_weights = model.get_weights()\n        self.assertTrue(os.path.exists(filepath))\n        new_model = get_model()\n        new_model.load_weights(filepath)\n        new_weights = new_model.get_weights()\n        self.assertEqual(len(ref_weights), len(new_weights))\n        for ref_w, w in zip(ref_weights, new_weights):\n            self.assertAllClose(ref_w, w)\n"
  },
  {
    "path": "keras/src/callbacks/monitor_callback.py",
    "content": "import warnings\n\nfrom keras.src import ops\nfrom keras.src.callbacks.callback import Callback\nfrom keras.src.trainers import compile_utils\n\n\nclass MonitorCallback(Callback):\n    \"\"\"Base class for callbacks that monitor a quantity and evaluates\n    improvements.\n\n    This class provides common functionality for callbacks that monitor a\n    metric during training to determine whether a condition has been met,\n    such as improvement over time. It encapsulates logic for selecting\n    the comparison operation based on a `monitor` value and `mode`, and\n    computing whether a new value is an improvement.\n\n    It is intended to be subclassed by other callbacks like `ModelCheckpoint`,\n    `EarlyStopping`, or `ReduceLROnPlateau`, and is not meant to be used\n    directly.\n\n    Arguments:\n        monitor: Quantity to be monitored. Defaults to `\"val_loss\"`.\n        mode: One of `{\"auto\", \"min\", \"max\"}`. In `min` mode, training will aim\n            to minimize the monitored quantity; in `'max'` mode it will aim to\n            maximize it.; in `\"auto\"` mode, the direction is automatically\n            inferred from the name of the monitored quantity. Defaults to\n            `\"auto\"`.\n        baseline: Floating point initial \"best\" value of the metric to be\n            monitored. If `None` (default), the first monitored value will be\n            used.\n        min_delta: Minimum change in the monitored quantity to qualify as an\n            improvement, i.e. an absolute change of less than min_delta, will\n            count as no improvement. Defaults to `0`.\n\n    Raises:\n        ValueError: If `mode='auto'` is selected and the direction of the metric\n        cannot be inferred.\n    \"\"\"\n\n    def __init__(\n        self,\n        monitor=\"val_loss\",\n        mode=\"auto\",\n        baseline=None,\n        min_delta=0,\n    ):\n        super().__init__()\n        if mode not in [\"auto\", \"min\", \"max\"]:\n            warnings.warn(\n                f\"{self.__class__.__name__} mode '{mode}' is unknown, fallback \"\n                \"to auto mode.\",\n                stacklevel=2,\n            )\n            mode = \"auto\"\n        self.monitor = monitor\n        self.mode = mode\n        self.best = baseline\n        self.min_delta = abs(min_delta)\n        self.monitor_op = None\n\n    def _set_monitor_op(self):\n        if self.mode == \"min\":\n            self.monitor_op = ops.less\n        elif self.mode == \"max\":\n            self.monitor_op = ops.greater\n        else:\n            metric_name = self.monitor.removeprefix(\"val_\")\n            if metric_name == \"loss\":\n                self.monitor_op = ops.less\n            if hasattr(self.model, \"metrics\"):\n                all_metrics = []\n                for m in self.model.metrics:\n                    if isinstance(\n                        m,\n                        (\n                            compile_utils.CompileMetrics,\n                            compile_utils.MetricsList,\n                        ),\n                    ):\n                        all_metrics.extend(m.metrics)\n                for m in all_metrics:\n                    if m.name == metric_name:\n                        if hasattr(m, \"_direction\"):\n                            if m._direction == \"up\":\n                                self.monitor_op = ops.greater\n                            else:\n                                self.monitor_op = ops.less\n            if self.monitor_op is None:\n                raise ValueError(\n                    f\"{self.__class__.__name__} callback received \"\n                    f\"monitor={self.monitor}, but Keras isn't able to \"\n                    \"automatically determine whether that metric should be \"\n                    \"maximized or minimized. Pass `mode='max'` in order to \"\n                    \"monitor based on the highest metric value, or pass \"\n                    \"`mode='min'` in order to use the lowest value.\"\n                )\n        if self.monitor_op == ops.less:\n            self.min_delta *= -1\n\n    def _is_improvement(self, monitor_value, reference_value):\n        if reference_value is None:\n            return True\n        return self.monitor_op(monitor_value - self.min_delta, reference_value)\n"
  },
  {
    "path": "keras/src/callbacks/monitor_callback_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import callbacks\nfrom keras.src import layers\nfrom keras.src import metrics\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import testing\n\n\nclass MonitorCallbackTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_monitor_op_logic(self):\n        x_train = np.random.random((10, 5))\n        y_train = np.random.random((10, 1))\n        x_test = np.random.random((10, 5))\n        y_test = np.random.random((10, 1))\n        model = models.Sequential(\n            (\n                layers.Dense(1, activation=\"relu\"),\n                layers.Dense(1, activation=\"relu\"),\n            )\n        )\n        model.compile(\n            loss=\"mae\",\n            optimizer=\"adam\",\n            metrics=[\n                \"mse\",\n                \"acc\",\n                \"accuracy\",\n                \"hinge\",\n                metrics.F1Score(name=\"f1_score\"),\n            ],\n        )\n\n        cases = [\n            (\"max\", \"val_mse\", \"max\"),\n            (\"min\", \"val_loss\", \"min\"),\n            (\"auto\", \"val_mse\", \"min\"),\n            (\"auto\", \"loss\", \"min\"),\n            (\"auto\", \"acc\", \"max\"),\n            (\"auto\", \"val_accuracy\", \"max\"),\n            (\"auto\", \"hinge\", \"min\"),\n            (\"auto\", \"f1_score\", \"max\"),\n        ]\n        for mode, monitor, expected_mode in cases:\n            monitor_callback = callbacks.MonitorCallback(monitor, mode)\n            monitor_callback.set_model(model)\n            model.fit(\n                x_train,\n                y_train,\n                batch_size=5,\n                validation_data=(x_test, y_test),\n                epochs=2,\n                verbose=0,\n            )\n            monitor_callback._set_monitor_op()\n            if expected_mode == \"max\":\n                monitor_op = ops.greater\n            else:\n                monitor_op = ops.less\n            self.assertEqual(monitor_callback.monitor_op, monitor_op)\n\n        with self.assertRaises(ValueError):\n            monitor = \"unknown\"\n            monitor_callback = callbacks.MonitorCallback(monitor)\n            monitor_callback.set_model(model)\n            model.fit(\n                x_train,\n                y_train,\n                batch_size=5,\n                validation_data=(x_test, y_test),\n                epochs=2,\n                verbose=0,\n            )\n            monitor_callback._set_monitor_op()\n\n    @pytest.mark.requires_trainable_backend\n    def test_min_delta(self):\n        monitor_callback = callbacks.MonitorCallback(mode=\"max\", min_delta=0.5)\n        monitor_callback._set_monitor_op()\n        self.assertTrue(monitor_callback._is_improvement(0.75, 0))\n        self.assertTrue(monitor_callback._is_improvement(0.5, None))\n        self.assertFalse(monitor_callback._is_improvement(0.5, 0))\n        self.assertFalse(monitor_callback._is_improvement(0.2, 0.5))\n"
  },
  {
    "path": "keras/src/callbacks/orbax_checkpoint.py",
    "content": "import warnings\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src import tree\nfrom keras.src.api_export import keras_export\nfrom keras.src.callbacks.monitor_callback import (\n    MonitorCallback,  # For metric monitoring logic\n)\nfrom keras.src.saving import saving_lib\nfrom keras.src.utils.module_utils import ocp\n\n# Context and AsyncOptions are accessed through the lazy-loaded ocp module\n\n# JAX monitoring compatibility: ensure record_scalar exists\n# to prevent AttributeError in older JAX versions\ntry:\n    import jax\n\n    if not hasattr(jax.monitoring, \"record_scalar\"):\n        jax.monitoring.record_scalar = lambda *args, **kwargs: None\nexcept ImportError:\n    pass\n\n\ndef _get_state_tree(model):\n    \"\"\"Get the complete model state as a nested tree structure.\"\"\"\n    # For JAX backend, preserve native arrays for performance\n    # For other backends, convert to numpy arrays\n    if backend.backend() == \"jax\":\n        state_tree = model.get_state_tree()\n        did_numpy_conversion = False\n    else:\n        state_tree = model.get_state_tree(value_format=\"numpy_array\")\n        did_numpy_conversion = True\n\n    # Convert numpy scalar types to Python types for Orbax compatibility\n    # Only needed when we did numpy conversion\n    if did_numpy_conversion:\n\n        def convert_scalars(obj):\n            if isinstance(obj, np.ndarray) and obj.ndim == 0:\n                # Convert 0-dimensional numpy arrays (scalars) to Python types\n                return obj.item()\n            elif isinstance(obj, np.generic):\n                # Convert numpy scalar types (like np.float32) to Python types\n                return obj.item()\n            else:\n                return obj\n\n        return tree.map_structure(convert_scalars, state_tree)\n    else:\n        return state_tree\n\n\n@keras_export(\"keras.callbacks.OrbaxCheckpoint\")\nclass OrbaxCheckpoint(MonitorCallback):\n    \"\"\"Callback to save and load model state using Orbax with a similar API to\n    ModelCheckpoint.\n\n    This callback saves the model's weights and optimizer state asynchronously\n    using Orbax, allowing training to continue without blocking for I/O.\n\n    **Multi-host Support**: When running in a multi-host distributed training\n    environment with JAX backend, this callback automatically coordinates\n    checkpointing across all hosts to ensure consistency and proper\n    synchronization. Multi-host checkpointing is only supported on JAX.\n\n    Example:\n\n    ```python\n    model.compile(loss=..., optimizer=..., metrics=['accuracy'])\n\n    EPOCHS = 10\n    checkpoint_dir = '/tmp/ckpt'\n    orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(\n        directory=checkpoint_dir,\n        monitor='val_accuracy',\n        mode='max',\n        save_best_only=True)\n\n    # Model is saved at the end of every epoch, if it's the best seen so far.\n    model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])\n\n    # Alternatively, save checkpoints every N batches -\n    orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(\n        directory=checkpoint_dir,\n        save_freq=100)  # Save every 100 batches\n\n    model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])\n    ```\n\n    Args:\n        directory: path to the directory where to save the checkpoints.\n        monitor: The metric name to monitor (e.g., 'val_loss').\n        verbose: Verbosity mode, 0 or 1.\n        save_best_only: if `save_best_only=True`, it only saves when the model\n            is considered the \"best\" based on the monitored quantity.\n        mode: one of {'auto', 'min', 'max'}. Used with `save_best_only`.\n        save_freq: `'epoch'` or integer. Frequency to save checkpoints.\n        max_to_keep: Integer, maximum number of recent checkpoints to keep.\n            If None, keeps all. Defaults to 1.\n        save_on_background: Boolean, whether to save asynchronously in the\n            background. Defaults to True.\n        initial_value_threshold: Floating point initial \"best\" value for the\n            monitor, used with `save_best_only`.\n    \"\"\"\n\n    def __init__(\n        self,\n        directory,\n        monitor=\"val_loss\",\n        verbose=0,\n        save_best_only=False,\n        mode=\"auto\",\n        save_freq=\"epoch\",\n        initial_value_threshold=None,\n        max_to_keep=1,\n        save_on_background=True,\n        save_weights_only=False,\n    ):\n        # Ensure orbax is available\n        ocp.initialize()\n\n        # Initialize MonitorCallback for handling 'monitor', 'mode', 'best'\n        # logic\n        super().__init__(monitor, mode, initial_value_threshold)\n\n        self.directory = directory\n        self.verbose = verbose\n        self.save_best_only = save_best_only\n        self.save_freq = save_freq\n        self.max_to_keep = max_to_keep\n        self.save_on_background = save_on_background\n        self.save_weights_only = save_weights_only\n        self._batches_seen_since_last_saving = 0\n        self._last_batch_seen = None\n        self._total_batches_seen = 0  # Global batch counter for step tracking\n        self._async_futures = []  # Track async save futures\n\n        # Multi-host support\n        self._multihost_initialized = self._is_multihost_initialized()\n\n        if self.save_freq != \"epoch\" and not isinstance(self.save_freq, int):\n            raise ValueError(\n                f\"Unrecognized save_freq: {self.save_freq}. \"\n                \"Expected save_freq are 'epoch' or integer values\"\n            )\n\n        # --- Orbax Checkpointer Setup (V1 API) ---\n        policies = []\n        if max_to_keep is not None:\n            policies.append(\n                ocp.training.preservation_policies.LatestN(max_to_keep)\n            )\n\n        # Use AnyPreservationPolicy to combine them, or use directly\n        # if single policy\n        preservation_policy = None\n        if policies:\n            if len(policies) == 1:\n                preservation_policy = policies[0]\n            else:\n                preservation_policy = (\n                    ocp.training.preservation_policies.AnyPreservationPolicy(\n                        policies\n                    )\n                )\n\n        # Create the V1 Checkpointer with direct parameter passing\n        # Orbax will handle directory creation on all processes as needed\n        # save_decision_policy is required for proper coordination of\n        # rapid async saves\n        self.checkpointer = ocp.training.Checkpointer(\n            directory=directory,\n            preservation_policy=preservation_policy,\n            save_decision_policy=ocp.training.save_decision_policies.FixedIntervalPolicy(\n                1\n            ),\n        )\n\n    def set_model(self, model):\n        super().set_model(model)\n        if hasattr(model, \"optimizer\") and model.optimizer is not None:\n            # Recover the number of batches seen for a reloaded model\n            self._total_batches_seen = int(model.optimizer.iterations)\n\n    def _is_multihost_initialized(self):\n        \"\"\"Check if multi-host environment is initialized.\"\"\"\n        # Multi-host checkpointing is only supported on JAX backend\n        if backend.backend() != \"jax\":\n            return False\n\n        multihost = ocp.multihost\n        # Check if JAX distributed client is initialized\n        # (indicates multihost setup)\n        return multihost.is_jax_distributed_client_initialized()\n\n    def _sync_processes(self, key=None):\n        \"\"\"Synchronize all processes across hosts.\"\"\"\n        if not self._multihost_initialized:\n            return  # No-op for single host\n\n        multihost = ocp.multihost\n        sync_key = key or \"orbax_checkpoint_sync\"\n        multihost.sync_global_processes(sync_key)\n\n    def is_multihost_enabled(self):\n        \"\"\"Return True if multi-host checkpointing is enabled and initialized.\n\n        This method can be used to check if the callback is operating in\n        a multi-host distributed training environment. Multi-host checkpointing\n        is only supported on JAX backend.\n\n        Returns:\n            bool: True if multi-host support is active, False otherwise.\n        \"\"\"\n        return self._multihost_initialized\n\n    def is_primary_host(self):\n        \"\"\"Return True if this process is the primary host in multi-host setup.\n\n        In multi-host environments, only the primary host typically handles\n        logging and coordination tasks. Multi-host checkpointing is only\n        supported on JAX backend.\n\n        Returns:\n            bool: True if this is the primary host, False otherwise.\n            Always returns True in single-host environments.\n        \"\"\"\n        if not self._multihost_initialized:\n            return True  # Single host is always primary\n        multihost = ocp.multihost\n        return multihost.is_primary_host(primary_host=0)\n\n    def _should_save_on_batch(self, batch):\n        \"\"\"Check if we should save on this batch.\"\"\"\n        if self.save_freq == \"epoch\":\n            return False\n\n        if self._last_batch_seen is None or batch <= self._last_batch_seen:\n            # New epoch.\n            add_batches = batch + 1\n        else:\n            add_batches = batch - self._last_batch_seen\n        self._batches_seen_since_last_saving += add_batches\n        self._last_batch_seen = batch\n        self._total_batches_seen += add_batches\n\n        if self._batches_seen_since_last_saving >= self.save_freq:\n            self._batches_seen_since_last_saving = 0\n            return True\n        return False\n\n    def _save_checkpoint(self, step, logs=None):\n        \"\"\"Save a checkpoint at the given step with multi-host coordination.\"\"\"\n\n        # --- Prepare Composite State (Backend-Agnostic) ---\n        state_tree = _get_state_tree(self.model)\n\n        # Save the nested state structures directly (preserving layer\n        # names and structure)\n        if self.save_weights_only:\n            composite_state = {\n                \"trainable_variables\": state_tree[\"trainable_variables\"],\n                \"non_trainable_variables\": state_tree[\n                    \"non_trainable_variables\"\n                ],\n            }\n        else:\n            composite_state = state_tree\n\n        # Build payload with pytree (pure arrays) and optional\n        # model_config / assets as separate checkpointables.\n        payload = {\"pytree\": composite_state}\n\n        if not self.save_weights_only:\n            config_json, _ = saving_lib._serialize_model_as_json(self.model)\n            payload[\"model_config\"] = {\"config\": config_json}\n\n            assets_dict = saving_lib._save_assets_to_dict(self.model)\n            if assets_dict is not None:\n                payload[\"assets\"] = assets_dict\n\n        # Use a single with statement. If context_options is empty,\n        # Context() uses defaults.\n        with ocp.Context():\n            # Determine sync vs async based on save_on_background setting\n            use_sync = not self.save_on_background\n\n            # Execute save based on sync/async mode\n            if use_sync:\n                self.checkpointer.save_checkpointables(step, payload)\n            else:\n                future = self.checkpointer.save_checkpointables_async(\n                    step, payload\n                )\n                self._async_futures.append(future)\n\n    def on_train_batch_end(self, batch, logs=None):\n        if self._should_save_on_batch(batch):\n            # Handle save_best_only logic for batch-level saving\n            should_save = True\n            if self.save_best_only:\n                current = logs.get(self.monitor) if logs else None\n                if current is None:\n                    warnings.warn(\n                        f\"Can save best model only with {self.monitor} \"\n                        f\"available, skipping save at batch {batch}.\",\n                        stacklevel=2,\n                    )\n                    should_save = False\n                elif not self._is_improvement(current, self.best):\n                    should_save = False\n                else:\n                    # Update best value when there's improvement\n                    self.best = current\n\n            if should_save:\n                # Use global batch count for Orbax save step\n                step = self._total_batches_seen\n                self._save_checkpoint(step=step, logs=logs)\n\n    def on_epoch_end(self, epoch, logs=None):\n        if self.monitor_op is None:\n            self._set_monitor_op()  # From MonitorCallback\n\n        # For save_freq=\"epoch\", save at every epoch\n        should_save = self.save_freq == \"epoch\"\n\n        # Handle save_best_only logic\n        if should_save and self.save_best_only:\n            current = logs.get(self.monitor) if logs else None\n            if current is None:\n                warnings.warn(\n                    f\"Can save best model only with {self.monitor} available, \"\n                    f\"skipping save at epoch {epoch}.\",\n                    stacklevel=2,\n                )\n                should_save = False\n            elif not self._is_improvement(current, self.best):\n                should_save = False\n            else:\n                # Update best value when there's improvement\n                self.best = current\n\n        if should_save:\n            # Use epoch number as the step for Orbax save\n            self._save_checkpoint(step=epoch, logs=logs)\n\n    def on_train_end(self, logs=None):\n        # Close the Checkpointer - this waits for any pending async saves\n        # to complete before closing\n        try:\n            self.checkpointer.close()\n        except Exception:\n            pass  # Ignore errors during cleanup\n\n        # Multi-host synchronization: ensure all hosts complete cleanup\n        self._sync_processes(\"checkpoint_cleanup\")\n\n    def wait_until_finished(self):\n        \"\"\"Wait for any in-progress checkpoint operations to complete.\n        This method blocks until all asynchronous checkpoint save operations\n        have completed across all hosts in a multi-host setup.\n        \"\"\"\n        # Wait for all tracked async futures to complete\n        for future in self._async_futures:\n            future.result()  # Wait for completion\n        self._async_futures.clear()  # Clear completed futures\n\n        # Wait for any remaining async operations to complete on this host\n        self.checkpointer.wait()\n\n        # Multi-host synchronization: ensure all hosts complete\n        self._sync_processes(\"checkpoint_wait_complete\")\n"
  },
  {
    "path": "keras/src/callbacks/orbax_checkpoint_test.py",
    "content": "import os\n\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import saving\nfrom keras.src import testing\nfrom keras.src import utils\nfrom keras.src.callbacks.orbax_checkpoint import OrbaxCheckpoint\nfrom keras.src.distribution import DeviceMesh\nfrom keras.src.distribution import LayoutMap\nfrom keras.src.distribution import ModelParallel\nfrom keras.src.distribution import TensorLayout\nfrom keras.src.distribution import distribution as get_distribution\nfrom keras.src.distribution import set_distribution\nfrom keras.src.saving import register_keras_serializable\nfrom keras.src.testing.test_utils import named_product\n\n\nclass OrbaxCheckpointTest(testing.TestCase, parameterized.TestCase):\n    def _create_test_model(self, steps_per_execution=1):\n        \"\"\"Create a simple test model compatible with 2-device sharding.\"\"\"\n        inputs = layers.Input(shape=(10,), name=\"input_layer\")\n        x = layers.Dense(6, name=\"dense_layer\")(inputs)  # 6 units (div by 2)\n        outputs = layers.Dense(2, name=\"output_layer\")(x)\n        model = models.Model(inputs, outputs, name=\"test_model\")\n        model.compile(\n            optimizer=\"adam\",\n            loss=\"mse\",\n            steps_per_execution=steps_per_execution,\n        )\n        return model\n\n    def _create_dummy_data(self, num_samples=100):\n        \"\"\"Create dummy training data.\"\"\"\n        x = np.random.randn(num_samples, 10)\n        y = np.random.randn(num_samples, 2)  # Match 2 outputs\n        return x, y\n\n    # Shared constants for distributed tests — fixed sizes divisible by\n    # 1, 2, 4, and 8 devices.\n    _DIST_DENSE_UNITS = 32\n    _DIST_OUT_UNITS = 16\n    _DIST_NUM_SAMPLES = 64\n    _DIST_PREDICT_BATCH = 8\n\n    def _setup_distributed_test(self):\n        \"\"\"Validate distributed prerequisites and return common objects.\n\n        Returns:\n            (num_devices, device_mesh, original_distribution)\n\n        Calls self.skipTest if fewer than 2 devices are available or if\n        the fixed layer sizes don't divide evenly by num_devices.\n        \"\"\"\n        import jax\n\n        devices = jax.devices()\n        num_devices = len(devices)\n        if num_devices < 2:\n            self.skipTest(\n                \"Test requires distributed setup with multiple devices\"\n            )\n        if (\n            self._DIST_DENSE_UNITS % num_devices != 0\n            or self._DIST_OUT_UNITS % num_devices != 0\n        ):\n            self.skipTest(\n                f\"num_devices={num_devices} does not evenly divide \"\n                f\"dense_units={self._DIST_DENSE_UNITS} or \"\n                f\"out_units={self._DIST_OUT_UNITS}\"\n            )\n        device_mesh = DeviceMesh(\n            (num_devices,), axis_names=[\"data\"], devices=devices\n        )\n        return num_devices, device_mesh, get_distribution()\n\n    def _build_distributed_model(self, dense_units, out_units):\n        \"\"\"Build and compile the shared two-layer functional model.\"\"\"\n        inputs_l = layers.Input(shape=(10,), name=\"input_layer\")\n        h = layers.Dense(dense_units, name=\"dense_layer\")(inputs_l)\n        outputs_l = layers.Dense(out_units, name=\"output_layer\")(h)\n        model = models.Model(inputs_l, outputs_l, name=\"test_model\")\n        model.compile(optimizer=\"adam\", loss=\"mse\")\n        return model\n\n    def _make_layout_map(self, device_mesh, *layer_names):\n        \"\"\"Build a LayoutMap that shards kernel+bias for each named layer.\n\n        Each named Dense layer's kernel is sharded along the output-units\n        axis (axes=(None, \"data\")) and its bias along axis 0 (axes=(\"data\",)).\n        Layers not listed are left replicated.\n        \"\"\"\n        layout_map = LayoutMap(device_mesh)\n        for name in layer_names:\n            layout_map[f\"{name}/kernel\"] = TensorLayout(axes=(None, \"data\"))\n            layout_map[f\"{name}/bias\"] = TensorLayout(axes=(\"data\",))\n        return layout_map\n\n    @parameterized.parameters(\n        {\"save_freq\": 10, \"epochs\": 1, \"batch_size\": 5},  # batch-level\n        {\"save_freq\": \"epoch\", \"epochs\": 3, \"batch_size\": None},  # epoch-level\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_checkpoint_saving_basic(self, save_freq, epochs, batch_size):\n        \"\"\"Test basic checkpoint saving with different frequencies.\"\"\"\n        model = self._create_test_model()\n        x, y = self._create_dummy_data(num_samples=50)\n\n        checkpoint_dir = os.path.join(\n            self.get_temp_dir(), f\"test_save_{save_freq}_{id(self)}\"\n        )\n        callback = OrbaxCheckpoint(\n            directory=checkpoint_dir, save_freq=save_freq\n        )\n\n        # Train with specified configuration\n        fit_kwargs = {\"callbacks\": [callback], \"verbose\": 0}\n        if batch_size:\n            fit_kwargs[\"batch_size\"] = batch_size\n        model.fit(x, y, epochs=epochs, **fit_kwargs)\n\n        # Verify checkpoint files were created\n        checkpoint_files = os.listdir(checkpoint_dir)\n        self.assertGreater(\n            len(checkpoint_files), 0, \"Should have checkpoint files\"\n        )\n\n    @parameterized.parameters(\n        {\"mode\": \"min\", \"monitor\": \"loss\"},\n        {\"mode\": \"max\", \"monitor\": \"loss\"},\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_save_best_only(self, mode, monitor):\n        \"\"\"Test save_best_only with different modes.\"\"\"\n        model = self._create_test_model()\n        x, y = self._create_dummy_data(num_samples=100)\n\n        checkpoint_dir = os.path.join(\n            self.get_temp_dir(), f\"test_best_{mode}_{id(self)}\"\n        )\n        callback = OrbaxCheckpoint(\n            directory=checkpoint_dir,\n            monitor=monitor,\n            save_best_only=True,\n            mode=mode,\n            save_freq=\"epoch\",\n        )\n\n        model.fit(x, y, epochs=5, callbacks=[callback], verbose=0)\n\n        checkpoint_files = os.listdir(checkpoint_dir)\n        self.assertGreater(\n            len(checkpoint_files), 0, \"Should have checkpoint files\"\n        )\n\n    @parameterized.parameters(\n        {\"save_on_background\": False},\n        {\"save_on_background\": True},\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_async_vs_sync_saving(self, save_on_background):\n        \"\"\"Test synchronous vs asynchronous saving.\"\"\"\n        model = self._create_test_model()\n        x, y = self._create_dummy_data()\n\n        checkpoint_dir = os.path.join(\n            self.get_temp_dir(), f\"test_async_{save_on_background}_{id(self)}\"\n        )\n        callback = OrbaxCheckpoint(\n            directory=checkpoint_dir,\n            save_freq=\"epoch\",\n            save_on_background=save_on_background,\n        )\n\n        model.fit(x, y, epochs=2, callbacks=[callback], verbose=0)\n\n        checkpoint_files = os.listdir(checkpoint_dir)\n        self.assertGreater(\n            len(checkpoint_files), 0, \"Should have checkpoint files\"\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_max_to_keep(self):\n        \"\"\"Test max_to_keep parameter limits number of checkpoints.\"\"\"\n        model = self._create_test_model()\n        x, y = self._create_dummy_data()\n\n        checkpoint_dir = os.path.join(\n            self.get_temp_dir(), f\"test_max_keep_{id(self)}\"\n        )\n        callback = OrbaxCheckpoint(\n            directory=checkpoint_dir, save_freq=\"epoch\", max_to_keep=2\n        )\n\n        model.fit(x, y, epochs=5, callbacks=[callback], verbose=0)\n\n        checkpoint_files = os.listdir(checkpoint_dir)\n        self.assertLessEqual(len(checkpoint_files), 5)\n\n    @pytest.mark.requires_trainable_backend\n    def test_load_weights_from_orbax_checkpoint(self):\n        \"\"\"Test loading weights from Orbax checkpoint using load_weights.\"\"\"\n\n        # Create and train model to create checkpoint\n        model = self._create_test_model()\n        x, y = self._create_dummy_data()\n\n        checkpoint_dir = os.path.join(\n            self.get_temp_dir(), \"test_load_weights_orbax\"\n        )\n        callback = OrbaxCheckpoint(\n            directory=checkpoint_dir,\n            save_freq=\"epoch\",\n            save_weights_only=True,  # Only save weights for load_weights test\n        )\n\n        # Train to create checkpoint\n        model.fit(x, y, epochs=1, callbacks=[callback], verbose=0)\n\n        # Get original weights after training\n        original_weights = model.get_weights()\n\n        # Create a new model with the same architecture\n        new_model = self._create_test_model()\n\n        # Initialize with different weights to ensure loading works\n        different_weights = [w * 2 for w in original_weights]\n        new_model.set_weights(different_weights)\n\n        # Verify weights are different initially\n        new_weights_before = new_model.get_weights()\n        for orig, new in zip(original_weights, new_weights_before):\n            self.assertNotAllClose(\n                orig, new, msg=\"Weights should be different before loading\"\n            )\n\n        # Load weights from Orbax checkpoint\n        new_model.load_weights(checkpoint_dir)\n\n        # Verify weights were loaded correctly\n        loaded_weights = new_model.get_weights()\n        for orig, loaded in zip(original_weights, loaded_weights):\n            self.assertAllClose(\n                orig,\n                loaded,\n                msg=\"Weights should match after loading from checkpoint\",\n            )\n\n    @pytest.mark.requires_trainable_backend\n    def test_save_freq_epoch(self):\n        \"\"\"Test save_freq='epoch' functionality.\"\"\"\n        model = self._create_test_model()\n        x, y = self._create_dummy_data()\n\n        checkpoint_dir = os.path.join(\n            self.get_temp_dir(), f\"test_epoch_freq_{id(self)}\"\n        )\n        callback = OrbaxCheckpoint(\n            directory=checkpoint_dir,\n            save_freq=\"epoch\",\n        )\n\n        # Train for 3 epochs\n        model.fit(x, y, epochs=3, callbacks=[callback], verbose=0)\n\n        # Should have only the latest checkpoint (epoch 2) due to max_to_keep=1\n        checkpoint_files = os.listdir(checkpoint_dir)\n        self.assertEqual(\n            len(checkpoint_files),\n            1,\n            f\"Should have exactly 1 checkpoint due to max_to_keep=1, \"\n            f\"found {len(checkpoint_files)}: {checkpoint_files}\",\n        )\n\n        # Check for the latest epoch directory (should be the highest numbered)\n        # Note: Due to preservation policy behavior, the actual latest kept\n        # may vary\n        # So we check that at least one checkpoint exists and has a reasonable\n        # name\n        self.assertTrue(\n            len(checkpoint_files) == 1 and checkpoint_files[0].isdigit(),\n            f\"Should have exactly one checkpoint with numeric name, \"\n            f\"found {checkpoint_files}\",\n        )\n\n    def test_invalid_save_freq(self):\n        \"\"\"Test error handling for invalid save_freq parameter.\"\"\"\n        checkpoint_dir = os.path.join(self.get_temp_dir(), \"test_invalid_freq\")\n        with self.assertRaises(ValueError):\n            OrbaxCheckpoint(directory=checkpoint_dir, save_freq=\"invalid\")\n\n    @pytest.mark.requires_trainable_backend\n    def test_initial_value_threshold(self):\n        \"\"\"Test initial_value_threshold parameter.\"\"\"\n        model = self._create_test_model()\n        x, y = self._create_dummy_data()\n\n        checkpoint_dir = os.path.join(self.get_temp_dir(), \"test_threshold\")\n        callback = OrbaxCheckpoint(\n            directory=checkpoint_dir,\n            monitor=\"loss\",\n            save_best_only=True,\n            mode=\"min\",\n            initial_value_threshold=1.0,\n            save_freq=\"epoch\",\n        )\n\n        model.fit(x, y, epochs=3, callbacks=[callback], verbose=0)\n        self.assertTrue(os.path.exists(checkpoint_dir))\n\n    @parameterized.parameters(\n        {\"save_on_background\": False},\n        {\"save_on_background\": True},\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_checkpoint_loading_comprehensive(self, save_on_background):\n        \"\"\"Test checkpoint loading with async and sync saving.\"\"\"\n        model = self._create_test_model()\n        model.compile(optimizer=\"adam\", loss=\"mse\")\n        x, y = self._create_dummy_data(num_samples=200)\n\n        checkpoint_dir = os.path.join(\n            self.get_temp_dir(),\n            f\"test_loading_{save_on_background}_{id(self)}\",\n        )\n\n        callback = OrbaxCheckpoint(\n            directory=checkpoint_dir,\n            save_freq=\"epoch\",\n            save_on_background=save_on_background,\n            save_weights_only=True,\n        )\n\n        model.fit(x, y, epochs=1, callbacks=[callback], verbose=0)\n        original_weights = model.get_weights()\n\n        # Test load_weights functionality\n        new_model = self._create_test_model()\n        new_model.compile(optimizer=\"adam\", loss=\"mse\")\n        new_x, new_y = self._create_dummy_data(num_samples=10)\n        new_model.fit(new_x, new_y, epochs=1, batch_size=5, verbose=0)\n\n        different_weights = [w * 2 for w in original_weights]\n        new_model.set_weights(different_weights)\n\n        # Verify different before loading\n        for orig, new in zip(original_weights, new_model.get_weights()):\n            self.assertNotAllClose(orig, new)\n\n        # Load and verify\n        new_model.load_weights(checkpoint_dir)\n        for orig, loaded in zip(original_weights, new_model.get_weights()):\n            self.assertAllClose(orig, loaded)\n\n    @pytest.mark.skipif(\n        backend.backend() != \"jax\",\n        reason=\"Requires JAX backend for distribution\",\n    )\n    def test_distributed_checkpoint_functionality(self):\n        \"\"\"Test OrbaxCheckpoint with distributed training.\n\n        Verifies that a full-model checkpoint (weights + optimizer state +\n        config) round-trips correctly under ModelParallel sharding.\n        All predict/load calls stay inside the distribution scope so that\n        JAX JIT sees the correct context mesh for sharded variables.\n        \"\"\"\n        num_devices, device_mesh, original_distribution = (\n            self._setup_distributed_test()\n        )\n\n        layout_map = self._make_layout_map(\n            device_mesh, \"dense_layer\", \"output_layer\"\n        )\n\n        dense_units = self._DIST_DENSE_UNITS\n        out_units = self._DIST_OUT_UNITS\n        predict_batch = self._DIST_PREDICT_BATCH\n\n        try:\n            set_distribution(ModelParallel(layout_map=layout_map))\n            model = self._build_distributed_model(dense_units, out_units)\n\n            x = np.random.randn(self._DIST_NUM_SAMPLES, 10)\n            y = np.random.randn(self._DIST_NUM_SAMPLES, out_units)\n\n            checkpoint_dir = os.path.join(\n                self.get_temp_dir(), \"test_distributed_checkpoint\"\n            )\n            callback = OrbaxCheckpoint(\n                directory=checkpoint_dir, save_freq=\"epoch\"\n            )\n            model.fit(x, y, epochs=2, callbacks=[callback], verbose=0)\n\n            original_predictions = model.predict(x[:predict_batch], verbose=0)\n            original_weights = model.get_weights()\n            original_opt_vars = [v.numpy() for v in model.optimizer.variables]\n\n            loaded = saving.load_model(checkpoint_dir)\n\n            for orig, lw in zip(original_weights, loaded.get_weights()):\n                self.assertAllClose(orig, lw)\n            for orig, lv in zip(original_opt_vars, loaded.optimizer.variables):\n                self.assertAllClose(orig, lv)\n\n            loaded_predictions = loaded.predict(x[:predict_batch], verbose=0)\n            self.assertAllClose(original_predictions, loaded_predictions)\n\n            self.assertEqual(model.name, loaded.name)\n            self.assertEqual(len(model.layers), len(loaded.layers))\n            self.assertTrue(loaded.compiled)\n            self.assertEqual(type(get_distribution()), ModelParallel)\n\n            original_shardings = {\n                var.path: var.value.sharding\n                for var in model.variables\n                if hasattr(var.value, \"sharding\")\n            }\n            loaded_shardings = {\n                var.path: var.value.sharding\n                for var in loaded.variables\n                if hasattr(var.value, \"sharding\")\n            }\n            for path, spec in original_shardings.items():\n                if path in loaded_shardings:\n                    self.assertEqual(\n                        spec,\n                        loaded_shardings[path],\n                        f\"Sharding mismatch for variable {path}\",\n                    )\n\n        finally:\n            if original_distribution is not None:\n                set_distribution(original_distribution)\n            else:\n                try:\n                    set_distribution(None)\n                except Exception:\n                    pass\n\n    @pytest.mark.skipif(\n        backend.backend() != \"jax\",\n        reason=\"Requires JAX backend for distribution\",\n    )\n    def test_distributed_checkpoint_resharding(self):\n        \"\"\"Test loading an Orbax checkpoint under a *different* layout.\n\n        Saves a model sharded with layout A (dense_layer + output_layer\n        sharded), then reloads it under layout B (only output_layer\n        sharded). The loaded model must have numerically identical\n        weights AND the new sharding layout.\n        \"\"\"\n        num_devices, device_mesh, original_distribution = (\n            self._setup_distributed_test()\n        )\n\n        dense_units = self._DIST_DENSE_UNITS\n        out_units = self._DIST_OUT_UNITS\n        predict_batch = self._DIST_PREDICT_BATCH\n\n        try:\n            # ---- Save with Layout A (both layers sharded) ----\n            layout_a = self._make_layout_map(\n                device_mesh, \"dense_layer\", \"output_layer\"\n            )\n            set_distribution(ModelParallel(layout_map=layout_a))\n            model = self._build_distributed_model(dense_units, out_units)\n\n            x = np.random.randn(self._DIST_NUM_SAMPLES, 10)\n            y = np.random.randn(self._DIST_NUM_SAMPLES, out_units)\n\n            checkpoint_dir = os.path.join(\n                self.get_temp_dir(), \"test_resharding_checkpoint\"\n            )\n            callback = OrbaxCheckpoint(\n                directory=checkpoint_dir, save_freq=\"epoch\"\n            )\n            model.fit(x, y, epochs=2, callbacks=[callback], verbose=0)\n\n            original_weights = model.get_weights()\n            original_predictions = model.predict(x[:predict_batch], verbose=0)\n\n            # ---- Reload with Layout B (only output_layer sharded) ----\n            layout_b = self._make_layout_map(device_mesh, \"output_layer\")\n            set_distribution(ModelParallel(layout_map=layout_b))\n\n            loaded = saving.load_model(checkpoint_dir)\n\n            # Weights must be numerically identical\n            for orig, lw in zip(original_weights, loaded.get_weights()):\n                self.assertAllClose(orig, lw)\n\n            loaded_predictions = loaded.predict(x[:predict_batch], verbose=0)\n            self.assertAllClose(original_predictions, loaded_predictions)\n\n            # Verify the loaded model uses Layout B shardings\n            self.assertEqual(model.name, loaded.name)\n            self.assertTrue(loaded.compiled)\n\n        finally:\n            if original_distribution is not None:\n                set_distribution(original_distribution)\n            else:\n                try:\n                    set_distribution(None)\n                except:\n                    pass\n\n    @pytest.mark.requires_trainable_backend\n    def test_checkpoint_loading_via_saving_api(self):\n        \"\"\"Test model loading via saving API.\"\"\"\n        model = self._create_test_model()\n        x, y = self._create_dummy_data()\n\n        # Test basic model loading\n        checkpoint_dir = os.path.join(self.get_temp_dir(), \"test_basic_loading\")\n        callback = OrbaxCheckpoint(directory=checkpoint_dir, save_freq=\"epoch\")\n        model.fit(x, y, epochs=1, callbacks=[callback], verbose=0)\n\n        original_weights = model.get_weights()\n        loaded_model = saving.load_model(checkpoint_dir)\n\n        # Verify weights and compilation\n        self.assertEqual(len(original_weights), len(loaded_model.get_weights()))\n        for orig, loaded in zip(original_weights, loaded_model.get_weights()):\n            self.assertAllClose(orig, loaded)\n        self.assertTrue(loaded_model.compiled)\n\n        # Test weights-only checkpoint should fail with load_model\n        weights_only_dir = os.path.join(\n            self.get_temp_dir(), \"test_weights_only\"\n        )\n        weights_callback = OrbaxCheckpoint(\n            directory=weights_only_dir,\n            save_freq=\"epoch\",\n            save_weights_only=True,\n        )\n        model.fit(x, y, epochs=1, callbacks=[weights_callback], verbose=0)\n\n        with self.assertRaises(ValueError):\n            saving.load_model(weights_only_dir)\n\n    @parameterized.parameters(\n        {\"save_on_background\": False},\n        {\"save_on_background\": True},\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_comprehensive_model_state_restoration(self, save_on_background):\n        \"\"\"Test comprehensive model state restoration with exact weight\n        matching.\n\n        Tests sync/async saving, exact weight matching, and complete state\n        restoration including trainable/non-trainable variables, optimizer\n        state, and custom layers.\n        \"\"\"\n        utils.set_random_seed(42)\n\n        # Create model with custom layer having non-trainable variables\n        @register_keras_serializable(package=\"test\")\n        class CustomLayer(layers.Layer):\n            def __init__(self, units, **kwargs):\n                super().__init__(**kwargs)\n                self.units = units\n\n            def build(self, input_shape):\n                self.kernel = self.add_weight(\n                    shape=(input_shape[-1], self.units), name=\"kernel\"\n                )\n                self.moving_mean = self.add_weight(\n                    shape=(self.units,), trainable=False, name=\"moving_mean\"\n                )\n                super().build(input_shape)\n\n            def call(self, inputs):\n                return inputs @ self.kernel\n\n        # Build model with both trainable and non-trainable variables\n        inputs = layers.Input(shape=(10,), name=\"input_layer\")\n        x = layers.Dense(8, name=\"dense_layer\")(inputs)\n        outputs = CustomLayer(2, name=\"custom_layer\")(x)\n        model = models.Model(inputs, outputs, name=\"comprehensive_test_model\")\n        model.compile(optimizer=\"adam\", loss=\"mse\", metrics=[\"mae\"])\n\n        x, y = self._create_dummy_data(num_samples=100)\n        checkpoint_dir = os.path.join(\n            self.get_temp_dir(),\n            f\"test_comprehensive_{save_on_background}_{id(self)}\",\n        )\n\n        # Test saving with exact weight matching\n        callback = OrbaxCheckpoint(\n            directory=checkpoint_dir,\n            save_freq=\"epoch\",\n            save_on_background=save_on_background,\n        )\n        model.fit(x, y, epochs=2, verbose=0, callbacks=[callback])\n\n        # Verify exact weight matching functionality\n        final_saved_weights = model.get_weights()\n        self.assertIsNotNone(final_saved_weights, \"Should have saved weights\")\n\n        # Load and verify complete model restoration\n        loaded_model = saving.load_model(checkpoint_dir)\n\n        # Architecture verification\n        self.assertEqual(model.name, loaded_model.name)\n        self.assertEqual(len(model.layers), len(loaded_model.layers))\n        self.assertTrue(loaded_model.compiled)\n\n        # Exact weight matching verification\n        loaded_weights = loaded_model.get_weights()\n        self.assertEqual(len(final_saved_weights), len(loaded_weights))\n        for i, (saved, loaded) in enumerate(\n            zip(final_saved_weights, loaded_weights)\n        ):\n            self.assertAllClose(saved, loaded, msg=f\"Weight {i} mismatch\")\n\n        # Verify optimizer variables\n        for i, (saved, loaded) in enumerate(\n            zip(model.optimizer.variables, loaded_model.optimizer.variables)\n        ):\n            self.assertAllClose(saved, loaded, msg=f\"Weight {i} mismatch\")\n\n    @parameterized.parameters(\n        {\"save_on_background\": False},\n        {\"save_on_background\": True},\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_checkpoint_with_assets(self, save_on_background):\n        \"\"\"Test checkpoint saving/loading with layers that have assets.\n\n        Tests that models with preprocessing layers that have vocab assets\n        can be saved and loaded correctly through Orbax checkpoints.\n\n        Passing a vocabulary *file path* (not an inline list) to\n        StringLookup causes the vocabulary to be stored via\n        save_assets / load_assets rather than inlined in get_config.\n        This test verifies the Orbax round-trip for that code path.\n        \"\"\"\n        # Write a vocabulary file so StringLookup stores it as an asset\n        # (inline lists are serialized in get_config, not via assets).\n        vocab_dir = self.get_temp_dir()\n        vocab_file = os.path.join(vocab_dir, \"vocab.txt\")\n        vocab_words = [\"cat\", \"dog\", \"bird\", \"fish\"]\n        with open(vocab_file, \"w\") as f:\n            f.write(\"\\n\".join(vocab_words))\n\n        string_lookup = layers.StringLookup(\n            vocabulary=vocab_file,\n            output_mode=\"int\",\n            name=\"string_lookup_layer\",\n        )\n\n        inputs = layers.Input(shape=(1,), dtype=\"string\")\n        x = string_lookup(inputs)\n        outputs = layers.Embedding(input_dim=10, output_dim=8)(x)\n        model = models.Model(inputs, outputs, name=\"model_with_assets\")\n        model.compile(optimizer=\"adam\", loss=\"mse\")\n\n        original_vocab = string_lookup.get_vocabulary()\n\n        # Save through OrbaxCheckpoint (the actual Orbax path)\n        checkpoint_dir = self.get_temp_dir()\n        callback = OrbaxCheckpoint(\n            directory=checkpoint_dir,\n            save_freq=\"epoch\",\n            save_on_background=save_on_background,\n            save_weights_only=False,\n        )\n\n        # We can't easily train with string inputs, so invoke the\n        # save path directly.\n        callback.set_model(model)\n        callback._save_checkpoint(step=0)\n        callback.checkpointer.close()\n\n        # Load the model back through the Orbax load path\n        loaded_model = saving.load_model(checkpoint_dir)\n\n        # Verify model structure\n        self.assertEqual(model.name, loaded_model.name)\n        self.assertEqual(len(model.layers), len(loaded_model.layers))\n\n        # Verify vocabulary (assets) was restored correctly\n        loaded_string_lookup = loaded_model.get_layer(\"string_lookup_layer\")\n        loaded_vocab = loaded_string_lookup.get_vocabulary()\n\n        self.assertEqual(original_vocab, loaded_vocab)\n\n    @parameterized.named_parameters(named_product(steps_per_execution=(1, 2)))\n    @pytest.mark.requires_trainable_backend\n    def test_training_resumption(self, steps_per_execution):\n        if backend.backend() == \"torch\" and steps_per_execution != 1:\n            pytest.skip(\"steps_per_execution unsupported on torch\")\n\n        model = self._create_test_model(steps_per_execution)\n        x, y = self._create_dummy_data(num_samples=50)\n        checkpoint_dir = self.get_temp_dir()\n\n        # Train with specified configuration\n        oc1 = OrbaxCheckpoint(checkpoint_dir, save_freq=1, max_to_keep=10)\n        model.fit(x, y, epochs=2, batch_size=25, callbacks=[oc1], verbose=0)\n\n        # Verify checkpoint files were created\n        checkpoint_files_1 = os.listdir(checkpoint_dir)\n        self.assertGreater(\n            len(checkpoint_files_1), 0, \"Should have checkpoint files\"\n        )\n\n        reloaded_model = saving.load_model(checkpoint_dir)\n        # Resume training with the same folder for checkpoints\n        oc2 = OrbaxCheckpoint(checkpoint_dir, save_freq=1, max_to_keep=10)\n        reloaded_model.fit(\n            x, y, epochs=1, batch_size=25, callbacks=[oc2], verbose=0\n        )\n\n        checkpoint_files_2 = os.listdir(checkpoint_dir)\n        self.assertGreater(\n            len(checkpoint_files_2),\n            len(checkpoint_files_1),\n            \"Should have more checkpoint files\",\n        )\n"
  },
  {
    "path": "keras/src/callbacks/progbar_logger.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.callbacks.callback import Callback\nfrom keras.src.utils import io_utils\nfrom keras.src.utils.progbar import Progbar\n\n\n@keras_export(\"keras.callbacks.ProgbarLogger\")\nclass ProgbarLogger(Callback):\n    \"\"\"Callback that prints metrics to stdout.\n\n    Args:\n        count_mode: One of `\"steps\"` or `\"samples\"`.\n            Whether the progress bar should\n            count samples seen or steps (batches) seen.\n\n    Raises:\n        ValueError: In case of invalid `count_mode`.\n    \"\"\"\n\n    def __init__(self):\n        super().__init__()\n        self.seen = 0\n        self.progbar = None\n        self.target = None\n        self.verbose = 1\n        self.epochs = 1\n\n        self._called_in_fit = False\n\n    def set_params(self, params):\n        verbose = params[\"verbose\"]\n        if verbose == \"auto\":\n            verbose = 1\n        self.verbose = verbose\n        self.epochs = params[\"epochs\"]\n        self.target = params[\"steps\"]\n\n    def on_train_begin(self, logs=None):\n        # When this logger is called inside `fit`, validation is silent.\n        self._called_in_fit = True\n\n    def on_test_begin(self, logs=None):\n        if not self._called_in_fit:\n            self._reset_progbar()\n            self._maybe_init_progbar()\n\n    def on_predict_begin(self, logs=None):\n        self._reset_progbar()\n        self._maybe_init_progbar()\n\n    def on_epoch_begin(self, epoch, logs=None):\n        self._reset_progbar()\n        self._maybe_init_progbar()\n        if self.verbose and self.epochs > 1:\n            io_utils.print_msg(f\"Epoch {epoch + 1}/{self.epochs}\")\n\n    def on_train_batch_end(self, batch, logs=None):\n        self._update_progbar(batch, logs)\n\n    def on_test_batch_end(self, batch, logs=None):\n        if not self._called_in_fit:\n            self._update_progbar(batch, logs)\n\n    def on_predict_batch_end(self, batch, logs=None):\n        # Don't pass prediction results.\n        self._update_progbar(batch, None)\n\n    def on_epoch_end(self, epoch, logs=None):\n        self._finalize_progbar(logs)\n\n    def on_test_end(self, logs=None):\n        if not self._called_in_fit:\n            self._finalize_progbar(logs)\n\n    def on_predict_end(self, logs=None):\n        self._finalize_progbar(logs)\n\n    def _reset_progbar(self):\n        self.seen = 0\n        self.progbar = None\n\n    def _maybe_init_progbar(self):\n        if self.progbar is None:\n            self.progbar = Progbar(\n                target=self.target, verbose=self.verbose, unit_name=\"step\"\n            )\n\n    def _update_progbar(self, batch, logs=None):\n        \"\"\"Updates the progbar.\"\"\"\n        logs = logs or {}\n        self._maybe_init_progbar()\n        self.seen = batch + 1  # One-indexed.\n\n        if self.verbose == 1:\n            self.progbar.update(self.seen, list(logs.items()), finalize=False)\n\n    def _finalize_progbar(self, logs):\n        logs = logs or {}\n        if self.target is None:\n            self.target = self.seen\n            self.progbar.target = self.target\n        self.progbar.update(self.target, list(logs.items()), finalize=True)\n"
  },
  {
    "path": "keras/src/callbacks/reduce_lr_on_plateau.py",
    "content": "import warnings\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.callbacks.monitor_callback import MonitorCallback\nfrom keras.src.utils import io_utils\n\n\n@keras_export(\"keras.callbacks.ReduceLROnPlateau\")\nclass ReduceLROnPlateau(MonitorCallback):\n    \"\"\"Reduce learning rate when a metric has stopped improving.\n\n    Models often benefit from reducing the learning rate by a factor\n    of 2-10 once learning stagnates. This callback monitors a\n    quantity and if no improvement is seen for a 'patience' number\n    of epochs, the learning rate is reduced.\n\n    Example:\n\n    ```python\n    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2,\n                                  patience=5, min_lr=0.001)\n    model.fit(x_train, y_train, callbacks=[reduce_lr])\n    ```\n\n    Args:\n        monitor: String. Quantity to be monitored.\n        factor: Float. Factor by which the learning rate will be reduced.\n            `new_lr = lr * factor`.\n        patience: Integer. Number of epochs with no improvement after which\n            learning rate will be reduced.\n        verbose: Integer. 0: quiet, 1: update messages.\n        mode: String. One of `{'auto', 'min', 'max'}`. In `'min'` mode,\n            the learning rate will be reduced when the\n            quantity monitored has stopped decreasing; in `'max'` mode it will\n            be reduced when the quantity monitored has stopped increasing; in\n            `'auto'` mode, the direction is automatically inferred from the name\n            of the monitored quantity.\n        min_delta: Float. Threshold for measuring the new optimum, to only focus\n            on significant changes.\n        cooldown: Integer. Number of epochs to wait before resuming normal\n            operation after the learning rate has been reduced.\n        min_lr: Float. Lower bound on the learning rate.\n    \"\"\"\n\n    def __init__(\n        self,\n        monitor=\"val_loss\",\n        factor=0.1,\n        patience=10,\n        verbose=0,\n        mode=\"auto\",\n        min_delta=1e-4,\n        cooldown=0,\n        min_lr=0.0,\n        **kwargs,\n    ):\n        super().__init__(monitor, mode, min_delta=min_delta)\n        if factor >= 1.0:\n            raise ValueError(\n                \"ReduceLROnPlateau does not support a factor >= 1.0. \"\n                f\"Received factor={factor}\"\n            )\n\n        self.factor = factor\n        self.min_lr = min_lr\n        self.patience = patience\n        self.verbose = verbose\n        self.cooldown = cooldown\n        self.cooldown_counter = 0  # Cooldown counter.\n        self.wait = 0\n\n    def _reset(self):\n        \"\"\"Resets wait counter and cooldown counter.\"\"\"\n        self.cooldown_counter = 0\n        self.wait = 0\n\n    def on_train_begin(self, logs=None):\n        self._reset()\n\n    def on_epoch_end(self, epoch, logs=None):\n        if self.monitor_op is None:\n            # Delay setup until the model's metrics are all built\n            self._set_monitor_op()\n        logs = logs or {}\n        logs[\"learning_rate\"] = float(\n            backend.convert_to_numpy(self.model.optimizer.learning_rate)\n        )\n        current = logs.get(self.monitor)\n\n        if current is None:\n            warnings.warn(\n                \"Learning rate reduction is conditioned on metric \"\n                f\"`{self.monitor}` which is not available. Available metrics \"\n                f\"are: {','.join(list(logs.keys()))}.\",\n                stacklevel=2,\n            )\n        else:\n            if self.in_cooldown():\n                self.cooldown_counter -= 1\n                self.wait = 0\n\n            if self._is_improvement(current, self.best):\n                self.best = current\n                self.wait = 0\n            elif not self.in_cooldown():\n                self.wait += 1\n                if self.wait >= self.patience:\n                    old_lr = float(\n                        backend.convert_to_numpy(\n                            self.model.optimizer.learning_rate\n                        )\n                    )\n                    if old_lr > np.float32(self.min_lr):\n                        new_lr = old_lr * self.factor\n                        new_lr = max(new_lr, self.min_lr)\n                        self.model.optimizer.learning_rate = new_lr\n                        if self.verbose > 0:\n                            io_utils.print_msg(\n                                f\"\\nEpoch {epoch + 1}: \"\n                                \"ReduceLROnPlateau reducing \"\n                                f\"learning rate to {new_lr}.\"\n                            )\n                        self.cooldown_counter = self.cooldown\n                        self.wait = 0\n\n    def in_cooldown(self):\n        return self.cooldown_counter > 0\n"
  },
  {
    "path": "keras/src/callbacks/reduce_lr_on_plateau_test.py",
    "content": "import pytest\n\nfrom keras.src import callbacks\nfrom keras.src import layers\nfrom keras.src import optimizers\nfrom keras.src import testing\nfrom keras.src.models import Sequential\nfrom keras.src.testing import test_utils\nfrom keras.src.utils import io_utils\nfrom keras.src.utils import numerical_utils\n\n\nclass ReduceLROnPlateauTest(testing.TestCase):\n    def setUp(self):\n        (x_train, y_train), (x_test, y_test) = test_utils.get_test_data(\n            train_samples=10,\n            test_samples=10,\n            input_shape=(3,),\n            num_classes=2,\n        )\n        y_test = numerical_utils.to_categorical(y_test)\n        y_train = numerical_utils.to_categorical(y_train)\n\n        model = Sequential([layers.Dense(5), layers.Dense(2)])\n\n        model.compile(\n            loss=\"mse\",\n            optimizer=optimizers.Adam(0.1),\n        )\n\n        self.model = model\n        self.x_train = x_train\n        self.x_test = x_test\n        self.y_train = y_train\n        self.y_test = y_test\n\n    @pytest.mark.requires_trainable_backend\n    def test_reduces_lr_with_model_fit(self):\n        reduce_lr = callbacks.ReduceLROnPlateau(\n            patience=1, factor=0.1, monitor=\"val_loss\", min_delta=100\n        )\n\n        self.model.fit(\n            self.x_train,\n            self.y_train,\n            validation_data=(self.x_test, self.y_test),\n            callbacks=[reduce_lr],\n            epochs=2,\n        )\n\n        self.assertEqual(self.model.optimizer.learning_rate.value, 0.01)\n\n    @pytest.mark.requires_trainable_backend\n    def test_throws_when_optimizer_has_schedule(self):\n        reduce_lr = callbacks.ReduceLROnPlateau(\n            patience=1, factor=0.1, monitor=\"val_loss\", min_delta=100\n        )\n\n        self.model.compile(\n            loss=\"mse\",\n            optimizer=optimizers.Adam(\n                optimizers.schedules.PolynomialDecay(\n                    initial_learning_rate=0.1, decay_steps=10\n                )\n            ),\n        )\n\n        with self.assertRaisesRegex(\n            TypeError,\n            \"This optimizer was created with a `LearningRateSchedule`\",\n        ):\n            self.model.fit(\n                self.x_train,\n                self.y_train,\n                validation_data=(self.x_test, self.y_test),\n                callbacks=[reduce_lr],\n                epochs=2,\n            )\n\n    @pytest.mark.requires_trainable_backend\n    def test_verbose_logging(self):\n        reduce_lr = callbacks.ReduceLROnPlateau(\n            patience=1, factor=0.1, monitor=\"val_loss\", min_delta=100, verbose=1\n        )\n        io_utils.disable_interactive_logging()\n        io_utils.set_logging_verbosity(\"INFO\")\n\n        with self.assertLogs() as logs:\n            self.model.fit(\n                self.x_train,\n                self.y_train,\n                validation_data=(self.x_test, self.y_test),\n                callbacks=[reduce_lr],\n                epochs=2,\n            )\n            expected_log = \"ReduceLROnPlateau reducing learning rate to 0.01\"\n            self.assertTrue(any(expected_log in log for log in logs.output))\n\n    @pytest.mark.requires_trainable_backend\n    def test_honors_min_lr(self):\n        reduce_lr = callbacks.ReduceLROnPlateau(\n            patience=1,\n            factor=0.1,\n            monitor=\"val_loss\",\n            min_delta=10,\n            min_lr=0.005,\n        )\n\n        self.model.fit(\n            self.x_train,\n            self.y_train,\n            validation_data=(self.x_test, self.y_test),\n            callbacks=[reduce_lr],\n            epochs=4,\n        )\n\n        self.assertEqual(self.model.optimizer.learning_rate.value, 0.005)\n\n    @pytest.mark.requires_trainable_backend\n    def test_cooldown(self):\n        reduce_lr = callbacks.ReduceLROnPlateau(\n            patience=1,\n            factor=0.1,\n            monitor=\"val_loss\",\n            min_delta=100,\n            cooldown=2,\n        )\n\n        self.model.fit(\n            self.x_train,\n            self.y_train,\n            validation_data=(self.x_test, self.y_test),\n            callbacks=[reduce_lr],\n            epochs=4,\n        )\n\n        # With a cooldown of 2 epochs, we should only reduce the LR every other\n        # epoch, so after 4 epochs we will have reduced 2 times.\n        self.assertAllClose(self.model.optimizer.learning_rate.value, 0.001)\n"
  },
  {
    "path": "keras/src/callbacks/remote_monitor.py",
    "content": "import json\nimport warnings\n\nimport numpy as np\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.callbacks.callback import Callback\n\ntry:\n    import requests\nexcept ImportError:\n    requests = None\n\n\n@keras_export(\"keras.callbacks.RemoteMonitor\")\nclass RemoteMonitor(Callback):\n    \"\"\"Callback used to stream events to a server.\n\n    Requires the `requests` library.\n    Events are sent to `root + '/publish/epoch/end/'` by default. Calls are\n    HTTP POST, with a `data` argument which is a\n    JSON-encoded dictionary of event data.\n    If `send_as_json=True`, the content type of the request will be\n    `\"application/json\"`.\n    Otherwise the serialized JSON will be sent within a form.\n\n    Args:\n        root: String; root url of the target server.\n        path: String; path relative to `root` to which the events will be sent.\n        field: String; JSON field under which the data will be stored.\n            The field is used only if the payload is sent within a form\n            (i.e. when `send_as_json=False`).\n        headers: Dictionary; optional custom HTTP headers.\n        send_as_json: Boolean; whether the request should be\n            sent as `\"application/json\"`.\n    \"\"\"\n\n    def __init__(\n        self,\n        root=\"http://localhost:9000\",\n        path=\"/publish/epoch/end/\",\n        field=\"data\",\n        headers=None,\n        send_as_json=False,\n    ):\n        super().__init__()\n\n        self.root = root\n        self.path = path\n        self.field = field\n        self.headers = headers\n        self.send_as_json = send_as_json\n\n    def on_epoch_end(self, epoch, logs=None):\n        if requests is None:\n            raise ImportError(\"RemoteMonitor requires the `requests` library.\")\n        logs = logs or {}\n        send = {}\n        send[\"epoch\"] = epoch\n        for k, v in logs.items():\n            # np.ndarray and np.generic are not scalar types\n            # therefore we must unwrap their scalar values and\n            # pass to the json-serializable dict 'send'\n            if isinstance(v, (np.ndarray, np.generic)):\n                send[k] = v.item()\n            else:\n                send[k] = v\n        try:\n            if self.send_as_json:\n                requests.post(\n                    self.root + self.path, json=send, headers=self.headers\n                )\n            else:\n                requests.post(\n                    self.root + self.path,\n                    {self.field: json.dumps(send)},\n                    headers=self.headers,\n                )\n        except requests.exceptions.RequestException:\n            warnings.warn(\n                f\"Could not reach RemoteMonitor root server at {self.root}\",\n                stacklevel=2,\n            )\n"
  },
  {
    "path": "keras/src/callbacks/remote_monitor_test.py",
    "content": "import warnings\nfrom unittest import mock\n\nimport numpy as np\n\nfrom conftest import skip_if_backend\nfrom keras.src import backend\nfrom keras.src import callbacks\nfrom keras.src import layers\nfrom keras.src import testing\nfrom keras.src.models import Sequential\nfrom keras.src.utils import numerical_utils\n\ntry:\n    import requests\nexcept ImportError:\n    requests = None\n\n\nclass TerminateOnNaNTest(testing.TestCase):\n    def test_RemoteMonitor(self):\n        if requests is None:\n            self.skipTest(\"`requests` required to run this test\")\n\n        monitor = callbacks.RemoteMonitor()\n        # This will raise a warning since the default address in unreachable:\n        warning_msg = \"Could not reach RemoteMonitor root server\"\n        with warnings.catch_warnings(record=True) as warning_logs:\n            warnings.simplefilter(\"always\")\n            monitor.on_epoch_end(0, logs={\"loss\": 0.0})\n            self.assertIn(warning_msg, str(warning_logs[-1].message))\n\n    def test_RemoteMonitor_np_array(self):\n        if requests is None:\n            self.skipTest(\"`requests` required to run this test\")\n\n        with mock.patch(\"requests.post\") as requests_post:\n            monitor = callbacks.RemoteMonitor(send_as_json=True)\n            a = np.arange(1)  # a 1 by 1 array\n            logs = {\"loss\": 0.0, \"val\": a}\n            monitor.on_epoch_end(0, logs=logs)\n            send = {\"loss\": 0.0, \"epoch\": 0, \"val\": 0}\n            requests_post.assert_called_once_with(\n                monitor.root + monitor.path, json=send, headers=monitor.headers\n            )\n\n    def test_RemoteMonitor_np_float32(self):\n        if requests is None:\n            self.skipTest(\"`requests` required to run this test\")\n\n        with mock.patch(\"requests.post\") as requests_post:\n            monitor = callbacks.RemoteMonitor(send_as_json=True)\n            a = np.float32(1.0)  # a float32 generic type\n            logs = {\"loss\": 0.0, \"val\": a}\n            monitor.on_epoch_end(0, logs=logs)\n            send = {\"loss\": 0.0, \"epoch\": 0, \"val\": 1.0}\n            requests_post.assert_called_once_with(\n                monitor.root + monitor.path, json=send, headers=monitor.headers\n            )\n\n    @skip_if_backend(\n        \"openvino\", \"openvino backend does not support `fit` method\"\n    )\n    def test_RemoteMonitorWithJsonPayload(self):\n        if requests is None:\n            self.skipTest(\"`requests` required to run this test\")\n\n        if backend.backend() == \"numpy\":\n            self.skipTest(\"Trainer not implemented from NumPy backend.\")\n        TRAIN_SAMPLES = 10\n        TEST_SAMPLES = 10\n        INPUT_DIM = 3\n        NUM_CLASSES = 2\n        BATCH_SIZE = 4\n\n        np.random.seed(1337)\n        x_train = np.random.random((TRAIN_SAMPLES, INPUT_DIM))\n        y_train = np.random.choice(np.arange(NUM_CLASSES), size=TRAIN_SAMPLES)\n        x_test = np.random.random((TEST_SAMPLES, INPUT_DIM))\n        y_test = np.random.choice(np.arange(NUM_CLASSES), size=TEST_SAMPLES)\n        y_test = numerical_utils.to_categorical(y_test)\n        y_train = numerical_utils.to_categorical(y_train)\n\n        model = Sequential([layers.Dense(NUM_CLASSES)])\n        model.compile(loss=\"mean_squared_error\", optimizer=\"sgd\")\n\n        with mock.patch(\"requests.post\") as requests_post:\n            monitor = callbacks.RemoteMonitor(send_as_json=True)\n            hist = model.fit(\n                x_train,\n                y_train,\n                batch_size=BATCH_SIZE,\n                validation_data=(x_test, y_test),\n                callbacks=[monitor],\n                epochs=1,\n            )\n            send = {\n                \"epoch\": 0,\n                \"loss\": hist.history[\"loss\"][0],\n                \"val_loss\": hist.history[\"val_loss\"][0],\n            }\n            requests_post.assert_called_once_with(\n                monitor.root + monitor.path, json=send, headers=monitor.headers\n            )\n"
  },
  {
    "path": "keras/src/callbacks/swap_ema_weights.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.callbacks.callback import Callback\n\n\n@keras_export(\"keras.callbacks.SwapEMAWeights\")\nclass SwapEMAWeights(Callback):\n    \"\"\"Swaps model weights and EMA weights before and after evaluation.\n\n    This callbacks replaces the model's weight values with the values of\n    the optimizer's EMA weights (the exponential moving average of the past\n    model weights values, implementing \"Polyak averaging\") before model\n    evaluation, and restores the previous weights after evaluation.\n\n    The `SwapEMAWeights` callback is to be used in conjunction with\n    an optimizer that sets `use_ema=True`.\n\n    Note that the weights are swapped in-place in order to save memory.\n    The behavior is undefined if you modify the EMA weights\n    or model weights in other callbacks.\n\n    Example:\n\n    ```python\n    # Remember to set `use_ema=True` in the optimizer\n    optimizer = SGD(use_ema=True)\n    model.compile(optimizer=optimizer, loss=..., metrics=...)\n\n    # Metrics will be computed with EMA weights\n    model.fit(X_train, Y_train, callbacks=[SwapEMAWeights()])\n\n    # If you want to save model checkpoint with EMA weights, you can set\n    # `swap_on_epoch=True` and place ModelCheckpoint after SwapEMAWeights.\n    model.fit(\n        X_train,\n        Y_train,\n        callbacks=[SwapEMAWeights(swap_on_epoch=True), ModelCheckpoint(...)]\n    )\n    ```\n\n    Args:\n        swap_on_epoch: whether to perform swapping at `on_epoch_begin()`\n            and `on_epoch_end()`. This is useful if you want to use\n            EMA weights for other callbacks such as `ModelCheckpoint`.\n            Defaults to `False`.\n    \"\"\"\n\n    def __init__(self, swap_on_epoch=False):\n        super().__init__()\n        self.swap_on_epoch = swap_on_epoch\n\n        self._ema_weights_in_model = False\n\n    def _tf_swap_variables(self, optimizer):\n        for var, average_var in zip(\n            self.model.trainable_variables,\n            optimizer._model_variables_moving_average,\n        ):\n            if isinstance(var, backend.Variable):\n                var = var.value\n            if isinstance(average_var, backend.Variable):\n                average_var = average_var.value\n            # swap using addition to prevent variable creation\n            optimizer._distribution_strategy.extended.update(\n                var,\n                lambda a, b: a.assign_add(b),\n                args=(average_var,),\n            )\n            optimizer._distribution_strategy.extended.update(\n                var,\n                lambda a, b: b.assign(a - b),\n                args=(average_var,),\n            )\n            optimizer._distribution_strategy.extended.update(\n                var,\n                lambda a, b: a.assign(a - b),\n                args=(average_var,),\n            )\n\n    def _backend_swap_variables(self, optimizer):\n        for var, average_var in zip(\n            self.model.trainable_variables,\n            optimizer._model_variables_moving_average,\n        ):\n            temporary_variable = ops.convert_to_numpy(var)\n            var.assign(average_var)\n            average_var.assign(temporary_variable)\n\n    def _tf_finalize_ema_values(self, optimizer):\n        for var, average_var in zip(\n            self.model.trainable_variables,\n            optimizer._model_variables_moving_average,\n        ):\n            if isinstance(var, backend.Variable):\n                var = var.value\n            if isinstance(average_var, backend.Variable):\n                average_var = average_var.value\n            optimizer._distribution_strategy.extended.update(\n                average_var,\n                lambda a, b: a.assign(b),\n                args=(var,),\n            )\n\n    def _backend_finalize_ema_values(self, optimizer):\n        for var, average_var in zip(\n            self.model.trainable_variables,\n            optimizer._model_variables_moving_average,\n        ):\n            average_var.assign(var)\n\n    def _swap_variables(self):\n        if hasattr(self.model.optimizer, \"inner_optimizer\"):\n            # LossScaleOptimizer\n            optimizer = self.model.optimizer.inner_optimizer\n        else:\n            optimizer = self.model.optimizer\n        if not hasattr(optimizer, \"_model_variables_moving_average\"):\n            raise ValueError(\n                \"SwapEMAWeights must be used when \"\n                \"`use_ema=True` is set on the optimizer. \"\n                f\"Received: use_ema={optimizer.use_ema}\"\n            )\n        if backend.backend() == \"tensorflow\":\n            self._tf_swap_variables(optimizer)\n        else:\n            self._backend_swap_variables(optimizer)\n\n    def _finalize_ema_values(self):\n        if hasattr(self.model.optimizer, \"inner_optimizer\"):\n            # LossScaleOptimizer\n            optimizer = self.model.optimizer.inner_optimizer\n        else:\n            optimizer = self.model.optimizer\n        if not hasattr(optimizer, \"_model_variables_moving_average\"):\n            raise ValueError(\n                \"SwapEMAWeights must be used when \"\n                \"`use_ema=True` is set on the optimizer. \"\n                f\"Received: use_ema={optimizer.use_ema}\"\n            )\n        if backend.backend() == \"tensorflow\":\n            self._tf_finalize_ema_values(optimizer)\n        else:\n            self._backend_finalize_ema_values(optimizer)\n\n    def on_epoch_begin(self, epoch, logs=None):\n        if self.swap_on_epoch and self._ema_weights_in_model:\n            self._swap_variables()\n            self._ema_weights_in_model = False\n\n    def on_epoch_end(self, epoch, logs=None):\n        if self.swap_on_epoch and not self._ema_weights_in_model:\n            self._swap_variables()\n            self._ema_weights_in_model = True\n            # We need to recover EMA weights from the previously swapped weights\n            # in the last epoch. This is because, at the end of the fitting,\n            # `finalize_variable_values` will be called to assign\n            # `_model_variables_moving_average` to `trainable_variables`.\n            if epoch == self.params[\"epochs\"] - 1:\n                self._finalize_ema_values()\n\n    def on_test_begin(self, logs=None):\n        if not self._ema_weights_in_model:\n            self._swap_variables()\n            self._ema_weights_in_model = True\n\n    def on_test_end(self, logs=None):\n        if self._ema_weights_in_model:\n            self._swap_variables()\n            self._ema_weights_in_model = False\n\n    def on_predict_begin(self, logs=None):\n        if not self._ema_weights_in_model:\n            self._swap_variables()\n            self._ema_weights_in_model = True\n\n    def on_predict_end(self, logs=None):\n        if not self._ema_weights_in_model:\n            self._swap_variables()\n            self._ema_weights_in_model = False\n"
  },
  {
    "path": "keras/src/callbacks/swap_ema_weights_test.py",
    "content": "import os.path\nimport tempfile\n\nimport pytest\nimport tensorflow as tf\nfrom tensorflow.python.eager import context\n\nfrom keras.src import backend\nfrom keras.src import callbacks\nfrom keras.src import layers\nfrom keras.src import losses\nfrom keras.src import metrics\nfrom keras.src import optimizers\nfrom keras.src import saving\nfrom keras.src import testing\nfrom keras.src.models import Sequential\nfrom keras.src.testing import test_utils\nfrom keras.src.utils import numerical_utils\n\n\nclass SwapEMAWeightsTest(testing.TestCase):\n    def setUp(self):\n        (x_train, y_train), _ = test_utils.get_test_data(\n            train_samples=10,\n            test_samples=10,\n            input_shape=(3,),\n            num_classes=2,\n            random_seed=2023,\n        )\n        y_train = numerical_utils.to_categorical(y_train)\n\n        self.x_train = x_train\n        self.y_train = y_train\n\n    def _get_compiled_model(\n        self, use_ema=True, jit_compile=True, loss_scale=False\n    ):\n        optimizer = optimizers.SGD(use_ema=use_ema, ema_momentum=0.9)\n        if loss_scale:\n            optimizer = optimizers.LossScaleOptimizer(optimizer)\n        model = Sequential(\n            [layers.Dense(2, kernel_initializer=\"ones\", use_bias=False)]\n        )\n        model.compile(\n            optimizer=optimizer,\n            loss=losses.MeanSquaredError(),\n            metrics=[metrics.MeanSquaredError()],\n            jit_compile=jit_compile,\n        )\n        return model\n\n    @pytest.mark.requires_trainable_backend\n    def test_swap_ema_weights_with_invalid_optimizer(self):\n        model = self._get_compiled_model(use_ema=False)\n        with self.assertRaisesRegex(\n            ValueError,\n            (\"SwapEMAWeights must be used when `use_ema=True` is set\"),\n        ):\n            model.fit(\n                self.x_train,\n                self.y_train,\n                epochs=2,\n                callbacks=[callbacks.SwapEMAWeights()],\n                validation_data=(self.x_train, self.y_train),\n            )\n\n    @pytest.mark.requires_trainable_backend\n    def test_swap_ema_weights(self):\n        # not using SwapEMAWeights\n        model = self._get_compiled_model()\n        history = model.fit(\n            self.x_train,\n            self.y_train,\n            epochs=2,\n            validation_data=(self.x_train, self.y_train),\n        )\n        logs = model.evaluate(self.x_train, self.y_train, return_dict=True)\n        # final metric during fitting is different from the evaluation\n        self.assertNotEqual(\n            history.history[\"val_mean_squared_error\"][-1],\n            logs[\"mean_squared_error\"],\n        )\n\n        # using SwapEMAWeights\n        model = self._get_compiled_model()\n        history = model.fit(\n            self.x_train,\n            self.y_train,\n            epochs=2,\n            callbacks=[callbacks.SwapEMAWeights()],\n            validation_data=(self.x_train, self.y_train),\n        )\n        logs = model.evaluate(self.x_train, self.y_train, return_dict=True)\n        # final metric during fitting is same as the evaluation\n        self.assertEqual(\n            history.history[\"val_mean_squared_error\"][-1],\n            logs[\"mean_squared_error\"],\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_swap_ema_weights_on_epoch(self):\n        # using SwapEMAWeights together with ModelCheckpoint\n        model = self._get_compiled_model()\n        with tempfile.TemporaryDirectory() as temp_dir:\n            model.fit(\n                self.x_train,\n                self.y_train,\n                epochs=2,\n                callbacks=[\n                    callbacks.SwapEMAWeights(swap_on_epoch=True),\n                    callbacks.ModelCheckpoint(\n                        os.path.join(temp_dir, \"{epoch:1d}.keras\")\n                    ),\n                ],\n                validation_data=(self.x_train, self.y_train),\n            )\n            model2 = saving.load_model(os.path.join(temp_dir, \"2.keras\"))\n\n        logs = model.evaluate(self.x_train, self.y_train, return_dict=True)\n        logs2 = model2.evaluate(self.x_train, self.y_train, return_dict=True)\n        # saved checkpoint will be applied by EMA weights\n        self.assertEqual(\n            logs[\"mean_squared_error\"],\n            logs2[\"mean_squared_error\"],\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_swap_ema_weights_with_loss_scale_optimizer(self):\n        model = self._get_compiled_model(loss_scale=True)\n        history = model.fit(\n            self.x_train,\n            self.y_train,\n            epochs=2,\n            callbacks=[callbacks.SwapEMAWeights()],\n            validation_data=(self.x_train, self.y_train),\n        )\n        logs = model.evaluate(self.x_train, self.y_train, return_dict=True)\n        # final metric during fitting is same as the evaluation\n        self.assertEqual(\n            history.history[\"val_mean_squared_error\"][-1],\n            logs[\"mean_squared_error\"],\n        )\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"The distribute test can only run with TF backend.\",\n    )\n    def test_swap_ema_weights_with_tf_distribute(self):\n        # Need at least 2 devices for distribution related tests.\n        cpus = tf.config.list_physical_devices(\"CPU\")\n        context._reset_context()\n        tf.config.set_logical_device_configuration(\n            cpus[0],\n            [\n                tf.config.LogicalDeviceConfiguration(),\n                tf.config.LogicalDeviceConfiguration(),\n            ],\n        )\n        strategy = tf.distribute.MirroredStrategy([\"CPU:0\", \"CPU:1\"])\n        with strategy.scope():\n            # TODO: set jit_compile=True once the issue is resolved in\n            # integration_tests/tf_distribute_training_test.py#L52\n            model = self._get_compiled_model(jit_compile=False)\n            with tempfile.TemporaryDirectory() as temp_dir:\n                model.fit(\n                    self.x_train,\n                    self.y_train,\n                    epochs=2,\n                    callbacks=[\n                        callbacks.SwapEMAWeights(swap_on_epoch=True),\n                        callbacks.ModelCheckpoint(\n                            os.path.join(\n                                temp_dir, \"distributed_{epoch:1d}.keras\"\n                            )\n                        ),\n                    ],\n                    validation_data=(self.x_train, self.y_train),\n                )\n                model2 = saving.load_model(\n                    os.path.join(temp_dir, \"distributed_2.keras\")\n                )\n        logs = model.evaluate(self.x_train, self.y_train, return_dict=True)\n        logs2 = model2.evaluate(self.x_train, self.y_train, return_dict=True)\n        # saved checkpoint will be applied by EMA weights\n        self.assertEqual(\n            logs[\"mean_squared_error\"],\n            logs2[\"mean_squared_error\"],\n        )\n"
  },
  {
    "path": "keras/src/callbacks/tensorboard.py",
    "content": "import logging\nimport os\nimport sys\nimport time\nimport warnings\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import tree\nfrom keras.src.api_export import keras_export\nfrom keras.src.callbacks.callback import Callback\nfrom keras.src.layers import Embedding\nfrom keras.src.optimizers import Optimizer\nfrom keras.src.utils import file_utils\n\n\n@keras_export(\"keras.callbacks.TensorBoard\")\nclass TensorBoard(Callback):\n    \"\"\"Enable visualizations for TensorBoard.\n\n    TensorBoard is a visualization tool provided with TensorFlow. A TensorFlow\n    installation is required to use this callback.\n\n    This callback logs events for TensorBoard, including:\n\n    * Metrics summary plots\n    * Training graph visualization\n    * Weight histograms\n    * Sampled profiling\n\n    When used in `model.evaluate()` or regular validation\n    in addition to epoch summaries, there will be a summary that records\n    evaluation metrics vs `model.optimizer.iterations` written. The metric names\n    will be prepended with `evaluation`, with `model.optimizer.iterations` being\n    the step in the visualized TensorBoard.\n\n    If you have installed TensorFlow with pip, you should be able\n    to launch TensorBoard from the command line:\n\n    ```\n    tensorboard --logdir=path_to_your_logs\n    ```\n\n    You can find more information about TensorBoard\n    [here](https://www.tensorflow.org/get_started/summaries_and_tensorboard).\n\n    Args:\n        log_dir: the path of the directory where to save the log files to be\n            parsed by TensorBoard. e.g.,\n            `log_dir = os.path.join(working_dir, 'logs')`.\n            This directory should not be reused by any other callbacks.\n        histogram_freq: frequency (in epochs) at which to compute\n            weight histograms for the layers of the model. If set to 0,\n            histograms won't be computed. Validation data (or split) must be\n            specified for histogram visualizations.\n        write_graph:  (Not supported at this time)\n            Whether to visualize the graph in TensorBoard.\n            Note that the log file can become quite large\n            when `write_graph` is set to `True`.\n        write_images: whether to write model weights to visualize as image in\n            TensorBoard.\n        write_steps_per_second: whether to log the training steps per second\n            into TensorBoard. This supports both epoch and batch frequency\n            logging.\n        update_freq: `\"batch\"` or `\"epoch\"` or integer. When using `\"epoch\"`,\n            writes the losses and metrics to TensorBoard after every epoch.\n            If using an integer, let's say `1000`, all metrics and losses\n            (including custom ones added by `Model.compile`) will be logged to\n            TensorBoard every 1000 batches. `\"batch\"` is a synonym for 1,\n            meaning that they will be written every batch.\n            Note however that writing too frequently to TensorBoard can slow\n            down your training, especially when used with distribution\n            strategies as it will incur additional synchronization overhead.\n            Batch-level summary writing is also available via `train_step`\n            override. Please see\n            [TensorBoard Scalars tutorial](\n                https://www.tensorflow.org/tensorboard/scalars_and_keras#batch-level_logging)\n            for more details.\n        profile_batch: Profile the batch(es) to sample compute characteristics.\n            profile_batch must be a non-negative integer or a tuple of integers.\n            A pair of positive integers signify a range of batches to profile.\n            By default, profiling is disabled.\n        embeddings_freq: frequency (in epochs) at which embedding layers will be\n            visualized. If set to 0, embeddings won't be visualized.\n        embeddings_metadata: Dictionary which maps embedding layer names to the\n            filename of a file in which to save metadata for the embedding layer.\n            In case the same metadata file is to be\n            used for all embedding layers, a single filename can be passed.\n\n    Examples:\n\n    ```python\n    tensorboard_callback = keras.callbacks.TensorBoard(log_dir=\"./logs\")\n    model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])\n    # Then run the tensorboard command to view the visualizations.\n    ```\n\n    Custom batch-level summaries in a subclassed Model:\n\n    ```python\n    class MyModel(keras.Model):\n\n        def build(self, _):\n            self.dense = keras.layers.Dense(10)\n\n        def call(self, x):\n            outputs = self.dense(x)\n            tf.summary.histogram('outputs', outputs)\n            return outputs\n\n    model = MyModel()\n    model.compile('sgd', 'mse')\n\n    # Make sure to set `update_freq=N` to log a batch-level summary every N\n    # batches.  In addition to any `tf.summary` contained in `model.call()`,\n    # metrics added in `Model.compile` will be logged every N batches.\n    tb_callback = keras.callbacks.TensorBoard('./logs', update_freq=1)\n    model.fit(x_train, y_train, callbacks=[tb_callback])\n    ```\n\n    Custom batch-level summaries in a Functional API Model:\n\n    ```python\n    def my_summary(x):\n        tf.summary.histogram('x', x)\n        return x\n\n    inputs = keras.Input(10)\n    x = keras.layers.Dense(10)(inputs)\n    outputs = keras.layers.Lambda(my_summary)(x)\n    model = keras.Model(inputs, outputs)\n    model.compile('sgd', 'mse')\n\n    # Make sure to set `update_freq=N` to log a batch-level summary every N\n    # batches. In addition to any `tf.summary` contained in `Model.call`,\n    # metrics added in `Model.compile` will be logged every N batches.\n    tb_callback = keras.callbacks.TensorBoard('./logs', update_freq=1)\n    model.fit(x_train, y_train, callbacks=[tb_callback])\n    ```\n\n    Profiling:\n\n    ```python\n    # Profile a single batch, e.g. the 5th batch.\n    tensorboard_callback = keras.callbacks.TensorBoard(\n        log_dir='./logs', profile_batch=5)\n    model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])\n\n    # Profile a range of batches, e.g. from 10 to 20.\n    tensorboard_callback = keras.callbacks.TensorBoard(\n        log_dir='./logs', profile_batch=(10,20))\n    model.fit(x_train, y_train, epochs=2, callbacks=[tensorboard_callback])\n    ```\n    \"\"\"  # noqa: E501\n\n    def __init__(\n        self,\n        log_dir=\"logs\",\n        histogram_freq=0,\n        write_graph=True,\n        write_images=False,\n        write_steps_per_second=False,\n        update_freq=\"epoch\",\n        profile_batch=0,\n        embeddings_freq=0,\n        embeddings_metadata=None,\n    ):\n        super().__init__()\n\n        self.log_dir = str(log_dir)\n        self.histogram_freq = histogram_freq\n        self.write_graph = write_graph\n        self.write_images = write_images\n        self.write_steps_per_second = write_steps_per_second\n        self.update_freq = 1 if update_freq == \"batch\" else update_freq\n        self.embeddings_freq = embeddings_freq\n        self.embeddings_metadata = embeddings_metadata\n        if profile_batch:\n            if backend.backend() not in (\"jax\", \"tensorflow\"):\n                # TODO: profiling not available in torch, numpy\n                raise ValueError(\n                    \"Profiling is not yet available with the \"\n                    f\"{backend.backend()} backend. Please open a PR \"\n                    \"if you'd like to add this feature. Received: \"\n                    f\"profile_batch={profile_batch} (must be 0)\"\n                )\n            elif backend.backend() == \"jax\":\n                if sys.version_info[1] < 12:\n                    warnings.warn(\n                        \"Profiling with the \"\n                        f\"{backend.backend()} backend requires python >= 3.12.\"\n                    )\n                    profile_batch = 0\n\n        self._init_profile_batch(profile_batch)\n        self._global_train_batch = 0\n        self._global_test_batch = 0\n        self._previous_epoch_iterations = 0\n        self._train_accumulated_time = 0\n        self._batch_start_time = 0\n        self._summary_module = None\n\n        # Lazily initialized in order to avoid creating event files when\n        # not needed.\n        self._writers = {}\n\n        # Used to restore any existing `SummaryWriter` after training ends.\n        self._prev_summary_state = []\n\n    def set_model(self, model):\n        \"\"\"Sets Keras model and writes graph if specified.\"\"\"\n        self._model = model\n        self._log_write_dir = self.log_dir\n\n        self._train_dir = os.path.join(self._log_write_dir, \"train\")\n        self._val_dir = os.path.join(self._log_write_dir, \"validation\")\n        self._writers = {}  # Resets writers.\n\n        self._should_write_train_graph = False\n        if self.write_graph:\n            self._write_keras_model_summary()\n            self._should_write_train_graph = True\n        if self.embeddings_freq:\n            self._configure_embeddings()\n\n    @property\n    def summary(self):\n        if self._summary_module is None:\n            import tensorflow.summary as summary\n\n            self._summary_module = summary\n        return self._summary_module\n\n    @property\n    def _train_writer(self):\n        if \"train\" not in self._writers:\n            self._writers[\"train\"] = self.summary.create_file_writer(\n                self._train_dir\n            )\n        return self._writers[\"train\"]\n\n    @property\n    def _val_writer(self):\n        if \"val\" not in self._writers:\n            self._writers[\"val\"] = self.summary.create_file_writer(\n                self._val_dir\n            )\n        return self._writers[\"val\"]\n\n    def _write_keras_model_train_graph(self):\n        \"\"\"Writes Keras model train_function graph to TensorBoard.\"\"\"\n        with self._train_writer.as_default():\n            train_fn = self.model.train_function\n            # If the train_function is a `tf.function`, we can write out a\n            # graph\n            if hasattr(train_fn, \"function_spec\"):\n                # TODO(b/243822285): Use _variable_creation_fn directly.\n                if hasattr(train_fn, \"_concrete_stateful_fn\"):\n                    self.summary.graph(train_fn._concrete_stateful_fn.graph)\n                else:\n                    self.summary.graph(\n                        train_fn._concrete_variable_creation_fn.graph\n                    )\n\n    def _write_keras_model_summary(self):\n        \"\"\"Writes Keras graph network summary to TensorBoard.\"\"\"\n        with self._train_writer.as_default():\n            if (\n                self.model.__class__.__name__ == \"Functional\"\n                or self.model.__class__.__name__ == \"Sequential\"\n            ):\n                keras_model_summary(\"keras\", self.model, step=0)\n\n    def _configure_embeddings(self):\n        \"\"\"Configure the Projector for embeddings.\"\"\"\n        from google.protobuf import text_format\n        from tensorboard.plugins import projector\n\n        config = projector.ProjectorConfig()\n        for layer in self.model.layers:\n            if isinstance(layer, Embedding):\n                embedding = config.embeddings.add()\n                # Embeddings are always the first layer, so this naming should\n                # be consistent in any keras models checkpoints.\n                name = (\n                    \"layer_with_weights-0/embeddings/.ATTRIBUTES/VARIABLE_VALUE\"\n                )\n                embedding.tensor_name = name\n\n                if self.embeddings_metadata is not None:\n                    if isinstance(self.embeddings_metadata, str):\n                        embedding.metadata_path = self.embeddings_metadata\n                    else:\n                        if layer.name in self.embeddings_metadata.keys():\n                            embedding.metadata_path = (\n                                self.embeddings_metadata.pop(layer.name)\n                            )\n\n        if self.embeddings_metadata and not isinstance(\n            self.embeddings_metadata, str\n        ):\n            raise ValueError(\n                \"Unrecognized `Embedding` layer names passed to \"\n                \"`keras.callbacks.TensorBoard` `embeddings_metadata` \"\n                f\"argument: {self.embeddings_metadata.keys()}\"\n            )\n\n        config_pbtxt = text_format.MessageToString(config)\n        path = os.path.join(self._log_write_dir, \"projector_config.pbtxt\")\n        with file_utils.File(path, \"w\") as f:\n            f.write(config_pbtxt)\n\n    def _push_writer(self, writer, step):\n        \"\"\"Sets the default writer for custom batch-level summaries.\"\"\"\n        if self.update_freq == \"epoch\":\n            return\n\n        def should_record():\n            return step % self.update_freq == 0\n\n        summary_context = (\n            writer.as_default(step),\n            self.summary.record_if(should_record),\n        )\n        self._prev_summary_state.append(summary_context)\n        summary_context[0].__enter__()\n        summary_context[1].__enter__()\n\n    def _pop_writer(self):\n        \"\"\"Pops the current writer.\"\"\"\n        if self.update_freq == \"epoch\":\n            return\n\n        # See _push_writer for the content of the previous_context, which is\n        # pair of context.\n        previous_context = self._prev_summary_state.pop()\n        previous_context[1].__exit__(*sys.exc_info())\n        previous_context[0].__exit__(*sys.exc_info())\n\n    def _close_writers(self):\n        for writer in self._writers.values():\n            writer.close()\n\n    def _init_profile_batch(self, profile_batch):\n        \"\"\"Validate profile_batch value and set the range of batches to profile.\n\n        Sets values of _start_batch and _stop_batch attributes,\n        specifying the start and stop batch to profile.\n        Setting `profile_batch=0` disables profiling.\n\n        Args:\n          profile_batch: The range of batches to profile. Should be a\n            non-negative integer or a comma separated string of pair of positive\n            integers. A pair of positive integers signify a range of batches to\n            profile.\n\n        Raises:\n          ValueError: If profile_batch is not an integer or a comma separated\n            pair of positive integers.\n\n        \"\"\"\n        profile_batch_error_message = (\n            \"profile_batch must be a non-negative integer or \"\n            \"2-tuple of positive \"\n            \"integers. A pair of positive integers \"\n            \"signifies a range of batches \"\n            f\"to profile. Found: {profile_batch}\"\n        )\n\n        # Support legacy way of specifying \"start,stop\" or \"start\" as str.\n        if isinstance(profile_batch, str):\n            profile_batch = str(profile_batch).split(\",\")\n            profile_batch = tree.map_structure(int, profile_batch)\n\n        if isinstance(profile_batch, int):\n            self._start_batch = profile_batch\n            self._stop_batch = profile_batch\n        elif (\n            isinstance(profile_batch, (tuple, list)) and len(profile_batch) == 2\n        ):\n            self._start_batch, self._stop_batch = profile_batch\n        else:\n            raise ValueError(profile_batch_error_message)\n\n        if self._start_batch < 0 or self._stop_batch < self._start_batch:\n            raise ValueError(profile_batch_error_message)\n\n        # True when the profiler was successfully started by this callback.\n        # We track the status here to make sure callbacks do not interfere with\n        # each other. The callback will only stop the profiler it started.\n        self._profiler_started = False\n        self._batch_trace_context = None\n\n        if self._start_batch > 0:\n            # Warm up and improve the profiling accuracy.\n            self._start_profiler(logdir=\"\")\n            self._stop_profiler(save=False)\n        # True when a trace is running.\n        self._is_tracing = False\n\n        # Setting `profile_batch=0` disables profiling.\n        self._should_trace = not (\n            self._start_batch == 0 and self._stop_batch == 0\n        )\n\n    def on_train_begin(self, logs=None):\n        self._global_train_batch = 0\n        self._previous_epoch_iterations = 0\n        self._push_writer(self._train_writer, self._global_train_batch)\n\n    def on_train_end(self, logs=None):\n        self._pop_writer()\n\n        if self._is_tracing:\n            self._stop_trace()\n\n        self._close_writers()\n\n    def on_test_begin(self, logs=None):\n        self._push_writer(self._val_writer, self._global_test_batch)\n\n    def on_test_end(self, logs=None):\n        if self.model.optimizer and hasattr(self.model.optimizer, \"iterations\"):\n            with self._val_writer.as_default():\n                for name, value in logs.items():\n                    self.summary.scalar(\n                        f\"evaluation_{name}_vs_iterations\",\n                        value,\n                        step=self.model.optimizer.iterations,\n                    )\n        self._pop_writer()\n\n    def on_train_batch_begin(self, batch, logs=None):\n        self._global_train_batch += 1\n        if self.write_steps_per_second:\n            self._batch_start_time = time.time()\n        if not self._should_trace:\n            return\n\n        if self._global_train_batch == self._start_batch:\n            self._start_trace()\n        if self._profiler_started:\n            self._batch_trace_context = backend.tensorboard.start_batch_trace(\n                batch\n            )\n\n    def on_train_batch_end(self, batch, logs=None):\n        if self._should_write_train_graph:\n            self._write_keras_model_train_graph()\n            self._should_write_train_graph = False\n        if self.write_steps_per_second:\n            batch_run_time = time.time() - self._batch_start_time\n            self.summary.scalar(\n                \"batch_steps_per_second\",\n                1.0 / batch_run_time,\n                step=self._global_train_batch,\n            )\n\n        # `logs` isn't necessarily always a dict\n        if isinstance(logs, dict):\n            for name, value in logs.items():\n                self.summary.scalar(\n                    f\"batch_{name}\", value, step=self._global_train_batch\n                )\n\n        if not self._should_trace:\n            return\n\n        if self._is_tracing:\n            if self._profiler_started and self._batch_trace_context is not None:\n                backend.tensorboard.stop_batch_trace(self._batch_trace_context)\n                self._batch_trace_context = None\n            if self._global_train_batch >= self._stop_batch:\n                self._stop_trace()\n\n    def on_test_batch_begin(self, batch, logs=None):\n        self._global_test_batch += 1\n\n    def on_epoch_begin(self, epoch, logs=None):\n        # Keeps track of epoch for profiling.\n        if self.write_steps_per_second:\n            self._previous_epoch_iterations = ops.convert_to_tensor(\n                self.model.optimizer.iterations, \"float32\"\n            )\n            self._epoch_start_time = time.time()\n\n    def on_epoch_end(self, epoch, logs=None):\n        \"\"\"Runs metrics and histogram summaries at epoch end.\"\"\"\n        self._log_epoch_metrics(epoch, logs)\n\n        if self.histogram_freq and epoch % self.histogram_freq == 0:\n            self._log_weights(epoch)\n\n        if self.embeddings_freq and epoch % self.embeddings_freq == 0:\n            self._log_embeddings(epoch)\n\n    def _start_trace(self):\n        self.summary.trace_on(graph=True, profiler=False)\n        self._start_profiler(logdir=self._train_dir)\n        self._is_tracing = True\n\n    def _stop_trace(self, batch=None):\n        \"\"\"Logs the trace graph to TensorBoard.\"\"\"\n        if batch is None:\n            batch = self._stop_batch\n        with self._train_writer.as_default():\n            # TODO(b/126388999): Remove step info in the summary name.\n            self.summary.trace_export(name=\"batch_%d\" % batch, step=batch)\n        self._stop_profiler()\n        self._is_tracing = False\n\n    def _collect_learning_rate(self, logs):\n        if isinstance(self.model.optimizer, Optimizer):\n            logs[\"learning_rate\"] = float(\n                ops.convert_to_numpy(self.model.optimizer.learning_rate)\n            )\n        return logs\n\n    def _compute_steps_per_second(self):\n        current_iteration = self.model.optimizer.iterations\n        time_since_epoch_begin = time.time() - self._epoch_start_time\n        current_iteration = ops.convert_to_tensor(current_iteration, \"float32\")\n        time_since_epoch_begin = ops.convert_to_tensor(\n            time_since_epoch_begin, \"float32\"\n        )\n\n        steps_per_second = (\n            current_iteration - self._previous_epoch_iterations\n        ) / time_since_epoch_begin\n        return float(steps_per_second)\n\n    def _log_epoch_metrics(self, epoch, logs):\n        \"\"\"Writes epoch metrics out as scalar summaries.\n\n        Args:\n            epoch: Int. The global step to use for TensorBoard.\n            logs: Dict. Keys are scalar summary names, values are scalars.\n        \"\"\"\n        if not logs:\n            return\n\n        train_logs = {k: v for k, v in logs.items() if not k.startswith(\"val_\")}\n        val_logs = {k: v for k, v in logs.items() if k.startswith(\"val_\")}\n        train_logs = self._collect_learning_rate(train_logs)\n        if self.write_steps_per_second:\n            train_logs[\"steps_per_second\"] = self._compute_steps_per_second()\n\n        if train_logs:\n            with self._train_writer.as_default():\n                for name, value in train_logs.items():\n                    self.summary.scalar(f\"epoch_{name}\", value, step=epoch)\n        if val_logs:\n            with self._val_writer.as_default():\n                for name, value in val_logs.items():\n                    name = name[4:]  # Remove 'val_' prefix.\n                    self.summary.scalar(f\"epoch_{name}\", value, step=epoch)\n\n    def _log_weights(self, epoch):\n        \"\"\"Logs the weights of the Model to TensorBoard.\"\"\"\n        with self._train_writer.as_default():\n            for layer in self.model.layers:\n                for weight in layer.weights:\n                    weight_name = weight.path.replace(\":\", \"_\")\n                    # Add a suffix to prevent summary tag name collision.\n                    histogram_weight_name = f\"{weight_name}/histogram\"\n                    self.summary.histogram(\n                        histogram_weight_name, weight, step=epoch\n                    )\n                    if self.write_images:\n                        # Add a suffix to prevent summary tag name\n                        # collision.\n                        image_weight_name = f\"{weight_name}/image\"\n                        self._log_weight_as_image(\n                            weight, image_weight_name, epoch\n                        )\n            self._train_writer.flush()\n\n    def _log_weight_as_image(self, weight, weight_name, epoch):\n        \"\"\"Logs a weight as a TensorBoard image.\"\"\"\n        w_img = ops.squeeze(weight)\n        shape = w_img.shape\n        if len(shape) == 1:  # Bias case\n            w_img = ops.reshape(w_img, [1, shape[0], 1, 1])\n        elif len(shape) == 2:  # Dense layer kernel case\n            if shape[0] > shape[1]:\n                w_img = ops.transpose(w_img)\n                shape = w_img.shape\n            w_img = ops.reshape(w_img, [1, shape[0], shape[1], 1])\n        elif len(shape) == 3:  # ConvNet case\n            if backend.image_data_format() == \"channels_last\":\n                # Switch to channels_first to display every kernel as a separate\n                # image.\n                w_img = ops.transpose(w_img, [2, 0, 1])\n                shape = w_img.shape\n            w_img = ops.reshape(w_img, [shape[0], shape[1], shape[2], 1])\n\n        w_img = backend.convert_to_numpy(w_img)\n        shape = w_img.shape\n        # Not possible to handle 3D convnets etc.\n        if len(shape) == 4 and shape[-1] in [1, 3, 4]:\n            self.summary.image(weight_name, w_img, step=epoch)\n\n    def _log_embeddings(self, epoch):\n        embeddings_ckpt = os.path.join(\n            self._log_write_dir,\n            \"train\",\n            f\"keras_embedding.ckpt-{epoch}.weights.h5\",\n        )\n        self.model.save_weights(embeddings_ckpt)\n\n    def _start_profiler(self, logdir):\n        \"\"\"Starts the profiler if currently inactive.\n\n        Args:\n          logdir: Directory where profiler results will be saved.\n        \"\"\"\n        if self._profiler_started:\n            return\n        try:\n            backend.tensorboard.start_trace(logdir)\n            self._profiler_started = True\n        except Exception as e:\n            # Profiler errors should not be fatal.\n            logging.error(\"Failed to start profiler: %s\", e)\n\n    def _stop_profiler(self, save=True):\n        \"\"\"Stops the profiler if currently active.\n\n        Args:\n          save: Whether to save the profiler results to TensorBoard.\n        \"\"\"\n        if not self._profiler_started:\n            return\n        try:\n            backend.tensorboard.stop_trace(save=save)\n        except Exception as e:\n            # Profiler errors should not be fatal.\n            logging.error(\"Failed to stop profiler: %s\", e)\n        finally:\n            self._profiler_started = False\n\n\ndef keras_model_summary(name, data, step=None):\n    \"\"\"Writes a Keras model as JSON to as a Summary.\n\n    Writing the Keras model configuration allows the TensorBoard graph plugin to\n    render a conceptual graph, as opposed to graph of ops. In case the model\n    fails to serialize as JSON, it ignores and returns False.\n\n    Args:\n        name: A name for this summary. The summary tag used for TensorBoard will\n            be this name prefixed by any active name scopes.\n        data: A Keras Model to write.\n        step: Explicit `int64`-castable monotonic step value for this summary.\n            If omitted, this defaults to `tf.summary.experimental.get_step()`,\n            which must not be `None`.\n\n    Returns:\n        True on success, or False if no summary was written because no default\n        summary writer was available.\n\n    Raises:\n        ValueError: if a default writer exists, but no step was provided and\n            `tf.summary.experimental.get_step()` is `None`.\n    \"\"\"\n    import tensorflow.summary as summary\n    from tensorflow.compat.v1 import SummaryMetadata\n\n    summary_metadata = SummaryMetadata()\n    # Hard coding a plugin name. Please refer to go/tb-plugin-name-hardcode for\n    # the rationale.\n    summary_metadata.plugin_data.plugin_name = \"graph_keras_model\"\n    # version number = 1\n    summary_metadata.plugin_data.content = b\"1\"\n\n    try:\n        json_string = data.to_json()\n    except Exception as exc:\n        # An exception should not break a model code.\n        warnings.warn(f\"Model failed to serialize as JSON. Ignoring... {exc}\")\n        return False\n\n    with summary.experimental.summary_scope(\n        name, \"graph_keras_model\", [data, step]\n    ) as (tag, _):\n        return summary.write(\n            tag=tag, tensor=json_string, step=step, metadata=summary_metadata\n        )\n"
  },
  {
    "path": "keras/src/callbacks/tensorboard_test.py",
    "content": "import collections\nimport os\nimport random\nimport sys\n\nimport numpy as np\nimport pytest\nimport tensorflow.summary as summary\nfrom tensorflow.compat.v1 import SummaryMetadata\nfrom tensorflow.core.util import event_pb2\nfrom tensorflow.python.lib.io import tf_record\n\nfrom keras.src import backend\nfrom keras.src import callbacks\nfrom keras.src import layers\nfrom keras.src import losses\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import optimizers\nfrom keras.src import testing\nfrom keras.src.optimizers import schedules\n\n# Note: this file and tensorboard in general has a dependency on tensorflow\n\n# A summary that was emitted during a test. Fields:\n#   logdir: str. The logdir of the FileWriter to which the summary was\n#     written.\n#   tag: str. The name of the summary.\n_ObservedSummary = collections.namedtuple(\"_ObservedSummary\", (\"logdir\", \"tag\"))\n\n\nclass _SummaryIterator:\n    \"\"\"Yields `Event` protocol buffers from a given path.\"\"\"\n\n    def __init__(self, path):\n        self._tf_record_iterator = tf_record.tf_record_iterator(path)\n\n    def __iter__(self):\n        return self\n\n    def __next__(self):\n        r = next(self._tf_record_iterator)\n        return event_pb2.Event.FromString(r)\n\n    next = __next__\n\n\nclass _SummaryFile:\n    \"\"\"A record of summary tags and the files to which they were written.\n\n    Fields `scalars`, `images`, `histograms`, and `tensors` are sets\n    containing `_ObservedSummary` values.\n    \"\"\"\n\n    def __init__(self):\n        self.scalars = set()\n        self.images = set()\n        self.histograms = set()\n        self.tensors = set()\n        self.graph_defs = []\n        self.convert_from_v2_summary_proto = False\n\n\ndef list_summaries(logdir):\n    \"\"\"Read all summaries under the logdir into a `_SummaryFile`.\n\n    Args:\n      logdir: A path to a directory that contains zero or more event\n        files, either as direct children or in transitive subdirectories.\n        Summaries in these events must only contain old-style scalars,\n        images, and histograms. Non-summary events, like `graph_def`s, are\n        ignored.\n\n    Returns:\n      A `_SummaryFile` object reflecting all summaries written to any\n      event files in the logdir or any of its descendant directories.\n\n    Raises:\n      ValueError: If an event file contains an summary of unexpected kind.\n    \"\"\"\n    result = _SummaryFile()\n    for dirpath, _, filenames in os.walk(logdir):\n        for filename in filenames:\n            if not filename.startswith(\"events.out.\"):\n                continue\n            path = os.path.join(dirpath, filename)\n            for event in _SummaryIterator(path):\n                if event.graph_def:\n                    result.graph_defs.append(event.graph_def)\n                if not event.summary:  # (e.g., it's a `graph_def` event)\n                    continue\n                for value in event.summary.value:\n                    tag = value.tag\n                    # Case on the `value` rather than the summary metadata\n                    # because the Keras callback uses `summary_ops_v2` to emit\n                    # old-style summaries. See b/124535134.\n                    kind = value.WhichOneof(\"value\")\n                    container = {\n                        \"simple_value\": result.scalars,\n                        \"image\": result.images,\n                        \"histo\": result.histograms,\n                        \"tensor\": result.tensors,\n                    }.get(kind)\n                    if container is None:\n                        raise ValueError(\n                            \"Unexpected summary kind %r in event file %s:\\n%r\"\n                            % (kind, path, event)\n                        )\n                    elif kind == \"tensor\" and tag != \"keras\":\n                        # Convert the tf2 summary proto to old style for type\n                        # checking.\n                        plugin_name = value.metadata.plugin_data.plugin_name\n                        container = {\n                            \"images\": result.images,\n                            \"histograms\": result.histograms,\n                            \"scalars\": result.scalars,\n                        }.get(plugin_name)\n                        if container is not None:\n                            result.convert_from_v2_summary_proto = True\n                        else:\n                            container = result.tensors\n                    container.add(_ObservedSummary(logdir=dirpath, tag=tag))\n    return result\n\n\nclass TestTensorBoardV2(testing.TestCase):\n    def _get_log_dirs(self):\n        logdir = os.path.join(\n            self.get_temp_dir(), str(random.randint(1, int(1e7))), \"tb\"\n        )\n        train_dir = os.path.join(logdir, \"train\")\n        validation_dir = os.path.join(logdir, \"validation\")\n        return logdir, train_dir, validation_dir\n\n    def _get_model(self, compile_model=True):\n        model = models.Sequential(\n            [\n                layers.Input((10, 10, 1)),\n                layers.Flatten(),\n                layers.Dense(1),\n            ]\n        )\n        if compile_model:\n            model.compile(\"sgd\", \"mse\")\n        return model\n\n    @pytest.mark.requires_trainable_backend\n    def test_TensorBoard_basic(self):\n        model = self._get_model()\n        x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))\n        logdir, train_dir, validation_dir = self._get_log_dirs()\n        tb_cbk = callbacks.TensorBoard(logdir)\n\n        model.fit(\n            x,\n            y,\n            batch_size=2,\n            epochs=2,\n            validation_data=(x, y),\n            callbacks=[tb_cbk],\n        )\n\n        summary_file = list_summaries(logdir)\n        self.assertEqual(\n            summary_file.scalars,\n            {\n                _ObservedSummary(logdir=train_dir, tag=\"epoch_loss\"),\n                _ObservedSummary(logdir=validation_dir, tag=\"epoch_loss\"),\n                _ObservedSummary(logdir=train_dir, tag=\"epoch_learning_rate\"),\n                _ObservedSummary(\n                    logdir=validation_dir,\n                    tag=\"evaluation_loss_vs_iterations\",\n                ),\n            },\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_TensorBoard_across_invocations(self):\n        \"\"\"Regression test for summary writer resource use-after-free.\"\"\"\n        model = self._get_model()\n        x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))\n        logdir, train_dir, validation_dir = self._get_log_dirs()\n        tb_cbk = callbacks.TensorBoard(logdir)\n\n        for _ in (1, 2):\n            model.fit(\n                x,\n                y,\n                batch_size=2,\n                epochs=2,\n                validation_data=(x, y),\n                callbacks=[tb_cbk],\n            )\n\n        summary_file = list_summaries(logdir)\n        self.assertEqual(\n            summary_file.scalars,\n            {\n                _ObservedSummary(logdir=train_dir, tag=\"epoch_loss\"),\n                _ObservedSummary(logdir=validation_dir, tag=\"epoch_loss\"),\n                _ObservedSummary(logdir=train_dir, tag=\"epoch_learning_rate\"),\n                _ObservedSummary(\n                    logdir=validation_dir,\n                    tag=\"evaluation_loss_vs_iterations\",\n                ),\n            },\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_TensorBoard_no_spurious_event_files(self):\n        model = self._get_model()\n        x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))\n        logdir, train_dir, _ = self._get_log_dirs()\n        tb_cbk = callbacks.TensorBoard(logdir)\n        model.fit(x, y, batch_size=2, epochs=2, callbacks=[tb_cbk])\n\n        events_file_run_basenames = set()\n        for dirpath, _, filenames in os.walk(train_dir):\n            if any(fn.startswith(\"events.out.\") for fn in filenames):\n                events_file_run_basenames.add(os.path.basename(dirpath))\n        self.assertEqual(events_file_run_basenames, {\"train\"})\n\n    @pytest.mark.requires_trainable_backend\n    def test_TensorBoard_batch_metrics(self):\n        model = self._get_model()\n        x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))\n        logdir, train_dir, validation_dir = self._get_log_dirs()\n        tb_cbk = callbacks.TensorBoard(logdir, update_freq=1)\n\n        model.fit(\n            x,\n            y,\n            batch_size=2,\n            epochs=2,\n            validation_data=(x, y),\n            callbacks=[tb_cbk],\n        )\n\n        summary_file = list_summaries(logdir)\n        self.assertEqual(\n            summary_file.scalars,\n            {\n                _ObservedSummary(logdir=train_dir, tag=\"batch_loss\"),\n                _ObservedSummary(logdir=train_dir, tag=\"epoch_loss\"),\n                _ObservedSummary(logdir=validation_dir, tag=\"epoch_loss\"),\n                _ObservedSummary(logdir=train_dir, tag=\"epoch_learning_rate\"),\n                _ObservedSummary(\n                    logdir=validation_dir,\n                    tag=\"evaluation_loss_vs_iterations\",\n                ),\n            },\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_TensorBoard_learning_rate_schedules(self):\n        model = self._get_model(compile_model=False)\n        opt = optimizers.SGD(schedules.CosineDecay(0.01, 1))\n        model.compile(opt, \"mse\")\n        logdir, train_dir, _ = self._get_log_dirs()\n        x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))\n\n        model.fit(\n            x,\n            y,\n            batch_size=2,\n            epochs=2,\n            callbacks=[callbacks.TensorBoard(logdir)],\n        )\n\n        summary_file = list_summaries(logdir)\n        self.assertEqual(\n            summary_file.scalars,\n            {\n                _ObservedSummary(logdir=train_dir, tag=\"epoch_loss\"),\n                _ObservedSummary(logdir=train_dir, tag=\"epoch_learning_rate\"),\n            },\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_TensorBoard_global_step(self):\n        model = self._get_model(compile_model=False)\n        opt = optimizers.SGD(schedules.CosineDecay(0.01, 1))\n        model.compile(opt, \"mse\")\n        logdir, train_dir, _ = self._get_log_dirs()\n        x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))\n\n        model.fit(\n            x,\n            y,\n            batch_size=2,\n            epochs=2,\n            verbose=0,\n            callbacks=[\n                callbacks.TensorBoard(\n                    logdir,\n                    update_freq=1,\n                    profile_batch=0,\n                    write_steps_per_second=True,\n                )\n            ],\n        )\n\n        summary_file = list_summaries(logdir)\n        self.assertEqual(\n            summary_file.scalars,\n            {\n                _ObservedSummary(logdir=train_dir, tag=\"batch_loss\"),\n                _ObservedSummary(logdir=train_dir, tag=\"epoch_loss\"),\n                _ObservedSummary(logdir=train_dir, tag=\"epoch_learning_rate\"),\n                _ObservedSummary(\n                    logdir=train_dir, tag=\"epoch_steps_per_second\"\n                ),\n                _ObservedSummary(\n                    logdir=train_dir, tag=\"batch_steps_per_second\"\n                ),\n            },\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_TensorBoard_weight_histograms(self):\n        model = self._get_model()\n        x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))\n        logdir, train_dir, validation_dir = self._get_log_dirs()\n        tb_cbk = callbacks.TensorBoard(logdir, histogram_freq=1)\n\n        model.fit(\n            x,\n            y,\n            batch_size=2,\n            epochs=2,\n            validation_data=(x, y),\n            callbacks=[tb_cbk],\n        )\n        summary_file = list_summaries(logdir)\n\n        self.assertEqual(\n            summary_file.scalars,\n            {\n                _ObservedSummary(logdir=train_dir, tag=\"epoch_loss\"),\n                _ObservedSummary(logdir=validation_dir, tag=\"epoch_loss\"),\n                _ObservedSummary(logdir=train_dir, tag=\"epoch_learning_rate\"),\n                _ObservedSummary(\n                    logdir=validation_dir,\n                    tag=\"evaluation_loss_vs_iterations\",\n                ),\n            },\n        )\n        self.assertEqual(\n            self._strip_layer_names(summary_file.histograms, \"sequential\"),\n            {_ObservedSummary(logdir=train_dir, tag=\"histogram\")},\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_TensorBoard_weight_images(self):\n        if backend.config.image_data_format() == \"channels_last\":\n            input_shape = (10, 10, 1)\n            x_shape = (10, 10, 10, 1)\n        else:\n            input_shape = (1, 10, 10)\n            x_shape = (10, 1, 10, 10)\n        x, y = np.ones(x_shape), np.ones((10, 1))\n        logdir, train_dir, validation_dir = self._get_log_dirs()\n        tb_cbk = callbacks.TensorBoard(\n            logdir, histogram_freq=1, write_images=True\n        )\n        model_type = \"sequential\"\n        model = models.Sequential(\n            [\n                layers.Input(input_shape),\n                layers.Conv2D(3, 10),\n                layers.GlobalAveragePooling2D(),\n                layers.Dense(1),\n            ]\n        )\n        model.compile(\"sgd\", \"mse\")\n        model.fit(\n            x,\n            y,\n            batch_size=2,\n            epochs=2,\n            validation_data=(x, y),\n            callbacks=[tb_cbk],\n        )\n        summary_file = list_summaries(logdir)\n\n        self.assertEqual(\n            summary_file.scalars,\n            {\n                _ObservedSummary(logdir=train_dir, tag=\"epoch_loss\"),\n                _ObservedSummary(logdir=validation_dir, tag=\"epoch_loss\"),\n                _ObservedSummary(logdir=train_dir, tag=\"epoch_learning_rate\"),\n                _ObservedSummary(\n                    logdir=validation_dir,\n                    tag=\"evaluation_loss_vs_iterations\",\n                ),\n            },\n        )\n        self.assertEqual(\n            self._strip_layer_names(summary_file.histograms, model_type),\n            {\n                _ObservedSummary(logdir=train_dir, tag=\"histogram\"),\n            },\n        )\n        expected_image_summaries = {\n            _ObservedSummary(logdir=train_dir, tag=\"image\"),\n        }\n        self.assertEqual(\n            self._strip_layer_names(summary_file.images, model_type),\n            expected_image_summaries,\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_TensorBoard_projector_callback(self):\n        model = models.Sequential(\n            [\n                layers.Input((10,)),\n                layers.Embedding(10, 10, name=\"test_embedding\"),\n                layers.Dense(1, activation=\"sigmoid\"),\n            ]\n        )\n        model.compile(\n            optimizer=\"adam\", loss=losses.BinaryCrossentropy(from_logits=True)\n        )\n        x, y = np.ones((10, 10)), np.ones((10, 10))\n        logdir, _, _ = self._get_log_dirs()\n        tb_cbk = callbacks.TensorBoard(\n            logdir,\n            embeddings_freq=1,\n            embeddings_metadata={\"test_embedding\": \"metadata.tsv\"},\n        )\n\n        model.fit(\n            x,\n            y,\n            batch_size=2,\n            epochs=2,\n            validation_data=(x, y),\n            callbacks=[tb_cbk],\n        )\n\n        with open(os.path.join(logdir, \"projector_config.pbtxt\")) as f:\n            self.assertEqual(\n                f.readlines(),\n                [\n                    \"embeddings {\\n\",\n                    \"  tensor_name: \"\n                    '\"layer_with_weights-0/embeddings/.ATTRIBUTES/'\n                    'VARIABLE_VALUE\"\\n',\n                    '  metadata_path: \"metadata.tsv\"\\n',\n                    \"}\\n\",\n                ],\n            )\n\n    @pytest.mark.requires_trainable_backend\n    def test_custom_summary(self):\n        def scalar_v2_mock(name, data, step=None):\n            \"\"\"A reimplementation of the scalar plugin to avoid circular\n            deps.\"\"\"\n            metadata = SummaryMetadata()\n            # Should match value in tensorboard/plugins/scalar/metadata.py.\n            metadata.plugin_data.plugin_name = \"scalars\"\n            with summary.experimental.summary_scope(\n                name, \"scalar_summary\", values=[data, step]\n            ) as (tag, _):\n                tensor = backend.convert_to_tensor(data, dtype=\"float32\")\n                if backend.backend() == \"torch\":\n                    # TODO: Use device scope after the API is added.\n                    if tensor.is_cuda:\n                        tensor = tensor.cpu()\n                summary.write(\n                    tag=tag,\n                    tensor=tensor,\n                    step=step,\n                    metadata=metadata,\n                )\n\n        class LayerWithSummary(layers.Layer):\n            def call(self, x):\n                scalar_v2_mock(\"custom_summary\", ops.sum(x))\n                return x\n\n        model = models.Sequential(\n            [\n                layers.Input((5,)),\n                LayerWithSummary(),\n            ]\n        )\n\n        # summary ops not compatible with XLA\n        model.compile(\"sgd\", \"mse\", jit_compile=False)\n        logdir, train_dir, validation_dir = self._get_log_dirs()\n        tb_cbk = callbacks.TensorBoard(logdir, update_freq=1)\n        x, y = np.ones((10, 5)), np.ones((10, 5))\n        model.fit(\n            x, y, batch_size=2, validation_data=(x, y), callbacks=[tb_cbk]\n        )\n        summary_file = list_summaries(logdir)\n        # TODO: tensorflow will tag with model/layer_with_summary/custom_summary\n        # Jax will only use custom_summary tag\n        self.assertEqual(\n            self._strip_to_only_final_name(summary_file.scalars),\n            {\n                _ObservedSummary(logdir=train_dir, tag=\"batch_loss\"),\n                _ObservedSummary(logdir=train_dir, tag=\"epoch_loss\"),\n                _ObservedSummary(logdir=train_dir, tag=\"epoch_learning_rate\"),\n                _ObservedSummary(logdir=validation_dir, tag=\"epoch_loss\"),\n                _ObservedSummary(\n                    logdir=validation_dir,\n                    tag=\"evaluation_loss_vs_iterations\",\n                ),\n                _ObservedSummary(\n                    logdir=train_dir,\n                    tag=\"custom_summary\",\n                ),\n                _ObservedSummary(\n                    logdir=validation_dir,\n                    tag=\"custom_summary\",\n                ),\n            },\n        )\n        # self.assertEqual(\n        #     summary_file.scalars,\n        #     {\n        #         _ObservedSummary(logdir=train_dir, tag=\"batch_loss\"),\n        #         _ObservedSummary(logdir=train_dir, tag=\"epoch_loss\"),\n        #         _ObservedSummary(logdir=validation_dir,\n        #               tag=\"epoch_loss\"),\n        #         _ObservedSummary(\n        #             logdir=validation_dir,\n        #             tag=\"evaluation_loss_vs_iterations\",\n        #         ),\n        #         _ObservedSummary(\n        #             logdir=train_dir,\n        #             tag=\"model/layer_with_summary/custom_summary\",\n        #         ),\n        #         _ObservedSummary(\n        #             logdir=validation_dir,\n        #             tag=\"model/layer_with_summary/custom_summary\",\n        #         ),\n        #     },\n        # )\n\n    def _strip_to_only_final_name(self, summaries):\n        \"\"\"Removes all leading names in a summary\n\n        Args:\n            summaries: A `set` of `_ObservedSummary` values.\n\n        Returns:\n            A new `set` of `_ObservedSummary` values striped of all\n            name except for the terminal one.\n\n        \"\"\"\n        result = set()\n        for s in summaries:\n            if \"/\" not in s.tag:\n                result.add(s)\n            else:\n                new_tag = s.tag.split(\"/\")[-1]\n                result.add(s._replace(tag=new_tag))\n        return result\n\n    def _strip_layer_names(self, summaries, model_type):\n        \"\"\"Deduplicate summary names modulo layer prefix.\n\n        This removes the first slash-component of each tag name: for\n        instance, \"foo/bar/baz\" becomes \"bar/baz\".\n\n        Args:\n            summaries: A `set` of `_ObservedSummary` values.\n            model_type: The model type currently being tested.\n\n        Returns:\n            A new `set` of `_ObservedSummary` values with layer prefixes\n            removed.\n        \"\"\"\n        result = set()\n        for s in summaries:\n            if \"/\" not in s.tag:\n                raise ValueError(f\"tag has no layer name: {s.tag!r}\")\n            new_tag = s.tag.split(\"/\")[-1]\n            result.add(s._replace(tag=new_tag))\n        return result\n\n    def _strip_variable_names(self, summaries):\n        \"\"\"Remove `variable_n` from summary tag\n\n        `variable_n` tag names are added with random numbers. Removing them\n        ensures deterministic tag names.\n\n        Args:\n            summaries: A `set` of `_ObservedSummary` values.\n\n        Returns:\n            A new `set` of `_ObservedSummary` values with layer prefixes\n            removed.\n        \"\"\"\n        result = set()\n        for s in summaries:\n            if \"/\" not in s.tag:\n                result.add(s)\n            else:\n                split_tag = s.tag.split(\"/\")\n                if \"variable\" in split_tag[0]:\n                    result.add(s._replace(tag=split_tag[-1]))\n                else:\n                    result.add(s)\n        return result\n\n    @pytest.mark.skipif(\n        backend.backend() == \"torch\",\n        reason=\"Torch backend requires blocking numpy conversion.\",\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_TensorBoard_non_blocking(self):\n        logdir, _, _ = self._get_log_dirs()\n        model = models.Sequential([layers.Dense(1)])\n        model.optimizer = optimizers.Adam()\n        tb = callbacks.TensorBoard(logdir)\n        cb_list = callbacks.CallbackList(\n            [tb], model=model, epochs=1, steps=100, verbose=0\n        )\n        tensor = ops.convert_to_tensor(1.0)\n\n        def mock_numpy():\n            raise RuntimeError(\n                \"If this error is seen, TensorBoard is causing a blocking \"\n                \"NumPy conversion.\"\n            )\n\n        tensor.numpy = mock_numpy\n\n        logs = {\"metric\": tensor}\n\n        cb_list.on_train_begin(logs)\n        cb_list.on_epoch_begin(0, logs)\n        cb_list.on_train_batch_begin(0, logs)\n        cb_list.on_train_batch_end(0, logs)\n        cb_list.on_epoch_end(0, logs)\n        cb_list.on_train_end(logs)\n\n        cb_list.on_test_begin(logs)\n        cb_list.on_test_batch_begin(0, logs)\n        cb_list.on_test_batch_end(0, logs)\n        cb_list.on_test_end(logs)\n\n        cb_list.on_predict_begin(logs)\n        cb_list.on_predict_batch_begin(logs)\n        cb_list.on_predict_batch_end(logs)\n        cb_list.on_predict_end(logs)\n\n    def _count_xplane_file(self, logdir):\n        profile_dir = os.path.join(logdir, \"plugins\", \"profile\")\n        count = 0\n        for dirpath, dirnames, filenames in os.walk(profile_dir):\n            del dirpath  # unused\n            del dirnames  # unused\n            for filename in filenames:\n                if filename.endswith(\".xplane.pb\"):\n                    count += 1\n        return count\n\n    def fitModelAndAssertKerasModelWritten(self, model):\n        x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))\n        logdir, train_dir, validation_dir = self._get_log_dirs()\n        tb_cbk = callbacks.TensorBoard(\n            logdir, write_graph=True, profile_batch=0\n        )\n        model.fit(\n            x,\n            y,\n            batch_size=2,\n            epochs=3,\n            validation_data=(x, y),\n            callbacks=[tb_cbk],\n        )\n        summary_file = list_summaries(logdir)\n        self.assertEqual(\n            summary_file.tensors,\n            {\n                _ObservedSummary(logdir=train_dir, tag=\"keras\"),\n            },\n        )\n        if not model.run_eagerly:\n            # There should be one train graph\n            self.assertLen(summary_file.graph_defs, 1)\n            for graph_def in summary_file.graph_defs:\n                graph_def_str = str(graph_def)\n\n                # All the model layers should appear in the graphs\n                for layer in model.layers:\n                    if \"input\" not in layer.name:\n                        self.assertIn(layer.name, graph_def_str)\n\n    def test_TensorBoard_write_sequential_model_no_input_shape(self):\n        # TODO: Requires to_json implementation in trainer\n        # model = models.Sequential(\n        #     [\n        #         Conv2D(8, (3, 3)),\n        #         Flatten(),\n        #         Dense(1),\n        #     ]\n        # )\n        # model.compile(\"sgd\", \"mse\")\n        # self.fitModelAndAssertKerasModelWritten(model)\n        pass\n\n    def test_TensorBoard_write_sequential_model_with_input_shape(self):\n        # TODO: Requires to_json implementation in trainer\n        # model = models.Sequential(\n        #     [\n        #         Input(input_shape=(10, 10, 1)),\n        #         Conv2D(8, (3, 3)),\n        #         Flatten(),\n        #         Dense(1),\n        #     ]\n        # )\n        # model.compile(\"sgd\", \"mse\")\n        # self.fitModelAndAssertKerasModelWritten(model)\n        pass\n\n    def test_TensorBoard_write_model(self):\n        # TODO: Requires to_json implementation in trainer\n        # See https://github.com/keras-team/keras/blob/ \\\n        # a8d4a7f1ffc9de3c5932828a107e4e95e8803fb4/ \\\n        # keras/engine/training.py#L3313\n        # inputs = Input([10, 10, 1])\n        # x = Conv2D(8, (3, 3), activation=\"relu\")(inputs)\n        # x = Flatten()(x)\n        # x = Dense(1)(x)\n        # model = models.Model(inputs=inputs, outputs=[x])\n        # model.compile(\"sgd\", \"mse\")\n        # breakpoint()\n        # self.fitModelAndAssertKerasModelWritten(model)\n        pass\n\n    @pytest.mark.skipif(\n        backend.backend() not in (\"jax\", \"tensorflow\"),\n        reason=\"The profiling test can only run with TF and JAX backends.\",\n    )\n    def test_TensorBoard_auto_trace(self):\n        logdir, train_dir, validation_dir = self._get_log_dirs()\n        model = models.Sequential(\n            [\n                layers.Input((10, 10, 1)),\n                layers.Flatten(),\n                layers.Dense(1),\n            ]\n        )\n        x, y = np.ones((10, 10, 10, 1)), np.ones((10, 1))\n        if backend.backend() == \"jax\" and sys.version_info[1] < 12:\n            with pytest.warns(match=\"backend requires python >= 3.12\"):\n                callbacks.TensorBoard(\n                    logdir, histogram_freq=1, profile_batch=1, write_graph=False\n                )\n            self.skipTest(\n                \"Profiling with JAX and python < 3.12 \"\n                \"raises segmentation fault.\"\n            )\n\n        tb_cbk = callbacks.TensorBoard(\n            logdir, histogram_freq=1, profile_batch=1, write_graph=False\n        )\n        model.compile(\"sgd\", \"mse\")\n        model.fit(\n            x,\n            y,\n            batch_size=2,\n            epochs=2,\n            validation_data=(x, y),\n            callbacks=[tb_cbk],\n        )\n        summary_file = list_summaries(logdir)\n\n        self.assertEqual(\n            summary_file.tensors,\n            {\n                _ObservedSummary(logdir=train_dir, tag=\"batch_1\"),\n            },\n        )\n        self.assertEqual(1, self._count_xplane_file(logdir=train_dir))\n        pass\n"
  },
  {
    "path": "keras/src/callbacks/terminate_on_nan.py",
    "content": "import numpy as np\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.callbacks.callback import Callback\nfrom keras.src.utils import io_utils\n\n\n@keras_export(\"keras.callbacks.TerminateOnNaN\")\nclass TerminateOnNaN(Callback):\n    \"\"\"Callback that terminates training when a NaN loss is encountered.\n\n    This callback monitors the loss value during training\n    and terminates training when a NaN or Inf loss is detected.\n    By default, training is stopped gracefully\n    by setting `model.stop_training = True`, which triggers all callback cleanup\n    methods including `on_train_end()`.\n\n    Alternatively, you can use `raise_error=True` to immediately raise a\n    RuntimeError when NaN/Inf is detected. This raise_error termination\n    prevents `on_train_end()` from being called on other callbacks, which\n    is useful for preserving backup states or preventing unintended cleanup\n    when training fails.\n\n    Args:\n        raise_error: Boolean, default False. If False, uses graceful stop via\n            `model.stop_training = True`. If True, immediately raises\n            RuntimeError on NaN/Inf loss, bypassing callback cleanup methods.\n\n    Example:\n\n    ```\n    # Graceful termination (default)\n    callback = keras.callbacks.TerminateOnNaN()\n    model.fit(x, y, callbacks=[callback])\n\n    # raise_error termination (strict failure)\n    callback = keras.callbacks.TerminateOnNaN(raise_error=True)\n    model.fit(x, y, callbacks=[callback])\n    ```\n    \"\"\"\n\n    def __init__(self, raise_error: bool = False):\n        super().__init__()\n        self.raise_error = raise_error\n\n    def on_batch_end(self, batch, logs=None):\n        \"\"\"Check for NaN/Inf loss at the end of each batch.\n\n        Args:\n            batch: Integer, index of batch within the current epoch.\n            logs: Dict, contains the return value of `model.train_step()`.\n\n        Raises:\n            RuntimeError: If loss is NaN/Inf and raise_error=True.\n        \"\"\"\n        logs = logs or {}\n        loss = logs.get(\"loss\")\n        if loss is not None:\n            if np.isnan(loss) or np.isinf(loss):\n                if self.raise_error:\n                    raise RuntimeError(\n                        f\"NaN or Inf loss encountered at batch {batch}. \"\n                        f\"Loss value: {loss}. Terminating training immediately.\"\n                    )\n                else:\n                    io_utils.print_msg(\n                        f\"Batch {batch}: Invalid loss, terminating training\"\n                    )\n                    self.model.stop_training = True\n"
  },
  {
    "path": "keras/src/callbacks/terminate_on_nan_test.py",
    "content": "import os\n\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import callbacks\nfrom keras.src import initializers\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import testing\nfrom keras.src.callbacks import BackupAndRestore\nfrom keras.src.callbacks import TerminateOnNaN\nfrom keras.src.models import Sequential\nfrom keras.src.utils import numerical_utils\n\n\n@pytest.mark.requires_trainable_backend\nclass TerminateOnNaNTest(testing.TestCase):\n    \"\"\"Test suite for TerminateOnNaN callback.\"\"\"\n\n    def test_TerminateOnNaN(self):\n        TRAIN_SAMPLES = 10\n        TEST_SAMPLES = 10\n        INPUT_DIM = 3\n        NUM_CLASSES = 2\n        BATCH_SIZE = 4\n\n        np.random.seed(1337)\n        x_train = np.random.random((TRAIN_SAMPLES, INPUT_DIM))\n        y_train = np.random.choice(np.arange(NUM_CLASSES), size=TRAIN_SAMPLES)\n        x_test = np.random.random((TEST_SAMPLES, INPUT_DIM))\n        y_test = np.random.choice(np.arange(NUM_CLASSES), size=TEST_SAMPLES)\n\n        y_test = numerical_utils.to_categorical(y_test)\n        y_train = numerical_utils.to_categorical(y_train)\n        model = Sequential()\n        initializer = initializers.Constant(value=1e5)\n        for _ in range(5):\n            model.add(\n                layers.Dense(\n                    2,\n                    activation=\"relu\",\n                    kernel_initializer=initializer,\n                )\n            )\n        model.add(layers.Dense(NUM_CLASSES))\n        model.compile(loss=\"mean_squared_error\", optimizer=\"sgd\")\n\n        history = model.fit(\n            x_train,\n            y_train,\n            batch_size=BATCH_SIZE,\n            validation_data=(x_test, y_test),\n            callbacks=[callbacks.TerminateOnNaN()],\n            epochs=20,\n        )\n        loss = history.history[\"loss\"]\n        self.assertEqual(len(loss), 1)\n        self.assertTrue(np.isnan(loss[0]) or np.isinf(loss[0]))\n\n    def test_terminate_on_nan_graceful_stop(self):\n        \"\"\"Test that TerminateOnNaN (default) gracefully stops training.\"\"\"\n        model = models.Sequential([layers.Dense(1, input_shape=(1,))])\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n\n        x = np.array([[1.0], [2.0]])\n        y = np.array([[np.inf], [np.inf]])\n\n        callback = TerminateOnNaN(raise_error=False)\n\n        # Training should complete without raising RuntimeError\n        history = model.fit(\n            x, y, epochs=2, batch_size=1, callbacks=[callback], verbose=0\n        )\n\n        # Training should stop early\n        self.assertLess(len(history.history[\"loss\"]), 4)\n\n    def test_terminate_on_nan_raise_error_raises_error(self):\n        \"\"\"Test that TerminateOnNaN(raise_error=True) raises\n        RuntimeError on NaN loss.\n        \"\"\"\n        model = models.Sequential([layers.Dense(1, input_shape=(1,))])\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n\n        x = np.array([[1.0], [2.0]])\n        y = np.array([[np.inf], [np.inf]])\n\n        callback = TerminateOnNaN(raise_error=True)\n\n        # Training should raise RuntimeError\n        with self.assertRaisesRegex(\n            RuntimeError,\n            \"NaN or Inf loss encountered\",\n        ):\n            model.fit(\n                x, y, epochs=1, batch_size=1, callbacks=[callback], verbose=0\n            )\n\n    def test_raise_error_terminate_does_not_trigger_on_train_end(self):\n        \"\"\"Test that on_train_end is NOT called when\n        TerminateOnNaN(raise_error=True) raises.\n        \"\"\"\n\n        class TrackingCallback(callbacks.Callback):\n            def __init__(self):\n                super().__init__()\n                self.train_end_called = False\n\n            def on_train_end(self, logs=None):\n                self.train_end_called = True\n\n        model = models.Sequential([layers.Dense(1, input_shape=(1,))])\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n\n        x = np.array([[1.0]])\n        y = np.array([[np.inf]])\n\n        tracking_callback = TrackingCallback()\n        raise_error_terminate_callback = TerminateOnNaN(raise_error=True)\n\n        # Should raise RuntimeError\n        with self.assertRaises(RuntimeError):\n            model.fit(\n                x,\n                y,\n                epochs=1,\n                callbacks=[tracking_callback, raise_error_terminate_callback],\n                verbose=0,\n            )\n\n        # on_train_end should NOT have been called\n        self.assertFalse(tracking_callback.train_end_called)\n\n    def test_raise_error_terminate_preserves_backup(self):\n        \"\"\"Ensure BackupAndRestore directory is preserved when\n        TerminateOnNaN(raise_error=True) triggers.\n        \"\"\"\n        tmpdir = self.get_temp_dir()\n        backup_dir = os.path.join(tmpdir, \"backups\")\n        os.makedirs(backup_dir, exist_ok=True)\n\n        fake_file = os.path.join(backup_dir, \"checkpoint.txt\")\n        with open(fake_file, \"w\") as f:\n            f.write(\"dummy checkpoint\")\n\n        model = models.Sequential([layers.Dense(1, input_shape=(1,))])\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n\n        x_nan = np.array([[1.0]])\n        y_nan = np.array([[np.inf]])\n\n        raise_error_terminate_callback = TerminateOnNaN(raise_error=True)\n        backup_callback = BackupAndRestore(backup_dir=backup_dir)\n\n        # Monkeypatch BackupAndRestore to prevent cleanup on train_end\n        backup_callback.on_train_end = lambda logs=None: None\n\n        # Training should raise RuntimeError\n        with self.assertRaises(RuntimeError):\n            model.fit(\n                x_nan,\n                y_nan,\n                epochs=1,\n                callbacks=[backup_callback, raise_error_terminate_callback],\n                verbose=0,\n            )\n\n        # Verify backup directory still exists and file inside is untouched\n        self.assertTrue(\n            os.path.exists(backup_dir),\n            f\"Backup dir deleted: {backup_dir}\",\n        )\n        self.assertTrue(\n            os.path.exists(fake_file),\n            \"Backup file missing unexpectedly.\",\n        )\n\n    @parameterized.named_parameters(\n        (\"raise_error_false\", False),\n        (\"raise_error_true\", True),\n    )\n    def test_normal_training_does_not_raise(self, raise_error):\n        \"\"\"Test that TerminateOnNaN does not raise on normal training.\"\"\"\n        model = models.Sequential([layers.Dense(1, input_shape=(1,))])\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n\n        x = np.array([[1.0], [2.0]])\n        y = np.array([[1.0], [2.0]])\n\n        callback = TerminateOnNaN(raise_error=raise_error)\n\n        # Should complete without raising RuntimeError\n        history = model.fit(x, y, epochs=2, callbacks=[callback], verbose=0)\n\n        # Should have completed 2 epochs\n        self.assertEqual(len(history.history[\"loss\"]), 2)\n\n    def test_raise_error_terminate_stops_on_later_batch(self):\n        \"\"\"Ensure TerminateOnNaN(raise_error=True) stops training\n        if NaN appears in later batch.\n        \"\"\"\n        model = models.Sequential([layers.Dense(1, input_shape=(1,))])\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n\n        # Batch 1: normal loss, Batch 2: NaN loss\n        x = np.array([[1.0], [2.0]])\n        y = np.array([[1.0], [np.inf]])  # NaN/Inf appears only in 2nd batch\n\n        callback = TerminateOnNaN(raise_error=True)\n\n        with self.assertRaises(RuntimeError) as exc:\n            model.fit(\n                x, y, epochs=1, batch_size=1, callbacks=[callback], verbose=0\n            )\n\n        self.assertTrue(any(f\"batch {i}\" in str(exc.exception) for i in [0, 1]))\n"
  },
  {
    "path": "keras/src/constraints/__init__.py",
    "content": "import inspect\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.constraints.constraints import Constraint\nfrom keras.src.constraints.constraints import MaxNorm\nfrom keras.src.constraints.constraints import MinMaxNorm\nfrom keras.src.constraints.constraints import NonNeg\nfrom keras.src.constraints.constraints import UnitNorm\nfrom keras.src.saving import serialization_lib\nfrom keras.src.utils.naming import to_snake_case\n\nALL_OBJECTS = {\n    Constraint,\n    MaxNorm,\n    MinMaxNorm,\n    NonNeg,\n    UnitNorm,\n}\n\nALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}\nALL_OBJECTS_DICT.update(\n    {to_snake_case(cls.__name__): cls for cls in ALL_OBJECTS}\n)\n\n\n@keras_export(\"keras.constraints.serialize\")\ndef serialize(constraint):\n    return serialization_lib.serialize_keras_object(constraint)\n\n\n@keras_export(\"keras.constraints.deserialize\")\ndef deserialize(config, custom_objects=None):\n    \"\"\"Return a Keras constraint object via its config.\"\"\"\n    return serialization_lib.deserialize_keras_object(\n        config,\n        module_objects=ALL_OBJECTS_DICT,\n        custom_objects=custom_objects,\n    )\n\n\n@keras_export(\"keras.constraints.get\")\ndef get(identifier):\n    \"\"\"Retrieve a Keras constraint object via an identifier.\"\"\"\n    if identifier is None:\n        return None\n    if isinstance(identifier, dict):\n        obj = deserialize(identifier)\n    elif isinstance(identifier, str):\n        obj = ALL_OBJECTS_DICT.get(identifier, None)\n    else:\n        obj = identifier\n\n    if callable(obj):\n        if inspect.isclass(obj):\n            obj = obj()\n        return obj\n    else:\n        raise ValueError(\n            f\"Could not interpret constraint identifier: {identifier}\"\n        )\n"
  },
  {
    "path": "keras/src/constraints/constraints.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\n\n\n@keras_export(\"keras.constraints.Constraint\")\nclass Constraint:\n    \"\"\"Base class for weight constraints.\n\n    A `Constraint` instance works like a stateless function.\n    Users who subclass this\n    class should override the `__call__()` method, which takes a single\n    weight parameter and return a projected version of that parameter\n    (e.g. normalized or clipped). Constraints can be used with various Keras\n    layers via the `kernel_constraint` or `bias_constraint` arguments.\n\n    Here's a simple example of a non-negative weight constraint:\n\n    >>> class NonNegative(keras.constraints.Constraint):\n    ...\n    ...  def __call__(self, w):\n    ...    return w * ops.cast(ops.greater_equal(w, 0.), dtype=w.dtype)\n\n    >>> weight = ops.convert_to_tensor((-1.0, 1.0))\n    >>> NonNegative()(weight)\n    [0.,  1.]\n\n    Usage in a layer:\n\n    >>> keras.layers.Dense(4, kernel_constraint=NonNegative())\n    \"\"\"\n\n    def __call__(self, w):\n        \"\"\"Applies the constraint to the input weight variable.\n\n        By default, the inputs weight variable is not modified.\n        Users should override this method to implement their own projection\n        function.\n\n        Args:\n            w: Input weight variable.\n\n        Returns:\n            Projected variable (by default, returns unmodified inputs).\n        \"\"\"\n        return w\n\n    def get_config(self):\n        \"\"\"Returns a Python dict of the object config.\n\n        A constraint config is a Python dictionary (JSON-serializable) that can\n        be used to reinstantiate the same object.\n\n        Returns:\n            Python dict containing the configuration of the constraint object.\n        \"\"\"\n        return {}\n\n    @classmethod\n    def from_config(cls, config):\n        \"\"\"Instantiates a weight constraint from a configuration dictionary.\n\n        Example:\n\n        ```python\n        constraint = UnitNorm()\n        config = constraint.get_config()\n        constraint = UnitNorm.from_config(config)\n        ```\n\n        Args:\n            config: A Python dictionary, the output of `get_config()`.\n\n        Returns:\n            A `keras.constraints.Constraint` instance.\n        \"\"\"\n        return cls(**config)\n\n\n@keras_export([\"keras.constraints.MaxNorm\", \"keras.constraints.max_norm\"])\nclass MaxNorm(Constraint):\n    \"\"\"MaxNorm weight constraint.\n\n    Constrains the weights incident to each hidden unit\n    to have a norm less than or equal to a desired value.\n\n    Also available via the shortcut function `keras.constraints.max_norm`.\n\n    Args:\n        max_value: the maximum norm value for the incoming weights.\n        axis: integer, axis along which to calculate weight norms.\n            For instance, in a `Dense` layer the weight matrix\n            has shape `(input_dim, output_dim)`,\n            set `axis` to `0` to constrain each weight vector\n            of length `(input_dim,)`.\n            In a `Conv2D` layer with `data_format=\"channels_last\"`,\n            the weight tensor has shape\n            `(rows, cols, input_depth, output_depth)`,\n            set `axis` to `[0, 1, 2]`\n            to constrain the weights of each filter tensor of size\n            `(rows, cols, input_depth)`.\n\n    \"\"\"\n\n    def __init__(self, max_value=2, axis=0):\n        self.max_value = max_value\n        self.axis = axis\n\n    def __call__(self, w):\n        w = backend.convert_to_tensor(w)\n        norms = ops.sqrt(ops.sum(ops.square(w), axis=self.axis, keepdims=True))\n        desired = ops.clip(norms, 0, self.max_value)\n        return ops.cast(w, norms.dtype) * (\n            desired / (backend.epsilon() + norms)\n        )\n\n    def get_config(self):\n        return {\"max_value\": self.max_value, \"axis\": self.axis}\n\n\n@keras_export([\"keras.constraints.NonNeg\", \"keras.constraints.non_neg\"])\nclass NonNeg(Constraint):\n    \"\"\"Constrains the weights to be non-negative.\"\"\"\n\n    def __call__(self, w):\n        w = backend.convert_to_tensor(w)\n        return ops.multiply(w, ops.greater_equal(w, 0.0))\n\n\n@keras_export([\"keras.constraints.UnitNorm\", \"keras.constraints.unit_norm\"])\nclass UnitNorm(Constraint):\n    \"\"\"Constrains the weights incident to each hidden unit to have unit norm.\n\n    Args:\n        axis: integer, axis along which to calculate weight norms.\n            For instance, in a `Dense` layer the weight matrix\n            has shape `(input_dim, output_dim)`,\n            set `axis` to `0` to constrain each weight vector\n            of length `(input_dim,)`.\n            In a `Conv2D` layer with `data_format=\"channels_last\"`,\n            the weight tensor has shape\n            `(rows, cols, input_depth, output_depth)`,\n            set `axis` to `[0, 1, 2]`\n            to constrain the weights of each filter tensor of size\n            `(rows, cols, input_depth)`.\n    \"\"\"\n\n    def __init__(self, axis=0):\n        self.axis = axis\n\n    def __call__(self, w):\n        w = backend.convert_to_tensor(w)\n        norms = ops.sqrt(ops.sum(ops.square(w), axis=self.axis, keepdims=True))\n        return ops.cast(w, norms.dtype) / (backend.epsilon() + norms)\n\n    def get_config(self):\n        return {\"axis\": self.axis}\n\n\n@keras_export(\n    [\"keras.constraints.MinMaxNorm\", \"keras.constraints.min_max_norm\"]\n)\nclass MinMaxNorm(Constraint):\n    \"\"\"MinMaxNorm weight constraint.\n\n    Constrains the weights incident to each hidden unit\n    to have the norm between a lower bound and an upper bound.\n\n    Args:\n        min_value: the minimum norm for the incoming weights.\n        max_value: the maximum norm for the incoming weights.\n        rate: rate for enforcing the constraint: weights will be\n            rescaled to yield\n            `(1 - rate) * norm + rate * norm.clip(min_value, max_value)`.\n            Effectively, this means that rate=1.0 stands for strict\n            enforcement of the constraint, while rate<1.0 means that\n            weights will be rescaled at each step to slowly move\n            towards a value inside the desired interval.\n        axis: integer, axis along which to calculate weight norms.\n            For instance, in a `Dense` layer the weight matrix\n            has shape `(input_dim, output_dim)`,\n            set `axis` to `0` to constrain each weight vector\n            of length `(input_dim,)`.\n            In a `Conv2D` layer with `data_format=\"channels_last\"`,\n            the weight tensor has shape\n            `(rows, cols, input_depth, output_depth)`,\n            set `axis` to `[0, 1, 2]`\n            to constrain the weights of each filter tensor of size\n            `(rows, cols, input_depth)`.\n    \"\"\"\n\n    def __init__(self, min_value=0.0, max_value=1.0, rate=1.0, axis=0):\n        self.min_value = min_value\n        self.max_value = max_value\n        self.rate = rate\n        self.axis = axis\n\n    def __call__(self, w):\n        w = backend.convert_to_tensor(w)\n        norms = ops.sqrt(ops.sum(ops.square(w), axis=self.axis, keepdims=True))\n        desired = (\n            self.rate * ops.clip(norms, self.min_value, self.max_value)\n            + (1 - self.rate) * norms\n        )\n        return ops.cast(w, norms.dtype) * (\n            desired / (backend.epsilon() + norms)\n        )\n\n    def get_config(self):\n        return {\n            \"min_value\": self.min_value,\n            \"max_value\": self.max_value,\n            \"rate\": self.rate,\n            \"axis\": self.axis,\n        }\n"
  },
  {
    "path": "keras/src/constraints/constraints_test.py",
    "content": "import numpy as np\n\nfrom keras.src import backend\nfrom keras.src import constraints\nfrom keras.src import testing\n\n\ndef get_example_array():\n    np.random.seed(3537)\n    example_array = np.random.random((100, 100)) * 100.0 - 50.0\n    example_array[0, 0] = 0.0  # Possible edge case\n    return example_array\n\n\nclass ConstraintsTest(testing.TestCase):\n    def test_max_norm(self):\n        constraint_fn = constraints.MaxNorm(2.0)\n        x = np.array([[0, 0, 0], [1.0, 0, 0], [3, 0, 0], [3, 3, 3]]).T\n        target = np.array(\n            [\n                [0, 0, 0],\n                [1.0, 0, 0],\n                [2.0, 0, 0],\n                [2.0 / np.sqrt(3), 2.0 / np.sqrt(3), 2.0 / np.sqrt(3)],\n            ]\n        ).T\n        output = constraint_fn(x)\n        self.assertAllClose(target, output)\n\n    def test_non_neg(self):\n        constraint_fn = constraints.NonNeg()\n        output = constraint_fn(get_example_array())\n        output = backend.convert_to_numpy(output)\n        self.assertTrue((np.min(output, axis=1) >= 0.0).all())\n\n    def test_unit_norm(self):\n        constraint_fn = constraints.UnitNorm()\n        output = constraint_fn(get_example_array())\n        output = backend.convert_to_numpy(output)\n        l2 = np.sqrt(np.sum(np.square(output), axis=0))\n        self.assertAllClose(l2, 1.0)\n\n    def test_min_max_norm(self):\n        constraint_fn = constraints.MinMaxNorm(min_value=0.2, max_value=0.5)\n        output = constraint_fn(get_example_array())\n        output = backend.convert_to_numpy(output)\n        l2 = np.sqrt(np.sum(np.square(output), axis=0))\n        self.assertTrue(np.all(l2 >= 0.2))\n        self.assertTrue(np.all(l2 <= 0.5 + 1e-6))\n\n    def test_get_method(self):\n        obj = constraints.get(\"unit_norm\")\n        self.assertTrue(obj, constraints.UnitNorm)\n\n        obj = constraints.get(None)\n        self.assertEqual(obj, None)\n\n        with self.assertRaises(ValueError):\n            constraints.get(\"typo\")\n\n    def test_default_constraint_call(self):\n        constraint_fn = constraints.Constraint()\n        x = np.array([1.0, 2.0, 3.0])\n        output = constraint_fn(x)\n        self.assertAllClose(x, output)\n\n    def test_constraint_get_config(self):\n        constraint_fn = constraints.Constraint()\n        config = constraint_fn.get_config()\n        self.assertEqual(config, {})\n\n    def test_constraint_from_config(self):\n        constraint_fn = constraints.Constraint()\n        config = constraint_fn.get_config()\n        recreated_constraint_fn = constraints.Constraint.from_config(config)\n        self.assertIsInstance(recreated_constraint_fn, constraints.Constraint)\n\n    def test_max_norm_get_config(self):\n        constraint_fn = constraints.MaxNorm(max_value=3.0, axis=1)\n        config = constraint_fn.get_config()\n        expected_config = {\"max_value\": 3.0, \"axis\": 1}\n        self.assertEqual(config, expected_config)\n\n    def test_unit_norm_get_config(self):\n        constraint_fn = constraints.UnitNorm(axis=1)\n        config = constraint_fn.get_config()\n        expected_config = {\"axis\": 1}\n        self.assertEqual(config, expected_config)\n\n    def test_min_max_norm_get_config(self):\n        constraint_fn = constraints.MinMaxNorm(\n            min_value=0.5, max_value=2.0, rate=0.7, axis=1\n        )\n        config = constraint_fn.get_config()\n        expected_config = {\n            \"min_value\": 0.5,\n            \"max_value\": 2.0,\n            \"rate\": 0.7,\n            \"axis\": 1,\n        }\n        self.assertEqual(config, expected_config)\n"
  },
  {
    "path": "keras/src/datasets/__init__.py",
    "content": "\"\"\"Small NumPy datasets for debugging/testing.\"\"\"\n\nfrom keras.src.datasets import boston_housing\nfrom keras.src.datasets import california_housing\nfrom keras.src.datasets import cifar10\nfrom keras.src.datasets import cifar100\nfrom keras.src.datasets import fashion_mnist\nfrom keras.src.datasets import imdb\nfrom keras.src.datasets import mnist\nfrom keras.src.datasets import reuters\n"
  },
  {
    "path": "keras/src/datasets/boston_housing.py",
    "content": "import numpy as np\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.utils.file_utils import get_file\n\n\n@keras_export(\"keras.datasets.boston_housing.load_data\")\ndef load_data(path=\"boston_housing.npz\", test_split=0.2, seed=113):\n    \"\"\"Loads the Boston Housing dataset.\n\n    This is a dataset taken from the StatLib library which is maintained at\n    Carnegie Mellon University.\n\n    **WARNING:** This dataset has an ethical problem: the authors of this\n    dataset included a variable, \"B\", that may appear to assume that racial\n    self-segregation influences house prices. As such, we strongly discourage\n    the use of this dataset, unless in the context of illustrating ethical\n    issues in data science and machine learning.\n\n    Samples contain 13 attributes of houses at different locations around the\n    Boston suburbs in the late 1970s. Targets are the median values of\n    the houses at a location (in k$).\n\n    The attributes themselves are defined in the\n    [StatLib website](http://lib.stat.cmu.edu/datasets/boston).\n\n    Args:\n        path: path where to cache the dataset locally\n            (relative to `~/.keras/datasets`).\n        test_split: fraction of the data to reserve as test set.\n        seed: Random seed for shuffling the data\n            before computing the test split.\n\n    Returns:\n        Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.\n\n    **x_train, x_test**: NumPy arrays with shape `(num_samples, 13)`\n        containing either the training samples (for x_train),\n        or test samples (for y_train).\n\n    **y_train, y_test**: NumPy arrays of shape `(num_samples,)` containing the\n        target scalars. The targets are float scalars typically between 10 and\n        50 that represent the home prices in k$.\n    \"\"\"\n    if not (0 <= test_split < 1):\n        raise ValueError(\n            f\"Invalid `test_split` argument: {test_split}. \"\n            \"It must be between 0 and 1 (exclusive of 1).\"\n        )\n    origin_folder = (\n        \"https://storage.googleapis.com/tensorflow/tf-keras-datasets/\"\n    )\n    path = get_file(\n        path,\n        origin=f\"{origin_folder}boston_housing.npz\",\n        file_hash=(  # noqa: E501\n            \"f553886a1f8d56431e820c5b82552d9d95cfcb96d1e678153f8839538947dff5\"\n        ),\n    )\n    with np.load(path, allow_pickle=True) as f:\n        x = f[\"x\"]\n        y = f[\"y\"]\n\n    rng = np.random.RandomState(seed)\n    indices = np.arange(len(x))\n    rng.shuffle(indices)\n    x = x[indices]\n    y = y[indices]\n\n    x_train = np.array(x[: int(len(x) * (1 - test_split))])\n    y_train = np.array(y[: int(len(x) * (1 - test_split))])\n    x_test = np.array(x[int(len(x) * (1 - test_split)) :])\n    y_test = np.array(y[int(len(x) * (1 - test_split)) :])\n    return (x_train, y_train), (x_test, y_test)\n"
  },
  {
    "path": "keras/src/datasets/california_housing.py",
    "content": "\"\"\"Boston housing price regression dataset.\"\"\"\n\nimport numpy as np\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.utils.file_utils import get_file\n\n\n@keras_export(\"keras.datasets.california_housing.load_data\")\ndef load_data(\n    version=\"large\", path=\"california_housing.npz\", test_split=0.2, seed=113\n):\n    \"\"\"Loads the California Housing dataset.\n\n    This dataset was obtained from the [StatLib repository](\n    https://www.dcc.fc.up.pt/~ltorgo/Regression/cal_housing.html).\n\n    It's a continuous regression dataset with 20,640 samples with\n    8 features each.\n\n    The target variable is a scalar: the median house value\n    for California districts, in dollars.\n\n    The 8 input features are the following:\n\n    - MedInc: median income in block group\n    - HouseAge: median house age in block group\n    - AveRooms: average number of rooms per household\n    - AveBedrms: average number of bedrooms per household\n    - Population: block group population\n    - AveOccup: average number of household members\n    - Latitude: block group latitude\n    - Longitude: block group longitude\n\n    This dataset was derived from the 1990 U.S. census, using one row\n    per census block group. A block group is the smallest geographical\n    unit for which the U.S. Census Bureau publishes sample data\n    (a block group typically has a population of 600 to 3,000 people).\n\n    A household is a group of people residing within a home.\n    Since the average number of rooms and bedrooms in this dataset are\n    provided per household, these columns may take surprisingly large\n    values for block groups with few households and many empty houses,\n    such as vacation resorts.\n\n    Args:\n        version: `\"small\"` or `\"large\"`. The small version\n            contains 600 samples, the large version contains\n            20,640 samples. The purpose of the small version is\n            to serve as an approximate replacement for the\n            deprecated `boston_housing` dataset.\n        path: path where to cache the dataset locally\n            (relative to `~/.keras/datasets`).\n        test_split: fraction of the data to reserve as test set.\n        seed: Random seed for shuffling the data\n            before computing the test split.\n\n    Returns:\n        Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.\n\n    **`x_train`, `x_test`**: numpy arrays with shape `(num_samples, 8)`\n      containing either the training samples (for `x_train`),\n      or test samples (for `y_train`).\n\n    **`y_train`, `y_test`**: numpy arrays of shape `(num_samples,)`\n        containing the target scalars. The targets are float scalars\n        typically between 25,000 and 500,000 that represent\n        the home prices in dollars.\n    \"\"\"\n    if not (0 <= test_split < 1):\n        raise ValueError(\n            f\"Invalid `test_split` argument: {test_split}. \"\n            \"It must be between 0 and 1 (exclusive of 1).\"\n        )\n    origin_folder = (\n        \"https://storage.googleapis.com/tensorflow/tf-keras-datasets/\"\n    )\n    path = get_file(\n        path,\n        origin=f\"{origin_folder}california_housing.npz\",\n        file_hash=(  # noqa: E501\n            \"1a2e3a52e0398de6463aebe6f4a8da34fb21fbb6b934cf88c3425e766f2a1a6f\"\n        ),\n    )\n    with np.load(path, allow_pickle=True) as f:\n        x = f[\"x\"]\n        y = f[\"y\"]\n\n    if version == \"small\":\n        x = x[:600]\n        y = y[:600]\n    elif version != \"large\":\n        raise ValueError(\n            \"Argument `version` must be one of 'small', 'large'. \"\n            f\"Received: version={version}\"\n        )\n\n    rng = np.random.RandomState(seed)\n    indices = np.arange(len(x))\n    rng.shuffle(indices)\n    x = x[indices]\n    y = y[indices]\n\n    x_train = np.array(x[: int(len(x) * (1 - test_split))])\n    y_train = np.array(y[: int(len(x) * (1 - test_split))])\n    x_test = np.array(x[int(len(x) * (1 - test_split)) :])\n    y_test = np.array(y[int(len(x) * (1 - test_split)) :])\n    return (x_train, y_train), (x_test, y_test)\n"
  },
  {
    "path": "keras/src/datasets/cifar.py",
    "content": "\"\"\"Utilities common to CIFAR10 and CIFAR100 datasets.\"\"\"\n\nimport _pickle as cPickle\n\n\ndef load_batch(fpath, label_key=\"labels\"):\n    \"\"\"Internal utility for parsing CIFAR data.\n\n    Args:\n        fpath: path the file to parse.\n        label_key: key for label data in the retrieve\n            dictionary.\n\n    Returns:\n        A tuple `(data, labels)`.\n    \"\"\"\n    with open(fpath, \"rb\") as f:\n        d = cPickle.load(f, encoding=\"bytes\")\n        # decode utf8\n        d_decoded = {}\n        for k, v in d.items():\n            d_decoded[k.decode(\"utf8\")] = v\n        d = d_decoded\n    data = d[\"data\"]\n    labels = d[label_key]\n\n    data = data.reshape(data.shape[0], 3, 32, 32)\n    return data, labels\n"
  },
  {
    "path": "keras/src/datasets/cifar10.py",
    "content": "\"\"\"CIFAR10 small images classification dataset.\"\"\"\n\nimport os\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.datasets.cifar import load_batch\nfrom keras.src.utils.file_utils import get_file\n\n\n@keras_export(\"keras.datasets.cifar10.load_data\")\ndef load_data():\n    \"\"\"Loads the CIFAR10 dataset.\n\n    This is a dataset of 50,000 32x32 color training images and 10,000 test\n    images, labeled over 10 categories. See more info at the\n    [CIFAR homepage](https://www.cs.toronto.edu/~kriz/cifar.html).\n\n    The classes are:\n\n    | Label | Description |\n    |:-----:|-------------|\n    |   0   | airplane    |\n    |   1   | automobile  |\n    |   2   | bird        |\n    |   3   | cat         |\n    |   4   | deer        |\n    |   5   | dog         |\n    |   6   | frog        |\n    |   7   | horse       |\n    |   8   | ship        |\n    |   9   | truck       |\n\n    Returns:\n        Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.\n\n    **`x_train`**: `uint8` NumPy array of grayscale image data with shapes\n      `(50000, 32, 32, 3)`, containing the training data. Pixel values range\n      from 0 to 255.\n\n    **`y_train`**: `uint8` NumPy array of labels (integers in range 0-9)\n      with shape `(50000, 1)` for the training data.\n\n    **`x_test`**: `uint8` NumPy array of grayscale image data with shapes\n      `(10000, 32, 32, 3)`, containing the test data. Pixel values range\n      from 0 to 255.\n\n    **`y_test`**: `uint8` NumPy array of labels (integers in range 0-9)\n      with shape `(10000, 1)` for the test data.\n\n    Example:\n\n    ```python\n    (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()\n    assert x_train.shape == (50000, 32, 32, 3)\n    assert x_test.shape == (10000, 32, 32, 3)\n    assert y_train.shape == (50000, 1)\n    assert y_test.shape == (10000, 1)\n    ```\n\n    **Note**: The CIFAR-10 dataset is known to have a small percentage of\n    mislabeled samples, which is inherent to the original dataset. This label\n    noise may impact training and evaluation. For more details, refer to\n    discussions in the research literature on CIFAR-10 label quality.\n    \"\"\"\n    dirname = \"cifar-10-batches-py-target\"\n    origin = \"https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz\"\n    path = get_file(\n        fname=dirname,\n        origin=origin,\n        extract=True,\n        file_hash=(  # noqa: E501\n            \"6d958be074577803d12ecdefd02955f39262c83c16fe9348329d7fe0b5c001ce\"\n        ),\n    )\n\n    num_train_samples = 50000\n\n    x_train = np.empty((num_train_samples, 3, 32, 32), dtype=\"uint8\")\n    y_train = np.empty((num_train_samples,), dtype=\"uint8\")\n\n    # batches are within an inner folder\n    path = os.path.join(path, \"cifar-10-batches-py\")\n    for i in range(1, 6):\n        fpath = os.path.join(path, f\"data_batch_{i}\")\n        (\n            x_train[(i - 1) * 10000 : i * 10000, :, :, :],\n            y_train[(i - 1) * 10000 : i * 10000],\n        ) = load_batch(fpath)\n\n    fpath = os.path.join(path, \"test_batch\")\n    x_test, y_test = load_batch(fpath)\n\n    y_train = np.reshape(y_train, (len(y_train), 1))\n    y_test = np.reshape(y_test, (len(y_test), 1))\n\n    if backend.image_data_format() == \"channels_last\":\n        x_train = x_train.transpose(0, 2, 3, 1)\n        x_test = x_test.transpose(0, 2, 3, 1)\n\n    x_test = x_test.astype(x_train.dtype)\n    y_test = y_test.astype(y_train.dtype)\n\n    return (x_train, y_train), (x_test, y_test)\n"
  },
  {
    "path": "keras/src/datasets/cifar100.py",
    "content": "\"\"\"CIFAR100 small images classification dataset.\"\"\"\n\nimport os\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.datasets.cifar import load_batch\nfrom keras.src.utils.file_utils import get_file\n\n\n@keras_export(\"keras.datasets.cifar100.load_data\")\ndef load_data(label_mode=\"fine\"):\n    \"\"\"Loads the CIFAR100 dataset.\n\n    This is a dataset of 50,000 32x32 color training images and\n    10,000 test images, labeled over 100 fine-grained classes that are\n    grouped into 20 coarse-grained classes. See more info at the\n    [CIFAR homepage](https://www.cs.toronto.edu/~kriz/cifar.html).\n\n    Args:\n        label_mode: one of `\"fine\"`, `\"coarse\"`.\n            If it is `\"fine\"`, the category labels\n            are the fine-grained labels, and if it is `\"coarse\"`,\n            the output labels are the coarse-grained superclasses.\n\n    Returns:\n        Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.\n\n    **`x_train`**: `uint8` NumPy array of grayscale image data with shapes\n      `(50000, 32, 32, 3)`, containing the training data. Pixel values range\n      from 0 to 255.\n\n    **`y_train`**: `uint8` NumPy array of labels (integers in range 0-99)\n      with shape `(50000, 1)` for the training data.\n\n    **`x_test`**: `uint8` NumPy array of grayscale image data with shapes\n      `(10000, 32, 32, 3)`, containing the test data. Pixel values range\n      from 0 to 255.\n\n    **`y_test`**: `uint8` NumPy array of labels (integers in range 0-99)\n      with shape `(10000, 1)` for the test data.\n\n    Example:\n\n    ```python\n    (x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()\n    assert x_train.shape == (50000, 32, 32, 3)\n    assert x_test.shape == (10000, 32, 32, 3)\n    assert y_train.shape == (50000, 1)\n    assert y_test.shape == (10000, 1)\n    ```\n    \"\"\"\n    if label_mode not in [\"fine\", \"coarse\"]:\n        raise ValueError(\n            '`label_mode` must be one of `\"fine\"`, `\"coarse\"`. '\n            f\"Received: label_mode={label_mode}.\"\n        )\n\n    dirname = \"cifar-100-python-target\"\n    origin = \"https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz\"\n    path = get_file(\n        fname=dirname,\n        origin=origin,\n        extract=True,\n        file_hash=(  # noqa: E501\n            \"85cd44d02ba6437773c5bbd22e183051d648de2e7d6b014e1ef29b855ba677a7\"\n        ),\n    )\n\n    path = os.path.join(path, \"cifar-100-python\")\n    fpath = os.path.join(path, \"train\")\n    x_train, y_train = load_batch(fpath, label_key=f\"{label_mode}_labels\")\n\n    fpath = os.path.join(path, \"test\")\n    x_test, y_test = load_batch(fpath, label_key=f\"{label_mode}_labels\")\n\n    y_train = np.reshape(y_train, (len(y_train), 1))\n    y_test = np.reshape(y_test, (len(y_test), 1))\n\n    if backend.image_data_format() == \"channels_last\":\n        x_train = x_train.transpose(0, 2, 3, 1)\n        x_test = x_test.transpose(0, 2, 3, 1)\n\n    return (x_train, y_train), (x_test, y_test)\n"
  },
  {
    "path": "keras/src/datasets/fashion_mnist.py",
    "content": "\"\"\"Fashion-MNIST dataset.\"\"\"\n\nimport gzip\nimport os\n\nimport numpy as np\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.utils.file_utils import get_file\n\n\n@keras_export(\"keras.datasets.fashion_mnist.load_data\")\ndef load_data():\n    \"\"\"Loads the Fashion-MNIST dataset.\n\n    This is a dataset of 60,000 28x28 grayscale images of 10 fashion categories,\n    along with a test set of 10,000 images. This dataset can be used as\n    a drop-in replacement for MNIST.\n\n    The classes are:\n\n    | Label | Description |\n    |:-----:|-------------|\n    |   0   | T-shirt/top |\n    |   1   | Trouser     |\n    |   2   | Pullover    |\n    |   3   | Dress       |\n    |   4   | Coat        |\n    |   5   | Sandal      |\n    |   6   | Shirt       |\n    |   7   | Sneaker     |\n    |   8   | Bag         |\n    |   9   | Ankle boot  |\n\n    Returns:\n\n    Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.\n\n    **`x_train`**: `uint8` NumPy array of grayscale image data with shapes\n      `(60000, 28, 28)`, containing the training data.\n\n    **`y_train`**: `uint8` NumPy array of labels (integers in range 0-9)\n      with shape `(60000,)` for the training data.\n\n    **`x_test`**: `uint8` NumPy array of grayscale image data with shapes\n      (10000, 28, 28), containing the test data.\n\n    **`y_test`**: `uint8` NumPy array of labels (integers in range 0-9)\n      with shape `(10000,)` for the test data.\n\n    Example:\n\n    ```python\n    (x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()\n    assert x_train.shape == (60000, 28, 28)\n    assert x_test.shape == (10000, 28, 28)\n    assert y_train.shape == (60000,)\n    assert y_test.shape == (10000,)\n    ```\n\n    License:\n\n    The copyright for Fashion-MNIST is held by Zalando SE.\n    Fashion-MNIST is licensed under the [MIT license](\n        https://github.com/zalandoresearch/fashion-mnist/blob/master/LICENSE).\n    \"\"\"\n    dirname = os.path.join(\"datasets\", \"fashion-mnist\")\n    base = \"https://storage.googleapis.com/tensorflow/tf-keras-datasets/\"\n    files = [\n        \"train-labels-idx1-ubyte.gz\",\n        \"train-images-idx3-ubyte.gz\",\n        \"t10k-labels-idx1-ubyte.gz\",\n        \"t10k-images-idx3-ubyte.gz\",\n    ]\n\n    paths = []\n    for fname in files:\n        paths.append(get_file(fname, origin=base + fname, cache_subdir=dirname))\n\n    with gzip.open(paths[0], \"rb\") as lbpath:\n        y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)\n\n    with gzip.open(paths[1], \"rb\") as imgpath:\n        x_train = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(\n            len(y_train), 28, 28\n        )\n\n    with gzip.open(paths[2], \"rb\") as lbpath:\n        y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)\n\n    with gzip.open(paths[3], \"rb\") as imgpath:\n        x_test = np.frombuffer(imgpath.read(), np.uint8, offset=16).reshape(\n            len(y_test), 28, 28\n        )\n\n    return (x_train, y_train), (x_test, y_test)\n"
  },
  {
    "path": "keras/src/datasets/imdb.py",
    "content": "\"\"\"IMDB sentiment classification dataset.\"\"\"\n\nimport json\n\nimport numpy as np\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.utils.file_utils import get_file\nfrom keras.src.utils.python_utils import remove_long_seq\n\n\n@keras_export(\"keras.datasets.imdb.load_data\")\ndef load_data(\n    path=\"imdb.npz\",\n    num_words=None,\n    skip_top=0,\n    maxlen=None,\n    seed=113,\n    start_char=1,\n    oov_char=2,\n    index_from=3,\n    **kwargs,\n):\n    \"\"\"Loads the [IMDB dataset](https://ai.stanford.edu/~amaas/data/sentiment/).\n\n    This is a dataset of 25,000 movies reviews from IMDB, labeled by sentiment\n    (positive/negative). Reviews have been preprocessed, and each review is\n    encoded as a list of word indexes (integers).\n    For convenience, words are indexed by overall frequency in the dataset,\n    so that for instance the integer \"3\" encodes the 3rd most frequent word in\n    the data. This allows for quick filtering operations such as:\n    \"only consider the top 10,000 most\n    common words, but eliminate the top 20 most common words\".\n\n    As a convention, \"0\" does not stand for a specific word, but instead is used\n    to encode the pad token.\n\n    Args:\n        path: where to cache the data (relative to `~/.keras/dataset`).\n        num_words: integer or None. Words are\n            ranked by how often they occur (in the training set) and only\n            the `num_words` most frequent words are kept. Any less frequent word\n            will appear as `oov_char` value in the sequence data. If None,\n            all words are kept. Defaults to `None`.\n        skip_top: skip the top N most frequently occurring words\n            (which may not be informative). These words will appear as\n            `oov_char` value in the dataset. When 0, no words are\n            skipped. Defaults to `0`.\n        maxlen: int or None. Maximum sequence length.\n            Any longer sequence will be truncated. None, means no truncation.\n            Defaults to `None`.\n        seed: int. Seed for reproducible data shuffling.\n        start_char: int. The start of a sequence will be marked with this\n            character. 0 is usually the padding character. Defaults to `1`.\n        oov_char: int. The out-of-vocabulary character.\n            Words that were cut out because of the `num_words` or\n            `skip_top` limits will be replaced with this character.\n        index_from: int. Index actual words with this index and higher.\n\n    Returns:\n        Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.\n\n    **`x_train`, `x_test`**: lists of sequences, which are lists of indexes\n      (integers). If the num_words argument was specific, the maximum\n      possible index value is `num_words - 1`. If the `maxlen` argument was\n      specified, the largest possible sequence length is `maxlen`.\n\n    **`y_train`, `y_test`**: lists of integer labels (1 or 0).\n\n    **Note**: The 'out of vocabulary' character is only used for\n    words that were present in the training set but are not included\n    because they're not making the `num_words` cut here.\n    Words that were not seen in the training set but are in the test set\n    have simply been skipped.\n    \"\"\"\n    origin_folder = (\n        \"https://storage.googleapis.com/tensorflow/tf-keras-datasets/\"\n    )\n    path = get_file(\n        fname=path,\n        origin=f\"{origin_folder}imdb.npz\",\n        file_hash=(  # noqa: E501\n            \"69664113be75683a8fe16e3ed0ab59fda8886cb3cd7ada244f7d9544e4676b9f\"\n        ),\n    )\n    with np.load(path, allow_pickle=True) as f:\n        x_train, labels_train = f[\"x_train\"], f[\"y_train\"]\n        x_test, labels_test = f[\"x_test\"], f[\"y_test\"]\n\n    rng = np.random.RandomState(seed)\n    indices = np.arange(len(x_train))\n    rng.shuffle(indices)\n    x_train = x_train[indices]\n    labels_train = labels_train[indices]\n\n    indices = np.arange(len(x_test))\n    rng.shuffle(indices)\n    x_test = x_test[indices]\n    labels_test = labels_test[indices]\n\n    if start_char is not None:\n        x_train = [[start_char] + [w + index_from for w in x] for x in x_train]\n        x_test = [[start_char] + [w + index_from for w in x] for x in x_test]\n    elif index_from:\n        x_train = [[w + index_from for w in x] for x in x_train]\n        x_test = [[w + index_from for w in x] for x in x_test]\n    else:\n        x_train = [[w for w in x] for x in x_train]\n        x_test = [[w for w in x] for x in x_test]\n\n    if maxlen:\n        x_train, labels_train = remove_long_seq(maxlen, x_train, labels_train)\n        x_test, labels_test = remove_long_seq(maxlen, x_test, labels_test)\n        if not x_train or not x_test:\n            raise ValueError(\n                \"After filtering for sequences shorter than maxlen=\"\n                f\"{str(maxlen)}, no sequence was kept. Increase maxlen.\"\n            )\n\n    xs = x_train + x_test\n    labels = np.concatenate([labels_train, labels_test])\n\n    if not num_words:\n        num_words = max(max(x) for x in xs)\n\n    # by convention, use 2 as OOV word\n    # reserve 'index_from' (=3 by default) characters:\n    # 0 (padding), 1 (start), 2 (OOV)\n    if oov_char is not None:\n        xs = [\n            [w if (skip_top <= w < num_words) else oov_char for w in x]\n            for x in xs\n        ]\n    else:\n        xs = [[w for w in x if skip_top <= w < num_words] for x in xs]\n\n    idx = len(x_train)\n    x_train, y_train = np.array(xs[:idx], dtype=\"object\"), labels[:idx]\n    x_test, y_test = np.array(xs[idx:], dtype=\"object\"), labels[idx:]\n    return (x_train, y_train), (x_test, y_test)\n\n\n@keras_export(\"keras.datasets.imdb.get_word_index\")\ndef get_word_index(path=\"imdb_word_index.json\"):\n    \"\"\"Retrieves a dict mapping words to their index in the IMDB dataset.\n\n    Args:\n        path: where to cache the data (relative to `~/.keras/dataset`).\n\n    Returns:\n        The word index dictionary. Keys are word strings, values are their\n        index.\n\n    Example:\n\n    ```python\n    # Use the default parameters to keras.datasets.imdb.load_data\n    start_char = 1\n    oov_char = 2\n    index_from = 3\n    # Retrieve the training sequences.\n    (x_train, _), _ = keras.datasets.imdb.load_data(\n        start_char=start_char, oov_char=oov_char, index_from=index_from\n    )\n    # Retrieve the word index file mapping words to indices\n    word_index = keras.datasets.imdb.get_word_index()\n    # Reverse the word index to obtain a dict mapping indices to words\n    # And add `index_from` to indices to sync with `x_train`\n    inverted_word_index = dict(\n        (i + index_from, word) for (word, i) in word_index.items()\n    )\n    # Update `inverted_word_index` to include `start_char` and `oov_char`\n    inverted_word_index[start_char] = \"[START]\"\n    inverted_word_index[oov_char] = \"[OOV]\"\n    # Decode the first sequence in the dataset\n    decoded_sequence = \" \".join(inverted_word_index[i] for i in x_train[0])\n    ```\n    \"\"\"\n    origin_folder = (\n        \"https://storage.googleapis.com/tensorflow/tf-keras-datasets/\"\n    )\n    path = get_file(\n        fname=path,\n        origin=f\"{origin_folder}imdb_word_index.json\",\n        file_hash=\"bfafd718b763782e994055a2d397834f\",\n    )\n    with open(path) as f:\n        return json.load(f)\n"
  },
  {
    "path": "keras/src/datasets/mnist.py",
    "content": "\"\"\"MNIST handwritten digits dataset.\"\"\"\n\nimport numpy as np\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.utils.file_utils import get_file\n\n\n@keras_export(\"keras.datasets.mnist.load_data\")\ndef load_data(path=\"mnist.npz\"):\n    \"\"\"Loads the MNIST dataset.\n\n    This is a dataset of 60,000 28x28 grayscale images of the 10 digits,\n    along with a test set of 10,000 images.\n    More info can be found at the\n    [MNIST homepage](http://yann.lecun.com/exdb/mnist/).\n\n    Args:\n        path: path where to cache the dataset locally\n            (relative to `~/.keras/datasets`).\n\n    Returns:\n        Tuple of NumPy arrays: `(x_train, y_train), (x_test, y_test)`.\n\n    **`x_train`**: `uint8` NumPy array of grayscale image data with shapes\n      `(60000, 28, 28)`, containing the training data. Pixel values range\n      from 0 to 255.\n\n    **`y_train`**: `uint8` NumPy array of digit labels (integers in range 0-9)\n      with shape `(60000,)` for the training data.\n\n    **`x_test`**: `uint8` NumPy array of grayscale image data with shapes\n      `(10000, 28, 28)`, containing the test data. Pixel values range\n      from 0 to 255.\n\n    **`y_test`**: `uint8` NumPy array of digit labels (integers in range 0-9)\n      with shape `(10000,)` for the test data.\n\n    Example:\n\n    ```python\n    (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n    assert x_train.shape == (60000, 28, 28)\n    assert x_test.shape == (10000, 28, 28)\n    assert y_train.shape == (60000,)\n    assert y_test.shape == (10000,)\n    ```\n\n    License:\n\n    Yann LeCun and Corinna Cortes hold the copyright of MNIST dataset,\n    which is a derivative work from original NIST datasets.\n    MNIST dataset is made available under the terms of the\n    [Creative Commons Attribution-Share Alike 3.0 license.](\n        https://creativecommons.org/licenses/by-sa/3.0/)\n    \"\"\"\n    origin_folder = (\n        \"https://storage.googleapis.com/tensorflow/tf-keras-datasets/\"\n    )\n    path = get_file(\n        fname=path,\n        origin=f\"{origin_folder}mnist.npz\",\n        file_hash=(  # noqa: E501\n            \"731c5ac602752760c8e48fbffcf8c3b850d9dc2a2aedcf2cc48468fc17b673d1\"\n        ),\n    )\n    with np.load(path, allow_pickle=True) as f:\n        x_train, y_train = f[\"x_train\"], f[\"y_train\"]\n        x_test, y_test = f[\"x_test\"], f[\"y_test\"]\n\n        return (x_train, y_train), (x_test, y_test)\n"
  },
  {
    "path": "keras/src/datasets/reuters.py",
    "content": "\"\"\"Reuters topic classification dataset.\"\"\"\n\nimport json\n\nimport numpy as np\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.utils.file_utils import get_file\nfrom keras.src.utils.python_utils import remove_long_seq\n\n\n@keras_export(\"keras.datasets.reuters.load_data\")\ndef load_data(\n    path=\"reuters.npz\",\n    num_words=None,\n    skip_top=0,\n    maxlen=None,\n    test_split=0.2,\n    seed=113,\n    start_char=1,\n    oov_char=2,\n    index_from=3,\n):\n    \"\"\"Loads the Reuters newswire classification dataset.\n\n    This is a dataset of 11,228 newswires from Reuters, labeled over 46 topics.\n\n    This was originally generated by parsing and preprocessing the classic\n    Reuters-21578 dataset, but the preprocessing code is no longer packaged\n    with Keras. See this\n    [GitHub discussion](https://github.com/keras-team/keras/issues/12072)\n    for more info.\n\n    Each newswire is encoded as a list of word indexes (integers).\n    For convenience, words are indexed by overall frequency in the dataset,\n    so that for instance the integer \"3\" encodes the 3rd most frequent word in\n    the data. This allows for quick filtering operations such as:\n    \"only consider the top 10,000 most\n    common words, but eliminate the top 20 most common words\".\n\n    As a convention, \"0\" does not stand for a specific word, but instead is used\n    to encode any unknown word.\n\n    Args:\n        path: where to cache the data (relative to `~/.keras/dataset`).\n        num_words: integer or None. Words are\n            ranked by how often they occur (in the training set) and only\n            the `num_words` most frequent words are kept. Any less frequent word\n            will appear as `oov_char` value in the sequence data. If None,\n            all words are kept. Defaults to `None`.\n        skip_top: skip the top N most frequently occurring words\n            (which may not be informative). These words will appear as\n            `oov_char` value in the dataset. 0 means no words are\n            skipped. Defaults to `0`.\n        maxlen: int or None. Maximum sequence length.\n            Any longer sequence will be truncated. None means no truncation.\n            Defaults to `None`.\n        test_split: Float between `0.` and `1.`. Fraction of the dataset to be\n            used as test data. `0.2` means that 20% of the dataset is used as\n            test data. Defaults to `0.2`.\n        seed: int. Seed for reproducible data shuffling.\n        start_char: int. The start of a sequence will be marked with this\n            character. 0 is usually the padding character. Defaults to `1`.\n        oov_char: int. The out-of-vocabulary character.\n            Words that were cut out because of the `num_words` or\n            `skip_top` limits will be replaced with this character.\n        index_from: int. Index actual words with this index and higher.\n\n    Returns:\n        Tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.\n\n    **`x_train`, `x_test`**: lists of sequences, which are lists of indexes\n      (integers). If the num_words argument was specific, the maximum\n      possible index value is `num_words - 1`. If the `maxlen` argument was\n      specified, the largest possible sequence length is `maxlen`.\n\n    **`y_train`, `y_test`**: lists of integer labels (1 or 0).\n\n    **Note**: The 'out of vocabulary' character is only used for\n    words that were present in the training set but are not included\n    because they're not making the `num_words` cut here.\n    Words that were not seen in the training set but are in the test set\n    have simply been skipped.\n    \"\"\"\n    origin_folder = (\n        \"https://storage.googleapis.com/tensorflow/tf-keras-datasets/\"\n    )\n    path = get_file(\n        fname=path,\n        origin=f\"{origin_folder}reuters.npz\",\n        file_hash=(  # noqa: E501\n            \"d6586e694ee56d7a4e65172e12b3e987c03096cb01eab99753921ef915959916\"\n        ),\n    )\n    with np.load(path, allow_pickle=True) as f:\n        xs, labels = f[\"x\"], f[\"y\"]\n\n    rng = np.random.RandomState(seed)\n    indices = np.arange(len(xs))\n    rng.shuffle(indices)\n    xs = xs[indices]\n    labels = labels[indices]\n\n    if start_char is not None:\n        xs = [[start_char] + [w + index_from for w in x] for x in xs]\n    elif index_from:\n        xs = [[w + index_from for w in x] for x in xs]\n\n    if maxlen:\n        xs, labels = remove_long_seq(maxlen, xs, labels)\n\n    if not num_words:\n        num_words = max(max(x) for x in xs)\n\n    # by convention, use 2 as OOV word\n    # reserve 'index_from' (=3 by default) characters:\n    # 0 (padding), 1 (start), 2 (OOV)\n    if oov_char is not None:\n        xs = [\n            [w if skip_top <= w < num_words else oov_char for w in x]\n            for x in xs\n        ]\n    else:\n        xs = [[w for w in x if skip_top <= w < num_words] for x in xs]\n\n    idx = int(len(xs) * (1 - test_split))\n    x_train, y_train = (\n        np.array(xs[:idx], dtype=\"object\"),\n        np.array(labels[:idx]),\n    )\n    x_test, y_test = np.array(xs[idx:], dtype=\"object\"), np.array(labels[idx:])\n\n    return (x_train, y_train), (x_test, y_test)\n\n\n@keras_export(\"keras.datasets.reuters.get_word_index\")\ndef get_word_index(path=\"reuters_word_index.json\"):\n    \"\"\"Retrieves a dict mapping words to their index in the Reuters dataset.\n\n    Actual word indices starts from 3, with 3 indices reserved for:\n    0 (padding), 1 (start), 2 (oov).\n\n    E.g. word index of 'the' is 1, but the in the actual training data, the\n    index of 'the' will be 1 + 3 = 4. Vice versa, to translate word indices in\n    training data back to words using this mapping, indices need to subtract 3.\n\n    Args:\n        path: where to cache the data (relative to `~/.keras/dataset`).\n\n    Returns:\n        The word index dictionary. Keys are word strings, values are their\n        index.\n    \"\"\"\n    origin_folder = (\n        \"https://storage.googleapis.com/tensorflow/tf-keras-datasets/\"\n    )\n    path = get_file(\n        path,\n        origin=f\"{origin_folder}reuters_word_index.json\",\n        file_hash=\"4d44cc38712099c9e383dc6e5f11a921\",\n    )\n    with open(path) as f:\n        return json.load(f)\n\n\n@keras_export(\"keras.datasets.reuters.get_label_names\")\ndef get_label_names():\n    \"\"\"Returns labels as a list of strings with indices matching training data.\n\n    Reference:\n\n    - [Reuters Dataset](https://martin-thoma.com/nlp-reuters/)\n    \"\"\"\n    return (\n        \"cocoa\",\n        \"grain\",\n        \"veg-oil\",\n        \"earn\",\n        \"acq\",\n        \"wheat\",\n        \"copper\",\n        \"housing\",\n        \"money-supply\",\n        \"coffee\",\n        \"sugar\",\n        \"trade\",\n        \"reserves\",\n        \"ship\",\n        \"cotton\",\n        \"carcass\",\n        \"crude\",\n        \"nat-gas\",\n        \"cpi\",\n        \"money-fx\",\n        \"interest\",\n        \"gnp\",\n        \"meal-feed\",\n        \"alum\",\n        \"oilseed\",\n        \"gold\",\n        \"tin\",\n        \"strategic-metal\",\n        \"livestock\",\n        \"retail\",\n        \"ipi\",\n        \"iron-steel\",\n        \"rubber\",\n        \"heat\",\n        \"jobs\",\n        \"lei\",\n        \"bop\",\n        \"zinc\",\n        \"orange\",\n        \"pet-chem\",\n        \"dlr\",\n        \"gas\",\n        \"silver\",\n        \"wpi\",\n        \"hog\",\n        \"lead\",\n    )\n"
  },
  {
    "path": "keras/src/distillation/__init__.py",
    "content": "\"\"\"Distillation module for knowledge distillation in Keras.\"\"\"\n"
  },
  {
    "path": "keras/src/distillation/distillation_loss.py",
    "content": "import keras\nfrom keras.src import tree\nfrom keras.src.api_export import keras_export\nfrom keras.src.saving import serialization_lib\nfrom keras.src.utils import tracking\n\n\ndef _convert_loss_to_function(loss_item):\n    \"\"\"Convert a loss string identifier to a loss function.\n\n    Arguments:\n        loss_item: Either a string identifier, a loss function instance,\n            or `None`.\n\n    Returns:\n        A loss function instance, or `None`.\n\n    Raises:\n        ValueError: If the loss string identifier is unknown.\n    \"\"\"\n    if loss_item is None:\n        return None\n    elif isinstance(loss_item, str):\n        loss_fn = keras.losses.get(loss_item)\n        if loss_fn is None:\n            raise ValueError(f\"Unknown loss function: '{loss_item}'.\")\n        return loss_fn\n    else:\n        return loss_item\n\n\n@keras_export(\"keras.distillation.DistillationLoss\")\nclass DistillationLoss:\n    \"\"\"Base class for distillation loss computation.\n\n    Distillation losses define how to compute the distillation loss\n    between teacher and student outputs. Each loss implements a specific\n    approach to knowledge transfer, from simple logits matching to feature-based\n    distillation.\n\n    To create custom distillation losses, subclass this class and\n    override the `compute_loss` method.\n    \"\"\"\n\n    def compute_loss(self, teacher_outputs, student_outputs, **kwargs):\n        \"\"\"Compute distillation loss between teacher and student outputs.\n\n        This method should implement the specific distillation logic for\n        transferring knowledge from teacher to student.\n\n        Arguments:\n            teacher_outputs: Outputs from the teacher model. Can be a single\n                tensor or a list/tuple of tensors for multi-output models.\n            student_outputs: Outputs from the student model. Can be a single\n                tensor or a list/tuple of tensors for multi-output models.\n            **kwargs: Additional arguments for custom distillation_loss.\n        Returns:\n            Distillation loss tensor.\n        \"\"\"\n        raise NotImplementedError(\"Subclasses must implement compute_loss\")\n\n    def validate_outputs(self, teacher_outputs, student_outputs):\n        \"\"\"Validate that teacher and student outputs are compatible.\n\n        Arguments:\n            teacher_outputs: Outputs from the teacher model.\n            student_outputs: Outputs from the student model.\n        Raises:\n            ValueError: If outputs are not compatible.\n        \"\"\"\n        keras.tree.assert_same_structure(teacher_outputs, student_outputs)\n\n    def validate_model_compatibility(self, teacher, student):\n        \"\"\"Validate that teacher and student models are compatible.\n\n        Arguments:\n            teacher: The teacher model.\n            student: The student model.\n        Raises:\n            ValueError: If models are not compatible with this distillation\n                loss.\n        \"\"\"\n        pass\n\n\n@keras_export(\"keras.distillation.FeatureDistillation\")\nclass FeatureDistillation(DistillationLoss):\n    \"\"\"Feature distillation loss.\n\n    Feature distillation transfers knowledge from intermediate layers of the\n    teacher model to corresponding layers of the student model. This approach\n    helps the student learn better internal representations and often leads\n    to better performance compared to logits-only distillation.\n\n    Arguments:\n        loss: Loss function to use for feature distillation. Can be:\n            - String identifier (e.g., 'mse', 'cosine_similarity', 'mae')\n            - Keras loss instance\n            - Nested structure of losses matching the layer output structure\n            - `None` to skip distillation for that output (useful for\n              multi-output models where you only want to distill some outputs)\n            At least one loss must be non-`None`. Defaults to 'mse'.\n        teacher_layer_name: Name of the teacher layer to extract features from.\n            If `None`, uses the final output. Defaults to `None`.\n        student_layer_name: Name of the student layer to extract features from.\n            If `None`, uses the final output. Defaults to `None`.\n\n    Examlpe(s):\n\n    ```python\n    # Basic feature distillation from final outputs\n    distillation_loss = FeatureDistillation(loss=\"mse\")\n\n    # Distill from specific intermediate layers\n    distillation_loss = FeatureDistillation(\n        loss=\"mse\",\n        teacher_layer_name=\"dense_1\",\n        student_layer_name=\"dense_1\"\n    )\n\n    # Use cosine similarity for different feature sizes\n    distillation_loss = FeatureDistillation(\n        loss=\"cosine_similarity\",\n        teacher_layer_name=\"conv2d_2\",\n        student_layer_name=\"conv2d_1\"\n    )\n\n    # With custom loss instance\n    distillation_loss = FeatureDistillation(\n        loss=keras.losses.MeanAbsoluteError()\n    )\n\n    # For multi-output models\n    distillation_loss = FeatureDistillation(\n        loss=[\"mse\", \"cosine_similarity\"]\n    )\n\n    # For multi-output models, only distill some outputs\n    distillation_loss = FeatureDistillation(\n        loss=[\"mse\", None, \"cosine_similarity\"]  # Skip middle output\n    )\n    ```\n    \"\"\"\n\n    @tracking.no_automatic_dependency_tracking\n    def __init__(\n        self, loss=\"mse\", teacher_layer_name=None, student_layer_name=None\n    ):\n        self.teacher_layer_name = teacher_layer_name\n        self.student_layer_name = student_layer_name\n        self.loss = tree.map_structure(_convert_loss_to_function, loss)\n\n        flat_losses = tree.flatten(self.loss)\n        if all(l is None for l in flat_losses):\n            raise ValueError(\n                \"The `loss` argument in `FeatureDistillation` must \"\n                \"contain at least one non-`None` value.\"\n            )\n\n    def validate_model_compatibility(self, teacher, student):\n        \"\"\"Validate that teacher and student models are compatible for feature\n        distillation.\"\"\"\n        if (\n            self.teacher_layer_name is not None\n            or self.student_layer_name is not None\n        ):\n            teacher_is_subclassed = (\n                not hasattr(teacher, \"inputs\") or teacher.inputs is None\n            )\n            student_is_subclassed = (\n                not hasattr(student, \"inputs\") or student.inputs is None\n            )\n\n            if teacher_is_subclassed or student_is_subclassed:\n                subclassed_models = []\n                if teacher_is_subclassed:\n                    subclassed_models.append(\"teacher\")\n                if student_is_subclassed:\n                    subclassed_models.append(\"student\")\n\n                models_str = \" and \".join(subclassed_models)\n                raise ValueError(\n                    f\"FeatureDistillation with specific layer names requires \"\n                    f\"Functional or Sequential models. The {models_str} \"\n                    f\"model(s) appear to be subclassed (no symbolic \"\n                    f\"inputs/outputs). Either use Functional/Sequential \"\n                    f\"models, or use FeatureDistillation without layer names \"\n                    f\"(to distill final outputs only), or use \"\n                    f\"LogitsDistillation instead.\"\n                )\n\n        if self.teacher_layer_name is not None:\n            try:\n                teacher.get_layer(name=self.teacher_layer_name)\n            except ValueError as e:\n                raise ValueError(f\"In teacher model: {e}\")\n\n        if self.student_layer_name is not None:\n            try:\n                student.get_layer(name=self.student_layer_name)\n            except ValueError as e:\n                raise ValueError(f\"In student model: {e}\")\n\n    def validate_outputs(self, teacher_outputs, student_outputs):\n        \"\"\"Validate that outputs are compatible for feature distillation.\"\"\"\n        super().validate_outputs(teacher_outputs, student_outputs)\n\n        try:\n            tree.assert_same_structure(self.loss, teacher_outputs)\n        except ValueError as e:\n            raise ValueError(\n                f\"Loss structure mismatch. \"\n                f\"Loss structure: {tree.structure(self.loss)}, \"\n                f\"Output structure: {tree.structure(teacher_outputs)}. \"\n                f\"Error: {e}\"\n            )\n\n    def compute_loss(self, teacher_outputs, student_outputs, **kwargs):\n        \"\"\"Compute feature distillation loss using extracted features.\n\n        Arguments:\n            teacher_outputs: Extracted features from teacher layer.\n            student_outputs: Extracted features from student layer.\n            **kwargs: Additional arguments (ignored).\n        Returns:\n            Scalar distillation loss tensor.\n        \"\"\"\n\n        def apply_loss(loss_fn, teacher_features, student_features):\n            if loss_fn is None:\n                return 0.0\n\n            loss = keras.ops.mean(loss_fn(teacher_features, student_features))\n\n            return loss\n\n        loss_values = tree.map_structure(\n            apply_loss, self.loss, teacher_outputs, student_outputs\n        )\n\n        flat_losses = tree.flatten(loss_values)\n        return keras.ops.sum(keras.ops.stack(flat_losses))\n\n    def get_config(self):\n        \"\"\"Get configuration for serialization.\"\"\"\n        return {\n            \"loss\": keras.losses.serialize(self.loss),\n            \"teacher_layer_name\": self.teacher_layer_name,\n            \"student_layer_name\": self.student_layer_name,\n        }\n\n    @classmethod\n    def from_config(cls, config):\n        \"\"\"Create instance from configuration.\"\"\"\n        config = config.copy()\n        config[\"loss\"] = keras.losses.deserialize(config[\"loss\"])\n        return cls(**config)\n\n\n@keras_export(\"keras.distillation.LogitsDistillation\")\nclass LogitsDistillation(DistillationLoss):\n    \"\"\"Distillation loss that transfers knowledge from final model outputs.\n\n    This distillation loss applies temperature scaling to the teacher's logits\n    before computing the loss between teacher and student predictions. It's the\n    most common approach for knowledge distillation.\n\n    Arguments:\n        temperature: Temperature for softmax scaling. Higher values produce\n            softer probability distributions that are easier for the student to\n            learn. Typical values range from 3-5. Defaults to 3.0.\n        loss: Loss function to use for distillation. Can be:\n            - String identifier (e.g., 'kl_divergence',\n              'categorical_crossentropy')\n            - Keras loss instance\n            - Nested structure of losses matching the model output structure\n            - `None` to skip distillation for that output (useful for\n              multi-output models where you only want to distill some outputs)\n            At least one loss must be non-`None`. Defaults to 'kl_divergence'.\n\n    Examlpe(s):\n\n    ```python\n    # Basic logits distillation with KL divergence\n    distillation_loss = LogitsDistillation(temperature=3.0)\n\n    # With categorical crossentropy loss\n    distillation_loss = LogitsDistillation(\n        temperature=4.0,\n        loss=\"categorical_crossentropy\"\n    )\n\n    # With custom loss instance\n    distillation_loss = LogitsDistillation(\n        temperature=4.0,\n        loss=keras.losses.CategoricalCrossentropy(from_logits=True)\n    )\n\n    # For multi-output models\n    distillation_loss = LogitsDistillation(\n        temperature=3.0,\n        loss=[\"kl_divergence\", \"categorical_crossentropy\"]\n    )\n\n    # For multi-output models, only distill some outputs\n    distillation_loss = LogitsDistillation(\n        temperature=3.0,\n        loss=[\"kl_divergence\", None]  # Skip second output\n    )\n    ```\n    \"\"\"\n\n    @tracking.no_automatic_dependency_tracking\n    def __init__(\n        self,\n        temperature=3.0,\n        loss=\"kl_divergence\",\n    ):\n        self.temperature = temperature\n        self.loss = tree.map_structure(_convert_loss_to_function, loss)\n\n        flat_losses = tree.flatten(self.loss)\n        if all(l is None for l in flat_losses):\n            raise ValueError(\"At least one loss must be non-`None`.\")\n\n        if not isinstance(self.temperature, (int, float)):\n            raise ValueError(\n                f\"temperature must be a number, got {type(self.temperature)}\"\n            )\n        if self.temperature <= 0.0:\n            raise ValueError(\"temperature must be positive.\")\n\n    def compute_loss(self, teacher_outputs, student_outputs, **kwargs):\n        \"\"\"Compute distillation loss using the configured loss function.\n\n        Arguments:\n            teacher_outputs: Logits from teacher model. Can be a single tensor,\n                list/tuple of tensors, or dict of tensors.\n            student_outputs: Logits from student model. Can be a single tensor,\n                list/tuple of tensors, or dict of tensors.\n            **kwargs: Additional arguments (ignored).\n        Returns:\n            Distillation loss tensor.\n        \"\"\"\n        # Apply temperature scaling using tree.map_structure\n        teacher_scaled = tree.map_structure(\n            lambda x: keras.ops.divide(x, self.temperature), teacher_outputs\n        )\n        student_scaled = tree.map_structure(\n            lambda x: keras.ops.divide(x, self.temperature), student_outputs\n        )\n\n        # Apply loss function(s) to corresponding outputs\n        def apply_loss(loss_fn, teacher_logits, student_logits):\n            if loss_fn is None:\n                return 0.0\n\n            # Special handling for KL divergence (needs probabilities)\n            if isinstance(loss_fn, keras.losses.KLDivergence):\n                teacher_probs = keras.ops.softmax(teacher_logits, axis=-1)\n                student_probs = keras.ops.softmax(student_logits, axis=-1)\n                loss = keras.ops.mean(loss_fn(teacher_probs, student_probs))\n                # Scale by temperature^2 for KL (per literature)\n                return loss * (self.temperature**2)\n            else:\n                # For other losses, use logits directly\n                return keras.ops.mean(loss_fn(teacher_logits, student_logits))\n\n        # Apply losses using tree.map_structure\n        loss_values = tree.map_structure(\n            apply_loss, self.loss, teacher_scaled, student_scaled\n        )\n\n        # Sum all losses and return scalar\n        flat_losses = tree.flatten(loss_values)\n        return keras.ops.sum(keras.ops.stack(flat_losses))\n\n    def get_config(self):\n        \"\"\"Get configuration for serialization.\"\"\"\n        return {\n            \"temperature\": self.temperature,\n            \"loss\": serialization_lib.serialize_keras_object(self.loss),\n        }\n\n    @classmethod\n    def from_config(cls, config):\n        \"\"\"Create instance from configuration.\"\"\"\n        config = config.copy()\n        config[\"loss\"] = keras.losses.deserialize(config[\"loss\"])\n        return cls(**config)\n"
  },
  {
    "path": "keras/src/distillation/distillation_loss_test.py",
    "content": "import numpy as np\nimport pytest\n\nimport keras\nfrom keras.src.distillation.distillation_loss import FeatureDistillation\nfrom keras.src.distillation.distillation_loss import LogitsDistillation\nfrom keras.src.distillation.distiller import Distiller\nfrom keras.src.testing import TestCase\n\n\n@pytest.mark.requires_trainable_backend\nclass TestLogitsDistillation(TestCase):\n    \"\"\"Test cases for LogitsDistillation distillation_loss.\"\"\"\n\n    def test_logits_distillation_basic(self):\n        \"\"\"Test basic logits distillation structure validation.\"\"\"\n        # Create dummy logits\n        teacher_logits = keras.ops.convert_to_tensor(\n            np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), dtype=\"float32\"\n        )\n        student_logits = keras.ops.convert_to_tensor(\n            np.array([[2.0, 1.0, 4.0], [3.0, 6.0, 2.0]]), dtype=\"float32\"\n        )\n\n        distillation_loss = LogitsDistillation(temperature=3.0)\n        distillation_loss.validate_outputs(teacher_logits, student_logits)\n        incompatible_logits = {\"output\": teacher_logits}\n        with self.assertRaises(ValueError):\n            distillation_loss.validate_outputs(\n                teacher_logits, incompatible_logits\n            )\n\n\n@pytest.mark.requires_trainable_backend\nclass TestFeatureDistillation(TestCase):\n    \"\"\"Test cases for FeatureDistillation distillation_loss.\"\"\"\n\n    def test_feature_distillation_basic(self):\n        \"\"\"Test basic feature distillation structure validation.\"\"\"\n        # Create dummy features\n        teacher_features = keras.ops.convert_to_tensor(\n            np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), dtype=\"float32\"\n        )\n        student_features = keras.ops.convert_to_tensor(\n            np.array([[1.1, 2.1, 3.1], [4.1, 5.1, 6.1]]), dtype=\"float32\"\n        )\n\n        distillation_loss = FeatureDistillation(loss=\"mse\")\n        distillation_loss.validate_outputs(teacher_features, student_features)\n        incompatible_features = [teacher_features, teacher_features]\n        with self.assertRaises(ValueError):\n            distillation_loss.validate_outputs(\n                teacher_features, incompatible_features\n            )\n\n\n@pytest.mark.requires_trainable_backend\nclass TestEndToEndDistillation(TestCase):\n    \"\"\"End-to-end distillation tests with real models.\"\"\"\n\n    def setUp(self):\n        \"\"\"Set up models and test data for all tests.\"\"\"\n        super().setUp()\n\n        # Create teacher model\n        self.teacher = keras.Sequential(\n            [\n                keras.layers.Dense(\n                    32, activation=\"relu\", name=\"teacher_dense_1\"\n                ),\n                keras.layers.Dense(\n                    16, activation=\"relu\", name=\"teacher_dense_2\"\n                ),\n                keras.layers.Dense(10, name=\"teacher_output\"),\n            ]\n        )\n\n        # Create student model\n        self.student = keras.Sequential(\n            [\n                keras.layers.Dense(\n                    32, activation=\"relu\", name=\"student_dense_1\"\n                ),\n                keras.layers.Dense(\n                    16, activation=\"relu\", name=\"student_dense_2\"\n                ),\n                keras.layers.Dense(10, name=\"student_output\"),\n            ]\n        )\n\n        self.x = np.random.random((32, 20)).astype(np.float32)\n        self.y = np.random.randint(0, 10, (32,)).astype(np.int32)\n\n        self.teacher(self.x[:2])\n        self.student(self.x[:2])\n\n    def test_logits_distillation_end_to_end(self):\n        \"\"\"Test end-to-end logits distillation with real models.\"\"\"\n        # Create distiller\n        distiller = Distiller(\n            teacher=self.teacher,\n            student=self.student,\n            distillation_losses=LogitsDistillation(temperature=3.0),\n            student_loss_weight=0.5,\n        )\n\n        # Compile distiller\n        distiller.compile(\n            optimizer=\"adam\",\n            loss=\"sparse_categorical_crossentropy\",\n            metrics=[\"accuracy\"],\n        )\n\n        # Test training\n        history = distiller.fit(self.x, self.y, epochs=2, verbose=0)\n\n        # Verify training completed\n        self.assertIn(\"total_loss\", history.history)\n        self.assertIn(\"student_loss\", history.history)\n        self.assertIn(\"distillation_loss\", history.history)\n\n        # Verify loss values are reasonable\n        final_loss = history.history[\"total_loss\"][-1]\n        self.assertTrue(np.isfinite(final_loss))\n        self.assertGreater(final_loss, 0.0)\n\n        # Test prediction\n        predictions = distiller.predict(self.x[:5], verbose=0)\n        self.assertEqual(predictions.shape, (5, 10))\n\n        # Test student model access\n        student_model = distiller.student\n        self.assertIsInstance(student_model, keras.Model)\n\n    def test_feature_distillation_end_to_end(self):\n        \"\"\"Test end-to-end feature distillation with real models.\"\"\"\n        # Create distiller with feature distillation\n        distiller = Distiller(\n            teacher=self.teacher,\n            student=self.student,\n            distillation_losses=FeatureDistillation(\n                loss=\"mse\",\n                teacher_layer_name=\"teacher_dense_1\",\n                student_layer_name=\"student_dense_1\",\n            ),\n            student_loss_weight=0.5,\n        )\n\n        # Compile distiller\n        distiller.compile(\n            optimizer=\"adam\",\n            loss=\"sparse_categorical_crossentropy\",\n            metrics=[\"accuracy\"],\n        )\n\n        # Test training\n        history = distiller.fit(self.x, self.y, epochs=2, verbose=0)\n\n        # Verify training completed\n        self.assertIn(\"total_loss\", history.history)\n        self.assertIn(\"student_loss\", history.history)\n        self.assertIn(\"distillation_loss\", history.history)\n\n        # Verify feature extraction worked\n        self.assertIsNotNone(distiller._teacher_feature_extractor)\n        self.assertIsNotNone(distiller._student_feature_extractor)\n\n        # Test that feature extractors have correct outputs\n        self.assertEqual(\n            len(distiller._teacher_feature_extractor.outputs), 2\n        )  # final + dense_1\n        self.assertEqual(\n            len(distiller._student_feature_extractor.outputs), 2\n        )  # final + dense_1\n\n    def test_multi_distillation_loss_distillation_end_to_end(self):\n        \"\"\"Test end-to-end distillation with multiple distillation_loss.\"\"\"\n        # Create multiple distillation_loss\n        distillation_loss = [\n            LogitsDistillation(temperature=3.0),\n            FeatureDistillation(\n                loss=\"mse\",\n                teacher_layer_name=\"teacher_dense_1\",\n                student_layer_name=\"student_dense_1\",\n            ),\n            FeatureDistillation(\n                loss=\"mse\",\n                teacher_layer_name=\"teacher_dense_2\",\n                student_layer_name=\"student_dense_2\",\n            ),\n        ]\n\n        # Create distiller\n        distiller = Distiller(\n            teacher=self.teacher,\n            student=self.student,\n            distillation_losses=distillation_loss,\n            distillation_loss_weights=[1.0, 0.5, 0.3],\n            student_loss_weight=0.5,\n        )\n\n        # Compile distiller\n        distiller.compile(\n            optimizer=\"adam\",\n            loss=\"sparse_categorical_crossentropy\",\n            metrics=[\"accuracy\"],\n        )\n\n        # Test training\n        history = distiller.fit(self.x, self.y, epochs=2, verbose=0)\n\n        # Verify training completed\n        self.assertIn(\"total_loss\", history.history)\n        self.assertIn(\"student_loss\", history.history)\n        self.assertIn(\"distillation_loss\", history.history)\n\n        # Verify efficient feature extraction\n        self.assertIsNotNone(distiller._teacher_feature_extractor)\n        self.assertIsNotNone(distiller._student_feature_extractor)\n\n        # Should have 3 outputs: final + dense_1 + dense_2\n        self.assertEqual(len(distiller._teacher_feature_extractor.outputs), 3)\n        self.assertEqual(len(distiller._student_feature_extractor.outputs), 3)\n\n        # Test that loss decreases (learning is happening)\n        initial_loss = history.history[\"total_loss\"][0]\n        final_loss = history.history[\"total_loss\"][-1]\n        self.assertTrue(np.isfinite(initial_loss))\n        self.assertTrue(np.isfinite(final_loss))\n"
  },
  {
    "path": "keras/src/distillation/distiller.py",
    "content": "import keras\nfrom keras.src import tree\nfrom keras.src.api_export import keras_export\nfrom keras.src.distillation.distillation_loss import _convert_loss_to_function\nfrom keras.src.models.model import Model\nfrom keras.src.saving import serialization_lib\n\n\n@keras_export(\"keras.distillation.Distiller\")\nclass Distiller(Model):\n    \"\"\"Distillation model for transferring knowledge from teacher to student.\n\n    Knowledge distillation transfers knowledge from a large, complex model\n    (teacher) to a smaller, simpler model (student). The student learns\n    from both ground truth labels and the teacher's predictions, often\n    achieving better performance than training on labels alone.\n\n    Arguments:\n        teacher: A trained `keras.Model` that serves as the knowledge source.\n            The teacher model is frozen during distillation.\n        student: A `keras.Model` to be trained through distillation.\n        distillation_losses: List of distillation losses to apply. Can be a\n            single distillation loss or a list of distillation losses like\n            `keras.distillation.LogitsDistillation`,\n            `keras.distillation.FeatureDistillation`, or custom distillation\n            losses.\n        distillation_loss_weights: List of weights for each distillation loss.\n            Must have the same length as `distillation_losses`. If `None`,\n            equal weights are used.\n        student_loss_weight: Weight for the student's supervised loss component.\n            Must be between 0 and 1. Defaults to 0.5.\n        name: Name for the distiller model. Defaults to `\"distiller\"`.\n        **kwargs: Additional keyword arguments passed to the parent `Model`\n            class.\n\n    Attributes:\n        student: The student model being trained. Access this to get the trained\n            student model for independent use after distillation training.\n        teacher: The teacher model providing knowledge. This model is frozen\n            during training.\n\n    Examples:\n\n    ```python\n    # Basic distillation with KerasHub models\n    import keras_hub as hub\n\n    teacher = hub.models.CausalLM.from_preset(\"gemma_2b_en\")\n    student = hub.models.CausalLM.from_preset(\n        \"gemma_1.1_2b_en\", load_weights=False\n    )\n\n    # Single distillation loss\n    distiller = Distiller(\n        teacher=teacher,\n        student=student,\n        distillation_losses=LogitsDistillation(temperature=3.0),\n    )\n\n    # Compile the distiller (like any Keras model)\n    distiller.compile(\n        optimizer='adam',\n        loss='sparse_categorical_crossentropy',\n        metrics=['accuracy']\n    )\n\n    # Train the distiller\n    distiller.fit(x_train, y_train, epochs=10)\n\n    # Access the trained student model\n    trained_student = distiller.student\n\n    # Multiple distillation losses\n    distiller = Distiller(\n        teacher=teacher,\n        student=student,\n        distillation_losses=[\n            LogitsDistillation(temperature=3.0),\n            FeatureDistillation(\n                teacher_layer_name=\"dense_1\",\n                student_layer_name=\"dense_1\"\n            )\n        ],\n        distillation_loss_weights=[1.0, 0.5],\n    )\n\n    # Compile with custom settings\n    distiller.compile(\n        optimizer='adam',\n        loss='sparse_categorical_crossentropy',\n        metrics=['accuracy']\n    )\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        teacher,\n        student,\n        distillation_losses,\n        distillation_loss_weights=None,\n        student_loss_weight=0.5,\n        name=\"distiller\",\n        **kwargs,\n    ):\n        super().__init__(name=name, **kwargs)\n\n        # Validate inputs\n        self._validate_models(teacher, student)\n\n        # Store configuration\n        self.teacher = teacher\n        self.student = student\n\n        # Validate student_loss_weight\n        if not isinstance(student_loss_weight, (int, float)):\n            raise ValueError(\n                f\"student_loss_weight must be a number, got \"\n                f\"{type(student_loss_weight)}\"\n            )\n        if student_loss_weight < 0.0 or student_loss_weight > 1.0:\n            raise ValueError(\n                f\"student_loss_weight must be between 0.0 and 1.0, \"\n                f\"got {student_loss_weight}\"\n            )\n        self.student_loss_weight = student_loss_weight\n\n        # Handle distillation losses configuration\n        if distillation_losses is None:\n            raise ValueError(\n                \"'distillation_losses' cannot be `None`. Provide a \"\n                \"distillation loss (e.g., LogitsDistillation or \"\n                \"FeatureDistillation) or a list of distillation losses.\"\n            )\n\n        # Convert single distillation loss to list for uniform handling\n        if not isinstance(distillation_losses, (list, tuple)):\n            self.distillation_losses = [distillation_losses]\n            self.distillation_loss_weights = [1.0]\n        else:\n            self.distillation_losses = distillation_losses\n            # Set default weights if not provided\n            if distillation_loss_weights is None:\n                self.distillation_loss_weights = [1.0] * len(\n                    distillation_losses\n                )\n            else:\n                if len(distillation_loss_weights) != len(distillation_losses):\n                    raise ValueError(\n                        f\"Number of distillation_loss_weights \"\n                        f\"({len(distillation_loss_weights)}) must match \"\n                        f\"number of distillation_losses \"\n                        f\"({len(distillation_losses)})\"\n                    )\n                self.distillation_loss_weights = distillation_loss_weights\n\n        # Validate distillation loss compatibility and create extractors\n        for distillation_loss in self.distillation_losses:\n            self._validate_distillation_loss_compatibility(\n                teacher, student, distillation_loss\n            )\n\n        self._create_multi_feature_extractors()\n\n        # Freeze teacher model\n        self.teacher.trainable = False\n\n        # Initialize loss tracking metrics\n        self.student_loss_tracker = keras.metrics.Mean(name=\"student_loss\")\n        self.distillation_loss_tracker = keras.metrics.Mean(\n            name=\"distillation_loss\"\n        )\n        self.total_loss_tracker = keras.metrics.Mean(name=\"total_loss\")\n\n    def _validate_models(self, teacher, student):\n        \"\"\"Validate that teacher and student models are compatible.\"\"\"\n        if not isinstance(teacher, keras.Model):\n            raise ValueError(\n                f\"Teacher must be a keras.Model, got {type(teacher)}\"\n            )\n        if not isinstance(student, keras.Model):\n            raise ValueError(\n                f\"Student must be a keras.Model, got {type(student)}\"\n            )\n\n        self._validate_input_compatibility(teacher, student)\n        self._validate_output_compatibility(teacher, student)\n        self._validate_dtype_compatibility(teacher, student)\n\n    def _assert_shapes_are_compatible(self, shape1, shape2, context):\n        \"\"\"Assert that two shapes are compatible.\"\"\"\n        if len(shape1) != len(shape2):\n            raise ValueError(\n                f\"Teacher and student {context} shapes have different \"\n                f\"dimensions. Teacher: {shape1}, Student: {shape2}.\"\n            )\n\n        for dim1, dim2 in zip(shape1, shape2):\n            if dim1 is not None and dim2 is not None and dim1 != dim2:\n                raise ValueError(\n                    f\"Teacher and student {context} shapes are incompatible. \"\n                    f\"Teacher: {shape1}, Student: {shape2}. \"\n                    f\"All dimensions must match.\"\n                )\n\n    def _assert_same_dtype(self, teacher_dtype, student_dtype, context):\n        \"\"\"Assert that teacher and student dtypes are the same.\"\"\"\n        if teacher_dtype != student_dtype:\n            raise ValueError(\n                f\"Teacher and student {context} dtypes must match. \"\n                f\"Teacher: {teacher_dtype}, Student: {student_dtype}.\"\n            )\n\n    def _validate_input_compatibility(self, teacher, student):\n        \"\"\"Validate that teacher and student have compatible input shapes.\"\"\"\n        if not hasattr(teacher, \"inputs\") or not hasattr(student, \"inputs\"):\n            return\n        teacher_inputs = getattr(teacher, \"inputs\")\n        student_inputs = getattr(student, \"inputs\")\n        if teacher_inputs is None or student_inputs is None:\n            return\n\n        tree.map_structure(\n            lambda ti, si: self._assert_shapes_are_compatible(\n                ti.shape, si.shape, \"input\"\n            ),\n            teacher_inputs,\n            student_inputs,\n        )\n\n    def _validate_output_compatibility(self, teacher, student):\n        \"\"\"Validate that teacher and student have compatible output shapes.\"\"\"\n        if not hasattr(teacher, \"outputs\") or not hasattr(student, \"outputs\"):\n            return\n        teacher_outputs = getattr(teacher, \"outputs\")\n        student_outputs = getattr(student, \"outputs\")\n        if teacher_outputs is None or student_outputs is None:\n            return\n\n        tree.map_structure(\n            lambda to, so: self._assert_shapes_are_compatible(\n                to.shape, so.shape, \"output\"\n            ),\n            teacher_outputs,\n            student_outputs,\n        )\n\n    def _validate_dtype_compatibility(self, teacher, student):\n        \"\"\"Validate that teacher and student have compatible data types.\"\"\"\n        if not hasattr(teacher, \"inputs\") or not hasattr(student, \"inputs\"):\n            return\n        if teacher.inputs is None or student.inputs is None:\n            return\n\n        tree.map_structure(\n            lambda ti, si: self._assert_same_dtype(ti.dtype, si.dtype, \"input\"),\n            teacher.inputs,\n            student.inputs,\n        )\n\n        if not hasattr(teacher, \"outputs\") or not hasattr(student, \"outputs\"):\n            return\n        if teacher.outputs is None or student.outputs is None:\n            return\n\n        tree.map_structure(\n            lambda to, so: self._assert_same_dtype(\n                to.dtype, so.dtype, \"output\"\n            ),\n            teacher.outputs,\n            student.outputs,\n        )\n\n    def _validate_distillation_loss_compatibility(\n        self, teacher, student, distillation_loss\n    ):\n        \"\"\"Validate that the distillation loss is compatible with teacher\n        and student models.\"\"\"\n        distillation_loss.validate_model_compatibility(teacher, student)\n\n    def _create_multi_feature_extractors(self):\n        \"\"\"Create feature extractors for efficient multi-layer extraction.\"\"\"\n        teacher_layer_names = []\n        student_layer_names = []\n\n        for distillation_loss in self.distillation_losses:\n            if (\n                hasattr(distillation_loss, \"teacher_layer_name\")\n                and distillation_loss.teacher_layer_name\n            ):\n                if (\n                    distillation_loss.teacher_layer_name\n                    not in teacher_layer_names\n                ):\n                    teacher_layer_names.append(\n                        distillation_loss.teacher_layer_name\n                    )\n            if (\n                hasattr(distillation_loss, \"student_layer_name\")\n                and distillation_loss.student_layer_name\n            ):\n                if (\n                    distillation_loss.student_layer_name\n                    not in student_layer_names\n                ):\n                    student_layer_names.append(\n                        distillation_loss.student_layer_name\n                    )\n\n        self._teacher_feature_extractor = self._create_feature_extractor(\n            self.teacher, teacher_layer_names\n        )\n        self._student_feature_extractor = self._create_feature_extractor(\n            self.student, student_layer_names\n        )\n\n    def _create_feature_extractor(self, model, layer_names):\n        \"\"\"Create a feature extractor for a model.\n\n        Arguments:\n            model: The model to create an extractor for.\n            layer_names: List of layer names to extract features from.\n\n        Returns:\n            Feature extractor model or `None` if no layer names provided.\n\n        Raises:\n            ValueError: If model has no symbolic inputs/outputs.\n        \"\"\"\n        if not layer_names:\n            return None\n\n        if not hasattr(model, \"inputs\") or model.inputs is None:\n            raise ValueError(\n                f\"Cannot create feature extractor for {model.name}. \"\n                f\"The model has no symbolic inputs attribute.\"\n            )\n\n        if isinstance(model, keras.Sequential):\n            final_output = model.layers[-1].output\n        else:\n            final_output = model.output\n\n        outputs = {\"final_output\": final_output}\n        for layer_name in layer_names:\n            layer = model.get_layer(name=layer_name)\n            outputs[layer_name] = layer.output\n\n        return keras.Model(\n            inputs=model.inputs,\n            outputs=outputs,\n            name=f\"{model.name}_multi_feature_extractor\",\n        )\n\n    def _extract_all_teacher_features(self, x):\n        \"\"\"Extract all teacher features in a single forward pass.\"\"\"\n        if self._teacher_feature_extractor is not None:\n            return self._teacher_feature_extractor(x, training=False)\n        else:\n            return {\"final_output\": self.teacher(x, training=False)}\n\n    def _extract_all_student_features(self, x, y_pred):\n        \"\"\"Extract all student features in a single forward pass.\"\"\"\n        if self._student_feature_extractor is not None:\n            return self._student_feature_extractor(x, training=True)\n        else:\n            return {\"final_output\": y_pred}\n\n    def _get_distillation_loss_features(\n        self, distillation_loss, all_features, is_teacher\n    ):\n        \"\"\"Get the specific features needed by a distillation loss.\"\"\"\n        if is_teacher:\n            layer_name = distillation_loss.teacher_layer_name or \"final_output\"\n        else:\n            layer_name = distillation_loss.student_layer_name or \"final_output\"\n\n        if layer_name not in all_features:\n            raise ValueError(\n                f\"Layer '{layer_name}' not found in extracted features. \"\n                f\"Available: {list(all_features.keys())}\"\n            )\n\n        return all_features[layer_name]\n\n    def compile(self, optimizer=\"adam\", loss=None, metrics=None, **kwargs):\n        \"\"\"Compile the distiller with proper integration.\n\n        Arguments:\n            optimizer: Optimizer for training the student model.\n            loss: Student loss function for the student's supervised learning.\n                Can be a string identifier or a loss function instance.\n            metrics: Additional metrics to track during training.\n            **kwargs: Additional arguments passed to parent compile.\n        \"\"\"\n        if loss is None:\n            raise ValueError(\"'loss' cannot be `None`.\")\n\n        self._student_loss = tree.map_structure(_convert_loss_to_function, loss)\n        self._student_loss_for_serialization = loss\n\n        if metrics is not None and not isinstance(metrics, (list, tuple)):\n            raise ValueError(\n                f\"metrics must be a list or tuple, got {type(metrics)}\"\n            )\n\n        super().compile(\n            optimizer=optimizer,\n            loss=None,\n            metrics=metrics,\n            **kwargs,\n        )\n\n    def call(self, inputs, training=None, **kwargs):\n        \"\"\"Forward pass returns student predictions.\"\"\"\n        return self.student(inputs, training=training, **kwargs)\n\n    def compute_loss(\n        self, x=None, y=None, y_pred=None, sample_weight=None, training=True\n    ):\n        \"\"\"Compute combined distillation loss.\n\n        Arguments:\n            x: Input data.\n            y: Target data.\n            y_pred: Model predictions.\n            sample_weight: Sample weights (currently unused).\n            training: Whether the model is in training mode.\n\n        Returns:\n            Combined loss tensor.\n        \"\"\"\n        # Handle case where y_pred is not provided\n        if y_pred is None:\n            y_pred = self(x, training=training)\n        # Compute student loss\n        student_loss = 0.0\n        if self.student_loss_weight > 0.0 and y is not None:\n            loss_values = tree.map_structure(\n                lambda l, o, o_pred: l(o, o_pred),\n                self._student_loss,\n                y,\n                y_pred,\n            )\n            flat_losses = tree.flatten(loss_values)\n            student_loss = (\n                keras.ops.sum(keras.ops.stack(flat_losses))\n                if len(flat_losses) > 1\n                else flat_losses[0]\n            )\n\n            # Ensure student_loss is a scalar\n            if hasattr(student_loss, \"shape\") and len(student_loss.shape) > 0:\n                student_loss = keras.ops.mean(student_loss)\n\n        # Compute distillation loss\n        distillation_loss = 0.0\n        if self.student_loss_weight < 1.0:\n            teacher_features = self._extract_all_teacher_features(x)\n            student_features = self._extract_all_student_features(x, y_pred)\n\n            # Apply distillation losses using pre-extracted features\n            for distillation_loss_fn, weight in zip(\n                self.distillation_losses, self.distillation_loss_weights\n            ):\n                # Get appropriate outputs/features for this distillation loss\n                if (\n                    hasattr(distillation_loss_fn, \"teacher_layer_name\")\n                    and distillation_loss_fn.teacher_layer_name is not None\n                ):\n                    # FeatureDistillation with specific layers\n                    try:\n                        distillation_loss_teacher_output = (\n                            self._get_distillation_loss_features(\n                                distillation_loss_fn,\n                                teacher_features,\n                                is_teacher=True,\n                            )\n                        )\n                        distillation_loss_student_output = (\n                            self._get_distillation_loss_features(\n                                distillation_loss_fn,\n                                student_features,\n                                is_teacher=False,\n                            )\n                        )\n                    except ValueError as e:\n                        # Re-raise with context about which loss failed\n                        raise RuntimeError(\n                            f\"Failed to extract features for \"\n                            f\"{type(distillation_loss_fn).__name__} \"\n                            f\"targeting teacher layer \"\n                            f\"'{distillation_loss_fn.teacher_layer_name}' \"\n                            f\"and student layer \"\n                            f\"'{distillation_loss_fn.student_layer_name}'. \"\n                            f\"Original error: {e}\"\n                        ) from e\n                else:\n                    # LogitsDistillation or FeatureDistillation (final outputs)\n                    distillation_loss_teacher_output = teacher_features[\n                        \"final_output\"\n                    ]\n                    distillation_loss_student_output = y_pred\n\n                # Validate outputs are compatible for this distillation loss\n                distillation_loss_fn.validate_outputs(\n                    distillation_loss_teacher_output,\n                    distillation_loss_student_output,\n                )\n\n                # Compute loss for this distillation loss\n                current_distillation_loss = distillation_loss_fn.compute_loss(\n                    distillation_loss_teacher_output,\n                    distillation_loss_student_output,\n                )\n\n                # Validate that distillation loss returns a scalar\n                if (\n                    hasattr(current_distillation_loss, \"shape\")\n                    and len(current_distillation_loss.shape) > 0\n                ):\n                    raise ValueError(\n                        f\"Distillation loss \"\n                        f\"{distillation_loss_fn.__class__.__name__} \"\n                        f\"returned a non-scalar loss with shape \"\n                        f\"{current_distillation_loss.shape}. \"\n                        f\"The compute_loss method must return a scalar \"\n                        f\"tensor.\"\n                    )\n\n                # Apply weight and add to total\n                distillation_loss = keras.ops.add(\n                    distillation_loss,\n                    keras.ops.multiply(weight, current_distillation_loss),\n                )\n\n        # Combine losses\n        total_loss = keras.ops.add(\n            keras.ops.multiply(self.student_loss_weight, student_loss),\n            keras.ops.multiply(\n                keras.ops.subtract(1.0, self.student_loss_weight),\n                distillation_loss,\n            ),\n        )\n\n        # Update metrics\n        self.student_loss_tracker.update_state(student_loss)\n        self.distillation_loss_tracker.update_state(distillation_loss)\n        self.total_loss_tracker.update_state(total_loss)\n\n        return total_loss\n\n    def reset_metrics(self):\n        \"\"\"Reset all metrics.\"\"\"\n        super().reset_metrics()\n        self.student_loss_tracker.reset_state()\n        self.distillation_loss_tracker.reset_state()\n        self.total_loss_tracker.reset_state()\n\n    def get_config(self):\n        \"\"\"Get configuration for serialization.\"\"\"\n        config = super().get_config()\n        config.update(\n            {\n                \"teacher\": serialization_lib.serialize_keras_object(\n                    self.teacher\n                ),\n                \"student\": serialization_lib.serialize_keras_object(\n                    self.student\n                ),\n                \"distillation_losses\": [\n                    serialization_lib.serialize_keras_object(distillation_loss)\n                    for distillation_loss in self.distillation_losses\n                ],\n                \"distillation_loss_weights\": self.distillation_loss_weights,\n                \"student_loss_weight\": self.student_loss_weight,\n            }\n        )\n        return config\n\n    @classmethod\n    def from_config(cls, config):\n        \"\"\"Create instance from configuration.\"\"\"\n        config = config.copy()\n\n        # Deserialize objects\n        config[\"teacher\"] = serialization_lib.deserialize_keras_object(\n            config[\"teacher\"]\n        )\n        config[\"student\"] = serialization_lib.deserialize_keras_object(\n            config[\"student\"]\n        )\n        config[\"distillation_losses\"] = [\n            serialization_lib.deserialize_keras_object(distillation_loss)\n            for distillation_loss in config[\"distillation_losses\"]\n        ]\n\n        return cls(**config)\n"
  },
  {
    "path": "keras/src/distillation/distiller_test.py",
    "content": "import json\nimport os\n\nimport numpy as np\nimport pytest\n\nimport keras\nfrom keras.src.distillation.distillation_loss import LogitsDistillation\nfrom keras.src.distillation.distiller import Distiller\nfrom keras.src.testing import TestCase\n\n\nclass SimpleTeacher(keras.Model):\n    \"\"\"Simple teacher model for testing.\"\"\"\n\n    def __init__(self, vocab_size=10, hidden_dim=32):\n        super().__init__()\n        self.dense1 = keras.layers.Dense(hidden_dim, activation=\"relu\")\n        self.dense2 = keras.layers.Dense(vocab_size)\n\n    def call(self, inputs, training=None):\n        x = self.dense1(inputs)\n        return self.dense2(x)\n\n\nclass SimpleStudent(keras.Model):\n    \"\"\"Simple student model for testing.\"\"\"\n\n    def __init__(self, vocab_size=10, hidden_dim=16):\n        super().__init__()\n        self.dense1 = keras.layers.Dense(hidden_dim, activation=\"relu\")\n        self.dense2 = keras.layers.Dense(vocab_size)\n\n    def call(self, inputs, training=None):\n        x = self.dense1(inputs)\n        return self.dense2(x)\n\n\n@pytest.mark.requires_trainable_backend\nclass TestDistiller(TestCase):\n    \"\"\"Essential test cases for the Distiller class.\"\"\"\n\n    def setUp(self):\n        \"\"\"Set up test fixtures.\"\"\"\n        super().setUp()\n\n        # Create test data\n        self.x = np.random.random((20, 5)).astype(np.float32)\n        self.y = np.random.randint(0, 10, (20,)).astype(np.int32)\n\n        # Create teacher and student models\n        self.teacher = SimpleTeacher(vocab_size=10, hidden_dim=32)\n        self.student = SimpleStudent(vocab_size=10, hidden_dim=16)\n\n        # Build models\n        dummy_input = self.x[:2]\n        self.teacher(dummy_input)\n        self.student(dummy_input)\n\n        # Create distillation distillation_loss\n        self.distillation_loss = LogitsDistillation(temperature=2.0)\n\n        # Create distiller\n        self.distiller = Distiller(\n            teacher=self.teacher,\n            student=self.student,\n            distillation_losses=self.distillation_loss,\n            student_loss_weight=0.5,\n        )\n\n        # Compile distiller\n        self.distiller.compile(\n            optimizer=\"adam\",\n            loss=\"sparse_categorical_crossentropy\",\n            metrics=[\"accuracy\"],\n        )\n\n    def test_distiller_initialization(self):\n        \"\"\"Test Distiller initialization.\"\"\"\n        # Check that teacher is frozen\n        self.assertFalse(self.teacher.trainable)\n\n        # Check that student is trainable\n        self.assertTrue(self.student.trainable)\n\n        # Check student_loss_weight\n        self.assertEqual(self.distiller.student_loss_weight, 0.5)\n\n        # Check distillation_loss (should be a list with one distillation_loss)\n        self.assertIsInstance(self.distiller.distillation_losses, list)\n        self.assertEqual(len(self.distiller.distillation_losses), 1)\n        self.assertIsInstance(\n            self.distiller.distillation_losses[0], LogitsDistillation\n        )\n\n        # Check that distillation_loss has the correct temperature\n        self.assertEqual(self.distiller.distillation_losses[0].temperature, 2.0)\n\n        # Check that model is compiled\n        self.assertIsNotNone(self.distiller.optimizer)\n        # Check if the model has been compiled (different backends may handle\n        # this differently)\n        self.assertTrue(\n            hasattr(self.distiller, \"_compile_config\")\n            or hasattr(self.distiller, \"compiled_loss\"),\n            \"Model should be compiled\",\n        )\n\n    def test_distiller_call(self):\n        \"\"\"Test Distiller call method (inference).\"\"\"\n        # Call should return student outputs\n        outputs = self.distiller(self.x)\n\n        # Check output shape\n        expected_shape = (20, 10)  # batch_size, vocab_size\n        self.assertEqual(outputs.shape, expected_shape)\n\n        # Check that outputs are from student, not teacher\n        student_outputs = self.student(self.x)\n        self.assertAllClose(outputs, student_outputs)\n\n    def test_teacher_freezing(self):\n        \"\"\"Test that teacher is properly frozen.\"\"\"\n        # Teacher should be frozen\n        self.assertFalse(self.teacher.trainable)\n\n        # Student should be trainable\n        self.assertTrue(self.student.trainable)\n\n        # Create a new teacher that is trainable and verify it gets frozen\n        new_teacher = SimpleTeacher(vocab_size=10, hidden_dim=32)\n        self.assertTrue(new_teacher.trainable)  # Should be trainable initially\n\n        # Create distiller - should freeze the teacher\n        Distiller(\n            teacher=new_teacher,\n            student=self.student,\n            distillation_losses=self.distillation_loss,\n            student_loss_weight=0.5,\n        )\n\n        # Teacher should now be frozen\n        self.assertFalse(new_teacher.trainable)\n\n    def test_model_compatibility_validation(self):\n        \"\"\"Test model compatibility validation.\"\"\"\n        # Test with non-Keras objects\n        with self.assertRaises(ValueError):\n            Distiller(\n                teacher=\"not_a_model\",\n                student=self.student,\n                distillation_losses=self.distillation_loss,\n            )\n\n        with self.assertRaises(ValueError):\n            Distiller(\n                teacher=self.teacher,\n                student=\"not_a_model\",\n                distillation_losses=self.distillation_loss,\n            )\n\n    def test_multi_distillation_loss_functionality(self):\n        \"\"\"Test multi-distillation_loss functionality.\"\"\"\n        # Create multiple distillation_loss\n        distillation_loss = [\n            LogitsDistillation(temperature=3.0),\n            LogitsDistillation(temperature=2.0),\n        ]\n        distillation_loss_weights = [0.7, 0.3]\n\n        # Create distiller with multiple distillation_loss\n        distiller = Distiller(\n            teacher=self.teacher,\n            student=self.student,\n            distillation_losses=distillation_loss,\n            distillation_loss_weights=distillation_loss_weights,\n            student_loss_weight=0.5,\n        )\n\n        # Compile distiller\n        distiller.compile(\n            optimizer=\"adam\",\n            loss=\"sparse_categorical_crossentropy\",\n            metrics=[\"accuracy\"],\n        )\n\n        # Test that distillation_loss are stored correctly\n        self.assertEqual(len(distiller.distillation_losses), 2)\n        self.assertEqual(distiller.distillation_loss_weights, [0.7, 0.3])\n\n        # Test training\n        x = np.random.random((10, 5)).astype(np.float32)\n        y = np.random.randint(0, 10, (10,))\n        history = distiller.fit(x, y, epochs=1, verbose=0)\n\n        # Check metrics\n        self.assertIn(\"total_loss\", history.history)\n        self.assertIn(\"student_loss\", history.history)\n        self.assertIn(\"distillation_loss\", history.history)\n\n    def test_multi_distillation_loss_validation(self):\n        \"\"\"Test multi-distillation_loss validation.\"\"\"\n        distillation_loss = [\n            LogitsDistillation(temperature=3.0),\n            LogitsDistillation(temperature=2.0),\n        ]\n\n        # Test that validation passes for valid configurations\n        distiller = Distiller(\n            teacher=self.teacher,\n            student=self.student,\n            distillation_losses=distillation_loss,\n            student_loss_weight=0.5,\n        )\n\n        self.assertEqual(len(distiller.distillation_losses), 2)\n\n        # Test invalid distillation_loss weights length\n        with self.assertRaises(ValueError):\n            Distiller(\n                teacher=self.teacher,\n                student=self.student,\n                distillation_losses=distillation_loss,\n                distillation_loss_weights=[1.0],  # Wrong length\n                student_loss_weight=0.5,\n            )\n\n    def test_student_loss_weighting(self):\n        \"\"\"Test student loss weighting functionality.\"\"\"\n        # Test with student_loss_weight = 0.0 (only distillation loss)\n        distiller_0 = Distiller(\n            teacher=self.teacher,\n            student=self.student,\n            distillation_losses=self.distillation_loss,\n            student_loss_weight=0.0,\n        )\n\n        # Test with student_loss_weight = 1.0 (only student loss)\n        distiller_1 = Distiller(\n            teacher=self.teacher,\n            student=self.student,\n            distillation_losses=self.distillation_loss,\n            student_loss_weight=1.0,\n        )\n\n        # Compile both distillers\n        distiller_0.compile(\n            optimizer=\"adam\",\n            loss=\"sparse_categorical_crossentropy\",\n            metrics=[\"accuracy\"],\n        )\n        distiller_1.compile(\n            optimizer=\"adam\",\n            loss=\"sparse_categorical_crossentropy\",\n            metrics=[\"accuracy\"],\n        )\n\n        # Test that they can be used for training without errors\n        small_x = self.x[:5]\n        small_y = self.y[:5]\n\n        # Both should train without errors\n        history_0 = distiller_0.fit(small_x, small_y, epochs=1, verbose=0)\n        history_1 = distiller_1.fit(small_x, small_y, epochs=1, verbose=0)\n\n        # Check that training completed\n        self.assertIn(\"total_loss\", history_0.history)\n        self.assertIn(\"total_loss\", history_1.history)\n\n    def test_full_training_workflow(self):\n        \"\"\"Test complete training workflow with model.fit() - MOST IMPORTANT.\"\"\"\n        # Create larger dataset for training\n        np.random.seed(42)\n        x_train = np.random.random((100, 5)).astype(np.float32)\n        y_train = np.random.randint(0, 10, (100,)).astype(np.int32)\n        x_val = np.random.random((20, 5)).astype(np.float32)\n        y_val = np.random.randint(0, 10, (20,)).astype(np.int32)\n\n        # Create fresh models for training\n        teacher = SimpleTeacher(vocab_size=10, hidden_dim=32)\n        student = SimpleStudent(vocab_size=10, hidden_dim=16)\n\n        # Build models to avoid JAX tracer issues\n        dummy_input = x_train[:2]\n        teacher(dummy_input)\n        student(dummy_input)\n\n        # Create distiller\n        distiller = Distiller(\n            teacher=teacher,\n            student=student,\n            distillation_losses=self.distillation_loss,\n            student_loss_weight=0.5,\n        )\n\n        # Compile distiller\n        distiller.compile(\n            optimizer=\"adam\",\n            loss=\"sparse_categorical_crossentropy\",\n            metrics=[\"accuracy\"],\n        )\n\n        # Train the model\n        history = distiller.fit(\n            x_train,\n            y_train,\n            validation_data=(x_val, y_val),\n            epochs=3,\n            batch_size=16,\n            verbose=0,\n        )\n\n        # Check that training completed\n        self.assertIn(\"total_loss\", history.history)\n        self.assertIn(\"val_total_loss\", history.history)\n        self.assertIn(\"student_loss\", history.history)\n        self.assertIn(\"distillation_loss\", history.history)\n\n        # Check that losses are finite\n        for loss_name in [\"total_loss\", \"student_loss\", \"distillation_loss\"]:\n            losses = history.history[loss_name]\n            self.assertGreater(len(losses), 0)\n            for loss in losses:\n                self.assertTrue(np.isfinite(loss))\n\n        # Check that the model can make predictions\n        predictions = distiller.predict(x_val[:5], verbose=0)\n        self.assertEqual(predictions.shape, (5, 10))  # batch_size, vocab_size\n\n        # Check that student weights have changed (indicating learning)\n        initial_weights = [w.numpy().copy() for w in student.trainable_weights]\n\n        # Train a bit more\n        distiller.fit(x_train[:10], y_train[:10], epochs=1, verbose=0)\n\n        final_weights = [w.numpy() for w in student.trainable_weights]\n\n        # At least some weights should have changed\n        weights_changed = any(\n            not np.allclose(initial, final, atol=1e-6)\n            for initial, final in zip(initial_weights, final_weights)\n        )\n        self.assertTrue(\n            weights_changed, \"Student weights should change during training\"\n        )\n\n    def test_evaluation_workflow(self):\n        \"\"\"Test evaluation workflow with model.evaluate().\"\"\"\n        # Create dataset\n        np.random.seed(42)\n        x_test = np.random.random((30, 5)).astype(np.float32)\n        y_test = np.random.randint(0, 10, (30,)).astype(np.int32)\n\n        # Create fresh models\n        teacher = SimpleTeacher(vocab_size=10, hidden_dim=32)\n        student = SimpleStudent(vocab_size=10, hidden_dim=16)\n\n        # Build models to avoid JAX tracer issues\n        dummy_input = x_test[:2]\n        teacher(dummy_input)\n        student(dummy_input)\n\n        # Create distiller\n        distiller = Distiller(\n            teacher=teacher,\n            student=student,\n            distillation_losses=self.distillation_loss,\n            student_loss_weight=0.5,\n        )\n\n        # Compile distiller\n        distiller.compile(\n            optimizer=\"adam\",\n            loss=\"sparse_categorical_crossentropy\",\n            metrics=[\"accuracy\"],\n        )\n\n        # Train briefly\n        distiller.fit(x_test[:10], y_test[:10], epochs=1, verbose=0)\n\n        # Evaluate the model\n        results = distiller.evaluate(x_test, y_test, verbose=0)\n\n        # Check that evaluation returns expected metrics\n        self.assertIsInstance(results, list)\n        self.assertGreater(len(results), 0)\n\n        # All results should be finite\n        for result in results:\n            self.assertTrue(np.isfinite(result))\n\n    def test_prediction_workflow(self):\n        \"\"\"Test prediction workflow with model.predict().\"\"\"\n        # Create dataset\n        np.random.seed(42)\n        x_test = np.random.random((20, 5)).astype(np.float32)\n\n        # Create fresh models\n        teacher = SimpleTeacher(vocab_size=10, hidden_dim=32)\n        student = SimpleStudent(vocab_size=10, hidden_dim=16)\n\n        # Build models to avoid JAX tracer issues\n        dummy_input = x_test[:2]\n        teacher(dummy_input)\n        student(dummy_input)\n\n        # Create distiller\n        distiller = Distiller(\n            teacher=teacher,\n            student=student,\n            distillation_losses=self.distillation_loss,\n            student_loss_weight=0.5,\n        )\n\n        # Make predictions\n        predictions = distiller.predict(x_test, verbose=0)\n\n        # Check prediction shape\n        self.assertEqual(predictions.shape, (20, 10))  # batch_size, vocab_size\n\n        # Check that predictions are finite\n        self.assertTrue(np.all(np.isfinite(predictions)))\n\n        # Check predictions sum to reasonable values (not zeros/infinities)\n        prediction_sums = np.sum(predictions, axis=1)\n        self.assertTrue(np.all(np.isfinite(prediction_sums)))\n\n    def test_distiller_serialization_and_saving(self):\n        \"\"\"Test Distiller serialization, saving, and loading.\"\"\"\n\n        # Use standard Sequential models for serialization testing\n        teacher = keras.Sequential(\n            [\n                keras.layers.Dense(\n                    32, activation=\"relu\", name=\"teacher_dense_1\"\n                ),\n                keras.layers.Dense(\n                    16, activation=\"relu\", name=\"teacher_dense_2\"\n                ),\n                keras.layers.Dense(10, name=\"teacher_output\"),\n            ]\n        )\n\n        student = keras.Sequential(\n            [\n                keras.layers.Dense(\n                    16, activation=\"relu\", name=\"student_dense_1\"\n                ),\n                keras.layers.Dense(\n                    8, activation=\"relu\", name=\"student_dense_2\"\n                ),\n                keras.layers.Dense(10, name=\"student_output\"),\n            ]\n        )\n\n        # Create distiller with single distillation_loss\n        distillation_loss = LogitsDistillation(\n            temperature=3.0, loss=\"kl_divergence\"\n        )\n\n        original_distiller = Distiller(\n            teacher=teacher,\n            student=student,\n            distillation_losses=distillation_loss,\n            student_loss_weight=0.7,\n        )\n\n        # Build the models by calling them\n        x_test = np.random.random((2, 20)).astype(np.float32)\n        _ = original_distiller(x_test)\n\n        # Test get_config\n        config = original_distiller.get_config()\n\n        # Verify all components are in config\n        required_keys = [\n            \"teacher\",\n            \"student\",\n            \"distillation_losses\",\n            \"distillation_loss_weights\",\n            \"student_loss_weight\",\n        ]\n        for key in required_keys:\n            self.assertIn(key, config, f\"Missing key: {key}\")\n\n        # Test JSON serialization\n        json_str = json.dumps(config)\n        self.assertIsInstance(json_str, str)\n\n        # Test from_config reconstruction\n        reconstructed_distiller = Distiller.from_config(config)\n\n        # Verify reconstruction\n        self.assertEqual(reconstructed_distiller.student_loss_weight, 0.7)\n        self.assertIsInstance(\n            reconstructed_distiller.distillation_losses[0], LogitsDistillation\n        )\n\n        # Verify distillation_loss parameters\n        self.assertEqual(\n            reconstructed_distiller.distillation_losses[0].temperature, 3.0\n        )\n\n        # Test that reconstructed distiller can be used for inference\n        reconstructed_output = reconstructed_distiller(x_test)\n        self.assertEqual(reconstructed_output.shape, (2, 10))\n\n        # Test model saving and loading (full integration test)\n        temp_dir = self.get_temp_dir()\n        model_path = os.path.join(temp_dir, \"distiller_model.keras\")\n\n        # Compile original distiller\n        original_distiller.compile(\n            loss=\"sparse_categorical_crossentropy\",\n        )\n\n        # Save the model\n        original_distiller.save(model_path)\n\n        # Load the model\n        loaded_distiller = keras.models.load_model(model_path)\n\n        # Verify loaded model works\n        loaded_output = loaded_distiller(x_test)\n        self.assertEqual(loaded_output.shape, (2, 10))\n\n        # Verify parameters are preserved\n        self.assertEqual(loaded_distiller.student_loss_weight, 0.7)\n\n        # The core serialization functionality is working\n        self.assertTrue(True, \"Distiller serialization test passed\")\n"
  },
  {
    "path": "keras/src/distribution/__init__.py",
    "content": "from keras.src.distribution.distribution_lib import DataParallel\nfrom keras.src.distribution.distribution_lib import DeviceMesh\nfrom keras.src.distribution.distribution_lib import Distribution\nfrom keras.src.distribution.distribution_lib import LayoutMap\nfrom keras.src.distribution.distribution_lib import ModelParallel\nfrom keras.src.distribution.distribution_lib import TensorLayout\nfrom keras.src.distribution.distribution_lib import distribute_tensor\nfrom keras.src.distribution.distribution_lib import distribution\nfrom keras.src.distribution.distribution_lib import initialize\nfrom keras.src.distribution.distribution_lib import list_devices\nfrom keras.src.distribution.distribution_lib import set_distribution\n"
  },
  {
    "path": "keras/src/distribution/distribution_lib.py",
    "content": "\"\"\"Unified high-level distribution APIs across backends.\"\"\"\n\nimport collections\nimport contextlib\nimport os\nimport re\nimport warnings\n\nimport numpy as np\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend import KerasTensor\nfrom keras.src.backend import distribution_lib\nfrom keras.src.backend.common import global_state\n\nDEFAULT_BATCH_DIM_NAME = \"batch\"\nGLOBAL_ATTRIBUTE_NAME = \"distribution\"\n\n\n@keras_export(\"keras.distribution.list_devices\")\ndef list_devices(device_type=None):\n    \"\"\"Return all the available devices based on the device type.\n\n    Note: in a distributed setting, global devices are returned.\n\n    When `device_type` is not provided, devices of the default type are\n    returned. This function nevers return a mix of device types, for instance\n    GPUs and CPUs.\n\n    Args:\n        device_type: string, one of `\"cpu\"`, `\"gpu\"` or `\"tpu\"`. Defaults to\n            `\"gpu\"` or `\"tpu\"` if available when `device_type` is not provided.\n            Otherwise returns the `\"cpu\"` devices.\n\n    Return:\n        List of string, the devices that are available for distributed\n        computation. Each device is formatted as \"device_type:id\", for instance\n        \"gpu:1\" or \"cpu:0\".\n\n    \"\"\"\n    return distribution_lib.list_devices(device_type)\n\n\n@keras_export(\"keras.distribution.get_device_count\")\ndef get_device_count(device_type=None):\n    \"\"\"Returns the number of available devices based on the device type.\n\n    When `device_type` is not provided, the count of devices of the default type\n    is returned. This function nevers counts a mix of device types, for instance\n    GPUs and CPUs.\n\n    Args:\n        device_type: string, one of `\"cpu\"`, `\"gpu\"` or `\"tpu\"`. Defaults to\n            `\"gpu\"` or `\"tpu\"` if available when `device_type` is not provided.\n            Otherwise returns the `\"cpu\"` devices.\n\n    Returns:\n        int: The total number of devices for the specified type.\n    \"\"\"\n    return distribution_lib.get_device_count(device_type=device_type)\n\n\n@keras_export(\"keras.distribution.initialize\")\ndef initialize(job_addresses=None, num_processes=None, process_id=None):\n    \"\"\"Initialize the distribution system for multi-host/process setting.\n\n    Calling `initialize` will prepare the backend for execution on multi-host\n    GPU or TPUs. It should be called before any computations.\n\n    Note that the parameters can also be injected via environment variables,\n    which can be better controlled by the launch script at startup time.\n    For certain backend that also rely on the environment variables to\n    configure, Keras will properly forward them.\n\n    Args:\n        job_addresses: string. Comma separated IP addresses for all the jobs\n            that will form the whole computation cluster. Note that for JAX\n            backend, only the address for job 0 (coodinator) is needed. For\n            certain runtime like cloud TPU, this value can be `None`, and the\n            backend will figure it out with the TPU environment variables. You\n            can also config this value via environment variable\n            `KERAS_DISTRIBUTION_JOB_ADDRESSES`.\n        num_processes: int. The number of worker/processes that will form the\n            whole computation cluster. For certain runtime like cloud TPU, this\n            value can be `None`, and the backend will figure it out with the TPU\n            environment variables. You can also configure this value via\n            environment variable `KERAS_DISTRIBUTION_NUM_PROCESSES`.\n        process_id: int. The ID number of the current worker/process. The value\n            should be ranged from `0` to `num_processes - 1`. `0` will indicate\n            the current worker/process is the master/coordinate job. You can\n            also configure this value via environment variable\n            `KERAS_DISTRIBUTION_PROCESS_ID`.\n\n    Example:\n        Suppose there are two GPU processes, and process 0 is running at\n        address `10.0.0.1:1234`, and process 1 is running at address\n        `10.0.0.2:2345`. To configure such cluster, you can run\n\n        On process 0:\n        ```python\n        keras.distribute.initialize(\n            job_addresses=\"10.0.0.1:1234,10.0.0.2:2345\",\n            num_processes=2,\n            process_id=0)\n        ```\n\n        On process 1:\n        ```python\n        keras.distribute.initialize(\n            job_addresses=\"10.0.0.1:1234,10.0.0.2:2345\",\n            num_processes=2,\n            process_id=1)\n        ```\n\n        or via the environment variables:\n        On process 0:\n        ```python\n        os.environ[\n            \"KERAS_DISTRIBUTION_JOB_ADDRESSES\"] = \"10.0.0.1:1234,10.0.0.2:2345\"\n        os.environ[\"KERAS_DISTRIBUTION_NUM_PROCESSES\"] = \"2\"\n        os.environ[\"KERAS_DISTRIBUTION_PROCESS_ID\"] = \"0\"\n        keras.distribute.initialize()\n        ```\n\n        On process 1:\n        ```python\n        os.environ[\n            \"KERAS_DISTRIBUTION_JOB_ADDRESSES\"] = \"10.0.0.1:1234,10.0.0.2:2345\"\n        os.environ[\"KERAS_DISTRIBUTION_NUM_PROCESSES\"] = \"2\"\n        os.environ[\"KERAS_DISTRIBUTION_PROCESS_ID\"] = \"1\"\n        keras.distribute.initialize()\n        ```\n\n        Also note that for JAX backend, the `job_addresses` can be further\n        reduced to just the master/coordinator address, which is\n        `10.0.0.1:1234`.\n    \"\"\"\n    if (\n        job_addresses is None\n        and \"KERAS_DISTRIBUTION_JOB_ADDRESSES\" in os.environ\n    ):\n        job_addresses = os.environ[\"KERAS_DISTRIBUTION_JOB_ADDRESSES\"]\n    if (\n        num_processes is None\n        and \"KERAS_DISTRIBUTION_NUM_PROCESSES\" in os.environ\n    ):\n        num_processes = int(os.environ[\"KERAS_DISTRIBUTION_NUM_PROCESSES\"])\n    if process_id is None and \"KERAS_DISTRIBUTION_PROCESS_ID\" in os.environ:\n        process_id = int(os.environ[\"KERAS_DISTRIBUTION_PROCESS_ID\"])\n    distribution_lib.initialize(job_addresses, num_processes, process_id)\n\n\n@keras_export(\"keras.distribution.DeviceMesh\")\nclass DeviceMesh:\n    \"\"\"A cluster of computation devices for distributed computation.\n\n    This API is aligned with `jax.sharding.Mesh`, which represents the\n    computation devices in the global context.\n\n    See more details in [jax.sharding.Mesh](\n        https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.Mesh).\n\n    Args:\n        shape: tuple of list of integers. The shape of the overall\n            `DeviceMesh`, e.g. `(8,)` for a data parallel only distribution,\n            or `(4, 2)` for a model+data parallel distribution.\n        axis_names: List of string. The logical name of the each axis for\n            the `DeviceMesh`. The length of the `axis_names` should match to\n            the rank of the `shape`. The `axis_names` will be used to\n            match/create the `TensorLayout` when distribute the data and\n            variables.\n        devices: Optional list of devices. Defaults to all the available\n            devices locally from `keras.distribution.list_devices()`.\n    \"\"\"\n\n    def __init__(\n        self,\n        shape,\n        axis_names,\n        devices=None,\n    ):\n        if not shape or not axis_names:\n            raise ValueError(\n                \"Shape and axis_names cannot be empty. Received: \"\n                f\"shape={shape}, axis_names={axis_names}\"\n            )\n\n        if len(shape) != len(axis_names):\n            raise ValueError(\n                \"Shape and axis_names should have same size. \"\n                f\"Received: shape={shape}, axis_names={axis_names}\"\n            )\n        if devices is None:\n            devices = list_devices()\n        devices = np.array(devices)\n        if np.prod(shape) != np.prod(devices.shape):\n            raise ValueError(\n                \"Shape does not match the number of devices. \"\n                f\"Received: shape={shape}; devices.shape=\"\n                f\"{devices.shape}\"\n            )\n\n        self._shape = shape\n        self._axis_names = axis_names\n        self._devices = np.reshape(devices, shape)\n\n    @property\n    def shape(self):\n        return self._shape\n\n    @property\n    def axis_names(self):\n        return self._axis_names\n\n    @property\n    def devices(self):\n        return self._devices\n\n    @property\n    def backend_mesh(self):\n        if not hasattr(self, \"_backend_mesh\"):\n            self._backend_mesh = distribution_lib._to_backend_mesh(self)\n        return self._backend_mesh\n\n    def __repr__(self):\n        return (\n            f\"<{self.__class__.__name__} \"\n            f\"shape={self.shape}, axis_names={self.axis_names}>\"\n        )\n\n    def __str__(self):\n        return self.__repr__()\n\n\n@keras_export(\"keras.distribution.TensorLayout\")\nclass TensorLayout:\n    \"\"\"A layout to apply to a tensor.\n\n    This API is aligned with `jax.sharding.NamedSharding`.\n\n    See more details in [jax.sharding.NamedSharding](\n        https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.NamedSharding).\n\n    Args:\n        axes: tuple of strings that should map to the `axis_names` in\n            a `DeviceMesh`. For any dimensions that doesn't need any sharding,\n            A `None` can be used a placeholder.\n        device_mesh: Optional `DeviceMesh` that will be used to create\n            the layout. The actual mapping of tensor to physical device\n            is not known until the mesh is specified.\n    \"\"\"\n\n    def __init__(self, axes, device_mesh=None):\n        self._axes = tuple(axes)\n        self._device_mesh = device_mesh\n        self._validate_axes()\n\n    @property\n    def axes(self):\n        return self._axes\n\n    @property\n    def device_mesh(self):\n        return self._device_mesh\n\n    @device_mesh.setter\n    def device_mesh(self, device_mesh):\n        if self._device_mesh is not None:\n            raise ValueError(\n                \"Cannot override device mesh value. Existing \"\n                f\"value is {self._device_mesh}\"\n            )\n        self._device_mesh = device_mesh\n        self._validate_axes()\n\n    @property\n    def backend_layout(self):\n        if not hasattr(self, \"_backend_layout\"):\n            self._backend_layout = distribution_lib._to_backend_layout(self)\n        return self._backend_layout\n\n    def _validate_axes(self):\n        if self._device_mesh:\n            valid_axis_names = set(self._device_mesh.axis_names)\n            axis_names = set(self._axes) - set([None])\n            if axis_names - valid_axis_names:\n                raise ValueError(\n                    \"Invalid axis names for Layout. Valid axis \"\n                    f\"names: {valid_axis_names}, Got {axis_names}\"\n                )\n\n    def __repr__(self):\n        return (\n            f\"<{self.__class__.__name__} \"\n            f\"axes={self.axes}, device_mesh={self.device_mesh}>\"\n        )\n\n    def __str__(self):\n        return self.__repr__()\n\n\nclass Distribution:\n    \"\"\"Base class for variable distribution strategies.\n\n    A `Distribution` has following key functionalities:\n\n    1. Distribute the model variables to a `DeviceMesh`.\n    2. Distribute the input data to a `DeviceMesh`.\n    3. Distribute an intermediate state tensor in the model.\n\n    It can create a context scope so that the framework to properly detect the\n    `Distribution` and distribute the variable/data accordingly.\n\n    Args:\n        device_mesh: A `DeviceMesh` instance.\n        batch_dim_name: Optional string name for the batch dimension.\n            Defaults to None.\n        auto_shard_dataset: Automatically shard the dataset amongst\n            processes in a multi-process setting. Set to `False` if the dataset\n            is already sharded across hosts.  Defaults to `True`.\n    \"\"\"\n\n    def __init__(\n        self, device_mesh, batch_dim_name=None, auto_shard_dataset=True\n    ):\n        self._device_mesh = device_mesh\n        self._batch_dim_name = batch_dim_name\n        self._auto_shard_dataset = auto_shard_dataset\n\n    def get_data_layout(self, data_shape):\n        \"\"\"Retrieve the `TensorLayout` for the input data.\n\n        Args:\n            data_shape: shape for the input data in list or tuple format.\n\n        Returns:\n            The `TensorLayout` for the data, which can be used by\n            `backend.distribute_value()` to redistribute a input data.\n        \"\"\"\n        raise NotImplementedError()\n\n    def get_variable_layout(self, variable):\n        \"\"\"Retrieve the `TensorLayout` for the variable.\n\n        Args:\n            variable: A `Variable` instance.\n\n        return:\n            The `TensorLayout` for the variable, which can be used by\n            `backend.distribute_value()` to redistribute a variable.\n        \"\"\"\n        raise NotImplementedError()\n\n    def get_tensor_layout(self, path):\n        \"\"\"Retrieve the `TensorLayout` for the intermediate tensor.\n\n        Args:\n            path: a string path for the corresponding tensor.\n\n        return:\n            The `TensorLayout` for the intermediate tensor, which can be used\n            by `backend.relayout()` to reshard the tensor. Could also return\n            None.\n        \"\"\"\n        raise NotImplementedError()\n\n    @contextlib.contextmanager\n    def scope(self):\n        \"\"\"Context manager to make the `Distribution` current.\"\"\"\n        original_scope = distribution()\n        set_distribution(self)\n        try:\n            yield\n        finally:\n            set_distribution(original_scope)\n\n    @property\n    def device_mesh(self):\n        return self._device_mesh\n\n    @property\n    def batch_dim_name(self):\n        return self._batch_dim_name\n\n    @property\n    def auto_shard_dataset(self):\n        return self._auto_shard_dataset\n\n    @auto_shard_dataset.setter\n    def auto_shard_dataset(self, auto_shard_dataset):\n        self._auto_shard_dataset = auto_shard_dataset\n\n    def distribute_dataset(self, dataset):\n        \"\"\"Create a distributed dataset from the original global dataset.\n\n        Args:\n            dataset: the original global dataset instance.\n\n        Returns:\n            If `auto_shard_dataset` is `True`, returns a sharded dataset that\n            only produces data for the current local worker/process.  Otherwise,\n            returns the original dataset.\n\n        Raises:\n            ValueError: if auto-sharding is requested in a multi-process\n            setting, but the dataset type is not supported.\n        \"\"\"\n        raise NotImplementedError()\n\n    def __repr__(self):\n        return f\"<{self.__class__.__name__} device_mesh={self.device_mesh}>\"\n\n    def __str__(self):\n        return self.__repr__()\n\n\n@keras_export(\"keras.distribution.DataParallel\")\nclass DataParallel(Distribution):\n    \"\"\"Distribution for data parallelism.\n\n    You can choose to create this instance by either specifying\n    the `device_mesh` or `devices` arguments (but not both).\n\n    The `device_mesh` argument is expected to be a `DeviceMesh` instance,\n    and is expected to be 1D only. In case that the mesh has multiple axes,\n    then the first axis will be treated as the data parallel dimension\n    (and a warning will be raised).\n\n    When a list of `devices` are provided, they will be used to construct a\n    1D mesh.\n\n    When both `mesh` and `devices` are absent, then `list_devices()`\n    will be used to detect any available devices and create a 1D mesh from\n    them.\n\n    Args:\n        device_mesh: Optional `DeviceMesh` instance.\n        devices: Optional list of devices.\n        auto_shard_dataset: Automatically shard the dataset amongst\n            processes in a multi-process setting. Set to `False` if the dataset\n            is already sharded across hosts.  Defaults to `True`.\n    \"\"\"\n\n    def __init__(self, device_mesh=None, devices=None, auto_shard_dataset=True):\n        if device_mesh:\n            self._initialize_with_device_mesh(device_mesh, auto_shard_dataset)\n        elif devices:\n            self._initialize_mesh_from_devices(devices, auto_shard_dataset)\n        else:\n            self._initialize_mesh_from_list_devices(auto_shard_dataset)\n\n        # Those following attributes might get convert to public methods.\n        self._num_process = distribution_lib.num_processes()\n        self._process_id = distribution_lib.process_id()\n        self._is_multi_process = self._num_process > 1\n\n    def _initialize_with_device_mesh(self, device_mesh, auto_shard_dataset):\n        if not isinstance(device_mesh, DeviceMesh):\n            raise ValueError(\n                \"Expect `mesh` to be an instance of `DeviceMesh`. \"\n                f\"Received: mesh={device_mesh} (of type {type(device_mesh)})\"\n            )\n        super().__init__(\n            device_mesh, device_mesh.axis_names[0], auto_shard_dataset\n        )\n        if self.device_mesh.devices.ndim != 1:\n            warnings.warn(\n                \"Expect the input mesh to be 1D, but received \"\n                \"mesh.devices.ndim=%d. \"\n                \"The first axis will be used for data-parallel sharding.\",\n                device_mesh.devices.ndim,\n            )\n\n    def _initialize_mesh_from_devices(self, devices, auto_shard_dataset):\n        devices = np.array(devices)\n        device_mesh = DeviceMesh(\n            shape=devices.shape,\n            axis_names=[DEFAULT_BATCH_DIM_NAME],\n            devices=devices,\n        )\n        super().__init__(\n            device_mesh, DEFAULT_BATCH_DIM_NAME, auto_shard_dataset\n        )\n\n    def _initialize_mesh_from_list_devices(self, auto_shard_dataset):\n        devices = np.array(list_devices())\n        device_mesh = DeviceMesh(\n            shape=devices.shape,\n            axis_names=[DEFAULT_BATCH_DIM_NAME],\n            devices=devices,\n        )\n        super().__init__(\n            device_mesh, DEFAULT_BATCH_DIM_NAME, auto_shard_dataset\n        )\n\n    def get_data_layout(self, data_shape):\n        data_shard_spec = [None] * len(data_shape)\n        data_shard_spec[0] = self.batch_dim_name  # Shard on the first dim\n        return TensorLayout(data_shard_spec, self.device_mesh)\n\n    def get_variable_layout(self, variable):\n        # First check if the variable already has a layout assigned.\n        if getattr(variable, \"_layout\", None) is not None:\n            return variable._layout\n        # Otherwise, replicate variable.\n        variable_shard_spec = [None] * len(variable.shape)\n        return TensorLayout(variable_shard_spec, self.device_mesh)\n\n    def get_tensor_layout(self, path):\n        # For data parallel training, the intermediate state is not changed.\n        return None\n\n    def distribute_dataset(self, dataset):\n        if not self._is_multi_process or not self.auto_shard_dataset:\n            return dataset\n\n        # Try to distribute a global tf.data.Dataset.\n        from keras.src.utils.module_utils import tensorflow as tf\n\n        if not tf.available or not isinstance(dataset, tf.data.Dataset):\n            raise ValueError(\n                \"Only `tf.data.Dataset` is supported for auto-sharding, \"\n                f\"got {type(dataset)}\"\n            )\n\n        from tensorflow.python.data.experimental.ops import (\n            distribute as tf_data_distribute,\n        )\n\n        batch_size = tf_data_distribute.compute_batch_size(dataset)\n        if batch_size.numpy() < 0:\n            raise ValueError(\n                \"The batch size of the input dataset is \"\n                \"unknown. Please config the batch size for \"\n                \"the input dataset, e.g via `dataset.batch(batch_size)`\"\n            )\n        per_worker_batch_size = tf_data_distribute.batch_sizes_for_worker(\n            global_batch_size=batch_size,\n            num_workers=self._num_process,\n            num_replicas_per_worker=1,  # We hard code this for now.\n            worker_index=self._process_id,\n        )\n        distributed_dataset = dataset.rebatch(per_worker_batch_size)\n        distributed_dataset = tf_data_distribute._AutoShardDataset(\n            distributed_dataset,\n            num_workers=self._num_process,\n            index=self._process_id,\n            num_replicas=self._num_process,\n        )\n        return distributed_dataset.prefetch(tf.data.AUTOTUNE)\n\n\n@keras_export(\"keras.distribution.ModelParallel\")\nclass ModelParallel(Distribution):\n    \"\"\"Distribution that shards model variables.\n\n    Compare to `DataParallel` which replicates the variables across all devices,\n    `ModelParallel` allows you to shard variables in addition to the input data.\n\n    To construct a `ModelParallel` distribution, you need to provide a\n    `DeviceMesh` and a `LayoutMap`.\n\n    1. `DeviceMesh` contains physical device information. The axis names in\n        the mesh will be used to map the variable and data layout.\n    2. `LayoutMap` contains the mapping between variable paths to their\n        corresponding `TensorLayout`.\n\n    Example:\n\n    ```python\n    devices = list_devices()    # Assume there are 8 devices.\n\n    # Create a mesh with 2 devices for data parallelism and 4 devices for\n    # model parallelism.\n    device_mesh = DeviceMesh(shape=(2, 4), axis_names=('batch', 'model'),\n                             devices=devices)\n    # Create a layout map that shard the `Dense` layer and `Conv2D`\n    # layer variables on the last dimension.\n    # Based on the `device_mesh`, this means the variables\n    # will be split across 4 devices. Any other variable that doesn't\n    # match any key in the layout map will be fully replicated.\n    layout_map = LayoutMap(device_mesh)\n    layout_map['.*dense.*kernel'] = (None, 'model')\n    layout_map['.*dense.*bias'] = ('model',)\n    layout_map['.*conv2d.*kernel'] = (None, None, None, 'model')\n    layout_map['.*conv2d.*bias'] = ('model',)\n\n    distribution = ModelParallel(\n        layout_map=layout_map,\n        batch_dim_name='batch',\n    )\n\n    # Set the global distribution, or via `with distribution.scope():`\n    set_distribution(distribution)\n\n    model = model_creation()\n    model.compile()\n    model.fit(data)\n    ```\n\n    You can quickly update the device mesh shape to change the sharding factor\n    of the variables. E.g.\n\n    ```python\n    # With only the shape change for the device mesh, the variables will be\n    # sharded across 8 devices instead of 4, which further reduces the memory\n    # footprint of variables on each of the device.\n    device_mesh = DeviceMesh(\n        shape=(1, 8),\n        axis_names=('batch', 'model'),\n        devices=devices,\n    )\n    ```\n\n    To figure out a proper layout mapping rule for all the model variables, you\n    can first list out all the model variable paths, which will be used as the\n    key to map the variables to `TensorLayout`.\n\n    e.g.\n\n    ```python\n    model = create_model()\n    for v in model.variables:\n        print(v.path)\n    ```\n\n    Args:\n        layout_map: `LayoutMap` instance which map the variable path to the\n            corresponding tensor layout.\n        batch_dim_name: Optional string, the axis name in the device mesh\n            (of the `layout_map` object)\n            that will be used to distribute data. If unspecified, the\n            first axis from the device mesh will be used.\n        auto_shard_dataset: Automatically shard the dataset amongst\n            processes in a multi-process setting. Set to `False` if the dataset\n            is already sharded across hosts.  Defaults to `True`.\n    \"\"\"\n\n    def __init__(\n        self,\n        *,\n        layout_map=None,\n        batch_dim_name=None,\n        auto_shard_dataset=True,\n        **kwargs,\n    ):\n        kwargs.pop(\"device_mesh\", None)\n        if layout_map is None:\n            raise ValueError(\"You must specify a layout_map argument.\")\n        if not isinstance(layout_map, LayoutMap):\n            raise ValueError(\n                \"Argument `layout_map` must be a `LayoutMap` instance. \"\n                f\"Received: layout_map={layout_map}\"\n            )\n        device_mesh = layout_map.device_mesh\n        batch_dim_name = batch_dim_name or device_mesh.axis_names[0]\n        super().__init__(device_mesh, batch_dim_name, auto_shard_dataset)\n        self._layout_map = layout_map\n\n        # Those following attributes might get convert to public methods.\n        self._num_process = distribution_lib.num_processes()\n        self._process_id = distribution_lib.process_id()\n        self._is_multi_process = self._num_process > 1\n\n    def get_data_layout(self, data_shape):\n        data_shard_spec = [None] * len(data_shape)\n        data_shard_spec[0] = self.batch_dim_name  # Shard on the first dim\n        return TensorLayout(data_shard_spec, self.device_mesh)\n\n    def get_variable_layout(self, variable):\n        # First check if the variable already has a layout assigned.\n        if getattr(variable, \"_layout\", None) is not None:\n            return variable._layout\n        # Check the layout map.\n        variable_layout = self._layout_map[variable.path]\n        if variable_layout is not None:\n            return variable_layout\n        variable_shard_spec = [None] * len(variable.shape)\n        return TensorLayout(variable_shard_spec, self.device_mesh)\n\n    def get_tensor_layout(self, path):\n        return self._layout_map[path]\n\n    def distribute_dataset(self, dataset):\n        if not self._is_multi_process or not self.auto_shard_dataset:\n            return dataset\n\n        # Try to distribute a global tf.data.Dataset.\n        from keras.src.utils.module_utils import tensorflow as tf\n\n        if not tf.available or not isinstance(dataset, tf.data.Dataset):\n            raise ValueError(\n                \"Only `tf.data.Dataset` is supported for auto-sharding, \"\n                f\"got {type(dataset)}\"\n            )\n\n        from tensorflow.python.data.experimental.ops import (\n            distribute as tf_data_distribute,\n        )\n\n        global_batch_size = tf_data_distribute.compute_batch_size(dataset)\n        if global_batch_size.numpy() < 0:\n            raise ValueError(\n                \"The batch size of the input dataset is \"\n                \"unknown. Please config the batch size for \"\n                \"the input dataset, e.g via `dataset.batch(batch_size)`\"\n            )\n\n        # We need to compute the per-process/worker/host batch size.\n        # This will depend on how many model replicas we have on each process.\n        # Note that this might be smaller than one if model replicas are sharded\n        # across multiple processes.\n        mesh_batch_dim_index = self.device_mesh.axis_names.index(\n            self.batch_dim_name\n        )\n        num_model_replicas = self.device_mesh.shape[mesh_batch_dim_index]\n        if num_model_replicas == 1:\n            # No sharding is needed in this case. Each process will have the\n            # global batch size, and data from the iterator will need to be\n            # replicated across all processes.\n            return dataset.prefetch(tf.data.AUTOTUNE)\n        num_model_replicas_per_process = num_model_replicas / self._num_process\n        if num_model_replicas_per_process >= 1:\n            # Each process will have one or more full model replicas. Data will\n            # be sharded across all processes without replication.\n            if global_batch_size % self._num_process != 0:\n                raise ValueError(\n                    \"Global batch size must be divisible by the number of \"\n                    f\"processes. `global_batch_size`={global_batch_size} and \"\n                    f\"`num_process`={self._num_process}\"\n                )\n            per_process_batch_size = global_batch_size // self._num_process\n            distributed_dataset = dataset.rebatch(per_process_batch_size)\n            distributed_dataset = distributed_dataset.shard(\n                num_shards=self._num_process,\n                index=self._process_id,\n            )\n            return distributed_dataset.prefetch(tf.data.AUTOTUNE)\n        else:\n            # Model replicas are sharded across multiple processes. Data will be\n            # sharded across model replicas, and replicated across processes\n            # within the same model replica.\n            if global_batch_size % num_model_replicas != 0:\n                raise ValueError(\n                    \"Global batch size must be divisible by the number of \"\n                    f\"replicas. `global_batch_size`={global_batch_size} and \"\n                    f\"`num_model_replicas`={num_model_replicas}\"\n                )\n            per_process_batch_size = global_batch_size // num_model_replicas\n            distributed_dataset = dataset.rebatch(per_process_batch_size)\n            processes_per_replica = self._num_process // num_model_replicas\n            # TODO: Figure out what the convention is for data sharding id.\n            data_shard_id = self._process_id % processes_per_replica\n            distributed_dataset = distributed_dataset.shard(\n                num_shards=num_model_replicas,\n                index=data_shard_id,\n            )\n            return distributed_dataset.prefetch(tf.data.AUTOTUNE)\n\n\n@keras_export(\"keras.distribution.LayoutMap\")\nclass LayoutMap(collections.abc.MutableMapping):\n    \"\"\"A dict-like object that maps string to `TensorLayout` instances.\n\n    `LayoutMap` uses a string as key and a `TensorLayout` as value. There is a\n    behavior difference between a normal Python dict and this class. The string\n    key will be treated as a regex when retrieving the value. See the docstring\n    of `get` for more details.\n\n    See below for a usage example. You can define the naming schema\n    of the `TensorLayout`, and then retrieve the corresponding\n    `TensorLayout` instance.\n\n    In the normal case, the key to query is usually the `variable.path`, which\n    is the identifier of the variable.\n\n    As shortcut, tuple or list of axis names are also allowed when inserting\n    as value, and will be converted to `TensorLayout`.\n\n    ```python\n    layout_map = LayoutMap(device_mesh)\n    layout_map['.*dense.*kernel'] = (None, 'model')\n    layout_map['.*dense.*bias'] = ('model',)\n    layout_map['.*conv2d.*kernel'] = (None, None, None, 'model')\n    layout_map['.*conv2d.*bias'] = ('model',)\n\n    layout_1 = layout_map['dense_1.kernel']             # layout_1 == layout_2d\n    layout_2 = layout_map['dense_1.bias']               # layout_2 == layout_1d\n    layout_3 = layout_map['dense_2.kernel']             # layout_3 == layout_2d\n    layout_4 = layout_map['dense_2.bias']               # layout_4 == layout_1d\n    layout_5 = layout_map['my_model/conv2d_123/kernel'] # layout_5 == layout_4d\n    layout_6 = layout_map['my_model/conv2d_123/bias']   # layout_6 == layout_1d\n    layout_7 = layout_map['my_model/conv3d_1/kernel']   # layout_7 == None\n    layout_8 = layout_map['my_model/conv3d_1/bias']     # layout_8 == None\n    ```\n\n    Args:\n        device_mesh: `keras.distribution.DeviceMesh` instance.\n    \"\"\"\n\n    def __init__(self, device_mesh):\n        self._layout_map = collections.OrderedDict()\n        self._device_mesh = device_mesh\n\n    def __getitem__(self, key):\n        \"\"\"Retrieves the corresponding layout by the string key.\n\n        When there isn't an exact match, all the existing keys in the layout map\n        will be treated as a regex and map against the input key again. When\n        there are multiple matches for the regex, an `ValueError` will be\n        raised. Returns `None` if there isn't any match found.\n\n        Args:\n            key: String key to query a layout.\n\n        Returns:\n            Corresponding layout based on the query.\n        \"\"\"\n        if key in self._layout_map:\n            return self._layout_map[key]\n\n        matching_keys = [\n            pattern\n            for pattern in self._layout_map\n            if re.fullmatch(pattern, key)\n        ]\n        if len(matching_keys) > 1:\n            raise ValueError(\n                f\"Path '{key}' matches multiple layout \"\n                f\"specification keys: {matching_keys}. Please make \"\n                \"sure each tensor/variable path only matches at most \"\n                \"one layout specification key in the LayoutMap.\"\n            )\n        elif len(matching_keys) == 1:\n            return self._layout_map[matching_keys[0]]\n        return None\n\n    def __setitem__(self, key, layout):\n        \"\"\"Insert TensorLayout to the LayoutMap.\n\n        Args:\n            key: String key for the `TensorLayout`.\n            layout: The `TensorLayout`. As a shortcut, tuple of string and None\n                are also acceptable, and will be converted to `TensorLayout`.\n        \"\"\"\n        if key in self._layout_map:\n            raise ValueError(\n                f\"{key} already exist in the LayoutMap with \"\n                f\"value {self._layout_map[key]}. Please make sure to \"\n                \"not use duplicated keys.\"\n            )\n        if isinstance(layout, tuple):\n            layout = TensorLayout(axes=layout, device_mesh=None)\n\n        if not isinstance(layout, TensorLayout):\n            raise ValueError(\n                f\"{layout} should be a TensorLayout type, got {type(layout)}\"\n            )\n        self._maybe_populate_device_mesh(layout)\n        self._layout_map[key] = layout\n\n    def __delitem__(self, key):\n        # let the dict to handle the key missing error\n        return self._layout_map.pop(key)\n\n    def __len__(self):\n        return len(self._layout_map)\n\n    def __iter__(self):\n        return iter(self._layout_map)\n\n    @property\n    def device_mesh(self):\n        return self._device_mesh\n\n    def _maybe_populate_device_mesh(self, layout):\n        if layout.device_mesh is None and self.device_mesh is not None:\n            layout.device_mesh = self.device_mesh\n\n\nLayoutMap.get.__doc__ = LayoutMap.__getitem__.__doc__\n\n\n@keras_export(\"keras.distribution.distribute_tensor\")\ndef distribute_tensor(tensor, layout):\n    \"\"\"Change the layout of a Tensor value in the jit function execution.\n\n    Args:\n        tensor: a Tensor to change the layout.\n        layout: `TensorLayout` to be applied on the value.\n\n    Returns:\n        a new value with the specified tensor layout.\n    \"\"\"\n    if isinstance(tensor, KerasTensor):\n        # keras tensor is only used for building functional model, and can't be\n        # used to alter layout/sharding.\n        return tensor\n    return distribution_lib.distribute_tensor(tensor, layout)\n\n\n@keras_export(\"keras.distribution.distribution\")\ndef distribution():\n    \"\"\"Retrieve the current distribution from global context.\"\"\"\n    return global_state.get_global_attribute(GLOBAL_ATTRIBUTE_NAME)\n\n\n@keras_export(\"keras.distribution.set_distribution\")\ndef set_distribution(value):\n    \"\"\"Set the distribution as the global distribution setting.\n\n    Args:\n        value: a `Distribution` instance.\n    \"\"\"\n    global_state.set_global_attribute(GLOBAL_ATTRIBUTE_NAME, value)\n"
  },
  {
    "path": "keras/src/distribution/distribution_lib_test.py",
    "content": "\"\"\"Test for distribution_lib.py.\"\"\"\n\nimport os\nfrom unittest import mock\n\nimport numpy as np\nimport pytest\nimport tensorflow as tf\n\nfrom keras.src import backend\nfrom keras.src import testing\nfrom keras.src.backend import distribution_lib as backend_dlib\nfrom keras.src.distribution import distribution_lib\n\n\n@pytest.mark.skipif(\n    backend.backend() != \"jax\",\n    reason=\"Only JAX has the backend to mock at the moment\",\n)\n@mock.patch.object(\n    backend_dlib,\n    \"initialize\",\n    return_value=None,\n)\nclass MultiProcessInitializeTest(testing.TestCase):\n    def tearDown(self):\n        super().tearDown()\n        os.environ.clear()\n\n    def test_initialize_with_explicit_param(self, mock_backend_initialize):\n        job_addresses = \"10.0.0.1:1234,10.0.0.2:2345\"\n        num_processes = 2\n        current_process_id = 0\n\n        distribution_lib.initialize(\n            job_addresses, num_processes, current_process_id\n        )\n\n        mock_backend_initialize.assert_called_once_with(\n            job_addresses, num_processes, current_process_id\n        )\n\n    def test_initialize_with_env_vars(self, mock_backend_initialize):\n        job_addresses = \"10.0.0.1:1234,10.0.0.2:2345\"\n        num_processes = 2\n        current_process_id = 0\n        os.environ[\"KERAS_DISTRIBUTION_JOB_ADDRESSES\"] = job_addresses\n        os.environ[\"KERAS_DISTRIBUTION_NUM_PROCESSES\"] = str(num_processes)\n        os.environ[\"KERAS_DISTRIBUTION_PROCESS_ID\"] = str(current_process_id)\n\n        distribution_lib.initialize()\n        mock_backend_initialize.assert_called_once_with(\n            job_addresses, num_processes, current_process_id\n        )\n\n    def test_init_with_nones(self, mock_backend_initialize):\n        # This is also valid case for Cloud TPU on JAX\n        distribution_lib.initialize()\n        mock_backend_initialize.assert_called_once_with(None, None, None)\n\n\nclass DeviceMeshTest(testing.TestCase):\n    def test_mesh_creation(self):\n        devices = [f\"cpu:{i}\" for i in range(8)]\n        shape = (4, 2)\n        axis_names = [\"batch\", \"model\"]\n\n        mesh = distribution_lib.DeviceMesh(shape, axis_names, devices)\n        self.assertEqual(mesh.shape, shape)\n        self.assertEqual(mesh.axis_names, axis_names)\n        self.assertEqual(mesh.devices.shape, shape)\n\n    def test_input_validation(self):\n        devices = [f\"cpu:{i}\" for i in range(4)]\n        with self.assertRaisesRegex(\n            ValueError, \"Shape and axis_names cannot be empty\"\n        ):\n            distribution_lib.DeviceMesh((4,), \"\", devices)\n\n        with self.assertRaisesRegex(\n            ValueError, \"Shape and axis_names should have same size\"\n        ):\n            distribution_lib.DeviceMesh((4, 2), [\"batch\"], devices)\n\n        with self.assertRaisesRegex(\n            ValueError, \"Shape does not match the number of devices\"\n        ):\n            distribution_lib.DeviceMesh((4, 2), [\"batch\", \"model\"], devices)\n\n\nclass TensorLayoutTest(testing.TestCase):\n    def setUp(self):\n        self.mesh = distribution_lib.DeviceMesh(\n            (4, 2), [\"data\", \"model\"], [f\"cpu:{i}\" for i in range(8)]\n        )\n\n    def test_tensor_layout_creation(self):\n        axes = (\"data\", None)\n        layout = distribution_lib.TensorLayout(axes, self.mesh)\n\n        self.assertEqual(layout.device_mesh, self.mesh)\n        self.assertEqual(layout.axes, axes)\n\n    def test_tensor_layout_validation(self):\n        axes = (\"data\", \"unknown\", None)\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid axis names for Layout\"\n        ):\n            distribution_lib.TensorLayout(axes, self.mesh)\n\n    def test_lazy_device_mesh_injection(self):\n        axes = (\"data\", None)\n        layout = distribution_lib.TensorLayout(axes, None)\n\n        self.assertIsNone(layout.device_mesh)\n        self.assertEqual(layout.axes, axes)\n\n        layout.device_mesh = self.mesh\n\n        self.assertEqual(layout.device_mesh, self.mesh)\n        self.assertEqual(layout.axes, axes)\n\n    def test_lazy_device_mesh_validation(self):\n        axes = (\"data\", \"unknown\", None)\n        layout = distribution_lib.TensorLayout(axes, None)\n\n        self.assertIsNone(layout.device_mesh)\n        self.assertEqual(layout.axes, axes)\n\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid axis names for Layout\"\n        ):\n            layout.device_mesh = self.mesh\n\n\nclass DistributionTest(testing.TestCase):\n    def setUp(self):\n        super().setUp()\n        devices = [f\"cpu:{i}\" for i in range(8)]\n        shape = (4, 2)\n        axis_names = [\"batch\", \"model\"]\n\n        self.device_mesh = distribution_lib.DeviceMesh(\n            shape, axis_names, devices\n        )\n\n    def test_init_with_device_mesh(self):\n        distribution = distribution_lib.Distribution(self.device_mesh)\n        self.assertIs(distribution.device_mesh, self.device_mesh)\n\n    def test_scope(self):\n        distribution_1 = distribution_lib.Distribution(self.device_mesh)\n        distribution_2 = distribution_lib.Distribution(self.device_mesh)\n\n        self.assertIsNone(distribution_lib.distribution())\n        with distribution_1.scope():\n            self.assertIs(distribution_lib.distribution(), distribution_1)\n            with distribution_2.scope():\n                self.assertIs(distribution_lib.distribution(), distribution_2)\n\n            self.assertIs(distribution_lib.distribution(), distribution_1)\n\n        self.assertIsNone(distribution_lib.distribution())\n\n\n@pytest.mark.skipif(\n    backend.backend() != \"jax\",\n    reason=\"Only JAX has the proper backend distribution lib\",\n)\nclass DataParallelDistributionTest(testing.TestCase):\n    def setUp(self):\n        super().setUp()\n        self.devices = [f\"cpu:{i}\" for i in range(8)]\n        shape = (8,)\n        axis_names = [\"data\"]\n\n        self.device_mesh = distribution_lib.DeviceMesh(\n            shape, axis_names, self.devices\n        )\n\n    def test_create_with_device_mesh(self):\n        distribution = distribution_lib.DataParallel(\n            device_mesh=self.device_mesh\n        )\n\n        device_mesh = distribution.device_mesh\n        self.assertEqual(len(device_mesh.devices), 8)\n        self.assertEqual(device_mesh.axis_names, [\"data\"])\n        self.assertEqual(distribution.batch_dim_name, \"data\")\n\n        self.assertFalse(distribution._is_multi_process)\n        self.assertEqual(distribution._process_id, 0)\n        self.assertEqual(distribution._num_process, 1)\n\n    def test_create_with_devices(self):\n        distribution = distribution_lib.DataParallel(devices=self.devices)\n        device_mesh = distribution.device_mesh\n        self.assertEqual(len(device_mesh.devices), 8)\n        self.assertEqual(device_mesh.axis_names, [\"batch\"])\n        self.assertEqual(distribution.batch_dim_name, \"batch\")\n\n    @mock.patch.object(\n        distribution_lib,\n        \"list_devices\",\n        return_value=[f\"cpu:{i}\" for i in range(8)],\n    )\n    def test_create_with_list_devices(self, mock_list_devices):\n        distribution = distribution_lib.DataParallel()\n        mock_list_devices.assert_called_once()\n\n        device_mesh = distribution.device_mesh\n        self.assertEqual(len(device_mesh.devices), 8)\n        self.assertEqual(device_mesh.axis_names, [\"batch\"])\n        self.assertEqual(distribution.batch_dim_name, \"batch\")\n\n    def test_get_data_layout(self):\n        distribution = distribution_lib.DataParallel(\n            device_mesh=self.device_mesh\n        )\n\n        data = np.arange(16).reshape((4, 2, 2))\n        data_layout = distribution.get_data_layout(data.shape)\n        self.assertIs(data_layout.device_mesh, self.device_mesh)\n        self.assertEqual(data_layout.axes, (\"data\", None, None))\n\n    @pytest.mark.skipif(testing.jax_uses_gpu(), reason=\"CI segfault\")\n    def test_get_variable_layout(self):\n        distribution = distribution_lib.DataParallel(\n            device_mesh=self.device_mesh\n        )\n\n        variable = backend.Variable(initializer=[1, 2, 3])\n        variable_layout = distribution.get_variable_layout(variable)\n        self.assertIs(variable_layout.device_mesh, self.device_mesh)\n        self.assertEqual(variable_layout.axes, (None,))\n\n    @pytest.mark.skipif(testing.jax_uses_gpu(), reason=\"CI segfault\")\n    def test_get_variable_layout_with_explicit_layout(self):\n        distribution = distribution_lib.DataParallel(\n            device_mesh=self.device_mesh\n        )\n\n        explicit_mesh = distribution_lib.DeviceMesh((8,), [\"x\"], self.devices)\n        explicit_layout = distribution_lib.TensorLayout([\"x\"], explicit_mesh)\n\n        variable = backend.Variable(initializer=[1, 2, 3])\n        variable._layout = explicit_layout\n        variable_layout = distribution.get_variable_layout(variable)\n        self.assertIs(variable_layout.device_mesh, explicit_mesh)\n        self.assertEqual(variable_layout.axes, explicit_layout.axes)\n\n    def test_get_tensor_layout(self):\n        distribution = distribution_lib.DataParallel(\n            device_mesh=self.device_mesh\n        )\n\n        path = \"path/to/tensor\"\n        tensor_layout = distribution.get_tensor_layout(path)\n        self.assertIsNone(tensor_layout)\n\n    def test_distribute_dataset(self):\n        # We can only verify the single worker/process case in OSS for now.\n        dataset = tf.data.Dataset.range(8)\n        distribution = distribution_lib.DataParallel(\n            device_mesh=self.device_mesh\n        )\n        distributed_dataset = distribution.distribute_dataset(dataset)\n        self.assertIs(dataset, distributed_dataset)\n\n\n@pytest.mark.skipif(\n    backend.backend() != \"jax\",\n    reason=\"Only JAX has the proper backend distribution lib\",\n)\nclass ModelParallelDistributionTest(testing.TestCase):\n    def setUp(self):\n        super().setUp()\n        self.devices = [f\"cpu:{i}\" for i in range(8)]\n        shape = (2, 4)\n        axis_names = [\"data\", \"model\"]\n\n        self.device_mesh = distribution_lib.DeviceMesh(\n            shape, axis_names, self.devices\n        )\n\n    @pytest.mark.skipif(testing.jax_uses_gpu(), reason=\"CI segfault\")\n    def test_distribute_weights(self):\n        layout_map = distribution_lib.LayoutMap(self.device_mesh)\n        layout_map[\".*kernel\"] = distribution_lib.TensorLayout([None, \"model\"])\n        layout_map[\".*bias\"] = distribution_lib.TensorLayout([\"model\"])\n\n        distribution = distribution_lib.ModelParallel(\n            layout_map=layout_map, batch_dim_name=\"data\"\n        )\n        kernel = backend.Variable(initializer=np.arange(8, 4), name=\"kernel\")\n        bias = backend.Variable(initializer=np.arange(4), name=\"bias\")\n        rng_seed = backend.Variable(initializer=[0, 1], name=\"seed\")\n\n        kernel_layout = distribution.get_variable_layout(kernel)\n        self.assertIs(kernel_layout.device_mesh, self.device_mesh)\n        self.assertEqual(kernel_layout.axes, (None, \"model\"))\n\n        bias_layout = distribution.get_variable_layout(bias)\n        self.assertIs(bias_layout.device_mesh, self.device_mesh)\n        self.assertEqual(bias_layout.axes, (\"model\",))\n\n        rng_seed_layout = distribution.get_variable_layout(rng_seed)\n        self.assertIs(rng_seed_layout.device_mesh, self.device_mesh)\n        self.assertEqual(rng_seed_layout.axes, (None,))\n\n    def test_distribute_data(self):\n        layout_map = distribution_lib.LayoutMap(self.device_mesh)\n        distribution = distribution_lib.ModelParallel(\n            layout_map=layout_map, batch_dim_name=\"data\"\n        )\n\n        data = np.arange(16).reshape((4, 2, 2))\n        data_layout = distribution.get_data_layout(data.shape)\n        self.assertIs(data_layout.device_mesh, self.device_mesh)\n        self.assertEqual(data_layout.axes, (\"data\", None, None))\n\n    def test_get_tensor_layout(self):\n        layout_map = distribution_lib.LayoutMap(self.device_mesh)\n        layout_map[\".*kernel\"] = distribution_lib.TensorLayout([None, \"model\"])\n        layout_map[\".*bias\"] = distribution_lib.TensorLayout([\"model\"])\n        layout_map[\"/model/layer/tensor\"] = (\"data\", None)\n\n        distribution = distribution_lib.ModelParallel(\n            layout_map=layout_map, batch_dim_name=\"data\"\n        )\n        layout = distribution.get_tensor_layout(\"/model/layer/tensor\")\n        self.assertIs(layout.device_mesh, self.device_mesh)\n        self.assertEqual(layout.axes, (\"data\", None))\n\n        layout = distribution.get_tensor_layout(\"/model/layer/other_tensor\")\n        self.assertIsNone(layout)\n\n    @pytest.mark.skipif(testing.jax_uses_gpu(), reason=\"CI segfault\")\n    def test_get_variable_layout_with_explicit_layout(self):\n        layout_map = distribution_lib.LayoutMap(self.device_mesh)\n        layout_map[\".*kernel\"] = distribution_lib.TensorLayout([None, \"model\"])\n        distribution = distribution_lib.ModelParallel(\n            layout_map=layout_map, batch_dim_name=\"data\"\n        )\n\n        explicit_mesh = distribution_lib.DeviceMesh((8,), [\"x\"], self.devices)\n        explicit_layout = distribution_lib.TensorLayout([\"x\"], explicit_mesh)\n        variable = backend.Variable(initializer=[1, 2, 3], name=\"kernel\")\n        variable._layout = explicit_layout\n        variable_layout = distribution.get_variable_layout(variable)\n        self.assertIs(variable_layout.device_mesh, explicit_mesh)\n        self.assertEqual(variable_layout.axes, explicit_layout.axes)\n\n    def test_distribute_dataset(self):\n        # We can only verify the single worker/process case in OSS for now.\n        dataset = tf.data.Dataset.range(8)\n        layout_map = distribution_lib.LayoutMap(self.device_mesh)\n        distribution = distribution_lib.ModelParallel(\n            layout_map=layout_map, batch_dim_name=\"data\"\n        )\n        distributed_dataset = distribution.distribute_dataset(dataset)\n        self.assertIs(dataset, distributed_dataset)\n\n\nclass LayoutMapTest(testing.TestCase):\n    def setUp(self):\n        super().setUp()\n        self.devices = [f\"cpu:{i}\" for i in range(8)]\n        shape = (4, 2)\n        axis_names = [\"data\", \"model\"]\n\n        self.device_mesh = distribution_lib.DeviceMesh(\n            shape, axis_names, self.devices\n        )\n        self.sharded_2d = distribution_lib.TensorLayout([None, \"model\"])\n        self.sharded_1d = distribution_lib.TensorLayout([\"model\"])\n\n        self.replicated_2d = distribution_lib.TensorLayout([None, None])\n        self.replicated_1d = distribution_lib.TensorLayout([None])\n\n    def test_add(self):\n        layout_map = distribution_lib.LayoutMap(self.device_mesh)\n        layout_map[\"dense/kernel\"] = self.sharded_2d\n        layout_map[\"dense/bias\"] = self.sharded_1d\n        # Test for adding list/tuple as shortcut for TensorLayout\n        layout_map[\"conv/bias\"] = (\"model\",)\n\n        # Make there are two items in the map, and we access them via the\n        # underlying container at layout_map._layout_map\n        self.assertLen(layout_map, 3)\n\n        kernel_layout = layout_map[\"dense/kernel\"]\n        self.assertEqual(kernel_layout.axes, (None, \"model\"))\n        self.assertIs(kernel_layout.device_mesh, self.device_mesh)\n\n        bias_layout = layout_map[\"dense/bias\"]\n        self.assertEqual(bias_layout.axes, (\"model\",))\n        self.assertIs(bias_layout.device_mesh, self.device_mesh)\n\n        conv_bias_layout = layout_map[\"conv/bias\"]\n        self.assertEqual(conv_bias_layout.axes, (\"model\",))\n        self.assertIs(bias_layout.device_mesh, self.device_mesh)\n\n        with self.assertRaisesRegex(ValueError, \"dense/kernel already exist\"):\n            layout_map[\"dense/kernel\"] = self.sharded_2d\n\n        with self.assertRaisesRegex(ValueError, \"should be a TensorLayout\"):\n            layout_map[\"conv.kernel\"] = [\"a\", \"b\"]\n\n    def test_get(self):\n        layout_map = distribution_lib.LayoutMap(self.device_mesh)\n        layout_map[\"dense/kernel\"] = self.sharded_2d\n        layout_map[\"dense/bias\"] = self.sharded_1d\n\n        layout_map[\".*dense.*kernel\"] = self.replicated_2d\n        layout_map[\".*dense.*bias\"] = self.replicated_1d\n\n        layout_map[\".*bias\"] = self.sharded_1d\n\n        self.assertEqual(layout_map[\"dense/kernel\"], self.sharded_2d)\n        self.assertEqual(layout_map[\"dense/bias\"], self.sharded_1d)\n\n        self.assertEqual(layout_map[\"dense_2/kernel\"], self.replicated_2d)\n        # Map against the wildcard bias rule for dense. This will cause a\n        # ValueError\n        with self.assertRaisesRegex(\n            ValueError, \"Path 'dense_2/bias' matches multiple layout\"\n        ):\n            layout_map[\"dense_2/bias\"]\n\n        self.assertIsNone(layout_map[\"conv2d/kernel\"])\n        self.assertEqual(layout_map[\"conv2d/bias\"], self.sharded_1d)\n\n    def test_delete(self):\n        layout_map = distribution_lib.LayoutMap(self.device_mesh)\n\n        layout_map[\"dense/kernel\"] = self.sharded_2d\n        layout_map[\"dense/bias\"] = self.sharded_1d\n\n        self.assertEqual(layout_map.pop(\"dense/kernel\"), self.sharded_2d)\n        # Make sure to match against the exact string, not the regex\n        with self.assertRaises(KeyError):\n            layout_map.pop(\".*bias\")\n\n        # Make sure del also works\n        del layout_map[\"dense/bias\"]\n\n        self.assertLen(layout_map, 0)\n\n    def test_len(self):\n        layout_map = distribution_lib.LayoutMap(self.device_mesh)\n        self.assertLen(layout_map, 0)\n\n        layout_map[\"dense/kernel\"] = self.sharded_2d\n        layout_map[\"dense/bias\"] = self.sharded_1d\n\n        self.assertLen(layout_map, 2)\n\n    def test_iter(self):\n        layout_map = distribution_lib.LayoutMap(self.device_mesh)\n\n        layout_map[\"dense/kernel\"] = self.sharded_2d\n        layout_map[\"dense/bias\"] = self.sharded_1d\n\n        # Make sure the items are ordered based on the insertion order.\n        self.assertEqual(\n            list(layout_map.keys()), [\"dense/kernel\", \"dense/bias\"]\n        )\n\n        keys = []\n        values = []\n        for k, v in layout_map.items():\n            keys.append(k)\n            values.append(v)\n\n        self.assertEqual(keys, [\"dense/kernel\", \"dense/bias\"])\n        self.assertEqual(values, [self.sharded_2d, self.sharded_1d])\n\n\n# @pytest.mark.skipif(\n#     backend.backend() != \"tensorflow\",\n#     reason=\"Backend specific test\",\n# )\n# class TensorflowDistributionLibTest(testing.TestCase):\n#     def setUp(self):\n#         super().setUp()\n#         # Config virtual devices for testing.\n#         cpus = tf.config.list_physical_devices(\"cpu\")\n#         context._reset_context()\n#         tf.config.set_logical_device_configuration(\n#             cpus[0], [tf.config.LogicalDeviceConfiguration()] * 8\n#         )\n\n#         dtensor.initialize_accelerator_system(\"cpu\")\n\n#     def tearDown(self) -> None:\n#         super().tearDown()\n#         dtensor.shutdown_accelerator_system()\n\n#     def test_list_devices(self):\n#         self.assertEqual(len(distribution_lib.list_devices()), 8)\n#         self.assertEqual(len(distribution_lib.list_devices(\"cpu\")), 8)\n#         self.assertEqual(len(distribution_lib.list_devices(\"cpu\")), 8)\n\n#     def test_to_dtensor_mesh(self):\n#         devices = [f\"cpu:{i}\" for i in range(8)]\n#         shape = (4, 2)\n#         axis_names = [\"batch\", \"model\"]\n\n#         mesh = distribution_lib.DeviceMesh(shape, axis_names, devices)\n#         dtensor_mesh = backend_dlib._to_dtensor_mesh(mesh)\n\n#         self.assertIsInstance(dtensor_mesh, dtensor.Mesh)\n#         self.assertEqual(dtensor_mesh.shape(), list(shape))\n#         self.assertEqual(dtensor_mesh.dim_names, axis_names)\n\n#     def test_to_dtensor_layout(self):\n#         axes = [\"data\", None]\n#         mesh = distribution_lib.DeviceMesh(\n#             (4, 2), [\"data\", \"model\"], [f\"cpu:{i}\" for i in range(8)]\n#         )\n#         layout = distribution_lib.TensorLayout(axes, mesh)\n#         dtensor_layout = backend_dlib._to_dtensor_layout(layout)\n#         dtensor_mesh = backend_dlib._to_dtensor_mesh(mesh)\n#         self.assertEqual(\n#             dtensor_layout,\n#             dtensor.Layout([\"data\", dtensor.UNSHARDED], dtensor_mesh),\n#         )\n\n#     def test_validation_for_device_mesh(self):\n#         axes = [\"data\", None]\n#         layout = distribution_lib.TensorLayout(axes, device_mesh=None)\n\n#         with self.assertRaisesRegex(\n#             ValueError, \"Cannot create sharding when device mesh is not set\"\n#         ):\n#             backend_dlib._to_dtensor_layout(layout)\n"
  },
  {
    "path": "keras/src/dtype_policies/__init__.py",
    "content": "from keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.dtype_policies import dtype_policy\nfrom keras.src.dtype_policies.dtype_policy import QUANTIZATION_MODES\nfrom keras.src.dtype_policies.dtype_policy import AWQDTypePolicy\nfrom keras.src.dtype_policies.dtype_policy import DTypePolicy\nfrom keras.src.dtype_policies.dtype_policy import FloatDTypePolicy\nfrom keras.src.dtype_policies.dtype_policy import GPTQDTypePolicy\nfrom keras.src.dtype_policies.dtype_policy import Int4DTypePolicy\nfrom keras.src.dtype_policies.dtype_policy import QuantizedDTypePolicy\nfrom keras.src.dtype_policies.dtype_policy import QuantizedFloat8DTypePolicy\nfrom keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap\n\nALL_OBJECTS = {\n    AWQDTypePolicy,\n    DTypePolicy,\n    FloatDTypePolicy,\n    QuantizedDTypePolicy,\n    QuantizedFloat8DTypePolicy,\n    DTypePolicyMap,\n    GPTQDTypePolicy,\n    Int4DTypePolicy,\n}\nALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}\n\n\n@keras_export(\"keras.dtype_policies.serialize\")\ndef serialize(dtype_policy):\n    \"\"\"Serializes `DTypePolicy` instance.\n\n    Args:\n        dtype_policy: A Keras `DTypePolicy` instance.\n\n    Returns:\n        `DTypePolicy` configuration dictionary.\n    \"\"\"\n    from keras.src.saving import serialization_lib\n\n    return serialization_lib.serialize_keras_object(dtype_policy)\n\n\n@keras_export(\"keras.dtype_policies.deserialize\")\ndef deserialize(config, custom_objects=None):\n    \"\"\"Deserializes a serialized `DTypePolicy` instance.\n\n    Args:\n        config: `DTypePolicy` configuration.\n        custom_objects: Optional dictionary mapping names (strings) to custom\n            objects (classes and functions) to be considered during\n            deserialization.\n\n    Returns:\n        A Keras `DTypePolicy` instance.\n    \"\"\"\n    from keras.src.saving import serialization_lib\n\n    return serialization_lib.deserialize_keras_object(\n        config,\n        module_objects=ALL_OBJECTS_DICT,\n        custom_objects=custom_objects,\n    )\n\n\n@keras_export(\"keras.dtype_policies.get\")\ndef get(identifier):\n    \"\"\"Retrieves a Keras `DTypePolicy` instance.\n\n    The `identifier` may be the string name of a `DTypePolicy` class.\n\n    >>> policy = dtype_policies.get(\"mixed_bfloat16\")\n    >>> type(policy)\n    <class '...DTypePolicy'>\n\n    You can also specify `config` of the dtype policy to this function by\n    passing dict containing `class_name` and `config` as an identifier. Also\n    note that the `class_name` must map to a `DTypePolicy` class\n\n    >>> identifier = {\"class_name\": \"DTypePolicy\",\n    ...               \"config\": {\"name\": \"float32\"}}\n    >>> policy = dtype_policies.get(identifier)\n    >>> type(policy)\n    <class '...DTypePolicy'>\n\n    Args:\n        identifier: A dtype policy identifier. One of `None` or string name of a\n            `DTypePolicy` or `DTypePolicy` configuration dictionary or a\n            `DTypePolicy` instance.\n\n    Returns:\n        A Keras `DTypePolicy` instance.\n    \"\"\"\n    from keras.src.dtype_policies.dtype_policy import (\n        _get_quantized_dtype_policy_by_str,\n    )\n\n    if identifier is None:\n        return dtype_policy.dtype_policy()\n    if isinstance(identifier, DTypePolicy):\n        return identifier\n    if isinstance(identifier, dict):\n        return deserialize(identifier)\n    if isinstance(identifier, str):\n        if identifier.startswith(QUANTIZATION_MODES):\n            return _get_quantized_dtype_policy_by_str(identifier)\n        else:\n            return DTypePolicy(identifier)\n    try:\n        return DTypePolicy(backend.standardize_dtype(identifier))\n    except:\n        raise ValueError(\n            \"Cannot interpret `dtype` argument. Expected a string \"\n            f\"or an instance of DTypePolicy. Received: dtype={identifier}\"\n        )\n"
  },
  {
    "path": "keras/src/dtype_policies/dtype_policy.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend.common import global_state\n\nQUANTIZATION_MODES = (\"int8\", \"float8\", \"int4\", \"gptq\", \"awq\")\n\n\n@keras_export(\n    [\n        \"keras.DTypePolicy\",\n        \"keras.dtype_policies.DTypePolicy\",\n        \"keras.mixed_precision.DTypePolicy\",  # Legacy\n        \"keras.mixed_precision.Policy\",  # Legacy\n    ]\n)\nclass DTypePolicy:\n    \"\"\"A dtype policy for a Keras layer.\n\n    A dtype policy determines a layer's computation and variable dtypes. Each\n    layer has a policy. Policies can be passed to the `dtype` argument of layer\n    constructors, or a global policy can be set with\n    `keras.config.set_dtype_policy`.\n\n    Args:\n        name: The policy name, which determines the compute and variable dtypes.\n            Can be any dtype name, such as `\"float32\"` or `\"float64\"`,\n            which causes both the compute and variable dtypes\n            will be that dtype.\n            Can also be the string `\"mixed_float16\"` or `\"mixed_bfloat16\"`,\n            which causes the compute dtype to be `float16` or `bfloat16`\n            and the variable dtype to be `float32`.\n\n    Typically you only need to interact with dtype policies when using mixed\n    precision, which is the use of float16 or bfloat16 for computations and\n    float32 for variables. This is why the term `mixed_precision` appears in the\n    API name. Mixed precision can be enabled by passing `\"mixed_float16\"` or\n    `\"mixed_bfloat16\"` to `keras.mixed_precision.set_dtype_policy()`.\n\n    >>> keras.config.set_dtype_policy(\"mixed_float16\")\n    >>> layer1 = keras.layers.Dense(10)\n    >>> layer1.dtype_policy  # layer1 will automatically use mixed precision\n    <DTypePolicy \"mixed_float16\">\n    >>> # Can optionally override layer to use float32\n    >>> # instead of mixed precision.\n    >>> layer2 = keras.layers.Dense(10, dtype=\"float32\")\n    >>> layer2.dtype_policy\n    <DTypePolicy \"float32\">\n    >>> # Set policy back to initial float32.\n    >>> keras.config.set_dtype_policy('float32')\n\n    In the example above, passing `dtype=\"float32\"` to the layer is\n    equivalent to passing\n    `dtype=keras.config.DTypePolicy(\"float32\")`.\n    In general, passing a dtype policy name to a layer is equivalent\n    to passing the corresponding policy, so it is never necessary\n    to explicitly construct a `DTypePolicy` object.\n    \"\"\"\n\n    def __init__(self, name=None):\n        # Use the global dtype policy if `name` is not specified\n        if name is None:\n            name = dtype_policy().name\n        self._name = name\n        self._compute_dtype, self._variable_dtype = self._parse_name(name)\n        self._quantization_mode = None\n\n    def _parse_name(self, name):\n        \"\"\"Parses a `DTypePolicy` name into a compute and variable dtype.\n\n        Args:\n            name: The name of the policy.\n\n        Returns:\n            The `(compute_dtype, variable_dtype)` pair.\n        \"\"\"\n        if not isinstance(name, str):\n            raise TypeError(\n                \"'name' must be a string, such as 'mixed_float16'. \"\n                f\"Received: name={name} (of type {type(name)})\"\n            )\n        if name == \"mixed_float16\":\n            return \"float16\", \"float32\"\n        elif name == \"mixed_bfloat16\":\n            return \"bfloat16\", \"float32\"\n        try:\n            dtype = backend.standardize_dtype(name)\n            return dtype, dtype\n        except ValueError:\n            raise ValueError(\n                f\"Cannot convert '{name}' to a mixed precision \"\n                \"DTypePolicy. Valid policies include 'mixed_float16', \"\n                \"'mixed_bfloat16', and the name of any float dtype such as \"\n                \"'float32'.\"\n            )\n\n    @property\n    def variable_dtype(self):\n        \"\"\"The variable dtype of this policy.\n\n        This is the dtype layers will create their variables in, unless a layer\n        explicitly chooses a different dtype. If this is different than\n        `DTypePolicy.compute_dtype`, Layers will cast variables to\n        the compute dtype to avoid type errors.\n\n        Variable regularizers are run in the variable dtype, not the compute\n        dtype.\n\n        Returns:\n            The variable dtype of this policy, as a string.\n        \"\"\"\n        return self._variable_dtype\n\n    @property\n    def compute_dtype(self):\n        \"\"\"The compute dtype of this policy.\n\n        This is the dtype layers will do their computations in. Typically layers\n        output tensors with the compute dtype as well.\n\n        Note that even if the compute dtype is float16 or bfloat16, hardware\n        devices may not do individual adds, multiplies, and other fundamental\n        operations in float16 or bfloat16, but instead may do some of them in\n        float32 for numeric stability. The compute dtype is the dtype of the\n        inputs and outputs of the ops that the layer executes.\n        Internally, many ops will do certain internal calculations in\n        float32 or some other device-internal intermediate format with higher\n        precision than float16/bfloat16, to increase numeric stability.\n\n        Returns:\n            The compute dtype of this policy, as a string.\n        \"\"\"\n        return self._compute_dtype\n\n    @property\n    def name(self):\n        \"\"\"Returns the name of this policy.\"\"\"\n        return self._name\n\n    @property\n    def quantization_mode(self):\n        \"\"\"The quantization mode of this policy.\n\n        Returns:\n            The quantization mode of this policy, as a string. If this policy is\n            not quantized, it will return `None`.\n        \"\"\"\n        return self._quantization_mode\n\n    def convert_input(self, x, autocast, dtype):\n        \"\"\"Converts the input dtype based on `autocast` and `dtype`.\n\n        Note that `x` can be a tensor, symbolic tensor or numpy array, and this\n        method will keep integer inputs untouched and only apply casting to\n        floats.\n        \"\"\"\n\n        dtype = backend.standardize_dtype(dtype)\n        if backend.is_tensor(x):\n            if self._should_cast(x, autocast, dtype):\n                x = backend.cast(x, dtype=dtype)\n            return x\n        elif backend.is_keras_tensor(x):\n            if self._should_cast(x, autocast, dtype):\n                x = ops.cast(x, dtype=dtype)\n            return x\n        elif hasattr(x, \"__array__\"):\n            try:\n                x = backend.convert_to_tensor(x)\n            except TypeError:\n                x = backend.convert_to_tensor(x, dtype=dtype)\n            if self._should_cast(x, autocast, dtype):\n                x = backend.cast(x, dtype=dtype)\n            return x\n        return x\n\n    def get_config(self):\n        return {\"name\": self.name}\n\n    @classmethod\n    def from_config(cls, config):\n        return cls(**config)\n\n    def __repr__(self):\n        class_name = self.__class__.__name__\n        if class_name == \"FloatDTypePolicy\":\n            class_name = \"DTypePolicy\"\n        return f'<{class_name} \"{self._name}\">'\n\n    def __eq__(self, other):\n        if self.__class__ in (DTypePolicy, FloatDTypePolicy):\n            if type(other) not in (DTypePolicy, FloatDTypePolicy):\n                return False\n        else:\n            if type(other) is not self.__class__:\n                return False\n        return self._name == other._name\n\n    def _should_cast(self, x, autocast, dtype):\n        x_dtype = backend.standardize_dtype(x.dtype)\n        if autocast and backend.is_float_dtype(x_dtype) and x_dtype != dtype:\n            return True\n        else:\n            return False\n\n\n@keras_export(\n    [\"keras.FloatDTypePolicy\", \"keras.dtype_policies.FloatDTypePolicy\"]\n)\nclass FloatDTypePolicy(DTypePolicy):\n    # An alias for `DTypePolicy`\n    pass\n\n\n@keras_export(\"keras.dtype_policies.QuantizedDTypePolicy\")\nclass QuantizedDTypePolicy(DTypePolicy):\n    def __init__(self, mode, source_name=None):\n        # Use the global dtype policy if `source_name` is not specified\n        if source_name is None:\n            source_name = dtype_policy().name\n        name = f\"{mode}_from_{source_name}\"\n        self._compute_dtype, self._variable_dtype = self._parse_name(\n            source_name\n        )\n        self._check_quantization_mode(mode, self._compute_dtype)\n\n        self._name = name\n        self._source_name = source_name\n        self._quantization_mode = mode\n\n    def __eq__(self, other):\n        if super().__eq__(other) is False:\n            return False\n        return (\n            self._quantization_mode == other._quantization_mode\n            and self._source_name == other._source_name\n        )\n\n    def get_config(self):\n        return {\n            \"mode\": self._quantization_mode,\n            \"source_name\": self._source_name,\n        }\n\n    def _check_quantization_mode(self, mode, compute_dtype):\n        if mode not in QUANTIZATION_MODES:\n            raise ValueError(\n                \"Invalid quantization mode. \"\n                f\"Expected one of {QUANTIZATION_MODES}. \"\n                f\"Received: mode={mode}\"\n            )\n        if compute_dtype == \"float16\" and mode == \"int8\":\n            raise ValueError(\n                f\"Quantization mode='{mode}' doesn't work well with \"\n                \"compute_dtype='float16'.\"\n            )\n\n\n@keras_export(\"keras.dtype_policies.QuantizedFloat8DTypePolicy\")\nclass QuantizedFloat8DTypePolicy(QuantizedDTypePolicy):\n    default_amax_history_length = 1024\n\n    def __init__(self, mode, source_name=None, amax_history_length=1024):\n        super().__init__(mode=mode, source_name=source_name)\n        if not isinstance(amax_history_length, int):\n            raise TypeError(\n                \"`amax_history_length` must be an integer. \"\n                f\"Received: amax_history_length={amax_history_length}\"\n            )\n        self._amax_history_length = amax_history_length\n\n    @property\n    def amax_history_length(self):\n        \"\"\"The length of the amax history window.\n\n        This property is used for scaling factor computation in float8 training.\n        \"\"\"\n        return self._amax_history_length\n\n    def __eq__(self, other):\n        if super().__eq__(other) is False:\n            return False\n        return self._amax_history_length == other._amax_history_length\n\n    def get_config(self):\n        config = super().get_config()\n        config.update({\"amax_history_length\": self.amax_history_length})\n        return config\n\n\n@keras_export(\"keras.dtype_policies.Int4DTypePolicy\")\nclass Int4DTypePolicy(QuantizedDTypePolicy):\n    \"\"\"Quantized dtype policy for int4 quantization.\n\n    This policy helps propagate quantization settings for int4 sub-channel\n    quantization when loading a quantized model in Keras format.\n\n    Args:\n        mode: The quantization mode. This should be a string in the format\n            `\"int4/<block_size>\"`.\n            -   `\"int4\"`: The identifier for the quantization algorithm.\n            -   `<block_size>`: The block size for sub-channel quantization.\n                Use -1 for per-channel (legacy) quantization. Any positive\n                integer enables sub-channel quantization with that block size.\n            Example: `\"int4/128\"` for sub-channel with 128-element groups.\n        source_name: The source dtype policy name, e.g. \"float32\".\n    \"\"\"\n\n    def __init__(\n        self,\n        mode,\n        source_name=None,\n    ):\n        parts = mode.split(\"/\")\n        expected_format = \"'int4/<block_size>'\"\n\n        # Validate format\n        if len(parts) != 2 or parts[0] != \"int4\":\n            raise ValueError(\n                \"Invalid mode for Int4DTypePolicy. Expected format \"\n                f\"{expected_format}, but got '{mode}'.\"\n            )\n\n        # Validate and cast block_size\n        try:\n            block_size = int(parts[1])\n        except ValueError:\n            raise ValueError(\n                \"Invalid mode for Int4DTypePolicy. <block_size> must be an \"\n                f\"integer. Expected format {expected_format}, but got '{mode}'.\"\n            )\n\n        # Validate supported values\n        if block_size < -1 or block_size == 0:\n            raise ValueError(\n                \"Invalid block_size in mode. Supported values are \"\n                \"-1 (per-channel) or a positive integer (sub-channel), \"\n                f\"but got {block_size} from '{mode}'.\"\n            )\n\n        base_mode = parts[0]\n        super().__init__(\n            mode=base_mode,\n            source_name=source_name,\n        )\n\n        self._name = f\"{mode}_from_{source_name}\"\n        self.mode = base_mode\n        self.block_size = block_size\n\n    def __eq__(self, other):\n        if super().__eq__(other) is False:\n            return False\n        return self.block_size == other.block_size\n\n    def get_config(self):\n        config = super().get_config()\n        # Reconstruct the full mode string for serialization\n        mode = f\"{self.mode}/{self.block_size}\"\n        config.update({\"mode\": mode})\n        return config\n\n\n@keras_export(\"keras.dtype_policies.GPTQDTypePolicy\")\nclass GPTQDTypePolicy(QuantizedDTypePolicy):\n    \"\"\"Quantized dtype policy for GPTQ quantization.\n\n    This policy helps propagate quantization settings for GPTQ\n    when loading a GPTQ quantized model in Keras format.\n\n    Args:\n        mode: The quantization mode. This should be a string in the format\n            `\"gptq/<weight_bits>/<group_size>\"`.\n            -   `\"gptq\"`: The identifier for the quantization algorithm.\n            -   `<weight_bits>`: Number of bits to quantize weights to.\n                Supported values are 2, 3, 4, and 8.\n            -   `<group_size>`: The group size for quantization. Supported\n                values are -1 (for whole-tensor quantization) or any\n                positive integer. Typically a smaller group size leads\n                to better accuracy but slower speed.\n            Example: `\"gptq/4/128\"`.\n        source_name: The source dtype policy name, e.g. \"float32\".\n    \"\"\"\n\n    def __init__(\n        self,\n        mode,\n        source_name=None,\n    ):\n        parts = mode.split(\"/\")\n        expected_format = \"'gptq/<weight_bits>/<group_size>'\"\n\n        # Validate format\n        if len(parts) != 3 or parts[0] != \"gptq\":\n            raise ValueError(\n                \"Invalid mode for GPTQDTypePolicy. Expected format \"\n                f\"{expected_format}, but got '{mode}'.\"\n            )\n\n        # Validate and cast weight_bits and group_size\n        try:\n            weight_bits = int(parts[1])\n            group_size = int(parts[2])\n        except ValueError:\n            raise ValueError(\n                \"Invalid mode for GPTQDTypePolicy. <weight_bits> and \"\n                \"<group_size> must be integers. Expected format \"\n                f\"{expected_format}, but got '{mode}'.\"\n            )\n\n        # Validate supported values\n        if weight_bits not in [2, 3, 4, 8]:\n            raise ValueError(\n                \"Invalid weight_bits in mode. Supported values are \"\n                f\"2, 3, 4, and 8, but got {weight_bits} from '{mode}'.\"\n            )\n\n        if group_size < -1 or group_size == 0:\n            raise ValueError(\n                \"Invalid group_size in mode. Supported values are \"\n                \"-1 (whole-tensor) or a positive integer, \"\n                f\"but got {group_size} from '{mode}'.\"\n            )\n\n        base_mode = parts[0]\n        super().__init__(\n            mode=base_mode,\n            source_name=source_name,\n        )\n\n        self._name = f\"{mode}_from_{source_name}\"\n        self.mode = base_mode\n        self.weight_bits = weight_bits\n        self.group_size = group_size\n\n    def __eq__(self, other):\n        if super().__eq__(other) is False:\n            return False\n        return (\n            self.weight_bits == other.weight_bits\n            and self.group_size == other.group_size\n        )\n\n    def get_config(self):\n        config = super().get_config()\n        # Reconstruct the full mode string for serialization\n        mode = f\"{self.mode}/{self.weight_bits}/{self.group_size}\"\n        config.update({\"mode\": mode})\n        return config\n\n\n@keras_export(\"keras.dtype_policies.AWQDTypePolicy\")\nclass AWQDTypePolicy(QuantizedDTypePolicy):\n    \"\"\"Quantized dtype policy for AWQ quantization.\n\n    This policy helps propagate quantization settings for AWQ\n    when loading an AWQ quantized model in Keras format.\n\n    Args:\n        mode: The quantization mode. This should be a string in the format\n            `\"awq/<weight_bits>/<group_size>\"`.\n            -   `\"awq\"`: The identifier for the quantization algorithm.\n            -   `<weight_bits>`: Number of bits to quantize weights to.\n                AWQ presently only supports 4-bit quantization.\n            -   `<group_size>`: The group size for quantization. Supported\n                values are -1 (for per-channel quantization) or any\n                positive integer.\n            Example: `\"awq/4/128\"`.\n        source_name: The source dtype policy name, e.g. \"float32\".\n    \"\"\"\n\n    def __init__(\n        self,\n        mode,\n        source_name=None,\n    ):\n        parts = mode.split(\"/\")\n        expected_format = \"'awq/<weight_bits>/<group_size>'\"\n\n        # Validate format.\n        if len(parts) != 3 or parts[0] != \"awq\":\n            raise ValueError(\n                \"Invalid mode for AWQDTypePolicy. Expected format \"\n                f\"{expected_format}, but got '{mode}'.\"\n            )\n\n        # Validate and cast weight_bits and group_size.\n        try:\n            weight_bits = int(parts[1])\n            group_size = int(parts[2])\n        except ValueError:\n            raise ValueError(\n                \"Invalid mode for AWQDTypePolicy. <weight_bits> and \"\n                \"<group_size> must be integers. Expected format \"\n                f\"{expected_format}, but got '{mode}'.\"\n            )\n\n        # AWQ presently only supports 4-bit quantization.\n        if weight_bits != 4:\n            raise ValueError(\n                \"Invalid weight_bits in mode. AWQ only supports 4-bit \"\n                f\"quantization, but got {weight_bits} from '{mode}'.\"\n            )\n\n        if group_size < -1 or group_size == 0:\n            raise ValueError(\n                \"Invalid group_size in mode. Supported values are \"\n                \"-1 (per-channel) or a positive integer, \"\n                f\"but got {group_size} from '{mode}'.\"\n            )\n\n        base_mode = parts[0]\n        super().__init__(\n            mode=base_mode,\n            source_name=source_name,\n        )\n\n        self._name = f\"{mode}_from_{source_name}\"\n        self.mode = base_mode\n        self.weight_bits = weight_bits\n        self.group_size = group_size\n\n    def __eq__(self, other):\n        if super().__eq__(other) is False:\n            return False\n        return (\n            self.weight_bits == other.weight_bits\n            and self.group_size == other.group_size\n        )\n\n    def get_config(self):\n        config = super().get_config()\n        # Reconstruct the full mode string for serialization\n        mode = f\"{self.mode}/{self.weight_bits}/{self.group_size}\"\n        config.update({\"mode\": mode})\n        return config\n\n\n@keras_export(\n    [\n        \"keras.config.set_dtype_policy\",\n        \"keras.mixed_precision.set_dtype_policy\",  # Legacy\n        \"keras.mixed_precision.set_global_policy\",  # Legacy\n    ]\n)\ndef set_dtype_policy(policy):\n    \"\"\"Sets the default dtype policy globally.\n\n    Example:\n\n    >>> keras.config.set_dtype_policy(\"mixed_float16\")\n    \"\"\"\n    if not isinstance(policy, DTypePolicy):\n        if isinstance(policy, str):\n            if policy.startswith(QUANTIZATION_MODES):\n                policy = _get_quantized_dtype_policy_by_str(policy)\n            else:\n                policy = DTypePolicy(policy)\n        else:\n            raise ValueError(\n                \"Invalid `policy` argument. \"\n                \"Expected the string name of a policy \"\n                \"(such as 'mixed_float16') or a `DTypePolicy` \"\n                f\"instance. Received: policy={policy} \"\n                f\"(of type {type(policy)})\"\n            )\n    global_state.set_global_attribute(\"dtype_policy\", policy)\n\n\n@keras_export(\n    [\n        \"keras.config.dtype_policy\",\n        \"keras.mixed_precision.dtype_policy\",  # Legacy\n        \"keras.mixed_precision.global_policy\",  # Legacy\n    ]\n)\ndef dtype_policy():\n    \"\"\"Returns the current default dtype policy object.\"\"\"\n    policy = global_state.get_global_attribute(\"dtype_policy\", None)\n    if policy is None:\n        policy = DTypePolicy(backend.floatx())\n        set_dtype_policy(policy)\n    return policy\n\n\ndef _get_quantized_dtype_policy_by_str(policy):\n    if not isinstance(policy, str):\n        raise TypeError(f\"`policy` must be a string. Received: policy={policy}\")\n    if not policy.startswith(QUANTIZATION_MODES):\n        raise ValueError(\n            \"`policy` is incompatible with the current supported quantization.\"\n        )\n    split_name = policy.split(\"_from_\")\n    if len(split_name) != 2:\n        raise ValueError(\n            \"Cannot convert `policy` into a valid pair (`mode`, `source_name`) \"\n            \"to instantiate `QuantizedDTypePolicy`. \"\n            f\"Received: policy={policy}\"\n        )\n    mode, source_name = split_name\n    if policy.startswith(\"int8\"):\n        return QuantizedDTypePolicy(mode, source_name)\n    elif policy.startswith(\"int4\"):\n        # Check if mode has block_size component (e.g., \"int4/128\")\n        if \"/\" in mode:\n            return Int4DTypePolicy(mode, source_name)\n        else:\n            return QuantizedDTypePolicy(mode, source_name)\n    elif policy.startswith(\"gptq\"):\n        return GPTQDTypePolicy(mode, source_name)\n    elif policy.startswith(\"awq\"):\n        return AWQDTypePolicy(mode, source_name)\n    elif policy.startswith(\"float8\"):\n        return QuantizedFloat8DTypePolicy(mode, source_name)\n    else:\n        raise NotImplementedError\n"
  },
  {
    "path": "keras/src/dtype_policies/dtype_policy_map.py",
    "content": "import re\nfrom collections.abc import MutableMapping\n\nfrom keras.src import dtype_policies\nfrom keras.src.api_export import keras_export\nfrom keras.src.dtype_policies import DTypePolicy\n\n\n@keras_export([\"keras.dtype_policies.DTypePolicyMap\"])\nclass DTypePolicyMap(DTypePolicy, MutableMapping):\n    \"\"\"Dict-like object mapping layer paths to `DTypePolicy` instances.\n\n    `DTypePolicyMap` can be used in `get_config` in layers and subclasses to\n    support a complex configurations of dtype policies.\n\n    For example, we can modify `get_config` in `layers.MultiHeadAttention` as\n    follows to support the mixing of dtype policies, such as quantization.\n\n    ```python\n    @keras.saving.register_keras_serializable(\"MyPackage\")\n    class MyMultiHeadAttention(keras.layers.MultiHeadAttention):\n        def get_config(self):\n            config = super().get_config()\n            dtype_policy_map = dtype_policies.DTypePolicyMap()\n            for layer in self._flatten_layers():\n                if layer.dtype_policy.quantization_mode is not None:\n                    dtype_policy_map[layer.path] = layer.dtype_policy\n            if len(dtype_policy_map) > 0:\n                config.update({\"dtype\": dtype_policy_map})\n            return config\n    ```\n\n    Internally, `DTypePolicyMap` uses a string as a key and a `DTypePolicy`\n    as the value. Typically, the key used for querying is the `Layer.path`.\n    However, it is also possible to set a regex as the key. See the docstring of\n    `get` for more details.\n\n    Args:\n        default_policy: An optional `DTypePolicy` instance specifying the\n            default dtype policy. If not specified, the value will default to\n            `keras.config.dtype_policy()`.\n        policy_map: An optional dict that maps string to `DTypePolicy`\n            instances. Defaults to `None`\n\n    Example:\n\n    ```python\n    >>> from keras.src import dtype_policies\n    >>> bfloat16 = dtype_policies.DTypePolicy(\"bfloat16\")\n    >>> float16 = dtype_policies.DTypePolicy(\"float16\")\n    >>> float32 = dtype_policies.DTypePolicy(\"float32\")\n    >>> policy_map = DTypePolicyMap(default_policy=float32)\n\n    # Set policies using an exact path and a regex pattern.\n    # Note: \"decoder\" will only match the exact path, not its children.\n    >>> policy_map[\"encoder/layer_0/dense\"] = bfloat16\n    >>> policy_map[\"encoder/.*\"] = float16\n    >>> policy_map[\"decoder\"] = bfloat16\n\n    # 1. An exact match is found and returned directly.\n    >>> policy_map[\"encoder/layer_0/dense\"].name\n    'bfloat16'\n\n    # 2. A regex match is found for a child layer.\n    # It matches the \"encoder/.*\" pattern.\n    >>> policy_map[\"encoder/attention/query\"].name\n    'float16'\n\n    # 3. No implicit prefix matching occurs.\n    # \"decoder/attention\" does not match the key \"decoder\".\n    # The default policy is returned.\n    >>> policy_map[\"decoder/attention\"].name\n    'float32'\n\n    # 4. A ValueError is raised if a path matches multiple patterns.\n    >>> policy_map[\"encoder/attention/.*\"] = bfloat16\n    # \"encoder/attention/query\" now matches two patterns:\n    # - \"encoder/.*\"\n    # - \"encoder/attention/.*\"\n    >>> try:\n    ...     policy_map[\"encoder/attention/query\"]\n    ... except ValueError as e:\n    ...     print(e)\n    Path 'encoder/attention/query' matches multiple dtype policy ..\n    ```\n    \"\"\"\n\n    def __init__(self, default_policy=None, policy_map=None):\n        if isinstance(default_policy, DTypePolicyMap):\n            raise ValueError(\"`default_policy` cannot be a `DTypePolicyMap`.\")\n        if policy_map is not None and not isinstance(policy_map, dict):\n            raise TypeError(\n                \"If specified, `policy_map` must be a dict. \"\n                f\"Received: policy_map={policy_map} of type {type(policy_map)}\"\n            )\n        self._default_policy_arg = default_policy\n        self._default_policy = dtype_policies.get(default_policy)\n        self._policy_map = policy_map or dict()\n\n    @property\n    def name(self):\n        return f\"map_{self.default_policy._name}\"\n\n    @property\n    def default_policy(self):\n        \"\"\"The default dtype policy.\n\n        If `default_policy` is not specified in the constructor, this property\n        will be `keras.config.dtype_policy()`.\n        \"\"\"\n        return dtype_policies.get(self._default_policy)\n\n    @property\n    def variable_dtype(self):\n        return self.default_policy.variable_dtype\n\n    @property\n    def compute_dtype(self):\n        return self.default_policy.compute_dtype\n\n    @property\n    def quantization_mode(self):\n        return self.default_policy.quantization_mode\n\n    def __getitem__(self, key):\n        \"\"\"Retrieves the corresponding `DTypePolicy` by the string key.\n\n        This method first attempts an exact key match. If no exact match is\n        found, it treats all keys in the map as regular expression patterns\n        and uses `re.fullmatch` to find a policy.\n\n        For example, to apply a policy to all sublayers of an `encoder` block,\n        the key should be explicitly set to `\"encoder/.*\"`. A key of\n        `\"encoder\"` will only match the layer with that exact path.\n\n        Args:\n            key: str. The key to query for a `DTypePolicy`.\n\n        Returns:\n            The corresponding `DTypePolicy`. If no match is found, this method\n            returns `self.default_policy`.\n\n        Raises:\n            ValueError: If the `key` matches more than one regex pattern in the\n            map.\n\n        Example:\n\n        ```python\n        >>> from keras.src import dtype_policies\n        >>> bfloat16 = dtype_policies.DTypePolicy(\"bfloat16\")\n        >>> float16 = dtype_policies.DTypePolicy(\"float16\")\n        >>> float32 = dtype_policies.DTypePolicy(\"float32\")\n        >>> policy_map = DTypePolicyMap(default_policy=float32)\n\n        # Set policies using an exact path and a regex pattern.\n        # Note: \"decoder\" will only match the exact path, not its children.\n        >>> policy_map[\"encoder/layer_0/dense\"] = bfloat16\n        >>> policy_map[\"encoder/.*\"] = float16\n        >>> policy_map[\"decoder\"] = bfloat16\n\n        # 1. An exact match is found and returned directly.\n        >>> policy_map[\"encoder/layer_0/dense\"].name\n        'bfloat16'\n\n        # 2. A regex match is found for a child layer.\n        # It matches the \"encoder/.*\" pattern.\n        >>> policy_map[\"encoder/attention/query\"].name\n        'float16'\n\n        # 3. No implicit prefix matching occurs.\n        # \"decoder/attention\" does not match the key \"decoder\".\n        # The default policy is returned.\n        >>> policy_map[\"decoder/attention\"].name\n        'float32'\n\n        # 4. A ValueError is raised if a path matches multiple patterns.\n        >>> policy_map[\"encoder/attention/.*\"] = bfloat16\n        # \"encoder/attention/query\" now matches two patterns:\n        # - \"encoder/.*\"\n        # - \"encoder/attention/.*\"\n        >>> try:\n        ...     policy_map[\"encoder/attention/query\"]\n        ... except ValueError as e:\n        ...     print(e)\n        Path 'encoder/attention/query' matches multiple dtype policy ..\n        ```\n        \"\"\"\n        # 1. Check for an exact match.\n        if key in self._policy_map:\n            return self._policy_map[key]\n\n        # 2. Fallback to a full regex match.\n        matching_keys = [\n            pattern\n            for pattern in self._policy_map\n            if re.fullmatch(pattern, key)\n        ]\n\n        # 3. Handle cases based on the number of matches found.\n        if len(matching_keys) > 1:\n            raise ValueError(\n                f\"Path '{key}' matches multiple dtype policy \"\n                f\"specification keys: {matching_keys}. Please make \"\n                \"sure each path only matches at most \"\n                \"one dtype policy specification key in the DTypePolicyMap.\"\n            )\n        elif len(matching_keys) == 1:\n            return self._policy_map[matching_keys[0]]\n\n        # 4. If there were no matches, return the default.\n        return self.default_policy\n\n    def __setitem__(self, key, policy):\n        \"\"\"Insert `DTypePolicy` to the `DTypePolicyMap`.\n\n        Args:\n            key: String key for the `DTypePolicy`.\n            policy: The `DTypePolicy`.\n        \"\"\"\n        if key in self._policy_map:\n            raise ValueError(\n                f\"{key} already exist in the DTypePolicyMap with \"\n                f\"value {self._policy_map[key]}. Please make sure to \"\n                \"not use duplicated keys.\"\n            )\n        try:\n            policy = dtype_policies.get(policy)\n        except Exception:\n            raise ValueError(\n                \"Cannot interpret the assigned value by \"\n                \"`keras.dtype_policies.get`. \"\n                f\"Received: {policy} of type {type(policy)}\"\n            )\n        self._policy_map[key] = policy\n\n    def __delitem__(self, key):\n        # Let the dict to handle the key missing error\n        return self._policy_map.pop(key)\n\n    def __contains__(self, key):\n        return key in self._policy_map\n\n    def get_config(self):\n        from keras.src.saving import serialization_lib\n\n        policy_map = self._policy_map\n        if self._default_policy_arg is None:\n            # `default_policy=None` enables us to defer to\n            # `keras.config.dtype_policy()` during loading.\n            # To support this feature, we can set `_name` and `_source_name` to\n            # `None` in `DTypePolicy` and `QuantizedDTypePolicy`,\n            # respectively.\n            for policy in policy_map.values():\n                if isinstance(policy, dtype_policies.QuantizedDTypePolicy):\n                    policy._name = None\n                    policy._source_name = None\n                elif isinstance(policy, dtype_policies.DTypePolicy):\n                    policy._name = None\n        return {\n            \"default_policy\": self._default_policy_arg,\n            \"policy_map\": serialization_lib.serialize_keras_object(policy_map),\n        }\n\n    @classmethod\n    def from_config(cls, config, custom_objects=None):\n        from keras.src.saving import serialization_lib\n\n        config = config.copy()\n        config[\"policy_map\"] = serialization_lib.deserialize_keras_object(\n            config[\"policy_map\"], custom_objects=custom_objects\n        )\n        return cls(**config)\n\n    def __len__(self):\n        return len(self._policy_map)\n\n    def __iter__(self):\n        return iter(self._policy_map)\n\n    def __repr__(self):\n        default_policy = (\n            self._default_policy.name\n            if self._default_policy is not None\n            else None\n        )\n        mapping = []\n        for k, v in self._policy_map.items():\n            mapping.append((k, v.name))\n        return (\n            f\"<DTypePolicyMap at {hex(id(self))} \"\n            f\"default_policy={default_policy}, \"\n            f\"mapping={mapping}>\"\n        )\n"
  },
  {
    "path": "keras/src/dtype_policies/dtype_policy_map_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import dtype_policies\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import saving\nfrom keras.src import testing\nfrom keras.src.dtype_policies.dtype_policy import dtype_policy\nfrom keras.src.dtype_policies.dtype_policy import set_dtype_policy\nfrom keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap\n\n\n@pytest.mark.skipif(testing.jax_uses_gpu(), reason=\"Leads to core dumps on CI\")\nclass DTypePolicyMapTest(testing.TestCase):\n    def setUp(self):\n        super().setUp()\n        self._global_dtype_policy = dtype_policy()\n\n    def tearDown(self):\n        super().tearDown()\n        set_dtype_policy(self._global_dtype_policy)\n\n    @pytest.mark.requires_trainable_backend\n    def test_basic_usage(self):\n        # Create a subclass that might contain mixing dtype policies for\n        # sublayers.\n        # It is important to ensure that `dtype` is passed to sublayers and\n        # that each sublayer has a unique `name`.\n        @saving.register_keras_serializable()\n        class Subclass(layers.Layer):\n            def __init__(self, dtype=None, name=\"subclass\", **kwargs):\n                super().__init__(dtype=dtype, name=name, **kwargs)\n                self.dense = layers.Dense(8, dtype=dtype, name=f\"{name}_dense\")\n                self.bn = layers.BatchNormalization(\n                    dtype=dtype, name=f\"{name}_bn\"\n                )\n                self.relu = layers.ReLU(dtype=dtype, name=f\"{name}_relu\")\n\n            def call(self, inputs, training=None):\n                return self.relu(self.bn(self.dense(inputs), training=training))\n\n            def get_config(self):\n                # Typically, we only need to record the quantized policy for\n                # `DTypePolicyMap`\n                config = super().get_config()\n                dtype_policy_map = DTypePolicyMap()\n                for layer in self._flatten_layers():\n                    if layer.quantization_mode is not None:\n                        dtype_policy_map[layer.path] = layer.dtype_policy\n                if len(dtype_policy_map) > 0:\n                    config.update({\"dtype\": dtype_policy_map})\n                return config\n\n        # Instantiate the model\n        inputs = layers.Input([4])\n        outputs = Subclass()(inputs)\n        model = models.Model(inputs, outputs)\n\n        # Quantize the model to make mixing of dtype policies in sublayers\n        model.quantize(\"int8\")\n        for layer in model._flatten_layers():\n            if isinstance(layer, layers.Dense):\n                self.assertEqual(\n                    layer.dtype_policy,\n                    dtype_policies.QuantizedDTypePolicy(\"int8\"),\n                )\n            elif isinstance(layer, layers.BatchNormalization):\n                self.assertEqual(\n                    layer.dtype_policy, dtype_policies.DTypePolicy()\n                )\n            elif isinstance(layer, layers.ReLU):\n                self.assertEqual(\n                    layer.dtype_policy, dtype_policies.DTypePolicy()\n                )\n\n        # Verify the output after saving and loading\n        x = np.random.uniform(size=[16, 4])\n        temp_dir = self.get_temp_dir()\n        y = model(x, training=False)\n        model.save(f\"{temp_dir}/model.keras\")\n        reloaded_model = saving.load_model(f\"{temp_dir}/model.keras\")\n        reloaded_y = reloaded_model(x, training=False)\n        self.assertAllClose(y, reloaded_y)\n\n    def test_add(self):\n        dtype_policy_map = DTypePolicyMap()\n        dtype_policy_map[\"layer/dense_0\"] = dtype_policies.DTypePolicy(\n            \"bfloat16\"\n        )\n        dtype_policy_map[\"layer/dense_1\"] = dtype_policies.QuantizedDTypePolicy(\n            \"int8\", \"mixed_bfloat16\"\n        )\n        dtype_policy_map[\"layer/dense_2\"] = (\n            dtype_policies.QuantizedFloat8DTypePolicy(\"float8\", \"mixed_float16\")\n        )\n\n        self.assertLen(dtype_policy_map, 3)\n\n        policy = dtype_policy_map[\"layer/dense_0\"]\n        self.assertIsInstance(policy, dtype_policies.DTypePolicy)\n        self.assertEqual(policy.name, \"bfloat16\")\n\n        policy = dtype_policy_map[\"layer/dense_1\"]\n        self.assertIsInstance(policy, dtype_policies.QuantizedDTypePolicy)\n        self.assertEqual(policy._source_name, \"mixed_bfloat16\")\n        self.assertEqual(policy.quantization_mode, \"int8\")\n\n        policy = dtype_policy_map[\"layer/dense_2\"]\n        self.assertIsInstance(policy, dtype_policies.QuantizedFloat8DTypePolicy)\n        self.assertEqual(policy._source_name, \"mixed_float16\")\n        self.assertEqual(policy.quantization_mode, \"float8\")\n\n        with self.assertRaisesRegex(\n            ValueError, \"layer/dense_0 already exist in the DTypePolicyMap\"\n        ):\n            dtype_policy_map[\"layer/dense_0\"] = dtype_policies.DTypePolicy(\n                \"float32\"\n            )\n\n        with self.assertRaisesRegex(\n            ValueError, \"Cannot interpret the assigned value.\"\n        ):\n            dtype_policy_map[\"layer/dense_3\"] = 123\n\n    def test_get(self):\n        # 1. Setup\n        bfloat16_policy = dtype_policies.DTypePolicy(\"bfloat16\")\n        int8_policy = dtype_policies.QuantizedDTypePolicy(\n            \"int8\", \"mixed_bfloat16\"\n        )\n        float32_policy = dtype_policies.DTypePolicy(\"float32\")\n        float16_policy = dtype_policies.DTypePolicy(\"float16\")\n\n        policy_map = DTypePolicyMap()\n        # Policy for an exact layer path\n        policy_map[\"model/encoder/layer_0/dense\"] = bfloat16_policy\n        # Policy for a layer that is also a prefix of another layer's name\n        policy_map[\"model/encoder/attention/query\"] = int8_policy\n        # Regex policies for entire scopes MUST include wildcards\n        policy_map[\"model/decoder/.*\"] = float32_policy\n        policy_map[\"model/decoder/attention/.*\"] = float16_policy\n\n        # 2. Test exact match\n        self.assertEqual(\n            policy_map[\"model/encoder/layer_0/dense\"], bfloat16_policy\n        )\n        self.assertEqual(\n            policy_map[\"model/encoder/attention/query\"], int8_policy\n        )\n\n        # 3. Test successful regex fallback (explicit wildcard)\n        # \"model/decoder/.*\" should match its children.\n        self.assertEqual(policy_map[\"model/decoder/layer_0\"], float32_policy)\n\n        # 4. Test that partial matches are ignored\n        # The exact key \"model/encoder/attention/query\" should not match\n        # \"model/encoder/attention/query_norm\" without a wildcard.\n        self.assertEqual(\n            policy_map[\"model/encoder/attention/query_norm\"],\n            policy_map.default_policy,\n        )\n        # A plain key \"model/decoder\" will not match \"model/decoder/layer_0\"\n        policy_map[\"model/decoder\"] = bfloat16_policy  # Add exact key\n        self.assertEqual(policy_map[\"model/decoder/layer_0\"], float32_policy)\n        # Still matches the more general regex\n        self.assertEqual(policy_map[\"model/decoder\"], bfloat16_policy)\n\n        # 5. Test no match\n        self.assertEqual(\n            policy_map[\"model/embedding\"], policy_map.default_policy\n        )\n\n        # 6. Test multiple regex matches causing a ValueError\n        # \"model/decoder/attention/output\" matches two regex patterns:\n        # - \"model/decoder/.*\"\n        # - \"model/decoder/attention/.*\"\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Path 'model/decoder/attention/output' matches multiple \"\n            \"dtype policy\",\n        ):\n            _ = policy_map[\"model/decoder/attention/output\"]\n\n    def test_delete(self):\n        dtype_policy_map = DTypePolicyMap()\n        dtype_policy_map[\"layer/dense_0\"] = dtype_policies.DTypePolicy(\n            \"bfloat16\"\n        )\n        dtype_policy_map[\"layer/dense_1\"] = dtype_policies.QuantizedDTypePolicy(\n            \"int8\", \"mixed_bfloat16\"\n        )\n\n        self.assertEqual(\n            dtype_policy_map.pop(\"layer/dense_0\"),\n            dtype_policies.DTypePolicy(\"bfloat16\"),\n        )\n        with self.assertRaises(KeyError):\n            dtype_policy_map.pop(\"layer/dense_0\")\n\n        # Test `del`, causing no hit\n        del dtype_policy_map[\"layer/dense_1\"]\n        self.assertEqual(\n            dtype_policy_map[\"layer/dense_1\"], dtype_policy_map.default_policy\n        )\n\n        self.assertLen(dtype_policy_map, 0)\n\n    def test_len(self):\n        dtype_policy_map = DTypePolicyMap()\n        self.assertLen(dtype_policy_map, 0)\n\n        dtype_policy_map[\"layer/dense_0\"] = dtype_policies.DTypePolicy(\n            \"bfloat16\"\n        )\n        dtype_policy_map[\"layer/dense_1\"] = dtype_policies.QuantizedDTypePolicy(\n            \"int8\", \"mixed_bfloat16\"\n        )\n        self.assertLen(dtype_policy_map, 2)\n\n    def test_iter(self):\n        dtype_policy_map = DTypePolicyMap()\n        dtype_policy_map[\"layer/dense_0\"] = dtype_policies.DTypePolicy(\n            \"bfloat16\"\n        )\n        dtype_policy_map[\"layer/dense_1\"] = dtype_policies.QuantizedDTypePolicy(\n            \"int8\", \"mixed_bfloat16\"\n        )\n\n        self.assertEqual(\n            list(dtype_policy_map.keys()), [\"layer/dense_0\", \"layer/dense_1\"]\n        )\n\n        keys = []\n        values = []\n        for k, v in dtype_policy_map.items():\n            keys.append(k)\n            values.append(v)\n        self.assertEqual(keys, [\"layer/dense_0\", \"layer/dense_1\"])\n        self.assertEqual(\n            values,\n            [\n                dtype_policies.DTypePolicy(\"bfloat16\"),\n                dtype_policies.QuantizedDTypePolicy(\"int8\", \"mixed_bfloat16\"),\n            ],\n        )\n\n    def test_in(self):\n        dtype_policy_map = DTypePolicyMap()\n        dtype_policy_map[\"layer/dense_0\"] = dtype_policies.DTypePolicy(\n            \"bfloat16\"\n        )\n        dtype_policy_map[\"layer/dense_1\"] = dtype_policies.QuantizedDTypePolicy(\n            \"int8\", \"mixed_bfloat16\"\n        )\n\n        self.assertTrue(\"layer/dense_0\" in dtype_policy_map)\n        self.assertTrue(\"layer/dense_1\" in dtype_policy_map)\n        self.assertFalse(\"layer/dense_2\" in dtype_policy_map)\n\n    def test_default_policy(self):\n        # Test default_policy is set to `\"float32\"`\n        dtype_policy_map = DTypePolicyMap(default_policy=\"mixed_bfloat16\")\n        dtype_policy_map[\"layer/dense_0\"] = dtype_policies.DTypePolicy(\n            \"mixed_bfloat16\"\n        )\n        dtype_policy_map[\"layer/dense_1\"] = dtype_policies.QuantizedDTypePolicy(\n            \"int8\", \"mixed_bfloat16\"\n        )\n        config = dtype_policy_map.get_config()\n        dtype_policy_map = DTypePolicyMap.from_config(config)\n        self.assertEqual(\n            dtype_policy_map[\"layer/dense_0\"],\n            dtype_policies.DTypePolicy(\"mixed_bfloat16\"),\n        )\n        self.assertEqual(\n            dtype_policy_map[\"layer/dense_1\"],\n            dtype_policies.QuantizedDTypePolicy(\"int8\", \"mixed_bfloat16\"),\n        )\n        # No hit, defers to `dtype_policy_map.default_policy`\n        self.assertEqual(\n            dtype_policy_map[\"layer/dense_2\"], dtype_policy_map.default_policy\n        )\n\n        # Test that default_policy defers to `keras.config.dtype_policy()`\n        # during loading\n        set_dtype_policy(\"bfloat16\")\n        dtype_policy_map = DTypePolicyMap()\n        dtype_policy_map[\"layer/dense_0\"] = dtype_policies.DTypePolicy(\n            \"mixed_bfloat16\"\n        )\n        dtype_policy_map[\"layer/dense_1\"] = dtype_policies.QuantizedDTypePolicy(\n            \"int8\", \"mixed_bfloat16\"\n        )\n        config = dtype_policy_map.get_config()\n        dtype_policy_map = DTypePolicyMap.from_config(config)\n        self.assertEqual(\n            dtype_policy_map[\"layer/dense_0\"],\n            dtype_policies.DTypePolicy(\"bfloat16\"),\n        )\n        self.assertEqual(\n            dtype_policy_map[\"layer/dense_1\"],\n            dtype_policies.QuantizedDTypePolicy(\"int8\", \"bfloat16\"),\n        )\n        # No hit, defers to `dtype_policy_map.default_policy` which is\n        # `keras.config.dtype_policy()`\n        self.assertEqual(\n            dtype_policy_map[\"layer/dense_2\"], dtype_policy_map.default_policy\n        )\n        self.assertEqual(\n            dtype_policy_map[\"layer/dense_2\"], dtype_policies.get(\"bfloat16\")\n        )\n\n    def test_serialization(self):\n        dtype_policy_map = DTypePolicyMap(default_policy=\"mixed_bfloat16\")\n        dtype_policy_map[\"layer/dense_0\"] = dtype_policies.DTypePolicy(\n            \"mixed_bfloat16\"\n        )\n        dtype_policy_map[\"layer/dense_1\"] = dtype_policies.QuantizedDTypePolicy(\n            \"int8\", \"mixed_bfloat16\"\n        )\n\n        config = dtype_policies.serialize(dtype_policy_map)\n        reloaded_dtype_policy_map = dtype_policies.deserialize(config)\n        self.assertEqual(\n            dtype_policy_map.default_policy,\n            reloaded_dtype_policy_map.default_policy,\n        )\n        for k, v in dtype_policy_map.items():\n            self.assertEqual(reloaded_dtype_policy_map[k], v)\n\n        # Test that config remains intact during deserialization\n        config = dtype_policy_map.get_config()\n        original_config = config.copy()\n        DTypePolicyMap.from_config(config)\n        self.assertDictEqual(config, original_config)\n\n    def test_repr(self):\n        dtype_policy_map = DTypePolicyMap()\n        dtype_policy_map[\"layer/dense_0\"] = dtype_policies.DTypePolicy(\n            \"mixed_bfloat16\"\n        )\n        repr_str = repr(dtype_policy_map)\n        self.assertTrue(\"DTypePolicyMap\" in repr_str)\n        self.assertTrue(\"default_policy\" in repr_str)\n        self.assertTrue(\n            \"mapping=[('layer/dense_0', 'mixed_bfloat16')]\" in repr_str\n        )\n\n    def test_invalid_policy_map(self):\n        with self.assertRaisesRegex(\n            TypeError, \"If specified, `policy_map` must be a dict.\"\n        ):\n            DTypePolicyMap(policy_map=123)\n\n        with self.assertRaisesRegex(\n            TypeError, \"If specified, `policy_map` must be a dict.\"\n        ):\n            DTypePolicyMap(\n                policy_map=dtype_policies.DTypePolicy(\"mixed_bfloat16\")\n            )\n"
  },
  {
    "path": "keras/src/dtype_policies/dtype_policy_test.py",
    "content": "from absl.testing import parameterized\n\nfrom keras.src.dtype_policies import deserialize\nfrom keras.src.dtype_policies import get\nfrom keras.src.dtype_policies import serialize\nfrom keras.src.dtype_policies.dtype_policy import AWQDTypePolicy\nfrom keras.src.dtype_policies.dtype_policy import DTypePolicy\nfrom keras.src.dtype_policies.dtype_policy import FloatDTypePolicy\nfrom keras.src.dtype_policies.dtype_policy import QuantizedDTypePolicy\nfrom keras.src.dtype_policies.dtype_policy import QuantizedFloat8DTypePolicy\nfrom keras.src.dtype_policies.dtype_policy import dtype_policy\nfrom keras.src.dtype_policies.dtype_policy import set_dtype_policy\nfrom keras.src.quantizers.gptq_config import GPTQConfig\nfrom keras.src.testing import test_case\n\n\nclass DTypePolicyTest(test_case.TestCase):\n    \"\"\"Test `DTypePolicy`.\n\n    In the tests, we also test `DTypePolicy` for historical reasons.\n    \"\"\"\n\n    def setUp(self):\n        \"\"\"Record the global dtype policy before each test.\"\"\"\n        super().setUp()\n        self._global_dtype_policy = dtype_policy()\n\n    def tearDown(self):\n        super().tearDown()\n        \"\"\"Restore the global dtype policy after each test.\"\"\"\n        set_dtype_policy(self._global_dtype_policy)\n\n    def test_initialization_valid_name(self):\n        \"\"\"Test initialization with a valid name.\"\"\"\n        policy = DTypePolicy(\"mixed_float16\")\n        self.assertEqual(policy.compute_dtype, \"float16\")\n        self.assertEqual(policy.variable_dtype, \"float32\")\n\n        policy = FloatDTypePolicy(\"mixed_float16\")\n        self.assertEqual(policy.compute_dtype, \"float16\")\n        self.assertEqual(policy.variable_dtype, \"float32\")\n\n    @parameterized.named_parameters(\n        (\"float32\", \"float32\", \"float32\", \"float32\"),\n        (\"float16\", \"float16\", \"float16\", \"float16\"),\n        (\"bfloat16\", \"bfloat16\", \"bfloat16\", \"bfloat16\"),\n        (\"mixed_float16\", \"mixed_float16\", \"float16\", \"float32\"),\n        (\"mixed_bfloat16\", \"mixed_bfloat16\", \"bfloat16\", \"float32\"),\n    )\n    def test_initialization_from_global(\n        self,\n        global_dtype_policy,\n        expected_compute_dtype,\n        expected_variable_dtype,\n    ):\n        set_dtype_policy(global_dtype_policy)\n\n        policy = DTypePolicy(name=None)\n        self.assertEqual(policy.name, global_dtype_policy)\n        self.assertEqual(policy.compute_dtype, expected_compute_dtype)\n        self.assertEqual(policy.variable_dtype, expected_variable_dtype)\n\n        policy = FloatDTypePolicy(name=None)\n        self.assertEqual(policy.name, global_dtype_policy)\n        self.assertEqual(policy.compute_dtype, expected_compute_dtype)\n        self.assertEqual(policy.variable_dtype, expected_variable_dtype)\n\n    def test_initialization_invalid_name(self):\n        \"\"\"Test initialization with an invalid name.\"\"\"\n        with self.assertRaisesRegex(ValueError, \"Cannot convert\"):\n            DTypePolicy(\"invalid_name\")\n\n        with self.assertRaisesRegex(ValueError, \"Cannot convert\"):\n            FloatDTypePolicy(\"invalid_name\")\n\n    def test_initialization_non_string_name(self):\n        \"\"\"Test initialization with a non-string name.\"\"\"\n        with self.assertRaisesRegex(TypeError, \"'name' must be a string\"):\n            DTypePolicy(123)\n\n        with self.assertRaisesRegex(TypeError, \"'name' must be a string\"):\n            FloatDTypePolicy(123)\n\n    def test_properties_mixed_float16(self):\n        \"\"\"Test properties for 'mixed_float16'.\"\"\"\n        policy = DTypePolicy(\"mixed_float16\")\n        self.assertEqual(policy.compute_dtype, \"float16\")\n        self.assertEqual(policy.variable_dtype, \"float32\")\n\n        policy = FloatDTypePolicy(\"mixed_float16\")\n        self.assertEqual(policy.compute_dtype, \"float16\")\n        self.assertEqual(policy.variable_dtype, \"float32\")\n\n    def test_properties_mixed_bfloat16(self):\n        \"\"\"Test properties for 'mixed_bfloat16'.\"\"\"\n        policy = DTypePolicy(\"mixed_bfloat16\")\n        self.assertEqual(policy.compute_dtype, \"bfloat16\")\n        self.assertEqual(policy.variable_dtype, \"float32\")\n\n        policy = FloatDTypePolicy(\"mixed_bfloat16\")\n        self.assertEqual(policy.compute_dtype, \"bfloat16\")\n        self.assertEqual(policy.variable_dtype, \"float32\")\n\n    def test_initialization_with_invalid_name_behaviour(self):\n        \"\"\"Test initialization behavior with an invalid name.\"\"\"\n        with self.assertRaisesRegex(ValueError, \"Cannot convert\"):\n            DTypePolicy(\"invalid_name\")\n\n        with self.assertRaisesRegex(ValueError, \"Cannot convert\"):\n            FloatDTypePolicy(\"invalid_name\")\n\n    def test_properties(self):\n        \"\"\"Test variable_dtype, compute_dtype, and name properties.\"\"\"\n        policy = DTypePolicy(\"mixed_float16\")\n        self.assertEqual(policy.variable_dtype, \"float32\")\n        self.assertEqual(policy.compute_dtype, \"float16\")\n        self.assertEqual(policy.name, \"mixed_float16\")\n        self.assertIsNone(policy.quantization_mode)\n\n        policy = FloatDTypePolicy(\"mixed_float16\")\n        self.assertEqual(policy.variable_dtype, \"float32\")\n        self.assertEqual(policy.compute_dtype, \"float16\")\n        self.assertEqual(policy.name, \"mixed_float16\")\n        self.assertIsNone(policy.quantization_mode)\n\n    def test_properties_uint8(self):\n        \"\"\"Test properties for 'uint8'.\"\"\"\n        policy = DTypePolicy(\"uint8\")\n        self.assertEqual(policy.compute_dtype, \"uint8\")\n        self.assertEqual(policy.variable_dtype, \"uint8\")\n        self.assertEqual(policy.name, \"uint8\")\n\n        policy = FloatDTypePolicy(\"uint8\")\n        self.assertEqual(policy.compute_dtype, \"uint8\")\n        self.assertEqual(policy.variable_dtype, \"uint8\")\n        self.assertEqual(policy.name, \"uint8\")\n\n    def test_repr(self):\n        \"\"\"Test __repr__ method.\"\"\"\n        policy = DTypePolicy(\"mixed_float16\")\n        self.assertEqual(repr(policy), '<DTypePolicy \"mixed_float16\">')\n\n        policy = FloatDTypePolicy(\"mixed_float16\")\n        self.assertEqual(repr(policy), '<DTypePolicy \"mixed_float16\">')\n\n    def test_get_config_from_config(self):\n        \"\"\"Test get_config and from_config methods.\"\"\"\n        # Test DTypePolicy\n        policy = DTypePolicy(\"mixed_float16\")\n        config = policy.get_config()\n        self.assertEqual(config, {\"name\": \"mixed_float16\"})\n        new_policy = DTypePolicy.from_config(config)\n        self.assertEqual(new_policy.name, \"mixed_float16\")\n\n        # Test FloatDTypePolicy\n        policy = FloatDTypePolicy(\"mixed_float16\")\n        config = policy.get_config()\n        self.assertEqual(config, {\"name\": \"mixed_float16\"})\n        new_policy = FloatDTypePolicy.from_config(config)\n        self.assertEqual(new_policy.name, \"mixed_float16\")\n\n    def test_serialization(self):\n        # Test DTypePolicy\n        policy = DTypePolicy(\"mixed_float16\")\n        config = serialize(policy)\n        reloaded_policy = deserialize(config)\n        self.assertEqual(policy.name, reloaded_policy.name)\n        reloaded_policy = get(config)\n        self.assertEqual(policy.name, reloaded_policy.name)\n\n        # Test FloatDTypePolicy\n        policy = FloatDTypePolicy(\"mixed_float16\")\n        config = serialize(policy)\n        reloaded_policy = deserialize(config)\n        self.assertEqual(policy.name, reloaded_policy.name)\n        reloaded_policy = get(config)\n        self.assertEqual(policy.name, reloaded_policy.name)\n\n    def test_python_serialization(self):\n        \"\"\"Test builtin serialization methods.\"\"\"\n        import copy\n        import pickle\n\n        # Test DTypePolicy\n        policy = DTypePolicy(\"mixed_float16\")\n\n        # copy.deepcopy\n        copied_policy = copy.deepcopy(policy)\n        self.assertEqual(repr(copied_policy), '<DTypePolicy \"mixed_float16\">')\n        # copy.copy\n        copied_policy = copy.copy(policy)\n        self.assertEqual(repr(copied_policy), '<DTypePolicy \"mixed_float16\">')\n        # pickle\n        temp_dir = self.get_temp_dir()\n        with open(f\"{temp_dir}/policy.pickle\", \"wb\") as f:\n            pickle.dump(policy, f)\n        with open(f\"{temp_dir}/policy.pickle\", \"rb\") as f:\n            copied_policy = pickle.load(f)\n        self.assertEqual(repr(copied_policy), '<DTypePolicy \"mixed_float16\">')\n\n        # Test FloatDTypePolicy\n        policy = FloatDTypePolicy(\"mixed_float16\")\n\n        # copy.deepcopy\n        copied_policy = copy.deepcopy(policy)\n        self.assertEqual(repr(copied_policy), '<DTypePolicy \"mixed_float16\">')\n        # copy.copy\n        copied_policy = copy.copy(policy)\n        self.assertEqual(repr(copied_policy), '<DTypePolicy \"mixed_float16\">')\n        # pickle\n        temp_dir = self.get_temp_dir()\n        with open(f\"{temp_dir}/policy.pickle\", \"wb\") as f:\n            pickle.dump(policy, f)\n        with open(f\"{temp_dir}/policy.pickle\", \"rb\") as f:\n            copied_policy = pickle.load(f)\n        self.assertEqual(repr(copied_policy), '<DTypePolicy \"mixed_float16\">')\n\n    def test_eq(self):\n        policy = DTypePolicy(\"mixed_bfloat16\")\n\n        # Test True\n        self.assertEqual(policy, DTypePolicy(\"mixed_bfloat16\"))\n        self.assertEqual(policy, FloatDTypePolicy(\"mixed_bfloat16\"))\n\n        # Test False\n        self.assertNotEqual(policy, \"mixed_float16\")\n        self.assertNotEqual(\n            policy, QuantizedDTypePolicy(\"int8\", \"mixed_bfloat16\")\n        )\n\n\nclass QuantizedDTypePolicyTest(test_case.TestCase):\n    def setUp(self):\n        \"\"\"Record the global dtype policy before each test.\"\"\"\n        super().setUp()\n        self._global_dtype_policy = dtype_policy()\n\n    def tearDown(self):\n        super().tearDown()\n        \"\"\"Restore the global dtype policy after each test.\"\"\"\n        set_dtype_policy(self._global_dtype_policy)\n\n    @parameterized.named_parameters(\n        (\"float32\", \"float32\", \"float32\", \"float32\"),\n        (\"bfloat16\", \"bfloat16\", \"bfloat16\", \"bfloat16\"),\n        (\"mixed_bfloat16\", \"mixed_bfloat16\", \"bfloat16\", \"float32\"),\n    )\n    def test_initialization_for_int8(\n        self, source_name, expected_compute_dtype, expected_variable_dtype\n    ):\n        name = f\"int8_from_{source_name}\"\n        policy = QuantizedDTypePolicy(mode=\"int8\", source_name=source_name)\n        self.assertEqual(policy.name, name)\n        self.assertEqual(policy.compute_dtype, expected_compute_dtype)\n        self.assertEqual(policy.variable_dtype, expected_variable_dtype)\n        self.assertEqual(repr(policy), f'<QuantizedDTypePolicy \"{name}\">')\n\n    @parameterized.named_parameters(\n        (\"float32\", \"float32\", \"float32\", \"float32\"),\n        (\"bfloat16\", \"bfloat16\", \"bfloat16\", \"bfloat16\"),\n        (\"mixed_bfloat16\", \"mixed_bfloat16\", \"bfloat16\", \"float32\"),\n    )\n    def test_initialization_for_int8_from_global(\n        self,\n        global_dtype_policy,\n        expected_compute_dtype,\n        expected_variable_dtype,\n    ):\n        set_dtype_policy(global_dtype_policy)\n        expected_name = f\"int8_from_{global_dtype_policy}\"\n\n        policy = QuantizedDTypePolicy(mode=\"int8\", source_name=None)\n        self.assertEqual(policy.name, expected_name)\n        self.assertEqual(policy.compute_dtype, expected_compute_dtype)\n        self.assertEqual(policy.variable_dtype, expected_variable_dtype)\n\n    @parameterized.named_parameters(\n        (\"float32\", \"float32\", \"float32\", \"float32\"),\n        (\"float16\", \"float16\", \"float16\", \"float16\"),\n        (\"bfloat16\", \"bfloat16\", \"bfloat16\", \"bfloat16\"),\n        (\"mixed_float16\", \"mixed_float16\", \"float16\", \"float32\"),\n        (\"mixed_bfloat16\", \"mixed_bfloat16\", \"bfloat16\", \"float32\"),\n    )\n    def test_initialization_for_float8(\n        self, source_name, expected_compute_dtype, expected_variable_dtype\n    ):\n        name = f\"float8_from_{source_name}\"\n        policy = QuantizedFloat8DTypePolicy(\n            mode=\"float8\", source_name=source_name\n        )\n        self.assertEqual(policy.name, name)\n        self.assertEqual(policy.compute_dtype, expected_compute_dtype)\n        self.assertEqual(policy.variable_dtype, expected_variable_dtype)\n        self.assertEqual(repr(policy), f'<QuantizedFloat8DTypePolicy \"{name}\">')\n\n    @parameterized.named_parameters(\n        (\"float32\", \"float32\", \"float32\", \"float32\"),\n        (\"float16\", \"float16\", \"float16\", \"float16\"),\n        (\"bfloat16\", \"bfloat16\", \"bfloat16\", \"bfloat16\"),\n        (\"mixed_float16\", \"mixed_float16\", \"float16\", \"float32\"),\n        (\"mixed_bfloat16\", \"mixed_bfloat16\", \"bfloat16\", \"float32\"),\n    )\n    def test_initialization_for_float8_from_global(\n        self,\n        global_dtype_policy,\n        expected_compute_dtype,\n        expected_variable_dtype,\n    ):\n        set_dtype_policy(global_dtype_policy)\n        expected_name = f\"float8_from_{global_dtype_policy}\"\n\n        policy = QuantizedFloat8DTypePolicy(mode=\"float8\", source_name=None)\n        self.assertEqual(policy.name, expected_name)\n        self.assertEqual(policy.compute_dtype, expected_compute_dtype)\n        self.assertEqual(policy.variable_dtype, expected_variable_dtype)\n\n    @parameterized.named_parameters(\n        (\"abc\", \"abc\"),\n        (\"abc_from_def\", \"def\"),\n    )\n    def test_initialization_with_invalid_name(self, invalid_name):\n        with self.assertRaisesRegex(ValueError, \"Cannot convert\"):\n            QuantizedDTypePolicy(mode=\"int8\", source_name=invalid_name)\n        with self.assertRaisesRegex(ValueError, \"Cannot convert\"):\n            QuantizedFloat8DTypePolicy(mode=\"float8\", source_name=invalid_name)\n\n    @parameterized.named_parameters(\n        (\"int7\", \"int7\"),\n        (\"float7\", \"float7\"),\n    )\n    def test_initialization_with_invalid_mode(self, invalid_mode):\n        with self.assertRaisesRegex(ValueError, \"Invalid quantization mode.\"):\n            QuantizedDTypePolicy(mode=invalid_mode)\n        with self.assertRaisesRegex(ValueError, \"Invalid quantization mode.\"):\n            QuantizedFloat8DTypePolicy(mode=invalid_mode)\n\n    @parameterized.named_parameters(\n        (\"int8_from_float16\", \"float16\"),\n        (\"int8_from_mixed_float16\", \"mixed_float16\"),\n    )\n    def test_initialization_with_invalid_compute_dtype(self, invalid_name):\n        with self.assertRaisesRegex(ValueError, \"doesn't work well\"):\n            QuantizedDTypePolicy(mode=\"int8\", source_name=invalid_name)\n\n    def test_initialization_non_string_name(self):\n        \"\"\"Test initialization with a non-string name.\"\"\"\n        with self.assertRaisesRegex(TypeError, \"'name' must be a string\"):\n            QuantizedDTypePolicy(mode=\"int8\", source_name=123)\n        with self.assertRaisesRegex(TypeError, \"'name' must be a string\"):\n            QuantizedFloat8DTypePolicy(mode=\"float8\", source_name=123)\n\n    def test_properties(self):\n        # Test int8\n        policy = QuantizedDTypePolicy(mode=\"int8\", source_name=\"mixed_bfloat16\")\n        self.assertEqual(policy.variable_dtype, \"float32\")\n        self.assertEqual(policy.compute_dtype, \"bfloat16\")\n        self.assertEqual(policy.name, \"int8_from_mixed_bfloat16\")\n        self.assertEqual(policy.quantization_mode, \"int8\")\n\n        # Test float8\n        policy = QuantizedFloat8DTypePolicy(\n            mode=\"float8\", source_name=\"mixed_bfloat16\"\n        )\n        self.assertEqual(policy.variable_dtype, \"float32\")\n        self.assertEqual(policy.compute_dtype, \"bfloat16\")\n        self.assertEqual(policy.name, \"float8_from_mixed_bfloat16\")\n        self.assertEqual(policy.quantization_mode, \"float8\")\n        self.assertEqual(policy.amax_history_length, 1024)\n\n        # Test float8 with amax_history_length\n        policy = QuantizedFloat8DTypePolicy(\n            mode=\"float8\", source_name=\"mixed_bfloat16\", amax_history_length=512\n        )\n        self.assertEqual(policy.amax_history_length, 512)\n\n        # Test float8 default_amax_history_length\n        self.assertEqual(\n            QuantizedFloat8DTypePolicy.default_amax_history_length, 1024\n        )\n\n    def test_invalid_properties_for_float8(self):\n        with self.assertRaisesRegex(TypeError, \"must be an integer.\"):\n            QuantizedFloat8DTypePolicy(\n                mode=\"float8\", source_name=\"float32\", amax_history_length=\"512\"\n            )\n        with self.assertRaisesRegex(TypeError, \"must be an integer.\"):\n            QuantizedFloat8DTypePolicy(\n                mode=\"float8\", source_name=\"float32\", amax_history_length=512.0\n            )\n\n    def test_get_config_from_config(self):\n        \"\"\"Test get_config and from_config methods.\"\"\"\n        # Test QuantizedDTypePolicy\n        policy = QuantizedDTypePolicy(mode=\"int8\", source_name=\"mixed_bfloat16\")\n        config = policy.get_config()\n        self.assertEqual(\n            config, {\"mode\": \"int8\", \"source_name\": \"mixed_bfloat16\"}\n        )\n        new_policy = QuantizedDTypePolicy.from_config(config)\n        self.assertEqual(new_policy.name, \"int8_from_mixed_bfloat16\")\n\n        # Test QuantizedFloat8DTypePolicy\n        policy = QuantizedFloat8DTypePolicy(\n            mode=\"float8\", source_name=\"mixed_bfloat16\"\n        )\n        config = policy.get_config()\n        self.assertEqual(\n            config,\n            {\n                \"mode\": \"float8\",\n                \"source_name\": \"mixed_bfloat16\",\n                \"amax_history_length\": 1024,\n            },\n        )\n        new_policy = QuantizedFloat8DTypePolicy.from_config(config)\n        self.assertEqual(new_policy.name, \"float8_from_mixed_bfloat16\")\n\n    def test_serialization(self):\n        # Test QuantizedDTypePolicy\n        policy = QuantizedDTypePolicy(mode=\"int8\", source_name=\"float32\")\n        config = serialize(policy)\n        reloaded_policy = deserialize(config)\n        self.assertEqual(policy.name, reloaded_policy.name)\n        reloaded_policy = get(config)\n        self.assertEqual(policy.name, reloaded_policy.name)\n\n        # Test QuantizedFloat8DTypePolicy\n        policy = QuantizedFloat8DTypePolicy(\n            mode=\"float8\", source_name=\"float32\"\n        )\n        config = serialize(policy)\n        reloaded_policy = deserialize(config)\n        self.assertEqual(policy.name, reloaded_policy.name)\n        reloaded_policy = get(config)\n        self.assertEqual(policy.name, reloaded_policy.name)\n\n    @parameterized.named_parameters(\n        (\n            \"int8_from_mixed_bfloat16\",\n            \"int8\",\n            \"mixed_bfloat16\",\n            '<QuantizedDTypePolicy \"int8_from_mixed_bfloat16\">',\n        ),\n        (\n            \"float8_from_mixed_bfloat16\",\n            \"float8\",\n            \"mixed_bfloat16\",\n            '<QuantizedFloat8DTypePolicy \"float8_from_mixed_bfloat16\">',\n        ),\n    )\n    def test_python_serialization(self, mode, source_name, repr_str):\n        import copy\n        import pickle\n\n        if mode == \"int8\":\n            policy = QuantizedDTypePolicy(mode=mode, source_name=source_name)\n        else:\n            policy = QuantizedFloat8DTypePolicy(\n                mode=mode, source_name=source_name, amax_history_length=123\n            )\n\n        # copy.deepcopy\n        copied_policy = copy.deepcopy(policy)\n        self.assertEqual(repr(copied_policy), repr_str)\n        if mode == \"float8\":\n            self.assertEqual(copied_policy.amax_history_length, 123)\n        # copy.copy\n        copied_policy = copy.copy(policy)\n        self.assertEqual(repr(copied_policy), repr_str)\n        if mode == \"float8\":\n            self.assertEqual(copied_policy.amax_history_length, 123)\n        # pickle\n        temp_dir = self.get_temp_dir()\n        with open(f\"{temp_dir}/policy.pickle\", \"wb\") as f:\n            pickle.dump(policy, f)\n        with open(f\"{temp_dir}/policy.pickle\", \"rb\") as f:\n            copied_policy = pickle.load(f)\n        self.assertEqual(repr(copied_policy), repr_str)\n        if mode == \"float8\":\n            self.assertEqual(copied_policy.amax_history_length, 123)\n\n    def test_serialization_for_float8(self):\n        policy = QuantizedFloat8DTypePolicy(\n            mode=\"float8\", source_name=\"mixed_float16\"\n        )\n        config = serialize(policy)\n        reloaded_policy = deserialize(config)\n        self.assertEqual(policy.name, reloaded_policy.name)\n        self.assertEqual(\n            policy.amax_history_length, reloaded_policy.amax_history_length\n        )\n\n        # Test `dtype_policies.get`\n        reloaded_policy = get(config)\n        self.assertEqual(policy.name, reloaded_policy.name)\n        self.assertEqual(\n            policy.amax_history_length, reloaded_policy.amax_history_length\n        )\n\n    def test_eq(self):\n        policy = QuantizedDTypePolicy(\"int8\", \"mixed_bfloat16\")\n\n        # Test True\n        self.assertEqual(policy, QuantizedDTypePolicy(\"int8\", \"mixed_bfloat16\"))\n\n        # Test False\n        self.assertNotEqual(policy, \"mixed_bfloat16\")\n        self.assertNotEqual(policy, DTypePolicy(\"mixed_bfloat16\"))\n        self.assertNotEqual(\n            policy, QuantizedFloat8DTypePolicy(\"float8\", \"mixed_bfloat16\")\n        )\n\n    @parameterized.named_parameters(\n        (\"int8_from_mixed_bfloat16\", \"int8_from_mixed_bfloat16\"),\n        (\"float8_from_mixed_bfloat16\", \"float8_from_mixed_bfloat16\"),\n    )\n    def test_get_quantized_dtype_policy_by_str(self, name):\n        from keras.src.dtype_policies.dtype_policy import (\n            _get_quantized_dtype_policy_by_str,\n        )\n\n        policy = _get_quantized_dtype_policy_by_str(name)\n        self.assertEqual(policy.name, name)\n\n    def test_invalid_get_quantized_dtype_policy_by_str(self):\n        from keras.src.dtype_policies.dtype_policy import (\n            _get_quantized_dtype_policy_by_str,\n        )\n\n        with self.assertRaisesRegex(TypeError, \"must be a string.\"):\n            _get_quantized_dtype_policy_by_str(123)\n        with self.assertRaisesRegex(\n            ValueError,\n            \"is incompatible with the current supported quantization.\",\n        ):\n            _get_quantized_dtype_policy_by_str(\"float7\")\n\n\nclass DTypePolicyGlobalFunctionsTest(test_case.TestCase):\n    def setUp(self):\n        \"\"\"Reset the global dtype policy before each test.\"\"\"\n        set_dtype_policy(\"float32\")\n\n    def test_set_dtype_policy_valid_string(self):\n        \"\"\"Test set_dtype_policy with a valid string.\"\"\"\n        set_dtype_policy(\"mixed_float16\")\n        policy = dtype_policy()\n        self.assertEqual(policy.name, \"mixed_float16\")\n\n    def test_set_dtype_policy_valid_string_quantized(self):\n        \"\"\"Test set_dtype_policy with a valid string.\"\"\"\n        set_dtype_policy(\"int8_from_mixed_bfloat16\")\n        policy = dtype_policy()\n        self.assertEqual(policy.name, \"int8_from_mixed_bfloat16\")\n\n    def test_set_dtype_policy_valid_policy(self):\n        \"\"\"Test set_dtype_policy with a valid DTypePolicy object.\"\"\"\n        policy_obj = DTypePolicy(\"mixed_float16\")\n        set_dtype_policy(policy_obj)\n        policy = dtype_policy()\n        self.assertEqual(policy.name, \"mixed_float16\")\n\n    def test_set_dtype_policy_valid_policy_quantized(self):\n        \"\"\"Test set_dtype_policy with a valid QuantizedDTypePolicy object.\"\"\"\n        policy_obj = QuantizedDTypePolicy(\n            mode=\"int8\", source_name=\"mixed_bfloat16\"\n        )\n        set_dtype_policy(policy_obj)\n        policy = dtype_policy()\n        self.assertEqual(policy.name, \"int8_from_mixed_bfloat16\")\n\n    def test_set_dtype_policy_invalid(self):\n        \"\"\"Test set_dtype_policy with an invalid input.\"\"\"\n        with self.assertRaisesRegex(ValueError, \"Invalid `policy` argument\"):\n            set_dtype_policy(12345)\n\n    def test_dtype_policy_default(self):\n        \"\"\"Test dtype_policy default value.\"\"\"\n        policy = dtype_policy()\n        self.assertEqual(policy.name, \"float32\")\n\n    def test_get_valid_policy(self):\n        policy = get(\"bfloat16\")\n        self.assertEqual(policy.name, \"bfloat16\")\n\n        policy = get(\"mixed_float16\")\n        self.assertEqual(policy.name, \"mixed_float16\")\n\n        policy = get(DTypePolicy(\"bfloat16\"))\n        self.assertEqual(policy.name, \"bfloat16\")\n\n        policy = get(FloatDTypePolicy(\"mixed_float16\"))\n        self.assertEqual(policy.name, \"mixed_float16\")\n\n    def test_get_valid_policy_quantized(self):\n        policy = get(\"int8_from_mixed_bfloat16\")\n        self.assertEqual(policy.name, \"int8_from_mixed_bfloat16\")\n\n        policy = get(\"float8_from_float32\")\n        self.assertEqual(policy.name, \"float8_from_float32\")\n\n        policy = get(QuantizedDTypePolicy(\"int8\", \"mixed_bfloat16\"))\n        self.assertEqual(policy.name, \"int8_from_mixed_bfloat16\")\n\n        policy = get(QuantizedFloat8DTypePolicy(\"float8\", \"mixed_float16\"))\n        self.assertEqual(policy.name, \"float8_from_mixed_float16\")\n\n    def test_get_invalid_policy(self):\n        with self.assertRaisesRegex(ValueError, \"Cannot convert\"):\n            get(\"mixed_bfloat15\")\n        with self.assertRaisesRegex(\n            ValueError, \"Cannot interpret `dtype` argument.\"\n        ):\n            get(123)\n\n    def test_get_invalid_policy_quantized(self):\n        with self.assertRaisesRegex(ValueError, \"Cannot convert\"):\n            get(\"int8_from_mixed_bfloat15\")\n        with self.assertRaisesRegex(ValueError, \"Cannot convert\"):\n            get(\"int8_from_\")\n        with self.assertRaisesRegex(\n            ValueError, \"Cannot convert `policy` into a valid pair\"\n        ):\n            get(\"int8_abc_\")\n\n\nclass DTypePolicyEdgeCasesTest(test_case.TestCase):\n    def test_empty_name(self):\n        \"\"\"Test initialization with an empty name.\"\"\"\n        with self.assertRaisesRegex(ValueError, \"Cannot convert\"):\n            DTypePolicy(\"\")\n\n    def test_special_character_name(self):\n        \"\"\"Test initialization with special characters in the name.\"\"\"\n        with self.assertRaisesRegex(ValueError, \"Cannot convert\"):\n            DTypePolicy(\"@mixed_float16!\")\n\n    def test_very_long_name(self):\n        \"\"\"Test initialization with a very long name.\"\"\"\n        with self.assertRaisesRegex(ValueError, \"Cannot convert\"):\n            DTypePolicy(\"mixed_float16\" * 100)\n\n    def test_almost_valid_name(self):\n        \"\"\"Test initialization with a name close to a valid one.\"\"\"\n        with self.assertRaisesRegex(ValueError, \"Cannot convert\"):\n            DTypePolicy(\"mixed_float15\")\n\n\nclass QuantizedDTypePolicyEdgeCasesTest(test_case.TestCase):\n    def test_empty_name(self):\n        \"\"\"Test initialization with an empty name.\"\"\"\n        with self.assertRaisesRegex(ValueError, \"Cannot convert\"):\n            QuantizedDTypePolicy(mode=\"int8\", source_name=\"\")\n\n    def test_special_character_name(self):\n        \"\"\"Test initialization with special characters in the name.\"\"\"\n        with self.assertRaisesRegex(ValueError, \"Cannot convert\"):\n            QuantizedDTypePolicy(\n                mode=\"int8\", source_name=\"@int8_from_mixed_bfloat16!\"\n            )\n\n    def test_very_long_name(self):\n        \"\"\"Test initialization with a very long name.\"\"\"\n        with self.assertRaisesRegex(ValueError, \"Cannot convert\"):\n            QuantizedDTypePolicy(\n                mode=\"int8\", source_name=\"int8_from_mixed_bfloat16\" * 100\n            )\n\n    def test_almost_valid_name(self):\n        \"\"\"Test initialization with a name close to a valid one.\"\"\"\n        with self.assertRaisesRegex(ValueError, \"Cannot convert\"):\n            QuantizedDTypePolicy(\n                mode=\"int8\", source_name=\"int7_from_mixed_bfloat16\"\n            )\n\n\nclass DTypePolicyGlobalFunctionsEdgeCasesTest(test_case.TestCase):\n    def setUp(self):\n        \"\"\"Reset the global dtype policy before each test.\"\"\"\n        set_dtype_policy(\"float32\")\n\n    def test_set_policy_multiple_times(self):\n        \"\"\"Test setting the policy multiple times in a row.\"\"\"\n        set_dtype_policy(\"mixed_float16\")\n        policy = dtype_policy()\n        self.assertEqual(policy.name, \"mixed_float16\")\n\n        set_dtype_policy(\"float32\")\n        policy = dtype_policy()\n        self.assertEqual(policy.name, \"float32\")\n\n    def test_set_policy_none(self):\n        \"\"\"Test setting the policy to None.\"\"\"\n        with self.assertRaisesRegex(ValueError, \"Invalid `policy` argument\"):\n            set_dtype_policy(None)\n\n\nclass GPTQConfigErrorHandlingTest(test_case.TestCase):\n    \"\"\"Test error handling in GPTQConfig.\"\"\"\n\n    def test_invalid_weight_bits(self):\n        with self.assertRaisesRegex(ValueError, \"Unsupported weight_bits\"):\n            GPTQConfig(\n                dataset=None,\n                tokenizer=None,\n                weight_bits=5,\n            )\n\n    def test_negative_num_samples(self):\n        with self.assertRaisesRegex(\n            ValueError, \"num_samples must be a positive integer.\"\n        ):\n            GPTQConfig(\n                dataset=None,\n                tokenizer=None,\n                num_samples=-10,\n            )\n\n    def test_zero_sequence_length(self):\n        with self.assertRaisesRegex(\n            ValueError, \"sequence_length must be a positive integer.\"\n        ):\n            GPTQConfig(\n                dataset=None,\n                tokenizer=None,\n                sequence_length=0,\n            )\n\n    def test_invalid_hessian_damping(self):\n        with self.assertRaisesRegex(\n            ValueError, \"hessian_damping must be between 0 and 1.\"\n        ):\n            GPTQConfig(\n                dataset=None,\n                tokenizer=None,\n                hessian_damping=1.5,\n            )\n\n    def test_invalid_group_size(self):\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid group_size. Supported values are -1\"\n        ):\n            GPTQConfig(\n                dataset=None,\n                tokenizer=None,\n                group_size=0,\n            )\n\n\nclass AWQDTypePolicyTest(test_case.TestCase):\n    \"\"\"Test AWQDTypePolicy creation and error handling.\"\"\"\n\n    def test_awq_dtype_policy_creation(self):\n        \"\"\"Test AWQDTypePolicy can be created.\"\"\"\n        policy = AWQDTypePolicy(\"awq/4/128\", source_name=\"float32\")\n        self.assertEqual(policy.weight_bits, 4)\n        self.assertEqual(policy.group_size, 128)\n        self.assertEqual(policy.mode, \"awq\")\n\n    def test_awq_dtype_policy_invalid_bits(self):\n        \"\"\"Test AWQDTypePolicy rejects non-4-bit.\"\"\"\n        with self.assertRaisesRegex(ValueError, \"only supports 4-bit\"):\n            AWQDTypePolicy(\"awq/8/128\", source_name=\"float32\")\n\n    def test_awq_dtype_policy_invalid_format(self):\n        \"\"\"Test AWQDTypePolicy rejects invalid format.\"\"\"\n        with self.assertRaisesRegex(ValueError, \"Invalid mode\"):\n            AWQDTypePolicy(\"awq/4\", source_name=\"float32\")\n"
  },
  {
    "path": "keras/src/export/__init__.py",
    "content": "from keras.src.export.litert import LiteRTExporter\nfrom keras.src.export.litert import export_litert\nfrom keras.src.export.onnx import export_onnx\nfrom keras.src.export.openvino import export_openvino\nfrom keras.src.export.saved_model import ExportArchive\nfrom keras.src.export.saved_model import export_saved_model\nfrom keras.src.export.tfsm_layer import TFSMLayer\n"
  },
  {
    "path": "keras/src/export/export_utils.py",
    "content": "from keras.src import backend\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import tree\nfrom keras.src.utils.module_utils import tensorflow as tf\n\n\ndef get_input_signature(model):\n    \"\"\"Get input signature for model export.\n\n    Args:\n        model: A Keras Model instance.\n\n    Returns:\n        Input signature suitable for model export (always a tuple or list).\n    \"\"\"\n    if not isinstance(model, models.Model):\n        raise TypeError(\n            \"The model must be a `keras.Model`. \"\n            f\"Received: model={model} of the type {type(model)}\"\n        )\n    if not model.built:\n        raise ValueError(\n            \"The model provided has not yet been built. It must be built \"\n            \"before export.\"\n        )\n\n    if isinstance(model, models.Functional):\n        # Functional models expect a single positional argument `inputs`\n        # containing the full nested input structure. We keep the\n        # original behavior of returning a single-element list that\n        # wraps the mapped structure so that downstream exporters\n        # build a tf.function with one positional argument.\n        input_signature = [\n            tree.map_structure(make_input_spec, model._inputs_struct)\n        ]\n    elif isinstance(model, models.Sequential):\n        input_signature = tree.map_structure(make_input_spec, model.inputs)\n    else:\n        # Subclassed models: rely on recorded shapes from the first call.\n        input_signature = _infer_input_signature_from_model(model)\n        if not input_signature or not model._called:\n            raise ValueError(\n                \"The model provided has never called. \"\n                \"It must be called at least once before export.\"\n            )\n    return input_signature\n\n\ndef _infer_input_signature_from_model(model):\n    shapes_dict = getattr(model, \"_build_shapes_dict\", None)\n    if not shapes_dict:\n        return None\n\n    def _make_input_spec(structure):\n        # We need to turn wrapper structures like TrackingDict or _DictWrapper\n        # into plain Python structures because they don't work with jax2tf/JAX.\n        if isinstance(structure, dict):\n            return {k: _make_input_spec(v) for k, v in structure.items()}\n        elif isinstance(structure, tuple):\n            if all(isinstance(d, (int, type(None))) for d in structure):\n                return layers.InputSpec(\n                    shape=(None,) + structure[1:], dtype=model.input_dtype\n                )\n            return tuple(_make_input_spec(v) for v in structure)\n        elif isinstance(structure, list):\n            if all(isinstance(d, (int, type(None))) for d in structure):\n                return layers.InputSpec(\n                    shape=[None] + structure[1:], dtype=model.input_dtype\n                )\n            return [_make_input_spec(v) for v in structure]\n        else:\n            raise ValueError(\n                f\"Unsupported type {type(structure)} for {structure}\"\n            )\n\n    # Always return a flat list preserving the order of shapes_dict values\n    return [_make_input_spec(value) for value in shapes_dict.values()]\n\n\ndef make_input_spec(x):\n    if isinstance(x, layers.InputSpec):\n        if x.shape is None or x.dtype is None:\n            raise ValueError(\n                f\"The `shape` and `dtype` must be provided. Received: x={x}\"\n            )\n        input_spec = x\n    elif isinstance(x, backend.KerasTensor):\n        shape = (None,) + backend.standardize_shape(x.shape)[1:]\n        dtype = backend.standardize_dtype(x.dtype)\n        input_spec = layers.InputSpec(dtype=dtype, shape=shape, name=x.name)\n    elif backend.is_tensor(x):\n        shape = (None,) + backend.standardize_shape(x.shape)[1:]\n        dtype = backend.standardize_dtype(x.dtype)\n        input_spec = layers.InputSpec(dtype=dtype, shape=shape, name=None)\n    else:\n        raise TypeError(\n            f\"Unsupported x={x} of the type ({type(x)}). Supported types are: \"\n            \"`keras.InputSpec`, `keras.KerasTensor` and backend tensor.\"\n        )\n    return input_spec\n\n\ndef make_tf_tensor_spec(x, dynamic_batch=False):\n    \"\"\"Create a TensorSpec from various input types.\n\n    Args:\n        x: Input to convert (tf.TensorSpec, KerasTensor, or backend tensor).\n        dynamic_batch: If True, set the batch dimension to None.\n\n    Returns:\n        A tf.TensorSpec instance.\n    \"\"\"\n    if isinstance(x, tf.TensorSpec):\n        tensor_spec = x\n        # Adjust batch dimension if needed\n        if dynamic_batch and len(tensor_spec.shape) > 0:\n            shape = tuple(\n                None if i == 0 else s for i, s in enumerate(tensor_spec.shape)\n            )\n            tensor_spec = tf.TensorSpec(\n                shape, dtype=tensor_spec.dtype, name=tensor_spec.name\n            )\n    else:\n        input_spec = make_input_spec(x)\n        shape = input_spec.shape\n        # Adjust batch dimension if needed and shape is not None\n        if dynamic_batch and shape is not None and len(shape) > 0:\n            shape = tuple(None if i == 0 else s for i, s in enumerate(shape))\n        tensor_spec = tf.TensorSpec(\n            shape, dtype=input_spec.dtype, name=input_spec.name\n        )\n    return tensor_spec\n\n\ndef convert_spec_to_tensor(spec, replace_none_number=None):\n    shape = backend.standardize_shape(spec.shape)\n    if replace_none_number is not None:\n        replace_none_number = int(replace_none_number)\n        shape = tuple(\n            s if s is not None else replace_none_number for s in shape\n        )\n    return ops.ones(shape, spec.dtype)\n"
  },
  {
    "path": "keras/src/export/litert.py",
    "content": "from keras.src import layers\nfrom keras.src import models\nfrom keras.src import tree\nfrom keras.src.export.export_utils import get_input_signature\nfrom keras.src.utils import io_utils\nfrom keras.src.utils.module_utils import tensorflow as tf\n\n\ndef export_litert(\n    model,\n    filepath,\n    input_signature=None,\n    **kwargs,\n):\n    \"\"\"Export the model as a LiteRT artifact for inference.\n\n    Args:\n        model: The Keras model to export.\n        filepath: The path to save the exported artifact.\n        input_signature: Optional input signature specification. If\n            `None`, it will be inferred.\n        **kwargs: Additional keyword arguments passed to the exporter.\n    \"\"\"\n\n    exporter = LiteRTExporter(\n        model=model,\n        input_signature=input_signature,\n        **kwargs,\n    )\n    exporter.export(filepath)\n    io_utils.print_msg(f\"Saved artifact at '{filepath}'.\")\n\n\nclass LiteRTExporter:\n    \"\"\"Exporter for the LiteRT (TFLite) format.\n\n    This class handles the conversion of Keras models for LiteRT runtime and\n    generates a `.tflite` model file. For efficient inference on mobile and\n    embedded devices, it creates a single callable signature based on the\n    model's `call()` method.\n    \"\"\"\n\n    def __init__(\n        self,\n        model,\n        input_signature=None,\n        **kwargs,\n    ):\n        \"\"\"Initialize the LiteRT exporter.\n\n        Args:\n            model: The Keras model to export\n            input_signature: Input signature specification (e.g., TensorFlow\n                TensorSpec or list of TensorSpec)\n            **kwargs: Additional export parameters\n        \"\"\"\n        self.model = model\n        self.input_signature = input_signature\n        self.kwargs = kwargs\n\n    def export(self, filepath):\n        \"\"\"Exports the Keras model to a TFLite file.\n\n        Args:\n            filepath: Output path for the exported model\n\n        Returns:\n            Path to exported model\n        \"\"\"\n        # 1. Resolve / infer input signature\n        if self.input_signature is None:\n            # Use the standard get_input_signature which handles all model types\n            # and preserves nested structures (dicts, lists, etc.)\n            self.input_signature = get_input_signature(self.model)\n\n        # 2. Determine input structure and create adapter if needed\n        # There are 3 cases:\n        # Case 1: Single input (not nested)\n        # Case 2: Flat list of inputs (list where flattened == original)\n        # Case 3: Nested structure (dicts, nested lists, etc.)\n\n        # Special handling for Functional models: get_input_signature wraps\n        # the structure in a list, so unwrap it for analysis\n        input_struct = self.input_signature\n        if (\n            isinstance(self.input_signature, list)\n            and len(self.input_signature) == 1\n        ):\n            input_struct = self.input_signature[0]\n\n        if not tree.is_nested(input_struct):\n            # Case 1: Single input - use as-is\n            model_to_convert = self.model\n            signature_for_conversion = self.input_signature\n        elif isinstance(input_struct, list) and len(input_struct) == len(\n            tree.flatten(input_struct)\n        ):\n            # Case 2: Flat list of inputs - use as-is\n            model_to_convert = self.model\n            signature_for_conversion = self.input_signature\n        else:\n            # Case 3: Nested structure (dict, nested lists, etc.)\n            # Create adapter model that converts flat list to nested structure\n            adapted_model = self._create_nested_inputs_adapter(input_struct)\n\n            # Flatten signature for TFLite conversion\n            signature_for_conversion = tree.flatten(input_struct)\n\n            # Use adapted model and flat list signature for conversion\n            model_to_convert = adapted_model\n\n        # Store original model reference for later use\n        original_model = self.model\n\n        # Temporarily replace self.model with the model to convert\n        self.model = model_to_convert\n\n        try:\n            # Convert the model to TFLite.\n            tflite_model = self._convert_to_tflite(signature_for_conversion)\n        finally:\n            # Restore original model\n            self.model = original_model\n\n        # Save the TFLite model to the specified file path.\n        if not filepath.endswith(\".tflite\"):\n            raise ValueError(\n                f\"The LiteRT export requires the filepath to end with \"\n                f\"'.tflite'. Got: {filepath}\"\n            )\n\n        with open(filepath, \"wb\") as f:\n            f.write(tflite_model)\n\n        return filepath\n\n    def _create_nested_inputs_adapter(self, input_signature_struct):\n        \"\"\"Create an adapter model that converts flat list inputs to nested\n        structure.\n\n        This adapter allows models expecting nested inputs (dicts, lists, etc.)\n        to be exported to TFLite format (which only supports positional/list\n        inputs).\n\n        Args:\n            input_signature_struct: Nested structure of InputSpecs (dict, list,\n                etc.)\n\n        Returns:\n            A Functional model that accepts flat list inputs and converts to\n            nested\n        \"\"\"\n        # Get flat paths to preserve names and print input mapping\n        paths_and_specs = tree.flatten_with_path(input_signature_struct)\n        paths = [\".\".join(str(e) for e in p) for p, v in paths_and_specs]\n        io_utils.print_msg(f\"Creating adapter for inputs: {paths}\")\n\n        # Create Input layers for TFLite (flat list-based)\n        input_layers = []\n        for path, spec in paths_and_specs:\n            # Extract the input name from spec or path\n            name = (\n                spec.name\n                if hasattr(spec, \"name\") and spec.name\n                else (str(path[-1]) if path else \"input\")\n            )\n\n            input_layer = layers.Input(\n                shape=spec.shape[1:],  # Remove batch dimension\n                dtype=spec.dtype,\n                name=name,\n            )\n            input_layers.append(input_layer)\n\n        # Reconstruct the nested structure from flat list\n        inputs_structure = tree.pack_sequence_as(\n            input_signature_struct, input_layers\n        )\n\n        # Call the original model with nested inputs\n        outputs = self.model(inputs_structure)\n\n        # Build as Functional model (flat list inputs -> nested -> model ->\n        # output)\n        adapted_model = models.Model(inputs=input_layers, outputs=outputs)\n\n        # Preserve the original model's variables\n        adapted_model._variables = self.model.variables\n        adapted_model._trainable_variables = self.model.trainable_variables\n        adapted_model._non_trainable_variables = (\n            self.model.non_trainable_variables\n        )\n\n        return adapted_model\n\n    def _convert_to_tflite(self, input_signature):\n        \"\"\"Converts the Keras model to TFLite format.\n\n        Returns:\n            A bytes object containing the serialized TFLite model.\n        \"\"\"\n        # Try direct conversion first for all models\n        try:\n            converter = tf.lite.TFLiteConverter.from_keras_model(self.model)\n            converter.target_spec.supported_ops = [\n                tf.lite.OpsSet.TFLITE_BUILTINS,\n                tf.lite.OpsSet.SELECT_TF_OPS,\n            ]\n            # Keras 3 only supports resource variables\n            converter.experimental_enable_resource_variables = True\n\n            # Apply any additional converter settings from kwargs\n            self._apply_converter_kwargs(converter)\n\n            tflite_model = converter.convert()\n\n            return tflite_model\n\n        except Exception as e:\n            # If direct conversion fails, raise the error with helpful message\n            raise RuntimeError(\n                f\"Direct TFLite conversion failed. This may be due to model \"\n                f\"complexity or unsupported operations. Error: {e}\"\n            ) from e\n\n    def _apply_converter_kwargs(self, converter):\n        \"\"\"Apply additional converter settings from kwargs.\n\n        Args:\n            converter: tf.lite.TFLiteConverter instance to configure\n\n        Raises:\n            ValueError: If any kwarg is not a valid converter attribute\n        \"\"\"\n        for attr, value in self.kwargs.items():\n            if attr == \"target_spec\" and isinstance(value, dict):\n                # Handle nested target_spec settings\n                for spec_key, spec_value in value.items():\n                    if hasattr(converter.target_spec, spec_key):\n                        setattr(converter.target_spec, spec_key, spec_value)\n                    else:\n                        raise ValueError(\n                            f\"Unknown target_spec attribute '{spec_key}'\"\n                        )\n            elif hasattr(converter, attr):\n                setattr(converter, attr, value)\n            else:\n                raise ValueError(f\"Unknown converter attribute '{attr}'\")\n"
  },
  {
    "path": "keras/src/export/litert_test.py",
    "content": "import os\n\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src import tree\nfrom keras.src.saving import saving_lib\nfrom keras.src.testing.test_utils import named_product\nfrom keras.src.utils.module_utils import litert\nfrom keras.src.utils.module_utils import tensorflow\n\n# Set up LiteRT interpreter with fallback logic:\n# 1. Try AI Edge LiteRT interpreter (preferred)\n# 2. Fall back to TensorFlow Lite interpreter if AI Edge LiteRT unavailable\nAI_EDGE_LITERT_AVAILABLE = False\nLiteRTInterpreter = None\n\nif backend.backend() == \"tensorflow\":\n    if litert.available:\n        try:\n            from ai_edge_litert.interpreter import (\n                Interpreter as LiteRTInterpreter,\n            )\n\n            AI_EDGE_LITERT_AVAILABLE = True\n        except (ImportError, OSError):\n            LiteRTInterpreter = tensorflow.lite.Interpreter\n    else:\n        LiteRTInterpreter = tensorflow.lite.Interpreter\n\n# Model types to test (LSTM only if AI Edge LiteRT is available)\nmodel_types = [\"sequential\", \"functional\"]\n# TODO(#21914): `\"lstm\"` does not work with ai-edge-litert==1.3.0.\n# Unfortunately, for TF 2.20.0, this is the only version which works. Uncomment\n# this part when we upgrade TF and ai-edge-litert.\n# if AI_EDGE_LITERT_AVAILABLE:\n#     model_types.append(\"lstm\")\n\n\nclass CustomModel(models.Model):\n    def __init__(self, layer_list):\n        super().__init__()\n        self.layer_list = layer_list\n\n    def call(self, input):\n        output = input\n        for layer in self.layer_list:\n            output = layer(output)\n        return output\n\n\ndef get_model(type=\"sequential\", input_shape=(10,), layer_list=None):\n    layer_list = layer_list or [\n        layers.Dense(10, activation=\"relu\"),\n        layers.BatchNormalization(),\n        layers.Dense(1, activation=\"sigmoid\"),\n    ]\n    if type == \"sequential\":\n        model = models.Sequential(layer_list)\n        model.build(input_shape=(None,) + input_shape)\n        return model\n    if type == \"functional\":\n        input = output = tree.map_shape_structure(layers.Input, input_shape)\n        for layer in layer_list:\n            output = layer(output)\n        return models.Model(inputs=input, outputs=output)\n    if type == \"subclass\":\n        model = CustomModel(layer_list)\n        model.build(input_shape=(None,) + input_shape)\n        # Trace the model with dummy data to ensure it's properly built for\n        # export\n        dummy_input = np.zeros((1,) + input_shape, dtype=np.float32)\n        _ = model(dummy_input)  # This traces the model\n        return model\n    if type == \"lstm\":\n        inputs = layers.Input((4, 10))\n        x = layers.Bidirectional(\n            layers.LSTM(\n                10,\n                kernel_initializer=\"he_normal\",\n                return_sequences=True,\n                kernel_regularizer=None,\n            ),\n            merge_mode=\"sum\",\n        )(inputs)\n        outputs = layers.Bidirectional(\n            layers.LSTM(\n                10,\n                kernel_initializer=\"he_normal\",\n                return_sequences=True,\n                kernel_regularizer=None,\n            ),\n            merge_mode=\"concat\",\n        )(x)\n        return models.Model(inputs=inputs, outputs=outputs)\n    if type == \"multi_input\":\n        input1 = layers.Input(shape=input_shape, name=\"input1\")\n        input2 = layers.Input(shape=input_shape, name=\"input2\")\n        x1 = layers.Dense(10, activation=\"relu\")(input1)\n        x2 = layers.Dense(10, activation=\"relu\")(input2)\n        combined = layers.concatenate([x1, x2])\n        output = layers.Dense(1, activation=\"sigmoid\")(combined)\n        return models.Model(inputs=[input1, input2], outputs=output)\n    if type == \"multi_output\":\n        inputs = layers.Input(shape=input_shape)\n        shared = layers.Dense(20, activation=\"relu\")(inputs)\n        output1 = layers.Dense(1, activation=\"sigmoid\", name=\"output1\")(shared)\n        output2 = layers.Dense(3, activation=\"softmax\", name=\"output2\")(shared)\n        return models.Model(inputs=inputs, outputs=[output1, output2])\n    raise ValueError(f\"Unknown model type: {type}\")\n\n\ndef _convert_to_numpy(structure):\n    return tree.map_structure(\n        lambda x: x.numpy() if hasattr(x, \"numpy\") else np.array(x), structure\n    )\n\n\ndef _normalize_name(name):\n    normalized = name.split(\":\")[0]\n    if normalized.startswith(\"serving_default_\"):\n        normalized = normalized[len(\"serving_default_\") :]\n    return normalized\n\n\ndef _set_interpreter_inputs(interpreter, inputs):\n    input_details = interpreter.get_input_details()\n    if isinstance(inputs, dict):\n        for detail in input_details:\n            key = _normalize_name(detail[\"name\"])\n            if key in inputs:\n                value = inputs[key]\n            else:\n                matched_key = None\n                for candidate in inputs:\n                    if key.endswith(candidate) or candidate.endswith(key):\n                        matched_key = candidate\n                        break\n                if matched_key is None:\n                    raise KeyError(\n                        f\"Unable to match input '{detail['name']}' in provided \"\n                        f\"inputs\"\n                    )\n                value = inputs[matched_key]\n            interpreter.set_tensor(detail[\"index\"], value)\n    else:\n        values = inputs\n        if not isinstance(values, (list, tuple)):\n            values = [values]\n        if len(values) != len(input_details):\n            raise ValueError(\n                \"Number of provided inputs does not match interpreter signature\"\n            )\n        for detail, value in zip(input_details, values):\n            interpreter.set_tensor(detail[\"index\"], value)\n\n\ndef _get_interpreter_outputs(interpreter):\n    output_details = interpreter.get_output_details()\n    outputs = [\n        interpreter.get_tensor(detail[\"index\"]) for detail in output_details\n    ]\n    return outputs[0] if len(outputs) == 1 else outputs\n\n\n@pytest.mark.skipif(\n    backend.backend() != \"tensorflow\",\n    reason=\"`export_litert` currently supports the TensorFlow backend only.\",\n)\nclass ExportLitertTest(testing.TestCase):\n    \"\"\"Test suite for LiteRT (TFLite) model export functionality.\n\n    Tests use AI Edge LiteRT interpreter when available, otherwise fall back\n    to TensorFlow Lite interpreter for validation.\n    \"\"\"\n\n    @parameterized.named_parameters(named_product(model_type=model_types))\n    def test_standard_model_export(self, model_type):\n        \"\"\"Test exporting standard model types to LiteRT format.\"\"\"\n        if model_type == \"lstm\" and not AI_EDGE_LITERT_AVAILABLE:\n            self.skipTest(\"LSTM models require AI Edge LiteRT interpreter.\")\n\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"exported_model.tflite\"\n        )\n        model = get_model(model_type)\n        batch_size = 1  # LiteRT expects batch_size=1\n        if model_type == \"lstm\":\n            ref_input = np.random.normal(size=(batch_size, 4, 10))\n        else:\n            ref_input = np.random.normal(size=(batch_size, 10))\n        ref_input = ref_input.astype(\"float32\")\n        ref_output = _convert_to_numpy(model(ref_input))\n\n        # Test with model.export()\n        model.export(temp_filepath, format=\"litert\")\n        self.assertTrue(os.path.exists(temp_filepath))\n\n        interpreter = LiteRTInterpreter(model_path=temp_filepath)\n        interpreter.allocate_tensors()\n        _set_interpreter_inputs(interpreter, ref_input)\n        interpreter.invoke()\n        litert_output = _get_interpreter_outputs(interpreter)\n\n        self.assertAllClose(ref_output, litert_output, atol=1e-4, rtol=1e-4)\n\n    @parameterized.named_parameters(\n        named_product(struct_type=[\"tuple\", \"array\", \"dict\"])\n    )\n    def test_model_with_input_structure(self, struct_type):\n        \"\"\"Test exporting models with structured inputs (tuple/array/dict).\"\"\"\n        batch_size = 1  # LiteRT expects batch_size=1\n        base_input = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n\n        if struct_type == \"tuple\":\n            # Use Functional API for proper Input layer handling\n            input1 = layers.Input(shape=(10,), name=\"input_1\")\n            input2 = layers.Input(shape=(10,), name=\"input_2\")\n            output = layers.Add()([input1, input2])\n            model = models.Model(inputs=[input1, input2], outputs=output)\n            ref_input = (base_input, base_input * 2)\n        elif struct_type == \"array\":\n            # Use Functional API for proper Input layer handling\n            input1 = layers.Input(shape=(10,), name=\"input_1\")\n            input2 = layers.Input(shape=(10,), name=\"input_2\")\n            output = layers.Add()([input1, input2])\n            model = models.Model(inputs=[input1, input2], outputs=output)\n            ref_input = [base_input, base_input * 2]\n        elif struct_type == \"dict\":\n            # Use Functional API for proper Input layer handling\n            input1 = layers.Input(shape=(10,), name=\"x\")\n            input2 = layers.Input(shape=(10,), name=\"y\")\n            output = layers.Add()([input1, input2])\n            model = models.Model(\n                inputs={\"x\": input1, \"y\": input2}, outputs=output\n            )\n            ref_input = {\"x\": base_input, \"y\": base_input * 2}\n        else:\n            raise AssertionError(\"Unexpected structure type\")\n\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"exported_model.tflite\"\n        )\n        ref_output = _convert_to_numpy(\n            model(tree.map_structure(ops.convert_to_tensor, ref_input))\n        )\n\n        # Test with model.export()\n        model.export(temp_filepath, format=\"litert\")\n        export_path = temp_filepath\n        interpreter = LiteRTInterpreter(model_path=export_path)\n        interpreter.allocate_tensors()\n\n        feed_inputs = ref_input\n        if isinstance(feed_inputs, tuple):\n            feed_inputs = list(feed_inputs)\n        _set_interpreter_inputs(interpreter, feed_inputs)\n        interpreter.invoke()\n        litert_output = _get_interpreter_outputs(interpreter)\n\n        self.assertAllClose(ref_output, litert_output, atol=1e-4, rtol=1e-4)\n\n        # Verify export still works after saving/loading via saving_lib.\n        archive_path = os.path.join(self.get_temp_dir(), \"revived.keras\")\n        saving_lib.save_model(model, archive_path)\n        revived_model = saving_lib.load_model(archive_path)\n        revived_output = _convert_to_numpy(revived_model(ref_input))\n        self.assertAllClose(ref_output, revived_output)\n\n    def test_model_with_multiple_inputs(self):\n        \"\"\"Test exporting models with multiple inputs and batch resizing.\"\"\"\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"exported_model.tflite\"\n        )\n\n        # Use Functional API for proper Input layer handling\n        input_x = layers.Input(shape=(10,), name=\"x\")\n        input_y = layers.Input(shape=(10,), name=\"y\")\n        output = layers.Add()([input_x, input_y])\n        model = models.Model(inputs=[input_x, input_y], outputs=output)\n\n        batch_size = 1  # LiteRT expects batch_size=1\n        ref_input_x = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n        ref_input_y = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n        ref_output = _convert_to_numpy(model([ref_input_x, ref_input_y]))\n\n        # Test with model.export()\n        model.export(temp_filepath, format=\"litert\")\n        export_path = temp_filepath\n        interpreter = LiteRTInterpreter(model_path=export_path)\n        interpreter.allocate_tensors()\n\n        _set_interpreter_inputs(interpreter, [ref_input_x, ref_input_y])\n        interpreter.invoke()\n        litert_output = _get_interpreter_outputs(interpreter)\n\n        self.assertAllClose(ref_output, litert_output, atol=1e-4, rtol=1e-4)\n\n        # Test with a different batch size by resizing interpreter inputs.\n        larger_x = np.concatenate([ref_input_x, ref_input_x], axis=0)\n        larger_y = np.concatenate([ref_input_y, ref_input_y], axis=0)\n        input_details = interpreter.get_input_details()\n        interpreter.resize_tensor_input(\n            input_details[0][\"index\"], larger_x.shape\n        )\n        interpreter.resize_tensor_input(\n            input_details[1][\"index\"], larger_y.shape\n        )\n        interpreter.allocate_tensors()\n        _set_interpreter_inputs(interpreter, [larger_x, larger_y])\n        interpreter.invoke()\n        larger_output = _get_interpreter_outputs(interpreter)\n        larger_ref_output = _convert_to_numpy(model([larger_x, larger_y]))\n        self.assertAllClose(\n            larger_ref_output, larger_output, atol=1e-4, rtol=1e-4\n        )\n\n    def test_export_with_custom_input_signature(self):\n        \"\"\"Test exporting with custom input signature specification.\"\"\"\n        model = get_model(\"sequential\")\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"exported_model.tflite\"\n        )\n        input_signature = [layers.InputSpec(shape=(None, 10), dtype=\"float32\")]\n\n        # Test with model.export()\n        model.export(\n            temp_filepath,\n            format=\"litert\",\n            input_signature=input_signature,\n        )\n        export_path = temp_filepath\n        self.assertTrue(os.path.exists(export_path))\n\n        interpreter = LiteRTInterpreter(model_path=export_path)\n        interpreter.allocate_tensors()\n        input_details = interpreter.get_input_details()\n        self.assertEqual(len(input_details), 1)\n        self.assertEqual(tuple(input_details[0][\"shape\"][1:]), (10,))\n\n    def test_multi_output_model_export(self):\n        \"\"\"Test exporting multi-output models.\"\"\"\n        model = get_model(\"multi_output\")\n\n        # Build the model\n        ref_input = np.random.normal(size=(3, 10)).astype(\"float32\")\n        model(ref_input)\n\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"exported_model.tflite\"\n        )\n        model.export(temp_filepath, format=\"litert\")\n\n        tflite_path = temp_filepath\n        self.assertTrue(os.path.exists(tflite_path))\n\n        # Test inference\n        interpreter = LiteRTInterpreter(model_path=tflite_path)\n        interpreter.allocate_tensors()\n\n        input_details = interpreter.get_input_details()\n        output_details = interpreter.get_output_details()\n\n        self.assertEqual(len(output_details), 2)\n\n        test_input = np.random.random(input_details[0][\"shape\"]).astype(\n            np.float32\n        )\n        interpreter.set_tensor(input_details[0][\"index\"], test_input)\n        interpreter.invoke()\n\n        for detail in output_details:\n            output = interpreter.get_tensor(detail[\"index\"])\n            self.assertIsInstance(output, np.ndarray)\n\n    def test_export_with_verbose(self):\n        \"\"\"Test export with verbose output.\"\"\"\n        model = get_model(\"sequential\")\n        dummy_input = np.random.random((3, 10)).astype(np.float32)\n        model(dummy_input)\n\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"exported_model.tflite\"\n        )\n\n        # Export with verbose=True\n        model.export(temp_filepath, format=\"litert\", verbose=True)\n\n        tflite_path = temp_filepath\n        self.assertTrue(os.path.exists(tflite_path))\n\n        # Verify the exported model works\n        interpreter = LiteRTInterpreter(model_path=tflite_path)\n        interpreter.allocate_tensors()\n\n        input_details = interpreter.get_input_details()\n        self.assertEqual(len(input_details), 1)\n\n    def test_export_error_handling(self):\n        \"\"\"Test error handling in export API.\"\"\"\n        model = get_model(\"sequential\")\n        dummy_input = np.random.random((3, 10)).astype(np.float32)\n        model(dummy_input)\n\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"exported_model.tflite\"\n        )\n\n        # Test with invalid format\n        with self.assertRaises(ValueError):\n            model.export(temp_filepath, format=\"invalid_format\")\n\n    def test_export_invalid_filepath(self):\n        \"\"\"Test that export fails with invalid file extension.\"\"\"\n        model = get_model(\"sequential\")\n        dummy_input = np.random.random((3, 10)).astype(np.float32)\n        model(dummy_input)\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model.txt\")\n\n        # Should raise ValueError for wrong extension\n        with self.assertRaises(ValueError):\n            model.export(temp_filepath, format=\"litert\")\n\n    def test_export_subclass_model(self):\n        \"\"\"Test exporting subclass models (uses wrapper conversion path).\"\"\"\n        if LiteRTInterpreter is None:\n            self.skipTest(\"No LiteRT interpreter available\")\n\n        model = get_model(\"subclass\")\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"exported_model.tflite\"\n        )\n\n        batch_size = 1\n        ref_input = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n        ref_output = _convert_to_numpy(model(ref_input))\n\n        # Export subclass model - this tests wrapper-based conversion\n        model.export(temp_filepath, format=\"litert\")\n        self.assertTrue(os.path.exists(temp_filepath))\n\n        # Verify inference\n        interpreter = LiteRTInterpreter(model_path=temp_filepath)\n        interpreter.allocate_tensors()\n        _set_interpreter_inputs(interpreter, ref_input)\n        interpreter.invoke()\n        litert_output = _get_interpreter_outputs(interpreter)\n\n        self.assertAllClose(ref_output, litert_output, atol=1e-4, rtol=1e-4)\n\n    def test_export_with_optimizations_default(self):\n        \"\"\"Test export with DEFAULT optimization.\"\"\"\n        if LiteRTInterpreter is None:\n            self.skipTest(\"No LiteRT interpreter available\")\n\n        model = get_model(\"sequential\")\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"optimized_default.tflite\"\n        )\n\n        batch_size = 1\n        ref_input = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n        ref_output = _convert_to_numpy(model(ref_input))\n\n        # Export with DEFAULT optimization\n        model.export(\n            temp_filepath,\n            format=\"litert\",\n            optimizations=[tensorflow.lite.Optimize.DEFAULT],\n        )\n        self.assertTrue(os.path.exists(temp_filepath))\n\n        # Verify inference still works\n        interpreter = LiteRTInterpreter(model_path=temp_filepath)\n        interpreter.allocate_tensors()\n        _set_interpreter_inputs(interpreter, ref_input)\n        interpreter.invoke()\n        litert_output = _get_interpreter_outputs(interpreter)\n\n        # Quantized model should be close but not exact\n        self.assertAllClose(ref_output, litert_output, atol=1e-2, rtol=1e-2)\n\n    def test_export_with_optimizations_sparsity(self):\n        \"\"\"Test export with EXPERIMENTAL_SPARSITY optimization.\"\"\"\n        if LiteRTInterpreter is None:\n            self.skipTest(\"No LiteRT interpreter available\")\n\n        model = get_model(\"functional\")\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"optimized_sparsity.tflite\"\n        )\n\n        batch_size = 1\n        ref_input = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n\n        # Export with EXPERIMENTAL_SPARSITY optimization\n        model.export(\n            temp_filepath,\n            format=\"litert\",\n            optimizations=[tensorflow.lite.Optimize.EXPERIMENTAL_SPARSITY],\n        )\n        self.assertTrue(os.path.exists(temp_filepath))\n\n        # Verify the model can run inference\n        interpreter = LiteRTInterpreter(model_path=temp_filepath)\n        interpreter.allocate_tensors()\n        _set_interpreter_inputs(interpreter, ref_input)\n        interpreter.invoke()\n        litert_output = _get_interpreter_outputs(interpreter)\n\n        # Output should have valid shape\n        self.assertEqual(litert_output.shape, (batch_size, 1))\n\n    def test_export_with_optimizations_size(self):\n        \"\"\"Test export with OPTIMIZE_FOR_SIZE optimization.\"\"\"\n        if LiteRTInterpreter is None:\n            self.skipTest(\"No LiteRT interpreter available\")\n\n        model = get_model(\"sequential\")\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"optimized_size.tflite\"\n        )\n\n        batch_size = 1\n        ref_input = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n\n        # Export with OPTIMIZE_FOR_SIZE\n        model.export(\n            temp_filepath,\n            format=\"litert\",\n            optimizations=[tensorflow.lite.Optimize.OPTIMIZE_FOR_SIZE],\n        )\n        self.assertTrue(os.path.exists(temp_filepath))\n\n        # Verify the model can run inference\n        interpreter = LiteRTInterpreter(model_path=temp_filepath)\n        interpreter.allocate_tensors()\n        _set_interpreter_inputs(interpreter, ref_input)\n        interpreter.invoke()\n        litert_output = _get_interpreter_outputs(interpreter)\n\n        self.assertEqual(litert_output.shape, (batch_size, 1))\n\n    def test_export_with_optimizations_latency(self):\n        \"\"\"Test export with OPTIMIZE_FOR_LATENCY optimization.\"\"\"\n        if LiteRTInterpreter is None:\n            self.skipTest(\"No LiteRT interpreter available\")\n\n        model = get_model(\"functional\")\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"optimized_latency.tflite\"\n        )\n\n        batch_size = 1\n        ref_input = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n\n        # Export with OPTIMIZE_FOR_LATENCY\n        model.export(\n            temp_filepath,\n            format=\"litert\",\n            optimizations=[tensorflow.lite.Optimize.OPTIMIZE_FOR_LATENCY],\n        )\n        self.assertTrue(os.path.exists(temp_filepath))\n\n        # Verify the model can run inference\n        interpreter = LiteRTInterpreter(model_path=temp_filepath)\n        interpreter.allocate_tensors()\n        _set_interpreter_inputs(interpreter, ref_input)\n        interpreter.invoke()\n        litert_output = _get_interpreter_outputs(interpreter)\n\n        self.assertEqual(litert_output.shape, (batch_size, 1))\n\n    def test_export_with_multiple_optimizations(self):\n        \"\"\"Test export with multiple optimization options combined.\"\"\"\n        if LiteRTInterpreter is None:\n            self.skipTest(\"No LiteRT interpreter available\")\n\n        model = get_model(\"sequential\")\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"optimized_multiple.tflite\"\n        )\n\n        batch_size = 1\n        ref_input = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n\n        # Export with multiple optimizations\n        model.export(\n            temp_filepath,\n            format=\"litert\",\n            optimizations=[\n                tensorflow.lite.Optimize.DEFAULT,\n                tensorflow.lite.Optimize.EXPERIMENTAL_SPARSITY,\n            ],\n        )\n        self.assertTrue(os.path.exists(temp_filepath))\n\n        # Verify the model can run inference\n        interpreter = LiteRTInterpreter(model_path=temp_filepath)\n        interpreter.allocate_tensors()\n        _set_interpreter_inputs(interpreter, ref_input)\n        interpreter.invoke()\n        litert_output = _get_interpreter_outputs(interpreter)\n\n        self.assertEqual(litert_output.shape, (batch_size, 1))\n\n    def test_export_with_representative_dataset(self):\n        \"\"\"Test export with representative dataset for better quantization.\"\"\"\n        if LiteRTInterpreter is None:\n            self.skipTest(\"No LiteRT interpreter available\")\n\n        model = get_model(\"functional\")\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"quantized_model.tflite\"\n        )\n\n        # Create representative dataset\n        def representative_dataset():\n            for _ in range(10):\n                yield [np.random.normal(size=(1, 10)).astype(\"float32\")]\n\n        # Export with optimizations and representative dataset\n        model.export(\n            temp_filepath,\n            format=\"litert\",\n            optimizations=[tensorflow.lite.Optimize.DEFAULT],\n            representative_dataset=representative_dataset,\n        )\n        self.assertTrue(os.path.exists(temp_filepath))\n\n        # Verify the model can run inference\n        interpreter = LiteRTInterpreter(model_path=temp_filepath)\n        interpreter.allocate_tensors()\n\n        batch_size = 1\n        ref_input = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n        _set_interpreter_inputs(interpreter, ref_input)\n        interpreter.invoke()\n        litert_output = _get_interpreter_outputs(interpreter)\n\n        # Output should have valid shape\n        self.assertEqual(litert_output.shape, (batch_size, 1))\n\n    def test_export_with_multiple_kwargs(self):\n        \"\"\"Test export with multiple converter kwargs.\"\"\"\n        if LiteRTInterpreter is None:\n            self.skipTest(\"No LiteRT interpreter available\")\n\n        # Create a larger model for quantization testing\n        inputs = layers.Input(shape=(28, 28, 3))\n        x = layers.Conv2D(32, 3, activation=\"relu\")(inputs)\n        x = layers.MaxPooling2D()(x)\n        x = layers.Flatten()(x)\n        x = layers.Dense(10, activation=\"softmax\")(x)\n        model = models.Model(inputs, x)\n\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"multi_kwargs_model.tflite\"\n        )\n\n        # Create representative dataset\n        def representative_dataset():\n            for _ in range(5):\n                yield [np.random.normal(size=(1, 28, 28, 3)).astype(\"float32\")]\n\n        # Export with multiple kwargs\n        model.export(\n            temp_filepath,\n            format=\"litert\",\n            optimizations=[tensorflow.lite.Optimize.DEFAULT],\n            representative_dataset=representative_dataset,\n            experimental_new_quantizer=True,\n        )\n        self.assertTrue(os.path.exists(temp_filepath))\n\n        # Verify file size is reduced compared to non-quantized\n        file_size = os.path.getsize(temp_filepath)\n        self.assertGreater(file_size, 0)\n\n    def test_export_optimization_file_size_comparison(self):\n        \"\"\"Test that optimizations reduce file size.\"\"\"\n        if LiteRTInterpreter is None:\n            self.skipTest(\"No LiteRT interpreter available\")\n\n        # Create a larger model to see size differences\n        inputs = layers.Input(shape=(28, 28, 3))\n        x = layers.Conv2D(64, 3, activation=\"relu\")(inputs)\n        x = layers.Conv2D(64, 3, activation=\"relu\")(x)\n        x = layers.MaxPooling2D()(x)\n        x = layers.Flatten()(x)\n        x = layers.Dense(128, activation=\"relu\")(x)\n        x = layers.Dense(10, activation=\"softmax\")(x)\n        model = models.Model(inputs, x)\n\n        # Export without optimization\n        filepath_no_opt = os.path.join(\n            self.get_temp_dir(), \"model_no_opt.tflite\"\n        )\n        model.export(filepath_no_opt, format=\"litert\")\n\n        # Export with optimization\n        filepath_with_opt = os.path.join(\n            self.get_temp_dir(), \"model_with_opt.tflite\"\n        )\n        model.export(\n            filepath_with_opt,\n            format=\"litert\",\n            optimizations=[tensorflow.lite.Optimize.DEFAULT],\n        )\n\n        # Optimized model should be smaller\n        size_no_opt = os.path.getsize(filepath_no_opt)\n        size_with_opt = os.path.getsize(filepath_with_opt)\n\n        self.assertLess(\n            size_with_opt,\n            size_no_opt,\n            f\"Optimized model ({size_with_opt} bytes) should be smaller \"\n            f\"than non-optimized ({size_no_opt} bytes)\",\n        )\n\n        # Typically expect ~75% size reduction with quantization\n        reduction_ratio = size_with_opt / size_no_opt\n        self.assertLess(\n            reduction_ratio,\n            0.5,  # Should be less than 50% of original size\n            f\"Expected significant size reduction, got {reduction_ratio:.2%}\",\n        )\n\n    def test_signature_def_with_named_model(self):\n        \"\"\"Test that exported models have SignatureDef with input names.\"\"\"\n        if LiteRTInterpreter is None:\n            self.skipTest(\"No LiteRT interpreter available\")\n\n        # Build a model with explicit layer names\n        inputs = layers.Input(shape=(10,), name=\"feature_input\")\n        x = layers.Dense(32, activation=\"relu\", name=\"encoder\")(inputs)\n        x = layers.Dense(16, activation=\"relu\", name=\"bottleneck\")(x)\n        outputs = layers.Dense(\n            1, activation=\"sigmoid\", name=\"prediction_output\"\n        )(x)\n        model = models.Model(inputs=inputs, outputs=outputs, name=\"named_model\")\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"named_model.tflite\")\n\n        # Export the model\n        model.export(temp_filepath, format=\"litert\")\n        self.assertTrue(os.path.exists(temp_filepath))\n\n        # Load and check SignatureDef\n        interpreter = LiteRTInterpreter(model_path=temp_filepath)\n        interpreter.allocate_tensors()\n\n        # Get SignatureDef information\n        signature_defs = interpreter.get_signature_list()\n        self.assertIn(\"serving_default\", signature_defs)\n\n        serving_sig = signature_defs[\"serving_default\"]\n        sig_inputs = serving_sig.get(\"inputs\", [])\n        sig_outputs = serving_sig.get(\"outputs\", [])\n\n        # Verify SignatureDef has inputs and outputs\n        self.assertGreater(\n            len(sig_inputs), 0, \"Should have at least one input in SignatureDef\"\n        )\n        self.assertGreater(\n            len(sig_outputs),\n            0,\n            \"Should have at least one output in SignatureDef\",\n        )\n\n        # Verify input names are preserved (they should match Keras input names)\n        self.assertIn(\n            \"feature_input\",\n            sig_inputs,\n            f\"Input name 'feature_input' should be in SignatureDef inputs: \"\n            f\"{sig_inputs}\",\n        )\n\n        # Verify inference works using signature runner\n        batch_size = 1\n        ref_input = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n        ref_output = _convert_to_numpy(model(ref_input))\n\n        # Note: For single-output Functional models, Keras returns a tensor\n        # (not dict). SignatureDef will have generic output names like\n        # 'output_0'.\n        # Only multi-output models or models with explicit dict returns have\n        # named outputs\n\n        # Test inference using signature runner for better output name handling\n        signature_runner = interpreter.get_signature_runner(\"serving_default\")\n        sig_output = signature_runner(feature_input=ref_input)\n\n        # sig_output should be a dict with meaningful output names\n        self.assertIsInstance(sig_output, dict)\n        self.assertGreater(\n            len(sig_output), 0, \"Should have at least one output\"\n        )\n\n        # For single output, extract the value\n        if len(sig_output) == 1:\n            litert_output = list(sig_output.values())[0]\n        else:\n            litert_output = list(sig_output.values())\n\n        self.assertAllClose(ref_output, litert_output, atol=1e-4, rtol=1e-4)\n\n    def test_signature_def_with_functional_model(self):\n        \"\"\"Test that SignatureDef preserves input/output names for\n        Functional models.\"\"\"\n        if LiteRTInterpreter is None:\n            self.skipTest(\"No LiteRT interpreter available\")\n\n        # Create a Functional model with named inputs and outputs\n        inputs = layers.Input(shape=(10,), name=\"input_layer\")\n        x = layers.Dense(32, activation=\"relu\", name=\"hidden_layer\")(inputs)\n        outputs = layers.Dense(1, activation=\"sigmoid\", name=\"output_layer\")(x)\n        model = models.Model(\n            inputs=inputs, outputs=outputs, name=\"functional_model\"\n        )\n\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"functional_model.tflite\"\n        )\n\n        # Export the model\n        model.export(temp_filepath, format=\"litert\")\n        self.assertTrue(os.path.exists(temp_filepath))\n\n        # Load and check SignatureDef\n        interpreter = LiteRTInterpreter(model_path=temp_filepath)\n        interpreter.allocate_tensors()\n\n        # Get SignatureDef information\n        signature_defs = interpreter.get_signature_list()\n        self.assertIn(\"serving_default\", signature_defs)\n\n        serving_sig = signature_defs[\"serving_default\"]\n        sig_inputs = serving_sig.get(\"inputs\", [])\n        sig_outputs = serving_sig.get(\"outputs\", [])\n\n        # Verify SignatureDef has inputs and outputs\n        self.assertGreater(\n            len(sig_inputs), 0, \"Should have at least one input in SignatureDef\"\n        )\n        self.assertGreater(\n            len(sig_outputs),\n            0,\n            \"Should have at least one output in SignatureDef\",\n        )\n\n        # Verify that input names are preserved\n        self.assertIn(\n            \"input_layer\",\n            sig_inputs,\n            f\"Input name 'input_layer' should be in SignatureDef inputs: \"\n            f\"{sig_inputs}\",\n        )\n\n        # Test inference using signature runner for named outputs\n        batch_size = 1\n        ref_input = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n        ref_output = _convert_to_numpy(model(ref_input))\n\n        # Use signature runner to get outputs with meaningful names\n        signature_runner = interpreter.get_signature_runner(\"serving_default\")\n        sig_output = signature_runner(input_layer=ref_input)\n\n        # sig_output should be a dict with output names\n        self.assertIsInstance(sig_output, dict)\n        self.assertGreater(\n            len(sig_output), 0, \"Should have at least one output\"\n        )\n\n        # For single output, TFLite typically uses generic names like 'output_0'\n        # Extract the single output value\n        if len(sig_output) == 1:\n            litert_output = list(sig_output.values())[0]\n        else:\n            litert_output = list(sig_output.values())\n\n        self.assertAllClose(ref_output, litert_output, atol=1e-4, rtol=1e-4)\n\n    def test_signature_def_with_multi_input_model(self):\n        \"\"\"Test that SignatureDef preserves names for multi-input models.\"\"\"\n        if LiteRTInterpreter is None:\n            self.skipTest(\"No LiteRT interpreter available\")\n\n        # Create a multi-input model\n        input1 = layers.Input(shape=(10,), name=\"input_1\")\n        input2 = layers.Input(shape=(5,), name=\"input_2\")\n        concat = layers.Concatenate(name=\"concat_layer\")([input1, input2])\n        outputs = layers.Dense(1, activation=\"sigmoid\", name=\"output\")(concat)\n        model = models.Model(\n            inputs=[input1, input2], outputs=outputs, name=\"multi_input_model\"\n        )\n\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"multi_input_model.tflite\"\n        )\n\n        # Export the model\n        model.export(temp_filepath, format=\"litert\")\n        self.assertTrue(os.path.exists(temp_filepath))\n\n        # Load and check SignatureDef\n        interpreter = LiteRTInterpreter(model_path=temp_filepath)\n        interpreter.allocate_tensors()\n\n        # Get SignatureDef information\n        signature_defs = interpreter.get_signature_list()\n        self.assertIn(\"serving_default\", signature_defs)\n\n        serving_sig = signature_defs[\"serving_default\"]\n        sig_inputs = serving_sig.get(\"inputs\", [])\n        sig_outputs = serving_sig.get(\"outputs\", [])\n\n        # Verify SignatureDef has correct number of inputs and outputs\n        self.assertEqual(\n            len(sig_inputs), 2, \"Should have 2 inputs in SignatureDef\"\n        )\n        self.assertGreater(\n            len(sig_outputs),\n            0,\n            \"Should have at least one output in SignatureDef\",\n        )\n\n        # Verify that input names are preserved\n        self.assertIn(\n            \"input_1\",\n            sig_inputs,\n            f\"Input name 'input_1' should be in SignatureDef inputs: \"\n            f\"{sig_inputs}\",\n        )\n        self.assertIn(\n            \"input_2\",\n            sig_inputs,\n            f\"Input name 'input_2' should be in SignatureDef inputs: \"\n            f\"{sig_inputs}\",\n        )\n\n        # Test inference using signature runner\n        batch_size = 1\n        ref_input1 = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n        ref_input2 = np.random.normal(size=(batch_size, 5)).astype(\"float32\")\n        ref_inputs = [ref_input1, ref_input2]\n        ref_output = _convert_to_numpy(model(ref_inputs))\n\n        # Use signature runner with named inputs\n        signature_runner = interpreter.get_signature_runner(\"serving_default\")\n        sig_output = signature_runner(input_1=ref_input1, input_2=ref_input2)\n\n        # sig_output should be a dict with output names\n        self.assertIsInstance(sig_output, dict)\n        self.assertGreater(\n            len(sig_output), 0, \"Should have at least one output\"\n        )\n\n        # For single output, TFLite uses generic names like 'output_0'\n        # Extract the single output value\n        if len(sig_output) == 1:\n            litert_output = list(sig_output.values())[0]\n        else:\n            litert_output = list(sig_output.values())\n\n        self.assertAllClose(ref_output, litert_output, atol=1e-4, rtol=1e-4)\n\n    def test_signature_def_with_multi_output_model(self):\n        \"\"\"Test that SignatureDef handles multi-output models correctly.\"\"\"\n        if LiteRTInterpreter is None:\n            self.skipTest(\"No LiteRT interpreter available\")\n\n        # Create a multi-output model\n        inputs = layers.Input(shape=(10,), name=\"input_layer\")\n        x = layers.Dense(32, activation=\"relu\", name=\"shared_layer\")(inputs)\n        output1 = layers.Dense(1, activation=\"sigmoid\", name=\"output_1\")(x)\n        output2 = layers.Dense(2, activation=\"softmax\", name=\"output_2\")(x)\n        model = models.Model(\n            inputs=inputs, outputs=[output1, output2], name=\"multi_output_model\"\n        )\n\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"multi_output_model.tflite\"\n        )\n\n        # Export the model\n        model.export(temp_filepath, format=\"litert\")\n        self.assertTrue(os.path.exists(temp_filepath))\n\n        # Load and check SignatureDef\n        interpreter = LiteRTInterpreter(model_path=temp_filepath)\n        interpreter.allocate_tensors()\n\n        # Get SignatureDef information\n        signature_defs = interpreter.get_signature_list()\n        self.assertIn(\"serving_default\", signature_defs)\n\n        serving_sig = signature_defs[\"serving_default\"]\n        sig_inputs = serving_sig.get(\"inputs\", [])\n        sig_outputs = serving_sig.get(\"outputs\", [])\n\n        # Verify SignatureDef structure\n        self.assertGreater(\n            len(sig_inputs), 0, \"Should have at least one input in SignatureDef\"\n        )\n        self.assertEqual(\n            len(sig_outputs), 2, \"Should have 2 outputs in SignatureDef\"\n        )\n\n        # Test inference using signature runner\n        batch_size = 1\n        ref_input = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n        ref_outputs = _convert_to_numpy(model(ref_input))\n\n        # Use signature runner\n        signature_runner = interpreter.get_signature_runner(\"serving_default\")\n        sig_output = signature_runner(input_layer=ref_input)\n\n        # sig_output should be a dict with output names\n        self.assertIsInstance(sig_output, dict)\n        self.assertEqual(len(sig_output), 2, \"Should have 2 outputs\")\n\n        # Note: TFLite uses generic names like 'output_0', 'output_1' for\n        # SignatureDef outputs. These don't match the Keras layer names\n        # ('output_1', 'output_2') - this is expected. The names come from\n        # TensorFlow's symbolic tracing, not from our exporter code.\n        # Verify outputs match by position\n        sig_output_values = list(sig_output.values())\n        for i, ref_out in enumerate(ref_outputs):\n            self.assertAllClose(\n                ref_out, sig_output_values[i], atol=1e-4, rtol=1e-4\n            )\n\n    def test_dict_input_adapter_creation(self):\n        \"\"\"Test that dict input adapter is created and works correctly.\"\"\"\n\n        # Create a model with dictionary inputs\n        input1 = layers.Input(shape=(10,), name=\"x\")\n        input2 = layers.Input(shape=(10,), name=\"y\")\n        output = layers.Add()([input1, input2])\n        model = models.Model(inputs={\"x\": input1, \"y\": input2}, outputs=output)\n\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"dict_adapter_model.tflite\"\n        )\n\n        # Export with verbose to verify adapter creation messages\n        model.export(temp_filepath, format=\"litert\", verbose=True)\n\n        # Verify the file was created\n        self.assertTrue(os.path.exists(temp_filepath))\n\n        # Load and test the model\n        interpreter = LiteRTInterpreter(model_path=temp_filepath)\n        interpreter.allocate_tensors()\n\n        # Check input details - should have 2 inputs in list form\n        input_details = interpreter.get_input_details()\n        self.assertEqual(len(input_details), 2)\n\n        # Test inference\n        batch_size = 1\n        x_val = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n        y_val = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n\n        ref_output = _convert_to_numpy(\n            model(\n                {\n                    \"x\": ops.convert_to_tensor(x_val),\n                    \"y\": ops.convert_to_tensor(y_val),\n                }\n            )\n        )\n\n        # Set inputs as list (adapter converts list to dict internally)\n        _set_interpreter_inputs(interpreter, [x_val, y_val])\n        interpreter.invoke()\n        litert_output = _get_interpreter_outputs(interpreter)\n\n        self.assertAllClose(ref_output, litert_output, atol=1e-4, rtol=1e-4)\n\n    def test_dict_input_signature_inference(self):\n        \"\"\"Test automatic inference of dict input signatures.\"\"\"\n\n        # Create a model with dictionary inputs (without calling it first)\n        input1 = layers.Input(shape=(5,), name=\"feature_a\")\n        input2 = layers.Input(shape=(3,), name=\"feature_b\")\n        concat = layers.Concatenate()([input1, input2])\n        output = layers.Dense(1)(concat)\n        model = models.Model(\n            inputs={\"feature_a\": input1, \"feature_b\": input2}, outputs=output\n        )\n\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"inferred_dict_model.tflite\"\n        )\n\n        # Export without providing input_signature - should be inferred\n        model.export(temp_filepath, format=\"litert\")\n\n        # Verify successful export\n        self.assertTrue(os.path.exists(temp_filepath))\n\n        # Load and verify structure\n        interpreter = LiteRTInterpreter(model_path=temp_filepath)\n        interpreter.allocate_tensors()\n\n        input_details = interpreter.get_input_details()\n        self.assertEqual(len(input_details), 2)\n\n        # Verify shapes match expected\n        shapes = [tuple(d[\"shape\"][1:]) for d in input_details]\n        self.assertIn((5,), shapes)\n        self.assertIn((3,), shapes)\n\n    def test_dict_input_with_custom_signature(self):\n        \"\"\"Test dict input export with custom input signature.\"\"\"\n\n        # Create model with dict inputs\n        input1 = layers.Input(shape=(10,), name=\"input_x\")\n        input2 = layers.Input(shape=(10,), name=\"input_y\")\n        output = layers.Multiply()([input1, input2])\n        model = models.Model(\n            inputs={\"input_x\": input1, \"input_y\": input2}, outputs=output\n        )\n\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"dict_custom_sig_model.tflite\"\n        )\n\n        # Provide custom dict input signature\n        input_signature = {\n            \"input_x\": layers.InputSpec(shape=(None, 10), dtype=\"float32\"),\n            \"input_y\": layers.InputSpec(shape=(None, 10), dtype=\"float32\"),\n        }\n\n        model.export(\n            temp_filepath, format=\"litert\", input_signature=input_signature\n        )\n\n        # Verify export\n        self.assertTrue(os.path.exists(temp_filepath))\n\n        interpreter = LiteRTInterpreter(model_path=temp_filepath)\n        interpreter.allocate_tensors()\n\n        # Test inference\n        batch_size = 1\n        x_val = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n        y_val = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n\n        ref_output = _convert_to_numpy(\n            model(\n                {\n                    \"input_x\": ops.convert_to_tensor(x_val),\n                    \"input_y\": ops.convert_to_tensor(y_val),\n                }\n            )\n        )\n\n        _set_interpreter_inputs(interpreter, [x_val, y_val])\n        interpreter.invoke()\n        litert_output = _get_interpreter_outputs(interpreter)\n\n        self.assertAllClose(ref_output, litert_output, atol=1e-4, rtol=1e-4)\n\n    def test_dict_input_numerical_accuracy(self):\n        \"\"\"Test numerical accuracy of dict input models with complex ops.\"\"\"\n\n        # Create a more complex model with dict inputs\n        input1 = layers.Input(shape=(20,), name=\"tokens\")\n        input2 = layers.Input(shape=(20,), name=\"mask\")\n\n        # Apply some transformations\n        x1 = layers.Dense(16, activation=\"relu\")(input1)\n        x2 = layers.Dense(16, activation=\"relu\")(input2)\n\n        # Combine\n        combined = layers.Multiply()([x1, x2])\n        output = layers.Dense(1, activation=\"sigmoid\")(combined)\n\n        model = models.Model(\n            inputs={\"tokens\": input1, \"mask\": input2}, outputs=output\n        )\n\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"dict_numerical_model.tflite\"\n        )\n\n        model.export(temp_filepath, format=\"litert\")\n\n        # Test with multiple samples\n        batch_size = 1\n        tokens_val = np.random.normal(size=(batch_size, 20)).astype(\"float32\")\n        mask_val = np.random.normal(size=(batch_size, 20)).astype(\"float32\")\n\n        ref_output = _convert_to_numpy(\n            model(\n                {\n                    \"tokens\": ops.convert_to_tensor(tokens_val),\n                    \"mask\": ops.convert_to_tensor(mask_val),\n                }\n            )\n        )\n\n        interpreter = LiteRTInterpreter(model_path=temp_filepath)\n        interpreter.allocate_tensors()\n        _set_interpreter_inputs(interpreter, [tokens_val, mask_val])\n        interpreter.invoke()\n        litert_output = _get_interpreter_outputs(interpreter)\n\n        # Should have good numerical accuracy\n        self.assertAllClose(ref_output, litert_output, atol=1e-5, rtol=1e-5)\n\n    def test_dict_input_preserves_variable_sharing(self):\n        \"\"\"Test that adapter preserves variable sharing from original model.\"\"\"\n\n        # Create model with shared layers\n        shared_dense = layers.Dense(8, activation=\"relu\")\n\n        input1 = layers.Input(shape=(10,), name=\"branch_a\")\n        input2 = layers.Input(shape=(10,), name=\"branch_b\")\n\n        # Both inputs go through same shared layer\n        x1 = shared_dense(input1)\n        x2 = shared_dense(input2)\n\n        output = layers.Add()([x1, x2])\n        model = models.Model(\n            inputs={\"branch_a\": input1, \"branch_b\": input2}, outputs=output\n        )\n\n        # Train briefly to ensure weights are meaningful\n        model.compile(optimizer=\"adam\", loss=\"mse\")\n        x_train = {\n            \"branch_a\": np.random.normal(size=(5, 10)).astype(\"float32\"),\n            \"branch_b\": np.random.normal(size=(5, 10)).astype(\"float32\"),\n        }\n        y_train = np.random.normal(size=(5, 8)).astype(\"float32\")\n        model.fit(x_train, y_train, epochs=1, verbose=0)\n\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"dict_shared_vars_model.tflite\"\n        )\n\n        model.export(temp_filepath, format=\"litert\")\n\n        # Verify export works and inference matches\n        interpreter = LiteRTInterpreter(model_path=temp_filepath)\n        interpreter.allocate_tensors()\n\n        batch_size = 1\n        a_val = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n        b_val = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n\n        ref_output = _convert_to_numpy(\n            model(\n                {\n                    \"branch_a\": ops.convert_to_tensor(a_val),\n                    \"branch_b\": ops.convert_to_tensor(b_val),\n                }\n            )\n        )\n\n        _set_interpreter_inputs(interpreter, [a_val, b_val])\n        interpreter.invoke()\n        litert_output = _get_interpreter_outputs(interpreter)\n\n        self.assertAllClose(ref_output, litert_output, atol=1e-4, rtol=1e-4)\n\n    def test_dict_input_multi_output_model(self):\n        \"\"\"Test dict input model with multiple outputs exports successfully.\"\"\"\n\n        # Create model with dict inputs and multiple outputs\n        input1 = layers.Input(shape=(10,), name=\"feature_1\")\n        input2 = layers.Input(shape=(10,), name=\"feature_2\")\n\n        # Two output branches\n        output1 = layers.Dense(5, name=\"output_a\")(input1)\n        output2 = layers.Dense(3, name=\"output_b\")(input2)\n\n        model = models.Model(\n            inputs={\"feature_1\": input1, \"feature_2\": input2},\n            outputs=[output1, output2],\n        )\n\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"dict_multi_output_model.tflite\"\n        )\n\n        # Main test: export should succeed with dict inputs + multi outputs\n        model.export(temp_filepath, format=\"litert\")\n\n        # Verify file was created\n        self.assertTrue(os.path.exists(temp_filepath))\n\n        # Verify structure\n        interpreter = LiteRTInterpreter(model_path=temp_filepath)\n        interpreter.allocate_tensors()\n\n        # Should have 2 inputs (from dict)\n        input_details = interpreter.get_input_details()\n        self.assertEqual(len(input_details), 2)\n\n        # Should have 2 outputs\n        output_details = interpreter.get_output_details()\n        self.assertEqual(len(output_details), 2)\n\n        # Verify shapes\n        output_shapes = [tuple(d[\"shape\"][1:]) for d in output_details]\n        self.assertIn((5,), output_shapes)\n        self.assertIn((3,), output_shapes)\n"
  },
  {
    "path": "keras/src/export/neptune_model_export_archive.py",
    "content": "\"\"\"Base class for NeptuneModel export archive.\"\"\"\n\n\nclass NeptuneModelExportArchive:\n    def __init__(self):\n        raise NotImplementedError(\n            \"NeptuneExportArchive is an abstract class. \"\n            \"Use a subclass such as OrbaxSavedModelExportArchive.\"\n        )\n\n    def track(self, resource):\n        raise NotImplementedError()\n\n    def add_endpoint(self, name, fn, input_signature=None, **kwargs):\n        raise NotImplementedError()\n\n    def track_and_add_endpoint(self, name, resource, input_signature, **kwargs):\n        raise NotImplementedError()\n\n    def add_variable_collection(self, name, variables):\n        raise NotImplementedError()\n\n    def write_out(self, filepath, options=None, verbose=True):\n        raise NotImplementedError()\n"
  },
  {
    "path": "keras/src/export/onnx.py",
    "content": "import warnings\n\nfrom keras.src import backend\nfrom keras.src import tree\nfrom keras.src.export.export_utils import convert_spec_to_tensor\nfrom keras.src.export.export_utils import get_input_signature\nfrom keras.src.export.export_utils import make_tf_tensor_spec\nfrom keras.src.export.saved_model import DEFAULT_ENDPOINT_NAME\nfrom keras.src.export.saved_model import ExportArchive\nfrom keras.src.export.tf2onnx_lib import patch_tf2onnx\nfrom keras.src.utils import io_utils\n\n\ndef export_onnx(\n    model,\n    filepath,\n    verbose=None,\n    input_signature=None,\n    opset_version=None,\n    **kwargs,\n):\n    \"\"\"Export the model as a ONNX artifact for inference.\n\n    This method lets you export a model to a lightweight ONNX artifact\n    that contains the model's forward pass only (its `call()` method)\n    and can be served via e.g. ONNX Runtime.\n\n    The original code of the model (including any custom layers you may\n    have used) is *no longer* necessary to reload the artifact -- it is\n    entirely standalone.\n\n    Args:\n        filepath: `str` or `pathlib.Path` object. The path to save the artifact.\n        verbose: `bool`. Whether to print a message during export. Defaults to\n            `None`, which uses the default value set by different backends and\n            formats.\n        input_signature: Optional. Specifies the shape and dtype of the model\n            inputs. Can be a structure of `keras.InputSpec`, `tf.TensorSpec`,\n            `backend.KerasTensor`, or backend tensor. If not provided, it will\n            be automatically computed. Defaults to `None`.\n        opset_version: Optional. An integer value that specifies the ONNX opset\n            version. If not provided, the default version for the backend will\n            be used. Defaults to `None`.\n        **kwargs: Additional keyword arguments.\n\n    **Note:** This feature is currently supported only with TensorFlow, JAX and\n    Torch backends.\n\n    **Note:** The dtype policy must be \"float32\" for the model. You can further\n    optimize the ONNX artifact using the ONNX toolkit. Learn more here:\n    [https://onnxruntime.ai/docs/performance/](https://onnxruntime.ai/docs/performance/).\n\n    **Note:** The dynamic shape feature is not yet supported with Torch\n    backend. As a result, you must fully define the shapes of the inputs using\n    `input_signature`. If `input_signature` is not provided, all instances of\n    `None` (such as the batch size) will be replaced with `1`.\n\n    Example:\n\n    ```python\n    # Export the model as a ONNX artifact\n    model.export(\"path/to/location\", format=\"onnx\")\n\n    # Load the artifact in a different process/environment\n    ort_session = onnxruntime.InferenceSession(\"path/to/location\")\n    ort_inputs = {\n        k.name: v for k, v in zip(ort_session.get_inputs(), input_data)\n    }\n    predictions = ort_session.run(None, ort_inputs)\n    ```\n    \"\"\"\n    actual_verbose = verbose\n    if actual_verbose is None:\n        actual_verbose = True  # Defaults to `True` for all backends.\n\n    if input_signature is None:\n        input_signature = get_input_signature(model)\n        if not input_signature or not model._called:\n            raise ValueError(\"The model provided has never called. \")\n\n    # Extract specs for proper input name generation\n    if len(input_signature) == 1 and isinstance(input_signature[0], list):\n        # Multi-input case: input_signature = [[spec1, spec2, ...]]\n        specs_for_names = input_signature[0]\n    else:\n        # Single input case: input_signature = [spec]\n        specs_for_names = input_signature\n\n    input_names = [\n        getattr(spec, \"name\", None) or f\"input_{i}\"\n        for i, spec in enumerate(specs_for_names)\n    ]\n\n    if backend.backend() in (\"tensorflow\", \"jax\"):\n        from keras.src.utils.module_utils import tf2onnx\n\n        input_signature = tree.map_structure(\n            make_tf_tensor_spec, input_signature\n        )\n        decorated_fn = get_concrete_fn(model, input_signature, **kwargs)\n\n        # Use `tf2onnx` to convert the `decorated_fn` to the ONNX format.\n        patch_tf2onnx()  # TODO: Remove this once `tf2onnx` supports numpy 2.\n        tf2onnx.convert.from_function(\n            decorated_fn,\n            input_signature,\n            opset=opset_version,\n            output_path=filepath,\n        )\n\n    elif backend.backend() == \"torch\":\n        import torch\n\n        \"\"\"Generate dynamic_axes format for ONNX export.\"\"\"\n        dynamic_axes = {}\n\n        for input_idx, spec in enumerate(specs_for_names):\n            if not hasattr(spec, \"shape\"):\n                continue\n\n            shape = spec.shape\n            dynamic_dims = {}\n\n            for dim_idx, dim_size in enumerate(shape):\n                if dim_size is None:\n                    if dim_idx == 0:\n                        dim_name = \"batch\"\n                    else:\n                        dim_name = f\"dim_{input_idx}_{dim_idx}\"\n                    dynamic_dims[dim_idx] = dim_name\n\n            if dynamic_dims:\n                input_name = (\n                    input_names[input_idx]\n                    if input_idx < len(input_names)\n                    else f\"input_{input_idx}\"\n                )\n                dynamic_axes[input_name] = dynamic_dims\n\n        sample_inputs = tree.map_structure(\n            lambda x: convert_spec_to_tensor(x, replace_none_number=1),\n            input_signature,\n        )\n\n        sample_inputs = tuple(sample_inputs)\n        # TODO: Make dict model exportable.\n        if any(isinstance(x, dict) for x in sample_inputs):\n            raise ValueError(\n                \"Currently, `export_onnx` in the torch backend doesn't support \"\n                \"dictionaries as inputs.\"\n            )\n\n        if hasattr(model, \"eval\"):\n            model.eval()\n        with warnings.catch_warnings():\n            # Suppress some unuseful warnings.\n            warnings.filterwarnings(\n                \"ignore\",\n                message=r\".*\\n.*\\n*.*\\n*.*export will treat it as a constant.*\",\n            )\n            warnings.filterwarnings(\n                \"ignore\",\n                message=r\".*not properly registered as a submodule,.*\",\n            )\n            warnings.filterwarnings(\n                \"ignore\",\n                message=r\".*which is what 'get_attr' Nodes typically target.*\",\n            )\n            warnings.filterwarnings(\n                \"ignore\",\n                message=r\".*underlying reference in the owning GraphModule.*\",\n            )\n            warnings.filterwarnings(\n                \"ignore\", message=r\".*suppressed about get_attr references.*\"\n            )\n            # Suppress TorchScript tracing warnings\n            warnings.filterwarnings(\n                \"ignore\",\n                message=r\".*Converting a tensor to a Python boolean.*\",\n                category=torch.jit.TracerWarning,\n            )\n            warnings.filterwarnings(\n                \"ignore\",\n                message=r\".*Converting a tensor to a Python integer.*\",\n                category=torch.jit.TracerWarning,\n            )\n            warnings.filterwarnings(\n                \"ignore\",\n                message=r\".*Iterating over a tensor.*\",\n                category=torch.jit.TracerWarning,\n            )\n            warnings.filterwarnings(\n                \"ignore\",\n                message=r\".*Using len to get tensor shape.*\",\n                category=torch.jit.TracerWarning,\n            )\n            warnings.filterwarnings(\n                \"ignore\",\n                message=r\".*torch.tensor results are registered as constants.*\",\n                category=torch.jit.TracerWarning,\n            )\n\n        # When dynamic shapes are present, prefer TorchScript over\n        # TorchDynamo because TorchDynamo has constraint inference issues\n        # with dynamic dimensions\n        if not dynamic_axes:\n            try:\n                # Try the TorchDynamo-based ONNX exporter first for static\n                # shapes\n                export_kwargs = {\n                    \"verbose\": actual_verbose,\n                    \"opset_version\": opset_version,\n                    \"input_names\": input_names,\n                    \"dynamo\": True,\n                }\n\n                onnx_program = torch.onnx.export(\n                    model, sample_inputs, **export_kwargs\n                )\n                if hasattr(onnx_program, \"optimize\"):\n                    onnx_program.optimize()  # Only supported by torch>=2.6.0.\n                onnx_program.save(filepath)\n\n                return\n            except Exception:\n                pass\n\n        \"\"\"Export using TorchScript-based ONNX exporter.\"\"\"\n        # Set verbose to False for TorchScript due to file system leakage\n        torchscript_verbose = verbose\n        if verbose is None:\n            # Set to `False` due to file system leakage issue:\n            # https://github.com/keras-team/keras/issues/20826\n            torchscript_verbose = False\n\n        export_kwargs = {\n            \"verbose\": torchscript_verbose,\n            \"opset_version\": opset_version,\n            \"input_names\": input_names,\n            \"export_params\": True,\n            \"do_constant_folding\": True,\n            \"dynamo\": False,\n        }\n\n        # For TorchScript (dynamo=False), use dynamic_axes parameter\n        if dynamic_axes:\n            export_kwargs[\"dynamic_axes\"] = dynamic_axes\n\n        torch.onnx.export(model, sample_inputs, filepath, **export_kwargs)\n    else:\n        raise NotImplementedError(\n            \"`export_onnx` is only compatible with TensorFlow, JAX and \"\n            \"Torch backends.\"\n        )\n\n    if actual_verbose:\n        io_utils.print_msg(f\"Saved artifact at '{filepath}'.\")\n\n\ndef _check_jax_kwargs(kwargs):\n    kwargs = kwargs.copy()\n    if \"is_static\" not in kwargs:\n        kwargs[\"is_static\"] = True\n    if \"jax2tf_kwargs\" not in kwargs:\n        # TODO: These options will be deprecated in JAX. We need to\n        # find another way to export ONNX.\n        kwargs[\"jax2tf_kwargs\"] = {\n            \"enable_xla\": False,\n            \"native_serialization\": False,\n        }\n    if kwargs[\"is_static\"] is not True:\n        raise ValueError(\n            \"`is_static` must be `True` in `kwargs` when using the jax backend.\"\n        )\n    if kwargs[\"jax2tf_kwargs\"][\"enable_xla\"] is not False:\n        raise ValueError(\n            \"`enable_xla` must be `False` in `kwargs['jax2tf_kwargs']` \"\n            \"when using the jax backend.\"\n        )\n    if kwargs[\"jax2tf_kwargs\"][\"native_serialization\"] is not False:\n        raise ValueError(\n            \"`native_serialization` must be `False` in \"\n            \"`kwargs['jax2tf_kwargs']` when using the jax backend.\"\n        )\n    return kwargs\n\n\ndef get_concrete_fn(model, input_signature, **kwargs):\n    \"\"\"Get the `tf.function` associated with the model.\"\"\"\n    if backend.backend() == \"jax\":\n        kwargs = _check_jax_kwargs(kwargs)\n    export_archive = ExportArchive()\n    export_archive.track_and_add_endpoint(\n        DEFAULT_ENDPOINT_NAME, model, input_signature, **kwargs\n    )\n    if backend.backend() == \"tensorflow\":\n        export_archive._filter_and_track_resources()\n    return export_archive._get_concrete_fn(DEFAULT_ENDPOINT_NAME)\n"
  },
  {
    "path": "keras/src/export/onnx_test.py",
    "content": "\"\"\"Tests for ONNX exporting utilities.\"\"\"\n\nimport os\n\nimport numpy as np\nimport onnxruntime\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src import tree\nfrom keras.src.export import onnx\nfrom keras.src.layers.input_spec import InputSpec as InputSpec\nfrom keras.src.saving import saving_lib\nfrom keras.src.testing.test_utils import named_product\n\n\nclass CustomModel(models.Model):\n    def __init__(self, layer_list):\n        super().__init__()\n        self.layer_list = layer_list\n\n    def call(self, input):\n        output = input\n        for layer in self.layer_list:\n            output = layer(output)\n        return output\n\n\ndef get_model(type=\"sequential\", input_shape=(10,), layer_list=None):\n    layer_list = layer_list or [\n        layers.Dense(10, activation=\"relu\"),\n        layers.BatchNormalization(),\n        layers.Dense(1, activation=\"sigmoid\"),\n    ]\n    if type == \"sequential\":\n        return models.Sequential(layer_list)\n    elif type == \"functional\":\n        input = output = tree.map_shape_structure(layers.Input, input_shape)\n        for layer in layer_list:\n            output = layer(output)\n        return models.Model(inputs=input, outputs=output)\n    elif type == \"subclass\":\n        return CustomModel(layer_list)\n    elif type == \"lstm\":\n        # https://github.com/keras-team/keras/issues/21390\n        inputs = layers.Input((4, 10))\n        x = layers.Bidirectional(\n            layers.LSTM(\n                10,\n                kernel_initializer=\"he_normal\",\n                return_sequences=True,\n                kernel_regularizer=None,\n            ),\n            merge_mode=\"sum\",\n        )(inputs)\n        outputs = layers.Bidirectional(\n            layers.LSTM(\n                10,\n                kernel_initializer=\"he_normal\",\n                return_sequences=True,\n                kernel_regularizer=None,\n            ),\n            merge_mode=\"concat\",\n        )(x)\n        return models.Model(inputs=inputs, outputs=outputs)\n\n\n@pytest.mark.skipif(\n    backend.backend() not in (\"tensorflow\", \"jax\", \"torch\"),\n    reason=(\n        \"`export_onnx` only currently supports the tensorflow, jax and torch \"\n        \"backends.\"\n    ),\n)\n@pytest.mark.skipif(testing.uses_gpu(), reason=\"Fails on GPU\")\n@pytest.mark.skipif(\n    np.version.version.startswith(\"2.\"),\n    reason=\"ONNX export is currently incompatible with NumPy 2.0\",\n)\nclass ExportONNXTest(testing.TestCase):\n    @parameterized.named_parameters(\n        named_product(\n            model_type=[\"sequential\", \"functional\", \"subclass\", \"lstm\"]\n        )\n    )\n    def test_standard_model_export(self, model_type):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n        model = get_model(model_type)\n        batch_size = 3 if backend.backend() != \"torch\" else 1\n        if model_type == \"lstm\":\n            ref_input = np.random.normal(size=(batch_size, 4, 10))\n        else:\n            ref_input = np.random.normal(size=(batch_size, 10))\n        ref_input = ref_input.astype(\"float32\")\n        ref_output = model(ref_input)\n\n        onnx.export_onnx(model, temp_filepath)\n        ort_session = onnxruntime.InferenceSession(temp_filepath)\n        ort_inputs = {\n            k.name: v for k, v in zip(ort_session.get_inputs(), [ref_input])\n        }\n        self.assertAllClose(ref_output, ort_session.run(None, ort_inputs)[0])\n        # Test with a different batch size\n        ort_inputs = {\n            k.name: v\n            for k, v in zip(\n                ort_session.get_inputs(),\n                [np.concatenate([ref_input, ref_input], axis=0)],\n            )\n        }\n        ort_session.run(None, ort_inputs)\n\n    @parameterized.named_parameters(\n        named_product(struct_type=[\"tuple\", \"array\", \"dict\"])\n    )\n    def test_model_with_input_structure(self, struct_type):\n        if backend.backend() == \"torch\" and struct_type == \"dict\":\n            self.skipTest(\"The torch backend doesn't support the dict model.\")\n\n        class TupleModel(models.Model):\n            def call(self, inputs):\n                x, y = inputs\n                return ops.add(x, y)\n\n        class ArrayModel(models.Model):\n            def call(self, inputs):\n                x = inputs[0]\n                y = inputs[1]\n                return ops.add(x, y)\n\n        class DictModel(models.Model):\n            def call(self, inputs):\n                x = inputs[\"x\"]\n                y = inputs[\"y\"]\n                return ops.add(x, y)\n\n        batch_size = 3 if backend.backend() != \"torch\" else 1\n        ref_input = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n        if struct_type == \"tuple\":\n            model = TupleModel()\n            ref_input = (ref_input, ref_input * 2)\n        elif struct_type == \"array\":\n            model = ArrayModel()\n            ref_input = [ref_input, ref_input * 2]\n        elif struct_type == \"dict\":\n            model = DictModel()\n            ref_input = {\"x\": ref_input, \"y\": ref_input * 2}\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n        ref_output = model(tree.map_structure(ops.convert_to_tensor, ref_input))\n\n        onnx.export_onnx(model, temp_filepath)\n        ort_session = onnxruntime.InferenceSession(temp_filepath)\n        if isinstance(ref_input, dict):\n            ort_inputs = {\n                k.name: v\n                for k, v in zip(ort_session.get_inputs(), ref_input.values())\n            }\n        else:\n            ort_inputs = {\n                k.name: v for k, v in zip(ort_session.get_inputs(), ref_input)\n            }\n        self.assertAllClose(ref_output, ort_session.run(None, ort_inputs)[0])\n\n        # Test with keras.saving_lib\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"exported_model.keras\"\n        )\n        saving_lib.save_model(model, temp_filepath)\n        revived_model = saving_lib.load_model(\n            temp_filepath,\n            {\n                \"TupleModel\": TupleModel,\n                \"ArrayModel\": ArrayModel,\n                \"DictModel\": DictModel,\n            },\n        )\n        self.assertAllClose(ref_output, revived_model(ref_input))\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model2\")\n        onnx.export_onnx(revived_model, temp_filepath)\n\n        # Test with a different batch size\n        bigger_ref_input = tree.map_structure(\n            lambda x: np.concatenate([x, x], axis=0), ref_input\n        )\n        if isinstance(bigger_ref_input, dict):\n            bigger_ort_inputs = {\n                k.name: v\n                for k, v in zip(\n                    ort_session.get_inputs(), bigger_ref_input.values()\n                )\n            }\n        else:\n            bigger_ort_inputs = {\n                k.name: v\n                for k, v in zip(ort_session.get_inputs(), bigger_ref_input)\n            }\n        ort_session.run(None, bigger_ort_inputs)\n\n    def test_model_with_multiple_inputs(self):\n        class TwoInputsModel(models.Model):\n            def call(self, x, y):\n                return x + y\n\n            def build(self, y_shape, x_shape):\n                self.built = True\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n        model = TwoInputsModel()\n        batch_size = 3 if backend.backend() != \"torch\" else 1\n        ref_input_x = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n        ref_input_y = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n        ref_output = model(ref_input_x, ref_input_y)\n\n        onnx.export_onnx(model, temp_filepath)\n        ort_session = onnxruntime.InferenceSession(temp_filepath)\n        ort_inputs = {\n            k.name: v\n            for k, v in zip(\n                ort_session.get_inputs(), [ref_input_x, ref_input_y]\n            )\n        }\n        self.assertAllClose(ref_output, ort_session.run(None, ort_inputs)[0])\n        # Test with a different batch size\n        ort_inputs = {\n            k.name: v\n            for k, v in zip(\n                ort_session.get_inputs(),\n                [\n                    np.concatenate([ref_input_x, ref_input_x], axis=0),\n                    np.concatenate([ref_input_y, ref_input_y], axis=0),\n                ],\n            )\n        }\n        ort_session.run(None, ort_inputs)\n\n    @parameterized.named_parameters(named_product(opset_version=[None, 17]))\n    def test_export_with_opset_version(self, opset_version):\n        import onnx as onnx_lib\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n        model = get_model(\"sequential\")\n        batch_size = 3 if backend.backend() != \"torch\" else 1\n        ref_input = np.random.normal(size=(batch_size, 10))\n        ref_input = ref_input.astype(\"float32\")\n        ref_output = model(ref_input)\n\n        onnx.export_onnx(\n            model, temp_filepath, opset_version=opset_version, verbose=True\n        )\n        ort_session = onnxruntime.InferenceSession(temp_filepath)\n        ort_inputs = {\n            k.name: v for k, v in zip(ort_session.get_inputs(), [ref_input])\n        }\n        self.assertAllClose(ref_output, ort_session.run(None, ort_inputs)[0])\n\n        if opset_version is not None:\n            onnx_model = onnx_lib.load(temp_filepath)\n            self.assertEqual(onnx_model.opset_import[0].version, opset_version)\n\n    def test_export_with_input_names(self):\n        \"\"\"Test ONNX export uses InputSpec.name for input names.\"\"\"\n        import onnx as onnx_lib\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n        model = get_model(\"sequential\")\n        batch_size = 3 if backend.backend() != \"torch\" else 1\n        ref_input = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n        ref_output = model(ref_input)\n\n        # Test with custom input name\n        input_spec = [\n            InputSpec(\n                name=\"custom_input\", shape=(batch_size, 10), dtype=\"float32\"\n            )\n        ]\n        onnx.export_onnx(model, temp_filepath, input_signature=input_spec)\n\n        onnx_model = onnx_lib.load(temp_filepath)\n        input_names = [input.name for input in onnx_model.graph.input]\n        self.assertIn(\"custom_input\", input_names)\n\n        ort_session = onnxruntime.InferenceSession(temp_filepath)\n        ort_inputs = {\n            k.name: v for k, v in zip(ort_session.get_inputs(), [ref_input])\n        }\n        self.assertAllClose(ref_output, ort_session.run(None, ort_inputs)[0])\n\n    @parameterized.named_parameters(\n        named_product(\n            model_type=[\"sequential\", \"functional\"],\n            dynamic_type=[\"batch_only\", \"height_width\"],\n        )\n    )\n    def test_dynamic_shapes_export(self, model_type, dynamic_type):\n        \"\"\"Test ONNX export with various dynamic shape configurations.\n\n        Tests two scenarios:\n        - batch_only: Only batch dimension is dynamic, spatial dims fixed\n        - height_width: Batch, height, width are dynamic, channels fixed\n        \"\"\"\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n\n        # Define input shapes based on dynamic type\n        if dynamic_type == \"batch_only\":\n            input_shape = (32, 32, 3)  # Only batch is dynamic (None)\n            test_shapes = [(1, 32, 32, 3), (2, 32, 32, 3), (4, 32, 32, 3)]\n        elif dynamic_type == \"height_width\":\n            input_shape = (None, None, 3)  # Height and width are dynamic\n            test_shapes = [(1, 28, 28, 3), (1, 64, 64, 3), (1, 128, 96, 3)]\n\n        # Create model with appropriate layers for dynamic shapes\n        layer_list = [\n            layers.Conv2D(16, 3, padding=\"same\", activation=\"relu\"),\n            layers.GlobalAveragePooling2D(),\n            layers.Dense(10, activation=\"softmax\"),\n        ]\n\n        if model_type == \"sequential\":\n            model = models.Sequential(\n                [layers.Input(shape=input_shape)] + layer_list\n            )\n        elif model_type == \"functional\":\n            input_layer = layers.Input(shape=input_shape)\n            output = input_layer\n            for layer in layer_list:\n                output = layer(output)\n            model = models.Model(inputs=input_layer, outputs=output)\n\n        # Build model with initial input\n        initial_input = np.random.normal(size=test_shapes[0]).astype(np.float32)\n        model(initial_input)\n\n        # Export to ONNX\n        onnx.export_onnx(model, temp_filepath)\n\n        # Verify with ONNX Runtime\n        ort_session = onnxruntime.InferenceSession(temp_filepath)\n        input_info = ort_session.get_inputs()[0]\n\n        # Check that dynamic dimensions are preserved\n        input_shape_onnx = input_info.shape\n        if dynamic_type == \"batch_only\":\n            # Batch should be dynamic, others static\n            self.assertTrue(isinstance(input_shape_onnx[0], str))  # Dynamic\n            self.assertEqual(input_shape_onnx[1:], [32, 32, 3])  # Static\n        elif dynamic_type == \"height_width\":\n            # Batch, height, width should be dynamic, channels static\n            self.assertTrue(isinstance(input_shape_onnx[0], str))  # Dynamic\n            self.assertTrue(isinstance(input_shape_onnx[1], str))  # Dynamic\n            self.assertTrue(isinstance(input_shape_onnx[2], str))  # Dynamic\n            self.assertEqual(input_shape_onnx[3], 3)  # Static\n\n        # Test inference with different input shapes\n        for test_shape in test_shapes:\n            test_input = np.random.randn(*test_shape).astype(np.float32)\n            ort_inputs = {input_info.name: test_input}\n            result = ort_session.run(None, ort_inputs)\n\n            # Verify output shape matches expected batch size\n            expected_batch_size = test_shape[0]\n            self.assertEqual(result[0].shape[0], expected_batch_size)\n            self.assertEqual(result[0].shape[1], 10)  # Number of classes\n\n    def test_multi_input_dynamic_shapes(self):\n        \"\"\"Test ONNX export with multi-input model having dynamic shapes.\"\"\"\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n\n        # Create multi-input model with dynamic shapes\n        text_input = layers.Input(\n            shape=(None, 64), name=\"text_input\"\n        )  # Variable sequence length\n        image_input = layers.Input(\n            shape=(None, None, 3), name=\"image_input\"\n        )  # Variable image size\n\n        # Process text input\n        text_features = layers.Dense(128, activation=\"relu\")(text_input)\n        text_pooled = layers.GlobalAveragePooling1D()(text_features)\n\n        # Process image input\n        image_features = layers.Conv2D(32, 1, activation=\"relu\")(\n            image_input\n        )  # Use 1x1 conv to avoid size issues\n        image_pooled = layers.GlobalAveragePooling2D()(image_features)\n\n        # Combine features\n        combined = layers.Concatenate()([text_pooled, image_pooled])\n        output = layers.Dense(5, activation=\"softmax\")(combined)\n\n        model = models.Model(inputs=[text_input, image_input], outputs=output)\n\n        # Build model\n        sample_text = np.random.normal(size=(1, 20, 64)).astype(np.float32)\n        sample_image = np.random.normal(size=(1, 32, 32, 3)).astype(np.float32)\n        model([sample_text, sample_image])\n\n        # Export to ONNX\n        onnx.export_onnx(model, temp_filepath)\n\n        # Verify with ONNX Runtime\n        ort_session = onnxruntime.InferenceSession(temp_filepath)\n        inputs_info = ort_session.get_inputs()\n\n        # Check that both inputs have dynamic dimensions\n        text_shape = inputs_info[0].shape\n        image_shape = inputs_info[1].shape\n\n        # Text input: [batch, seq_len, features] - batch and seq_len dynamic\n        self.assertTrue(isinstance(text_shape[0], str))  # Dynamic\n        self.assertTrue(isinstance(text_shape[1], str))  # Dynamic\n        self.assertEqual(text_shape[2], 64)  # Static\n\n        # Image input: [batch, height, width, channels] - batch, h, w dynamic\n        self.assertTrue(isinstance(image_shape[0], str))  # Dynamic\n        self.assertTrue(isinstance(image_shape[1], str))  # Dynamic\n        self.assertTrue(isinstance(image_shape[2], str))  # Dynamic\n        self.assertEqual(image_shape[3], 3)  # Static\n\n        # Test inference with different input shapes\n        test_cases = [\n            ((1, 10, 64), (1, 28, 28, 3)),\n            ((2, 15, 64), (2, 64, 64, 3)),\n            ((1, 25, 64), (1, 48, 32, 3)),\n        ]\n\n        for text_shape, image_shape in test_cases:\n            text_input_data = np.random.randn(*text_shape).astype(np.float32)\n            image_input_data = np.random.randn(*image_shape).astype(np.float32)\n\n            ort_inputs = {\n                inputs_info[0].name: text_input_data,\n                inputs_info[1].name: image_input_data,\n            }\n            result = ort_session.run(None, ort_inputs)\n\n            # Verify output shape matches expected batch size\n            expected_batch_size = text_shape[0]\n            self.assertEqual(result[0].shape[0], expected_batch_size)\n            self.assertEqual(result[0].shape[1], 5)  # Number of classes\n"
  },
  {
    "path": "keras/src/export/openvino.py",
    "content": "import warnings\n\nfrom keras.src import backend\nfrom keras.src import tree\nfrom keras.src.export.export_utils import convert_spec_to_tensor\nfrom keras.src.export.export_utils import get_input_signature\nfrom keras.src.export.export_utils import make_tf_tensor_spec\nfrom keras.src.export.saved_model import DEFAULT_ENDPOINT_NAME\nfrom keras.src.export.saved_model import ExportArchive\nfrom keras.src.utils import io_utils\n\n\ndef export_openvino(\n    model, filepath, verbose=None, input_signature=None, **kwargs\n):\n    \"\"\"Export the model as an OpenVINO IR artifact for inference.\n\n    This method exports the model to the OpenVINO IR format,\n    which includes two files:\n    a `.xml` file containing the model structure and a `.bin` file\n    containing the weights.\n    The exported model contains only the forward pass\n    (i.e., the model's `call()` method), and can be deployed with the\n    OpenVINO Runtime for fast inference on CPU and other Intel hardware.\n\n    Args:\n        filepath: `str` or `pathlib.Path`. Path to the output `.xml` file.\n        The corresponding `.bin` file will be saved alongside it.\n        verbose: Optional `bool`. Whether to print a confirmation message\n        after export. If `None`, it uses the default verbosity configured\n        by the backend.\n        input_signature: Optional. Specifies the shape and dtype of the\n        model inputs. If not provided, it will be inferred.\n        **kwargs: Additional keyword arguments.\n\n     Example:\n\n    ```python\n    import keras\n\n    # Define or load a Keras model\n    model = keras.models.Sequential([\n        keras.layers.Input(shape=(128,)),\n        keras.layers.Dense(64, activation=\"relu\"),\n        keras.layers.Dense(10)\n    ])\n\n    # Export to OpenVINO IR\n    model.export(\"model.xml\", format=\"openvino\")\n    ```\n    \"\"\"\n    if not filepath.endswith(\".xml\"):\n        raise ValueError(\n            \"The OpenVINO export requires the filepath to end with '.xml'. \"\n            f\"Got: filepath={filepath}\"\n        )\n\n    import openvino as ov\n    import openvino.opset15 as ov_opset\n\n    from keras.src.backend.openvino.core import OPENVINO_DTYPES\n    from keras.src.backend.openvino.core import OpenVINOKerasTensor\n\n    actual_verbose = verbose if verbose is not None else True\n\n    if input_signature is None:\n        input_signature = get_input_signature(model)\n\n    if backend.backend() == \"openvino\":\n        import inspect\n\n        def parameterize_inputs(inputs, prefix=\"\"):\n            if isinstance(inputs, (list, tuple)):\n                return [\n                    parameterize_inputs(e, f\"{prefix}{i}\")\n                    for i, e in enumerate(inputs)\n                ]\n            elif isinstance(inputs, dict):\n                return {k: parameterize_inputs(v, k) for k, v in inputs.items()}\n            elif isinstance(inputs, OpenVINOKerasTensor):\n                ov_type = OPENVINO_DTYPES[str(inputs.dtype)]\n                ov_shape = list(inputs.shape)\n                param = ov_opset.parameter(shape=ov_shape, dtype=ov_type)\n                param.set_friendly_name(prefix)\n                return OpenVINOKerasTensor(param.output(0))\n            else:\n                raise TypeError(f\"Unknown input type: {type(inputs)}\")\n\n        if isinstance(input_signature, list) and len(input_signature) == 1:\n            input_signature = input_signature[0]\n\n        sample_inputs = tree.map_structure(\n            lambda x: convert_spec_to_tensor(x, replace_none_number=1),\n            input_signature,\n        )\n        params = parameterize_inputs(sample_inputs)\n        signature = inspect.signature(model.call)\n        if len(signature.parameters) > 1 and isinstance(params, (list, tuple)):\n            outputs = model(*params)\n        else:\n            outputs = model(params)\n        parameters = [p.output.get_node() for p in tree.flatten(params)]\n        results = [ov_opset.result(r.output) for r in tree.flatten(outputs)]\n        ov_model = ov.Model(results=results, parameters=parameters)\n        flat_specs = tree.flatten(input_signature)\n        for ov_input, spec in zip(ov_model.inputs, flat_specs):\n            # Respect the dynamic axes from the original input signature.\n            dynamic_shape_dims = [\n                -1 if dim is None else dim for dim in spec.shape\n            ]\n            dynamic_shape = ov.PartialShape(dynamic_shape_dims)\n            ov_input.get_node().set_partial_shape(dynamic_shape)\n\n    elif backend.backend() in (\"tensorflow\", \"jax\"):\n        inputs = tree.map_structure(make_tf_tensor_spec, input_signature)\n        decorated_fn = get_concrete_fn(model, inputs, **kwargs)\n        ov_model = ov.convert_model(decorated_fn)\n        set_names(ov_model, inputs)\n    elif backend.backend() == \"torch\":\n        import torch\n\n        sample_inputs = tree.map_structure(\n            lambda x: convert_spec_to_tensor(x, replace_none_number=1),\n            input_signature,\n        )\n        sample_inputs = tuple(sample_inputs)\n        if hasattr(model, \"eval\"):\n            model.eval()\n        with warnings.catch_warnings():\n            warnings.filterwarnings(\"ignore\", category=torch.jit.TracerWarning)\n            traced = torch.jit.trace(model, sample_inputs)\n            ov_model = ov.convert_model(traced)\n            set_names(ov_model, sample_inputs)\n    else:\n        raise NotImplementedError(\n            \"`export_openvino` is only compatible with OpenVINO, \"\n            \"TensorFlow, JAX and Torch backends.\"\n        )\n\n    ov.serialize(ov_model, filepath)\n\n    if actual_verbose:\n        io_utils.print_msg(f\"Saved OpenVINO IR at '{filepath}'.\")\n\n\ndef collect_names(structure):\n    if isinstance(structure, dict):\n        for k, v in structure.items():\n            if isinstance(v, (dict, list, tuple)):\n                yield from collect_names(v)\n            else:\n                yield k\n    elif isinstance(structure, (list, tuple)):\n        for v in structure:\n            yield from collect_names(v)\n    else:\n        if hasattr(structure, \"name\") and structure.name:\n            yield structure.name\n        else:\n            yield \"input\"\n\n\ndef set_names(model, inputs):\n    names = list(collect_names(inputs))\n    for ov_input, name in zip(model.inputs, names):\n        ov_input.get_node().set_friendly_name(name)\n        ov_input.tensor.set_names({name})\n\n\ndef _check_jax_kwargs(kwargs):\n    kwargs = kwargs.copy()\n    if \"is_static\" not in kwargs:\n        kwargs[\"is_static\"] = True\n    if \"jax2tf_kwargs\" not in kwargs:\n        kwargs[\"jax2tf_kwargs\"] = {\n            \"enable_xla\": False,\n            \"native_serialization\": False,\n        }\n    if kwargs[\"is_static\"] is not True:\n        raise ValueError(\n            \"`is_static` must be `True` in `kwargs` when using the jax backend.\"\n        )\n    if kwargs[\"jax2tf_kwargs\"][\"enable_xla\"] is not False:\n        raise ValueError(\n            \"`enable_xla` must be `False` in `kwargs['jax2tf_kwargs']` \"\n            \"when using the jax backend.\"\n        )\n    if kwargs[\"jax2tf_kwargs\"][\"native_serialization\"] is not False:\n        raise ValueError(\n            \"`native_serialization` must be `False` in \"\n            \"`kwargs['jax2tf_kwargs']` when using the jax backend.\"\n        )\n    return kwargs\n\n\ndef get_concrete_fn(model, input_signature, **kwargs):\n    if backend.backend() == \"jax\":\n        kwargs = _check_jax_kwargs(kwargs)\n    export_archive = ExportArchive()\n    export_archive.track_and_add_endpoint(\n        DEFAULT_ENDPOINT_NAME, model, input_signature, **kwargs\n    )\n    if backend.backend() == \"tensorflow\":\n        export_archive._filter_and_track_resources()\n    return export_archive._get_concrete_fn(DEFAULT_ENDPOINT_NAME)\n"
  },
  {
    "path": "keras/src/export/openvino_test.py",
    "content": "import os\n\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src import tree\nfrom keras.src.export import openvino\nfrom keras.src.saving import saving_lib\nfrom keras.src.testing.test_utils import named_product\n\ntry:\n    import openvino as ov\nexcept ImportError:\n    ov = None\n\n\nclass CustomModel(models.Model):\n    def __init__(self, layer_list):\n        super().__init__()\n        self.layer_list = layer_list\n\n    def call(self, input):\n        output = input\n        for layer in self.layer_list:\n            output = layer(output)\n        return output\n\n\ndef get_model(type=\"sequential\", input_shape=(10,), layer_list=None):\n    layer_list = layer_list or [\n        layers.Dense(10, activation=\"relu\"),\n        layers.BatchNormalization(),\n        layers.Dense(1, activation=\"sigmoid\"),\n    ]\n    if type == \"sequential\":\n        return models.Sequential(layer_list)\n    elif type == \"functional\":\n        input = output = tree.map_shape_structure(layers.Input, input_shape)\n        for layer in layer_list:\n            output = layer(output)\n        return models.Model(inputs=input, outputs=output)\n    elif type == \"subclass\":\n        return CustomModel(layer_list)\n    elif type == \"lstm\":\n        # https://github.com/keras-team/keras/issues/21390\n        inputs = layers.Input((4, 10))\n        x = layers.Bidirectional(\n            layers.LSTM(\n                10,\n                kernel_initializer=\"he_normal\",\n                return_sequences=True,\n                kernel_regularizer=None,\n            ),\n            merge_mode=\"sum\",\n        )(inputs)\n        outputs = layers.Bidirectional(\n            layers.LSTM(\n                10,\n                kernel_initializer=\"he_normal\",\n                return_sequences=True,\n                kernel_regularizer=None,\n            ),\n            merge_mode=\"concat\",\n        )(x)\n        return models.Model(inputs=inputs, outputs=outputs)\n\n\n@pytest.mark.skipif(ov is None, reason=\"OpenVINO is not installed\")\n@pytest.mark.skipif(\n    backend.backend() not in (\"tensorflow\", \"openvino\", \"jax\", \"torch\"),\n    reason=(\n        \"`export_openvino` only currently supports\"\n        \"the tensorflow, jax, torch and openvino backends.\"\n    ),\n)\n@pytest.mark.skipif(testing.jax_uses_gpu(), reason=\"Leads to core dumps on CI\")\n@pytest.mark.skipif(\n    testing.tensorflow_uses_gpu(), reason=\"Leads to core dumps on CI\"\n)\nclass ExportOpenVINOTest(testing.TestCase):\n    @parameterized.named_parameters(\n        named_product(\n            model_type=[\"sequential\", \"functional\", \"subclass\", \"lstm\"]\n        )\n    )\n    def test_standard_model_export(self, model_type):\n        if model_type == \"lstm\":\n            self.skipTest(\n                \"LSTM export not supported - unimplemented QR operation\"\n            )\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model.xml\")\n        model = get_model(model_type)\n        batch_size = 3\n        if model_type == \"lstm\":\n            ref_input = np.random.normal(size=(batch_size, 4, 10))\n        else:\n            ref_input = np.random.normal(size=(batch_size, 10))\n        ref_input = ref_input.astype(\"float32\")\n        ref_output = model(ref_input)\n\n        try:\n            openvino.export_openvino(model, temp_filepath)\n        except Exception as e:\n            if \"XlaCallModule\" in str(e):\n                self.skipTest(\"OpenVINO does not support XlaCallModule yet\")\n            raise e\n\n        # Load and run inference with OpenVINO\n        core = ov.Core()\n        ov_model = core.read_model(temp_filepath)\n        compiled_model = core.compile_model(ov_model, \"CPU\")\n\n        ov_output = compiled_model([ref_input])[compiled_model.output(0)]\n\n        self.assertAllClose(ref_output, ov_output)\n\n        larger_input = np.concatenate([ref_input, ref_input], axis=0)\n        compiled_model([larger_input])\n\n    @parameterized.named_parameters(\n        named_product(struct_type=[\"tuple\", \"array\", \"dict\"])\n    )\n    def test_model_with_input_structure(self, struct_type):\n        class TupleModel(models.Model):\n            def call(self, inputs):\n                x, y = inputs\n                return ops.add(x, y)\n\n        class ArrayModel(models.Model):\n            def call(self, inputs):\n                x = inputs[0]\n                y = inputs[1]\n                return ops.add(x, y)\n\n        class DictModel(models.Model):\n            def call(self, inputs):\n                x = inputs[\"x\"]\n                y = inputs[\"y\"]\n                return ops.add(x, y)\n\n        batch_size = 3\n        ref_input = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n        if struct_type == \"tuple\":\n            model = TupleModel()\n            ref_input = (ref_input, ref_input * 2)\n        elif struct_type == \"array\":\n            model = ArrayModel()\n            ref_input = [ref_input, ref_input * 2]\n        elif struct_type == \"dict\":\n            model = DictModel()\n            ref_input = {\"x\": ref_input, \"y\": ref_input * 2}\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model.xml\")\n        ref_output = model(tree.map_structure(ops.convert_to_tensor, ref_input))\n\n        try:\n            openvino.export_openvino(model, temp_filepath)\n        except Exception as e:\n            if \"XlaCallModule\" in str(e):\n                self.skipTest(\"OpenVINO does not support XlaCallModule yet\")\n            raise e\n\n        # Load and run inference with OpenVINO\n        core = ov.Core()\n        ov_model = core.read_model(temp_filepath)\n        compiled_model = core.compile_model(ov_model, \"CPU\")\n\n        if isinstance(ref_input, dict):\n            ov_inputs = [ref_input[key] for key in ref_input.keys()]\n        else:\n            ov_inputs = list(ref_input)\n\n        ov_output = compiled_model(ov_inputs)[compiled_model.output(0)]\n        self.assertAllClose(ref_output, ov_output)\n\n        # Test with keras.saving_lib\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"exported_model.keras\"\n        )\n        saving_lib.save_model(model, temp_filepath)\n        revived_model = saving_lib.load_model(\n            temp_filepath,\n            {\n                \"TupleModel\": TupleModel,\n                \"ArrayModel\": ArrayModel,\n                \"DictModel\": DictModel,\n            },\n        )\n        self.assertAllClose(ref_output, revived_model(ref_input))\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model2.xml\")\n        try:\n            openvino.export_openvino(revived_model, temp_filepath)\n        except Exception as e:\n            if \"XlaCallModule\" in str(e):\n                self.skipTest(\"OpenVINO does not support XlaCallModule yet\")\n            raise e\n\n        bigger_ref_input = tree.map_structure(\n            lambda x: np.concatenate([x, x], axis=0), ref_input\n        )\n        if isinstance(bigger_ref_input, dict):\n            bigger_ov_inputs = [\n                bigger_ref_input[key] for key in bigger_ref_input.keys()\n            ]\n        else:\n            bigger_ov_inputs = list(bigger_ref_input)\n        compiled_model(bigger_ov_inputs)\n\n    def test_model_with_multiple_inputs(self):\n        class TwoInputsModel(models.Model):\n            def call(self, x, y):\n                return x + y\n\n            def build(self, y_shape, x_shape):\n                self.built = True\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model.xml\")\n        model = TwoInputsModel()\n        batch_size = 3\n        ref_input_x = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n        ref_input_y = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n        ref_output = model(ref_input_x, ref_input_y)\n\n        try:\n            openvino.export_openvino(model, temp_filepath)\n        except Exception as e:\n            if \"XlaCallModule\" in str(e):\n                self.skipTest(\"OpenVINO does not support XlaCallModule yet\")\n            raise e\n\n        # Load and run inference with OpenVINO\n        core = ov.Core()\n        ov_model = core.read_model(temp_filepath)\n        compiled_model = core.compile_model(ov_model, \"CPU\")\n\n        ov_output = compiled_model([ref_input_x, ref_input_y])[\n            compiled_model.output(0)\n        ]\n        self.assertAllClose(ref_output, ov_output)\n        larger_input_x = np.concatenate([ref_input_x, ref_input_x], axis=0)\n        larger_input_y = np.concatenate([ref_input_y, ref_input_y], axis=0)\n        compiled_model([larger_input_x, larger_input_y])\n"
  },
  {
    "path": "keras/src/export/saved_model.py",
    "content": "\"\"\"Library for exporting SavedModel for Keras models/layers.\"\"\"\n\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.export.export_utils import get_input_signature\nfrom keras.src.export.neptune_model_export_archive import (\n    NeptuneModelExportArchive,\n)\n\n# Re-export for backward compatibility (used by tfsm_layer.py)\nfrom keras.src.export.saved_model_export_archive import (  # noqa: F401\n    _list_variables_used_by_fns,\n)\n\nif backend.backend() == \"tensorflow\":\n    from keras.src.backend.tensorflow.export import (\n        TFExportArchive as BackendSavedModelExportArchive,\n    )\nelif backend.backend() == \"jax\":\n    from keras.src.backend.jax.export import (\n        JaxExportArchive as BackendSavedModelExportArchive,\n    )\nelif backend.backend() == \"torch\":\n    from keras.src.backend.torch.export import (\n        TorchExportArchive as BackendSavedModelExportArchive,\n    )\nelif backend.backend() == \"numpy\":\n    from keras.src.backend.numpy.export import (\n        NumpyExportArchive as BackendSavedModelExportArchive,\n    )\nelif backend.backend() == \"openvino\":\n    from keras.src.backend.openvino.export import (\n        OpenvinoExportArchive as BackendSavedModelExportArchive,\n    )\nelse:\n    raise RuntimeError(\n        f\"Backend '{backend.backend()}' must implement ExportArchive.\"\n    )\n\nDEFAULT_ENDPOINT_NAME = \"serve\"\n\n\ndef export_saved_model(\n    model, filepath, verbose=None, input_signature=None, **kwargs\n):\n    \"\"\"Export the model as a TensorFlow SavedModel artifact for inference.\n\n    This method lets you export a model to a lightweight SavedModel artifact\n    that contains the model's forward pass only (its `call()` method)\n    and can be served via e.g. TensorFlow Serving. The forward pass is\n    registered under the name `serve()` (see example below).\n\n    The original code of the model (including any custom layers you may\n    have used) is *no longer* necessary to reload the artifact -- it is\n    entirely standalone.\n\n    Args:\n        filepath: `str` or `pathlib.Path` object. The path to save the artifact.\n        verbose: `bool`. Whether to print a message during export. Defaults to\n            `None`, which uses the default value set by different backends and\n            formats.\n        input_signature: Optional. Specifies the shape and dtype of the model\n            inputs. Can be a structure of `keras.InputSpec`, `tf.TensorSpec`,\n            `backend.KerasTensor`, or backend tensor. If not provided, it will\n            be automatically computed. Defaults to `None`.\n        **kwargs: Additional keyword arguments:\n            - Specific to the JAX backend:\n                - `is_static`: Optional `bool`. Indicates whether `fn` is\n                    static. Set to `False` if `fn` involves state updates\n                    (e.g., RNG seeds).\n                - `jax2tf_kwargs`: Optional `dict`. Arguments for\n                    `jax2tf.convert`. See [`jax2tf.convert`](\n                        https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).\n                    If `native_serialization` and `polymorphic_shapes` are not\n                    provided, they are automatically computed.\n\n    **Note:** This feature is currently supported only with TensorFlow, JAX and\n    Torch backends. Support for the Torch backend is experimental.\n\n    **Note:** The dynamic shape feature is not yet supported with Torch\n    backend. As a result, you must fully define the shapes of the inputs using\n    `input_signature`. If `input_signature` is not provided, all instances of\n    `None` (such as the batch size) will be replaced with `1`.\n\n    Example:\n\n    ```python\n    # Export the model as a TensorFlow SavedModel artifact\n    model.export(\"path/to/location\", format=\"tf_saved_model\")\n\n    # Load the artifact in a different process/environment\n    reloaded_artifact = tf.saved_model.load(\"path/to/location\")\n    predictions = reloaded_artifact.serve(input_data)\n    ```\n\n    If you would like to customize your serving endpoints, you can\n    use the lower-level `keras.export.ExportArchive` class. The\n    `export()` method relies on `ExportArchive` internally.\n    \"\"\"\n    if verbose is None:\n        verbose = True  # Defaults to `True` for all backends.\n    export_archive = ExportArchive()\n    if input_signature is None:\n        input_signature = get_input_signature(model)\n\n    export_archive.track_and_add_endpoint(\n        DEFAULT_ENDPOINT_NAME, model, input_signature, **kwargs\n    )\n    export_archive.write_out(filepath, verbose=verbose)\n\n\n@keras_export(\"keras.export.ExportArchive\")\nclass ExportArchive:\n    \"\"\"ExportArchive is used to write SavedModel artifacts for inference.\n\n    If you have a Keras model or layer that you want to export as SavedModel for\n    serving (e.g. via TensorFlow-Serving), you can use `ExportArchive`\n    to configure the different serving endpoints you need to make available,\n    as well as their signatures. Simply instantiate an `ExportArchive`,\n    use `track()` to register the layer(s) or model(s) to be used,\n    then use the `add_endpoint()` method to register a new serving endpoint.\n    When done, use the `write_out()` method to save the artifact.\n\n    The resulting artifact is a SavedModel and can be reloaded via\n    `tf.saved_model.load`.\n\n    Examples:\n\n    Here's how to export a model for inference.\n\n    ```python\n    export_archive = ExportArchive()\n    export_archive.track(model)\n    export_archive.add_endpoint(\n        name=\"serve\",\n        fn=model.call,\n        input_signature=[keras.InputSpec(shape=(None, 3), dtype=\"float32\")],\n    )\n    export_archive.write_out(\"path/to/location\")\n\n    # Elsewhere, we can reload the artifact and serve it.\n    # The endpoint we added is available as a method:\n    serving_model = tf.saved_model.load(\"path/to/location\")\n    outputs = serving_model.serve(inputs)\n    ```\n\n    Here's how to export a model with one endpoint for inference and one\n    endpoint for a training-mode forward pass (e.g. with dropout on).\n\n    ```python\n    export_archive = ExportArchive()\n    export_archive.track(model)\n    export_archive.add_endpoint(\n        name=\"call_inference\",\n        fn=lambda x: model.call(x, training=False),\n        input_signature=[keras.InputSpec(shape=(None, 3), dtype=\"float32\")],\n    )\n    export_archive.add_endpoint(\n        name=\"call_training\",\n        fn=lambda x: model.call(x, training=True),\n        input_signature=[keras.InputSpec(shape=(None, 3), dtype=\"float32\")],\n    )\n    export_archive.write_out(\"path/to/location\")\n    ```\n\n    **Note on resource tracking:**\n\n    `ExportArchive` is able to automatically track all `keras.Variables` used\n    by its endpoints, so most of the time calling `.track(model)`\n    is not strictly required. However, if your model uses lookup layers such\n    as `IntegerLookup`, `StringLookup`, or `TextVectorization`,\n    it will need to be tracked explicitly via `.track(model)`.\n\n    Explicit tracking is also required if you need to be able to access\n    the properties `variables`, `trainable_variables`, or\n    `non_trainable_variables` on the revived archive.\n    \"\"\"\n\n    def __new__(cls, format=\"saved_model\", **kwargs):\n        if format == \"saved_model\":\n            return BackendSavedModelExportArchive()\n        elif format == \"neptune_model\":\n            return NeptuneModelExportArchive()\n        else:\n            raise ValueError(f\"Unsupported format: {format}\")\n\n    def track(self, resource):\n        \"\"\"Track the variables (of a layer or model) and other assets.\n\n        By default, all variables used by an endpoint function are automatically\n        tracked when you call `add_endpoint()`. However, non-variables assets\n        such as lookup tables need to be tracked manually. Note that lookup\n        tables used by built-in Keras layers (`TextVectorization`,\n        `IntegerLookup`, `StringLookup`) are automatically tracked by\n        `add_endpoint()`.\n\n        Args:\n            resource: A layer, model or a TensorFlow trackable resource.\n        \"\"\"\n        raise NotImplementedError(\n            \"track() is not implemented for this backend.\"\n        )\n\n    def add_endpoint(self, name, fn, input_signature=None, **kwargs):\n        \"\"\"Register a new serving endpoint.\n\n        Args:\n            name: `str`. The name of the endpoint.\n            fn: A callable. It should only leverage resources\n                (e.g. `keras.Variable` objects or `tf.lookup.StaticHashTable`\n                objects) that are available on the models/layers tracked by the\n                `ExportArchive` (you can call `.track(model)` to track a new\n                model).\n                The shape and dtype of the inputs to the function must be\n                known. For that purpose, you can either 1) make sure that `fn`\n                is a `tf.function` that has been called at least once, or 2)\n                provide an `input_signature` argument that specifies the shape\n                and dtype of the inputs (see below).\n            input_signature: Optional. Specifies the shape and dtype of `fn`.\n                Can be a structure of `keras.InputSpec`, `tf.TensorSpec`,\n                `backend.KerasTensor`, or backend tensor (see below for an\n                example showing a `Functional` model with 2 input arguments). If\n                not provided, `fn` must be a `tf.function` that has been called\n                at least once. Defaults to `None`.\n            **kwargs: Additional keyword arguments:\n                - Specific to the JAX backend:\n                    - `is_static`: Optional `bool`. Indicates whether `fn` is\n                        static. Set to `False` if `fn` involves state updates\n                        (e.g., RNG seeds).\n                    - `jax2tf_kwargs`: Optional `dict`. Arguments for\n                        `jax2tf.convert`. See [`jax2tf.convert`](\n                            https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).\n                        If `native_serialization` and `polymorphic_shapes` are\n                        not provided, they are automatically computed.\n\n        Returns:\n            The `tf.function` wrapping `fn` that was added to the archive.\n\n        Example:\n\n        Adding an endpoint using the `input_signature` argument when the\n        model has a single input argument:\n\n        ```python\n        export_archive = ExportArchive()\n        export_archive.track(model)\n        export_archive.add_endpoint(\n            name=\"serve\",\n            fn=model.call,\n            input_signature=[keras.InputSpec(shape=(None, 3), dtype=\"float32\")],\n        )\n        ```\n\n        Adding an endpoint using the `input_signature` argument when the\n        model has two positional input arguments:\n\n        ```python\n        export_archive = ExportArchive()\n        export_archive.track(model)\n        export_archive.add_endpoint(\n            name=\"serve\",\n            fn=model.call,\n            input_signature=[\n                keras.InputSpec(shape=(None, 3), dtype=\"float32\"),\n                keras.InputSpec(shape=(None, 4), dtype=\"float32\"),\n            ],\n        )\n        ```\n\n        Adding an endpoint using the `input_signature` argument when the\n        model has one input argument that is a list of 2 tensors (e.g.\n        a Functional model with 2 inputs):\n\n        ```python\n        model = keras.Model(inputs=[x1, x2], outputs=outputs)\n\n        export_archive = ExportArchive()\n        export_archive.track(model)\n        export_archive.add_endpoint(\n            name=\"serve\",\n            fn=model.call,\n            input_signature=[\n                [\n                    keras.InputSpec(shape=(None, 3), dtype=\"float32\"),\n                    keras.InputSpec(shape=(None, 4), dtype=\"float32\"),\n                ],\n            ],\n        )\n        ```\n\n        This also works with dictionary inputs:\n\n        ```python\n        model = keras.Model(inputs={\"x1\": x1, \"x2\": x2}, outputs=outputs)\n\n        export_archive = ExportArchive()\n        export_archive.track(model)\n        export_archive.add_endpoint(\n            name=\"serve\",\n            fn=model.call,\n            input_signature=[\n                {\n                    \"x1\": keras.InputSpec(shape=(None, 3), dtype=\"float32\"),\n                    \"x2\": keras.InputSpec(shape=(None, 4), dtype=\"float32\"),\n                },\n            ],\n        )\n        ```\n\n        Adding an endpoint that is a `tf.function`:\n\n        ```python\n        @tf.function()\n        def serving_fn(x):\n            return model(x)\n\n        # The function must be traced, i.e. it must be called at least once.\n        serving_fn(tf.random.normal(shape=(2, 3)))\n\n        export_archive = ExportArchive()\n        export_archive.track(model)\n        export_archive.add_endpoint(name=\"serve\", fn=serving_fn)\n        ```\n\n        Combining a model with some TensorFlow preprocessing, which can use\n        TensorFlow resources:\n\n        ```python\n        lookup_table = tf.lookup.StaticHashTable(initializer, default_value=0.0)\n\n        export_archive = ExportArchive()\n        model_fn = export_archive.track_and_add_endpoint(\n            \"model_fn\",\n            model,\n            input_signature=[tf.TensorSpec(shape=(None, 5), dtype=tf.float32)],\n        )\n        export_archive.track(lookup_table)\n\n        @tf.function()\n        def serving_fn(x):\n            x = lookup_table.lookup(x)\n            return model_fn(x)\n\n        export_archive.add_endpoint(name=\"serve\", fn=serving_fn)\n        ```\n        \"\"\"\n        raise NotImplementedError(\n            \"add_endpoint() is not implemented for this backend.\"\n        )\n\n    def track_and_add_endpoint(self, name, resource, input_signature, **kwargs):\n        \"\"\"Track the variables and register a new serving endpoint.\n\n        This function combines the functionality of `track` and `add_endpoint`.\n        It tracks the variables of the `resource` (either a layer or a model)\n        and registers a serving endpoint using `resource.__call__`.\n\n        Args:\n            name: `str`. The name of the endpoint.\n            resource: A trackable Keras resource, such as a layer or model.\n            input_signature: Optional. Specifies the shape and dtype of `fn`.\n                Can be a structure of `keras.InputSpec`, `tf.TensorSpec`,\n                `backend.KerasTensor`, or backend tensor (see below for an\n                example showing a `Functional` model with 2 input arguments). If\n                not provided, `fn` must be a `tf.function` that has been called\n                at least once. Defaults to `None`.\n            **kwargs: Additional keyword arguments:\n                - Specific to the JAX backend:\n                    - `is_static`: Optional `bool`. Indicates whether `fn` is\n                        static. Set to `False` if `fn` involves state updates\n                        (e.g., RNG seeds).\n                    - `jax2tf_kwargs`: Optional `dict`. Arguments for\n                        `jax2tf.convert`. See [`jax2tf.convert`](\n                            https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).\n                        If `native_serialization` and `polymorphic_shapes` are\n                        not provided, they are automatically computed.\n\n        \"\"\"\n        raise NotImplementedError(\n            \"track_and_add_endpoint() is not implemented for this backend.\"\n        )\n\n    def add_variable_collection(self, name, variables):\n        \"\"\"Register a set of variables to be retrieved after reloading.\n\n        Arguments:\n            name: The string name for the collection.\n            variables: A tuple/list/set of `keras.Variable` instances.\n\n        Example:\n\n        ```python\n        export_archive = ExportArchive()\n        export_archive.track(model)\n        # Register an endpoint\n        export_archive.add_endpoint(\n            name=\"serve\",\n            fn=model.call,\n            input_signature=[keras.InputSpec(shape=(None, 3), dtype=\"float32\")],\n        )\n        # Save a variable collection\n        export_archive.add_variable_collection(\n            name=\"optimizer_variables\", variables=model.optimizer.variables)\n        export_archive.write_out(\"path/to/location\")\n\n        # Reload the object\n        revived_object = tf.saved_model.load(\"path/to/location\")\n        # Retrieve the variables\n        optimizer_variables = revived_object.optimizer_variables\n        ```\n        \"\"\"\n        raise NotImplementedError(\n            \"add_variable_collection() is not implemented for this backend.\"\n        )\n\n    def write_out(self, filepath, options=None, verbose=True):\n        \"\"\"Write the corresponding SavedModel to disk.\n\n        Arguments:\n            filepath: `str` or `pathlib.Path` object.\n                Path where to save the artifact.\n            options: `tf.saved_model.SaveOptions` object that specifies\n                SavedModel saving options.\n            verbose: whether to print all the variables of an\n                exported SavedModel.\n\n        **Note on TF-Serving**: all endpoints registered via `add_endpoint()`\n        are made visible for TF-Serving in the SavedModel artifact. In addition,\n        the first endpoint registered is made visible under the alias\n        `\"serving_default\"` (unless an endpoint with the name\n        `\"serving_default\"` was already registered manually),\n        since TF-Serving requires this endpoint to be set.\n        \"\"\"\n        raise NotImplementedError(\n            \"write_out() is not implemented for this backend.\"\n        )\n"
  },
  {
    "path": "keras/src/export/saved_model_export_archive.py",
    "content": "\"\"\"Base class for SavedModel export archive.\"\"\"\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import tree\nfrom keras.src.export.export_utils import make_tf_tensor_spec\nfrom keras.src.utils.module_utils import tensorflow as tf\n\n\nclass SavedModelExportArchive:\n    \"\"\"Base class for SavedModel export archive.\n\n    This class contains all the common SavedModel export logic that is shared\n    across different backends (TensorFlow, JAX, Torch). Backend-specific\n    implementations should extend this class and override the following methods:\n    - `_backend_track_layer(layer)`: Track variables of a layer.\n    - `_backend_add_endpoint(name, fn, input_signature, **kwargs)`: Backend-\n        specific endpoint creation logic.\n    - `_backend_init()`: Backend-specific initialization (optional).\n    \"\"\"\n\n    def __init__(self):\n        if backend.backend() not in (\"tensorflow\", \"jax\", \"torch\"):\n            raise NotImplementedError(\n                \"`ExportArchive` is only compatible with TensorFlow, JAX and \"\n                \"Torch backends.\"\n            )\n\n        self._endpoint_names = []\n        self._endpoint_signatures = {}\n        self.tensorflow_version = tf.__version__\n\n        self._tf_trackable = tf.__internal__.tracking.AutoTrackable()\n        self._tf_trackable.variables = []\n        self._tf_trackable.trainable_variables = []\n        self._tf_trackable.non_trainable_variables = []\n\n        # Call backend-specific initialization if defined\n        self._backend_init()\n\n    def _backend_init(self):\n        \"\"\"Backend-specific initialization. Override in subclasses.\"\"\"\n        pass\n\n    @property\n    def variables(self):\n        return self._tf_trackable.variables\n\n    @property\n    def trainable_variables(self):\n        return self._tf_trackable.trainable_variables\n\n    @property\n    def non_trainable_variables(self):\n        return self._tf_trackable.non_trainable_variables\n\n    def track(self, resource):\n        \"\"\"Track the variables (of a layer or model) and other assets.\n\n        By default, all variables used by an endpoint function are automatically\n        tracked when you call `add_endpoint()`. However, non-variables assets\n        such as lookup tables need to be tracked manually. Note that lookup\n        tables used by built-in Keras layers (`TextVectorization`,\n        `IntegerLookup`, `StringLookup`) are automatically tracked by\n        `add_endpoint()`.\n\n        Args:\n            resource: A layer, model or a TensorFlow trackable resource.\n        \"\"\"\n        if isinstance(resource, layers.Layer) and not resource.built:\n            raise ValueError(\n                \"The layer provided has not yet been built. \"\n                \"It must be built before export.\"\n            )\n\n        # Note: with the TensorFlow backend, Layers and Models fall into both\n        # the Layer case and the Trackable case. The Trackable case is needed\n        # for preprocessing layers in order to track lookup tables.\n        if isinstance(resource, tf.__internal__.tracking.Trackable):\n            if not hasattr(self, \"_tracked\"):\n                self._tracked = []\n            self._tracked.append(resource)\n\n        if isinstance(resource, layers.Layer):\n            self._backend_track_layer(resource)\n        elif not isinstance(resource, tf.__internal__.tracking.Trackable):\n            raise ValueError(\n                \"Invalid resource type. Expected a Keras `Layer` or `Model` \"\n                \"or a TensorFlow `Trackable` object. \"\n                f\"Received object {resource} of type '{type(resource)}'. \"\n            )\n\n    def _backend_track_layer(self, layer):\n        raise NotImplementedError(\n            \"_backend_track_layer() must be implemented in backend subclasses.\"\n        )\n\n    def add_endpoint(self, name, fn, input_signature=None, **kwargs):\n        if name in self._endpoint_names:\n            raise ValueError(f\"Endpoint name '{name}' is already taken.\")\n\n        if backend.backend() != \"jax\":\n            if \"jax2tf_kwargs\" in kwargs or \"is_static\" in kwargs:\n                raise ValueError(\n                    \"'jax2tf_kwargs' and 'is_static' are only supported with \"\n                    f\"the jax backend. Current backend: {backend.backend()}\"\n                )\n\n        # The fast path if `fn` is already a `tf.function`.\n        if input_signature is None:\n            if isinstance(fn, tf.types.experimental.GenericFunction):\n                if not fn._list_all_concrete_functions():\n                    raise ValueError(\n                        f\"The provided tf.function '{fn}' \"\n                        \"has never been called. \"\n                        \"To specify the expected shape and dtype \"\n                        \"of the function's arguments, \"\n                        \"you must either provide a function that \"\n                        \"has been called at least once, or alternatively pass \"\n                        \"an `input_signature` argument in `add_endpoint()`.\"\n                    )\n                decorated_fn = fn\n            else:\n                raise ValueError(\n                    \"If the `fn` argument provided is not a `tf.function`, \"\n                    \"you must provide an `input_signature` argument to \"\n                    \"specify the shape and dtype of the function arguments. \"\n                    \"Example:\\n\\n\"\n                    \"export_archive.add_endpoint(\\n\"\n                    \"    name='call',\\n\"\n                    \"    fn=model.call,\\n\"\n                    \"    input_signature=[\\n\"\n                    \"        keras.InputSpec(\\n\"\n                    \"            shape=(None, 224, 224, 3),\\n\"\n                    \"            dtype='float32',\\n\"\n                    \"        )\\n\"\n                    \"    ],\\n\"\n                    \")\"\n                )\n            setattr(self._tf_trackable, name, decorated_fn)\n            self._endpoint_names.append(name)\n            return decorated_fn\n\n        input_signature = tree.map_structure(\n            make_tf_tensor_spec, input_signature\n        )\n        decorated_fn = self._backend_add_endpoint(\n            name, fn, input_signature, **kwargs\n        )\n        self._endpoint_signatures[name] = input_signature\n        setattr(self._tf_trackable, name, decorated_fn)\n        self._endpoint_names.append(name)\n        return decorated_fn\n\n    def _backend_add_endpoint(self, name, fn, input_signature, **kwargs):\n        raise NotImplementedError(\n            \"_backend_add_endpoint() must be implemented in backend subclasses.\"\n        )\n\n    def track_and_add_endpoint(self, name, resource, input_signature, **kwargs):\n        \"\"\"Track the variables and register a new serving endpoint.\n\n        This function combines the functionality of `track` and `add_endpoint`.\n        It tracks the variables of the `resource` (either a layer or a model)\n        and registers a serving endpoint using `resource.__call__`.\n\n        Args:\n            name: `str`. The name of the endpoint.\n            resource: A trackable Keras resource, such as a layer or model.\n            input_signature: Optional. Specifies the shape and dtype of `fn`.\n                Can be a structure of `keras.InputSpec`, `tf.TensorSpec`,\n                `backend.KerasTensor`, or backend tensor (see below for an\n                example showing a `Functional` model with 2 input arguments). If\n                not provided, `fn` must be a `tf.function` that has been called\n                at least once. Defaults to `None`.\n            **kwargs: Additional keyword arguments:\n                - Specific to the JAX backend:\n                    - `is_static`: Optional `bool`. Indicates whether `fn` is\n                        static. Set to `False` if `fn` involves state updates\n                        (e.g., RNG seeds).\n                    - `jax2tf_kwargs`: Optional `dict`. Arguments for\n                        `jax2tf.convert`. See [`jax2tf.convert`](\n                            https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).\n                        If `native_serialization` and `polymorphic_shapes` are\n                        not provided, they are automatically computed.\n\n        \"\"\"\n        self.track(resource)\n        return self.add_endpoint(\n            name, resource.__call__, input_signature, **kwargs\n        )\n\n    def add_variable_collection(self, name, variables):\n        \"\"\"Register a set of variables to be retrieved after reloading.\n\n        Arguments:\n            name: The string name for the collection.\n            variables: A tuple/list/set of `keras.Variable` instances.\n\n        Example:\n\n        ```python\n        export_archive = ExportArchive()\n        export_archive.track(model)\n        # Register an endpoint\n        export_archive.add_endpoint(\n            name=\"serve\",\n            fn=model.call,\n            input_signature=[keras.InputSpec(shape=(None, 3), dtype=\"float32\")],\n        )\n        # Save a variable collection\n        export_archive.add_variable_collection(\n            name=\"optimizer_variables\", variables=model.optimizer.variables)\n        export_archive.write_out(\"path/to/location\")\n\n        # Reload the object\n        revived_object = tf.saved_model.load(\"path/to/location\")\n        # Retrieve the variables\n        optimizer_variables = revived_object.optimizer_variables\n        ```\n        \"\"\"\n        if not isinstance(variables, (list, tuple, set)):\n            raise ValueError(\n                \"Expected `variables` to be a list/tuple/set. \"\n                f\"Received instead object of type '{type(variables)}'.\"\n            )\n        # Ensure that all variables added are either tf.Variables\n        # or Variables created by Keras 3 with the TF or JAX backends.\n        if not all(\n            isinstance(v, (tf.Variable, backend.Variable)) for v in variables\n        ):\n            raise ValueError(\n                \"Expected all elements in `variables` to be \"\n                \"`tf.Variable` instances. Found instead the following types: \"\n                f\"{list(set(type(v) for v in variables))}\"\n            )\n        if backend.backend() == \"jax\":\n            variables = tree.flatten(\n                tree.map_structure(self._convert_to_tf_variable, variables)\n            )\n        setattr(self._tf_trackable, name, list(variables))\n\n    def write_out(self, filepath, options=None, verbose=True):\n        \"\"\"Write the corresponding SavedModel to disk.\n\n        Arguments:\n            filepath: `str` or `pathlib.Path` object.\n                Path where to save the artifact.\n            options: `tf.saved_model.SaveOptions` object that specifies\n                SavedModel saving options.\n            verbose: whether to print all the variables of an\n                exported SavedModel.\n\n        **Note on TF-Serving**: all endpoints registered via `add_endpoint()`\n        are made visible for TF-Serving in the SavedModel artifact. In addition,\n        the first endpoint registered is made visible under the alias\n        `\"serving_default\"` (unless an endpoint with the name\n        `\"serving_default\"` was already registered manually),\n        since TF-Serving requires this endpoint to be set.\n        \"\"\"\n        from keras.src.utils import io_utils\n\n        if not self._endpoint_names:\n            raise ValueError(\n                \"No endpoints have been set yet. Call add_endpoint().\"\n            )\n        self._filter_and_track_resources()\n\n        signatures = {}\n        for name in self._endpoint_names:\n            signatures[name] = self._get_concrete_fn(name)\n        # Add \"serving_default\" signature key for TFServing\n        if \"serving_default\" not in self._endpoint_names:\n            signatures[\"serving_default\"] = self._get_concrete_fn(\n                self._endpoint_names[0]\n            )\n\n        tf.saved_model.save(\n            self._tf_trackable,\n            filepath,\n            options=options,\n            signatures=signatures,\n        )\n\n        # Print out available endpoints\n        if verbose:\n            endpoints = \"\\n\\n\".join(\n                _print_signature(\n                    getattr(self._tf_trackable, name), name, verbose=verbose\n                )\n                for name in self._endpoint_names\n            )\n            io_utils.print_msg(\n                f\"Saved artifact at '{filepath}'. \"\n                \"The following endpoints are available:\\n\\n\"\n                f\"{endpoints}\"\n            )\n\n    def _convert_to_tf_variable(self, backend_variable):\n        if not isinstance(backend_variable, backend.Variable):\n            raise TypeError(\n                \"`backend_variable` must be a `backend.Variable`. \"\n                f\"Recevied: backend_variable={backend_variable} of type \"\n                f\"({type(backend_variable)})\"\n            )\n        return tf.Variable(\n            backend_variable.value,\n            dtype=backend_variable.dtype,\n            trainable=backend_variable.trainable,\n            name=backend_variable.name,\n        )\n\n    def _get_concrete_fn(self, endpoint):\n        \"\"\"Workaround for some SavedModel quirks.\"\"\"\n        if endpoint in self._endpoint_signatures:\n            return getattr(self._tf_trackable, endpoint)\n        else:\n            traces = getattr(self._tf_trackable, endpoint)._trackable_children(\n                \"saved_model\"\n            )\n            return list(traces.values())[0]\n\n    def _get_variables_used_by_endpoints(self):\n        fns = [self._get_concrete_fn(name) for name in self._endpoint_names]\n        return _list_variables_used_by_fns(fns)\n\n    def _filter_and_track_resources(self):\n        \"\"\"Track resources used by endpoints / referenced in `track()` calls.\"\"\"\n        # Start by extracting variables from endpoints.\n        fns = [self._get_concrete_fn(name) for name in self._endpoint_names]\n        tvs, ntvs = _list_variables_used_by_fns(fns)\n        self._tf_trackable._all_variables = list(tvs + ntvs)\n\n        # `tf.train.TrackableView` hardcodes the `save_type` to \"checkpoint\".\n        # We need to subclass to use a `save_type` of \"savedmodel\".\n        savedmodel_cache = {}\n\n        class SavedModelTrackableView(tf.train.TrackableView):\n            @classmethod\n            def children(cls, obj, save_type=\"savedmodel\", **kwargs):\n                return super().children(obj, save_type, cache=savedmodel_cache)\n\n        # Next, track lookup tables.\n        # Hopefully, one day this will be automated at the tf.function level.\n        self._tf_trackable._misc_assets = []\n        from tensorflow.saved_model.experimental import TrackableResource\n\n        if hasattr(self, \"_tracked\"):\n            for root in self._tracked:\n                descendants = SavedModelTrackableView(root).descendants()\n                for trackable in descendants:\n                    if isinstance(trackable, TrackableResource):\n                        self._tf_trackable._misc_assets.append(trackable)\n\n\ndef _print_signature(fn, name, verbose=True):\n    concrete_fn = fn._list_all_concrete_functions()[0]\n    pprinted_signature = concrete_fn.pretty_printed_signature(verbose=verbose)\n    lines = pprinted_signature.split(\"\\n\")\n    lines = [f\"* Endpoint '{name}'\"] + lines[1:]\n    endpoint = \"\\n\".join(lines)\n    return endpoint\n\n\ndef _list_variables_used_by_fns(fns):\n    trainable_variables = []\n    non_trainable_variables = []\n    trainable_variables_ids = set()\n    non_trainable_variables_ids = set()\n    for fn in fns:\n        if hasattr(fn, \"concrete_functions\"):\n            concrete_functions = fn.concrete_functions\n        elif hasattr(fn, \"get_concrete_function\"):\n            concrete_functions = [fn.get_concrete_function()]\n        else:\n            concrete_functions = [fn]\n        for concrete_fn in concrete_functions:\n            for v in concrete_fn.trainable_variables:\n                if id(v) not in trainable_variables_ids:\n                    trainable_variables.append(v)\n                    trainable_variables_ids.add(id(v))\n\n            for v in concrete_fn.variables:\n                if (\n                    id(v) not in trainable_variables_ids\n                    and id(v) not in non_trainable_variables_ids\n                ):\n                    non_trainable_variables.append(v)\n                    non_trainable_variables_ids.add(id(v))\n    return trainable_variables, non_trainable_variables\n"
  },
  {
    "path": "keras/src/export/saved_model_test.py",
    "content": "\"\"\"Tests for SavedModel exporting utilities.\"\"\"\n\nimport os\n\nimport numpy as np\nimport pytest\nimport tensorflow as tf\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import random\nfrom keras.src import testing\nfrom keras.src import tree\nfrom keras.src.export import saved_model\nfrom keras.src.saving import saving_lib\nfrom keras.src.testing.test_utils import named_product\n\n\nclass CustomModel(models.Model):\n    def __init__(self, layer_list):\n        super().__init__()\n        self.layer_list = layer_list\n\n    def call(self, input):\n        output = input\n        for layer in self.layer_list:\n            output = layer(output)\n        return output\n\n\ndef get_model(type=\"sequential\", input_shape=(10,), layer_list=None):\n    layer_list = layer_list or [\n        layers.Dense(10, activation=\"relu\"),\n        layers.BatchNormalization(),\n        layers.Dense(1, activation=\"sigmoid\"),\n    ]\n    if type == \"sequential\":\n        return models.Sequential(layer_list)\n    elif type == \"functional\":\n        input = output = tree.map_shape_structure(layers.Input, input_shape)\n        for layer in layer_list:\n            output = layer(output)\n        return models.Model(inputs=input, outputs=output)\n    elif type == \"subclass\":\n        return CustomModel(layer_list)\n\n\n@pytest.mark.skipif(\n    backend.backend() not in (\"tensorflow\", \"jax\", \"torch\"),\n    reason=(\n        \"`export_saved_model` only currently supports the tensorflow, jax and \"\n        \"torch backends.\"\n    ),\n)\n@pytest.mark.skipif(testing.jax_uses_gpu(), reason=\"Leads to core dumps on CI\")\n@pytest.mark.skipif(\n    testing.torch_uses_gpu(), reason=\"Leads to core dumps on CI\"\n)\n@pytest.mark.skipif(\n    backend.backend() == \"torch\" and np.version.version.startswith(\"2.\"),\n    reason=\"Torch backend export (via torch_xla) is incompatible with np 2.0\",\n)\nclass ExportSavedModelTest(testing.TestCase):\n    @parameterized.named_parameters(\n        named_product(model_type=[\"sequential\", \"functional\", \"subclass\"])\n    )\n    def test_standard_model_export(self, model_type):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n        model = get_model(model_type)\n        batch_size = 3 if backend.backend() != \"torch\" else 1\n        ref_input = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n        ref_output = model(ref_input)\n\n        saved_model.export_saved_model(model, temp_filepath)\n        revived_model = tf.saved_model.load(temp_filepath)\n        self.assertAllClose(ref_output, revived_model.serve(ref_input))\n        # Test with a different batch size\n        revived_model.serve(tf.random.normal((6, 10)))\n\n    @parameterized.named_parameters(\n        named_product(model_type=[\"sequential\", \"functional\", \"subclass\"])\n    )\n    @pytest.mark.skipif(\n        backend.backend() == \"torch\",\n        reason=(\n            \"RuntimeError: mutating a non-functional tensor with a \"\n            \"functional tensor is not allowed in the torch backend.\"\n        ),\n    )\n    def test_model_with_rng_export(self, model_type):\n        class RandomLayer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.seed_generator = backend.random.SeedGenerator()\n\n            def call(self, inputs):\n                return inputs + random.uniform(\n                    ops.shape(inputs), seed=self.seed_generator\n                )\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n        model = get_model(model_type, layer_list=[RandomLayer()])\n        ref_input = tf.random.normal((3, 10))\n        ref_output = model(ref_input)\n\n        saved_model.export_saved_model(model, temp_filepath)\n        revived_model = tf.saved_model.load(temp_filepath)\n        self.assertEqual(ref_output.shape, revived_model.serve(ref_input).shape)\n        # Test with a different batch size\n        input = tf.random.normal((6, 10))\n        output1 = revived_model.serve(input)\n        output2 = revived_model.serve(input)\n        # Verify RNG seeding works and produces random outputs\n        self.assertNotAllClose(output1, output2)\n\n    @parameterized.named_parameters(\n        named_product(model_type=[\"sequential\", \"functional\", \"subclass\"])\n    )\n    @pytest.mark.skipif(\n        backend.backend() == \"torch\",\n        reason=(\n            \"RuntimeError: mutating a non-functional tensor with a \"\n            \"functional tensor is not allowed in the torch backend.\"\n        ),\n    )\n    def test_model_with_non_trainable_state_export(self, model_type):\n        class StateLayer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.counter = self.add_variable(\n                    (), \"zeros\", \"int32\", trainable=False\n                )\n\n            def call(self, inputs):\n                self.counter.assign_add(1)\n                return ops.array(inputs), ops.array(self.counter.value)\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n        model = get_model(model_type, layer_list=[StateLayer()])\n        model(tf.random.normal((3, 10)))\n\n        saved_model.export_saved_model(model, temp_filepath)\n        revived_model = tf.saved_model.load(temp_filepath)\n\n        # The non-trainable counter is expected to increment\n        input = tf.random.normal((6, 10))\n        output1, counter1 = revived_model.serve(input)\n        self.assertAllClose(output1, input)\n        self.assertAllClose(counter1, 2)\n        output2, counter2 = revived_model.serve(input)\n        self.assertAllClose(output2, input)\n        self.assertAllClose(counter2, 3)\n\n    @parameterized.named_parameters(\n        named_product(model_type=[\"sequential\", \"functional\", \"subclass\"])\n    )\n    def test_model_with_tf_data_layer(self, model_type):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n        model = get_model(model_type, layer_list=[layers.Rescaling(scale=2.0)])\n        batch_size = 3 if backend.backend() != \"torch\" else 1\n        ref_input = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n        ref_output = model(ref_input)\n\n        saved_model.export_saved_model(model, temp_filepath)\n        revived_model = tf.saved_model.load(temp_filepath)\n        self.assertAllClose(ref_output, revived_model.serve(ref_input))\n        # Test with a different batch size\n        revived_model.serve(tf.random.normal((6, 10)))\n\n    @parameterized.named_parameters(\n        named_product(struct_type=[\"tuple\", \"array\", \"dict\"])\n    )\n    def test_model_with_input_structure(self, struct_type):\n        class TupleModel(models.Model):\n            def call(self, inputs):\n                x, y = inputs\n                return ops.add(x, y)\n\n        class ArrayModel(models.Model):\n            def call(self, inputs):\n                x = inputs[0]\n                y = inputs[1]\n                return ops.add(x, y)\n\n        class DictModel(models.Model):\n            def call(self, inputs):\n                x = inputs[\"x\"]\n                y = inputs[\"y\"]\n                return ops.add(x, y)\n\n        batch_size = 3 if backend.backend() != \"torch\" else 1\n        ref_input = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n        if struct_type == \"tuple\":\n            model = TupleModel()\n            ref_input = (ref_input, ref_input * 2)\n        elif struct_type == \"array\":\n            model = ArrayModel()\n            ref_input = [ref_input, ref_input * 2]\n        elif struct_type == \"dict\":\n            model = DictModel()\n            ref_input = {\"x\": ref_input, \"y\": ref_input * 2}\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n        ref_output = model(tree.map_structure(ops.convert_to_tensor, ref_input))\n\n        saved_model.export_saved_model(model, temp_filepath)\n        revived_model = tf.saved_model.load(temp_filepath)\n        self.assertAllClose(ref_output, revived_model.serve(ref_input))\n\n        # Test with keras.saving_lib\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"exported_model.keras\"\n        )\n        saving_lib.save_model(model, temp_filepath)\n        revived_model = saving_lib.load_model(\n            temp_filepath,\n            {\n                \"TupleModel\": TupleModel,\n                \"ArrayModel\": ArrayModel,\n                \"DictModel\": DictModel,\n            },\n        )\n        self.assertAllClose(ref_output, revived_model(ref_input))\n        saved_model.export_saved_model(revived_model, self.get_temp_dir())\n\n        # Test with a different batch size\n        bigger_input = tree.map_structure(\n            lambda x: tf.concat([x, x], axis=0), ref_input\n        )\n        revived_model(bigger_input)\n\n    def test_model_with_multiple_inputs(self):\n        class TwoInputsModel(models.Model):\n            def call(self, x, y):\n                return x + y\n\n            def build(self, y_shape, x_shape):\n                self.built = True\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n        model = TwoInputsModel()\n        batch_size = 3 if backend.backend() != \"torch\" else 1\n        ref_input_x = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n        ref_input_y = np.random.normal(size=(batch_size, 10)).astype(\"float32\")\n        ref_output = model(ref_input_x, ref_input_y)\n\n        saved_model.export_saved_model(model, temp_filepath)\n        revived_model = tf.saved_model.load(temp_filepath)\n        self.assertAllClose(\n            ref_output, revived_model.serve(ref_input_x, ref_input_y)\n        )\n        # Test with a different batch size\n        revived_model.serve(\n            tf.random.normal((6, 10)), tf.random.normal((6, 10))\n        )\n\n    @parameterized.named_parameters(\n        named_product(\n            model_type=[\"sequential\", \"functional\", \"subclass\"],\n            input_signature=[\n                layers.InputSpec(\n                    dtype=\"float32\", shape=(None, 10), name=\"inputs\"\n                ),\n                tf.TensorSpec((None, 10), dtype=\"float32\", name=\"inputs\"),\n                backend.KerasTensor((None, 10), dtype=\"float32\", name=\"inputs\"),\n                \"backend_tensor\",\n            ],\n        )\n    )\n    def test_input_signature(self, model_type, input_signature):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n        model = get_model(model_type)\n        batch_size = 3 if backend.backend() != \"torch\" else 1\n        ref_input = ops.random.normal((batch_size, 10))\n        ref_output = model(ref_input)\n\n        if input_signature == \"backend_tensor\":\n            input_signature = (ref_input,)\n        else:\n            input_signature = (input_signature,)\n        saved_model.export_saved_model(\n            model, temp_filepath, input_signature=input_signature\n        )\n        revived_model = tf.saved_model.load(temp_filepath)\n        self.assertAllClose(\n            ref_output, revived_model.serve(ops.convert_to_numpy(ref_input))\n        )\n\n    def test_input_signature_error(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n        model = get_model(\"functional\")\n        with self.assertRaisesRegex(TypeError, \"Unsupported x=\"):\n            input_signature = (123,)\n            saved_model.export_saved_model(\n                model, temp_filepath, input_signature=input_signature\n            )\n\n    @parameterized.named_parameters(\n        named_product(\n            model_type=[\"sequential\", \"functional\", \"subclass\"],\n            is_static=(True, False),\n            jax2tf_kwargs=(\n                None,\n                {\"enable_xla\": True, \"native_serialization\": True},\n            ),\n        )\n    )\n    @pytest.mark.skipif(\n        backend.backend() != \"jax\",\n        reason=\"This test is only for the jax backend.\",\n    )\n    def test_jax_specific_kwargs(self, model_type, is_static, jax2tf_kwargs):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n        model = get_model(model_type)\n        ref_input = ops.random.uniform((3, 10))\n        ref_output = model(ref_input)\n\n        saved_model.export_saved_model(\n            model,\n            temp_filepath,\n            is_static=is_static,\n            jax2tf_kwargs=jax2tf_kwargs,\n        )\n        revived_model = tf.saved_model.load(temp_filepath)\n        self.assertAllClose(ref_output, revived_model.serve(ref_input))\n\n\n@pytest.mark.skipif(\n    backend.backend()\n    not in (\n        \"tensorflow\",\n        \"jax\",\n        # \"torch\",  # TODO: Support low-level operations in the torch backend.\n    ),\n    reason=\"Export only currently supports the TF and JAX backends.\",\n)\n@pytest.mark.skipif(testing.jax_uses_gpu(), reason=\"Leads to core dumps on CI\")\n@pytest.mark.skipif(\n    testing.torch_uses_gpu(), reason=\"Leads to core dumps on CI\"\n)\nclass ExportArchiveTest(testing.TestCase):\n    @parameterized.named_parameters(\n        named_product(model_type=[\"sequential\", \"functional\", \"subclass\"])\n    )\n    def test_low_level_model_export(self, model_type):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n\n        model = get_model(model_type)\n        ref_input = tf.random.normal((3, 10))\n        ref_output = model(ref_input)\n\n        # Test variable tracking\n        export_archive = saved_model.ExportArchive()\n        export_archive.track(model)\n        self.assertLen(export_archive.variables, 8)\n        self.assertLen(export_archive.trainable_variables, 6)\n        self.assertLen(export_archive.non_trainable_variables, 2)\n\n        export_archive = saved_model.ExportArchive()\n        export_archive.track(model)\n        export_archive.add_endpoint(\n            \"call\",\n            model.__call__,\n            input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],\n        )\n        export_archive.write_out(temp_filepath)\n        revived_model = tf.saved_model.load(temp_filepath)\n        self.assertAllClose(ref_output, revived_model.call(ref_input))\n        # Test with a different batch size\n        revived_model.call(tf.random.normal((6, 10)))\n\n    def test_low_level_model_export_with_alias(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n\n        model = get_model()\n        ref_input = tf.random.normal((3, 10))\n        ref_output = model(ref_input)\n\n        export_archive = saved_model.ExportArchive()\n        export_archive.track(model)\n        fn = export_archive.add_endpoint(\n            \"call\",\n            model.__call__,\n            input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],\n        )\n        export_archive.write_out(\n            temp_filepath,\n            tf.saved_model.SaveOptions(function_aliases={\"call_alias\": fn}),\n        )\n        revived_model = tf.saved_model.load(\n            temp_filepath,\n            options=tf.saved_model.LoadOptions(\n                experimental_load_function_aliases=True\n            ),\n        )\n        self.assertAllClose(\n            ref_output, revived_model.function_aliases[\"call_alias\"](ref_input)\n        )\n        # Test with a different batch size\n        revived_model.function_aliases[\"call_alias\"](tf.random.normal((6, 10)))\n\n    @parameterized.named_parameters(\n        named_product(model_type=[\"sequential\", \"functional\", \"subclass\"])\n    )\n    def test_low_level_model_export_with_dynamic_dims(self, model_type):\n        class ReductionLayer(layers.Layer):\n            def call(self, inputs):\n                return ops.max(inputs, axis=1)\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n\n        model = get_model(\n            model_type,\n            input_shape=[(None,), (None,)],\n            layer_list=[layers.Concatenate(), ReductionLayer()],\n        )\n        ref_input = [tf.random.normal((3, 8)), tf.random.normal((3, 6))]\n        ref_output = model(ref_input)\n\n        export_archive = saved_model.ExportArchive()\n        export_archive.track(model)\n        export_archive.add_endpoint(\n            \"call\",\n            model.__call__,\n            input_signature=[\n                [\n                    tf.TensorSpec(shape=(None, None), dtype=tf.float32),\n                    tf.TensorSpec(shape=(None, None), dtype=tf.float32),\n                ]\n            ],\n        )\n        export_archive.write_out(temp_filepath)\n        revived_model = tf.saved_model.load(temp_filepath)\n        self.assertAllClose(ref_output, revived_model.call(ref_input))\n        # Test with a different batch size\n        revived_model.call([tf.random.normal((6, 8)), tf.random.normal((6, 6))])\n        # Test with a different batch size and different dynamic sizes\n        revived_model.call([tf.random.normal((6, 3)), tf.random.normal((6, 5))])\n\n    @pytest.mark.skipif(\n        backend.backend() != \"jax\",\n        reason=\"This test is only for the JAX backend.\",\n    )\n    def test_low_level_model_export_with_jax2tf_kwargs(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n\n        model = get_model()\n        ref_input = tf.random.normal((3, 10))\n        ref_output = model(ref_input)\n\n        export_archive = saved_model.ExportArchive()\n        export_archive.track(model)\n        export_archive.add_endpoint(\n            \"call\",\n            model.__call__,\n            input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],\n            jax2tf_kwargs={\n                \"native_serialization\": True,\n                \"native_serialization_platforms\": (\"cpu\", \"tpu\"),\n            },\n        )\n        with self.assertRaisesRegex(\n            ValueError, \"native_serialization_platforms.*bogus\"\n        ):\n            export_archive.add_endpoint(\n                \"call2\",\n                model.__call__,\n                input_signature=[\n                    tf.TensorSpec(shape=(None, 10), dtype=tf.float32)\n                ],\n                jax2tf_kwargs={\n                    \"native_serialization\": True,\n                    \"native_serialization_platforms\": (\"cpu\", \"bogus\"),\n                },\n            )\n        export_archive.write_out(temp_filepath)\n        revived_model = tf.saved_model.load(temp_filepath)\n        self.assertAllClose(ref_output, revived_model.call(ref_input))\n\n    @pytest.mark.skipif(\n        backend.backend() != \"jax\",\n        reason=\"This test is only for the JAX backend.\",\n    )\n    def test_low_level_model_export_with_jax2tf_polymorphic_shapes(self):\n        class SquareLayer(layers.Layer):\n            def call(self, inputs):\n                return ops.matmul(inputs, inputs)\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n\n        model = CustomModel([SquareLayer()])\n        ref_input = tf.random.normal((3, 10, 10))\n        ref_output = model(ref_input)\n        signature = [tf.TensorSpec(shape=(None, None, None), dtype=tf.float32)]\n\n        with self.assertRaises(TypeError):\n            # This will fail because the polymorphic_shapes that is\n            # automatically generated will not account for the fact that\n            # dynamic dimensions 1 and 2 must have the same value.\n            export_archive = saved_model.ExportArchive()\n            export_archive.track(model)\n            export_archive.add_endpoint(\n                \"call\",\n                model.__call__,\n                input_signature=signature,\n                jax2tf_kwargs={},\n            )\n            export_archive.write_out(temp_filepath)\n\n        export_archive = saved_model.ExportArchive()\n        export_archive.track(model)\n        export_archive.add_endpoint(\n            \"call\",\n            model.__call__,\n            input_signature=signature,\n            jax2tf_kwargs={\"polymorphic_shapes\": [\"(batch, a, a)\"]},\n        )\n        export_archive.write_out(temp_filepath)\n        revived_model = tf.saved_model.load(temp_filepath)\n        self.assertAllClose(ref_output, revived_model.call(ref_input))\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"This test is native to the TF backend.\",\n    )\n    def test_endpoint_registration_tf_function(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n\n        model = get_model()\n        ref_input = tf.random.normal((3, 10))\n        ref_output = model(ref_input)\n\n        # Test variable tracking\n        export_archive = saved_model.ExportArchive()\n        export_archive.track(model)\n        self.assertLen(export_archive.variables, 8)\n        self.assertLen(export_archive.trainable_variables, 6)\n        self.assertLen(export_archive.non_trainable_variables, 2)\n\n        @tf.function()\n        def my_endpoint(x):\n            return model(x)\n\n        # Test registering an endpoint that is a tf.function (called)\n        my_endpoint(ref_input)  # Trace fn\n\n        export_archive.add_endpoint(\n            \"call\",\n            my_endpoint,\n        )\n        export_archive.write_out(temp_filepath)\n\n        revived_model = tf.saved_model.load(temp_filepath)\n        self.assertFalse(hasattr(revived_model, \"_tracked\"))\n        self.assertAllClose(ref_output, revived_model.call(ref_input))\n        self.assertLen(revived_model.variables, 8)\n        self.assertLen(revived_model.trainable_variables, 6)\n        self.assertLen(revived_model.non_trainable_variables, 2)\n\n    @pytest.mark.skipif(\n        backend.backend() != \"jax\",\n        reason=\"This test is native to the JAX backend.\",\n    )\n    def test_jax_endpoint_registration_tf_function(self):\n        model = get_model()\n        ref_input = np.random.normal(size=(3, 10))\n        model(ref_input)\n\n        # build a JAX function\n        def model_call(x):\n            return model(x)\n\n        from jax import default_backend as jax_device\n        from jax.experimental import jax2tf\n\n        native_jax_compatible = not (\n            jax_device() == \"gpu\"\n            and len(tf.config.list_physical_devices(\"GPU\")) == 0\n        )\n        # now, convert JAX function\n        converted_model_call = jax2tf.convert(\n            model_call,\n            native_serialization=native_jax_compatible,\n            polymorphic_shapes=[\"(b, 10)\"],\n        )\n\n        # you can now build a TF inference function\n        @tf.function(\n            input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],\n            autograph=False,\n        )\n        def infer_fn(x):\n            return converted_model_call(x)\n\n        ref_output = infer_fn(ref_input)\n\n        # Export with TF inference function as endpoint\n        temp_filepath = os.path.join(self.get_temp_dir(), \"my_model\")\n        export_archive = saved_model.ExportArchive()\n        export_archive.track(model)\n        export_archive.add_endpoint(\"serve\", infer_fn)\n        export_archive.write_out(temp_filepath)\n\n        # Reload and verify outputs\n        revived_model = tf.saved_model.load(temp_filepath)\n        self.assertFalse(hasattr(revived_model, \"_tracked\"))\n        self.assertAllClose(ref_output, revived_model.serve(ref_input))\n        self.assertLen(revived_model.variables, 8)\n        self.assertLen(revived_model.trainable_variables, 6)\n        self.assertLen(revived_model.non_trainable_variables, 2)\n\n        # Assert all variables wrapped as `tf.Variable`\n        self.assertIsInstance(export_archive.variables[0], tf.Variable)\n        self.assertIsInstance(\n            export_archive.trainable_variables[0], tf.Variable\n        )\n        self.assertIsInstance(\n            export_archive.non_trainable_variables[0], tf.Variable\n        )\n\n    @pytest.mark.skipif(\n        backend.backend() != \"jax\",\n        reason=\"This test is native to the JAX backend.\",\n    )\n    def test_jax_multi_unknown_endpoint_registration(self):\n        window_size = 100\n\n        X = np.random.random((1024, window_size, 1))\n        Y = np.random.random((1024, window_size, 1))\n\n        model = models.Sequential(\n            [\n                layers.Dense(128, activation=\"relu\"),\n                layers.Dense(64, activation=\"relu\"),\n                layers.Dense(1, activation=\"relu\"),\n            ]\n        )\n\n        model.compile(optimizer=\"adam\", loss=\"mse\")\n\n        model.fit(X, Y, batch_size=32)\n\n        # build a JAX function\n        def model_call(x):\n            return model(x)\n\n        from jax import default_backend as jax_device\n        from jax.experimental import jax2tf\n\n        native_jax_compatible = not (\n            jax_device() == \"gpu\"\n            and len(tf.config.list_physical_devices(\"GPU\")) == 0\n        )\n        # now, convert JAX function\n        converted_model_call = jax2tf.convert(\n            model_call,\n            native_serialization=native_jax_compatible,\n            polymorphic_shapes=[\"(b, t, 1)\"],\n        )\n\n        # you can now build a TF inference function\n        @tf.function(\n            input_signature=[\n                tf.TensorSpec(shape=(None, None, 1), dtype=tf.float32)\n            ],\n            autograph=False,\n        )\n        def infer_fn(x):\n            return converted_model_call(x)\n\n        ref_input = np.random.random((1024, window_size, 1))\n        ref_output = infer_fn(ref_input)\n\n        # Export with TF inference function as endpoint\n        temp_filepath = os.path.join(self.get_temp_dir(), \"my_model\")\n        export_archive = saved_model.ExportArchive()\n        export_archive.track(model)\n        export_archive.add_endpoint(\"serve\", infer_fn)\n        export_archive.write_out(temp_filepath)\n\n        # Reload and verify outputs\n        revived_model = tf.saved_model.load(temp_filepath)\n        self.assertFalse(hasattr(revived_model, \"_tracked\"))\n        self.assertAllClose(ref_output, revived_model.serve(ref_input))\n        self.assertLen(revived_model.variables, 6)\n        self.assertLen(revived_model.trainable_variables, 6)\n        self.assertLen(revived_model.non_trainable_variables, 0)\n\n        # Assert all variables wrapped as `tf.Variable`\n        self.assertIsInstance(export_archive.variables[0], tf.Variable)\n        self.assertIsInstance(\n            export_archive.trainable_variables[0], tf.Variable\n        )\n\n    def test_layer_export(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_layer\")\n\n        layer = layers.BatchNormalization()\n        ref_input = tf.random.normal((3, 10))\n        ref_output = layer(ref_input)  # Build layer (important)\n\n        export_archive = saved_model.ExportArchive()\n        export_archive.track(layer)\n        export_archive.add_endpoint(\n            \"call\",\n            layer.call,\n            input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],\n        )\n        export_archive.write_out(temp_filepath)\n        revived_layer = tf.saved_model.load(temp_filepath)\n        self.assertAllClose(ref_output, revived_layer.call(ref_input))\n\n    def test_multi_input_output_functional_model(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n        x1 = layers.Input((2,))\n        x2 = layers.Input((2,))\n        y1 = layers.Dense(3)(x1)\n        y2 = layers.Dense(3)(x2)\n        model = models.Model([x1, x2], [y1, y2])\n\n        ref_inputs = [tf.random.normal((3, 2)), tf.random.normal((3, 2))]\n        ref_outputs = model(ref_inputs)\n\n        model.export(temp_filepath)\n        revived_model = tf.saved_model.load(temp_filepath)\n        self.assertAllClose(ref_outputs[0], revived_model.serve(ref_inputs)[0])\n        self.assertAllClose(ref_outputs[1], revived_model.serve(ref_inputs)[1])\n        # Test with a different batch size\n        revived_model.serve(\n            [tf.random.normal((6, 2)), tf.random.normal((6, 2))]\n        )\n\n        # Now test dict inputs\n        model = models.Model({\"x1\": x1, \"x2\": x2}, [y1, y2])\n\n        ref_inputs = {\n            \"x1\": tf.random.normal((3, 2)),\n            \"x2\": tf.random.normal((3, 2)),\n        }\n        ref_outputs = model(ref_inputs)\n\n        model.export(temp_filepath)\n        revived_model = tf.saved_model.load(temp_filepath)\n        self.assertAllClose(ref_outputs[0], revived_model.serve(ref_inputs)[0])\n        self.assertAllClose(ref_outputs[1], revived_model.serve(ref_inputs)[1])\n        # Test with a different batch size\n        revived_model.serve(\n            {\n                \"x1\": tf.random.normal((6, 2)),\n                \"x2\": tf.random.normal((6, 2)),\n            }\n        )\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"String lookup requires TensorFlow backend\",\n    )\n    def test_model_with_lookup_table(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n        text_vectorization = layers.TextVectorization()\n        text_vectorization.adapt([\"one two\", \"three four\", \"five six\"])\n        model = models.Sequential(\n            [\n                layers.Input(shape=(), dtype=\"string\"),\n                text_vectorization,\n                layers.Embedding(10, 32),\n                layers.Dense(1),\n            ]\n        )\n        ref_input = tf.convert_to_tensor([\"one two three four\"])\n        ref_output = model(ref_input)\n\n        saved_model.export_saved_model(model, temp_filepath)\n        revived_model = tf.saved_model.load(temp_filepath)\n        self.assertAllClose(ref_output, revived_model.serve(ref_input))\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"String lookup requires TensorFlow backend\",\n    )\n    def test_model_with_tracked_collection(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n        text_vectorization = layers.TextVectorization()\n        text_vectorization.adapt([\"one two\", \"three four\", \"five six\"])\n\n        # CustomModel has a list of layers. The `TrackedList` that Keras uses is\n        # not a TensorFlow Trackable, but `Layer._trackable_children` makes it\n        # work.\n        model = CustomModel(\n            [\n                text_vectorization,\n                layers.Embedding(10, 32),\n                layers.Dense(1),\n            ]\n        )\n        ref_input = tf.convert_to_tensor([\"one two three four\"])\n        ref_output = model(ref_input)\n\n        saved_model.export_saved_model(\n            model,\n            temp_filepath,\n            input_signature=[tf.TensorSpec(shape=[1], dtype=tf.string)],\n        )\n        revived_model = tf.saved_model.load(temp_filepath)\n        self.assertAllClose(ref_output, revived_model.serve(ref_input))\n\n    def test_track_multiple_layers(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n        layer_1 = layers.Dense(2)\n        ref_input_1 = tf.random.normal((3, 4))\n        ref_output_1 = layer_1(ref_input_1)\n        layer_2 = layers.Dense(3)\n        ref_input_2 = tf.random.normal((3, 5))\n        ref_output_2 = layer_2(ref_input_2)\n\n        export_archive = saved_model.ExportArchive()\n        export_archive.add_endpoint(\n            \"call_1\",\n            layer_1.call,\n            input_signature=[tf.TensorSpec(shape=(None, 4), dtype=tf.float32)],\n        )\n        export_archive.add_endpoint(\n            \"call_2\",\n            layer_2.call,\n            input_signature=[tf.TensorSpec(shape=(None, 5), dtype=tf.float32)],\n        )\n        export_archive.write_out(temp_filepath)\n        revived_layer = tf.saved_model.load(temp_filepath)\n        self.assertAllClose(ref_output_1, revived_layer.call_1(ref_input_1))\n        self.assertAllClose(ref_output_2, revived_layer.call_2(ref_input_2))\n\n    def test_non_standard_layer_signature(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_layer\")\n\n        layer = layers.MultiHeadAttention(2, 2)\n        x1 = tf.random.normal((3, 2, 2))\n        x2 = tf.random.normal((3, 2, 2))\n        ref_output = layer(x1, x2)  # Build layer (important)\n        export_archive = saved_model.ExportArchive()\n        export_archive.track(layer)\n        export_archive.add_endpoint(\n            \"call\",\n            layer.call,\n            input_signature=[\n                tf.TensorSpec(shape=(None, 2, 2), dtype=tf.float32),\n                tf.TensorSpec(shape=(None, 2, 2), dtype=tf.float32),\n            ],\n        )\n        export_archive.write_out(temp_filepath)\n        revived_layer = tf.saved_model.load(temp_filepath)\n        self.assertAllClose(ref_output, revived_layer.call(x1, x2))\n\n    def test_non_standard_layer_signature_with_kwargs(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_layer\")\n\n        layer = layers.MultiHeadAttention(2, 2)\n        x1 = tf.random.normal((3, 2, 2))\n        x2 = tf.random.normal((3, 2, 2))\n        ref_output = layer(x1, x2)  # Build layer (important)\n        export_archive = saved_model.ExportArchive()\n        export_archive.track(layer)\n        export_archive.add_endpoint(\n            \"call\",\n            layer.call,\n            input_signature=[\n                tf.TensorSpec(shape=(None, 2, 2), dtype=tf.float32),\n                tf.TensorSpec(shape=(None, 2, 2), dtype=tf.float32),\n            ],\n        )\n        export_archive.write_out(temp_filepath)\n        revived_layer = tf.saved_model.load(temp_filepath)\n        self.assertAllClose(ref_output, revived_layer.call(query=x1, value=x2))\n        # Test with a different batch size\n        revived_layer.call(\n            query=tf.random.normal((6, 2, 2)), value=tf.random.normal((6, 2, 2))\n        )\n\n    def test_variable_collection(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n\n        model = models.Sequential(\n            [\n                layers.Input((10,)),\n                layers.Dense(2),\n                layers.Dense(2),\n            ]\n        )\n\n        # Test variable tracking\n        export_archive = saved_model.ExportArchive()\n        export_archive.track(model)\n        export_archive.add_endpoint(\n            \"call\",\n            model.__call__,\n            input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],\n        )\n        export_archive.add_variable_collection(\n            \"my_vars\", model.layers[1].weights\n        )\n\n        self.assertLen(export_archive._tf_trackable.my_vars, 2)\n        export_archive.write_out(temp_filepath)\n        revived_model = tf.saved_model.load(temp_filepath)\n        self.assertLen(revived_model.my_vars, 2)\n\n    def test_export_saved_model_errors(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n\n        # Model has not been built\n        model = models.Sequential([layers.Dense(2)])\n        with self.assertRaisesRegex(ValueError, \"It must be built\"):\n            saved_model.export_saved_model(model, temp_filepath)\n\n        # Subclassed model has not been called\n        model = get_model(\"subclass\")\n        model.build((2, 10))\n        with self.assertRaisesRegex(ValueError, \"It must be called\"):\n            saved_model.export_saved_model(model, temp_filepath)\n\n    def test_export_archive_errors(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n        model = models.Sequential([layers.Dense(2)])\n        model(tf.random.normal((2, 3)))\n\n        # Endpoint name reuse\n        export_archive = saved_model.ExportArchive()\n        export_archive.track(model)\n        export_archive.add_endpoint(\n            \"call\",\n            model.__call__,\n            input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],\n        )\n        with self.assertRaisesRegex(ValueError, \"already taken\"):\n            export_archive.add_endpoint(\n                \"call\",\n                model.__call__,\n                input_signature=[\n                    tf.TensorSpec(shape=(None, 3), dtype=tf.float32)\n                ],\n            )\n\n        # Write out with no endpoints\n        export_archive = saved_model.ExportArchive()\n        export_archive.track(model)\n        with self.assertRaisesRegex(ValueError, \"No endpoints have been set\"):\n            export_archive.write_out(temp_filepath)\n\n        # Invalid object type\n        with self.assertRaisesRegex(ValueError, \"Invalid resource type\"):\n            export_archive = saved_model.ExportArchive()\n            export_archive.track(\"model\")\n\n        # Set endpoint with no input signature\n        export_archive = saved_model.ExportArchive()\n        export_archive.track(model)\n        with self.assertRaisesRegex(\n            ValueError, \"you must provide an `input_signature`\"\n        ):\n            export_archive.add_endpoint(\"call\", model.__call__)\n\n        # Set endpoint that has never been called\n        export_archive = saved_model.ExportArchive()\n        export_archive.track(model)\n\n        @tf.function()\n        def my_endpoint(x):\n            return model(x)\n\n        export_archive = saved_model.ExportArchive()\n        export_archive.track(model)\n        with self.assertRaisesRegex(\n            ValueError, \"you must either provide a function\"\n        ):\n            export_archive.add_endpoint(\"call\", my_endpoint)\n\n    def test_export_no_assets(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n\n        # Case where there are legitimately no assets.\n        model = models.Sequential([layers.Flatten()])\n        model(tf.random.normal((2, 3)))\n        export_archive = saved_model.ExportArchive()\n        export_archive.add_endpoint(\n            \"call\",\n            model.__call__,\n            input_signature=[tf.TensorSpec(shape=(None, 3), dtype=tf.float32)],\n        )\n        export_archive.write_out(temp_filepath)\n\n    @parameterized.named_parameters(\n        named_product(model_type=[\"sequential\", \"functional\", \"subclass\"])\n    )\n    def test_model_export_method(self, model_type):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n        model = get_model(model_type)\n        ref_input = tf.random.normal((3, 10))\n        ref_output = model(ref_input)\n\n        model.export(temp_filepath)\n        revived_model = tf.saved_model.load(temp_filepath)\n        self.assertAllClose(ref_output, revived_model.serve(ref_input))\n        # Test with a different batch size\n        revived_model.serve(tf.random.normal((6, 10)))\n\n    def test_model_combined_with_tf_preprocessing(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n\n        lookup_table = tf.lookup.StaticHashTable(\n            tf.lookup.KeyValueTensorInitializer(\n                tf.constant([\"a\", \"b\", \"c\"]), tf.constant([1.0, 2.0, 3.0])\n            ),\n            default_value=-1.0,\n        )\n        ref_input = tf.constant([[\"c\", \"b\", \"c\", \"a\", \"d\"]])\n        ref_intermediate = lookup_table.lookup(ref_input)\n\n        model = models.Sequential([layers.Dense(1)])\n        ref_output = model(ref_intermediate)\n\n        export_archive = saved_model.ExportArchive()\n        model_fn = export_archive.track_and_add_endpoint(\n            \"model\",\n            model,\n            input_signature=[tf.TensorSpec(shape=(None, 5), dtype=tf.float32)],\n        )\n        export_archive.track(lookup_table)\n\n        @tf.function()\n        def combined_fn(x):\n            x = lookup_table.lookup(x)\n            x = model_fn(x)\n            return x\n\n        self.assertAllClose(combined_fn(ref_input), ref_output)\n\n        export_archive.add_endpoint(\"combined_fn\", combined_fn)\n        export_archive.write_out(temp_filepath)\n\n        revived_model = tf.saved_model.load(temp_filepath)\n        self.assertAllClose(revived_model.combined_fn(ref_input), ref_output)\n"
  },
  {
    "path": "keras/src/export/tf2onnx_lib.py",
    "content": "import copy\nimport functools\nimport logging\nimport traceback\n\nimport numpy as np\n\n\n@functools.lru_cache()\ndef patch_tf2onnx():\n    \"\"\"Patches `tf2onnx` to ensure compatibility with numpy>=2.0.0.\"\"\"\n\n    from onnx import AttributeProto\n    from onnx import TensorProto\n\n    from keras.src.utils.module_utils import tf2onnx\n\n    logger = logging.getLogger(tf2onnx.__name__)\n\n    if not hasattr(np, \"object\"):\n        np.object = object\n\n    def patched_rewrite_constant_fold(g, ops):\n        \"\"\"\n        We call tensorflow transform with constant folding but in some cases\n        tensorflow does fold all constants. Since there are a bunch of ops in\n        onnx that use attributes where tensorflow has dynamic inputs, we badly\n        want constant folding to work. For cases where tensorflow missed\n        something, make another pass over the graph and fix want we care about.\n        \"\"\"\n        func_map = {\n            \"Add\": np.add,\n            \"GreaterEqual\": np.greater_equal,\n            \"Cast\": np.asarray,\n            \"ConcatV2\": np.concatenate,\n            \"Less\": np.less,\n            \"ListDiff\": np.setdiff1d,\n            \"Mul\": np.multiply,\n            \"Pack\": np.stack,\n            \"Range\": np.arange,\n            \"Sqrt\": np.sqrt,\n            \"Sub\": np.subtract,\n        }\n        ops = list(ops)\n\n        keep_looking = True\n        while keep_looking:\n            keep_looking = False\n            for idx, op in enumerate(ops):\n                func = func_map.get(op.type)\n                if func is None:\n                    continue\n                if set(op.output) & set(g.outputs):\n                    continue\n                try:\n                    inputs = []\n                    for node in op.inputs:\n                        if not node.is_const():\n                            break\n                        inputs.append(node.get_tensor_value(as_list=False))\n\n                    logger.debug(\n                        \"op name %s, %s, %s\",\n                        op.name,\n                        len(op.input),\n                        len(inputs),\n                    )\n                    if inputs and len(op.input) == len(inputs):\n                        logger.info(\n                            \"folding node type=%s, name=%s\" % (op.type, op.name)\n                        )\n                        if op.type == \"Cast\":\n                            dst = op.get_attr_int(\"to\")\n                            np_type = tf2onnx.utils.map_onnx_to_numpy_type(dst)\n                            val = np.asarray(*inputs, dtype=np_type)\n                        elif op.type == \"ConcatV2\":\n                            axis = inputs[-1]\n                            values = inputs[:-1]\n                            val = func(tuple(values), axis)\n                        elif op.type == \"ListDiff\":\n                            out_type = op.get_attr_int(\"out_idx\")\n                            np_type = tf2onnx.utils.map_onnx_to_numpy_type(\n                                out_type\n                            )\n                            val = func(*inputs)\n                            val = val.astype(np_type)\n                        elif op.type in [\"Pack\"]:\n                            # handle ops that need input array and axis\n                            axis = op.get_attr_int(\"axis\")\n                            val = func(inputs, axis=axis)\n                        elif op.type == \"Range\":\n                            dtype = op.get_attr_int(\"Tidx\")\n                            np_type = tf2onnx.utils.map_onnx_to_numpy_type(\n                                dtype\n                            )\n                            val = func(*inputs, dtype=np_type)\n                        else:\n                            val = func(*inputs)\n\n                        new_node_name = tf2onnx.utils.make_name(op.name)\n                        new_output_name = new_node_name\n                        old_output_name = op.output[0]\n                        old_node_name = op.name\n                        logger.debug(\n                            \"create const node [%s] replacing [%s]\",\n                            new_node_name,\n                            old_node_name,\n                        )\n                        ops[idx] = g.make_const(new_node_name, val)\n\n                        logger.debug(\n                            \"replace old output [%s] with new output [%s]\",\n                            old_output_name,\n                            new_output_name,\n                        )\n                        # need to re-write the consumers input name to use the\n                        # const name\n                        consumers = g.find_output_consumers(old_output_name)\n                        if consumers:\n                            for consumer in consumers:\n                                g.replace_input(\n                                    consumer, old_output_name, new_output_name\n                                )\n\n                        # keep looking until there is nothing we can fold.\n                        # We keep the graph in topological order so if we\n                        # folded, the result might help a following op.\n                        keep_looking = True\n                except Exception as ex:\n                    tb = traceback.format_exc()\n                    logger.info(\"exception: %s, details: %s\", ex, tb)\n                    # ignore errors\n\n        return ops\n\n    def patched_get_value_attr(self, external_tensor_storage=None):\n        \"\"\"\n        Return onnx attr for value property of node.\n        Attr is modified to point to external tensor data stored in\n        external_tensor_storage, if included.\n        \"\"\"\n        a = self._attr[\"value\"]\n        if (\n            external_tensor_storage is not None\n            and self in external_tensor_storage.node_to_modified_value_attr\n        ):\n            return external_tensor_storage.node_to_modified_value_attr[self]\n        if external_tensor_storage is None or a.type != AttributeProto.TENSOR:\n            return a\n\n        def prod(x):\n            if hasattr(np, \"product\"):\n                return np.product(x)\n            else:\n                return np.prod(x)\n\n        if (\n            prod(a.t.dims)\n            > external_tensor_storage.external_tensor_size_threshold\n        ):\n            a = copy.deepcopy(a)\n            tensor_name = (\n                f\"{self.name.strip()}_{external_tensor_storage.name_counter}\"\n            )\n            for c in '~\"#%&*:<>?/\\\\{|}':\n                tensor_name = tensor_name.replace(c, \"_\")\n            external_tensor_storage.name_counter += 1\n            external_tensor_storage.name_to_tensor_data[tensor_name] = (\n                a.t.raw_data\n            )\n            external_tensor_storage.node_to_modified_value_attr[self] = a\n            a.t.raw_data = b\"\"\n            a.t.ClearField(\"raw_data\")\n            location = a.t.external_data.add()\n            location.key = \"location\"\n            location.value = tensor_name\n            a.t.data_location = TensorProto.EXTERNAL\n        return a\n\n    tf2onnx.tfonnx.rewrite_constant_fold = patched_rewrite_constant_fold\n    tf2onnx.graph.Node.get_value_attr = patched_get_value_attr\n"
  },
  {
    "path": "keras/src/export/tfsm_layer.py",
    "content": "from keras.src import backend\nfrom keras.src import layers\nfrom keras.src.api_export import keras_export\nfrom keras.src.export.saved_model import _list_variables_used_by_fns\nfrom keras.src.saving import serialization_lib\nfrom keras.src.utils.module_utils import tensorflow as tf\n\n\n@keras_export(\"keras.layers.TFSMLayer\")\nclass TFSMLayer(layers.Layer):\n    \"\"\"Reload a Keras model/layer that was saved via SavedModel / ExportArchive.\n\n    Arguments:\n        filepath: `str` or `pathlib.Path` object. The path to the SavedModel.\n        call_endpoint: Name of the endpoint to use as the `call()` method\n            of the reloaded layer. If the SavedModel was created\n            via `model.export()`,\n            then the default endpoint name is `'serve'`. In other cases\n            it may be named `'serving_default'`.\n\n    Example:\n\n    ```python\n    model.export(\"path/to/artifact\")\n    reloaded_layer = TFSMLayer(\"path/to/artifact\")\n    outputs = reloaded_layer(inputs)\n    ```\n\n    The reloaded object can be used like a regular Keras layer, and supports\n    training/fine-tuning of its trainable weights. Note that the reloaded\n    object retains none of the internal structure or custom methods of the\n    original object -- it's a brand new layer created around the saved\n    function.\n\n    **Limitations:**\n\n    * Only call endpoints with a single `inputs` tensor argument\n    (which may optionally be a dict/tuple/list of tensors) are supported.\n    For endpoints with multiple separate input tensor arguments, consider\n    subclassing `TFSMLayer` and implementing a `call()` method with a\n    custom signature.\n    * If you need training-time behavior to differ from inference-time behavior\n    (i.e. if you need the reloaded object to support a `training=True` argument\n    in `__call__()`), make sure that the training-time call function is\n    saved as a standalone endpoint in the artifact, and provide its name\n    to the `TFSMLayer` via the `call_training_endpoint` argument.\n    \"\"\"\n\n    def __init__(\n        self,\n        filepath,\n        call_endpoint=\"serve\",\n        call_training_endpoint=None,\n        trainable=True,\n        name=None,\n        dtype=None,\n    ):\n        if backend.backend() != \"tensorflow\":\n            raise NotImplementedError(\n                \"The TFSMLayer is only currently supported with the \"\n                \"TensorFlow backend.\"\n            )\n\n        # Initialize an empty layer, then add_weight() etc. as needed.\n        super().__init__(trainable=trainable, name=name, dtype=dtype)\n\n        self._reloaded_obj = tf.saved_model.load(filepath)\n\n        self.filepath = filepath\n        self.call_endpoint = call_endpoint\n        self.call_training_endpoint = call_training_endpoint\n\n        # Resolve the call function.\n        if hasattr(self._reloaded_obj, call_endpoint):\n            # Case 1: it's set as an attribute.\n            self.call_endpoint_fn = getattr(self._reloaded_obj, call_endpoint)\n        elif call_endpoint in self._reloaded_obj.signatures:\n            # Case 2: it's listed in the `signatures` field.\n            self.call_endpoint_fn = self._reloaded_obj.signatures[call_endpoint]\n        else:\n            raise ValueError(\n                f\"The endpoint '{call_endpoint}' \"\n                \"is neither an attribute of the reloaded SavedModel, \"\n                \"nor an entry in the `signatures` field of \"\n                \"the reloaded SavedModel. Select another endpoint via \"\n                \"the `call_endpoint` argument. Available endpoints for \"\n                \"this SavedModel: \"\n                f\"{list(self._reloaded_obj.signatures.keys())}\"\n            )\n\n        # Resolving the training function.\n        if call_training_endpoint:\n            if hasattr(self._reloaded_obj, call_training_endpoint):\n                self.call_training_endpoint_fn = getattr(\n                    self._reloaded_obj, call_training_endpoint\n                )\n            elif call_training_endpoint in self._reloaded_obj.signatures:\n                self.call_training_endpoint_fn = self._reloaded_obj.signatures[\n                    call_training_endpoint\n                ]\n            else:\n                raise ValueError(\n                    f\"The endpoint '{call_training_endpoint}' \"\n                    \"is neither an attribute of the reloaded SavedModel, \"\n                    \"nor an entry in the `signatures` field of \"\n                    \"the reloaded SavedModel. Available endpoints for \"\n                    \"this SavedModel: \"\n                    f\"{list(self._reloaded_obj.signatures.keys())}\"\n                )\n\n        # Add trainable and non-trainable weights from the call_endpoint_fn.\n        all_fns = [self.call_endpoint_fn]\n        if call_training_endpoint:\n            all_fns.append(self.call_training_endpoint_fn)\n        tvs, ntvs = _list_variables_used_by_fns(all_fns)\n        for v in tvs:\n            self._add_existing_weight(v)\n        for v in ntvs:\n            self._add_existing_weight(v)\n\n        self._build_at_init()\n\n    def _add_existing_weight(self, weight):\n        \"\"\"Tracks an existing weight.\"\"\"\n        variable = backend.Variable(\n            initializer=weight,\n            trainable=weight.trainable,\n            dtype=weight.dtype,\n            shape=weight.shape,\n            # Keras variable names cannot contain slashes.\n            name=weight.name.replace(\"/\", \"_\"),\n        )\n        self._track_variable(variable)\n\n    def call(self, inputs, training=False, **kwargs):\n        if training:\n            if self.call_training_endpoint:\n                return self.call_training_endpoint_fn(inputs, **kwargs)\n        return self.call_endpoint_fn(inputs, **kwargs)\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\n            # Note: this is not intended to be portable.\n            \"filepath\": self.filepath,\n            \"call_endpoint\": self.call_endpoint,\n            \"call_training_endpoint\": self.call_training_endpoint,\n        }\n        return {**base_config, **config}\n\n    @classmethod\n    def from_config(cls, config, custom_objects=None, safe_mode=None):\n        \"\"\"Creates a TFSMLayer from its config.\n        Args:\n            config: A Python dictionary, typically the output of `get_config`.\n            custom_objects: Optional dictionary mapping names to custom objects.\n            safe_mode: Boolean, whether to disallow loading TFSMLayer.\n                When `safe_mode=True`, loading is disallowed because TFSMLayer\n                loads external SavedModels that may contain attacker-controlled\n                executable graph code. Defaults to `True`.\n        Returns:\n            A TFSMLayer instance.\n        \"\"\"\n        # Follow the same pattern as Lambda layer for safe_mode handling\n        effective_safe_mode = (\n            safe_mode\n            if safe_mode is not None\n            else serialization_lib.in_safe_mode()\n        )\n\n        if effective_safe_mode is not False:\n            raise ValueError(\n                \"Requested the deserialization of a `TFSMLayer`, which \"\n                \"loads an external SavedModel. This carries a potential risk \"\n                \"of arbitrary code execution and thus it is disallowed by \"\n                \"default. If you trust the source of the artifact, you can \"\n                \"override this error by passing `safe_mode=False` to the \"\n                \"loading function, or calling \"\n                \"`keras.config.enable_unsafe_deserialization().\"\n            )\n\n        return cls(**config)\n"
  },
  {
    "path": "keras/src/export/tfsm_layer_test.py",
    "content": "import os\n\nimport numpy as np\nimport pytest\nimport tensorflow as tf\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import testing\nfrom keras.src import utils\nfrom keras.src.export import saved_model\nfrom keras.src.export import tfsm_layer\nfrom keras.src.export.saved_model_test import get_model\nfrom keras.src.saving import saving_lib\n\n\n@pytest.mark.skipif(\n    backend.backend() != \"tensorflow\",\n    reason=\"TFSM Layer reloading is only for the TF backend.\",\n)\nclass TestTFSMLayer(testing.TestCase):\n    def test_reloading_export_archive(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n        model = get_model()\n        ref_input = tf.random.normal((3, 10))\n        ref_output = model(ref_input)\n\n        saved_model.export_saved_model(model, temp_filepath)\n        reloaded_layer = tfsm_layer.TFSMLayer(temp_filepath)\n        self.assertAllClose(reloaded_layer(ref_input), ref_output, atol=1e-7)\n        self.assertLen(reloaded_layer.weights, len(model.weights))\n        self.assertLen(\n            reloaded_layer.trainable_weights, len(model.trainable_weights)\n        )\n        self.assertLen(\n            reloaded_layer.non_trainable_weights,\n            len(model.non_trainable_weights),\n        )\n\n    def test_reloading_default_saved_model(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n        model = get_model()\n        ref_input = tf.random.normal((3, 10))\n        ref_output = model(ref_input)\n\n        tf.saved_model.save(model, temp_filepath)\n        reloaded_layer = tfsm_layer.TFSMLayer(\n            temp_filepath, call_endpoint=\"serving_default\"\n        )\n        # The output is a dict, due to the nature of SavedModel saving.\n        new_output = reloaded_layer(ref_input)\n        self.assertAllClose(\n            new_output[list(new_output.keys())[0]],\n            ref_output,\n            atol=1e-7,\n        )\n        self.assertLen(reloaded_layer.weights, len(model.weights))\n        self.assertLen(\n            reloaded_layer.trainable_weights, len(model.trainable_weights)\n        )\n        self.assertLen(\n            reloaded_layer.non_trainable_weights,\n            len(model.non_trainable_weights),\n        )\n        for keras_var in reloaded_layer.weights:\n            self.assertIsInstance(keras_var, backend.Variable)\n\n    def test_call_training(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n        utils.set_random_seed(1337)\n        model = models.Sequential(\n            [\n                layers.Input((10,)),\n                layers.Dense(10),\n                layers.Dropout(0.99999),\n            ]\n        )\n        export_archive = saved_model.ExportArchive()\n        export_archive.track(model)\n        export_archive.add_endpoint(\n            name=\"call_inference\",\n            fn=lambda x: model(x, training=False),\n            input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],\n        )\n        export_archive.add_endpoint(\n            name=\"call_training\",\n            fn=lambda x: model(x, training=True),\n            input_signature=[tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],\n        )\n        export_archive.write_out(temp_filepath)\n        reloaded_layer = tfsm_layer.TFSMLayer(\n            temp_filepath,\n            call_endpoint=\"call_inference\",\n            call_training_endpoint=\"call_training\",\n        )\n        inference_output = reloaded_layer(\n            tf.random.normal((1, 10)), training=False\n        )\n        training_output = reloaded_layer(\n            tf.random.normal((1, 10)), training=True\n        )\n        self.assertAllClose(np.mean(training_output), 0.0, atol=1e-7)\n        self.assertNotAllClose(np.mean(inference_output), 0.0, atol=1e-7)\n\n    def test_serialization(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n        model = get_model()\n        ref_input = tf.random.normal((3, 10))\n        ref_output = model(ref_input)\n\n        saved_model.export_saved_model(model, temp_filepath)\n        reloaded_layer = tfsm_layer.TFSMLayer(temp_filepath)\n\n        # Test reinstantiation from config\n        config = reloaded_layer.get_config()\n        rereloaded_layer = tfsm_layer.TFSMLayer.from_config(\n            config, safe_mode=False\n        )\n        self.assertAllClose(rereloaded_layer(ref_input), ref_output, atol=1e-7)\n\n        # Test whole model saving with reloaded layer inside\n        model = models.Sequential([reloaded_layer])\n        temp_model_filepath = os.path.join(self.get_temp_dir(), \"m.keras\")\n        model.save(temp_model_filepath, save_format=\"keras_v3\")\n        reloaded_model = saving_lib.load_model(\n            temp_model_filepath, safe_mode=False\n        )\n        self.assertAllClose(reloaded_model(ref_input), ref_output, atol=1e-7)\n\n    def test_safe_mode_blocks_model_loading(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n\n        # Create and export a model\n        model = get_model()\n        model(tf.random.normal((1, 10)))\n        saved_model.export_saved_model(model, temp_filepath)\n\n        # Wrap SavedModel in TFSMLayer and save as .keras\n        reloaded_layer = tfsm_layer.TFSMLayer(temp_filepath)\n        wrapper_model = models.Sequential([reloaded_layer])\n\n        model_path = os.path.join(self.get_temp_dir(), \"tfsm_model.keras\")\n        wrapper_model.save(model_path)\n\n        # Default safe_mode=True should block loading\n        with self.assertRaisesRegex(\n            ValueError,\n            \"arbitrary code execution\",\n        ):\n            saving_lib.load_model(model_path)\n\n        # Explicit opt-out should allow loading\n        loaded_model = saving_lib.load_model(model_path, safe_mode=False)\n\n        x = tf.random.normal((2, 10))\n        self.assertAllClose(loaded_model(x), wrapper_model(x))\n\n    def test_errors(self):\n        # Test missing call endpoint\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n        model = models.Sequential([layers.Input((2,)), layers.Dense(3)])\n        saved_model.export_saved_model(model, temp_filepath)\n        with self.assertRaisesRegex(ValueError, \"The endpoint 'wrong'\"):\n            tfsm_layer.TFSMLayer(temp_filepath, call_endpoint=\"wrong\")\n\n        # Test missing call training endpoint\n        with self.assertRaisesRegex(ValueError, \"The endpoint 'wrong'\"):\n            tfsm_layer.TFSMLayer(\n                temp_filepath,\n                call_endpoint=\"serve\",\n                call_training_endpoint=\"wrong\",\n            )\n"
  },
  {
    "path": "keras/src/initializers/__init__.py",
    "content": "import inspect\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.initializers.constant_initializers import STFT\nfrom keras.src.initializers.constant_initializers import Constant\nfrom keras.src.initializers.constant_initializers import Identity\nfrom keras.src.initializers.constant_initializers import Ones\nfrom keras.src.initializers.constant_initializers import Zeros\nfrom keras.src.initializers.initializer import Initializer\nfrom keras.src.initializers.random_initializers import GlorotNormal\nfrom keras.src.initializers.random_initializers import GlorotUniform\nfrom keras.src.initializers.random_initializers import HeNormal\nfrom keras.src.initializers.random_initializers import HeUniform\nfrom keras.src.initializers.random_initializers import LecunNormal\nfrom keras.src.initializers.random_initializers import LecunUniform\nfrom keras.src.initializers.random_initializers import Orthogonal\nfrom keras.src.initializers.random_initializers import RandomNormal\nfrom keras.src.initializers.random_initializers import RandomUniform\nfrom keras.src.initializers.random_initializers import TruncatedNormal\nfrom keras.src.initializers.random_initializers import VarianceScaling\nfrom keras.src.saving import serialization_lib\nfrom keras.src.utils.naming import to_snake_case\n\nALL_OBJECTS = {\n    Initializer,\n    Constant,\n    Identity,\n    Ones,\n    STFT,\n    Zeros,\n    GlorotNormal,\n    GlorotUniform,\n    HeNormal,\n    HeUniform,\n    LecunNormal,\n    LecunUniform,\n    Orthogonal,\n    RandomNormal,\n    RandomUniform,\n    TruncatedNormal,\n    VarianceScaling,\n}\n\nALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}\nALL_OBJECTS_DICT.update(\n    {to_snake_case(cls.__name__): cls for cls in ALL_OBJECTS}\n)\n# Aliases\nALL_OBJECTS_DICT.update(\n    {\n        \"IdentityInitializer\": Identity,  # For compatibility\n        \"normal\": RandomNormal,\n        \"one\": Ones,\n        \"STFTInitializer\": STFT,  # For compatibility\n        \"OrthogonalInitializer\": Orthogonal,  # For compatibility\n        \"uniform\": RandomUniform,\n        \"zero\": Zeros,\n    }\n)\n\n\n@keras_export(\"keras.initializers.serialize\")\ndef serialize(initializer):\n    \"\"\"Returns the initializer configuration as a Python dict.\"\"\"\n    return serialization_lib.serialize_keras_object(initializer)\n\n\n@keras_export(\"keras.initializers.deserialize\")\ndef deserialize(config, custom_objects=None):\n    \"\"\"Returns a Keras initializer object via its configuration.\"\"\"\n    return serialization_lib.deserialize_keras_object(\n        config,\n        module_objects=ALL_OBJECTS_DICT,\n        custom_objects=custom_objects,\n    )\n\n\n@keras_export(\"keras.initializers.get\")\ndef get(identifier):\n    \"\"\"Retrieves a Keras initializer object via an identifier.\n\n    The `identifier` may be the string name of a initializers function or class\n    (case-sensitively).\n\n    >>> identifier = 'Ones'\n    >>> keras.initializers.get(identifier)\n    <...keras.initializers.initializers.Ones...>\n\n    You can also specify `config` of the initializer to this function by passing\n    dict containing `class_name` and `config` as an identifier. Also note that\n    the `class_name` must map to a `Initializer` class.\n\n    >>> cfg = {'class_name': 'Ones', 'config': {}}\n    >>> keras.initializers.get(cfg)\n    <...keras.initializers.initializers.Ones...>\n\n    In the case that the `identifier` is a class, this method will return a new\n    instance of the class by its constructor.\n\n    You may also pass a callable function with a signature that includes `shape`\n    and `dtype=None` as an identifier.\n\n    >>> fn = lambda shape, dtype=None: ops.ones(shape, dtype)\n    >>> keras.initializers.get(fn)\n    <function <lambda> at ...>\n\n    Alternatively, you can pass a backend tensor or numpy array as the\n    `identifier` to define the initializer values directly. Note that when\n    calling the initializer, the specified `shape` argument must be the same as\n    the shape of the tensor.\n\n    >>> tensor = ops.ones(shape=(5, 5))\n    >>> keras.initializers.get(tensor)\n    <function get.<locals>.initialize_fn at ...>\n\n    Args:\n        identifier: A string, dict, callable function, or tensor specifying\n            the initializer. If a string, it should be the name of an\n            initializer. If a dict, it should contain the configuration of an\n            initializer. Callable functions or predefined tensors are also\n            accepted.\n\n    Returns:\n        Initializer instance base on the input identifier.\n    \"\"\"\n    if identifier is None:\n        return None\n    if isinstance(identifier, dict):\n        obj = deserialize(identifier)\n    elif isinstance(identifier, str):\n        config = {\"class_name\": str(identifier), \"config\": {}}\n        obj = deserialize(config)\n    elif ops.is_tensor(identifier) or isinstance(\n        identifier, (np.generic, np.ndarray)\n    ):\n\n        def initialize_fn(shape, dtype=None):\n            dtype = backend.standardize_dtype(dtype)\n            if backend.standardize_shape(shape) != backend.standardize_shape(\n                identifier.shape\n            ):\n                raise ValueError(\n                    f\"Expected `shape` to be {identifier.shape} for direct \"\n                    f\"tensor as initializer. Received shape={shape}\"\n                )\n            return ops.cast(identifier, dtype)\n\n        obj = initialize_fn\n    else:\n        obj = identifier\n\n    if callable(obj):\n        if inspect.isclass(obj):\n            obj = obj()\n        return obj\n    else:\n        raise ValueError(\n            f\"Could not interpret initializer identifier: {identifier}\"\n        )\n"
  },
  {
    "path": "keras/src/initializers/constant_initializers.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend import standardize_dtype\nfrom keras.src.initializers.initializer import Initializer\nfrom keras.src.saving import serialization_lib\nfrom keras.src.utils.module_utils import scipy\n\n\n@keras_export([\"keras.initializers.Constant\", \"keras.initializers.constant\"])\nclass Constant(Initializer):\n    \"\"\"Initializer that generates tensors with constant values.\n\n    Only scalar values are allowed.\n    The constant value provided must be convertible to the dtype requested\n    when calling the initializer.\n\n    Examples:\n\n    >>> # Standalone usage:\n    >>> initializer = Constant(10.)\n    >>> values = initializer(shape=(2, 2))\n\n    >>> # Usage in a Keras layer:\n    >>> initializer = Constant(10.)\n    >>> layer = Dense(3, kernel_initializer=initializer)\n\n    Args:\n        value: A Python scalar.\n    \"\"\"\n\n    def __init__(self, value=0.0):\n        self.value = value\n\n    def __call__(self, shape, dtype=None):\n        dtype = standardize_dtype(dtype)\n        return ops.cast(self.value, dtype=dtype) * ops.ones(\n            shape=shape, dtype=dtype\n        )\n\n    def get_config(self):\n        return {\"value\": serialization_lib.serialize_keras_object(self.value)}\n\n    @classmethod\n    def from_config(cls, config):\n        value = serialization_lib.deserialize_keras_object(config[\"value\"])\n        return cls(value)\n\n\n@keras_export([\"keras.initializers.Zeros\", \"keras.initializers.zeros\"])\nclass Zeros(Initializer):\n    \"\"\"Initializer that generates tensors initialized to 0.\n\n    Examples:\n\n    >>> # Standalone usage:\n    >>> initializer = Zeros()\n    >>> values = initializer(shape=(2, 2))\n\n    >>> # Usage in a Keras layer:\n    >>> initializer = Zeros()\n    >>> layer = Dense(units=3, kernel_initializer=initializer)\n    \"\"\"\n\n    def __call__(self, shape, dtype=None):\n        \"\"\"Returns a tensor object initialized as specified by the initializer.\n\n        Args:\n            shape: Shape of the tensor.\n            dtype: Optional dtype of the tensor. Only numeric or boolean dtypes\n                are supported. If not specified, `keras.backend.floatx()`\n                is used, which default to `float32` unless you configured it\n                otherwise (via `keras.backend.set_floatx(float_dtype)`).\n        \"\"\"\n        dtype = standardize_dtype(dtype)\n        return ops.zeros(shape, dtype=dtype)\n\n\n@keras_export([\"keras.initializers.Ones\", \"keras.initializers.ones\"])\nclass Ones(Initializer):\n    \"\"\"Initializer that generates tensors initialized to 1.\n\n    Also available via the shortcut function `ones`.\n\n    Examples:\n\n    >>> # Standalone usage:\n    >>> initializer = Ones()\n    >>> values = initializer(shape=(2, 2))\n\n    >>> # Usage in a Keras layer:\n    >>> initializer = Ones()\n    >>> layer = Dense(3, kernel_initializer=initializer)\n    \"\"\"\n\n    def __call__(self, shape, dtype=None):\n        \"\"\"Returns a tensor object initialized as specified by the initializer.\n\n        Args:\n            shape: Shape of the tensor.\n            dtype: Optional dtype of the tensor. Only numeric or boolean dtypes\n                are supported. If not specified, `keras.backend.floatx()`\n                is used, which default to `float32` unless you configured it\n                otherwise (via `keras.backend.set_floatx(float_dtype)`).\n        \"\"\"\n        dtype = standardize_dtype(dtype)\n        return ops.ones(shape, dtype=dtype)\n\n\n@keras_export(\n    [\n        \"keras.initializers.Identity\",\n        \"keras.initializers.identity\",\n        \"keras.initializers.IdentityInitializer\",\n    ]\n)\nclass Identity(Initializer):\n    \"\"\"Initializer that generates the identity matrix.\n\n    Only usable for generating 2D matrices.\n\n    Examples:\n\n    >>> # Standalone usage:\n    >>> initializer = Identity()\n    >>> values = initializer(shape=(2, 2))\n\n    >>> # Usage in a Keras layer:\n    >>> initializer = Identity()\n    >>> layer = Dense(3, kernel_initializer=initializer)\n\n    Args:\n        gain: Multiplicative factor to apply to the identity matrix.\n    \"\"\"\n\n    def __init__(self, gain=1.0):\n        self.gain = gain\n\n    def __call__(self, shape, dtype=None):\n        \"\"\"Returns a tensor object initialized as specified by the initializer.\n\n        Args:\n            shape: Shape of the tensor.\n            dtype: Optional dtype of the tensor. Only numeric or boolean dtypes\n                are supported. If not specified, `keras.backend.floatx()`\n                is used, which default to `float32` unless you configured it\n                otherwise (via `keras.backend.set_floatx(float_dtype)`).\n        \"\"\"\n        if len(shape) != 2:\n            raise ValueError(\n                \"Identity matrix initializer can only be used for 2D matrices. \"\n                f\"Received: shape={shape} of rank {len(shape)}.\"\n            )\n        dtype = standardize_dtype(dtype)\n        return self.gain * ops.eye(*shape, dtype=dtype)\n\n\n@keras_export(\n    [\n        \"keras.initializers.STFT\",\n        \"keras.initializers.stft\",\n        \"keras.initializers.STFTInitializer\",\n    ]\n)\nclass STFT(Initializer):\n    \"\"\"Initializer of Conv kernels for Short-term Fourier Transformation (STFT).\n\n    Since the formula involves complex numbers, this class compute either the\n    real or the imaginary components of the final output.\n\n    Additionally, this initializer supports windowing functions across the time\n    dimension as commonly used in STFT. Windowing functions from the module\n    `scipy.signal.windows` are supported, including the common `hann` and\n    `hamming` windowing functions. This layer supports periodic windows and\n    scaling-based normalization.\n\n    This is primarily intended for use in the `STFTSpectrogram` layer.\n\n    Examples:\n\n    >>> # Standalone usage:\n    >>> initializer = STFTInitializer(\"real\", \"hann\", \"density\", False)\n    >>> values = initializer(shape=(128, 1, 513))\n\n    Args:\n        side: String, `\"real\"` or `\"imag\"` deciding if the kernel will compute\n            the real side or the imaginary side of the output. Defaults to\n            `\"real\"`.\n        window: String for the name of the windowing function in the\n            `scipy.signal.windows` module, or array_like for the window values,\n            or `None` for no windowing.\n        scaling: String, `\"density\"` or `\"spectrum\"` for scaling of the window\n            for normalization, either L2 or L1 normalization.\n            `None` for no scaling.\n        periodic: Boolean, if True, the window function will be treated as\n            periodic. Defaults to `False`.\n    \"\"\"\n\n    def __init__(\n        self, side=\"real\", window=\"hann\", scaling=\"density\", periodic=False\n    ):\n        if side not in [\"real\", \"imag\"]:\n            raise ValueError(f\"side should be 'real' or 'imag', not {side}\")\n        if isinstance(window, str):\n            # throws an exception for invalid window function\n            scipy.signal.get_window(window, 1)\n        if scaling is not None and scaling not in [\"density\", \"spectrum\"]:\n            raise ValueError(\n                \"Scaling is invalid, it must be `None`, 'density' \"\n                f\"or 'spectrum'. Received scaling={scaling}\"\n            )\n        self.side = side\n        self.window = window\n        self.scaling = scaling\n        self.periodic = periodic\n\n    def __call__(self, shape, dtype=None):\n        \"\"\"Returns a tensor object initialized as specified by the initializer.\n\n        The shape is assumed to be `(T, 1, F // 2 + 1)`, where `T` is the size\n        of the given window, and `F` is the number of frequency bands. Only half\n        the frequency bands are used, which is a common practice in STFT,\n        because the second half are the conjugates of the first half in\n        a reversed order.\n\n        Args:\n            shape: Shape of the tensor.\n            dtype: Optional dtype of the tensor. Only numeric or boolean dtypes\n                are supported. If not specified, `keras.backend.floatx()`\n                is used, which default to `float32` unless you configured it\n                otherwise (via `keras.backend.set_floatx(float_dtype)`).\n        \"\"\"\n        dtype = standardize_dtype(dtype)\n        frame_length, input_channels, fft_length = shape\n\n        win = None\n        scaling = 1\n        if self.window is not None:\n            win = self.window\n            if isinstance(win, str):\n                # Using SciPy since it provides more windowing functions,\n                # easier to be compatible with multiple backends.\n                win = scipy.signal.get_window(win, frame_length, self.periodic)\n            win = ops.convert_to_tensor(win, dtype=dtype)\n            if len(win.shape) != 1 or win.shape[-1] != frame_length:\n                raise ValueError(\n                    \"The shape of `window` must be equal to [frame_length].\"\n                    f\"Received: window shape={win.shape}\"\n                )\n            win = ops.reshape(win, [frame_length, 1, 1])\n            if self.scaling == \"density\":\n                scaling = ops.sqrt(ops.sum(ops.square(win)))\n            elif self.scaling == \"spectrum\":\n                scaling = ops.sum(ops.abs(win))\n\n        _fft_length = (fft_length - 1) * 2\n        freq = ops.divide(\n            ops.reshape(\n                ops.arange(fft_length, dtype=dtype), (1, 1, fft_length)\n            ),\n            _fft_length,\n        )\n        time = ops.reshape(\n            ops.arange(frame_length, dtype=dtype), (frame_length, 1, 1)\n        )\n        args = ops.multiply(ops.multiply(-2, time), freq) * ops.arccos(\n            ops.cast(-1, dtype)\n        )\n\n        if self.side == \"real\":\n            kernel = ops.cast(ops.cos(args), dtype)\n        else:\n            kernel = ops.cast(ops.sin(args), dtype)\n\n        if win is not None:\n            kernel = ops.divide(ops.multiply(kernel, win), scaling)\n        return kernel\n\n    def get_config(self):\n        return {\n            \"side\": self.side,\n            \"window\": self.window,\n            \"periodic\": self.periodic,\n            \"scaling\": self.scaling,\n        }\n"
  },
  {
    "path": "keras/src/initializers/constant_initializers_test.py",
    "content": "import numpy as np\nimport scipy.signal\n\nfrom conftest import skip_if_backend\nfrom keras.src import backend\nfrom keras.src import initializers\nfrom keras.src import testing\n\n\nclass ConstantInitializersTest(testing.TestCase):\n    def test_zeros_initializer(self):\n        shape = (3, 3)\n\n        initializer = initializers.Zeros()\n        values = initializer(shape=shape)\n        self.assertEqual(values.shape, shape)\n        np_values = backend.convert_to_numpy(values)\n        self.assertAllClose(np_values, np.zeros(shape=shape))\n\n        self.run_class_serialization_test(initializer)\n\n    def test_ones_initializer(self):\n        shape = (3, 3)\n\n        initializer = initializers.Ones()\n        values = initializer(shape=shape)\n        self.assertEqual(values.shape, shape)\n        np_values = backend.convert_to_numpy(values)\n        self.assertAllClose(np_values, np.ones(shape=shape))\n\n        self.run_class_serialization_test(initializer)\n\n    def test_constant_initializer(self):\n        shape = (3, 3)\n        constant_value = 6.0\n\n        initializer = initializers.Constant(value=constant_value)\n        values = initializer(shape=shape)\n        self.assertEqual(values.shape, shape)\n        np_values = backend.convert_to_numpy(values)\n        self.assertAllClose(\n            np_values, np.full(shape=shape, fill_value=constant_value)\n        )\n\n        self.run_class_serialization_test(initializer)\n\n    def test_constant_initializer_array_value(self):\n        shape = (3, 3)\n        constant_value = np.random.random((3, 3))\n\n        initializer = initializers.Constant(value=constant_value)\n        values = initializer(shape=shape)\n        self.assertEqual(values.shape, shape)\n        np_values = backend.convert_to_numpy(values)\n        self.assertAllClose(\n            np_values, np.full(shape=shape, fill_value=constant_value)\n        )\n\n        self.run_class_serialization_test(initializer)\n\n    @skip_if_backend(\"openvino\", \"openvino backend does not support `eye`\")\n    def test_identity_initializer(self):\n        shape = (3, 3)\n        gain = 2\n\n        initializer = initializers.Identity(gain=gain)\n        values = initializer(shape=shape)\n        self.assertEqual(values.shape, shape)\n        np_values = backend.convert_to_numpy(values)\n        self.assertAllClose(np_values, np.eye(*shape) * gain)\n\n        self.run_class_serialization_test(initializer)\n\n        # Test compatible class_name\n        initializer = initializers.get(\"IdentityInitializer\")\n        self.assertIsInstance(initializer, initializers.Identity)\n\n    @skip_if_backend(\"openvino\", \"openvino backend does not support `arange`\")\n    def test_stft_initializer(self):\n        shape = (256, 1, 513)\n        time_range = np.arange(256).reshape((-1, 1, 1))\n        freq_range = (np.arange(513) / 1024.0).reshape((1, 1, -1))\n        pi = np.arccos(np.float32(-1))\n        args = -2 * pi * time_range * freq_range\n        tol_kwargs = {\"atol\": 1e-4, \"rtol\": 1e-6}\n\n        initializer = initializers.STFT(\"real\", None)\n        values = backend.convert_to_numpy(initializer(shape))\n        self.assertAllClose(np.cos(args), values, atol=1e-4)\n        self.run_class_serialization_test(initializer)\n\n        initializer = initializers.STFT(\n            \"real\",\n            \"hamming\",\n            None,\n            True,\n        )\n        window = scipy.signal.windows.get_window(\"hamming\", 256, True)\n        window = window.astype(\"float32\").reshape((-1, 1, 1))\n        values = backend.convert_to_numpy(initializer(shape, \"float32\"))\n        self.assertAllClose(np.cos(args) * window, values, **tol_kwargs)\n        self.run_class_serialization_test(initializer)\n\n        initializer = initializers.STFT(\n            \"imag\",\n            \"tukey\",\n            \"density\",\n            False,\n        )\n        window = scipy.signal.windows.get_window(\"tukey\", 256, False)\n        window = window.astype(\"float32\").reshape((-1, 1, 1))\n        window = window / np.sqrt(np.sum(window**2))\n        values = backend.convert_to_numpy(initializer(shape, \"float32\"))\n        self.assertAllClose(np.sin(args) * window, values, **tol_kwargs)\n        self.run_class_serialization_test(initializer)\n\n        initializer = initializers.STFT(\n            \"imag\",\n            list(range(1, 257)),\n            \"spectrum\",\n        )\n        window = np.arange(1, 257)\n        window = window.astype(\"float32\").reshape((-1, 1, 1))\n        window = window / np.sum(window)\n        values = backend.convert_to_numpy(initializer(shape, \"float32\"))\n        self.assertAllClose(np.sin(args) * window, values, **tol_kwargs)\n        self.run_class_serialization_test(initializer)\n\n        with self.assertRaises(ValueError):\n            initializers.STFT(\"imaginary\")\n        with self.assertRaises(ValueError):\n            initializers.STFT(\"real\", scaling=\"l2\")\n        with self.assertRaises(ValueError):\n            initializers.STFT(\"real\", window=\"unknown\")\n\n        # Test compatible class_name\n        initializer = initializers.get(\"STFTInitializer\")\n        self.assertIsInstance(initializer, initializers.STFT)\n"
  },
  {
    "path": "keras/src/initializers/initializer.py",
    "content": "from keras.src.api_export import keras_export\n\n\n@keras_export([\"keras.Initializer\", \"keras.initializers.Initializer\"])\nclass Initializer:\n    \"\"\"Initializer base class: all Keras initializers inherit from this class.\n\n    Initializers should implement a `__call__()` method with the following\n    signature:\n\n    ```python\n    def __call__(self, shape, dtype=None, **kwargs):\n        # returns a tensor of shape `shape` and dtype `dtype`\n        # containing values drawn from a distribution of your choice.\n    ```\n\n    Optionally, you can also implement the method `get_config()` and the class\n    method `from_config` in order to support serialization, just like with\n    any Keras object.\n\n    Here's a simple example: a random normal initializer.\n\n    ```python\n    class ExampleRandomNormal(Initializer):\n        def __init__(self, mean, stddev):\n            self.mean = mean\n            self.stddev = stddev\n\n        def __call__(self, shape, dtype=None, **kwargs):\n            return keras.random.normal(\n                shape, mean=self.mean, stddev=self.stddev, dtype=dtype\n            )\n\n        def get_config(self):  # To support serialization\n            return {\"mean\": self.mean, \"stddev\": self.stddev}\n    ```\n\n    Note that we don't have to implement `from_config()` in the example above\n    since the constructor arguments of the class the keys in the config returned\n    by `get_config()` are the same. In this case, the default `from_config()`\n    works fine.\n    \"\"\"\n\n    def __call__(self, shape, dtype=None):\n        \"\"\"Returns a tensor object initialized as specified by the initializer.\n\n        Args:\n            shape: Shape of the tensor.\n            dtype: Optional dtype of the tensor.\n        \"\"\"\n        raise NotImplementedError(\n            \"Initializer subclasses must implement the `__call__()` method.\"\n        )\n\n    def get_config(self):\n        \"\"\"Returns the initializer's configuration as a JSON-serializable dict.\n\n        Returns:\n            A JSON-serializable Python dict.\n        \"\"\"\n        return {}\n\n    @classmethod\n    def from_config(cls, config):\n        \"\"\"Instantiates an initializer from a configuration dictionary.\n\n        Example:\n\n        ```python\n        initializer = RandomUniform(-1, 1)\n        config = initializer.get_config()\n        initializer = RandomUniform.from_config(config)\n        ```\n\n        Args:\n            config: A Python dictionary, the output of `get_config()`.\n\n        Returns:\n            An `Initializer` instance.\n        \"\"\"\n        return cls(**config)\n\n    def clone(self):\n        return self.__class__.from_config(self.get_config())\n"
  },
  {
    "path": "keras/src/initializers/random_initializers.py",
    "content": "import math\n\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend import random\nfrom keras.src.initializers.initializer import Initializer\nfrom keras.src.saving import serialization_lib\n\n\nclass RandomInitializer(Initializer):\n    def __init__(self, seed=None):\n        self._init_seed = seed\n        if seed is None:\n            seed = random.make_default_seed()\n        elif isinstance(seed, dict):\n            seed = serialization_lib.deserialize_keras_object(seed)\n        elif not isinstance(seed, (int, random.SeedGenerator)):\n            raise ValueError(\n                \"`seed` argument should be an instance of \"\n                \"`keras.random.SeedGenerator()` or an integer. \"\n                f\"Received: seed={seed}\"\n            )\n        self.seed = seed\n\n    def get_config(self):\n        seed_config = serialization_lib.serialize_keras_object(self._init_seed)\n        return {\"seed\": seed_config}\n\n\n@keras_export(\n    [\n        \"keras.initializers.RandomNormal\",\n        \"keras.initializers.random_normal\",\n    ]\n)\nclass RandomNormal(RandomInitializer):\n    \"\"\"Random normal initializer.\n\n    Draws samples from a normal distribution for given parameters.\n\n    Examples:\n\n    >>> # Standalone usage:\n    >>> initializer = RandomNormal(mean=0.0, stddev=1.0)\n    >>> values = initializer(shape=(2, 2))\n\n    >>> # Usage in a Keras layer:\n    >>> initializer = RandomNormal(mean=0.0, stddev=1.0)\n    >>> layer = Dense(3, kernel_initializer=initializer)\n\n    Args:\n        mean: A python scalar or a scalar keras tensor. Mean of the random\n            values to generate.\n        stddev: A python scalar or a scalar keras tensor. Standard deviation of\n           the random values to generate.\n        seed: A Python integer or instance of\n            `keras.backend.SeedGenerator`.\n            Used to make the behavior of the initializer\n            deterministic. Note that an initializer seeded with an integer\n            or `None` (unseeded) will produce the same random values\n            across multiple calls. To get different random values\n            across multiple calls, use as seed an instance\n            of `keras.backend.SeedGenerator`.\n    \"\"\"\n\n    def __init__(self, mean=0.0, stddev=0.05, seed=None):\n        self.mean = mean\n        self.stddev = stddev\n        super().__init__(seed=seed)\n\n    def __call__(self, shape, dtype=None):\n        return random.normal(\n            shape=shape,\n            mean=self.mean,\n            stddev=self.stddev,\n            seed=self.seed,\n            dtype=dtype,\n        )\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\"mean\": self.mean, \"stddev\": self.stddev}\n        return {**base_config, **config}\n\n\n@keras_export(\n    [\n        \"keras.initializers.TruncatedNormal\",\n        \"keras.initializers.truncated_normal\",\n    ]\n)\nclass TruncatedNormal(RandomInitializer):\n    \"\"\"Initializer that generates a truncated normal distribution.\n\n    The values generated are similar to values from a\n    `RandomNormal` initializer, except that values more\n    than two standard deviations from the mean are\n    discarded and re-drawn.\n\n    Examples:\n\n    >>> # Standalone usage:\n    >>> initializer = TruncatedNormal(mean=0., stddev=1.)\n    >>> values = initializer(shape=(2, 2))\n\n    >>> # Usage in a Keras layer:\n    >>> initializer = TruncatedNormal(mean=0., stddev=1.)\n    >>> layer = Dense(3, kernel_initializer=initializer)\n\n    Args:\n        mean: A python scalar or a scalar keras tensor. Mean of the random\n            values to generate.\n        stddev: A python scalar or a scalar keras tensor. Standard deviation of\n           the random values to generate.\n        seed: A Python integer or instance of\n            `keras.backend.SeedGenerator`.\n            Used to make the behavior of the initializer\n            deterministic. Note that an initializer seeded with an integer\n            or `None` (unseeded) will produce the same random values\n            across multiple calls. To get different random values\n            across multiple calls, use as seed an instance\n            of `keras.backend.SeedGenerator`.\n    \"\"\"\n\n    def __init__(self, mean=0.0, stddev=0.05, seed=None):\n        self.mean = mean\n        self.stddev = stddev\n        super().__init__(seed=seed)\n\n    def __call__(self, shape, dtype=None):\n        return random.truncated_normal(\n            shape=shape,\n            mean=self.mean,\n            stddev=self.stddev,\n            seed=self.seed,\n            dtype=dtype,\n        )\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\"mean\": self.mean, \"stddev\": self.stddev}\n        return {**base_config, **config}\n\n\n@keras_export(\n    [\n        \"keras.initializers.RandomUniform\",\n        \"keras.initializers.random_uniform\",\n    ]\n)\nclass RandomUniform(RandomInitializer):\n    \"\"\"Random uniform initializer.\n\n    Draws samples from a uniform distribution for given parameters.\n\n    Examples:\n\n    >>> # Standalone usage:\n    >>> initializer = RandomUniform(minval=0.0, maxval=1.0)\n    >>> values = initializer(shape=(2, 2))\n\n    >>> # Usage in a Keras layer:\n    >>> initializer = RandomUniform(minval=0.0, maxval=1.0)\n    >>> layer = Dense(3, kernel_initializer=initializer)\n\n    Args:\n        minval: A python scalar or a scalar keras tensor. Lower bound of the\n            range of random values to generate (inclusive).\n        maxval: A python scalar or a scalar keras tensor. Upper bound of the\n            range of random values to generate (exclusive).\n        seed: A Python integer or instance of\n            `keras.backend.SeedGenerator`.\n            Used to make the behavior of the initializer\n            deterministic. Note that an initializer seeded with an integer\n            or `None` (unseeded) will produce the same random values\n            across multiple calls. To get different random values\n            across multiple calls, use as seed an instance\n            of `keras.backend.SeedGenerator`.\n    \"\"\"\n\n    def __init__(self, minval=-0.05, maxval=0.05, seed=None):\n        self.minval = minval\n        self.maxval = maxval\n        super().__init__(seed=seed)\n\n    def __call__(self, shape, dtype=None):\n        return random.uniform(\n            shape=shape,\n            minval=self.minval,\n            maxval=self.maxval,\n            seed=self.seed,\n            dtype=dtype,\n        )\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\"minval\": self.minval, \"maxval\": self.maxval}\n        return {**base_config, **config}\n\n\n@keras_export(\n    [\n        \"keras.initializers.VarianceScaling\",\n        \"keras.initializers.variance_scaling\",\n    ]\n)\nclass VarianceScaling(RandomInitializer):\n    \"\"\"Initializer that adapts its scale to the shape of its input tensors.\n\n    With `distribution=\"truncated_normal\" or \"untruncated_normal\"`, samples are\n    drawn from a truncated/untruncated normal distribution with a mean of zero\n    and a standard deviation (after truncation, if used) `stddev = sqrt(scale /\n    n)`, where `n` is:\n\n    - number of input units in the weight tensor, if `mode=\"fan_in\"`\n    - number of output units, if `mode=\"fan_out\"`\n    - average of the numbers of input and output units, if `mode=\"fan_avg\"`\n\n    With `distribution=\"uniform\"`, samples are drawn from a uniform distribution\n    within `[-limit, limit]`, where `limit = sqrt(3 * scale / n)`.\n\n    Examples:\n\n    >>> # Standalone usage:\n    >>> initializer = VarianceScaling(\n        scale=0.1, mode='fan_in', distribution='uniform')\n    >>> values = initializer(shape=(2, 2))\n\n    >>> # Usage in a Keras layer:\n    >>> initializer = VarianceScaling(\n        scale=0.1, mode='fan_in', distribution='uniform')\n    >>> layer = Dense(3, kernel_initializer=initializer)\n\n    Args:\n        scale: Scaling factor (positive float).\n        mode: One of `\"fan_in\"`, `\"fan_out\"`, `\"fan_avg\"`.\n        distribution: Random distribution to use.\n            One of `\"truncated_normal\"`, `\"untruncated_normal\"`, or `\"uniform\"`.\n        seed: A Python integer or instance of\n            `keras.backend.SeedGenerator`.\n            Used to make the behavior of the initializer\n            deterministic. Note that an initializer seeded with an integer\n            or `None` (unseeded) will produce the same random values\n            across multiple calls. To get different random values\n            across multiple calls, use as seed an instance\n            of `keras.backend.SeedGenerator`.\n    \"\"\"\n\n    def __init__(\n        self,\n        scale=1.0,\n        mode=\"fan_in\",\n        distribution=\"truncated_normal\",\n        seed=None,\n    ):\n        if scale <= 0.0:\n            raise ValueError(\n                \"Argument `scale` must be positive float. \"\n                f\"Received: scale={scale}\"\n            )\n        allowed_modes = {\"fan_in\", \"fan_out\", \"fan_avg\"}\n        if mode not in allowed_modes:\n            raise ValueError(\n                f\"Invalid `mode` argument: {mode}. \"\n                f\"Please use one of {allowed_modes}\"\n            )\n        distribution = distribution.lower()\n        if distribution == \"normal\":\n            distribution = \"truncated_normal\"\n        allowed_distributions = {\n            \"uniform\",\n            \"truncated_normal\",\n            \"untruncated_normal\",\n        }\n        if distribution not in allowed_distributions:\n            raise ValueError(\n                f\"Invalid `distribution` argument: {distribution}.\"\n                f\"Please use one of {allowed_distributions}\"\n            )\n        self.scale = scale\n        self.mode = mode\n        self.distribution = distribution\n        super().__init__(seed=seed)\n\n    def __call__(self, shape, dtype=None):\n        scale = self.scale\n        fan_in, fan_out = compute_fans(shape)\n        if self.mode == \"fan_in\":\n            scale /= max(1.0, fan_in)\n        elif self.mode == \"fan_out\":\n            scale /= max(1.0, fan_out)\n        else:\n            scale /= max(1.0, (fan_in + fan_out) / 2.0)\n        if self.distribution == \"truncated_normal\":\n            stddev = math.sqrt(scale) / 0.87962566103423978\n            return random.truncated_normal(\n                shape, mean=0.0, stddev=stddev, dtype=dtype, seed=self.seed\n            )\n        elif self.distribution == \"untruncated_normal\":\n            stddev = math.sqrt(scale)\n            return random.normal(\n                shape, mean=0.0, stddev=stddev, dtype=dtype, seed=self.seed\n            )\n        else:\n            limit = math.sqrt(3.0 * scale)\n            return random.uniform(\n                shape, minval=-limit, maxval=limit, dtype=dtype, seed=self.seed\n            )\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\n            \"scale\": self.scale,\n            \"mode\": self.mode,\n            \"distribution\": self.distribution,\n        }\n        return {**base_config, **config}\n\n\n@keras_export(\n    [\n        \"keras.initializers.GlorotUniform\",\n        \"keras.initializers.glorot_uniform\",\n    ]\n)\nclass GlorotUniform(VarianceScaling):\n    \"\"\"The Glorot uniform initializer, also called Xavier uniform initializer.\n\n    Draws samples from a uniform distribution within `[-limit, limit]`, where\n    `limit = sqrt(6 / (fan_in + fan_out))` (`fan_in` is the number of input\n    units in the weight tensor and `fan_out` is the number of output units).\n\n    Examples:\n\n    >>> # Standalone usage:\n    >>> initializer = GlorotUniform()\n    >>> values = initializer(shape=(2, 2))\n\n    >>> # Usage in a Keras layer:\n    >>> initializer = GlorotUniform()\n    >>> layer = Dense(3, kernel_initializer=initializer)\n\n    Args:\n        seed: A Python integer or instance of\n            `keras.backend.SeedGenerator`.\n            Used to make the behavior of the initializer\n            deterministic. Note that an initializer seeded with an integer\n            or `None` (unseeded) will produce the same random values\n            across multiple calls. To get different random values\n            across multiple calls, use as seed an instance\n            of `keras.backend.SeedGenerator`.\n\n    Reference:\n\n    - [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html)\n    \"\"\"\n\n    def __init__(self, seed=None):\n        super().__init__(\n            scale=1.0, mode=\"fan_avg\", distribution=\"uniform\", seed=seed\n        )\n\n    def get_config(self):\n        return {\n            \"seed\": serialization_lib.serialize_keras_object(self._init_seed)\n        }\n\n\n@keras_export(\n    [\n        \"keras.initializers.GlorotNormal\",\n        \"keras.initializers.glorot_normal\",\n    ]\n)\nclass GlorotNormal(VarianceScaling):\n    \"\"\"The Glorot normal initializer, also called Xavier normal initializer.\n\n    Draws samples from a truncated normal distribution centered on 0 with\n    `stddev = sqrt(2 / (fan_in + fan_out))` where `fan_in` is the number of\n    input units in the weight tensor and `fan_out` is the number of output units\n    in the weight tensor.\n\n    Examples:\n\n    >>> # Standalone usage:\n    >>> initializer = GlorotNormal()\n    >>> values = initializer(shape=(2, 2))\n\n    >>> # Usage in a Keras layer:\n    >>> initializer = GlorotNormal()\n    >>> layer = Dense(3, kernel_initializer=initializer)\n\n    Args:\n        seed: A Python integer or instance of\n            `keras.backend.SeedGenerator`.\n            Used to make the behavior of the initializer\n            deterministic. Note that an initializer seeded with an integer\n            or `None` (unseeded) will produce the same random values\n            across multiple calls. To get different random values\n            across multiple calls, use as seed an instance\n            of `keras.backend.SeedGenerator`.\n\n    Reference:\n\n    - [Glorot et al., 2010](http://proceedings.mlr.press/v9/glorot10a.html)\n    \"\"\"\n\n    def __init__(self, seed=None):\n        super().__init__(\n            scale=1.0,\n            mode=\"fan_avg\",\n            distribution=\"truncated_normal\",\n            seed=seed,\n        )\n\n    def get_config(self):\n        return {\n            \"seed\": serialization_lib.serialize_keras_object(self._init_seed)\n        }\n\n\n@keras_export(\n    [\n        \"keras.initializers.LecunNormal\",\n        \"keras.initializers.lecun_normal\",\n    ]\n)\nclass LecunNormal(VarianceScaling):\n    \"\"\"Lecun normal initializer.\n\n    Initializers allow you to pre-specify an initialization strategy, encoded in\n    the Initializer object, without knowing the shape and dtype of the variable\n    being initialized.\n\n    Draws samples from a truncated normal distribution centered on 0 with\n    `stddev = sqrt(1 / fan_in)` where `fan_in` is the number of input units in\n    the weight tensor.\n\n    Examples:\n\n    >>> # Standalone usage:\n    >>> initializer = LecunNormal()\n    >>> values = initializer(shape=(2, 2))\n\n    >>> # Usage in a Keras layer:\n    >>> initializer = LecunNormal()\n    >>> layer = Dense(3, kernel_initializer=initializer)\n\n    Args:\n        seed: A Python integer or instance of\n            `keras.backend.SeedGenerator`.\n            Used to make the behavior of the initializer\n            deterministic. Note that an initializer seeded with an integer\n            or `None` (unseeded) will produce the same random values\n            across multiple calls. To get different random values\n            across multiple calls, use as seed an instance\n            of `keras.backend.SeedGenerator`.\n\n    Reference:\n\n    - [Klambauer et al., 2017](https://arxiv.org/abs/1706.02515)\n    \"\"\"\n\n    def __init__(self, seed=None):\n        super().__init__(\n            scale=1.0, mode=\"fan_in\", distribution=\"truncated_normal\", seed=seed\n        )\n\n    def get_config(self):\n        return {\n            \"seed\": serialization_lib.serialize_keras_object(self._init_seed)\n        }\n\n\n@keras_export(\n    [\n        \"keras.initializers.LecunUniform\",\n        \"keras.initializers.lecun_uniform\",\n    ]\n)\nclass LecunUniform(VarianceScaling):\n    \"\"\"Lecun uniform initializer.\n\n    Draws samples from a uniform distribution within `[-limit, limit]`, where\n    `limit = sqrt(3 / fan_in)` (`fan_in` is the number of input units in the\n    weight tensor).\n\n    Examples:\n\n    >>> # Standalone usage:\n    >>> initializer = LecunUniform()\n    >>> values = initializer(shape=(2, 2))\n\n    >>> # Usage in a Keras layer:\n    >>> initializer = LecunUniform()\n    >>> layer = Dense(3, kernel_initializer=initializer)\n\n    Args:\n        seed: A Python integer or instance of\n            `keras.backend.SeedGenerator`.\n            Used to make the behavior of the initializer\n            deterministic. Note that an initializer seeded with an integer\n            or `None` (unseeded) will produce the same random values\n            across multiple calls. To get different random values\n            across multiple calls, use as seed an instance\n            of `keras.backend.SeedGenerator`.\n\n    Reference:\n\n    - [Klambauer et al., 2017](https://arxiv.org/abs/1706.02515)\n    \"\"\"\n\n    def __init__(self, seed=None):\n        super().__init__(\n            scale=1.0, mode=\"fan_in\", distribution=\"uniform\", seed=seed\n        )\n\n    def get_config(self):\n        return {\n            \"seed\": serialization_lib.serialize_keras_object(self._init_seed)\n        }\n\n\n@keras_export([\"keras.initializers.HeNormal\", \"keras.initializers.he_normal\"])\nclass HeNormal(VarianceScaling):\n    \"\"\"He normal initializer.\n\n    It draws samples from a truncated normal distribution centered on 0 with\n    `stddev = sqrt(2 / fan_in)` where `fan_in` is the number of input units in\n    the weight tensor.\n\n    Examples:\n\n    >>> # Standalone usage:\n    >>> initializer = HeNormal()\n    >>> values = initializer(shape=(2, 2))\n\n    >>> # Usage in a Keras layer:\n    >>> initializer = HeNormal()\n    >>> layer = Dense(3, kernel_initializer=initializer)\n\n    Args:\n        seed: A Python integer or instance of\n            `keras.backend.SeedGenerator`.\n            Used to make the behavior of the initializer\n            deterministic. Note that an initializer seeded with an integer\n            or `None` (unseeded) will produce the same random values\n            across multiple calls. To get different random values\n            across multiple calls, use as seed an instance\n            of `keras.backend.SeedGenerator`.\n\n    Reference:\n\n    - [He et al., 2015](https://arxiv.org/abs/1502.01852)\n    \"\"\"\n\n    def __init__(self, seed=None):\n        super().__init__(\n            scale=2.0, mode=\"fan_in\", distribution=\"truncated_normal\", seed=seed\n        )\n\n    def get_config(self):\n        return {\n            \"seed\": serialization_lib.serialize_keras_object(self._init_seed)\n        }\n\n\n@keras_export([\"keras.initializers.HeUniform\", \"keras.initializers.he_uniform\"])\nclass HeUniform(VarianceScaling):\n    \"\"\"He uniform variance scaling initializer.\n\n    Draws samples from a uniform distribution within `[-limit, limit]`, where\n    `limit = sqrt(6 / fan_in)` (`fan_in` is the number of input units in the\n    weight tensor).\n\n    Examples:\n\n    >>> # Standalone usage:\n    >>> initializer = HeUniform()\n    >>> values = initializer(shape=(2, 2))\n\n    >>> # Usage in a Keras layer:\n    >>> initializer = HeUniform()\n    >>> layer = Dense(3, kernel_initializer=initializer)\n\n    Args:\n        seed: A Python integer or instance of\n            `keras.backend.SeedGenerator`.\n            Used to make the behavior of the initializer\n            deterministic. Note that an initializer seeded with an integer\n            or `None` (unseeded) will produce the same random values\n            across multiple calls. To get different random values\n            across multiple calls, use as seed an instance\n            of `keras.backend.SeedGenerator`.\n\n    Reference:\n\n    - [He et al., 2015](https://arxiv.org/abs/1502.01852)\n    \"\"\"\n\n    def __init__(self, seed=None):\n        super().__init__(\n            scale=2.0, mode=\"fan_in\", distribution=\"uniform\", seed=seed\n        )\n\n    def get_config(self):\n        return {\n            \"seed\": serialization_lib.serialize_keras_object(self._init_seed)\n        }\n\n\ndef compute_fans(shape):\n    \"\"\"Computes the number of input and output units for a weight shape.\n\n    Args:\n        shape: Integer shape tuple.\n\n    Returns:\n        A tuple of integer scalars: `(fan_in, fan_out)`.\n    \"\"\"\n    shape = tuple(shape)\n    if len(shape) < 1:  # Just to avoid errors for constants.\n        fan_in = fan_out = 1\n    elif len(shape) == 1:\n        fan_in = fan_out = shape[0]\n    elif len(shape) == 2:\n        fan_in = shape[0]\n        fan_out = shape[1]\n    else:\n        # Assuming convolution kernels (2D, 3D, or more).\n        # kernel shape: (..., input_depth, depth)\n        receptive_field_size = 1\n        for dim in shape[:-2]:\n            receptive_field_size *= dim\n        fan_in = shape[-2] * receptive_field_size\n        fan_out = shape[-1] * receptive_field_size\n    return int(fan_in), int(fan_out)\n\n\n@keras_export(\n    [\n        \"keras.initializers.Orthogonal\",\n        \"keras.initializers.orthogonal\",\n        \"keras.initializers.OrthogonalInitializer\",\n    ]\n)\nclass Orthogonal(RandomInitializer):\n    \"\"\"Initializer that generates an orthogonal matrix.\n\n    If the shape of the tensor to initialize is two-dimensional, it is\n    initialized with an orthogonal matrix obtained from the QR decomposition of\n    a matrix of random numbers drawn from a normal distribution. If the matrix\n    has fewer rows than columns then the output will have orthogonal rows.\n    Otherwise, the output will have orthogonal columns.\n\n    If the shape of the tensor to initialize is more than two-dimensional,\n    a matrix of shape `(shape[0] * ... * shape[n - 2], shape[n - 1])`\n    is initialized, where `n` is the length of the shape vector.\n    The matrix is subsequently reshaped to give a tensor of the desired shape.\n\n    Examples:\n\n    >>> # Standalone usage:\n    >>> initializer = keras.initializers.Orthogonal()\n    >>> values = initializer(shape=(2, 2))\n\n    >>> # Usage in a Keras layer:\n    >>> initializer = keras.initializers.Orthogonal()\n    >>> layer = keras.layers.Dense(3, kernel_initializer=initializer)\n\n    Args:\n        gain: Multiplicative factor to apply to the orthogonal matrix.\n        seed: A Python integer. Used to make the behavior of the initializer\n            deterministic.\n\n    Reference:\n\n    - [Saxe et al., 2014](https://openreview.net/forum?id=_wzZwKpTDF_9C)\n    \"\"\"\n\n    def __init__(self, gain=1.0, seed=None):\n        self.gain = gain\n        super().__init__(seed=seed)\n\n    def __call__(self, shape, dtype=None):\n        if len(shape) < 2:\n            raise ValueError(\n                \"The tensor to initialize must be \"\n                \"at least two-dimensional. Received: \"\n                f\"shape={shape} of rank {len(shape)}.\"\n            )\n\n        # Flatten the input shape with the last dimension remaining\n        # its original shape so it works for conv2d\n        num_rows = 1\n        for dim in shape[:-1]:\n            num_rows *= dim\n        num_cols = shape[-1]\n        flat_shape = (max(num_cols, num_rows), min(num_cols, num_rows))\n\n        # Generate a random matrix\n        a = random.normal(flat_shape, seed=self.seed, dtype=dtype)\n        # Compute the qr factorization\n        q, r = ops.qr(a)\n        # Make Q uniform\n        d = ops.diag(r)\n        q *= ops.sign(d)\n        if num_rows < num_cols:\n            q = ops.transpose(q)\n        return self.gain * ops.reshape(q, shape)\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\"gain\": self.gain}\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/initializers/random_initializers_test.py",
    "content": "import numpy as np\n\nfrom conftest import skip_if_backend\nfrom keras.src import backend\nfrom keras.src import initializers\nfrom keras.src import random\nfrom keras.src import testing\nfrom keras.src import utils\n\n\nclass RandomInitializersTest(testing.TestCase):\n    def test_random_normal(self):\n        utils.set_random_seed(1337)\n        shape = (25, 20)\n        mean = 0.0\n        stddev = 1.0\n        seed = 1234\n        initializer = initializers.RandomNormal(\n            mean=mean, stddev=stddev, seed=seed\n        )\n        values = initializer(shape=shape)\n        self.assertEqual(initializer.mean, mean)\n        self.assertEqual(initializer.stddev, stddev)\n        self.assertEqual(initializer.seed, seed)\n        self.assertEqual(values.shape, shape)\n        self.assertAllClose(\n            np.std(backend.convert_to_numpy(values)), stddev, atol=1e-1\n        )\n\n        self.run_class_serialization_test(initializer)\n\n        # Test that a fixed seed yields the same results each call.\n        initializer = initializers.RandomNormal(\n            mean=mean, stddev=stddev, seed=1337\n        )\n        values = initializer(shape=shape)\n        next_values = initializer(shape=shape)\n        self.assertAllClose(values, next_values)\n\n        # Test that a SeedGenerator yields different results each call.\n        initializer = initializers.RandomNormal(\n            mean=mean, stddev=stddev, seed=backend.random.SeedGenerator(1337)\n        )\n        values = initializer(shape=shape)\n        next_values = initializer(shape=shape)\n        self.assertNotAllClose(values, next_values)\n\n        # Test serialization with SeedGenerator\n        initializer = initializers.RandomNormal(\n            mean=mean, stddev=stddev, seed=backend.random.SeedGenerator(1337)\n        )\n        values = initializer(shape=shape)\n\n        # Test that unseeded generator gets different results after cloning\n        initializer = initializers.RandomNormal(\n            mean=mean, stddev=stddev, seed=None\n        )\n        values = initializer(shape=shape)\n        cloned_initializer = initializers.RandomNormal.from_config(\n            initializer.get_config()\n        )\n        new_values = cloned_initializer(shape=shape)\n        self.assertNotAllClose(values, new_values)\n\n        # Test that seeded generator gets same results after cloning\n        initializer = initializers.RandomNormal(\n            mean=mean, stddev=stddev, seed=1337\n        )\n        values = initializer(shape=shape)\n        cloned_initializer = initializers.RandomNormal.from_config(\n            initializer.get_config()\n        )\n        new_values = cloned_initializer(shape=shape)\n        self.assertAllClose(values, new_values)\n\n    def test_random_uniform(self):\n        shape = (5, 5)\n        minval = -1.0\n        maxval = 1.0\n        seed = 1234\n        initializer = initializers.RandomUniform(\n            minval=minval, maxval=maxval, seed=seed\n        )\n        values = initializer(shape=shape)\n        self.assertEqual(initializer.minval, minval)\n        self.assertEqual(initializer.maxval, maxval)\n        self.assertEqual(initializer.seed, seed)\n        self.assertEqual(values.shape, shape)\n        values = backend.convert_to_numpy(values)\n        self.assertGreaterEqual(np.min(values), minval)\n        self.assertLess(np.max(values), maxval)\n\n        self.run_class_serialization_test(initializer)\n\n    def test_variance_scaling(self):\n        utils.set_random_seed(1337)\n        shape = (25, 20)\n        scale = 2.0\n        seed = 1234\n        initializer = initializers.VarianceScaling(\n            scale=scale, seed=seed, mode=\"fan_in\"\n        )\n        values = initializer(shape=shape)\n        self.assertEqual(initializer.scale, scale)\n        self.assertEqual(initializer.seed, seed)\n        self.assertEqual(values.shape, shape)\n        self.assertAllClose(\n            np.std(backend.convert_to_numpy(values)),\n            np.sqrt(scale / 25),\n            atol=1e-1,\n        )\n        self.run_class_serialization_test(initializer)\n\n        initializer = initializers.VarianceScaling(\n            scale=scale, seed=seed, mode=\"fan_out\"\n        )\n        values = initializer(shape=shape)\n        self.assertEqual(initializer.scale, scale)\n        self.assertEqual(initializer.seed, seed)\n        self.assertEqual(values.shape, shape)\n        self.assertAllClose(\n            np.std(backend.convert_to_numpy(values)),\n            np.sqrt(scale / 20),\n            atol=1e-1,\n        )\n        self.run_class_serialization_test(initializer)\n\n    @skip_if_backend(\"openvino\", \"openvino backend does not support `qr`\")\n    def test_orthogonal(self):\n        shape = (5, 5)\n        gain = 2.0\n        seed = 1234\n        initializer = initializers.Orthogonal(gain=gain, seed=seed)\n        values = initializer(shape=shape)\n        self.assertEqual(initializer.seed, seed)\n        self.assertEqual(initializer.gain, gain)\n\n        self.assertEqual(values.shape, shape)\n        array = backend.convert_to_numpy(values)\n        # Making sure that the columns have gain * unit norm value\n        for column in array.T:\n            self.assertAlmostEqual(np.linalg.norm(column), gain * 1.0)\n\n        # Making sure that each column is orthonormal to the other column\n        for i in range(array.shape[-1]):\n            for j in range(i + 1, array.shape[-1]):\n                self.assertAlmostEqual(\n                    np.dot(array[..., i], array[..., j]), 0.0\n                )\n\n        self.run_class_serialization_test(initializer)\n\n        # Test compatible class_name\n        initializer = initializers.get(\"OrthogonalInitializer\")\n        self.assertIsInstance(initializer, initializers.Orthogonal)\n\n    def test_get_method(self):\n        obj = initializers.get(\"glorot_normal\")\n        self.assertTrue(obj, initializers.GlorotNormal)\n\n        obj = initializers.get(None)\n        self.assertEqual(obj, None)\n\n        with self.assertRaises(ValueError):\n            initializers.get(\"typo\")\n\n    @skip_if_backend(\n        \"openvino\", \"openvino backend does not support `uniform` with None seed\"\n    )\n    def test_get_method_with_tensor(self):\n        shape = (5, 5)\n\n        # Test backend tensor\n        tensor = random.uniform(shape=shape)\n        initializer = initializers.get(tensor)\n        values = initializer(shape=shape)\n        self.assertAllClose(values, tensor)\n\n        # Test numpy array\n        tensor = np.random.uniform(size=shape).astype(\"float32\")\n        initializer = initializers.get(tensor)\n        values = initializer(shape=shape)\n        self.assertAllClose(values, tensor)\n\n        # Test bad `shape` argument\n        with self.assertRaisesRegex(ValueError, r\"Expected `shape` to be\"):\n            initializer(shape=(10, 10))\n\n    def test_variance_scaling_invalid_scale(self):\n        seed = 1234\n\n        with self.assertRaisesRegex(\n            ValueError, \"Argument `scale` must be positive float.\"\n        ):\n            initializers.VarianceScaling(scale=-1.0, seed=seed, mode=\"fan_in\")\n\n    def test_variance_scaling_invalid_mode(self):\n        scale = 2.0\n        seed = 1234\n\n        with self.assertRaisesRegex(ValueError, \"Invalid `mode` argument:\"):\n            initializers.VarianceScaling(\n                scale=scale, seed=seed, mode=\"invalid_mode\"\n            )\n\n    def test_variance_scaling_invalid_distribution(self):\n        scale = 2.0\n        seed = 1234\n\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid `distribution` argument:\"\n        ):\n            initializers.VarianceScaling(\n                scale=scale,\n                seed=seed,\n                mode=\"fan_in\",\n                distribution=\"invalid_dist\",\n            )\n\n    def test_serialization_with_seed_generator(self):\n        seed = random.SeedGenerator()\n        initializer = initializers.Orthogonal(seed=seed)\n        self.run_class_serialization_test(initializer)\n\n        seed = random.SeedGenerator()\n        initializer = initializers.VarianceScaling(seed=seed)\n        self.run_class_serialization_test(initializer)\n\n        seed = random.SeedGenerator()\n        initializer = initializers.RandomUniform(seed=seed)\n        self.run_class_serialization_test(initializer)\n\n        seed = random.SeedGenerator()\n        initializer = initializers.RandomNormal(seed=seed)\n        self.run_class_serialization_test(initializer)\n"
  },
  {
    "path": "keras/src/layers/__init__.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.activations.activation import Activation\nfrom keras.src.layers.activations.elu import ELU\nfrom keras.src.layers.activations.leaky_relu import LeakyReLU\nfrom keras.src.layers.activations.prelu import PReLU\nfrom keras.src.layers.activations.relu import ReLU\nfrom keras.src.layers.activations.softmax import Softmax\nfrom keras.src.layers.attention.additive_attention import AdditiveAttention\nfrom keras.src.layers.attention.attention import Attention\nfrom keras.src.layers.attention.grouped_query_attention import (\n    GroupedQueryAttention,\n)\nfrom keras.src.layers.attention.multi_head_attention import MultiHeadAttention\nfrom keras.src.layers.convolutional.conv1d import Conv1D\nfrom keras.src.layers.convolutional.conv1d_transpose import Conv1DTranspose\nfrom keras.src.layers.convolutional.conv2d import Conv2D\nfrom keras.src.layers.convolutional.conv2d_transpose import Conv2DTranspose\nfrom keras.src.layers.convolutional.conv3d import Conv3D\nfrom keras.src.layers.convolutional.conv3d_transpose import Conv3DTranspose\nfrom keras.src.layers.convolutional.depthwise_conv1d import DepthwiseConv1D\nfrom keras.src.layers.convolutional.depthwise_conv2d import DepthwiseConv2D\nfrom keras.src.layers.convolutional.separable_conv1d import SeparableConv1D\nfrom keras.src.layers.convolutional.separable_conv2d import SeparableConv2D\nfrom keras.src.layers.core.dense import Dense\nfrom keras.src.layers.core.einsum_dense import EinsumDense\nfrom keras.src.layers.core.embedding import Embedding\nfrom keras.src.layers.core.identity import Identity\nfrom keras.src.layers.core.input_layer import Input\nfrom keras.src.layers.core.input_layer import InputLayer\nfrom keras.src.layers.core.lambda_layer import Lambda\nfrom keras.src.layers.core.masking import Masking\nfrom keras.src.layers.core.reversible_embedding import ReversibleEmbedding\nfrom keras.src.layers.core.wrapper import Wrapper\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\nfrom keras.src.layers.merging.add import Add\nfrom keras.src.layers.merging.add import add\nfrom keras.src.layers.merging.average import Average\nfrom keras.src.layers.merging.average import average\nfrom keras.src.layers.merging.concatenate import Concatenate\nfrom keras.src.layers.merging.concatenate import concatenate\nfrom keras.src.layers.merging.dot import Dot\nfrom keras.src.layers.merging.dot import dot\nfrom keras.src.layers.merging.maximum import Maximum\nfrom keras.src.layers.merging.maximum import maximum\nfrom keras.src.layers.merging.minimum import Minimum\nfrom keras.src.layers.merging.minimum import minimum\nfrom keras.src.layers.merging.multiply import Multiply\nfrom keras.src.layers.merging.multiply import multiply\nfrom keras.src.layers.merging.subtract import Subtract\nfrom keras.src.layers.merging.subtract import subtract\nfrom keras.src.layers.normalization.batch_normalization import (\n    BatchNormalization,\n)\nfrom keras.src.layers.normalization.group_normalization import (\n    GroupNormalization,\n)\nfrom keras.src.layers.normalization.layer_normalization import (\n    LayerNormalization,\n)\nfrom keras.src.layers.normalization.rms_normalization import RMSNormalization\nfrom keras.src.layers.normalization.spectral_normalization import (\n    SpectralNormalization,\n)\nfrom keras.src.layers.normalization.unit_normalization import UnitNormalization\nfrom keras.src.layers.pooling.adaptive_average_pooling1d import (\n    AdaptiveAveragePooling1D,\n)\nfrom keras.src.layers.pooling.adaptive_average_pooling2d import (\n    AdaptiveAveragePooling2D,\n)\nfrom keras.src.layers.pooling.adaptive_average_pooling3d import (\n    AdaptiveAveragePooling3D,\n)\nfrom keras.src.layers.pooling.adaptive_max_pooling1d import AdaptiveMaxPooling1D\nfrom keras.src.layers.pooling.adaptive_max_pooling2d import AdaptiveMaxPooling2D\nfrom keras.src.layers.pooling.adaptive_max_pooling3d import AdaptiveMaxPooling3D\nfrom keras.src.layers.pooling.average_pooling1d import AveragePooling1D\nfrom keras.src.layers.pooling.average_pooling2d import AveragePooling2D\nfrom keras.src.layers.pooling.average_pooling3d import AveragePooling3D\nfrom keras.src.layers.pooling.global_average_pooling1d import (\n    GlobalAveragePooling1D,\n)\nfrom keras.src.layers.pooling.global_average_pooling2d import (\n    GlobalAveragePooling2D,\n)\nfrom keras.src.layers.pooling.global_average_pooling3d import (\n    GlobalAveragePooling3D,\n)\nfrom keras.src.layers.pooling.global_max_pooling1d import GlobalMaxPooling1D\nfrom keras.src.layers.pooling.global_max_pooling2d import GlobalMaxPooling2D\nfrom keras.src.layers.pooling.global_max_pooling3d import GlobalMaxPooling3D\nfrom keras.src.layers.pooling.max_pooling1d import MaxPooling1D\nfrom keras.src.layers.pooling.max_pooling2d import MaxPooling2D\nfrom keras.src.layers.pooling.max_pooling3d import MaxPooling3D\nfrom keras.src.layers.preprocessing.category_encoding import CategoryEncoding\nfrom keras.src.layers.preprocessing.discretization import Discretization\nfrom keras.src.layers.preprocessing.hashed_crossing import HashedCrossing\nfrom keras.src.layers.preprocessing.hashing import Hashing\nfrom keras.src.layers.preprocessing.image_preprocessing.aug_mix import AugMix\nfrom keras.src.layers.preprocessing.image_preprocessing.auto_contrast import (\n    AutoContrast,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.center_crop import (\n    CenterCrop,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.clahe import (\n    ContrastLimitedAdaptiveHistogramEqualization,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.cut_mix import CutMix\nfrom keras.src.layers.preprocessing.image_preprocessing.equalization import (\n    Equalization,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.max_num_bounding_box import (\n    MaxNumBoundingBoxes,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.mix_up import MixUp\nfrom keras.src.layers.preprocessing.image_preprocessing.rand_augment import (\n    RandAugment,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_brightness import (\n    RandomBrightness,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_color_degeneration import (\n    RandomColorDegeneration,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_color_jitter import (\n    RandomColorJitter,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_contrast import (\n    RandomContrast,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_crop import (\n    RandomCrop,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_elastic_transform import (\n    RandomElasticTransform,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_erasing import (\n    RandomErasing,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_flip import (\n    RandomFlip,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_gaussian_blur import (\n    RandomGaussianBlur,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_grayscale import (\n    RandomGrayscale,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_hue import (\n    RandomHue,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_invert import (\n    RandomInvert,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_perspective import (\n    RandomPerspective,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_posterization import (\n    RandomPosterization,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_rotation import (\n    RandomRotation,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_saturation import (\n    RandomSaturation,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_sharpness import (\n    RandomSharpness,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_shear import (\n    RandomShear,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_translation import (\n    RandomTranslation,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.random_zoom import (\n    RandomZoom,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.resizing import Resizing\nfrom keras.src.layers.preprocessing.image_preprocessing.solarization import (\n    Solarization,\n)\nfrom keras.src.layers.preprocessing.index_lookup import IndexLookup\nfrom keras.src.layers.preprocessing.integer_lookup import IntegerLookup\nfrom keras.src.layers.preprocessing.mel_spectrogram import MelSpectrogram\nfrom keras.src.layers.preprocessing.normalization import Normalization\nfrom keras.src.layers.preprocessing.pipeline import Pipeline\nfrom keras.src.layers.preprocessing.rescaling import Rescaling\nfrom keras.src.layers.preprocessing.stft_spectrogram import STFTSpectrogram\nfrom keras.src.layers.preprocessing.string_lookup import StringLookup\nfrom keras.src.layers.preprocessing.text_vectorization import TextVectorization\nfrom keras.src.layers.regularization.activity_regularization import (\n    ActivityRegularization,\n)\nfrom keras.src.layers.regularization.alpha_dropout import AlphaDropout\nfrom keras.src.layers.regularization.dropout import Dropout\nfrom keras.src.layers.regularization.gaussian_dropout import GaussianDropout\nfrom keras.src.layers.regularization.gaussian_noise import GaussianNoise\nfrom keras.src.layers.regularization.spatial_dropout import SpatialDropout1D\nfrom keras.src.layers.regularization.spatial_dropout import SpatialDropout2D\nfrom keras.src.layers.regularization.spatial_dropout import SpatialDropout3D\nfrom keras.src.layers.reshaping.cropping1d import Cropping1D\nfrom keras.src.layers.reshaping.cropping2d import Cropping2D\nfrom keras.src.layers.reshaping.cropping3d import Cropping3D\nfrom keras.src.layers.reshaping.flatten import Flatten\nfrom keras.src.layers.reshaping.permute import Permute\nfrom keras.src.layers.reshaping.repeat_vector import RepeatVector\nfrom keras.src.layers.reshaping.reshape import Reshape\nfrom keras.src.layers.reshaping.up_sampling1d import UpSampling1D\nfrom keras.src.layers.reshaping.up_sampling2d import UpSampling2D\nfrom keras.src.layers.reshaping.up_sampling3d import UpSampling3D\nfrom keras.src.layers.reshaping.zero_padding1d import ZeroPadding1D\nfrom keras.src.layers.reshaping.zero_padding2d import ZeroPadding2D\nfrom keras.src.layers.reshaping.zero_padding3d import ZeroPadding3D\nfrom keras.src.layers.rnn.bidirectional import Bidirectional\nfrom keras.src.layers.rnn.conv_lstm1d import ConvLSTM1D\nfrom keras.src.layers.rnn.conv_lstm2d import ConvLSTM2D\nfrom keras.src.layers.rnn.conv_lstm3d import ConvLSTM3D\nfrom keras.src.layers.rnn.gru import GRU\nfrom keras.src.layers.rnn.gru import GRUCell\nfrom keras.src.layers.rnn.lstm import LSTM\nfrom keras.src.layers.rnn.lstm import LSTMCell\nfrom keras.src.layers.rnn.rnn import RNN\nfrom keras.src.layers.rnn.simple_rnn import SimpleRNN\nfrom keras.src.layers.rnn.simple_rnn import SimpleRNNCell\nfrom keras.src.layers.rnn.stacked_rnn_cells import StackedRNNCells\nfrom keras.src.layers.rnn.time_distributed import TimeDistributed\nfrom keras.src.saving import serialization_lib\n\n\n@keras_export(\"keras.layers.serialize\")\ndef serialize(layer):\n    \"\"\"Returns the layer configuration as a Python dict.\n\n    Args:\n        layer: A `keras.layers.Layer` instance to serialize.\n\n    Returns:\n        Python dict which contains the configuration of the layer.\n    \"\"\"\n    return serialization_lib.serialize_keras_object(layer)\n\n\n@keras_export(\"keras.layers.deserialize\")\ndef deserialize(config, custom_objects=None):\n    \"\"\"Returns a Keras layer object via its configuration.\n\n    Args:\n        config: A python dict containing a serialized layer configuration.\n        custom_objects: Optional dictionary mapping names (strings) to custom\n            objects (classes and functions) to be considered during\n            deserialization.\n\n    Returns:\n        A Keras layer instance.\n    \"\"\"\n    obj = serialization_lib.deserialize_keras_object(\n        config,\n        custom_objects=custom_objects,\n    )\n    if not isinstance(obj, Layer):\n        raise ValueError(\n            \"`keras.layers.deserialize` was passed a `config` object that is \"\n            f\"not a `keras.layers.Layer`. Received: {config}\"\n        )\n    return obj\n"
  },
  {
    "path": "keras/src/layers/activations/__init__.py",
    "content": "from keras.src.layers.activations.elu import ELU\nfrom keras.src.layers.activations.leaky_relu import LeakyReLU\nfrom keras.src.layers.activations.prelu import PReLU\nfrom keras.src.layers.activations.relu import ReLU\nfrom keras.src.layers.activations.softmax import Softmax\n"
  },
  {
    "path": "keras/src/layers/activations/activation.py",
    "content": "from keras.src import activations\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.layer import Layer\n\n\n@keras_export(\"keras.layers.Activation\")\nclass Activation(Layer):\n    \"\"\"Applies an activation function to an output.\n\n    Args:\n        activation: Activation function. It could be a callable, or the name of\n            an activation from the `keras.activations` namespace.\n        **kwargs: Base layer keyword arguments, such as `name` and `dtype`.\n\n    Example:\n\n    >>> layer = keras.layers.Activation('relu')\n    >>> layer(np.array([-3.0, -1.0, 0.0, 2.0]))\n    [0.0, 0.0, 0.0, 2.0]\n    >>> layer = keras.layers.Activation(keras.activations.relu)\n    >>> layer(np.array([-3.0, -1.0, 0.0, 2.0]))\n    [0.0, 0.0, 0.0, 2.0]\n    \"\"\"\n\n    def __init__(self, activation, **kwargs):\n        super().__init__(**kwargs)\n        self.supports_masking = True\n        self.activation = activations.get(activation)\n\n        self._build_at_init()\n\n    def call(self, inputs):\n        return self.activation(inputs)\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def get_config(self):\n        config = {\"activation\": activations.serialize(self.activation)}\n        base_config = super().get_config()\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/activations/activation_test.py",
    "content": "import pytest\n\nfrom keras.src import activations\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass ActivationTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_activation_basics(self):\n        self.run_layer_test(\n            layers.Activation,\n            init_kwargs={\n                \"activation\": \"relu\",\n            },\n            input_shape=(2, 3),\n            expected_output_shape=(2, 3),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n            assert_built_after_instantiation=True,\n        )\n        self.run_layer_test(\n            layers.Activation,\n            init_kwargs={\n                \"activation\": activations.gelu,\n            },\n            input_shape=(2, 2),\n            expected_output_shape=(2, 2),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n            assert_built_after_instantiation=True,\n        )\n"
  },
  {
    "path": "keras/src/layers/activations/elu.py",
    "content": "from keras.src import activations\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.layer import Layer\n\n\n@keras_export(\"keras.layers.ELU\")\nclass ELU(Layer):\n    \"\"\"Applies an Exponential Linear Unit function to an output.\n\n    Formula:\n\n    ```\n    f(x) = alpha * (exp(x) - 1.) for x < 0\n    f(x) = x for x >= 0\n    ```\n\n    Args:\n        alpha: float, slope of negative section. Defaults to `1.0`.\n        **kwargs: Base layer keyword arguments, such as `name` and `dtype`.\n    \"\"\"\n\n    def __init__(self, alpha=1.0, **kwargs):\n        super().__init__(**kwargs)\n        self.alpha = alpha\n        self.supports_masking = True\n\n        self._build_at_init()\n\n    def call(self, inputs):\n        return activations.elu(inputs, alpha=self.alpha)\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n"
  },
  {
    "path": "keras/src/layers/activations/elu_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import testing\nfrom keras.src.layers.activations import elu\n\n\nclass ELUTest(testing.TestCase):\n    def test_config(self):\n        elu_layer = elu.ELU()\n        self.run_class_serialization_test(elu_layer)\n\n    @pytest.mark.requires_trainable_backend\n    def test_elu(self):\n        self.run_layer_test(\n            elu.ELU,\n            init_kwargs={},\n            input_shape=(2, 3, 4),\n            supports_masking=True,\n            assert_built_after_instantiation=True,\n        )\n\n    def test_correctness(self):\n        def np_elu(x, alpha=1.0):\n            return (x > 0) * x + (x <= 0) * alpha * (np.exp(x) - 1)\n\n        x = np.random.random((2, 2, 5))\n        elu_layer = elu.ELU()\n        self.assertAllClose(elu_layer(x), np_elu(x))\n\n        elu_layer = elu.ELU(alpha=0.7)\n        self.assertAllClose(elu_layer(x), np_elu(x, alpha=0.7))\n"
  },
  {
    "path": "keras/src/layers/activations/leaky_relu.py",
    "content": "import warnings\n\nfrom keras.src import activations\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.layer import Layer\n\n\n@keras_export(\"keras.layers.LeakyReLU\")\nclass LeakyReLU(Layer):\n    \"\"\"Leaky version of a Rectified Linear Unit activation layer.\n\n    This layer allows a small gradient when the unit is not active.\n\n    Formula:\n\n    ``` python\n    f(x) = alpha * x if x < 0\n    f(x) = x if x >= 0\n    ```\n\n    Example:\n\n    ``` python\n    leaky_relu_layer = LeakyReLU(negative_slope=0.5)\n    input = np.array([-10, -5, 0.0, 5, 10])\n    result = leaky_relu_layer(input)\n    # result = [-5. , -2.5,  0. ,  5. , 10.]\n    ```\n\n    Args:\n        negative_slope: Float >= 0.0. Negative slope coefficient.\n          Defaults to `0.3`.\n        **kwargs: Base layer keyword arguments, such as\n            `name` and `dtype`.\n\n    \"\"\"\n\n    def __init__(self, negative_slope=0.3, **kwargs):\n        if \"alpha\" in kwargs:\n            negative_slope = kwargs.pop(\"alpha\")\n            warnings.warn(\n                \"Argument `alpha` is deprecated. Use `negative_slope` instead.\"\n            )\n        super().__init__(**kwargs)\n        if negative_slope is None or negative_slope < 0:\n            raise ValueError(\n                \"The negative_slope value of a Leaky ReLU layer \"\n                \"cannot be None or negative value. Expected a float.\"\n                f\" Received: negative_slope={negative_slope}\"\n            )\n        self.negative_slope = negative_slope\n        self.supports_masking = True\n\n        self._build_at_init()\n\n    def call(self, inputs):\n        return activations.leaky_relu(\n            inputs, negative_slope=self.negative_slope\n        )\n\n    def get_config(self):\n        config = super().get_config()\n        config.update({\"negative_slope\": self.negative_slope})\n        return config\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n"
  },
  {
    "path": "keras/src/layers/activations/leaky_relu_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import testing\nfrom keras.src.layers.activations import leaky_relu\n\n\nclass LeakyReLUTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_leaky_relu(self):\n        self.run_layer_test(\n            leaky_relu.LeakyReLU,\n            init_kwargs={\n                \"negative_slope\": 1,\n            },\n            input_shape=(2, 3, 4),\n            supports_masking=True,\n            assert_built_after_instantiation=True,\n        )\n\n    def test_leaky_relu_correctness(self):\n        leaky_relu_layer = leaky_relu.LeakyReLU(negative_slope=0.5)\n        input = np.array([-10, -5, 0.0, 5, 10])\n        expected_output = np.array([-5.0, -2.5, 0.0, 5.0, 10.0])\n        result = leaky_relu_layer(input)\n        self.assertAllClose(result, expected_output)\n\n    def test_invalid_usage(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"The negative_slope value of a Leaky ReLU layer cannot be None\",\n        ):\n            self.run_layer_test(\n                leaky_relu.LeakyReLU,\n                init_kwargs={\"negative_slope\": None},\n                input_shape=(2, 3, 4),\n                supports_masking=True,\n            )\n"
  },
  {
    "path": "keras/src/layers/activations/prelu.py",
    "content": "from keras.src import activations\nfrom keras.src import constraints\nfrom keras.src import initializers\nfrom keras.src import regularizers\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\n\n\n@keras_export(\"keras.layers.PReLU\")\nclass PReLU(Layer):\n    \"\"\"Parametric Rectified Linear Unit activation layer.\n\n    Formula:\n    ``` python\n    f(x) = alpha * x for x < 0\n    f(x) = x for x >= 0\n    ```\n    where `alpha` is a learned array with the same shape as x.\n\n    Args:\n        alpha_initializer: Initializer function for the weights.\n        alpha_regularizer: Regularizer for the weights.\n        alpha_constraint: Constraint for the weights.\n        shared_axes: The axes along which to share learnable parameters for the\n            activation function. For example, if the incoming feature maps are\n            from a 2D convolution with output shape\n            `(batch, height, width, channels)`, and you wish to share parameters\n            across space so that each filter only has one set of parameters,\n            set `shared_axes=[1, 2]`.\n        **kwargs: Base layer keyword arguments, such as `name` and `dtype`.\n    \"\"\"\n\n    def __init__(\n        self,\n        alpha_initializer=\"Zeros\",\n        alpha_regularizer=None,\n        alpha_constraint=None,\n        shared_axes=None,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.supports_masking = True\n        self.alpha_initializer = initializers.get(alpha_initializer)\n        self.alpha_regularizer = regularizers.get(alpha_regularizer)\n        self.alpha_constraint = constraints.get(alpha_constraint)\n        if shared_axes is None:\n            self.shared_axes = None\n        elif not isinstance(shared_axes, (list, tuple)):\n            self.shared_axes = [shared_axes]\n        else:\n            self.shared_axes = list(shared_axes)\n\n    def build(self, input_shape):\n        param_shape = list(input_shape[1:])\n        if self.shared_axes is not None:\n            for i in self.shared_axes:\n                param_shape[i - 1] = 1\n        self.alpha = self.add_weight(\n            shape=param_shape,\n            name=\"alpha\",\n            initializer=self.alpha_initializer,\n            regularizer=self.alpha_regularizer,\n            constraint=self.alpha_constraint,\n        )\n        # Set input spec\n        axes = {}\n        if self.shared_axes:\n            for i in range(1, len(input_shape)):\n                if i not in self.shared_axes:\n                    axes[i] = input_shape[i]\n        self.input_spec = InputSpec(ndim=len(input_shape), axes=axes)\n\n    def call(self, inputs):\n        pos = activations.relu(inputs)\n        neg = -self.alpha * activations.relu(-inputs)\n        return pos + neg\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"alpha_initializer\": initializers.serialize(\n                    self.alpha_initializer\n                ),\n                \"alpha_regularizer\": regularizers.serialize(\n                    self.alpha_regularizer\n                ),\n                \"alpha_constraint\": constraints.serialize(\n                    self.alpha_constraint\n                ),\n                \"shared_axes\": self.shared_axes,\n            }\n        )\n        return config\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n"
  },
  {
    "path": "keras/src/layers/activations/prelu_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import testing\nfrom keras.src.layers.activations import prelu\n\n\nclass PReLUTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_prelu(self):\n        self.run_layer_test(\n            prelu.PReLU,\n            init_kwargs={\n                \"alpha_initializer\": \"zeros\",\n                \"alpha_regularizer\": \"L1\",\n                \"alpha_constraint\": \"MaxNorm\",\n                \"shared_axes\": 1,\n            },\n            input_shape=(2, 3, 4),\n            supports_masking=True,\n        )\n\n    def test_prelu_correctness(self):\n        def np_prelu(x, alpha):\n            return (x > 0) * x + (x <= 0) * alpha * x\n\n        inputs = np.random.randn(2, 10, 5, 3)\n        prelu_layer = prelu.PReLU(\n            alpha_initializer=\"glorot_uniform\",\n            alpha_regularizer=\"l1\",\n            alpha_constraint=\"non_neg\",\n            shared_axes=(1, 2),\n        )\n        prelu_layer.build(inputs.shape)\n\n        weights = np.random.random((1, 1, 3))\n        prelu_layer.alpha.assign(weights)\n        ref_out = np_prelu(inputs, weights)\n        self.assertAllClose(prelu_layer(inputs), ref_out)\n"
  },
  {
    "path": "keras/src/layers/activations/relu.py",
    "content": "from keras.src import activations\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.layer import Layer\n\n\n@keras_export(\"keras.layers.ReLU\")\nclass ReLU(Layer):\n    \"\"\"Rectified Linear Unit activation function layer.\n\n    Formula:\n    ``` python\n    f(x) = max(x,0)\n    f(x) = max_value if x >= max_value\n    f(x) = x if threshold <= x < max_value\n    f(x) = negative_slope * (x - threshold) otherwise\n    ```\n\n    Example:\n    ``` python\n    relu_layer = keras.layers.ReLU(\n        max_value=10,\n        negative_slope=0.5,\n        threshold=0,\n    )\n    input = np.array([-10, -5, 0.0, 5, 10])\n    result = relu_layer(input)\n    # result = [-5. , -2.5,  0. ,  5. , 10.]\n    ```\n\n    Args:\n        max_value: Float >= 0. Maximum activation value. None means unlimited.\n            Defaults to `None`.\n        negative_slope: Float >= 0. Negative slope coefficient.\n            Defaults to `0.0`.\n        threshold: Float >= 0. Threshold value for thresholded activation.\n            Defaults to `0.0`.\n        **kwargs: Base layer keyword arguments, such as `name` and `dtype`.\n    \"\"\"\n\n    def __init__(\n        self, max_value=None, negative_slope=0.0, threshold=0.0, **kwargs\n    ):\n        super().__init__(**kwargs)\n        if max_value is not None and max_value < 0.0:\n            raise ValueError(\n                \"max_value of a ReLU layer cannot be a negative \"\n                f\"value. Received: max_value={max_value}\"\n            )\n        if negative_slope is None or negative_slope < 0.0:\n            raise ValueError(\n                \"negative_slope of a ReLU layer cannot be a negative \"\n                f\"value. Received: negative_slope={negative_slope}\"\n            )\n        if threshold is None or threshold < 0.0:\n            raise ValueError(\n                \"threshold of a ReLU layer cannot be a negative \"\n                f\"value. Received: threshold={threshold}\"\n            )\n\n        self.max_value = max_value\n        self.negative_slope = negative_slope\n        self.threshold = threshold\n        self.supports_masking = True\n\n        self._build_at_init()\n\n    def call(self, inputs):\n        return activations.relu(\n            inputs,\n            negative_slope=self.negative_slope,\n            max_value=self.max_value,\n            threshold=self.threshold,\n        )\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"max_value\": self.max_value,\n                \"negative_slope\": self.negative_slope,\n                \"threshold\": self.threshold,\n            }\n        )\n        return config\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n"
  },
  {
    "path": "keras/src/layers/activations/relu_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import testing\nfrom keras.src.layers.activations import relu\n\n\nclass ReLUTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_relu(self):\n        self.run_layer_test(\n            relu.ReLU,\n            init_kwargs={\n                \"max_value\": 10,\n                \"negative_slope\": 1,\n                \"threshold\": 0.5,\n            },\n            input_shape=(2, 3, 4),\n            supports_masking=True,\n            assert_built_after_instantiation=True,\n        )\n\n    def test_normal_relu_correctness(self):\n        relu_layer = relu.ReLU(max_value=10, negative_slope=0.0, threshold=0)\n        input = np.array([-10, -5, 0.0, 5, 10])\n        expected_output = np.array([0.0, 0.0, 0.0, 5.0, 10.0])\n        result = relu_layer(input)\n        self.assertAllClose(result, expected_output)\n\n    def test_leaky_relu_correctness(self):\n        relu_layer = relu.ReLU(max_value=10, negative_slope=0.5, threshold=0)\n        input = np.array([-10, -5, 0.0, 5, 10])\n        expected_output = np.array([-5.0, -2.5, 0.0, 5.0, 10.0])\n        result = relu_layer(input)\n        self.assertAllClose(result, expected_output)\n\n    def test_threshold_relu_correctness(self):\n        relu_layer = relu.ReLU(max_value=8, negative_slope=0.0, threshold=5)\n        input = np.array([6.0, 7.0, 0.0, 5, 10])\n        expected_output = np.array([6.0, 7.0, 0.0, 0.0, 8.0])\n        result = relu_layer(input)\n        self.assertAllClose(result, expected_output)\n\n    def test_invalid_usage(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"max_value of a ReLU layer cannot be a negative value\",\n        ):\n            self.run_layer_test(\n                relu.ReLU,\n                init_kwargs={\n                    \"max_value\": -10,\n                    \"negative_slope\": 1,\n                    \"threshold\": 0.5,\n                },\n                input_shape=(2, 3, 4),\n                supports_masking=True,\n            )\n\n        with self.assertRaisesRegex(\n            ValueError,\n            \"negative_slope of a ReLU layer cannot be a negative value\",\n        ):\n            self.run_layer_test(\n                relu.ReLU,\n                init_kwargs={\n                    \"max_value\": 10,\n                    \"negative_slope\": -10,\n                    \"threshold\": 0.5,\n                },\n                input_shape=(2, 3, 4),\n                supports_masking=True,\n            )\n\n        with self.assertRaisesRegex(\n            ValueError, \"threshold of a ReLU layer cannot be a negative value\"\n        ):\n            self.run_layer_test(\n                relu.ReLU,\n                init_kwargs={\n                    \"max_value\": 10,\n                    \"negative_slope\": 1,\n                    \"threshold\": -10,\n                },\n                input_shape=(2, 3, 4),\n                supports_masking=True,\n            )\n"
  },
  {
    "path": "keras/src/layers/activations/softmax.py",
    "content": "from keras.src import activations\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.layer import Layer\n\n\ndef _large_negative_number(dtype):\n    \"\"\"Return a Large negative number based on dtype.\"\"\"\n    if backend.standardize_dtype(dtype) == \"float16\":\n        return -3e4\n    return -1e9\n\n\n@keras_export(\"keras.layers.Softmax\")\nclass Softmax(Layer):\n    \"\"\"Softmax activation layer.\n\n    Formula:\n    ``` python\n    exp_x = exp(x - max(x))\n    f(x) = exp_x / sum(exp_x)\n    ```\n\n    Example:\n    >>> softmax_layer = keras.layers.Softmax()\n    >>> input = np.array([1.0, 2.0, 1.0])\n    >>> result = softmax_layer(input)\n    >>> result\n    [0.21194157, 0.5761169, 0.21194157]\n\n\n    Args:\n        axis: Integer, or list of Integers, axis along which the softmax\n            normalization is applied.\n        **kwargs: Base layer keyword arguments, such as `name` and `dtype`.\n\n    Call arguments:\n        inputs: The inputs (logits) to the softmax layer.\n        mask: A boolean mask that is broadcastable to `inputs`. The mask\n            specifies 1 to keep and 0 to mask. Each dimension of the mask\n            must either be 1 or match the corresponding dimension of\n            `inputs`; it must not be larger. Defaults to `None`.\n\n    Returns:\n        Softmaxed output with the same shape as `inputs`.\n    \"\"\"\n\n    def __init__(self, axis=-1, **kwargs):\n        super().__init__(**kwargs)\n        self.axis = axis\n        self.supports_masking = True\n\n        self._build_at_init()\n\n    def call(self, inputs, mask=None):\n        if mask is not None:\n            if len(mask.shape) > len(inputs.shape):\n                raise ValueError(\n                    \"The `mask` must be broadcastable to `inputs` \"\n                    \"and must not have more dimensions. \"\n                    f\"Received: inputs.shape={inputs.shape}, \"\n                    f\"mask.shape={mask.shape}\"\n                )\n            for m_dim, i_dim in zip(mask.shape[::-1], inputs.shape[::-1]):\n                if m_dim is not None and i_dim is not None:\n                    if m_dim != 1 and m_dim != i_dim:\n                        raise ValueError(\n                            \"The `mask` must be broadcastable to \"\n                            \"`inputs`. Each mask dimension must be 1 \"\n                            \"or match the corresponding input \"\n                            \"dimension. Received: \"\n                            f\"inputs.shape={inputs.shape}, \"\n                            f\"mask.shape={mask.shape}\"\n                        )\n            # We keep the positions where the mask is True or > 0.5, and set the\n            # other (masked) positions to -1e.9.\n            if backend.standardize_dtype(mask.dtype) != \"bool\":\n                mask = backend.numpy.greater(\n                    mask, backend.cast(0.5, dtype=mask.dtype)\n                )\n            inputs = backend.numpy.where(\n                mask, inputs, _large_negative_number(inputs.dtype)\n            )\n        if isinstance(self.axis, (tuple, list)):\n            if len(self.axis) > 1:\n                outputs = backend.numpy.exp(\n                    inputs\n                    - backend.math.logsumexp(\n                        inputs, axis=self.axis, keepdims=True\n                    )\n                )\n            else:\n                outputs = activations.softmax(inputs, axis=self.axis[0])\n        else:\n            outputs = activations.softmax(inputs, axis=self.axis)\n\n        # Free pre-softmax masked inputs to reduce peak memory.\n        # Without this, the masked inputs, softmax outputs, and\n        # post-masked outputs all exist simultaneously.\n        del inputs\n\n        if mask is not None:\n            # Zero out masked positions in case the entire axis is masked\n            # (where softmax would output a uniform distribution).\n            outputs = backend.numpy.where(mask, outputs, 0.0)\n\n        return outputs\n\n    def get_config(self):\n        config = super().get_config()\n        config.update({\"axis\": self.axis})\n        return config\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n"
  },
  {
    "path": "keras/src/layers/activations/softmax_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import testing\nfrom keras.src.layers.activations import softmax\n\n\nclass SoftmaxTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_softmax(self):\n        self.run_layer_test(\n            softmax.Softmax,\n            init_kwargs={},\n            input_shape=(2, 3, 4),\n            supports_masking=True,\n            assert_built_after_instantiation=True,\n        )\n\n    def test_softmax_correctness(self):\n        softmax_layer = softmax.Softmax()\n        input = np.array([[1.0, 2.0, 1.0], [1.0, 2.0, 1.0]])\n        expected_output = np.array(\n            [\n                [0.21194157, 0.5761169, 0.21194157],\n                [0.21194157, 0.5761169, 0.21194157],\n            ]\n        )\n        result = softmax_layer(input)\n        self.assertAllClose(result, expected_output)\n\n    def test_softmax_correctness_with_mask(self):\n        softmax_layer = softmax.Softmax(axis=(1, 0))\n        input = np.array([[1.0, 2.0, 1.0], [1.0, 2.0, 1.0]])\n        mask = np.array([[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]])\n        expected_output = np.array(\n            [[0.21194154, 0.0, 0.21194154], [0.0, 0.57611686, 0.0]]\n        )\n        result = softmax_layer(input, mask=mask)\n        self.assertAllClose(result, expected_output)\n\n    def test_softmax_correctness_with_axis(self):\n        softmax_layer = softmax.Softmax(axis=(1))\n        input = np.array([[1.0, 2.0, 1.0], [1.0, 2.0, 1.0]])\n        expected_output = np.array(\n            [\n                [0.21194157, 0.5761169, 0.21194157],\n                [0.21194157, 0.5761169, 0.21194157],\n            ]\n        )\n        result = softmax_layer(input)\n        self.assertAllClose(result, expected_output)\n\n    def test_softmax_masked_values_are_zero_including_fully_masked(self):\n        \"\"\"\n        Tests softmax with mask on default axis (-1).\n        Ensures output is 0 where mask is False.\n        Includes a row where all elements are masked.\n        \"\"\"\n        softmax_layer = softmax.Softmax()  # Default axis = -1\n\n        input = np.array(\n            [\n                [1.0, 2.0, 5.0, 1.0],\n                [1.0, 1.0, 1.0, 1.0],\n                [3.0, 1.0, 2.0, 4.0],\n            ],\n            dtype=np.float32,\n        )\n        mask = np.array(\n            [\n                [True, True, False, False],  # Partially masked\n                [False, False, False, False],  # Fully masked\n                [True, True, True, True],  # Not masked\n            ],\n            dtype=bool,\n        )\n\n        expected_output = np.array(\n            [\n                [0.268941, 0.731059, 0.0, 0.0],  # last two masked\n                [0.0, 0.0, 0.0, 0.0],  # Fully masked row should be all zeros\n                [0.236883, 0.032059, 0.087144, 0.643914],\n            ]\n        )\n\n        result = softmax_layer(input, mask=mask)\n\n        self.assertAllClose(result, expected_output)\n\n    def test_softmax_mask_broadcastable(self):\n        softmax_layer = softmax.Softmax()\n        # mask (1, 3) broadcastable to inputs (2, 3) — should work\n        inputs = np.array([[1.0, 2.0, 1.0], [3.0, 4.0, 5.0]])\n        mask = np.array([[1.0, 0.0, 1.0]])\n        result = softmax_layer(inputs, mask=mask)\n        # Masked position (column 1) should be 0\n        self.assertAllClose(result[:, 1], [0.0, 0.0])\n\n    def test_softmax_mask_broadcastable_fewer_dims(self):\n        softmax_layer = softmax.Softmax(axis=-1)\n        # mask (3,) broadcastable to inputs (2, 3) — should work\n        inputs = np.array([[1.0, 2.0, 1.0], [3.0, 4.0, 5.0]])\n        mask = np.array([1.0, 0.0, 1.0])\n        result = softmax_layer(inputs, mask=mask)\n        self.assertAllClose(result[:, 1], [0.0, 0.0])\n\n    def test_softmax_mask_shape_mismatch(self):\n        softmax_layer = softmax.Softmax()\n        # mask (2, 3) with inputs (1, 3) — mask is larger, not allowed\n        inputs = np.array([[1.0, 2.0, 1.0]])\n        mask = np.array([[1.0, 0.0, 1.0], [0.0, 1.0, 0.0]])\n        with self.assertRaisesRegex(ValueError, \"broadcastable\"):\n            softmax_layer(inputs, mask=mask)\n\n    def test_softmax_mask_too_many_dims(self):\n        softmax_layer = softmax.Softmax()\n        # mask has more dimensions than inputs — not allowed\n        inputs = np.array([[1.0, 2.0, 1.0]])\n        mask = np.array([[[1.0, 0.0, 1.0]]])\n        with self.assertRaisesRegex(\n            ValueError, \"must not have more dimensions\"\n        ):\n            softmax_layer(inputs, mask=mask)\n"
  },
  {
    "path": "keras/src/layers/attention/__init__.py",
    "content": ""
  },
  {
    "path": "keras/src/layers/attention/additive_attention.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.attention.attention import Attention\n\n\n@keras_export(\"keras.layers.AdditiveAttention\")\nclass AdditiveAttention(Attention):\n    \"\"\"Additive attention layer, a.k.a. Bahdanau-style attention.\n\n    Inputs are a list with 2 or 3 elements:\n    1. A `query` tensor of shape `(batch_size, Tq, dim)`.\n    2. A `value` tensor of shape `(batch_size, Tv, dim)`.\n    3. A optional `key` tensor of shape `(batch_size, Tv, dim)`. If none\n        supplied, `value` will be used as `key`.\n\n    The calculation follows the steps:\n    1. Calculate attention scores using `query` and `key` with shape\n        `(batch_size, Tq, Tv)` as a non-linear sum\n        `scores = reduce_sum(tanh(query + key), axis=-1)`.\n    2. Use scores to calculate a softmax distribution with shape\n        `(batch_size, Tq, Tv)`.\n    3. Use the softmax distribution to create a linear combination of `value`\n        with shape `(batch_size, Tq, dim)`.\n\n    Args:\n        use_scale: If `True`, will create a scalar variable to scale the\n            attention scores.\n        dropout: Float between 0 and 1. Fraction of the units to drop for the\n            attention scores. Defaults to `0.0`.\n\n    Call arguments:\n        inputs: List of the following tensors:\n            - `query`: Query tensor of shape `(batch_size, Tq, dim)`.\n            - `value`: Value tensor of shape `(batch_size, Tv, dim)`.\n            - `key`: Optional key tensor of shape `(batch_size, Tv, dim)`. If\n                not given, will use `value` for both `key` and `value`, which is\n                the most common case.\n        mask: List of the following tensors:\n            - `query_mask`: A boolean mask tensor of shape `(batch_size, Tq)`.\n                If given, the output will be zero at the positions where\n                `mask==False`.\n            - `value_mask`: A boolean mask tensor of shape `(batch_size, Tv)`.\n                If given, will apply the mask such that values at positions\n                 where `mask==False` do not contribute to the result.\n        return_attention_scores: bool, it `True`, returns the attention scores\n            (after masking and softmax) as an additional output argument.\n        training: Python boolean indicating whether the layer should behave in\n            training mode (adding dropout) or in inference mode (no dropout).\n        use_causal_mask: Boolean. Set to `True` for decoder self-attention. Adds\n            a mask such that position `i` cannot attend to positions `j > i`.\n            This prevents the flow of information from the future towards the\n            past. Defaults to `False`.\n\n    Output:\n        Attention outputs of shape `(batch_size, Tq, dim)`.\n        (Optional) Attention scores after masking and softmax with shape\n            `(batch_size, Tq, Tv)`.\n    \"\"\"\n\n    def __init__(\n        self,\n        use_scale=True,\n        dropout=0.0,\n        **kwargs,\n    ):\n        super().__init__(use_scale=use_scale, dropout=dropout, **kwargs)\n\n    def build(self, input_shape):\n        self._validate_inputs(input_shape)\n        dim = input_shape[0][-1]\n        self.scale = None\n        if self.use_scale:\n            self.scale = self.add_weight(\n                name=\"scale\",\n                shape=[dim],\n                initializer=\"glorot_uniform\",\n                dtype=self.dtype,\n                trainable=True,\n            )\n\n    def _calculate_scores(self, query, key):\n        \"\"\"Calculates attention scores as a nonlinear sum of query and key.\n\n        Args:\n            query: Query tensor of shape `(batch_size, Tq, dim)`.\n            key: Key tensor of shape `(batch_size, Tv, dim)`.\n\n        Returns:\n            Tensor of shape `(batch_size, Tq, Tv)`.\n        \"\"\"\n        # Reshape tensors to enable broadcasting.\n        # Reshape into [batch_size, Tq, 1, dim].\n        q_reshaped = ops.expand_dims(query, axis=-2)\n        # Reshape into [batch_size, 1, Tv, dim].\n        k_reshaped = ops.expand_dims(key, axis=-3)\n        scale = self.scale if self.use_scale else 1.0\n        return ops.sum(scale * ops.tanh(q_reshaped + k_reshaped), axis=-1)\n\n    def get_config(self):\n        base_config = super().get_config()\n        del base_config[\"score_mode\"]\n        return base_config\n"
  },
  {
    "path": "keras/src/layers/attention/additive_attention_test.py",
    "content": "import numpy as np\n\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass AdditiveAttentionTest(testing.TestCase):\n    def test_attention_basics(self):\n        # No scale\n        self.run_layer_test(\n            layers.AdditiveAttention,\n            init_kwargs={\n                \"use_scale\": True,\n                \"dropout\": 0.5,\n            },\n            input_shape=[(2, 3, 4), (2, 4, 4)],\n            expected_output_shape=(2, 3, 4),\n            expected_num_trainable_weights=1,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=1,\n            expected_num_losses=0,\n            supports_masking=True,\n            run_training_check=False,\n        )\n        # With scale.\n        self.run_layer_test(\n            layers.AdditiveAttention,\n            init_kwargs={\n                \"use_scale\": False,\n                \"dropout\": 0.5,\n            },\n            input_shape=[(2, 3, 4), (2, 4, 4)],\n            expected_output_shape=(2, 3, 4),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=1,\n            expected_num_losses=0,\n            supports_masking=True,\n            run_training_check=False,\n        )\n\n    def test_attention_correctness(self):\n        query = np.array([[[1.0, 0.0], [0.0, 1.0]]])\n        key = np.array([[[0.0, 1.0], [1.0, 0.0]]])\n        value = np.array([[[1.0, 2.0], [3.0, 4.0]]])\n\n        layer = layers.AdditiveAttention(use_scale=False)\n        output, scores = layer(\n            [query, value, key],\n            return_attention_scores=True,\n        )\n        self.assertAllClose(\n            output,\n            [[[1.727, 2.727], [2.272, 3.272]]],\n            atol=1e-3,\n            tpu_atol=1e-2,\n        )\n        self.assertAllClose(\n            scores,\n            [[[0.636, 0.363], [0.363, 0.636]]],\n            atol=1e-3,\n            tpu_atol=1e-2,\n        )\n\n    def test_attention_with_mask(self):\n        layer = layers.AdditiveAttention(use_scale=False)\n        query = np.array([[[1.0, 0.0], [0.0, 1.0]]])\n        value = np.array([[[1.0, 1.0], [1.0, 1.0]]])\n        query_mask = np.array([[True, False]])\n        value_mask = np.array([[True, False]])\n        output, scores = layer(\n            [query, value],\n            mask=[query_mask, value_mask],\n            return_attention_scores=True,\n        )\n        self.assertAllClose(output, [[[1.0, 1.0], [0.0, 0.0]]])\n        self.assertAllClose(scores, [[[1.0, 0.0], [1.0, 0.0]]])\n\n    def test_attention_errors(self):\n        layer = layers.AdditiveAttention()\n        tensor = np.array([[[1.0, 1.0], [1.0, 1.0]]])\n        with self.assertRaisesRegex(ValueError, \"must be called on a list\"):\n            layer(tensor)\n\n        with self.assertRaisesRegex(ValueError, \"length 2 or 3\"):\n            layer([tensor, tensor, tensor, tensor])\n\n        with self.assertRaisesRegex(ValueError, \"layer mask must be a list\"):\n            layer([tensor, tensor], mask=tensor)\n\n        with self.assertRaisesRegex(ValueError, \"length 2 or 3\"):\n            layer([tensor, tensor], mask=[tensor])\n"
  },
  {
    "path": "keras/src/layers/attention/attention.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend import KerasTensor\nfrom keras.src.layers.layer import Layer\n\n\n@keras_export(\"keras.layers.Attention\")\nclass Attention(Layer):\n    \"\"\"Dot-product attention layer, a.k.a. Luong-style attention.\n\n    Inputs are a list with 2 or 3 elements:\n    1. A `query` tensor of shape `(batch_size, Tq, dim)`.\n    2. A `value` tensor of shape `(batch_size, Tv, dim)`.\n    3. A optional `key` tensor of shape `(batch_size, Tv, dim)`. If none\n        supplied, `value` will be used as a `key`.\n\n    The calculation follows the steps:\n    1. Calculate attention scores using `query` and `key` with shape\n        `(batch_size, Tq, Tv)`.\n    2. Use scores to calculate a softmax distribution with shape\n        `(batch_size, Tq, Tv)`.\n    3. Use the softmax distribution to create a linear combination of `value`\n        with shape `(batch_size, Tq, dim)`.\n\n    Args:\n        use_scale: If `True`, will create a scalar variable to scale the\n            attention scores.\n        dropout: Float between 0 and 1. Fraction of the units to drop for the\n            attention scores. Defaults to `0.0`.\n        seed: A Python integer to use as random seed in case of `dropout`.\n        score_mode: Function to use to compute attention scores, one of\n            `{\"dot\", \"concat\"}`. `\"dot\"` refers to the dot product between the\n            query and key vectors. `\"concat\"` refers to the hyperbolic tangent\n            of the concatenation of the `query` and `key` vectors.\n\n    Call arguments:\n        inputs: List of the following tensors:\n            - `query`: Query tensor of shape `(batch_size, Tq, dim)`.\n            - `value`: Value tensor of shape `(batch_size, Tv, dim)`.\n            - `key`: Optional key tensor of shape `(batch_size, Tv, dim)`. If\n                not given, will use `value` for both `key` and `value`, which is\n                the most common case.\n        mask: List of the following tensors:\n            - `query_mask`: A boolean mask tensor of shape `(batch_size, Tq)`.\n                If given, the output will be zero at the positions where\n                `mask==False`.\n            - `value_mask`: A boolean mask tensor of shape `(batch_size, Tv)`.\n                If given, will apply the mask such that values at positions\n                 where `mask==False` do not contribute to the result.\n        return_attention_scores: bool, it `True`, returns the attention scores\n            (after masking and softmax) as an additional output argument.\n        training: Python boolean indicating whether the layer should behave in\n            training mode (adding dropout) or in inference mode (no dropout).\n        use_causal_mask: Boolean. Set to `True` for decoder self-attention. Adds\n            a mask such that position `i` cannot attend to positions `j > i`.\n            This prevents the flow of information from the future towards the\n            past. Defaults to `False`.\n\n    Output:\n        Attention outputs of shape `(batch_size, Tq, dim)`.\n        (Optional) Attention scores after masking and softmax with shape\n            `(batch_size, Tq, Tv)`.\n    \"\"\"\n\n    def __init__(\n        self,\n        use_scale=False,\n        score_mode=\"dot\",\n        dropout=0.0,\n        seed=None,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.use_scale = use_scale\n        self.score_mode = score_mode\n        self.dropout = dropout\n        if self.dropout > 0:\n            self.seed_generator = backend.random.SeedGenerator(seed=seed)\n\n        if self.score_mode not in [\"dot\", \"concat\"]:\n            raise ValueError(\n                \"Invalid value for argument score_mode. \"\n                \"Expected one of {'dot', 'concat'}. \"\n                f\"Received: score_mode={score_mode}\"\n            )\n\n        self._return_attention_scores = False\n\n    def build(self, input_shape):\n        self._validate_inputs(input_shape)\n        self.scale = None\n        self.concat_score_weight = None\n        if self.use_scale:\n            self.scale = self.add_weight(\n                name=\"scale\",\n                shape=(),\n                initializer=\"ones\",\n                dtype=self.dtype,\n                trainable=True,\n            )\n        if self.score_mode == \"concat\":\n            self.concat_score_weight = self.add_weight(\n                name=\"concat_score_weight\",\n                shape=(),\n                initializer=\"ones\",\n                dtype=self.dtype,\n                trainable=True,\n            )\n\n    def _calculate_scores(self, query, key):\n        \"\"\"Calculates attention scores as a query-key dot product.\n\n        Args:\n            query: Query tensor of shape `(batch_size, Tq, dim)`.\n            key: Key tensor of shape `(batch_size, Tv, dim)`.\n\n        Returns:\n            Tensor of shape `(batch_size, Tq, Tv)`.\n        \"\"\"\n        if self.score_mode == \"dot\":\n            scores = ops.matmul(query, ops.transpose(key, axes=[0, 2, 1]))\n            if self.scale is not None:\n                scores = ops.multiply(scores, self.scale)\n        elif self.score_mode == \"concat\":\n            # Reshape tensors to enable broadcasting.\n            # Reshape into [batch_size, Tq, 1, dim].\n            q_reshaped = ops.expand_dims(query, axis=-2)\n            # Reshape into [batch_size, 1, Tv, dim].\n            k_reshaped = ops.expand_dims(key, axis=-3)\n            if self.scale is not None:\n                scores = self.concat_score_weight * ops.sum(\n                    ops.tanh(self.scale * (q_reshaped + k_reshaped)), axis=-1\n                )\n            else:\n                scores = self.concat_score_weight * ops.sum(\n                    ops.tanh(q_reshaped + k_reshaped), axis=-1\n                )\n        else:\n            raise ValueError(\"scores not computed\")\n\n        return scores\n\n    def _apply_scores(self, scores, value, scores_mask=None, training=False):\n        \"\"\"Applies attention scores to the given value tensor.\n\n        To use this method in your attention layer, follow the steps:\n\n        * Use `query` tensor of shape `(batch_size, Tq)` and `key` tensor of\n            shape `(batch_size, Tv)` to calculate the attention `scores`.\n        * Pass `scores` and `value` tensors to this method. The method applies\n            `scores_mask`, calculates\n            `attention_distribution = softmax(scores)`, then returns\n            `matmul(attention_distribution, value).\n        * Apply `query_mask` and return the result.\n\n        Args:\n            scores: Scores float tensor of shape `(batch_size, Tq, Tv)`.\n            value: Value tensor of shape `(batch_size, Tv, dim)`.\n            scores_mask: A boolean mask tensor of shape `(batch_size, 1, Tv)`\n                or `(batch_size, Tq, Tv)`. If given, scores at positions where\n                `scores_mask==False` do not contribute to the result. It must\n                contain at least one `True` value in each line along the last\n                dimension.\n            training: Python boolean indicating whether the layer should behave\n                in training mode (adding dropout) or in inference mode\n                (no dropout).\n\n        Returns:\n            Tensor of shape `(batch_size, Tq, dim)`.\n            Attention scores after masking and softmax with shape\n                `(batch_size, Tq, Tv)`.\n        \"\"\"\n        if scores_mask is not None:\n            padding_mask = ops.logical_not(scores_mask)\n            # Bias so padding positions do not contribute to attention\n            # distribution.  Note 65504. is the max float16 value.\n            max_value = 65504.0 if scores.dtype == \"float16\" else 1.0e9\n            if len(padding_mask.shape) == 2:\n                padding_mask = ops.expand_dims(padding_mask, axis=-2)\n            scores -= max_value * ops.cast(padding_mask, dtype=scores.dtype)\n\n        weights = ops.softmax(scores, axis=-1)\n        if training and self.dropout > 0:\n            weights = backend.random.dropout(\n                weights,\n                self.dropout,\n                seed=self.seed_generator,\n            )\n        return ops.matmul(weights, value), weights\n\n    def _calculate_score_mask(self, scores, v_mask, use_causal_mask):\n        if use_causal_mask:\n            # Creates a lower triangular mask, so position i cannot attend to\n            # positions j > i. This prevents the flow of information from the\n            # future into the past.\n            score_shape = ops.shape(scores)\n            # causal_mask_shape = [1, Tq, Tv].\n            mask_shape = (1, score_shape[-2], score_shape[-1])\n            ones_mask = ops.ones(shape=mask_shape, dtype=\"int32\")\n            row_index = ops.cumsum(ones_mask, axis=-2)\n            col_index = ops.cumsum(ones_mask, axis=-1)\n            causal_mask = ops.greater_equal(row_index, col_index)\n\n            if v_mask is not None:\n                # Mask of shape [batch_size, 1, Tv].\n                v_mask = ops.expand_dims(v_mask, axis=-2)\n                return ops.logical_and(v_mask, causal_mask)\n            return causal_mask\n        else:\n            # If not using causal mask, return the value mask as is,\n            # or None if the value mask is not provided.\n            return v_mask\n\n    def call(\n        self,\n        inputs,\n        mask=None,\n        training=False,\n        return_attention_scores=False,\n        use_causal_mask=False,\n    ):\n        self._validate_inputs(inputs=inputs, mask=mask)\n        self._return_attention_scores = return_attention_scores\n        q = inputs[0]\n        v = inputs[1]\n        k = inputs[2] if len(inputs) > 2 else v\n        q_mask = mask[0] if mask else None\n        v_mask = mask[1] if mask else None\n        scores = self._calculate_scores(query=q, key=k)\n        scores_mask = self._calculate_score_mask(\n            scores, v_mask, use_causal_mask\n        )\n        attention_output, attention_scores = self._apply_scores(\n            scores=scores, value=v, scores_mask=scores_mask, training=training\n        )\n        if q_mask is not None:\n            # Mask of shape [batch_size, Tq, 1].\n            q_mask = ops.expand_dims(q_mask, axis=-1)\n            attention_output *= ops.cast(q_mask, dtype=attention_output.dtype)\n        if return_attention_scores:\n            return (attention_output, attention_scores)\n        else:\n            return attention_output\n\n    def compute_mask(self, inputs, mask=None):\n        self._validate_inputs(inputs=inputs, mask=mask)\n        if mask is None or mask[0] is None:\n            return None\n        return ops.convert_to_tensor(mask[0])\n\n    def compute_output_shape(self, input_shape):\n        query_shape, value_shape, key_shape = input_shape\n        if key_shape is None:\n            key_shape = value_shape\n\n        output_shape = (*query_shape[:-1], value_shape[-1])\n        if self._return_attention_scores:\n            scores_shape = (query_shape[0], query_shape[1], key_shape[1])\n            return output_shape, scores_shape\n        return output_shape\n\n    def compute_output_spec(\n        self,\n        inputs,\n        mask=None,\n        return_attention_scores=False,\n        training=None,\n        use_causal_mask=False,\n    ):\n        # Validate and unpack inputs\n        self._validate_inputs(inputs, mask)\n        query = inputs[0]\n        value = inputs[1]\n        key = inputs[2] if len(inputs) > 2 else value\n\n        # Compute primary output shape\n        output_shape = self.compute_output_shape(\n            [query.shape, value.shape, key.shape]\n        )\n        output_spec = KerasTensor(output_shape, dtype=self.compute_dtype)\n\n        # Handle attention scores if requested\n        if self._return_attention_scores or return_attention_scores:\n            scores_shape = (\n                query.shape[0],\n                query.shape[1],\n                key.shape[1],\n            )  # (batch_size, Tq, Tv)\n            attention_scores_spec = KerasTensor(\n                scores_shape, dtype=self.compute_dtype\n            )\n            return (output_spec, attention_scores_spec)\n\n        return output_spec\n\n    def _validate_inputs(self, inputs, mask=None):\n        \"\"\"Validates arguments of the call method.\"\"\"\n        class_name = self.__class__.__name__\n        if not isinstance(inputs, list):\n            raise ValueError(\n                f\"{class_name} layer must be called on a list of inputs, \"\n                \"namely [query, value] or [query, value, key]. \"\n                f\"Received: inputs={inputs}.\"\n            )\n        if len(inputs) < 2 or len(inputs) > 3:\n            raise ValueError(\n                f\"{class_name} layer accepts inputs list of length 2 or 3, \"\n                \"namely [query, value] or [query, value, key]. \"\n                f\"Received length: {len(inputs)}.\"\n            )\n        if mask is not None:\n            if not isinstance(mask, list):\n                raise ValueError(\n                    f\"{class_name} layer mask must be a list, \"\n                    f\"namely [query_mask, value_mask]. Received: mask={mask}.\"\n                )\n            if len(mask) < 2 or len(mask) > 3:\n                raise ValueError(\n                    f\"{class_name} layer accepts mask list of length 2 or 3. \"\n                    f\"Received: inputs={inputs}, mask={mask}.\"\n                )\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\n            \"use_scale\": self.use_scale,\n            \"score_mode\": self.score_mode,\n            \"dropout\": self.dropout,\n        }\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/attention/attention_test.py",
    "content": "import numpy as np\n\nfrom keras.src import layers\nfrom keras.src import ops\nfrom keras.src import testing\n\n\nclass AttentionTest(testing.TestCase):\n    def test_attention_basics(self):\n        # No scale, no concat.\n        self.run_layer_test(\n            layers.Attention,\n            init_kwargs={\n                \"score_mode\": \"dot\",\n                \"dropout\": 0.5,\n            },\n            input_shape=[(2, 3, 4), (2, 4, 4)],\n            expected_output_shape=(2, 3, 4),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=1,\n            expected_num_losses=0,\n            supports_masking=True,\n            run_training_check=False,\n        )\n        # Scale and concat.\n        self.run_layer_test(\n            layers.Attention,\n            init_kwargs={\n                \"use_scale\": True,\n                \"score_mode\": \"concat\",\n                \"dropout\": 0.5,\n            },\n            input_shape=[(2, 3, 4), (2, 4, 4)],\n            expected_output_shape=(2, 3, 4),\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=1,\n            expected_num_losses=0,\n            supports_masking=True,\n            run_training_check=False,\n        )\n\n    def test_attention_correctness(self):\n        query = np.array([[[1.0, 0.0], [0.0, 1.0]]])\n        key = np.array([[[0.0, 1.0], [1.0, 0.0]]])\n        value = np.array([[[1.0, 2.0], [3.0, 4.0]]])\n\n        # Dot.\n        layer = layers.Attention(score_mode=\"dot\")\n        output, scores = layer(\n            [query, value, key],\n            return_attention_scores=True,\n        )\n        self.assertAllClose(\n            output,\n            [[[2.462, 3.462], [1.538, 2.538]]],\n            atol=1e-3,\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n        self.assertAllClose(\n            scores,\n            [[[0.269, 0.731], [0.731, 0.269]]],\n            atol=1e-3,\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n\n        # Concat.\n        layer = layers.Attention(score_mode=\"concat\")\n        output, scores = layer(\n            [query, value, key],\n            return_attention_scores=True,\n        )\n        self.assertAllClose(\n            output,\n            [[[1.727, 2.727], [2.272, 3.272]]],\n            atol=1e-3,\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n        self.assertAllClose(\n            scores,\n            [[[0.636, 0.363], [0.363, 0.636]]],\n            atol=1e-3,\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n\n    def test_attention_with_mask(self):\n        layer = layers.Attention()\n        query = np.array([[[1.0, 0.0], [0.0, 1.0]]])\n        value = np.array([[[1.0, 1.0], [1.0, 1.0]]])\n        query_mask = np.array([[True, False]])\n        value_mask = np.array([[True, False]])\n        output, scores = layer(\n            [query, value],\n            mask=[query_mask, value_mask],\n            return_attention_scores=True,\n        )\n        self.assertAllClose(output, [[[1.0, 1.0], [0.0, 0.0]]])\n        self.assertAllClose(scores, [[[1.0, 0.0], [1.0, 0.0]]])\n\n    def test_attention_2D_mask_shape_mismatch(self):\n        layer = layers.Attention()\n        batch_size, Tq, Tv, dim = 2, 3, 4, 5\n        query = np.random.random((batch_size, Tq, dim)).astype(np.float32)\n        value = np.random.random((batch_size, Tv, dim)).astype(np.float32)\n        query_mask = np.array([[True, False, True], [True, False, True]])\n        value_mask = np.array(\n            [[True, False, True, True], [True, False, True, True]]\n        )\n        output, scores = layer(\n            [query, value],\n            mask=[query_mask, value_mask],\n            return_attention_scores=True,\n        )\n        self.assertEqual(output.shape, (batch_size, Tq, dim))\n        self.assertEqual(scores.shape, (batch_size, Tq, Tv))\n\n    def test_attention_errors(self):\n        layer = layers.Attention()\n        tensor = np.array([[[1.0, 1.0], [1.0, 1.0]]])\n        with self.assertRaisesRegex(ValueError, \"must be called on a list\"):\n            layer(tensor)\n\n        with self.assertRaisesRegex(ValueError, \"length 2 or 3\"):\n            layer([tensor, tensor, tensor, tensor])\n\n        with self.assertRaisesRegex(ValueError, \"layer mask must be a list\"):\n            layer([tensor, tensor], mask=tensor)\n\n        with self.assertRaisesRegex(ValueError, \"length 2 or 3\"):\n            layer([tensor, tensor], mask=[tensor])\n\n    def test_attention_with_dropout(self):\n        query = np.array([[[1.0, 0.0], [0.0, 1.0]]])\n        value = np.array([[[1.0, 1.0], [1.0, 1.0]]])\n        layer_with_dropout = layers.Attention(dropout=0.2)\n        layer_without_dropout = layers.Attention()\n\n        output1, scores1 = layer_with_dropout(\n            [query, value], return_attention_scores=True, training=True\n        )\n        output2, scores2 = layer_without_dropout(\n            [query, value], return_attention_scores=True, training=True\n        )\n        self.assertNotAllClose(output1, output2)\n        self.assertNotAllClose(scores1, scores2)\n\n    def test_attention_invalid_score_mode(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Invalid value for argument score_mode. \"\n            \"Expected one of {'dot', 'concat'}\",\n        ):\n            layers.Attention(score_mode=\"invalid_mode\")\n\n    def test_attention_calculate_scores_with_scale(self):\n        query = np.random.random((2, 3, 4))\n        key = np.random.random((2, 4, 4))\n        layer = layers.Attention(use_scale=True, score_mode=\"dot\")\n        layer.build(input_shape=[(2, 3, 4), (2, 4, 4)])\n        expected_scores = np.matmul(query, key.transpose((0, 2, 1)))\n        expected_scores *= layer.scale.numpy()\n        actual_scores = layer._calculate_scores(query, key)\n        self.assertAllClose(\n            actual_scores, expected_scores, tpu_atol=1e-2, tpu_rtol=1e-2\n        )\n\n    def test_attention_calculate_score_mask_no_causal_no_vmask(self):\n        scores = np.random.random((2, 3, 4))\n        layer = layers.Attention()\n        mask = layer._calculate_score_mask(\n            scores, v_mask=None, use_causal_mask=False\n        )\n        self.assertIsNone(\n            mask,\n            \"Mask should be None when no causal mask and no value mask \"\n            \"are used\",\n        )\n\n    def test_attention_calculate_score_mask_with_causal_no_vmask(self):\n        scores = np.random.random((2, 3, 4))\n        layer = layers.Attention()\n\n        causal_mask = layer._calculate_score_mask(\n            scores, v_mask=None, use_causal_mask=True\n        )\n        expected_causal_mask = np.tril(\n            np.ones((1, scores.shape[1], scores.shape[2])), k=0\n        )\n        self.assertAllClose(causal_mask, expected_causal_mask, atol=1e-6)\n\n    def test_attention_calculate_score_mask_with_causal_and_vmask(self):\n        scores = np.random.random((2, 3, 4))\n        layer = layers.Attention()\n        v_mask = np.array([[True, False, True, False]])\n\n        combined_mask = layer._calculate_score_mask(\n            scores, v_mask=v_mask, use_causal_mask=True\n        )\n        expected_causal_mask = np.tril(\n            np.ones((1, scores.shape[1], scores.shape[2])), k=0\n        )\n        expected_combined_mask = np.logical_and(\n            expected_causal_mask, v_mask[:, np.newaxis, :]\n        )\n        self.assertAllClose(combined_mask, expected_combined_mask, atol=1e-6)\n\n    def test_attention_compute_mask_with_no_mask(self):\n        layer = layers.Attention()\n        dummy_inputs = [\n            np.random.random((2, 3, 4)),\n            np.random.random((2, 4, 4)),\n        ]\n        self.assertIsNone(\n            layer.compute_mask(inputs=dummy_inputs, mask=None),\n            \"compute_mask should return None when mask is None\",\n        )\n\n    def test_attention_compute_mask_with_first_element_none(self):\n        layer = layers.Attention()\n        dummy_inputs = [\n            np.random.random((2, 3, 4)),\n            np.random.random((2, 4, 4)),\n        ]\n        mask = [None, np.array([True, False, True])]\n        self.assertIsNone(\n            layer.compute_mask(inputs=dummy_inputs, mask=mask),\n            \"compute_mask should return None when the first element is None\",\n        )\n\n    def test_attention_compute_mask_does_not_return_none_with_valid_mask(self):\n        layer = layers.Attention()\n        dummy_inputs = [\n            np.random.random((2, 3, 4)),\n            np.random.random((2, 4, 4)),\n        ]\n        valid_mask = np.array([True, False, True])\n        mask = [valid_mask, np.array([False, True, False])]\n        computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask)\n        computed_mask = ops.convert_to_numpy(computed_mask)\n        self.assertIsNotNone(\n            computed_mask,\n            \"compute_mask should not return None with a valid mask\",\n        )\n\n    def test_attention_compute_mask_returns_correct_tensor_with_valid_mask(\n        self,\n    ):\n        layer = layers.Attention()\n        dummy_inputs = [\n            np.random.random((2, 3, 4)),\n            np.random.random((2, 4, 4)),\n        ]\n        valid_mask = np.array([True, False, True])\n        mask = [valid_mask, np.array([False, True, False])]\n        computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask)\n        computed_mask = ops.convert_to_numpy(computed_mask)\n        self.assertTrue(\n            np.array_equal(computed_mask, valid_mask),\n            \"compute_mask did not return the correct mask tensor\",\n        )\n\n    def test_attention_compute_mask_returns_correct_tensor_with_all_true_mask(\n        self,\n    ):\n        layer = layers.Attention()\n        dummy_inputs = [np.ones((2, 3, 4)), np.ones((2, 4, 4))]\n        valid_mask = np.array([True, True, True])\n        mask = [valid_mask, np.array([True, True, True])]\n        computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask)\n        computed_mask = ops.convert_to_numpy(computed_mask)\n        expected_mask = np.array([True, True, True])\n        self.assertTrue(\n            np.array_equal(computed_mask, expected_mask),\n            \"compute_mask did not return the correct mask tensor\",\n        )\n\n    def test_attention_compute_mask_returns_correct_tensor_with_all_false_mask(\n        self,\n    ):\n        layer = layers.Attention()\n        dummy_inputs = [np.ones((2, 3, 4)), np.ones((2, 4, 4))]\n        valid_mask = np.array([False, False, False])\n        mask = [valid_mask, np.array([False, False, False])]\n        computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask)\n        computed_mask = ops.convert_to_numpy(computed_mask)\n        expected_mask = np.array([False, False, False])\n        self.assertTrue(\n            np.array_equal(computed_mask, expected_mask),\n            \"compute_mask did not return the correct mask tensor\",\n        )\n\n    def test_attention_compute_mask_with_tolerance_1e_3(self):\n        layer = layers.Attention()\n        dummy_inputs = [np.ones((2, 3, 4)), np.ones((2, 4, 4))]\n        valid_mask = np.array([1.0, 0.0, 1.0], dtype=float)\n        mask = [valid_mask, np.array([0.0, 1.0, 0.0], dtype=float)]\n        computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask)\n        computed_mask = ops.convert_to_numpy(computed_mask)\n        expected_mask = valid_mask\n        self.assertTrue(\n            np.allclose(computed_mask, expected_mask, atol=1e-3),\n            \"Incorrect mask tensor within tolerance 1e-3\",\n        )\n\n    def test_attention_compute_mask_with_tolerance_1e_5(self):\n        layer = layers.Attention()\n        dummy_inputs = [np.ones((2, 3, 4)), np.ones((2, 4, 4))]\n        valid_mask = np.array([1.0, 0.0, 1.0], dtype=float)\n        mask = [valid_mask, np.array([0.0, 1.0, 0.0], dtype=float)]\n        computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask)\n        computed_mask = ops.convert_to_numpy(computed_mask)\n        expected_mask = valid_mask\n        self.assertTrue(\n            np.allclose(computed_mask, expected_mask, atol=1e-5),\n            \"Incorrect mask tensor within tolerance 1e-5\",\n        )\n\n    def test_attention_compute_mask_with_tolerance_1e_7(self):\n        layer = layers.Attention()\n        dummy_inputs = [np.ones((2, 3, 4)), np.ones((2, 4, 4))]\n        valid_mask = np.array([1.0, 0.0, 1.0], dtype=float)\n        mask = [valid_mask, np.array([0.0, 1.0, 0.0], dtype=float)]\n        computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask)\n        computed_mask = ops.convert_to_numpy(computed_mask)\n        expected_mask = valid_mask\n        self.assertTrue(\n            np.allclose(computed_mask, expected_mask, atol=1e-7),\n            \"Incorrect mask tensor within tolerance 1e-7 \",\n        )\n\n    def test_attention_compute_mask_with_single_element_masks(self):\n        layer = layers.Attention()\n        dummy_inputs = [np.ones((2, 3, 4)), np.ones((2, 4, 4))]\n        valid_mask = np.array([True])\n        mask = [valid_mask, np.array([False])]\n        computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask)\n        computed_mask = ops.convert_to_numpy(computed_mask)\n        expected_shape = (1,)\n        self.assertEqual(computed_mask.shape, expected_shape)\n\n    def test_attention_compute_mask_with_non_boolean_masks(self):\n        layer = layers.Attention()\n        dummy_inputs = [np.ones((2, 3, 4)), np.ones((2, 4, 4))]\n        valid_mask = np.array([1, 0, 1])\n        mask = [valid_mask, np.array([0, 1, 0])]\n        computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask)\n        computed_mask = ops.convert_to_numpy(computed_mask)\n        self.assertTrue(np.array_equal(computed_mask, valid_mask))\n\n    def test_attention_compute_mask_with_edge_case_masks(self):\n        layer = layers.Attention()\n        dummy_inputs = [np.ones((2, 3, 4)), np.ones((2, 4, 4))]\n        edge_case_masks = [\n            np.array([True, True, True]),\n            np.array([False, False, False]),\n            np.array([True, False, True]),\n        ]\n        for mask in edge_case_masks:\n            computed_mask = layer.compute_mask(\n                inputs=dummy_inputs, mask=[mask, mask]\n            )\n            computed_mask = ops.convert_to_numpy(computed_mask)\n            self.assertTrue(np.array_equal(computed_mask, mask))\n\n    def test_attention_compute_mask_with_different_input_shapes(self):\n        layer = layers.Attention()\n        input_shapes = [(2, 3, 4), (3, 2, 5), (4, 1, 6)]\n        valid_mask = np.array([True, False, True])\n        for shape in input_shapes:\n            dummy_inputs = [np.ones(shape), np.ones(shape)]\n            mask = [valid_mask, np.array([False, True, False])]\n            computed_mask = layer.compute_mask(inputs=dummy_inputs, mask=mask)\n            computed_mask = ops.convert_to_numpy(computed_mask)\n            self.assertTrue(np.array_equal(computed_mask, valid_mask))\n\n    def test_attention_compute_output_shape(self):\n        layer = layers.Attention()\n\n        query = np.random.random((2, 3, 4))\n        value = np.random.random((2, 3, 5))\n        key = np.random.random((2, 3, 4))\n        layer = layers.Attention()\n        output = layer([query, value, key])\n        self.assertAllEqual(output.shape, value.shape)\n        self.assertAllEqual(\n            layer.compute_output_shape(\n                input_shape=[query.shape, value.shape, key.shape]\n            ),\n            output.shape,\n        )\n\n    def test_return_attention_scores_true(self):\n        \"\"\"Test that the layer returns attention scores along with outputs.\"\"\"\n        # Generate dummy input data\n        query = np.random.random((2, 8, 16)).astype(np.float32)\n        value = np.random.random((2, 4, 16)).astype(np.float32)\n\n        # Initialize the Attention layer\n        layer = layers.Attention()\n\n        # Call the layer with return_attention_scores=True\n        output, attention_scores = layer(\n            [query, value], return_attention_scores=True\n        )\n\n        # Check the shape of the outputs\n        self.assertEqual(output.shape, (2, 8, 16))  # Output shape\n        self.assertEqual(\n            attention_scores.shape, (2, 8, 4)\n        )  # Attention scores shape\n\n    def test_return_attention_scores_true_and_tuple(self):\n        \"\"\"Test that the layer outputs are a tuple when\n        return_attention_scores=True.\"\"\"\n        # Generate dummy input data\n        query = np.random.random((2, 8, 16)).astype(np.float32)\n        value = np.random.random((2, 4, 16)).astype(np.float32)\n\n        # Initialize the Attention layer\n        layer = layers.Attention()\n\n        # Call the layer with return_attention_scores=True\n        outputs = layer([query, value], return_attention_scores=True)\n\n        # Check that outputs is a tuple\n        self.assertIsInstance(\n            outputs, tuple, \"Expected the outputs to be a tuple\"\n        )\n\n    def test_return_attention_scores_true_tuple_then_unpack(self):\n        \"\"\"Test that outputs can be unpacked correctly.\"\"\"\n        # Generate dummy input data\n        query = np.random.random((2, 8, 16)).astype(np.float32)\n        value = np.random.random((2, 4, 16)).astype(np.float32)\n\n        # Initialize the Attention layer\n        layer = layers.Attention()\n\n        # Call the layer with return_attention_scores=True\n        outputs = layer([query, value], return_attention_scores=True)\n\n        # Unpack the outputs\n        output, attention_scores = outputs\n\n        # Check the shape of the unpacked outputs\n        self.assertEqual(output.shape, (2, 8, 16))  # Output shape\n        self.assertEqual(\n            attention_scores.shape, (2, 8, 4)\n        )  # Attention scores shape\n\n    def test_return_attention_scores_with_symbolic_tensors(self):\n        \"\"\"Test to check outputs with symbolic tensors with\n        return_attention_scores = True\"\"\"\n        attention = layers.Attention()\n        x = layers.Input(shape=(3, 5))\n        y = layers.Input(shape=(4, 5))\n        output, attention_scores = attention(\n            [x, y], return_attention_scores=True\n        )\n        self.assertEqual(output.shape, (None, 3, 5))  # Output shape\n        self.assertEqual(attention_scores.shape, (None, 3, 4))\n"
  },
  {
    "path": "keras/src/layers/attention/grouped_query_attention.py",
    "content": "import math\n\nfrom keras.src import constraints\nfrom keras.src import initializers\nfrom keras.src import ops\nfrom keras.src import regularizers\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend.config import is_flash_attention_enabled\nfrom keras.src.layers.activations.softmax import Softmax\nfrom keras.src.layers.core.einsum_dense import EinsumDense\nfrom keras.src.layers.layer import Layer\nfrom keras.src.layers.regularization.dropout import Dropout\n\n\n@keras_export(\"keras.layers.GroupQueryAttention\")\nclass GroupedQueryAttention(Layer):\n    \"\"\"Grouped Query Attention layer.\n\n    This is an implementation of grouped-query attention introduced by\n    [Ainslie et al., 2023](https://arxiv.org/abs/2305.13245). Here\n    `num_key_value_heads` denotes number of groups, setting\n    `num_key_value_heads` to 1 is equivalent to multi-query attention, and\n    when `num_key_value_heads` is equal to `num_query_heads` it is equivalent\n    to multi-head attention.\n\n    This layer first projects `query`, `key`, and `value` tensors. Then, `key`\n    and `value` are repeated to match the number of heads of `query`.\n\n    Then, the `query` is scaled and dot-producted with `key` tensors. These are\n    softmaxed to obtain attention probabilities. The value tensors are then\n    interpolated by these probabilities and concatenated back to a single\n    tensor.\n\n    Args:\n        head_dim: Size of each attention head.\n        num_query_heads: Number of query attention heads.\n        num_key_value_heads: Number of key and value attention heads.\n        dropout: Dropout probability.\n        use_bias: Boolean, whether the dense layers use bias vectors/matrices.\n        flash_attention: If `None`, the layer attempts to use flash\n            attention for faster and more memory-efficient attention\n            computations when possible. This behavior can be configured using\n            `keras.config.enable_flash_attention()` or\n            `keras.config.disable_flash_attention()`.\n        kernel_initializer: Initializer for dense layer kernels.\n        bias_initializer: Initializer for dense layer biases.\n        kernel_regularizer: Regularizer for dense layer kernels.\n        bias_regularizer: Regularizer for dense layer biases.\n        activity_regularizer: Regularizer for dense layer activity.\n        kernel_constraint: Constraint for dense layer kernels.\n        bias_constraint: Constraint for dense layer kernels.\n        use_gate: Boolean, whether to apply a gated attention mechanism.\n            When True, an additional gating branch is added based on the\n            (Gated Attention for Large Language Models)[https://arxiv.org/abs/2505.06708].\n            It applies a sigmoid-activated linear projection to the query\n            which then gates the attention output. This helps improve training\n            stability and eliminates \"attention sinks\".\n        seed: Optional integer to seed the dropout layer.\n\n    Call arguments:\n        query: Query tensor of shape `(batch_dim, target_seq_len, feature_dim)`,\n            where `batch_dim` is batch size, `target_seq_len` is the length of\n            target sequence, and `feature_dim` is dimension of feature.\n        value: Value tensor of shape `(batch_dim, source_seq_len, feature_dim)`,\n            where `batch_dim` is batch size, `source_seq_len` is the length of\n            source sequence, and `feature_dim` is dimension of feature.\n        key: Optional key tensor of shape\n            `(batch_dim, source_seq_len, feature_dim)`. If not given, will use\n            `value` for both `key` and `value`, which is most common case.\n        attention_mask: A boolean mask of shape\n            `(batch_dim, target_seq_len, source_seq_len)`, that prevents\n            attention to certain positions. The boolean mask specifies which\n            query elements can attend to which key elements, where 1 indicates\n            attention and 0 indicates no attention. Broadcasting can happen for\n            the missing batch dimensions and the head dimension.\n        return_attention_scores: A boolean to indicate whether the output\n            should be `(attention_output, attention_scores)` if `True`, or\n            `attention_output` if `False`. Defaults to `False`.\n        training: Python boolean indicating whether the layer should behave in\n            training mode (adding dropout) or in inference mode (no dropout).\n            Will go with either using the training mode of the parent\n            layer/model or `False` (inference) if there is no parent layer.\n        use_causal_mask: A boolean to indicate whether to apply a causal mask to\n            prevent tokens from attending to future tokens (e.g., used in a\n            decoder Transformer).\n\n    Returns:\n        attention_output: Result of the computation, of shape\n            `(batch_dim, target_seq_len, feature_dim)`, where `target_seq_len`\n            is for target sequence length and `feature_dim` is the query input\n            last dim.\n        attention_scores: (Optional) attention coefficients of shape\n            `(batch_dim, num_query_heads, target_seq_len, source_seq_len)`.\n    \"\"\"\n\n    def __init__(\n        self,\n        head_dim,\n        num_query_heads,\n        num_key_value_heads,\n        dropout=0.0,\n        use_bias=True,\n        flash_attention=None,\n        kernel_initializer=\"glorot_uniform\",\n        bias_initializer=\"zeros\",\n        kernel_regularizer=None,\n        bias_regularizer=None,\n        activity_regularizer=None,\n        kernel_constraint=None,\n        bias_constraint=None,\n        use_gate=False,\n        seed=None,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.supports_masking = True\n        self.head_dim = head_dim\n        self.num_query_heads = num_query_heads\n        self.num_key_value_heads = num_key_value_heads\n        if num_query_heads % num_key_value_heads != 0:\n            raise ValueError(\n                \"`num_query_heads` must be divisible by `num_key_value_heads`.\"\n            )\n        self.num_repeats = num_query_heads // num_key_value_heads\n        self.dropout = dropout\n        self.use_bias = use_bias\n        self.use_gate = use_gate\n        self._flash_attention = flash_attention or is_flash_attention_enabled()\n        self.kernel_initializer = initializers.get(kernel_initializer)\n        self.bias_initializer = initializers.get(bias_initializer)\n        self.kernel_regularizer = regularizers.get(kernel_regularizer)\n        self.bias_regularizer = regularizers.get(bias_regularizer)\n        self.activity_regularizer = regularizers.get(activity_regularizer)\n        self.kernel_constraint = constraints.get(kernel_constraint)\n        self.bias_constraint = constraints.get(bias_constraint)\n        self.seed = seed\n\n        self._inverse_sqrt_head_dim = 1.0 / math.sqrt(float(self.head_dim))\n        self._return_attention_scores = False\n\n        # Check for flash attention constraints\n        if self._flash_attention and self.dropout > 0.0:\n            raise ValueError(\n                \"Dropout is not supported when flash attention is enabled. \"\n                \"Please set dropout to 0.0 to use flash attention.\"\n            )\n\n    def build(\n        self,\n        query_shape,\n        value_shape,\n        key_shape=None,\n    ):\n        # Einsum variables:\n        # b = batch size\n        # q = query length\n        # k = key/value length\n        # m = model dim\n        # u = num query heads\n        # v = num key/value heads\n        # h = head dim\n        key_shape = value_shape if key_shape is None else key_shape\n        self.feature_dim = query_shape[-1]\n        self._query_dense = EinsumDense(\n            \"bqm,muh->bquh\",\n            output_shape=(None, self.num_query_heads, self.head_dim),\n            bias_axes=\"uh\" if self.use_bias else None,\n            name=\"query\",\n            **self._get_common_kwargs_for_sublayer(),\n        )\n        self._query_dense.build(query_shape)\n\n        self._key_dense = EinsumDense(\n            \"bkm,mvh->bkvh\",\n            output_shape=(None, self.num_key_value_heads, self.head_dim),\n            bias_axes=\"vh\" if self.use_bias else None,\n            name=\"key\",\n            **self._get_common_kwargs_for_sublayer(),\n        )\n        self._key_dense.build(key_shape)\n        if self.use_gate:\n            self._gate_dense = EinsumDense(\n                \"bqm,muh->bquh\",\n                output_shape=(None, self.num_query_heads, self.head_dim),\n                bias_axes=\"uh\" if self.use_bias else None,\n                activation=\"sigmoid\",\n                name=\"gate\",\n                **self._get_common_kwargs_for_sublayer(),\n            )\n            self._gate_dense.build(query_shape)\n        self._value_dense = EinsumDense(\n            \"bkm,mvh->bkvh\",\n            output_shape=(None, self.num_key_value_heads, self.head_dim),\n            bias_axes=\"vh\" if self.use_bias else None,\n            name=\"value\",\n            **self._get_common_kwargs_for_sublayer(),\n        )\n        self._value_dense.build(value_shape)\n\n        self._softmax = Softmax(axis=-1, dtype=self.dtype_policy)\n        self._dropout_layer = Dropout(\n            rate=self.dropout, dtype=self.dtype_policy, seed=self.seed\n        )\n\n        self._dot_product_equation = \"bquh,bkuh->buqk\"\n        self._combine_equation = \"buqk,bkuh->bquh\"\n\n        self._output_dense = EinsumDense(\n            \"bquh,uhm->bqm\",\n            output_shape=(None, self.feature_dim),\n            bias_axes=\"m\" if self.use_bias else None,\n            name=\"attention_output\",\n            **self._get_common_kwargs_for_sublayer(),\n        )\n        self._output_dense.build(\n            (None, None, self.num_query_heads, self.head_dim)\n        )\n\n    def _get_common_kwargs_for_sublayer(self):\n        common_kwargs = dict(\n            kernel_regularizer=self.kernel_regularizer,\n            bias_regularizer=self.bias_regularizer,\n            activity_regularizer=self.activity_regularizer,\n            kernel_constraint=self.kernel_constraint,\n            bias_constraint=self.bias_constraint,\n            dtype=self.dtype_policy,\n        )\n        # Create new clone of kernel/bias initializer, so that we don't reuse\n        # the initializer instance, which could lead to same init value since\n        # initializer is stateless.\n        kernel_initializer = self.kernel_initializer.__class__.from_config(\n            self.kernel_initializer.get_config()\n        )\n        bias_initializer = self.bias_initializer.__class__.from_config(\n            self.bias_initializer.get_config()\n        )\n        common_kwargs[\"kernel_initializer\"] = kernel_initializer\n        common_kwargs[\"bias_initializer\"] = bias_initializer\n        return common_kwargs\n\n    def call(\n        self,\n        query,\n        value,\n        key=None,\n        query_mask=None,\n        value_mask=None,\n        key_mask=None,\n        attention_mask=None,\n        return_attention_scores=False,\n        training=None,\n        use_causal_mask=False,\n    ):\n        self._return_attention_scores = return_attention_scores\n        if key is None:\n            key = value\n\n        attention_mask = self._compute_attention_mask(\n            query,\n            value,\n            query_mask=query_mask,\n            value_mask=value_mask,\n            key_mask=key_mask,\n            attention_mask=attention_mask,\n            use_causal_mask=use_causal_mask,\n        )\n        if self.use_gate:\n            gate = self._gate_dense(query)\n        query = self._query_dense(query)\n        key = self._key_dense(key)\n        value = self._value_dense(value)\n\n        key = ops.repeat(\n            key, self.num_repeats, axis=2\n        )  # (batch_dim, source_seq_len, query_heads, head_dim)\n        value = ops.repeat(\n            value, self.num_repeats, axis=2\n        )  # (batch_dim, source_seq_len, query_heads, head_dim)\n\n        output, scores = self._compute_attention(\n            query,\n            key,\n            value,\n            attention_mask=attention_mask,\n            training=training,\n        )\n        # (batch_dim, target_seq_len, feature_dim)\n        if self.use_gate:\n            output = self._output_dense(ops.multiply(output, gate))\n        else:\n            output = self._output_dense(output)\n\n        if return_attention_scores:\n            return output, scores\n        return output\n\n    def _compute_attention_mask(\n        self,\n        query,\n        value,\n        query_mask=None,\n        value_mask=None,\n        key_mask=None,\n        attention_mask=None,\n        use_causal_mask=False,\n    ):\n        \"\"\"Computes the attention mask, using the Keras masks of the inputs.\n\n        * The `query`'s mask is reshaped from [B, T] to [B, T, 1].\n        * The `value`'s mask is reshaped from [B, S] to [B, 1, S].\n        * The `key`'s mask is reshaped from [B, S] to [B, 1, S]. The `key`'s\n          mask is ignored if `key` is `None` or if `key is value`.\n        * If `use_causal_mask=True`, then the causal mask is computed. Its shape\n          is [1, T, S].\n\n        All defined masks are merged using a logical AND operation (`&`).\n\n        In general, if the `query` and `value` are masked, then there is no need\n        to define the `attention_mask`.\n\n        Args:\n            query: Projected query tensor of shape `(B, T, N, key_dim)`.\n            key: Projected key tensor of shape `(B, T, N, key_dim)`.\n            value: Projected value tensor of shape `(B, T, N, value_dim)`.\n            attention_mask: a boolean mask of shape `(B, T, S)`, that prevents\n                attention to certain positions.\n            use_causal_mask: A boolean to indicate whether to apply a causal\n                mask to prevent tokens from attending to future tokens (e.g.,\n                used in a decoder Transformer).\n\n        Returns:\n            attention_mask: a boolean mask of shape `(B, T, S)`, that prevents\n                attention to certain positions, based on the Keras masks of the\n                `query`, `key`, `value`, and `attention_mask` tensors, and the\n                causal mask if `use_causal_mask=True`.\n        \"\"\"\n        auto_mask = None\n        if query_mask is not None:\n            query_mask = ops.cast(query_mask, \"bool\")  # defensive casting\n            # B = batch size, T = max query length\n            auto_mask = ops.expand_dims(query_mask, -1)  # shape is [B, T, 1]\n        if value_mask is not None:\n            value_mask = ops.cast(value_mask, \"bool\")  # defensive casting\n            # B = batch size, S == max value length\n            mask = ops.expand_dims(value_mask, -2)  # shape is [B, 1, S]\n            auto_mask = mask if auto_mask is None else auto_mask & mask\n        if key_mask is not None:\n            key_mask = ops.cast(key_mask, \"bool\")  # defensive casting\n            # B == batch size, S == max key length == max value length\n            mask = ops.expand_dims(key_mask, -2)  # shape is [B, 1, S]\n            auto_mask = mask if auto_mask is None else auto_mask & mask\n        if use_causal_mask:\n            # the shape of the causal mask is [1, T, S]\n            mask = self._compute_causal_mask(query, value)\n            auto_mask = mask if auto_mask is None else auto_mask & mask\n        if auto_mask is not None:\n            # merge attention_mask & automatic mask, to shape [B, T, S]\n            attention_mask = (\n                auto_mask\n                if attention_mask is None\n                else ops.cast(attention_mask, bool) & auto_mask\n            )\n        return attention_mask\n\n    def _compute_causal_mask(self, query, value=None):\n        \"\"\"Computes a causal mask (e.g., for masked self-attention layers).\n\n        For example, if query and value both contain sequences of length 4,\n        this function returns a boolean tensor equal to:\n\n        ```\n        [[[True,  False, False, False],\n          [True,  True,  False, False],\n          [True,  True,  True,  False],\n          [True,  True,  True,  True]]]\n        ```\n\n        Args:\n            query: query tensor of shape `(B, T, ...)`.\n            value: value tensor of shape `(B, S, ...)` (optional, defaults to\n                query).\n\n        Returns:\n            mask: a boolean tensor of shape `(1, T, S)` containing a lower\n                triangular matrix of shape `(T, S)`.\n        \"\"\"\n        q_seq_length = ops.shape(query)[1]\n        v_seq_length = q_seq_length if value is None else ops.shape(value)[1]\n        ones_mask = ops.ones((1, q_seq_length, v_seq_length), dtype=\"int32\")\n        row_index = ops.cumsum(ones_mask, axis=-2)\n        col_index = ops.cumsum(ones_mask, axis=-1)\n        return ops.greater_equal(row_index, col_index)\n\n    def _compute_attention(\n        self, query, key, value, attention_mask=None, training=None\n    ):\n        # Check for flash attention constraints\n        if self._flash_attention and self._return_attention_scores:\n            raise ValueError(\n                \"Returning attention scores is not supported when flash \"\n                \"attention is enabled. Please disable flash attention to access\"\n                \" attention scores.\"\n            )\n\n        # Determine whether to use dot-product attention\n        use_dot_product_attention = not (\n            self.dropout > 0.0\n            or self._return_attention_scores\n            or (len(query.shape) != 4)\n        )\n\n        if use_dot_product_attention:\n            if attention_mask is not None:\n                # Ensure attention_mask has the correct shape for broadcasting\n                # Expected shape: [batch_size, num_heads, query_seq_len,\n                # key_seq_len].\n                mask_expansion_axis = -1 * 2 - 1\n                len_attention_scores_shape = 4  # Only accepts 4D inputs\n                for _ in range(\n                    len_attention_scores_shape - len(attention_mask.shape)\n                ):\n                    attention_mask = ops.expand_dims(\n                        attention_mask, axis=mask_expansion_axis\n                    )\n                attention_mask = ops.cast(attention_mask, dtype=\"bool\")\n            # Directly compute the attention output using dot-product attention\n            attention_output = ops.dot_product_attention(\n                query=query,\n                key=key,\n                value=value,\n                bias=None,\n                mask=attention_mask,\n                scale=self._inverse_sqrt_head_dim,\n                is_causal=False,\n                flash_attention=self._flash_attention,\n            )\n            return attention_output, None\n\n        # Default behavior without flash attention, with explicit attention\n        # scores\n        query = ops.multiply(\n            query, ops.cast(self._inverse_sqrt_head_dim, query.dtype)\n        )\n        # Take the dot product between \"query\" and \"key\" to get the raw\n        # attention scores.\n        scores = ops.einsum(\n            self._dot_product_equation, query, key\n        )  # (batch_dim, query_heads, target_seq_len, source_seq_len)\n        scores = self._masked_softmax(scores, attention_mask=attention_mask)\n        # This is actually dropping out entire tokens to attend to, which might\n        # seem a bit unusual, but is taken from the original Transformer paper.\n        if self.dropout > 0.0:\n            scores_dropout = self._dropout_layer(scores, training=training)\n        else:\n            scores_dropout = scores\n        output = ops.einsum(self._combine_equation, scores_dropout, value)\n        return output, scores\n\n    def _masked_softmax(self, scores, attention_mask=None):\n        # Normalize the attention scores to probabilities.\n        # scores = [B, N, T, S]\n        if attention_mask is not None:\n            # The expand dim happens starting from the `num_heads` dimension,\n            # (<batch_dims>, num_heads, <query_attention_dims,\n            # key_attention_dims>)\n            mask_expansion_axis = -1 * 2 - 1\n            for _ in range(len(scores.shape) - len(attention_mask.shape)):\n                attention_mask = ops.expand_dims(\n                    attention_mask, axis=mask_expansion_axis\n                )\n        return self._softmax(scores, mask=attention_mask)\n\n    def compute_output_shape(\n        self,\n        query_shape,\n        value_shape,\n        key_shape=None,\n    ):\n        if key_shape is None:\n            key_shape = value_shape\n\n        if query_shape[-1] != value_shape[-1]:\n            raise ValueError(\n                \"The last dimension of `query_shape` and `value_shape` \"\n                f\"must be equal, but are {query_shape[-1]}, {value_shape[-1]}. \"\n                f\"Received: query_shape={query_shape}, \"\n                f\"value_shape={value_shape}\"\n            )\n\n        if value_shape[1:-1] != key_shape[1:-1]:\n            raise ValueError(\n                \"All dimensions of `value` and `key`, except the last one, \"\n                f\"must be equal. Received: value_shape={value_shape} and \"\n                f\"key_shape={key_shape}\"\n            )\n\n        return query_shape\n\n    def get_config(self):\n        config = {\n            \"head_dim\": self.head_dim,\n            \"num_query_heads\": self.num_query_heads,\n            \"num_key_value_heads\": self.num_key_value_heads,\n            \"use_bias\": self.use_bias,\n            \"use_gate\": self.use_gate,\n            \"dropout\": self.dropout,\n            \"kernel_initializer\": initializers.serialize(\n                self.kernel_initializer\n            ),\n            \"bias_initializer\": initializers.serialize(self.bias_initializer),\n            \"kernel_regularizer\": regularizers.serialize(\n                self.kernel_regularizer\n            ),\n            \"bias_regularizer\": regularizers.serialize(self.bias_regularizer),\n            \"activity_regularizer\": regularizers.serialize(\n                self.activity_regularizer\n            ),\n            \"kernel_constraint\": constraints.serialize(self.kernel_constraint),\n            \"bias_constraint\": constraints.serialize(self.bias_constraint),\n            \"seed\": self.seed,\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/attention/grouped_query_attention_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import initializers\nfrom keras.src import layers\nfrom keras.src import testing\nfrom keras.src.backend.config import disable_flash_attention\nfrom keras.src.backend.config import enable_flash_attention\nfrom keras.src.backend.config import is_flash_attention_enabled\n\n\nclass GroupedQueryAttentionTest(testing.TestCase):\n    def setUp(self):\n        super().setUp()\n        # Flash attention is a newly introduced feature. We need to disable it\n        # for testing purposes.\n        disable_flash_attention()\n\n    def tearDown(self):\n        enable_flash_attention()\n        return super().tearDown()\n\n    def test_basics(self):\n        self.assertFalse(is_flash_attention_enabled())\n        self.run_layer_test(\n            layers.GroupedQueryAttention,\n            init_kwargs={\n                \"num_query_heads\": 2,\n                \"num_key_value_heads\": 2,\n                \"head_dim\": 2,\n            },\n            input_shape={\"query_shape\": (2, 8, 16), \"value_shape\": (2, 4, 16)},\n            expected_output_shape=(2, 8, 16),\n            expected_num_trainable_weights=8,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n            run_training_check=False,\n        )\n\n        self.run_layer_test(\n            layers.GroupedQueryAttention,\n            init_kwargs={\n                \"num_query_heads\": 2,\n                \"num_key_value_heads\": 2,\n                \"head_dim\": 2,\n                \"use_bias\": False,\n                \"dropout\": 0.5,\n            },\n            input_shape={\"query_shape\": (2, 8, 16), \"value_shape\": (2, 4, 16)},\n            expected_output_shape=(2, 8, 16),\n            expected_num_trainable_weights=4,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=1,\n            expected_num_losses=0,\n            supports_masking=True,\n            run_training_check=False,\n        )\n\n        self.run_layer_test(\n            layers.GroupedQueryAttention,\n            init_kwargs={\n                \"num_query_heads\": 2,\n                \"num_key_value_heads\": 2,\n                \"head_dim\": 2,\n                \"use_gate\": True,\n            },\n            input_shape={\"query_shape\": (2, 8, 16), \"value_shape\": (2, 4, 16)},\n            expected_output_shape=(2, 8, 16),\n            expected_num_trainable_weights=10,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n            run_training_check=False,\n        )\n\n        self.run_layer_test(\n            layers.GroupedQueryAttention,\n            init_kwargs={\n                \"num_query_heads\": 2,\n                \"num_key_value_heads\": 2,\n                \"head_dim\": 2,\n                \"use_bias\": False,\n                \"dropout\": 0.5,\n                \"use_gate\": True,\n            },\n            input_shape={\"query_shape\": (2, 8, 16), \"value_shape\": (2, 4, 16)},\n            expected_output_shape=(2, 8, 16),\n            expected_num_trainable_weights=5,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=1,\n            expected_num_losses=0,\n            supports_masking=True,\n            run_training_check=False,\n        )\n\n    @pytest.mark.skipif(\n        backend.backend() not in (\"jax\", \"torch\"),\n        reason=\"Flash attention only supported on JAX and Torch\",\n    )\n    def test_basics_with_flash_attention(self):\n        enable_flash_attention()\n        init_kwargs = {\n            \"num_query_heads\": 2,\n            \"num_key_value_heads\": 2,\n            \"head_dim\": 8,\n            \"dtype\": \"float16\",\n        }\n        input_shape = {\n            \"query_shape\": (2, 8, 16),\n            \"value_shape\": (2, 4, 16),\n        }\n        expected_output_shape = (2, 8, 16)\n        if backend.backend() == \"torch\":\n            try:\n                self.run_layer_test(\n                    layers.GroupedQueryAttention,\n                    init_kwargs=init_kwargs,\n                    input_shape=input_shape,\n                    expected_output_shape=expected_output_shape,\n                    expected_num_trainable_weights=8,\n                    expected_num_non_trainable_weights=0,\n                    expected_num_seed_generators=0,\n                    expected_num_losses=0,\n                    supports_masking=True,\n                    run_training_check=False,\n                )\n            except ImportError as e:\n                if \"Flash attention is not supported\" in str(e.args[0]):\n                    self.assertTrue(\n                        (\n                            \"Flash attention is not supported in your current \"\n                            \"PyTorch version.\"\n                        )\n                        in str(e.args[0])\n                    )\n            except RuntimeError as e:\n                if (\n                    \"Flash attention is not supported with the provided inputs\"\n                    in str(e.args[0])\n                ):\n                    self.assertTrue(\n                        (\n                            \"Flash attention is not supported with the \"\n                            \"provided inputs\"\n                        )\n                        in str(e.args[0])\n                    )\n        elif backend.backend() == \"jax\":\n            try:\n                self.run_layer_test(\n                    layers.GroupedQueryAttention,\n                    init_kwargs=init_kwargs,\n                    input_shape=input_shape,\n                    expected_output_shape=expected_output_shape,\n                    expected_num_trainable_weights=8,\n                    expected_num_non_trainable_weights=0,\n                    expected_num_seed_generators=0,\n                    expected_num_losses=0,\n                    supports_masking=True,\n                    run_training_check=False,\n                )\n            except ImportError as e:\n                if \"Flash attention is not supported\" in str(e.args[0]):\n                    self.assertTrue(\n                        (\n                            \"Flash attention is not supported in your current \"\n                            \"JAX version.\"\n                        )\n                        in str(e.args[0])\n                    )\n            except RuntimeError as e:\n                if \"cuDNN\" in str(e.args[0]):\n                    self.assertTrue(\"cuDNN is not detected.\" in str(e.args[0]))\n                elif \"Require at least\" in str(e.args[0]):\n                    self.assertTrue(\n                        \"Require at least Ampere arch to run\" in str(e.args[0])\n                    )\n                elif \"Flash attention\" in str(e.args[0]):\n                    self.assertTrue(\n                        (\n                            \"Flash attention is not supported in your current \"\n                            \"JAX version.\"\n                        )\n                        in str(e.args[0])\n                    )\n\n    @parameterized.named_parameters(\n        (\"without_key_proj_mha\", (4, 8), (2, 8), None, 2, 2),\n        (\"with_key_proj_mha\", (4, 8), (2, 8), (2, 3), 2, 2),\n        (\"without_key_proj_gqa\", (4, 8), (2, 8), None, 4, 2),\n        (\"with_key_proj_gqa\", (4, 8), (2, 8), (2, 3), 4, 2),\n        (\"without_key_value_proj_mqa\", (4, 8), (2, 8), None, 4, 1),\n        (\"with_key_value_proj_mqa\", (4, 8), (2, 8), (2, 3), 4, 1),\n    )\n    def test_compute_output_shape(\n        self,\n        query_dims,\n        value_dims,\n        key_dims,\n        num_query_heads,\n        num_key_value_heads,\n    ):\n        \"\"\"Test computed shape is equal to the layer output's shape.\"\"\"\n        layer = layers.GroupedQueryAttention(\n            num_query_heads=num_query_heads,\n            num_key_value_heads=num_key_value_heads,\n            head_dim=2,\n        )\n        batch_size = 7\n        query_shape = (batch_size,) + query_dims\n        value_shape = (batch_size,) + value_dims\n        key_shape = (batch_size,) + key_dims if key_dims else None\n\n        query = np.ones(query_shape)\n        value = np.ones(value_shape)\n        key = np.ones(key_shape) if key_shape else None\n        output = layer(query=query, value=value, key=key)\n        comp_output_shape = layer.compute_output_shape(\n            query_shape, value_shape, key_shape\n        )\n        self.assertEqual(output.shape, comp_output_shape)\n\n    @parameterized.named_parameters(\n        (\"query_value_dim_mismatch\", (2, 4, 8), (2, 2, 7), 2),\n        (\"key_value_dim_mismatch\", (2, 4, 8), (2, 2, 8), (2, 1, 7)),\n    )\n    def test_shape_mismatch_error(self, query_shape, value_shape, key_shape):\n        \"\"\"Test dimension mismatches\"\"\"\n        layer = layers.GroupedQueryAttention(\n            num_query_heads=4,\n            num_key_value_heads=4,\n            head_dim=2,\n        )\n        with self.assertRaisesRegex(ValueError, r\"must be equal\"):\n            layer.compute_output_shape(query_shape, value_shape, key_shape)\n\n    def test_initializer(self):\n        # Test with a specified initializer.\n        layer = layers.GroupedQueryAttention(\n            num_query_heads=16,\n            num_key_value_heads=16,\n            head_dim=64,\n            use_gate=True,\n            kernel_initializer=initializers.TruncatedNormal(stddev=0.02),\n        )\n        layer.build((2, 4, 8), (2, 4, 8))\n\n        # Make sure the sub layers have different kernel init value.\n        self.assertNotAllClose(\n            layer._query_dense.kernel,\n            layer._key_dense.kernel,\n        )\n        self.assertNotAllClose(\n            layer._query_dense.kernel,\n            layer._value_dense.kernel,\n        )\n        self.assertNotAllClose(\n            layer._query_dense.kernel,\n            layer._output_dense.kernel,\n        )\n\n        self.assertNotAllClose(\n            layer._query_dense.kernel,\n            layer._gate_dense.kernel,\n        )\n\n    @pytest.mark.skipif(\n        backend.backend() == \"numpy\",\n        reason=\"Numpy backend does not support masking.\",\n    )\n    def test_query_mask_propagation(self):\n        \"\"\"Test automatic propagation of the query's mask.\"\"\"\n        layer = layers.GroupedQueryAttention(\n            num_query_heads=2, num_key_value_heads=2, head_dim=2\n        )\n        self.assertTrue(layer.supports_masking)\n        query = np.array([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]])\n        masked_query = layers.Embedding(4, 8, mask_zero=True)(query)\n        value = np.random.normal(size=(3, 3, 8))\n        output = layer(query=masked_query, value=value)\n        self.assertAllClose(masked_query._keras_mask, output._keras_mask)\n\n        layer = layers.GroupedQueryAttention(\n            num_query_heads=2, num_key_value_heads=2, head_dim=2, use_gate=True\n        )\n        self.assertTrue(layer.supports_masking)\n        query = np.array([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]])\n        masked_query = layers.Embedding(4, 8, mask_zero=True)(query)\n        value = np.random.normal(size=(3, 3, 8))\n        output = layer(query=masked_query, value=value)\n        self.assertAllClose(masked_query._keras_mask, output._keras_mask)\n\n    @parameterized.named_parameters((\"causal\", True), (\"not_causal\", 0))\n    @pytest.mark.skipif(\n        backend.backend() == \"numpy\",\n        reason=\"Numpy backend does not support masking.\",\n    )\n    def test_masking(self, use_causal_mask):\n        \"\"\"Test that the value and causal masks are taken into account.\"\"\"\n        layer = layers.GroupedQueryAttention(\n            num_query_heads=2, num_key_value_heads=2, head_dim=2\n        )\n        query = np.array([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]])\n        masked_query = layers.Embedding(4, 8, mask_zero=True)(query)\n        value = np.array([[5, 4, 0], [3, 0, 0], [2, 1, 1]])\n        masked_value = layers.Embedding(6, 8, mask_zero=True)(value)\n        output = layer(\n            query=masked_query,\n            value=masked_value,\n            use_causal_mask=use_causal_mask,\n        )\n        mask = np.array(\n            [[[1, 1, 0]] * 3 + [[0, 0, 0]] * 2]\n            + [[[1, 0, 0]] * 5]\n            + [[[1, 1, 1]] + [[0, 0, 0]] * 4]\n        ).astype(bool)\n        if use_causal_mask:\n            mask = mask & np.array(\n                [[[1, 0, 0], [1, 1, 0]] + [[1, 1, 1]] * 3]\n            ).astype(bool)\n        del masked_query._keras_mask\n        del masked_value._keras_mask\n        output_with_manual_mask = layer(\n            query=masked_query, value=masked_value, attention_mask=mask\n        )\n        self.assertAllClose(output, output_with_manual_mask)\n\n        layer = layers.GroupedQueryAttention(\n            num_query_heads=2, num_key_value_heads=2, head_dim=2, use_gate=True\n        )\n        query = np.array([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]])\n        masked_query = layers.Embedding(4, 8, mask_zero=True)(query)\n        value = np.array([[5, 4, 0], [3, 0, 0], [2, 1, 1]])\n        masked_value = layers.Embedding(6, 8, mask_zero=True)(value)\n        output = layer(\n            query=masked_query,\n            value=masked_value,\n            use_causal_mask=use_causal_mask,\n        )\n        mask = np.array(\n            [[[1, 1, 0]] * 3 + [[0, 0, 0]] * 2]\n            + [[[1, 0, 0]] * 5]\n            + [[[1, 1, 1]] + [[0, 0, 0]] * 4]\n        ).astype(bool)\n        if use_causal_mask:\n            mask = mask & np.array(\n                [[[1, 0, 0], [1, 1, 0]] + [[1, 1, 1]] * 3]\n            ).astype(bool)\n        del masked_query._keras_mask\n        del masked_value._keras_mask\n        output_with_manual_mask = layer(\n            query=masked_query, value=masked_value, attention_mask=mask\n        )\n        self.assertAllClose(output, output_with_manual_mask)\n\n    @parameterized.named_parameters(\n        (\"disable_flash_attention\", False), (\"enable_flash_attention\", True)\n    )\n    def test_correctness(self, flash_attention):\n        if flash_attention:\n            # Let the backend decide whether to use flash attention\n            enable_flash_attention()\n        dtype = \"float16\"  # Flash attention only accepts float16/bfloat16\n        head_dim = 8  # key_dim % 8 == 0 to enable flash attention\n        num_query_heads = num_key_value_heads = 8\n\n        query = np.identity(head_dim)[np.newaxis, ...]\n        key = np.identity(head_dim)[np.newaxis, ...]\n        value = (\n            np.reshape(np.arange(head_dim * head_dim), (1, head_dim, head_dim))\n            / 100.0  # Prevent overflow/underflow\n        )\n\n        # Setup layer.\n        layer = layers.GroupedQueryAttention(\n            head_dim=head_dim,\n            num_query_heads=num_query_heads,\n            num_key_value_heads=num_key_value_heads,\n            dtype=dtype,\n        )\n        layer.build(query.shape, key.shape, value.shape)\n\n        # Set layer weights.\n        kernel = np.identity(head_dim)\n        # To get an identity kernel we need to add a head dim and repeat on it.\n        kernel = np.repeat(kernel[:, np.newaxis, :], num_query_heads, axis=1)\n        # Zeros for all biases.\n        bias = np.zeros((num_query_heads, head_dim))\n        output_bias = np.zeros((head_dim,))\n        layer.set_weights([kernel, bias] * 3 + [kernel, output_bias])\n\n        # Call layer and assert output.\n        expected_output = np.array(\n            [2.406, 2.440, 2.473, 2.504, 2.535, 2.568, 2.602, 2.633]\n        )\n        expected_output = np.tile(\n            expected_output[np.newaxis, :, np.newaxis], (1, 1, head_dim)\n        )\n        expected_score = np.array(\n            [\n                [0.1187] * 0 + [0.1691] + [0.1187] * 7,\n                [0.1187] * 1 + [0.1691] + [0.1187] * 6,\n                [0.1187] * 2 + [0.1691] + [0.1187] * 5,\n                [0.1187] * 3 + [0.1691] + [0.1187] * 4,\n                [0.1187] * 4 + [0.1691] + [0.1187] * 3,\n                [0.1187] * 5 + [0.1691] + [0.1187] * 2,\n                [0.1187] * 6 + [0.1691] + [0.1187] * 1,\n                [0.1187] * 7 + [0.1691] + [0.1187] * 0,\n            ]\n        )\n        expected_score = np.tile(\n            expected_score[np.newaxis, np.newaxis, ...], (1, head_dim, 1, 1)\n        )\n        if flash_attention:\n            output = layer(query=query, value=value, key=key)\n            self.assertAllClose(output, expected_output, atol=1e-2)\n        else:\n            output, scores = layer(\n                query=query,\n                value=value,\n                key=key,\n                return_attention_scores=True,\n            )\n            self.assertAllClose(output, expected_output, atol=1e-2)\n            self.assertAllClose(scores, expected_score, atol=1e-2)\n\n    def test_flash_attention_with_errors(self):\n        if backend.backend() in (\"numpy\", \"tensorflow\"):\n            pytest.skip(\n                reason=(\n                    \"Flash attention is not supported on tensorflow and numpy.\"\n                )\n            )\n        # Check `flash_attention=True` and `dropout=0.1`\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Dropout is not supported when flash attention is enabled.\",\n        ):\n            layer = layers.GroupedQueryAttention(\n                head_dim=2,\n                num_query_heads=2,\n                num_key_value_heads=2,\n                flash_attention=True,\n                dropout=0.1,\n            )\n\n        # Check `flash_attention=True` and `return_attention_scores=True`\n        layer = layers.GroupedQueryAttention(\n            head_dim=2,\n            num_query_heads=2,\n            num_key_value_heads=2,\n            flash_attention=True,\n        )\n        self.assertTrue(layer._flash_attention)\n        query = np.random.random((2, 4, 8))\n        value = np.random.random((2, 4, 8))\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Returning attention scores is not supported when flash \"\n            \"attention is enabled. Please disable flash attention to access\"\n            \" attention scores.\",\n        ):\n            layer(query=query, value=value, return_attention_scores=True)\n"
  },
  {
    "path": "keras/src/layers/attention/multi_head_attention.py",
    "content": "import math\nimport string\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src import constraints\nfrom keras.src import initializers\nfrom keras.src import ops\nfrom keras.src import regularizers\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend.config import is_flash_attention_enabled\nfrom keras.src.layers.activations.softmax import Softmax\nfrom keras.src.layers.core.einsum_dense import EinsumDense\nfrom keras.src.layers.layer import Layer\nfrom keras.src.layers.regularization.dropout import Dropout\n\n\n@keras_export(\"keras.layers.MultiHeadAttention\")\nclass MultiHeadAttention(Layer):\n    \"\"\"MultiHeadAttention layer.\n\n    This is an implementation of multi-headed attention as described in the\n    paper \"Attention is all you Need\"\n    [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762).\n    If `query`, `key,` `value` are the same, then\n    this is self-attention. Each timestep in `query` attends to the\n    corresponding sequence in `key`, and returns a fixed-width vector.\n\n    This layer first projects `query`, `key` and `value`. These are\n    (effectively) a list of tensors of length `num_attention_heads`, where the\n    corresponding shapes are `(batch_size, <query dimensions>, key_dim)`,\n    `(batch_size, <key/value dimensions>, key_dim)`,\n    `(batch_size, <key/value dimensions>, value_dim)`.\n\n    Then, the query and key tensors are dot-producted and scaled. These are\n    softmaxed to obtain attention probabilities. The value tensors are then\n    interpolated by these probabilities, then concatenated back to a single\n    tensor.\n\n    Finally, the result tensor with the last dimension as `value_dim` can take\n    a linear projection and return.\n\n    Args:\n        num_heads: Number of attention heads.\n        key_dim: Size of each attention head for query and key.\n        value_dim: Size of each attention head for value.\n        dropout: Dropout probability.\n        use_bias: Boolean, whether the dense layers use bias vectors/matrices.\n        output_shape: The expected shape of an output tensor, besides the batch\n            and sequence dims. If not specified, projects back to the query\n            feature dim (the query input's last dimension).\n        attention_axes: axes over which the attention is applied. `None` means\n            attention over all axes, but batch, heads, and features.\n        flash_attention: If `None`, the layer attempts to use flash\n            attention for faster and more memory-efficient attention\n            computations when possible. This behavior can be configured using\n            `keras.config.enable_flash_attention()` or\n            `keras.config.disable_flash_attention()`.\n        kernel_initializer: Initializer for dense layer kernels.\n        bias_initializer: Initializer for dense layer biases.\n        kernel_regularizer: Regularizer for dense layer kernels.\n        bias_regularizer: Regularizer for dense layer biases.\n        activity_regularizer: Regularizer for dense layer activity.\n        kernel_constraint: Constraint for dense layer kernels.\n        bias_constraint: Constraint for dense layer kernels.\n        use_gate: Boolean, whether to apply a gated attention mechanism.\n            When True, an additional gating branch is added based on the\n            (Gated Attention for Large Language Models)[https://arxiv.org/abs/2505.06708].\n            It applies a sigmoid-activated linear projection to the query\n            which then gates the attention output. This helps improve training\n            stability and eliminates \"attention sinks\".\n        seed: Optional integer to seed the dropout layer.\n\n    Call arguments:\n        query: Query tensor of shape `(B, T, dim)`, where `B` is the batch size,\n            `T` is the target sequence length, and dim is the feature dimension.\n        value: Value tensor of shape `(B, S, dim)`, where `B` is the batch size,\n            `S` is the source sequence length, and dim is the feature dimension.\n        key: Optional key tensor of shape `(B, S, dim)`. If not given, will\n            use `value` for both `key` and `value`, which is the most common\n            case.\n        attention_mask: a boolean mask of shape `(B, T, S)`, that prevents\n            attention to certain positions. The boolean mask specifies which\n            query elements can attend to which key elements, 1 indicates\n            attention and 0 indicates no attention. Broadcasting can happen for\n            the missing batch dimensions and the head dimension.\n        return_attention_scores: A boolean to indicate whether the output should\n            be `(attention_output, attention_scores)` if `True`, or\n            `attention_output` if `False`. Defaults to `False`.\n        training: Python boolean indicating whether the layer should behave in\n            training mode (adding dropout) or in inference mode (no dropout).\n            Will go with either using the training mode of the parent\n            layer/model, or `False` (inference) if there is no parent layer.\n        use_causal_mask: A boolean to indicate whether to apply a causal mask to\n            prevent tokens from attending to future tokens (e.g., used in a\n            decoder Transformer).\n\n    Returns:\n        attention_output: The result of the computation, of shape `(B, T, E)`,\n            where `T` is for target sequence shapes and `E` is the query input\n            last dimension if `output_shape` is `None`. Otherwise, the\n            multi-head outputs are projected to the shape specified by\n            `output_shape`.\n        attention_scores: (Optional) multi-head attention coefficients over\n            attention axes.\n    \"\"\"\n\n    def __init__(\n        self,\n        num_heads,\n        key_dim,\n        value_dim=None,\n        dropout=0.0,\n        use_bias=True,\n        output_shape=None,\n        attention_axes=None,\n        flash_attention=None,\n        kernel_initializer=\"glorot_uniform\",\n        bias_initializer=\"zeros\",\n        kernel_regularizer=None,\n        bias_regularizer=None,\n        activity_regularizer=None,\n        kernel_constraint=None,\n        bias_constraint=None,\n        use_gate=False,\n        seed=None,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.supports_masking = True\n        self._num_heads = num_heads\n        self._key_dim = key_dim\n        self._value_dim = value_dim if value_dim else key_dim\n        self._dropout = dropout\n        self._use_bias = use_bias\n        self._use_gate = use_gate\n        if output_shape:\n            if isinstance(output_shape, int):\n                output_shape = (output_shape,)\n            try:\n                output_shape = tuple(output_shape)\n            except:\n                raise ValueError(\n                    f\"Invalid `output_shape`: {output_shape}. When \"\n                    \"specified, the `output_shape` should be of type tuple, \"\n                    \"list, or int.\"\n                )\n        self._output_shape = output_shape\n        self._flash_attention = flash_attention or is_flash_attention_enabled()\n        self._kernel_initializer = initializers.get(kernel_initializer)\n        self._bias_initializer = initializers.get(bias_initializer)\n        self._kernel_regularizer = regularizers.get(kernel_regularizer)\n        self._bias_regularizer = regularizers.get(bias_regularizer)\n        self._activity_regularizer = regularizers.get(activity_regularizer)\n        self._kernel_constraint = constraints.get(kernel_constraint)\n        self._bias_constraint = constraints.get(bias_constraint)\n        if isinstance(attention_axes, int):\n            attention_axes = (attention_axes,)\n        elif attention_axes and not isinstance(attention_axes, (list, tuple)):\n            raise ValueError(\n                \"`attention_axes` must be an int, list, or tuple.\"\n                f\"Received: attention_axes={attention_axes}\"\n            )\n        self._attention_axes = attention_axes\n        self.seed = seed\n\n        self._inverse_sqrt_key_dim = 1.0 / math.sqrt(float(self._key_dim))\n\n        # Check for flash attention constraints\n        if self._flash_attention and self._dropout > 0.0:\n            raise ValueError(\n                \"Dropout is not supported when flash attention is enabled. \"\n                \"Please set dropout to 0.0 to use flash attention.\"\n            )\n\n    @property\n    def num_heads(self):\n        return self._num_heads\n\n    @property\n    def key_dim(self):\n        return self._key_dim\n\n    @property\n    def value_dim(self):\n        return self._value_dim\n\n    @property\n    def dropout(self):\n        return self._dropout\n\n    @property\n    def use_bias(self):\n        return self._use_bias\n\n    # Avoid exposing `output_shape` as it may conflict with `Functional` and\n    # `Sequential` models when calling `summary()`.\n\n    @property\n    def attention_axes(self):\n        return self._attention_axes\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\n            \"num_heads\": self._num_heads,\n            \"key_dim\": self._key_dim,\n            \"value_dim\": self._value_dim,\n            \"dropout\": self._dropout,\n            \"use_bias\": self._use_bias,\n            \"use_gate\": self._use_gate,\n            \"output_shape\": self._output_shape,\n            \"attention_axes\": self._attention_axes,\n            \"kernel_initializer\": initializers.serialize(\n                self._kernel_initializer\n            ),\n            \"bias_initializer\": initializers.serialize(self._bias_initializer),\n            \"kernel_regularizer\": regularizers.serialize(\n                self._kernel_regularizer\n            ),\n            \"bias_regularizer\": regularizers.serialize(self._bias_regularizer),\n            \"activity_regularizer\": regularizers.serialize(\n                self._activity_regularizer\n            ),\n            \"kernel_constraint\": constraints.serialize(self._kernel_constraint),\n            \"bias_constraint\": constraints.serialize(self._bias_constraint),\n            \"seed\": self.seed,\n        }\n        return {**base_config, **config}\n\n    def build(\n        self,\n        query_shape,\n        value_shape,\n        key_shape=None,\n    ):\n        \"\"\"Builds layers and variables.\n\n        Args:\n            query_shape: Shape of the `query` tensor.\n            value_shape: Shape of the `value` tensor.\n            key: Optional shape of the `key` tensor.\n        \"\"\"\n        key_shape = value_shape if key_shape is None else key_shape\n\n        if value_shape[1:-1] != key_shape[1:-1]:\n            raise ValueError(\n                \"All dimensions of `value` and `key`, except the last one, \"\n                f\"must be equal. Received: value_shape={value_shape} and \"\n                f\"key_shape={key_shape}\"\n            )\n\n        query_rank = len(query_shape)\n        value_rank = len(value_shape)\n        key_rank = len(key_shape)\n        einsum_equation, bias_axes, output_rank = _build_proj_equation(\n            query_rank - 1, bound_dims=1, output_dims=2\n        )\n        self._query_dense = EinsumDense(\n            einsum_equation,\n            output_shape=_get_output_shape(\n                output_rank - 1, [self._num_heads, self._key_dim]\n            ),\n            bias_axes=bias_axes if self._use_bias else None,\n            name=\"query\",\n            **self._get_common_kwargs_for_sublayer(),\n        )\n        self._query_dense.build(query_shape)\n        einsum_equation, bias_axes, output_rank = _build_proj_equation(\n            key_rank - 1, bound_dims=1, output_dims=2\n        )\n        self._key_dense = EinsumDense(\n            einsum_equation,\n            output_shape=_get_output_shape(\n                output_rank - 1, [self._num_heads, self._key_dim]\n            ),\n            bias_axes=bias_axes if self._use_bias else None,\n            name=\"key\",\n            **self._get_common_kwargs_for_sublayer(),\n        )\n        self._key_dense.build(key_shape)\n        if self._use_gate:\n            query_einsum_equation, query_bias_axes, query_output_rank = (\n                _build_proj_equation(\n                    query_rank - 1, bound_dims=1, output_dims=2\n                )\n            )\n            self._gate_dense = EinsumDense(\n                query_einsum_equation,\n                output_shape=_get_output_shape(\n                    query_output_rank - 1, [self._num_heads, self._value_dim]\n                ),\n                bias_axes=query_bias_axes if self._use_bias else None,\n                activation=\"sigmoid\",\n                name=\"gate\",\n                **self._get_common_kwargs_for_sublayer(),\n            )\n            self._gate_dense.build(query_shape)\n        einsum_equation, bias_axes, output_rank = _build_proj_equation(\n            value_rank - 1, bound_dims=1, output_dims=2\n        )\n        self._value_dense = EinsumDense(\n            einsum_equation,\n            output_shape=_get_output_shape(\n                output_rank - 1, [self._num_heads, self._value_dim]\n            ),\n            bias_axes=bias_axes if self._use_bias else None,\n            name=\"value\",\n            **self._get_common_kwargs_for_sublayer(),\n        )\n        self._value_dense.build(value_shape)\n\n        # Builds the attention computations for multi-head dot product\n        # attention.  These computations could be wrapped into the keras\n        # attention layer once it supports multi-head einsum computations.\n        self._build_attention(output_rank)\n        self._output_dense = self._make_output_dense(\n            query_shape,\n            self._get_common_kwargs_for_sublayer(),\n            \"attention_output\",\n        )\n        output_dense_input_shape = list(\n            self._query_dense.compute_output_shape(query_shape)\n        )\n        output_dense_input_shape[-1] = self._value_dim\n        self._output_dense.build(tuple(output_dense_input_shape))\n\n    @property\n    def query_dense(self):\n        return self._query_dense\n\n    @property\n    def key_dense(self):\n        return self._key_dense\n\n    @property\n    def value_dense(self):\n        return self._value_dense\n\n    @property\n    def output_dense(self):\n        return self._output_dense\n\n    def _get_common_kwargs_for_sublayer(self):\n        common_kwargs = dict(\n            kernel_regularizer=self._kernel_regularizer,\n            bias_regularizer=self._bias_regularizer,\n            activity_regularizer=self._activity_regularizer,\n            kernel_constraint=self._kernel_constraint,\n            bias_constraint=self._bias_constraint,\n            dtype=self.dtype_policy,\n        )\n        # Create new clone of kernel/bias initializer, so that we don't reuse\n        # the initializer instance, which could lead to same init value since\n        # initializer is stateless.\n        kernel_initializer = self._kernel_initializer.__class__.from_config(\n            self._kernel_initializer.get_config()\n        )\n        bias_initializer = self._bias_initializer.__class__.from_config(\n            self._bias_initializer.get_config()\n        )\n        common_kwargs[\"kernel_initializer\"] = kernel_initializer\n        common_kwargs[\"bias_initializer\"] = bias_initializer\n        return common_kwargs\n\n    def _make_output_dense(self, query_shape, common_kwargs, name=None):\n        \"\"\"Builds the output projection matrix.\n\n        Args:\n            free_dims: Number of free dimensions for einsum equation building.\n            common_kwargs: Common keyword arguments for einsum layer.\n            name: Name for the projection layer.\n\n        Returns:\n            Projection layer.\n        \"\"\"\n        query_rank = len(query_shape)\n        if self._output_shape:\n            output_shape = self._output_shape\n        else:\n            output_shape = [query_shape[-1]]\n        einsum_equation, bias_axes, output_rank = _build_proj_equation(\n            query_rank - 1, bound_dims=2, output_dims=len(output_shape)\n        )\n        return EinsumDense(\n            einsum_equation,\n            output_shape=_get_output_shape(output_rank - 1, output_shape),\n            bias_axes=bias_axes if self._use_bias else None,\n            name=name,\n            **common_kwargs,\n        )\n\n    def _build_attention(self, rank):\n        \"\"\"Builds multi-head dot-product attention computations.\n\n        This function builds attributes necessary for `_compute_attention` to\n        customize attention computation to replace the default dot-product\n        attention.\n\n        Args:\n            rank: the rank of query, key, value tensors.\n        \"\"\"\n        if self._attention_axes is None:\n            self._attention_axes = tuple(range(1, rank - 2))\n        else:\n            self._attention_axes = tuple(\n                axis if axis >= 0 else (rank - 1) + axis\n                for axis in self._attention_axes\n            )\n        (\n            self._dot_product_equation,\n            self._combine_equation,\n            attn_scores_rank,\n        ) = _build_attention_equation(rank, attn_axes=self._attention_axes)\n        norm_axes = tuple(\n            range(\n                attn_scores_rank - len(self._attention_axes), attn_scores_rank\n            )\n        )\n        self._softmax = Softmax(axis=norm_axes, dtype=self.dtype_policy)\n        self._dropout_layer = Dropout(\n            rate=self._dropout, dtype=self.dtype_policy, seed=self.seed\n        )\n\n    def _masked_softmax(self, attention_scores, attention_mask=None):\n        # Normalize the attention scores to probabilities.\n        # attention_scores = [B, N, T, S]\n        if attention_mask is not None:\n            # The expand dim happens starting from the `num_heads` dimension,\n            # (<batch_dims>, num_heads, <query_attention_dims,\n            # key_attention_dims>)\n            mask_expansion_axis = -len(self._attention_axes) * 2 - 1\n            for _ in range(\n                len(attention_scores.shape) - len(attention_mask.shape)\n            ):\n                attention_mask = ops.expand_dims(\n                    attention_mask, axis=mask_expansion_axis\n                )\n        return self._softmax(attention_scores, mask=attention_mask)\n\n    def _compute_attention(\n        self,\n        query,\n        key,\n        value,\n        attention_mask=None,\n        training=None,\n        return_attention_scores=False,\n    ):\n        \"\"\"Applies Dot-product attention with query, key, value tensors.\n\n        This function defines the computation inside `call` with projected\n        multi-head Q, K, V inputs. Users can override this function for\n        customized attention implementation.\n\n        Args:\n            query: Projected query tensor of shape `(B, T, N, key_dim)`.\n            key: Projected key tensor of shape `(B, S, N, key_dim)`.\n            value: Projected value tensor of shape `(B, S, N, value_dim)`.\n            attention_mask: a boolean mask of shape `(B, T, S)`, that prevents\n                attention to certain positions. It is generally not needed if\n                the `query` and `value` (and/or `key`) are masked.\n            training: Python boolean indicating whether the layer should behave\n                in training mode (adding dropout) or in inference mode (doing\n                nothing).\n\n        Returns:\n          attention_output: Multi-headed outputs of attention computation.\n          attention_scores: Multi-headed attention weights.\n        \"\"\"\n        # Check for flash attention constraints\n        if self._flash_attention and return_attention_scores:\n            raise ValueError(\n                \"Returning attention scores is not supported when flash \"\n                \"attention is enabled. Please disable flash attention to access\"\n                \" attention scores.\"\n            )\n\n        # Determine whether to use dot-product attention\n        use_dot_product_attention = not (\n            self._dropout > 0.0\n            or return_attention_scores\n            or (len(query.shape) != 4)\n        )\n\n        if use_dot_product_attention:\n            if attention_mask is not None:\n                # Ensure attention_mask has the correct shape for broadcasting\n                # Expected shape: [batch_size, num_heads, query_seq_len,\n                # key_seq_len].\n                mask_expansion_axis = -len(self._attention_axes) * 2 - 1\n                len_attention_scores_shape = 4  # Only accepts 4D inputs\n                for _ in range(\n                    len_attention_scores_shape - len(attention_mask.shape)\n                ):\n                    attention_mask = ops.expand_dims(\n                        attention_mask, axis=mask_expansion_axis\n                    )\n                attention_mask = ops.cast(attention_mask, dtype=\"bool\")\n            # Directly compute the attention output using dot-product attention\n            attention_output = ops.dot_product_attention(\n                query=query,\n                key=key,\n                value=value,\n                bias=None,\n                mask=attention_mask,\n                scale=self._inverse_sqrt_key_dim,\n                is_causal=False,\n                flash_attention=self._flash_attention,\n            )\n            return attention_output, None\n\n        # Default behavior without flash attention, with explicit attention\n        # scores\n        query = ops.multiply(\n            query, ops.cast(self._inverse_sqrt_key_dim, query.dtype)\n        )\n\n        # Take the dot product between \"query\" and \"key\" to get the raw\n        # attention scores.\n        attention_scores = ops.einsum(self._dot_product_equation, key, query)\n\n        # Apply the mask using the custom masked softmax\n        attention_scores = self._masked_softmax(\n            attention_scores, attention_mask\n        )\n\n        # Apply dropout to the attention scores if needed\n        if self._dropout > 0.0:\n            final_attn_scores = self._dropout_layer(\n                attention_scores, training=training\n            )\n        else:\n            final_attn_scores = attention_scores\n\n        # `context_layer` = [B, T, N, H]\n        attention_output = ops.einsum(\n            self._combine_equation, final_attn_scores, value\n        )\n        return attention_output, attention_scores\n\n    def call(\n        self,\n        query,\n        value,\n        key=None,\n        query_mask=None,\n        value_mask=None,\n        key_mask=None,\n        attention_mask=None,\n        return_attention_scores=False,\n        training=None,\n        use_causal_mask=False,\n    ):\n        if key is None:\n            key = value\n\n        # Delete the masks because the masks are handled at the level of the\n        # layer\n        query_mask = backend.get_keras_mask(query)\n        backend.set_keras_mask(query, None)\n        backend.set_keras_mask(value, None)\n        backend.set_keras_mask(key, None)\n\n        attention_mask = self._compute_attention_mask(\n            query,\n            value,\n            query_mask=query_mask,\n            value_mask=value_mask,\n            key_mask=key_mask,\n            attention_mask=attention_mask,\n            use_causal_mask=use_causal_mask,\n        )\n        #   N = `num_attention_heads`\n        #   H = `size_per_head`\n\n        # `gate` = [B, T, N, H]\n        if self._use_gate:\n            gate = self._gate_dense(query)\n\n        # `query` = [B, T, N, H]\n        query = self._query_dense(query)\n\n        # `key` = [B, S, N, H]\n        key = self._key_dense(key)\n\n        # `value` = [B, S, N, H]\n        value = self._value_dense(value)\n        attention_output, attention_scores = self._compute_attention(\n            query,\n            key,\n            value,\n            attention_mask,\n            training,\n            return_attention_scores,\n        )\n        if self._use_gate:\n            attention_output = self._output_dense(\n                ops.multiply(attention_output, gate)\n            )\n        else:\n            attention_output = self._output_dense(attention_output)\n\n        # Set mask on output if needed\n        if query_mask is not None:\n            backend.set_keras_mask(attention_output, query_mask)\n\n        if return_attention_scores:\n            return attention_output, attention_scores\n        return attention_output\n\n    def _compute_attention_mask(\n        self,\n        query,\n        value,\n        query_mask=None,\n        value_mask=None,\n        key_mask=None,\n        attention_mask=None,\n        use_causal_mask=False,\n    ):\n        \"\"\"Computes the attention mask, using the Keras masks of the inputs.\n\n        * The `query`'s mask is reshaped from [B, T] to [B, T, 1].\n        * The `value`'s mask is reshaped from [B, S] to [B, 1, S].\n        * The `key`'s mask is reshaped from [B, S] to [B, 1, S]. The `key`'s\n          mask is ignored if `key` is `None` or if `key is value`.\n        * If `use_causal_mask=True`, then the causal mask is computed. Its shape\n          is [1, T, S].\n\n        All defined masks are merged using a logical AND operation (`&`).\n\n        In general, if the `query` and `value` are masked, then there is no need\n        to define the `attention_mask`.\n\n        Args:\n            query: Projected query tensor of shape `(B, T, N, key_dim)`.\n            key: Projected key tensor of shape `(B, T, N, key_dim)`.\n            value: Projected value tensor of shape `(B, T, N, value_dim)`.\n            attention_mask: a boolean mask of shape `(B, T, S)`, that prevents\n                attention to certain positions.\n            use_causal_mask: A boolean to indicate whether to apply a causal\n                mask to prevent tokens from attending to future tokens (e.g.,\n                used in a decoder Transformer).\n\n        Returns:\n            attention_mask: a boolean mask of shape `(B, T, S)`, that prevents\n                attention to certain positions, based on the Keras masks of the\n                `query`, `key`, `value`, and `attention_mask` tensors, and the\n                causal mask if `use_causal_mask=True`.\n        \"\"\"\n        auto_mask = None\n        if query_mask is not None:\n            query_mask = ops.cast(query_mask, \"bool\")  # defensive casting\n            # B = batch size, T = max query length\n            auto_mask = ops.expand_dims(query_mask, -1)  # shape is [B, T, 1]\n        if value_mask is not None:\n            value_mask = ops.cast(value_mask, \"bool\")  # defensive casting\n            # B = batch size, S == max value length\n            mask = ops.expand_dims(value_mask, -2)  # shape is [B, 1, S]\n            auto_mask = mask if auto_mask is None else auto_mask & mask\n        if key_mask is not None:\n            key_mask = ops.cast(key_mask, \"bool\")  # defensive casting\n            # B == batch size, S == max key length == max value length\n            mask = ops.expand_dims(key_mask, -2)  # shape is [B, 1, S]\n            auto_mask = mask if auto_mask is None else auto_mask & mask\n        if use_causal_mask:\n            # the shape of the causal mask is [1, T, S]\n            mask = self._compute_causal_mask(query, value)\n            auto_mask = mask if auto_mask is None else auto_mask & mask\n\n        if attention_mask is not None:\n            attention_mask = ops.cast(attention_mask, \"bool\")\n        if auto_mask is not None:\n            # merge attention_mask & automatic mask, to shape [B, T, S]\n            attention_mask = (\n                auto_mask\n                if attention_mask is None\n                else attention_mask & auto_mask\n            )\n        return attention_mask\n\n    def _compute_causal_mask(self, query, value=None):\n        \"\"\"Computes a causal mask (e.g., for masked self-attention layers).\n\n        For example, if query and value both contain sequences of length 4,\n        this function returns a boolean tensor equal to:\n\n        ```\n        [[[True,  False, False, False],\n          [True,  True,  False, False],\n          [True,  True,  True,  False],\n          [True,  True,  True,  True]]]\n        ```\n\n        Args:\n            query: query tensor of shape `(B, T, ...)`.\n            value: value tensor of shape `(B, S, ...)` (optional, defaults to\n                query).\n\n        Returns:\n            mask: a boolean tensor of shape `(1, T, S)` containing a lower\n                triangular matrix of shape `(T, S)`.\n        \"\"\"\n        q_seq_length = ops.shape(query)[1]\n        v_seq_length = q_seq_length if value is None else ops.shape(value)[1]\n        ones_mask = ops.ones((1, q_seq_length, v_seq_length), dtype=\"int32\")\n        row_index = ops.cumsum(ones_mask, axis=-2)\n        col_index = ops.cumsum(ones_mask, axis=-1)\n        return ops.greater_equal(row_index, col_index)\n\n    def compute_output_shape(\n        self,\n        query_shape,\n        value_shape,\n        key_shape=None,\n    ):\n        query_shape = tuple(query_shape)\n        value_shape = tuple(value_shape)\n        if key_shape is None:\n            key_shape = value_shape\n        else:\n            key_shape = tuple(key_shape)\n\n        if value_shape[1:-1] != key_shape[1:-1]:\n            raise ValueError(\n                \"All dimensions of `value` and `key`, except the last one, \"\n                f\"must be equal. Received: value_shape={value_shape} and \"\n                f\"key_shape={key_shape}\"\n            )\n        if self._output_shape:\n            query_shape = query_shape[:-1] + self._output_shape\n        return query_shape\n\n    def compute_output_spec(\n        self,\n        query,\n        value,\n        key=None,\n        query_mask=None,\n        value_mask=None,\n        key_mask=None,\n        attention_mask=None,\n        return_attention_scores=False,\n        training=None,\n        use_causal_mask=False,\n    ):\n        if key is not None:\n            key_shape = key.shape\n        else:\n            key_shape = None\n        output_shape = self.compute_output_shape(\n            query.shape, value.shape, key_shape\n        )\n        output_spec = backend.KerasTensor(\n            output_shape, dtype=self.compute_dtype\n        )\n        if return_attention_scores:\n            length = query.shape[1]\n            attention_shape = (query.shape[0], self.num_heads, length, length)\n            return output_spec, backend.KerasTensor(\n                attention_shape, dtype=self.compute_dtype\n            )\n        return output_spec\n\n\ndef _index_to_einsum_variable(i):\n    \"\"\"Converts an index to a einsum variable name.\n\n    We simply map indices to lowercase characters, e.g. 0 -> 'a', 1 -> 'b'.\n    \"\"\"\n    return string.ascii_lowercase[i]\n\n\ndef _build_attention_equation(rank, attn_axes):\n    \"\"\"Builds einsum equations for the attention computation.\n\n    Query, key, value inputs after projection are expected to have the shape as:\n    `(bs, <non-attention dims>, <attention dims>, num_heads, channels)`.\n    `bs` and `<non-attention dims>` are treated as `<batch dims>`.\n\n    The attention operations can be generalized:\n    1. Query-key dot product:\n        (<batch dims>, <query attention dims>, num_heads, channels),\n        (<batch dims>, <key attention dims>, num_heads, channels) ->\n        (<batch dims>, num_heads, <query attention dims>, <key attention dims>)\n    2. Combination:\n        (<batch dims>, num_heads, <query attention dims>, <key attention dims>),\n        (<batch dims>, <value attention dims>, num_heads, channels) -> (<batch\n        dims>, <query attention dims>, num_heads, channels)\n\n    Args:\n        rank: Rank of query, key, value tensors.\n        attn_axes: List/tuple of axes, `[-1, rank)`,\n            that attention will be applied to.\n\n    Returns:\n        Einsum equations.\n    \"\"\"\n    target_notation = \"\"\n    for i in range(rank):\n        target_notation += _index_to_einsum_variable(i)\n    # `batch_dims` includes the head dim.\n    batch_dims = tuple(np.delete(range(rank), attn_axes + (rank - 1,)))\n    letter_offset = rank\n    source_notation = \"\"\n    for i in range(rank):\n        if i in batch_dims or i == rank - 1:\n            source_notation += target_notation[i]\n        else:\n            source_notation += _index_to_einsum_variable(letter_offset)\n            letter_offset += 1\n\n    product_notation = \"\".join(\n        [target_notation[i] for i in batch_dims]\n        + [target_notation[i] for i in attn_axes]\n        + [source_notation[i] for i in attn_axes]\n    )\n    dot_product_equation = \"%s,%s->%s\" % (\n        source_notation,\n        target_notation,\n        product_notation,\n    )\n    attn_scores_rank = len(product_notation)\n    combine_equation = \"%s,%s->%s\" % (\n        product_notation,\n        source_notation,\n        target_notation,\n    )\n    return dot_product_equation, combine_equation, attn_scores_rank\n\n\ndef _build_proj_equation(free_dims, bound_dims, output_dims):\n    \"\"\"Builds an einsum equation for projections inside multi-head attention.\"\"\"\n    input_str = \"\"\n    kernel_str = \"\"\n    output_str = \"\"\n    bias_axes = \"\"\n    letter_offset = 0\n    for i in range(free_dims):\n        char = _index_to_einsum_variable(i + letter_offset)\n        input_str += char\n        output_str += char\n\n    letter_offset += free_dims\n    for i in range(bound_dims):\n        char = _index_to_einsum_variable(i + letter_offset)\n        input_str += char\n        kernel_str += char\n\n    letter_offset += bound_dims\n    for i in range(output_dims):\n        char = _index_to_einsum_variable(i + letter_offset)\n        kernel_str += char\n        output_str += char\n        bias_axes += char\n    equation = f\"{input_str},{kernel_str}->{output_str}\"\n\n    return equation, bias_axes, len(output_str)\n\n\ndef _get_output_shape(output_rank, known_last_dims):\n    return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims)\n"
  },
  {
    "path": "keras/src/layers/attention/multi_head_attention_test.py",
    "content": "import os\nimport warnings\n\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import constraints\nfrom keras.src import dtype_policies\nfrom keras.src import initializers\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import random\nfrom keras.src import saving\nfrom keras.src import testing\nfrom keras.src.backend.config import disable_flash_attention\nfrom keras.src.backend.config import enable_flash_attention\nfrom keras.src.backend.config import is_flash_attention_enabled\n\n\nclass MultiHeadAttentionTest(testing.TestCase):\n    def setUp(self):\n        super().setUp()\n        # Flash attention is a newly introduced feature. We need to disable it\n        # for testing purposes.\n        disable_flash_attention()\n\n    def tearDown(self):\n        enable_flash_attention()\n        return super().tearDown()\n\n    def test_basics(self):\n        self.assertFalse(is_flash_attention_enabled())\n        self.run_layer_test(\n            layers.MultiHeadAttention,\n            init_kwargs={\n                \"num_heads\": 2,\n                \"key_dim\": 2,\n            },\n            input_shape={\"query_shape\": (2, 8, 16), \"value_shape\": (2, 4, 16)},\n            expected_output_shape=(2, 8, 16),\n            expected_num_trainable_weights=8,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n            run_training_check=False,\n        )\n\n        self.run_layer_test(\n            layers.MultiHeadAttention,\n            init_kwargs={\n                \"num_heads\": 2,\n                \"key_dim\": 2,\n                \"value_dim\": 4,\n                \"use_bias\": False,\n                \"dropout\": 0.5,\n            },\n            input_shape={\"query_shape\": (2, 8, 16), \"value_shape\": (2, 4, 16)},\n            expected_output_shape=(2, 8, 16),\n            expected_num_trainable_weights=4,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=1,\n            expected_num_losses=0,\n            supports_masking=True,\n            run_training_check=False,\n        )\n\n        self.run_layer_test(\n            layers.MultiHeadAttention,\n            init_kwargs={\n                \"num_heads\": 2,\n                \"key_dim\": 2,\n                \"use_gate\": True,\n            },\n            input_shape={\"query_shape\": (2, 8, 16), \"value_shape\": (2, 4, 16)},\n            expected_output_shape=(2, 8, 16),\n            expected_num_trainable_weights=10,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n            run_training_check=False,\n        )\n\n        self.run_layer_test(\n            layers.MultiHeadAttention,\n            init_kwargs={\n                \"num_heads\": 2,\n                \"key_dim\": 2,\n                \"value_dim\": 4,\n                \"use_bias\": False,\n                \"dropout\": 0.5,\n                \"use_gate\": True,\n            },\n            input_shape={\"query_shape\": (2, 8, 16), \"value_shape\": (2, 4, 16)},\n            expected_output_shape=(2, 8, 16),\n            expected_num_trainable_weights=5,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=1,\n            expected_num_losses=0,\n            supports_masking=True,\n            run_training_check=False,\n        )\n\n    @pytest.mark.skipif(\n        backend.backend() not in (\"jax\", \"torch\"),\n        reason=\"Flash attention only supported on JAX and Torch\",\n    )\n    def test_basics_with_flash_attention(self):\n        enable_flash_attention()\n        if backend.backend() == \"torch\":\n            try:\n                self.run_layer_test(\n                    layers.MultiHeadAttention,\n                    init_kwargs={\n                        \"num_heads\": 2,\n                        \"key_dim\": 8,\n                        \"dtype\": \"float16\",\n                    },\n                    input_shape={\n                        \"query_shape\": (2, 8, 16),\n                        \"value_shape\": (2, 4, 16),\n                    },\n                    expected_output_shape=(2, 8, 16),\n                    expected_num_trainable_weights=8,\n                    expected_num_non_trainable_weights=0,\n                    expected_num_seed_generators=0,\n                    expected_num_losses=0,\n                    supports_masking=True,\n                    run_training_check=False,\n                )\n            except ImportError as e:\n                if \"Flash attention is not supported\" in str(e.args[0]):\n                    self.assertTrue(\n                        (\n                            \"Flash attention is not supported in your current \"\n                            \"PyTorch version.\"\n                        )\n                        in str(e.args[0])\n                    )\n            except RuntimeError as e:\n                if (\n                    \"Flash attention is not supported with the provided inputs\"\n                    in str(e.args[0])\n                ):\n                    self.assertTrue(\n                        (\n                            \"Flash attention is not supported with the \"\n                            \"provided inputs\"\n                        )\n                        in str(e.args[0])\n                    )\n        elif backend.backend() == \"jax\":\n            try:\n                self.run_layer_test(\n                    layers.MultiHeadAttention,\n                    init_kwargs={\n                        \"num_heads\": 2,\n                        \"key_dim\": 8,\n                        \"dtype\": \"float16\",\n                    },\n                    input_shape={\n                        \"query_shape\": (2, 8, 16),\n                        \"value_shape\": (2, 4, 16),\n                    },\n                    expected_output_shape=(2, 8, 16),\n                    expected_num_trainable_weights=8,\n                    expected_num_non_trainable_weights=0,\n                    expected_num_seed_generators=0,\n                    expected_num_losses=0,\n                    supports_masking=True,\n                    run_training_check=False,\n                )\n            except ImportError as e:\n                if \"Flash attention is not supported\" in str(e.args[0]):\n                    self.assertTrue(\n                        (\n                            \"Flash attention is not supported in your current \"\n                            \"JAX version.\"\n                        )\n                        in str(e.args[0])\n                    )\n            except RuntimeError as e:\n                if \"cuDNN\" in str(e.args[0]):\n                    self.assertTrue(\"cuDNN is not detected.\" in str(e.args[0]))\n                elif \"Require at least\" in str(e.args[0]):\n                    self.assertTrue(\n                        \"Require at least Ampere arch to run\" in str(e.args[0])\n                    )\n                elif \"Flash attention\" in str(e.args[0]):\n                    self.assertTrue(\n                        (\n                            \"Flash attention is not supported in your current \"\n                            \"JAX version.\"\n                        )\n                        in str(e.args[0])\n                    )\n\n    @parameterized.named_parameters(\n        (\"4d_inputs_1freebatch_mask2\", (3, 4), (3, 2), (4, 2), (2,)),\n        (\"4d_inputs_1freebatch_mask3\", (3, 4), (3, 2), (3, 4, 2), (2,)),\n        (\"4d_inputs_1freebatch_mask4\", (3, 4), (3, 2), (3, 2, 4, 2), (2,)),\n        (\"4d_inputs_2d_attention\", (3, 4), (3, 2), (3, 4, 3, 2), (1, 2)),\n        (\"5d_inputs_2d_attention\", (5, 3, 4), (5, 3, 2), (3, 4, 3, 2), (2, 3)),\n        (\n            \"5d_inputs_2d_attention_fullmask\",\n            (5, 3, 4),\n            (5, 3, 2),\n            (5, 3, 4, 3, 2),\n            (2, 3),\n        ),\n    )\n    def test_high_dim_attention(\n        self, q_dims, v_dims, mask_dims, attention_axes\n    ):\n        batch_size, hidden_size = 3, 8\n        query_shape = (batch_size,) + q_dims + (hidden_size,)\n        value_shape = (batch_size,) + v_dims + (hidden_size,)\n        self.run_layer_test(\n            layers.MultiHeadAttention,\n            init_kwargs={\n                \"num_heads\": 2,\n                \"key_dim\": 2,\n                \"attention_axes\": attention_axes,\n            },\n            input_shape={\n                \"query_shape\": query_shape,\n                \"value_shape\": value_shape,\n            },\n            expected_output_shape=query_shape,\n            expected_num_trainable_weights=8,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n            run_training_check=False,\n        )\n\n        self.run_layer_test(\n            layers.MultiHeadAttention,\n            init_kwargs={\n                \"num_heads\": 2,\n                \"key_dim\": 2,\n                \"use_gate\": True,\n                \"attention_axes\": attention_axes,\n            },\n            input_shape={\n                \"query_shape\": query_shape,\n                \"value_shape\": value_shape,\n            },\n            expected_output_shape=query_shape,\n            expected_num_trainable_weights=10,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n            run_training_check=False,\n        )\n\n    def test_attention_axes_negative_indexing(self):\n        x = np.random.normal(size=(2, 3, 8, 4))\n\n        # Create two layers with equivalent positive and negative indices\n        mha_pos = layers.MultiHeadAttention(\n            num_heads=2, key_dim=4, attention_axes=2\n        )\n        mha_neg = layers.MultiHeadAttention(\n            num_heads=2, key_dim=4, attention_axes=-2\n        )\n\n        # Initialize both layers\n        _ = mha_pos(x, x)\n        _ = mha_neg(x, x)\n\n        # Set same weights for fair comparison\n        mha_neg.set_weights(mha_pos.get_weights())\n\n        # Get outputs and attention scores\n        z_pos, a_pos = mha_pos(x, x, return_attention_scores=True)\n        z_neg, a_neg = mha_neg(x, x, return_attention_scores=True)\n\n        # Verify shapes match\n        self.assertEqual(z_pos.shape, z_neg.shape)\n        self.assertEqual(a_pos.shape, a_neg.shape)\n\n        # Verify outputs are identical\n        self.assertAllClose(z_pos, z_neg, rtol=1e-5, atol=1e-5)\n        self.assertAllClose(a_pos, a_neg, rtol=1e-5, atol=1e-5)\n\n        # Create two layers with equivalent positive and negative indices\n        mha_pos = layers.MultiHeadAttention(\n            num_heads=2, key_dim=4, attention_axes=2, use_gate=True\n        )\n        mha_neg = layers.MultiHeadAttention(\n            num_heads=2, key_dim=4, attention_axes=-2, use_gate=True\n        )\n\n        # Initialize both layers\n        _ = mha_pos(x, x)\n        _ = mha_neg(x, x)\n\n        # Set same weights for fair comparison\n        mha_neg.set_weights(mha_pos.get_weights())\n\n        # Get outputs and attention scores\n        z_pos, a_pos = mha_pos(x, x, return_attention_scores=True)\n        z_neg, a_neg = mha_neg(x, x, return_attention_scores=True)\n\n        # Verify shapes match\n        self.assertEqual(z_pos.shape, z_neg.shape)\n        self.assertEqual(a_pos.shape, a_neg.shape)\n\n        # Verify outputs are identical\n        self.assertAllClose(z_pos, z_neg, rtol=1e-5, atol=1e-5)\n        self.assertAllClose(a_pos, a_neg, rtol=1e-5, atol=1e-5)\n\n    @parameterized.named_parameters(\n        (\"without_key_same_proj\", (4, 8), (2, 8), None, None),\n        (\"with_key_same_proj\", (4, 8), (2, 8), (2, 3), None),\n        (\"without_key_different_proj\", (4, 8), (2, 8), None, (3, 4)),\n        (\"with_key_different_proj\", (4, 8), (2, 8), (2, 3), (1, 5)),\n        (\"high_dim_same_proj\", (4, 2, 3, 8), (1, 1, 5, 8), (1, 1, 5, 2), None),\n        (\n            \"high_dim_different_proj\",\n            (4, 2, 3, 8),\n            (1, 1, 5, 8),\n            (1, 1, 5, 2),\n            (3, 2),\n        ),\n        (\n            \"different_qv_last_dims\",\n            (4, 2, 3, 8),\n            (4, 2, 3, 7),\n            (4, 2, 3, 8),\n            None,\n        ),\n    )\n    def test_compute_output_shape(\n        self, query_dims, value_dims, key_dims, output_shape\n    ):\n        \"\"\"Test computed shape is equal to the layer output's shape.\"\"\"\n        layer = layers.MultiHeadAttention(\n            num_heads=2,\n            key_dim=2,\n            value_dim=2,\n            output_shape=output_shape,\n        )\n        batch_size = 7\n        query_shape = (batch_size,) + query_dims\n        value_shape = (batch_size,) + value_dims\n        key_shape = (batch_size,) + key_dims if key_dims else None\n\n        query = np.ones(query_shape)\n        value = np.ones(value_shape)\n        key = np.ones(key_shape) if key_shape else None\n        output = layer(query=query, value=value, key=key)\n        comp_output_shape = layer.compute_output_shape(\n            query_shape, value_shape, key_shape\n        )\n        self.assertEqual(output.shape, comp_output_shape)\n\n        # Test shapes as lists.\n        comp_output_shape = layer.compute_output_shape(\n            list(query_shape),\n            list(value_shape),\n            list(key_shape) if key_shape is not None else None,\n        )\n        self.assertEqual(output.shape, comp_output_shape)\n\n    @parameterized.named_parameters(\n        (\"query_value_dim_mismatch\", (2, 4, 8), (2, 2, 7), (2,)),\n        (\"key_value_dim_mismatch\", (2, 4, 8), (2, 2, 8), (2, 1, 7)),\n        (\n            \"key_value_dim_mismatch_high_dim\",\n            (2, 4, 2, 3, 8),\n            (2, 1, 1, 5, 8),\n            (2, 1, 15, 5, 2),\n        ),\n    )\n    def test_shape_mismatch_error(self, query_shape, value_shape, key_shape):\n        \"\"\"Test dimension mismatches\"\"\"\n        layer = layers.MultiHeadAttention(\n            num_heads=4,\n            key_dim=2,\n            value_dim=2,\n        )\n        with self.assertRaisesRegex(ValueError, r\"must be equal\"):\n            layer.compute_output_shape(query_shape, value_shape, key_shape)\n        with self.assertRaisesRegex(ValueError, r\"must be equal\"):\n            layer(\n                np.ones(query_shape), np.ones(value_shape), np.ones(key_shape)\n            )\n\n    def test_initializer(self):\n        # Test with a specified initializer.\n        layer = layers.MultiHeadAttention(\n            num_heads=12,\n            key_dim=64,\n            use_gate=True,\n            kernel_initializer=initializers.TruncatedNormal(stddev=0.02),\n        )\n        layer.build((2, 4, 8), (2, 4, 8))\n\n        # Make sure the sub layers have different kernel init value.\n        self.assertNotAllClose(\n            layer._query_dense.kernel,\n            layer._key_dense.kernel,\n        )\n        self.assertNotAllClose(\n            layer._query_dense.kernel,\n            layer._value_dense.kernel,\n        )\n        self.assertNotAllClose(\n            layer._query_dense.kernel,\n            layer._output_dense.kernel,\n        )\n        self.assertNotAllClose(\n            layer._query_dense.kernel,\n            layer._gate_dense.kernel,\n        )\n\n    @pytest.mark.skipif(\n        backend.backend() == \"numpy\",\n        reason=\"Numpy backend does not support masking.\",\n    )\n    def test_query_mask_propagation(self):\n        \"\"\"Test automatic propagation of the query's mask.\"\"\"\n        layer = layers.MultiHeadAttention(num_heads=2, key_dim=2)\n        self.assertTrue(layer.supports_masking)\n        query = np.array([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]])\n        masked_query = layers.Embedding(4, 8, mask_zero=True)(query)\n        query_mask = backend.get_keras_mask(masked_query)\n        value = np.random.normal(size=(3, 3, 8))\n        output = layer(query=masked_query, value=value)\n        self.assertAllClose(query_mask, output._keras_mask)\n\n    @parameterized.named_parameters((\"causal\", True), (\"not_causal\", 0))\n    @pytest.mark.skipif(\n        backend.backend() == \"numpy\",\n        reason=\"Numpy backend does not support masking.\",\n    )\n    def test_masking(self, use_causal_mask):\n        \"\"\"Test that the value and causal masks are taken into account.\"\"\"\n        layer = layers.MultiHeadAttention(num_heads=2, key_dim=2)\n        query = np.array([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]])\n        masked_query = layers.Embedding(4, 8, mask_zero=True)(query)\n        value = np.array([[5, 4, 0], [3, 0, 0], [2, 1, 1]])\n        masked_value = layers.Embedding(6, 8, mask_zero=True)(value)\n        output = layer(\n            query=masked_query,\n            value=masked_value,\n            use_causal_mask=use_causal_mask,\n        )\n        mask = np.array(\n            [[[1, 1, 0]] * 3 + [[0, 0, 0]] * 2]\n            + [[[1, 0, 0]] * 5]\n            + [[[1, 1, 1]] + [[0, 0, 0]] * 4]\n        )\n        if use_causal_mask:\n            mask = mask & np.array([[[1, 0, 0], [1, 1, 0]] + [[1, 1, 1]] * 3])\n        del masked_query._keras_mask\n        del masked_value._keras_mask\n        output_with_manual_mask = layer(\n            query=masked_query, value=masked_value, attention_mask=mask\n        )\n        self.assertAllClose(output, output_with_manual_mask)\n\n    def test_masking_with_different_shapes(self):\n        x = random.uniform(shape=(2, 5, 8))\n        mask = ops.tril(ops.ones((5, 5)))  # (5, 5)\n        layer = layers.MultiHeadAttention(num_heads=2, key_dim=4)\n        output_1 = layer(query=x, value=x, attention_mask=mask)\n\n        mask = ops.tile(mask[None, ...], (2, 1, 1))  # (2, 5, 5)\n        output_2 = layer(query=x, value=x, attention_mask=mask)\n\n        mask = ops.tile(mask[:, None, ...], (1, 2, 1, 1))  # (2, 2, 5, 5)\n        output_3 = layer(query=x, value=x, attention_mask=mask)\n\n        self.assertAllClose(output_1, output_2)\n        self.assertAllClose(output_1, output_3)\n\n    @pytest.mark.skipif(\n        backend.backend() == \"numpy\",\n        reason=\"Numpy backend does not support masking.\",\n    )\n    def test_no_warning_with_keras_mask(self):\n        layer = layers.MultiHeadAttention(num_heads=2, key_dim=2)\n        query = np.array([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]])\n        masked_query = layers.Embedding(4, 8, mask_zero=True)(query)\n        value = np.array([[5, 4, 0], [3, 0, 0], [2, 1, 1]])\n        masked_value = layers.Embedding(6, 8, mask_zero=True)(value)\n\n        with warnings.catch_warnings(record=True) as warning_logs:\n            _ = layer(query=masked_query, value=masked_value)\n            self.assertLen(warning_logs, 0)\n\n    @parameterized.named_parameters(\n        (\"disable_flash_attention\", False), (\"enable_flash_attention\", True)\n    )\n    def test_correctness(self, flash_attention):\n        if flash_attention:\n            # Let the backend decide whether to use flash attention\n            enable_flash_attention()\n        dtype = \"float16\"  # Flash attention only accepts float16/bfloat16\n\n        num_heads = 8\n        key_dim = 8  # key_dim % 8 == 0 to enable flash attention\n\n        query = np.identity(key_dim)[np.newaxis, ...]\n        key = np.identity(key_dim)[np.newaxis, ...]\n        value = (\n            np.reshape(np.arange(key_dim * key_dim), (1, key_dim, key_dim))\n            / 100.0  # Prevent overflow/underflow\n        )\n\n        # Setup layer.\n        layer = layers.MultiHeadAttention(\n            num_heads=num_heads, key_dim=key_dim, dtype=dtype\n        )\n        layer.build(query.shape, key.shape, value.shape)\n\n        # Set layer weights.\n        kernel = np.identity(key_dim)\n        # To get an identity kernel we need to add a head dim and repeat on it.\n        kernel = np.repeat(kernel[:, np.newaxis, :], num_heads, axis=1)\n        # Zeros for all biases.\n        bias = np.zeros((num_heads, key_dim))\n        output_bias = np.zeros((key_dim,))\n        layer.set_weights([kernel, bias] * 3 + [kernel, output_bias])\n        # Call layer and assert output.\n        expected_output = np.array(\n            [2.406, 2.440, 2.473, 2.504, 2.535, 2.568, 2.602, 2.633]\n        )\n        expected_output = np.tile(\n            expected_output[np.newaxis, :, np.newaxis], (1, 1, key_dim)\n        )\n        expected_score = np.array(\n            [\n                [0.1187] * 0 + [0.1691] + [0.1187] * 7,\n                [0.1187] * 1 + [0.1691] + [0.1187] * 6,\n                [0.1187] * 2 + [0.1691] + [0.1187] * 5,\n                [0.1187] * 3 + [0.1691] + [0.1187] * 4,\n                [0.1187] * 4 + [0.1691] + [0.1187] * 3,\n                [0.1187] * 5 + [0.1691] + [0.1187] * 2,\n                [0.1187] * 6 + [0.1691] + [0.1187] * 1,\n                [0.1187] * 7 + [0.1691] + [0.1187] * 0,\n            ]\n        )\n        expected_score = np.tile(\n            expected_score[np.newaxis, np.newaxis, ...], (1, key_dim, 1, 1)\n        )\n        if flash_attention:\n            output = layer(query=query, value=value, key=key)\n            self.assertAllClose(output, expected_output, atol=1e-2)\n        else:\n            output, scores = layer(\n                query=query,\n                value=value,\n                key=key,\n                return_attention_scores=True,\n            )\n            self.assertAllClose(output, expected_output, atol=1e-2)\n            self.assertAllClose(scores, expected_score, atol=1e-2)\n\n    def test_mha_constraints(self):\n        query = np.array([[[1.0, 0.0], [0.0, 1.0]]])\n        key = np.array([[[0.0, 1.0], [1.0, 0.0]]])\n        value = np.array([[[1.0, 2.0], [3.0, 4.0]]])\n        num_heads = 2\n        key_dim = 2\n        layer = layers.MultiHeadAttention(\n            num_heads=num_heads,\n            key_dim=key_dim,\n            use_gate=True,\n            kernel_constraint=\"non_neg\",\n        )\n        layer.build(query.shape, key.shape, value.shape)\n        self.assertIsInstance(\n            layer._query_dense.kernel.constraint, constraints.NonNeg\n        )\n        self.assertIsInstance(\n            layer._value_dense.kernel.constraint, constraints.NonNeg\n        )\n        self.assertIsInstance(\n            layer._key_dense.kernel.constraint, constraints.NonNeg\n        )\n        self.assertIsInstance(\n            layer._gate_dense.kernel.constraint, constraints.NonNeg\n        )\n        layer = layers.MultiHeadAttention(\n            num_heads=num_heads,\n            key_dim=key_dim,\n            use_gate=True,\n            bias_constraint=\"non_neg\",\n        )\n        layer.build(query.shape, key.shape, value.shape)\n        self.assertIsInstance(\n            layer._query_dense.bias.constraint, constraints.NonNeg\n        )\n        self.assertIsInstance(\n            layer._value_dense.bias.constraint, constraints.NonNeg\n        )\n        self.assertIsInstance(\n            layer._key_dense.bias.constraint, constraints.NonNeg\n        )\n        self.assertIsInstance(\n            layer._gate_dense.bias.constraint, constraints.NonNeg\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_lora(self):\n        query = np.array([[[1.0, 0.0], [0.0, 1.0]]])\n        key = np.array([[[0.0, 1.0], [1.0, 0.0]]])\n        value = np.array([[[1.0, 2.0], [3.0, 4.0]]])\n        layer = layers.MultiHeadAttention(\n            num_heads=3,\n            key_dim=8,\n            use_bias=False,\n            use_gate=True,\n        )\n        layer.build(query.shape, key.shape, value.shape)\n        layer.query_dense.enable_lora(2)\n        layer.key_dense.enable_lora(2)\n        layer.value_dense.enable_lora(2)\n\n        self.assertLen(layer.trainable_variables, 8)\n        self.assertLen(layer.non_trainable_variables, 3)\n\n        # Try eager call\n        x = {\n            \"query\": query,\n            \"key\": key,\n            \"value\": value,\n        }\n        y = np.random.random((1, 2, 2))\n        _ = layer(**x)\n\n        # Try calling fit()\n        inputs = {\n            \"query\": layers.Input((2, 2)),\n            \"key\": layers.Input((2, 2)),\n            \"value\": layers.Input((2, 2)),\n        }\n        outputs = layer(inputs[\"query\"], inputs[\"key\"], inputs[\"value\"])\n        model = models.Model(inputs, outputs)\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n        model.fit(x, y)\n\n        # Try saving and reloading the model\n        temp_filepath = os.path.join(self.get_temp_dir(), \"lora_model.keras\")\n        model.save(temp_filepath)\n\n        new_model = saving.load_model(temp_filepath)\n        self.assertAllClose(model.predict(x), new_model.predict(x))\n\n        # Try saving and reloading the model's weights only\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"lora_model.weights.h5\"\n        )\n        model.save_weights(temp_filepath)\n\n        # Load the file into a fresh, non-lora model\n        inputs = {\n            \"query\": layers.Input((2, 2)),\n            \"key\": layers.Input((2, 2)),\n            \"value\": layers.Input((2, 2)),\n        }\n        outputs = layers.MultiHeadAttention(\n            num_heads=3,\n            key_dim=8,\n            use_bias=False,\n            use_gate=True,\n        )(inputs[\"query\"], inputs[\"key\"], inputs[\"value\"])\n        new_model = models.Model(inputs, outputs)\n\n        new_model.load_weights(temp_filepath)\n        self.assertAllClose(model.predict(x), new_model.predict(x))\n\n        # Try loading a normal checkpoint into a lora model\n        new_model.save_weights(temp_filepath)\n        model.load_weights(temp_filepath)\n        self.assertAllClose(model.predict(x), new_model.predict(x))\n\n    @parameterized.parameters([((1, 2, 3),), ((2, 3, 5),)])\n    def test_symbolic_return_attention_scores(self, shape):\n        mha = layers.MultiHeadAttention(num_heads=4, key_dim=2)\n        x = layers.Input(batch_shape=shape)\n        y = layers.Input(batch_shape=shape)\n        symbolic_out = mha(x, y, return_attention_scores=True)\n        self.assertLen(symbolic_out, 2)\n\n        x = np.random.random(shape)\n        y = np.random.random(shape)\n        out = mha(x, y, return_attention_scores=True)\n        self.assertLen(out, 2)\n        self.assertEqual(symbolic_out[0].shape, out[0].shape)\n        self.assertEqual(symbolic_out[1].shape, out[1].shape)\n\n    def test_dtype_policy_map(self):\n        quantized_policy = dtype_policies.QuantizedDTypePolicy(\n            \"int8\", \"float32\"\n        )\n        policy_map = dtype_policies.DTypePolicyMap()\n\n        # Preset the quantized policy\n        policy_map[\"mha/query\"] = quantized_policy\n        policy_map[\"mha/key\"] = quantized_policy\n        policy_map[\"mha/value\"] = quantized_policy\n        query = np.array([[[1.0, 0.0], [0.0, 1.0]]])\n        key = np.array([[[0.0, 1.0], [1.0, 0.0]]])\n        value = np.array([[[1.0, 2.0], [3.0, 4.0]]])\n        layer = layers.MultiHeadAttention(\n            num_heads=3, key_dim=8, use_bias=False, dtype=policy_map, name=\"mha\"\n        )\n        layer.build(query.shape, key.shape, value.shape)\n\n        # Sublayers should be quantized\n        self.assertDType(layer._query_dense._kernel, \"int8\")\n        self.assertDType(layer._key_dense._kernel, \"int8\")\n        self.assertDType(layer._value_dense._kernel, \"int8\")\n\n    def test_flash_attention_with_errors(self):\n        if backend.backend() in (\"numpy\", \"tensorflow\"):\n            pytest.skip(\n                reason=(\n                    \"Flash attention is not supported on tensorflow and numpy.\"\n                )\n            )\n        # Check `flash_attention=True` and `dropout=0.1`\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Dropout is not supported when flash attention is enabled.\",\n        ):\n            layer = layers.MultiHeadAttention(\n                num_heads=2, key_dim=2, flash_attention=True, dropout=0.1\n            )\n\n        # Check `flash_attention=True` and `return_attention_scores=True`\n        layer = layers.MultiHeadAttention(\n            num_heads=2, key_dim=2, flash_attention=True\n        )\n        self.assertTrue(layer._flash_attention)\n        query = np.random.random((2, 4, 8))\n        value = np.random.random((2, 4, 8))\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Returning attention scores is not supported when flash \"\n            \"attention is enabled. Please disable flash attention to access\"\n            \" attention scores.\",\n        ):\n            layer(query=query, value=value, return_attention_scores=True)\n\n    def test_multi_head_attention_output_shape_as_int(self):\n        \"\"\"Test MultiHeadAttention with output_shape as an int.\"\"\"\n        mha = layers.MultiHeadAttention(num_heads=2, key_dim=16, output_shape=8)\n        query = random.uniform((2, 4, 16))\n        value = random.uniform((2, 4, 16))\n        output = mha(query=query, value=value)\n\n        self.assertEqual(output.shape, (2, 4, 8))\n\n    def test_multi_head_attention_output_shape_as_tuple(self):\n        \"\"\"Test MultiHeadAttention with output_shape as a tuple.\"\"\"\n        mha = layers.MultiHeadAttention(\n            num_heads=2, key_dim=16, output_shape=(8, 8)\n        )\n        query = random.uniform((2, 4, 16))\n        value = random.uniform((2, 4, 16))\n        output = mha(query=query, value=value)\n\n        self.assertEqual(output.shape, (2, 4, 8, 8))\n\n    def test_multi_head_attention_output_shape_error(self):\n        with self.assertRaisesRegex(ValueError, r\"Invalid `output_shape`\"):\n            layers.MultiHeadAttention(num_heads=2, key_dim=16, output_shape=8.0)\n\n    def test_quantize_int8(self):\n        query = np.array([[[1.0, 0.0], [0.0, 1.0]]])\n        key = np.array([[[0.0, 1.0], [1.0, 0.0]]])\n        value = np.array([[[1.0, 2.0], [3.0, 4.0]]])\n        layer = layers.MultiHeadAttention(\n            num_heads=3,\n            key_dim=8,\n            use_bias=False,\n        )\n        layer.build(query.shape, value.shape, key.shape)\n        output_float = layer(query, key, value)\n        for sublayer in layer._flatten_layers():\n            try:\n                sublayer.quantize(\"int8\")\n            except:\n                pass\n\n        # Verify weights dtype\n        self.assertDType(layer._query_dense._kernel, \"int8\")\n        self.assertDType(layer._key_dense._kernel, \"int8\")\n        self.assertDType(layer._value_dense._kernel, \"int8\")\n        self.assertDType(layer._output_dense._kernel, \"int8\")\n\n        # Try eager call and verify output correctness\n        output_quantized = layer(query, key, value)\n        mse = ops.mean(ops.square(output_float - output_quantized))\n        self.assertLess(mse, 1e-3)  # A weak correctness test\n\n        layer = layers.MultiHeadAttention(\n            num_heads=3,\n            key_dim=8,\n            use_gate=True,\n            use_bias=False,\n        )\n        layer.build(query.shape, value.shape, key.shape)\n        output_float = layer(query, key, value)\n        for sublayer in layer._flatten_layers():\n            try:\n                sublayer.quantize(\"int8\")\n            except:\n                pass\n\n        # Verify weights dtype\n        self.assertDType(layer._query_dense._kernel, \"int8\")\n        self.assertDType(layer._key_dense._kernel, \"int8\")\n        self.assertDType(layer._value_dense._kernel, \"int8\")\n        self.assertDType(layer._gate_dense._kernel, \"int8\")\n        self.assertDType(layer._output_dense._kernel, \"int8\")\n"
  },
  {
    "path": "keras/src/layers/convolutional/__init__.py",
    "content": ""
  },
  {
    "path": "keras/src/layers/convolutional/base_conv.py",
    "content": "\"\"\"Keras base class for convolution layers.\"\"\"\n\nfrom keras.src import activations\nfrom keras.src import constraints\nfrom keras.src import initializers\nfrom keras.src import ops\nfrom keras.src import regularizers\nfrom keras.src.backend import standardize_data_format\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\nfrom keras.src.ops.operation_utils import compute_conv_output_shape\nfrom keras.src.utils.argument_validation import standardize_padding\nfrom keras.src.utils.argument_validation import standardize_tuple\n\n\nclass BaseConv(Layer):\n    \"\"\"Abstract N-D convolution layer (private, used as implementation base).\n\n    This layer creates a convolution kernel that is convolved (actually\n    cross-correlated) with the layer input to produce a tensor of outputs. If\n    `use_bias` is True (and a `bias_initializer` is provided), a bias vector is\n    created and added to the outputs. Finally, if `activation` is not `None`, it\n    is applied to the outputs as well.\n\n    Note: layer attributes cannot be modified after the layer has been called\n    once (except the `trainable` attribute).\n\n    Args:\n        rank: int, the rank of the convolution, e.g. 2 for 2D convolution.\n        filters: int, the dimension of the output space (the number of filters\n            in the convolution).\n        kernel_size: int or tuple/list of `rank` integers, specifying the size\n            of the convolution window.\n        strides: int or tuple/list of `rank` integers, specifying the stride\n            length of the convolution. If only one int is specified, the same\n            stride size will be used for all dimensions. `strides > 1` is\n            incompatible with `dilation_rate > 1`.\n        padding: string, either `\"valid\"` or `\"same\"` (case-insensitive).\n            `\"valid\"` means no padding. `\"same\"` results in padding evenly to\n            the left/right or up/down of the input. When `padding=\"same\"` and\n            `strides=1`, the output has the same size as the input.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, steps, features)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, features, steps)`. It defaults to the `image_data_format`\n            value found in your Keras config file at `~/.keras/keras.json`.\n            If you never set it, then it will be `\"channels_last\"`.\n        dilation_rate: int or tuple/list of `rank` integers, specifying the\n            dilation rate to use for dilated convolution. If only one int is\n            specified, the same dilation rate will be used for all dimensions.\n        groups: A positive int specifying the number of groups in which the\n            input is split along the channel axis. Each group is convolved\n            separately with `filters // groups` filters. The output is the\n            concatenation of all the `groups` results along the channel axis.\n            Input channels and `filters` must both be divisible by `groups`.\n        activation: Activation function. If `None`, no activation is applied.\n        use_bias: bool, if `True`, bias will be added to the output.\n        kernel_initializer: Initializer for the convolution kernel. If `None`,\n            the default initializer (`\"glorot_uniform\"`) will be used.\n        bias_initializer: Initializer for the bias vector. If `None`, the\n            default initializer (`\"zeros\"`) will be used.\n        kernel_regularizer: Optional regularizer for the convolution kernel.\n        bias_regularizer: Optional regularizer for the bias vector.\n        activity_regularizer: Optional regularizer function for the output.\n        kernel_constraint: Optional projection function to be applied to the\n            kernel after being updated by an `Optimizer` (e.g. used to implement\n            norm constraints or value constraints for layer weights). The\n            function must take as input the unprojected variable and must return\n            the projected variable (which must have the same shape). Constraints\n            are not safe to use when doing asynchronous distributed training.\n        bias_constraint: Optional projection function to be applied to the\n            bias after being updated by an `Optimizer`.\n        lora_rank: Optional integer. If set, the layer's forward pass\n            will implement LoRA (Low-Rank Adaptation)\n            with the provided rank. LoRA sets the layer's kernel\n            to non-trainable and replaces it with a delta over the\n            original kernel, obtained via multiplying two lower-rank\n            trainable matrices. This can be useful to reduce the\n            computation cost of fine-tuning large dense layers.\n            You can also enable LoRA on an existing layer by calling\n            `layer.enable_lora(rank)`.\n        lora_alpha: Optional integer. If set, this parameter scales the\n            low-rank adaptation delta (computed as the product of two lower-rank\n            trainable matrices) during the forward pass. The delta is scaled by\n            `lora_alpha / lora_rank`, allowing you to fine-tune the strength of\n            the LoRA adjustment independently of `lora_rank`.\n    \"\"\"\n\n    def __init__(\n        self,\n        rank,\n        filters,\n        kernel_size,\n        strides=1,\n        padding=\"valid\",\n        data_format=None,\n        dilation_rate=1,\n        groups=1,\n        activation=None,\n        use_bias=True,\n        kernel_initializer=\"glorot_uniform\",\n        bias_initializer=\"zeros\",\n        kernel_regularizer=None,\n        bias_regularizer=None,\n        activity_regularizer=None,\n        kernel_constraint=None,\n        bias_constraint=None,\n        lora_rank=None,\n        lora_alpha=None,\n        **kwargs,\n    ):\n        super().__init__(activity_regularizer=activity_regularizer, **kwargs)\n        self.rank = rank\n        self.filters = filters\n        self.groups = groups\n        self.kernel_size = standardize_tuple(kernel_size, rank, \"kernel_size\")\n        self.strides = standardize_tuple(strides, rank, \"strides\")\n        self.dilation_rate = standardize_tuple(\n            dilation_rate, rank, \"dilation_rate\"\n        )\n        self.padding = standardize_padding(padding, allow_causal=rank == 1)\n        self.data_format = standardize_data_format(data_format)\n        self.activation = activations.get(activation)\n        self.use_bias = use_bias\n        self.kernel_initializer = initializers.get(kernel_initializer)\n        self.bias_initializer = initializers.get(bias_initializer)\n        self.kernel_regularizer = regularizers.get(kernel_regularizer)\n        self.bias_regularizer = regularizers.get(bias_regularizer)\n        self.kernel_constraint = constraints.get(kernel_constraint)\n        self.bias_constraint = constraints.get(bias_constraint)\n        self.lora_rank = lora_rank\n        self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank\n        self.lora_enabled = False\n        self.input_spec = InputSpec(min_ndim=self.rank + 2)\n        self.data_format = self.data_format\n\n        if self.filters is not None and self.filters <= 0:\n            raise ValueError(\n                \"Invalid value for argument `filters`. Expected a strictly \"\n                f\"positive value. Received filters={self.filters}.\"\n            )\n\n        if self.groups <= 0:\n            raise ValueError(\n                \"The number of groups must be a positive integer. \"\n                f\"Received: groups={self.groups}.\"\n            )\n\n        if self.filters is not None and self.filters % self.groups != 0:\n            raise ValueError(\n                \"The number of filters must be evenly divisible by the \"\n                f\"number of groups. Received: groups={self.groups}, \"\n                f\"filters={self.filters}.\"\n            )\n\n        if not all(self.kernel_size):\n            raise ValueError(\n                \"The argument `kernel_size` cannot contain 0. Received \"\n                f\"kernel_size={self.kernel_size}.\"\n            )\n\n        if not all(self.strides):\n            raise ValueError(\n                \"The argument `strides` cannot contains 0. Received \"\n                f\"strides={self.strides}\"\n            )\n\n        if max(self.strides) > 1 and max(self.dilation_rate) > 1:\n            raise ValueError(\n                \"`strides > 1` not supported in conjunction with \"\n                f\"`dilation_rate > 1`. Received: strides={self.strides} and \"\n                f\"dilation_rate={self.dilation_rate}\"\n            )\n\n    def build(self, input_shape):\n        if self.data_format == \"channels_last\":\n            channel_axis = -1\n            input_channel = input_shape[-1]\n        else:\n            channel_axis = 1\n            input_channel = input_shape[1]\n        self.input_spec = InputSpec(\n            min_ndim=self.rank + 2, axes={channel_axis: input_channel}\n        )\n        if input_channel % self.groups != 0:\n            raise ValueError(\n                \"The number of input channels must be evenly divisible by \"\n                f\"the number of groups. Received groups={self.groups}, but the \"\n                f\"input has {input_channel} channels (full input shape is \"\n                f\"{input_shape}).\"\n            )\n        kernel_shape = self.kernel_size + (\n            input_channel // self.groups,\n            self.filters,\n        )\n\n        # compute_output_shape contains some validation logic for the input\n        # shape, and make sure the output shape has all positive dimensions.\n        self.compute_output_shape(input_shape)\n\n        self._kernel = self.add_weight(\n            name=\"kernel\",\n            shape=kernel_shape,\n            initializer=self.kernel_initializer,\n            regularizer=self.kernel_regularizer,\n            constraint=self.kernel_constraint,\n            trainable=True,\n            dtype=self.dtype,\n        )\n        if self.use_bias:\n            self.bias = self.add_weight(\n                name=\"bias\",\n                shape=(self.filters,),\n                initializer=self.bias_initializer,\n                regularizer=self.bias_regularizer,\n                constraint=self.bias_constraint,\n                trainable=True,\n                dtype=self.dtype,\n            )\n        else:\n            self.bias = None\n        self.built = True\n        if self.lora_rank:\n            self.enable_lora(self.lora_rank, lora_alpha=self.lora_alpha)\n\n    @property\n    def kernel(self):\n        if not self.built:\n            raise AttributeError(\n                \"You must build the layer before accessing `kernel`.\"\n            )\n        if self.lora_enabled:\n            return self._kernel + (\n                self.lora_alpha / self.lora_rank\n            ) * ops.matmul(self.lora_kernel_a, self.lora_kernel_b)\n        return self._kernel\n\n    def convolution_op(self, inputs, kernel):\n        return ops.conv(\n            inputs,\n            kernel,\n            strides=list(self.strides),\n            padding=self.padding,\n            dilation_rate=self.dilation_rate,\n            data_format=self.data_format,\n        )\n\n    def call(self, inputs):\n        outputs = self.convolution_op(\n            inputs,\n            self.kernel,\n        )\n        if self.use_bias:\n            if self.data_format == \"channels_last\":\n                bias_shape = (1,) * (self.rank + 1) + (self.filters,)\n            else:\n                bias_shape = (1, self.filters) + (1,) * self.rank\n            bias = ops.reshape(self.bias, bias_shape)\n            outputs = ops.add(outputs, bias)\n\n        if self.activation is not None:\n            return self.activation(outputs)\n        return outputs\n\n    def compute_output_shape(self, input_shape):\n        return compute_conv_output_shape(\n            input_shape,\n            self.filters,\n            self.kernel_size,\n            strides=self.strides,\n            padding=self.padding,\n            data_format=self.data_format,\n            dilation_rate=self.dilation_rate,\n        )\n\n    def enable_lora(\n        self,\n        rank,\n        lora_alpha=None,\n        a_initializer=\"he_uniform\",\n        b_initializer=\"zeros\",\n    ):\n        if self.kernel_constraint:\n            raise ValueError(\n                \"Lora is incompatible with kernel constraints. \"\n                \"In order to enable lora on this layer, remove the \"\n                \"`kernel_constraint` argument.\"\n            )\n        if not self.built:\n            raise ValueError(\n                \"Cannot enable lora on a layer that isn't yet built.\"\n            )\n        if self.lora_enabled:\n            raise ValueError(\n                \"lora is already enabled. This can only be done once per layer.\"\n            )\n        self._tracker.unlock()\n        self.lora_kernel_a = self.add_weight(\n            name=\"lora_kernel_a\",\n            shape=self._kernel.shape[:-1] + (rank,),\n            initializer=initializers.get(a_initializer),\n            regularizer=self.kernel_regularizer,\n        )\n        self.lora_kernel_b = self.add_weight(\n            name=\"lora_kernel_b\",\n            shape=(rank, self.filters),\n            initializer=initializers.get(b_initializer),\n            regularizer=self.kernel_regularizer,\n        )\n        self._kernel.trainable = False\n        self._tracker.lock()\n        self.lora_enabled = True\n        self.lora_rank = rank\n        self.lora_alpha = lora_alpha if lora_alpha is not None else rank\n\n    def save_own_variables(self, store):\n        # Do nothing if the layer isn't yet built\n        if not self.built:\n            return\n        target_variables = [self.kernel]\n        if self.use_bias:\n            target_variables.append(self.bias)\n        for i, variable in enumerate(target_variables):\n            store[str(i)] = variable\n\n    def load_own_variables(self, store):\n        if not self.lora_enabled:\n            self._check_load_own_variables(store)\n        # Do nothing if the layer isn't yet built\n        if not self.built:\n            return\n        target_variables = [self._kernel]\n        if self.use_bias:\n            target_variables.append(self.bias)\n        for i, variable in enumerate(target_variables):\n            variable.assign(store[str(i)])\n        if self.lora_enabled:\n            self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))\n            self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"filters\": self.filters,\n                \"kernel_size\": self.kernel_size,\n                \"strides\": self.strides,\n                \"padding\": self.padding,\n                \"data_format\": self.data_format,\n                \"dilation_rate\": self.dilation_rate,\n                \"groups\": self.groups,\n                \"activation\": activations.serialize(self.activation),\n                \"use_bias\": self.use_bias,\n                \"kernel_initializer\": initializers.serialize(\n                    self.kernel_initializer\n                ),\n                \"bias_initializer\": initializers.serialize(\n                    self.bias_initializer\n                ),\n                \"kernel_regularizer\": regularizers.serialize(\n                    self.kernel_regularizer\n                ),\n                \"bias_regularizer\": regularizers.serialize(\n                    self.bias_regularizer\n                ),\n                \"activity_regularizer\": regularizers.serialize(\n                    self.activity_regularizer\n                ),\n                \"kernel_constraint\": constraints.serialize(\n                    self.kernel_constraint\n                ),\n                \"bias_constraint\": constraints.serialize(self.bias_constraint),\n            }\n        )\n        if self.lora_rank:\n            config[\"lora_rank\"] = self.lora_rank\n            config[\"lora_alpha\"] = self.lora_alpha\n        return config\n\n    def _check_load_own_variables(self, store):\n        all_vars = self._trainable_variables + self._non_trainable_variables\n        if len(store.keys()) != len(all_vars):\n            if len(all_vars) == 0 and not self.built:\n                raise ValueError(\n                    f\"Layer '{self.name}' was never built \"\n                    \"and thus it doesn't have any variables. \"\n                    f\"However the weights file lists {len(store.keys())} \"\n                    \"variables for this layer.\\n\"\n                    \"In most cases, this error indicates that either:\\n\\n\"\n                    \"1. The layer is owned by a parent layer that \"\n                    \"implements a `build()` method, but calling the \"\n                    \"parent's `build()` method did NOT create the state of \"\n                    f\"the child layer '{self.name}'. A `build()` method \"\n                    \"must create ALL state for the layer, including \"\n                    \"the state of any children layers.\\n\\n\"\n                    \"2. You need to implement \"\n                    \"the `def build_from_config(self, config)` method \"\n                    f\"on layer '{self.name}', to specify how to rebuild \"\n                    \"it during loading. \"\n                    \"In this case, you might also want to implement the \"\n                    \"method that generates the build config at saving time, \"\n                    \"`def get_build_config(self)`. \"\n                    \"The method `build_from_config()` is meant \"\n                    \"to create the state \"\n                    \"of the layer (i.e. its variables) upon deserialization.\",\n                )\n            raise ValueError(\n                f\"Layer '{self.name}' expected {len(all_vars)} variables, \"\n                \"but received \"\n                f\"{len(store.keys())} variables during loading. \"\n                f\"Expected: {[v.name for v in all_vars]}\"\n            )\n"
  },
  {
    "path": "keras/src/layers/convolutional/base_conv_transpose.py",
    "content": "\"\"\"Keras base class for transpose convolution layers.\"\"\"\n\nfrom keras.src import activations\nfrom keras.src import constraints\nfrom keras.src import initializers\nfrom keras.src import ops\nfrom keras.src import regularizers\nfrom keras.src.backend import standardize_data_format\nfrom keras.src.backend.common.backend_utils import (\n    compute_conv_transpose_output_shape,\n)\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\nfrom keras.src.utils.argument_validation import standardize_padding\nfrom keras.src.utils.argument_validation import standardize_tuple\n\n\nclass BaseConvTranspose(Layer):\n    \"\"\"Abstract N-D transposed convolution layer.\n\n    The need for transposed convolutions generally arises from the desire to use\n    a transformation going in the opposite direction of a normal convolution,\n    i.e., from something that has the shape of the output of some convolution to\n    something that has the shape of its input while maintaining a connectivity\n    pattern that is compatible with said convolution.\n\n    Args:\n        rank: int, the rank of the transposed convolution, e.g. 2 for 2D\n            transposed convolution.\n        filters: int, the dimension of the output space (the number of filters\n            in the transposed convolution).\n        kernel_size: int or tuple/list of `rank` integers, specifying the size\n            of the transposed convolution window.\n        strides: int or tuple/list of `rank` integers, specifying the stride\n            length of the transposed convolution. If only one int is specified,\n            the same stride size will be used for all dimensions.\n            `strides > 1` is incompatible with `dilation_rate > 1`.\n        padding: string, either `\"valid\"` or `\"same\"` (case-insensitive).\n            `\"valid\"` means no padding. `\"same\"` results in padding evenly to\n            the left/right or up/down of the input such that output has the same\n            height/width dimension as the input.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, steps, features)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, features, steps)`. It defaults to the `image_data_format`\n            value found in your Keras config file at `~/.keras/keras.json`.\n            If you never set it, then it will be `\"channels_last\"`.\n        dilation_rate: int or tuple/list of `rank` integers, specifying the\n            dilation rate to use for dilated convolution. If only one int is\n            specified, the same dilation rate will be used for all dimensions.\n        activation: Activation function. If `None`, no activation is applied.\n        use_bias: bool, if `True`, bias will be added to the output.\n        kernel_initializer: Initializer for the convolution kernel. If `None`,\n            the default initializer (`\"glorot_uniform\"`) will be used.\n        bias_initializer: Initializer for the bias vector. If `None`, the\n            default initializer (`\"zeros\"`) will be used.\n        kernel_regularizer: Optional regularizer for the convolution kernel.\n        bias_regularizer: Optional regularizer for the bias vector.\n        activity_regularizer: Optional regularizer function for the output.\n        kernel_constraint: Optional projection function to be applied to the\n            kernel after being updated by an `Optimizer` (e.g. used to implement\n            norm constraints or value constraints for layer weights). The\n            function must take as input the unprojected variable and must return\n            the projected variable (which must have the same shape). Constraints\n            are not safe to use when doing asynchronous distributed training.\n        bias_constraint: Optional projection function to be applied to the\n            bias after being updated by an `Optimizer`.\n    \"\"\"\n\n    def __init__(\n        self,\n        rank,\n        filters,\n        kernel_size,\n        strides=1,\n        padding=\"valid\",\n        output_padding=None,\n        data_format=None,\n        dilation_rate=1,\n        activation=None,\n        use_bias=True,\n        kernel_initializer=\"glorot_uniform\",\n        bias_initializer=\"zeros\",\n        kernel_regularizer=None,\n        bias_regularizer=None,\n        activity_regularizer=None,\n        kernel_constraint=None,\n        bias_constraint=None,\n        trainable=True,\n        name=None,\n        **kwargs,\n    ):\n        super().__init__(\n            trainable=trainable,\n            name=name,\n            activity_regularizer=activity_regularizer,\n            **kwargs,\n        )\n        self.rank = rank\n        self.filters = filters\n        self.kernel_size = standardize_tuple(kernel_size, rank, \"kernel_size\")\n        self.strides = standardize_tuple(strides, rank, \"strides\")\n        self.dilation_rate = standardize_tuple(\n            dilation_rate, rank, \"dilation_rate\"\n        )\n        self.padding = standardize_padding(padding)\n        if output_padding is None:\n            self.output_padding = None\n        else:\n            self.output_padding = standardize_tuple(\n                output_padding,\n                rank,\n                \"output_padding\",\n                allow_zero=True,\n            )\n        self.data_format = standardize_data_format(data_format)\n        self.activation = activations.get(activation)\n        self.use_bias = use_bias\n        self.kernel_initializer = initializers.get(kernel_initializer)\n        self.bias_initializer = initializers.get(bias_initializer)\n        self.kernel_regularizer = regularizers.get(kernel_regularizer)\n        self.bias_regularizer = regularizers.get(bias_regularizer)\n        self.kernel_constraint = constraints.get(kernel_constraint)\n        self.bias_constraint = constraints.get(bias_constraint)\n        self.input_spec = InputSpec(min_ndim=self.rank + 2)\n        self.data_format = self.data_format\n\n        if self.filters is not None and self.filters <= 0:\n            raise ValueError(\n                \"Invalid value for argument `filters`. Expected a strictly \"\n                f\"positive value. Received filters={self.filters}.\"\n            )\n\n        if not all(self.kernel_size):\n            raise ValueError(\n                \"The argument `kernel_size` cannot contain 0. Received \"\n                f\"kernel_size={self.kernel_size}.\"\n            )\n\n        if not all(self.strides):\n            raise ValueError(\n                \"The argument `strides` cannot contains 0. Received \"\n                f\"strides={self.strides}.\"\n            )\n\n        if self.output_padding is not None:\n            for i, (op, s) in enumerate(zip(self.output_padding, self.strides)):\n                if op >= s:\n                    raise ValueError(\n                        \"`output_padding` must be strictly less than \"\n                        f\"`strides` for all dimensions. At dimension {i}, \"\n                        f\"`output_padding` is {op} but `strides` is {s}. \"\n                        f\"Received: output_padding={self.output_padding}, \"\n                        f\"strides={self.strides}\"\n                    )\n\n        if max(self.strides) > 1 and max(self.dilation_rate) > 1:\n            raise ValueError(\n                \"`strides > 1` not supported in conjunction with \"\n                f\"`dilation_rate > 1`. Received: strides={self.strides} and \"\n                f\"dilation_rate={self.dilation_rate}\"\n            )\n\n        if self.output_padding is not None:\n            for i, (op, s) in enumerate(zip(self.output_padding, self.strides)):\n                if op >= s:\n                    raise ValueError(\n                        \"Invalid `output_padding` argument. \"\n                        \"Each value in `output_padding` must be strictly \"\n                        \"less than the corresponding `strides` value.\\n\"\n                        f\"At index {i}, `output_padding` is {op} and `strides` \"\n                        f\"is {s}.\\n\"\n                        f\"Received: output_padding={self.output_padding}, \"\n                        f\"strides={self.strides}.\"\n                    )\n\n    def build(self, input_shape):\n        if self.data_format == \"channels_last\":\n            channel_axis = -1\n            input_channel = input_shape[-1]\n        else:\n            channel_axis = 1\n            input_channel = input_shape[1]\n        self.input_spec = InputSpec(\n            min_ndim=self.rank + 2, axes={channel_axis: input_channel}\n        )\n        kernel_shape = self.kernel_size + (\n            self.filters,\n            input_channel,\n        )\n\n        self.kernel = self.add_weight(\n            name=\"kernel\",\n            shape=kernel_shape,\n            initializer=self.kernel_initializer,\n            regularizer=self.kernel_regularizer,\n            constraint=self.kernel_constraint,\n            trainable=True,\n            dtype=self.dtype,\n        )\n        if self.use_bias:\n            self.bias = self.add_weight(\n                name=\"bias\",\n                shape=(self.filters,),\n                initializer=self.bias_initializer,\n                regularizer=self.bias_regularizer,\n                constraint=self.bias_constraint,\n                trainable=True,\n                dtype=self.dtype,\n            )\n        else:\n            self.bias = None\n\n    def call(self, inputs):\n        outputs = ops.conv_transpose(\n            inputs,\n            self.kernel,\n            strides=list(self.strides),\n            padding=self.padding,\n            output_padding=self.output_padding,\n            dilation_rate=self.dilation_rate,\n            data_format=self.data_format,\n        )\n\n        if self.use_bias:\n            if self.data_format == \"channels_last\":\n                bias_shape = (1,) * (self.rank + 1) + (self.filters,)\n            else:\n                bias_shape = (1, self.filters) + (1,) * self.rank\n            bias = ops.reshape(self.bias, bias_shape)\n            outputs = ops.add(outputs, bias)\n\n        if self.activation is not None:\n            return self.activation(outputs)\n        return outputs\n\n    def compute_output_shape(self, input_shape):\n        return compute_conv_transpose_output_shape(\n            input_shape,\n            self.kernel_size,\n            self.filters,\n            strides=self.strides,\n            padding=self.padding,\n            output_padding=self.output_padding,\n            data_format=self.data_format,\n            dilation_rate=self.dilation_rate,\n        )\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"filters\": self.filters,\n                \"kernel_size\": self.kernel_size,\n                \"strides\": self.strides,\n                \"padding\": self.padding,\n                \"data_format\": self.data_format,\n                \"dilation_rate\": self.dilation_rate,\n                \"activation\": activations.serialize(self.activation),\n                \"use_bias\": self.use_bias,\n                \"kernel_initializer\": initializers.serialize(\n                    self.kernel_initializer\n                ),\n                \"bias_initializer\": initializers.serialize(\n                    self.bias_initializer\n                ),\n                \"kernel_regularizer\": regularizers.serialize(\n                    self.kernel_regularizer\n                ),\n                \"bias_regularizer\": regularizers.serialize(\n                    self.bias_regularizer\n                ),\n                \"activity_regularizer\": regularizers.serialize(\n                    self.activity_regularizer\n                ),\n                \"kernel_constraint\": constraints.serialize(\n                    self.kernel_constraint\n                ),\n                \"bias_constraint\": constraints.serialize(self.bias_constraint),\n            }\n        )\n        return config\n"
  },
  {
    "path": "keras/src/layers/convolutional/base_depthwise_conv.py",
    "content": "\"\"\"Keras base class for depthwise convolution layers.\"\"\"\n\nfrom keras.src import activations\nfrom keras.src import constraints\nfrom keras.src import initializers\nfrom keras.src import ops\nfrom keras.src import regularizers\nfrom keras.src.backend import standardize_data_format\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\nfrom keras.src.ops.operation_utils import compute_conv_output_shape\nfrom keras.src.utils.argument_validation import standardize_padding\nfrom keras.src.utils.argument_validation import standardize_tuple\n\n\nclass BaseDepthwiseConv(Layer):\n    \"\"\"Abstract N-D depthwise convolution layer.\n\n    Depthwise convolution is a type of convolution in which each input channel\n    is convolved with a different kernel (called a depthwise kernel). You can\n    understand depthwise convolution as the first step in a depthwise separable\n    convolution.\n\n    It is implemented via the following steps:\n\n    - Split the input into individual channels.\n    - Convolve each channel with an individual depthwise kernel with\n      `depth_multiplier` output channels.\n    - Concatenate the convolved outputs along the channels axis.\n\n    Unlike a regular convolution, depthwise convolution does not mix information\n    across different input channels.\n\n    The `depth_multiplier` argument determines how many filter are applied to\n    one input channel. As such, it controls the amount of output channels that\n    are generated per input channel in the depthwise step.\n\n\n    Args:\n        rank: int, the rank of the convolution, e.g. 2 for 2D convolution.\n        depth_multiplier: The number of depthwise convolution output channels\n            for each input channel. The total number of depthwise convolution\n            output channels will be equal to `input_channel * depth_multiplier`.\n        kernel_size: int or tuple/list of `rank` integers, specifying the size\n            of the depthwise convolution window.\n        strides: int or tuple/list of `rank` integers, specifying the stride\n            length of the depthwise convolution. If only one int is specified,\n            the same stride size will be used for all dimensions.\n            `strides > 1` is incompatible with `dilation_rate > 1`.\n        padding: string, either `\"valid\"` or `\"same\"` (case-insensitive).\n            `\"valid\"` means no padding. `\"same\"` results in padding evenly to\n            the left/right or up/down of the input. When `padding=\"same\"` and\n            `strides=1`, the output has the same size as the input.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, steps, features)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, features, steps)`. It defaults to the `image_data_format`\n            value found in your Keras config file at `~/.keras/keras.json`.\n            If you never set it, then it will be `\"channels_last\"`.\n        dilation_rate: int or tuple/list of `rank` integers, specifying the\n            dilation rate to use for dilated convolution. If only one int is\n            specified, the same dilation rate will be used for all dimensions.\n        activation: Activation function. If `None`, no activation is applied.\n        use_bias: bool, if `True`, bias will be added to the output.\n        depthwise_initializer: Initializer for the depthwsie convolution\n            kernel. If `None`, the default initializer (`\"glorot_uniform\"`)\n            will be used.\n        bias_initializer: Initializer for the bias vector. If `None`, the\n            default initializer (`\"zeros\"`) will be used.\n        depthwise_regularizer: Optional regularizer for the convolution kernel.\n        bias_regularizer: Optional regularizer for the bias vector.\n        activity_regularizer: Optional regularizer function for the output.\n        depthwise_constraint: Optional projection function to be applied to the\n            kernel after being updated by an `Optimizer` (e.g. used to implement\n            norm constraints or value constraints for layer weights). The\n            function must take as input the unprojected variable and must return\n            the projected variable (which must have the same shape). Constraints\n            are not safe to use when doing asynchronous distributed training.\n        bias_constraint: Optional projection function to be applied to the\n            bias after being updated by an `Optimizer`.\n    \"\"\"\n\n    def __init__(\n        self,\n        rank,\n        depth_multiplier,\n        kernel_size,\n        strides=1,\n        padding=\"valid\",\n        data_format=None,\n        dilation_rate=1,\n        activation=None,\n        use_bias=True,\n        depthwise_initializer=\"glorot_uniform\",\n        bias_initializer=\"zeros\",\n        depthwise_regularizer=None,\n        bias_regularizer=None,\n        activity_regularizer=None,\n        depthwise_constraint=None,\n        bias_constraint=None,\n        trainable=True,\n        name=None,\n        **kwargs,\n    ):\n        super().__init__(\n            trainable=trainable,\n            name=name,\n            activity_regularizer=regularizers.get(activity_regularizer),\n            **kwargs,\n        )\n        self.rank = rank\n        self.depth_multiplier = depth_multiplier\n        self.kernel_size = standardize_tuple(kernel_size, rank, \"kernel_size\")\n        self.strides = standardize_tuple(strides, rank, \"strides\")\n        self.dilation_rate = standardize_tuple(\n            dilation_rate, rank, \"dilation_rate\"\n        )\n        self.padding = standardize_padding(padding)\n        self.data_format = standardize_data_format(data_format)\n        self.activation = activations.get(activation)\n        self.use_bias = use_bias\n        self.depthwise_initializer = initializers.get(depthwise_initializer)\n        self.bias_initializer = initializers.get(bias_initializer)\n        self.depthwise_regularizer = regularizers.get(depthwise_regularizer)\n        self.bias_regularizer = regularizers.get(bias_regularizer)\n        self.depthwise_constraint = constraints.get(depthwise_constraint)\n        self.bias_constraint = constraints.get(bias_constraint)\n        self.input_spec = InputSpec(min_ndim=self.rank + 2)\n        self.data_format = self.data_format\n\n        if self.depth_multiplier is not None and self.depth_multiplier <= 0:\n            raise ValueError(\n                \"Invalid value for argument `depth_multiplier`. Expected a \"\n                \"strictly positive value. Received \"\n                f\"depth_multiplier={self.depth_multiplier}.\"\n            )\n\n        if not all(self.kernel_size):\n            raise ValueError(\n                \"The argument `kernel_size` cannot contain 0. Received \"\n                f\"kernel_size={self.kernel_size}.\"\n            )\n\n        if not all(self.strides):\n            raise ValueError(\n                \"The argument `strides` cannot contains 0. Received \"\n                f\"strides={self.strides}\"\n            )\n\n        if max(self.strides) > 1 and max(self.dilation_rate) > 1:\n            raise ValueError(\n                \"`strides > 1` not supported in conjunction with \"\n                f\"`dilation_rate > 1`. Received: strides={self.strides} and \"\n                f\"dilation_rate={self.dilation_rate}\"\n            )\n\n    def build(self, input_shape):\n        if self.data_format == \"channels_last\":\n            channel_axis = -1\n            input_channel = input_shape[-1]\n        else:\n            channel_axis = 1\n            input_channel = input_shape[1]\n        self.input_spec = InputSpec(\n            min_ndim=self.rank + 2, axes={channel_axis: input_channel}\n        )\n        depthwise_shape = self.kernel_size + (\n            input_channel,\n            self.depth_multiplier,\n        )\n        self.kernel = self.add_weight(\n            name=\"kernel\",\n            shape=depthwise_shape,\n            initializer=self.depthwise_initializer,\n            regularizer=self.depthwise_regularizer,\n            constraint=self.depthwise_constraint,\n            trainable=True,\n            dtype=self.dtype,\n        )\n        if self.use_bias:\n            self.bias = self.add_weight(\n                name=\"bias\",\n                shape=(self.depth_multiplier * input_channel,),\n                initializer=self.bias_initializer,\n                regularizer=self.bias_regularizer,\n                constraint=self.bias_constraint,\n                trainable=True,\n                dtype=self.dtype,\n            )\n        else:\n            self.bias = None\n\n    def _get_input_channel(self, input_shape):\n        if self.data_format == \"channels_last\":\n            input_channel = input_shape[-1]\n        else:\n            input_channel = input_shape[1]\n        return input_channel\n\n    def call(self, inputs):\n        input_channel = self._get_input_channel(inputs.shape)\n        outputs = ops.depthwise_conv(\n            inputs,\n            self.kernel,\n            strides=self.strides,\n            padding=self.padding,\n            dilation_rate=self.dilation_rate,\n            data_format=self.data_format,\n        )\n\n        if self.use_bias:\n            if self.data_format == \"channels_last\":\n                bias_shape = (1,) * (self.rank + 1) + (\n                    self.depth_multiplier * input_channel,\n                )\n            else:\n                bias_shape = (1, self.depth_multiplier * input_channel) + (\n                    1,\n                ) * self.rank\n            bias = ops.reshape(self.bias, bias_shape)\n            outputs = ops.add(outputs, bias)\n\n        if self.activation is not None:\n            return self.activation(outputs)\n        return outputs\n\n    def compute_output_shape(self, input_shape):\n        input_channel = self._get_input_channel(input_shape)\n        return compute_conv_output_shape(\n            input_shape,\n            self.depth_multiplier * input_channel,\n            self.kernel_size,\n            strides=self.strides,\n            padding=self.padding,\n            data_format=self.data_format,\n            dilation_rate=self.dilation_rate,\n        )\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"depth_multiplier\": self.depth_multiplier,\n                \"kernel_size\": self.kernel_size,\n                \"strides\": self.strides,\n                \"padding\": self.padding,\n                \"data_format\": self.data_format,\n                \"dilation_rate\": self.dilation_rate,\n                \"activation\": activations.serialize(self.activation),\n                \"use_bias\": self.use_bias,\n                \"depthwise_initializer\": initializers.serialize(\n                    self.depthwise_initializer\n                ),\n                \"bias_initializer\": initializers.serialize(\n                    self.bias_initializer\n                ),\n                \"depthwise_regularizer\": regularizers.serialize(\n                    self.depthwise_regularizer\n                ),\n                \"bias_regularizer\": regularizers.serialize(\n                    self.bias_regularizer\n                ),\n                \"activity_regularizer\": regularizers.serialize(\n                    self.activity_regularizer\n                ),\n                \"depthwise_constraint\": constraints.serialize(\n                    self.depthwise_constraint\n                ),\n                \"bias_constraint\": constraints.serialize(self.bias_constraint),\n            }\n        )\n        return config\n"
  },
  {
    "path": "keras/src/layers/convolutional/base_separable_conv.py",
    "content": "\"\"\"Keras abstract base layer for separable convolution.\"\"\"\n\nfrom keras.src import activations\nfrom keras.src import constraints\nfrom keras.src import initializers\nfrom keras.src import ops\nfrom keras.src import regularizers\nfrom keras.src.backend import standardize_data_format\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\nfrom keras.src.ops.operation_utils import compute_conv_output_shape\nfrom keras.src.utils.argument_validation import standardize_padding\nfrom keras.src.utils.argument_validation import standardize_tuple\n\n\nclass BaseSeparableConv(Layer):\n    \"\"\"Abstract base layer for separable convolution.\n\n    This layer performs a depthwise convolution that acts separately on\n    channels, followed by a pointwise convolution that mixes channels. If\n    `use_bias` is True and a bias initializer is provided, it adds a bias vector\n    to the output.\n\n    Args:\n        rank: int, the rank of the convolution, e.g. 2 for 2D convolution.\n        depth_multiplier: The number of depthwise convolution output channels\n            for each input channel. The total number of depthwise convolution\n            output channels will be equal to `input_channel * depth_multiplier`.\n        filters: int, the dimensionality of the output space (i.e. the number\n            of filters in the pointwise convolution).\n        kernel_size: int or tuple/list of `rank` integers, specifying the size\n            of the depthwise convolution window.\n        strides: int or tuple/list of `rank` integers, specifying the stride\n            length of the depthwise convolution. If only one int is specified,\n            the same stride size will be used for all dimensions.\n            `stride value != 1` is incompatible with `dilation_rate != 1`.\n        padding: string, either `\"valid\"` or `\"same\"` (case-insensitive).\n            `\"valid\"` means no padding. `\"same\"` results in padding evenly to\n            the left/right or up/down of the input. When `padding=\"same\"` and\n            `strides=1`, the output has the same size as the input.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, steps, features)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, features, steps)`. It defaults to the `image_data_format`\n            value found in your Keras config file at `~/.keras/keras.json`.\n            If you never set it, then it will be `\"channels_last\"`.\n        dilation_rate: int or tuple/list of `rank` integers, specifying the\n            dilation rate to use for dilated convolution. If only one int is\n            specified, the same dilation rate will be used for all dimensions.\n        activation: Activation function. If `None`, no activation is applied.\n        use_bias: bool, if `True`, bias will be added to the output.\n        depthwise_initializer: An initializer for the depthwise convolution\n            kernel. If None, then the default initializer (`\"glorot_uniform\"`)\n            will be used.\n        pointwise_initializer: An initializer for the pointwise convolution\n            kernel. If None, then the default initializer (`\"glorot_uniform\"`)\n            will be used.\n        bias_initializer: An initializer for the bias vector. If None, the\n            default initializer ('\"zeros\"') will be used.\n        depthwise_regularizer: Optional regularizer for the depthwise\n            convolution kernel.\n        pointwise_regularizer: Optional regularizer for the pointwise\n            convolution kernel.\n        bias_regularizer: Optional regularizer for the bias vector.\n        activity_regularizer: Optional regularizer function for the output.\n        depthwise_constraint: Optional projection function to be applied to the\n            depthwise kernel after being updated by an `Optimizer` (e.g. used\n            for norm constraints or value constraints for layer weights). The\n            function must take as input the unprojected variable and must return\n            the projected variable (which must have the same shape).\n        pointwise_constraint: Optional projection function to be applied to the\n            pointwise kernel after being updated by an `Optimizer`.\n        bias_constraint: Optional projection function to be applied to the\n            bias after being updated by an `Optimizer`.\n    \"\"\"\n\n    def __init__(\n        self,\n        rank,\n        depth_multiplier,\n        filters,\n        kernel_size,\n        strides=1,\n        padding=\"valid\",\n        data_format=None,\n        dilation_rate=1,\n        activation=None,\n        use_bias=True,\n        depthwise_initializer=\"glorot_uniform\",\n        pointwise_initializer=\"glorot_uniform\",\n        bias_initializer=\"zeros\",\n        depthwise_regularizer=None,\n        pointwise_regularizer=None,\n        bias_regularizer=None,\n        activity_regularizer=None,\n        depthwise_constraint=None,\n        pointwise_constraint=None,\n        bias_constraint=None,\n        trainable=True,\n        name=None,\n        **kwargs,\n    ):\n        super().__init__(\n            trainable=trainable,\n            name=name,\n            activity_regularizer=regularizers.get(activity_regularizer),\n            **kwargs,\n        )\n        self.rank = rank\n        self.depth_multiplier = depth_multiplier\n        self.filters = filters\n        self.kernel_size = standardize_tuple(kernel_size, rank, \"kernel_size\")\n        self.strides = standardize_tuple(strides, rank, \"strides\")\n        self.dilation_rate = standardize_tuple(\n            dilation_rate, rank, \"dilation_rate\"\n        )\n        self.padding = standardize_padding(padding)\n        self.data_format = standardize_data_format(data_format)\n        self.activation = activations.get(activation)\n        self.use_bias = use_bias\n        self.depthwise_initializer = initializers.get(depthwise_initializer)\n        self.pointwise_initializer = initializers.get(pointwise_initializer)\n        self.bias_initializer = initializers.get(bias_initializer)\n        self.depthwise_regularizer = regularizers.get(depthwise_regularizer)\n        self.pointwise_regularizer = regularizers.get(pointwise_regularizer)\n        self.bias_regularizer = regularizers.get(bias_regularizer)\n        self.depthwise_constraint = constraints.get(depthwise_constraint)\n        self.pointwise_constraint = constraints.get(pointwise_constraint)\n        self.bias_constraint = constraints.get(bias_constraint)\n        self.data_format = self.data_format\n\n        self.input_spec = InputSpec(min_ndim=self.rank + 2)\n\n        if self.depth_multiplier is not None and self.depth_multiplier <= 0:\n            raise ValueError(\n                \"Invalid value for argument `depth_multiplier`. Expected a \"\n                \"strictly positive value. Received \"\n                f\"depth_multiplier={self.depth_multiplier}.\"\n            )\n\n        if self.filters is not None and self.filters <= 0:\n            raise ValueError(\n                \"Invalid value for argument `filters`. Expected a strictly \"\n                f\"positive value. Received filters={self.filters}.\"\n            )\n\n        if not all(self.kernel_size):\n            raise ValueError(\n                \"The argument `kernel_size` cannot contain 0. Received: \"\n                f\"kernel_size={self.kernel_size}.\"\n            )\n\n        if not all(self.strides):\n            raise ValueError(\n                \"The argument `strides` cannot contains 0(s). Received: \"\n                f\"strides={self.strides}\"\n            )\n\n        if max(self.strides) > 1 and max(self.dilation_rate) > 1:\n            raise ValueError(\n                \"`strides > 1` not supported in conjunction with \"\n                f\"`dilation_rate > 1`. Received: strides={self.strides} and \"\n                f\"dilation_rate={self.dilation_rate}\"\n            )\n\n    def build(self, input_shape):\n        if self.data_format == \"channels_last\":\n            channel_axis = -1\n            input_channel = input_shape[-1]\n        else:\n            channel_axis = 1\n            input_channel = input_shape[1]\n        self.input_spec = InputSpec(\n            min_ndim=self.rank + 2, axes={channel_axis: input_channel}\n        )\n        depthwise_kernel_shape = self.kernel_size + (\n            input_channel,\n            self.depth_multiplier,\n        )\n        pointwise_kernel_shape = (1,) * self.rank + (\n            self.depth_multiplier * input_channel,\n            self.filters,\n        )\n\n        self.depthwise_kernel = self.add_weight(\n            name=\"depthwise_kernel\",\n            shape=depthwise_kernel_shape,\n            initializer=self.depthwise_initializer,\n            regularizer=self.depthwise_regularizer,\n            constraint=self.depthwise_constraint,\n            trainable=True,\n            dtype=self.dtype,\n        )\n        self.pointwise_kernel = self.add_weight(\n            name=\"pointwise_kernel\",\n            shape=pointwise_kernel_shape,\n            initializer=self.pointwise_initializer,\n            regularizer=self.pointwise_regularizer,\n            constraint=self.pointwise_constraint,\n            trainable=True,\n            dtype=self.dtype,\n        )\n        if self.use_bias:\n            self.bias = self.add_weight(\n                name=\"bias\",\n                shape=(self.filters,),\n                initializer=self.bias_initializer,\n                regularizer=self.bias_regularizer,\n                constraint=self.bias_constraint,\n                trainable=True,\n                dtype=self.dtype,\n            )\n        else:\n            self.bias = None\n\n    def call(self, inputs):\n        outputs = ops.separable_conv(\n            inputs,\n            self.depthwise_kernel,\n            self.pointwise_kernel,\n            strides=self.strides,\n            padding=self.padding,\n            dilation_rate=self.dilation_rate,\n            data_format=self.data_format,\n        )\n\n        if self.use_bias:\n            if self.data_format == \"channels_last\":\n                bias_shape = (1,) * (self.rank + 1) + (self.filters,)\n            else:\n                bias_shape = (1, self.filters) + (1,) * self.rank\n            bias = ops.reshape(self.bias, bias_shape)\n            outputs = ops.add(outputs, bias)\n\n        if self.activation is not None:\n            return self.activation(outputs)\n        return outputs\n\n    def compute_output_shape(self, input_shape):\n        return compute_conv_output_shape(\n            input_shape,\n            self.filters,\n            self.kernel_size,\n            strides=self.strides,\n            padding=self.padding,\n            data_format=self.data_format,\n            dilation_rate=self.dilation_rate,\n        )\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"depth_multiplier\": self.depth_multiplier,\n                \"filters\": self.filters,\n                \"kernel_size\": self.kernel_size,\n                \"strides\": self.strides,\n                \"padding\": self.padding,\n                \"data_format\": self.data_format,\n                \"dilation_rate\": self.dilation_rate,\n                \"activation\": activations.serialize(self.activation),\n                \"use_bias\": self.use_bias,\n                \"depthwise_initializer\": initializers.serialize(\n                    self.depthwise_initializer\n                ),\n                \"pointwise_initializer\": initializers.serialize(\n                    self.pointwise_initializer\n                ),\n                \"bias_initializer\": initializers.serialize(\n                    self.bias_initializer\n                ),\n                \"depthwise_regularizer\": regularizers.serialize(\n                    self.depthwise_regularizer\n                ),\n                \"pointwise_regularizer\": regularizers.serialize(\n                    self.pointwise_regularizer\n                ),\n                \"bias_regularizer\": regularizers.serialize(\n                    self.bias_regularizer\n                ),\n                \"activity_regularizer\": regularizers.serialize(\n                    self.activity_regularizer\n                ),\n                \"depthwise_constraint\": constraints.serialize(\n                    self.depthwise_constraint\n                ),\n                \"pointwise_constraint\": constraints.serialize(\n                    self.pointwise_constraint\n                ),\n                \"bias_constraint\": constraints.serialize(self.bias_constraint),\n            }\n        )\n        return config\n"
  },
  {
    "path": "keras/src/layers/convolutional/conv1d.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.convolutional.base_conv import BaseConv\n\n\n@keras_export([\"keras.layers.Conv1D\", \"keras.layers.Convolution1D\"])\nclass Conv1D(BaseConv):\n    \"\"\"1D convolution layer (e.g. temporal convolution).\n\n    This layer creates a convolution kernel that is convolved with the layer\n    input over a single spatial (or temporal) dimension to produce a tensor of\n    outputs. If `use_bias` is True, a bias vector is created and added to the\n    outputs. Finally, if `activation` is not `None`, it is applied to the\n    outputs as well.\n\n    Args:\n        filters: int, the dimension of the output space (the number of filters\n            in the convolution).\n        kernel_size: int or tuple/list of 1 integer, specifying the size of the\n            convolution window.\n        strides: int or tuple/list of 1 integer, specifying the stride length\n            of the convolution. `strides > 1` is incompatible with\n            `dilation_rate > 1`.\n        padding: string, `\"valid\"`, `\"same\"` or `\"causal\"`(case-insensitive).\n            `\"valid\"` means no padding. `\"same\"` results in padding evenly to\n            the left/right or up/down of the input. When `padding=\"same\"` and\n            `strides=1`, the output has the same size as the input.\n            `\"causal\"` results in causal(dilated) convolutions, e.g. `output[t]`\n            does not depend on`input[t+1:]`. Useful when modeling temporal data\n            where the model should not violate the temporal order.\n            See [WaveNet: A Generative Model for Raw Audio, section2.1](\n            https://arxiv.org/abs/1609.03499).\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, steps, features)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, features, steps)`. It defaults to the `image_data_format`\n            value found in your Keras config file at `~/.keras/keras.json`.\n            If you never set it, then it will be `\"channels_last\"`.\n        dilation_rate: int or tuple/list of 1 integers, specifying the dilation\n            rate to use for dilated convolution.\n        groups: A positive int specifying the number of groups in which the\n            input is split along the channel axis. Each group is convolved\n            separately with `filters // groups` filters. The output is the\n            concatenation of all the `groups` results along the channel axis.\n            Input channels and `filters` must both be divisible by `groups`.\n        activation: Activation function. If `None`, no activation is applied.\n        use_bias: bool, if `True`, bias will be added to the output.\n        kernel_initializer: Initializer for the convolution kernel. If `None`,\n            the default initializer (`\"glorot_uniform\"`) will be used.\n        bias_initializer: Initializer for the bias vector. If `None`, the\n            default initializer (`\"zeros\"`) will be used.\n        kernel_regularizer: Optional regularizer for the convolution kernel.\n        bias_regularizer: Optional regularizer for the bias vector.\n        activity_regularizer: Optional regularizer function for the output.\n        kernel_constraint: Optional projection function to be applied to the\n            kernel after being updated by an `Optimizer` (e.g. used to implement\n            norm constraints or value constraints for layer weights). The\n            function must take as input the unprojected variable and must return\n            the projected variable (which must have the same shape). Constraints\n            are not safe to use when doing asynchronous distributed training.\n        bias_constraint: Optional projection function to be applied to the\n            bias after being updated by an `Optimizer`.\n\n    Input shape:\n\n    - If `data_format=\"channels_last\"`:\n        A 3D tensor with shape: `(batch_shape, steps, channels)`\n    - If `data_format=\"channels_first\"`:\n        A 3D tensor with shape: `(batch_shape, channels, steps)`\n\n    Output shape:\n\n    - If `data_format=\"channels_last\"`:\n        A 3D tensor with shape: `(batch_shape, new_steps, filters)`\n    - If `data_format=\"channels_first\"`:\n        A 3D tensor with shape: `(batch_shape, filters, new_steps)`\n\n    Returns:\n        A 3D tensor representing `activation(conv1d(inputs, kernel) + bias)`.\n\n    Raises:\n        ValueError: when both `strides > 1` and `dilation_rate > 1`.\n\n    Example:\n\n    >>> # The inputs are 128-length vectors with 10 timesteps, and the\n    >>> # batch size is 4.\n    >>> x = np.random.rand(4, 10, 128)\n    >>> y = keras.layers.Conv1D(32, 3, activation='relu')(x)\n    >>> print(y.shape)\n    (4, 8, 32)\n    \"\"\"\n\n    def __init__(\n        self,\n        filters,\n        kernel_size,\n        strides=1,\n        padding=\"valid\",\n        data_format=None,\n        dilation_rate=1,\n        groups=1,\n        activation=None,\n        use_bias=True,\n        kernel_initializer=\"glorot_uniform\",\n        bias_initializer=\"zeros\",\n        kernel_regularizer=None,\n        bias_regularizer=None,\n        activity_regularizer=None,\n        kernel_constraint=None,\n        bias_constraint=None,\n        **kwargs,\n    ):\n        super().__init__(\n            rank=1,\n            filters=filters,\n            kernel_size=kernel_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n            groups=groups,\n            activation=activation,\n            use_bias=use_bias,\n            kernel_initializer=kernel_initializer,\n            bias_initializer=bias_initializer,\n            kernel_regularizer=kernel_regularizer,\n            bias_regularizer=bias_regularizer,\n            activity_regularizer=activity_regularizer,\n            kernel_constraint=kernel_constraint,\n            bias_constraint=bias_constraint,\n            **kwargs,\n        )\n\n    def _compute_causal_padding(self):\n        left_pad = self.dilation_rate[0] * (self.kernel_size[0] - 1)\n        if self.data_format == \"channels_last\":\n            causal_padding = [[0, 0], [left_pad, 0], [0, 0]]\n        else:\n            causal_padding = [[0, 0], [0, 0], [left_pad, 0]]\n        return causal_padding\n\n    def call(self, inputs):\n        padding = self.padding\n        if self.padding == \"causal\":\n            # Apply causal padding to inputs.\n            inputs = ops.pad(inputs, self._compute_causal_padding())\n            padding = \"valid\"\n\n        outputs = ops.conv(\n            inputs,\n            self.kernel,\n            strides=list(self.strides),\n            padding=padding,\n            dilation_rate=self.dilation_rate,\n            data_format=self.data_format,\n        )\n\n        if self.use_bias:\n            if self.data_format == \"channels_last\":\n                bias_shape = (1,) * (self.rank + 1) + (self.filters,)\n            else:\n                bias_shape = (1, self.filters) + (1,) * self.rank\n            bias = ops.reshape(self.bias, bias_shape)\n            outputs = ops.add(outputs, bias)\n\n        if self.activation is not None:\n            return self.activation(outputs)\n        return outputs\n"
  },
  {
    "path": "keras/src/layers/convolutional/conv1d_transpose.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.convolutional.base_conv_transpose import BaseConvTranspose\n\n\n@keras_export(\n    [\n        \"keras.layers.Conv1DTranspose\",\n        \"keras.layers.Convolution1DTranspose\",\n    ]\n)\nclass Conv1DTranspose(BaseConvTranspose):\n    \"\"\"1D transposed convolution layer.\n\n    The need for transposed convolutions generally arise from the desire to use\n    a transformation going in the opposite direction of a normal convolution,\n    i.e., from something that has the shape of the output of some convolution\n    to something that has the shape of its input while maintaining a\n    connectivity pattern that is compatible with said convolution.\n\n    Args:\n        filters: int, the dimension of the output space (the number of filters\n            in the transpose convolution).\n        kernel_size: int or tuple/list of 1 integer, specifying the size of the\n            transposed convolution window.\n        strides: int or tuple/list of 1 integer, specifying the stride length\n            of the transposed convolution. `strides > 1` is incompatible with\n            `dilation_rate > 1`.\n        padding: string, either `\"valid\"` or `\"same\"` (case-insensitive).\n            `\"valid\"` means no padding. `\"same\"` results in padding evenly to\n            the left/right or up/down of the input such that output has the same\n            height/width dimension as the input.\n        output_padding: An integer tuple/list of 1 integer specifying the\n            amount of padding along the time dimension of the output tensor.\n            The amount of output padding must be lower than the stride.\n            If set to `None` (default), the output shape is inferred.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, steps, features)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, features, steps)`. It defaults to the `image_data_format`\n            value found in your Keras config file at `~/.keras/keras.json`.\n            If you never set it, then it will be `\"channels_last\"`.\n        dilation_rate: An integer tuple/list of 1 integer, specifying\n            the dilation rate to use for dilated convolution.\n            Currently, specifying a `dilation_rate` value != 1 is\n            incompatible with specifying a stride value != 1.\n            Also dilation rate larger than 1 is not currently supported.\n        activation: Activation function. If `None`, no activation is applied.\n        use_bias: bool, if `True`, bias will be added to the output.\n        kernel_initializer: Initializer for the convolution kernel. If `None`,\n            the default initializer (`\"glorot_uniform\"`) will be used.\n        bias_initializer: Initializer for the bias vector. If `None`, the\n            default initializer (`\"zeros\"`) will be used.\n        kernel_regularizer: Optional regularizer for the convolution kernel.\n        bias_regularizer: Optional regularizer for the bias vector.\n        activity_regularizer: Optional regularizer function for the output.\n        kernel_constraint: Optional projection function to be applied to the\n            kernel after being updated by an `Optimizer` (e.g. used to implement\n            norm constraints or value constraints for layer weights). The\n            function must take as input the unprojected variable and must return\n            the projected variable (which must have the same shape). Constraints\n            are not safe to use when doing asynchronous distributed training.\n        bias_constraint: Optional projection function to be applied to the\n            bias after being updated by an `Optimizer`.\n\n    Input shape:\n\n    - If `data_format=\"channels_last\"`:\n        A 3D tensor with shape: `(batch_shape, steps, channels)`\n    - If `data_format=\"channels_first\"`:\n        A 3D tensor with shape: `(batch_shape, channels, steps)`\n\n    Output shape:\n\n    - If `data_format=\"channels_last\"`:\n        A 3D tensor with shape: `(batch_shape, new_steps, filters)`\n    - If `data_format=\"channels_first\"`:\n        A 3D tensor with shape: `(batch_shape, filters, new_steps)`\n\n    Returns:\n        A 3D tensor representing\n        `activation(conv1d_transpose(inputs, kernel) + bias)`.\n\n    Raises:\n        ValueError: when both `strides > 1` and `dilation_rate > 1`.\n\n    References:\n    - [A guide to convolution arithmetic for deep learning](\n        https://arxiv.org/abs/1603.07285v1)\n    - [Deconvolutional Networks](\n        https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf)\n\n    Example:\n\n    >>> x = np.random.rand(4, 10, 128)\n    >>> y = keras.layers.Conv1DTranspose(32, 3, 2, activation='relu')(x)\n    >>> print(y.shape)\n    (4, 21, 32)\n    \"\"\"\n\n    def __init__(\n        self,\n        filters,\n        kernel_size,\n        strides=1,\n        padding=\"valid\",\n        output_padding=None,\n        data_format=None,\n        dilation_rate=1,\n        activation=None,\n        use_bias=True,\n        kernel_initializer=\"glorot_uniform\",\n        bias_initializer=\"zeros\",\n        kernel_regularizer=None,\n        bias_regularizer=None,\n        activity_regularizer=None,\n        kernel_constraint=None,\n        bias_constraint=None,\n        **kwargs,\n    ):\n        super().__init__(\n            rank=1,\n            filters=filters,\n            kernel_size=kernel_size,\n            strides=strides,\n            padding=padding,\n            output_padding=output_padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n            activation=activation,\n            use_bias=use_bias,\n            kernel_initializer=kernel_initializer,\n            bias_initializer=bias_initializer,\n            kernel_regularizer=kernel_regularizer,\n            bias_regularizer=bias_regularizer,\n            activity_regularizer=activity_regularizer,\n            kernel_constraint=kernel_constraint,\n            bias_constraint=bias_constraint,\n            **kwargs,\n        )\n"
  },
  {
    "path": "keras/src/layers/convolutional/conv2d.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.convolutional.base_conv import BaseConv\n\n\n@keras_export([\"keras.layers.Conv2D\", \"keras.layers.Convolution2D\"])\nclass Conv2D(BaseConv):\n    \"\"\"2D convolution layer.\n\n    This layer creates a convolution kernel that is convolved with the layer\n    input over a 2D spatial (or temporal) dimension (height and width) to\n    produce a tensor of outputs. If `use_bias` is True, a bias vector is created\n    and added to the outputs. Finally, if `activation` is not `None`, it is\n    applied to the outputs as well.\n\n    Note on numerical precision: While in general Keras operation execution\n    results are identical across backends up to 1e-7 precision in float32,\n    `Conv2D` operations may show larger variations. Due to the large\n    number of element-wise multiplications and additions in convolution\n    operations, especially with large inputs or kernel sizes, accumulated\n    floating-point differences can exceed this 1e-7 threshold. These variations\n    are particularly noticeable when using different backends (e.g., TensorFlow\n    vs JAX) or different hardware.\n\n    Args:\n        filters: int, the dimension of the output space (the number of filters\n            in the convolution).\n        kernel_size: int or tuple/list of 2 integer, specifying the size of the\n            convolution window.\n        strides: int or tuple/list of 2 integer, specifying the stride length\n            of the convolution. `strides > 1` is incompatible with\n            `dilation_rate > 1`.\n        padding: string, either `\"valid\"` or `\"same\"` (case-insensitive).\n            `\"valid\"` means no padding. `\"same\"` results in padding evenly to\n            the left/right or up/down of the input. When `padding=\"same\"` and\n            `strides=1`, the output has the same size as the input.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape\n            `(batch_size, height, width, channels)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch_size, channels, height, width)`. It defaults to the\n            `image_data_format` value found in your Keras config file at\n            `~/.keras/keras.json`. If you never set it, then it will be\n            `\"channels_last\"`.\n        dilation_rate: int or tuple/list of 2 integers, specifying the dilation\n            rate to use for dilated convolution.\n        groups: A positive int specifying the number of groups in which the\n            input is split along the channel axis. Each group is convolved\n            separately with `filters // groups` filters. The output is the\n            concatenation of all the `groups` results along the channel axis.\n            Input channels and `filters` must both be divisible by `groups`.\n        activation: Activation function. If `None`, no activation is applied.\n        use_bias: bool, if `True`, bias will be added to the output.\n        kernel_initializer: Initializer for the convolution kernel. If `None`,\n            the default initializer (`\"glorot_uniform\"`) will be used.\n        bias_initializer: Initializer for the bias vector. If `None`, the\n            default initializer (`\"zeros\"`) will be used.\n        kernel_regularizer: Optional regularizer for the convolution kernel.\n        bias_regularizer: Optional regularizer for the bias vector.\n        activity_regularizer: Optional regularizer function for the output.\n        kernel_constraint: Optional projection function to be applied to the\n            kernel after being updated by an `Optimizer` (e.g. used to implement\n            norm constraints or value constraints for layer weights). The\n            function must take as input the unprojected variable and must return\n            the projected variable (which must have the same shape). Constraints\n            are not safe to use when doing asynchronous distributed training.\n        bias_constraint: Optional projection function to be applied to the\n            bias after being updated by an `Optimizer`.\n\n    Input shape:\n\n    - If `data_format=\"channels_last\"`:\n        A 4D tensor with shape: `(batch_size, height, width, channels)`\n    - If `data_format=\"channels_first\"`:\n        A 4D tensor with shape: `(batch_size, channels, height, width)`\n\n    Output shape:\n\n    - If `data_format=\"channels_last\"`:\n        A 4D tensor with shape: `(batch_size, new_height, new_width, filters)`\n    - If `data_format=\"channels_first\"`:\n        A 4D tensor with shape: `(batch_size, filters, new_height, new_width)`\n\n    Returns:\n        A 4D tensor representing `activation(conv2d(inputs, kernel) + bias)`.\n\n    Raises:\n        ValueError: when both `strides > 1` and `dilation_rate > 1`.\n\n    Example:\n\n    >>> x = np.random.rand(4, 10, 10, 128)\n    >>> y = keras.layers.Conv2D(32, 3, activation='relu')(x)\n    >>> print(y.shape)\n    (4, 8, 8, 32)\n    \"\"\"\n\n    def __init__(\n        self,\n        filters,\n        kernel_size,\n        strides=(1, 1),\n        padding=\"valid\",\n        data_format=None,\n        dilation_rate=(1, 1),\n        groups=1,\n        activation=None,\n        use_bias=True,\n        kernel_initializer=\"glorot_uniform\",\n        bias_initializer=\"zeros\",\n        kernel_regularizer=None,\n        bias_regularizer=None,\n        activity_regularizer=None,\n        kernel_constraint=None,\n        bias_constraint=None,\n        **kwargs,\n    ):\n        super().__init__(\n            rank=2,\n            filters=filters,\n            kernel_size=kernel_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n            groups=groups,\n            activation=activation,\n            use_bias=use_bias,\n            kernel_initializer=kernel_initializer,\n            bias_initializer=bias_initializer,\n            kernel_regularizer=kernel_regularizer,\n            bias_regularizer=bias_regularizer,\n            activity_regularizer=activity_regularizer,\n            kernel_constraint=kernel_constraint,\n            bias_constraint=bias_constraint,\n            **kwargs,\n        )\n"
  },
  {
    "path": "keras/src/layers/convolutional/conv2d_transpose.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.convolutional.base_conv_transpose import BaseConvTranspose\n\n\n@keras_export(\n    [\n        \"keras.layers.Conv2DTranspose\",\n        \"keras.layers.Convolution2DTranspose\",\n    ]\n)\nclass Conv2DTranspose(BaseConvTranspose):\n    \"\"\"2D transposed convolution layer.\n\n    The need for transposed convolutions generally arise from the desire to use\n    a transformation going in the opposite direction of a normal convolution,\n    i.e., from something that has the shape of the output of some convolution\n    to something that has the shape of its input while maintaining a\n    connectivity pattern that is compatible with said convolution.\n\n    Args:\n        filters: int, the dimension of the output space (the number of filters\n            in the transposed convolution).\n        kernel_size: int or tuple/list of 1 integer, specifying the size of the\n            transposed convolution window.\n        strides: int or tuple/list of 1 integer, specifying the stride length\n            of the transposed convolution. `strides > 1` is incompatible with\n            `dilation_rate > 1`.\n        padding: string, either `\"valid\"` or `\"same\"` (case-insensitive).\n            `\"valid\"` means no padding. `\"same\"` results in padding evenly to\n            the left/right or up/down of the input. When `padding=\"same\"` and\n            `strides=1`, the output has the same size as the input.\n        output_padding: An integer or tuple/list of 2 integers,\n            specifying the amount of padding along the height and width\n            of the output tensor.\n            Can be a single integer to specify the same value for all\n            spatial dimensions.\n            The amount of output padding along a given dimension must be\n            lower than the stride along that same dimension.\n            If set to `None` (default), the output shape is inferred.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape\n            `(batch_size, height, width, channels)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch_size, channels, height, width)`. It defaults to the\n            `image_data_format` value found in your Keras config file at\n            `~/.keras/keras.json`. If you never set it, then it will be\n            `\"channels_last\"`.\n         dilation_rate: An integer or tuple/list of 2 integers,\n            specifying the dilation rate for\n            all spatial dimensions for dilated convolution.\n            Specifying different dilation rates\n            for different dimensions is not supported.\n            Currently, specifying any `dilation_rate` value != 1 is\n            incompatible with specifying any stride value != 1.\n        activation: Activation function. If `None`, no activation is applied.\n        use_bias: bool, if `True`, bias will be added to the output.\n        kernel_initializer: Initializer for the convolution kernel. If `None`,\n            the default initializer (`\"glorot_uniform\"`) will be used.\n        bias_initializer: Initializer for the bias vector. If `None`, the\n            default initializer (`\"zeros\"`) will be used.\n        kernel_regularizer: Optional regularizer for the convolution kernel.\n        bias_regularizer: Optional regularizer for the bias vector.\n        activity_regularizer: Optional regularizer function for the output.\n        kernel_constraint: Optional projection function to be applied to the\n            kernel after being updated by an `Optimizer` (e.g. used to implement\n            norm constraints or value constraints for layer weights). The\n            function must take as input the unprojected variable and must return\n            the projected variable (which must have the same shape). Constraints\n            are not safe to use when doing asynchronous distributed training.\n        bias_constraint: Optional projection function to be applied to the\n            bias after being updated by an `Optimizer`.\n\n    Input shape:\n\n    - If `data_format=\"channels_last\"`:\n        A 4D tensor with shape: `(batch_size, height, width, channels)`\n    - If `data_format=\"channels_first\"`:\n        A 4D tensor with shape: `(batch_size, channels, height, width)`\n\n    Output shape:\n\n    - If `data_format=\"channels_last\"`:\n        A 4D tensor with shape: `(batch_size, new_height, new_width, filters)`\n    - If `data_format=\"channels_first\"`:\n        A 4D tensor with shape: `(batch_size, filters, new_height, new_width)`\n\n    Returns:\n        A 4D tensor representing\n        `activation(conv2d_transpose(inputs, kernel) + bias)`.\n\n    Raises:\n        ValueError: when both `strides > 1` and `dilation_rate > 1`.\n\n    References:\n    - [A guide to convolution arithmetic for deep learning](\n        https://arxiv.org/abs/1603.07285v1)\n    - [Deconvolutional Networks](\n        https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf)\n\n    Example:\n\n    >>> x = np.random.rand(4, 10, 8, 128)\n    >>> y = keras.layers.Conv2DTranspose(32, 2, 2, activation='relu')(x)\n    >>> print(y.shape)\n    (4, 20, 16, 32)\n    \"\"\"\n\n    def __init__(\n        self,\n        filters,\n        kernel_size,\n        strides=(1, 1),\n        padding=\"valid\",\n        output_padding=None,\n        data_format=None,\n        dilation_rate=(1, 1),\n        activation=None,\n        use_bias=True,\n        kernel_initializer=\"glorot_uniform\",\n        bias_initializer=\"zeros\",\n        kernel_regularizer=None,\n        bias_regularizer=None,\n        activity_regularizer=None,\n        kernel_constraint=None,\n        bias_constraint=None,\n        **kwargs,\n    ):\n        super().__init__(\n            rank=2,\n            filters=filters,\n            kernel_size=kernel_size,\n            strides=strides,\n            padding=padding,\n            output_padding=output_padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n            activation=activation,\n            use_bias=use_bias,\n            kernel_initializer=kernel_initializer,\n            bias_initializer=bias_initializer,\n            kernel_regularizer=kernel_regularizer,\n            bias_regularizer=bias_regularizer,\n            activity_regularizer=activity_regularizer,\n            kernel_constraint=kernel_constraint,\n            bias_constraint=bias_constraint,\n            **kwargs,\n        )\n"
  },
  {
    "path": "keras/src/layers/convolutional/conv3d.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.convolutional.base_conv import BaseConv\n\n\n@keras_export([\"keras.layers.Conv3D\", \"keras.layers.Convolution3D\"])\nclass Conv3D(BaseConv):\n    \"\"\"3D convolution layer.\n\n    This layer creates a convolution kernel that is convolved with the layer\n    input over a 3D spatial (or temporal) dimension (width,height and depth) to\n    produce a tensor of outputs. If `use_bias` is True, a bias vector is created\n    and added to the outputs. Finally, if `activation` is not `None`, it is\n    applied to the outputs as well.\n\n    Args:\n        filters: int, the dimension of the output space (the number of filters\n            in the convolution).\n        kernel_size: int or tuple/list of 3 integer, specifying the size of the\n            convolution window.\n        strides: int or tuple/list of 3 integer, specifying the stride length\n            of the convolution. `strides > 1` is incompatible with\n            `dilation_rate > 1`.\n        padding: string, either `\"valid\"` or `\"same\"` (case-insensitive).\n            `\"valid\"` means no padding. `\"same\"` results in padding evenly to\n            the left/right or up/down of the input. When `padding=\"same\"` and\n            `strides=1`, the output has the same size as the input.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape\n            `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.\n            It defaults to the `image_data_format` value found in your Keras\n            config file at `~/.keras/keras.json`. If you never set it, then it\n            will be `\"channels_last\"`.\n        dilation_rate: int or tuple/list of 3 integers, specifying the dilation\n            rate to use for dilated convolution.\n        groups: A positive int specifying the number of groups in which the\n            input is split along the channel axis. Each group is convolved\n            separately with `filters // groups` filters. The output is the\n            concatenation of all the `groups` results along the channel axis.\n            Input channels and `filters` must both be divisible by `groups`.\n        activation: Activation function. If `None`, no activation is applied.\n        use_bias: bool, if `True`, bias will be added to the output.\n        kernel_initializer: Initializer for the convolution kernel. If `None`,\n            the default initializer (`\"glorot_uniform\"`) will be used.\n        bias_initializer: Initializer for the bias vector. If `None`, the\n            default initializer (`\"zeros\"`) will be used.\n        kernel_regularizer: Optional regularizer for the convolution kernel.\n        bias_regularizer: Optional regularizer for the bias vector.\n        activity_regularizer: Optional regularizer function for the output.\n        kernel_constraint: Optional projection function to be applied to the\n            kernel after being updated by an `Optimizer` (e.g. used to implement\n            norm constraints or value constraints for layer weights). The\n            function must take as input the unprojected variable and must return\n            the projected variable (which must have the same shape). Constraints\n            are not safe to use when doing asynchronous distributed training.\n        bias_constraint: Optional projection function to be applied to the\n            bias after being updated by an `Optimizer`.\n\n    Input shape:\n\n    - If `data_format=\"channels_last\"`:\n        5D tensor with shape:\n        `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`\n    - If `data_format=\"channels_first\"`:\n        5D tensor with shape:\n        `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`\n\n    Output shape:\n\n    - If `data_format=\"channels_last\"`:\n        5D tensor with shape:\n        `(batch_size, new_spatial_dim1, new_spatial_dim2, new_spatial_dim3,\n        filters)`\n    - If `data_format=\"channels_first\"`:\n        5D tensor with shape:\n        `(batch_size, filters, new_spatial_dim1, new_spatial_dim2,\n        new_spatial_dim3)`\n\n    Returns:\n        A 5D tensor representing `activation(conv3d(inputs, kernel) + bias)`.\n\n    Raises:\n        ValueError: when both `strides > 1` and `dilation_rate > 1`.\n\n    Example:\n\n    >>> x = np.random.rand(4, 10, 10, 10, 128)\n    >>> y = keras.layers.Conv3D(32, 3, activation='relu')(x)\n    >>> print(y.shape)\n    (4, 8, 8, 8, 32)\n    \"\"\"\n\n    def __init__(\n        self,\n        filters,\n        kernel_size,\n        strides=(1, 1, 1),\n        padding=\"valid\",\n        data_format=None,\n        dilation_rate=(1, 1, 1),\n        groups=1,\n        activation=None,\n        use_bias=True,\n        kernel_initializer=\"glorot_uniform\",\n        bias_initializer=\"zeros\",\n        kernel_regularizer=None,\n        bias_regularizer=None,\n        activity_regularizer=None,\n        kernel_constraint=None,\n        bias_constraint=None,\n        **kwargs,\n    ):\n        super().__init__(\n            rank=3,\n            filters=filters,\n            kernel_size=kernel_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n            groups=groups,\n            activation=activation,\n            use_bias=use_bias,\n            kernel_initializer=kernel_initializer,\n            bias_initializer=bias_initializer,\n            kernel_regularizer=kernel_regularizer,\n            bias_regularizer=bias_regularizer,\n            activity_regularizer=activity_regularizer,\n            kernel_constraint=kernel_constraint,\n            bias_constraint=bias_constraint,\n            **kwargs,\n        )\n"
  },
  {
    "path": "keras/src/layers/convolutional/conv3d_transpose.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.convolutional.base_conv_transpose import BaseConvTranspose\n\n\n@keras_export(\n    [\n        \"keras.layers.Conv3DTranspose\",\n        \"keras.layers.Convolution3DTranspose\",\n    ]\n)\nclass Conv3DTranspose(BaseConvTranspose):\n    \"\"\"3D transposed convolution layer.\n\n    The need for transposed convolutions generally arise from the desire to use\n    a transformation going in the opposite direction of a normal convolution,\n    i.e., from something that has the shape of the output of some convolution\n    to something that has the shape of its input while maintaining a\n    connectivity pattern that is compatible with said convolution.\n\n    Args:\n        filters: int, the dimension of the output space (the number of filters\n            in the transposed convolution).\n        kernel_size: int or tuple/list of 1 integer, specifying the size of the\n            transposed convolution window.\n        strides: int or tuple/list of 1 integer, specifying the stride length\n            of the transposed convolution. `strides > 1` is incompatible with\n            `dilation_rate > 1`.\n        padding: string, either `\"valid\"` or `\"same\"` (case-insensitive).\n            `\"valid\"` means no padding. `\"same\"` results in padding evenly to\n            the left/right or up/down of the input. When `padding=\"same\"` and\n            `strides=1`, the output has the same size as the input.\n         output_padding: An integer or tuple/list of 3 integers,\n            specifying the amount of padding along the depth, height, and\n            width.\n            Can be a single integer to specify the same value for all\n            spatial dimensions.\n            The amount of output padding along a given dimension must be\n            lower than the stride along that same dimension.\n            If set to `None` (default), the output shape is inferred.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape\n            `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.\n            It defaults to the `image_data_format` value found in your Keras\n            config file at `~/.keras/keras.json`. If you never set it, then it\n            will be `\"channels_last\"`.\n        dilation_rate: an integer or tuple/list of 3 integers, specifying\n            the dilation rate to use for dilated convolution.\n            Can be a single integer to specify the same value for\n            all spatial dimensions.\n            Currently, specifying any `dilation_rate` value != 1 is\n            incompatible with specifying any stride value != 1.\n        activation: Activation function. If `None`, no activation is applied.\n        use_bias: bool, if `True`, bias will be added to the output.\n        kernel_initializer: Initializer for the convolution kernel. If `None`,\n            the default initializer (`\"glorot_uniform\"`) will be used.\n        bias_initializer: Initializer for the bias vector. If `None`, the\n            default initializer (`\"zeros\"`) will be used.\n        kernel_regularizer: Optional regularizer for the convolution kernel.\n        bias_regularizer: Optional regularizer for the bias vector.\n        activity_regularizer: Optional regularizer function for the output.\n        kernel_constraint: Optional projection function to be applied to the\n            kernel after being updated by an `Optimizer` (e.g. used to implement\n            norm constraints or value constraints for layer weights). The\n            function must take as input the unprojected variable and must return\n            the projected variable (which must have the same shape). Constraints\n            are not safe to use when doing asynchronous distributed training.\n        bias_constraint: Optional projection function to be applied to the\n            bias after being updated by an `Optimizer`.\n\n    Input shape:\n\n    - If `data_format=\"channels_last\"`:\n        5D tensor with shape:\n        `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`\n    - If `data_format=\"channels_first\"`:\n        5D tensor with shape:\n        `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`\n\n    Output shape:\n\n    - If `data_format=\"channels_last\"`:\n        5D tensor with shape:\n        `(batch_size, new_spatial_dim1, new_spatial_dim2, new_spatial_dim3,\n        filters)`\n    - If `data_format=\"channels_first\"`:\n        5D tensor with shape:\n        `(batch_size, filters, new_spatial_dim1, new_spatial_dim2,\n        new_spatial_dim3)`\n\n    Returns:\n        A 5D tensor representing `activation(conv3d(inputs, kernel) + bias)`.\n\n    Raises:\n        ValueError: when both `strides > 1` and `dilation_rate > 1`.\n\n    References:\n    - [A guide to convolution arithmetic for deep learning](\n        https://arxiv.org/abs/1603.07285v1)\n    - [Deconvolutional Networks](\n        https://www.matthewzeiler.com/mattzeiler/deconvolutionalnetworks.pdf)\n\n    Example:\n\n    >>> x = np.random.rand(4, 10, 8, 12, 128)\n    >>> y = keras.layers.Conv3DTranspose(32, 2, 2, activation='relu')(x)\n    >>> print(y.shape)\n    (4, 20, 16, 24, 32)\n    \"\"\"\n\n    def __init__(\n        self,\n        filters,\n        kernel_size,\n        strides=(1, 1, 1),\n        padding=\"valid\",\n        data_format=None,\n        output_padding=None,\n        dilation_rate=(1, 1, 1),\n        activation=None,\n        use_bias=True,\n        kernel_initializer=\"glorot_uniform\",\n        bias_initializer=\"zeros\",\n        kernel_regularizer=None,\n        bias_regularizer=None,\n        activity_regularizer=None,\n        kernel_constraint=None,\n        bias_constraint=None,\n        **kwargs,\n    ):\n        super().__init__(\n            rank=3,\n            filters=filters,\n            kernel_size=kernel_size,\n            strides=strides,\n            padding=padding,\n            output_padding=output_padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n            activation=activation,\n            use_bias=use_bias,\n            kernel_initializer=kernel_initializer,\n            bias_initializer=bias_initializer,\n            kernel_regularizer=kernel_regularizer,\n            bias_regularizer=bias_regularizer,\n            activity_regularizer=activity_regularizer,\n            kernel_constraint=kernel_constraint,\n            bias_constraint=bias_constraint,\n            **kwargs,\n        )\n"
  },
  {
    "path": "keras/src/layers/convolutional/conv_test.py",
    "content": "import os\n\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\nfrom numpy.lib.stride_tricks import as_strided\n\nfrom keras.src import backend\nfrom keras.src import constraints\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import saving\nfrom keras.src import testing\n\n\ndef _same_padding(input_size, kernel_size, stride):\n    if input_size % stride == 0:\n        padding = max(kernel_size - stride, 0)\n    else:\n        padding = max(kernel_size - (input_size % stride), 0)\n    return padding // 2, padding - padding // 2\n\n\ndef np_conv1d(\n    x,\n    kernel_weights,\n    bias_weights,\n    strides,\n    padding,\n    data_format,\n    dilation_rate,\n    groups,\n):\n    if data_format == \"channels_first\":\n        x = x.swapaxes(1, 2)\n    if isinstance(strides, (tuple, list)):\n        h_stride = strides[0]\n    else:\n        h_stride = strides\n    if isinstance(dilation_rate, (tuple, list)):\n        dilation_rate = dilation_rate[0]\n    kernel_size, ch_in, ch_out = kernel_weights.shape\n\n    if dilation_rate > 1:\n        new_kernel_size = kernel_size + (dilation_rate - 1) * (kernel_size - 1)\n        new_kernel_weights = np.zeros(\n            (new_kernel_size, ch_in, ch_out), dtype=kernel_weights.dtype\n        )\n        new_kernel_weights[::dilation_rate] = kernel_weights\n        kernel_weights = new_kernel_weights\n        kernel_size = kernel_weights.shape[0]\n\n    if padding != \"valid\":\n        n_batch, h_x, _ = x.shape\n        h_pad = _same_padding(h_x, kernel_size, h_stride)\n        npad = [(0, 0)] * x.ndim\n        if padding == \"causal\":\n            npad[1] = (h_pad[0] + h_pad[1], 0)\n        else:\n            npad[1] = h_pad\n        x = np.pad(x, pad_width=npad, mode=\"constant\", constant_values=0)\n\n    n_batch, h_x, _ = x.shape\n    h_out = int((h_x - kernel_size) / h_stride) + 1\n\n    kernel_weights = kernel_weights.reshape(-1, ch_out)\n    bias_weights = bias_weights.reshape(1, ch_out)\n\n    out_grps = []\n    for grp in range(1, groups + 1):\n        x_in = x[..., (grp - 1) * ch_in : grp * ch_in]\n        stride_shape = (n_batch, h_out, kernel_size, ch_in)\n        strides = (\n            x_in.strides[0],\n            h_stride * x_in.strides[1],\n            x_in.strides[1],\n            x_in.strides[2],\n        )\n        inner_dim = kernel_size * ch_in\n        x_strided = as_strided(\n            x_in, shape=stride_shape, strides=strides\n        ).reshape(n_batch, h_out, inner_dim)\n        ch_out_groups = ch_out // groups\n        kernel_weights_grp = kernel_weights[\n            ..., (grp - 1) * ch_out_groups : grp * ch_out_groups\n        ]\n        bias_weights_grp = bias_weights[\n            ..., (grp - 1) * ch_out_groups : grp * ch_out_groups\n        ]\n        out_grps.append(x_strided @ kernel_weights_grp + bias_weights_grp)\n    out = np.concatenate(out_grps, axis=-1)\n    if data_format == \"channels_first\":\n        out = out.swapaxes(1, 2)\n    return out\n\n\ndef np_conv2d(\n    x,\n    kernel_weights,\n    bias_weights,\n    strides,\n    padding,\n    data_format,\n    dilation_rate,\n    groups,\n):\n    if data_format == \"channels_first\":\n        x = x.transpose((0, 2, 3, 1))\n    if isinstance(strides, (tuple, list)):\n        h_stride, w_stride = strides\n    else:\n        h_stride = strides\n        w_stride = strides\n    if isinstance(dilation_rate, (tuple, list)):\n        h_dilation, w_dilation = dilation_rate\n    else:\n        h_dilation = dilation_rate\n        w_dilation = dilation_rate\n    h_kernel, w_kernel, ch_in, ch_out = kernel_weights.shape\n\n    if h_dilation > 1 or w_dilation > 1:\n        new_h_kernel = h_kernel + (h_dilation - 1) * (h_kernel - 1)\n        new_w_kernel = w_kernel + (w_dilation - 1) * (w_kernel - 1)\n        new_kenel_size_tuple = (new_h_kernel, new_w_kernel)\n        new_kernel_weights = np.zeros(\n            (*new_kenel_size_tuple, ch_in, ch_out),\n            dtype=kernel_weights.dtype,\n        )\n        new_kernel_weights[::h_dilation, ::w_dilation] = kernel_weights\n        kernel_weights = new_kernel_weights\n        h_kernel, w_kernel = kernel_weights.shape[:2]\n\n    if padding == \"same\":\n        n_batch, h_x, w_x, _ = x.shape\n        h_pad = _same_padding(h_x, h_kernel, h_stride)\n        w_pad = _same_padding(w_x, w_kernel, w_stride)\n        npad = [(0, 0)] * x.ndim\n        npad[1] = h_pad\n        npad[2] = w_pad\n        x = np.pad(x, pad_width=npad, mode=\"constant\", constant_values=0)\n\n    n_batch, h_x, w_x, _ = x.shape\n    h_out = int((h_x - h_kernel) / h_stride) + 1\n    w_out = int((w_x - w_kernel) / w_stride) + 1\n\n    out_grps = []\n    for grp in range(1, groups + 1):\n        x_in = x[..., (grp - 1) * ch_in : grp * ch_in]\n        stride_shape = (n_batch, h_out, w_out, h_kernel, w_kernel, ch_in)\n        strides = (\n            x_in.strides[0],\n            h_stride * x_in.strides[1],\n            w_stride * x_in.strides[2],\n            x_in.strides[1],\n            x_in.strides[2],\n            x_in.strides[3],\n        )\n        inner_dim = h_kernel * w_kernel * ch_in\n        x_strided = as_strided(\n            x_in, shape=stride_shape, strides=strides\n        ).reshape(-1, inner_dim)\n        ch_out_groups = ch_out // groups\n        kernel_weights_grp = kernel_weights[\n            ..., (grp - 1) * ch_out_groups : grp * ch_out_groups\n        ].reshape(-1, ch_out_groups)\n        bias_weights_grp = bias_weights[\n            ..., (grp - 1) * ch_out_groups : grp * ch_out_groups\n        ]\n        out_grps.append(x_strided @ kernel_weights_grp + bias_weights_grp)\n    out = np.concatenate(out_grps, axis=-1).reshape(\n        n_batch, h_out, w_out, ch_out\n    )\n    if data_format == \"channels_first\":\n        out = out.transpose((0, 3, 1, 2))\n    return out\n\n\ndef np_conv3d(\n    x,\n    kernel_weights,\n    bias_weights,\n    strides,\n    padding,\n    data_format,\n    dilation_rate,\n    groups,\n):\n    if data_format == \"channels_first\":\n        x = x.transpose((0, 2, 3, 4, 1))\n    if isinstance(strides, (tuple, list)):\n        h_stride, w_stride, d_stride = strides\n    else:\n        h_stride = strides\n        w_stride = strides\n        d_stride = strides\n    if isinstance(dilation_rate, (tuple, list)):\n        h_dilation, w_dilation, d_dilation = dilation_rate\n    else:\n        h_dilation = dilation_rate\n        w_dilation = dilation_rate\n        d_dilation = dilation_rate\n\n    h_kernel, w_kernel, d_kernel, ch_in, ch_out = kernel_weights.shape\n\n    if h_dilation > 1 or w_dilation > 1 or d_dilation > 1:\n        new_h_kernel = h_kernel + (h_dilation - 1) * (h_kernel - 1)\n        new_w_kernel = w_kernel + (w_dilation - 1) * (w_kernel - 1)\n        new_d_kernel = d_kernel + (d_dilation - 1) * (d_kernel - 1)\n        new_kenel_size_tuple = (new_h_kernel, new_w_kernel, new_d_kernel)\n        new_kernel_weights = np.zeros(\n            (*new_kenel_size_tuple, ch_in, ch_out),\n            dtype=kernel_weights.dtype,\n        )\n        new_kernel_weights[::h_dilation, ::w_dilation, ::d_dilation] = (\n            kernel_weights\n        )\n        kernel_weights = new_kernel_weights\n        h_kernel, w_kernel, d_kernel = kernel_weights.shape[:3]\n\n    if padding == \"same\":\n        n_batch, h_x, w_x, d_x, _ = x.shape\n        h_pad = _same_padding(h_x, h_kernel, h_stride)\n        w_pad = _same_padding(w_x, w_kernel, w_stride)\n        d_pad = _same_padding(d_x, d_kernel, d_stride)\n        npad = [(0, 0)] * x.ndim\n        npad[1] = h_pad\n        npad[2] = w_pad\n        npad[3] = d_pad\n        x = np.pad(x, pad_width=npad, mode=\"constant\", constant_values=0)\n\n    n_batch, h_x, w_x, d_x, _ = x.shape\n    h_out = int((h_x - h_kernel) / h_stride) + 1\n    w_out = int((w_x - w_kernel) / w_stride) + 1\n    d_out = int((d_x - d_kernel) / d_stride) + 1\n\n    out_grps = []\n    for grp in range(1, groups + 1):\n        x_in = x[..., (grp - 1) * ch_in : grp * ch_in]\n        stride_shape = (\n            n_batch,\n            h_out,\n            w_out,\n            d_out,\n            h_kernel,\n            w_kernel,\n            d_kernel,\n            ch_in,\n        )\n        strides = (\n            x_in.strides[0],\n            h_stride * x_in.strides[1],\n            w_stride * x_in.strides[2],\n            d_stride * x_in.strides[3],\n            x_in.strides[1],\n            x_in.strides[2],\n            x_in.strides[3],\n            x_in.strides[4],\n        )\n        inner_dim = h_kernel * w_kernel * d_kernel * ch_in\n        x_strided = as_strided(\n            x_in, shape=stride_shape, strides=strides\n        ).reshape(-1, inner_dim)\n        ch_out_groups = ch_out // groups\n        kernel_weights_grp = kernel_weights[\n            ..., (grp - 1) * ch_out_groups : grp * ch_out_groups\n        ].reshape(-1, ch_out_groups)\n        bias_weights_grp = bias_weights[\n            ..., (grp - 1) * ch_out_groups : grp * ch_out_groups\n        ]\n        out_grps.append(x_strided @ kernel_weights_grp + bias_weights_grp)\n    out = np.concatenate(out_grps, axis=-1).reshape(\n        n_batch, h_out, w_out, d_out, ch_out\n    )\n    if data_format == \"channels_first\":\n        out = out.transpose((0, 4, 1, 2, 3))\n    return out\n\n\nclass ConvBasicTest(testing.TestCase):\n    @parameterized.parameters(\n        {\n            \"filters\": 5,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n            \"groups\": 1,\n            \"input_shape\": (3, 5, 4),\n            \"output_shape\": (3, 4, 5),\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"same\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (2,),\n            \"groups\": 2,\n            \"input_shape\": (3, 4, 4),\n            \"output_shape\": (3, 4, 6),\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"causal\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (2,),\n            \"groups\": 2,\n            \"input_shape\": (3, 4, 4),\n            \"output_shape\": (3, 4, 6),\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": 2,\n            \"strides\": (2,),\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n            \"groups\": 2,\n            \"input_shape\": (3, 5, 4),\n            \"output_shape\": (3, 2, 6),\n        },\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_conv1d_basic(\n        self,\n        filters,\n        kernel_size,\n        strides,\n        padding,\n        data_format,\n        dilation_rate,\n        groups,\n        input_shape,\n        output_shape,\n    ):\n        self.run_layer_test(\n            layers.Conv1D,\n            init_kwargs={\n                \"filters\": filters,\n                \"kernel_size\": kernel_size,\n                \"strides\": strides,\n                \"padding\": padding,\n                \"data_format\": data_format,\n                \"dilation_rate\": dilation_rate,\n                \"groups\": groups,\n            },\n            input_shape=input_shape,\n            expected_output_shape=output_shape,\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n\n    @parameterized.parameters(\n        {\n            \"filters\": 5,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n            \"groups\": 1,\n            \"input_shape\": (3, 5, 5, 4),\n            \"output_shape\": (3, 4, 4, 5),\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"same\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (2, 2),\n            \"groups\": 2,\n            \"input_shape\": (3, 4, 4, 4),\n            \"output_shape\": (3, 4, 4, 6),\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": (2, 2),\n            \"strides\": (2, 1),\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (1, 1),\n            \"groups\": 2,\n            \"input_shape\": (3, 5, 5, 4),\n            \"output_shape\": (3, 2, 4, 6),\n        },\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_conv2d_basic(\n        self,\n        filters,\n        kernel_size,\n        strides,\n        padding,\n        data_format,\n        dilation_rate,\n        groups,\n        input_shape,\n        output_shape,\n    ):\n        self.run_layer_test(\n            layers.Conv2D,\n            init_kwargs={\n                \"filters\": filters,\n                \"kernel_size\": kernel_size,\n                \"strides\": strides,\n                \"padding\": padding,\n                \"data_format\": data_format,\n                \"dilation_rate\": dilation_rate,\n                \"groups\": groups,\n            },\n            input_shape=input_shape,\n            expected_output_shape=output_shape,\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n\n    @parameterized.parameters(\n        {\n            \"filters\": 5,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n            \"groups\": 1,\n            \"input_shape\": (3, 5, 5, 5, 4),\n            \"output_shape\": (3, 4, 4, 4, 5),\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"same\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (2, 2, 2),\n            \"groups\": 2,\n            \"input_shape\": (3, 4, 4, 4, 4),\n            \"output_shape\": (3, 4, 4, 4, 6),\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": (2, 2, 3),\n            \"strides\": (2, 1, 2),\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (1, 1, 1),\n            \"groups\": 2,\n            \"input_shape\": (3, 5, 5, 5, 4),\n            \"output_shape\": (3, 2, 4, 2, 6),\n        },\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_conv3d_basic(\n        self,\n        filters,\n        kernel_size,\n        strides,\n        padding,\n        data_format,\n        dilation_rate,\n        groups,\n        input_shape,\n        output_shape,\n    ):\n        self.run_layer_test(\n            layers.Conv3D,\n            init_kwargs={\n                \"filters\": filters,\n                \"kernel_size\": kernel_size,\n                \"strides\": strides,\n                \"padding\": padding,\n                \"data_format\": data_format,\n                \"dilation_rate\": dilation_rate,\n                \"groups\": groups,\n            },\n            input_shape=input_shape,\n            expected_output_shape=output_shape,\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n\n    def test_bad_init_args(self):\n        # `filters` is not positive.\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Invalid value for argument `filters`. Expected a \"\n            \"strictly positive value. Received filters=0.\",\n        ):\n            layers.Conv1D(filters=0, kernel_size=1)\n\n        # `kernel_size` has 0.\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"The `kernel_size` argument must be a tuple of \\d+ \"\n            r\"integers. Received kernel_size=\\(1, 0\\), including values \\{0\\} \"\n            r\"that do not satisfy `value > 0`\",\n        ):\n            layers.Conv2D(filters=2, kernel_size=(1, 0))\n\n        # `strides` has 0.\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"The `strides` argument must be a tuple of \\d+ \"\n            r\"integers. Received strides=\\(1, 0\\), including values \\{0\\} that \"\n            r\"do not satisfy `value > 0`\",\n        ):\n            layers.Conv2D(filters=2, kernel_size=(2, 2), strides=(1, 0))\n\n        # `dilation_rate > 1` while `strides > 1`.\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"`strides > 1` not supported in conjunction with \"\n            r\"`dilation_rate > 1`. Received: strides=\\(2, 2\\) and \"\n            r\"dilation_rate=\\(2, 1\\)\",\n        ):\n            layers.Conv2D(\n                filters=2, kernel_size=(2, 2), strides=2, dilation_rate=(2, 1)\n            )\n\n        # `groups` is not strictly positive.\n        with self.assertRaisesRegex(\n            ValueError,\n            \"The number of groups must be a positive integer. \"\n            \"Received: groups=0.\",\n        ):\n            layers.Conv2D(filters=5, kernel_size=(2, 2), groups=0)\n\n        # `filters` cannot be divided by `groups`.\n        with self.assertRaisesRegex(\n            ValueError,\n            \"The number of filters must be evenly divisible by the\"\n            \" number of groups. Received: groups=2, filters=5.\",\n        ):\n            layers.Conv2D(filters=5, kernel_size=(2, 2), groups=2)\n\n    @parameterized.named_parameters(\n        {\n            \"testcase_name\": \"conv1d_kernel_size3_strides1\",\n            \"conv_cls\": layers.Conv1D,\n            \"filters\": 6,\n            \"kernel_size\": 3,\n            \"strides\": 1,\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n            \"groups\": 1,\n            \"input_shape\": (None, 5, 4),\n            \"output_shape\": (None, 3, 6),\n        },\n        {\n            \"testcase_name\": \"conv1d_kernel_size2_strides2\",\n            \"conv_cls\": layers.Conv1D,\n            \"filters\": 6,\n            \"kernel_size\": 2,\n            \"strides\": 2,\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n            \"groups\": 2,\n            \"input_shape\": (None, 5, 4),\n            \"output_shape\": (None, 2, 6),\n        },\n        {\n            \"testcase_name\": \"conv2d_kernel_size3_strides1\",\n            \"conv_cls\": layers.Conv2D,\n            \"filters\": 6,\n            \"kernel_size\": 3,\n            \"strides\": 1,\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n            \"groups\": 1,\n            \"input_shape\": (None, 5, 5, 4),\n            \"output_shape\": (None, 3, 3, 6),\n        },\n        {\n            \"testcase_name\": \"conv2d_kernel_size2_strides2\",\n            \"conv_cls\": layers.Conv2D,\n            \"filters\": 6,\n            \"kernel_size\": 2,\n            \"strides\": 2,\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n            \"groups\": 2,\n            \"input_shape\": (None, 5, 5, 4),\n            \"output_shape\": (None, 2, 2, 6),\n        },\n        {\n            \"testcase_name\": \"conv3d_kernel_size3_strides1\",\n            \"conv_cls\": layers.Conv3D,\n            \"filters\": 6,\n            \"kernel_size\": 3,\n            \"strides\": 1,\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n            \"groups\": 1,\n            \"input_shape\": (None, 5, 5, 5, 4),\n            \"output_shape\": (None, 3, 3, 3, 6),\n        },\n        {\n            \"testcase_name\": \"conv3d_kernel_size2_strides2\",\n            \"conv_cls\": layers.Conv3D,\n            \"filters\": 6,\n            \"kernel_size\": 2,\n            \"strides\": 2,\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n            \"groups\": 2,\n            \"input_shape\": (None, 5, 5, 5, 4),\n            \"output_shape\": (None, 2, 2, 2, 6),\n        },\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_enable_lora(\n        self,\n        conv_cls,\n        filters,\n        kernel_size,\n        strides,\n        padding,\n        data_format,\n        dilation_rate,\n        groups,\n        input_shape,\n        output_shape,\n    ):\n        if conv_cls not in (layers.Conv1D, layers.Conv2D, layers.Conv3D):\n            raise TypeError\n        layer = conv_cls(\n            filters=filters,\n            kernel_size=kernel_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n            groups=groups,\n        )\n        layer.build(input_shape)\n        layer.enable_lora(2)\n        self.assertLen(layer.trainable_weights, 3)\n        self.assertLen(layer.non_trainable_weights, 1)\n        if backend.backend() == \"torch\":\n            self.assertLen(layer.torch_params, 4)\n        # Try eager call\n        x = np.random.random((64,) + input_shape[1:])\n        y = np.random.random((64,) + output_shape[1:])\n        _ = layer(x[:2])\n\n        init_lora_a_kernel_value = layer.lora_kernel_a.numpy()\n        init_lora_b_kernel_value = layer.lora_kernel_b.numpy()\n\n        # Try calling fit()\n        model = models.Sequential([layer])\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n        model.fit(x, y)\n\n        final_lora_a_kernel_value = layer.lora_kernel_a.numpy()\n        final_lora_b_kernel_value = layer.lora_kernel_b.numpy()\n        diff_a = np.max(\n            np.abs(init_lora_a_kernel_value - final_lora_a_kernel_value)\n        )\n        diff_b = np.max(\n            np.abs(init_lora_b_kernel_value - final_lora_b_kernel_value)\n        )\n        self.assertGreater(diff_a, 0.0)\n        self.assertGreater(diff_b, 0.0)\n\n        # Try saving and reloading the model\n        temp_filepath = os.path.join(self.get_temp_dir(), \"lora_model.keras\")\n        model.save(temp_filepath)\n\n        new_model = saving.load_model(temp_filepath)\n        self.assertTrue(new_model.layers[0].lora_enabled)\n        self.assertAllClose(model.predict(x), new_model.predict(x))\n\n        # Try saving and reloading the model's weights only\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"lora_model.weights.h5\"\n        )\n        model.save_weights(temp_filepath)\n\n        # Load the file into a fresh, non-lora model\n        new_model = models.Sequential(\n            [\n                conv_cls(\n                    filters=filters,\n                    kernel_size=kernel_size,\n                    strides=strides,\n                    padding=padding,\n                    data_format=data_format,\n                    dilation_rate=dilation_rate,\n                    groups=groups,\n                )\n            ]\n        )\n        new_model.build(input_shape)\n        new_model.load_weights(temp_filepath)\n        self.assertAllClose(model.predict(x), new_model.predict(x))\n\n        # Try loading a normal checkpoint into a lora model\n        new_model.save_weights(temp_filepath)\n        model.load_weights(temp_filepath)\n        self.assertAllClose(model.predict(x), new_model.predict(x))\n\n    @pytest.mark.requires_trainable_backend\n    def test_lora_weight_name(self):\n        class MyModel(models.Model):\n            def __init__(self):\n                super().__init__(name=\"mymodel\")\n                self.conv2d = layers.Conv2D(4, 3, name=\"conv2d\")\n\n            def build(self, input_shape):\n                self.conv2d.build(input_shape)\n\n            def call(self, x):\n                return self.conv2d(x)\n\n        model = MyModel()\n        model.build((None, 5, 5, 4))\n        model.conv2d.enable_lora(2)\n        self.assertEqual(\n            model.conv2d.lora_kernel_a.path, \"mymodel/conv2d/lora_kernel_a\"\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_enable_lora_with_alpha(self):\n        # Create a `Conv2D` layer with a small kernel for simplicity.\n        layer = layers.Conv2D(filters=3, kernel_size=(2, 2), padding=\"valid\")\n        # Use a fixed input shape: batch size 1, height=4, width=4, channels=3.\n        input_shape = (1, 4, 4, 3)\n        layer.build(input_shape)\n\n        # Set the base kernel to known, deterministic values.\n        base_kernel = np.linspace(\n            0, 1, num=np.prod(layer.kernel.shape), dtype=np.float32\n        )\n        base_kernel = base_kernel.reshape(layer.kernel.shape)\n        layer.kernel.assign(base_kernel)\n\n        # Enable LoRA with `rank`=2 and a custom `lora_alpha` value (e.g. 3.0).\n        layer.enable_lora(rank=2, lora_alpha=3.0)\n        self.assertEqual(layer.lora_rank, 2)\n        self.assertEqual(layer.lora_alpha, 3.0)\n\n        # For `Conv2D`, assume the LoRA weights have shapes:\n        #   `lora_kernel_a`: (kernel_height, kernel_width, in_channels, rank)\n        #   `lora_kernel_b`: (rank, out_channels)\n        lora_a_shape = layer.lora_kernel_a.shape\n        lora_b_shape = layer.lora_kernel_b.shape\n\n        # Assign known constant values to LoRA weights.\n        lora_a = np.full(lora_a_shape, 0.1, dtype=np.float32)\n        lora_b = np.full(lora_b_shape, 0.2, dtype=np.float32)\n        layer.lora_kernel_a.assign(lora_a)\n        layer.lora_kernel_b.assign(lora_b)\n\n        # Compute the expected delta.\n        # Flatten `lora_kernel_a` to shape (-1, `rank`),\n        # multiply with `lora_kernel_b`,\n        # then reshape to the kernel's shape.\n        scaling = 3.0 / 2  # `lora_alpha / lora_rank`\n        delta = np.matmul(lora_a.reshape(-1, 2), lora_b)\n        delta = delta.reshape(base_kernel.shape)\n        expected_effective_kernel = base_kernel + scaling * delta\n\n        # Compare the effective kernel computed via the property.\n        actual_effective_kernel = ops.convert_to_numpy(layer.kernel)\n        self.assertAllClose(\n            actual_effective_kernel,\n            expected_effective_kernel,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_lora_rank_argument(self):\n        self.run_layer_test(\n            layers.Conv2D,\n            init_kwargs={\n                \"filters\": 5,\n                \"kernel_size\": 3,\n                \"activation\": \"sigmoid\",\n                \"data_format\": \"channels_last\",\n                \"kernel_regularizer\": \"l2\",\n                \"lora_rank\": 2,\n            },\n            input_shape=(2, 5, 5, 4),\n            expected_output_shape=(2, 3, 3, 5),\n            expected_num_trainable_weights=3,\n            expected_num_non_trainable_weights=1,\n            expected_num_seed_generators=0,\n            expected_num_losses=2,  # we have 2 regularizers.\n            supports_masking=False,\n        )\n\n    @parameterized.parameters(\n        (layers.Conv1D, (10, 3), 2, 10),\n        (layers.Conv2D, (10, 10, 3), (2, 2), (10, 10)),\n        (layers.Conv3D, (10, 10, 10, 3), (2, 2, 2), (10, 10, 10)),\n    )\n    def test_conv_symbolic_invalid_configuration(\n        self, layer_cls, input_shape, kernel_size, dilation_rate\n    ):\n        inputs = layers.Input(shape=input_shape)\n        layer = layer_cls(\n            filters=1,\n            kernel_size=kernel_size,\n            dilation_rate=dilation_rate,\n        )\n\n        with self.assertRaises(ValueError):\n            layer(inputs)\n\n\nclass ConvCorrectnessTest(testing.TestCase):\n    @parameterized.parameters(\n        {\n            \"filters\": 5,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n            \"groups\": 1,\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"same\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (2,),\n            \"groups\": 2,\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"causal\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (2,),\n            \"groups\": 2,\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": (2,),\n            \"strides\": (2,),\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n            \"groups\": 2,\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": (2,),\n            \"strides\": (2,),\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_first\",\n            \"dilation_rate\": 1,\n            \"groups\": 2,\n        },\n    )\n    def test_conv1d(\n        self,\n        filters,\n        kernel_size,\n        strides,\n        padding,\n        data_format,\n        dilation_rate,\n        groups,\n    ):\n        layer = layers.Conv1D(\n            filters=filters,\n            kernel_size=kernel_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n            groups=groups,\n        )\n\n        inputs = np.random.normal(size=[2, 8, 4])\n        layer.build(input_shape=inputs.shape)\n\n        kernel_shape = layer.kernel.shape\n        kernel_weights = np.random.normal(size=kernel_shape)\n        bias_weights = np.random.normal(size=(filters,))\n        layer.kernel.assign(kernel_weights)\n        layer.bias.assign(bias_weights)\n\n        outputs = layer(inputs)\n        expected = np_conv1d(\n            inputs,\n            kernel_weights,\n            bias_weights,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n            groups=groups,\n        )\n        self.assertAllClose(outputs, expected, tpu_atol=1e-1, tpu_rtol=1e-1)\n\n    @parameterized.parameters(\n        {\n            \"filters\": 5,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n            \"groups\": 1,\n        },\n        {\n            \"filters\": 4,\n            \"kernel_size\": 3,\n            \"strides\": 2,\n            \"padding\": \"same\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n            \"groups\": 1,\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"same\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (2, 2),\n            \"groups\": 2,\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"same\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (2, 3),\n            \"groups\": 2,\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": (4, 3),\n            \"strides\": (2, 1),\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (1, 1),\n            \"groups\": 2,\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": (4, 3),\n            \"strides\": (2, 1),\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_first\",\n            \"dilation_rate\": (1, 1),\n            \"groups\": 2,\n        },\n    )\n    def test_conv2d(\n        self,\n        filters,\n        kernel_size,\n        strides,\n        padding,\n        data_format,\n        dilation_rate,\n        groups,\n    ):\n        layer = layers.Conv2D(\n            filters=filters,\n            kernel_size=kernel_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n            groups=groups,\n        )\n\n        inputs = np.random.normal(size=[2, 8, 8, 4])\n        layer.build(input_shape=inputs.shape)\n\n        kernel_shape = layer.kernel.shape\n        kernel_weights = np.random.normal(size=kernel_shape)\n        bias_weights = np.random.normal(size=(filters,))\n        layer.kernel.assign(kernel_weights)\n        layer.bias.assign(bias_weights)\n\n        outputs = layer(inputs)\n        expected = np_conv2d(\n            inputs,\n            kernel_weights,\n            bias_weights,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n            groups=groups,\n        )\n        self.assertAllClose(\n            outputs, expected, rtol=5e-4, tpu_atol=1e-1, tpu_rtol=1e-1\n        )\n\n    @parameterized.parameters(\n        {\n            \"filters\": 5,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n            \"groups\": 1,\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"same\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (2, 2, 2),\n            \"groups\": 2,\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"same\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (2, 3, 4),\n            \"groups\": 2,\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": (2, 2, 3),\n            \"strides\": (2, 1, 2),\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (1, 1, 1),\n            \"groups\": 2,\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": (2, 2, 3),\n            \"strides\": (2, 1, 2),\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_first\",\n            \"dilation_rate\": (1, 1, 1),\n            \"groups\": 2,\n        },\n    )\n    def test_conv3d(\n        self,\n        filters,\n        kernel_size,\n        strides,\n        padding,\n        data_format,\n        dilation_rate,\n        groups,\n    ):\n        layer = layers.Conv3D(\n            filters=filters,\n            kernel_size=kernel_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n            groups=groups,\n        )\n\n        inputs = np.random.normal(size=[2, 8, 8, 8, 4])\n        layer.build(input_shape=inputs.shape)\n\n        kernel_shape = layer.kernel.shape\n        kernel_weights = np.random.normal(size=kernel_shape)\n        bias_weights = np.random.normal(size=(filters,))\n        layer.kernel.assign(kernel_weights)\n        layer.bias.assign(bias_weights)\n\n        outputs = layer(inputs)\n        expected = np_conv3d(\n            inputs,\n            kernel_weights,\n            bias_weights,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n            groups=groups,\n        )\n        self.assertAllClose(\n            outputs, expected, rtol=1e-3, tpu_atol=1e-1, tpu_rtol=1e-1\n        )\n\n    def test_conv_constraints(self):\n        layer = layers.Conv2D(\n            filters=4,\n            kernel_size=3,\n            kernel_constraint=\"non_neg\",\n        )\n        layer.build((None, 5, 5, 3))\n        self.assertIsInstance(layer.kernel.constraint, constraints.NonNeg)\n        layer = layers.Conv2D(\n            filters=4,\n            kernel_size=3,\n            bias_constraint=\"non_neg\",\n        )\n        layer.build((None, 5, 5, 3))\n        self.assertIsInstance(layer.bias.constraint, constraints.NonNeg)\n\n    def test_conv_raises_exception_on_zero_dims(self):\n        x = np.random.rand(3, 4, 4, 4)\n        l = layers.Conv2D(6, [5, 5], 1, \"valid\")\n        # The exception type can vary across backends (e.g., ValueError,\n        # tf.errors.InvalidArgumentError, RuntimeError).\n        with self.assertRaises(Exception):\n            l(x)\n"
  },
  {
    "path": "keras/src/layers/convolutional/conv_transpose_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\nfrom keras.src.backend.common.backend_utils import (\n    _convert_conv_transpose_padding_args_from_keras_to_torch,\n)\nfrom keras.src.backend.common.backend_utils import (\n    compute_conv_transpose_output_shape,\n)\nfrom keras.src.backend.common.backend_utils import (\n    compute_conv_transpose_padding_args_for_jax,\n)\n\n\ndef np_conv1d_transpose(\n    x,\n    kernel_weights,\n    bias_weights,\n    strides,\n    padding,\n    output_padding,\n    data_format,\n    dilation_rate,\n):\n    if data_format == \"channels_first\":\n        x = x.transpose((0, 2, 1))\n    if isinstance(strides, (tuple, list)):\n        h_stride = strides[0]\n    else:\n        h_stride = strides\n    if isinstance(dilation_rate, (tuple, list)):\n        h_dilation = dilation_rate[0]\n    else:\n        h_dilation = dilation_rate\n\n    h_kernel, ch_out, ch_in = kernel_weights.shape\n    n_batch, h_x, _ = x.shape\n    # Get output shape and padding\n    _, h_out, _ = compute_conv_transpose_output_shape(\n        x.shape,\n        kernel_weights.shape,\n        ch_out,\n        strides,\n        padding,\n        output_padding,\n        \"channels_last\",\n        dilation_rate,\n    )\n    jax_padding = compute_conv_transpose_padding_args_for_jax(\n        input_shape=x.shape,\n        kernel_shape=kernel_weights.shape,\n        strides=strides,\n        padding=padding,\n        output_padding=output_padding,\n        dilation_rate=dilation_rate,\n    )\n    h_pad_side1 = h_kernel - 1 - jax_padding[0][0]\n\n    if h_dilation > 1:\n        # Increase kernel size\n        new_h_kernel = h_kernel + (h_dilation - 1) * (h_kernel - 1)\n        new_kernel_size_tuple = (new_h_kernel,)\n        new_kernel_weights = np.zeros(\n            (*new_kernel_size_tuple, ch_out, ch_in),\n            dtype=kernel_weights.dtype,\n        )\n        new_kernel_weights[::h_dilation] = kernel_weights\n        kernel_weights = new_kernel_weights\n        h_kernel = kernel_weights.shape[0]\n\n    # Compute output\n    output = np.zeros([n_batch, h_out + h_kernel, ch_out])\n    for nb in range(n_batch):\n        for h_x_idx in range(h_x):\n            h_out_idx = h_x_idx * h_stride  # Index in output\n            output[nb, h_out_idx : h_out_idx + h_kernel, :] += np.sum(\n                kernel_weights[:, :, :] * x[nb, h_x_idx, :], axis=-1\n            )\n    output = output + bias_weights\n\n    # Cut padding results from output\n    output = output[:, h_pad_side1 : h_out + h_pad_side1]\n    if data_format == \"channels_first\":\n        output = output.transpose((0, 2, 1))\n    return output\n\n\ndef np_conv2d_transpose(\n    x,\n    kernel_weights,\n    bias_weights,\n    strides,\n    padding,\n    output_padding,\n    data_format,\n    dilation_rate,\n):\n    if data_format == \"channels_first\":\n        x = x.transpose((0, 2, 3, 1))\n    if isinstance(strides, (tuple, list)):\n        h_stride, w_stride = strides\n    else:\n        h_stride = strides\n        w_stride = strides\n    if isinstance(dilation_rate, (tuple, list)):\n        h_dilation, w_dilation = dilation_rate\n    else:\n        h_dilation = dilation_rate\n        w_dilation = dilation_rate\n\n    h_kernel, w_kernel, ch_out, ch_in = kernel_weights.shape\n    n_batch, h_x, w_x, _ = x.shape\n    # Get output shape and padding\n    _, h_out, w_out, _ = compute_conv_transpose_output_shape(\n        x.shape,\n        kernel_weights.shape,\n        ch_out,\n        strides,\n        padding,\n        output_padding,\n        \"channels_last\",\n        dilation_rate,\n    )\n    jax_padding = compute_conv_transpose_padding_args_for_jax(\n        input_shape=x.shape,\n        kernel_shape=kernel_weights.shape,\n        strides=strides,\n        padding=padding,\n        output_padding=output_padding,\n        dilation_rate=dilation_rate,\n    )\n    h_pad_side1 = h_kernel - 1 - jax_padding[0][0]\n    w_pad_side1 = w_kernel - 1 - jax_padding[1][0]\n\n    if h_dilation > 1 or w_dilation > 1:\n        # Increase kernel size\n        new_h_kernel = h_kernel + (h_dilation - 1) * (h_kernel - 1)\n        new_w_kernel = w_kernel + (w_dilation - 1) * (w_kernel - 1)\n        new_kernel_size_tuple = (new_h_kernel, new_w_kernel)\n        new_kernel_weights = np.zeros(\n            (*new_kernel_size_tuple, ch_out, ch_in),\n            dtype=kernel_weights.dtype,\n        )\n        new_kernel_weights[::h_dilation, ::w_dilation] = kernel_weights\n        kernel_weights = new_kernel_weights\n        h_kernel, w_kernel = kernel_weights.shape[:2]\n\n    # Compute output\n    output = np.zeros([n_batch, h_out + h_kernel, w_out + w_kernel, ch_out])\n    for nb in range(n_batch):\n        for h_x_idx in range(h_x):\n            h_out_idx = h_x_idx * h_stride  # Index in output\n            for w_x_idx in range(w_x):\n                w_out_idx = w_x_idx * w_stride\n                output[\n                    nb,\n                    h_out_idx : h_out_idx + h_kernel,\n                    w_out_idx : w_out_idx + w_kernel,\n                    :,\n                ] += np.sum(\n                    kernel_weights[:, :, :, :] * x[nb, h_x_idx, w_x_idx, :],\n                    axis=-1,\n                )\n    output = output + bias_weights\n\n    # Cut padding results from output\n    output = output[\n        :,\n        h_pad_side1 : h_out + h_pad_side1,\n        w_pad_side1 : w_out + w_pad_side1,\n    ]\n    if data_format == \"channels_first\":\n        output = output.transpose((0, 3, 1, 2))\n    return output\n\n\ndef np_conv3d_transpose(\n    x,\n    kernel_weights,\n    bias_weights,\n    strides,\n    padding,\n    output_padding,\n    data_format,\n    dilation_rate,\n):\n    if data_format == \"channels_first\":\n        x = x.transpose((0, 2, 3, 4, 1))\n    if isinstance(strides, (tuple, list)):\n        h_stride, w_stride, d_stride = strides\n    else:\n        h_stride = strides\n        w_stride = strides\n        d_stride = strides\n    if isinstance(dilation_rate, (tuple, list)):\n        h_dilation, w_dilation, d_dilation = dilation_rate\n    else:\n        h_dilation = dilation_rate\n        w_dilation = dilation_rate\n        d_dilation = dilation_rate\n\n    h_kernel, w_kernel, d_kernel, ch_out, ch_in = kernel_weights.shape\n    n_batch, h_x, w_x, d_x, _ = x.shape\n    # Get output shape and padding\n    _, h_out, w_out, d_out, _ = compute_conv_transpose_output_shape(\n        x.shape,\n        kernel_weights.shape,\n        ch_out,\n        strides,\n        padding,\n        output_padding,\n        data_format,\n        dilation_rate,\n    )\n    jax_padding = compute_conv_transpose_padding_args_for_jax(\n        input_shape=x.shape,\n        kernel_shape=kernel_weights.shape,\n        strides=strides,\n        padding=padding,\n        output_padding=output_padding,\n        dilation_rate=dilation_rate,\n    )\n    h_pad_side1 = h_kernel - 1 - jax_padding[0][0]\n    w_pad_side1 = w_kernel - 1 - jax_padding[1][0]\n    d_pad_side1 = d_kernel - 1 - jax_padding[2][0]\n\n    if h_dilation > 1 or w_dilation > 1 or d_dilation > 1:\n        # Increase kernel size\n        new_h_kernel = h_kernel + (h_dilation - 1) * (h_kernel - 1)\n        new_w_kernel = w_kernel + (w_dilation - 1) * (w_kernel - 1)\n        new_d_kernel = d_kernel + (d_dilation - 1) * (d_kernel - 1)\n        new_kernel_size_tuple = (new_h_kernel, new_w_kernel, new_d_kernel)\n        new_kernel_weights = np.zeros(\n            (*new_kernel_size_tuple, ch_out, ch_in),\n            dtype=kernel_weights.dtype,\n        )\n        new_kernel_weights[::h_dilation, ::w_dilation, ::d_dilation] = (\n            kernel_weights\n        )\n        kernel_weights = new_kernel_weights\n        h_kernel, w_kernel, d_kernel = kernel_weights.shape[:3]\n\n    # Compute output\n    output = np.zeros(\n        [\n            n_batch,\n            h_out + h_kernel,\n            w_out + w_kernel,\n            d_out + d_kernel,\n            ch_out,\n        ]\n    )\n    for nb in range(n_batch):\n        for h_x_idx in range(h_x):\n            h_out_idx = h_x_idx * h_stride  # Index in output\n            for w_x_idx in range(w_x):\n                w_out_idx = w_x_idx * w_stride\n                for d_x_idx in range(d_x):\n                    d_out_idx = d_x_idx * d_stride\n                    output[\n                        nb,\n                        h_out_idx : h_out_idx + h_kernel,\n                        w_out_idx : w_out_idx + w_kernel,\n                        d_out_idx : d_out_idx + d_kernel,\n                        :,\n                    ] += np.sum(\n                        kernel_weights[:, :, :, :, :]\n                        * x[nb, h_x_idx, w_x_idx, d_x_idx, :],\n                        axis=-1,\n                    )\n    output = output + bias_weights\n\n    # Cut padding results from output\n    output = output[\n        :,\n        h_pad_side1 : h_out + h_pad_side1,\n        w_pad_side1 : w_out + w_pad_side1,\n        d_pad_side1 : d_out + d_pad_side1,\n    ]\n    if data_format == \"channels_first\":\n        output = output.transpose((0, 4, 1, 2, 3))\n    return output\n\n\nclass ConvTransposeBasicTest(testing.TestCase):\n    @parameterized.parameters(\n        {\n            \"filters\": 5,\n            \"kernel_size\": 2,\n            \"strides\": 2,\n            \"padding\": \"valid\",\n            \"output_padding\": None,\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n            \"input_shape\": (2, 8, 4),\n            \"output_shape\": (2, 16, 5),\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": 2,\n            \"strides\": 3,\n            \"padding\": \"same\",\n            \"output_padding\": 2,\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (1,),\n            \"input_shape\": (2, 8, 4),\n            \"output_shape\": (2, 23, 6),\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": (2,),\n            \"strides\": (2,),\n            \"padding\": \"valid\",\n            \"output_padding\": None,\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n            \"input_shape\": (2, 8, 4),\n            \"output_shape\": (2, 16, 6),\n        },\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_conv1d_transpose_basic(\n        self,\n        filters,\n        kernel_size,\n        strides,\n        padding,\n        output_padding,\n        data_format,\n        dilation_rate,\n        input_shape,\n        output_shape,\n    ):\n        self.run_layer_test(\n            layers.Conv1DTranspose,\n            init_kwargs={\n                \"filters\": filters,\n                \"kernel_size\": kernel_size,\n                \"strides\": strides,\n                \"padding\": padding,\n                \"output_padding\": output_padding,\n                \"data_format\": data_format,\n                \"dilation_rate\": dilation_rate,\n            },\n            input_shape=input_shape,\n            expected_output_shape=output_shape,\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n\n    @parameterized.parameters(\n        {\n            \"filters\": 5,\n            \"kernel_size\": 2,\n            \"strides\": 2,\n            \"padding\": \"valid\",\n            \"output_padding\": None,\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n            \"input_shape\": (2, 8, 8, 4),\n            \"output_shape\": (2, 16, 16, 5),\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": 2,\n            \"strides\": 3,\n            \"padding\": \"same\",\n            \"output_padding\": 2,\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (1, 1),\n            \"input_shape\": (2, 8, 8, 4),\n            \"output_shape\": (2, 23, 23, 6),\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": (2, 3),\n            \"strides\": (2, 1),\n            \"padding\": \"valid\",\n            \"output_padding\": None,\n            \"data_format\": \"channels_first\",\n            \"dilation_rate\": (1, 1),\n            \"input_shape\": (2, 4, 8, 8),\n            \"output_shape\": (2, 6, 16, 10),\n        },\n        {\n            \"filters\": 2,\n            \"kernel_size\": (7, 7),\n            \"strides\": (16, 16),\n            \"padding\": \"valid\",\n            \"output_padding\": None,\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (1, 1),\n            \"input_shape\": (1, 14, 14, 2),\n            \"output_shape\": (1, 224, 224, 2),\n        },\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_conv2d_transpose_basic(\n        self,\n        filters,\n        kernel_size,\n        strides,\n        padding,\n        output_padding,\n        data_format,\n        dilation_rate,\n        input_shape,\n        output_shape,\n    ):\n        if (\n            data_format == \"channels_first\"\n            and backend.backend() == \"tensorflow\"\n        ):\n            pytest.skip(\"channels_first unsupported on CPU with TF\")\n\n        self.run_layer_test(\n            layers.Conv2DTranspose,\n            init_kwargs={\n                \"filters\": filters,\n                \"kernel_size\": kernel_size,\n                \"strides\": strides,\n                \"padding\": padding,\n                \"output_padding\": output_padding,\n                \"data_format\": data_format,\n                \"dilation_rate\": dilation_rate,\n            },\n            input_shape=input_shape,\n            expected_output_shape=output_shape,\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n\n    @parameterized.parameters(\n        {\n            \"filters\": 5,\n            \"kernel_size\": 2,\n            \"strides\": 2,\n            \"padding\": \"valid\",\n            \"output_padding\": None,\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n            \"input_shape\": (2, 8, 8, 8, 4),\n            \"output_shape\": (2, 16, 16, 16, 5),\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": 2,\n            \"strides\": 3,\n            \"padding\": \"same\",\n            \"output_padding\": 2,\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (1, 1, 1),\n            \"input_shape\": (2, 8, 8, 8, 4),\n            \"output_shape\": (2, 23, 23, 23, 6),\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": (2, 2, 3),\n            \"strides\": (2, 1, 2),\n            \"padding\": \"valid\",\n            \"output_padding\": None,\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (1, 1, 1),\n            \"input_shape\": (2, 8, 8, 8, 4),\n            \"output_shape\": (2, 16, 9, 17, 6),\n        },\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_conv3d_transpose_basic(\n        self,\n        filters,\n        kernel_size,\n        strides,\n        padding,\n        output_padding,\n        data_format,\n        dilation_rate,\n        input_shape,\n        output_shape,\n    ):\n        self.run_layer_test(\n            layers.Conv3DTranspose,\n            init_kwargs={\n                \"filters\": filters,\n                \"kernel_size\": kernel_size,\n                \"strides\": strides,\n                \"padding\": padding,\n                \"output_padding\": output_padding,\n                \"data_format\": data_format,\n                \"dilation_rate\": dilation_rate,\n            },\n            input_shape=input_shape,\n            expected_output_shape=output_shape,\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n\n    def test_bad_init_args(self):\n        # `filters` is not positive.\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Invalid value for argument `filters`. Expected a \"\n            \"strictly positive value. Received filters=0.\",\n        ):\n            layers.Conv1DTranspose(filters=0, kernel_size=1)\n\n        # `kernel_size` has 0.\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"The `kernel_size` argument must be a tuple of \"\n            r\"\\d+ integers. Received kernel_size=\\(1, 0\\), including values\"\n            r\" \\{0\\} that do not satisfy `value > 0`\",\n        ):\n            layers.Conv2DTranspose(filters=2, kernel_size=(1, 0))\n\n        # `strides` has 0.\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"The `strides` argument must be a tuple of \\d+ \"\n            r\"integers. Received strides=\\(1, 0\\), including values \\{0\\} \"\n            r\"that do not satisfy `value > 0`\",\n        ):\n            layers.Conv2DTranspose(\n                filters=2, kernel_size=(2, 2), strides=(1, 0)\n            )\n\n        # `output_padding` >= `strides`.\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"`output_padding` must be strictly less than `strides`\",\n        ):\n            layers.Conv1DTranspose(\n                filters=2, kernel_size=3, strides=2, output_padding=2\n            )\n\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"`output_padding` must be strictly less than `strides`\",\n        ):\n            layers.Conv2DTranspose(\n                filters=2,\n                kernel_size=3,\n                strides=2,\n                output_padding=3,\n            )\n\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"`output_padding` must be strictly less than `strides`\",\n        ):\n            layers.Conv2DTranspose(\n                filters=2,\n                kernel_size=3,\n                strides=(2, 3),\n                output_padding=(1, 3),\n            )\n\n        # `dilation_rate > 1` while `strides > 1`.\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"`strides > 1` not supported in conjunction with \"\n            r\"`dilation_rate > 1`. Received: strides=\\(2, 2\\) and \"\n            r\"dilation_rate=\\(2, 1\\)\",\n        ):\n            layers.Conv2DTranspose(\n                filters=2, kernel_size=(2, 2), strides=2, dilation_rate=(2, 1)\n            )\n\n        with self.assertRaisesRegex(\n            ValueError,\n            \"`output_padding` must be strictly less than `strides`\",\n        ):\n            layers.Conv1DTranspose(\n                filters=2, kernel_size=2, strides=2, output_padding=2\n            )\n\n        with self.assertRaisesRegex(\n            ValueError,\n            \"`output_padding` must be strictly less than `strides`\",\n        ):\n            layers.Conv2DTranspose(\n                filters=16, kernel_size=3, strides=[1, 1], output_padding=1\n            )\n\n        with self.assertRaisesRegex(\n            ValueError,\n            \"`output_padding` must be strictly less than `strides`\",\n        ):\n            layers.Conv3DTranspose(\n                filters=8, kernel_size=3, strides=1, output_padding=1\n            )\n\n\nclass ConvTransposeCorrectnessTest(testing.TestCase):\n    @parameterized.parameters(\n        {\n            \"filters\": 5,\n            \"kernel_size\": 2,\n            \"strides\": 2,\n            \"padding\": \"valid\",\n            \"output_padding\": None,\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": 2,\n            \"strides\": 3,\n            \"padding\": \"same\",\n            \"output_padding\": 2,\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (1,),\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": (2,),\n            \"strides\": (2,),\n            \"padding\": \"valid\",\n            \"output_padding\": None,\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n        },\n    )\n    def test_conv1d_transpose(\n        self,\n        filters,\n        kernel_size,\n        strides,\n        padding,\n        output_padding,\n        data_format,\n        dilation_rate,\n    ):\n        layer = layers.Conv1DTranspose(\n            filters=filters,\n            kernel_size=kernel_size,\n            strides=strides,\n            padding=padding,\n            output_padding=output_padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n        )\n\n        inputs = np.random.normal(size=[2, 8, 4])\n        layer.build(input_shape=inputs.shape)\n\n        kernel_shape = layer.kernel.shape\n        kernel_weights = np.random.normal(size=kernel_shape)\n        bias_weights = np.random.normal(size=(filters,))\n        layer.kernel.assign(kernel_weights)\n        layer.bias.assign(bias_weights)\n\n        outputs = layer(inputs)\n        expected = np_conv1d_transpose(\n            inputs,\n            kernel_weights,\n            bias_weights,\n            strides,\n            padding,\n            output_padding,\n            data_format,\n            dilation_rate,\n        )\n        self.assertAllClose(\n            outputs, expected, atol=1e-5, tpu_atol=1e-1, tpu_rtol=1e-1\n        )\n\n    @parameterized.parameters(\n        {\n            \"filters\": 5,\n            \"kernel_size\": 2,\n            \"strides\": 2,\n            \"padding\": \"valid\",\n            \"output_padding\": None,\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": 7,\n            \"strides\": 16,\n            \"padding\": \"same\",\n            \"output_padding\": 2,\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (1, 1),\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": (2, 3),\n            \"strides\": (2, 1),\n            \"padding\": \"valid\",\n            \"output_padding\": None,\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (1, 1),\n        },\n        {\n            \"filters\": 2,\n            \"kernel_size\": (7, 7),\n            \"strides\": (16, 16),\n            \"padding\": \"valid\",\n            \"output_padding\": None,\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (1, 1),\n        },\n    )\n    def test_conv2d_transpose(\n        self,\n        filters,\n        kernel_size,\n        strides,\n        padding,\n        output_padding,\n        data_format,\n        dilation_rate,\n    ):\n        layer = layers.Conv2DTranspose(\n            filters=filters,\n            kernel_size=kernel_size,\n            strides=strides,\n            padding=padding,\n            output_padding=output_padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n        )\n\n        inputs = np.random.normal(size=[2, 14, 14, 4])\n        layer.build(input_shape=inputs.shape)\n\n        kernel_shape = layer.kernel.shape\n        kernel_weights = np.random.normal(size=kernel_shape)\n        bias_weights = np.random.normal(size=(filters,))\n        layer.kernel.assign(kernel_weights)\n        layer.bias.assign(bias_weights)\n\n        outputs = layer(inputs)\n        expected = np_conv2d_transpose(\n            inputs,\n            kernel_weights,\n            bias_weights,\n            strides,\n            padding,\n            output_padding,\n            data_format,\n            dilation_rate,\n        )\n        self.assertAllClose(\n            outputs, expected, atol=1e-5, tpu_atol=1e-1, tpu_rtol=1e-1\n        )\n\n    @parameterized.parameters(\n        {\n            \"filters\": 5,\n            \"kernel_size\": 2,\n            \"strides\": 2,\n            \"padding\": \"valid\",\n            \"output_padding\": None,\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": 2,\n            \"strides\": 3,\n            \"padding\": \"same\",\n            \"output_padding\": 2,\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (1, 1, 1),\n        },\n        {\n            \"filters\": 6,\n            \"kernel_size\": (2, 2, 3),\n            \"strides\": (2, 1, 2),\n            \"padding\": \"valid\",\n            \"output_padding\": None,\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (1, 1, 1),\n        },\n    )\n    def test_conv3d_transpose(\n        self,\n        filters,\n        kernel_size,\n        strides,\n        padding,\n        output_padding,\n        data_format,\n        dilation_rate,\n    ):\n        layer = layers.Conv3DTranspose(\n            filters=filters,\n            kernel_size=kernel_size,\n            strides=strides,\n            padding=padding,\n            output_padding=output_padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n        )\n\n        inputs = np.random.normal(size=[2, 8, 8, 8, 4])\n        layer.build(input_shape=inputs.shape)\n\n        kernel_shape = layer.kernel.shape\n        kernel_weights = np.random.normal(size=kernel_shape)\n        bias_weights = np.random.normal(size=(filters,))\n        layer.kernel.assign(kernel_weights)\n        layer.bias.assign(bias_weights)\n\n        outputs = layer(inputs)\n        expected = np_conv3d_transpose(\n            inputs,\n            kernel_weights,\n            bias_weights,\n            strides,\n            padding,\n            output_padding,\n            data_format,\n            dilation_rate,\n        )\n        self.assertAllClose(\n            outputs, expected, atol=1e-5, tpu_atol=1e-1, tpu_rtol=1e-1\n        )\n\n    @parameterized.product(\n        kernel_size=list(range(1, 5)),\n        strides=list(range(1, 5)),\n        padding=[\"same\", \"valid\"],\n        output_padding=[None] + list(range(1, 5)),\n    )\n    def test_conv1d_transpose_consistency(\n        self, kernel_size, strides, padding, output_padding\n    ):\n        \"\"\"Test conv transpose, on an 1D array of size 3, against several\n        convolution parameters. In particular, tests if Torch inconsistencies\n        are raised.\n        \"\"\"\n\n        # output_padding cannot be greater than strides\n        if isinstance(output_padding, int) and output_padding >= strides:\n            pytest.skip(\n                \"`output_padding` greater than `strides` is not supported\"\n            )\n\n        if backend.config.image_data_format() == \"channels_last\":\n            input_shape = (1, 3, 1)\n        else:\n            input_shape = (1, 1, 3)\n\n        input = np.ones(shape=input_shape)\n        kernel_weights = np.arange(1, kernel_size + 1).reshape(\n            (kernel_size, 1, 1)\n        )\n\n        # Expected result\n        expected_res = np_conv1d_transpose(\n            x=input,\n            kernel_weights=kernel_weights,\n            bias_weights=np.zeros(shape=(1,)),\n            strides=strides,\n            padding=padding,\n            output_padding=output_padding,\n            data_format=backend.config.image_data_format(),\n            dilation_rate=1,\n        )\n\n        # keras layer\n        kc_layer = layers.Conv1DTranspose(\n            filters=1,\n            kernel_size=kernel_size,\n            strides=strides,\n            padding=padding,\n            output_padding=output_padding,\n            dilation_rate=1,\n        )\n        kc_layer.build(input_shape=input_shape)\n        kc_layer.kernel.assign(kernel_weights)\n\n        # Special cases for Torch\n        if backend.backend() == \"torch\":\n            # Args that cause output_padding >= strides\n            # are clamped with a warning.\n            if (kernel_size, strides, padding, output_padding) in [\n                (2, 1, \"same\", None),\n                (4, 1, \"same\", None),\n            ]:\n                clamped_output_padding = strides - 1  # usually 0 when stride=1\n                expected_res = np_conv1d_transpose(\n                    x=input,\n                    kernel_weights=kernel_weights,\n                    bias_weights=np.zeros(shape=(1,)),\n                    strides=strides,\n                    padding=padding,\n                    output_padding=clamped_output_padding,\n                    data_format=backend.config.image_data_format(),\n                    dilation_rate=1,\n                )\n                with pytest.warns(UserWarning):\n                    kc_res = kc_layer(input)\n                self.assertAllClose(expected_res, kc_res, atol=1e-5)\n                return\n\n            # torch_padding > 0 and torch_output_padding > 0 case\n            # Torch output differs from TF.\n            (\n                torch_padding,\n                torch_output_padding,\n            ) = _convert_conv_transpose_padding_args_from_keras_to_torch(\n                kernel_size=kernel_size,\n                stride=strides,\n                dilation_rate=1,\n                padding=padding,\n                output_padding=output_padding,\n            )\n            if torch_padding > 0 and torch_output_padding > 0:\n                with pytest.raises(AssertionError):\n                    kc_res = kc_layer(input)\n                    self.assertAllClose(expected_res, kc_res, atol=1e-5)\n                return\n\n        # Compare results\n        kc_res = kc_layer(input)\n        self.assertAllClose(expected_res, kc_res, atol=1e-5)\n\n    @parameterized.product(\n        kernel_size=list(range(1, 5)),\n        strides=list(range(1, 5)),\n        padding=[\"same\", \"valid\"],\n        output_padding=[None] + list(range(1, 5)),\n    )\n    def test_shape_inference_static_unknown_shape(\n        self, kernel_size, strides, padding, output_padding\n    ):\n        # output_padding cannot be greater than or equal to strides\n        if output_padding is not None and output_padding >= strides:\n            pytest.skip(\"`output_padding` must be less than `strides`\")\n\n        if backend.config.image_data_format() == \"channels_last\":\n            input_shape = (None, None, 3)\n            output_tensor_shape = (None, None, None, 2)\n        else:\n            input_shape = (3, None, None)\n            output_tensor_shape = (None, 2, None, None)\n        x = layers.Input(shape=input_shape)\n        x = layers.Conv2DTranspose(\n            filters=2,\n            kernel_size=kernel_size,\n            strides=strides,\n            padding=padding,\n            output_padding=output_padding,\n            dilation_rate=1,\n        )(x)\n        self.assertEqual(x.shape, output_tensor_shape)\n"
  },
  {
    "path": "keras/src/layers/convolutional/depthwise_conv1d.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.convolutional.base_depthwise_conv import BaseDepthwiseConv\n\n\n@keras_export(\"keras.layers.DepthwiseConv1D\")\nclass DepthwiseConv1D(BaseDepthwiseConv):\n    \"\"\"1D depthwise convolution layer.\n\n    Depthwise convolution is a type of convolution in which each input channel\n    is convolved with a different kernel (called a depthwise kernel). You can\n    understand depthwise convolution as the first step in a depthwise separable\n    convolution.\n\n    It is implemented via the following steps:\n\n    - Split the input into individual channels.\n    - Convolve each channel with an individual depthwise kernel with\n      `depth_multiplier` output channels.\n    - Concatenate the convolved outputs along the channels axis.\n\n    Unlike a regular 1D convolution, depthwise convolution does not mix\n    information across different input channels.\n\n    The `depth_multiplier` argument determines how many filters are applied to\n    one input channel. As such, it controls the amount of output channels that\n    are generated per input channel in the depthwise step.\n\n    Args:\n        kernel_size: int or tuple/list of 1 integer, specifying the size of the\n            depthwise convolution window.\n        strides: int or tuple/list of 1 integer, specifying the stride length\n            of the convolution. `strides > 1` is incompatible with\n            `dilation_rate > 1`.\n        padding: string, either `\"valid\"` or `\"same\"` (case-insensitive).\n            `\"valid\"` means no padding. `\"same\"` results in padding evenly to\n            the left/right or up/down of the input. When `padding=\"same\"` and\n            `strides=1`, the output has the same size as the input.\n        depth_multiplier: The number of depthwise convolution output channels\n            for each input channel. The total number of depthwise convolution\n            output channels will be equal to `input_channel * depth_multiplier`.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, steps, features)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, features, steps)`. It defaults to the `image_data_format`\n            value found in your Keras config file at `~/.keras/keras.json`.\n            If you never set it, then it will be `\"channels_last\"`.\n        dilation_rate: int or tuple/list of 1 integers, specifying the dilation\n            rate to use for dilated convolution.\n        activation: Activation function. If `None`, no activation is applied.\n        use_bias: bool, if `True`, bias will be added to the output.\n        depthwise_initializer: Initializer for the convolution kernel.\n            If `None`, the default initializer (`\"glorot_uniform\"`)\n            will be used.\n        bias_initializer: Initializer for the bias vector. If `None`, the\n            default initializer (`\"zeros\"`) will be used.\n        depthwise_regularizer: Optional regularizer for the convolution kernel.\n        bias_regularizer: Optional regularizer for the bias vector.\n        activity_regularizer: Optional regularizer function for the output.\n        depthwise_constraint: Optional projection function to be applied to the\n            kernel after being updated by an `Optimizer` (e.g. used to implement\n            norm constraints or value constraints for layer weights). The\n            function must take as input the unprojected variable and must return\n            the projected variable (which must have the same shape). Constraints\n            are not safe to use when doing asynchronous distributed training.\n        bias_constraint: Optional projection function to be applied to the\n            bias after being updated by an `Optimizer`.\n\n    Input shape:\n\n    - If `data_format=\"channels_last\"`:\n        A 3D tensor with shape: `(batch_shape, steps, channels)`\n    - If `data_format=\"channels_first\"`:\n        A 3D tensor with shape: `(batch_shape, channels, steps)`\n\n    Output shape:\n\n    - If `data_format=\"channels_last\"`:\n        A 3D tensor with shape:\n        `(batch_shape, new_steps, channels * depth_multiplier)`\n    - If `data_format=\"channels_first\"`:\n        A 3D tensor with shape:\n        `(batch_shape, channels * depth_multiplier, new_steps)`\n\n    Returns:\n        A 3D tensor representing\n        `activation(depthwise_conv1d(inputs, kernel) + bias)`.\n\n    Raises:\n        ValueError: when both `strides > 1` and `dilation_rate > 1`.\n\n    Example:\n\n    >>> x = np.random.rand(4, 10, 12)\n    >>> y = keras.layers.DepthwiseConv1D(3, 3, 2, activation='relu')(x)\n    >>> print(y.shape)\n    (4, 4, 36)\n    \"\"\"\n\n    def __init__(\n        self,\n        kernel_size,\n        strides=1,\n        padding=\"valid\",\n        depth_multiplier=1,\n        data_format=None,\n        dilation_rate=1,\n        activation=None,\n        use_bias=True,\n        depthwise_initializer=\"glorot_uniform\",\n        bias_initializer=\"zeros\",\n        depthwise_regularizer=None,\n        bias_regularizer=None,\n        activity_regularizer=None,\n        depthwise_constraint=None,\n        bias_constraint=None,\n        **kwargs,\n    ):\n        super().__init__(\n            rank=1,\n            depth_multiplier=depth_multiplier,\n            kernel_size=kernel_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n            activation=activation,\n            use_bias=use_bias,\n            depthwise_initializer=depthwise_initializer,\n            bias_initializer=bias_initializer,\n            depthwise_regularizer=depthwise_regularizer,\n            bias_regularizer=bias_regularizer,\n            activity_regularizer=activity_regularizer,\n            depthwise_constraint=depthwise_constraint,\n            bias_constraint=bias_constraint,\n            **kwargs,\n        )\n"
  },
  {
    "path": "keras/src/layers/convolutional/depthwise_conv2d.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.convolutional.base_depthwise_conv import BaseDepthwiseConv\n\n\n@keras_export(\"keras.layers.DepthwiseConv2D\")\nclass DepthwiseConv2D(BaseDepthwiseConv):\n    \"\"\"2D depthwise convolution layer.\n\n    Depthwise convolution is a type of convolution in which each input channel\n    is convolved with a different kernel (called a depthwise kernel). You can\n    understand depthwise convolution as the first step in a depthwise separable\n    convolution.\n\n    It is implemented via the following steps:\n\n    - Split the input into individual channels.\n    - Convolve each channel with an individual depthwise kernel with\n      `depth_multiplier` output channels.\n    - Concatenate the convolved outputs along the channels axis.\n\n    Unlike a regular 2D convolution, depthwise convolution does not mix\n    information across different input channels.\n\n    The `depth_multiplier` argument determines how many filters are applied to\n    one input channel. As such, it controls the amount of output channels that\n    are generated per input channel in the depthwise step.\n\n    Args:\n        kernel_size: int or tuple/list of 2 integer, specifying the size of the\n            depthwise convolution window.\n        strides: int or tuple/list of 2 integer, specifying the stride length\n            of the depthwise convolution. `strides > 1` is incompatible with\n            `dilation_rate > 1`.\n        padding: string, either `\"valid\"` or `\"same\"` (case-insensitive).\n            `\"valid\"` means no padding. `\"same\"` results in padding evenly to\n            the left/right or up/down of the input. When `padding=\"same\"` and\n            `strides=1`, the output has the same size as the input.\n        depth_multiplier: The number of depthwise convolution output channels\n            for each input channel. The total number of depthwise convolution\n            output channels will be equal to `input_channel * depth_multiplier`.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, height, width, channels)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, channels, height, width)`. It defaults to the\n            `image_data_format` value found in your Keras config file\n            at `~/.keras/keras.json`.\n            If you never set it, then it will be `\"channels_last\"`.\n        dilation_rate: int or tuple/list of 2 integers, specifying the dilation\n            rate to use for dilated convolution.\n        activation: Activation function. If `None`, no activation is applied.\n        use_bias: bool, if `True`, bias will be added to the output.\n        depthwise_initializer: Initializer for the convolution kernel.\n            If `None`, the default initializer (`\"glorot_uniform\"`)\n            will be used.\n        bias_initializer: Initializer for the bias vector. If `None`, the\n            default initializer (`\"zeros\"`) will be used.\n        depthwise_regularizer: Optional regularizer for the convolution kernel.\n        bias_regularizer: Optional regularizer for the bias vector.\n        activity_regularizer: Optional regularizer function for the output.\n        depthwise_constraint: Optional projection function to be applied to the\n            kernel after being updated by an `Optimizer` (e.g. used to implement\n            norm constraints or value constraints for layer weights). The\n            function must take as input the unprojected variable and must return\n            the projected variable (which must have the same shape). Constraints\n            are not safe to use when doing asynchronous distributed training.\n        bias_constraint: Optional projection function to be applied to the\n            bias after being updated by an `Optimizer`.\n\n    Input shape:\n\n    - If `data_format=\"channels_last\"`:\n        A 4D tensor with shape: `(batch_size, height, width, channels)`\n    - If `data_format=\"channels_first\"`:\n        A 4D tensor with shape: `(batch_size, channels, height, width)`\n\n    Output shape:\n\n    - If `data_format=\"channels_last\"`:\n        A 4D tensor with shape:\n        `(batch_size, new_height, new_width, channels * depth_multiplier)`\n    - If `data_format=\"channels_first\"`:\n        A 4D tensor with shape:\n        `(batch_size, channels * depth_multiplier, new_height, new_width)`\n\n    Returns:\n        A 4D tensor representing\n        `activation(depthwise_conv2d(inputs, kernel) + bias)`.\n\n    Raises:\n        ValueError: when both `strides > 1` and `dilation_rate > 1`.\n\n    Example:\n\n    >>> x = np.random.rand(4, 10, 10, 12)\n    >>> y = keras.layers.DepthwiseConv2D(kernel_size=3, activation='relu')(x)\n    >>> print(y.shape)\n    (4, 8, 8, 12)\n    \"\"\"\n\n    def __init__(\n        self,\n        kernel_size,\n        strides=(1, 1),\n        padding=\"valid\",\n        depth_multiplier=1,\n        data_format=None,\n        dilation_rate=(1, 1),\n        activation=None,\n        use_bias=True,\n        depthwise_initializer=\"glorot_uniform\",\n        bias_initializer=\"zeros\",\n        depthwise_regularizer=None,\n        bias_regularizer=None,\n        activity_regularizer=None,\n        depthwise_constraint=None,\n        bias_constraint=None,\n        **kwargs,\n    ):\n        super().__init__(\n            rank=2,\n            depth_multiplier=depth_multiplier,\n            kernel_size=kernel_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n            activation=activation,\n            use_bias=use_bias,\n            depthwise_initializer=depthwise_initializer,\n            bias_initializer=bias_initializer,\n            depthwise_regularizer=depthwise_regularizer,\n            bias_regularizer=bias_regularizer,\n            activity_regularizer=activity_regularizer,\n            depthwise_constraint=depthwise_constraint,\n            bias_constraint=bias_constraint,\n            **kwargs,\n        )\n"
  },
  {
    "path": "keras/src/layers/convolutional/depthwise_conv_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\nfrom numpy.lib.stride_tricks import as_strided\n\nfrom keras.src import layers\nfrom keras.src import testing\n\n\ndef _same_padding(input_size, kernel_size, stride):\n    if input_size % stride == 0:\n        padding = max(kernel_size - stride, 0)\n    else:\n        padding = max(kernel_size - (input_size % stride), 0)\n    return padding // 2, padding - padding // 2\n\n\ndef np_depthwise_conv1d(\n    x,\n    kernel_weights,\n    bias_weights,\n    strides,\n    padding,\n    data_format,\n    dilation_rate,\n):\n    if data_format == \"channels_first\":\n        x = x.transpose((0, 2, 1))\n    if isinstance(strides, (tuple, list)):\n        h_stride = strides[0]\n    else:\n        h_stride = strides\n    if isinstance(dilation_rate, (tuple, list)):\n        h_dilation = dilation_rate[0]\n    else:\n        h_dilation = dilation_rate\n    h_kernel, ch_in, ch_out = kernel_weights.shape\n\n    if h_dilation > 1:\n        new_h_kernel = h_kernel + (h_dilation - 1) * (h_kernel - 1)\n        new_kernel_weights = np.zeros(\n            (new_h_kernel, ch_in, ch_out),\n            dtype=kernel_weights.dtype,\n        )\n        new_kernel_weights[::h_dilation] = kernel_weights\n        kernel_weights = new_kernel_weights\n        h_kernel = kernel_weights.shape[0]\n\n    if padding == \"same\":\n        n_batch, h_x, _ = x.shape\n        h_pad = _same_padding(h_x, h_kernel, h_stride)\n        npad = [(0, 0)] * x.ndim\n        npad[1] = h_pad\n        x = np.pad(x, pad_width=npad, mode=\"constant\", constant_values=0)\n\n    n_batch, h_x, _ = x.shape\n    h_out = int((h_x - h_kernel) / h_stride) + 1\n\n    out_grps = []\n    bias_weights = bias_weights.reshape(ch_in, ch_out)\n    for ch_in_idx in range(ch_in):\n        for ch_out_idx in range(ch_out):\n            x_in = np.ascontiguousarray(x[..., ch_in_idx])\n            stride_shape = (n_batch, h_out, h_kernel)\n            strides = (\n                x_in.strides[0],\n                h_stride * x_in.strides[1],\n                x_in.strides[1],\n            )\n            inner_dim = h_kernel\n            x_strided = as_strided(\n                x_in, shape=stride_shape, strides=strides\n            ).reshape(-1, inner_dim)\n            kernel_weights_grp = kernel_weights[\n                ..., ch_in_idx, ch_out_idx\n            ].reshape(-1, 1)\n            bias_weights_grp = bias_weights[..., ch_in_idx, ch_out_idx]\n            out_grps.append(\n                (x_strided @ kernel_weights_grp + bias_weights_grp).reshape(\n                    n_batch, h_out, 1\n                )\n            )\n    out = np.concatenate(out_grps, axis=-1)\n    if data_format == \"channels_first\":\n        out = out.transpose((0, 2, 1))\n    return out\n\n\ndef np_depthwise_conv2d(\n    x,\n    kernel_weights,\n    bias_weights,\n    strides,\n    padding,\n    data_format,\n    dilation_rate,\n):\n    if data_format == \"channels_first\":\n        x = x.transpose((0, 2, 3, 1))\n    if isinstance(strides, (tuple, list)):\n        h_stride, w_stride = strides\n    else:\n        h_stride = strides\n        w_stride = strides\n    if isinstance(dilation_rate, (tuple, list)):\n        h_dilation, w_dilation = dilation_rate\n    else:\n        h_dilation = dilation_rate\n        w_dilation = dilation_rate\n    h_kernel, w_kernel, ch_in, ch_out = kernel_weights.shape\n\n    if h_dilation > 1 or w_dilation > 1:\n        new_h_kernel = h_kernel + (h_dilation - 1) * (h_kernel - 1)\n        new_w_kernel = w_kernel + (w_dilation - 1) * (w_kernel - 1)\n        new_kenel_size_tuple = (new_h_kernel, new_w_kernel)\n        new_kernel_weights = np.zeros(\n            (*new_kenel_size_tuple, ch_in, ch_out),\n            dtype=kernel_weights.dtype,\n        )\n        new_kernel_weights[::h_dilation, ::w_dilation] = kernel_weights\n        kernel_weights = new_kernel_weights\n        h_kernel, w_kernel = kernel_weights.shape[:2]\n\n    if padding == \"same\":\n        n_batch, h_x, w_x, _ = x.shape\n        h_pad = _same_padding(h_x, h_kernel, h_stride)\n        w_pad = _same_padding(w_x, w_kernel, w_stride)\n        npad = [(0, 0)] * x.ndim\n        npad[1] = h_pad\n        npad[2] = w_pad\n        x = np.pad(x, pad_width=npad, mode=\"constant\", constant_values=0)\n\n    n_batch, h_x, w_x, _ = x.shape\n    h_out = int((h_x - h_kernel) / h_stride) + 1\n    w_out = int((w_x - w_kernel) / w_stride) + 1\n\n    out_grps = []\n    bias_weights = bias_weights.reshape(ch_in, ch_out)\n    for ch_in_idx in range(ch_in):\n        for ch_out_idx in range(ch_out):\n            x_in = np.ascontiguousarray(x[..., ch_in_idx])\n            stride_shape = (n_batch, h_out, w_out, h_kernel, w_kernel)\n            strides = (\n                x_in.strides[0],\n                h_stride * x_in.strides[1],\n                w_stride * x_in.strides[2],\n                x_in.strides[1],\n                x_in.strides[2],\n            )\n            inner_dim = h_kernel * w_kernel\n            x_strided = as_strided(\n                x_in, shape=stride_shape, strides=strides\n            ).reshape(-1, inner_dim)\n            kernel_weights_grp = kernel_weights[\n                ..., ch_in_idx, ch_out_idx\n            ].reshape(-1, 1)\n            bias_weights_grp = bias_weights[..., ch_in_idx, ch_out_idx]\n            out_grps.append(\n                (x_strided @ kernel_weights_grp + bias_weights_grp).reshape(\n                    n_batch, h_out, w_out, 1\n                )\n            )\n    out = np.concatenate(out_grps, axis=-1)\n    if data_format == \"channels_first\":\n        out = out.transpose((0, 3, 1, 2))\n    return out\n\n\nclass DepthwiseConvBasicTest(testing.TestCase):\n    @parameterized.parameters(\n        {\n            \"depth_multiplier\": 5,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n            \"input_shape\": (3, 5, 4),\n            \"output_shape\": (3, 4, 20),\n        },\n        {\n            \"depth_multiplier\": 6,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"same\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (2,),\n            \"input_shape\": (3, 4, 4),\n            \"output_shape\": (3, 4, 24),\n        },\n        {\n            \"depth_multiplier\": 6,\n            \"kernel_size\": 2,\n            \"strides\": (2,),\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n            \"input_shape\": (3, 5, 4),\n            \"output_shape\": (3, 2, 24),\n        },\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_depthwise_conv1d_basic(\n        self,\n        depth_multiplier,\n        kernel_size,\n        strides,\n        padding,\n        data_format,\n        dilation_rate,\n        input_shape,\n        output_shape,\n    ):\n        self.run_layer_test(\n            layers.DepthwiseConv1D,\n            init_kwargs={\n                \"depth_multiplier\": depth_multiplier,\n                \"kernel_size\": kernel_size,\n                \"strides\": strides,\n                \"padding\": padding,\n                \"data_format\": data_format,\n                \"dilation_rate\": dilation_rate,\n            },\n            input_shape=input_shape,\n            expected_output_shape=output_shape,\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n\n    @parameterized.parameters(\n        {\n            \"depth_multiplier\": 5,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n            \"input_shape\": (3, 5, 5, 4),\n            \"output_shape\": (3, 4, 4, 20),\n        },\n        {\n            \"depth_multiplier\": 6,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"same\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (2, 2),\n            \"input_shape\": (3, 4, 4, 4),\n            \"output_shape\": (3, 4, 4, 24),\n        },\n        {\n            \"depth_multiplier\": 6,\n            \"kernel_size\": (2, 2),\n            \"strides\": (2, 2),\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (1, 1),\n            \"input_shape\": (3, 5, 5, 4),\n            \"output_shape\": (3, 2, 2, 24),\n        },\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_depthwise_conv2d_basic(\n        self,\n        depth_multiplier,\n        kernel_size,\n        strides,\n        padding,\n        data_format,\n        dilation_rate,\n        input_shape,\n        output_shape,\n    ):\n        self.run_layer_test(\n            layers.DepthwiseConv2D,\n            init_kwargs={\n                \"depth_multiplier\": depth_multiplier,\n                \"kernel_size\": kernel_size,\n                \"strides\": strides,\n                \"padding\": padding,\n                \"data_format\": data_format,\n                \"dilation_rate\": dilation_rate,\n            },\n            input_shape=input_shape,\n            expected_output_shape=output_shape,\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n\n    def test_bad_init_args(self):\n        # `depth_multiplier` is not positive.\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Invalid value for argument `depth_multiplier`. \"\n            \"Expected a strictly positive value. Received \"\n            \"depth_multiplier=0.\",\n        ):\n            layers.DepthwiseConv1D(depth_multiplier=0, kernel_size=1)\n\n        # `kernel_size` has 0.\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"The `kernel_size` argument must be a tuple of 2 \"\n            r\"integers. Received kernel_size=\\(1, 0\\), including values \"\n            r\"\\{0\\} that do not satisfy `value > 0`\",\n        ):\n            layers.DepthwiseConv2D(depth_multiplier=2, kernel_size=(1, 0))\n\n        # `strides` has 0.\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"The `strides` argument must be a tuple of \\d+ \"\n            r\"integers. Received strides=\\(1, 0\\), including values \\{0\\} \"\n            r\"that do not satisfy `value > 0`\",\n        ):\n            layers.DepthwiseConv2D(\n                depth_multiplier=2, kernel_size=(2, 2), strides=(1, 0)\n            )\n\n        # `dilation_rate > 1` while `strides > 1`.\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"`strides > 1` not supported in conjunction with \"\n            r\"`dilation_rate > 1`. Received: strides=\\(2, 2\\) and \"\n            r\"dilation_rate=\\(2, 1\\)\",\n        ):\n            layers.DepthwiseConv2D(\n                depth_multiplier=2,\n                kernel_size=(2, 2),\n                strides=2,\n                dilation_rate=(2, 1),\n            )\n\n\nclass DepthwiseConvCorrectnessTest(testing.TestCase):\n    @parameterized.parameters(\n        {\n            \"depth_multiplier\": 5,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n        },\n        {\n            \"depth_multiplier\": 6,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"same\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (2,),\n        },\n        {\n            \"depth_multiplier\": 6,\n            \"kernel_size\": (2,),\n            \"strides\": (2,),\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n        },\n    )\n    def test_depthwise_conv1d(\n        self,\n        depth_multiplier,\n        kernel_size,\n        strides,\n        padding,\n        data_format,\n        dilation_rate,\n    ):\n        layer = layers.DepthwiseConv1D(\n            depth_multiplier=depth_multiplier,\n            kernel_size=kernel_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n        )\n\n        inputs = np.random.normal(size=[2, 8, 4])\n        layer.build(input_shape=inputs.shape)\n\n        kernel_shape = layer.kernel.shape\n        kernel_weights = np.random.normal(size=kernel_shape)\n        bias_weights = np.random.normal(size=(depth_multiplier * 4,))\n        layer.kernel.assign(kernel_weights)\n        layer.bias.assign(bias_weights)\n\n        outputs = layer(inputs)\n        expected = np_depthwise_conv1d(\n            inputs,\n            kernel_weights,\n            bias_weights,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n        )\n        self.assertAllClose(outputs, expected, tpu_atol=1e-2, tpu_rtol=1e-2)\n\n    @parameterized.parameters(\n        {\n            \"depth_multiplier\": 5,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n        },\n        {\n            \"depth_multiplier\": 6,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"same\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (2, 2),\n        },\n        {\n            \"depth_multiplier\": 6,\n            \"kernel_size\": (2, 2),\n            \"strides\": (2, 2),\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (1, 1),\n        },\n    )\n    def test_depthwise_conv2d(\n        self,\n        depth_multiplier,\n        kernel_size,\n        strides,\n        padding,\n        data_format,\n        dilation_rate,\n    ):\n        layer = layers.DepthwiseConv2D(\n            depth_multiplier=depth_multiplier,\n            kernel_size=kernel_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n        )\n\n        inputs = np.random.normal(size=[2, 8, 8, 4])\n        layer.build(input_shape=inputs.shape)\n\n        kernel_shape = layer.kernel.shape\n        kernel_weights = np.random.normal(size=kernel_shape)\n        bias_weights = np.random.normal(size=(depth_multiplier * 4,))\n        layer.kernel.assign(kernel_weights)\n        layer.bias.assign(bias_weights)\n\n        outputs = layer(inputs)\n        expected = np_depthwise_conv2d(\n            inputs,\n            kernel_weights,\n            bias_weights,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n        )\n        self.assertAllClose(\n            outputs.shape, expected.shape, tpu_atol=1e-2, tpu_rtol=1e-2\n        )\n        self.assertAllClose(\n            outputs, expected, atol=1e-5, tpu_atol=1e-1, tpu_rtol=1e-1\n        )\n"
  },
  {
    "path": "keras/src/layers/convolutional/separable_conv1d.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.convolutional.base_separable_conv import BaseSeparableConv\n\n\n@keras_export(\n    [\n        \"keras.layers.SeparableConv1D\",\n        \"keras.layers.SeparableConvolution1D\",\n    ]\n)\nclass SeparableConv1D(BaseSeparableConv):\n    \"\"\"1D separable convolution layer.\n\n    This layer performs a depthwise convolution that acts separately on\n    channels, followed by a pointwise convolution that mixes channels.\n    If `use_bias` is True and a bias initializer is provided,\n    it adds a bias vector to the output. It then optionally applies an\n    activation function to produce the final output.\n\n    Args:\n        filters: int, the dimensionality of the output space (i.e. the number\n            of filters in the pointwise convolution).\n        kernel_size: int or tuple/list of 1 integers, specifying the size of the\n            depthwise convolution window.\n        strides: int or tuple/list of 1 integers, specifying the stride length\n            of the depthwise convolution. If only one int is specified, the same\n            stride size will be used for all dimensions. `strides > 1` is\n            incompatible with `dilation_rate > 1`.\n        padding: string, either `\"valid\"` or `\"same\"` (case-insensitive).\n            `\"valid\"` means no padding. `\"same\"` results in padding evenly to\n            the left/right or up/down of the input. When `padding=\"same\"` and\n            `strides=1`, the output has the same size as the input.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, steps, features)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, features, steps)`. It defaults to the `image_data_format`\n            value found in your Keras config file at `~/.keras/keras.json`.\n            If you never set it, then it will be `\"channels_last\"`.\n        dilation_rate: int or tuple/list of 1 integers, specifying the dilation\n            rate to use for dilated convolution. If only one int is specified,\n            the same dilation rate will be used for all dimensions.\n        depth_multiplier: The number of depthwise convolution output channels\n            for each input channel. The total number of depthwise convolution\n            output channels will be equal to `input_channel * depth_multiplier`.\n        activation: Activation function. If `None`, no activation is applied.\n        use_bias: bool, if `True`, bias will be added to the output.\n        depthwise_initializer: An initializer for the depthwise convolution\n            kernel. If None, then the default initializer (`\"glorot_uniform\"`)\n            will be used.\n        pointwise_initializer: An initializer for the pointwise convolution\n            kernel. If None, then the default initializer (`\"glorot_uniform\"`)\n            will be used.\n        bias_initializer: An initializer for the bias vector. If None, the\n            default initializer ('\"zeros\"') will be used.\n        depthwise_regularizer: Optional regularizer for the depthwise\n            convolution kernel.\n        pointwise_regularizer: Optional regularizer for the pointwise\n            convolution kernel.\n        bias_regularizer: Optional regularizer for the bias vector.\n        activity_regularizer: Optional regularizer function for the output.\n        depthwise_constraint: Optional projection function to be applied to the\n            depthwise kernel after being updated by an `Optimizer` (e.g. used\n            for norm constraints or value constraints for layer weights). The\n            function must take as input the unprojected variable and must return\n            the projected variable (which must have the same shape).\n        pointwise_constraint: Optional projection function to be applied to the\n            pointwise kernel after being updated by an `Optimizer`.\n        bias_constraint: Optional projection function to be applied to the\n            bias after being updated by an `Optimizer`.\n\n    Input shape:\n\n    - If `data_format=\"channels_last\"`:\n        A 3D tensor with shape: `(batch_shape, steps, channels)`\n    - If `data_format=\"channels_first\"`:\n        A 3D tensor with shape: `(batch_shape, channels, steps)`\n\n    Output shape:\n\n    - If `data_format=\"channels_last\"`:\n        A 3D tensor with shape: `(batch_shape, new_steps, filters)`\n    - If `data_format=\"channels_first\"`:\n        A 3D tensor with shape: `(batch_shape, filters, new_steps)`\n\n    Returns:\n        A 3D tensor representing\n        `activation(separable_conv1d(inputs, kernel) + bias)`.\n\n    Example:\n\n    >>> x = np.random.rand(4, 10, 12)\n    >>> y = keras.layers.SeparableConv1D(3, 4, 3, 2, activation='relu')(x)\n    >>> print(y.shape)\n    (4, 4, 4)\n    \"\"\"\n\n    def __init__(\n        self,\n        filters,\n        kernel_size,\n        strides=1,\n        padding=\"valid\",\n        data_format=None,\n        dilation_rate=1,\n        depth_multiplier=1,\n        activation=None,\n        use_bias=True,\n        depthwise_initializer=\"glorot_uniform\",\n        pointwise_initializer=\"glorot_uniform\",\n        bias_initializer=\"zeros\",\n        depthwise_regularizer=None,\n        pointwise_regularizer=None,\n        bias_regularizer=None,\n        activity_regularizer=None,\n        depthwise_constraint=None,\n        pointwise_constraint=None,\n        bias_constraint=None,\n        **kwargs,\n    ):\n        super().__init__(\n            rank=1,\n            depth_multiplier=depth_multiplier,\n            filters=filters,\n            kernel_size=kernel_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n            activation=activation,\n            use_bias=use_bias,\n            depthwise_initializer=depthwise_initializer,\n            pointwise_initializer=pointwise_initializer,\n            bias_initializer=bias_initializer,\n            depthwise_regularizer=depthwise_regularizer,\n            pointwise_regularizer=pointwise_regularizer,\n            bias_regularizer=bias_regularizer,\n            activity_regularizer=activity_regularizer,\n            depthwise_constraint=depthwise_constraint,\n            pointwise_constraint=pointwise_constraint,\n            bias_constraint=bias_constraint,\n            **kwargs,\n        )\n"
  },
  {
    "path": "keras/src/layers/convolutional/separable_conv2d.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.convolutional.base_separable_conv import BaseSeparableConv\n\n\n@keras_export(\n    [\n        \"keras.layers.SeparableConv2D\",\n        \"keras.layers.SeparableConvolution2D\",\n    ]\n)\nclass SeparableConv2D(BaseSeparableConv):\n    \"\"\"2D separable convolution layer.\n\n    This layer performs a depthwise convolution that acts separately on\n    channels, followed by a pointwise convolution that mixes channels.\n    If `use_bias` is True and a bias initializer is provided,\n    it adds a bias vector to the output. It then optionally applies an\n    activation function to produce the final output.\n\n    Args:\n        filters: int, the dimensionality of the output space (i.e. the number\n            of filters in the pointwise convolution).\n        kernel_size: int or tuple/list of 2 integers, specifying the size of the\n            depthwise convolution window.\n        strides: int or tuple/list of 2 integers, specifying the stride length\n            of the depthwise convolution. If only one int is specified, the same\n            stride size will be used for all dimensions. `strides > 1` is\n            incompatible with `dilation_rate > 1`.\n        padding: string, either `\"valid\"` or `\"same\"` (case-insensitive).\n            `\"valid\"` means no padding. `\"same\"` results in padding evenly to\n            the left/right or up/down of the input. When `padding=\"same\"` and\n            `strides=1`, the output has the same size as the input.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, height, width, channels)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, channels, height, width)`. It defaults to the\n            `image_data_format` value found in your Keras config file\n            at `~/.keras/keras.json`.\n            If you never set it, then it will be `\"channels_last\"`.\n        dilation_rate: int or tuple/list of 2 integers, specifying the dilation\n            rate to use for dilated convolution. If only one int is specified,\n            the same dilation rate will be used for all dimensions.\n        depth_multiplier: The number of depthwise convolution output channels\n            for each input channel. The total number of depthwise convolution\n            output channels will be equal to `input_channel * depth_multiplier`.\n        activation: Activation function. If `None`, no activation is applied.\n        use_bias: bool, if `True`, bias will be added to the output.\n        depthwise_initializer: An initializer for the depthwise convolution\n            kernel. If None, then the default initializer (`\"glorot_uniform\"`)\n            will be used.\n        pointwise_initializer: An initializer for the pointwise convolution\n            kernel. If None, then the default initializer (`\"glorot_uniform\"`)\n            will be used.\n        bias_initializer: An initializer for the bias vector. If None, the\n            default initializer ('\"zeros\"') will be used.\n        depthwise_regularizer: Optional regularizer for the depthwise\n            convolution kernel.\n        pointwise_regularizer: Optional regularizer for the pointwise\n            convolution kernel.\n        bias_regularizer: Optional regularizer for the bias vector.\n        activity_regularizer: Optional regularizer function for the output.\n        depthwise_constraint: Optional projection function to be applied to the\n            depthwise kernel after being updated by an `Optimizer` (e.g. used\n            for norm constraints or value constraints for layer weights). The\n            function must take as input the unprojected variable and must return\n            the projected variable (which must have the same shape).\n        pointwise_constraint: Optional projection function to be applied to the\n            pointwise kernel after being updated by an `Optimizer`.\n        bias_constraint: Optional projection function to be applied to the\n            bias after being updated by an `Optimizer`.\n\n    Input shape:\n\n    - If `data_format=\"channels_last\"`:\n        A 4D tensor with shape: `(batch_size, height, width, channels)`\n    - If `data_format=\"channels_first\"`:\n        A 4D tensor with shape: `(batch_size, channels, height, width)`\n\n    Output shape:\n\n    - If `data_format=\"channels_last\"`:\n        A 4D tensor with shape: `(batch_size, new_height, new_width, filters)`\n    - If `data_format=\"channels_first\"`:\n        A 4D tensor with shape: `(batch_size, filters, new_height, new_width)`\n\n    Returns:\n        A 4D tensor representing\n        `activation(separable_conv2d(inputs, kernel) + bias)`.\n\n    Example:\n\n    >>> x = np.random.rand(4, 10, 10, 12)\n    >>> y = keras.layers.SeparableConv2D(3, 4, 3, 2, activation='relu')(x)\n    >>> print(y.shape)\n    (4, 4, 4, 4)\n    \"\"\"\n\n    def __init__(\n        self,\n        filters,\n        kernel_size,\n        strides=(1, 1),\n        padding=\"valid\",\n        data_format=None,\n        dilation_rate=(1, 1),\n        depth_multiplier=1,\n        activation=None,\n        use_bias=True,\n        depthwise_initializer=\"glorot_uniform\",\n        pointwise_initializer=\"glorot_uniform\",\n        bias_initializer=\"zeros\",\n        depthwise_regularizer=None,\n        pointwise_regularizer=None,\n        bias_regularizer=None,\n        activity_regularizer=None,\n        depthwise_constraint=None,\n        pointwise_constraint=None,\n        bias_constraint=None,\n        **kwargs,\n    ):\n        super().__init__(\n            rank=2,\n            depth_multiplier=depth_multiplier,\n            filters=filters,\n            kernel_size=kernel_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n            activation=activation,\n            use_bias=use_bias,\n            depthwise_initializer=depthwise_initializer,\n            pointwise_initializer=pointwise_initializer,\n            bias_initializer=bias_initializer,\n            depthwise_regularizer=depthwise_regularizer,\n            pointwise_regularizer=pointwise_regularizer,\n            bias_regularizer=bias_regularizer,\n            activity_regularizer=activity_regularizer,\n            depthwise_constraint=depthwise_constraint,\n            pointwise_constraint=pointwise_constraint,\n            bias_constraint=bias_constraint,\n            **kwargs,\n        )\n"
  },
  {
    "path": "keras/src/layers/convolutional/separable_conv_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import layers\nfrom keras.src import testing\nfrom keras.src.layers.convolutional.conv_test import np_conv1d\nfrom keras.src.layers.convolutional.conv_test import np_conv2d\nfrom keras.src.layers.convolutional.depthwise_conv_test import (\n    np_depthwise_conv1d,\n)\nfrom keras.src.layers.convolutional.depthwise_conv_test import (\n    np_depthwise_conv2d,\n)\n\n\nclass SeparableConvBasicTest(testing.TestCase):\n    @parameterized.parameters(\n        {\n            \"depth_multiplier\": 5,\n            \"filters\": 5,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n            \"input_shape\": (3, 5, 4),\n            \"output_shape\": (3, 4, 5),\n        },\n        {\n            \"depth_multiplier\": 6,\n            \"filters\": 6,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"same\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (2,),\n            \"input_shape\": (3, 4, 4),\n            \"output_shape\": (3, 4, 6),\n        },\n        {\n            \"depth_multiplier\": 6,\n            \"filters\": 6,\n            \"kernel_size\": 2,\n            \"strides\": (2,),\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n            \"input_shape\": (3, 5, 4),\n            \"output_shape\": (3, 2, 6),\n        },\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_separable_conv1d_basic(\n        self,\n        depth_multiplier,\n        filters,\n        kernel_size,\n        strides,\n        padding,\n        data_format,\n        dilation_rate,\n        input_shape,\n        output_shape,\n    ):\n        self.run_layer_test(\n            layers.SeparableConv1D,\n            init_kwargs={\n                \"depth_multiplier\": depth_multiplier,\n                \"filters\": filters,\n                \"kernel_size\": kernel_size,\n                \"strides\": strides,\n                \"padding\": padding,\n                \"data_format\": data_format,\n                \"dilation_rate\": dilation_rate,\n            },\n            input_shape=input_shape,\n            expected_output_shape=output_shape,\n            expected_num_trainable_weights=3,\n            expected_num_non_trainable_weights=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n\n    @parameterized.parameters(\n        {\n            \"depth_multiplier\": 5,\n            \"filters\": 5,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n            \"input_shape\": (3, 5, 5, 4),\n            \"output_shape\": (3, 4, 4, 5),\n        },\n        {\n            \"depth_multiplier\": 6,\n            \"filters\": 6,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"same\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (2, 2),\n            \"input_shape\": (3, 4, 4, 4),\n            \"output_shape\": (3, 4, 4, 6),\n        },\n        {\n            \"depth_multiplier\": 6,\n            \"filters\": 6,\n            \"kernel_size\": (2, 2),\n            \"strides\": (2, 2),\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (1, 1),\n            \"input_shape\": (3, 5, 5, 4),\n            \"output_shape\": (3, 2, 2, 6),\n        },\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_separable_conv2d_basic(\n        self,\n        depth_multiplier,\n        filters,\n        kernel_size,\n        strides,\n        padding,\n        data_format,\n        dilation_rate,\n        input_shape,\n        output_shape,\n    ):\n        self.run_layer_test(\n            layers.SeparableConv2D,\n            init_kwargs={\n                \"depth_multiplier\": depth_multiplier,\n                \"filters\": filters,\n                \"kernel_size\": kernel_size,\n                \"strides\": strides,\n                \"padding\": padding,\n                \"data_format\": data_format,\n                \"dilation_rate\": dilation_rate,\n            },\n            input_shape=input_shape,\n            expected_output_shape=output_shape,\n            expected_num_trainable_weights=3,\n            expected_num_non_trainable_weights=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n\n    def test_bad_init_args(self):\n        # `depth_multiplier` is not positive.\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Invalid value for argument `depth_multiplier`. \"\n            \"Expected a strictly positive value. Received \"\n            \"depth_multiplier=0.\",\n        ):\n            layers.SeparableConv1D(depth_multiplier=0, filters=1, kernel_size=1)\n\n        # `filters` is not positive.\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Invalid value for argument `filters`. Expected a \"\n            \"strictly positive value. Received filters=0.\",\n        ):\n            layers.SeparableConv1D(depth_multiplier=1, filters=0, kernel_size=1)\n\n        # `kernel_size` has 0.\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"The `kernel_size` argument must be a tuple of \"\n            r\"\\d+ integers. Received kernel_size=\\(1, 0\\), including values\"\n            r\" \\{0\\} that do not satisfy `value > 0`\",\n        ):\n            layers.SeparableConv2D(\n                depth_multiplier=2, filters=2, kernel_size=(1, 0)\n            )\n\n        # `strides` has 0.\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"The `strides` argument must be a tuple of \\d+ \"\n            r\"integers. Received strides=\\(1, 0\\), including values \\{0\\} \"\n            r\"that do not satisfy `value > 0`\",\n        ):\n            layers.SeparableConv2D(\n                depth_multiplier=2,\n                filters=2,\n                kernel_size=(2, 2),\n                strides=(1, 0),\n            )\n\n        # `dilation_rate > 1` while `strides > 1`.\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"`strides > 1` not supported in conjunction with \"\n            r\"`dilation_rate > 1`. Received: strides=\\(2, 2\\) and \"\n            r\"dilation_rate=\\(2, 1\\)\",\n        ):\n            layers.SeparableConv2D(\n                depth_multiplier=2,\n                filters=2,\n                kernel_size=(2, 2),\n                strides=2,\n                dilation_rate=(2, 1),\n            )\n\n\nclass SeparableConvCorrectnessTest(testing.TestCase):\n    @parameterized.parameters(\n        {\n            \"depth_multiplier\": 5,\n            \"filters\": 5,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n        },\n        {\n            \"depth_multiplier\": 6,\n            \"filters\": 6,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"same\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (2,),\n        },\n        {\n            \"depth_multiplier\": 6,\n            \"filters\": 6,\n            \"kernel_size\": (2,),\n            \"strides\": (2,),\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n        },\n    )\n    def test_separable_conv1d(\n        self,\n        depth_multiplier,\n        filters,\n        kernel_size,\n        strides,\n        padding,\n        data_format,\n        dilation_rate,\n    ):\n        layer = layers.SeparableConv1D(\n            depth_multiplier=depth_multiplier,\n            filters=filters,\n            kernel_size=kernel_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n        )\n\n        inputs = np.random.normal(size=[2, 8, 4])\n        layer.build(input_shape=inputs.shape)\n\n        depthwise_kernel_shape = layer.depthwise_kernel.shape\n        depthwise_kernel_weights = np.random.normal(size=depthwise_kernel_shape)\n        layer.depthwise_kernel.assign(depthwise_kernel_weights)\n\n        pointwise_kernel_shape = layer.pointwise_kernel.shape\n        pointwise_kernel_weights = np.random.normal(size=pointwise_kernel_shape)\n        layer.pointwise_kernel.assign(pointwise_kernel_weights)\n\n        bias_weights = np.random.normal(size=(filters,))\n        layer.bias.assign(bias_weights)\n\n        outputs = layer(inputs)\n        expected_depthwise = np_depthwise_conv1d(\n            inputs,\n            depthwise_kernel_weights,\n            np.zeros(4 * depth_multiplier),\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n        )\n        expected = np_conv1d(\n            expected_depthwise,\n            pointwise_kernel_weights,\n            bias_weights,\n            strides=1,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=1,\n            groups=1,\n        )\n\n        self.assertAllClose(outputs.shape, expected.shape)\n        self.assertAllClose(\n            outputs,\n            expected,\n            rtol=1e-5,\n            atol=1e-5,\n            tpu_atol=1e-1,\n            tpu_rtol=1e-1,\n        )\n\n    @parameterized.parameters(\n        {\n            \"depth_multiplier\": 5,\n            \"filters\": 5,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": 1,\n        },\n        {\n            \"depth_multiplier\": 6,\n            \"filters\": 6,\n            \"kernel_size\": 2,\n            \"strides\": 1,\n            \"padding\": \"same\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (2, 2),\n        },\n        {\n            \"depth_multiplier\": 6,\n            \"filters\": 6,\n            \"kernel_size\": (2, 2),\n            \"strides\": (2, 2),\n            \"padding\": \"valid\",\n            \"data_format\": \"channels_last\",\n            \"dilation_rate\": (1, 1),\n        },\n    )\n    def test_separable_conv2d(\n        self,\n        depth_multiplier,\n        filters,\n        kernel_size,\n        strides,\n        padding,\n        data_format,\n        dilation_rate,\n    ):\n        layer = layers.SeparableConv2D(\n            depth_multiplier=depth_multiplier,\n            filters=filters,\n            kernel_size=kernel_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n        )\n\n        inputs = np.random.normal(size=[2, 8, 8, 4])\n        layer.build(input_shape=inputs.shape)\n\n        depthwise_kernel_shape = layer.depthwise_kernel.shape\n        depthwise_kernel_weights = np.random.normal(size=depthwise_kernel_shape)\n        layer.depthwise_kernel.assign(depthwise_kernel_weights)\n\n        pointwise_kernel_shape = layer.pointwise_kernel.shape\n        pointwise_kernel_weights = np.random.normal(size=pointwise_kernel_shape)\n        layer.pointwise_kernel.assign(pointwise_kernel_weights)\n\n        bias_weights = np.random.normal(size=(filters,))\n        layer.bias.assign(bias_weights)\n\n        outputs = layer(inputs)\n        expected_depthwise = np_depthwise_conv2d(\n            inputs,\n            depthwise_kernel_weights,\n            np.zeros(4 * depth_multiplier),\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n        )\n        expected = np_conv2d(\n            expected_depthwise,\n            pointwise_kernel_weights,\n            bias_weights,\n            strides=1,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=1,\n            groups=1,\n        )\n\n        self.assertAllClose(outputs.shape, expected.shape)\n        self.assertAllClose(\n            outputs,\n            expected,\n            rtol=1e-5,\n            atol=1e-5,\n            tpu_atol=1e-1,\n            tpu_rtol=1e-1,\n        )\n"
  },
  {
    "path": "keras/src/layers/core/__init__.py",
    "content": ""
  },
  {
    "path": "keras/src/layers/core/dense.py",
    "content": "import math\n\nimport ml_dtypes\n\nfrom keras.src import activations\nfrom keras.src import constraints\nfrom keras.src import initializers\nfrom keras.src import ops\nfrom keras.src import quantizers\nfrom keras.src import regularizers\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\nfrom keras.src.quantizers.quantization_config import QuantizationConfig\nfrom keras.src.quantizers.quantization_config import get_block_size_for_layer\nfrom keras.src.quantizers.quantizers import dequantize_with_sz_map\nfrom keras.src.saving import serialization_lib\n\n\n@keras_export(\"keras.layers.Dense\")\nclass Dense(Layer):\n    \"\"\"Just your regular densely-connected NN layer.\n\n    `Dense` implements the operation:\n    `output = activation(dot(input, kernel) + bias)`\n    where `activation` is the element-wise activation function\n    passed as the `activation` argument, `kernel` is a weights matrix\n    created by the layer, and `bias` is a bias vector created by the layer\n    (only applicable if `use_bias` is `True`). When this layer is\n    followed by a `BatchNormalization` layer, it is recommended to set\n    `use_bias=False` as `BatchNormalization` has its own bias term.\n\n    Note: If the input to the layer has a rank greater than 2, `Dense`\n    computes the dot product between the `inputs` and the `kernel` along the\n    last axis of the `inputs` and axis 0 of the `kernel` (using `tf.tensordot`).\n    For example, if input has dimensions `(batch_size, d0, d1)`, then we create\n    a `kernel` with shape `(d1, units)`, and the `kernel` operates along axis 2\n    of the `input`, on every sub-tensor of shape `(1, 1, d1)` (there are\n    `batch_size * d0` such sub-tensors). The output in this case will have\n    shape `(batch_size, d0, units)`.\n\n    Args:\n        units: Positive integer, dimensionality of the output space.\n        activation: Activation function to use.\n            If you don't specify anything, no activation is applied\n            (ie. \"linear\" activation: `a(x) = x`).\n        use_bias: Boolean, whether the layer uses a bias vector.\n        kernel_initializer: Initializer for the `kernel` weights matrix.\n        bias_initializer: Initializer for the bias vector.\n        kernel_regularizer: Regularizer function applied to\n            the `kernel` weights matrix.\n        bias_regularizer: Regularizer function applied to the bias vector.\n        activity_regularizer: Regularizer function applied to\n            the output of the layer (its \"activation\").\n        kernel_constraint: Constraint function applied to\n            the `kernel` weights matrix.\n        bias_constraint: Constraint function applied to the bias vector.\n        lora_rank: Optional integer. If set, the layer's forward pass\n            will implement LoRA (Low-Rank Adaptation)\n            with the provided rank. LoRA sets the layer's kernel\n            to non-trainable and replaces it with a delta over the\n            original kernel, obtained via multiplying two lower-rank\n            trainable matrices. This can be useful to reduce the\n            computation cost of fine-tuning large dense layers.\n            You can also enable LoRA on an existing\n            `Dense` layer by calling `layer.enable_lora(rank)`.\n        lora_alpha: Optional integer. If set, this parameter scales the\n            low-rank adaptation delta (computed as the product of two lower-rank\n            trainable matrices) during the forward pass. The delta is scaled by\n            `lora_alpha / lora_rank`, allowing you to fine-tune the strength of\n            the LoRA adjustment independently of `lora_rank`.\n\n    Input shape:\n        N-D tensor with shape: `(batch_size, ..., input_dim)`.\n        The most common situation would be\n        a 2D input with shape `(batch_size, input_dim)`.\n\n    Output shape:\n        N-D tensor with shape: `(batch_size, ..., units)`.\n        For instance, for a 2D input with shape `(batch_size, input_dim)`,\n        the output would have shape `(batch_size, units)`.\n    \"\"\"\n\n    def __init__(\n        self,\n        units,\n        activation=None,\n        use_bias=True,\n        kernel_initializer=\"glorot_uniform\",\n        bias_initializer=\"zeros\",\n        kernel_regularizer=None,\n        bias_regularizer=None,\n        activity_regularizer=None,\n        kernel_constraint=None,\n        bias_constraint=None,\n        lora_rank=None,\n        lora_alpha=None,\n        quantization_config=None,\n        **kwargs,\n    ):\n        if not isinstance(units, int) or units <= 0:\n            raise ValueError(\n                \"Received an invalid value for `units`, expected a positive \"\n                f\"integer. Received: units={units}\"\n            )\n\n        super().__init__(activity_regularizer=activity_regularizer, **kwargs)\n        self.units = units\n        self.activation = activations.get(activation)\n        self.use_bias = use_bias\n        self.kernel_initializer = initializers.get(kernel_initializer)\n        self.bias_initializer = initializers.get(bias_initializer)\n        self.kernel_regularizer = regularizers.get(kernel_regularizer)\n        self.bias_regularizer = regularizers.get(bias_regularizer)\n        self.kernel_constraint = constraints.get(kernel_constraint)\n        self.bias_constraint = constraints.get(bias_constraint)\n        self.lora_rank = lora_rank\n        self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank\n        self.lora_enabled = False\n        self.quantization_config = quantization_config\n        self.input_spec = InputSpec(min_ndim=2)\n        self.supports_masking = True\n\n    def build(self, input_shape):\n        kernel_shape = (input_shape[-1], self.units)\n        if self.quantization_mode:\n            self.quantized_build(\n                kernel_shape,\n                mode=self.quantization_mode,\n                config=self.quantization_config,\n            )\n        if self.quantization_mode not in (\"int8\", \"int4\", \"gptq\", \"awq\"):\n            # If the layer is quantized to int8 or int4, `self._kernel` will be\n            # added in `self._int8_build` or `_int4_build`. Therefore, we skip\n            # it here.\n            self._kernel = self.add_weight(\n                name=\"kernel\",\n                shape=kernel_shape,\n                initializer=self.kernel_initializer,\n                regularizer=self.kernel_regularizer,\n                constraint=self.kernel_constraint,\n            )\n        if self.use_bias:\n            self.bias = self.add_weight(\n                name=\"bias\",\n                shape=(self.units,),\n                initializer=self.bias_initializer,\n                regularizer=self.bias_regularizer,\n                constraint=self.bias_constraint,\n            )\n        else:\n            self.bias = None\n        self.input_spec = InputSpec(min_ndim=2, axes={-1: input_shape[-1]})\n        self.built = True\n        if self.lora_rank:\n            self.enable_lora(self.lora_rank)\n\n    @property\n    def kernel(self):\n        from keras.src.quantizers import gptq_core\n\n        if not self.built:\n            raise AttributeError(\n                \"You must build the layer before accessing `kernel`.\"\n            )\n\n        mode = self.quantization_mode\n        is_gptq = mode == \"gptq\"\n        is_awq = mode == \"awq\"\n        is_int4 = mode == \"int4\"\n        gptq_calibrated = bool(getattr(self, \"is_gptq_calibrated\", False))\n        awq_calibrated = bool(getattr(self, \"is_awq_calibrated\", False))\n        gptq_bits = (\n            gptq_core.get_weight_bits_for_layer(self, None) if is_gptq else None\n        )\n\n        # Decide the source tensor first (packed vs already-quantized vs plain\n        # kernel)\n        if is_gptq and gptq_calibrated and gptq_bits != 4:\n            # calibrated GPTQ, not 4-bit, no unpacking needed\n            kernel = self.quantized_kernel\n        else:\n            # Start with the stored kernel\n            kernel = getattr(self, \"_kernel\", None)\n\n            # Handle int4 unpacking cases in one place\n            if is_int4:\n                # unpack [in, ceil(out/2)] to [in, out]\n                kernel = quantizers.unpack_int4(\n                    kernel, self._orig_output_dim, axis=-1\n                )\n            elif is_gptq and gptq_calibrated and gptq_bits == 4:\n                kernel = quantizers.unpack_int4(\n                    self.quantized_kernel,\n                    orig_len=self.units,\n                    axis=0,\n                    dtype=\"uint8\",\n                )\n            elif is_awq and awq_calibrated:\n                # AWQ always uses 4-bit quantization\n                kernel = quantizers.unpack_int4(\n                    self.quantized_kernel,\n                    orig_len=self.units,\n                    axis=0,\n                    dtype=\"uint8\",\n                )\n\n        # Apply LoRA once at the end.\n        if self.lora_enabled:\n            kernel = kernel + (self.lora_alpha / self.lora_rank) * ops.matmul(\n                self.lora_kernel_a, self.lora_kernel_b\n            )\n\n        return kernel\n\n    def call(self, inputs, training=None):\n        x = ops.matmul(inputs, self.kernel)\n        if self.bias is not None:\n            x = ops.add(x, self.bias)\n        if self.activation is not None:\n            x = self.activation(x)\n        return x\n\n    def compute_output_shape(self, input_shape):\n        output_shape = list(input_shape)\n        output_shape[-1] = self.units\n        return tuple(output_shape)\n\n    def enable_lora(\n        self,\n        rank,\n        lora_alpha=None,\n        a_initializer=\"he_uniform\",\n        b_initializer=\"zeros\",\n    ):\n        if self.kernel_constraint:\n            raise ValueError(\n                \"Lora is incompatible with kernel constraints. \"\n                \"In order to enable lora on this layer, remove the \"\n                \"`kernel_constraint` argument.\"\n            )\n        if not self.built:\n            raise ValueError(\n                \"Cannot enable lora on a layer that isn't yet built.\"\n            )\n        if self.lora_enabled:\n            raise ValueError(\n                \"lora is already enabled. This can only be done once per layer.\"\n            )\n        if self.quantization_mode == \"gptq\":\n            raise NotImplementedError(\n                \"lora is not currently supported with GPTQ quantization.\"\n            )\n        self._tracker.unlock()\n        # Determine the correct input dimension for the LoRA A matrix. When\n        # the layer has been int4-quantized, `self._kernel` stores a *packed*\n        # representation whose first dimension is `ceil(input_dim/2)`. We\n        # saved the true, *unpacked* input dimension in `self._orig_input_dim`\n        # during quantization. Use it if available; otherwise fall back to the\n        # first dimension of `self.kernel`.\n        if self.quantization_mode == \"int4\" and hasattr(\n            self, \"_orig_input_dim\"\n        ):\n            input_dim_for_lora = self._orig_input_dim\n        else:\n            input_dim_for_lora = self.kernel.shape[0]\n\n        self.lora_kernel_a = self.add_weight(\n            name=\"lora_kernel_a\",\n            shape=(input_dim_for_lora, rank),\n            initializer=initializers.get(a_initializer),\n            regularizer=self.kernel_regularizer,\n        )\n        self.lora_kernel_b = self.add_weight(\n            name=\"lora_kernel_b\",\n            shape=(rank, self.kernel.shape[1]),\n            initializer=initializers.get(b_initializer),\n            regularizer=self.kernel_regularizer,\n        )\n        self._kernel.trainable = False\n        self._tracker.lock()\n        self.lora_enabled = True\n        self.lora_rank = rank\n        self.lora_alpha = lora_alpha if lora_alpha is not None else rank\n\n    def save_own_variables(self, store):\n        # Do nothing if the layer isn't yet built\n        if not self.built:\n            return\n        mode = self.quantization_mode\n        if mode not in self.variable_serialization_spec:\n            raise self._quantization_mode_error(mode)\n\n        # Kernel plus optional merged LoRA-aware scale/zero (returns\n        # (kernel, None, None) for None/gptq/awq)\n        kernel_value, merged_kernel_scale, merged_kernel_zero = (\n            self._get_kernel_with_merged_lora()\n        )\n        idx = 0\n        for name in self.variable_serialization_spec[mode]:\n            if name == \"kernel\":\n                store[str(idx)] = kernel_value\n            elif name == \"bias\" and self.bias is None:\n                continue\n            elif name == \"kernel_zero\":\n                if merged_kernel_zero is None:\n                    # kernel_zero only exists for sub-channel int4 quantization\n                    continue\n                store[str(idx)] = merged_kernel_zero\n            elif name == \"g_idx\":\n                if not hasattr(self, \"g_idx\"):\n                    # g_idx only exists for sub-channel int4 quantization\n                    continue\n                store[str(idx)] = self.g_idx\n            elif name == \"kernel_scale\" and mode in (\"int4\", \"int8\"):\n                # For int4/int8, the merged LoRA scale (if any) comes from\n                # `_get_kernel_with_merged_lora()`\n                store[str(idx)] = merged_kernel_scale\n            else:\n                store[str(idx)] = getattr(self, name)\n            idx += 1\n\n    def load_own_variables(self, store):\n        if not self.lora_enabled:\n            self._check_load_own_variables(store)\n        # Do nothing if the layer isn't yet built\n        if not self.built:\n            return\n        mode = self.quantization_mode\n        if mode not in self.variable_serialization_spec:\n            raise self._quantization_mode_error(mode)\n\n        # A saved GPTQ/AWQ quantized model will always be calibrated.\n        self.is_gptq_calibrated = mode == \"gptq\"\n        self.is_awq_calibrated = mode == \"awq\"\n\n        idx = 0\n        for name in self.variable_serialization_spec[mode]:\n            if name == \"kernel\":\n                self._kernel.assign(store[str(idx)])\n            elif name == \"bias\" and self.bias is None:\n                continue\n            elif name == \"kernel_zero\" and not hasattr(self, \"kernel_zero\"):\n                # kernel_zero only exists for sub-channel int4 quantization\n                continue\n            elif name == \"g_idx\" and not hasattr(self, \"g_idx\"):\n                # g_idx only exists for sub-channel int4 quantization\n                continue\n            else:\n                getattr(self, name).assign(store[str(idx)])\n            idx += 1\n        if self.lora_enabled:\n            self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))\n            self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\n            \"units\": self.units,\n            \"activation\": activations.serialize(self.activation),\n            \"use_bias\": self.use_bias,\n            \"kernel_initializer\": initializers.serialize(\n                self.kernel_initializer\n            ),\n            \"bias_initializer\": initializers.serialize(self.bias_initializer),\n            \"kernel_regularizer\": regularizers.serialize(\n                self.kernel_regularizer\n            ),\n            \"bias_regularizer\": regularizers.serialize(self.bias_regularizer),\n            \"kernel_constraint\": constraints.serialize(self.kernel_constraint),\n            \"bias_constraint\": constraints.serialize(self.bias_constraint),\n            \"quantization_config\": serialization_lib.serialize_keras_object(\n                self.quantization_config\n            ),\n        }\n        if self.lora_rank:\n            config[\"lora_rank\"] = self.lora_rank\n            config[\"lora_alpha\"] = self.lora_alpha\n        return {**base_config, **config}\n\n    @classmethod\n    def from_config(cls, config):\n        config = config.copy()\n        config[\"quantization_config\"] = (\n            serialization_lib.deserialize_keras_object(\n                config.get(\"quantization_config\", None)\n            )\n        )\n        return super().from_config(config)\n\n    @property\n    def variable_serialization_spec(self):\n        \"\"\"Returns a dict mapping quantization modes to variable names in order.\n\n        This spec is used by `save_own_variables` and `load_own_variables` to\n        determine the correct ordering of variables during serialization for\n        each quantization mode. `None` means no quantization.\n        \"\"\"\n        return {\n            None: [\n                \"kernel\",\n                \"bias\",\n            ],\n            \"int8\": [\n                \"kernel\",\n                \"bias\",\n                \"kernel_scale\",\n            ],\n            \"int4\": [\n                \"kernel\",\n                \"bias\",\n                \"kernel_scale\",\n                \"kernel_zero\",\n                \"g_idx\",\n            ],\n            \"float8\": [\n                \"kernel\",\n                \"bias\",\n                \"inputs_scale\",\n                \"inputs_amax_history\",\n                \"kernel_scale\",\n                \"kernel_amax_history\",\n                \"outputs_grad_scale\",\n                \"outputs_grad_amax_history\",\n            ],\n            \"gptq\": [\n                \"bias\",\n                \"quantized_kernel\",\n                \"kernel_scale\",\n                \"kernel_zero\",\n                \"g_idx\",\n            ],\n            \"awq\": [\n                \"bias\",\n                \"quantized_kernel\",\n                \"kernel_scale\",\n                \"kernel_zero\",\n                \"awq_scales\",\n                \"g_idx\",\n            ],\n        }\n\n    def quantized_build(self, kernel_shape, mode, config=None):\n        if mode == \"int8\":\n            self._int8_build(kernel_shape, config)\n        elif mode == \"int4\":\n            self._int4_build(kernel_shape, config)\n        elif mode == \"float8\":\n            self._float8_build()\n        elif mode == \"gptq\":\n            self._gptq_build(kernel_shape, config)\n        elif mode == \"awq\":\n            self._awq_build(kernel_shape, config)\n        else:\n            raise self._quantization_mode_error(mode)\n        self._is_quantized = True\n\n    def _int8_build(self, kernel_shape, config=None):\n        self.inputs_quantizer = (\n            QuantizationConfig.activation_quantizer_or_default(\n                config, quantizers.AbsMaxQuantizer()\n            )\n        )\n\n        self._kernel = self.add_weight(\n            name=\"kernel\",\n            shape=kernel_shape,\n            initializer=\"zeros\",\n            dtype=\"int8\",\n            trainable=False,\n        )\n        self.kernel_scale = self.add_weight(\n            name=\"kernel_scale\",\n            shape=(self.units,),\n            initializer=\"ones\",\n            trainable=False,\n        )\n\n    def _gptq_build(self, kernel_shape, config):\n        from keras.src.quantizers import gptq_core\n\n        # Ensures the forward pass uses the original high-precision kernel\n        # until calibration has been performed.\n        self.is_gptq_calibrated = False\n        self.kernel_shape = kernel_shape\n\n        weight_bits = gptq_core.get_weight_bits_for_layer(self, config)\n        # For 4-bit weights, we pack two values per byte.\n        units = (\n            (kernel_shape[1] + 1) // 2 if weight_bits == 4 else kernel_shape[1]\n        )\n\n        self.quantized_kernel = self.add_weight(\n            name=\"kernel\",\n            shape=(units, kernel_shape[0]),\n            initializer=\"zeros\",\n            dtype=\"uint8\",\n            trainable=False,\n        )\n\n        group_size = gptq_core.get_group_size_for_layer(self, config)\n        n_groups = (\n            1\n            if group_size == -1\n            else math.ceil(self.kernel_shape[0] / group_size)\n        )\n        self.kernel_scale = self.add_weight(\n            name=\"kernel_scale\",\n            shape=(self.units, n_groups),\n            initializer=\"ones\",\n            trainable=False,\n        )\n        self.kernel_zero = self.add_weight(\n            name=\"kernel_zero\",\n            shape=(self.units, n_groups),\n            initializer=\"zeros\",\n            dtype=\"uint8\",\n            trainable=False,\n        )\n        self.g_idx = self.add_weight(\n            name=\"g_idx\",\n            shape=(self.kernel_shape[0],),\n            initializer=\"zeros\",\n            dtype=\"float32\",\n            trainable=False,\n        )\n\n    def _gptq_call(self, inputs, training=False):\n        from keras.src.quantizers import gptq_core\n\n        if not self.is_gptq_calibrated:\n            W = self._kernel\n        else:\n            should_unpack = (\n                gptq_core.get_weight_bits_for_layer(self, config=None) == 4\n            )\n            W = (\n                quantizers.unpack_int4(\n                    self.quantized_kernel,\n                    orig_len=self.units,\n                    axis=0,\n                    dtype=\"uint8\",\n                )\n                if should_unpack\n                else self.quantized_kernel\n            )\n            W = ops.transpose(\n                dequantize_with_sz_map(\n                    W,\n                    self.kernel_scale,\n                    self.kernel_zero,\n                    self.g_idx,\n                )\n            )\n\n        y = ops.matmul(inputs, W)\n        if self.bias is not None:\n            y = ops.add(y, self.bias)\n        if self.activation is not None:\n            y = self.activation(y)\n        return y\n\n    def _awq_build(self, kernel_shape, config):\n        \"\"\"Build variables for AWQ quantization.\n\n        AWQ uses 4-bit quantization with per-channel AWQ scales that protect\n        salient weights based on activation magnitudes.\n        \"\"\"\n        from keras.src.quantizers import awq_core\n\n        # Ensures the forward pass uses the original high-precision kernel\n        # until calibration has been performed.\n        self.is_awq_calibrated = False\n        self.kernel_shape = kernel_shape\n\n        # For 4-bit weights, we pack two values per byte.\n        units = (kernel_shape[1] + 1) // 2\n\n        self.quantized_kernel = self.add_weight(\n            name=\"kernel\",\n            shape=(units, kernel_shape[0]),\n            initializer=\"zeros\",\n            dtype=\"uint8\",\n            trainable=False,\n        )\n\n        group_size = awq_core.get_group_size_for_layer(self, config)\n        num_groups = (\n            1 if group_size == -1 else math.ceil(kernel_shape[0] / group_size)\n        )\n        self.kernel_scale = self.add_weight(\n            name=\"kernel_scale\",\n            shape=(self.units, num_groups),\n            initializer=\"ones\",\n            trainable=False,\n        )\n        self.kernel_zero = self.add_weight(\n            name=\"kernel_zero\",\n            shape=(self.units, num_groups),\n            initializer=\"zeros\",\n            dtype=\"uint8\",\n            trainable=False,\n        )\n\n        # Per-channel AWQ scales from activation magnitudes\n        self.awq_scales = self.add_weight(\n            name=\"awq_scales\",\n            shape=(kernel_shape[0],),\n            initializer=\"ones\",\n            trainable=False,\n        )\n        self.g_idx = self.add_weight(\n            name=\"g_idx\",\n            shape=(kernel_shape[0],),\n            initializer=\"zeros\",\n            dtype=\"float32\",\n            trainable=False,\n        )\n\n    def _awq_call(self, inputs, training=False):\n        \"\"\"Forward pass for AWQ quantized layer.\"\"\"\n        if not self.is_awq_calibrated:\n            W = self._kernel\n        else:\n            # Unpack 4-bit weights\n            W = quantizers.unpack_int4(\n                self.quantized_kernel,\n                orig_len=self.units,\n                axis=0,\n                dtype=\"uint8\",\n            )\n            # Dequantize using scale/zero maps\n            W = ops.transpose(\n                dequantize_with_sz_map(\n                    W,\n                    self.kernel_scale,\n                    self.kernel_zero,\n                    self.g_idx,\n                )\n            )\n            # Apply AWQ scales by dividing to restore original magnitude\n            # (We multiplied by scales before quantization, so divide to undo)\n            # awq_scales has shape [input_dim], W has shape [input_dim, units]\n            # Expand dims for proper broadcasting.\n            W = ops.divide(W, ops.expand_dims(self.awq_scales, -1))\n\n        y = ops.matmul(inputs, W)\n        if self.bias is not None:\n            y = ops.add(y, self.bias)\n        if self.activation is not None:\n            y = self.activation(y)\n        return y\n\n    def _int4_build(self, kernel_shape, config=None):\n        \"\"\"Build variables for int4 quantization.\n\n        The kernel is packed along the last axis,\n        resulting in shape `(input_dim, ceil(units/2))`.\n\n        Args:\n            kernel_shape: The original float32 kernel shape\n                `(input_dim, units)`.\n            config: Optional quantization config specifying block_size.\n        \"\"\"\n        self.inputs_quantizer = (\n            QuantizationConfig.activation_quantizer_or_default(config, None)\n        )\n        input_dim, output_dim = kernel_shape\n\n        # kernel is packed along last axis (output dimension)\n        # Stored shape: [input_dim, ceil(output_dim/2)]\n        packed_cols = (output_dim + 1) // 2\n\n        self._kernel = self.add_weight(\n            name=\"kernel\",\n            shape=(input_dim, packed_cols),\n            initializer=\"zeros\",\n            dtype=\"int8\",\n            trainable=False,\n        )\n\n        block_size = get_block_size_for_layer(self, config)\n        self._int4_block_size = block_size\n\n        if block_size is None or block_size == -1:\n            # Per-channel: one scale per output unit\n            scale_shape = (self.units,)\n        else:\n            # Sub-channel: [n_groups, out_features]\n            n_groups = math.ceil(input_dim / block_size)\n            scale_shape = (n_groups, self.units)\n\n        self.kernel_scale = self.add_weight(\n            name=\"kernel_scale\",\n            shape=scale_shape,\n            initializer=\"ones\",\n            trainable=False,\n        )\n\n        # Sub-channel quantization uses asymmetric quantization\n        if block_size is not None and block_size > 0:\n\n            def idx_initializer(shape, dtype):\n                return ops.floor_divide(\n                    ops.arange(input_dim, dtype=dtype), block_size\n                )\n\n            self.kernel_zero = self.add_weight(\n                name=\"kernel_zero\",\n                shape=scale_shape,\n                initializer=\"zeros\",\n                dtype=\"int8\",\n                trainable=False,\n            )\n            self.g_idx = self.add_weight(\n                name=\"g_idx\",\n                shape=(input_dim,),\n                initializer=idx_initializer,\n                dtype=\"float32\",\n                trainable=False,\n            )\n\n        # Record dimensions for unpacking and reshaping at runtime.\n        self._orig_input_dim = input_dim\n        self._orig_output_dim = output_dim\n\n    def _float8_build(self):\n        from keras.src.dtype_policies import QuantizedFloat8DTypePolicy\n\n        # If `self.dtype_policy` is not QuantizedFloat8DTypePolicy, then set\n        # `amax_history_length` to its default value.\n        amax_history_length = getattr(\n            self.dtype_policy,\n            \"amax_history_length\",\n            QuantizedFloat8DTypePolicy.default_amax_history_length,\n        )\n        # We set `trainable=True` because we will use the gradients to overwrite\n        # these variables\n        scale_kwargs = {\n            \"shape\": (),\n            \"initializer\": \"ones\",\n            \"dtype\": \"float32\",  # Always be float32\n            \"trainable\": True,\n            \"autocast\": False,\n            \"overwrite_with_gradient\": True,\n        }\n        amax_history_kwargs = {\n            \"shape\": (amax_history_length,),\n            \"initializer\": \"zeros\",\n            \"dtype\": \"float32\",  # Always be float32\n            \"trainable\": True,\n            \"autocast\": False,\n            \"overwrite_with_gradient\": True,\n        }\n        self.inputs_scale = self.add_weight(name=\"inputs_scale\", **scale_kwargs)\n        self.inputs_amax_history = self.add_weight(\n            name=\"inputs_amax_history\", **amax_history_kwargs\n        )\n        self.kernel_scale = self.add_weight(name=\"kernel_scale\", **scale_kwargs)\n        self.kernel_amax_history = self.add_weight(\n            name=\"kernel_amax_history\", **amax_history_kwargs\n        )\n        self.outputs_grad_scale = self.add_weight(\n            name=\"outputs_grad_scale\", **scale_kwargs\n        )\n        self.outputs_grad_amax_history = self.add_weight(\n            name=\"outputs_grad_amax_history\", **amax_history_kwargs\n        )\n\n    def _int8_call(self, inputs, training=None):\n        @ops.custom_gradient\n        def matmul_with_inputs_gradient(inputs, kernel, kernel_scale):\n            \"\"\"Custom gradient function to handle the int8 quantized weights.\n\n            Automatic differentiation will not know how to handle the int8\n            quantized weights. So a custom gradient function is needed to\n            handle the int8 quantized weights.\n\n            The custom gradient function will use the dequantized kernel to\n            compute the gradient.\n            \"\"\"\n\n            def grad_fn(*args, upstream=None):\n                if upstream is None:\n                    (upstream,) = args\n                float_kernel = ops.divide(\n                    ops.cast(kernel, dtype=self.compute_dtype),\n                    kernel_scale,\n                )\n                inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel))\n                return (inputs_grad, None, None)\n\n            output_scale = kernel_scale\n            if self.inputs_quantizer:\n                inputs, inputs_scale = self.inputs_quantizer(inputs, axis=-1)\n                output_scale = ops.multiply(output_scale, inputs_scale)\n\n            x = ops.matmul(inputs, kernel)\n            # De-scale outputs\n            x = ops.cast(x, self.compute_dtype)\n            x = ops.divide(x, output_scale)\n            return x, grad_fn\n\n        x = matmul_with_inputs_gradient(\n            inputs,\n            ops.convert_to_tensor(self._kernel),\n            ops.convert_to_tensor(self.kernel_scale),\n        )\n        if self.lora_enabled:\n            lora_x = ops.matmul(inputs, self.lora_kernel_a)\n            lora_x = ops.matmul(lora_x, self.lora_kernel_b)\n            x = ops.add(x, (self.lora_alpha / self.lora_rank) * lora_x)\n        if self.bias is not None:\n            x = ops.add(x, self.bias)\n        if self.activation is not None:\n            x = self.activation(x)\n        return x\n\n    def _int4_call(self, inputs, training=None):\n        \"\"\"Forward pass for int4 quantized Dense layer.\n\n        Uses custom gradients to handle quantized weights since autodiff\n        cannot differentiate through int4 operations.\n        \"\"\"\n        block_size = getattr(self, \"_int4_block_size\", None)\n\n        if block_size is None or block_size == -1:\n            # Per-channel: symmetric quantization (no zero point needed)\n            @ops.custom_gradient\n            def matmul_per_channel_with_inputs_gradient(\n                inputs, kernel, kernel_scale\n            ):\n                \"\"\"Per-channel int4 forward pass with custom gradient.\"\"\"\n                # Unpack: stored as [in, ceil(out/2)], unpack along last axis\n                unpacked_kernel = quantizers.unpack_int4(\n                    kernel, self._orig_output_dim, axis=-1\n                )\n\n                def grad_fn(*args, upstream=None):\n                    if upstream is None:\n                        (upstream,) = args\n                    # Per-channel: unpacked is [in, out]\n                    float_kernel = ops.divide(\n                        ops.cast(unpacked_kernel, dtype=self.compute_dtype),\n                        kernel_scale,\n                    )\n                    inputs_grad = ops.matmul(\n                        upstream, ops.transpose(float_kernel)\n                    )\n                    return (inputs_grad, None, None)\n\n                # Forward pass: per-channel dequantization\n                output_scale = kernel_scale\n                if self.inputs_quantizer:\n                    inputs, inputs_scale = self.inputs_quantizer(\n                        inputs, axis=-1\n                    )\n                    output_scale = ops.multiply(output_scale, inputs_scale)\n\n                x = ops.matmul(inputs, unpacked_kernel)\n                x = ops.cast(x, self.compute_dtype)\n                x = ops.divide(x, output_scale)\n                return x, grad_fn\n\n            x = matmul_per_channel_with_inputs_gradient(\n                inputs,\n                ops.convert_to_tensor(self._kernel),\n                ops.convert_to_tensor(self.kernel_scale),\n            )\n        else:\n            # Sub-channel: asymmetric quantization (with zero point)\n            @ops.custom_gradient\n            def matmul_sub_channel_with_inputs_gradient(\n                inputs, kernel, kernel_scale, kernel_zero, g_idx\n            ):\n                \"\"\"Sub-channel int4 forward pass with custom gradient.\"\"\"\n                # Unpack: stored as [in, ceil(out/2)], unpack along last axis\n                unpacked_kernel = quantizers.unpack_int4(\n                    kernel, self._orig_output_dim, axis=-1\n                )\n\n                def grad_fn(*args, upstream=None):\n                    if upstream is None:\n                        (upstream,) = args\n                    float_kernel = dequantize_with_sz_map(\n                        unpacked_kernel,\n                        kernel_scale,\n                        kernel_zero,\n                        g_idx,\n                        group_axis=0,\n                    )\n                    float_kernel = ops.cast(float_kernel, self.compute_dtype)\n                    inputs_grad = ops.matmul(\n                        upstream, ops.transpose(float_kernel)\n                    )\n                    return (inputs_grad, None, None, None, None)\n\n                float_kernel = dequantize_with_sz_map(\n                    unpacked_kernel,\n                    kernel_scale,\n                    kernel_zero,\n                    g_idx,\n                    group_axis=0,\n                )\n                float_kernel = ops.cast(float_kernel, self.compute_dtype)\n                x = ops.matmul(inputs, float_kernel)\n                return x, grad_fn\n\n            x = matmul_sub_channel_with_inputs_gradient(\n                inputs,\n                ops.convert_to_tensor(self._kernel),\n                ops.convert_to_tensor(self.kernel_scale),\n                ops.convert_to_tensor(self.kernel_zero),\n                ops.convert_to_tensor(self.g_idx),\n            )\n\n        if self.lora_enabled:\n            lora_x = ops.matmul(inputs, self.lora_kernel_a)\n            lora_x = ops.matmul(lora_x, self.lora_kernel_b)\n            x = ops.add(x, (self.lora_alpha / self.lora_rank) * lora_x)\n\n        if self.bias is not None:\n            x = ops.add(x, self.bias)\n        if self.activation is not None:\n            x = self.activation(x)\n        return x\n\n    def _float8_call(self, inputs, training=None):\n        if self.lora_enabled:\n            raise NotImplementedError(\n                \"Currently, `_float8_call` doesn't support LoRA\"\n            )\n\n        @ops.custom_gradient\n        def quantized_dequantize_inputs(inputs, scale, amax_history):\n            if training:\n                new_scale = quantizers.compute_float8_scale(\n                    ops.max(amax_history, axis=0),\n                    scale,\n                    ops.cast(\n                        float(ml_dtypes.finfo(\"float8_e4m3fn\").max), \"float32\"\n                    ),\n                )\n                new_amax_history = quantizers.compute_float8_amax_history(\n                    inputs, amax_history\n                )\n            else:\n                new_scale = None\n                new_amax_history = None\n            qdq_inputs = quantizers.quantize_and_dequantize(\n                inputs, scale, \"float8_e4m3fn\", self.compute_dtype\n            )\n\n            def grad(*args, upstream=None, variables=None):\n                if upstream is None:\n                    (upstream,) = args\n                return upstream, new_scale, new_amax_history\n\n            return qdq_inputs, grad\n\n        @ops.custom_gradient\n        def quantized_dequantize_outputs(outputs, scale, amax_history):\n            \"\"\"Quantize-dequantize the output gradient but not the output.\"\"\"\n\n            def grad(*args, upstream=None, variables=None):\n                if upstream is None:\n                    (upstream,) = args\n                new_scale = quantizers.compute_float8_scale(\n                    ops.max(amax_history, axis=0),\n                    scale,\n                    ops.cast(\n                        float(ml_dtypes.finfo(\"float8_e5m2\").max), \"float32\"\n                    ),\n                )\n                qdq_upstream = quantizers.quantize_and_dequantize(\n                    upstream, scale, \"float8_e5m2\", self.compute_dtype\n                )\n                new_amax_history = quantizers.compute_float8_amax_history(\n                    upstream, amax_history\n                )\n                return qdq_upstream, new_scale, new_amax_history\n\n            return outputs, grad\n\n        x = ops.matmul(\n            quantized_dequantize_inputs(\n                inputs,\n                ops.convert_to_tensor(self.inputs_scale),\n                ops.convert_to_tensor(self.inputs_amax_history),\n            ),\n            quantized_dequantize_inputs(\n                ops.convert_to_tensor(self._kernel),\n                ops.convert_to_tensor(self.kernel_scale),\n                ops.convert_to_tensor(self.kernel_amax_history),\n            ),\n        )\n        # `quantized_dequantize_outputs` is placed immediately after\n        # `ops.matmul` for the sake of pattern matching in gemm_rewrite. That\n        # way, the qdq will be adjacent to the corresponding matmul_bprop in the\n        # bprop.\n        x = quantized_dequantize_outputs(\n            x,\n            ops.convert_to_tensor(self.outputs_grad_scale),\n            ops.convert_to_tensor(self.outputs_grad_amax_history),\n        )\n        if self.bias is not None:\n            # Under non-mixed precision cases, F32 bias has to be converted to\n            # BF16 first to get the biasAdd fusion support. ref. PR\n            # https://github.com/tensorflow/tensorflow/pull/60306\n            bias = self.bias\n            if self.dtype_policy.compute_dtype == \"float32\":\n                bias_bf16 = ops.cast(bias, \"bfloat16\")\n                bias = ops.cast(bias_bf16, bias.dtype)\n            x = ops.add(x, bias)\n        if self.activation is not None:\n            x = self.activation(x)\n        return x\n\n    def quantize(self, mode=None, type_check=True, config=None):\n        # Prevent quantization of the subclasses\n        if type_check and (type(self) is not Dense):\n            raise self._not_implemented_error(self.quantize)\n\n        self.quantization_config = config\n\n        kernel_shape = self._kernel.shape\n        if mode == \"int8\":\n            weight_quantizer = QuantizationConfig.weight_quantizer_or_default(\n                self.quantization_config, quantizers.AbsMaxQuantizer(axis=0)\n            )\n            kernel_value, kernel_scale = weight_quantizer(\n                self._kernel, to_numpy=True\n            )\n            kernel_scale = ops.squeeze(kernel_scale, axis=0)\n            del self._kernel\n            # Build variables for int8 mode\n            self.quantized_build(kernel_shape, mode, self.quantization_config)\n            self._kernel.assign(kernel_value)\n            self.kernel_scale.assign(kernel_scale)\n        elif mode == \"int4\":\n            from keras.src.quantizers.quantization_config import (\n                Int4QuantizationConfig,\n            )\n\n            block_size = None\n            if isinstance(self.quantization_config, Int4QuantizationConfig):\n                block_size = self.quantization_config.block_size\n\n            if block_size is None or block_size == -1:\n                # Per-channel quantization\n                weight_quantizer = (\n                    QuantizationConfig.weight_quantizer_or_default(\n                        self.quantization_config,\n                        quantizers.AbsMaxQuantizer(\n                            axis=0, value_range=(-8, 7), output_dtype=\"int8\"\n                        ),\n                    )\n                )\n                kernel_value_int4, kernel_scale = weight_quantizer(\n                    self._kernel, to_numpy=True\n                )\n                kernel_scale = ops.squeeze(kernel_scale, axis=0)\n            else:\n                # Sub-channel quantization with asymmetric zero point\n                # Returns kernel [in, out], scale [n_groups, out], zero\n                # [n_groups, out]\n                kernel_value_int4, kernel_scale, kernel_zero = (\n                    quantizers.abs_max_quantize_grouped_with_zero_point(\n                        self._kernel, block_size=block_size, to_numpy=True\n                    )\n                )\n\n            # Pack two int4 values per int8 byte along last axis\n            # Stored as [in, ceil(out/2)]\n            packed_kernel_value, _, _ = quantizers.pack_int4(\n                kernel_value_int4, axis=-1\n            )\n            del self._kernel\n            self.quantized_build(kernel_shape, mode, self.quantization_config)\n            self._kernel.assign(packed_kernel_value)\n            self.kernel_scale.assign(kernel_scale)\n            if block_size is not None and block_size > 0:\n                self.kernel_zero.assign(kernel_zero)\n        elif mode == \"gptq\":\n            self.quantized_build(kernel_shape, mode, self.quantization_config)\n        elif mode == \"awq\":\n            self.quantized_build(kernel_shape, mode, self.quantization_config)\n        elif mode == \"float8\":\n            self.quantized_build(kernel_shape, mode)\n        else:\n            raise self._quantization_mode_error(mode)\n\n        # Set new dtype policy only for modes that already have a policy.\n        if self.dtype_policy.quantization_mode is None:\n            from keras.src import dtype_policies  # local import to avoid cycle\n\n            policy_name = mode\n            if mode in (\"gptq\", \"awq\"):\n                policy_name = self.quantization_config.dtype_policy_string()\n            elif mode == \"int4\":\n                # Include block_size in policy name for sub-channel quantization\n                block_size = get_block_size_for_layer(self, config)\n                # Use -1 for per-channel, otherwise use block_size\n                block_size_value = -1 if block_size is None else block_size\n                policy_name = f\"int4/{block_size_value}\"\n            policy = dtype_policies.get(\n                f\"{policy_name}_from_{self.dtype_policy.name}\"\n            )\n            self.dtype_policy = policy\n\n    def _get_kernel_with_merged_lora(self):\n        \"\"\"Returns the kernel with LoRA matrices merged, for serialization.\n\n        This method is called by `save_own_variables` to produce a single\n        kernel tensor that includes the adaptations from LoRA. This is useful\n        for deploying the model or for continuing training after permanently\n        applying the LoRA update.\n\n        If the layer is quantized (`int8` or `int4`), the process is:\n        1. Dequantize the base kernel to float.\n        2. Compute the LoRA delta (`lora_kernel_a @ lora_kernel_b`) and add\n            it to the dequantized kernel.\n        3. Re-quantize the merged result back to the original quantized\n            type (`int8` or packed `int4`), calculating a new scale factor.\n\n        If the layer is not quantized, this method returns the result of the\n        `kernel` property (which computes the merge in floating-point) and a\n        scale of `None`.\n\n        If LoRA is not enabled, it returns the original kernel and scale\n        without modification.\n\n        Returns:\n            A tuple `(kernel_value, kernel_scale, kernel_zero)`:\n                `kernel_value`: The merged kernel. A quantized tensor if\n                    quantization is active, otherwise a high precision tensor.\n                `kernel_scale`: The quantization scale for the merged kernel.\n                    This is `None` if the layer is not quantized.\n                `kernel_zero`: The zero point for sub-channel int4 quantization.\n                    This is `None` for per-channel or non-int4 modes.\n        \"\"\"\n        if self.dtype_policy.quantization_mode in (None, \"gptq\", \"awq\"):\n            return self.kernel, None, None\n\n        kernel_value = self._kernel\n        kernel_scale = self.kernel_scale\n        kernel_zero = getattr(self, \"kernel_zero\", None)\n\n        if not self.lora_enabled:\n            return kernel_value, kernel_scale, kernel_zero\n\n        # Dequantize, Merge, and Re-quantize\n        block_size = getattr(self, \"_int4_block_size\", None)\n\n        # Step 1: Dequantize kernel to float\n        if self.quantization_mode == \"int4\":\n            # Unpack along last axis ([in, out])\n            unpacked_kernel = quantizers.unpack_int4(\n                kernel_value, self._orig_output_dim, axis=-1\n            )\n            if block_size is None or block_size == -1:\n                # Per-channel: kernel [in, out], scale [out]\n                float_kernel = ops.divide(\n                    ops.cast(unpacked_kernel, self.compute_dtype),\n                    kernel_scale,\n                )\n            else:\n                # Sub-channel: scale/zero are [n_groups, out]\n                float_kernel = dequantize_with_sz_map(\n                    unpacked_kernel,\n                    kernel_scale,\n                    self.kernel_zero,\n                    self.g_idx,\n                    group_axis=0,\n                )\n                float_kernel = ops.cast(float_kernel, self.compute_dtype)\n            quant_range = (-8, 7)\n        elif self.quantization_mode == \"int8\":\n            float_kernel = ops.divide(\n                ops.cast(kernel_value, self.compute_dtype), kernel_scale\n            )\n            quant_range = (-127, 127)\n        else:\n            raise ValueError(\n                f\"Unsupported quantization mode: {self.quantization_mode}\"\n            )\n\n        # Step 2: Merge LoRA weights in float domain\n        lora_delta = (self.lora_alpha / self.lora_rank) * ops.matmul(\n            self.lora_kernel_a, self.lora_kernel_b\n        )\n        merged_float_kernel = ops.add(float_kernel, lora_delta)\n\n        # Step 3: Re-quantize the merged kernel\n        if (\n            self.quantization_mode == \"int4\"\n            and block_size is not None\n            and block_size != -1\n        ):\n            # Sub-channel: returns kernel [in, out], scale [n_groups, out]\n            requantized_kernel, kernel_scale, kernel_zero = (\n                quantizers.abs_max_quantize_grouped_with_zero_point(\n                    merged_float_kernel, block_size=block_size, to_numpy=True\n                )\n            )\n        elif self.quantization_mode == \"int4\":\n            # Per-channel: quantize along input axis (axis=0)\n            requantized_kernel, kernel_scale = quantizers.abs_max_quantize(\n                merged_float_kernel,\n                axis=0,\n                value_range=quant_range,\n                dtype=\"int8\",\n                to_numpy=True,\n            )\n            kernel_scale = ops.squeeze(kernel_scale, axis=0)\n            kernel_zero = None\n        else:\n            requantized_kernel, kernel_scale = quantizers.abs_max_quantize(\n                merged_float_kernel,\n                axis=0,\n                value_range=quant_range,\n                dtype=\"int8\",\n                to_numpy=True,\n            )\n            kernel_scale = ops.squeeze(kernel_scale, axis=0)\n            kernel_zero = None\n\n        if self.quantization_mode == \"int4\":\n            # Pack along last axis\n            kernel_value, _, _ = quantizers.pack_int4(\n                requantized_kernel, axis=-1\n            )\n        else:\n            kernel_value = requantized_kernel\n        return kernel_value, kernel_scale, kernel_zero\n"
  },
  {
    "path": "keras/src/layers/core/dense_test.py",
    "content": "import math\nimport os\n\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import constraints\nfrom keras.src import export\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import optimizers\nfrom keras.src import quantizers\nfrom keras.src import random\nfrom keras.src import saving\nfrom keras.src import testing\nfrom keras.src.backend.common import keras_tensor\nfrom keras.src.quantizers.awq_config import AWQConfig\nfrom keras.src.quantizers.gptq_config import GPTQConfig\nfrom keras.src.quantizers.quantization_config import Int4QuantizationConfig\nfrom keras.src.quantizers.quantization_config import Int8QuantizationConfig\nfrom keras.src.quantizers.quantizers import AbsMaxQuantizer\n\n\nclass DenseTest(testing.TestCase):\n    @parameterized.named_parameters(\n        (\"int8\", \"int8\", {\"axis\": 0}, {}),\n        (\n            \"int4\",\n            \"int4\",\n            {\"axis\": 0, \"value_range\": (-8, 7), \"output_dtype\": \"int8\"},\n            {\"axis\": -1},\n        ),\n        (\"int8_weight_only\", \"int8\", {\"axis\": 0}, None),\n    )\n    def test_dense_quantize_config(\n        self, mode, weight_quantizer_args, activation_quantizer_args\n    ):\n        \"\"\"Test Dense quantization with QuantizationConfig.\"\"\"\n        layer = layers.Dense(units=32)\n        layer.build((None, 8))\n\n        weight_quantizer = AbsMaxQuantizer(**weight_quantizer_args)\n        if activation_quantizer_args is not None:\n            activation_quantizer = AbsMaxQuantizer(**activation_quantizer_args)\n        else:\n            activation_quantizer = None\n\n        if mode == \"int8\":\n            config = Int8QuantizationConfig(\n                weight_quantizer=weight_quantizer,\n                activation_quantizer=activation_quantizer,\n            )\n        elif mode == \"int4\":\n            # Custom quantizers require per-channel mode (block_size=None)\n            config = Int4QuantizationConfig(\n                weight_quantizer=weight_quantizer,\n                activation_quantizer=activation_quantizer,\n                block_size=None,\n            )\n\n        layer.quantize(mode, config=config)\n\n        if activation_quantizer_args is not None:\n            # Verify inputs_quantizer is set correctly\n            self.assertIsInstance(layer.inputs_quantizer, AbsMaxQuantizer)\n        else:\n            # Verify inputs_quantizer is None\n            self.assertIsNone(layer.inputs_quantizer)\n\n        # Verify call works\n        x = np.random.random((2, 8)).astype(\"float32\")\n        y = layer(x)\n        self.assertEqual(y.shape, (2, 32))\n\n        if mode == \"int4\":\n            # Verify kernel is int8 (packed int4)\n            self.assertEqual(\n                backend.standardize_dtype(layer._kernel.dtype), \"int8\"\n            )\n\n    @pytest.mark.requires_trainable_backend\n    def test_dense_basics(self):\n        # 2D case, no bias.\n        self.run_layer_test(\n            layers.Dense,\n            init_kwargs={\n                \"units\": 4,\n                \"activation\": \"relu\",\n                \"kernel_initializer\": \"random_uniform\",\n                \"bias_initializer\": \"ones\",\n                \"use_bias\": False,\n            },\n            input_shape=(2, 3),\n            expected_output_shape=(2, 4),\n            expected_num_trainable_weights=1,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n        )\n        # 3D case, some regularizers.\n        self.run_layer_test(\n            layers.Dense,\n            init_kwargs={\n                \"units\": 5,\n                \"activation\": \"sigmoid\",\n                \"kernel_regularizer\": \"l2\",\n                \"bias_regularizer\": \"l2\",\n            },\n            input_shape=(2, 3, 4),\n            expected_output_shape=(2, 3, 5),\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=2,  # we have 2 regularizers.\n            supports_masking=True,\n        )\n\n    @parameterized.named_parameters(\n        (\"zero\", 0),\n        (\"negative\", -3),\n        (\"float\", 2.5),\n        (\"none\", None),\n        (\"string\", \"64\"),\n    )\n    def test_dense_invalid_units_raises(self, units):\n        with self.assertRaisesRegex(ValueError, \"positive integer\"):\n            layers.Dense(units)\n\n    def test_dense_correctness(self):\n        # With bias and activation.\n        layer = layers.Dense(units=2, activation=\"relu\")\n        layer.build((1, 2))\n        layer.set_weights(\n            [\n                np.array([[1.0, -2.0], [3.0, -4.0]]),\n                np.array([5.0, -6.0]),\n            ]\n        )\n        inputs = np.array(\n            [[-1.0, 2.0]],\n        )\n        self.assertAllClose(layer(inputs), [[10.0, 0.0]])\n\n        # Just a kernel matmul.\n        layer = layers.Dense(units=2, use_bias=False)\n        layer.build((1, 2))\n        layer.set_weights(\n            [\n                np.array([[1.0, -2.0], [3.0, -4.0]]),\n            ]\n        )\n        inputs = np.array(\n            [[-1.0, 2.0]],\n        )\n        self.assertEqual(layer.bias, None)\n        self.assertAllClose(layer(inputs), [[5.0, -6.0]])\n\n    def test_dense_errors(self):\n        with self.assertRaisesRegex(ValueError, \"incompatible with the layer\"):\n            layer = layers.Dense(units=2, activation=\"relu\")\n            layer(keras_tensor.KerasTensor((1, 2)))\n            layer(keras_tensor.KerasTensor((1, 3)))\n\n    @pytest.mark.skipif(\n        not backend.SUPPORTS_SPARSE_TENSORS,\n        reason=\"Backend does not support sparse tensors.\",\n    )\n    def test_dense_sparse(self):\n        import tensorflow as tf\n\n        self.run_layer_test(\n            layers.Dense,\n            init_kwargs={\n                \"units\": 4,\n            },\n            input_shape=(2, 3),\n            input_sparse=True,\n            expected_output_shape=(2, 4),\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n        )\n\n        inputs = 4 * backend.random.uniform((10, 10))\n        inputs = tf.sparse.from_dense(tf.nn.dropout(inputs, 0.8))\n\n        inputs = np.random.random((10, 10)).astype(\"float32\")\n        inputs = np.multiply(inputs, inputs >= 0.8)\n\n        if backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            inputs = tf.sparse.from_dense(inputs)\n        elif backend.backend() == \"jax\":\n            import jax.experimental.sparse as jax_sparse\n\n            inputs = jax_sparse.BCOO.fromdense(inputs)\n        else:\n            self.fail(f\"Sparse is unsupported with backend {backend.backend()}\")\n\n        layer = layers.Dense(units=10)\n        outputs = layer(inputs)\n\n        # Verify the computation is the same as if it had been a dense tensor\n        expected_outputs = ops.add(\n            ops.matmul(\n                backend.convert_to_tensor(inputs, sparse=False), layer.kernel\n            ),\n            layer.bias,\n        )\n        self.assertAllClose(\n            outputs, expected_outputs, tpu_atol=1e-2, tpu_rtol=1e-2\n        )\n\n        # Verify the gradient is sparse\n        if backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            with tf.GradientTape() as g:\n                outputs = layer(inputs)\n\n            self.assertIsInstance(\n                g.gradient(outputs, layer.kernel), tf.IndexedSlices\n            )\n\n    def test_dense_no_activation(self):\n        layer = layers.Dense(units=2, use_bias=False, activation=None)\n        layer.build((1, 2))\n        layer.set_weights(\n            [\n                np.array([[1.0, -2.0], [3.0, -4.0]]),\n            ]\n        )\n        inputs = np.array(\n            [[-1.0, 2.0]],\n        )\n        self.assertEqual(layer.bias, None)\n        self.assertAllClose(layer(inputs), [[5.0, -6.0]])\n\n    def test_dense_without_activation_set(self):\n        layer = layers.Dense(units=2, use_bias=False)\n        layer.build((1, 2))\n        layer.set_weights(\n            [\n                np.array([[1.0, -2.0], [3.0, -4.0]]),\n            ]\n        )\n        layer.activation = None\n        inputs = np.array(\n            [[-1.0, 2.0]],\n        )\n        self.assertEqual(layer.bias, None)\n        self.assertAllClose(layer(inputs), [[5.0, -6.0]])\n\n    def test_dense_with_activation(self):\n        layer = layers.Dense(units=2, use_bias=False, activation=\"relu\")\n        layer.build((1, 2))\n        layer.set_weights(\n            [\n                np.array([[1.0, -2.0], [3.0, -4.0]]),\n            ]\n        )\n\n        inputs = np.array(\n            [[-1.0, 2.0]],\n        )\n        output = layer(inputs)\n        expected_output = np.array([[5.0, 0.0]])\n        self.assertAllClose(output, expected_output)\n\n    def test_dense_constraints(self):\n        layer = layers.Dense(units=2, kernel_constraint=\"non_neg\")\n        layer.build((None, 2))\n        self.assertIsInstance(layer.kernel.constraint, constraints.NonNeg)\n        layer = layers.Dense(units=2, bias_constraint=\"non_neg\")\n        layer.build((None, 2))\n        self.assertIsInstance(layer.bias.constraint, constraints.NonNeg)\n\n    @pytest.mark.requires_trainable_backend\n    def test_enable_lora(self):\n        layer = layers.Dense(units=16)\n        layer.build((None, 8))\n        layer.enable_lora(4)\n        self.assertLen(layer.trainable_weights, 3)\n        self.assertLen(layer.non_trainable_weights, 1)\n        if backend.backend() == \"torch\":\n            self.assertLen(layer.torch_params, 4)\n        # Try eager call\n        x = np.random.random((64, 8))\n        y = np.random.random((64, 16))\n        _ = layer(x[:2])\n\n        init_lora_a_kernel_value = layer.lora_kernel_a.numpy()\n        init_lora_b_kernel_value = layer.lora_kernel_b.numpy()\n\n        # Try calling fit()\n        model = models.Sequential(\n            [\n                layer,\n            ]\n        )\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n        model.fit(x, y)\n\n        final_lora_a_kernel_value = layer.lora_kernel_a.numpy()\n        final_lora_b_kernel_value = layer.lora_kernel_b.numpy()\n        diff_a = np.max(\n            np.abs(init_lora_a_kernel_value - final_lora_a_kernel_value)\n        )\n        diff_b = np.max(\n            np.abs(init_lora_b_kernel_value - final_lora_b_kernel_value)\n        )\n        self.assertGreater(diff_a, 0.0)\n        self.assertGreater(diff_b, 0.0)\n\n        # Try saving and reloading the model\n        temp_filepath = os.path.join(self.get_temp_dir(), \"lora_model.keras\")\n        model.save(temp_filepath)\n\n        new_model = saving.load_model(temp_filepath)\n        self.assertTrue(new_model.layers[0].lora_enabled)\n        self.assertAllClose(model.predict(x), new_model.predict(x))\n\n        # Try saving and reloading the model's weights only\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"lora_model.weights.h5\"\n        )\n        model.save_weights(temp_filepath)\n\n        # Load the file into a fresh, non-lora model\n        new_model = models.Sequential(\n            [\n                layers.Dense(units=16),\n            ]\n        )\n        new_model.build((None, 8))\n        new_model.load_weights(temp_filepath)\n        self.assertAllClose(model.predict(x), new_model.predict(x))\n\n        # Try loading a normal checkpoint into a lora model\n        new_model.save_weights(temp_filepath)\n        model.load_weights(temp_filepath)\n        self.assertAllClose(model.predict(x), new_model.predict(x))\n\n    @pytest.mark.requires_trainable_backend\n    def test_enable_lora_with_alpha(self):\n        # Create a `Dense` layer and build it.\n        layer = layers.Dense(units=8)\n        layer.build((None, 4))\n\n        # Enable LoRA with `rank`=2 and `lora_alpha`=3.0.\n        layer.enable_lora(2, lora_alpha=3.0)\n        self.assertEqual(layer.lora_rank, 2)\n        self.assertEqual(layer.lora_alpha, 3.0)\n\n        # Manually compute the expected effective kernel:\n        # `effective_kernel_expected` = `base_kernel` +\n        # `lora_alpha / lora_rank` * `lora_kernel_a @ lora_kernel_b`\n        base_kernel = ops.convert_to_numpy(layer._kernel)\n        lora_update = np.matmul(\n            ops.convert_to_numpy(layer.lora_kernel_a),\n            ops.convert_to_numpy(layer.lora_kernel_b),\n        )\n        effective_kernel_expected = base_kernel + (3.0 / 2) * lora_update\n\n        # Verify that the effective kernel matches expectation.\n        self.assertAllClose(\n            ops.convert_to_numpy(layer.kernel), effective_kernel_expected\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_lora_weight_name(self):\n        class MyModel(models.Model):\n            def __init__(self):\n                super().__init__(name=\"mymodel\")\n                self.dense = layers.Dense(16, name=\"dense\")\n\n            def build(self, input_shape):\n                self.dense.build(input_shape)\n\n            def call(self, x):\n                return self.dense(x)\n\n        model = MyModel()\n        model.build((None, 8))\n        model.dense.enable_lora(4)\n        self.assertEqual(\n            model.dense.lora_kernel_a.path, \"mymodel/dense/lora_kernel_a\"\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_lora_rank_argument(self):\n        self.run_layer_test(\n            layers.Dense,\n            init_kwargs={\n                \"units\": 5,\n                \"activation\": \"sigmoid\",\n                \"kernel_regularizer\": \"l2\",\n                \"lora_rank\": 2,\n            },\n            input_shape=(2, 3, 4),\n            expected_output_shape=(2, 3, 5),\n            expected_num_trainable_weights=3,\n            expected_num_non_trainable_weights=1,\n            expected_num_seed_generators=0,\n            expected_num_losses=2,  # we have 2 regularizers.\n            supports_masking=True,\n        )\n\n    def test_enable_lora_with_kernel_constraint(self):\n        layer = layers.Dense(units=2, kernel_constraint=\"max_norm\")\n        with self.assertRaisesRegex(\n            ValueError, \"incompatible with kernel constraints\"\n        ):\n            layer.enable_lora(rank=2)\n\n    def test_enable_lora_on_unbuilt_layer(self):\n        layer = layers.Dense(units=2)\n        with self.assertRaisesRegex(\n            ValueError, \"Cannot enable lora on a layer that isn't yet built\"\n        ):\n            layer.enable_lora(rank=2)\n\n    def test_enable_lora_when_already_enabled(self):\n        layer = layers.Dense(units=2)\n        layer.build((None, 2))\n        layer.enable_lora(rank=2)\n        with self.assertRaisesRegex(ValueError, \"lora is already enabled\"):\n            layer.enable_lora(rank=2)\n\n    # Test quantization-related methods.\n\n    @parameterized.named_parameters(\n        (\"int8\", \"int8\", 1e-3),\n        (\"int4\", \"int4\", 2e-3),\n    )\n    def test_quantize_int(self, mode, error_threshold):\n        if mode == \"int4\" and testing.tensorflow_uses_gpu():\n            self.skipTest(\"Segfault\")\n        layer = layers.Dense(units=16)\n        layer.build((None, 8))\n        x = np.random.random((2, 8))\n        y_float = layer(x)\n        layer.quantize(mode)\n\n        # Verify the dtype of the weights.\n        # The kernel's data type is int8, despite the int4 quantization, because\n        # we pack the int4 values into int8.\n        self.assertEqual(backend.standardize_dtype(layer._kernel.dtype), \"int8\")\n        self.assertEqual(\n            backend.standardize_dtype(layer.kernel_scale.dtype),\n            layer.variable_dtype,\n        )\n\n        # Verify the correctness of the outputs.\n        y_quantized = layer(x)\n        mse = ops.mean(ops.square(y_float - y_quantized))\n        self.assertLess(mse, error_threshold)  # A weak correctness test\n\n        # Check model save / load round-trip.\n        model = models.Sequential([layer])\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"quantized_model.keras\"\n        )\n        model.save(temp_filepath)\n        new_model = saving.load_model(temp_filepath)\n        self.assertAllClose(model.predict(x), new_model.predict(x))\n\n        # Check weights-only save / load round-trip.\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"quantized_model.weights.h5\"\n        )\n        model.save_weights(temp_filepath)\n        new_model = models.Sequential([layers.Dense(units=16)])\n        new_model.build((None, 8))\n        new_model.quantize(mode)\n        new_model.load_weights(temp_filepath)\n        self.assertAllClose(model.predict(x), new_model.predict(x))\n\n    @parameterized.named_parameters(\n        (\"int8\", \"int8\"),\n        (\"int4\", \"int4\"),\n        (\"float8\", \"float8\"),\n    )\n    def test_quantize_on_unbuilt_layer(self, mode):\n        layer = layers.Dense(units=2)\n        with self.assertRaisesRegex(\n            ValueError, \"Cannot quantize a layer that isn't yet built.\"\n        ):\n            layer.quantize(mode)\n\n    @parameterized.named_parameters(\n        (\"int8\", \"int8\"),\n        (\"int4\", \"int4\"),\n        (\"float8\", \"float8\"),\n    )\n    def test_quantize_on_subclass(self, mode):\n        class MyDense(layers.Dense):\n            pass\n\n        layer = MyDense(units=16)\n        layer.build((None, 8))\n        with self.assertRaises(NotImplementedError):\n            layer.quantize(mode)\n\n        layer.quantize(mode, type_check=False)  # No error\n\n    @parameterized.named_parameters(\n        (\"int8\", \"int8\"),\n        (\"int4\", \"int4\"),\n        (\"float8\", \"float8\"),\n    )\n    def test_quantize_when_already_quantized(self, mode):\n        layer = layers.Dense(units=2)\n        layer.build((None, 2))\n        layer.quantize(mode)\n        for m in [\"int8\", \"int4\", \"float8\"]:\n            with self.assertRaisesRegex(\n                ValueError, \"is already quantized with dtype_policy=\"\n            ):\n                layer.quantize(m)\n\n        layer = layers.Dense(units=2, dtype=f\"{mode}_from_float32\")\n        layer.build((None, 2))\n        for m in [\"int8\", \"int4\", \"float8\"]:\n            with self.assertRaisesRegex(\n                ValueError, \"is already quantized with dtype_policy=\"\n            ):\n                layer.quantize(m)\n\n    @parameterized.named_parameters(\n        (\"int8\", \"int8_from_float32\", 3),\n        (\"int4\", \"int4_from_float32\", 5),  # bias + kernel + scale + zero + gidx\n        (\"float8\", \"float8_from_float32\", 8),\n    )\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_quantize_by_setting_dtype_policy(\n        self, policy, expected_num_variables\n    ):\n        layer = layers.Dense(units=2)\n        layer.build((None, 2))\n        layer.dtype_policy = policy\n        self.assertLen(layer.variables, expected_num_variables)\n\n    @parameterized.named_parameters(\n        (\"int7\", \"int7\"),\n        (\"float7\", \"float7\"),\n    )\n    def test_quantize_invalid_mode(self, mode):\n        layer = layers.Dense(units=2)\n        layer.build((None, 2))\n        x = np.random.random((1, 2))\n        # dtype_policy should not be altered by failed quantization\n        original_dtype_policy = layer.dtype_policy\n\n        # Test quantize\n        with self.assertRaisesRegex(ValueError, \"Invalid quantization mode.\"):\n            layer.quantize(mode)\n        self.assertEqual(layer.dtype_policy, original_dtype_policy)\n\n        # Test quantized_build\n        with self.assertRaisesRegex(\n            NotImplementedError, \"Invalid quantization mode.\"\n        ):\n            layer.quantized_build((None, 2), mode)\n        self.assertEqual(layer.dtype_policy, original_dtype_policy)\n\n        # Test quantized_call\n        with self.assertRaisesRegex(\n            NotImplementedError, \"Invalid quantization mode.\"\n        ):\n            # Explicitly set quantization_mode\n            layer._dtype_policy._quantization_mode = mode\n            layer.quantized_call(x)\n        self.assertEqual(layer.dtype_policy, original_dtype_policy)\n\n    @parameterized.named_parameters(\n        (\"int8\", \"int8_from_mixed_bfloat16\", 1, 2),\n        (\"int4\", \"int4_from_mixed_bfloat16\", 1, 2),\n        (\"float8\", \"float8_from_mixed_bfloat16\", 8, 0),\n    )\n    @pytest.mark.requires_trainable_backend\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_quantize_dtype_argument(\n        self, dtype, num_trainable_weights, num_non_trainable_weights\n    ):\n        self.run_layer_test(\n            layers.Dense,\n            init_kwargs={\"units\": 5, \"dtype\": dtype},\n            input_shape=(2, 3, 4),\n            expected_output_shape=(2, 3, 5),\n            expected_num_trainable_weights=num_trainable_weights,\n            expected_num_non_trainable_weights=num_non_trainable_weights,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n        )\n\n    @parameterized.named_parameters(\n        (\"int8\", \"int8\", 3, 2, 5),\n        (\"int4\", \"int4\", 3, 4, 7),  # +2 non-trainable for zero and g_idx\n    )\n    @pytest.mark.requires_trainable_backend\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_quantize_lora_integration(\n        self,\n        mode,\n        num_trainable_weights,\n        num_non_trainable_weights,\n        num_torch_params,\n    ):\n        # Note that saving and loading with lora_enabled and quantized are\n        # lossy, so we use a weak correctness test for model outputs (atol=0.5).\n        config = dict(units=16)\n        layer = layers.Dense(**config)\n        layer.build((None, 8))\n        layer.enable_lora(4)\n        layer.quantize(mode)\n        self.assertLen(layer.trainable_weights, num_trainable_weights)\n        self.assertLen(layer.non_trainable_weights, num_non_trainable_weights)\n        if backend.backend() == \"torch\":\n            self.assertLen(layer.torch_params, num_torch_params)\n\n        # Try calling fit()\n        init_lora_a_kernel_value = layer.lora_kernel_a.numpy()\n        init_lora_b_kernel_value = layer.lora_kernel_b.numpy()\n        x = np.random.random((64, 8))\n        y = np.random.random((64, 16))\n        model = models.Sequential([layer])\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n        model.fit(x, y, epochs=2)\n\n        final_lora_a_kernel_value = layer.lora_kernel_a.numpy()\n        final_lora_b_kernel_value = layer.lora_kernel_b.numpy()\n        diff_a = np.max(\n            np.abs(init_lora_a_kernel_value - final_lora_a_kernel_value)\n        )\n        diff_b = np.max(\n            np.abs(init_lora_b_kernel_value - final_lora_b_kernel_value)\n        )\n        self.assertGreater(diff_a, 0.0)\n        self.assertGreater(diff_b, 0.0)\n\n        # Try saving and reloading the model\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"quantized_lora_model.keras\"\n        )\n        model.save(temp_filepath)\n        new_model = saving.load_model(temp_filepath)\n        self.assertTrue(new_model.layers[0].lora_enabled)\n        self.assertAllClose(model.predict(x), new_model.predict(x), atol=0.5)\n\n        # Try saving and reloading the model's weights only\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"quantized_lora_model.weights.h5\"\n        )\n        model.save_weights(temp_filepath)\n        new_model = models.Sequential([layers.Dense(**config)])\n        new_model.build((None, 8))\n        new_model.quantize(mode)\n        new_model.load_weights(temp_filepath)\n        self.assertFalse(new_model.layers[0].lora_enabled)\n        self.assertAllClose(model.predict(x), new_model.predict(x), atol=0.5)\n\n        # Try loading a normal checkpoint into a lora model\n        new_model.save_weights(temp_filepath)\n        model.load_weights(temp_filepath)\n        self.assertAllClose(model.predict(x), new_model.predict(x), atol=0.5)\n\n        # Test export and TFSMLayer reloading when using tensorflow backend\n        if backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n            ref_input = tf.random.normal((2, 8))\n            ref_output = model(ref_input)\n            model.export(temp_filepath, format=\"tf_saved_model\")\n            reloaded_layer = export.TFSMLayer(temp_filepath)\n            self.assertAllClose(\n                reloaded_layer(ref_input), ref_output, atol=1e-7\n            )\n            self.assertLen(reloaded_layer.weights, len(model.weights))\n            self.assertLen(\n                reloaded_layer.trainable_weights, len(model.trainable_weights)\n            )\n            self.assertLen(\n                reloaded_layer.non_trainable_weights,\n                len(model.non_trainable_weights),\n            )\n\n    @pytest.mark.requires_trainable_backend\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_quantize_float8(self):\n        import ml_dtypes\n\n        from keras.src import quantizers\n\n        layer = layers.Dense(units=32)\n        layer.build((None, 16))\n        layer.quantize(\"float8\")\n        optimizer = optimizers.AdamW(learning_rate=0.1)\n        optimizer.build(layer.trainable_variables)\n\n        def loss_fn(x, dy):\n            y = layer(x, training=True)\n            loss = y * ops.cast(dy, y.dtype)\n            return ops.sum(loss)\n\n        if backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            @tf.function(jit_compile=True)\n            def train_one_step(x, dy):\n                with tf.GradientTape() as tape:\n                    loss = loss_fn(x, dy)\n                grads = tape.gradient(loss, layer.trainable_variables)\n                optimizer.apply(grads, layer.trainable_variables)\n\n        elif backend.backend() == \"jax\":\n            import jax\n\n            def stateless_loss_fn(trainable_variables, x, dy):\n                y = layer.stateless_call(\n                    trainable_variables, [], x, training=True\n                )[0]\n                loss = y * ops.cast(dy, y.dtype)\n                return ops.sum(loss)\n\n            grad_fn = jax.jit(jax.grad(stateless_loss_fn))\n\n            def train_one_step(x, dy):\n                trainable_variables = [\n                    v.value for v in layer.trainable_variables\n                ]\n                optimizer_variables = [v.value for v in optimizer.variables]\n                grads = grad_fn(trainable_variables, x, dy)\n                trainable_variables, optimizer_variables = (\n                    optimizer.stateless_apply(\n                        optimizer_variables, grads, trainable_variables\n                    )\n                )\n                for variable, value in zip(\n                    layer.trainable_variables, trainable_variables\n                ):\n                    variable.assign(value)\n                for variable, value in zip(\n                    optimizer.variables, optimizer_variables\n                ):\n                    variable.assign(value)\n\n        elif backend.backend() == \"torch\":\n\n            def train_one_step(x, dy):\n                layer.zero_grad()\n                loss = loss_fn(x, dy)\n                loss.backward()\n                grads = [v.value.grad for v in layer.trainable_variables]\n                optimizer.apply(grads, layer.trainable_variables)\n\n        scale_x, amax_history_x = ops.ones(()), ops.zeros((1024,))\n        scale_k, amax_history_k = ops.ones(()), ops.zeros((1024,))\n        scale_g, amax_history_g = ops.ones(()), ops.zeros((1024,))\n        e4m3_max = ops.cast(\n            float(ml_dtypes.finfo(\"float8_e4m3fn\").max), \"float32\"\n        )\n        e5m2_max = ops.cast(\n            float(ml_dtypes.finfo(\"float8_e5m2\").max), \"float32\"\n        )\n\n        for _ in range(3):\n            x = random.normal((16, 16), dtype=\"float32\")\n            g = random.normal((16, 32), dtype=\"float32\")\n            k = ops.convert_to_tensor(layer._kernel)\n\n            # Manually compute the expected amax history and scaling factors.\n            amax_from_history_x = ops.max(amax_history_x)\n            amax_from_history_k = ops.max(amax_history_k)\n            amax_from_history_g = ops.max(amax_history_g)\n            scale_x = quantizers.compute_float8_scale(\n                amax_from_history_x, scale_x, e4m3_max\n            )\n            scale_k = quantizers.compute_float8_scale(\n                amax_from_history_k, scale_k, e4m3_max\n            )\n            scale_g = quantizers.compute_float8_scale(\n                amax_from_history_g, scale_g, e5m2_max\n            )\n            amax_history_x = quantizers.compute_float8_amax_history(\n                x, amax_history_x\n            )\n            amax_history_k = quantizers.compute_float8_amax_history(\n                k, amax_history_k\n            )\n            amax_history_g = quantizers.compute_float8_amax_history(\n                g, amax_history_g\n            )\n\n            train_one_step(x, g)\n\n            self.assertAllClose(layer.inputs_amax_history, amax_history_x)\n            self.assertAllClose(layer.kernel_amax_history, amax_history_k)\n            self.assertAllClose(layer.outputs_grad_amax_history, amax_history_g)\n            self.assertAllClose(layer.inputs_scale, scale_x)\n            self.assertAllClose(layer.kernel_scale, scale_k)\n            self.assertAllClose(layer.outputs_grad_scale, scale_g)\n\n    @pytest.mark.requires_trainable_backend\n    def test_quantize_float8_fitting(self):\n        config = dict(units=16)\n        layer = layers.Dense(**config)\n        layer.build((None, 8))\n        layer.quantize(\"float8\")\n        self.assertLen(layer.trainable_weights, 8)\n        self.assertLen(layer.non_trainable_weights, 0)\n\n        # Try calling fit()\n        x = np.random.random((64, 8))\n        y = np.random.random((64, 16))\n        model = models.Sequential([layer])\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n        model.fit(x, y, epochs=2)\n\n        # Try saving and reloading the model\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"quantized_float8_model.keras\"\n        )\n        model.save(temp_filepath)\n        new_model = saving.load_model(temp_filepath)\n        self.assertAllClose(model.predict(x), new_model.predict(x))\n\n        # Try saving and reloading the model's weights only\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"quantized_float8_model.weights.h5\"\n        )\n        model.save_weights(temp_filepath)\n        new_model = models.Sequential([layers.Dense(**config)])\n        new_model.build((None, 8))\n        new_model.quantize(\"float8\")\n        new_model.load_weights(temp_filepath)\n        self.assertAllClose(model.predict(x), new_model.predict(x))\n\n        # Test export and TFSMLayer reloading when using tensorflow backend\n        if backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n            ref_input = tf.random.normal((2, 8))\n            ref_output = model(ref_input)\n            model.export(temp_filepath, format=\"tf_saved_model\")\n            reloaded_layer = export.TFSMLayer(temp_filepath)\n            self.assertAllClose(reloaded_layer(ref_input), ref_output)\n            self.assertLen(reloaded_layer.weights, len(model.weights))\n            self.assertLen(\n                reloaded_layer.trainable_weights, len(model.trainable_weights)\n            )\n            self.assertLen(\n                reloaded_layer.non_trainable_weights,\n                len(model.non_trainable_weights),\n            )\n\n    def test_quantize_float8_inference(self):\n        config = dict(units=16)\n        layer = layers.Dense(**config)\n        layer.build((None, 8))\n        layer.quantize(\"float8\")\n\n        # Try calling with `training=False` and the result must match\n        # `training=True` because there is no update.\n        x = np.random.random((64, 8))\n        y_inference = layer(x, training=False)\n        y_training = layer(x, training=True)\n        self.assertAllClose(y_inference, y_training)\n\n    def test_gptq_serialization(self):\n        \"\"\"Test that a GPTQ-quantized layer can be serialized and deserialized\n        correctly.\"\"\"\n        layer = layers.Dense(units=16)\n        layer.build((None, 8))\n        layer.quantize(\n            \"gptq\",\n            config=GPTQConfig(\n                dataset=None, tokenizer=None, weight_bits=4, group_size=8\n            ),\n        )\n        config = layer.get_config()\n        new_layer = layers.Dense.from_config(config)\n        new_layer.build((None, 8))\n        self.assertEqual(new_layer.quantization_mode, \"gptq\")\n\n    def test_awq_serialization(self):\n        \"\"\"Test that an AWQ-quantized layer can be serialized and deserialized\n        correctly.\"\"\"\n        layer = layers.Dense(units=16)\n        layer.build((None, 8))\n        layer.quantize(\n            \"awq\",\n            config=AWQConfig(\n                dataset=None, tokenizer=None, group_size=8, num_grid_points=10\n            ),\n        )\n        config = layer.get_config()\n        new_layer = layers.Dense.from_config(config)\n        new_layer.build((None, 8))\n        self.assertEqual(new_layer.quantization_mode, \"awq\")\n\n    def test_int4_kernel_returns_unpacked_form(self):\n        \"\"\"Test that the `kernel` property returns the unpacked int4 kernel.\"\"\"\n        layer = layers.Dense(units=2)\n        layer.build((None, 2))\n        layer.quantize(\"int4\")\n        packed_kernel = layer._kernel\n        # unpack [in, ceil(out/2)] -> [in, out]\n        expected = quantizers.unpack_int4(\n            packed_kernel, layer._orig_output_dim, axis=-1\n        )\n        self.assertAllClose(layer.kernel, expected)\n\n    def test_legacy_load_own_variables(self):\n        # In previous versions, `load_own_variables` accepted a store with\n        # numeric keys.\n        float32_store = {\n            \"0\": np.random.random((8, 16)).astype(\"float32\"),\n            \"1\": np.random.random((16,)).astype(\"float32\"),\n        }\n        int8_store = {\n            \"0\": np.random.randint(-128, 127, size=(8, 16), dtype=\"int8\"),\n            \"1\": np.random.random((16,)).astype(\"float32\"),\n            \"2\": np.random.random((16,)).astype(\"float32\"),  # kernel_scale.\n        }\n        int4_store = {\n            # kernel is [in, ceil(out/2)] = [8, 8]\n            \"0\": np.random.randint(-128, 127, size=(8, 8), dtype=\"int8\"),\n            \"1\": np.random.random((16,)).astype(\"float32\"),\n            \"2\": np.random.random((16,)).astype(\"float32\"),  # kernel_scale.\n        }\n        float8_store = {\n            \"0\": np.random.random((8, 16)).astype(\"float32\"),\n            \"1\": np.random.random((16,)).astype(\"float32\"),\n            # inputs_scale.\n            \"2\": np.random.random(()).astype(\"float32\"),\n            # inputs_amax_history.\n            \"3\": np.random.random((1024,)).astype(\"float32\"),\n            # kernel_scale.\n            \"4\": np.random.random(()).astype(\"float32\"),\n            # kernel_amax_history.\n            \"5\": np.random.random((1024,)).astype(\"float32\"),\n            # outputs_grad_scale.\n            \"6\": np.random.random(()).astype(\"float32\"),\n            # outputs_grad_amax_history.\n            \"7\": np.random.random((1024,)).astype(\"float32\"),\n        }\n        gptq_store = {\n            # bias\n            \"0\": np.random.random((16,)).astype(\"float32\"),\n            # quantized_kernel\n            \"1\": np.random.randint(0, 16, size=(8, 8), dtype=\"uint8\"),\n            # kernel_scale.\n            \"2\": np.random.random((16, 1)).astype(\"float32\"),\n            # kernel_zero\n            \"3\": np.random.random((16, 1)).astype(\"uint8\"),\n            # g_idx\n            \"4\": np.random.random((8,)).astype(\"float32\"),\n        }\n        awq_store = {\n            \"0\": np.random.random((16,)).astype(\"float32\"),  # bias\n            \"1\": np.random.randint(0, 16, size=(8, 8), dtype=\"uint8\"),  # kernel\n            \"2\": np.random.random((16, 1)).astype(\"float32\"),  # scale\n            \"3\": np.random.random((16, 1)).astype(\"uint8\"),  # zero\n            \"4\": np.random.random((8,)).astype(\"float32\"),  # awq_scales\n            \"5\": np.random.random((8,)).astype(\"float32\"),  # g_idx\n        }\n\n        # Test float32 layer.\n        layer = layers.Dense(units=16)\n        layer.build((None, 8))\n        layer.load_own_variables(float32_store)\n        self.assertAllClose(layer._kernel, float32_store[\"0\"])\n        self.assertAllClose(layer.bias, float32_store[\"1\"])\n\n        # Test int8-quantized layer.\n        layer = layers.Dense(units=16, dtype=\"int8_from_float32\")\n        layer.build((None, 8))\n        layer.load_own_variables(int8_store)\n        self.assertAllClose(layer._kernel, int8_store[\"0\"])\n        self.assertAllClose(layer.bias, int8_store[\"1\"])\n        self.assertAllClose(layer.kernel_scale, int8_store[\"2\"])\n\n        # Test int4-quantized layer.\n        layer = layers.Dense(units=16, dtype=\"int4_from_float32\")\n        layer.build((None, 8))\n        layer.load_own_variables(int4_store)\n        self.assertAllClose(layer._kernel, int4_store[\"0\"])\n        self.assertAllClose(layer.bias, int4_store[\"1\"])\n        self.assertAllClose(layer.kernel_scale, int4_store[\"2\"])\n\n        # Test float8-quantized layer.\n        layer = layers.Dense(units=16, dtype=\"float8_from_float32\")\n        layer.build((None, 8))\n        layer.load_own_variables(float8_store)\n        self.assertAllClose(layer._kernel, float8_store[\"0\"])\n        self.assertAllClose(layer.bias, float8_store[\"1\"])\n        self.assertAllClose(layer.inputs_scale, float8_store[\"2\"])\n        self.assertAllClose(layer.inputs_amax_history, float8_store[\"3\"])\n        self.assertAllClose(layer.kernel_scale, float8_store[\"4\"])\n        self.assertAllClose(layer.kernel_amax_history, float8_store[\"5\"])\n        self.assertAllClose(layer.outputs_grad_scale, float8_store[\"6\"])\n        self.assertAllClose(layer.outputs_grad_amax_history, float8_store[\"7\"])\n\n        # Test gptq-quantized layer.\n        layer = layers.Dense(units=16, dtype=\"gptq/4/8_from_float32\")\n        layer.build((None, 8))\n        layer.load_own_variables(gptq_store)\n        self.assertTrue(layer.is_gptq_calibrated)\n        self.assertAllClose(layer.bias, gptq_store[\"0\"])\n        self.assertAllClose(layer.quantized_kernel, gptq_store[\"1\"])\n        self.assertAllClose(layer.kernel_scale, gptq_store[\"2\"])\n        self.assertAllClose(layer.kernel_zero, gptq_store[\"3\"])\n        self.assertAllClose(layer.g_idx, gptq_store[\"4\"])\n\n        # Test awq-quantized layer.\n        layer = layers.Dense(units=16, dtype=\"awq/4/8_from_float32\")\n        layer.build((None, 8))\n        layer.load_own_variables(awq_store)\n        self.assertTrue(layer.is_awq_calibrated)\n        self.assertAllClose(layer.bias, awq_store[\"0\"])\n        self.assertAllClose(layer.quantized_kernel, awq_store[\"1\"])\n        self.assertAllClose(layer.kernel_scale, awq_store[\"2\"])\n        self.assertAllClose(layer.kernel_zero, awq_store[\"3\"])\n        self.assertAllClose(layer.awq_scales, awq_store[\"4\"])\n        self.assertAllClose(layer.g_idx, awq_store[\"5\"])\n\n    def test_int4_gptq_kernel_returns_unpacked_form(self):\n        \"\"\"Test that the `kernel` property returns the unpacked int4 GPTQ\n        kernel.\"\"\"\n        layer = layers.Dense(units=2)\n        layer.build((None, 2))\n        layer.quantize(\n            \"gptq\",\n            config=GPTQConfig(\n                dataset=None, tokenizer=None, weight_bits=4, group_size=8\n            ),\n        )\n        layer.is_gptq_calibrated = True  # Bypass calibration check\n        packed_kernel = layer.quantized_kernel\n        self.assertAllClose(\n            layer.kernel, quantizers.unpack_int4(packed_kernel, 2)\n        )\n\n    def test_gptq_kernel_packing(self):\n        \"\"\"Validates that 4-bit GPTQ packing reduces the kernel size.\"\"\"\n        layer = layers.Dense(units=16, use_bias=False)\n        layer.build((None, 8))\n\n        original_kernel_params = ops.prod(layer._kernel.shape)\n\n        layer.quantize(\n            \"gptq\",\n            config=GPTQConfig(\n                dataset=None, tokenizer=None, weight_bits=4, group_size=8\n            ),\n        )\n\n        quantized_kernel_params = ops.prod(layer.quantized_kernel.shape)\n        self.assertEqual(quantized_kernel_params, original_kernel_params // 2)\n\n    def test_int4_awq_kernel_returns_unpacked_form(self):\n        \"\"\"Test that the `kernel` property returns the unpacked int4 AWQ\n        kernel.\"\"\"\n        layer = layers.Dense(units=2)\n        layer.build((None, 2))\n        layer.quantize(\n            \"awq\",\n            config=AWQConfig(\n                dataset=None, tokenizer=None, group_size=8, num_grid_points=10\n            ),\n        )\n        layer.is_awq_calibrated = True  # Bypass calibration check\n        packed_kernel = layer.quantized_kernel\n        self.assertAllClose(\n            layer.kernel, quantizers.unpack_int4(packed_kernel, 2)\n        )\n\n    def test_awq_kernel_packing(self):\n        \"\"\"Validates that 4-bit AWQ packing reduces the kernel size.\"\"\"\n        layer = layers.Dense(units=16, use_bias=False)\n        layer.build((None, 8))\n\n        original_kernel_params = ops.prod(layer._kernel.shape)\n\n        layer.quantize(\n            \"awq\",\n            config=AWQConfig(\n                dataset=None, tokenizer=None, group_size=8, num_grid_points=10\n            ),\n        )\n\n        quantized_kernel_params = ops.prod(layer.quantized_kernel.shape)\n        self.assertEqual(quantized_kernel_params, original_kernel_params // 2)\n\n    def _check_quantizer_config(\n        self, quantizer, valid_class, axis, value_range\n    ):\n        self.assertIsInstance(quantizer, valid_class)\n        self.assertEqual(quantizer.axis, axis)\n\n        # Normalize value_range to list\n        if value_range is not None:\n            self.assertAllEqual(quantizer.value_range, value_range)\n\n    def test_dense_int8_custom_quantizer(self):\n        \"\"\"\n        Test custom quantizer serialization for dense layer.\n        \"\"\"\n        # Setup\n        weight_range = (-127, 127)\n        act_range = (-5, 5)\n        config = Int8QuantizationConfig(\n            weight_quantizer=AbsMaxQuantizer(axis=0, value_range=weight_range),\n            activation_quantizer=AbsMaxQuantizer(\n                axis=-1, value_range=act_range\n            ),\n        )\n\n        # Build & Quantize\n        layer = layers.Dense(10)\n        layer.build((None, 5))\n        layer.quantize(\"int8\", config=config)\n\n        # Serialize & Deserialize\n        serialized = layer.get_config()\n        new_layer = layers.Dense.from_config(serialized)\n\n        # Verify\n        self.assertIsInstance(\n            new_layer.quantization_config, Int8QuantizationConfig\n        )\n        self._check_quantizer_config(\n            new_layer.quantization_config.weight_quantizer,\n            AbsMaxQuantizer,\n            axis=(0,),\n            value_range=weight_range,\n        )\n        self._check_quantizer_config(\n            new_layer.quantization_config.activation_quantizer,\n            AbsMaxQuantizer,\n            axis=(-1,),\n            value_range=act_range,\n        )\n\n    def test_dense_int8_weight_only_quantizer(self):\n        \"\"\"\n        Test custom quantizer serialization for dense layer with\n        weight-only quantization.\n        \"\"\"\n        # Setup\n        config = Int8QuantizationConfig(\n            weight_quantizer=AbsMaxQuantizer(axis=0),\n            activation_quantizer=None,\n        )\n\n        # Build & Quantize\n        layer = layers.Dense(10)\n        layer.build((None, 5))\n        layer.quantize(\"int8\", config=config)\n\n        # Serialize & Deserialize\n        serialized = layer.get_config()\n        new_layer = layers.Dense.from_config(serialized)\n\n        # Verify\n        self.assertIsInstance(\n            new_layer.quantization_config, Int8QuantizationConfig\n        )\n        self.assertIsInstance(\n            new_layer.quantization_config.weight_quantizer, AbsMaxQuantizer\n        )\n        self.assertIsNone(new_layer.quantization_config.activation_quantizer)\n\n    def test_dense_int4_custom_quantizer(self):\n        \"\"\"\n        Test custom quantizer serialization for dense layer with\n        int4 quantization.\n        \"\"\"\n        # Setup - custom quantizers require per-channel mode (block_size=None)\n        weight_range = (-8, 7)\n        act_range = (-2, 2)\n        config = Int4QuantizationConfig(\n            weight_quantizer=AbsMaxQuantizer(axis=0, value_range=weight_range),\n            activation_quantizer=AbsMaxQuantizer(\n                axis=-1, value_range=act_range\n            ),\n            block_size=None,\n        )\n\n        # Build & Quantize\n        layer = layers.Dense(10)\n        layer.build((None, 5))\n        layer.quantize(\"int4\", config=config)\n\n        # Serialize & Deserialize\n        serialized = layer.get_config()\n        new_layer = layers.Dense.from_config(serialized)\n\n        # Verify\n        self.assertIsInstance(\n            new_layer.quantization_config, Int4QuantizationConfig\n        )\n        self._check_quantizer_config(\n            new_layer.quantization_config.weight_quantizer,\n            AbsMaxQuantizer,\n            axis=(0,),\n            value_range=weight_range,\n        )\n        self._check_quantizer_config(\n            new_layer.quantization_config.activation_quantizer,\n            AbsMaxQuantizer,\n            axis=(-1,),\n            value_range=act_range,\n        )\n\n    @parameterized.named_parameters(\n        (\"grouped_block_64\", 64),\n        (\"grouped_block_128\", 128),\n        (\"per_channel_none\", None),\n        (\"per_channel_neg1\", -1),\n    )\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_int4_quantization_block_size(self, block_size):\n        \"\"\"Test int4 quantization with different block_size configurations.\"\"\"\n        input_dim, output_dim = 256, 64\n        layer = layers.Dense(units=output_dim)\n        layer.build((None, input_dim))\n\n        x = np.random.random((2, input_dim)).astype(\"float32\")\n        y_float = layer(x)\n\n        # Create config with specified block_size\n        config = Int4QuantizationConfig(block_size=block_size)\n        layer.quantize(\"int4\", config=config)\n\n        # Verify block_size is stored\n        self.assertEqual(layer._int4_block_size, block_size)\n\n        # Verify kernel_scale shape\n        if block_size is None or block_size == -1:\n            # Per-channel: one scale per output unit\n            expected_scale_shape = (output_dim,)\n        else:\n            # Sub-channel: (n_groups, out_features)\n            n_groups = math.ceil(input_dim / block_size)\n            expected_scale_shape = (n_groups, output_dim)\n\n        self.assertEqual(layer.kernel_scale.shape, expected_scale_shape)\n\n        # Verify outputs are reasonable\n        y_quantized = layer(x)\n        mse = ops.mean(ops.square(y_float - y_quantized))\n        self.assertLess(mse, 0.01)  # Reasonable accuracy\n\n    @parameterized.named_parameters(\n        (\"grouped_block_64\", 64),\n        (\"grouped_block_128\", 128),\n        (\"per_channel_none\", None),\n    )\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_int4_block_size_serialization(self, block_size):\n        \"\"\"Test that block_size is preserved through serialization.\"\"\"\n        layer = layers.Dense(units=32)\n        layer.build((None, 128))\n\n        config = Int4QuantizationConfig(block_size=block_size)\n        layer.quantize(\"int4\", config=config)\n\n        # Get output before serialization\n        x = np.random.random((2, 128)).astype(\"float32\")\n        y_before = layer(x)\n\n        # Save and load model to test full serialization roundtrip\n        model = models.Sequential([layer])\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"int4_block_size_model.keras\"\n        )\n        model.save(temp_filepath)\n        loaded_model = saving.load_model(temp_filepath)\n\n        # Verify block_size is preserved\n        loaded_layer = loaded_model.layers[0]\n        self.assertIsInstance(\n            loaded_layer.quantization_config, Int4QuantizationConfig\n        )\n        self.assertEqual(\n            loaded_layer.quantization_config.block_size, block_size\n        )\n\n        # Verify outputs match after deserialization\n        y_after = loaded_model(x)\n        self.assertAllClose(y_before, y_after)\n\n    @parameterized.named_parameters(\n        (\"grouped_block_64\", 64),\n        (\"per_channel\", None),\n    )\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_int4_block_size_with_lora(self, block_size):\n        \"\"\"Test int4 quantization with LoRA and different block_size.\"\"\"\n        input_dim, output_dim = 128, 64\n        layer = layers.Dense(units=output_dim)\n        layer.build((None, input_dim))\n\n        config = Int4QuantizationConfig(block_size=block_size)\n        layer.quantize(\"int4\", config=config)\n        layer.enable_lora(rank=4)\n\n        x = np.random.random((2, input_dim)).astype(\"float32\")\n\n        # Should run without error\n        y = layer(x)\n        self.assertEqual(y.shape, (2, output_dim))\n\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_int4_grouped_vs_perchannel_scale_shapes(self):\n        \"\"\"Test that grouped and per-channel have different scale shapes.\"\"\"\n        input_dim, output_dim = 256, 64\n        block_size = 64\n\n        # Per-channel layer\n        layer_pc = layers.Dense(units=output_dim)\n        layer_pc.build((None, input_dim))\n        config_pc = Int4QuantizationConfig(block_size=None)\n        layer_pc.quantize(\"int4\", config=config_pc)\n\n        # Grouped layer\n        layer_grouped = layers.Dense(units=output_dim)\n        layer_grouped.build((None, input_dim))\n        config_grouped = Int4QuantizationConfig(block_size=block_size)\n        layer_grouped.quantize(\"int4\", config=config_grouped)\n\n        # Verify different scale shapes\n        self.assertEqual(layer_pc.kernel_scale.shape, (output_dim,))\n        # scale shape is (n_groups, out_features)\n        self.assertEqual(\n            layer_grouped.kernel_scale.shape,\n            (input_dim // block_size, output_dim),\n        )\n\n    @parameterized.named_parameters(\n        (\"grouped_block_64\", 64),\n        (\"grouped_block_128\", 128),\n    )\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_int4_subchannel_g_idx_created(self, block_size):\n        \"\"\"Test that g_idx is created for sub-channel int4 quantization.\"\"\"\n        input_dim, output_dim = 256, 64\n        layer = layers.Dense(units=output_dim)\n        layer.build((None, input_dim))\n\n        config = Int4QuantizationConfig(block_size=block_size)\n        layer.quantize(\"int4\", config=config)\n\n        # Verify g_idx is created\n        self.assertTrue(hasattr(layer, \"g_idx\"))\n\n        # Verify g_idx shape\n        self.assertEqual(layer.g_idx.shape, (input_dim,))\n\n        # Verify g_idx values (should map each row to its group)\n        expected_g_idx = np.arange(input_dim) // block_size\n        self.assertAllClose(layer.g_idx, expected_g_idx)\n\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_int4_perchannel_no_g_idx(self):\n        \"\"\"Test that per-channel int4 does NOT create g_idx.\"\"\"\n        layer = layers.Dense(units=32)\n        layer.build((None, 64))\n\n        config = Int4QuantizationConfig(block_size=None)  # Per-channel\n        layer.quantize(\"int4\", config=config)\n\n        # Verify g_idx is NOT created for per-channel\n        self.assertFalse(hasattr(layer, \"g_idx\"))\n\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_int4_subchannel_g_idx_serialization(self):\n        \"\"\"Test that g_idx is properly serialized and deserialized.\"\"\"\n        input_dim, output_dim = 128, 32\n        block_size = 64\n\n        layer = layers.Dense(units=output_dim)\n        layer.build((None, input_dim))\n\n        config = Int4QuantizationConfig(block_size=block_size)\n        layer.quantize(\"int4\", config=config)\n\n        x = np.random.random((2, input_dim)).astype(\"float32\")\n        y_before = layer(x)\n        g_idx_before = layer.g_idx\n\n        # Save and load\n        model = models.Sequential([layer])\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"int4_g_idx_model.keras\"\n        )\n        model.save(temp_filepath)\n        loaded_model = saving.load_model(temp_filepath)\n\n        # Verify g_idx is preserved\n        loaded_layer = loaded_model.layers[0]\n        self.assertTrue(hasattr(loaded_layer, \"g_idx\"))\n        self.assertAllClose(loaded_layer.g_idx, g_idx_before)\n\n        # Verify outputs match\n        y_after = loaded_model(x)\n        self.assertAllClose(y_before, y_after)\n"
  },
  {
    "path": "keras/src/layers/core/einsum_dense.py",
    "content": "import math\nimport re\nimport string\n\nimport ml_dtypes\nimport numpy as np\n\nfrom keras.src import activations\nfrom keras.src import backend\nfrom keras.src import constraints\nfrom keras.src import dtype_policies\nfrom keras.src import initializers\nfrom keras.src import ops\nfrom keras.src import quantizers\nfrom keras.src import regularizers\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\nfrom keras.src.quantizers.quantization_config import QuantizationConfig\nfrom keras.src.quantizers.quantization_config import get_block_size_for_layer\nfrom keras.src.quantizers.quantizers import dequantize_with_sz_map\nfrom keras.src.saving import serialization_lib\n\n\n@keras_export(\"keras.layers.EinsumDense\")\nclass EinsumDense(Layer):\n    \"\"\"A layer that uses `einsum` as the backing computation.\n\n    This layer can perform einsum calculations of arbitrary dimensionality.\n\n    Args:\n        equation: An equation describing the einsum to perform.\n            This equation must be a valid einsum string of the form\n            `ab,bc->ac`, `...ab,bc->...ac`, or\n            `ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum\n            axis expression sequence.\n        output_shape: The expected shape of the output tensor\n            (excluding the batch dimension and any dimensions\n            represented by ellipses). You can specify `None` for any dimension\n            that is unknown or can be inferred from the input shape.\n        activation: Activation function to use. If you don't specify anything,\n            no activation is applied\n            (that is, a \"linear\" activation: `a(x) = x`).\n        bias_axes: A string containing the output dimension(s)\n            to apply a bias to. Each character in the `bias_axes` string\n            should correspond to a character in the output portion\n            of the `equation` string.\n        kernel_initializer: Initializer for the `kernel` weights matrix.\n        bias_initializer: Initializer for the bias vector.\n        kernel_regularizer: Regularizer function applied to the `kernel` weights\n            matrix.\n        bias_regularizer: Regularizer function applied to the bias vector.\n        kernel_constraint: Constraint function applied to the `kernel` weights\n            matrix.\n        bias_constraint: Constraint function applied to the bias vector.\n        lora_rank: Optional integer. If set, the layer's forward pass\n            will implement LoRA (Low-Rank Adaptation)\n            with the provided rank. LoRA sets the layer's kernel\n            to non-trainable and replaces it with a delta over the\n            original kernel, obtained via multiplying two lower-rank\n            trainable matrices\n            (the factorization happens on the last dimension).\n            This can be useful to reduce the\n            computation cost of fine-tuning large dense layers.\n            You can also enable LoRA on an existing\n            `EinsumDense` layer by calling `layer.enable_lora(rank)`.\n         lora_alpha: Optional integer. If set, this parameter scales the\n            low-rank adaptation delta (computed as the product of two lower-rank\n            trainable matrices) during the forward pass. The delta is scaled by\n            `lora_alpha / lora_rank`, allowing you to fine-tune the strength of\n            the LoRA adjustment independently of `lora_rank`.\n        **kwargs: Base layer keyword arguments, such as `name` and `dtype`.\n\n    Examples:\n\n    **Biased dense layer with einsums**\n\n    This example shows how to instantiate a standard Keras dense layer using\n    einsum operations. This example is equivalent to\n    `keras.layers.Dense(64, use_bias=True)`.\n\n    >>> layer = keras.layers.EinsumDense(\"ab,bc->ac\",\n    ...                                       output_shape=64,\n    ...                                       bias_axes=\"c\")\n    >>> input_tensor = keras.Input(shape=[32])\n    >>> output_tensor = layer(input_tensor)\n    >>> output_tensor.shape\n    (None, 64)\n\n    **Applying a dense layer to a sequence**\n\n    This example shows how to instantiate a layer that applies the same dense\n    operation to every element in a sequence. Here, the `output_shape` has two\n    values (since there are two non-batch dimensions in the output); the first\n    dimension in the `output_shape` is `None`, because the sequence dimension\n    `b` has an unknown shape.\n\n    >>> layer = keras.layers.EinsumDense(\"abc,cd->abd\",\n    ...                                       output_shape=(None, 64),\n    ...                                       bias_axes=\"d\")\n    >>> input_tensor = keras.Input(shape=[32, 128])\n    >>> output_tensor = layer(input_tensor)\n    >>> output_tensor.shape\n    (None, 32, 64)\n\n    **Applying a dense layer to a sequence using ellipses**\n\n    This example shows how to instantiate a layer that applies the same dense\n    operation to every element in a sequence, but uses the ellipsis notation\n    instead of specifying the batch and sequence dimensions.\n\n    Because we are using ellipsis notation and have specified only one axis, the\n    `output_shape` arg is a single value. When instantiated in this way, the\n    layer can handle any number of sequence dimensions - including the case\n    where no sequence dimension exists.\n\n    >>> layer = keras.layers.EinsumDense(\"...x,xy->...y\",\n    ...                                       output_shape=64,\n    ...                                       bias_axes=\"y\")\n    >>> input_tensor = keras.Input(shape=[32, 128])\n    >>> output_tensor = layer(input_tensor)\n    >>> output_tensor.shape\n    (None, 32, 64)\n    \"\"\"\n\n    def __init__(\n        self,\n        equation,\n        output_shape,\n        activation=None,\n        bias_axes=None,\n        kernel_initializer=\"glorot_uniform\",\n        bias_initializer=\"zeros\",\n        kernel_regularizer=None,\n        bias_regularizer=None,\n        kernel_constraint=None,\n        bias_constraint=None,\n        lora_rank=None,\n        lora_alpha=None,\n        gptq_unpacked_column_size=None,\n        quantization_config=None,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.equation = equation\n        if isinstance(output_shape, int):\n            self.partial_output_shape = (output_shape,)\n        else:\n            self.partial_output_shape = tuple(output_shape)\n        self.bias_axes = bias_axes\n        self.activation = activations.get(activation)\n        self.kernel_initializer = initializers.get(kernel_initializer)\n        self.bias_initializer = initializers.get(bias_initializer)\n        self.kernel_regularizer = regularizers.get(kernel_regularizer)\n        self.bias_regularizer = regularizers.get(bias_regularizer)\n        self.kernel_constraint = constraints.get(kernel_constraint)\n        self.bias_constraint = constraints.get(bias_constraint)\n        self.lora_rank = lora_rank\n        self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank\n        self.lora_enabled = False\n        self.gptq_unpacked_column_size = gptq_unpacked_column_size\n        self.quantization_config = quantization_config\n\n    def build(self, input_shape):\n        shape_data = _analyze_einsum_string(\n            self.equation,\n            self.bias_axes,\n            input_shape,\n            self.partial_output_shape,\n        )\n        kernel_shape, bias_shape, full_output_shape = shape_data\n        self.full_output_shape = tuple(full_output_shape)\n        self.input_spec = InputSpec(ndim=len(input_shape))\n        if self.quantization_mode is not None:\n            self.quantized_build(\n                kernel_shape,\n                mode=self.quantization_mode,\n                config=self.quantization_config,\n            )\n        # Skip creating a duplicate kernel variable when the layer is already\n        # quantized to int8 or int4, because `quantized_build` has created the\n        # appropriate kernel variable. For other modes (e.g., float8 or no\n        # quantization), we still need the floating-point kernel.\n        if self.quantization_mode not in (\"int8\", \"int4\", \"gptq\", \"awq\"):\n            # If the layer is quantized to int8, `self._kernel` will be added\n            # in `self._int8_build`. Therefore, we skip it here.\n            self._kernel = self.add_weight(\n                name=\"kernel\",\n                shape=tuple(kernel_shape),\n                initializer=self.kernel_initializer,\n                regularizer=self.kernel_regularizer,\n                constraint=self.kernel_constraint,\n                dtype=self.dtype,\n                trainable=True,\n            )\n        if bias_shape is not None:\n            self.bias = self.add_weight(\n                name=\"bias\",\n                shape=tuple(bias_shape),\n                initializer=self.bias_initializer,\n                regularizer=self.bias_regularizer,\n                constraint=self.bias_constraint,\n                dtype=self.dtype,\n                trainable=True,\n            )\n        else:\n            self.bias = None\n        self.built = True\n        if self.lora_rank:\n            self.enable_lora(self.lora_rank, lora_alpha=self.lora_alpha)\n\n    @property\n    def kernel(self):\n        from keras.src.quantizers import gptq_core\n\n        if not self.built:\n            raise AttributeError(\n                \"You must build the layer before accessing `kernel`.\"\n            )\n\n        mode = self.quantization_mode\n        is_gptq = mode == \"gptq\"\n        is_awq = mode == \"awq\"\n        is_int4 = mode == \"int4\"\n        gptq_calibrated = bool(getattr(self, \"is_gptq_calibrated\", False))\n        awq_calibrated = bool(getattr(self, \"is_awq_calibrated\", False))\n        gptq_bits = (\n            gptq_core.get_weight_bits_for_layer(self, None) if is_gptq else None\n        )\n\n        # Decide the source tensor first (packed vs already-quantized vs plain\n        # kernel)\n        if is_gptq and gptq_calibrated and gptq_bits != 4:\n            # calibrated GPTQ, not 4-bit, no unpacking needed\n            kernel = self.quantized_kernel\n        else:\n            # Start with the stored kernel\n            kernel = getattr(self, \"_kernel\", None)\n\n            # Handle int4 unpacking cases in one place\n            if is_int4:\n                # unpack [rows, ceil(columns/2)] to [rows, columns]\n                kernel = quantizers.unpack_int4(\n                    kernel,\n                    self._int4_unpacked_column_size,\n                    axis=-1,\n                )\n                kernel = ops.reshape(kernel, self.original_kernel_shape)\n            elif is_gptq and gptq_calibrated and gptq_bits == 4:\n                kernel = quantizers.unpack_int4(\n                    self.quantized_kernel,\n                    orig_len=self.gptq_unpacked_column_size,\n                    axis=0,\n                    dtype=\"uint8\",\n                )\n            elif is_awq and awq_calibrated:\n                # AWQ always uses 4-bit quantization\n                kernel = quantizers.unpack_int4(\n                    self.quantized_kernel,\n                    orig_len=self.awq_unpacked_column_size,\n                    axis=0,\n                    dtype=\"uint8\",\n                )\n\n        # Apply LoRA if enabled\n        if self.lora_enabled:\n            kernel = kernel + (self.lora_alpha / self.lora_rank) * ops.matmul(\n                self.lora_kernel_a, self.lora_kernel_b\n            )\n\n        return kernel\n\n    def compute_output_shape(self, _):\n        return self.full_output_shape\n\n    def call(self, inputs, training=None):\n        x = ops.einsum(self.equation, inputs, self.kernel)\n        if self.bias is not None:\n            x = ops.add(x, self.bias)\n        if self.activation is not None:\n            x = self.activation(x)\n        return x\n\n    def enable_lora(\n        self,\n        rank,\n        lora_alpha=None,\n        a_initializer=\"he_uniform\",\n        b_initializer=\"zeros\",\n    ):\n        if self.kernel_constraint:\n            raise ValueError(\n                \"Lora is incompatible with kernel constraints. \"\n                \"In order to enable lora on this layer, remove the \"\n                \"`kernel_constraint` argument.\"\n            )\n        if not self.built:\n            raise ValueError(\n                \"Cannot enable lora on a layer that isn't yet built.\"\n            )\n        if self.lora_enabled:\n            raise ValueError(\n                \"lora is already enabled. This can only be done once per layer.\"\n            )\n        if self.quantization_mode == \"gptq\":\n            raise NotImplementedError(\n                \"lora is not currently supported with GPTQ quantization.\"\n            )\n        self._tracker.unlock()\n        # Determine the appropriate (unpacked) kernel shape for LoRA.\n        if self.quantization_mode == \"int4\":\n            # INT4 weights are stored in a flattened 2D layout that loses\n            # the original N-dimensional structure required by the einsum\n            # equation. We use `original_kernel_shape`` to ensure LoRA adapters\n            # operate in the correct logical dimension space.\n            kernel_shape_for_lora = tuple(self.original_kernel_shape)\n        else:\n            kernel_shape_for_lora = self.kernel.shape\n\n        self.lora_kernel_a = self.add_weight(\n            name=\"lora_kernel_a\",\n            shape=(kernel_shape_for_lora[:-1] + (rank,)),\n            initializer=initializers.get(a_initializer),\n            regularizer=self.kernel_regularizer,\n        )\n        self.lora_kernel_b = self.add_weight(\n            name=\"lora_kernel_b\",\n            shape=(rank, kernel_shape_for_lora[-1]),\n            initializer=initializers.get(b_initializer),\n            regularizer=self.kernel_regularizer,\n        )\n        self._kernel.trainable = False\n        self._tracker.lock()\n        self.lora_enabled = True\n        self.lora_rank = rank\n        self.lora_alpha = lora_alpha if lora_alpha is not None else rank\n\n    def save_own_variables(self, store):\n        # Do nothing if the layer isn't yet built\n        if not self.built:\n            return\n        mode = self.quantization_mode\n        if mode not in self.variable_serialization_spec:\n            raise self._quantization_mode_error(mode)\n\n        # Kernel plus optional merged LoRA-aware scale/zero (returns\n        # (kernel, None, None) for None/gptq)\n        kernel_value, merged_kernel_scale, merged_kernel_zero = (\n            self._get_kernel_with_merged_lora()\n        )\n        idx = 0\n        for name in self.variable_serialization_spec[mode]:\n            if name == \"kernel\":\n                store[str(idx)] = kernel_value\n            elif name == \"bias\" and self.bias is None:\n                continue\n            elif name == \"kernel_zero\":\n                if merged_kernel_zero is None:\n                    # kernel_zero only exists for sub-channel int4 quantization\n                    continue\n                store[str(idx)] = merged_kernel_zero\n            elif name == \"g_idx\":\n                if not hasattr(self, \"g_idx\"):\n                    # g_idx only exists for sub-channel int4 quantization\n                    continue\n                store[str(idx)] = self.g_idx\n            elif name == \"kernel_scale\" and mode in (\"int4\", \"int8\"):\n                # For int4/int8, the merged LoRA scale (if any) comes from\n                # `_get_kernel_with_merged_lora()`\n                store[str(idx)] = merged_kernel_scale\n            else:\n                store[str(idx)] = getattr(self, name)\n            idx += 1\n\n    def load_own_variables(self, store):\n        if not self.lora_enabled:\n            self._check_load_own_variables(store)\n        # Do nothing if the layer isn't yet built\n        if not self.built:\n            return\n        mode = self.quantization_mode\n        if mode not in self.variable_serialization_spec:\n            raise self._quantization_mode_error(mode)\n\n        # A saved GPTQ/AWQ quantized model will always be calibrated.\n        self.is_gptq_calibrated = mode == \"gptq\"\n        self.is_awq_calibrated = mode == \"awq\"\n\n        idx = 0\n        for name in self.variable_serialization_spec[mode]:\n            if name == \"kernel\":\n                self._kernel.assign(store[str(idx)])\n            elif name == \"bias\" and self.bias is None:\n                continue\n            elif name == \"kernel_zero\" and not hasattr(self, \"kernel_zero\"):\n                # kernel_zero only exists for sub-channel int4 quantization\n                continue\n            elif name == \"g_idx\" and not hasattr(self, \"g_idx\"):\n                # g_idx only exists for sub-channel int4 quantization\n                continue\n            else:\n                getattr(self, name).assign(store[str(idx)])\n            idx += 1\n        if self.lora_enabled:\n            self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))\n            self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\n            \"output_shape\": self.partial_output_shape,\n            \"equation\": self.equation,\n            \"activation\": activations.serialize(self.activation),\n            \"bias_axes\": self.bias_axes,\n            \"kernel_initializer\": initializers.serialize(\n                self.kernel_initializer\n            ),\n            \"bias_initializer\": initializers.serialize(self.bias_initializer),\n            \"kernel_regularizer\": regularizers.serialize(\n                self.kernel_regularizer\n            ),\n            \"bias_regularizer\": regularizers.serialize(self.bias_regularizer),\n            \"activity_regularizer\": regularizers.serialize(\n                self.activity_regularizer\n            ),\n            \"kernel_constraint\": constraints.serialize(self.kernel_constraint),\n            \"bias_constraint\": constraints.serialize(self.bias_constraint),\n            \"quantization_config\": serialization_lib.serialize_keras_object(\n                self.quantization_config\n            ),\n        }\n        if self.lora_rank:\n            config[\"lora_rank\"] = self.lora_rank\n            config[\"lora_alpha\"] = self.lora_alpha\n        if self.gptq_unpacked_column_size:\n            config[\"gptq_unpacked_column_size\"] = self.gptq_unpacked_column_size\n        return {**base_config, **config}\n\n    @classmethod\n    def from_config(cls, config):\n        config = config.copy()\n        config[\"quantization_config\"] = (\n            serialization_lib.deserialize_keras_object(\n                config.get(\"quantization_config\", None)\n            )\n        )\n        return super().from_config(config)\n\n    @property\n    def variable_serialization_spec(self):\n        \"\"\"Returns a dict mapping quantization modes to variable names in order.\n\n        This spec is used by `save_own_variables` and `load_own_variables` to\n        determine the correct ordering of variables during serialization for\n        each quantization mode. `None` means no quantization.\n        \"\"\"\n        return {\n            None: [\n                \"kernel\",\n                \"bias\",\n            ],\n            \"int8\": [\n                \"kernel\",\n                \"bias\",\n                \"kernel_scale\",\n            ],\n            \"int4\": [\n                \"kernel\",\n                \"bias\",\n                \"kernel_scale\",\n                \"kernel_zero\",\n                \"g_idx\",\n            ],\n            \"float8\": [\n                \"kernel\",\n                \"bias\",\n                \"inputs_scale\",\n                \"inputs_amax_history\",\n                \"kernel_scale\",\n                \"kernel_amax_history\",\n                \"outputs_grad_scale\",\n                \"outputs_grad_amax_history\",\n            ],\n            \"gptq\": [\n                \"bias\",\n                \"quantized_kernel\",\n                \"kernel_scale\",\n                \"kernel_zero\",\n                \"g_idx\",\n            ],\n            \"awq\": [\n                \"bias\",\n                \"quantized_kernel\",\n                \"kernel_scale\",\n                \"kernel_zero\",\n                \"awq_scales\",\n                \"g_idx\",\n            ],\n        }\n\n    def quantized_build(self, kernel_shape, mode, config=None):\n        if mode == \"int8\":\n            self._int8_build(kernel_shape, config)\n        elif mode == \"int4\":\n            self._int4_build(kernel_shape, config)\n        elif mode == \"float8\":\n            self._float8_build()\n        elif mode == \"gptq\":\n            self._gptq_build(kernel_shape, config)\n        elif mode == \"awq\":\n            self._awq_build(kernel_shape, config)\n        else:\n            raise self._quantization_mode_error(mode)\n        self._is_quantized = True\n\n    def _int8_build(self, kernel_shape, config=None):\n        self._set_quantization_info()\n        self.inputs_quantizer = (\n            QuantizationConfig.activation_quantizer_or_default(\n                config,\n                quantizers.AbsMaxQuantizer(),\n            )\n        )\n        # If the config provided a default AbsMaxQuantizer, we need to\n        # override the axis to match the equation's reduction axes.\n        self.quantization_axis = tuple(self._input_reduced_axes)\n        self._kernel = self.add_weight(\n            name=\"kernel\",\n            shape=kernel_shape,\n            initializer=\"zeros\",\n            dtype=\"int8\",\n            trainable=False,\n        )\n        kernel_scale_shape = self._get_kernel_scale_shape(kernel_shape)\n        self.kernel_scale = self.add_weight(\n            name=\"kernel_scale\",\n            shape=kernel_scale_shape,\n            initializer=\"ones\",\n            trainable=False,\n        )\n\n    def _gptq_build(self, kernel_shape, config):\n        \"\"\"\n        Allocate quantized kernel & params for EinsumDense.\n\n        Args:\n            kernel_shape: tuple/list; the layer's original kernel shape, e.g.\n                [in_features, out_features] or [in_features, heads, head_dim].\n            group_size: int; contiguous input-group size for quantization\n                (=-1 means per-output-channel with no grouping).\n        \"\"\"\n        from keras.src.quantizers import gptq_core\n\n        # Ensures the forward pass uses the original high-precision kernel\n        # until calibration has been performed.\n        self.is_gptq_calibrated = False\n\n        self.original_kernel_shape = kernel_shape\n        if len(kernel_shape) == 2:\n            rows = kernel_shape[0]\n            columns = kernel_shape[1]\n        elif len(kernel_shape) == 3:\n            shape = list(self.original_kernel_shape)\n            d_model_dim_index = shape.index(max(shape))\n\n            if d_model_dim_index == 0:  # QKV projection case\n                in_features, heads, head_dim = shape\n                rows, columns = (\n                    in_features,\n                    heads * head_dim,\n                )\n            elif d_model_dim_index in [1, 2]:  # Attention Output case\n                heads, head_dim, out_features = shape\n                rows, columns = (\n                    heads * head_dim,\n                    out_features,\n                )\n            else:\n                raise ValueError(\"Could not determine row/column split.\")\n\n        group_size = gptq_core.get_group_size_for_layer(self, config)\n        n_groups = 1 if group_size == -1 else math.ceil(rows / group_size)\n\n        self.gptq_unpacked_column_size = columns\n\n        weight_bits = gptq_core.get_weight_bits_for_layer(self, config)\n        # For 4-bit weights, we pack two values per byte.\n        kernel_columns = (columns + 1) // 2 if weight_bits == 4 else columns\n\n        self._set_quantization_info()\n\n        self.quantized_kernel = self.add_weight(\n            name=\"kernel\",\n            shape=(kernel_columns, rows),\n            initializer=\"zeros\",\n            dtype=\"uint8\",\n            trainable=False,\n        )\n\n        self.kernel_scale = self.add_weight(\n            name=\"kernel_scale\",\n            shape=(columns, n_groups),\n            initializer=\"ones\",\n            trainable=False,\n        )\n        self.kernel_zero = self.add_weight(\n            name=\"zero_point\",\n            shape=(columns, n_groups),\n            initializer=\"zeros\",\n            dtype=\"uint8\",\n            trainable=False,\n        )\n\n        self.g_idx = self.add_weight(\n            name=\"g_idx\",\n            shape=(rows,),\n            initializer=\"zeros\",\n            dtype=\"float32\",\n            trainable=False,\n        )\n\n    def _gptq_call(self, inputs, training=False):\n        from keras.src.quantizers import gptq_core\n\n        if not self.is_gptq_calibrated:\n            W = self._kernel\n        else:\n            should_unpack = (\n                gptq_core.get_weight_bits_for_layer(self, config=None) == 4\n            )\n            W = (\n                quantizers.unpack_int4(\n                    self.quantized_kernel,\n                    orig_len=self.gptq_unpacked_column_size,\n                    axis=0,\n                    dtype=\"uint8\",\n                )\n                if should_unpack\n                else self.quantized_kernel\n            )\n            W = dequantize_with_sz_map(\n                W,\n                self.kernel_scale,\n                self.kernel_zero,\n                self.g_idx,\n            )\n            W = ops.transpose(W)\n\n            W = ops.reshape(W, self.original_kernel_shape)\n\n        y = ops.einsum(self.equation, inputs, W)\n        if self.bias is not None:\n            y = ops.add(y, self.bias)\n        if self.activation is not None:\n            y = self.activation(y)\n        return y\n\n    def _awq_build(self, kernel_shape, config):\n        \"\"\"Build variables for AWQ quantization.\n\n        AWQ uses 4-bit quantization with per-channel AWQ scales that protect\n        salient weights based on activation magnitudes.\n        \"\"\"\n        from keras.src.quantizers import awq_core\n\n        # Ensures the forward pass uses the original high-precision kernel\n        # until calibration has been performed.\n        self.is_awq_calibrated = False\n\n        self.original_kernel_shape = kernel_shape\n        if len(kernel_shape) == 2:\n            rows = kernel_shape[0]\n            columns = kernel_shape[1]\n        elif len(kernel_shape) == 3:\n            shape = list(self.original_kernel_shape)\n            d_model_dim_index = shape.index(max(shape))\n\n            if d_model_dim_index == 0:  # QKV projection case\n                in_features, heads, head_dim = shape\n                rows, columns = (\n                    in_features,\n                    heads * head_dim,\n                )\n            elif d_model_dim_index in [1, 2]:  # Attention Output case\n                heads, head_dim, out_features = shape\n                rows, columns = (\n                    heads * head_dim,\n                    out_features,\n                )\n            else:\n                raise ValueError(\"Could not determine row/column split.\")\n        else:\n            raise ValueError(\"AWQ quantization only supports 2D or 3D kernels.\")\n\n        group_size = awq_core.get_group_size_for_layer(self, config)\n        num_groups = 1 if group_size == -1 else math.ceil(rows / group_size)\n\n        self.awq_unpacked_column_size = columns\n\n        # For 4-bit weights, we pack two values per byte.\n        kernel_columns = (columns + 1) // 2\n\n        self._set_quantization_info()\n\n        self.quantized_kernel = self.add_weight(\n            name=\"kernel\",\n            shape=(kernel_columns, rows),\n            initializer=\"zeros\",\n            dtype=\"uint8\",\n            trainable=False,\n        )\n\n        self.kernel_scale = self.add_weight(\n            name=\"kernel_scale\",\n            shape=(columns, num_groups),\n            initializer=\"ones\",\n            trainable=False,\n        )\n        self.kernel_zero = self.add_weight(\n            name=\"zero_point\",\n            shape=(columns, num_groups),\n            initializer=\"zeros\",\n            dtype=\"uint8\",\n            trainable=False,\n        )\n\n        # Per-channel AWQ scales from activation magnitudes\n        self.awq_scales = self.add_weight(\n            name=\"awq_scales\",\n            shape=(rows,),\n            initializer=\"ones\",\n            trainable=False,\n        )\n\n        self.g_idx = self.add_weight(\n            name=\"g_idx\",\n            shape=(rows,),\n            initializer=\"zeros\",\n            dtype=\"float32\",\n            trainable=False,\n        )\n\n    def _awq_call(self, inputs, training=False):\n        \"\"\"Forward pass for AWQ quantized layer.\"\"\"\n        if not self.is_awq_calibrated:\n            W = self._kernel\n        else:\n            # Unpack 4-bit weights\n            W = quantizers.unpack_int4(\n                self.quantized_kernel,\n                orig_len=self.awq_unpacked_column_size,\n                axis=0,\n                dtype=\"uint8\",\n            )\n            # Dequantize using scale/zero maps\n            W = dequantize_with_sz_map(\n                W,\n                self.kernel_scale,\n                self.kernel_zero,\n                self.g_idx,\n            )\n            W = ops.transpose(W)\n\n            # Apply AWQ scales by dividing to restore original magnitude\n            # (We multiplied by scales before quantization, so divide to undo)\n            # awq_scales has shape [input_dim], W has shape [input_dim, out_dim]\n            # Expand dims for proper broadcasting.\n            W = ops.divide(W, ops.expand_dims(self.awq_scales, -1))\n\n            W = ops.reshape(W, self.original_kernel_shape)\n\n        y = ops.einsum(self.equation, inputs, W)\n        if self.bias is not None:\n            y = ops.add(y, self.bias)\n        if self.activation is not None:\n            y = self.activation(y)\n        return y\n\n    def _int4_build(self, kernel_shape, config=None):\n        \"\"\"Build variables for int4 quantization.\n\n        The kernel is  flattened to 2D [rows, columns]\n        and packed along last axis to [rows, ceil(columns/2)].\n\n        Args:\n            kernel_shape: Original kernel shape (may be N-dimensional).\n            config: Optional quantization config specifying block_size.\n        \"\"\"\n        self._set_quantization_info()\n\n        self.inputs_quantizer = (\n            QuantizationConfig.activation_quantizer_or_default(config, None)\n        )\n        self.quantization_axis = tuple(self._input_reduced_axes)\n        self.original_kernel_shape = kernel_shape\n\n        # Flatten kernel to 2D: rows = reduced dims, columns = non-reduced dims\n        rows = 1\n        columns = 1\n        for i, dim in enumerate(kernel_shape):\n            if i in self._kernel_reduced_axes:\n                rows *= dim\n            else:\n                columns *= dim\n\n        block_size = get_block_size_for_layer(self, config)\n        use_grouped = block_size is not None and block_size != -1\n        self._int4_block_size = block_size if use_grouped else None\n        self._int4_unpacked_column_size = columns\n        self._int4_rows = rows\n\n        # Kernel packed along last axis (columns)\n        # Stored shape: [rows, ceil(columns/2)]\n        packed_cols = (columns + 1) // 2\n        self._kernel = self.add_weight(\n            name=\"kernel\",\n            shape=(rows, packed_cols),\n            initializer=\"zeros\",\n            dtype=\"int8\",\n            trainable=False,\n        )\n\n        if use_grouped:\n            # Sub-channel: [n_groups, columns]\n            n_groups = math.ceil(rows / block_size)\n            scale_shape = (n_groups, columns)\n        else:\n            scale_shape = (columns,)\n\n        self.kernel_scale = self.add_weight(\n            name=\"kernel_scale\",\n            shape=scale_shape,\n            initializer=\"ones\",\n            trainable=False,\n        )\n\n        # Sub-channel quantization uses asymmetric quantization with zero point\n        if use_grouped:\n            self.kernel_zero = self.add_weight(\n                name=\"kernel_zero\",\n                shape=scale_shape,\n                initializer=\"zeros\",\n                dtype=\"int8\",\n                trainable=False,\n            )\n            self.g_idx = self.add_weight(\n                name=\"g_idx\",\n                shape=(rows,),\n                initializer=\"zeros\",\n                dtype=\"float32\",\n                trainable=False,\n            )\n            self.g_idx.assign(\n                ops.floor_divide(ops.arange(rows, dtype=\"float32\"), block_size)\n            )\n\n    def _float8_build(self):\n        from keras.src.dtype_policies import QuantizedFloat8DTypePolicy\n\n        # If `self.dtype_policy` is not QuantizedFloat8DTypePolicy, then set\n        # `amax_history_length` to its default value.\n        amax_history_length = getattr(\n            self.dtype_policy,\n            \"amax_history_length\",\n            QuantizedFloat8DTypePolicy.default_amax_history_length,\n        )\n        # We set `trainable=True` because we will use the gradients to overwrite\n        # these variables\n        scale_kwargs = {\n            \"shape\": (),\n            \"initializer\": \"ones\",\n            \"dtype\": \"float32\",  # Always be float32\n            \"trainable\": True,\n            \"autocast\": False,\n            \"overwrite_with_gradient\": True,\n        }\n        amax_history_kwargs = {\n            \"shape\": (amax_history_length,),\n            \"initializer\": \"zeros\",\n            \"dtype\": \"float32\",  # Always be float32\n            \"trainable\": True,\n            \"autocast\": False,\n            \"overwrite_with_gradient\": True,\n        }\n        self.inputs_scale = self.add_weight(name=\"inputs_scale\", **scale_kwargs)\n        self.inputs_amax_history = self.add_weight(\n            name=\"inputs_amax_history\", **amax_history_kwargs\n        )\n        self.kernel_scale = self.add_weight(name=\"kernel_scale\", **scale_kwargs)\n        self.kernel_amax_history = self.add_weight(\n            name=\"kernel_amax_history\", **amax_history_kwargs\n        )\n        self.outputs_grad_scale = self.add_weight(\n            name=\"outputs_grad_scale\", **scale_kwargs\n        )\n        self.outputs_grad_amax_history = self.add_weight(\n            name=\"outputs_grad_amax_history\", **amax_history_kwargs\n        )\n\n    def _int8_call(self, inputs, training=None):\n        @ops.custom_gradient\n        def einsum_with_inputs_gradient(inputs, kernel, kernel_scale):\n            \"\"\"Performs int8 quantized einsum with a custom gradient.\n\n            Computes the einsum operation with quantized inputs and a quantized\n            kernel, then de-quantizes the result.\n\n            Also computes the gradient with respect to the original,\n            full-precision inputs by using a de-quantized kernel.\n\n            Args:\n                inputs: The full-precision input tensor.\n                kernel: The int8 quantized kernel tensor.\n                kernel_scale: The float32 scale factor for the kernel.\n\n            Returns:\n                A tuple `(output, grad_fn)`:\n                    `output`: The de-quantized result of the einsum operation.\n                    `grad_fn`: The custom gradient function for the backward\n                        pass.\n\n            Raises:\n                ValueError: If the quantization mode is not supported.\n            \"\"\"\n\n            def grad_fn(*args, upstream=None):\n                if upstream is None:\n                    (upstream,) = args\n                # De-scale kernel\n                _kernel_scale = kernel_scale\n                _kernel_scale = self._adjust_scale_for_dequant(_kernel_scale)\n                float_kernel = ops.divide(\n                    ops.cast(kernel, dtype=self.compute_dtype),\n                    _kernel_scale,\n                )\n                # From https://stackoverflow.com/a/47609896\n                inputs_grad = ops.einsum(\n                    self._custom_gradient_equation, upstream, float_kernel\n                )\n                return (inputs_grad, None, None)\n\n            if self.inputs_quantizer:\n                inputs, inputs_scale = self.inputs_quantizer(\n                    inputs, axis=self.quantization_axis\n                )\n                # Align `inputs_scale` axes with the output\n                # for correct broadcasting\n                inputs_scale = self._adjust_scale_for_quant(\n                    inputs_scale, \"input\"\n                )\n                x = ops.einsum(self.equation, inputs, kernel)\n                # De-scale outputs\n                x = ops.cast(x, self.compute_dtype)\n                x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))\n            else:\n                # Weight-only quantization: dequantize kernel and use float\n                # einsum. This is a workaround for PyTorch's einsum which\n                # doesn't support mixed-precision inputs (float input,\n                # int8 kernel).\n                if backend.backend() == \"torch\":\n                    kernel_scale = self._adjust_scale_for_dequant(kernel_scale)\n                    float_kernel = ops.divide(\n                        ops.cast(kernel, dtype=self.compute_dtype),\n                        kernel_scale,\n                    )\n                    x = ops.einsum(self.equation, inputs, float_kernel)\n                else:\n                    x = ops.einsum(self.equation, inputs, kernel)\n                    # De-scale outputs\n                    x = ops.cast(x, self.compute_dtype)\n                    x = ops.divide(x, kernel_scale)\n            return x, grad_fn\n\n        x = einsum_with_inputs_gradient(\n            inputs,\n            ops.convert_to_tensor(self._kernel),\n            ops.convert_to_tensor(self.kernel_scale),\n        )\n        if self.lora_enabled:\n            lora_x = ops.einsum(self.equation, inputs, self.lora_kernel_a)\n            lora_x = ops.matmul(lora_x, self.lora_kernel_b)\n            x = ops.add(x, (self.lora_alpha / self.lora_rank) * lora_x)\n        if self.bias is not None:\n            x = ops.add(x, self.bias)\n        if self.activation is not None:\n            x = self.activation(x)\n        return x\n\n    def _int4_call(self, inputs, training=None):\n        \"\"\"Forward pass for int4 quantized EinsumDense.\n\n        Uses custom gradients to handle quantized weights since autodiff\n        cannot differentiate through int4 operations.\n        \"\"\"\n        block_size = getattr(self, \"_int4_block_size\", None)\n\n        if block_size is None or block_size == -1:\n\n            @ops.custom_gradient\n            def einsum_per_channel_with_inputs_gradient(\n                inputs, packed_kernel, kernel_scale\n            ):\n                \"\"\"Per-channel int4 forward pass with custom gradient.\"\"\"\n                # Unpack: stored as [rows, ceil(columns/2)],\n                # unpack along last axis\n                unpacked_kernel = quantizers.unpack_int4(\n                    packed_kernel,\n                    self._int4_unpacked_column_size,\n                    axis=-1,\n                    dtype=\"int8\",\n                )\n\n                def _dequantize_kernel(unpacked, scale):\n                    # kernel is [rows, columns], scale is [columns]\n                    float_kernel = ops.divide(\n                        ops.cast(unpacked, dtype=self.compute_dtype),\n                        scale,\n                    )\n                    return ops.reshape(float_kernel, self.original_kernel_shape)\n\n                def grad_fn(*args, upstream=None):\n                    if upstream is None:\n                        (upstream,) = args\n                    float_kernel = _dequantize_kernel(\n                        unpacked_kernel, kernel_scale\n                    )\n                    inputs_grad = ops.einsum(\n                        self._custom_gradient_equation, upstream, float_kernel\n                    )\n                    return (inputs_grad, None, None)\n\n                if self.inputs_quantizer:\n                    # Per-channel with input quantization\n                    float_kernel = _dequantize_kernel(\n                        unpacked_kernel, kernel_scale\n                    )\n                    inputs_q, inputs_scale = self.inputs_quantizer(\n                        inputs, axis=self.quantization_axis\n                    )\n                    inputs_scale = self._adjust_scale_for_quant(\n                        inputs_scale, \"input\"\n                    )\n                    # Cast inputs to float for einsum. This is a workaround\n                    # for PyTorch's einsum which doesn't support\n                    # mixed-precision inputs (int8 input, float kernel).\n                    if backend.backend() == \"torch\":\n                        x = ops.einsum(\n                            self.equation,\n                            ops.cast(inputs_q, self.compute_dtype),\n                            float_kernel,\n                        )\n                        x = ops.divide(x, inputs_scale)\n                    else:\n                        x = ops.einsum(self.equation, inputs_q, float_kernel)\n                        x = ops.cast(x, self.compute_dtype)\n                        x = ops.divide(x, inputs_scale)\n                else:\n                    # Weight-only per-channel quantization\n                    float_kernel = _dequantize_kernel(\n                        unpacked_kernel, kernel_scale\n                    )\n                    x = ops.einsum(self.equation, inputs, float_kernel)\n                return x, grad_fn\n\n            x = einsum_per_channel_with_inputs_gradient(\n                inputs,\n                ops.convert_to_tensor(self._kernel),\n                ops.convert_to_tensor(self.kernel_scale),\n            )\n        else:\n\n            @ops.custom_gradient\n            def einsum_sub_channel_with_inputs_gradient(\n                inputs, packed_kernel, kernel_scale, kernel_zero, g_idx\n            ):\n                \"\"\"Sub-channel int4 forward pass with custom gradient.\"\"\"\n                # Unpack: stored as [rows, ceil(columns/2)],\n                # unpack along last axis\n                unpacked_kernel = quantizers.unpack_int4(\n                    packed_kernel,\n                    self._int4_unpacked_column_size,\n                    axis=-1,\n                    dtype=\"int8\",\n                )\n\n                def _dequantize_kernel(unpacked, scale, zero, g_idx_t):\n                    # Dequantize with group_axis=0 since\n                    # scale is [n_groups, columns]\n                    float_kernel = dequantize_with_sz_map(\n                        unpacked, scale, zero, g_idx_t, group_axis=0\n                    )\n                    float_kernel = ops.cast(float_kernel, self.compute_dtype)\n                    return ops.reshape(float_kernel, self.original_kernel_shape)\n\n                def grad_fn(*args, upstream=None):\n                    if upstream is None:\n                        (upstream,) = args\n                    float_kernel = _dequantize_kernel(\n                        unpacked_kernel, kernel_scale, kernel_zero, g_idx\n                    )\n                    inputs_grad = ops.einsum(\n                        self._custom_gradient_equation, upstream, float_kernel\n                    )\n                    return (inputs_grad, None, None, None, None)\n\n                float_kernel = _dequantize_kernel(\n                    unpacked_kernel, kernel_scale, kernel_zero, g_idx\n                )\n                x = ops.einsum(self.equation, inputs, float_kernel)\n                return x, grad_fn\n\n            x = einsum_sub_channel_with_inputs_gradient(\n                inputs,\n                ops.convert_to_tensor(self._kernel),\n                ops.convert_to_tensor(self.kernel_scale),\n                ops.convert_to_tensor(self.kernel_zero),\n                ops.convert_to_tensor(self.g_idx),\n            )\n\n        if self.lora_enabled:\n            lora_x = ops.einsum(self.equation, inputs, self.lora_kernel_a)\n            lora_x = ops.matmul(lora_x, self.lora_kernel_b)\n            x = ops.add(x, (self.lora_alpha / self.lora_rank) * lora_x)\n\n        # Bias & activation\n        if self.bias is not None:\n            x = ops.add(x, self.bias)\n        if self.activation is not None:\n            x = self.activation(x)\n        return x\n\n    def _float8_call(self, inputs, training=None):\n        if self.lora_enabled:\n            raise NotImplementedError(\n                \"Currently, `_float8_call` doesn't support LoRA\"\n            )\n\n        @ops.custom_gradient\n        def quantized_dequantize_inputs(inputs, scale, amax_history):\n            if training:\n                new_scale = quantizers.compute_float8_scale(\n                    ops.max(amax_history, axis=0),\n                    scale,\n                    ops.cast(\n                        float(ml_dtypes.finfo(\"float8_e4m3fn\").max), \"float32\"\n                    ),\n                )\n                new_amax_history = quantizers.compute_float8_amax_history(\n                    inputs, amax_history\n                )\n            else:\n                new_scale = None\n                new_amax_history = None\n            qdq_inputs = quantizers.quantize_and_dequantize(\n                inputs, scale, \"float8_e4m3fn\", self.compute_dtype\n            )\n\n            def grad(*args, upstream=None, variables=None):\n                if upstream is None:\n                    (upstream,) = args\n                return upstream, new_scale, new_amax_history\n\n            return qdq_inputs, grad\n\n        @ops.custom_gradient\n        def quantized_dequantize_outputs(outputs, scale, amax_history):\n            \"\"\"Quantize-dequantize the output gradient but not the output.\"\"\"\n\n            def grad(*args, upstream=None, variables=None):\n                if upstream is None:\n                    (upstream,) = args\n                new_scale = quantizers.compute_float8_scale(\n                    ops.max(amax_history, axis=0),\n                    scale,\n                    ops.cast(\n                        float(ml_dtypes.finfo(\"float8_e5m2\").max), \"float32\"\n                    ),\n                )\n                qdq_upstream = quantizers.quantize_and_dequantize(\n                    upstream, scale, \"float8_e5m2\", self.compute_dtype\n                )\n                new_amax_history = quantizers.compute_float8_amax_history(\n                    upstream, amax_history\n                )\n                return qdq_upstream, new_scale, new_amax_history\n\n            return outputs, grad\n\n        x = ops.einsum(\n            self.equation,\n            quantized_dequantize_inputs(\n                inputs,\n                ops.convert_to_tensor(self.inputs_scale),\n                ops.convert_to_tensor(self.inputs_amax_history),\n            ),\n            quantized_dequantize_inputs(\n                ops.convert_to_tensor(self._kernel),\n                ops.convert_to_tensor(self.kernel_scale),\n                ops.convert_to_tensor(self.kernel_amax_history),\n            ),\n        )\n        # `quantized_dequantize_outputs` is placed immediately after\n        # `ops.einsum` for the sake of pattern matching in gemm_rewrite. That\n        # way, the qdq will be adjacent to the corresponding einsum_bprop in the\n        # bprop.\n        x = quantized_dequantize_outputs(\n            x,\n            ops.convert_to_tensor(self.outputs_grad_scale),\n            ops.convert_to_tensor(self.outputs_grad_amax_history),\n        )\n        if self.bias is not None:\n            # Under non-mixed precision cases, F32 bias has to be converted to\n            # BF16 first to get the biasAdd fusion support. ref. PR\n            # https://github.com/tensorflow/tensorflow/pull/60306\n            bias = self.bias\n            if self.dtype_policy.compute_dtype == \"float32\":\n                bias_bf16 = ops.cast(bias, \"bfloat16\")\n                bias = ops.cast(bias_bf16, bias.dtype)\n            x = ops.add(x, bias)\n        if self.activation is not None:\n            x = self.activation(x)\n        return x\n\n    def quantize(self, mode=None, type_check=True, config=None):\n        # Prevent quantization of the subclasses\n        if type_check and (type(self) is not EinsumDense):\n            raise self._not_implemented_error(self.quantize)\n\n        self.quantization_config = config\n\n        kernel_shape = self._kernel.shape\n        if mode in (\"int8\", \"int4\", \"gptq\", \"awq\"):\n            self._set_quantization_info()\n\n        if mode == \"int8\":\n            # Quantize `self._kernel` to int8 and compute corresponding scale\n            weight_quantizer = QuantizationConfig.weight_quantizer_or_default(\n                self.quantization_config,\n                quantizers.AbsMaxQuantizer(axis=self._kernel_reduced_axes),\n            )\n            kernel_value, kernel_scale = weight_quantizer(\n                self._kernel, to_numpy=True\n            )\n            kernel_scale = self._adjust_scale_for_quant(kernel_scale, \"kernel\")\n            del self._kernel\n        elif mode == \"int4\":\n            from keras.src.quantizers.quantization_config import (\n                Int4QuantizationConfig,\n            )\n\n            block_size = None\n            if isinstance(self.quantization_config, Int4QuantizationConfig):\n                block_size = self.quantization_config.block_size\n\n            use_grouped = block_size is not None and block_size != -1\n\n            # Flatten kernel to 2D: rows = reduced dims, columns = non-reduced\n            rows = 1\n            columns = 1\n            for i, dim in enumerate(kernel_shape):\n                if i in self._kernel_reduced_axes:\n                    rows *= dim\n                else:\n                    columns *= dim\n\n            flat_kernel = ops.reshape(self._kernel, (rows, columns))\n\n            if not use_grouped:\n                # Per-channel quantization\n                kernel_value_int4, kernel_scale = quantizers.abs_max_quantize(\n                    flat_kernel,\n                    axis=0,\n                    value_range=(-8, 7),\n                    dtype=\"int8\",\n                    to_numpy=True,\n                )\n                kernel_scale = ops.squeeze(kernel_scale, axis=0)\n            else:\n                # Sub-channel quantization with asymmetric zero point\n                # Returns kernel [rows, columns], scale [n_groups, columns]\n                kernel_value_int4, kernel_scale, kernel_zero = (\n                    quantizers.abs_max_quantize_grouped_with_zero_point(\n                        flat_kernel, block_size=block_size, to_numpy=True\n                    )\n                )\n\n            # Pack two int4 values per int8 byte along last axis\n            # Stored as [rows, ceil(columns/2)]\n            packed_kernel_value, _, _ = quantizers.pack_int4(\n                kernel_value_int4, axis=-1\n            )\n            kernel_value = packed_kernel_value\n            del self._kernel\n        self.quantized_build(kernel_shape, mode, self.quantization_config)\n\n        # Assign values to the newly created variables.\n        if mode in (\"int8\", \"int4\"):\n            self._kernel.assign(kernel_value)\n            self.kernel_scale.assign(kernel_scale)\n            # Assign zero point for sub-channel int4 quantization\n            if mode == \"int4\" and use_grouped:\n                self.kernel_zero.assign(kernel_zero)\n\n        # Set new dtype policy\n        if self.dtype_policy.quantization_mode is None:\n            policy_name = mode\n            if mode in (\"gptq\", \"awq\"):\n                policy_name = self.quantization_config.dtype_policy_string()\n            elif mode == \"int4\":\n                # Include block_size in policy name for sub-channel quantization\n                block_size = get_block_size_for_layer(self, config)\n                # Use -1 for per-channel, otherwise use block_size\n                block_size_value = -1 if block_size is None else block_size\n                policy_name = f\"int4/{block_size_value}\"\n            policy = dtype_policies.get(\n                f\"{policy_name}_from_{self.dtype_policy.name}\"\n            )\n            self.dtype_policy = policy\n\n    def _get_kernel_scale_shape(self, kernel_shape, block_size=None):\n        \"\"\"Get the shape of the kernel scale tensor.\n\n        The kernel scale tensor is used to scale the kernel tensor.\n        The shape of the kernel scale tensor is the same as the shape of the\n        kernel tensor, but with the reduced axes set to 1 (for per-channel)\n        or n_groups (for grouped quantization), and the transpose axes set\n        to the original axes.\n\n        Args:\n            kernel_shape: The shape of the kernel tensor.\n            block_size: If provided and positive, use grouped quantization\n                along the reduced axes with the specified block size.\n\n        Returns:\n            The shape of the kernel scale tensor.\n        \"\"\"\n        if block_size is not None and block_size > 0:\n            # Grouped quantization: use simple 2D scale shape\n            # (n_groups, non_reduced) - matches dequantize_grouped format\n            total_reduced_dim = 1\n            for ax in self._kernel_reduced_axes:\n                total_reduced_dim *= kernel_shape[ax]\n            n_groups = math.ceil(total_reduced_dim / block_size)\n\n            total_non_reduced = 1\n            for i, dim in enumerate(kernel_shape):\n                if i not in self._kernel_reduced_axes:\n                    total_non_reduced *= dim\n\n            return (n_groups, total_non_reduced)\n        else:\n            # Per-channel quantization: use the original transformation logic\n            kernel_scale_shape = np.array(kernel_shape)\n            kernel_scale_shape[self._kernel_reduced_axes] = 1\n\n            kernel_scale_shape = kernel_scale_shape[self._kernel_transpose_axes]\n            kernel_scale_shape = kernel_scale_shape.tolist()\n            for a in sorted(self._kernel_expand_axes):\n                kernel_scale_shape.insert(a, 1)\n            for a in sorted(self._kernel_squeeze_axes, reverse=True):\n                kernel_scale_shape.pop(a)\n            return kernel_scale_shape\n\n    def _get_kernel_with_merged_lora(self):\n        \"\"\"Returns the kernel with LoRA matrices merged, for serialization.\n\n        This method is called by `save_own_variables` to produce a single\n        kernel tensor that includes the adaptations from LoRA. This is useful\n        for deploying the model or for continuing training after permanently\n        applying the LoRA update.\n\n        If the layer is quantized (`int8` or `int4`), the process is:\n        1. Dequantize the base kernel to float.\n        2. Adjust the scale tensor layout for dequantization. This is the\n            reverse order of operations used when building the layer.\n        3. Compute the LoRA delta (`lora_kernel_a @ lora_kernel_b`) and add\n            it to the dequantized kernel.\n        4. Re-quantize the merged result back to the original quantized\n            type (`int8` or packed `int4`), calculating a new scale factor.\n        5. Adjust the scale tensor layout for quantization. This is the forward\n            order of operations used when building the layer.\n\n        If the layer is not quantized, this method returns the result of the\n        `kernel` property (which computes the merge in floating-point) and a\n        scale of `None`.\n\n        If LoRA is not enabled, it returns the original kernel and scale\n        without modification.\n\n        Returns:\n            A tuple `(kernel_value, kernel_scale, kernel_zero)`:\n                `kernel_value`: The merged kernel. A quantized tensor if\n                    quantization is active, otherwise a high precision tensor.\n                `kernel_scale`: The quantization scale for the merged kernel.\n                    This is `None` if the layer is not quantized.\n                `kernel_zero`: The zero point for sub-channel int4 quantization.\n                    This is `None` for per-channel or non-int4 modes.\n        \"\"\"\n        # If not a quantized layer, return the full-precision kernel directly.\n        if self.dtype_policy.quantization_mode in (None, \"gptq\", \"awq\"):\n            return self.kernel, None, None\n\n        kernel_zero = getattr(self, \"kernel_zero\", None)\n\n        # If quantized but LoRA is not enabled, return the original quantized\n        # kernel.\n        if not self.lora_enabled:\n            return self._kernel, self.kernel_scale, kernel_zero\n\n        # Dequantize, Merge, and Re-quantize\n\n        # 1. Dequantize the kernel\n        if self.quantization_mode == \"int4\":\n            # Unpack [rows, ceil(columns/2)] to [rows, columns]\n            unpacked_kernel = quantizers.unpack_int4(\n                self._kernel,\n                self._int4_unpacked_column_size,\n                axis=-1,\n            )\n            block_size = getattr(self, \"_int4_block_size\", None)\n            if block_size is not None and block_size != -1:\n                # Grouped dequantization with group_axis=0\n                kernel_fp = dequantize_with_sz_map(\n                    unpacked_kernel,\n                    self.kernel_scale,\n                    self.kernel_zero,\n                    self.g_idx,\n                    group_axis=0,\n                )\n            else:\n                # Per-channel dequantization:\n                # kernel [rows, columns], scale [columns]\n                kernel_fp = ops.divide(\n                    ops.cast(unpacked_kernel, self.compute_dtype),\n                    self.kernel_scale,\n                )\n            kernel_fp = ops.reshape(kernel_fp, self.original_kernel_shape)\n        elif self.quantization_mode == \"int8\":\n            adjusted_scale = self._adjust_scale_for_dequant(self.kernel_scale)\n            kernel_fp = ops.divide(self._kernel, adjusted_scale)\n        else:\n            raise ValueError(\n                f\"Unsupported quantization mode: {self.quantization_mode}\"\n            )\n\n        # 2. Merge the LoRA update in the float domain\n        lora_update = (self.lora_alpha / self.lora_rank) * ops.matmul(\n            self.lora_kernel_a, self.lora_kernel_b\n        )\n        merged_kernel = ops.add(kernel_fp, lora_update)\n\n        # 3. Re-quantize the merged float kernel back to the target format\n        if self.quantization_mode == \"int4\":\n            block_size = getattr(self, \"_int4_block_size\", None)\n            rows = self._int4_rows\n            columns = self._int4_unpacked_column_size\n\n            # Flatten to 2D [rows, columns]\n            flat_kernel = ops.reshape(merged_kernel, (rows, columns))\n\n            if block_size is not None and block_size != -1:\n                # Use abs_max_quantize_grouped_with_zero_point for proper\n                # signed quantization (same as quantize() method)\n                # Returns kernel [rows, columns], scale [n_groups, columns]\n                kernel_quant, new_scale, new_zero = (\n                    quantizers.abs_max_quantize_grouped_with_zero_point(\n                        flat_kernel, block_size=block_size, to_numpy=True\n                    )\n                )\n                kernel_zero = new_zero\n            else:\n                # Per-channel: quantize along rows axis\n                kernel_quant, new_scale = quantizers.abs_max_quantize(\n                    flat_kernel,\n                    axis=0,\n                    value_range=(-8, 7),\n                    dtype=\"int8\",\n                    to_numpy=True,\n                )\n                new_scale = ops.squeeze(new_scale, axis=0)\n                kernel_zero = None\n\n            # Pack along last axis\n            new_kernel, _, _ = quantizers.pack_int4(kernel_quant, axis=-1)\n        elif self.quantization_mode == \"int8\":\n            new_kernel, new_scale = quantizers.abs_max_quantize(\n                merged_kernel,\n                axis=self._kernel_reduced_axes,\n                to_numpy=True,\n            )\n            new_scale = self._adjust_scale_for_quant(new_scale, \"kernel\")\n            kernel_zero = None\n\n        return new_kernel, new_scale, kernel_zero\n\n    def _adjust_scale_for_dequant(self, scale):\n        \"\"\"Adjusts scale tensor layout for dequantization.\n\n        Helper method to handle scale adjustments before dequantization.\n        This is the reverse order of operations used when building the layer.\n\n        Args:\n            scale: The scale tensor to adjust.\n\n        Returns:\n            The adjusted scale tensor.\n        \"\"\"\n        if self._kernel_squeeze_axes:\n            scale = ops.expand_dims(scale, axis=self._kernel_squeeze_axes)\n        if self._kernel_expand_axes:\n            scale = ops.squeeze(scale, axis=self._kernel_expand_axes)\n        if self._kernel_transpose_axes:\n            # We need to reverse the transpose operation.\n            reverse_transpose = sorted(\n                range(len(self._kernel_transpose_axes)),\n                key=self._kernel_transpose_axes.__getitem__,\n            )\n            scale = ops.transpose(scale, axes=reverse_transpose)\n        return scale\n\n    def _adjust_scale_for_quant(self, scale, tensor_type=\"kernel\"):\n        \"\"\"Adjusts scale tensor layout after quantization.\n\n        Helper method to handle scale adjustments after re-quantization.\n        This is the forward order of operations used when building the layer.\n\n        Args:\n            scale: The scale tensor to adjust.\n            tensor_type: The type of tensor to adjust the scale for.\n                \"kernel\" or \"input\".\n        Returns:\n            The adjusted scale tensor.\n        \"\"\"\n        if tensor_type == \"kernel\":\n            transpose_axes = self._kernel_transpose_axes\n            expand_axes = self._kernel_expand_axes\n            squeeze_axes = self._kernel_squeeze_axes\n        elif tensor_type == \"input\":\n            transpose_axes = self._input_transpose_axes\n            expand_axes = self._input_expand_axes\n            squeeze_axes = self._input_squeeze_axes\n        else:\n            raise ValueError(f\"Invalid tensor type: {tensor_type}\")\n\n        if transpose_axes:\n            scale = ops.transpose(scale, transpose_axes)\n        if expand_axes:\n            scale = ops.expand_dims(scale, axis=expand_axes)\n        if squeeze_axes:\n            scale = ops.squeeze(scale, axis=squeeze_axes)\n        return scale\n\n    def _set_quantization_info(self):\n        if hasattr(self, \"_input_reduced_axes\"):\n            # Already set.\n            return\n        (\n            self._input_reduced_axes,\n            self._kernel_reduced_axes,\n            self._input_transpose_axes,\n            self._kernel_transpose_axes,\n            self._input_expand_axes,\n            self._kernel_expand_axes,\n            self._input_squeeze_axes,\n            self._kernel_squeeze_axes,\n            self._custom_gradient_equation,\n            self._kernel_reverse_transpose_axes,\n        ) = _analyze_quantization_info(self.equation, self.input_spec.ndim)\n\n\ndef _analyze_einsum_string(equation, bias_axes, input_shape, output_shape):\n    \"\"\"Parses an einsum string to determine the shapes of the weights.\n\n    This function is the main entry point for analyzing the einsum equation.\n    It handles equations with and without ellipses (`...`) by converting them\n    to a standard format and then delegating to `_analyze_split_string` for\n    the core logic.\n\n    Args:\n        equation: The einsum equation string, e.g., \"ab,bc->ac\" or\n            \"...ab,bc->...ac\".\n        bias_axes: A string indicating which output axes to apply a bias to.\n        input_shape: The shape of the input tensor.\n        output_shape: The user-specified shape of the output tensor (may be\n            partial).\n\n    Returns:\n        A tuple `(kernel_shape, bias_shape, full_output_shape)` where:\n            `kernel_shape`: The calculated shape of the einsum kernel.\n            `bias_shape`: The calculated shape of the bias, or `None`.\n            `full_output_shape`: The fully-resolved shape of the output tensor.\n\n    Raises:\n        ValueError: If the einsum `equation` is not in a supported format.\n    \"\"\"\n\n    dot_replaced_string = re.sub(r\"\\.\\.\\.\", \"0\", equation)\n\n    # This is the case where no ellipses are present in the string.\n    split_string = re.match(\n        \"([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)\", dot_replaced_string\n    )\n    if split_string:\n        return _analyze_split_string(\n            split_string, bias_axes, input_shape, output_shape\n        )\n\n    # This is the case where ellipses are present on the left.\n    split_string = re.match(\n        \"0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)\", dot_replaced_string\n    )\n    if split_string:\n        return _analyze_split_string(\n            split_string, bias_axes, input_shape, output_shape, left_elided=True\n        )\n\n    # This is the case where ellipses are present on the right.\n    split_string = re.match(\n        \"([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0\", dot_replaced_string\n    )\n    if split_string:\n        return _analyze_split_string(\n            split_string, bias_axes, input_shape, output_shape\n        )\n\n    raise ValueError(\n        f\"Invalid einsum equation '{equation}'. Equations must be in the form \"\n        \"[X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]....\"\n    )\n\n\ndef _analyze_split_string(\n    split_string, bias_axes, input_shape, output_shape, left_elided=False\n):\n    \"\"\"Computes kernel and bias shapes from a parsed einsum equation.\n\n    This function takes the components of an einsum equation, validates them,\n    and calculates the required shapes for the kernel and bias weights.\n\n    Args:\n        split_string: A regex match object containing the input, weight, and\n            output specifications.\n        bias_axes: A string indicating which output axes to apply a bias to.\n        input_shape: The shape of the input tensor.\n        output_shape: The user-specified partial shape of the output tensor.\n        left_elided: A boolean indicating if the ellipsis \"...\" was on the\n            left side of the equation.\n\n    Returns:\n        A tuple `(kernel_shape, bias_shape, full_output_shape)` where:\n            `kernel_shape`: The calculated shape of the einsum kernel.\n            `bias_shape`: The calculated shape of the bias, or `None`.\n            `full_output_shape`: The fully-resolved shape of the output tensor.\n\n    Raises:\n        ValueError: If there are inconsistencies between the input and output\n            shapes or if the equation specifications are invalid.\n    \"\"\"\n    input_spec = split_string.group(1)\n    weight_spec = split_string.group(2)\n    output_spec = split_string.group(3)\n    elided = len(input_shape) - len(input_spec)\n\n    if isinstance(output_shape, int):\n        output_shape = [output_shape]\n    else:\n        output_shape = list(output_shape)\n\n    output_shape.insert(0, input_shape[0])\n\n    if elided > 0 and left_elided:\n        for i in range(1, elided):\n            # We already inserted the 0th input dimension at dim 0, so we need\n            # to start at location 1 here.\n            output_shape.insert(1, input_shape[i])\n    elif elided > 0 and not left_elided:\n        for i in range(len(input_shape) - elided, len(input_shape)):\n            output_shape.append(input_shape[i])\n\n    if left_elided:\n        # If we have beginning dimensions elided, we need to use negative\n        # indexing to determine where in the input dimension our values are.\n        input_dim_map = {\n            dim: (i + elided) - len(input_shape)\n            for i, dim in enumerate(input_spec)\n        }\n        # Because we've constructed the full output shape already, we don't need\n        # to do negative indexing.\n        output_dim_map = {\n            dim: (i + elided) for i, dim in enumerate(output_spec)\n        }\n    else:\n        input_dim_map = {dim: i for i, dim in enumerate(input_spec)}\n        output_dim_map = {dim: i for i, dim in enumerate(output_spec)}\n\n    for dim in input_spec:\n        input_shape_at_dim = input_shape[input_dim_map[dim]]\n        if dim in output_dim_map:\n            output_shape_at_dim = output_shape[output_dim_map[dim]]\n            if (\n                output_shape_at_dim is not None\n                and output_shape_at_dim != input_shape_at_dim\n            ):\n                raise ValueError(\n                    \"Input shape and output shape do not match at shared \"\n                    f\"dimension '{dim}'. Input shape is {input_shape_at_dim}, \"\n                    \"and output shape \"\n                    f\"is {output_shape[output_dim_map[dim]]}.\"\n                )\n\n    for dim in output_spec:\n        if dim not in input_spec and dim not in weight_spec:\n            raise ValueError(\n                f\"Dimension '{dim}' was specified in the output \"\n                f\"'{output_spec}' but has no corresponding dim in the input \"\n                f\"spec '{input_spec}' or weight spec '{output_spec}'\"\n            )\n\n    weight_shape = []\n    for dim in weight_spec:\n        if dim in input_dim_map:\n            weight_shape.append(input_shape[input_dim_map[dim]])\n        elif dim in output_dim_map:\n            weight_shape.append(output_shape[output_dim_map[dim]])\n        else:\n            raise ValueError(\n                f\"Weight dimension '{dim}' did not have a match in either \"\n                f\"the input spec '{input_spec}' or the output \"\n                f\"spec '{output_spec}'. For this layer, the weight must \"\n                \"be fully specified.\"\n            )\n\n    if bias_axes is not None:\n        num_left_elided = elided if left_elided else 0\n        idx_map = {\n            char: output_shape[i + num_left_elided]\n            for i, char in enumerate(output_spec)\n        }\n\n        for char in bias_axes:\n            if char not in output_spec:\n                raise ValueError(\n                    f\"Bias dimension '{char}' was requested, but is not part \"\n                    f\"of the output spec '{output_spec}'\"\n                )\n\n        first_bias_location = min(\n            [output_spec.find(char) for char in bias_axes]\n        )\n        bias_output_spec = output_spec[first_bias_location:]\n\n        bias_shape = [\n            idx_map[char] if char in bias_axes else 1\n            for char in bias_output_spec\n        ]\n\n        if not left_elided:\n            for _ in range(elided):\n                bias_shape.append(1)\n    else:\n        bias_shape = None\n\n    return weight_shape, bias_shape, output_shape\n\n\ndef _analyze_quantization_info(equation, input_shape):\n    \"\"\"Analyzes an einsum equation to derive information for quantization.\n\n    This function canonicalizes the einsum equation (handling ellipses) and\n    determines the necessary tensor manipulations (reduction, transposition,\n    expansion, squeezing) required to correctly apply per-axis quantization\n    to the inputs and kernel. It also derives the einsum equation needed for\n    the custom gradient.\n\n    Args:\n        equation: The einsum equation string.\n        input_shape: The shape of the input tensor.\n\n    Returns:\n        A tuple containing metadata for quantization operations:\n        `input_reduced_axes`: Axes to reduce for input quantization.\n        `kernel_reduced_axes`: Axes to reduce for kernel quantization.\n        `input_transpose_axes`: Permutation for transposing the input scale.\n        `kernel_transpose_axes`: Permutation for transposing the kernel scale.\n        `input_expand_axes`: Axes to expand for the input scale.\n        `kernel_expand_axes`: Axes to expand for the kernel scale.\n        `input_squeeze_axes`: Axes to squeeze from the input scale.\n        `kernel_squeeze_axes`: Axes to squeeze from the kernel scale.\n        `custom_gradient_equation`: Einsum equation for the backward pass.\n        `kernel_reverse_transpose_axes`: Permutation to reverse the kernel\n            scale transpose.\n    \"\"\"\n\n    def get_specs(equation, input_shape):\n        possible_labels = string.ascii_letters\n        dot_replaced_string = re.sub(r\"\\.\\.\\.\", \"0\", equation)\n\n        # This is the case where no ellipses are present in the string.\n        split_string = re.match(\n            \"([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)\", dot_replaced_string\n        )\n        if split_string is not None:\n            input_spec = split_string.group(1)\n            weight_spec = split_string.group(2)\n            output_spec = split_string.group(3)\n            return input_spec, weight_spec, output_spec\n\n        # This is the case where ellipses are present on the left.\n        split_string = re.match(\n            \"0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)\", dot_replaced_string\n        )\n        if split_string is not None:\n            input_spec = split_string.group(1)\n            weight_spec = split_string.group(2)\n            output_spec = split_string.group(3)\n            elided = len(input_shape) - len(input_spec)\n            possible_labels = sorted(\n                set(possible_labels)\n                - set(input_spec)\n                - set(weight_spec)\n                - set(output_spec)\n            )\n            # Pad labels on the left to `input_spec` and `output_spec`\n            for i in range(elided):\n                input_spec = possible_labels[i] + input_spec\n                output_spec = possible_labels[i] + output_spec\n            return input_spec, weight_spec, output_spec\n\n        # This is the case where ellipses are present on the right.\n        split_string = re.match(\n            \"([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0\", dot_replaced_string\n        )\n        if split_string is not None:\n            input_spec = split_string.group(1)\n            weight_spec = split_string.group(2)\n            output_spec = split_string.group(3)\n            elided = len(input_shape) - len(input_spec)\n            possible_labels = sorted(\n                set(possible_labels)\n                - set(input_spec)\n                - set(weight_spec)\n                - set(output_spec)\n            )\n            # Pad labels on the right to `input_spec` and `output_spec`\n            for i in range(elided):\n                input_spec = input_spec + possible_labels[i]\n                output_spec = output_spec + possible_labels[i]\n            return input_spec, weight_spec, output_spec\n\n        raise ValueError(\n            f\"Invalid einsum equation '{equation}'. Equations must be in the \"\n            \"form [X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]....\"\n        )\n\n    input_spec, weight_spec, output_spec = get_specs(equation, input_shape)\n\n    # Determine the axes that should be reduced by the quantizer\n    input_reduced_axes = []\n    weight_reduced_axes = []\n    for i, label in enumerate(input_spec):\n        index = output_spec.find(label)\n        if index == -1:\n            input_reduced_axes.append(i)\n    for i, label in enumerate(weight_spec):\n        index = output_spec.find(label)\n        if index == -1:\n            weight_reduced_axes.append(i)\n\n    # Determine the axes of `ops.expand_dims`\n    input_expand_axes = []\n    weight_expand_axes = []\n    for i, label in enumerate(output_spec):\n        index_input = input_spec.find(label)\n        index_weight = weight_spec.find(label)\n        if index_input == -1:\n            input_expand_axes.append(i)\n        if index_weight == -1:\n            weight_expand_axes.append(i)\n\n    # Determine the axes of `ops.transpose`\n    input_transpose_axes = []\n    weight_transpose_axes = []\n    for i, label in enumerate(output_spec):\n        index_input = input_spec.find(label)\n        index_weight = weight_spec.find(label)\n        if index_input != -1:\n            input_transpose_axes.append(index_input)\n        if index_weight != -1:\n            weight_transpose_axes.append(index_weight)\n    # Postprocess the information:\n    # 1. Add dummy axes (1) to transpose_axes\n    # 2. Add axis to squeeze_axes if 1. failed\n    input_squeeze_axes = []\n    weight_squeeze_axes = []\n    for ori_index in input_reduced_axes:\n        try:\n            index = input_expand_axes.pop(0)\n        except IndexError:\n            input_squeeze_axes.append(ori_index)\n        input_transpose_axes.insert(index, ori_index)\n    for ori_index in weight_reduced_axes:\n        try:\n            index = weight_expand_axes.pop(0)\n        except IndexError:\n            weight_squeeze_axes.append(ori_index)\n        weight_transpose_axes.insert(index, ori_index)\n    # Prepare equation for `einsum_with_inputs_gradient`\n    custom_gradient_equation = f\"{output_spec},{weight_spec}->{input_spec}\"\n    weight_reverse_transpose_axes = [\n        i\n        for (_, i) in sorted(\n            (v, i) for (i, v) in enumerate(weight_transpose_axes)\n        )\n    ]\n    return (\n        input_reduced_axes,\n        weight_reduced_axes,\n        input_transpose_axes,\n        weight_transpose_axes,\n        input_expand_axes,\n        weight_expand_axes,\n        input_squeeze_axes,\n        weight_squeeze_axes,\n        custom_gradient_equation,\n        weight_reverse_transpose_axes,\n    )\n"
  },
  {
    "path": "keras/src/layers/core/einsum_dense_test.py",
    "content": "import math\nimport os\n\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import constraints\nfrom keras.src import export\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import optimizers\nfrom keras.src import quantizers\nfrom keras.src import random\nfrom keras.src import saving\nfrom keras.src import testing\nfrom keras.src.quantizers.awq_config import AWQConfig\nfrom keras.src.quantizers.gptq_config import GPTQConfig\nfrom keras.src.quantizers.quantization_config import Int4QuantizationConfig\nfrom keras.src.quantizers.quantization_config import Int8QuantizationConfig\nfrom keras.src.quantizers.quantizers import AbsMaxQuantizer\nfrom keras.src.saving.saving_api import load_model\n\n\nclass EinsumDenseTest(testing.TestCase):\n    @parameterized.named_parameters(\n        (\"int8\", \"int8\", {\"axis\": 0}, {\"axis\": -1}),\n        (\n            \"int4\",\n            \"int4\",\n            {\"axis\": 0, \"value_range\": (-8, 7), \"output_dtype\": \"int8\"},\n            {\"axis\": -1},\n        ),\n        (\"int8_weight_only\", \"int8\", {\"axis\": 0}, None),\n        (\n            \"int4_weight_only\",\n            \"int4\",\n            {\"axis\": 0, \"value_range\": (-8, 7), \"output_dtype\": \"int8\"},\n            None,\n        ),\n    )\n    def test_einsum_dense_quantize(\n        self, mode, weight_quantizer_args, activation_quantizer_args\n    ):\n        \"\"\"Test EinsumDense quantization with QuantizationConfig.\"\"\"\n        layer = layers.EinsumDense(\n            equation=\"ab,bcd->acd\",\n            output_shape=(8, 32),\n            bias_axes=\"d\",\n        )\n        layer.build((None, 3))\n\n        weight_quantizer = AbsMaxQuantizer(**weight_quantizer_args)\n        if activation_quantizer_args is not None:\n            activation_quantizer = AbsMaxQuantizer(**activation_quantizer_args)\n        else:\n            activation_quantizer = None\n\n        if mode == \"int8\":\n            config = Int8QuantizationConfig(\n                weight_quantizer=weight_quantizer,\n                activation_quantizer=activation_quantizer,\n            )\n        elif mode == \"int4\":\n            # Custom quantizers require per-channel mode (block_size=None)\n            config = Int4QuantizationConfig(\n                weight_quantizer=weight_quantizer,\n                activation_quantizer=activation_quantizer,\n                block_size=None,\n            )\n\n        layer.quantize(mode, config=config)\n\n        if activation_quantizer_args is not None:\n            # Verify inputs_quantizer is set correctly\n            self.assertIsInstance(layer.inputs_quantizer, AbsMaxQuantizer)\n        else:\n            # Verify inputs_quantizer is None\n            self.assertIsNone(layer.inputs_quantizer)\n\n        # Verify call works\n        x = np.random.random((2, 3)).astype(\"float32\")\n        y = layer(x)\n        self.assertEqual(y.shape, (2, 8, 32))\n\n        if mode == \"int4\":\n            # Verify kernel is int8 (packed int4)\n            self.assertEqual(\n                backend.standardize_dtype(layer._kernel.dtype), \"int8\"\n            )\n\n    @parameterized.named_parameters(\n        {\n            \"testcase_name\": \"_1d_end_weight\",\n            \"equation\": \"ab,b->a\",\n            \"bias_axes\": None,\n            \"input_shape\": (2, 32),\n            \"output_shape\": (),\n            \"expected_kernel_shape\": (32,),\n            \"expected_bias_shape\": None,\n            \"expected_output_shape\": (2,),\n        },\n        {\n            \"testcase_name\": \"_2d_middle_weight\",\n            \"equation\": \"ab,bc->ac\",\n            \"bias_axes\": None,\n            \"input_shape\": (2, 32),\n            \"output_shape\": (64),\n            \"expected_kernel_shape\": (32, 64),\n            \"expected_bias_shape\": None,\n            \"expected_output_shape\": (2, 64),\n        },\n        {\n            \"testcase_name\": \"_3d_bert\",\n            \"equation\": \"abc,cde->abde\",\n            \"bias_axes\": None,\n            \"input_shape\": (2, 1, 2),\n            \"output_shape\": (1, 3, 4),\n            \"expected_kernel_shape\": (2, 3, 4),\n            \"expected_bias_shape\": None,\n            \"expected_output_shape\": (2, 1, 3, 4),\n        },\n        {\n            \"testcase_name\": \"_3d_3_bias\",\n            \"equation\": \"abc,cde->abde\",\n            \"bias_axes\": \"e\",\n            \"input_shape\": (2, 1, 2),\n            \"output_shape\": (1, 3, 4),\n            \"expected_kernel_shape\": (2, 3, 4),\n            \"expected_bias_shape\": (4,),\n            \"expected_output_shape\": (2, 1, 3, 4),\n        },\n        {\n            \"testcase_name\": \"_3d_2_bias\",\n            \"equation\": \"abc,cde->abde\",\n            \"bias_axes\": \"d\",\n            \"input_shape\": (2, 1, 2),\n            \"output_shape\": (1, 3, 4),\n            \"expected_kernel_shape\": (2, 3, 4),\n            \"expected_bias_shape\": (3, 1),\n            \"expected_output_shape\": (2, 1, 3, 4),\n        },\n        {\n            \"testcase_name\": \"_3d_1_3_bias\",\n            \"equation\": \"abc,cde->abde\",\n            \"bias_axes\": \"be\",\n            \"input_shape\": (2, 7, 2),\n            \"output_shape\": (7, 3, 4),\n            \"expected_kernel_shape\": (2, 3, 4),\n            \"expected_bias_shape\": (7, 1, 4),\n            \"expected_output_shape\": (2, 7, 3, 4),\n        },\n        {\n            \"testcase_name\": \"_3d_bert_projection\",\n            \"equation\": \"BFNH,NHD->BFD\",\n            \"bias_axes\": None,\n            \"input_shape\": (2, 1, 2, 3),\n            \"output_shape\": (1, 4),\n            \"expected_kernel_shape\": (2, 3, 4),\n            \"expected_bias_shape\": None,\n            \"expected_output_shape\": (2, 1, 4),\n        },\n        {\n            \"testcase_name\": \"_2d_bert\",\n            \"equation\": \"abc,cd->abd\",\n            \"bias_axes\": None,\n            \"input_shape\": (2, 1, 2),\n            \"output_shape\": (1, 4),\n            \"expected_kernel_shape\": (2, 4),\n            \"expected_bias_shape\": None,\n            \"expected_output_shape\": (2, 1, 4),\n        },\n        {\n            \"testcase_name\": \"_embedding_1d\",\n            \"equation\": \"i,d->id\",\n            \"bias_axes\": None,\n            \"input_shape\": (2,),\n            \"output_shape\": (2,),\n            \"expected_kernel_shape\": (2,),\n            \"expected_bias_shape\": None,\n            \"expected_output_shape\": (2, 2),\n        },\n        {\n            \"testcase_name\": \"_xlnet_lm\",\n            \"equation\": \"ibd,nd->ibn\",\n            \"bias_axes\": None,\n            \"input_shape\": (2, 2, 1),\n            \"output_shape\": (2, 2),\n            \"expected_kernel_shape\": (2, 1),\n            \"expected_bias_shape\": None,\n            \"expected_output_shape\": (2, 2, 2),\n        },\n        {\n            \"testcase_name\": \"_2d_precast\",\n            \"equation\": \"...b,bc->...c\",\n            \"bias_axes\": None,\n            \"input_shape\": (2, 32),\n            \"output_shape\": (64,),\n            \"expected_kernel_shape\": (32, 64),\n            \"expected_bias_shape\": None,\n            \"expected_output_shape\": (2, 64),\n        },\n        {\n            \"testcase_name\": \"_2d_precast_elided_input_used_in_output\",\n            \"equation\": \"...bc,bc->...b\",\n            \"bias_axes\": None,\n            \"input_shape\": (2, 32, 64),\n            \"output_shape\": (32,),\n            \"expected_kernel_shape\": (32, 64),\n            \"expected_bias_shape\": None,\n            \"expected_output_shape\": (2, 32),\n        },\n        {\n            \"testcase_name\": \"_2d_precast_multiple_elided_dims\",\n            \"equation\": \"...b,bc->...c\",\n            \"bias_axes\": None,\n            \"input_shape\": (2, 3, 32),\n            \"output_shape\": (64,),\n            \"expected_kernel_shape\": (32, 64),\n            \"expected_bias_shape\": None,\n            \"expected_output_shape\": (2, 3, 64),\n        },\n        {\n            \"testcase_name\": \"_3d_precast\",\n            \"equation\": \"...c,cde->...de\",\n            \"bias_axes\": None,\n            \"input_shape\": (2, 1, 2),\n            \"output_shape\": (3, 4),\n            \"expected_kernel_shape\": (2, 3, 4),\n            \"expected_bias_shape\": None,\n            \"expected_output_shape\": (2, 1, 3, 4),\n        },\n        {\n            \"testcase_name\": \"_3d_precast_3_bias\",\n            \"equation\": \"...c,cde->...de\",\n            \"bias_axes\": \"e\",\n            \"input_shape\": (2, 1, 2),\n            \"output_shape\": (3, 4),\n            \"expected_kernel_shape\": (2, 3, 4),\n            \"expected_bias_shape\": (4,),\n            \"expected_output_shape\": (2, 1, 3, 4),\n        },\n        {\n            \"testcase_name\": \"_3d_precast_2_bias\",\n            \"equation\": \"...c,cde->...de\",\n            \"bias_axes\": \"d\",\n            \"input_shape\": (2, 1, 2),\n            \"output_shape\": (3, 4),\n            \"expected_kernel_shape\": (2, 3, 4),\n            \"expected_bias_shape\": (3, 1),\n            \"expected_output_shape\": (2, 1, 3, 4),\n        },\n        {\n            \"testcase_name\": \"_3d_precast_2_3_bias\",\n            \"equation\": \"...c,cde->...de\",\n            \"bias_axes\": \"de\",\n            \"input_shape\": (2, 1, 2),\n            \"output_shape\": (3, 4),\n            \"expected_kernel_shape\": (2, 3, 4),\n            \"expected_bias_shape\": (3, 4),\n            \"expected_output_shape\": (2, 1, 3, 4),\n        },\n        {\n            \"testcase_name\": \"_2d_postcast\",\n            \"equation\": \"bc...,cd->bd...\",\n            \"bias_axes\": None,\n            \"input_shape\": (2, 1, 2, 3),\n            \"output_shape\": (4,),\n            \"expected_kernel_shape\": (1, 4),\n            \"expected_bias_shape\": None,\n            \"expected_output_shape\": (2, 4, 2, 3),\n        },\n        {\n            \"testcase_name\": \"_3d_postcast\",\n            \"equation\": \"bc...,cde->bde...\",\n            \"bias_axes\": None,\n            \"input_shape\": (2, 1, 2),\n            \"output_shape\": (3, 4),\n            \"expected_kernel_shape\": (1, 3, 4),\n            \"expected_bias_shape\": None,\n            \"expected_output_shape\": (2, 3, 4, 2),\n        },\n        {\n            \"testcase_name\": \"_3d_postcast_1_bias\",\n            \"equation\": \"bc...,cde->bde...\",\n            \"bias_axes\": \"d\",\n            \"input_shape\": (2, 1, 2),\n            \"output_shape\": (3, 4),\n            \"expected_kernel_shape\": (1, 3, 4),\n            \"expected_bias_shape\": (3, 1, 1),\n            \"expected_output_shape\": (2, 3, 4, 2),\n        },\n        {\n            \"testcase_name\": \"_3d_postcast_2_bias\",\n            \"equation\": \"bc...,cde->bde...\",\n            \"bias_axes\": \"e\",\n            \"input_shape\": (2, 1, 2),\n            \"output_shape\": (3, 4),\n            \"expected_kernel_shape\": (1, 3, 4),\n            \"expected_bias_shape\": (4, 1),\n            \"expected_output_shape\": (2, 3, 4, 2),\n        },\n        {\n            \"testcase_name\": \"_3d_postcast_1_2_bias\",\n            \"equation\": \"bc...,cde->bde...\",\n            \"bias_axes\": \"de\",\n            \"input_shape\": (2, 1, 2),\n            \"output_shape\": (3, 4),\n            \"expected_kernel_shape\": (1, 3, 4),\n            \"expected_bias_shape\": (3, 4, 1),\n            \"expected_output_shape\": (2, 3, 4, 2),\n        },\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_einsum_dense_basics(\n        self,\n        equation,\n        bias_axes,\n        input_shape,\n        output_shape,\n        expected_kernel_shape,\n        expected_bias_shape,\n        expected_output_shape,\n    ):\n        self.run_layer_test(\n            layers.EinsumDense,\n            init_kwargs={\n                \"equation\": equation,\n                \"output_shape\": output_shape,\n                \"bias_axes\": bias_axes,\n            },\n            input_shape=input_shape,\n            expected_output_shape=expected_output_shape,\n            expected_num_trainable_weights=(\n                2 if expected_bias_shape is not None else 1\n            ),\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n        layer = layers.EinsumDense(\n            equation, output_shape=output_shape, bias_axes=bias_axes\n        )\n        layer.build(input_shape)\n        self.assertEqual(layer.kernel.shape, expected_kernel_shape)\n        if expected_bias_shape is not None:\n            self.assertEqual(layer.bias.shape, expected_bias_shape)\n\n    def test_einsum_dense_constraints(self):\n        layer = layers.EinsumDense(\n            \"abc,cde->abde\", (1, 3, 4), kernel_constraint=\"non_neg\"\n        )\n        layer.build((2, 1, 2))\n        self.assertIsInstance(layer.kernel.constraint, constraints.NonNeg)\n        layer = layers.EinsumDense(\n            \"ab,b->a\", (1, 3, 4), bias_axes=\"a\", bias_constraint=\"non_neg\"\n        )\n        layer.build((2, 1, 2))\n        self.assertIsInstance(layer.bias.constraint, constraints.NonNeg)\n\n    @pytest.mark.requires_trainable_backend\n    def test_enable_lora(self):\n        layer = layers.EinsumDense(\n            equation=\"ab,bcd->acd\",\n            output_shape=(8, 32),\n            bias_axes=None,\n        )\n        layer.build((None, 3))\n        layer.enable_lora(2)\n        self.assertLen(layer.trainable_weights, 2)\n        self.assertLen(layer.non_trainable_weights, 1)\n        if backend.backend() == \"torch\":\n            self.assertLen(layer.torch_params, 3)\n        # Try eager call\n        x = np.random.random((64, 3))\n        y = np.random.random((64, 8, 32))\n        _ = layer(x[:2])\n\n        init_lora_a_kernel_value = layer.lora_kernel_a.numpy()\n        init_lora_b_kernel_value = layer.lora_kernel_b.numpy()\n\n        # Try calling fit()\n        model = models.Sequential(\n            [\n                layer,\n            ]\n        )\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n        model.fit(x, y, epochs=2)\n\n        final_lora_a_kernel_value = layer.lora_kernel_a.numpy()\n        final_lora_b_kernel_value = layer.lora_kernel_b.numpy()\n        diff_a = np.max(\n            np.abs(init_lora_a_kernel_value - final_lora_a_kernel_value)\n        )\n        diff_b = np.max(\n            np.abs(init_lora_b_kernel_value - final_lora_b_kernel_value)\n        )\n        self.assertGreater(diff_a, 0.0)\n        self.assertGreater(diff_b, 0.0)\n\n        # Try saving and reloading the model\n        temp_filepath = os.path.join(self.get_temp_dir(), \"lora_model.keras\")\n        model.save(temp_filepath)\n\n        new_model = saving.load_model(temp_filepath)\n        self.assertTrue(new_model.layers[0].lora_enabled)\n        self.assertAllClose(model.predict(x), new_model.predict(x))\n\n        # Try saving and reloading the model's weights only\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"lora_model.weights.h5\"\n        )\n        model.save_weights(temp_filepath)\n\n        # Load the file into a fresh, non-lora model\n        new_model = models.Sequential(\n            [\n                layers.EinsumDense(\n                    equation=\"ab,bcd->acd\",\n                    output_shape=(8, 32),\n                    bias_axes=None,\n                ),\n            ]\n        )\n        new_model.build((None, 3))\n        new_model.load_weights(temp_filepath)\n        self.assertAllClose(model.predict(x), new_model.predict(x))\n\n        # Try loading a normal checkpoint into a lora model\n        new_model.save_weights(temp_filepath)\n        model.load_weights(temp_filepath)\n        self.assertAllClose(model.predict(x), new_model.predict(x))\n\n    @pytest.mark.requires_trainable_backend\n    def test_enable_lora_with_alpha(self):\n        # Use a simple equation that mimics a `Dense` layer behavior.\n        equation = \"ab,bc->ac\"\n        output_shape = 3  # This means the kernel shape will be (input_dim, 3).\n        bias_axes = None\n\n        # Create and build the `EinsumDense` layer\n        # with an input shape (None, 2).\n        layer = layers.EinsumDense(\n            equation=equation, output_shape=output_shape, bias_axes=bias_axes\n        )\n        # Build the layer with an input shape of (batch, 2).\n        layer.build((None, 2))\n\n        # Set the base kernel weights to a known value.\n        base_kernel = np.array(\n            [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32\n        )\n        layer._kernel.assign(base_kernel)\n\n        # Enable LoRA with `rank`=2 and a custom `lora_alpha`=3.0.\n        layer.enable_lora(rank=2, lora_alpha=3.0)\n        self.assertEqual(layer.lora_rank, 2)\n        self.assertEqual(layer.lora_alpha, 3.0)\n\n        # The expected shapes are:\n        #   `base_kernel`: (2, 3)\n        #   `lora_kernel_a`: (2, 2) and `lora_kernel_b`: (2, 3)\n        a_val = np.array([[0.1, 0.2], [0.3, 0.4]], dtype=np.float32)\n        b_val = np.array([[0.5, 0.6, 0.7], [0.8, 0.9, 1.0]], dtype=np.float32)\n        layer.lora_kernel_a.assign(a_val)\n        layer.lora_kernel_b.assign(b_val)\n\n        # Compute expected effective kernel.\n        # Scaling factor is `lora_alpha / lora_rank` = 3.0 / 2 = 1.5\n        expected_delta = 1.5 * np.matmul(a_val, b_val)\n        expected_kernel = base_kernel + expected_delta\n\n        # Verify that the effective kernel property returns the expected value.\n        actual_kernel = ops.convert_to_numpy(layer.kernel)\n        self.assertAllClose(\n            actual_kernel, expected_kernel, tpu_atol=1e-3, tpu_rtol=1e-3\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_lora_rank_argument(self):\n        self.run_layer_test(\n            layers.EinsumDense,\n            init_kwargs={\n                \"equation\": \"ab,bcd->acd\",\n                \"output_shape\": (8, 32),\n                \"bias_axes\": None,\n                \"lora_rank\": 2,\n            },\n            input_shape=(2, 3),\n            expected_output_shape=(2, 8, 32),\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=1,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n\n    # Test quantization-related methods.\n\n    @parameterized.named_parameters(\n        (\"int8\", \"int8\", 1e-3),\n        (\"int4\", \"int4\", 3e-3),\n    )\n    def test_quantize_int(self, mode, error_threshold):\n        layer = layers.EinsumDense(\n            equation=\"ab,bcd->acd\",\n            output_shape=(8, 32),\n            bias_axes=\"d\",\n        )\n        layer.build((None, 3))\n        x = np.random.random((2, 3))\n        y_float = layer(x)\n        layer.quantize(mode)\n\n        # Verify weights dtype\n        self.assertEqual(backend.standardize_dtype(layer._kernel.dtype), \"int8\")\n        self.assertEqual(\n            backend.standardize_dtype(layer.kernel_scale.dtype),\n            layer.variable_dtype,\n        )\n\n        # Try eager call and verify output correctness\n        y_quantized = layer(x)\n        mse = ops.mean(ops.square(y_float - y_quantized))\n        self.assertLess(mse, error_threshold)  # A weak correctness test\n\n        # Try saving and reloading the model\n        model = models.Sequential([layer])\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"quantized_model.keras\"\n        )\n        model.save(temp_filepath)\n        new_model = saving.load_model(temp_filepath)\n        self.assertAllClose(model.predict(x), new_model.predict(x))\n\n        # Try saving and reloading the model's weights only\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"quantized_model.weights.h5\"\n        )\n        model.save_weights(temp_filepath)\n\n        # Try building with quantized dtype policy\n        layer = layers.EinsumDense(\n            equation=\"abcde,afce->acdbf\",  # Test reduce and transpose\n            output_shape=(2, 4, 8, 16),\n            bias_axes=\"d\",\n            dtype=f\"{mode}_from_mixed_bfloat16\",\n        )\n        layer.build((1, 8, 2, 4, 32))\n        self.assertEqual(backend.standardize_dtype(layer._kernel.dtype), \"int8\")\n        self.assertEqual(\n            backend.standardize_dtype(layer.kernel_scale.dtype), \"float32\"\n        )\n        layer = layers.EinsumDense(\n            equation=\"a,b->ab\",  # Test expand\n            output_shape=(4,),\n            dtype=f\"{mode}_from_float32\",\n        )\n        layer.build((None,))\n        self.assertEqual(backend.standardize_dtype(layer._kernel.dtype), \"int8\")\n        self.assertEqual(\n            backend.standardize_dtype(layer.kernel_scale.dtype), \"float32\"\n        )\n        layer = layers.EinsumDense(\n            equation=\"ab,ab->a\",  # Test squeeze\n            output_shape=(2,),\n            dtype=\"int8_from_float32\",\n        )\n        layer.build((2, 4))\n        self.assertEqual(backend.standardize_dtype(layer._kernel.dtype), \"int8\")\n        self.assertEqual(\n            backend.standardize_dtype(layer.kernel_scale.dtype), \"float32\"\n        )\n\n    @parameterized.named_parameters(\n        (\n            \"int8_btnh,nhd->btd\",\n            \"int8\",\n            \"btnh,nhd->btd\",\n            (None, 8),\n            (1, 2, 2, 4),\n            1e-3,\n        ),\n        (\n            \"int8_btd,ndh->btnh\",\n            \"int8\",\n            \"btd,ndh->btnh\",\n            (None, 2, 8),\n            (1, 2, 4),\n            1e-3,\n        ),\n        (\"int8_btd,df->btf\", \"int8\", \"btd,df->btf\", (None, 4), (1, 2, 4), 1e-3),\n        (\n            \"int4_btnh,nhd->btd\",\n            \"int4\",\n            \"btnh,nhd->btd\",\n            (None, 8),\n            (1, 2, 2, 4),\n            3e-3,\n        ),\n        (\n            \"int4_btd,ndh->btnh\",\n            \"int4\",\n            \"btd,ndh->btnh\",\n            (None, 2, 8),\n            (1, 2, 4),\n            3e-3,\n        ),\n        (\n            \"int4_btd,df->btf\",\n            \"int4\",\n            \"btd,df->btf\",\n            (None, 4),\n            (1, 2, 4),\n            3.5e-3,  # Slightly higher threshold for grouped quantization\n        ),\n    )\n    def test_quantize_with_specific_equations(\n        self,\n        quantization_mode,\n        equation,\n        output_shape,\n        input_shape,\n        error_threshold,\n    ):\n        layer = layers.EinsumDense(equation=equation, output_shape=output_shape)\n        layer.build(input_shape)\n        x = ops.random.uniform(input_shape)\n        y_float = layer(x)\n\n        layer.quantize(quantization_mode)\n        y_quantized = layer(x)\n        mse = ops.mean(ops.square(y_float - y_quantized))\n        self.assertLess(mse, error_threshold)  # A weak correctness test\n\n    @parameterized.named_parameters(\n        (\"int8\", \"int8\"),\n        (\"float8\", \"float8\"),\n        (\"int4\", \"int4\"),\n    )\n    def test_quantize_on_unbuilt_layer(self, mode):\n        layer = layers.EinsumDense(\n            equation=\"ab,bcd->acd\",\n            output_shape=(8, 32),\n            bias_axes=\"d\",\n        )\n        with self.assertRaisesRegex(\n            ValueError, \"Cannot quantize a layer that isn't yet built.\"\n        ):\n            layer.quantize(mode)\n\n    @parameterized.named_parameters(\n        (\"int8\", \"int8\"),\n        (\"float8\", \"float8\"),\n        (\"int4\", \"int4\"),\n    )\n    def test_quantize_on_subclass(self, mode):\n        class MyEinsumDense(layers.EinsumDense):\n            pass\n\n        layer = MyEinsumDense(\n            equation=\"ab,bcd->acd\",\n            output_shape=(8, 32),\n            bias_axes=\"d\",\n        )\n        layer.build((None, 3))\n        with self.assertRaises(NotImplementedError):\n            layer.quantize(mode)\n\n        layer.quantize(mode, type_check=False)  # No error\n\n    @parameterized.named_parameters(\n        (\"int8\", \"int8\"),\n        (\"float8\", \"float8\"),\n        (\"int4\", \"int4\"),\n    )\n    def test_quantize_when_already_quantized(self, mode):\n        layer = layers.EinsumDense(\n            equation=\"ab,bcd->acd\",\n            output_shape=(8, 16),\n            bias_axes=\"d\",\n        )\n        layer.build((None, 3))\n        layer.quantize(mode)\n        for m in [\"int8\", \"float8\"]:\n            with self.assertRaisesRegex(\n                ValueError, \"is already quantized with dtype_policy=\"\n            ):\n                layer.quantize(m)\n\n        layer = layers.EinsumDense(\n            equation=\"ab,bcd->acd\",\n            output_shape=(8, 16),\n            bias_axes=\"d\",\n            dtype=f\"{mode}_from_float32\",\n        )\n        layer.build((None, 3))\n        for m in [\"int8\", \"float8\"]:\n            with self.assertRaisesRegex(\n                ValueError, \"is already quantized with dtype_policy=\"\n            ):\n                layer.quantize(m)\n\n    @parameterized.named_parameters(\n        (\"int8\", \"int8_from_float32\", 3),\n        (\"float8\", \"float8_from_float32\", 8),\n        (\"int4\", \"int4_from_float32\", 5),  # kernel + bias + scale + zero + gidx\n    )\n    def test_quantize_by_setting_dtype_policy(\n        self, policy, expected_num_variables\n    ):\n        layer = layers.EinsumDense(\n            equation=\"ab,bcd->acd\",\n            output_shape=(8, 32),\n            bias_axes=\"d\",\n        )\n        layer.build((None, 3))\n        layer.dtype_policy = policy\n        self.assertLen(layer.variables, expected_num_variables)\n\n    @parameterized.named_parameters(\n        (\"int7\", \"int7\"),\n        (\"float7\", \"float7\"),\n        (\"int3\", \"int3\"),\n    )\n    def test_quantize_invalid_mode(self, mode):\n        layer = layers.EinsumDense(\n            equation=\"ab,bcd->acd\",\n            output_shape=(8, 32),\n            bias_axes=\"d\",\n        )\n        layer.build((None, 3))\n        x = np.random.random((1, 3))\n        # dtype_policy should not be altered by failed quantization\n        original_dtype_policy = layer.dtype_policy\n\n        # Test quantize\n        with self.assertRaisesRegex(ValueError, \"Invalid quantization mode.\"):\n            layer.quantize(mode)\n        self.assertEqual(layer.dtype_policy, original_dtype_policy)\n\n        # Test quantized_build\n        with self.assertRaisesRegex(\n            NotImplementedError, \"Invalid quantization mode.\"\n        ):\n            layer.quantized_build((None, 2), mode)\n        self.assertEqual(layer.dtype_policy, original_dtype_policy)\n\n        # Test quantized_call\n        with self.assertRaisesRegex(\n            NotImplementedError, \"Invalid quantization mode.\"\n        ):\n            # Explicitly set quantization_mode\n            layer._dtype_policy._quantization_mode = mode\n            layer.quantized_call(x)\n        self.assertEqual(layer.dtype_policy, original_dtype_policy)\n\n    @parameterized.named_parameters(\n        (\"int8\", \"int8_from_mixed_bfloat16\", 1, 2),\n        (\"float8\", \"float8_from_mixed_bfloat16\", 8, 0),\n        (\"int4\", \"int4_from_mixed_bfloat16\", 1, 2),\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_quantize_dtype_argument(\n        self, dtype, num_trainable_weights, num_non_trainable_weights\n    ):\n        self.run_layer_test(\n            layers.EinsumDense,\n            init_kwargs={\n                \"equation\": \"ab,bcd->acd\",\n                \"output_shape\": (8, 32),\n                \"bias_axes\": \"d\",\n                \"dtype\": dtype,\n            },\n            input_shape=(2, 3),\n            expected_output_shape=(2, 8, 32),\n            expected_num_trainable_weights=num_trainable_weights,\n            expected_num_non_trainable_weights=num_non_trainable_weights,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n\n    @parameterized.named_parameters(\n        (\"int8_ab,bcd->acd\", \"int8\", \"ab,bcd->acd\", (64, 3), (64, 8, 32), 2, 4),\n        (\n            \"int8_btd,ndh->btnh\",\n            \"int8\",\n            \"btd,ndh->btnh\",\n            (1, 4, 32),\n            (1, 4, 8, 16),\n            2,\n            4,\n        ),\n        # int4 has +1 for kernel_zero and +1 for g_idx\n        (\"int4_ab,bcd->acd\", \"int4\", \"ab,bcd->acd\", (64, 3), (64, 8, 32), 4, 6),\n        (\n            \"int4_btd,ndh->btnh\",\n            \"int4\",\n            \"btd,ndh->btnh\",\n            (1, 4, 32),\n            (1, 4, 8, 16),\n            4,\n            6,\n        ),\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_quantize_lora_integration(\n        self,\n        quantization_mode,\n        equation,\n        input_shape,\n        output_shape,\n        expected_non_trainable,\n        expected_torch_params,\n    ):\n        config = dict(\n            equation=equation, output_shape=output_shape[1:], bias_axes=None\n        )\n        layer = layers.EinsumDense(**config)\n        layer.build(input_shape)\n        layer.enable_lora(2)\n        layer.quantize(quantization_mode)\n        self.assertLen(layer.trainable_weights, 2)\n        self.assertLen(layer.non_trainable_weights, expected_non_trainable)\n        if backend.backend() == \"torch\":\n            self.assertLen(layer.torch_params, expected_torch_params)\n\n        # Try calling fit()\n        init_lora_a_kernel_value = layer.lora_kernel_a.numpy()\n        init_lora_b_kernel_value = layer.lora_kernel_b.numpy()\n        x = np.random.random(input_shape)\n        y = np.random.random(output_shape)\n        model = models.Sequential([layer])\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n        model.fit(x, y, epochs=2)\n\n        final_lora_a_kernel_value = layer.lora_kernel_a.numpy()\n        final_lora_b_kernel_value = layer.lora_kernel_b.numpy()\n        diff_a = np.max(\n            np.abs(init_lora_a_kernel_value - final_lora_a_kernel_value)\n        )\n        diff_b = np.max(\n            np.abs(init_lora_b_kernel_value - final_lora_b_kernel_value)\n        )\n        self.assertGreater(diff_a, 0.0)\n        self.assertGreater(diff_b, 0.0)\n\n        # Try saving and reloading the model\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"quantized_lora_model.keras\"\n        )\n        model.save(temp_filepath)\n        new_model = saving.load_model(temp_filepath)\n        self.assertTrue(new_model.layers[0].lora_enabled)\n        self.assertAllClose(model.predict(x), new_model.predict(x), atol=0.5)\n\n        # Try saving and reloading the model's weights only\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"quantized_lora_model.weights.h5\"\n        )\n        model.save_weights(temp_filepath)\n        new_model = models.Sequential([layers.EinsumDense(**config)])\n        new_model.build(input_shape)\n        new_model.quantize(quantization_mode)\n        new_model.load_weights(temp_filepath)\n        self.assertFalse(new_model.layers[0].lora_enabled)\n        self.assertAllClose(model.predict(x), new_model.predict(x), atol=0.5)\n\n        # Try loading a normal checkpoint into a lora model\n        new_model.save_weights(temp_filepath)\n        model.load_weights(temp_filepath)\n        self.assertAllClose(model.predict(x), new_model.predict(x), atol=0.5)\n\n        # Test export and TFSMLayer reloading when using tensorflow backend\n        if backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n            ref_input = tf.random.normal(input_shape)\n            ref_output = model(ref_input)\n            model.export(temp_filepath, format=\"tf_saved_model\")\n            reloaded_layer = export.TFSMLayer(temp_filepath)\n            self.assertAllClose(\n                reloaded_layer(ref_input), ref_output, atol=1e-7\n            )\n            self.assertLen(reloaded_layer.weights, len(model.weights))\n            self.assertLen(\n                reloaded_layer.trainable_weights, len(model.trainable_weights)\n            )\n            self.assertLen(\n                reloaded_layer.non_trainable_weights,\n                len(model.non_trainable_weights),\n            )\n\n    @pytest.mark.requires_trainable_backend\n    def test_quantize_float8(self):\n        import ml_dtypes\n\n        from keras.src import quantizers\n\n        layer = layers.EinsumDense(\n            \"ab,bc->ac\",\n            output_shape=[32],\n            bias_axes=\"c\",\n        )\n        layer.build((None, 16))\n        layer.quantize(\"float8\")\n        optimizer = optimizers.AdamW(learning_rate=0.1)\n        optimizer.build(layer.trainable_variables)\n\n        def loss_fn(x, dy):\n            y = layer(x, training=True)\n            loss = y * ops.cast(dy, y.dtype)\n            return ops.sum(loss)\n\n        if backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            @tf.function(jit_compile=True)\n            def train_one_step(x, dy):\n                with tf.GradientTape() as tape:\n                    loss = loss_fn(x, dy)\n                grads = tape.gradient(loss, layer.trainable_variables)\n                optimizer.apply(grads, layer.trainable_variables)\n\n        elif backend.backend() == \"jax\":\n            import jax\n\n            def stateless_loss_fn(trainable_variables, x, dy):\n                y = layer.stateless_call(\n                    trainable_variables, [], x, training=True\n                )[0]\n                loss = y * ops.cast(dy, y.dtype)\n                return ops.sum(loss)\n\n            grad_fn = jax.jit(jax.grad(stateless_loss_fn))\n\n            def train_one_step(x, dy):\n                trainable_variables = [\n                    v.value for v in layer.trainable_variables\n                ]\n                optimizer_variables = [v.value for v in optimizer.variables]\n                grads = grad_fn(trainable_variables, x, dy)\n                trainable_variables, optimizer_variables = (\n                    optimizer.stateless_apply(\n                        optimizer_variables, grads, trainable_variables\n                    )\n                )\n                for variable, value in zip(\n                    layer.trainable_variables, trainable_variables\n                ):\n                    variable.assign(value)\n                for variable, value in zip(\n                    optimizer.variables, optimizer_variables\n                ):\n                    variable.assign(value)\n\n        elif backend.backend() == \"torch\":\n\n            def train_one_step(x, dy):\n                layer.zero_grad()\n                loss = loss_fn(x, dy)\n                loss.backward()\n                grads = [v.value.grad for v in layer.trainable_variables]\n                optimizer.apply(grads, layer.trainable_variables)\n\n        scale_x, amax_history_x = ops.ones(()), ops.zeros((1024,))\n        scale_k, amax_history_k = ops.ones(()), ops.zeros((1024,))\n        scale_g, amax_history_g = ops.ones(()), ops.zeros((1024,))\n        e4m3_max = ops.cast(\n            float(ml_dtypes.finfo(\"float8_e4m3fn\").max), \"float32\"\n        )\n        e5m2_max = ops.cast(\n            float(ml_dtypes.finfo(\"float8_e5m2\").max), \"float32\"\n        )\n\n        for _ in range(3):\n            x = random.normal((16, 16), dtype=\"float32\")\n            g = random.normal((16, 32), dtype=\"float32\")\n            k = ops.convert_to_tensor(layer._kernel)\n\n            # Manually compute the expected amax history and scaling factors.\n            amax_from_history_x = ops.max(amax_history_x)\n            amax_from_history_k = ops.max(amax_history_k)\n            amax_from_history_g = ops.max(amax_history_g)\n            scale_x = quantizers.compute_float8_scale(\n                amax_from_history_x, scale_x, e4m3_max\n            )\n            scale_k = quantizers.compute_float8_scale(\n                amax_from_history_k, scale_k, e4m3_max\n            )\n            scale_g = quantizers.compute_float8_scale(\n                amax_from_history_g, scale_g, e5m2_max\n            )\n            amax_history_x = quantizers.compute_float8_amax_history(\n                x, amax_history_x\n            )\n            amax_history_k = quantizers.compute_float8_amax_history(\n                k, amax_history_k\n            )\n            amax_history_g = quantizers.compute_float8_amax_history(\n                g, amax_history_g\n            )\n\n            train_one_step(x, g)\n\n            self.assertAllClose(layer.inputs_amax_history, amax_history_x)\n            self.assertAllClose(layer.kernel_amax_history, amax_history_k)\n            self.assertAllClose(layer.outputs_grad_amax_history, amax_history_g)\n            self.assertAllClose(layer.inputs_scale, scale_x)\n            self.assertAllClose(layer.kernel_scale, scale_k)\n            self.assertAllClose(layer.outputs_grad_scale, scale_g)\n\n    @pytest.mark.requires_trainable_backend\n    def test_quantize_float8_fitting(self):\n        config = dict(\n            equation=\"ab,bcd->acd\",\n            output_shape=(8, 32),\n            bias_axes=\"d\",\n        )\n        layer = layers.EinsumDense(**config)\n        layer.build((None, 3))\n        layer.quantize(\"float8\")\n        self.assertLen(layer.trainable_weights, 8)\n        self.assertLen(layer.non_trainable_weights, 0)\n\n        # Try calling fit()\n        x = np.random.random((64, 3))\n        y = np.random.random((64, 8, 32))\n        model = models.Sequential([layer])\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n        model.fit(x, y, epochs=2)\n\n        # Try saving and reloading the model\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"quantized_lora_model.keras\"\n        )\n        model.save(temp_filepath)\n        new_model = saving.load_model(temp_filepath)\n        self.assertAllClose(model.predict(x), new_model.predict(x))\n\n        # Try saving and reloading the model's weights only\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"quantized_lora_model.weights.h5\"\n        )\n        model.save_weights(temp_filepath)\n        new_model = models.Sequential([layers.EinsumDense(**config)])\n        new_model.build((None, 3))\n        new_model.quantize(\"float8\")\n        new_model.load_weights(temp_filepath)\n        self.assertAllClose(model.predict(x), new_model.predict(x))\n\n        # Test export and TFSMLayer reloading when using tensorflow backend\n        if backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n            ref_input = tf.random.normal((2, 3))\n            ref_output = model(ref_input)\n            model.export(temp_filepath, format=\"tf_saved_model\")\n            reloaded_layer = export.TFSMLayer(temp_filepath)\n            self.assertAllClose(reloaded_layer(ref_input), ref_output)\n            self.assertLen(reloaded_layer.weights, len(model.weights))\n            self.assertLen(\n                reloaded_layer.trainable_weights, len(model.trainable_weights)\n            )\n            self.assertLen(\n                reloaded_layer.non_trainable_weights,\n                len(model.non_trainable_weights),\n            )\n\n    def test_quantize_float8_inference(self):\n        config = dict(\n            equation=\"ab,bcd->acd\",\n            output_shape=(8, 32),\n            bias_axes=\"d\",\n        )\n        layer = layers.EinsumDense(**config)\n        layer.build((None, 3))\n        layer.quantize(\"float8\")\n\n        # Try calling with `training=False` and the result must match\n        # `training=True` because there is no update.\n        x = np.random.random((64, 3))\n        y_inference = layer(x, training=False)\n        y_training = layer(x, training=True)\n        self.assertAllClose(y_inference, y_training)\n\n    def test_gptq_serialization(self):\n        \"\"\"Test that a GPTQ-quantized layer can be serialized and deserialized\n        correctly.\"\"\"\n        config = dict(\n            equation=\"ab,bcd->acd\",\n            output_shape=(8, 32),\n            bias_axes=\"d\",\n        )\n        layer = layers.EinsumDense(**config)\n        layer.build((None, 3))\n        layer.quantize(\n            \"gptq\",\n            config=GPTQConfig(\n                dataset=None, tokenizer=None, weight_bits=4, group_size=8\n            ),\n        )\n        config = layer.get_config()\n        new_layer = layers.EinsumDense.from_config(config)\n        new_layer.build((None, 3))\n        self.assertEqual(new_layer.quantization_mode, \"gptq\")\n\n    def test_awq_serialization(self):\n        \"\"\"Test that an AWQ-quantized layer can be serialized and deserialized\n        correctly.\"\"\"\n        config = dict(\n            equation=\"ab,bcd->acd\",\n            output_shape=(8, 32),\n            bias_axes=\"d\",\n        )\n        layer = layers.EinsumDense(**config)\n        layer.build((None, 3))\n        layer.quantize(\n            \"awq\",\n            config=AWQConfig(\n                dataset=None, tokenizer=None, group_size=8, num_grid_points=10\n            ),\n        )\n        layer_config = layer.get_config()\n        new_layer = layers.EinsumDense.from_config(layer_config)\n        new_layer.build((None, 3))\n        self.assertEqual(new_layer.quantization_mode, \"awq\")\n\n    def test_int4_kernel_returns_unpacked_form(self):\n        \"\"\"Test that the `kernel` property returns the unpacked int4 kernel.\"\"\"\n        layer = layers.EinsumDense(\n            equation=\"ab,bc->ac\",\n            output_shape=(2,),\n        )\n        layer.build((None, 2))\n        layer.quantize(\"int4\")\n        packed_kernel = layer._kernel\n        # Unpack [rows, ceil(columns/2)] -> [rows, columns],\n        # then reshape to original shape\n        unpacked = quantizers.unpack_int4(\n            packed_kernel, layer._int4_unpacked_column_size, axis=-1\n        )\n        expected = ops.reshape(unpacked, layer.original_kernel_shape)\n        self.assertAllClose(layer.kernel, expected)\n\n    def test_legacy_load_own_variables(self):\n        # In previous versions, `load_own_variables` accepted a store with\n        # numeric keys.\n        float32_store = {\n            \"0\": np.random.random((3, 8, 32)).astype(\"float32\"),\n            \"1\": np.random.random((32,)).astype(\"float32\"),\n        }\n        int8_store = {\n            \"0\": np.random.randint(-128, 127, size=(3, 8, 32), dtype=\"int8\"),\n            \"1\": np.random.random((32,)).astype(\"float32\"),\n            \"2\": np.random.random((1, 8, 32)).astype(\"float32\"),\n        }\n        int4_store = {\n            # int4 layout: kernel is [rows, ceil(columns/2)] = [3, 128]\n            # where rows=3, columns=8*32=256\n            \"0\": np.random.randint(-128, 127, size=(3, 128), dtype=\"int8\"),\n            \"1\": np.random.random((32,)).astype(\"float32\"),\n            \"2\": np.random.random((256,)).astype(\"float32\"),  # per-channel\n        }\n        float8_store = {\n            \"0\": np.random.random((3, 8, 32)).astype(\"float32\"),\n            \"1\": np.random.random((32,)).astype(\"float32\"),\n            # inputs_scale.\n            \"2\": np.random.random(()).astype(\"float32\"),\n            # inputs_amax_history.\n            \"3\": np.random.random((1024,)).astype(\"float32\"),\n            # kernel_scale.\n            \"4\": np.random.random(()).astype(\"float32\"),\n            # kernel_amax_history.\n            \"5\": np.random.random((1024,)).astype(\"float32\"),\n            # outputs_grad_scale.\n            \"6\": np.random.random(()).astype(\"float32\"),\n            # outputs_grad_amax_history.\n            \"7\": np.random.random((1024,)).astype(\"float32\"),\n        }\n        gptq_store = {\n            # bias\n            \"0\": np.random.random((32,)).astype(\"float32\"),\n            # quantized_kernel\n            \"1\": np.random.randint(0, 16, size=(16, 24), dtype=\"uint8\"),\n            # kernel_scale.\n            \"2\": np.random.random((32, 3)).astype(\"float32\"),\n            # kernel_zero\n            \"3\": np.random.random((32, 3)).astype(\"uint8\"),\n            # g_idx\n            \"4\": np.random.random((24,)).astype(\"float32\"),\n        }\n        # kernel shape (3, 8, 32), packed: (16, 24) for 4-bit\n        awq_store = {\n            \"0\": np.random.random((32,)).astype(\"float32\"),  # bias\n            \"1\": np.random.randint(0, 16, size=(16, 24), dtype=\"uint8\"),\n            \"2\": np.random.random((32, 3)).astype(\"float32\"),  # scale\n            \"3\": np.random.random((32, 3)).astype(\"uint8\"),  # zero\n            \"4\": np.random.random((24,)).astype(\"float32\"),  # awq_scales\n            \"5\": np.random.random((24,)).astype(\"float32\"),  # g_idx\n        }\n        config = dict(\n            equation=\"ab,bcd->acd\",\n            output_shape=(8, 32),\n            bias_axes=\"d\",\n        )\n\n        # Test float32 layer.\n        layer = layers.EinsumDense(**config)\n        layer.build((None, 3))\n        layer.load_own_variables(float32_store)\n        self.assertAllClose(layer._kernel, float32_store[\"0\"])\n        self.assertAllClose(layer.bias, float32_store[\"1\"])\n\n        # Test int8-quantized layer.\n        layer = layers.EinsumDense(**config, dtype=\"int8_from_float32\")\n        layer.build((None, 3))\n        layer.load_own_variables(int8_store)\n        self.assertAllClose(layer._kernel, int8_store[\"0\"])\n        self.assertAllClose(layer.bias, int8_store[\"1\"])\n        self.assertAllClose(layer.kernel_scale, int8_store[\"2\"])\n\n        # Test int4-quantized layer.\n        layer = layers.EinsumDense(**config, dtype=\"int4_from_float32\")\n        layer.build((None, 3))\n        layer.load_own_variables(int4_store)\n        self.assertAllClose(layer._kernel, int4_store[\"0\"])\n        self.assertAllClose(layer.bias, int4_store[\"1\"])\n        self.assertAllClose(layer.kernel_scale, int4_store[\"2\"])\n\n        # Test float8-quantized layer.\n        layer = layers.EinsumDense(**config, dtype=\"float8_from_float32\")\n        layer.build((None, 3))\n        layer.load_own_variables(float8_store)\n        self.assertAllClose(layer._kernel, float8_store[\"0\"])\n        self.assertAllClose(layer.bias, float8_store[\"1\"])\n        self.assertAllClose(layer.inputs_scale, float8_store[\"2\"])\n        self.assertAllClose(layer.inputs_amax_history, float8_store[\"3\"])\n        self.assertAllClose(layer.kernel_scale, float8_store[\"4\"])\n        self.assertAllClose(layer.kernel_amax_history, float8_store[\"5\"])\n        self.assertAllClose(layer.outputs_grad_scale, float8_store[\"6\"])\n        self.assertAllClose(layer.outputs_grad_amax_history, float8_store[\"7\"])\n\n        # Test gptq-quantized layer.\n        layer = layers.EinsumDense(**config, dtype=\"gptq/4/8_from_float32\")\n        layer.build((None, 3))\n        layer.load_own_variables(gptq_store)\n        self.assertTrue(layer.is_gptq_calibrated)\n        self.assertAllClose(layer.bias, gptq_store[\"0\"])\n        self.assertAllClose(layer.quantized_kernel, gptq_store[\"1\"])\n        self.assertAllClose(layer.kernel_scale, gptq_store[\"2\"])\n        self.assertAllClose(layer.kernel_zero, gptq_store[\"3\"])\n        self.assertAllClose(layer.g_idx, gptq_store[\"4\"])\n\n        # Test awq-quantized layer.\n        layer = layers.EinsumDense(**config, dtype=\"awq/4/8_from_float32\")\n        layer.build((None, 3))\n        layer.load_own_variables(awq_store)\n        self.assertTrue(layer.is_awq_calibrated)\n        self.assertAllClose(layer.bias, awq_store[\"0\"])\n        self.assertAllClose(layer.quantized_kernel, awq_store[\"1\"])\n        self.assertAllClose(layer.kernel_scale, awq_store[\"2\"])\n        self.assertAllClose(layer.kernel_zero, awq_store[\"3\"])\n        self.assertAllClose(layer.awq_scales, awq_store[\"4\"])\n        self.assertAllClose(layer.g_idx, awq_store[\"5\"])\n\n    def test_int4_gptq_kernel_returns_unpacked_form(self):\n        \"\"\"Test that the `kernel` property returns the unpacked int4 GPTQ\n        kernel.\"\"\"\n        layer = layers.EinsumDense(\n            equation=\"ab,bc->ac\",\n            output_shape=(2,),\n        )\n        layer.build((None, 2))\n        layer.quantize(\n            \"gptq\",\n            config=GPTQConfig(\n                dataset=None, tokenizer=None, weight_bits=4, group_size=8\n            ),\n        )\n        layer.is_gptq_calibrated = True  # Bypass calibration check\n        packed_kernel = layer.quantized_kernel\n        self.assertAllClose(\n            layer.kernel, quantizers.unpack_int4(packed_kernel, 2)\n        )\n\n    def test_gptq_kernel_packing(self):\n        \"\"\"Validates that 4-bit GPTQ packing reduces the kernel size.\"\"\"\n        config = dict(\n            equation=\"ab,bcd->acd\",\n            output_shape=(8, 32),\n            bias_axes=\"d\",\n        )\n        layer = layers.EinsumDense(**config)\n        layer.build((None, 3))\n\n        original_kernel_params = ops.prod(layer._kernel.shape)\n\n        layer.quantize(\n            \"gptq\",\n            config=GPTQConfig(\n                dataset=None, tokenizer=None, weight_bits=4, group_size=8\n            ),\n        )\n\n        quantized_kernel_params = ops.prod(layer.quantized_kernel.shape)\n        self.assertEqual(\n            quantized_kernel_params,\n            original_kernel_params // 2,\n        )\n\n    def test_int4_awq_kernel_returns_unpacked_form(self):\n        \"\"\"Test that the `kernel` property returns the unpacked int4 AWQ\n        kernel.\"\"\"\n        layer = layers.EinsumDense(\n            equation=\"ab,bc->ac\",\n            output_shape=(2,),\n        )\n        layer.build((None, 2))\n        layer.quantize(\n            \"awq\",\n            config=AWQConfig(\n                dataset=None, tokenizer=None, group_size=8, num_grid_points=10\n            ),\n        )\n        layer.is_awq_calibrated = True  # Bypass calibration check\n        packed_kernel = layer.quantized_kernel\n        self.assertAllClose(\n            layer.kernel, quantizers.unpack_int4(packed_kernel, 2)\n        )\n\n    def test_awq_kernel_packing(self):\n        \"\"\"Validates that 4-bit AWQ packing reduces the kernel size.\"\"\"\n        config = dict(\n            equation=\"ab,bcd->acd\",\n            output_shape=(8, 32),\n            bias_axes=\"d\",\n        )\n        layer = layers.EinsumDense(**config)\n        layer.build((None, 3))\n\n        original_kernel_params = ops.prod(layer._kernel.shape)\n\n        layer.quantize(\n            \"awq\",\n            config=AWQConfig(\n                dataset=None, tokenizer=None, group_size=8, num_grid_points=10\n            ),\n        )\n\n        quantized_kernel_params = ops.prod(layer.quantized_kernel.shape)\n        self.assertEqual(\n            quantized_kernel_params,\n            original_kernel_params // 2,\n        )\n\n    def test_einsum_dense_int8_custom_quantizer(self):\n        \"\"\"\n        Test custom quantizer serialization for einsum dense layer with\n        int8 quantization.\n        \"\"\"\n        # Setup\n        weight_range = (-10, 10)\n        config = Int8QuantizationConfig(\n            weight_quantizer=AbsMaxQuantizer(axis=0, value_range=weight_range),\n            activation_quantizer=None,\n        )\n\n        # Build & Quantize\n        layer = layers.EinsumDense(\"ab,bc->ac\", output_shape=10)\n        layer.build((None, 5))\n        layer.quantize(\"int8\", config=config)\n\n        # Serialize & Deserialize\n        serialized = layer.get_config()\n        new_layer = layers.EinsumDense.from_config(serialized)\n\n        # Verify\n        self.assertIsInstance(\n            new_layer.quantization_config, Int8QuantizationConfig\n        )\n\n        quantizer = new_layer.quantization_config.weight_quantizer\n        self.assertIsInstance(quantizer, AbsMaxQuantizer)\n        self.assertEqual(quantizer.axis, (0,))\n        self.assertAllEqual(quantizer.value_range, weight_range)\n\n    @parameterized.named_parameters(\n        (\"grouped_block_64\", 64),\n        (\"grouped_block_128\", 128),\n        (\"per_channel_none\", None),\n        (\"per_channel_neg1\", -1),\n    )\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_int4_quantization_block_size(self, block_size):\n        \"\"\"Test int4 quantization with different block_size configurations.\"\"\"\n        # Use simple 2D equation: ab,bc->ac (like Dense)\n        layer = layers.EinsumDense(\n            equation=\"ab,bc->ac\",\n            output_shape=(64,),\n            bias_axes=\"c\",\n        )\n        layer.build((None, 256))\n\n        x = np.random.random((2, 256)).astype(\"float32\")\n        y_float = layer(x)\n\n        # Create config with specified block_size\n        config = Int4QuantizationConfig(block_size=block_size)\n        layer.quantize(\"int4\", config=config)\n\n        # For EinsumDense, when per-channel mode is used (block_size None\n        # or -1), the stored _int4_block_size is None (not the original value)\n        if block_size is None or block_size == -1:\n            self.assertIsNone(layer._int4_block_size)\n        else:\n            self.assertEqual(layer._int4_block_size, block_size)\n\n        # Verify kernel_scale shape (GPTQ layout)\n        if block_size is None or block_size == -1:\n            # Per-channel: one scale per output unit\n            expected_scale_shape = (64,)\n        else:\n            # Sub-channel: (n_groups, columns) = (n_groups, out)\n            n_groups = math.ceil(256 / block_size)\n            expected_scale_shape = (n_groups, 64)\n\n        self.assertEqual(layer.kernel_scale.shape, expected_scale_shape)\n\n        # Verify outputs are reasonable\n        y_quantized = layer(x)\n        mse = ops.mean(ops.square(y_float - y_quantized))\n        self.assertLess(mse, 0.01)  # Reasonable accuracy\n\n    @parameterized.named_parameters(\n        (\"grouped_block_64\", 64),\n        (\"grouped_block_128\", 128),\n        (\"per_channel_none\", None),\n    )\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_int4_block_size_serialization(self, block_size):\n        \"\"\"Test that block_size is preserved through serialization.\"\"\"\n        layer = layers.EinsumDense(\n            equation=\"ab,bc->ac\",\n            output_shape=(32,),\n            bias_axes=\"c\",\n        )\n        layer.build((None, 128))\n\n        config = Int4QuantizationConfig(block_size=block_size)\n        layer.quantize(\"int4\", config=config)\n\n        # Get output before serialization\n        x = np.random.random((2, 128)).astype(\"float32\")\n        y_before = layer(x)\n\n        # Save and load model to test full serialization roundtrip\n        model = models.Sequential([layer])\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"int4_block_size_einsum_model.keras\"\n        )\n        model.save(temp_filepath)\n        loaded_model = saving.load_model(temp_filepath)\n\n        # Verify block_size is preserved\n        loaded_layer = loaded_model.layers[0]\n        self.assertIsInstance(\n            loaded_layer.quantization_config, Int4QuantizationConfig\n        )\n        self.assertEqual(\n            loaded_layer.quantization_config.block_size, block_size\n        )\n\n        # Verify outputs match after deserialization\n        y_after = loaded_model(x)\n        self.assertAllClose(y_before, y_after)\n\n    @parameterized.named_parameters(\n        (\"grouped_block_64\", 64),\n        (\"per_channel\", None),\n    )\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_int4_block_size_with_lora(self, block_size):\n        \"\"\"Test int4 quantization with LoRA and different block_size.\"\"\"\n        layer = layers.EinsumDense(\n            equation=\"ab,bc->ac\",\n            output_shape=(64,),\n            bias_axes=\"c\",\n        )\n        layer.build((None, 128))\n\n        config = Int4QuantizationConfig(block_size=block_size)\n        layer.quantize(\"int4\", config=config)\n        layer.enable_lora(rank=4)\n\n        x = np.random.random((2, 128)).astype(\"float32\")\n\n        # Should run without error\n        y = layer(x)\n        self.assertEqual(y.shape, (2, 64))\n\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_int4_grouped_vs_perchannel_scale_shapes(self):\n        \"\"\"Test that grouped and per-channel have different scale shapes.\"\"\"\n        input_dim, output_dim = 256, 64\n        block_size = 64\n\n        # Per-channel layer\n        layer_pc = layers.EinsumDense(\n            equation=\"ab,bc->ac\",\n            output_shape=(output_dim,),\n        )\n        layer_pc.build((None, input_dim))\n        config_pc = Int4QuantizationConfig(block_size=None)\n        layer_pc.quantize(\"int4\", config=config_pc)\n\n        # Grouped layer\n        layer_grouped = layers.EinsumDense(\n            equation=\"ab,bc->ac\",\n            output_shape=(output_dim,),\n        )\n        layer_grouped.build((None, input_dim))\n        config_grouped = Int4QuantizationConfig(block_size=block_size)\n        layer_grouped.quantize(\"int4\", config=config_grouped)\n\n        # Per-channel: (output_dim,),\n        # Grouped: (n_groups, output_dim)\n        self.assertEqual(layer_pc.kernel_scale.shape, (output_dim,))\n        self.assertEqual(\n            layer_grouped.kernel_scale.shape,\n            (input_dim // block_size, output_dim),\n        )\n\n    @parameterized.named_parameters(\n        (\"btd_df_btf_grouped\", \"btd,df->btf\", (8, 32), (None, 8, 256), 64),\n        (\"btd_df_btf_pc\", \"btd,df->btf\", (8, 32), (None, 8, 256), None),\n        (\"ab_bcd_acd_grouped\", \"ab,bcd->acd\", (8, 32), (None, 64), 32),\n        (\"ab_bcd_acd_pc\", \"ab,bcd->acd\", (8, 32), (None, 64), None),\n    )\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_int4_block_size_various_equations(\n        self, equation, output_shape, input_shape, block_size\n    ):\n        \"\"\"Test int4 quantization with different equations and block_size.\"\"\"\n        layer = layers.EinsumDense(\n            equation=equation,\n            output_shape=output_shape,\n        )\n        layer.build(input_shape)\n\n        batch_input_shape = (2,) + input_shape[1:]\n        x = np.random.random(batch_input_shape).astype(\"float32\")\n        y_float = layer(x)\n\n        config = Int4QuantizationConfig(block_size=block_size)\n        layer.quantize(\"int4\", config=config)\n\n        # Verify quantization works\n        y_quantized = layer(x)\n        self.assertEqual(y_float.shape, y_quantized.shape)\n\n        # Verify reasonable accuracy\n        mse = ops.mean(ops.square(y_float - y_quantized))\n        self.assertLess(mse, 0.02)\n\n    @parameterized.named_parameters(\n        # Attention output: bnh,nhd->bd (two reduced axes: n, h)\n        # Kernel shape: (n=4, h=32, d=64), reduced axes: n, h\n        (\"attn_output_grouped\", \"bnh,nhd->bd\", (64,), (None, 4, 32), 64),\n        (\"attn_output_pc\", \"bnh,nhd->bd\", (64,), (None, 4, 32), None),\n        # Multi-head attention value projection: ab,bcd->acd (one reduced: b)\n        (\"mha_value_grouped\", \"ab,bcd->acd\", (8, 32), (None, 64), 32),\n        (\"mha_value_pc\", \"ab,bcd->acd\", (8, 32), (None, 64), None),\n    )\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_int4_grouped_multi_reduced_axes(\n        self, equation, output_shape, input_shape, block_size\n    ):\n        \"\"\"Test int4 grouped quantization with multiple reduced axes.\"\"\"\n        layer = layers.EinsumDense(\n            equation=equation,\n            output_shape=output_shape,\n        )\n        layer.build(input_shape)\n\n        batch_input_shape = (2,) + input_shape[1:]\n        x = np.random.random(batch_input_shape).astype(\"float32\")\n        y_float = layer(x)\n\n        config = Int4QuantizationConfig(block_size=block_size)\n        layer.quantize(\"int4\", config=config)\n\n        # Verify quantization works\n        y_quantized = layer(x)\n        self.assertEqual(y_float.shape, y_quantized.shape)\n\n        # Verify reasonable accuracy\n        mse = ops.mean(ops.square(y_float - y_quantized))\n        self.assertLess(mse, 0.1)\n\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_int4_multi_reduced_axes_scale_shape(self):\n        \"\"\"Test that scale shape is correct for multi-reduced-axis equations.\"\"\"\n        # Equation: bnh,nhd->bd (two reduced axes: n, h)\n        # Kernel shape: (n=4, h=32, d=64)\n        # Reduced axes: n, h (total reduced dim = 4 * 32 = 128)\n        layer = layers.EinsumDense(\n            equation=\"bnh,nhd->bd\",\n            output_shape=(64,),\n        )\n        layer.build((None, 4, 32))\n\n        block_size = 64\n        config = Int4QuantizationConfig(block_size=block_size)\n        layer.quantize(\"int4\", config=config)\n\n        # Total reduced dim = 4 * 32 = 128, n_groups = ceil(128/64) = 2\n        # scale shape is (n_groups, columns) =\n        # (n_groups, output_dim)\n        # For this equation, output_dim=64, n_groups=ceil(128/64)=2\n        kernel_scale = layer.kernel_scale\n        self.assertIsNotNone(kernel_scale)\n        # (n_groups, columns) where n_groups is at index 0\n        self.assertEqual(kernel_scale.shape[0], 2)\n\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_int4_multi_reduced_axes_serialization(self):\n        \"\"\"Test serialization with multi-reduced-axis equations.\"\"\"\n        # Equation: bnh,nhd->bd (two reduced axes: n, h)\n        layer = layers.EinsumDense(\n            equation=\"bnh,nhd->bd\",\n            output_shape=(64,),\n        )\n        layer.build((None, 4, 32))\n\n        config = Int4QuantizationConfig(block_size=64)\n        layer.quantize(\"int4\", config=config)\n\n        x = np.random.random((2, 4, 32)).astype(\"float32\")\n        y_before = layer(x)\n\n        # Save and load model\n        model = models.Sequential([layer])\n        path = f\"{self.get_temp_dir()}/model.keras\"\n        model.save(path)\n        loaded_model = load_model(path)\n\n        # Verify outputs match\n        y_after = loaded_model(x)\n        self.assertAllClose(y_before, y_after, atol=1e-5)\n\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_int4_multi_reduced_vs_single_reduced(self):\n        \"\"\"Compare grouped quantization: single vs multi-reduced axes.\"\"\"\n        block_size = 64\n\n        # Single reduced axis: bd,df->bf (reduced: d)\n        layer_single = layers.EinsumDense(\n            equation=\"bd,df->bf\",\n            output_shape=(32,),\n        )\n        layer_single.build((None, 128))\n\n        # Multi reduced axes: bnh,nhd->bd (reduced: n, h)\n        # n=4, h=32 gives total reduced dim = 128\n        layer_multi = layers.EinsumDense(\n            equation=\"bnh,nhd->bd\",\n            output_shape=(128,),\n        )\n        layer_multi.build((None, 4, 32))\n\n        # Quantize both with grouped quantization\n        config = Int4QuantizationConfig(block_size=block_size)\n        layer_single.quantize(\"int4\", config=config)\n        layer_multi.quantize(\"int4\", config=config)\n\n        # Both should use grouped quantization (block_size stored)\n        self.assertEqual(layer_single._int4_block_size, block_size)\n        self.assertEqual(layer_multi._int4_block_size, block_size)\n\n        # Verify forward pass works for both\n        x_single = np.random.random((2, 128)).astype(\"float32\")\n        x_multi = np.random.random((2, 4, 32)).astype(\"float32\")\n\n        y_single = layer_single(x_single)\n        y_multi = layer_multi(x_multi)\n\n        self.assertEqual(y_single.shape, (2, 32))\n        self.assertEqual(y_multi.shape, (2, 128))\n\n    @parameterized.named_parameters(\n        (\"grouped_block_64\", 64),\n        (\"grouped_block_128\", 128),\n    )\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_int4_subchannel_g_idx_created(self, block_size):\n        \"\"\"Test that g_idx is created for sub-channel int4 quantization.\"\"\"\n        layer = layers.EinsumDense(\n            equation=\"ab,bc->ac\",\n            output_shape=(32,),\n            bias_axes=\"c\",\n        )\n        layer.build((None, 128))\n\n        config = Int4QuantizationConfig(block_size=block_size)\n        layer.quantize(\"int4\", config=config)\n\n        # Verify g_idx is created\n        self.assertTrue(hasattr(layer, \"g_idx\"))\n\n        # Verify g_idx shape (128 = input_dim = reduced dimension)\n        self.assertEqual(layer.g_idx.shape, (128,))\n\n        # Verify g_idx values (should map each row to its group)\n        expected_g_idx = np.arange(128) // block_size\n        self.assertAllClose(layer.g_idx, expected_g_idx)\n\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_int4_perchannel_no_g_idx(self):\n        \"\"\"Test that per-channel int4 does NOT create g_idx.\"\"\"\n        layer = layers.EinsumDense(\n            equation=\"ab,bc->ac\",\n            output_shape=(32,),\n            bias_axes=\"c\",\n        )\n        layer.build((None, 64))\n\n        config = Int4QuantizationConfig(block_size=None)  # Per-channel\n        layer.quantize(\"int4\", config=config)\n\n        # Verify g_idx is NOT created for per-channel\n        self.assertFalse(hasattr(layer, \"g_idx\"))\n\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_int4_subchannel_g_idx_serialization(self):\n        \"\"\"Test that g_idx is properly serialized and deserialized.\"\"\"\n        layer = layers.EinsumDense(\n            equation=\"ab,bc->ac\",\n            output_shape=(32,),\n            bias_axes=\"c\",\n        )\n        layer.build((None, 128))\n        block_size = 64\n\n        config = Int4QuantizationConfig(block_size=block_size)\n        layer.quantize(\"int4\", config=config)\n\n        x = np.random.random((2, 128)).astype(\"float32\")\n        y_before = layer(x)\n        g_idx_before = layer.g_idx\n\n        # Save and load\n        model = models.Sequential([layer])\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"einsum_int4_g_idx_model.keras\"\n        )\n        model.save(temp_filepath)\n        loaded_model = saving.load_model(temp_filepath)\n\n        # Verify g_idx is preserved\n        loaded_layer = loaded_model.layers[0]\n        self.assertTrue(hasattr(loaded_layer, \"g_idx\"))\n        self.assertAllClose(loaded_layer.g_idx, g_idx_before)\n\n        # Verify outputs match\n        y_after = loaded_model(x)\n        self.assertAllClose(y_before, y_after)\n"
  },
  {
    "path": "keras/src/layers/core/embedding.py",
    "content": "import math\nimport warnings\n\nfrom keras.src import backend\nfrom keras.src import constraints\nfrom keras.src import dtype_policies\nfrom keras.src import initializers\nfrom keras.src import ops\nfrom keras.src import quantizers\nfrom keras.src import regularizers\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend import KerasTensor\nfrom keras.src.layers.layer import Layer\nfrom keras.src.quantizers.quantization_config import QuantizationConfig\nfrom keras.src.quantizers.quantization_config import get_block_size_for_layer\nfrom keras.src.quantizers.quantizers import dequantize_with_sz_map\nfrom keras.src.saving import serialization_lib\n\n\n@keras_export(\"keras.layers.Embedding\")\nclass Embedding(Layer):\n    \"\"\"Turns nonnegative integers (indexes) into dense vectors of fixed size.\n\n    e.g. `[[4], [20]] -> [[0.25, 0.1], [0.6, -0.2]]`\n\n    This layer can only be used on nonnegative integer inputs of a fixed range.\n\n    Example:\n\n    >>> model = keras.Sequential()\n    >>> model.add(keras.layers.Embedding(1000, 64))\n    >>> # The model will take as input an integer matrix of size (batch,\n    >>> # input_length), and the largest integer (i.e. word index) in the input\n    >>> # should be no larger than 999 (vocabulary size).\n    >>> # Now model.output_shape is (None, 10, 64), where `None` is the batch\n    >>> # dimension.\n    >>> input_array = np.random.randint(1000, size=(32, 10))\n    >>> model.compile('rmsprop', 'mse')\n    >>> output_array = model.predict(input_array)\n    >>> print(output_array.shape)\n    (32, 10, 64)\n\n    Args:\n        input_dim: Integer. Size of the vocabulary,\n            i.e. maximum integer index + 1.\n        output_dim: Integer. Dimension of the dense embedding.\n        embeddings_initializer: Initializer for the `embeddings`\n            matrix (see `keras.initializers`).\n        embeddings_regularizer: Regularizer function applied to\n            the `embeddings` matrix (see `keras.regularizers`).\n        embeddings_constraint: Constraint function applied to\n            the `embeddings` matrix (see `keras.constraints`).\n        mask_zero: Boolean, whether or not the input value 0 is a special\n            \"padding\" value that should be masked out.\n            This is useful when using recurrent layers which\n            may take variable length input. If this is `True`,\n            then all subsequent layers in the model need\n            to support masking or an exception will be raised.\n            If `mask_zero` is set to `True`, as a consequence,\n            index 0 cannot be used in the vocabulary (`input_dim` should\n            equal size of vocabulary + 1).\n        weights: Optional floating-point matrix of size\n            `(input_dim, output_dim)`. The initial embeddings values\n            to use.\n        lora_rank: Optional integer. If set, the layer's forward pass\n            will implement LoRA (Low-Rank Adaptation)\n            with the provided rank. LoRA sets the layer's embeddings\n            matrix to non-trainable and replaces it with a delta over the\n            original matrix, obtained via multiplying two lower-rank\n            trainable matrices. This can be useful to reduce the\n            computation cost of fine-tuning large embedding layers.\n            You can also enable LoRA on an existing\n            `Embedding` layer by calling `layer.enable_lora(rank)`.\n        lora_alpha: Optional integer. If set, this parameter scales the\n            low-rank adaptation delta (computed as the product of two lower-rank\n            trainable matrices) during the forward pass. The delta is scaled by\n            `lora_alpha / lora_rank`, allowing you to fine-tune the strength of\n            the LoRA adjustment independently of `lora_rank`.\n\n    Input shape:\n        2D tensor with shape: `(batch_size, input_length)`.\n\n    Output shape:\n        3D tensor with shape: `(batch_size, input_length, output_dim)`.\n    \"\"\"\n\n    def __init__(\n        self,\n        input_dim,\n        output_dim,\n        embeddings_initializer=\"uniform\",\n        embeddings_regularizer=None,\n        embeddings_constraint=None,\n        mask_zero=False,\n        weights=None,\n        lora_rank=None,\n        lora_alpha=None,\n        quantization_config=None,\n        **kwargs,\n    ):\n        input_length = kwargs.pop(\"input_length\", None)\n        if input_length is not None:\n            warnings.warn(\n                \"Argument `input_length` is deprecated. Just remove it.\"\n            )\n        super().__init__(**kwargs)\n        self.input_dim = input_dim\n        self.output_dim = output_dim\n        self.embeddings_initializer = initializers.get(embeddings_initializer)\n        self.embeddings_regularizer = regularizers.get(embeddings_regularizer)\n        self.embeddings_constraint = constraints.get(embeddings_constraint)\n        self.mask_zero = mask_zero\n        self.supports_masking = mask_zero\n        self.autocast = False\n        self.lora_rank = lora_rank\n        self.lora_alpha = lora_alpha if lora_alpha is not None else lora_rank\n        self.lora_enabled = False\n        self.quantization_config = quantization_config\n\n        if weights is not None:\n            self.build()\n            if not (isinstance(weights, list) and len(weights) == 1):\n                weights = [weights]\n            self.set_weights(weights)\n\n    def build(self, input_shape=None):\n        if self.built:\n            return\n        embeddings_shape = (self.input_dim, self.output_dim)\n        if self.quantization_mode:\n            self.quantized_build(\n                embeddings_shape,\n                mode=self.quantization_mode,\n                config=self.quantization_config,\n            )\n        if self.quantization_mode not in (\"int8\", \"int4\"):\n            self._embeddings = self.add_weight(\n                shape=embeddings_shape,\n                initializer=self.embeddings_initializer,\n                name=\"embeddings\",\n                regularizer=self.embeddings_regularizer,\n                constraint=self.embeddings_constraint,\n                trainable=True,\n            )\n        self.built = True\n        if self.lora_rank:\n            self.enable_lora(self.lora_rank)\n\n    @property\n    def embeddings(self):\n        if not self.built:\n            raise AttributeError(\n                \"You must build the layer before accessing `embeddings`.\"\n            )\n        embeddings = self._embeddings\n        if self.quantization_mode == \"int4\":\n            embeddings = quantizers.unpack_int4(\n                embeddings, self._orig_output_dim, axis=-1\n            )\n        if self.lora_enabled:\n            return embeddings + (self.lora_alpha / self.lora_rank) * ops.matmul(\n                self.lora_embeddings_a, self.lora_embeddings_b\n            )\n        return embeddings\n\n    def call(self, inputs):\n        if inputs.dtype != \"int32\" and inputs.dtype != \"int64\":\n            inputs = ops.cast(inputs, \"int32\")\n        outputs = ops.take(self.embeddings, inputs, axis=0)\n        return ops.cast(outputs, dtype=self.compute_dtype)\n\n    def compute_mask(self, inputs, mask=None):\n        if not self.mask_zero:\n            return None\n        return ops.not_equal(inputs, 0)\n\n    def compute_output_shape(self, input_shape):\n        return (*input_shape, self.output_dim)\n\n    def compute_output_spec(self, inputs):\n        output_shape = self.compute_output_shape(inputs.shape)\n        ragged = getattr(inputs, \"ragged\", False)\n        return KerasTensor(\n            output_shape, dtype=self.compute_dtype, ragged=ragged\n        )\n\n    def enable_lora(\n        self,\n        rank,\n        lora_alpha=None,\n        a_initializer=\"he_uniform\",\n        b_initializer=\"zeros\",\n    ):\n        if self.embeddings_constraint:\n            raise ValueError(\n                \"Lora is incompatible with embedding constraints. \"\n                \"In order to enable lora on this layer, remove the \"\n                \"`embeddings_constraint` argument.\"\n            )\n        if not self.built:\n            raise ValueError(\n                \"Cannot enable lora on a layer that isn't yet built.\"\n            )\n        if self.lora_enabled:\n            raise ValueError(\n                \"lora is already enabled. This can only be done once per layer.\"\n            )\n        self._tracker.unlock()\n        self.lora_embeddings_a = self.add_weight(\n            name=\"lora_embeddings_a\",\n            shape=(self.input_dim, rank),\n            initializer=initializers.get(a_initializer),\n            regularizer=self.embeddings_regularizer,\n        )\n        self.lora_embeddings_b = self.add_weight(\n            name=\"lora_embeddings_b\",\n            shape=(rank, self.output_dim),\n            initializer=initializers.get(b_initializer),\n            regularizer=self.embeddings_regularizer,\n        )\n        self.embeddings.trainable = False\n        self._tracker.lock()\n        self.lora_enabled = True\n        self.lora_rank = rank\n        self.lora_alpha = lora_alpha if lora_alpha is not None else rank\n\n    def save_own_variables(self, store):\n        # Do nothing if the layer isn't yet built\n        if not self.built:\n            return\n        mode = self.quantization_mode\n        if mode not in self.variable_serialization_spec:\n            raise self._quantization_mode_error(mode)\n\n        # Embeddings plus optional merged LoRA-aware scale/zero (returns\n        # (embeddings, None, None) for `None` mode).\n        embeddings_value, merged_embeddings_scale, merged_embeddings_zero = (\n            self._get_embeddings_with_merged_lora()\n        )\n        idx = 0\n        for name in self.variable_serialization_spec[mode]:\n            if name == \"embeddings\":\n                store[str(idx)] = embeddings_value\n            elif name == \"embeddings_zero\":\n                if merged_embeddings_zero is None:\n                    # embeddings_zero only exists for sub-channel int4\n                    # quantization\n                    continue\n                store[str(idx)] = merged_embeddings_zero\n            elif name == \"g_idx\" and not hasattr(self, \"g_idx\"):\n                # g_idx only exists for sub-channel int4 quantization\n                continue\n            elif name == \"embeddings_scale\" and mode in (\"int4\", \"int8\"):\n                # For int4/int8, the merged LoRA scale (if any) comes from\n                # `_get_embeddings_with_merged_lora()`\n                store[str(idx)] = merged_embeddings_scale\n            else:\n                # Generic handling for subclass variables:\n                # Check if the attribute exists on the instance before saving.\n                # This supports optional variables in subclasses (e.g.,\n                # `reverse_embeddings_zero` in ReversibleEmbedding) that are\n                # present in the spec but may not exist on the object depending\n                # on configuration (e.g., per-channel vs. sub-channel).\n                if not hasattr(self, name):\n                    continue\n                store[str(idx)] = getattr(self, name)\n            idx += 1\n\n    def load_own_variables(self, store):\n        if not self.lora_enabled:\n            self._check_load_own_variables(store)\n        # Do nothing if the layer isn't yet built\n        if not self.built:\n            return\n        mode = self.quantization_mode\n        if mode not in self.variable_serialization_spec:\n            raise self._quantization_mode_error(mode)\n\n        idx = 0\n        for name in self.variable_serialization_spec[mode]:\n            if name == \"embeddings\":\n                self._embeddings.assign(store[str(idx)])\n            elif name == \"embeddings_zero\" and not hasattr(\n                self, \"embeddings_zero\"\n            ):\n                # embeddings_zero only exists for sub-channel int4 quantization\n                continue\n            elif name == \"g_idx\" and not hasattr(self, \"g_idx\"):\n                # g_idx only exists for sub-channel int4 quantization\n                continue\n            else:\n                # Generic handling for subclass variables:\n                # Check if the attribute exists before attempting to assign.\n                # If the variable is in the spec but missing from the object,\n                # we skip it to prevent AttributeError.\n                if not hasattr(self, name):\n                    continue\n                getattr(self, name).assign(store[str(idx)])\n            idx += 1\n        if self.lora_enabled:\n            self.lora_embeddings_a.assign(\n                ops.zeros(self.lora_embeddings_a.shape)\n            )\n            self.lora_embeddings_b.assign(\n                ops.zeros(self.lora_embeddings_b.shape)\n            )\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\n            \"input_dim\": self.input_dim,\n            \"output_dim\": self.output_dim,\n            \"embeddings_initializer\": initializers.serialize(\n                self.embeddings_initializer\n            ),\n            \"embeddings_regularizer\": regularizers.serialize(\n                self.embeddings_regularizer\n            ),\n            \"activity_regularizer\": regularizers.serialize(\n                self.activity_regularizer\n            ),\n            \"embeddings_constraint\": constraints.serialize(\n                self.embeddings_constraint\n            ),\n            \"mask_zero\": self.mask_zero,\n            \"quantization_config\": serialization_lib.serialize_keras_object(\n                self.quantization_config\n            ),\n        }\n        if self.lora_rank:\n            config[\"lora_rank\"] = self.lora_rank\n            config[\"lora_alpha\"] = self.lora_alpha\n        return {**base_config, **config}\n\n    @classmethod\n    def from_config(cls, config):\n        config = config.copy()\n        config[\"quantization_config\"] = (\n            serialization_lib.deserialize_keras_object(\n                config.get(\"quantization_config\", None)\n            )\n        )\n        return super().from_config(config)\n\n    def _quantization_mode_error(self, mode):\n        return NotImplementedError(\n            \"Invalid quantization mode. Expected one of ('int8', 'int4'). \"\n            f\"Received: quantization_mode={mode}\"\n        )\n\n    @property\n    def variable_serialization_spec(self):\n        \"\"\"Returns a dict mapping quantization modes to variable names in order.\n\n        This spec is used by `save_own_variables` and `load_own_variables` to\n        determine the correct ordering of variables during serialization for\n        each quantization mode. `None` means no quantization.\n        \"\"\"\n        return {\n            None: [\n                \"embeddings\",\n            ],\n            \"int8\": [\n                \"embeddings\",\n                \"embeddings_scale\",\n            ],\n            \"int4\": [\n                \"embeddings\",\n                \"embeddings_scale\",\n                \"embeddings_zero\",\n                \"g_idx\",\n            ],\n        }\n\n    def quantized_build(self, embeddings_shape, mode, config=None):\n        if mode == \"int8\":\n            self._int8_build(embeddings_shape, config)\n        elif mode == \"int4\":\n            self._int4_build(embeddings_shape, config)\n        else:\n            raise self._quantization_mode_error(mode)\n        self._is_quantized = True\n\n    def _int8_build(self, embeddings_shape, config=None):\n        self._embeddings = self.add_weight(\n            name=\"embeddings\",\n            shape=embeddings_shape,\n            initializer=\"zeros\",\n            dtype=\"int8\",\n            trainable=False,\n        )\n        # We choose to reduce the axis of `output_dim` because, typically,\n        # `input_dim` is larger than `output_dim`. This reduces quantization\n        # error.\n        self.embeddings_scale = self.add_weight(\n            name=\"embeddings_scale\",\n            shape=(self.input_dim,),\n            initializer=\"ones\",\n            trainable=False,\n        )\n\n    def _int4_build(self, embeddings_shape, config=None):\n        \"\"\"Build variables for int4 quantization.\n\n        Args:\n            embeddings_shape: Original shape `(input_dim, output_dim)`.\n            config: Optional quantization config specifying block_size.\n        \"\"\"\n        input_dim, output_dim = embeddings_shape\n        packed_rows = (output_dim + 1) // 2\n\n        # Embeddings are stored packed: each int8 byte contains two\n        # int4 values.\n        self._embeddings = self.add_weight(\n            name=\"embeddings\",\n            shape=(input_dim, packed_rows),\n            initializer=\"zeros\",\n            dtype=\"int8\",\n            trainable=False,\n        )\n\n        block_size = get_block_size_for_layer(self, config)\n        self._int4_block_size = block_size\n\n        if block_size is None or block_size == -1:\n            scale_shape = (self.input_dim,)\n        else:\n            n_groups = math.ceil(output_dim / block_size)\n            scale_shape = (self.input_dim, n_groups)\n\n        self.embeddings_scale = self.add_weight(\n            name=\"embeddings_scale\",\n            shape=scale_shape,\n            initializer=\"ones\",\n            trainable=False,\n        )\n\n        # Sub-channel quantization uses asymmetric quantization with\n        # zero point\n        if block_size is not None and block_size > 0:\n            self.embeddings_zero = self.add_weight(\n                name=\"embeddings_zero\",\n                shape=scale_shape,\n                initializer=\"zeros\",\n                dtype=\"int8\",\n                trainable=False,\n            )\n            self.g_idx = self.add_weight(\n                name=\"g_idx\",\n                shape=(output_dim,),\n                initializer=\"zeros\",\n                dtype=\"float32\",\n                trainable=False,\n            )\n            self.g_idx.assign(\n                ops.floor_divide(\n                    ops.arange(output_dim, dtype=\"float32\"), block_size\n                )\n            )\n\n        self._orig_output_dim = output_dim\n\n    def _int8_call(self, inputs, training=None):\n        # We cannot update quantized self._embeddings, so the custom gradient is\n        # not needed\n        if backend.standardize_dtype(inputs.dtype) not in (\"int32\", \"int64\"):\n            inputs = ops.cast(inputs, \"int32\")\n        embeddings_scale = ops.take(self.embeddings_scale, inputs, axis=0)\n        outputs = ops.take(self._embeddings, inputs, axis=0)\n        # De-scale outputs\n        outputs = ops.divide(\n            ops.cast(outputs, dtype=self.compute_dtype),\n            ops.expand_dims(embeddings_scale, axis=-1),\n        )\n        if self.lora_enabled:\n            lora_outputs = ops.take(self.lora_embeddings_a, inputs, axis=0)\n            lora_outputs = ops.matmul(lora_outputs, self.lora_embeddings_b)\n            outputs = ops.add(\n                outputs, (self.lora_alpha / self.lora_rank) * lora_outputs\n            )\n        return outputs\n\n    def _int4_call(self, inputs, training=None):\n        \"\"\"Forward pass for int4 quantized Embedding layer.\"\"\"\n        if backend.standardize_dtype(inputs.dtype) not in (\"int32\", \"int64\"):\n            inputs = ops.cast(inputs, \"int32\")\n\n        unpacked_embeddings = quantizers.unpack_int4(\n            self._embeddings, self._orig_output_dim, axis=-1\n        )\n        outputs = ops.take(unpacked_embeddings, inputs, axis=0)\n\n        block_size = getattr(self, \"_int4_block_size\", None)\n\n        if block_size is None or block_size == -1:\n            embeddings_scale = ops.take(self.embeddings_scale, inputs, axis=0)\n            outputs = ops.divide(\n                ops.cast(outputs, dtype=self.compute_dtype),\n                ops.expand_dims(embeddings_scale, axis=-1),\n            )\n        else:\n            # Sub-channel: look up scale/zero for each input token,\n            # then dequantize using g_idx to expand groups\n            embeddings_scale = ops.take(self.embeddings_scale, inputs, axis=0)\n            embeddings_zero = ops.take(self.embeddings_zero, inputs, axis=0)\n\n            # Scale/zero are [batch..., n_groups], g_idx is [output_dim]\n            outputs = dequantize_with_sz_map(\n                ops.cast(outputs, dtype=self.compute_dtype),\n                embeddings_scale,\n                embeddings_zero,\n                self.g_idx,\n                group_axis=-1,\n            )\n\n        if self.lora_enabled:\n            lora_outputs = ops.take(self.lora_embeddings_a, inputs, axis=0)\n            lora_outputs = ops.matmul(lora_outputs, self.lora_embeddings_b)\n            outputs = ops.add(\n                outputs, (self.lora_alpha / self.lora_rank) * lora_outputs\n            )\n        return outputs\n\n    def quantize(self, mode=None, type_check=True, config=None):\n        # Prevent quantization of the subclasses.\n        if type_check and (type(self) is not Embedding):\n            raise self._not_implemented_error(self.quantize)\n\n        self.quantization_config = config\n\n        embeddings_shape = (self.input_dim, self.output_dim)\n        if mode == \"int8\":\n            # Quantize `self._embeddings` to int8 and compute corresponding\n            # scale.\n            weight_quantizer = QuantizationConfig.weight_quantizer_or_default(\n                self.quantization_config,\n                quantizers.AbsMaxQuantizer(axis=-1),\n            )\n            embeddings_value, embeddings_scale = weight_quantizer(\n                self._embeddings, to_numpy=True\n            )\n            embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)\n            del self._embeddings\n            self.quantized_build(\n                embeddings_shape, mode, self.quantization_config\n            )\n            self._embeddings.assign(embeddings_value)\n            self.embeddings_scale.assign(embeddings_scale)\n        elif mode == \"int4\":\n            from keras.src.quantizers.quantization_config import (\n                Int4QuantizationConfig,\n            )\n\n            block_size = None\n            if isinstance(self.quantization_config, Int4QuantizationConfig):\n                block_size = self.quantization_config.block_size\n\n            use_grouped = block_size is not None and block_size != -1\n\n            if not use_grouped:\n                # Per-channel quantization\n                weight_quantizer = (\n                    QuantizationConfig.weight_quantizer_or_default(\n                        self.quantization_config,\n                        quantizers.AbsMaxQuantizer(\n                            axis=-1,\n                            value_range=(-8, 7),\n                            output_dtype=\"int8\",\n                        ),\n                    )\n                )\n                embeddings_value, embeddings_scale = weight_quantizer(\n                    self._embeddings, to_numpy=True\n                )\n                embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)\n            else:\n                # Sub-channel quantization with asymmetric zero point\n                input_dim, output_dim = ops.shape(self._embeddings)\n                # Transpose to put output_dim first for grouped quantization\n                embeddings_t = ops.transpose(self._embeddings)\n\n                embeddings_value_t, scale_t, zero_t = (\n                    quantizers.abs_max_quantize_grouped_with_zero_point(\n                        embeddings_t,\n                        block_size=block_size,\n                        value_range=(-8, 7),\n                        dtype=\"int8\",\n                        to_numpy=True,\n                    )\n                )\n                # Transpose back to (input_dim, output_dim) layout\n                embeddings_value = ops.transpose(embeddings_value_t)\n                embeddings_scale = ops.transpose(scale_t)\n                embeddings_zero = ops.transpose(zero_t)\n\n            packed_embeddings_value, _, _ = quantizers.pack_int4(\n                embeddings_value, axis=-1\n            )\n            del self._embeddings\n            self.quantized_build(\n                embeddings_shape, mode, self.quantization_config\n            )\n            self._embeddings.assign(packed_embeddings_value)\n            self.embeddings_scale.assign(embeddings_scale)\n            if use_grouped:\n                self.embeddings_zero.assign(embeddings_zero)\n        else:\n            raise self._quantization_mode_error(mode)\n\n        # Set new dtype policy.\n        if self.dtype_policy.quantization_mode is None:\n            policy_name = mode\n            if mode == \"int4\":\n                # Include block_size in policy name for sub-channel quantization\n                block_size = get_block_size_for_layer(self, config)\n                block_size_value = -1 if block_size is None else block_size\n                policy_name = f\"int4/{block_size_value}\"\n            policy = dtype_policies.get(\n                f\"{policy_name}_from_{self.dtype_policy.name}\"\n            )\n            self.dtype_policy = policy\n\n    def _get_embeddings_with_merged_lora(self):\n        \"\"\"Returns the embeddings with LoRA matrices merged, for serialization.\n\n        This method is called by `save_own_variables` to produce a single\n        embeddings tensor that includes the adaptations from LoRA. This is\n        useful for deploying the model or for continuing training after\n        permanently applying the LoRA update.\n\n        If the layer is quantized (`int8` or `int4`), the process is:\n        1. Dequantize the base embeddings to float.\n        2. Compute the LoRA delta (`lora_embeddings_a @ lora_embeddings_b`) and\n            add it to the dequantized embeddings.\n        3. Re-quantize the merged result back to the original quantized\n            type (`int8` or packed `int4`), calculating a new scale factor.\n\n        If the layer is not quantized, this method returns the result of the\n        `embeddings` property (which computes the merge in floating-point) and a\n        scale of `None`.\n\n        If LoRA is not enabled, it returns the original embeddings and scale\n        without modification.\n\n        Returns:\n            A tuple `(embeddings_value, embeddings_scale, embeddings_zero)`:\n                `embeddings_value`: The merged embeddings. A quantized tensor if\n                    quantization is active, otherwise a high precision tensor.\n                `embeddings_scale`: The quantization scale for the merged\n                    embeddings. This is `None` if the layer is not quantized.\n                `embeddings_zero`: The zero point for sub-channel quantization.\n                    This is `None` for per-channel quantization modes.\n        \"\"\"\n        if self.dtype_policy.quantization_mode in (None, \"gptq\", \"awq\"):\n            return self.embeddings, None, None\n\n        embeddings_value = self._embeddings\n        embeddings_scale = self.embeddings_scale\n        embeddings_zero = getattr(self, \"embeddings_zero\", None)\n\n        if not self.lora_enabled:\n            return embeddings_value, embeddings_scale, embeddings_zero\n\n        block_size = getattr(self, \"_int4_block_size\", None)\n\n        # Dequantize embeddings to float.\n        if self.quantization_mode == \"int4\":\n            unpacked_embeddings = quantizers.unpack_int4(\n                embeddings_value, self._orig_output_dim, axis=-1\n            )\n            if block_size is None or block_size == -1:\n                # Per-channel dequantization\n                float_embeddings = ops.divide(\n                    ops.cast(unpacked_embeddings, self.compute_dtype),\n                    ops.expand_dims(embeddings_scale, axis=-1),\n                )\n            else:\n                # Sub-channel: grouped dequantization using shared utility\n                float_embeddings = dequantize_with_sz_map(\n                    ops.cast(unpacked_embeddings, self.compute_dtype),\n                    embeddings_scale,\n                    self.embeddings_zero,\n                    self.g_idx,\n                    group_axis=-1,\n                )\n            quant_range = (-8, 7)\n        elif self.quantization_mode == \"int8\":\n            float_embeddings = ops.divide(\n                ops.cast(embeddings_value, self.compute_dtype),\n                ops.expand_dims(embeddings_scale, axis=-1),\n            )\n            quant_range = (-127, 127)\n        else:\n            raise ValueError(\n                f\"Unsupported quantization mode: {self.quantization_mode}\"\n            )\n\n        # Merge LoRA weights in float domain.\n        lora_delta = (self.lora_alpha / self.lora_rank) * ops.matmul(\n            self.lora_embeddings_a, self.lora_embeddings_b\n        )\n        merged_float_embeddings = ops.add(float_embeddings, lora_delta)\n\n        # Requantize.\n        if self.quantization_mode == \"int4\":\n            if block_size is None or block_size == -1:\n                # Per-channel re-quantization\n                requantized_embeddings, new_scale = quantizers.abs_max_quantize(\n                    merged_float_embeddings,\n                    axis=-1,\n                    value_range=quant_range,\n                    dtype=\"int8\",\n                    to_numpy=True,\n                )\n                new_scale = ops.squeeze(new_scale, axis=-1)\n                embeddings_zero = None\n            else:\n                # Grouped re-quantization (asymmetric with zero point)\n                merged_np = merged_float_embeddings\n                # Transpose to (output_dim, input_dim) for grouped quantization\n                merged_t = ops.transpose(merged_np)\n\n                requantized_t, scale_t, zero_t = (\n                    quantizers.abs_max_quantize_grouped_with_zero_point(\n                        merged_t,\n                        block_size=block_size,\n                        value_range=quant_range,\n                        dtype=\"int8\",\n                        to_numpy=True,\n                    )\n                )\n                # Transpose back\n                requantized_embeddings = ops.transpose(requantized_t)\n                new_scale = ops.transpose(scale_t)\n                embeddings_zero = ops.transpose(zero_t)\n\n            # Pack for int4\n            embeddings_value, _, _ = quantizers.pack_int4(\n                requantized_embeddings, axis=-1\n            )\n            embeddings_scale = new_scale\n        else:\n            # int8 re-quantization\n            requantized_embeddings, embeddings_scale = (\n                quantizers.abs_max_quantize(\n                    merged_float_embeddings,\n                    axis=-1,\n                    value_range=quant_range,\n                    dtype=\"int8\",\n                    to_numpy=True,\n                )\n            )\n            embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)\n            embeddings_value = requantized_embeddings\n            embeddings_zero = None\n        return embeddings_value, embeddings_scale, embeddings_zero\n"
  },
  {
    "path": "keras/src/layers/core/embedding_test.py",
    "content": "import math\nimport os\n\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import constraints\nfrom keras.src import export\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import quantizers\nfrom keras.src import saving\nfrom keras.src import testing\nfrom keras.src.quantizers.quantization_config import Int4QuantizationConfig\nfrom keras.src.quantizers.quantization_config import Int8QuantizationConfig\nfrom keras.src.quantizers.quantizers import AbsMaxQuantizer\nfrom keras.src.testing import test_case\n\n\nclass EmbeddingTest(test_case.TestCase):\n    @parameterized.named_parameters(\n        (\"int8\", \"int8\", {\"axis\": -1}),\n        (\n            \"int4\",\n            \"int4\",\n            {\"axis\": -1, \"value_range\": (-8, 7), \"output_dtype\": \"int8\"},\n        ),\n        (\"int8_custom\", \"int8\", {\"axis\": -1}),\n    )\n    def test_embedding_quantize_config(self, mode, weight_quantizer_args):\n        \"\"\"Test Embedding quantization with QuantizationConfig.\"\"\"\n        layer = layers.Embedding(input_dim=10, output_dim=6)\n        layer.build((None,))\n\n        weight_quantizer = AbsMaxQuantizer(**weight_quantizer_args)\n        if mode == \"int8\":\n            config = Int8QuantizationConfig(\n                weight_quantizer=weight_quantizer, activation_quantizer=None\n            )\n        elif mode == \"int4\":\n            # Custom quantizers require per-channel mode (block_size=None)\n            config = Int4QuantizationConfig(\n                weight_quantizer=weight_quantizer,\n                activation_quantizer=None,\n                block_size=None,\n            )\n\n        layer.quantize(mode, config=config)\n\n        # Verify weights are quantized\n        self.assertEqual(\n            backend.standardize_dtype(layer._embeddings.dtype), \"int8\"\n        )\n        self.assertTrue(hasattr(layer, \"embeddings_scale\"))\n\n        # Verify call works\n        x = np.random.randint(0, 10, size=(2, 3))\n        y = layer(x)\n        self.assertEqual(y.shape, (2, 3, 6))\n\n    @pytest.mark.requires_trainable_backend\n    def test_embedding_basics(self):\n        self.run_layer_test(\n            layers.Embedding,\n            {\"input_dim\": 4, \"output_dim\": 3},\n            input_shape=(2,),\n            input_dtype=\"int32\",\n            expected_output_shape=(2, 3),\n            expected_num_trainable_weights=1,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n        self.run_layer_test(\n            layers.Embedding,\n            {\"input_dim\": 5, \"output_dim\": 4, \"mask_zero\": True},\n            input_shape=(2, 3),\n            input_dtype=\"int64\",\n            expected_output_shape=(2, 3, 4),\n            expected_num_trainable_weights=1,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n        )\n\n    @pytest.mark.skipif(\n        not backend.SUPPORTS_SPARSE_TENSORS,\n        reason=\"Backend does not support sparse tensors.\",\n    )\n    def test_sparse(self):\n        self.run_layer_test(\n            layers.Embedding,\n            {\"input_dim\": 5, \"output_dim\": 4},\n            input_shape=(2, 3),\n            input_dtype=\"int32\",\n            input_sparse=True,\n            expected_output_shape=(2, 3, 4),\n            expected_num_trainable_weights=1,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n\n    @pytest.mark.skipif(\n        not backend.SUPPORTS_RAGGED_TENSORS,\n        reason=\"Backend does not support ragged tensors.\",\n    )\n    def test_ragged(self):\n        self.run_layer_test(\n            layers.Embedding,\n            {\"input_dim\": 5, \"output_dim\": 4},\n            input_shape=(2, 3),\n            input_dtype=\"int32\",\n            input_ragged=True,\n            expected_output_shape=(2, None, 4),\n            expected_output_ragged=True,\n            expected_num_trainable_weights=1,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n            # run_training_check=False,\n        )\n\n    def test_correctness(self):\n        layer = layers.Embedding(input_dim=3, output_dim=2)\n        layer.build()\n        layer.embeddings.assign(np.array([[0.0, 0.0], [2.0, 2.0], [3.0, 3.0]]))\n        out = layer(np.array([2, 1, 0]))\n        self.assertAllClose(out, np.array([[3.0, 3.0], [2.0, 2.0], [0.0, 0.0]]))\n\n    @pytest.mark.skipif(\n        not backend.SUPPORTS_SPARSE_TENSORS,\n        reason=\"Backend does not support sparse tensors.\",\n    )\n    def test_correctness_sparse(self):\n        layer = layers.Embedding(input_dim=3, output_dim=2)\n        layer.build()\n        layer.embeddings.assign(np.array([[0.0, 0.0], [2.0, 2.0], [3.0, 3.0]]))\n\n        if backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            x = tf.SparseTensor([[0, 0], [1, 2]], [2, 1], (2, 3))\n        elif backend.backend() == \"jax\":\n            import jax.experimental.sparse as jax_sparse\n\n            x = jax_sparse.BCOO(([2, 1], [[0, 0], [1, 2]]), shape=(2, 3))\n        else:\n            self.fail(f\"Sparse is unsupported with backend {backend.backend()}\")\n\n        self.assertAllClose(\n            layer(x),\n            np.array(\n                [\n                    [[3.0, 3.0], [0.0, 0.0], [0.0, 0.0]],\n                    [[0.0, 0.0], [0.0, 0.0], [2.0, 2.0]],\n                ]\n            ),\n        )\n\n    def test_masking(self):\n        layer = layers.Embedding(input_dim=3, output_dim=2, mask_zero=True)\n        layer.build()\n        out = layer.compute_mask(np.array(([2, 1, 0])))\n        self.assertAllClose(out, np.array([True, True, False]))\n\n    def test_compute_mask_no_masking(self):\n        layer = layers.Embedding(input_dim=3, output_dim=2, mask_zero=False)\n        input_data = np.array([2, 1, 0])\n        mask = layer.compute_mask(input_data)\n        self.assertIsNone(mask)\n\n    def test_embedding_constraints(self):\n        layer = layers.Embedding(3, 2, embeddings_constraint=\"non_neg\")\n        layer.build((None, 2))\n        self.assertIsInstance(layer.embeddings.constraint, constraints.NonNeg)\n\n    def test_weights_constructor_arg(self):\n        layer = layers.Embedding(3, 4, weights=np.ones((3, 4)))\n        self.assertAllClose(layer.embeddings.numpy(), np.ones((3, 4)))\n        layer = layers.Embedding(3, 4, weights=[np.ones((3, 4))])\n        self.assertAllClose(layer.embeddings.numpy(), np.ones((3, 4)))\n\n    @pytest.mark.requires_trainable_backend\n    def test_enable_lora(self):\n        layer = layers.Embedding(10, 16)\n        layer.build()\n        layer.enable_lora(4)\n        self.assertLen(layer.trainable_weights, 2)\n        self.assertLen(layer.non_trainable_weights, 1)\n        if backend.backend() == \"torch\":\n            self.assertLen(layer.torch_params, 3)\n        # Try eager call\n        x = np.random.randint(0, 9, size=(64, 3))\n        y = np.random.random((64, 3, 16))\n        _ = layer(x[:2])\n\n        init_lora_a_embeddings_value = layer.lora_embeddings_a.numpy()\n        init_lora_b_embeddings_value = layer.lora_embeddings_b.numpy()\n\n        # Try calling fit()\n        model = models.Sequential(\n            [\n                layer,\n            ]\n        )\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n        model.fit(x, y)\n\n        final_lora_a_embeddings_value = layer.lora_embeddings_a.numpy()\n        final_lora_b_embeddings_value = layer.lora_embeddings_b.numpy()\n        diff_a = np.max(\n            np.abs(init_lora_a_embeddings_value - final_lora_a_embeddings_value)\n        )\n        diff_b = np.max(\n            np.abs(init_lora_b_embeddings_value - final_lora_b_embeddings_value)\n        )\n        self.assertGreater(diff_a, 0.0)\n        self.assertGreater(diff_b, 0.0)\n\n        # Try saving and reloading the model\n        temp_filepath = os.path.join(self.get_temp_dir(), \"lora_model.keras\")\n        model.save(temp_filepath)\n\n        new_model = saving.load_model(temp_filepath)\n        self.assertTrue(new_model.layers[0].lora_enabled)\n        self.assertAllClose(model.predict(x), new_model.predict(x))\n\n        # Try saving and reloading the model's weights only\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"lora_model.weights.h5\"\n        )\n        model.save_weights(temp_filepath)\n\n        # Load the file into a fresh, non-lora model\n        new_model = models.Sequential(\n            [\n                layers.Input((3,), dtype=\"int32\"),\n                layers.Embedding(10, 16),\n            ]\n        )\n        new_model.load_weights(temp_filepath)\n        self.assertAllClose(model.predict(x), new_model.predict(x))\n\n        # Try loading a normal checkpoint into a lora model\n        new_model.save_weights(temp_filepath)\n        model.load_weights(temp_filepath)\n        self.assertAllClose(model.predict(x), new_model.predict(x))\n\n    @pytest.mark.requires_trainable_backend\n    def test_enable_lora_with_alpha(self):\n        # Create an `Embedding` layer without specifying `lora_rank`\n        layer = layers.Embedding(input_dim=3, output_dim=2)\n        layer.build((None,))  # Build the layer\n\n        # Set the base embeddings to known values.\n        base_emb = np.array(\n            [[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]], dtype=np.float32\n        )\n        layer.embeddings.assign(base_emb)\n\n        # Enable LoRA with a custom alpha: `rank`=2, `lora_alpha`=3.0.\n        layer.enable_lora(2, lora_alpha=3.0)\n        self.assertEqual(layer.lora_rank, 2)\n        self.assertEqual(layer.lora_alpha, 3.0)\n\n        # Manually assign known values to lora weights.\n        a_val = np.array([[0.1, 0.1], [0.2, 0.2], [0.3, 0.3]], dtype=np.float32)\n        b_val = np.array([[0.5, 0.5], [0.6, 0.6]], dtype=np.float32)\n        layer.lora_embeddings_a.assign(a_val)\n        layer.lora_embeddings_b.assign(b_val)\n\n        # Compute the expected delta.\n        # Scaling factor: (3.0 / 2) = 1.5\n        effective_delta = 1.5 * np.matmul(a_val, b_val)\n        expected_embeddings = base_emb + effective_delta\n\n        # Verify that the effective embeddings match expectation.\n        actual_embeddings = ops.convert_to_numpy(layer.embeddings)\n        self.assertAllClose(\n            actual_embeddings, expected_embeddings, tpu_atol=1e-3, tpu_rtol=1e-3\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_lora_rank_argument(self):\n        self.run_layer_test(\n            layers.Embedding,\n            init_kwargs={\"input_dim\": 5, \"output_dim\": 4, \"lora_rank\": 2},\n            input_shape=(2, 3),\n            input_dtype=\"int32\",\n            expected_output_shape=(2, 3, 4),\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=1,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n\n    def test_enable_lora_with_embeddings_constraint(self):\n        layer = layers.Embedding(\n            input_dim=10, output_dim=16, embeddings_constraint=\"max_norm\"\n        )\n        with self.assertRaisesRegex(\n            ValueError, \"incompatible with embedding constraints\"\n        ):\n            layer.enable_lora(rank=2)\n\n    def test_enable_lora_when_already_enabled(self):\n        layer = layers.Embedding(input_dim=10, output_dim=16)\n        layer.build()\n        layer.enable_lora(rank=2)\n        with self.assertRaisesRegex(ValueError, \"lora is already enabled\"):\n            layer.enable_lora(rank=2)\n\n    # Test quantization-related methods.\n\n    @parameterized.named_parameters(\n        (\"int8\", \"int8\"),\n        (\"int4\", \"int4\"),\n    )\n    def test_quantize_int(self, mode):\n        layer = layers.Embedding(10, 16)\n        layer.build()\n        x = np.random.randint(0, 9, size=(64, 3))\n        y_float = layer(x)\n        layer.quantize(mode)\n\n        # Verify the dtype of the weights.\n        # The embeddings's dtype is int8, despite the int4 quantization, because\n        # we pack the int4 values into int8.\n        self.assertEqual(\n            backend.standardize_dtype(layer._embeddings.dtype), \"int8\"\n        )\n        self.assertEqual(\n            backend.standardize_dtype(layer.embeddings_scale.dtype),\n            layer.variable_dtype,\n        )\n\n        # Verify the unpacked embeddings for int4 quantization.\n        if mode == \"int4\":\n            self.assertAllClose(\n                layer.embeddings,\n                quantizers.unpack_int4(\n                    layer._embeddings, layer.output_dim, axis=-1\n                ),\n            )\n\n        # Verify the correctness of the outputs.\n        y_quantized = layer(x)\n        mse = ops.mean(ops.square(y_float - y_quantized))\n        self.assertLess(mse, 1e-3)  # A weak correctness test\n\n        # Check model save / load round-trip.\n        model = models.Sequential([layer])\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"quantized_model.keras\"\n        )\n        model.save(temp_filepath)\n        new_model = saving.load_model(temp_filepath)\n        self.assertAllClose(model.predict(x), new_model.predict(x))\n\n        # Check weights-only save / load round-trip.\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"quantized_model.weights.h5\"\n        )\n        model.save_weights(temp_filepath)\n        new_model = models.Sequential([layers.Embedding(10, 16)])\n        new_model.build((None, 3))\n        new_model.quantize(mode)\n        new_model.load_weights(temp_filepath)\n        self.assertAllClose(model.predict(x), new_model.predict(x))\n\n    @parameterized.named_parameters(\n        (\"int8\", \"int8\"),\n        (\"int4\", \"int4\"),\n    )\n    def test_quantize_on_unbuilt_layer(self, mode):\n        layer = layers.Embedding(10, 16)\n        with self.assertRaisesRegex(\n            ValueError, \"Cannot quantize a layer that isn't yet built.\"\n        ):\n            layer.quantize(mode)\n\n    @parameterized.named_parameters(\n        (\"int8\", \"int8\"),\n        (\"int4\", \"int4\"),\n    )\n    def test_quantize_on_subclass(self, mode):\n        class MyEmbedding(layers.Embedding):\n            pass\n\n        layer = MyEmbedding(10, 16)\n        layer.build()\n        with self.assertRaises(NotImplementedError):\n            layer.quantize(mode)\n\n        layer.quantize(mode, type_check=False)  # No error\n\n    @parameterized.named_parameters(\n        (\"int8\", \"int8\"),\n        (\"int4\", \"int4\"),\n    )\n    def test_quantize_when_already_quantized(self, mode):\n        layer = layers.Embedding(10, 16)\n        layer.build()\n        layer.quantize(mode)\n        for m in (\"int8\", \"int4\"):\n            with self.assertRaisesRegex(\n                ValueError, \"is already quantized with dtype_policy=\"\n            ):\n                layer.quantize(m)\n\n        layer = layers.Embedding(10, 16, dtype=f\"{mode}_from_float32\")\n        layer.build()\n        for m in (\"int8\", \"int4\"):\n            with self.assertRaisesRegex(\n                ValueError, \"is already quantized with dtype_policy=\"\n            ):\n                layer.quantize(m)\n\n    @parameterized.named_parameters(\n        (\"int8\", \"int8_from_float32\", 2),\n        (\"int4\", \"int4_from_float32\", 4),  # embeddings + scale + zero\n    )\n    def test_quantize_by_setting_dtype_policy(\n        self, policy, expected_num_variables\n    ):\n        layer = layers.Embedding(10, 16)\n        layer.build()\n        layer.dtype_policy = policy\n        self.assertLen(layer.variables, expected_num_variables)\n\n    @parameterized.named_parameters(\n        (\"int7\", \"int7\"),\n        (\"float7\", \"float7\"),\n    )\n    def test_quantize_invalid_mode(self, mode):\n        layer = layers.Embedding(10, 16)\n        layer.build()\n        x = np.random.randint(0, 9, size=(1, 3))\n        # dtype_policy should not be altered by failed quantization\n        original_dtype_policy = layer.dtype_policy\n\n        # Test quantize\n        with self.assertRaisesRegex(ValueError, \"Invalid quantization mode.\"):\n            layer.quantize(mode)\n        self.assertEqual(layer.dtype_policy, original_dtype_policy)\n\n        # Test quantized_build\n        with self.assertRaisesRegex(\n            NotImplementedError, \"Invalid quantization mode.\"\n        ):\n            layer.quantized_build((None, 2), mode)\n        self.assertEqual(layer.dtype_policy, original_dtype_policy)\n\n        # Test quantized_call\n        with self.assertRaisesRegex(\n            NotImplementedError, \"Invalid quantization mode.\"\n        ):\n            # Explicitly set quantization_mode\n            layer._dtype_policy._quantization_mode = mode\n            layer.quantized_call(x)\n        self.assertEqual(layer.dtype_policy, original_dtype_policy)\n\n    @parameterized.named_parameters(\n        (\"int8\", \"int8_from_mixed_bfloat16\", 0, 2),\n        (\n            \"int4\",\n            \"int4_from_mixed_bfloat16\",\n            0,\n            2,\n        ),  # per-channel (no zero point)\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_quantize_dtype_argument(\n        self, dtype, num_trainable_weights, num_non_trainable_weights\n    ):\n        self.run_layer_test(\n            layers.Embedding,\n            {\"input_dim\": 4, \"output_dim\": 3, \"dtype\": dtype},\n            input_shape=(2,),\n            input_dtype=\"int32\",\n            expected_output_shape=(2, 3),\n            expected_num_trainable_weights=num_trainable_weights,\n            expected_num_non_trainable_weights=num_non_trainable_weights,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n        self.run_layer_test(\n            layers.Embedding,\n            {\n                \"input_dim\": 5,\n                \"output_dim\": 4,\n                \"mask_zero\": True,\n                \"dtype\": dtype,\n            },\n            input_shape=(2, 3),\n            input_dtype=\"int64\",\n            expected_output_shape=(2, 3, 4),\n            expected_num_trainable_weights=num_trainable_weights,\n            expected_num_non_trainable_weights=num_non_trainable_weights,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n        )\n\n    @parameterized.named_parameters(\n        (\"int8\", \"int8\", 2, 2, 4),\n        (\"int4\", \"int4\", 2, 4, 6),  # +2 for embeddings_zero + g_idx\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_quantize_lora_integration(\n        self,\n        mode,\n        num_trainable_weights,\n        num_non_trainable_weights,\n        num_torch_params,\n    ):\n        layer = layers.Embedding(10, 16)\n        layer.build()\n        layer.enable_lora(4)\n        layer.quantize(mode)\n        self.assertLen(layer.trainable_weights, num_trainable_weights)\n        self.assertLen(layer.non_trainable_weights, num_non_trainable_weights)\n        if backend.backend() == \"torch\":\n            self.assertLen(layer.torch_params, num_torch_params)\n\n        # Try calling fit()\n        init_lora_a_embeddings_value = layer.lora_embeddings_a.numpy()\n        init_lora_b_embeddings_value = layer.lora_embeddings_b.numpy()\n        x = np.random.randint(0, 9, size=(64, 3))\n        y = np.random.random((64, 3, 16))\n        model = models.Sequential([layer])\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n        model.fit(x, y)\n\n        final_lora_a_embeddings_value = layer.lora_embeddings_a.numpy()\n        final_lora_b_embeddings_value = layer.lora_embeddings_b.numpy()\n        diff_a = np.max(\n            np.abs(init_lora_a_embeddings_value - final_lora_a_embeddings_value)\n        )\n        diff_b = np.max(\n            np.abs(init_lora_b_embeddings_value - final_lora_b_embeddings_value)\n        )\n        self.assertGreater(diff_a, 0.0)\n        self.assertGreater(diff_b, 0.0)\n\n        # Try saving and reloading the model\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"quantized_lora_model.keras\"\n        )\n        model.save(temp_filepath)\n        new_model = saving.load_model(temp_filepath)\n        self.assertTrue(new_model.layers[0].lora_enabled)\n        self.assertAllClose(model.predict(x), new_model.predict(x), atol=0.5)\n\n        # Try saving and reloading the model's weights only\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"quantized_lora_model.weights.h5\"\n        )\n        model.save_weights(temp_filepath)\n        new_model = models.Sequential(\n            [layers.Input((3,), dtype=\"int32\"), layers.Embedding(10, 16)]\n        )\n        new_model.quantize(mode)\n        new_model.load_weights(temp_filepath)\n        self.assertFalse(new_model.layers[0].lora_enabled)\n        self.assertAllClose(model.predict(x), new_model.predict(x), atol=0.5)\n\n        # Try loading a normal checkpoint into a lora model\n        new_model.save_weights(temp_filepath)\n        model.load_weights(temp_filepath)\n        self.assertAllClose(model.predict(x), new_model.predict(x), atol=0.5)\n\n        # Test export and TFSMLayer reloading when using tensorflow backend\n        if backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n            ref_input = tf.random.normal((32, 3))\n            ref_output = model(ref_input)\n            model.export(temp_filepath, format=\"tf_saved_model\")\n            reloaded_layer = export.TFSMLayer(temp_filepath)\n            self.assertAllClose(\n                reloaded_layer(ref_input), ref_output, atol=1e-7\n            )\n            self.assertLen(reloaded_layer.weights, len(model.weights))\n            self.assertLen(\n                reloaded_layer.trainable_weights, len(model.trainable_weights)\n            )\n            self.assertLen(\n                reloaded_layer.non_trainable_weights,\n                len(model.non_trainable_weights),\n            )\n\n    def test_legacy_load_own_variables(self):\n        # In previous versions, `load_own_variables` accepted a store with\n        # numeric keys.\n        float32_store = {\n            \"0\": np.random.random((10, 16)).astype(\"float32\"),\n        }\n        int8_store = {\n            \"0\": np.random.randint(-128, 127, size=(10, 16), dtype=\"int8\"),\n            \"1\": np.random.random((10,)).astype(\"float32\"),\n        }\n        int4_store = {\n            \"0\": np.random.randint(-128, 127, size=(10, 8), dtype=\"int8\"),\n            \"1\": np.random.random((10,)).astype(\"float32\"),\n        }\n\n        # Test float32 layer.\n        layer = layers.Embedding(10, 16)\n        layer.build()\n        layer.load_own_variables(float32_store)\n        self.assertAllClose(layer._embeddings, float32_store[\"0\"])\n\n        # Test int8-quantized layer.\n        layer = layers.Embedding(10, 16, dtype=\"int8_from_float32\")\n        layer.build()\n        layer.load_own_variables(int8_store)\n        self.assertAllClose(layer._embeddings, int8_store[\"0\"])\n        self.assertAllClose(layer.embeddings_scale, int8_store[\"1\"])\n\n        # Test int4-quantized layer.\n        layer = layers.Embedding(10, 16, dtype=\"int4_from_float32\")\n        layer.build()\n        layer.load_own_variables(int4_store)\n        self.assertAllClose(layer._embeddings, int4_store[\"0\"])\n        self.assertAllClose(layer.embeddings_scale, int4_store[\"1\"])\n\n    def test_embedding_int8_custom_quantizer(self):\n        \"\"\"\n        Test custom quantizer serialization for embedding layer with\n        int8 quantization.\n        \"\"\"\n        # Setup\n        weight_range = (-50, 50)\n        config = Int8QuantizationConfig(\n            weight_quantizer=AbsMaxQuantizer(axis=-1, value_range=weight_range),\n        )\n\n        # Build & Quantize\n        layer = layers.Embedding(input_dim=100, output_dim=16)\n        layer.build(None)\n        layer.quantize(\"int8\", config=config)\n\n        # Serialize & Deserialize\n        serialized = layer.get_config()\n        new_layer = layers.Embedding.from_config(serialized)\n\n        # Verify\n        self.assertIsInstance(\n            new_layer.quantization_config, Int8QuantizationConfig\n        )\n        quantizer = new_layer.quantization_config.weight_quantizer\n        self.assertIsInstance(quantizer, AbsMaxQuantizer)\n        self.assertEqual(quantizer.axis, (-1,))\n        self.assertAllEqual(quantizer.value_range, weight_range)\n\n    @parameterized.named_parameters(\n        (\"grouped_block_64\", 64),\n        (\"grouped_block_128\", 128),\n        (\"per_channel_none\", None),\n        (\"per_channel_neg1\", -1),\n    )\n    def test_int4_quantization_block_size(self, block_size):\n        \"\"\"Test int4 quantization with different block_size configurations.\"\"\"\n\n        input_dim, output_dim = 100, 256\n        layer = layers.Embedding(input_dim=input_dim, output_dim=output_dim)\n        layer.build()\n\n        x = np.random.randint(0, input_dim, size=(4, 8))\n        y_float = layer(x)\n\n        # Create config with specified block_size\n        config = Int4QuantizationConfig(block_size=block_size)\n        layer.quantize(\"int4\", config=config)\n\n        # Verify block_size is stored\n        self.assertEqual(layer._int4_block_size, block_size)\n\n        # Verify embeddings_scale shape\n        if block_size is None or block_size == -1:\n            # Per-channel: one scale per vocabulary item\n            expected_scale_shape = (input_dim,)\n        else:\n            # Sub-channel: n_groups scales per vocabulary item\n            n_groups = math.ceil(output_dim / block_size)\n            expected_scale_shape = (input_dim, n_groups)\n\n        self.assertEqual(layer.embeddings_scale.shape, expected_scale_shape)\n\n        # Verify outputs are reasonable\n        y_quantized = layer(x)\n        mse = ops.mean(ops.square(y_float - y_quantized))\n        self.assertLess(mse, 1e-3)\n\n    @parameterized.named_parameters(\n        (\"grouped_block_64\", 64),\n        (\"grouped_block_128\", 128),\n        (\"per_channel_none\", None),\n    )\n    def test_int4_block_size_serialization(self, block_size):\n        \"\"\"Test that block_size is preserved through serialization.\"\"\"\n\n        input_dim, output_dim = 50, 128\n        layer = layers.Embedding(input_dim=input_dim, output_dim=output_dim)\n        layer.build()\n\n        config = Int4QuantizationConfig(block_size=block_size)\n        layer.quantize(\"int4\", config=config)\n\n        # Get output before serialization\n        x = np.random.randint(0, input_dim, size=(2, 8))\n        y_before = layer(x)\n\n        # Save and load model to test full serialization roundtrip\n        model = models.Sequential([layer])\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"int4_block_size_embedding_model.keras\"\n        )\n        model.save(temp_filepath)\n        loaded_model = saving.load_model(temp_filepath)\n\n        # Verify block_size is preserved\n        loaded_layer = loaded_model.layers[0]\n        self.assertIsInstance(\n            loaded_layer.quantization_config, Int4QuantizationConfig\n        )\n        self.assertEqual(\n            loaded_layer.quantization_config.block_size, block_size\n        )\n\n        # Verify outputs match after deserialization\n        y_after = loaded_model(x)\n        self.assertAllClose(y_before, y_after)\n\n    @parameterized.named_parameters(\n        (\"grouped_block_64\", 64),\n        (\"per_channel\", None),\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_int4_block_size_with_lora(self, block_size):\n        \"\"\"Test int4 quantization with LoRA and different block_size.\"\"\"\n        input_dim, output_dim = 50, 128\n        layer = layers.Embedding(input_dim=input_dim, output_dim=output_dim)\n        layer.build()\n\n        config = Int4QuantizationConfig(block_size=block_size)\n        layer.quantize(\"int4\", config=config)\n        layer.enable_lora(rank=4)\n\n        x = np.random.randint(0, input_dim, size=(4, 8))\n\n        # Should run without error\n        y = layer(x)\n        self.assertEqual(y.shape, (4, 8, output_dim))\n\n    def test_int4_grouped_vs_perchannel_scale_shapes(self):\n        \"\"\"Test that grouped and per-channel have different scale shapes.\"\"\"\n\n        input_dim, output_dim = 100, 256\n        block_size = 64\n\n        # Per-channel layer\n        layer_pc = layers.Embedding(input_dim=input_dim, output_dim=output_dim)\n        layer_pc.build()\n        config_pc = Int4QuantizationConfig(block_size=None)\n        layer_pc.quantize(\"int4\", config=config_pc)\n\n        # Grouped layer\n        layer_grouped = layers.Embedding(\n            input_dim=input_dim, output_dim=output_dim\n        )\n        layer_grouped.build()\n        config_grouped = Int4QuantizationConfig(block_size=block_size)\n        layer_grouped.quantize(\"int4\", config=config_grouped)\n\n        # Verify different scale shapes\n        self.assertEqual(layer_pc.embeddings_scale.shape, (input_dim,))\n        n_groups = math.ceil(output_dim / block_size)\n        self.assertEqual(\n            layer_grouped.embeddings_scale.shape, (input_dim, n_groups)\n        )\n\n    @parameterized.named_parameters(\n        (\"grouped_block_4\", 4),\n        (\"grouped_block_8\", 8),\n    )\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_int4_subchannel_g_idx_created(self, block_size):\n        \"\"\"Test that g_idx is created for sub-channel int4 quantization.\"\"\"\n        input_dim, output_dim = 10, 16\n        layer = layers.Embedding(input_dim=input_dim, output_dim=output_dim)\n        layer.build()\n\n        config = Int4QuantizationConfig(block_size=block_size)\n        layer.quantize(\"int4\", config=config)\n\n        # Verify g_idx is created\n        self.assertTrue(hasattr(layer, \"g_idx\"))\n\n        # Verify g_idx shape (output_dim for embedding)\n        self.assertEqual(layer.g_idx.shape, (output_dim,))\n\n        # Verify g_idx values (should map each column to its group)\n        expected_g_idx = np.arange(output_dim) // block_size\n        self.assertAllClose(layer.g_idx, expected_g_idx)\n\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_int4_perchannel_no_g_idx(self):\n        \"\"\"Test that per-channel int4 does NOT create g_idx.\"\"\"\n        layer = layers.Embedding(input_dim=10, output_dim=16)\n        layer.build()\n\n        config = Int4QuantizationConfig(block_size=None)  # Per-channel\n        layer.quantize(\"int4\", config=config)\n\n        # Verify g_idx is NOT created for per-channel\n        self.assertFalse(hasattr(layer, \"g_idx\"))\n\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_int4_subchannel_g_idx_serialization(self):\n        \"\"\"Test that g_idx is properly serialized and deserialized.\"\"\"\n        input_dim, output_dim = 10, 16\n        block_size = 8\n\n        layer = layers.Embedding(input_dim=input_dim, output_dim=output_dim)\n        layer.build()\n\n        config = Int4QuantizationConfig(block_size=block_size)\n        layer.quantize(\"int4\", config=config)\n\n        x = np.array([[1, 2, 3], [4, 5, 6]], dtype=\"int32\")\n        y_before = layer(x)\n        g_idx_before = ops.convert_to_numpy(layer.g_idx)\n\n        # Save and load\n        model = models.Sequential([layer])\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"embedding_int4_g_idx_model.keras\"\n        )\n        model.save(temp_filepath)\n        loaded_model = saving.load_model(temp_filepath)\n\n        # Verify g_idx is preserved\n        loaded_layer = loaded_model.layers[0]\n        self.assertTrue(hasattr(loaded_layer, \"g_idx\"))\n        self.assertAllClose(loaded_layer.g_idx, g_idx_before)\n\n        # Verify outputs match\n        y_after = loaded_model(x)\n        self.assertAllClose(y_before, y_after)\n"
  },
  {
    "path": "keras/src/layers/core/identity.py",
    "content": "from keras.src import tree\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend import KerasTensor\nfrom keras.src.layers.layer import Layer\n\n\n@keras_export(\"keras.layers.Identity\")\nclass Identity(Layer):\n    \"\"\"Identity layer.\n\n    This layer should be used as a placeholder when no operation is to be\n    performed. The layer just returns its `inputs` argument as output.\n    \"\"\"\n\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n        self.supports_masking = True\n\n        self._build_at_init()\n\n    def call(self, inputs):\n        return inputs\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def compute_output_spec(self, inputs):\n        return tree.map_structure(\n            lambda x: KerasTensor(x.shape, dtype=x.dtype, sparse=x.sparse),\n            inputs,\n        )\n"
  },
  {
    "path": "keras/src/layers/core/identity_test.py",
    "content": "import pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass IdentityTest(testing.TestCase):\n    @parameterized.named_parameters(\n        [\n            {\"testcase_name\": \"dense\", \"sparse\": False},\n            {\"testcase_name\": \"sparse\", \"sparse\": True},\n        ]\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_identity_basics(self, sparse):\n        if sparse and not backend.SUPPORTS_SPARSE_TENSORS:\n            pytest.skip(\"Backend does not support sparse tensors.\")\n        self.run_layer_test(\n            layers.Identity,\n            init_kwargs={},\n            input_shape=(2, 3),\n            input_sparse=sparse,\n            expected_output_shape=(2, 3),\n            expected_output_sparse=sparse,\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            run_training_check=not sparse,\n            supports_masking=True,\n            assert_built_after_instantiation=True,\n        )\n"
  },
  {
    "path": "keras/src/layers/core/input_layer.py",
    "content": "import warnings\n\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.layer import Layer\nfrom keras.src.ops.node import Node\n\n\n@keras_export(\"keras.layers.InputLayer\")\nclass InputLayer(Layer):\n    def __init__(\n        self,\n        shape=None,\n        batch_size=None,\n        dtype=None,\n        sparse=None,\n        ragged=None,\n        batch_shape=None,\n        input_tensor=None,\n        optional=False,\n        name=None,\n        **kwargs,\n    ):\n        super().__init__(name=name)\n\n        if \"input_shape\" in kwargs:\n            warnings.warn(\n                \"Argument `input_shape` is deprecated. Use `shape` instead.\"\n            )\n            shape = kwargs.pop(\"input_shape\")\n        if \"batch_input_shape\" in kwargs:\n            batch_shape = kwargs.pop(\"batch_input_shape\")\n\n        if input_tensor is not None:\n            if not isinstance(input_tensor, backend.KerasTensor):\n                raise ValueError(\n                    \"Argument `input_tensor` must be a KerasTensor. \"\n                    f\"Received invalid type: input_tensor={input_tensor} \"\n                    f\"(of type {type(input_tensor)})\"\n                )\n            if batch_size is not None:\n                if (\n                    len(input_tensor.shape) < 1\n                    or input_tensor.shape[0] != batch_size\n                ):\n                    raise ValueError(\n                        \"When providing the `input_tensor` argument, you \"\n                        \"cannot provide an incompatible `batch_size` argument.\"\n                    )\n            if shape is not None:\n                if (\n                    len(shape) != len(input_tensor.shape) - 1\n                    or shape != input_tensor.shape[1:]\n                ):\n                    raise ValueError(\n                        \"When providing the `input_tensor` argument, you \"\n                        \"cannot provide an incompatible `shape` argument.\"\n                    )\n            if batch_shape is not None and batch_shape != input_tensor.shape:\n                raise ValueError(\n                    \"When providing the `input_tensor` argument, you \"\n                    \"cannot provide an incompatible `batch_shape` argument.\"\n                )\n            if dtype is not None and input_tensor.dtype != dtype:\n                raise ValueError(\n                    \"When providing the `input_tensor` argument, you \"\n                    \"cannot provide an incompatible `dtype` argument.\"\n                )\n            if sparse is not None and input_tensor.sparse != sparse:\n                raise ValueError(\n                    \"When providing the `input_tensor` argument, you \"\n                    \"cannot provide an incompatible `sparse` argument.\"\n                )\n            batch_shape = input_tensor.shape\n            dtype = input_tensor.dtype\n            sparse = input_tensor.sparse\n        else:\n            if shape is not None and batch_shape is not None:\n                raise ValueError(\n                    \"You cannot pass both `shape` and `batch_shape` at the \"\n                    \"same time.\"\n                )\n            if batch_size is not None and batch_shape is not None:\n                raise ValueError(\n                    \"You cannot pass both `batch_size` and `batch_shape` \"\n                    \"at the same time.\"\n                )\n            if shape is None and batch_shape is None:\n                raise ValueError(\"You must pass a `shape` argument.\")\n\n            if shape is not None:\n                shape = backend.standardize_shape(shape)\n                batch_shape = (batch_size,) + shape\n\n        self._batch_shape = backend.standardize_shape(batch_shape)\n        self._dtype = backend.standardize_dtype(dtype)\n        self.sparse = bool(sparse)\n        if self.sparse and not backend.SUPPORTS_SPARSE_TENSORS:\n            raise ValueError(\n                f\"`sparse=True` is not supported with the {backend.backend()} \"\n                \"backend\"\n            )\n        self.ragged = bool(ragged)\n        if self.ragged and not backend.SUPPORTS_RAGGED_TENSORS:\n            raise ValueError(\n                f\"`ragged=True` is not supported with the {backend.backend()} \"\n                \"backend\"\n            )\n\n        if input_tensor is None:\n            input_tensor = backend.KerasTensor(\n                shape=batch_shape,\n                dtype=dtype,\n                sparse=sparse,\n                ragged=ragged,\n                name=name,\n            )\n        self._input_tensor = input_tensor\n        Node(operation=self, call_args=(), call_kwargs={}, outputs=input_tensor)\n        self.built = True\n        self.optional = optional\n\n    def call(self):\n        return\n\n    @property\n    def batch_shape(self):\n        return self._batch_shape\n\n    @property\n    def dtype(self):\n        return self._dtype\n\n    def get_config(self):\n        return {\n            \"batch_shape\": self.batch_shape,\n            \"dtype\": self.dtype,\n            \"sparse\": self.sparse,\n            \"ragged\": self.ragged,\n            \"name\": self.name,\n            \"optional\": self.optional,\n        }\n\n\n@keras_export([\"keras.layers.Input\", \"keras.Input\"])\ndef Input(\n    shape=None,\n    batch_size=None,\n    dtype=None,\n    sparse=None,\n    ragged=None,\n    batch_shape=None,\n    name=None,\n    tensor=None,\n    optional=False,\n):\n    \"\"\"Used to instantiate a Keras tensor.\n\n    A Keras tensor is a symbolic tensor-like object, which we augment with\n    certain attributes that allow us to build a Keras model just by knowing the\n    inputs and outputs of the model.\n\n    For instance, if `a`, `b` and `c` are Keras tensors,\n    it becomes possible to do:\n    `model = Model(input=[a, b], output=c)`\n\n    Args:\n        shape: A shape tuple (tuple of integers or `None` objects),\n            not including the batch size.\n            For instance, `shape=(32,)` indicates that the expected input\n            will be batches of 32-dimensional vectors. Elements of this tuple\n            can be `None`; `None` elements represent dimensions where the shape\n            is not known and may vary (e.g. sequence length).\n        batch_size: Optional static batch size (integer).\n        dtype: The data type expected by the input, as a string\n            (e.g. `\"float32\"`, `\"int32\"`...)\n        sparse: A boolean specifying whether the expected input will be sparse\n            tensors. Note that, if `sparse` is `False`, sparse tensors can still\n            be passed into the input - they will be densified with a default\n            value of 0. This feature is only supported with the TensorFlow and\n            the JAX backends. Defaults to `False`.\n        ragged: A boolean specifying whether the expected input will be ragged\n            tensors. Note that, if `ragged` is `False`, ragged tensors can still\n            be passed into the input - they will be densified with a default\n            value of 0. This feature is only supported with the TensorFlow\n            backend. Defaults to `False`.\n        batch_shape: Optional shape tuple (tuple of integers or `None` objects),\n            including the batch size.\n        name: Optional name string for the layer.\n            Should be unique in a model (do not reuse the same name twice).\n            It will be autogenerated if it isn't provided.\n        tensor: Optional existing tensor to wrap into the `Input` layer.\n            If set, the layer will use this tensor rather\n            than creating a new placeholder tensor.\n        optional: Boolean, whether the input is optional or not.\n            An optional input can accept `None` values.\n\n    Returns:\n      A Keras tensor.\n\n    Example:\n\n    ```python\n    # This is a logistic regression in Keras\n    x = Input(shape=(32,))\n    y = Dense(16, activation='softmax')(x)\n    model = Model(x, y)\n    ```\n    \"\"\"\n    layer = InputLayer(\n        shape=shape,\n        batch_size=batch_size,\n        dtype=dtype,\n        sparse=sparse,\n        ragged=ragged,\n        batch_shape=batch_shape,\n        name=name,\n        input_tensor=tensor,\n        optional=optional,\n    )\n    return layer.output\n"
  },
  {
    "path": "keras/src/layers/core/input_layer_test.py",
    "content": "import numpy as np\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import testing\nfrom keras.src.backend import KerasTensor\nfrom keras.src.layers import InputLayer\n\n\nclass InputLayerTest(testing.TestCase):\n    # Testing happy path for layer without input tensor\n    @parameterized.named_parameters(\n        [\n            {\"testcase_name\": \"dense\"},\n            {\"testcase_name\": \"sparse\", \"sparse\": True},\n            {\"testcase_name\": \"ragged\", \"ragged\": True},\n        ]\n    )\n    def test_input_basic(self, sparse=False, ragged=False):\n        input_shape = (2, 3)\n        batch_size = 4\n        dtype = \"float32\"\n        ndim = len(tuple((batch_size,) + input_shape))\n\n        init_kwargs = {\n            \"shape\": input_shape,\n            \"batch_size\": batch_size,\n            \"dtype\": dtype,\n            \"sparse\": sparse,\n            \"ragged\": ragged,\n        }\n\n        if sparse and not backend.SUPPORTS_SPARSE_TENSORS:\n            with self.assertRaisesRegex(\n                ValueError, \"`sparse=True` is not supported\"\n            ):\n                InputLayer(**init_kwargs)\n            return\n        if ragged and not backend.SUPPORTS_RAGGED_TENSORS:\n            with self.assertRaisesRegex(\n                ValueError, \"`ragged=True` is not supported\"\n            ):\n                InputLayer(**init_kwargs)\n            return\n\n        values = InputLayer(**init_kwargs)\n\n        self.assertEqual(values.dtype, dtype)\n        self.assertEqual(values.batch_shape[0], batch_size)\n        self.assertEqual(values.batch_shape[1:], input_shape)\n        self.assertEqual(values.sparse, sparse)\n        self.assertEqual(values.ragged, ragged)\n        self.assertEqual(values.trainable, True)\n        self.assertIsInstance(values.output, KerasTensor)\n        self.assertEqual(values.output.ndim, ndim)\n        self.assertEqual(values.output.dtype, dtype)\n        self.assertEqual(values.output.sparse, sparse)\n        self.assertEqual(values.output.ragged, ragged)\n\n    # Testing shape is not None and batch_shape is not None condition\n    def test_input_error1(self):\n        input_shape = (2, 3)\n\n        with self.assertRaisesRegex(\n            ValueError, \"cannot pass both `shape` and `batch_shape`\"\n        ):\n            InputLayer(shape=input_shape, batch_shape=input_shape)\n\n    # Testing batch_size is not None and batch_shape is not None\n    def test_input_error2(self):\n        input_shape = (2, 3)\n        batch_size = 4\n\n        with self.assertRaisesRegex(\n            ValueError, \"cannot pass both `batch_size` and `batch_shape`\"\n        ):\n            InputLayer(batch_size=batch_size, batch_shape=input_shape)\n\n    # Testing shape is None and batch_shape is None\n    def test_input_error3(self):\n        with self.assertRaisesRegex(ValueError, \"pass a `shape` argument.\"):\n            InputLayer(shape=None, batch_shape=None)\n\n    # Testing Input tensor is not Keras tensor\n    def test_input_tensor_error(self):\n        input_shape = (2, 3)\n        batch_size = 4\n        input_tensor = np.zeros(input_shape)\n\n        with self.assertRaisesRegex(\n            ValueError, \"Argument `input_tensor` must be a KerasTensor\"\n        ):\n            InputLayer(\n                shape=input_shape,\n                batch_size=batch_size,\n                input_tensor=input_tensor,\n            )\n\n    # Testing happy path for layer with input tensor\n    def testing_input_tensor(self):\n        input_shape = (2, 3)\n        dtype = \"float32\"\n        input_tensor = KerasTensor(shape=input_shape, dtype=dtype)\n\n        layer = InputLayer(\n            input_tensor=input_tensor,\n        )\n\n        self.assertEqual(layer.dtype, dtype)\n        self.assertEqual(layer.batch_shape, (2, 3))\n        self.assertEqual(layer.trainable, True)\n        self.assertIsInstance(layer.output, KerasTensor)\n        self.assertEqual(layer.output, input_tensor)\n        self.assertEqual(layer.output.ndim, input_tensor.ndim)\n        self.assertEqual(layer.output.dtype, dtype)\n\n    def test_input_shape_deprecated(self):\n        input_shape = (2, 3)\n        batch_size = 4\n        dtype = \"float32\"\n\n        with self.assertWarnsRegex(\n            UserWarning,\n            \"Argument `input_shape` is deprecated. Use `shape` instead.\",\n        ):\n            layer = InputLayer(\n                input_shape=input_shape, batch_size=batch_size, dtype=dtype\n            )\n\n        self.assertEqual(layer.batch_shape[0], batch_size)\n        self.assertEqual(layer.batch_shape[1:], input_shape)\n        self.assertEqual(layer.dtype, dtype)\n        self.assertIsInstance(layer.output, KerasTensor)\n\n    def test_call_method(self):\n        layer = InputLayer(shape=(32,))\n        output = layer.call()\n        self.assertIsNone(output)\n\n    def test_numpy_shape(self):\n        # non-python int type shapes should be ok\n        InputLayer(shape=(np.int64(32),))\n\n    def test_invalid_arg_combinations(self):\n        input_tensor = KerasTensor(shape=(2, 3), dtype=\"float32\")\n\n        with self.assertRaisesRegex(\n            ValueError, \"cannot provide an incompatible `shape`\"\n        ):\n            _ = InputLayer(\n                shape=(2, 4),\n                input_tensor=input_tensor,\n            )\n        with self.assertRaisesRegex(\n            ValueError, \"cannot provide an incompatible `batch_shape`\"\n        ):\n            _ = InputLayer(\n                batch_shape=(2, 4),\n                input_tensor=input_tensor,\n            )\n        with self.assertRaisesRegex(\n            ValueError, \"cannot provide an incompatible `batch_size`\"\n        ):\n            _ = InputLayer(\n                batch_size=5,\n                input_tensor=input_tensor,\n            )\n        with self.assertRaisesRegex(\n            ValueError, \"cannot provide an incompatible `dtype`\"\n        ):\n            _ = InputLayer(\n                dtype=\"float16\",\n                input_tensor=input_tensor,\n            )\n        with self.assertRaisesRegex(\n            ValueError, \"cannot provide an incompatible `sparse`\"\n        ):\n            _ = InputLayer(\n                sparse=True,\n                input_tensor=input_tensor,\n            )\n\n        # This works\n        _ = InputLayer(\n            shape=(3,),\n            batch_size=2,\n            sparse=False,\n            dtype=\"float32\",\n            input_tensor=input_tensor,\n        )\n"
  },
  {
    "path": "keras/src/layers/core/lambda_layer.py",
    "content": "import inspect\nimport types\n\nfrom keras.src import backend\nfrom keras.src import tree\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.layer import Layer\nfrom keras.src.saving import serialization_lib\nfrom keras.src.utils import python_utils\n\n\n@keras_export(\"keras.layers.Lambda\")\nclass Lambda(Layer):\n    \"\"\"Wraps arbitrary expressions as a `Layer` object.\n\n    The `Lambda` layer exists so that arbitrary expressions can be used\n    as a `Layer` when constructing Sequential\n    and Functional API models. `Lambda` layers are best suited for simple\n    operations or quick experimentation. For more advanced use cases,\n    prefer writing new subclasses of `Layer`.\n\n    WARNING: `Lambda` layers have (de)serialization limitations!\n\n    The main reason to subclass `Layer` instead of using a\n    `Lambda` layer is saving and inspecting a model. `Lambda` layers\n    are saved by serializing the Python bytecode, which is fundamentally\n    non-portable and potentially unsafe.\n    They should only be loaded in the same environment where\n    they were saved. Subclassed layers can be saved in a more portable way\n    by overriding their `get_config()` method. Models that rely on\n    subclassed Layers are also often easier to visualize and reason about.\n\n    Example:\n\n    ```python\n    # add a x -> x^2 layer\n    model.add(Lambda(lambda x: x ** 2))\n    ```\n\n    Args:\n        function: The function to be evaluated. Takes input tensor as first\n            argument.\n        output_shape: Expected output shape from function. This argument\n            can usually be inferred if not explicitly provided.\n            Can be a tuple or function. If a tuple, it only specifies\n            the first dimension onward; sample dimension is assumed\n            either the same as the input:\n            `output_shape = (input_shape[0], ) + output_shape` or,\n            the input is `None` and the sample dimension is also `None`:\n            `output_shape = (None, ) + output_shape`.\n            If a function, it specifies the\n            entire shape as a function of the input shape:\n            `output_shape = f(input_shape)`.\n        mask: Either None (indicating no masking) or a callable with the same\n            signature as the `compute_mask` layer method, or a tensor\n            that will be returned as output mask regardless\n            of what the input is.\n        arguments: Optional dictionary of keyword arguments to be passed to the\n            function.\n    \"\"\"\n\n    def __init__(\n        self, function, output_shape=None, mask=None, arguments=None, **kwargs\n    ):\n        super().__init__(**kwargs)\n\n        self.arguments = arguments or {}\n        self.function = function\n\n        if mask is not None:\n            self.supports_masking = True\n        else:\n            self.supports_masking = False\n        self.mask = mask\n        self._output_shape = output_shape\n\n        # Warning on every invocation will be quite irksome in Eager mode.\n        self._already_warned = False\n\n        function_args = inspect.getfullargspec(function).args\n        self._fn_expects_training_arg = \"training\" in function_args\n        self._fn_expects_mask_arg = \"mask\" in function_args\n\n    def compute_output_shape(self, input_shape):\n        if self._output_shape is None:\n            # Leverage backend shape inference\n            try:\n                inputs = tree.map_shape_structure(\n                    lambda x: backend.KerasTensor(x, dtype=self.compute_dtype),\n                    input_shape,\n                )\n                output_spec = backend.compute_output_spec(self.call, inputs)\n                return tree.map_structure(lambda x: x.shape, output_spec)\n            except:\n                raise NotImplementedError(\n                    \"We could not automatically infer the shape of \"\n                    \"the Lambda's output. Please specify the `output_shape` \"\n                    \"argument for this Lambda layer.\"\n                )\n\n        if callable(self._output_shape):\n            return self._output_shape(input_shape)\n\n        # Output shapes are passed directly and don't include batch dimension.\n        batch_size = tree.flatten(input_shape)[0]\n\n        def _add_batch(shape):\n            return (batch_size,) + shape\n\n        return tree.map_shape_structure(_add_batch, self._output_shape)\n\n    def call(self, inputs, mask=None, training=None):\n        # We must copy for thread safety,\n        # but it only needs to be a shallow copy.\n        kwargs = {k: v for k, v in self.arguments.items()}\n        if self._fn_expects_mask_arg:\n            kwargs[\"mask\"] = mask\n        if self._fn_expects_training_arg:\n            kwargs[\"training\"] = training\n        return self.function(inputs, **kwargs)\n\n    def compute_mask(self, inputs, mask=None):\n        if callable(self.mask):\n            return self.mask(inputs, mask)\n        return self.mask\n\n    def get_config(self):\n        config = {\n            \"function\": self._serialize_function_to_config(self.function),\n        }\n        if self._output_shape is not None:\n            if callable(self._output_shape):\n                output_shape = self._serialize_function_to_config(\n                    self._output_shape\n                )\n            else:\n                output_shape = self._output_shape\n            config[\"output_shape\"] = output_shape\n        if self.mask is not None:\n            if callable(self.mask):\n                mask = self._serialize_function_to_config(self.mask)\n            else:\n                mask = serialization_lib.serialize_keras_object(self.mask)\n            config[\"mask\"] = mask\n        config[\"arguments\"] = serialization_lib.serialize_keras_object(\n            self.arguments\n        )\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n    def _serialize_function_to_config(self, fn):\n        if isinstance(fn, types.LambdaType) and fn.__name__ == \"<lambda>\":\n            code, defaults, closure = python_utils.func_dump(fn)\n            return {\n                \"class_name\": \"__lambda__\",\n                \"config\": {\n                    \"code\": code,\n                    \"defaults\": defaults,\n                    \"closure\": closure,\n                },\n            }\n        elif callable(fn):\n            return serialization_lib.serialize_keras_object(fn)\n        raise ValueError(\n            \"Invalid input type for serialization. \"\n            f\"Received: {fn} of type {type(fn)}.\"\n        )\n\n    @staticmethod\n    def _raise_for_lambda_deserialization(safe_mode):\n        if safe_mode:\n            raise ValueError(\n                \"Requested the deserialization of a `Lambda` layer whose \"\n                \"`function` is a Python lambda. This carries a potential risk \"\n                \"of arbitrary code execution and thus it is disallowed by \"\n                \"default. If you trust the source of the artifact, you can \"\n                \"override this error by passing `safe_mode=False` to the \"\n                \"loading function, or calling \"\n                \"`keras.config.enable_unsafe_deserialization().\"\n            )\n\n    @classmethod\n    def from_config(cls, config, custom_objects=None, safe_mode=None):\n        safe_mode = safe_mode or serialization_lib.in_safe_mode()\n        fn_config = config[\"function\"]\n        if (\n            isinstance(fn_config, dict)\n            and \"class_name\" in fn_config\n            and fn_config[\"class_name\"] == \"__lambda__\"\n        ):\n            cls._raise_for_lambda_deserialization(safe_mode)\n            inner_config = fn_config[\"config\"]\n            fn = python_utils.func_load(\n                inner_config[\"code\"],\n                defaults=inner_config[\"defaults\"],\n                closure=inner_config[\"closure\"],\n            )\n            config[\"function\"] = fn\n        else:\n            config[\"function\"] = serialization_lib.deserialize_keras_object(\n                fn_config, custom_objects=custom_objects\n            )\n        if \"output_shape\" in config:\n            fn_config = config[\"output_shape\"]\n            if (\n                isinstance(fn_config, dict)\n                and \"class_name\" in fn_config\n                and fn_config[\"class_name\"] == \"__lambda__\"\n            ):\n                cls._raise_for_lambda_deserialization(safe_mode)\n                inner_config = fn_config[\"config\"]\n                fn = python_utils.func_load(\n                    inner_config[\"code\"],\n                    defaults=inner_config[\"defaults\"],\n                    closure=inner_config[\"closure\"],\n                )\n                config[\"output_shape\"] = fn\n            else:\n                output_shape = serialization_lib.deserialize_keras_object(\n                    fn_config, custom_objects=custom_objects\n                )\n                if isinstance(output_shape, list) and all(\n                    isinstance(e, (int, type(None))) for e in output_shape\n                ):\n                    output_shape = tuple(output_shape)\n                config[\"output_shape\"] = output_shape\n\n        if \"arguments\" in config:\n            config[\"arguments\"] = serialization_lib.deserialize_keras_object(\n                config[\"arguments\"], custom_objects=custom_objects\n            )\n        return cls(**config)\n"
  },
  {
    "path": "keras/src/layers/core/lambda_layer_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import layers\nfrom keras.src import ops\nfrom keras.src import testing\n\n\nclass LambdaTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_lambda_basics(self):\n        self.run_layer_test(\n            layers.Lambda,\n            init_kwargs={\n                \"function\": ops.square,\n            },\n            input_shape=(2, 3),\n            expected_output_shape=(2, 3),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n            custom_objects={\"square\": ops.square},\n        )\n        self.run_layer_test(\n            layers.Lambda,\n            init_kwargs={\"function\": ops.square, \"mask\": ops.ones((2, 3))},\n            input_shape=(2, 3, 4),\n            expected_output_shape=(2, 3, 4),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n            custom_objects={\"square\": ops.square},\n        )\n\n        def stacker(x):\n            return ops.concatenate([x, x], axis=1)\n\n        self.run_layer_test(\n            layers.Lambda,\n            init_kwargs={\"function\": stacker, \"output_shape\": (6,)},\n            input_shape=(2, 3),\n            expected_output_shape=(2, 6),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n            custom_objects={\"stacker\": stacker},\n        )\n\n        def stacker_shape(s):\n            return (s[0], s[1] * 2)\n\n        self.run_layer_test(\n            layers.Lambda,\n            init_kwargs={\n                \"function\": stacker,\n                \"output_shape\": stacker_shape,\n            },\n            input_shape=(2, 3),\n            expected_output_shape=(2, 6),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n            custom_objects={\"stacker\": stacker, \"stacker_shape\": stacker_shape},\n        )\n\n    def test_correctness(self):\n        layer = layers.Lambda(lambda x: x**2)\n        output = layer(2 * np.ones((2, 3)))\n        self.assertAllClose(4 * np.ones((2, 3)), output)\n\n        # Test serialization roundtrip\n        config = layer.get_config()\n        layer = layers.Lambda.from_config(config, safe_mode=False)\n        output = layer(2 * np.ones((2, 3)))\n        self.assertAllClose(4 * np.ones((2, 3)), output)\n\n    def test_correctness_lambda_shape(self):\n        layer = layers.Lambda(lambda x: x**2, output_shape=lambda x: x)\n        output = layer(2 * np.ones((2, 3)))\n        self.assertAllClose(4 * np.ones((2, 3)), output)\n\n        # Test serialization roundtrip\n        config = layer.get_config()\n        layer = layers.Lambda.from_config(config, safe_mode=False)\n        output = layer(2 * np.ones((2, 3)))\n        self.assertAllClose(4 * np.ones((2, 3)), output)\n"
  },
  {
    "path": "keras/src/layers/core/masking.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.layer import Layer\nfrom keras.src.saving.serialization_lib import deserialize_keras_object\n\n\n@keras_export(\"keras.layers.Masking\")\nclass Masking(Layer):\n    \"\"\"Masks a sequence by using a mask value to skip timesteps.\n\n    For each timestep in the input tensor (dimension #1 in the tensor),\n    if all values in the input tensor at that timestep\n    are equal to `mask_value`, then the timestep will be masked (skipped)\n    in all downstream layers (as long as they support masking).\n\n    If any downstream layer does not support masking yet receives such\n    an input mask, an exception will be raised.\n\n    Example:\n\n    Consider a NumPy data array `x` of shape `(samples, timesteps, features)`,\n    to be fed to an LSTM layer. You want to mask timestep #3 and #5 because you\n    lack data for these timesteps. You can:\n\n    - Set `x[:, 3, :] = 0.` and `x[:, 5, :] = 0.`\n    - Insert a `Masking` layer with `mask_value=0.` before the LSTM layer:\n\n    ```python\n    samples, timesteps, features = 32, 10, 8\n    inputs = np.random.random([samples, timesteps, features]).astype(np.float32)\n    inputs[:, 3, :] = 0.\n    inputs[:, 5, :] = 0.\n\n    model = keras.models.Sequential()\n    model.add(keras.layers.Masking(mask_value=0.0))\n    model.add(keras.layers.LSTM(32))\n    output = model(inputs)\n    # The time step 3 and 5 will be skipped from LSTM calculation.\n    ```\n\n    Note: in the Keras masking convention, a masked timestep is denoted by\n    a mask value of `False`, while a non-masked (i.e. usable) timestep\n    is denoted by a mask value of `True`.\n    \"\"\"\n\n    def __init__(self, mask_value=0.0, **kwargs):\n        super().__init__(**kwargs)\n        # `mask_value` can be a serialized tensor, hence verify it\n        if isinstance(mask_value, dict) and mask_value.get(\"config\", None):\n            mask_value = deserialize_keras_object(mask_value)\n        self.mask_value = mask_value\n        self.supports_masking = True\n\n        self._build_at_init()\n\n    def compute_mask(self, inputs, mask=None):\n        return ops.any(ops.not_equal(inputs, self.mask_value), axis=-1)\n\n    def call(self, inputs):\n        boolean_mask = ops.any(\n            ops.not_equal(inputs, self.mask_value), axis=-1, keepdims=True\n        )\n        # Set masked outputs to 0\n        outputs = inputs * backend.cast(boolean_mask, dtype=inputs.dtype)\n        # Compute the mask and outputs simultaneously.\n        backend.set_keras_mask(outputs, mask=ops.squeeze(boolean_mask, axis=-1))\n        return outputs\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\"mask_value\": self.mask_value}\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/core/masking_test.py",
    "content": "import os\n\nimport numpy as np\nimport pytest\n\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.saving import load_model\n\n\nclass MaskingTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_masking_basics(self):\n        self.run_layer_test(\n            layers.Masking,\n            init_kwargs={\"mask_value\": 0.0},\n            input_shape=(2, 3, 2),\n            expected_output_shape=(2, 3, 2),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n            assert_built_after_instantiation=True,\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_masking_correctness(self):\n        x = np.array(\n            [\n                [[0.0, 0.0], [1.0, 2.0], [0.0, 0.0]],\n                [[2.0, 2.0], [0.0, 0.0], [2.0, 1.0]],\n            ]\n        )\n        expected_mask = [[False, True, False], [True, False, True]]\n\n        layer = layers.Masking(mask_value=0.0)\n        self.assertAllClose(layer.compute_mask(x), expected_mask)\n\n        test_obj = self\n\n        class TestLayer(layers.Layer):\n            def __init__(self, **kwargs):\n                super().__init__(**kwargs)\n                self.supports_masking = True\n\n            def compute_output_shape(self, input_shape):\n                return input_shape\n\n            def call(self, inputs, mask=None):\n                test_obj.assertIsNotNone(mask)\n                test_obj.assertAllClose(mask, expected_mask)\n                return inputs\n\n        model = models.Sequential(\n            [\n                layers.Masking(mask_value=0.0),\n                TestLayer(),\n            ]\n        )\n        model(x)\n\n    @pytest.mark.requires_trainable_backend\n    def test_masking_with_tensor(self):\n        model = models.Sequential(\n            [\n                layers.Masking(mask_value=ops.convert_to_tensor([0.0])),\n                layers.LSTM(1),\n            ]\n        )\n        x = np.array(\n            [\n                [[0.0, 0.0], [1.0, 2.0], [0.0, 0.0]],\n                [[2.0, 2.0], [0.0, 0.0], [2.0, 1.0]],\n            ]\n        )\n        model(x)\n        temp_filepath = os.path.join(self.get_temp_dir(), \"model.keras\")\n        model.save(temp_filepath)\n        reload_model = load_model(temp_filepath)\n        reload_model(x)\n"
  },
  {
    "path": "keras/src/layers/core/reversible_embedding.py",
    "content": "import copy\nimport math\n\nfrom keras.src import dtype_policies\nfrom keras.src import layers\nfrom keras.src import ops\nfrom keras.src import quantizers\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend import KerasTensor\nfrom keras.src.backend import set_keras_mask\nfrom keras.src.quantizers.quantization_config import QuantizationConfig\nfrom keras.src.quantizers.quantization_config import get_block_size_for_layer\nfrom keras.src.quantizers.quantizers import dequantize_with_sz_map\n\n\n@keras_export(\"keras.layers.ReversibleEmbedding\")\nclass ReversibleEmbedding(layers.Embedding):\n    \"\"\"An embedding layer which can project backwards to the input dim.\n\n    This layer is an extension of `keras.layers.Embedding` for language models.\n    This layer can be called \"in reverse\" with `reverse=True`, in which case the\n    layer will linearly project from `output_dim` back to `input_dim`.\n\n    By default, the reverse projection will use the transpose of the\n    `embeddings` weights to project to `input_dim` (weights are \"tied\"). If\n    `tie_weights=False`, the model will use a separate, trainable variable for\n    reverse projection.\n\n    This layer has no bias terms.\n\n    Args:\n        input_dim: Integer. Size of the vocabulary,\n            i.e. maximum integer index + 1.\n        output_dim: Integer. Dimension of the dense embedding.\n        tie_weights: Boolean, whether or not the matrix for embedding and\n            the matrix for the `reverse` projection should share the same\n            weights.\n        embeddings_initializer: Initializer for the `embeddings`\n            matrix (see `keras.initializers`).\n        embeddings_regularizer: Regularizer function applied to\n            the `embeddings` matrix (see `keras.regularizers`).\n        embeddings_constraint: Constraint function applied to\n            the `embeddings` matrix (see `keras.constraints`).\n        mask_zero: Boolean, whether or not the input value 0 is a special\n            \"padding\" value that should be masked out.\n        reverse_dtype: The dtype for the reverse projection computation.\n            Defaults to the `compute_dtype` of the layer.\n        logit_soft_cap: If `logit_soft_cap` is set and `reverse=True`, the\n            output logits will be scaled by\n            `tanh(logits / logit_soft_cap) * logit_soft_cap`. This narrows the\n            range of output logits and can improve training.\n        **kwargs: other keyword arguments passed to `keras.layers.Embedding`,\n            including `name`, `trainable`, `dtype` etc.\n\n    Call arguments:\n        inputs: The tensor inputs to the layer.\n        reverse: Boolean. If `True` the layer will perform a linear projection\n            from `output_dim` to `input_dim`, instead of a normal embedding\n            call. Default to `False`.\n\n    Example:\n    ```python\n    batch_size = 16\n    vocab_size = 100\n    hidden_dim = 32\n    seq_length = 50\n\n    # Generate random inputs.\n    token_ids = np.random.randint(vocab_size, size=(batch_size, seq_length))\n\n    embedding = keras.layers.ReversibleEmbedding(vocab_size, hidden_dim)\n    # Embed tokens to shape `(batch_size, seq_length, hidden_dim)`.\n    hidden_states = embedding(token_ids)\n    # Project hidden states to shape `(batch_size, seq_length, vocab_size)`.\n    logits = embedding(hidden_states, reverse=True)\n    ```\n\n    References:\n    - [Vaswani et al., 2017](https://arxiv.org/abs/1706.03762)\n    - [Press and Wolf, 2016](https://arxiv.org/abs/1608.05859)\n    \"\"\"\n\n    def __init__(\n        self,\n        input_dim,\n        output_dim,\n        tie_weights=True,\n        embeddings_initializer=\"uniform\",\n        embeddings_regularizer=None,\n        embeddings_constraint=None,\n        mask_zero=False,\n        reverse_dtype=None,\n        logit_soft_cap=None,\n        **kwargs,\n    ):\n        super().__init__(\n            input_dim,\n            output_dim,\n            embeddings_initializer=embeddings_initializer,\n            embeddings_regularizer=embeddings_regularizer,\n            embeddings_constraint=embeddings_constraint,\n            mask_zero=mask_zero,\n            **kwargs,\n        )\n        self.tie_weights = tie_weights\n        self.reverse_dtype = reverse_dtype\n        self.logit_soft_cap = logit_soft_cap\n\n    def build(self, inputs_shape=None):\n        super().build(inputs_shape)\n        if not self.tie_weights and self.quantization_mode not in (\n            \"int8\",\n            \"int4\",\n        ):\n            self.reverse_embeddings = self.add_weight(\n                shape=(self.output_dim, self.input_dim),\n                initializer=self.embeddings_initializer,\n                name=\"reverse_embeddings\",\n                trainable=True,\n            )\n\n    def call(self, inputs, reverse=False):\n        if not reverse:\n            result = super().call(inputs)\n            mask = super().compute_mask(inputs)\n            if mask is not None:\n                set_keras_mask(result, mask)\n            return result\n        else:\n            if self.tie_weights:\n                kernel = ops.transpose(self.embeddings)\n            else:\n                kernel = self.reverse_embeddings\n            if self.reverse_dtype is not None:\n                inputs = ops.cast(inputs, self.reverse_dtype)\n                kernel = ops.cast(kernel, self.reverse_dtype)\n            logits = ops.matmul(inputs, kernel)\n            # Optionally soft-cap logits.\n            if self.logit_soft_cap is not None:\n                soft_cap = self.logit_soft_cap\n                logits = ops.multiply(\n                    ops.tanh(ops.divide(logits, soft_cap)), soft_cap\n                )\n            return logits\n\n    def compute_mask(self, inputs, mask=None):\n        # Disable masking from super class, masking is done directly in call.\n        return None\n\n    def compute_output_shape(self, input_shape, reverse=False):\n        output_shape = list(input_shape)\n        if reverse:\n            output_shape[-1] = self.input_dim\n        else:\n            output_shape += [self.output_dim]\n        return output_shape\n\n    def compute_output_spec(self, inputs, reverse=False):\n        output_shape = list(inputs.shape)\n        if reverse:\n            output_shape[-1] = self.input_dim\n        else:\n            output_shape += [self.output_dim]\n        return KerasTensor(output_shape, dtype=self.compute_dtype)\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"tie_weights\": self.tie_weights,\n                \"reverse_dtype\": self.reverse_dtype,\n                \"logit_soft_cap\": self.logit_soft_cap,\n            }\n        )\n        return config\n\n    @property\n    def variable_serialization_spec(self):\n        # Avoid modifying the parent's spec.\n        _spec = copy.deepcopy(super().variable_serialization_spec)\n        if not self.tie_weights:\n            for mode, variable_spec in _spec.items():\n                variable_spec.append(\"reverse_embeddings\")\n                if mode in (\"int4\", \"int8\"):\n                    variable_spec.append(\"reverse_embeddings_scale\")\n                if mode == \"int4\":\n                    # reverse_embeddings_zero only exists for sub-channel\n                    variable_spec.append(\"reverse_embeddings_zero\")\n        return _spec\n\n    def quantized_build(self, embeddings_shape, mode, config=None):\n        if mode == \"int8\":\n            self._int8_build(embeddings_shape, config)\n        elif mode == \"int4\":\n            self._int4_build(embeddings_shape, config)\n        else:\n            raise self._quantization_mode_error(mode)\n        self._is_quantized = True\n\n    def _int8_build(self, embeddings_shape, config=None):\n        if embeddings_shape is None:\n            embeddings_shape = (self.input_dim, self.output_dim)\n        super()._int8_build(embeddings_shape=embeddings_shape)\n\n        self.inputs_quantizer = (\n            QuantizationConfig.activation_quantizer_or_default(\n                config, quantizers.AbsMaxQuantizer(axis=-1)\n            )\n        )\n        if not self.tie_weights:\n            self.reverse_embeddings = self.add_weight(\n                name=\"reverse_embeddings\",\n                shape=(self.output_dim, self.input_dim),\n                initializer=\"zeros\",\n                dtype=\"int8\",\n                trainable=False,\n            )\n            self.reverse_embeddings_scale = self.add_weight(\n                name=\"reverse_embeddings_scale\",\n                shape=(self.input_dim,),\n                initializer=\"ones\",\n                trainable=False,\n            )\n\n    def _int4_build(self, embeddings_shape, config=None):\n        if embeddings_shape is None:\n            embeddings_shape = (self.input_dim, self.output_dim)\n        super()._int4_build(embeddings_shape=embeddings_shape, config=config)\n\n        self.inputs_quantizer = (\n            QuantizationConfig.activation_quantizer_or_default(\n                config, quantizers.AbsMaxQuantizer(axis=-1)\n            )\n        )\n        if not self.tie_weights:\n            packed_rows = (self.output_dim + 1) // 2  # ceil for odd dims\n            self.reverse_embeddings = self.add_weight(\n                name=\"reverse_embeddings\",\n                shape=(packed_rows, self.input_dim),\n                initializer=\"zeros\",\n                dtype=\"int8\",\n                trainable=False,\n            )\n\n            # Determine block_size from config or dtype_policy\n            block_size = get_block_size_for_layer(self, config)\n\n            if block_size is None or block_size == -1:\n                # Per-channel: one scale per output unit (input_dim)\n                reverse_scale_shape = (self.input_dim,)\n            else:\n                # Grouped: scale per group along output_dim (axis=0)\n                n_groups = math.ceil(self.output_dim / block_size)\n                reverse_scale_shape = (n_groups, self.input_dim)\n\n            self.reverse_embeddings_scale = self.add_weight(\n                name=\"reverse_embeddings_scale\",\n                shape=reverse_scale_shape,\n                initializer=\"ones\",\n                trainable=False,\n            )\n\n            # Zero point for asymmetric grouped quantization\n            if block_size is not None and block_size != -1:\n                self.reverse_embeddings_zero = self.add_weight(\n                    name=\"reverse_embeddings_zero\",\n                    shape=reverse_scale_shape,\n                    initializer=\"zeros\",\n                    trainable=False,\n                )\n\n    def _int8_call(self, inputs, reverse=False):\n        if not reverse:\n            return super()._int8_call(inputs)\n        else:\n            if self.tie_weights:\n                kernel = ops.transpose(self._embeddings)\n                scale = ops.transpose(self.embeddings_scale)\n            else:\n                kernel = self.reverse_embeddings\n                scale = self.reverse_embeddings_scale\n            if self.inputs_quantizer:\n                inputs, inputs_scale = self.inputs_quantizer(inputs)\n            else:\n                inputs_scale = ops.ones((1,), dtype=self.compute_dtype)\n            logits = ops.matmul(inputs, kernel)\n            # De-scale outputs\n            logits = ops.cast(logits, self.compute_dtype)\n            logits = ops.divide(logits, ops.multiply(inputs_scale, scale))\n            # Optionally soft-cap logits.\n            if self.logit_soft_cap is not None:\n                soft_cap = self.logit_soft_cap\n                logits = ops.multiply(\n                    ops.tanh(ops.divide(logits, soft_cap)), soft_cap\n                )\n            return logits\n\n    def _int4_call(self, inputs, reverse=False):\n        if not reverse:\n            return super()._int4_call(inputs)\n        else:\n            block_size = getattr(self, \"_int4_block_size\", None)\n\n            if self.tie_weights:\n                embeddings = ops.transpose(self._embeddings)\n                scale = self.embeddings_scale\n                # For tied weights, scale shape is (input_dim,) or\n                # (input_dim, n_groups). For per-channel, transpose scale.\n                if block_size is None or block_size == -1:\n                    scale = ops.transpose(scale)\n            else:\n                embeddings = self.reverse_embeddings\n                scale = self.reverse_embeddings_scale\n\n            unpacked_embeddings = quantizers.unpack_int4(\n                embeddings, self.output_dim, axis=0\n            )\n\n            if self.inputs_quantizer:\n                inputs, inputs_scale = self.inputs_quantizer(inputs)\n            else:\n                inputs_scale = ops.ones((1,), dtype=self.compute_dtype)\n\n            if block_size is None or block_size == -1:\n                # Per-channel: do matmul then dequantize\n                logits = ops.matmul(inputs, unpacked_embeddings)\n                logits = ops.cast(logits, self.compute_dtype)\n                logits = ops.divide(logits, ops.multiply(inputs_scale, scale))\n            elif self.tie_weights:\n                # Sub-channel with asymmetric quantization (tied weights)\n                # Must dequantize embeddings before matmul for correctness\n                # unpacked_embeddings shape: (output_dim, input_dim)\n                # scale shape: (input_dim, n_groups)\n                # embeddings_zero shape: (input_dim, n_groups)\n                # g_idx shape: (output_dim,)\n\n                # Transpose scale/zero for dequantization:\n                # [input_dim, n_groups] -> [n_groups, input_dim]\n                scale_t = ops.transpose(scale)\n                zero_t = ops.transpose(self.embeddings_zero)\n\n                float_embeddings = dequantize_with_sz_map(\n                    ops.cast(unpacked_embeddings, self.compute_dtype),\n                    scale_t,\n                    zero_t,\n                    self.g_idx,\n                    group_axis=0,\n                )\n\n                # inputs shape: (batch, output_dim)\n                # float_embeddings shape: (output_dim, input_dim)\n                logits = ops.matmul(inputs, float_embeddings)\n                logits = ops.divide(logits, inputs_scale)\n            else:\n                # Untied weights with asymmetric grouped quantization\n                # Must dequantize embeddings before matmul for correctness\n                # unpacked_embeddings shape: (output_dim, input_dim)\n                # scale shape: (n_groups, input_dim)\n                # reverse_embeddings_zero shape: (n_groups, input_dim)\n                # g_idx shape: (output_dim,) - reuse from forward pass\n\n                float_embeddings = dequantize_with_sz_map(\n                    ops.cast(unpacked_embeddings, self.compute_dtype),\n                    scale,\n                    self.reverse_embeddings_zero,\n                    self.g_idx,\n                    group_axis=0,\n                )\n\n                # inputs shape: (batch, output_dim)\n                # float_embeddings shape: (output_dim, input_dim)\n                logits = ops.matmul(inputs, float_embeddings)\n                logits = ops.divide(logits, inputs_scale)\n\n            # Optionally soft-cap logits.\n            if self.logit_soft_cap is not None:\n                soft_cap = self.logit_soft_cap\n                logits = ops.multiply(\n                    ops.tanh(ops.divide(logits, soft_cap)), soft_cap\n                )\n            return logits\n\n    def quantize(self, mode=None, type_check=True, config=None):\n        if type_check and type(self) is not ReversibleEmbedding:\n            raise self._not_implemented_error(self.quantize)\n\n        self.quantization_config = config\n\n        embeddings_shape = (self.input_dim, self.output_dim)\n        if mode == \"int8\":\n            # Quantize `self._embeddings` to int8 and compute corresponding\n            # scale.\n            weight_quantizer = QuantizationConfig.weight_quantizer_or_default(\n                self.quantization_config, quantizers.AbsMaxQuantizer(axis=-1)\n            )\n            embeddings_value, embeddings_scale = weight_quantizer(\n                self._embeddings, to_numpy=True\n            )\n            embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)\n            del self._embeddings\n            if not self.tie_weights:\n                reverse_weight_quantizer = (\n                    QuantizationConfig.weight_quantizer_or_default(\n                        self.quantization_config,\n                        quantizers.AbsMaxQuantizer(axis=0),\n                    )\n                )\n                reverse_embeddings_value, reverse_embeddings_scale = (\n                    reverse_weight_quantizer(\n                        self.reverse_embeddings, to_numpy=True\n                    )\n                )\n                reverse_embeddings_scale = ops.squeeze(\n                    reverse_embeddings_scale, axis=0\n                )\n                del self.reverse_embeddings\n            self.quantized_build(\n                embeddings_shape, mode, self.quantization_config\n            )\n            self._embeddings.assign(embeddings_value)\n            self.embeddings_scale.assign(embeddings_scale)\n            if not self.tie_weights:\n                self.reverse_embeddings.assign(reverse_embeddings_value)\n                self.reverse_embeddings_scale.assign(reverse_embeddings_scale)\n        elif mode == \"int4\":\n            from keras.src.quantizers.quantization_config import (\n                Int4QuantizationConfig,\n            )\n\n            block_size = None\n            if isinstance(self.quantization_config, Int4QuantizationConfig):\n                block_size = self.quantization_config.block_size\n\n            use_grouped = block_size is not None and block_size != -1\n\n            # Quantize forward embeddings\n            if not use_grouped:\n                # Per-channel quantization\n                weight_quantizer = (\n                    QuantizationConfig.weight_quantizer_or_default(\n                        self.quantization_config,\n                        quantizers.AbsMaxQuantizer(\n                            axis=-1,\n                            value_range=(-8, 7),\n                            output_dtype=\"int8\",\n                        ),\n                    )\n                )\n                embeddings_value, embeddings_scale = weight_quantizer(\n                    self._embeddings, to_numpy=True\n                )\n                embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)\n            else:\n                # Sub-channel quantization with asymmetric zero point\n                embeddings_t = ops.transpose(self._embeddings)\n                embeddings_value_t, scale_t, zero_t = (\n                    quantizers.abs_max_quantize_grouped_with_zero_point(\n                        embeddings_t,\n                        block_size=block_size,\n                        value_range=(-8, 7),\n                        dtype=\"int8\",\n                        to_numpy=True,\n                    )\n                )\n                # Transpose back to (input_dim, output_dim) layout\n                embeddings_value = ops.transpose(embeddings_value_t)\n                embeddings_scale = ops.transpose(scale_t)\n                embeddings_zero = ops.transpose(zero_t)\n\n            packed_embeddings_value, _, _ = quantizers.pack_int4(\n                embeddings_value, axis=-1\n            )\n            del self._embeddings\n\n            # Quantize reverse embeddings if not tied\n            if not self.tie_weights:\n                if not use_grouped:\n                    reverse_weight_quantizer = (\n                        QuantizationConfig.weight_quantizer_or_default(\n                            self.quantization_config,\n                            quantizers.AbsMaxQuantizer(\n                                axis=0,\n                                value_range=(-8, 7),\n                                output_dtype=\"int8\",\n                            ),\n                        )\n                    )\n                    reverse_embeddings_value, reverse_embeddings_scale = (\n                        reverse_weight_quantizer(\n                            self.reverse_embeddings, to_numpy=True\n                        )\n                    )\n                    reverse_embeddings_scale = ops.squeeze(\n                        reverse_embeddings_scale, axis=0\n                    )\n                else:\n                    reverse_value, reverse_scale, reverse_zero = (\n                        quantizers.abs_max_quantize_grouped_with_zero_point(\n                            self.reverse_embeddings,\n                            block_size=block_size,\n                            value_range=(-8, 7),\n                            dtype=\"int8\",\n                            to_numpy=True,\n                        )\n                    )\n                    reverse_embeddings_value = reverse_value\n                    reverse_embeddings_scale = reverse_scale\n                    reverse_embeddings_zero = reverse_zero\n\n                packed_reverse_embeddings_value, _, _ = quantizers.pack_int4(\n                    reverse_embeddings_value, axis=0\n                )\n                del self.reverse_embeddings\n\n            self.quantized_build(\n                embeddings_shape, mode, self.quantization_config\n            )\n            self._embeddings.assign(packed_embeddings_value)\n            self.embeddings_scale.assign(embeddings_scale)\n            if use_grouped:\n                self.embeddings_zero.assign(embeddings_zero)\n            if not self.tie_weights:\n                self.reverse_embeddings.assign(packed_reverse_embeddings_value)\n                self.reverse_embeddings_scale.assign(reverse_embeddings_scale)\n                if use_grouped:\n                    self.reverse_embeddings_zero.assign(reverse_embeddings_zero)\n        else:\n            raise self._quantization_mode_error(mode)\n\n        # Set new dtype policy.\n        if self.dtype_policy.quantization_mode is None:\n            policy_name = mode\n            if mode == \"int4\":\n                # Include block_size in policy name for sub-channel quantization\n                block_size = get_block_size_for_layer(self, config)\n                block_size_value = -1 if block_size is None else block_size\n                policy_name = f\"int4/{block_size_value}\"\n            policy = dtype_policies.get(\n                f\"{policy_name}_from_{self.dtype_policy.name}\"\n            )\n            self.dtype_policy = policy\n"
  },
  {
    "path": "keras/src/layers/core/reversible_embedding_test.py",
    "content": "import math\nimport os\n\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import saving\nfrom keras.src import testing\nfrom keras.src.quantizers.quantization_config import Int4QuantizationConfig\nfrom keras.src.quantizers.quantization_config import Int8QuantizationConfig\nfrom keras.src.quantizers.quantizers import AbsMaxQuantizer\nfrom keras.src.testing import test_case\nfrom keras.src.testing.test_utils import named_product\n\n\nclass ReversibleEmbeddingTest(test_case.TestCase):\n    @parameterized.named_parameters(\n        (\"int8\", \"int8\", {\"axis\": -1}, {\"axis\": -1}),\n        (\n            \"int4\",\n            \"int4\",\n            {\"axis\": -1, \"value_range\": (-8, 7), \"output_dtype\": \"int8\"},\n            {\"axis\": -1},\n        ),\n        (\"int8_weight_only\", \"int8\", {\"axis\": -1}, None),\n    )\n    def test_reversible_embedding_quantize(\n        self, mode, weight_quantizer_args, activation_quantizer_args\n    ):\n        \"\"\"Test ReversibleEmbedding quantization with QuantizationConfig.\"\"\"\n        layer = layers.ReversibleEmbedding(\n            input_dim=10, output_dim=6, tie_weights=True\n        )\n        layer.build((None,))\n\n        weight_quantizer = AbsMaxQuantizer(**weight_quantizer_args)\n        if activation_quantizer_args is not None:\n            activation_quantizer = AbsMaxQuantizer(**activation_quantizer_args)\n        else:\n            activation_quantizer = None\n\n        if mode == \"int8\":\n            config = Int8QuantizationConfig(\n                weight_quantizer=weight_quantizer,\n                activation_quantizer=activation_quantizer,\n            )\n        elif mode == \"int4\":\n            # Custom quantizers require per-channel mode (block_size=None)\n            config = Int4QuantizationConfig(\n                weight_quantizer=weight_quantizer,\n                activation_quantizer=activation_quantizer,\n                block_size=None,\n            )\n\n        layer.quantize(mode, config=config)\n\n        if activation_quantizer_args is not None:\n            # Verify inputs_quantizer is set correctly\n            self.assertIsInstance(layer.inputs_quantizer, AbsMaxQuantizer)\n        else:\n            # Verify inputs_quantizer is None\n            self.assertIsNone(layer.inputs_quantizer)\n\n        # Verify reverse call works\n        x = np.random.random((2, 6)).astype(\"float32\")\n        y = layer(x, reverse=True)\n        self.assertEqual(y.shape, (2, 10))\n\n    @parameterized.named_parameters(\n        (\"tie_weights\", True),\n        (\"untie_weights\", False),\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_reversible_embedding_basics(self, tie_weights):\n        self.run_layer_test(\n            layers.ReversibleEmbedding,\n            init_kwargs={\n                \"input_dim\": 100,\n                \"output_dim\": 32,\n                \"tie_weights\": tie_weights,\n                \"embeddings_initializer\": \"HeNormal\",\n                \"logit_soft_cap\": 50,\n            },\n            input_data=np.random.randint(low=0, high=100, size=(4, 10)),\n            expected_output_shape=(4, 10, 32),\n            expected_num_trainable_weights=1 if tie_weights else 2,\n        )\n\n    @parameterized.named_parameters(\n        (\"tie_weights\", True),\n        (\"untie_weights\", False),\n    )\n    def test_saving(self, tie_weights):\n        input_data = np.random.randint(low=0, high=100, size=(4, 10))\n        model = models.Sequential(\n            [\n                layers.ReversibleEmbedding(\n                    input_dim=100,\n                    output_dim=32,\n                    tie_weights=tie_weights,\n                )\n            ]\n        )\n        path = os.path.join(self.get_temp_dir(), \"model.keras\")\n        model_output = model(input_data)\n        model.save(path)\n        restored_model = saving.load_model(path)\n        restored_output = restored_model(input_data)\n        self.assertAllClose(model_output, restored_output)\n\n    def test_correctness(self):\n        layer = layers.ReversibleEmbedding(input_dim=3, output_dim=2)\n        layer.build()\n        layer.embeddings.assign(np.array([[0.0, 0.0], [2.0, 2.0], [3.0, 3.0]]))\n        out = layer(np.array(([2, 1, 0])))\n        self.assertAllClose(out, np.array([[3.0, 3.0], [2.0, 2.0], [0.0, 0.0]]))\n\n        layer = layers.ReversibleEmbedding(input_dim=3, output_dim=2)\n        layer.build()\n        layer.embeddings.assign(np.array([[0.0, 0.0], [2.0, 2.0], [3.0, 3.0]]))\n        out = layer(np.array(([[1.0, 1.0]])), reverse=True)\n        self.assertAllClose(out, np.array([[0.0, 4.0, 6.0]]))\n\n        layer = layers.ReversibleEmbedding(\n            input_dim=3, output_dim=2, logit_soft_cap=5\n        )\n        layer.build()\n        layer.embeddings.assign(np.array([[0.0, 0.0], [2.0, 2.0], [3.0, 3.0]]))\n        out = layer(np.array(([[1.0, 1.0]])), reverse=True)\n        self.assertAllClose(out, np.array([[0.0, 3.320184, 4.168273]]))\n\n    def test_reverse_dtype(self):\n        embedding = layers.ReversibleEmbedding(100, 16, reverse_dtype=\"float32\")\n        input_data = ops.ones(shape=(4, 10, 16))\n        output_data = embedding(input_data, reverse=True)\n        self.assertEqual(output_data.shape, (4, 10, 100))\n        self.assertDType(output_data, \"float32\")\n\n        if backend.backend() == \"torch\":\n            import torch\n\n            if not torch.cuda.is_available():\n                self.skipTest(\"Torch CPU does not support float16\")\n\n        embedding = layers.ReversibleEmbedding(100, 16, reverse_dtype=\"float16\")\n        input_data = ops.ones(shape=(4, 10, 16))\n        output_data = embedding(input_data, reverse=True)\n        self.assertEqual(output_data.shape, (4, 10, 100))\n        self.assertDType(output_data, \"float16\")\n\n    @parameterized.named_parameters(\n        named_product(mode=(\"int4\", \"int8\"), tie_weights=(False, True))\n    )\n    def test_quantize_int(self, mode, tie_weights):\n        layer = layers.ReversibleEmbedding(10, 16, tie_weights=tie_weights)\n        layer.build()\n        x = np.random.randint(0, 9, size=(64, 3))\n        x_reverse = np.random.uniform(size=(64, 16)).astype(\"float32\")\n        y_float = layer(x)\n        y_reverse_float = layer(x_reverse, reverse=True)\n        layer.quantize(mode)\n\n        # Verify the dtype of the weights.\n        if not tie_weights:\n            # The reverse_embeddings's dtype is int8, despite the int4\n            # quantization, because we pack the int4 values into int8.\n            self.assertDType(layer.reverse_embeddings, \"int8\")\n            self.assertDType(\n                layer.reverse_embeddings_scale, layer.variable_dtype\n            )\n\n        # Verify the correctness of the outputs.\n        y_quantized = layer(x)\n        y_reverse_quantized = layer(x_reverse, reverse=True)\n        mse = ops.mean(ops.square(y_float - y_quantized))\n        mse_reverse = ops.mean(\n            ops.square(y_reverse_float - y_reverse_quantized)\n        )\n        self.assertLess(mse, 1e-3)  # A weak correctness test\n        self.assertLess(mse_reverse, 1e-3)  # A weak correctness test\n\n        # Check model save / load round-trip.\n        model = models.Sequential([layer])\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"quantized_model.keras\"\n        )\n        model.save(temp_filepath)\n        new_model = saving.load_model(temp_filepath)\n        self.assertAllClose(model.predict(x), new_model.predict(x))\n\n        # Check weights-only save / load round-trip.\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"quantized_model.weights.h5\"\n        )\n        model.save_weights(temp_filepath)\n        new_model = models.Sequential(\n            [layers.ReversibleEmbedding(10, 16, tie_weights=tie_weights)]\n        )\n        new_model.build((None, 3))\n        new_model.quantize(mode)\n        new_model.load_weights(temp_filepath)\n        self.assertAllClose(model.predict(x), new_model.predict(x))\n\n    @parameterized.named_parameters(\n        (\"int8_tie_weights\", \"int8_from_mixed_bfloat16\", True, 0, 2),\n        (\"int8_untie_weights\", \"int8_from_mixed_bfloat16\", False, 0, 4),\n        (\"int4_tie_weights\", \"int4_from_mixed_bfloat16\", True, 0, 2),\n        (\"int4_untie_weights\", \"int4_from_mixed_bfloat16\", False, 0, 4),\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_quantize_dtype_argument(\n        self,\n        dtype,\n        tie_weights,\n        num_trainable_weights,\n        num_non_trainable_weights,\n    ):\n        self.run_layer_test(\n            layers.ReversibleEmbedding,\n            init_kwargs={\n                \"input_dim\": 100,\n                \"output_dim\": 32,\n                \"tie_weights\": tie_weights,\n                \"embeddings_initializer\": \"HeNormal\",\n                \"dtype\": dtype,\n            },\n            input_data=np.random.randint(low=0, high=100, size=(4, 10)),\n            expected_output_shape=(4, 10, 32),\n            expected_num_trainable_weights=num_trainable_weights,\n            expected_num_non_trainable_weights=num_non_trainable_weights,\n            expected_num_non_trainable_variables=num_non_trainable_weights,\n        )\n\n    def test_reversible_embedding_int8_custom_quantizer(self):\n        \"\"\"\n        Test custom quantizer serialization for reversible embedding layer with\n        int8 quantization.\n        \"\"\"\n        # Setup\n        weight_range = (-20, 20)\n        config = Int8QuantizationConfig(\n            weight_quantizer=AbsMaxQuantizer(axis=-1, value_range=weight_range),\n        )\n\n        # Build & Quantize\n        layer = layers.ReversibleEmbedding(input_dim=100, output_dim=16)\n        layer.build(None)\n        layer.quantize(\"int8\", config=config)\n\n        # Serialize & Deserialize\n        serialized = layer.get_config()\n        new_layer = layers.ReversibleEmbedding.from_config(serialized)\n\n        # Verify\n        self.assertIsInstance(\n            new_layer.quantization_config, Int8QuantizationConfig\n        )\n        quantizer = new_layer.quantization_config.weight_quantizer\n        self.assertIsInstance(quantizer, AbsMaxQuantizer)\n        self.assertAllEqual(quantizer.value_range, weight_range)\n\n    def test_masking(self):\n        layer = layers.ReversibleEmbedding(3, 2, mask_zero=True)\n        layer.build()\n\n        out = layer(np.array(([2, 1, 0])))\n        mask = backend.get_keras_mask(out)\n        self.assertAllClose(mask, np.array([True, True, False]))\n\n        out = layer(np.array(([[1.0, 2.0], [0.0, 0.0]])), reverse=True)\n        mask = backend.get_keras_mask(out)\n        self.assertIsNone(mask)\n\n    @parameterized.named_parameters(\n        named_product(\n            block_size=(64, 128, None, -1),\n            tie_weights=(True, False),\n        )\n    )\n    def test_int4_quantization_block_size(self, block_size, tie_weights):\n        \"\"\"Test int4 quantization with different block_size configurations.\"\"\"\n\n        input_dim, output_dim = 100, 256\n        layer = layers.ReversibleEmbedding(\n            input_dim=input_dim, output_dim=output_dim, tie_weights=tie_weights\n        )\n        layer.build()\n\n        x = np.random.randint(0, input_dim, size=(4, 8))\n        x_reverse = np.random.random((4, output_dim)).astype(\"float32\")\n        y_float = layer(x)\n        y_reverse_float = layer(x_reverse, reverse=True)\n\n        # Create config with specified block_size\n        config = Int4QuantizationConfig(block_size=block_size)\n        layer.quantize(\"int4\", config=config)\n\n        # Verify block_size is stored\n        self.assertEqual(layer._int4_block_size, block_size)\n\n        # Verify embeddings_scale shape\n        if block_size is None or block_size == -1:\n            expected_scale_shape = (input_dim,)\n        else:\n            n_groups = math.ceil(output_dim / block_size)\n            expected_scale_shape = (input_dim, n_groups)\n\n        self.assertEqual(layer.embeddings_scale.shape, expected_scale_shape)\n\n        # Verify reverse_embeddings_scale shape if not tied\n        if not tie_weights:\n            if block_size is None or block_size == -1:\n                expected_reverse_scale_shape = (input_dim,)\n            else:\n                n_groups = math.ceil(output_dim / block_size)\n                expected_reverse_scale_shape = (n_groups, input_dim)\n\n            self.assertEqual(\n                layer.reverse_embeddings_scale.shape,\n                expected_reverse_scale_shape,\n            )\n\n        # Verify outputs are reasonable\n        y_quantized = layer(x)\n        y_reverse_quantized = layer(x_reverse, reverse=True)\n        mse = ops.mean(ops.square(y_float - y_quantized))\n        mse_reverse = ops.mean(\n            ops.square(y_reverse_float - y_reverse_quantized)\n        )\n        self.assertLess(mse, 1e-3)\n        self.assertLess(mse_reverse, 1e-2)\n\n    @parameterized.named_parameters(\n        named_product(\n            block_size=(64, 128, None),\n            tie_weights=(True, False),\n        )\n    )\n    def test_int4_block_size_serialization(self, block_size, tie_weights):\n        \"\"\"Test that block_size is preserved through serialization.\"\"\"\n        input_dim, output_dim = 50, 128\n        layer = layers.ReversibleEmbedding(\n            input_dim=input_dim, output_dim=output_dim, tie_weights=tie_weights\n        )\n        layer.build()\n\n        config = Int4QuantizationConfig(block_size=block_size)\n        layer.quantize(\"int4\", config=config)\n\n        # Get output before serialization\n        x = np.random.randint(0, input_dim, size=(2, 8))\n        y_before = layer(x)\n\n        # Save and load model to test full serialization roundtrip\n        model = models.Sequential([layer])\n        temp_filepath = os.path.join(\n            self.get_temp_dir(),\n            f\"int4_block_size_rev_emb_model_{tie_weights}.keras\",\n        )\n        model.save(temp_filepath)\n        loaded_model = saving.load_model(temp_filepath)\n\n        # Verify block_size is preserved\n        loaded_layer = loaded_model.layers[0]\n        self.assertIsInstance(\n            loaded_layer.quantization_config, Int4QuantizationConfig\n        )\n        self.assertEqual(\n            loaded_layer.quantization_config.block_size, block_size\n        )\n\n        # Verify reverse_embeddings_zero is preserved for untied grouped\n        if not tie_weights and block_size is not None:\n            self.assertTrue(hasattr(loaded_layer, \"reverse_embeddings_zero\"))\n            self.assertAllClose(\n                loaded_layer.reverse_embeddings_zero,\n                layer.reverse_embeddings_zero,\n            )\n\n        # Verify outputs match after deserialization\n        y_after = loaded_model(x)\n        self.assertAllClose(y_before, y_after)\n\n    @parameterized.named_parameters(\n        (\"tie_grouped\", True, 64),\n        (\"tie_perchannel\", True, None),\n        (\"untie_grouped\", False, 64),\n        (\"untie_perchannel\", False, None),\n    )\n    def test_int4_grouped_vs_perchannel_scale_shapes(\n        self, tie_weights, block_size\n    ):\n        \"\"\"Test that grouped and per-channel have different scale shapes.\"\"\"\n\n        input_dim, output_dim = 100, 256\n\n        layer = layers.ReversibleEmbedding(\n            input_dim=input_dim, output_dim=output_dim, tie_weights=tie_weights\n        )\n        layer.build()\n        config = Int4QuantizationConfig(block_size=block_size)\n        layer.quantize(\"int4\", config=config)\n\n        if block_size is None or block_size == -1:\n            # Per-channel\n            expected_scale_shape = (input_dim,)\n            expected_reverse_scale_shape = (input_dim,)\n        else:\n            # Grouped\n            n_groups = math.ceil(output_dim / block_size)\n            expected_scale_shape = (input_dim, n_groups)\n            expected_reverse_scale_shape = (n_groups, input_dim)\n\n        self.assertEqual(layer.embeddings_scale.shape, expected_scale_shape)\n\n        if not tie_weights:\n            self.assertEqual(\n                layer.reverse_embeddings_scale.shape,\n                expected_reverse_scale_shape,\n            )\n            # Check reverse_embeddings_zero shape for grouped quantization\n            if block_size is not None and block_size != -1:\n                self.assertTrue(hasattr(layer, \"reverse_embeddings_zero\"))\n                self.assertEqual(\n                    layer.reverse_embeddings_zero.shape,\n                    expected_reverse_scale_shape,\n                )\n            else:\n                self.assertFalse(hasattr(layer, \"reverse_embeddings_zero\"))\n\n    @parameterized.named_parameters(\n        (\"grouped_block_4\", 4),\n        (\"grouped_block_8\", 8),\n    )\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_int4_subchannel_g_idx_created(self, block_size):\n        \"\"\"Test that g_idx is created for sub-channel int4 quantization.\"\"\"\n        input_dim, output_dim = 10, 16\n        layer = layers.ReversibleEmbedding(\n            input_dim=input_dim, output_dim=output_dim\n        )\n        layer.build()\n\n        config = Int4QuantizationConfig(block_size=block_size)\n        layer.quantize(\"int4\", config=config)\n\n        # Verify g_idx is created\n        self.assertTrue(hasattr(layer, \"g_idx\"))\n\n        # Verify g_idx shape (output_dim for embedding)\n        self.assertEqual(layer.g_idx.shape, (output_dim,))\n\n        # Verify g_idx values (should map each column to its group)\n        expected_g_idx = np.arange(output_dim) // block_size\n        self.assertAllClose(layer.g_idx, expected_g_idx)\n\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_int4_perchannel_no_g_idx(self):\n        \"\"\"Test that per-channel int4 does NOT create g_idx.\"\"\"\n        layer = layers.ReversibleEmbedding(input_dim=10, output_dim=16)\n        layer.build()\n\n        config = Int4QuantizationConfig(block_size=None)  # Per-channel\n        layer.quantize(\"int4\", config=config)\n\n        # Verify g_idx is NOT created for per-channel\n        self.assertFalse(hasattr(layer, \"g_idx\"))\n\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_int4_subchannel_g_idx_serialization(self):\n        \"\"\"Test that g_idx is properly serialized and deserialized.\"\"\"\n        input_dim, output_dim = 10, 16\n        block_size = 8\n\n        layer = layers.ReversibleEmbedding(\n            input_dim=input_dim, output_dim=output_dim\n        )\n        layer.build()\n\n        config = Int4QuantizationConfig(block_size=block_size)\n        layer.quantize(\"int4\", config=config)\n\n        x = np.array([[1, 2, 3], [4, 5, 6]], dtype=\"int32\")\n        y_before = layer(x)\n        g_idx_before = ops.convert_to_numpy(layer.g_idx)\n\n        # Save and load\n        model = models.Sequential([layer])\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"rev_embedding_int4_g_idx_model.keras\"\n        )\n        model.save(temp_filepath)\n        loaded_model = saving.load_model(temp_filepath)\n\n        # Verify g_idx is preserved\n        loaded_layer = loaded_model.layers[0]\n        self.assertTrue(hasattr(loaded_layer, \"g_idx\"))\n        self.assertAllClose(loaded_layer.g_idx, g_idx_before)\n\n        # Verify outputs match\n        y_after = loaded_model(x)\n        self.assertAllClose(y_before, y_after)\n"
  },
  {
    "path": "keras/src/layers/core/wrapper.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.layer import Layer\nfrom keras.src.saving import serialization_lib\n\n\n@keras_export(\"keras.layers.Wrapper\")\nclass Wrapper(Layer):\n    \"\"\"Abstract wrapper base class.\n\n    Wrappers take another layer and augment it in various ways.\n    Do not use this class as a layer, it is only an abstract base class.\n    Two usable wrappers are the `TimeDistributed` and `Bidirectional` layers.\n\n    Args:\n        layer: The layer to be wrapped.\n    \"\"\"\n\n    def __init__(self, layer, **kwargs):\n        if not isinstance(layer, Layer):\n            raise ValueError(\n                f\"Layer {layer} supplied to Wrapper isn't \"\n                \"a supported layer type. Please \"\n                \"ensure wrapped layer is a valid Keras layer.\"\n            )\n        super().__init__(**kwargs)\n        self.layer = layer\n\n    def build(self, input_shape=None):\n        if not self.layer.built:\n            self.layer.build(input_shape)\n            self.layer.built = True\n\n    def get_config(self):\n        config = {\"layer\": serialization_lib.serialize_keras_object(self.layer)}\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n    @classmethod\n    def from_config(cls, config, custom_objects=None):\n        layer = serialization_lib.deserialize_keras_object(\n            config.pop(\"layer\"),\n            custom_objects=custom_objects,\n        )\n        return cls(layer, **config)\n"
  },
  {
    "path": "keras/src/layers/core/wrapper_test.py",
    "content": "import pytest\n\nfrom keras.src import layers\nfrom keras.src import ops\nfrom keras.src import testing\n\n\nclass ExampleWrapper(layers.Wrapper):\n    \"\"\"Simple Wrapper subclass.\"\"\"\n\n    def call(self, inputs, **kwargs):\n        return ops.cast(self.layer(inputs, **kwargs), self.compute_dtype)\n\n\nclass WrapperTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_wrapper_basics(self):\n        self.run_layer_test(\n            ExampleWrapper,\n            init_kwargs={\n                \"layer\": layers.Dense(2),\n            },\n            input_shape=(2, 3),\n            expected_output_shape=(2, 2),\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n        self.run_layer_test(\n            ExampleWrapper,\n            init_kwargs={\n                \"layer\": layers.Dense(2, activity_regularizer=\"l2\"),\n            },\n            input_shape=(2, 3),\n            expected_output_shape=(2, 2),\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=1,\n            supports_masking=False,\n        )\n        self.run_layer_test(\n            ExampleWrapper,\n            init_kwargs={\n                \"layer\": layers.Dense(2),\n                \"activity_regularizer\": \"l2\",\n            },\n            input_shape=(2, 3),\n            expected_output_shape=(2, 2),\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=1,\n            supports_masking=False,\n        )\n        self.run_layer_test(\n            ExampleWrapper,\n            init_kwargs={\n                \"layer\": layers.BatchNormalization(),\n            },\n            input_shape=(2, 3),\n            expected_output_shape=(2, 3),\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=2,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n\n    def test_wrapper_invalid_layer(self):\n        invalid_layer = \"This is not a valid Keras layer.\"\n\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Layer .* supplied to Wrapper isn't a supported layer type. \"\n            \"Please ensure wrapped layer is a valid Keras layer.\",\n        ):\n            layers.Wrapper(invalid_layer)\n"
  },
  {
    "path": "keras/src/layers/input_spec.py",
    "content": "from keras.src import backend\nfrom keras.src import tree\nfrom keras.src.api_export import keras_export\n\n\n@keras_export([\"keras.InputSpec\", \"keras.layers.InputSpec\"])\nclass InputSpec:\n    \"\"\"Specifies the rank, dtype and shape of every input to a layer.\n\n    Layers can expose (if appropriate) an `input_spec` attribute:\n    an instance of `InputSpec`, or a nested structure of `InputSpec` instances\n    (one per input tensor). These objects enable the layer to run input\n    compatibility checks for input structure, input rank, input shape, and\n    input dtype for the first argument of `Layer.__call__`.\n\n    A `None` entry in a shape is compatible with any dimension.\n\n    Args:\n        dtype: Expected dtype of the input.\n        shape: Shape tuple, expected shape of the input\n            (may include `None` for dynamic axes).\n            Includes the batch size.\n        ndim: Integer, expected rank of the input.\n        max_ndim: Integer, maximum rank of the input.\n        min_ndim: Integer, minimum rank of the input.\n        axes: Dictionary mapping integer axes to\n            a specific dimension value.\n        allow_last_axis_squeeze: If `True`, allow inputs of rank N+1 as long\n            as the last axis of the input is 1, as well as inputs of rank N-1\n            as long as the last axis of the spec is 1.\n        name: Expected key corresponding to this input when passing data as\n            a dictionary.\n        optional: Boolean, whether the input is optional or not.\n            An optional input can accept `None` values.\n\n    Example:\n\n    ```python\n    class MyLayer(Layer):\n        def __init__(self):\n            super().__init__()\n            # The layer will accept inputs with\n            # shape (*, 28, 28) & (*, 28, 28, 1)\n            # and raise an appropriate error message otherwise.\n            self.input_spec = InputSpec(\n                shape=(None, 28, 28, 1),\n                allow_last_axis_squeeze=True)\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        dtype=None,\n        shape=None,\n        ndim=None,\n        max_ndim=None,\n        min_ndim=None,\n        axes=None,\n        allow_last_axis_squeeze=False,\n        name=None,\n        optional=False,\n    ):\n        self.dtype = (\n            backend.standardize_dtype(dtype) if dtype is not None else None\n        )\n        if shape is not None:\n            self.shape = backend.standardize_shape(shape)\n            self.ndim = len(shape)\n        else:\n            self.ndim = ndim\n            self.shape = None\n        self.max_ndim = max_ndim\n        self.min_ndim = min_ndim\n        self.name = name\n        self.optional = optional\n        self.allow_last_axis_squeeze = allow_last_axis_squeeze\n        try:\n            axes = axes or {}\n            self.axes = {int(k): axes[k] for k in axes}\n        except (ValueError, TypeError):\n            raise TypeError(\n                \"Argument `axes` must be a dict with integer keys. \"\n                f\"Received: axes={axes}\"\n            )\n\n        if self.axes and (self.ndim is not None or self.max_ndim is not None):\n            max_dim = (self.ndim if self.ndim else self.max_ndim) - 1\n            max_axis = max(self.axes)\n            if max_axis > max_dim:\n                raise ValueError(\n                    \"Axis {} is greater than the maximum \"\n                    \"allowed value: {}\".format(max_axis, max_dim)\n                )\n\n    def __repr__(self):\n        spec = [\n            (f\"dtype={str(self.dtype)}\") if self.dtype else \"\",\n            (f\"shape={str(self.shape)}\") if self.shape else \"\",\n            (f\"ndim={str(self.ndim)}\") if self.ndim else \"\",\n            (f\"max_ndim={str(self.max_ndim)}\") if self.max_ndim else \"\",\n            (f\"min_ndim={str(self.min_ndim)}\") if self.min_ndim else \"\",\n            (f\"axes={str(self.axes)}\") if self.axes else \"\",\n        ]\n        return f\"InputSpec({', '.join(x for x in spec if x)})\"\n\n    def get_config(self):\n        return {\n            \"dtype\": self.dtype,\n            \"shape\": self.shape,\n            \"ndim\": self.ndim,\n            \"max_ndim\": self.max_ndim,\n            \"min_ndim\": self.min_ndim,\n            \"axes\": self.axes,\n            \"optional\": self.optional,\n        }\n\n    @classmethod\n    def from_config(cls, config):\n        return cls(**config)\n\n\ndef assert_input_compatibility(input_spec, inputs, layer_name):\n    \"\"\"Checks compatibility between the layer and provided inputs.\n\n    This checks that the tensor(s) `inputs` verify the input assumptions\n    of a layer (if any). If not, a clear and actional exception gets raised.\n\n    Args:\n        input_spec: An InputSpec instance, list of InputSpec instances, a nested\n            structure of InputSpec instances, or None.\n        inputs: Input tensor, list of input tensors, or a nested structure of\n            input tensors.\n        layer_name: String, name of the layer (for error message formatting).\n\n    Raises:\n        ValueError: in case of mismatch between\n            the provided inputs and the expectations of the layer.\n    \"\"\"\n    if not input_spec:\n        return\n\n    input_spec = tree.flatten(input_spec)\n    if isinstance(inputs, dict):\n        # Flatten `inputs` by reference order if input spec names are provided\n        names = [spec.name for spec in input_spec]\n        if all(names):\n            list_inputs = []\n            for name in names:\n                if name not in inputs:\n                    raise ValueError(\n                        f'Missing data for input \"{name}\". '\n                        \"You passed a data dictionary with keys \"\n                        f\"{list(inputs.keys())}. \"\n                        f\"Expected the following keys: {names}\"\n                    )\n                list_inputs.append(inputs[name])\n            inputs = list_inputs\n\n    inputs = tree.flatten(inputs)\n    if len(inputs) != len(input_spec):\n        raise ValueError(\n            f'Layer \"{layer_name}\" expects {len(input_spec)} input(s),'\n            f\" but it received {len(inputs)} input tensors. \"\n            f\"Inputs received: {inputs}\"\n        )\n    for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):\n        if spec is None:\n            continue\n        if x is None and spec.optional:\n            continue\n\n        # Having a shape/dtype is the only commonality of the various\n        # tensor-like objects that may be passed. The most common kind of\n        # invalid type we are guarding for is a Layer instance (Functional API),\n        # which does not have a `shape` attribute.\n        if not hasattr(x, \"shape\"):\n            raise ValueError(\n                f\"Inputs to a layer should be tensors. Got '{x}' \"\n                f\"(of type {type(x)}) as input for layer '{layer_name}'.\"\n            )\n\n        shape = backend.standardize_shape(x.shape)\n        ndim = len(shape)\n        # Check ndim.\n        if spec.ndim is not None and not spec.allow_last_axis_squeeze:\n            if ndim != spec.ndim:\n                raise ValueError(\n                    f\"Input {input_index} with name '{spec.name}' of layer \"\n                    f\"'{layer_name}' is incompatible with the layer: \"\n                    f\"expected ndim={spec.ndim}, found ndim={ndim}. \"\n                    f\"Full shape received: {shape}\"\n                )\n        if spec.max_ndim is not None:\n            if ndim is not None and ndim > spec.max_ndim:\n                raise ValueError(\n                    f\"Input {input_index} with name '{spec.name}' of layer \"\n                    f\"'{layer_name}' is incompatible with the layer: \"\n                    f\"expected max_ndim={spec.max_ndim}, \"\n                    f\"found ndim={ndim}\"\n                )\n        if spec.min_ndim is not None:\n            if ndim is not None and ndim < spec.min_ndim:\n                raise ValueError(\n                    f\"Input {input_index} with name '{spec.name}' of layer \"\n                    f\"'{layer_name}' is incompatible with the layer: \"\n                    f\"expected min_ndim={spec.min_ndim}, \"\n                    f\"found ndim={ndim}. \"\n                    f\"Full shape received: {shape}\"\n                )\n        # Check dtype.\n        if spec.dtype is not None:\n            dtype = backend.standardize_dtype(x.dtype)\n            if dtype != spec.dtype:\n                raise ValueError(\n                    f\"Input {input_index} with name '{spec.name}' of layer \"\n                    f\"'{layer_name}' is incompatible with the layer: \"\n                    f\"expected dtype={spec.dtype}, \"\n                    f\"found dtype={dtype}\"\n                )\n\n        # Check specific shape axes.\n        if spec.axes:\n            for axis, value in spec.axes.items():\n                if value is not None and shape[axis] not in {\n                    value,\n                    None,\n                }:\n                    raise ValueError(\n                        f\"Input {input_index} with name '{spec.name}' of layer \"\n                        f\"'{layer_name}' is incompatible with the layer: \"\n                        f\"expected axis {axis} of input shape to have value \"\n                        f\"{value}, but received input with shape {shape}\"\n                    )\n        # Check shape.\n        if spec.shape is not None:\n            spec_shape = spec.shape\n            if spec.allow_last_axis_squeeze:\n                if shape and shape[-1] == 1:\n                    shape = shape[:-1]\n                if spec_shape and spec_shape[-1] == 1:\n                    spec_shape = spec_shape[:-1]\n            for spec_dim, dim in zip(spec_shape, shape):\n                if spec_dim is not None and dim is not None:\n                    if spec_dim != dim:\n                        raise ValueError(\n                            f\"Input {input_index} with name '{spec.name}' of \"\n                            f\"layer '{layer_name}' is incompatible with the \"\n                            f\"layer: expected shape={spec.shape}, found \"\n                            f\"shape={shape}\"\n                        )\n"
  },
  {
    "path": "keras/src/layers/layer.py",
    "content": "\"\"\"Layer is an Operation with state.\n\nTakes care of:\n\n- Weights / variables (and tracking thereof)\n- deferred build\n- trainable argument value inference\n- masking\n- autocasting\n\nAnd some more magic:\n\n- add_loss\n- metric tracking\n- RNG seed tracking\n- activity regularization\n\"\"\"\n\nimport collections\nimport functools\nimport inspect\nimport math\nimport warnings\nfrom functools import wraps\n\nfrom keras.src import backend\nfrom keras.src import constraints\nfrom keras.src import dtype_policies\nfrom keras.src import initializers\nfrom keras.src import ops\nfrom keras.src import regularizers\nfrom keras.src import tree\nfrom keras.src import utils\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend import KerasTensor\nfrom keras.src.backend.common import global_state\nfrom keras.src.backend.common import remat\nfrom keras.src.backend.common.keras_tensor import any_symbolic_tensors\nfrom keras.src.backend.common.name_scope import current_path\nfrom keras.src.backend.common.remat import get_current_remat_mode\nfrom keras.src.backend.common.symbolic_scope import in_symbolic_scope\nfrom keras.src.backend.config import is_nnx_enabled\nfrom keras.src.distribution import distribution_lib\nfrom keras.src.dtype_policies import DTypePolicyMap\nfrom keras.src.layers import input_spec\nfrom keras.src.metrics.metric import Metric\nfrom keras.src.ops.node import Node\nfrom keras.src.ops.operation import Operation\nfrom keras.src.quantizers.quantization_config import validate_and_resolve_config\nfrom keras.src.utils import python_utils\nfrom keras.src.utils import summary_utils\nfrom keras.src.utils import traceback_utils\nfrom keras.src.utils import tracking\n\nif backend.backend() == \"tensorflow\":\n    from keras.src.backend.tensorflow.layer import TFLayer as BackendLayer\nelif backend.backend() == \"jax\":\n    from keras.src.backend.jax.layer import JaxLayer as BackendLayer\nelif backend.backend() == \"torch\":\n    from keras.src.backend.torch.layer import TorchLayer as BackendLayer\nelif backend.backend() == \"numpy\":\n    from keras.src.backend.numpy.layer import NumpyLayer as BackendLayer\nelif backend.backend() == \"openvino\":\n    from keras.src.backend.openvino.layer import OpenvinoLayer as BackendLayer\nelse:\n    raise RuntimeError(\n        f\"Backend '{backend.backend()}' must implement a layer mixin class.\"\n    )\n\n\n@keras_export([\"keras.Layer\", \"keras.layers.Layer\"])\nclass Layer(BackendLayer, Operation):\n    \"\"\"This is the class from which all layers inherit.\n\n    A layer is a callable object that takes as input one or more tensors and\n    that outputs one or more tensors. It involves *computation*, defined\n    in the `call()` method, and a *state* (weight variables). State can be\n    created:\n\n    * in `__init__()`, for instance via `self.add_weight()`;\n    * in the optional `build()` method, which is invoked by the first\n      `__call__()` to the layer, and supplies the shape(s) of the input(s),\n      which may not have been known at initialization time.\n\n    Layers are recursively composable: If you assign a Layer instance as an\n    attribute of another Layer, the outer layer will start tracking the weights\n    created by the inner layer. Nested layers should be instantiated in the\n    `__init__()` method or `build()` method.\n\n    Users will just instantiate a layer and then treat it as a callable.\n\n    Args:\n        trainable: Boolean, whether the layer's variables should be trainable.\n        name: String name of the layer.\n        dtype: The dtype of the layer's computations and weights. Can also be a\n            `keras.DTypePolicy`, which allows the computation and weight dtype\n            to differ. Defaults to `None`. `None` means to use\n            `keras.config.dtype_policy()`, which is a `float32` policy unless\n            set to different value (via `keras.config.set_dtype_policy()`).\n\n    Attributes:\n        name: The name of the layer (string).\n        dtype: Dtype of the layer's weights. Alias of `layer.variable_dtype`.\n        variable_dtype: Dtype of the layer's weights.\n        compute_dtype: The dtype of the layer's computations.\n            Layers automatically cast inputs to this dtype, which causes\n            the computations and output to also be in this dtype.\n            When mixed precision is used with a\n            `keras.DTypePolicy`, this will be different\n            than `variable_dtype`.\n        trainable_weights: List of variables to be included in backprop.\n        non_trainable_weights: List of variables that should not be\n            included in backprop.\n        weights: The concatenation of the lists trainable_weights and\n            non_trainable_weights (in this order).\n        trainable: Whether the layer should be trained (boolean), i.e.\n            whether its potentially-trainable weights should be returned\n            as part of `layer.trainable_weights`.\n        input_spec: Optional (list of) `InputSpec` object(s) specifying the\n            constraints on inputs that can be accepted by the layer.\n\n    We recommend that descendants of `Layer` implement the following methods:\n\n    * `__init__()`: Defines custom layer attributes, and creates layer weights\n        that do not depend on input shapes, using `add_weight()`,\n        or other state.\n    * `build(self, input_shape)`: This method can be used to create weights that\n        depend on the shape(s) of the input(s), using `add_weight()`, or other\n        state. `__call__()` will automatically build the layer\n        (if it has not been built yet) by calling `build()`.\n    * `call(self, *args, **kwargs)`: Called in `__call__` after making\n        sure `build()` has been called. `call()` performs the logic of applying\n        the layer to the input arguments.\n        Two reserved keyword arguments you can optionally use in `call()` are:\n            1. `training` (boolean, whether the call is in inference mode or\n                training mode).\n            2. `mask` (boolean tensor encoding masked timesteps in the input,\n                used e.g. in RNN layers).\n        A typical signature for this method is `call(self, inputs)`, and user\n        could optionally add `training` and `mask` if the layer need them.\n    * `get_config(self)`: Returns a dictionary containing the configuration\n        used to initialize this layer. If the keys differ from the arguments\n        in `__init__()`, then override `from_config(self)` as well.\n        This method is used when saving\n        the layer or a model that contains this layer.\n\n    Examples:\n\n    Here's a basic example: a layer with two variables, `w` and `b`,\n    that returns `y = w . x + b`.\n    It shows how to implement `build()` and `call()`.\n    Variables set as attributes of a layer are tracked as weights\n    of the layers (in `layer.weights`).\n\n    ```python\n    class SimpleDense(Layer):\n        def __init__(self, units=32):\n            super().__init__()\n            self.units = units\n\n        # Create the state of the layer (weights)\n        def build(self, input_shape):\n            self.kernel = self.add_weight(\n                shape=(input_shape[-1], self.units),\n                initializer=\"glorot_uniform\",\n                trainable=True,\n                name=\"kernel\",\n            )\n            self.bias = self.add_weight(\n                shape=(self.units,),\n                initializer=\"zeros\",\n                trainable=True,\n                name=\"bias\",\n            )\n\n        # Defines the computation\n        def call(self, inputs):\n            return ops.matmul(inputs, self.kernel) + self.bias\n\n    # Instantiates the layer.\n    linear_layer = SimpleDense(4)\n\n    # This will also call `build(input_shape)` and create the weights.\n    y = linear_layer(ops.ones((2, 2)))\n    assert len(linear_layer.weights) == 2\n\n    # These weights are trainable, so they're listed in `trainable_weights`:\n    assert len(linear_layer.trainable_weights) == 2\n    ```\n\n    Besides trainable weights, updated via backpropagation during training,\n    layers can also have non-trainable weights. These weights are meant to\n    be updated manually during `call()`. Here's a example layer that computes\n    the running sum of its inputs:\n\n    ```python\n    class ComputeSum(Layer):\n\n      def __init__(self, input_dim):\n          super(ComputeSum, self).__init__()\n          # Create a non-trainable weight.\n          self.total = self.add_weight(\n            shape=(),\n            initializer=\"zeros\",\n            trainable=False,\n            name=\"total\",\n          )\n\n      def call(self, inputs):\n          self.total.assign(self.total + ops.sum(inputs))\n          return self.total\n\n    my_sum = ComputeSum(2)\n    x = ops.ones((2, 2))\n    y = my_sum(x)\n\n    assert my_sum.weights == [my_sum.total]\n    assert my_sum.non_trainable_weights == [my_sum.total]\n    assert my_sum.trainable_weights == []\n    ```\n    \"\"\"\n\n    def __new__(cls, *args, **kwargs):\n        obj = super().__new__(cls, *args, **kwargs)\n        # Wrap the user-provided `build` method in the `build_wrapper`\n        # to add name scope support and serialization support.\n        original_build_method = obj.build\n\n        @wraps(original_build_method)\n        def build_wrapper(*args, **kwargs):\n            with obj._open_name_scope():\n                obj._path = current_path()\n                original_build_method(*args, **kwargs)\n            # Record build config.\n            signature = inspect.signature(original_build_method)\n            obj._build_shapes_dict = signature.bind(*args, **kwargs).arguments\n            # Set built, post build actions, and lock state.\n            obj.built = True\n            obj._post_build()\n            obj._lock_state()\n\n        obj.build = build_wrapper\n\n        # Wrap the user-provided `quantize` method in the `quantize_wrapper`\n        # to add tracker support.\n        original_quantize_method = obj.quantize\n\n        @wraps(original_quantize_method)\n        def quantize_wrapper(mode=None, config=None, **kwargs):\n            config = validate_and_resolve_config(mode, config)\n            mode = config.mode\n            obj._check_quantize_args(mode, obj.compute_dtype)\n            obj._tracker.unlock()\n            try:\n                original_quantize_method(mode=mode, config=config, **kwargs)\n            except Exception:\n                raise\n            finally:\n                obj._tracker.lock()\n\n        obj.quantize = quantize_wrapper\n\n        return obj\n\n    def __init__(\n        self,\n        *,\n        activity_regularizer=None,\n        trainable=True,\n        dtype=None,\n        autocast=True,\n        name=None,\n        **kwargs,\n    ):\n        BackendLayer.__init__(self)\n        self._lock = False\n        Operation.__init__(self, name=name)\n        self._dtype_policy = dtype_policies.get(dtype)\n        self.activity_regularizer = regularizers.get(activity_regularizer)\n        input_dim_arg = kwargs.pop(\"input_dim\", None)\n        if input_dim_arg is not None:\n            input_shape_arg = (input_dim_arg,)\n        else:\n            input_shape_arg = kwargs.pop(\"input_shape\", None)\n        if input_shape_arg is not None:\n            warnings.warn(\n                \"Do not pass an `input_shape`/`input_dim` argument to \"\n                \"a layer. When using Sequential models, \"\n                \"prefer using an `Input(shape)` object as the \"\n                \"first layer in the model instead.\",\n                stacklevel=2,\n            )\n            self._input_shape_arg = input_shape_arg\n        if kwargs:\n            raise ValueError(\n                \"Unrecognized keyword arguments \"\n                f\"passed to {self.__class__.__name__}: {kwargs}\"\n            )\n\n        self._path = None  # Will be determined in `build_wrapper`\n        self.built = False\n        self.autocast = autocast\n        self._input_spec = None\n        self._called = False\n        self.supports_jit = True\n\n        self._trainable = trainable\n        self._losses = []\n        self._loss_ids = set()\n        self._losses_override = []\n\n        self._call_signature = inspect.signature(self.call)\n        self.call_signature_parameters = [\n            p.name for p in self._call_signature.parameters.values()\n        ]\n        self._call_has_training_arg = (\n            \"training\" in self.call_signature_parameters\n        )\n        self._call_has_mask_arg = \"mask\" in self.call_signature_parameters\n\n        # 1. collect names that should be auto‑propagated\n        self._call_context_args = {\"training\"}\n\n        # 2. remember which of them exist in *this* call signature\n        self._call_has_context_arg = {\n            arg: (arg in self.call_signature_parameters)\n            for arg in self._call_context_args\n        }\n\n        self._supports_masking = not utils.is_default(self.compute_mask)\n        # Whether to automatically convert (+ auto-cast) inputs to `call()`.\n        self._convert_input_args = True\n        # Whether to allow non-tensors as positional arguments in `call()`.\n        self._allow_non_tensor_positional_args = False\n        # Dict of shapes that were used to call `build()`.\n        self._build_shapes_dict = None\n        # Parent path\n        self._parent_path = None\n        self._remat_mode = get_current_remat_mode()\n        self._initialize_tracker()\n\n    @tracking.no_automatic_dependency_tracking\n    def _initialize_tracker(self):\n        if hasattr(self, \"_tracker\"):\n            return\n\n        trainable_variables = []\n        non_trainable_variables = []\n        layers = []\n        metrics = []\n        seed_generators = []\n        self._tracker = tracking.Tracker(\n            {\n                \"trainable_variables\": (\n                    lambda x: isinstance(x, backend.Variable) and x.trainable,\n                    trainable_variables,\n                ),\n                \"non_trainable_variables\": (\n                    lambda x: (\n                        isinstance(x, backend.Variable) and not x.trainable\n                    ),\n                    non_trainable_variables,\n                ),\n                \"metrics\": (lambda x: isinstance(x, Metric), metrics),\n                \"layers\": (\n                    lambda x: (\n                        isinstance(x, Layer) and not isinstance(x, Metric)\n                    ),\n                    layers,\n                ),\n                \"seed_generators\": (\n                    lambda x: isinstance(x, backend.random.SeedGenerator),\n                    seed_generators,\n                ),\n            },\n            exclusions={\"non_trainable_variables\": [\"trainable_variables\"]},\n        )\n        if backend.backend() == \"tensorflow\":\n            # Remove attribute tracking for lists (TF-specific attribute)\n            _self_setattr_tracking = getattr(\n                self, \"_self_setattr_tracking\", True\n            )\n            self._self_setattr_tracking = False\n\n        self._trainable_variables = trainable_variables\n        self._non_trainable_variables = non_trainable_variables\n        self._layers = layers\n        self._metrics = metrics\n        self._seed_generators = seed_generators\n\n        if backend.backend() == \"tensorflow\":\n            # Reset attribute tracking (TF-specific)\n            self._self_setattr_tracking = _self_setattr_tracking\n\n    def _build_at_init(self):\n        \"\"\"Build the layer at `Layer.__init__`.\n\n        We can only safely mark the layer as `built=True` in `Layer.__init__` if\n        `build` is not overridden. Otherwise, it might cause the subclasses to\n        ignore the user's `build`.\n        \"\"\"\n        if utils.is_default(self.build):\n            self.built = True\n            self._post_build()\n            self._lock_state()\n\n    @property\n    def path(self):\n        \"\"\"The path of the layer.\n\n        If the layer has not been built yet, it will be `None`.\n        \"\"\"\n        return self._path\n\n    @property\n    def input_spec(self):\n        return self._input_spec\n\n    @input_spec.setter\n    def input_spec(self, value):\n        self._input_spec = value\n\n    @utils.default\n    def build(self, input_shape):\n        self._check_super_called()\n        if utils.is_default(self.build) and might_have_unbuilt_state(self):\n            warnings.warn(\n                f\"`build()` was called on layer '{self.name}', however \"\n                \"the layer does not have a `build()` method implemented \"\n                \"and it looks like it has unbuilt state. This will cause \"\n                \"the layer to be marked as built, despite not being \"\n                \"actually built, which may cause failures down the line. \"\n                \"Make sure to implement a proper `build()` method.\"\n            )\n        self.built = True\n\n    def _lock_state(self):\n        \"\"\"Prevent further state updates, called automatically in `build()`.\"\"\"\n        if not self._tracker.locked:\n            self._tracker.lock(\n                msg=(\n                    \"You cannot add new elements of state \"\n                    \"(variables or sub-layers) \"\n                    \"to a layer that is already built. All state \"\n                    \"must be created in the `__init__()` method or \"\n                    \"in the `build()` method.\"\n                )\n            )\n\n    def get_build_config(self):\n        \"\"\"Returns a dictionary with the layer's input shape.\n\n        This method returns a config dict that can be used by\n        `build_from_config(config)` to create all states (e.g. Variables and\n        Lookup tables) needed by the layer.\n\n        By default, the config only contains the input shape that the layer\n        was built with. If you're writing a custom layer that creates state in\n        an unusual way, you should override this method to make sure this state\n        is already created when Keras attempts to load its value upon model\n        loading.\n\n        Returns:\n            A dict containing the input shape associated with the layer.\n        \"\"\"\n        if self._build_shapes_dict is not None:\n            if len(self._build_shapes_dict) == 1:\n                return {\n                    \"input_shape\": tuple(self._build_shapes_dict.values())[0],\n                }\n            else:\n                return {\"shapes_dict\": self._build_shapes_dict}\n\n    def build_from_config(self, config):\n        \"\"\"Builds the layer's states with the supplied config dict.\n\n        By default, this method calls the `build(config[\"input_shape\"])` method,\n        which creates weights based on the layer's input shape in the supplied\n        config. If your config contains other information needed to load the\n        layer's state, you should override this method.\n\n        Args:\n            config: Dict containing the input shape associated with this layer.\n        \"\"\"\n        if config:\n            if \"input_shape\" in config:\n                self.build(config[\"input_shape\"])\n            elif \"shapes_dict\" in config:\n                self.build(**config[\"shapes_dict\"])\n\n    def _obj_type(self):\n        return \"Layer\"\n\n    def add_variable(\n        self,\n        shape,\n        initializer,\n        dtype=None,\n        trainable=True,\n        autocast=True,\n        regularizer=None,\n        constraint=None,\n        name=None,\n    ):\n        \"\"\"Add a weight variable to the layer.\n\n        Alias of `add_weight()`.\n        \"\"\"\n        return self.add_weight(\n            shape=shape,\n            initializer=initializer,\n            dtype=dtype,\n            trainable=trainable,\n            autocast=autocast,\n            regularizer=regularizer,\n            constraint=constraint,\n            name=name,\n        )\n\n    def add_weight(\n        self,\n        *args,\n        shape=None,\n        initializer=None,\n        dtype=None,\n        trainable=True,\n        autocast=True,\n        regularizer=None,\n        constraint=None,\n        aggregation=\"none\",\n        overwrite_with_gradient=False,\n        name=None,\n    ):\n        \"\"\"Add a weight variable to the layer.\n\n        Args:\n            shape: Shape tuple for the variable. Must be fully-defined\n                (no `None` entries). Defaults to `()` (scalar) if unspecified.\n            initializer: Initializer object to use to populate the initial\n                variable value, or string name of a built-in initializer\n                (e.g. `\"random_normal\"`). If unspecified, defaults to\n                `\"glorot_uniform\"` for floating-point variables and to `\"zeros\"`\n                for all other types (e.g. int, bool).\n            dtype: Dtype of the variable to create, e.g. `\"float32\"`. If\n                unspecified, defaults to the layer's variable dtype\n                (which itself defaults to `\"float32\"` if unspecified).\n            trainable: Boolean, whether the variable should be trainable via\n                backprop or whether its updates are managed manually. Defaults\n                to `True`.\n            autocast: Boolean, whether to autocast layers variables when\n                accessing them. Defaults to `True`.\n            regularizer: Regularizer object to call to apply penalty on the\n                weight. These penalties are summed into the loss function\n                during optimization. Defaults to `None`.\n            constraint: Contrainst object to call on the variable after any\n                optimizer update, or string name of a built-in constraint.\n                Defaults to `None`.\n            aggregation: Optional string, one of `None`, `\"none\"`, `\"mean\"`,\n                `\"sum\"` or `\"only_first_replica\"`. Annotates the variable with\n                the type of multi-replica aggregation to be used for this\n                variable when writing custom data parallel training loops.\n                Defaults to `\"none\"`.\n            overwrite_with_gradient: Boolean, whether to overwrite the variable\n                with the computed gradient. This is useful for float8 training.\n                Defaults to `False`.\n            name: String name of the variable. Useful for debugging purposes.\n        \"\"\"\n        self._check_super_called()\n        if args:\n            # `args` is only kept to detect the legacy Keras 2 call style\n            # (`add_weight(shape, initializer, dtype, ...)`) and raise a clear\n            # error for positional `name`.\n            if len(args) > 3:\n                raise TypeError(\n                    \"add_weight() takes at most 3 positional arguments \"\n                    f\"but {len(args)} were given.\"\n                )\n            shape_arg = args[0]\n            if isinstance(shape_arg, str):\n                raise ValueError(\n                    \"`name` must be passed as a keyword argument. \"\n                    f\"Received: add_weight('{shape_arg}', ...). \"\n                    f\"Use: add_weight(shape=..., name='{shape_arg}').\"\n                )\n            if shape is not None:\n                raise ValueError(\n                    \"`shape` was passed both positionally and as \"\n                    \"a keyword argument.\"\n                )\n            shape = shape_arg\n            if len(args) > 1:\n                if initializer is not None:\n                    raise ValueError(\n                        \"`initializer` was passed both positionally and \"\n                        \"as a keyword argument.\"\n                    )\n                initializer = args[1]\n            if len(args) > 2:\n                if dtype is not None:\n                    raise ValueError(\n                        \"`dtype` was passed both positionally and as a \"\n                        \"keyword argument.\"\n                    )\n                dtype = args[2]\n        if shape is None:\n            shape = ()\n        if dtype is not None:\n            dtype = backend.standardize_dtype(dtype)\n        else:\n            dtype = self.variable_dtype\n        if initializer is None:\n            if \"float\" in dtype:\n                initializer = \"glorot_uniform\"\n            else:\n                initializer = \"zeros\"\n        initializer = initializers.get(initializer)\n        with backend.name_scope(self.name, caller=self):\n            variable = backend.Variable(\n                initializer=initializer,\n                shape=shape,\n                dtype=dtype,\n                trainable=trainable,\n                autocast=autocast,\n                aggregation=aggregation,\n                name=name,\n            )\n        # Will be added to layer.losses\n        variable.regularizer = regularizers.get(regularizer)\n        variable.constraint = constraints.get(constraint)\n        variable.overwrite_with_gradient = overwrite_with_gradient\n        self._track_variable(variable)\n        return variable\n\n    @property\n    def trainable(self):\n        \"\"\"Settable boolean, whether this layer should be trainable or not.\"\"\"\n        return self._trainable\n\n    @trainable.setter\n    def trainable(self, value):\n        \"\"\"Sets trainable attribute for the layer and its sublayers.\n\n        When this value is changed during training (e.g. with a\n        `Callback`) you need to call the parent\n        `Model.make_train_function` with `force=True` in order to\n        recompile the training graph.\n\n        Args:\n            value: Boolean with the desired state for the layer's trainable\n                attribute.\n        \"\"\"\n        value = bool(value)\n        self._trainable = value\n        for v in self._trainable_variables:\n            v.trainable = value\n        for layer in self._layers:\n            layer.trainable = value\n\n    @property\n    def variables(self):\n        \"\"\"List of all layer state, including random seeds.\n\n        This extends `layer.weights` to include all state used by the layer\n        including `SeedGenerator`s.\n\n        Note that metrics variables are not included here, use\n        `metrics_variables` to visit all the metric variables.\n        \"\"\"\n        # Return all `Variables` associate with the layer including metrics\n        # and random seeds. Also deduplicate them.\n        variables = []\n        seen_ids = set()\n        for v in self._trainable_variables + self._non_trainable_variables:\n            if id(v) not in seen_ids:\n                variables.append(v)\n                seen_ids.add(id(v))\n        for sg in self._seed_generators:\n            variables.append(sg.state)\n        for layer in self._layers:\n            for v in layer.variables:\n                if id(v) not in seen_ids:\n                    variables.append(v)\n                    seen_ids.add(id(v))\n        return variables\n\n    @property\n    def trainable_variables(self):\n        \"\"\"List of all trainable layer state.\n\n        This is equivalent to `layer.trainable_weights`.\n        \"\"\"\n        if not self.trainable:\n            return []\n        return [v for v in self.variables if v.trainable]\n\n    @property\n    def non_trainable_variables(self):\n        \"\"\"List of all non-trainable layer state.\n\n        This extends `layer.non_trainable_weights` to include all state used by\n        the layer including state for metrics and `SeedGenerator`s.\n        \"\"\"\n        if not self.trainable:\n            return self.variables\n        return [v for v in self.variables if not v.trainable]\n\n    @property\n    def weights(self):\n        \"\"\"List of all weight variables of the layer.\n\n        Unlike, `layer.variables` this excludes metric state and random seeds.\n        \"\"\"\n        # Return only `Variables` directly owned by layers and sub-layers.\n        # Also deduplicate them.\n        weights = []\n        seen_ids = set()\n        for w in self._trainable_variables + self._non_trainable_variables:\n            if id(w) not in seen_ids:\n                weights.append(w)\n                seen_ids.add(id(w))\n        for layer in self._layers:\n            for w in layer.weights:\n                if id(w) not in seen_ids:\n                    weights.append(w)\n                    seen_ids.add(id(w))\n        return weights\n\n    @property\n    def trainable_weights(self):\n        \"\"\"List of all trainable weight variables of the layer.\n\n        These are the weights that get updated by the optimizer during training.\n        \"\"\"\n        if not self.trainable:\n            return []\n        return [v for v in self.weights if v.trainable]\n\n    @property\n    def non_trainable_weights(self):\n        \"\"\"List of all non-trainable weight variables of the layer.\n\n        These are the weights that should not be updated by the optimizer during\n        training. Unlike, `layer.non_trainable_variables` this excludes metric\n        state and random seeds.\n        \"\"\"\n        if not self.trainable:\n            return self.weights\n        return [v for v in self.weights if not v.trainable]\n\n    @property\n    def metrics(self):\n        \"\"\"List of all metrics.\"\"\"\n        metrics = list(self._metrics)\n        for layer in self._layers:\n            metrics.extend(layer.metrics)\n        return metrics\n\n    @property\n    def metrics_variables(self):\n        \"\"\"List of all metric variables.\"\"\"\n        vars = []\n        for metric in self.metrics:\n            vars.extend(metric.variables)\n        return vars\n\n    def get_weights(self):\n        \"\"\"Return the values of `layer.weights` as a list of NumPy arrays.\"\"\"\n        return [v.numpy() for v in self.weights]\n\n    def set_weights(self, weights):\n        \"\"\"Sets the values of `layer.weights` from a list of NumPy arrays.\"\"\"\n        layer_weights = self.weights\n        if len(layer_weights) != len(weights):\n            raise ValueError(\n                f\"You called `set_weights(weights)` on layer '{self.name}' \"\n                f\"with a weight list of length {len(weights)}, but the layer \"\n                f\"was expecting {len(layer_weights)} weights.\"\n            )\n        for variable, value in zip(layer_weights, weights):\n            if variable.shape != value.shape:\n                raise ValueError(\n                    f\"Layer {self.name} weight shape {variable.shape} \"\n                    \"is not compatible with provided weight \"\n                    f\"shape {value.shape}.\"\n                )\n            variable.assign(value)\n\n    @property\n    def dtype_policy(self):\n        return self._dtype_policy\n\n    @dtype_policy.setter\n    def dtype_policy(self, value):\n        policy = dtype_policies.get(value)\n        if isinstance(self._dtype_policy, DTypePolicyMap) and self.path:\n            if self.path in self._dtype_policy:\n                del self._dtype_policy[self.path]\n            self._dtype_policy[self.path] = policy\n        else:\n            self._dtype_policy = policy\n        if policy.quantization_mode is not None:\n            if self.built and not getattr(self, \"_is_quantized\", False):\n                if policy.quantization_mode == \"gptq\":\n                    raise ValueError(\n                        \"Implicitly enabling GPTQ quantization by setting \"\n                        f\"`dtype_policy` to '{value}' is not supported. \"\n                        \"GPTQ requires a calibration dataset and a \"\n                        \"`GPTQConfig` object.\\n\\n\"\n                        \"Please use the `.quantize('gptq', config=...)` method \"\n                        \"on the layer or model instead.\"\n                    )\n                self.quantize(policy.quantization_mode)\n\n    @property\n    def dtype(self):\n        \"\"\"Alias of `layer.variable_dtype`.\"\"\"\n        return self.variable_dtype\n\n    @property\n    def compute_dtype(self):\n        \"\"\"The dtype of the computations performed by the layer.\"\"\"\n        if isinstance(self._dtype_policy, DTypePolicyMap) and self.path:\n            policy = self._dtype_policy[self.path]\n        else:\n            policy = self._dtype_policy\n        return policy.compute_dtype\n\n    @property\n    def variable_dtype(self):\n        \"\"\"The dtype of the state (weights) of the layer.\"\"\"\n        if isinstance(self._dtype_policy, DTypePolicyMap) and self.path:\n            policy = self._dtype_policy[self.path]\n        else:\n            policy = self._dtype_policy\n        return policy.variable_dtype\n\n    @property\n    def quantization_mode(self):\n        \"\"\"The quantization mode of this layer, `None` if not quantized.\"\"\"\n        if isinstance(self._dtype_policy, DTypePolicyMap) and self.path:\n            policy = self._dtype_policy[self.path]\n        else:\n            policy = self._dtype_policy\n        return policy.quantization_mode\n\n    @property\n    def input_dtype(self):\n        \"\"\"The dtype layer inputs should be converted to.\"\"\"\n        return self.compute_dtype\n\n    @property\n    def supports_masking(self):\n        \"\"\"Whether this layer supports computing a mask using `compute_mask`.\"\"\"\n        return self._supports_masking\n\n    @supports_masking.setter\n    def supports_masking(self, value):\n        self._supports_masking = value\n\n    @utils.default\n    def compute_mask(self, inputs, previous_mask):\n        return previous_mask\n\n    def symbolic_call(self, *args, **kwargs):\n        # Node is created at the end of `__call__` instead of `symbolic_call`.\n        return self.compute_output_spec(*args, **kwargs)\n\n    @traceback_utils.filter_traceback\n    def __call__(self, *args, **kwargs):\n        self._check_super_called()\n        self._called = True\n\n        original_args = args\n        original_kwargs = kwargs\n\n        #############################################################\n        # 1. Convert any array arguments to tensors of correct dtype.\n        def maybe_convert(x):\n            # Prevent _keras_mask from disappearing\n            mask = backend.get_keras_mask(x)\n            y = self.dtype_policy.convert_input(\n                x, self.autocast, self.input_dtype\n            )\n            if mask is not None:\n                backend.set_keras_mask(y, mask)\n            return y\n\n        # Used to avoid expensive `tree` operations in the most common case.\n        if (\n            kwargs\n            or len(args) != 1\n            or not is_backend_tensor_or_symbolic(args[0], allow_none=False)\n            or backend.standardize_dtype(args[0].dtype) != self.input_dtype\n        ) and self._convert_input_args:\n            args = tree.map_structure(maybe_convert, args)\n            kwargs = tree.map_structure(maybe_convert, kwargs)\n\n        ##########################################################\n        # 2. Enforce that only tensors can be passed positionally.\n        if not self._allow_non_tensor_positional_args:\n            for arg in tree.flatten(args):\n                if not is_backend_tensor_or_symbolic(arg, allow_none=True):\n                    raise ValueError(\n                        \"Only input tensors may be passed as \"\n                        \"positional arguments. The following argument value \"\n                        f\"should be passed as a keyword argument: {arg} \"\n                        f\"(of type {type(arg)})\"\n                    )\n\n        # Caches info about `call()` signature, args, kwargs.\n        call_spec = CallSpec(\n            self._call_signature, self._call_context_args, args, kwargs\n        )\n\n        ############################################\n        # 3. Check input spec for 1st positional arg.\n        # TODO: consider extending this to all args and kwargs.\n        self._assert_input_compatibility(call_spec.first_arg)\n\n        ################\n        # 4. Call build\n        with self._open_name_scope():\n            self._maybe_build(call_spec)\n\n        ##########################\n        # 5. Infer training value\n        # Training phase for `Layer.call` is set via (in order of priority):\n        # (1) The `training` argument passed to this `Layer.call`, if not None\n        # (2) The training argument of an outer `Layer.call`.\n        # (4) Any non-None default value for `training` in the call signature\n        # (5) False (treating the layer as if it's in inference)\n\n        # Maintains info about the `Layer.call` stack\n        # across nested calls.\n        call_context = self._get_call_context()\n\n        for context_arg in self._call_context_args:\n            self._resolve_and_populate_arg(\n                context_arg, call_spec, call_context, kwargs\n            )\n\n        ##############################\n        # 6. Populate mask argument(s)\n        if len(call_spec.tensor_arguments_dict) == 1:\n            if (\n                \"mask\" in call_spec.argument_names\n                and call_spec.arguments_dict[\"mask\"] is None\n            ):\n                arg_name = list(call_spec.tensor_arguments_dict.keys())[0]\n                only_tensor_arg = call_spec.tensor_arguments_dict[arg_name]\n                mask = tree.map_structure(\n                    backend.get_keras_mask,\n                    only_tensor_arg,\n                )\n                kwargs[\"mask\"] = mask\n        elif len(call_spec.tensor_arguments_dict) > 1:\n            for k, v in call_spec.tensor_arguments_dict.items():\n                expected_mask_arg_name = f\"{k}_mask\"\n                if expected_mask_arg_name in call_spec.argument_names:\n                    if call_spec.arguments_dict[expected_mask_arg_name] is None:\n                        mask = tree.map_structure(backend.get_keras_mask, v)\n                        kwargs[expected_mask_arg_name] = mask\n\n        # We need to cache the `previous_mask` before `__call__` because the\n        # mask might be removed during the call, such as `MultiHeadAttention`.\n        if \"mask\" in kwargs and kwargs[\"mask\"] is not None:\n            # Case 1: Mask was explicitly passed or auto-populated in step 6.\n            previous_mask = kwargs[\"mask\"]\n        else:\n            # Case 2: Fallback to the mask attached to the first input tensor.\n            previous_mask = tree.map_structure(\n                backend.get_keras_mask, call_spec.first_arg\n            )\n\n        ####################\n        # 7. Call the layer.\n        try:\n            with self._open_name_scope():\n                current_scope = backend.get_autocast_scope()\n                new_scope = None\n                if current_scope is not None:\n                    # Clear or update the current scope if necessary.\n                    if not self.autocast:\n                        new_scope = backend.AutocastScope(None)\n                    elif not backend.is_float_dtype(self.compute_dtype):\n                        # Some preprocessing layers might have a non-float\n                        # dtype, we should not autocast in this case.\n                        new_scope = backend.AutocastScope(None)\n                    elif current_scope.dtype != self.compute_dtype:\n                        new_scope = backend.AutocastScope(self.compute_dtype)\n                elif self.compute_dtype != self.variable_dtype:\n                    # Enter a new scope if our dtypes are \"mixed\".\n                    new_scope = backend.AutocastScope(self.compute_dtype)\n                if new_scope is not None:\n                    with new_scope:\n                        outputs = super().__call__(*args, **kwargs)\n                else:\n                    outputs = super().__call__(*args, **kwargs)\n                # Change the layout for the layer output if needed.\n                # This is useful for relayout intermediate tensor in the model\n                # to achieve the optimal performance.\n                distribution = distribution_lib.distribution()\n                if distribution is not None:\n                    current_layer_path = current_path()\n                    current_layer_path += \"/output\"\n                    layout = distribution.get_tensor_layout(current_layer_path)\n                    if layout:\n                        outputs = distribution_lib.distribute_tensor(\n                            outputs, layout\n                        )\n\n                self.built = True\n                # Record activity regularizer loss.\n                if self.activity_regularizer is not None:\n                    for output in tree.flatten(outputs):\n                        if backend.is_tensor(output):\n                            loss = self.activity_regularizer(output)\n                            if output.ndim > 0:\n                                # Normalize by batch size to ensure consistent\n                                # regularization strength across batch sizes\n                                batch_size = ops.cast(\n                                    ops.shape(output)[0], dtype=loss.dtype\n                                )\n                                loss = ops.divide_no_nan(loss, batch_size)\n                            self.add_loss(loss)\n\n            # Set `previous_mask` on outputs if available. It is provided only\n            # for the first positional input arg and its mask.\n            # TODO: consider extending this to all args and kwargs.\n            if self.supports_masking:\n                self._set_mask_metadata(\n                    call_spec.first_arg, outputs, previous_mask\n                )\n            elif any(m is not None for m in tree.flatten(previous_mask)):\n                warnings.warn(\n                    f\"Layer '{self.name}' (of type {self.__class__.__name__}) \"\n                    \"was passed an input with a mask attached to it. \"\n                    \"However, this layer does not support masking and will \"\n                    \"therefore destroy the mask information. Downstream \"\n                    \"layers will not see the mask.\"\n                )\n        finally:\n            # Destroy call context if we created it\n            self._maybe_reset_call_context()\n\n        ################################################\n        # 8. Add a node in the graph for symbolic calls.\n        if any_symbolic_tensors(original_args, original_kwargs):\n            Node(\n                operation=self,\n                call_args=original_args,\n                call_kwargs=original_kwargs,\n                outputs=outputs,\n            )\n\n        return outputs\n\n    def call(self, *args, **kwargs):\n        raise self._not_implemented_error(self.call)\n\n    def _resolve_and_populate_arg(\n        self, arg_name, call_spec, call_context, kwargs\n    ):\n        # 1) user explicitly passed it?\n        if arg_name in call_spec.user_arguments_dict:\n            value = call_spec.user_arguments_dict[arg_name]\n        # 2) else: inherited from outer layer call?\n        elif call_context.get_value(arg_name) is not None:\n            value = call_context.get_value(arg_name)\n        # 3) else: default from the call() signature\n        else:\n            value = call_spec.arguments_dict.get(arg_name, None)\n\n        # stash it for downstream layers\n        call_context.set_value(arg_name, value)\n\n        # only inject it if this layer actually accepts it and it's not None\n        if (\n            self._call_has_context_arg.get(arg_name, False)\n            and value is not None\n        ):\n            kwargs[arg_name] = value\n\n    @traceback_utils.filter_traceback\n    def stateless_call(\n        self,\n        trainable_variables,\n        non_trainable_variables,\n        *args,\n        return_losses=False,\n        **kwargs,\n    ):\n        \"\"\"Call the layer without any side effects.\n\n        Args:\n            trainable_variables: List of trainable variables of the model.\n            non_trainable_variables: List of non-trainable variables of the\n                model.\n            *args: Positional arguments to be passed to `call()`.\n            return_losses: If `True`, `stateless_call()` will return the list of\n                losses created during `call()` as part of its return values.\n            **kwargs: Keyword arguments to be passed to `call()`.\n\n        Returns:\n            A tuple. By default, returns `(outputs, non_trainable_variables)`.\n                If `return_losses = True`, then returns\n                `(outputs, non_trainable_variables, losses)`.\n\n        Note: `non_trainable_variables` include not only non-trainable weights\n        such as `BatchNormalization` statistics, but also RNG seed state\n        (if there are any random operations part of the layer, such as dropout),\n        and `Metric` state (if there are any metrics attached to the layer).\n        These are all elements of state of the layer.\n\n        Example:\n\n        ```python\n        model = ...\n        data = ...\n        trainable_variables = model.trainable_variables\n        non_trainable_variables = model.non_trainable_variables\n        # Call the model with zero side effects\n        outputs, non_trainable_variables = model.stateless_call(\n            trainable_variables,\n            non_trainable_variables,\n            data,\n        )\n        # Attach the updated state to the model\n        # (until you do this, the model is still in its pre-call state).\n        for ref_var, value in zip(\n            model.non_trainable_variables, non_trainable_variables\n        ):\n            ref_var.assign(value)\n        ```\n        \"\"\"\n        self._check_super_called()\n        if not self.built:\n            raise ValueError(\n                f\"To call stateless_call, {self.__class__.__name__} must be \"\n                \"built (i.e. its variables must have been already created). \"\n                \"You can build it by calling it on some data.\"\n            )\n        if len(trainable_variables) != len(self.trainable_variables):\n            raise ValueError(\n                \"Argument `trainable_variables` must be a list of tensors \"\n                \"corresponding 1:1 to \"\n                f\"{self.__class__.__name__}().trainable_variables. \"\n                f\"Received list with length {len(trainable_variables)}, \"\n                f\"but expected {len(self.trainable_variables)} variables.\"\n            )\n        if len(non_trainable_variables) != len(self.non_trainable_variables):\n            raise ValueError(\n                \"Argument `non_trainable_variables` must be a list of tensors \"\n                \"corresponding 1:1 to \"\n                f\"{self.__class__.__name__}().non_trainable_variables. \"\n                f\"Received list with length {len(non_trainable_variables)}, \"\n                f\"but expected {len(self.non_trainable_variables)} variables.\"\n            )\n\n        # Gather variable mapping\n        trainable_mapping = zip(self.trainable_variables, trainable_variables)\n        non_trainable_mapping = zip(\n            self.non_trainable_variables, non_trainable_variables\n        )\n        mapping = list(trainable_mapping) + list(non_trainable_mapping)\n\n        # Call in stateless scope\n        losses = None\n        with backend.StatelessScope(\n            state_mapping=mapping, collect_losses=return_losses\n        ) as scope:\n            if self.dtype_policy.quantization_mode is not None:\n                if self._remat_mode is not None:\n                    outputs = self.rematerialized_call(\n                        self.quantized_call, *args, **kwargs\n                    )(*args, **kwargs)\n                else:\n                    outputs = self.quantized_call(*args, **kwargs)\n            elif self._remat_mode is not None:\n                outputs = self.rematerialized_call(self.call, *args, **kwargs)(\n                    *args, **kwargs\n                )\n            else:\n                outputs = self.call(*args, **kwargs)\n            if return_losses:\n                losses = self.losses\n\n        # Gather updated non-trainable variables\n        non_trainable_variables = []\n        for v in self.non_trainable_variables:\n            new_v = scope.get_current_value(v)\n            non_trainable_variables.append(new_v)\n\n        if return_losses:\n            return outputs, non_trainable_variables, losses\n        return outputs, non_trainable_variables\n\n    def compute_output_spec(self, *args, **kwargs):\n        if utils.is_default(self.compute_output_shape):\n            return super().compute_output_spec(*args, **kwargs)\n        else:\n            # Use compute_output_shape() to return the right output spec\n            call_spec = CallSpec(\n                self._call_signature, self._call_context_args, args, kwargs\n            )\n            shapes_dict = get_shapes_dict(call_spec)\n            shapes_dict = update_shapes_dict_for_target_fn(\n                self.compute_output_shape,\n                shapes_dict=shapes_dict,\n                call_spec=call_spec,\n                class_name=self.__class__.__name__,\n            )\n            output_shape = self.compute_output_shape(**shapes_dict)\n\n            if (\n                isinstance(output_shape, list)\n                and output_shape\n                and isinstance(output_shape[0], (int, type(None)))\n            ):\n                output_shape = tuple(output_shape)\n            if not isinstance(output_shape, (list, tuple, dict)):\n                try:\n                    output_shape = tuple(output_shape)\n                except:\n                    raise ValueError(\n                        \"Method `compute_output_shape()` of layer \"\n                        f\"{self.__class__.__name__} is returning \"\n                        \"a type that cannot be interpreted as a shape. \"\n                        \"It should return a shape tuple. \"\n                        f\"Received: {output_shape}\"\n                    )\n            if (\n                isinstance(output_shape, tuple)\n                and output_shape\n                and isinstance(output_shape[0], (int, type(None)))\n            ):\n                return KerasTensor(output_shape, dtype=self.compute_dtype)\n            # Case: nested. Could be a tuple/list of shapes, or a dict of\n            # shapes. Could be deeply nested.\n            return tree.map_shape_structure(\n                lambda s: KerasTensor(s, dtype=self.compute_dtype), output_shape\n            )\n\n    @utils.default\n    def compute_output_shape(self, *args, **kwargs):\n        raise self._not_implemented_error(\n            self.compute_output_shape,\n            \"Should implement `def compute_output_shape(self, input_shape)`.\",\n        )\n\n    def add_loss(self, loss):\n        \"\"\"Can be called inside of the `call()` method to add a scalar loss.\n\n        Example:\n\n        ```python\n        class MyLayer(Layer):\n            ...\n            def call(self, x):\n                self.add_loss(ops.sum(x))\n                return x\n        ```\n        \"\"\"\n        # Eager only.\n        losses = tree.flatten(loss)\n        for x in losses:\n            if not backend.is_tensor(x):\n                raise ValueError(\n                    \"`add_loss()` can only be called from inside `build()` or \"\n                    f\"`call()`, on a tensor input. Received invalid value: {x}\"\n                )\n        if backend.in_stateless_scope():\n            scope = backend.get_stateless_scope()\n            if scope.collect_losses:\n                for x in losses:\n                    scope.add_loss(x)\n                    self._loss_ids.add(id(x))\n        else:\n            self._losses.extend(losses)\n\n    def _get_own_losses(self):\n        if backend.in_stateless_scope():\n            losses = []\n            scope = backend.get_stateless_scope()\n            for loss in scope.losses:\n                if id(loss) in self._loss_ids:\n                    losses.append(loss)\n            return losses\n        else:\n            return self._losses[:]\n\n    def _get_regularization_losses(self):\n        weight_regularization_losses = []\n        for variable in self.trainable_weights:\n            if variable.regularizer is None:\n                continue\n            if backend.in_stateless_scope() and not in_symbolic_scope():\n                # If in symbolic scope, we might get `None` from\n                # `get_current_value` in `backend.compute_output_spec`. So we\n                # assign `variable` instead.\n                v = backend.get_stateless_scope().get_current_value(variable)\n            else:\n                v = variable\n            weight_regularization_losses.append(variable.regularizer(v))\n        return weight_regularization_losses\n\n    @property\n    def losses(self):\n        \"\"\"List of scalar losses from `add_loss`, regularizers and sublayers.\"\"\"\n        if self._losses_override:\n            return self._losses_override\n        losses = self._get_own_losses()\n        for layer in self._flatten_layers(include_self=False):\n            losses.extend(layer._get_own_losses())\n        weight_regularization_losses = self._get_regularization_losses()\n        losses.extend(weight_regularization_losses)\n        return losses\n\n    def _clear_losses(self):\n        if backend.in_stateless_scope():\n            scope = backend.get_stateless_scope()\n            if scope.collect_losses:\n                # Filter by identity (id) rather than using list.remove(),\n                # which compares by value. Value comparison on JAX tracers\n                # during JIT causes TracerBoolConversionError.\n                scope.losses[:] = [\n                    x for x in scope.losses if id(x) not in self._loss_ids\n                ]\n        self._losses.clear()\n        self._loss_ids.clear()\n        for layer in self._layers:\n            layer._clear_losses()\n\n    # Quantization-related (int8 and float8) methods\n\n    def quantized_build(self, input_shape, mode):\n        raise self._not_implemented_error(self.quantized_build)\n\n    def quantize(self, mode=None, type_check=True, config=None):\n        raise self._not_implemented_error(self.quantize)\n\n    def _check_quantize_args(self, mode, compute_dtype):\n        if not self.built:\n            raise ValueError(\n                \"Cannot quantize a layer that isn't yet built. \"\n                f\"Layer '{self.name}' (of type '{self.__class__.__name__}') \"\n                \"is not built yet.\"\n            )\n        if getattr(self, \"_is_quantized\", False):\n            raise ValueError(\n                f\"Layer '{self.name}' is already quantized with \"\n                f\"dtype_policy='{self.dtype_policy.name}'. \"\n                f\"Received: mode={mode}\"\n            )\n        if mode not in dtype_policies.QUANTIZATION_MODES:\n            raise ValueError(\n                \"Invalid quantization mode. \"\n                f\"Expected one of {dtype_policies.QUANTIZATION_MODES}. \"\n                f\"Received: mode={mode}\"\n            )\n        if mode == \"int8\" and compute_dtype == \"float16\":\n            raise ValueError(\n                f\"Quantization mode='{mode}' doesn't work well with \"\n                \"compute_dtype='float16'. Consider loading model/layer with \"\n                \"another dtype policy such as 'mixed_bfloat16' or \"\n                \"'mixed_float16' before calling `quantize()`.\"\n            )\n\n    def quantized_call(self, *args, **kwargs):\n        current_remat_mode = get_current_remat_mode()\n\n        if (\n            current_remat_mode != self._remat_mode\n            and current_remat_mode is not None\n        ):\n            warnings.warn(\n                f\"The RematScope at call time ({current_remat_mode}) differs \"\n                f\"the one set during layer initialization \"\n                f\"({self._remat_mode}). \"\n                f\"Restoring the correct rematerialization mode \"\n                f\"{self._remat_mode} for this layer.\"\n            )\n        if self.quantization_mode == \"int8\":\n            return self._int8_call(*args, **kwargs)\n        elif self.quantization_mode == \"float8\":\n            return self._float8_call(*args, **kwargs)\n        elif self.quantization_mode == \"int4\":\n            return self._int4_call(*args, **kwargs)\n        elif self.quantization_mode == \"gptq\":\n            return self._gptq_call(*args, **kwargs)\n        elif self.quantization_mode == \"awq\":\n            return self._awq_call(*args, **kwargs)\n        else:\n            raise self._quantization_mode_error(self.quantization_mode)\n\n    def _int4_call(self, *args, **kwargs):\n        raise self._not_implemented_error(self._int4_call)\n\n    def _int8_call(self, *args, **kwargs):\n        raise self._not_implemented_error(self._int8_call)\n\n    def _float8_call(self, *args, **kwargs):\n        raise self._not_implemented_error(self._float8_call)\n\n    def _gptq_call(self, *args, **kwargs):\n        raise self._not_implemented_error(self._gptq_call)\n\n    def _awq_call(self, *args, **kwargs):\n        raise self._not_implemented_error(self._awq_call)\n\n    def _not_implemented_error(self, attr, msg=None):\n        if callable(attr):\n            attr_name = attr.__name__\n            attr_type = \"method\"\n        else:\n            attr_name = str(attr)\n            attr_type = \"attribute\"\n        msg = f\" {msg}\" if msg is not None else \"\"\n        return NotImplementedError(\n            f\"Layer {self.__class__.__name__} does not have a `{attr_name}` \"\n            f\"{attr_type} implemented.{msg}\"\n        )\n\n    def _quantization_mode_error(self, mode):\n        return NotImplementedError(\n            \"Invalid quantization mode. Expected one of \"\n            f\"{dtype_policies.QUANTIZATION_MODES}. \"\n            f\"Received: quantization_mode={mode}\"\n        )\n\n    def save_own_variables(self, store):\n        \"\"\"Saves the state of the layer.\n\n        You can override this method to take full control of how the state of\n        the layer is saved upon calling `model.save()`.\n\n        Args:\n            store: Dict where the state of the model will be saved.\n        \"\"\"\n        all_vars = self._trainable_variables + self._non_trainable_variables\n        for i, v in enumerate(all_vars):\n            store[f\"{i}\"] = v\n\n    def _check_load_own_variables(self, store):\n        all_vars = self._trainable_variables + self._non_trainable_variables\n        if len(store.keys()) != len(all_vars):\n            if len(all_vars) == 0 and not self.built:\n                raise ValueError(\n                    f\"Layer '{self.name}' was never built \"\n                    \"and thus it doesn't have any variables. \"\n                    f\"However the weights file lists {len(store.keys())} \"\n                    \"variables for this layer.\\n\"\n                    \"In most cases, this error indicates that either:\\n\\n\"\n                    \"1. The layer is owned by a parent layer that \"\n                    \"implements a `build()` method, but calling the \"\n                    \"parent's `build()` method did NOT create the state of \"\n                    f\"the child layer '{self.name}'. A `build()` method \"\n                    \"must create ALL state for the layer, including \"\n                    \"the state of any children layers.\\n\\n\"\n                    \"2. You need to implement \"\n                    \"the `def build_from_config(self, config)` method \"\n                    f\"on layer '{self.name}', to specify how to rebuild \"\n                    \"it during loading. \"\n                    \"In this case, you might also want to implement the \"\n                    \"method that generates the build config at saving time, \"\n                    \"`def get_build_config(self)`. \"\n                    \"The method `build_from_config()` is meant \"\n                    \"to create the state \"\n                    \"of the layer (i.e. its variables) upon deserialization.\",\n                )\n            raise ValueError(\n                f\"Layer '{self.name}' expected {len(all_vars)} variables, \"\n                \"but received \"\n                f\"{len(store.keys())} variables during loading. \"\n                f\"Expected: {[v.name for v in all_vars]}\"\n            )\n\n    def load_own_variables(self, store):\n        \"\"\"Loads the state of the layer.\n\n        You can override this method to take full control of how the state of\n        the layer is loaded upon calling `keras.models.load_model()`.\n\n        Args:\n            store: Dict from which the state of the model will be loaded.\n        \"\"\"\n        self._check_load_own_variables(store)\n        all_vars = self._trainable_variables + self._non_trainable_variables\n        for i, v in enumerate(all_vars):\n            v.assign(store[f\"{i}\"])\n\n    def _track_variable(self, variable):\n        if variable.trainable:\n            self._tracker.add_to_store(\"trainable_variables\", variable)\n        else:\n            self._tracker.add_to_store(\"non_trainable_variables\", variable)\n        if not self.trainable:\n            variable.trainable = False\n        self._post_track_variable(variable)\n\n    def _untrack_variable(self, variable):\n        previous_lock_state = self._tracker.locked\n        self._tracker.unlock()\n        self._tracker.untrack(variable)\n        if previous_lock_state is True:\n            self._tracker.lock()\n        self._post_untrack_variable(variable)\n\n    def add_metric(self, *args, **kwargs):\n        # Permanently disabled\n        raise NotImplementedError(\n            \"Layer `add_metric()` method is deprecated. \"\n            \"Add your metric in `Model.compile(metrics=[...])`, \"\n            \"or create metric trackers in init() or build() \"\n            \"when subclassing the layer or model, then call \"\n            \"`metric.update_state()` whenever necessary.\"\n        )\n\n    def count_params(self):\n        \"\"\"Count the total number of scalars composing the weights.\n\n        Returns:\n            An integer count.\n        \"\"\"\n        if not self.built:\n            raise ValueError(\n                \"You tried to call `count_params` \"\n                f\"on layer '{self.name}', \"\n                \"but the layer isn't built. \"\n                \"You can build it manually via: \"\n                f\"`layer.build(input_shape)`.\"\n            )\n        return summary_utils.count_params(self.weights)\n\n    def _maybe_build(self, call_spec):\n        if self.built:\n            return\n\n        shapes_dict = get_shapes_dict(call_spec)\n        first_shape = next(iter(shapes_dict.values()), None)\n\n        # If the layer has a build method, call it with our input shapes.\n        if not utils.is_default(self.build):\n            shapes_dict = update_shapes_dict_for_target_fn(\n                self.build,\n                shapes_dict=shapes_dict,\n                call_spec=call_spec,\n                class_name=self.__class__.__name__,\n            )\n            self.build(**shapes_dict)\n            # Check input spec again (after build, since self.input_spec\n            # may have been updated\n            self._assert_input_compatibility(call_spec.first_arg)\n            return\n\n        # Otherwise, attempt to build the layer by calling it on symbolic input.\n        if might_have_unbuilt_state(self):\n            try:\n                backend.compute_output_spec(\n                    self.call, **call_spec.arguments_dict\n                )\n            except Exception as e:\n                if call_spec.eager:\n                    # Will let the actual eager call do state-building\n                    return\n                warnings.warn(\n                    f\"Layer '{self.name}' looks like it has unbuilt state, but \"\n                    \"Keras is not able to trace the layer `call()` in order to \"\n                    \"build it automatically. Possible causes:\\n\"\n                    \"1. The `call()` method of your layer may be crashing. Try \"\n                    \"to `__call__()` the layer eagerly on some test input \"\n                    \"first to see if it works. \"\n                    \"E.g. `x = np.random.random((3, 4)); y = layer(x)`\\n\"\n                    \"2. If the `call()` method is correct, then you may need \"\n                    \"to implement the `def build(self, input_shape)` method on \"\n                    \"your layer. It should create all variables used by the \"\n                    \"layer (e.g. by calling `layer.build()` on all its \"\n                    \"children layers).\\n\"\n                    f\"Exception encountered: ''{e}''\"\n                )\n        self.build(first_shape)\n\n    def _build_by_run_for_single_pos_arg(self, input_shape):\n        # Case: all inputs are in the first arg (possibly nested).\n        input_tensors = tree.map_shape_structure(\n            lambda s: backend.KerasTensor(s), input_shape\n        )\n        try:\n            backend.compute_output_spec(self.call, input_tensors)\n            return True\n        except:\n            return False\n\n    def _build_by_run_for_kwargs(self, shapes_dict):\n        # Case: inputs were recorded as multiple keyword arguments.\n        if all(is_shape_tuple(s) for s in shapes_dict.values()):\n            # Case: all input keyword arguments were plain tensors.\n            input_tensors = {\n                # We strip the `_shape` suffix to recover kwarg names.\n                utils.removesuffix(k, \"_shape\"): backend.KerasTensor(shape)\n                for k, shape in shapes_dict.items()\n            }\n            try:\n                backend.compute_output_spec(self.call, **input_tensors)\n                return True\n            except:\n                return False\n        else:\n            # Not supported: nested input keyword arguments.\n            return False\n\n    def __repr__(self):\n        return (\n            f\"<{self.__class__.__name__} name={self.name}, built={self.built}>\"\n        )\n\n    def __str__(self):\n        return self.__repr__()\n\n    def __setattr__(self, name, value):\n        # Track Variables, Layers, Metrics, SeedGenerators.\n        name, value = self._setattr_hook(name, value)\n        if name != \"_tracker\":\n            if not hasattr(self, \"_tracker\"):\n                self._initialize_tracker()\n            value = self._tracker.track(value)\n\n        # NNX-specific bypass for `_called` and `built` attributes\n        # bypass nnx.Module.__setattr__ which cannot be called while tracing\n        if (\n            backend.backend() == \"jax\"\n            and is_nnx_enabled()\n            and (name == \"_called\" or name == \"built\")\n        ):\n            object.__setattr__(self, name, value)\n            return\n\n        super().__setattr__(name, value)\n\n    def __delattr__(self, name):\n        obj = getattr(self, name)\n        if isinstance(obj, backend.Variable):\n            import gc\n\n            # It will take a short amount of time for the corresponding buffer\n            # to be actually removed from the device.\n            # https://stackoverflow.com/a/74631949\n            self._untrack_variable(obj)\n            super().__delattr__(name)\n            gc.collect()\n        else:\n            super().__delattr__(name)\n\n    def _check_super_called(self):\n        if getattr(self, \"_lock\", True):\n            raise RuntimeError(\n                f\"In layer '{self.__class__.__name__}', you forgot to call \"\n                \"`super().__init__()` as the first statement \"\n                \"in the `__init__()` method. Go add it!\"\n            )\n\n    def _assert_input_compatibility(self, arg_0):\n        if self.input_spec:\n            try:\n                input_spec.assert_input_compatibility(\n                    self.input_spec, arg_0, layer_name=self.name\n                )\n            except SystemError:\n                if backend.backend() == \"torch\":\n                    # TODO: The torch backend failed the ONNX CI with the error:\n                    # SystemError: <method '__int__' of 'torch._C.TensorBase'\n                    # objects> returned a result with an exception set\n                    # As a workaround, we are skipping this for now.\n                    pass\n                else:\n                    raise\n\n    def _get_call_context(self):\n        \"\"\"Returns currently active `CallContext`.\"\"\"\n        layer_call_ctx = global_state.get_global_attribute(\"current_call_ctx\")\n        if layer_call_ctx is None:\n            # Enter new call context.\n            layer_call_ctx = CallContext(entry_layer=self)\n            global_state.set_global_attribute(\n                \"current_call_ctx\", layer_call_ctx\n            )\n            self._clear_losses()\n        return layer_call_ctx\n\n    def _maybe_reset_call_context(self):\n        layer_call_ctx = global_state.get_global_attribute(\"current_call_ctx\")\n        if layer_call_ctx is None or layer_call_ctx.entry_layer == self:\n            global_state.set_global_attribute(\"current_call_ctx\", None)\n\n    def _flatten_layers(self, include_self=True, recursive=True):\n        layers = []\n        if include_self:\n            layers.append(self)\n        seen_object_ids = set()\n        deque = collections.deque(self._layers)\n        while deque:\n            layer = deque.popleft()\n            if id(layer) in seen_object_ids:\n                continue\n            seen_object_ids.add(id(layer))\n            layers.append(layer)\n            # Introspect recursively through sublayers.\n            if recursive:\n                deque.extendleft(layer._layers)\n        return layers\n\n    def _set_mask_metadata(self, inputs, outputs, previous_mask):\n        flat_outputs = tree.flatten(outputs)\n\n        mask_already_computed = all(\n            backend.get_keras_mask(x) is not None for x in flat_outputs\n        )\n        if mask_already_computed:\n            return\n\n        output_masks = self.compute_mask(inputs, previous_mask)\n        if output_masks is None:\n            return\n\n        flat_masks = tree.flatten(output_masks)\n        for tensor, mask in zip(flat_outputs, flat_masks):\n            if backend.get_keras_mask(tensor) is None and mask is not None:\n                if backend.backend() == \"numpy\":\n                    warnings.warn(\n                        \"The NumPy backend does not support masking at this\"\n                        \"time. Masks will be ignored.\"\n                    )\n                else:\n                    backend.set_keras_mask(tensor, mask)\n\n    @python_utils.default\n    def get_config(self):\n        self._check_super_called()\n        base_config = super().get_config()\n        config = {\n            \"trainable\": self.trainable,\n            \"dtype\": dtype_policies.serialize(self.dtype_policy),\n        }\n        if self.activity_regularizer is not None:\n            config[\"activity_regularizer\"] = regularizers.serialize(\n                self.activity_regularizer\n            )\n        return {**base_config, **config}\n\n    def _open_name_scope(self):\n        from keras.src.utils import jax_utils  # avoid circular imports\n\n        if self._parent_path is None:\n            # Avoid mutating _parent_path during a JAX trace if it's part of\n            # nnx.Object state and the object was created at a different trace\n            # level. We check if we are in NNX mode and if we are in a JAX\n            # trace.\n            if not (is_nnx_enabled() and jax_utils.is_in_jax_tracing_scope()):\n                self._parent_path = current_path()\n\n        return backend.name_scope(self.name, caller=self)\n\n    def rematerialized_call(self, layer_call, *args, **kwargs):\n        \"\"\"Enable rematerialization dynamically for layer's call method.\n\n        Args:\n            layer_call: The original `call` method of a layer.\n\n        Returns:\n            Rematerialized layer's `call` method.\n        \"\"\"\n\n        def compute_size(x):\n            return (\n                math.prod([d or 1 for d in x.shape])\n                if isinstance(x, KerasTensor)\n                else 0\n            )\n\n        # Full rematerialization\n        if self._remat_mode.mode == \"full\":\n            return remat.remat(layer_call)\n\n        # Apply rematerialization to specific layers\n        elif self._remat_mode.mode == \"list_of_layers\" and (\n            self.name in self._remat_mode.layer_names\n        ):\n            return remat.remat(layer_call)\n\n        # Apply rematerialization based on output size threshold\n        elif self._remat_mode.mode == \"larger_than\":\n            output_spec = self.compute_output_spec(*args, **kwargs)\n            output_size = sum(\n                tree.flatten(tree.map_structure(compute_size, output_spec))\n            )\n            if (\n                output_size\n                and output_size > self._remat_mode.output_size_threshold\n            ):\n                return remat.remat(layer_call)\n        elif self._remat_mode.mode == \"activations\":\n            has_activation = (\n                hasattr(self, \"activation\") and self.activation is not None\n            )\n            if has_activation:\n\n                @functools.wraps(layer_call)\n                def rematerialized_activation_call_wrapper(*args, **kwargs):\n                    original_activation = self.activation\n                    self.activation = remat.remat(original_activation)\n                    try:\n                        return layer_call(*args, **kwargs)\n                    finally:\n                        self.activation = original_activation\n\n                return rematerialized_activation_call_wrapper\n        return layer_call\n\n    def _register_call_context_args(self, *names):\n        \"\"\"Registers call-context args for this layer.\n\n        If this layer declares a `call()` method that accepts\n        one or more of the given args, those args will be\n        automatically injected into the call signature of this\n        layer. This layer will also propagate the args to any\n        nested sublayers that are called from within this layer.\n\n        If this layer doesn't declare a `call()` method that\n        accepts one or more of the given args, these args will\n        simply be propagated to any nested sublayers without\n        being injected into the call signature of this layer.\n        This is useful for propagating custom arguments\n        from top-level layers/models to sublayers.\n\n        Example:\n        ```\n        class Inner(layers.Layer):\n\n            def __init__(self):\n                super().__init__()\n                # Register `foo_mode` as a call-context arg\n                self._register_call_context_args(\"foo_mode\")\n\n            def call(self, x, foo_mode=False):\n                # If foo_mode=True add 1, otherwise add 0\n                add_val = ops.where(foo_mode, 1.0, 0.0)\n                return x + add_val\n\n        class Outer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.inner = Inner()\n\n            def call(self, x):\n                # We don't explicitly pass foo_mode here—Base Layer.__call__\n                # should inject it into `self.inner`\n                return self.inner(x)\n\n        sample_input = np.array([[1.0], [2.0]])\n\n        # Sequential model\n        seq = models.Sequential([Outer()])\n\n        # Tell the Sequential model to propagate foo_mode down\n        # the call-stack\n        seq._register_call_context_args(\"foo_mode\")\n\n        # foo_mode=True -> input + 1\n        out_true = seq(sample_input, foo_mode=True)\n        \"\"\"\n        if self._called:\n            raise RuntimeError(\n                \"Cannot add call-context args after the layer has been called.\"\n            )\n        self._call_context_args = self._call_context_args | set(names)\n\n        self._call_has_context_arg.update(\n            {arg: (arg in self.call_signature_parameters) for arg in names}\n        )\n\n\ndef is_backend_tensor_or_symbolic(x, allow_none=False):\n    if allow_none and x is None:\n        return True\n    return backend.is_tensor(x) or isinstance(x, backend.KerasTensor)\n\n\nclass CallSpec:\n    def __init__(self, signature, call_context_args, args, kwargs):\n        # Strip out user-supplied call-context args that this layer’s `call()`\n        # does not accept (otherwise `signature.bind` would raise).\n        # This includes built-in args like `training`, and user-defined args.\n        call_args = {\n            context_arg: kwargs.pop(context_arg)\n            for context_arg in call_context_args\n            if context_arg in kwargs and context_arg not in signature.parameters\n        }\n\n        bound_args = signature.bind(*args, **kwargs)\n\n        # Combine the two dicts.\n        self.user_arguments_dict = {**call_args, **bound_args.arguments}\n\n        bound_args.apply_defaults()\n        arg_dict = {}\n        arg_names = []\n        tensor_arg_dict = {}\n        tensor_args = []\n        tensor_arg_names = []\n        nested_tensor_arg_names = []\n        for name, value in bound_args.arguments.items():\n            arg_dict[name] = value\n            arg_names.append(name)\n            if is_backend_tensor_or_symbolic(value):\n                tensor_args.append(value)\n                tensor_arg_names.append(name)\n                tensor_arg_dict[name] = value\n            elif tree.is_nested(value) and len(value) > 0:\n                flat_values = tree.flatten(value)\n                if all(\n                    is_backend_tensor_or_symbolic(x, allow_none=True)\n                    for x in flat_values\n                ):\n                    tensor_args.append(value)\n                    tensor_arg_names.append(name)\n                    tensor_arg_dict[name] = value\n                    nested_tensor_arg_names.append(name)\n                elif any(is_backend_tensor_or_symbolic(x) for x in flat_values):\n                    raise ValueError(\n                        \"In a nested call() argument, \"\n                        \"you cannot mix tensors and non-tensors. \"\n                        \"Received invalid mixed argument: \"\n                        f\"{name}={value}\"\n                    )\n        self.arguments_dict = arg_dict\n        self.argument_names = arg_names\n        self.tensor_arguments_dict = tensor_arg_dict\n        self.tensor_arguments_names = tensor_arg_names\n        self.nested_tensor_argument_names = nested_tensor_arg_names\n        self.first_arg = arg_dict[arg_names[0]]\n        if all(\n            backend.is_tensor(x) for x in self.tensor_arguments_dict.values()\n        ):\n            self.eager = True\n        else:\n            self.eager = False\n\n\ndef get_arguments_dict(fn, args, kwargs):\n    \"\"\"Return a dict mapping argument names to their values.\"\"\"\n    sig = inspect.signature(fn)\n    bound_args = sig.bind(*args, **kwargs)\n    arg_dict = {}\n    for name, value in bound_args.arguments.items():\n        arg_dict[name] = value\n    return arg_dict\n\n\ndef get_shapes_dict(call_spec):\n    \"\"\"Convert the call() arguments dict into a dict of input shape arguments.\n\n    Example:\n\n    ```\n    >>> get_shapes_dict(call_spec)\n    {\"input_a_shape\": (2, 3)}\n    ```\n    \"\"\"\n\n    def standardize_shape_or_none(x):\n        return None if x is None else backend.standardize_shape(x.shape)\n\n    shapes_dict = {}\n    for k, v in call_spec.tensor_arguments_dict.items():\n        if k == \"mask\":\n            # Do not include mask tensors in shapes dict\n            continue\n        if k == \"kwargs\" or k == \"args\":\n            # Do not include catch-alls in shapes dict\n            continue\n        if k in call_spec.nested_tensor_argument_names:\n            shapes_dict[f\"{k}_shape\"] = tree.map_structure(\n                standardize_shape_or_none, v\n            )\n        else:\n            shapes_dict[f\"{k}_shape\"] = standardize_shape_or_none(v)\n    return shapes_dict\n\n\ndef update_shapes_dict_for_target_fn(\n    target_fn,\n    shapes_dict,\n    call_spec,\n    class_name,\n):\n    \"\"\"Updates a `shapes_dict` for `build()` or `compute_output_shape()`.\n\n    This function will align a dictionary of the shapes of all tensor\n    passed to `call`, with the signatures of `build()` or\n    `compute_output_shape()`.\n\n    The alignment is a follows:\n\n    - If `build()` or `compute_output_shape()` accept only one argument,\n        forward the shape of the first positional argument from call without\n        checking any argument names.\n    - If `build()` or `compute_output_shape()` accept multiple arguments,\n        enforce that all argument names match a call argument name, e.g.\n        `foo_shape` would match call argument `foo`.\n\n    Returns:\n        An updated `shapes_dict` that can be used to invoke\n        `target_fn(**shapes_dict)`.\n    \"\"\"\n    if utils.is_default(target_fn):\n        return None\n    sig = inspect.signature(target_fn)\n    expected_names = []\n    for name, param in sig.parameters.items():\n        if param.kind in (\n            param.POSITIONAL_OR_KEYWORD,\n            param.POSITIONAL_ONLY,\n            param.KEYWORD_ONLY,\n        ):\n            expected_names.append(name)\n\n    # Single arg: don't check names, pass first shape.\n    if len(expected_names) == 1:\n        key = expected_names[0]\n        values = tuple(shapes_dict.values())\n        if values:\n            input_shape = values[0]\n        else:\n            input_shape = None\n        return {key: input_shape}\n\n    # Multiple args: check that all names line up.\n    kwargs = {}\n    for name in expected_names:\n        method_name = target_fn.__name__\n        error_preamble = (\n            f\"For a `{method_name}()` method with more than one argument, all \"\n            \"arguments should have a `_shape` suffix and match an argument \"\n            f\"from `call()`. E.g. `{method_name}(self, foo_shape, bar_shape)` \"\n        )\n        if not name.endswith(\"_shape\"):\n            raise ValueError(\n                f\"{error_preamble} For layer '{class_name}', \"\n                f\"Received `{method_name}()` argument \"\n                f\"`{name}`, which does not end in `_shape`.\"\n            )\n        expected_call_arg = utils.removesuffix(name, \"_shape\")\n        if expected_call_arg not in call_spec.arguments_dict:\n            raise ValueError(\n                f\"{error_preamble} For layer '{class_name}', \"\n                f\"received `{method_name}()` argument \"\n                f\"`{name}`, but `call()` does not have argument \"\n                f\"`{expected_call_arg}`.\"\n            )\n        if name in shapes_dict:\n            kwargs[name] = shapes_dict[name]\n\n    return kwargs\n\n\nclass CallContext:\n    def __init__(self, entry_layer):\n        self.entry_layer = entry_layer\n\n    def get_value(self, arg_name, default=None):\n        \"\"\"Get the context value for `arg_name`, or `default` if unset.\"\"\"\n        return getattr(self, arg_name, default)\n\n    def set_value(self, arg_name, value):\n        \"\"\"Set `arg_name` = `value` on this context object.\"\"\"\n        setattr(self, arg_name, value)\n\n\ndef is_shape_tuple(s):\n    return isinstance(s, (list, tuple)) and all(\n        d is None or isinstance(d, int) for d in s\n    )\n\n\ndef might_have_unbuilt_state(layer):\n    return any(not lr.built for lr in layer._layers)\n"
  },
  {
    "path": "keras/src/layers/layer_test.py",
    "content": "import pickle\nfrom unittest import mock\n\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import Input\nfrom keras.src import backend\nfrom keras.src import dtype_policies\nfrom keras.src import layers\nfrom keras.src import metrics\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.backend.common import global_state\nfrom keras.src.backend.common.remat import RematScope\nfrom keras.src.models import Model\n\n\nclass MockRemat:\n    \"\"\"Mock remat by returning a wrapper Mock calling the original function\"\"\"\n\n    def __init__(self):\n        self.rematted_functions = {}\n\n    def __call__(self, func):\n        if func in self.rematted_functions:\n            return self.rematted_functions[func]\n\n        wrapped_func = mock.Mock(wraps=func)\n        self.rematted_functions[func] = wrapped_func\n        return wrapped_func\n\n\nclass LayerTest(testing.TestCase):\n    def test_compute_output_spec(self):\n        # Test that implementing compute_output_shape\n        # is enough to make compute_output_spec work.\n\n        # Case: single output\n        class TestLayer(layers.Layer):\n            def call(self, x):\n                raise RuntimeError(\"Should never be called.\")\n\n            def compute_output_shape(self, input_shape):\n                return input_shape\n\n        layer = TestLayer()\n        self.assertEqual(\n            layer.compute_output_spec(backend.KerasTensor((2, 3))).shape, (2, 3)\n        )\n\n        # Case: tuple output\n        class TestLayer(layers.Layer):\n            def call(self, x):\n                raise RuntimeError(\"Should never be called.\")\n\n            def compute_output_shape(self, input_shape):\n                return (input_shape, input_shape)\n\n        layer = TestLayer()\n        out = layer.compute_output_spec(backend.KerasTensor((2, 3)))\n        self.assertIsInstance(out, tuple)\n        self.assertEqual(len(out), 2)\n        self.assertEqual(out[0].shape, (2, 3))\n        self.assertEqual(out[1].shape, (2, 3))\n\n        # Case: list output\n        class TestLayer(layers.Layer):\n            def call(self, x):\n                raise RuntimeError(\"Should never be called.\")\n\n            def compute_output_shape(self, input_shape):\n                return [input_shape, input_shape]\n\n        layer = TestLayer()\n        out = layer.compute_output_spec(backend.KerasTensor((2, 3)))\n        self.assertIsInstance(out, list)\n        self.assertEqual(len(out), 2)\n        self.assertEqual(out[0].shape, (2, 3))\n        self.assertEqual(out[1].shape, (2, 3))\n\n        # Case: dict output\n        class TestLayer(layers.Layer):\n            def call(self, x):\n                raise RuntimeError(\"Should never be called.\")\n\n            def compute_output_shape(self, input_shape):\n                return {\"1\": input_shape, \"2\": input_shape}\n\n        layer = TestLayer()\n        out = layer.compute_output_spec(backend.KerasTensor((2, 3)))\n        self.assertIsInstance(out, dict)\n        self.assertEqual(len(out), 2)\n        self.assertEqual(out[\"1\"].shape, (2, 3))\n        self.assertEqual(out[\"2\"].shape, (2, 3))\n\n        # Case: nested tuple output\n        class TestLayer(layers.Layer):\n            def call(self, x):\n                raise RuntimeError(\"Should never be called.\")\n\n            def compute_output_shape(self, input_shape):\n                return (\n                    input_shape,\n                    (input_shape, input_shape),\n                    (input_shape, input_shape),\n                )\n\n        layer = TestLayer()\n        out = layer.compute_output_spec(backend.KerasTensor((2, 3)))\n        self.assertIsInstance(out, tuple)\n        self.assertEqual(len(out), 3)\n        self.assertEqual(out[0].shape, (2, 3))\n        self.assertIsInstance(out[1], tuple)\n        self.assertEqual(len(out[1]), 2)\n        self.assertEqual(out[1][0].shape, (2, 3))\n        self.assertEqual(out[1][1].shape, (2, 3))\n        self.assertIsInstance(out[2], tuple)\n        self.assertEqual(len(out[2]), 2)\n        self.assertEqual(out[2][0].shape, (2, 3))\n        self.assertEqual(out[2][1].shape, (2, 3))\n\n        # Case: nested dict output\n        class TestLayer(layers.Layer):\n            def call(self, x):\n                raise RuntimeError(\"Should never be called.\")\n\n            def compute_output_shape(self, input_shape):\n                return {\n                    \"1\": input_shape,\n                    \"2\": {\"11\": input_shape, \"22\": input_shape},\n                }\n\n        layer = TestLayer()\n        out = layer.compute_output_spec(backend.KerasTensor((2, 3)))\n        self.assertIsInstance(out, dict)\n        self.assertEqual(len(out), 2)\n        self.assertEqual(out[\"1\"].shape, (2, 3))\n        self.assertIsInstance(out[\"2\"], dict)\n        self.assertEqual(len(out[\"2\"]), 2)\n        self.assertEqual(out[\"2\"][\"11\"].shape, (2, 3))\n        self.assertEqual(out[\"2\"][\"22\"].shape, (2, 3))\n\n    def test_positional_arg_error(self):\n        class SomeLayer(layers.Layer):\n            def call(self, x, bool_arg):\n                if bool_arg:\n                    return x\n                return x + 1\n\n        x = backend.KerasTensor(shape=(2, 3), name=\"x\")\n        with self.assertRaisesRegex(\n            ValueError, \"Only input tensors may be passed as\"\n        ):\n            SomeLayer()(x, True)\n\n        # This works\n        SomeLayer()(x, bool_arg=True)\n\n    @parameterized.named_parameters(\n        (\"call\", \"call\", None),\n        (\"compute_output_shape\", \"compute_output_shape\", None),\n        (\n            \"quantized_build\",\n            \"quantized_build\",\n            {\"input_shape\": None, \"mode\": None},\n        ),\n        (\"quantize\", \"quantize\", {\"mode\": \"int8\"}),\n        (\"_int8_call\", \"_int8_call\", None),\n        (\"_float8_call\", \"_float8_call\", None),\n    )\n    def test_not_implemented_error(self, method, args):\n        layer = layers.Layer()\n        layer.built = True\n\n        with self.assertRaisesRegex(\n            NotImplementedError,\n            f\"does not have a `{method}` method implemented.\",\n        ):\n            if isinstance(args, dict):\n                getattr(layer, method)(**args)\n            else:\n                getattr(layer, method)(args)\n\n    def test_layer_with_remat(self):\n        \"\"\"Test rematerialization on a simple layer.\"\"\"\n        # Create a mock to track calls to remat\n        mock_remat = MockRemat()\n        with mock.patch(\n            \"keras.src.backend.common.remat.remat\", wraps=mock_remat\n        ):\n\n            class SomeLayer(layers.Layer):\n                def call(self, x):\n                    return x + 1\n\n            input_tensor = backend.random.uniform((2, 4))\n            layer = SomeLayer()\n            # Case 1: Without rematerialization\n            output_no_remat = layer(input_tensor)\n\n            # Case 2: With rematerialization\n            with RematScope(mode=\"full\"):\n                layer = SomeLayer()\n                output_with_remat = layer(input_tensor)\n\n            # Assert outputs are the same\n            self.assertAllClose(output_no_remat, output_with_remat)\n\n        # Ensure remat was applied in the second case\n        self.assertLen(mock_remat.rematted_functions, 1)\n        next(iter(mock_remat.rematted_functions.values())).assert_called()\n\n    def test_quantized_layer_with_remat(self):\n        \"\"\"Test rematerialization on a quantized layer.\"\"\"\n        mock_remat = MockRemat()\n        with mock.patch(\n            \"keras.src.backend.common.remat.remat\", wraps=mock_remat\n        ):\n            input_tensor = backend.random.uniform((2, 4))\n\n            # Case 2: With rematerialization\n            with RematScope(mode=\"full\"):\n                layer = layers.Dense(3)\n                layer.build((2, 4))\n                layer.quantize(\"float8\")\n                layer(input_tensor)\n\n        # Ensure remat was applied\n        self.assertLen(mock_remat.rematted_functions, 1)\n        next(iter(mock_remat.rematted_functions.values())).assert_called()\n\n    def test_gptq_quantization_by_setting_dtype(self):\n        \"\"\"Tests error being raised when dtype is set to GPTQ.\"\"\"\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Implicitly enabling GPTQ quantization.*is not supported\",\n        ):\n            layer = layers.Dense(3)\n            layer.build((2, 4))\n            layer.dtype_policy = \"gptq/4/-1_from_float32\"\n\n    @pytest.mark.skipif(\n        backend.backend() in (\"openvino\", \"numpy\"),\n        reason=\"remat not supported on OpenVino and Numpy\",\n    )\n    def test_functional_model_with_remat(self):\n        mock_remat = MockRemat()\n        with mock.patch(\n            \"keras.src.backend.common.remat.remat\", wraps=mock_remat\n        ):\n            # Define model inputs\n            inputs = Input(shape=(32, 32, 3))\n\n            # just one layer in remat scope\n            with RematScope(mode=\"activations\"):\n                layer = layers.Dense(64, activation=\"relu\")\n                output = layer(layers.Flatten()(inputs))\n\n            # Build the functional model\n            model = Model(inputs=inputs, outputs=output)\n\n            # Compile the model\n            model.compile(optimizer=\"adam\", loss=\"mse\")\n\n            # Generate dummy data for testing\n            x_train = np.random.random((10, 32, 32, 3)).astype(np.float32)\n            y_train = np.random.random((10, 64)).astype(np.float32)\n\n            # Run training to ensure `RematScope` is applied correctly\n            model.fit(x_train, y_train, epochs=1, batch_size=2, verbose=0)\n\n        self.assertLen(mock_remat.rematted_functions, 1)\n        next(iter(mock_remat.rematted_functions.values())).assert_called()\n\n    def test_remat_wrapper_list_of_layers(self):\n        \"\"\"Test rematerialization using list_of_layers mode.\"\"\"\n        mock_remat = MockRemat()\n        with mock.patch(\n            \"keras.src.backend.common.remat.remat\", wraps=mock_remat\n        ):\n\n            class TestLayer(layers.Layer):\n                def call(self, x):\n                    return x + 1\n\n            class OtherLayer(layers.Layer):\n                def call(self, x):\n                    return x * 2\n\n            remat_layers = [\"test_layer\"]\n            input_tensor = backend.random.uniform((4, 4))\n\n            with RematScope(mode=\"list_of_layers\", layer_names=remat_layers):\n                test_layer = TestLayer(name=\"test_layer\")\n                other_layer = OtherLayer(name=\"other_layer\")\n                output_test = test_layer(input_tensor)\n                output_other = other_layer(input_tensor)\n\n            self.assertAllClose(output_test, input_tensor + 1)\n            self.assertAllClose(output_other, input_tensor * 2)\n\n        # Ensure remat was applied to the correct layer\n        self.assertLen(mock_remat.rematted_functions, 1)\n        next(iter(mock_remat.rematted_functions.values())).assert_called()\n\n    def test_remat_larger_than_mode(self):\n        \"\"\"Test rematerialization using larger_than mode.\"\"\"\n        mock_remat = MockRemat()\n        with mock.patch(\n            \"keras.src.backend.common.remat.remat\", wraps=mock_remat\n        ):\n\n            class TestLayer(layers.Layer):\n                def compute_output_shape(self, input_shape):\n                    return input_shape\n\n                def call(self, x):\n                    return x + 1\n\n            input_tensor = backend.random.uniform((100, 100))  # Large tensor\n\n            with RematScope(mode=\"larger_than\", output_size_threshold=5000):\n                layer = TestLayer()\n                output = layer(input_tensor)\n\n            self.assertAllClose(output, input_tensor + 1)\n\n        # Ensure remat was applied\n        self.assertLen(mock_remat.rematted_functions, 1)\n        next(iter(mock_remat.rematted_functions.values())).assert_called()\n\n    def test_remat_larger_than_mode_high_threshold(self):\n        \"\"\"Test rematerialization using larger_than mode.\"\"\"\n        mock_remat = MockRemat()\n        with mock.patch(\n            \"keras.src.backend.common.remat.remat\", wraps=mock_remat\n        ):\n\n            class TestLayer(layers.Layer):\n                def compute_output_shape(self, input_shape):\n                    return input_shape\n\n                def call(self, x):\n                    return x + 1\n\n            input_tensor = backend.random.uniform((100, 100))  # Large tensor\n\n            with RematScope(mode=\"larger_than\", output_size_threshold=50000):\n                layer = TestLayer()\n                output = layer(input_tensor)\n\n            self.assertAllClose(output, input_tensor + 1)\n\n        # Ensure remat was not applied\n        self.assertLen(mock_remat.rematted_functions, 0)\n\n    def test_rng_seed_tracking(self):\n        class RNGLayer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.seed_gen = backend.random.SeedGenerator(seed=1337)\n\n            def call(self, x):\n                return x * backend.random.normal(x.shape, seed=self.seed_gen)\n\n        layer = RNGLayer()\n        self.assertEqual(layer.variables, [layer.seed_gen.state])\n        self.assertAllClose(layer.variables[0], [1337, 0])\n        layer(np.ones((3, 4)))\n        self.assertAllClose(layer.variables[0], [1337, 1])\n\n        # Test tracking in list attributes.\n        class RNGListLayer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.seed_gens = []\n                self.seed_gens.append(backend.random.SeedGenerator(seed=1))\n                self.seed_gens.append(backend.random.SeedGenerator(seed=10))\n\n            def call(self, x):\n                x = x * backend.random.normal(x.shape, seed=self.seed_gens[0])\n                x = x * backend.random.normal(x.shape, seed=self.seed_gens[1])\n                return x\n\n        layer = RNGListLayer()\n        self.assertEqual(\n            layer.variables,\n            [layer.seed_gens[0].state, layer.seed_gens[1].state],\n        )\n        self.assertAllClose(layer.variables[0], [1, 0])\n        self.assertAllClose(layer.variables[1], [10, 0])\n        layer(np.ones((3, 4)))\n        self.assertAllClose(layer.variables[0], [1, 1])\n        self.assertAllClose(layer.variables[1], [10, 1])\n\n    def test_layer_tracking(self):\n        class LayerWithDenseLayers(layers.Layer):\n            def __init__(self, units):\n                super().__init__()\n                self.dense1 = layers.Dense(units)\n                self.layer_dict = {\n                    \"dense2\": layers.Dense(units),\n                }\n                self.layer_list = [layers.Dense(units)]\n                self.units = units\n                self.seed_generator = backend.random.SeedGenerator(seed=1)\n\n            def build(self, input_shape):\n                self.layer_list.append(layers.Dense(self.units))\n\n            def call(self, x):\n                x = self.dense1(x)\n                x = self.layer_dict[\"dense2\"](x)\n                x = self.layer_list[0](x)\n                x = self.layer_list[1](x)\n                return x\n\n        class ParentLayer(layers.Layer):\n            def __init__(self, inner_layer):\n                super().__init__()\n                self.inner_layer = inner_layer\n\n            def call(self, x):\n                return self.inner_layer(x)\n\n        layer = LayerWithDenseLayers(3)\n        layer.build((1, 3))\n        self.assertLen(layer._layers, 4)\n        layer(np.zeros((1, 3)))\n        self.assertLen(layer.variables, 9)\n        self.assertLen(layer.weights, 8)\n\n        layer = ParentLayer(LayerWithDenseLayers(3))\n        self.assertLen(layer._layers, 1)\n        layer(np.zeros((1, 3)))\n        self.assertLen(layer.variables, 9)\n        self.assertLen(layer.weights, 8)\n\n        layer = ParentLayer(ParentLayer(LayerWithDenseLayers(3)))\n        self.assertLen(layer._layers, 1)\n        layer(np.zeros((1, 3)))\n        self.assertLen(layer.variables, 9)\n        self.assertLen(layer.weights, 8)\n\n    def test_metric_tracking(self):\n        class LayerWithMetric(layers.Layer):\n            def __init__(self, units):\n                super().__init__()\n                self.dense = layers.Dense(units)\n                self.metric = metrics.MeanSquaredError(name=\"my_metric\")\n\n            def build(self, input_shape):\n                self.dense.build(input_shape)\n\n            def call(self, x):\n                return self.dense(x)\n\n        class ParentLayerWithMetric(layers.Layer):\n            def __init__(self, inner_layer):\n                super().__init__()\n                self.inner_layer = inner_layer\n                self.metric = metrics.MeanSquaredError(name=\"my_metric\")\n\n            def build(self, input_shape):\n                self.inner_layer.build(input_shape)\n\n            def call(self, x):\n                return self.inner_layer(x)\n\n        layer = LayerWithMetric(3)\n        layer.build((1, 3))\n\n        self.assertLen(layer.metrics, 1)\n        self.assertLen(layer.metrics_variables, 2)\n        self.assertLen(layer.trainable_variables, 2)\n        self.assertLen(layer.non_trainable_variables, 0)\n\n        layer = ParentLayerWithMetric(LayerWithMetric(3))\n        layer.build((1, 3))\n\n        self.assertLen(layer.metrics, 2)\n        self.assertLen(layer.metrics_variables, 4)\n        self.assertLen(layer.trainable_variables, 2)\n        self.assertLen(layer.non_trainable_variables, 0)\n\n        layer = ParentLayerWithMetric(ParentLayerWithMetric(LayerWithMetric(3)))\n        layer.build((1, 3))\n\n        self.assertLen(layer.metrics, 3)\n        self.assertLen(layer.metrics_variables, 6)\n        self.assertLen(layer.trainable_variables, 2)\n        self.assertLen(layer.non_trainable_variables, 0)\n\n    def test_build_on_call(self):\n        class LayerWithUnbuiltState(layers.Layer):\n            def __init__(self, units):\n                super().__init__()\n                self.dense1 = layers.Dense(units)\n\n            def call(self, x):\n                return self.dense1(x)\n\n        layer = LayerWithUnbuiltState(2)\n        layer(backend.KerasTensor((3, 4)))\n        self.assertLen(layer.weights, 2)\n\n        class KwargsLayerWithUnbuiltState(layers.Layer):\n            def __init__(self, units):\n                super().__init__()\n                self.dense1 = layers.Dense(units)\n                self.dense2 = layers.Dense(units)\n\n            def call(self, x1, x2):\n                return self.dense1(x1) + self.dense2(x2)\n\n        layer = KwargsLayerWithUnbuiltState(2)\n        layer(backend.KerasTensor((3, 4)), backend.KerasTensor((3, 4)))\n        self.assertLen(layer.weights, 4)\n\n        layer = KwargsLayerWithUnbuiltState(2)\n        layer(x1=backend.KerasTensor((3, 4)), x2=backend.KerasTensor((3, 4)))\n        self.assertLen(layer.weights, 4)\n\n        class DictLayerWithUnbuiltState(layers.Layer):\n            def __init__(self, units):\n                super().__init__()\n                self.dense = layers.Dense(units)\n\n            def call(self, xs):\n                result = self.dense(xs[\"x1\"])\n                if xs.get(\"x2\", None) is not None:\n                    result += self.dense(xs[\"x2\"])\n                return result\n\n        layer = DictLayerWithUnbuiltState(2)\n        layer(\n            {\n                \"x1\": backend.KerasTensor((3, 4)),\n                \"x2\": backend.KerasTensor((3, 4)),\n            }\n        )\n        self.assertLen(layer.weights, 2)\n\n        layer = DictLayerWithUnbuiltState(2)\n        layer({\"x1\": backend.KerasTensor((3, 4)), \"x2\": None})\n        self.assertLen(layer.weights, 2)\n\n        class ListLayerWithUnbuiltState(layers.Layer):\n            def __init__(self, units):\n                super().__init__()\n                self.dense = layers.Dense(units)\n\n            def call(self, xs):\n                result = self.dense(xs[0])\n                if xs[1] is not None:\n                    result += self.dense(xs[1])\n                return result\n\n        layer = ListLayerWithUnbuiltState(2)\n        layer([backend.KerasTensor((3, 4)), backend.KerasTensor((3, 4))])\n        self.assertLen(layer.weights, 2)\n\n        layer = ListLayerWithUnbuiltState(2)\n        layer([backend.KerasTensor((3, 4)), None])\n        self.assertLen(layer.weights, 2)\n\n    def test_activity_regularization(self):\n        class ActivityRegularizer(layers.Layer):\n            def call(self, x):\n                return x\n\n        layer = ActivityRegularizer(activity_regularizer=\"l1\")\n        layer(np.ones((1,)))\n        self.assertLen(layer.losses, 1)\n        self.assertAllClose(layer.losses[0], 0.01)\n\n        # losses are reset upon call\n        layer(np.ones((1,)))\n        self.assertLen(layer.losses, 1)\n        self.assertAllClose(layer.losses[0], 0.01)\n\n        # KerasTensors are no op\n        layer = ActivityRegularizer(activity_regularizer=\"l1\")\n        layer(layers.Input(batch_shape=(2, 2)))\n        self.assertLen(layer.losses, 0)\n\n    @parameterized.named_parameters(\n        (\"batch_size_0\", 0),\n        (\"batch_size_1\", 1),\n        (\"batch_size_5\", 5),\n        (\"batch_size_10\", 10),\n    )\n    def test_activity_regularization_batch_normalization(self, batch_size):\n        class SimpleLayer(layers.Layer):\n            def call(self, x):\n                return x\n\n        layer = SimpleLayer(activity_regularizer=\"l2\")\n        layer(ops.ones((batch_size, 5)) * 2.0)\n        self.assertLen(layer.losses, 1)\n        expected_loss = 0.0 if batch_size == 0 else 0.2\n        self.assertAllClose(layer.losses[0], expected_loss)\n\n    @pytest.mark.requires_trainable_backend\n    def test_add_loss(self):\n        class LossLayer(layers.Layer):\n            def call(self, x):\n                self.add_loss(ops.sum(x))\n                return x\n\n        layer = LossLayer()\n        layer(np.ones((1,)))\n        self.assertLen(layer.losses, 1)\n        self.assertAllClose(layer.losses[0], 1.0)\n\n        # losses are reset upon call\n        layer = LossLayer()\n        layer(np.ones((1,)))\n        self.assertLen(layer.losses, 1)\n        self.assertAllClose(layer.losses[0], 1.0)\n\n        # It works inside a model\n        model = models.Sequential([layer])\n        model(np.ones((1,)))\n        self.assertLen(model.losses, 1)\n        self.assertAllClose(model.losses[0], 1.0)\n\n        # It works recursively in nested models\n        model = models.Sequential([model])\n        model(np.ones((1,)))\n        self.assertLen(model.losses, 1)\n        self.assertAllClose(model.losses[0], 1.0)\n\n    def test_training_arg_value_resolution(self):\n        # Check that even if `training` is not passed\n        # to an inner layer, the outer value gets propagated\n        # in __call__.\n        class TrainingLayer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.dp = layers.Dropout(0.9)\n\n            def call(self, x, training=False):\n                return self.dp(x)\n\n        layer = TrainingLayer()\n        x = np.ones((4, 4))\n        y = layer(x)\n        self.assertEqual(ops.min(y), 1)\n        y = layer(x, training=True)\n        self.assertEqual(ops.min(y), 0)\n\n        # Check that it still works one level deeper.\n        class WrappedTrainingLayer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.dp = TrainingLayer()\n\n            def call(self, x, training=False):\n                return self.dp(x)\n\n        layer = WrappedTrainingLayer()\n        x = np.ones((4, 4))\n        y = layer(x)\n        self.assertEqual(ops.min(y), 1)\n        y = layer(x, training=True)\n        self.assertEqual(ops.min(y), 0)\n\n        # Check that if `training` is passed\n        # to an inner layer in call(), the explicitly\n        # passed value is what the layer sees.\n        class TrainingLayerExplicit(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.dp = layers.Dropout(0.9)\n\n            def call(self, x, training=False):\n                return self.dp(x, training=True)\n\n        layer = TrainingLayerExplicit()\n        x = np.ones((4, 4))\n        y = layer(x, training=False)\n        self.assertEqual(ops.min(y), 0)\n\n        # Test that layer interruption does not cause\n        # the call context to linger\n        class BadLayer(layers.Layer):\n            def call(self, x, training=False):\n                raise RuntimeError(\"oops!\")\n\n        x = np.ones((4, 4))\n        layer = BadLayer()\n        try:\n            # training=True will be recorded\n            # in the call context\n            layer(x, training=True)\n        except RuntimeError:\n            pass\n        layer = TrainingLayer()\n        # But this layer call should not see it\n        y = layer(x)\n        self.assertEqual(ops.min(y), 1)\n\n    @pytest.mark.skipif(\n        backend.backend() == \"torch\",\n        reason=\"Some torch ops not implemented for float16 on CPU.\",\n    )\n    def test_mixed_precision(self):\n        x = np.ones((4, 4))\n\n        layer = layers.Dense(2, dtype=\"float16\")\n        y = layer(x)\n        self.assertEqual(layer.compute_dtype, \"float16\")\n        self.assertEqual(layer.variable_dtype, \"float16\")\n        self.assertDType(y, \"float16\")\n\n        layer = layers.Dense(2, dtype=\"mixed_float16\")\n        y = layer(x)\n        self.assertEqual(layer.compute_dtype, \"float16\")\n        self.assertEqual(layer.variable_dtype, \"float32\")\n        self.assertDType(y, \"float16\")\n        self.assertEqual(layer.kernel.dtype, \"float32\")\n\n    @pytest.mark.skipif(\n        backend.backend() == \"torch\",\n        reason=\"Some torch ops not implemented for float16 on CPU.\",\n    )\n    def test_autocast(self):\n        assertDType = self.assertDType\n\n        # A layer with a int dtype (some preprocessing layers do this).\n        class InnerLayerOne(layers.Layer):\n            def __init__(self):\n                super().__init__(dtype=\"int\")\n                self.v = self.add_weight(\n                    shape=(),\n                    initializer=\"ones\",\n                    trainable=True,\n                    dtype=\"float32\",\n                )\n                self._build_at_init()\n\n            def call(self, x):\n                # Should not autocast.\n                assertDType(self.v, \"float32\")\n                return ops.add(ops.cast(x, \"float32\"), self.v)\n\n        # A layer that is explicitly full precision.\n        class InnerLayerTwo(layers.Layer):\n            def __init__(self):\n                super().__init__(dtype=\"float32\")\n                self.v = self.add_weight(\n                    shape=(),\n                    initializer=\"ones\",\n                    trainable=True,\n                )\n                self._build_at_init()\n\n            def call(self, x):\n                # Should not autocast.\n                assertDType(self.v, \"float32\")\n                return ops.add(x, self.v)\n\n        # A layer that is explicitly mixed precision but with autocast=False\n        # weight.\n        class InnerLayerThree(layers.Layer):\n            def __init__(self):\n                super().__init__(dtype=\"mixed_float16\")\n                self.v = self.add_weight(\n                    shape=(),\n                    initializer=\"ones\",\n                    trainable=True,\n                    autocast=False,\n                )\n                self._build_at_init()\n\n            def call(self, x):\n                # Should not autocast `self.v`.\n                assertDType(self.v, \"float32\")\n                return ops.add(x, self.v)\n\n        # A layer that is explicitly mixed precision with inner layers.\n        class MixedPrecisionLayer(layers.Layer):\n            def __init__(self):\n                super().__init__(dtype=\"mixed_float16\")\n                self.v = self.add_weight(\n                    shape=(),\n                    initializer=\"ones\",\n                    trainable=True,\n                )\n                self.inner_one = InnerLayerOne()\n                self.inner_two = InnerLayerTwo()\n                self.inner_three = InnerLayerThree()\n                self._build_at_init()\n\n            def call(self, x):\n                # Should autocast.\n                assertDType(self.v, \"float16\")\n                return self.inner_three(\n                    self.inner_two(self.inner_one(ops.add(x, self.v)))\n                )\n\n        layer = MixedPrecisionLayer()\n        y = layer(np.array(0.0))\n        self.assertEqual(y, 4.0)\n\n    def test_autocast_with_np_array(self):\n        assertDType = self.assertDType\n\n        class CustomLayer(layers.Layer):\n            def __init__(self, **kwargs):\n                super().__init__(**kwargs)\n\n            def call(self, x):\n                # Here are the assertions.\n                assertDType(x[0], \"float32\")  # Cast to compute_dtype\n                assertDType(x[1], \"int32\")  # Untouched\n\n        x = [np.zeros(1, dtype=\"float64\"), np.zeros(1, dtype=\"int32\")]\n        CustomLayer()(x)\n\n    @pytest.mark.skipif(\n        backend.backend() == \"numpy\", reason=\"masking not supported with numpy\"\n    )\n    def test_keras_mask_with_autocast(self):\n        test_obj = self\n\n        class CustomLayer(layers.Layer):\n            def __init__(self, **kwargs):\n                super().__init__(**kwargs)\n                self.supports_masking = True\n\n            def call(self, x, mask=None):\n                test_obj.assertIsNotNone(mask)\n                test_obj.assertDType(x, \"float16\")\n                return x\n\n        x = ops.zeros((1, 2), dtype=\"float32\")\n        mask = ops.array([True, False])\n        backend.set_keras_mask(x, mask)\n        y = CustomLayer(dtype=\"float16\")(x)\n        self.assertAllEqual(\n            mask,\n            backend.get_keras_mask(y),\n            \"Masking is not propagated by Autocast\",\n        )\n\n    @pytest.mark.skipif(\n        backend.backend() == \"numpy\", reason=\"masking not supported with numpy\"\n    )\n    def test_end_to_end_masking(self):\n        # Check that masking survives compilation\n        model = models.Sequential(\n            [\n                layers.Embedding(\n                    2, 2, mask_zero=True, embeddings_initializer=\"ones\"\n                ),\n            ]\n        )\n        model.compile(loss=\"mse\")\n        targets = np.array([[[1.0, 1.0], [2.0, 2.0], [3.0, 3.0], [1.0, 1.0]]])\n        loss = model.evaluate(np.array([[1, 0, 0, 1]]), targets, verbose=0)\n        self.assertAllClose(loss, 0.0)\n\n    @pytest.mark.skipif(\n        backend.backend() == \"numpy\", reason=\"masking not supported with numpy\"\n    )\n    def test_masking(self):\n        test_obj = self\n\n        class BasicMaskedLayer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.supports_masking = True\n\n            def call(self, x, mask=None):\n                test_obj.assertIsNotNone(mask)\n                return x\n\n        layer = BasicMaskedLayer()\n        x = backend.numpy.ones((4, 4))\n        mask = backend.numpy.ones((4,))\n        backend.set_keras_mask(x, mask)\n        layer(x)\n\n        layer(backend.numpy.ones((4, 4)), mask=backend.numpy.ones((4,)))\n\n        class NestedInputMaskedLayer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.supports_masking = True\n\n            def call(self, x, mask=None):\n                test_obj.assertIsInstance(x, list)\n                test_obj.assertLen(x, 2)\n                test_obj.assertIsInstance(mask, list)\n                test_obj.assertLen(mask, 2)\n                return x\n\n        layer = NestedInputMaskedLayer()\n        x1 = backend.numpy.ones((4, 4))\n        mask1 = backend.numpy.ones((4,))\n        backend.set_keras_mask(x1, mask1)\n        x2 = backend.numpy.ones((4, 4))\n        mask2 = backend.numpy.ones((4,))\n        backend.set_keras_mask(x2, mask2)\n        layer([x1, x2])\n\n        layer(\n            [backend.numpy.ones((4, 4)), backend.numpy.ones((4, 4))],\n            mask=[backend.numpy.ones((4,)), backend.numpy.ones((4,))],\n        )\n\n        class PositionalInputsMaskedLayer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.supports_masking = True\n\n            def call(self, x1, x2, x1_mask=None, x2_mask=None):\n                test_obj.assertIsNotNone(x1_mask)\n                test_obj.assertIsNotNone(x2_mask)\n                return x1 + x2\n\n        layer = PositionalInputsMaskedLayer()\n        layer(x1, x2)\n        layer(x1=x1, x2=x2)\n\n        class PositionalNestedInputsMaskedLayer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.supports_masking = True\n\n            def call(self, x1, x2, x1_mask=None, x2_mask=None):\n                test_obj.assertIsInstance(x1, tuple)\n                test_obj.assertIsNotNone(x1_mask)\n                test_obj.assertIsNotNone(x2_mask)\n                test_obj.assertIsInstance(x1_mask, tuple)\n                return x1[0] + x1[1] + x2\n\n        layer = PositionalNestedInputsMaskedLayer()\n        x1_1 = backend.numpy.ones((4, 4))\n        mask1 = backend.numpy.ones((4,))\n        backend.set_keras_mask(x1_1, mask1)\n        x1_2 = backend.numpy.ones((4, 4))\n        mask2 = backend.numpy.ones((4,))\n        backend.set_keras_mask(x1_2, mask2)\n        x2 = backend.numpy.ones((4, 4))\n        mask2 = backend.numpy.ones((4,))\n        backend.set_keras_mask(x2, mask2)\n        layer((x1_1, x1_2), x2)\n        layer(x1=(x1_1, x1_2), x2=x2)\n\n        class MaskUnsetDuringCallLayer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.supports_masking = True\n\n            def call(self, x, mask=None):\n                test_obj.assertIsNotNone(mask)\n                backend.set_keras_mask(x, None)  # Unset mask\n                return x\n\n        layer = MaskUnsetDuringCallLayer()\n        x = backend.numpy.ones((4, 4))\n        mask = backend.numpy.ones((4,))\n        backend.set_keras_mask(x, mask)\n        y = layer(x)\n        self.assertAllClose(y._keras_mask, mask)\n\n    @pytest.mark.skipif(\n        backend.backend() == \"numpy\", reason=\"masking not supported with numpy\"\n    )\n    def test_masking_with_explicit_kwarg_propagation(self):\n        \"\"\"This test validates that an explicit `mask` kwarg is correctly\n        used to compute the output mask.\n        \"\"\"\n\n        class PassthroughMaskLayer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.supports_masking = True\n\n            def call(self, x, mask=None):\n                # The layer itself can use the mask.\n                self.used_mask = mask is not None\n                return x\n\n        layer = PassthroughMaskLayer()\n        # Create an input tensor WITHOUT an attached mask.\n        x = backend.numpy.ones((4, 4))\n        self.assertIsNone(getattr(x, \"_keras_mask\", None))\n\n        # Create a mask to be passed explicitly.\n        explicit_mask = backend.numpy.array([True, True, False, False])\n\n        # Call the layer, passing the mask as a keyword argument.\n        y = layer(x, mask=explicit_mask)\n\n        # Assert that the layer's internal call received the mask.\n        self.assertTrue(layer.used_mask)\n\n        # Assert that the output tensor 'y' now has the explicit mask attached\n        # for propagation to the next layer.\n        self.assertAllClose(backend.get_keras_mask(y), explicit_mask)\n\n    def test_stateless_call(self):\n        class TestLayer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self._seed_generator = backend.random.SeedGenerator(1337)\n                self.ntw = self.add_weight(\n                    shape=(),\n                    initializer=\"zeros\",\n                    trainable=False,\n                )\n                self.tw = self.add_weight(\n                    shape=(),\n                    initializer=\"zeros\",\n                    trainable=True,\n                    regularizer=\"l1\",\n                )\n                self._build_at_init()\n\n            def call(self, x):\n                x = backend.convert_to_tensor(x, dtype=\"float32\")\n                self.add_loss(ops.sum(x))\n                self.ntw.assign(ops.sum(x))\n                x = x + backend.random.normal(\n                    shape=(), seed=self._seed_generator\n                )\n                return ops.add(x, ops.add(self.tw, self.ntw))\n\n        data = np.random.random((3, 4))\n        layer = TestLayer()\n        out = layer(data)\n        layer1 = TestLayer()\n        out1 = layer1(data)\n        # Check that the layer is in fact deterministic\n        self.assertAllClose(out, out1)\n\n        # Test stateless_call correctness\n        layer2 = TestLayer()\n        trainable_variables = layer2.trainable_variables\n        non_trainable_variables = layer2.non_trainable_variables\n        out2, non_trainable_variables = layer2.stateless_call(\n            trainable_variables, non_trainable_variables, data\n        )\n        self.assertAllClose(out1, out2)\n        self.assertEqual(\n            len(layer1.non_trainable_variables), len(non_trainable_variables)\n        )\n        for ref_v, v in zip(\n            layer1.non_trainable_variables, non_trainable_variables\n        ):\n            self.assertAllClose(ref_v, v)\n\n        # Test with loss collection\n        layer3 = TestLayer()\n        trainable_variables = layer3.trainable_variables\n        non_trainable_variables = layer3.non_trainable_variables\n        out3, non_trainable_variables, losses = layer3.stateless_call(\n            trainable_variables,\n            non_trainable_variables,\n            data,\n            return_losses=True,\n        )\n        self.assertAllClose(out1, out3)\n        for ref_v, v in zip(\n            layer1.non_trainable_variables, non_trainable_variables\n        ):\n            self.assertAllClose(ref_v, v)\n        self.assertLen(losses, 2)\n        for ref_loss, loss in zip(layer1.losses, losses):\n            self.assertAllClose(ref_loss, loss)\n\n    def test_trainable_setting(self):\n        class NonTrainableWeightsLayer(layers.Layer):\n            def build(self, _):\n                self.w1 = self.add_weight(\n                    shape=(),\n                    initializer=\"ones\",\n                    trainable=True,\n                )\n                self.w2 = self.add_weight(\n                    shape=(),\n                    initializer=\"ones\",\n                    trainable=False,\n                )\n                self.seed = backend.random.SeedGenerator(123)\n\n            def call(self, inputs):\n                return inputs\n\n        class NestedNonTrainableWeightsLayer(layers.Layer):\n            def build(self, _):\n                self.w1 = self.add_weight(\n                    shape=(),\n                    initializer=\"ones\",\n                    trainable=True,\n                )\n                self.w2 = self.add_weight(\n                    shape=(),\n                    initializer=\"ones\",\n                    trainable=False,\n                )\n                self.nested = NonTrainableWeightsLayer()\n                self.nested.build(None)\n\n            def call(self, inputs):\n                return inputs\n\n        layer = NestedNonTrainableWeightsLayer()\n        layer.build(None)\n        self.assertEqual(len(layer.trainable_weights), 2)\n        self.assertEqual(len(layer.trainable_variables), 2)\n        self.assertEqual(len(layer.non_trainable_weights), 2)\n        self.assertEqual(len(layer.non_trainable_variables), 3)\n\n        layer.trainable = False\n        self.assertEqual(len(layer.trainable_weights), 0)\n        self.assertEqual(len(layer.trainable_variables), 0)\n        self.assertEqual(len(layer.non_trainable_weights), 4)\n        self.assertEqual(len(layer.non_trainable_variables), 5)\n        self.assertFalse(layer.w1.trainable)\n        self.assertFalse(layer.nested.w1.trainable)\n\n        layer.trainable = True\n        self.assertEqual(len(layer.trainable_weights), 2)\n        self.assertEqual(len(layer.trainable_variables), 2)\n        self.assertEqual(len(layer.non_trainable_weights), 2)\n        self.assertEqual(len(layer.non_trainable_variables), 3)\n        self.assertTrue(layer.w1.trainable)\n        self.assertTrue(layer.nested.w1.trainable)\n\n        layer = NestedNonTrainableWeightsLayer(trainable=False)\n        layer.build(None)\n        self.assertEqual(len(layer.trainable_weights), 0)\n        self.assertEqual(len(layer.trainable_variables), 0)\n        self.assertEqual(len(layer.non_trainable_weights), 4)\n        self.assertEqual(len(layer.non_trainable_variables), 5)\n\n        layer.trainable = True\n        self.assertEqual(len(layer.trainable_weights), 2)\n        self.assertEqual(len(layer.trainable_variables), 2)\n        self.assertEqual(len(layer.non_trainable_weights), 2)\n        self.assertEqual(len(layer.non_trainable_variables), 3)\n\n    def test_build_signature_errors(self):\n        class NoShapeSuffix(layers.Layer):\n            def build(self, foo_shape, bar):\n                self.built = True\n\n            def call(self, foo, bar):\n                return foo + bar\n\n        class NonMatchingArgument(layers.Layer):\n            def build(self, foo_shape, baz_shape):\n                self.built = True\n\n            def call(self, foo, bar):\n                return foo[:, 0] + bar[:, 0]\n\n        class MatchingArguments(layers.Layer):\n            def build(self, bar_shape, foo_shape):\n                self.foo_shape = foo_shape\n                self.bar_shape = bar_shape\n\n            def call(self, foo, bar):\n                return foo[:, 0] + bar[:, 0]\n\n        class SubsetArguments(layers.Layer):\n            def build(self, baz_shape, foo_shape):\n                self.foo_shape = foo_shape\n                self.baz_shape = baz_shape\n\n            def call(self, foo, bar=None, baz=None):\n                return foo[:, 0] + bar[:, 0] + baz[:, 0]\n\n        class SingleArgument(layers.Layer):\n            def build(self, anything_whatsoever):\n                self.foo_shape = anything_whatsoever\n\n            def call(self, foo, bar):\n                return foo[:, 0] + bar[:, 0]\n\n        foo = backend.numpy.ones((4, 1))\n        bar = backend.numpy.ones((4, 2))\n        baz = backend.numpy.ones((4, 3))\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"argument `bar`, which does not end in `_shape`\",\n        ):\n            layer = NoShapeSuffix()\n            layer(foo, bar)\n\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"`baz_shape`, but `call\\(\\)` does not have argument `baz`\",\n        ):\n            layer = NonMatchingArgument()\n            layer(foo, bar)\n\n        # Align by name when build and call arguments match.\n        layer = MatchingArguments()\n        layer(foo, bar)\n        self.assertEqual(layer.foo_shape, foo.shape)\n        self.assertEqual(layer.bar_shape, bar.shape)\n\n        # Align by name when build supports a subset of call arguments.\n        layer = SubsetArguments()\n        layer(foo, bar, baz)\n        self.assertEqual(layer.foo_shape, foo.shape)\n        self.assertEqual(layer.baz_shape, baz.shape)\n\n        # When build has only one argument, match the first call argument.\n        layer = SingleArgument()\n        layer(foo, bar)\n        self.assertEqual(layer.foo_shape, foo.shape)\n\n    def test_training_arg_not_specified(self):\n        class NoTrainingSpecified(layers.Layer):\n            def __init__(self):\n                super().__init__()\n\n            def build(self, input_shape):\n                self.activation = layers.Activation(\"linear\")\n\n            def call(self, inputs):\n                return self.activation(inputs)\n\n        layer = NoTrainingSpecified()\n        inputs = ops.random.uniform(shape=(1, 100, 100, 3))\n        layer(inputs, training=True)\n\n    def test_tracker_locking(self):\n        class BadLayer(layers.Layer):\n            def call(self, x):\n                self.w = self.add_weight(initializer=\"zeros\", shape=())\n                return x\n\n        layer = BadLayer()\n        with self.assertRaisesRegex(\n            ValueError,\n            \"cannot add new elements of state\",\n        ):\n            layer(np.random.random((3, 2)))\n\n    def test_init_after_state_tracking(self):\n        class MyLayer(layers.Layer):\n            def __init__(self):\n                self.some_attr = True\n                self.w = backend.Variable(np.random.random((2,)))\n                super().__init__()\n\n        layer = MyLayer()\n        self.assertEqual(len(layer.weights), 1)\n\n    def test_add_weight_defaults(self):\n        class MyLayer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.w1 = self.add_weight()\n                self.w2 = self.add_weight(dtype=\"int32\", trainable=False)\n                self.w3 = self.add_weight(dtype=\"bool\", trainable=False)\n                self.w4 = self.add_weight(\n                    dtype=\"int32\", shape=(2, 2), trainable=False\n                )\n                self.w5 = self.add_weight(initializer=\"ones\", shape=(2, 2))\n\n        layer = MyLayer()\n        self.assertEqual(layer.w1.shape, ())\n        self.assertEqual(layer.w1.dtype, \"float32\")\n\n        self.assertEqual(layer.w2.shape, ())\n        self.assertEqual(layer.w2.dtype, \"int32\")\n        self.assertAllClose(backend.convert_to_numpy(layer.w2), 0)\n\n        self.assertEqual(layer.w3.shape, ())\n        self.assertEqual(layer.w3.dtype, \"bool\")\n        self.assertAllClose(backend.convert_to_numpy(layer.w3), False)\n\n        self.assertEqual(layer.w4.shape, (2, 2))\n        self.assertEqual(layer.w4.dtype, \"int32\")\n        self.assertAllClose(\n            backend.convert_to_numpy(layer.w4), np.zeros((2, 2))\n        )\n\n        self.assertEqual(layer.w5.shape, (2, 2))\n        self.assertEqual(layer.w5.dtype, \"float32\")\n        self.assertAllClose(backend.convert_to_numpy(layer.w5), np.ones((2, 2)))\n\n    def test_add_weight_string_as_first_positional_arg(self):\n        \"\"\"Test that passing a string as first positional arg to add_weight\n        raises a clear error guiding users to use name= keyword.\"\"\"\n\n        # Case 1: String as only positional arg (e.g. add_weight(\"matrix\"))\n        class MyLayer1(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.w = self.add_weight(\"my_weight\")\n\n        with self.assertRaisesRegex(\n            ValueError,\n            \"name.*keyword argument\",\n        ):\n            MyLayer1()\n\n        # Case 2: String positional + shape kwarg — the exact bug from\n        # https://github.com/keras-team/keras/issues/22265\n        # In Keras 2 this was valid: add_weight(\"matrix\", shape=(3, 4))\n        class MyLayer2(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.w = self.add_weight(\n                    \"matrix\", shape=(3, 4), initializer=\"zeros\"\n                )\n\n        with self.assertRaisesRegex(\n            ValueError,\n            \"name.*keyword argument\",\n        ):\n            MyLayer2()\n\n        # Case 3: shape passed both positionally and as keyword\n        class MyLayer3(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.w = self.add_weight((3, 4), shape=(3, 4))\n\n        with self.assertRaisesRegex(\n            ValueError,\n            \"`shape` was passed both positionally and as a keyword argument\",\n        ):\n            MyLayer3()\n\n        # Case 4: positional shape / initializer / dtype must remain valid.\n        class MyLayer4(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.w = self.add_weight((3, 4), \"zeros\", \"float32\")\n\n        layer = MyLayer4()\n        self.assertEqual(layer.w.shape, (3, 4))\n        self.assertEqual(layer.w.dtype, \"float32\")\n        self.assertAllClose(backend.convert_to_numpy(layer.w), np.zeros((3, 4)))\n\n        # Case 5: too many positional arguments\n        class MyLayer5(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.w = self.add_weight((3, 4), \"zeros\", \"float32\", \"name\")\n\n        with self.assertRaisesRegex(\n            TypeError,\n            \"takes at most 3 positional arguments\",\n        ):\n            MyLayer5()\n\n    def test_remove_weight(self):\n        class MyLayer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.w = self.add_weight()\n\n            def custom_remove_w(self):\n                self.w = self._untrack_variable(self.w)\n\n            def custom_change_dtype(self):\n                self.w = self._untrack_variable(self.w)\n                self.w = self.add_weight(\n                    initializer=\"zeros\", dtype=\"int8\", trainable=False\n                )\n\n        layer = MyLayer()\n        self.assertEqual(len(layer.weights), 1)\n        layer.custom_remove_w()\n        self.assertEqual(len(layer.weights), 0)\n        self.assertEqual(layer.w, None)\n\n        layer = MyLayer()\n        self.assertEqual(layer.w.dtype, \"float32\")\n        self.assertEqual(layer.w.trainable, True)\n        layer.custom_change_dtype()\n        self.assertEqual(layer.w.dtype, \"int8\")\n        self.assertEqual(layer.w.trainable, False)\n\n    def test_trainable_init_arg(self):\n        inputs = layers.Input(shape=(1,))\n        layer = layers.Dense(2, trainable=False)\n        outputs = layer(inputs)\n        model = models.Model(inputs, outputs)\n\n        self.assertFalse(layer.trainable)\n        self.assertLen(layer._trainable_variables, 2)\n        self.assertLen(layer._non_trainable_variables, 0)\n        self.assertLen(layer.trainable_weights, 0)\n        self.assertLen(model.trainable_weights, 0)\n        self.assertLen(model.non_trainable_weights, 2)\n\n        layer.trainable = True\n        self.assertTrue(layer.trainable)\n        self.assertLen(layer._trainable_variables, 2)\n        self.assertLen(layer._non_trainable_variables, 0)\n        self.assertLen(layer.trainable_weights, 2)\n        self.assertLen(model.trainable_weights, 2)\n        self.assertLen(model.non_trainable_weights, 0)\n\n    def test_dtype_policy_setter(self):\n        layer = layers.Dense(2)\n        # Set by string\n        layer.dtype_policy = \"mixed_bfloat16\"\n        self.assertEqual(layer.dtype_policy.name, \"mixed_bfloat16\")\n        self.assertEqual(layer.dtype_policy.compute_dtype, \"bfloat16\")\n        self.assertEqual(layer.dtype_policy.variable_dtype, \"float32\")\n        # Set by DTypePolicy\n        layer.dtype_policy = dtype_policies.DTypePolicy(\"mixed_float16\")\n        self.assertEqual(layer.dtype_policy.name, \"mixed_float16\")\n        self.assertEqual(layer.dtype_policy.compute_dtype, \"float16\")\n        self.assertEqual(layer.dtype_policy.variable_dtype, \"float32\")\n        # Set with DTypePolicyMap\n        dtype_policy_map = dtype_policies.DTypePolicyMap()\n        layer = layers.Dense(2, dtype=dtype_policy_map)\n        layer.build([None, 1])\n        layer.dtype_policy = \"mixed_bfloat16\"\n        self.assertIsInstance(\n            layer._dtype_policy, dtype_policies.DTypePolicyMap\n        )\n        self.assertEqual(\n            layer._dtype_policy[layer.path],\n            dtype_policies.DTypePolicy(\"mixed_bfloat16\"),\n        )\n\n    def test_pickle_layer(self):\n        layer = layers.Dense(2)\n        reloaded = pickle.loads(pickle.dumps(layer))\n        self.assertEqual(layer.get_config(), reloaded.get_config())\n\n    def test_serialize_dtype(self):\n        assertIsNone = self.assertIsNone\n        assertIsNotNone = self.assertIsNotNone\n\n        class AssertionDense(layers.Dense):\n            def __init__(self, *args, **kwargs):\n                dtype = kwargs[\"dtype\"]\n                if isinstance(dtype, str):\n                    # `dtype` is a plain string, it should be the `name` from a\n                    # `DTypePolicy`\n                    dtype = dtype_policies.get(dtype)\n                    assertIsNone(dtype.quantization_mode)\n                else:\n                    # `dtype` is a DTypePolicy instance, it should be an\n                    # instance of `QuantizedDTypePolicy`\n                    assertIsNotNone(dtype.quantization_mode)\n                super().__init__(*args, **kwargs)\n\n        # Test floating dtype serialization\n        layer = layers.Dense(2, dtype=\"bfloat16\")\n        config = layer.get_config()\n        self.assertIn(\"dtype\", config)\n        self.assertEqual(\n            config[\"dtype\"],\n            dtype_policies.serialize(dtype_policies.DTypePolicy(\"bfloat16\")),\n        )\n        AssertionDense.from_config(config)  # Assertion inside\n\n        # Test quantized dtype serialization\n        layer = layers.Dense(2, dtype=\"int8_from_bfloat16\")\n        config = layer.get_config()\n        self.assertIn(\"dtype\", config)\n        self.assertEqual(\n            config[\"dtype\"],\n            dtype_policies.serialize(dtype_policies.get(\"int8_from_bfloat16\")),\n        )\n        AssertionDense.from_config(config)  # Assertion inside\n\n    def test_serialize_activity_regularizer(self):\n        layer = layers.Dense(2, activity_regularizer=\"l2\")\n        config = layer.get_config()\n        self.assertIn(\"activity_regularizer\", config)\n        new_layer = layers.Dense.from_config(config)\n        self.assertEqual(\n            new_layer.activity_regularizer.__class__.__name__, \"L2\"\n        )\n\n        layer = layers.Dense(2)\n        config = layer.get_config()\n        self.assertNotIn(\"activity_regularizer\", config)\n\n    def test_custom_layer_add_weight_in_init_name(self):\n        class TrainingLayer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.inner = InnerLayer()\n\n        class InnerLayer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.var = self.add_weight(\n                    shape=(1,),\n                    name=\"inner\",\n                )\n                self.inner = InnerInnerLayer()\n\n        class InnerInnerLayer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.var = self.add_weight(\n                    shape=(1,),\n                    name=\"inner\",\n                )\n\n        layer = TrainingLayer()\n        layer.build(None)\n        self.assertEqual(len(layer.variables), 2)\n        variable_paths = set(v.path for v in layer.variables)\n        self.assertTrue(\"inner_layer/inner\" in variable_paths)\n        self.assertTrue(\"inner_inner_layer/inner\" in variable_paths)\n        if backend.backend() == \"torch\":\n            parameter_names = set(\n                param_name.replace(\"_torch_params.\", \"\")\n                for param_name, _ in layer.named_parameters()\n            )\n            self.assertSetEqual(variable_paths, parameter_names)\n\n    def test_custom_layer_add_weight_in_build_name(self):\n        class TrainingLayer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.inner = InnerLayer()\n\n            def call(self, input):\n                return self.inner(input)\n\n        class InnerLayer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.inner = InnerInnerLayer()\n\n            def build(self, _):\n                self.var = self.add_weight(\n                    shape=(1,),\n                    name=\"inner\",\n                )\n\n            def call(self, input):\n                return self.var + self.inner(input)\n\n        class InnerInnerLayer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n\n            def build(self, _):\n                self.var = self.add_weight(\n                    shape=(1,),\n                    name=\"inner\",\n                )\n\n            def call(self, input):\n                return self.var + input\n\n        layer = TrainingLayer()\n        output = layer(\n            backend.KerasTensor(\n                (4, 1),\n            )\n        )\n        self.assertEqual(output.shape, (4, 1))\n        self.assertEqual(len(layer.variables), 2)\n        variable_paths = set(v.path for v in layer.variables)\n        self.assertTrue(\"training_layer/inner_layer/inner\" in variable_paths)\n        self.assertTrue(\n            \"training_layer/inner_layer/inner_inner_layer/inner\"\n            in variable_paths\n        )\n        if backend.backend() == \"torch\":\n            parameter_names = set(\n                param_name.replace(\"_torch_params.\", \"\")\n                for param_name, _ in layer.named_parameters()\n            )\n            self.assertSetEqual(variable_paths, parameter_names)\n\n    def test_layer_variable_tracking_correct(self):\n        class TrainingLayer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.post_build_modify_layer = PostBuildModifyLayer()\n\n            def call(self, input):\n                return self.post_build_modify_layer(input)\n\n        class PostBuildModifyLayer(layers.Layer):\n            def call(self, input):\n                return self.var + input\n\n            def build(self, _):\n                self.var = self.add_weight(\n                    shape=(2,),\n                    name=\"var\",\n                )\n\n            def post_build_add(self):\n                self._tracker.unlock()\n                self.additional_var = self.add_weight(\n                    shape=(2,),\n                    name=\"var2\",\n                )\n                self._tracker.lock()\n\n            def post_build_remove(self):\n                self._tracker.unlock()\n                self._untrack_variable(self.var)\n                del self.var\n                self._tracker.lock()\n\n        layer = TrainingLayer()\n        output = layer(backend.KerasTensor((4, 2)))\n\n        self.assertEqual(output.shape, (4, 2))\n        self.assertEqual(len(layer.variables), 1)\n        self.assertEqual(\n            layer.variables[0].path,\n            \"training_layer/post_build_modify_layer/var\",\n        )\n        if backend.backend() == \"torch\":\n            parameter_names = [pname for pname, _ in layer.named_parameters()]\n            self.assertEqual(len(parameter_names), 1)\n            self.assertEqual(\n                parameter_names[0],\n                \"_torch_params.training_layer/post_build_modify_layer/var\",\n            )\n\n        layer.post_build_modify_layer.post_build_add()\n        self.assertEqual(len(layer.variables), 2)\n        self.assertEqual(\n            layer.variables[0].path,\n            \"training_layer/post_build_modify_layer/var\",\n        )\n        self.assertEqual(\n            layer.variables[1].path,\n            \"training_layer/post_build_modify_layer/var2\",\n        )\n        if backend.backend() == \"torch\":\n            # TODO (haohuanw, fchollet): Needs further discussion on how to\n            # properly manage torch params. Post build modification cannot\n            # propagate to parent torch params.\n            parameter_names = [pname for pname, _ in layer.named_parameters()]\n            # Below check should have 2 parameters instead of 1.\n            self.assertEqual(len(parameter_names), 1)\n            self.assertEqual(\n                parameter_names[0],\n                \"_torch_params.training_layer/post_build_modify_layer/var\",\n            )\n\n            parameter_names = [\n                pname\n                for pname, _ in layer.post_build_modify_layer.named_parameters()\n            ]\n            self.assertEqual(len(parameter_names), 2)\n            self.assertEqual(\n                parameter_names[0],\n                \"_torch_params.training_layer/post_build_modify_layer/var\",\n            )\n            self.assertEqual(\n                parameter_names[1],\n                \"_torch_params.training_layer/post_build_modify_layer/var2\",\n            )\n\n        layer.post_build_modify_layer.post_build_remove()\n        self.assertEqual(len(layer.variables), 1)\n        self.assertEqual(\n            layer.variables[0].path,\n            \"training_layer/post_build_modify_layer/var2\",\n        )\n        if backend.backend() == \"torch\":\n            # TODO (haohuanw, fchollet): Needs further discussion on how to\n            # properly manage torch params. Post build modification cannot\n            # propagate to parent torch params.\n            parameter_names = [pname for pname, _ in layer.named_parameters()]\n            # Below check should have 1 parameters instead of 2, torch_params\n            # in parent layer is wrong.\n            self.assertEqual(len(parameter_names), 2)\n            self.assertEqual(\n                parameter_names[0],\n                \"post_build_modify_layer._torch_params.training_layer/\"\n                \"post_build_modify_layer/var2\",\n            )\n            self.assertEqual(\n                parameter_names[1],\n                \"_torch_params.training_layer/post_build_modify_layer/var\",\n            )\n\n            parameter_names = [\n                pname\n                for pname, _ in layer.post_build_modify_layer.named_parameters()\n            ]\n            self.assertEqual(len(parameter_names), 1)\n            self.assertEqual(\n                parameter_names[0],\n                \"_torch_params.training_layer/post_build_modify_layer/var2\",\n            )\n\n    @pytest.mark.skipif(backend.backend() != \"torch\", reason=\"Torch only test.\")\n    def test_torch_params_create_deterministic(self):\n        class MyLayer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.w1 = self.add_weight()\n                self.w2 = self.add_weight(dtype=\"int32\", trainable=False)\n                self.w3 = self.add_weight(dtype=\"bool\", trainable=False)\n                self.w4 = self.add_weight(\n                    dtype=\"int32\", shape=(2, 2), trainable=False\n                )\n                self.w5 = self.add_weight(initializer=\"ones\", shape=(2, 2))\n\n        layer1 = MyLayer()\n        layer1.build(None)\n        layer1_names = list(pname for pname, _ in layer1.named_parameters())\n        global_state.clear_session()\n        layer2 = MyLayer()\n        layer2.build(None)\n        layer2_names = list(pname for pname, _ in layer2.named_parameters())\n        self.assertListEqual(layer1_names, layer2_names)\n\n    def test_complex_dtype_support(self):\n        class MyDenseLayer(layers.Layer):\n            def __init__(self, num_outputs):\n                super(MyDenseLayer, self).__init__()\n                self.num_outputs = num_outputs\n\n            def build(self, input_shape):\n                self.kernel = self.add_weight(\n                    shape=[int(input_shape[-1]), self.num_outputs],\n                )\n\n            def call(self, inputs):\n                kernel = ops.cast(self.kernel, \"complex64\")\n                return ops.matmul(inputs, kernel)\n\n        inputs = ops.zeros([10, 5], dtype=\"complex64\")\n        layer = MyDenseLayer(10)\n        output = layer(inputs)\n        self.assertAllEqual(output.shape, (10, 10))\n\n    def test_call_context_args_with_custom_layers(self):\n        class Inner(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self._register_call_context_args(\"foo_mode\")\n\n            def call(self, x, foo_mode=None):\n                return x + (1 if foo_mode else 0)\n\n        class Outer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self._register_call_context_args(\"foo_mode\")\n                self.inner = Inner()\n\n            def call(self, x):\n                # Outer doesn’t even need to re‑inject explicitly:\n                # our base class will propagate foo_mode automatically\n                return self.inner(x)\n\n        layer = Outer()\n        self.assertEqual(int(layer(np.array(0), foo_mode=True)), 1)\n        self.assertEqual(int(layer(np.array(0))), 0)\n\n    def test_register_call_context_arguments(self):\n        \"\"\"Validate that registering call-context args works as expected.\"\"\"\n\n        class MyLayer(layers.Layer):\n            def call(self, x):\n                return x\n\n        layer = MyLayer()\n\n        layer._register_call_context_args(\"foo_mode\")\n\n        self.assertCountEqual(\n            layer._call_context_args, (\"foo_mode\", \"training\")\n        )\n\n    def test_register_call_context_arguments_after_call(self):\n        \"\"\"Validate that registering call-context args after the layer has\n        been called raises an error.\"\"\"\n\n        class MyLayer(layers.Layer):\n            def call(self, x):\n                return x\n\n        layer = MyLayer()\n        layer(np.array(0))\n        with self.assertRaisesRegex(\n            RuntimeError,\n            \"Cannot add call-context args after the layer has been called.\",\n        ):\n            layer._register_call_context_args(\"foo_mode\")\n\n    def test_context_args_with_triple_nesting_and_priority(self):\n        \"\"\"Validate that call-context args are propagated correctly\n        through multiple layers, and that the most specific value is used\n        when multiple values are passed down the call-stack.\n        \"\"\"\n\n        class Inner(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self._register_call_context_args(\"foo_mode\")\n\n            def call(self, x, foo_mode=None):\n                return x + (1 if foo_mode else 0)\n\n        class Middle(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.inner = Inner()\n\n            def call(self, x):\n                return self.inner(x)\n\n        class Outer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.middle = Middle()\n\n            def call(self, x):\n                # Outer explicitly sets foo_mode=False when calling Inner,\n                # so the value being passed here should be ignored.\n                return self.middle(x)\n\n        layer = Outer()\n        layer._register_call_context_args(\"foo_mode\")\n\n        # The value of foo_mode is set to True in the call to Outer,\n        # so it should automatically propagate to Inner through Middle.\n        self.assertEqual(int(layer(np.array(0), foo_mode=True)), 1)\n        self.assertEqual(int(layer(np.array(0))), 0)\n\n    def test_context_arg_propagation_without_declaration(self):\n        \"\"\"Validate that layer does not resolve a propagated arg if it is not\n        declared as a call-context arg in the layer itself.\"\"\"\n\n        class Inner(layers.Layer):\n            def call(self, x, foo_mode=None):\n                return x + (1 if foo_mode else 0)\n\n        class Wrapper(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.inner = Inner()\n\n            def call(self, x):\n                return self.inner(x)\n\n        layer = Wrapper()\n        layer._register_call_context_args(\"foo_mode\")\n\n        # The value of foo_mode is set to True in the call to Wrapper,\n        # However, it is not declared as a call-context arg in Inner,\n        # so it should not resolve to True inside Inner (and instead\n        # default to False).\n        self.assertEqual(int(layer(np.array(0), foo_mode=True)), 0)\n\n    def test_call_context_args_with_func_seq_models_as_layers(self):\n        \"\"\"Validate that call-context args are propagated correctly\n        through functional and sequential models when used as layers.\n        \"\"\"\n\n        class Inner(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self._register_call_context_args(\"foo_mode\")\n\n            def call(self, x, foo_mode=False):\n                # If foo_mode=True add 1, otherwise add 0\n                add_val = ops.where(foo_mode, 1.0, 0.0)\n                return x + add_val\n\n        class Outer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.inner = Inner()\n\n            def call(self, x):\n                # We don’t explicitly pass foo_mode here—Base Layer.__call__\n                # should inject it into `self.inner`\n                return self.inner(x)\n\n        sample_input = np.array([[1.0], [2.0]])\n\n        # Sequential model\n        seq = models.Sequential([layers.Identity(), Outer()])\n        # Tell the Sequential model to propagate foo_mode down\n        # the call-stack\n        seq._register_call_context_args(\"foo_mode\")\n\n        # foo_mode=True -> input + 1\n        out_true = seq(sample_input, foo_mode=True)\n        self.assertAllClose(out_true, sample_input + 1.0)\n\n        # foo_mode omitted -> foo_mode defaults to False -> no change\n        out_false = seq(sample_input)\n        self.assertAllClose(out_false, sample_input)\n\n        # Functional model\n        inp = Input(shape=(1,))\n        out = layers.Identity()(inp)\n        out = Outer()(out)\n        model = models.Model(inp, out)\n        # Tell the Functional model to propagate foo_mode down\n        # the call-stack\n        model._register_call_context_args(\"foo_mode\")\n\n        # foo_mode=True -> input + 1\n        y1 = model(sample_input, foo_mode=True)\n        self.assertAllClose(y1, sample_input + 1.0)\n\n        # foo_mode omitted -> foo_mode defaults to False -> no change\n        y2 = model(sample_input)\n        self.assertAllClose(y2, sample_input)\n\n    def test_layer_build_with_attention_mask_arg(self):\n        test = self\n\n        class CustomLayer(layers.Layer):\n            def call(self, inputs, attention_mask=None):\n                if attention_mask is not None:\n                    return inputs * ops.cast(attention_mask, x.dtype)\n                return inputs\n\n            def build(self, inputs_shape, attention_mask_shape=None):\n                test.assertIsNotNone(attention_mask_shape)\n                self.built = True\n\n        layer = CustomLayer()\n        x = np.ones((2, 3), dtype=\"float32\")\n        mask = np.ones((2, 1), dtype=\"float32\")\n        y = layer(x, attention_mask=mask)\n        self.assertEqual(y.shape, (2, 3))\n"
  },
  {
    "path": "keras/src/layers/merging/__init__.py",
    "content": ""
  },
  {
    "path": "keras/src/layers/merging/add.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.merging.base_merge import Merge\n\n\n@keras_export(\"keras.layers.Add\")\nclass Add(Merge):\n    \"\"\"Performs elementwise addition operation.\n\n    It takes as input a list of tensors, all of the same shape,\n    and returns a single tensor (also of the same shape).\n\n    Examples:\n\n    >>> input_shape = (2, 3, 4)\n    >>> x1 = np.random.rand(*input_shape)\n    >>> x2 = np.random.rand(*input_shape)\n    >>> y = keras.layers.Add()([x1, x2])\n\n    Usage in a Keras model:\n\n    >>> input1 = keras.layers.Input(shape=(16,))\n    >>> x1 = keras.layers.Dense(8, activation='relu')(input1)\n    >>> input2 = keras.layers.Input(shape=(32,))\n    >>> x2 = keras.layers.Dense(8, activation='relu')(input2)\n    >>> # equivalent to `added = keras.layers.add([x1, x2])`\n    >>> added = keras.layers.Add()([x1, x2])\n    >>> out = keras.layers.Dense(4)(added)\n    >>> model = keras.models.Model(inputs=[input1, input2], outputs=out)\n\n    \"\"\"\n\n    def _merge_function(self, inputs):\n        output = inputs[0]\n        for i in range(1, len(inputs)):\n            output = ops.add(output, inputs[i])\n        return output\n\n\n@keras_export(\"keras.layers.add\")\ndef add(inputs, **kwargs):\n    \"\"\"Functional interface to the `keras.layers.Add` layer.\n\n    Args:\n        inputs: A list of input tensors with the same shape.\n        **kwargs: Standard layer keyword arguments.\n\n    Returns:\n        A tensor as the sum of the inputs. It has the same shape as the inputs.\n\n    Examples:\n\n    >>> input_shape = (2, 3, 4)\n    >>> x1 = np.random.rand(*input_shape)\n    >>> x2 = np.random.rand(*input_shape)\n    >>> y = keras.layers.add([x1, x2])\n\n    Usage in a Keras model:\n\n    >>> input1 = keras.layers.Input(shape=(16,))\n    >>> x1 = keras.layers.Dense(8, activation='relu')(input1)\n    >>> input2 = keras.layers.Input(shape=(32,))\n    >>> x2 = keras.layers.Dense(8, activation='relu')(input2)\n    >>> added = keras.layers.add([x1, x2])\n    >>> out = keras.layers.Dense(4)(added)\n    >>> model = keras.models.Model(inputs=[input1, input2], outputs=out)\n\n    \"\"\"\n    return Add(**kwargs)(inputs)\n"
  },
  {
    "path": "keras/src/layers/merging/average.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.merging.base_merge import Merge\n\n\n@keras_export(\"keras.layers.Average\")\nclass Average(Merge):\n    \"\"\"Averages a list of inputs element-wise..\n\n    It takes as input a list of tensors, all of the same shape,\n    and returns a single tensor (also of the same shape).\n\n    Examples:\n\n    >>> input_shape = (2, 3, 4)\n    >>> x1 = np.random.rand(*input_shape)\n    >>> x2 = np.random.rand(*input_shape)\n    >>> y = keras.layers.Average()([x1, x2])\n\n    Usage in a Keras model:\n\n    >>> input1 = keras.layers.Input(shape=(16,))\n    >>> x1 = keras.layers.Dense(8, activation='relu')(input1)\n    >>> input2 = keras.layers.Input(shape=(32,))\n    >>> x2 = keras.layers.Dense(8, activation='relu')(input2)\n    >>> # equivalent to `y = keras.layers.average([x1, x2])`\n    >>> y = keras.layers.Average()([x1, x2])\n    >>> out = keras.layers.Dense(4)(y)\n    >>> model = keras.models.Model(inputs=[input1, input2], outputs=out)\n\n    \"\"\"\n\n    def _merge_function(self, inputs):\n        output = inputs[0]\n        for i in range(1, len(inputs)):\n            output = ops.add(output, inputs[i])\n        return output / len(inputs)\n\n\n@keras_export(\"keras.layers.average\")\ndef average(inputs, **kwargs):\n    \"\"\"Functional interface to the `keras.layers.Average` layer.\n\n    Args:\n        inputs: A list of input tensors , all of the same shape.\n        **kwargs: Standard layer keyword arguments.\n\n    Returns:\n        A tensor as the element-wise product of the inputs with the same\n        shape as the inputs.\n\n    Examples:\n\n    >>> input_shape = (2, 3, 4)\n    >>> x1 = np.random.rand(*input_shape)\n    >>> x2 = np.random.rand(*input_shape)\n    >>> y = keras.layers.average([x1, x2])\n\n    Usage in a Keras model:\n\n    >>> input1 = keras.layers.Input(shape=(16,))\n    >>> x1 = keras.layers.Dense(8, activation='relu')(input1)\n    >>> input2 = keras.layers.Input(shape=(32,))\n    >>> x2 = keras.layers.Dense(8, activation='relu')(input2)\n    >>> y = keras.layers.average([x1, x2])\n    >>> out = keras.layers.Dense(4)(y)\n    >>> model = keras.models.Model(inputs=[input1, input2], outputs=out)\n\n    \"\"\"\n    return Average(**kwargs)(inputs)\n"
  },
  {
    "path": "keras/src/layers/merging/base_merge.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\nfrom keras.src.backend.common.keras_tensor import KerasTensor\nfrom keras.src.layers.layer import Layer\n\n\nclass Merge(Layer):\n    \"\"\"Generic merge layer for elementwise merge functions.\n\n    Used to implement `Sum`, `Average`, etc.\n\n    Args:\n        **kwargs: standard layer keyword arguments.\n    \"\"\"\n\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n        self.supports_masking = True\n\n    def _merge_function(self, inputs):\n        raise NotImplementedError\n\n    def _apply_merge_op_and_or_mask(self, op_fn, inputs):\n        \"\"\"Merge a set of inputs by applying `op_fn` and ORing the masks.\n\n        We use this for `Minimum` and `Maximum` as it handles the fact that\n        there is no identity element. If applicable, the mask obtained by ORing\n        all masks is set on the output.\n\n        Args:\n            op_fn: binary operation to apply to tensor pair.\n            inputs: array of tensors to apply operation on.\n        \"\"\"\n        output = None\n        output_mask = None\n\n        for x in inputs:\n            mask = backend.get_keras_mask(x)\n            if mask is not None:\n                mask = ops.broadcast_to(ops.expand_dims(mask, -1), ops.shape(x))\n            if output is None:\n                output = x\n                output_mask = mask\n                continue\n            if mask is not None:\n                x = ops.where(mask, x, output)\n            if output_mask is not None:\n                output = ops.where(output_mask, output, x)\n            if mask is not None and output_mask is not None:\n                output_mask = ops.logical_or(output_mask, mask)\n            else:\n                output_mask = None\n            output = op_fn(output, x)\n\n        if output_mask is not None:\n            output_mask = ops.any(output_mask, axis=-1, keepdims=False)\n            backend.set_keras_mask(output, output_mask)\n        return output\n\n    def _compute_elemwise_op_output_shape(self, shape1, shape2):\n        \"\"\"Computes the shape of the resultant of an elementwise operation.\n\n        Args:\n            shape1: Tuple or None. Shape of the first tensor\n            shape2: Tuple or None. Shape of the second tensor\n\n        Returns:\n            Expected output shape when an element-wise operation is\n            carried out on 2 tensors with shapes shape1 and shape2.\n            tuple or None.\n\n        Raises:\n            ValueError: If shape1 and shape2 are not compatible for\n                element-wise operations.\n        \"\"\"\n\n        if None in [shape1, shape2]:\n            return None\n        elif len(shape1) < len(shape2):\n            return self._compute_elemwise_op_output_shape(shape2, shape1)\n        elif not shape2:\n            return shape1\n        output_shape = list(shape1[: -len(shape2)])\n        for i, j in zip(shape1[-len(shape2) :], shape2):\n            if i is None or j is None:\n                output_shape.append(None)\n            elif i == 1:\n                output_shape.append(j)\n            elif j == 1:\n                output_shape.append(i)\n            else:\n                if i != j:\n                    raise ValueError(\n                        \"Inputs have incompatible shapes. \"\n                        f\"Received shapes {shape1} and {shape2}\"\n                    )\n                output_shape.append(i)\n        return tuple(output_shape)\n\n    def build(self, input_shape):\n        # Used purely for shape validation.\n        if not isinstance(input_shape[0], (tuple, list)):\n            raise ValueError(\n                \"A merge layer should be called on a list of inputs. \"\n                f\"Received: input_shape={input_shape} (not a list of shapes)\"\n            )\n        if len(input_shape) < 1:\n            raise ValueError(\n                \"A merge layer should be called \"\n                \"on a list of at least 1 input. \"\n                f\"Received {len(input_shape)} inputs. \"\n                f\"Full input_shape received: {input_shape}\"\n            )\n\n        batch_sizes = {s[0] for s in input_shape if s} - {None}\n        if len(batch_sizes) > 1:\n            raise ValueError(\n                \"Cannot merge tensors with different batch sizes. \"\n                f\"Received tensors with shapes {input_shape}\"\n            )\n\n        if input_shape[0] is None:\n            output_shape = None\n        else:\n            output_shape = input_shape[0][1:]\n\n        for i in range(1, len(input_shape)):\n            if input_shape[i] is None:\n                shape = None\n            else:\n                shape = input_shape[i][1:]\n            output_shape = self._compute_elemwise_op_output_shape(\n                output_shape, shape\n            )\n\n        # If the inputs have different ranks, we have to reshape them\n        # to make them broadcastable.\n        if None not in input_shape and len(set(map(len, input_shape))) == 1:\n            self._reshape_required = False\n        else:\n            self._reshape_required = True\n\n    def call(self, inputs):\n        if not isinstance(inputs, (list, tuple)):\n            raise ValueError(\n                \"A merge layer should be called on a list of inputs. \"\n                f\"Received: inputs={inputs} (not a list of tensors)\"\n            )\n        if self._reshape_required:\n            reshaped_inputs = []\n            input_ndims = list(map(ops.ndim, inputs))\n            if None not in input_ndims:\n                # If ranks of all inputs are available,\n                # we simply expand each of them at axis=1\n                # until all of them have the same rank.\n                max_ndim = max(input_ndims)\n                for x in inputs:\n                    x_ndim = ops.ndim(x)\n                    for _ in range(max_ndim - x_ndim):\n                        x = ops.expand_dims(x, axis=1)\n                    reshaped_inputs.append(x)\n                return self._merge_function(reshaped_inputs)\n            else:\n                # Transpose all inputs so that batch size is the last dimension.\n                # (batch_size, dim1, dim2, ... ) -> (dim1, dim2, ... ,\n                # batch_size)\n                transposed = False\n                for x in inputs:\n                    x_ndim = ops.ndim(x)\n\n                    if x_ndim is None:\n                        x_shape = ops.shape(x)\n                        batch_size = x_shape[0]\n\n                        new_shape = backend.concatenate(\n                            [x_shape[1:], ops.expand_dims(batch_size, axis=-1)]\n                        )\n                        x_transposed = ops.reshape(\n                            x,\n                            ops.stack(\n                                [batch_size, ops.prod(x_shape[1:])],\n                                axis=0,\n                            ),\n                        )\n                        x_transposed = ops.transpose(x_transposed, perm=(1, 0))\n                        x_transposed = ops.reshape(x_transposed, new_shape)\n\n                        reshaped_inputs.append(x_transposed)\n                        transposed = True\n\n                    elif x_ndim > 1:\n                        dims = list(range(1, x_ndim)) + [0]\n                        reshaped_inputs.append(ops.transpose(x, perm=dims))\n                        print(dims)\n                        transposed = True\n                    else:\n                        # We don't transpose inputs if they are 1D vectors or\n                        # scalars.\n                        reshaped_inputs.append(x)\n\n                y = self._merge_function(reshaped_inputs)\n                y_ndim = ops.ndim(y)\n\n                if transposed:\n                    # If inputs have been transposed, we have to transpose the\n                    # output too.\n                    if y_ndim is None:\n                        y_shape = ops.shape(y)\n                        y_ndim = ops.shape(y_shape)[0]\n                        batch_size = y_shape[y_ndim - 1]\n                        new_shape = ops.concatenate(\n                            [\n                                ops.expand_dims(batch_size, axis=-1),\n                                y_shape[: y_ndim - 1],\n                            ]\n                        )\n                        y = ops.reshape(y, (-1, batch_size))\n                        y = ops.transpose(y, perm=(1, 0))\n                        y = ops.reshape(y, new_shape)\n                    elif y_ndim > 1:\n                        dims = [y_ndim - 1] + list(range(y_ndim - 1))\n                        y = ops.transpose(y, perm=dims)\n                return y\n        else:\n            return self._merge_function(inputs)\n\n    def compute_output_shape(self, input_shape):\n        if input_shape[0] is None:\n            output_shape = None\n        else:\n            output_shape = input_shape[0][1:]\n\n        for i in range(1, len(input_shape)):\n            if input_shape[i] is None:\n                shape = None\n            else:\n                shape = input_shape[i][1:]\n            output_shape = self._compute_elemwise_op_output_shape(\n                output_shape, shape\n            )\n        batch_sizes = {s[0] for s in input_shape if s is not None} - {None}\n        if len(batch_sizes) == 1:\n            output_shape = (list(batch_sizes)[0],) + output_shape\n        else:\n            output_shape = (None,) + output_shape\n        return output_shape\n\n    def compute_output_spec(self, inputs):\n        output_shape = self.compute_output_shape([x.shape for x in inputs])\n        output_sparse = all(x.sparse for x in inputs)\n        return KerasTensor(\n            output_shape, dtype=self.compute_dtype, sparse=output_sparse\n        )\n\n    def compute_mask(self, inputs, mask=None):\n        if mask is None:\n            return None\n        if not isinstance(mask, (tuple, list)):\n            raise ValueError(f\"`mask` should be a list. Received: mask={mask}\")\n        if not isinstance(inputs, (tuple, list)):\n            raise ValueError(\n                f\"`inputs` should be a list. Received: inputs={inputs}\"\n            )\n        if len(mask) != len(inputs):\n            raise ValueError(\n                \"The lists `inputs` and `mask` should have the same length. \"\n                f\"Received: inputs={inputs} of length {len(inputs)}, and \"\n                f\"mask={mask} of length {len(mask)}\"\n            )\n        # Default implementation does an OR between the masks, which works\n        # for `Add`, `Subtract`, `Average`, `Maximum`, `Minimum`, `Multiply`.\n        if any(m is None for m in mask):\n            return None\n        output_mask = mask[0]\n        for m in mask[1:]:\n            output_mask = ops.logical_or(output_mask, m)\n        return output_mask\n\n    def get_config(self):\n        return super().get_config()\n"
  },
  {
    "path": "keras/src/layers/merging/concatenate.py",
    "content": "import copy\n\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.merging.base_merge import Merge\n\n\n@keras_export(\"keras.layers.Concatenate\")\nclass Concatenate(Merge):\n    \"\"\"Concatenates a list of inputs.\n\n    It takes as input a list of tensors, all of the same shape except\n    for the concatenation axis, and returns a single tensor that is the\n    concatenation of all inputs.\n\n    Examples:\n\n    >>> x = np.arange(20).reshape(2, 2, 5)\n    >>> y = np.arange(20, 30).reshape(2, 1, 5)\n    >>> keras.layers.Concatenate(axis=1)([x, y])\n\n    Usage in a Keras model:\n\n    >>> x1 = keras.layers.Dense(8)(np.arange(10).reshape(5, 2))\n    >>> x2 = keras.layers.Dense(8)(np.arange(10, 20).reshape(5, 2))\n    >>> y = keras.layers.Concatenate()([x1, x2])\n\n    Args:\n        axis: Axis along which to concatenate.\n        **kwargs: Standard layer keyword arguments.\n\n    Returns:\n        A tensor, the concatenation of the inputs alongside axis `axis`.\n    \"\"\"\n\n    def __init__(self, axis=-1, **kwargs):\n        super().__init__(**kwargs)\n        self.axis = axis\n        self.supports_masking = True\n        self._reshape_required = False\n\n    def build(self, input_shape):\n        # Used purely for shape validation.\n        if len(input_shape) < 1 or not isinstance(\n            input_shape[0], (tuple, list)\n        ):\n            raise ValueError(\n                \"A `Concatenate` layer should be called on a list of \"\n                f\"at least 1 input. Received: input_shape={input_shape}\"\n            )\n        if all(shape is None for shape in input_shape):\n            return\n\n        reduced_inputs_shapes = [list(shape) for shape in input_shape]\n        reduced_inputs_shapes_copy = copy.copy(reduced_inputs_shapes)\n        shape_set = set()\n        for i in range(len(reduced_inputs_shapes_copy)):\n            # Convert self.axis to positive axis for each input\n            # in case self.axis is a negative number\n            concat_axis = self.axis % len(reduced_inputs_shapes_copy[i])\n            #  Skip batch axis.\n            for axis, axis_value in enumerate(\n                reduced_inputs_shapes_copy, start=1\n            ):\n                # Remove squeezable axes (axes with value of 1)\n                # if not in the axis that will be used for concatenation\n                # otherwise leave it.\n                # This approach allows building the layer,\n                # but if tensor shapes are not the same when\n                # calling, an exception will be raised.\n                if axis != concat_axis and axis_value == 1:\n                    del reduced_inputs_shapes[i][axis]\n\n            if len(reduced_inputs_shapes[i]) > self.axis:\n                del reduced_inputs_shapes[i][self.axis]\n            shape_set.add(tuple(reduced_inputs_shapes[i]))\n\n        if len(shape_set) != 1:\n            err_msg = (\n                \"A `Concatenate` layer requires inputs with matching shapes \"\n                \"except for the concatenation axis. \"\n                f\"Received: input_shape={input_shape}\"\n            )\n            # Make sure all the shapes have same ranks.\n            ranks = set(len(shape) for shape in shape_set)\n            if len(ranks) != 1:\n                raise ValueError(err_msg)\n            # Get the only rank for the set.\n            (rank,) = ranks\n            for axis in range(rank):\n                # Skip the Nones in the shape since they are dynamic, also the\n                # axis for concat has been removed above.\n                unique_dims = set(\n                    shape[axis]\n                    for shape in shape_set\n                    if shape[axis] is not None\n                )\n                if len(unique_dims) > 1:\n                    raise ValueError(err_msg)\n\n    def _merge_function(self, inputs):\n        return ops.concatenate(inputs, axis=self.axis)\n\n    def compute_output_shape(self, input_shape):\n        if (not isinstance(input_shape, (tuple, list))) or (\n            not isinstance(input_shape[0], (tuple, list))\n        ):\n            raise ValueError(\n                \"A `Concatenate` layer should be called on a list of inputs. \"\n                f\"Received: input_shape={input_shape}\"\n            )\n        input_shapes = input_shape\n        output_shape = list(input_shapes[0])\n\n        for shape in input_shapes[1:]:\n            if output_shape[self.axis] is None or shape[self.axis] is None:\n                output_shape[self.axis] = None\n                break\n            output_shape[self.axis] += shape[self.axis]\n        return tuple(output_shape)\n\n    def compute_mask(self, inputs, mask=None):\n        if mask is None:\n            return None\n        if not isinstance(mask, (tuple, list)):\n            raise ValueError(f\"`mask` should be a list. Received mask={mask}\")\n        if not isinstance(inputs, (tuple, list)):\n            raise ValueError(\n                f\"`inputs` should be a list. Received: inputs={inputs}\"\n            )\n        if len(mask) != len(inputs):\n            raise ValueError(\n                \"The lists `inputs` and `mask` should have the same length. \"\n                f\"Received: inputs={inputs} of length {len(inputs)}, and \"\n                f\"mask={mask} of length {len(mask)}\"\n            )\n        if all(m is None for m in mask):\n            return None\n        # Make a list of masks while making sure\n        # the dimensionality of each mask\n        # is the same as the corresponding input.\n        masks = []\n        for input_i, mask_i in zip(inputs, mask):\n            if mask_i is None:\n                # Input is unmasked. Append all 1s to masks,\n                masks.append(ops.ones_like(input_i, dtype=\"bool\"))\n            elif mask_i.ndim < input_i.ndim:\n                # Broadcast mask shape to match in a way where we capture the\n                # input as a symbolic input in the op graph.\n                mask_i = ops.logical_or(\n                    ops.expand_dims(mask_i, axis=-1),\n                    ops.zeros_like(input_i, dtype=\"bool\"),\n                )\n                masks.append(mask_i)\n            else:\n                masks.append(mask_i)\n        concatenated = ops.concatenate(masks, axis=self.axis)\n        return ops.any(concatenated, axis=-1, keepdims=False)\n\n    def get_config(self):\n        config = {\"axis\": self.axis}\n        base_config = super().get_config()\n        return dict(list(base_config.items()) + list(config.items()))\n\n\n@keras_export(\"keras.layers.concatenate\")\ndef concatenate(inputs, axis=-1, **kwargs):\n    \"\"\"Functional interface to the `Concatenate` layer.\n\n    Args:\n        inputs: A list of input tensors.\n        axis: Concatenation axis.\n        **kwargs: Standard layer keyword arguments.\n\n    Returns:\n        A tensor, the concatenation of the inputs alongside axis `axis`.\n    \"\"\"\n    return Concatenate(axis=axis, **kwargs)(inputs)\n"
  },
  {
    "path": "keras/src/layers/merging/dot.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.merging.base_merge import Merge\nfrom keras.src.utils.numerical_utils import normalize\n\n\ndef batch_dot(x, y, axes=None):\n    \"\"\"Batchwise dot product.\n\n    `batch_dot` is used to compute dot product of `x` and `y` when\n    `x` and `y` are data in batch, i.e. in a shape of `(batch_size, :)`.\n    `batch_dot` results in a tensor or variable with less dimensions\n    than the input. If the number of dimensions is reduced to 1,\n    we use `expand_dims` to make sure that ndim is at least 2.\n\n    Shape inference:\n\n    Let `x`'s shape be `(100, 20)` and `y`'s shape be `(100, 30, 20)`.\n    If `axes` is (1, 2), to find the output shape of resultant tensor,\n    loop through each dimension in `x`'s shape and `y`'s shape:\n\n    * `x.shape[0]` : 100 : append to output shape\n    * `x.shape[1]` : 20 : do not append to output shape, dimension 1 of\n        `x` has been summed over. (`dot_axes[0]` = 1)\n    * `y.shape[0]` : 100 : do not append to output shape, always ignore\n        first dimension of `y`\n    * `y.shape[1]` : 30 : append to output shape\n    * `y.shape[2]` : 20 : do not append to output shape, dimension 2 of\n        `y` has been summed over.\n        (`dot_axes[1]` = 2) `output_shape` = `(100, 30)`\n\n    Example:\n\n    >>> x_batch = np.ones(shape=(32, 20, 1))\n    >>> y_batch = np.ones(shape=(32, 30, 20))\n    >>> xy_batch_dot = batch_dot(x_batch, y_batch, axes=(1, 2))\n\n    Args:\n        x: Keras tensor or variable with `ndim >= 2`.\n        y: Keras tensor or variable with `ndim >= 2`.\n        axes: Tuple or list of integers with target dimensions, or single\n            integer. The sizes of `x.shape[axes[0]]` and `y.shape[axes[1]]`\n            should be equal.\n            Note that axis `0` (the batch axis) cannot be included.\n\n    Returns:\n        A tensor with shape equal to the concatenation of `x`'s shape\n        (less the dimension that was summed over) and `y`'s shape (less the\n        batch dimension and the dimension that was summed over). If the final\n        rank is 1, we reshape it to `(batch_size, 1)`.\n    \"\"\"\n\n    x_shape = x.shape\n    y_shape = y.shape\n\n    x_ndim = len(x_shape)\n    y_ndim = len(y_shape)\n\n    if x_ndim < 2 or y_ndim < 2:\n        raise ValueError(\n            f\"Cannot do batch_dot on inputs \"\n            f\"with rank < 2. \"\n            f\"Received inputs with shapes \"\n            f\"{x_shape} and {y_shape}.\"\n        )\n\n    x_batch_size = x_shape[0]\n    y_batch_size = y_shape[0]\n\n    if x_batch_size is not None and y_batch_size is not None:\n        if x_batch_size != y_batch_size:\n            raise ValueError(\n                f\"Cannot do batch_dot on inputs \"\n                f\"with different batch sizes. \"\n                f\"Received inputs with shapes \"\n                f\"{x_shape} and {y_shape}.\"\n            )\n    if isinstance(axes, int):\n        axes = [axes, axes]\n\n    if axes is None:\n        if y_ndim == 2:\n            axes = [x_ndim - 1, y_ndim - 1]\n        else:\n            axes = [x_ndim - 1, y_ndim - 2]\n\n    if any(isinstance(a, (list, tuple)) for a in axes):\n        raise ValueError(\n            f\"Multiple target dimensions are not supported. \"\n            f\"Expected: None, int, (int, int), \"\n            f\"Provided: {axes} \"\n        )\n\n    # if tuple, convert to list.\n    axes = list(axes)\n\n    # convert negative indices.\n    if axes[0] < 0:\n        axes[0] += x_ndim\n    if axes[1] < 0:\n        axes[1] += y_ndim\n\n    # sanity checks\n    if 0 in axes:\n        raise ValueError(\n            \"Cannot perform batch_dot over axis 0. \"\n            \"If your inputs are not batched, \"\n            \"add a dummy batch dimension to your \"\n            \"inputs using keras.ops.expand_dims(x, 0)\"\n        )\n    a0, a1 = axes\n    d1 = x_shape[a0]\n    d2 = y_shape[a1]\n\n    if d1 is not None and d2 is not None and d1 != d2:\n        raise ValueError(\n            f\"Cannot do batch_dot on inputs with shapes \"\n            f\"{x_shape} and {y_shape} with axes={axes}. \"\n            f\"x.shape[{axes[0]}] != y.shape[{axes[1]}] ({d1} != {d2}).\"\n        )\n\n    # backup ndims. Need them later.\n    orig_x_ndim = x_ndim\n    orig_y_ndim = y_ndim\n\n    # if rank is 2, expand to 3.\n    if x_ndim == 2:\n        x = ops.expand_dims(x, 1)\n        a0 += 1\n        x_ndim += 1\n    if y_ndim == 2:\n        y = ops.expand_dims(y, 2)\n        y_ndim += 1\n\n    # bring x's dimension to be reduced to last axis.\n    if a0 != x_ndim - 1:\n        pattern = list(range(x_ndim))\n        for i in range(a0, x_ndim - 1):\n            pattern[i] = pattern[i + 1]\n        pattern[-1] = a0\n        x = ops.transpose(x, pattern)\n\n    # bring y's dimension to be reduced to axis 1.\n    if a1 != 1:\n        pattern = list(range(y_ndim))\n        for i in range(a1, 1, -1):\n            pattern[i] = pattern[i - 1]\n        pattern[1] = a1\n        y = ops.transpose(y, pattern)\n\n    # normalize both inputs to rank 3.\n    if x_ndim > 3:\n        # squash middle dimensions of x.\n        x_shape = ops.shape(x)\n        x_mid_dims = x_shape[1:-1]\n        x_squashed_shape = (x_shape[0], -1, x_shape[-1])\n        x = ops.reshape(x, x_squashed_shape)\n        x_squashed = True\n    else:\n        x_squashed = False\n\n    if y_ndim > 3:\n        # squash trailing dimensions of y.\n        y_shape = ops.shape(y)\n        y_trail_dims = y_shape[2:]\n        y_squashed_shape = (y_shape[0], y_shape[1], -1)\n        y = ops.reshape(y, y_squashed_shape)\n        y_squashed = True\n    else:\n        y_squashed = False\n\n    result = ops.matmul(x, y)\n\n    # if inputs were squashed, we have to reshape the matmul output.\n    output_shape = ops.shape(result)\n    do_reshape = False\n\n    if x_squashed:\n        output_shape = output_shape[:1] + x_mid_dims + output_shape[-1:]\n        do_reshape = True\n\n    if y_squashed:\n        output_shape = output_shape[:-1] + y_trail_dims\n        do_reshape = True\n\n    if do_reshape:\n        result = ops.reshape(result, output_shape)\n\n    # if the inputs were originally rank 2, we remove the added 1 dim.\n    if orig_x_ndim == 2:\n        result = ops.squeeze(result, 1)\n    elif orig_y_ndim == 2:\n        result = ops.squeeze(result, -1)\n\n    return result\n\n\n@keras_export(\"keras.layers.Dot\")\nclass Dot(Merge):\n    \"\"\"Computes element-wise dot product of two tensors.\n\n    It takes a list of inputs of size 2, and the axes\n    corresponding to each input along with the dot product\n    is to be performed.\n\n    Let's say `x` and `y` are the two input tensors with shapes\n    `(2, 3, 5)` and `(2, 10, 3)`. The batch dimension should be\n    of same size for both the inputs, and `axes` should correspond\n    to the dimensions that have the same size in the corresponding\n    inputs. e.g. with `axes=(1, 2)`, the dot product of `x`, and `y`\n    will result in a tensor with shape `(2, 5, 10)`\n\n    Example:\n\n    >>> x = np.arange(10).reshape(1, 5, 2)\n    >>> y = np.arange(10, 20).reshape(1, 2, 5)\n    >>> keras.layers.Dot(axes=(1, 2))([x, y])\n\n    Usage in a Keras model:\n\n    >>> x1 = keras.layers.Dense(8)(np.arange(10).reshape(5, 2))\n    >>> x2 = keras.layers.Dense(8)(np.arange(10, 20).reshape(5, 2))\n    >>> y = keras.layers.Dot(axes=1)([x1, x2])\n\n    Args:\n        axes: Integer or tuple of integers, axis or axes along which to\n            take the dot product. If a tuple, should be two integers\n            corresponding to the desired axis from the first input and the\n            desired axis from the second input, respectively. Note that the\n            size of the two selected axes must match, and that\n            axis `0` (the batch axis) cannot be included.\n        normalize: Whether to L2-normalize samples along the dot product axis\n            before taking the dot product. If set to `True`, then\n            the output of the dot product is the cosine proximity\n            between the two samples.\n        **kwargs: Standard layer keyword arguments.\n\n    Returns:\n        A tensor, the dot product of the samples from the inputs.\n    \"\"\"\n\n    def __init__(self, axes, normalize=False, **kwargs):\n        super().__init__(**kwargs)\n        if not isinstance(axes, int):\n            if not isinstance(axes, (list, tuple)):\n                raise TypeError(\n                    f\"Invalid type for argument `axes`: it should be \"\n                    f\"a list or an int. Received: axes={axes}\"\n                )\n            if len(axes) != 2:\n                raise ValueError(\n                    f\"Invalid format for argument `axes`: it should contain \"\n                    f\"two elements. Received: axes={axes}\"\n                )\n            if not isinstance(axes[0], int) or not isinstance(axes[1], int):\n                raise ValueError(\n                    f\"Invalid format for argument `axes`: list elements should \"\n                    f\"be integers. Received: axes={axes}\"\n                )\n        self.axes = axes\n        self.normalize = normalize\n        self.supports_masking = True\n        self._reshape_required = False\n\n    def build(self, input_shape):\n        # Used purely for shape validation.\n        if (\n            not isinstance(input_shape[0], (tuple, list))\n            or len(input_shape) != 2\n        ):\n            raise ValueError(\n                f\"A `Dot` layer should be called on a list of 2 inputs. \"\n                f\"Received: input_shape={input_shape}\"\n            )\n        shape1 = input_shape[0]\n        shape2 = input_shape[1]\n        if shape1 is None or shape2 is None:\n            return\n        if isinstance(self.axes, int):\n            if self.axes < 0:\n                axes = [self.axes % len(shape1), self.axes % len(shape2)]\n            else:\n                axes = [self.axes] * 2\n        else:\n            axes = self.axes\n        if shape1[axes[0]] != shape2[axes[1]]:\n            raise ValueError(\n                f\"Incompatible input shapes: \"\n                f\"axis values {shape1[axes[0]]} (at axis {axes[0]}) != \"\n                f\"{shape2[axes[1]]} (at axis {axes[1]}). \"\n                f\"Full input shapes: {shape1}, {shape2}\"\n            )\n\n    def _merge_function(self, inputs):\n        if len(inputs) != 2:\n            raise ValueError(\n                f\"A `Dot` layer should be called on exactly 2 inputs. \"\n                f\"Received: inputs={inputs}\"\n            )\n        x1 = inputs[0]\n        x2 = inputs[1]\n\n        if isinstance(self.axes, int):\n            if self.axes < 0:\n                axes = [\n                    self.axes % len(x1.shape),\n                    self.axes % len(x2.shape),\n                ]\n            else:\n                axes = [self.axes] * 2\n        else:\n            axes = []\n            for i in range(len(self.axes)):\n                if self.axes[i] < 0:\n                    axes.append(self.axes[i] % len(inputs[i].shape))\n                else:\n                    axes.append(self.axes[i])\n\n        if self.normalize:\n            x1 = normalize(x1, axis=axes[0])\n            x2 = normalize(x2, axis=axes[1])\n        output = batch_dot(x1, x2, axes)\n        return output\n\n    def compute_output_shape(self, input_shape):\n        if not isinstance(input_shape, (tuple, list)) or len(input_shape) != 2:\n            raise ValueError(\n                f\"A `Dot` layer should be called on a list of 2 inputs. \"\n                f\"Received: input_shape={input_shape}\"\n            )\n        shape1 = list(input_shape[0])\n        shape2 = list(input_shape[1])\n        if isinstance(self.axes, int):\n            if self.axes < 0:\n                axes = [self.axes % len(shape1), self.axes % len(shape2)]\n            else:\n                axes = [self.axes] * 2\n        else:\n            axes = self.axes\n        shape1.pop(axes[0])\n        shape2.pop(axes[1])\n        shape2.pop(0)\n        output_shape = shape1 + shape2\n        if len(output_shape) == 1:\n            output_shape += [1]\n        return tuple(output_shape)\n\n    def compute_mask(self, inputs, mask=None):\n        return None\n\n    def get_config(self):\n        config = {\n            \"axes\": self.axes,\n            \"normalize\": self.normalize,\n        }\n        base_config = super().get_config()\n        return dict(list(base_config.items()) + list(config.items()))\n\n\n@keras_export(\"keras.layers.dot\")\ndef dot(inputs, axes=-1, **kwargs):\n    \"\"\"Functional interface to the `Dot` layer.\n\n    Args:\n        inputs: A list of input tensors (at least 2).\n        axes: Integer or tuple of integers,\n            axis or axes along which to take the dot product.\n            Note that axis `0` (the batch axis) cannot be included.\n        normalize: Whether to L2-normalize samples along the\n            dot product axis before taking the dot product.\n            If set to `True`, then the output of the dot product\n            is the cosine proximity between the two samples.\n        **kwargs: Standard layer keyword arguments.\n\n    Returns:\n        A tensor, the dot product of the samples from the inputs.\n    \"\"\"\n    return Dot(axes=axes, **kwargs)(inputs)\n"
  },
  {
    "path": "keras/src/layers/merging/maximum.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.merging.base_merge import Merge\n\n\n@keras_export(\"keras.layers.Maximum\")\nclass Maximum(Merge):\n    \"\"\"Computes element-wise maximum on a list of inputs.\n\n    It takes as input a list of tensors, all of the same shape,\n    and returns a single tensor (also of the same shape).\n\n    Examples:\n\n    >>> input_shape = (2, 3, 4)\n    >>> x1 = np.random.rand(*input_shape)\n    >>> x2 = np.random.rand(*input_shape)\n    >>> y = keras.layers.Maximum()([x1, x2])\n\n    Usage in a Keras model:\n\n    >>> input1 = keras.layers.Input(shape=(16,))\n    >>> x1 = keras.layers.Dense(8, activation='relu')(input1)\n    >>> input2 = keras.layers.Input(shape=(32,))\n    >>> x2 = keras.layers.Dense(8, activation='relu')(input2)\n    >>> # equivalent to `y = keras.layers.maximum([x1, x2])`\n    >>> y = keras.layers.Maximum()([x1, x2])\n    >>> out = keras.layers.Dense(4)(y)\n    >>> model = keras.models.Model(inputs=[input1, input2], outputs=out)\n\n    \"\"\"\n\n    def _merge_function(self, inputs):\n        return self._apply_merge_op_and_or_mask(ops.maximum, inputs)\n\n\n@keras_export(\"keras.layers.maximum\")\ndef maximum(inputs, **kwargs):\n    \"\"\"Functional interface to the `keras.layers.Maximum` layer.\n\n    Args:\n        inputs: A list of input tensors , all of the same shape.\n        **kwargs: Standard layer keyword arguments.\n\n    Returns:\n        A tensor as the element-wise product of the inputs with the same\n        shape as the inputs.\n\n    Examples:\n\n    >>> input_shape = (2, 3, 4)\n    >>> x1 = np.random.rand(*input_shape)\n    >>> x2 = np.random.rand(*input_shape)\n    >>> y = keras.layers.maximum([x1, x2])\n\n    Usage in a Keras model:\n\n    >>> input1 = keras.layers.Input(shape=(16,))\n    >>> x1 = keras.layers.Dense(8, activation='relu')(input1)\n    >>> input2 = keras.layers.Input(shape=(32,))\n    >>> x2 = keras.layers.Dense(8, activation='relu')(input2)\n    >>> y = keras.layers.maximum([x1, x2])\n    >>> out = keras.layers.Dense(4)(y)\n    >>> model = keras.models.Model(inputs=[input1, input2], outputs=out)\n\n    \"\"\"\n    return Maximum(**kwargs)(inputs)\n"
  },
  {
    "path": "keras/src/layers/merging/merging_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import testing\n\n\ndef np_dot(a, b, axes):\n    if isinstance(axes, int):\n        axes = (axes, axes)\n    axes = [axis if axis < 0 else axis - 1 for axis in axes]\n    res = np.stack([np.tensordot(a[i], b[i], axes) for i in range(a.shape[0])])\n    if len(res.shape) == 1:\n        res = np.expand_dims(res, axis=1)\n    return res\n\n\nTEST_PARAMETERS = [\n    {\n        \"testcase_name\": \"add\",\n        \"layer_class\": layers.Add,\n        \"np_op\": np.add,\n    },\n    {\n        \"testcase_name\": \"subtract\",\n        \"layer_class\": layers.Subtract,\n        \"np_op\": np.subtract,\n    },\n    {\n        \"testcase_name\": \"minimum\",\n        \"layer_class\": layers.Minimum,\n        \"np_op\": np.minimum,\n    },\n    {\n        \"testcase_name\": \"maximum\",\n        \"layer_class\": layers.Maximum,\n        \"np_op\": np.maximum,\n    },\n    {\n        \"testcase_name\": \"multiply\",\n        \"layer_class\": layers.Multiply,\n        \"np_op\": np.multiply,\n    },\n    {\n        \"testcase_name\": \"average\",\n        \"layer_class\": layers.Average,\n        \"np_op\": lambda a, b: np.multiply(np.add(a, b), 0.5),\n    },\n    {\n        \"testcase_name\": \"concat\",\n        \"layer_class\": layers.Concatenate,\n        \"np_op\": lambda a, b, **kwargs: np.concatenate((a, b), **kwargs),\n        \"init_kwargs\": {\"axis\": -1},\n        \"expected_output_shape\": (2, 4, 10),\n    },\n    {\n        \"testcase_name\": \"dot_2d\",\n        \"layer_class\": layers.Dot,\n        \"np_op\": np_dot,\n        \"init_kwargs\": {\"axes\": -1},\n        \"input_shape\": (2, 4),\n        \"expected_output_shape\": (2, 1),\n        \"skip_mask_test\": True,\n    },\n    {\n        \"testcase_name\": \"dot_3d\",\n        \"layer_class\": layers.Dot,\n        \"np_op\": np_dot,\n        \"init_kwargs\": {\"axes\": -1},\n        \"expected_output_shape\": (2, 4, 4),\n        \"skip_mask_test\": True,\n    },\n]\n\n\n@pytest.mark.requires_trainable_backend\nclass MergingLayersTest(testing.TestCase):\n    @parameterized.named_parameters(TEST_PARAMETERS)\n    def test_basic(\n        self,\n        layer_class,\n        init_kwargs={},\n        input_shape=(2, 4, 5),\n        expected_output_shape=(2, 4, 5),\n        **kwargs,\n    ):\n        self.run_layer_test(\n            layer_class,\n            init_kwargs=init_kwargs,\n            input_shape=(input_shape, input_shape),\n            expected_output_shape=expected_output_shape,\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n        )\n\n    @parameterized.named_parameters(TEST_PARAMETERS)\n    def test_correctness_static(\n        self,\n        layer_class,\n        np_op,\n        init_kwargs={},\n        input_shape=(2, 4, 5),\n        expected_output_shape=(2, 4, 5),\n        skip_mask_test=False,\n    ):\n        batch_size = input_shape[0]\n        shape = input_shape[1:]\n        x1 = np.random.rand(*input_shape)\n        x2 = np.random.rand(*input_shape)\n        x3 = np_op(x1, x2, **init_kwargs)\n\n        input_1 = layers.Input(shape=shape, batch_size=batch_size)\n        input_2 = layers.Input(shape=shape, batch_size=batch_size)\n        layer = layer_class(**init_kwargs)\n        out = layer([input_1, input_2])\n        model = models.Model([input_1, input_2], out)\n        res = model([x1, x2])\n\n        self.assertEqual(res.shape, expected_output_shape)\n        self.assertAllClose(res, x3, atol=1e-4, tpu_atol=1e-2, tpu_rtol=1e-2)\n        self.assertIsNone(layer.compute_mask([input_1, input_2], [None, None]))\n        self.assertIsNone(layer.compute_mask([x1, x2], [None, None]))\n        if not skip_mask_test:\n            mask1 = np.ones(input_shape[:-1], dtype=np.bool_)\n            mask2 = np.ones(input_shape[:-1], dtype=np.bool_)\n            self.assertTrue(\n                np.all(\n                    backend.convert_to_numpy(\n                        layer.compute_mask([x1, x2], [mask1, mask2])\n                    )\n                )\n            )\n\n    @parameterized.named_parameters(TEST_PARAMETERS)\n    def test_correctness_dynamic(\n        self,\n        layer_class,\n        np_op,\n        init_kwargs={},\n        input_shape=(2, 4, 5),\n        expected_output_shape=(2, 4, 5),\n        skip_mask_test=False,\n    ):\n        shape = input_shape[1:]\n        x1 = np.random.rand(*input_shape)\n        x2 = np.random.rand(*input_shape)\n        x3 = np_op(x1, x2, **init_kwargs)\n\n        input_1 = layers.Input(shape=shape)\n        input_2 = layers.Input(shape=shape)\n        layer = layer_class(**init_kwargs)\n        out = layer([input_1, input_2])\n        model = models.Model([input_1, input_2], out)\n        res = model([x1, x2])\n\n        self.assertEqual(res.shape, expected_output_shape)\n        self.assertAllClose(res, x3, atol=1e-4, tpu_atol=1e-2, tpu_rtol=1e-2)\n        self.assertIsNone(layer.compute_mask([input_1, input_2], [None, None]))\n        if not skip_mask_test:\n            self.assertTrue(\n                np.all(\n                    backend.convert_to_numpy(\n                        layer.compute_mask(\n                            [input_1, input_2],\n                            [backend.Variable(x1), backend.Variable(x2)],\n                        )\n                    )\n                )\n            )\n\n    @parameterized.named_parameters(TEST_PARAMETERS)\n    def test_errors(\n        self,\n        layer_class,\n        init_kwargs={},\n        input_shape=(2, 4, 5),\n        skip_mask_test=False,\n        **kwargs,\n    ):\n        if skip_mask_test:\n            pytest.skip(\"Masking not supported\")\n\n        batch_size = input_shape[0]\n        shape = input_shape[1:]\n        x1 = np.random.rand(*input_shape)\n        x1 = np.random.rand(batch_size, *shape)\n\n        input_1 = layers.Input(shape=shape, batch_size=batch_size)\n        input_2 = layers.Input(shape=shape, batch_size=batch_size)\n        layer = layer_class(**init_kwargs)\n\n        with self.assertRaisesRegex(ValueError, \"`mask` should be a list.\"):\n            layer.compute_mask([input_1, input_2], x1)\n\n        with self.assertRaisesRegex(ValueError, \"`inputs` should be a list.\"):\n            layer.compute_mask(input_1, [None, None])\n\n        with self.assertRaisesRegex(\n            ValueError, \" should have the same length.\"\n        ):\n            layer.compute_mask([input_1, input_2], [None])\n\n    def test_subtract_layer_inputs_length_errors(self):\n        shape = (4, 5)\n        input_1 = layers.Input(shape=shape)\n        input_2 = layers.Input(shape=shape)\n        input_3 = layers.Input(shape=shape)\n\n        with self.assertRaisesRegex(\n            ValueError, \"layer should be called on exactly 2 inputs\"\n        ):\n            layers.Subtract()([input_1, input_2, input_3])\n        with self.assertRaisesRegex(\n            ValueError, \"layer should be called on exactly 2 inputs\"\n        ):\n            layers.Subtract()([input_1])\n\n    def test_dot_higher_dim(self):\n        a_shape = (1, 3, 2)\n        b_shape = (1, 1, 2, 3)\n        # Test symbolic call\n        a = layers.Input(batch_shape=a_shape)\n        b = layers.Input(batch_shape=b_shape)\n        c = layers.Dot(axes=(-2, -1))([a, b])\n        self.assertEqual(c.shape, (1, 2, 1, 2))\n        a = np.random.random(a_shape)\n        b = np.random.random(b_shape)\n        c = layers.Dot(axes=(-2, -1))([a, b])\n        self.assertEqual(backend.standardize_shape(c.shape), (1, 2, 1, 2))\n\n    def test_add_with_mask(self):\n        mask = layers.Masking()\n        x1 = mask(backend.convert_to_tensor([[[0, 0], [1, 2], [0, 0], [3, 4]]]))\n        x2 = backend.convert_to_tensor([[[0, 0], [0, 0], [1, 2], [3, 4]]])\n\n        output = layers.Add()([x1, x2])\n        self.assertAllClose(output, [[[0, 0], [1, 2], [1, 2], [6, 8]]])\n        self.assertIsNone(getattr(output, \"_keras_mask\", None))\n\n        x2 = mask(x2)\n        output = layers.Add()([x1, x2])\n        self.assertAllClose(output, [[[0, 0], [1, 2], [1, 2], [6, 8]]])\n        self.assertAllClose(output._keras_mask, [[0, 1, 1, 1]])\n\n    def test_subtract_with_mask(self):\n        mask = layers.Masking()\n        x1 = mask(backend.convert_to_tensor([[[0, 0], [1, 2], [0, 0], [3, 4]]]))\n        x2 = backend.convert_to_tensor([[[0, 0], [0, 0], [1, 2], [3, 4]]])\n\n        output = layers.Subtract()([x1, x2])\n        self.assertAllClose(output, [[[0, 0], [1, 2], [-1, -2], [0, 0]]])\n        self.assertIsNone(getattr(output, \"_keras_mask\", None))\n\n        x2 = mask(x2)\n        output = layers.Subtract()([x1, x2])\n        self.assertAllClose(output, [[[0, 0], [1, 2], [-1, -2], [0, 0]]])\n        self.assertAllClose(output._keras_mask, [[0, 1, 1, 1]])\n\n    def test_average_with_mask(self):\n        mask = layers.Masking()\n        x1 = mask(backend.convert_to_tensor([[[0, 0], [1, 2], [0, 0], [3, 4]]]))\n        x2 = backend.convert_to_tensor([[[0, 0], [0, 0], [1, 2], [3, 4]]])\n\n        output = layers.Average()([x1, x2])\n        self.assertAllClose(output, [[[0, 0], [0.5, 1], [0.5, 1], [3, 4]]])\n        self.assertIsNone(getattr(output, \"_keras_mask\", None))\n\n        x2 = mask(x2)\n        output = layers.Average()([x1, x2])\n        self.assertAllClose(output, [[[0, 0], [0.5, 1], [0.5, 1], [3, 4]]])\n        self.assertAllClose(output._keras_mask, [[0, 1, 1, 1]])\n\n    def test_multiply_with_mask(self):\n        mask = layers.Masking()\n        x1 = mask(backend.convert_to_tensor([[[0, 0], [1, 2], [0, 0], [3, 4]]]))\n        x2 = backend.convert_to_tensor([[[0, 0], [0, 0], [1, 2], [3, 4]]])\n\n        output = layers.Multiply()([x1, x2])\n        self.assertAllClose(output, [[[0, 0], [0, 0], [1, 2], [9, 16]]])\n        self.assertIsNone(getattr(output, \"_keras_mask\", None))\n\n        x2 = mask(x2)\n        output = layers.Multiply()([x1, x2])\n        self.assertAllClose(output, [[[0, 0], [1, 2], [1, 2], [9, 16]]])\n        self.assertAllClose(output._keras_mask, [[0, 1, 1, 1]])\n\n    def test_maximum_with_mask(self):\n        mask = layers.Masking()\n        x1 = mask(\n            backend.convert_to_tensor([[[0, 0], [-1, -2], [0, 0], [-3, -4]]])\n        )\n        x2 = backend.convert_to_tensor([[[0, 0], [0, 0], [-1, -2], [-3, -4]]])\n\n        output = layers.Maximum()([x1, x2])\n        self.assertAllClose(output, [[[0, 0], [0, 0], [-1, -2], [-3, -4]]])\n        self.assertIsNone(getattr(output, \"_keras_mask\", None))\n\n        x2 = mask(x2)\n        output = layers.Maximum()([x1, x2])\n        self.assertAllClose(output, [[[0, 0], [-1, -2], [-1, -2], [-3, -4]]])\n        self.assertAllClose(output._keras_mask, [[0, 1, 1, 1]])\n\n    def test_minimum_with_mask(self):\n        mask = layers.Masking()\n        x1 = mask(backend.convert_to_tensor([[[0, 0], [1, 2], [0, 0], [3, 4]]]))\n        x2 = backend.convert_to_tensor([[[0, 0], [0, 0], [1, 2], [3, 4]]])\n\n        output = layers.Minimum()([x1, x2])\n        self.assertAllClose(output, [[[0, 0], [0, 0], [1, 2], [3, 4]]])\n        self.assertIsNone(getattr(output, \"_keras_mask\", None))\n\n        x2 = mask(x2)\n        output = layers.Minimum()([x1, x2])\n        self.assertAllClose(output, [[[0, 0], [1, 2], [1, 2], [3, 4]]])\n        self.assertAllClose(output._keras_mask, [[0, 1, 1, 1]])\n\n    def test_concatenate_with_mask(self):\n        mask = layers.Masking()\n        x1 = mask(backend.convert_to_tensor([[[0, 0], [1, 2], [0, 0], [3, 4]]]))\n        x2 = backend.convert_to_tensor([[[0, 0], [0, 0], [1, 2], [3, 4]]])\n\n        output = layers.Concatenate(axis=1)([x1, x2])\n        self.assertAllClose(\n            output,\n            [[[0, 0], [1, 2], [0, 0], [3, 4], [0, 0], [0, 0], [1, 2], [3, 4]]],\n        )\n        self.assertAllClose(output._keras_mask, [[0, 1, 0, 1, 1, 1, 1, 1]])\n\n        output = layers.Concatenate(axis=2)([x1, x2])\n        self.assertAllClose(\n            output,\n            [[[0, 0, 0, 0], [1, 2, 0, 0], [0, 0, 1, 2], [3, 4, 3, 4]]],\n        )\n        self.assertAllClose(output._keras_mask, [[1, 1, 1, 1]])\n\n    def test_concatenate_with_mask_symbolic(self):\n        input1 = layers.Input((4, 2))\n        input2 = layers.Input((4, 2))\n        mask = layers.Masking()\n        output = layers.Concatenate(axis=1)([mask(input1), input2])\n        model = models.Model(\n            inputs=[input1, input2], outputs=output._keras_mask\n        )\n        x1 = backend.convert_to_tensor([[[0, 0], [1, 2], [0, 0], [3, 4]]])\n        x2 = backend.convert_to_tensor([[[0, 0], [0, 0], [1, 2], [3, 4]]])\n        self.assertAllClose(model([x1, x2]), [[0, 1, 0, 1, 1, 1, 1, 1]])\n\n    def test_concatenate_errors(self):\n        # This should work\n        x1 = np.ones((1, 1, 1, 1, 5))\n        x2 = np.ones((1, 1, 1, 1, 4))\n        out = layers.Concatenate(axis=-1)([x1, x2])\n        self.assertEqual(ops.shape(out), (1, 1, 1, 1, 9))\n\n        # This won't\n        x1 = np.ones((1, 2, 1, 1, 5))\n        x2 = np.ones((1, 1, 1, 1, 4))\n        with self.assertRaisesRegex(\n            ValueError,\n            (\n                \"requires inputs with matching shapes \"\n                \"except for the concatenation axis\"\n            ),\n        ):\n            out = layers.Concatenate(axis=-1)([x1, x2])\n        x1 = np.ones((1, 2, 1, 2, 1))\n        x2 = np.ones((1, 1, 1, 3, 1))\n        with self.assertRaisesRegex(\n            ValueError,\n            (\n                \"requires inputs with matching shapes \"\n                \"except for the concatenation axis\"\n            ),\n        ):\n            out = layers.Concatenate(axis=1)([x1, x2])\n\n    @parameterized.named_parameters(TEST_PARAMETERS)\n    @pytest.mark.skipif(\n        not backend.SUPPORTS_SPARSE_TENSORS,\n        reason=\"Backend does not support sparse tensors.\",\n    )\n    def test_sparse(\n        self,\n        layer_class,\n        np_op,\n        init_kwargs={},\n        input_shape=(2, 4, 5),\n        expected_output_shape=(2, 4, 5),\n        **kwargs,\n    ):\n        self.run_layer_test(\n            layer_class,\n            init_kwargs=init_kwargs,\n            input_shape=[input_shape, input_shape],\n            input_sparse=True,\n            expected_output_shape=expected_output_shape,\n            expected_output_sparse=True,\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n            run_training_check=False,\n            run_mixed_precision_check=False,\n        )\n\n        layer = layer_class(**init_kwargs)\n\n        # Merging a sparse tensor with a dense tensor, or a dense tensor with a\n        # sparse tensor produces a dense tensor\n        if backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            x1 = tf.SparseTensor([[0, 0], [1, 2]], [1.0, 2.0], (2, 3))\n            x3 = tf.SparseTensor([[0, 0], [1, 1]], [4.0, 5.0], (2, 3))\n        elif backend.backend() == \"jax\":\n            import jax.experimental.sparse as jax_sparse\n\n            # Use n_batch of 1 to be compatible with all ops.\n            x1 = jax_sparse.BCOO(([[1.0, 2.0]], [[[0], [2]]]), shape=(2, 3))\n            x3 = jax_sparse.BCOO(([[4.0, 5.0]], [[[0], [1]]]), shape=(2, 3))\n        else:\n            self.fail(f\"Sparse is unsupported with backend {backend.backend()}\")\n\n        x1_np = backend.convert_to_numpy(x1)\n        x2 = np.random.rand(2, 3)\n        self.assertAllClose(layer([x1, x2]), np_op(x1_np, x2, **init_kwargs))\n        self.assertAllClose(layer([x2, x1]), np_op(x2, x1_np, **init_kwargs))\n\n        # Merging a sparse tensor with a sparse tensor produces a sparse tensor\n        x3_np = backend.convert_to_numpy(x3)\n\n        self.assertSparse(layer([x1, x3]))\n        self.assertAllClose(layer([x1, x3]), np_op(x1_np, x3_np, **init_kwargs))\n"
  },
  {
    "path": "keras/src/layers/merging/minimum.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.merging.base_merge import Merge\n\n\n@keras_export(\"keras.layers.Minimum\")\nclass Minimum(Merge):\n    \"\"\"Computes elementwise minimum on a list of inputs.\n\n    It takes as input a list of tensors, all of the same shape,\n    and returns a single tensor (also of the same shape).\n\n    Examples:\n\n    >>> input_shape = (2, 3, 4)\n    >>> x1 = np.random.rand(*input_shape)\n    >>> x2 = np.random.rand(*input_shape)\n    >>> y = keras.layers.Minimum()([x1, x2])\n\n    Usage in a Keras model:\n\n    >>> input1 = keras.layers.Input(shape=(16,))\n    >>> x1 = keras.layers.Dense(8, activation='relu')(input1)\n    >>> input2 = keras.layers.Input(shape=(32,))\n    >>> x2 = keras.layers.Dense(8, activation='relu')(input2)\n    >>> # equivalent to `y = keras.layers.minimum([x1, x2])`\n    >>> y = keras.layers.Minimum()([x1, x2])\n    >>> out = keras.layers.Dense(4)(y)\n    >>> model = keras.models.Model(inputs=[input1, input2], outputs=out)\n\n    \"\"\"\n\n    def _merge_function(self, inputs):\n        return self._apply_merge_op_and_or_mask(ops.minimum, inputs)\n\n\n@keras_export(\"keras.layers.minimum\")\ndef minimum(inputs, **kwargs):\n    \"\"\"Functional interface to the `keras.layers.Minimum` layer.\n\n    Args:\n        inputs: A list of input tensors , all of the same shape.\n        **kwargs: Standard layer keyword arguments.\n\n    Returns:\n        A tensor as the elementwise product of the inputs with the same\n        shape as the inputs.\n\n    Examples:\n\n    >>> input_shape = (2, 3, 4)\n    >>> x1 = np.random.rand(*input_shape)\n    >>> x2 = np.random.rand(*input_shape)\n    >>> y = keras.layers.minimum([x1, x2])\n\n    Usage in a Keras model:\n\n    >>> input1 = keras.layers.Input(shape=(16,))\n    >>> x1 = keras.layers.Dense(8, activation='relu')(input1)\n    >>> input2 = keras.layers.Input(shape=(32,))\n    >>> x2 = keras.layers.Dense(8, activation='relu')(input2)\n    >>> y = keras.layers.minimum([x1, x2])\n    >>> out = keras.layers.Dense(4)(y)\n    >>> model = keras.models.Model(inputs=[input1, input2], outputs=out)\n\n    \"\"\"\n    return Minimum(**kwargs)(inputs)\n"
  },
  {
    "path": "keras/src/layers/merging/multiply.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.merging.base_merge import Merge\n\n\n@keras_export(\"keras.layers.Multiply\")\nclass Multiply(Merge):\n    \"\"\"Performs elementwise multiplication.\n\n    It takes as input a list of tensors, all of the same shape,\n    and returns a single tensor (also of the same shape).\n\n    Examples:\n\n    >>> input_shape = (2, 3, 4)\n    >>> x1 = np.random.rand(*input_shape)\n    >>> x2 = np.random.rand(*input_shape)\n    >>> y = keras.layers.Multiply()([x1, x2])\n\n    Usage in a Keras model:\n\n    >>> input1 = keras.layers.Input(shape=(16,))\n    >>> x1 = keras.layers.Dense(8, activation='relu')(input1)\n    >>> input2 = keras.layers.Input(shape=(32,))\n    >>> x2 = keras.layers.Dense(8, activation='relu')(input2)\n    >>> # equivalent to `y = keras.layers.multiply([x1, x2])`\n    >>> y = keras.layers.Multiply()([x1, x2])\n    >>> out = keras.layers.Dense(4)(y)\n    >>> model = keras.models.Model(inputs=[input1, input2], outputs=out)\n\n    \"\"\"\n\n    def _merge_function(self, inputs):\n        masks = [backend.get_keras_mask(x) for x in inputs]\n        has_output_mask = all(mask is not None for mask in masks)\n        output = None\n        output_mask = None\n\n        for x, mask in zip(inputs, masks):\n            if mask is not None:\n                mask = ops.broadcast_to(ops.expand_dims(mask, -1), ops.shape(x))\n                # Replace 0s with 1s outside of mask.\n                x = ops.where(mask, x, ops.cast(1, x.dtype))\n                if has_output_mask:\n                    output_mask = (\n                        mask\n                        if output_mask is None\n                        else ops.logical_or(output_mask, mask)\n                    )\n            output = x if output is None else ops.multiply(output, x)\n\n        if has_output_mask:\n            # Replace 1s with 0s outside of mask per standard masking rules.\n            output = ops.where(output_mask, output, ops.cast(0, output.dtype))\n            output_mask = ops.any(output_mask, axis=-1, keepdims=False)\n            backend.set_keras_mask(output, output_mask)\n        return output\n\n\n@keras_export(\"keras.layers.multiply\")\ndef multiply(inputs, **kwargs):\n    \"\"\"Functional interface to the `keras.layers.Multiply` layer.\n\n    Args:\n        inputs: A list of input tensors , all of the same shape.\n        **kwargs: Standard layer keyword arguments.\n\n    Returns:\n        A tensor as the elementwise product of the inputs with the same\n        shape as the inputs.\n\n    Examples:\n\n    >>> input_shape = (2, 3, 4)\n    >>> x1 = np.random.rand(*input_shape)\n    >>> x2 = np.random.rand(*input_shape)\n    >>> y = keras.layers.multiply([x1, x2])\n\n    Usage in a Keras model:\n\n    >>> input1 = keras.layers.Input(shape=(16,))\n    >>> x1 = keras.layers.Dense(8, activation='relu')(input1)\n    >>> input2 = keras.layers.Input(shape=(32,))\n    >>> x2 = keras.layers.Dense(8, activation='relu')(input2)\n    >>> y = keras.layers.multiply([x1, x2])\n    >>> out = keras.layers.Dense(4)(y)\n    >>> model = keras.models.Model(inputs=[input1, input2], outputs=out)\n\n    \"\"\"\n    return Multiply(**kwargs)(inputs)\n"
  },
  {
    "path": "keras/src/layers/merging/subtract.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.merging.base_merge import Merge\n\n\n@keras_export(\"keras.layers.Subtract\")\nclass Subtract(Merge):\n    \"\"\"Performs elementwise subtraction.\n\n    It takes as input a list of tensors of size 2 both of the\n    same shape, and returns a single tensor (inputs[0] - inputs[1])\n    of same shape.\n\n    Examples:\n\n    >>> input_shape = (2, 3, 4)\n    >>> x1 = np.random.rand(*input_shape)\n    >>> x2 = np.random.rand(*input_shape)\n    >>> y = keras.layers.Subtract()([x1, x2])\n\n    Usage in a Keras model:\n\n    >>> input1 = keras.layers.Input(shape=(16,))\n    >>> x1 = keras.layers.Dense(8, activation='relu')(input1)\n    >>> input2 = keras.layers.Input(shape=(32,))\n    >>> x2 = keras.layers.Dense(8, activation='relu')(input2)\n    >>> # equivalent to `subtracted = keras.layers.subtract([x1, x2])`\n    >>> subtracted = keras.layers.Subtract()([x1, x2])\n    >>> out = keras.layers.Dense(4)(subtracted)\n    >>> model = keras.models.Model(inputs=[input1, input2], outputs=out)\n\n    \"\"\"\n\n    def build(self, input_shape):\n        super().build(input_shape)\n        if len(input_shape) != 2:\n            raise ValueError(\n                \"A `Subtract` layer should be called on exactly 2 inputs. \"\n                f\"Received: input_shape={input_shape}\"\n            )\n\n    def _merge_function(self, inputs):\n        if len(inputs) != 2:\n            raise ValueError(\n                \"A `Subtract` layer should be called on exactly 2 inputs. \"\n                f\"Received: inputs={inputs}\"\n            )\n        return ops.subtract(inputs[0], inputs[1])\n\n\n@keras_export(\"keras.layers.subtract\")\ndef subtract(inputs, **kwargs):\n    \"\"\"Functional interface to the `keras.layers.Subtract` layer.\n\n    Args:\n        inputs: A list of input tensors of size 2, each tensor of\n            the same shape.\n        **kwargs: Standard layer keyword arguments.\n\n    Returns:\n        A tensor as the difference of the inputs. It has the same shape\n        as the inputs.\n\n    Examples:\n\n    >>> input_shape = (2, 3, 4)\n    >>> x1 = np.random.rand(*input_shape)\n    >>> x2 = np.random.rand(*input_shape)\n    >>> y = keras.layers.subtract([x1, x2])\n\n    Usage in a Keras model:\n\n    >>> input1 = keras.layers.Input(shape=(16,))\n    >>> x1 = keras.layers.Dense(8, activation='relu')(input1)\n    >>> input2 = keras.layers.Input(shape=(32,))\n    >>> x2 = keras.layers.Dense(8, activation='relu')(input2)\n    >>> subtracted = keras.layers.subtract([x1, x2])\n    >>> out = keras.layers.Dense(4)(subtracted)\n    >>> model = keras.models.Model(inputs=[input1, input2], outputs=out)\n\n    \"\"\"\n    return Subtract(**kwargs)(inputs)\n"
  },
  {
    "path": "keras/src/layers/normalization/__init__.py",
    "content": ""
  },
  {
    "path": "keras/src/layers/normalization/batch_normalization.py",
    "content": "from keras.src import backend\nfrom keras.src import constraints\nfrom keras.src import initializers\nfrom keras.src import ops\nfrom keras.src import regularizers\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\n\n\n# TODO(abheesht17): Move this to utils?\ndef _clone_initializer(initializer):\n    \"\"\"Clones an initializer to ensure a new seed.\n\n    Args:\n        initializer: The initializer to clone.\n\n    Returns:\n        A cloned initializer if it is clonable, otherwise the original one.\n\n    As of tensorflow 2.10, we need to clone user passed initializers when\n    invoking them twice to avoid creating the same randomized initialization.\n    \"\"\"\n    if isinstance(initializer, initializers.Initializer):\n        config = initializer.get_config()\n        return initializer.__class__.from_config(config)\n    # If we get a string or dict, just return as we cannot and should not clone.\n    return initializer\n\n\n@keras_export(\"keras.layers.BatchNormalization\")\nclass BatchNormalization(Layer):\n    \"\"\"Layer that normalizes its inputs.\n\n    Batch normalization applies a transformation that maintains the mean output\n    close to 0 and the output standard deviation close to 1.\n\n    Importantly, batch normalization works differently during training and\n    during inference.\n\n    **During training** (i.e. when using `fit()` or when calling the layer/model\n    with the argument `training=True`), the layer normalizes its output using\n    the mean and standard deviation of the current batch of inputs. That is to\n    say, for each channel being normalized, the layer returns\n    `gamma * (batch - mean(batch)) / sqrt(var(batch) + epsilon) + beta`, where:\n\n    - `epsilon` is small constant (configurable as part of the constructor\n    arguments)\n    - `gamma` is a learned scaling factor (initialized as 1), which\n    can be disabled by passing `scale=False` to the constructor.\n    - `beta` is a learned offset factor (initialized as 0), which\n    can be disabled by passing `center=False` to the constructor.\n\n    **During inference** (i.e. when using `evaluate()` or `predict()` or when\n    calling the layer/model with the argument `training=False` (which is the\n    default), the layer normalizes its output using a moving average of the\n    mean and standard deviation of the batches it has seen during training. That\n    is to say, it returns\n    `gamma * (batch - self.moving_mean) / sqrt(self.moving_var+epsilon) + beta`.\n\n    `self.moving_mean` and `self.moving_var` are non-trainable variables that\n    are updated each time the layer in called in training mode, as such:\n\n    - `moving_mean = moving_mean * momentum + mean(batch) * (1 - momentum)`\n    - `moving_var = moving_var * momentum + var(batch) * (1 - momentum)`\n\n    As such, the layer will only normalize its inputs during inference\n    *after having been trained on data that has similar statistics as the\n    inference data*.\n\n    Args:\n        axis: Integer, the axis that should be normalized\n            (typically the features axis). For instance, after a `Conv2D` layer\n            with `data_format=\"channels_first\"`, use `axis=1`.\n        momentum: Momentum for the moving average.\n        epsilon: Small float added to variance to avoid dividing by zero.\n        center: If `True`, add offset of `beta` to normalized tensor.\n            If `False`, `beta` is ignored.\n        scale: If `True`, multiply by `gamma`. If `False`, `gamma` is not used.\n            When the next layer is linear this can be disabled\n            since the scaling will be done by the next layer.\n        beta_initializer: Initializer for the beta weight.\n        gamma_initializer: Initializer for the gamma weight.\n        moving_mean_initializer: Initializer for the moving mean.\n        moving_variance_initializer: Initializer for the moving variance.\n        beta_regularizer: Optional regularizer for the beta weight.\n        gamma_regularizer: Optional regularizer for the gamma weight.\n        beta_constraint: Optional constraint for the beta weight.\n        gamma_constraint: Optional constraint for the gamma weight.\n        synchronized: Only applicable with the TensorFlow backend.\n            If `True`, synchronizes the global batch statistics (mean and\n            variance) for the layer across all devices at each training step\n            in a distributed training strategy.\n            If `False`, each replica uses its own local batch statistics.\n        renorm: Whether to use\n            [Batch Renormalization](https://arxiv.org/abs/1702.03275). This\n            adds extra variables during training. The inference is the same\n            for either value of this parameter.\n        renorm_clipping: Dictionary, valid only if `renorm = True`.\n            Maps optional keys `\"rmax\"`, `\"rmin\"`, `\"dmax\"` to floats used to\n            clip the renorm correction. The correction `(r, d)` is used as\n            `corrected_value = normalized_value * r + d`, with `r` clipped to\n            `[rmin, rmax]`, and `d` to `[-dmax, dmax]`. Missing `rmax`, `rmin`,\n            `dmax` are set to `inf`, `0`, `inf`, respectively.\n        renorm_momentum: Momentum used to update the moving means and standard\n            deviations with renorm. Valid only if `renorm= True`. Unlike\n            `momentum`, this affects training and should be neither too small\n            (which would add noise) nor too large (which would give stale\n            estimates). Note that `momentum` is still applied to get the means\n            and variances for inference.\n        **kwargs: Base layer keyword arguments (e.g. `name` and `dtype`).\n\n    Call arguments:\n        inputs: Input tensor (of any rank).\n        training: Python boolean indicating whether the layer should behave in\n            training mode or in inference mode.\n            - `training=True`: The layer will normalize its inputs using\n            the mean and variance of the current batch of inputs.\n            - `training=False`: The layer will normalize its inputs using\n            the mean and variance of its moving statistics, learned during\n            training.\n        mask: Binary tensor of shape broadcastable to `inputs` tensor, with\n            `True` values indicating the positions for which mean and variance\n            should be computed. Masked elements of the current inputs are not\n            taken into account for mean and variance computation during\n            training. Any prior unmasked element values will be taken into\n            account until their momentum expires.\n\n    Reference:\n\n    - [Ioffe and Szegedy, 2015](https://arxiv.org/abs/1502.03167).\n\n    **About setting `layer.trainable = False` on a `BatchNormalization` layer:**\n\n    The meaning of setting `layer.trainable = False` is to freeze the layer,\n    i.e. its internal state will not change during training:\n    its trainable weights will not be updated\n    during `fit()` or `train_on_batch()`, and its state updates will not be run.\n\n    Usually, this does not necessarily mean that the layer is run in inference\n    mode (which is normally controlled by the `training` argument that can\n    be passed when calling a layer). \"Frozen state\" and \"inference mode\"\n    are two separate concepts.\n\n    However, in the case of the `BatchNormalization` layer, **setting\n    `trainable = False` on the layer means that the layer will be\n    subsequently run in inference mode** (meaning that it will use\n    the moving mean and the moving variance to normalize the current batch,\n    rather than using the mean and variance of the current batch).\n\n    Note that:\n\n    - Setting `trainable` on an model containing other layers will recursively\n        set the `trainable` value of all inner layers.\n    - If the value of the `trainable` attribute is changed after calling\n        `compile()` on a model, the new value doesn't take effect for this model\n        until `compile()` is called again.\n    \"\"\"\n\n    def __init__(\n        self,\n        axis=-1,\n        momentum=0.99,\n        epsilon=1e-3,\n        center=True,\n        scale=True,\n        beta_initializer=\"zeros\",\n        gamma_initializer=\"ones\",\n        moving_mean_initializer=\"zeros\",\n        moving_variance_initializer=\"ones\",\n        beta_regularizer=None,\n        gamma_regularizer=None,\n        beta_constraint=None,\n        gamma_constraint=None,\n        renorm=False,\n        renorm_clipping=None,\n        renorm_momentum=0.99,\n        synchronized=False,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.axis = int(axis)\n\n        if synchronized and backend.backend() != \"tensorflow\":\n            raise ValueError(\n                \"Argument synchronized=True is only supported \"\n                \"with the TensorFlow backend.\"\n            )\n        self.synchronized = synchronized\n\n        self.momentum = float(momentum)\n        self.epsilon = float(epsilon)\n        self.center = center\n        self.scale = scale\n        self.beta_initializer = initializers.get(beta_initializer)\n        self.gamma_initializer = initializers.get(gamma_initializer)\n        self.moving_mean_initializer = initializers.get(moving_mean_initializer)\n        self.moving_variance_initializer = initializers.get(\n            moving_variance_initializer\n        )\n        self.beta_regularizer = regularizers.get(beta_regularizer)\n        self.gamma_regularizer = regularizers.get(gamma_regularizer)\n        self.beta_constraint = constraints.get(beta_constraint)\n        self.gamma_constraint = constraints.get(gamma_constraint)\n        self.supports_masking = True\n\n        self.renorm = renorm\n        if renorm:\n            renorm_clipping = renorm_clipping or {}\n            keys = [\"rmax\", \"rmin\", \"dmax\"]\n            if set(renorm_clipping) - set(keys):\n                raise ValueError(\n                    \"Received invalid keys for `renorm_clipping` argument: \"\n                    f\"{renorm_clipping}. Supported values: {keys}.\"\n                )\n            rmax = renorm_clipping.get(\"rmax\")\n            rmin = renorm_clipping.get(\"rmin\")\n            dmax = renorm_clipping.get(\"dmax\")\n\n            if rmax is not None and rmin is not None and rmax < rmin:\n                raise ValueError(\n                    \"rmax should be greater than rmin in the `renorm_clipping` \"\n                    \"argument. Received: rmax={rmax}, rmin={rmin}.\"\n                )\n            if dmax is not None and dmax < 0:\n                raise ValueError(\n                    \"dmax should be non-negative in the `renorm_clipping` \"\n                    \"\"\"argument. Received: dmax={dmax}.\"\"\"\n                )\n\n        self.renorm_clipping = renorm_clipping\n        self.renorm_momentum = renorm_momentum\n\n        self.gamma = None\n        self.beta = None\n        self.moving_mean = None\n        self.moving_variance = None\n        self._reduction_axes = None\n\n    def build(self, input_shape):\n        shape = (input_shape[self.axis],)\n        if self.scale:\n            self.gamma = self.add_weight(\n                shape=shape,\n                name=\"gamma\",\n                initializer=self.gamma_initializer,\n                regularizer=self.gamma_regularizer,\n                constraint=self.gamma_constraint,\n                trainable=True,\n                autocast=False,\n            )\n        if self.center:\n            self.beta = self.add_weight(\n                shape=shape,\n                name=\"beta\",\n                initializer=self.beta_initializer,\n                regularizer=self.beta_regularizer,\n                constraint=self.beta_constraint,\n                trainable=True,\n                autocast=False,\n            )\n        self.moving_mean = self.add_weight(\n            shape=shape,\n            name=\"moving_mean\",\n            initializer=self.moving_mean_initializer,\n            trainable=False,\n            autocast=False,\n        )\n        self.moving_variance = self.add_weight(\n            shape=shape,\n            name=\"moving_variance\",\n            initializer=self.moving_variance_initializer,\n            trainable=False,\n            autocast=False,\n        )\n\n        if self.renorm:\n            # In batch renormalization we track the inference moving stddev\n            # instead of the moving variance to more closely align with the\n            # paper. The stddev is initialized as sqrt of the variance\n            # initializer.\n            def moving_stddev_initializer(shape, dtype=None):\n                cloned = _clone_initializer(self.moving_variance_initializer)\n                return ops.sqrt(cloned(shape, dtype=dtype))\n\n            self.moving_stddev = self.add_weight(\n                shape=shape,\n                name=\"moving_stddev\",\n                initializer=moving_stddev_initializer,\n                trainable=False,\n                autocast=False,\n            )\n            # Create variables to maintain the moving mean and standard\n            # deviation. These are used in training and thus are different\n            # from the moving averages above.\n            self.renorm_mean = self.add_weight(\n                shape=shape,\n                name=\"renorm_mean\",\n                initializer=_clone_initializer(self.moving_mean_initializer),\n                trainable=False,\n                autocast=False,\n            )\n            self.renorm_stddev = self.add_weight(\n                shape=shape,\n                name=\"renorm_stddev\",\n                initializer=moving_stddev_initializer,\n                trainable=False,\n                autocast=False,\n            )\n\n        self.input_spec = InputSpec(\n            ndim=len(input_shape), axes={self.axis: input_shape[self.axis]}\n        )\n\n        reduction_axes = list(range(len(input_shape)))\n        del reduction_axes[self.axis]\n        self._reduction_axes = reduction_axes\n\n    def compute_output_shape(self, input_shape):\n        if isinstance(self.axis, int):\n            axes = [self.axis]\n        else:\n            axes = self.axis\n\n        for axis in axes:\n            if axis >= len(input_shape) or axis < -len(input_shape):\n                raise ValueError(\n                    f\"Axis {axis} is out of bounds for \"\n                    f\"input shape {input_shape}. \"\n                    f\"Received: axis={self.axis}\"\n                )\n        return input_shape\n\n    def call(self, inputs, training=None, mask=None):\n        # Check if the mask has one less dimension than the inputs.\n        if mask is not None:\n            if len(mask.shape) != len(inputs.shape) - 1:\n                # Raise a value error\n                raise ValueError(\n                    \"The mask provided should be one dimension less \"\n                    \"than the inputs. Received: \"\n                    f\"mask.shape={mask.shape}, inputs.shape={inputs.shape}\"\n                )\n\n        compute_dtype = backend.result_type(inputs.dtype, \"float32\")\n        # BN is prone to overflow with float16/bfloat16 inputs, so we upcast to\n        # float32 for the subsequent computations.\n        inputs = ops.cast(inputs, compute_dtype)\n\n        moving_mean = ops.cast(self.moving_mean, inputs.dtype)\n        moving_variance = ops.cast(self.moving_variance, inputs.dtype)\n\n        if self.scale:\n            gamma = ops.cast(self.gamma, inputs.dtype)\n        else:\n            gamma = None\n\n        if self.center:\n            beta = ops.cast(self.beta, inputs.dtype)\n        else:\n            beta = None\n\n        if training and self.trainable:\n            mean, variance = self._moments(inputs, mask)\n\n            if self.renorm:\n                # Compute renorm corrections (r and d).\n                (\n                    r,\n                    d,\n                    mean,\n                    variance,\n                ) = self._renorm_correction_and_moments(mean, variance)\n\n                # x = x * gamma + beta without renorm, and\n                # (x * r + d) * gamma + beta = x * (r * gamma) + (d * gamma +\n                # beta) with renorm.\n                gamma, beta = self._compose_transforms(\n                    r, d, gamma, beta, inputs.dtype\n                )\n\n                # Update moving statistics.\n                self._update_renorm_statistics(mean, variance)\n            else:\n                self.moving_mean.assign(\n                    moving_mean * self.momentum + mean * (1.0 - self.momentum)\n                )\n                self.moving_variance.assign(\n                    moving_variance * self.momentum\n                    + variance * (1.0 - self.momentum)\n                )\n        else:\n            mean = moving_mean\n            variance = moving_variance\n\n        outputs = ops.batch_normalization(\n            x=inputs,\n            mean=mean,\n            variance=variance,\n            axis=self.axis,\n            offset=beta,\n            scale=gamma,\n            epsilon=self.epsilon,\n        )\n        return ops.cast(outputs, self.compute_dtype)\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\n            \"axis\": self.axis,\n            \"momentum\": self.momentum,\n            \"epsilon\": self.epsilon,\n            \"center\": self.center,\n            \"scale\": self.scale,\n            \"beta_initializer\": initializers.serialize(self.beta_initializer),\n            \"gamma_initializer\": initializers.serialize(self.gamma_initializer),\n            \"moving_mean_initializer\": initializers.serialize(\n                self.moving_mean_initializer\n            ),\n            \"moving_variance_initializer\": initializers.serialize(\n                self.moving_variance_initializer\n            ),\n            \"beta_regularizer\": regularizers.serialize(self.beta_regularizer),\n            \"gamma_regularizer\": regularizers.serialize(self.gamma_regularizer),\n            \"beta_constraint\": constraints.serialize(self.beta_constraint),\n            \"gamma_constraint\": constraints.serialize(self.gamma_constraint),\n            \"synchronized\": self.synchronized,\n            \"renorm\": self.renorm,\n            \"renorm_clipping\": self.renorm_clipping,\n            \"renorm_momentum\": self.renorm_momentum,\n        }\n        return {**base_config, **config}\n\n    def _moments(self, inputs, mask):\n        if mask is None:\n            return ops.moments(\n                inputs,\n                axes=self._reduction_axes,\n                synchronized=self.synchronized,\n            )\n\n        mask_weights = ops.cast(mask, inputs.dtype)\n        mask_weights_broadcasted = ops.expand_dims(mask_weights, axis=-1)\n        broadcasted_mask = ops.broadcast_to(\n            mask_weights_broadcasted, ops.shape(inputs)\n        )\n        weighted_inputs = broadcasted_mask * inputs\n\n        weighted_input_sum = ops.sum(\n            weighted_inputs,\n            self._reduction_axes,\n            keepdims=True,\n        )\n        sum_of_weights = ops.sum(\n            broadcasted_mask,\n            self._reduction_axes,\n            keepdims=True,\n        )\n        mean = weighted_input_sum / (sum_of_weights + backend.epsilon())\n\n        difference = weighted_inputs - mean\n        squared_difference = ops.square(difference)\n        weighted_distsq = ops.sum(\n            broadcasted_mask * squared_difference,\n            self._reduction_axes,\n            keepdims=True,\n        )\n        variance = weighted_distsq / (sum_of_weights + backend.epsilon())\n\n        return ops.squeeze(mean), ops.squeeze(variance)\n\n    def _renorm_correction_and_moments(self, mean, variance):\n        \"\"\"Computes the correction for batch renormalization.\n\n        This method computes the r and d correction factors.\n\n        Args:\n            mean: The mean of the current batch.\n            variance: The variance of the current batch.\n\n        Returns:\n            A tuple (r, s, mean, variance) where r and d are the correction\n            factors, and mean/variance are passed through unchanged.\n        \"\"\"\n        stddev = ops.sqrt(variance + self.epsilon)\n\n        # Get the renorm moving statistics.\n        renorm_mean = ops.cast(self.renorm_mean, mean.dtype)\n        # Avoid divide by zero early on in training.\n        renorm_stddev = ops.maximum(\n            ops.cast(self.renorm_stddev, mean.dtype),\n            ops.sqrt(ops.cast(self.epsilon, mean.dtype)),\n        )\n\n        # Compute the corrections for batch renorm.\n        r = ops.divide(stddev, renorm_stddev)\n        d = ops.divide(ops.subtract(mean, renorm_mean), renorm_stddev)\n\n        # Apply clipping.\n        rmin = self.renorm_clipping.get(\"rmin\")\n        rmax = self.renorm_clipping.get(\"rmax\")\n        dmax = self.renorm_clipping.get(\"dmax\")\n\n        if rmin is not None:\n            r = ops.maximum(r, rmin)\n        if rmax is not None:\n            r = ops.minimum(r, rmax)\n        if dmax is not None:\n            d = ops.clip(d, -dmax, dmax)\n\n        return r, d, mean, variance\n\n    def _compose_transforms(self, r, d, gamma, beta, dtype):\n        \"\"\"Composes the renorm correction with gamma and beta.\n\n        When training with renorm, the normalized values (x) are transformed\n        as: (x * r + d) * gamma + beta = x * (r * gamma) + (d * gamma + beta).\n        This method computes the effective scale and offset.\n\n        Args:\n            r: The r correction factor.\n            d: The d correction factor.\n            gamma: The gamma (scale) parameter, or None.\n            beta: The beta (offset) parameter, or None.\n            dtype: The dtype for the output.\n\n        Returns:\n            A tuple (effective_gamma, effective_beta).\n        \"\"\"\n        r = ops.stop_gradient(r)\n        d = ops.stop_gradient(d)\n\n        if gamma is not None:\n            effective_gamma = ops.multiply(r, gamma)\n            effective_beta = ops.multiply(d, gamma)\n        else:\n            effective_gamma = ops.cast(r, dtype)\n            effective_beta = ops.cast(d, dtype)\n\n        if beta is not None:\n            effective_beta = ops.add(effective_beta, beta)\n\n        return effective_gamma, effective_beta\n\n    def _update_renorm_statistics(self, mean, variance):\n        \"\"\"Updates the renorm and moving statistics.\n        Args:\n            mean: The mean of the current batch.\n            variance: The variance of the current batch.\n        \"\"\"\n        stddev = ops.sqrt(variance + self.epsilon)\n\n        # Update renorm moving mean and stddev.\n        renorm_mean = ops.cast(self.renorm_mean, mean.dtype)\n        renorm_stddev = ops.cast(self.renorm_stddev, mean.dtype)\n\n        self.renorm_mean.assign(\n            renorm_mean * self.renorm_momentum\n            + mean * (1.0 - self.renorm_momentum)\n        )\n        self.renorm_stddev.assign(\n            renorm_stddev * self.renorm_momentum\n            + stddev * (1.0 - self.renorm_momentum)\n        )\n\n        moving_mean = ops.cast(self.moving_mean, mean.dtype)\n        moving_stddev = ops.cast(self.moving_stddev, mean.dtype)\n\n        self.moving_mean.assign(\n            moving_mean * self.momentum + mean * (1.0 - self.momentum)\n        )\n\n        new_moving_stddev = moving_stddev * self.momentum + stddev * (\n            1.0 - self.momentum\n        )\n        self.moving_stddev.assign(new_moving_stddev)\n\n        # Derive `moving_variance` from `moving_stddev`, applying ReLU in case\n        # floating point rounding causes it to go negative.\n        self.moving_variance.assign(\n            ops.relu(new_moving_stddev * new_moving_stddev - self.epsilon)\n        )\n"
  },
  {
    "path": "keras/src/layers/normalization/batch_normalization_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import initializers\nfrom keras.src import layers\nfrom keras.src import ops\nfrom keras.src import random\nfrom keras.src import testing\nfrom keras.src.losses import MeanSquaredError\nfrom keras.src.models import Model\n\n\nclass BatchNormalizationTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_bn_basics(self):\n        # vector case\n        self.run_layer_test(\n            layers.BatchNormalization,\n            init_kwargs={\n                \"center\": True,\n                \"scale\": True,\n            },\n            call_kwargs={\"training\": True},\n            input_shape=(2, 3),\n            expected_output_shape=(2, 3),\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=2,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.BatchNormalization,\n            init_kwargs={\n                \"center\": False,\n                \"scale\": False,\n            },\n            call_kwargs={\"training\": True},\n            input_shape=(2, 3),\n            expected_output_shape=(2, 3),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=2,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n        )\n        # image case, with regularizers\n        self.run_layer_test(\n            layers.BatchNormalization,\n            init_kwargs={\n                \"center\": True,\n                \"scale\": True,\n                \"beta_regularizer\": \"l2\",\n                \"gamma_regularizer\": \"l2\",\n            },\n            call_kwargs={\"training\": True},\n            input_shape=(2, 4, 4, 3),\n            expected_output_shape=(2, 4, 4, 3),\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=2,\n            expected_num_seed_generators=0,\n            expected_num_losses=2,  # we have 2 regularizers.\n            supports_masking=True,\n        )\n\n    @parameterized.product(\n        axis=(-1, 1),\n        input_shape=((5, 2, 3), (5, 3, 3, 2)),\n        moving_mean_initializer=(\"zeros\", \"ones\"),\n        moving_variance_initializer=(\"zeros\", \"ones\"),\n    )\n    def test_correctness(\n        self,\n        axis,\n        input_shape,\n        moving_mean_initializer,\n        moving_variance_initializer,\n    ):\n        # Training\n        layer = layers.BatchNormalization(\n            axis=axis,\n            momentum=0,\n            moving_mean_initializer=moving_mean_initializer,\n            moving_variance_initializer=moving_variance_initializer,\n        )\n        # Random data centered on 5.0, variance 10.0\n        x = np.random.normal(loc=5.0, scale=10.0, size=input_shape)\n        out = x\n        for _ in range(3):\n            out = layer(out, training=True)\n\n        # Assert the normalization is correct.\n        broadcast_shape = [1] * len(input_shape)\n        broadcast_shape[axis] = input_shape[axis]\n        out = backend.convert_to_numpy(out)\n        out = out - np.reshape(\n            backend.convert_to_numpy(layer.beta), broadcast_shape\n        )\n        out = out / np.reshape(\n            backend.convert_to_numpy(layer.gamma), broadcast_shape\n        )\n\n        reduction_axes = list(range(len(input_shape)))\n        del reduction_axes[axis]\n        reduction_axes = tuple(reduction_axes)\n        self.assertAllClose(np.mean(out, axis=reduction_axes), 0.0, atol=1e-3)\n        self.assertAllClose(np.std(out, axis=reduction_axes), 1.0, atol=1e-3)\n        self.assertAllClose(layer.moving_mean, 0.0, atol=1e-3)\n        self.assertAllClose(layer.moving_variance, 1.0, atol=1e-3)\n\n        # Inference done before training shouldn't match.\n        inference_out = layer(x, training=False)\n        training_out = layer(x, training=True)\n        self.assertNotAllClose(inference_out, training_out)\n\n        # Since momentum is zero, inference after training should match.\n        training_out = layer(x, training=True)\n        inference_out = layer(x, training=False)\n        self.assertAllClose(inference_out, training_out)\n\n        # Masked result with no training should not differ\n        x[:, 1, :] = 0.0\n        unmasked_out = layer(x, training=False)\n        masked = layers.Masking()(x)\n        masked_out = layer(masked, training=False)\n        self.assertAllClose(unmasked_out, masked_out)\n\n        # Masked result should differ from unmasked result\n        unmasked_out = layer(x, training=False)\n        x[:, 1, :] = 0.0\n        masked = layers.Masking()(x)\n        masked_out = layer(masked, training=True)\n        self.assertNotAllClose(unmasked_out, masked_out)\n\n    @parameterized.product(\n        synchronized=(\n            (False, True) if backend.backend == \"tensorflow\" else (False,)\n        ),\n    )\n    def test_input_fully_masked(self, synchronized):\n        norm = layers.BatchNormalization(\n            scale=False,\n            center=False,\n            synchronized=synchronized,\n        )\n        x = np.zeros((4, 5))\n        mask = np.zeros((4,), dtype=np.float32)\n        y = norm(x, mask=mask, training=True)\n        self.assertAllClose(y, np.zeros_like(x, dtype=np.float32))\n\n    @parameterized.product(run_eagerly=(True, False), mask_value=(0.0, 0.1, 1))\n    @pytest.mark.requires_trainable_backend\n    def test_bachnorm_ignore_masked_values(self, run_eagerly, mask_value):\n        padded_data = np.array(\n            [\n                [\n                    [1, 5],\n                    [2, 5],\n                    [mask_value, mask_value],\n                    [mask_value, mask_value],\n                ]\n                for _ in range(10)\n            ],\n            dtype=\"float32\",\n        )\n\n        inputs = layers.Input((None, 2))\n        masked = layers.Masking(mask_value=mask_value)(inputs)\n        normed = layers.BatchNormalization(momentum=0.0)(masked)\n        model = Model(inputs, normed)\n        loss = MeanSquaredError()\n        model.compile(\n            \"rmsprop\",\n            loss=loss,\n            run_eagerly=run_eagerly,\n        )\n        model.fit(x=padded_data, y=padded_data, batch_size=10, epochs=5)\n        self.assertAllClose(model.layers[2].moving_mean, [1.5, 5.0])\n        self.assertAllClose(model.layers[2].moving_variance, [0.25, 0.0])\n\n    def test_trainable_behavior(self):\n        layer = layers.BatchNormalization(axis=-1, momentum=0.8, epsilon=1e-7)\n        layer.build((1, 4, 4, 3))\n        layer.trainable = False\n        self.assertEqual(len(layer.weights), 4)\n        self.assertEqual(len(layer.trainable_weights), 0)\n        self.assertEqual(len(layer.non_trainable_weights), 4)\n\n        # Random data centered on 5.0, variance 10.0\n        x = np.random.normal(loc=5.0, scale=10.0, size=(200, 4, 4, 3))\n\n        out = layer(x, training=True)\n        self.assertAllClose(out, x)\n\n        layer.trainable = True\n        self.assertEqual(len(layer.weights), 4)\n        self.assertEqual(len(layer.trainable_weights), 2)\n        self.assertEqual(len(layer.non_trainable_weights), 2)\n\n        for _ in range(10):\n            out = layer(x, training=True)\n\n        out = backend.convert_to_numpy(out)\n        out = out - np.reshape(\n            backend.convert_to_numpy(layer.beta), (1, 1, 1, 3)\n        )\n        out = out / np.reshape(\n            backend.convert_to_numpy(layer.gamma), (1, 1, 1, 3)\n        )\n\n        self.assertAllClose(np.mean(out, axis=(0, 1, 2)), 0.0, atol=1e-3)\n        self.assertAllClose(np.std(out, axis=(0, 1, 2)), 1.0, atol=1e-3)\n\n    def test_large_value_within_autocast_scope(self):\n        layer = layers.BatchNormalization()\n        layer.build((1, 4, 4, 3))\n        # Use 70000 to trigger overflow for float16\n        large_value = ops.full(layer.moving_variance.shape, 70000)\n        with backend.AutocastScope(\"float16\"):\n            layer.moving_variance.assign(large_value)\n            self.assertAllClose(layer.moving_variance.value, large_value)\n\n    def test_masked_broadcast_normalization(self):\n        input_shape = (1, 2, 3, 4)\n        mask_shape = (1, 2, 1)\n        x = ops.ones(input_shape)\n        mask = ops.ones(mask_shape)\n\n        layer = layers.BatchNormalization(axis=-1, momentum=0.0, epsilon=1e-3)\n\n        y = layer(x, training=True, mask=mask)\n\n        mean_y = ops.mean(y, axis=[0, 1, 2])\n\n        self.assertAllClose(mean_y, ops.zeros((4,)), atol=1e-6)\n        self.assertAllClose(y, ops.zeros_like(y), atol=1e-6)\n\n        self.assertAllClose(layer.moving_mean, ops.ones((4,)), atol=1e-6)\n        self.assertAllClose(layer.moving_variance, ops.zeros((4,)), atol=1e-6)\n\n    @pytest.mark.requires_trainable_backend\n    def test_renorm_basics(self):\n        # Test basic renorm functionality\n        self.run_layer_test(\n            layers.BatchNormalization,\n            init_kwargs={\n                \"center\": True,\n                \"scale\": True,\n                \"renorm\": True,\n            },\n            call_kwargs={\"training\": True},\n            input_shape=(2, 3),\n            expected_output_shape=(2, 3),\n            expected_num_trainable_weights=2,\n            # moving_mean, moving_variance, moving_stddev, renorm_mean,\n            # renorm_stddev\n            expected_num_non_trainable_weights=5,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n        )\n        # Test renorm with clipping\n        self.run_layer_test(\n            layers.BatchNormalization,\n            init_kwargs={\n                \"center\": True,\n                \"scale\": True,\n                \"renorm\": True,\n                \"renorm_clipping\": {\"rmax\": 3.0, \"rmin\": 0.3, \"dmax\": 5.0},\n            },\n            call_kwargs={\"training\": True},\n            input_shape=(2, 4, 4, 3),\n            expected_output_shape=(2, 4, 4, 3),\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=5,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n        )\n\n    def test_renorm_invalid_clipping_keys(self):\n        with self.assertRaisesRegex(ValueError, \"Received invalid keys\"):\n            layers.BatchNormalization(\n                renorm=True, renorm_clipping={\"random_key\": 1.0}\n            )\n        with self.assertRaisesRegex(ValueError, \"rmax should be\"):\n            layers.BatchNormalization(\n                renorm=True, renorm_clipping={\"rmax\": 0.0, \"rmin\": 1.0}\n            )\n        with self.assertRaisesRegex(ValueError, \"dmax should be non-negative\"):\n            layers.BatchNormalization(\n                renorm=True, renorm_clipping={\"rmax\": 1.0, \"dmax\": -1.0}\n            )\n\n    def test_renorm_stddev_initializer(self):\n        # `moving_stddev` and `renorm_stddev` should be initialized as\n        # `sqrt` of `moving_variance_initializer`.\n        layer = layers.BatchNormalization(\n            renorm=True,\n            moving_variance_initializer=initializers.Constant(4.0),\n        )\n        layer.build((None, 5))\n\n        self.assertAllClose(layer.moving_stddev, np.full((5,), 2.0), atol=1e-6)\n        self.assertAllClose(layer.renorm_stddev, np.full((5,), 2.0), atol=1e-6)\n\n    def test_renorm_inference(self):\n        # At inference time, the behaviour of both with and without renorm\n        # should be the same.\n        bn = layers.BatchNormalization(renorm=False)\n        bn_renorm = layers.BatchNormalization(renorm=True)\n\n        bn.build((None, 10))\n        bn_renorm.build((None, 10))\n\n        # Copy the vars to renorm layer.\n        for attr in [\"gamma\", \"beta\", \"moving_mean\", \"moving_variance\"]:\n            getattr(bn, attr).assign(random.normal(shape=(10,)))\n            getattr(bn_renorm, attr).assign(getattr(bn, attr))\n\n        x = np.random.normal(size=(4, 10))\n        out = bn(x, training=False)\n        out_renorm = bn_renorm(x, training=False)\n\n        self.assertAllClose(out, out_renorm, atol=1e-5, rtol=1e-5)\n\n    @pytest.mark.requires_trainable_backend\n    def test_renorm_correctness(self):\n        epsilon = 1e-3\n        momentum = 0.9\n        renorm_momentum = 0.8\n\n        # Create layer\n        layer = layers.BatchNormalization(\n            axis=-1,\n            epsilon=epsilon,\n            momentum=momentum,\n            renorm=True,\n            renorm_momentum=renorm_momentum,\n        )\n        layer.build((None, 3))\n\n        # Assign initial values.\n        size = (3,)\n        init_moving_mean = np.random.normal(0.0, 1.0, size=size)\n        init_moving_var = np.abs(np.random.normal(1.0, 0.5, size=size))\n        init_moving_stddev = np.sqrt(init_moving_var)\n        init_renorm_mean = np.random.normal(0.0, 1.0, size=size)\n        init_renorm_stddev = np.abs(np.random.normal(1.0, 0.5, size=size))\n        init_gamma = np.random.normal(1.0, 0.1, size=size)\n        init_beta = np.random.normal(0.0, 0.1, size=size)\n\n        layer.moving_mean.assign(init_moving_mean)\n        layer.moving_variance.assign(init_moving_var)\n        layer.moving_stddev.assign(init_moving_stddev)\n        layer.renorm_mean.assign(init_renorm_mean)\n        layer.renorm_stddev.assign(init_renorm_stddev)\n        layer.gamma.assign(init_gamma)\n        layer.beta.assign(init_beta)\n\n        # Input data\n        x = np.array(\n            [[4.0, 6.0, 2.0], [8.0, -2.0, 5.0], [6.0, 4.0, 3.0]],\n            dtype=\"float32\",\n        )\n\n        # Manually compute expected output.\n        # Normalise input.\n        batch_mean = np.mean(x, axis=0)\n        batch_var = np.var(x, axis=0)\n        batch_stddev = np.sqrt(batch_var + epsilon)\n        x_norm = (x - batch_mean) / batch_stddev\n\n        # Compute r, d, and then expected output.\n        r = batch_stddev / init_renorm_stddev\n        d = (batch_mean - init_renorm_mean) / init_renorm_stddev\n\n        expected_output = (x_norm * r + d) * init_gamma + init_beta\n        actual_output = layer(x, training=True)\n        self.assertAllClose(actual_output, expected_output, atol=1e-5)\n\n        # Verify moving statistics.\n        expected_renorm_mean = (\n            init_renorm_mean * renorm_momentum\n            + batch_mean * (1 - renorm_momentum)\n        )\n        self.assertAllClose(\n            layer.renorm_mean,\n            expected_renorm_mean,\n            atol=1e-5,\n        )\n        expected_renorm_stddev = (\n            init_renorm_stddev * renorm_momentum\n            + batch_stddev * (1 - renorm_momentum)\n        )\n        self.assertAllClose(\n            layer.renorm_stddev,\n            expected_renorm_stddev,\n            atol=1e-5,\n        )\n        expected_moving_mean = init_moving_mean * momentum + batch_mean * (\n            1 - momentum\n        )\n        self.assertAllClose(\n            layer.moving_mean,\n            expected_moving_mean,\n            atol=1e-5,\n        )\n        expected_moving_stddev = (\n            init_moving_stddev * momentum + batch_stddev * (1 - momentum)\n        )\n        self.assertAllClose(\n            layer.moving_stddev,\n            expected_moving_stddev,\n            atol=1e-5,\n        )\n        expected_moving_var = expected_moving_stddev**2 - epsilon\n        self.assertAllClose(\n            layer.moving_variance,\n            expected_moving_var,\n            atol=1e-5,\n        )\n\n    def test_serialization(self):\n        layer = layers.BatchNormalization(\n            renorm=True,\n            renorm_clipping={\"rmax\": 3.0, \"rmin\": 0.3, \"dmax\": 5.0},\n            renorm_momentum=0.95,\n        )\n\n        config = layer.get_config()\n        new_layer = layers.BatchNormalization.from_config(config)\n        self.assertEqual(new_layer.get_config(), config)\n"
  },
  {
    "path": "keras/src/layers/normalization/group_normalization.py",
    "content": "from keras.src import backend\nfrom keras.src import constraints\nfrom keras.src import initializers\nfrom keras.src import ops\nfrom keras.src import regularizers\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\n\n\n@keras_export(\"keras.layers.GroupNormalization\")\nclass GroupNormalization(Layer):\n    \"\"\"Group normalization layer.\n\n    Group Normalization divides the channels into groups and computes\n    within each group the mean and variance for normalization.\n    Empirically, its accuracy is more stable than batch norm in a wide\n    range of small batch sizes, if learning rate is adjusted linearly\n    with batch sizes.\n\n    Relation to Layer Normalization:\n    If the number of groups is set to 1, then this operation becomes nearly\n    identical to Layer Normalization (see Layer Normalization docs for details).\n\n    Relation to Instance Normalization:\n    If the number of groups is set to the input dimension (number of groups is\n    equal to number of channels), then this operation becomes identical to\n    Instance Normalization. You can achieve this via `groups=-1`.\n\n    Args:\n        groups: Integer, the number of groups for Group Normalization. Can be in\n            the range `[1, N]` where N is the input dimension. The input\n            dimension must be divisible by the number of groups.\n            Defaults to 32.\n        axis: Integer or List/Tuple. The axis or axes to normalize across.\n            Typically, this is the features axis/axes. The left-out axes are\n            typically the batch axis/axes. -1 is the last dimension in the\n            input. Defaults to `-1`.\n        epsilon: Small float added to variance to avoid dividing by zero.\n            Defaults to 1e-3.\n        center: If `True`, add offset of `beta` to normalized tensor.\n            If `False`, `beta` is ignored. Defaults to `True`.\n        scale: If `True`, multiply by `gamma`. If `False`, `gamma` is not used.\n            When the next layer is linear (also e.g. `relu`), this can be\n            disabled since the scaling will be done by the next layer.\n            Defaults to `True`.\n        beta_initializer: Initializer for the beta weight. Defaults to zeros.\n        gamma_initializer: Initializer for the gamma weight. Defaults to ones.\n        beta_regularizer: Optional regularizer for the beta weight. None by\n            default.\n        gamma_regularizer: Optional regularizer for the gamma weight. None by\n            default.\n        beta_constraint: Optional constraint for the beta weight.\n            None by default.\n        gamma_constraint: Optional constraint for the gamma weight. None by\n            default.  Input shape: Arbitrary. Use the keyword argument\n            `input_shape` (tuple of integers, does not include the samples\n            axis) when using this layer as the first layer in a model.\n            Output shape: Same shape as input.\n        **kwargs: Base layer keyword arguments (e.g. `name` and `dtype`).\n\n    Reference:\n\n    - [Yuxin Wu & Kaiming He, 2018](https://arxiv.org/abs/1803.08494)\n    \"\"\"\n\n    def __init__(\n        self,\n        groups=32,\n        axis=-1,\n        epsilon=1e-3,\n        center=True,\n        scale=True,\n        beta_initializer=\"zeros\",\n        gamma_initializer=\"ones\",\n        beta_regularizer=None,\n        gamma_regularizer=None,\n        beta_constraint=None,\n        gamma_constraint=None,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.supports_masking = True\n        self.groups = groups\n        self.axis = axis\n        self.epsilon = epsilon\n        self.center = center\n        self.scale = scale\n        self.beta_initializer = initializers.get(beta_initializer)\n        self.gamma_initializer = initializers.get(gamma_initializer)\n        self.beta_regularizer = regularizers.get(beta_regularizer)\n        self.gamma_regularizer = regularizers.get(gamma_regularizer)\n        self.beta_constraint = constraints.get(beta_constraint)\n        self.gamma_constraint = constraints.get(gamma_constraint)\n\n    def build(self, input_shape):\n        dim = input_shape[self.axis]\n\n        if dim is None:\n            raise ValueError(\n                f\"Axis {self.axis} of input tensor should have a defined \"\n                \"dimension but the layer received an input with shape \"\n                f\"{input_shape}.\"\n            )\n\n        if self.groups == -1:\n            self.groups = dim\n\n        if dim < self.groups:\n            raise ValueError(\n                f\"Number of groups ({self.groups}) cannot be more than the \"\n                f\"number of channels ({dim}).\"\n            )\n\n        if dim % self.groups != 0:\n            raise ValueError(\n                f\"Number of groups ({self.groups}) must be a multiple \"\n                f\"of the number of channels ({dim}).\"\n            )\n\n        self.input_spec = InputSpec(\n            ndim=len(input_shape), axes={self.axis: dim}\n        )\n\n        if self.scale:\n            self.gamma = self.add_weight(\n                shape=(dim,),\n                name=\"gamma\",\n                initializer=self.gamma_initializer,\n                regularizer=self.gamma_regularizer,\n                constraint=self.gamma_constraint,\n            )\n        else:\n            self.gamma = None\n\n        if self.center:\n            self.beta = self.add_weight(\n                shape=(dim,),\n                name=\"beta\",\n                initializer=self.beta_initializer,\n                regularizer=self.beta_regularizer,\n                constraint=self.beta_constraint,\n            )\n        else:\n            self.beta = None\n\n        super().build(input_shape)\n\n    def call(self, inputs):\n        reshaped_inputs = self._reshape_into_groups(inputs)\n        normalized_inputs = self._apply_normalization(\n            reshaped_inputs, inputs.shape\n        )\n        return ops.reshape(normalized_inputs, ops.shape(inputs))\n\n    def _reshape_into_groups(self, inputs):\n        input_shape = ops.shape(inputs)\n        group_shape = list(inputs.shape)\n        group_shape[0] = -1\n        for i, e in enumerate(group_shape[1:]):\n            if e is None:\n                group_shape[i + 1] = input_shape[i + 1]\n\n        group_shape[self.axis] = input_shape[self.axis] // self.groups\n        group_shape.insert(self.axis, self.groups)\n        reshaped_inputs = ops.reshape(inputs, group_shape)\n        return reshaped_inputs\n\n    def _apply_normalization(self, reshaped_inputs, input_shape):\n        inputs_dtype = reshaped_inputs.dtype\n        compute_dtype = backend.result_type(inputs_dtype, \"float32\")\n        # GN is prone to overflow with float16/bfloat16 inputs, so we upcast to\n        # float32 for the subsequent computations.\n        reshaped_inputs = ops.cast(reshaped_inputs, compute_dtype)\n\n        group_reduction_axes = list(range(1, len(reshaped_inputs.shape)))\n\n        axis = -2 if self.axis == -1 else self.axis - 1\n        group_reduction_axes.pop(axis)\n\n        broadcast_shape = self._create_broadcast_shape(input_shape)\n        mean, variance = ops.moments(\n            reshaped_inputs, axes=group_reduction_axes, keepdims=True\n        )\n\n        # Compute the batch normalization.\n        inv = ops.rsqrt(variance + self.epsilon)\n        if self.scale:\n            gamma = ops.reshape(self.gamma, broadcast_shape)\n            gamma = ops.cast(gamma, reshaped_inputs.dtype)\n            inv = inv * gamma\n\n        res = -mean * inv\n        if self.center:\n            beta = ops.reshape(self.beta, broadcast_shape)\n            beta = ops.cast(beta, reshaped_inputs.dtype)\n            res = res + beta\n\n        normalized_inputs = reshaped_inputs * inv + res\n        normalized_inputs = ops.cast(normalized_inputs, inputs_dtype)\n\n        return normalized_inputs\n\n    def _create_broadcast_shape(self, input_shape):\n        broadcast_shape = [1] * len(input_shape)\n        broadcast_shape[self.axis] = input_shape[self.axis] // self.groups\n        broadcast_shape.insert(self.axis, self.groups)\n        return broadcast_shape\n\n    def compute_output_shape(self, input_shape):\n        if isinstance(self.axis, int):\n            axes = [self.axis]\n        else:\n            axes = self.axis\n\n        for axis in axes:\n            if axis >= len(input_shape) or axis < -len(input_shape):\n                raise ValueError(\n                    f\"Axis {axis} is out of bounds for \"\n                    f\"input shape {input_shape}. \"\n                    f\"Received: axis={self.axis}\"\n                )\n        return input_shape\n\n    def get_config(self):\n        config = {\n            \"groups\": self.groups,\n            \"axis\": self.axis,\n            \"epsilon\": self.epsilon,\n            \"center\": self.center,\n            \"scale\": self.scale,\n            \"beta_initializer\": initializers.serialize(self.beta_initializer),\n            \"gamma_initializer\": initializers.serialize(self.gamma_initializer),\n            \"beta_regularizer\": regularizers.serialize(self.beta_regularizer),\n            \"gamma_regularizer\": regularizers.serialize(self.gamma_regularizer),\n            \"beta_constraint\": constraints.serialize(self.beta_constraint),\n            \"gamma_constraint\": constraints.serialize(self.gamma_constraint),\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/normalization/group_normalization_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import constraints\nfrom keras.src import layers\nfrom keras.src import regularizers\nfrom keras.src import testing\n\n\nclass GroupNormalizationTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_groupnorm(self):\n        self.run_layer_test(\n            layers.GroupNormalization,\n            init_kwargs={\n                \"gamma_regularizer\": regularizers.L2(0.01),\n                \"beta_regularizer\": regularizers.L2(0.01),\n            },\n            input_shape=(3, 4, 32),\n            expected_output_shape=(3, 4, 32),\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=2,\n            supports_masking=True,\n        )\n\n        self.run_layer_test(\n            layers.GroupNormalization,\n            init_kwargs={\n                \"groups\": 4,\n                \"gamma_constraint\": constraints.UnitNorm(),\n                \"beta_constraint\": constraints.UnitNorm(),\n            },\n            input_shape=(3, 4, 4),\n            expected_output_shape=(3, 4, 4),\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n        )\n\n    def test_undefined_dim_error(self):\n        inputs = layers.Input(shape=(2, 2, 2, None))\n        layer = layers.GroupNormalization()\n        with self.assertRaisesRegex(\n            ValueError,\n            (\n                \"input tensor should have a defined dimension but the layer \"\n                \"received an input with shape\"\n            ),\n        ):\n            _ = layer(inputs)\n\n    def test_groups_bigger_than_dim_error(self):\n        inputs = np.ones(shape=(2, 2, 2, 4))\n        layer = layers.GroupNormalization(groups=5)\n        with self.assertRaisesRegex(\n            ValueError,\n            \"cannot be more than the number of channels\",\n        ):\n            _ = layer(inputs)\n\n    def test_groups_not_a_multiple_of_dim_error(self):\n        inputs = np.ones(shape=(2, 2, 2, 4))\n        layer = layers.GroupNormalization(groups=3)\n        with self.assertRaisesRegex(\n            ValueError,\n            \"must be a multiple of the number of channels\",\n        ):\n            _ = layer(inputs)\n\n    def test_groups_instance_norm(self):\n        # GroupNormalization with groups=-1 will become InstanceNormalization\n        instance_norm_layer_1 = layers.GroupNormalization(\n            groups=-1, axis=-1, scale=False, center=False\n        )\n        instance_norm_layer_2 = layers.GroupNormalization(\n            groups=4, axis=-1, scale=False, center=False\n        )\n        inputs = np.array([[[-1.0, 1.0, 0, 2.0], [1.0, 3.0, -4, -2.0]]])\n\n        outputs_1 = instance_norm_layer_1(inputs)\n        outputs_2 = instance_norm_layer_2(inputs)\n\n        self.assertAllClose(outputs_1, outputs_2)\n\n    def test_correctness_instance_norm(self):\n        instance_norm_layer = layers.GroupNormalization(\n            groups=4, axis=-1, scale=False, center=False\n        )\n\n        inputs = np.array([[[-1.0, 1.0, 0, 2.0], [1.0, 3.0, -4, -2.0]]])\n\n        expected_instance_norm_output = np.array(\n            [[[-1.0, -1.0, 1.0, 1.0], [1.0, 1.0, -1.0, -1.0]]]\n        )\n\n        self.assertAllClose(\n            instance_norm_layer(inputs),\n            expected_instance_norm_output,\n            atol=1e-3,\n        )\n\n    def test_correctness_1d(self):\n        layer_with_1_group = layers.GroupNormalization(\n            groups=1, axis=-1, scale=False, center=False\n        )\n        layer_with_2_groups = layers.GroupNormalization(\n            groups=2, axis=1, scale=False, center=False\n        )\n\n        inputs = np.array([[-1.0, -1.0, 1.0, 1.0, 2.0, 2.0, 0, -2.0]])\n\n        expected_output_1_group = np.array(\n            [[-0.898, -0.898, 0.539, 0.539, 1.257, 1.257, -0.180, -1.616]],\n        )\n        self.assertAllClose(\n            layer_with_1_group(inputs),\n            expected_output_1_group,\n            atol=1e-3,\n        )\n\n        expected_output_2_groups = np.array(\n            [[-1.0, -1.0, 1.0, 1.0, 0.904, 0.904, -0.301, -1.507]]\n        )\n        self.assertAllClose(\n            layer_with_2_groups(inputs),\n            expected_output_2_groups,\n            atol=1e-3,\n        )\n\n    def test_correctness_2d(self):\n        layer_with_1_group = layers.GroupNormalization(\n            groups=1, axis=-1, scale=False, center=False\n        )\n        layer_with_2_groups = layers.GroupNormalization(\n            groups=2, axis=2, scale=False, center=False\n        )\n\n        inputs = np.array([[[-1.0, -1.0, 2.0, 2.0], [1.0, 1.0, 0, -2.0]]])\n\n        expected_output_1_group = np.array(\n            [[[-0.898, -0.898, 1.257, 1.257], [0.539, 0.539, -0.180, -1.616]]]\n        )\n\n        self.assertAllClose(\n            layer_with_1_group(inputs),\n            expected_output_1_group,\n            atol=1e-3,\n        )\n\n        expected_output_2_groups = np.array(\n            [[[-1.0, -1.0, 0.904, 0.904], [1.0, 1.0, -0.301, -1.507]]]\n        )\n        self.assertAllClose(\n            layer_with_2_groups(inputs),\n            expected_output_2_groups,\n            atol=1e-3,\n        )\n\n    def test_broadcasting_2d_channels_first(self):\n        x = np.arange(16).reshape((1, 4, 2, 2)).astype(\"float32\")\n        x = layers.GroupNormalization(groups=2, axis=1)(x)\n        self.assertAllClose(\n            x,\n            np.array(\n                [\n                    [\n                        [[-1.5274, -1.0910], [-0.6546, -0.2182]],\n                        [[0.2182, 0.6546], [1.0910, 1.5274]],\n                        [[-1.5274, -1.0910], [-0.6546, -0.2182]],\n                        [[0.2182, 0.6546], [1.0910, 1.5274]],\n                    ]\n                ]\n            ),\n            atol=1e-3,\n        )\n"
  },
  {
    "path": "keras/src/layers/normalization/layer_normalization.py",
    "content": "import warnings\n\nfrom keras.src import constraints\nfrom keras.src import initializers\nfrom keras.src import ops\nfrom keras.src import regularizers\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.layer import Layer\n\n\n@keras_export(\"keras.layers.LayerNormalization\")\nclass LayerNormalization(Layer):\n    \"\"\"Layer normalization layer (Ba et al., 2016).\n\n    Normalize the activations of the previous layer for each given example in a\n    batch independently, rather than across a batch like Batch Normalization.\n    i.e. applies a transformation that maintains the mean activation within each\n    example close to 0 and the activation standard deviation close to 1.\n\n    If `scale` or `center` are enabled, the layer will scale the normalized\n    outputs by broadcasting them with a trainable variable `gamma`, and center\n    the outputs by broadcasting with a trainable variable `beta`. `gamma` will\n    default to a ones tensor and `beta` will default to a zeros tensor, so that\n    centering and scaling are no-ops before training has begun.\n\n    So, with scaling and centering enabled the normalization equations\n    are as follows:\n\n    Let the intermediate activations for a mini-batch to be the `inputs`.\n\n    For each sample `x_i` in `inputs` with `k` features, we compute the mean and\n    variance of the sample:\n\n    ```python\n    mean_i = sum(x_i[j] for j in range(k)) / k\n    var_i = sum((x_i[j] - mean_i) ** 2 for j in range(k)) / k\n    ```\n\n    and then compute a normalized `x_i_normalized`, including a small factor\n    `epsilon` for numerical stability.\n\n    ```python\n    x_i_normalized = (x_i - mean_i) / sqrt(var_i + epsilon)\n    ```\n\n    And finally `x_i_normalized ` is linearly transformed by `gamma` and `beta`,\n    which are learned parameters:\n\n    ```python\n    output_i = x_i_normalized * gamma + beta\n    ```\n\n    `gamma` and `beta` will span the axes of `inputs` specified in `axis`, and\n    this part of the inputs' shape must be fully defined.\n\n    For example:\n\n    >>> layer = keras.layers.LayerNormalization(axis=[1, 2, 3])\n    >>> layer.build([5, 20, 30, 40])\n    >>> print(layer.beta.shape)\n    (20, 30, 40)\n    >>> print(layer.gamma.shape)\n    (20, 30, 40)\n\n    Note that other implementations of layer normalization may choose to define\n    `gamma` and `beta` over a separate set of axes from the axes being\n    normalized across. For example, Group Normalization\n    ([Wu et al. 2018](https://arxiv.org/abs/1803.08494)) with group size of 1\n    corresponds to a Layer Normalization that normalizes across height, width,\n    and channel and has `gamma` and `beta` span only the channel dimension.\n    So, this Layer Normalization implementation will not match a Group\n    Normalization layer with group size set to 1.\n\n    Args:\n        axis: Integer or List/Tuple. The axis or axes to normalize across.\n            Typically, this is the features axis/axes. The left-out axes are\n            typically the batch axis/axes. `-1` is the last dimension in the\n            input. Defaults to `-1`.\n        epsilon: Small float added to variance to avoid dividing by zero.\n            Defaults to 1e-3.\n        center: If True, add offset of `beta` to normalized tensor. If False,\n            `beta` is ignored. Defaults to `True`.\n        scale: If True, multiply by `gamma`. If False, `gamma` is not used.\n            When the next layer is linear (also e.g. `nn.relu`), this can be\n            disabled since the scaling will be done by the next layer.\n            Defaults to `True`.\n        beta_initializer: Initializer for the beta weight. Defaults to zeros.\n        gamma_initializer: Initializer for the gamma weight. Defaults to ones.\n        beta_regularizer: Optional regularizer for the beta weight.\n            None by default.\n        gamma_regularizer: Optional regularizer for the gamma weight.\n            None by default.\n        beta_constraint: Optional constraint for the beta weight.\n            None by default.\n        gamma_constraint: Optional constraint for the gamma weight.\n            None by default.\n        **kwargs: Base layer keyword arguments (e.g. `name` and `dtype`).\n\n\n    Reference:\n\n    - [Lei Ba et al., 2016](https://arxiv.org/abs/1607.06450).\n    \"\"\"\n\n    def __init__(\n        self,\n        axis=-1,\n        epsilon=1e-3,\n        center=True,\n        scale=True,\n        beta_initializer=\"zeros\",\n        gamma_initializer=\"ones\",\n        beta_regularizer=None,\n        gamma_regularizer=None,\n        beta_constraint=None,\n        gamma_constraint=None,\n        **kwargs,\n    ):\n        rms_scaling = kwargs.pop(\"rms_scaling\", False)\n        if rms_scaling:\n            warnings.warn(\n                \"You passed `rms_scaling=True`, which is deprecated. This \"\n                \"argument incorrectly scales the input by the variance, not \"\n                \"the root mean square. To correctly use RMS Normalization, \"\n                \"please use `keras.layers.RMSNormalization` instead.\"\n            )\n\n        super().__init__(**kwargs)\n        if isinstance(axis, (list, tuple)):\n            self.axis = list(axis)\n        elif isinstance(axis, int):\n            self.axis = axis\n        else:\n            raise TypeError(\n                \"Expected an int or a list/tuple of ints for the \"\n                \"argument 'axis', but received: %r\" % axis\n            )\n\n        self.epsilon = epsilon\n        self.center = center\n        self.scale = scale\n        self.rms_scaling = rms_scaling\n        self.beta_initializer = initializers.get(beta_initializer)\n        self.gamma_initializer = initializers.get(gamma_initializer)\n        self.beta_regularizer = regularizers.get(beta_regularizer)\n        self.gamma_regularizer = regularizers.get(gamma_regularizer)\n        self.beta_constraint = constraints.get(beta_constraint)\n        self.gamma_constraint = constraints.get(gamma_constraint)\n\n        self.supports_masking = True\n        self.autocast = False\n\n    def build(self, input_shape):\n        if isinstance(self.axis, (list, tuple)):\n            self.axis = sorted(self.axis)\n            shape = tuple(input_shape[dim] for dim in self.axis)\n        else:\n            shape = (input_shape[self.axis],)\n            self.axis = [self.axis]\n        if self.scale or self.rms_scaling:\n            self.gamma = self.add_weight(\n                name=\"gamma\",\n                shape=shape,\n                initializer=self.gamma_initializer,\n                regularizer=self.gamma_regularizer,\n                constraint=self.gamma_constraint,\n                trainable=True,\n                autocast=False,\n            )\n        else:\n            self.gamma = None\n\n        if self.center and not self.rms_scaling:\n            self.beta = self.add_weight(\n                name=\"beta\",\n                shape=shape,\n                initializer=self.beta_initializer,\n                regularizer=self.beta_regularizer,\n                constraint=self.beta_constraint,\n                trainable=True,\n                autocast=False,\n            )\n        else:\n            self.beta = None\n\n    def call(self, inputs):\n        outputs = ops.layer_normalization(\n            inputs,\n            self.gamma,\n            self.beta,\n            self.axis,\n            self.epsilon,\n            rms_scaling=self.rms_scaling,\n        )\n        return ops.cast(outputs, self.compute_dtype)\n\n    def compute_output_shape(self, input_shape):\n        if isinstance(self.axis, int):\n            axes = [self.axis]\n        else:\n            axes = self.axis\n\n        for axis in axes:\n            if axis >= len(input_shape) or axis < -len(input_shape):\n                raise ValueError(\n                    f\"Axis {axis} is out of bounds for \"\n                    f\"input shape {input_shape}. \"\n                    f\"Received: axis={self.axis}\"\n                )\n        return input_shape\n\n    def get_config(self):\n        config = {\n            \"axis\": self.axis,\n            \"epsilon\": self.epsilon,\n            \"center\": self.center,\n            \"scale\": self.scale,\n            \"rms_scaling\": self.rms_scaling,\n            \"beta_initializer\": initializers.serialize(self.beta_initializer),\n            \"gamma_initializer\": initializers.serialize(self.gamma_initializer),\n            \"beta_regularizer\": regularizers.serialize(self.beta_regularizer),\n            \"gamma_regularizer\": regularizers.serialize(self.gamma_regularizer),\n            \"beta_constraint\": constraints.serialize(self.beta_constraint),\n            \"gamma_constraint\": constraints.serialize(self.gamma_constraint),\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/normalization/layer_normalization_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import ops\nfrom keras.src import regularizers\nfrom keras.src import testing\n\n\nclass LayerNormalizationTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_ln_basics(self):\n        self.run_layer_test(\n            layers.LayerNormalization,\n            init_kwargs={\n                \"gamma_regularizer\": regularizers.L2(0.01),\n                \"beta_regularizer\": regularizers.L2(0.01),\n            },\n            input_shape=(3, 4, 2),\n            expected_output_shape=(3, 4, 2),\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=2,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.LayerNormalization,\n            init_kwargs={\n                \"gamma_initializer\": \"ones\",\n                \"beta_initializer\": \"ones\",\n            },\n            input_shape=(3, 4, 2),\n            expected_output_shape=(3, 4, 2),\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.LayerNormalization,\n            init_kwargs={\"scale\": False, \"center\": False},\n            input_shape=(3, 3),\n            expected_output_shape=(3, 3),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.LayerNormalization,\n            init_kwargs={\"rms_scaling\": True},\n            input_shape=(3, 3),\n            expected_output_shape=(3, 3),\n            expected_num_trainable_weights=1,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.LayerNormalization,\n            init_kwargs={\"axis\": (-3, -2, -1)},\n            input_shape=(2, 8, 8, 3),\n            expected_output_shape=(2, 8, 8, 3),\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.LayerNormalization,\n            init_kwargs={},\n            input_shape=(1, 0, 10),\n            expected_output_shape=(1, 0, 10),\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n        )\n\n    def test_invalid_axis(self):\n        with self.assertRaisesRegex(\n            TypeError,\n            (\"Expected an int or a list/tuple of ints for the argument 'axis'\"),\n        ):\n            layers.LayerNormalization(axis={\"axis\": -1})\n\n    def test_correctness(self):\n        layer = layers.LayerNormalization(dtype=\"float32\")\n        layer.build(input_shape=(2, 2, 2))\n        inputs = np.random.normal(\n            loc=5.0, scale=10.0, size=(1000, 2, 2, 2)\n        ).astype(\"float32\")\n\n        out = layer(inputs)\n        out = ops.subtract(out, layer.beta)\n        out = ops.divide(out, layer.gamma)\n\n        self.assertAllClose(ops.mean(out), 0.0, atol=1e-1)\n        self.assertAllClose(ops.std(out), 1.0, atol=1e-1)\n\n    def test_output(self):\n        layer = layers.LayerNormalization(\n            dtype=\"float32\",\n            beta_initializer=\"ones\",\n            gamma_initializer=\"ones\",\n        )\n        inputs = np.arange(5).astype(\"float32\")[None, :]\n        out = layer(inputs)\n        self.assertAllClose(out, [[-0.41386, 0.29307, 1.0, 1.70693, 2.41386]])\n\n    def test_output_with_rms_scaling(self):\n        layer = layers.LayerNormalization(\n            dtype=\"float32\",\n            rms_scaling=True,\n            gamma_initializer=\"ones\",\n        )\n        inputs = np.arange(5).astype(\"float32\")[None, :]\n        out = layer(inputs)\n        self.assertAllClose(out, [[0.0, 0.70693, 1.41386, 2.12079, 2.82772]])\n\n    def test_large_value_within_autocast_scope(self):\n        layer = layers.LayerNormalization()\n        layer.build((1, 4, 4, 3))\n        # Use 70000 to trigger overflow for float16\n        large_value = ops.full(layer.gamma.shape, 70000)\n        with backend.AutocastScope(\"float16\"):\n            layer.gamma.assign(large_value)\n            self.assertAllClose(layer.gamma.value, large_value)\n\n    def test_unsorted_axis(self):\n        x = np.random.randn(2, 3, 4).astype(\"float32\")\n        layer_sorted = layers.LayerNormalization(axis=[-2, -1])\n        layer_unsorted = layers.LayerNormalization(axis=[-1, -2])\n        out_sorted = layer_sorted(x)\n        out_unsorted = layer_unsorted(x)\n        self.assertEqual(out_sorted.shape, (2, 3, 4))\n        self.assertEqual(out_unsorted.shape, (2, 3, 4))\n        self.assertAllClose(out_sorted, out_unsorted)\n"
  },
  {
    "path": "keras/src/layers/normalization/rms_normalization.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.layer import Layer\n\n\n@keras_export(\"keras.layers.RMSNormalization\")\nclass RMSNormalization(Layer):\n    \"\"\"Root Mean Square (RMS) Normalization layer.\n\n    This layer normalizes the input tensor based on its RMS value.\n\n    The Keras layer performs the operation as described in\n    [Root Mean Square Layer Normalization](https://arxiv.org/pdf/1910.07467)\n    by Biao Zhang et al.\n\n\n    If `scale` is enabled, the layer will scale the normalized outputs via\n    a learnable scaling factor.\n\n    So, with scaling enabled, the normalization equations\n    are as follows:\n\n    Let the intermediate activations for a mini-batch to be the `inputs`.\n\n    ```python\n    rms_normalization(x) = x * rsqrt(mean(square(x))) * scale\n    ```\n\n    For example:\n\n    >>> layer = keras.layers.RMSNormalization()\n    >>> layer.build([5, 20, 30, 10])\n    >>> print(layer.scale.shape)\n    (10,)\n    >>> layer(np.random.rand(1, 10)).numpy()\n    array([[0.35098287, 1.0495652 , 1.4645109 , 1.2944688 , 0.31124955,\n            1.2768592 , 1.184331  , 0.17474432, 0.49955517, 1.2428929 ]],\n        dtype=float32)\n\n    Args:\n        axis: int. The axis on which to perform the normalization.\n        epsilon: float. A small number to add to avoid division by zero.\n    \"\"\"\n\n    def __init__(self, axis=-1, epsilon=1e-6, **kwargs):\n        super().__init__(**kwargs)\n        self.axis = axis\n        self.epsilon = epsilon\n\n    def build(self, input_shape):\n        if isinstance(self.axis, (list, tuple)):\n            self.axis = sorted(self.axis)\n            shape = tuple(input_shape[dim] for dim in self.axis)\n        else:\n            shape = (input_shape[self.axis],)\n            self.axis = [self.axis]\n\n        self.scale = self.add_weight(\n            name=\"scale\", shape=shape, initializer=\"ones\"\n        )\n\n        self.built = True\n\n    def call(self, x):\n        \"\"\"Applies RMS normalization to the input tensor.\n\n        Args:\n            x: Input tensor of shape (batch_size, input_dim).\n\n        Returns:\n            The RMS-normalized tensor of the same shape (batch_size, input_dim),\n            scaled by the learned `scale` parameter.\n        \"\"\"\n        return ops.rms_normalization(\n            x, scale=self.scale, axis=self.axis, epsilon=self.epsilon\n        )\n\n    def compute_output_shape(self, input_shape):\n        if isinstance(self.axis, int):\n            axes = [self.axis]\n        else:\n            axes = self.axis\n\n        for axis in axes:\n            if axis >= len(input_shape) or axis < -len(input_shape):\n                raise ValueError(\n                    f\"Axis {axis} is out of bounds for \"\n                    f\"input shape {input_shape}. \"\n                    f\"Received: axis={self.axis}\"\n                )\n        return input_shape\n\n    def get_config(self):\n        config = {\n            \"axis\": self.axis,\n            \"epsilon\": self.epsilon,\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/normalization/rms_normalization_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import layers\nfrom keras.src import ops\nfrom keras.src import testing\n\n\nclass RMSNormalizationTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_ln_basics(self):\n        self.run_layer_test(\n            layers.RMSNormalization,\n            init_kwargs={},\n            input_shape=(4, 2),\n            expected_output_shape=(4, 2),\n            expected_num_trainable_weights=1,\n            expected_num_seed_generators=0,\n        )\n        self.run_layer_test(\n            layers.RMSNormalization,\n            init_kwargs={\n                \"axis\": -1,\n            },\n            input_shape=(4, 2),\n            expected_output_shape=(4, 2),\n            expected_num_trainable_weights=1,\n            expected_num_seed_generators=0,\n        )\n\n    def test_correctness(self):\n        layer = layers.RMSNormalization()\n        layer.build(input_shape=(2, 2, 2))\n        inputs = np.random.normal(\n            loc=5.0, scale=10.0, size=(1000, 2, 2, 2)\n        ).astype(\"float32\")\n\n        inputs = ops.convert_to_tensor(inputs)\n\n        out = layer(inputs)\n        expected = ops.multiply(\n            ops.multiply(\n                inputs,\n                ops.rsqrt(ops.mean(ops.square(inputs), axis=-1, keepdims=True)),\n            ),\n            layer.scale,\n        )\n\n        self.assertAllClose(out, expected, atol=1e-1)\n\n    def test_output(self):\n        layer = layers.RMSNormalization()\n        inputs = np.arange(10).astype(\"float32\")[None, :]\n        out = layer(inputs)\n        self.assertAllClose(\n            out,\n            [\n                [\n                    0.0,\n                    0.18731716,\n                    0.37463433,\n                    0.5619515,\n                    0.74926865,\n                    0.9365858,\n                    1.123903,\n                    1.3112202,\n                    1.4985373,\n                    1.6858544,\n                ]\n            ],\n        )\n\n    def test_unsorted_axis(self):\n        x = np.random.randn(2, 3, 4).astype(\"float32\")\n        layer_sorted = layers.RMSNormalization(axis=[-2, -1])\n        layer_unsorted = layers.RMSNormalization(axis=[-1, -2])\n        out_sorted = layer_sorted(x)\n        out_unsorted = layer_unsorted(x)\n        self.assertEqual(out_sorted.shape, (2, 3, 4))\n        self.assertEqual(out_unsorted.shape, (2, 3, 4))\n        self.assertAllClose(out_sorted, out_unsorted)\n"
  },
  {
    "path": "keras/src/layers/normalization/spectral_normalization.py",
    "content": "from keras.src import initializers\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers import Wrapper\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.utils.numerical_utils import normalize\n\n\n@keras_export(\"keras.layers.SpectralNormalization\")\nclass SpectralNormalization(Wrapper):\n    \"\"\"Performs spectral normalization on the weights of a target layer.\n\n    This wrapper controls the Lipschitz constant of the weights of a layer by\n    constraining their spectral norm, which can stabilize the training of GANs.\n\n    Args:\n        layer: A `keras.layers.Layer` instance that\n            has either a `kernel` (e.g. `Conv2D`, `Dense`...)\n            or an `embeddings` attribute (`Embedding` layer).\n        power_iterations: int, the number of iterations during normalization.\n        **kwargs: Base wrapper keyword arguments.\n\n    Examples:\n\n    Wrap `keras.layers.Conv2D`:\n    >>> x = np.random.rand(1, 10, 10, 1)\n    >>> conv2d = SpectralNormalization(keras.layers.Conv2D(2, 2))\n    >>> y = conv2d(x)\n    >>> y.shape\n    (1, 9, 9, 2)\n\n    Wrap `keras.layers.Dense`:\n    >>> x = np.random.rand(1, 10, 10, 1)\n    >>> dense = SpectralNormalization(keras.layers.Dense(10))\n    >>> y = dense(x)\n    >>> y.shape\n    (1, 10, 10, 10)\n\n    Reference:\n\n    - [Spectral Normalization for GAN](https://arxiv.org/abs/1802.05957).\n    \"\"\"\n\n    def __init__(self, layer, power_iterations=1, **kwargs):\n        super().__init__(layer, **kwargs)\n        if power_iterations <= 0:\n            raise ValueError(\n                \"`power_iterations` should be greater than zero. Received: \"\n                f\"`power_iterations={power_iterations}`\"\n            )\n        self.power_iterations = power_iterations\n\n    def build(self, input_shape):\n        super().build(input_shape)\n        self.input_spec = InputSpec(min_ndim=1, axes={-1: input_shape[-1]})\n\n        if hasattr(self.layer, \"kernel\"):\n            self.kernel = self.layer.kernel\n        elif hasattr(self.layer, \"embeddings\"):\n            self.kernel = self.layer.embeddings\n        else:\n            raise ValueError(\n                f\"{type(self.layer).__name__} object has no attribute 'kernel' \"\n                \"nor 'embeddings'\"\n            )\n\n        self.kernel_shape = self.kernel.shape\n\n        self.vector_u = self.add_weight(\n            shape=(1, self.kernel_shape[-1]),\n            initializer=initializers.TruncatedNormal(stddev=0.02),\n            trainable=False,\n            name=\"vector_u\",\n            dtype=self.kernel.dtype,\n        )\n\n    def call(self, inputs, training=False):\n        if training:\n            new_vector_u, new_kernel = ops.cond(\n                ops.all(ops.equal(self.kernel.value, 0)),\n                lambda: (self.vector_u.value, self.kernel.value),\n                self.normalized_weights,\n            )\n            self.vector_u.assign(new_vector_u)\n            self.kernel.assign(new_kernel)\n\n        output = self.layer(inputs)\n        return ops.cast(output, inputs.dtype)\n\n    def compute_output_shape(self, input_shape):\n        return self.layer.compute_output_shape(input_shape)\n\n    def normalized_weights(self):\n        \"\"\"Generate spectral normalized weights.\n\n        This method returns the updated value for `self.kernel` with the\n        spectral normalized value, so that the layer is ready for `call()`.\n        \"\"\"\n\n        weights = ops.reshape(self.kernel, [-1, self.kernel_shape[-1]])\n        vector_u = self.vector_u.value\n\n        for _ in range(self.power_iterations):\n            vector_v = normalize(\n                ops.matmul(vector_u, ops.transpose(weights)), axis=None\n            )\n            vector_u = normalize(ops.matmul(vector_v, weights), axis=None)\n        vector_u = ops.stop_gradient(vector_u)\n        vector_v = ops.stop_gradient(vector_v)\n        sigma = ops.matmul(\n            ops.matmul(vector_v, weights), ops.transpose(vector_u)\n        )\n        kernel = ops.reshape(ops.divide(self.kernel, sigma), self.kernel_shape)\n        return ops.cast(vector_u, self.vector_u.dtype), ops.cast(\n            kernel, self.kernel.dtype\n        )\n\n    def get_config(self):\n        config = {\"power_iterations\": self.power_iterations}\n        base_config = super().get_config()\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/normalization/spectral_normalization_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import backend\nfrom keras.src import initializers\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import testing\n\n\nclass SpectralNormalizationTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_basic_spectralnorm(self):\n        self.run_layer_test(\n            layers.SpectralNormalization,\n            init_kwargs={\"layer\": layers.Dense(2)},\n            input_data=np.random.uniform(size=(10, 3, 4)),\n            expected_output_shape=(10, 3, 2),\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=1,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n        self.run_layer_test(\n            layers.SpectralNormalization,\n            init_kwargs={\"layer\": layers.Embedding(10, 4)},\n            input_data=np.random.randint(10, size=(10,)).astype(\"float32\"),\n            expected_output_shape=(10, 4),\n            expected_num_trainable_weights=1,\n            expected_num_non_trainable_weights=1,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n            run_training_check=False,\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_spectralnorm_higher_dim(self):\n        self.run_layer_test(\n            layers.SpectralNormalization,\n            init_kwargs={\"layer\": layers.Dense(2)},\n            input_data=np.random.uniform(size=(10, 3, 4, 5)),\n            expected_output_shape=(10, 3, 4, 2),\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=1,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n\n    def test_invalid_power_iterations(self):\n        with self.assertRaisesRegex(\n            ValueError, \"`power_iterations` should be greater than zero.\"\n        ):\n            layers.SpectralNormalization(layers.Dense(2), power_iterations=0)\n\n    def test_invalid_layer(self):\n        layer = layers.SpectralNormalization(layers.ReLU())\n        inputs = np.ones(shape=(4, 2))\n        with self.assertRaisesRegex(\n            ValueError, \"object has no attribute 'kernel' nor 'embeddings'\"\n        ):\n            layer(inputs)\n\n    def test_apply_layer(self):\n        if backend.config.image_data_format() == \"channels_last\":\n            images = np.ones((1, 2, 2, 1))\n        else:\n            images = np.ones((1, 1, 2, 2))\n        sn_wrapper = layers.SpectralNormalization(\n            layers.Conv2D(\n                1, (2, 2), kernel_initializer=initializers.Constant(value=1)\n            ),\n            power_iterations=8,\n        )\n\n        result = sn_wrapper(images, training=False)\n        result_train = sn_wrapper(images, training=True)\n        expected_output = np.array([[[[4.0]]]], dtype=np.float32)\n        self.assertAllClose(result, expected_output)\n        # max eigen value of 2x2 matrix of ones is 2\n        self.assertAllClose(result_train, expected_output / 2)\n\n    @pytest.mark.requires_trainable_backend\n    def test_end_to_end(self):\n        sn_wrapper = layers.SpectralNormalization(\n            layers.Conv2D(\n                3, (2, 2), padding=\"same\", data_format=\"channels_last\"\n            ),\n            power_iterations=2,\n        )\n        model = models.Sequential([sn_wrapper])\n        model.compile(\"rmsprop\", loss=\"mse\")\n        x = np.random.random((4, 8, 8, 3))\n        y = np.random.random((4, 8, 8, 3))\n        model.fit(x, y)\n"
  },
  {
    "path": "keras/src/layers/normalization/unit_normalization.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.layer import Layer\n\n\n@keras_export(\"keras.layers.UnitNormalization\")\nclass UnitNormalization(Layer):\n    \"\"\"Unit normalization layer.\n\n    Normalize a batch of inputs so that each input in the batch has a L2 norm\n    equal to 1 (across the axes specified in `axis`).\n\n    Example:\n\n    >>> data = np.arange(6).reshape(2, 3)\n    >>> normalized_data = keras.layers.UnitNormalization()(data)\n    >>> np.sum(normalized_data[0, :] ** 2)\n    1.0\n\n    Args:\n        axis: Integer or list/tuple. The axis or axes to normalize across.\n            Typically, this is the features axis or axes. The left-out axes are\n            typically the batch axis or axes. `-1` is the last dimension\n            in the input. Defaults to `-1`.\n    \"\"\"\n\n    def __init__(self, axis=-1, **kwargs):\n        super().__init__(**kwargs)\n        if isinstance(axis, (list, tuple)):\n            self.axis = list(axis)\n        elif isinstance(axis, int):\n            self.axis = axis\n        else:\n            raise TypeError(\n                \"Invalid value for `axis` argument: \"\n                \"expected an int or a list/tuple of ints. \"\n                f\"Received: axis={axis}\"\n            )\n        self.supports_masking = True\n\n        self._build_at_init()\n\n    def call(self, inputs):\n        return ops.normalize(inputs, axis=self.axis, order=2, epsilon=1e-12)\n\n    def compute_output_shape(self, input_shape):\n        # Ensure axis is always treated as a list\n        if isinstance(self.axis, int):\n            axes = [self.axis]\n        else:\n            axes = self.axis\n\n        for axis in axes:\n            if axis >= len(input_shape) or axis < -len(input_shape):\n                raise ValueError(\n                    f\"Axis {self.axis} is out of bounds for \"\n                    f\"input shape {input_shape}.\"\n                )\n        return input_shape\n\n    def get_config(self):\n        config = super().get_config()\n        config.update({\"axis\": self.axis})\n        return config\n"
  },
  {
    "path": "keras/src/layers/normalization/unit_normalization_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\ndef squared_l2_norm(x):\n    x = backend.convert_to_numpy(x)\n    return np.sum(x**2)\n\n\nclass UnitNormalizationTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_un_basics(self):\n        self.run_layer_test(\n            layers.UnitNormalization,\n            init_kwargs={\"axis\": -1},\n            input_shape=(2, 3),\n            expected_output_shape=(2, 3),\n            supports_masking=True,\n            assert_built_after_instantiation=True,\n        )\n        self.run_layer_test(\n            layers.UnitNormalization,\n            init_kwargs={\"axis\": (1, 2)},\n            input_shape=(1, 3, 3),\n            expected_output_shape=(1, 3, 3),\n            supports_masking=True,\n            assert_built_after_instantiation=True,\n        )\n\n    def test_invalid_axis(self):\n        with self.assertRaisesRegex(\n            TypeError,\n            (\n                \"Invalid value for `axis` argument: expected an int or a \"\n                \"list/tuple of ints.\"\n            ),\n        ):\n            layers.UnitNormalization(axis={\"axis\": -1})\n\n    def test_correctness(self):\n        layer = layers.UnitNormalization(axis=-1)\n        inputs = np.random.normal(size=(2, 3))\n        outputs = layer(inputs)\n        self.assertAllClose(squared_l2_norm(outputs[0, :]), 1.0)\n        self.assertAllClose(squared_l2_norm(outputs[1, :]), 1.0)\n\n        layer = layers.UnitNormalization(axis=(1, 2))\n        inputs = np.random.normal(size=(2, 3, 3))\n        outputs = layer(inputs)\n        self.assertAllClose(squared_l2_norm(outputs[0, :, :]), 1.0)\n        self.assertAllClose(squared_l2_norm(outputs[1, :, :]), 1.0)\n\n        layer = layers.UnitNormalization(axis=1)\n        inputs = np.random.normal(size=(2, 3, 2))\n        outputs = layer(inputs)\n        self.assertAllClose(squared_l2_norm(outputs[0, :, 0]), 1.0)\n        self.assertAllClose(squared_l2_norm(outputs[1, :, 0]), 1.0)\n        self.assertAllClose(squared_l2_norm(outputs[0, :, 1]), 1.0)\n        self.assertAllClose(squared_l2_norm(outputs[1, :, 1]), 1.0)\n"
  },
  {
    "path": "keras/src/layers/pooling/__init__.py",
    "content": ""
  },
  {
    "path": "keras/src/layers/pooling/adaptive_average_pooling1d.py",
    "content": "\"\"\"Adaptive Average Pooling 1D layer.\"\"\"\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.pooling.base_adaptive_pooling import (\n    BaseAdaptiveAveragePooling,\n)\n\n\n@keras_export(\"keras.layers.AdaptiveAveragePooling1D\")\nclass AdaptiveAveragePooling1D(BaseAdaptiveAveragePooling):\n    \"\"\"Adaptive average pooling operation for 1D temporal or spatial data.\n\n    This layer applies an adaptive average pooling operation, which pools the\n    input such that the output has a target length specified by `output_size`,\n    regardless of the input length. The kernel size and stride are automatically\n    computed to achieve the target output size.\n\n    Args:\n        output_size: Integer specifying the target output length.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch, length, channels)`.\n            `\"channels_first\"` corresponds to inputs with shape\n            `(batch, channels, length)`.\n            Defaults to the value found in your Keras config file at\n            `~/.keras/keras.json`. If never set, `\"channels_last\"` is used.\n\n    Input shape:\n        - If `data_format=\"channels_last\"`: 3D tensor\n            `(batch_size, length, channels)`\n        - If `data_format=\"channels_first\"`: 3D tensor\n            `(batch_size, channels, length)`\n\n    Output shape:\n        - If `data_format=\"channels_last\"`:\n            `(batch_size, output_length, channels)`\n        - If `data_format=\"channels_first\"`:\n            `(batch_size, channels, output_length)`\n\n    Examples:\n        >>> import numpy as np\n        >>> input_seq = np.random.rand(1, 64, 3)\n        >>> layer = AdaptiveAveragePooling1D(output_size=32)\n        >>> output_seq = layer(input_seq)\n        >>> output_seq.shape\n        (1, 32, 3)\n    \"\"\"\n\n    def __init__(self, output_size, data_format=None, **kwargs):\n        if isinstance(output_size, int):\n            output_size = (output_size,)\n        elif isinstance(output_size, (tuple, list)):\n            if len(output_size) != 1:\n                raise ValueError(\n                    f\"For 1D input, `output_size` tuple must have length 1. \"\n                    f\"Received: {output_size}\"\n                )\n            output_size = tuple(output_size)\n        else:\n            raise TypeError(\n                f\"`output_size` must be an integer or tuple of 1 integer. \"\n                f\"Received: {output_size} of type {type(output_size)}\"\n            )\n\n        super().__init__(output_size, data_format, **kwargs)\n"
  },
  {
    "path": "keras/src/layers/pooling/adaptive_average_pooling2d.py",
    "content": "\"\"\"Adaptive Average Pooling 2D layer.\"\"\"\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.pooling.base_adaptive_pooling import (\n    BaseAdaptiveAveragePooling,\n)\n\n\n@keras_export(\"keras.layers.AdaptiveAveragePooling2D\")\nclass AdaptiveAveragePooling2D(BaseAdaptiveAveragePooling):\n    \"\"\"Adaptive average pooling operation for 2D spatial data.\n\n    This layer applies an adaptive average pooling operation, which pools the\n    input such that the output has a target spatial size specified by\n    `output_size`, regardless of the input spatial size. The kernel size\n    and stride are automatically computed to achieve the target output size.\n\n    Args:\n        output_size: Integer or tuple of 2 integers specifying the\n            target output size.\n            If an integer, the same value is used for both height and width.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch, height, width, channels)`.\n            `\"channels_first\"` corresponds to inputs with shape\n            `(batch, channels, height, width)`.\n            Defaults to the value found in your Keras config file at\n            `~/.keras/keras.json`. If never set, `\"channels_last\"` is used.\n\n    Input shape:\n        - If `data_format=\"channels_last\"`: 4D tensor\n            `(batch_size, height, width, channels)`\n        - If `data_format=\"channels_first\"`: 4D tensor\n            `(batch_size, channels, height, width)`\n\n    Output shape:\n        - If `data_format=\"channels_last\"`:\n            `(batch_size, output_height, output_width, channels)`\n        - If `data_format=\"channels_first\"`:\n            `(batch_size, channels, output_height, output_width)`\n\n    Examples:\n        >>> import numpy as np\n        >>> input_img = np.random.rand(1, 64, 64, 3)\n        >>> layer = AdaptiveAveragePooling2D(output_size=32)\n        >>> output_img = layer(input_img)\n        >>> output_img.shape\n        (1, 32, 32, 3)\n    \"\"\"\n\n    def __init__(self, output_size, data_format=None, **kwargs):\n        if isinstance(output_size, int):\n            output_size_tuple = (output_size, output_size)\n        elif isinstance(output_size, (tuple, list)) and len(output_size) == 2:\n            output_size_tuple = tuple(output_size)\n        else:\n            raise TypeError(\n                f\"`output_size` must be an integer or (height, width) tuple. \"\n                f\"Received: {output_size} of type {type(output_size)}\"\n            )\n\n        super().__init__(output_size_tuple, data_format, **kwargs)\n"
  },
  {
    "path": "keras/src/layers/pooling/adaptive_average_pooling3d.py",
    "content": "\"\"\"Adaptive Average Pooling 3D layer.\"\"\"\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.pooling.base_adaptive_pooling import (\n    BaseAdaptiveAveragePooling,\n)\n\n\n@keras_export(\"keras.layers.AdaptiveAveragePooling3D\")\nclass AdaptiveAveragePooling3D(BaseAdaptiveAveragePooling):\n    \"\"\"Adaptive average pooling operation for 3D volumetric data.\n\n    This layer applies an adaptive average pooling operation, which pools the\n    input such that the output has a target spatial size specified by\n    `output_size`, regardless of the input spatial size. The kernel size\n    and stride are automatically computed to achieve the target output size.\n\n    Args:\n        output_size: Integer or tuple of 3 integers specifying the\n            target output size.\n            If an integer, the same value is used for depth, height, and width.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch, depth, height, width, channels)`.\n            `\"channels_first\"` corresponds to inputs with shape\n            `(batch, channels, depth, height, width)`.\n            Defaults to the value found in your Keras config file at\n            `~/.keras/keras.json`. If never set, `\"channels_last\"` is used.\n\n    Input shape:\n        - If `data_format=\"channels_last\"`: 5D tensor\n            `(batch_size, depth, height, width, channels)`\n        - If `data_format=\"channels_first\"`: 5D tensor\n            `(batch_size, channels, depth, height, width)`\n\n    Output shape:\n        - If `data_format=\"channels_last\"`:\n            `(batch_size, output_depth, output_height, output_width, channels)`\n        - If `data_format=\"channels_first\"`:\n            `(batch_size, channels, output_depth, output_height, output_width)`\n\n    Examples:\n        >>> import numpy as np\n        >>> input_vol = np.random.rand(1, 32, 32, 32, 3)\n        >>> layer = AdaptiveAveragePooling3D(output_size=16)\n        >>> output_vol = layer(input_vol)\n        >>> output_vol.shape\n        (1, 16, 16, 16, 3)\n    \"\"\"\n\n    def __init__(self, output_size, data_format=None, **kwargs):\n        if isinstance(output_size, int):\n            output_size_tuple = (output_size, output_size, output_size)\n        elif isinstance(output_size, (tuple, list)) and len(output_size) == 3:\n            output_size_tuple = tuple(output_size)\n        else:\n            raise TypeError(\n                f\"`output_size` must be an integer or \"\n                f\"(depth, height, width) tuple. \"\n                f\"Received: {output_size} of type {type(output_size)}\"\n            )\n\n        super().__init__(output_size_tuple, data_format, **kwargs)\n"
  },
  {
    "path": "keras/src/layers/pooling/adaptive_max_pooling1d.py",
    "content": "\"\"\"Adaptive Max Pooling 1D layer.\"\"\"\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.pooling.base_adaptive_pooling import (\n    BaseAdaptiveMaxPooling,\n)\n\n\n@keras_export(\"keras.layers.AdaptiveMaxPooling1D\")\nclass AdaptiveMaxPooling1D(BaseAdaptiveMaxPooling):\n    \"\"\"Adaptive max pooling operation for 1D temporal or spatial data.\n\n    This layer applies an adaptive max pooling operation, which pools the\n    input such that the output has a target length specified by `output_size`,\n    regardless of the input length. The kernel size and stride are automatically\n    computed to achieve the target output size.\n\n    Args:\n        output_size: Integer specifying the target output length.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch, length, channels)`.\n            `\"channels_first\"` corresponds to inputs with shape\n            `(batch, channels, length)`.\n            Defaults to the value found in your Keras config file at\n            `~/.keras/keras.json`. If never set, `\"channels_last\"` is used.\n\n    Input shape:\n        - If `data_format=\"channels_last\"`: 3D tensor\n            `(batch_size, length, channels)`\n        - If `data_format=\"channels_first\"`: 3D tensor\n            `(batch_size, channels, length)`\n\n    Output shape:\n        - If `data_format=\"channels_last\"`:\n            `(batch_size, output_length, channels)`\n        - If `data_format=\"channels_first\"`:\n            `(batch_size, channels, output_length)`\n\n    Examples:\n        >>> import numpy as np\n        >>> input_seq = np.random.rand(1, 64, 3)\n        >>> layer = AdaptiveMaxPooling1D(output_size=32)\n        >>> output_seq = layer(input_seq)\n        >>> output_seq.shape\n        (1, 32, 3)\n    \"\"\"\n\n    def __init__(self, output_size, data_format=None, **kwargs):\n        if isinstance(output_size, int):\n            output_size = (output_size,)\n        elif isinstance(output_size, (tuple, list)):\n            if len(output_size) != 1:\n                raise ValueError(\n                    f\"For 1D input, `output_size` tuple must have length 1. \"\n                    f\"Received: {output_size}\"\n                )\n            output_size = tuple(output_size)\n        else:\n            raise TypeError(\n                f\"`output_size` must be an integer or tuple of 1 integer. \"\n                f\"Received: {output_size} of type {type(output_size)}\"\n            )\n\n        super().__init__(output_size, data_format, **kwargs)\n"
  },
  {
    "path": "keras/src/layers/pooling/adaptive_max_pooling2d.py",
    "content": "\"\"\"Adaptive Max Pooling 2D layer.\"\"\"\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.pooling.base_adaptive_pooling import (\n    BaseAdaptiveMaxPooling,\n)\n\n\n@keras_export(\"keras.layers.AdaptiveMaxPooling2D\")\nclass AdaptiveMaxPooling2D(BaseAdaptiveMaxPooling):\n    \"\"\"Adaptive max pooling operation for 2D spatial data.\n\n    This layer applies an adaptive max pooling operation, which pools the\n    input such that the output has a target spatial size specified by\n    `output_size`, regardless of the input spatial size. The kernel size\n    and stride are automatically computed to achieve the target output size.\n\n    Args:\n        output_size: Integer or tuple of 2 integers specifying the\n            target output size.\n            If an integer, the same value is used for both height and width.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch, height, width, channels)`.\n            `\"channels_first\"` corresponds to inputs with shape\n            `(batch, channels, height, width)`.\n            Defaults to the value found in your Keras config file at\n            `~/.keras/keras.json`. If never set, `\"channels_last\"` is used.\n\n    Input shape:\n        - If `data_format=\"channels_last\"`: 4D tensor\n            `(batch_size, height, width, channels)`\n        - If `data_format=\"channels_first\"`: 4D tensor\n            `(batch_size, channels, height, width)`\n\n    Output shape:\n        - If `data_format=\"channels_last\"`:\n            `(batch_size, output_height, output_width, channels)`\n        - If `data_format=\"channels_first\"`:\n            `(batch_size, channels, output_height, output_width)`\n\n    Examples:\n        >>> import numpy as np\n        >>> input_img = np.random.rand(1, 64, 64, 3)\n        >>> layer = AdaptiveMaxPooling2D(output_size=32)\n        >>> output_img = layer(input_img)\n        >>> output_img.shape\n        (1, 32, 32, 3)\n    \"\"\"\n\n    def __init__(self, output_size, data_format=None, **kwargs):\n        if isinstance(output_size, int):\n            output_size_tuple = (output_size, output_size)\n        elif isinstance(output_size, (tuple, list)) and len(output_size) == 2:\n            output_size_tuple = tuple(output_size)\n        else:\n            raise TypeError(\n                f\"`output_size` must be an integer or (height, width) tuple. \"\n                f\"Received: {output_size} of type {type(output_size)}\"\n            )\n\n        super().__init__(output_size_tuple, data_format, **kwargs)\n"
  },
  {
    "path": "keras/src/layers/pooling/adaptive_max_pooling3d.py",
    "content": "\"\"\"Adaptive Max Pooling 3D layer.\"\"\"\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.pooling.base_adaptive_pooling import (\n    BaseAdaptiveMaxPooling,\n)\n\n\n@keras_export(\"keras.layers.AdaptiveMaxPooling3D\")\nclass AdaptiveMaxPooling3D(BaseAdaptiveMaxPooling):\n    \"\"\"Adaptive max pooling operation for 3D volumetric data.\n\n    This layer applies an adaptive max pooling operation, which pools the\n    input such that the output has a target spatial size specified by\n    `output_size`, regardless of the input spatial size. The kernel size\n    and stride are automatically computed to achieve the target output size.\n\n    Args:\n        output_size: Integer or tuple of 3 integers specifying the\n            target output size.\n            If an integer, the same value is used for depth, height, and width.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch, depth, height, width, channels)`.\n            `\"channels_first\"` corresponds to inputs with shape\n            `(batch, channels, depth, height, width)`.\n            Defaults to the value found in your Keras config file at\n            `~/.keras/keras.json`. If never set, `\"channels_last\"` is used.\n\n    Input shape:\n        - If `data_format=\"channels_last\"`: 5D tensor\n            `(batch_size, depth, height, width, channels)`\n        - If `data_format=\"channels_first\"`: 5D tensor\n            `(batch_size, channels, depth, height, width)`\n\n    Output shape:\n        - If `data_format=\"channels_last\"`:\n            `(batch_size, output_depth, output_height, output_width, channels)`\n        - If `data_format=\"channels_first\"`:\n            `(batch_size, channels, output_depth, output_height, output_width)`\n\n    Examples:\n        >>> import numpy as np\n        >>> input_vol = np.random.rand(1, 32, 32, 32, 3)\n        >>> layer = AdaptiveMaxPooling3D(output_size=16)\n        >>> output_vol = layer(input_vol)\n        >>> output_vol.shape\n        (1, 16, 16, 16, 3)\n    \"\"\"\n\n    def __init__(self, output_size, data_format=None, **kwargs):\n        if isinstance(output_size, int):\n            output_size_tuple = (output_size, output_size, output_size)\n        elif isinstance(output_size, (tuple, list)) and len(output_size) == 3:\n            output_size_tuple = tuple(output_size)\n        else:\n            raise TypeError(\n                f\"`output_size` must be an integer or \"\n                f\"(depth, height, width) tuple. \"\n                f\"Received: {output_size} of type {type(output_size)}\"\n            )\n\n        super().__init__(output_size_tuple, data_format, **kwargs)\n"
  },
  {
    "path": "keras/src/layers/pooling/adaptive_pooling1d_test.py",
    "content": "import numpy as np\n\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass AdaptivePooling1DLayerTest(testing.TestCase):\n    \"\"\"Tests for AdaptiveAveragePooling1D and AdaptiveMaxPooling1D.\"\"\"\n\n    def _run_layer_test(self, layer_class, x_np, output_size, data_format):\n        \"\"\"Helper: test layer output shape matches compute_output_shape().\"\"\"\n        layer = layer_class(output_size=output_size, data_format=data_format)\n        y = layer(x_np)\n        expected_shape = layer.compute_output_shape(x_np.shape)\n        self.assertEqual(y.shape, expected_shape)\n\n    def test_average_pooling_basic_shapes(self):\n        \"\"\"Test AdaptiveAveragePooling1D basic shape transformation.\"\"\"\n        shape = (2, 3, 8)  # N,C,L\n        x = np.random.randn(*shape).astype(\"float32\")\n        self._run_layer_test(\n            layers.AdaptiveAveragePooling1D,\n            x,\n            output_size=4,\n            data_format=\"channels_first\",\n        )\n\n    def test_max_pooling_basic_shapes(self):\n        \"\"\"Test AdaptiveMaxPooling1D basic shape transformation.\"\"\"\n        shape = (2, 3, 8)\n        x = np.random.randn(*shape).astype(\"float32\")\n        self._run_layer_test(\n            layers.AdaptiveMaxPooling1D,\n            x,\n            output_size=4,\n            data_format=\"channels_first\",\n        )\n\n    def test_average_pooling_channels_last(self):\n        \"\"\"Test AdaptiveAveragePooling1D with channels_last format.\"\"\"\n        shape = (2, 8, 3)  # N,L,C\n        x = np.random.randn(*shape).astype(\"float32\")\n        self._run_layer_test(\n            layers.AdaptiveAveragePooling1D,\n            x,\n            output_size=4,\n            data_format=\"channels_last\",\n        )\n\n    def test_max_pooling_channels_last(self):\n        \"\"\"Test AdaptiveMaxPooling1D with channels_last format.\"\"\"\n        shape = (2, 8, 3)\n        x = np.random.randn(*shape).astype(\"float32\")\n        self._run_layer_test(\n            layers.AdaptiveMaxPooling1D,\n            x,\n            output_size=4,\n            data_format=\"channels_last\",\n        )\n\n    def test_average_pooling_compute_output_shape(self):\n        \"\"\"Test compute_output_shape() for AdaptiveAveragePooling1D.\"\"\"\n        layer = layers.AdaptiveAveragePooling1D(\n            output_size=16, data_format=\"channels_last\"\n        )\n        input_shape = (None, 64, 3)\n        output_shape = layer.compute_output_shape(input_shape)\n        self.assertEqual(output_shape, (None, 16, 3))\n\n    def test_max_pooling_compute_output_shape(self):\n        \"\"\"Test compute_output_shape() for AdaptiveMaxPooling1D.\"\"\"\n        layer = layers.AdaptiveMaxPooling1D(\n            output_size=16, data_format=\"channels_first\"\n        )\n        input_shape = (2, 3, 64)\n        output_shape = layer.compute_output_shape(input_shape)\n        self.assertEqual(output_shape, (2, 3, 16))\n\n    def test_average_pooling_get_config(self):\n        \"\"\"Test get_config() serialization for AdaptiveAveragePooling1D.\"\"\"\n        layer = layers.AdaptiveAveragePooling1D(\n            output_size=32, data_format=\"channels_first\"\n        )\n        config = layer.get_config()\n        self.assertEqual(config[\"output_size\"], (32,))\n        self.assertEqual(config[\"data_format\"], \"channels_first\")\n\n    def test_max_pooling_get_config(self):\n        \"\"\"Test get_config() serialization for AdaptiveMaxPooling1D.\"\"\"\n        layer = layers.AdaptiveMaxPooling1D(\n            output_size=32, data_format=\"channels_last\"\n        )\n        config = layer.get_config()\n        self.assertEqual(config[\"output_size\"], (32,))\n        self.assertEqual(config[\"data_format\"], \"channels_last\")\n\n    def test_average_pooling_numerical(self):\n        \"\"\"Test AdaptiveAveragePooling1D numerical correctness.\"\"\"\n        inputs = np.array([[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]], dtype=\"float32\")\n        expected = np.array([[[2.0, 5.0]]], dtype=\"float32\")\n\n        layer = layers.AdaptiveAveragePooling1D(\n            output_size=2, data_format=\"channels_first\"\n        )\n\n        outputs = layer(inputs)\n        self.assertAllClose(outputs, expected, atol=1e-4)\n\n    def test_max_pooling_numerical(self):\n        \"\"\"Test AdaptiveMaxPooling1D numerical correctness.\"\"\"\n        inputs = np.array([[[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]], dtype=\"float32\")\n        expected = np.array([[[3.0, 6.0]]], dtype=\"float32\")\n\n        layer = layers.AdaptiveMaxPooling1D(\n            output_size=2, data_format=\"channels_first\"\n        )\n\n        outputs = layer(inputs)\n        self.assertAllClose(outputs, expected, atol=1e-4)\n"
  },
  {
    "path": "keras/src/layers/pooling/adaptive_pooling2d_test.py",
    "content": "import numpy as np\n\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass AdaptivePooling2DLayerTest(testing.TestCase):\n    \"\"\"Tests for AdaptiveAveragePooling2D and AdaptiveMaxPooling2D.\"\"\"\n\n    def _run_layer_test(self, layer_class, x_np, output_size, data_format):\n        \"\"\"Helper: test layer output shape matches compute_output_shape().\"\"\"\n        layer = layer_class(output_size=output_size, data_format=data_format)\n        y = layer(x_np)\n        expected_shape = layer.compute_output_shape(x_np.shape)\n        self.assertEqual(y.shape, expected_shape)\n\n    def test_average_pooling_basic_shapes(self):\n        \"\"\"Test AdaptiveAveragePooling2D basic shape transformation.\"\"\"\n        shape = (2, 3, 8, 8)  # N,C,H,W\n        x = np.random.randn(*shape).astype(\"float32\")\n        self._run_layer_test(\n            layers.AdaptiveAveragePooling2D,\n            x,\n            output_size=4,\n            data_format=\"channels_first\",\n        )\n\n    def test_max_pooling_basic_shapes(self):\n        \"\"\"Test AdaptiveMaxPooling2D basic shape transformation.\"\"\"\n        shape = (2, 3, 8, 8)\n        x = np.random.randn(*shape).astype(\"float32\")\n        self._run_layer_test(\n            layers.AdaptiveMaxPooling2D,\n            x,\n            output_size=4,\n            data_format=\"channels_first\",\n        )\n\n    def test_average_pooling_channels_last(self):\n        \"\"\"Test AdaptiveAveragePooling2D with channels_last format.\"\"\"\n        shape = (2, 8, 8, 3)  # N,H,W,C\n        x = np.random.randn(*shape).astype(\"float32\")\n        self._run_layer_test(\n            layers.AdaptiveAveragePooling2D,\n            x,\n            output_size=4,\n            data_format=\"channels_last\",\n        )\n\n    def test_max_pooling_channels_last(self):\n        \"\"\"Test AdaptiveMaxPooling2D with channels_last format.\"\"\"\n        shape = (2, 8, 8, 3)\n        x = np.random.randn(*shape).astype(\"float32\")\n        self._run_layer_test(\n            layers.AdaptiveMaxPooling2D,\n            x,\n            output_size=4,\n            data_format=\"channels_last\",\n        )\n\n    def test_average_pooling_tuple_output_size(self):\n        \"\"\"Test AdaptiveAveragePooling2D with tuple output_size.\"\"\"\n        shape = (2, 8, 8, 3)\n        x = np.random.randn(*shape).astype(\"float32\")\n        self._run_layer_test(\n            layers.AdaptiveAveragePooling2D,\n            x,\n            output_size=(4, 4),\n            data_format=\"channels_last\",\n        )\n\n    def test_max_pooling_tuple_output_size(self):\n        \"\"\"Test AdaptiveMaxPooling2D with tuple output_size.\"\"\"\n        shape = (2, 8, 8, 3)\n        x = np.random.randn(*shape).astype(\"float32\")\n        self._run_layer_test(\n            layers.AdaptiveMaxPooling2D,\n            x,\n            output_size=(2, 4),\n            data_format=\"channels_last\",\n        )\n\n    def test_average_pooling_compute_output_shape(self):\n        \"\"\"Test compute_output_shape() for AdaptiveAveragePooling2D.\"\"\"\n        layer = layers.AdaptiveAveragePooling2D(\n            output_size=16, data_format=\"channels_last\"\n        )\n        input_shape = (None, 64, 64, 3)\n        output_shape = layer.compute_output_shape(input_shape)\n        self.assertEqual(output_shape, (None, 16, 16, 3))\n\n    def test_max_pooling_compute_output_shape(self):\n        \"\"\"Test compute_output_shape() for AdaptiveMaxPooling2D.\"\"\"\n        layer = layers.AdaptiveMaxPooling2D(\n            output_size=(8, 16), data_format=\"channels_first\"\n        )\n        input_shape = (2, 3, 64, 64)\n        output_shape = layer.compute_output_shape(input_shape)\n        self.assertEqual(output_shape, (2, 3, 8, 16))\n\n    def test_average_pooling_get_config(self):\n        \"\"\"Test get_config() serialization for AdaptiveAveragePooling2D.\"\"\"\n        layer = layers.AdaptiveAveragePooling2D(\n            output_size=32, data_format=\"channels_first\"\n        )\n        config = layer.get_config()\n        self.assertEqual(config[\"output_size\"], (32, 32))\n        self.assertEqual(config[\"data_format\"], \"channels_first\")\n\n    def test_max_pooling_get_config(self):\n        \"\"\"Test get_config() serialization for AdaptiveMaxPooling2D.\"\"\"\n        layer = layers.AdaptiveMaxPooling2D(\n            output_size=(8, 16), data_format=\"channels_last\"\n        )\n        config = layer.get_config()\n        self.assertEqual(config[\"output_size\"], (8, 16))\n        self.assertEqual(config[\"data_format\"], \"channels_last\")\n\n    def test_average_pooling2d_numerical(self):\n        \"\"\"Test AdaptiveAveragePooling2D numerical correctness.\"\"\"\n        inputs = np.array(\n            [\n                [\n                    [\n                        [1.0, 2.0, 3.0, 4.0],\n                        [5.0, 6.0, 7.0, 8.0],\n                        [9.0, 10.0, 11.0, 12.0],\n                        [13.0, 14.0, 15.0, 16.0],\n                    ]\n                ]\n            ],\n            dtype=\"float32\",\n        )\n        expected = np.array([[[[3.5, 5.5], [11.5, 13.5]]]], dtype=\"float32\")\n\n        layer = layers.AdaptiveAveragePooling2D(\n            output_size=2, data_format=\"channels_first\"\n        )\n        outputs = layer(inputs)\n        self.assertAllClose(outputs, expected, atol=1e-4)\n\n    def test_max_pooling2d_numerical(self):\n        \"\"\"Test AdaptiveMaxPooling2D numerical correctness.\"\"\"\n        inputs = np.array(\n            [\n                [\n                    [\n                        [1.0, 2.0, 3.0, 4.0],\n                        [5.0, 6.0, 7.0, 8.0],\n                        [9.0, 10.0, 11.0, 12.0],\n                        [13.0, 14.0, 15.0, 16.0],\n                    ]\n                ]\n            ],\n            dtype=\"float32\",\n        )\n        expected = np.array([[[[6.0, 8.0], [14.0, 16.0]]]], dtype=\"float32\")\n\n        layer = layers.AdaptiveMaxPooling2D(\n            output_size=2, data_format=\"channels_first\"\n        )\n        outputs = layer(inputs)\n        self.assertAllClose(outputs, expected, atol=1e-4)\n"
  },
  {
    "path": "keras/src/layers/pooling/adaptive_pooling3d_test.py",
    "content": "import numpy as np\n\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass AdaptivePooling3DLayerTest(testing.TestCase):\n    \"\"\"Tests for AdaptiveAveragePooling3D and AdaptiveMaxPooling3D.\"\"\"\n\n    def _run_layer_test(self, layer_class, x_np, output_size, data_format):\n        \"\"\"Helper: test layer output shape matches compute_output_shape().\"\"\"\n        layer = layer_class(output_size=output_size, data_format=data_format)\n        y = layer(x_np)\n        expected_shape = layer.compute_output_shape(x_np.shape)\n        self.assertEqual(y.shape, expected_shape)\n\n    def test_average_pooling_basic_shapes(self):\n        \"\"\"Test AdaptiveAveragePooling3D basic shape transformation.\"\"\"\n        shape = (2, 3, 8, 8, 8)  # N,C,D,H,W\n        x = np.random.randn(*shape).astype(\"float32\")\n        self._run_layer_test(\n            layers.AdaptiveAveragePooling3D,\n            x,\n            output_size=4,\n            data_format=\"channels_first\",\n        )\n\n    def test_max_pooling_basic_shapes(self):\n        \"\"\"Test AdaptiveMaxPooling3D basic shape transformation.\"\"\"\n        shape = (2, 3, 8, 8, 8)\n        x = np.random.randn(*shape).astype(\"float32\")\n        self._run_layer_test(\n            layers.AdaptiveMaxPooling3D,\n            x,\n            output_size=4,\n            data_format=\"channels_first\",\n        )\n\n    def test_average_pooling_channels_last(self):\n        \"\"\"Test AdaptiveAveragePooling3D with channels_last format.\"\"\"\n        shape = (2, 8, 8, 8, 3)  # N,D,H,W,C\n        x = np.random.randn(*shape).astype(\"float32\")\n        self._run_layer_test(\n            layers.AdaptiveAveragePooling3D,\n            x,\n            output_size=4,\n            data_format=\"channels_last\",\n        )\n\n    def test_max_pooling_channels_last(self):\n        \"\"\"Test AdaptiveMaxPooling3D with channels_last format.\"\"\"\n        shape = (2, 8, 8, 8, 3)\n        x = np.random.randn(*shape).astype(\"float32\")\n        self._run_layer_test(\n            layers.AdaptiveMaxPooling3D,\n            x,\n            output_size=4,\n            data_format=\"channels_last\",\n        )\n\n    def test_average_pooling_tuple_output_size(self):\n        \"\"\"Test AdaptiveAveragePooling3D with tuple output_size.\"\"\"\n        shape = (2, 8, 8, 8, 3)\n        x = np.random.randn(*shape).astype(\"float32\")\n        self._run_layer_test(\n            layers.AdaptiveAveragePooling3D,\n            x,\n            output_size=(4, 4, 4),\n            data_format=\"channels_last\",\n        )\n\n    def test_max_pooling_tuple_output_size(self):\n        \"\"\"Test AdaptiveMaxPooling3D with tuple output_size.\"\"\"\n        shape = (2, 8, 8, 8, 3)\n        x = np.random.randn(*shape).astype(\"float32\")\n        self._run_layer_test(\n            layers.AdaptiveMaxPooling3D,\n            x,\n            output_size=(2, 4, 4),\n            data_format=\"channels_last\",\n        )\n\n    def test_average_pooling_compute_output_shape(self):\n        \"\"\"Test compute_output_shape() for AdaptiveAveragePooling3D.\"\"\"\n        layer = layers.AdaptiveAveragePooling3D(\n            output_size=8, data_format=\"channels_last\"\n        )\n        input_shape = (None, 32, 32, 32, 3)\n        output_shape = layer.compute_output_shape(input_shape)\n        self.assertEqual(output_shape, (None, 8, 8, 8, 3))\n\n    def test_max_pooling_compute_output_shape(self):\n        \"\"\"Test compute_output_shape() for AdaptiveMaxPooling3D.\"\"\"\n        layer = layers.AdaptiveMaxPooling3D(\n            output_size=(4, 8, 8), data_format=\"channels_first\"\n        )\n        input_shape = (2, 3, 32, 32, 32)\n        output_shape = layer.compute_output_shape(input_shape)\n        self.assertEqual(output_shape, (2, 3, 4, 8, 8))\n\n    def test_average_pooling_get_config(self):\n        \"\"\"Test get_config() serialization for AdaptiveAveragePooling3D.\"\"\"\n        layer = layers.AdaptiveAveragePooling3D(\n            output_size=16, data_format=\"channels_first\"\n        )\n        config = layer.get_config()\n        self.assertEqual(config[\"output_size\"], (16, 16, 16))\n        self.assertEqual(config[\"data_format\"], \"channels_first\")\n\n    def test_max_pooling_get_config(self):\n        \"\"\"Test get_config() serialization for AdaptiveMaxPooling3D.\"\"\"\n        layer = layers.AdaptiveMaxPooling3D(\n            output_size=(8, 16, 16), data_format=\"channels_last\"\n        )\n        config = layer.get_config()\n        self.assertEqual(config[\"output_size\"], (8, 16, 16))\n        self.assertEqual(config[\"data_format\"], \"channels_last\")\n\n    def test_average_pooling3d_numerical(self):\n        \"\"\"Test AdaptiveAveragePooling3D numerical correctness.\"\"\"\n        inputs = np.array(\n            [[[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]],\n            dtype=\"float32\",\n        )\n        expected = np.array(\n            [[[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]],\n            dtype=\"float32\",\n        )\n\n        layer = layers.AdaptiveAveragePooling3D(\n            output_size=2, data_format=\"channels_first\"\n        )\n        outputs = layer(inputs)\n        self.assertAllClose(outputs, expected, atol=1e-4)\n\n    def test_max_pooling3d_numerical(self):\n        \"\"\"Test AdaptiveMaxPooling3D numerical correctness.\"\"\"\n        inputs = np.array(\n            [[[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]],\n            dtype=\"float32\",\n        )\n        expected = np.array(\n            [[[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]],\n            dtype=\"float32\",\n        )\n\n        layer = layers.AdaptiveMaxPooling3D(\n            output_size=2, data_format=\"channels_first\"\n        )\n        outputs = layer(inputs)\n        self.assertAllClose(outputs, expected, atol=1e-4)\n"
  },
  {
    "path": "keras/src/layers/pooling/average_pooling1d.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.pooling.base_pooling import BasePooling\n\n\n@keras_export([\"keras.layers.AveragePooling1D\", \"keras.layers.AvgPool1D\"])\nclass AveragePooling1D(BasePooling):\n    \"\"\"Average pooling for temporal data.\n\n    Downsamples the input representation by taking the average value over the\n    window defined by `pool_size`. The window is shifted by `strides`.  The\n    resulting output when using \"valid\" padding option has a shape of:\n    `output_shape = (input_shape - pool_size + 1) / strides)`\n\n    The resulting output shape when using the \"same\" padding option is:\n    `output_shape = input_shape / strides`\n\n    Args:\n        pool_size: int, size of the max pooling window.\n        strides: int or None. Specifies how much the pooling window moves\n            for each pooling step. If None, it will default to `pool_size`.\n        padding: string, either `\"valid\"` or `\"same\"` (case-insensitive).\n            `\"valid\"` means no padding. `\"same\"` results in padding evenly to\n            the left/right or up/down of the input such that output has the same\n            height/width dimension as the input.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, steps, features)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, features, steps)`. It defaults to the `image_data_format`\n            value found in your Keras config file at `~/.keras/keras.json`.\n            If you never set it, then it will be `\"channels_last\"`.\n\n    Input shape:\n\n    - If `data_format=\"channels_last\"`:\n        3D tensor with shape `(batch_size, steps, features)`.\n    - If `data_format=\"channels_first\"`:\n        3D tensor with shape `(batch_size, features, steps)`.\n\n    Output shape:\n\n    - If `data_format=\"channels_last\"`:\n        3D tensor with shape `(batch_size, downsampled_steps, features)`.\n    - If `data_format=\"channels_first\"`:\n        3D tensor with shape `(batch_size, features, downsampled_steps)`.\n\n    Examples:\n\n    `strides=1` and `padding=\"valid\"`:\n\n    >>> x = np.array([1., 2., 3., 4., 5.])\n    >>> x = np.reshape(x, [1, 5, 1])\n    >>> avg_pool_1d = keras.layers.AveragePooling1D(pool_size=2,\n    ...    strides=1, padding=\"valid\")\n    >>> avg_pool_1d(x)\n\n    `strides=2` and `padding=\"valid\"`:\n\n    >>> x = np.array([1., 2., 3., 4., 5.])\n    >>> x = np.reshape(x, [1, 5, 1])\n    >>> avg_pool_1d = keras.layers.AveragePooling1D(pool_size=2,\n    ...    strides=2, padding=\"valid\")\n    >>> avg_pool_1d(x)\n\n    `strides=1` and `padding=\"same\"`:\n\n    >>> x = np.array([1., 2., 3., 4., 5.])\n    >>> x = np.reshape(x, [1, 5, 1])\n    >>> avg_pool_1d = keras.layers.AveragePooling1D(pool_size=2,\n    ...    strides=1, padding=\"same\")\n    >>> avg_pool_1d(x)\n    \"\"\"\n\n    def __init__(\n        self,\n        pool_size,\n        strides=None,\n        padding=\"valid\",\n        data_format=None,\n        name=None,\n        **kwargs,\n    ):\n        super().__init__(\n            pool_size,\n            strides,\n            pool_dimensions=1,\n            pool_mode=\"average\",\n            padding=padding,\n            data_format=data_format,\n            name=name,\n            **kwargs,\n        )\n"
  },
  {
    "path": "keras/src/layers/pooling/average_pooling2d.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.pooling.base_pooling import BasePooling\n\n\n@keras_export([\"keras.layers.AveragePooling2D\", \"keras.layers.AvgPool2D\"])\nclass AveragePooling2D(BasePooling):\n    \"\"\"Average pooling operation for 2D spatial data.\n\n    Downsamples the input along its spatial dimensions (height and width)\n    by taking the average value over an input window\n    (of size defined by `pool_size`) for each channel of the input.\n    The window is shifted by `strides` along each dimension.\n\n    The resulting output when using the `\"valid\"` padding option has a spatial\n    shape (number of rows or columns) of:\n    `output_shape = math.floor((input_shape - pool_size) / strides) + 1`\n    (when `input_shape >= pool_size`)\n\n    The resulting output shape when using the `\"same\"` padding option is:\n    `output_shape = input_shape`\n\n    Args:\n        pool_size: int or tuple of 2 integers, factors by which to downscale\n            (dim1, dim2). If only one integer is specified, the same\n            window length will be used for all dimensions.\n        strides: int or tuple of 2 integers, or None. Strides values. If None,\n            it will default to `pool_size`. If only one int is specified, the\n            same stride size will be used for all dimensions.\n        padding: string, either `\"valid\"` or `\"same\"` (case-insensitive).\n            `\"valid\"` means no padding. `\"same\"` results in padding evenly to\n            the left/right or up/down of the input such that output has the same\n            height/width dimension as the input.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, height, width, channels)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, channels, height, width)`. It defaults to the\n            `image_data_format` value found in your Keras config file at\n            `~/.keras/keras.json`. If you never set it, then it will be\n            `\"channels_last\"`.\n\n    Input shape:\n\n    - If `data_format=\"channels_last\"`:\n        4D tensor with shape `(batch_size, height, width, channels)`.\n    - If `data_format=\"channels_first\"`:\n        4D tensor with shape `(batch_size, channels, height, width)`.\n\n    Output shape:\n\n    - If `data_format=\"channels_last\"`:\n        4D tensor with shape\n        `(batch_size, pooled_height, pooled_width, channels)`.\n    - If `data_format=\"channels_first\"`:\n        4D tensor with shape\n        `(batch_size, channels, pooled_height, pooled_width)`.\n\n    Examples:\n\n    `strides=(1, 1)` and `padding=\"valid\"`:\n\n    >>> x = np.array([[1., 2., 3.],\n    ...               [4., 5., 6.],\n    ...               [7., 8., 9.]])\n    >>> x = np.reshape(x, [1, 3, 3, 1])\n    >>> avg_pool_2d = keras.layers.AveragePooling2D(pool_size=(2, 2),\n    ...    strides=(1, 1), padding=\"valid\")\n    >>> avg_pool_2d(x)\n\n    `strides=(2, 2)` and `padding=\"valid\"`:\n\n    >>> x = np.array([[1., 2., 3., 4.],\n    ...              [5., 6., 7., 8.],\n    ...              [9., 10., 11., 12.]])\n    >>> x = np.reshape(x, [1, 3, 4, 1])\n    >>> avg_pool_2d = keras.layers.AveragePooling2D(pool_size=(2, 2),\n    ...    strides=(2, 2), padding=\"valid\")\n    >>> avg_pool_2d(x)\n\n    `stride=(1, 1)` and `padding=\"same\"`:\n\n    >>> x = np.array([[1., 2., 3.],\n    ...                  [4., 5., 6.],\n    ...                  [7., 8., 9.]])\n    >>> x = np.reshape(x, [1, 3, 3, 1])\n    >>> avg_pool_2d = keras.layers.AveragePooling2D(pool_size=(2, 2),\n    ...    strides=(1, 1), padding=\"same\")\n    >>> avg_pool_2d(x)\n    \"\"\"\n\n    def __init__(\n        self,\n        pool_size,\n        strides=None,\n        padding=\"valid\",\n        data_format=None,\n        name=None,\n        **kwargs,\n    ):\n        super().__init__(\n            pool_size,\n            strides,\n            pool_dimensions=2,\n            pool_mode=\"average\",\n            padding=padding,\n            data_format=data_format,\n            name=name,\n            **kwargs,\n        )\n"
  },
  {
    "path": "keras/src/layers/pooling/average_pooling3d.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.pooling.base_pooling import BasePooling\n\n\n@keras_export([\"keras.layers.AveragePooling3D\", \"keras.layers.AvgPool3D\"])\nclass AveragePooling3D(BasePooling):\n    \"\"\"Average pooling operation for 3D data (spatial or spatio-temporal).\n\n    Downsamples the input along its spatial dimensions (depth, height, and\n    width) by taking the average value over an input window (of size defined by\n    `pool_size`) for each channel of the input. The window is shifted by\n    `strides` along each dimension.\n\n    Args:\n        pool_size: int or tuple of 3 integers, factors by which to downscale\n            (dim1, dim2, dim3). If only one integer is specified, the same\n            window length will be used for all dimensions.\n        strides: int or tuple of 3 integers, or None. Strides values. If None,\n            it will default to `pool_size`. If only one int is specified, the\n            same stride size will be used for all dimensions.\n        padding: string, either `\"valid\"` or `\"same\"` (case-insensitive).\n            `\"valid\"` means no padding. `\"same\"` results in padding evenly to\n            the left/right or up/down of the input such that output has the same\n            height/width dimension as the input.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape\n            `(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)` while\n            `\"channels_first\"` corresponds to inputs with shape\n            `(batch, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.\n            It defaults to the `image_data_format` value found in your Keras\n            config file at `~/.keras/keras.json`. If you never set it, then it\n            will be `\"channels_last\"`.\n\n    Input shape:\n\n    - If `data_format=\"channels_last\"`:\n        5D tensor with shape:\n        `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`\n    - If `data_format=\"channels_first\"`:\n        5D tensor with shape:\n        `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`\n\n    Output shape:\n\n    - If `data_format=\"channels_last\"`:\n        5D tensor with shape:\n        `(batch_size, pooled_dim1, pooled_dim2, pooled_dim3, channels)`\n    - If `data_format=\"channels_first\"`:\n        5D tensor with shape:\n        `(batch_size, channels, pooled_dim1, pooled_dim2, pooled_dim3)`\n\n    Example:\n\n    ```python\n    depth = 30\n    height = 30\n    width = 30\n    channels = 3\n\n    inputs = keras.layers.Input(shape=(depth, height, width, channels))\n    layer = keras.layers.AveragePooling3D(pool_size=3)\n    outputs = layer(inputs)  # Shape: (batch_size, 10, 10, 10, 3)\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        pool_size,\n        strides=None,\n        padding=\"valid\",\n        data_format=None,\n        name=None,\n        **kwargs,\n    ):\n        super().__init__(\n            pool_size,\n            strides,\n            pool_dimensions=3,\n            pool_mode=\"average\",\n            padding=padding,\n            data_format=data_format,\n            name=name,\n            **kwargs,\n        )\n"
  },
  {
    "path": "keras/src/layers/pooling/average_pooling_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\nfrom numpy.lib.stride_tricks import as_strided\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\ndef _same_padding(input_size, pool_size, stride):\n    if input_size % stride == 0:\n        return max(pool_size - stride, 0)\n    else:\n        return max(pool_size - (input_size % stride), 0)\n\n\ndef np_avgpool1d(x, pool_size, strides, padding, data_format):\n    if data_format == \"channels_first\":\n        x = x.swapaxes(1, 2)\n    if isinstance(pool_size, (tuple, list)):\n        pool_size = pool_size[0]\n    if isinstance(strides, (tuple, list)):\n        h_stride = strides[0]\n    else:\n        h_stride = strides\n\n    if padding == \"same\":\n        n_batch, h_x, ch_x = x.shape\n        pad_value = _same_padding(h_x, pool_size, h_stride)\n        npad = [(0, 0)] * x.ndim\n        npad[1] = (0, pad_value)\n        x = np.pad(x, pad_width=npad, mode=\"edge\")\n\n    n_batch, h_x, ch_x = x.shape\n    out_h = int((h_x - pool_size) / h_stride) + 1\n\n    stride_shape = (n_batch, out_h, ch_x, pool_size)\n    strides = (\n        x.strides[0],\n        h_stride * x.strides[1],\n        x.strides[2],\n        x.strides[1],\n    )\n    windows = as_strided(x, shape=stride_shape, strides=strides)\n    out = np.mean(windows, axis=(3,))\n    if data_format == \"channels_first\":\n        out = out.swapaxes(1, 2)\n    return out\n\n\ndef np_avgpool2d(x, pool_size, strides, padding, data_format):\n    if data_format == \"channels_first\":\n        x = x.transpose((0, 2, 3, 1))\n    if isinstance(pool_size, int):\n        pool_size = (pool_size, pool_size)\n    if isinstance(strides, int):\n        strides = (strides, strides)\n\n    h_pool_size, w_pool_size = pool_size\n    h_stride, w_stride = strides\n    if padding == \"same\":\n        n_batch, h_x, w_x, ch_x = x.shape\n        h_padding = _same_padding(h_x, h_pool_size, h_stride)\n        w_padding = _same_padding(w_x, w_pool_size, w_stride)\n        npad = [(0, 0)] * x.ndim\n        npad[1] = (0, h_padding)\n        npad[2] = (0, w_padding)\n        x = np.pad(x, pad_width=npad, mode=\"edge\")\n\n    n_batch, h_x, w_x, ch_x = x.shape\n    out_h = int((h_x - h_pool_size) / h_stride) + 1\n    out_w = int((w_x - w_pool_size) / w_stride) + 1\n\n    stride_shape = (n_batch, out_h, out_w, ch_x, *pool_size)\n    strides = (\n        x.strides[0],\n        h_stride * x.strides[1],\n        w_stride * x.strides[2],\n        x.strides[3],\n        x.strides[1],\n        x.strides[2],\n    )\n    windows = as_strided(x, shape=stride_shape, strides=strides)\n    out = np.mean(windows, axis=(4, 5))\n    if data_format == \"channels_first\":\n        out = out.transpose((0, 3, 1, 2))\n    return out\n\n\ndef np_avgpool3d(x, pool_size, strides, padding, data_format):\n    if data_format == \"channels_first\":\n        x = x.transpose((0, 2, 3, 4, 1))\n\n    if isinstance(pool_size, int):\n        pool_size = (pool_size, pool_size, pool_size)\n    if isinstance(strides, int):\n        strides = (strides, strides, strides)\n\n    h_pool_size, w_pool_size, d_pool_size = pool_size\n    h_stride, w_stride, d_stride = strides\n\n    if padding == \"same\":\n        n_batch, h_x, w_x, d_x, ch_x = x.shape\n        h_padding = _same_padding(h_x, h_pool_size, h_stride)\n        w_padding = _same_padding(w_x, w_pool_size, w_stride)\n        d_padding = _same_padding(d_x, d_pool_size, d_stride)\n        npad = [(0, 0)] * x.ndim\n        npad[1] = (0, h_padding)\n        npad[2] = (0, w_padding)\n        npad[3] = (0, d_padding)\n        x = np.pad(x, pad_width=npad, mode=\"symmetric\")\n\n    n_batch, h_x, w_x, d_x, ch_x = x.shape\n    out_h = int((h_x - h_pool_size) / h_stride) + 1\n    out_w = int((w_x - w_pool_size) / w_stride) + 1\n    out_d = int((d_x - d_pool_size) / d_stride) + 1\n\n    stride_shape = (n_batch, out_h, out_w, out_d, ch_x, *pool_size)\n    strides = (\n        x.strides[0],\n        h_stride * x.strides[1],\n        w_stride * x.strides[2],\n        d_stride * x.strides[3],\n        x.strides[4],\n        x.strides[1],\n        x.strides[2],\n        x.strides[3],\n    )\n    windows = as_strided(x, shape=stride_shape, strides=strides)\n    out = np.mean(windows, axis=(5, 6, 7))\n    if data_format == \"channels_first\":\n        out = out.transpose((0, 4, 1, 2, 3))\n    return out\n\n\n@pytest.mark.requires_trainable_backend\nclass AveragePoolingBasicTest(testing.TestCase):\n    @parameterized.parameters(\n        (2, 1, \"valid\", \"channels_last\", (3, 5, 4), (3, 4, 4)),\n        (2, 1, \"same\", \"channels_first\", (3, 5, 4), (3, 5, 4)),\n        ((2,), (2,), \"valid\", \"channels_last\", (3, 5, 4), (3, 2, 4)),\n    )\n    def test_average_pooling1d(\n        self,\n        pool_size,\n        strides,\n        padding,\n        data_format,\n        input_shape,\n        output_shape,\n    ):\n        self.run_layer_test(\n            layers.AveragePooling1D,\n            init_kwargs={\n                \"pool_size\": pool_size,\n                \"strides\": strides,\n                \"padding\": padding,\n                \"data_format\": data_format,\n            },\n            input_shape=input_shape,\n            expected_output_shape=output_shape,\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_losses=0,\n            supports_masking=False,\n            assert_built_after_instantiation=True,\n        )\n\n    @parameterized.parameters(\n        (2, 1, \"valid\", \"channels_last\", (3, 5, 5, 4), (3, 4, 4, 4)),\n        (2, 1, \"same\", \"channels_last\", (3, 5, 5, 4), (3, 5, 5, 4)),\n        (2, 1, \"valid\", \"channels_first\", (3, 5, 5, 4), (3, 5, 4, 3)),\n        (2, 1, \"same\", \"channels_first\", (3, 5, 5, 4), (3, 5, 5, 4)),\n        ((2, 3), (2, 2), \"valid\", \"channels_last\", (3, 5, 5, 4), (3, 2, 2, 4)),\n        ((2, 3), (2, 2), \"same\", \"channels_last\", (3, 5, 5, 4), (3, 3, 3, 4)),\n        ((2, 3), (3, 3), \"same\", \"channels_first\", (3, 5, 5, 4), (3, 5, 2, 2)),\n    )\n    def test_average_pooling2d(\n        self,\n        pool_size,\n        strides,\n        padding,\n        data_format,\n        input_shape,\n        output_shape,\n    ):\n        self.run_layer_test(\n            layers.AveragePooling2D,\n            init_kwargs={\n                \"pool_size\": pool_size,\n                \"strides\": strides,\n                \"padding\": padding,\n                \"data_format\": data_format,\n            },\n            input_shape=input_shape,\n            expected_output_shape=output_shape,\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_losses=0,\n            supports_masking=False,\n            assert_built_after_instantiation=True,\n        )\n\n    @parameterized.parameters(\n        (2, 1, \"valid\", \"channels_last\", (3, 5, 5, 5, 4), (3, 4, 4, 4, 4)),\n        (2, 1, \"same\", \"channels_first\", (3, 5, 5, 5, 4), (3, 5, 5, 5, 4)),\n        (\n            (2, 3, 2),\n            (2, 2, 1),\n            \"valid\",\n            \"channels_last\",\n            (3, 5, 5, 5, 4),\n            (3, 2, 2, 4, 4),\n        ),\n    )\n    def test_average_pooling3d(\n        self,\n        pool_size,\n        strides,\n        padding,\n        data_format,\n        input_shape,\n        output_shape,\n    ):\n        self.run_layer_test(\n            layers.AveragePooling3D,\n            init_kwargs={\n                \"pool_size\": pool_size,\n                \"strides\": strides,\n                \"padding\": padding,\n                \"data_format\": data_format,\n            },\n            input_shape=input_shape,\n            expected_output_shape=output_shape,\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_losses=0,\n            supports_masking=False,\n            # Incomplete op support on tensorflow.\n            run_mixed_precision_check=False,\n            assert_built_after_instantiation=True,\n        )\n\n\nclass AveragePoolingCorrectnessTest(testing.TestCase):\n    @parameterized.parameters(\n        (2, 1, \"valid\", \"channels_last\"),\n        (2, 1, \"valid\", \"channels_first\"),\n        ((2,), (2,), \"valid\", \"channels_last\"),\n        ((2,), (2,), \"valid\", \"channels_first\"),\n    )\n    def test_average_pooling1d(self, pool_size, strides, padding, data_format):\n        inputs = np.arange(24, dtype=\"float32\").reshape((2, 3, 4))\n\n        layer = layers.AveragePooling1D(\n            pool_size=pool_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n        )\n        outputs = layer(inputs)\n        expected = np_avgpool1d(\n            inputs, pool_size, strides, padding, data_format\n        )\n        self.assertAllClose(outputs, expected)\n\n    @parameterized.parameters(\n        (2, 1, \"same\", \"channels_last\"),\n        (2, 1, \"same\", \"channels_first\"),\n        ((2,), (2,), \"same\", \"channels_last\"),\n        ((2,), (2,), \"same\", \"channels_first\"),\n    )\n    @pytest.mark.skipif(\n        backend.backend() == \"torch\",\n        reason=\"Same padding in Torch backend produces different results.\",\n    )\n    def test_average_pooling1d_same_padding(\n        self, pool_size, strides, padding, data_format\n    ):\n        inputs = np.arange(24, dtype=\"float32\").reshape((2, 3, 4))\n\n        layer = layers.AveragePooling1D(\n            pool_size=pool_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n        )\n        outputs = layer(inputs)\n        expected = np_avgpool1d(\n            inputs, pool_size, strides, padding, data_format\n        )\n        self.assertAllClose(outputs, expected)\n\n    @parameterized.parameters(\n        (2, 1, \"valid\", \"channels_last\"),\n        ((2, 3), (2, 2), \"valid\", \"channels_last\"),\n    )\n    def test_average_pooling2d(self, pool_size, strides, padding, data_format):\n        inputs = np.arange(16, dtype=\"float32\").reshape((1, 4, 4, 1))\n        layer = layers.AveragePooling2D(\n            pool_size=pool_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n        )\n        outputs = layer(inputs)\n        expected = np_avgpool2d(\n            inputs, pool_size, strides, padding, data_format\n        )\n        self.assertAllClose(outputs, expected)\n\n    @parameterized.parameters(\n        (2, (2, 1), \"same\", \"channels_last\"),\n        (2, (2, 1), \"same\", \"channels_first\"),\n        ((2, 2), (2, 2), \"same\", \"channels_last\"),\n        ((2, 2), (2, 2), \"same\", \"channels_first\"),\n    )\n    @pytest.mark.skipif(\n        backend.backend() == \"torch\",\n        reason=\"Same padding in Torch backend produces different results.\",\n    )\n    def test_average_pooling2d_same_padding(\n        self, pool_size, strides, padding, data_format\n    ):\n        inputs = np.arange(16, dtype=\"float32\").reshape((1, 4, 4, 1))\n        layer = layers.AveragePooling2D(\n            pool_size=pool_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n        )\n        outputs = layer(inputs)\n        expected = np_avgpool2d(\n            inputs, pool_size, strides, padding, data_format\n        )\n        self.assertAllClose(outputs, expected)\n\n    @parameterized.parameters(\n        (2, 1, \"valid\", \"channels_last\"),\n        (2, 1, \"valid\", \"channels_first\"),\n        ((2, 3, 2), (2, 2, 1), \"valid\", \"channels_last\"),\n        ((2, 3, 2), (2, 2, 1), \"valid\", \"channels_first\"),\n    )\n    def test_average_pooling3d(self, pool_size, strides, padding, data_format):\n        inputs = np.arange(240, dtype=\"float32\").reshape((2, 3, 4, 5, 2))\n\n        layer = layers.AveragePooling3D(\n            pool_size=pool_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n        )\n        outputs = layer(inputs)\n        expected = np_avgpool3d(\n            inputs, pool_size, strides, padding, data_format\n        )\n        self.assertAllClose(outputs, expected)\n\n    @parameterized.parameters(\n        (2, 1, \"same\", \"channels_last\"),\n        (2, 1, \"same\", \"channels_first\"),\n        ((2, 2, 2), (2, 2, 1), \"same\", \"channels_last\"),\n        ((2, 2, 2), (2, 2, 1), \"same\", \"channels_first\"),\n    )\n    @pytest.mark.skipif(\n        backend.backend() == \"torch\",\n        reason=\"Same padding in Torch backend produces different results.\",\n    )\n    def test_average_pooling3d_same_padding(\n        self, pool_size, strides, padding, data_format\n    ):\n        inputs = np.arange(240, dtype=\"float32\").reshape((2, 3, 4, 5, 2))\n\n        layer = layers.AveragePooling3D(\n            pool_size=pool_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n        )\n        outputs = layer(inputs)\n        expected = np_avgpool3d(\n            inputs, pool_size, strides, padding, data_format\n        )\n        self.assertAllClose(outputs, expected)\n"
  },
  {
    "path": "keras/src/layers/pooling/base_adaptive_pooling.py",
    "content": "\"\"\"Base classes for adaptive pooling layers.\"\"\"\n\nfrom keras.src import ops\nfrom keras.src.backend import config\nfrom keras.src.layers.layer import Layer\n\n\nclass BaseAdaptivePooling(Layer):\n    \"\"\"Base class shared by all adaptive pooling layers.\"\"\"\n\n    def __init__(self, output_size, data_format=None, **kwargs):\n        \"\"\"Initialize base adaptive pooling layer.\n\n        Args:\n            output_size: Normalized spatial output size as a tuple\n                (for example, (32,), (32, 32), or (32, 32, 32)).\n            data_format: Either \"channels_last\" or \"channels_first\".\n            **kwargs: Additional layer keyword arguments.\n        \"\"\"\n        super().__init__(**kwargs)\n        self.output_size = output_size\n        self.data_format = data_format or config.image_data_format()\n        if self.data_format not in {\"channels_first\", \"channels_last\"}:\n            raise ValueError(\n                f\"Invalid data_format: {self.data_format}. \"\n                \"Expected 'channels_first' or 'channels_last'.\"\n            )\n\n    def compute_output_shape(self, input_shape):\n        \"\"\"Return the output shape tensor after pooling.\"\"\"\n        batch_size = input_shape[0]\n        if self.data_format == \"channels_last\":\n            channels = input_shape[-1]\n            return (batch_size, *self.output_size, channels)\n        else:\n            channels = input_shape[1]\n            return (batch_size, channels, *self.output_size)\n\n    def get_config(self):\n        config_dict = {\n            \"output_size\": self.output_size,\n            \"data_format\": self.data_format,\n        }\n        base_config = super().get_config()\n        return {**base_config, **config_dict}\n\n\nclass BaseAdaptiveAveragePooling(BaseAdaptivePooling):\n    \"\"\"Base class for adaptive average pooling in 1D, 2D, and 3D.\"\"\"\n\n    def call(self, inputs):\n        return ops.adaptive_average_pool(\n            inputs, output_size=self.output_size, data_format=self.data_format\n        )\n\n\nclass BaseAdaptiveMaxPooling(BaseAdaptivePooling):\n    \"\"\"Base class for adaptive max pooling in 1D, 2D, and 3D.\"\"\"\n\n    def call(self, inputs):\n        return ops.adaptive_max_pool(\n            inputs, output_size=self.output_size, data_format=self.data_format\n        )\n"
  },
  {
    "path": "keras/src/layers/pooling/base_global_pooling.py",
    "content": "from keras.src import backend\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\n\n\nclass BaseGlobalPooling(Layer):\n    \"\"\"Base global pooling layer.\"\"\"\n\n    def __init__(\n        self, pool_dimensions, data_format=None, keepdims=False, **kwargs\n    ):\n        super().__init__(**kwargs)\n\n        self.data_format = backend.standardize_data_format(data_format)\n        self.keepdims = keepdims\n        self.input_spec = InputSpec(ndim=pool_dimensions + 2)\n\n        self._build_at_init()\n\n    def call(self, inputs):\n        raise NotImplementedError\n\n    def compute_output_shape(self, input_shape):\n        num_spatial_dims = len(input_shape) - 2\n        if self.data_format == \"channels_last\":\n            if self.keepdims:\n                return (\n                    (input_shape[0],)\n                    + (1,) * num_spatial_dims\n                    + (input_shape[-1],)\n                )\n            else:\n                return (input_shape[0],) + (input_shape[-1],)\n        else:\n            if self.keepdims:\n                return (input_shape[0], input_shape[1]) + (\n                    1,\n                ) * num_spatial_dims\n            else:\n                return (input_shape[0], input_shape[1])\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"data_format\": self.data_format,\n                \"keepdims\": self.keepdims,\n            }\n        )\n        return config\n"
  },
  {
    "path": "keras/src/layers/pooling/base_pooling.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\nfrom keras.src.ops.operation_utils import compute_pooling_output_shape\nfrom keras.src.utils import argument_validation\n\n\nclass BasePooling(Layer):\n    \"\"\"Base pooling layer.\"\"\"\n\n    def __init__(\n        self,\n        pool_size,\n        strides,\n        pool_dimensions,\n        pool_mode=\"max\",\n        padding=\"valid\",\n        data_format=None,\n        name=None,\n        **kwargs,\n    ):\n        super().__init__(name=name, **kwargs)\n\n        self.pool_size = argument_validation.standardize_tuple(\n            pool_size, pool_dimensions, \"pool_size\"\n        )\n        strides = pool_size if strides is None else strides\n        self.strides = argument_validation.standardize_tuple(\n            strides, pool_dimensions, \"strides\", allow_zero=True\n        )\n        self.pool_mode = pool_mode\n        self.padding = padding\n        self.data_format = backend.standardize_data_format(data_format)\n\n        self.input_spec = InputSpec(ndim=pool_dimensions + 2)\n\n        self._build_at_init()\n\n    def call(self, inputs):\n        if self.pool_mode == \"max\":\n            return ops.max_pool(\n                inputs,\n                pool_size=self.pool_size,\n                strides=self.strides,\n                padding=self.padding,\n                data_format=self.data_format,\n            )\n        elif self.pool_mode == \"average\":\n            return ops.average_pool(\n                inputs,\n                pool_size=self.pool_size,\n                strides=self.strides,\n                padding=self.padding,\n                data_format=self.data_format,\n            )\n        else:\n            raise ValueError(\n                \"`pool_mode` must be either 'max' or 'average'. Received: \"\n                f\"{self.pool_mode}.\"\n            )\n\n    def compute_output_shape(self, input_shape):\n        return compute_pooling_output_shape(\n            input_shape,\n            self.pool_size,\n            self.strides,\n            self.padding,\n            self.data_format,\n        )\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"pool_size\": self.pool_size,\n                \"padding\": self.padding,\n                \"strides\": self.strides,\n                \"data_format\": self.data_format,\n            }\n        )\n        return config\n"
  },
  {
    "path": "keras/src/layers/pooling/global_average_pooling1d.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.pooling.base_global_pooling import BaseGlobalPooling\n\n\n@keras_export(\n    [\n        \"keras.layers.GlobalAveragePooling1D\",\n        \"keras.layers.GlobalAvgPool1D\",\n    ]\n)\nclass GlobalAveragePooling1D(BaseGlobalPooling):\n    \"\"\"Global average pooling operation for temporal data.\n\n    Args:\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, steps, features)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, features, steps)`. It defaults to the `image_data_format`\n            value found in your Keras config file at `~/.keras/keras.json`.\n            If you never set it, then it will be `\"channels_last\"`.\n        keepdims: A boolean, whether to keep the temporal dimension or not.\n            If `keepdims` is `False` (default), the rank of the tensor is\n            reduced for spatial dimensions. If `keepdims` is `True`, the\n            temporal dimension are retained with length 1.\n            The behavior is the same as for `tf.reduce_mean` or `np.mean`.\n\n    Call arguments:\n        inputs: A 3D tensor.\n        mask: Binary tensor of shape `(batch_size, steps)` indicating whether\n            a given step should be masked (excluded from the average).\n\n    Input shape:\n\n    - If `data_format='channels_last'`:\n        3D tensor with shape:\n        `(batch_size, steps, features)`\n    - If `data_format='channels_first'`:\n        3D tensor with shape:\n        `(batch_size, features, steps)`\n\n    Output shape:\n\n    - If `keepdims=False`:\n        2D tensor with shape `(batch_size, features)`.\n    - If `keepdims=True`:\n        - If `data_format=\"channels_last\"`:\n            3D tensor with shape `(batch_size, 1, features)`\n        - If `data_format=\"channels_first\"`:\n            3D tensor with shape `(batch_size, features, 1)`\n\n    Example:\n\n    >>> x = np.random.rand(2, 3, 4)\n    >>> y = keras.layers.GlobalAveragePooling1D()(x)\n    >>> y.shape\n    (2, 4)\n    \"\"\"\n\n    def __init__(self, data_format=None, keepdims=False, **kwargs):\n        super().__init__(\n            pool_dimensions=1,\n            data_format=data_format,\n            keepdims=keepdims,\n            **kwargs,\n        )\n        self.supports_masking = True\n\n    def call(self, inputs, mask=None):\n        steps_axis = 1 if self.data_format == \"channels_last\" else 2\n        if mask is not None:\n            mask = backend.cast(mask, inputs[0].dtype)\n            mask = ops.expand_dims(\n                mask, 2 if self.data_format == \"channels_last\" else 1\n            )\n            inputs *= mask\n            return ops.sum(\n                inputs, axis=steps_axis, keepdims=self.keepdims\n            ) / ops.sum(mask, axis=steps_axis, keepdims=self.keepdims)\n        else:\n            return ops.mean(inputs, axis=steps_axis, keepdims=self.keepdims)\n\n    def compute_mask(self, inputs, mask=None):\n        return None\n"
  },
  {
    "path": "keras/src/layers/pooling/global_average_pooling2d.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.pooling.base_global_pooling import BaseGlobalPooling\n\n\n@keras_export(\n    [\n        \"keras.layers.GlobalAveragePooling2D\",\n        \"keras.layers.GlobalAvgPool2D\",\n    ]\n)\nclass GlobalAveragePooling2D(BaseGlobalPooling):\n    \"\"\"Global average pooling operation for 2D data.\n\n    Args:\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, height, width, channels)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, features, height, weight)`. It defaults to the\n            `image_data_format` value found in your Keras config file at\n            `~/.keras/keras.json`. If you never set it, then it will be\n            `\"channels_last\"`.\n        keepdims: A boolean, whether to keep the temporal dimension or not.\n            If `keepdims` is `False` (default), the rank of the tensor is\n            reduced for spatial dimensions. If `keepdims` is `True`, the\n            spatial dimension are retained with length 1.\n            The behavior is the same as for `tf.reduce_mean` or `np.mean`.\n\n    Input shape:\n\n    - If `data_format='channels_last'`:\n        4D tensor with shape:\n        `(batch_size, height, width, channels)`\n    - If `data_format='channels_first'`:\n        4D tensor with shape:\n        `(batch_size, channels, height, width)`\n\n    Output shape:\n\n    - If `keepdims=False`:\n        2D tensor with shape `(batch_size, channels)`.\n    - If `keepdims=True`:\n        - If `data_format=\"channels_last\"`:\n            4D tensor with shape `(batch_size, 1, 1, channels)`\n        - If `data_format=\"channels_first\"`:\n            4D tensor with shape `(batch_size, channels, 1, 1)`\n\n    Example:\n\n    >>> x = np.random.rand(2, 4, 5, 3)\n    >>> y = keras.layers.GlobalAveragePooling2D()(x)\n    >>> y.shape\n    (2, 3)\n    \"\"\"\n\n    def __init__(self, data_format=None, keepdims=False, **kwargs):\n        super().__init__(\n            pool_dimensions=2,\n            data_format=data_format,\n            keepdims=keepdims,\n            **kwargs,\n        )\n\n    def call(self, inputs):\n        if self.data_format == \"channels_last\":\n            return ops.mean(inputs, axis=[1, 2], keepdims=self.keepdims)\n        return ops.mean(inputs, axis=[2, 3], keepdims=self.keepdims)\n"
  },
  {
    "path": "keras/src/layers/pooling/global_average_pooling3d.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.pooling.base_global_pooling import BaseGlobalPooling\n\n\n@keras_export(\n    [\n        \"keras.layers.GlobalAveragePooling3D\",\n        \"keras.layers.GlobalAvgPool3D\",\n    ]\n)\nclass GlobalAveragePooling3D(BaseGlobalPooling):\n    \"\"\"Global average pooling operation for 3D data.\n\n    Args:\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape\n            `(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.\n            It defaults to the `image_data_format` value found in your Keras\n            config file at `~/.keras/keras.json`. If you never set it, then it\n            will be `\"channels_last\"`.\n        keepdims: A boolean, whether to keep the temporal dimension or not.\n            If `keepdims` is `False` (default), the rank of the tensor is\n            reduced for spatial dimensions. If `keepdims` is `True`, the\n            spatial dimension are retained with length 1.\n            The behavior is the same as for `tf.reduce_mean` or `np.mean`.\n\n    Input shape:\n\n    - If `data_format='channels_last'`:\n        5D tensor with shape:\n        `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`\n    - If `data_format='channels_first'`:\n        5D tensor with shape:\n        `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`\n\n    Output shape:\n\n    - If `keepdims=False`:\n        2D tensor with shape `(batch_size, channels)`.\n    - If `keepdims=True`:\n        - If `data_format=\"channels_last\"`:\n            5D tensor with shape `(batch_size, 1, 1, 1, channels)`\n        - If `data_format=\"channels_first\"`:\n            5D tensor with shape `(batch_size, channels, 1, 1, 1)`\n\n    Example:\n\n    >>> x = np.random.rand(2, 4, 5, 4, 3)\n    >>> y = keras.layers.GlobalAveragePooling3D()(x)\n    >>> y.shape\n    (2, 3)\n    \"\"\"\n\n    def __init__(self, data_format=None, keepdims=False, **kwargs):\n        super().__init__(\n            pool_dimensions=3,\n            data_format=data_format,\n            keepdims=keepdims,\n            **kwargs,\n        )\n\n    def call(self, inputs):\n        if self.data_format == \"channels_last\":\n            return ops.mean(inputs, axis=[1, 2, 3], keepdims=self.keepdims)\n        return ops.mean(inputs, axis=[2, 3, 4], keepdims=self.keepdims)\n"
  },
  {
    "path": "keras/src/layers/pooling/global_average_pooling_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import layers\nfrom keras.src import testing\n\n\n@pytest.mark.requires_trainable_backend\nclass GlobalAveragePoolingBasicTest(testing.TestCase):\n    @parameterized.parameters(\n        (\"channels_last\", False, (3, 5, 4), (3, 4)),\n        (\"channels_last\", True, (3, 5, 4), (3, 1, 4)),\n        (\"channels_first\", False, (3, 5, 4), (3, 5)),\n    )\n    def test_global_average_pooling1d(\n        self,\n        data_format,\n        keepdims,\n        input_shape,\n        output_shape,\n    ):\n        self.run_layer_test(\n            layers.GlobalAveragePooling1D,\n            init_kwargs={\n                \"data_format\": data_format,\n                \"keepdims\": keepdims,\n            },\n            input_shape=input_shape,\n            expected_output_shape=output_shape,\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_losses=0,\n            supports_masking=True,\n            assert_built_after_instantiation=True,\n        )\n\n    @parameterized.parameters(\n        (\"channels_last\", False, (3, 5, 6, 4), (3, 4)),\n        (\"channels_last\", True, (3, 5, 6, 4), (3, 1, 1, 4)),\n        (\"channels_first\", False, (3, 5, 6, 4), (3, 5)),\n    )\n    def test_global_average_pooling2d(\n        self,\n        data_format,\n        keepdims,\n        input_shape,\n        output_shape,\n    ):\n        self.run_layer_test(\n            layers.GlobalAveragePooling2D,\n            init_kwargs={\n                \"data_format\": data_format,\n                \"keepdims\": keepdims,\n            },\n            input_shape=input_shape,\n            expected_output_shape=output_shape,\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_losses=0,\n            supports_masking=False,\n            assert_built_after_instantiation=True,\n        )\n\n    @parameterized.parameters(\n        (\"channels_last\", False, (3, 5, 6, 5, 4), (3, 4)),\n        (\"channels_last\", True, (3, 5, 6, 5, 4), (3, 1, 1, 1, 4)),\n        (\"channels_first\", False, (3, 5, 6, 5, 4), (3, 5)),\n    )\n    def test_global_average_pooling3d(\n        self,\n        data_format,\n        keepdims,\n        input_shape,\n        output_shape,\n    ):\n        self.run_layer_test(\n            layers.GlobalAveragePooling3D,\n            init_kwargs={\n                \"data_format\": data_format,\n                \"keepdims\": keepdims,\n            },\n            input_shape=input_shape,\n            expected_output_shape=output_shape,\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_losses=0,\n            supports_masking=False,\n            assert_built_after_instantiation=True,\n        )\n\n\nclass GlobalAveragePoolingCorrectnessTest(testing.TestCase):\n    @parameterized.parameters(\n        (\"channels_last\", False),\n        (\"channels_last\", True),\n        (\"channels_first\", False),\n        (\"channels_first\", True),\n    )\n    def test_global_average_pooling1d(self, data_format, keepdims):\n        def np_gap1d(x, data_format, keepdims, mask=None):\n            steps_axis = 1 if data_format == \"channels_last\" else 2\n            if mask is not None:\n                mask = np.expand_dims(\n                    mask, 2 if data_format == \"channels_last\" else 1\n                )\n                x *= mask\n                res = np.sum(x, axis=steps_axis) / np.sum(mask, axis=steps_axis)\n            else:\n                res = np.mean(x, axis=steps_axis)\n            if keepdims:\n                res = np.expand_dims(res, axis=steps_axis)\n            return res\n\n        inputs = np.arange(24, dtype=\"float32\").reshape((2, 3, 4))\n        layer = layers.GlobalAveragePooling1D(\n            data_format=data_format,\n            keepdims=keepdims,\n        )\n        outputs = layer(inputs)\n        expected = np_gap1d(inputs, data_format, keepdims)\n        self.assertAllClose(outputs, expected)\n\n        if data_format == \"channels_last\":\n            mask = np.array([[1, 1, 0], [0, 1, 0]], dtype=\"int32\")\n        else:\n            mask = np.array([[1, 1, 0, 0], [0, 1, 0, 1]], dtype=\"int32\")\n        outputs = layer(inputs, mask)\n        expected = np_gap1d(inputs, data_format, keepdims, mask)\n        self.assertAllClose(outputs, expected)\n\n    @parameterized.parameters(\n        (\"channels_last\", False),\n        (\"channels_last\", True),\n        (\"channels_first\", False),\n        (\"channels_first\", True),\n    )\n    def test_global_average_pooling2d(self, data_format, keepdims):\n        def np_gap2d(x, data_format, keepdims):\n            steps_axis = [1, 2] if data_format == \"channels_last\" else [2, 3]\n            res = np.apply_over_axes(np.mean, x, steps_axis)\n            if not keepdims:\n                res = res.squeeze()\n            return res\n\n        inputs = np.arange(96, dtype=\"float32\").reshape((2, 3, 4, 4))\n        layer = layers.GlobalAveragePooling2D(\n            data_format=data_format,\n            keepdims=keepdims,\n        )\n        outputs = layer(inputs)\n        expected = np_gap2d(inputs, data_format, keepdims)\n        self.assertAllClose(outputs, expected)\n\n    @parameterized.parameters(\n        (\"channels_last\", False),\n        (\"channels_last\", True),\n        (\"channels_first\", False),\n        (\"channels_first\", True),\n    )\n    def test_global_average_pooling3d(self, data_format, keepdims):\n        def np_gap3d(x, data_format, keepdims):\n            steps_axis = (\n                [1, 2, 3] if data_format == \"channels_last\" else [2, 3, 4]\n            )\n            res = np.apply_over_axes(np.mean, x, steps_axis)\n            if not keepdims:\n                res = res.squeeze()\n            return res\n\n        inputs = np.arange(360, dtype=\"float32\").reshape((2, 3, 3, 5, 4))\n        layer = layers.GlobalAveragePooling3D(\n            data_format=data_format,\n            keepdims=keepdims,\n        )\n        outputs = layer(inputs)\n        expected = np_gap3d(inputs, data_format, keepdims)\n        self.assertAllClose(outputs, expected)\n"
  },
  {
    "path": "keras/src/layers/pooling/global_max_pooling1d.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.pooling.base_global_pooling import BaseGlobalPooling\n\n\n@keras_export(\n    [\n        \"keras.layers.GlobalMaxPooling1D\",\n        \"keras.layers.GlobalMaxPool1D\",\n    ]\n)\nclass GlobalMaxPooling1D(BaseGlobalPooling):\n    \"\"\"Global max pooling operation for temporal data.\n\n    Args:\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, steps, features)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, features, steps)`. It defaults to the `image_data_format`\n            value found in your Keras config file at `~/.keras/keras.json`.\n            If you never set it, then it will be `\"channels_last\"`.\n        keepdims: A boolean, whether to keep the temporal dimension or not.\n            If `keepdims` is `False` (default), the rank of the tensor is\n            reduced for spatial dimensions. If `keepdims` is `True`, the\n            temporal dimension are retained with length 1.\n            The behavior is the same as for `tf.reduce_mean` or `np.mean`.\n\n    Input shape:\n\n    - If `data_format='channels_last'`:\n        3D tensor with shape:\n        `(batch_size, steps, features)`\n    - If `data_format='channels_first'`:\n        3D tensor with shape:\n        `(batch_size, features, steps)`\n\n    Output shape:\n\n    - If `keepdims=False`:\n        2D tensor with shape `(batch_size, features)`.\n    - If `keepdims=True`:\n        - If `data_format=\"channels_last\"`:\n            3D tensor with shape `(batch_size, 1, features)`\n        - If `data_format=\"channels_first\"`:\n            3D tensor with shape `(batch_size, features, 1)`\n\n    Example:\n\n    >>> x = np.random.rand(2, 3, 4)\n    >>> y = keras.layers.GlobalMaxPooling1D()(x)\n    >>> y.shape\n    (2, 4)\n    \"\"\"\n\n    def __init__(self, data_format=None, keepdims=False, **kwargs):\n        super().__init__(\n            pool_dimensions=1,\n            data_format=data_format,\n            keepdims=keepdims,\n            **kwargs,\n        )\n\n    def call(self, inputs):\n        steps_axis = 1 if self.data_format == \"channels_last\" else 2\n        return ops.max(inputs, axis=steps_axis, keepdims=self.keepdims)\n"
  },
  {
    "path": "keras/src/layers/pooling/global_max_pooling2d.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.pooling.base_global_pooling import BaseGlobalPooling\n\n\n@keras_export(\n    [\n        \"keras.layers.GlobalMaxPooling2D\",\n        \"keras.layers.GlobalMaxPool2D\",\n    ]\n)\nclass GlobalMaxPooling2D(BaseGlobalPooling):\n    \"\"\"Global max pooling operation for 2D data.\n\n    Args:\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, height, width, channels)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, features, height, weight)`. It defaults to the\n            `image_data_format` value found in your Keras config file at\n            `~/.keras/keras.json`. If you never set it, then it will be\n            `\"channels_last\"`.\n        keepdims: A boolean, whether to keep the temporal dimension or not.\n            If `keepdims` is `False` (default), the rank of the tensor is\n            reduced for spatial dimensions. If `keepdims` is `True`, the\n            spatial dimension are retained with length 1.\n            The behavior is the same as for `tf.reduce_mean` or `np.mean`.\n\n    Input shape:\n\n    - If `data_format='channels_last'`:\n        4D tensor with shape:\n        `(batch_size, height, width, channels)`\n    - If `data_format='channels_first'`:\n        4D tensor with shape:\n        `(batch_size, channels, height, width)`\n\n    Output shape:\n\n    - If `keepdims=False`:\n        2D tensor with shape `(batch_size, channels)`.\n    - If `keepdims=True`:\n        - If `data_format=\"channels_last\"`:\n            4D tensor with shape `(batch_size, 1, 1, channels)`\n        - If `data_format=\"channels_first\"`:\n            4D tensor with shape `(batch_size, channels, 1, 1)`\n\n    Example:\n\n    >>> x = np.random.rand(2, 4, 5, 3)\n    >>> y = keras.layers.GlobalMaxPooling2D()(x)\n    >>> y.shape\n    (2, 3)\n    \"\"\"\n\n    def __init__(self, data_format=None, keepdims=False, **kwargs):\n        super().__init__(\n            pool_dimensions=2,\n            data_format=data_format,\n            keepdims=keepdims,\n            **kwargs,\n        )\n\n    def call(self, inputs):\n        if self.data_format == \"channels_last\":\n            return ops.max(inputs, axis=[1, 2], keepdims=self.keepdims)\n        return ops.max(inputs, axis=[2, 3], keepdims=self.keepdims)\n"
  },
  {
    "path": "keras/src/layers/pooling/global_max_pooling3d.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.pooling.base_global_pooling import BaseGlobalPooling\n\n\n@keras_export(\n    [\n        \"keras.layers.GlobalMaxPooling3D\",\n        \"keras.layers.GlobalMaxPool3D\",\n    ]\n)\nclass GlobalMaxPooling3D(BaseGlobalPooling):\n    \"\"\"Global max pooling operation for 3D data.\n\n    Args:\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape\n            `(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.\n            It defaults to the `image_data_format` value found in your Keras\n            config file at `~/.keras/keras.json`. If you never set it, then it\n            will be `\"channels_last\"`.\n        keepdims: A boolean, whether to keep the temporal dimension or not.\n            If `keepdims` is `False` (default), the rank of the tensor is\n            reduced for spatial dimensions. If `keepdims` is `True`, the\n            spatial dimension are retained with length 1.\n            The behavior is the same as for `tf.reduce_mean` or `np.mean`.\n\n    Input shape:\n\n    - If `data_format='channels_last'`:\n        5D tensor with shape:\n        `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`\n    - If `data_format='channels_first'`:\n        5D tensor with shape:\n        `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`\n\n    Output shape:\n\n    - If `keepdims=False`:\n        2D tensor with shape `(batch_size, channels)`.\n    - If `keepdims=True`:\n        - If `data_format=\"channels_last\"`:\n            5D tensor with shape `(batch_size, 1, 1, 1, channels)`\n        - If `data_format=\"channels_first\"`:\n            5D tensor with shape `(batch_size, channels, 1, 1, 1)`\n\n    Example:\n\n    >>> x = np.random.rand(2, 4, 5, 4, 3)\n    >>> y = keras.layers.GlobalMaxPooling3D()(x)\n    >>> y.shape\n    (2, 3)\n    \"\"\"\n\n    def __init__(self, data_format=None, keepdims=False, **kwargs):\n        super().__init__(\n            pool_dimensions=3,\n            data_format=data_format,\n            keepdims=keepdims,\n            **kwargs,\n        )\n\n    def call(self, inputs):\n        if self.data_format == \"channels_last\":\n            return ops.max(inputs, axis=[1, 2, 3], keepdims=self.keepdims)\n        return ops.max(inputs, axis=[2, 3, 4], keepdims=self.keepdims)\n"
  },
  {
    "path": "keras/src/layers/pooling/global_max_pooling_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import layers\nfrom keras.src import testing\n\n\n@pytest.mark.requires_trainable_backend\nclass GlobalMaxPoolingBasicTest(testing.TestCase):\n    @parameterized.parameters(\n        (\"channels_last\", False, (3, 5, 4), (3, 4)),\n        (\"channels_last\", True, (3, 5, 4), (3, 1, 4)),\n        (\"channels_first\", False, (3, 5, 4), (3, 5)),\n    )\n    def test_global_max_pooling1d(\n        self,\n        data_format,\n        keepdims,\n        input_shape,\n        output_shape,\n    ):\n        self.run_layer_test(\n            layers.GlobalMaxPooling1D,\n            init_kwargs={\n                \"data_format\": data_format,\n                \"keepdims\": keepdims,\n            },\n            input_shape=input_shape,\n            expected_output_shape=output_shape,\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_losses=0,\n            supports_masking=False,\n            assert_built_after_instantiation=True,\n        )\n\n    @parameterized.parameters(\n        (\"channels_last\", False, (3, 5, 6, 4), (3, 4)),\n        (\"channels_last\", True, (3, 5, 6, 4), (3, 1, 1, 4)),\n        (\"channels_first\", False, (3, 5, 6, 4), (3, 5)),\n    )\n    def test_global_max_pooling2d(\n        self,\n        data_format,\n        keepdims,\n        input_shape,\n        output_shape,\n    ):\n        self.run_layer_test(\n            layers.GlobalMaxPooling2D,\n            init_kwargs={\n                \"data_format\": data_format,\n                \"keepdims\": keepdims,\n            },\n            input_shape=input_shape,\n            expected_output_shape=output_shape,\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_losses=0,\n            supports_masking=False,\n            assert_built_after_instantiation=True,\n        )\n\n    @parameterized.parameters(\n        (\"channels_last\", False, (3, 5, 6, 5, 4), (3, 4)),\n        (\"channels_last\", True, (3, 5, 6, 5, 4), (3, 1, 1, 1, 4)),\n        (\"channels_first\", False, (3, 5, 6, 5, 4), (3, 5)),\n    )\n    def test_global_max_pooling3d(\n        self,\n        data_format,\n        keepdims,\n        input_shape,\n        output_shape,\n    ):\n        self.run_layer_test(\n            layers.GlobalMaxPooling3D,\n            init_kwargs={\n                \"data_format\": data_format,\n                \"keepdims\": keepdims,\n            },\n            input_shape=input_shape,\n            expected_output_shape=output_shape,\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_losses=0,\n            supports_masking=False,\n            assert_built_after_instantiation=True,\n        )\n\n\nclass GlobalMaxPoolingCorrectnessTest(testing.TestCase):\n    @parameterized.parameters(\n        (\"channels_last\", False),\n        (\"channels_last\", True),\n        (\"channels_first\", False),\n        (\"channels_first\", True),\n    )\n    def test_global_max_pooling1d(self, data_format, keepdims):\n        def np_global_max_pool1d(x, data_format, keepdims):\n            steps_axis = [1] if data_format == \"channels_last\" else [2]\n            res = np.apply_over_axes(np.max, x, steps_axis)\n            if not keepdims:\n                res = res.squeeze()\n            return res\n\n        inputs = np.arange(24, dtype=\"float32\").reshape((2, 3, 4))\n        layer = layers.GlobalMaxPooling1D(\n            data_format=data_format,\n            keepdims=keepdims,\n        )\n        outputs = layer(inputs)\n        expected = np_global_max_pool1d(inputs, data_format, keepdims)\n        self.assertAllClose(outputs, expected)\n\n    @parameterized.parameters(\n        (\"channels_last\", False),\n        (\"channels_last\", True),\n        (\"channels_first\", False),\n        (\"channels_first\", True),\n    )\n    def test_global_max_pooling2d(self, data_format, keepdims):\n        def np_global_max_pool2d(x, data_format, keepdims):\n            steps_axis = [1, 2] if data_format == \"channels_last\" else [2, 3]\n            res = np.apply_over_axes(np.max, x, steps_axis)\n            if not keepdims:\n                res = res.squeeze()\n            return res\n\n        inputs = np.arange(96, dtype=\"float32\").reshape((2, 3, 4, 4))\n        layer = layers.GlobalMaxPooling2D(\n            data_format=data_format,\n            keepdims=keepdims,\n        )\n        outputs = layer(inputs)\n        expected = np_global_max_pool2d(inputs, data_format, keepdims)\n        self.assertAllClose(outputs, expected)\n\n    @parameterized.parameters(\n        (\"channels_last\", False),\n        (\"channels_last\", True),\n        (\"channels_first\", False),\n        (\"channels_first\", True),\n    )\n    def test_global_max_pooling3d(self, data_format, keepdims):\n        def np_global_max_pool3d(x, data_format, keepdims):\n            steps_axis = (\n                [1, 2, 3] if data_format == \"channels_last\" else [2, 3, 4]\n            )\n            res = np.apply_over_axes(np.max, x, steps_axis)\n            if not keepdims:\n                res = res.squeeze()\n            return res\n\n        inputs = np.arange(360, dtype=\"float32\").reshape((2, 3, 3, 5, 4))\n        layer = layers.GlobalMaxPooling3D(\n            data_format=data_format,\n            keepdims=keepdims,\n        )\n        outputs = layer(inputs)\n        expected = np_global_max_pool3d(inputs, data_format, keepdims)\n        self.assertAllClose(outputs, expected)\n"
  },
  {
    "path": "keras/src/layers/pooling/max_pooling1d.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.pooling.base_pooling import BasePooling\n\n\n@keras_export([\"keras.layers.MaxPooling1D\", \"keras.layers.MaxPool1D\"])\nclass MaxPooling1D(BasePooling):\n    \"\"\"Max pooling operation for 1D temporal data.\n\n    Downsamples the input representation by taking the maximum value over a\n    spatial window of size `pool_size`. The window is shifted by `strides`.\n\n    The resulting output when using the `\"valid\"` padding option has a shape of:\n    `output_shape = (input_shape - pool_size + 1) / strides)`.\n\n    The resulting output shape when using the `\"same\"` padding option is:\n    `output_shape = input_shape / strides`\n\n    Args:\n        pool_size: int, size of the max pooling window.\n        strides: int or None. Specifies how much the pooling window moves\n            for each pooling step. If None, it will default to `pool_size`.\n        padding: string, either `\"valid\"` or `\"same\"` (case-insensitive).\n            `\"valid\"` means no padding. `\"same\"` results in padding evenly to\n            the left/right or up/down of the input such that output has the same\n            height/width dimension as the input.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, steps, features)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, features, steps)`. It defaults to the `image_data_format`\n            value found in your Keras config file at `~/.keras/keras.json`.\n            If you never set it, then it will be `\"channels_last\"`.\n\n    Input shape:\n\n    - If `data_format=\"channels_last\"`:\n        3D tensor with shape `(batch_size, steps, features)`.\n    - If `data_format=\"channels_first\"`:\n        3D tensor with shape `(batch_size, features, steps)`.\n\n    Output shape:\n\n    - If `data_format=\"channels_last\"`:\n        3D tensor with shape `(batch_size, downsampled_steps, features)`.\n    - If `data_format=\"channels_first\"`:\n        3D tensor with shape `(batch_size, features, downsampled_steps)`.\n\n    Examples:\n\n    `strides=1` and `padding=\"valid\"`:\n\n    >>> x = np.array([1., 2., 3., 4., 5.])\n    >>> x = np.reshape(x, [1, 5, 1])\n    >>> max_pool_1d = keras.layers.MaxPooling1D(pool_size=2,\n    ...    strides=1, padding=\"valid\")\n    >>> max_pool_1d(x)\n\n    `strides=2` and `padding=\"valid\"`:\n\n    >>> x = np.array([1., 2., 3., 4., 5.])\n    >>> x = np.reshape(x, [1, 5, 1])\n    >>> max_pool_1d = keras.layers.MaxPooling1D(pool_size=2,\n    ...    strides=2, padding=\"valid\")\n    >>> max_pool_1d(x)\n\n    `strides=1` and `padding=\"same\"`:\n\n    >>> x = np.array([1., 2., 3., 4., 5.])\n    >>> x = np.reshape(x, [1, 5, 1])\n    >>> max_pool_1d = keras.layers.MaxPooling1D(pool_size=2,\n    ...    strides=1, padding=\"same\")\n    >>> max_pool_1d(x)\n    \"\"\"\n\n    def __init__(\n        self,\n        pool_size=2,\n        strides=None,\n        padding=\"valid\",\n        data_format=None,\n        name=None,\n        **kwargs,\n    ):\n        super().__init__(\n            pool_size,\n            strides,\n            pool_dimensions=1,\n            pool_mode=\"max\",\n            padding=padding,\n            data_format=data_format,\n            name=name,\n            **kwargs,\n        )\n"
  },
  {
    "path": "keras/src/layers/pooling/max_pooling2d.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.pooling.base_pooling import BasePooling\n\n\n@keras_export([\"keras.layers.MaxPooling2D\", \"keras.layers.MaxPool2D\"])\nclass MaxPooling2D(BasePooling):\n    \"\"\"Max pooling operation for 2D spatial data.\n\n    Downsamples the input along its spatial dimensions (height and width)\n    by taking the maximum value over an input window\n    (of size defined by `pool_size`) for each channel of the input.\n    The window is shifted by `strides` along each dimension.\n\n    The resulting output when using the `\"valid\"` padding option has a spatial\n    shape (number of rows or columns) of:\n    `output_shape = math.floor((input_shape - pool_size) / strides) + 1`\n    (when `input_shape >= pool_size`)\n\n    The resulting output shape when using the `\"same\"` padding option is:\n    `output_shape = math.floor((input_shape - 1) / strides) + 1`\n\n    Args:\n        pool_size: int or tuple of 2 integers, factors by which to downscale\n            (dim1, dim2). If only one integer is specified, the same\n            window length will be used for all dimensions.\n        strides: int or tuple of 2 integers, or None. Strides values. If None,\n            it will default to `pool_size`. If only one int is specified, the\n            same stride size will be used for all dimensions.\n        padding: string, either `\"valid\"` or `\"same\"` (case-insensitive).\n            `\"valid\"` means no padding. `\"same\"` results in padding evenly to\n            the left/right or up/down of the input such that output has the same\n            height/width dimension as the input.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, height, width, channels)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, channels, height, width)`. It defaults to the\n            `image_data_format` value found in your Keras config file at\n            `~/.keras/keras.json`. If you never set it, then it will be\n            `\"channels_last\"`.\n\n    Input shape:\n\n    - If `data_format=\"channels_last\"`:\n        4D tensor with shape `(batch_size, height, width, channels)`.\n    - If `data_format=\"channels_first\"`:\n        4D tensor with shape `(batch_size, channels, height, width)`.\n\n    Output shape:\n\n    - If `data_format=\"channels_last\"`:\n        4D tensor with shape\n        `(batch_size, pooled_height, pooled_width, channels)`.\n    - If `data_format=\"channels_first\"`:\n        4D tensor with shape\n        `(batch_size, channels, pooled_height, pooled_width)`.\n\n    Examples:\n\n    `strides=(1, 1)` and `padding=\"valid\"`:\n\n    >>> x = np.array([[1., 2., 3.],\n    ...               [4., 5., 6.],\n    ...               [7., 8., 9.]])\n    >>> x = np.reshape(x, [1, 3, 3, 1])\n    >>> max_pool_2d = keras.layers.MaxPooling2D(pool_size=(2, 2),\n    ...    strides=(1, 1), padding=\"valid\")\n    >>> max_pool_2d(x)\n\n    `strides=(2, 2)` and `padding=\"valid\"`:\n\n    >>> x = np.array([[1., 2., 3., 4.],\n    ...               [5., 6., 7., 8.],\n    ...               [9., 10., 11., 12.]])\n    >>> x = np.reshape(x, [1, 3, 4, 1])\n    >>> max_pool_2d = keras.layers.MaxPooling2D(pool_size=(2, 2),\n    ...    strides=(2, 2), padding=\"valid\")\n    >>> max_pool_2d(x)\n\n    `stride=(1, 1)` and `padding=\"same\"`:\n\n    >>> x = np.array([[1., 2., 3.],\n    ...               [4., 5., 6.],\n    ...               [7., 8., 9.]])\n    >>> x = np.reshape(x, [1, 3, 3, 1])\n    >>> max_pool_2d = keras.layers.MaxPooling2D(pool_size=(2, 2),\n    ...    strides=(1, 1), padding=\"same\")\n    >>> max_pool_2d(x)\n    \"\"\"\n\n    def __init__(\n        self,\n        pool_size=(2, 2),\n        strides=None,\n        padding=\"valid\",\n        data_format=None,\n        name=None,\n        **kwargs,\n    ):\n        super().__init__(\n            pool_size,\n            strides,\n            pool_dimensions=2,\n            pool_mode=\"max\",\n            padding=padding,\n            data_format=data_format,\n            name=name,\n            **kwargs,\n        )\n"
  },
  {
    "path": "keras/src/layers/pooling/max_pooling3d.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.pooling.base_pooling import BasePooling\n\n\n@keras_export([\"keras.layers.MaxPooling3D\", \"keras.layers.MaxPool3D\"])\nclass MaxPooling3D(BasePooling):\n    \"\"\"Max pooling operation for 3D data (spatial or spatio-temporal).\n\n    Downsamples the input along its spatial dimensions (depth, height, and\n    width) by taking the maximum value over an input window (of size defined by\n    `pool_size`) for each channel of the input. The window is shifted by\n    `strides` along each dimension.\n\n    Args:\n        pool_size: int or tuple of 3 integers, factors by which to downscale\n            (dim1, dim2, dim3). If only one integer is specified, the same\n            window length will be used for all dimensions.\n        strides: int or tuple of 3 integers, or None. Strides values. If None,\n            it will default to `pool_size`. If only one int is specified, the\n            same stride size will be used for all dimensions.\n        padding: string, either `\"valid\"` or `\"same\"` (case-insensitive).\n            `\"valid\"` means no padding. `\"same\"` results in padding evenly to\n            the left/right or up/down of the input such that output has the same\n            height/width dimension as the input.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape\n            `(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)` while\n            `\"channels_first\"` corresponds to inputs with shape\n            `(batch, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.\n            It defaults to the `image_data_format` value found in your Keras\n            config file at `~/.keras/keras.json`. If you never set it, then it\n            will be `\"channels_last\"`.\n\n    Input shape:\n\n    - If `data_format=\"channels_last\"`:\n        5D tensor with shape:\n        `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`\n    - If `data_format=\"channels_first\"`:\n        5D tensor with shape:\n        `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`\n\n    Output shape:\n\n    - If `data_format=\"channels_last\"`:\n        5D tensor with shape:\n        `(batch_size, pooled_dim1, pooled_dim2, pooled_dim3, channels)`\n    - If `data_format=\"channels_first\"`:\n        5D tensor with shape:\n        `(batch_size, channels, pooled_dim1, pooled_dim2, pooled_dim3)`\n\n    Example:\n\n    ```python\n    depth = 30\n    height = 30\n    width = 30\n    channels = 3\n\n    inputs = keras.layers.Input(shape=(depth, height, width, channels))\n    layer = keras.layers.MaxPooling3D(pool_size=3)\n    outputs = layer(inputs)  # Shape: (batch_size, 10, 10, 10, 3)\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        pool_size=(2, 2, 2),\n        strides=None,\n        padding=\"valid\",\n        data_format=None,\n        name=None,\n        **kwargs,\n    ):\n        super().__init__(\n            pool_size,\n            strides,\n            pool_dimensions=3,\n            pool_mode=\"max\",\n            padding=padding,\n            data_format=data_format,\n            name=name,\n            **kwargs,\n        )\n"
  },
  {
    "path": "keras/src/layers/pooling/max_pooling_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\nfrom numpy.lib.stride_tricks import as_strided\n\nfrom keras.src import layers\nfrom keras.src import testing\n\n\ndef _same_padding(input_size, pool_size, stride):\n    if input_size % stride == 0:\n        return max(pool_size - stride, 0)\n    else:\n        return max(pool_size - (input_size % stride), 0)\n\n\ndef np_maxpool1d(x, pool_size, strides, padding, data_format):\n    if data_format == \"channels_first\":\n        x = x.swapaxes(1, 2)\n    if isinstance(pool_size, (tuple, list)):\n        pool_size = pool_size[0]\n    if isinstance(strides, (tuple, list)):\n        h_stride = strides[0]\n    else:\n        h_stride = strides\n\n    if padding == \"same\":\n        n_batch, h_x, ch_x = x.shape\n        pad_value = _same_padding(h_x, pool_size, h_stride)\n        npad = [(0, 0)] * x.ndim\n        npad[1] = (0, pad_value)\n        x = np.pad(x, pad_width=npad, mode=\"constant\", constant_values=-np.inf)\n\n    n_batch, h_x, ch_x = x.shape\n    out_h = int((h_x - pool_size) / h_stride) + 1\n\n    stride_shape = (n_batch, out_h, ch_x, pool_size)\n    strides = (\n        x.strides[0],\n        h_stride * x.strides[1],\n        x.strides[2],\n        x.strides[1],\n    )\n    windows = as_strided(x, shape=stride_shape, strides=strides)\n    out = np.max(windows, axis=(3,))\n    if data_format == \"channels_first\":\n        out = out.swapaxes(1, 2)\n    return out\n\n\ndef np_maxpool2d(x, pool_size, strides, padding, data_format):\n    if data_format == \"channels_first\":\n        x = x.transpose((0, 2, 3, 1))\n    if isinstance(pool_size, int):\n        pool_size = (pool_size, pool_size)\n    if isinstance(strides, int):\n        strides = (strides, strides)\n\n    h_pool_size, w_pool_size = pool_size\n    h_stride, w_stride = strides\n    if padding == \"same\":\n        n_batch, h_x, w_x, ch_x = x.shape\n        h_padding = _same_padding(h_x, h_pool_size, h_stride)\n        w_padding = _same_padding(w_x, w_pool_size, w_stride)\n        npad = [(0, 0)] * x.ndim\n        npad[1] = (0, h_padding)\n        npad[2] = (0, w_padding)\n        x = np.pad(x, pad_width=npad, mode=\"constant\", constant_values=-np.inf)\n\n    n_batch, h_x, w_x, ch_x = x.shape\n    out_h = int((h_x - h_pool_size) / h_stride) + 1\n    out_w = int((w_x - w_pool_size) / w_stride) + 1\n\n    stride_shape = (n_batch, out_h, out_w, ch_x, *pool_size)\n    strides = (\n        x.strides[0],\n        h_stride * x.strides[1],\n        w_stride * x.strides[2],\n        x.strides[3],\n        x.strides[1],\n        x.strides[2],\n    )\n    windows = as_strided(x, shape=stride_shape, strides=strides)\n    out = np.max(windows, axis=(4, 5))\n    if data_format == \"channels_first\":\n        out = out.transpose((0, 3, 1, 2))\n    return out\n\n\ndef np_maxpool3d(x, pool_size, strides, padding, data_format):\n    if data_format == \"channels_first\":\n        x = x.transpose((0, 2, 3, 4, 1))\n\n    if isinstance(pool_size, int):\n        pool_size = (pool_size, pool_size, pool_size)\n    if isinstance(strides, int):\n        strides = (strides, strides, strides)\n\n    h_pool_size, w_pool_size, d_pool_size = pool_size\n    h_stride, w_stride, d_stride = strides\n\n    if padding == \"same\":\n        n_batch, h_x, w_x, d_x, ch_x = x.shape\n        h_padding = _same_padding(h_x, h_pool_size, h_stride)\n        w_padding = _same_padding(w_x, w_pool_size, w_stride)\n        d_padding = _same_padding(d_x, d_pool_size, d_stride)\n        npad = [(0, 0)] * x.ndim\n        npad[1] = (0, h_padding)\n        npad[2] = (0, w_padding)\n        npad[3] = (0, d_padding)\n        x = np.pad(x, pad_width=npad, mode=\"constant\", constant_values=-np.inf)\n\n    n_batch, h_x, w_x, d_x, ch_x = x.shape\n    out_h = int((h_x - h_pool_size) / h_stride) + 1\n    out_w = int((w_x - w_pool_size) / w_stride) + 1\n    out_d = int((d_x - d_pool_size) / d_stride) + 1\n\n    stride_shape = (n_batch, out_h, out_w, out_d, ch_x, *pool_size)\n    strides = (\n        x.strides[0],\n        h_stride * x.strides[1],\n        w_stride * x.strides[2],\n        d_stride * x.strides[3],\n        x.strides[4],\n        x.strides[1],\n        x.strides[2],\n        x.strides[3],\n    )\n    windows = as_strided(x, shape=stride_shape, strides=strides)\n    out = np.max(windows, axis=(5, 6, 7))\n    if data_format == \"channels_first\":\n        out = out.transpose((0, 4, 1, 2, 3))\n    return out\n\n\n@pytest.mark.requires_trainable_backend\nclass MaxPoolingBasicTest(testing.TestCase):\n    @parameterized.parameters(\n        (2, 1, \"valid\", \"channels_last\", (3, 5, 4), (3, 4, 4)),\n        (2, 1, \"same\", \"channels_first\", (3, 5, 4), (3, 5, 4)),\n        ((2,), (2,), \"valid\", \"channels_last\", (3, 5, 4), (3, 2, 4)),\n    )\n    def test_max_pooling1d(\n        self,\n        pool_size,\n        strides,\n        padding,\n        data_format,\n        input_shape,\n        output_shape,\n    ):\n        self.run_layer_test(\n            layers.MaxPooling1D,\n            init_kwargs={\n                \"pool_size\": pool_size,\n                \"strides\": strides,\n                \"padding\": padding,\n                \"data_format\": data_format,\n            },\n            input_shape=input_shape,\n            expected_output_shape=output_shape,\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_losses=0,\n            supports_masking=False,\n            assert_built_after_instantiation=True,\n        )\n\n    @parameterized.parameters(\n        (2, 1, \"valid\", \"channels_last\", (3, 5, 5, 4), (3, 4, 4, 4)),\n        (2, 1, \"same\", \"channels_first\", (3, 5, 5, 4), (3, 5, 5, 4)),\n        ((2, 3), (2, 2), \"valid\", \"channels_last\", (3, 5, 5, 4), (3, 2, 2, 4)),\n    )\n    def test_max_pooling2d(\n        self,\n        pool_size,\n        strides,\n        padding,\n        data_format,\n        input_shape,\n        output_shape,\n    ):\n        self.run_layer_test(\n            layers.MaxPooling2D,\n            init_kwargs={\n                \"pool_size\": pool_size,\n                \"strides\": strides,\n                \"padding\": padding,\n                \"data_format\": data_format,\n            },\n            input_shape=input_shape,\n            expected_output_shape=output_shape,\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_losses=0,\n            supports_masking=False,\n            assert_built_after_instantiation=True,\n        )\n\n    @parameterized.parameters(\n        (2, 1, \"valid\", \"channels_last\", (3, 5, 5, 5, 4), (3, 4, 4, 4, 4)),\n        (2, 1, \"same\", \"channels_first\", (3, 5, 5, 5, 4), (3, 5, 5, 5, 4)),\n        (\n            (2, 3, 2),\n            (2, 2, 1),\n            \"valid\",\n            \"channels_last\",\n            (3, 5, 5, 5, 4),\n            (3, 2, 2, 4, 4),\n        ),\n    )\n    def test_max_pooling3d(\n        self,\n        pool_size,\n        strides,\n        padding,\n        data_format,\n        input_shape,\n        output_shape,\n    ):\n        self.run_layer_test(\n            layers.MaxPooling3D,\n            init_kwargs={\n                \"pool_size\": pool_size,\n                \"strides\": strides,\n                \"padding\": padding,\n                \"data_format\": data_format,\n            },\n            input_shape=input_shape,\n            expected_output_shape=output_shape,\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_losses=0,\n            supports_masking=False,\n            # Incomplete op support on tensorflow.\n            run_mixed_precision_check=False,\n            assert_built_after_instantiation=True,\n        )\n\n\nclass MaxPoolingCorrectnessTest(testing.TestCase):\n    @parameterized.parameters(\n        (2, 1, \"valid\", \"channels_last\"),\n        (2, 1, \"valid\", \"channels_first\"),\n        (2, 1, \"same\", \"channels_last\"),\n        (2, 1, \"same\", \"channels_first\"),\n        ((2,), (2,), \"valid\", \"channels_last\"),\n        ((2,), (2,), \"valid\", \"channels_first\"),\n    )\n    def test_max_pooling1d(self, pool_size, strides, padding, data_format):\n        inputs = np.arange(24, dtype=\"float32\").reshape((2, 3, 4))\n\n        layer = layers.MaxPooling1D(\n            pool_size=pool_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n        )\n        outputs = layer(inputs)\n        expected = np_maxpool1d(\n            inputs, pool_size, strides, padding, data_format\n        )\n        self.assertAllClose(outputs, expected)\n\n    @parameterized.parameters(\n        (2, 1, \"valid\", \"channels_last\"),\n        (2, 1, \"valid\", \"channels_first\"),\n        ((2, 2), (2, 2), \"same\", \"channels_last\"),\n        ((2, 2), (2, 2), \"same\", \"channels_first\"),\n        ((2, 3), (3, 3), \"same\", \"channels_last\"),\n    )\n    def test_max_pooling2d(self, pool_size, strides, padding, data_format):\n        inputs = np.arange(100, dtype=\"float32\").reshape((1, 5, 5, 4))\n\n        layer = layers.MaxPooling2D(\n            pool_size=pool_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n        )\n        outputs = layer(inputs)\n        expected = np_maxpool2d(\n            inputs, pool_size, strides, padding, data_format\n        )\n        self.assertAllClose(outputs, expected)\n\n    @parameterized.parameters(\n        (2, 1, \"valid\", \"channels_last\"),\n        (2, 1, \"same\", \"channels_first\"),\n        ((2, 3, 2), (2, 2, 1), \"valid\", \"channels_last\"),\n        ((2, 3, 2), (2, 2, 1), \"valid\", \"channels_first\"),\n    )\n    def test_max_pooling3d(self, pool_size, strides, padding, data_format):\n        inputs = np.arange(240, dtype=\"float32\").reshape((2, 3, 4, 5, 2))\n\n        layer = layers.MaxPooling3D(\n            pool_size=pool_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n        )\n        outputs = layer(inputs)\n        expected = np_maxpool3d(\n            inputs, pool_size, strides, padding, data_format\n        )\n        self.assertAllClose(outputs, expected)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/__init__.py",
    "content": ""
  },
  {
    "path": "keras/src/layers/preprocessing/category_encoding.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.backend import KerasTensor\nfrom keras.src.layers.preprocessing.data_layer import DataLayer\nfrom keras.src.utils import backend_utils\nfrom keras.src.utils import numerical_utils\n\n\n@keras_export(\"keras.layers.CategoryEncoding\")\nclass CategoryEncoding(DataLayer):\n    \"\"\"A preprocessing layer which encodes integer features.\n\n    This layer provides options for condensing data into a categorical encoding\n    when the total number of tokens are known in advance. It accepts integer\n    values as inputs, and it outputs a dense or sparse representation of those\n    inputs. For integer inputs where the total number of tokens is not known,\n    use `keras.layers.IntegerLookup` instead.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Examples:\n\n    **One-hot encoding data**\n\n    >>> layer = keras.layers.CategoryEncoding(\n    ...           num_tokens=4, output_mode=\"one_hot\")\n    >>> layer([3, 2, 0, 1])\n    array([[0., 0., 0., 1.],\n            [0., 0., 1., 0.],\n            [1., 0., 0., 0.],\n            [0., 1., 0., 0.]]>\n\n    **Multi-hot encoding data**\n\n    >>> layer = keras.layers.CategoryEncoding(\n    ...           num_tokens=4, output_mode=\"multi_hot\")\n    >>> layer([[0, 1], [0, 0], [1, 2], [3, 1]])\n    array([[1., 1., 0., 0.],\n            [1., 0., 0., 0.],\n            [0., 1., 1., 0.],\n            [0., 1., 0., 1.]]>\n\n    **Using weighted inputs in `\"count\"` mode**\n\n    >>> layer = keras.layers.CategoryEncoding(\n    ...           num_tokens=4, output_mode=\"count\")\n    >>> count_weights = np.array([[.1, .2], [.1, .1], [.2, .3], [.4, .2]])\n    >>> layer([[0, 1], [0, 0], [1, 2], [3, 1]], count_weights=count_weights)\n      array([[0.1, 0.2, 0. , 0. ],\n             [0.2, 0. , 0. , 0. ],\n             [0. , 0.2, 0.3, 0. ],\n             [0. , 0.2, 0. , 0.4]]>\n\n    Args:\n        num_tokens: The total number of tokens the layer should support. All\n            inputs to the layer must integers in the range `0 <= value <\n            num_tokens`, or an error will be thrown.\n        output_mode: Specification for the output of the layer.\n            Values can be `\"one_hot\"`, `\"multi_hot\"` or `\"count\"`,\n            configuring the layer as follows:\n                - `\"one_hot\"`: Encodes each individual element in the input\n                    into an array of `num_tokens` size, containing a 1 at the\n                    element index. If the last dimension is size 1, will encode\n                    on that dimension. If the last dimension is not size 1,\n                    will append a new dimension for the encoded output.\n                - `\"multi_hot\"`: Encodes each sample in the input into a single\n                    array of `num_tokens` size, containing a 1 for each\n                    vocabulary term present in the sample. Treats the last\n                    dimension as the sample dimension, if input shape is\n                    `(..., sample_length)`, output shape will be\n                    `(..., num_tokens)`.\n                - `\"count\"`: Like `\"multi_hot\"`, but the int array contains a\n                    count of the number of times the token at that index\n                    appeared in the sample.\n            For all output modes, currently only output up to rank 2 is\n            supported.\n            Defaults to `\"multi_hot\"`.\n        sparse: Whether to return a sparse tensor; for backends that support\n            sparse tensors.\n\n    Call arguments:\n        inputs: A 1D or 2D tensor of integer inputs.\n        count_weights: A tensor in the same shape as `inputs` indicating the\n            weight for each sample value when summing up in `count` mode.\n            Not used in `\"multi_hot\"` or `\"one_hot\"` modes.\n    \"\"\"\n\n    def __init__(\n        self, num_tokens=None, output_mode=\"multi_hot\", sparse=False, **kwargs\n    ):\n        super().__init__(**kwargs)\n\n        # Support deprecated names for output_modes.\n        if output_mode == \"binary\":\n            output_mode = \"multi_hot\"\n\n        # 'output_mode' must be one of (\"count\", \"one_hot\", \"multi_hot\")\n        if output_mode not in (\"count\", \"one_hot\", \"multi_hot\"):\n            raise ValueError(f\"Unknown arg for output_mode: {output_mode}\")\n\n        if num_tokens is None:\n            raise ValueError(\n                \"num_tokens must be set to use this layer. If the \"\n                \"number of tokens is not known beforehand, use the \"\n                \"IntegerLookup layer instead.\"\n            )\n        if num_tokens < 1:\n            raise ValueError(\n                f\"`num_tokens` must be >= 1. Received: num_tokens={num_tokens}.\"\n            )\n        self.num_tokens = num_tokens\n        self.output_mode = output_mode\n        self.sparse = sparse\n        self._allow_non_tensor_positional_args = True\n        self._convert_input_args = False\n\n    def _encode(self, inputs, count_weights=None):\n        inputs = self.backend.core.convert_to_tensor(inputs)\n        return numerical_utils.encode_categorical_inputs(\n            inputs,\n            output_mode=self.output_mode,\n            depth=self.num_tokens,\n            dtype=self.dtype,\n            sparse=self.sparse,\n            count_weights=count_weights,\n            backend_module=self.backend,\n        )\n\n    def compute_output_shape(self, input_shape):\n        if (input_shape is not None) & (len(input_shape) == 0):\n            return (self.num_tokens,)\n        if self.output_mode == \"one_hot\":\n            if input_shape[-1] != 1:\n                return tuple(input_shape) + (self.num_tokens,)\n            elif len(input_shape) == 1:\n                return tuple(input_shape) + (self.num_tokens,)\n            else:\n                return tuple(input_shape[:-1]) + (self.num_tokens,)\n        return tuple(input_shape[:-1]) + (self.num_tokens,)\n\n    def compute_output_spec(self, inputs, count_weights=None):\n        output_shape = self.compute_output_shape(inputs.shape)\n        return KerasTensor(\n            output_shape, dtype=self.compute_dtype, sparse=self.sparse\n        )\n\n    def get_config(self):\n        config = {\n            \"num_tokens\": self.num_tokens,\n            \"output_mode\": self.output_mode,\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n    def call(self, inputs, count_weights=None):\n        if count_weights is not None:\n            if self.output_mode != \"count\":\n                raise ValueError(\n                    \"`count_weights` is not used when `output_mode` is not \"\n                    f\"`'count'`. Received `count_weights={count_weights}`.\"\n                )\n            count_weights = self.backend.convert_to_tensor(\n                count_weights, dtype=self.compute_dtype\n            )\n        outputs = self._encode(inputs, count_weights)\n        return backend_utils.convert_tf_tensor(outputs)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/category_encoding_test.py",
    "content": "import numpy as np\nfrom absl.testing import parameterized\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\nTEST_CASES = [{\"testcase_name\": \"dense\", \"sparse\": False}]\nif backend.SUPPORTS_SPARSE_TENSORS:\n    TEST_CASES += [{\"testcase_name\": \"sparse\", \"sparse\": True}]\n\n\nclass CategoryEncodingTest(testing.TestCase):\n    @parameterized.named_parameters(TEST_CASES)\n    def test_count_output(self, sparse):\n        input_array = np.array([1, 2, 3, 1])\n        expected_output = np.array([0, 2, 1, 1, 0, 0])\n\n        num_tokens = 6\n        expected_output_shape = (num_tokens,)\n\n        layer = layers.CategoryEncoding(\n            num_tokens=num_tokens, output_mode=\"count\", sparse=sparse\n        )\n        int_data = layer(input_array)\n        self.assertEqual(expected_output_shape, int_data.shape)\n        self.assertAllClose(int_data, expected_output)\n        self.assertSparse(int_data, sparse)\n\n        # Test symbolic call.\n        output = layer(\n            layers.Input(batch_shape=input_array.shape, dtype=\"int32\")\n        )\n        self.assertEqual(expected_output_shape, output.shape)\n        self.assertEqual(\"float32\", output.dtype)\n        self.assertSparse(output, sparse)\n\n    @parameterized.named_parameters(TEST_CASES)\n    def test_count_weighted_output(self, sparse):\n        input_array = np.array([[0, 1], [0, 0], [1, 2], [3, 1]])\n        count_weights = np.array(\n            [[0.1, 0.2], [0.1, 0.1], [0.2, 0.3], [0.4, 0.2]]\n        )\n        expected_output = np.array(\n            [\n                [0.1, 0.2, 0.0, 0.0, 0.0, 0.0],\n                [0.2, 0.0, 0.0, 0.0, 0.0, 0.0],\n                [0.0, 0.2, 0.3, 0.0, 0.0, 0.0],\n                [0.0, 0.2, 0.0, 0.4, 0.0, 0.0],\n            ]\n        )\n\n        num_tokens = 6\n        expected_output_shape = (input_array.shape[0], num_tokens)\n\n        layer = layers.CategoryEncoding(\n            num_tokens=num_tokens, output_mode=\"count\", sparse=sparse\n        )\n        int_data = layer(input_array, count_weights=count_weights)\n        self.assertEqual(expected_output_shape, int_data.shape)\n        self.assertAllClose(int_data, expected_output)\n        self.assertSparse(int_data, sparse)\n\n        # Test symbolic call.\n        output = layer(\n            layers.Input(batch_shape=input_array.shape, dtype=\"int32\"),\n            count_weights=layers.Input(\n                batch_shape=input_array.shape, dtype=\"float32\"\n            ),\n        )\n        self.assertEqual(expected_output_shape, output.shape)\n        self.assertEqual(\"float32\", output.dtype)\n        self.assertSparse(output, sparse)\n\n    @parameterized.named_parameters(TEST_CASES)\n    def test_batched_count_output(self, sparse):\n        input_array = np.array([[1, 2, 3, 1], [0, 3, 1, 0]])\n        expected_output = np.array([[0, 2, 1, 1, 0, 0], [2, 1, 0, 1, 0, 0]])\n\n        num_tokens = 6\n        expected_output_shape = (2, num_tokens)\n\n        layer = layers.CategoryEncoding(\n            num_tokens=num_tokens, output_mode=\"count\", sparse=sparse\n        )\n        int_data = layer(input_array)\n        self.assertEqual(expected_output_shape, int_data.shape)\n        self.assertAllClose(int_data, expected_output)\n        self.assertSparse(int_data, sparse)\n\n        # Test symbolic call.\n        output = layer(\n            layers.Input(batch_shape=input_array.shape, dtype=\"int32\")\n        )\n        self.assertEqual(expected_output_shape, output.shape)\n        self.assertEqual(\"float32\", output.dtype)\n        self.assertSparse(output, sparse)\n\n    @parameterized.named_parameters(TEST_CASES)\n    def test_multi_hot(self, sparse):\n        input_data = np.array([3, 2, 0, 1])\n        expected_output = np.array([1, 1, 1, 1, 0, 0])\n        num_tokens = 6\n        expected_output_shape = (num_tokens,)\n\n        # Test call on layer directly.\n        layer = layers.CategoryEncoding(\n            num_tokens=num_tokens, output_mode=\"multi_hot\", sparse=sparse\n        )\n        output_data = layer(input_data)\n        self.assertAllClose(expected_output, output_data)\n        self.assertEqual(expected_output_shape, output_data.shape)\n        self.assertSparse(output_data, sparse)\n\n        # Test symbolic call.\n        output = layer(\n            layers.Input(batch_shape=input_data.shape, dtype=\"int32\")\n        )\n        self.assertEqual(expected_output_shape, output.shape)\n        self.assertEqual(\"float32\", output.dtype)\n        self.assertSparse(output, sparse)\n\n    @parameterized.named_parameters(TEST_CASES)\n    def test_batched_multi_hot(self, sparse):\n        input_data = np.array([[3, 2, 0, 1], [3, 2, 0, 1]])\n        expected_output = np.array([[1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 0, 0]])\n        num_tokens = 6\n        expected_output_shape = (input_data.shape[0], num_tokens)\n\n        # Test call on layer directly.\n        layer = layers.CategoryEncoding(\n            num_tokens=num_tokens, output_mode=\"multi_hot\", sparse=sparse\n        )\n        output_data = layer(input_data)\n        self.assertAllClose(expected_output, output_data)\n        self.assertEqual(expected_output_shape, output_data.shape)\n        self.assertSparse(output_data, sparse)\n\n        # Test symbolic call.\n        output = layer(\n            layers.Input(batch_shape=input_data.shape, dtype=\"int32\")\n        )\n        self.assertEqual(expected_output_shape, output.shape)\n        self.assertEqual(\"float32\", output.dtype)\n        self.assertSparse(output, sparse)\n\n        # Test compute_output_shape\n        input_data = np.array((4))\n        layer = layers.CategoryEncoding(\n            num_tokens=num_tokens, output_mode=\"multi_hot\", sparse=sparse\n        )\n        self.assertEqual(\n            layer(input_data).shape,\n            layer.compute_output_shape(input_data.shape),\n        )\n\n    @parameterized.named_parameters(TEST_CASES)\n    def test_one_hot(self, sparse):\n        input_data = np.array([3, 2, 0, 1])\n        expected_output = np.array(\n            [\n                [0, 0, 0, 1, 0, 0],\n                [0, 0, 1, 0, 0, 0],\n                [1, 0, 0, 0, 0, 0],\n                [0, 1, 0, 0, 0, 0],\n            ]\n        )\n        num_tokens = 6\n        expected_output_shape = (input_data.shape[0], num_tokens)\n\n        # Test call on layer directly.\n        layer = layers.CategoryEncoding(\n            num_tokens=num_tokens, output_mode=\"one_hot\", sparse=sparse\n        )\n        output_data = layer(input_data)\n        self.assertAllClose(expected_output, output_data)\n        self.assertEqual(expected_output_shape, output_data.shape)\n        self.assertSparse(output_data, sparse)\n\n        # Test symbolic call.\n        output = layer(\n            layers.Input(batch_shape=input_data.shape, dtype=\"int32\")\n        )\n        self.assertEqual(expected_output_shape, output.shape)\n        self.assertEqual(\"float32\", output.dtype)\n        self.assertSparse(output, sparse)\n\n        # Test compute_output_shape\n        layer = layers.CategoryEncoding(\n            num_tokens=num_tokens, output_mode=\"one_hot\", sparse=sparse\n        )\n        self.assertEqual(\n            layer(input_data).shape,\n            layer.compute_output_shape(input_data.shape),\n        )\n\n        # Test compute_output_shape with 1 extra dimension\n        input_data = np.array([[3], [2], [0], [1]])\n        layer = layers.CategoryEncoding(\n            num_tokens=num_tokens, output_mode=\"one_hot\", sparse=sparse\n        )\n        self.assertEqual(\n            layer(input_data).shape,\n            layer.compute_output_shape(input_data.shape),\n        )\n\n        input_data = np.array((4,))\n        layer = layers.CategoryEncoding(\n            num_tokens=num_tokens, output_mode=\"one_hot\", sparse=sparse\n        )\n        self.assertEqual(\n            layer(input_data).shape,\n            layer.compute_output_shape(input_data.shape),\n        )\n\n    @parameterized.named_parameters(TEST_CASES)\n    def test_batched_one_hot(self, sparse):\n        input_data = np.array([[3, 2, 0, 1], [3, 2, 0, 1]])\n        expected_output = np.array(\n            [\n                [\n                    [0, 0, 0, 1, 0, 0],\n                    [0, 0, 1, 0, 0, 0],\n                    [1, 0, 0, 0, 0, 0],\n                    [0, 1, 0, 0, 0, 0],\n                ],\n                [\n                    [0, 0, 0, 1, 0, 0],\n                    [0, 0, 1, 0, 0, 0],\n                    [1, 0, 0, 0, 0, 0],\n                    [0, 1, 0, 0, 0, 0],\n                ],\n            ]\n        )\n        num_tokens = 6\n        expected_output_shape = input_data.shape[0:2] + (num_tokens,)\n\n        # Test call on layer directly.\n        layer = layers.CategoryEncoding(\n            num_tokens=num_tokens, output_mode=\"one_hot\", sparse=sparse\n        )\n        output_data = layer(input_data)\n        self.assertAllClose(expected_output, output_data)\n        self.assertEqual(expected_output_shape, output_data.shape)\n        self.assertSparse(output_data, sparse)\n\n        # Test symbolic call.\n        output = layer(\n            layers.Input(batch_shape=input_data.shape, dtype=\"int32\")\n        )\n        self.assertEqual(expected_output_shape, output.shape)\n        self.assertEqual(\"float32\", output.dtype)\n        self.assertSparse(output, sparse)\n\n    def test_tf_data_compatibility(self):\n        layer = layers.CategoryEncoding(\n            num_tokens=4, output_mode=\"one_hot\", dtype=\"int32\"\n        )\n        input_data = np.array([3, 2, 0, 1])\n        expected_output = np.array(\n            [\n                [0, 0, 0, 1],\n                [0, 0, 1, 0],\n                [1, 0, 0, 0],\n                [0, 1, 0, 0],\n            ]\n        )\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(4).map(layer)\n        for output in ds.take(1):\n            output = output.numpy()\n        self.assertAllClose(output, expected_output)\n\n    def test_category_encoding_without_num_tokens(self):\n        with self.assertRaisesRegex(\n            ValueError, r\"num_tokens must be set to use this layer\"\n        ):\n            layers.CategoryEncoding(output_mode=\"multi_hot\")\n\n    def test_category_encoding_with_invalid_num_tokens(self):\n        with self.assertRaisesRegex(ValueError, r\"`num_tokens` must be >= 1\"):\n            layers.CategoryEncoding(num_tokens=0, output_mode=\"multi_hot\")\n\n        with self.assertRaisesRegex(ValueError, r\"`num_tokens` must be >= 1\"):\n            layers.CategoryEncoding(num_tokens=-1, output_mode=\"multi_hot\")\n\n    def test_category_encoding_with_unnecessary_count_weights(self):\n        layer = layers.CategoryEncoding(num_tokens=4, output_mode=\"multi_hot\")\n        input_data = np.array([0, 1, 2, 3])\n        count_weights = np.array([0.1, 0.2, 0.3, 0.4])\n        with self.assertRaisesRegex(\n            ValueError, r\"`count_weights` is not used when `output_mode`\"\n        ):\n            layer(input_data, count_weights=count_weights)\n\n    def test_invalid_output_mode_raises_error(self):\n        with self.assertRaisesRegex(\n            ValueError, r\"Unknown arg for output_mode: invalid_mode\"\n        ):\n            layers.CategoryEncoding(num_tokens=4, output_mode=\"invalid_mode\")\n\n    def test_encode_one_hot_single_sample(self):\n        layer = layers.CategoryEncoding(num_tokens=4, output_mode=\"one_hot\")\n        input_array = np.array([1, 2, 3, 1])\n        expected_output = np.array(\n            [\n                [0, 1, 0, 0],\n                [0, 0, 1, 0],\n                [0, 0, 0, 1],\n                [0, 1, 0, 0],\n            ]\n        )\n        output = layer._encode(input_array)\n        self.assertAllClose(expected_output, output)\n\n    def test_encode_one_hot_batched_samples(self):\n        layer = layers.CategoryEncoding(num_tokens=4, output_mode=\"one_hot\")\n        input_array = np.array([[3, 2, 0, 1], [3, 2, 0, 1]])\n        expected_output = np.array(\n            [\n                [[0, 0, 0, 1], [0, 0, 1, 0], [1, 0, 0, 0], [0, 1, 0, 0]],\n                [[0, 0, 0, 1], [0, 0, 1, 0], [1, 0, 0, 0], [0, 1, 0, 0]],\n            ]\n        )\n        output = layer._encode(input_array)\n        self.assertAllClose(expected_output, output)\n\n    def test_count_single_sample(self):\n        layer = layers.CategoryEncoding(num_tokens=4, output_mode=\"count\")\n        input_array = np.array([1, 2, 3, 1])\n        expected_output = np.array([0, 2, 1, 1])\n        output = layer(input_array)\n        self.assertAllClose(expected_output, output)\n\n    def test_count_batched_samples(self):\n        layer = layers.CategoryEncoding(num_tokens=4, output_mode=\"count\")\n        input_array = np.array([[1, 2, 3, 1], [0, 3, 1, 0]])\n        expected_output = np.array([[0, 2, 1, 1], [2, 1, 0, 1]])\n        output = layer(input_array)\n        self.assertAllClose(expected_output, output)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/data_layer.py",
    "content": "import keras.src.backend\nfrom keras.src import tree\nfrom keras.src.layers.layer import Layer\nfrom keras.src.random.seed_generator import SeedGenerator\nfrom keras.src.utils import backend_utils\nfrom keras.src.utils import jax_utils\nfrom keras.src.utils import tracking\n\n\nclass DataLayer(Layer):\n    \"\"\"Layer designed for safe use in `tf.data` or `grain` pipeline.\n\n    This layer overrides the `__call__` method to ensure that the correct\n    backend is used and that computation is performed on the CPU.\n\n    The `call()` method in subclasses should use `self.backend` ops. If\n    randomness is needed, define both `seed` and `generator` in `__init__` and\n    retrieve the running seed using `self._get_seed_generator()`. If the layer\n    has weights in `__init__` or `build()`, use `convert_weight()` to ensure\n    they are in the correct backend.\n\n    **Note:** This layer and its subclasses only support a single input tensor.\n\n    Examples:\n\n    **Custom `DataLayer` subclass:**\n\n    ```python\n    from keras.src.layers.preprocessing.data_layer import DataLayer\n    from keras.src.random import SeedGenerator\n\n\n    class BiasedRandomRGBToHSVLayer(DataLayer):\n        def __init__(self, seed=None, **kwargs):\n            super().__init__(**kwargs)\n            self.probability_bias = ops.convert_to_tensor(0.01)\n            self.seed = seed\n            self.generator = SeedGenerator(seed)\n\n        def call(self, inputs):\n            images_shape = self.backend.shape(inputs)\n            batch_size = 1 if len(images_shape) == 3 else images_shape[0]\n            seed = self._get_seed_generator(self.backend._backend)\n\n            probability = self.backend.random.uniform(\n                shape=(batch_size,),\n                minval=0.0,\n                maxval=1.0,\n                seed=seed,\n            )\n            probability = self.backend.numpy.add(\n                probability, self.convert_weight(self.probability_bias)\n            )\n            hsv_images = self.backend.image.rgb_to_hsv(inputs)\n            return self.backend.numpy.where(\n                probability[:, None, None, None] > 0.5,\n                hsv_images,\n                inputs,\n            )\n\n        def compute_output_shape(self, input_shape):\n            return input_shape\n    ```\n\n    **Using as a regular Keras layer:**\n\n    ```python\n    import numpy as np\n\n    x = np.random.uniform(size=(1, 16, 16, 3)).astype(\"float32\")\n    print(BiasedRandomRGBToHSVLayer()(x).shape)  # (1, 16, 16, 3)\n    ```\n\n    **Using in a `tf.data` pipeline:**\n\n    ```python\n    import tensorflow as tf\n\n    tf_ds = tf.data.Dataset.from_tensors(x)\n    tf_ds = tf_ds.map(BiasedRandomRGBToHSVLayer())\n    print([x.shape for x in tf_ds])  # [(1, 16, 16, 3)]\n    ```\n\n    **Using in a `grain` pipeline:**\n\n    ```python\n    import grain\n\n    grain_ds = grain.MapDataset.source([x])\n    grain_ds = grain_ds.map(BiasedRandomRGBToHSVLayer())\n    print([x.shape for x in grain_ds])  # [(1, 16, 16, 3)]\n    \"\"\"\n\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n        self.backend = backend_utils.DynamicBackend()\n        self._allow_non_tensor_positional_args = True\n\n    def __call__(self, inputs, **kwargs):\n        sample_input = tree.flatten(inputs)[0]\n        if (\n            not isinstance(sample_input, keras.KerasTensor)\n            and backend_utils.in_tf_graph()\n            and not jax_utils.is_in_jax_tracing_scope(sample_input)\n        ):\n            # We're in a TF graph, e.g. a tf.data pipeline.\n            self.backend.set_backend(\"tensorflow\")\n            inputs = tree.map_structure(\n                lambda x: self.backend.convert_to_tensor(\n                    x, dtype=self.compute_dtype\n                ),\n                inputs,\n            )\n            switch_convert_input_args = False\n            if self._convert_input_args:\n                self._convert_input_args = False\n                switch_convert_input_args = True\n            try:\n                outputs = super().__call__(inputs, **kwargs)\n            finally:\n                self.backend.reset()\n                if switch_convert_input_args:\n                    self._convert_input_args = True\n            return outputs\n        elif (\n            not isinstance(sample_input, keras.KerasTensor)\n            and backend_utils.in_grain_data_pipeline()\n        ):\n            # We're in a Grain data pipeline. Force computation and data\n            # placement to CPU.\n            with keras.src.backend.device_scope(\"cpu\"):\n                return super().__call__(inputs, **kwargs)\n        else:\n            return super().__call__(inputs, **kwargs)\n\n    @tracking.no_automatic_dependency_tracking\n    def _get_seed_generator(self, backend=None):\n        if not hasattr(self, \"seed\") or not hasattr(self, \"generator\"):\n            raise ValueError(\n                \"The `seed` and `generator` variable must be set in the \"\n                \"`__init__` method before calling `_get_seed_generator()`.\"\n            )\n        if backend is None or backend == keras.backend.backend():\n            return self.generator\n        if not hasattr(self, \"_backend_generators\"):\n            self._backend_generators = {}\n        if backend in self._backend_generators:\n            return self._backend_generators[backend]\n        seed_generator = SeedGenerator(self.seed, backend=self.backend)\n        self._backend_generators[backend] = seed_generator\n        return seed_generator\n\n    def convert_weight(self, weight):\n        \"\"\"Convert the weight if it is from the a different backend.\"\"\"\n        if self.backend.name == keras.backend.backend():\n            return weight\n        else:\n            weight = keras.ops.convert_to_numpy(weight)\n            return self.backend.convert_to_tensor(weight)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/data_layer_test.py",
    "content": "import grain\nimport numpy as np\nimport pytest\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import testing\nfrom keras.src.layers.preprocessing.data_layer import DataLayer\nfrom keras.src.random import SeedGenerator\n\n\nclass RandomRGBToHSVLayer(DataLayer):\n    def __init__(self, data_format=None, seed=None, **kwargs):\n        super().__init__(**kwargs)\n        self.data_format = backend.standardize_data_format(data_format)\n        self.seed = seed\n        self.generator = SeedGenerator(seed)\n\n    def call(self, inputs):\n        images_shape = self.backend.shape(inputs)\n        batch_size = 1 if len(images_shape) == 3 else images_shape[0]\n        seed = self._get_seed_generator(self.backend._backend)\n\n        probability = self.backend.random.uniform(\n            shape=(batch_size,),\n            minval=0.0,\n            maxval=1.0,\n            seed=seed,\n        )\n        hsv_images = self.backend.image.rgb_to_hsv(\n            inputs, data_format=self.data_format\n        )\n        return self.backend.numpy.where(\n            probability[:, None, None, None] > 0.5, hsv_images, inputs\n        )\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n\nclass DataLayerTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_layer(self):\n        self.run_layer_test(\n            RandomRGBToHSVLayer,\n            init_kwargs={\n                \"seed\": 1337,\n                \"data_format\": \"channels_last\",\n            },\n            input_shape=(1, 2, 2, 3),\n            supports_masking=False,\n            expected_output_shape=(1, 2, 2, 3),\n        )\n\n        self.run_layer_test(\n            RandomRGBToHSVLayer,\n            init_kwargs={\n                \"seed\": 1337,\n                \"data_format\": \"channels_first\",\n            },\n            input_shape=(1, 3, 2, 2),\n            supports_masking=False,\n            expected_output_shape=(1, 3, 2, 2),\n        )\n\n    def test_tf_data_compatibility(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_data = np.random.random((2, 8, 8, 3)).astype(\"float32\")\n        else:\n            input_data = np.random.random((2, 3, 8, 8)).astype(\"float32\")\n        layer = RandomRGBToHSVLayer(data_format=data_format, seed=1337)\n\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        for output in ds.take(1):\n            self.assertDType(output, \"float32\")\n            self.assertEqual(list(output.shape), list(input_data.shape))\n\n    def test_grain_compatibility(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_data = np.random.random((2, 8, 8, 3)).astype(\"float32\")\n        else:\n            input_data = np.random.random((2, 3, 8, 8)).astype(\"float32\")\n        layer = RandomRGBToHSVLayer(data_format=data_format, seed=1337)\n\n        ds = grain.MapDataset.source(input_data).batch(2).map(layer)\n        for output in ds[:1]:\n            self.assertDType(output, \"float32\")\n            self.assertEqual(list(output.shape), list(input_data.shape))\n"
  },
  {
    "path": "keras/src/layers/preprocessing/discretization.py",
    "content": "import numpy as np\n\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.data_layer import DataLayer\nfrom keras.src.utils import argument_validation\nfrom keras.src.utils import numerical_utils\nfrom keras.src.utils.module_utils import tensorflow as tf\n\n\n@keras_export(\"keras.layers.Discretization\")\nclass Discretization(DataLayer):\n    \"\"\"A preprocessing layer which buckets continuous features by ranges.\n\n    This layer will place each element of its input data into one of several\n    contiguous ranges and output an integer index indicating which range each\n    element was placed in.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Input shape:\n        Any array of dimension 2 or higher.\n\n    Output shape:\n        Same as input shape.\n\n    Arguments:\n        bin_boundaries: A list of bin boundaries.\n            The leftmost and rightmost bins\n            will always extend to `-inf` and `inf`,\n            so `bin_boundaries=[0., 1., 2.]`\n            generates bins `(-inf, 0.)`, `[0., 1.)`, `[1., 2.)`,\n            and `[2., +inf)`.\n            If this option is set, `adapt()` should not be called.\n        num_bins: The integer number of bins to compute.\n            If this option is set, `bin_boundaries` should not be set and\n            `adapt()` should be called to learn the bin boundaries.\n        epsilon: Error tolerance, typically a small fraction\n            close to zero (e.g. 0.01). Higher values of epsilon increase\n            the quantile approximation, and hence result in more\n            unequal buckets, but could improve performance\n            and resource consumption.\n        output_mode: Specification for the output of the layer.\n            Values can be `\"int\"`, `\"one_hot\"`, `\"multi_hot\"`, or\n            `\"count\"` configuring the layer as follows:\n            - `\"int\"`: Return the discretized bin indices directly.\n            - `\"one_hot\"`: Encodes each individual element in the\n                input into an array the same size as `num_bins`,\n                containing a 1 at the input's bin\n                index. If the last dimension is size 1, will encode on that\n                dimension.  If the last dimension is not size 1,\n                will append a new dimension for the encoded output.\n            - `\"multi_hot\"`: Encodes each sample in the input into a\n                single array the same size as `num_bins`,\n                containing a 1 for each bin index\n                index present in the sample.\n                Treats the last dimension as the sample\n                dimension, if input shape is `(..., sample_length)`,\n                output shape will be `(..., num_tokens)`.\n            - `\"count\"`: As `\"multi_hot\"`, but the int array contains\n                a count of the number of times the bin index appeared\n                in the sample.\n            Defaults to `\"int\"`.\n        sparse: Boolean. Only applicable to `\"one_hot\"`, `\"multi_hot\"`,\n            and `\"count\"` output modes. Only supported with TensorFlow\n            backend. If `True`, returns a `SparseTensor` instead of\n            a dense `Tensor`. Defaults to `False`.\n\n    Examples:\n\n    Discretize float values based on provided buckets.\n    >>> input = np.array([[-1.5, 1.0, 3.4, .5], [0.0, 3.0, 1.3, 0.0]])\n    >>> layer = Discretization(bin_boundaries=[0., 1., 2.])\n    >>> layer(input)\n    array([[0, 2, 3, 1],\n           [1, 3, 2, 1]])\n\n    Discretize float values based on a number of buckets to compute.\n    >>> input = np.array([[-1.5, 1.0, 3.4, .5], [0.0, 3.0, 1.3, 0.0]])\n    >>> layer = Discretization(num_bins=4, epsilon=0.01)\n    >>> layer.adapt(input)\n    >>> layer(input)\n    array([[0, 2, 3, 2],\n           [1, 3, 3, 1]])\n    \"\"\"\n\n    def __init__(\n        self,\n        bin_boundaries=None,\n        num_bins=None,\n        epsilon=0.01,\n        output_mode=\"int\",\n        sparse=False,\n        dtype=None,\n        name=None,\n    ):\n        super().__init__(name=name, dtype=dtype)\n\n        if sparse and not backend.SUPPORTS_SPARSE_TENSORS:\n            raise ValueError(\n                f\"`sparse=True` cannot be used with backend {backend.backend()}\"\n            )\n        if sparse and output_mode == \"int\":\n            raise ValueError(\n                \"`sparse=True` may only be used if `output_mode` is \"\n                \"`'one_hot'`, `'multi_hot'`, or `'count'`. \"\n                f\"Received: sparse={sparse} and \"\n                f\"output_mode={output_mode}\"\n            )\n\n        argument_validation.validate_string_arg(\n            output_mode,\n            allowable_strings=(\n                \"int\",\n                \"one_hot\",\n                \"multi_hot\",\n                \"count\",\n            ),\n            caller_name=self.__class__.__name__,\n            arg_name=\"output_mode\",\n        )\n\n        if num_bins is not None and num_bins < 0:\n            raise ValueError(\n                \"`num_bins` must be greater than or equal to 0. \"\n                f\"Received: `num_bins={num_bins}`\"\n            )\n        if num_bins is not None and bin_boundaries is not None:\n            raise ValueError(\n                \"Both `num_bins` and `bin_boundaries` should not be set. \"\n                f\"Received: `num_bins={num_bins}` and \"\n                f\"`bin_boundaries={bin_boundaries}`\"\n            )\n        if num_bins is None and bin_boundaries is None:\n            raise ValueError(\n                \"You need to set either `num_bins` or `bin_boundaries`.\"\n            )\n\n        self.bin_boundaries = bin_boundaries\n        self.num_bins = num_bins\n        self.epsilon = epsilon\n        self.output_mode = output_mode\n        self.sparse = sparse\n\n        if self.bin_boundaries:\n            self.summary = None\n        else:\n            self.summary = np.array([[], []], dtype=\"float32\")\n\n    @property\n    def input_dtype(self):\n        return backend.floatx()\n\n    @property\n    def output_dtype(self):\n        return self.compute_dtype if self.output_mode != \"int\" else \"int32\"\n\n    def adapt(self, data, steps=None):\n        \"\"\"Computes bin boundaries from quantiles in a input dataset.\n\n        Calling `adapt()` on a `Discretization` layer is an alternative to\n        passing in a `bin_boundaries` argument during construction. A\n        `Discretization` layer should always be either adapted over a dataset or\n        passed `bin_boundaries`.\n\n        During `adapt()`, the layer will estimate the quantile boundaries of the\n        input dataset. The number of quantiles can be controlled via the\n        `num_bins` argument, and the error tolerance for quantile boundaries can\n        be controlled via the `epsilon` argument.\n\n        Arguments:\n            data: The data to train on. It can be passed either as a\n                batched `tf.data.Dataset`,\n                or as a NumPy array.\n            steps: Integer or `None`.\n                Total number of steps (batches of samples) to process.\n                If `data` is a `tf.data.Dataset`, and `steps` is `None`,\n                `adapt()` will run until the input dataset is exhausted.\n                When passing an infinitely\n                repeating dataset, you must specify the `steps` argument. This\n                argument is not supported with array inputs or list inputs.\n        \"\"\"\n        if self.num_bins is None:\n            raise ValueError(\n                \"Cannot adapt a Discretization layer that has been initialized \"\n                \"with `bin_boundaries`, use `num_bins` instead.\"\n            )\n        self.reset_state()\n        if isinstance(data, tf.data.Dataset):\n            if steps is not None:\n                data = data.take(steps)\n            for batch in data:\n                self.update_state(batch)\n        else:\n            self.update_state(data)\n        self.finalize_state()\n\n    def update_state(self, data):\n        data = np.array(data).astype(\"float32\")\n        summary = summarize(data, self.epsilon)\n        self.summary = merge_summaries(summary, self.summary, self.epsilon)\n\n    def finalize_state(self):\n        if self.num_bins is None:\n            return\n        self.bin_boundaries = get_bin_boundaries(\n            self.summary, self.num_bins\n        ).tolist()\n\n    def reset_state(self):\n        if self.num_bins is None:\n            return\n        self.summary = np.array([[], []], dtype=\"float32\")\n\n    def compute_output_shape(self, input_shape):\n        if self.output_mode == \"int\":\n            return input_shape\n\n        # Calculate depth (number of bins)\n        depth = (\n            len(self.bin_boundaries) + 1\n            if self.bin_boundaries is not None\n            else self.num_bins\n        )\n\n        if self.output_mode == \"one_hot\":\n            # For one_hot mode, add depth dimension\n            # If last dimension is 1, replace it with depth, otherwise append\n            if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:\n                return tuple(input_shape[:-1]) + (depth,)\n            else:\n                return tuple(input_shape) + (depth,)\n        else:\n            if input_shape and len(input_shape) >= 2:\n                # Match to eager tensor, remove second and append depth\n                out_shape = (\n                    (input_shape[0],) + tuple(input_shape[2:]) + (depth,)\n                )\n                return out_shape\n            else:\n                return (depth,)\n\n    def compute_output_spec(self, inputs):\n        output_shape = self.compute_output_shape(inputs.shape)\n        return backend.KerasTensor(shape=output_shape, dtype=self.output_dtype)\n\n    def load_own_variables(self, store):\n        if len(store) == 1:\n            # Legacy format case\n            self.summary = store[\"0\"]\n        return\n\n    def call(self, inputs):\n        if self.bin_boundaries is None:\n            raise ValueError(\n                \"You need to either pass the `bin_boundaries` argument at \"\n                \"construction time or call `adapt(dataset)` before you can \"\n                \"start using the `Discretization` layer.\"\n            )\n\n        indices = self.backend.numpy.digitize(inputs, self.bin_boundaries)\n        return numerical_utils.encode_categorical_inputs(\n            indices,\n            output_mode=self.output_mode,\n            depth=len(self.bin_boundaries) + 1,\n            dtype=self.output_dtype,\n            sparse=self.sparse,\n            backend_module=self.backend,\n        )\n\n    def get_config(self):\n        return {\n            \"bin_boundaries\": self.bin_boundaries,\n            \"num_bins\": self.num_bins,\n            \"epsilon\": self.epsilon,\n            \"output_mode\": self.output_mode,\n            \"sparse\": self.sparse,\n            \"name\": self.name,\n            \"dtype\": self.dtype,\n        }\n\n    @classmethod\n    def from_config(cls, config, custom_objects=None):\n        if (\n            config.get(\"bin_boundaries\", None) is not None\n            and config.get(\"num_bins\", None) is not None\n        ):\n            # After `adapt` was called, both `bin_boundaries` and `num_bins` are\n            # populated, but `__init__` won't let us create a new layer with\n            # both `bin_boundaries` and `num_bins`. We therefore apply\n            # `bin_boundaries` after creation.\n            config = config.copy()\n            bin_boundaries = config.pop(\"bin_boundaries\")\n            discretization = cls(**config)\n            discretization.bin_boundaries = bin_boundaries\n            return discretization\n        return cls(**config)\n\n\ndef summarize(values, epsilon):\n    \"\"\"Reduce a 1D sequence of values to a summary.\n\n    This algorithm is based on numpy.quantiles but modified to allow for\n    intermediate steps between multiple data sets. It first finds the target\n    number of bins as the reciprocal of epsilon and then takes the individual\n    values spaced at appropriate intervals to arrive at that target.\n    The final step is to return the corresponding counts between those values\n    If the target num_bins is larger than the size of values, the whole array is\n    returned (with weights of 1).\n\n    Args:\n        values: 1D `np.ndarray` to be summarized.\n        epsilon: A `'float32'` that determines the approximate desired\n        precision.\n\n    Returns:\n        A 2D `np.ndarray` that is a summary of the inputs. First column is the\n        interpolated partition values, the second is the weights (counts).\n    \"\"\"\n    values = np.reshape(values, [-1])\n    values = np.sort(values)\n    elements = np.size(values)\n    num_buckets = 1.0 / epsilon\n    increment = elements / num_buckets\n    start = increment\n    step = max(increment, 1)\n    boundaries = values[int(start) :: int(step)]\n    weights = np.ones_like(boundaries)\n    weights = weights * step\n    return np.stack([boundaries, weights])\n\n\ndef merge_summaries(prev_summary, next_summary, epsilon):\n    \"\"\"Weighted merge sort of summaries.\n\n    Given two summaries of distinct data, this function merges (and compresses)\n    them to stay within `epsilon` error tolerance.\n\n    Args:\n        prev_summary: 2D `np.ndarray` summary to be merged with `next_summary`.\n        next_summary: 2D `np.ndarray` summary to be merged with `prev_summary`.\n        epsilon: A float that determines the approximate desired precision.\n\n    Returns:\n        A 2-D `np.ndarray` that is a merged summary. First column is the\n        interpolated partition values, the second is the weights (counts).\n    \"\"\"\n    merged = np.concatenate((prev_summary, next_summary), axis=1)\n    merged = np.take(merged, np.argsort(merged[0]), axis=1)\n    return compress_summary(merged, epsilon)\n\n\ndef get_bin_boundaries(summary, num_bins):\n    return compress_summary(summary, 1.0 / num_bins)[0, :-1]\n\n\ndef compress_summary(summary, epsilon):\n    \"\"\"Compress a summary to within `epsilon` accuracy.\n\n    The compression step is needed to keep the summary sizes small after\n    merging, and also used to return the final target boundaries. It finds the\n    new bins based on interpolating cumulative weight percentages from the large\n    summary.  Taking the difference of the cumulative weights from the previous\n    bin's cumulative weight will give the new weight for that bin.\n\n    Args:\n        summary: 2D `np.ndarray` summary to be compressed.\n        epsilon: A `'float32'` that determines the approximate desired\n        precision.\n\n    Returns:\n        A 2D `np.ndarray` that is a compressed summary. First column is the\n        interpolated partition values, the second is the weights (counts).\n    \"\"\"\n    if summary.shape[1] * epsilon < 1:\n        return summary\n\n    percents = epsilon + np.arange(0.0, 1.0, epsilon)\n    cum_weights = summary[1].cumsum()\n    cum_weight_percents = cum_weights / cum_weights[-1]\n    new_bins = np.interp(percents, cum_weight_percents, summary[0])\n    cum_weights = np.interp(percents, cum_weight_percents, cum_weights)\n    new_weights = cum_weights - np.concatenate(\n        (np.array([0]), cum_weights[:-1])\n    )\n    summary = np.stack((new_bins, new_weights))\n    return summary.astype(\"float32\")\n"
  },
  {
    "path": "keras/src/layers/preprocessing/discretization_test.py",
    "content": "import os\n\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import testing\nfrom keras.src.saving import saving_api\nfrom keras.src.testing.test_utils import named_product\n\n\nclass DiscretizationTest(testing.TestCase):\n    def test_discretization_basics(self):\n        self.run_layer_test(\n            layers.Discretization,\n            init_kwargs={\n                \"bin_boundaries\": [0.0, 0.5, 1.0],\n            },\n            input_shape=(2, 3),\n            expected_output_shape=(2, 3),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n            run_training_check=False,\n        )\n\n    def test_adapt_flow(self):\n        layer = layers.Discretization(num_bins=4)\n        layer.adapt(\n            np.random.random((32, 3)),\n        )\n        output = layer(np.array([[0.0, 0.1, 0.3]]))\n        self.assertTrue(output.dtype, \"int32\")\n\n    @parameterized.named_parameters(\n        named_product(\n            [\n                {\n                    \"testcase_name\": \"int\",\n                    \"output_mode\": \"int\",\n                    \"input_array\": [[-1.0, 0.0, 0.1, 0.8, 1.2]],\n                    \"expected_output\": [[0, 1, 1, 2, 3]],\n                },\n                {\n                    \"testcase_name\": \"one_hot_rank_1\",\n                    \"output_mode\": \"one_hot\",\n                    \"input_array\": [0.1, 0.8],\n                    \"expected_output\": [[0, 1, 0, 0], [0, 0, 1, 0]],\n                },\n                {\n                    \"testcase_name\": \"multi_hot_rank_2\",\n                    \"output_mode\": \"multi_hot\",\n                    \"input_array\": [[0.1, 0.8]],\n                    \"expected_output\": [[0, 1, 1, 0]],\n                },\n                {\n                    \"testcase_name\": \"one_hot_rank_3\",\n                    \"output_mode\": \"one_hot\",\n                    \"input_array\": [[[0.15, 0.75], [0.85, 0.45]]],\n                    \"expected_output\": [\n                        [\n                            [[0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]],\n                            [[0.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 0.0]],\n                        ]\n                    ],\n                },\n                {\n                    \"testcase_name\": \"multi_hot_rank_3\",\n                    \"output_mode\": \"multi_hot\",\n                    \"input_array\": [[[0.15, 0.75], [0.85, 0.45]]],\n                    \"expected_output\": [\n                        [[0.0, 1.0, 1.0, 0.0], [0.0, 1.0, 1.0, 0.0]]\n                    ],\n                },\n                {\n                    \"testcase_name\": \"count\",\n                    \"output_mode\": \"count\",\n                    \"input_array\": [[0.1, 0.8, 0.9]],\n                    \"expected_output\": [[0, 1, 2, 0]],\n                },\n            ],\n            sparse=(\n                [True, False] if backend.SUPPORTS_SPARSE_TENSORS else [False]\n            ),\n        )\n    )\n    def test_correctness(\n        self, output_mode, input_array, expected_output, sparse\n    ):\n        if output_mode == \"int\" and sparse:\n            pytest.skip(\"sparse=True cannot be combined with output_mode=int\")\n\n        input_array = np.array(input_array)\n        expected_output = np.array(expected_output)\n\n        layer = layers.Discretization(\n            bin_boundaries=[0.0, 0.5, 1.0],\n            output_mode=output_mode,\n            sparse=sparse,\n        )\n        output = layer(input_array)\n        self.assertSparse(output, sparse)\n        self.assertTrue(backend.is_tensor(output))\n        self.assertAllClose(output, expected_output)\n\n    def test_tf_data_compatibility(self):\n        # With fixed bins\n        layer = layers.Discretization(\n            bin_boundaries=[0.0, 0.35, 0.5, 1.0], dtype=\"float32\"\n        )\n        x = np.array([[-1.0, 0.0, 0.1, 0.2, 0.4, 0.5, 1.0, 1.2, 0.98]])\n        self.assertAllClose(layer(x), np.array([[0, 1, 1, 1, 2, 3, 4, 4, 3]]))\n        ds = tf_data.Dataset.from_tensor_slices(x).batch(1).map(layer)\n        for output in ds.take(1):\n            output = output.numpy()\n        self.assertAllClose(output, np.array([[0, 1, 1, 1, 2, 3, 4, 4, 3]]))\n\n        # With adapt flow\n        layer = layers.Discretization(num_bins=4)\n        layer.adapt(\n            np.random.random((32, 3)),\n        )\n        x = np.array([[0.0, 0.1, 0.3]])\n        ds = tf_data.Dataset.from_tensor_slices(x).batch(1).map(layer)\n        for output in ds.take(1):\n            output.numpy()\n\n    def test_serialization(self):\n        layer = layers.Discretization(num_bins=5)\n\n        # Serialization before `adapt` is called.\n        config = layer.get_config()\n        revived_layer = layers.Discretization.from_config(config)\n        self.assertEqual(config, revived_layer.get_config())\n\n        # Serialization after `adapt` is called but `num_bins` was not reached.\n        layer.adapt(np.array([0.0, 1.0, 5.0]))\n        config = layer.get_config()\n        revived_layer = layers.Discretization.from_config(config)\n        self.assertEqual(config, revived_layer.get_config())\n\n        # Serialization after `adapt` is called and `num_bins` is reached.\n        layer.adapt(np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]))\n        config = layer.get_config()\n        revived_layer = layers.Discretization.from_config(config)\n        self.assertEqual(config, revived_layer.get_config())\n\n        # Serialization with `bin_boundaries`.\n        layer = layers.Discretization(bin_boundaries=[0.0, 0.35, 0.5, 1.0])\n        config = layer.get_config()\n        revived_layer = layers.Discretization.from_config(config)\n        self.assertEqual(config, revived_layer.get_config())\n\n    def test_saving(self):\n        # With fixed bins\n        layer = layers.Discretization(bin_boundaries=[0.0, 0.35, 0.5, 1.0])\n        model = models.Sequential(\n            [\n                layers.Input((2,)),\n                layer,\n            ]\n        )\n        fpath = os.path.join(self.get_temp_dir(), \"model.keras\")\n        model.save(fpath)\n        model = saving_api.load_model(fpath)\n        x = np.array([[-1.0, 0.0, 0.1, 0.2, 0.4, 0.5, 1.0, 1.2, 0.98]])\n        self.assertAllClose(layer(x), np.array([[0, 1, 1, 1, 2, 3, 4, 4, 3]]))\n\n        # With adapt flow\n        layer = layers.Discretization(num_bins=4)\n        layer.adapt(\n            np.random.random((32, 3)),\n        )\n        ref_input = np.random.random((1, 2))\n        ref_output = layer(ref_input)\n        model = models.Sequential(\n            [\n                layers.Input((2,)),\n                layer,\n            ]\n        )\n        fpath = os.path.join(self.get_temp_dir(), \"model.keras\")\n        model.save(fpath)\n        model = saving_api.load_model(fpath)\n        self.assertAllClose(layer(ref_input), ref_output)\n\n    def test_init_num_bins_and_bin_boundaries_raises(self):\n        with self.assertRaisesRegex(\n            ValueError, \"Both `num_bins` and `bin_boundaries`\"\n        ):\n            layers.Discretization(num_bins=3, bin_boundaries=[0.0, 1.0])\n\n        with self.assertRaisesRegex(\n            ValueError, \"either `num_bins` or `bin_boundaries`\"\n        ):\n            layers.Discretization()\n\n    def test_call_before_adapt_raises(self):\n        layer = layers.Discretization(num_bins=3)\n        with self.assertRaisesRegex(ValueError, \"You need .* call .*adapt\"):\n            layer([[0.1, 0.8, 0.9]])\n\n    def test_model_call_vs_predict_consistency(self):\n        \"\"\"Test that model(input) and model.predict(input) produce consistent\n        outputs.\"\"\"\n        # Test with int output mode\n        layer = layers.Discretization(\n            bin_boundaries=[-0.5, 0, 0.1, 0.2, 3],\n            output_mode=\"int\",\n        )\n        x = np.array([[0.0, 0.15, 0.21, 0.3], [0.0, 0.17, 0.451, 7.8]])\n\n        # Create model\n        inputs = layers.Input(shape=(4,), dtype=\"float32\")\n        outputs = layer(inputs)\n        model = models.Model(inputs=inputs, outputs=outputs)\n\n        # Test both execution modes\n        model_call_output = model(x)\n        predict_output = model.predict(x)\n\n        # Check consistency\n        self.assertAllClose(model_call_output, predict_output)\n        self.assertEqual(\n            backend.standardize_dtype(model_call_output.dtype),\n            backend.standardize_dtype(predict_output.dtype),\n        )\n        self.assertTrue(backend.is_int_dtype(model_call_output.dtype))\n\n    @parameterized.named_parameters(\n        named_product(\n            [\n                {\n                    \"testcase_name\": \"int_mode\",\n                    \"output_mode\": \"int\",\n                    \"input_shape\": (3, 4),\n                    \"expected_shape\": (None, 3, 4),  # int mode - no change\n                },\n                {\n                    \"testcase_name\": \"one_hot_mode\",\n                    \"output_mode\": \"one_hot\",\n                    \"input_shape\": (3, 4),\n                    \"expected_shape\": (None, 3, 4, 5),  # one_hot - add dim\n                },\n                {\n                    \"testcase_name\": \"multi_hot_mode\",\n                    \"output_mode\": \"multi_hot\",\n                    \"input_shape\": (3, 4),\n                    \"expected_shape\": (None, 4, 5),  # multi_hot - replace\n                },\n                {\n                    \"testcase_name\": \"count_mode\",\n                    \"output_mode\": \"count\",\n                    \"input_shape\": (3, 4),\n                    \"expected_shape\": (None, 4, 5),  # count - replace\n                },\n            ]\n        )\n    )\n    def test_symbolic_tensor_output_shape(\n        self, output_mode, input_shape, expected_shape\n    ):\n        \"\"\"Test symbolic tensors have correct output shape for modes.\"\"\"\n        # Create layer with bin_boundaries that create 5 bins (4 boundaries)\n        layer = layers.Discretization(\n            bin_boundaries=[0.0, 1.0, 2.0, 5.0], output_mode=output_mode\n        )\n\n        # Create symbolic input and get output\n        symbolic_input = layers.Input(shape=input_shape)\n        symbolic_output = layer(symbolic_input)\n\n        # Verify symbolic output shape\n        self.assertEqual(symbolic_output.shape, expected_shape)\n\n        eager_input = np.random.uniform(0, 3, size=(2,) + input_shape)\n        eager_output = layer(eager_input)\n\n        # Verify batch dimension is preserved correctly\n        self.assertEqual(eager_output.shape[0], 2)  # Batch size preserved\n        self.assertEqual(\n            symbolic_output.shape[0], None\n        )  # Batch is None for symbolic\n\n        # Verify non-batch dimensions are identical\n        self.assertEqual(eager_output.shape[1:], symbolic_output.shape[1:])\n\n        # Verify total number of dimensions is the same\n        self.assertEqual(len(eager_output.shape), len(symbolic_output.shape))\n\n    @parameterized.named_parameters(\n        [\n            {\n                \"testcase_name\": \"int_mode\",\n                \"output_mode\": \"int\",\n                \"input_shape\": (None, 3, 4),\n                \"expected_shape\": (None, 3, 4),  # int mode - no change\n            },\n            {\n                \"testcase_name\": \"one_hot_mode\",\n                \"output_mode\": \"one_hot\",\n                \"input_shape\": (None, 3, 4),\n                \"expected_shape\": (None, 3, 4, 3),  # one_hot - add dim\n            },\n            {\n                \"testcase_name\": \"multi_hot_mode\",\n                \"output_mode\": \"multi_hot\",\n                \"input_shape\": (None, 3, 4),\n                \"expected_shape\": (None, 4, 3),  # multi_hot - replace\n            },\n            {\n                \"testcase_name\": \"count_mode\",\n                \"output_mode\": \"count\",\n                \"input_shape\": (None, 3, 4),\n                \"expected_shape\": (None, 4, 3),  # count - replace\n            },\n        ]\n    )\n    def test_compute_output_shape_modes(\n        self, output_mode, input_shape, expected_shape\n    ):\n        \"\"\"Test compute_output_shape with different output modes.\"\"\"\n        layer = layers.Discretization(\n            bin_boundaries=[0.0, 1.0], output_mode=output_mode\n        )\n        result_shape = layer.compute_output_shape(input_shape)\n        self.assertEqual(result_shape, expected_shape)\n\n    def test_compute_output_shape_edge_cases(self):\n        \"\"\"Test edge cases in compute_output_shape to improve coverage.\"\"\"\n\n        # Test edge case - last dimension is 1 with one_hot\n        layer_one_hot = layers.Discretization(\n            bin_boundaries=[0.0, 1.0], output_mode=\"one_hot\"\n        )\n\n        # Should replace last dimension of 1 with depth\n        shape = layer_one_hot.compute_output_shape((None, 5, 1))\n        expected = (None, 5, 3)  # 2 boundaries = 3 bins, replace last dim\n        self.assertEqual(shape, expected)\n\n        # Test empty input shape\n        shape = layer_one_hot.compute_output_shape(())\n        expected = (3,)  # Just depth\n        self.assertEqual(shape, expected)\n\n    def test_compute_output_spec_method(self):\n        \"\"\"Test compute_output_spec method directly.\"\"\"\n\n        layer = layers.Discretization(\n            bin_boundaries=[0.0, 1.0, 2.0], output_mode=\"one_hot\"\n        )\n\n        # Create a KerasTensor input\n        input_tensor = backend.KerasTensor(shape=(None, 3, 4), dtype=\"float32\")\n\n        # Test compute_output_spec\n        output_spec = layer.compute_output_spec(input_tensor)\n\n        # Verify shape and dtype\n        expected_shape = (None, 3, 4, 4)  # 3 boundaries = 4 bins\n        self.assertEqual(output_spec.shape, expected_shape)\n        self.assertEqual(output_spec.dtype, layer.output_dtype)\n\n    @parameterized.named_parameters(named_product(batch_size=[1, 3, 5, 10]))\n    def test_batch_dimension_consistency(self, batch_size):\n        \"\"\"Test that batch dimensions are handled consistently.\"\"\"\n\n        layer = layers.Discretization(\n            bin_boundaries=[0.0, 1.0, 2.0], output_mode=\"one_hot\"\n        )\n\n        # Test different batch sizes\n        input_shape = (4, 3)\n\n        # Create eager input with specific batch size\n        eager_input = np.random.uniform(0, 3, size=(batch_size,) + input_shape)\n        eager_output = layer(eager_input)\n\n        # Create symbolic input\n        symbolic_input = layers.Input(shape=input_shape)\n        symbolic_output = layer(symbolic_input)\n\n        # Verify batch dimension handling\n        self.assertEqual(\n            eager_output.shape[0],\n            batch_size,\n            f\"Eager batch size should be {batch_size}\",\n        )\n        self.assertEqual(\n            symbolic_output.shape[0],\n            None,\n            \"Symbolic batch size should be None\",\n        )\n\n        # Verify non-batch dimensions are identical\n        self.assertEqual(\n            eager_output.shape[1:],\n            symbolic_output.shape[1:],\n            \"Non-batch dimensions should be identical\",\n        )\n\n        # Verify expected output shape\n        expected_shape = input_shape + (4,)  # 3 boundaries = 4 bins\n        self.assertEqual(eager_output.shape[1:], expected_shape)\n        self.assertEqual(symbolic_output.shape[1:], expected_shape)\n\n    @parameterized.named_parameters(\n        named_product(\n            [\n                {\n                    \"testcase_name\": \"int_mode_3_bins\",\n                    \"num_bins\": 3,\n                    \"output_mode\": \"int\",\n                    \"input_shape\": (None, 5, 4),\n                    \"expected_shape\": (None, 5, 4),  # int mode - no change\n                },\n                {\n                    \"testcase_name\": \"one_hot_mode_4_bins\",\n                    \"num_bins\": 4,\n                    \"output_mode\": \"one_hot\",\n                    \"input_shape\": (None, 3, 2),\n                    \"expected_shape\": (\n                        None,\n                        3,\n                        2,\n                        4,\n                    ),  # one_hot - add depth dimension\n                },\n                {\n                    \"testcase_name\": \"multi_hot_mode_5_bins\",\n                    \"num_bins\": 5,\n                    \"output_mode\": \"multi_hot\",\n                    \"input_shape\": (None, 6, 3),\n                    \"expected_shape\": (\n                        None,\n                        3,\n                        5,\n                    ),  # multi_hot - replace last dimension\n                },\n                {\n                    \"testcase_name\": \"count_mode_6_bins\",\n                    \"num_bins\": 6,\n                    \"output_mode\": \"count\",\n                    \"input_shape\": (None, 4, 2),\n                    \"expected_shape\": (\n                        None,\n                        2,\n                        6,\n                    ),  # count - replace last dimension\n                },\n            ]\n        )\n    )\n    def test_compute_output_shape_num_bins(\n        self, num_bins, output_mode, input_shape, expected_shape\n    ):\n        \"\"\"Test compute_output_shape with num_bins parameter.\"\"\"\n\n        layer = layers.Discretization(\n            num_bins=num_bins, output_mode=output_mode\n        )\n\n        # Test compute_output_shape directly\n        result_shape = layer.compute_output_shape(input_shape)\n        self.assertEqual(\n            result_shape,\n            expected_shape,\n            f\"Failed for num_bins={num_bins}, mode={output_mode}\",\n        )\n"
  },
  {
    "path": "keras/src/layers/preprocessing/feature_space.py",
    "content": "from keras.src import backend\nfrom keras.src import layers\nfrom keras.src import tree\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.layer import Layer\nfrom keras.src.layers.preprocessing.data_layer import DataLayer\nfrom keras.src.saving import saving_lib\nfrom keras.src.saving import serialization_lib\nfrom keras.src.saving.keras_saveable import KerasSaveable\nfrom keras.src.utils import backend_utils\nfrom keras.src.utils.module_utils import tensorflow as tf\nfrom keras.src.utils.naming import auto_name\n\n\nclass Cross(KerasSaveable):\n    def __init__(self, feature_names, crossing_dim, output_mode=\"one_hot\"):\n        if output_mode not in {\"int\", \"one_hot\"}:\n            raise ValueError(\n                \"Invalid value for argument `output_mode`. \"\n                \"Expected one of {'int', 'one_hot'}. \"\n                f\"Received: output_mode={output_mode}\"\n            )\n        self.feature_names = tuple(feature_names)\n        self.crossing_dim = crossing_dim\n        self.output_mode = output_mode\n\n    def _obj_type(self):\n        return \"Cross\"\n\n    @property\n    def name(self):\n        return \"_X_\".join(self.feature_names)\n\n    def get_config(self):\n        return {\n            \"feature_names\": self.feature_names,\n            \"crossing_dim\": self.crossing_dim,\n            \"output_mode\": self.output_mode,\n        }\n\n    @classmethod\n    def from_config(cls, config):\n        return cls(**config)\n\n\nclass Feature(KerasSaveable):\n    def __init__(self, dtype, preprocessor, output_mode):\n        if output_mode not in {\"int\", \"one_hot\", \"float\"}:\n            raise ValueError(\n                \"Invalid value for argument `output_mode`. \"\n                \"Expected one of {'int', 'one_hot', 'float'}. \"\n                f\"Received: output_mode={output_mode}\"\n            )\n        self.dtype = dtype\n        if isinstance(preprocessor, dict):\n            preprocessor = serialization_lib.deserialize_keras_object(\n                preprocessor\n            )\n        self.preprocessor = preprocessor\n        self.output_mode = output_mode\n\n    def _obj_type(self):\n        return \"Feature\"\n\n    def get_config(self):\n        return {\n            \"dtype\": self.dtype,\n            \"preprocessor\": serialization_lib.serialize_keras_object(\n                self.preprocessor\n            ),\n            \"output_mode\": self.output_mode,\n        }\n\n    @classmethod\n    def from_config(cls, config):\n        return cls(**config)\n\n\n@keras_export(\"keras.utils.FeatureSpace\")\nclass FeatureSpace(Layer):\n    \"\"\"One-stop utility for preprocessing and encoding structured data.\n\n    Arguments:\n        feature_names: Dict mapping the names of your features to their\n            type specification, e.g. `{\"my_feature\": \"integer_categorical\"}`\n            or `{\"my_feature\": FeatureSpace.integer_categorical()}`.\n            For a complete list of all supported types, see\n            \"Available feature types\" paragraph below.\n        output_mode: One of `\"concat\"` or `\"dict\"`. In concat mode, all\n            features get concatenated together into a single vector.\n            In dict mode, the FeatureSpace returns a dict of individually\n            encoded features (with the same keys as the input dict keys).\n        crosses: List of features to be crossed together, e.g.\n            `crosses=[(\"feature_1\", \"feature_2\")]`. The features will be\n            \"crossed\" by hashing their combined value into\n            a fixed-length vector.\n        crossing_dim: Default vector size for hashing crossed features.\n            Defaults to `32`.\n        hashing_dim: Default vector size for hashing features of type\n            `\"integer_hashed\"` and `\"string_hashed\"`. Defaults to `32`.\n        num_discretization_bins: Default number of bins to be used for\n            discretizing features of type `\"float_discretized\"`.\n            Defaults to `32`.\n\n    **Available feature types:**\n\n    Note that all features can be referred to by their string name,\n    e.g. `\"integer_categorical\"`. When using the string name, the default\n    argument values are used.\n\n    ```python\n    # Plain float values.\n    FeatureSpace.float(name=None)\n\n    # Float values to be preprocessed via featurewise standardization\n    # (i.e. via a `keras.layers.Normalization` layer).\n    FeatureSpace.float_normalized(name=None)\n\n    # Float values to be preprocessed via linear rescaling\n    # (i.e. via a `keras.layers.Rescaling` layer).\n    FeatureSpace.float_rescaled(scale=1., offset=0., name=None)\n\n    # Float values to be discretized. By default, the discrete\n    # representation will then be one-hot encoded.\n    FeatureSpace.float_discretized(\n        num_bins, bin_boundaries=None, output_mode=\"one_hot\", name=None)\n\n    # Integer values to be indexed. By default, the discrete\n    # representation will then be one-hot encoded.\n    FeatureSpace.integer_categorical(\n        max_tokens=None, num_oov_indices=1, output_mode=\"one_hot\", name=None)\n\n    # String values to be indexed. By default, the discrete\n    # representation will then be one-hot encoded.\n    FeatureSpace.string_categorical(\n        max_tokens=None, num_oov_indices=1, output_mode=\"one_hot\", name=None)\n\n    # Integer values to be hashed into a fixed number of bins.\n    # By default, the discrete representation will then be one-hot encoded.\n    FeatureSpace.integer_hashed(num_bins, output_mode=\"one_hot\", name=None)\n\n    # String values to be hashed into a fixed number of bins.\n    # By default, the discrete representation will then be one-hot encoded.\n    FeatureSpace.string_hashed(num_bins, output_mode=\"one_hot\", name=None)\n    ```\n\n    Examples:\n\n    **Basic usage with a dict of input data:**\n\n    ```python\n    raw_data = {\n        \"float_values\": [0.0, 0.1, 0.2, 0.3],\n        \"string_values\": [\"zero\", \"one\", \"two\", \"three\"],\n        \"int_values\": [0, 1, 2, 3],\n    }\n    dataset = tf.data.Dataset.from_tensor_slices(raw_data)\n\n    feature_space = FeatureSpace(\n        features={\n            \"float_values\": \"float_normalized\",\n            \"string_values\": \"string_categorical\",\n            \"int_values\": \"integer_categorical\",\n        },\n        crosses=[(\"string_values\", \"int_values\")],\n        output_mode=\"concat\",\n    )\n    # Before you start using the FeatureSpace,\n    # you must `adapt()` it on some data.\n    feature_space.adapt(dataset)\n\n    # You can call the FeatureSpace on a dict of data (batched or unbatched).\n    output_vector = feature_space(raw_data)\n    ```\n\n    **Basic usage with `tf.data`:**\n\n    ```python\n    # Unlabeled data\n    preprocessed_ds = unlabeled_dataset.map(feature_space)\n\n    # Labeled data\n    preprocessed_ds = labeled_dataset.map(lambda x, y: (feature_space(x), y))\n    ```\n\n    **Basic usage with the Keras Functional API:**\n\n    ```python\n    # Retrieve a dict Keras Input objects\n    inputs = feature_space.get_inputs()\n    # Retrieve the corresponding encoded Keras tensors\n    encoded_features = feature_space.get_encoded_features()\n    # Build a Functional model\n    outputs = keras.layers.Dense(1, activation=\"sigmoid\")(encoded_features)\n    model = keras.Model(inputs, outputs)\n    ```\n\n    **Customizing each feature or feature cross:**\n\n    ```python\n    feature_space = FeatureSpace(\n        features={\n            \"float_values\": FeatureSpace.float_normalized(),\n            \"string_values\": FeatureSpace.string_categorical(max_tokens=10),\n            \"int_values\": FeatureSpace.integer_categorical(max_tokens=10),\n        },\n        crosses=[\n            FeatureSpace.cross((\"string_values\", \"int_values\"), crossing_dim=32)\n        ],\n        output_mode=\"concat\",\n    )\n    ```\n\n    **Returning a dict of integer-encoded features:**\n\n    ```python\n    feature_space = FeatureSpace(\n        features={\n            \"string_values\": FeatureSpace.string_categorical(output_mode=\"int\"),\n            \"int_values\": FeatureSpace.integer_categorical(output_mode=\"int\"),\n        },\n        crosses=[\n            FeatureSpace.cross(\n                feature_names=(\"string_values\", \"int_values\"),\n                crossing_dim=32,\n                output_mode=\"int\",\n            )\n        ],\n        output_mode=\"dict\",\n    )\n    ```\n\n    **Specifying your own Keras preprocessing layer:**\n\n    ```python\n    # Let's say that one of the features is a short text paragraph that\n    # we want to encode as a vector (one vector per paragraph) via TF-IDF.\n    data = {\n        \"text\": [\"1st string\", \"2nd string\", \"3rd string\"],\n    }\n\n    # There's a Keras layer for this: TextVectorization.\n    custom_layer = layers.TextVectorization(output_mode=\"tf_idf\")\n\n    # We can use FeatureSpace.feature to create a custom feature\n    # that will use our preprocessing layer.\n    feature_space = FeatureSpace(\n        features={\n            \"text\": FeatureSpace.feature(\n                preprocessor=custom_layer, dtype=\"string\", output_mode=\"float\"\n            ),\n        },\n        output_mode=\"concat\",\n    )\n    feature_space.adapt(tf.data.Dataset.from_tensor_slices(data))\n    output_vector = feature_space(data)\n    ```\n\n    **Retrieving the underlying Keras preprocessing layers:**\n\n    ```python\n    # The preprocessing layer of each feature is available in `.preprocessors`.\n    preprocessing_layer = feature_space.preprocessors[\"feature1\"]\n\n    # The crossing layer of each feature cross is available in `.crossers`.\n    # It's an instance of keras.layers.HashedCrossing.\n    crossing_layer = feature_space.crossers[\"feature1_X_feature2\"]\n    ```\n\n    **Saving and reloading a FeatureSpace:**\n\n    ```python\n    feature_space.save(\"featurespace.keras\")\n    reloaded_feature_space = keras.models.load_model(\"featurespace.keras\")\n    ```\n    \"\"\"\n\n    @classmethod\n    def cross(cls, feature_names, crossing_dim, output_mode=\"one_hot\"):\n        return Cross(feature_names, crossing_dim, output_mode=output_mode)\n\n    @classmethod\n    def feature(cls, dtype, preprocessor, output_mode):\n        return Feature(dtype, preprocessor, output_mode)\n\n    @classmethod\n    def float(cls, name=None):\n        name = name or auto_name(\"float\")\n        preprocessor = TFDIdentity(dtype=\"float32\", name=f\"{name}_preprocessor\")\n        return Feature(\n            dtype=\"float32\", preprocessor=preprocessor, output_mode=\"float\"\n        )\n\n    @classmethod\n    def float_rescaled(cls, scale=1.0, offset=0.0, name=None):\n        name = name or auto_name(\"float_rescaled\")\n        preprocessor = layers.Rescaling(\n            scale=scale, offset=offset, name=f\"{name}_preprocessor\"\n        )\n        return Feature(\n            dtype=\"float32\", preprocessor=preprocessor, output_mode=\"float\"\n        )\n\n    @classmethod\n    def float_normalized(cls, name=None):\n        name = name or auto_name(\"float_normalized\")\n        preprocessor = layers.Normalization(\n            axis=-1, name=f\"{name}_preprocessor\"\n        )\n        return Feature(\n            dtype=\"float32\", preprocessor=preprocessor, output_mode=\"float\"\n        )\n\n    @classmethod\n    def float_discretized(\n        cls, num_bins, bin_boundaries=None, output_mode=\"one_hot\", name=None\n    ):\n        name = name or auto_name(\"float_discretized\")\n        preprocessor = layers.Discretization(\n            num_bins=num_bins,\n            bin_boundaries=bin_boundaries,\n            name=f\"{name}_preprocessor\",\n        )\n        return Feature(\n            dtype=\"float32\", preprocessor=preprocessor, output_mode=output_mode\n        )\n\n    @classmethod\n    def integer_categorical(\n        cls,\n        max_tokens=None,\n        num_oov_indices=1,\n        output_mode=\"one_hot\",\n        name=None,\n    ):\n        name = name or auto_name(\"integer_categorical\")\n        preprocessor = layers.IntegerLookup(\n            name=f\"{name}_preprocessor\",\n            max_tokens=max_tokens,\n            num_oov_indices=num_oov_indices,\n        )\n        return Feature(\n            dtype=\"int32\", preprocessor=preprocessor, output_mode=output_mode\n        )\n\n    @classmethod\n    def string_categorical(\n        cls,\n        max_tokens=None,\n        num_oov_indices=1,\n        output_mode=\"one_hot\",\n        name=None,\n    ):\n        name = name or auto_name(\"string_categorical\")\n        preprocessor = layers.StringLookup(\n            name=f\"{name}_preprocessor\",\n            max_tokens=max_tokens,\n            num_oov_indices=num_oov_indices,\n        )\n        return Feature(\n            dtype=\"string\", preprocessor=preprocessor, output_mode=output_mode\n        )\n\n    @classmethod\n    def string_hashed(cls, num_bins, output_mode=\"one_hot\", name=None):\n        name = name or auto_name(\"string_hashed\")\n        preprocessor = layers.Hashing(\n            name=f\"{name}_preprocessor\", num_bins=num_bins\n        )\n        return Feature(\n            dtype=\"string\", preprocessor=preprocessor, output_mode=output_mode\n        )\n\n    @classmethod\n    def integer_hashed(cls, num_bins, output_mode=\"one_hot\", name=None):\n        name = name or auto_name(\"integer_hashed\")\n        preprocessor = layers.Hashing(\n            name=f\"{name}_preprocessor\", num_bins=num_bins\n        )\n        return Feature(\n            dtype=\"int32\", preprocessor=preprocessor, output_mode=output_mode\n        )\n\n    def __init__(\n        self,\n        features,\n        output_mode=\"concat\",\n        crosses=None,\n        crossing_dim=32,\n        hashing_dim=32,\n        num_discretization_bins=32,\n        name=None,\n    ):\n        super().__init__(name=name)\n        if not features:\n            raise ValueError(\"The `features` argument cannot be None or empty.\")\n        self.crossing_dim = crossing_dim\n        self.hashing_dim = hashing_dim\n        self.num_discretization_bins = num_discretization_bins\n        self.features = {\n            name: self._standardize_feature(name, value)\n            for name, value in features.items()\n        }\n        self.crosses = []\n        if crosses:\n            feature_set = set(features.keys())\n            for cross in crosses:\n                if isinstance(cross, dict):\n                    cross = serialization_lib.deserialize_keras_object(cross)\n                if isinstance(cross, Cross):\n                    self.crosses.append(cross)\n                else:\n                    if not crossing_dim:\n                        raise ValueError(\n                            \"When specifying `crosses`, the argument \"\n                            \"`crossing_dim` \"\n                            \"(dimensionality of the crossing space) \"\n                            \"should be specified as well.\"\n                        )\n                    for key in cross:\n                        if key not in feature_set:\n                            raise ValueError(\n                                \"All features referenced \"\n                                \"in the `crosses` argument \"\n                                \"should be present in the `features` dict. \"\n                                f\"Received unknown features: {cross}\"\n                            )\n                    self.crosses.append(Cross(cross, crossing_dim=crossing_dim))\n        self.crosses_by_name = {cross.name: cross for cross in self.crosses}\n\n        if output_mode not in {\"dict\", \"concat\"}:\n            raise ValueError(\n                \"Invalid value for argument `output_mode`. \"\n                \"Expected one of {'dict', 'concat'}. \"\n                f\"Received: output_mode={output_mode}\"\n            )\n        self.output_mode = output_mode\n\n        self.inputs = {\n            name: self._feature_to_input(name, value)\n            for name, value in self.features.items()\n        }\n        self.preprocessors = {\n            name: value.preprocessor for name, value in self.features.items()\n        }\n        self.encoded_features = None\n        self.crossers = {\n            cross.name: self._cross_to_crosser(cross) for cross in self.crosses\n        }\n        self.one_hot_encoders = {}\n        self._is_adapted = False\n        self.concat = None\n        self._preprocessed_features_names = None\n        self._crossed_features_names = None\n        self._sublayers_built = False\n\n    def _feature_to_input(self, name, feature):\n        return layers.Input(shape=(1,), dtype=feature.dtype, name=name)\n\n    def _standardize_feature(self, name, feature):\n        if isinstance(feature, Feature):\n            return feature\n\n        if isinstance(feature, dict):\n            return serialization_lib.deserialize_keras_object(feature)\n\n        if feature == \"float\":\n            return self.float(name=name)\n        elif feature == \"float_normalized\":\n            return self.float_normalized(name=name)\n        elif feature == \"float_rescaled\":\n            return self.float_rescaled(name=name)\n        elif feature == \"float_discretized\":\n            return self.float_discretized(\n                name=name, num_bins=self.num_discretization_bins\n            )\n        elif feature == \"integer_categorical\":\n            return self.integer_categorical(name=name)\n        elif feature == \"string_categorical\":\n            return self.string_categorical(name=name)\n        elif feature == \"integer_hashed\":\n            return self.integer_hashed(self.hashing_dim, name=name)\n        elif feature == \"string_hashed\":\n            return self.string_hashed(self.hashing_dim, name=name)\n        else:\n            raise ValueError(f\"Invalid feature type: {feature}\")\n\n    def _cross_to_crosser(self, cross):\n        return layers.HashedCrossing(cross.crossing_dim, name=cross.name)\n\n    def _list_adaptable_preprocessors(self):\n        adaptable_preprocessors = []\n        for name in self.features.keys():\n            preprocessor = self.preprocessors[name]\n            # Special case: a Normalization layer with preset mean/variance.\n            # Not adaptable.\n            if isinstance(preprocessor, layers.Normalization):\n                if preprocessor.input_mean is not None:\n                    continue\n            # Special case: a TextVectorization layer with provided vocabulary.\n            elif isinstance(preprocessor, layers.TextVectorization):\n                if preprocessor._has_input_vocabulary:\n                    continue\n            if hasattr(preprocessor, \"adapt\"):\n                adaptable_preprocessors.append(name)\n        return adaptable_preprocessors\n\n    def adapt(self, dataset):\n        if not isinstance(dataset, tf.data.Dataset):\n            if isinstance(dataset, dict):\n                dataset = tf.data.Dataset.from_tensor_slices(dataset)\n            else:\n                raise ValueError(\n                    \"`adapt()` can only be called on a tf.data.Dataset or a \"\n                    \"dict of arrays/lists. \"\n                    f\"Received instead: {dataset} (of type {type(dataset)})\"\n                )\n\n        for name in self._list_adaptable_preprocessors():\n            # Call adapt() on each individual adaptable layer.\n\n            # TODO: consider rewriting this to instead iterate on the\n            # dataset once, split each batch into individual features,\n            # and call the layer's `_adapt_function` on each batch\n            # to simulate the behavior of adapt() in a more performant fashion.\n\n            feature_dataset = dataset.map(lambda x: x[name])\n            preprocessor = self.preprocessors[name]\n            # TODO: consider adding an adapt progress bar.\n            # Sample 1 element to check the rank\n            x = next(iter(feature_dataset))\n            if len(x.shape) == 0:\n                # The dataset yields unbatched scalars; batch it.\n                feature_dataset = feature_dataset.batch(32)\n            if len(x.shape) in {0, 1}:\n                # If the rank is 1, add a dimension\n                # so we can reduce on axis=-1.\n                # Note: if rank was previously 0, it is now 1.\n                feature_dataset = feature_dataset.map(\n                    lambda x: tf.expand_dims(x, -1)\n                )\n            preprocessor.adapt(feature_dataset)\n        self._is_adapted = True\n        self.get_encoded_features()  # Finish building the layer\n        self.built = True\n        self._sublayers_built = True\n\n    def get_inputs(self):\n        self._check_if_built()\n        return self.inputs\n\n    def get_encoded_features(self):\n        self._check_if_adapted()\n\n        if self.encoded_features is None:\n            preprocessed_features = self._preprocess_features(self.inputs)\n            crossed_features = self._cross_features(preprocessed_features)\n            merged_features = self._merge_features(\n                preprocessed_features, crossed_features\n            )\n            self.encoded_features = merged_features\n        return self.encoded_features\n\n    def _preprocess_features(self, features):\n        return {\n            name: self.preprocessors[name](features[name])\n            for name in features.keys()\n        }\n\n    def _cross_features(self, features):\n        all_outputs = {}\n        for cross in self.crosses:\n            inputs = [features[name] for name in cross.feature_names]\n            outputs = self.crossers[cross.name](inputs)\n            all_outputs[cross.name] = outputs\n        return all_outputs\n\n    def _merge_features(self, preprocessed_features, crossed_features):\n        if not self._preprocessed_features_names:\n            self._preprocessed_features_names = sorted(\n                preprocessed_features.keys()\n            )\n            self._crossed_features_names = sorted(crossed_features.keys())\n\n        all_names = (\n            self._preprocessed_features_names + self._crossed_features_names\n        )\n        all_features = [\n            preprocessed_features[name]\n            for name in self._preprocessed_features_names\n        ] + [crossed_features[name] for name in self._crossed_features_names]\n\n        if self.output_mode == \"dict\":\n            output_dict = {}\n        else:\n            features_to_concat = []\n\n        if self._sublayers_built:\n            # Fast mode.\n            for name, feature in zip(all_names, all_features):\n                encoder = self.one_hot_encoders.get(name, None)\n                if encoder:\n                    feature = encoder(feature)\n                if self.output_mode == \"dict\":\n                    output_dict[name] = feature\n                else:\n                    features_to_concat.append(feature)\n            if self.output_mode == \"dict\":\n                return output_dict\n            else:\n                return self.concat(features_to_concat)\n\n        # If the object isn't built,\n        # we create the encoder and concat layers below\n        all_specs = [\n            self.features[name] for name in self._preprocessed_features_names\n        ] + [\n            self.crosses_by_name[name] for name in self._crossed_features_names\n        ]\n\n        for name, feature, spec in zip(all_names, all_features, all_specs):\n            if tree.is_nested(feature):\n                dtype = tree.flatten(feature)[0].dtype\n            else:\n                dtype = feature.dtype\n            dtype = backend.standardize_dtype(dtype)\n\n            if spec.output_mode == \"one_hot\":\n                preprocessor = self.preprocessors.get(\n                    name\n                ) or self.crossers.get(name)\n\n                cardinality = None\n                if not dtype.startswith(\"int\"):\n                    raise ValueError(\n                        f\"Feature '{name}' has `output_mode='one_hot'`. \"\n                        \"Thus its preprocessor should return an integer dtype. \"\n                        f\"Instead it returns a {dtype} dtype.\"\n                    )\n\n                if isinstance(\n                    preprocessor, (layers.IntegerLookup, layers.StringLookup)\n                ):\n                    cardinality = preprocessor.vocabulary_size()\n                elif isinstance(preprocessor, layers.CategoryEncoding):\n                    cardinality = preprocessor.num_tokens\n                elif isinstance(preprocessor, layers.Discretization):\n                    cardinality = preprocessor.num_bins\n                elif isinstance(\n                    preprocessor, (layers.HashedCrossing, layers.Hashing)\n                ):\n                    cardinality = preprocessor.num_bins\n                else:\n                    raise ValueError(\n                        f\"Feature '{name}' has `output_mode='one_hot'`. \"\n                        \"However it isn't a standard feature and the \"\n                        \"dimensionality of its output space is not known, \"\n                        \"thus it cannot be one-hot encoded. \"\n                        \"Try using `output_mode='int'`.\"\n                    )\n                if cardinality is not None:\n                    encoder = layers.CategoryEncoding(\n                        num_tokens=cardinality, output_mode=\"multi_hot\"\n                    )\n                    self.one_hot_encoders[name] = encoder\n                    feature = encoder(feature)\n\n            if self.output_mode == \"concat\":\n                dtype = feature.dtype\n                if dtype.startswith(\"int\") or dtype == \"string\":\n                    raise ValueError(\n                        f\"Cannot concatenate features because feature '{name}' \"\n                        f\"has not been encoded (it has dtype {dtype}). \"\n                        \"Consider using `output_mode='dict'`.\"\n                    )\n                features_to_concat.append(feature)\n            else:\n                output_dict[name] = feature\n\n        if self.output_mode == \"concat\":\n            self.concat = TFDConcat(axis=-1)\n            return self.concat(features_to_concat)\n        else:\n            return output_dict\n\n    def _check_if_adapted(self):\n        if not self._is_adapted:\n            if not self._list_adaptable_preprocessors():\n                self._is_adapted = True\n            else:\n                raise ValueError(\n                    \"You need to call `.adapt(dataset)` on the FeatureSpace \"\n                    \"before you can start using it.\"\n                )\n\n    def _check_if_built(self):\n        if not self._sublayers_built:\n            self._check_if_adapted()\n            # Finishes building\n            self.get_encoded_features()\n            self._sublayers_built = True\n\n    def _convert_input(self, x):\n        if not isinstance(x, (tf.Tensor, tf.SparseTensor, tf.RaggedTensor)):\n            if not isinstance(x, (list, tuple, int, float)):\n                x = backend.convert_to_numpy(x)\n            x = tf.convert_to_tensor(x)\n        return x\n\n    def __call__(self, data):\n        self._check_if_built()\n        if not isinstance(data, dict):\n            raise ValueError(\n                \"A FeatureSpace can only be called with a dict. \"\n                f\"Received: data={data} (of type {type(data)}\"\n            )\n\n        # Many preprocessing layers support all backends but many do not.\n        # Switch to TF to make FeatureSpace work universally.\n        data = {key: self._convert_input(value) for key, value in data.items()}\n        rebatched = False\n        for name, x in data.items():\n            if len(x.shape) == 0:\n                data[name] = tf.reshape(x, (1, 1))\n                rebatched = True\n            elif len(x.shape) == 1:\n                data[name] = tf.expand_dims(x, -1)\n\n        with backend_utils.TFGraphScope():\n            # This scope is to make sure that inner DataLayers\n            # will not convert outputs back to backend-native --\n            # they should be TF tensors throughout\n            preprocessed_data = self._preprocess_features(data)\n            preprocessed_data = tree.map_structure(\n                lambda x: self._convert_input(x), preprocessed_data\n            )\n\n            crossed_data = self._cross_features(preprocessed_data)\n            crossed_data = tree.map_structure(\n                lambda x: self._convert_input(x), crossed_data\n            )\n\n            merged_data = self._merge_features(preprocessed_data, crossed_data)\n\n        if rebatched:\n            if self.output_mode == \"concat\":\n                if merged_data.shape[0] != 1:\n                    raise ValueError(\n                        \"Expected rebatched data to have batch size 1. \"\n                        f\"Received: shape={merged_data.shape}\"\n                    )\n                if (\n                    backend.backend() != \"tensorflow\"\n                    and not backend_utils.in_tf_graph()\n                ):\n                    merged_data = backend.convert_to_numpy(merged_data)\n                merged_data = tf.squeeze(merged_data, axis=0)\n            else:\n                for name, x in merged_data.items():\n                    if len(x.shape) == 2 and x.shape[0] == 1:\n                        merged_data[name] = tf.squeeze(x, axis=0)\n\n        if (\n            backend.backend() != \"tensorflow\"\n            and not backend_utils.in_tf_graph()\n        ):\n            merged_data = tree.map_structure(\n                lambda x: backend.convert_to_tensor(x, dtype=x.dtype),\n                merged_data,\n            )\n        return merged_data\n\n    def get_config(self):\n        return {\n            \"features\": serialization_lib.serialize_keras_object(self.features),\n            \"output_mode\": self.output_mode,\n            \"crosses\": serialization_lib.serialize_keras_object(self.crosses),\n            \"crossing_dim\": self.crossing_dim,\n            \"hashing_dim\": self.hashing_dim,\n            \"num_discretization_bins\": self.num_discretization_bins,\n        }\n\n    @classmethod\n    def from_config(cls, config):\n        return cls(**config)\n\n    def get_build_config(self):\n        return {\n            name: feature.preprocessor.get_build_config()\n            for name, feature in self.features.items()\n        }\n\n    def build_from_config(self, config):\n        for name in config.keys():\n            preprocessor = self.features[name].preprocessor\n            if not preprocessor.built:\n                preprocessor.build_from_config(config[name])\n        self._is_adapted = True\n\n    def save(self, filepath):\n        \"\"\"Save the `FeatureSpace` instance to a `.keras` file.\n\n        You can reload it via `keras.models.load_model()`:\n\n        ```python\n        feature_space.save(\"featurespace.keras\")\n        reloaded_fs = keras.models.load_model(\"featurespace.keras\")\n        ```\n        \"\"\"\n        saving_lib.save_model(self, filepath)\n\n    def save_own_variables(self, store):\n        return\n\n    def load_own_variables(self, store):\n        return\n\n\nclass TFDConcat(DataLayer):\n    def __init__(self, axis, **kwargs):\n        super().__init__(**kwargs)\n        self.axis = axis\n\n    def call(self, xs):\n        return self.backend.numpy.concatenate(xs, axis=self.axis)\n\n\nclass TFDIdentity(DataLayer):\n    def call(self, x):\n        return x\n"
  },
  {
    "path": "keras/src/layers/preprocessing/feature_space_test.py",
    "content": "import os\n\nimport pytest\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.layers.preprocessing import feature_space\nfrom keras.src.saving import saving_api\n\n\nclass FeatureSpaceTest(testing.TestCase):\n    def _get_train_data_dict(\n        self,\n        as_dataset=False,\n        as_tensors=False,\n        as_labeled_dataset=False,\n        include_strings=True,\n    ):\n        data = {\n            \"float_1\": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],\n            \"float_2\": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],\n            \"float_3\": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],\n            \"int_1\": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],\n            \"int_2\": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],\n            \"int_3\": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],\n        }\n        if include_strings:\n            data[\"string_1\"] = [\n                \"0\",\n                \"1\",\n                \"2\",\n                \"3\",\n                \"4\",\n                \"5\",\n                \"6\",\n                \"7\",\n                \"8\",\n                \"9\",\n            ]\n            data[\"string_2\"] = [\n                \"0\",\n                \"1\",\n                \"2\",\n                \"3\",\n                \"4\",\n                \"5\",\n                \"6\",\n                \"7\",\n                \"8\",\n                \"9\",\n            ]\n\n        if as_dataset:\n            return tf_data.Dataset.from_tensor_slices(data)\n        elif as_tensors:\n            return {\n                key: ops.convert_to_tensor(value) for key, value in data.items()\n            }\n        elif as_labeled_dataset:\n            labels = [0, 1, 0, 1, 0, 0, 1, 0, 1, 1]\n            return tf_data.Dataset.from_tensor_slices((data, labels))\n        return data\n\n    def test_basic_usage_no_strings(self):\n        fs = feature_space.FeatureSpace(\n            features={\n                \"float_1\": \"float\",\n                \"float_2\": \"float_normalized\",\n                \"float_3\": \"float_discretized\",\n                \"int_1\": \"integer_categorical\",\n                \"int_2\": \"integer_hashed\",\n                \"int_3\": \"integer_categorical\",\n            },\n            crosses=[(\"int_1\", \"int_2\"), (\"int_2\", \"int_3\")],\n            output_mode=\"concat\",\n        )\n        # Test unbatched adapt\n        fs.adapt(\n            self._get_train_data_dict(as_dataset=True, include_strings=False)\n        )\n        # Test batched adapt\n        fs.adapt(\n            self._get_train_data_dict(\n                as_dataset=True, include_strings=False\n            ).batch(4)\n        )\n\n        # Test unbatched call on raw data\n        data = {\n            key: value[0]\n            for key, value in self._get_train_data_dict(\n                include_strings=False\n            ).items()\n        }\n        out = fs(data)\n        out_dim = 152\n        self.assertEqual(out.shape, (out_dim,))\n\n        # Test unbatched call on backend tensors\n        data = self._get_train_data_dict(as_tensors=True, include_strings=False)\n        data = {key: value[0] for key, value in data.items()}\n        out = fs(data)\n        self.assertEqual(out.shape, (out_dim,))\n\n        # Test batched call on raw data\n        out = fs(self._get_train_data_dict(include_strings=False))\n        self.assertEqual(out.shape, (10, out_dim))\n\n        # Test batched call on backend tensors\n        out = fs(\n            self._get_train_data_dict(as_tensors=True, include_strings=False)\n        )\n        self.assertEqual(out.shape, (10, out_dim))\n\n    def test_output_mode_dict_no_strings(self):\n        fs = feature_space.FeatureSpace(\n            features={\n                \"float_1\": \"float\",\n                \"float_2\": \"float_normalized\",\n                \"float_3\": \"float_discretized\",\n                \"int_1\": \"integer_categorical\",\n                \"int_2\": \"integer_hashed\",\n                \"int_3\": \"integer_categorical\",\n            },\n            crosses=[(\"int_1\", \"int_2\")],\n            output_mode=\"dict\",\n        )\n        fs.adapt(\n            self._get_train_data_dict(as_dataset=True, include_strings=False)\n        )\n\n        # Test unbatched call on raw data\n        data = {\n            key: value[0]\n            for key, value in self._get_train_data_dict(\n                include_strings=False\n            ).items()\n        }\n        out = fs(data)\n        self.assertIsInstance(out, dict)\n        self.assertLen(out, 7)\n        self.assertEqual(out[\"int_2\"].shape, (32,))\n        self.assertEqual(out[\"int_1_X_int_2\"].shape, (32,))\n\n        # Test batched call on raw data\n        out = fs(self._get_train_data_dict(include_strings=False))\n        self.assertIsInstance(out, dict)\n        self.assertLen(out, 7)\n        self.assertEqual(out[\"int_2\"].shape, (10, 32))\n\n        # Test batched call on backend tensors\n        out = fs(\n            self._get_train_data_dict(as_tensors=True, include_strings=False)\n        )\n        self.assertIsInstance(out, dict)\n        self.assertLen(out, 7)\n        self.assertEqual(out[\"int_2\"].shape, (10, 32))\n\n    def test_output_mode_dict_of_ints_no_strings(self):\n        cls = feature_space.FeatureSpace\n        fs = feature_space.FeatureSpace(\n            features={\n                \"float_1\": \"float\",\n                \"float_2\": \"float_normalized\",\n                \"float_3\": \"float_discretized\",\n                \"int_1\": cls.integer_categorical(output_mode=\"int\"),\n                \"int_2\": cls.integer_hashed(num_bins=32, output_mode=\"int\"),\n                \"int_3\": cls.integer_categorical(output_mode=\"int\"),\n            },\n            crosses=[\n                cls.cross(\n                    (\"int_1\", \"int_2\"), output_mode=\"int\", crossing_dim=32\n                ),\n            ],\n            output_mode=\"dict\",\n        )\n        fs.adapt(\n            self._get_train_data_dict(as_dataset=True, include_strings=False)\n        )\n        data = {\n            key: value[0]\n            for key, value in self._get_train_data_dict(\n                include_strings=False\n            ).items()\n        }\n        out = fs(data)\n        self.assertIsInstance(out, dict)\n        self.assertLen(out, 7)\n        self.assertEqual(out[\"int_2\"].shape, (1,))\n        self.assertTrue(\n            backend.standardize_dtype(out[\"int_2\"].dtype).startswith(\"int\")\n        )\n        self.assertEqual(out[\"int_1_X_int_2\"].shape, (1,))\n        self.assertTrue(\n            backend.standardize_dtype(out[\"int_1_X_int_2\"].dtype).startswith(\n                \"int\"\n            )\n        )\n\n    def test_basic_usage(self):\n        fs = feature_space.FeatureSpace(\n            features={\n                \"float_1\": \"float\",\n                \"float_2\": \"float_normalized\",\n                \"float_3\": \"float_discretized\",\n                \"string_1\": \"string_categorical\",\n                \"string_2\": \"string_hashed\",\n                \"int_1\": \"integer_categorical\",\n                \"int_2\": \"integer_hashed\",\n                \"int_3\": \"integer_categorical\",\n            },\n            crosses=[(\"float_3\", \"string_1\"), (\"string_2\", \"int_2\")],\n            output_mode=\"concat\",\n        )\n        # Test unbatched adapt\n        fs.adapt(self._get_train_data_dict(as_dataset=True))\n        # Test batched adapt\n        fs.adapt(self._get_train_data_dict(as_dataset=True).batch(4))\n\n        # Test unbatched call on raw data\n        data = {\n            key: value[0] for key, value in self._get_train_data_dict().items()\n        }\n        out = fs(data)\n        out_dim = 195\n        self.assertEqual(out.shape, (out_dim,))\n\n        # Test unbatched call on tensors\n        if backend.backend() == \"tensorflow\":\n            data = self._get_train_data_dict(as_tensors=True)\n            data = {key: value[0] for key, value in data.items()}\n            out = fs(data)\n            self.assertEqual(out.shape, (out_dim,))\n\n        # Test batched call on raw data\n        out = fs(self._get_train_data_dict())\n        self.assertEqual(out.shape, (10, out_dim))\n\n        # Test batched call on tensors\n        if backend.backend() == \"tensorflow\":\n            out = fs(self._get_train_data_dict(as_tensors=True))\n            self.assertEqual(out.shape, (10, out_dim))\n\n    def test_output_mode_dict(self):\n        fs = feature_space.FeatureSpace(\n            features={\n                \"float_1\": \"float\",\n                \"float_2\": \"float_normalized\",\n                \"float_3\": \"float_discretized\",\n                \"string_1\": \"string_categorical\",\n                \"string_2\": \"string_hashed\",\n                \"int_1\": \"integer_categorical\",\n                \"int_2\": \"integer_hashed\",\n                \"int_3\": \"integer_categorical\",\n            },\n            crosses=[(\"float_3\", \"string_1\"), (\"string_2\", \"int_2\")],\n            output_mode=\"dict\",\n        )\n        fs.adapt(self._get_train_data_dict(as_dataset=True))\n\n        # Test unbatched call on raw data\n        data = {\n            key: value[0] for key, value in self._get_train_data_dict().items()\n        }\n        out = fs(data)\n        self.assertIsInstance(out, dict)\n        self.assertLen(out, 10)\n        self.assertEqual(out[\"string_1\"].shape, (11,))\n        self.assertEqual(out[\"int_2\"].shape, (32,))\n        self.assertEqual(out[\"string_2_X_int_2\"].shape, (32,))\n\n        # Test batched call on raw data\n        out = fs(self._get_train_data_dict())\n        self.assertIsInstance(out, dict)\n        self.assertLen(out, 10)\n        self.assertEqual(out[\"string_1\"].shape, (10, 11))\n        self.assertEqual(out[\"int_2\"].shape, (10, 32))\n        self.assertEqual(out[\"string_2_X_int_2\"].shape, (10, 32))\n\n        # Test batched call on tensors\n        if backend.backend() == \"tensorflow\":\n            out = fs(self._get_train_data_dict(as_tensors=True))\n            self.assertIsInstance(out, dict)\n            self.assertLen(out, 10)\n            self.assertEqual(out[\"string_1\"].shape, (10, 11))\n            self.assertEqual(out[\"int_2\"].shape, (10, 32))\n            self.assertEqual(out[\"string_2_X_int_2\"].shape, (10, 32))\n\n    def test_output_mode_dict_of_ints(self):\n        cls = feature_space.FeatureSpace\n        fs = feature_space.FeatureSpace(\n            features={\n                \"float_1\": \"float\",\n                \"float_2\": \"float_normalized\",\n                \"float_3\": \"float_discretized\",\n                \"string_1\": cls.string_categorical(output_mode=\"int\"),\n                \"string_2\": cls.string_hashed(num_bins=32, output_mode=\"int\"),\n                \"int_1\": cls.integer_categorical(output_mode=\"int\"),\n                \"int_2\": cls.integer_hashed(num_bins=32, output_mode=\"int\"),\n                \"int_3\": cls.integer_categorical(output_mode=\"int\"),\n            },\n            crosses=[\n                cls.cross(\n                    (\"float_3\", \"string_1\"), output_mode=\"int\", crossing_dim=32\n                ),\n                cls.cross(\n                    (\"string_2\", \"int_2\"), output_mode=\"int\", crossing_dim=32\n                ),\n            ],\n            output_mode=\"dict\",\n        )\n        fs.adapt(self._get_train_data_dict(as_dataset=True))\n        data = {\n            key: value[0] for key, value in self._get_train_data_dict().items()\n        }\n        out = fs(data)\n        self.assertIsInstance(out, dict)\n        self.assertLen(out, 10)\n        self.assertEqual(out[\"string_1\"].shape, (1,))\n        self.assertTrue(\n            backend.standardize_dtype(out[\"string_1\"].dtype).startswith(\"int\")\n        )\n        self.assertEqual(out[\"int_2\"].shape, (1,))\n        self.assertTrue(\n            backend.standardize_dtype(out[\"int_2\"].dtype).startswith(\"int\")\n        )\n        self.assertEqual(out[\"string_2_X_int_2\"].shape, (1,))\n        self.assertTrue(\n            backend.standardize_dtype(out[\"string_2_X_int_2\"].dtype).startswith(\n                \"int\"\n            )\n        )\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\", reason=\"Requires string dtype.\"\n    )\n    def test_functional_api_sync_processing(self):\n        fs = feature_space.FeatureSpace(\n            features={\n                \"float_1\": \"float\",\n                \"float_2\": \"float_normalized\",\n                \"float_3\": \"float_discretized\",\n                \"string_1\": \"string_categorical\",\n                \"string_2\": \"string_hashed\",\n                \"int_1\": \"integer_categorical\",\n                \"int_2\": \"integer_hashed\",\n                \"int_3\": \"integer_categorical\",\n            },\n            crosses=[(\"float_3\", \"string_1\"), (\"string_2\", \"int_2\")],\n            output_mode=\"concat\",\n        )\n        fs.adapt(self._get_train_data_dict(as_dataset=True))\n        inputs = fs.get_inputs()\n        features = fs.get_encoded_features()\n        outputs = layers.Dense(1)(features)\n        model = models.Model(inputs=inputs, outputs=outputs)\n        model.compile(\"adam\", \"mse\")\n        ds = self._get_train_data_dict(as_labeled_dataset=True)\n        model.fit(ds.batch(4))\n        model.evaluate(ds.batch(4))\n        ds = self._get_train_data_dict(as_dataset=True)\n        model.predict(ds.batch(4))\n\n    @pytest.mark.requires_trainable_backend\n    def test_tf_data_async_processing(self):\n        fs = feature_space.FeatureSpace(\n            features={\n                \"float_1\": \"float\",\n                \"float_2\": \"float_normalized\",\n                \"float_3\": \"float_discretized\",\n                \"int_1\": \"integer_categorical\",\n                \"int_2\": \"integer_hashed\",\n                \"int_3\": \"integer_categorical\",\n            },\n            crosses=[(\"float_3\", \"int_1\"), (\"int_1\", \"int_2\")],\n            output_mode=\"concat\",\n        )\n        fs.adapt(\n            self._get_train_data_dict(as_dataset=True, include_strings=False)\n        )\n        features = fs.get_encoded_features()\n        outputs = layers.Dense(1)(features)\n        model = models.Model(inputs=features, outputs=outputs)\n        model.compile(\"adam\", \"mse\")\n        ds = self._get_train_data_dict(\n            as_labeled_dataset=True, include_strings=False\n        )\n        # Try map before batch\n        ds = ds.map(lambda x, y: (fs(x), y))\n        model.fit(ds.batch(4))\n        # Try map after batch\n        ds = self._get_train_data_dict(\n            as_labeled_dataset=True, include_strings=False\n        )\n        ds = ds.batch(4)\n        ds = ds.map(lambda x, y: (fs(x), y))\n        model.evaluate(ds)\n        ds = self._get_train_data_dict(as_dataset=True, include_strings=False)\n        ds = ds.map(fs)\n        model.predict(ds.batch(4))\n\n    def test_advanced_usage(self):\n        cls = feature_space.FeatureSpace\n        fs = feature_space.FeatureSpace(\n            features={\n                \"float_1\": cls.float(),\n                \"float_2\": cls.float_normalized(),\n                \"float_3\": cls.float_discretized(num_bins=3),\n                \"string_1\": cls.string_categorical(max_tokens=5),\n                \"string_2\": cls.string_hashed(num_bins=32),\n                \"int_1\": cls.integer_categorical(\n                    max_tokens=5, num_oov_indices=2\n                ),\n                \"int_2\": cls.integer_hashed(num_bins=32),\n                \"int_3\": cls.integer_categorical(max_tokens=5),\n            },\n            crosses=[\n                cls.cross((\"float_3\", \"string_1\"), crossing_dim=32),\n                cls.cross((\"string_2\", \"int_2\"), crossing_dim=32),\n            ],\n            output_mode=\"concat\",\n        )\n        fs.adapt(self._get_train_data_dict(as_dataset=True))\n        data = {\n            key: value[0] for key, value in self._get_train_data_dict().items()\n        }\n        out = fs(data)\n        self.assertEqual(out.shape, (148,))\n\n    def test_manual_kpl(self):\n        data = {\n            \"text\": [\"1st string\", \"2nd string\", \"3rd string\"],\n        }\n        cls = feature_space.FeatureSpace\n\n        # Test with a tf-idf TextVectorization layer\n        tv = layers.TextVectorization(output_mode=\"tf_idf\")\n        fs = feature_space.FeatureSpace(\n            features={\n                \"text\": cls.feature(\n                    preprocessor=tv, dtype=\"string\", output_mode=\"float\"\n                ),\n            },\n            output_mode=\"concat\",\n        )\n        fs.adapt(tf_data.Dataset.from_tensor_slices(data))\n        out = fs(data)\n        self.assertEqual(list(out.shape), [3, 5])\n\n    def test_no_adapt(self):\n        data = {\n            \"int_1\": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9],\n            \"text_1\": [\n                \"This is\",\n                \"not just\",\n                \"an example\",\n                \"of random words.\",\n                \"these are\",\n                \"some words\",\n                \"in\",\n                \"a random\",\n                \"example.\",\n                \"Bye!\",\n            ],\n            \"float_1\": [\n                -1.2,\n                0.0,\n                2.4,\n                1.2,\n                15.0,\n                -100.0,\n                23.1,\n                3.12,\n                0.1,\n                -0.01,\n            ],\n        }\n        cls = feature_space.FeatureSpace\n        # Pre-defined vocabulary. No need to adapt.\n        tv_vocab = [\n            \"this\",\n            \"is\",\n            \"just\",\n            \"an\",\n            \"example\",\n            \"with\",\n            \"some\",\n            \"words\",\n        ]\n        tv_with_vocab = layers.TextVectorization(\n            vocabulary=tv_vocab, output_mode=\"int\", output_sequence_length=3\n        )\n\n        # Pre-defined mean and variance. No need to adapt.\n        mean, variance = 12.0, 5.0\n        normalization = layers.Normalization(mean=mean, variance=variance)\n\n        fs = feature_space.FeatureSpace(\n            {\n                \"int_1\": \"integer_hashed\",\n                \"text_1\": cls.feature(\n                    dtype=\"string\",\n                    preprocessor=tv_with_vocab,\n                    output_mode=\"int\",\n                ),\n                \"float_1\": cls.feature(\n                    dtype=\"float32\",\n                    preprocessor=normalization,\n                    output_mode=\"float\",\n                ),\n            },\n            output_mode=\"dict\",\n        )\n\n        out = fs(data)\n        float_out = ops.divide(\n            ops.convert_to_tensor(data[\"float_1\"]) - mean, ops.sqrt(variance)\n        )\n        float_out = ops.reshape(float_out, (10, -1))\n\n        self.assertEqual(tuple(out[\"int_1\"].shape), (10, 32))\n        self.assertEqual(tuple(out[\"text_1\"].shape), (10, 3))\n        self.assertAllClose(out[\"float_1\"], float_out, atol=1e-3)\n\n    @pytest.mark.skipif(\n        backend.backend() in (\"numpy\", \"torch\"),\n        reason=(\n            \"TODO: When using FeatureSpace as a Model in torch and numpy, \"\n            \"the error is large.\"\n        ),\n    )\n    def test_saving(self):\n        cls = feature_space.FeatureSpace\n        fs = feature_space.FeatureSpace(\n            features={\n                \"float_1\": cls.float(),\n                \"float_2\": cls.float_normalized(),\n                \"float_3\": cls.float_discretized(num_bins=3),\n                \"int_1\": cls.integer_categorical(\n                    max_tokens=5, num_oov_indices=2\n                ),\n                \"int_2\": cls.integer_hashed(num_bins=32),\n                \"int_3\": cls.integer_categorical(max_tokens=5),\n            },\n            crosses=[\n                cls.cross((\"float_3\", \"int_1\"), crossing_dim=32),\n                cls.cross((\"int_1\", \"int_2\"), crossing_dim=32),\n            ],\n            output_mode=\"concat\",\n        )\n        fs.adapt(\n            self._get_train_data_dict(as_dataset=True, include_strings=False)\n        )\n        data = {\n            key: value[0]\n            for key, value in self._get_train_data_dict(\n                include_strings=False\n            ).items()\n        }\n        ref_out = fs(data)\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"fs.keras\")\n        fs.save(temp_filepath)\n        fs = saving_api.load_model(temp_filepath)\n\n        # Save again immediately after loading to test idempotency\n        temp_filepath = os.path.join(self.get_temp_dir(), \"fs2.keras\")\n        fs.save(temp_filepath)\n\n        # Test correctness of the first saved FS\n        out = fs(data)\n        self.assertAllClose(out, ref_out)\n\n        inputs = fs.get_inputs()\n        outputs = fs.get_encoded_features()\n        model = models.Model(inputs=inputs, outputs=outputs)\n        ds = self._get_train_data_dict(as_dataset=True, include_strings=False)\n        out = model.predict(ds.batch(4))\n        self.assertAllClose(out[0], ref_out)\n\n        # Test correctness of the re-saved FS\n        fs = saving_api.load_model(temp_filepath)\n        out = fs(data)\n        self.assertAllClose(out, ref_out)\n\n    def test_errors(self):\n        # Test no features\n        with self.assertRaisesRegex(ValueError, \"cannot be None or empty\"):\n            feature_space.FeatureSpace(features={})\n        # Test no crossing dim\n        with self.assertRaisesRegex(ValueError, \"`crossing_dim`\"):\n            feature_space.FeatureSpace(\n                features={\n                    \"f1\": \"integer_categorical\",\n                    \"f2\": \"integer_categorical\",\n                },\n                crosses=[(\"f1\", \"f2\")],\n                crossing_dim=None,\n            )\n        # Test wrong cross feature name\n        with self.assertRaisesRegex(ValueError, \"should be present in \"):\n            feature_space.FeatureSpace(\n                features={\n                    \"f1\": \"integer_categorical\",\n                    \"f2\": \"integer_categorical\",\n                },\n                crosses=[(\"f1\", \"unknown\")],\n                crossing_dim=32,\n            )\n        # Test wrong output mode\n        with self.assertRaisesRegex(ValueError, \"for argument `output_mode`\"):\n            feature_space.FeatureSpace(\n                features={\n                    \"f1\": \"integer_categorical\",\n                    \"f2\": \"integer_categorical\",\n                },\n                output_mode=\"unknown\",\n            )\n        # Test call before adapt\n        with self.assertRaisesRegex(ValueError, \"You need to call `.adapt\"):\n            fs = feature_space.FeatureSpace(\n                features={\n                    \"f1\": \"integer_categorical\",\n                    \"f2\": \"integer_categorical\",\n                }\n            )\n            fs({\"f1\": [0], \"f2\": [0]})\n        # Test get_encoded_features before adapt\n        with self.assertRaisesRegex(ValueError, \"You need to call `.adapt\"):\n            fs = feature_space.FeatureSpace(\n                features={\n                    \"f1\": \"integer_categorical\",\n                    \"f2\": \"integer_categorical\",\n                }\n            )\n            fs.get_encoded_features()\n\n    def test_adapt_with_dict(self):\n        fs = feature_space.FeatureSpace(\n            features={\n                \"float_1\": \"float\",\n                \"float_2\": \"float_normalized\",\n                \"float_3\": \"float_discretized\",\n                \"string_1\": \"string_categorical\",\n                \"string_2\": \"string_hashed\",\n                \"int_1\": \"integer_categorical\",\n                \"int_2\": \"integer_hashed\",\n                \"int_3\": \"integer_categorical\",\n            },\n            crosses=[(\"float_3\", \"string_1\"), (\"string_2\", \"int_2\")],\n            output_mode=\"concat\",\n        )\n        # Adapt with dict\n        train_data = self._get_train_data_dict(as_dataset=False)\n        fs.adapt(train_data)\n\n        # Build checks\n        data = {key: value[0] for key, value in train_data.items()}\n        out = fs(data)\n        out_dim = 195\n        self.assertEqual(out.shape, (out_dim,))\n"
  },
  {
    "path": "keras/src/layers/preprocessing/hashed_crossing.py",
    "content": "from keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.layer import Layer\nfrom keras.src.utils import argument_validation\nfrom keras.src.utils import backend_utils\nfrom keras.src.utils import numerical_utils\nfrom keras.src.utils import tf_utils\nfrom keras.src.utils.module_utils import tensorflow as tf\n\n\n@keras_export(\"keras.layers.HashedCrossing\")\nclass HashedCrossing(Layer):\n    \"\"\"A preprocessing layer which crosses features using the \"hashing trick\".\n\n    This layer performs crosses of categorical features using the \"hashing\n    trick\". Conceptually, the transformation can be thought of as:\n    `hash(concatenate(features)) % num_bins`.\n\n    This layer currently only performs crosses of scalar inputs and batches of\n    scalar inputs. Valid input shapes are `(batch_size, 1)`, `(batch_size,)` and\n    `()`.\n\n    **Note:** This layer wraps `tf.keras.layers.HashedCrossing`. It cannot\n    be used as part of the compiled computation graph of a model with\n    any backend other than TensorFlow.\n    It can however be used with any backend when running eagerly.\n    It can also always be used as part of an input preprocessing pipeline\n    with any backend (outside the model itself), which is how we recommend\n    to use this layer.\n\n    **Note:** This layer is safe to use inside a `tf.data` pipeline\n    (independently of which backend you're using).\n\n    Args:\n        num_bins: Number of hash bins.\n        output_mode: Specification for the output of the layer. Values can be\n            `\"int\"`, or `\"one_hot\"` configuring the layer as follows:\n            - `\"int\"`: Return the integer bin indices directly.\n            - `\"one_hot\"`: Encodes each individual element in the input into an\n                array the same size as `num_bins`, containing a 1 at the input's\n                bin index. Defaults to `\"int\"`.\n        sparse: Boolean. Only applicable to `\"one_hot\"` mode and only valid\n            when using the TensorFlow backend. If `True`, returns\n            a `SparseTensor` instead of a dense `Tensor`. Defaults to `False`.\n        **kwargs: Keyword arguments to construct a layer.\n\n    Examples:\n\n    **Crossing two scalar features.**\n\n    >>> layer = keras.layers.HashedCrossing(\n    ...     num_bins=5)\n    >>> feat1 = np.array(['A', 'B', 'A', 'B', 'A'])\n    >>> feat2 = np.array([101, 101, 101, 102, 102])\n    >>> layer((feat1, feat2))\n    array([1, 4, 1, 1, 3])\n\n    **Crossing and one-hotting two scalar features.**\n\n    >>> layer = keras.layers.HashedCrossing(\n    ...     num_bins=5, output_mode='one_hot')\n    >>> feat1 = np.array(['A', 'B', 'A', 'B', 'A'])\n    >>> feat2 = np.array([101, 101, 101, 102, 102])\n    >>> layer((feat1, feat2))\n    array([[0., 1., 0., 0., 0.],\n            [0., 0., 0., 0., 1.],\n            [0., 1., 0., 0., 0.],\n            [0., 1., 0., 0., 0.],\n            [0., 0., 0., 1., 0.]], dtype=float32)\n    \"\"\"\n\n    def __init__(\n        self,\n        num_bins,\n        output_mode=\"int\",\n        sparse=False,\n        name=None,\n        dtype=None,\n        **kwargs,\n    ):\n        if not tf.available:\n            raise ImportError(\n                \"Layer HashedCrossing requires TensorFlow. \"\n                \"Install it via `pip install tensorflow`.\"\n            )\n\n        if output_mode == \"int\" and dtype is None:\n            dtype = \"int64\"\n\n        super().__init__(name=name, dtype=dtype)\n        if sparse and backend.backend() != \"tensorflow\":\n            raise ValueError(\n                \"`sparse=True` can only be used with the TensorFlow backend.\"\n            )\n\n        argument_validation.validate_string_arg(\n            output_mode,\n            allowable_strings=(\"int\", \"one_hot\"),\n            caller_name=self.__class__.__name__,\n            arg_name=\"output_mode\",\n        )\n\n        self.num_bins = num_bins\n        self.output_mode = output_mode\n        self.sparse = sparse\n        self._allow_non_tensor_positional_args = True\n        self._convert_input_args = False\n        self.supports_jit = False\n\n    def compute_output_shape(self, input_shape):\n        if (\n            not len(input_shape) == 2\n            or not isinstance(input_shape[0], tuple)\n            or not isinstance(input_shape[1], tuple)\n        ):\n            raise ValueError(\n                \"Expected as input a list/tuple of 2 tensors. \"\n                f\"Received input_shape={input_shape}\"\n            )\n        if input_shape[0][-1] != input_shape[1][-1]:\n            raise ValueError(\n                \"Expected the two input tensors to have identical shapes. \"\n                f\"Received input_shape={input_shape}\"\n            )\n\n        if not input_shape:\n            if self.output_mode == \"int\":\n                return ()\n            return (self.num_bins,)\n        if self.output_mode == \"int\":\n            return tuple(input_shape[0])\n\n        if self.output_mode == \"one_hot\" and input_shape[0][-1] != 1:\n            return tuple(input_shape[0]) + (self.num_bins,)\n\n        return tuple(input_shape[0])[:-1] + (self.num_bins,)\n\n    def call(self, inputs):\n        from keras.src.backend import tensorflow as tf_backend\n\n        self._check_at_least_two_inputs(inputs)\n        inputs = [tf_utils.ensure_tensor(x) for x in inputs]\n        self._check_input_shape_and_type(inputs)\n\n        # Uprank to rank 2 for the cross_hashed op.\n        first_shape = tuple(inputs[0].shape)\n        rank = len(first_shape)\n        if rank < 2:\n            inputs = [tf_backend.numpy.expand_dims(x, -1) for x in inputs]\n        if rank < 1:\n            inputs = [tf_backend.numpy.expand_dims(x, -1) for x in inputs]\n\n        # Perform the cross and convert to dense\n        outputs = tf.sparse.cross_hashed(inputs, self.num_bins)\n        outputs = tf.sparse.to_dense(outputs)\n\n        # tf.sparse.cross_hashed output shape will always have None dimensions.\n        # Re-apply the known static shape and downrank to match input rank.\n        if rank == 2:\n            outputs.set_shape(first_shape)\n        elif rank == 1:\n            outputs.set_shape(first_shape + (1,))\n            outputs = tf.squeeze(outputs, axis=1)\n        elif rank == 0:\n            outputs = tf.reshape(outputs, [])\n\n        # Encode outputs.\n        outputs = numerical_utils.encode_categorical_inputs(\n            outputs,\n            output_mode=self.output_mode,\n            depth=self.num_bins,\n            sparse=self.sparse,\n            dtype=self.compute_dtype,\n            backend_module=tf_backend,\n        )\n        return backend_utils.convert_tf_tensor(outputs, dtype=self.dtype)\n\n    def get_config(self):\n        return {\n            \"num_bins\": self.num_bins,\n            \"output_mode\": self.output_mode,\n            \"sparse\": self.sparse,\n            \"name\": self.name,\n            \"dtype\": self.dtype,\n        }\n\n    def _check_at_least_two_inputs(self, inputs):\n        if not isinstance(inputs, (list, tuple)):\n            raise ValueError(\n                \"`HashedCrossing` should be called on a list or tuple of \"\n                f\"inputs. Received: inputs={inputs}\"\n            )\n        if len(inputs) < 2:\n            raise ValueError(\n                \"`HashedCrossing` should be called on at least two inputs. \"\n                f\"Received: inputs={inputs}\"\n            )\n\n    def _check_input_shape_and_type(self, inputs):\n        first_shape = tuple(inputs[0].shape)\n        rank = len(first_shape)\n        if rank > 2 or (rank == 2 and first_shape[-1] != 1):\n            raise ValueError(\n                \"All `HashedCrossing` inputs should have shape `()`, \"\n                \"`(batch_size)` or `(batch_size, 1)`. \"\n                f\"Received: inputs={inputs}\"\n            )\n        if not all(tuple(x.shape) == first_shape for x in inputs[1:]):\n            raise ValueError(\n                \"All `HashedCrossing` inputs should have equal shape. \"\n                f\"Received: inputs={inputs}\"\n            )\n        if any(\n            isinstance(x, (tf.RaggedTensor, tf.SparseTensor)) for x in inputs\n        ):\n            raise ValueError(\n                \"All `HashedCrossing` inputs should be dense tensors. \"\n                f\"Received: inputs={inputs}\"\n            )\n        if not all(\n            tf.as_dtype(x.dtype).is_integer or x.dtype == tf.string\n            for x in inputs\n        ):\n            raise ValueError(\n                \"All `HashedCrossing` inputs should have an integer or \"\n                f\"string dtype. Received: inputs={inputs}\"\n            )\n"
  },
  {
    "path": "keras/src/layers/preprocessing/hashed_crossing_test.py",
    "content": "import numpy as np\nimport pytest\nimport tensorflow as tf\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\nfrom keras.src.testing.test_utils import named_product\n\n\nclass HashedCrossingTest(testing.TestCase):\n    def test_basics(self):\n        self.run_layer_test(\n            layers.HashedCrossing,\n            init_kwargs={\n                \"num_bins\": 3,\n                \"output_mode\": \"int\",\n            },\n            input_data=(np.array([1, 2]), np.array([4, 5])),\n            expected_output_shape=(2,),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n            run_training_check=False,\n            # Incomplete op support on tensorflow.\n            run_mixed_precision_check=False,\n        )\n        self.run_layer_test(\n            layers.HashedCrossing,\n            init_kwargs={\"num_bins\": 4, \"output_mode\": \"one_hot\"},\n            input_data=(np.array([1, 2]), np.array([4, 5])),\n            expected_output_shape=(2, 4),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n            run_training_check=False,\n            # Incomplete op support on tensorflow.\n            run_mixed_precision_check=False,\n        )\n\n    @parameterized.named_parameters(\n        named_product(\n            sparse=(\n                [True, False] if backend.backend() == \"tensorflow\" else [False]\n            )\n        )\n    )\n    def test_correctness(self, sparse):\n        layer = layers.HashedCrossing(num_bins=5)\n        feat1 = np.array([\"A\", \"B\", \"A\", \"B\", \"A\"])\n        feat2 = np.array([101, 101, 101, 102, 102])\n        output = layer((feat1, feat2))\n        self.assertAllClose(tf.constant([1, 4, 1, 1, 3]), output)\n\n        layer = layers.HashedCrossing(\n            num_bins=5, output_mode=\"one_hot\", sparse=sparse\n        )\n        feat1 = np.array([\"A\", \"B\", \"A\", \"B\", \"A\"])\n        feat2 = np.array([101, 101, 101, 102, 102])\n        output = layer((feat1, feat2))\n        self.assertSparse(output, sparse)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.0, 1.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 1.0],\n                    [0.0, 1.0, 0.0, 0.0, 0.0],\n                    [0.0, 1.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 1.0, 0.0],\n                ]\n            ),\n            output,\n        )\n\n    def test_tf_data_compatibility(self):\n        layer = layers.HashedCrossing(num_bins=5)\n        feat1 = np.array([\"A\", \"B\", \"A\", \"B\", \"A\"])\n        feat2 = np.array([101, 101, 101, 102, 102])\n        ds = (\n            tf.data.Dataset.from_tensor_slices((feat1, feat2))\n            .batch(5)\n            .map(lambda x1, x2: layer((x1, x2)))\n        )\n        output = next(iter(ds)).numpy()\n        self.assertAllClose(np.array([1, 4, 1, 1, 3]), output)\n\n    def test_static_shape_preserved(self):\n        layer = layers.HashedCrossing(num_bins=5)\n\n        def call_layer(x1, x2):\n            result = layer((x1, x2))\n            self.assertEqual(result.shape, (5,))\n            return result\n\n        feat1 = np.array([\"A\", \"B\", \"A\", \"B\", \"A\"])\n        feat2 = np.array([101, 101, 101, 102, 102])\n        ds = (\n            tf.data.Dataset.from_tensor_slices((feat1, feat2))\n            .batch(5, drop_remainder=True)\n            .map(call_layer)\n        )\n        next(iter(ds))\n\n    def test_unsupported_shape_input_fails(self):\n        with self.assertRaisesRegex(ValueError, \"inputs should have shape\"):\n            layers.HashedCrossing(num_bins=10)(\n                (np.array([[[1.0]]]), np.array([[[1.0]]]))\n            )\n\n    @pytest.mark.xfail\n    def test_cross_output_dtype(self):\n        input_1, input_2 = np.array([1]), np.array([1])\n\n        layer = layers.HashedCrossing(num_bins=2)\n        output_dtype = backend.standardize_dtype(\n            layer((input_1, input_2)).dtype\n        )\n        self.assertEqual(output_dtype, \"int64\")\n        layer = layers.HashedCrossing(num_bins=2, dtype=\"int32\")\n        output_dtype = backend.standardize_dtype(\n            layer((input_1, input_2)).dtype\n        )\n        self.assertEqual(output_dtype, \"int32\")\n        layer = layers.HashedCrossing(num_bins=2, output_mode=\"one_hot\")\n        output_dtype = backend.standardize_dtype(\n            layer((input_1, input_2)).dtype\n        )\n        self.assertEqual(output_dtype, \"float32\")\n        layer = layers.HashedCrossing(\n            num_bins=2, output_mode=\"one_hot\", dtype=\"float64\"\n        )\n        output_dtype = backend.standardize_dtype(\n            layer((input_1, input_2)).dtype\n        )\n        self.assertEqual(output_dtype, \"float64\")\n\n    def test_non_list_input_fails(self):\n        with self.assertRaisesRegex(ValueError, \"should be called on a list\"):\n            layers.HashedCrossing(num_bins=10)(np.array(1))\n\n    def test_single_input_fails(self):\n        with self.assertRaisesRegex(ValueError, \"at least two inputs\"):\n            layers.HashedCrossing(num_bins=10)([np.array(1)])\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"Need sparse tensor support.\",\n    )\n    def test_sparse_input_fails(self):\n        with self.assertRaisesRegex(\n            ValueError, \"inputs should be dense tensors\"\n        ):\n            sparse_in = tf.sparse.from_dense(np.array([1]))\n            layers.HashedCrossing(num_bins=10)((sparse_in, sparse_in))\n\n    def test_float_input_fails(self):\n        with self.assertRaisesRegex(\n            ValueError, \"should have an integer or string\"\n        ):\n            layers.HashedCrossing(num_bins=10)(\n                (np.array([1.0]), np.array([1.0]))\n            )\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"Need string tensor support.\",\n    )\n    def test_tf_string(self):\n        layer = layers.HashedCrossing(num_bins=10)\n        feat1 = tf.constant(\"A\")\n        feat2 = tf.constant(101)\n        outputs = layer((feat1, feat2))\n        self.assertAllClose(outputs, 1)\n\n        layer = layers.HashedCrossing(num_bins=5, output_mode=\"one_hot\")\n        feat1 = tf.constant([\"A\", \"B\", \"A\", \"B\", \"A\"])\n        feat2 = tf.constant([101, 101, 101, 102, 102])\n        self.assertAllClose(\n            tf.constant(\n                [\n                    [0.0, 1.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 1.0],\n                    [0.0, 1.0, 0.0, 0.0, 0.0],\n                    [0.0, 1.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 1.0, 0.0],\n                ]\n            ),\n            layer((feat1, feat2)),\n        )\n\n        layer = layers.HashedCrossing(num_bins=5)\n        feat1 = tf.constant([\"A\", \"B\", \"A\", \"B\", \"A\"])\n        feat2 = tf.constant([101, 101, 101, 102, 102])\n        self.assertAllClose(tf.constant([1, 4, 1, 1, 3]), layer((feat1, feat2)))\n\n        layer = layers.HashedCrossing(\n            num_bins=5, output_mode=\"one_hot\", sparse=True\n        )\n        cloned_layer = layers.HashedCrossing.from_config(layer.get_config())\n        feat1 = tf.constant([[\"A\"], [\"B\"], [\"A\"], [\"B\"], [\"A\"]])\n        feat2 = tf.constant([[101], [101], [101], [102], [102]])\n        original_outputs = layer((feat1, feat2))\n        cloned_outputs = cloned_layer((feat1, feat2))\n        self.assertAllClose(\n            tf.sparse.to_dense(cloned_outputs),\n            tf.sparse.to_dense(original_outputs),\n        )\n"
  },
  {
    "path": "keras/src/layers/preprocessing/hashing.py",
    "content": "from keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.layer import Layer\nfrom keras.src.utils import backend_utils\nfrom keras.src.utils import numerical_utils\nfrom keras.src.utils import tf_utils\nfrom keras.src.utils.module_utils import tensorflow as tf\n\n\n@keras_export(\"keras.layers.Hashing\")\nclass Hashing(Layer):\n    \"\"\"A preprocessing layer which hashes and bins categorical features.\n\n    This layer transforms categorical inputs to hashed output. It element-wise\n    converts a ints or strings to ints in a fixed range. The stable hash\n    function uses `tensorflow::ops::Fingerprint` to produce the same output\n    consistently across all platforms.\n\n    This layer uses [FarmHash64](https://github.com/google/farmhash) by default,\n    which provides a consistent hashed output across different platforms and is\n    stable across invocations, regardless of device and context, by mixing the\n    input bits thoroughly.\n\n    If you want to obfuscate the hashed output, you can also pass a random\n    `salt` argument in the constructor. In that case, the layer will use the\n    [SipHash64](https://github.com/google/highwayhash) hash function, with\n    the `salt` value serving as additional input to the hash function.\n\n    **Note:** This layer internally uses TensorFlow. It cannot\n    be used as part of the compiled computation graph of a model with\n    any backend other than TensorFlow.\n    It can however be used with any backend when running eagerly.\n    It can also always be used as part of an input preprocessing pipeline\n    with any backend (outside the model itself), which is how we recommend\n    to use this layer.\n\n    **Note:** This layer is safe to use inside a `tf.data` pipeline\n    (independently of which backend you're using).\n\n    **Example (FarmHash64)**\n\n    >>> layer = keras.layers.Hashing(num_bins=3)\n    >>> inp = [['A'], ['B'], ['C'], ['D'], ['E']]\n    >>> layer(inp)\n    array([[1],\n            [0],\n            [1],\n            [1],\n            [2]])>\n\n    **Example (FarmHash64) with a mask value**\n\n    >>> layer = keras.layers.Hashing(num_bins=3, mask_value='')\n    >>> inp = [['A'], ['B'], [''], ['C'], ['D']]\n    >>> layer(inp)\n    array([[1],\n            [1],\n            [0],\n            [2],\n            [2]])\n\n    **Example (SipHash64)**\n\n    >>> layer = keras.layers.Hashing(num_bins=3, salt=[133, 137])\n    >>> inp = [['A'], ['B'], ['C'], ['D'], ['E']]\n    >>> layer(inp)\n    array([[1],\n            [2],\n            [1],\n            [0],\n            [2]])\n\n    **Example (Siphash64 with a single integer, same as `salt=[133, 133]`)**\n\n    >>> layer = keras.layers.Hashing(num_bins=3, salt=133)\n    >>> inp = [['A'], ['B'], ['C'], ['D'], ['E']]\n    >>> layer(inp)\n    array([[0],\n            [0],\n            [2],\n            [1],\n            [0]])\n\n    Args:\n        num_bins: Number of hash bins. Note that this includes the `mask_value`\n            bin, so the effective number of bins is `(num_bins - 1)`\n            if `mask_value` is set.\n        mask_value: A value that represents masked inputs, which are mapped to\n            index 0. `None` means no mask term will be added and the\n            hashing will start at index 0. Defaults to `None`.\n        salt: A single unsigned integer or None.\n            If passed, the hash function used will be SipHash64,\n            with these values used as an additional input\n            (known as a \"salt\" in cryptography).\n            These should be non-zero. If `None`, uses the FarmHash64 hash\n            function. It also supports tuple/list of 2 unsigned\n            integer numbers, see reference paper for details.\n            Defaults to `None`.\n        output_mode: Specification for the output of the layer. Values can be\n            `\"int\"`, `\"one_hot\"`, `\"multi_hot\"`, or\n            `\"count\"` configuring the layer as follows:\n            - `\"int\"`: Return the integer bin indices directly.\n            - `\"one_hot\"`: Encodes each individual element in the input into an\n                array the same size as `num_bins`, containing a 1\n                at the input's bin index. If the last dimension is size 1,\n                will encode on that dimension.\n                If the last dimension is not size 1, will append a new\n                dimension for the encoded output.\n            - `\"multi_hot\"`: Encodes each sample in the input into a\n                single array the same size as `num_bins`,\n                containing a 1 for each bin index\n                index present in the sample. Treats the last dimension\n                as the sample dimension, if input shape is\n                `(..., sample_length)`, output shape will be\n                `(..., num_tokens)`.\n            - `\"count\"`: As `\"multi_hot\"`, but the int array contains a count of\n                the number of times the bin index appeared in the sample.\n            Defaults to `\"int\"`.\n        sparse: Boolean. Only applicable to `\"one_hot\"`, `\"multi_hot\"`,\n            and `\"count\"` output modes. Only supported with TensorFlow\n            backend. If `True`, returns a `SparseTensor` instead of\n            a dense `Tensor`. Defaults to `False`.\n        **kwargs: Keyword arguments to construct a layer.\n\n    Input shape:\n        A single string, a list of strings, or an `int32` or `int64` tensor\n        of shape `(batch_size, ...,)`.\n\n    Output shape:\n        An `int32` tensor of shape `(batch_size, ...)`.\n\n    Reference:\n\n    - [SipHash with salt](https://www.131002.net/siphash/siphash.pdf)\n    \"\"\"\n\n    def __init__(\n        self,\n        num_bins,\n        mask_value=None,\n        salt=None,\n        output_mode=\"int\",\n        sparse=False,\n        **kwargs,\n    ):\n        if not tf.available:\n            raise ImportError(\n                \"Layer Hashing requires TensorFlow. \"\n                \"Install it via `pip install tensorflow`.\"\n            )\n\n        # By default, output int32 when output_mode='int' and floats otherwise.\n        if \"dtype\" not in kwargs or kwargs[\"dtype\"] is None:\n            kwargs[\"dtype\"] = (\n                \"int64\" if output_mode == \"int\" else backend.floatx()\n            )\n\n        super().__init__(**kwargs)\n\n        if num_bins is None or num_bins <= 0:\n            raise ValueError(\n                \"The `num_bins` for `Hashing` cannot be `None` or \"\n                f\"non-positive values. Received: num_bins={num_bins}.\"\n            )\n\n        if output_mode == \"int\" and (\n            self.dtype_policy.name not in (\"int32\", \"int64\")\n        ):\n            raise ValueError(\n                'When `output_mode=\"int\"`, `dtype` should be an integer '\n                f\"type, 'int32' or 'in64'. Received: dtype={kwargs['dtype']}\"\n            )\n\n        # 'output_mode' must be one of (INT, ONE_HOT, MULTI_HOT, COUNT)\n        accepted_output_modes = (\"int\", \"one_hot\", \"multi_hot\", \"count\")\n        if output_mode not in accepted_output_modes:\n            raise ValueError(\n                \"Invalid value for argument `output_mode`. \"\n                f\"Expected one of {accepted_output_modes}. \"\n                f\"Received: output_mode={output_mode}\"\n            )\n\n        if sparse and output_mode == \"int\":\n            raise ValueError(\n                \"`sparse` may only be true if `output_mode` is \"\n                '`\"one_hot\"`, `\"multi_hot\"`, or `\"count\"`. '\n                f\"Received: sparse={sparse} and \"\n                f\"output_mode={output_mode}\"\n            )\n\n        self.num_bins = num_bins\n        self.mask_value = mask_value\n        self.strong_hash = True if salt is not None else False\n        self.output_mode = output_mode\n        self.sparse = sparse\n        self.salt = None\n        if salt is not None:\n            if isinstance(salt, (tuple, list)) and len(salt) == 2:\n                self.salt = list(salt)\n            elif isinstance(salt, int):\n                self.salt = [salt, salt]\n            else:\n                raise ValueError(\n                    \"The `salt` argument for `Hashing` can only be a tuple of \"\n                    \"size 2 integers, or a single integer. \"\n                    f\"Received: salt={salt}.\"\n                )\n        self._convert_input_args = False\n        self._allow_non_tensor_positional_args = True\n        self.supports_jit = False\n\n    def call(self, inputs):\n        from keras.src.backend import tensorflow as tf_backend\n\n        inputs = tf_utils.ensure_tensor(inputs)\n        if self.output_mode == \"one_hot\" and inputs.shape[-1] == 1:\n            # One hot only upranks if the final dimension is not 1.\n            inputs = tf_backend.numpy.squeeze(inputs, axis=-1)\n        if isinstance(inputs, tf.SparseTensor):\n            indices = tf.SparseTensor(\n                indices=inputs.indices,\n                values=self._hash_values_to_bins(inputs.values),\n                dense_shape=inputs.dense_shape,\n            )\n        else:\n            indices = self._hash_values_to_bins(inputs)\n        outputs = numerical_utils.encode_categorical_inputs(\n            indices,\n            output_mode=self.output_mode,\n            depth=self.num_bins,\n            sparse=self.sparse,\n            dtype=self.dtype,\n            backend_module=tf_backend,\n        )\n        return backend_utils.convert_tf_tensor(outputs)\n\n    def _hash_values_to_bins(self, values):\n        \"\"\"Converts a non-sparse tensor of values to bin indices.\"\"\"\n        hash_bins = self.num_bins\n        mask = None\n        # If mask_value is set, the zeroth bin is reserved for it.\n        if self.mask_value is not None and hash_bins > 1:\n            hash_bins -= 1\n            mask = tf.equal(values, self.mask_value)\n        # Convert all values to strings before hashing.\n        # Floats are first normalized to int64.\n        if values.dtype.is_floating:\n            values = tf.cast(values, dtype=\"int64\")\n        if values.dtype != tf.string:\n            values = tf.as_string(values)\n        # Hash the strings.\n        if self.strong_hash:\n            values = tf.strings.to_hash_bucket_strong(\n                values, hash_bins, name=\"hash\", key=self.salt\n            )\n        else:\n            values = tf.strings.to_hash_bucket_fast(\n                values, hash_bins, name=\"hash\"\n            )\n        if mask is not None:\n            values = tf.add(values, tf.ones_like(values))\n            values = tf.where(mask, tf.zeros_like(values), values)\n        return values\n\n    def compute_output_spec(self, inputs):\n        if self.output_mode == \"int\":\n            return backend.KerasTensor(shape=inputs.shape, dtype=self.dtype)\n        if len(inputs.shape) >= 1:\n            base_shape = tuple(inputs.shape)[:-1]\n        else:\n            base_shape = ()\n        return backend.KerasTensor(\n            shape=base_shape + (self.num_bins,), dtype=self.dtype\n        )\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"num_bins\": self.num_bins,\n                \"salt\": self.salt,\n                \"mask_value\": self.mask_value,\n                \"output_mode\": self.output_mode,\n                \"sparse\": self.sparse,\n            }\n        )\n        return config\n"
  },
  {
    "path": "keras/src/layers/preprocessing/hashing_test.py",
    "content": "import os\n\nimport numpy as np\nimport pytest\nimport tensorflow as tf\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import testing\nfrom keras.src.saving import load_model\n\n\nclass ArrayLike:\n    def __init__(self, values):\n        self.values = values\n\n    def __array__(self):\n        return np.array(self.values)\n\n\n@pytest.mark.skipif(\n    backend.backend() == \"numpy\", reason=\"Broken with NumPy backend.\"\n)\nclass HashingTest(testing.TestCase):\n    def test_config(self):\n        layer = layers.Hashing(\n            num_bins=8,\n            output_mode=\"int\",\n        )\n        self.run_class_serialization_test(layer)\n\n    def test_correctness(self):\n        layer = layers.Hashing(num_bins=3)\n        inp = [[\"A\"], [\"B\"], [\"C\"], [\"D\"], [\"E\"]]\n        output = layer(inp)\n        self.assertTrue(backend.is_tensor(output))\n        self.assertAllClose(output, np.array([[1], [0], [1], [1], [2]]))\n\n        layer = layers.Hashing(num_bins=3, mask_value=\"\")\n        inp = [[\"A\"], [\"B\"], [\"\"], [\"C\"], [\"D\"]]\n        output = layer(inp)\n        self.assertTrue(backend.is_tensor(output))\n        self.assertAllClose(output, np.array([[1], [1], [0], [2], [2]]))\n\n        layer = layers.Hashing(num_bins=3, salt=[133, 137])\n        inp = [[\"A\"], [\"B\"], [\"C\"], [\"D\"], [\"E\"]]\n        output = layer(inp)\n        self.assertTrue(backend.is_tensor(output))\n        self.assertAllClose(output, np.array([[1], [2], [1], [0], [2]]))\n\n        layer = layers.Hashing(num_bins=3, salt=133)\n        inp = [[\"A\"], [\"B\"], [\"C\"], [\"D\"], [\"E\"]]\n        output = layer(inp)\n        self.assertTrue(backend.is_tensor(output))\n        self.assertAllClose(output, np.array([[0], [0], [2], [1], [0]]))\n\n    def test_tf_data_compatibility(self):\n        layer = layers.Hashing(num_bins=3)\n        inp = [[\"A\"], [\"B\"], [\"C\"], [\"D\"], [\"E\"]]\n        ds = tf.data.Dataset.from_tensor_slices(inp).batch(5).map(layer)\n        output = next(iter(ds)).numpy()\n        self.assertAllClose(output, np.array([[1], [0], [1], [1], [2]]))\n\n    @parameterized.named_parameters(\n        (\"list\", list),\n        (\"tuple\", tuple),\n        (\"numpy\", np.array),\n        (\"array_like\", ArrayLike),\n    )\n    def test_tensor_like_inputs(self, data_fn):\n        input_data = data_fn([0, 1, 2, 3, 4])\n        expected_output = [1, 0, 1, 0, 2]\n\n        layer = layers.Hashing(num_bins=3)\n        output_data = layer(input_data)\n        self.assertAllEqual(output_data, expected_output)\n\n    def test_hash_single_bin(self):\n        layer = layers.Hashing(num_bins=1)\n        inp = np.asarray([[\"A\"], [\"B\"], [\"C\"], [\"D\"], [\"E\"]])\n        output = layer(inp)\n        self.assertAllClose([[0], [0], [0], [0], [0]], output)\n\n    def test_hash_dense_input_farmhash(self):\n        layer = layers.Hashing(num_bins=2)\n        inp = np.asarray(\n            [[\"omar\"], [\"stringer\"], [\"marlo\"], [\"wire\"], [\"skywalker\"]]\n        )\n        output = layer(inp)\n        # Assert equal for hashed output that should be true on all platforms.\n        self.assertAllClose([[0], [0], [1], [0], [0]], output)\n\n    def test_hash_dense_input_mask_value_farmhash(self):\n        empty_mask_layer = layers.Hashing(num_bins=3, mask_value=\"\")\n        omar_mask_layer = layers.Hashing(num_bins=3, mask_value=\"omar\")\n        inp = np.asarray(\n            [[\"omar\"], [\"stringer\"], [\"marlo\"], [\"wire\"], [\"skywalker\"]]\n        )\n        empty_mask_output = empty_mask_layer(inp)\n        omar_mask_output = omar_mask_layer(inp)\n        # Outputs should be one more than test_hash_dense_input_farmhash (the\n        # zeroth bin is now reserved for masks).\n        self.assertAllClose([[1], [1], [2], [1], [1]], empty_mask_output)\n        # 'omar' should map to 0.\n        self.assertAllClose([[0], [1], [2], [1], [1]], omar_mask_output)\n\n    def test_hash_dense_list_input_farmhash(self):\n        layer = layers.Hashing(num_bins=2)\n        inp = [[\"omar\"], [\"stringer\"], [\"marlo\"], [\"wire\"], [\"skywalker\"]]\n        output = layer(inp)\n        # Assert equal for hashed output that should be true on all platforms.\n        self.assertAllClose([[0], [0], [1], [0], [0]], output)\n\n        inp = [\"omar\", \"stringer\", \"marlo\", \"wire\", \"skywalker\"]\n        output = layer(inp)\n        # Assert equal for hashed output that should be true on all platforms.\n        self.assertAllClose([0, 0, 1, 0, 0], output)\n\n    def test_hash_dense_int_input_farmhash(self):\n        layer = layers.Hashing(num_bins=3)\n        inp = np.asarray([[0], [1], [2], [3], [4]])\n        output = layer(inp)\n        # Assert equal for hashed output that should be true on all platforms.\n        self.assertAllClose([[1], [0], [1], [0], [2]], output)\n\n    def test_hash_dense_input_siphash(self):\n        layer = layers.Hashing(num_bins=2, salt=[133, 137])\n        inp = np.asarray(\n            [[\"omar\"], [\"stringer\"], [\"marlo\"], [\"wire\"], [\"skywalker\"]]\n        )\n        output = layer(inp)\n        # Assert equal for hashed output that should be true on all platforms.\n        # Note the result is different from FarmHash.\n        self.assertAllClose([[0], [1], [0], [1], [0]], output)\n\n        layer_2 = layers.Hashing(num_bins=2, salt=[211, 137])\n        output_2 = layer_2(inp)\n        # Note the result is different from (133, 137).\n        self.assertAllClose([[1], [0], [1], [0], [1]], output_2)\n\n    def test_hash_dense_int_input_siphash(self):\n        layer = layers.Hashing(num_bins=3, salt=[133, 137])\n        inp = np.asarray([[0], [1], [2], [3], [4]])\n        output = layer(inp)\n        # Assert equal for hashed output that should be true on all platforms.\n        self.assertAllClose([[1], [1], [2], [0], [1]], output)\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\", reason=\"Uses tf.SparseTensor.\"\n    )\n    def test_hash_sparse_input_farmhash(self):\n        layer = layers.Hashing(num_bins=2)\n        indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]]\n        inp = tf.SparseTensor(\n            indices=indices,\n            values=[\"omar\", \"stringer\", \"marlo\", \"wire\", \"skywalker\"],\n            dense_shape=[3, 2],\n        )\n        output = layer(inp)\n        self.assertAllClose(indices, output.indices)\n        self.assertAllClose([0, 0, 1, 0, 0], output.values)\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\", reason=\"Uses tf.SparseTensor.\"\n    )\n    def test_hash_sparse_input_mask_value_farmhash(self):\n        empty_mask_layer = layers.Hashing(num_bins=3, mask_value=\"\")\n        omar_mask_layer = layers.Hashing(num_bins=3, mask_value=\"omar\")\n        indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]]\n        inp = tf.SparseTensor(\n            indices=indices,\n            values=[\"omar\", \"stringer\", \"marlo\", \"wire\", \"skywalker\"],\n            dense_shape=[3, 2],\n        )\n        empty_mask_output = empty_mask_layer(inp)\n        omar_mask_output = omar_mask_layer(inp)\n        self.assertAllClose(indices, omar_mask_output.indices)\n        self.assertAllClose(indices, empty_mask_output.indices)\n        # Outputs should be one more than test_hash_sparse_input_farmhash (the\n        # zeroth bin is now reserved for masks).\n        self.assertAllClose([1, 1, 2, 1, 1], empty_mask_output.values)\n        # 'omar' should map to 0.\n        self.assertAllClose([0, 1, 2, 1, 1], omar_mask_output.values)\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\", reason=\"Uses tf.SparseTensor.\"\n    )\n    def test_hash_sparse_int_input_farmhash(self):\n        layer = layers.Hashing(num_bins=3)\n        indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]]\n        inp = tf.SparseTensor(\n            indices=indices, values=[0, 1, 2, 3, 4], dense_shape=[3, 2]\n        )\n        output = layer(inp)\n        self.assertAllClose(indices, output.indices)\n        self.assertAllClose([1, 0, 1, 0, 2], output.values)\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\", reason=\"Uses tf.SparseTensor.\"\n    )\n    def test_hash_sparse_input_siphash(self):\n        layer = layers.Hashing(num_bins=2, salt=[133, 137])\n        indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]]\n        inp = tf.SparseTensor(\n            indices=indices,\n            values=[\"omar\", \"stringer\", \"marlo\", \"wire\", \"skywalker\"],\n            dense_shape=[3, 2],\n        )\n        output = layer(inp)\n        self.assertAllClose(output.indices, indices)\n        # The result should be same with test_hash_dense_input_siphash.\n        self.assertAllClose([0, 1, 0, 1, 0], output.values)\n\n        layer_2 = layers.Hashing(num_bins=2, salt=[211, 137])\n        output = layer_2(inp)\n        # The result should be same with test_hash_dense_input_siphash.\n        self.assertAllClose([1, 0, 1, 0, 1], output.values)\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\", reason=\"Uses tf.SparseTensor.\"\n    )\n    def test_hash_sparse_int_input_siphash(self):\n        layer = layers.Hashing(num_bins=3, salt=[133, 137])\n        indices = [[0, 0], [1, 0], [1, 1], [2, 0], [2, 1]]\n        inp = tf.SparseTensor(\n            indices=indices, values=[0, 1, 2, 3, 4], dense_shape=[3, 2]\n        )\n        output = layer(inp)\n        self.assertAllClose(indices, output.indices)\n        self.assertAllClose([1, 1, 2, 0, 1], output.values)\n\n    def test_invalid_inputs(self):\n        with self.assertRaisesRegex(ValueError, \"cannot be `None`\"):\n            _ = layers.Hashing(num_bins=None)\n        with self.assertRaisesRegex(ValueError, \"cannot be `None`\"):\n            _ = layers.Hashing(num_bins=-1)\n        with self.assertRaisesRegex(\n            ValueError, \"can only be a tuple of size 2\"\n        ):\n            _ = layers.Hashing(num_bins=2, salt=\"string\")\n        with self.assertRaisesRegex(\n            ValueError, \"can only be a tuple of size 2\"\n        ):\n            _ = layers.Hashing(num_bins=2, salt=[1])\n        with self.assertRaisesRegex(\n            ValueError, \"can only be a tuple of size 2\"\n        ):\n            _ = layers.Hashing(num_bins=1, salt=[133, 137, 177])\n\n    def test_one_hot_output(self):\n        input_array = np.array([0, 1, 2, 3, 4])\n\n        expected_output = [\n            [0.0, 1.0, 0.0],\n            [1.0, 0.0, 0.0],\n            [0.0, 1.0, 0.0],\n            [1.0, 0.0, 0.0],\n            [0.0, 0.0, 1.0],\n        ]\n        expected_output_shape = [None, 3]\n\n        inputs = layers.Input(shape=(1,), dtype=\"int32\")\n        layer = layers.Hashing(num_bins=3, output_mode=\"one_hot\")\n        outputs = layer(inputs)\n        self.assertAllEqual(expected_output_shape, outputs.shape)\n\n        model = models.Model(inputs, outputs)\n        output_data = model(input_array)\n        self.assertAllClose(expected_output, output_data)\n\n    def test_multi_hot_output(self):\n        input_array = np.array([[0, 1, 2, 3, 4]])\n\n        expected_output = [[1.0, 1.0, 1.0]]\n        expected_output_shape = [None, 3]\n\n        inputs = layers.Input(shape=(None,), dtype=\"int32\")\n        layer = layers.Hashing(num_bins=3, output_mode=\"multi_hot\")\n        outputs = layer(inputs)\n        self.assertAllEqual(expected_output_shape, outputs.shape)\n\n        model = models.Model(inputs, outputs)\n        output_data = model(input_array)\n        self.assertAllClose(expected_output, output_data)\n\n    @parameterized.named_parameters(\n        (\n            \"1d_input\",\n            [0, 1, 2, 3, 4],\n            [2.0, 2.0, 1.0],\n            [3],\n        ),\n        (\n            \"2d_input\",\n            [[0, 1, 2, 3, 4]],\n            [[2.0, 2.0, 1.0]],\n            [None, 3],\n        ),\n    )\n    def test_count_output(self, input_value, expected_output, output_shape):\n        input_array = np.array(input_value)\n        if input_array.ndim == 1:\n            symbolic_sample_shape = ()\n        elif input_array.ndim == 2:\n            symbolic_sample_shape = (None,)\n        else:\n            raise TypeError(\"Unknown `symbolic_sample_shape`\")\n        inputs = layers.Input(shape=symbolic_sample_shape, dtype=\"int32\")\n        layer = layers.Hashing(num_bins=3, output_mode=\"count\")\n        outputs = layer(inputs)\n        self.assertAllEqual(output_shape, outputs.shape)\n        output_data = layer(input_array)\n        self.assertAllEqual(expected_output, output_data)\n\n    @parameterized.named_parameters(\n        (\"int32\", \"int32\"),\n        (\"int64\", \"int64\"),\n    )\n    def test_int_output_dtype(self, dtype):\n        input_data = layers.Input(batch_size=16, shape=(4,), dtype=\"string\")\n        layer = layers.Hashing(num_bins=3, output_mode=\"int\", dtype=dtype)\n        output = layer(input_data)\n        self.assertEqual(output.dtype, dtype)\n\n    @parameterized.named_parameters(\n        (\"float32\", \"float32\"),\n        (\"float64\", \"float64\"),\n    )\n    def test_one_hot_output_dtype(self, dtype):\n        input_data = layers.Input(batch_size=16, shape=(1,), dtype=\"string\")\n        layer = layers.Hashing(num_bins=3, output_mode=\"one_hot\", dtype=dtype)\n        output = layer(input_data)\n        self.assertEqual(output.dtype, dtype)\n\n    def test_config_with_custom_name(self):\n        layer = layers.Hashing(num_bins=2, name=\"hashing\")\n        config = layer.get_config()\n        layer_1 = layers.Hashing.from_config(config)\n        self.assertEqual(layer_1.name, layer.name)\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\", reason=\"Uses string dtype.\"\n    )\n    def test_saving(self):\n        input_data = np.array(\n            [\"omar\", \"stringer\", \"marlo\", \"wire\", \"skywalker\"]\n        )\n        inputs = layers.Input(shape=(), dtype=\"string\")\n        outputs = layers.Hashing(num_bins=100)(inputs)\n        model = models.Model(inputs=inputs, outputs=outputs)\n\n        original_output_data = model(input_data)\n\n        # Save the model to disk.\n        output_path = os.path.join(self.get_temp_dir(), \"keras_model.keras\")\n        model.save(output_path)\n        loaded_model = load_model(output_path)\n\n        # Ensure that the loaded model is unique (so that the save/load is real)\n        self.assertIsNot(model, loaded_model)\n\n        # Validate correctness of the new model.\n        new_output_data = loaded_model(input_data)\n        self.assertAllClose(new_output_data, original_output_data)\n\n    @parameterized.named_parameters(\n        (\n            \"list_input\",\n            [1, 2, 3],\n            [1, 1, 1],\n        ),\n        (\n            \"list_input_2d\",\n            [[1], [2], [3]],\n            [[1], [1], [1]],\n        ),\n        (\n            \"list_input_2d_multiple\",\n            [[1, 2], [2, 3], [3, 4]],\n            [[1, 1], [1, 1], [1, 1]],\n        ),\n        (\n            \"list_input_3d\",\n            [[[1], [2]], [[2], [3]], [[3], [4]]],\n            [[[1], [1]], [[1], [1]], [[1], [1]]],\n        ),\n    )\n    def test_hash_list_input(self, input_data, expected):\n        layer = layers.Hashing(num_bins=2)\n        out_data = layer(input_data)\n        self.assertAllEqual(\n            expected, backend.convert_to_numpy(out_data).tolist()\n        )\n\n    def test_hashing_invalid_num_bins(self):\n        # Test with `num_bins` set to None\n        with self.assertRaisesRegex(\n            ValueError,\n            \"The `num_bins` for `Hashing` cannot be `None` or non-positive\",\n        ):\n            layers.Hashing(num_bins=None)\n\n        # Test with `num_bins` set to 0\n        with self.assertRaisesRegex(\n            ValueError,\n            \"The `num_bins` for `Hashing` cannot be `None` or non-positive\",\n        ):\n            layers.Hashing(num_bins=0)\n\n    def test_hashing_invalid_output_mode(self):\n        # Test with an unsupported `output_mode`\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Invalid value for argument `output_mode`. Expected one of\",\n        ):\n            layers.Hashing(num_bins=3, output_mode=\"unsupported_mode\")\n\n    def test_hashing_invalid_dtype_for_int_mode(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            'When `output_mode=\"int\"`, `dtype` should be an integer type,',\n        ):\n            layers.Hashing(num_bins=3, output_mode=\"int\", dtype=\"float32\")\n\n    def test_hashing_sparse_with_int_mode(self):\n        # Test setting `sparse=True` with `output_mode='int'`\n        with self.assertRaisesRegex(\n            ValueError, \"`sparse` may only be true if `output_mode` is\"\n        ):\n            layers.Hashing(num_bins=3, output_mode=\"int\", sparse=True)\n\n\n# TODO: support tf.RaggedTensor.\n# def test_hash_ragged_string_input_farmhash(self):\n#     layer = layers.Hashing(num_bins=2)\n#     inp_data = tf.ragged.constant(\n#         [\n#             [\"omar\", \"stringer\", \"marlo\", \"wire\"],\n#             [\"marlo\", \"skywalker\", \"wire\"],\n#         ],\n#         dtype=\"string\",\n#     )\n#     out_data = layer(inp_data)\n#     # Same hashed output as test_hash_sparse_input_farmhash\n#     expected_output = [[0, 0, 1, 0], [1, 0, 0]]\n#     self.assertAllEqual(expected_output, out_data)\n\n#     inp_t = layers.Input(shape=(None,), ragged=True, dtype=\"string\")\n#     out_t = layer(inp_t)\n#     model = models.Model(inputs=inp_t, outputs=out_t)\n#     self.assertAllClose(out_data, model.predict(inp_data))\n\n# TODO: support tf.RaggedTensor.\n# def test_hash_ragged_input_mask_value(self):\n#     empty_mask_layer = layers.Hashing(num_bins=3, mask_value=\"\")\n#     omar_mask_layer = layers.Hashing(num_bins=3, mask_value=\"omar\")\n#     inp_data = tf.ragged.constant(\n#         [\n#             [\"omar\", \"stringer\", \"marlo\", \"wire\"],\n#             [\"marlo\", \"skywalker\", \"wire\"],\n#         ],\n#         dtype=\"string\",\n#     )\n#     empty_mask_output = empty_mask_layer(inp_data)\n#     omar_mask_output = omar_mask_layer(inp_data)\n#     # Outputs should be one more than test_hash_ragged_string_input_farmhash\n#     # (the zeroth bin is now reserved for masks).\n#     expected_output = [[1, 1, 2, 1], [2, 1, 1]]\n#     self.assertAllClose(expected_output[0], empty_mask_output[1])\n#     self.assertAllClose(expected_output[1], empty_mask_output[2])\n#     # 'omar' should map to 0.\n#     expected_output = [[0, 1, 2, 1], [2, 1, 1]]\n#     self.assertAllClose(expected_output[0], omar_mask_output[0])\n#     self.assertAllClose(expected_output[1], omar_mask_output[1])\n\n# TODO: support tf.RaggedTensor.\n# def test_hash_ragged_int_input_farmhash(self):\n#     layer = layers.Hashing(num_bins=3)\n#     inp_data = tf.ragged.constant([[0, 1, 3, 4], [2, 1, 0]], dtype=\"int64\")\n#     out_data = layer(inp_data)\n#     # Same hashed output as test_hash_sparse_input_farmhash\n#     expected_output = [[1, 0, 0, 2], [1, 0, 1]]\n#     self.assertAllEqual(expected_output[0], out_data[0])\n#     self.assertAllEqual(expected_output[1], out_data[1])\n#     inp_t = layers.Input(shape=(None,), ragged=True, dtype=\"int64\")\n#     out_t = layer(inp_t)\n#     model = models.Model(inputs=inp_t, outputs=out_t)\n#     self.assertAllClose(out_data, model.predict(inp_data))\n\n# TODO: support tf.RaggedTensor.\n# def test_hash_ragged_string_input_siphash(self):\n#     layer = layers.Hashing(num_bins=2, salt=[133, 137])\n#     inp_data = tf.ragged.constant(\n#         [\n#             [\"omar\", \"stringer\", \"marlo\", \"wire\"],\n#             [\"marlo\", \"skywalker\", \"wire\"],\n#         ],\n#         dtype=\"string\",\n#     )\n#     out_data = layer(inp_data)\n#     # Same hashed output as test_hash_dense_input_siphash\n#     expected_output = [[0, 1, 0, 1], [0, 0, 1]]\n#     self.assertAllEqual(expected_output, out_data)\n\n#     inp_t = layers.Input(shape=(None,), ragged=True, dtype=\"string\")\n#     out_t = layer(inp_t)\n#     model = models.Model(inputs=inp_t, outputs=out_t)\n#     self.assertAllClose(out_data, model.predict(inp_data))\n\n#     layer_2 = layers.Hashing(num_bins=2, salt=[211, 137])\n#     out_data = layer_2(inp_data)\n#     expected_output = [[1, 0, 1, 0], [1, 1, 0]]\n#     self.assertAllEqual(expected_output, out_data)\n\n#     out_t = layer_2(inp_t)\n#     model = models.Model(inputs=inp_t, outputs=out_t)\n#     self.assertAllClose(out_data, model.predict(inp_data))\n\n# TODO: support tf.RaggedTensor.\n# def test_hash_ragged_int_input_siphash(self):\n#     layer = layers.Hashing(num_bins=3, salt=[133, 137])\n#     inp_data = tf.ragged.constant([[0, 1, 3, 4], [2, 1, 0]], dtype=\"int64\")\n#     out_data = layer(inp_data)\n#     # Same hashed output as test_hash_sparse_input_farmhash\n#     expected_output = [[1, 1, 0, 1], [2, 1, 1]]\n#     self.assertAllEqual(expected_output, out_data)\n\n#     inp_t = layers.Input(shape=(None,), ragged=True, dtype=\"int64\")\n#     out_t = layer(inp_t)\n#     model = models.Model(inputs=inp_t, outputs=out_t)\n#     self.assertAllClose(out_data, model.predict(inp_data))\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/__init__.py",
    "content": ""
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/aug_mix.py",
    "content": "import random as py_random\n\nimport keras.src.layers as layers\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    base_image_preprocessing_transform_example,\n)\nfrom keras.src.random import SeedGenerator\nfrom keras.src.utils import backend_utils\n\nAUGMENT_LAYERS_ALL = [\n    \"random_shear\",\n    \"random_translation\",\n    \"random_rotation\",\n    \"random_posterization\",\n    \"solarization\",\n    \"auto_contrast\",\n    \"equalization\",\n    \"random_brightness\",\n    \"random_color_degeneration\",\n    \"random_contrast\",\n    \"random_sharpness\",\n]\n\nAUGMENT_LAYERS = [\n    \"random_shear\",\n    \"random_translation\",\n    \"random_rotation\",\n    \"random_posterization\",\n    \"solarization\",\n    \"auto_contrast\",\n    \"equalization\",\n]\n\n\n@keras_export(\"keras.layers.AugMix\")\nclass AugMix(BaseImagePreprocessingLayer):\n    \"\"\"Performs the AugMix data augmentation technique.\n\n    AugMix aims to produce images with variety while preserving the image\n    semantics and local statistics. During the augmentation process,\n    the same augmentation is applied across all images in the batch\n    in num_chains different ways, with each chain consisting of\n    chain_depth augmentations.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    References:\n        - [AugMix paper](https://arxiv.org/pdf/1912.02781)\n        - [Official Code](https://github.com/google-research/augmix)\n\n    Args:\n        value_range: the range of values the incoming images will have.\n            Represented as a two number tuple written (low, high).\n            This is typically either `(0, 1)` or `(0, 255)` depending\n            on how your preprocessing pipeline is set up.\n        num_chains: an integer representing the number of different chains to\n            be mixed, defaults to 3.\n        chain_depth: an integer representing the maximum number of\n            transformations to be applied in each chain. The actual number\n            of transformations in each chain will be sampled randomly\n            from the range `[0, `chain_depth`]`. Defaults to 3.\n        factor: The strength of the augmentation as a normalized value\n            between 0 and 1. Default is 0.3.\n        alpha: a float value used as the probability coefficients for the\n            Beta and Dirichlet distributions, defaults to 1.0.\n        all_ops: Use all operations (including random_brightness,\n            random_color_degeneration, random_contrast and random_sharpness).\n            Default is True.\n        interpolation: The interpolation method to use for resizing operations.\n            Options include `\"nearest\"`, `\"bilinear\"`. Default is `\"bilinear\"`.\n        seed: Integer. Used to create a random seed.\n\n    Example:\n\n    {{base_image_preprocessing_transform_example}}\n    \"\"\"\n\n    _USE_BASE_FACTOR = False\n    _FACTOR_BOUNDS = (0, 1)\n\n    def __init__(\n        self,\n        value_range=(0, 255),\n        num_chains=3,\n        chain_depth=3,\n        factor=0.3,\n        alpha=1.0,\n        all_ops=True,\n        interpolation=\"bilinear\",\n        seed=None,\n        data_format=None,\n        **kwargs,\n    ):\n        super().__init__(data_format=data_format, **kwargs)\n\n        self.value_range = value_range\n        self.num_chains = num_chains\n        self.chain_depth = chain_depth\n        self._set_factor(factor)\n        self.alpha = alpha\n        self.all_ops = all_ops\n        self.interpolation = interpolation\n        self.seed = seed\n        self.generator = SeedGenerator(seed)\n\n        if self.all_ops:\n            self._augment_layers = AUGMENT_LAYERS_ALL\n        else:\n            self._augment_layers = AUGMENT_LAYERS\n\n        self.random_shear = layers.RandomShear(\n            x_factor=self.factor,\n            y_factor=self.factor,\n            interpolation=interpolation,\n            seed=self.seed,\n            data_format=data_format,\n            **kwargs,\n        )\n\n        self.random_translation = layers.RandomTranslation(\n            height_factor=self.factor,\n            width_factor=self.factor,\n            interpolation=interpolation,\n            seed=self.seed,\n            data_format=data_format,\n            **kwargs,\n        )\n\n        self.random_rotation = layers.RandomRotation(\n            factor=self.factor,\n            interpolation=interpolation,\n            seed=self.seed,\n            data_format=data_format,\n            **kwargs,\n        )\n\n        self.solarization = layers.Solarization(\n            addition_factor=self.factor,\n            threshold_factor=self.factor,\n            value_range=self.value_range,\n            seed=self.seed,\n            data_format=data_format,\n            **kwargs,\n        )\n\n        self.random_posterization = layers.RandomPosterization(\n            factor=max(1, int(8 * self.factor[1])),\n            value_range=self.value_range,\n            seed=self.seed,\n            data_format=data_format,\n            **kwargs,\n        )\n\n        self.auto_contrast = layers.AutoContrast(\n            value_range=self.value_range, data_format=data_format, **kwargs\n        )\n\n        self.equalization = layers.Equalization(\n            value_range=self.value_range, data_format=data_format, **kwargs\n        )\n\n        if self.all_ops:\n            self.random_brightness = layers.RandomBrightness(\n                factor=self.factor,\n                value_range=self.value_range,\n                seed=self.seed,\n                data_format=data_format,\n                **kwargs,\n            )\n\n            self.random_color_degeneration = layers.RandomColorDegeneration(\n                factor=self.factor,\n                value_range=self.value_range,\n                seed=self.seed,\n                data_format=data_format,\n                **kwargs,\n            )\n\n            self.random_contrast = layers.RandomContrast(\n                factor=self.factor,\n                value_range=self.value_range,\n                seed=self.seed,\n                data_format=data_format,\n                **kwargs,\n            )\n\n            self.random_sharpness = layers.RandomSharpness(\n                factor=self.factor,\n                value_range=self.value_range,\n                seed=self.seed,\n                data_format=data_format,\n                **kwargs,\n            )\n\n    def build(self, input_shape):\n        for layer_name in self._augment_layers:\n            augmentation_layer = getattr(self, layer_name)\n            augmentation_layer.build(input_shape)\n\n    def _sample_from_dirichlet(self, shape, alpha, seed):\n        gamma_sample = self.backend.random.gamma(\n            shape=shape,\n            alpha=alpha,\n            seed=seed,\n        )\n        return gamma_sample / self.backend.numpy.sum(\n            gamma_sample, axis=-1, keepdims=True\n        )\n\n    def get_random_transformation(self, data, training=True, seed=None):\n        if not training:\n            return None\n\n        if backend_utils.in_tf_graph():\n            self.backend.set_backend(\"tensorflow\")\n\n            for layer_name in self._augment_layers:\n                augmentation_layer = getattr(self, layer_name)\n                augmentation_layer.backend.set_backend(\"tensorflow\")\n\n        seed = seed or self._get_seed_generator(self.backend._backend)\n\n        chain_mixing_weights = self._sample_from_dirichlet(\n            [self.num_chains], self.alpha, seed\n        )\n        weight_sample = self.backend.random.beta(\n            shape=(),\n            alpha=self.alpha,\n            beta=self.alpha,\n            seed=seed,\n        )\n\n        chain_transforms = []\n        for _ in range(self.num_chains):\n            depth_transforms = []\n            for _ in range(self.chain_depth):\n                layer_name = py_random.choice(self._augment_layers + [None])\n                if layer_name is None:\n                    continue\n                augmentation_layer = getattr(self, layer_name)\n                depth_transforms.append(\n                    {\n                        \"layer_name\": layer_name,\n                        \"transformation\": (\n                            augmentation_layer.get_random_transformation(\n                                data,\n                                seed=self._get_seed_generator(\n                                    self.backend._backend\n                                ),\n                            )\n                        ),\n                    }\n                )\n            chain_transforms.append(depth_transforms)\n\n        transformation = {\n            \"chain_mixing_weights\": chain_mixing_weights,\n            \"weight_sample\": weight_sample,\n            \"chain_transforms\": chain_transforms,\n        }\n\n        return transformation\n\n    def transform_images(self, images, transformation, training=True):\n        if training:\n            images = self.backend.cast(images, self.compute_dtype)\n\n            chain_mixing_weights = self.backend.cast(\n                transformation[\"chain_mixing_weights\"], dtype=self.compute_dtype\n            )\n            weight_sample = self.backend.cast(\n                transformation[\"weight_sample\"], dtype=self.compute_dtype\n            )\n            chain_transforms = transformation[\"chain_transforms\"]\n\n            aug_images = self.backend.numpy.zeros_like(images)\n            for idx, chain_transform in enumerate(chain_transforms):\n                copied_images = self.backend.numpy.copy(images)\n                for depth_transform in chain_transform:\n                    layer_name = depth_transform[\"layer_name\"]\n                    layer_transform = depth_transform[\"transformation\"]\n\n                    augmentation_layer = getattr(self, layer_name)\n                    copied_images = augmentation_layer.transform_images(\n                        copied_images, layer_transform\n                    )\n                aug_images += copied_images * chain_mixing_weights[idx]\n            images = weight_sample * images + (1 - weight_sample) * aug_images\n\n            images = self.backend.numpy.clip(\n                images, self.value_range[0], self.value_range[1]\n            )\n\n        images = self.backend.cast(images, self.compute_dtype)\n        return images\n\n    def transform_labels(self, labels, transformation, training=True):\n        return labels\n\n    def transform_bounding_boxes(\n        self,\n        bounding_boxes,\n        transformation,\n        training=True,\n    ):\n        return bounding_boxes\n\n    def transform_segmentation_masks(\n        self, segmentation_masks, transformation, training=True\n    ):\n        if training:\n            chain_mixing_weights = self.backend.cast(\n                transformation[\"chain_mixing_weights\"], dtype=self.compute_dtype\n            )\n            weight_sample = self.backend.cast(\n                transformation[\"weight_sample\"], dtype=self.compute_dtype\n            )\n            chain_transforms = transformation[\"chain_transforms\"]\n\n            aug_masks = self.backend.numpy.zeros_like(segmentation_masks)\n            for idx, chain_transform in enumerate(chain_transforms):\n                copied_masks = self.backend.numpy.copy(segmentation_masks)\n                for depth_transform in chain_transform:\n                    layer_name = depth_transform[\"layer_name\"]\n                    layer_transform = depth_transform[\"transformation\"]\n\n                    augmentation_layer = getattr(self, layer_name)\n                    copied_masks = (\n                        augmentation_layer.transform_segmentation_masks(\n                            copied_masks, layer_transform\n                        )\n                    )\n                aug_masks += copied_masks * chain_mixing_weights[idx]\n            segmentation_masks = (\n                weight_sample * segmentation_masks\n                + (1 - weight_sample) * aug_masks\n            )\n        return segmentation_masks\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def get_config(self):\n        config = {\n            \"value_range\": self.value_range,\n            \"num_chains\": self.num_chains,\n            \"chain_depth\": self.chain_depth,\n            \"factor\": self.factor,\n            \"alpha\": self.alpha,\n            \"all_ops\": self.all_ops,\n            \"interpolation\": self.interpolation,\n            \"seed\": self.seed,\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n\nAugMix.__doc__ = AugMix.__doc__.replace(\n    \"{{base_image_preprocessing_transform_example}}\",\n    base_image_preprocessing_transform_example.replace(\"{LayerName}\", \"AugMix\"),\n)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/aug_mix_test.py",
    "content": "import numpy as np\nimport pytest\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass RandAugmentTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_layer(self):\n        self.run_layer_test(\n            layers.AugMix,\n            init_kwargs={\n                \"value_range\": (0, 255),\n                \"num_chains\": 2,\n                \"chain_depth\": 2,\n                \"factor\": 1,\n                \"alpha\": 1.0,\n                \"all_ops\": True,\n                \"interpolation\": \"nearest\",\n                \"seed\": 43,\n                \"data_format\": \"channels_last\",\n            },\n            input_shape=(8, 3, 4, 3),\n            supports_masking=False,\n            expected_output_shape=(8, 3, 4, 3),\n        )\n\n    def test_aug_mix_inference(self):\n        seed = 3481\n        layer = layers.AugMix()\n\n        np.random.seed(seed)\n        inputs = np.random.randint(0, 255, size=(224, 224, 3))\n        output = layer(inputs, training=False)\n        self.assertAllClose(inputs, output)\n\n    def test_random_augment_randomness(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_data = np.random.random((2, 8, 8, 3))\n        else:\n            input_data = np.random.random((2, 3, 8, 8))\n\n        layer = layers.AugMix(\n            num_chains=11, all_ops=True, data_format=data_format\n        )\n        augmented_image = layer(input_data)\n\n        self.assertNotAllClose(\n            backend.convert_to_numpy(augmented_image), input_data\n        )\n\n    def test_tf_data_compatibility(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_data = np.random.random((2, 8, 8, 3))\n        else:\n            input_data = np.random.random((2, 3, 8, 8))\n        layer = layers.AugMix(data_format=data_format)\n\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        for output in ds.take(1):\n            output.numpy()\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/auto_contrast.py",
    "content": "from keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\nfrom keras.src.ops.core import _saturate_cast\n\n\n@keras_export(\"keras.layers.AutoContrast\")\nclass AutoContrast(BaseImagePreprocessingLayer):\n    \"\"\"Performs the auto-contrast operation on an image.\n\n    Auto contrast stretches the values of an image across the entire available\n    `value_range`. This makes differences between pixels more obvious. An\n    example of this is if an image only has values `[0, 1]` out of the range\n    `[0, 255]`, auto contrast will change the `1` values to be `255`.\n\n    This layer is active at both training and inference time.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Args:\n        value_range: Range of values the incoming images will have.\n            Represented as a two number tuple written `(low, high)`.\n            This is typically either `(0, 1)` or `(0, 255)` depending\n            on how your preprocessing pipeline is set up.\n            Defaults to `(0, 255)`.\n    \"\"\"\n\n    _USE_BASE_FACTOR = False\n    _VALUE_RANGE_VALIDATION_ERROR = (\n        \"The `value_range` argument should be a list of two numbers. \"\n    )\n\n    def __init__(\n        self,\n        value_range=(0, 255),\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self._set_value_range(value_range)\n\n    def _set_value_range(self, value_range):\n        if not isinstance(value_range, (tuple, list)):\n            raise ValueError(\n                self._VALUE_RANGE_VALIDATION_ERROR\n                + f\"Received: value_range={value_range}\"\n            )\n        if len(value_range) != 2:\n            raise ValueError(\n                self._VALUE_RANGE_VALIDATION_ERROR\n                + f\"Received: value_range={value_range}\"\n            )\n        self.value_range = sorted(value_range)\n\n    def transform_images(self, images, transformation=None, training=True):\n        original_images = images\n        images = self._transform_value_range(\n            images,\n            original_range=self.value_range,\n            target_range=(0, 255),\n            dtype=self.compute_dtype,\n        )\n\n        images = self.backend.cast(images, self.compute_dtype)\n        low = self.backend.numpy.min(images, axis=(1, 2), keepdims=True)\n        high = self.backend.numpy.max(images, axis=(1, 2), keepdims=True)\n        scale = 255.0 / (high - low)\n        offset = -low * scale\n\n        images = images * scale + offset\n        results = self.backend.numpy.clip(images, 0.0, 255.0)\n        results = self._transform_value_range(\n            results,\n            original_range=(0, 255),\n            target_range=self.value_range,\n            dtype=self.compute_dtype,\n        )\n        # don't process NaN channels\n        results = self.backend.numpy.where(\n            self.backend.numpy.isnan(results), original_images, results\n        )\n        if results.dtype == images.dtype:\n            return results\n        if backend.is_int_dtype(images.dtype):\n            results = self.backend.numpy.round(results)\n        return _saturate_cast(results, images.dtype, self.backend)\n\n    def transform_labels(self, labels, transformation, training=True):\n        return labels\n\n    def transform_bounding_boxes(\n        self,\n        bounding_boxes,\n        transformation,\n        training=True,\n    ):\n        return bounding_boxes\n\n    def transform_segmentation_masks(\n        self, segmentation_masks, transformation, training=True\n    ):\n        return segmentation_masks\n\n    def get_config(self):\n        config = super().get_config()\n        config.update({\"value_range\": self.value_range})\n        return config\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/auto_contrast_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import layers\nfrom keras.src import ops\nfrom keras.src import testing\n\n\nclass AutoContrastTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_layer(self):\n        self.run_layer_test(\n            layers.AutoContrast,\n            init_kwargs={\n                \"value_range\": (20, 200),\n            },\n            input_shape=(8, 3, 4, 3),\n            supports_masking=False,\n            expected_output_shape=(8, 3, 4, 3),\n        )\n\n    def test_constant_channels_dont_get_nanned(self):\n        img = np.array([1, 1], dtype=\"float32\")\n        img = np.expand_dims(img, axis=-1)\n        img = np.expand_dims(img, axis=-1)\n        img = np.expand_dims(img, axis=0)\n\n        layer = layers.AutoContrast(value_range=(0, 255))\n        ys = layer(img)\n\n        self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 1.0))\n        self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 1.0))\n\n    def test_auto_contrast_expands_value_range(self):\n        img = np.array([0, 128], dtype=\"float32\")\n        img = np.expand_dims(img, axis=-1)\n        img = np.expand_dims(img, axis=-1)\n        img = np.expand_dims(img, axis=0)\n\n        layer = layers.AutoContrast(value_range=(0, 255))\n        ys = layer(img)\n\n        self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 0.0))\n        self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 255.0))\n\n    def test_auto_contrast_different_values_per_channel(self):\n        img = np.array(\n            [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],\n            dtype=\"float32\",\n        )\n        img = np.expand_dims(img, axis=0)\n\n        layer = layers.AutoContrast(value_range=(0, 255))\n        ys = layer(img)\n\n        self.assertTrue(np.any(ops.convert_to_numpy(ys[0, ..., 0]) == 0.0))\n        self.assertTrue(np.any(ops.convert_to_numpy(ys[0, ..., 1]) == 0.0))\n\n        self.assertTrue(np.any(ops.convert_to_numpy(ys[0, ..., 0]) == 255.0))\n        self.assertTrue(np.any(ops.convert_to_numpy(ys[0, ..., 1]) == 255.0))\n\n        self.assertAllClose(\n            ys,\n            [\n                [\n                    [[0.0, 0.0, 0.0], [85.0, 85.0, 85.0]],\n                    [[170.0, 170.0, 170.0], [255.0, 255.0, 255.0]],\n                ]\n            ],\n        )\n\n    def test_auto_contrast_expands_value_range_uint8(self):\n        img = np.array([0, 128], dtype=\"uint8\")\n        img = np.expand_dims(img, axis=-1)\n        img = np.expand_dims(img, axis=-1)\n        img = np.expand_dims(img, axis=0)\n\n        layer = layers.AutoContrast(value_range=(0, 255))\n        ys = layer(img)\n\n        self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 0.0))\n        self.assertTrue(np.any(ops.convert_to_numpy(ys[0]) == 255.0))\n\n    def test_auto_contrast_properly_converts_value_range(self):\n        img = np.array([0, 0.5], dtype=\"float32\")\n        img = np.expand_dims(img, axis=-1)\n        img = np.expand_dims(img, axis=-1)\n        img = np.expand_dims(img, axis=0)\n\n        layer = layers.AutoContrast(value_range=(0, 1))\n        ys = layer(img)\n        self.assertAllClose(\n            ops.convert_to_numpy(ys[0]), np.array([[[0.0]], [[1]]])\n        )\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/base_image_preprocessing_layer.py",
    "content": "import math\n\nfrom keras.src.backend import config as backend_config\nfrom keras.src.layers.preprocessing.data_layer import DataLayer\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.validation import (  # noqa: E501\n    densify_bounding_boxes,\n)\n\n\nclass BaseImagePreprocessingLayer(DataLayer):\n    _USE_BASE_FACTOR = True\n    _FACTOR_BOUNDS = (-1, 1)\n\n    def __init__(\n        self, factor=None, bounding_box_format=None, data_format=None, **kwargs\n    ):\n        super().__init__(**kwargs)\n        self.bounding_box_format = bounding_box_format\n        self.data_format = backend_config.standardize_data_format(data_format)\n        if self._USE_BASE_FACTOR:\n            factor = factor or 0.0\n            self._set_factor(factor)\n        elif factor is not None:\n            raise ValueError(\n                f\"Layer {self.__class__.__name__} does not take \"\n                f\"a `factor` argument. Received: factor={factor}\"\n            )\n\n    def _set_factor(self, factor):\n        error_msg = (\n            \"The `factor` argument should be a number \"\n            \"(or a list of two numbers) \"\n            \"in the range \"\n            f\"[{self._FACTOR_BOUNDS[0]}, {self._FACTOR_BOUNDS[1]}]. \"\n            f\"Received: factor={factor}\"\n        )\n        if isinstance(factor, (tuple, list)):\n            if len(factor) != 2:\n                raise ValueError(error_msg)\n            if (\n                factor[0] > self._FACTOR_BOUNDS[1]\n                or factor[1] < self._FACTOR_BOUNDS[0]\n            ):\n                raise ValueError(error_msg)\n            lower, upper = sorted(factor)\n        elif isinstance(factor, (int, float)):\n            if (\n                factor < self._FACTOR_BOUNDS[0]\n                or factor > self._FACTOR_BOUNDS[1]\n            ):\n                raise ValueError(error_msg)\n            factor = abs(factor)\n            lower, upper = [max(-factor, self._FACTOR_BOUNDS[0]), factor]\n        else:\n            raise ValueError(error_msg)\n        self.factor = lower, upper\n\n    def get_random_transformation(self, data, training=True, seed=None):\n        return None\n\n    def _transform_images(self, images, transformation, interpolation):\n        raise NotImplementedError()\n\n    def transform_images(self, images, transformation, training=True):\n        images = self.backend.cast(images, self.compute_dtype)\n        if training:\n            images = self._transform_images(\n                images, transformation, self.interpolation\n            )\n        return images\n\n    def transform_labels(self, labels, transformation, training=True):\n        raise NotImplementedError()\n\n    def transform_bounding_boxes(\n        self,\n        bounding_boxes,\n        transformation,\n        training=True,\n    ):\n        raise NotImplementedError()\n\n    def transform_segmentation_masks(\n        self, segmentation_masks, transformation, training=True\n    ):\n        if training:\n            segmentation_masks = self._transform_images(\n                segmentation_masks, transformation, \"nearest\"\n            )\n        return segmentation_masks\n\n    def transform_single_image(self, image, transformation, training=True):\n        images = self.backend.numpy.expand_dims(image, axis=0)\n        outputs = self.transform_images(\n            images, transformation=transformation, training=training\n        )\n        return self.backend.numpy.squeeze(outputs, axis=0)\n\n    def transform_single_label(self, label, transformation, training=True):\n        labels = self.backend.numpy.expand_dims(label, axis=0)\n        outputs = self.transform_labels(\n            labels, transformation=transformation, training=training\n        )\n        return self.backend.numpy.squeeze(outputs, axis=0)\n\n    def transform_single_bounding_box(\n        self,\n        bounding_box,\n        transformation,\n        training=True,\n    ):\n        bounding_boxes = self._format_single_input_bounding_box(bounding_box)\n        outputs = self.transform_bounding_boxes(\n            bounding_boxes,\n            transformation=transformation,\n            training=training,\n        )\n        bounding_box = self._format_single_output_bounding_box(outputs)\n        return bounding_box\n\n    def transform_single_segmentation_mask(\n        self, segmentation_mask, transformation, training=True\n    ):\n        segmentation_masks = self.backend.numpy.expand_dims(\n            segmentation_mask, axis=0\n        )\n        outputs = self.transform_segmentation_masks(\n            segmentation_masks, transformation=transformation, training=training\n        )\n        return self.backend.numpy.squeeze(outputs, axis=0)\n\n    def _is_batched(self, maybe_image_batch):\n        shape = self.backend.core.shape(maybe_image_batch)\n        if len(shape) == 3:\n            return False\n        if len(shape) == 4:\n            return True\n        raise ValueError(\n            \"Expected image tensor to have rank 3 (single image) \"\n            f\"or 4 (batch of images). Received: data.shape={shape}\"\n        )\n\n    def call(self, data, training=True):\n        transformation = self.get_random_transformation(data, training=training)\n        if isinstance(data, dict):\n            is_batched = self._is_batched(data[\"images\"])\n            if is_batched:\n                data[\"images\"] = self.transform_images(\n                    self.backend.convert_to_tensor(data[\"images\"]),\n                    transformation=transformation,\n                    training=training,\n                )\n            else:\n                data[\"images\"] = self.transform_single_image(\n                    self.backend.convert_to_tensor(data[\"images\"]),\n                    transformation=transformation,\n                    training=training,\n                )\n            if \"bounding_boxes\" in data:\n                if not self.bounding_box_format:\n                    raise ValueError(\n                        \"You passed an input with a 'bounding_boxes' key, \"\n                        \"but you didn't specify a bounding box format. \"\n                        \"Pass a `bounding_box_format` argument to your \"\n                        f\"{self.__class__.__name__} layer, e.g. \"\n                        \"`bounding_box_format='xyxy'`.\"\n                    )\n                bounding_boxes = densify_bounding_boxes(\n                    data[\"bounding_boxes\"],\n                    is_batched=is_batched,\n                    backend=self.backend,\n                )\n\n                if is_batched:\n                    data[\"bounding_boxes\"] = self.transform_bounding_boxes(\n                        bounding_boxes,\n                        transformation=transformation,\n                        training=training,\n                    )\n                else:\n                    data[\"bounding_boxes\"] = self.transform_single_bounding_box(\n                        bounding_boxes,\n                        transformation=transformation,\n                        training=training,\n                    )\n            if \"labels\" in data:\n                if is_batched:\n                    data[\"labels\"] = self.transform_labels(\n                        self.backend.convert_to_tensor(data[\"labels\"]),\n                        transformation=transformation,\n                        training=training,\n                    )\n                else:\n                    data[\"labels\"] = self.transform_single_label(\n                        self.backend.convert_to_tensor(data[\"labels\"]),\n                        transformation=transformation,\n                        training=training,\n                    )\n            if \"segmentation_masks\" in data:\n                if is_batched:\n                    data[\"segmentation_masks\"] = (\n                        self.transform_segmentation_masks(\n                            data[\"segmentation_masks\"],\n                            transformation=transformation,\n                            training=training,\n                        )\n                    )\n                else:\n                    data[\"segmentation_masks\"] = (\n                        self.transform_single_segmentation_mask(\n                            data[\"segmentation_masks\"],\n                            transformation=transformation,\n                            training=training,\n                        )\n                    )\n            return data\n\n        # `data` is just images.\n        if self._is_batched(data):\n            return self.transform_images(\n                self.backend.convert_to_tensor(data),\n                transformation=transformation,\n                training=training,\n            )\n        return self.transform_single_image(\n            self.backend.convert_to_tensor(data),\n            transformation=transformation,\n            training=training,\n        )\n\n    def _format_single_input_bounding_box(self, bounding_box):\n        for key in bounding_box:\n            if key == \"labels\":\n                bounding_box[key] = self.backend.numpy.expand_dims(\n                    bounding_box[key], axis=0\n                )\n            if key == \"boxes\":\n                bounding_box[key] = self.backend.numpy.expand_dims(\n                    bounding_box[key], axis=0\n                )\n\n        return bounding_box\n\n    def _format_single_output_bounding_box(self, bounding_boxes):\n        for key in bounding_boxes:\n            if key == \"labels\":\n                bounding_boxes[key] = self.backend.numpy.squeeze(\n                    bounding_boxes[key], axis=0\n                )\n            if key == \"boxes\":\n                bounding_boxes[key] = self.backend.numpy.squeeze(\n                    bounding_boxes[key], axis=0\n                )\n\n        return bounding_boxes\n\n    def get_config(self):\n        config = super().get_config()\n        if self.bounding_box_format is not None:\n            config.update(\n                {\n                    \"bounding_box_format\": self.bounding_box_format,\n                }\n            )\n        return config\n\n    def _transform_value_range(\n        self, images, original_range, target_range, dtype=\"float32\"\n    ):\n        \"\"\"Convert input values from `original_range` to `target_range`.\n\n        This function is intended to be used in preprocessing layers that\n        rely upon color values. This allows us to assume internally that\n        the input tensor is always in the range `(0, 255)`.\n\n        Args:\n            images: the set of images to transform to the target range.\n            original_range: the value range to transform from.\n            target_range: the value range to transform to.\n            dtype: the dtype to compute the conversion with,\n                defaults to \"float32\".\n\n        Returns:\n            a new Tensor with values in the target range.\n\n        Example:\n\n        ```python\n        original_range = [0, 1]\n        target_range = [0, 255]\n        images = layer.preprocessing.transform_value_range(\n            images,\n            original_range,\n            target_range\n        )\n        images = ops.minimum(images + 10, 255)\n        images = layer.preprocessing.transform_value_range(\n            images,\n            target_range,\n            original_range\n        )\n        ```\n        \"\"\"\n        if (\n            original_range[0] == target_range[0]\n            and original_range[1] == target_range[1]\n        ):\n            return images\n\n        images = self.backend.cast(images, dtype=dtype)\n        original_min_value, original_max_value = self._unwrap_value_range(\n            original_range, dtype=dtype\n        )\n        target_min_value, target_max_value = self._unwrap_value_range(\n            target_range, dtype=dtype\n        )\n\n        # images in the [0, 1] scale\n        images = (images - original_min_value) / (\n            original_max_value - original_min_value\n        )\n\n        scale_factor = target_max_value - target_min_value\n        return (images * scale_factor) + target_min_value\n\n    def _unwrap_value_range(self, value_range, dtype=\"float32\"):\n        min_value, max_value = value_range\n        min_value = self.backend.cast(min_value, dtype=dtype)\n        max_value = self.backend.cast(max_value, dtype=dtype)\n        return min_value, max_value\n\n    def _compute_affine_matrix(\n        self,\n        center_x,\n        center_y,\n        angle,\n        translate_x,\n        translate_y,\n        scale,\n        shear_x,\n        shear_y,\n        height,\n        width,\n    ):\n        \"\"\"\n        #       Scaling          Shear           Rotation\n        #     [sx  0   0]    [1   shx  0]   [cos(θ)  -sin(θ)  0]\n        # M = [0   sy  0] *  [shy  1   0] * [sin(θ)   cos(θ)  0]\n        #     [0   0   1]    [0    0   1]   [0        0       1]\n\n        # a0 = sx * (cos(θ) + shx * sin(θ))\n        # a1 = sx * (-sin(θ) + shx * cos(θ))\n        # a2 = tx + cx - cx * a0 - cy * a1\n        # b0 = sy * (shy * cos(θ) + sin(θ))\n        # b1 = sy * (shy * -sin(θ) + cos(θ))\n        # b2 = ty + cy - cx * b0 - cy * b1\n        \"\"\"\n        ops = self.backend\n\n        degree_to_radian_factor = ops.convert_to_tensor(math.pi / 180.0)\n\n        angle = angle * degree_to_radian_factor\n        shear_x = shear_x * degree_to_radian_factor\n        shear_y = shear_y * degree_to_radian_factor\n\n        batch_size = ops.shape(angle)[0]\n        dtype = angle.dtype\n        width = ops.cast(width, dtype)\n        height = ops.cast(height, dtype)\n        cx = center_x * (width - 1)\n        cy = center_y * (height - 1)\n\n        cos_theta = ops.numpy.cos(angle)\n        sin_theta = ops.numpy.sin(angle)\n        shear_x = ops.numpy.tan(shear_x)\n        shear_y = ops.numpy.tan(shear_y)\n\n        a0 = scale * (cos_theta + shear_x * sin_theta)\n        a1 = scale * (-sin_theta + shear_x * cos_theta)\n        a2 = translate_x + cx - cx * a0 - cy * a1\n        b0 = scale * (shear_y * cos_theta + sin_theta)\n        b1 = scale * (shear_y * -sin_theta + cos_theta)\n        b2 = translate_y + cy - cx * b0 - cy * b1\n        affine_matrix = ops.numpy.concatenate(\n            [\n                a0[:, None],\n                a1[:, None],\n                a2[:, None],\n                b0[:, None],\n                b1[:, None],\n                b2[:, None],\n                ops.numpy.zeros((batch_size, 2)),\n            ],\n            axis=1,\n        )\n\n        return affine_matrix\n\n\nbase_image_preprocessing_transform_example = \"\"\"\n```python\nlayer = keras.layers.{LayerName}(bounding_box_format=\"xyxy\")\nimages = np.random.randint(0, 255, (4, 224, 224, 3), dtype=\"uint8\")\n\nbounding_boxes = {\n    \"boxes\": np.array([\n        [[10, 20, 100, 150], [50, 60, 200, 250]],\n        [[15, 25, 110, 160], [55, 65, 210, 260]],\n        [[20, 30, 120, 170], [60, 70, 220, 270]],\n        [[25, 35, 130, 180], [65, 75, 230, 280]],\n    ], dtype=\"float32\"),\n    \"labels\": np.array([[0, 1], [1, 2], [2, 3], [0, 3]], dtype=\"int32\")\n}\n\nlabels = keras.ops.one_hot(\n    np.array([0, 1, 2, 3]),\n    num_classes=4\n)\n\nsegmentation_masks = np.random.randint(0, 3, (4, 224, 224, 1), dtype=\"uint8\")\n\noutput = layer(\n    {\n        \"images\": images,\n        \"bounding_boxes\": bounding_boxes,\n        \"labels\": labels,\n        \"segmentation_masks\": segmentation_masks\n    },\n    training=True\n)\n```\n\"\"\"\n\nbase_image_preprocessing_color_example = \"\"\"\n```python\nlayer = keras.layers.{LayerName}(value_range=(0, 255))\nimages = np.random.randint(0, 255, (8, 224, 224, 3), dtype=\"uint8\")\n\nlabels = keras.ops.one_hot(\n    np.array([0, 1, 2, 0, 1, 2, 0, 1]),\n    num_classes=3\n)\n\nsegmentation_masks = np.random.randint(0, 3, (8, 224, 224, 1), dtype=\"uint8\")\n\noutput = layer(\n    {\n        \"images\": images,\n        \"labels\": labels,\n        \"segmentation_masks\": segmentation_masks\n    },\n    training=True\n)\n```\n\"\"\"\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/__init__.py",
    "content": ""
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/bounding_box.py",
    "content": "import math\n\nfrom keras.src.utils import backend_utils\n\nSUPPORTED_FORMATS = (\n    \"xyxy\",\n    \"yxyx\",\n    \"xywh\",\n    \"center_xywh\",\n    \"center_yxhw\",\n    \"rel_xyxy\",\n    \"rel_yxyx\",\n    \"rel_xywh\",\n    \"rel_center_xywh\",\n)\n\n\nclass BoundingBox:\n    def __init__(self):\n        self.backend = backend_utils.DynamicBackend()\n\n    def convert_format(\n        self,\n        boxes,\n        source,\n        target,\n        height=None,\n        width=None,\n        dtype=\"float32\",\n    ):\n        if isinstance(boxes, dict):\n            boxes[\"boxes\"] = self.convert_format(\n                boxes[\"boxes\"],\n                source=source,\n                target=target,\n                height=height,\n                width=width,\n                dtype=dtype,\n            )\n            return boxes\n\n        to_xyxy_converters = {\n            \"xyxy\": self._xyxy_to_xyxy,\n            \"yxyx\": self._yxyx_to_xyxy,\n            \"xywh\": self._xywh_to_xyxy,\n            \"center_xywh\": self._center_xywh_to_xyxy,\n            \"center_yxhw\": self._center_yxhw_to_xyxy,\n            \"rel_xyxy\": self._rel_xyxy_to_xyxy,\n            \"rel_yxyx\": self._rel_yxyx_to_xyxy,\n            \"rel_xywh\": self._rel_xywh_to_xyxy,\n            \"rel_center_xywh\": self._rel_center_xywh_to_xyxy,\n        }\n        from_xyxy_converters = {\n            \"xyxy\": self._xyxy_to_xyxy,\n            \"yxyx\": self._xyxy_to_yxyx,\n            \"xywh\": self._xyxy_to_xywh,\n            \"center_xywh\": self._xyxy_to_center_xywh,\n            \"center_yxhw\": self._xyxy_to_center_yxhw,\n            \"rel_xyxy\": self._xyxy_to_rel_xyxy,\n            \"rel_yxyx\": self._xyxy_to_rel_yxyx,\n            \"rel_xywh\": self._xyxy_to_rel_xywh,\n            \"rel_center_xywh\": self._xyxy_to_rel_center_xywh,\n        }\n\n        ops = self.backend\n        boxes_shape = ops.shape(boxes)\n        if boxes_shape[-1] != 4:\n            raise ValueError(\n                \"`boxes` must be a tensor with the last dimension of 4. \"\n                f\"Received: boxes.shape={boxes_shape}\"\n            )\n        source = source.lower()\n        target = target.lower()\n        if source not in SUPPORTED_FORMATS or target not in SUPPORTED_FORMATS:\n            raise ValueError(\n                f\"Invalid source or target format. \"\n                f\"Supported formats: {SUPPORTED_FORMATS}\"\n            )\n\n        if (source.startswith(\"rel_\") or target.startswith(\"rel_\")) and (\n            width is None or height is None\n        ):\n            raise ValueError(\n                \"convert_format() must receive `height` and `width` \"\n                \"transforming between relative and absolute formats.\"\n                f\"convert_format() received source=`{source}`, \"\n                f\"target=`{target}, \"\n                f\"but height={height} and width={width}.\"\n            )\n        boxes = ops.cast(boxes, dtype)\n        if source == target:\n            return boxes\n        if width is not None:\n            width = ops.cast(width, dtype)\n        if height is not None:\n            height = ops.cast(height, dtype)\n\n        if source.startswith(\"rel_\") and target.startswith(\"rel_\"):\n            source = source.replace(\"rel_\", \"\", 1)\n            target = target.replace(\"rel_\", \"\", 1)\n        to_xyxy_converter = to_xyxy_converters[source]\n        from_xyxy_converter = from_xyxy_converters[target]\n        in_xyxy_boxes = to_xyxy_converter(boxes, height, width)\n        return from_xyxy_converter(in_xyxy_boxes, height, width)\n\n    def clip_to_image_size(\n        self,\n        bounding_boxes,\n        height=None,\n        width=None,\n        bounding_box_format=\"xyxy\",\n    ):\n        if bounding_box_format not in (\"xyxy\", \"rel_xyxy\"):\n            raise NotImplementedError\n        if bounding_box_format == \"xyxy\" and (height is None or width is None):\n            raise ValueError(\n                \"`height` and `width` must be set if `format='xyxy'`.\"\n            )\n\n        ops = self.backend\n        boxes = bounding_boxes[\"boxes\"]\n        labels = bounding_boxes.get(\"labels\", None)\n        if width is not None:\n            width = ops.cast(width, boxes.dtype)\n        if height is not None:\n            height = ops.cast(height, boxes.dtype)\n\n        if bounding_box_format == \"xyxy\":\n            x1, y1, x2, y2 = ops.numpy.split(boxes, 4, axis=-1)\n            x1 = ops.numpy.clip(x1, 0, width)\n            y1 = ops.numpy.clip(y1, 0, height)\n            x2 = ops.numpy.clip(x2, 0, width)\n            y2 = ops.numpy.clip(y2, 0, height)\n            boxes = ops.numpy.concatenate([x1, y1, x2, y2], axis=-1)\n\n            if labels is not None:\n                areas = self._compute_area(boxes)\n                areas = ops.numpy.squeeze(areas, axis=-1)\n                labels = ops.numpy.where(areas > 0, labels, -1)\n        elif bounding_box_format == \"rel_xyxy\":\n            x1, y1, x2, y2 = ops.numpy.split(boxes, 4, axis=-1)\n            x1 = ops.numpy.clip(x1, 0.0, 1.0)\n            y1 = ops.numpy.clip(y1, 0.0, 1.0)\n            x2 = ops.numpy.clip(x2, 0.0, 1.0)\n            y2 = ops.numpy.clip(y2, 0.0, 1.0)\n            boxes = ops.numpy.concatenate([x1, y1, x2, y2], axis=-1)\n\n            if labels is not None:\n                areas = self._compute_area(boxes)\n                areas = ops.numpy.squeeze(areas, axis=-1)\n                labels = ops.numpy.where(areas > 0, labels, -1)\n\n        result = bounding_boxes.copy()\n        result[\"boxes\"] = boxes\n        if labels is not None:\n            result[\"labels\"] = labels\n        return result\n\n    def affine(\n        self,\n        boxes,\n        angle,\n        translate_x,\n        translate_y,\n        scale,\n        shear_x,\n        shear_y,\n        height,\n        width,\n        center_x=None,\n        center_y=None,\n    ):\n        ops = self.backend\n\n        boxes_shape = ops.shape(boxes)\n        batch_size = boxes_shape[0]\n        n_boxes = boxes_shape[1]\n        if center_x is None:\n            center_x = 0.5\n        if center_y is None:\n            center_y = 0.5\n        matrix = self._compute_inverse_affine_matrix(\n            center_x,\n            center_y,\n            angle,\n            translate_x,\n            translate_y,\n            scale,\n            shear_x,\n            shear_y,\n            height,\n            width,\n        )\n        boxes = ops.cast(boxes, dtype=matrix.dtype)\n        transposed_matrix = ops.numpy.transpose(matrix[:, :2, :], [0, 2, 1])\n        points = boxes  # [B, N, 4]\n        points = ops.numpy.stack(\n            [\n                points[..., 0],\n                points[..., 1],\n                points[..., 2],\n                points[..., 1],\n                points[..., 2],\n                points[..., 3],\n                points[..., 0],\n                points[..., 3],\n            ],\n            axis=-1,\n        )\n        points = ops.numpy.reshape(points, [batch_size, n_boxes, 4, 2])\n        points = ops.numpy.concatenate(\n            [\n                points,\n                ops.numpy.ones([batch_size, n_boxes, 4, 1], points.dtype),\n            ],\n            axis=-1,\n        )\n        transformed_points = ops.numpy.einsum(\n            \"bnxy,byz->bnxz\", points, transposed_matrix\n        )\n        boxes_min = ops.numpy.amin(transformed_points, axis=2)\n        boxes_max = ops.numpy.amax(transformed_points, axis=2)\n        outputs = ops.numpy.concatenate([boxes_min, boxes_max], axis=-1)\n        return outputs\n\n    def crop(self, boxes, top, left, height, width):\n        ops = self.backend\n\n        x1, y1, x2, y2 = ops.numpy.split(boxes, 4, axis=-1)\n        x1 = x1 - left\n        y1 = y1 - top\n        x2 = x2 - left\n        y2 = y2 - top\n        x1 = ops.numpy.clip(x1, 0, width)\n        y1 = ops.numpy.clip(y1, 0, height)\n        x2 = ops.numpy.clip(x2, 0, width)\n        y2 = ops.numpy.clip(y2, 0, height)\n        outputs = ops.numpy.concatenate([x1, y1, x2, y2], axis=-1)\n        return outputs\n\n    def pad(self, boxes, top, left):\n        ops = self.backend\n\n        x1, y1, x2, y2 = ops.numpy.split(boxes, 4, axis=-1)\n        x1 = x1 + left\n        y1 = y1 + top\n        x2 = x2 + left\n        y2 = y2 + top\n        outputs = ops.numpy.concatenate([x1, y1, x2, y2], axis=-1)\n        return outputs\n\n    # Converters\n\n    def _xyxy_to_xyxy(self, boxes, height=None, width=None):\n        return boxes\n\n    def _yxyx_to_xyxy(self, boxes, height=None, width=None):\n        y1, x1, y2, x2 = self.backend.numpy.split(boxes, 4, axis=-1)\n        return self.backend.numpy.concatenate([x1, y1, x2, y2], axis=-1)\n\n    def _xywh_to_xyxy(self, boxes, height=None, width=None):\n        x1, y1, w, h = self.backend.numpy.split(boxes, 4, axis=-1)\n        x2 = x1 + w\n        y2 = y1 + h\n        return self.backend.numpy.concatenate([x1, y1, x2, y2], axis=-1)\n\n    def _center_xywh_to_xyxy(self, boxes, height=None, width=None):\n        ops = self.backend\n        cx, cy, w, h = ops.numpy.split(boxes, 4, axis=-1)\n        half_w = w / 2.0\n        half_h = h / 2.0\n        x1 = cx - half_w\n        y1 = cy - half_h\n        x2 = cx + half_w\n        y2 = cy + half_h\n        return self.backend.numpy.concatenate([x1, y1, x2, y2], axis=-1)\n\n    def _center_yxhw_to_xyxy(self, boxes, height=None, width=None):\n        ops = self.backend\n        cy, cx, h, w = ops.numpy.split(boxes, 4, axis=-1)\n        half_w = w / 2.0\n        half_h = h / 2.0\n        x1 = cx - half_w\n        y1 = cy - half_h\n        x2 = cx + half_w\n        y2 = cy + half_h\n        return self.backend.numpy.concatenate([x1, y1, x2, y2], axis=-1)\n\n    def _rel_xyxy_to_xyxy(self, boxes, height=None, width=None):\n        ops = self.backend\n        rel_x1, rel_y1, rel_x2, rel_y2 = ops.numpy.split(boxes, 4, axis=-1)\n        x1 = rel_x1 * width\n        y1 = rel_y1 * height\n        x2 = rel_x2 * width\n        y2 = rel_y2 * height\n        return self.backend.numpy.concatenate([x1, y1, x2, y2], axis=-1)\n\n    def _rel_yxyx_to_xyxy(self, boxes, height=None, width=None):\n        ops = self.backend\n        rel_y1, rel_x1, rel_y2, rel_x2 = ops.numpy.split(boxes, 4, axis=-1)\n        x1 = rel_x1 * width\n        y1 = rel_y1 * height\n        x2 = rel_x2 * width\n        y2 = rel_y2 * height\n        return self.backend.numpy.concatenate([x1, y1, x2, y2], axis=-1)\n\n    def _rel_xywh_to_xyxy(self, boxes, height=None, width=None):\n        ops = self.backend\n        rel_x1, rel_y1, rel_w, rel_h = ops.numpy.split(boxes, 4, axis=-1)\n        x1 = rel_x1 * width\n        y1 = rel_y1 * height\n        x2 = (rel_x1 + rel_w) * width\n        y2 = (rel_y1 + rel_h) * height\n        return self.backend.numpy.concatenate([x1, y1, x2, y2], axis=-1)\n\n    def _rel_center_xywh_to_xyxy(self, boxes, height=None, width=None):\n        ops = self.backend\n        rel_cx, rel_cy, rel_w, rel_h = ops.numpy.split(boxes, 4, axis=-1)\n        half_rel_w = rel_w / 2.0\n        half_rel_h = rel_h / 2.0\n        x1 = (rel_cx - half_rel_w) * height\n        y1 = (rel_cy - half_rel_h) * width\n        x2 = (rel_cx + half_rel_w) * height\n        y2 = (rel_cy + half_rel_h) * width\n        return self.backend.numpy.concatenate([x1, y1, x2, y2], axis=-1)\n\n    def _xyxy_to_yxyx(self, boxes, height=None, width=None):\n        x1, y1, x2, y2 = self.backend.numpy.split(boxes, 4, axis=-1)\n        return self.backend.numpy.concatenate([y1, x1, y2, x2], axis=-1)\n\n    def _xyxy_to_xywh(self, boxes, height=None, width=None):\n        x1, y1, x2, y2 = self.backend.numpy.split(boxes, 4, axis=-1)\n        w = x2 - x1\n        h = y2 - y1\n        return self.backend.numpy.concatenate([x1, y1, w, h], axis=-1)\n\n    def _xyxy_to_center_xywh(self, boxes, height=None, width=None):\n        x1, y1, x2, y2 = self.backend.numpy.split(boxes, 4, axis=-1)\n        cx = x1 + ((x2 - x1) / 2.0)\n        cy = y1 + ((y2 - y1) / 2.0)\n        w = x2 - x1\n        h = y2 - y1\n        return self.backend.numpy.concatenate([cx, cy, w, h], axis=-1)\n\n    def _xyxy_to_center_yxhw(self, boxes, height=None, width=None):\n        x1, y1, x2, y2 = self.backend.numpy.split(boxes, 4, axis=-1)\n        cx = x1 + ((x2 - x1) / 2.0)\n        cy = y1 + ((y2 - y1) / 2.0)\n        w = x2 - x1\n        h = y2 - y1\n        return self.backend.numpy.concatenate([cy, cx, h, w], axis=-1)\n\n    def _xyxy_to_rel_xyxy(self, boxes, height=None, width=None):\n        x1, y1, x2, y2 = self.backend.numpy.split(boxes, 4, axis=-1)\n        rel_x1 = self.backend.numpy.divide(x1, width)\n        rel_y1 = self.backend.numpy.divide(y1, height)\n        rel_x2 = self.backend.numpy.divide(x2, width)\n        rel_y2 = self.backend.numpy.divide(y2, height)\n        return self.backend.numpy.concatenate(\n            [rel_x1, rel_y1, rel_x2, rel_y2], axis=-1\n        )\n\n    def _xyxy_to_rel_yxyx(self, boxes, height=None, width=None):\n        x1, y1, x2, y2 = self.backend.numpy.split(boxes, 4, axis=-1)\n        rel_x1 = self.backend.numpy.divide(x1, width)\n        rel_y1 = self.backend.numpy.divide(y1, height)\n        rel_x2 = self.backend.numpy.divide(x2, width)\n        rel_y2 = self.backend.numpy.divide(y2, height)\n        return self.backend.numpy.concatenate(\n            [rel_y1, rel_x1, rel_y2, rel_x2], axis=-1\n        )\n\n    def _xyxy_to_rel_xywh(self, boxes, height=None, width=None):\n        x1, y1, x2, y2 = self.backend.numpy.split(boxes, 4, axis=-1)\n        rel_x1 = x1 / width\n        rel_y1 = y1 / height\n        rel_w = (x2 - x1) / width\n        rel_h = (y2 - y1) / height\n        return self.backend.numpy.concatenate(\n            [rel_x1, rel_y1, rel_w, rel_h], axis=-1\n        )\n\n    def _xyxy_to_rel_center_xywh(self, boxes, height=None, width=None):\n        x1, y1, x2, y2 = self.backend.numpy.split(boxes, 4, axis=-1)\n        rel_cx = (x1 + ((x2 - x1) / 2.0)) / width\n        rel_cy = (y1 + ((y2 - y1) / 2.0)) / height\n        rel_w = (x2 - x1) / width\n        rel_h = (y2 - y1) / height\n        return self.backend.numpy.concatenate(\n            [rel_cx, rel_cy, rel_w, rel_h], axis=-1\n        )\n\n    # Clip\n    def _compute_area(self, boxes, format=\"xyxy\"):\n        if format not in (\"xyxy\", \"rel_xyxy\"):\n            raise NotImplementedError\n\n        ops = self.backend\n        x1, y1, x2, y2 = ops.numpy.split(boxes, 4, axis=-1)\n        widths = x2 - x1\n        heights = y2 - y1\n        return widths * heights\n\n    def _compute_inverse_affine_matrix(\n        self,\n        center_x,\n        center_y,\n        angle,\n        translate_x,\n        translate_y,\n        scale,\n        shear_x,\n        shear_y,\n        height,\n        width,\n    ):\n        # Ref: TF._geometry._get_inverse_affine_matrix\n        ops = self.backend\n        batch_size = ops.shape(angle)[0]\n        dtype = angle.dtype\n\n        angle = -angle\n        shear_x = -shear_x\n        shear_y = -shear_y\n\n        cx = ops.numpy.multiply(center_x, (width - 1))\n        cy = ops.numpy.multiply(center_y, (height - 1))\n        rot = ops.numpy.multiply(angle, 1.0 / 180.0 * math.pi)\n        tx = ops.numpy.multiply(-translate_x, (width - 1))\n        ty = ops.numpy.multiply(-translate_y, (height - 1))\n        sx = ops.numpy.multiply(shear_x, 1.0 / 180.0 * math.pi)\n        sy = ops.numpy.multiply(shear_y, 1.0 / 180.0 * math.pi)\n\n        # Cached results\n        cos_sy = ops.numpy.cos(sy)\n        tan_sx = ops.numpy.tan(sx)\n        rot_minus_sy = rot - sy\n        cx_plus_tx = cx + tx\n        cy_plus_ty = cy + ty\n\n        # Rotate Scale Shear (RSS) without scaling\n        a = ops.numpy.cos(rot_minus_sy) / cos_sy\n        b = a * tan_sx + ops.numpy.sin(rot)\n        c = -ops.numpy.sin(rot_minus_sy) / cos_sy\n        d = ops.numpy.cos(rot) - c * tan_sx\n\n        # Inverted rotation matrix with scale and shear\n        # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1\n        a0 = ops.numpy.multiply(d, scale)\n        a1 = ops.numpy.multiply(-b, scale)\n        b0 = ops.numpy.multiply(-c, scale)\n        b1 = ops.numpy.multiply(a, scale)\n        a2 = cx - a0 * cx_plus_tx - a1 * cy_plus_ty\n        b2 = cy - b0 * cx_plus_tx - b1 * cy_plus_ty\n\n        # Shape of matrix: [[batch_size], ...] -> [batch_size, 6]\n        matrix = ops.numpy.stack(\n            [\n                a0,\n                a1,\n                a2,\n                b0,\n                b1,\n                b2,\n                ops.numpy.zeros([batch_size], dtype),\n                ops.numpy.zeros([batch_size], dtype),\n                ops.numpy.ones([batch_size], dtype),\n            ],\n            axis=-1,\n        )\n        matrix = ops.numpy.reshape(matrix, [batch_size, 3, 3])\n        return matrix\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/converters.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.bounding_box import (  # noqa: E501\n    BoundingBox,\n)\nfrom keras.src.utils import backend_utils\n\n\n@keras_export(\"keras.utils.bounding_boxes.convert_format\")\ndef convert_format(\n    boxes, source, target, height=None, width=None, dtype=\"float32\"\n):\n    \"\"\"Converts bounding boxes between formats.\n\n    Supported formats (case-insensitive):\n    `\"xyxy\"`: [left, top, right, bottom]\n    `\"yxyx\"`: [top, left, bottom, right]\n    `\"xywh\"`: [left, top, width, height]\n    `\"center_xywh\"`: [center_x, center_y, width, height]\n    `\"center_yxhw\"`: [center_y, center_x, height, width]\n    `\"rel_xyxy\"`, `\"rel_yxyx\"`, `\"rel_xywh\"`, `\"rel_center_xywh\"`:  Relative\n        versions of the above formats, where coordinates are normalized\n        to the range [0, 1] based on the image `height` and `width`.\n\n    Args:\n        boxes: Bounding boxes tensor/array or dictionary of `boxes` and\n            `labels`.\n        source: Source format string.\n        target: Target format string.\n        height: Image height (required for relative target format).\n        width: Image width (required for relative target format).\n        dtype: Data type for conversion (optional).\n\n    Returns:\n        Converted boxes.\n\n    Raises:\n        ValueError: For invalid formats, shapes, or missing dimensions.\n\n    Example:\n    ```python\n    boxes = np.array([[10, 20, 30, 40], [50, 60, 70, 80]])\n    # Convert from 'xyxy' to 'xywh' format\n    boxes_xywh = keras.utils.bounding_boxes.convert_format(\n        boxes, source='xyxy', target='xywh'\n    )  # Output: [[10. 20. 20. 20.], [50. 60. 20. 20.]]\n\n    # Convert to relative 'rel_xyxy' format\n    boxes_rel_xyxy = keras.utils.bounding_boxes.convert_format(\n        boxes, source='xyxy', target='rel_xyxy', height=200, width=300\n    ) # Output: [[0.03333334 0.1        0.1        0.2       ],\n               #[0.16666667 0.3        0.23333333 0.4       ]]\n    ```\n    \"\"\"\n    box_utils = BoundingBox()\n    # Switch to tensorflow backend if we are in tf.data pipe\n    if backend_utils.in_tf_graph():\n        box_utils.backend.set_backend(\"tensorflow\")\n    boxes = box_utils.convert_format(\n        boxes=boxes,\n        source=source,\n        target=target,\n        height=height,\n        width=width,\n        dtype=dtype,\n    )\n    # Switch back to original backend\n    box_utils.backend.reset()\n    return boxes\n\n\n@keras_export(\"keras.utils.bounding_boxes.clip_to_image_size\")\ndef clip_to_image_size(\n    bounding_boxes, height=None, width=None, bounding_box_format=\"xyxy\"\n):\n    \"\"\"Clips bounding boxes to be within the image dimensions.\n    Args:\n        bounding_boxes: A dictionary with 'boxes' shape `(N, 4)` or\n            `(batch, N, 4)` and 'labels' shape `(N,)` or `(batch, N,)`.\n        height: Image height.\n        width: Image width.\n        bounding_box_format: The format of the input bounding boxes. Defaults to\n            `\"xyxy\"`.\n\n    Returns:\n        Clipped bounding boxes.\n\n    Example:\n    ```python\n    boxes = {\"boxes\": np.array([[-10, -20, 150, 160], [50, 40, 70, 80]]),\n             \"labels\": np.array([0, 1])}\n    clipped_boxes = keras.utils.bounding_boxes.clip_to_image_size(\n        boxes, height=100, width=120,\n    )\n    # Output will have boxes clipped to the image boundaries, and labels\n    # potentially adjusted if the clipped area becomes zero\n    ```\n    \"\"\"\n\n    box_utils = BoundingBox()\n    # Switch to tensorflow backend if we are in tf.data pipe\n    if backend_utils.in_tf_graph():\n        box_utils.backend.set_backend(\"tensorflow\")\n    bounding_boxes = box_utils.clip_to_image_size(\n        bounding_boxes,\n        height=height,\n        width=width,\n        bounding_box_format=bounding_box_format,\n    )\n    # Switch back to original backend\n    box_utils.backend.reset()\n    return bounding_boxes\n\n\n@keras_export(\"keras.utils.bounding_boxes.affine_transform\")\ndef affine_transform(\n    boxes,\n    angle,\n    translate_x,\n    translate_y,\n    scale,\n    shear_x,\n    shear_y,\n    height,\n    width,\n    center_x=None,\n    center_y=None,\n    bounding_box_format=\"xyxy\",\n):\n    \"\"\"Applies an affine transformation to the bounding boxes.\n\n    The `height` and `width` parameters are used to normalize the\n    translation and scaling factors.\n\n    Args:\n        boxes: The bounding boxes to transform, a tensor/array of shape\n            `(N, 4)` or `(batch_size, N, 4)`.\n        angle: Rotation angle in degrees.\n        translate_x: Horizontal translation fraction.\n        translate_y: Vertical translation fraction.\n        scale: Scaling factor.\n        shear_x: Shear angle in x-direction (degrees).\n        shear_y: Shear angle in y-direction (degrees).\n        height: Height of the image/data.\n        width: Width of the image/data.\n        center_x:  x-coordinate of the transformation center (fraction).\n        center_y: y-coordinate of the transformation center (fraction).\n        bounding_box_format: The format of the input bounding boxes. Defaults to\n            `\"xyxy\"`.\n\n    Returns:\n        The transformed bounding boxes, a tensor/array with the same shape\n        as the input `boxes`.\n    \"\"\"\n    if bounding_box_format != \"xyxy\":\n        raise NotImplementedError\n    box_utils = BoundingBox()\n    # Switch to tensorflow backend if we are in tf.data pipe\n    if backend_utils.in_tf_graph():\n        box_utils.backend.set_backend(\"tensorflow\")\n\n    boxes = box_utils.affine(\n        boxes,\n        angle,\n        translate_x,\n        translate_y,\n        scale,\n        shear_x,\n        shear_y,\n        height,\n        width,\n        center_x=center_x,\n        center_y=center_y,\n    )\n    box_utils.backend.reset()\n    return boxes\n\n\n@keras_export(\"keras.utils.bounding_boxes.crop\")\ndef crop(boxes, top, left, height, width, bounding_box_format=\"xyxy\"):\n    \"\"\"Crops bounding boxes based on the given offsets and dimensions.\n\n    This function crops bounding boxes to a specified region defined by\n    `top`, `left`, `height`, and `width`. The boxes are first converted to\n    `xyxy` format, cropped, and then returned.\n\n    Args:\n        boxes: The bounding boxes to crop.  A NumPy array or tensor of shape\n            `(N, 4)` or `(batch_size, N, 4)`.\n        top: The vertical offset of the top-left corner of the cropping region.\n        left: The horizontal offset of the top-left corner of the cropping\n            region.\n        height: The height of the cropping region. Defaults to `None`.\n        width: The width of the cropping region. Defaults to `None`.\n        bounding_box_format: The format of the input bounding boxes. Defaults to\n            `\"xyxy\"`.\n\n    Returns:\n        The cropped bounding boxes.\n\n    Example:\n    ```python\n    boxes = np.array([[10, 20, 50, 60], [70, 80, 100, 120]])  # xyxy format\n    cropped_boxes = keras.utils.bounding_boxes.crop(\n        boxes, bounding_box_format=\"xyxy\", top=10, left=20, height=40, width=30\n    )  # Cropping a 30x40 region starting at (20, 10)\n    print(cropped_boxes)\n    # Expected output:\n    # array([[ 0., 10., 30., 50.],\n    #        [50., 70., 80., 110.]])\n    \"\"\"\n    if bounding_box_format != \"xyxy\":\n        raise NotImplementedError\n    box_utils = BoundingBox()\n    # Switch to tensorflow backend if we are in tf.data pipe\n    if backend_utils.in_tf_graph():\n        box_utils.backend.set_backend(\"tensorflow\")\n    outputs = box_utils.crop(boxes, top, left, height, width)\n    box_utils.backend.reset()\n    return outputs\n\n\n@keras_export(\"keras.utils.bounding_boxes.pad\")\ndef pad(boxes, top, left, height=None, width=None, bounding_box_format=\"xyxy\"):\n    \"\"\"Pads bounding boxes by adding top and left offsets.\n\n    This function adds padding to the bounding boxes by increasing the 'top'\n    and 'left' coordinates by the specified amounts. The method assume the\n    input bounding_box_format is `xyxy`.\n\n    Args:\n        boxes: Bounding boxes to pad. Shape `(N, 4)` or `(batch, N, 4)`.\n        top: Vertical padding to add.\n        left: Horizontal padding to add.\n        height: Image height. Defaults to None.\n        width: Image width. Defaults to None.\n        bounding_box_format: The format of the input bounding boxes. Defaults to\n            `\"xyxy\"`.\n\n    Returns:\n        Padded bounding boxes in the original format.\n    \"\"\"\n    if bounding_box_format != \"xyxy\":\n        raise NotImplementedError\n    box_utils = BoundingBox()\n    # Switch to tensorflow backend if we are in tf.data pipe\n    if backend_utils.in_tf_graph():\n        box_utils.backend.set_backend(\"tensorflow\")\n    outputs = box_utils.pad(boxes, top, left)\n    box_utils.backend.reset()\n    return outputs\n\n\n@keras_export(\"keras.utils.bounding_boxes.encode_box_to_deltas\")\ndef encode_box_to_deltas(\n    anchors,\n    boxes,\n    anchor_format,\n    box_format,\n    encoding_format=\"center_yxhw\",\n    variance=None,\n    image_shape=None,\n):\n    \"\"\"Encodes bounding boxes relative to anchors as deltas.\n\n    This function calculates the deltas that represent the difference between\n    bounding boxes and provided anchors. Deltas encode the offsets and scaling\n    factors to apply to anchors to obtain the target boxes.\n\n    Boxes and anchors are first converted to the specified `encoding_format`\n    (defaulting to `center_yxhw`) for consistent delta representation.\n\n    Args:\n        anchors: `Tensors`. Anchor boxes with shape of `(N, 4)` where N is the\n            number of anchors.\n        boxes:  `Tensors` Bounding boxes to encode. Boxes can be of shape\n            `(B, N, 4)` or `(N, 4)`.\n        anchor_format: str. The format of the input `anchors`\n            (e.g., \"xyxy\", \"xywh\", etc.).\n        box_format: str. The format of the input `boxes`\n            (e.g., \"xyxy\", \"xywh\", etc.).\n        encoding_format: str. The intermediate format to which boxes and anchors\n            are converted before delta calculation. Defaults to \"center_yxhw\".\n        variance: `List[float]`. A 4-element array/tensor representing variance\n            factors to scale the box deltas. If provided, the calculated deltas\n            are divided by the variance. Defaults to None.\n        image_shape: `Tuple[int]`. The shape of the image (height, width, 3).\n            When using relative bounding box format for `box_format` the\n            `image_shape` is used for normalization.\n    Returns:\n        Encoded box deltas. The return type matches the `encode_format`.\n\n    Raises:\n        ValueError: If `variance` is not None and its length is not 4.\n        ValueError: If `encoding_format` is not `\"center_xywh\"` or\n            `\"center_yxhw\"`.\n\n    \"\"\"\n    if variance is not None:\n        variance = ops.convert_to_tensor(variance, \"float32\")\n        var_len = variance.shape[-1]\n\n        if var_len != 4:\n            raise ValueError(f\"`variance` must be length 4, got {variance}\")\n\n    if encoding_format not in [\"center_xywh\", \"center_yxhw\"]:\n        raise ValueError(\n            \"`encoding_format` should be one of 'center_xywh' or \"\n            f\"'center_yxhw', got {encoding_format}\"\n        )\n\n    if image_shape is None:\n        height, width = None, None\n    else:\n        height, width, _ = image_shape\n\n    encoded_anchors = convert_format(\n        anchors,\n        source=anchor_format,\n        target=encoding_format,\n        height=height,\n        width=width,\n    )\n    boxes = convert_format(\n        boxes,\n        source=box_format,\n        target=encoding_format,\n        height=height,\n        width=width,\n    )\n    anchor_dimensions = ops.maximum(encoded_anchors[..., 2:], backend.epsilon())\n    box_dimensions = ops.maximum(boxes[..., 2:], backend.epsilon())\n    # anchors be unbatched, boxes can either be batched or unbatched.\n    boxes_delta = ops.concatenate(\n        [\n            (boxes[..., :2] - encoded_anchors[..., :2]) / anchor_dimensions,\n            ops.log(box_dimensions / anchor_dimensions),\n        ],\n        axis=-1,\n    )\n    if variance is not None:\n        boxes_delta /= variance\n    return boxes_delta\n\n\n@keras_export(\"keras.utils.bounding_boxes.decode_deltas_to_boxes\")\ndef decode_deltas_to_boxes(\n    anchors,\n    boxes_delta,\n    anchor_format,\n    box_format,\n    encoded_format=\"center_yxhw\",\n    variance=None,\n    image_shape=None,\n):\n    \"\"\"Converts bounding boxes from delta format to the specified `box_format`.\n\n    This function decodes bounding box deltas relative to anchors to obtain the\n    final bounding box coordinates. The boxes are encoded in a specific\n    `encoded_format` (center_yxhw by default) during the decoding process.\n    This allows flexibility in how the deltas are applied to the anchors.\n\n    Args:\n        anchors: Can be `Tensors` or `Dict[Tensors]` where keys are level\n            indices and values are corresponding anchor boxes.\n            The shape of the array/tensor should be `(N, 4)` where N is the\n            number of anchors.\n        boxes_delta Can be `Tensors` or `Dict[Tensors]` Bounding box deltas\n            must have the same type and structure as `anchors`.  The\n            shape of the array/tensor can be `(N, 4)` or `(B, N, 4)` where N is\n            the number of boxes.\n        anchor_format: str. The format of the input `anchors`.\n            (e.g., `\"xyxy\"`, `\"xywh\"`, etc.)\n        box_format: str. The desired format for the output boxes.\n            (e.g., `\"xyxy\"`, `\"xywh\"`, etc.)\n        encoded_format: str. Raw output format from regression head. Defaults\n            to `\"center_yxhw\"`.\n        variance: `List[floats]`. A 4-element array/tensor representing\n            variance factors to scale the box deltas. If provided, the deltas\n            are multiplied by the variance before being applied to the anchors.\n            Defaults to None.\n        image_shape: `Tuple[int]`. The shape of the image (height, width, 3).\n            When using relative bounding box format for `box_format` the\n            `image_shape` is used for normalization.\n\n    Returns:\n        Decoded box coordinates. The return type matches the `box_format`.\n\n    Raises:\n        ValueError: If `variance` is not None and its length is not 4.\n        ValueError: If `encoded_format` is not `\"center_xywh\"` or\n            `\"center_yxhw\"`.\n\n    \"\"\"\n    if variance is not None:\n        variance = ops.convert_to_tensor(variance, \"float32\")\n        var_len = variance.shape[-1]\n\n        if var_len != 4:\n            raise ValueError(f\"`variance` must be length 4, got {variance}\")\n\n    if encoded_format not in [\"center_xywh\", \"center_yxhw\"]:\n        raise ValueError(\n            f\"`encoded_format` should be 'center_xywh' or 'center_yxhw', \"\n            f\"but got '{encoded_format}'.\"\n        )\n\n    if image_shape is None:\n        height, width = None, None\n    else:\n        height, width, _ = image_shape\n\n    def decode_single_level(anchor, box_delta):\n        encoded_anchor = convert_format(\n            anchor,\n            source=anchor_format,\n            target=encoded_format,\n            height=height,\n            width=width,\n        )\n        if variance is not None:\n            box_delta = box_delta * variance\n        # anchors be unbatched, boxes can either be batched or unbatched.\n        box = ops.concatenate(\n            [\n                box_delta[..., :2] * encoded_anchor[..., 2:]\n                + encoded_anchor[..., :2],\n                ops.exp(box_delta[..., 2:]) * encoded_anchor[..., 2:],\n            ],\n            axis=-1,\n        )\n        box = convert_format(\n            box,\n            source=encoded_format,\n            target=box_format,\n            height=height,\n            width=width,\n        )\n        return box\n\n    if isinstance(anchors, dict) and isinstance(boxes_delta, dict):\n        boxes = {}\n        for lvl, anchor in anchors.items():\n            boxes[lvl] = decode_single_level(anchor, boxes_delta[lvl])\n        return boxes\n    else:\n        return decode_single_level(anchors, boxes_delta)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/converters_test.py",
    "content": "import itertools\n\nimport numpy as np\nfrom absl.testing import parameterized\n\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (  # noqa: E501\n    affine_transform,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (  # noqa: E501\n    clip_to_image_size,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (  # noqa: E501\n    convert_format,\n)\n\n\nclass ConvertersTest(testing.TestCase):\n    def setUp(self):\n        xyxy_box = np.array(\n            [[[10, 20, 110, 120], [20, 30, 120, 130]]], dtype=\"float32\"\n        )\n        yxyx_box = np.array(\n            [[[20, 10, 120, 110], [30, 20, 130, 120]]], dtype=\"float32\"\n        )\n        rel_xyxy_box = np.array(\n            [[[0.01, 0.02, 0.11, 0.12], [0.02, 0.03, 0.12, 0.13]]],\n            dtype=\"float32\",\n        )\n        rel_yxyx_box = np.array(\n            [[[0.02, 0.01, 0.12, 0.11], [0.03, 0.02, 0.13, 0.12]]],\n            dtype=\"float32\",\n        )\n        center_xywh_box = np.array(\n            [[[60, 70, 100, 100], [70, 80, 100, 100]]], dtype=\"float32\"\n        )\n        center_yxhw_box = np.array(\n            [[[70, 60, 100, 100], [80, 70, 100, 100]]], dtype=\"float32\"\n        )\n        xywh_box = np.array(\n            [[[10, 20, 100, 100], [20, 30, 100, 100]]], dtype=\"float32\"\n        )\n        rel_xywh_box = np.array(\n            [[[0.01, 0.02, 0.1, 0.1], [0.02, 0.03, 0.1, 0.1]]], dtype=\"float32\"\n        )\n\n        self.images = np.ones([2, 1000, 1000, 3], dtype=\"float32\")\n        self.height = 1000\n        self.width = 1000\n\n        self.boxes = {\n            \"xyxy\": xyxy_box,\n            \"center_xywh\": center_xywh_box,\n            \"rel_xywh\": rel_xywh_box,\n            \"xywh\": xywh_box,\n            \"rel_xyxy\": rel_xyxy_box,\n            \"yxyx\": yxyx_box,\n            \"rel_yxyx\": rel_yxyx_box,\n            \"center_yxhw\": center_yxhw_box,\n        }\n\n    @parameterized.named_parameters(\n        *[\n            (f\"{source}_{target}\", source, target)\n            for (source, target) in itertools.permutations(\n                [\n                    \"xyxy\",\n                    \"yxyx\",\n                    \"xywh\",\n                    \"rel_xyxy\",\n                    \"rel_yxyx\",\n                    \"center_xywh\",\n                    \"center_yxhw\",\n                ],\n                2,\n            )\n        ]\n        + [(\"xyxy_xyxy\", \"xyxy\", \"xyxy\")]\n    )\n    def test_convert_all_formats(self, source, target):\n        source_box = self.boxes[source]\n        target_box = self.boxes[target]\n        self.assertAllClose(\n            convert_format(\n                source_box,\n                source=source,\n                target=target,\n                height=self.height,\n                width=self.width,\n            ),\n            target_box,\n        )\n\n    def test_convert_format_invalid_source(self):\n        boxes = self.boxes[\"xywh\"]\n        with self.assertRaises(ValueError):\n            convert_format(boxes, source=\"invalid\", target=\"xywh\")\n\n    def test_convert_format_invalid_target(self):\n        boxes = self.boxes[\"xyxy\"]\n        with self.assertRaises(ValueError):\n            convert_format(boxes, source=\"xyxy\", target=\"invalid\")\n\n    def test_convert_format_missing_dimensions(self):\n        boxes = self.boxes[\"xyxy\"]\n        with self.assertRaisesRegex(\n            ValueError, r\"must receive `height` and `width`\"\n        ):\n            convert_format(boxes, source=\"xyxy\", target=\"rel_xyxy\")\n\n    def test_clip_to_image_size(self):\n        boxes = {\n            \"boxes\": np.array([[0.0, 0.0, 1.5, 1.6], [0.5, 0.4, 0.7, 0.8]]),\n            \"labels\": np.array([0, 1]),\n        }\n\n        expected_clipped = {\n            \"boxes\": np.array([[0.0, 0.0, 1.0, 1.0], [0.5, 0.4, 0.7, 0.8]]),\n            \"labels\": np.array([0, 1]),\n        }\n\n        clipped_boxes = clip_to_image_size(\n            boxes, bounding_box_format=\"rel_xyxy\"\n        )\n\n        self.assertAllEqual(clipped_boxes, expected_clipped)\n\n    def test_affine_identity(self):\n        # Test identity transform (no change)\n        batch_size = self.boxes[\"xyxy\"].shape[0]\n        transformed_boxes = affine_transform(\n            boxes=self.boxes[\"xyxy\"],\n            angle=np.zeros([batch_size], dtype=\"float32\"),\n            translate_x=np.zeros([batch_size], dtype=\"float32\"),\n            translate_y=np.zeros([batch_size], dtype=\"float32\"),\n            scale=np.ones([batch_size], dtype=\"float32\"),\n            shear_x=np.zeros([batch_size], dtype=\"float32\"),\n            shear_y=np.zeros([batch_size], dtype=\"float32\"),\n            height=self.height,\n            width=self.width,\n        )\n        transformed_boxes = ops.convert_to_numpy(transformed_boxes)\n        self.assertAllClose(self.boxes[\"xyxy\"], transformed_boxes)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/formats.py",
    "content": "class XYXY:\n    \"\"\"XYXY contains axis indices for the XYXY format.\n\n    All values in the XYXY format should be absolute pixel values.\n\n    The XYXY format consists of the following required indices:\n\n    - LEFT: left of the bounding box\n    - TOP: top of the bounding box\n    - RIGHT: right of the bounding box\n    - BOTTOM: bottom of the bounding box\n    \"\"\"\n\n    LEFT = 0\n    TOP = 1\n    RIGHT = 2\n    BOTTOM = 3\n\n\nclass REL_XYXY:\n    \"\"\"REL_XYXY contains axis indices for the REL_XYXY format.\n\n    REL_XYXY is like XYXY, but each value is relative to the width and height of\n    the origin image. Values are percentages of the origin images' width and\n    height respectively.\n\n    The REL_XYXY format consists of the following required indices:\n\n    - LEFT: left of the bounding box\n    - TOP: top of the bounding box\n    - RIGHT: right of the bounding box\n    - BOTTOM: bottom of the bounding box\n    \"\"\"\n\n    LEFT = 0\n    TOP = 1\n    RIGHT = 2\n    BOTTOM = 3\n\n\nclass CENTER_XYWH:\n    \"\"\"CENTER_XYWH contains axis indices for the CENTER_XYWH format.\n\n    All values in the CENTER_XYWH format should be absolute pixel values.\n\n    The CENTER_XYWH format consists of the following required indices:\n\n    - X: X coordinate of the center of the bounding box\n    - Y: Y coordinate of the center of the bounding box\n    - WIDTH: width of the bounding box\n    - HEIGHT: height of the bounding box\n    \"\"\"\n\n    X = 0\n    Y = 1\n    WIDTH = 2\n    HEIGHT = 3\n\n\nclass XYWH:\n    \"\"\"XYWH contains axis indices for the XYWH format.\n\n    All values in the XYWH format should be absolute pixel values.\n\n    The XYWH format consists of the following required indices:\n\n    - X: X coordinate of the left of the bounding box\n    - Y: Y coordinate of the top of the bounding box\n    - WIDTH: width of the bounding box\n    - HEIGHT: height of the bounding box\n    \"\"\"\n\n    X = 0\n    Y = 1\n    WIDTH = 2\n    HEIGHT = 3\n\n\nclass REL_XYWH:\n    \"\"\"REL_XYWH contains axis indices for the XYWH format.\n\n    REL_XYXY is like XYWH, but each value is relative to the width and height of\n    the origin image. Values are percentages of the origin images' width and\n    height respectively.\n\n    - X: X coordinate of the left of the bounding box\n    - Y: Y coordinate of the top of the bounding box\n    - WIDTH: width of the bounding box\n    - HEIGHT: height of the bounding box\n    \"\"\"\n\n    X = 0\n    Y = 1\n    WIDTH = 2\n    HEIGHT = 3\n\n\nclass YXYX:\n    \"\"\"YXYX contains axis indices for the YXYX format.\n\n    All values in the YXYX format should be absolute pixel values.\n\n    The YXYX format consists of the following required indices:\n\n    - TOP: top of the bounding box\n    - LEFT: left of the bounding box\n    - BOTTOM: bottom of the bounding box\n    - RIGHT: right of the bounding box\n    \"\"\"\n\n    TOP = 0\n    LEFT = 1\n    BOTTOM = 2\n    RIGHT = 3\n\n\nclass REL_YXYX:\n    \"\"\"REL_YXYX contains axis indices for the REL_YXYX format.\n\n    REL_YXYX is like YXYX, but each value is relative to the width and height of\n    the origin image. Values are percentages of the origin images' width and\n    height respectively.\n\n    The REL_YXYX format consists of the following required indices:\n\n    - TOP: top of the bounding box\n    - LEFT: left of the bounding box\n    - BOTTOM: bottom of the bounding box\n    - RIGHT: right of the bounding box\n    \"\"\"\n\n    TOP = 0\n    LEFT = 1\n    BOTTOM = 2\n    RIGHT = 3\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/iou.py",
    "content": "\"\"\"Contains functions to compute ious of bounding boxes.\"\"\"\n\nimport math\n\nimport keras\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes import (\n    converters,\n)\n\n\ndef _compute_area(box):\n    \"\"\"Computes area for bounding boxes\n\n    Args:\n      box: [N, 4] or [batch_size, N, 4] float Tensor, either batched\n        or unbatched boxes.\n    Returns:\n      a float Tensor of [N] or [batch_size, N]\n    \"\"\"\n    y_min, x_min, y_max, x_max = ops.split(box[..., :4], 4, axis=-1)\n    return ops.squeeze((y_max - y_min) * (x_max - x_min), axis=-1)\n\n\ndef _compute_intersection(boxes1, boxes2):\n    \"\"\"Computes intersection area between two sets of boxes.\n\n    Args:\n      boxes1: [N, 4] or [batch_size, N, 4] float Tensor boxes.\n      boxes2: [M, 4] or [batch_size, M, 4] float Tensor boxes.\n    Returns:\n      a [N, M] or [batch_size, N, M] float Tensor.\n    \"\"\"\n    y_min1, x_min1, y_max1, x_max1 = ops.split(boxes1[..., :4], 4, axis=-1)\n    y_min2, x_min2, y_max2, x_max2 = ops.split(boxes2[..., :4], 4, axis=-1)\n    boxes2_rank = len(boxes2.shape)\n    perm = [1, 0] if boxes2_rank == 2 else [0, 2, 1]\n    # [N, M] or [batch_size, N, M]\n    intersect_ymax = ops.minimum(y_max1, ops.transpose(y_max2, perm))\n    intersect_ymin = ops.maximum(y_min1, ops.transpose(y_min2, perm))\n    intersect_xmax = ops.minimum(x_max1, ops.transpose(x_max2, perm))\n    intersect_xmin = ops.maximum(x_min1, ops.transpose(x_min2, perm))\n\n    intersect_height = intersect_ymax - intersect_ymin\n    intersect_width = intersect_xmax - intersect_xmin\n    zeros_t = ops.cast(0, intersect_height.dtype)\n    intersect_height = ops.maximum(zeros_t, intersect_height)\n    intersect_width = ops.maximum(zeros_t, intersect_width)\n\n    return intersect_height * intersect_width\n\n\n@keras_export(\"keras.utils.bounding_boxes.compute_iou\")\ndef compute_iou(\n    boxes1,\n    boxes2,\n    bounding_box_format,\n    use_masking=False,\n    mask_val=-1,\n    image_shape=None,\n):\n    \"\"\"Computes a lookup table vector containing the ious for a given set boxes.\n\n    The lookup vector is to be indexed by [`boxes1_index`,`boxes2_index`] if\n    boxes are unbatched and by [`batch`, `boxes1_index`,`boxes2_index`] if the\n    boxes are batched.\n\n    The users can pass `boxes1` and `boxes2` to be different ranks. For example:\n    1) `boxes1`: [batch_size, M, 4], `boxes2`: [batch_size, N, 4] -> return\n        [batch_size, M, N].\n    2) `boxes1`: [batch_size, M, 4], `boxes2`: [N, 4] -> return\n        [batch_size, M, N]\n    3) `boxes1`: [M, 4], `boxes2`: [batch_size, N, 4] -> return\n        [batch_size, M, N]\n    4) `boxes1`: [M, 4], `boxes2`: [N, 4] -> return [M, N]\n\n    Args:\n        boxes1: a list of bounding boxes in 'corners' format. Can be batched or\n            unbatched.\n        boxes2: a list of bounding boxes in 'corners' format. Can be batched or\n            unbatched.\n        bounding_box_format: a case-insensitive string which is one of `\"xyxy\"`,\n            `\"rel_xyxy\"`, `\"xyWH\"`, `\"center_xyWH\"`, `\"yxyx\"`, `\"rel_yxyx\"`.\n            For detailed information on the supported format, see the\n        use_masking: whether masking will be applied. This will mask all\n            `boxes1` or `boxes2` that have values less than 0 in all its 4\n            dimensions. Default to `False`.\n        mask_val: int to mask those returned IOUs if the masking is True,\n            defaults to -1.\n        image_shape: `Tuple[int]`. The shape of the image (height, width, 3).\n            When using relative bounding box format for `box_format` the\n            `image_shape` is used for normalization.\n\n    Returns:\n        iou_lookup_table: a vector containing the pairwise ious of boxes1 and\n            boxes2.\n    \"\"\"  # noqa: E501\n\n    boxes1_rank = len(ops.shape(boxes1))\n    boxes2_rank = len(ops.shape(boxes2))\n\n    if boxes1_rank not in [2, 3]:\n        raise ValueError(\n            \"compute_iou() expects boxes1 to be batched, or to be unbatched. \"\n            f\"Received len(boxes1.shape)={boxes1_rank}, \"\n            f\"len(boxes2.shape)={boxes2_rank}. Expected either \"\n            \"len(boxes1.shape)=2 AND or len(boxes1.shape)=3.\"\n        )\n    if boxes2_rank not in [2, 3]:\n        raise ValueError(\n            \"compute_iou() expects boxes2 to be batched, or to be unbatched. \"\n            f\"Received len(boxes1.shape)={boxes1_rank}, \"\n            f\"len(boxes2.shape)={boxes2_rank}. Expected either \"\n            \"len(boxes2.shape)=2 AND or len(boxes2.shape)=3.\"\n        )\n\n    target_format = \"yxyx\"\n    if \"rel\" in bounding_box_format and image_shape is None:\n        raise ValueError(\n            \"When using relative bounding box formats (e.g. `rel_yxyx`) \"\n            \"the `image_shape` argument must be provided.\"\n            f\"Received `image_shape`: {image_shape}\"\n        )\n\n    if image_shape is None:\n        height, width = None, None\n    else:\n        height, width, _ = image_shape\n\n    boxes1 = converters.convert_format(\n        boxes1,\n        source=bounding_box_format,\n        target=target_format,\n        height=height,\n        width=width,\n    )\n\n    boxes2 = converters.convert_format(\n        boxes2,\n        source=bounding_box_format,\n        target=target_format,\n        height=height,\n        width=width,\n    )\n\n    intersect_area = _compute_intersection(boxes1, boxes2)\n    boxes1_area = _compute_area(boxes1)\n    boxes2_area = _compute_area(boxes2)\n    boxes2_area_rank = len(boxes2_area.shape)\n    boxes2_axis = 1 if (boxes2_area_rank == 2) else 0\n    boxes1_area = ops.expand_dims(boxes1_area, axis=-1)\n    boxes2_area = ops.expand_dims(boxes2_area, axis=boxes2_axis)\n    union_area = boxes1_area + boxes2_area - intersect_area\n    res = ops.divide(intersect_area, union_area + backend.epsilon())\n\n    if boxes1_rank == 2:\n        perm = [1, 0]\n    else:\n        perm = [0, 2, 1]\n\n    if not use_masking:\n        return res\n\n    mask_val_t = ops.cast(mask_val, res.dtype) * ops.ones_like(res)\n    boxes1_mask = ops.less(ops.max(boxes1, axis=-1, keepdims=True), 0.0)\n    boxes2_mask = ops.less(ops.max(boxes2, axis=-1, keepdims=True), 0.0)\n    background_mask = ops.logical_or(\n        boxes1_mask, ops.transpose(boxes2_mask, perm)\n    )\n    iou_lookup_table = ops.where(background_mask, mask_val_t, res)\n    return iou_lookup_table\n\n\n@keras_export(\"keras.utils.bounding_boxes.compute_ciou\")\ndef compute_ciou(boxes1, boxes2, bounding_box_format, image_shape=None):\n    \"\"\"\n    Computes the Complete IoU (CIoU) between two bounding boxes or between\n    two batches of bounding boxes.\n\n    CIoU loss is an extension of GIoU loss, which further improves the IoU\n    optimization for object detection. CIoU loss not only penalizes the\n    bounding box coordinates but also considers the aspect ratio and center\n    distance of the boxes. The length of the last dimension should be 4 to\n    represent the bounding boxes.\n\n    Args:\n        box1 (tensor): tensor representing the first bounding box with\n            shape (..., 4).\n        box2 (tensor): tensor representing the second bounding box with\n            shape (..., 4).\n        bounding_box_format: a case-insensitive string (for example, \"xyxy\").\n            Each bounding box is defined by these 4 values. For detailed\n            information on the supported formats, see the [KerasCV bounding box\n            documentation](https://keras.io/api/keras_cv/bounding_box/formats/).\n        image_shape: `Tuple[int]`. The shape of the image (height, width, 3).\n            When using relative bounding box format for `box_format` the\n            `image_shape` is used for normalization.\n\n    Returns:\n        tensor: The CIoU distance between the two bounding boxes.\n    \"\"\"\n    target_format = \"xyxy\"\n    if \"rel\" in bounding_box_format:\n        raise ValueError(\n            \"When using relative bounding box formats (e.g. `rel_yxyx`) \"\n            \"the `image_shape` argument must be provided.\"\n            f\"Received `image_shape`: {image_shape}\"\n        )\n\n    if image_shape is None:\n        height, width = None, None\n    else:\n        height, width, _ = image_shape\n\n    boxes1 = converters.convert_format(\n        boxes1,\n        source=bounding_box_format,\n        target=target_format,\n        height=height,\n        width=width,\n    )\n\n    boxes2 = converters.convert_format(\n        boxes2,\n        source=bounding_box_format,\n        target=target_format,\n        height=height,\n        width=width,\n    )\n\n    x_min1, y_min1, x_max1, y_max1 = ops.split(boxes1[..., :4], 4, axis=-1)\n    x_min2, y_min2, x_max2, y_max2 = ops.split(boxes2[..., :4], 4, axis=-1)\n\n    width_1 = x_max1 - x_min1\n    height_1 = y_max1 - y_min1 + keras.backend.epsilon()\n    width_2 = x_max2 - x_min2\n    height_2 = y_max2 - y_min2 + keras.backend.epsilon()\n\n    intersection_area = ops.maximum(\n        ops.minimum(x_max1, x_max2) - ops.maximum(x_min1, x_min2), 0\n    ) * ops.maximum(\n        ops.minimum(y_max1, y_max2) - ops.maximum(y_min1, y_min2), 0\n    )\n    union_area = (\n        width_1 * height_1\n        + width_2 * height_2\n        - intersection_area\n        + keras.backend.epsilon()\n    )\n    iou = ops.squeeze(\n        ops.divide(intersection_area, union_area + keras.backend.epsilon()),\n        axis=-1,\n    )\n\n    convex_width = ops.maximum(x_max1, x_max2) - ops.minimum(x_min1, x_min2)\n    convex_height = ops.maximum(y_max1, y_max2) - ops.minimum(y_min1, y_min2)\n    convex_diagonal_squared = ops.squeeze(\n        convex_width**2 + convex_height**2 + keras.backend.epsilon(),\n        axis=-1,\n    )\n    centers_distance_squared = ops.squeeze(\n        ((x_min1 + x_max1) / 2 - (x_min2 + x_max2) / 2) ** 2\n        + ((y_min1 + y_max1) / 2 - (y_min2 + y_max2) / 2) ** 2,\n        axis=-1,\n    )\n\n    v = ops.squeeze(\n        (4 / math.pi**2)\n        * ops.power(\n            (ops.arctan(width_2 / height_2) - ops.arctan(width_1 / height_1)),\n            2,\n        ),\n        axis=-1,\n    )\n    alpha = v / (v - iou + (1 + keras.backend.epsilon()))\n\n    return iou - (\n        centers_distance_squared / convex_diagonal_squared + v * alpha\n    )\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/iou_test.py",
    "content": "\"\"\"Tests for iou functions.\"\"\"\n\nimport numpy as np\n\nfrom keras.src import testing\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes import (\n    iou as iou_lib,\n)\n\n\nclass IoUTest(testing.TestCase):\n    def test_compute_single_iou(self):\n        bb1 = np.array([[100, 101, 200, 201]])\n        bb1_off_by_1 = np.array([[101, 102, 201, 202]])\n        # area of bb1 and bb1_off_by_1 are each 10000.\n        # intersection area is 99*99=9801\n        # iou=9801/(2*10000 - 9801)=0.96097656633\n        self.assertAllClose(\n            iou_lib.compute_iou(bb1, bb1_off_by_1, \"yxyx\")[0], [0.96097656633]\n        )\n\n    def test_compute_iou(self):\n        bb1 = [100, 101, 200, 201]\n        bb1_off_by_1_pred = [101, 102, 201, 202]\n        iou_bb1_bb1_off = 0.96097656633\n        top_left_bounding_box = [0, 2, 1, 3]\n        far_away_box = [1300, 1400, 1500, 1401]\n        another_far_away_pred = [1000, 1400, 1200, 1401]\n\n        # Rows represent predictions, columns ground truths\n        expected_result = np.array(\n            [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],\n            dtype=np.float32,\n        )\n\n        sample_y_true = np.array([bb1, top_left_bounding_box, far_away_box])\n        sample_y_pred = np.array(\n            [bb1_off_by_1_pred, top_left_bounding_box, another_far_away_pred],\n        )\n\n        result = iou_lib.compute_iou(sample_y_true, sample_y_pred, \"yxyx\")\n        self.assertAllClose(expected_result, result)\n\n    def test_batched_compute_iou(self):\n        bb1 = [100, 101, 200, 201]\n        bb1_off_by_1_pred = [101, 102, 201, 202]\n        iou_bb1_bb1_off = 0.96097656633\n        top_left_bounding_box = [0, 2, 1, 3]\n        far_away_box = [1300, 1400, 1500, 1401]\n        another_far_away_pred = [1000, 1400, 1200, 1401]\n\n        # Rows represent predictions, columns ground truths\n        expected_result = np.array(\n            [\n                [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],\n                [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],\n            ],\n        )\n\n        sample_y_true = np.array(\n            [\n                [bb1, top_left_bounding_box, far_away_box],\n                [bb1, top_left_bounding_box, far_away_box],\n            ],\n        )\n        sample_y_pred = np.array(\n            [\n                [\n                    bb1_off_by_1_pred,\n                    top_left_bounding_box,\n                    another_far_away_pred,\n                ],\n                [\n                    bb1_off_by_1_pred,\n                    top_left_bounding_box,\n                    another_far_away_pred,\n                ],\n            ],\n        )\n\n        result = iou_lib.compute_iou(sample_y_true, sample_y_pred, \"yxyx\")\n        self.assertAllClose(expected_result, result)\n\n    def test_batched_boxes1_unbatched_boxes2(self):\n        bb1 = [100, 101, 200, 201]\n        bb1_off_by_1_pred = [101, 102, 201, 202]\n        iou_bb1_bb1_off = 0.96097656633\n        top_left_bounding_box = [0, 2, 1, 3]\n        far_away_box = [1300, 1400, 1500, 1401]\n        another_far_away_pred = [1000, 1400, 1200, 1401]\n\n        # Rows represent predictions, columns ground truths\n        expected_result = np.array(\n            [\n                [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],\n                [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],\n            ],\n        )\n\n        sample_y_true = np.array(\n            [\n                [bb1, top_left_bounding_box, far_away_box],\n                [bb1, top_left_bounding_box, far_away_box],\n            ],\n        )\n        sample_y_pred = np.array(\n            [bb1_off_by_1_pred, top_left_bounding_box, another_far_away_pred],\n        )\n\n        result = iou_lib.compute_iou(sample_y_true, sample_y_pred, \"yxyx\")\n        self.assertAllClose(expected_result, result)\n\n    def test_unbatched_boxes1_batched_boxes2(self):\n        bb1 = [100, 101, 200, 201]\n        bb1_off_by_1_pred = [101, 102, 201, 202]\n        iou_bb1_bb1_off = 0.96097656633\n        top_left_bounding_box = [0, 2, 1, 3]\n        far_away_box = [1300, 1400, 1500, 1401]\n        another_far_away_pred = [1000, 1400, 1200, 1401]\n\n        # Rows represent predictions, columns ground truths\n        expected_result = np.array(\n            [\n                [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],\n                [[iou_bb1_bb1_off, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],\n            ],\n        )\n\n        sample_y_true = np.array(\n            [\n                [bb1, top_left_bounding_box, far_away_box],\n            ],\n        )\n        sample_y_pred = np.array(\n            [\n                [\n                    bb1_off_by_1_pred,\n                    top_left_bounding_box,\n                    another_far_away_pred,\n                ],\n                [\n                    bb1_off_by_1_pred,\n                    top_left_bounding_box,\n                    another_far_away_pred,\n                ],\n            ],\n        )\n\n        result = iou_lib.compute_iou(sample_y_true, sample_y_pred, \"yxyx\")\n        self.assertAllClose(expected_result, result)\n\n\nclass CIoUTest(testing.TestCase):\n    def test_compute_single_ciou(self):\n        bb1 = np.array([[100, 101, 200, 201]])\n        bb2 = np.array([[101, 102, 201, 202]])\n        self.assertAllClose(\n            iou_lib.compute_ciou(bb1, bb2, \"yxyx\")[0], [0.96087853672]\n        )\n\n    def test_compute_ciou(self):\n        bb1 = np.array([100, 101, 200, 201])\n        bb2 = np.array([150, 150, 250, 250])\n        ciou_bb1_bb2 = 0.036492417\n\n        # non overlapping case\n        far_away_bb1 = np.array([1000, 1000, 1500, 1500])\n        far_away_bb2 = np.array([2000, 2000, 2500, 2500])\n        ciou_far_away_bb1_bb2 = -0.44444444435\n\n        sample_y_true = np.array([bb1, far_away_bb1])\n        sample_y_pred = np.array([bb2, far_away_bb2])\n\n        result = iou_lib.compute_ciou(sample_y_true, sample_y_pred, \"yxyx\")\n        self.assertAllClose(ciou_bb1_bb2, result[0])\n        self.assertAllClose(ciou_far_away_bb1_bb2, result[1])\n\n    def test_batched_compute_ciou(self):\n        bb1 = np.array([100, 101, 200, 201])\n        bb2 = np.array([150, 150, 250, 250])\n        ciou_bb1_bb2 = 0.036492417\n\n        # non overlapping case\n        far_away_bb1 = np.array([1000, 1000, 1500, 1500])\n        far_away_bb2 = np.array([2000, 2000, 2500, 2500])\n        ciou_far_away_bb1_bb2 = -0.44444444435\n\n        sample_y_true = np.array([[bb1, far_away_bb1], [bb1, far_away_bb1]])\n        sample_y_pred = np.array([[bb2, far_away_bb2], [bb2, far_away_bb2]])\n\n        result = iou_lib.compute_ciou(sample_y_true, sample_y_pred, \"yxyx\")\n        self.assertAllClose(ciou_bb1_bb2, result[0][0])\n        self.assertAllClose(ciou_bb1_bb2, result[1][0])\n        self.assertAllClose(ciou_far_away_bb1_bb2, result[0][1])\n        self.assertAllClose(ciou_far_away_bb1_bb2, result[1][1])\n\n    def test_batched_boxes1_unbatched_boxes2(self):\n        bb1 = np.array([100, 101, 200, 201])\n        bb2 = np.array([150, 150, 250, 250])\n        ciou_bb1_bb2 = 0.036492417\n\n        # non overlapping case\n        far_away_bb1 = np.array([1000, 1000, 1500, 1500])\n        far_away_bb2 = np.array([2000, 2000, 2500, 2500])\n        ciou_far_away_bb1_bb2 = -0.44444444435\n\n        sample_y_true = np.array([[bb1, far_away_bb1], [bb1, far_away_bb1]])\n        sample_y_pred = np.array([bb2, far_away_bb2])\n\n        result = iou_lib.compute_ciou(sample_y_true, sample_y_pred, \"yxyx\")\n        self.assertAllClose(ciou_bb1_bb2, result[0][0])\n        self.assertAllClose(ciou_bb1_bb2, result[1][0])\n        self.assertAllClose(ciou_far_away_bb1_bb2, result[0][1])\n        self.assertAllClose(ciou_far_away_bb1_bb2, result[1][1])\n\n    def test_unbatched_boxes1_batched_boxes2(self):\n        bb1 = np.array([100, 101, 200, 201])\n        bb2 = np.array([150, 150, 250, 250])\n        ciou_bb1_bb2 = 0.036492417\n\n        # non overlapping case\n        far_away_bb1 = np.array([1000, 1000, 1500, 1500])\n        far_away_bb2 = np.array([2000, 2000, 2500, 2500])\n        ciou_far_away_bb1_bb2 = -0.44444444435\n\n        sample_y_true = np.array([bb1, far_away_bb1])\n        sample_y_pred = np.array([[bb2, far_away_bb2], [bb2, far_away_bb2]])\n\n        result = iou_lib.compute_ciou(sample_y_true, sample_y_pred, \"yxyx\")\n        self.assertAllClose(ciou_bb1_bb2, result[0][0])\n        self.assertAllClose(ciou_bb1_bb2, result[1][0])\n        self.assertAllClose(ciou_far_away_bb1_bb2, result[0][1])\n        self.assertAllClose(ciou_far_away_bb1_bb2, result[1][1])\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation.py",
    "content": "from keras.src import backend as current_backend\nfrom keras.src.utils import tf_utils\n\n\ndef _classes_shape(batched, classes_shape, max_boxes):\n    if max_boxes is None:\n        return None\n    if batched:\n        return [None, max_boxes] + classes_shape[2:]\n    return [max_boxes] + classes_shape[1:]\n\n\ndef _box_shape(batched, boxes_shape, max_boxes):\n    # ensure we dont drop the final axis in RaggedTensor mode\n    if max_boxes is None:\n        shape = list(boxes_shape)\n        shape[-1] = 4\n        return shape\n    if batched:\n        return [None, max_boxes, 4]\n    return [max_boxes, 4]\n\n\ndef densify_bounding_boxes(\n    bounding_boxes,\n    is_batched=False,\n    max_boxes=None,\n    boxes_default_value=0,\n    labels_default_value=-1,\n    backend=None,\n):\n    validate_bounding_boxes(bounding_boxes)\n    boxes = bounding_boxes[\"boxes\"]\n    labels = bounding_boxes[\"labels\"]\n    backend = backend or current_backend\n    if isinstance(boxes, list):\n        if boxes and isinstance(boxes[0], list):\n            if boxes[0] and isinstance(boxes[0][0], list):\n                # Batched case\n                if not isinstance(labels[0][0], int):\n                    raise ValueError(\n                        \"If providing `bounding_boxes['labels']` as a list, \"\n                        \"it should contain integers labels. Received: \"\n                        f\"bounding_boxes['labels']={labels}\"\n                    )\n                if max_boxes is not None:\n                    max_boxes = max([len(b) for b in boxes])\n                new_boxes = []\n                new_labels = []\n                for b, l in zip(boxes, labels):\n                    if len(b) >= max_boxes:\n                        new_boxes.append(b[:max_boxes])\n                        new_labels.append(l[:max_boxes])\n                    else:\n                        num_boxes_to_add = max_boxes - len(b)\n                        added_boxes = [\n                            [\n                                boxes_default_value,\n                                boxes_default_value,\n                                boxes_default_value,\n                                boxes_default_value,\n                            ]\n                            for _ in range(num_boxes_to_add)\n                        ]\n                        new_boxes.append(b + added_boxes)\n                        new_labels.append(\n                            l\n                            + [\n                                labels_default_value\n                                for _ in range(num_boxes_to_add)\n                            ]\n                        )\n            else:\n                # Unbatched case\n                if max_boxes and len(b) >= max_boxes:\n                    new_boxes = b[:max_boxes]\n                    new_labels = l[:max_boxes]\n                else:\n                    num_boxes_to_add = max_boxes - len(b)\n                    added_boxes = [\n                        [\n                            boxes_default_value,\n                            boxes_default_value,\n                            boxes_default_value,\n                            boxes_default_value,\n                        ]\n                        for _ in range(num_boxes_to_add)\n                    ]\n                    new_boxes = b + added_boxes\n                    new_labels = l + [\n                        labels_default_value for _ in range(num_boxes_to_add)\n                    ]\n            return {\n                \"boxes\": backend.convert_to_tensor(new_boxes, dtype=\"float32\"),\n                \"labels\": backend.convert_to_tensor(new_labels, dtype=\"int32\"),\n            }\n\n    if tf_utils.is_ragged_tensor(boxes):\n        bounding_boxes[\"boxes\"] = bounding_boxes[\"boxes\"].to_tensor(\n            default_value=boxes_default_value,\n            shape=_box_shape(\n                is_batched, bounding_boxes[\"boxes\"].shape, max_boxes\n            ),\n        )\n        bounding_boxes[\"labels\"] = bounding_boxes[\"labels\"].to_tensor(\n            default_value=labels_default_value,\n            shape=_classes_shape(\n                is_batched, bounding_boxes[\"labels\"].shape, max_boxes\n            ),\n        )\n        return bounding_boxes\n\n    bounding_boxes[\"boxes\"] = backend.convert_to_tensor(boxes, dtype=\"float32\")\n    bounding_boxes[\"labels\"] = backend.convert_to_tensor(labels)\n    return bounding_boxes\n\n\ndef validate_bounding_boxes(bounding_boxes):\n    if (\n        not isinstance(bounding_boxes, dict)\n        or \"labels\" not in bounding_boxes\n        or \"boxes\" not in bounding_boxes\n    ):\n        raise ValueError(\n            \"Expected `bounding_boxes` agurment to be a \"\n            \"dict with keys 'boxes' and 'labels'. Received: \"\n            f\"bounding_boxes={bounding_boxes}\"\n        )\n    boxes = bounding_boxes[\"boxes\"]\n    labels = bounding_boxes[\"labels\"]\n    if isinstance(boxes, list):\n        if not isinstance(labels, list):\n            raise ValueError(\n                \"If `bounding_boxes['boxes']` is a list, then \"\n                \"`bounding_boxes['labels']` must also be a list.\"\n                f\"Received: bounding_boxes['labels']={labels}\"\n            )\n        if len(boxes) != len(labels):\n            raise ValueError(\n                \"If `bounding_boxes['boxes']` and \"\n                \"`bounding_boxes['labels']` are both lists, \"\n                \"they must have the same length. Received: \"\n                f\"len(bounding_boxes['boxes'])={len(boxes)} and \"\n                f\"len(bounding_boxes['labels'])={len(labels)} and \"\n            )\n    elif tf_utils.is_ragged_tensor(boxes):\n        if not tf_utils.is_ragged_tensor(labels):\n            raise ValueError(\n                \"If `bounding_boxes['boxes']` is a Ragged tensor, \"\n                \" `bounding_boxes['labels']` must also be a \"\n                \"Ragged tensor. \"\n                f\"Received: bounding_boxes['labels']={labels}\"\n            )\n    else:\n        boxes_shape = current_backend.shape(boxes)\n        labels_shape = current_backend.shape(labels)\n        if len(boxes_shape) == 2:  # (boxes, 4)\n            if len(labels_shape) not in {1, 2}:\n                raise ValueError(\n                    \"Found \"\n                    f\"bounding_boxes['boxes'].shape={boxes_shape} \"\n                    \"and expected bounding_boxes['labels'] to have \"\n                    \"rank 1 or 2, but received: \"\n                    f\"bounding_boxes['labels'].shape={labels_shape} \"\n                )\n        elif len(boxes_shape) == 3:\n            if len(labels_shape) not in {2, 3}:\n                raise ValueError(\n                    \"Found \"\n                    f\"bounding_boxes['boxes'].shape={boxes_shape} \"\n                    \"and expected bounding_boxes['labels'] to have \"\n                    \"rank 2 or 3, but received: \"\n                    f\"bounding_boxes['labels'].shape={labels_shape} \"\n                )\n        else:\n            raise ValueError(\n                \"Expected `bounding_boxes['boxes']` \"\n                \"to have rank 2 or 3, with shape \"\n                \"(num_boxes, 4) or (batch_size, num_boxes, 4). \"\n                \"Received: \"\n                f\"bounding_boxes['boxes'].shape={boxes_shape}\"\n            )\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/bounding_boxes/validation_test.py",
    "content": "import pytest\nimport tensorflow as tf\n\nfrom keras.src import backend\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes import (\n    validation,\n)\nfrom keras.src.testing import test_case\n\n\n@pytest.mark.skipif(\n    backend.backend() != \"tensorflow\",\n    reason=\"The test targets TensorFlow-specific ragged tensors.\",\n)\nclass DensifyBoundingBoxesTest(test_case.TestCase):\n    def test_densify_ragged_bounding_boxes_batched(self):\n        ragged_boxes = tf.ragged.constant(\n            [\n                [[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]],\n                [[0.5, 0.5, 0.6, 0.6]],\n            ],\n            dtype=tf.float32,\n        )\n        ragged_labels = tf.ragged.constant(\n            [\n                [0, 1],\n                [2],\n            ],\n            dtype=tf.int32,\n        )\n        bounding_boxes = {\"boxes\": ragged_boxes, \"labels\": ragged_labels}\n        max_boxes = 3\n        densified_data = validation.densify_bounding_boxes(\n            bounding_boxes.copy(), is_batched=True, max_boxes=max_boxes\n        )\n        densified_boxes = densified_data[\"boxes\"]\n        densified_labels = densified_data[\"labels\"]\n        self.assertEqual(densified_boxes.shape, (2, max_boxes, 4))\n        self.assertEqual(densified_labels.shape, (2, max_boxes))\n        expected_boxes = [\n            [[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4], [0.0, 0.0, 0.0, 0.0]],\n            [[0.5, 0.5, 0.6, 0.6], [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]],\n        ]\n        expected_labels = [\n            [0, 1, -1],\n            [2, -1, -1],\n        ]\n        self.assertAllClose(densified_boxes, expected_boxes)\n        self.assertAllEqual(densified_labels, expected_labels)\n\n    def test_densify_ragged_bounding_boxes_unbatched(self):\n        ragged_boxes = tf.ragged.constant(\n            [[0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4]],\n            dtype=tf.float32,\n        )\n        ragged_labels = tf.ragged.constant([[0], [1]], dtype=tf.int32)\n        bounding_boxes = {\"boxes\": ragged_boxes, \"labels\": ragged_labels}\n        max_boxes = 4\n        densified_data = validation.densify_bounding_boxes(\n            bounding_boxes.copy(), is_batched=False, max_boxes=max_boxes\n        )\n        densified_boxes = densified_data[\"boxes\"]\n        densified_labels = densified_data[\"labels\"]\n\n        self.assertEqual(densified_boxes.shape, (max_boxes, 4))\n        self.assertEqual(densified_labels.shape, (max_boxes, 1))\n        expected_boxes = [\n            [0.1, 0.1, 0.2, 0.2],\n            [0.3, 0.3, 0.4, 0.4],\n            [0.0, 0.0, 0.0, 0.0],\n            [0.0, 0.0, 0.0, 0.0],\n        ]\n        expected_labels = [[0], [1], [-1], [-1]]\n        self.assertAllClose(densified_boxes, expected_boxes)\n        self.assertAllEqual(densified_labels, expected_labels)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/center_crop.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (  # noqa: E501\n    clip_to_image_size,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (  # noqa: E501\n    convert_format,\n)\nfrom keras.src.utils import image_utils\n\n\n@keras_export(\"keras.layers.CenterCrop\")\nclass CenterCrop(BaseImagePreprocessingLayer):\n    \"\"\"A preprocessing layer which crops images.\n\n    This layers crops the central portion of the images to a target size. If an\n    image is smaller than the target size, it will be resized and cropped\n    so as to return the largest possible window in the image that matches\n    the target aspect ratio.\n\n    Input pixel values can be of any range (e.g. `[0., 1.)` or `[0, 255]`).\n\n    Input shape:\n        3D (unbatched) or 4D (batched) tensor with shape:\n        `(..., height, width, channels)`, in `\"channels_last\"` format,\n        or `(..., channels, height, width)`, in `\"channels_first\"` format.\n\n    Output shape:\n        3D (unbatched) or 4D (batched) tensor with shape:\n        `(..., target_height, target_width, channels)`,\n        or `(..., channels, target_height, target_width)`,\n        in `\"channels_first\"` format.\n\n    If the input height/width is even and the target height/width is odd (or\n    inversely), the input image is left-padded by 1 pixel.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Args:\n        height: Integer, the height of the output shape.\n        width: Integer, the width of the output shape.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, height, width, channels)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, channels, height, width)`. It defaults to the\n            `image_data_format` value found in your Keras config file at\n            `~/.keras/keras.json`. If you never set it, then it will be\n            `\"channels_last\"`.\n    \"\"\"\n\n    _USE_BASE_FACTOR = False\n\n    def __init__(self, height, width, data_format=None, **kwargs):\n        super().__init__(data_format=data_format, **kwargs)\n        self.height = height\n        self.width = width\n\n    def get_random_transformation(self, data, training=True, seed=None):\n        if isinstance(data, dict):\n            images = data[\"images\"]\n        else:\n            images = data\n        shape = self.backend.core.shape(images)\n        return {\"input_shape\": shape}\n\n    def transform_labels(self, labels, transformation, training=True):\n        return labels\n\n    def transform_bounding_boxes(\n        self, bounding_boxes, transformation, training=True\n    ):\n        def _get_height_width(input_shape):\n            if self.data_format == \"channels_first\":\n                input_height = input_shape[-2]\n                input_width = input_shape[-1]\n            else:\n                input_height = input_shape[-3]\n                input_width = input_shape[-2]\n            return input_height, input_width\n\n        def _get_clipped_bbox(bounding_boxes, h_end, h_start, w_end, w_start):\n            bboxes = bounding_boxes[\"boxes\"]\n            x1, y1, x2, y2 = self.backend.numpy.split(bboxes, 4, axis=-1)\n            x1 = self.backend.numpy.clip(x1, w_start, w_end) - w_start\n            y1 = self.backend.numpy.clip(y1, h_start, h_end) - h_start\n            x2 = self.backend.numpy.clip(x2, w_start, w_end) - w_start\n            y2 = self.backend.numpy.clip(y2, h_start, h_end) - h_start\n            bounding_boxes[\"boxes\"] = self.backend.numpy.concatenate(\n                [x1, y1, x2, y2], axis=-1\n            )\n            return bounding_boxes\n\n        input_shape = transformation[\"input_shape\"]\n\n        init_height, init_width = _get_height_width(input_shape)\n\n        bounding_boxes = convert_format(\n            bounding_boxes,\n            source=self.bounding_box_format,\n            target=\"xyxy\",\n            height=init_height,\n            width=init_width,\n        )\n\n        h_diff = init_height - self.height\n        w_diff = init_width - self.width\n\n        if h_diff >= 0 and w_diff >= 0:\n            h_start = int(h_diff / 2)\n            w_start = int(w_diff / 2)\n\n            h_end = h_start + self.height\n            w_end = w_start + self.width\n\n            bounding_boxes = _get_clipped_bbox(\n                bounding_boxes, h_end, h_start, w_end, w_start\n            )\n        else:\n            width = init_width\n            height = init_height\n            target_height = self.height\n            target_width = self.width\n\n            crop_height = int(float(width * target_height) / target_width)\n            crop_height = max(min(height, crop_height), 1)\n            crop_width = int(float(height * target_width) / target_height)\n            crop_width = max(min(width, crop_width), 1)\n            crop_box_hstart = int(float(height - crop_height) / 2)\n            crop_box_wstart = int(float(width - crop_width) / 2)\n\n            h_start = crop_box_hstart\n            w_start = crop_box_wstart\n\n            h_end = crop_box_hstart + crop_height\n            w_end = crop_box_wstart + crop_width\n            bounding_boxes = _get_clipped_bbox(\n                bounding_boxes, h_end, h_start, w_end, w_start\n            )\n\n            bounding_boxes = convert_format(\n                bounding_boxes,\n                source=\"xyxy\",\n                target=\"rel_xyxy\",\n                height=crop_height,\n                width=crop_width,\n            )\n\n            bounding_boxes = convert_format(\n                bounding_boxes,\n                source=\"rel_xyxy\",\n                target=\"xyxy\",\n                height=self.height,\n                width=self.width,\n            )\n\n        bounding_boxes = clip_to_image_size(\n            bounding_boxes=bounding_boxes,\n            height=self.height,\n            width=self.width,\n            bounding_box_format=\"xyxy\",\n        )\n\n        bounding_boxes = convert_format(\n            bounding_boxes,\n            source=\"xyxy\",\n            target=self.bounding_box_format,\n            height=self.height,\n            width=self.width,\n        )\n\n        return bounding_boxes\n\n    def transform_segmentation_masks(\n        self, segmentation_masks, transformation, training=True\n    ):\n        return self.transform_images(\n            segmentation_masks, transformation, training=training\n        )\n\n    def transform_images(self, images, transformation=None, training=True):\n        inputs = self.backend.cast(images, self.compute_dtype)\n        inputs_shape = self.backend.shape(inputs)\n\n        if self.data_format == \"channels_first\":\n            init_height = inputs_shape[-2]\n            init_width = inputs_shape[-1]\n        else:\n            init_height = inputs_shape[-3]\n            init_width = inputs_shape[-2]\n\n        # All these operations work both with ints (static sizes) and scalar\n        # tensors (dynamic sizes).\n        h_diff = init_height - self.height\n        w_diff = init_width - self.width\n\n        h_start = h_diff // 2\n        w_start = w_diff // 2\n\n        if (not isinstance(h_diff, int) or h_diff >= 0) and (\n            not isinstance(w_diff, int) or w_diff >= 0\n        ):\n            if len(inputs.shape) == 4:\n                if self.data_format == \"channels_first\":\n                    return inputs[\n                        :,\n                        :,\n                        h_start : h_start + self.height,\n                        w_start : w_start + self.width,\n                    ]\n                return inputs[\n                    :,\n                    h_start : h_start + self.height,\n                    w_start : w_start + self.width,\n                    :,\n                ]\n            elif len(inputs.shape) == 3:\n                if self.data_format == \"channels_first\":\n                    return inputs[\n                        :,\n                        h_start : h_start + self.height,\n                        w_start : w_start + self.width,\n                    ]\n                return inputs[\n                    h_start : h_start + self.height,\n                    w_start : w_start + self.width,\n                    :,\n                ]\n        return image_utils.smart_resize(\n            inputs,\n            [self.height, self.width],\n            data_format=self.data_format,\n            backend_module=self.backend,\n        )\n\n    def compute_output_shape(self, input_shape):\n        input_shape = list(input_shape)\n        if isinstance(input_shape[0], (list, tuple)) or len(\n            input_shape\n        ) not in (3, 4):\n            raise ValueError(\n                \"`input_shape` must be a non-nested tuple or list \"\n                \"of rank-1 with size 3 (unbatched) or 4 (batched). \"\n            )\n        if len(input_shape) == 4:\n            if self.data_format == \"channels_last\":\n                input_shape[1] = self.height\n                input_shape[2] = self.width\n            else:\n                input_shape[2] = self.height\n                input_shape[3] = self.width\n        else:\n            if self.data_format == \"channels_last\":\n                input_shape[0] = self.height\n                input_shape[1] = self.width\n            else:\n                input_shape[1] = self.height\n                input_shape[2] = self.width\n        return tuple(input_shape)\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\n            \"height\": self.height,\n            \"width\": self.width,\n            \"data_format\": self.data_format,\n        }\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/center_crop_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import testing\n\n\nclass CenterCropTest(testing.TestCase):\n    def np_center_crop(self, img, h_new, w_new, data_format=\"channels_last\"):\n        img = np.array(img)\n        if img.ndim == 4:\n            if data_format == \"channels_last\":\n                _, h, w = img.shape[:3]\n            else:\n                _, h, w = img.shape[1:]\n        else:\n            if data_format == \"channels_last\":\n                h, w = img.shape[:2]\n            else:\n                h, w = img.shape[1:]\n        h_start = (h - h_new) // 2\n        w_start = (w - w_new) // 2\n        if data_format == \"channels_last\":\n            return img[\n                ..., h_start : h_start + h_new, w_start : w_start + w_new, :\n            ]\n        else:\n            return img[\n                ..., h_start : h_start + h_new, w_start : w_start + w_new\n            ]\n\n    @pytest.mark.requires_trainable_backend\n    def test_center_crop_basics(self):\n        self.run_layer_test(\n            layers.CenterCrop,\n            init_kwargs={\n                \"height\": 6,\n                \"width\": 6,\n                \"data_format\": \"channels_last\",\n            },\n            input_shape=(2, 12, 12, 3),\n            expected_output_shape=(2, 6, 6, 3),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n        self.run_layer_test(\n            layers.CenterCrop,\n            init_kwargs={\n                \"height\": 7,\n                \"width\": 7,\n                \"data_format\": \"channels_first\",\n            },\n            input_shape=(2, 3, 13, 13),\n            expected_output_shape=(2, 3, 7, 7),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n\n    @parameterized.parameters(\n        [\n            ((5, 7), \"channels_first\"),\n            ((5, 7), \"channels_last\"),\n            ((4, 9), \"channels_first\"),\n            ((9, 4), \"channels_last\"),\n        ]\n    )\n    def test_center_crop_correctness(self, size, data_format):\n        # batched case\n        if data_format == \"channels_first\":\n            img = np.random.random((2, 3, 9, 11))\n        else:\n            img = np.random.random((2, 9, 11, 3))\n        out = layers.CenterCrop(\n            size[0],\n            size[1],\n            data_format=data_format,\n        )(img)\n        if data_format == \"channels_first\":\n            img_transpose = np.transpose(img, (0, 2, 3, 1))\n\n            ref_out = np.transpose(\n                self.np_center_crop(img_transpose, size[0], size[1]),\n                (0, 3, 1, 2),\n            )\n        else:\n            ref_out = self.np_center_crop(img, size[0], size[1])\n        self.assertAllClose(ref_out, out)\n\n        # unbatched case\n        if data_format == \"channels_first\":\n            img = np.random.random((3, 9, 11))\n        else:\n            img = np.random.random((9, 11, 3))\n        out = layers.CenterCrop(\n            size[0],\n            size[1],\n            data_format=data_format,\n        )(img)\n        if data_format == \"channels_first\":\n            img_transpose = np.transpose(img, (1, 2, 0))\n            ref_out = np.transpose(\n                self.np_center_crop(\n                    img_transpose,\n                    size[0],\n                    size[1],\n                ),\n                (2, 0, 1),\n            )\n        else:\n            ref_out = self.np_center_crop(\n                img,\n                size[0],\n                size[1],\n            )\n        self.assertAllClose(ref_out, out)\n\n    @parameterized.parameters(\n        [\n            ((15, 10), \"channels_first\"),\n            ((10, 17), \"channels_last\"),\n        ]\n    )\n    def test_input_smaller_than_crop_box(self, size, data_format):\n        \"\"\"Output should equal resizing with crop_to_aspect ratio.\"\"\"\n        # batched case\n        if data_format == \"channels_first\":\n            img = np.random.random((2, 3, 9, 11))\n        else:\n            img = np.random.random((2, 9, 11, 3))\n        out = layers.CenterCrop(\n            size[0],\n            size[1],\n            data_format=data_format,\n        )(img)\n        ref_out = layers.Resizing(\n            size[0], size[1], data_format=data_format, crop_to_aspect_ratio=True\n        )(img)\n        self.assertAllClose(ref_out, out)\n\n        # unbatched case\n        if data_format == \"channels_first\":\n            img = np.random.random((3, 9, 11))\n        else:\n            img = np.random.random((9, 11, 3))\n        out = layers.CenterCrop(\n            size[0],\n            size[1],\n            data_format=data_format,\n        )(img)\n        ref_out = layers.Resizing(\n            size[0], size[1], data_format=data_format, crop_to_aspect_ratio=True\n        )(img)\n        self.assertAllClose(ref_out, out)\n\n    def test_tf_data_compatibility(self):\n        if backend.config.image_data_format() == \"channels_last\":\n            input_shape = (2, 10, 12, 3)\n            output_shape = (2, 8, 9, 3)\n        else:\n            input_shape = (2, 3, 10, 12)\n            output_shape = (2, 3, 8, 9)\n        layer = layers.CenterCrop(8, 9)\n        input_data = np.random.random(input_shape)\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        output = next(iter(ds)).numpy()\n        self.assertEqual(tuple(output.shape), output_shape)\n\n    # TODO\n    # def test_list_compatibility(self):\n    #     if backend.config.image_data_format() == \"channels_last\":\n    #         images = [\n    #             np.random.rand(10, 10, 3),\n    #             np.random.rand(10, 10, 3),\n    #         ]\n    #         output_shape = (2, 6, 5, 3)\n    #     else:\n    #         images = [\n    #             np.random.rand(3, 10, 10),\n    #             np.random.rand(3, 10, 10),\n    #         ]\n    #         output_shape = (2, 3, 6, 5)\n    #     output = layers.CenterCrop(height=6, width=5)(images)\n    #     ref_output = self.np_center_crop(\n    #         images, 6, 5, data_format=backend.config.image_data_format()\n    #     )\n    #     self.assertEqual(tuple(output.shape), output_shape)\n    #     self.assertAllClose(ref_output, output)\n\n    @parameterized.parameters(\n        [((5, 17), \"channels_last\"), ((5, 100), \"channels_last\")]\n    )\n    def test_image_stretch(self, size, data_format):\n        img = np.random.rand(2, 11, 3, 9)\n        out = layers.CenterCrop(\n            size[0],\n            size[1],\n            data_format=data_format,\n        )(img)\n        ref_out = layers.Resizing(\n            size[0], size[1], data_format=data_format, crop_to_aspect_ratio=True\n        )(img)\n        self.assertAllClose(ref_out, out)\n\n    @parameterized.named_parameters(\n        (\n            \"normal\",\n            5,\n            5,\n            [[1.0, 0.0, 3.0, 1.0], [5.0, 2.0, 5.0, 4.0]],\n        ),\n        (\n            \"with_stretch\",\n            20,\n            20,\n            [[5.0, 0.0, 10.0, 5.0], [15.0, 7.5, 20.0, 12.5]],\n        ),\n    )\n    def test_center_crop_bounding_boxes(self, height, width, expected_boxes):\n        if backend.config.image_data_format() == \"channels_last\":\n            image_shape = (10, 8, 3)\n        else:\n            image_shape = (3, 10, 8)\n        input_image = np.random.random(image_shape)\n        bounding_boxes = {\n            \"boxes\": np.array(\n                [\n                    [2, 1, 4, 3],\n                    [6, 4, 8, 6],\n                ]\n            ),\n            \"labels\": np.array([[1, 2]]),\n        }\n        input_data = {\"images\": input_image, \"bounding_boxes\": bounding_boxes}\n        center_crop_layer = layers.CenterCrop(\n            height=height,\n            width=width,\n            bounding_box_format=\"xyxy\",\n        )\n        output = center_crop_layer(input_data)\n        self.assertAllClose(output[\"bounding_boxes\"][\"boxes\"], expected_boxes)\n\n    @parameterized.named_parameters(\n        (\n            \"normal\",\n            5,\n            5,\n            [[1.0, 0.0, 3.0, 1.0], [5.0, 2.0, 5.0, 4.0]],\n        ),\n        (\n            \"with_stretch\",\n            20,\n            20,\n            [[5.0, 0.0, 10.0, 5.0], [15.0, 7.5, 20.0, 12.5]],\n        ),\n    )\n    def test_center_crop_tf_data_bounding_boxes(\n        self, height, width, expected_boxes\n    ):\n        if backend.config.image_data_format() == \"channels_last\":\n            image_shape = (1, 10, 8, 3)\n        else:\n            image_shape = (1, 3, 10, 8)\n        input_image = np.random.random(image_shape)\n        bounding_boxes = {\n            \"boxes\": np.array(\n                [\n                    [\n                        [2, 1, 4, 3],\n                        [6, 4, 8, 6],\n                    ]\n                ]\n            ),\n            \"labels\": np.array([[1, 2]]),\n        }\n\n        input_data = {\"images\": input_image, \"bounding_boxes\": bounding_boxes}\n\n        ds = tf_data.Dataset.from_tensor_slices(input_data)\n        center_crop_layer = layers.CenterCrop(\n            height=height,\n            width=width,\n            bounding_box_format=\"xyxy\",\n        )\n        ds = ds.map(center_crop_layer)\n        output = next(iter(ds))\n        expected_boxes = np.array(expected_boxes)\n        self.assertAllClose(output[\"bounding_boxes\"][\"boxes\"], expected_boxes)\n\n    def test_dynamic_spatial_dims(self):\n        if backend.config.image_data_format() == \"channels_last\":\n            large_input = (2, 25, 30, 3)\n            small_input = (2, 6, 7, 3)\n        else:\n            large_input = (2, 3, 25, 30)\n            small_input = (2, 3, 6, 7)\n\n        model = models.Sequential([layers.CenterCrop(10, 12)])\n\n        def generator():\n            yield (np.random.random(large_input).astype(\"float32\"),)\n            yield (np.random.random(small_input).astype(\"float32\"),)\n\n        model.predict(generator())\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/clahe.py",
    "content": "from keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\n\n\n@keras_export(\"keras.layers.ContrastLimitedAdaptiveHistogramEqualization\")\nclass ContrastLimitedAdaptiveHistogramEqualization(BaseImagePreprocessingLayer):\n    \"\"\"Contrast Limited Adaptive Histogram Equalization (CLAHE) layer.\n\n    CLAHE is a variant of Adaptive Histogram Equalization (AHE) which takes care\n    of over-amplification of the contrast. It operates on small regions in the\n    image, called tiles, rather than the entire image. The neighboring tiles are\n    then combined using bilinear interpolation to remove the artificial\n    boundaries. This algorithm can be applied to improve the contrast of an\n    image.\n\n    **Note:** This layer computes histograms using `self.backend.nn.one_hot`,\n    which can be highly memory-intensive. For large batch sizes or\n    high-resolution images, it may lead to high memory consumption or\n    out-of-memory errors.\n\n    Args:\n        value_range: Optional list/tuple of 2 floats specifying the lower\n            and upper limits of the input data values. Defaults to `(0, 255)`.\n        clip_limit: Float. Limits the noise amplification in near-constant\n            regions. Defaults to 4.0.\n        tile_grid_size: Tuple of 2 integers `(height, width)`.\n            Specifies the number of tiles to divide the image into.\n            Defaults to `(8, 8)`.\n        data_format: String, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, height, width, channels)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, channels, height, width)`. It defaults to the\n            `image_data_format` value found in your Keras config file at\n            `~/.keras/keras.json`. If you never set it, then it will be\n            `\"channels_last\"`.\n\n    Input shape:\n        3D (unbatched) or 4D (batched) tensor with shape:\n        `(..., height, width, channels)`, in `\"channels_last\"` format,\n        or `(..., channels, height, width)`, in `\"channels_first\"` format.\n\n    Output shape:\n        3D (unbatched) or 4D (batched) tensor with shape:\n        `(..., height, width, channels)`,\n        or `(..., channels, height, width)`,\n        in `\"channels_first\"` format.\n\n    Example:\n\n    ```python\n    import keras\n    import numpy as np\n\n    # Create a CLAHE layer with default parameters\n    clahe = keras.layers.ContrastLimitedAdaptiveHistogramEqualization()\n\n    # Apply CLAHE to an image\n    # image values should be in the range[0, 255] by default\n    input_image = np.random.randint(0, 256, (1, 256, 256, 3))\n    output_image = clahe(input_image)\n\n    # For normalized images[0, 1]\n    clahe_normalized=keras.layers.ContrastLimitedAdaptiveHistogramEqualization(\n        value_range=(0.0, 1.0)\n    )\n    norm_image = np.random.rand(1, 256, 256, 3)\n    output_norm = clahe_normalized(norm_image)\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        value_range=(0, 255),\n        clip_limit=4.0,\n        tile_grid_size=(8, 8),\n        data_format=None,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.value_range = value_range\n        self.clip_limit = clip_limit\n        self.tile_grid_size = tile_grid_size\n        self.data_format = backend.standardize_data_format(data_format)\n\n    def transform_images(self, images, transformation=None, training=True):\n        if self.data_format == \"channels_first\":\n            if len(images.shape) == 4:\n                images = self.backend.numpy.transpose(images, (0, 2, 3, 1))\n            else:\n                images = self.backend.numpy.transpose(images, (1, 2, 0))\n\n        original_dtype = images.dtype\n\n        images = self._transform_value_range(\n            images, self.value_range, (0, 255), dtype=\"float32\"\n        )\n\n        images = self._clahe(images)\n\n        images = self._transform_value_range(\n            images, (0, 255), self.value_range, dtype=\"float32\"\n        )\n        images = self.backend.cast(images, original_dtype)\n\n        if self.data_format == \"channels_first\":\n            if len(images.shape) == 4:\n                images = self.backend.numpy.transpose(images, (0, 3, 1, 2))\n            else:\n                images = self.backend.numpy.transpose(images, (2, 0, 1))\n\n        return images\n\n    def _clahe(self, images):\n        unbatched = False\n        if len(images.shape) == 3:\n            images = self.backend.numpy.expand_dims(images, axis=0)\n            unbatched = True\n\n        shape = self.backend.core.shape(images)\n        batch_size = (\n            images.shape[0] if images.shape[0] is not None else shape[0]\n        )\n        height = images.shape[1] if images.shape[1] is not None else shape[1]\n        width = images.shape[2] if images.shape[2] is not None else shape[2]\n        channels = images.shape[3] if images.shape[3] is not None else shape[3]\n\n        grid_h, grid_w = self.tile_grid_size\n\n        tile_h = (height + grid_h - 1) // grid_h\n        tile_w = (width + grid_w - 1) // grid_w\n\n        pad_h = (tile_h * grid_h) - height\n        pad_w = (tile_w * grid_w) - width\n\n        if (\n            isinstance(pad_h, int)\n            and isinstance(pad_w, int)\n            and pad_h == 0\n            and pad_w == 0\n        ):\n            padded_images = images\n        else:\n            images_nchw = self.backend.numpy.transpose(images, (0, 3, 1, 2))\n\n            images_3d = self.backend.numpy.reshape(\n                images_nchw, (-1, height, width)\n            )\n            padded_3d = self.backend.numpy.pad(\n                images_3d, [[0, 0], [0, pad_h], [0, pad_w]], mode=\"symmetric\"\n            )\n            padded_nchw = self.backend.numpy.reshape(\n                padded_3d, (-1, channels, height + pad_h, width + pad_w)\n            )\n            padded_images = self.backend.numpy.transpose(\n                padded_nchw, (0, 2, 3, 1)\n            )\n\n        # Compute Histograms per tile\n        tiled = self.backend.numpy.reshape(\n            padded_images,\n            (batch_size, grid_h, tile_h, grid_w, tile_w, channels),\n        )\n        tiled = self.backend.numpy.transpose(tiled, (0, 1, 3, 5, 2, 4))\n\n        tiled_flat = self.backend.numpy.reshape(\n            tiled, (batch_size, grid_h, grid_w, channels, tile_h * tile_w)\n        )\n\n        tiled_int = self.backend.cast(tiled_flat, \"int32\")\n        tiled_int = self.backend.numpy.clip(tiled_int, 0, 255)\n\n        # Compute histograms via one_hot and sum\n        hists = self.backend.numpy.sum(\n            self.backend.nn.one_hot(tiled_int, 256), axis=-2\n        )\n\n        # Clip and redistribute\n        if self.clip_limit > 0:\n            limit = self.clip_limit * (tile_h * tile_w) / 256.0\n            limit = self.backend.cast(limit, hists.dtype)\n\n            clipped = self.backend.numpy.clip(hists, 0, limit)\n\n            excess = self.backend.numpy.sum(\n                hists - clipped, axis=-1, keepdims=True\n            )\n            redist = excess / 256.0\n            hists = clipped + redist\n\n        # Compute CDF\n        cdf = self.backend.numpy.cumsum(hists, axis=-1)\n        cdf_min = self.backend.numpy.min(cdf, axis=-1, keepdims=True)\n\n        numerator = (cdf - cdf_min) * 255.0\n        denominator = self.backend.cast(tile_h * tile_w, cdf.dtype) - cdf_min\n\n        denominator = self.backend.numpy.where(\n            denominator == 0,\n            self.backend.numpy.ones_like(denominator),\n            denominator,\n        )\n        cdf_norm = numerator / denominator\n        cdf_norm = self.backend.numpy.clip(cdf_norm, 0, 255)\n\n        # Interpolation\n\n        top = cdf_norm[:, 0:1, :, :, :]\n        bottom = cdf_norm[:, -1:, :, :, :]\n        cdf_padded = self.backend.numpy.concatenate(\n            [top, cdf_norm, bottom], axis=1\n        )\n\n        left = cdf_padded[:, :, 0:1, :, :]\n        right = cdf_padded[:, :, -1:, :, :]\n        cdf_padded = self.backend.numpy.concatenate(\n            [left, cdf_padded, right], axis=2\n        )\n\n        H_padded = tile_h * grid_h\n        W_padded = tile_w * grid_w\n\n        y_range = self.backend.numpy.arange(H_padded, dtype=\"float32\")\n        x_range = self.backend.numpy.arange(W_padded, dtype=\"float32\")\n\n        y_grid = (y_range - (tile_h / 2.0)) / tile_h\n        x_grid = (x_range - (tile_w / 2.0)) / tile_w\n\n        y_grid = y_grid + 1.0\n        x_grid = x_grid + 1.0\n\n        y0 = self.backend.numpy.floor(y_grid)\n        x0 = self.backend.numpy.floor(x_grid)\n        y1 = y0 + 1.0\n        x1 = x0 + 1.0\n\n        wy = y_grid - y0\n        wx = x_grid - x0\n\n        y0 = self.backend.numpy.clip(y0, 0, grid_h + 1)\n        y1 = self.backend.numpy.clip(y1, 0, grid_h + 1)\n        x0 = self.backend.numpy.clip(x0, 0, grid_w + 1)\n        x1 = self.backend.numpy.clip(x1, 0, grid_w + 1)\n\n        y0 = self.backend.cast(y0, \"int32\")\n        y1 = self.backend.cast(y1, \"int32\")\n        x0 = self.backend.cast(x0, \"int32\")\n        x1 = self.backend.cast(x1, \"int32\")\n\n        stride_c = 256\n        stride_x = stride_c * channels\n        stride_y = stride_x * (grid_w + 2)\n        stride_b = stride_y * (grid_h + 2)\n\n        cdf_flat = self.backend.numpy.reshape(cdf_padded, (-1,))\n\n        pixels = self.backend.cast(\n            self.backend.numpy.clip(padded_images, 0, 255), \"int32\"\n        )\n\n        b_idx = self.backend.numpy.arange(batch_size, dtype=\"int32\")[\n            :, None, None, None\n        ]\n\n        c_idx = self.backend.numpy.arange(channels, dtype=\"int32\")[\n            None, None, None, :\n        ]\n\n        y0_e = y0[None, :, None, None]\n        y1_e = y1[None, :, None, None]\n\n        x0_e = x0[None, None, :, None]\n        x1_e = x1[None, None, :, None]\n\n        wy_e = wy[None, :, None, None]\n        wx_e = wx[None, None, :, None]\n\n        base_idx = b_idx * stride_b + c_idx * stride_c + pixels\n\n        idx_nw = base_idx + y0_e * stride_y + x0_e * stride_x\n        val_nw = self.backend.numpy.take(cdf_flat, idx_nw)\n\n        idx_ne = base_idx + y0_e * stride_y + x1_e * stride_x\n        val_ne = self.backend.numpy.take(cdf_flat, idx_ne)\n\n        idx_sw = base_idx + y1_e * stride_y + x0_e * stride_x\n        val_sw = self.backend.numpy.take(cdf_flat, idx_sw)\n\n        idx_se = base_idx + y1_e * stride_y + x1_e * stride_x\n        val_se = self.backend.numpy.take(cdf_flat, idx_se)\n\n        top_interp = val_nw * (1.0 - wx_e) + val_ne * wx_e\n        bot_interp = val_sw * (1.0 - wx_e) + val_se * wx_e\n        result = top_interp * (1.0 - wy_e) + bot_interp * wy_e\n\n        result = result[:, :height, :width, :]\n\n        if unbatched:\n            result = self.backend.numpy.squeeze(result, axis=0)\n\n        return result\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"value_range\": self.value_range,\n                \"clip_limit\": self.clip_limit,\n                \"tile_grid_size\": self.tile_grid_size,\n                \"data_format\": self.data_format,\n            }\n        )\n        return config\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/clahe_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\nfrom tensorflow import data as tf_data\n\nfrom keras.src import layers\nfrom keras.src import ops\nfrom keras.src import testing\n\n\nclass ContrastLimitedAdaptiveHistogramEqualizationTest(testing.TestCase):\n    def assertAllInRange(self, array, min_val, max_val):\n        self.assertTrue(np.all(array >= min_val))\n        self.assertTrue(np.all(array <= max_val))\n\n    @pytest.mark.requires_trainable_backend\n    def test_layer(self):\n        self.run_layer_test(\n            layers.ContrastLimitedAdaptiveHistogramEqualization,\n            init_kwargs={\n                \"value_range\": (0, 255),\n                \"data_format\": \"channels_last\",\n                \"tile_grid_size\": (2, 2),\n            },\n            input_shape=(1, 4, 4, 3),\n            supports_masking=False,\n            expected_output_shape=(1, 4, 4, 3),\n        )\n\n        self.run_layer_test(\n            layers.ContrastLimitedAdaptiveHistogramEqualization,\n            init_kwargs={\n                \"value_range\": (0, 255),\n                \"data_format\": \"channels_first\",\n                \"tile_grid_size\": (2, 2),\n            },\n            input_shape=(1, 3, 4, 4),\n            supports_masking=False,\n            expected_output_shape=(1, 3, 4, 4),\n        )\n\n    def test_clahe_identity(self):\n        xs = np.random.uniform(size=(2, 64, 64, 3), low=0, high=255).astype(\n            np.float32\n        )\n        layer = layers.ContrastLimitedAdaptiveHistogramEqualization(\n            value_range=(0, 255), data_format=\"channels_last\"\n        )\n        xs_out = layer(xs)\n        self.assertAllInRange(ops.convert_to_numpy(xs_out), 0, 255)\n        self.assertEqual(xs_out.shape, (2, 64, 64, 3))\n\n    @parameterized.named_parameters(\n        (\"float32\", np.float32), (\"int32\", np.int32), (\"int64\", np.int64)\n    )\n    def test_input_dtypes(self, dtype):\n        xs = np.random.uniform(size=(2, 32, 32, 3), low=0, high=255).astype(\n            dtype\n        )\n        layer = layers.ContrastLimitedAdaptiveHistogramEqualization(\n            value_range=(0, 255), data_format=\"channels_last\"\n        )\n        xs_out = ops.convert_to_numpy(layer(xs))\n        self.assertAllInRange(xs_out, 0, 255)\n\n    @parameterized.named_parameters((\"0_255\", 0, 255), (\"0_1\", 0, 1))\n    def test_output_range(self, lower, upper):\n        xs = np.random.uniform(\n            size=(2, 32, 32, 3), low=lower, high=upper\n        ).astype(np.float32)\n        layer = layers.ContrastLimitedAdaptiveHistogramEqualization(\n            value_range=(lower, upper), data_format=\"channels_last\"\n        )\n        xs_out = ops.convert_to_numpy(layer(xs))\n        self.assertAllInRange(xs_out, lower, upper)\n\n    def test_grayscale_images(self):\n        xs = np.random.uniform(0, 255, size=(2, 32, 32, 1)).astype(np.float32)\n        layer = layers.ContrastLimitedAdaptiveHistogramEqualization(\n            value_range=(0, 255), data_format=\"channels_last\"\n        )\n        out = ops.convert_to_numpy(layer(xs))\n        self.assertEqual(out.shape[-1], 1)\n        self.assertAllInRange(out, 0, 255)\n\n    def test_tf_data_compatibility(self):\n        layer = layers.ContrastLimitedAdaptiveHistogramEqualization(\n            value_range=(0, 255), data_format=\"channels_last\"\n        )\n        input_data = np.random.random((2, 16, 16, 3)) * 255\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        for output in ds.take(1):\n            output_array = output.numpy()\n            self.assertAllInRange(output_array, 0, 255)\n            self.assertEqual(output_array.shape, (2, 16, 16, 3))\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/cut_mix.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    base_image_preprocessing_color_example,\n)\nfrom keras.src.random import SeedGenerator\n\n\n@keras_export(\"keras.layers.CutMix\")\nclass CutMix(BaseImagePreprocessingLayer):\n    \"\"\"CutMix data augmentation technique.\n\n    CutMix is a data augmentation method where patches are cut and pasted\n    between two images in the dataset, while the labels are also mixed\n    proportionally to the area of the patches.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    References:\n       - [CutMix paper]( https://arxiv.org/abs/1905.04899).\n\n    Args:\n        factor: A single float or a tuple of two floats between 0 and 1.\n            If a tuple of numbers is passed, a `factor` is sampled\n            between the two values.\n            If a single float is passed, a value between 0 and the passed\n            float is sampled. These values define the range from which the\n            mixing weight is sampled. A higher factor increases the variability\n            in patch sizes, leading to more diverse and larger mixed patches.\n            Defaults to 1.\n        seed: Integer. Used to create a random seed.\n\n    Example:\n\n    {{base_image_preprocessing_color_example}}\n    \"\"\"\n\n    _USE_BASE_FACTOR = False\n    _FACTOR_BOUNDS = (0, 1)\n\n    def __init__(self, factor=1.0, seed=None, data_format=None, **kwargs):\n        super().__init__(data_format=data_format, **kwargs)\n        self._set_factor(factor)\n        self.seed = seed\n        self.generator = SeedGenerator(seed)\n\n        if self.data_format == \"channels_first\":\n            self.height_axis = -2\n            self.width_axis = -1\n            self.channel_axis = -3\n        else:\n            self.height_axis = -3\n            self.width_axis = -2\n            self.channel_axis = -1\n\n    def get_random_transformation(self, data, training=True, seed=None):\n        if not training:\n            return None\n\n        if isinstance(data, dict):\n            images = data[\"images\"]\n        else:\n            images = data\n\n        images_shape = self.backend.shape(images)\n        if len(images_shape) == 3:\n            return None\n\n        batch_size = images_shape[0]\n        image_height = images_shape[self.height_axis]\n        image_width = images_shape[self.width_axis]\n\n        seed = seed or self._get_seed_generator(self.backend._backend)\n\n        mix_weight = self._generate_mix_weight(batch_size, seed)\n        ratio = self.backend.numpy.sqrt(1.0 - mix_weight)\n\n        x0, x1 = self._compute_crop_bounds(batch_size, image_width, ratio, seed)\n        y0, y1 = self._compute_crop_bounds(\n            batch_size, image_height, ratio, seed\n        )\n\n        batch_masks, mix_weight = self._generate_batch_mask(\n            images_shape,\n            (x0, x1, y0, y1),\n        )\n\n        permutation_order = self.backend.random.shuffle(\n            self.backend.numpy.arange(0, batch_size, dtype=\"int32\"),\n            seed=seed,\n        )\n\n        return {\n            \"permutation_order\": permutation_order,\n            \"batch_masks\": batch_masks,\n            \"mix_weight\": mix_weight,\n        }\n\n    def _generate_batch_mask(self, images_shape, box_corners):\n        def _generate_grid_xy(image_height, image_width):\n            grid_y, grid_x = self.backend.numpy.meshgrid(\n                self.backend.numpy.arange(\n                    image_height, dtype=self.compute_dtype\n                ),\n                self.backend.numpy.arange(\n                    image_width, dtype=self.compute_dtype\n                ),\n                indexing=\"ij\",\n            )\n            if self.data_format == \"channels_last\":\n                grid_y = self.backend.cast(\n                    grid_y[None, :, :, None], dtype=self.compute_dtype\n                )\n                grid_x = self.backend.cast(\n                    grid_x[None, :, :, None], dtype=self.compute_dtype\n                )\n            else:\n                grid_y = self.backend.cast(\n                    grid_y[None, None, :, :], dtype=self.compute_dtype\n                )\n                grid_x = self.backend.cast(\n                    grid_x[None, None, :, :], dtype=self.compute_dtype\n                )\n            return grid_x, grid_y\n\n        image_height, image_width = (\n            images_shape[self.height_axis],\n            images_shape[self.width_axis],\n        )\n        grid_x, grid_y = _generate_grid_xy(image_height, image_width)\n\n        x0, x1, y0, y1 = box_corners\n\n        x0 = x0[:, None, None, None]\n        y0 = y0[:, None, None, None]\n        x1 = x1[:, None, None, None]\n        y1 = y1[:, None, None, None]\n\n        batch_masks = (\n            (grid_x >= x0) & (grid_x < x1) & (grid_y >= y0) & (grid_y < y1)\n        )\n        batch_masks = self.backend.numpy.repeat(\n            batch_masks, images_shape[self.channel_axis], axis=self.channel_axis\n        )\n        mix_weight = 1.0 - (x1 - x0) * (y1 - y0) / (image_width * image_height)\n        return batch_masks, mix_weight\n\n    def _compute_crop_bounds(self, batch_size, image_length, crop_ratio, seed):\n        crop_length = self.backend.cast(\n            crop_ratio * image_length, dtype=self.compute_dtype\n        )\n\n        start_pos = self.backend.random.uniform(\n            shape=[batch_size],\n            minval=0,\n            maxval=1,\n            dtype=self.compute_dtype,\n            seed=seed,\n        ) * (image_length - crop_length)\n\n        end_pos = start_pos + crop_length\n\n        return start_pos, end_pos\n\n    def _generate_mix_weight(self, batch_size, seed):\n        alpha = (\n            self.backend.random.uniform(\n                shape=(),\n                minval=self.factor[0],\n                maxval=self.factor[1],\n                dtype=self.compute_dtype,\n                seed=seed,\n            )\n            + 1e-6\n        )\n        mix_weight = self.backend.random.beta(\n            (batch_size,), alpha, alpha, seed=seed, dtype=self.compute_dtype\n        )\n        return mix_weight\n\n    def transform_images(self, images, transformation=None, training=True):\n        if training and transformation is not None:\n            images = self.backend.cast(images, self.compute_dtype)\n\n            permutation_order = transformation[\"permutation_order\"]\n            batch_masks = transformation[\"batch_masks\"]\n\n            images = self.backend.numpy.where(\n                batch_masks,\n                self.backend.numpy.take(images, permutation_order, axis=0),\n                images,\n            )\n        images = self.backend.cast(images, self.compute_dtype)\n        return images\n\n    def transform_labels(self, labels, transformation, training=True):\n        if training and transformation is not None:\n            permutation_order = transformation[\"permutation_order\"]\n            mix_weight = transformation[\"mix_weight\"]\n\n            cutout_labels = self.backend.numpy.take(\n                labels, permutation_order, axis=0\n            )\n            mix_weight = self.backend.numpy.reshape(mix_weight, [-1, 1])\n            labels = mix_weight * labels + (1.0 - mix_weight) * cutout_labels\n\n        return labels\n\n    def transform_bounding_boxes(\n        self,\n        bounding_boxes,\n        transformation,\n        training=True,\n    ):\n        raise NotImplementedError()\n\n    def transform_segmentation_masks(\n        self, segmentation_masks, transformation, training=True\n    ):\n        return self.transform_images(\n            segmentation_masks, transformation, training\n        )\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def get_config(self):\n        config = {\n            \"factor\": self.factor,\n            \"seed\": self.seed,\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n\nCutMix.__doc__ = CutMix.__doc__.replace(\n    \"{{base_image_preprocessing_color_example}}\",\n    base_image_preprocessing_color_example.replace(\"{LayerName}\", \"CutMix\"),\n)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/cut_mix_test.py",
    "content": "import numpy as np\nimport pytest\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass CutMixTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_layer(self):\n        self.run_layer_test(\n            layers.CutMix,\n            init_kwargs={\n                \"factor\": 1.0,\n                \"seed\": 1,\n            },\n            input_shape=(8, 3, 4, 3),\n            supports_masking=False,\n            expected_output_shape=(8, 3, 4, 3),\n            # StatelessRandomGammaV3 is not supported on XLA_GPU_JIT\n            run_training_check=not testing.tensorflow_uses_gpu(),\n        )\n\n    def test_cut_mix_inference(self):\n        seed = 3481\n        layer = layers.CutMix()\n\n        np.random.seed(seed)\n        inputs = np.random.randint(0, 255, size=(224, 224, 3))\n        output = layer(inputs, training=False)\n        self.assertAllClose(inputs, output)\n\n    def test_cut_mix_basic(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            image1 = np.ones((2, 2, 1))\n            image2 = np.zeros((2, 2, 1))\n            inputs = np.asarray([image1, image2])\n            expected_output = np.array(\n                [\n                    [[[1.0], [1.0]], [[1.0], [1.0]]],\n                    [[[0.0], [0.0]], [[0.0], [0.0]]],\n                ]\n            )\n        else:\n            image1 = np.ones((1, 2, 2))\n            image2 = np.zeros((1, 2, 2))\n            inputs = np.asarray([image1, image2])\n            expected_output = np.asarray(\n                [\n                    [[[1.0, 1.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]],\n                    [[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]],\n                ]\n            )\n\n        layer = layers.CutMix(data_format=data_format)\n\n        transformation = {\n            \"batch_masks\": np.asarray(\n                [\n                    [[[False], [True]], [[False], [False]]],\n                    [[[False], [False]], [[True], [False]]],\n                ]\n            ),\n            \"mix_weight\": np.asarray([[[[0.7826548]]], [[[0.8133545]]]]),\n            \"permutation_order\": np.asarray([0, 1]),\n        }\n\n        output = layer.transform_images(inputs, transformation)\n\n        self.assertAllClose(expected_output, output)\n\n    def test_tf_data_compatibility(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_data = np.random.random((2, 8, 8, 3))\n        else:\n            input_data = np.random.random((2, 3, 8, 8))\n        layer = layers.CutMix(data_format=data_format)\n\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        for output in ds.take(1):\n            output.numpy()\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/equalization.py",
    "content": "from keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\n\n\n@keras_export(\"keras.layers.Equalization\")\nclass Equalization(BaseImagePreprocessingLayer):\n    \"\"\"Preprocessing layer for histogram equalization on image channels.\n\n    Histogram equalization is a technique to adjust image intensities to\n    enhance contrast by effectively spreading out the most frequent\n    intensity values. This layer applies equalization on a channel-wise\n    basis, which can improve the visibility of details in images.\n\n    This layer works with both grayscale and color images, performing\n    equalization independently on each color channel. At inference time,\n    the equalization is consistently applied.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Args:\n        value_range: Optional list/tuple of 2 floats specifying the lower\n            and upper limits of the input data values. Defaults to `[0, 255]`.\n            If the input image has been scaled, use the appropriate range\n            (e.g., `[0.0, 1.0]`). The equalization will be scaled to this\n            range, and output values will be clipped accordingly.\n        bins: Integer specifying the number of histogram bins to use for\n            equalization. Defaults to 256, which is suitable for 8-bit images.\n            Larger values can provide more granular intensity redistribution.\n\n    Input shape:\n        3D (unbatched) or 4D (batched) tensor with shape:\n        `(..., height, width, channels)`, in `\"channels_last\"` format,\n        or `(..., channels, height, width)`, in `\"channels_first\"` format.\n\n    Output shape:\n        3D (unbatched) or 4D (batched) tensor with shape:\n        `(..., target_height, target_width, channels)`,\n        or `(..., channels, target_height, target_width)`,\n        in `\"channels_first\"` format.\n\n    Example:\n\n    ```python\n    # Create an equalization layer for standard 8-bit images\n    equalizer = keras.layers.Equalization()\n\n    # An image with uneven intensity distribution\n    image = [...] # your input image\n\n    # Apply histogram equalization\n    equalized_image = equalizer(image)\n\n    # For images with custom value range\n    custom_equalizer = keras.layers.Equalization(\n        value_range=[0.0, 1.0],  # for normalized images\n        bins=128  # fewer bins for more subtle equalization\n    )\n    custom_equalized = custom_equalizer(normalized_image)\n    ```\n    \"\"\"\n\n    def __init__(\n        self, value_range=(0, 255), bins=256, data_format=None, **kwargs\n    ):\n        super().__init__(**kwargs)\n        self.bins = bins\n        self._set_value_range(value_range)\n        self.data_format = backend.standardize_data_format(data_format)\n\n    def _set_value_range(self, value_range):\n        if not isinstance(value_range, (tuple, list)):\n            raise ValueError(\n                self._VALUE_RANGE_VALIDATION_ERROR\n                + f\"Received: value_range={value_range}\"\n            )\n        if len(value_range) != 2:\n            raise ValueError(\n                self._VALUE_RANGE_VALIDATION_ERROR\n                + f\"Received: value_range={value_range}\"\n            )\n        self.value_range = sorted(value_range)\n\n    def _custom_histogram_fixed_width(self, values, value_range, nbins):\n        values = self.backend.cast(values, \"float32\")\n        value_min, value_max = value_range\n        value_min = self.backend.cast(value_min, \"float32\")\n        value_max = self.backend.cast(value_max, \"float32\")\n\n        scaled = (values - value_min) * (nbins - 1) / (value_max - value_min)\n        indices = self.backend.cast(scaled, \"int32\")\n        indices = self.backend.numpy.clip(indices, 0, nbins - 1)\n        flat_indices = self.backend.numpy.reshape(indices, [-1])\n\n        if backend.backend() == \"jax\":\n            # for JAX bincount is never jittable because of output shape\n            histogram = self.backend.numpy.zeros(nbins, dtype=\"int32\")\n            for i in range(nbins):\n                matches = self.backend.cast(\n                    self.backend.numpy.equal(flat_indices, i), \"int32\"\n                )\n                bin_count = self.backend.numpy.sum(matches)\n                one_hot = self.backend.cast(\n                    self.backend.numpy.arange(nbins) == i, \"int32\"\n                )\n                histogram = histogram + (bin_count * one_hot)\n            return histogram\n        else:\n            # TensorFlow/PyTorch/NumPy implementation using bincount\n            return self.backend.numpy.bincount(\n                flat_indices,\n                minlength=nbins,\n            )\n\n    def _scale_values(self, values, source_range, target_range):\n        source_min, source_max = source_range\n        target_min, target_max = target_range\n        scale = (target_max - target_min) / (source_max - source_min)\n        offset = target_min - source_min * scale\n        return values * scale + offset\n\n    def _equalize_channel(self, channel, value_range):\n        if value_range != (0, 255):\n            channel = self._scale_values(channel, value_range, (0, 255))\n\n        hist = self._custom_histogram_fixed_width(\n            channel, value_range=(0, 255), nbins=self.bins\n        )\n\n        nonzero_bins = self.backend.numpy.count_nonzero(hist)\n        equalized = self.backend.numpy.where(\n            nonzero_bins <= 1, channel, self._apply_equalization(channel, hist)\n        )\n\n        if value_range != (0, 255):\n            equalized = self._scale_values(equalized, (0, 255), value_range)\n\n        return equalized\n\n    def _apply_equalization(self, channel, hist):\n        cdf = self.backend.numpy.cumsum(hist)\n\n        if self.backend.name == \"jax\":\n            mask = cdf > 0\n            first_nonzero_idx = self.backend.numpy.argmax(mask)\n            cdf_min = self.backend.numpy.take(cdf, first_nonzero_idx)\n        else:\n            cdf_min = self.backend.numpy.take(\n                cdf, self.backend.numpy.nonzero(cdf)[0][0]\n            )\n\n        denominator = cdf[-1] - cdf_min\n        denominator = self.backend.numpy.where(\n            denominator == 0,\n            self.backend.numpy.ones_like(1, dtype=denominator.dtype),\n            denominator,\n        )\n\n        lookup_table = ((cdf - cdf_min) * 255) / denominator\n        lookup_table = self.backend.numpy.clip(\n            self.backend.numpy.round(lookup_table), 0, 255\n        )\n\n        scaled_channel = (channel / 255.0) * (self.bins - 1)\n        indices = self.backend.cast(\n            self.backend.numpy.clip(scaled_channel, 0, self.bins - 1), \"int32\"\n        )\n        return self.backend.numpy.take(lookup_table, indices)\n\n    def transform_images(self, images, transformation, training=True):\n        if training:\n            images = self.backend.cast(images, self.compute_dtype)\n\n            if self.data_format == \"channels_first\":\n                channels = []\n                for i in range(self.backend.core.shape(images)[-3]):\n                    channel = images[..., i, :, :]\n                    equalized = self._equalize_channel(\n                        channel, self.value_range\n                    )\n                    channels.append(equalized)\n                equalized_images = self.backend.numpy.stack(channels, axis=-3)\n            else:\n                channels = []\n                for i in range(self.backend.core.shape(images)[-1]):\n                    channel = images[..., i]\n                    equalized = self._equalize_channel(\n                        channel, self.value_range\n                    )\n                    channels.append(equalized)\n                equalized_images = self.backend.numpy.stack(channels, axis=-1)\n\n            return self.backend.cast(equalized_images, self.compute_dtype)\n        return images\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def compute_output_spec(self, inputs, **kwargs):\n        return inputs\n\n    def transform_bounding_boxes(\n        self,\n        bounding_boxes,\n        transformation,\n        training=True,\n    ):\n        return bounding_boxes\n\n    def transform_labels(self, labels, transformation, training=True):\n        return labels\n\n    def transform_segmentation_masks(\n        self, segmentation_masks, transformation, training=True\n    ):\n        return segmentation_masks\n\n    def get_config(self):\n        config = super().get_config()\n        config.update({\"bins\": self.bins, \"value_range\": self.value_range})\n        return config\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/equalization_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\nfrom tensorflow import data as tf_data\n\nfrom keras.src import layers\nfrom keras.src import ops\nfrom keras.src import testing\n\n\nclass EqualizationTest(testing.TestCase):\n    def assertAllInRange(self, array, min_val, max_val):\n        self.assertTrue(np.all(array >= min_val))\n        self.assertTrue(np.all(array <= max_val))\n\n    @pytest.mark.requires_trainable_backend\n    def test_layer(self):\n        self.run_layer_test(\n            layers.Equalization,\n            init_kwargs={\n                \"value_range\": (0, 255),\n                \"data_format\": \"channels_last\",\n            },\n            input_shape=(1, 2, 2, 3),\n            supports_masking=False,\n            expected_output_shape=(1, 2, 2, 3),\n        )\n\n        self.run_layer_test(\n            layers.Equalization,\n            init_kwargs={\n                \"value_range\": (0, 255),\n                \"data_format\": \"channels_first\",\n            },\n            input_shape=(1, 3, 2, 2),\n            supports_masking=False,\n            expected_output_shape=(1, 3, 2, 2),\n        )\n\n    def test_equalizes_to_all_bins(self):\n        xs = np.random.uniform(size=(2, 512, 512, 3), low=0, high=255).astype(\n            np.float32\n        )\n        layer = layers.Equalization(value_range=(0, 255))\n        xs = layer(xs)\n\n        for i in range(0, 256):\n            self.assertTrue(np.any(ops.convert_to_numpy(xs) == i))\n\n    @parameterized.named_parameters(\n        (\"float32\", np.float32), (\"int32\", np.int32), (\"int64\", np.int64)\n    )\n    def test_input_dtypes(self, dtype):\n        xs = np.random.uniform(size=(2, 512, 512, 3), low=0, high=255).astype(\n            dtype\n        )\n        layer = layers.Equalization(value_range=(0, 255))\n        xs = ops.convert_to_numpy(layer(xs))\n\n        for i in range(0, 256):\n            self.assertTrue(np.any(xs == i))\n        self.assertAllInRange(xs, 0, 255)\n\n    @parameterized.named_parameters((\"0_255\", 0, 255), (\"0_1\", 0, 1))\n    def test_output_range(self, lower, upper):\n        xs = np.random.uniform(\n            size=(2, 512, 512, 3), low=lower, high=upper\n        ).astype(np.float32)\n        layer = layers.Equalization(value_range=(lower, upper))\n        xs = ops.convert_to_numpy(layer(xs))\n        self.assertAllInRange(xs, lower, upper)\n\n    def test_constant_regions(self):\n        xs = np.zeros((1, 64, 64, 3), dtype=np.float32)\n        xs[:, :21, :, :] = 50\n        xs[:, 21:42, :, :] = 100\n        xs[:, 42:, :, :] = 200\n\n        layer = layers.Equalization(value_range=(0, 255))\n        equalized = ops.convert_to_numpy(layer(xs))\n\n        self.assertTrue(len(np.unique(equalized)) >= 3)\n        self.assertAllInRange(equalized, 0, 255)\n\n    def test_grayscale_images(self):\n        xs_last = np.random.uniform(0, 255, size=(2, 64, 64, 1)).astype(\n            np.float32\n        )\n        layer_last = layers.Equalization(\n            value_range=(0, 255), data_format=\"channels_last\"\n        )\n        equalized_last = ops.convert_to_numpy(layer_last(xs_last))\n        self.assertEqual(equalized_last.shape[-1], 1)\n        self.assertAllInRange(equalized_last, 0, 255)\n\n        xs_first = np.random.uniform(0, 255, size=(2, 1, 64, 64)).astype(\n            np.float32\n        )\n        layer_first = layers.Equalization(\n            value_range=(0, 255), data_format=\"channels_first\"\n        )\n        equalized_first = ops.convert_to_numpy(layer_first(xs_first))\n        self.assertEqual(equalized_first.shape[1], 1)\n        self.assertAllInRange(equalized_first, 0, 255)\n\n    def test_single_color_image(self):\n        xs_last = np.full((1, 64, 64, 3), 128, dtype=np.float32)\n        layer_last = layers.Equalization(\n            value_range=(0, 255), data_format=\"channels_last\"\n        )\n        equalized_last = ops.convert_to_numpy(layer_last(xs_last))\n        self.assertAllClose(equalized_last, 128.0)\n\n        xs_first = np.full((1, 3, 64, 64), 128, dtype=np.float32)\n        layer_first = layers.Equalization(\n            value_range=(0, 255), data_format=\"channels_first\"\n        )\n        equalized_first = ops.convert_to_numpy(layer_first(xs_first))\n        self.assertAllClose(equalized_first, 128.0)\n\n    def test_different_bin_sizes(self):\n        xs = np.random.uniform(0, 255, size=(1, 64, 64, 3)).astype(np.float32)\n        bin_sizes = [16, 64, 128, 256]\n        for bins in bin_sizes:\n            layer = layers.Equalization(value_range=(0, 255), bins=bins)\n            equalized = ops.convert_to_numpy(layer(xs))\n            self.assertAllInRange(equalized, 0, 255)\n\n    def test_tf_data_compatibility(self):\n        layer = layers.Equalization(value_range=(0, 255))\n        input_data = np.random.random((2, 8, 8, 3)) * 255\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        for output in ds.take(1):\n            output_array = output.numpy()\n            self.assertAllInRange(output_array, 0, 255)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/max_num_bounding_box.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\n\n\n@keras_export(\"keras.layers.MaxNumBoundingBoxes\")\nclass MaxNumBoundingBoxes(BaseImagePreprocessingLayer):\n    \"\"\"Ensure the maximum number of bounding boxes.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Args:\n        max_number: Desired output number of bounding boxes.\n        padding_value: The padding value of the `boxes` and `labels` in\n            `bounding_boxes`. Defaults to `-1`.\n\n    Example:\n\n    ```python\n    max_boxes_layer = keras.layers.MaxNumBoundingBoxes(\n        max_number=10,\n        fill_value=-1\n    )\n\n    images = np.random.randint(0, 255, (1, 224, 224, 3), dtype=\"uint8\")\n\n    bounding_boxes = {\n        \"boxes\": np.array([\n            [[10, 20, 100, 150], [50, 60, 200, 250], [0, 0, 50, 50]],\n        ]),\n        \"labels\": np.array([[1, 2, 3]])\n    }\n\n    result = max_boxes_layer({\n        \"images\": images,\n        \"bounding_boxes\": bounding_boxes\n    })\n    ```\n    \"\"\"\n\n    def __init__(self, max_number, fill_value=-1, **kwargs):\n        super().__init__(**kwargs)\n        self.max_number = int(max_number)\n        self.fill_value = int(fill_value)\n\n    def transform_images(self, images, transformation=None, training=True):\n        return images\n\n    def transform_labels(self, labels, transformation=None, training=True):\n        return labels\n\n    def transform_bounding_boxes(\n        self, bounding_boxes, transformation, training=True\n    ):\n        ops = self.backend\n        boxes = bounding_boxes[\"boxes\"]\n        labels = bounding_boxes[\"labels\"]\n        boxes_shape = ops.shape(boxes)\n        batch_size = boxes_shape[0]\n        num_boxes = boxes_shape[1]\n\n        # Get pad size\n        pad_size = ops.numpy.maximum(\n            ops.numpy.subtract(self.max_number, num_boxes), 0\n        )\n        boxes = boxes[:, : self.max_number, ...]\n        boxes = ops.numpy.pad(\n            boxes,\n            [[0, 0], [0, pad_size], [0, 0]],\n            constant_values=self.fill_value,\n        )\n        labels = labels[:, : self.max_number]\n        labels = ops.numpy.pad(\n            labels, [[0, 0], [0, pad_size]], constant_values=self.fill_value\n        )\n\n        # Ensure shape\n        boxes = ops.numpy.reshape(boxes, [batch_size, self.max_number, 4])\n        labels = ops.numpy.reshape(labels, [batch_size, self.max_number])\n\n        bounding_boxes = bounding_boxes.copy()\n        bounding_boxes[\"boxes\"] = boxes\n        bounding_boxes[\"labels\"] = labels\n        return bounding_boxes\n\n    def transform_segmentation_masks(\n        self, segmentation_masks, transformation=None, training=True\n    ):\n        return self.transform_images(segmentation_masks)\n\n    def compute_output_shape(self, input_shape):\n        if isinstance(input_shape, dict) and \"bounding_boxes\" in input_shape:\n            input_keys = set(input_shape[\"bounding_boxes\"].keys())\n            extra_keys = input_keys - set((\"boxes\", \"labels\"))\n            if extra_keys:\n                raise KeyError(\n                    \"There are unsupported keys in `bounding_boxes`: \"\n                    f\"{list(extra_keys)}. \"\n                    \"Only `boxes` and `labels` are supported.\"\n                )\n\n            boxes_shape = list(input_shape[\"bounding_boxes\"][\"boxes\"])\n            boxes_shape[1] = self.max_number\n            labels_shape = list(input_shape[\"bounding_boxes\"][\"labels\"])\n            labels_shape[1] = self.max_number\n            input_shape[\"bounding_boxes\"][\"boxes\"] = boxes_shape\n            input_shape[\"bounding_boxes\"][\"labels\"] = labels_shape\n        return input_shape\n\n    def get_config(self):\n        config = super().get_config()\n        config.update({\"max_number\": self.max_number})\n        return config\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/max_num_bounding_box_test.py",
    "content": "import numpy as np\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass MaxNumBoundingBoxesTest(testing.TestCase):\n    def test_max_num_bounding_boxes_basics(self):\n        self.run_layer_test(\n            layers.MaxNumBoundingBoxes,\n            init_kwargs={\n                \"max_number\": 40,\n                \"fill_value\": -1,\n            },\n            input_shape=(12, 12, 3),\n            expected_output_shape=(12, 12, 3),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n            run_training_check=False,\n        )\n\n    def test_output_shapes(self):\n        if backend.config.image_data_format() == \"channels_last\":\n            image_shape = (10, 8, 3)\n        else:\n            image_shape = (3, 10, 8)\n        input_image = np.random.random(image_shape)\n        bounding_boxes = {\n            \"boxes\": np.array(\n                [\n                    [2, 1, 4, 3],\n                    [6, 4, 8, 6],\n                ]\n            ),  # Example boxes (normalized)\n            \"labels\": np.array([1, 2]),  # Dummy labels\n        }\n        layer = layers.MaxNumBoundingBoxes(\n            max_number=40, bounding_box_format=\"xyxy\"\n        )\n\n        input_data = {\"images\": input_image, \"bounding_boxes\": bounding_boxes}\n        output = layer(input_data)\n        self.assertAllEqual(output[\"bounding_boxes\"][\"boxes\"].shape, (40, 4))\n        self.assertAllEqual(output[\"bounding_boxes\"][\"labels\"].shape, (40,))\n\n    def test_output_shapes_with_tf_data(self):\n        if backend.config.image_data_format() == \"channels_last\":\n            image_shape = (1, 10, 8, 3)\n        else:\n            image_shape = (1, 3, 10, 8)\n        input_image = np.random.random(image_shape)\n        bounding_boxes = {\n            \"boxes\": np.array(\n                [\n                    [\n                        [2, 1, 4, 3],\n                        [6, 4, 8, 6],\n                    ]\n                ]\n            ),  # Example boxes (normalized)\n            \"labels\": np.array([[1, 2]]),  # Dummy labels\n        }\n        layer = layers.MaxNumBoundingBoxes(\n            max_number=40, bounding_box_format=\"xyxy\"\n        )\n        input_data = {\"images\": input_image, \"bounding_boxes\": bounding_boxes}\n        ds = tf_data.Dataset.from_tensor_slices(input_data)\n        ds = ds.map(layer)\n        ds = ds.batch(1)\n        output = next(iter(ds))\n        self.assertAllEqual(output[\"bounding_boxes\"][\"boxes\"].shape, (1, 40, 4))\n        self.assertAllEqual(output[\"bounding_boxes\"][\"labels\"].shape, (1, 40))\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/mix_up.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\nfrom keras.src.random import SeedGenerator\nfrom keras.src.utils import backend_utils\n\n\n@keras_export(\"keras.layers.MixUp\")\nclass MixUp(BaseImagePreprocessingLayer):\n    \"\"\"MixUp implements the MixUp data augmentation technique.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    References:\n        - [MixUp paper](https://arxiv.org/abs/1710.09412).\n        - [MixUp for Object Detection paper](https://arxiv.org/pdf/1902.04103).\n\n    Args:\n        alpha: Float between 0 and 1. Controls the blending strength.\n               Smaller values mean less mixing, while larger values allow\n               for more  blending between images. Defaults to 0.2,\n               recommended for ImageNet1k classification.\n        seed: Integer. Used to create a random seed.\n\n    Example:\n    ```python\n    (images, labels), _ = keras.datasets.cifar10.load_data()\n    images, labels = images[:8], labels[:8]\n    labels = keras.ops.cast(keras.ops.one_hot(labels.flatten(), 10), \"float32\")\n    mix_up = keras.layers.MixUp(alpha=0.2)\n    output = mix_up({\"images\": images, \"labels\": labels})\n    ```\n    \"\"\"\n\n    def __init__(self, alpha=0.2, data_format=None, seed=None, **kwargs):\n        super().__init__(data_format=data_format, **kwargs)\n        self.alpha = alpha\n        self.seed = seed\n        self.generator = SeedGenerator(seed)\n\n    def get_random_transformation(self, data, training=True, seed=None):\n        if isinstance(data, dict):\n            images = data[\"images\"]\n        else:\n            images = data\n\n        images_shape = self.backend.shape(images)\n\n        if len(images_shape) == 3:\n            batch_size = 1\n        else:\n            batch_size = self.backend.shape(images)[0]\n\n        if seed is None:\n            seed = self._get_seed_generator(self.backend._backend)\n\n        permutation_order = self.backend.random.shuffle(\n            self.backend.numpy.arange(0, batch_size, dtype=\"int64\"),\n            seed=seed,\n        )\n\n        mix_weight = self.backend.random.beta(\n            (batch_size,), self.alpha, self.alpha, seed=seed\n        )\n        return {\n            \"mix_weight\": mix_weight,\n            \"permutation_order\": permutation_order,\n        }\n\n    def transform_images(self, images, transformation=None, training=True):\n        def _mix_up_input(images, transformation):\n            images = self.backend.cast(images, self.compute_dtype)\n            mix_weight = transformation[\"mix_weight\"]\n            permutation_order = transformation[\"permutation_order\"]\n            mix_weight = self.backend.cast(\n                self.backend.numpy.reshape(mix_weight, [-1, 1, 1, 1]),\n                dtype=self.compute_dtype,\n            )\n            mix_up_images = self.backend.cast(\n                self.backend.numpy.take(images, permutation_order, axis=0),\n                dtype=self.compute_dtype,\n            )\n            images = mix_weight * images + (1.0 - mix_weight) * mix_up_images\n            return images\n\n        if training:\n            images = _mix_up_input(images, transformation)\n        return images\n\n    def transform_labels(self, labels, transformation, training=True):\n        def _mix_up_labels(labels, transformation):\n            mix_weight = transformation[\"mix_weight\"]\n            permutation_order = transformation[\"permutation_order\"]\n            labels_for_mix_up = self.backend.numpy.take(\n                labels, permutation_order, axis=0\n            )\n            mix_weight = self.backend.numpy.reshape(mix_weight, [-1, 1])\n            labels = (\n                mix_weight * labels + (1.0 - mix_weight) * labels_for_mix_up\n            )\n            return labels\n\n        if training:\n            labels = _mix_up_labels(labels, transformation)\n        return labels\n\n    def transform_bounding_boxes(\n        self,\n        bounding_boxes,\n        transformation,\n        training=True,\n    ):\n        def _mix_up_bounding_boxes(bounding_boxes, transformation):\n            if backend_utils.in_tf_graph():\n                self.backend.set_backend(\"tensorflow\")\n\n            permutation_order = transformation[\"permutation_order\"]\n            # Make sure we are on cpu for torch tensors.\n            permutation_order = ops.convert_to_numpy(permutation_order)\n\n            boxes, labels = bounding_boxes[\"boxes\"], bounding_boxes[\"labels\"]\n            boxes_for_mix_up = self.backend.numpy.take(\n                boxes, permutation_order, axis=0\n            )\n\n            labels_for_mix_up = self.backend.numpy.take(\n                labels, permutation_order, axis=0\n            )\n            boxes = self.backend.numpy.concatenate(\n                [boxes, boxes_for_mix_up], axis=1\n            )\n\n            labels = self.backend.numpy.concatenate(\n                [labels, labels_for_mix_up], axis=0\n            )\n\n            self.backend.reset()\n\n            return {\"boxes\": boxes, \"labels\": labels}\n\n        if training:\n            bounding_boxes = _mix_up_bounding_boxes(\n                bounding_boxes, transformation\n            )\n        return bounding_boxes\n\n    def transform_segmentation_masks(\n        self, segmentation_masks, transformation, training=True\n    ):\n        def _mix_up_segmentation_masks(segmentation_masks, transformation):\n            mix_weight = transformation[\"mix_weight\"]\n            # Make sure we are on cpu for torch tensors.\n            mix_weight = ops.convert_to_numpy(mix_weight)\n            permutation_order = transformation[\"permutation_order\"]\n            mix_weight = self.backend.numpy.reshape(mix_weight, [-1, 1, 1, 1])\n            segmentation_masks_for_mix_up = self.backend.numpy.take(\n                segmentation_masks, permutation_order\n            )\n            segmentation_masks = (\n                mix_weight * segmentation_masks\n                + (1.0 - mix_weight) * segmentation_masks_for_mix_up\n            )\n            return segmentation_masks\n\n        if training:\n            segmentation_masks = _mix_up_segmentation_masks(\n                segmentation_masks, transformation\n            )\n        return segmentation_masks\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def get_config(self):\n        config = {\n            \"alpha\": self.alpha,\n            \"seed\": self.seed,\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/mix_up_test.py",
    "content": "import numpy as np\nimport pytest\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\nfrom keras.src.backend import convert_to_tensor\n\n\nclass MixUpTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_layer(self):\n        self.run_layer_test(\n            layers.MixUp,\n            init_kwargs={\n                \"alpha\": 0.2,\n            },\n            input_shape=(8, 3, 4, 3),\n            supports_masking=False,\n            expected_output_shape=(8, 3, 4, 3),\n            # StatelessRandomGammaV3 is not supported on XLA_GPU_JIT\n            run_training_check=not testing.tensorflow_uses_gpu(),\n        )\n\n    def test_mix_up_inference(self):\n        seed = 3481\n        layer = layers.MixUp(alpha=0.2)\n        np.random.seed(seed)\n        inputs = np.random.randint(0, 255, size=(224, 224, 3))\n        output = layer(inputs, training=False)\n        self.assertAllClose(inputs, output)\n\n    def test_mix_up_basic_functionality(self):\n        image = np.random.random((64, 64, 3))\n        mix_up_layer = layers.MixUp(alpha=1)\n        transformation = {\"mix_weight\": 1, \"permutation_order\": [0]}\n        output = mix_up_layer.transform_images(\n            image, transformation=transformation\n        )[0]\n        self.assertAllClose(output, image)\n\n        image = np.random.random((4, 64, 64, 3))\n        mix_up_layer = layers.MixUp(alpha=0.2)\n        transformation = {\"mix_weight\": 0.2, \"permutation_order\": [1, 0, 2, 3]}\n        output = mix_up_layer.transform_images(\n            image, transformation=transformation\n        )\n        self.assertNotAllClose(output, image)\n        self.assertAllClose(output.shape, image.shape)\n\n    def test_mix_up_basic_functionality_channel_first(self):\n        image = np.random.random((3, 64, 64))\n        mix_up_layer = layers.MixUp(alpha=1)\n        transformation = {\"mix_weight\": 1, \"permutation_order\": [0]}\n        output = mix_up_layer.transform_images(\n            image, transformation=transformation\n        )[0]\n        self.assertAllClose(output, image)\n\n        image = np.random.random((4, 3, 64, 64))\n        mix_up_layer = layers.MixUp(alpha=0.2)\n        transformation = {\"mix_weight\": 0.2, \"permutation_order\": [1, 0, 2, 3]}\n        output = mix_up_layer.transform_images(\n            image, transformation=transformation\n        )\n        self.assertNotAllClose(output, image)\n        self.assertAllClose(output.shape, image.shape)\n\n    def test_tf_data_compatibility(self):\n        layer = layers.MixUp()\n        input_data = np.random.random((2, 8, 8, 3))\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        for output in ds.take(1):\n            output.numpy()\n\n    def test_mix_up_bounding_boxes(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            image_shape = (10, 8, 3)\n        else:\n            image_shape = (3, 10, 8)\n        input_image = np.random.random(image_shape)\n        bounding_boxes = {\n            \"boxes\": np.array(\n                [\n                    [2, 1, 4, 3],\n                    [6, 4, 8, 6],\n                ]\n            ),\n            \"labels\": np.array([1, 2]),\n        }\n        input_data = {\"images\": input_image, \"bounding_boxes\": bounding_boxes}\n\n        expected_boxes = [[2, 1, 4, 3, 6, 4, 8, 6], [6, 4, 8, 6, 2, 1, 4, 3]]\n\n        random_flip_layer = layers.MixUp(\n            data_format=data_format,\n            seed=42,\n            bounding_box_format=\"xyxy\",\n        )\n\n        transformation = {\n            \"mix_weight\": convert_to_tensor([0.5, 0.5]),\n            \"permutation_order\": convert_to_tensor([1, 0]),\n        }\n        output = random_flip_layer.transform_bounding_boxes(\n            input_data[\"bounding_boxes\"],\n            transformation=transformation,\n            training=True,\n        )\n        self.assertAllClose(output[\"boxes\"], expected_boxes)\n\n    def test_mix_up_tf_data_bounding_boxes(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            image_shape = (1, 10, 8, 3)\n        else:\n            image_shape = (1, 3, 10, 8)\n        input_image = np.random.random(image_shape)\n        bounding_boxes = {\n            \"boxes\": np.array(\n                [\n                    [\n                        [2, 1, 4, 3],\n                        [6, 4, 8, 6],\n                    ]\n                ]\n            ),\n            \"labels\": np.array([[1, 2]]),\n        }\n\n        input_data = {\"images\": input_image, \"bounding_boxes\": bounding_boxes}\n        expected_boxes = [[2, 1, 4, 3, 6, 4, 8, 6], [6, 4, 8, 6, 2, 1, 4, 3]]\n\n        ds = tf_data.Dataset.from_tensor_slices(input_data)\n        layer = layers.MixUp(\n            data_format=data_format,\n            seed=42,\n            bounding_box_format=\"xyxy\",\n        )\n\n        transformation = {\n            \"mix_weight\": convert_to_tensor([0.5, 0.5]),\n            \"permutation_order\": convert_to_tensor([1, 0]),\n        }\n        ds = ds.map(\n            lambda x: layer.transform_bounding_boxes(\n                x[\"bounding_boxes\"],\n                transformation=transformation,\n                training=True,\n            )\n        )\n\n        output = next(iter(ds))\n        expected_boxes = np.array(expected_boxes)\n        self.assertAllClose(output[\"boxes\"], expected_boxes)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/rand_augment.py",
    "content": "import keras.src.layers as layers\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    base_image_preprocessing_transform_example,\n)\nfrom keras.src.random import SeedGenerator\nfrom keras.src.utils import backend_utils\n\n\n@keras_export(\"keras.layers.RandAugment\")\nclass RandAugment(BaseImagePreprocessingLayer):\n    \"\"\"RandAugment performs the Rand Augment operation on input images.\n\n    This layer can be thought of as an all-in-one image augmentation layer. The\n    policy implemented by this layer has been benchmarked extensively and is\n    effective on a wide variety of datasets.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    References:\n        - [RandAugment](https://arxiv.org/abs/1909.13719)\n\n    Args:\n        value_range: The range of values the input image can take.\n            Default is `(0, 255)`. Typically, this would be `(0, 1)`\n            for normalized images or `(0, 255)` for raw images.\n        num_ops: The number of augmentation operations to apply sequentially\n            to each image. Default is 2.\n        factor: The strength of the augmentation as a normalized value\n            between 0 and 1. Default is 0.5.\n        interpolation: The interpolation method to use for resizing operations.\n            Options include `nearest`, `bilinear`. Default is `bilinear`.\n        seed: Integer. Used to create a random seed.\n\n    Example:\n\n    {{base_image_preprocessing_transform_example}}\n    \"\"\"\n\n    _USE_BASE_FACTOR = False\n    _FACTOR_BOUNDS = (0, 1)\n\n    _AUGMENT_LAYERS = [\n        \"random_shear\",\n        \"random_translation\",\n        \"random_rotation\",\n        \"random_brightness\",\n        \"random_color_degeneration\",\n        \"random_contrast\",\n        \"random_sharpness\",\n        \"random_posterization\",\n        \"solarization\",\n        \"auto_contrast\",\n        \"equalization\",\n    ]\n\n    def __init__(\n        self,\n        value_range=(0, 255),\n        num_ops=2,\n        factor=0.5,\n        interpolation=\"bilinear\",\n        seed=None,\n        data_format=None,\n        **kwargs,\n    ):\n        super().__init__(data_format=data_format, **kwargs)\n\n        self.value_range = value_range\n        self.num_ops = num_ops\n        self._set_factor(factor)\n        self.interpolation = interpolation\n        self.seed = seed\n        self.generator = SeedGenerator(seed)\n\n        self.random_shear = layers.RandomShear(\n            x_factor=self.factor,\n            y_factor=self.factor,\n            interpolation=interpolation,\n            seed=self.seed,\n            data_format=data_format,\n            **kwargs,\n        )\n\n        self.random_translation = layers.RandomTranslation(\n            height_factor=self.factor,\n            width_factor=self.factor,\n            interpolation=interpolation,\n            seed=self.seed,\n            data_format=data_format,\n            **kwargs,\n        )\n\n        self.random_rotation = layers.RandomRotation(\n            factor=self.factor,\n            interpolation=interpolation,\n            seed=self.seed,\n            data_format=data_format,\n            **kwargs,\n        )\n\n        self.random_brightness = layers.RandomBrightness(\n            factor=self.factor,\n            value_range=self.value_range,\n            seed=self.seed,\n            data_format=data_format,\n            **kwargs,\n        )\n\n        self.random_color_degeneration = layers.RandomColorDegeneration(\n            factor=self.factor,\n            value_range=self.value_range,\n            seed=self.seed,\n            data_format=data_format,\n            **kwargs,\n        )\n\n        self.random_contrast = layers.RandomContrast(\n            factor=self.factor,\n            value_range=self.value_range,\n            seed=self.seed,\n            data_format=data_format,\n            **kwargs,\n        )\n\n        self.random_sharpness = layers.RandomSharpness(\n            factor=self.factor,\n            value_range=self.value_range,\n            seed=self.seed,\n            data_format=data_format,\n            **kwargs,\n        )\n\n        self.solarization = layers.Solarization(\n            addition_factor=self.factor,\n            threshold_factor=self.factor,\n            value_range=self.value_range,\n            seed=self.seed,\n            data_format=data_format,\n            **kwargs,\n        )\n\n        self.random_posterization = layers.RandomPosterization(\n            factor=max(1, int(8 * self.factor[1])),\n            value_range=self.value_range,\n            seed=self.seed,\n            data_format=data_format,\n            **kwargs,\n        )\n\n        self.auto_contrast = layers.AutoContrast(\n            value_range=self.value_range, data_format=data_format, **kwargs\n        )\n\n        self.equalization = layers.Equalization(\n            value_range=self.value_range, data_format=data_format, **kwargs\n        )\n\n    def build(self, input_shape):\n        for layer_name in self._AUGMENT_LAYERS:\n            augmentation_layer = getattr(self, layer_name)\n            augmentation_layer.build(input_shape)\n\n    def get_random_transformation(self, data, training=True, seed=None):\n        if not training:\n            return None\n\n        if backend_utils.in_tf_graph():\n            self.backend.set_backend(\"tensorflow\")\n\n            for layer_name in self._AUGMENT_LAYERS:\n                augmentation_layer = getattr(self, layer_name)\n                augmentation_layer.backend.set_backend(\"tensorflow\")\n\n        layer_idxes = self.backend.random.randint(\n            (self.num_ops,),\n            0,\n            len(self._AUGMENT_LAYERS),\n            seed=self._get_seed_generator(self.backend._backend),\n        )\n\n        transformation = {}\n        for layer_name in self._AUGMENT_LAYERS:\n            augmentation_layer = getattr(self, layer_name)\n            transformation[layer_name] = (\n                augmentation_layer.get_random_transformation(\n                    data,\n                    training=training,\n                    seed=self._get_seed_generator(self.backend._backend),\n                )\n            )\n\n        return {\n            \"transforms\": transformation,\n            \"layer_idxes\": layer_idxes,\n        }\n\n    def transform_images(self, images, transformation, training=True):\n        if training:\n            images = self.backend.cast(images, self.compute_dtype)\n\n            layer_idxes = transformation[\"layer_idxes\"]\n            transforms = transformation[\"transforms\"]\n            for i in range(self.num_ops):\n                for idx, (key, value) in enumerate(transforms.items()):\n                    augmentation_layer = getattr(self, key)\n                    images = self.backend.numpy.where(\n                        layer_idxes[i] == idx,\n                        augmentation_layer.transform_images(images, value),\n                        images,\n                    )\n\n        images = self.backend.cast(images, self.compute_dtype)\n        return images\n\n    def transform_labels(self, labels, transformation, training=True):\n        return labels\n\n    def transform_bounding_boxes(\n        self,\n        bounding_boxes,\n        transformation,\n        training=True,\n    ):\n        if training:\n            layer_idxes = transformation[\"layer_idxes\"]\n            transforms = transformation[\"transforms\"]\n            for idx, (key, value) in enumerate(transforms.items()):\n                augmentation_layer = getattr(self, key)\n\n                transformed_bounding_box = (\n                    augmentation_layer.transform_bounding_boxes(\n                        bounding_boxes.copy(), value\n                    )\n                )\n                for i in range(self.num_ops):\n                    bounding_boxes[\"boxes\"] = self.backend.numpy.where(\n                        layer_idxes[i] == idx,\n                        transformed_bounding_box[\"boxes\"],\n                        bounding_boxes[\"boxes\"],\n                    )\n\n                    bounding_boxes[\"labels\"] = self.backend.numpy.where(\n                        layer_idxes[i] == idx,\n                        transformed_bounding_box[\"labels\"],\n                        bounding_boxes[\"labels\"],\n                    )\n\n        return bounding_boxes\n\n    def transform_segmentation_masks(\n        self, segmentation_masks, transformation, training=True\n    ):\n        if training:\n            layer_idxes = transformation[\"layer_idxes\"]\n            transforms = transformation[\"transforms\"]\n            for i in range(self.num_ops):\n                for idx, (key, value) in enumerate(transforms.items()):\n                    augmentation_layer = getattr(self, key)\n                    segmentation_masks = self.backend.numpy.where(\n                        layer_idxes[i] == idx,\n                        augmentation_layer.transform_segmentation_masks(\n                            segmentation_masks, value\n                        ),\n                        segmentation_masks,\n                    )\n        return segmentation_masks\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def get_config(self):\n        config = {\n            \"value_range\": self.value_range,\n            \"num_ops\": self.num_ops,\n            \"factor\": self.factor,\n            \"interpolation\": self.interpolation,\n            \"seed\": self.seed,\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n\nRandAugment.__doc__ = RandAugment.__doc__.replace(\n    \"{{base_image_preprocessing_transform_example}}\",\n    base_image_preprocessing_transform_example.replace(\n        \"{LayerName}\", \"RandAugment\"\n    ),\n)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/rand_augment_test.py",
    "content": "import numpy as np\nimport pytest\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass RandAugmentTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_layer(self):\n        self.run_layer_test(\n            layers.RandAugment,\n            init_kwargs={\n                \"value_range\": (0, 255),\n                \"num_ops\": 2,\n                \"factor\": 1,\n                \"interpolation\": \"nearest\",\n                \"seed\": 1,\n                \"data_format\": \"channels_last\",\n            },\n            input_shape=(8, 3, 4, 3),\n            supports_masking=False,\n            expected_output_shape=(8, 3, 4, 3),\n        )\n\n    def test_rand_augment_inference(self):\n        seed = 3481\n        layer = layers.RandAugment()\n\n        np.random.seed(seed)\n        inputs = np.random.randint(0, 255, size=(224, 224, 3))\n        output = layer(inputs, training=False)\n        self.assertAllClose(inputs, output)\n\n    def test_rand_augment_basic(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_data = np.random.random((2, 8, 8, 3))\n        else:\n            input_data = np.random.random((2, 3, 8, 8))\n        layer = layers.RandAugment(data_format=data_format)\n\n        augmented_image = layer(input_data)\n        self.assertEqual(augmented_image.shape, input_data.shape)\n\n    def test_rand_augment_no_operations(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_data = np.random.random((2, 8, 8, 3))\n        else:\n            input_data = np.random.random((2, 3, 8, 8))\n        layer = layers.RandAugment(num_ops=0, data_format=data_format)\n\n        augmented_image = layer(input_data)\n        self.assertAllClose(\n            backend.convert_to_numpy(augmented_image), input_data\n        )\n\n    def test_random_augment_randomness(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_data = np.random.random((2, 8, 8, 3))\n        else:\n            input_data = np.random.random((2, 3, 8, 8))\n\n        layer = layers.RandAugment(num_ops=11, data_format=data_format)\n        augmented_image = layer(input_data)\n\n        self.assertNotAllClose(\n            backend.convert_to_numpy(augmented_image), input_data\n        )\n\n    def test_tf_data_compatibility(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_data = np.random.random((2, 8, 8, 3))\n        else:\n            input_data = np.random.random((2, 3, 8, 8))\n        layer = layers.RandAugment(data_format=data_format)\n\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        for output in ds.take(1):\n            output.numpy()\n\n    def test_rand_augment_tf_data_bounding_boxes(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            image_shape = (1, 10, 8, 3)\n        else:\n            image_shape = (1, 3, 10, 8)\n        input_image = np.random.random(image_shape)\n        bounding_boxes = {\n            \"boxes\": np.array(\n                [\n                    [\n                        [2, 1, 4, 3],\n                        [6, 4, 8, 6],\n                    ]\n                ]\n            ),\n            \"labels\": np.array([[1, 2]]),\n        }\n\n        input_data = {\"images\": input_image, \"bounding_boxes\": bounding_boxes}\n\n        ds = tf_data.Dataset.from_tensor_slices(input_data)\n        layer = layers.RandAugment(\n            data_format=data_format,\n            seed=42,\n            bounding_box_format=\"xyxy\",\n        )\n        ds.map(layer)\n\n    def test_graph_issue(self):\n        input_data = np.random.random((10, 8, 8, 3))\n        layer = layers.RandAugment()\n        ds = (\n            tf_data.Dataset.from_tensor_slices(input_data)\n            .batch(2)\n            .map(lambda x: layer.get_random_transformation(x))\n        )\n\n        key_list = []\n        for output in ds:\n            key_list.append(output[\"layer_idxes\"])\n\n        self.assertNotEqual(len(np.unique(key_list)), 1)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_brightness.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\nfrom keras.src.random.seed_generator import SeedGenerator\n\n\n@keras_export(\"keras.layers.RandomBrightness\")\nclass RandomBrightness(BaseImagePreprocessingLayer):\n    \"\"\"A preprocessing layer which randomly adjusts brightness during training.\n\n    This layer will randomly increase/reduce the brightness for the input RGB\n    images. At inference time, the output will be identical to the input.\n    Call the layer with `training=True` to adjust the brightness of the input.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Args:\n        factor: Float or a list/tuple of 2 floats between -1.0 and 1.0. The\n            factor is used to determine the lower bound and upper bound of the\n            brightness adjustment. A float value will be chosen randomly between\n            the limits. When -1.0 is chosen, the output image will be black, and\n            when 1.0 is chosen, the image will be fully white.\n            When only one float is provided, eg, 0.2,\n            then -0.2 will be used for lower bound and 0.2\n            will be used for upper bound.\n        value_range: Optional list/tuple of 2 floats\n            for the lower and upper limit\n            of the values of the input data.\n            To make no change, use `[0.0, 1.0]`, e.g., if the image input\n            has been scaled before this layer. Defaults to `[0.0, 255.0]`.\n            The brightness adjustment will be scaled to this range, and the\n            output values will be clipped to this range.\n        seed: optional integer, for fixed RNG behavior.\n\n    Inputs: 3D (HWC) or 4D (NHWC) tensor, with float or int dtype. Input pixel\n        values can be of any range (e.g. `[0., 1.)` or `[0, 255]`)\n\n    Output: 3D (HWC) or 4D (NHWC) tensor with brightness adjusted based on the\n        `factor`. By default, the layer will output floats.\n        The output value will be clipped to the range `[0, 255]`,\n        the valid range of RGB colors, and\n        rescaled based on the `value_range` if needed.\n\n    Example:\n\n    ```python\n    random_bright = keras.layers.RandomBrightness(factor=0.2)\n\n    # An image with shape [2, 2, 3]\n    image = [[[1, 2, 3], [4 ,5 ,6]], [[7, 8, 9], [10, 11, 12]]]\n\n    # Assume we randomly select the factor to be 0.1, then it will apply\n    # 0.1 * 255 to all the channel\n    output = random_bright(image, training=True)\n\n    # output will be int64 with 25.5 added to each channel and round down.\n    >>> array([[[26.5, 27.5, 28.5]\n                [29.5, 30.5, 31.5]]\n               [[32.5, 33.5, 34.5]\n                [35.5, 36.5, 37.5]]],\n              shape=(2, 2, 3), dtype=int64)\n    ```\n    \"\"\"\n\n    _VALUE_RANGE_VALIDATION_ERROR = (\n        \"The `value_range` argument should be a list of two numbers. \"\n    )\n\n    def __init__(self, factor, value_range=(0, 255), seed=None, **kwargs):\n        super().__init__(factor=factor, **kwargs)\n        self.seed = seed\n        self.generator = SeedGenerator(seed)\n        self._set_value_range(value_range)\n\n    def _set_value_range(self, value_range):\n        if not isinstance(value_range, (tuple, list)):\n            raise ValueError(\n                self._VALUE_RANGE_VALIDATION_ERROR\n                + f\"Received: value_range={value_range}\"\n            )\n        if len(value_range) != 2:\n            raise ValueError(\n                self._VALUE_RANGE_VALIDATION_ERROR\n                + f\"Received: value_range={value_range}\"\n            )\n        self.value_range = sorted(value_range)\n\n    def get_random_transformation(self, data, training=True, seed=None):\n        if isinstance(data, dict):\n            images = data[\"images\"]\n        else:\n            images = data\n        images_shape = self.backend.shape(images)\n        rank = len(images_shape)\n        if rank == 3:\n            rgb_delta_shape = (1, 1, 1)\n        elif rank == 4:\n            # Keep only the batch dim. This will ensure to have same adjustment\n            # with in one image, but different across the images.\n            rgb_delta_shape = [images_shape[0], 1, 1, 1]\n        else:\n            raise ValueError(\n                \"Expected the input image to be rank 3 or 4. Received \"\n                f\"inputs.shape={images_shape}\"\n            )\n        if not training:\n            return {\"rgb_delta\": self.backend.numpy.zeros(rgb_delta_shape)}\n\n        if seed is None:\n            seed = self._get_seed_generator(self.backend._backend)\n        rgb_delta = self.backend.random.uniform(\n            minval=self.factor[0],\n            maxval=self.factor[1],\n            shape=rgb_delta_shape,\n            seed=seed,\n        )\n        rgb_delta = rgb_delta * (self.value_range[1] - self.value_range[0])\n        return {\"rgb_delta\": rgb_delta}\n\n    def transform_images(self, images, transformation, training=True):\n        if training:\n            rgb_delta = transformation[\"rgb_delta\"]\n            rgb_delta = self.backend.cast(rgb_delta, images.dtype)\n            images += rgb_delta\n            return self.backend.numpy.clip(\n                images, self.value_range[0], self.value_range[1]\n            )\n        return images\n\n    def transform_labels(self, labels, transformation, training=True):\n        return labels\n\n    def transform_bounding_boxes(\n        self,\n        bounding_boxes,\n        transformation,\n        training=True,\n    ):\n        return bounding_boxes\n\n    def transform_segmentation_masks(\n        self, segmentation_masks, transformation, training=True\n    ):\n        return segmentation_masks\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def get_config(self):\n        config = {\n            \"factor\": self.factor,\n            \"value_range\": self.value_range,\n            \"seed\": self.seed,\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_brightness_test.py",
    "content": "import numpy as np\nimport pytest\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass RandomBrightnessTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_layer(self):\n        self.run_layer_test(\n            layers.RandomBrightness,\n            init_kwargs={\n                \"factor\": 0.75,\n                \"value_range\": (20, 200),\n                \"seed\": 1,\n            },\n            input_shape=(8, 3, 4, 3),\n            supports_masking=False,\n            expected_output_shape=(8, 3, 4, 3),\n        )\n\n    def test_random_brightness_inference(self):\n        seed = 3481\n        layer = layers.RandomBrightness([0, 1.0])\n        np.random.seed(seed)\n        inputs = np.random.randint(0, 255, size=(224, 224, 3))\n        output = layer(inputs, training=False)\n        self.assertAllClose(inputs, output)\n\n    def test_correctness(self):\n        seed = 2390\n\n        # Always scale up, but randomly between 0 ~ 255\n        layer = layers.RandomBrightness([0.1, 1.0])\n        np.random.seed(seed)\n        inputs = np.random.randint(0, 255, size=(224, 224, 3))\n        output = backend.convert_to_numpy(layer(inputs))\n        diff = output - inputs\n        diff = backend.convert_to_numpy(diff)\n        self.assertTrue(np.amin(diff) >= 0)\n        self.assertTrue(np.mean(diff) > 0)\n\n        # Always scale down, but randomly between 0 ~ 255\n        layer = layers.RandomBrightness([-1.0, -0.1])\n        np.random.seed(seed)\n        inputs = np.random.randint(0, 255, size=(224, 224, 3))\n        output = backend.convert_to_numpy(layer(inputs))\n        diff = output - inputs\n        self.assertTrue(np.amax(diff) <= 0)\n        self.assertTrue(np.mean(diff) < 0)\n\n    def test_tf_data_compatibility(self):\n        layer = layers.RandomBrightness(factor=0.5, seed=1337)\n        input_data = np.random.random((2, 8, 8, 3))\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        for output in ds.take(1):\n            output.numpy()\n\n    def test_value_range_incorrect_type(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"The `value_range` argument should be a list of two numbers.*\",\n        ):\n            layers.RandomBrightness(factor=0.1, value_range=\"incorrect_type\")\n\n    def test_value_range_incorrect_length(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"The `value_range` argument should be a list of two numbers.*\",\n        ):\n            layers.RandomBrightness(factor=0.1, value_range=[10])\n\n    def test_set_factor_incorrect_length(self):\n        layer = layers.RandomBrightness(factor=0.5)\n        with self.assertRaisesRegex(\n            ValueError, \"The `factor` argument should be a number.*\"\n        ):\n            layer._set_factor([0.1])  # Only one element in list\n\n    def test_set_factor_incorrect_type(self):\n        layer = layers.RandomBrightness(factor=0.5)\n        with self.assertRaisesRegex(\n            ValueError, \"The `factor` argument should be a number.*\"\n        ):\n            layer._set_factor(\n                \"invalid_type\"\n            )  # Passing a string instead of a number or a list/tuple of numbers\n\n    def test_factor_range_below_lower_bound(self):\n        with self.assertRaisesRegex(\n            ValueError, \"The `factor` argument should be a number.*\"\n        ):\n            # Passing a value less than -1.0\n            layers.RandomBrightness(factor=-1.1)\n\n    def test_factor_range_above_upper_bound(self):\n        with self.assertRaisesRegex(\n            ValueError, \"The `factor` argument should be a number.*\"\n        ):\n            # Passing a value more than 1.0\n            layers.RandomBrightness(factor=1.1)\n\n    def test_randomly_adjust_brightness_input_incorrect_rank(self):\n        layer = layers.RandomBrightness(factor=0.1)\n        wrong_rank_input = np.random.rand(10, 10)\n\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Expected the input image to be rank 3 or 4.\",\n        ):\n            layer(\n                wrong_rank_input, training=True\n            )  # Call the method that triggers the error\n\n    def test_dict_input(self):\n        layer = layers.RandomBrightness(factor=0.1, bounding_box_format=\"xyxy\")\n        data = {\n            \"images\": np.random.random((2, 4, 5, 3)),\n            \"labels\": np.random.random((2, 7)),\n            \"segmentation_masks\": np.random.random((2, 4, 5, 7)),\n            \"bounding_boxes\": {\n                \"boxes\": np.array([[1, 2, 2, 3]]),\n                \"labels\": np.array([0]),\n            },\n        }\n        transformed_data = layer(data)\n        self.assertEqual(\n            data[\"images\"].shape[:-1],\n            transformed_data[\"segmentation_masks\"].shape[:-1],\n        )\n        self.assertAllClose(data[\"labels\"], transformed_data[\"labels\"])\n        self.assertAllClose(\n            data[\"bounding_boxes\"][\"boxes\"],\n            transformed_data[\"bounding_boxes\"][\"boxes\"],\n        )\n        self.assertAllClose(\n            data[\"bounding_boxes\"][\"labels\"],\n            transformed_data[\"bounding_boxes\"][\"labels\"],\n        )\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    base_image_preprocessing_color_example,\n)\nfrom keras.src.random import SeedGenerator\n\n\n@keras_export(\"keras.layers.RandomColorDegeneration\")\nclass RandomColorDegeneration(BaseImagePreprocessingLayer):\n    \"\"\"Randomly performs the color degeneration operation on given images.\n\n    The sharpness operation first converts an image to gray scale, then back to\n    color. It then takes a weighted average between original image and the\n    degenerated image. This makes colors appear more dull.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Args:\n        factor: A tuple of two floats or a single float.\n            `factor` controls the extent to which the\n            image sharpness is impacted. `factor=0.0` makes this layer perform a\n            no-op operation, while a value of 1.0 uses the degenerated result\n            entirely. Values between 0 and 1 result in linear interpolation\n            between the original image and the sharpened image.\n            Values should be between `0.0` and `1.0`. If a tuple is used, a\n            `factor` is sampled between the two values for every image\n            augmented. If a single float is used, a value between `0.0` and the\n            passed float is sampled. In order to ensure the value is always the\n            same, please pass a tuple with two identical floats: `(0.5, 0.5)`.\n        seed: Integer. Used to create a random seed.\n\n    Example:\n\n    {{base_image_preprocessing_color_example}}\n    \"\"\"\n\n    _VALUE_RANGE_VALIDATION_ERROR = (\n        \"The `value_range` argument should be a list of two numbers. \"\n    )\n\n    def __init__(\n        self,\n        factor,\n        value_range=(0, 255),\n        data_format=None,\n        seed=None,\n        **kwargs,\n    ):\n        super().__init__(data_format=data_format, **kwargs)\n        self._set_factor(factor)\n        self._set_value_range(value_range)\n        self.seed = seed\n        self.generator = SeedGenerator(seed)\n\n    def _set_value_range(self, value_range):\n        if not isinstance(value_range, (tuple, list)):\n            raise ValueError(\n                self._VALUE_RANGE_VALIDATION_ERROR\n                + f\"Received: value_range={value_range}\"\n            )\n        if len(value_range) != 2:\n            raise ValueError(\n                self._VALUE_RANGE_VALIDATION_ERROR\n                + f\"Received: value_range={value_range}\"\n            )\n        self.value_range = sorted(value_range)\n\n    def get_random_transformation(self, data, training=True, seed=None):\n        if isinstance(data, dict):\n            images = data[\"images\"]\n        else:\n            images = data\n        images_shape = self.backend.shape(images)\n        rank = len(images_shape)\n        if rank == 3:\n            batch_size = 1\n        elif rank == 4:\n            batch_size = images_shape[0]\n        else:\n            raise ValueError(\n                \"Expected the input image to be rank 3 or 4. Received: \"\n                f\"inputs.shape={images_shape}\"\n            )\n\n        if seed is None:\n            seed = self._get_seed_generator(self.backend._backend)\n\n        factor = self.backend.random.uniform(\n            (batch_size, 1, 1, 1),\n            minval=self.factor[0],\n            maxval=self.factor[1],\n            seed=seed,\n        )\n        factor = factor\n        return {\"factor\": factor}\n\n    def transform_images(self, images, transformation=None, training=True):\n        if training:\n            images = self.backend.cast(images, self.compute_dtype)\n            factor = self.backend.cast(\n                transformation[\"factor\"], self.compute_dtype\n            )\n            degenerates = self.backend.image.rgb_to_grayscale(\n                images, data_format=self.data_format\n            )\n            images = images + factor * (degenerates - images)\n            images = self.backend.numpy.clip(\n                images, self.value_range[0], self.value_range[1]\n            )\n            images = self.backend.cast(images, self.compute_dtype)\n        return images\n\n    def transform_labels(self, labels, transformation, training=True):\n        return labels\n\n    def transform_segmentation_masks(\n        self, segmentation_masks, transformation, training=True\n    ):\n        return segmentation_masks\n\n    def transform_bounding_boxes(\n        self, bounding_boxes, transformation, training=True\n    ):\n        return bounding_boxes\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"factor\": self.factor,\n                \"value_range\": self.value_range,\n                \"seed\": self.seed,\n            }\n        )\n        return config\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n\nRandomColorDegeneration.__doc__ = RandomColorDegeneration.__doc__.replace(\n    \"{{base_image_preprocessing_color_example}}\",\n    base_image_preprocessing_color_example.replace(\n        \"{LayerName}\", \"RandomColorDegeneration\"\n    ),\n)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration_test.py",
    "content": "import numpy as np\nimport pytest\nfrom tensorflow import data as tf_data\n\nimport keras\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass RandomColorDegenerationTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_layer(self):\n        self.run_layer_test(\n            layers.RandomColorDegeneration,\n            init_kwargs={\n                \"factor\": 0.75,\n                \"value_range\": (0, 1),\n                \"seed\": 1,\n            },\n            input_shape=(8, 3, 4, 3),\n            supports_masking=False,\n            expected_output_shape=(8, 3, 4, 3),\n        )\n\n    def test_random_color_degeneration_value_range(self):\n        image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)\n\n        layer = layers.RandomColorDegeneration(0.2, value_range=(0, 1))\n        adjusted_image = layer(image)\n\n        self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0))\n        self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1))\n\n    def test_random_color_degeneration_no_op(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            inputs = np.random.random((2, 8, 8, 3))\n        else:\n            inputs = np.random.random((2, 3, 8, 8))\n\n        layer = layers.RandomColorDegeneration((0.5, 0.5))\n        output = layer(inputs, training=False)\n        self.assertAllClose(inputs, output, atol=1e-3, rtol=1e-5)\n\n    def test_random_color_degeneration_factor_zero(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            inputs = np.random.random((2, 8, 8, 3))\n        else:\n            inputs = np.random.random((2, 3, 8, 8))\n        layer = layers.RandomColorDegeneration(factor=(0.0, 0.0))\n        result = layer(inputs)\n\n        self.assertAllClose(inputs, result, atol=1e-3, rtol=1e-5)\n\n    def test_random_color_degeneration_randomness(self):\n        image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)[:5]\n\n        layer = layers.RandomColorDegeneration(0.2)\n        adjusted_images = layer(image)\n\n        self.assertNotAllClose(adjusted_images, image)\n\n    def test_tf_data_compatibility(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_data = np.random.random((2, 8, 8, 3))\n        else:\n            input_data = np.random.random((2, 3, 8, 8))\n        layer = layers.RandomColorDegeneration(\n            factor=0.5, data_format=data_format, seed=1337\n        )\n\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        for output in ds.take(1):\n            output.numpy()\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_color_jitter.py",
    "content": "import keras.src.layers.preprocessing.image_preprocessing.random_brightness as random_brightness  # noqa: E501\nimport keras.src.layers.preprocessing.image_preprocessing.random_contrast as random_contrast  # noqa: E501\nimport keras.src.layers.preprocessing.image_preprocessing.random_hue as random_hue  # noqa: E501\nimport keras.src.layers.preprocessing.image_preprocessing.random_saturation as random_saturation  # noqa: E501\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    base_image_preprocessing_color_example,\n)\nfrom keras.src.random.seed_generator import SeedGenerator\nfrom keras.src.utils import backend_utils\n\n\n@keras_export(\"keras.layers.RandomColorJitter\")\nclass RandomColorJitter(BaseImagePreprocessingLayer):\n    \"\"\"RandomColorJitter class randomly apply brightness, contrast, saturation\n    and hue image processing operation sequentially and randomly on the\n    input.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Args:\n        value_range: the range of values the incoming images will have.\n            Represented as a two number tuple written [low, high].\n            This is typically either `[0, 1]` or `[0, 255]` depending\n            on how your preprocessing pipeline is set up.\n        brightness_factor: Float or a list/tuple of 2 floats between -1.0\n            and 1.0. The factor is used to determine the lower bound and\n            upper bound of the brightness adjustment. A float value will\n            be chosen randomly between the limits. When -1.0 is chosen,\n            the output image will be black, and when 1.0 is chosen, the\n            image will be fully white. When only one float is provided,\n            eg, 0.2, then -0.2 will be used for lower bound and 0.2 will\n            be used for upper bound.\n        contrast_factor: a positive float represented as fraction of value,\n            or a tuple of size 2 representing lower and upper bound. When\n            represented as a single float, lower = upper. The contrast\n            factor will be randomly picked between `[1.0 - lower, 1.0 +\n            upper]`. For any pixel x in the channel, the output will be\n            `(x - mean) * factor + mean` where `mean` is the mean value\n            of the channel.\n        saturation_factor: A tuple of two floats or a single float. `factor`\n            controls the extent to which the image saturation is impacted.\n            `factor=0.5` makes this layer perform a no-op operation.\n            `factor=0.0` makes the image fully grayscale. `factor=1.0`\n            makes the image fully saturated. Values should be between\n            `0.0` and `1.0`. If a tuple is used, a `factor` is sampled\n            between the two values for every image augmented. If a single\n            float is used, a value between `0.0` and the passed float is\n            sampled. To ensure the value is always the same, pass a tuple\n            with two identical floats: `(0.5, 0.5)`.\n        hue_factor: A single float or a tuple of two floats. `factor`\n            controls the extent to which the image hue is impacted.\n            `factor=0.0` makes this layer perform a no-op operation,\n            while a value of `1.0` performs the most aggressive contrast\n            adjustment available. If a tuple is used, a `factor` is\n            sampled between the two values for every image augmented.\n            If a single float is used, a value between `0.0` and the\n            passed float is sampled. In order to ensure the value is\n            always the same, please pass a tuple with two identical\n            floats: `(0.5, 0.5)`.\n        seed: Integer. Used to create a random seed.\n\n    Example:\n\n    {{base_image_preprocessing_color_example}}\n    \"\"\"\n\n    def __init__(\n        self,\n        value_range=(0, 255),\n        brightness_factor=None,\n        contrast_factor=None,\n        saturation_factor=None,\n        hue_factor=None,\n        seed=None,\n        data_format=None,\n        **kwargs,\n    ):\n        super().__init__(data_format=data_format, **kwargs)\n        self.value_range = value_range\n        self.brightness_factor = brightness_factor\n        self.contrast_factor = contrast_factor\n        self.saturation_factor = saturation_factor\n        self.hue_factor = hue_factor\n        self.seed = seed\n        self.generator = SeedGenerator(seed)\n\n        self.random_brightness = None\n        self.random_contrast = None\n        self.random_saturation = None\n        self.random_hue = None\n\n        if self.brightness_factor is not None:\n            self.random_brightness = random_brightness.RandomBrightness(\n                factor=self.brightness_factor,\n                value_range=self.value_range,\n                seed=self.seed,\n            )\n\n        if self.contrast_factor is not None:\n            self.random_contrast = random_contrast.RandomContrast(\n                factor=self.contrast_factor,\n                value_range=self.value_range,\n                seed=self.seed,\n            )\n\n        if self.saturation_factor is not None:\n            self.random_saturation = random_saturation.RandomSaturation(\n                factor=self.saturation_factor,\n                value_range=self.value_range,\n                seed=self.seed,\n            )\n\n        if self.hue_factor is not None:\n            self.random_hue = random_hue.RandomHue(\n                factor=self.hue_factor,\n                value_range=self.value_range,\n                seed=self.seed,\n            )\n\n    def build(self, input_shape):\n        if self.brightness_factor is not None:\n            self.random_brightness.build(input_shape)\n\n        if self.contrast_factor is not None:\n            self.random_contrast.build(input_shape)\n\n        if self.saturation_factor is not None:\n            self.random_saturation.build(input_shape)\n\n        if self.hue_factor is not None:\n            self.random_hue.build(input_shape)\n\n    def transform_images(self, images, transformation, training=True):\n        if training:\n            if backend_utils.in_tf_graph():\n                self.backend.set_backend(\"tensorflow\")\n            images = self.backend.cast(images, self.compute_dtype)\n            if self.brightness_factor is not None:\n                if backend_utils.in_tf_graph():\n                    self.random_brightness.backend.set_backend(\"tensorflow\")\n                transformation = (\n                    self.random_brightness.get_random_transformation(\n                        images,\n                        seed=self._get_seed_generator(self.backend._backend),\n                    )\n                )\n                images = self.random_brightness.transform_images(\n                    images, transformation\n                )\n            if self.contrast_factor is not None:\n                if backend_utils.in_tf_graph():\n                    self.random_contrast.backend.set_backend(\"tensorflow\")\n                transformation = self.random_contrast.get_random_transformation(\n                    images, seed=self._get_seed_generator(self.backend._backend)\n                )\n                transformation[\"contrast_factor\"] = self.backend.cast(\n                    transformation[\"contrast_factor\"], dtype=self.compute_dtype\n                )\n                images = self.random_contrast.transform_images(\n                    images, transformation\n                )\n            if self.saturation_factor is not None:\n                if backend_utils.in_tf_graph():\n                    self.random_saturation.backend.set_backend(\"tensorflow\")\n                transformation = (\n                    self.random_saturation.get_random_transformation(\n                        images,\n                        seed=self._get_seed_generator(self.backend._backend),\n                    )\n                )\n                images = self.random_saturation.transform_images(\n                    images, transformation\n                )\n            if self.hue_factor is not None:\n                if backend_utils.in_tf_graph():\n                    self.random_hue.backend.set_backend(\"tensorflow\")\n                transformation = self.random_hue.get_random_transformation(\n                    images, seed=self._get_seed_generator(self.backend._backend)\n                )\n                images = self.random_hue.transform_images(\n                    images, transformation\n                )\n            images = self.backend.cast(images, self.compute_dtype)\n        return images\n\n    def transform_labels(self, labels, transformation, training=True):\n        return labels\n\n    def transform_bounding_boxes(\n        self,\n        bounding_boxes,\n        transformation,\n        training=True,\n    ):\n        return bounding_boxes\n\n    def transform_segmentation_masks(\n        self, segmentation_masks, transformation, training=True\n    ):\n        return segmentation_masks\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def get_config(self):\n        config = {\n            \"value_range\": self.value_range,\n            \"brightness_factor\": self.brightness_factor,\n            \"contrast_factor\": self.contrast_factor,\n            \"saturation_factor\": self.saturation_factor,\n            \"hue_factor\": self.hue_factor,\n            \"seed\": self.seed,\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n\nRandomColorJitter.__doc__ = RandomColorJitter.__doc__.replace(\n    \"{{base_image_preprocessing_color_example}}\",\n    base_image_preprocessing_color_example.replace(\n        \"{LayerName}\", \"RandomColorJitter\"\n    ),\n)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_color_jitter_test.py",
    "content": "import numpy as np\nimport pytest\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass RandomColorJitterTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_layer(self):\n        self.run_layer_test(\n            layers.RandomColorJitter,\n            init_kwargs={\n                \"value_range\": (20, 200),\n                \"brightness_factor\": 0.2,\n                \"contrast_factor\": 0.2,\n                \"saturation_factor\": 0.2,\n                \"hue_factor\": 0.2,\n                \"seed\": 1,\n            },\n            input_shape=(8, 3, 4, 3),\n            supports_masking=False,\n            expected_output_shape=(8, 3, 4, 3),\n        )\n\n    def test_random_color_jitter_inference(self):\n        seed = 3481\n        layer = layers.RandomColorJitter(\n            value_range=(0, 1),\n            brightness_factor=0.1,\n            contrast_factor=0.2,\n            saturation_factor=0.9,\n            hue_factor=0.1,\n        )\n\n        np.random.seed(seed)\n        inputs = np.random.randint(0, 255, size=(224, 224, 3))\n        output = layer(inputs, training=False)\n        self.assertAllClose(inputs, output)\n\n    def test_brightness_only(self):\n        seed = 2390\n        np.random.seed(seed)\n\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            inputs = np.random.random((12, 8, 16, 3))\n        else:\n            inputs = np.random.random((12, 3, 8, 16))\n\n        layer = layers.RandomColorJitter(\n            brightness_factor=[0.5, 0.5], seed=seed\n        )\n        output = backend.convert_to_numpy(layer(inputs))\n\n        layer = layers.RandomBrightness(factor=[0.5, 0.5], seed=seed)\n        sub_output = backend.convert_to_numpy(layer(inputs))\n\n        self.assertAllClose(output, sub_output)\n\n    def test_saturation_only(self):\n        seed = 2390\n        np.random.seed(seed)\n\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            inputs = np.random.random((12, 8, 16, 3))\n        else:\n            inputs = np.random.random((12, 3, 8, 16))\n\n        layer = layers.RandomColorJitter(\n            saturation_factor=[0.5, 0.5], seed=seed\n        )\n        output = layer(inputs)\n\n        layer = layers.RandomSaturation(factor=[0.5, 0.5], seed=seed)\n        sub_output = layer(inputs)\n\n        self.assertAllClose(output, sub_output)\n\n    def test_hue_only(self):\n        seed = 2390\n        np.random.seed(seed)\n\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            inputs = np.random.random((12, 8, 16, 3))\n        else:\n            inputs = np.random.random((12, 3, 8, 16))\n\n        layer = layers.RandomColorJitter(hue_factor=[0.5, 0.5], seed=seed)\n        output = layer(inputs)\n\n        layer = layers.RandomHue(factor=[0.5, 0.5], seed=seed)\n        sub_output = layer(inputs)\n\n        self.assertAllClose(output, sub_output)\n\n    def test_contrast_only(self):\n        seed = 2390\n        np.random.seed(seed)\n\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            inputs = np.random.random((12, 8, 16, 3))\n        else:\n            inputs = np.random.random((12, 3, 8, 16))\n\n        layer = layers.RandomColorJitter(contrast_factor=[0.5, 0.5], seed=seed)\n        output = layer(inputs)\n\n        layer = layers.RandomContrast(factor=[0.5, 0.5], seed=seed)\n        sub_output = layer(inputs)\n\n        self.assertAllClose(output, sub_output)\n\n    def test_tf_data_compatibility(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_data = np.random.random((2, 8, 8, 3))\n        else:\n            input_data = np.random.random((2, 3, 8, 8))\n        layer = layers.RandomColorJitter(\n            value_range=(0, 1),\n            brightness_factor=0.1,\n            contrast_factor=0.2,\n            saturation_factor=0.9,\n            hue_factor=0.1,\n        )\n\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        for output in ds.take(1):\n            output.numpy()\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_contrast.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    base_image_preprocessing_color_example,\n)\nfrom keras.src.random.seed_generator import SeedGenerator\n\n\n@keras_export(\"keras.layers.RandomContrast\")\nclass RandomContrast(BaseImagePreprocessingLayer):\n    \"\"\"A preprocessing layer which randomly adjusts contrast during training.\n\n    This layer will randomly adjust the contrast of an image or images\n    by a random factor. Contrast is adjusted independently\n    for each channel of each image during training.\n\n    For each channel, this layer computes the mean of the image pixels in the\n    channel and then adjusts each component `x` of each pixel to\n    `(x - mean) * contrast_factor + mean`.\n\n    Input pixel values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) and\n    in integer or floating point dtype.\n    By default, the layer will output floats.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Input shape:\n        3D (unbatched) or 4D (batched) tensor with shape:\n        `(..., height, width, channels)`, in `\"channels_last\"` format.\n\n    Output shape:\n        3D (unbatched) or 4D (batched) tensor with shape:\n        `(..., height, width, channels)`, in `\"channels_last\"` format.\n\n    Args:\n        factor: a positive float represented as fraction of value, or a tuple of\n            size 2 representing lower and upper bound.\n            When represented as a single float, lower = upper.\n            The contrast factor will be randomly picked between\n            `[1.0 - lower, 1.0 + upper]`. For any pixel x in the channel,\n            the output will be `(x - mean) * factor + mean`\n            where `mean` is the mean value of the channel.\n        value_range: the range of values the incoming images will have.\n            Represented as a two-number tuple written `[low, high]`. This is\n            typically either `[0, 1]` or `[0, 255]` depending on how your\n            preprocessing pipeline is set up.\n        seed: Integer. Used to create a random seed.\n\n    Example:\n\n    {{base_image_preprocessing_color_example}}\n    \"\"\"\n\n    _FACTOR_BOUNDS = (0, 1)\n\n    def __init__(self, factor, value_range=(0, 255), seed=None, **kwargs):\n        super().__init__(**kwargs)\n        self._set_factor(factor)\n        self.value_range = value_range\n        self.seed = seed\n        self.generator = SeedGenerator(seed)\n\n    def get_random_transformation(self, data, training=True, seed=None):\n        if isinstance(data, dict):\n            images = data[\"images\"]\n        else:\n            images = data\n        images_shape = self.backend.shape(images)\n        rank = len(images_shape)\n        if rank == 3:\n            factor_shape = (1, 1, 1)\n        elif rank == 4:\n            # Keep only the batch dim. This will ensure to have same adjustment\n            # with in one image, but different across the images.\n            factor_shape = [images_shape[0], 1, 1, 1]\n        else:\n            raise ValueError(\n                \"Expected the input image to be rank 3 or 4. Received \"\n                f\"inputs.shape={images_shape}\"\n            )\n\n        if not training:\n            return {\"contrast_factor\": self.backend.numpy.zeros(factor_shape)}\n\n        if seed is None:\n            seed = self._get_seed_generator(self.backend._backend)\n\n        factor = self.backend.random.uniform(\n            shape=factor_shape,\n            minval=1.0 - self.factor[0],\n            maxval=1.0 + self.factor[1],\n            seed=seed,\n            dtype=self.compute_dtype,\n        )\n        return {\"contrast_factor\": factor}\n\n    def transform_images(self, images, transformation, training=True):\n        if training:\n            contrast_factor = transformation[\"contrast_factor\"]\n            outputs = self._adjust_contrast(images, contrast_factor)\n            outputs = self.backend.numpy.clip(\n                outputs, self.value_range[0], self.value_range[1]\n            )\n            self.backend.numpy.reshape(outputs, self.backend.shape(images))\n            return outputs\n        return images\n\n    def transform_labels(self, labels, transformation, training=True):\n        return labels\n\n    def transform_bounding_boxes(\n        self,\n        bounding_boxes,\n        transformation,\n        training=True,\n    ):\n        return bounding_boxes\n\n    def transform_segmentation_masks(\n        self, segmentation_masks, transformation, training=True\n    ):\n        return segmentation_masks\n\n    def _adjust_contrast(self, inputs, contrast_factor):\n        if self.data_format == \"channels_first\":\n            height_axis = -2\n            width_axis = -1\n        else:\n            height_axis = -3\n            width_axis = -2\n        # reduce mean on height\n        inp_mean = self.backend.numpy.mean(\n            inputs, axis=height_axis, keepdims=True\n        )\n        # reduce mean on width\n        inp_mean = self.backend.numpy.mean(\n            inp_mean, axis=width_axis, keepdims=True\n        )\n\n        outputs = (inputs - inp_mean) * contrast_factor + inp_mean\n        return outputs\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def get_config(self):\n        config = {\n            \"factor\": self.factor,\n            \"value_range\": self.value_range,\n            \"seed\": self.seed,\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n\nRandomContrast.__doc__ = RandomContrast.__doc__.replace(\n    \"{{base_image_preprocessing_color_example}}\",\n    base_image_preprocessing_color_example.replace(\n        \"{LayerName}\", \"RandomContrast\"\n    ),\n)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_contrast_test.py",
    "content": "import numpy as np\nimport pytest\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass RandomContrastTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_layer(self):\n        self.run_layer_test(\n            layers.RandomContrast,\n            init_kwargs={\n                \"factor\": 0.75,\n                \"value_range\": (0, 255),\n                \"seed\": 1,\n            },\n            input_shape=(8, 3, 4, 3),\n            supports_masking=False,\n            expected_output_shape=(8, 3, 4, 3),\n        )\n        self.run_layer_test(\n            layers.RandomContrast,\n            init_kwargs={\n                \"factor\": 0.75,\n                \"value_range\": (0, 255),\n                \"seed\": 1,\n                \"data_format\": \"channels_first\",\n            },\n            input_shape=(8, 3, 4, 4),\n            supports_masking=False,\n            expected_output_shape=(8, 3, 4, 4),\n        )\n\n    def test_random_contrast_with_value_range_0_to_255(self):\n        seed = 9809\n        np.random.seed(seed)\n\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            inputs = np.random.random((12, 8, 16, 3))\n            height_axis = -3\n            width_axis = -2\n        else:\n            inputs = np.random.random((12, 3, 8, 16))\n            height_axis = -2\n            width_axis = -1\n\n        inputs = backend.convert_to_tensor(inputs, dtype=\"float32\")\n        layer = layers.RandomContrast(\n            factor=0.5, value_range=(0, 255), seed=seed\n        )\n        transformation = layer.get_random_transformation(inputs, training=True)\n        outputs = layer.transform_images(inputs, transformation, training=True)\n\n        # Actual contrast arithmetic\n        np.random.seed(seed)\n        factor = backend.convert_to_numpy(transformation[\"contrast_factor\"])\n        inputs = backend.convert_to_numpy(inputs)\n        inp_mean = np.mean(inputs, axis=height_axis, keepdims=True)\n        inp_mean = np.mean(inp_mean, axis=width_axis, keepdims=True)\n        actual_outputs = (inputs - inp_mean) * factor + inp_mean\n        outputs = backend.convert_to_numpy(outputs)\n        actual_outputs = np.clip(actual_outputs, 0, 255)\n\n        self.assertAllClose(outputs, actual_outputs)\n\n    def test_random_contrast_with_value_range_0_to_1(self):\n        seed = 9809\n        np.random.seed(seed)\n\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            inputs = np.random.random((12, 8, 16, 3))\n            height_axis = -3\n            width_axis = -2\n        else:\n            inputs = np.random.random((12, 3, 8, 16))\n            height_axis = -2\n            width_axis = -1\n\n        inputs = backend.convert_to_tensor(inputs, dtype=\"float32\")\n        layer = layers.RandomContrast(factor=0.5, value_range=(0, 1), seed=seed)\n        transformation = layer.get_random_transformation(inputs, training=True)\n        outputs = layer.transform_images(inputs, transformation, training=True)\n\n        # Actual contrast arithmetic\n        np.random.seed(seed)\n        factor = backend.convert_to_numpy(transformation[\"contrast_factor\"])\n        inputs = backend.convert_to_numpy(inputs)\n        inp_mean = np.mean(inputs, axis=height_axis, keepdims=True)\n        inp_mean = np.mean(inp_mean, axis=width_axis, keepdims=True)\n        actual_outputs = (inputs - inp_mean) * factor + inp_mean\n        outputs = backend.convert_to_numpy(outputs)\n        actual_outputs = np.clip(actual_outputs, 0, 1)\n\n        self.assertAllClose(outputs, actual_outputs)\n\n    def test_tf_data_compatibility(self):\n        layer = layers.RandomContrast(factor=0.5, seed=1337)\n        input_data = np.random.random((2, 8, 8, 3))\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        next(iter(ds)).numpy()\n\n    def test_dict_input(self):\n        layer = layers.RandomContrast(factor=0.1, bounding_box_format=\"xyxy\")\n        data = {\n            \"images\": np.random.random((2, 4, 5, 3)),\n            \"labels\": np.random.random((2, 7)),\n            \"segmentation_masks\": np.random.random((2, 4, 5, 7)),\n            \"bounding_boxes\": {\n                \"boxes\": np.array([[1, 2, 2, 3]]),\n                \"labels\": np.array([0]),\n            },\n        }\n        transformed_data = layer(data)\n        self.assertEqual(\n            data[\"images\"].shape[:-1],\n            transformed_data[\"segmentation_masks\"].shape[:-1],\n        )\n        self.assertAllClose(data[\"labels\"], transformed_data[\"labels\"])\n        self.assertAllClose(\n            data[\"bounding_boxes\"][\"boxes\"],\n            transformed_data[\"bounding_boxes\"][\"boxes\"],\n        )\n        self.assertAllClose(\n            data[\"bounding_boxes\"][\"labels\"],\n            transformed_data[\"bounding_boxes\"][\"labels\"],\n        )\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_crop.py",
    "content": "from keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    base_image_preprocessing_transform_example,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (  # noqa: E501\n    convert_format,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.validation import (  # noqa: E501\n    densify_bounding_boxes,\n)\nfrom keras.src.random.seed_generator import SeedGenerator\n\n\n@keras_export(\"keras.layers.RandomCrop\")\nclass RandomCrop(BaseImagePreprocessingLayer):\n    \"\"\"A preprocessing layer which randomly crops images during training.\n\n    During training, this layer will randomly choose a location to crop images\n    down to a target size. The layer will crop all the images in the same batch\n    to the same cropping location.\n\n    At inference time, and during training if an input image is smaller than the\n    target size, the input will be resized and cropped so as to return the\n    largest possible window in the image that matches the target aspect ratio.\n    If you need to apply random cropping at inference time, set `training` to\n    True when calling the layer.\n\n    Input pixel values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) and\n    of integer or floating point dtype. By default, the layer will output\n    floats.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Input shape:\n        3D (unbatched) or 4D (batched) tensor with shape:\n        `(..., height, width, channels)`, in `\"channels_last\"` format.\n\n    Output shape:\n        3D (unbatched) or 4D (batched) tensor with shape:\n        `(..., target_height, target_width, channels)`.\n\n    Args:\n        height: Integer, the height of the output shape.\n        width: Integer, the width of the output shape.\n        seed: Integer. Used to create a random seed.\n        **kwargs: Base layer keyword arguments, such as\n            `name` and `dtype`.\n\n    Example:\n\n    {{base_image_preprocessing_transform_example}}\n    \"\"\"\n\n    def __init__(\n        self, height, width, seed=None, data_format=None, name=None, **kwargs\n    ):\n        super().__init__(name=name, **kwargs)\n        self.height = height\n        self.width = width\n        self.seed = (\n            seed if seed is not None else backend.random.make_default_seed()\n        )\n        self.generator = SeedGenerator(seed)\n        self.data_format = backend.standardize_data_format(data_format)\n\n        if self.data_format == \"channels_first\":\n            self.height_axis = -2\n            self.width_axis = -1\n        elif self.data_format == \"channels_last\":\n            self.height_axis = -3\n            self.width_axis = -2\n\n        self.supports_masking = False\n        self.supports_jit = False\n        self._convert_input_args = False\n        self._allow_non_tensor_positional_args = True\n\n    def get_random_transformation(self, data, training=True, seed=None):\n        if seed is None:\n            seed = self._get_seed_generator(self.backend._backend)\n\n        if isinstance(data, dict):\n            input_shape = self.backend.shape(data[\"images\"])\n        else:\n            input_shape = self.backend.shape(data)\n\n        input_height, input_width = (\n            input_shape[self.height_axis],\n            input_shape[self.width_axis],\n        )\n        if input_height is None or input_width is None:\n            raise ValueError(\n                \"RandomCrop requires the input to have a fully defined \"\n                f\"height and width. Received: images.shape={input_shape}\"\n            )\n\n        if training and input_height > self.height and input_width > self.width:\n            h_start = self.backend.cast(\n                self.backend.random.uniform(\n                    (),\n                    0,\n                    maxval=float(input_height - self.height + 1),\n                    seed=seed,\n                ),\n                \"int32\",\n            )\n            w_start = self.backend.cast(\n                self.backend.random.uniform(\n                    (),\n                    0,\n                    maxval=float(input_width - self.width + 1),\n                    seed=seed,\n                ),\n                \"int32\",\n            )\n        else:\n            crop_height = int(float(input_width * self.height) / self.width)\n            crop_height = max(min(input_height, crop_height), 1)\n            crop_width = int(float(input_height * self.width) / self.height)\n            crop_width = max(min(input_width, crop_width), 1)\n            h_start = int(float(input_height - crop_height) / 2)\n            w_start = int(float(input_width - crop_width) / 2)\n\n        return h_start, w_start\n\n    def transform_images(self, images, transformation, training=True):\n        if training:\n            images = self.backend.cast(images, self.compute_dtype)\n            crop_box_hstart, crop_box_wstart = transformation\n            crop_height = self.height\n            crop_width = self.width\n\n            if self.data_format == \"channels_last\":\n                if len(images.shape) == 4:\n                    images = images[\n                        :,\n                        crop_box_hstart : crop_box_hstart + crop_height,\n                        crop_box_wstart : crop_box_wstart + crop_width,\n                        :,\n                    ]\n                else:\n                    images = images[\n                        crop_box_hstart : crop_box_hstart + crop_height,\n                        crop_box_wstart : crop_box_wstart + crop_width,\n                        :,\n                    ]\n            else:\n                if len(images.shape) == 4:\n                    images = images[\n                        :,\n                        :,\n                        crop_box_hstart : crop_box_hstart + crop_height,\n                        crop_box_wstart : crop_box_wstart + crop_width,\n                    ]\n                else:\n                    images = images[\n                        :,\n                        crop_box_hstart : crop_box_hstart + crop_height,\n                        crop_box_wstart : crop_box_wstart + crop_width,\n                    ]\n\n            shape = self.backend.shape(images)\n            new_height = shape[self.height_axis]\n            new_width = shape[self.width_axis]\n            if (\n                not isinstance(new_height, int)\n                or not isinstance(new_width, int)\n                or new_height != self.height\n                or new_width != self.width\n            ):\n                # Resize images if size mismatch or\n                # if size mismatch cannot be determined\n                # (in the case of a TF dynamic shape).\n                images = self.backend.image.resize(\n                    images,\n                    size=(self.height, self.width),\n                    data_format=self.data_format,\n                )\n                # Resize may have upcasted the outputs\n                images = self.backend.cast(images, self.compute_dtype)\n        return images\n\n    def transform_labels(self, labels, transformation, training=True):\n        return labels\n\n    def transform_bounding_boxes(\n        self,\n        bounding_boxes,\n        transformation,\n        training=True,\n    ):\n        \"\"\"\n        bounding_boxes = {\n            \"boxes\": (batch, num_boxes, 4),  # left-top-right-bottom (xyxy)\n            \"labels\": (batch, num_boxes, num_classes),\n        }\n        or\n        bounding_boxes = {\n            \"boxes\": (num_boxes, 4),\n            \"labels\": (num_boxes, num_classes),\n        }\n        \"\"\"\n\n        if training:\n            h_start, w_start = transformation\n            if not self.backend.is_tensor(bounding_boxes[\"boxes\"]):\n                bounding_boxes = densify_bounding_boxes(\n                    bounding_boxes, backend=self.backend\n                )\n            boxes = bounding_boxes[\"boxes\"]\n            # Convert to a standard xyxy as operations are done xyxy by default.\n            boxes = convert_format(\n                boxes=boxes,\n                source=self.bounding_box_format,\n                target=\"xyxy\",\n                height=self.height,\n                width=self.width,\n            )\n            h_start = self.backend.cast(h_start, boxes.dtype)\n            w_start = self.backend.cast(w_start, boxes.dtype)\n            if len(self.backend.shape(boxes)) == 3:\n                boxes = self.backend.numpy.stack(\n                    [\n                        self.backend.numpy.maximum(boxes[:, :, 0] - h_start, 0),\n                        self.backend.numpy.maximum(boxes[:, :, 1] - w_start, 0),\n                        self.backend.numpy.maximum(boxes[:, :, 2] - h_start, 0),\n                        self.backend.numpy.maximum(boxes[:, :, 3] - w_start, 0),\n                    ],\n                    axis=-1,\n                )\n            else:\n                boxes = self.backend.numpy.stack(\n                    [\n                        self.backend.numpy.maximum(boxes[:, 0] - h_start, 0),\n                        self.backend.numpy.maximum(boxes[:, 1] - w_start, 0),\n                        self.backend.numpy.maximum(boxes[:, 2] - h_start, 0),\n                        self.backend.numpy.maximum(boxes[:, 3] - w_start, 0),\n                    ],\n                    axis=-1,\n                )\n\n            # Convert to user defined bounding box format\n            boxes = convert_format(\n                boxes=boxes,\n                source=\"xyxy\",\n                target=self.bounding_box_format,\n                height=self.height,\n                width=self.width,\n            )\n\n            return {\n                \"boxes\": boxes,\n                \"labels\": bounding_boxes[\"labels\"],\n            }\n        return bounding_boxes\n\n    def transform_segmentation_masks(\n        self, segmentation_masks, transformation, training=True\n    ):\n        return self.transform_images(segmentation_masks, transformation)\n\n    def compute_output_shape(self, input_shape, *args, **kwargs):\n        input_shape = list(input_shape)\n        input_shape[self.height_axis] = self.height\n        input_shape[self.width_axis] = self.width\n        return tuple(input_shape)\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"height\": self.height,\n                \"width\": self.width,\n                \"seed\": self.seed,\n                \"data_format\": self.data_format,\n            }\n        )\n        return config\n\n\nRandomCrop.__doc__ = RandomCrop.__doc__.replace(\n    \"{{base_image_preprocessing_transform_example}}\",\n    base_image_preprocessing_transform_example.replace(\n        \"{LayerName}\", \"RandomCrop\"\n    ),\n)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_crop_test.py",
    "content": "import numpy as np\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass RandomCropTest(testing.TestCase):\n    def test_random_crop(self):\n        self.run_layer_test(\n            layers.RandomCrop,\n            init_kwargs={\n                \"height\": 2,\n                \"width\": 2,\n                \"data_format\": \"channels_last\",\n            },\n            input_shape=(1, 3, 4, 3),\n            supports_masking=False,\n            run_training_check=False,\n            expected_output_shape=(1, 2, 2, 3),\n        )\n        self.run_layer_test(\n            layers.RandomCrop,\n            init_kwargs={\n                \"height\": 2,\n                \"width\": 2,\n                \"data_format\": \"channels_last\",\n            },\n            input_shape=(3, 4, 3),\n            supports_masking=False,\n            run_training_check=False,\n            expected_output_shape=(2, 2, 3),\n        )\n        self.run_layer_test(\n            layers.RandomCrop,\n            init_kwargs={\n                \"height\": 2,\n                \"width\": 2,\n                \"data_format\": \"channels_first\",\n            },\n            input_shape=(1, 3, 3, 4),\n            supports_masking=False,\n            run_training_check=False,\n            expected_output_shape=(1, 3, 2, 2),\n        )\n        self.run_layer_test(\n            layers.RandomCrop,\n            init_kwargs={\n                \"height\": 2,\n                \"width\": 2,\n                \"data_format\": \"channels_first\",\n            },\n            input_shape=(3, 3, 4),\n            supports_masking=False,\n            run_training_check=False,\n            expected_output_shape=(3, 2, 2),\n        )\n\n    def test_random_crop_full(self):\n        np.random.seed(1337)\n        height, width = 8, 16\n        if backend.config.image_data_format() == \"channels_last\":\n            input_shape = (12, 8, 16, 3)\n        else:\n            input_shape = (12, 3, 8, 16)\n        inp = np.random.random(input_shape)\n        layer = layers.RandomCrop(height, width)\n        actual_output = layer(inp, training=False)\n        self.assertAllClose(inp, actual_output)\n\n    def test_random_crop_partial(self):\n        if backend.config.image_data_format() == \"channels_last\":\n            input_shape = (12, 8, 16, 3)\n            output_shape = (12, 8, 8, 3)\n        else:\n            input_shape = (12, 3, 8, 16)\n            output_shape = (12, 3, 8, 8)\n        self.run_layer_test(\n            layers.RandomCrop,\n            init_kwargs={\n                \"height\": 8,\n                \"width\": 8,\n            },\n            input_shape=input_shape,\n            expected_output_shape=output_shape,\n            supports_masking=False,\n            run_training_check=False,\n        )\n\n    def test_predicting_with_longer_height(self):\n        if backend.config.image_data_format() == \"channels_last\":\n            input_shape = (12, 8, 16, 3)\n            output_shape = (12, 10, 8, 3)\n        else:\n            input_shape = (12, 3, 8, 16)\n            output_shape = (12, 3, 10, 8)\n        self.run_layer_test(\n            layers.RandomCrop,\n            init_kwargs={\n                \"height\": 10,\n                \"width\": 8,\n            },\n            input_shape=input_shape,\n            expected_output_shape=output_shape,\n            supports_masking=False,\n            run_training_check=False,\n        )\n\n    def test_predicting_with_longer_width(self):\n        if backend.config.image_data_format() == \"channels_last\":\n            input_shape = (12, 8, 16, 3)\n            output_shape = (12, 8, 18, 3)\n        else:\n            input_shape = (12, 3, 8, 16)\n            output_shape = (12, 3, 8, 18)\n        self.run_layer_test(\n            layers.RandomCrop,\n            init_kwargs={\n                \"height\": 8,\n                \"width\": 18,\n            },\n            input_shape=input_shape,\n            expected_output_shape=output_shape,\n            supports_masking=False,\n            run_training_check=False,\n        )\n\n    def test_tf_data_compatibility(self):\n        layer = layers.RandomCrop(8, 9)\n        if backend.config.image_data_format() == \"channels_last\":\n            input_shape = (2, 10, 12, 3)\n            output_shape = (2, 8, 9, 3)\n        else:\n            input_shape = (2, 3, 10, 12)\n            output_shape = (2, 3, 8, 9)\n        input_data = np.random.random(input_shape)\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        output = next(iter(ds)).numpy()\n        self.assertEqual(tuple(output.shape), output_shape)\n\n    def test_dict_input(self):\n        layer = layers.RandomCrop(\n            3, 3, data_format=\"channels_last\", bounding_box_format=\"xyxy\"\n        )\n        data = {\n            \"images\": np.random.random((2, 4, 5, 3)),\n            \"labels\": np.random.random((2, 7)),\n            \"segmentation_masks\": np.random.random((2, 4, 5, 7)),\n            \"bounding_boxes\": {\n                \"boxes\": np.array([[1, 2, 2, 3]]),\n                \"labels\": np.array([0]),\n            },\n        }\n        transformed_data = layer(data)\n        self.assertEqual(\n            data[\"images\"].shape[:-1],\n            transformed_data[\"segmentation_masks\"].shape[:-1],\n        )\n        self.assertAllClose(data[\"labels\"], transformed_data[\"labels\"])\n        self.assertEqual(data[\"bounding_boxes\"][\"boxes\"].shape, (1, 4))\n        self.assertAllClose(\n            data[\"bounding_boxes\"][\"labels\"],\n            transformed_data[\"bounding_boxes\"][\"labels\"],\n        )\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_elastic_transform.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    base_image_preprocessing_transform_example,\n)\nfrom keras.src.random.seed_generator import SeedGenerator\n\n\n@keras_export(\"keras.layers.RandomElasticTransform\")\nclass RandomElasticTransform(BaseImagePreprocessingLayer):\n    \"\"\"A preprocessing layer that applies random elastic transformations.\n\n    This layer distorts input images by applying elastic deformations,\n    simulating a physically realistic transformation. The magnitude of the\n    distortion is controlled by the `scale` parameter, while the `factor`\n    determines the probability of applying the transformation.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Args:\n        factor: A single float or a tuple of two floats.\n            `factor` controls the probability of applying the transformation.\n            - `factor=0.0` ensures no erasing is applied.\n            - `factor=1.0` means erasing is always applied.\n            - If a tuple `(min, max)` is provided, a probability value\n              is sampled between `min` and `max` for each image.\n            - If a single float is provided, a probability is sampled\n              between `0.0` and the given float.\n            Default is 1.0.\n        scale: A float or a tuple of two floats defining the magnitude of\n            the distortion applied.\n            - If a tuple `(min, max)` is provided, a random scale value is\n              sampled within this range.\n            - If a single float is provided, a random scale value is sampled\n              between `0.0` and the given float.\n            Default is 1.0.\n        interpolation: Interpolation mode. Supported values: `\"nearest\"`,\n            `\"bilinear\"`.\n        fill_mode: Points outside the boundaries of the input are filled\n            according to the given mode. Available methods are `\"constant\"`,\n            `\"nearest\"`, `\"wrap\"` and `\"reflect\"`. Defaults to `\"constant\"`.\n            - `\"reflect\"`: `(d c b a | a b c d | d c b a)`\n                The input is extended by reflecting about the edge of the last\n                pixel.\n            - `\"constant\"`: `(k k k k | a b c d | k k k k)`\n                The input is extended by filling all values beyond\n                the edge with the same constant value k specified by\n                `fill_value`.\n            - `\"wrap\"`: `(a b c d | a b c d | a b c d)`\n                The input is extended by wrapping around to the opposite edge.\n            - `\"nearest\"`: `(a a a a | a b c d | d d d d)`\n                The input is extended by the nearest pixel.\n            Note that when using torch backend, `\"reflect\"` is redirected to\n            `\"mirror\"` `(c d c b | a b c d | c b a b)` because torch does not\n            support `\"reflect\"`.\n            Note that torch backend does not support `\"wrap\"`.\n        fill_value: a float represents the value to be filled outside the\n            boundaries when `fill_mode=\"constant\"`.\n        value_range: the range of values the incoming images will have.\n            Represented as a two-number tuple written `[low, high]`. This is\n            typically either `[0, 1]` or `[0, 255]` depending on how your\n            preprocessing pipeline is set up.\n        seed: Integer. Used to create a random seed.\n\n    Example:\n\n    {{base_image_preprocessing_transform_example}}\n    \"\"\"\n\n    _USE_BASE_FACTOR = False\n    _FACTOR_BOUNDS = (0, 1)\n    _SUPPORTED_INTERPOLATION = (\"nearest\", \"bilinear\")\n    _SUPPORTED_FILL_MODES = {\n        \"constant\",\n        \"nearest\",\n        \"wrap\",\n        \"mirror\",\n        \"reflect\",\n    }\n\n    def __init__(\n        self,\n        factor=1.0,\n        scale=1.0,\n        interpolation=\"bilinear\",\n        fill_mode=\"reflect\",\n        fill_value=0.0,\n        value_range=(0, 255),\n        seed=None,\n        data_format=None,\n        **kwargs,\n    ):\n        super().__init__(data_format=data_format, **kwargs)\n        self._set_factor(factor)\n        self.scale = self._set_factor_by_name(scale, \"scale\")\n        self.interpolation = interpolation\n        self.fill_mode = fill_mode\n        self.fill_value = fill_value\n        self.value_range = value_range\n        self.seed = seed\n        self.generator = SeedGenerator(seed)\n\n        if interpolation not in self._SUPPORTED_INTERPOLATION:\n            raise NotImplementedError(\n                f\"Unknown `interpolation` {interpolation}. Expected of one \"\n                f\"{self._SUPPORTED_INTERPOLATION}.\"\n            )\n\n        if fill_mode not in self._SUPPORTED_FILL_MODES:\n            raise NotImplementedError(\n                f\"Unknown `fill_mode` {fill_mode}. Expected of one \"\n                f\"{self._SUPPORTED_FILL_MODES}.\"\n            )\n\n        if self.data_format == \"channels_first\":\n            self.height_axis = -2\n            self.width_axis = -1\n            self.channel_axis = -3\n        else:\n            self.height_axis = -3\n            self.width_axis = -2\n            self.channel_axis = -1\n\n    def _set_factor_by_name(self, factor, name):\n        error_msg = (\n            f\"The `{name}` argument should be a number \"\n            \"(or a list of two numbers) \"\n            \"in the range \"\n            f\"[{self._FACTOR_BOUNDS[0]}, {self._FACTOR_BOUNDS[1]}]. \"\n            f\"Received: factor={factor}\"\n        )\n        if isinstance(factor, (tuple, list)):\n            if len(factor) != 2:\n                raise ValueError(error_msg)\n            if (\n                factor[0] > self._FACTOR_BOUNDS[1]\n                or factor[1] < self._FACTOR_BOUNDS[0]\n            ):\n                raise ValueError(error_msg)\n            lower, upper = sorted(factor)\n        elif isinstance(factor, (int, float)):\n            if (\n                factor < self._FACTOR_BOUNDS[0]\n                or factor > self._FACTOR_BOUNDS[1]\n            ):\n                raise ValueError(error_msg)\n            factor = abs(factor)\n            lower, upper = [max(-factor, self._FACTOR_BOUNDS[0]), factor]\n        else:\n            raise ValueError(error_msg)\n        return lower, upper\n\n    def get_random_transformation(self, data, training=True, seed=None):\n        if not training:\n            return None\n\n        if (self.scale[1] == 0) or (self.factor[1] == 0):\n            return None\n\n        if isinstance(data, dict):\n            images = data[\"images\"]\n        else:\n            images = data\n\n        images_shape = self.backend.shape(images)\n        unbatched = len(images_shape) == 3\n        if unbatched:\n            batch_size = 1\n        else:\n            batch_size = images_shape[0]\n\n        seed = seed or self._get_seed_generator(self.backend._backend)\n\n        transformation_probability = self.backend.random.uniform(\n            shape=(batch_size,),\n            minval=self.factor[0],\n            maxval=self.factor[1],\n            seed=seed,\n        )\n\n        random_threshold = self.backend.random.uniform(\n            shape=(batch_size,),\n            minval=0.0,\n            maxval=1.0,\n            seed=seed,\n        )\n        apply_transform = random_threshold < transformation_probability\n\n        distortion_factor = self.backend.random.uniform(\n            shape=(),\n            minval=self.scale[0],\n            maxval=self.scale[1],\n            seed=seed,\n            dtype=self.compute_dtype,\n        )\n\n        return {\n            \"apply_transform\": apply_transform,\n            \"distortion_factor\": distortion_factor,\n            \"seed\": seed,\n        }\n\n    def get_elastic_transform_params(self, height, width, factor):\n        alpha_scale = 0.1 * factor\n        sigma_scale = 0.05 * factor\n\n        alpha = max(height, width) * alpha_scale\n        sigma = min(height, width) * sigma_scale\n\n        return alpha, sigma\n\n    def _transform_images(self, images, transformation, interpolation):\n        if transformation is None:\n            return images\n\n        apply_transform = transformation[\"apply_transform\"]\n        distortion_factor = transformation[\"distortion_factor\"]\n        seed = transformation[\"seed\"]\n\n        height, width = (\n            images.shape[self.height_axis],\n            images.shape[self.width_axis],\n        )\n\n        alpha, sigma = self.get_elastic_transform_params(\n            height, width, distortion_factor\n        )\n\n        transformed_images = self.backend.image.elastic_transform(\n            images,\n            alpha=alpha,\n            sigma=sigma,\n            interpolation=interpolation,\n            fill_mode=self.fill_mode,\n            fill_value=self.fill_value,\n            seed=seed,\n            data_format=self.data_format,\n        )\n\n        apply_transform = (\n            apply_transform[:, None, None]\n            if len(images.shape) == 3\n            else apply_transform[:, None, None, None]\n        )\n\n        images = self.backend.numpy.where(\n            apply_transform,\n            transformed_images,\n            images,\n        )\n\n        images = self.backend.numpy.clip(\n            images, self.value_range[0], self.value_range[1]\n        )\n\n        images = self.backend.cast(images, self.compute_dtype)\n        return images\n\n    def transform_labels(self, labels, transformation, training=True):\n        return labels\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\n            \"factor\": self.factor,\n            \"scale\": self.scale,\n            \"interpolation\": self.interpolation,\n            \"fill_mode\": self.fill_mode,\n            \"fill_value\": self.fill_value,\n            \"value_range\": self.value_range,\n            \"seed\": self.seed,\n        }\n        return {**base_config, **config}\n\n\nRandomElasticTransform.__doc__ = RandomElasticTransform.__doc__.replace(\n    \"{{base_image_preprocessing_transform_example}}\",\n    base_image_preprocessing_transform_example.replace(\n        \"{LayerName}\", \"RandomElasticTransform\"\n    ),\n)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_elastic_transform_test.py",
    "content": "import numpy as np\nimport pytest\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass RandomElasticTransformTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_layer(self):\n        self.run_layer_test(\n            layers.RandomElasticTransform,\n            init_kwargs={\n                \"factor\": 1.0,\n                \"scale\": 0.5,\n                \"interpolation\": \"bilinear\",\n                \"fill_mode\": \"reflect\",\n                \"fill_value\": 0,\n                \"value_range\": (0, 255),\n                \"seed\": 1,\n            },\n            input_shape=(8, 3, 4, 3),\n            supports_masking=False,\n            expected_output_shape=(8, 3, 4, 3),\n            run_training_check=False,\n        )\n\n    def test_random_elastic_transform_inference(self):\n        seed = 3481\n        layer = layers.RandomElasticTransform()\n\n        np.random.seed(seed)\n        inputs = np.random.randint(0, 255, size=(224, 224, 3))\n        output = layer(inputs, training=False)\n        self.assertAllClose(inputs, output)\n\n    def test_random_elastic_transform_no_op(self):\n        seed = 3481\n        layer = layers.RandomElasticTransform(factor=0)\n\n        np.random.seed(seed)\n        inputs = np.random.randint(0, 255, size=(224, 224, 3))\n        output = layer(inputs)\n        self.assertAllClose(inputs, output)\n\n        layer = layers.RandomElasticTransform(scale=0)\n\n        np.random.seed(seed)\n        inputs = np.random.randint(0, 255, size=(224, 224, 3))\n        output = layer(inputs)\n        self.assertAllClose(inputs, output)\n\n    def test_random_elastic_transform_basic(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            inputs = np.zeros((8, 8, 1))\n            inputs[3:5, 3:5, :] = 1.0\n        else:\n            inputs = np.zeros((1, 8, 8))\n            inputs[:, 3:5, 3:5] = 1.0\n\n        layer = layers.RandomElasticTransform(data_format=data_format)\n\n        transformation = {\n            \"apply_transform\": np.array([True]),\n            \"distortion_factor\": np.float32(0.9109325),\n            \"seed\": 42,\n        }\n\n        output = layer.transform_images(inputs, transformation)\n\n        self.assertNotAllClose(inputs, output)\n        self.assertEqual(inputs.shape, output.shape)\n\n    def test_tf_data_compatibility(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_data = np.random.random((2, 8, 8, 3))\n        else:\n            input_data = np.random.random((2, 3, 8, 8))\n        layer = layers.RandomElasticTransform(data_format=data_format)\n\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        for output in ds.take(1):\n            print(\"Output shape:\", output.shape)  # Debugging line\n            output_numpy = output.numpy()\n            print(\"Output numpy shape:\", output_numpy.shape)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_erasing.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    base_image_preprocessing_color_example,\n)\nfrom keras.src.random import SeedGenerator\n\n\n@keras_export(\"keras.layers.RandomErasing\")\nclass RandomErasing(BaseImagePreprocessingLayer):\n    \"\"\"Random Erasing data augmentation technique.\n\n    Random Erasing is a data augmentation method where random patches of\n    an image are erased (replaced by a constant value or noise)\n    during training to improve generalization.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    References:\n       - [Random Erasing paper](https://arxiv.org/abs/1708.04896).\n\n    Args:\n        factor: A single float or a tuple of two floats.\n            `factor` controls the probability of applying the transformation.\n            - `factor=0.0` ensures no erasing is applied.\n            - `factor=1.0` means erasing is always applied.\n            - If a tuple `(min, max)` is provided, a probability value\n              is sampled between `min` and `max` for each image.\n            - If a single float is provided, a probability is sampled\n              between `0.0` and the given float.\n            Default is 1.0.\n        scale: A tuple of two floats representing the aspect ratio range of\n            the erased patch. This defines the width-to-height ratio of\n            the patch to be erased. It can help control the rw shape of\n            the erased region. Default is (0.02, 0.33).\n        fill_value: A value to fill the erased region with. This can be set to\n            a constant value or `None` to sample a random value\n            from a normal distribution. Default is `None`.\n        value_range: the range of values the incoming images will have.\n            Represented as a two-number tuple written `[low, high]`. This is\n            typically either `[0, 1]` or `[0, 255]` depending on how your\n            preprocessing pipeline is set up.\n        seed: Integer. Used to create a random seed.\n\n    Example:\n\n    {{base_image_preprocessing_color_example}}\n    \"\"\"\n\n    _USE_BASE_FACTOR = False\n    _FACTOR_BOUNDS = (0, 1)\n\n    def __init__(\n        self,\n        factor=1.0,\n        scale=(0.02, 0.33),\n        fill_value=None,\n        value_range=(0, 255),\n        seed=None,\n        data_format=None,\n        **kwargs,\n    ):\n        super().__init__(data_format=data_format, **kwargs)\n        self._set_factor(factor)\n        self.scale = self._set_factor_by_name(scale, \"scale\")\n        self.fill_value = fill_value\n        self.value_range = value_range\n        self.seed = seed\n        self.generator = SeedGenerator(seed)\n\n        if self.data_format == \"channels_first\":\n            self.height_axis = -2\n            self.width_axis = -1\n            self.channel_axis = -3\n        else:\n            self.height_axis = -3\n            self.width_axis = -2\n            self.channel_axis = -1\n\n    def _set_factor_by_name(self, factor, name):\n        error_msg = (\n            f\"The `{name}` argument should be a number \"\n            \"(or a list of two numbers) \"\n            \"in the range \"\n            f\"[{self._FACTOR_BOUNDS[0]}, {self._FACTOR_BOUNDS[1]}]. \"\n            f\"Received: factor={factor}\"\n        )\n        if isinstance(factor, (tuple, list)):\n            if len(factor) != 2:\n                raise ValueError(error_msg)\n            if (\n                factor[0] > self._FACTOR_BOUNDS[1]\n                or factor[1] < self._FACTOR_BOUNDS[0]\n            ):\n                raise ValueError(error_msg)\n            lower, upper = sorted(factor)\n        elif isinstance(factor, (int, float)):\n            if (\n                factor < self._FACTOR_BOUNDS[0]\n                or factor > self._FACTOR_BOUNDS[1]\n            ):\n                raise ValueError(error_msg)\n            factor = abs(factor)\n            lower, upper = [max(-factor, self._FACTOR_BOUNDS[0]), factor]\n        else:\n            raise ValueError(error_msg)\n        return lower, upper\n\n    def _compute_crop_bounds(self, batch_size, image_length, crop_ratio, seed):\n        crop_length = self.backend.cast(\n            crop_ratio * image_length, dtype=self.compute_dtype\n        )\n\n        start_pos = self.backend.random.uniform(\n            shape=[batch_size],\n            minval=0,\n            maxval=1,\n            dtype=self.compute_dtype,\n            seed=seed,\n        ) * (image_length - crop_length)\n\n        end_pos = start_pos + crop_length\n\n        return start_pos, end_pos\n\n    def _generate_batch_mask(self, images_shape, box_corners):\n        def _generate_grid_xy(image_height, image_width):\n            grid_y, grid_x = self.backend.numpy.meshgrid(\n                self.backend.numpy.arange(\n                    image_height, dtype=self.compute_dtype\n                ),\n                self.backend.numpy.arange(\n                    image_width, dtype=self.compute_dtype\n                ),\n                indexing=\"ij\",\n            )\n            if self.data_format == \"channels_last\":\n                grid_y = self.backend.cast(\n                    grid_y[None, :, :, None], dtype=self.compute_dtype\n                )\n                grid_x = self.backend.cast(\n                    grid_x[None, :, :, None], dtype=self.compute_dtype\n                )\n            else:\n                grid_y = self.backend.cast(\n                    grid_y[None, None, :, :], dtype=self.compute_dtype\n                )\n                grid_x = self.backend.cast(\n                    grid_x[None, None, :, :], dtype=self.compute_dtype\n                )\n            return grid_x, grid_y\n\n        image_height, image_width = (\n            images_shape[self.height_axis],\n            images_shape[self.width_axis],\n        )\n        grid_x, grid_y = _generate_grid_xy(image_height, image_width)\n\n        x0, x1, y0, y1 = box_corners\n\n        x0 = x0[:, None, None, None]\n        y0 = y0[:, None, None, None]\n        x1 = x1[:, None, None, None]\n        y1 = y1[:, None, None, None]\n\n        batch_masks = (\n            (grid_x >= x0) & (grid_x < x1) & (grid_y >= y0) & (grid_y < y1)\n        )\n        batch_masks = self.backend.numpy.repeat(\n            batch_masks, images_shape[self.channel_axis], axis=self.channel_axis\n        )\n\n        return batch_masks\n\n    def _get_fill_value(self, images, images_shape, seed):\n        fill_value = self.fill_value\n        if fill_value is None:\n            fill_value = (\n                self.backend.random.normal(\n                    images_shape,\n                    dtype=self.compute_dtype,\n                    seed=seed,\n                )\n                * self.value_range[1]\n            )\n        else:\n            error_msg = (\n                \"The `fill_value` argument should be a number \"\n                \"(or a list of three numbers) \"\n            )\n            if isinstance(fill_value, (tuple, list)):\n                if len(fill_value) != 3:\n                    raise ValueError(error_msg)\n                fill_value = self.backend.numpy.full_like(\n                    images, fill_value, dtype=self.compute_dtype\n                )\n            elif isinstance(fill_value, (int, float)):\n                fill_value = (\n                    self.backend.numpy.ones(\n                        images_shape, dtype=self.compute_dtype\n                    )\n                    * fill_value\n                )\n            else:\n                raise ValueError(error_msg)\n        fill_value = self.backend.numpy.clip(\n            fill_value, self.value_range[0], self.value_range[1]\n        )\n        return fill_value\n\n    def get_random_transformation(self, data, training=True, seed=None):\n        if not training:\n            return None\n\n        if isinstance(data, dict):\n            images = data[\"images\"]\n        else:\n            images = data\n\n        images_shape = self.backend.shape(images)\n        rank = len(images_shape)\n        if rank == 3:\n            batch_size = 1\n        elif rank == 4:\n            batch_size = images_shape[0]\n        else:\n            raise ValueError(\n                \"Expected the input image to be rank 3 or 4. Received \"\n                f\"inputs.shape={images_shape}\"\n            )\n\n        image_height = images_shape[self.height_axis]\n        image_width = images_shape[self.width_axis]\n\n        seed = seed or self._get_seed_generator(self.backend._backend)\n\n        mix_weight = self.backend.random.uniform(\n            shape=(batch_size, 2),\n            minval=self.scale[0],\n            maxval=self.scale[1],\n            dtype=self.compute_dtype,\n            seed=seed,\n        )\n\n        mix_weight = self.backend.numpy.sqrt(mix_weight)\n\n        x0, x1 = self._compute_crop_bounds(\n            batch_size, image_width, mix_weight[:, 0], seed\n        )\n        y0, y1 = self._compute_crop_bounds(\n            batch_size, image_height, mix_weight[:, 1], seed\n        )\n\n        batch_masks = self._generate_batch_mask(\n            images_shape,\n            (x0, x1, y0, y1),\n        )\n\n        erase_probability = self.backend.random.uniform(\n            shape=(batch_size,),\n            minval=self.factor[0],\n            maxval=self.factor[1],\n            seed=seed,\n        )\n\n        random_threshold = self.backend.random.uniform(\n            shape=(batch_size,),\n            minval=0.0,\n            maxval=1.0,\n            seed=seed,\n        )\n        apply_erasing = random_threshold < erase_probability\n\n        fill_value = self._get_fill_value(images, images_shape, seed)\n\n        return {\n            \"apply_erasing\": apply_erasing,\n            \"batch_masks\": batch_masks,\n            \"fill_value\": fill_value,\n        }\n\n    def transform_images(self, images, transformation=None, training=True):\n        if training:\n            images = self.backend.cast(images, self.compute_dtype)\n            batch_masks = transformation[\"batch_masks\"]\n            apply_erasing = transformation[\"apply_erasing\"]\n            fill_value = transformation[\"fill_value\"]\n\n            erased_images = self.backend.numpy.where(\n                batch_masks,\n                fill_value,\n                images,\n            )\n\n            images = self.backend.numpy.where(\n                apply_erasing[:, None, None, None],\n                erased_images,\n                images,\n            )\n\n        images = self.backend.cast(images, self.compute_dtype)\n        return images\n\n    def transform_labels(self, labels, transformation, training=True):\n        return labels\n\n    def transform_bounding_boxes(\n        self,\n        bounding_boxes,\n        transformation,\n        training=True,\n    ):\n        return bounding_boxes\n\n    def transform_segmentation_masks(\n        self, segmentation_masks, transformation, training=True\n    ):\n        return segmentation_masks\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def get_config(self):\n        config = {\n            \"factor\": self.factor,\n            \"scale\": self.scale,\n            \"fill_value\": self.fill_value,\n            \"value_range\": self.value_range,\n            \"seed\": self.seed,\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n\nRandomErasing.__doc__ = RandomErasing.__doc__.replace(\n    \"{{base_image_preprocessing_color_example}}\",\n    base_image_preprocessing_color_example.replace(\n        \"{LayerName}\", \"RandomErasing\"\n    ),\n)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_erasing_test.py",
    "content": "import numpy as np\nimport pytest\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass RandomErasingTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_layer(self):\n        self.run_layer_test(\n            layers.RandomErasing,\n            init_kwargs={\n                \"factor\": 1.0,\n                \"scale\": 0.5,\n                \"fill_value\": 0,\n                \"value_range\": (0, 255),\n                \"seed\": 1,\n            },\n            input_shape=(8, 3, 4, 3),\n            supports_masking=False,\n            expected_output_shape=(8, 3, 4, 3),\n        )\n\n    def test_random_erasing_inference(self):\n        seed = 3481\n        layer = layers.RandomErasing()\n\n        np.random.seed(seed)\n        inputs = np.random.randint(0, 255, size=(224, 224, 3))\n        output = layer(inputs, training=False)\n        self.assertAllClose(inputs, output)\n\n    def test_random_erasing_no_op(self):\n        seed = 3481\n        layer = layers.RandomErasing(factor=0)\n\n        np.random.seed(seed)\n        inputs = np.random.randint(0, 255, size=(224, 224, 3))\n        output = layer(inputs)\n        self.assertAllClose(inputs, output)\n\n        layer = layers.RandomErasing(scale=0)\n\n        np.random.seed(seed)\n        inputs = np.random.randint(0, 255, size=(224, 224, 3))\n        output = layer(inputs)\n        self.assertAllClose(inputs, output)\n\n    def test_random_erasing_basic(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            inputs = np.ones((2, 2, 1))\n            expected_output = np.array([[[[0.0], [1.0]], [[1.0], [1.0]]]])\n\n        else:\n            inputs = np.ones((1, 2, 2))\n\n            expected_output = np.array(\n                [[[[0.0, 0.0], [1.0, 1.0]], [[1.0, 1.0], [1.0, 1.0]]]]\n            )\n\n        layer = layers.RandomErasing(data_format=data_format)\n\n        transformation = {\n            \"apply_erasing\": np.asarray([True]),\n            \"batch_masks\": np.asarray(\n                [[[[True], [False]], [[False], [False]]]]\n            ),\n            \"fill_value\": 0,\n        }\n\n        output = layer.transform_images(inputs, transformation)\n\n        print(output)\n\n        self.assertAllClose(expected_output, output)\n\n    def test_tf_data_compatibility(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_data = np.random.random((2, 8, 8, 3))\n        else:\n            input_data = np.random.random((2, 3, 8, 8))\n        layer = layers.RandomErasing(data_format=data_format)\n\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        for output in ds.take(1):\n            output.numpy()\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_flip.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    base_image_preprocessing_transform_example,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (  # noqa: E501\n    clip_to_image_size,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (  # noqa: E501\n    convert_format,\n)\nfrom keras.src.random.seed_generator import SeedGenerator\nfrom keras.src.utils import backend_utils\n\nHORIZONTAL = \"horizontal\"\nVERTICAL = \"vertical\"\nHORIZONTAL_AND_VERTICAL = \"horizontal_and_vertical\"\n\n\n@keras_export(\"keras.layers.RandomFlip\")\nclass RandomFlip(BaseImagePreprocessingLayer):\n    \"\"\"A preprocessing layer which randomly flips images during training.\n\n    This layer will flip the images horizontally and or vertically based on the\n    `mode` attribute. During inference time, the output will be identical to\n    input. Call the layer with `training=True` to flip the input.\n    Input pixel values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) and\n    of integer or floating point dtype.\n    By default, the layer will output floats.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Input shape:\n        3D (unbatched) or 4D (batched) tensor with shape:\n        `(..., height, width, channels)`, in `\"channels_last\"` format.\n\n    Output shape:\n        3D (unbatched) or 4D (batched) tensor with shape:\n        `(..., height, width, channels)`, in `\"channels_last\"` format.\n\n    Args:\n        mode: String indicating which flip mode to use. Can be `\"horizontal\"`,\n            `\"vertical\"`, or `\"horizontal_and_vertical\"`. `\"horizontal\"` is a\n            left-right flip and `\"vertical\"` is a top-bottom flip. Defaults to\n            `\"horizontal_and_vertical\"`\n        seed: Integer. Used to create a random seed.\n        **kwargs: Base layer keyword arguments, such as\n            `name` and `dtype`.\n\n    Example:\n\n    {{base_image_preprocessing_transform_example}}\n    \"\"\"\n\n    _USE_BASE_FACTOR = False\n\n    def __init__(\n        self,\n        mode=HORIZONTAL_AND_VERTICAL,\n        seed=None,\n        data_format=None,\n        **kwargs,\n    ):\n        super().__init__(data_format=data_format, **kwargs)\n        self.seed = seed\n        self.generator = SeedGenerator(seed)\n        self.mode = mode\n        self._convert_input_args = False\n        self._allow_non_tensor_positional_args = True\n\n    def get_random_transformation(self, data, training=True, seed=None):\n        if not training:\n            return None\n\n        if isinstance(data, dict):\n            images = data[\"images\"]\n        else:\n            images = data\n        shape = self.backend.core.shape(images)\n        if len(shape) == 3:\n            flips_shape = (1, 1, 1)\n        else:\n            flips_shape = (shape[0], 1, 1, 1)\n\n        if seed is None:\n            seed = self._get_seed_generator(self.backend._backend)\n\n        flips = self.backend.numpy.less_equal(\n            self.backend.random.uniform(shape=flips_shape, seed=seed), 0.5\n        )\n        return {\"flips\": flips, \"input_shape\": shape}\n\n    def transform_images(self, images, transformation, training=True):\n        images = self.backend.cast(images, self.compute_dtype)\n        if training:\n            return self._flip_inputs(images, transformation)\n        return images\n\n    def transform_labels(self, labels, transformation, training=True):\n        return labels\n\n    def transform_bounding_boxes(\n        self,\n        bounding_boxes,\n        transformation,\n        training=True,\n    ):\n        def _flip_boxes_horizontal(boxes):\n            x1, x2, x3, x4 = self.backend.numpy.split(boxes, 4, axis=-1)\n            outputs = self.backend.numpy.concatenate(\n                [1 - x3, x2, 1 - x1, x4], axis=-1\n            )\n            return outputs\n\n        def _flip_boxes_vertical(boxes):\n            x1, x2, x3, x4 = self.backend.numpy.split(boxes, 4, axis=-1)\n            outputs = self.backend.numpy.concatenate(\n                [x1, 1 - x4, x3, 1 - x2], axis=-1\n            )\n            return outputs\n\n        def _transform_xyxy(boxes, box_flips):\n            bboxes = boxes[\"boxes\"]\n            if self.mode in {HORIZONTAL, HORIZONTAL_AND_VERTICAL}:\n                bboxes = self.backend.numpy.where(\n                    box_flips,\n                    _flip_boxes_horizontal(bboxes),\n                    bboxes,\n                )\n            if self.mode in {VERTICAL, HORIZONTAL_AND_VERTICAL}:\n                bboxes = self.backend.numpy.where(\n                    box_flips,\n                    _flip_boxes_vertical(bboxes),\n                    bboxes,\n                )\n            return bboxes\n\n        if training:\n            if backend_utils.in_tf_graph():\n                self.backend.set_backend(\"tensorflow\")\n\n            flips = self.backend.numpy.squeeze(transformation[\"flips\"], axis=-1)\n\n            if self.data_format == \"channels_first\":\n                height_axis = -2\n                width_axis = -1\n            else:\n                height_axis = -3\n                width_axis = -2\n\n            input_height, input_width = (\n                transformation[\"input_shape\"][height_axis],\n                transformation[\"input_shape\"][width_axis],\n            )\n\n            bounding_boxes = convert_format(\n                bounding_boxes,\n                source=self.bounding_box_format,\n                target=\"rel_xyxy\",\n                height=input_height,\n                width=input_width,\n            )\n\n            bounding_boxes[\"boxes\"] = _transform_xyxy(bounding_boxes, flips)\n\n            bounding_boxes = clip_to_image_size(\n                bounding_boxes=bounding_boxes,\n                height=input_height,\n                width=input_width,\n                bounding_box_format=\"xyxy\",\n            )\n\n            bounding_boxes = convert_format(\n                bounding_boxes,\n                source=\"rel_xyxy\",\n                target=self.bounding_box_format,\n                height=input_height,\n                width=input_width,\n            )\n\n            self.backend.reset()\n\n        return bounding_boxes\n\n    def transform_segmentation_masks(\n        self, segmentation_masks, transformation, training=True\n    ):\n        return self.transform_images(\n            segmentation_masks, transformation, training=training\n        )\n\n    def _flip_inputs(self, inputs, transformation):\n        if transformation is None:\n            return inputs\n\n        flips = transformation[\"flips\"]\n        inputs_shape = self.backend.shape(inputs)\n        unbatched = len(inputs_shape) == 3\n        if unbatched:\n            inputs = self.backend.numpy.expand_dims(inputs, axis=0)\n\n        flipped_outputs = inputs\n        if self.data_format == \"channels_last\":\n            horizontal_axis = -2\n            vertical_axis = -3\n        else:\n            horizontal_axis = -1\n            vertical_axis = -2\n\n        if self.mode == HORIZONTAL or self.mode == HORIZONTAL_AND_VERTICAL:\n            flipped_outputs = self.backend.numpy.where(\n                flips,\n                self.backend.numpy.flip(flipped_outputs, axis=horizontal_axis),\n                flipped_outputs,\n            )\n        if self.mode == VERTICAL or self.mode == HORIZONTAL_AND_VERTICAL:\n            flipped_outputs = self.backend.numpy.where(\n                flips,\n                self.backend.numpy.flip(flipped_outputs, axis=vertical_axis),\n                flipped_outputs,\n            )\n        if unbatched:\n            flipped_outputs = self.backend.numpy.squeeze(\n                flipped_outputs, axis=0\n            )\n        return flipped_outputs\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"seed\": self.seed,\n                \"mode\": self.mode,\n                \"data_format\": self.data_format,\n            }\n        )\n        return config\n\n\nRandomFlip.__doc__ = RandomFlip.__doc__.replace(\n    \"{{base_image_preprocessing_transform_example}}\",\n    base_image_preprocessing_transform_example.replace(\n        \"{LayerName}\", \"RandomFlip\"\n    ),\n)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_flip_test.py",
    "content": "import unittest.mock\n\nimport numpy as np\nfrom absl.testing import parameterized\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\nfrom keras.src import utils\n\n\nclass MockedRandomFlip(layers.RandomFlip):\n    def call(self, inputs, training=True):\n        unbatched = len(inputs.shape) == 3\n        batch_size = 1 if unbatched else self.backend.shape(inputs)[0]\n        mocked_value = self.backend.numpy.full(\n            (batch_size, 1, 1, 1), 0.1, dtype=\"float32\"\n        )\n        with unittest.mock.patch.object(\n            self.backend.random,\n            \"uniform\",\n            return_value=mocked_value,\n        ):\n            out = super().call(inputs, training=training)\n        return out\n\n\nclass RandomFlipTest(testing.TestCase):\n    @parameterized.named_parameters(\n        (\"random_flip_horizontal\", \"horizontal\"),\n        (\"random_flip_vertical\", \"vertical\"),\n        (\"random_flip_both\", \"horizontal_and_vertical\"),\n    )\n    def test_random_flip(self, mode):\n        run_training_check = False if backend.backend() == \"numpy\" else True\n        self.run_layer_test(\n            layers.RandomFlip,\n            init_kwargs={\n                \"mode\": mode,\n            },\n            input_shape=(2, 3, 4),\n            expected_output_shape=(2, 3, 4),\n            supports_masking=False,\n            run_training_check=run_training_check,\n        )\n\n    def test_random_flip_horizontal(self):\n        run_training_check = False if backend.backend() == \"numpy\" else True\n        utils.set_random_seed(0)\n        # Test 3D input: shape (1*2*3)\n        self.run_layer_test(\n            MockedRandomFlip,\n            init_kwargs={\n                \"mode\": \"horizontal\",\n                \"data_format\": \"channels_last\",\n                \"seed\": 42,\n            },\n            input_data=np.asarray([[[2, 3, 4], [5, 6, 7]]]),\n            expected_output=backend.convert_to_tensor([[[5, 6, 7], [2, 3, 4]]]),\n            supports_masking=False,\n            run_training_check=run_training_check,\n        )\n        # Test 4D input: shape (2*1*2*3)\n        self.run_layer_test(\n            MockedRandomFlip,\n            init_kwargs={\n                \"mode\": \"horizontal\",\n                \"data_format\": \"channels_last\",\n                \"seed\": 42,\n            },\n            input_data=np.asarray(\n                [\n                    [[[2, 3, 4], [5, 6, 7]]],\n                    [[[2, 3, 4], [5, 6, 7]]],\n                ]\n            ),\n            expected_output=backend.convert_to_tensor(\n                [\n                    [[[5, 6, 7], [2, 3, 4]]],\n                    [[[5, 6, 7], [2, 3, 4]]],\n                ]\n            ),\n            supports_masking=False,\n            run_training_check=run_training_check,\n        )\n\n    def test_random_flip_vertical(self):\n        run_training_check = False if backend.backend() == \"numpy\" else True\n        utils.set_random_seed(0)\n        # Test 3D input: shape (2*1*3)\n        self.run_layer_test(\n            MockedRandomFlip,\n            init_kwargs={\n                \"mode\": \"vertical\",\n                \"data_format\": \"channels_last\",\n                \"seed\": 42,\n            },\n            input_data=np.asarray([[[2, 3, 4]], [[5, 6, 7]]]),\n            expected_output=backend.convert_to_tensor(\n                [[[5, 6, 7]], [[2, 3, 4]]]\n            ),\n            supports_masking=False,\n            run_training_check=run_training_check,\n        )\n        # Test 4D input: shape (2*2*1*3)\n        self.run_layer_test(\n            MockedRandomFlip,\n            init_kwargs={\n                \"mode\": \"vertical\",\n                \"data_format\": \"channels_last\",\n                \"seed\": 42,\n            },\n            input_data=np.asarray(\n                [\n                    [\n                        [[2, 3, 4]],\n                        [[5, 6, 7]],\n                    ],\n                    [\n                        [[2, 3, 4]],\n                        [[5, 6, 7]],\n                    ],\n                ]\n            ),\n            expected_output=backend.convert_to_tensor(\n                [\n                    [[[5, 6, 7]], [[2, 3, 4]]],\n                    [[[5, 6, 7]], [[2, 3, 4]]],\n                ]\n            ),\n            supports_masking=False,\n            run_training_check=run_training_check,\n        )\n\n    def test_tf_data_compatibility(self):\n        # Test 3D input: shape (2, 1, 3)\n        layer = layers.RandomFlip(\n            \"vertical\", data_format=\"channels_last\", seed=42\n        )\n        input_data = np.array([[[2, 3, 4]], [[5, 6, 7]]])\n        expected_output = np.array([[[5, 6, 7]], [[2, 3, 4]]])\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        output = next(iter(ds)).numpy()\n        self.assertAllClose(output, expected_output)\n        # Test 4D input: shape (2, 2, 1, 3)\n        layer = layers.RandomFlip(\n            \"vertical\", data_format=\"channels_last\", seed=42\n        )\n        input_data = np.array(\n            [\n                [\n                    [[2, 3, 4]],\n                    [[5, 6, 7]],\n                ],\n                [\n                    [[2, 3, 4]],\n                    [[5, 6, 7]],\n                ],\n            ]\n        )\n        expected_output = np.array(\n            [\n                [[[5, 6, 7]], [[2, 3, 4]]],\n                [[[5, 6, 7]], [[2, 3, 4]]],\n            ]\n        )\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        output = next(iter(ds)).numpy()\n        self.assertAllClose(output, expected_output)\n\n    @parameterized.named_parameters(\n        (\n            \"with_horizontal\",\n            \"horizontal\",\n            [[4, 1, 6, 3], [0, 4, 2, 6]],\n        ),\n        (\n            \"with_vertical\",\n            \"vertical\",\n            [[2, 7, 4, 9], [6, 4, 8, 6]],\n        ),\n        (\n            \"with_horizontal_and_vertical\",\n            \"horizontal_and_vertical\",\n            [[4, 7, 6, 9], [0, 4, 2, 6]],\n        ),\n    )\n    def test_random_flip_bounding_boxes(self, mode, expected_boxes):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            image_shape = (10, 8, 3)\n        else:\n            image_shape = (3, 10, 8)\n        input_image = np.random.random(image_shape)\n        bounding_boxes = {\n            \"boxes\": np.array(\n                [\n                    [2, 1, 4, 3],\n                    [6, 4, 8, 6],\n                ]\n            ),\n            \"labels\": np.array([[1, 2]]),\n        }\n        input_data = {\"images\": input_image, \"bounding_boxes\": bounding_boxes}\n        random_flip_layer = layers.RandomFlip(\n            mode,\n            data_format=data_format,\n            seed=42,\n            bounding_box_format=\"xyxy\",\n        )\n\n        transformation = {\n            \"flips\": np.asarray([[True]]),\n            \"input_shape\": input_image.shape,\n        }\n        output = random_flip_layer.transform_bounding_boxes(\n            input_data[\"bounding_boxes\"],\n            transformation=transformation,\n            training=True,\n        )\n\n        self.assertAllClose(output[\"boxes\"], expected_boxes)\n\n    @parameterized.named_parameters(\n        (\n            \"with_horizontal\",\n            \"horizontal\",\n            [[4, 1, 6, 3], [0, 4, 2, 6]],\n        ),\n        (\n            \"with_vertical\",\n            \"vertical\",\n            [[2, 7, 4, 9], [6, 4, 8, 6]],\n        ),\n        (\n            \"with_horizontal_and_vertical\",\n            \"horizontal_and_vertical\",\n            [[4, 7, 6, 9], [0, 4, 2, 6]],\n        ),\n    )\n    def test_random_flip_tf_data_bounding_boxes(self, mode, expected_boxes):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            image_shape = (1, 10, 8, 3)\n        else:\n            image_shape = (1, 3, 10, 8)\n        input_image = np.random.random(image_shape)\n        bounding_boxes = {\n            \"boxes\": np.array(\n                [\n                    [\n                        [2, 1, 4, 3],\n                        [6, 4, 8, 6],\n                    ]\n                ]\n            ),\n            \"labels\": np.array([[1, 2]]),\n        }\n\n        input_data = {\"images\": input_image, \"bounding_boxes\": bounding_boxes}\n\n        ds = tf_data.Dataset.from_tensor_slices(input_data)\n        random_flip_layer = layers.RandomFlip(\n            mode,\n            data_format=data_format,\n            seed=42,\n            bounding_box_format=\"xyxy\",\n        )\n\n        transformation = {\n            \"flips\": np.asarray([[True]]),\n            \"input_shape\": input_image.shape,\n        }\n        ds = ds.map(\n            lambda x: random_flip_layer.transform_bounding_boxes(\n                x[\"bounding_boxes\"],\n                transformation=transformation,\n                training=True,\n            )\n        )\n\n        output = next(iter(ds))\n        expected_boxes = np.array(expected_boxes)\n        self.assertAllClose(output[\"boxes\"], expected_boxes)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_gaussian_blur.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    base_image_preprocessing_color_example,\n)\nfrom keras.src.random import SeedGenerator\n\n\n@keras_export(\"keras.layers.RandomGaussianBlur\")\nclass RandomGaussianBlur(BaseImagePreprocessingLayer):\n    \"\"\"Applies random Gaussian blur to images for data augmentation.\n\n    This layer performs a Gaussian blur operation on input images with a\n    randomly selected degree of blurring, controlled by the `factor` and\n    `sigma` arguments.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Args:\n        factor: A single float or a tuple of two floats.\n            `factor` controls the extent to which the image hue is impacted.\n            `factor=0.0` makes this layer perform a no-op operation,\n            while a value of `1.0` performs the most aggressive\n            blurring available. If a tuple is used, a `factor` is\n            sampled between the two values for every image augmented. If a\n            single float is used, a value between `0.0` and the passed float is\n            sampled. Default is 1.0.\n        kernel_size: Integer. Size of the Gaussian kernel used for blurring.\n            Must be an odd integer. Default is 3.\n        sigma: Float or tuple of two floats. Standard deviation of the Gaussian\n            kernel. Controls the intensity of the blur. If a tuple is provided,\n            a value is sampled between the two for each image. Default is 1.0.\n        value_range: the range of values the incoming images will have.\n            Represented as a two-number tuple written `[low, high]`. This is\n            typically either `[0, 1]` or `[0, 255]` depending on how your\n            preprocessing pipeline is set up.\n        seed: Integer. Used to create a random seed.\n\n    Example:\n\n    {{base_image_preprocessing_color_example}}\n    \"\"\"\n\n    _USE_BASE_FACTOR = False\n    _FACTOR_BOUNDS = (0, 1)\n\n    def __init__(\n        self,\n        factor=1.0,\n        kernel_size=3,\n        sigma=1.0,\n        value_range=(0, 255),\n        data_format=None,\n        seed=None,\n        **kwargs,\n    ):\n        super().__init__(data_format=data_format, **kwargs)\n        self._set_factor(factor)\n        self.kernel_size = self._set_kernel_size(kernel_size, \"kernel_size\")\n        self.sigma = self._set_factor_by_name(sigma, \"sigma\")\n        self.value_range = value_range\n        self.seed = seed\n        self.generator = SeedGenerator(seed)\n\n    def _set_kernel_size(self, factor, name):\n        error_msg = f\"{name} must be an odd number. Received: {name}={factor}\"\n        if isinstance(factor, (tuple, list)):\n            if len(factor) != 2:\n                error_msg = (\n                    f\"The `{name}` argument should be a number \"\n                    \"(or a list of two numbers) \"\n                    f\"Received: {name}={factor}\"\n                )\n                raise ValueError(error_msg)\n            if (factor[0] % 2 == 0) or (factor[1] % 2 == 0):\n                raise ValueError(error_msg)\n            lower, upper = factor\n        elif isinstance(factor, (int, float)):\n            if factor % 2 == 0:\n                raise ValueError(error_msg)\n            lower, upper = factor, factor\n        else:\n            raise ValueError(error_msg)\n\n        return lower, upper\n\n    def _set_factor_by_name(self, factor, name):\n        error_msg = (\n            f\"The `{name}` argument should be a number \"\n            \"(or a list of two numbers) \"\n            \"in the range \"\n            f\"[{self._FACTOR_BOUNDS[0]}, {self._FACTOR_BOUNDS[1]}]. \"\n            f\"Received: factor={factor}\"\n        )\n        if isinstance(factor, (tuple, list)):\n            if len(factor) != 2:\n                raise ValueError(error_msg)\n            if (\n                factor[0] > self._FACTOR_BOUNDS[1]\n                or factor[1] < self._FACTOR_BOUNDS[0]\n            ):\n                raise ValueError(error_msg)\n            lower, upper = sorted(factor)\n        elif isinstance(factor, (int, float)):\n            if (\n                factor < self._FACTOR_BOUNDS[0]\n                or factor > self._FACTOR_BOUNDS[1]\n            ):\n                raise ValueError(error_msg)\n            factor = abs(factor)\n            lower, upper = [max(-factor, self._FACTOR_BOUNDS[0]), factor]\n        else:\n            raise ValueError(error_msg)\n        return lower, upper\n\n    def get_random_transformation(self, data, training=True, seed=None):\n        if not training:\n            return None\n\n        if isinstance(data, dict):\n            images = data[\"images\"]\n        else:\n            images = data\n\n        images_shape = self.backend.shape(images)\n        rank = len(images_shape)\n        if rank == 3:\n            batch_size = 1\n        elif rank == 4:\n            batch_size = images_shape[0]\n        else:\n            raise ValueError(\n                \"Expected the input image to be rank 3 or 4. Received \"\n                f\"inputs.shape={images_shape}\"\n            )\n\n        seed = seed or self._get_seed_generator(self.backend._backend)\n\n        blur_probability = self.backend.random.uniform(\n            shape=(batch_size,),\n            minval=self.factor[0],\n            maxval=self.factor[1],\n            seed=seed,\n        )\n\n        random_threshold = self.backend.random.uniform(\n            shape=(batch_size,),\n            minval=0.0,\n            maxval=1.0,\n            seed=seed,\n        )\n        should_apply_blur = random_threshold < blur_probability\n\n        blur_factor = (\n            self.backend.random.uniform(\n                shape=(2,),\n                minval=self.sigma[0],\n                maxval=self.sigma[1],\n                seed=seed,\n                dtype=self.compute_dtype,\n            )\n            + 1e-6\n        )\n\n        return {\n            \"should_apply_blur\": should_apply_blur,\n            \"blur_factor\": blur_factor,\n        }\n\n    def transform_images(self, images, transformation=None, training=True):\n        images = self.backend.cast(images, self.compute_dtype)\n        if training and transformation is not None:\n            blur_factor = transformation[\"blur_factor\"]\n            should_apply_blur = transformation[\"should_apply_blur\"]\n\n            blur_images = self.backend.image.gaussian_blur(\n                images,\n                kernel_size=self.kernel_size,\n                sigma=blur_factor,\n                data_format=self.data_format,\n            )\n\n            images = self.backend.numpy.where(\n                should_apply_blur[:, None, None, None],\n                blur_images,\n                images,\n            )\n\n            images = self.backend.numpy.clip(\n                images, self.value_range[0], self.value_range[1]\n            )\n\n            images = self.backend.cast(images, dtype=self.compute_dtype)\n\n        return images\n\n    def transform_labels(self, labels, transformation, training=True):\n        return labels\n\n    def transform_segmentation_masks(\n        self, segmentation_masks, transformation, training=True\n    ):\n        return segmentation_masks\n\n    def transform_bounding_boxes(\n        self, bounding_boxes, transformation, training=True\n    ):\n        return bounding_boxes\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"factor\": self.factor,\n                \"kernel_size\": self.kernel_size,\n                \"sigma\": self.sigma,\n                \"value_range\": self.value_range,\n                \"seed\": self.seed,\n            }\n        )\n        return config\n\n\nRandomGaussianBlur.__doc__ = RandomGaussianBlur.__doc__.replace(\n    \"{{base_image_preprocessing_color_example}}\",\n    base_image_preprocessing_color_example.replace(\n        \"{LayerName}\", \"RandomGaussianBlur\"\n    ),\n)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_gaussian_blur_test.py",
    "content": "import numpy as np\nimport pytest\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\nfrom keras.src.backend import convert_to_tensor\n\n\nclass RandomGaussianBlurTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_layer(self):\n        self.run_layer_test(\n            layers.RandomGaussianBlur,\n            init_kwargs={\n                \"factor\": 1.0,\n                \"kernel_size\": 3,\n                \"sigma\": 0,\n                \"value_range\": (0, 255),\n                \"seed\": 1,\n            },\n            input_shape=(8, 3, 4, 3),\n            supports_masking=False,\n            expected_output_shape=(8, 3, 4, 3),\n        )\n\n    def test_random_erasing_inference(self):\n        seed = 3481\n        layer = layers.RandomGaussianBlur()\n\n        np.random.seed(seed)\n        inputs = np.random.randint(0, 255, size=(224, 224, 3))\n        output = layer(inputs, training=False)\n        self.assertAllClose(inputs, output)\n\n    def test_random_erasing_no_op(self):\n        seed = 3481\n        layer = layers.RandomGaussianBlur(factor=0)\n\n        np.random.seed(seed)\n        inputs = np.random.randint(0, 255, size=(224, 224, 3))\n        output = layer(inputs)\n        self.assertAllClose(inputs, output)\n\n    def test_random_erasing_basic(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            inputs = np.ones((1, 2, 2, 3))\n            expected_output = np.asarray(\n                [\n                    [\n                        [[0.7273, 0.7273, 0.7273], [0.7273, 0.7273, 0.7273]],\n                        [[0.7273, 0.7273, 0.7273], [0.7273, 0.7273, 0.7273]],\n                    ]\n                ]\n            )\n\n        else:\n            inputs = np.ones((1, 3, 2, 2))\n            expected_output = np.asarray(\n                [\n                    [\n                        [[0.7273, 0.7273], [0.7273, 0.7273]],\n                        [[0.7273, 0.7273], [0.7273, 0.7273]],\n                        [[0.7273, 0.7273], [0.7273, 0.7273]],\n                    ]\n                ]\n            )\n\n        layer = layers.RandomGaussianBlur(data_format=data_format)\n\n        transformation = {\n            \"blur_factor\": convert_to_tensor([0.3732, 0.8654]),\n            \"should_apply_blur\": convert_to_tensor([True]),\n        }\n        output = layer.transform_images(inputs, transformation)\n\n        self.assertAllClose(\n            expected_output,\n            output,\n            atol=1e-4,\n            rtol=1e-4,\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n\n    def test_tf_data_compatibility(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_data = np.random.random((2, 8, 8, 3))\n        else:\n            input_data = np.random.random((2, 3, 8, 8))\n        layer = layers.RandomGaussianBlur(data_format=data_format)\n\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        for output in ds.take(1):\n            output.numpy()\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py",
    "content": "from keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    base_image_preprocessing_color_example,\n)\n\n\n@keras_export(\"keras.layers.RandomGrayscale\")\nclass RandomGrayscale(BaseImagePreprocessingLayer):\n    \"\"\"Preprocessing layer for random conversion of RGB images to grayscale.\n\n    This layer randomly converts input images to grayscale with a specified\n    factor. When applied, it maintains the original number of channels\n    but sets all channels to the same grayscale value. This can be useful\n    for data augmentation and training models to be robust to color\n    variations.\n\n    The conversion preserves the perceived luminance of the original color\n    image using standard RGB to grayscale conversion coefficients. Images\n    that are not selected for conversion remain unchanged.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Args:\n        factor: Float between 0 and 1, specifying the factor of\n            converting each image to grayscale. Defaults to 0.5. A value of\n            1.0 means all images will be converted, while 0.0 means no images\n            will be converted.\n        data_format: String, one of `\"channels_last\"` (default) or\n            `\"channels_first\"`. The ordering of the dimensions in the inputs.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch, height, width, channels)` while `\"channels_first\"`\n            corresponds to inputs with shape\n            `(batch, channels, height, width)`.\n\n    Input shape:\n        3D (unbatched) or 4D (batched) tensor with shape:\n        `(..., height, width, channels)`, in `\"channels_last\"` format,\n        or `(..., channels, height, width)`, in `\"channels_first\"` format.\n\n    Output shape:\n        Same as input shape. The output maintains the same number of channels\n        as the input, even for grayscale-converted images where all channels\n        will have the same value.\n\n    Example:\n\n    {{base_image_preprocessing_color_example}}\n    \"\"\"\n\n    def __init__(self, factor=0.5, data_format=None, seed=None, **kwargs):\n        super().__init__(**kwargs)\n        if factor < 0 or factor > 1:\n            raise ValueError(\n                f\"`factor` should be between 0 and 1. Received: factor={factor}\"\n            )\n        self.factor = factor\n        self.data_format = backend.standardize_data_format(data_format)\n        self.seed = seed\n        self.generator = self.backend.random.SeedGenerator(seed)\n\n    def get_random_transformation(self, images, training=True, seed=None):\n        if seed is None:\n            seed = self._get_seed_generator(self.backend._backend)\n        # Base case: Unbatched data\n        batch_size = 1\n        if len(images.shape) == 4:\n            # This is a batch of images (4D input)\n            batch_size = self.backend.core.shape(images)[0]\n\n        random_values = self.backend.random.uniform(\n            shape=(batch_size,),\n            minval=0,\n            maxval=1,\n            seed=seed,\n        )\n        should_apply = self.backend.numpy.expand_dims(\n            random_values < self.factor, axis=[1, 2, 3]\n        )\n        return should_apply\n\n    def transform_images(self, images, transformation, training=True):\n        if training:\n            should_apply = (\n                transformation\n                if transformation is not None\n                else self.get_random_transformation(images)\n            )\n\n            grayscale_images = self.backend.image.rgb_to_grayscale(\n                images, data_format=self.data_format\n            )\n            return self.backend.numpy.where(\n                should_apply, grayscale_images, images\n            )\n        return images\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def compute_output_spec(self, inputs, **kwargs):\n        return backend.KerasTensor(\n            inputs.shape, dtype=inputs.dtype, sparse=inputs.sparse\n        )\n\n    def transform_bounding_boxes(self, bounding_boxes, **kwargs):\n        return bounding_boxes\n\n    def transform_labels(self, labels, transformations=None, **kwargs):\n        return labels\n\n    def transform_segmentation_masks(\n        self, segmentation_masks, transformations=None, **kwargs\n    ):\n        return segmentation_masks\n\n    def get_config(self):\n        config = super().get_config()\n        config.update({\"factor\": self.factor})\n        return config\n\n\nRandomGrayscale.__doc__ = RandomGrayscale.__doc__.replace(\n    \"{{base_image_preprocessing_color_example}}\",\n    base_image_preprocessing_color_example.replace(\n        \"{LayerName}\", \"RandomGrayscale\"\n    ),\n)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_grayscale_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import ops\nfrom keras.src import testing\n\n\nclass RandomGrayscaleTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_layer(self):\n        self.run_layer_test(\n            layers.RandomGrayscale,\n            init_kwargs={\n                \"factor\": 0.5,\n                \"data_format\": \"channels_last\",\n            },\n            input_shape=(1, 2, 2, 3),\n            supports_masking=False,\n            expected_output_shape=(1, 2, 2, 3),\n        )\n\n        self.run_layer_test(\n            layers.RandomGrayscale,\n            init_kwargs={\n                \"factor\": 0.5,\n                \"data_format\": \"channels_first\",\n            },\n            input_shape=(1, 3, 2, 2),\n            supports_masking=False,\n            expected_output_shape=(1, 3, 2, 2),\n        )\n\n    @parameterized.named_parameters(\n        (\"channels_last\", \"channels_last\"), (\"channels_first\", \"channels_first\")\n    )\n    def test_grayscale_conversion(self, data_format):\n        if data_format == \"channels_last\":\n            xs = np.random.uniform(0, 255, size=(2, 4, 4, 3)).astype(np.float32)\n            layer = layers.RandomGrayscale(factor=1.0, data_format=data_format)\n            transformed = ops.convert_to_numpy(layer(xs))\n            self.assertEqual(transformed.shape[-1], 3)\n            for img in transformed:\n                r, g, b = img[:, :, 0], img[:, :, 1], img[:, :, 2]\n                self.assertTrue(np.allclose(r, g) and np.allclose(g, b))\n        else:\n            xs = np.random.uniform(0, 255, size=(2, 3, 4, 4)).astype(np.float32)\n            layer = layers.RandomGrayscale(factor=1.0, data_format=data_format)\n            transformed = ops.convert_to_numpy(layer(xs))\n            self.assertEqual(transformed.shape[1], 3)\n            for img in transformed:\n                r, g, b = img[0], img[1], img[2]\n                self.assertTrue(np.allclose(r, g) and np.allclose(g, b))\n\n    def test_invalid_factor(self):\n        with self.assertRaises(ValueError):\n            layers.RandomGrayscale(factor=-0.1)\n\n        with self.assertRaises(ValueError):\n            layers.RandomGrayscale(factor=1.1)\n\n    def test_tf_data_compatibility(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_data = np.random.random((2, 8, 8, 3)) * 255\n        else:\n            input_data = np.random.random((2, 3, 8, 8)) * 255\n\n        layer = layers.RandomGrayscale(factor=0.5, data_format=data_format)\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n\n        for output in ds.take(1):\n            output_array = output.numpy()\n            self.assertEqual(output_array.shape, input_data.shape)\n\n    def test_grayscale_with_single_color_image(self):\n        test_cases = [\n            # batched inputs\n            (np.full((1, 4, 4, 3), 128, dtype=np.float32), \"channels_last\"),\n            (np.full((1, 3, 4, 4), 128, dtype=np.float32), \"channels_first\"),\n            # unbatched inputs\n            (np.full((4, 4, 3), 128, dtype=np.float32), \"channels_last\"),\n            (np.full((3, 4, 4), 128, dtype=np.float32), \"channels_first\"),\n        ]\n\n        for xs, data_format in test_cases:\n            layer = layers.RandomGrayscale(factor=1.0, data_format=data_format)\n            transformed = ops.convert_to_numpy(layer(xs))\n\n            # Determine if the input was batched\n            is_batched = len(xs.shape) == 4\n\n            # If batched, select the first image from the batch for inspection.\n            # Otherwise, use the transformed image directly.\n            # `image_to_inspect` will always be a 3D tensor.\n            if is_batched:\n                image_to_inspect = transformed[0]\n            else:\n                image_to_inspect = transformed\n\n            if data_format == \"channels_last\":\n                # image_to_inspect has shape (H, W, C),\n                # get the first channel [:, :, 0]\n                channel_data = image_to_inspect[:, :, 0]\n            else:  # data_format == \"channels_first\"\n                # image_to_inspect has shape (C, H, W),\n                # get the first channel [0, :, :]\n                channel_data = image_to_inspect[0, :, :]\n\n            unique_vals = np.unique(channel_data)\n            self.assertEqual(len(unique_vals), 1)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_hue.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\n\n\n@keras_export(\"keras.layers.RandomHue\")\nclass RandomHue(BaseImagePreprocessingLayer):\n    \"\"\"Randomly adjusts the hue on given images.\n\n    This layer will randomly increase/reduce the hue for the input RGB\n    images.\n\n    The image hue is adjusted by converting the image(s) to HSV and rotating the\n    hue channel (H) by delta. The image is then converted back to RGB.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Args:\n        factor: A single float or a tuple of two floats.\n            `factor` controls the extent to which the\n            image hue is impacted. `factor=0.0` makes this layer perform a\n            no-op operation, while a value of `1.0` performs the most aggressive\n            contrast adjustment available. If a tuple is used, a `factor` is\n            sampled between the two values for every image augmented. If a\n            single float is used, a value between `0.0` and the passed float is\n            sampled. In order to ensure the value is always the same, please\n            pass a tuple with two identical floats: `(0.5, 0.5)`.\n        value_range: the range of values the incoming images will have.\n            Represented as a two-number tuple written `[low, high]`. This is\n            typically either `[0, 1]` or `[0, 255]` depending on how your\n            preprocessing pipeline is set up.\n        seed: Integer. Used to create a random seed.\n\n    Example:\n\n    ```python\n    (images, labels), _ = keras.datasets.cifar10.load_data()\n    random_hue = keras.layers.RandomHue(factor=0.5, value_range=[0, 1])\n    images = keras.ops.cast(images, \"float32\")\n    augmented_images_batch = random_hue(images[:8])\n    ```\n    \"\"\"\n\n    _USE_BASE_FACTOR = True\n    _FACTOR_BOUNDS = (0, 1)\n\n    def __init__(\n        self,\n        factor,\n        value_range=(0, 255),\n        data_format=None,\n        seed=None,\n        **kwargs,\n    ):\n        super().__init__(data_format=data_format, **kwargs)\n        self._set_factor(factor)\n        self.value_range = value_range\n        self.seed = seed\n        self.generator = self.backend.random.SeedGenerator(seed)\n\n    def get_random_transformation(self, data, training=True, seed=None):\n        if isinstance(data, dict):\n            images = data[\"images\"]\n        else:\n            images = data\n        images_shape = self.backend.shape(images)\n        rank = len(images_shape)\n        if rank == 3:\n            batch_size = 1\n        elif rank == 4:\n            batch_size = images_shape[0]\n        else:\n            raise ValueError(\n                \"Expected the input image to be rank 3 or 4. Received \"\n                f\"inputs.shape={images_shape}\"\n            )\n\n        if seed is None:\n            seed = self._get_seed_generator(self.backend._backend)\n        invert = self.backend.random.uniform((batch_size,), seed=seed)\n\n        invert = self.backend.numpy.where(\n            invert > 0.5,\n            -self.backend.numpy.ones_like(invert),\n            self.backend.numpy.ones_like(invert),\n        )\n        factor = self.backend.random.uniform(\n            (batch_size,),\n            minval=self.factor[0],\n            maxval=self.factor[1],\n            seed=seed,\n        )\n        return {\"factor\": invert * factor * 0.5}\n\n    def transform_images(self, images, transformation=None, training=True):\n        def _apply_random_hue(images, transformation):\n            images = self.backend.cast(images, self.compute_dtype)\n            images = self._transform_value_range(\n                images, self.value_range, (0, 1)\n            )\n            adjust_factors = transformation[\"factor\"]\n            adjust_factors = self.backend.cast(adjust_factors, images.dtype)\n            adjust_factors = self.backend.numpy.expand_dims(adjust_factors, -1)\n            adjust_factors = self.backend.numpy.expand_dims(adjust_factors, -1)\n            images = self.backend.image.rgb_to_hsv(\n                images, data_format=self.data_format\n            )\n            if self.data_format == \"channels_first\":\n                h_channel = images[:, 0, :, :] + adjust_factors\n                h_channel = self.backend.numpy.where(\n                    h_channel > 1.0, h_channel - 1.0, h_channel\n                )\n                h_channel = self.backend.numpy.where(\n                    h_channel < 0.0, h_channel + 1.0, h_channel\n                )\n                images = self.backend.numpy.stack(\n                    [h_channel, images[:, 1, :, :], images[:, 2, :, :]], axis=1\n                )\n            else:\n                h_channel = images[..., 0] + adjust_factors\n                h_channel = self.backend.numpy.where(\n                    h_channel > 1.0, h_channel - 1.0, h_channel\n                )\n                h_channel = self.backend.numpy.where(\n                    h_channel < 0.0, h_channel + 1.0, h_channel\n                )\n                images = self.backend.numpy.stack(\n                    [h_channel, images[..., 1], images[..., 2]], axis=-1\n                )\n            images = self.backend.image.hsv_to_rgb(\n                images, data_format=self.data_format\n            )\n            images = self.backend.numpy.clip(images, 0, 1)\n            images = self._transform_value_range(\n                images, (0, 1), self.value_range\n            )\n            images = self.backend.cast(images, self.compute_dtype)\n            return images\n\n        if training:\n            images = _apply_random_hue(images, transformation)\n        return images\n\n    def transform_labels(self, labels, transformation, training=True):\n        return labels\n\n    def transform_segmentation_masks(\n        self, segmentation_masks, transformation, training=True\n    ):\n        return segmentation_masks\n\n    def transform_bounding_boxes(\n        self, bounding_boxes, transformation, training=True\n    ):\n        return bounding_boxes\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"factor\": self.factor,\n                \"value_range\": self.value_range,\n                \"seed\": self.seed,\n            }\n        )\n        return config\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_hue_test.py",
    "content": "import numpy as np\nimport pytest\nfrom tensorflow import data as tf_data\n\nimport keras\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass RandomHueTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_layer(self):\n        self.run_layer_test(\n            layers.RandomHue,\n            init_kwargs={\n                \"factor\": 0.75,\n                \"value_range\": (20, 200),\n                \"seed\": 1,\n            },\n            input_shape=(8, 3, 4, 3),\n            supports_masking=False,\n            expected_output_shape=(8, 3, 4, 3),\n        )\n\n    def test_random_hue_inference(self):\n        seed = 3481\n        layer = layers.RandomHue(0.2, [0, 1.0])\n        np.random.seed(seed)\n        inputs = np.random.randint(0, 255, size=(224, 224, 3))\n        output = layer(inputs, training=False)\n        self.assertAllClose(inputs, output)\n\n    def test_random_hue_value_range_0_to_1(self):\n        image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)\n\n        layer = layers.RandomHue(0.2, (0, 1))\n        adjusted_image = layer(image)\n\n        self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0))\n        self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1))\n\n    def test_random_hue_value_range_0_to_255(self):\n        image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=255)\n\n        layer = layers.RandomHue(0.2, (0, 255))\n        adjusted_image = layer(image)\n\n        self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0))\n        self.assertTrue(keras.ops.numpy.all(adjusted_image <= 255))\n\n    def test_random_hue_no_change_with_zero_factor(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            inputs = keras.random.randint((224, 224, 3), 0, 255)\n        else:\n            inputs = keras.random.randint((3, 224, 224), 0, 255)\n\n        layer = layers.RandomHue(0, (0, 255), data_format=data_format)\n        output = layer(inputs, training=False)\n        self.assertAllClose(inputs, output, atol=1e-3, rtol=1e-5)\n\n    def test_random_hue_randomness(self):\n        image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)[:5]\n\n        layer = layers.RandomHue(0.2, (0, 255))\n        adjusted_images = layer(image)\n\n        self.assertNotAllClose(adjusted_images, image)\n\n    def test_tf_data_compatibility(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_data = np.random.random((2, 8, 8, 3))\n        else:\n            input_data = np.random.random((2, 3, 8, 8))\n        layer = layers.RandomHue(\n            factor=0.5, value_range=[0, 1], data_format=data_format, seed=1337\n        )\n\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        for output in ds.take(1):\n            output.numpy()\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_invert.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    base_image_preprocessing_color_example,\n)\n\n\n@keras_export(\"keras.layers.RandomInvert\")\nclass RandomInvert(BaseImagePreprocessingLayer):\n    \"\"\"Preprocessing layer for random inversion of image colors.\n\n    This layer randomly inverts the colors of input images with a specified\n    probability range. When applied, each image has a chance of having its\n    colors inverted, where the pixel values are transformed to their\n    complementary values. Images that are not selected for inversion\n    remain unchanged.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Args:\n        factor: A single float or a tuple of two floats.\n            `factor` controls the probability of inverting the image colors.\n            If a tuple is provided, the value is sampled between the two values\n            for each image, where `factor[0]` is the minimum and `factor[1]` is\n            the maximum probability. If a single float is provided, a value\n            between `0.0` and the provided float is sampled.\n            Defaults to `(0, 1)`.\n        value_range: a tuple or a list of two elements. The first value\n            represents the lower bound for values in passed images, the second\n            represents the upper bound. Images passed to the layer should have\n            values within `value_range`. Defaults to `(0, 255)`.\n        seed: Integer. Used to create a random seed.\n\n    Example:\n\n    {{base_image_preprocessing_color_example}}\n    \"\"\"\n\n    _USE_BASE_FACTOR = False\n    _FACTOR_BOUNDS = (0, 1)\n\n    def __init__(\n        self,\n        factor=1.0,\n        value_range=(0, 255),\n        seed=None,\n        data_format=None,\n        **kwargs,\n    ):\n        super().__init__(data_format=data_format, **kwargs)\n        self._set_factor(factor)\n        self.value_range = value_range\n        self.seed = seed\n        self.generator = self.backend.random.SeedGenerator(seed)\n\n    def get_random_transformation(self, data, training=True, seed=None):\n        if not training:\n            return None\n\n        if isinstance(data, dict):\n            images = data[\"images\"]\n        else:\n            images = data\n\n        seed = seed or self._get_seed_generator(self.backend._backend)\n\n        images_shape = self.backend.shape(images)\n        rank = len(images_shape)\n        if rank == 3:\n            batch_size = 1\n        elif rank == 4:\n            batch_size = images_shape[0]\n        else:\n            raise ValueError(\n                \"Expected the input image to be rank 3 or 4. Received \"\n                f\"inputs.shape={images_shape}\"\n            )\n\n        invert_probability = self.backend.random.uniform(\n            shape=(batch_size,),\n            minval=self.factor[0],\n            maxval=self.factor[1],\n            seed=seed,\n        )\n\n        random_threshold = self.backend.random.uniform(\n            shape=(batch_size,),\n            minval=0,\n            maxval=1,\n            seed=seed,\n        )\n\n        apply_inversion = random_threshold < invert_probability\n        return {\"apply_inversion\": apply_inversion}\n\n    def transform_images(self, images, transformation, training=True):\n        if training:\n            images = self.backend.cast(images, self.compute_dtype)\n            apply_inversion = transformation[\"apply_inversion\"]\n            return self.backend.numpy.where(\n                apply_inversion[:, None, None, None],\n                self.value_range[1] - images,\n                images,\n            )\n        return images\n\n    def transform_labels(self, labels, transformation, training=True):\n        return labels\n\n    def transform_bounding_boxes(\n        self,\n        bounding_boxes,\n        transformation,\n        training=True,\n    ):\n        return bounding_boxes\n\n    def transform_segmentation_masks(\n        self, segmentation_masks, transformation, training=True\n    ):\n        return segmentation_masks\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def get_config(self):\n        config = {\n            \"factor\": self.factor,\n            \"value_range\": self.value_range,\n            \"seed\": self.seed,\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n\nRandomInvert.__doc__ = RandomInvert.__doc__.replace(\n    \"{{base_image_preprocessing_color_example}}\",\n    base_image_preprocessing_color_example.replace(\n        \"{LayerName}\", \"RandomInvert\"\n    ),\n)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_invert_test.py",
    "content": "import numpy as np\nimport pytest\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass RandomInvertTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_layer(self):\n        self.run_layer_test(\n            layers.RandomInvert,\n            init_kwargs={\n                \"factor\": 0.75,\n                \"value_range\": (20, 200),\n                \"seed\": 1,\n            },\n            input_shape=(8, 3, 4, 3),\n            supports_masking=False,\n            expected_output_shape=(8, 3, 4, 3),\n        )\n\n    def test_random_invert_inference(self):\n        seed = 3481\n        layer = layers.RandomInvert()\n        np.random.seed(seed)\n        inputs = np.random.randint(0, 255, size=(224, 224, 3))\n        output = layer(inputs, training=False)\n        self.assertAllClose(inputs, output)\n\n    def test_random_invert_no_op(self):\n        seed = 3481\n        layer = layers.RandomInvert(factor=0)\n        np.random.seed(seed)\n        inputs = np.random.randint(0, 255, size=(224, 224, 3))\n        output = layer(inputs)\n        self.assertAllClose(inputs, output)\n\n    def test_random_invert_basic(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_data = np.random.random((1, 8, 8, 3))\n        else:\n            input_data = np.random.random((1, 3, 8, 8))\n        layer = layers.RandomInvert(\n            factor=(1, 1),\n            value_range=[0, 1],\n            data_format=data_format,\n            seed=1337,\n        )\n        output = layer(input_data)\n        self.assertAllClose(1 - input_data, output)\n\n    def test_tf_data_compatibility(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_data = np.random.random((2, 8, 8, 3))\n        else:\n            input_data = np.random.random((2, 3, 8, 8))\n        layer = layers.RandomInvert(\n            factor=0.5, value_range=[0, 1], data_format=data_format, seed=1337\n        )\n\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        for output in ds.take(1):\n            output.numpy()\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_perspective.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    base_image_preprocessing_transform_example,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (  # noqa: E501\n    clip_to_image_size,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (  # noqa: E501\n    convert_format,\n)\nfrom keras.src.random.seed_generator import SeedGenerator\nfrom keras.src.utils import backend_utils\n\n\n@keras_export(\"keras.layers.RandomPerspective\")\nclass RandomPerspective(BaseImagePreprocessingLayer):\n    \"\"\"A preprocessing layer that applies random perspective transformations.\n\n    This layer distorts the perspective of input images by shifting their\n    corner points, simulating a 3D-like transformation. The amount of distortion\n    is controlled by the `factor` and `scale` parameters.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Args:\n        factor: A float or a tuple of two floats.\n            Represents the probability of applying the perspective\n            transformation to each image in the batch.\n            - `factor=0.0` ensures no transformation is applied.\n            - `factor=1.0` means the transformation is always applied.\n            - If a tuple `(min, max)` is provided, a probability is randomly\n              sampled between `min` and `max` for each image.\n            - If a single float is given, the probability is sampled between\n              `0.0` and the provided float.\n            Default is 1.0.\n        scale: A float defining the relative amount of perspective shift.\n            Determines how much the image corners are displaced, affecting\n            the intensity of the perspective effect.\n        interpolation: Interpolation mode. Supported values: `\"nearest\"`,\n            `\"bilinear\"`.\n        fill_value: a float represents the value to be filled outside the\n            boundaries when `fill_mode=\"constant\"`.\n        seed: Integer. Used to create a random seed.\n\n    Example:\n\n    {{base_image_preprocessing_transform_example}}\n    \"\"\"\n\n    _USE_BASE_FACTOR = False\n    _FACTOR_BOUNDS = (0, 1)\n    _SUPPORTED_INTERPOLATION = (\"nearest\", \"bilinear\")\n\n    def __init__(\n        self,\n        factor=1.0,\n        scale=1.0,\n        interpolation=\"bilinear\",\n        fill_value=0.0,\n        seed=None,\n        data_format=None,\n        **kwargs,\n    ):\n        super().__init__(data_format=data_format, **kwargs)\n        self._set_factor(factor)\n        self.scale = scale\n        self.fill_value = fill_value\n        self.interpolation = interpolation\n        self.seed = seed\n        self.generator = SeedGenerator(seed)\n        self.supports_jit = False\n\n        if scale < 0.0 or scale > 1.0:\n            raise ValueError(\n                \"The `scale` argument should be a number \"\n                \"in the range \"\n                f\"[0,1]. \"\n                f\"Received: scale={scale}\"\n            )\n\n        if interpolation not in self._SUPPORTED_INTERPOLATION:\n            raise NotImplementedError(\n                f\"Unknown `interpolation` {interpolation}. Expected of one \"\n                f\"{self._SUPPORTED_INTERPOLATION}.\"\n            )\n\n        if self.data_format == \"channels_first\":\n            self.height_axis = -2\n            self.width_axis = -1\n            self.channel_axis = -3\n        else:\n            self.height_axis = -3\n            self.width_axis = -2\n            self.channel_axis = -1\n\n    def get_random_transformation(self, data, training=True, seed=None):\n        if not training:\n            return None\n\n        if isinstance(data, dict):\n            images = data[\"images\"]\n        else:\n            images = data\n\n        images_shape = self.backend.shape(images)\n        unbatched = len(images_shape) == 3\n        if unbatched:\n            batch_size = 1\n        else:\n            batch_size = images_shape[0]\n        height, width = (\n            images.shape[self.height_axis],\n            images.shape[self.width_axis],\n        )\n\n        seed = seed or self._get_seed_generator(self.backend._backend)\n\n        transformation_probability = self.backend.random.uniform(\n            shape=(batch_size,),\n            minval=self.factor[0],\n            maxval=self.factor[1],\n            seed=seed,\n        )\n\n        random_threshold = self.backend.random.uniform(\n            shape=(batch_size,),\n            minval=0.0,\n            maxval=1.0,\n            seed=seed,\n        )\n        apply_perspective = random_threshold < transformation_probability\n\n        perspective_factor = self.backend.random.uniform(\n            shape=(batch_size, 4, 2),\n            minval=-0.5 * self.scale,\n            maxval=0.5 * self.scale,\n            seed=seed,\n            dtype=self.compute_dtype,\n        )\n\n        start_points = self.backend.convert_to_tensor(\n            [\n                [\n                    [0.0, 0.0],\n                    [width - 1, 0.0],\n                    [0.0, height - 1],\n                    [width - 1, height - 1],\n                ]\n            ],\n            dtype=self.compute_dtype,\n        )\n\n        start_points = self.backend.numpy.repeat(\n            start_points, batch_size, axis=0\n        )\n        end_points = start_points + start_points * perspective_factor\n\n        return {\n            \"apply_perspective\": apply_perspective,\n            \"start_points\": start_points,\n            \"end_points\": end_points,\n            \"input_shape\": images_shape,\n        }\n\n    def _transform_images(self, images, transformation, interpolation):\n        if transformation is None:\n            return images\n\n        inputs_shape = self.backend.shape(images)\n        unbatched = len(inputs_shape) == 3\n        if unbatched:\n            images = self.backend.numpy.expand_dims(images, axis=0)\n\n        start_points = transformation[\"start_points\"]\n        end_points = transformation[\"end_points\"]\n\n        outputs = self.backend.image.perspective_transform(\n            images,\n            start_points,\n            end_points,\n            interpolation=interpolation,\n            fill_value=self.fill_value,\n            data_format=self.data_format,\n        )\n\n        apply_perspective = transformation[\"apply_perspective\"]\n        outputs = self.backend.numpy.where(\n            apply_perspective[:, None, None, None],\n            outputs,\n            images,\n        )\n\n        if unbatched:\n            outputs = self.backend.numpy.squeeze(outputs, axis=0)\n        return outputs\n\n    def transform_bounding_boxes(\n        self,\n        bounding_boxes,\n        transformation,\n        training=True,\n    ):\n        if training and transformation is not None:\n            if backend_utils.in_tf_graph():\n                self.backend.set_backend(\"tensorflow\")\n\n            input_height, input_width = (\n                transformation[\"input_shape\"][self.height_axis],\n                transformation[\"input_shape\"][self.width_axis],\n            )\n\n            bounding_boxes = convert_format(\n                bounding_boxes,\n                source=self.bounding_box_format,\n                target=\"xyxy\",\n                height=input_height,\n                width=input_width,\n            )\n\n            boxes = bounding_boxes[\"boxes\"]\n            x0, y0, x1, y1 = self.backend.numpy.split(boxes, 4, axis=-1)\n\n            start_points = transformation[\"start_points\"]\n            end_points = transformation[\"end_points\"]\n            transform = self.backend.image.compute_homography_matrix(\n                start_points, end_points\n            )\n            transform = self.backend.numpy.expand_dims(transform, axis=1)\n            transform = self.backend.cast(transform, dtype=self.compute_dtype)\n\n            corners = [\n                self._get_transformed_coordinates(x, y, transform)\n                for x, y in [(x0, y0), (x1, y1), (x0, y1), (x1, y0)]\n            ]\n            x_corners, y_corners = zip(*corners)\n\n            xs = self.backend.numpy.stack(x_corners, axis=-1)\n            ys = self.backend.numpy.stack(y_corners, axis=-1)\n\n            min_x, max_x = (\n                self.backend.numpy.min(xs, axis=-1),\n                self.backend.numpy.max(xs, axis=-1),\n            )\n            min_y, max_y = (\n                self.backend.numpy.min(ys, axis=-1),\n                self.backend.numpy.max(ys, axis=-1),\n            )\n\n            min_x = self.backend.numpy.expand_dims(min_x, axis=-1)\n            max_x = self.backend.numpy.expand_dims(max_x, axis=-1)\n            min_y = self.backend.numpy.expand_dims(min_y, axis=-1)\n            max_y = self.backend.numpy.expand_dims(max_y, axis=-1)\n\n            boxes = self.backend.numpy.concatenate(\n                [min_x, min_y, max_x, max_y], axis=-1\n            )\n\n            apply_perspective = self.backend.core.convert_to_tensor(\n                transformation[\"apply_perspective\"], dtype=boxes.dtype\n            )\n\n            bounding_boxes[\"boxes\"] = self.backend.numpy.where(\n                apply_perspective[:, None, None],\n                boxes,\n                bounding_boxes[\"boxes\"],\n            )\n\n            bounding_boxes = clip_to_image_size(\n                bounding_boxes=bounding_boxes,\n                height=input_height,\n                width=input_width,\n                bounding_box_format=\"xyxy\",\n            )\n\n            self.backend.reset()\n\n        return bounding_boxes\n\n    def _get_transformed_coordinates(\n        self, x_coords, y_coords, transformation_matrix\n    ):\n        backend = self.backend\n\n        batch_size = backend.shape(transformation_matrix)[0]\n\n        homogeneous_transform = backend.numpy.concatenate(\n            [transformation_matrix, backend.numpy.ones((batch_size, 1, 1))],\n            axis=-1,\n        )\n        homogeneous_transform = backend.numpy.reshape(\n            homogeneous_transform, (batch_size, 3, 3)\n        )\n\n        inverse_transform = backend.linalg.inv(homogeneous_transform)\n\n        ones_column = backend.numpy.ones_like(x_coords)\n        homogeneous_coords = backend.numpy.concatenate(\n            [x_coords, y_coords, ones_column], axis=-1\n        )\n\n        homogeneous_coords = backend.numpy.moveaxis(homogeneous_coords, -1, -2)\n        transformed_coords = backend.numpy.matmul(\n            inverse_transform, homogeneous_coords\n        )\n        transformed_coords = backend.numpy.moveaxis(transformed_coords, -1, -2)\n\n        x_transformed = transformed_coords[..., 0] / transformed_coords[..., 2]\n        y_transformed = transformed_coords[..., 1] / transformed_coords[..., 2]\n\n        return x_transformed, y_transformed\n\n    def transform_labels(self, labels, transformation, training=True):\n        return labels\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\n            \"factor\": self.factor,\n            \"scale\": self.scale,\n            \"interpolation\": self.interpolation,\n            \"fill_value\": self.fill_value,\n            \"seed\": self.seed,\n        }\n        return {**base_config, **config}\n\n\nRandomPerspective.__doc__ = RandomPerspective.__doc__.replace(\n    \"{{base_image_preprocessing_transform_example}}\",\n    base_image_preprocessing_transform_example.replace(\n        \"{LayerName}\", \"RandomPerspective\"\n    ),\n)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_perspective_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass RandomPerspectiveTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_layer(self):\n        self.run_layer_test(\n            layers.RandomPerspective,\n            init_kwargs={\n                \"factor\": 1.0,\n                \"scale\": 0.5,\n                \"interpolation\": \"bilinear\",\n                \"fill_value\": 0,\n                \"seed\": 1,\n            },\n            input_shape=(8, 3, 4, 3),\n            supports_masking=False,\n            expected_output_shape=(8, 3, 4, 3),\n        )\n\n    def test_random_perspective_inference(self):\n        seed = 3481\n        layer = layers.RandomPerspective()\n\n        np.random.seed(seed)\n        inputs = np.random.randint(0, 255, size=(224, 224, 3))\n        output = layer(inputs, training=False)\n        self.assertAllClose(inputs, output)\n\n    def test_random_perspective_no_op(self):\n        seed = 3481\n        layer = layers.RandomPerspective(factor=0)\n\n        np.random.seed(seed)\n        inputs = np.random.randint(0, 255, size=(224, 224, 3))\n        output = layer(inputs)\n        self.assertAllClose(inputs, output)\n\n    def test_random_perspective_basic(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            inputs = np.ones((4, 4, 1))\n            expected_output = np.asarray(\n                [\n                    [[1.0], [1.0], [0.0], [0.0]],\n                    [[1.0], [1.0], [0.0], [0.0]],\n                    [[0.0], [0.0], [0.0], [0.0]],\n                    [[0.0], [0.0], [0.0], [0.0]],\n                ],\n            )\n\n        else:\n            inputs = np.ones((1, 4, 4))\n            expected_output = np.array(\n                [\n                    [\n                        [1.0, 1.0, 0.0, 0.0],\n                        [1.0, 1.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0],\n                        [0.0, 0.0, 0.0, 0.0],\n                    ]\n                ]\n            )\n\n        layer = layers.RandomPerspective(data_format=data_format)\n\n        transformation = {\n            \"apply_perspective\": np.array([True]),\n            \"start_points\": np.array(\n                [[[0.0, 0.0], [3.0, 0.0], [0.0, 3.0], [3.0, 3.0]]]\n            ),\n            \"end_points\": np.array([[[0.0, 0.0], [1, 0.0], [0.0, 1], [1, 1]]]),\n            \"input_shape\": np.array((4, 4, 1)),\n        }\n        output = layer.transform_images(inputs, transformation)\n\n        self.assertAllClose(expected_output, output, atol=1e-4, rtol=1e-4)\n\n    def test_tf_data_compatibility(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_data = np.random.random((2, 8, 8, 3))\n        else:\n            input_data = np.random.random((2, 3, 8, 8))\n        layer = layers.RandomPerspective(data_format=data_format)\n\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        for output in ds.take(1):\n            output.numpy()\n\n    @parameterized.named_parameters(\n        (\n            \"with_large_scale\",\n            [\n                [\n                    [0.0, 0.0],\n                    [8.151311, 0.0],\n                    [0.0, 12.695701],\n                    [9.2712054, 10.524198],\n                ]\n            ],\n            [\n                [\n                    [2.6490488, 1.1149256, 5.2026834, 3.6187303],\n                    [7.5547166, 4.2492595, 8.0, 6.869391],\n                ]\n            ],\n        ),\n        (\n            \"with_small_scale\",\n            [\n                [\n                    [0.0, 0.0],\n                    [4.151311, 0.0],\n                    [0.0, 6.695701],\n                    [4.2712054, 7.524198],\n                ]\n            ],\n            [\n                [\n                    [1.095408, 0.7504317, 2.2761598, 2.3389952],\n                    [3.5416048, 3.2349987, 4.920989, 5.0568376],\n                ]\n            ],\n        ),\n    )\n    def test_random_perspective_bounding_boxes(\n        self, end_points, expected_boxes\n    ):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            image_shape = (10, 8, 3)\n        else:\n            image_shape = (3, 10, 8)\n        input_image = np.random.random(image_shape)\n        bounding_boxes = {\n            \"boxes\": np.array(\n                [\n                    [\n                        [2, 1, 4, 3],\n                        [6, 4, 8, 6],\n                    ]\n                ]\n            ),\n            \"labels\": np.array([[1, 2]]),\n        }\n        input_data = {\"images\": input_image, \"bounding_boxes\": bounding_boxes}\n        layer = layers.RandomPerspective(\n            # data_format=data_format,\n            seed=42,\n            bounding_box_format=\"xyxy\",\n        )\n\n        transformation = {\n            \"apply_perspective\": np.array([True]),\n            \"end_points\": np.array(end_points),\n            \"input_shape\": np.array(image_shape),\n            \"start_points\": np.array(\n                [[[0.0, 0.0], [7.0, 0.0], [0.0, 9.0], [7.0, 9.0]]]\n            ),\n        }\n\n        output = layer.transform_bounding_boxes(\n            input_data[\"bounding_boxes\"],\n            transformation,\n        )\n\n        self.assertAllClose(\n            output[\"boxes\"],\n            expected_boxes,\n            atol=1e-3,\n            rtol=1e-3,\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n\n    @parameterized.named_parameters(\n        (\n            \"with_large_scale\",\n            [\n                [\n                    [0.0, 0.0],\n                    [8.151311, 0.0],\n                    [0.0, 12.695701],\n                    [9.2712054, 10.524198],\n                ]\n            ],\n            [\n                [\n                    [2.6490488, 1.1149256, 5.2026834, 3.6187303],\n                    [7.5547166, 4.2492595, 8.0, 6.869391],\n                ]\n            ],\n        ),\n        (\n            \"with_small_scale\",\n            [\n                [\n                    [0.0, 0.0],\n                    [4.151311, 0.0],\n                    [0.0, 6.695701],\n                    [4.2712054, 7.524198],\n                ]\n            ],\n            [\n                [\n                    [1.095408, 0.7504317, 2.2761598, 2.3389952],\n                    [3.5416048, 3.2349987, 4.920989, 5.0568376],\n                ]\n            ],\n        ),\n    )\n    def test_random_flip_tf_data_bounding_boxes(\n        self, end_points, expected_boxes\n    ):\n        data_format = backend.config.image_data_format()\n        if backend.config.image_data_format() == \"channels_last\":\n            image_shape = (1, 10, 8, 3)\n        else:\n            image_shape = (1, 3, 10, 8)\n        input_image = np.random.random(image_shape)\n        bounding_boxes = {\n            \"boxes\": np.array(\n                [\n                    [\n                        [\n                            [2, 1, 4, 3],\n                            [6, 4, 8, 6],\n                        ]\n                    ]\n                ]\n            ),\n            \"labels\": np.array([[1, 2]]),\n        }\n\n        input_data = {\"images\": input_image, \"bounding_boxes\": bounding_boxes}\n\n        ds = tf_data.Dataset.from_tensor_slices(input_data)\n        layer = layers.RandomPerspective(\n            data_format=data_format,\n            seed=42,\n            bounding_box_format=\"xyxy\",\n        )\n\n        transformation = {\n            \"apply_perspective\": np.array([True]),\n            \"end_points\": np.array(end_points),\n            \"input_shape\": np.array(image_shape),\n            \"start_points\": np.array(\n                [[[0.0, 0.0], [7.0, 0.0], [0.0, 9.0], [7.0, 9.0]]]\n            ),\n        }\n\n        ds = ds.map(\n            lambda x: layer.transform_bounding_boxes(\n                x[\"bounding_boxes\"],\n                transformation=transformation,\n                training=True,\n            )\n        )\n\n        output = next(iter(ds))\n        expected_boxes = np.array(expected_boxes)\n        self.assertAllClose(\n            output[\"boxes\"], expected_boxes, atol=1e-3, rtol=1e-3\n        )\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_posterization.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    base_image_preprocessing_color_example,\n)\n\n\n@keras_export(\"keras.layers.RandomPosterization\")\nclass RandomPosterization(BaseImagePreprocessingLayer):\n    \"\"\"Reduces the number of bits for each color channel.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    References:\n    - [AutoAugment: Learning Augmentation Policies from Data](https://arxiv.org/abs/1805.09501)\n    - [RandAugment: Practical automated data augmentation with a reduced search space](https://arxiv.org/abs/1909.13719)\n\n    Args:\n        value_range: a tuple or a list of two elements. The first value\n            represents the lower bound for values in passed images, the second\n            represents the upper bound. Images passed to the layer should have\n            values within `value_range`. Defaults to `(0, 255)`.\n        factor: integer, the number of bits to keep for each channel. Must be a\n            value between 1-8.\n\n    Example:\n\n    {{base_image_preprocessing_color_example}}\n    \"\"\"\n\n    _USE_BASE_FACTOR = False\n    _FACTOR_BOUNDS = (1, 8)\n    _MAX_FACTOR = 8\n    _VALUE_RANGE_VALIDATION_ERROR = (\n        \"The `value_range` argument should be a list of two numbers. \"\n    )\n\n    def __init__(\n        self,\n        factor,\n        value_range=(0, 255),\n        data_format=None,\n        seed=None,\n        **kwargs,\n    ):\n        super().__init__(data_format=data_format, **kwargs)\n        self._set_factor(factor)\n        self._set_value_range(value_range)\n        self.seed = seed\n        self.generator = self.backend.random.SeedGenerator(seed)\n\n    def _set_value_range(self, value_range):\n        if not isinstance(value_range, (tuple, list)):\n            raise ValueError(\n                self._VALUE_RANGE_VALIDATION_ERROR\n                + f\"Received: value_range={value_range}\"\n            )\n        if len(value_range) != 2:\n            raise ValueError(\n                self._VALUE_RANGE_VALIDATION_ERROR\n                + f\"Received: value_range={value_range}\"\n            )\n        self.value_range = sorted(value_range)\n\n    def get_random_transformation(self, data, training=True, seed=None):\n        if isinstance(data, dict):\n            images = data[\"images\"]\n        else:\n            images = data\n        images_shape = self.backend.shape(images)\n        rank = len(images_shape)\n        if rank == 3:\n            batch_size = 1\n        elif rank == 4:\n            batch_size = images_shape[0]\n        else:\n            raise ValueError(\n                \"Expected the input image to be rank 3 or 4. Received: \"\n                f\"inputs.shape={images_shape}\"\n            )\n\n        if seed is None:\n            seed = self._get_seed_generator(self.backend._backend)\n\n        if self.factor[0] != self.factor[1]:\n            factor = self.backend.random.randint(\n                (batch_size,),\n                minval=self.factor[0],\n                maxval=self.factor[1],\n                seed=seed,\n                dtype=\"uint8\",\n            )\n        else:\n            factor = (\n                self.backend.numpy.ones((batch_size,), dtype=\"uint8\")\n                * self.factor[0]\n            )\n\n        shift_factor = self._MAX_FACTOR - factor\n        return {\"shift_factor\": shift_factor}\n\n    def transform_images(self, images, transformation=None, training=True):\n        if training:\n            shift_factor = transformation[\"shift_factor\"]\n\n            shift_factor = self.backend.numpy.reshape(\n                shift_factor, self.backend.shape(shift_factor) + (1, 1, 1)\n            )\n\n            images = self._transform_value_range(\n                images,\n                original_range=self.value_range,\n                target_range=(0, 255),\n                dtype=self.compute_dtype,\n            )\n\n            images = self.backend.cast(images, \"uint8\")\n            images = self.backend.numpy.bitwise_left_shift(\n                self.backend.numpy.bitwise_right_shift(images, shift_factor),\n                shift_factor,\n            )\n            images = self.backend.cast(images, self.compute_dtype)\n\n            images = self._transform_value_range(\n                images,\n                original_range=(0, 255),\n                target_range=self.value_range,\n                dtype=self.compute_dtype,\n            )\n\n        return images\n\n    def transform_labels(self, labels, transformation, training=True):\n        return labels\n\n    def transform_segmentation_masks(\n        self, segmentation_masks, transformation, training=True\n    ):\n        return segmentation_masks\n\n    def transform_bounding_boxes(\n        self, bounding_boxes, transformation, training=True\n    ):\n        return bounding_boxes\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"factor\": self.factor,\n                \"value_range\": self.value_range,\n                \"seed\": self.seed,\n            }\n        )\n        return config\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n\nRandomPosterization.__doc__ = RandomPosterization.__doc__.replace(\n    \"{{base_image_preprocessing_color_example}}\",\n    base_image_preprocessing_color_example.replace(\n        \"{LayerName}\", \"RandomPosterization\"\n    ),\n)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_posterization_test.py",
    "content": "import numpy as np\nimport pytest\nfrom tensorflow import data as tf_data\n\nimport keras\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass RandomPosterizationTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_layer(self):\n        self.run_layer_test(\n            layers.RandomPosterization,\n            init_kwargs={\n                \"factor\": 1,\n                \"value_range\": (20, 200),\n                \"seed\": 1,\n            },\n            input_shape=(8, 3, 4, 3),\n            supports_masking=False,\n            expected_output_shape=(8, 3, 4, 3),\n        )\n\n    def test_random_posterization_inference(self):\n        seed = 3481\n        layer = layers.RandomPosterization(1, [0, 255])\n        np.random.seed(seed)\n        inputs = np.random.randint(0, 255, size=(224, 224, 3))\n        output = layer(inputs, training=False)\n        self.assertAllClose(inputs, output)\n\n    def test_random_posterization_basic(self):\n        seed = 3481\n        layer = layers.RandomPosterization(\n            1, [0, 255], data_format=\"channels_last\", seed=seed\n        )\n        np.random.seed(seed)\n        inputs = np.asarray(\n            [[[128.0, 235.0, 87.0], [12.0, 1.0, 23.0], [24.0, 18.0, 121.0]]]\n        )\n        output = layer(inputs)\n        expected_output = np.asarray(\n            [[[128.0, 128.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]\n        )\n        self.assertAllClose(expected_output, output)\n\n    def test_random_posterization_value_range_0_to_1(self):\n        image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)\n\n        layer = layers.RandomPosterization(1, [0, 1.0])\n        adjusted_image = layer(image)\n\n        self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0))\n        self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1))\n\n    def test_random_posterization_value_range_0_to_255(self):\n        image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=255)\n\n        layer = layers.RandomPosterization(1, [0, 255])\n        adjusted_image = layer(image)\n\n        self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0))\n        self.assertTrue(keras.ops.numpy.all(adjusted_image <= 255))\n\n    def test_random_posterization_randomness(self):\n        image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)\n\n        layer = layers.RandomPosterization(1, [0, 255])\n        adjusted_images = layer(image)\n\n        self.assertNotAllClose(adjusted_images, image)\n\n    def test_tf_data_compatibility(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_data = np.random.random((2, 8, 8, 3))\n        else:\n            input_data = np.random.random((2, 3, 8, 8))\n        layer = layers.RandomPosterization(1, [0, 255])\n\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        for output in ds.take(1):\n            output.numpy()\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_rotation.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    base_image_preprocessing_transform_example,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes import (\n    converters,\n)\nfrom keras.src.random.seed_generator import SeedGenerator\n\n\n@keras_export(\"keras.layers.RandomRotation\")\nclass RandomRotation(BaseImagePreprocessingLayer):\n    \"\"\"A preprocessing layer which randomly rotates images during training.\n\n    This layer will apply random rotations to each image, filling empty space\n    according to `fill_mode`.\n\n    By default, random rotations are only applied during training.\n    At inference time, the layer does nothing. If you need to apply random\n    rotations at inference time, pass `training=True` when calling the layer.\n\n    Input pixel values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) and\n    of integer or floating point dtype.\n    By default, the layer will output floats.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Input shape:\n        3D (unbatched) or 4D (batched) tensor with shape:\n        `(..., height, width, channels)`, in `\"channels_last\"` format\n\n    Output shape:\n        3D (unbatched) or 4D (batched) tensor with shape:\n        `(..., height, width, channels)`, in `\"channels_last\"` format\n\n    Args:\n        factor: a float represented as fraction of 2 Pi, or a tuple of size 2\n            representing lower and upper bound for rotating clockwise and\n            counter-clockwise. A positive values means rotating\n            counter clock-wise,\n            while a negative value means clock-wise.\n            When represented as a single\n            float, this value is used for both the upper and lower bound.\n            For instance, `factor=(-0.2, 0.3)`\n            results in an output rotation by a random\n            amount in the range `[-20% * 360, 30% * 360]`.\n            `factor=0.2` results in an\n            output rotating by a random amount\n            in the range `[-20% * 360, 20% * 360]`.\n        fill_mode: Points outside the boundaries of the input are filled\n            according to the given mode\n            (one of `{\"constant\", \"reflect\", \"wrap\", \"nearest\"}`).\n            - *reflect*: `(d c b a | a b c d | d c b a)`\n                The input is extended by reflecting about\n                the edge of the last pixel.\n            - *constant*: `(k k k k | a b c d | k k k k)`\n                The input is extended by\n                filling all values beyond the edge with\n                the same constant value k = 0.\n            - *wrap*: `(a b c d | a b c d | a b c d)` The input is extended by\n                wrapping around to the opposite edge.\n            - *nearest*: `(a a a a | a b c d | d d d d)`\n                The input is extended by the nearest pixel.\n        interpolation: Interpolation mode. Supported values: `\"nearest\"`,\n            `\"bilinear\"`.\n        seed: Integer. Used to create a random seed.\n        fill_value: a float represents the value to be filled outside\n            the boundaries when `fill_mode=\"constant\"`.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, height, width, channels)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, channels, height, width)`. It defaults to the\n            `image_data_format` value found in your Keras config file at\n            `~/.keras/keras.json`. If you never set it, then it will be\n            `\"channels_last\"`.\n\n    Example:\n\n    {{base_image_preprocessing_transform_example}}\n    \"\"\"\n\n    _SUPPORTED_FILL_MODE = (\"reflect\", \"wrap\", \"constant\", \"nearest\")\n    _SUPPORTED_INTERPOLATION = (\"nearest\", \"bilinear\")\n\n    def __init__(\n        self,\n        factor,\n        fill_mode=\"reflect\",\n        interpolation=\"bilinear\",\n        seed=None,\n        fill_value=0.0,\n        data_format=None,\n        **kwargs,\n    ):\n        super().__init__(factor=factor, data_format=data_format, **kwargs)\n        self.seed = seed\n        self.generator = SeedGenerator(seed)\n        self.fill_mode = fill_mode\n        self.interpolation = interpolation\n        self.fill_value = fill_value\n        self.supports_jit = False\n\n        if self.fill_mode not in self._SUPPORTED_FILL_MODE:\n            raise NotImplementedError(\n                f\"Unknown `fill_mode` {fill_mode}. Expected of one \"\n                f\"{self._SUPPORTED_FILL_MODE}.\"\n            )\n        if self.interpolation not in self._SUPPORTED_INTERPOLATION:\n            raise NotImplementedError(\n                f\"Unknown `interpolation` {interpolation}. Expected of one \"\n                f\"{self._SUPPORTED_INTERPOLATION}.\"\n            )\n\n    def _transform_images(self, images, transformation, interpolation):\n        return self.backend.image.affine_transform(\n            images=images,\n            transform=transformation[\"rotation_matrix\"],\n            interpolation=interpolation,\n            fill_mode=self.fill_mode,\n            fill_value=self.fill_value,\n            data_format=self.data_format,\n        )\n\n    def transform_labels(self, labels, transformation, training=True):\n        return labels\n\n    def transform_bounding_boxes(\n        self,\n        bounding_boxes,\n        transformation,\n        training=True,\n    ):\n        if training:\n            ops = self.backend\n            boxes = bounding_boxes[\"boxes\"]\n            height = transformation[\"image_height\"]\n            width = transformation[\"image_width\"]\n            batch_size = transformation[\"batch_size\"]\n            boxes = converters.affine_transform(\n                boxes=boxes,\n                angle=transformation[\"angle\"],\n                translate_x=ops.numpy.zeros([batch_size]),\n                translate_y=ops.numpy.zeros([batch_size]),\n                scale=ops.numpy.ones([batch_size]),\n                shear_x=ops.numpy.zeros([batch_size]),\n                shear_y=ops.numpy.zeros([batch_size]),\n                height=height,\n                width=width,\n            )\n\n            bounding_boxes[\"boxes\"] = boxes\n            bounding_boxes = converters.clip_to_image_size(\n                bounding_boxes,\n                height=height,\n                width=width,\n                bounding_box_format=\"xyxy\",\n            )\n            bounding_boxes = converters.convert_format(\n                bounding_boxes,\n                source=\"xyxy\",\n                target=self.bounding_box_format,\n                height=height,\n                width=width,\n            )\n        return bounding_boxes\n\n    def get_random_transformation(self, data, training=True, seed=None):\n        ops = self.backend\n        if not training:\n            return None\n        if isinstance(data, dict):\n            images = data[\"images\"]\n        else:\n            images = data\n        shape = ops.core.shape(images)\n        if len(shape) == 4:\n            batch_size = shape[0]\n            if self.data_format == \"channels_last\":\n                image_height = shape[1]\n                image_width = shape[2]\n            else:\n                image_height = shape[2]\n                image_width = shape[3]\n        else:\n            batch_size = 1\n            if self.data_format == \"channels_last\":\n                image_height = shape[0]\n                image_width = shape[1]\n            else:\n                image_height = shape[1]\n                image_width = shape[2]\n\n        if seed is None:\n            seed = self._get_seed_generator(ops._backend)\n        lower = self.factor[0] * 360.0\n        upper = self.factor[1] * 360.0\n        angle = ops.random.uniform(\n            shape=(batch_size,),\n            minval=lower,\n            maxval=upper,\n            seed=seed,\n        )\n        center_x, center_y = 0.5, 0.5\n        rotation_matrix = self._compute_affine_matrix(\n            center_x=center_x,\n            center_y=center_y,\n            angle=angle,\n            translate_x=ops.numpy.zeros([batch_size]),\n            translate_y=ops.numpy.zeros([batch_size]),\n            scale=ops.numpy.ones([batch_size]),\n            shear_x=ops.numpy.zeros([batch_size]),\n            shear_y=ops.numpy.zeros([batch_size]),\n            height=image_height,\n            width=image_width,\n        )\n        if len(shape) == 3:\n            rotation_matrix = self.backend.numpy.squeeze(\n                rotation_matrix, axis=0\n            )\n        return {\n            \"angle\": angle,\n            \"rotation_matrix\": rotation_matrix,\n            \"image_height\": image_height,\n            \"image_width\": image_width,\n            \"batch_size\": batch_size,\n        }\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def get_config(self):\n        config = {\n            \"factor\": self.factor,\n            \"data_format\": self.data_format,\n            \"fill_mode\": self.fill_mode,\n            \"fill_value\": self.fill_value,\n            \"interpolation\": self.interpolation,\n            \"seed\": self.seed,\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n\nRandomRotation.__doc__ = RandomRotation.__doc__.replace(\n    \"{{base_image_preprocessing_transform_example}}\",\n    base_image_preprocessing_transform_example.replace(\n        \"{LayerName}\", \"RandomRotation\"\n    ),\n)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_rotation_test.py",
    "content": "import numpy as np\nfrom absl.testing import parameterized\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass RandomRotationTest(testing.TestCase):\n    @parameterized.named_parameters(\n        (\"random_rotate_neg4\", -0.4),\n        (\"random_rotate_neg2\", -0.2),\n        (\"random_rotate_4\", 0.4),\n        (\"random_rotate_2\", 0.2),\n        (\"random_rotate_tuple\", (-0.2, 0.4)),\n    )\n    def test_random_rotation_shapes(self, factor):\n        self.run_layer_test(\n            layers.RandomRotation,\n            init_kwargs={\n                \"factor\": factor,\n            },\n            input_shape=(2, 3, 4),\n            expected_output_shape=(2, 3, 4),\n            supports_masking=False,\n            run_training_check=False,\n        )\n\n    def test_random_rotation_correctness(self):\n        if backend.config.image_data_format() == \"channels_last\":\n            input_shape = (1, 5, 5, 1)\n        else:\n            input_shape = (1, 1, 5, 5)\n        input_image = np.reshape(np.arange(0, 25), input_shape)\n        layer = layers.RandomRotation(factor=(0.5, 0.5))\n        actual_output = layer(input_image)\n        expected_output = np.asarray(\n            [\n                [24, 23, 22, 21, 20],\n                [19, 18, 17, 16, 15],\n                [14, 13, 12, 11, 10],\n                [9, 8, 7, 6, 5],\n                [4, 3, 2, 1, 0],\n            ]\n        ).reshape(input_shape)\n\n        self.assertAllClose(\n            backend.convert_to_tensor(expected_output), actual_output, atol=1e-5\n        )\n\n    def test_training_false(self):\n        input_image = np.reshape(np.arange(0, 25), (1, 5, 5, 1))\n        layer = layers.RandomRotation(factor=(0.5, 0.5))\n        actual_output = layer(input_image, training=False)\n        self.assertAllClose(actual_output, input_image)\n\n    def test_tf_data_compatibility(self):\n        if backend.config.image_data_format() == \"channels_last\":\n            input_shape = (1, 5, 5, 1)\n        else:\n            input_shape = (1, 1, 5, 5)\n        input_image = np.reshape(np.arange(0, 25), input_shape)\n        layer = layers.RandomRotation(factor=(0.5, 0.5))\n\n        ds = tf_data.Dataset.from_tensor_slices(input_image).map(layer)\n        expected_output = np.asarray(\n            [\n                [24, 23, 22, 21, 20],\n                [19, 18, 17, 16, 15],\n                [14, 13, 12, 11, 10],\n                [9, 8, 7, 6, 5],\n                [4, 3, 2, 1, 0],\n            ]\n        ).reshape(input_shape[1:])\n        output = next(iter(ds)).numpy()\n        self.assertAllClose(expected_output, output)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_saturation.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.backend import epsilon\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\nfrom keras.src.random import SeedGenerator\n\n\n@keras_export(\"keras.layers.RandomSaturation\")\nclass RandomSaturation(BaseImagePreprocessingLayer):\n    \"\"\"Randomly adjusts the saturation on given images.\n\n    This layer will randomly increase/reduce the saturation for the input RGB\n    images.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Args:\n        factor: A tuple of two floats or a single float.\n            `factor` controls the extent to which the image saturation\n            is impacted. `factor=0.5` makes this layer perform a no-op\n            operation. `factor=0.0` makes the image fully grayscale.\n            `factor=1.0` makes the image fully saturated. Values should\n            be between `0.0` and `1.0`. If a tuple is used, a `factor`\n            is sampled between the two values for every image augmented.\n            If a single float is used, a value between `0.0` and the passed\n            float is sampled. To ensure the value is always the same,\n            pass a tuple with two identical floats: `(0.5, 0.5)`.\n        value_range: the range of values the incoming images will have.\n            Represented as a two-number tuple written `[low, high]`. This is\n            typically either `[0, 1]` or `[0, 255]` depending on how your\n            preprocessing pipeline is set up.\n        seed: Integer. Used to create a random seed.\n\n    Example:\n    ```python\n    (images, labels), _ = keras.datasets.cifar10.load_data()\n    images = images.astype(\"float32\")\n    random_saturation = keras.layers.RandomSaturation(factor=0.2)\n    augmented_images = random_saturation(images)\n    ```\n    \"\"\"\n\n    _VALUE_RANGE_VALIDATION_ERROR = (\n        \"The `value_range` argument should be a list of two numbers. \"\n    )\n\n    def __init__(\n        self,\n        factor,\n        value_range=(0, 255),\n        data_format=None,\n        seed=None,\n        **kwargs,\n    ):\n        super().__init__(data_format=data_format, **kwargs)\n        self._set_factor(factor)\n        self._set_value_range(value_range)\n        self.seed = seed\n        self.generator = SeedGenerator(seed)\n\n    def _set_value_range(self, value_range):\n        if not isinstance(value_range, (tuple, list)):\n            raise ValueError(\n                self._VALUE_RANGE_VALIDATION_ERROR\n                + f\"Received: value_range={value_range}\"\n            )\n        if len(value_range) != 2:\n            raise ValueError(\n                self._VALUE_RANGE_VALIDATION_ERROR\n                + f\"Received: value_range={value_range}\"\n            )\n        self.value_range = sorted(value_range)\n\n    def get_random_transformation(self, data, training=True, seed=None):\n        if isinstance(data, dict):\n            images = data[\"images\"]\n        else:\n            images = data\n        images_shape = self.backend.shape(images)\n        rank = len(images_shape)\n        if rank == 3:\n            batch_size = 1\n        elif rank == 4:\n            batch_size = images_shape[0]\n        else:\n            raise ValueError(\n                \"Expected the input image to be rank 3 or 4. Received: \"\n                f\"inputs.shape={images_shape}\"\n            )\n\n        if seed is None:\n            seed = self._get_seed_generator(self.backend._backend)\n\n        factor = self.backend.random.uniform(\n            (batch_size,),\n            minval=self.factor[0],\n            maxval=self.factor[1],\n            seed=seed,\n        )\n        factor = factor / (1 - factor + epsilon())\n        return {\"factor\": factor}\n\n    def transform_images(self, images, transformation=None, training=True):\n        if training:\n            adjust_factors = transformation[\"factor\"]\n            adjust_factors = self.backend.cast(\n                adjust_factors, self.compute_dtype\n            )\n            adjust_factors = self.backend.numpy.reshape(\n                adjust_factors, self.backend.shape(adjust_factors) + (1, 1)\n            )\n            images = self.backend.image.rgb_to_hsv(\n                images, data_format=self.data_format\n            )\n            if self.data_format == \"channels_first\":\n                s_channel = self.backend.numpy.multiply(\n                    images[:, 1, :, :], adjust_factors\n                )\n                s_channel = self.backend.numpy.clip(\n                    s_channel, self.value_range[0], self.value_range[1]\n                )\n                images = self.backend.numpy.stack(\n                    [images[:, 0, :, :], s_channel, images[:, 2, :, :]], axis=1\n                )\n            else:\n                s_channel = self.backend.numpy.multiply(\n                    images[..., 1], adjust_factors\n                )\n                s_channel = self.backend.numpy.clip(\n                    s_channel, self.value_range[0], self.value_range[1]\n                )\n                images = self.backend.numpy.stack(\n                    [images[..., 0], s_channel, images[..., 2]], axis=-1\n                )\n            images = self.backend.image.hsv_to_rgb(\n                images, data_format=self.data_format\n            )\n        return images\n\n    def transform_labels(self, labels, transformation, training=True):\n        return labels\n\n    def transform_segmentation_masks(\n        self, segmentation_masks, transformation, training=True\n    ):\n        return segmentation_masks\n\n    def transform_bounding_boxes(\n        self, bounding_boxes, transformation, training=True\n    ):\n        return bounding_boxes\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"factor\": self.factor,\n                \"value_range\": self.value_range,\n                \"seed\": self.seed,\n            }\n        )\n        return config\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_saturation_test.py",
    "content": "import numpy as np\nimport pytest\nfrom tensorflow import data as tf_data\n\nimport keras\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass RandomSaturationTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_layer(self):\n        self.run_layer_test(\n            layers.RandomSaturation,\n            init_kwargs={\n                \"factor\": 0.75,\n                \"seed\": 1,\n            },\n            input_shape=(8, 3, 4, 3),\n            supports_masking=False,\n            expected_output_shape=(8, 3, 4, 3),\n        )\n\n    def test_random_saturation_value_range(self):\n        image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)\n\n        layer = layers.RandomSaturation(0.2)\n        adjusted_image = layer(image)\n\n        self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0))\n        self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1))\n\n    def test_random_saturation_no_op(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            inputs = np.random.random((2, 8, 8, 3))\n        else:\n            inputs = np.random.random((2, 3, 8, 8))\n\n        layer = layers.RandomSaturation((0.5, 0.5))\n        output = layer(inputs, training=False)\n        self.assertAllClose(inputs, output, atol=1e-3, rtol=1e-5)\n\n    def test_random_saturation_full_grayscale(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            inputs = np.random.random((2, 8, 8, 3))\n        else:\n            inputs = np.random.random((2, 3, 8, 8))\n        layer = layers.RandomSaturation(factor=(0.0, 0.0))\n        result = layer(inputs)\n\n        if data_format == \"channels_last\":\n            self.assertAllClose(result[..., 0], result[..., 1])\n            self.assertAllClose(result[..., 1], result[..., 2])\n        else:\n            self.assertAllClose(result[:, 0, :, :], result[:, 1, :, :])\n            self.assertAllClose(result[:, 1, :, :], result[:, 2, :, :])\n\n    def test_random_saturation_full_saturation(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            inputs = np.random.random((2, 8, 8, 3))\n        else:\n            inputs = np.random.random((2, 3, 8, 8))\n        layer = layers.RandomSaturation(factor=(1.0, 1.0))\n        result = layer(inputs)\n\n        hsv = backend.image.rgb_to_hsv(result)\n        s_channel = hsv[..., 1]\n\n        self.assertAllClose(\n            keras.ops.numpy.max(s_channel), layer.value_range[1]\n        )\n\n    def test_random_saturation_randomness(self):\n        image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)[:5]\n\n        layer = layers.RandomSaturation(0.2)\n        adjusted_images = layer(image)\n\n        self.assertNotAllClose(adjusted_images, image)\n\n    def test_tf_data_compatibility(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_data = np.random.random((2, 8, 8, 3))\n        else:\n            input_data = np.random.random((2, 3, 8, 8))\n        layer = layers.RandomSaturation(\n            factor=0.5, data_format=data_format, seed=1337\n        )\n\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        for output in ds.take(1):\n            output.numpy()\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_sharpness.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    base_image_preprocessing_color_example,\n)\nfrom keras.src.random import SeedGenerator\n\n\n@keras_export(\"keras.layers.RandomSharpness\")\nclass RandomSharpness(BaseImagePreprocessingLayer):\n    \"\"\"Randomly performs the sharpness operation on given images.\n\n    The sharpness operation first performs a blur, then blends between the\n    original image and the processed image. This operation adjusts the clarity\n    of the edges in an image, ranging from blurred to enhanced sharpness.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Args:\n        factor: A tuple of two floats or a single float.\n            `factor` controls the extent to which the image sharpness\n            is impacted. `factor=0.0` results in a fully blurred image,\n            `factor=0.5` applies no operation (preserving the original image),\n            and `factor=1.0` enhances the sharpness beyond the original. Values\n            should be between `0.0` and `1.0`. If a tuple is used, a `factor`\n            is sampled between the two values for every image augmented.\n            If a single float is used, a value between `0.0` and the passed\n            float is sampled. To ensure the value is always the same,\n            pass a tuple with two identical floats: `(0.5, 0.5)`.\n        value_range: the range of values the incoming images will have.\n            Represented as a two-number tuple written `[low, high]`. This is\n            typically either `[0, 1]` or `[0, 255]` depending on how your\n            preprocessing pipeline is set up.\n        seed: Integer. Used to create a random seed.\n\n    Example:\n\n    {{base_image_preprocessing_color_example}}\n    \"\"\"\n\n    _USE_BASE_FACTOR = False\n    _FACTOR_BOUNDS = (0, 1)\n\n    _VALUE_RANGE_VALIDATION_ERROR = (\n        \"The `value_range` argument should be a list of two numbers. \"\n    )\n\n    def __init__(\n        self,\n        factor,\n        value_range=(0, 255),\n        data_format=None,\n        seed=None,\n        **kwargs,\n    ):\n        super().__init__(data_format=data_format, **kwargs)\n        self._set_factor(factor)\n        self._set_value_range(value_range)\n        self.seed = seed\n        self.generator = SeedGenerator(seed)\n\n    def _set_value_range(self, value_range):\n        if not isinstance(value_range, (tuple, list)):\n            raise ValueError(\n                self._VALUE_RANGE_VALIDATION_ERROR\n                + f\"Received: value_range={value_range}\"\n            )\n        if len(value_range) != 2:\n            raise ValueError(\n                self._VALUE_RANGE_VALIDATION_ERROR\n                + f\"Received: value_range={value_range}\"\n            )\n        self.value_range = sorted(value_range)\n\n    def get_random_transformation(self, data, training=True, seed=None):\n        if isinstance(data, dict):\n            images = data[\"images\"]\n        else:\n            images = data\n        images_shape = self.backend.shape(images)\n        rank = len(images_shape)\n        if rank == 3:\n            batch_size = 1\n        elif rank == 4:\n            batch_size = images_shape[0]\n        else:\n            raise ValueError(\n                \"Expected the input image to be rank 3 or 4. Received: \"\n                f\"inputs.shape={images_shape}\"\n            )\n\n        if seed is None:\n            seed = self._get_seed_generator(self.backend._backend)\n\n        factor = self.backend.random.uniform(\n            (batch_size,),\n            minval=self.factor[0],\n            maxval=self.factor[1],\n            seed=seed,\n        )\n        return {\"factor\": factor}\n\n    def transform_images(self, images, transformation=None, training=True):\n        images = self.backend.cast(images, self.compute_dtype)\n        if training:\n            if self.data_format == \"channels_first\":\n                images = self.backend.numpy.swapaxes(images, -3, -1)\n\n            sharpness_factor = self.backend.cast(\n                transformation[\"factor\"] * 2, dtype=self.compute_dtype\n            )\n            sharpness_factor = self.backend.numpy.reshape(\n                sharpness_factor, (-1, 1, 1, 1)\n            )\n\n            num_channels = self.backend.shape(images)[-1]\n\n            a, b = 1.0 / 13.0, 5.0 / 13.0\n            kernel = self.backend.convert_to_tensor(\n                [[a, a, a], [a, b, a], [a, a, a]], dtype=self.compute_dtype\n            )\n            kernel = self.backend.numpy.reshape(kernel, (3, 3, 1, 1))\n            kernel = self.backend.numpy.tile(kernel, [1, 1, num_channels, 1])\n            kernel = self.backend.cast(kernel, self.compute_dtype)\n\n            smoothed_image = self.backend.nn.depthwise_conv(\n                images,\n                kernel,\n                strides=1,\n                padding=\"same\",\n                data_format=\"channels_last\",\n            )\n\n            smoothed_image = self.backend.cast(\n                smoothed_image, dtype=self.compute_dtype\n            )\n            images = images + (1.0 - sharpness_factor) * (\n                smoothed_image - images\n            )\n\n            images = self.backend.numpy.clip(\n                images, self.value_range[0], self.value_range[1]\n            )\n\n            if self.data_format == \"channels_first\":\n                images = self.backend.numpy.swapaxes(images, -3, -1)\n\n        return images\n\n    def transform_labels(self, labels, transformation, training=True):\n        return labels\n\n    def transform_segmentation_masks(\n        self, segmentation_masks, transformation, training=True\n    ):\n        return segmentation_masks\n\n    def transform_bounding_boxes(\n        self, bounding_boxes, transformation, training=True\n    ):\n        return bounding_boxes\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"factor\": self.factor,\n                \"value_range\": self.value_range,\n                \"seed\": self.seed,\n            }\n        )\n        return config\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n\nRandomSharpness.__doc__ = RandomSharpness.__doc__.replace(\n    \"{{base_image_preprocessing_color_example}}\",\n    base_image_preprocessing_color_example.replace(\n        \"{LayerName}\", \"RandomSharpness\"\n    ),\n)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_sharpness_test.py",
    "content": "import numpy as np\nimport pytest\nfrom tensorflow import data as tf_data\n\nimport keras\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass RandomSharpnessTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_layer(self):\n        self.run_layer_test(\n            layers.RandomSharpness,\n            init_kwargs={\n                \"factor\": 0.75,\n                \"seed\": 1,\n            },\n            input_shape=(8, 3, 4, 3),\n            supports_masking=False,\n            expected_output_shape=(8, 3, 4, 3),\n        )\n\n    def test_random_sharpness_value_range(self):\n        image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)\n\n        layer = layers.RandomSharpness(0.2)\n        adjusted_image = layer(image)\n\n        self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0))\n        self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1))\n\n    def test_random_sharpness_no_op(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            inputs = np.random.random((2, 8, 8, 3))\n        else:\n            inputs = np.random.random((2, 3, 8, 8))\n\n        layer = layers.RandomSharpness((0.5, 0.5))\n        output = layer(inputs, training=False)\n        self.assertAllClose(inputs, output, atol=1e-3, rtol=1e-5)\n\n    def test_random_sharpness_randomness(self):\n        image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)[:5]\n\n        layer = layers.RandomSharpness(0.2)\n        adjusted_images = layer(image)\n\n        self.assertNotAllClose(adjusted_images, image)\n\n    def test_tf_data_compatibility(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_data = np.random.random((2, 8, 8, 3))\n        else:\n            input_data = np.random.random((2, 3, 8, 8))\n        layer = layers.RandomSharpness(\n            factor=0.5, data_format=data_format, seed=1337\n        )\n\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        for output in ds.take(1):\n            output.numpy()\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_shear.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    base_image_preprocessing_transform_example,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (  # noqa: E501\n    clip_to_image_size,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (  # noqa: E501\n    convert_format,\n)\nfrom keras.src.random.seed_generator import SeedGenerator\nfrom keras.src.utils import backend_utils\n\n\n@keras_export(\"keras.layers.RandomShear\")\nclass RandomShear(BaseImagePreprocessingLayer):\n    \"\"\"A preprocessing layer that randomly applies shear transformations to\n    images.\n\n    This layer shears the input images along the x-axis and/or y-axis by a\n    randomly selected factor within the specified range. The shear\n    transformation is applied to each image independently in a batch. Empty\n    regions created during the transformation are filled according to the\n    `fill_mode` and `fill_value` parameters.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Args:\n        x_factor: A tuple of two floats. For each augmented image, a value\n            is sampled from the provided range. If a float is passed, the\n            range is interpreted as `(0, x_factor)`. Values represent a\n            percentage of the image to shear over. For example, 0.3 shears\n            pixels up to 30% of the way across the image. All provided values\n            should be positive.\n        y_factor: A tuple of two floats. For each augmented image, a value\n            is sampled from the provided range. If a float is passed, the\n            range is interpreted as `(0, y_factor)`. Values represent a\n            percentage of the image to shear over. For example, 0.3 shears\n            pixels up to 30% of the way across the image. All provided values\n            should be positive.\n        interpolation: Interpolation mode. Supported values: `\"nearest\"`,\n            `\"bilinear\"`.\n        fill_mode: Points outside the boundaries of the input are filled\n            according to the given mode. Available methods are `\"constant\"`,\n            `\"nearest\"`, `\"wrap\"` and `\"reflect\"`. Defaults to `\"constant\"`.\n            - `\"reflect\"`: `(d c b a | a b c d | d c b a)`\n                The input is extended by reflecting about the edge of the\n                last pixel.\n            - `\"constant\"`: `(k k k k | a b c d | k k k k)`\n                The input is extended by filling all values beyond the edge\n                with the same constant value `k` specified by `fill_value`.\n            - `\"wrap\"`: `(a b c d | a b c d | a b c d)`\n                The input is extended by wrapping around to the opposite edge.\n            - `\"nearest\"`: `(a a a a | a b c d | d d d d)`\n                The input is extended by the nearest pixel.\n            Note that when using torch backend, `\"reflect\"` is redirected to\n            `\"mirror\"` `(c d c b | a b c d | c b a b)` because torch does\n            not support `\"reflect\"`.\n            Note that torch backend does not support `\"wrap\"`.\n        fill_value: A float representing the value to be filled outside the\n            boundaries when `fill_mode=\"constant\"`.\n        seed: Integer. Used to create a random seed.\n\n    Example:\n\n    {{base_image_preprocessing_transform_example}}\n    \"\"\"\n\n    _USE_BASE_FACTOR = False\n    _FACTOR_BOUNDS = (0, 1)\n    _FACTOR_VALIDATION_ERROR = (\n        \"The `factor` argument should be a number (or a list of two numbers) \"\n        \"in the range [0, 1.0]. \"\n    )\n    _SUPPORTED_FILL_MODE = (\"reflect\", \"wrap\", \"constant\", \"nearest\")\n    _SUPPORTED_INTERPOLATION = (\"nearest\", \"bilinear\")\n\n    def __init__(\n        self,\n        x_factor=0.0,\n        y_factor=0.0,\n        interpolation=\"bilinear\",\n        fill_mode=\"reflect\",\n        fill_value=0.0,\n        data_format=None,\n        seed=None,\n        **kwargs,\n    ):\n        super().__init__(data_format=data_format, **kwargs)\n        self.x_factor = self._set_factor_with_name(x_factor, \"x_factor\")\n        self.y_factor = self._set_factor_with_name(y_factor, \"y_factor\")\n\n        if fill_mode not in self._SUPPORTED_FILL_MODE:\n            raise NotImplementedError(\n                f\"Unknown `fill_mode` {fill_mode}. Expected of one \"\n                f\"{self._SUPPORTED_FILL_MODE}.\"\n            )\n        if interpolation not in self._SUPPORTED_INTERPOLATION:\n            raise NotImplementedError(\n                f\"Unknown `interpolation` {interpolation}. Expected of one \"\n                f\"{self._SUPPORTED_INTERPOLATION}.\"\n            )\n\n        self.fill_mode = fill_mode\n        self.fill_value = fill_value\n        self.interpolation = interpolation\n        self.seed = seed\n        self.generator = SeedGenerator(seed)\n        self.supports_jit = False\n\n    def _set_factor_with_name(self, factor, factor_name):\n        if isinstance(factor, (tuple, list)):\n            if len(factor) != 2:\n                raise ValueError(\n                    self._FACTOR_VALIDATION_ERROR\n                    + f\"Received: {factor_name}={factor}\"\n                )\n            self._check_factor_range(factor[0])\n            self._check_factor_range(factor[1])\n            lower, upper = sorted(factor)\n        elif isinstance(factor, (int, float)):\n            self._check_factor_range(factor)\n            factor = abs(factor)\n            lower, upper = [-factor, factor]\n        else:\n            raise ValueError(\n                self._FACTOR_VALIDATION_ERROR\n                + f\"Received: {factor_name}={factor}\"\n            )\n        return lower, upper\n\n    def _check_factor_range(self, input_number):\n        if input_number > 1.0 or input_number < 0.0:\n            raise ValueError(\n                self._FACTOR_VALIDATION_ERROR\n                + f\"Received: input_number={input_number}\"\n            )\n\n    def get_random_transformation(self, data, training=True, seed=None):\n        if not training:\n            return None\n\n        if isinstance(data, dict):\n            images = data[\"images\"]\n        else:\n            images = data\n\n        images_shape = self.backend.shape(images)\n        if len(images_shape) == 3:\n            batch_size = 1\n        else:\n            batch_size = images_shape[0]\n\n        if seed is None:\n            seed = self._get_seed_generator(self.backend._backend)\n\n        invert = self.backend.random.uniform(\n            minval=0,\n            maxval=1,\n            shape=[batch_size, 1],\n            seed=seed,\n            dtype=self.compute_dtype,\n        )\n        invert = self.backend.numpy.where(\n            invert > 0.5,\n            -self.backend.numpy.ones_like(invert),\n            self.backend.numpy.ones_like(invert),\n        )\n\n        shear_y = self.backend.random.uniform(\n            minval=self.y_factor[0],\n            maxval=self.y_factor[1],\n            shape=[batch_size, 1],\n            seed=seed,\n            dtype=self.compute_dtype,\n        )\n        shear_x = self.backend.random.uniform(\n            minval=self.x_factor[0],\n            maxval=self.x_factor[1],\n            shape=[batch_size, 1],\n            seed=seed,\n            dtype=self.compute_dtype,\n        )\n        shear_factor = (\n            self.backend.cast(\n                self.backend.numpy.concatenate([shear_x, shear_y], axis=1),\n                dtype=self.compute_dtype,\n            )\n            * invert\n        )\n        return {\"shear_factor\": shear_factor, \"input_shape\": images_shape}\n\n    def _transform_images(self, images, transformation, interpolation):\n        if transformation is None:\n            return images\n\n        inputs_shape = self.backend.shape(images)\n        unbatched = len(inputs_shape) == 3\n        if unbatched:\n            images = self.backend.numpy.expand_dims(images, axis=0)\n\n        shear_factor = transformation[\"shear_factor\"]\n        outputs = self.backend.image.affine_transform(\n            images,\n            transform=self._get_shear_matrix(shear_factor),\n            interpolation=interpolation,\n            fill_mode=self.fill_mode,\n            fill_value=self.fill_value,\n            data_format=self.data_format,\n        )\n\n        if unbatched:\n            outputs = self.backend.numpy.squeeze(outputs, axis=0)\n        return outputs\n\n    def _get_shear_matrix(self, shear_factors):\n        num_shear_factors = self.backend.shape(shear_factors)[0]\n\n        # The shear matrix looks like:\n        # [[1   s_x  0]\n        #  [s_y  1   0]\n        #  [0    0   1]]\n\n        return self.backend.numpy.stack(\n            [\n                self.backend.numpy.ones((num_shear_factors,)),\n                shear_factors[:, 0],\n                self.backend.numpy.zeros((num_shear_factors,)),\n                shear_factors[:, 1],\n                self.backend.numpy.ones((num_shear_factors,)),\n                self.backend.numpy.zeros((num_shear_factors,)),\n                self.backend.numpy.zeros((num_shear_factors,)),\n                self.backend.numpy.zeros((num_shear_factors,)),\n            ],\n            axis=1,\n        )\n\n    def transform_labels(self, labels, transformation, training=True):\n        return labels\n\n    def get_transformed_x_y(self, x, y, transform):\n        a0, a1, a2, b0, b1, b2, c0, c1 = self.backend.numpy.split(\n            transform, 8, axis=-1\n        )\n\n        k = c0 * x + c1 * y + 1\n        x_transformed = (a0 * x + a1 * y + a2) / k\n        y_transformed = (b0 * x + b1 * y + b2) / k\n        return x_transformed, y_transformed\n\n    def get_shifted_bbox(self, bounding_boxes, w_shift_factor, h_shift_factor):\n        bboxes = bounding_boxes[\"boxes\"]\n        x1, x2, x3, x4 = self.backend.numpy.split(bboxes, 4, axis=-1)\n\n        w_shift_factor = self.backend.convert_to_tensor(\n            w_shift_factor, dtype=x1.dtype\n        )\n        h_shift_factor = self.backend.convert_to_tensor(\n            h_shift_factor, dtype=x1.dtype\n        )\n\n        if len(bboxes.shape) == 3:\n            w_shift_factor = self.backend.numpy.expand_dims(w_shift_factor, -1)\n            h_shift_factor = self.backend.numpy.expand_dims(h_shift_factor, -1)\n\n        bounding_boxes[\"boxes\"] = self.backend.numpy.concatenate(\n            [\n                x1 - w_shift_factor,\n                x2 - h_shift_factor,\n                x3 - w_shift_factor,\n                x4 - h_shift_factor,\n            ],\n            axis=-1,\n        )\n        return bounding_boxes\n\n    def transform_bounding_boxes(\n        self,\n        bounding_boxes,\n        transformation,\n        training=True,\n    ):\n        def _get_height_width(transformation):\n            if self.data_format == \"channels_first\":\n                height_axis = -2\n                width_axis = -1\n            else:\n                height_axis = -3\n                width_axis = -2\n            input_height, input_width = (\n                transformation[\"input_shape\"][height_axis],\n                transformation[\"input_shape\"][width_axis],\n            )\n            return input_height, input_width\n\n        if training:\n            if backend_utils.in_tf_graph():\n                self.backend.set_backend(\"tensorflow\")\n\n            input_height, input_width = _get_height_width(transformation)\n\n            bounding_boxes = convert_format(\n                bounding_boxes,\n                source=self.bounding_box_format,\n                target=\"rel_xyxy\",\n                height=input_height,\n                width=input_width,\n                dtype=self.compute_dtype,\n            )\n\n            bounding_boxes = self._shear_bboxes(bounding_boxes, transformation)\n\n            bounding_boxes = clip_to_image_size(\n                bounding_boxes=bounding_boxes,\n                height=input_height,\n                width=input_width,\n                bounding_box_format=\"rel_xyxy\",\n            )\n\n            bounding_boxes = convert_format(\n                bounding_boxes,\n                source=\"rel_xyxy\",\n                target=self.bounding_box_format,\n                height=input_height,\n                width=input_width,\n                dtype=self.compute_dtype,\n            )\n\n            self.backend.reset()\n\n        return bounding_boxes\n\n    def _shear_bboxes(self, bounding_boxes, transformation):\n        shear_factor = self.backend.cast(\n            transformation[\"shear_factor\"], dtype=self.compute_dtype\n        )\n        shear_x_amount, shear_y_amount = self.backend.numpy.split(\n            shear_factor, 2, axis=-1\n        )\n\n        x1, y1, x2, y2 = self.backend.numpy.split(\n            bounding_boxes[\"boxes\"], 4, axis=-1\n        )\n        x1 = self.backend.numpy.squeeze(x1, axis=-1)\n        y1 = self.backend.numpy.squeeze(y1, axis=-1)\n        x2 = self.backend.numpy.squeeze(x2, axis=-1)\n        y2 = self.backend.numpy.squeeze(y2, axis=-1)\n\n        if shear_x_amount is not None:\n            x1_top = x1 - (shear_x_amount * y1)\n            x1_bottom = x1 - (shear_x_amount * y2)\n            x1 = self.backend.numpy.where(shear_x_amount < 0, x1_top, x1_bottom)\n\n            x2_top = x2 - (shear_x_amount * y1)\n            x2_bottom = x2 - (shear_x_amount * y2)\n            x2 = self.backend.numpy.where(shear_x_amount < 0, x2_bottom, x2_top)\n\n        if shear_y_amount is not None:\n            y1_left = y1 - (shear_y_amount * x1)\n            y1_right = y1 - (shear_y_amount * x2)\n            y1 = self.backend.numpy.where(shear_y_amount > 0, y1_right, y1_left)\n\n            y2_left = y2 - (shear_y_amount * x1)\n            y2_right = y2 - (shear_y_amount * x2)\n            y2 = self.backend.numpy.where(shear_y_amount > 0, y2_left, y2_right)\n\n        boxes = self.backend.numpy.concatenate(\n            [\n                self.backend.numpy.expand_dims(x1, axis=-1),\n                self.backend.numpy.expand_dims(y1, axis=-1),\n                self.backend.numpy.expand_dims(x2, axis=-1),\n                self.backend.numpy.expand_dims(y2, axis=-1),\n            ],\n            axis=-1,\n        )\n        bounding_boxes[\"boxes\"] = boxes\n\n        return bounding_boxes\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\n            \"x_factor\": self.x_factor,\n            \"y_factor\": self.y_factor,\n            \"fill_mode\": self.fill_mode,\n            \"interpolation\": self.interpolation,\n            \"seed\": self.seed,\n            \"fill_value\": self.fill_value,\n            \"data_format\": self.data_format,\n        }\n        return {**base_config, **config}\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n\nRandomShear.__doc__ = RandomShear.__doc__.replace(\n    \"{{base_image_preprocessing_transform_example}}\",\n    base_image_preprocessing_transform_example.replace(\n        \"{LayerName}\", \"RandomShear\"\n    ),\n)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_shear_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\nfrom tensorflow import data as tf_data\n\nimport keras\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\nfrom keras.src.utils import backend_utils\n\n\nclass RandomShearTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_layer(self):\n        self.run_layer_test(\n            layers.RandomShear,\n            init_kwargs={\n                \"x_factor\": (0.5, 1),\n                \"y_factor\": (0.5, 1),\n                \"interpolation\": \"bilinear\",\n                \"fill_mode\": \"reflect\",\n                \"data_format\": \"channels_last\",\n                \"seed\": 1,\n            },\n            input_shape=(8, 3, 4, 3),\n            supports_masking=False,\n            expected_output_shape=(8, 3, 4, 3),\n        )\n\n    def test_random_posterization_inference(self):\n        seed = 3481\n        layer = layers.RandomShear(1, 1)\n        np.random.seed(seed)\n        inputs = np.random.randint(0, 255, size=(224, 224, 3))\n        output = layer(inputs, training=False)\n        self.assertAllClose(inputs, output)\n\n    def test_shear_pixel_level(self):\n        image = np.zeros((1, 5, 5, 3))\n        image[0, 1:4, 1:4, :] = 1.0\n        image[0, 2, 2, :] = [0.0, 1.0, 0.0]\n        image = keras.ops.convert_to_tensor(image, dtype=\"float32\")\n\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_first\":\n            image = keras.ops.transpose(image, (0, 3, 1, 2))\n\n        shear_layer = layers.RandomShear(\n            x_factor=(0.2, 0.3),\n            y_factor=(0.2, 0.3),\n            interpolation=\"bilinear\",\n            fill_mode=\"constant\",\n            fill_value=0.0,\n            seed=42,\n            data_format=data_format,\n        )\n\n        sheared_image = shear_layer(image)\n\n        if data_format == \"channels_first\":\n            sheared_image = keras.ops.transpose(sheared_image, (0, 2, 3, 1))\n\n        original_pixel = image[0, 2, 2, :]\n        sheared_pixel = sheared_image[0, 2, 2, :]\n        self.assertNotAllClose(original_pixel, sheared_pixel)\n\n    def test_tf_data_compatibility(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_data = np.random.random((2, 8, 8, 3))\n        else:\n            input_data = np.random.random((2, 3, 8, 8))\n        layer = layers.RandomShear(1, 1)\n\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        for output in ds.take(1):\n            output.numpy()\n\n    @parameterized.named_parameters(\n        (\n            \"with_x_shift\",\n            [[1.0, 0.0]],\n            [[[0.0, 1.0, 3.2, 3.0], [1.2, 4.0, 4.8, 6.0]]],\n        ),\n        (\n            \"with_y_shift\",\n            [[0.0, 1.0]],\n            [[[2.0, 0.0, 4.0, 0.5], [6.0, 0.0, 8.0, 0.0]]],\n        ),\n        (\n            \"with_xy_shift\",\n            [[1.0, 1.0]],\n            [[[0.0, 0.0, 3.2, 3.5], [1.2, 0.0, 4.8, 4.5]]],\n        ),\n    )\n    def test_random_shear_bounding_boxes(self, translation, expected_boxes):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            image_shape = (10, 8, 3)\n        else:\n            image_shape = (3, 10, 8)\n        input_image = np.random.random(image_shape)\n        bounding_boxes = {\n            \"boxes\": np.array(\n                [\n                    [2, 1, 4, 3],\n                    [6, 4, 8, 6],\n                ]\n            ),\n            \"labels\": np.array([[1, 2]]),\n        }\n        input_data = {\"images\": input_image, \"bounding_boxes\": bounding_boxes}\n        layer = layers.RandomShear(\n            x_factor=0.5,\n            y_factor=0.5,\n            data_format=data_format,\n            seed=42,\n            bounding_box_format=\"xyxy\",\n        )\n\n        transformation = {\n            \"shear_factor\": backend_utils.convert_tf_tensor(\n                np.array(translation)\n            ),\n            \"input_shape\": image_shape,\n        }\n        output = layer.transform_bounding_boxes(\n            input_data[\"bounding_boxes\"],\n            transformation=transformation,\n            training=True,\n        )\n\n        self.assertAllClose(output[\"boxes\"], expected_boxes)\n\n    @parameterized.named_parameters(\n        (\n            \"with_x_shift\",\n            [[1.0, 0.0]],\n            [[[0.0, 1.0, 3.2, 3.0], [1.2, 4.0, 4.8, 6.0]]],\n        ),\n        (\n            \"with_y_shift\",\n            [[0.0, 1.0]],\n            [[[2.0, 0.0, 4.0, 0.5], [6.0, 0.0, 8.0, 0.0]]],\n        ),\n        (\n            \"with_xy_shift\",\n            [[1.0, 1.0]],\n            [[[0.0, 0.0, 3.2, 3.5], [1.2, 0.0, 4.8, 4.5]]],\n        ),\n    )\n    def test_random_shear_tf_data_bounding_boxes(\n        self, translation, expected_boxes\n    ):\n        data_format = backend.config.image_data_format()\n        if backend.config.image_data_format() == \"channels_last\":\n            image_shape = (1, 10, 8, 3)\n        else:\n            image_shape = (1, 3, 10, 8)\n        input_image = np.random.random(image_shape)\n        bounding_boxes = {\n            \"boxes\": np.array(\n                [\n                    [\n                        [2, 1, 4, 3],\n                        [6, 4, 8, 6],\n                    ]\n                ]\n            ),\n            \"labels\": np.array([[1, 2]]),\n        }\n\n        input_data = {\"images\": input_image, \"bounding_boxes\": bounding_boxes}\n\n        ds = tf_data.Dataset.from_tensor_slices(input_data)\n        layer = layers.RandomShear(\n            x_factor=0.5,\n            y_factor=0.5,\n            data_format=data_format,\n            seed=42,\n            bounding_box_format=\"xyxy\",\n        )\n\n        transformation = {\n            \"shear_factor\": np.array(translation),\n            \"input_shape\": image_shape,\n        }\n\n        ds = ds.map(\n            lambda x: layer.transform_bounding_boxes(\n                x[\"bounding_boxes\"],\n                transformation=transformation,\n                training=True,\n            )\n        )\n\n        output = next(iter(ds))\n        expected_boxes = np.array(expected_boxes)\n        self.assertAllClose(output[\"boxes\"], expected_boxes)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_translation.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    base_image_preprocessing_transform_example,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (  # noqa: E501\n    clip_to_image_size,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (  # noqa: E501\n    convert_format,\n)\nfrom keras.src.random.seed_generator import SeedGenerator\nfrom keras.src.utils import backend_utils\n\n\n@keras_export(\"keras.layers.RandomTranslation\")\nclass RandomTranslation(BaseImagePreprocessingLayer):\n    \"\"\"A preprocessing layer which randomly translates images during training.\n\n    This layer will apply random translations to each image during training,\n    filling empty space according to `fill_mode`.\n\n    Input pixel values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) and\n    of integer or floating point dtype. By default, the layer will output\n    floats.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Input shape:\n        3D (unbatched) or 4D (batched) tensor with shape:\n        `(..., height, width, channels)`, in `\"channels_last\"` format,\n        or `(..., channels, height, width)`, in `\"channels_first\"` format.\n\n    Output shape:\n        3D (unbatched) or 4D (batched) tensor with shape:\n        `(..., target_height, target_width, channels)`,\n        or `(..., channels, target_height, target_width)`,\n        in `\"channels_first\"` format.\n\n    Args:\n        height_factor: a float represented as fraction of value, or a tuple of\n            size 2 representing lower and upper bound for shifting vertically. A\n            negative value means shifting image up, while a positive value means\n            shifting image down. When represented as a single positive float,\n            this value is used for both the upper and lower bound. For instance,\n            `height_factor=(-0.2, 0.3)` results in an output shifted by a random\n            amount in the range `[-20%, +30%]`. `height_factor=0.2` results in\n            an output height shifted by a random amount in the range\n            `[-20%, +20%]`.\n        width_factor: a float represented as fraction of value, or a tuple of\n            size 2 representing lower and upper bound for shifting horizontally.\n            A negative value means shifting image left, while a positive value\n            means shifting image right. When represented as a single positive\n            float, this value is used for both the upper and lower bound. For\n            instance, `width_factor=(-0.2, 0.3)` results in an output shifted\n            left by 20%, and shifted right by 30%. `width_factor=0.2` results\n            in an output height shifted left or right by 20%.\n        fill_mode: Points outside the boundaries of the input are filled\n            according to the given mode. Available methods are `\"constant\"`,\n            `\"nearest\"`, `\"wrap\"` and `\"reflect\"`. Defaults to `\"constant\"`.\n            - `\"reflect\"`: `(d c b a | a b c d | d c b a)`\n                The input is extended by reflecting about the edge of the last\n                pixel.\n            - `\"constant\"`: `(k k k k | a b c d | k k k k)`\n                The input is extended by filling all values beyond\n                the edge with the same constant value k specified by\n                `fill_value`.\n            - `\"wrap\"`: `(a b c d | a b c d | a b c d)`\n                The input is extended by wrapping around to the opposite edge.\n            - `\"nearest\"`: `(a a a a | a b c d | d d d d)`\n                The input is extended by the nearest pixel.\n            Note that when using torch backend, `\"reflect\"` is redirected to\n            `\"mirror\"` `(c d c b | a b c d | c b a b)` because torch does not\n            support `\"reflect\"`.\n            Note that torch backend does not support `\"wrap\"`.\n        interpolation: Interpolation mode. Supported values: `\"nearest\"`,\n            `\"bilinear\"`.\n        seed: Integer. Used to create a random seed.\n        fill_value: a float represents the value to be filled outside the\n            boundaries when `fill_mode=\"constant\"`.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, height, width, channels)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, channels, height, width)`. It defaults to the\n            `image_data_format` value found in your Keras config file at\n            `~/.keras/keras.json`. If you never set it, then it will be\n            `\"channels_last\"`.\n        **kwargs: Base layer keyword arguments, such as `name` and `dtype`.\n\n    Example:\n\n    {{base_image_preprocessing_transform_example}}\n    \"\"\"\n\n    _USE_BASE_FACTOR = False\n    _FACTOR_VALIDATION_ERROR = (\n        \"The `factor` argument should be a number (or a list of two numbers) \"\n        \"in the range [-1.0, 1.0]. \"\n    )\n    _SUPPORTED_FILL_MODE = (\"reflect\", \"wrap\", \"constant\", \"nearest\")\n    _SUPPORTED_INTERPOLATION = (\"nearest\", \"bilinear\")\n\n    def __init__(\n        self,\n        height_factor,\n        width_factor,\n        fill_mode=\"reflect\",\n        interpolation=\"bilinear\",\n        seed=None,\n        fill_value=0.0,\n        data_format=None,\n        **kwargs,\n    ):\n        super().__init__(data_format=data_format, **kwargs)\n        self.height_factor = height_factor\n        self.height_lower, self.height_upper = self._set_factor(\n            height_factor, \"height_factor\"\n        )\n        self.width_factor = width_factor\n        self.width_lower, self.width_upper = self._set_factor(\n            width_factor, \"width_factor\"\n        )\n\n        if fill_mode not in self._SUPPORTED_FILL_MODE:\n            raise NotImplementedError(\n                f\"Unknown `fill_mode` {fill_mode}. Expected of one \"\n                f\"{self._SUPPORTED_FILL_MODE}.\"\n            )\n        if interpolation not in self._SUPPORTED_INTERPOLATION:\n            raise NotImplementedError(\n                f\"Unknown `interpolation` {interpolation}. Expected of one \"\n                f\"{self._SUPPORTED_INTERPOLATION}.\"\n            )\n\n        self.fill_mode = fill_mode\n        self.fill_value = fill_value\n        self.interpolation = interpolation\n        self.seed = seed\n        self.generator = SeedGenerator(seed)\n        self.supports_jit = False\n\n    def _set_factor(self, factor, factor_name):\n        if isinstance(factor, (tuple, list)):\n            if len(factor) != 2:\n                raise ValueError(\n                    self._FACTOR_VALIDATION_ERROR\n                    + f\"Received: {factor_name}={factor}\"\n                )\n            self._check_factor_range(factor[0])\n            self._check_factor_range(factor[1])\n            lower, upper = sorted(factor)\n        elif isinstance(factor, (int, float)):\n            self._check_factor_range(factor)\n            factor = abs(factor)\n            lower, upper = [-factor, factor]\n        else:\n            raise ValueError(\n                self._FACTOR_VALIDATION_ERROR\n                + f\"Received: {factor_name}={factor}\"\n            )\n        return lower, upper\n\n    def _check_factor_range(self, input_number):\n        if input_number > 1.0 or input_number < -1.0:\n            raise ValueError(\n                self._FACTOR_VALIDATION_ERROR\n                + f\"Received: input_number={input_number}\"\n            )\n\n    def _transform_images(self, images, transformation, interpolation):\n        return self._translate_inputs(images, transformation, interpolation)\n\n    def transform_labels(self, labels, transformation, training=True):\n        return labels\n\n    def get_transformed_x_y(self, x, y, transform):\n        a0, a1, a2, b0, b1, b2, c0, c1 = self.backend.numpy.split(\n            transform, 8, axis=-1\n        )\n\n        k = c0 * x + c1 * y + 1\n        x_transformed = (a0 * x + a1 * y + a2) / k\n        y_transformed = (b0 * x + b1 * y + b2) / k\n        return x_transformed, y_transformed\n\n    def get_shifted_bbox(self, bounding_boxes, w_shift_factor, h_shift_factor):\n        bboxes = bounding_boxes[\"boxes\"]\n        x1, x2, x3, x4 = self.backend.numpy.split(bboxes, 4, axis=-1)\n\n        w_shift_factor = self.backend.convert_to_tensor(\n            w_shift_factor, dtype=x1.dtype\n        )\n        h_shift_factor = self.backend.convert_to_tensor(\n            h_shift_factor, dtype=x1.dtype\n        )\n\n        if len(bboxes.shape) == 3:\n            w_shift_factor = self.backend.numpy.expand_dims(w_shift_factor, -1)\n            h_shift_factor = self.backend.numpy.expand_dims(h_shift_factor, -1)\n\n        bounding_boxes[\"boxes\"] = self.backend.numpy.concatenate(\n            [\n                x1 - w_shift_factor,\n                x2 - h_shift_factor,\n                x3 - w_shift_factor,\n                x4 - h_shift_factor,\n            ],\n            axis=-1,\n        )\n        return bounding_boxes\n\n    def transform_bounding_boxes(\n        self,\n        bounding_boxes,\n        transformation,\n        training=True,\n    ):\n        if training:\n            if backend_utils.in_tf_graph():\n                self.backend.set_backend(\"tensorflow\")\n\n            if self.data_format == \"channels_first\":\n                height_axis = -2\n                width_axis = -1\n            else:\n                height_axis = -3\n                width_axis = -2\n\n            input_height, input_width = (\n                transformation[\"input_shape\"][height_axis],\n                transformation[\"input_shape\"][width_axis],\n            )\n\n            bounding_boxes = convert_format(\n                bounding_boxes,\n                source=self.bounding_box_format,\n                target=\"xyxy\",\n                height=input_height,\n                width=input_width,\n            )\n\n            translations = transformation[\"translations\"]\n            transform = self._get_translation_matrix(translations)\n\n            w_shift_factor, h_shift_factor = self.get_transformed_x_y(\n                0, 0, transform\n            )\n            bounding_boxes = self.get_shifted_bbox(\n                bounding_boxes, w_shift_factor, h_shift_factor\n            )\n\n            bounding_boxes = clip_to_image_size(\n                bounding_boxes=bounding_boxes,\n                height=input_height,\n                width=input_width,\n                bounding_box_format=\"xyxy\",\n            )\n\n            bounding_boxes = convert_format(\n                bounding_boxes,\n                source=\"xyxy\",\n                target=self.bounding_box_format,\n                height=input_height,\n                width=input_width,\n            )\n\n            self.backend.reset()\n\n        return bounding_boxes\n\n    def get_random_transformation(self, data, training=True, seed=None):\n        if not training:\n            return None\n\n        if isinstance(data, dict):\n            images = data[\"images\"]\n        else:\n            images = data\n\n        images_shape = self.backend.shape(images)\n        unbatched = len(images_shape) == 3\n        if unbatched:\n            images_shape = self.backend.shape(images)\n            batch_size = 1\n        else:\n            batch_size = images_shape[0]\n        if self.data_format == \"channels_first\":\n            height = images_shape[-2]\n            width = images_shape[-1]\n        else:\n            height = images_shape[-3]\n            width = images_shape[-2]\n\n        if seed is None:\n            seed = self._get_seed_generator(self.backend._backend)\n\n        height_translate = self.backend.random.uniform(\n            minval=self.height_lower,\n            maxval=self.height_upper,\n            shape=[batch_size, 1],\n            seed=seed,\n        )\n        height_translate = self.backend.numpy.multiply(height_translate, height)\n        width_translate = self.backend.random.uniform(\n            minval=self.width_lower,\n            maxval=self.width_upper,\n            shape=[batch_size, 1],\n            seed=seed,\n        )\n        width_translate = self.backend.numpy.multiply(width_translate, width)\n        translations = self.backend.cast(\n            self.backend.numpy.concatenate(\n                [width_translate, height_translate], axis=1\n            ),\n            dtype=\"float32\",\n        )\n        return {\"translations\": translations, \"input_shape\": images_shape}\n\n    def _translate_inputs(self, inputs, transformation, interpolation):\n        if transformation is None:\n            return inputs\n\n        inputs_shape = self.backend.shape(inputs)\n        unbatched = len(inputs_shape) == 3\n        if unbatched:\n            inputs = self.backend.numpy.expand_dims(inputs, axis=0)\n\n        translations = transformation[\"translations\"]\n        outputs = self.backend.image.affine_transform(\n            inputs,\n            transform=self._get_translation_matrix(translations),\n            interpolation=interpolation,\n            fill_mode=self.fill_mode,\n            fill_value=self.fill_value,\n            data_format=self.data_format,\n        )\n\n        if unbatched:\n            outputs = self.backend.numpy.squeeze(outputs, axis=0)\n        return outputs\n\n    def _get_translation_matrix(self, translations):\n        num_translations = self.backend.shape(translations)[0]\n        # The translation matrix looks like:\n        #     [[1 0 -dx]\n        #      [0 1 -dy]\n        #      [0 0 1]]\n        # where the last entry is implicit.\n        # translation matrices are always float32.\n        return self.backend.numpy.concatenate(\n            [\n                self.backend.numpy.ones((num_translations, 1)),\n                self.backend.numpy.zeros((num_translations, 1)),\n                -translations[:, 0:1],\n                self.backend.numpy.zeros((num_translations, 1)),\n                self.backend.numpy.ones((num_translations, 1)),\n                -translations[:, 1:],\n                self.backend.numpy.zeros((num_translations, 2)),\n            ],\n            axis=1,\n        )\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\n            \"height_factor\": self.height_factor,\n            \"width_factor\": self.width_factor,\n            \"fill_mode\": self.fill_mode,\n            \"interpolation\": self.interpolation,\n            \"seed\": self.seed,\n            \"fill_value\": self.fill_value,\n            \"data_format\": self.data_format,\n        }\n        return {**base_config, **config}\n\n\nRandomTranslation.__doc__ = RandomTranslation.__doc__.replace(\n    \"{{base_image_preprocessing_transform_example}}\",\n    base_image_preprocessing_transform_example.replace(\n        \"{LayerName}\", \"RandomTranslation\"\n    ),\n)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_translation_test.py",
    "content": "import numpy as np\nfrom absl.testing import parameterized\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\nfrom keras.src.utils import backend_utils\n\n\nclass RandomTranslationTest(testing.TestCase):\n    @parameterized.named_parameters(\n        (\"random_translate_4_by_6\", 0.4, 0.6),\n        (\"random_translate_3_by_2\", 0.3, 0.2),\n        (\"random_translate_tuple_factor\", (-0.5, 0.4), (0.2, 0.3)),\n    )\n    def test_random_translation(self, height_factor, width_factor):\n        self.run_layer_test(\n            layers.RandomTranslation,\n            init_kwargs={\n                \"height_factor\": height_factor,\n                \"width_factor\": width_factor,\n            },\n            input_shape=(2, 3, 4),\n            expected_output_shape=(2, 3, 4),\n            supports_masking=False,\n            run_training_check=False,\n        )\n\n    @parameterized.named_parameters(\n        (\"bad_len\", [0.1, 0.2, 0.3], 0.0),\n        (\"bad_type\", {\"dummy\": 0.3}, 0.0),\n        (\"exceed_range_single\", -1.1, 0.0),\n        (\"exceed_range_tuple\", (-1.1, 0.0), 0.0),\n    )\n    def test_random_translation_with_bad_factor(\n        self, height_factor, width_factor\n    ):\n        with self.assertRaises(ValueError):\n            self.run_layer_test(\n                layers.RandomTranslation,\n                init_kwargs={\n                    \"height_factor\": height_factor,\n                    \"width_factor\": width_factor,\n                },\n                input_shape=(2, 3, 4),\n                expected_output_shape=(2, 3, 4),\n                supports_masking=False,\n                run_training_check=False,\n            )\n\n    def test_random_translation_with_inference_mode(self):\n        input_data = np.random.random((1, 4, 4, 3))\n        expected_output = input_data\n        layer = layers.RandomTranslation(0.2, 0.1)\n        output = layer(input_data, training=False)\n        self.assertAllClose(output, expected_output)\n\n    @parameterized.parameters([\"channels_first\", \"channels_last\"])\n    def test_random_translation_up_numeric_reflect(self, data_format):\n        input_image = np.arange(0, 25)\n        expected_output = np.asarray(\n            [\n                [5, 6, 7, 8, 9],\n                [10, 11, 12, 13, 14],\n                [15, 16, 17, 18, 19],\n                [20, 21, 22, 23, 24],\n                [20, 21, 22, 23, 24],\n            ]\n        )\n        if data_format == \"channels_last\":\n            input_image = np.reshape(input_image, (1, 5, 5, 1))\n            expected_output = backend.convert_to_tensor(\n                np.reshape(expected_output, (1, 5, 5, 1))\n            )\n        else:\n            input_image = np.reshape(input_image, (1, 1, 5, 5))\n            expected_output = backend.convert_to_tensor(\n                np.reshape(expected_output, (1, 1, 5, 5))\n            )\n        self.run_layer_test(\n            layers.RandomTranslation,\n            init_kwargs={\n                \"height_factor\": (-0.2, -0.2),\n                \"width_factor\": 0.0,\n                \"data_format\": data_format,\n            },\n            input_shape=None,\n            input_data=input_image,\n            expected_output=expected_output,\n            supports_masking=False,\n            run_training_check=False,\n        )\n\n    @parameterized.parameters([\"channels_first\", \"channels_last\"])\n    def test_random_translation_up_numeric_constant(self, data_format):\n        input_image = np.arange(0, 25).astype(\"float32\")\n        # Shifting by -.2 * 5 = 1 pixel.\n        expected_output = np.asarray(\n            [\n                [5, 6, 7, 8, 9],\n                [10, 11, 12, 13, 14],\n                [15, 16, 17, 18, 19],\n                [20, 21, 22, 23, 24],\n                [0, 0, 0, 0, 0],\n            ]\n        )\n        if data_format == \"channels_last\":\n            input_image = np.reshape(input_image, (1, 5, 5, 1))\n            expected_output = backend.convert_to_tensor(\n                np.reshape(expected_output, (1, 5, 5, 1)), dtype=\"float32\"\n            )\n        else:\n            input_image = np.reshape(input_image, (1, 1, 5, 5))\n            expected_output = backend.convert_to_tensor(\n                np.reshape(expected_output, (1, 1, 5, 5)), dtype=\"float32\"\n            )\n        self.run_layer_test(\n            layers.RandomTranslation,\n            init_kwargs={\n                \"height_factor\": (-0.2, -0.2),\n                \"width_factor\": 0.0,\n                \"fill_mode\": \"constant\",\n                \"data_format\": data_format,\n            },\n            input_shape=None,\n            input_data=input_image,\n            expected_output=expected_output,\n            supports_masking=False,\n            run_training_check=False,\n        )\n\n    @parameterized.parameters([\"channels_first\", \"channels_last\"])\n    def test_random_translation_down_numeric_reflect(self, data_format):\n        input_image = np.arange(0, 25)\n        # Shifting by .2 * 5 = 1 pixel.\n        expected_output = np.asarray(\n            [\n                [0, 1, 2, 3, 4],\n                [0, 1, 2, 3, 4],\n                [5, 6, 7, 8, 9],\n                [10, 11, 12, 13, 14],\n                [15, 16, 17, 18, 19],\n            ]\n        )\n        if data_format == \"channels_last\":\n            input_image = np.reshape(input_image, (1, 5, 5, 1))\n            expected_output = backend.convert_to_tensor(\n                np.reshape(expected_output, (1, 5, 5, 1))\n            )\n        else:\n            input_image = np.reshape(input_image, (1, 1, 5, 5))\n            expected_output = backend.convert_to_tensor(\n                np.reshape(expected_output, (1, 1, 5, 5))\n            )\n        self.run_layer_test(\n            layers.RandomTranslation,\n            init_kwargs={\n                \"height_factor\": (0.2, 0.2),\n                \"width_factor\": 0.0,\n                \"data_format\": data_format,\n            },\n            input_shape=None,\n            input_data=input_image,\n            expected_output=expected_output,\n            supports_masking=False,\n            run_training_check=False,\n        )\n\n    @parameterized.parameters([\"channels_first\", \"channels_last\"])\n    def test_random_translation_asymmetric_size_numeric_reflect(\n        self, data_format\n    ):\n        input_image = np.arange(0, 16)\n        # Shifting by .2 * 5 = 1 pixel.\n        expected_output = np.asarray(\n            [\n                [6, 7],\n                [4, 5],\n                [2, 3],\n                [0, 1],\n                [0, 1],\n                [2, 3],\n                [4, 5],\n                [6, 7],\n            ]\n        )\n        if data_format == \"channels_last\":\n            input_image = np.reshape(input_image, (1, 8, 2, 1))\n            expected_output = backend.convert_to_tensor(\n                np.reshape(expected_output, (1, 8, 2, 1))\n            )\n        else:\n            input_image = np.reshape(input_image, (1, 1, 8, 2))\n            expected_output = backend.convert_to_tensor(\n                np.reshape(expected_output, (1, 1, 8, 2))\n            )\n        self.run_layer_test(\n            layers.RandomTranslation,\n            init_kwargs={\n                \"height_factor\": (0.5, 0.5),\n                \"width_factor\": 0.0,\n                \"data_format\": data_format,\n            },\n            input_shape=None,\n            input_data=input_image,\n            expected_output=expected_output,\n            supports_masking=False,\n            run_training_check=False,\n        )\n\n    @parameterized.parameters([\"channels_first\", \"channels_last\"])\n    def test_random_translation_down_numeric_constant(self, data_format):\n        input_image = np.arange(0, 25)\n        # Shifting by .2 * 5 = 1 pixel.\n        expected_output = np.asarray(\n            [\n                [0, 0, 0, 0, 0],\n                [0, 1, 2, 3, 4],\n                [5, 6, 7, 8, 9],\n                [10, 11, 12, 13, 14],\n                [15, 16, 17, 18, 19],\n            ]\n        )\n        if data_format == \"channels_last\":\n            input_image = np.reshape(input_image, (1, 5, 5, 1))\n            expected_output = backend.convert_to_tensor(\n                np.reshape(expected_output, (1, 5, 5, 1))\n            )\n        else:\n            input_image = np.reshape(input_image, (1, 1, 5, 5))\n            expected_output = backend.convert_to_tensor(\n                np.reshape(expected_output, (1, 1, 5, 5))\n            )\n        self.run_layer_test(\n            layers.RandomTranslation,\n            init_kwargs={\n                \"height_factor\": (0.2, 0.2),\n                \"width_factor\": 0.0,\n                \"fill_mode\": \"constant\",\n                \"fill_value\": 0.0,\n                \"data_format\": data_format,\n            },\n            input_shape=None,\n            input_data=input_image,\n            expected_output=expected_output,\n            supports_masking=False,\n            run_training_check=False,\n        )\n\n    @parameterized.parameters([\"channels_first\", \"channels_last\"])\n    def test_random_translation_left_numeric_reflect(self, data_format):\n        input_image = np.arange(0, 25)\n        # Shifting by .2 * 5 = 1 pixel.\n        expected_output = np.asarray(\n            [\n                [1, 2, 3, 4, 4],\n                [6, 7, 8, 9, 9],\n                [11, 12, 13, 14, 14],\n                [16, 17, 18, 19, 19],\n                [21, 22, 23, 24, 24],\n            ]\n        )\n        if data_format == \"channels_last\":\n            input_image = np.reshape(input_image, (1, 5, 5, 1))\n            expected_output = backend.convert_to_tensor(\n                np.reshape(expected_output, (1, 5, 5, 1))\n            )\n        else:\n            input_image = np.reshape(input_image, (1, 1, 5, 5))\n            expected_output = backend.convert_to_tensor(\n                np.reshape(expected_output, (1, 1, 5, 5))\n            )\n        self.run_layer_test(\n            layers.RandomTranslation,\n            init_kwargs={\n                \"height_factor\": 0.0,\n                \"width_factor\": (-0.2, -0.2),\n                \"data_format\": data_format,\n            },\n            input_shape=None,\n            input_data=input_image,\n            expected_output=expected_output,\n            supports_masking=False,\n            run_training_check=False,\n        )\n\n    @parameterized.parameters([\"channels_first\", \"channels_last\"])\n    def test_random_translation_left_numeric_constant(self, data_format):\n        input_image = np.arange(0, 25)\n        # Shifting by .2 * 5 = 1 pixel.\n        expected_output = np.asarray(\n            [\n                [1, 2, 3, 4, 0],\n                [6, 7, 8, 9, 0],\n                [11, 12, 13, 14, 0],\n                [16, 17, 18, 19, 0],\n                [21, 22, 23, 24, 0],\n            ]\n        )\n        if data_format == \"channels_last\":\n            input_image = np.reshape(input_image, (1, 5, 5, 1))\n            expected_output = backend.convert_to_tensor(\n                np.reshape(expected_output, (1, 5, 5, 1))\n            )\n        else:\n            input_image = np.reshape(input_image, (1, 1, 5, 5))\n            expected_output = backend.convert_to_tensor(\n                np.reshape(expected_output, (1, 1, 5, 5))\n            )\n        self.run_layer_test(\n            layers.RandomTranslation,\n            init_kwargs={\n                \"height_factor\": 0.0,\n                \"width_factor\": (-0.2, -0.2),\n                \"fill_mode\": \"constant\",\n                \"fill_value\": 0.0,\n                \"data_format\": data_format,\n            },\n            input_shape=None,\n            input_data=input_image,\n            expected_output=expected_output,\n            supports_masking=False,\n            run_training_check=False,\n        )\n\n    def test_tf_data_compatibility(self):\n        layer = layers.RandomTranslation(0.2, 0.1)\n        input_data = np.random.random((1, 4, 4, 3))\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(1).map(layer)\n        next(iter(ds)).numpy()\n\n    @parameterized.named_parameters(\n        (\n            \"with_positive_shift\",\n            [[1.0, 2.0]],\n            [[3.0, 3.0, 5.0, 5.0], [7.0, 6.0, 8.0, 8.0]],\n        ),\n        (\n            \"with_negative_shift\",\n            [[-1.0, -2.0]],\n            [[1.0, 0.0, 3.0, 1.0], [5.0, 2.0, 7.0, 4.0]],\n        ),\n    )\n    def test_random_flip_bounding_boxes(self, translation, expected_boxes):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            image_shape = (10, 8, 3)\n        else:\n            image_shape = (3, 10, 8)\n        input_image = np.random.random(image_shape)\n        bounding_boxes = {\n            \"boxes\": np.array(\n                [\n                    [2, 1, 4, 3],\n                    [6, 4, 8, 6],\n                ]\n            ),\n            \"labels\": np.array([[1, 2]]),\n        }\n        input_data = {\"images\": input_image, \"bounding_boxes\": bounding_boxes}\n        random_translation_layer = layers.RandomTranslation(\n            height_factor=0.5,\n            width_factor=0.5,\n            data_format=data_format,\n            seed=42,\n            bounding_box_format=\"xyxy\",\n        )\n\n        transformation = {\n            \"translations\": backend_utils.convert_tf_tensor(\n                np.array(translation)\n            ),\n            \"input_shape\": image_shape,\n        }\n        output = random_translation_layer.transform_bounding_boxes(\n            input_data[\"bounding_boxes\"],\n            transformation=transformation,\n            training=True,\n        )\n\n        self.assertAllClose(output[\"boxes\"], expected_boxes)\n\n    @parameterized.named_parameters(\n        (\n            \"with_positive_shift\",\n            [[1.0, 2.0]],\n            [[3.0, 3.0, 5.0, 5.0], [7.0, 6.0, 8.0, 8.0]],\n        ),\n        (\n            \"with_negative_shift\",\n            [[-1.0, -2.0]],\n            [[1.0, 0.0, 3.0, 1.0], [5.0, 2.0, 7.0, 4.0]],\n        ),\n    )\n    def test_random_flip_tf_data_bounding_boxes(\n        self, translation, expected_boxes\n    ):\n        data_format = backend.config.image_data_format()\n        if backend.config.image_data_format() == \"channels_last\":\n            image_shape = (1, 10, 8, 3)\n        else:\n            image_shape = (1, 3, 10, 8)\n        input_image = np.random.random(image_shape)\n        bounding_boxes = {\n            \"boxes\": np.array(\n                [\n                    [\n                        [2, 1, 4, 3],\n                        [6, 4, 8, 6],\n                    ]\n                ]\n            ),\n            \"labels\": np.array([[1, 2]]),\n        }\n\n        input_data = {\"images\": input_image, \"bounding_boxes\": bounding_boxes}\n\n        ds = tf_data.Dataset.from_tensor_slices(input_data)\n        random_translation_layer = layers.RandomTranslation(\n            height_factor=0.5,\n            width_factor=0.5,\n            data_format=data_format,\n            seed=42,\n            bounding_box_format=\"xyxy\",\n        )\n\n        transformation = {\n            \"translations\": np.array(translation),\n            \"input_shape\": image_shape,\n        }\n\n        ds = ds.map(\n            lambda x: random_translation_layer.transform_bounding_boxes(\n                x[\"bounding_boxes\"],\n                transformation=transformation,\n                training=True,\n            )\n        )\n\n        output = next(iter(ds))\n        expected_boxes = np.array(expected_boxes)\n        self.assertAllClose(output[\"boxes\"], expected_boxes)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_zoom.py",
    "content": "from keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (  # noqa: E501\n    clip_to_image_size,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (  # noqa: E501\n    convert_format,\n)\nfrom keras.src.random.seed_generator import SeedGenerator\nfrom keras.src.utils import backend_utils\n\n\n@keras_export(\"keras.layers.RandomZoom\")\nclass RandomZoom(BaseImagePreprocessingLayer):\n    \"\"\"A preprocessing layer which randomly zooms images during training.\n\n    This layer will randomly zoom in or out on each axis of an image\n    independently, filling empty space according to `fill_mode`.\n\n    Input pixel values can be of any range (e.g. `[0., 1.)` or `[0, 255]`) and\n    of integer or floating point dtype.\n    By default, the layer will output floats.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Input shape:\n        3D (unbatched) or 4D (batched) tensor with shape:\n        `(..., height, width, channels)`, in `\"channels_last\"` format,\n        or `(..., channels, height, width)`, in `\"channels_first\"` format.\n\n    Output shape:\n        3D (unbatched) or 4D (batched) tensor with shape:\n        `(..., target_height, target_width, channels)`,\n        or `(..., channels, target_height, target_width)`,\n        in `\"channels_first\"` format.\n\n    Args:\n        height_factor: a float represented as fraction of value, or a tuple of\n            size 2 representing lower and upper bound for zooming vertically.\n            When represented as a single float, this value is used for both the\n            upper and lower bound. A positive value means zooming out, while a\n            negative value means zooming in. For instance,\n            `height_factor=(0.2, 0.3)` result in an output zoomed out by a\n            random amount in the range `[+20%, +30%]`.\n            `height_factor=(-0.3, -0.2)` result in an output zoomed in by a\n            random amount in the range `[+20%, +30%]`.\n        width_factor: a float represented as fraction of value, or a tuple of\n            size 2 representing lower and upper bound for zooming horizontally.\n            When represented as a single float, this value is used for both the\n            upper and lower bound. For instance, `width_factor=(0.2, 0.3)`\n            result in an output zooming out between 20% to 30%.\n            `width_factor=(-0.3, -0.2)` result in an output zooming in between\n            20% to 30%. `None` means i.e., zooming vertical and horizontal\n            directions by preserving the aspect ratio. Defaults to `None`.\n        fill_mode: Points outside the boundaries of the input are filled\n            according to the given mode. Available methods are `\"constant\"`,\n            `\"nearest\"`, `\"wrap\"` and `\"reflect\"`. Defaults to `\"reflect\"`.\n            - `\"reflect\"`: `(d c b a | a b c d | d c b a)`\n                The input is extended by reflecting about the edge of the last\n                pixel.\n            - `\"constant\"`: `(k k k k | a b c d | k k k k)`\n                The input is extended by filling all values beyond\n                the edge with the same constant value k specified by\n                `fill_value`.\n            - `\"wrap\"`: `(a b c d | a b c d | a b c d)`\n                The input is extended by wrapping around to the opposite edge.\n            - `\"nearest\"`: `(a a a a | a b c d | d d d d)`\n                The input is extended by the nearest pixel.\n            Note that when using torch backend, `\"reflect\"` is redirected to\n            `\"mirror\"` `(c d c b | a b c d | c b a b)` because torch does not\n            support `\"reflect\"`.\n            Note that torch backend does not support `\"wrap\"`.\n        interpolation: Interpolation mode. Supported values: `\"nearest\"`,\n            `\"bilinear\"`.\n        seed: Integer. Used to create a random seed.\n        fill_value: a float that represents the value to be filled outside\n            the boundaries when `fill_mode=\"constant\"`.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, height, width, channels)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, channels, height, width)`. It defaults to the\n            `image_data_format` value found in your Keras config file at\n            `~/.keras/keras.json`. If you never set it, then it will be\n            `\"channels_last\"`.\n        **kwargs: Base layer keyword arguments, such as `name` and `dtype`.\n\n    Example:\n\n    >>> input_img = np.random.random((32, 224, 224, 3))\n    >>> layer = keras.layers.RandomZoom(.5, .2)\n    >>> out_img = layer(input_img)\n    \"\"\"\n\n    _USE_BASE_FACTOR = False\n    _FACTOR_VALIDATION_ERROR = (\n        \"The `height_factor` and `width_factor` arguments \"\n        \"should be a number (or a list of two numbers) \"\n        \"in the range (-1.0, 1.0]. \"\n    )\n    _SUPPORTED_FILL_MODE = (\"reflect\", \"wrap\", \"constant\", \"nearest\")\n    _SUPPORTED_INTERPOLATION = (\"nearest\", \"bilinear\")\n\n    def __init__(\n        self,\n        height_factor,\n        width_factor=None,\n        fill_mode=\"reflect\",\n        interpolation=\"bilinear\",\n        seed=None,\n        fill_value=0.0,\n        data_format=None,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.height_factor = height_factor\n        self.height_lower, self.height_upper = self._set_factor(\n            height_factor, \"height_factor\"\n        )\n        self.width_factor = width_factor\n        if width_factor is not None:\n            self.width_lower, self.width_upper = self._set_factor(\n                width_factor, \"width_factor\"\n            )\n        if fill_mode not in self._SUPPORTED_FILL_MODE:\n            raise NotImplementedError(\n                f\"Unknown `fill_mode` {fill_mode}. Expected of one \"\n                f\"{self._SUPPORTED_FILL_MODE}.\"\n            )\n        if interpolation not in self._SUPPORTED_INTERPOLATION:\n            raise NotImplementedError(\n                f\"Unknown `interpolation` {interpolation}. Expected of one \"\n                f\"{self._SUPPORTED_INTERPOLATION}.\"\n            )\n\n        self.fill_mode = fill_mode\n        self.fill_value = fill_value\n        self.interpolation = interpolation\n        self.seed = seed\n        self.generator = SeedGenerator(seed)\n        self.data_format = backend.standardize_data_format(data_format)\n        self.supports_jit = False\n\n    def _set_factor(self, factor, factor_name):\n        if isinstance(factor, (tuple, list)):\n            if len(factor) != 2:\n                raise ValueError(\n                    self._FACTOR_VALIDATION_ERROR\n                    + f\"Received: {factor_name}={factor}\"\n                )\n            self._check_factor_range(factor[0])\n            self._check_factor_range(factor[1])\n            lower, upper = sorted(factor)\n        elif isinstance(factor, (int, float)):\n            self._check_factor_range(factor)\n            factor = abs(factor)\n            lower, upper = [-factor, factor]\n        else:\n            raise ValueError(\n                self._FACTOR_VALIDATION_ERROR\n                + f\"Received: {factor_name}={factor}\"\n            )\n        return lower, upper\n\n    def _check_factor_range(self, input_number):\n        if input_number > 1.0 or input_number <= -1.0:\n            raise ValueError(\n                self._FACTOR_VALIDATION_ERROR\n                + f\"Received: input_number={input_number}\"\n            )\n\n    def _transform_images(self, images, transformation, interpolation):\n        return self._zoom_inputs(images, transformation, interpolation)\n\n    def transform_labels(self, labels, transformation, training=True):\n        return labels\n\n    def get_transformed_x_y(self, x, y, transform):\n        a0, a1, a2, b0, b1, b2, c0, c1 = self.backend.numpy.split(\n            transform, 8, axis=-1\n        )\n\n        k = c0 * x + c1 * y + 1\n        x_transformed = (a0 * x + a1 * y + a2) / k\n        y_transformed = (b0 * x + b1 * y + b2) / k\n        return x_transformed, y_transformed\n\n    def get_clipped_bbox(self, bounding_boxes, h_end, h_start, w_end, w_start):\n        bboxes = bounding_boxes[\"boxes\"]\n        x1, y1, x2, y2 = self.backend.numpy.split(bboxes, 4, axis=-1)\n\n        if len(bboxes.shape) == 3:\n            h_end = self.backend.numpy.expand_dims(h_end, -1)\n            h_start = self.backend.numpy.expand_dims(h_start, -1)\n            w_end = self.backend.numpy.expand_dims(w_end, -1)\n            w_start = self.backend.numpy.expand_dims(w_start, -1)\n\n        x1 = self.backend.numpy.clip(x1, w_start, w_end) - w_start\n        y1 = self.backend.numpy.clip(y1, h_start, h_end) - h_start\n        x2 = self.backend.numpy.clip(x2, w_start, w_end) - w_start\n        y2 = self.backend.numpy.clip(y2, h_start, h_end) - h_start\n        bounding_boxes[\"boxes\"] = self.backend.numpy.concatenate(\n            [x1, y1, x2, y2], axis=-1\n        )\n        return bounding_boxes\n\n    def transform_bounding_boxes(\n        self,\n        bounding_boxes,\n        transformation,\n        training=True,\n    ):\n        if training:\n            if backend_utils.in_tf_graph():\n                self.backend.set_backend(\"tensorflow\")\n\n            width_zoom = transformation[\"width_zoom\"]\n            height_zoom = transformation[\"height_zoom\"]\n            inputs_shape = transformation[\"input_shape\"]\n\n            if self.data_format == \"channels_first\":\n                height = inputs_shape[-2]\n                width = inputs_shape[-1]\n            else:\n                height = inputs_shape[-3]\n                width = inputs_shape[-2]\n\n            bounding_boxes = convert_format(\n                bounding_boxes,\n                source=self.bounding_box_format,\n                target=\"xyxy\",\n                height=height,\n                width=width,\n            )\n\n            zooms = self.backend.cast(\n                self.backend.numpy.concatenate(\n                    [width_zoom, height_zoom], axis=1\n                ),\n                dtype=\"float32\",\n            )\n            transform = self._get_zoom_matrix(zooms, height, width)\n\n            w_start, h_start = self.get_transformed_x_y(\n                0,\n                0,\n                transform,\n            )\n\n            w_end, h_end = self.get_transformed_x_y(\n                width,\n                height,\n                transform,\n            )\n\n            bounding_boxes = self.get_clipped_bbox(\n                bounding_boxes, h_end, h_start, w_end, w_start\n            )\n\n            height_transformed = h_end - h_start\n            width_transformed = w_end - w_start\n\n            height_transformed = self.backend.numpy.expand_dims(\n                height_transformed, -1\n            )\n            width_transformed = self.backend.numpy.expand_dims(\n                width_transformed, -1\n            )\n\n            bounding_boxes = convert_format(\n                bounding_boxes,\n                source=\"xyxy\",\n                target=\"rel_xyxy\",\n                height=height_transformed,\n                width=width_transformed,\n            )\n\n            bounding_boxes = clip_to_image_size(\n                bounding_boxes=bounding_boxes,\n                height=height_transformed,\n                width=width_transformed,\n                bounding_box_format=\"rel_xyxy\",\n            )\n\n            bounding_boxes = convert_format(\n                bounding_boxes,\n                source=\"rel_xyxy\",\n                target=self.bounding_box_format,\n                height=height,\n                width=width,\n            )\n\n            self.backend.reset()\n\n        return bounding_boxes\n\n    def get_random_transformation(self, data, training=True, seed=None):\n        if not training:\n            return None\n        if isinstance(data, dict):\n            images = data[\"images\"]\n        else:\n            images = data\n        images_shape = self.backend.shape(images)\n        if len(images_shape) == 4:\n            zoom_factor_shape = (images_shape[0], 1)\n        else:\n            zoom_factor_shape = (1, 1)\n\n        if not training:\n            return {\n                \"height_zoom\": self.backend.numpy.zeros(zoom_factor_shape),\n                \"width_zoom\": self.backend.numpy.zeros(zoom_factor_shape),\n            }\n        if seed is None:\n            seed = self._get_seed_generator(self.backend._backend)\n\n        height_zoom = self.backend.random.uniform(\n            minval=1.0 + self.height_lower,\n            maxval=1.0 + self.height_upper,\n            shape=zoom_factor_shape,\n            seed=seed,\n        )\n        if self.width_factor is not None:\n            width_zoom = self.backend.random.uniform(\n                minval=1.0 + self.width_lower,\n                maxval=1.0 + self.width_upper,\n                shape=zoom_factor_shape,\n                seed=seed,\n            )\n        else:\n            width_zoom = height_zoom\n        return {\n            \"height_zoom\": height_zoom,\n            \"width_zoom\": width_zoom,\n            \"input_shape\": images_shape,\n        }\n\n    def _zoom_inputs(self, inputs, transformation, interpolation):\n        if transformation is None:\n            return inputs\n\n        width_zoom = transformation[\"width_zoom\"]\n        height_zoom = transformation[\"height_zoom\"]\n        zooms = self.backend.cast(\n            self.backend.numpy.concatenate([width_zoom, height_zoom], axis=1),\n            dtype=\"float32\",\n        )\n\n        inputs_shape = self.backend.shape(inputs)\n        unbatched = len(inputs_shape) == 3\n        if unbatched:\n            inputs = self.backend.numpy.expand_dims(inputs, axis=0)\n            inputs_shape = self.backend.shape(inputs)\n        if self.data_format == \"channels_first\":\n            height = inputs_shape[-2]\n            width = inputs_shape[-1]\n        else:\n            height = inputs_shape[-3]\n            width = inputs_shape[-2]\n\n        outputs = self.backend.image.affine_transform(\n            inputs,\n            transform=self._get_zoom_matrix(zooms, height, width),\n            interpolation=interpolation,\n            fill_mode=self.fill_mode,\n            fill_value=self.fill_value,\n            data_format=self.data_format,\n        )\n\n        if unbatched:\n            outputs = self.backend.numpy.squeeze(outputs, axis=0)\n        return outputs\n\n    def _get_zoom_matrix(self, zooms, image_height, image_width):\n        num_zooms = self.backend.shape(zooms)[0]\n        # The zoom matrix looks like:\n        #     [[zx 0 0]\n        #      [0 zy 0]\n        #      [0 0 1]]\n        # where the last entry is implicit.\n        # zoom matrices are always float32.\n        x_offset = ((self.backend.cast(image_width, \"float32\") - 1.0) / 2.0) * (\n            1.0 - zooms[:, 0:1]\n        )\n        y_offset = (\n            (self.backend.cast(image_height, \"float32\") - 1.0) / 2.0\n        ) * (1.0 - zooms[:, 1:])\n        return self.backend.numpy.concatenate(\n            [\n                zooms[:, 0:1],\n                self.backend.numpy.zeros((num_zooms, 1)),\n                x_offset,\n                self.backend.numpy.zeros((num_zooms, 1)),\n                zooms[:, 1:],\n                y_offset,\n                self.backend.numpy.zeros((num_zooms, 2)),\n            ],\n            axis=1,\n        )\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\n            \"height_factor\": self.height_factor,\n            \"width_factor\": self.width_factor,\n            \"fill_mode\": self.fill_mode,\n            \"interpolation\": self.interpolation,\n            \"seed\": self.seed,\n            \"fill_value\": self.fill_value,\n            \"data_format\": self.data_format,\n        }\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/random_zoom_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import testing\nfrom keras.src.utils import backend_utils\n\n\nclass RandomZoomTest(testing.TestCase):\n    @parameterized.named_parameters(\n        (\"random_zoom_in_4_by_6\", -0.4, -0.6),\n        (\"random_zoom_in_2_by_3\", -0.2, -0.3),\n        (\"random_zoom_in_tuple_factor\", (-0.4, -0.5), (-0.2, -0.3)),\n        (\"random_zoom_out_4_by_6\", 0.4, 0.6),\n        (\"random_zoom_out_2_by_3\", 0.2, 0.3),\n        (\"random_zoom_out_tuple_factor\", (0.4, 0.5), (0.2, 0.3)),\n    )\n    def test_random_zoom(self, height_factor, width_factor):\n        self.run_layer_test(\n            layers.RandomZoom,\n            init_kwargs={\n                \"height_factor\": height_factor,\n                \"width_factor\": width_factor,\n            },\n            input_shape=(2, 3, 4),\n            expected_output_shape=(2, 3, 4),\n            supports_masking=False,\n            run_training_check=False,\n        )\n\n    def test_random_zoom_out_correctness(self):\n        if backend.config.image_data_format() == \"channels_last\":\n            input_shape = (1, 5, 5, 1)\n        else:\n            input_shape = (1, 1, 5, 5)\n        input_image = np.reshape(np.arange(0, 25), input_shape)\n        expected_output = np.asarray(\n            [\n                [0, 0, 0, 0, 0],\n                [0, 2.7, 4.5, 6.3, 0],\n                [0, 10.2, 12.0, 13.8, 0],\n                [0, 17.7, 19.5, 21.3, 0],\n                [0, 0, 0, 0, 0],\n            ]\n        )\n        expected_output = backend.convert_to_tensor(\n            np.reshape(expected_output, input_shape)\n        )\n        self.run_layer_test(\n            layers.RandomZoom,\n            init_kwargs={\n                \"height_factor\": (0.5, 0.5),\n                \"width_factor\": (0.8, 0.8),\n                \"interpolation\": \"bilinear\",\n                \"fill_mode\": \"constant\",\n            },\n            input_shape=None,\n            input_data=input_image,\n            expected_output=expected_output,\n            supports_masking=False,\n            run_training_check=False,\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n\n    def test_random_zoom_in_correctness(self):\n        if backend.config.image_data_format() == \"channels_last\":\n            input_shape = (1, 5, 5, 1)\n        else:\n            input_shape = (1, 1, 5, 5)\n        input_image = np.reshape(np.arange(0, 25), input_shape)\n        expected_output = np.asarray(\n            [\n                [6.0, 6.5, 7.0, 7.5, 8.0],\n                [8.5, 9.0, 9.5, 10.0, 10.5],\n                [11.0, 11.5, 12.0, 12.5, 13.0],\n                [13.5, 14.0, 14.5, 15.0, 15.5],\n                [16.0, 16.5, 17.0, 17.5, 18.0],\n            ]\n        )\n        expected_output = backend.convert_to_tensor(\n            np.reshape(expected_output, input_shape)\n        )\n        self.run_layer_test(\n            layers.RandomZoom,\n            init_kwargs={\n                \"height_factor\": (-0.5, -0.5),\n                \"width_factor\": (-0.5, -0.5),\n                \"interpolation\": \"bilinear\",\n                \"fill_mode\": \"constant\",\n            },\n            input_shape=None,\n            input_data=input_image,\n            expected_output=expected_output,\n            supports_masking=False,\n            run_training_check=False,\n        )\n\n    def test_tf_data_compatibility(self):\n        if backend.config.image_data_format() == \"channels_last\":\n            input_shape = (1, 5, 5, 1)\n        else:\n            input_shape = (1, 1, 5, 5)\n        input_image = np.reshape(np.arange(0, 25), input_shape)\n        layer = layers.RandomZoom(\n            height_factor=(0.5, 0.5),\n            width_factor=(0.8, 0.8),\n            interpolation=\"nearest\",\n            fill_mode=\"constant\",\n        )\n        ds = tf_data.Dataset.from_tensor_slices(input_image).batch(1).map(layer)\n        expected_output = np.asarray(\n            [\n                [0, 0, 0, 0, 0],\n                [0, 5, 7, 9, 0],\n                [0, 10, 12, 14, 0],\n                [0, 20, 22, 24, 0],\n                [0, 0, 0, 0, 0],\n            ]\n        ).reshape(input_shape)\n        output = next(iter(ds)).numpy()\n        self.assertAllClose(expected_output, output)\n\n    def test_dynamic_shape(self):\n        inputs = layers.Input((None, None, 3))\n        outputs = layers.RandomZoom(\n            height_factor=(0.5, 0.5),\n            width_factor=(0.8, 0.8),\n            interpolation=\"nearest\",\n            fill_mode=\"constant\",\n        )(inputs)\n        model = models.Model(inputs, outputs)\n        model.predict(np.random.random((1, 6, 6, 3)))\n\n    @pytest.mark.skipif(\n        backend.backend() == \"numpy\",\n        reason=\"The NumPy backend does not implement fit.\",\n    )\n    def test_connect_with_flatten(self):\n        model = models.Sequential(\n            [\n                layers.RandomZoom((-0.5, 0.0), (-0.5, 0.0)),\n                layers.Flatten(),\n                layers.Dense(1, activation=\"relu\"),\n            ],\n        )\n\n        model.compile(loss=\"mse\")\n        model.fit(np.random.random((2, 2, 2, 1)), y=np.random.random((2,)))\n\n    @parameterized.named_parameters(\n        (\n            \"with_zoom_in\",\n            [[[0.1]], [[0.1]]],\n            [[[0.0, 0.0, 8.0, 0.0], [8.0, 0.0, 8.0, 10.0]]],\n        ),\n        (\n            \"with_zoom_out\",\n            [[[1.9]], [[1.9]]],\n            [\n                [\n                    [2.710526, 2.657895, 3.763158, 3.710526],\n                    [4.815789, 4.236842, 5.868421, 5.289474],\n                ]\n            ],\n        ),\n    )\n    def test_random_flip_bounding_boxes(self, zoom, expected_boxes):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            image_shape = (10, 8, 3)\n        else:\n            image_shape = (3, 10, 8)\n        input_image = np.random.random(image_shape)\n        bounding_boxes = {\n            \"boxes\": np.array(\n                [\n                    [2, 1, 4, 3],\n                    [6, 4, 8, 6],\n                ]\n            ),\n            \"labels\": np.array([[1, 2]]),\n        }\n        input_data = {\"images\": input_image, \"bounding_boxes\": bounding_boxes}\n        random_zoom_layer = layers.RandomZoom(\n            height_factor=(0.5, 0.5),\n            data_format=data_format,\n            seed=42,\n            bounding_box_format=\"xyxy\",\n        )\n\n        transformation = {\n            \"height_zoom\": backend_utils.convert_tf_tensor(np.array(zoom[0])),\n            \"width_zoom\": backend_utils.convert_tf_tensor(np.array(zoom[1])),\n            \"input_shape\": image_shape,\n        }\n        output = random_zoom_layer.transform_bounding_boxes(\n            input_data[\"bounding_boxes\"],\n            transformation=transformation,\n            training=True,\n        )\n\n        self.assertAllClose(output[\"boxes\"], expected_boxes)\n\n    @parameterized.named_parameters(\n        (\n            \"with_zoom_in\",\n            [[[0.1]], [[0.1]]],\n            [[[0.0, 0.0, 8.0, 0.0], [8.0, 0.0, 8.0, 10.0]]],\n        ),\n        (\n            \"with_zoom_out\",\n            [[[1.9]], [[1.9]]],\n            [\n                [\n                    [2.710526, 2.657895, 3.763158, 3.710526],\n                    [4.815789, 4.236842, 5.868421, 5.289474],\n                ]\n            ],\n        ),\n    )\n    def test_random_flip_tf_data_bounding_boxes(self, zoom, expected_boxes):\n        data_format = backend.config.image_data_format()\n        if backend.config.image_data_format() == \"channels_last\":\n            image_shape = (1, 10, 8, 3)\n        else:\n            image_shape = (1, 3, 10, 8)\n        input_image = np.random.random(image_shape)\n        bounding_boxes = {\n            \"boxes\": np.array(\n                [\n                    [\n                        [2, 1, 4, 3],\n                        [6, 4, 8, 6],\n                    ]\n                ]\n            ),\n            \"labels\": np.array([[1, 2]]),\n        }\n\n        input_data = {\"images\": input_image, \"bounding_boxes\": bounding_boxes}\n\n        ds = tf_data.Dataset.from_tensor_slices(input_data)\n        random_zoom_layer = layers.RandomZoom(\n            height_factor=0.5,\n            data_format=data_format,\n            seed=42,\n            bounding_box_format=\"xyxy\",\n        )\n\n        transformation = {\n            \"height_zoom\": np.array(zoom[0]),\n            \"width_zoom\": np.array(zoom[1]),\n            \"input_shape\": image_shape,\n        }\n\n        ds = ds.map(\n            lambda x: random_zoom_layer.transform_bounding_boxes(\n                x[\"bounding_boxes\"],\n                transformation=transformation,\n                training=True,\n            )\n        )\n\n        output = next(iter(ds))\n        expected_boxes = np.array(expected_boxes)\n        self.assertAllClose(output[\"boxes\"], expected_boxes)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/resizing.py",
    "content": "from keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (  # noqa: E501\n    clip_to_image_size,\n)\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (  # noqa: E501\n    convert_format,\n)\nfrom keras.src.ops.core import _saturate_cast\n\n\n@keras_export(\"keras.layers.Resizing\")\nclass Resizing(BaseImagePreprocessingLayer):\n    \"\"\"A preprocessing layer which resizes images.\n\n    This layer resizes an image input to a target height and width. The input\n    should be a 4D (batched) or 3D (unbatched) tensor in `\"channels_last\"`\n    format. Input pixel values can be of any range\n    (e.g. `[0., 1.)` or `[0, 255]`).\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Input shape:\n        3D (unbatched) or 4D (batched) tensor with shape:\n        `(..., height, width, channels)`, in `\"channels_last\"` format,\n        or `(..., channels, height, width)`, in `\"channels_first\"` format.\n\n    Output shape:\n        3D (unbatched) or 4D (batched) tensor with shape:\n        `(..., target_height, target_width, channels)`,\n        or `(..., channels, target_height, target_width)`,\n        in `\"channels_first\"` format.\n\n    Args:\n        height: Integer, the height of the output shape.\n        width: Integer, the width of the output shape.\n        interpolation: String, the interpolation method.\n            Supports `\"bilinear\"`, `\"nearest\"`, `\"bicubic\"`,\n            `\"lanczos3\"`, `\"lanczos5\"`. Defaults to `\"bilinear\"`.\n        crop_to_aspect_ratio: If `True`, resize the images without aspect\n            ratio distortion. When the original aspect ratio differs\n            from the target aspect ratio, the output image will be\n            cropped so as to return the\n            largest possible window in the image (of size `(height, width)`)\n            that matches the target aspect ratio. By default\n            (`crop_to_aspect_ratio=False`), aspect ratio may not be preserved.\n        pad_to_aspect_ratio: If `True`, pad the images without aspect\n            ratio distortion. When the original aspect ratio differs\n            from the target aspect ratio, the output image will be\n            evenly padded on the short side.\n        fill_mode: When using `pad_to_aspect_ratio=True`, padded areas\n            are filled according to the given mode. Only `\"constant\"` is\n            supported at this time\n            (fill with constant value, equal to `fill_value`).\n        fill_value: Float. Padding value to use when `pad_to_aspect_ratio=True`.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, height, width, channels)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, channels, height, width)`. It defaults to the\n            `image_data_format` value found in your Keras config file at\n            `~/.keras/keras.json`. If you never set it, then it will be\n            `\"channels_last\"`.\n        **kwargs: Base layer keyword arguments, such as `name` and `dtype`.\n\n    Example:\n\n    ```python\n    (x_train, y_train), _ = keras.datasets.cifar10.load_data()\n    image = x_train[0]\n    resizer = keras.layers.Resizing(128, 128)\n    resized_image = resizer(image)\n    print(\"original:\", image.shape, \"resized:\", resized_image.shape)\n    ```\n    \"\"\"\n\n    _USE_BASE_FACTOR = False\n\n    def __init__(\n        self,\n        height,\n        width,\n        interpolation=\"bilinear\",\n        crop_to_aspect_ratio=False,\n        pad_to_aspect_ratio=False,\n        fill_mode=\"constant\",\n        fill_value=0.0,\n        antialias=False,\n        data_format=None,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.height = height\n        self.width = width\n        self.interpolation = interpolation\n        self.data_format = backend.standardize_data_format(data_format)\n        self.crop_to_aspect_ratio = crop_to_aspect_ratio\n        self.pad_to_aspect_ratio = pad_to_aspect_ratio\n        self.fill_mode = fill_mode\n        self.fill_value = fill_value\n        self.antialias = bool(antialias)\n        if self.data_format == \"channels_first\":\n            self.height_axis = -2\n            self.width_axis = -1\n        elif self.data_format == \"channels_last\":\n            self.height_axis = -3\n            self.width_axis = -2\n\n    def _transform_images(self, images, transformation, interpolation):\n        size = (self.height, self.width)\n        resized = self.backend.image.resize(\n            images,\n            size=size,\n            interpolation=interpolation,\n            antialias=self.antialias,\n            data_format=self.data_format,\n            crop_to_aspect_ratio=self.crop_to_aspect_ratio,\n            pad_to_aspect_ratio=self.pad_to_aspect_ratio,\n            fill_mode=self.fill_mode,\n            fill_value=self.fill_value,\n        )\n        if resized.dtype == images.dtype:\n            return resized\n        if backend.is_int_dtype(images.dtype):\n            resized = self.backend.numpy.round(resized)\n        return _saturate_cast(resized, images.dtype, self.backend)\n\n    def transform_images(self, images, transformation=None, training=True):\n        return self._transform_images(\n            images, transformation, self.interpolation\n        )\n\n    def transform_segmentation_masks(\n        self, segmentation_masks, transformation=None, training=True\n    ):\n        return self._transform_images(\n            segmentation_masks, transformation, \"nearest\"\n        )\n\n    def transform_labels(self, labels, transformation=None, training=True):\n        return labels\n\n    def get_random_transformation(self, data, training=True, seed=None):\n        if isinstance(data, dict):\n            input_shape = self.backend.shape(data[\"images\"])\n        else:\n            input_shape = self.backend.shape(data)\n\n        input_height, input_width = (\n            input_shape[self.height_axis],\n            input_shape[self.width_axis],\n        )\n\n        return input_height, input_width\n\n    def transform_bounding_boxes(\n        self,\n        bounding_boxes,\n        transformation,\n        training=True,\n    ):\n        ops = self.backend\n        input_height, input_width = transformation\n        mask_negative_1s = ops.numpy.all(bounding_boxes[\"boxes\"] == -1, axis=-1)\n        mask_zeros = ops.numpy.all(bounding_boxes[\"boxes\"] == 0, axis=-1)\n        boxes_mask = ops.numpy.logical_or(mask_negative_1s, mask_zeros)\n\n        bounding_boxes = convert_format(\n            bounding_boxes,\n            source=self.bounding_box_format,\n            target=\"xyxy\",\n            height=input_height,\n            width=input_width,\n        )\n\n        bounding_boxes[\"boxes\"] = self._transform_xyxy(\n            bounding_boxes[\"boxes\"],\n            input_height=input_height,\n            input_width=input_width,\n        )\n\n        bounding_boxes = clip_to_image_size(\n            bounding_boxes=bounding_boxes,\n            height=self.height,\n            width=self.width,\n        )\n\n        bounding_boxes[\"boxes\"] = ops.numpy.where(\n            ops.numpy.expand_dims(boxes_mask, axis=-1),\n            ops.convert_to_tensor(\n                [0.0, 0.0, 0.0, 0.0], dtype=bounding_boxes[\"boxes\"].dtype\n            ),\n            bounding_boxes[\"boxes\"],\n        )\n\n        bounding_boxes = convert_format(\n            bounding_boxes,\n            source=\"xyxy\",\n            target=self.bounding_box_format,\n            height=self.height,\n            width=self.width,\n        )\n\n        return bounding_boxes\n\n    def _transform_xyxy(self, boxes, input_height, input_width):\n        ops = self.backend\n        input_height = ops.cast(input_height, dtype=boxes.dtype)\n        input_width = ops.cast(input_width, dtype=boxes.dtype)\n\n        if self.pad_to_aspect_ratio:\n            return self._transform_boxes_pad_to_aspect_ratio(\n                boxes, input_height, input_width\n            )\n        elif self.crop_to_aspect_ratio:\n            return self._transform_boxes_crop_to_aspect_ratio(\n                boxes, input_height, input_width\n            )\n        else:\n            return self._transform_boxes_stretch(\n                boxes, input_height, input_width\n            )\n\n    def _transform_boxes_pad_to_aspect_ratio(\n        self, boxes, input_height, input_width\n    ):\n        \"\"\"Transforms bounding boxes for padding to aspect ratio.\"\"\"\n        ops = self.backend\n        height_ratio = ops.cast(self.height / input_height, dtype=boxes.dtype)\n        width_ratio = ops.cast(self.width / input_width, dtype=boxes.dtype)\n        min_aspect_ratio = ops.numpy.minimum(height_ratio, width_ratio)\n        y_offset = (self.height - input_height * min_aspect_ratio) // 2\n        x_offset = (self.width - input_width * min_aspect_ratio) // 2\n        return ops.numpy.stack(\n            [\n                boxes[..., 0] * min_aspect_ratio + x_offset,\n                boxes[..., 1] * min_aspect_ratio + y_offset,\n                boxes[..., 2] * min_aspect_ratio + x_offset,\n                boxes[..., 3] * min_aspect_ratio + y_offset,\n            ],\n            axis=-1,\n        )\n\n    def _transform_boxes_crop_to_aspect_ratio(\n        self, boxes, input_height, input_width\n    ):\n        \"\"\"Transforms bounding boxes for cropping to aspect ratio.\"\"\"\n        ops = self.backend\n        source_aspect_ratio = input_width / input_height\n        target_aspect_ratio = self.width / self.height\n        new_width = ops.numpy.where(\n            source_aspect_ratio > target_aspect_ratio,\n            self.height * source_aspect_ratio,\n            self.width,\n        )\n        new_height = ops.numpy.where(\n            source_aspect_ratio > target_aspect_ratio,\n            self.height,\n            self.width / source_aspect_ratio,\n        )\n        scale_x = new_width / input_width\n        scale_y = new_height / input_height\n        crop_left = (new_width - self.width) // 2\n        crop_top = (new_height - self.height) // 2\n        return ops.numpy.stack(\n            [\n                boxes[..., 0] * scale_x - crop_left,\n                boxes[..., 1] * scale_y - crop_top,\n                boxes[..., 2] * scale_x - crop_left,\n                boxes[..., 3] * scale_y - crop_top,\n            ],\n            axis=-1,\n        )\n\n    def _transform_boxes_stretch(self, boxes, input_height, input_width):\n        \"\"\"Transforms bounding boxes by simple stretching.\"\"\"\n        ops = self.backend\n        height_ratio = ops.cast(self.height / input_height, dtype=boxes.dtype)\n        width_ratio = ops.cast(self.width / input_width, dtype=boxes.dtype)\n        return ops.numpy.stack(\n            [\n                boxes[..., 0] * width_ratio,\n                boxes[..., 1] * height_ratio,\n                boxes[..., 2] * width_ratio,\n                boxes[..., 3] * height_ratio,\n            ],\n            axis=-1,\n        )\n\n    def compute_output_shape(self, input_shape):\n        input_shape = list(input_shape)\n        if len(input_shape) == 4:\n            if self.data_format == \"channels_last\":\n                input_shape[1] = self.height\n                input_shape[2] = self.width\n            else:\n                input_shape[2] = self.height\n                input_shape[3] = self.width\n        else:\n            if self.data_format == \"channels_last\":\n                input_shape[0] = self.height\n                input_shape[1] = self.width\n            else:\n                input_shape[1] = self.height\n                input_shape[2] = self.width\n        return tuple(input_shape)\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\n            \"height\": self.height,\n            \"width\": self.width,\n            \"interpolation\": self.interpolation,\n            \"crop_to_aspect_ratio\": self.crop_to_aspect_ratio,\n            \"pad_to_aspect_ratio\": self.pad_to_aspect_ratio,\n            \"fill_mode\": self.fill_mode,\n            \"fill_value\": self.fill_value,\n            \"antialias\": self.antialias,\n            \"data_format\": self.data_format,\n        }\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/resizing_test.py",
    "content": "import grain\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\nfrom tensorflow import data as tf_data\n\nfrom keras.src import Sequential\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\nfrom keras.src.testing.test_utils import named_product\n\n\nclass ResizingTest(testing.TestCase):\n    @parameterized.named_parameters(\n        named_product(\n            interpolation=[\"nearest\", \"bilinear\", \"bicubic\", \"lanczos5\"],\n            crop_pad=[(False, False), (True, False), (False, True)],\n            antialias=[False, True],\n            data_format=[\"channels_last\", \"channels_first\"],\n        )\n    )\n    def test_resizing_basics(\n        self,\n        interpolation,\n        crop_pad,\n        antialias,\n        data_format,\n    ):\n        if interpolation == \"lanczos5\" and backend.backend() == \"torch\":\n            self.skipTest(\"Torch does not support lanczos.\")\n\n        crop_to_aspect_ratio, pad_to_aspect_ratio = crop_pad\n        if data_format == \"channels_last\":\n            input_shape = (2, 12, 12, 3)\n            expected_output_shape = (2, 6, 6, 3)\n        else:\n            input_shape = (2, 3, 12, 12)\n            expected_output_shape = (2, 3, 6, 6)\n\n        self.run_layer_test(\n            layers.Resizing,\n            init_kwargs={\n                \"height\": 6,\n                \"width\": 6,\n                \"interpolation\": interpolation,\n                \"crop_to_aspect_ratio\": crop_to_aspect_ratio,\n                \"pad_to_aspect_ratio\": pad_to_aspect_ratio,\n                \"antialias\": antialias,\n                \"data_format\": data_format,\n            },\n            input_shape=input_shape,\n            expected_output_shape=expected_output_shape,\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n            run_training_check=False,\n        )\n\n    @parameterized.parameters([(\"channels_first\",), (\"channels_last\",)])\n    def test_down_sampling_numeric(self, data_format):\n        img = np.reshape(np.arange(0, 16), (1, 4, 4, 1)).astype(np.float32)\n        if data_format == \"channels_first\":\n            img = img.transpose(0, 3, 1, 2)\n        out = layers.Resizing(\n            height=2, width=2, interpolation=\"nearest\", data_format=data_format\n        )(img)\n        ref_out = (\n            np.asarray([[5, 7], [13, 15]])\n            .astype(np.float32)\n            .reshape((1, 2, 2, 1))\n        )\n        if data_format == \"channels_first\":\n            ref_out = ref_out.transpose(0, 3, 1, 2)\n        self.assertAllClose(ref_out, out)\n\n    @parameterized.parameters([(\"channels_first\",), (\"channels_last\",)])\n    def test_up_sampling_numeric(self, data_format):\n        img = np.reshape(np.arange(0, 4), (1, 2, 2, 1)).astype(np.float32)\n        if data_format == \"channels_first\":\n            img = img.transpose(0, 3, 1, 2)\n        out = layers.Resizing(\n            height=4,\n            width=4,\n            interpolation=\"nearest\",\n            data_format=data_format,\n        )(img)\n        ref_out = (\n            np.asarray([[0, 0, 1, 1], [0, 0, 1, 1], [2, 2, 3, 3], [2, 2, 3, 3]])\n            .astype(np.float32)\n            .reshape((1, 4, 4, 1))\n        )\n        if data_format == \"channels_first\":\n            ref_out = ref_out.transpose(0, 3, 1, 2)\n        self.assertAllClose(ref_out, out)\n\n    @parameterized.parameters([(\"channels_first\",), (\"channels_last\",)])\n    def test_crop_to_aspect_ratio(self, data_format):\n        img = np.reshape(np.arange(0, 16), (1, 4, 4, 1)).astype(\"float32\")\n        if data_format == \"channels_first\":\n            img = img.transpose(0, 3, 1, 2)\n        out = layers.Resizing(\n            height=4,\n            width=2,\n            interpolation=\"nearest\",\n            data_format=data_format,\n            crop_to_aspect_ratio=True,\n        )(img)\n        ref_out = (\n            np.asarray(\n                [\n                    [1, 2],\n                    [5, 6],\n                    [9, 10],\n                    [13, 14],\n                ]\n            )\n            .astype(\"float32\")\n            .reshape((1, 4, 2, 1))\n        )\n        if data_format == \"channels_first\":\n            ref_out = ref_out.transpose(0, 3, 1, 2)\n        self.assertAllClose(ref_out, out)\n\n    @parameterized.parameters([(\"channels_first\",), (\"channels_last\",)])\n    def test_unbatched_image(self, data_format):\n        img = np.reshape(np.arange(0, 16), (4, 4, 1)).astype(\"float32\")\n        if data_format == \"channels_first\":\n            img = img.transpose(2, 0, 1)\n        out = layers.Resizing(\n            2, 2, interpolation=\"nearest\", data_format=data_format\n        )(img)\n        ref_out = (\n            np.asarray(\n                [\n                    [5, 7],\n                    [13, 15],\n                ]\n            )\n            .astype(\"float32\")\n            .reshape((2, 2, 1))\n        )\n        if data_format == \"channels_first\":\n            ref_out = ref_out.transpose(2, 0, 1)\n        self.assertAllClose(ref_out, out)\n\n    def test_tf_data_compatibility(self):\n        if backend.config.image_data_format() == \"channels_last\":\n            input_shape = (2, 10, 12, 3)\n            output_shape = (2, 8, 9, 3)\n        else:\n            input_shape = (2, 3, 10, 12)\n            output_shape = (2, 3, 8, 9)\n        layer = layers.Resizing(8, 9)\n        input_data = np.random.random(input_shape)\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        output = next(iter(ds)).numpy()\n        self.assertEqual(tuple(output.shape), output_shape)\n\n    def test_grain_compatibility(self):\n        if backend.config.image_data_format() == \"channels_last\":\n            input_shape = (2, 10, 12, 3)\n            output_shape = (2, 8, 9, 3)\n        else:\n            input_shape = (2, 3, 10, 12)\n            output_shape = (2, 3, 8, 9)\n        layer = layers.Resizing(8, 9)\n        input_data = np.random.random(input_shape)\n        ds = (\n            grain.MapDataset.source(input_data)\n            .to_iter_dataset()\n            .batch(2)\n            .map(layer)\n        )\n        output = next(iter(ds))\n        output_np = backend.convert_to_numpy(output)\n\n        self.assertEqual(tuple(output_np.shape), output_shape)\n        self.assertTrue(backend.is_tensor(output))\n        # Ensure the device of the data is on CPU.\n        if backend.backend() == \"tensorflow\":\n            self.assertIn(\"CPU\", str(output.device))\n        elif backend.backend() == \"jax\":\n            self.assertIn(\"CPU\", str(output.device))\n        elif backend.backend() == \"torch\":\n            self.assertEqual(\"cpu\", str(output.device))\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"Sequential + tf.data only works with TF backend\",\n    )\n    def test_tf_data_compatibility_sequential(self):\n        # Test compatibility when wrapping in a Sequential\n        # https://github.com/keras-team/keras/issues/347\n        if backend.config.image_data_format() == \"channels_last\":\n            input_shape = (2, 10, 12, 3)\n            output_shape = (2, 8, 9, 3)\n        else:\n            input_shape = (2, 3, 10, 12)\n            output_shape = (2, 3, 8, 9)\n        layer = layers.Resizing(8, 9)\n        input_data = np.random.random(input_shape)\n        ds = (\n            tf_data.Dataset.from_tensor_slices(input_data)\n            .batch(2)\n            .map(Sequential([layer]))\n        )\n        output = next(iter(ds)).numpy()\n        self.assertEqual(tuple(output.shape), output_shape)\n\n    @parameterized.parameters(\n        [((15, 10), \"channels_last\"), ((15, 100), \"channels_last\")]\n    )\n    def test_data_stretch(self, size, data_format):\n        img = np.random.rand(1, 1, 4, 4)\n        output = layers.Resizing(\n            size[0], size[1], data_format=data_format, crop_to_aspect_ratio=True\n        )(img)\n        self.assertEqual(output.shape, (1, *size, 4))\n\n    @parameterized.named_parameters(\n        (\n            \"with_pad_to_aspect_ratio\",\n            True,\n            False,\n            [[6.0, 2.0, 10.0, 6.0], [14.0, 8.0, 18.0, 12.0]],\n        ),\n        (\n            \"with_crop_to_aspect_ratio\",\n            False,\n            True,\n            [[5.0, 0.5, 10.0, 5.5], [15.0, 8.0, 20.0, 13.0]],\n        ),\n        (\n            \"boxes_stretch\",\n            False,\n            False,\n            [[5.0, 2.0, 10.0, 6.0], [15.0, 8.0, 20.0, 12.0]],\n        ),\n    )\n    def test_resize_bounding_boxes(\n        self, pad_to_aspect_ratio, crop_to_aspect_ratio, expected_boxes\n    ):\n        if backend.config.image_data_format() == \"channels_last\":\n            image_shape = (10, 8, 3)\n        else:\n            image_shape = (3, 10, 8)\n        input_image = np.random.random(image_shape)\n        bounding_boxes = {\n            \"boxes\": np.array(\n                [\n                    [2, 1, 4, 3],\n                    [6, 4, 8, 6],\n                ]\n            ),  # Example boxes (normalized)\n            \"labels\": np.array([[1, 2]]),  # Dummy labels\n        }\n        input_data = {\"images\": input_image, \"bounding_boxes\": bounding_boxes}\n        resizing_layer = layers.Resizing(\n            height=20,\n            width=20,\n            pad_to_aspect_ratio=pad_to_aspect_ratio,\n            crop_to_aspect_ratio=crop_to_aspect_ratio,\n            bounding_box_format=\"xyxy\",\n        )\n        output = resizing_layer(input_data)\n        self.assertAllClose(output[\"bounding_boxes\"][\"boxes\"], expected_boxes)\n\n    @parameterized.named_parameters(\n        (\n            \"with_pad_to_aspect_ratio\",\n            True,\n            False,\n            [[6.0, 2.0, 10.0, 6.0], [14.0, 8.0, 18.0, 12.0]],\n        ),\n        (\n            \"with_crop_to_aspect_ratio\",\n            False,\n            True,\n            [[5.0, 0.5, 10.0, 5.5], [15.0, 8.0, 20.0, 13.0]],\n        ),\n        (\n            \"boxes_stretch\",\n            False,\n            False,\n            [[5.0, 2.0, 10.0, 6.0], [15.0, 8.0, 20.0, 12.0]],\n        ),\n    )\n    def test_resize_tf_data_bounding_boxes(\n        self, pad_to_aspect_ratio, crop_to_aspect_ratio, expected_boxes\n    ):\n        if backend.config.image_data_format() == \"channels_last\":\n            image_shape = (1, 10, 8, 3)\n        else:\n            image_shape = (1, 3, 10, 8)\n        input_image = np.random.random(image_shape)\n        bounding_boxes = {\n            \"boxes\": np.array(\n                [\n                    [\n                        [2, 1, 4, 3],\n                        [6, 4, 8, 6],\n                    ]\n                ]\n            ),  # Example boxes (normalized)\n            \"labels\": np.array([[1, 2]]),  # Dummy labels\n        }\n\n        input_data = {\"images\": input_image, \"bounding_boxes\": bounding_boxes}\n\n        ds = tf_data.Dataset.from_tensor_slices(input_data)\n        resizing_layer = layers.Resizing(\n            height=20,\n            width=20,\n            pad_to_aspect_ratio=pad_to_aspect_ratio,\n            crop_to_aspect_ratio=crop_to_aspect_ratio,\n            bounding_box_format=\"xyxy\",\n        )\n        ds = ds.map(resizing_layer)\n        output = next(iter(ds))\n        expected_boxes = np.array(expected_boxes)\n        self.assertAllClose(output[\"bounding_boxes\"][\"boxes\"], expected_boxes)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/solarization.py",
    "content": "from keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\nfrom keras.src.ops.core import _saturate_cast\nfrom keras.src.random.seed_generator import SeedGenerator\n\n\n@keras_export(\"keras.layers.Solarization\")\nclass Solarization(BaseImagePreprocessingLayer):\n    \"\"\"Applies `(max_value - pixel + min_value)` for each pixel in the image.\n\n    When created without `threshold` parameter, the layer performs solarization\n    to all values. When created with specified `threshold` the layer only\n    augments pixels that are above the `threshold` value.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Args:\n        addition_factor: (Optional)  A tuple of two floats or a single float,\n            between 0 and 1.\n            For each augmented image a value is\n            sampled from the provided range. If a float is passed, the range is\n            interpreted as `(0, addition_factor)`. If specified, this value\n            (times the value range of input images, e.g. 255), is\n            added to each pixel before solarization and thresholding.\n            Defaults to 0.0.\n        threshold_factor: (Optional)  A tuple of two floats or a single float.\n            For each augmented image a value is\n            sampled from the provided range. If a float is passed, the range is\n            interpreted as `(0, threshold_factor)`. If specified, only pixel\n            values above this threshold will be solarized.\n        value_range: a tuple or a list of two elements. The first value\n            represents the lower bound for values in input images, the second\n            represents the upper bound. Images passed to the layer should have\n            values within `value_range`. Typical values to pass\n            are `(0, 255)` (RGB image) or `(0., 1.)` (scaled image).\n        seed: Integer. Used to create a random seed.\n        **kwargs: Base layer keyword arguments, such as `name` and `dtype`.\n\n    Example:\n\n    ```python\n    (images, labels), _ = keras.datasets.cifar10.load_data()\n    print(images[0, 0, 0])\n    # [59 62 63]\n    # Note that images are Tensor with values in the range [0, 255]\n    solarization = Solarization(value_range=(0, 255))\n    images = solarization(images)\n    print(images[0, 0, 0])\n    # [196, 193, 192]\n    ```\n    \"\"\"\n\n    _USE_BASE_FACTOR = False\n    _VALUE_RANGE_VALIDATION_ERROR = (\n        \"The `value_range` argument should be a list of two numbers. \"\n    )\n    _FACTOR_VALIDATION_ERROR = (\n        \"The `addition_factor` and `threshold_factor` arguments \"\n        \"should be a number (or a list of two numbers) \"\n        \"in the range [0, 1]. \"\n    )\n\n    def __init__(\n        self,\n        addition_factor=0.0,\n        threshold_factor=0.0,\n        value_range=(0, 255),\n        seed=None,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.seed = seed\n        self.generator = SeedGenerator(seed)\n        self.addition_factor = self._set_factor(\n            addition_factor, \"addition_factor\"\n        )\n        self.threshold_factor = self._set_factor(\n            threshold_factor, \"threshold_factor\"\n        )\n        self._set_value_range(value_range)\n\n    def _set_value_range(self, value_range):\n        if not isinstance(value_range, (tuple, list)):\n            raise ValueError(\n                self._VALUE_RANGE_VALIDATION_ERROR\n                + f\"Received: value_range={value_range}\"\n            )\n        if len(value_range) != 2:\n            raise ValueError(\n                self._VALUE_RANGE_VALIDATION_ERROR\n                + f\"Received: value_range={value_range}\"\n            )\n        self.value_range = sorted(value_range)\n\n    def _set_factor(self, factor, factor_name):\n        if isinstance(factor, (tuple, list)):\n            if len(factor) != 2:\n                raise ValueError(\n                    self._FACTOR_VALIDATION_ERROR\n                    + f\"Received: {factor_name}={factor}\"\n                )\n            self._check_factor_range(factor[0])\n            self._check_factor_range(factor[1])\n            lower, upper = sorted(factor)\n        elif isinstance(factor, (int, float)):\n            self._check_factor_range(factor)\n            lower, upper = [0, factor]\n        else:\n            raise ValueError(\n                self._FACTOR_VALIDATION_ERROR\n                + f\"Received: {factor_name}={factor}\"\n            )\n        return lower, upper\n\n    def _check_factor_range(self, input_number):\n        if input_number > 1.0 or input_number < 0:\n            raise ValueError(\n                self._FACTOR_VALIDATION_ERROR\n                + f\"Received: input_number={input_number}\"\n            )\n\n    def get_random_transformation(self, data, training=True, seed=None):\n        if not training:\n            return None\n\n        if isinstance(data, dict):\n            images = data[\"images\"]\n        else:\n            images = data\n        images_shape = self.backend.shape(images)\n        if len(images_shape) == 4:\n            factor_shape = (images_shape[0], 1, 1, 1)\n        else:\n            factor_shape = (1, 1, 1)\n\n        if seed is None:\n            seed = self._get_seed_generator(self.backend._backend)\n\n        return {\n            \"additions\": self.backend.random.uniform(\n                minval=self.addition_factor[0],\n                maxval=self.addition_factor[1] * 255,\n                shape=factor_shape,\n                seed=seed,\n                dtype=self.compute_dtype,\n            ),\n            \"thresholds\": self.backend.random.uniform(\n                minval=self.threshold_factor[0],\n                maxval=self.threshold_factor[1] * 255,\n                shape=factor_shape,\n                seed=seed,\n                dtype=self.compute_dtype,\n            ),\n        }\n\n    def transform_images(self, images, transformation, training=True):\n        images = self.backend.cast(images, self.compute_dtype)\n\n        if training:\n            if transformation is None:\n                return images\n\n            thresholds = transformation[\"thresholds\"]\n            additions = transformation[\"additions\"]\n            images = self._transform_value_range(\n                images,\n                original_range=self.value_range,\n                target_range=(0, 255),\n                dtype=self.compute_dtype,\n            )\n            results = images + additions\n            results = self.backend.numpy.clip(results, 0, 255)\n            results = self.backend.numpy.where(\n                results < thresholds, results, 255 - results\n            )\n            results = self._transform_value_range(\n                results,\n                original_range=(0, 255),\n                target_range=self.value_range,\n                dtype=self.compute_dtype,\n            )\n            if results.dtype == images.dtype:\n                return results\n            if backend.is_int_dtype(images.dtype):\n                results = self.backend.numpy.round(results)\n            return _saturate_cast(results, images.dtype, self.backend)\n        return images\n\n    def transform_labels(self, labels, transformation, training=True):\n        return labels\n\n    def transform_bounding_boxes(\n        self, bounding_boxes, transformation, training=True\n    ):\n        return bounding_boxes\n\n    def transform_segmentation_masks(\n        self, segmentation_masks, transformation, training=True\n    ):\n        return segmentation_masks\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\n            \"value_range\": self.value_range,\n            \"addition_factor\": self.addition_factor,\n            \"threshold_factor\": self.threshold_factor,\n            \"seed\": self.seed,\n        }\n        return {**base_config, **config}\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n"
  },
  {
    "path": "keras/src/layers/preprocessing/image_preprocessing/solarization_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import layers\nfrom keras.src import ops\nfrom keras.src import random\nfrom keras.src import testing\n\n\nclass SolarizationTest(testing.TestCase):\n    def _test_input_output(self, layer, input_value, expected_value, dtype):\n        input = np.ones(shape=(2, 224, 224, 3), dtype=dtype) * input_value\n        expected_output = ops.clip(\n            (\n                np.ones(shape=(2, 224, 224, 3), dtype=layer.compute_dtype)\n                * expected_value\n            ),\n            0,\n            255,\n        )\n        output = layer(input)\n        self.assertAllClose(output, expected_output)\n\n    @pytest.mark.requires_trainable_backend\n    def test_layer(self):\n        self.run_layer_test(\n            layers.Solarization,\n            init_kwargs={\n                \"addition_factor\": 0.75,\n                \"value_range\": (20, 200),\n                \"threshold_factor\": (0, 1),\n                \"seed\": 1,\n            },\n            input_shape=(8, 3, 4, 3),\n            supports_masking=False,\n            expected_output_shape=(8, 3, 4, 3),\n        )\n\n    @parameterized.named_parameters(\n        (\"0_255\", 0, 255),\n        (\"64_191\", 64, 191),\n        (\"127_128\", 127, 128),\n        (\"191_64\", 191, 64),\n        (\"255_0\", 255, 0),\n    )\n    def test_output_values(self, input_value, expected_value):\n        solarization = layers.Solarization(value_range=(0, 255))\n\n        self._test_input_output(\n            layer=solarization,\n            input_value=input_value,\n            expected_value=expected_value,\n            dtype=\"uint8\",\n        )\n\n    @parameterized.named_parameters(\n        (\"0_0\", 0, 0),\n        (\"191_64\", 191, 64),\n        (\"255_0\", 255, 0),\n    )\n    def test_only_values_above_threshold_are_solarized(\n        self, input_value, output_value\n    ):\n        solarization = layers.Solarization(\n            threshold_factor=(128.0 / 255.0, 128.0 / 255.0),\n            value_range=(0, 255),\n        )\n\n        self._test_input_output(\n            layer=solarization,\n            input_value=input_value,\n            expected_value=output_value,\n            dtype=\"uint8\",\n        )\n\n    def test_random_augmentation_applied_per_sample(self):\n        image = random.uniform((16, 16, 3), minval=0, maxval=255)\n        images = ops.stack([image, image])\n        layer = layers.Solarization(\n            value_range=(0, 255), threshold_factor=0.5, addition_factor=0.5\n        )\n        outputs = layer(images)\n        self.assertNotAllClose(outputs[0], outputs[1])\n"
  },
  {
    "path": "keras/src/layers/preprocessing/index_lookup.py",
    "content": "import collections\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src.layers.layer import Layer\nfrom keras.src.saving import serialization_lib\nfrom keras.src.utils import argument_validation\nfrom keras.src.utils import numerical_utils\nfrom keras.src.utils import tf_utils\nfrom keras.src.utils.module_utils import tensorflow as tf\n\n\nclass IndexLookup(Layer):\n    \"\"\"Maps values from a vocabulary to integer indices.\n\n    This layer translates a set of arbitrary hashables into an integer output\n    via a table-based lookup, with optional out-of-vocabulary handling. This is\n    the basis layer for both IntegerLookup and StringLookup; it holds the common\n    logic but is not intended to be exported as part of the Keras API.\n\n    Args:\n        max_tokens: The maximum size of the vocabulary for this layer.\n            If `None`, there is no cap on the size of the vocabulary.\n            Note that this size includes the OOV and mask tokens.\n        num_oov_indices: The number of out-of-vocabulary tokens to use.\n            If this value is more than 1, OOV inputs are hashed to determine\n            their OOV value. If this value is 0,\n            OOV inputs will cause an error when calling the layer.\n        mask_token: A token that represents masked inputs.\n            When `output_mode` is `\"int\"`,\n            the token is included in vocabulary and mapped to index 0.\n            In other output modes, the token will not appear in the vocabulary\n            and instances of the mask token in the input will be dropped.\n            If set to `None`, no mask term will be added.\n        oov_token: Only used when `invert` is `True`.\n            The token to return for OOV indices.\n        vocabulary: Optional. Either an array or a string path to a text file.\n            If passing an array, can pass a tuple, list, 1D numpy array,\n            or 1D tensor containing the vocbulary terms.\n            If passing a file path, the file should contain one line per term\n            in the vocabulary. If this argument is set,\n            there is no need to `adapt` the layer.\n        vocabulary_dtype: The dtype of the vocabulary terms.\n            For example, `\"int64\"` or `\"string\"`.\n        idf_weights: Only valid when `output_mode` is `\"tf_idf\"`.\n            A tuple, list, 1D numpy array, or 1D tensor or the same length\n            as the vocabulary, containing the floating point\n            inverse document frequency weights, which will be multiplied\n            by per sample term counts for the final TF-IDF\n            weight. If the `vocabulary` argument is set, and `output_mode`\n            is `\"tf_idf\"`, this argument must be supplied.\n        invert: Only valid when `output_mode` is `\"int\"`.\n            If `True`, this layer will map indices to vocabulary items\n            instead of mapping vocabulary items to indices.\n            Defaults to `False`.\n        output_mode: Specification for the output of the layer. Values can be\n            `\"int\"`, `\"one_hot\"`, `\"multi_hot\"`, `\"count\"`, or `\"tf_idf\"`\n            configuring the layer as follows:\n            - `\"int\"`: Return the raw integer indices of the input tokens.\n            - `\"one_hot\"`: Encodes each individual element in the input into an\n                array the same size as the vocabulary, containing a 1\n                at the element index. If the last dimension is size 1,\n                will encode on that dimension.\n                If the last dimension is not size 1,\n                will append a new dimension for the encoded output.\n            - `\"multi_hot\"`: Encodes each sample in the input into\n                a single array the same size as the vocabulary,\n                containing a 1 for each vocabulary term present in the sample.\n                Treats the last dimension as the sample dimension,\n                if input shape is `(..., sample_length)`, output shape will\n                be `(..., num_tokens)`.\n            - `\"count\"`: As `\"multi_hot\"`, but the int array contains a count\n                of the number of times the token at that index appeared\n                in the sample.\n            - `\"tf_idf\"`: As `\"multi_hot\"`, but the TF-IDF algorithm\n                is applied to find the value in each token slot.\n            Defaults to `\"int\"`.\n        pad_to_max_tokens: Only valid when `output_mode` is `\"multi_hot\"`,\n            `\"count\"`, or `\"tf_idf\"`. If `True`, the output will have its\n            feature axis padded to `max_tokens` even if the number\n            of unique tokens in the vocabulary is less than max_tokens,\n            resulting in a tensor of shape `(batch_size, max_tokens)`\n            regardless of vocabulary size. Defaults to `False`.\n        sparse: Boolean. Only applicable to `\"one_hot\"`, `\"multi_hot\"`,\n            `\"count\"` and `\"tf-idf\"` output modes.\n            If `True`, returns a `SparseTensor` instead of a dense `Tensor`.\n            Defaults to `False`.\n        oov_method: Only relevant when `num_oov_indices > 1` and the input\n            dtype is integer (i.e. for `IntegerLookup`). Controls how\n            Out-Of-Vocabulary (OOV) tokens are assigned to OOV buckets.\n            - `\"floormod\"` (default): uses `token % num_oov_indices`.\n              Preserves backwards compatibility but can produce severe bucket\n              imbalance when input IDs share a common factor with\n              `num_oov_indices` (e.g. all-even IDs with an even bucket count).\n            - `\"farmhash\"`: applies FarmHash64. Distributes OOV tokens\n            uniformly regardless of the arithmetic structure of the input IDs.\n            This parameter is ignored for string inputs, which always use\n            FarmHash64.\n    \"\"\"\n\n    def __init__(\n        self,\n        max_tokens,\n        num_oov_indices,\n        mask_token,\n        oov_token,\n        vocabulary_dtype,\n        vocabulary=None,\n        idf_weights=None,\n        invert=False,\n        output_mode=\"int\",\n        sparse=False,\n        pad_to_max_tokens=False,\n        oov_method=\"floormod\",\n        name=None,\n        **kwargs,\n    ):\n        # If max_tokens is set, the value must be greater than 1 - otherwise we\n        # are creating a 0-element vocab, which doesn't make sense.\n        if max_tokens is not None and max_tokens <= 1:\n            raise ValueError(\n                \"If set, `max_tokens` must be greater than 1. \"\n                f\"Received: max_tokens={max_tokens}\"\n            )\n\n        if pad_to_max_tokens and max_tokens is None:\n            raise ValueError(\n                \"If pad_to_max_tokens is True, must set `max_tokens`. \"\n                f\"Received: max_tokens={max_tokens}\"\n            )\n\n        if num_oov_indices < 0:\n            raise ValueError(\n                \"`num_oov_indices` must be greater than or equal to 0. \"\n                f\"Received: num_oov_indices={num_oov_indices}\"\n            )\n\n        argument_validation.validate_string_arg(\n            oov_method,\n            allowable_strings=(\"floormod\", \"farmhash\"),\n            caller_name=self.__class__.__name__,\n            arg_name=\"oov_method\",\n        )\n\n        # Support deprecated names for output_modes.\n        if output_mode == \"binary\":\n            output_mode = \"multi_hot\"\n        if output_mode == \"tf-idf\":\n            output_mode = \"tf_idf\"\n        argument_validation.validate_string_arg(\n            output_mode,\n            allowable_strings=(\n                \"int\",\n                \"one_hot\",\n                \"multi_hot\",\n                \"count\",\n                \"tf_idf\",\n            ),\n            caller_name=self.__class__.__name__,\n            arg_name=\"output_mode\",\n        )\n\n        if invert and output_mode != \"int\":\n            raise ValueError(\n                \"`output_mode` must be `'int'` when `invert` is true. \"\n                f\"Received: output_mode={output_mode}\"\n            )\n\n        if sparse and output_mode == \"int\":\n            raise ValueError(\n                \"`sparse` may only be true if `output_mode` is \"\n                \"`'one_hot'`, `'multi_hot'`, `'count'` or `'tf_idf'`. \"\n                f\"Received: sparse={sparse} and \"\n                f\"output_mode={output_mode}\"\n            )\n\n        if idf_weights is not None and output_mode != \"tf_idf\":\n            raise ValueError(\n                \"`idf_weights` should only be set if `output_mode` is \"\n                f\"`'tf_idf'`. Received: idf_weights={idf_weights} and \"\n                f\"output_mode={output_mode}\"\n            )\n\n        super().__init__(name=name)\n        self._convert_input_args = False\n        self._allow_non_tensor_positional_args = True\n        self.supports_jit = False\n\n        self.invert = invert\n        self.max_tokens = max_tokens\n        self.num_oov_indices = num_oov_indices\n        self.mask_token = mask_token\n        self.oov_token = oov_token\n        self.output_mode = output_mode\n        self.sparse = sparse\n        self.pad_to_max_tokens = pad_to_max_tokens\n        self.vocabulary_dtype = tf.as_dtype(vocabulary_dtype).name\n        self.oov_method = oov_method\n        self._frozen_vocab_size = kwargs.pop(\"vocabulary_size\", None)\n\n        # Remember original `vocabulary` as `input_vocabulary` for serialization\n        # via `get_config`. However, if `vocabulary` is a file path or a URL, we\n        # serialize the vocabulary as an asset and clear the original path/URL.\n        self.input_vocabulary = (\n            vocabulary if not isinstance(vocabulary, str) else None\n        )\n        self.input_idf_weights = idf_weights\n\n        # We set this hidden attr to\n        # persist the fact that we have have a non-adaptable layer with a\n        # manually set vocab.\n        self._has_input_vocabulary = kwargs.pop(\n            \"has_input_vocabulary\", (vocabulary is not None)\n        )\n        kwargs.pop(\"trainable\", None)\n        kwargs.pop(\"dtype\", None)\n        if kwargs:\n            raise ValueError(f\"Unrecognized keyword argument(s): {kwargs}\")\n\n        if invert:\n            self._key_dtype = \"int64\"\n            self._value_dtype = self.vocabulary_dtype\n            mask_key = 0\n            mask_value = mask_token\n            self._default_value = self.oov_token\n        else:\n            self._key_dtype = self.vocabulary_dtype\n            self._value_dtype = \"int64\"\n            mask_key = mask_token\n            # Masks should map to 0 for int output and be dropped otherwise. Max\n            # ints will be dropped from the bincount op.\n            mask_value = (\n                0\n                if self.output_mode == \"int\"\n                else tf.as_dtype(self._value_dtype).max\n            )\n            if self.num_oov_indices == 0:\n                # If there are no OOV indices, we map OOV tokens to -1 and error\n                # out during call if we find a negative index.\n                self._default_value = -1\n            elif self.num_oov_indices == 1:\n                # If there is only one OOV index, we can set that index as the\n                # default value of the index_lookup table.\n                self._default_value = self._oov_start_index()\n            else:\n                # If we have multiple OOV values, we need to do a further\n                # hashing step; to make this easier, we set the OOV value to -1.\n                # (This lets us do a vectorized add and cast to boolean to\n                # determine locations where we need to do extra hashing.)\n                self._default_value = -1\n        if self.mask_token is not None:\n            self._mask_key = tf.convert_to_tensor(mask_key, self._key_dtype)\n            self._mask_value = tf.convert_to_tensor(\n                mask_value, self._value_dtype\n            )\n\n        if self.output_mode == \"tf_idf\":\n            if self._has_input_vocabulary and idf_weights is None:\n                raise ValueError(\n                    \"When specifying the `vocabulary` argument, \"\n                    \"in TF-IDF output mode, the `idf_weights` argument \"\n                    \"must also be provided.\"\n                )\n            if idf_weights is not None:\n                self.idf_weights = tf.Variable(\n                    idf_weights,\n                    dtype=backend.floatx(),\n                    trainable=False,\n                )\n                self.idf_weights_const = self.idf_weights.value()\n\n        if vocabulary is not None:\n            self.set_vocabulary(vocabulary, idf_weights)\n        else:\n            # When restoring from a keras SavedModel, the loading code will\n            # expect to find and restore a lookup_table attribute on the layer.\n            # This table needs to be uninitialized as a StaticHashTable cannot\n            # be initialized twice.\n            self.lookup_table = self._uninitialized_lookup_table()\n\n        # Only set up adapt state if we did not receive a vocab on construction.\n        if not self._has_input_vocabulary:\n            # Set adapt state.\n            self.token_counts = tf.lookup.experimental.MutableHashTable(\n                key_dtype=vocabulary_dtype,\n                value_dtype=\"int64\",\n                default_value=0,\n            )\n            if self.output_mode == \"tf_idf\":\n                self.token_document_counts = (\n                    tf.lookup.experimental.MutableHashTable(\n                        key_dtype=vocabulary_dtype,\n                        value_dtype=\"int64\",\n                        default_value=0,\n                    )\n                )\n                self.num_documents = tf.Variable(\n                    0, dtype=\"int64\", trainable=False\n                )\n\n    def get_vocabulary(self, include_special_tokens=True):\n        \"\"\"Returns the current vocabulary of the layer.\n\n        Args:\n            include_special_tokens: If `True`, the returned vocabulary\n                will include mask and OOV tokens,\n                and a term's index in the vocabulary\n                will equal the term's index when calling the layer.\n                If `False`, the returned vocabulary will not include\n                any mask or OOV tokens.\n        \"\"\"\n        # The lookup table data will not be sorted, so we will create a inverted\n        # lookup here, and use that to lookup a range of indices\n        # [0, vocab_size).\n        if self.lookup_table.size() == 0:\n            vocab, indices = [], []\n        else:\n            keys, values = self.lookup_table.export()\n            vocab, indices = (values, keys) if self.invert else (keys, values)\n            vocab, indices = (\n                self._tensor_vocab_to_numpy(vocab),\n                indices.numpy(),\n            )\n        lookup = collections.defaultdict(\n            lambda: self.oov_token, zip(indices, vocab)\n        )\n        vocab = [lookup[x] for x in range(self.vocabulary_size())]\n        if self.mask_token is not None and self.output_mode == \"int\":\n            vocab[0] = self.mask_token\n        if not include_special_tokens:\n            vocab = vocab[self._token_start_index() :]\n        if self.vocabulary_dtype == \"string\":\n            return [\n                i.decode(\"utf-8\") if isinstance(i, bytes) else i for i in vocab\n            ]\n        else:\n            return vocab\n\n    def vocabulary_size(self):\n        \"\"\"Gets the current size of the layer's vocabulary.\n\n        Returns:\n          The integer size of the vocabulary, including optional mask and oov\n          indices.\n        \"\"\"\n        if tf.executing_eagerly():\n            return (\n                int(self.lookup_table.size().numpy())\n                + self._token_start_index()\n            )\n        else:\n            return self.lookup_table.size() + self._token_start_index()\n\n    def get_config(self):\n        config = {\n            \"invert\": self.invert,\n            \"max_tokens\": self.max_tokens,\n            \"num_oov_indices\": self.num_oov_indices,\n            \"oov_token\": self.oov_token,\n            \"mask_token\": self.mask_token,\n            \"output_mode\": self.output_mode,\n            \"sparse\": self.sparse,\n            \"pad_to_max_tokens\": self.pad_to_max_tokens,\n            \"vocabulary_dtype\": self.vocabulary_dtype,\n            \"idf_weights\": listify_tensors(self.input_idf_weights),\n            \"vocabulary\": listify_tensors(self.input_vocabulary),\n            \"vocabulary_size\": self._frozen_vocab_size,\n            \"oov_method\": self.oov_method,\n        }\n        base_config = super().get_config()\n        return dict(list(base_config.items()) + list(config.items()))\n\n    def _record_vocabulary_size(self):\n        self._ensure_vocab_size_unchanged()\n        with tf.init_scope():\n            self._frozen_vocab_size = self.vocabulary_size()\n\n    def set_vocabulary(self, vocabulary, idf_weights=None):\n        \"\"\"Sets vocabulary (and optionally document frequency) for this layer.\n\n        This method sets the vocabulary and idf weights for this layer directly,\n        instead of analyzing a dataset through `adapt`. It should be used\n        whenever the vocab (and optionally document frequency) information is\n        already known.  If vocabulary data is already present in the layer, this\n        method will replace it.\n\n        Args:\n            vocabulary: Either an array or a string path to a text file.\n                If passing an array, can pass a tuple, list,\n                1D numpy array, or 1D tensor containing the vocbulary terms.\n                If passing a file path, the file should contain one line\n                per term in the vocabulary.\n            idf_weights: A tuple, list, 1D numpy array, or 1D tensor\n                of inverse document frequency weights with equal\n                length to vocabulary. Must be set if `output_mode`\n                is `\"tf_idf\"`. Should not be set otherwise.\n        \"\"\"\n        if self.output_mode == \"tf_idf\":\n            if idf_weights is None:\n                raise ValueError(\n                    \"`idf_weights` must be set if output_mode is 'tf_idf'.\"\n                )\n        elif idf_weights is not None:\n            raise ValueError(\n                \"`idf_weights` should only be set if output_mode is \"\n                f\"`'tf_idf'`. Received: output_mode={self.output_mode} \"\n                f\"and idf_weights={idf_weights}\"\n            )\n\n        if isinstance(vocabulary, str):\n            if serialization_lib.in_safe_mode():\n                raise ValueError(\n                    \"Requested the loading of a vocabulary file outside of the \"\n                    \"model archive. This carries a potential risk of loading \"\n                    \"arbitrary and sensitive files and thus it is disallowed \"\n                    \"by default. If you trust the source of the artifact, you \"\n                    \"can override this error by passing `safe_mode=False` to \"\n                    \"the loading function, or calling \"\n                    \"`keras.config.enable_unsafe_deserialization(). \"\n                    f\"Vocabulary file: '{vocabulary}'\"\n                )\n\n            if not tf.io.gfile.exists(vocabulary):\n                raise ValueError(\n                    f\"Vocabulary file {vocabulary} does not exist.\"\n                )\n            if self.output_mode == \"tf_idf\":\n                raise ValueError(\n                    \"output_mode `'tf_idf'` does not support loading a \"\n                    \"vocabulary from file.\"\n                )\n            self.lookup_table = self._lookup_table_from_file(vocabulary)\n            self._record_vocabulary_size()\n            return\n\n        if not tf.executing_eagerly() and (\n            tf.is_tensor(vocabulary) or tf.is_tensor(idf_weights)\n        ):\n            raise RuntimeError(\n                f\"Cannot set a tensor vocabulary on layer {self.name} \"\n                \"when not executing eagerly. \"\n                \"Create this layer or call `set_vocabulary()` \"\n                \"outside of any traced function.\"\n            )\n\n        # TODO(mattdangerw): for better performance we should rewrite this\n        # entire function to operate on tensors and convert vocabulary to a\n        # tensor here.\n        if tf.is_tensor(vocabulary):\n            vocabulary = self._tensor_vocab_to_numpy(vocabulary)\n        elif isinstance(vocabulary, (list, tuple)):\n            vocabulary = np.array(vocabulary)\n        if tf.is_tensor(idf_weights):\n            idf_weights = idf_weights.numpy()\n        elif isinstance(idf_weights, (list, tuple)):\n            idf_weights = np.array(idf_weights)\n\n        if vocabulary.size == 0:\n            raise ValueError(\n                \"Cannot set an empty vocabulary. \"\n                f\"Received: vocabulary={vocabulary}\"\n            )\n\n        oov_start = self._oov_start_index()\n        token_start = self._token_start_index()\n        special_tokens = [self.mask_token] * oov_start + [\n            self.oov_token\n        ] * self.num_oov_indices\n        found_special_tokens = np.array_equal(\n            special_tokens, vocabulary[:token_start]\n        )\n        if found_special_tokens:\n            tokens = vocabulary[token_start:]\n        else:\n            tokens = vocabulary\n\n        repeated_tokens = self._find_repeated_tokens(tokens)\n        if repeated_tokens:\n            raise ValueError(\n                \"The passed vocabulary has at least one repeated \"\n                \"term. Please uniquify your dataset. The repeated terms \"\n                f\"are: {repeated_tokens}\"\n            )\n\n        if self.mask_token is not None and self.mask_token in tokens:\n            mask_index = np.argwhere(vocabulary == self.mask_token)[-1]\n            raise ValueError(\n                \"Found reserved mask token at unexpected location in \"\n                \"`vocabulary`. Note that passed `vocabulary` does not need to \"\n                \"include the OOV and mask tokens. Either remove all mask and \"\n                \"OOV tokens, or include them only at the start of the \"\n                f\"vocabulary in precisely this order: {special_tokens}. \"\n                f\"Received: mask_token={self.mask_token} at \"\n                f\"vocabulary index {mask_index}\"\n            )\n        # Only error out for oov_token when invert=True. When invert=False,\n        # oov_token is unused during lookup.\n        if (\n            self.oov_token is not None\n            and self.invert\n            and self.oov_token in tokens\n        ):\n            oov_index = np.argwhere(vocabulary == self.oov_token)[-1]\n            raise ValueError(\n                \"Found reserved OOV token at unexpected location in \"\n                \"`vocabulary`. Note that passed `vocabulary` does not need to \"\n                \"include the OOV and mask tokens. Either remove all mask and \"\n                \"OOV tokens, or include them only at the start of the \"\n                f\"vocabulary in precisely this order: {special_tokens}. \"\n                f\"Received: oov_token={self.oov_token} at \"\n                f\"vocabulary index {oov_index}\"\n            )\n\n        new_vocab_size = token_start + len(tokens)\n        if self.max_tokens is not None and (new_vocab_size > self.max_tokens):\n            raise ValueError(\n                \"Attempted to set a vocabulary larger than the maximum vocab \"\n                f\"size. Received vocabulary size is {new_vocab_size}; \"\n                f\"`max_tokens` is {self.max_tokens}.\"\n            )\n        self.lookup_table = self._lookup_table_from_tokens(tokens)\n        self._record_vocabulary_size()\n\n        if self.output_mode == \"tf_idf\" and idf_weights is not None:\n            if len(vocabulary) != len(idf_weights):\n                raise ValueError(\n                    \"`idf_weights` must be the same length as vocabulary. \"\n                    f\"len(idf_weights) is {len(idf_weights)}; \"\n                    f\"len(vocabulary) is {len(vocabulary)}\"\n                )\n            idf_weights = self._convert_to_ndarray(idf_weights)\n            if idf_weights.ndim != 1:\n                raise ValueError(\n                    \"TF-IDF data must be a 1-index array. \"\n                    f\"Received: type(idf_weights)={type(idf_weights)}\"\n                )\n\n            # If the passed vocabulary has no special tokens, we need to pad the\n            # front of idf_weights. We don't have real document frequencies for\n            # these tokens so we will use an average of all idf_weights passed\n            # in as a reasonable default.\n            if found_special_tokens:\n                front_padding = 0\n                front_padding_value = 0\n            else:\n                front_padding = token_start\n                front_padding_value = np.average(idf_weights)\n            # If pad_to_max_tokens is true, and max_tokens is greater than our\n            # total vocab size, we need to pad the back of idf_weights with\n            # zeros as well.\n            back_padding_value = 0\n            if self.pad_to_max_tokens and self.max_tokens is not None:\n                back_padding = (\n                    self.max_tokens - front_padding - len(idf_weights)\n                )\n            else:\n                back_padding = 0\n            weights = np.pad(\n                idf_weights,\n                (front_padding, back_padding),\n                \"constant\",\n                constant_values=(front_padding_value, back_padding_value),\n            )\n            weights = tf.convert_to_tensor(weights, dtype=backend.floatx())\n            self.idf_weights = tf.Variable(\n                weights,\n                trainable=False,\n            )\n            self.idf_weights_const = self.idf_weights.value()\n\n    def get_build_config(self):\n        return {}\n\n    def build_from_config(self, config):\n        self.build(None)\n\n    @property\n    def compute_dtype(self):\n        return self.vocabulary_dtype\n\n    @property\n    def variable_dtype(self):\n        return self.vocabulary_dtype\n\n    def compute_output_shape(self, input_shape):\n        if self.output_mode == \"int\":\n            return input_shape\n        depth = (\n            self.max_tokens\n            if self.pad_to_max_tokens\n            else self._frozen_vocab_size\n        )\n        input_shape = tuple(input_shape)\n        if self.output_mode == \"one_hot\":\n            # One-hot encodes each element: (batch, d1, ..., dN) -> (batch, d1,\n            # ..., dN, depth)\n            if len(input_shape) > 1 and input_shape[-1] == 1:\n                return input_shape[:-1] + (depth,)\n            return input_shape + (depth,)\n        # multi_hot, count, tf_idf: treat last dim as sample dim, output\n        # (batch, ..., depth)\n        return input_shape[:-1] + (depth,)\n\n    def compute_output_spec(self, inputs):\n        if self.output_mode == \"int\":\n            output_dtype = \"int64\"\n        else:\n            output_dtype = backend.floatx()\n        output_shape = self.compute_output_shape(inputs.shape)\n        return backend.KerasTensor(output_shape, dtype=output_dtype)\n\n    def adapt(self, data, steps=None):\n        self.reset_state()\n        if isinstance(data, tf.data.Dataset):\n            if steps is not None:\n                data = data.take(steps)\n            for batch in data:\n                self.update_state(batch)\n        else:\n            data = tf_utils.ensure_tensor(data, dtype=self.vocabulary_dtype)\n            if data.shape.rank == 1:\n                # A plain list of strings\n                # is treated as as many documents\n                data = tf.expand_dims(data, -1)\n            self.update_state(data)\n        self.finalize_state()\n\n    def update_state(self, data):\n        if self._has_input_vocabulary:\n            raise ValueError(\n                f\"Cannot adapt layer '{self.name}' after setting a static \"\n                \"vocabulary via `vocabulary` argument or \"\n                \"`set_vocabulary()` method.\"\n            )\n\n        data = tf_utils.ensure_tensor(data, dtype=self.vocabulary_dtype)\n        if data.shape.rank == 0:\n            data = tf.expand_dims(data, 0)\n        if data.shape.rank == 1:\n            # Expand dims on axis 0 for tf-idf. A 1-d tensor\n            # is a single document.\n            data = tf.expand_dims(data, 0)\n\n        tokens, counts = self._num_tokens(data)\n        self.token_counts.insert(\n            tokens, counts + self.token_counts.lookup(tokens)\n        )\n\n        if self.output_mode == \"tf_idf\":\n            # Dedupe each row of our dataset.\n            if isinstance(data, tf.RaggedTensor):\n                deduped_doc_data = tf.map_fn(lambda x: tf.unique(x)[0], data)\n            else:\n                deduped_doc_data = [tf.unique(x)[0] for x in data]\n                deduped_doc_data = tf.concat(deduped_doc_data, axis=0)\n            # Flatten and count tokens.\n            tokens, counts = self._num_tokens(deduped_doc_data)\n\n            self.token_document_counts.insert(\n                tokens, counts + self.token_document_counts.lookup(tokens)\n            )\n            if isinstance(data, tf.RaggedTensor):\n                self.num_documents.assign_add(data.nrows())\n            else:\n                self.num_documents.assign_add(\n                    tf.shape(data, out_type=\"int64\")[0]\n                )\n\n    def finalize_state(self):\n        if self._has_input_vocabulary or tf.equal(self.token_counts.size(), 0):\n            # Finalize idf_weights to a const for call even if we don't need to\n            # compute a new vocabulary.\n            if self.output_mode == \"tf_idf\":\n                self.idf_weights_const = self.idf_weights.value()\n            self._record_vocabulary_size()\n            return\n\n        # Remove special tokens from our counts.\n        if self.mask_token is not None:\n            self.token_counts.remove(\n                tf.convert_to_tensor([self.mask_token], self.vocabulary_dtype)\n            )\n        if self.oov_token is not None:\n            self.token_counts.remove(\n                tf.convert_to_tensor([self.oov_token], self.vocabulary_dtype)\n            )\n\n        tokens, counts = self.token_counts.export()\n        # To keep vocabs deterministic, we sort our tokens by count and break\n        # ties by sorting the tokens themselves. Tensorflow has no ops for\n        # sorting strings, so we need to use numpy for the sort.\n        sorted_indices = np.lexsort((tokens.numpy(), counts.numpy()))[::-1]\n        token_start = self._token_start_index()\n        if self.max_tokens:\n            max_learned_tokens = self.max_tokens - token_start\n            sorted_indices = sorted_indices[:max_learned_tokens]\n        tokens = tf.gather(tokens, sorted_indices)\n        self.lookup_table = self._lookup_table_from_tokens(tokens)\n\n        if self.output_mode == \"tf_idf\":\n            token_document_counts = self.token_document_counts.lookup(tokens)\n            idf_weights = self._inverse_document_frequency(\n                token_document_counts, self.num_documents\n            )\n            idf_weights = tf.cast(idf_weights, backend.floatx())\n            # Pad the front of idf_weights with the average idf weight for OOV\n            # tokens.  We cannot compute the real idf weight of OOV in a single\n            # pass.\n            idf_weights = tf.pad(\n                idf_weights,\n                [[self._token_start_index(), 0]],\n                constant_values=tf.reduce_mean(idf_weights),\n            )\n            if self.pad_to_max_tokens and self.max_tokens is not None:\n                # Pad the back of idf_weights with zeros.\n                idf_weights = tf.pad(\n                    idf_weights,\n                    [[0, self.max_tokens - tf.size(idf_weights)]],\n                    constant_values=0,\n                )\n            self.idf_weights = tf.Variable(\n                idf_weights,\n                dtype=backend.floatx(),\n                trainable=False,\n            )\n            self.idf_weights_const = self.idf_weights.value()\n\n        # We call this here to save memory, now that we've built our vocabulary,\n        # we don't want to keep every token we've seen in separate lookup\n        # tables.\n        self.reset_state()\n        self._record_vocabulary_size()\n\n    def reset_state(self):\n        if self._has_input_vocabulary:\n            return\n\n        self.token_counts.remove(self.token_counts.export()[0])\n        if self.output_mode == \"tf_idf\":\n            self.token_document_counts.remove(\n                self.token_document_counts.export()[0]\n            )\n            self.num_documents.assign(0)\n\n    def call(self, inputs):\n        from keras.src.backend import tensorflow as tf_backend\n\n        self._ensure_known_vocab_size()\n\n        inputs = tf_utils.ensure_tensor(inputs, dtype=self._key_dtype)\n        original_shape = inputs.shape\n        # Some ops will not handle scalar input, so uprank to rank 1.\n        if inputs.shape.rank == 0:\n            inputs = self._expand_dims(inputs, -1)\n\n        if isinstance(inputs, tf.SparseTensor):\n            lookups = tf.SparseTensor(\n                inputs.indices,\n                self._lookup_dense(inputs.values),\n                inputs.dense_shape,\n            )\n        elif isinstance(inputs, tf.RaggedTensor):\n            lookups = tf.ragged.map_flat_values(self._lookup_dense, inputs)\n        else:\n            lookups = self._lookup_dense(inputs)\n\n        if self.output_mode == \"int\":\n            # If we received a scalar input, downrank back to a scalar.\n            if original_shape.rank == 0:\n                lookups = tf.squeeze(lookups, -1)\n            return lookups\n\n        depth = (\n            self.max_tokens\n            if self.pad_to_max_tokens\n            else self._frozen_vocab_size\n        )\n        idf_weights = (\n            self.idf_weights_const if self.output_mode == \"tf_idf\" else None\n        )\n        output = numerical_utils.encode_categorical_inputs(\n            lookups,\n            output_mode=(\n                \"count\" if self.output_mode == \"tf_idf\" else self.output_mode\n            ),\n            depth=depth,\n            dtype=self._value_dtype,\n            sparse=self.sparse,\n            backend_module=tf_backend,\n        )\n        if self.output_mode == \"tf_idf\":\n            if idf_weights is None:\n                raise ValueError(\n                    \"When `output_mode` is `'tf_idf'`, `idf_weights` must be \"\n                    \"provided.\"\n                )\n            output = tf_backend.numpy.multiply(\n                tf_backend.core.cast(output, idf_weights.dtype), idf_weights\n            )\n        return output\n\n    def _lookup_dense(self, inputs):\n        \"\"\"Lookup table values for a dense Tensor, handling masking and OOV.\"\"\"\n        # When executing eagerly and tracing keras.Input objects,\n        # do not call lookup.\n        # This is critical for restoring SavedModel, which will first trace\n        # layer.call and then attempt to restore the table. We need the table to\n        # be uninitialized for the restore to work, but calling the table\n        # uninitialized would error.\n        if tf.executing_eagerly() and backend.is_keras_tensor(inputs):\n            lookups = tf.zeros_like(inputs, dtype=self._value_dtype)\n        else:\n            lookups = self.lookup_table.lookup(inputs)\n\n        if self.mask_token is not None:\n            mask_locations = tf.equal(inputs, self._mask_key)\n            lookups = tf.where(mask_locations, self._mask_value, lookups)\n\n        if self.invert:\n            return lookups\n\n        lookup_checks = []\n\n        if self.num_oov_indices == 0:\n            # If we have zero oov indices, we need to check for oov inputs.\n            oov_indices = tf.where(tf.equal(lookups, -1))\n            oov_inputs = tf.gather_nd(inputs, oov_indices)\n            msg = tf.strings.format(\n                \"When `num_oov_indices=0` all inputs should be in vocabulary, \"\n                \"found OOV values {}, consider setting `num_oov_indices=1`.\",\n                (oov_inputs,),\n            )\n            assertion = tf.Assert(tf.equal(tf.size(oov_indices), 0), [msg])\n            lookup_checks.append(assertion)\n        elif self.num_oov_indices > 1:\n            if tf.as_dtype(self._key_dtype).is_integer:\n                if self.oov_method == \"farmhash\":\n                    # Cast int to string so we can apply FarmHash64\n                    oov_indices = tf.strings.to_hash_bucket_fast(\n                        tf.strings.as_string(inputs),\n                        num_buckets=self.num_oov_indices,\n                    )\n                else:\n                    # Default: backwards-compatible floormod behaviour.\n                    oov_indices = tf.math.floormod(inputs, self.num_oov_indices)\n            else:\n                oov_indices = tf.strings.to_hash_bucket_fast(\n                    inputs, num_buckets=self.num_oov_indices\n                )\n            oov_indices = oov_indices + self._oov_start_index()\n            oov_locations = tf.equal(lookups, self._default_value)\n            lookups = tf.where(oov_locations, oov_indices, lookups)\n\n        with tf.control_dependencies(lookup_checks):\n            return tf.identity(lookups)\n\n    def save_own_variables(self, store):\n        if self.output_mode == \"tf_idf\":\n            store[\"idf_weights\"] = self.idf_weights_const.numpy()\n\n    def load_own_variables(self, store):\n        if self.output_mode == \"tf_idf\":\n            idf_weights = store[\"idf_weights\"]\n            if hasattr(self, \"idf_weights\"):\n                self.idf_weights.assign(idf_weights)\n            else:\n                self.idf_weights = tf.Variable(idf_weights, trainable=False)\n            self.idf_weights_const = self.idf_weights.value()\n\n    def save_assets(self, dir_path):\n        if self.input_vocabulary is not None:\n            # Vocab saved in config.\n            # TODO: consider unifying both paths.\n            return\n        vocabulary = self.get_vocabulary(include_special_tokens=True)\n        vocabulary_filepath = tf.io.gfile.join(dir_path, \"vocabulary.txt\")\n        with open(vocabulary_filepath, \"w\") as f:\n            f.write(\"\\n\".join([str(w) for w in vocabulary]))\n\n    def load_assets(self, dir_path):\n        if self.input_vocabulary is not None:\n            # Vocab saved in config.\n            # TODO: consider unifying both paths.\n            return\n        vocabulary_filepath = tf.io.gfile.join(dir_path, \"vocabulary.txt\")\n        with open(vocabulary_filepath, \"r\") as f:\n            lines = f.read().splitlines()\n            while lines and lines[-1] == \"\":\n                lines.pop()\n            if tf.as_dtype(self.vocabulary_dtype) == tf.string:\n                values = [str(line) for line in lines]\n            else:\n                values = [int(line) for line in lines]\n            if self.output_mode == \"tf_idf\":\n                idf_weights = self.idf_weights_const.numpy()\n                self.set_vocabulary(values, idf_weights=idf_weights)\n            else:\n                self.set_vocabulary(values)\n\n    def _uninitialized_lookup_table(self):\n        with tf.init_scope():\n            initializer = get_null_initializer(\n                self._key_dtype, self._value_dtype\n            )\n            return tf.lookup.StaticHashTable(initializer, self._default_value)\n\n    def _lookup_table_from_tokens(self, tokens):\n        with tf.init_scope():\n            token_start = self._token_start_index()\n            token_end = token_start + tf.size(tokens)\n            indices_dtype = (\n                self._key_dtype if self.invert else self._value_dtype\n            )\n            indices = tf.range(token_start, token_end, dtype=indices_dtype)\n            keys, values = (\n                (indices, tokens) if self.invert else (tokens, indices)\n            )\n            initializer = tf.lookup.KeyValueTensorInitializer(\n                keys, values, self._key_dtype, self._value_dtype\n            )\n            return tf.lookup.StaticHashTable(initializer, self._default_value)\n\n    def _lookup_table_from_file(self, filename):\n        if self.invert:\n            key_index = tf.lookup.TextFileIndex.LINE_NUMBER\n            value_index = tf.lookup.TextFileIndex.WHOLE_LINE\n        else:\n            key_index = tf.lookup.TextFileIndex.WHOLE_LINE\n            value_index = tf.lookup.TextFileIndex.LINE_NUMBER\n        with tf.init_scope():\n            initializer = tf.lookup.TextFileInitializer(\n                filename=filename,\n                key_dtype=self._key_dtype,\n                key_index=key_index,\n                value_dtype=self._value_dtype,\n                value_index=value_index,\n                value_index_offset=self._token_start_index(),\n            )\n            return tf.lookup.StaticHashTable(initializer, self._default_value)\n\n    def _convert_to_ndarray(self, x):\n        return np.array(x) if isinstance(x, (list, tuple)) else x\n\n    def _expand_dims(self, inputs, axis):\n        if isinstance(inputs, tf.SparseTensor):\n            return tf.sparse.expand_dims(inputs, axis)\n        else:\n            return tf.expand_dims(inputs, axis)\n\n    def _oov_start_index(self):\n        return (\n            1\n            if self.mask_token is not None and self.output_mode == \"int\"\n            else 0\n        )\n\n    def _token_start_index(self):\n        return self._oov_start_index() + self.num_oov_indices\n\n    def _ensure_known_vocab_size(self):\n        if self.output_mode == \"int\" or self.pad_to_max_tokens:\n            return\n        if self._frozen_vocab_size is None:\n            raise RuntimeError(\n                f\"When using `output_mode={self.output_mode}` \"\n                \"and `pad_to_max_tokens=False`, \"\n                \"you must set the layer's vocabulary before calling it. Either \"\n                \"pass a `vocabulary` argument to the layer, or call `adapt` \"\n                \"with some sample data.\"\n            )\n\n    def _ensure_vocab_size_unchanged(self):\n        if self.output_mode == \"int\" or self.pad_to_max_tokens:\n            return\n\n        with tf.init_scope():\n            new_vocab_size = self.vocabulary_size()\n\n        if (\n            self._frozen_vocab_size is not None\n            and new_vocab_size != self._frozen_vocab_size\n        ):\n            raise RuntimeError(\n                f\"When using `output_mode={self.output_mode}` \"\n                \"and `pad_to_max_tokens=False`, \"\n                \"the vocabulary size cannot be changed after the layer is \"\n                f\"called. Old vocab size is {self._frozen_vocab_size}, \"\n                f\"new vocab size is {new_vocab_size}\"\n            )\n\n    def _find_repeated_tokens(self, vocabulary):\n        \"\"\"Return all repeated tokens in a vocabulary.\"\"\"\n        vocabulary_set = set(vocabulary)\n        if len(vocabulary) != len(vocabulary_set):\n            return [\n                item\n                for item, count in collections.Counter(vocabulary).items()\n                if count > 1\n            ]\n        else:\n            return []\n\n    def _num_tokens(self, data):\n        \"\"\"Count the number of tokens in a ragged, sparse or dense tensor.\"\"\"\n        if isinstance(data, tf.SparseTensor):\n            flat_values = data.values\n        elif isinstance(data, tf.RaggedTensor):\n            flat_values = data.flat_values\n        else:\n            flat_values = tf.reshape(data, [-1])\n        tokens, _, counts = tf.unique_with_counts(flat_values, out_idx=\"int64\")\n        return tokens, counts\n\n    def _inverse_document_frequency(self, token_document_counts, num_documents):\n        \"\"\"Computes the inverse-document-frequency (IDF) component of \"tf_idf\".\n        Args:\n            token_document_counts: An array of the # of documents each token\n                appears in.\n            num_documents: An int representing the total number of documents\n\n        Returns:\n            An array of \"inverse document frequency\" weights.\n        \"\"\"\n        return tf.math.log(1 + num_documents / (1 + token_document_counts))\n\n    # Override points for IntegerLookup and StringLookup.\n    def _tensor_vocab_to_numpy(self, vocabulary):\n        \"\"\"Converts a tensor vocabulary to a numpy vocabulary.\"\"\"\n        return vocabulary.numpy()\n\n\ndef get_null_initializer(key_dtype, value_dtype):\n    class NullInitializer(tf.lookup.KeyValueTensorInitializer):\n        \"\"\"A placeholder initializer for restoring from a SavedModel.\"\"\"\n\n        def __init__(self, key_dtype, value_dtype):\n            \"\"\"Construct a table initializer object.\n\n            Args:\n            key_dtype: Type of the table keys.\n            value_dtype: Type of the table values.\n            \"\"\"\n            self._key_dtype = key_dtype\n            self._value_dtype = value_dtype\n\n        @property\n        def key_dtype(self):\n            \"\"\"The expected table key dtype.\"\"\"\n            return self._key_dtype\n\n        @property\n        def value_dtype(self):\n            \"\"\"The expected table value dtype.\"\"\"\n            return self._value_dtype\n\n        def initialize(self, table):\n            \"\"\"Returns the table initialization op.\"\"\"\n            pass\n\n    return NullInitializer(key_dtype, value_dtype)\n\n\ndef listify_tensors(x):\n    \"\"\"Convert any tensors or numpy arrays to lists for config serialization.\"\"\"\n    if tf.is_tensor(x):\n        x = x.numpy()\n    if isinstance(x, np.ndarray):\n        x = x.tolist()\n    return x\n"
  },
  {
    "path": "keras/src/layers/preprocessing/index_lookup_test.py",
    "content": "import os\n\nimport numpy as np\nimport pytest\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import testing\nfrom keras.src.saving import saving_api\n\n\n@pytest.mark.skipif(\n    backend.backend() == \"numpy\", reason=\"Failing for numpy backend.\"\n)\nclass IndexLookupLayerTest(testing.TestCase):\n    def test_basics_string_vocab(self):\n        # Case: adapt + list inputs\n        adapt_data = [\"one\", \"one\", \"one\", \"two\", \"two\", \"three\"]\n        input_data = [\"one\", \"two\", \"four\"]\n        kwargs = {\n            \"max_tokens\": 7,\n            \"num_oov_indices\": 1,\n            \"mask_token\": \"\",\n            \"oov_token\": \"[OOV]\",\n            \"vocabulary_dtype\": \"string\",\n        }\n        layer = layers.IndexLookup(**kwargs)\n        layer.adapt(adapt_data)\n        self.assertEqual(\n            layer.get_vocabulary(), [\"\", \"[OOV]\", \"one\", \"two\", \"three\"]\n        )\n        self.assertEqual(\n            layer.get_vocabulary(include_special_tokens=False),\n            [\"one\", \"two\", \"three\"],\n        )\n        output = layer(input_data)\n        self.assertEqual(list(output), [2, 3, 1])\n        if backend.backend() != \"torch\":\n            self.run_class_serialization_test(layer)\n\n        # Case: numpy array input\n        output = layer(np.array(input_data))\n        self.assertEqual(list(output), [2, 3, 1])\n\n        # Case: fixed vocab + list inputs\n        vocabulary = [\"one\", \"two\", \"three\"]\n        layer = layers.IndexLookup(vocabulary=vocabulary, **kwargs)\n        self.assertEqual(\n            layer.get_vocabulary(), [\"\", \"[OOV]\", \"one\", \"two\", \"three\"]\n        )\n        self.assertEqual(\n            layer.get_vocabulary(include_special_tokens=False),\n            [\"one\", \"two\", \"three\"],\n        )\n        output = layer(input_data)\n        self.assertEqual(list(output), [2, 3, 1])\n        if backend.backend() != \"torch\":\n            self.run_class_serialization_test(layer)\n\n        # Case: fixed vocab with special tokens + list inputs\n        vocabulary_with_special_tokens = [\"\", \"[OOV]\", \"one\", \"two\", \"three\"]\n        layer = layers.IndexLookup(\n            vocabulary=vocabulary_with_special_tokens, **kwargs\n        )\n        self.assertEqual(\n            layer.get_vocabulary(), [\"\", \"[OOV]\", \"one\", \"two\", \"three\"]\n        )\n        self.assertEqual(\n            layer.get_vocabulary(include_special_tokens=False),\n            [\"one\", \"two\", \"three\"],\n        )\n        output = layer(input_data)\n        self.assertEqual(list(output), [2, 3, 1])\n        if backend.backend() != \"torch\":\n            self.run_class_serialization_test(layer)\n\n        # Case: set vocabulary\n        layer = layers.IndexLookup(**kwargs)\n        layer.set_vocabulary(vocabulary)\n        self.assertEqual(\n            layer.get_vocabulary(), [\"\", \"[OOV]\", \"one\", \"two\", \"three\"]\n        )\n        self.assertEqual(\n            layer.get_vocabulary(include_special_tokens=False),\n            [\"one\", \"two\", \"three\"],\n        )\n        output = layer(input_data)\n        self.assertEqual(list(output), [2, 3, 1])\n        if backend.backend() != \"torch\":\n            self.run_class_serialization_test(layer)\n\n        # Case: set vocabulary (with special tokens)\n        layer = layers.IndexLookup(**kwargs)\n        layer.set_vocabulary(vocabulary_with_special_tokens)\n        self.assertEqual(\n            layer.get_vocabulary(), [\"\", \"[OOV]\", \"one\", \"two\", \"three\"]\n        )\n        self.assertEqual(\n            layer.get_vocabulary(include_special_tokens=False),\n            [\"one\", \"two\", \"three\"],\n        )\n        output = layer(input_data)\n        self.assertEqual(list(output), [2, 3, 1])\n        if backend.backend() != \"torch\":\n            self.run_class_serialization_test(layer)\n\n    def test_basics_integer_vocab(self):\n        # Case: adapt + list inputs\n        adapt_data = [1, 1, 1, 2, 2, 3]\n        input_data = [1, 2, 4]\n        kwargs = {\n            \"max_tokens\": 7,\n            \"num_oov_indices\": 1,\n            \"mask_token\": 0,\n            \"oov_token\": -1,\n            \"vocabulary_dtype\": \"int64\",\n        }\n        layer = layers.IndexLookup(**kwargs)\n        layer.adapt(adapt_data)\n        self.assertEqual(layer.get_vocabulary(), [0, -1, 1, 2, 3])\n        self.assertEqual(\n            layer.get_vocabulary(include_special_tokens=False),\n            [1, 2, 3],\n        )\n        output = layer(input_data)\n        self.assertEqual(list(output), [2, 3, 1])\n        if backend.backend() != \"torch\":\n            self.run_class_serialization_test(layer)\n\n        # Case: numpy array input\n        output = layer(np.array(input_data))\n        self.assertEqual(list(output), [2, 3, 1])\n\n        # Case: fixed vocab + list inputs\n        vocabulary = [1, 2, 3]\n        layer = layers.IndexLookup(vocabulary=vocabulary, **kwargs)\n        self.assertEqual(layer.get_vocabulary(), [0, -1, 1, 2, 3])\n        self.assertEqual(\n            layer.get_vocabulary(include_special_tokens=False),\n            [1, 2, 3],\n        )\n        output = layer(input_data)\n        self.assertEqual(list(output), [2, 3, 1])\n        if backend.backend() != \"torch\":\n            self.run_class_serialization_test(layer)\n\n        # Case: fixed vocab with special tokens + list inputs\n        vocabulary_with_special_tokens = [0, -1, 1, 2, 3]\n        layer = layers.IndexLookup(\n            vocabulary=vocabulary_with_special_tokens, **kwargs\n        )\n        self.assertEqual(layer.get_vocabulary(), [0, -1, 1, 2, 3])\n        self.assertEqual(\n            layer.get_vocabulary(include_special_tokens=False),\n            [1, 2, 3],\n        )\n        output = layer(input_data)\n        self.assertEqual(list(output), [2, 3, 1])\n        if backend.backend() != \"torch\":\n            self.run_class_serialization_test(layer)\n\n        # Case: set vocabulary\n        layer = layers.IndexLookup(**kwargs)\n        layer.set_vocabulary(vocabulary)\n        self.assertEqual(layer.get_vocabulary(), [0, -1, 1, 2, 3])\n        self.assertEqual(\n            layer.get_vocabulary(include_special_tokens=False),\n            [1, 2, 3],\n        )\n        output = layer(input_data)\n        self.assertEqual(list(output), [2, 3, 1])\n        if backend.backend() != \"torch\":\n            self.run_class_serialization_test(layer)\n\n        # Case: set vocabulary (with special tokens)\n        layer = layers.IndexLookup(**kwargs)\n        layer.set_vocabulary(vocabulary_with_special_tokens)\n        self.assertEqual(layer.get_vocabulary(), [0, -1, 1, 2, 3])\n        self.assertEqual(\n            layer.get_vocabulary(include_special_tokens=False),\n            [1, 2, 3],\n        )\n        output = layer(input_data)\n        self.assertEqual(list(output), [2, 3, 1])\n        if backend.backend() != \"torch\":\n            self.run_class_serialization_test(layer)\n\n    def test_max_tokens_adapt(self):\n        adapt_data = [1, 1, 1, 2, 2, 3]\n        input_data = [1, 2, 3, 4]\n        kwargs = {\n            \"max_tokens\": 4,\n            \"num_oov_indices\": 1,\n            \"mask_token\": 0,\n            \"oov_token\": -1,\n            \"vocabulary_dtype\": \"int64\",\n        }\n        layer = layers.IndexLookup(**kwargs)\n        layer.adapt(adapt_data)\n        self.assertEqual(layer.get_vocabulary(), [0, -1, 1, 2])\n        self.assertEqual(\n            layer.get_vocabulary(include_special_tokens=False),\n            [1, 2],\n        )\n        output = layer(input_data)\n        self.assertEqual(list(output), [2, 3, 1, 1])\n        if backend.backend() != \"torch\":\n            self.run_class_serialization_test(layer)\n\n    def test_pad_to_max_tokens(self):\n        vocabulary = [1, 2]\n        input_data = [1, 2]\n        kwargs = {\n            \"max_tokens\": 5,\n            \"num_oov_indices\": 1,\n            \"mask_token\": 0,\n            \"oov_token\": -1,\n            \"vocabulary_dtype\": \"int64\",\n            \"vocabulary\": vocabulary,\n            \"pad_to_max_tokens\": True,\n            \"output_mode\": \"multi_hot\",\n        }\n        layer = layers.IndexLookup(**kwargs)\n        output = layer(input_data)\n        self.assertAllClose(output, [0, 1, 1, 0, 0])\n        if backend.backend() != \"torch\":\n            self.run_class_serialization_test(layer)\n\n    def test_output_modes(self):\n        vocabulary = [\"one\", \"two\", \"three\"]\n        single_sample_input_data = [\"one\", \"two\", \"four\"]\n        batch_input_data = [[\"one\", \"two\", \"four\", \"two\"]]\n        kwargs = {\n            \"max_tokens\": 7,\n            \"num_oov_indices\": 1,\n            \"mask_token\": \"\",\n            \"oov_token\": \"[OOV]\",\n            \"vocabulary_dtype\": \"string\",\n            \"vocabulary\": vocabulary,\n        }\n\n        # int\n        kwargs[\"output_mode\"] = \"int\"\n        layer = layers.IndexLookup(**kwargs)\n        output = layer(single_sample_input_data)\n        self.assertAllClose(output, [2, 3, 1])\n        output = layer(batch_input_data)\n        self.assertAllClose(output, [[2, 3, 1, 3]])\n\n        # multi-hot\n        kwargs[\"output_mode\"] = \"multi_hot\"\n        layer = layers.IndexLookup(**kwargs)\n        output = layer(single_sample_input_data)\n        self.assertAllClose(output, [1, 1, 1, 0])\n        output = layer(batch_input_data)\n        self.assertAllClose(output, [[1, 1, 1, 0]])\n\n        # one-hot\n        kwargs[\"output_mode\"] = \"one_hot\"\n        layer = layers.IndexLookup(**kwargs)\n        output = layer(single_sample_input_data)\n        self.assertAllClose(output, [[0, 1, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0]])\n\n        # count\n        kwargs[\"output_mode\"] = \"count\"\n        layer = layers.IndexLookup(**kwargs)\n        output = layer(single_sample_input_data)\n        self.assertAllClose(output, [1, 1, 1, 0])\n        output = layer(batch_input_data)\n        self.assertAllClose(output, [[1, 1, 2, 0]])\n\n        # tf-idf\n        kwargs[\"output_mode\"] = \"tf_idf\"\n        kwargs[\"idf_weights\"] = np.array([0.1, 0.2, 0.3])\n        layer = layers.IndexLookup(**kwargs)\n        output = layer(single_sample_input_data)\n        self.assertAllClose(output, [0.2, 0.1, 0.2, 0.0])\n        output = layer(batch_input_data)\n        self.assertAllClose(output, [[0.2, 0.1, 0.4, 0.0]])\n\n    def test_one_hot_symbolic_output_shape_with_higher_rank_input(self):\n        \"\"\"Symbolic output shape for one_hot must preserve input dims + depth.\n\n        Regression test for gh-22336: StringLookup/IntegerLookup with\n        output_mode='one_hot' produced (None, depth) instead of\n        (None, d1, ..., dN, depth) for nested inputs.\n        \"\"\"\n        # IntegerLookup with one_hot and 3D input (batch, 2, 2)\n        layer = layers.IntegerLookup(\n            vocabulary=[1, 2, 3],\n            output_mode=\"one_hot\",\n        )\n        symbolic_input = layers.Input(shape=(2, 2), dtype=\"int32\")\n        symbolic_output = layer(symbolic_input)\n        # Expected: (None, 2, 2, vocab_size) where vocab_size = 4 (3 + OOV)\n        self.assertEqual(\n            tuple(symbolic_output.shape),\n            (None, 2, 2, 4),\n            msg=\"one_hot symbolic output shape must be input_shape + (depth,)\",\n        )\n        # Eager execution: same input shape -> same output shape\n        eager_input = np.array([[[1, 2], [3, 0]], [[1, 2], [3, 0]]])\n        eager_output = layer(eager_input)\n        self.assertEqual(eager_output.shape, (2, 2, 2, 4))\n        self.assertEqual(\n            tuple(symbolic_output.shape)[1:],\n            eager_output.shape[1:],\n            msg=\"Symbolic and eager output shapes must match (except batch)\",\n        )\n\n    def test_one_hot_compute_output_shape_multi_hot_consistency(self):\n        \"\"\"multi_hot/count/tf_idf last dim is sample in output shape.\"\"\"\n        kwargs = {\n            \"max_tokens\": 10,\n            \"num_oov_indices\": 1,\n            \"mask_token\": None,\n            \"oov_token\": \"[OOV]\",\n            \"vocabulary_dtype\": \"string\",\n            \"vocabulary\": [\"a\", \"b\", \"c\"],\n        }\n        # depth = vocab size (3) + OOV (1) = 4 when pad_to_max_tokens is False\n        depth = 4\n        # multi_hot: (batch, sample_len) -> (batch, depth)\n        layer_multi = layers.IndexLookup(**kwargs, output_mode=\"multi_hot\")\n        shape_multi = layer_multi.compute_output_shape((None, 5))\n        self.assertEqual(shape_multi, (None, depth))\n        # one_hot: (batch, d1, d2) -> (batch, d1, d2, depth)\n        layer_one = layers.IndexLookup(**kwargs, output_mode=\"one_hot\")\n        shape_one = layer_one.compute_output_shape((None, 2, 2))\n        self.assertEqual(shape_one, (None, 2, 2, depth))\n\n    def test_one_hot_compute_output_spec_preserves_input_dims(self):\n        \"\"\"compute_output_spec for one_hot must preserve all input dims.\"\"\"\n        layer = layers.IntegerLookup(\n            vocabulary=[1, 2, 3],\n            output_mode=\"one_hot\",\n        )\n        symbolic_input = layers.Input(shape=(3, 4), dtype=\"int32\")\n        output_spec = layer.compute_output_spec(symbolic_input)\n        self.assertEqual(output_spec.shape, (None, 3, 4, 4))\n        self.assertEqual(output_spec.dtype, backend.floatx())\n\n    def test_sparse_outputs(self):\n        # TODO\n        pass\n\n    def test_adapt_tf_idf(self):\n        # Case: unbatched data\n        adapt_data = [\"one\", \"one\", \"one\", \"two\", \"two\", \"three\"]\n        input_data = [\"one\", \"two\", \"four\"]\n        kwargs = {\n            \"max_tokens\": 7,\n            \"num_oov_indices\": 1,\n            \"mask_token\": \"\",\n            \"oov_token\": \"[OOV]\",\n            \"vocabulary_dtype\": \"string\",\n            \"output_mode\": \"tf_idf\",\n        }\n        layer = layers.IndexLookup(**kwargs)\n        layer.adapt(adapt_data)\n        output = layer(input_data)\n        # Document counts for one, two, three = [3, 2, 1]\n        idf_weights = np.log(1 + len(adapt_data) / (1 + np.array([3, 2, 1])))\n        self.assertAllClose(layer.idf_weights[1:], idf_weights)\n        self.assertAllClose(output, [1.1337324, 0.91629076, 1.0986123, 0.0])\n        # Case: batched data\n        adapt_data = [[\"one\", \"one\"], [\"one\", \"two\"], [\"two\", \"three\"]]\n        input_data = [[\"one\", \"two\"], [\"two\", \"four\"]]\n        kwargs = {\n            \"max_tokens\": 7,\n            \"num_oov_indices\": 1,\n            \"mask_token\": \"\",\n            \"oov_token\": \"[OOV]\",\n            \"vocabulary_dtype\": \"string\",\n            \"output_mode\": \"tf_idf\",\n        }\n        layer = layers.IndexLookup(**kwargs)\n        layer.adapt(adapt_data)\n        # Document counts for one, two, three = [2, 2, 1]\n        idf_weights = np.log(1 + len(adapt_data) / (1 + np.array([2, 2, 1])))\n        self.assertAllClose(layer.idf_weights[1:], idf_weights)\n        output = layer(input_data)\n        self.assertAllClose(\n            output,\n            [\n                [0.0, 0.6931472, 0.6931472, 0.0],\n                [0.76752836, 0.0, 0.6931472, 0.0],\n            ],\n        )\n\n    def test_invert(self):\n        vocabulary = [\"one\", \"two\", \"three\"]\n        single_sample_input_data = [2, 3, 1]\n        batch_input_data = [[2, 3, 1, 3]]\n        kwargs = {\n            \"max_tokens\": 7,\n            \"num_oov_indices\": 1,\n            \"mask_token\": \"\",\n            \"oov_token\": \"[OOV]\",\n            \"vocabulary_dtype\": \"string\",\n            \"vocabulary\": vocabulary,\n            \"invert\": True,\n            \"output_mode\": \"int\",\n        }\n        layer = layers.IndexLookup(**kwargs)\n        output = layer(single_sample_input_data)\n        self.assertEqual(\n            [w.decode(\"utf-8\") for w in output.numpy()], [\"one\", \"two\", \"[OOV]\"]\n        )\n        output = layer(batch_input_data)\n        self.assertEqual(\n            [w.decode(\"utf-8\") for w in output.numpy()[0]],\n            [\"one\", \"two\", \"[OOV]\", \"two\"],\n        )\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\", reason=\"Requires string input dtype\"\n    )\n    def test_saving(self):\n        # Test with adapt()\n        vocabulary = [\"one\", \"two\", \"three\"]\n        adapt_data = [\"one\", \"one\", \"one\", \"two\", \"two\", \"three\"]\n        batch_input_data = np.array([[\"one\", \"two\", \"four\"]])\n        kwargs = {\n            \"max_tokens\": 7,\n            \"num_oov_indices\": 1,\n            \"mask_token\": \"\",\n            \"oov_token\": \"[OOV]\",\n            \"vocabulary_dtype\": \"string\",\n            \"output_mode\": \"int\",\n        }\n        layer = layers.IndexLookup(**kwargs)\n        layer.adapt(adapt_data)\n        model = models.Sequential(\n            [\n                layers.Input(shape=(None,), dtype=\"string\"),\n                layer,\n            ]\n        )\n        output_1 = model(batch_input_data)\n        path = os.path.join(self.get_temp_dir(), \"model.keras\")\n        model.save(path)\n        model = saving_api.load_model(path)\n        output_2 = model(batch_input_data)\n        self.assertAllClose(output_1, output_2)\n\n        # Test when vocabulary is provided\n        kwargs[\"vocabulary\"] = vocabulary\n        layer = layers.IndexLookup(**kwargs)\n        model = models.Sequential(\n            [\n                layers.Input(shape=(None,), dtype=\"string\"),\n                layer,\n            ]\n        )\n        output_1 = model(batch_input_data)\n        path = os.path.join(self.get_temp_dir(), \"model.keras\")\n        model.save(path)\n        model = saving_api.load_model(path)\n        output_2 = model(batch_input_data)\n        self.assertAllClose(output_1, output_2)\n\n    def test_adapt_with_tf_data(self):\n        # Case: adapt + list inputs\n        adapt_data = tf_data.Dataset.from_tensor_slices(\n            [\"one\", \"one\", \"one\", \"two\", \"two\", \"three\"]\n        ).batch(2)\n        input_data = [\"one\", \"two\", \"four\"]\n        kwargs = {\n            \"max_tokens\": 7,\n            \"num_oov_indices\": 1,\n            \"mask_token\": \"\",\n            \"oov_token\": \"[OOV]\",\n            \"vocabulary_dtype\": \"string\",\n        }\n        layer = layers.IndexLookup(**kwargs)\n        layer.adapt(adapt_data)\n        self.assertEqual(\n            layer.get_vocabulary(), [\"\", \"[OOV]\", \"one\", \"two\", \"three\"]\n        )\n        self.assertEqual(\n            layer.get_vocabulary(include_special_tokens=False),\n            [\"one\", \"two\", \"three\"],\n        )\n        output = layer(input_data)\n        self.assertEqual(list(output), [2, 3, 1])\n        if backend.backend() != \"torch\":\n            self.run_class_serialization_test(layer)\n\n    def test_max_tokens_less_than_two(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"If set, `max_tokens` must be greater than 1.\",\n        ):\n            layers.IndexLookup(\n                max_tokens=1,\n                num_oov_indices=1,\n                mask_token=None,\n                oov_token=None,\n                vocabulary_dtype=\"int64\",\n            )\n\n    def test_max_tokens_none_with_pad_to_max_tokens(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"If pad_to_max_tokens is True, must set `max_tokens`.\",\n        ):\n            layers.IndexLookup(\n                num_oov_indices=1,\n                max_tokens=None,\n                mask_token=None,\n                oov_token=None,\n                vocabulary_dtype=\"int64\",\n                pad_to_max_tokens=True,\n            )\n\n    def test_negative_num_oov_indices(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"`num_oov_indices` must be greater than or equal to 0.\",\n        ):\n            layers.IndexLookup(\n                max_tokens=10,\n                num_oov_indices=-1,\n                mask_token=None,\n                oov_token=None,\n                vocabulary_dtype=\"int64\",\n            )\n\n    def test_invert_with_non_int_output_mode(self):\n        with self.assertRaisesRegex(\n            ValueError, r\"`output_mode` must be `'int'` when `invert` is true.\"\n        ):\n            layers.IndexLookup(\n                num_oov_indices=1,\n                max_tokens=None,\n                mask_token=None,\n                oov_token=None,\n                vocabulary_dtype=\"string\",\n                invert=True,\n                output_mode=\"one_hot\",  # Invalid combination\n            )\n\n    def test_sparse_true_with_int_output_mode(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"`sparse` may only be true if `output_mode` is `'one_hot'`\",\n        ):\n            layers.IndexLookup(\n                num_oov_indices=1,\n                max_tokens=None,\n                mask_token=None,\n                oov_token=None,\n                vocabulary_dtype=\"string\",\n                sparse=True,\n                output_mode=\"int\",  # Invalid combination\n            )\n\n    def test_idf_weights_set_with_non_tfidf_output_mode(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"`idf_weights` should only be set if `output_mode` is `'tf_idf'`\",\n        ):\n            layers.IndexLookup(\n                num_oov_indices=1,\n                max_tokens=None,\n                mask_token=None,\n                oov_token=None,\n                vocabulary_dtype=\"string\",\n                idf_weights=[\n                    0.5,\n                    0.1,\n                    0.3,\n                ],  # Should not be set for non-TF-IDF modes\n                output_mode=\"int\",\n            )\n\n    def test_unrecognized_kwargs(self):\n        with self.assertRaisesRegex(\n            ValueError, \"Unrecognized keyword argument\"\n        ):\n            layers.IndexLookup(\n                num_oov_indices=1,\n                max_tokens=None,\n                mask_token=None,\n                oov_token=None,\n                vocabulary_dtype=\"string\",\n                output_mode=\"int\",\n                # This is an unrecognized argument\n                extra_arg=True,\n            )\n\n    def test_non_tf_idf_with_idf_weights(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"`idf_weights` should only be set if `output_mode` is\",\n        ):\n            layers.IndexLookup(\n                num_oov_indices=1,\n                max_tokens=None,\n                mask_token=None,\n                oov_token=None,\n                vocabulary_dtype=\"string\",\n                output_mode=\"multi_hot\",\n                idf_weights=[\n                    0.5,\n                    0.1,\n                    0.3,\n                ],  # idf_weights not valid for multi_hot mode\n            )\n\n    def test_vocabulary_file_does_not_exist(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Vocabulary file path/to/missing_vocab.txt does not exist\",\n        ):\n            layers.IndexLookup(\n                num_oov_indices=1,\n                max_tokens=None,\n                mask_token=None,\n                oov_token=None,\n                vocabulary_dtype=\"string\",\n                output_mode=\"int\",\n                # Nonexistent file path\n                vocabulary=\"path/to/missing_vocab.txt\",\n            )\n\n    def test_repeated_tokens_in_vocabulary(self):\n        with self.assertRaisesRegex(\n            ValueError, \"The passed vocabulary has at least one repeated term.\"\n        ):\n            layers.IndexLookup(\n                num_oov_indices=1,\n                max_tokens=None,\n                mask_token=None,\n                oov_token=None,\n                vocabulary_dtype=\"string\",\n                vocabulary=[\"token\", \"token\", \"unique\"],\n            )\n\n    def test_mask_token_in_wrong_position(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Found reserved mask token at unexpected location in `vocabulary`.\",\n        ):\n            layers.IndexLookup(\n                num_oov_indices=1,\n                max_tokens=None,\n                mask_token=\"mask\",\n                oov_token=None,\n                vocabulary_dtype=\"string\",\n                vocabulary=[\n                    \"token\",\n                    \"mask\",\n                    \"unique\",\n                ],  # 'mask' should be at the start if included explicitly\n            )\n\n    def test_ensure_known_vocab_size_without_vocabulary(self):\n        kwargs = {\n            \"num_oov_indices\": 1,\n            # Assume empty string or some default token is valid.\n            \"mask_token\": \"\",\n            # Assume [OOV] or some default token is valid.\n            \"oov_token\": \"[OOV]\",\n            \"output_mode\": \"multi_hot\",\n            \"pad_to_max_tokens\": False,\n            \"vocabulary_dtype\": \"string\",\n            \"max_tokens\": None,\n        }\n        layer = layers.IndexLookup(**kwargs)\n\n        # Try calling the layer without setting the vocabulary.\n        with self.assertRaisesRegex(\n            RuntimeError, \"When using `output_mode=multi_hot` and\"\n        ):\n            input_data = [\"sample\", \"data\"]\n            layer(input_data)\n\n    def test_save_and_load_assets_string_vocab(self):\n        kwargs = {\n            \"max_tokens\": 10,\n            \"num_oov_indices\": 1,\n            \"mask_token\": \"<mask>\",\n            \"oov_token\": \"[OOV]\",\n            \"vocabulary_dtype\": \"string\",\n        }\n        layer = layers.IndexLookup(**kwargs)\n\n        vocabulary = [\"apple\", \"banana\", \"cherry\"]\n        layer.set_vocabulary(vocabulary)\n\n        vocab_before = layer.get_vocabulary(include_special_tokens=True)\n        vocab_before_no_special = layer.get_vocabulary(\n            include_special_tokens=False\n        )\n\n        sample_input = [\"apple\", \"banana\", \"unknown\"]\n        output_before = layer(sample_input).numpy()\n\n        tmpdir = self.get_temp_dir()\n\n        layer.save_assets(tmpdir)\n\n        layer2 = layers.IndexLookup(**kwargs)\n        layer2.load_assets(tmpdir)\n\n        vocab_after = layer2.get_vocabulary(include_special_tokens=True)\n        vocab_after_no_special = layer2.get_vocabulary(\n            include_special_tokens=False\n        )\n\n        self.assertEqual(vocab_before, vocab_after)\n        self.assertEqual(vocab_before_no_special, vocab_after_no_special)\n\n        output_after = layer2(sample_input).numpy()\n        np.testing.assert_array_equal(output_before, output_after)\n\n    def test_save_and_load_assets_with_multiple_oov_indices(self):\n        kwargs = {\n            \"max_tokens\": 10,\n            \"num_oov_indices\": 2,\n            \"mask_token\": \"<mask>\",\n            \"oov_token\": \"[OOV]\",\n            \"vocabulary_dtype\": \"string\",\n        }\n        layer = layers.IndexLookup(**kwargs)\n\n        vocabulary = [\"apple\", \"banana\"]\n        layer.set_vocabulary(vocabulary)\n\n        vocab_before = layer.get_vocabulary(include_special_tokens=True)\n\n        self.assertEqual(len(vocab_before), 5)\n        self.assertEqual(vocab_before[0], \"<mask>\")\n        self.assertEqual(vocab_before[1], \"[OOV]\")\n        self.assertEqual(vocab_before[2], \"[OOV]\")\n\n        tmpdir = self.get_temp_dir()\n\n        layer.save_assets(tmpdir)\n\n        layer2 = layers.IndexLookup(**kwargs)\n        layer2.load_assets(tmpdir)\n\n        vocab_after = layer2.get_vocabulary(include_special_tokens=True)\n        self.assertEqual(vocab_before, vocab_after)\n\n    def test_load_assets_handles_trailing_newlines(self):\n        kwargs = {\n            \"max_tokens\": 10,\n            \"num_oov_indices\": 1,\n            \"mask_token\": \"<mask>\",\n            \"oov_token\": \"[OOV]\",\n            \"vocabulary_dtype\": \"string\",\n        }\n        layer = layers.IndexLookup(**kwargs)\n\n        vocabulary = [\"apple\", \"banana\", \"cherry\"]\n        layer.set_vocabulary(vocabulary)\n        vocab_expected = layer.get_vocabulary(include_special_tokens=True)\n\n        tmpdir = self.get_temp_dir()\n\n        vocab_file = os.path.join(tmpdir, \"vocabulary.txt\")\n        with open(vocab_file, \"w\") as f:\n            f.write(\"<mask>\\n[OOV]\\napple\\nbanana\\ncherry\\n\")\n\n        layer2 = layers.IndexLookup(**kwargs)\n        layer2.load_assets(tmpdir)\n\n        vocab_loaded = layer2.get_vocabulary(include_special_tokens=True)\n        self.assertEqual(vocab_expected, vocab_loaded)\n\n    def test_oov_method_ignored_for_string_dtype(self):\n        vocabulary = [\"cat\", \"dog\", \"fish\"]\n        oov_data = [\"aaa\", \"bbb\", \"ccc\", \"ddd\", \"eee\", \"fff\"]\n        kwargs = {\n            \"max_tokens\": 10,\n            \"num_oov_indices\": 4,\n            \"mask_token\": \"\",\n            \"oov_token\": \"[OOV]\",\n            \"vocabulary_dtype\": \"string\",\n            \"vocabulary\": vocabulary,\n        }\n        layer_floormod = layers.IndexLookup(oov_method=\"floormod\", **kwargs)\n        layer_farmhash = layers.IndexLookup(oov_method=\"farmhash\", **kwargs)\n        out_floormod = backend.convert_to_numpy(layer_floormod(oov_data))\n        out_farmhash = backend.convert_to_numpy(layer_farmhash(oov_data))\n        self.assertAllClose(out_floormod, out_farmhash)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/integer_lookup.py",
    "content": "import numpy as np\n\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.index_lookup import IndexLookup\nfrom keras.src.utils import backend_utils\nfrom keras.src.utils.module_utils import tensorflow as tf\n\n\n@keras_export(\"keras.layers.IntegerLookup\")\nclass IntegerLookup(IndexLookup):\n    \"\"\"A preprocessing layer that maps integers to (possibly encoded) indices.\n\n    This layer maps a set of arbitrary integer input tokens into indexed integer\n    output via a table-based vocabulary lookup. The layer's output indices will\n    be contiguously arranged up to the maximum vocab size, even if the input\n    tokens are non-continguous or unbounded. The layer supports multiple options\n    for encoding the output via `output_mode`, and has optional support for\n    out-of-vocabulary (OOV) tokens and masking.\n\n    The vocabulary for the layer must be either supplied on construction or\n    learned via `adapt()`. During `adapt()`, the layer will analyze a data set,\n    determine the frequency of individual integer tokens, and create a\n    vocabulary from them. If the vocabulary is capped in size, the most frequent\n    tokens will be used to create the vocabulary and all others will be treated\n    as OOV.\n\n    There are two possible output modes for the layer.  When `output_mode` is\n    `\"int\"`, input integers are converted to their index in the vocabulary (an\n    integer).  When `output_mode` is `\"multi_hot\"`, `\"count\"`, or `\"tf_idf\"`,\n    input integers are encoded into an array where each dimension corresponds to\n    an element in the vocabulary.\n\n    The vocabulary can optionally contain a mask token as well as an OOV token\n    (which can optionally occupy multiple indices in the vocabulary, as set\n    by `num_oov_indices`).\n    The position of these tokens in the vocabulary is fixed. When `output_mode`\n    is `\"int\"`, the vocabulary will begin with the mask token at index 0,\n    followed by OOV indices, followed by the rest of the vocabulary. When\n    `output_mode` is `\"multi_hot\"`, `\"count\"`, or `\"tf_idf\"` the vocabulary will\n    begin with OOV indices and instances of the mask token will be dropped.\n\n    **Note:** This layer uses TensorFlow internally. It cannot\n    be used as part of the compiled computation graph of a model with\n    any backend other than TensorFlow.\n    It can however be used with any backend when running eagerly.\n    It can also always be used as part of an input preprocessing pipeline\n    with any backend (outside the model itself), which is how we recommend\n    to use this layer.\n\n    **Note:** This layer is safe to use inside a `tf.data` pipeline\n    (independently of which backend you're using).\n\n    Args:\n        max_tokens: Maximum size of the vocabulary for this layer. This should\n            only be specified when adapting the vocabulary or when setting\n            `pad_to_max_tokens=True`. If None, there is no cap on the size of\n            the vocabulary. Note that this size includes the OOV\n            and mask tokens. Defaults to `None`.\n        num_oov_indices: The number of out-of-vocabulary tokens to use.\n            If this value is more than 1, OOV inputs are hashed or modulated\n            to determine their OOV value (see `oov_method`).\n            If this value is 0, OOV inputs will cause an error when calling\n            the layer. Defaults to `1`.\n        mask_token: An integer token that represents masked inputs. When\n            `output_mode` is `\"int\"`, the token is included in vocabulary\n            and mapped to index 0. In other output modes,\n            the token will not appear in the vocabulary and instances\n            of the mask token in the input will be dropped.\n            If set to None, no mask term will be added. Defaults to `None`.\n        oov_token: Only used when `invert` is `True`. The token to return\n            for OOV indices. Defaults to `-1`.\n        vocabulary: Optional. Either an array of integers or a string path to a\n            text file. If passing an array, can pass a tuple, list,\n            1D NumPy array, or 1D tensor containing the integer vocbulary terms.\n            If passing a file path, the file should contain one line per term\n            in the vocabulary. If this argument is set,\n            there is no need to `adapt()` the layer.\n        vocabulary_dtype: The dtype of the vocabulary terms.\n            Only `vocabulary_dtype='int64'` is supported at this time.\n            Defaults to `\"int64\"`.\n        idf_weights: Only valid when `output_mode` is `\"tf_idf\"`.\n            A tuple, list, 1D NumPy array, or 1D tensor or the same length\n            as the vocabulary, containing the floating point inverse document\n            frequency weights, which will be multiplied by per sample term\n            counts for the final TF-IDF weight.\n            If the `vocabulary` argument is set, and `output_mode` is\n            `\"tf_idf\"`, this argument must be supplied.\n        invert: Only valid when `output_mode` is `\"int\"`.\n            If `True`, this layer will map indices to vocabulary items\n            instead of mapping vocabulary items to indices.\n            Defaults to `False`.\n        output_mode: Specification for the output of the layer. Values can be\n            `\"int\"`, `\"one_hot\"`, `\"multi_hot\"`, `\"count\"`, or `\"tf_idf\"`\n            configuring the layer as follows:\n            - `\"int\"`: Return the vocabulary indices of the input tokens.\n            - `\"one_hot\"`: Encodes each individual element in the input into an\n                array the same size as the vocabulary,\n                containing a 1 at the element index. If the last dimension\n                is size 1, will encode on that dimension.\n                If the last dimension is not size 1, will append a new\n                dimension for the encoded output.\n            - `\"multi_hot\"`: Encodes each sample in the input into a single\n                array the same size as the vocabulary,\n                containing a 1 for each vocabulary term present in the sample.\n                Treats the last dimension as the sample dimension,\n                if input shape is `(..., sample_length)`,\n                output shape will be `(..., num_tokens)`.\n            - `\"count\"`: As `\"multi_hot\"`, but the int array contains\n                a count of the number of times the token at that index\n                appeared in the sample.\n            - `\"tf_idf\"`: As `\"multi_hot\"`, but the TF-IDF algorithm is\n                applied to find the value in each token slot.\n            For `\"int\"` output, the output shape matches the input shape.\n            For `\"one_hot\"` output, the output shape is\n            `input_shape + (vocabulary_size,)`, where `input_shape` may\n            have arbitrary rank. For other output modes (`\"multi_hot\"`,\n            `\"count\"`, `\"tf_idf\"`), the output shape is `(batch_size,\n            vocabulary_size)`. Defaults to `\"int\"`.\n        pad_to_max_tokens: Only applicable when `output_mode` is `\"multi_hot\"`,\n            `\"count\"`, or `\"tf_idf\"`. If `True`, the output will have\n            its feature axis padded to `max_tokens` even if the number\n            of unique tokens in the vocabulary is less than `max_tokens`,\n            resulting in a tensor of shape `(batch_size, max_tokens)`\n            regardless of vocabulary size. Defaults to `False`.\n        sparse: Boolean. Only applicable to `\"multi_hot\"`, `\"count\"`, and\n            `\"tf_idf\"` output modes. Only supported with TensorFlow\n            backend. If `True`, returns a `SparseTensor`\n            instead of a dense `Tensor`. Defaults to `False`.\n        oov_method: Only relevant when `num_oov_indices > 1`. Controls how OOV\n            tokens are assigned to OOV buckets.\n            - `\"floormod\"` (default): uses `token % num_oov_indices`.\n              Preserves backwards compatibility but can produce severe bucket\n              imbalance when input IDs share a common factor with\n              `num_oov_indices` (e.g. all-even IDs with an even bucket count).\n            - `\"farmhash\"`: applies FarmHash64. Distributes OOV tokens\n            uniformly regardless of the arithmetic structure of the input IDs.\n            This parameter is ignored for string inputs, which always use\n            FarmHash64.\n\n    Examples:\n\n    **Creating a lookup layer with a known vocabulary**\n\n    This example creates a lookup layer with a pre-existing vocabulary.\n\n    >>> vocab = [12, 36, 1138, 42]\n    >>> data = np.array([[12, 1138, 42], [42, 1000, 36]])  # Note OOV tokens\n    >>> layer = IntegerLookup(vocabulary=vocab)\n    >>> layer(data)\n    array([[1, 3, 4],\n           [4, 0, 2]])\n\n    **Creating a lookup layer with an adapted vocabulary**\n\n    This example creates a lookup layer and generates the vocabulary by\n    analyzing the dataset.\n\n    >>> data = np.array([[12, 1138, 42], [42, 1000, 36]])\n    >>> layer = IntegerLookup()\n    >>> layer.adapt(data)\n    >>> layer.get_vocabulary()\n    [-1, 42, 1138, 1000, 36, 12]\n\n    Note that the OOV token -1 have been added to the vocabulary. The remaining\n    tokens are sorted by frequency (42, which has 2 occurrences, is first) then\n    by inverse sort order.\n\n    >>> data = np.array([[12, 1138, 42], [42, 1000, 36]])\n    >>> layer = IntegerLookup()\n    >>> layer.adapt(data)\n    >>> layer(data)\n    array([[5, 2, 1],\n           [1, 3, 4]])\n\n    **Lookups with multiple OOV indices**\n\n    This example demonstrates how to use a lookup layer with multiple OOV\n    indices.  When a layer is created with more than one OOV index, any OOV\n    tokens are hashed or modulated into the number of OOV buckets, distributing\n    OOV tokens in a deterministic fashion across the set. Use `oov_method` to\n    control whether `floormod` or FarmHash64 is used.\n\n    >>> vocab = [12, 36, 1138, 42]\n    >>> data = np.array([[12, 1138, 42], [37, 1000, 36]])\n    >>> layer = IntegerLookup(vocabulary=vocab, num_oov_indices=2)\n    >>> layer(data)\n    array([[2, 4, 5],\n           [1, 0, 3]])\n\n    Note that the output for OOV token 37 is 1, while the output for OOV token\n    1000 is 0. The in-vocab terms have their output index increased by 1 from\n    earlier examples (12 maps to 2, etc) in order to make space for the extra\n    OOV token.\n\n    **Uniform OOV distribution with FarmHash**\n\n    This example shows how `oov_method=\"farmhash\"` avoids the bucket imbalance\n    that `\"floormod\"` produces for arithmetically structured input IDs.\n\n    >>> vocab = [10, 20, 30]\n    >>> layer_floormod = IntegerLookup(\n    ...     vocabulary=vocab, num_oov_indices=4, oov_method=\"floormod\")\n    >>> layer_farmhash = IntegerLookup(\n    ...     vocabulary=vocab, num_oov_indices=4, oov_method=\"farmhash\")\n    >>> oov_values = np.array([100, 300, 700, 1100, 1700, 2000])\n    >>> layer_floormod(oov_values)\n    tf.Tensor([0 0 0 0 0 0], shape=(6,), dtype=int64) # All map to index 0\n    >>> layer_farmhash(oov_values)\n    tf.Tensor([3 3 2 2 0 1], shape=(6,), dtype=int64) # Spread across indices\n\n    **One-hot output**\n\n    Configure the layer with `output_mode='one_hot'`. Note that the first\n    `num_oov_indices` dimensions in the ont_hot encoding represent OOV values.\n\n    >>> vocab = [12, 36, 1138, 42]\n    >>> data = np.array([12, 36, 1138, 42, 7])  # Note OOV tokens\n    >>> layer = IntegerLookup(vocabulary=vocab, output_mode='one_hot')\n    >>> layer(data)\n    array([[0., 1., 0., 0., 0.],\n            [0., 0., 1., 0., 0.],\n            [0., 0., 0., 1., 0.],\n            [0., 0., 0., 0., 1.],\n            [1., 0., 0., 0., 0.]], dtype=float32)\n\n    **Multi-hot output**\n\n    Configure the layer with `output_mode='multi_hot'`. Note that the first\n    `num_oov_indices` dimensions in the multi_hot encoding represent OOV tokens\n\n    >>> vocab = [12, 36, 1138, 42]\n    >>> data = np.array([[12, 1138, 42, 42],\n    ...                  [42,    7, 36,  7]])  # Note OOV tokens\n    >>> layer = IntegerLookup(vocabulary=vocab, output_mode='multi_hot')\n    >>> layer(data)\n    array([[0., 1., 0., 1., 1.],\n           [1., 0., 1., 0., 1.]], dtype=float32)\n\n    **Token count output**\n\n    Configure the layer with `output_mode='count'`. As with multi_hot output,\n    the first `num_oov_indices` dimensions in the output represent OOV tokens.\n\n    >>> vocab = [12, 36, 1138, 42]\n    >>> data = np.array([[12, 1138, 42, 42],\n    ...                  [42,    7, 36,  7]])  # Note OOV tokens\n    >>> layer = IntegerLookup(vocabulary=vocab, output_mode='count')\n    >>> layer(data)\n    array([[0., 1., 0., 1., 2.],\n           [2., 0., 1., 0., 1.]], dtype=float32)\n\n    **TF-IDF output**\n\n    Configure the layer with `output_mode='tf_idf'`. As with multi_hot output,\n    the first `num_oov_indices` dimensions in the output represent OOV tokens.\n\n    Each token bin will output `token_count * idf_weight`, where the idf weights\n    are the inverse document frequency weights per token. These should be\n    provided along with the vocabulary. Note that the `idf_weight` for OOV\n    tokens will default to the average of all idf weights passed in.\n\n    >>> vocab = [12, 36, 1138, 42]\n    >>> idf_weights = [0.25, 0.75, 0.6, 0.4]\n    >>> data = np.array([[12, 1138, 42, 42],\n    ...                  [42,    7, 36,  7]])  # Note OOV tokens\n    >>> layer = IntegerLookup(\n    ...     output_mode='tf_idf', vocabulary=vocab, idf_weights=idf_weights)\n    >>> layer(data)\n    array([[0.  , 0.25, 0.  , 0.6 , 0.8 ],\n            [1.0 , 0.  , 0.75, 0.  , 0.4 ]], dtype=float32)\n\n    To specify the idf weights for oov tokens, you will need to pass the entire\n    vocabulary including the leading oov token.\n\n    >>> vocab = [-1, 12, 36, 1138, 42]\n    >>> idf_weights = [0.9, 0.25, 0.75, 0.6, 0.4]\n    >>> data = np.array([[12, 1138, 42, 42],\n    ...                  [42,    7, 36,  7]])  # Note OOV tokens\n    >>> layer = IntegerLookup(\n    ...     output_mode='tf_idf', vocabulary=vocab, idf_weights=idf_weights)\n    >>> layer(data)\n    array([[0.  , 0.25, 0.  , 0.6 , 0.8 ],\n            [1.8 , 0.  , 0.75, 0.  , 0.4 ]], dtype=float32)\n\n    When adapting the layer in `\"tf_idf\"` mode, each input sample will\n    be considered a document, and IDF weight per token will be\n    calculated as:\n    `log(1 + num_documents / (1 + token_document_count))`.\n\n    **Inverse lookup**\n\n    This example demonstrates how to map indices to tokens using this layer.\n    (You can also use `adapt()` with `inverse=True`, but for simplicity we'll\n    pass the vocab in this example.)\n\n    >>> vocab = [12, 36, 1138, 42]\n    >>> data = np.array([[1, 3, 4], [4, 0, 2]])\n    >>> layer = IntegerLookup(vocabulary=vocab, invert=True)\n    >>> layer(data)\n    array([[  12, 1138,   42],\n           [  42,   -1,   36]])\n\n    Note that the first index correspond to the oov token by default.\n\n    **Forward and inverse lookup pairs**\n\n    This example demonstrates how to use the vocabulary of a standard lookup\n    layer to create an inverse lookup layer.\n\n    >>> vocab = [12, 36, 1138, 42]\n    >>> data = np.array([[12, 1138, 42], [42, 1000, 36]])\n    >>> layer = IntegerLookup(vocabulary=vocab)\n    >>> i_layer = IntegerLookup(\n    ...     vocabulary=layer.get_vocabulary(), invert=True)\n    >>> int_data = layer(data)\n    >>> i_layer(int_data)\n    array([[  12, 1138,   42],\n           [  42,   -1,   36]])\n\n    In this example, the input token 1000 resulted in an output of -1, since\n    1000 was not in the vocabulary - it got represented as an OOV, and all OOV\n    tokens are returned as -1 in the inverse layer. Also, note that for the\n    inverse to work, you must have already set the forward layer vocabulary\n    either directly or via `adapt()` before calling `get_vocabulary()`.\n    \"\"\"\n\n    def __init__(\n        self,\n        max_tokens=None,\n        num_oov_indices=1,\n        mask_token=None,\n        oov_token=-1,\n        vocabulary=None,\n        vocabulary_dtype=\"int64\",\n        idf_weights=None,\n        invert=False,\n        output_mode=\"int\",\n        sparse=False,\n        pad_to_max_tokens=False,\n        oov_method=\"floormod\",\n        name=None,\n        **kwargs,\n    ):\n        if not tf.available:\n            raise ImportError(\n                \"Layer IntegerLookup requires TensorFlow. \"\n                \"Install it via `pip install tensorflow`.\"\n            )\n        if max_tokens is not None and max_tokens <= 1:\n            raise ValueError(\n                \"If `max_tokens` is set for `IntegerLookup`, it must be \"\n                f\"greater than 1. Received: max_tokens={max_tokens}\"\n            )\n        if num_oov_indices < 0:\n            raise ValueError(\n                \"The value of `num_oov_indices` argument for `IntegerLookup` \"\n                \"must >= 0. Received: num_oov_indices=\"\n                f\"{num_oov_indices}\"\n            )\n        if sparse and backend.backend() != \"tensorflow\":\n            raise ValueError(\n                \"`sparse=True` can only be used with the TensorFlow backend.\"\n            )\n        if vocabulary_dtype != \"int64\":\n            raise ValueError(\n                \"Only `vocabulary_dtype='int64'` is supported \"\n                \"at this time. Received: \"\n                f\"vocabulary_dtype={vocabulary_dtype}\"\n            )\n        super().__init__(\n            max_tokens=max_tokens,\n            num_oov_indices=num_oov_indices,\n            mask_token=mask_token,\n            oov_token=oov_token,\n            vocabulary=vocabulary,\n            vocabulary_dtype=vocabulary_dtype,\n            idf_weights=idf_weights,\n            invert=invert,\n            output_mode=output_mode,\n            sparse=sparse,\n            pad_to_max_tokens=pad_to_max_tokens,\n            oov_method=oov_method,\n            name=name,\n            **kwargs,\n        )\n        self._convert_input_args = False\n        self._allow_non_tensor_positional_args = True\n        self.supports_jit = False\n\n    def adapt(self, data, steps=None):\n        \"\"\"Computes a vocabulary of integer terms from tokens in a dataset.\n\n        Calling `adapt()` on an `IntegerLookup` layer is an alternative to\n        passing in a precomputed vocabulary  on construction via the\n        `vocabulary` argument.  An `IntegerLookup` layer should always be either\n        adapted over a dataset or supplied with a vocabulary.\n\n        During `adapt()`, the layer will build a vocabulary of all integer\n        tokens seen in the dataset, sorted by occurrence count, with ties broken\n        by sort order of the tokens (high to low). At the end of `adapt()`, if\n        `max_tokens` is set, the vocabulary will be truncated to `max_tokens`\n        size. For example, adapting a layer with `max_tokens=1000` will compute\n        the 1000 most frequent tokens occurring in the input dataset. If\n        `output_mode='tf-idf'`, `adapt()` will also learn the document\n        frequencies of each token in the input dataset.\n\n        Arguments:\n            data: The data to train on. It can be passed either as a\n                batched `tf.data.Dataset`, as a list of integers,\n                or as a NumPy array.\n            steps: Integer or `None`.\n                Total number of steps (batches of samples) to process.\n                If `data` is a `tf.data.Dataset`, and `steps` is `None`,\n                `adapt()` will run until the input dataset is exhausted.\n                When passing an infinitely\n                repeating dataset, you must specify the `steps` argument. This\n                argument is not supported with array inputs or list inputs.\n        \"\"\"\n        super().adapt(data, steps=steps)\n\n    def get_config(self):\n        config = super().get_config()\n        if config[\"oov_token\"] is not None:\n            config[\"oov_token\"] = int(config[\"oov_token\"])\n        if config[\"mask_token\"] is not None:\n            config[\"mask_token\"] = int(config[\"mask_token\"])\n        if config[\"vocabulary\"] is not None:\n            config[\"vocabulary\"] = [int(v) for v in config[\"vocabulary\"]]\n        return config\n\n    def call(self, inputs):\n        if not isinstance(\n            inputs, (tf.Tensor, tf.RaggedTensor, np.ndarray, list, tuple)\n        ):\n            inputs = tf.convert_to_tensor(backend.convert_to_numpy(inputs))\n        outputs = super().call(inputs)\n        return backend_utils.convert_tf_tensor(outputs)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/integer_lookup_test.py",
    "content": "import os\n\nimport numpy as np\nimport pytest\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass IntegerLookupTest(testing.TestCase):\n    def test_config(self):\n        layer = layers.IntegerLookup(\n            output_mode=\"int\",\n            vocabulary=[1, 2, 3],\n            oov_token=1,\n            mask_token=0,\n        )\n        self.run_class_serialization_test(layer)\n\n    def test_adapt_flow(self):\n        adapt_data = [1, 1, 1, 2, 2, 3]\n        single_sample_input_data = [1, 2, 4]\n        batch_input_data = [[1, 2, 4], [2, 3, 5]]\n\n        # int mode\n        layer = layers.IntegerLookup(\n            output_mode=\"int\",\n        )\n        layer.adapt(adapt_data)\n        output = layer(single_sample_input_data)\n        self.assertTrue(backend.is_tensor(output))\n        self.assertAllClose(output, np.array([1, 2, 0]))\n        output = layer(batch_input_data)\n        self.assertTrue(backend.is_tensor(output))\n        self.assertAllClose(output, np.array([[1, 2, 0], [2, 3, 0]]))\n\n        # one_hot mode\n        layer = layers.IntegerLookup(\n            output_mode=\"one_hot\",\n        )\n        layer.adapt(adapt_data)\n        output = layer(single_sample_input_data)\n        self.assertTrue(backend.is_tensor(output))\n        self.assertAllClose(\n            output, np.array([[0, 1, 0, 0], [0, 0, 1, 0], [1, 0, 0, 0]])\n        )\n\n        # multi_hot mode\n        layer = layers.IntegerLookup(\n            output_mode=\"multi_hot\",\n        )\n        layer.adapt(adapt_data)\n        output = layer(single_sample_input_data)\n        self.assertTrue(backend.is_tensor(output))\n        self.assertAllClose(output, np.array([1, 1, 1, 0]))\n\n        # tf_idf mode\n        layer = layers.IntegerLookup(\n            output_mode=\"tf_idf\",\n        )\n        layer.adapt(adapt_data)\n        output = layer(single_sample_input_data)\n        self.assertTrue(backend.is_tensor(output))\n        self.assertAllClose(\n            output, np.array([1.133732, 0.916291, 1.098612, 0.0])\n        )\n\n        # count mode\n        layer = layers.IntegerLookup(\n            output_mode=\"count\",\n        )\n        layer.adapt(adapt_data)\n        output = layer([1, 2, 3, 4, 1, 2, 1])\n        self.assertTrue(backend.is_tensor(output))\n        self.assertAllClose(output, np.array([1, 3, 2, 1]))\n\n    def test_fixed_vocabulary(self):\n        layer = layers.IntegerLookup(\n            output_mode=\"int\",\n            vocabulary=[1, 2, 3, 4],\n        )\n        input_data = [2, 3, 4, 5]\n        output = layer(input_data)\n        self.assertTrue(backend.is_tensor(output))\n        self.assertAllClose(output, np.array([2, 3, 4, 0]))\n\n    def test_set_vocabulary(self):\n        layer = layers.IntegerLookup(\n            output_mode=\"int\",\n        )\n        layer.set_vocabulary([1, 2, 3, 4])\n        input_data = [2, 3, 4, 5]\n        output = layer(input_data)\n        self.assertTrue(backend.is_tensor(output))\n        self.assertAllClose(output, np.array([2, 3, 4, 0]))\n\n    def test_tf_data_compatibility(self):\n        layer = layers.IntegerLookup(\n            output_mode=\"int\",\n            vocabulary=[1, 2, 3, 4],\n        )\n        input_data = [2, 3, 4, 5]\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(4).map(layer)\n        output = next(iter(ds)).numpy()\n        self.assertAllClose(output, np.array([2, 3, 4, 0]))\n\n    def test_one_hot_output_with_higher_rank_input(self):\n        input_data = np.array([[1, 2], [3, 0]])\n        vocabulary = [1, 2, 3]\n        layer = layers.IntegerLookup(\n            vocabulary=vocabulary, output_mode=\"one_hot\"\n        )\n        output_data = layer(input_data)\n        self.assertEqual(output_data.shape, (2, 2, 4))\n        expected_output = np.array(\n            [\n                [[0, 1, 0, 0], [0, 0, 1, 0]],\n                [[0, 0, 0, 1], [1, 0, 0, 0]],\n            ]\n        )\n        self.assertAllClose(output_data, expected_output)\n        output_data_3d = layer(np.expand_dims(input_data, axis=0))\n        self.assertEqual(output_data_3d.shape, (1, 2, 2, 4))\n        self.assertAllClose(\n            output_data_3d, np.expand_dims(expected_output, axis=0)\n        )\n\n    def test_multi_hot_output_shape(self):\n        input_data = np.array([[1, 2], [3, 0]])\n        vocabulary = [1, 2, 3]\n        layer = layers.IntegerLookup(\n            vocabulary=vocabulary, output_mode=\"multi_hot\"\n        )\n        output_data = layer(input_data)\n        self.assertEqual(output_data.shape, (2, 4))\n\n    def test_count_output_shape(self):\n        input_data = np.array([[1, 2], [3, 0]])\n        vocabulary = [1, 2, 3]\n        layer = layers.IntegerLookup(vocabulary=vocabulary, output_mode=\"count\")\n        output_data = layer(input_data)\n        self.assertEqual(output_data.shape, (2, 4))\n\n    def test_tf_idf_output_shape(self):\n        input_data = np.array([[1, 2], [3, 0]])\n        vocabulary = [1, 2, 3]\n        idf_weights = [1.0, 1.0, 1.0]\n        layer = layers.IntegerLookup(\n            vocabulary=vocabulary,\n            idf_weights=idf_weights,\n            output_mode=\"tf_idf\",\n        )\n        output_data = layer(input_data)\n        self.assertEqual(output_data.shape, (2, 4))\n\n    def test_max_tokens(self):\n        layer = layers.IntegerLookup(output_mode=\"int\", max_tokens=4)\n        layer.adapt([1, 2, 3, 4, 5, 6, 1, 1, 2, 2])\n        vocab = layer.get_vocabulary()\n        self.assertEqual(len(vocab), 4)\n\n    def test_mask_token(self):\n        layer = layers.IntegerLookup(\n            output_mode=\"int\",\n            vocabulary=[1, 2, 3],\n            mask_token=0,\n        )\n        output = layer([0, 1, 2, 3])\n        self.assertAllClose(output, np.array([0, 2, 3, 4]))\n\n    def test_invert(self):\n        layer = layers.IntegerLookup(\n            vocabulary=[10, 20, 30],\n            invert=True,\n        )\n        output = layer([1, 2, 3, 0])\n        self.assertAllClose(output, np.array([10, 20, 30, -1]))\n\n    def test_pad_to_max_tokens(self):\n        layer = layers.IntegerLookup(\n            vocabulary=[1, 2],\n            output_mode=\"multi_hot\",\n            max_tokens=5,\n            pad_to_max_tokens=True,\n        )\n        output = layer([1, 2])\n        self.assertEqual(output.shape[-1], 5)\n\n    def test_num_oov_indices(self):\n        layer = layers.IntegerLookup(\n            vocabulary=[1, 2, 3],\n            num_oov_indices=2,\n            output_mode=\"int\",\n        )\n        output = layer([1, 2, 3, 999, 1000])\n        self.assertAllClose(output[:3], np.array([2, 3, 4]))\n        self.assertTrue(\n            all(o in [0, 1] for o in backend.convert_to_numpy(output[3:]))\n        )\n\n    def test_get_vocabulary(self):\n        layer = layers.IntegerLookup(output_mode=\"int\")\n        layer.adapt([5, 5, 5, 10, 10, 15])\n        vocab = layer.get_vocabulary()\n        self.assertEqual(vocab[0], -1)\n        self.assertEqual(vocab[1], 5)\n\n    def test_invalid_max_tokens(self):\n        with self.assertRaises(ValueError):\n            layers.IntegerLookup(max_tokens=1)\n\n    def test_invalid_num_oov_indices(self):\n        with self.assertRaises(ValueError):\n            layers.IntegerLookup(num_oov_indices=-1)\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"sparse=True only supported on TensorFlow\",\n    )\n    def test_sparse_output(self):\n        layer = layers.IntegerLookup(\n            vocabulary=[1, 2, 3],\n            output_mode=\"multi_hot\",\n            sparse=True,\n        )\n        output = layer([1, 2])\n        self.assertTrue(hasattr(output, \"indices\"))  # SparseTensor check\n\n    def test_invalid_vocabulary_dtype(self):\n        with self.assertRaises(ValueError):\n            layers.IntegerLookup(vocabulary_dtype=\"int32\")\n\n    def test_num_oov_indices_zero(self):\n        layer = layers.IntegerLookup(\n            vocabulary=[1, 2, 3],\n            num_oov_indices=0,\n            output_mode=\"int\",\n        )\n        output = layer([1, 2, 3])\n        self.assertAllClose(output, np.array([0, 1, 2]))\n\n    def test_adapt_with_steps(self):\n        layer = layers.IntegerLookup(output_mode=\"int\")\n        ds = tf_data.Dataset.from_tensor_slices([1, 2, 3, 1, 1]).batch(2)\n        layer.adapt(ds, steps=2)\n        vocab = layer.get_vocabulary()\n        self.assertIn(1, vocab)\n\n    def test_vocabulary_from_file(self):\n        tmp_dir = self.get_temp_dir()\n        vocab_file = os.path.join(tmp_dir, \"vocab.txt\")\n        with open(vocab_file, \"w\") as f:\n            f.write(\"10\\n20\\n30\\n\")\n        layer = layers.IntegerLookup(\n            vocabulary=vocab_file,\n            output_mode=\"int\",\n        )\n        output = layer([10, 20, 30, 999])\n        self.assertAllClose(output, np.array([1, 2, 3, 0]))\n\n    def test_oov_method_farmhash(self):\n        vocab = [12, 36, 1138, 42]\n        layer = layers.IntegerLookup(\n            vocabulary=vocab, num_oov_indices=2, oov_method=\"farmhash\"\n        )\n        data = np.array([12, 36, 1138, 42, 100, 200])\n        output = layer(data)\n        # In-vocab tokens should map correctly (offset by num_oov_indices=2)\n        self.assertAllClose(output[:4], np.array([2, 3, 4, 5]))\n        # OOV tokens should land in [0, num_oov_indices)\n        oov_output = backend.convert_to_numpy(output[4:])\n        self.assertTrue(all(o in [0, 1] for o in oov_output))\n\n    def test_oov_method_invalid_value(self):\n        with self.assertRaises(ValueError):\n            layers.IntegerLookup(\n                vocabulary=[1, 2, 3],\n                num_oov_indices=2,\n                oov_method=\"invalid_method\",\n            )\n\n    def test_oov_method_ignored_when_single_oov_index(self):\n        # oov_method has no effect when num_oov_indices=1\n        layer_floormod = layers.IntegerLookup(\n            vocabulary=[1, 2, 3], num_oov_indices=1, oov_method=\"floormod\"\n        )\n        layer_farmhash = layers.IntegerLookup(\n            vocabulary=[1, 2, 3], num_oov_indices=1, oov_method=\"farmhash\"\n        )\n        oov_values = [99, 100, 101]\n        out_floormod = backend.convert_to_numpy(layer_floormod(oov_values))\n        out_farmhash = backend.convert_to_numpy(layer_farmhash(oov_values))\n        self.assertAllClose(out_floormod, out_farmhash)\n\n    def test_oov_method_farmhash_output_is_correct(self):\n        # Expected values computed once via:\n        # tf.strings.to_hash_bucket_fast(\n        #     tf.strings.as_string([100, 200, 300, 400]), num_buckets=4\n        # )\n        # FarmHash64 is deterministic\n        layer = layers.IntegerLookup(\n            vocabulary=[10, 20, 30],\n            num_oov_indices=4,\n            oov_method=\"farmhash\",\n        )\n        output = backend.convert_to_numpy(layer([100, 200, 300, 400]))\n        self.assertAllClose(output, [3, 1, 3, 1])\n"
  },
  {
    "path": "keras/src/layers/preprocessing/mel_spectrogram.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.data_layer import DataLayer\n\n# mel spectrum constants.\n_MEL_BREAK_FREQUENCY_HERTZ = 700.0\n_MEL_HIGH_FREQUENCY_Q = 1127.0\n\n\n@keras_export(\"keras.layers.MelSpectrogram\")\nclass MelSpectrogram(DataLayer):\n    \"\"\"A preprocessing layer to convert raw audio signals to Mel spectrograms.\n\n    This layer takes `float32`/`float64` single or batched audio signal as\n    inputs and computes the Mel spectrogram using Short-Time Fourier Transform\n    and Mel scaling. The input should be a 1D (unbatched) or 2D (batched) tensor\n    representing audio signals. The output will be a 2D or 3D tensor\n    representing Mel spectrograms.\n\n    A spectrogram is an image-like representation that shows the frequency\n    spectrum of a signal over time. It uses x-axis to represent time, y-axis to\n    represent frequency, and each pixel to represent intensity.\n    Mel spectrograms are a special type of spectrogram that use the mel scale,\n    which approximates how humans perceive sound. They are commonly used in\n    speech and music processing tasks like speech recognition, speaker\n    identification, and music genre classification.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    References:\n    - [Spectrogram](https://en.wikipedia.org/wiki/Spectrogram),\n    - [Mel scale](https://en.wikipedia.org/wiki/Mel_scale).\n\n    Args:\n        fft_length: Integer, size of the FFT window.\n        sequence_stride: Integer, number of samples between successive STFT\n            columns.\n        sequence_length: Integer, size of the window used for applying\n            `window` to each audio frame. If `None`, defaults to `fft_length`.\n        window: String, name of the window function to use. Available values\n            are `\"hann\"` and `\"hamming\"`. If `window` is a tensor, it will be\n            used directly as the window and its length must be\n            `sequence_length`. If `window` is `None`, no windowing is\n            used. Defaults to `\"hann\"`.\n        sampling_rate: Integer, sample rate of the input signal.\n        num_mel_bins: Integer, number of mel bins to generate.\n        min_freq: Float, minimum frequency of the mel bins.\n        max_freq: Float, maximum frequency of the mel bins.\n            If `None`, defaults to `sampling_rate / 2`.\n        power_to_db: If True, convert the power spectrogram to decibels.\n        top_db: Float, minimum negative cut-off `max(10 * log10(S)) - top_db`.\n        mag_exp: Float, exponent for the magnitude spectrogram.\n            1 for magnitude, 2 for power, etc. Default is 2.\n        ref_power: Float, the power is scaled relative to it\n            `10 * log10(S / ref_power)`.\n        min_power: Float, minimum value for power and `ref_power`.\n\n    Examples:\n\n    **Unbatched audio signal**\n\n    >>> layer = keras.layers.MelSpectrogram(num_mel_bins=64,\n    ...                                     sampling_rate=8000,\n    ...                                     sequence_stride=256,\n    ...                                     fft_length=2048)\n    >>> layer(keras.random.uniform(shape=(16000,))).shape\n    (64, 63)\n\n    **Batched audio signal**\n\n    >>> layer = keras.layers.MelSpectrogram(num_mel_bins=80,\n    ...                                     sampling_rate=8000,\n    ...                                     sequence_stride=128,\n    ...                                     fft_length=2048)\n    >>> layer(keras.random.uniform(shape=(2, 16000))).shape\n    (2, 80, 125)\n\n    Input shape:\n        1D (unbatched) or 2D (batched) tensor with shape:`(..., samples)`.\n\n    Output shape:\n        2D (unbatched) or 3D (batched) tensor with\n        shape:`(..., num_mel_bins, time)`.\n\n    \"\"\"\n\n    def __init__(\n        self,\n        fft_length=2048,\n        sequence_stride=512,\n        sequence_length=None,\n        window=\"hann\",\n        sampling_rate=16000,\n        num_mel_bins=128,\n        min_freq=20.0,\n        max_freq=None,\n        power_to_db=True,\n        top_db=80.0,\n        mag_exp=2.0,\n        min_power=1e-10,\n        ref_power=1.0,\n        **kwargs,\n    ):\n        self.fft_length = fft_length\n        self.sequence_stride = sequence_stride\n        self.sequence_length = sequence_length or fft_length\n        self.window = window\n        self.sampling_rate = sampling_rate\n        self.num_mel_bins = num_mel_bins\n        self.min_freq = min_freq\n        self.max_freq = max_freq or int(sampling_rate / 2)\n        self.power_to_db = power_to_db\n        self.top_db = top_db\n        self.mag_exp = mag_exp\n        self.min_power = min_power\n        self.ref_power = ref_power\n        super().__init__(**kwargs)\n\n    def call(self, inputs):\n        dtype = (\n            \"float32\"\n            if self.compute_dtype not in [\"float32\", \"float64\"]\n            else self.compute_dtype\n        )  # jax, tf supports only \"float32\" and \"float64\" in stft\n        inputs = self.backend.convert_to_tensor(inputs, dtype=dtype)\n        outputs = self._spectrogram(inputs)\n        outputs = self._melscale(outputs)\n        if self.power_to_db:\n            outputs = self._dbscale(outputs)\n        # swap time & freq axis to have shape of (..., num_mel_bins, time)\n        outputs = self.backend.numpy.swapaxes(outputs, -1, -2)\n        outputs = self.backend.cast(outputs, self.compute_dtype)\n        return outputs\n\n    def _spectrogram(self, inputs):\n        real, imag = self.backend.math.stft(\n            inputs,\n            sequence_length=self.sequence_length,\n            sequence_stride=self.sequence_stride,\n            fft_length=self.fft_length,\n            window=self.window,\n            center=True,\n        )\n        # abs of complex  = sqrt(real^2 + imag^2)\n        spec = self.backend.numpy.sqrt(\n            self.backend.numpy.add(\n                self.backend.numpy.square(real), self.backend.numpy.square(imag)\n            )\n        )\n        spec = self.backend.numpy.power(spec, self.mag_exp)\n        return spec\n\n    def _melscale(self, inputs):\n        matrix = self.linear_to_mel_weight_matrix(\n            num_mel_bins=self.num_mel_bins,\n            num_spectrogram_bins=self.backend.shape(inputs)[-1],\n            sampling_rate=self.sampling_rate,\n            lower_edge_hertz=self.min_freq,\n            upper_edge_hertz=self.max_freq,\n        )\n        return self.backend.numpy.tensordot(inputs, matrix, axes=1)\n\n    def _dbscale(self, inputs):\n        log_spec = 10.0 * (\n            self.backend.numpy.log10(\n                self.backend.numpy.maximum(inputs, self.min_power)\n            )\n        )\n        ref_value = self.backend.numpy.abs(\n            self.backend.convert_to_tensor(self.ref_power)\n        )\n        log_spec -= 10.0 * self.backend.numpy.log10(\n            self.backend.numpy.maximum(ref_value, self.min_power)\n        )\n        log_spec = self.backend.numpy.maximum(\n            log_spec, self.backend.numpy.max(log_spec) - self.top_db\n        )\n        return log_spec\n\n    def _hertz_to_mel(self, frequencies_hertz):\n        \"\"\"Converts frequencies in `frequencies_hertz` in Hertz to the\n            mel scale.\n\n        Args:\n            frequencies_hertz: A tensor of frequencies in Hertz.\n            name: An optional name for the operation.\n\n        Returns:\n            A tensor of the same shape and type of `frequencies_hertz`\n            containing frequencies in the mel scale.\n        \"\"\"\n        return _MEL_HIGH_FREQUENCY_Q * self.backend.numpy.log(\n            1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ)\n        )\n\n    def linear_to_mel_weight_matrix(\n        self,\n        num_mel_bins=20,\n        num_spectrogram_bins=129,\n        sampling_rate=8000,\n        lower_edge_hertz=125.0,\n        upper_edge_hertz=3800.0,\n        dtype=\"float32\",\n    ):\n        \"\"\"Returns a matrix to warp linear scale spectrograms to the mel scale.\n\n        Returns a weight matrix that can be used to re-weight a tensor\n        containing `num_spectrogram_bins` linearly sampled frequency information\n        from `[0, sampling_rate / 2]` into `num_mel_bins` frequency information\n        from `[lower_edge_hertz, upper_edge_hertz]` on the mel scale.\n\n        This function follows the [Hidden Markov Model Toolkit (HTK)](\n        http://htk.eng.cam.ac.uk/) convention, defining the mel scale in\n        terms of a frequency in hertz according to the following formula:\n\n        ```mel(f) = 2595 * log10( 1 + f/700)```\n\n        In the returned matrix, all the triangles (filterbanks) have a peak\n        value of 1.0.\n\n        For example, the returned matrix `A` can be used to right-multiply a\n        spectrogram `S` of shape `[frames, num_spectrogram_bins]` of linear\n        scale spectrum values (e.g. STFT magnitudes) to generate a\n        \"mel spectrogram\" `M` of shape `[frames, num_mel_bins]`.\n\n        ```\n        # `S` has shape [frames, num_spectrogram_bins]\n        # `M` has shape [frames, num_mel_bins]\n        M = keras.ops.matmul(S, A)\n        ```\n\n        The matrix can be used with `keras.ops.tensordot` to convert an\n        arbitrary rank `Tensor` of linear-scale spectral bins into the\n        mel scale.\n\n        ```\n        # S has shape [..., num_spectrogram_bins].\n        # M has shape [..., num_mel_bins].\n        M = keras.ops.tensordot(S, A, 1)\n        ```\n\n        References:\n        - [Mel scale (Wikipedia)](https://en.wikipedia.org/wiki/Mel_scale)\n\n        Args:\n            num_mel_bins: Python int. How many bands in the resulting\n                mel spectrum.\n            num_spectrogram_bins: An integer `Tensor`. How many bins there are\n                in the source spectrogram data, which is understood to be\n                `fft_size // 2 + 1`, i.e. the spectrogram only contains the\n                nonredundant FFT bins.\n            sampling_rate: An integer or float `Tensor`. Samples per second of\n                the input signal used to create the spectrogram. Used to figure\n                out the frequencies corresponding to each spectrogram bin,\n                which dictates how they are mapped into the mel scale.\n            lower_edge_hertz: Python float. Lower bound on the frequencies to be\n                included in the mel spectrum. This corresponds to the lower\n                edge of the lowest triangular band.\n            upper_edge_hertz: Python float. The desired top edge of the highest\n                frequency band.\n            dtype: The `DType` of the result matrix. Must be a floating point\n                type.\n\n        Returns:\n            A tensor of shape `[num_spectrogram_bins, num_mel_bins]`.\n        \"\"\"\n\n        # This function can be constant folded by graph optimization since\n        # there are no Tensor inputs.\n        sampling_rate = self.backend.cast(sampling_rate, dtype)\n        lower_edge_hertz = self.backend.convert_to_tensor(\n            lower_edge_hertz,\n            dtype,\n        )\n        upper_edge_hertz = self.backend.convert_to_tensor(\n            upper_edge_hertz,\n            dtype,\n        )\n        zero = self.backend.convert_to_tensor(0.0, dtype)\n\n        # HTK excludes the spectrogram DC bin.\n        bands_to_zero = 1\n        nyquist_hertz = sampling_rate / 2.0\n        linear_frequencies = self.backend.numpy.linspace(\n            zero, nyquist_hertz, num_spectrogram_bins\n        )[bands_to_zero:]\n        spectrogram_bins_mel = self.backend.numpy.expand_dims(\n            self._hertz_to_mel(linear_frequencies), 1\n        )\n\n        # Compute num_mel_bins triples of (lower_edge, center, upper_edge). The\n        # center of each band is the lower and upper edge of the adjacent bands.\n        # Accordingly, we divide [lower_edge_hertz, upper_edge_hertz] into\n        # num_mel_bins + 2 pieces.\n        band_edges_mel = self.backend.math.extract_sequences(\n            self.backend.numpy.linspace(\n                self._hertz_to_mel(lower_edge_hertz),\n                self._hertz_to_mel(upper_edge_hertz),\n                num_mel_bins + 2,\n            ),\n            sequence_length=3,\n            sequence_stride=1,\n        )\n\n        # Split the triples up and reshape them into [1, num_mel_bins] tensors.\n        lower_edge_mel, center_mel, upper_edge_mel = tuple(\n            self.backend.numpy.reshape(t, [1, num_mel_bins])\n            for t in self.backend.numpy.split(band_edges_mel, 3, axis=1)\n        )\n\n        # Calculate lower and upper slopes for every spectrogram bin.\n        # Line segments are linear in the mel domain, not Hertz.\n        lower_slopes = (spectrogram_bins_mel - lower_edge_mel) / (\n            center_mel - lower_edge_mel\n        )\n        upper_slopes = (upper_edge_mel - spectrogram_bins_mel) / (\n            upper_edge_mel - center_mel\n        )\n\n        # Intersect the line segments with each other and zero.\n        mel_weights_matrix = self.backend.numpy.maximum(\n            zero, self.backend.numpy.minimum(lower_slopes, upper_slopes)\n        )\n\n        # Re-add the zeroed lower bins we sliced out above.\n        return self.backend.numpy.pad(\n            mel_weights_matrix,\n            [[bands_to_zero, 0], [0, 0]],\n        )\n\n    def compute_output_shape(self, input_shape):\n        if len(input_shape) == 1:\n            output_shape = [\n                self.num_mel_bins,\n                (\n                    (input_shape[0] + self.sequence_stride + 1)\n                    // self.sequence_stride\n                    if input_shape[0] is not None\n                    else None\n                ),\n            ]\n        else:\n            output_shape = [\n                input_shape[0],\n                self.num_mel_bins,\n                (\n                    (input_shape[1] + self.sequence_stride + 1)\n                    // self.sequence_stride\n                    if input_shape[1] is not None\n                    else None\n                ),\n            ]\n        return output_shape\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"fft_length\": self.fft_length,\n                \"sequence_stride\": self.sequence_stride,\n                \"sequence_length\": self.sequence_length,\n                \"window\": self.window,\n                \"sampling_rate\": self.sampling_rate,\n                \"num_mel_bins\": self.num_mel_bins,\n                \"min_freq\": self.min_freq,\n                \"max_freq\": self.max_freq,\n                \"power_to_db\": self.power_to_db,\n                \"top_db\": self.top_db,\n                \"mag_exp\": self.mag_exp,\n                \"min_power\": self.min_power,\n                \"ref_power\": self.ref_power,\n            }\n        )\n        return config\n"
  },
  {
    "path": "keras/src/layers/preprocessing/mel_spectrogram_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\nfrom tensorflow import data as tf_data\n\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass MelSpectrogramTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_mel_spectrogram_basics(self):\n        self.run_layer_test(\n            layers.MelSpectrogram,\n            init_kwargs={\n                \"num_mel_bins\": 80,\n                \"sampling_rate\": 8000,\n                \"sequence_stride\": 128,\n                \"fft_length\": 2048,\n            },\n            input_shape=(2, 16000),\n            expected_output_shape=(2, 80, 126),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n        self.run_layer_test(\n            layers.MelSpectrogram,\n            init_kwargs={\n                \"num_mel_bins\": 80,\n                \"sampling_rate\": 8000,\n                \"sequence_stride\": 128,\n                \"fft_length\": 2048,\n            },\n            input_shape=(16000,),\n            expected_output_shape=(80, 126),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n\n    @parameterized.parameters(\n        [\n            ((2, 16000), 80, 128, 2048, 8000, False),\n            ((16000,), 80, 128, 2048, 8000, False),\n            ((2, 16001), 80, 128, 2048, 16000, False),\n            ((16001,), 80, 128, 2048, 8000, False),\n            ((2, 8000), 128, 64, 512, 32000, False),\n            ((8000,), 128, 64, 512, 32000, False),\n            ((2, 8000), 128, 64, 512, 32000, True),\n            ((8000,), 128, 64, 512, 32000, True),\n        ]\n    )\n    def test_output_shape(\n        self,\n        input_shape,\n        num_mel_bins,\n        sequence_stride,\n        fft_length,\n        sampling_rate,\n        all_zero,\n    ):\n        if all_zero:\n            audios = np.zeros(input_shape)\n        else:\n            audios = np.random.random(input_shape)\n        out = layers.MelSpectrogram(\n            num_mel_bins=num_mel_bins,\n            sequence_stride=sequence_stride,\n            fft_length=fft_length,\n            sampling_rate=sampling_rate,\n        )(audios)\n        if len(input_shape) == 1:\n            ref_shape = (\n                num_mel_bins,\n                (input_shape[0] + sequence_stride + 1) // sequence_stride,\n            )\n        else:\n            ref_shape = (\n                input_shape[0],\n                num_mel_bins,\n                (input_shape[1] + sequence_stride + 1) // sequence_stride,\n            )\n        self.assertEqual(tuple(out.shape), ref_shape)\n\n    def test_tf_data_compatibility(self):\n        input_shape = (2, 16000)\n        output_shape = (2, 80, 126)\n        layer = layers.MelSpectrogram(\n            num_mel_bins=80,\n            sampling_rate=8000,\n            sequence_stride=128,\n            fft_length=2048,\n        )\n        input_data = np.random.random(input_shape)\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        for output in ds.take(1):\n            output = output.numpy()\n        self.assertEqual(tuple(output.shape), output_shape)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/normalization.py",
    "content": "import itertools\nimport math\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.data_layer import DataLayer\nfrom keras.src.trainers.data_adapters.py_dataset_adapter import PyDataset\nfrom keras.src.utils.module_utils import tensorflow as tf\n\n\ndef _extract_batch(batch):\n    \"\"\"Return input from batch; handle (x, y) or (x, y, sample_weight).\"\"\"\n    if isinstance(batch, tuple):\n        return batch[0]\n    return batch\n\n\n@keras_export(\"keras.layers.Normalization\")\nclass Normalization(DataLayer):\n    \"\"\"A preprocessing layer that normalizes continuous features.\n\n    This layer will shift and scale inputs into a distribution centered around\n    0 with standard deviation 1. It accomplishes this by precomputing the mean\n    and variance of the data, and calling `(input - mean) / sqrt(var)` at\n    runtime.\n\n    The mean and variance values for the layer must be either supplied on\n    construction or learned via `adapt()`. `adapt()` will compute the mean and\n    variance of the data and store them as the layer's weights. `adapt()` should\n    be called before `fit()`, `evaluate()`, or `predict()`.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Args:\n        axis: Integer, tuple of integers, or None. The axis or axes that should\n            have a separate mean and variance for each index in the shape.\n            For example, if shape is `(None, 5)` and `axis=1`, the layer will\n            track 5 separate mean and variance values for the last axis.\n            If `axis` is set to `None`, the layer will normalize\n            all elements in the input by a scalar mean and variance.\n            When `-1`, the last axis of the input is assumed to be a\n            feature dimension and is normalized per index.\n            Note that in the specific case of batched scalar inputs where\n            the only axis is the batch axis, the default will normalize\n            each index in the batch separately.\n            In this case, consider passing `axis=None`. Defaults to `-1`.\n        mean: The mean value(s) to use during normalization. The passed value(s)\n            will be broadcast to the shape of the kept axes above;\n            if the value(s) cannot be broadcast, an error will be raised when\n            this layer's `build()` method is called.\n            `mean` and `variance` must be specified together.\n        variance: The variance value(s) to use during normalization. The passed\n            value(s) will be broadcast to the shape of the kept axes above;\n            if the value(s) cannot be broadcast, an error will be raised when\n            this layer's `build()` method is called.\n            `mean` and `variance` must be specified together.\n        invert: If `True`, this layer will apply the inverse transformation\n            to its inputs: it would turn a normalized input back into its\n            original form.\n\n    Examples:\n\n    Calculate a global mean and variance by analyzing the dataset in `adapt()`.\n\n    >>> adapt_data = np.array([1., 2., 3., 4., 5.], dtype='float32')\n    >>> input_data = np.array([1., 2., 3.], dtype='float32')\n    >>> layer = keras.layers.Normalization(axis=None)\n    >>> layer.adapt(adapt_data)\n    >>> layer(input_data)\n    array([-1.4142135, -0.70710677, 0.], dtype=float32)\n\n    Calculate a mean and variance for each index on the last axis.\n\n    >>> adapt_data = np.array([[0., 7., 4.],\n    ...                        [2., 9., 6.],\n    ...                        [0., 7., 4.],\n    ...                        [2., 9., 6.]], dtype='float32')\n    >>> input_data = np.array([[0., 7., 4.]], dtype='float32')\n    >>> layer = keras.layers.Normalization(axis=-1)\n    >>> layer.adapt(adapt_data)\n    >>> layer(input_data)\n    array([-1., -1., -1.], dtype=float32)\n\n    Pass the mean and variance directly.\n\n    >>> input_data = np.array([[1.], [2.], [3.]], dtype='float32')\n    >>> layer = keras.layers.Normalization(mean=3., variance=2.)\n    >>> layer(input_data)\n    array([[-1.4142135 ],\n           [-0.70710677],\n           [ 0.        ]], dtype=float32)\n\n    Use the layer to de-normalize inputs (after adapting the layer).\n\n    >>> adapt_data = np.array([[0., 7., 4.],\n    ...                        [2., 9., 6.],\n    ...                        [0., 7., 4.],\n    ...                        [2., 9., 6.]], dtype='float32')\n    >>> input_data = np.array([[1., 2., 3.]], dtype='float32')\n    >>> layer = keras.layers.Normalization(axis=-1, invert=True)\n    >>> layer.adapt(adapt_data)\n    >>> layer(input_data)\n    array([2., 10., 8.], dtype=float32)\n    \"\"\"\n\n    def __init__(\n        self, axis=-1, mean=None, variance=None, invert=False, **kwargs\n    ):\n        super().__init__(**kwargs)\n        # Standardize `axis` to a tuple.\n        if axis is None:\n            axis = ()\n        elif isinstance(axis, int):\n            axis = (axis,)\n        else:\n            axis = tuple(axis)\n        self.axis = axis\n\n        self.input_mean = mean\n        self.input_variance = variance\n        self.invert = invert\n        self.supports_masking = True\n        self._build_input_shape = None\n        self.mean = None\n\n        # Set `mean` and `variance` if passed.\n        if (mean is not None) != (variance is not None):\n            raise ValueError(\n                \"When setting values directly, both `mean` and `variance` \"\n                f\"must be set. Received: mean={mean} and variance={variance}\"\n            )\n        if mean is not None:\n            # Verify mean and variance have the same shape.\n            if np.shape(mean) != np.shape(variance):\n                raise ValueError(\n                    \"When setting values directly, `mean` and `variance` \"\n                    \"must have the same shape. Received: \"\n                    f\"mean shape {np.shape(mean)} and \"\n                    f\"variance shape {np.shape(variance)}\"\n                )\n            # Verify mean rank <= number of axes.\n            if len(np.shape(mean)) > len(self.axis):\n                raise ValueError(\n                    \"The rank of `mean` must be less than or equal to the \"\n                    f\"number of axes ({len(self.axis)}). Received: \"\n                    f\"mean shape {np.shape(mean)} for axis {self.axis}\"\n                )\n\n    def build(self, input_shape):\n        if input_shape is None:\n            return\n\n        ndim = len(input_shape)\n        self._build_input_shape = input_shape\n\n        if any(a < -ndim or a >= ndim for a in self.axis):\n            raise ValueError(\n                \"All `axis` values must be in the range [-ndim, ndim). \"\n                f\"Received inputs with ndim={ndim}, while axis={self.axis}\"\n            )\n\n        # Axes to be kept, replacing negative values with positive equivalents.\n        # Sorted to avoid transposing axes.\n        self._keep_axis = tuple(\n            sorted([d if d >= 0 else d + ndim for d in self.axis])\n        )\n        # All axes to be kept should have known shape.\n        for d in self._keep_axis:\n            if input_shape[d] is None:\n                raise ValueError(\n                    \"All `axis` values to be kept must have a known shape. \"\n                    f\"Received axis={self.axis}, \"\n                    f\"inputs.shape={input_shape}, \"\n                    f\"with unknown axis at index {d}\"\n                )\n        # Axes to be reduced.\n        self._reduce_axis = tuple(\n            d for d in range(ndim) if d not in self._keep_axis\n        )\n        # 1 if an axis should be reduced, 0 otherwise.\n        self._reduce_axis_mask = [\n            0 if d in self._keep_axis else 1 for d in range(ndim)\n        ]\n        # Broadcast any reduced axes.\n        self._broadcast_shape = [\n            input_shape[d] if d in self._keep_axis else 1 for d in range(ndim)\n        ]\n        mean_and_var_shape = tuple(input_shape[d] for d in self._keep_axis)\n        self._mean_and_var_shape = mean_and_var_shape\n\n        if self.input_mean is None:\n            self.adapt_mean = self.add_weight(\n                name=\"mean\",\n                shape=mean_and_var_shape,\n                initializer=\"zeros\",\n                trainable=False,\n            )\n            self.adapt_variance = self.add_weight(\n                name=\"variance\",\n                shape=mean_and_var_shape,\n                initializer=\"ones\",\n                trainable=False,\n            )\n            # For backwards compatibility with older saved models.\n            self.count = self.add_weight(\n                name=\"count\",\n                shape=(),\n                dtype=\"int\",\n                initializer=\"zeros\",\n                trainable=False,\n            )\n            self.built = True\n            self.finalize_state()\n\n        else:\n            mean = ops.convert_to_tensor(self.input_mean)\n            variance = ops.convert_to_tensor(self.input_variance)\n\n            if ops.ndim(mean) == 0:\n                # Case 1: Scalar mean/variance\n                mean = ops.broadcast_to(mean, self._broadcast_shape)\n                variance = ops.broadcast_to(variance, self._broadcast_shape)\n            else:\n                # Case 2: General broadcasting. Align mean/variance dims\n                # to the kept axes from right to left.\n                expanded_shape = [1] * ndim\n                mean_shape = ops.shape(mean)\n                mean_ndim = ops.ndim(mean)\n\n                # Map mean dimensions to the correct kept axes (right-to-left).\n                # This handles cases where mean has fewer dims than keep_axis.\n                for i in range(1, mean_ndim + 1):\n                    axis_idx = self._keep_axis[-i]\n                    expanded_shape[axis_idx] = mean_shape[-i]\n\n                mean = ops.reshape(mean, expanded_shape)\n                variance = ops.reshape(variance, expanded_shape)\n\n                # Broadcast to the full target shape.\n                mean = ops.broadcast_to(mean, self._broadcast_shape)\n                variance = ops.broadcast_to(variance, self._broadcast_shape)\n\n            self.mean = ops.cast(mean, dtype=self.compute_dtype)\n            self.variance = ops.cast(variance, dtype=self.compute_dtype)\n            self.built = True\n\n    def adapt(self, data):\n        \"\"\"Computes the mean and variance of values in a dataset.\n\n        Calling `adapt()` on a `Normalization` layer is an alternative to\n        passing in `mean` and `variance` arguments during layer construction. A\n        `Normalization` layer should always either be adapted over a dataset or\n        passed `mean` and `variance`.\n\n        During `adapt()`, the layer will compute a `mean` and `variance`\n        separately for each position in each axis specified by the `axis`\n        argument. To calculate a single `mean` and `variance` over the input\n        data, simply pass `axis=None` to the layer.\n\n        Arg:\n            data: The data to train on. It can be passed as a NumPy array, a\n                backend-native eager tensor, a `tf.data.Dataset`, a\n                `keras.utils.PyDataset`, or an iterable of batches (e.g. a\n                list of arrays or a generator yielding batches). If a dataset\n                or iterable, *it must be batched*. Keras will assume that each\n                element is a batch, and if that assumption doesn't hold, the\n                mean and variance may be incorrectly computed.\n        \"\"\"\n        data_is_iterable = False\n        if isinstance(data, np.ndarray) or backend.is_tensor(data):\n            input_shape = data.shape\n        elif isinstance(data, tf.data.Dataset):\n\n            def get_input_shape(d):\n                element_spec = d.element_spec\n                x_spec = (\n                    element_spec[0]\n                    if isinstance(element_spec, tuple)\n                    else element_spec\n                )\n                return tuple(x_spec.shape)\n\n            input_shape = get_input_shape(data)\n            if len(input_shape) == 1:\n                data = data.batch(128)\n                input_shape = get_input_shape(data)\n        elif isinstance(data, PyDataset):\n            input_shape = _extract_batch(data[0]).shape\n        elif hasattr(data, \"__iter__\"):\n            data_is_iterable = True\n            # Consume first batch to infer input_shape; then chain it back for\n            # accumulation so we iterate over (first_batch, *rest).\n            data_iter = iter(data)\n            first_batch = next(data_iter, None)\n            if first_batch is None:\n                raise ValueError(\n                    \"adapt() received an empty iterable (no batches). \"\n                    \"Expected at least one batch. Pass a non-empty iterable \"\n                    \"of arrays or tensors, e.g. layer.adapt([x]) or \"\n                    \"layer.adapt(list_of_batches).\"\n                )\n            first_batch = _extract_batch(first_batch)\n            input_shape = getattr(first_batch, \"shape\", None)\n            if input_shape is None:\n                raise TypeError(\n                    \"adapt() expects an iterable that yields arrays or \"\n                    \"tensors with a `.shape` attribute (e.g. numpy arrays or \"\n                    \"backend tensors). Got an element of type \"\n                    f\"{type(first_batch).__name__}. Ensure each yielded \"\n                    \"element is array-like with a `.shape` attribute.\"\n                )\n            input_shape = tuple(input_shape)\n            data = itertools.chain([first_batch], data_iter)\n        else:\n            raise TypeError(\n                f\"Unsupported data type: {type(data)}. `adapt` supports \"\n                f\"`np.ndarray`, backend tensors, `tf.data.Dataset`, \"\n                f\"`keras.utils.PyDataset`, and iterables of batches (e.g. \"\n                f\"list, generator).\"\n            )\n\n        if not self.built:\n            self.build(input_shape)\n        else:\n            for d in self._keep_axis:\n                if input_shape[d] != self._build_input_shape[d]:\n                    raise ValueError(\n                        \"The layer was built with \"\n                        f\"input_shape={self._build_input_shape}, \"\n                        \"but adapt() is being called with data with \"\n                        f\"an incompatible shape, data.shape={input_shape}\"\n                    )\n\n        if isinstance(data, np.ndarray):\n            total_mean = np.mean(data, axis=self._reduce_axis)\n            total_var = np.var(data, axis=self._reduce_axis)\n        elif backend.is_tensor(data):\n            total_mean = ops.mean(data, axis=self._reduce_axis)\n            total_var = ops.var(data, axis=self._reduce_axis)\n        elif isinstance(data, (tf.data.Dataset, PyDataset)) or data_is_iterable:\n            total_mean = ops.zeros(self._mean_and_var_shape)\n            total_var = ops.zeros(self._mean_and_var_shape)\n            total_count = 0\n            for batch in data:\n                batch = _extract_batch(batch)\n                batch = backend.convert_to_tensor(\n                    batch, dtype=self.compute_dtype\n                )\n                for d in self._keep_axis:\n                    batch_dim = batch.shape[d]\n                    expected = self._build_input_shape[d]\n                    if (\n                        batch_dim is not None\n                        and expected is not None\n                        and batch_dim != expected\n                    ):\n                        raise ValueError(\n                            \"adapt() yielded a batch with incompatible \"\n                            \"shape. Expected \"\n                            f\"{self._build_input_shape}, got \"\n                            f\"{tuple(batch.shape)}.\"\n                        )\n                batch_mean = ops.mean(batch, axis=self._reduce_axis)\n                batch_var = ops.var(batch, axis=self._reduce_axis)\n                if self._reduce_axis:\n                    batch_reduce_shape = (\n                        batch.shape[d] for d in self._reduce_axis\n                    )\n                    batch_count = math.prod(batch_reduce_shape)\n                else:\n                    batch_count = 1\n\n                total_count += batch_count\n                batch_weight = float(batch_count) / total_count\n                existing_weight = 1.0 - batch_weight\n                new_total_mean = (\n                    total_mean * existing_weight + batch_mean * batch_weight\n                )\n                # The variance is computed using the lack-of-fit sum of squares\n                # formula (see\n                # https://en.wikipedia.org/wiki/Lack-of-fit_sum_of_squares).\n                total_var = (\n                    total_var + (total_mean - new_total_mean) ** 2\n                ) * existing_weight + (\n                    batch_var + (batch_mean - new_total_mean) ** 2\n                ) * batch_weight\n                total_mean = new_total_mean\n        else:\n            raise NotImplementedError(f\"Unsupported data type: {type(data)}\")\n\n        self.adapt_mean.assign(total_mean)\n        self.adapt_variance.assign(total_var)\n        self.finalize_state()\n\n    def finalize_state(self):\n        if self.input_mean is not None or not self.built:\n            return\n\n        # In the adapt case, we make constant tensors for mean and variance with\n        # proper broadcast shape and dtype each time `finalize_state` is called.\n        self.mean = ops.reshape(self.adapt_mean, self._broadcast_shape)\n        self.mean = ops.cast(self.mean, self.compute_dtype)\n        self.variance = ops.reshape(self.adapt_variance, self._broadcast_shape)\n        self.variance = ops.cast(self.variance, self.compute_dtype)\n\n    def call(self, inputs):\n        # This layer can be called in tf.data\n        # even with another backend after it has been adapted.\n        # However it must use backend-native logic for adapt().\n        if self.mean is None:\n            # May happen when in tf.data when mean/var was passed explicitly\n            raise ValueError(\n                \"You must call `.build(input_shape)` \"\n                \"on the layer before using it.\"\n            )\n        inputs = self.backend.core.convert_to_tensor(\n            inputs, dtype=self.compute_dtype\n        )\n        # Ensure the weights are in the correct backend. Without this, it is\n        # possible to cause breakage when using this layer in tf.data.\n        mean = self.convert_weight(self.mean)\n        variance = self.convert_weight(self.variance)\n        if self.invert:\n            return self.backend.numpy.add(\n                mean,\n                self.backend.numpy.multiply(\n                    inputs,\n                    self.backend.numpy.maximum(\n                        self.backend.numpy.sqrt(variance), backend.epsilon()\n                    ),\n                ),\n            )\n        else:\n            return self.backend.numpy.divide(\n                self.backend.numpy.subtract(inputs, mean),\n                self.backend.numpy.maximum(\n                    self.backend.numpy.sqrt(variance), backend.epsilon()\n                ),\n            )\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"axis\": self.axis,\n                \"invert\": self.invert,\n                \"mean\": np.array(self.input_mean).tolist(),\n                \"variance\": np.array(self.input_variance).tolist(),\n            }\n        )\n        return config\n\n    def load_own_variables(self, store):\n        super().load_own_variables(store)\n        # Ensure that we call finalize_state after variable loading.\n        self.finalize_state()\n\n    def get_build_config(self):\n        if self._build_input_shape:\n            return {\"input_shape\": self._build_input_shape}\n\n    def build_from_config(self, config):\n        if config:\n            self.build(config[\"input_shape\"])\n"
  },
  {
    "path": "keras/src/layers/preprocessing/normalization_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\nfrom keras.src.trainers.data_adapters.py_dataset_adapter import PyDataset\n\n\nclass NormalizationTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_normalization_basics(self):\n        self.run_layer_test(\n            layers.Normalization,\n            init_kwargs={\n                \"axis\": -1,\n            },\n            input_shape=(2, 3),\n            expected_output_shape=(2, 3),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=3,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.Normalization,\n            init_kwargs={\n                \"axis\": -1,\n                \"mean\": np.array([0.5, 0.2, -0.1]),\n                \"variance\": np.array([0.1, 0.2, 0.3]),\n            },\n            input_shape=(2, 3),\n            expected_output_shape=(2, 3),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.Normalization,\n            init_kwargs={\n                \"axis\": -1,\n                \"mean\": np.array([0.5, 0.2, -0.1]),\n                \"variance\": np.array([0.1, 0.2, 0.3]),\n                \"invert\": True,\n            },\n            input_shape=(2, 3),\n            expected_output_shape=(2, 3),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n        )\n\n    @parameterized.parameters([(\"np\",), (\"tensor\",), (\"tf.data\")])\n    def test_normalization_adapt(self, input_type):\n        x = np.random.random((32, 4))\n        if input_type == \"np\":\n            data = x\n        elif input_type == \"tensor\":\n            data = backend.convert_to_tensor(x)\n        elif input_type == \"tf.data\":\n            data = tf_data.Dataset.from_tensor_slices(x).batch(8)\n        else:\n            raise NotImplementedError(input_type)\n\n        layer = layers.Normalization()\n        layer.adapt(data)\n        self.assertTrue(layer.built)\n        output = layer(x)\n        output = backend.convert_to_numpy(output)\n        self.assertAllClose(np.var(output, axis=0), 1.0, atol=1e-5)\n        self.assertAllClose(np.mean(output, axis=0), 0.0, atol=1e-5)\n\n        # Test in high-dim and with tuple axis.\n        x = np.random.random((32, 4, 3, 5))\n        if input_type == \"np\":\n            data = x\n        elif input_type == \"tensor\":\n            data = backend.convert_to_tensor(x)\n        elif input_type == \"tf.data\":\n            data = tf_data.Dataset.from_tensor_slices(x).batch(8)\n\n        layer = layers.Normalization(axis=(1, 2))\n        layer.adapt(data)\n        self.assertTrue(layer.built)\n        output = layer(x)\n        output = backend.convert_to_numpy(output)\n        self.assertAllClose(np.var(output, axis=(0, 3)), 1.0, atol=1e-5)\n        self.assertAllClose(np.mean(output, axis=(0, 3)), 0.0, atol=1e-5)\n\n    @pytest.mark.skipif(\n        backend.backend() != \"torch\",\n        reason=\"Test symbolic call for torch meta device.\",\n    )\n    def test_call_on_meta_device_after_built(self):\n        layer = layers.Normalization()\n        data = np.random.random((32, 4))\n        layer.adapt(data)\n        with backend.device(\"meta\"):\n            layer(data)\n\n    def test_normalization_with_mean_only_raises_error(self):\n        # Test error when only `mean` is provided\n        with self.assertRaisesRegex(\n            ValueError, \"both `mean` and `variance` must be set\"\n        ):\n            layers.Normalization(mean=0.5)\n\n    def test_normalization_with_variance_only_raises_error(self):\n        # Test error when only `variance` is provided\n        with self.assertRaisesRegex(\n            ValueError, \"both `mean` and `variance` must be set\"\n        ):\n            layers.Normalization(variance=0.1)\n\n    def test_normalization_axis_too_high(self):\n        with self.assertRaisesRegex(\n            ValueError, \"All `axis` values must be in the range\"\n        ):\n            layer = layers.Normalization(axis=3)\n            layer.build((2, 2))\n\n    def test_normalization_axis_too_low(self):\n        with self.assertRaisesRegex(\n            ValueError, \"All `axis` values must be in the range\"\n        ):\n            layer = layers.Normalization(axis=-4)\n            layer.build((2, 3, 4))\n\n    def test_normalization_unknown_axis_shape(self):\n        with self.assertRaisesRegex(ValueError, \"All `axis` values to be kept\"):\n            layer = layers.Normalization(axis=1)\n            layer.build((None, None))\n\n    def test_normalization_adapt_with_incompatible_shape(self):\n        layer = layers.Normalization(axis=-1)\n        initial_shape = (10, 5)\n        layer.build(initial_shape)\n        new_shape_data = np.random.random((10, 3))\n        with self.assertRaisesRegex(ValueError, \"an incompatible shape\"):\n            layer.adapt(new_shape_data)\n\n    def test_tf_data_compatibility(self):\n        x = np.random.random((32, 3))\n        ds = tf_data.Dataset.from_tensor_slices(x).batch(1)\n\n        # With built-in values\n        layer = layers.Normalization(\n            mean=[0.1, 0.2, 0.3], variance=[0.1, 0.2, 0.3], axis=-1\n        )\n        layer.build((None, 3))\n        for output in ds.map(layer).take(1):\n            output.numpy()\n\n        # With adapt flow\n        layer = layers.Normalization(axis=-1)\n        layer.adapt(\n            np.random.random((32, 3)),\n        )\n        for output in ds.map(layer).take(1):\n            output.numpy()\n\n    def test_normalization_with_scalar_mean_var(self):\n        input_data = np.array([[1, 2, 3]], dtype=\"float32\")\n        layer = layers.Normalization(mean=3.0, variance=2.0)\n        layer(input_data)\n\n    @parameterized.parameters([(\"x\",), (\"x_and_y\",), (\"x_y_and_weights\",)])\n    def test_adapt_pydataset_compat(self, pydataset_type):\n        import keras\n\n        class CustomDataset(PyDataset):\n            def __len__(self):\n                return 100\n\n            def __getitem__(self, idx):\n                x = np.random.rand(32, 32, 3)\n                y = np.random.randint(0, 10, size=(1,))\n                weights = np.random.randint(0, 10, size=(1,))\n                if pydataset_type == \"x\":\n                    return x\n                elif pydataset_type == \"x_and_y\":\n                    return x, y\n                elif pydataset_type == \"x_y_and_weights\":\n                    return x, y, weights\n                else:\n                    raise NotImplementedError(pydataset_type)\n\n        normalizer = keras.layers.Normalization()\n        normalizer.adapt(CustomDataset())\n        self.assertTrue(normalizer.built)\n        self.assertIsNotNone(normalizer.mean)\n        self.assertIsNotNone(normalizer.variance)\n        self.assertEqual(normalizer.mean.shape[-1], 3)\n        self.assertEqual(normalizer.variance.shape[-1], 3)\n        sample_input = np.random.rand(1, 32, 32, 3)\n        output = normalizer(sample_input)\n        self.assertEqual(output.shape, (1, 32, 32, 3))\n\n    def test_broadcast_non_scalar_middle_axis(self):\n        \"\"\"\n        Tests mean/variance that are not scalars and require\n        expanding dims on non-kept axes (the 'general case').\n        \"\"\"\n        # (Batch=2, Height=4, Width=5, Channels=3)\n        input_shape = (2, 4, 5, 3)\n        # We want to normalize only across the 'Width' (axis 2)\n        axis = 2\n        custom_mean = np.arange(1, 6, dtype=\"float32\")  # shape (5,)\n        custom_var = np.ones((5,), dtype=\"float32\")\n        layer = layers.Normalization(\n            axis=axis, mean=custom_mean, variance=custom_var\n        )\n        layer.build(input_shape)\n\n        # The expected broadcast shape should be (1, 1, 5, 1)\n        self.assertEqual(tuple(layer.mean.shape), (1, 1, 5, 1))\n        self.assertAllClose(layer.mean[0, 0, :, 0], custom_mean)\n\n    def test_broadcast_multiple_axes(self):\n        \"\"\"\n        Tests keeping multiple axes, e.g., (Height, Width) but not Channels.\n        \"\"\"\n        # Batch=None, Height=10, Width=13, Channels=3\n        input_shape = (None, 10, 13, 3)\n        axis = (1, 2)\n\n        custom_mean = np.zeros((10, 13), dtype=\"float32\")\n        custom_var = np.ones((10, 13), dtype=\"float32\")\n\n        layer = layers.Normalization(\n            axis=axis, mean=custom_mean, variance=custom_var\n        )\n        layer.build(input_shape)\n\n        # The expected broadcast shape should be (1, 10, 13, 1)\n        self.assertEqual(tuple(layer.mean.shape), (1, 10, 13, 1))\n\n    def test_broadcast_partial_keep_axis(self):\n        \"\"\"\n        Test mean has fewer dims than kept axes (right-to-left alignment).\n\n        This covers the case where axis=(1, 2) but mean is 1D, meaning it\n        should align with axis 2 and broadcast across axis 1.\n        \"\"\"\n        # Batch=2, H=7, W=5, C=3\n        input_shape = (2, 7, 5, 3)\n        axis = (1, 2)\n\n        custom_mean = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=\"float32\")\n        custom_var = np.ones((5,), dtype=\"float32\")\n\n        layer = layers.Normalization(\n            axis=axis, mean=custom_mean, variance=custom_var\n        )\n        layer.build(input_shape)\n        self.assertEqual(tuple(layer.mean.shape), (1, 7, 5, 1))\n\n        # Verify alignment and broadcasting\n        expected_values = np.reshape(custom_mean, (1, 1, 5, 1))\n\n        self.assertAllClose(layer.mean[:, 0:1, :, :], expected_values)\n        self.assertAllClose(layer.mean[:, 6:7, :, :], expected_values)\n\n    def test_scalar_broadcast(self):\n        \"\"\"\n        Ensures the scalar case still broadcasts to the full rank.\n        \"\"\"\n        input_shape = (3, 7)  # (Batch=3, Features=7)\n        layer = layers.Normalization(axis=-1, mean=5.0, variance=1.0)\n        layer.build(input_shape)\n\n        # The expected broadcast shape should be (1, 7)\n        self.assertEqual(tuple(layer.mean.shape), (1, 7))\n        self.assertAllClose(layer.mean, [[5.0] * 7])\n\n    def test_adapt_list_of_batches(self):\n        x = np.random.random((32, 4)).astype(\"float32\")\n        batches = [x[:8], x[8:16], x[16:24], x[24:32]]\n        layer = layers.Normalization(axis=-1)\n        layer.adapt(batches)\n        self.assertTrue(layer.built)\n        output = layer(x)\n        output = backend.convert_to_numpy(output)\n        self.assertAllClose(np.var(output, axis=0), 1.0, atol=1e-5)\n        self.assertAllClose(np.mean(output, axis=0), 0.0, atol=1e-5)\n\n    def test_adapt_generator(self):\n        x = np.random.random((32, 4)).astype(\"float32\")\n\n        def batch_gen():\n            for i in range(0, 32, 8):\n                yield x[i : i + 8]\n\n        layer = layers.Normalization(axis=-1)\n        layer.adapt(batch_gen())\n        self.assertTrue(layer.built)\n        output = layer(x)\n        output = backend.convert_to_numpy(output)\n        self.assertAllClose(np.var(output, axis=0), 1.0, atol=1e-5)\n        self.assertAllClose(np.mean(output, axis=0), 0.0, atol=1e-5)\n\n    def test_adapt_iterable_same_result_as_ndarray(self):\n        x = np.random.random((64, 5)).astype(\"float32\")\n        list_of_batches = [x[i : i + 16] for i in range(0, 64, 16)]\n        layer_list = layers.Normalization(axis=-1)\n        layer_list.adapt(list_of_batches)\n        layer_ndarray = layers.Normalization(axis=-1)\n        layer_ndarray.adapt(x)\n        out_list = layer_list(x[:10])\n        out_ndarray = layer_ndarray(x[:10])\n        out_list = backend.convert_to_numpy(out_list)\n        out_ndarray = backend.convert_to_numpy(out_ndarray)\n        self.assertAllClose(out_list, out_ndarray, atol=1e-5)\n\n    def test_adapt_iterable_with_tuples(self):\n        x = np.random.random((24, 3)).astype(\"float32\")\n        batches = [(x[i : i + 8], np.zeros(8)) for i in range(0, 24, 8)]\n        layer = layers.Normalization(axis=-1)\n        layer.adapt(batches)\n        self.assertTrue(layer.built)\n        output = layer(x)\n        output = backend.convert_to_numpy(output)\n        self.assertAllClose(np.var(output, axis=0), 1.0, atol=1e-5)\n        self.assertAllClose(np.mean(output, axis=0), 0.0, atol=1e-5)\n\n    def test_adapt_iterable_axis_none(self):\n        x = np.random.random((20, 2, 3)).astype(\"float32\")\n        batches = [x[i : i + 5] for i in range(0, 20, 5)]\n        layer = layers.Normalization(axis=None)\n        layer.adapt(batches)\n        self.assertTrue(layer.built)\n        output = layer(x)\n        output = backend.convert_to_numpy(output)\n        self.assertAllClose(np.var(output), 1.0, atol=1e-5)\n        self.assertAllClose(np.mean(output), 0.0, atol=1e-5)\n\n    def test_adapt_empty_iterable_raises(self):\n        layer = layers.Normalization(axis=-1)\n        with self.assertRaisesRegex(ValueError, \"empty iterable\"):\n            layer.adapt([])\n\n    def test_adapt_empty_generator_raises(self):\n        layer = layers.Normalization(axis=-1)\n\n        def empty_gen():\n            yield from ()\n\n        with self.assertRaisesRegex(ValueError, \"empty iterable\"):\n            layer.adapt(empty_gen())\n\n    def test_adapt_iterable_incompatible_shape_raises(self):\n        x1 = np.random.random((8, 4)).astype(\"float32\")\n        x2 = np.random.random((8, 6)).astype(\"float32\")\n\n        def bad_gen():\n            yield x1\n            yield x2\n\n        layer = layers.Normalization(axis=-1)\n        with self.assertRaisesRegex(ValueError, \"incompatible shape\"):\n            layer.adapt(bad_gen())\n\n    def test_adapt_iterable_batch_without_shape_raises(self):\n        layer = layers.Normalization(axis=-1)\n\n        def gen_no_shape():\n            yield 42\n\n        with self.assertRaisesRegex(TypeError, \"`.shape`\"):\n            layer.adapt(gen_no_shape())\n\n    def test_adapt_iterable_single_batch(self):\n        x = np.random.random((16, 4)).astype(\"float32\")\n        layer = layers.Normalization(axis=-1)\n        layer.adapt([x])\n        self.assertTrue(layer.built)\n        output = layer(x)\n        output = backend.convert_to_numpy(output)\n        self.assertAllClose(np.var(output, axis=0), 1.0, atol=1e-5)\n        self.assertAllClose(np.mean(output, axis=0), 0.0, atol=1e-5)\n\n    def test_adapt_iterable_high_dim_axis_tuple(self):\n        x = np.random.random((32, 4, 3, 5)).astype(\"float32\")\n        batches = [x[i : i + 8] for i in range(0, 32, 8)]\n        layer = layers.Normalization(axis=(1, 2))\n        layer.adapt(batches)\n        self.assertTrue(layer.built)\n        output = layer(x)\n        output = backend.convert_to_numpy(output)\n        self.assertAllClose(np.var(output, axis=(0, 3)), 1.0, atol=1e-5)\n        self.assertAllClose(np.mean(output, axis=(0, 3)), 0.0, atol=1e-5)\n\n    def test_adapt_iterator_of_batches(self):\n        x = np.random.random((24, 3)).astype(\"float32\")\n        list_of_batches = [x[i : i + 6] for i in range(0, 24, 6)]\n        layer = layers.Normalization(axis=-1)\n        layer.adapt(iter(list_of_batches))\n        self.assertTrue(layer.built)\n        output = layer(x)\n        output = backend.convert_to_numpy(output)\n        self.assertAllClose(np.var(output, axis=0), 1.0, atol=1e-5)\n        self.assertAllClose(np.mean(output, axis=0), 0.0, atol=1e-5)\n\n    @pytest.mark.requires_trainable_backend\n    def test_adapt_grain_dataset(self):\n        grain = pytest.importorskip(\"grain\")\n        x = np.random.random((24, 3)).astype(\"float32\")\n        ds = grain.MapDataset.source(x).to_iter_dataset().batch(8)\n        layer = layers.Normalization(axis=-1)\n        layer.adapt(ds)\n        self.assertTrue(layer.built)\n        output = layer(x)\n        output = backend.convert_to_numpy(output)\n        self.assertAllClose(np.var(output, axis=0), 1.0, atol=1e-5)\n        self.assertAllClose(np.mean(output, axis=0), 0.0, atol=1e-5)\n\n    def test_adapt_tf_dataset_with_labels(self):\n        \"\"\"Normalization.adapt should support supervised tf.data.Dataset.\"\"\"\n        import tensorflow as tf\n\n        x = np.ones((32, 3), dtype=\"float32\")\n        y = np.ones((32,), dtype=\"int32\")\n\n        dataset = tf.data.Dataset.from_tensor_slices((x, y)).batch(8)\n\n        layer = layers.Normalization()\n        layer.adapt(dataset)\n\n        mean = backend.convert_to_numpy(layer.mean).squeeze()\n        var = backend.convert_to_numpy(layer.variance).squeeze()\n\n        np.testing.assert_allclose(mean, np.ones(3))\n        np.testing.assert_allclose(var, np.zeros(3))\n"
  },
  {
    "path": "keras/src/layers/preprocessing/pipeline.py",
    "content": "from keras.src import tree\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.layer import Layer\nfrom keras.src.saving import serialization_lib\n\n\n@keras_export(\"keras.layers.Pipeline\")\nclass Pipeline(Layer):\n    \"\"\"Applies a series of layers to an input.\n\n    This class is useful to build a preprocessing pipeline,\n    in particular an image data augmentation pipeline.\n    Compared to a `Sequential` model, `Pipeline` features\n    a few important differences:\n\n    - It's not a `Model`, just a plain layer.\n    - When the layers in the pipeline are compatible\n        with `tf.data`, the pipeline will also\n        remain `tf.data` compatible. That is to say,\n        the pipeline will not attempt to convert\n        its inputs to backend-native tensors\n        when in a tf.data context (unlike a `Sequential`\n        model).\n\n    Example:\n\n    ```python\n    from keras import layers\n    preprocessing_pipeline = layers.Pipeline([\n        layers.AutoContrast(),\n        layers.RandomZoom(0.2),\n        layers.RandomRotation(0.2),\n    ])\n\n    # `ds` is a tf.data.Dataset\n    preprocessed_ds = ds.map(\n        preprocessing_pipeline,\n        num_parallel_calls=4,\n    )\n    ```\n    \"\"\"\n\n    def __init__(self, layers, name=None):\n        super().__init__(name=name)\n        self._pipeline_layers = layers\n        self._convert_input_args = False\n        self._allow_non_tensor_positional_args = True\n\n    @property\n    def layers(self):\n        return self._pipeline_layers\n\n    def call(self, inputs, training=True, mask=None):\n        for layer in self._pipeline_layers:\n            kwargs = {}\n            if layer._call_has_mask_arg:\n                kwargs[\"mask\"] = mask\n            if layer._call_has_training_arg and training is not None:\n                kwargs[\"training\"] = training\n            outputs = layer(inputs, **kwargs)\n            inputs = outputs\n\n            def _get_mask_from_keras_tensor(kt):\n                return getattr(kt, \"_keras_mask\", None)\n\n            mask = tree.map_structure(_get_mask_from_keras_tensor, outputs)\n        return outputs\n\n    @classmethod\n    def from_config(cls, config):\n        config[\"layers\"] = [\n            serialization_lib.deserialize_keras_object(x)\n            for x in config[\"layers\"]\n        ]\n        return cls(**config)\n\n    def get_config(self):\n        config = {\n            \"layers\": serialization_lib.serialize_keras_object(\n                self._pipeline_layers\n            ),\n            \"name\": self.name,\n        }\n        return config\n"
  },
  {
    "path": "keras/src/layers/preprocessing/pipeline_test.py",
    "content": "import numpy as np\nimport pytest\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass CanaryLayer(layers.Layer):\n    def __init__(self):\n        super().__init__()\n        self.training = None\n        self.received_mask = False\n\n    def call(self, x, training=False, mask=None):\n        self.training = training\n        if mask is not None:\n            self.received_mask = True\n        return x\n\n    def compute_mask(self, x, mask=None):\n        return x\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n\nclass PipelineTest(testing.TestCase):\n    def test_basics(self):\n        run_training_check = False if backend.backend() == \"numpy\" else True\n        self.run_layer_test(\n            layers.Pipeline,\n            init_kwargs={\n                \"layers\": [layers.AutoContrast(), layers.RandomBrightness(0.1)],\n            },\n            input_shape=(8, 3, 4, 3),\n            supports_masking=False,\n            expected_output_shape=(8, 3, 4, 3),\n            run_mixed_precision_check=False,\n            run_training_check=run_training_check,\n        )\n\n    @pytest.mark.skipif(\n        backend.backend() == \"numpy\", reason=\"masking not working in numpy\"\n    )\n    def test_correctness(self):\n        pipeline = layers.Pipeline([CanaryLayer(), CanaryLayer()])\n        x = np.array([0])\n        mask = np.array([0])\n        pipeline(x, training=True, mask=mask)\n        self.assertTrue(pipeline.layers[0].training)\n        self.assertTrue(pipeline.layers[0].received_mask)\n        self.assertTrue(pipeline.layers[1].training)\n        self.assertTrue(pipeline.layers[1].received_mask)\n\n    def test_tf_data_compatibility(self):\n        if backend.config.image_data_format() == \"channels_last\":\n            input_shape = (2, 10, 12, 3)\n            output_shape = (2, 8, 9, 3)\n        else:\n            input_shape = (2, 3, 10, 12)\n            output_shape = (2, 3, 8, 9)\n        layer = layers.Pipeline(\n            [\n                layers.AutoContrast(),\n                layers.CenterCrop(8, 9),\n            ]\n        )\n        input_data = np.random.random(input_shape)\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        for output in ds.take(1):\n            output = output.numpy()\n        self.assertEqual(tuple(output.shape), output_shape)\n\n    @pytest.mark.skipif(\n        backend.backend() == \"torch\",\n        reason=\"Fails on CI, passes locally. TODO: debug\",\n    )\n    def test_from_config(self):\n        pipeline = layers.Pipeline(\n            [\n                layers.AutoContrast(),\n                layers.CenterCrop(8, 9),\n            ]\n        )\n        x = np.ones((2, 10, 12, 3))\n        output = pipeline(x)\n        restored = layers.Pipeline.from_config(pipeline.get_config())\n        restored_output = restored(x)\n        self.assertEqual(tuple(output.shape), (2, 8, 9, 3))\n        self.assertAllClose(output, restored_output)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/rescaling.py",
    "content": "from keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.data_layer import DataLayer\nfrom keras.src.saving import serialization_lib\n\n\n@keras_export(\"keras.layers.Rescaling\")\nclass Rescaling(DataLayer):\n    \"\"\"A preprocessing layer which rescales input values to a new range.\n\n    This layer rescales every value of an input (often an image) by multiplying\n    by `scale` and adding `offset`.\n\n    For instance:\n\n    1. To rescale an input in the `[0, 255]` range\n    to be in the `[0, 1]` range, you would pass `scale=1./255`.\n\n    2. To rescale an input in the `[0, 255]` range to be in the `[-1, 1]` range,\n    you would pass `scale=1./127.5, offset=-1`.\n\n    The rescaling is applied both during training and inference. Inputs can be\n    of integer or floating point dtype, and by default the layer will output\n    floats.\n\n    **Note:** This layer is safe to use inside a `tf.data` or `grain` pipeline\n    (independently of which backend you're using).\n\n    Args:\n        scale: Float, the scale to apply to the inputs.\n        offset: Float, the offset to apply to the inputs.\n        **kwargs: Base layer keyword arguments, such as `name` and `dtype`.\n    \"\"\"\n\n    def __init__(self, scale, offset=0.0, **kwargs):\n        super().__init__(**kwargs)\n        self.scale = scale\n        self.offset = offset\n        self.supports_masking = True\n\n    def call(self, inputs):\n        dtype = self.compute_dtype\n        scale = self.backend.cast(self.scale, dtype)\n        offset = self.backend.cast(self.offset, dtype)\n        scale_shape = self.backend.core.shape(scale)\n        if (\n            len(scale_shape) > 0\n            and backend.image_data_format() == \"channels_first\"\n        ):\n            scale = self.backend.numpy.reshape(\n                scale, scale_shape + (1,) * (3 - len(scale_shape))\n            )\n        return self.backend.cast(inputs, dtype) * scale + offset\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                # `scale` and `offset` might be numpy array.\n                \"scale\": serialization_lib.serialize_keras_object(self.scale),\n                \"offset\": serialization_lib.serialize_keras_object(self.offset),\n            }\n        )\n        return config\n\n    @classmethod\n    def from_config(cls, config, custom_objects=None):\n        config = config.copy()\n        config[\"scale\"] = serialization_lib.deserialize_keras_object(\n            config[\"scale\"], custom_objects=custom_objects\n        )\n        config[\"offset\"] = serialization_lib.deserialize_keras_object(\n            config[\"offset\"], custom_objects=custom_objects\n        )\n        return cls(**config)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/rescaling_test.py",
    "content": "import grain\nimport numpy as np\nimport pytest\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass RescalingTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_rescaling_basics(self):\n        self.run_layer_test(\n            layers.Rescaling,\n            init_kwargs={\"scale\": 1.0 / 255, \"offset\": 0.5},\n            input_shape=(2, 3),\n            expected_output_shape=(2, 3),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_rescaling_dtypes(self):\n        # int scale\n        self.run_layer_test(\n            layers.Rescaling,\n            init_kwargs={\"scale\": 2, \"offset\": 0.5},\n            input_shape=(2, 3),\n            expected_output_shape=(2, 3),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n        )\n        # int offset\n        self.run_layer_test(\n            layers.Rescaling,\n            init_kwargs={\"scale\": 1.0, \"offset\": 2},\n            input_shape=(2, 3),\n            expected_output_shape=(2, 3),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n        )\n        # int inputs\n        self.run_layer_test(\n            layers.Rescaling,\n            init_kwargs={\"scale\": 1.0 / 255, \"offset\": 0.5},\n            input_shape=(2, 3),\n            input_dtype=\"int16\",\n            expected_output_shape=(2, 3),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n        )\n\n    def test_rescaling_correctness(self):\n        layer = layers.Rescaling(scale=1.0 / 255, offset=0.5)\n        x = np.random.random((3, 10, 10, 3)) * 255\n        out = layer(x)\n        self.assertAllClose(out, x / 255 + 0.5)\n\n    def test_tf_data_compatibility(self):\n        layer = layers.Rescaling(scale=1.0 / 255, offset=0.5)\n        x = np.random.random((3, 10, 10, 3)) * 255\n        ds = tf_data.Dataset.from_tensor_slices(x).batch(3).map(layer)\n        next(iter(ds)).numpy()\n\n    def test_grain_compatibility(self):\n        layer = layers.Rescaling(scale=1.0 / 255, offset=0.5)\n        x = np.random.random((3, 10, 10, 3)) * 255\n        ds = grain.MapDataset.source(x).to_iter_dataset().batch(3).map(layer)\n        output = next(iter(ds))\n\n        self.assertTrue(backend.is_tensor(output))\n        # Ensure the device of the data is on CPU.\n        if backend.backend() == \"tensorflow\":\n            self.assertIn(\"CPU\", str(output.device))\n        elif backend.backend() == \"jax\":\n            self.assertIn(\"CPU\", str(output.device))\n        elif backend.backend() == \"torch\":\n            self.assertEqual(\"cpu\", str(output.device))\n\n    def test_rescaling_with_channels_first_and_vector_scale(self):\n        config = backend.image_data_format()\n        backend.set_image_data_format(\"channels_first\")\n        layer = layers.Rescaling(\n            scale=[1.0 / 255, 1.5 / 255, 2.0 / 255], offset=0.5\n        )\n        x = np.random.random((2, 3, 10, 10)) * 255\n        layer(x)\n        backend.set_image_data_format(config)\n\n    @pytest.mark.requires_trainable_backend\n    def test_numpy_args(self):\n        # https://github.com/keras-team/keras/issues/20072\n        self.run_layer_test(\n            layers.Rescaling,\n            init_kwargs={\n                \"scale\": np.array(1.0 / 255.0),\n                \"offset\": np.array(0.5),\n            },\n            input_shape=(2, 3),\n            expected_output_shape=(2, 3),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=True,\n        )\n"
  },
  {
    "path": "keras/src/layers/preprocessing/stft_spectrogram.py",
    "content": "import math\nimport warnings\n\nfrom keras.src import backend\nfrom keras.src import initializers\nfrom keras.src import layers\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.utils.module_utils import scipy\n\n\n@keras_export(\"keras.layers.STFTSpectrogram\")\nclass STFTSpectrogram(layers.Layer):\n    \"\"\"Layer to compute the Short-Time Fourier Transform (STFT) on a 1D signal.\n\n    A layer that computes Spectrograms of the input signal to produce\n    a spectrogram. This layers utilizes Short-Time Fourier Transform (STFT) by\n    The layer computes Spectrograms based on STFT by utilizing convolution\n    kernels, which allows parallelization on GPUs and trainable kernels for\n    fine-tuning support. This layer allows different modes of output\n    (e.g., log-scaled magnitude, phase, power spectral density, etc.) and\n    provides flexibility in windowing, padding, and scaling options for the\n    STFT calculation.\n\n    Examples:\n\n    Apply it as a non-trainable preprocessing layer on 3 audio tracks of\n    1 channel, 10 seconds and sampled at 16 kHz.\n\n    >>> layer = keras.layers.STFTSpectrogram(\n    ...     mode='log',\n    ...     frame_length=256,\n    ...     frame_step=128,   # 50% overlap\n    ...     fft_length=512,\n    ...     window=\"hann\",\n    ...     padding=\"valid\",\n    ...     trainable=False,  # non-trainable, preprocessing only\n    ... )\n    >>> layer(keras.random.uniform(shape=(3, 160000, 1))).shape\n    (3, 1249, 257)\n\n    Apply it as a trainable processing layer on 3 stereo audio tracks of\n    2 channels, 10 seconds and sampled at 16 kHz. This is initialized as the\n    non-trainable layer, but then can be trained jointly within a model.\n\n    >>> layer = keras.layers.STFTSpectrogram(\n    ...     mode='log',\n    ...     frame_length=256,\n    ...     frame_step=128,    # 50% overlap\n    ...     fft_length=512,\n    ...     window=\"hamming\",  # hamming windowing function\n    ...     padding=\"same\",    # padding to preserve the time dimension\n    ...     trainable=True,    # trainable, this is the default in keras\n    ... )\n    >>> layer(keras.random.uniform(shape=(3, 160000, 2))).shape\n    (3, 1250, 514)\n\n    Similar to the last example, but add an extra dimension so the output is\n    an image to be used with image models. We apply this here on a signal of\n    3 input channels to output an image tensor, hence is directly applicable\n    with an image model.\n\n    >>> layer = keras.layers.STFTSpectrogram(\n    ...     mode='log',\n    ...     frame_length=256,\n    ...     frame_step=128,\n    ...     fft_length=512,\n    ...     padding=\"same\",\n    ...     expand_dims=True,  # this adds the extra dimension\n    ... )\n    >>> layer(keras.random.uniform(shape=(3, 160000, 3))).shape\n    (3, 1250, 257, 3)\n\n    Args:\n        mode: String, the output type of the spectrogram. Can be one of\n            `\"log\"`, `\"magnitude`\", `\"psd\"`, `\"real`\", `\"imag`\", `\"angle`\",\n            `\"stft`\". Defaults to `\"log`\".\n        frame_length: Integer, The length of each frame (window) for STFT in\n            samples. Defaults to 256.\n        frame_step: Integer, the step size (hop length) between\n            consecutive frames. If not provided, defaults to half the\n            frame_length. Defaults to `frame_length // 2`.\n        fft_length: Integer, the size of frequency bins used in the Fast-Fourier\n            Transform (FFT) to apply to each frame. Should be greater than or\n            equal to `frame_length`.  Recommended to be a power of two. Defaults\n            to the smallest power of two that is greater than or equal\n            to `frame_length`.\n        window: (String or array_like), the windowing function to apply to each\n            frame. Can be `\"hann`\" (default), `\"hamming`\", or a custom window\n            provided as an array_like.\n        periodic: Boolean, if True, the window function will be treated as\n            periodic. Defaults to `False`.\n        scaling: String, type of scaling applied to the window. Can be\n            `\"density`\", `\"spectrum`\", or None. Default is `\"density`\".\n        padding: String, padding strategy. Can be `\"valid`\" or `\"same`\".\n            Defaults to `\"valid\"`.\n        expand_dims: Boolean, if True, will expand the output into spectrograms\n            into two dimensions to be compatible with image models.\n            Defaults to `False`.\n        data_format: String, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, height, width, channels)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, channels, height, weight)`. Defaults to `\"channels_last\"`.\n\n    Raises:\n        ValueError: If an invalid value is provided for `\"mode`\", `\"scaling`\",\n            `\"padding`\", or other input arguments.\n        TypeError: If the input data type is not one of `\"float16`\",\n            `\"float32`\", or `\"float64`\".\n\n    Input shape:\n        A 3D tensor of shape `(batch_size, time_length, input_channels)`, if\n        `data_format==\"channels_last\"`, and of shape\n        `(batch_size, input_channels, time_length)` if\n        `data_format==\"channels_first\"`, where `time_length` is the length of\n        the input signal, and `input_channels` is the number of input channels.\n        The same kernels are applied to each channel independently.\n\n    Output shape:\n        If `data_format==\"channels_first\" and not expand_dims`, a 3D tensor:\n            `(batch_size, input_channels * freq_channels, new_time_length)`\n        If `data_format==\"channels_last\" and not expand_dims`, a 3D tensor:\n            `(batch_size, new_time_length, input_channels * freq_channels)`\n        If `data_format==\"channels_first\" and expand_dims`, a 4D tensor:\n            `(batch_size, input_channels, new_time_length, freq_channels)`\n        If `data_format==\"channels_last\" and expand_dims`, a 4D tensor:\n            `(batch_size, new_time_length, freq_channels, input_channels)`\n\n        where `new_time_length` depends on the padding, and `freq_channels` is\n        the number of FFT bins `(fft_length // 2 + 1)`.\n    \"\"\"\n\n    def __init__(\n        self,\n        mode=\"log\",\n        frame_length=256,\n        frame_step=None,\n        fft_length=None,\n        window=\"hann\",\n        periodic=False,\n        scaling=\"density\",\n        padding=\"valid\",\n        expand_dims=False,\n        data_format=None,\n        **kwargs,\n    ):\n        if frame_step is not None and (\n            frame_step > frame_length or frame_step < 1\n        ):\n            raise ValueError(\n                \"`frame_step` should be a positive integer not greater than \"\n                f\"`frame_length`. Received frame_step={frame_step}, \"\n                f\"frame_length={frame_length}\"\n            )\n\n        if fft_length is not None and fft_length < frame_length:\n            raise ValueError(\n                \"`fft_length` should be not less than `frame_length`. \"\n                f\"Received fft_length={fft_length}, frame_length={frame_length}\"\n            )\n\n        if fft_length is not None and (fft_length & -fft_length) != fft_length:\n            warnings.warn(\n                \"`fft_length` is recommended to be a power of two. \"\n                f\"Received fft_length={fft_length}\"\n            )\n\n        all_modes = [\"log\", \"magnitude\", \"psd\", \"real\", \"imag\", \"angle\", \"stft\"]\n\n        if mode not in all_modes:\n            raise ValueError(\n                \"Output mode is invalid, it must be one of \"\n                f\"{', '.join(all_modes)}. Received: mode={mode}\"\n            )\n\n        if scaling is not None and scaling not in [\"density\", \"spectrum\"]:\n            raise ValueError(\n                \"Scaling is invalid, it must be `None`, 'density' \"\n                f\"or 'spectrum'. Received scaling={scaling}\"\n            )\n\n        if padding not in [\"valid\", \"same\"]:\n            raise ValueError(\n                \"Padding is invalid, it should be 'valid', 'same'. \"\n                f\"Received: padding={padding}\"\n            )\n\n        if isinstance(window, str):\n            # throws an exception for invalid window function\n            scipy.signal.get_window(window, 1)\n\n        super().__init__(**kwargs)\n\n        self.mode = mode\n\n        self.frame_length = frame_length\n        self.frame_step = frame_step\n        self._frame_step = frame_step or self.frame_length // 2\n        self.fft_length = fft_length\n        self._fft_length = fft_length or (\n            2 ** int(math.ceil(math.log2(frame_length)))\n        )\n\n        self.window = window\n        self.periodic = periodic\n        self.scaling = scaling\n        self.padding = padding\n        self.expand_dims = expand_dims\n        self.data_format = backend.standardize_data_format(data_format)\n        self.input_spec = layers.input_spec.InputSpec(ndim=3)\n\n    def build(self, input_shape):\n        shape = (self.frame_length, 1, self._fft_length // 2 + 1)\n\n        if self.mode != \"imag\":\n            self.real_kernel = self.add_weight(\n                name=\"real_kernel\",\n                shape=shape,\n                initializer=initializers.STFT(\n                    \"real\", self.window, self.scaling, self.periodic\n                ),\n            )\n        if self.mode != \"real\":\n            self.imag_kernel = self.add_weight(\n                name=\"imag_kernel\",\n                shape=shape,\n                initializer=initializers.STFT(\n                    \"imag\", self.window, self.scaling, self.periodic\n                ),\n            )\n\n    def _adjust_shapes(self, outputs):\n        _, channels, freq_channels, time_seq = ops.shape(outputs)\n        batch_size = -1\n        if self.data_format == \"channels_last\":\n            if self.expand_dims:\n                outputs = ops.transpose(outputs, [0, 3, 2, 1])\n                # [batch_size, time_seq, freq_channels, input_channels]\n            else:\n                outputs = ops.reshape(\n                    outputs,\n                    [batch_size, channels * freq_channels, time_seq],\n                )\n                # [batch_size, input_channels * freq_channels, time_seq]\n                outputs = ops.transpose(outputs, [0, 2, 1])\n        else:\n            if self.expand_dims:\n                outputs = ops.transpose(outputs, [0, 1, 3, 2])\n                # [batch_size, channels, time_seq, freq_channels]\n            else:\n                outputs = ops.reshape(\n                    outputs,\n                    [batch_size, channels * freq_channels, time_seq],\n                )\n        return outputs\n\n    def _apply_conv(self, inputs, kernel):\n        if self.data_format == \"channels_last\":\n            _, time_seq, channels = ops.shape(inputs)\n            inputs = ops.transpose(inputs, [0, 2, 1])\n            inputs = ops.reshape(inputs, [-1, time_seq, 1])\n        else:\n            _, channels, time_seq = ops.shape(inputs)\n            inputs = ops.reshape(inputs, [-1, 1, time_seq])\n\n        outputs = ops.conv(\n            inputs,\n            ops.cast(kernel, backend.standardize_dtype(inputs.dtype)),\n            padding=self.padding,\n            strides=self._frame_step,\n            data_format=self.data_format,\n        )\n        batch_size = -1\n        if self.data_format == \"channels_last\":\n            _, time_seq, freq_channels = ops.shape(outputs)\n            outputs = ops.transpose(outputs, [0, 2, 1])\n            outputs = ops.reshape(\n                outputs,\n                [batch_size, channels, freq_channels, time_seq],\n            )\n        else:\n            _, freq_channels, time_seq = ops.shape(outputs)\n            outputs = ops.reshape(\n                outputs,\n                [batch_size, channels, freq_channels, time_seq],\n            )\n        return outputs\n\n    def call(self, inputs):\n        dtype = inputs.dtype\n        if backend.standardize_dtype(dtype) not in {\n            \"float16\",\n            \"float32\",\n            \"float64\",\n        }:\n            raise TypeError(\n                \"Invalid input type. Expected `float16`, `float32` or \"\n                f\"`float64`. Received: input type={dtype}\"\n            )\n\n        real_signal = None\n        imag_signal = None\n        power = None\n\n        if self.mode != \"imag\":\n            real_signal = self._apply_conv(inputs, self.real_kernel)\n        if self.mode != \"real\":\n            imag_signal = self._apply_conv(inputs, self.imag_kernel)\n\n        if self.mode == \"real\":\n            return self._adjust_shapes(real_signal)\n        elif self.mode == \"imag\":\n            return self._adjust_shapes(imag_signal)\n        elif self.mode == \"angle\":\n            return self._adjust_shapes(ops.arctan2(imag_signal, real_signal))\n        elif self.mode == \"stft\":\n            return self._adjust_shapes(\n                ops.concatenate([real_signal, imag_signal], axis=2)\n            )\n        else:\n            power = ops.square(real_signal) + ops.square(imag_signal)\n\n        if self.mode == \"psd\":\n            return self._adjust_shapes(\n                power\n                + ops.pad(\n                    power[:, :, 1:-1, :], [[0, 0], [0, 0], [1, 1], [0, 0]]\n                )\n            )\n        linear_stft = self._adjust_shapes(\n            ops.sqrt(ops.maximum(power, backend.epsilon()))\n        )\n\n        if self.mode == \"magnitude\":\n            return linear_stft\n        else:\n            return ops.log(ops.maximum(linear_stft, backend.epsilon()))\n\n    def compute_output_shape(self, input_shape):\n        if self.data_format == \"channels_last\":\n            channels = input_shape[-1]\n        else:\n            channels = input_shape[1]\n        freq_channels = self._fft_length // 2 + 1\n        if self.mode == \"stft\":\n            freq_channels *= 2\n        shape = ops.operation_utils.compute_conv_output_shape(\n            input_shape,\n            freq_channels * channels,\n            (self.frame_length,),\n            strides=self._frame_step,\n            padding=self.padding,\n            data_format=self.data_format,\n        )\n        if self.data_format == \"channels_last\":\n            batch_size, time_seq, _ = shape\n        else:\n            batch_size, _, time_seq = shape\n        if self.expand_dims:\n            if self.data_format == \"channels_last\":\n                return (batch_size, time_seq, freq_channels, channels)\n            else:\n                return (batch_size, channels, time_seq, freq_channels)\n        return shape\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"mode\": self.mode,\n                \"frame_length\": self.frame_length,\n                \"frame_step\": self.frame_step,\n                \"fft_length\": self.fft_length,\n                \"window\": self.window,\n                \"periodic\": self.periodic,\n                \"scaling\": self.scaling,\n                \"padding\": self.padding,\n                \"data_format\": self.data_format,\n                \"expand_dims\": self.expand_dims,\n            }\n        )\n        return config\n"
  },
  {
    "path": "keras/src/layers/preprocessing/stft_spectrogram_test.py",
    "content": "import numpy as np\nimport pytest\nimport scipy.signal\nimport tensorflow as tf\n\nfrom keras import Input\nfrom keras import Sequential\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass TestSpectrogram(testing.TestCase):\n    DTYPE = \"float32\"\n\n    @staticmethod\n    def _calc_spectrograms(\n        x, mode, scaling, window, periodic, frame_length, frame_step, fft_length\n    ):\n        data_format = backend.image_data_format()\n        input_shape = (None, 1) if data_format == \"channels_last\" else (1, None)\n\n        layer = Sequential(\n            [\n                Input(shape=input_shape, dtype=TestSpectrogram.DTYPE),\n                layers.STFTSpectrogram(\n                    mode=mode,\n                    frame_length=frame_length,\n                    frame_step=frame_step,\n                    fft_length=fft_length,\n                    window=window,\n                    scaling=scaling,\n                    periodic=periodic,\n                    dtype=TestSpectrogram.DTYPE,\n                ),\n            ]\n        )\n        if data_format == \"channels_first\":\n            y = layer.predict(np.transpose(x, [0, 2, 1]), verbose=0)\n            y = np.transpose(y, [0, 2, 1])\n        else:\n            y = layer.predict(x, verbose=0)\n\n        window_arr = scipy.signal.get_window(window, frame_length, periodic)\n        _, _, spec = scipy.signal.spectrogram(\n            x[..., 0].astype(TestSpectrogram.DTYPE),\n            window=window_arr.astype(TestSpectrogram.DTYPE),\n            nperseg=frame_length,\n            noverlap=frame_length - frame_step,\n            mode=mode,\n            scaling=scaling,\n            detrend=False,\n            nfft=fft_length,\n        )\n        y_true = np.transpose(spec, [0, 2, 1])\n        return y_true, y\n\n    @pytest.mark.requires_trainable_backend\n    def test_spectrogram_channels_broadcasting(self):\n        rnd = np.random.RandomState(41)\n        audio = rnd.uniform(-1, 1, size=(3, 16000, 7))\n\n        layer_last = Sequential(\n            [\n                Input(shape=(None, 7), dtype=self.DTYPE),\n                layers.STFTSpectrogram(\n                    mode=\"psd\", dtype=self.DTYPE, data_format=\"channels_last\"\n                ),\n            ]\n        )\n        layer_single = Sequential(\n            [\n                Input(shape=(None, 1), dtype=self.DTYPE),\n                layers.STFTSpectrogram(\n                    mode=\"psd\", dtype=self.DTYPE, data_format=\"channels_last\"\n                ),\n            ]\n        )\n\n        layer_expand = Sequential(\n            [\n                Input(shape=(None, 7), dtype=self.DTYPE),\n                layers.STFTSpectrogram(\n                    mode=\"psd\",\n                    dtype=self.DTYPE,\n                    data_format=\"channels_last\",\n                    expand_dims=True,\n                ),\n            ]\n        )\n\n        y_last = layer_last.predict(audio, verbose=0)\n        y_expanded = layer_expand.predict(audio, verbose=0)\n        y_singles = [\n            layer_single.predict(audio[..., i : i + 1], verbose=0)\n            for i in range(audio.shape[-1])\n        ]\n\n        self.assertAllClose(y_last, np.concatenate(y_singles, axis=-1))\n        self.assertAllClose(y_expanded, np.stack(y_singles, axis=-1))\n\n    @pytest.mark.skipif(\n        backend.backend() == \"tensorflow\",\n        reason=\"TF doesn't support channels_first\",\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_spectrogram_channels_first(self):\n        rnd = np.random.RandomState(41)\n        audio = rnd.uniform(-1, 1, size=(3, 16000, 7))\n\n        layer_first = Sequential(\n            [\n                Input(shape=(7, None), dtype=self.DTYPE),\n                layers.STFTSpectrogram(\n                    mode=\"psd\", dtype=self.DTYPE, data_format=\"channels_first\"\n                ),\n            ]\n        )\n        layer_last = Sequential(\n            [\n                Input(shape=(None, 7), dtype=self.DTYPE),\n                layers.STFTSpectrogram(\n                    mode=\"psd\", dtype=self.DTYPE, data_format=\"channels_last\"\n                ),\n            ]\n        )\n        layer_single = Sequential(\n            [\n                Input(shape=(None, 1), dtype=self.DTYPE),\n                layers.STFTSpectrogram(\n                    mode=\"psd\", dtype=self.DTYPE, data_format=\"channels_last\"\n                ),\n            ]\n        )\n        layer_expand = Sequential(\n            [\n                Input(shape=(7, None), dtype=self.DTYPE),\n                layers.STFTSpectrogram(\n                    mode=\"psd\",\n                    dtype=self.DTYPE,\n                    data_format=\"channels_first\",\n                    expand_dims=True,\n                ),\n            ]\n        )\n\n        y_singles = [\n            layer_single.predict(audio[..., i : i + 1], verbose=0)\n            for i in range(audio.shape[-1])\n        ]\n        y_expanded = layer_expand.predict(\n            np.transpose(audio, [0, 2, 1]), verbose=0\n        )\n        y_last = layer_last.predict(audio, verbose=0)\n        y_first = layer_first.predict(np.transpose(audio, [0, 2, 1]), verbose=0)\n        self.assertAllClose(np.transpose(y_first, [0, 2, 1]), y_last)\n        self.assertAllClose(y_expanded, np.stack(y_singles, axis=1))\n        self.assertAllClose(\n            y_first,\n            np.transpose(np.concatenate(y_singles, axis=-1), [0, 2, 1]),\n        )\n        self.run_layer_test(\n            layers.STFTSpectrogram,\n            init_kwargs={\n                \"frame_length\": 150,\n                \"frame_step\": 10,\n                \"fft_length\": 512,\n                \"trainable\": False,\n                \"padding\": \"same\",\n                \"expand_dims\": True,\n                \"data_format\": \"channels_first\",\n            },\n            input_shape=(2, 3, 160000),\n            expected_output_shape=(2, 3, 160000 // 10, 257),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=2,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_spectrogram_basics(self):\n        self.run_layer_test(\n            layers.STFTSpectrogram,\n            init_kwargs={\n                \"frame_length\": 500,\n                \"frame_step\": 25,\n                \"fft_length\": 1024,\n                \"mode\": \"stft\",\n                \"data_format\": \"channels_last\",\n            },\n            input_shape=(2, 16000, 1),\n            expected_output_shape=(2, 15500 // 25 + 1, 513 * 2),\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n\n        self.run_layer_test(\n            layers.STFTSpectrogram,\n            init_kwargs={\n                \"frame_length\": 150,\n                \"frame_step\": 71,\n                \"fft_length\": 4096,\n                \"mode\": \"real\",\n                \"data_format\": \"channels_last\",\n            },\n            input_shape=(2, 160000, 1),\n            expected_output_shape=(2, 159850 // 71 + 1, 2049),\n            expected_num_trainable_weights=1,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n\n        self.run_layer_test(\n            layers.STFTSpectrogram,\n            init_kwargs={\n                \"frame_length\": 150,\n                \"frame_step\": 43,\n                \"fft_length\": 512,\n                \"mode\": \"imag\",\n                \"padding\": \"same\",\n                \"data_format\": \"channels_last\",\n            },\n            input_shape=(2, 160000, 1),\n            expected_output_shape=(2, 160000 // 43 + 1, 257),\n            expected_num_trainable_weights=1,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n        self.run_layer_test(\n            layers.STFTSpectrogram,\n            init_kwargs={\n                \"frame_length\": 150,\n                \"frame_step\": 10,\n                \"fft_length\": 512,\n                \"trainable\": False,\n                \"padding\": \"same\",\n                \"data_format\": \"channels_last\",\n            },\n            input_shape=(2, 160000, 3),\n            expected_output_shape=(2, 160000 // 10, 257 * 3),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=2,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n        self.run_layer_test(\n            layers.STFTSpectrogram,\n            init_kwargs={\n                \"frame_length\": 150,\n                \"frame_step\": 10,\n                \"fft_length\": 512,\n                \"trainable\": False,\n                \"padding\": \"same\",\n                \"expand_dims\": True,\n                \"data_format\": \"channels_last\",\n            },\n            input_shape=(2, 160000, 3),\n            expected_output_shape=(2, 160000 // 10, 257, 3),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=2,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"Backend does not support dynamic shapes\",\n    )\n    def test_spectrogram_dynamic_shape(self):\n        model = Sequential(\n            [\n                Input(shape=(None, 1), dtype=TestSpectrogram.DTYPE),\n                layers.STFTSpectrogram(\n                    frame_length=500,\n                    frame_step=25,\n                    fft_length=1024,\n                    mode=\"stft\",\n                    data_format=\"channels_last\",\n                ),\n            ]\n        )\n\n        def generator():\n            yield (np.random.random((2, 16000, 1)),)\n            yield (np.random.random((3, 8000, 1)),)\n\n        model.predict(generator())\n\n    @pytest.mark.requires_trainable_backend\n    def test_spectrogram_error(self):\n        rnd = np.random.RandomState(41)\n        x = rnd.uniform(low=-1, high=1, size=(4, 160000, 1)).astype(self.DTYPE)\n        names = [\n            \"scaling\",\n            \"window\",\n            \"periodic\",\n            \"frame_length\",\n            \"frame_step\",\n            \"fft_length\",\n        ]\n        for args in [\n            (\"density\", \"hann\", False, 512, 256, 1024),\n            (\"spectrum\", \"blackman\", True, 512, 32, 1024),\n            (\"spectrum\", \"hamming\", True, 256, 192, 512),\n            (\"spectrum\", \"tukey\", False, 512, 128, 512),\n            (\"density\", \"hamming\", True, 256, 256, 256),\n            (\"density\", \"hann\", True, 256, 128, 256),\n        ]:\n            init_args = dict(zip(names, args))\n\n            if testing.uses_tpu():\n                tol_kwargs = {\"atol\": 5e-2, \"rtol\": 1e-3}\n            else:\n                tol_kwargs = {\"atol\": 5e-4, \"rtol\": 1e-6}\n\n            init_args[\"mode\"] = \"magnitude\"\n            y_true, y = self._calc_spectrograms(x, **init_args)\n            self.assertEqual(np.shape(y_true), np.shape(y))\n            self.assertAllClose(y_true, y, **tol_kwargs)\n\n            init_args[\"mode\"] = \"psd\"\n            y_true, y = self._calc_spectrograms(x, **init_args)\n            self.assertEqual(np.shape(y_true), np.shape(y))\n            self.assertAllClose(y_true, y, **tol_kwargs)\n\n            init_args[\"mode\"] = \"angle\"\n            y_true, y = self._calc_spectrograms(x, **init_args)\n\n            mask = np.isclose(y, y_true, **tol_kwargs)\n            mask |= np.isclose(y + 2 * np.pi, y_true, **tol_kwargs)\n            mask |= np.isclose(y - 2 * np.pi, y_true, **tol_kwargs)\n            mask |= np.isclose(np.cos(y), np.cos(y_true), **tol_kwargs)\n            mask |= np.isclose(np.sin(y), np.sin(y_true), **tol_kwargs)\n\n            self.assertLess(np.mean(~mask), 2e-4)\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"Requires TF tensors for TF-data module.\",\n    )\n    def test_tf_data_compatibility(self):\n        input_shape = (2, 16000, 1)\n        output_shape = (2, 16000 // 128, 358)\n        layer = layers.STFTSpectrogram(\n            frame_length=256,\n            frame_step=128,\n            fft_length=715,\n            padding=\"same\",\n            scaling=None,\n        )\n        input_data = np.random.random(input_shape)\n        ds = tf.data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        for output in ds.take(1):\n            output = output.numpy()\n        self.assertEqual(tuple(output.shape), output_shape)\n\n    def test_exceptions(self):\n        with self.assertRaises(ValueError):\n            layers.STFTSpectrogram(\n                frame_length=256, frame_step=1024, fft_length=512\n            )\n        with self.assertRaises(ValueError):\n            layers.STFTSpectrogram(\n                frame_length=256, frame_step=0, fft_length=512\n            )\n        with self.assertRaises(ValueError):\n            layers.STFTSpectrogram(\n                frame_length=256, frame_step=32, fft_length=128\n            )\n        with self.assertRaises(ValueError):\n            layers.STFTSpectrogram(padding=\"mypadding\")\n        with self.assertRaises(ValueError):\n            layers.STFTSpectrogram(scaling=\"l2\")\n        with self.assertRaises(ValueError):\n            layers.STFTSpectrogram(mode=\"spectrogram\")\n        with self.assertRaises(ValueError):\n            layers.STFTSpectrogram(window=\"unknowable\")\n        with self.assertRaises(ValueError):\n            layers.STFTSpectrogram(scaling=\"l2\")\n        with self.assertRaises(ValueError):\n            layers.STFTSpectrogram(padding=\"divide\")\n        with self.assertRaises(TypeError):\n            layers.STFTSpectrogram()(\n                np.random.randint(0, 255, size=(2, 16000, 1))\n            )\n"
  },
  {
    "path": "keras/src/layers/preprocessing/string_lookup.py",
    "content": "import numpy as np\n\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.index_lookup import IndexLookup\nfrom keras.src.utils import backend_utils\nfrom keras.src.utils.module_utils import tensorflow as tf\n\nif backend.backend() == \"torch\":\n    import torch\n\n\n@keras_export(\"keras.layers.StringLookup\")\nclass StringLookup(IndexLookup):\n    \"\"\"A preprocessing layer that maps strings to (possibly encoded) indices.\n\n    This layer translates a set of arbitrary strings into integer output via a\n    table-based vocabulary lookup. This layer will perform no splitting or\n    transformation of input strings. For a layer that can split and tokenize\n    natural language, see the `keras.layers.TextVectorization` layer.\n\n    The vocabulary for the layer must be either supplied on construction or\n    learned via `adapt()`. During `adapt()`, the layer will analyze a data set,\n    determine the frequency of individual strings tokens, and create a\n    vocabulary from them. If the vocabulary is capped in size, the most frequent\n    tokens will be used to create the vocabulary and all others will be treated\n    as out-of-vocabulary (OOV).\n\n    There are two possible output modes for the layer. When `output_mode` is\n    `\"int\"`, input strings are converted to their index in the vocabulary (an\n    integer).\n    When `output_mode` is `\"multi_hot\"`, `\"count\"`, or `\"tf_idf\"`, input strings\n    are encoded into an array where each dimension corresponds to an element in\n    the vocabulary.\n\n    The vocabulary can optionally contain a mask token as well as an OOV token\n    (which can optionally occupy multiple indices in the vocabulary, as set\n    by `num_oov_indices`).\n    The position of these tokens in the vocabulary is fixed. When `output_mode`\n    is `\"int\"`, the vocabulary will begin with the mask token (if set), followed\n    by OOV indices, followed by the rest of the vocabulary. When `output_mode`\n    is `\"multi_hot\"`, `\"count\"`, or `\"tf_idf\"` the vocabulary will begin with\n    OOV indices and instances of the mask token will be dropped.\n\n    **Note:** This layer uses TensorFlow internally. It cannot\n    be used as part of the compiled computation graph of a model with\n    any backend other than TensorFlow.\n    It can however be used with any backend when running eagerly.\n    It can also always be used as part of an input preprocessing pipeline\n    with any backend (outside the model itself), which is how we recommend\n    using this layer.\n\n    **Note:** This layer is safe to use inside a `tf.data` pipeline\n    (independently of which backend you're using).\n\n    Args:\n        max_tokens: Maximum size of the vocabulary for this layer. This should\n            only be specified when adapting the vocabulary or when setting\n            `pad_to_max_tokens=True`. If None, there is no cap on the size of\n            the vocabulary. Note that this size includes the OOV\n            and mask tokens. Defaults to `None`.\n        num_oov_indices: The number of out-of-vocabulary tokens to use.\n            If this value is more than 1, OOV inputs are modulated to\n            determine their OOV value.\n            If this value is 0, OOV inputs will cause an error when calling\n            the layer. Defaults to `1`.\n        mask_token: A token that represents masked inputs. When `output_mode` is\n            `\"int\"`, the token is included in the vocabulary and mapped to index\n            0.\n            In other output modes, the token will not appear in the vocabulary\n            and instances of the mask token in the input will be dropped.\n            If set to `None`, no mask term will be added. Defaults to `None`.\n        oov_token: Only used when `invert` is True. The token to return for OOV\n            indices. Defaults to `\"[UNK]\"`.\n        vocabulary: Optional. Either an array of strings or a string path to a\n            text file. If passing an array, you can pass a tuple, list, 1D NumPy\n            array, or 1D tensor containing the string vocabulary terms.\n            If passing a file path, the file should contain one line per term in\n            the vocabulary. If this argument is set, there is no need to\n            `adapt()` the layer.\n        idf_weights: Only valid when `output_mode` is `\"tf_idf\"`.\n            A tuple, list, 1D NumPy array, or 1D tensor or the same length\n            as the vocabulary, containing the floating point inverse document\n            frequency weights, which will be multiplied by per sample term\n            counts for the final TF-IDF weight.\n            If the `vocabulary` argument is set and `output_mode` is `\"tf_idf\"`,\n            this argument must be supplied.\n        invert: Only valid when `output_mode` is `\"int\"`.\n            If `True`, this layer will map indices to vocabulary items\n            instead of mapping vocabulary items to indices.\n            Defaults to `False`.\n        output_mode: Specification for the output of the layer. Values can be\n            `\"int\"`, `\"one_hot\"`, `\"multi_hot\"`, `\"count\"`, or `\"tf_idf\"`\n            configuring the layer as follows:\n            - `\"int\"`: Return the vocabulary indices of the input tokens.\n            - `\"one_hot\"`: Encodes each individual element in the input into an\n                array the same size as the vocabulary,\n                containing a 1 at the element index. If the last dimension\n                is size 1, will encode on that dimension.\n                If the last dimension is not size 1, will append a new\n                dimension for the encoded output.\n            - `\"multi_hot\"`: Encodes each sample in the input into a single\n                array the same size as the vocabulary containing a 1 for each\n                vocabulary term present in the sample.\n                Treats the last dimension as the sample dimension, if the input\n                shape is `(..., sample_length)`, the output shape will be\n                `(..., num_tokens)`.\n            - `\"count\"`: As `\"multi_hot\"`, but the int array contains\n                a count of the number of times the token at that index\n                appeared in the sample.\n            - `\"tf_idf\"`: As `\"multi_hot\"`, but the TF-IDF algorithm is\n                applied to find the value in each token slot.\n            For `\"int\"` output, any shape of input and output is supported.\n            For all other output modes, currently only output up to rank 2\n            is supported. Defaults to `\"int\"`.\n        pad_to_max_tokens: Only applicable when `output_mode` is `\"multi_hot\"`,\n            `\"count\"`, or `\"tf_idf\"`. If `True`, the output will have\n            its feature axis padded to `max_tokens` even if the number\n            of unique tokens in the vocabulary is less than `max_tokens`,\n            resulting in a tensor of shape `(batch_size, max_tokens)`\n            regardless of vocabulary size. Defaults to `False`.\n        sparse: Boolean. Only applicable to `\"multi_hot\"`, `\"count\"`, and\n            `\"tf_idf\"` output modes. Only supported with TensorFlow\n            backend. If `True`, returns a `SparseTensor`\n            instead of a dense `Tensor`. Defaults to `False`.\n        encoding: Optional. The text encoding to use to interpret the input\n            strings. Defaults to `\"utf-8\"`.\n\n    Examples:\n\n    **Creating a lookup layer with a known vocabulary**\n\n    This example creates a lookup layer with a pre-existing vocabulary.\n\n    >>> vocab = [\"a\", \"b\", \"c\", \"d\"]\n    >>> data = [[\"a\", \"c\", \"d\"], [\"d\", \"z\", \"b\"]]\n    >>> layer = StringLookup(vocabulary=vocab)\n    >>> layer(data)\n    array([[1, 3, 4],\n           [4, 0, 2]])\n\n    **Creating a lookup layer with an adapted vocabulary**\n\n    This example creates a lookup layer and generates the vocabulary by\n    analyzing the dataset.\n\n    >>> data = [[\"a\", \"c\", \"d\"], [\"d\", \"z\", \"b\"]]\n    >>> layer = StringLookup()\n    >>> layer.adapt(data)\n    >>> layer.get_vocabulary()\n    ['[UNK]', 'd', 'z', 'c', 'b', 'a']\n\n    Note that the OOV token `\"[UNK]\"` has been added to the vocabulary.\n    The remaining tokens are sorted by frequency\n    (`\"d\"`, which has 2 occurrences, is first) then by inverse sort order.\n\n    >>> data = [[\"a\", \"c\", \"d\"], [\"d\", \"z\", \"b\"]]\n    >>> layer = StringLookup()\n    >>> layer.adapt(data)\n    >>> layer(data)\n    array([[5, 3, 1],\n           [1, 2, 4]])\n\n    **Lookups with multiple OOV indices**\n\n    This example demonstrates how to use a lookup layer with multiple OOV\n    indices.  When a layer is created with more than one OOV index, any OOV\n    values are hashed into the number of OOV buckets, distributing OOV values in\n    a deterministic fashion across the set.\n\n    >>> vocab = [\"a\", \"b\", \"c\", \"d\"]\n    >>> data = [[\"a\", \"c\", \"d\"], [\"m\", \"z\", \"b\"]]\n    >>> layer = StringLookup(vocabulary=vocab, num_oov_indices=2)\n    >>> layer(data)\n    array([[2, 4, 5],\n           [0, 1, 3]])\n\n    Note that the output for OOV value 'm' is 0, while the output for OOV value\n    `\"z\"` is 1. The in-vocab terms have their output index increased by 1 from\n    earlier examples (a maps to 2, etc) in order to make space for the extra OOV\n    value.\n\n    **One-hot output**\n\n    Configure the layer with `output_mode='one_hot'`. Note that the first\n    `num_oov_indices` dimensions in the ont_hot encoding represent OOV values.\n\n    >>> vocab = [\"a\", \"b\", \"c\", \"d\"]\n    >>> data = [\"a\", \"b\", \"c\", \"d\", \"z\"]\n    >>> layer = StringLookup(vocabulary=vocab, output_mode='one_hot')\n    >>> layer(data)\n    array([[0., 1., 0., 0., 0.],\n           [0., 0., 1., 0., 0.],\n           [0., 0., 0., 1., 0.],\n           [0., 0., 0., 0., 1.],\n           [1., 0., 0., 0., 0.]], dtype=int64)\n\n    **Multi-hot output**\n\n    Configure the layer with `output_mode='multi_hot'`. Note that the first\n    `num_oov_indices` dimensions in the multi_hot encoding represent OOV values.\n\n    >>> vocab = [\"a\", \"b\", \"c\", \"d\"]\n    >>> data = [[\"a\", \"c\", \"d\", \"d\"], [\"d\", \"z\", \"b\", \"z\"]]\n    >>> layer = StringLookup(vocabulary=vocab, output_mode='multi_hot')\n    >>> layer(data)\n    array([[0., 1., 0., 1., 1.],\n           [1., 0., 1., 0., 1.]], dtype=int64)\n\n    **Token count output**\n\n    Configure the layer with `output_mode='count'`. As with multi_hot output,\n    the first `num_oov_indices` dimensions in the output represent OOV values.\n\n    >>> vocab = [\"a\", \"b\", \"c\", \"d\"]\n    >>> data = [[\"a\", \"c\", \"d\", \"d\"], [\"d\", \"z\", \"b\", \"z\"]]\n    >>> layer = StringLookup(vocabulary=vocab, output_mode='count')\n    >>> layer(data)\n    array([[0., 1., 0., 1., 2.],\n           [2., 0., 1., 0., 1.]], dtype=int64)\n\n    **TF-IDF output**\n\n    Configure the layer with `output_mode=\"tf_idf\"`. As with multi_hot output,\n    the first `num_oov_indices` dimensions in the output represent OOV values.\n\n    Each token bin will output `token_count * idf_weight`, where the idf weights\n    are the inverse document frequency weights per token. These should be\n    provided along with the vocabulary. Note that the `idf_weight` for OOV\n    values will default to the average of all idf weights passed in.\n\n    >>> vocab = [\"a\", \"b\", \"c\", \"d\"]\n    >>> idf_weights = [0.25, 0.75, 0.6, 0.4]\n    >>> data = [[\"a\", \"c\", \"d\", \"d\"], [\"d\", \"z\", \"b\", \"z\"]]\n    >>> layer = StringLookup(output_mode=\"tf_idf\")\n    >>> layer.set_vocabulary(vocab, idf_weights=idf_weights)\n    >>> layer(data)\n    array([[0.  , 0.25, 0.  , 0.6 , 0.8 ],\n           [1.0 , 0.  , 0.75, 0.  , 0.4 ]], dtype=float32)\n\n    To specify the idf weights for OOV values, you will need to pass the entire\n    vocabulary including the leading OOV token.\n\n    >>> vocab = [\"[UNK]\", \"a\", \"b\", \"c\", \"d\"]\n    >>> idf_weights = [0.9, 0.25, 0.75, 0.6, 0.4]\n    >>> data = [[\"a\", \"c\", \"d\", \"d\"], [\"d\", \"z\", \"b\", \"z\"]]\n    >>> layer = StringLookup(output_mode=\"tf_idf\")\n    >>> layer.set_vocabulary(vocab, idf_weights=idf_weights)\n    >>> layer(data)\n    array([[0.  , 0.25, 0.  , 0.6 , 0.8 ],\n           [1.8 , 0.  , 0.75, 0.  , 0.4 ]], dtype=float32)\n\n    When adapting the layer in `\"tf_idf\"` mode, each input sample will be\n    considered a document, and IDF weight per token will be calculated as\n    `log(1 + num_documents / (1 + token_document_count))`.\n\n    **Inverse lookup**\n\n    This example demonstrates how to map indices to strings using this layer.\n    (You can also use `adapt()` with `inverse=True`, but for simplicity we'll\n    pass the vocab in this example.)\n\n    >>> vocab = [\"a\", \"b\", \"c\", \"d\"]\n    >>> data = [[1, 3, 4], [4, 0, 2]]\n    >>> layer = StringLookup(vocabulary=vocab, invert=True)\n    >>> layer(data)\n    array([[b'a', b'c', b'd'],\n           [b'd', b'[UNK]', b'b']], dtype=object)\n\n    Note that the first index corresponds to the OOV token by default.\n\n\n    **Forward and inverse lookup pairs**\n\n    This example demonstrates how to use the vocabulary of a standard lookup\n    layer to create an inverse lookup layer.\n\n    >>> vocab = [\"a\", \"b\", \"c\", \"d\"]\n    >>> data = [[\"a\", \"c\", \"d\"], [\"d\", \"z\", \"b\"]]\n    >>> layer = StringLookup(vocabulary=vocab)\n    >>> i_layer = StringLookup(vocabulary=vocab, invert=True)\n    >>> int_data = layer(data)\n    >>> i_layer(int_data)\n    array([[b'a', b'c', b'd'],\n           [b'd', b'[UNK]', b'b']], dtype=object)\n\n    In this example, the input value `\"z\"` resulted in an output of `\"[UNK]\"`,\n    since 1000 was not in the vocabulary - it got represented as an OOV, and all\n    OOV values are returned as `\"[UNK]\"` in the inverse layer. Also, note that\n    for the inverse to work, you must have already set the forward layer\n    vocabulary either directly or via `adapt()` before calling\n    `get_vocabulary()`.\n    \"\"\"\n\n    def __init__(\n        self,\n        max_tokens=None,\n        num_oov_indices=1,\n        mask_token=None,\n        oov_token=\"[UNK]\",\n        vocabulary=None,\n        idf_weights=None,\n        invert=False,\n        output_mode=\"int\",\n        pad_to_max_tokens=False,\n        sparse=False,\n        encoding=\"utf-8\",\n        name=None,\n        **kwargs,\n    ):\n        if not tf.available:\n            raise ImportError(\n                \"Layer StringLookup requires TensorFlow. \"\n                \"Install it via `pip install tensorflow`.\"\n            )\n        if sparse and backend.backend() != \"tensorflow\":\n            raise ValueError(\n                \"`sparse=True` can only be used with the TensorFlow backend.\"\n            )\n        self.encoding = encoding\n        super().__init__(\n            max_tokens=max_tokens,\n            num_oov_indices=num_oov_indices,\n            mask_token=mask_token,\n            oov_token=oov_token,\n            vocabulary=vocabulary,\n            idf_weights=idf_weights,\n            invert=invert,\n            output_mode=output_mode,\n            pad_to_max_tokens=pad_to_max_tokens,\n            sparse=sparse,\n            name=name,\n            vocabulary_dtype=\"string\",\n            **kwargs,\n        )\n        self._convert_input_args = False\n        self._allow_non_tensor_positional_args = True\n        self.supports_jit = False\n\n    def adapt(self, data, steps=None):\n        \"\"\"Computes a vocabulary of terms from tokens in a dataset.\n\n        Calling `adapt()` on a `StringLookup` layer is an alternative to passing\n        in a precomputed vocabulary on construction via the `vocabulary`\n        argument. A `StringLookup` layer should always be either adapted over a\n        dataset or supplied with a vocabulary.\n\n        During `adapt()`, the layer will build a vocabulary of all string tokens\n        seen in the dataset, sorted by occurrence count, with ties broken by\n        sort order of the tokens (high to low). At the end of `adapt()`, if\n        `max_tokens` is set, the vocabulary will be truncated to `max_tokens`\n        size. For example, adapting a layer with `max_tokens=1000` will compute\n        the 1000 most frequent tokens occurring in the input dataset. If\n        `output_mode='tf-idf'`, `adapt()` will also learn the document\n        frequencies of each token in the input dataset.\n\n        Arguments:\n            data: The data to train on. It can be passed either as a\n                batched `tf.data.Dataset`, as a list of strings,\n                or as a NumPy array.\n            steps: Integer or `None`.\n                Total number of steps (batches of samples) to process.\n                If `data` is a `tf.data.Dataset`, and `steps` is `None`,\n                `adapt()` will run until the input dataset is exhausted.\n                When passing an infinitely\n                repeating dataset, you must specify the `steps` argument. This\n                argument is not supported with array inputs or list inputs.\n        \"\"\"\n        super().adapt(data, steps=steps)\n\n    # Overridden methods from IndexLookup.\n    def _tensor_vocab_to_numpy(self, vocabulary):\n        vocabulary = vocabulary.numpy()\n        return np.array(\n            [tf.compat.as_text(x, self.encoding) for x in vocabulary]\n        )\n\n    def get_config(self):\n        config = {\"encoding\": self.encoding}\n        base_config = super().get_config()\n        # There is only one valid dtype for strings, so we don't expose this.\n        del base_config[\"vocabulary_dtype\"]\n        return {**base_config, **config}\n\n    def call(self, inputs):\n        is_torch_backend = backend.backend() == \"torch\"\n\n        # Handle input conversion\n        inputs_for_processing = inputs\n        was_tf_input = isinstance(\n            inputs, (tf.Tensor, tf.RaggedTensor, tf.SparseTensor)\n        )\n\n        if is_torch_backend and isinstance(inputs, torch.Tensor):\n            inputs_for_processing = tf.convert_to_tensor(\n                inputs.detach().cpu().numpy()\n            )\n        elif isinstance(inputs, (np.ndarray, list, tuple)):\n            inputs_for_processing = tf.convert_to_tensor(inputs)\n        elif not was_tf_input:\n            inputs_for_processing = tf.convert_to_tensor(\n                backend.convert_to_numpy(inputs)\n            )\n\n        output = super().call(inputs_for_processing)\n\n        # Handle torch backend output conversion\n        if is_torch_backend and isinstance(\n            inputs, (torch.Tensor, np.ndarray, list, tuple)\n        ):\n            numpy_outputs = output.numpy()\n            if self.invert:\n                return [n.decode(self.encoding) for n in numpy_outputs]\n            else:\n                return torch.from_numpy(numpy_outputs)\n\n        # other backends\n        if not was_tf_input:\n            output = backend_utils.convert_tf_tensor(output)\n\n        return output\n"
  },
  {
    "path": "keras/src/layers/preprocessing/string_lookup_test.py",
    "content": "import os\n\nimport numpy as np\nimport pytest\nfrom tensorflow import data as tf_data\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import saving\nfrom keras.src import testing\nfrom keras.src.ops import convert_to_tensor\n\n\nclass StringLookupTest(testing.TestCase):\n    # TODO: increase coverage. Most features aren't being tested.\n\n    def test_config(self):\n        layer = layers.StringLookup(\n            output_mode=\"int\",\n            vocabulary=[\"a\", \"b\", \"c\"],\n            oov_token=\"[OOV]\",\n            mask_token=\"[MASK]\",\n        )\n        self.run_class_serialization_test(layer)\n        self.assertEqual(layer.get_config()[\"vocabulary\"], [\"a\", \"b\", \"c\"])\n\n    def test_vocabulary_file(self):\n        temp_dir = self.get_temp_dir()\n        vocab_path = os.path.join(temp_dir, \"vocab.txt\")\n        with open(vocab_path, \"w\") as file:\n            file.write(\"a\\nb\\nc\\n\")\n\n        layer = layers.StringLookup(\n            output_mode=\"int\",\n            vocabulary=vocab_path,\n            oov_token=\"[OOV]\",\n            mask_token=\"[MASK]\",\n            name=\"index\",\n        )\n        self.assertEqual(\n            [str(v) for v in layer.get_vocabulary()],\n            [\"[MASK]\", \"[OOV]\", \"a\", \"b\", \"c\"],\n        )\n        self.assertIsNone(layer.get_config().get(\"vocabulary\", None))\n\n        # Make sure vocabulary comes from the archive, not the original file.\n        os.remove(vocab_path)\n\n        model = models.Sequential([layer])\n        model_path = os.path.join(temp_dir, \"test_model.keras\")\n        model.save(model_path)\n\n        reloaded_model = saving.load_model(model_path)\n        reloaded_layer = reloaded_model.get_layer(\"index\")\n        self.assertEqual(\n            [str(v) for v in reloaded_layer.get_vocabulary()],\n            [\"[MASK]\", \"[OOV]\", \"a\", \"b\", \"c\"],\n        )\n\n    def test_adapt_flow(self):\n        layer = layers.StringLookup(\n            output_mode=\"int\",\n        )\n        layer.adapt([\"a\", \"a\", \"a\", \"b\", \"b\", \"c\"])\n        input_data = [\"b\", \"c\", \"d\"]\n        output = layer(input_data)\n        self.assertTrue(backend.is_tensor(output))\n        self.assertAllClose(output, np.array([2, 3, 0]))\n\n    def test_fixed_vocabulary(self):\n        layer = layers.StringLookup(\n            output_mode=\"int\",\n            vocabulary=[\"a\", \"b\", \"c\"],\n        )\n        input_data = [\"b\", \"c\", \"d\"]\n        output = layer(input_data)\n        self.assertTrue(backend.is_tensor(output))\n        self.assertAllClose(output, np.array([2, 3, 0]))\n\n    @pytest.mark.skipif(\n        not backend.backend() == \"tensorflow\", reason=\"Requires tf.SparseTensor\"\n    )\n    def test_sparse_inputs(self):\n        import tensorflow as tf\n\n        layer = layers.StringLookup(\n            output_mode=\"int\",\n            vocabulary=[\"a\", \"b\", \"c\"],\n        )\n        input_data = tf.SparseTensor(\n            indices=[[0, 0], [1, 1], [2, 2]],\n            values=[\"b\", \"c\", \"d\"],\n            dense_shape=(3, 3),\n        )\n        output = layer(input_data)\n        self.assertIsInstance(output, tf.SparseTensor)\n        self.assertAllClose(output, np.array([[2, 0, 0], [0, 3, 0], [0, 0, 0]]))\n        self.assertAllClose(output.values, np.array([2, 3, 0]))\n\n    def test_set_vocabulary(self):\n        layer = layers.StringLookup(\n            output_mode=\"int\",\n        )\n        layer.set_vocabulary([\"a\", \"b\", \"c\"])\n        input_data = [\"b\", \"c\", \"d\"]\n        output = layer(input_data)\n        self.assertTrue(backend.is_tensor(output))\n        self.assertAllClose(output, np.array([2, 3, 0]))\n\n    def test_tf_data_compatibility(self):\n        layer = layers.StringLookup(\n            output_mode=\"int\",\n            vocabulary=[\"a\", \"b\", \"c\"],\n        )\n        input_data = [\"b\", \"c\", \"d\"]\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(3).map(layer)\n        output = next(iter(ds)).numpy()\n        self.assertAllClose(output, np.array([2, 3, 0]))\n\n    @pytest.mark.skipif(not backend.backend() == \"tensorflow\", reason=\"tf only\")\n    def test_tensor_as_vocab(self):\n        vocab = convert_to_tensor([\"a\", \"b\", \"c\", \"d\"])\n        data = [[\"a\", \"c\", \"d\"], [\"d\", \"z\", \"b\"]]\n        layer = layers.StringLookup(\n            vocabulary=vocab,\n        )\n        output = layer(data)\n        self.assertAllClose(output, np.array([[1, 3, 4], [4, 0, 2]]))\n\n    @pytest.mark.skipif(backend.backend() != \"torch\", reason=\"Only torch\")\n    def test_torch_backend_compatibility(self):\n        import torch\n\n        # Forward lookup: String -> number\n        forward_lookup = layers.StringLookup(\n            vocabulary=[\"a\", \"b\", \"c\"], oov_token=\"[OOV]\"\n        )\n        input_data_str = [\"a\", \"b\", \"[OOV]\", \"d\"]\n        output_numeric = forward_lookup(input_data_str)\n\n        # assert instance of output is torch.Tensor\n        self.assertIsInstance(output_numeric, torch.Tensor)\n        expected_numeric = torch.tensor([1, 2, 0, 0])\n        self.assertAllClose(output_numeric.cpu(), expected_numeric)\n\n        oov = \"[OOV]\"\n        # Inverse lookup: Number -> string\n        inverse_lookup = layers.StringLookup(\n            vocabulary=[\"a\", \"b\", \"c\"], oov_token=oov, invert=True\n        )\n        input_data_int = torch.tensor([1, 2, 0], dtype=torch.int64)\n        output_string = inverse_lookup(input_data_int)\n        # Assert that the output is a list\n        # See : https://docs.pytorch.org/text/stable/_modules/torchtext/vocab/vocab.html#Vocab.lookup_tokens\n        # The torch equivalent implementation of this returns a list of strings\n        self.assertIsInstance(output_string, list)\n        expected_string = [\"a\", \"b\", \"[OOV]\"]\n        self.assertEqual(output_string, expected_string)\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"invert=True requires TensorFlow string tensors\",\n    )\n    def test_invert_lookup_basic(self):\n        layer = layers.StringLookup(\n            vocabulary=[\"a\", \"b\", \"c\"],\n            invert=True,\n        )\n        output = layer([1, 2, 0])\n        self.assertAllEqual(\n            backend.convert_to_numpy(output).astype(str),\n            [\"a\", \"b\", \"[UNK]\"],\n        )\n\n    def test_output_mode_count_shape(self):\n        layer = layers.StringLookup(\n            vocabulary=[\"a\", \"b\"],\n            output_mode=\"count\",\n        )\n        output = layer([\"a\", \"a\", \"a\", \"b\", \"b\"])\n        self.assertEqual(output.shape[-1], len(layer.get_vocabulary()))\n\n    def test_output_mode_multi_hot_binary(self):\n        layer = layers.StringLookup(\n            vocabulary=[\"a\", \"b\"],\n            output_mode=\"multi_hot\",\n        )\n        output = layer([\"a\", \"b\"])\n        self.assertAllClose(output, [0, 1, 1])\n\n    def test_mask_token_basic(self):\n        layer = layers.StringLookup(\n            vocabulary=[\"a\"],\n            mask_token=\"[MASK]\",\n        )\n        output = layer([\"[MASK]\", \"a\"])\n        self.assertEqual(int(output[0]), 0)\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"Requires tf.SparseTensor\",\n    )\n    def test_sparse_output_in_multi_hot(self):\n        import tensorflow as tf\n\n        layer = layers.StringLookup(\n            vocabulary=[\"a\", \"b\", \"c\"],\n            output_mode=\"multi_hot\",\n            sparse=True,\n        )\n        input_data = tf.ragged.constant([[\"a\", \"b\"], [\"c\", \"a\"]])\n        output = layer(input_data)\n\n        self.assertIsInstance(output, tf.SparseTensor)\n\n    def test_get_vocabulary_include_special_tokens_false(self):\n        layer = layers.StringLookup(\n            vocabulary=[\"a\", \"b\", \"c\"],\n        )\n        vocab = layer.get_vocabulary(include_special_tokens=False)\n\n        self.assertEqual(vocab, [\"a\", \"b\", \"c\"])\n\n    @pytest.mark.skipif(\n        backend.backend() == \"numpy\",\n        reason=(\n            \"StringLookup symbolic string Input not supported on numpy backend.\"\n        ),\n    )\n    def test_one_hot_symbolic_output_shape_nested_input(self):\n        \"\"\"StringLookup one_hot symbolic shape matches eager for nested input.\n\n        Regression test for gh-22336: symbolic output was (None, max_tokens)\n        instead of (None, d1, d2, ..., max_tokens).\n        \"\"\"\n        layer = layers.StringLookup(\n            max_tokens=20,\n            num_oov_indices=4,\n            oov_token=\"[UNK]\",\n            vocabulary=[\"a\", \"b\", \"c\", \"d\", \"e\", \"f\", \"g\", \"h\", \"i\"],\n            output_mode=\"one_hot\",\n            pad_to_max_tokens=True,\n            sparse=False,\n        )\n        # Symbolic input shape (None, 2, 2) as in the issue\n        symbolic_input = layers.Input(shape=(2, 2), dtype=\"string\")\n        symbolic_output = layer(symbolic_input)\n        self.assertEqual(\n            tuple(symbolic_output.shape),\n            (None, 2, 2, 20),\n            msg=\"one_hot symbolic output must be (batch, d1, d2, max_tokens)\",\n        )\n        # Eager: (2, 2, 2) input -> (2, 2, 2, 20) output\n        input_data = np.array(\n            [[[\"a\", \"b\"], [\"c\", \"d\"]], [[\"a\", \"b\"], [\"c\", \"d\"]]]\n        )\n        eager_output = layer(input_data)\n        self.assertEqual(eager_output.shape, (2, 2, 2, 20))\n        self.assertEqual(\n            tuple(symbolic_output.shape)[1:],\n            eager_output.shape[1:],\n        )\n"
  },
  {
    "path": "keras/src/layers/preprocessing/text_vectorization.py",
    "content": "import numpy as np\n\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.layer import Layer\nfrom keras.src.layers.preprocessing.index_lookup import listify_tensors\nfrom keras.src.layers.preprocessing.string_lookup import StringLookup\nfrom keras.src.saving import serialization_lib\nfrom keras.src.utils import argument_validation\nfrom keras.src.utils import backend_utils\nfrom keras.src.utils import tf_utils\nfrom keras.src.utils.module_utils import tensorflow as tf\n\n\n@keras_export(\"keras.layers.TextVectorization\")\nclass TextVectorization(Layer):\n    \"\"\"A preprocessing layer which maps text features to integer sequences.\n\n    This layer has basic options for managing text in a Keras model. It\n    transforms a batch of strings (one example = one string) into either a list\n    of token indices (one example = 1D tensor of integer token indices) or a\n    dense representation (one example = 1D tensor of float values representing\n    data about the example's tokens). This layer is meant to handle natural\n    language inputs. To handle simple string inputs (categorical strings or\n    pre-tokenized strings) see `kers_core.layers.StringLookup`.\n\n    The vocabulary for the layer must be either supplied on construction or\n    learned via `adapt()`. When this layer is adapted, it will analyze the\n    dataset, determine the frequency of individual string values, and create a\n    vocabulary from them. This vocabulary can have unlimited size or be capped,\n    depending on the configuration options for this layer; if there are more\n    unique values in the input than the maximum vocabulary size, the most\n    frequent terms will be used to create the vocabulary.\n\n    The processing of each example contains the following steps:\n\n    1. Standardize each example (usually lowercasing + punctuation stripping)\n    2. Split each example into substrings (usually words)\n    3. Recombine substrings into tokens (usually ngrams)\n    4. Index tokens (associate a unique int value with each token)\n    5. Transform each example using this index, either into a vector of ints or\n       a dense float vector.\n\n    Some notes on passing callables to customize splitting and normalization for\n    this layer:\n\n    1. Any callable can be passed to this Layer, but if you want to serialize\n       this object you should only pass functions that are registered Keras\n       serializables (see `keras.saving.register_keras_serializable`\n       for more details).\n    2. When using a custom callable for `standardize`, the data received\n       by the callable will be exactly as passed to this layer. The callable\n       should return a tensor of the same shape as the input.\n    3. When using a custom callable for `split`, the data received by the\n       callable will have the 1st dimension squeezed out - instead of\n       `[[\"string to split\"], [\"another string to split\"]]`, the Callable will\n       see `[\"string to split\", \"another string to split\"]`.\n       The callable should return a `tf.Tensor` of dtype `string`\n       with the first dimension containing the split tokens -\n       in this example, we should see something like `[[\"string\", \"to\",\n       \"split\"], [\"another\", \"string\", \"to\", \"split\"]]`.\n\n    **Note:** This layer uses TensorFlow internally. It cannot\n    be used as part of the compiled computation graph of a model with\n    any backend other than TensorFlow.\n    It can however be used with any backend when running eagerly.\n    It can also always be used as part of an input preprocessing pipeline\n    with any backend (outside the model itself), which is how we recommend\n    to use this layer.\n\n    **Note:** This layer is safe to use inside a `tf.data` pipeline\n    (independently of which backend you're using).\n\n    Args:\n        max_tokens: Maximum size of the vocabulary for this layer. This should\n            only be specified when adapting a vocabulary or when setting\n            `pad_to_max_tokens=True`. Note that this vocabulary\n            contains 1 OOV token, so the effective number of tokens is\n            `(max_tokens - 1 - (1 if output_mode == \"int\" else 0))`.\n        standardize: Optional specification for standardization to apply to the\n            input text. Values can be:\n            - `None`: No standardization.\n            - `\"lower_and_strip_punctuation\"`: Text will be lowercased and all\n                punctuation removed.\n            - `\"lower\"`: Text will be lowercased.\n            - `\"strip_punctuation\"`: All punctuation will be removed.\n            - Callable: Inputs will passed to the callable function,\n                which should be standardized and returned.\n        split: Optional specification for splitting the input text.\n            Values can be:\n            - `None`: No splitting.\n            - `\"whitespace\"`: Split on whitespace.\n            - `\"character\"`: Split on each unicode character.\n            - Callable: Standardized inputs will passed to the callable\n                function, which should be split and returned.\n        ngrams: Optional specification for ngrams to create from the\n            possibly-split input text. Values can be `None`, an integer\n            or tuple of integers; passing an integer will create ngrams\n            up to that integer, and passing a tuple of integers will\n            create ngrams for the specified values in the tuple.\n            Passing `None` means that no ngrams will be created.\n        output_mode: Optional specification for the output of the layer.\n            Values can be `\"int\"`, `\"multi_hot\"`, `\"count\"` or `\"tf_idf\"`,\n            configuring the layer as follows:\n            - `\"int\"`: Outputs integer indices, one integer index per split\n                string token. When `output_mode == \"int\"`,\n                0 is reserved for masked locations;\n                this reduces the vocab size to `max_tokens - 2`\n                instead of `max_tokens - 1`.\n            - `\"multi_hot\"`: Outputs a single int array per batch, of either\n                vocab_size or max_tokens size, containing 1s in all elements\n                where the token mapped to that index exists at least\n                once in the batch item.\n            - `\"count\"`: Like `\"multi_hot\"`, but the int array contains\n                a count of the number of times the token at that index\n                appeared in the batch item.\n            - `\"tf_idf\"`: Like `\"multi_hot\"`, but the TF-IDF algorithm\n                is applied to find the value in each token slot.\n            For `\"int\"` output, any shape of input and output is supported.\n            For all other output modes, currently only rank 1 inputs\n            (and rank 2 outputs after splitting) are supported.\n        output_sequence_length: Only valid in INT mode. If set, the output will\n            have its time dimension padded or truncated to exactly\n            `output_sequence_length` values, resulting in a tensor of shape\n            `(batch_size, output_sequence_length)` regardless of how many tokens\n            resulted from the splitting step. Defaults to `None`. If `ragged`\n            is `True` then `output_sequence_length` may still truncate the\n            output.\n        pad_to_max_tokens: Only valid in  `\"multi_hot\"`, `\"count\"`,\n            and `\"tf_idf\"` modes. If `True`, the output will have\n            its feature axis padded to `max_tokens` even if the number\n            of unique tokens in the vocabulary is less than `max_tokens`,\n            resulting in a tensor of shape `(batch_size, max_tokens)`\n            regardless of vocabulary size. Defaults to `False`.\n        vocabulary: Optional. Either an array of strings or a string path to a\n            text file. If passing an array, can pass a tuple, list,\n            1D NumPy array, or 1D tensor containing the string vocabulary terms.\n            If passing a file path, the file should contain one line per term\n            in the vocabulary. If this argument is set,\n            there is no need to `adapt()` the layer.\n        idf_weights: Only valid when `output_mode` is `\"tf_idf\"`. A tuple, list,\n            1D NumPy array, or 1D tensor of the same length as the vocabulary,\n            containing the floating point inverse document frequency weights,\n            which will be multiplied by per sample term counts for\n            the final `tf_idf` weight. If the `vocabulary` argument is set,\n            and `output_mode` is `\"tf_idf\"`, this argument must be supplied.\n        ragged: Boolean. Only applicable to `\"int\"` output mode.\n            Only supported with TensorFlow backend.\n            If `True`, returns a `RaggedTensor` instead of a dense `Tensor`,\n            where each sequence may have a different length\n            after string splitting. Defaults to `False`.\n        sparse: Boolean. Only applicable to `\"multi_hot\"`, `\"count\"`, and\n            `\"tf_idf\"` output modes. Only supported with TensorFlow\n            backend. If `True`, returns a `SparseTensor`\n            instead of a dense `Tensor`. Defaults to `False`.\n        encoding: Optional. The text encoding to use to interpret the input\n            strings. Defaults to `\"utf-8\"`.\n\n    Examples:\n\n    This example instantiates a `TextVectorization` layer that lowercases text,\n    splits on whitespace, strips punctuation, and outputs integer vocab indices.\n\n    >>> max_tokens = 5000  # Maximum vocab size.\n    >>> max_len = 4  # Sequence length to pad the outputs to.\n    >>> # Create the layer.\n    >>> vectorize_layer = TextVectorization(\n    ...     max_tokens=max_tokens,\n    ...     output_mode='int',\n    ...     output_sequence_length=max_len)\n\n    >>> # Now that the vocab layer has been created, call `adapt` on the\n    >>> # list of strings to create the vocabulary.\n    >>> vectorize_layer.adapt([\"foo bar\", \"bar baz\", \"baz bada boom\"])\n\n    >>> # Now, the layer can map strings to integers -- you can use an\n    >>> # embedding layer to map these integers to learned embeddings.\n    >>> input_data = [[\"foo qux bar\"], [\"qux baz\"]]\n    >>> vectorize_layer(input_data)\n    array([[4, 1, 3, 0],\n           [1, 2, 0, 0]])\n\n    This example instantiates a `TextVectorization` layer by passing a list\n    of vocabulary terms to the layer's `__init__()` method.\n\n    >>> vocab_data = [\"earth\", \"wind\", \"and\", \"fire\"]\n    >>> max_len = 4  # Sequence length to pad the outputs to.\n    >>> # Create the layer, passing the vocab directly. You can also pass the\n    >>> # vocabulary arg a path to a file containing one vocabulary word per\n    >>> # line.\n    >>> vectorize_layer = keras.layers.TextVectorization(\n    ...     max_tokens=max_tokens,\n    ...     output_mode='int',\n    ...     output_sequence_length=max_len,\n    ...     vocabulary=vocab_data)\n\n    >>> # Because we've passed the vocabulary directly, we don't need to adapt\n    >>> # the layer - the vocabulary is already set. The vocabulary contains the\n    >>> # padding token ('') and OOV token ('[UNK]')\n    >>> # as well as the passed tokens.\n    >>> vectorize_layer.get_vocabulary()\n    ['', '[UNK]', 'earth', 'wind', 'and', 'fire']\n    \"\"\"\n\n    def __init__(\n        self,\n        max_tokens=None,\n        standardize=\"lower_and_strip_punctuation\",\n        split=\"whitespace\",\n        ngrams=None,\n        output_mode=\"int\",\n        output_sequence_length=None,\n        pad_to_max_tokens=False,\n        vocabulary=None,\n        idf_weights=None,\n        sparse=False,\n        ragged=False,\n        encoding=\"utf-8\",\n        name=None,\n        **kwargs,\n    ):\n        if not tf.available:\n            raise ImportError(\n                \"Layer TextVectorization requires TensorFlow. \"\n                \"Install it via `pip install tensorflow`.\"\n            )\n        if sparse and backend.backend() != \"tensorflow\":\n            raise ValueError(\n                \"`sparse=True` can only be used with the TensorFlow backend.\"\n            )\n        if ragged and backend.backend() != \"tensorflow\":\n            raise ValueError(\n                \"`ragged=True` can only be used with the TensorFlow backend.\"\n            )\n\n        # 'standardize' must be one of\n        # (None, \"lower_and_strip_punctuation\", \"lower\", \"strip_punctuation\",\n        # callable)\n        argument_validation.validate_string_arg(\n            standardize,\n            allowable_strings=(\n                \"lower_and_strip_punctuation\",\n                \"lower\",\n                \"strip_punctuation\",\n            ),\n            caller_name=self.__class__.__name__,\n            arg_name=\"standardize\",\n            allow_none=True,\n            allow_callables=True,\n        )\n\n        # 'split' must be one of (None, \"whitespace\", \"character\", callable)\n        argument_validation.validate_string_arg(\n            split,\n            allowable_strings=(\"whitespace\", \"character\"),\n            caller_name=self.__class__.__name__,\n            arg_name=\"split\",\n            allow_none=True,\n            allow_callables=True,\n        )\n\n        # Support deprecated names for output_modes.\n        if output_mode == \"binary\":\n            output_mode = \"multi_hot\"\n        if output_mode == \"tf-idf\":\n            output_mode = \"tf_idf\"\n        argument_validation.validate_string_arg(\n            output_mode,\n            allowable_strings=(\n                \"int\",\n                \"one_hot\",\n                \"multi_hot\",\n                \"count\",\n                \"tf_idf\",\n            ),\n            caller_name=self.__class__.__name__,\n            arg_name=\"output_mode\",\n        )\n\n        # 'ngrams' must be one of (None, int, tuple(int))\n        if not (\n            ngrams is None\n            or isinstance(ngrams, int)\n            or isinstance(ngrams, tuple)\n            and all(isinstance(item, int) for item in ngrams)\n        ):\n            raise ValueError(\n                \"`ngrams` must be None, an integer, or a tuple of \"\n                f\"integers. Received: ngrams={ngrams}\"\n            )\n\n        # 'output_sequence_length' must be one of (None, int) and is only\n        # set if output_mode is \"int\"\".\n        if output_mode == \"int\" and not (\n            isinstance(output_sequence_length, int)\n            or (output_sequence_length is None)\n        ):\n            raise ValueError(\n                \"`output_sequence_length` must be either None or an \"\n                \"integer when `output_mode` is 'int'. Received: \"\n                f\"output_sequence_length={output_sequence_length}\"\n            )\n\n        if output_mode != \"int\" and output_sequence_length is not None:\n            raise ValueError(\n                \"`output_sequence_length` must not be set if `output_mode` is \"\n                \"not 'int'. \"\n                f\"Received output_sequence_length={output_sequence_length}.\"\n            )\n\n        if ragged and output_mode != \"int\":\n            raise ValueError(\n                \"`ragged` must not be true if `output_mode` is \"\n                f\"`'int'`. Received: ragged={ragged} and \"\n                f\"output_mode={output_mode}\"\n            )\n\n        self._max_tokens = max_tokens\n        self._standardize = standardize\n        self._split = split\n        self._ngrams_arg = ngrams\n        if isinstance(ngrams, int):\n            self._ngrams = tuple(range(1, ngrams + 1))\n        else:\n            self._ngrams = ngrams\n        self._ragged = ragged\n\n        self._output_mode = output_mode\n        self._output_sequence_length = output_sequence_length\n        self._encoding = encoding\n\n        # We save this hidden option to persist the fact\n        # that we have a non-adaptable layer with a\n        # manually set vocab.\n        self._has_input_vocabulary = kwargs.pop(\n            \"has_input_vocabulary\", (vocabulary is not None)\n        )\n        vocabulary_size = kwargs.pop(\"vocabulary_size\", None)\n\n        super().__init__(name=name, **kwargs)\n\n        self._lookup_layer = StringLookup(\n            max_tokens=max_tokens,\n            vocabulary=vocabulary,\n            idf_weights=idf_weights,\n            pad_to_max_tokens=pad_to_max_tokens,\n            mask_token=\"\",\n            output_mode=output_mode,\n            sparse=sparse,\n            has_input_vocabulary=self._has_input_vocabulary,\n            encoding=encoding,\n            vocabulary_size=vocabulary_size,\n        )\n        self._convert_input_args = False\n        self._allow_non_tensor_positional_args = True\n        self.supports_jit = False\n\n    @property\n    def compute_dtype(self):\n        return \"string\"\n\n    @property\n    def variable_dtype(self):\n        return \"string\"\n\n    def build(self, input_shape=None):\n        pass\n\n    def compute_output_shape(self, input_shape):\n        if self._output_mode == \"int\":\n            return (input_shape[0], self._output_sequence_length)\n        if self._split is None:\n            if len(input_shape) <= 1:\n                input_shape = tuple(input_shape) + (1,)\n        else:\n            input_shape = tuple(input_shape) + (None,)\n        return self._lookup_layer.compute_output_shape(input_shape)\n\n    def compute_output_spec(self, inputs):\n        output_shape = self.compute_output_shape(inputs.shape)\n        if self._output_mode == \"int\":\n            output_dtype = \"int64\"\n        else:\n            output_dtype = backend.floatx()\n        return backend.KerasTensor(output_shape, dtype=output_dtype)\n\n    def adapt(self, data, batch_size=None, steps=None):\n        \"\"\"Computes a vocabulary of string terms from tokens in a dataset.\n\n        Calling `adapt()` on a `TextVectorization` layer is an alternative to\n        passing in a precomputed vocabulary on construction via the `vocabulary`\n        argument. A `TextVectorization` layer should always be either adapted\n        over a dataset or supplied with a vocabulary.\n\n        During `adapt()`, the layer will build a vocabulary of all string tokens\n        seen in the dataset, sorted by occurrence count, with ties broken by\n        sort order of the tokens (high to low). At the end of `adapt()`, if\n        `max_tokens` is set, the vocabulary will be truncated to `max_tokens`\n        size. For example, adapting a layer with `max_tokens=1000` will compute\n        the 1000 most frequent tokens occurring in the input dataset. If\n        `output_mode='tf-idf'`, `adapt()` will also learn the document\n        frequencies of each token in the input dataset.\n\n        Arguments:\n            data: The data to train on. It can be passed either as a\n                batched `tf.data.Dataset`, as a list of strings,\n                or as a NumPy array.\n            steps: Integer or `None`.\n                Total number of steps (batches of samples) to process.\n                If `data` is a `tf.data.Dataset`, and `steps` is `None`,\n                `adapt()` will run until the input dataset is exhausted.\n                When passing an infinitely\n                repeating dataset, you must specify the `steps` argument. This\n                argument is not supported with array inputs or list inputs.\n        \"\"\"\n        self.reset_state()\n        if isinstance(data, tf.data.Dataset):\n            if steps is not None:\n                data = data.take(steps)\n            for batch in data:\n                self.update_state(batch)\n        else:\n            data = tf_utils.ensure_tensor(data, dtype=\"string\")\n            if data.shape.rank == 1:\n                # A plain list of strings\n                # is treated as as many documents\n                data = tf.expand_dims(data, -1)\n            self.update_state(data)\n        self.finalize_state()\n\n    def update_state(self, data):\n        self._lookup_layer.update_state(self._preprocess(data))\n\n    def finalize_state(self):\n        self._lookup_layer.finalize_state()\n\n    def reset_state(self):\n        self._lookup_layer.reset_state()\n\n    def get_vocabulary(self, include_special_tokens=True):\n        \"\"\"Returns the current vocabulary of the layer.\n\n        Args:\n            include_special_tokens: If `True`, the returned vocabulary\n                will include the padding and OOV tokens,\n                and a term's index in the vocabulary will equal\n                the term's index when calling the layer. If `False`, the\n                returned vocabulary will not include any padding\n                or OOV tokens.\n        \"\"\"\n        return self._lookup_layer.get_vocabulary(include_special_tokens)\n\n    def vocabulary_size(self):\n        \"\"\"Gets the current size of the layer's vocabulary.\n\n        Returns:\n            The integer size of the vocabulary, including optional\n            mask and OOV indices.\n        \"\"\"\n        return self._lookup_layer.vocabulary_size()\n\n    def get_config(self):\n        config = {\n            \"max_tokens\": self._lookup_layer.max_tokens,\n            \"standardize\": self._standardize,\n            \"split\": self._split,\n            \"ngrams\": self._ngrams_arg,\n            \"output_mode\": self._output_mode,\n            \"output_sequence_length\": self._output_sequence_length,\n            \"pad_to_max_tokens\": self._lookup_layer.pad_to_max_tokens,\n            \"sparse\": self._lookup_layer.sparse,\n            \"ragged\": self._ragged,\n            \"vocabulary\": listify_tensors(self._lookup_layer.input_vocabulary),\n            \"idf_weights\": listify_tensors(\n                self._lookup_layer.input_idf_weights\n            ),\n            \"encoding\": self._encoding,\n            \"vocabulary_size\": self.vocabulary_size(),\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n    @classmethod\n    def from_config(cls, config):\n        if not isinstance(config[\"standardize\"], str):\n            config[\"standardize\"] = serialization_lib.deserialize_keras_object(\n                config[\"standardize\"]\n            )\n        if not isinstance(config[\"split\"], str):\n            config[\"split\"] = serialization_lib.deserialize_keras_object(\n                config[\"split\"]\n            )\n\n        if isinstance(config[\"ngrams\"], list):\n            config[\"ngrams\"] = tuple(config[\"ngrams\"])\n\n        return cls(**config)\n\n    def set_vocabulary(self, vocabulary, idf_weights=None):\n        \"\"\"Sets vocabulary (and optionally document frequency) for this layer.\n\n        This method sets the vocabulary and IDF weights for this layer directly,\n        instead of analyzing a dataset through `adapt()`. It should be used\n        whenever the vocab (and optionally document frequency) information is\n        already known. If vocabulary data is already present in the layer, this\n        method will replace it.\n\n        Args:\n            vocabulary: Either an array or a string path to a text file.\n                If passing an array, can pass a tuple, list, 1D NumPy array,\n                or 1D tensor containing the vocabulary terms.\n                If passing a file path, the file should contain one line\n                per term in the vocabulary.\n            idf_weights: A tuple, list, 1D NumPy array, or 1D tensor of inverse\n                document frequency weights with equal length to vocabulary.\n                Must be set if `output_mode` is `\"tf_idf\"`.\n                Should not be set otherwise.\n        \"\"\"\n        self._lookup_layer.set_vocabulary(vocabulary, idf_weights=idf_weights)\n\n    def _preprocess(self, inputs):\n        inputs = tf_utils.ensure_tensor(inputs, dtype=tf.string)\n        if self._standardize in (\"lower\", \"lower_and_strip_punctuation\"):\n            inputs = tf.strings.lower(inputs)\n        if self._standardize in (\n            \"strip_punctuation\",\n            \"lower_and_strip_punctuation\",\n        ):\n            inputs = tf.strings.regex_replace(\n                inputs, r'[!\"#$%&()\\*\\+,-\\./:;<=>?@\\[\\\\\\]^_`{|}~\\']', \"\"\n            )\n        if callable(self._standardize):\n            inputs = self._standardize(inputs)\n\n        if self._split is not None:\n            # If we are splitting, we validate that the 1st axis is of dimension\n            # 1 and so can be squeezed out. We do this here instead of after\n            # splitting for performance reasons - it's more expensive to squeeze\n            # a ragged tensor.\n            if inputs.shape.rank > 1:\n                if inputs.shape[-1] != 1:\n                    raise ValueError(\n                        \"When using `TextVectorization` to tokenize strings, \"\n                        \"the input rank must be 1 or the last shape dimension \"\n                        f\"must be 1. Received: inputs.shape={inputs.shape} \"\n                        f\"with rank={inputs.shape.rank}\"\n                    )\n                else:\n                    inputs = tf.squeeze(inputs, axis=-1)\n            if self._split == \"whitespace\":\n                # This treats multiple whitespaces as one whitespace, and strips\n                # leading and trailing whitespace.\n                inputs = tf.strings.split(inputs)\n            elif self._split == \"character\":\n                inputs = tf.strings.unicode_split(inputs, \"UTF-8\")\n            elif callable(self._split):\n                inputs = self._split(inputs)\n\n        # Note that 'inputs' here can be either ragged or dense depending on the\n        # configuration choices for this Layer. The strings.ngrams op, however,\n        # does support both ragged and dense inputs.\n        if self._ngrams is not None:\n            inputs = tf.strings.ngrams(\n                inputs, ngram_width=self._ngrams, separator=\" \"\n            )\n        return inputs\n\n    def call(self, inputs):\n        if not isinstance(\n            inputs, (tf.Tensor, tf.RaggedTensor, np.ndarray, list, tuple)\n        ):\n            inputs = tf.convert_to_tensor(backend.convert_to_numpy(inputs))\n\n        inputs = self._preprocess(inputs)\n\n        # If we're not doing any output processing, return right away.\n        if self._output_mode is None:\n            outputs = inputs\n\n        lookup_data = self._lookup_layer.call(inputs)\n\n        # For non-int output, we can return directly from the underlying layer.\n        if self._output_mode != \"int\":\n            return backend_utils.convert_tf_tensor(lookup_data)\n\n        # If we have a ragged tensor, we can pad during the conversion to dense.\n        if isinstance(lookup_data, tf.RaggedTensor) and not self._ragged:\n            shape = lookup_data.shape.as_list()\n            # If output sequence length is None, to_tensor will pad the last\n            # dimension to the bounding shape of the ragged dimension.\n            shape[-1] = self._output_sequence_length\n            outputs = lookup_data.to_tensor(default_value=0, shape=shape)\n        # If we have a dense tensor, we need to pad/trim directly.\n        elif self._output_sequence_length is not None:\n            # Maybe trim the output.\n            outputs = lookup_data[..., : self._output_sequence_length]\n\n            # Maybe pad the output. We need to be careful to use dynamic shape\n            # here as required_space_to_batch_paddings requires a fully known\n            # shape.\n            if not self._ragged:\n                shape = tf.shape(outputs)\n                padded_shape = tf.concat(\n                    (shape[:-1], [self._output_sequence_length]), 0\n                )\n                padding, _ = tf.required_space_to_batch_paddings(\n                    shape, padded_shape\n                )\n                outputs = tf.pad(outputs, padding)\n                # Because `tf.pad` used a dynamic shape, the output shape is\n                # dynamic. Apply the known static `_output_sequence_length`.\n                static_padded_shape = lookup_data.shape.as_list()\n                static_padded_shape[-1] = self._output_sequence_length\n                outputs.set_shape(static_padded_shape)\n        else:\n            outputs = lookup_data\n\n        return backend_utils.convert_tf_tensor(outputs)\n\n    def save_own_variables(self, store):\n        self._lookup_layer.save_own_variables(store)\n\n    def load_own_variables(self, store):\n        self._lookup_layer.load_own_variables(store)\n\n    def save_assets(self, dir_path):\n        self._lookup_layer.save_assets(dir_path)\n\n    def load_assets(self, dir_path):\n        self._lookup_layer.load_assets(dir_path)\n"
  },
  {
    "path": "keras/src/layers/preprocessing/text_vectorization_test.py",
    "content": "import os\n\nimport numpy as np\nimport pytest\nimport tensorflow as tf\nfrom absl.testing import parameterized\nfrom tensorflow import data as tf_data\n\nfrom keras.src import Sequential\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import saving\nfrom keras.src import testing\n\n\nclass TextVectorizationTest(testing.TestCase, parameterized.TestCase):\n    # TODO: increase coverage. Most features aren't being tested.\n\n    def test_config(self):\n        layer = layers.TextVectorization(\n            output_mode=\"int\",\n            vocabulary=[\"one\", \"two\"],\n            output_sequence_length=5,\n        )\n        self.run_class_serialization_test(layer)\n\n    def test_adapt_flow(self):\n        max_tokens = 5000\n        max_len = 4\n        layer = layers.TextVectorization(\n            max_tokens=max_tokens,\n            output_mode=\"int\",\n            output_sequence_length=max_len,\n        )\n        layer.adapt([\"foo bar\", \"bar baz\", \"baz bada boom\"])\n        input_data = [[\"foo qux bar\"], [\"qux baz\"]]\n        output = layer(input_data)\n        self.assertTrue(backend.is_tensor(output))\n        self.assertAllClose(output, np.array([[4, 1, 3, 0], [1, 2, 0, 0]]))\n\n    def test_fixed_vocabulary(self):\n        max_tokens = 5000\n        max_len = 4\n        layer = layers.TextVectorization(\n            max_tokens=max_tokens,\n            output_mode=\"int\",\n            output_sequence_length=max_len,\n            vocabulary=[\"baz\", \"bar\", \"foo\"],\n        )\n        input_data = [[\"foo qux bar\"], [\"qux baz\"]]\n        output = layer(input_data)\n        self.assertTrue(backend.is_tensor(output))\n        self.assertAllClose(output, np.array([[4, 1, 3, 0], [1, 2, 0, 0]]))\n\n    def test_set_vocabulary(self):\n        max_tokens = 5000\n        max_len = 4\n        layer = layers.TextVectorization(\n            max_tokens=max_tokens,\n            output_mode=\"int\",\n            output_sequence_length=max_len,\n        )\n        layer.set_vocabulary([\"baz\", \"bar\", \"foo\"])\n        input_data = [[\"foo qux bar\"], [\"qux baz\"]]\n        output = layer(input_data)\n        self.assertTrue(backend.is_tensor(output))\n        self.assertAllClose(output, np.array([[4, 1, 3, 0], [1, 2, 0, 0]]))\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\", reason=\"Requires string input dtype\"\n    )\n    def test_save_load_with_ngrams_flow(self):\n        input_data = np.array([\"foo bar\", \"bar baz\", \"baz bada boom\"])\n        model = Sequential(\n            [\n                layers.Input(dtype=\"string\", shape=(1,)),\n                layers.TextVectorization(ngrams=(1, 2)),\n            ]\n        )\n        model.layers[0].adapt(input_data)\n        output = model(input_data)\n        temp_filepath = os.path.join(self.get_temp_dir(), \"model.keras\")\n        model.save(temp_filepath)\n        model = saving.load_model(temp_filepath)\n        self.assertAllClose(output, model(input_data))\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\", reason=\"Requires string input dtype\"\n    )\n    def test_save_load_tf_idf_mode(self):\n        input_data = np.array([\"foo bar\", \"bar baz\", \"baz bada boom\"])\n        model = Sequential(\n            [\n                layers.Input(dtype=\"string\", shape=()),\n                layers.TextVectorization(max_tokens=100, output_mode=\"tf_idf\"),\n            ]\n        )\n        model.layers[0].adapt(input_data)\n        output = model(input_data)\n        temp_filepath = os.path.join(self.get_temp_dir(), \"model.keras\")\n        model.save(temp_filepath)\n        loaded_model = saving.load_model(temp_filepath)\n        self.assertAllClose(output, loaded_model(input_data))\n\n    def test_tf_data_compatibility(self):\n        max_tokens = 5000\n        max_len = 4\n        layer = layers.TextVectorization(\n            max_tokens=max_tokens,\n            output_mode=\"int\",\n            output_sequence_length=max_len,\n            vocabulary=[\"baz\", \"bar\", \"foo\"],\n        )\n        input_data = [[\"foo qux bar\"], [\"qux baz\"]]\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        output = next(iter(ds)).numpy()\n        self.assertAllClose(output, np.array([[4, 1, 3, 0], [1, 2, 0, 0]]))\n\n        # Test adapt flow\n        layer = layers.TextVectorization(\n            max_tokens=max_tokens,\n            output_mode=\"int\",\n            output_sequence_length=max_len,\n        )\n        layer.adapt(input_data)\n        ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        next(iter(ds)).numpy()\n\n    @parameterized.named_parameters(\n        [\n            (\"from_ragged\", \"whitespace\"),  # intermediate tensor is ragged\n            (\"from_dense\", None),  # intermediate tensor is dense\n        ]\n    )\n    def test_static_output_sequence_length(self, split):\n        max_tokens = 5000\n        max_len = 4\n        layer = layers.TextVectorization(\n            max_tokens=max_tokens,\n            output_mode=\"int\",\n            output_sequence_length=max_len,\n            split=split,\n            vocabulary=[\"baz\", \"bar\", \"foo\"],\n        )\n        if split:\n            input_data = [[\"foo qux bar\"], [\"qux baz\"]]\n        else:\n            input_data = [[\"foo\"], [\"baz\"]]\n\n        def call_layer(x):\n            result = layer(x)\n            self.assertEqual(result.shape, (None, 4))\n            return result\n\n        ds = (\n            tf_data.Dataset.from_tensor_slices(input_data)\n            .batch(2)\n            .map(call_layer)\n        )\n        next(iter(ds))\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\", reason=\"Requires string tensors.\"\n    )\n    def test_tf_as_first_sequential_layer(self):\n        layer = layers.TextVectorization(\n            max_tokens=10,\n            output_mode=\"int\",\n            output_sequence_length=3,\n        )\n        layer.set_vocabulary([\"baz\", \"bar\", \"foo\"])\n        model = models.Sequential(\n            [\n                layer,\n                layers.Embedding(5, 4),\n            ]\n        )\n        model(backend.convert_to_tensor([[\"foo qux bar\"], [\"qux baz\"]]))\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\", reason=\"Requires ragged tensors.\"\n    )\n    def test_ragged_tensor(self):\n        layer = layers.TextVectorization(\n            output_mode=\"int\",\n            vocabulary=[\"baz\", \"bar\", \"foo\"],\n            ragged=True,\n        )\n        input_data = [[\"foo qux bar\"], [\"qux baz\"], [\"foo\"]]\n        output = layer(input_data)\n        self.assertIsInstance(output, tf.RaggedTensor)\n        self.assertEqual(output.shape, (3, None))\n        self.assertEqual(output.to_list(), [[4, 1, 3], [1, 2], [4]])\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\", reason=\"Requires ragged tensors.\"\n    )\n    def test_ragged_tensor_output_length(self):\n        layer = layers.TextVectorization(\n            output_mode=\"int\",\n            vocabulary=[\"baz\", \"bar\", \"foo\"],\n            ragged=True,\n            output_sequence_length=2,\n        )\n        input_data = [[\"foo qux bar\"], [\"qux baz\"], [\"foo\"]]\n        output = layer(input_data)\n        self.assertIsInstance(output, tf.RaggedTensor)\n        self.assertEqual(output.shape, (3, None))\n        self.assertEqual(output.to_list(), [[4, 1], [1, 2], [4]])\n\n    @pytest.mark.skipif(\n        backend.backend() == \"tensorflow\",\n        reason=\"Verify raises exception for non-TF backends\",\n    )\n    def test_raises_exception_ragged_tensor(self):\n        with self.assertRaises(ValueError):\n            _ = layers.TextVectorization(\n                output_mode=\"int\",\n                vocabulary=[\"baz\", \"bar\", \"foo\"],\n                ragged=True,\n            )\n\n    def test_multi_hot_output(self):\n        layer = layers.TextVectorization(\n            output_mode=\"multi_hot\", vocabulary=[\"foo\", \"bar\", \"baz\"]\n        )\n        input_data = [[\"foo bar\"], [\"baz foo foo\"]]\n        output = layer(input_data)\n\n        \"\"\"\n        First batch\n        Tokens present: [\"foo\", \"bar\"]\n            For each token in vocabulary:\n            foo (index 1): present -> 1\n            bar (index 2): present -> 1\n            baz (index 3): absent -> 0\n            Result: [0, 1, 1, 0]\n        \n        Second batch\n            Tokens: [\"baz\", \"foo\", \"foo\"]\n            For each token in vocabulary:\n            foo (index 1): present -> 1\n            bar (index 2): absent -> 0\n            baz (index 3): present -> 1\n            Result: [0, 1, 0, 1]\n        \"\"\"\n        self.assertAllClose(output, [[0, 1, 1, 0], [0, 1, 0, 1]])\n\n    def test_output_mode_count_output(self):\n        layer = layers.TextVectorization(\n            output_mode=\"count\", vocabulary=[\"foo\", \"bar\", \"baz\"]\n        )\n        output = layer([\"foo bar\", \"baz foo foo\"])\n        self.assertAllClose(output, [[0, 1, 1, 0], [0, 2, 0, 1]])\n\n    def test_output_mode_tf_idf_output(self):\n        layer = layers.TextVectorization(\n            output_mode=\"tf_idf\",\n            vocabulary=[\"foo\", \"bar\", \"baz\"],\n            idf_weights=[0.3, 0.5, 0.2],\n        )\n        output = layer([\"foo bar\", \"baz foo foo\"])\n        self.assertAllClose(\n            output, [[0.0, 0.3, 0.5, 0.0], [0.0, 0.6, 0.0, 0.2]]\n        )\n\n    def test_lower_and_strip_punctuation_standardization(self):\n        layer = layers.TextVectorization(\n            standardize=\"lower_and_strip_punctuation\",\n            vocabulary=[\"hello\", \"world\", \"this\", \"is\", \"nice\", \"test\"],\n        )\n        output = layer([\"Hello, World!. This is just a nice test!\"])\n        self.assertTrue(backend.is_tensor(output))\n\n        # test output sequence length, taking first batch.\n        self.assertEqual(len(output[0]), 8)\n\n        self.assertAllEqual(output, [[2, 3, 4, 5, 1, 1, 6, 7]])\n\n    def test_lower_standardization(self):\n        layer = layers.TextVectorization(\n            standardize=\"lower\",\n            vocabulary=[\n                \"hello,\",\n                \"hello\",\n                \"world\",\n                \"this\",\n                \"is\",\n                \"nice\",\n                \"test\",\n            ],\n        )\n        output = layer([\"Hello, World!. This is just a nice test!\"])\n        self.assertTrue(backend.is_tensor(output))\n        self.assertEqual(len(output[0]), 8)\n        \"\"\"\n        The input is lowercased and tokenized into words. The vocab is:\n        {0: '',\n        1: '[UNK]',\n        2: 'hello,',\n        3: 'hello',\n        4: 'world',\n        5: 'this',\n        6: 'is',\n        7: 'nice',\n        8: 'test'}\n        \"\"\"\n        self.assertAllEqual(output, [[2, 1, 5, 6, 1, 1, 7, 1]])\n\n    def test_char_splitting(self):\n        layer = layers.TextVectorization(\n            split=\"character\", vocabulary=list(\"abcde\"), output_mode=\"int\"\n        )\n        output = layer([\"abcf\"])\n        self.assertTrue(backend.is_tensor(output))\n        self.assertEqual(len(output[0]), 4)\n        self.assertAllEqual(output, [[2, 3, 4, 1]])\n\n    def test_custom_splitting(self):\n        def custom_split(text):\n            return tf.strings.split(text, sep=\"|\")\n\n        layer = layers.TextVectorization(\n            split=custom_split,\n            vocabulary=[\"foo\", \"bar\", \"foobar\"],\n            output_mode=\"int\",\n        )\n        output = layer([\"foo|bar\"])\n        self.assertTrue(backend.is_tensor(output))\n\n        # after custom split, the outputted index should be the last\n        # token in the vocab.\n        self.assertAllEqual(output, [[4]])\n\n    def test_strip_punctuation_standardization(self):\n        layer = layers.TextVectorization(\n            standardize=\"strip_punctuation\",\n            vocabulary=[\"Hello\", \"World\", \"Test\"],\n        )\n        output = layer([\"Hello, World! Test.\"])\n        self.assertTrue(backend.is_tensor(output))\n        # Case is preserved, punctuation stripped\n        self.assertAllEqual(output, [[2, 3, 4]])\n\n    def test_no_standardization(self):\n        layer = layers.TextVectorization(\n            standardize=None,\n            vocabulary=[\"Hello\", \"world\"],\n        )\n        # \"Hello\" matches, \"hello\" does not (case-sensitive)\n        output = layer([\"Hello hello\"])\n        self.assertTrue(backend.is_tensor(output))\n        self.assertAllEqual(output, [[2, 1]])\n\n    def test_custom_standardize_callable(self):\n        def custom_standardize(text):\n            return tf.strings.regex_replace(text, \"-\", \" \")\n\n        layer = layers.TextVectorization(\n            standardize=custom_standardize,\n            split=\"whitespace\",\n            vocabulary=[\"foo\", \"bar\"],\n        )\n        output = layer([\"foo-bar\"])\n        self.assertTrue(backend.is_tensor(output))\n        self.assertAllEqual(output, [[2, 3]])\n\n    def test_no_split(self):\n        layer = layers.TextVectorization(\n            split=None,\n            vocabulary=[\"foo\", \"bar\", \"baz\"],\n            output_mode=\"int\",\n        )\n        # Each element is looked up as a whole string (no splitting)\n        output = layer([[\"foo\"], [\"bar\"], [\"unknown\"]])\n        self.assertTrue(backend.is_tensor(output))\n        self.assertAllEqual(output, [[2], [3], [1]])\n\n    def test_ngrams_integer(self):\n        layer = layers.TextVectorization(\n            ngrams=2,\n            output_mode=\"int\",\n        )\n        layer.adapt([\"foo bar baz\"])\n        vocab = layer.get_vocabulary()\n        # ngrams=2 produces unigrams and bigrams\n        # Verify bigrams are in the vocabulary\n        self.assertIn(\"foo bar\", vocab)\n        self.assertIn(\"bar baz\", vocab)\n\n    def test_ngrams_tuple(self):\n        layer = layers.TextVectorization(\n            ngrams=(1, 3),\n            output_mode=\"int\",\n        )\n        layer.adapt([\"foo bar baz\"])\n        vocab = layer.get_vocabulary()\n        # Should have unigrams and trigrams but not bigrams\n        self.assertIn(\"foo\", vocab)\n        self.assertIn(\"foo bar baz\", vocab)\n        self.assertNotIn(\"foo bar\", vocab)\n\n    def test_max_tokens(self):\n        layer = layers.TextVectorization(\n            output_mode=\"int\",\n            max_tokens=5,\n        )\n        layer.adapt([\"a b c d e f g h i j\"])\n        vocab = layer.get_vocabulary()\n        self.assertEqual(len(vocab), 5)\n\n    def test_pad_to_max_tokens(self):\n        layer = layers.TextVectorization(\n            output_mode=\"multi_hot\",\n            vocabulary=[\"foo\", \"bar\"],\n            max_tokens=8,\n            pad_to_max_tokens=True,\n        )\n        output = layer([\"foo bar\"])\n        self.assertEqual(output.shape[-1], 8)\n\n    def test_get_vocabulary(self):\n        layer = layers.TextVectorization(\n            output_mode=\"int\",\n            vocabulary=[\"foo\", \"bar\", \"baz\"],\n        )\n        vocab = layer.get_vocabulary()\n        self.assertEqual(vocab[0], \"\")\n        self.assertEqual(vocab[1], \"[UNK]\")\n        self.assertIn(\"foo\", vocab)\n        self.assertIn(\"bar\", vocab)\n        self.assertIn(\"baz\", vocab)\n\n    def test_get_vocabulary_no_special_tokens(self):\n        layer = layers.TextVectorization(\n            output_mode=\"int\",\n            vocabulary=[\"foo\", \"bar\", \"baz\"],\n        )\n        vocab = layer.get_vocabulary(include_special_tokens=False)\n        self.assertNotIn(\"\", vocab)\n        self.assertNotIn(\"[UNK]\", vocab)\n        self.assertEqual(len(vocab), 3)\n\n    def test_vocabulary_size(self):\n        layer = layers.TextVectorization(\n            output_mode=\"int\",\n            vocabulary=[\"foo\", \"bar\", \"baz\"],\n        )\n        # 3 vocab + mask + OOV = 5\n        self.assertEqual(layer.vocabulary_size(), 5)\n\n    def test_vocabulary_from_file(self):\n        tmp_dir = self.get_temp_dir()\n        vocab_file = os.path.join(tmp_dir, \"vocab.txt\")\n        with open(vocab_file, \"w\") as f:\n            f.write(\"foo\\nbar\\nbaz\\n\")\n        layer = layers.TextVectorization(\n            vocabulary=vocab_file,\n            output_mode=\"int\",\n        )\n        output = layer([\"foo bar baz unknown\"])\n        self.assertAllClose(output, np.array([[2, 3, 4, 1]]))\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"one_hot output only supported on TensorFlow\",\n    )\n    def test_one_hot_output(self):\n        layer = layers.TextVectorization(\n            output_mode=\"one_hot\",\n            vocabulary=[\"foo\", \"bar\", \"baz\"],\n        )\n        output = layer([\"foo bar\"])\n        self.assertTrue(backend.is_tensor(output))\n        # one_hot on a split sentence produces shape (1, num_tokens, vocab_size)\n        self.assertEqual(output.shape[-1], 4)\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"sparse=True only supported on TensorFlow\",\n    )\n    def test_sparse_output(self):\n        layer = layers.TextVectorization(\n            output_mode=\"multi_hot\",\n            vocabulary=[\"foo\", \"bar\", \"baz\"],\n            sparse=True,\n        )\n        output = layer([\"foo bar\"])\n        self.assertTrue(hasattr(output, \"indices\"))\n\n    def test_adapt_with_tf_dataset(self):\n        ds = tf_data.Dataset.from_tensor_slices(\n            [\"foo bar\", \"bar baz\", \"baz foo\"]\n        ).batch(2)\n        layer = layers.TextVectorization(output_mode=\"int\")\n        layer.adapt(ds)\n        vocab = layer.get_vocabulary()\n        self.assertIn(\"foo\", vocab)\n        self.assertIn(\"bar\", vocab)\n        self.assertIn(\"baz\", vocab)\n\n    def test_adapt_with_steps(self):\n        ds = tf_data.Dataset.from_tensor_slices(\n            [\"foo bar\", \"bar baz\", \"unique_word\"]\n        ).batch(1)\n        layer = layers.TextVectorization(output_mode=\"int\")\n        # Only process first 2 batches, so \"unique_word\" should not be adapted\n        layer.adapt(ds, steps=2)\n        vocab = layer.get_vocabulary()\n        self.assertIn(\"bar\", vocab)\n        self.assertNotIn(\"unique_word\", vocab)\n\n    def test_invalid_ngrams(self):\n        with self.assertRaises(ValueError):\n            layers.TextVectorization(ngrams=\"invalid\")\n\n    def test_output_sequence_length_with_non_int_mode(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"`output_sequence_length` must not be set if `output_mode` is not\",\n        ):\n            layers.TextVectorization(\n                output_mode=\"multi_hot\",\n                output_sequence_length=5,\n            )\n"
  },
  {
    "path": "keras/src/layers/regularization/__init__.py",
    "content": ""
  },
  {
    "path": "keras/src/layers/regularization/activity_regularization.py",
    "content": "from keras.src import regularizers\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.layer import Layer\n\n\n@keras_export(\"keras.layers.ActivityRegularization\")\nclass ActivityRegularization(Layer):\n    \"\"\"Layer that applies an update to the cost function based input activity.\n\n    Args:\n        l1: L1 regularization factor (positive float).\n        l2: L2 regularization factor (positive float).\n\n    Input shape:\n        Arbitrary. Use the keyword argument `input_shape`\n        (tuple of integers, does not include the samples axis)\n        when using this layer as the first layer in a model.\n\n    Output shape:\n        Same shape as input.\n    \"\"\"\n\n    def __init__(self, l1=0.0, l2=0.0, **kwargs):\n        super().__init__(\n            activity_regularizer=regularizers.L1L2(l1=l1, l2=l2), **kwargs\n        )\n        self.supports_masking = True\n        self.l1 = l1\n        self.l2 = l2\n\n        self._build_at_init()\n\n    def call(self, inputs):\n        return inputs\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def get_config(self):\n        base_config = super().get_config()\n        base_config.pop(\"activity_regularizer\", None)\n        config = {\"l1\": self.l1, \"l2\": self.l2}\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/regularization/activity_regularization_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import layers\nfrom keras.src.testing import test_case\n\n\nclass ActivityRegularizationTest(test_case.TestCase):\n    def test_correctness(self):\n        layer = layers.ActivityRegularization(l1=0.2, l2=0.3)\n        layer(2 * np.ones((1,)))\n        self.assertLen(layer.losses, 1)\n        self.assertAllClose(layer.losses[0], 4 * 0.3 + 2 * 0.2)\n\n    @pytest.mark.requires_trainable_backend\n    def test_activity_regularization_basics(self):\n        self.run_layer_test(\n            layers.ActivityRegularization,\n            {\"l1\": 0.1, \"l2\": 0.2},\n            input_shape=(2, 3),\n            input_dtype=\"float32\",\n            expected_output_shape=(2, 3),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=1,\n            supports_masking=True,\n            assert_built_after_instantiation=True,\n        )\n"
  },
  {
    "path": "keras/src/layers/regularization/alpha_dropout.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.layer import Layer\n\n\n@keras_export(\"keras.layers.AlphaDropout\")\nclass AlphaDropout(Layer):\n    \"\"\"Applies Alpha Dropout to the input.\n\n    Alpha Dropout is a `Dropout` that keeps mean and variance of inputs\n    to their original values, in order to ensure the self-normalizing property\n    even after this dropout.\n    Alpha Dropout fits well to Scaled Exponential Linear Units (SELU) by\n    randomly setting activations to the negative saturation value.\n\n    Args:\n        rate: Float between 0 and 1. The multiplicative noise will have\n            standard deviation `sqrt(rate / (1 - rate))`.\n        noise_shape: 1D integer tensor representing the shape of the\n            binary alpha dropout mask that will be multiplied with the input.\n            For instance, if your inputs have shape\n            `(batch_size, timesteps, features)` and\n            you want the alpha dropout mask to be the same for all timesteps,\n            you can use `noise_shape=(batch_size, 1, features)`.\n        seed: A Python integer to use as random seed.\n\n    Call arguments:\n        inputs: Input tensor (of any rank).\n        training: Python boolean indicating whether the layer should behave in\n            training mode (adding alpha dropout) or in inference mode\n            (doing nothing).\n    \"\"\"\n\n    def __init__(self, rate, noise_shape=None, seed=None, **kwargs):\n        super().__init__(**kwargs)\n        if not 0 <= rate <= 1:\n            raise ValueError(\n                f\"Invalid value received for argument \"\n                \"`rate`. Expected a float value between 0 and 1. \"\n                f\"Received: rate={rate}\"\n            )\n        self.rate = rate\n        self.seed = seed\n        self.noise_shape = noise_shape\n        if rate > 0:\n            self.seed_generator = backend.random.SeedGenerator(seed)\n        self.supports_masking = True\n\n        self._build_at_init()\n\n    def call(self, inputs, training=False):\n        if training and self.rate > 0:\n            noise_shape = self._get_concrete_noise_shape(\n                inputs, self.noise_shape\n            )\n            alpha = 1.6732632423543772848170429916717\n            scale = 1.0507009873554804934193349852946\n            alpha_p = -alpha * scale\n\n            kept_idx = ops.greater_equal(\n                ops.random.uniform(noise_shape, seed=self.seed_generator),\n                self.rate,\n            )\n            kept_idx = ops.cast(kept_idx, inputs.dtype)\n\n            # Compute affine transformation parameters\n            a = ((1 - self.rate) * (1 + self.rate * alpha_p**2)) ** -0.5\n            b = -a * alpha_p * self.rate\n\n            # Apply mask\n            x = inputs * kept_idx + alpha_p * (1 - kept_idx)\n            return a * x + b\n\n        return inputs\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def _get_concrete_noise_shape(self, inputs, noise_shape):\n        if noise_shape is None:\n            return ops.shape(inputs)\n\n        concrete_inputs_shape = ops.shape(inputs)\n        concrete_noise_shape = []\n        for i, value in enumerate(noise_shape):\n            concrete_noise_shape.append(\n                concrete_inputs_shape[i] if value is None else value\n            )\n        return concrete_noise_shape\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\n            \"rate\": self.rate,\n            \"seed\": self.seed,\n            \"noise_shape\": self.noise_shape,\n        }\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/regularization/alpha_dropout_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass AlphaDropoutTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_alpha_dropout_basics(self):\n        self.run_layer_test(\n            layers.AlphaDropout,\n            init_kwargs={\n                \"rate\": 0.2,\n            },\n            input_shape=(2, 3),\n            call_kwargs={\"training\": True},\n            expected_output_shape=(2, 3),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=1,\n            expected_num_losses=0,\n            supports_masking=True,\n            assert_built_after_instantiation=True,\n        )\n\n    def test_alpha_dropout_correctness(self):\n        inputs = np.ones((20, 500)).astype(\"float32\")\n        layer = layers.AlphaDropout(0.3, seed=1337)\n        outputs = layer(inputs, training=True)\n        self.assertAllClose(\n            np.std(backend.convert_to_numpy(outputs)), 1.0, atol=1e-1\n        )\n\n    def test_alpha_dropout_partial_noise_shape_dynamic(self):\n        inputs = np.ones((20, 5, 10))\n        layer = layers.AlphaDropout(0.5, noise_shape=(None, 1, None))\n        outputs = layer(inputs, training=True)\n        self.assertAllClose(outputs[:, 0, :], outputs[:, 1, :])\n\n    def test_alpha_dropout_partial_noise_shape_static(self):\n        inputs = np.ones((20, 5, 10))\n        layer = layers.AlphaDropout(0.5, noise_shape=(20, 1, 10))\n        outputs = layer(inputs, training=True)\n        self.assertAllClose(outputs[:, 0, :], outputs[:, 1, :])\n\n    def test_alpha_dropout_negative_rate(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Invalid value received for argument `rate`. \"\n            \"Expected a float value between 0 and 1.\",\n        ):\n            _ = layers.AlphaDropout(rate=-0.5)\n\n    def test_alpha_dropout_rate_greater_than_one(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Invalid value received for argument `rate`. \"\n            \"Expected a float value between 0 and 1.\",\n        ):\n            _ = layers.AlphaDropout(rate=1.5)\n"
  },
  {
    "path": "keras/src/layers/regularization/dropout.py",
    "content": "from keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.layer import Layer\n\n\n@keras_export(\"keras.layers.Dropout\")\nclass Dropout(Layer):\n    \"\"\"Applies dropout to the input.\n\n    The `Dropout` layer randomly sets input units to 0 with a frequency of\n    `rate` at each step during training time, which helps prevent overfitting.\n    Inputs not set to 0 are scaled up by `1 / (1 - rate)` such that the sum over\n    all inputs is unchanged.\n\n    Note that the `Dropout` layer only applies when `training` is set to `True`\n    in `call()`, such that no values are dropped during inference.\n    When using `model.fit`, `training` will be appropriately set to `True`\n    automatically. In other contexts, you can set the argument explicitly\n    to `True` when calling the layer.\n\n    (This is in contrast to setting `trainable=False` for a `Dropout` layer.\n    `trainable` does not affect the layer's behavior, as `Dropout` does\n    not have any variables/weights that can be frozen during training.)\n\n    Args:\n        rate: Float between 0 and 1. Fraction of the input units to drop.\n        noise_shape: 1D integer tensor representing the shape of the\n            binary dropout mask that will be multiplied with the input.\n            For instance, if your inputs have shape\n            `(batch_size, timesteps, features)` and\n            you want the dropout mask to be the same for all timesteps,\n            you can use `noise_shape=(batch_size, 1, features)`.\n        seed: A Python integer to use as random seed.\n\n    Call arguments:\n        inputs: Input tensor (of any rank).\n        training: Python boolean indicating whether the layer should behave in\n            training mode (adding dropout) or in inference mode (doing nothing).\n    \"\"\"\n\n    def __init__(self, rate, noise_shape=None, seed=None, **kwargs):\n        super().__init__(**kwargs)\n        if not 0 <= rate <= 1:\n            raise ValueError(\n                f\"Invalid value received for argument \"\n                \"`rate`. Expected a float value between 0 and 1. \"\n                f\"Received: rate={rate}\"\n            )\n        self.rate = rate\n        self.seed = seed\n        self.noise_shape = self._validate_noise_shape(noise_shape)\n        if rate > 0:\n            self.seed_generator = backend.random.SeedGenerator(seed)\n        self.supports_masking = True\n\n        self._build_at_init()\n\n    def _validate_noise_shape(self, noise_shape):\n        if noise_shape is None:\n            return None\n\n        if isinstance(noise_shape, str):\n            raise ValueError(\n                f\"Invalid value received for argument `noise_shape`. \"\n                f\"Expected a tuple or list of integers. \"\n                f\"Received: noise_shape={noise_shape}\"\n            )\n\n        if not isinstance(noise_shape, tuple):\n            try:\n                noise_shape = tuple(noise_shape)\n            except TypeError:\n                raise ValueError(\n                    f\"Invalid value received for argument `noise_shape`. \"\n                    f\"Expected an iterable of integers \"\n                    f\"(e.g., a tuple or list). \"\n                    f\"Received: noise_shape={noise_shape}\"\n                )\n\n        for i, dim in enumerate(noise_shape):\n            if dim is not None:\n                if not isinstance(dim, int):\n                    raise ValueError(\n                        f\"Invalid value received for argument `noise_shape`. \"\n                        f\"Expected all elements to be integers or None. \"\n                        f\"Received element at index {i}: {dim} \"\n                        f\"(type: {type(dim).__name__})\"\n                    )\n\n                if dim <= 0:\n                    raise ValueError(\n                        f\"Invalid value received for argument `noise_shape`. \"\n                        f\"Expected all dimensions to be positive integers \"\n                        f\"or None. \"\n                        f\"Received negative or zero value at index {i}: {dim}\"\n                    )\n\n        return noise_shape\n\n    def call(self, inputs, training=False):\n        if training and self.rate > 0:\n            return backend.random.dropout(\n                inputs,\n                self.rate,\n                noise_shape=self.noise_shape,\n                seed=self.seed_generator,\n            )\n        return inputs\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\n            \"rate\": self.rate,\n            \"seed\": self.seed,\n            \"noise_shape\": self.noise_shape,\n        }\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/regularization/dropout_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass DropoutTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_dropout_basics(self):\n        self.run_layer_test(\n            layers.Dropout,\n            init_kwargs={\n                \"rate\": 0.2,\n            },\n            input_shape=(2, 3),\n            call_kwargs={\"training\": True},\n            expected_output_shape=(2, 3),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=1,\n            expected_num_losses=0,\n            supports_masking=True,\n            assert_built_after_instantiation=True,\n        )\n\n    def test_dropout_rescaling(self):\n        inputs = np.ones((20, 500))\n        layer = layers.Dropout(0.5, seed=1337)\n        outputs = layer(inputs, training=True)\n        outputs = backend.convert_to_numpy(outputs)\n        self.assertAllClose(np.mean(outputs), 1.0, atol=0.02)\n        self.assertAllClose(np.max(outputs), 2.0)\n\n    def test_dropout_partial_noise_shape_dynamic(self):\n        inputs = np.ones((20, 5, 10))\n        layer = layers.Dropout(0.5, noise_shape=(None, 1, None))\n        outputs = layer(inputs, training=True)\n        self.assertAllClose(outputs[:, 0, :], outputs[:, 1, :])\n\n    def test_dropout_partial_noise_shape_static(self):\n        inputs = np.ones((20, 5, 10))\n        layer = layers.Dropout(0.5, noise_shape=(20, 1, 10))\n        outputs = layer(inputs, training=True)\n        self.assertAllClose(outputs[:, 0, :], outputs[:, 1, :])\n\n    def test_dropout_negative_rate(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Invalid value received for argument `rate`. \"\n            \"Expected a float value between 0 and 1.\",\n        ):\n            _ = layers.Dropout(rate=-0.5)\n\n    def test_dropout_rate_greater_than_one(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Invalid value received for argument `rate`. \"\n            \"Expected a float value between 0 and 1.\",\n        ):\n            _ = layers.Dropout(rate=1.5)\n\n    def test_validate_noise_shape_none(self):\n        layer = layers.Dropout(0.5, noise_shape=None)\n        self.assertIsNone(layer.noise_shape)\n\n    def test_validate_noise_shape_integer_tuple(self):\n        layer = layers.Dropout(0.5, noise_shape=(20, 1, 10))\n        self.assertEqual(layer.noise_shape, (20, 1, 10))\n\n    def test_validate_noise_shape_none_values(self):\n        layer = layers.Dropout(0.5, noise_shape=(None, 1, None))\n        self.assertEqual(layer.noise_shape, (None, 1, None))\n\n    def test_validate_noise_shape_cast_to_a_tuple(self):\n        layer = layers.Dropout(0.5, noise_shape=[20, 1, 10])\n        self.assertEqual(layer.noise_shape, (20, 1, 10))\n        self.assertIsInstance(layer.noise_shape, tuple)\n\n    def test_validate_noise_shape_non_iterable(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Invalid value received for argument `noise_shape`. \"\n            \"Expected a tuple or list of integers.\",\n        ):\n            layers.Dropout(0.5, noise_shape=\"Invalid\")\n\n    def test_validate_noise_shape_invalid_type(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Invalid value received for argument `noise_shape`. \"\n            \"Expected all elements to be integers or None.\",\n        ):\n            layers.Dropout(0.5, noise_shape=(20, 1.5, 10))\n\n    def test_validate_noise_shape_negative_value(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Invalid value received for argument `noise_shape`. \"\n            \"Expected all dimensions to be positive integers or None.\",\n        ):\n            layers.Dropout(0.5, noise_shape=(20, -1, 10))\n\n    def test_validate_noise_shape_zero_value(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Invalid value received for argument `noise_shape`. \"\n            \"Expected all dimensions to be positive integers or None.\",\n        ):\n            layers.Dropout(0.5, noise_shape=(20, 0, 10))\n"
  },
  {
    "path": "keras/src/layers/regularization/gaussian_dropout.py",
    "content": "import math\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\n\n\n@keras_export(\"keras.layers.GaussianDropout\")\nclass GaussianDropout(layers.Layer):\n    \"\"\"Apply multiplicative 1-centered Gaussian noise.\n\n    As it is a regularization layer, it is only active at training time.\n\n    Args:\n        rate: Float, drop probability (as with `Dropout`).\n            The multiplicative noise will have\n            standard deviation `sqrt(rate / (1 - rate))`.\n        seed: Integer, optional random seed to enable deterministic behavior.\n\n    Call arguments:\n        inputs: Input tensor (of any rank).\n        training: Python boolean indicating whether the layer should behave in\n            training mode (adding dropout) or in inference mode (doing nothing).\n    \"\"\"\n\n    def __init__(self, rate, seed=None, **kwargs):\n        super().__init__(**kwargs)\n        if not 0 <= rate <= 1:\n            raise ValueError(\n                f\"Invalid value received for argument \"\n                \"`rate`. Expected a float value between 0 and 1. \"\n                f\"Received: rate={rate}\"\n            )\n        self.rate = rate\n        self.seed = seed\n        if rate > 0:\n            self.seed_generator = backend.random.SeedGenerator(seed)\n        self.supports_masking = True\n\n        self._build_at_init()\n\n    def call(self, inputs, training=False):\n        if training and self.rate > 0:\n            stddev = math.sqrt(self.rate / (1.0 - self.rate))\n            return inputs * backend.random.normal(\n                shape=ops.shape(inputs),\n                mean=1.0,\n                stddev=stddev,\n                dtype=self.compute_dtype,\n                seed=self.seed_generator,\n            )\n        return inputs\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\n            \"rate\": self.rate,\n            \"seed\": self.seed,\n        }\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/regularization/gaussian_dropout_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass GaussianDropoutTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_gaussian_dropout_basics(self):\n        self.run_layer_test(\n            layers.GaussianDropout,\n            init_kwargs={\n                \"rate\": 0.2,\n            },\n            input_shape=(2, 3),\n            call_kwargs={\"training\": True},\n            expected_output_shape=(2, 3),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=1,\n            expected_num_losses=0,\n            supports_masking=True,\n            assert_built_after_instantiation=True,\n        )\n\n    def test_gaussian_dropout_correctness(self):\n        inputs = np.ones((20, 500))\n        layer = layers.GaussianDropout(0.3, seed=1337)\n        outputs = layer(inputs, training=True)\n        self.assertAllClose(\n            np.std(backend.convert_to_numpy(outputs)),\n            np.sqrt(0.3 / (1 - 0.3)),\n            atol=0.02,\n        )\n"
  },
  {
    "path": "keras/src/layers/regularization/gaussian_noise.py",
    "content": "from keras.src import backend\nfrom keras.src import layers\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\n\n\n@keras_export(\"keras.layers.GaussianNoise\")\nclass GaussianNoise(layers.Layer):\n    \"\"\"Apply additive zero-centered Gaussian noise.\n\n    This is useful to mitigate overfitting\n    (you could see it as a form of random data augmentation).\n    Gaussian Noise (GS) is a natural choice as corruption process\n    for real valued inputs.\n\n    As it is a regularization layer, it is only active at training time.\n\n    Args:\n        stddev: Float, standard deviation of the noise distribution.\n        seed: Integer, optional random seed to enable deterministic behavior.\n\n    Call arguments:\n        inputs: Input tensor (of any rank).\n        training: Python boolean indicating whether the layer should behave in\n            training mode (adding noise) or in inference mode (doing nothing).\n    \"\"\"\n\n    def __init__(self, stddev, seed=None, **kwargs):\n        super().__init__(**kwargs)\n        if not 0 <= stddev <= 1:\n            raise ValueError(\n                f\"Invalid value received for argument \"\n                \"`stddev`. Expected a float value between 0 and 1. \"\n                f\"Received: stddev={stddev}\"\n            )\n        self.stddev = stddev\n        self.seed = seed\n        if stddev > 0:\n            self.seed_generator = backend.random.SeedGenerator(seed)\n        self.supports_masking = True\n\n        self._build_at_init()\n\n    def call(self, inputs, training=False):\n        if training and self.stddev > 0:\n            return inputs + backend.random.normal(\n                shape=ops.shape(inputs),\n                mean=0.0,\n                stddev=self.stddev,\n                dtype=self.compute_dtype,\n                seed=self.seed_generator,\n            )\n        return inputs\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\n            \"stddev\": self.stddev,\n            \"seed\": self.seed,\n        }\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/regularization/gaussian_noise_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass GaussianNoiseTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_gaussian_noise_basics(self):\n        self.run_layer_test(\n            layers.GaussianNoise,\n            init_kwargs={\n                \"stddev\": 0.2,\n            },\n            input_shape=(2, 3),\n            call_kwargs={\"training\": True},\n            expected_output_shape=(2, 3),\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=1,\n            expected_num_losses=0,\n            supports_masking=True,\n            assert_built_after_instantiation=True,\n        )\n\n    def test_gaussian_noise_correctness(self):\n        inputs = np.ones((20, 500))\n        layer = layers.GaussianNoise(0.3, seed=1337)\n        outputs = layer(inputs, training=True)\n        self.assertAllClose(\n            np.std(backend.convert_to_numpy(outputs)), 0.3, atol=0.02\n        )\n"
  },
  {
    "path": "keras/src/layers/regularization/spatial_dropout.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.regularization.dropout import Dropout\n\n\nclass BaseSpatialDropout(Dropout):\n    def __init__(self, rate, seed=None, name=None, dtype=None):\n        super().__init__(rate, seed=seed, name=name, dtype=dtype)\n\n    def call(self, inputs, training=False):\n        if training and self.rate > 0:\n            return backend.random.dropout(\n                inputs,\n                self.rate,\n                noise_shape=self._get_noise_shape(inputs),\n                seed=self.seed_generator,\n            )\n        return inputs\n\n    def get_config(self):\n        return {\n            \"rate\": self.rate,\n            \"seed\": self.seed,\n            \"name\": self.name,\n            \"dtype\": self.dtype,\n        }\n\n\n@keras_export(\"keras.layers.SpatialDropout1D\")\nclass SpatialDropout1D(BaseSpatialDropout):\n    \"\"\"Spatial 1D version of Dropout.\n\n    This layer performs the same function as Dropout, however, it drops\n    entire 1D feature maps instead of individual elements. If adjacent frames\n    within feature maps are strongly correlated (as is normally the case in\n    early convolution layers) then regular dropout will not regularize the\n    activations and will otherwise just result in an effective learning rate\n    decrease. In this case, `SpatialDropout1D` will help promote independence\n    between feature maps and should be used instead.\n\n    Args:\n        rate: Float between 0 and 1. Fraction of the input units to drop.\n\n    Call arguments:\n        inputs: A 3D tensor.\n        training: Python boolean indicating whether the layer\n            should behave in training mode (applying dropout)\n            or in inference mode (pass-through).\n\n    Input shape:\n        3D tensor with shape: `(samples, timesteps, channels)`\n\n    Output shape: Same as input.\n\n    Reference:\n\n    - [Tompson et al., 2014](https://arxiv.org/abs/1411.4280)\n    \"\"\"\n\n    def __init__(self, rate, seed=None, name=None, dtype=None):\n        super().__init__(rate, seed=seed, name=name, dtype=dtype)\n        self.input_spec = InputSpec(ndim=3)\n\n    def _get_noise_shape(self, inputs):\n        input_shape = ops.shape(inputs)\n        return (input_shape[0], 1, input_shape[2])\n\n\n@keras_export(\"keras.layers.SpatialDropout2D\")\nclass SpatialDropout2D(BaseSpatialDropout):\n    \"\"\"Spatial 2D version of Dropout.\n\n    This version performs the same function as Dropout, however, it drops\n    entire 2D feature maps instead of individual elements. If adjacent pixels\n    within feature maps are strongly correlated (as is normally the case in\n    early convolution layers) then regular dropout will not regularize the\n    activations and will otherwise just result in an effective learning rate\n    decrease. In this case, `SpatialDropout2D` will help promote independence\n    between feature maps and should be used instead.\n\n    Args:\n        rate: Float between 0 and 1. Fraction of the input units to drop.\n        data_format: `\"channels_first\"` or `\"channels_last\"`.\n            In `\"channels_first\"` mode, the channels dimension (the depth)\n            is at index 1, in `\"channels_last\"` mode is it at index 3.\n            It defaults to the `image_data_format` value found in your\n            Keras config file at `~/.keras/keras.json`.\n            If you never set it, then it will be `\"channels_last\"`.\n\n    Call arguments:\n        inputs: A 4D tensor.\n        training: Python boolean indicating whether the layer\n            should behave in training mode (applying dropout)\n            or in inference mode (pass-through).\n\n    Input shape:\n        4D tensor with shape: `(samples, channels, rows, cols)` if\n            data_format='channels_first'\n        or 4D tensor with shape: `(samples, rows, cols, channels)` if\n            data_format='channels_last'.\n\n    Output shape: Same as input.\n\n    Reference:\n\n    - [Tompson et al., 2014](https://arxiv.org/abs/1411.4280)\n    \"\"\"\n\n    def __init__(\n        self, rate, data_format=None, seed=None, name=None, dtype=None\n    ):\n        super().__init__(rate, seed=seed, name=name, dtype=dtype)\n        self.data_format = backend.standardize_data_format(data_format)\n        self.input_spec = InputSpec(ndim=4)\n\n    def _get_noise_shape(self, inputs):\n        input_shape = ops.shape(inputs)\n        if self.data_format == \"channels_first\":\n            return (input_shape[0], input_shape[1], 1, 1)\n        elif self.data_format == \"channels_last\":\n            return (input_shape[0], 1, 1, input_shape[3])\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\n            \"data_format\": self.data_format,\n        }\n        return {**base_config, **config}\n\n\n@keras_export(\"keras.layers.SpatialDropout3D\")\nclass SpatialDropout3D(BaseSpatialDropout):\n    \"\"\"Spatial 3D version of Dropout.\n\n    This version performs the same function as Dropout, however, it drops\n    entire 3D feature maps instead of individual elements. If adjacent voxels\n    within feature maps are strongly correlated (as is normally the case in\n    early convolution layers) then regular dropout will not regularize the\n    activations and will otherwise just result in an effective learning rate\n    decrease. In this case, SpatialDropout3D will help promote independence\n    between feature maps and should be used instead.\n\n    Args:\n        rate: Float between 0 and 1. Fraction of the input units to drop.\n        data_format: `\"channels_first\"` or `\"channels_last\"`.\n            In `\"channels_first\"` mode, the channels dimension (the depth)\n            is at index 1, in `\"channels_last\"` mode is it at index 4.\n            It defaults to the `image_data_format` value found in your\n            Keras config file at `~/.keras/keras.json`.\n            If you never set it, then it will be `\"channels_last\"`.\n\n    Call arguments:\n        inputs: A 5D tensor.\n        training: Python boolean indicating whether the layer\n                should behave in training mode (applying dropout)\n                or in inference mode (pass-through).\n\n    Input shape:\n        5D tensor with shape: `(samples, channels, dim1, dim2, dim3)` if\n            data_format='channels_first'\n        or 5D tensor with shape: `(samples, dim1, dim2, dim3, channels)` if\n            data_format='channels_last'.\n\n    Output shape: Same as input.\n\n    Reference:\n\n    - [Tompson et al., 2014](https://arxiv.org/abs/1411.4280)\n    \"\"\"\n\n    def __init__(\n        self, rate, data_format=None, seed=None, name=None, dtype=None\n    ):\n        super().__init__(rate, seed=seed, name=name, dtype=dtype)\n        self.data_format = backend.standardize_data_format(data_format)\n        self.input_spec = InputSpec(ndim=5)\n\n    def _get_noise_shape(self, inputs):\n        input_shape = ops.shape(inputs)\n        if self.data_format == \"channels_first\":\n            return (input_shape[0], input_shape[1], 1, 1, 1)\n        elif self.data_format == \"channels_last\":\n            return (input_shape[0], 1, 1, 1, input_shape[4])\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\n            \"data_format\": self.data_format,\n        }\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/regularization/spatial_dropout_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src.testing import test_case\n\n\nclass SpatialDropoutTest(test_case.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_spatial_dropout_1d(self):\n        self.run_layer_test(\n            layers.SpatialDropout1D,\n            init_kwargs={\"rate\": 0.5},\n            call_kwargs={\"training\": True},\n            input_shape=(2, 3, 4),\n            assert_built_after_instantiation=True,\n        )\n\n        self.run_layer_test(\n            layers.SpatialDropout1D,\n            init_kwargs={\"rate\": 0.5},\n            call_kwargs={\"training\": False},\n            input_shape=(2, 3, 4),\n            assert_built_after_instantiation=True,\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_spatial_dropout_2d(self):\n        self.run_layer_test(\n            layers.SpatialDropout2D,\n            init_kwargs={\"rate\": 0.5},\n            call_kwargs={\"training\": True},\n            input_shape=(2, 3, 4, 5),\n            assert_built_after_instantiation=True,\n        )\n\n        self.run_layer_test(\n            layers.SpatialDropout2D,\n            init_kwargs={\"rate\": 0.5, \"data_format\": \"channels_first\"},\n            call_kwargs={\"training\": True},\n            input_shape=(2, 3, 4, 5),\n            assert_built_after_instantiation=True,\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_spatial_dropout_3d(self):\n        self.run_layer_test(\n            layers.SpatialDropout3D,\n            init_kwargs={\"rate\": 0.5},\n            call_kwargs={\"training\": True},\n            input_shape=(2, 3, 4, 4, 5),\n            assert_built_after_instantiation=True,\n        )\n\n        self.run_layer_test(\n            layers.SpatialDropout3D,\n            init_kwargs={\"rate\": 0.5, \"data_format\": \"channels_first\"},\n            call_kwargs={\"training\": True},\n            input_shape=(2, 3, 4, 4, 5),\n            assert_built_after_instantiation=True,\n        )\n\n    def test_spatial_dropout_1D_dynamic(self):\n        inputs = layers.Input((3, 2))\n        layer = layers.SpatialDropout1D(0.5)\n        layer(inputs, training=True)\n\n    def test_spatial_dropout_1D_correctness(self):\n        inputs = np.ones((10, 3, 10))\n        layer = layers.SpatialDropout1D(0.5)\n        outputs = layer(inputs, training=True)\n        self.assertAllClose(outputs[:, 0, :], outputs[:, 1, :])\n\n    def test_spatial_dropout_2D_dynamic(self):\n        inputs = layers.Input((3, 2, 4))\n        layer = layers.SpatialDropout2D(0.5)\n        layer(inputs, training=True)\n\n    def test_spatial_dropout_2D_correctness(self):\n        if backend.config.image_data_format() == \"channels_last\":\n            inputs = np.ones((10, 3, 3, 10))\n        else:\n            inputs = np.ones((10, 10, 3, 3))\n        layer = layers.SpatialDropout2D(0.5)\n        outputs = layer(inputs, training=True)\n        if backend.config.image_data_format() == \"channels_last\":\n            self.assertAllClose(outputs[:, 0, 0, :], outputs[:, 1, 1, :])\n        else:\n            self.assertAllClose(outputs[:, :, 0, 0], outputs[:, :, 1, 1])\n\n    def test_spatial_dropout_3D_dynamic(self):\n        inputs = layers.Input((3, 2, 4, 2))\n        layer = layers.SpatialDropout3D(0.5)\n        layer(inputs, training=True)\n\n    def test_spatial_dropout_3D_correctness(self):\n        if backend.config.image_data_format() == \"channels_last\":\n            inputs = np.ones((10, 3, 3, 3, 10))\n        else:\n            inputs = np.ones((10, 10, 3, 3, 3))\n        layer = layers.SpatialDropout3D(0.5)\n        outputs = layer(inputs, training=True)\n        if backend.config.image_data_format() == \"channels_last\":\n            self.assertAllClose(outputs[:, 0, 0, 0, :], outputs[:, 1, 1, 1, :])\n        else:\n            self.assertAllClose(outputs[:, :, 0, 0, 0], outputs[:, :, 1, 1, 1])\n"
  },
  {
    "path": "keras/src/layers/reshaping/__init__.py",
    "content": ""
  },
  {
    "path": "keras/src/layers/reshaping/cropping1d.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\nfrom keras.src.utils import argument_validation\n\n\n@keras_export(\"keras.layers.Cropping1D\")\nclass Cropping1D(Layer):\n    \"\"\"Cropping layer for 1D input (e.g. temporal sequence).\n\n    It crops along the time dimension (axis 1).\n\n    Example:\n\n    >>> input_shape = (2, 3, 2)\n    >>> x = np.arange(np.prod(input_shape)).reshape(input_shape)\n    >>> x\n    [[[ 0  1]\n      [ 2  3]\n      [ 4  5]]\n     [[ 6  7]\n      [ 8  9]\n      [10 11]]]\n    >>> y = keras.layers.Cropping1D(cropping=1)(x)\n    >>> y\n    [[[2 3]]\n     [[8 9]]]\n\n    Args:\n        cropping: Int, or tuple of int (length 2), or dictionary.\n            - If int: how many units should be trimmed off at the beginning and\n              end of the cropping dimension (axis 1).\n            - If tuple of 2 ints: how many units should be trimmed off at the\n              beginning and end of the cropping dimension\n              (`(left_crop, right_crop)`).\n\n    Input shape:\n        3D tensor with shape `(batch_size, axis_to_crop, features)`\n\n    Output shape:\n        3D tensor with shape `(batch_size, cropped_axis, features)`\n    \"\"\"\n\n    def __init__(self, cropping=(1, 1), **kwargs):\n        super().__init__(**kwargs)\n        self.cropping = argument_validation.standardize_tuple(\n            cropping, 2, \"cropping\", allow_zero=True\n        )\n        self.input_spec = InputSpec(ndim=3)\n\n    def compute_output_shape(self, input_shape):\n        if input_shape[1] is not None:\n            length = input_shape[1] - self.cropping[0] - self.cropping[1]\n            if length <= 0:\n                raise ValueError(\n                    \"`cropping` parameter of `Cropping1D` layer must be \"\n                    \"smaller than the input length. Received: input_shape=\"\n                    f\"{input_shape}, cropping={self.cropping}\"\n                )\n        else:\n            length = None\n        return (input_shape[0], length, input_shape[2])\n\n    def call(self, inputs):\n        if (\n            inputs.shape[1] is not None\n            and sum(self.cropping) >= inputs.shape[1]\n        ):\n            raise ValueError(\n                \"`cropping` parameter of `Cropping1D` layer must be \"\n                \"smaller than the input length. Received: inputs.shape=\"\n                f\"{inputs.shape}, cropping={self.cropping}\"\n            )\n        if self.cropping[1] == 0:\n            return inputs[:, self.cropping[0] :, :]\n        else:\n            return inputs[:, self.cropping[0] : -self.cropping[1], :]\n\n    def get_config(self):\n        config = {\"cropping\": self.cropping}\n        base_config = super().get_config()\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/reshaping/cropping1d_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import layers\nfrom keras.src import ops\nfrom keras.src import testing\n\n\nclass Cropping1DTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_cropping_1d(self):\n        inputs = np.random.rand(3, 5, 7)\n\n        # Cropping with different values on the left and the right.\n        self.run_layer_test(\n            layers.Cropping1D,\n            init_kwargs={\"cropping\": (1, 2)},\n            input_data=inputs,\n            expected_output=ops.convert_to_tensor(inputs[:, 1:3, :]),\n        )\n        # Same cropping on the left and the right.\n        self.run_layer_test(\n            layers.Cropping1D,\n            init_kwargs={\"cropping\": (1, 1)},\n            input_data=inputs,\n            expected_output=ops.convert_to_tensor(inputs[:, 1:4, :]),\n        )\n        # Same cropping on the left and the right provided as an int.\n        self.run_layer_test(\n            layers.Cropping1D,\n            init_kwargs={\"cropping\": 1},\n            input_data=inputs,\n            expected_output=ops.convert_to_tensor(inputs[:, 1:4, :]),\n        )\n        # Cropping on the right only.\n        self.run_layer_test(\n            layers.Cropping1D,\n            init_kwargs={\"cropping\": (0, 1)},\n            input_data=inputs,\n            expected_output=ops.convert_to_tensor(inputs[:, 0:4, :]),\n        )\n        # Cropping on the left only.\n        self.run_layer_test(\n            layers.Cropping1D,\n            init_kwargs={\"cropping\": (1, 0)},\n            input_data=inputs,\n            expected_output=ops.convert_to_tensor(inputs[:, 1:5, :]),\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_cropping_1d_with_dynamic_spatial_dim(self):\n        input_layer = layers.Input(batch_shape=(1, None, 7))\n        cropped = layers.Cropping1D((1, 2))(input_layer)\n        self.assertEqual(cropped.shape, (1, None, 7))\n\n    def test_cropping_1d_errors_if_cropping_argument_invalid(self):\n        with self.assertRaises(ValueError):\n            layers.Cropping1D(cropping=(1,))\n        with self.assertRaises(ValueError):\n            layers.Cropping1D(cropping=(1, 2, 3))\n        with self.assertRaises(ValueError):\n            layers.Cropping1D(cropping=\"1\")\n\n    def test_cropping_1d_errors_if_cropping_more_than_available(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"`cropping` parameter of `Cropping1D` layer must be smaller than\",\n        ):\n            input_layer = layers.Input(batch_shape=(3, 5, 7))\n            layers.Cropping1D(cropping=(2, 3))(input_layer)\n\n    def test_cropping_1d_error_on_excessive_cropping(self):\n        inputs = np.random.rand(3, 5, 7)\n\n        with self.assertRaisesRegex(\n            ValueError,\n            \"`cropping` parameter of `Cropping1D` layer must be smaller than\",\n        ):\n            layer = layers.Cropping1D(cropping=(3, 3))\n            _ = layer(inputs)\n"
  },
  {
    "path": "keras/src/layers/reshaping/cropping2d.py",
    "content": "from keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\nfrom keras.src.utils import argument_validation\n\n\n@keras_export(\"keras.layers.Cropping2D\")\nclass Cropping2D(Layer):\n    \"\"\"Cropping layer for 2D input (e.g. picture).\n\n    It crops along spatial dimensions, i.e. height and width.\n\n    Example:\n\n    >>> input_shape = (2, 28, 28, 3)\n    >>> x = np.arange(np.prod(input_shape)).reshape(input_shape)\n    >>> y = keras.layers.Cropping2D(cropping=((2, 2), (4, 4)))(x)\n    >>> y.shape\n    (2, 24, 20, 3)\n\n    Args:\n        cropping: Int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints.\n            - If int: the same symmetric cropping is applied to height and\n              width.\n            - If tuple of 2 ints: interpreted as two different symmetric\n              cropping values for height and width:\n              `(symmetric_height_crop, symmetric_width_crop)`.\n            - If tuple of 2 tuples of 2 ints: interpreted as\n              `((top_crop, bottom_crop), (left_crop, right_crop))`.\n        data_format: A string, one of `\"channels_last\"` (default) or\n            `\"channels_first\"`. The ordering of the dimensions in the inputs.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch_size, height, width, channels)` while `\"channels_first\"`\n            corresponds to inputs with shape\n            `(batch_size, channels, height, width)`.\n            When unspecified, uses `image_data_format` value found in your Keras\n            config file at `~/.keras/keras.json` (if exists). Defaults to\n            `\"channels_last\"`.\n\n    Input shape:\n        4D tensor with shape:\n        - If `data_format` is `\"channels_last\"`:\n          `(batch_size, height, width, channels)`\n        - If `data_format` is `\"channels_first\"`:\n          `(batch_size, channels, height, width)`\n\n    Output shape:\n        4D tensor with shape:\n        - If `data_format` is `\"channels_last\"`:\n          `(batch_size, cropped_height, cropped_width, channels)`\n        - If `data_format` is `\"channels_first\"`:\n          `(batch_size, channels, cropped_height, cropped_width)`\n    \"\"\"\n\n    def __init__(self, cropping=((0, 0), (0, 0)), data_format=None, **kwargs):\n        super().__init__(**kwargs)\n        self.data_format = backend.standardize_data_format(data_format)\n        if isinstance(cropping, int):\n            if cropping < 0:\n                raise ValueError(\n                    \"`cropping` cannot be negative. \"\n                    f\"Received: cropping={cropping}.\"\n                )\n            self.cropping = ((cropping, cropping), (cropping, cropping))\n        elif hasattr(cropping, \"__len__\"):\n            if len(cropping) != 2:\n                raise ValueError(\n                    \"`cropping` should have two elements. \"\n                    f\"Received: cropping={cropping}.\"\n                )\n            height_cropping = argument_validation.standardize_tuple(\n                cropping[0], 2, \"1st entry of cropping\", allow_zero=True\n            )\n            width_cropping = argument_validation.standardize_tuple(\n                cropping[1], 2, \"2nd entry of cropping\", allow_zero=True\n            )\n            self.cropping = (height_cropping, width_cropping)\n        else:\n            raise ValueError(\n                \"`cropping` should be either an int, a tuple of 2 ints \"\n                \"(symmetric_height_crop, symmetric_width_crop), \"\n                \"or a tuple of 2 tuples of 2 ints \"\n                \"((top_crop, bottom_crop), (left_crop, right_crop)). \"\n                f\"Received: cropping={cropping}.\"\n            )\n        self.input_spec = InputSpec(ndim=4)\n\n    def compute_output_shape(self, input_shape):\n        if self.data_format == \"channels_first\":\n            if (\n                input_shape[2] is not None\n                and sum(self.cropping[0]) >= input_shape[2]\n            ) or (\n                input_shape[3] is not None\n                and sum(self.cropping[1]) >= input_shape[3]\n            ):\n                raise ValueError(\n                    \"Values in `cropping` argument should be smaller than the \"\n                    \"corresponding spatial dimension of the input. Received: \"\n                    f\"input_shape={input_shape}, cropping={self.cropping}\"\n                )\n            return (\n                input_shape[0],\n                input_shape[1],\n                (\n                    input_shape[2] - self.cropping[0][0] - self.cropping[0][1]\n                    if input_shape[2] is not None\n                    else None\n                ),\n                (\n                    input_shape[3] - self.cropping[1][0] - self.cropping[1][1]\n                    if input_shape[3] is not None\n                    else None\n                ),\n            )\n        else:\n            if (\n                input_shape[1] is not None\n                and sum(self.cropping[0]) >= input_shape[1]\n            ) or (\n                input_shape[2] is not None\n                and sum(self.cropping[1]) >= input_shape[2]\n            ):\n                raise ValueError(\n                    \"Values in `cropping` argument should be smaller than the \"\n                    \"corresponding spatial dimension of the input. Received: \"\n                    f\"input_shape={input_shape}, cropping={self.cropping}\"\n                )\n            return (\n                input_shape[0],\n                (\n                    input_shape[1] - self.cropping[0][0] - self.cropping[0][1]\n                    if input_shape[1] is not None\n                    else None\n                ),\n                (\n                    input_shape[2] - self.cropping[1][0] - self.cropping[1][1]\n                    if input_shape[2] is not None\n                    else None\n                ),\n                input_shape[3],\n            )\n\n    def call(self, inputs):\n        if self.data_format == \"channels_first\":\n            if (\n                inputs.shape[2] is not None\n                and sum(self.cropping[0]) >= inputs.shape[2]\n            ) or (\n                inputs.shape[3] is not None\n                and sum(self.cropping[1]) >= inputs.shape[3]\n            ):\n                raise ValueError(\n                    \"Values in `cropping` argument should be smaller than the \"\n                    \"corresponding spatial dimension of the input. Received: \"\n                    f\"inputs.shape={inputs.shape}, cropping={self.cropping}\"\n                )\n            if self.cropping[0][1] == self.cropping[1][1] == 0:\n                return inputs[\n                    :, :, self.cropping[0][0] :, self.cropping[1][0] :\n                ]\n            elif self.cropping[0][1] == 0:\n                return inputs[\n                    :,\n                    :,\n                    self.cropping[0][0] :,\n                    self.cropping[1][0] : -self.cropping[1][1],\n                ]\n            elif self.cropping[1][1] == 0:\n                return inputs[\n                    :,\n                    :,\n                    self.cropping[0][0] : -self.cropping[0][1],\n                    self.cropping[1][0] :,\n                ]\n            return inputs[\n                :,\n                :,\n                self.cropping[0][0] : -self.cropping[0][1],\n                self.cropping[1][0] : -self.cropping[1][1],\n            ]\n        else:\n            if (\n                inputs.shape[1] is not None\n                and sum(self.cropping[0]) >= inputs.shape[1]\n            ) or (\n                inputs.shape[2] is not None\n                and sum(self.cropping[1]) >= inputs.shape[2]\n            ):\n                raise ValueError(\n                    \"Values in `cropping` argument should be smaller than the \"\n                    \"corresponding spatial dimension of the input. Received: \"\n                    f\"inputs.shape={inputs.shape}, cropping={self.cropping}\"\n                )\n            if self.cropping[0][1] == self.cropping[1][1] == 0:\n                return inputs[\n                    :, self.cropping[0][0] :, self.cropping[1][0] :, :\n                ]\n            elif self.cropping[0][1] == 0:\n                return inputs[\n                    :,\n                    self.cropping[0][0] :,\n                    self.cropping[1][0] : -self.cropping[1][1],\n                    :,\n                ]\n            elif self.cropping[1][1] == 0:\n                return inputs[\n                    :,\n                    self.cropping[0][0] : -self.cropping[0][1],\n                    self.cropping[1][0] :,\n                    :,\n                ]\n            return inputs[\n                :,\n                self.cropping[0][0] : -self.cropping[0][1],\n                self.cropping[1][0] : -self.cropping[1][1],\n                :,\n            ]\n\n    def get_config(self):\n        config = {\"cropping\": self.cropping, \"data_format\": self.data_format}\n        base_config = super().get_config()\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/reshaping/cropping2d_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import ops\nfrom keras.src import testing\n\n\nclass Cropping2DTest(testing.TestCase):\n    @parameterized.product(\n        (\n            # different cropping values\n            {\"cropping\": ((1, 2), (3, 4)), \"expected_ranges\": ((1, 5), (3, 5))},\n            # same cropping values with 2 tuples\n            {\"cropping\": ((2, 2), (2, 2)), \"expected_ranges\": ((2, 5), (2, 7))},\n            # same cropping values with 1 tuple\n            {\"cropping\": (2, 2), \"expected_ranges\": ((2, 5), (2, 7))},\n            # same cropping values with an integer\n            {\"cropping\": 2, \"expected_ranges\": ((2, 5), (2, 7))},\n            # cropping right only in both dimensions\n            {\"cropping\": ((0, 2), (0, 4)), \"expected_ranges\": ((0, 5), (0, 5))},\n            # cropping left only in both dimensions\n            {\"cropping\": ((1, 0), (3, 0)), \"expected_ranges\": ((1, 7), (3, 9))},\n            # cropping left only in rows dimension\n            {\"cropping\": ((1, 0), (3, 4)), \"expected_ranges\": ((1, 7), (3, 5))},\n            # cropping left only in cols dimension\n            {\"cropping\": ((1, 2), (3, 0)), \"expected_ranges\": ((1, 5), (3, 9))},\n        ),\n        (\n            {\"data_format\": \"channels_first\"},\n            {\"data_format\": \"channels_last\"},\n        ),\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_cropping_2d(self, cropping, data_format, expected_ranges):\n        if data_format == \"channels_first\":\n            inputs = np.random.rand(3, 5, 7, 9)\n            expected_output = ops.convert_to_tensor(\n                inputs[\n                    :,\n                    :,\n                    expected_ranges[0][0] : expected_ranges[0][1],\n                    expected_ranges[1][0] : expected_ranges[1][1],\n                ]\n            )\n        else:\n            inputs = np.random.rand(3, 7, 9, 5)\n            expected_output = ops.convert_to_tensor(\n                inputs[\n                    :,\n                    expected_ranges[0][0] : expected_ranges[0][1],\n                    expected_ranges[1][0] : expected_ranges[1][1],\n                    :,\n                ]\n            )\n\n        self.run_layer_test(\n            layers.Cropping2D,\n            init_kwargs={\"cropping\": cropping, \"data_format\": data_format},\n            input_data=inputs,\n            expected_output=expected_output,\n        )\n\n    def test_cropping_2d_with_dynamic_spatial_dim(self):\n        if backend.config.image_data_format() == \"channels_last\":\n            input_layer = layers.Input(batch_shape=(1, 7, None, 5))\n        else:\n            input_layer = layers.Input(batch_shape=(1, 5, 7, None))\n        cropped = layers.Cropping2D(((1, 2), (3, 4)))(input_layer)\n        if backend.config.image_data_format() == \"channels_last\":\n            self.assertEqual(cropped.shape, (1, 4, None, 5))\n        else:\n            self.assertEqual(cropped.shape, (1, 5, 4, None))\n\n    @parameterized.product(\n        (\n            {\"cropping\": ((3, 6), (0, 0))},\n            {\"cropping\": ((0, 0), (5, 4))},\n        ),\n        (\n            {\"data_format\": \"channels_first\"},\n            {\"data_format\": \"channels_last\"},\n        ),\n    )\n    def test_cropping_2d_errors_if_cropping_more_than_available(\n        self, cropping, data_format\n    ):\n        input_layer = layers.Input(batch_shape=(3, 7, 9, 5))\n        with self.assertRaises(ValueError):\n            layers.Cropping2D(cropping=cropping, data_format=data_format)(\n                input_layer\n            )\n\n    def test_cropping_2d_errors_if_cropping_argument_invalid(self):\n        with self.assertRaises(ValueError):\n            layers.Cropping2D(cropping=(1,))\n        with self.assertRaises(ValueError):\n            layers.Cropping2D(cropping=(1, 2, 3))\n        with self.assertRaises(ValueError):\n            layers.Cropping2D(cropping=\"1\")\n        with self.assertRaises(ValueError):\n            layers.Cropping2D(cropping=((1, 2), (3, 4, 5)))\n        with self.assertRaises(ValueError):\n            layers.Cropping2D(cropping=((1, 2), (3, -4)))\n        with self.assertRaises(ValueError):\n            layers.Cropping2D(cropping=((1, 2), \"3\"))\n\n    @parameterized.product(\n        (\n            {\"cropping\": ((4, 5), (0, 0)), \"input_shape\": (3, 8, 9, 5)},\n            {\"cropping\": ((0, 0), (5, 5)), \"input_shape\": (3, 8, 9, 5)},\n            {\"cropping\": ((6, 3), (0, 0)), \"input_shape\": (3, 8, 9, 5)},\n            {\"cropping\": ((0, 0), (7, 3)), \"input_shape\": (3, 8, 9, 5)},\n        ),\n        (\n            {\"data_format\": \"channels_first\"},\n            {\"data_format\": \"channels_last\"},\n        ),\n    )\n    def test_cropping_2d_error_on_excessive_cropping(\n        self, cropping, input_shape, data_format\n    ):\n        inputs = np.random.rand(*input_shape)\n\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Values in `cropping` argument should be smaller than the \"\n            \"corresponding spatial dimension of the input.\",\n        ):\n            layer = layers.Cropping2D(\n                cropping=cropping, data_format=data_format\n            )\n            _ = layer(inputs)\n"
  },
  {
    "path": "keras/src/layers/reshaping/cropping3d.py",
    "content": "from keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\nfrom keras.src.utils import argument_validation\n\n\n@keras_export(\"keras.layers.Cropping3D\")\nclass Cropping3D(Layer):\n    \"\"\"Cropping layer for 3D data (e.g. spatial or spatio-temporal).\n\n    Example:\n\n    >>> input_shape = (2, 28, 28, 10, 3)\n    >>> x = np.arange(np.prod(input_shape)).reshape(input_shape)\n    >>> y = keras.layers.Cropping3D(cropping=(2, 4, 2))(x)\n    >>> y.shape\n    (2, 24, 20, 6, 3)\n\n    Args:\n        cropping: Int, or tuple of 3 ints, or tuple of 3 tuples of 2 ints.\n            - If int: the same symmetric cropping is applied to depth, height,\n              and width.\n            - If tuple of 3 ints: interpreted as three different symmetric\n              cropping values for depth, height, and width:\n              `(symmetric_dim1_crop, symmetric_dim2_crop, symmetric_dim3_crop)`.\n            - If tuple of 3 tuples of 2 ints: interpreted as\n              `((left_dim1_crop, right_dim1_crop), (left_dim2_crop,\n              right_dim2_crop), (left_dim3_crop, right_dim3_crop))`.\n        data_format: A string, one of `\"channels_last\"` (default) or\n            `\"channels_first\"`. The ordering of the dimensions in the inputs.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.\n            When unspecified, uses `image_data_format` value found in your Keras\n            config file at `~/.keras/keras.json` (if exists). Defaults to\n            `\"channels_last\"`.\n\n    Input shape:\n        5D tensor with shape:\n        - If `data_format` is `\"channels_last\"`:\n          `(batch_size, first_axis_to_crop, second_axis_to_crop,\n          third_axis_to_crop, channels)`\n        - If `data_format` is `\"channels_first\"`:\n          `(batch_size, channels, first_axis_to_crop, second_axis_to_crop,\n          third_axis_to_crop)`\n\n    Output shape:\n        5D tensor with shape:\n        - If `data_format` is `\"channels_last\"`:\n          `(batch_size, first_cropped_axis, second_cropped_axis,\n          third_cropped_axis, channels)`\n        - If `data_format` is `\"channels_first\"`:\n          `(batch_size, channels, first_cropped_axis, second_cropped_axis,\n          third_cropped_axis)`\n    \"\"\"\n\n    def __init__(\n        self, cropping=((1, 1), (1, 1), (1, 1)), data_format=None, **kwargs\n    ):\n        super().__init__(**kwargs)\n        self.data_format = backend.standardize_data_format(data_format)\n        if isinstance(cropping, int):\n            if cropping < 0:\n                raise ValueError(\n                    \"`cropping` cannot be negative. \"\n                    f\"Received: cropping={cropping}.\"\n                )\n            self.cropping = (\n                (cropping, cropping),\n                (cropping, cropping),\n                (cropping, cropping),\n            )\n        elif hasattr(cropping, \"__len__\"):\n            if len(cropping) != 3:\n                raise ValueError(\n                    f\"`cropping` should have 3 elements. Received: {cropping}.\"\n                )\n            dim1_cropping = argument_validation.standardize_tuple(\n                cropping[0], 2, \"1st entry of cropping\", allow_zero=True\n            )\n            dim2_cropping = argument_validation.standardize_tuple(\n                cropping[1], 2, \"2nd entry of cropping\", allow_zero=True\n            )\n            dim3_cropping = argument_validation.standardize_tuple(\n                cropping[2], 2, \"3rd entry of cropping\", allow_zero=True\n            )\n            self.cropping = (dim1_cropping, dim2_cropping, dim3_cropping)\n        else:\n            raise ValueError(\n                \"`cropping` should be either an int, a tuple of 3 ints \"\n                \"(symmetric_dim1_crop, symmetric_dim2_crop, \"\n                \"symmetric_dim3_crop), \"\n                \"or a tuple of 3 tuples of 2 ints \"\n                \"((left_dim1_crop, right_dim1_crop),\"\n                \" (left_dim2_crop, right_dim2_crop),\"\n                \" (left_dim3_crop, right_dim2_crop)). \"\n                f\"Received: {cropping}.\"\n            )\n        self.input_spec = InputSpec(ndim=5)\n\n    def compute_output_shape(self, input_shape):\n        if self.data_format == \"channels_first\":\n            spatial_dims = list(input_shape[2:5])\n        else:\n            spatial_dims = list(input_shape[1:4])\n\n        for index in range(0, 3):\n            if spatial_dims[index] is None:\n                continue\n            spatial_dims[index] -= sum(self.cropping[index])\n            if spatial_dims[index] <= 0:\n                raise ValueError(\n                    \"Values in `cropping` argument should be smaller than the \"\n                    \"corresponding spatial dimension of the input. Received: \"\n                    f\"input_shape={input_shape}, cropping={self.cropping}\"\n                )\n\n        if self.data_format == \"channels_first\":\n            return (input_shape[0], input_shape[1], *spatial_dims)\n        else:\n            return (input_shape[0], *spatial_dims, input_shape[4])\n\n    def call(self, inputs):\n        if self.data_format == \"channels_first\":\n            spatial_dims = list(inputs.shape[2:5])\n        else:\n            spatial_dims = list(inputs.shape[1:4])\n\n        for index in range(0, 3):\n            if spatial_dims[index] is None:\n                continue\n            spatial_dims[index] -= sum(self.cropping[index])\n            if spatial_dims[index] <= 0:\n                raise ValueError(\n                    \"Values in `cropping` argument should be smaller than the \"\n                    \"corresponding spatial dimension of the input. Received: \"\n                    f\"inputs.shape={inputs.shape}, cropping={self.cropping}\"\n                )\n\n        if self.data_format == \"channels_first\":\n            if (\n                self.cropping[0][1]\n                == self.cropping[1][1]\n                == self.cropping[2][1]\n                == 0\n            ):\n                return inputs[\n                    :,\n                    :,\n                    self.cropping[0][0] :,\n                    self.cropping[1][0] :,\n                    self.cropping[2][0] :,\n                ]\n            elif self.cropping[0][1] == self.cropping[1][1] == 0:\n                return inputs[\n                    :,\n                    :,\n                    self.cropping[0][0] :,\n                    self.cropping[1][0] :,\n                    self.cropping[2][0] : -self.cropping[2][1],\n                ]\n            elif self.cropping[1][1] == self.cropping[2][1] == 0:\n                return inputs[\n                    :,\n                    :,\n                    self.cropping[0][0] : -self.cropping[0][1],\n                    self.cropping[1][0] :,\n                    self.cropping[2][0] :,\n                ]\n            elif self.cropping[0][1] == self.cropping[2][1] == 0:\n                return inputs[\n                    :,\n                    :,\n                    self.cropping[0][0] :,\n                    self.cropping[1][0] : -self.cropping[1][1],\n                    self.cropping[2][0] :,\n                ]\n            elif self.cropping[0][1] == 0:\n                return inputs[\n                    :,\n                    :,\n                    self.cropping[0][0] :,\n                    self.cropping[1][0] : -self.cropping[1][1],\n                    self.cropping[2][0] : -self.cropping[2][1],\n                ]\n            elif self.cropping[1][1] == 0:\n                return inputs[\n                    :,\n                    :,\n                    self.cropping[0][0] : -self.cropping[0][1],\n                    self.cropping[1][0] :,\n                    self.cropping[2][0] : -self.cropping[2][1],\n                ]\n            elif self.cropping[2][1] == 0:\n                return inputs[\n                    :,\n                    :,\n                    self.cropping[0][0] : -self.cropping[0][1],\n                    self.cropping[1][0] : -self.cropping[1][1],\n                    self.cropping[2][0] :,\n                ]\n            return inputs[\n                :,\n                :,\n                self.cropping[0][0] : -self.cropping[0][1],\n                self.cropping[1][0] : -self.cropping[1][1],\n                self.cropping[2][0] : -self.cropping[2][1],\n            ]\n        else:\n            if (\n                self.cropping[0][1]\n                == self.cropping[1][1]\n                == self.cropping[2][1]\n                == 0\n            ):\n                return inputs[\n                    :,\n                    self.cropping[0][0] :,\n                    self.cropping[1][0] :,\n                    self.cropping[2][0] :,\n                    :,\n                ]\n            elif self.cropping[0][1] == self.cropping[1][1] == 0:\n                return inputs[\n                    :,\n                    self.cropping[0][0] :,\n                    self.cropping[1][0] :,\n                    self.cropping[2][0] : -self.cropping[2][1],\n                    :,\n                ]\n            elif self.cropping[1][1] == self.cropping[2][1] == 0:\n                return inputs[\n                    :,\n                    self.cropping[0][0] : -self.cropping[0][1],\n                    self.cropping[1][0] :,\n                    self.cropping[2][0] :,\n                    :,\n                ]\n            elif self.cropping[0][1] == self.cropping[2][1] == 0:\n                return inputs[\n                    :,\n                    self.cropping[0][0] :,\n                    self.cropping[1][0] : -self.cropping[1][1],\n                    self.cropping[2][0] :,\n                    :,\n                ]\n            elif self.cropping[0][1] == 0:\n                return inputs[\n                    :,\n                    self.cropping[0][0] :,\n                    self.cropping[1][0] : -self.cropping[1][1],\n                    self.cropping[2][0] : -self.cropping[2][1],\n                    :,\n                ]\n            elif self.cropping[1][1] == 0:\n                return inputs[\n                    :,\n                    self.cropping[0][0] : -self.cropping[0][1],\n                    self.cropping[1][0] :,\n                    self.cropping[2][0] : -self.cropping[2][1],\n                    :,\n                ]\n            elif self.cropping[2][1] == 0:\n                return inputs[\n                    :,\n                    self.cropping[0][0] : -self.cropping[0][1],\n                    self.cropping[1][0] : -self.cropping[1][1],\n                    self.cropping[2][0] :,\n                    :,\n                ]\n            return inputs[\n                :,\n                self.cropping[0][0] : -self.cropping[0][1],\n                self.cropping[1][0] : -self.cropping[1][1],\n                self.cropping[2][0] : -self.cropping[2][1],\n                :,\n            ]\n\n    def get_config(self):\n        config = {\"cropping\": self.cropping, \"data_format\": self.data_format}\n        base_config = super().get_config()\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/reshaping/cropping3d_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import ops\nfrom keras.src import testing\n\n\nclass Cropping3DTest(testing.TestCase):\n    @parameterized.product(\n        (\n            {\"dim1_cropping\": (1, 2), \"dim1_expected\": (1, 5)},  # both\n            {\"dim1_cropping\": (0, 2), \"dim1_expected\": (0, 5)},  # left only\n            {\"dim1_cropping\": (1, 0), \"dim1_expected\": (1, 7)},  # right only\n        ),\n        (\n            {\"dim2_cropping\": (3, 4), \"dim2_expected\": (3, 5)},  # both\n            {\"dim2_cropping\": (0, 4), \"dim2_expected\": (0, 5)},  # left only\n            {\"dim2_cropping\": (3, 0), \"dim2_expected\": (3, 9)},  # right only\n        ),\n        (\n            {\"dim3_cropping\": (5, 6), \"dim3_expected\": (5, 7)},  # both\n            {\"dim3_cropping\": (0, 6), \"dim3_expected\": (0, 7)},  # left only\n            {\"dim3_cropping\": (5, 0), \"dim3_expected\": (5, 13)},  # right only\n        ),\n        (\n            {\"data_format\": \"channels_first\"},\n            {\"data_format\": \"channels_last\"},\n        ),\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_cropping_3d(\n        self,\n        dim1_cropping,\n        dim2_cropping,\n        dim3_cropping,\n        data_format,\n        dim1_expected,\n        dim2_expected,\n        dim3_expected,\n    ):\n        if data_format == \"channels_first\":\n            inputs = np.random.rand(3, 5, 7, 9, 13)\n            expected_output = ops.convert_to_tensor(\n                inputs[\n                    :,\n                    :,\n                    dim1_expected[0] : dim1_expected[1],\n                    dim2_expected[0] : dim2_expected[1],\n                    dim3_expected[0] : dim3_expected[1],\n                ]\n            )\n        else:\n            inputs = np.random.rand(3, 7, 9, 13, 5)\n            expected_output = ops.convert_to_tensor(\n                inputs[\n                    :,\n                    dim1_expected[0] : dim1_expected[1],\n                    dim2_expected[0] : dim2_expected[1],\n                    dim3_expected[0] : dim3_expected[1],\n                    :,\n                ]\n            )\n\n        cropping = (dim1_cropping, dim2_cropping, dim3_cropping)\n        self.run_layer_test(\n            layers.Cropping3D,\n            init_kwargs={\"cropping\": cropping, \"data_format\": data_format},\n            input_data=inputs,\n            expected_output=expected_output,\n        )\n\n    @parameterized.product(\n        (\n            # same cropping values with 3 tuples\n            {\n                \"cropping\": ((2, 2), (2, 2), (2, 2)),\n                \"expected\": ((2, 5), (2, 7), (2, 11)),\n            },\n            # same cropping values with 1 tuple\n            {\"cropping\": (2, 2, 2), \"expected\": ((2, 5), (2, 7), (2, 11))},\n            # same cropping values with an integer\n            {\"cropping\": 2, \"expected\": ((2, 5), (2, 7), (2, 11))},\n        ),\n        (\n            {\"data_format\": \"channels_first\"},\n            {\"data_format\": \"channels_last\"},\n        ),\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_cropping_3d_with_same_cropping(\n        self, cropping, data_format, expected\n    ):\n        if data_format == \"channels_first\":\n            inputs = np.random.rand(3, 5, 7, 9, 13)\n            expected_output = ops.convert_to_tensor(\n                inputs[\n                    :,\n                    :,\n                    expected[0][0] : expected[0][1],\n                    expected[1][0] : expected[1][1],\n                    expected[2][0] : expected[2][1],\n                ]\n            )\n        else:\n            inputs = np.random.rand(3, 7, 9, 13, 5)\n            expected_output = ops.convert_to_tensor(\n                inputs[\n                    :,\n                    expected[0][0] : expected[0][1],\n                    expected[1][0] : expected[1][1],\n                    expected[2][0] : expected[2][1],\n                    :,\n                ]\n            )\n\n        self.run_layer_test(\n            layers.Cropping3D,\n            init_kwargs={\"cropping\": cropping, \"data_format\": data_format},\n            input_data=inputs,\n            expected_output=expected_output,\n        )\n\n    def test_cropping_3d_with_dynamic_spatial_dim(self):\n        if backend.config.image_data_format() == \"channels_last\":\n            input_layer = layers.Input(batch_shape=(1, 7, None, 13, 5))\n        else:\n            input_layer = layers.Input(batch_shape=(1, 5, 7, None, 13))\n        cropped = layers.Cropping3D(((1, 2), (3, 4), (5, 6)))(input_layer)\n        if backend.config.image_data_format() == \"channels_last\":\n            self.assertEqual(cropped.shape, (1, 4, None, 2, 5))\n        else:\n            self.assertEqual(cropped.shape, (1, 5, 4, None, 2))\n\n    @parameterized.product(\n        (\n            {\"cropping\": ((3, 6), (0, 0), (0, 0))},\n            {\"cropping\": ((0, 0), (5, 8), (0, 0))},\n            {\"cropping\": ((0, 0), (0, 0), (7, 6))},\n        ),\n        (\n            {\"data_format\": \"channels_first\"},\n            {\"data_format\": \"channels_last\"},\n        ),\n    )\n    def test_cropping_3d_errors_if_cropping_more_than_available(\n        self, cropping, data_format\n    ):\n        input_layer = layers.Input(batch_shape=(3, 7, 9, 13, 5))\n        with self.assertRaises(ValueError):\n            layers.Cropping3D(cropping=cropping, data_format=data_format)(\n                input_layer\n            )\n\n    def test_cropping_3d_errors_if_cropping_argument_invalid(self):\n        with self.assertRaises(ValueError):\n            layers.Cropping3D(cropping=(1,))\n        with self.assertRaises(ValueError):\n            layers.Cropping3D(cropping=(1, 2))\n        with self.assertRaises(ValueError):\n            layers.Cropping3D(cropping=(1, 2, 3, 4))\n        with self.assertRaises(ValueError):\n            layers.Cropping3D(cropping=\"1\")\n        with self.assertRaises(ValueError):\n            layers.Cropping3D(cropping=((1, 2), (3, 4), (5, 6, 7)))\n        with self.assertRaises(ValueError):\n            layers.Cropping3D(cropping=((1, 2), (3, 4), (5, -6)))\n        with self.assertRaises(ValueError):\n            layers.Cropping3D(cropping=((1, 2), (3, 4), \"5\"))\n\n    @parameterized.product(\n        (\n            {\"cropping\": ((8, 1), (1, 1), (1, 1))},\n            {\"cropping\": ((1, 1), (10, 1), (1, 1))},\n            {\"cropping\": ((1, 1), (1, 1), (14, 1))},\n        ),\n        (\n            {\"data_format\": \"channels_first\"},\n            {\"data_format\": \"channels_last\"},\n        ),\n    )\n    def test_cropping_3d_with_excessive_cropping(self, cropping, data_format):\n        if data_format == \"channels_first\":\n            shape = (3, 5, 7, 9, 13)\n            input_layer = layers.Input(batch_shape=shape)\n        else:\n            shape = (3, 7, 9, 13, 5)\n            input_layer = layers.Input(batch_shape=shape)\n\n        expected_error_msg = (\n            \"Values in `cropping` argument should be smaller than the\"\n        )\n\n        with self.assertRaisesRegex(ValueError, expected_error_msg):\n            layers.Cropping3D(cropping=cropping, data_format=data_format)(\n                input_layer\n            )\n"
  },
  {
    "path": "keras/src/layers/reshaping/flatten.py",
    "content": "import math\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend.common.keras_tensor import KerasTensor\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\n\n\n@keras_export(\"keras.layers.Flatten\")\nclass Flatten(Layer):\n    \"\"\"Flattens the input. Does not affect the batch size.\n\n    Note: If inputs are shaped `(batch,)` without a feature axis, then\n    flattening adds an extra channel dimension and output shape is `(batch, 1)`.\n\n    Args:\n        data_format: A string, one of `\"channels_last\"` (default) or\n            `\"channels_first\"`. The ordering of the dimensions in the inputs.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch, ..., channels)` while `\"channels_first\"` corresponds to\n            inputs with shape `(batch, channels, ...)`.\n            When unspecified, uses `image_data_format` value found in your Keras\n            config file at `~/.keras/keras.json` (if exists). Defaults to\n            `\"channels_last\"`.\n\n    Example:\n\n    >>> x = keras.Input(shape=(10, 64))\n    >>> y = keras.layers.Flatten()(x)\n    >>> y.shape\n    (None, 640)\n    \"\"\"\n\n    def __init__(self, data_format=None, **kwargs):\n        super().__init__(**kwargs)\n        self.data_format = backend.standardize_data_format(data_format)\n        self.input_spec = InputSpec(min_ndim=1)\n        self._channels_first = self.data_format == \"channels_first\"\n\n    def call(self, inputs):\n        input_shape = ops.shape(inputs)\n        rank = len(input_shape)\n\n        if self._channels_first and rank > 1:\n            # Switch to channels-last format.\n            inputs = ops.transpose(inputs, axes=(0, *range(2, rank), 1))\n\n        non_batch_dims = input_shape[1:]\n        if len(non_batch_dims) == 0:\n            flattened_dim = 1\n        elif any(not isinstance(d, int) for d in non_batch_dims):\n            flattened_dim = -1\n        else:\n            flattened_dim = math.prod(non_batch_dims)\n\n        return ops.reshape(inputs, (input_shape[0], flattened_dim))\n\n    def compute_output_shape(self, input_shape):\n        non_batch_dims = input_shape[1:]\n        if len(non_batch_dims) == 0:\n            flattened_dim = 1\n        elif any(d is None for d in non_batch_dims):\n            # NB: we cannot use the shorter `None in non_batch_dims` here b/c\n            # torchdynamo errors when calling `__contains__` op with\n            # a constant (in this case `None`) operand since it assumes\n            # that the elements in the collection are also `ConstantVariable`s\n            # but tensor shapes can be `SymNodeVariable`s (e.g. `SymInt`)\n            flattened_dim = None\n        else:\n            flattened_dim = math.prod(non_batch_dims)\n        return (input_shape[0], flattened_dim)\n\n    def compute_output_spec(self, inputs):\n        output_shape = self.compute_output_shape(inputs.shape)\n        return KerasTensor(\n            shape=output_shape, dtype=inputs.dtype, sparse=inputs.sparse\n        )\n\n    def get_config(self):\n        config = {\"data_format\": self.data_format}\n        base_config = super().get_config()\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/reshaping/flatten_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom conftest import skip_if_backend\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import testing\n\n\nclass FlattenTest(testing.TestCase):\n    @parameterized.named_parameters(\n        [\n            {\"testcase_name\": \"dense\", \"sparse\": False},\n            {\"testcase_name\": \"sparse\", \"sparse\": True},\n        ]\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_flatten(self, sparse):\n        if sparse and not backend.SUPPORTS_SPARSE_TENSORS:\n            pytest.skip(\"Backend does not support sparse tensors.\")\n\n        inputs = np.random.random((10, 3, 5, 5)).astype(\"float32\")\n        # Make the ndarray relatively sparse\n        inputs = np.multiply(inputs, inputs >= 0.8)\n        expected_output_channels_last = ops.convert_to_tensor(\n            np.reshape(inputs, (-1, 5 * 5 * 3))\n        )\n        expected_output_channels_first = ops.convert_to_tensor(\n            np.reshape(np.transpose(inputs, (0, 2, 3, 1)), (-1, 5 * 5 * 3))\n        )\n        if sparse:\n            if backend.backend() == \"tensorflow\":\n                import tensorflow as tf\n\n                dense_to_sparse = tf.sparse.from_dense\n            elif backend.backend() == \"jax\":\n                import jax.experimental.sparse as jax_sparse\n\n                dense_to_sparse = jax_sparse.BCOO.fromdense\n            else:\n                self.fail(\n                    f\"Sparse is unsupported with backend {backend.backend()}\"\n                )\n            inputs = dense_to_sparse(inputs)\n            expected_output_channels_last = dense_to_sparse(\n                expected_output_channels_last\n            )\n            expected_output_channels_first = dense_to_sparse(\n                expected_output_channels_first\n            )\n\n        # Test default data_format and channels_last\n        self.run_layer_test(\n            layers.Flatten,\n            init_kwargs={},\n            input_data=inputs,\n            input_sparse=True,\n            expected_output=(\n                expected_output_channels_last\n                if backend.config.image_data_format() == \"channels_last\"\n                else expected_output_channels_first\n            ),\n            expected_output_sparse=sparse,\n            run_training_check=not sparse,\n        )\n        self.run_layer_test(\n            layers.Flatten,\n            init_kwargs={\"data_format\": \"channels_last\"},\n            input_data=inputs,\n            input_sparse=True,\n            expected_output=expected_output_channels_last,\n            expected_output_sparse=sparse,\n            run_training_check=not sparse,\n        )\n\n        # Test channels_first\n        self.run_layer_test(\n            layers.Flatten,\n            init_kwargs={\"data_format\": \"channels_first\"},\n            input_data=inputs,\n            input_sparse=True,\n            expected_output=expected_output_channels_first,\n            expected_output_sparse=sparse,\n            run_training_check=not sparse,\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_flatten_with_scalar_channels(self):\n        inputs = np.random.random((10,)).astype(\"float32\")\n        expected_output = ops.convert_to_tensor(np.expand_dims(inputs, -1))\n\n        # Test default data_format and channels_last\n        self.run_layer_test(\n            layers.Flatten,\n            init_kwargs={},\n            input_data=inputs,\n            expected_output=expected_output,\n        )\n        self.run_layer_test(\n            layers.Flatten,\n            init_kwargs={\"data_format\": \"channels_last\"},\n            input_data=inputs,\n            expected_output=expected_output,\n        )\n\n        # Test channels_first\n        self.run_layer_test(\n            layers.Flatten,\n            init_kwargs={\"data_format\": \"channels_first\"},\n            input_data=inputs,\n            expected_output=expected_output,\n        )\n\n    def test_flatten_symbolic_with_dynamic_batch_size(self):\n        input_layer = layers.Input(batch_shape=(None, 2, 3))\n        flattened = layers.Flatten()(input_layer)\n        self.assertEqual(flattened.shape, (None, 2 * 3))\n\n    def test_flatten_symbolic_with_dynamic_dimension(self):\n        input_layer = layers.Input(batch_shape=(5, 2, None))\n        flattened = layers.Flatten()(input_layer)\n        self.assertEqual(flattened.shape, (5, None))\n\n    @skip_if_backend(\"openvino\", \"Dynamic dimensions not supported by OpenVino\")\n    def test_flatten_with_dynamic_batch_size_and_dynamic_dimenstions(self):\n        def generator():\n            yield (np.ones((3, 5, 7), dtype=\"float32\"),)\n            yield (np.ones((2, 7, 5), dtype=\"float32\"),)\n\n        model = models.Sequential([layers.Flatten()])\n        model.predict(generator())\n"
  },
  {
    "path": "keras/src/layers/reshaping/permute.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend.common.keras_tensor import KerasTensor\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\n\n\n@keras_export(\"keras.layers.Permute\")\nclass Permute(Layer):\n    \"\"\"Permutes the dimensions of the input according to a given pattern.\n\n    Useful e.g. connecting RNNs and convnets.\n\n    Args:\n        dims: Tuple of integers. Permutation pattern does not include the\n            batch dimension. Indexing starts at 1.\n            For instance, `(1, 3, 2)` permutes the second and third dimensions\n            of the input.\n\n    Input shape:\n        Arbitrary.\n\n    Output shape:\n        Same as the input shape, but with the dimensions re-ordered according\n        to the specified pattern.\n\n    Example:\n\n    >>> x = keras.Input(shape=(10, 64))\n    >>> y = keras.layers.Permute((2, 1))(x)\n    >>> y.shape\n    (None, 64, 10)\n    \"\"\"\n\n    def __init__(self, dims, **kwargs):\n        super().__init__(**kwargs)\n        self.dims = tuple(dims)\n        if sorted(dims) != list(range(1, len(dims) + 1)):\n            raise ValueError(\n                \"Invalid permutation argument `dims` for Permute Layer. \"\n                \"The set of indices in `dims` must be consecutive and start \"\n                f\"from 1. Received dims={dims}\"\n            )\n        self.input_spec = InputSpec(ndim=len(self.dims) + 1)\n\n    def compute_output_shape(self, input_shape):\n        output_shape = [input_shape[0]]\n        for dim in self.dims:\n            output_shape.append(input_shape[dim])\n        return tuple(output_shape)\n\n    def compute_output_spec(self, inputs):\n        output_shape = self.compute_output_shape(inputs.shape)\n        return KerasTensor(\n            shape=output_shape, dtype=inputs.dtype, sparse=inputs.sparse\n        )\n\n    def call(self, inputs):\n        return ops.transpose(inputs, axes=(0,) + self.dims)\n\n    def get_config(self):\n        config = {\"dims\": self.dims}\n        base_config = super().get_config()\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/reshaping/permute_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import ops\nfrom keras.src import testing\n\n\nclass PermuteTest(testing.TestCase):\n    @parameterized.named_parameters(\n        [\n            {\"testcase_name\": \"dense\", \"sparse\": False},\n            {\"testcase_name\": \"sparse\", \"sparse\": True},\n        ]\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_permute(self, sparse):\n        if sparse and not backend.SUPPORTS_SPARSE_TENSORS:\n            pytest.skip(\"Backend does not support sparse tensors.\")\n\n        inputs = np.random.random((10, 3, 5, 5)).astype(\"float32\")\n        # Make the ndarray relatively sparse\n        inputs = np.multiply(inputs, inputs >= 0.8)\n        expected_output = ops.convert_to_tensor(\n            np.transpose(inputs, axes=(0, 3, 1, 2))\n        )\n        if sparse:\n            if backend.backend() == \"tensorflow\":\n                import tensorflow as tf\n\n                inputs = tf.sparse.from_dense(inputs)\n                expected_output = tf.sparse.from_dense(expected_output)\n            elif backend.backend() == \"jax\":\n                import jax.experimental.sparse as jax_sparse\n\n                inputs = jax_sparse.BCOO.fromdense(inputs)\n                expected_output = jax_sparse.BCOO.fromdense(expected_output)\n            else:\n                self.fail(\n                    f\"Backend {backend.backend()} does not support sparse\"\n                )\n\n        self.run_layer_test(\n            layers.Permute,\n            init_kwargs={\"dims\": (3, 1, 2)},\n            input_data=inputs,\n            input_sparse=sparse,\n            expected_output=expected_output,\n            expected_output_sparse=sparse,\n            run_training_check=not sparse,\n        )\n\n    def test_permute_with_dynamic_batch_size(self):\n        input_layer = layers.Input(batch_shape=(None, 3, 5))\n        permuted = layers.Permute((2, 1))(input_layer)\n        self.assertEqual(permuted.shape, (None, 5, 3))\n\n    def test_permute_errors_on_invalid_starting_dims_index(self):\n        with self.assertRaisesRegex(\n            ValueError, r\"Invalid permutation .*dims.*\"\n        ):\n            self.run_layer_test(\n                layers.Permute,\n                init_kwargs={\"dims\": (0, 1, 2)},\n                input_shape=(3, 2, 4),\n            )\n\n    def test_permute_errors_on_invalid_set_of_dims_indices(self):\n        with self.assertRaisesRegex(\n            ValueError, r\"Invalid permutation .*dims.*\"\n        ):\n            self.run_layer_test(\n                layers.Permute,\n                init_kwargs={\"dims\": (1, 4, 2)},\n                input_shape=(3, 2, 4),\n            )\n"
  },
  {
    "path": "keras/src/layers/reshaping/repeat_vector.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\n\n\n@keras_export(\"keras.layers.RepeatVector\")\nclass RepeatVector(Layer):\n    \"\"\"Repeats the input n times.\n\n    Example:\n\n    >>> x = keras.Input(shape=(32,))\n    >>> y = keras.layers.RepeatVector(3)(x)\n    >>> y.shape\n    (None, 3, 32)\n\n    Args:\n        n: Integer, repetition factor.\n\n    Input shape:\n        2D tensor with shape `(batch_size, features)`.\n\n    Output shape:\n        3D tensor with shape `(batch_size, n, features)`.\n    \"\"\"\n\n    def __init__(self, n, **kwargs):\n        super().__init__(**kwargs)\n        self.n = n\n        if not isinstance(n, int):\n            raise TypeError(\n                f\"Expected an integer value for `n`, got {type(n)}.\"\n            )\n        self.input_spec = InputSpec(ndim=2)\n\n    def compute_output_shape(self, input_shape):\n        return (input_shape[0], self.n, input_shape[1])\n\n    def call(self, inputs):\n        input_shape = ops.shape(inputs)\n        reshaped = ops.reshape(inputs, (input_shape[0], 1, input_shape[1]))\n        return ops.repeat(reshaped, self.n, axis=1)\n\n    def get_config(self):\n        config = {\"n\": self.n}\n        base_config = super().get_config()\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/reshaping/repeat_vector_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import layers\nfrom keras.src import ops\nfrom keras.src import testing\n\n\nclass FlattenTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_repeat_vector(self):\n        inputs = np.random.random((2, 5)).astype(\"float32\")\n        expected_output = ops.convert_to_tensor(\n            np.repeat(np.reshape(inputs, (2, 1, 5)), 3, axis=1)\n        )\n        self.run_layer_test(\n            layers.RepeatVector,\n            init_kwargs={\"n\": 3},\n            input_data=inputs,\n            expected_output=expected_output,\n        )\n\n    def test_repeat_vector_with_dynamic_batch_size(self):\n        input_layer = layers.Input(batch_shape=(None, 5))\n        repeated = layers.RepeatVector(n=3)(input_layer)\n        self.assertEqual(repeated.shape, (None, 3, 5))\n\n    def test_repeat_vector_with_dynamic_dimension(self):\n        input_layer = layers.Input(batch_shape=(2, None))\n        repeated = layers.RepeatVector(n=3)(input_layer)\n        self.assertEqual(repeated.shape, (2, 3, None))\n\n    def test_repeat_vector_with_invalid_n(self):\n        with self.assertRaisesRegex(\n            TypeError, \"Expected an integer value for `n`\"\n        ):\n            layers.RepeatVector(n=\"3\")\n\n        with self.assertRaisesRegex(\n            TypeError, \"Expected an integer value for `n`\"\n        ):\n            layers.RepeatVector(n=3.5)\n\n        with self.assertRaisesRegex(\n            TypeError, \"Expected an integer value for `n`\"\n        ):\n            layers.RepeatVector(n=[3])\n"
  },
  {
    "path": "keras/src/layers/reshaping/reshape.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend.common.keras_tensor import KerasTensor\nfrom keras.src.layers.layer import Layer\nfrom keras.src.ops import operation_utils\n\n\n@keras_export(\"keras.layers.Reshape\")\nclass Reshape(Layer):\n    \"\"\"Layer that reshapes inputs into the given shape.\n\n    Args:\n        target_shape: Target shape. Tuple of integers, does not include the\n            samples dimension (batch size). One element of the `target_shape`\n            can be -1 in which case the missing value is inferred from the\n            size of the array and remaining dimensions.\n\n    Input shape:\n        Arbitrary, but required to be compatible with `target_shape`.\n\n    Output shape:\n        `(batch_size, *target_shape)`\n\n    Example:\n\n    >>> x = keras.Input(shape=(12,))\n    >>> y = keras.layers.Reshape((3, 4))(x)\n    >>> y.shape\n    (None, 3, 4)\n\n    >>> # another example with shape inference using `-1` as dimension\n    >>> y = keras.layers.Reshape((-1, 2, 2))(x)\n    >>> y.shape\n    (None, 3, 2, 2)\n    \"\"\"\n\n    def __init__(self, target_shape, **kwargs):\n        super().__init__(**kwargs)\n        target_shape = tuple(target_shape)\n        # test validity of target_shape\n        if target_shape.count(-1) > 1:\n            raise ValueError(\n                \"The `target_shape` argument must not contain more than one \"\n                f\"`-1` value. Received: target_shape={target_shape}\"\n            )\n        self.target_shape = target_shape\n        self.built = True\n\n    def compute_output_shape(self, input_shape):\n        return (\n            input_shape[0],\n            *operation_utils.compute_reshape_output_shape(\n                input_shape[1:], self.target_shape, \"target_shape\"\n            ),\n        )\n\n    def compute_output_spec(self, inputs):\n        output_shape = self.compute_output_shape(inputs.shape)\n        return KerasTensor(\n            shape=output_shape, dtype=inputs.dtype, sparse=inputs.sparse\n        )\n\n    def call(self, inputs):\n        potentially_resolved_target_shape = (\n            operation_utils.compute_reshape_output_shape(\n                tuple(inputs.shape)[1:], self.target_shape, \"target_shape\"\n            )\n        )\n        potentially_resolved_target_shape = tuple(\n            -1 if d is None else d for d in potentially_resolved_target_shape\n        )\n        return ops.reshape(\n            inputs, (ops.shape(inputs)[0],) + potentially_resolved_target_shape\n        )\n\n    def get_config(self):\n        config = {\"target_shape\": self.target_shape}\n        base_config = super().get_config()\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/reshaping/reshape_test.py",
    "content": "import pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import Sequential\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import ops\nfrom keras.src import random\nfrom keras.src import testing\nfrom keras.src.backend.common.keras_tensor import KerasTensor\n\n\nclass ReshapeTest(testing.TestCase):\n    @parameterized.named_parameters(\n        [\n            {\"testcase_name\": \"dense\", \"sparse\": False},\n            {\"testcase_name\": \"sparse\", \"sparse\": True},\n        ]\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_reshape(self, sparse):\n        if sparse and not backend.SUPPORTS_SPARSE_TENSORS:\n            pytest.skip(\"Backend does not support sparse tensors.\")\n\n        self.run_layer_test(\n            layers.Reshape,\n            init_kwargs={\"target_shape\": (8, 1)},\n            input_shape=(3, 2, 4),\n            input_sparse=sparse,\n            expected_output_shape=(3, 8, 1),\n            expected_output_sparse=sparse,\n            run_training_check=not sparse,\n        )\n\n        self.run_layer_test(\n            layers.Reshape,\n            init_kwargs={\"target_shape\": (8,)},\n            input_shape=(3, 2, 4),\n            input_sparse=sparse,\n            expected_output_shape=(3, 8),\n            expected_output_sparse=sparse,\n            run_training_check=not sparse,\n        )\n\n        self.run_layer_test(\n            layers.Reshape,\n            init_kwargs={\"target_shape\": (2, 4)},\n            input_shape=(3, 8),\n            input_sparse=sparse,\n            expected_output_shape=(3, 2, 4),\n            expected_output_sparse=sparse,\n            run_training_check=not sparse,\n        )\n        self.run_layer_test(\n            layers.Reshape,\n            init_kwargs={\"target_shape\": (-1, 1)},\n            input_shape=(3, 2, 4),\n            input_sparse=sparse,\n            expected_output_shape=(3, 8, 1),\n            expected_output_sparse=sparse,\n            run_training_check=not sparse,\n        )\n\n        self.run_layer_test(\n            layers.Reshape,\n            init_kwargs={\"target_shape\": (1, -1)},\n            input_shape=(3, 2, 4),\n            input_sparse=sparse,\n            expected_output_shape=(3, 1, 8),\n            expected_output_sparse=sparse,\n            run_training_check=not sparse,\n        )\n\n        self.run_layer_test(\n            layers.Reshape,\n            init_kwargs={\"target_shape\": (-1,)},\n            input_shape=(3, 2, 4),\n            input_sparse=sparse,\n            expected_output_shape=(3, 8),\n            expected_output_sparse=sparse,\n            run_training_check=not sparse,\n        )\n\n        self.run_layer_test(\n            layers.Reshape,\n            init_kwargs={\"target_shape\": (2, -1)},\n            input_shape=(3, 2, 4),\n            input_sparse=sparse,\n            expected_output_shape=(3, 2, 4),\n            expected_output_sparse=sparse,\n            run_training_check=not sparse,\n        )\n\n    def test_reshape_with_dynamic_batch_size(self):\n        input_layer = layers.Input(shape=(2, 4))\n        reshaped = layers.Reshape((8,))(input_layer)\n        self.assertEqual(reshaped.shape, (None, 8))\n\n    def test_reshape_with_dynamic_batch_size_and_minus_one(self):\n        input = KerasTensor((None, 6, 4))\n        layer = layers.Reshape((-1, 8))\n        reshaped = backend.compute_output_spec(layer.__call__, input)\n        self.assertEqual(reshaped.shape, (None, 3, 8))\n\n    def test_reshape_layer_with_varying_input_size_and_minus_one(self):\n        layer = layers.Reshape((-1, 8))\n        res = layer(ops.ones((1, 6, 4), dtype=\"float32\"))\n        self.assertEqual(res.shape, (1, 3, 8))\n        res = layer(ops.ones((1, 10, 4), dtype=\"float32\"))\n        self.assertEqual(res.shape, (1, 5, 8))\n\n    def test_reshape_with_dynamic_dim_and_minus_one(self):\n        input = KerasTensor((4, 6, None, 3))\n        layer = layers.Reshape((-1, 3))\n        reshaped = backend.compute_output_spec(layer.__call__, input)\n        self.assertEqual(reshaped.shape, (4, None, 3))\n\n    def test_reshape_sets_static_shape(self):\n        input_layer = layers.Input(batch_shape=(2, None))\n        reshaped = layers.Reshape((3, 5))(input_layer)\n        # Also make sure the batch dim is not lost after reshape.\n        self.assertEqual(reshaped.shape, (2, 3, 5))\n\n    @pytest.mark.requires_trainable_backend\n    def test_reshape_model_fit_with_varying_input_size_and_minus_one(self):\n        def generator():\n            yield (\n                ops.ones((1, 12, 2), dtype=\"float32\"),\n                ops.zeros((1, 3, 8), dtype=\"float32\"),\n            )\n            yield (\n                ops.ones((1, 20, 2), dtype=\"float32\"),\n                ops.zeros((1, 5, 8), dtype=\"float32\"),\n            )\n\n        layer = layers.Reshape((-1, 8))\n        model = Sequential([layer])\n        model.compile(loss=\"mean_squared_error\")\n        model.fit(generator(), steps_per_epoch=2, epochs=1)\n\n        # Also test inference with varying shapes to ensure -1 works dynamically\n        # Input: (batch, seq_len, 2), Output: (batch, seq_len*2/8, 8)\n        for seq_len, expected_first_dim in [(12, 3), (20, 5), (24, 6)]:\n            x_test = random.normal((1, seq_len, 2))\n            output = model(x_test)\n            # Verify output shape is correct\n            self.assertEqual(output.shape[1], expected_first_dim)\n            self.assertEqual(output.shape[2], 8)\n"
  },
  {
    "path": "keras/src/layers/reshaping/up_sampling1d.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\n\n\n@keras_export(\"keras.layers.UpSampling1D\")\nclass UpSampling1D(Layer):\n    \"\"\"Upsampling layer for 1D inputs.\n\n    Repeats each temporal step `size` times along the time axis.\n\n    Example:\n\n    >>> input_shape = (2, 2, 3)\n    >>> x = np.arange(np.prod(input_shape)).reshape(input_shape)\n    >>> x\n    [[[ 0  1  2]\n      [ 3  4  5]]\n     [[ 6  7  8]\n      [ 9 10 11]]]\n    >>> y = keras.layers.UpSampling1D(size=2)(x)\n    >>> y\n    [[[ 0.  1.  2.]\n      [ 0.  1.  2.]\n      [ 3.  4.  5.]\n      [ 3.  4.  5.]]\n     [[ 6.  7.  8.]\n      [ 6.  7.  8.]\n      [ 9. 10. 11.]\n      [ 9. 10. 11.]]]\n\n    Args:\n        size: Integer. Upsampling factor.\n\n    Input shape:\n        3D tensor with shape: `(batch_size, steps, features)`.\n\n    Output shape:\n        3D tensor with shape: `(batch_size, upsampled_steps, features)`.\n    \"\"\"\n\n    def __init__(self, size=2, **kwargs):\n        super().__init__(**kwargs)\n        self.size = int(size)\n        self.input_spec = InputSpec(ndim=3)\n\n    def compute_output_shape(self, input_shape):\n        size = (\n            self.size * input_shape[1] if input_shape[1] is not None else None\n        )\n        return [input_shape[0], size, input_shape[2]]\n\n    def call(self, inputs):\n        return ops.repeat(x=inputs, repeats=self.size, axis=1)\n\n    def get_config(self):\n        config = {\"size\": self.size}\n        base_config = super().get_config()\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/reshaping/up_sampling1d_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import layers\nfrom keras.src import testing\nfrom keras.src.backend.common.keras_tensor import KerasTensor\n\n\nclass UpSamplingTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_upsampling_1d(self):\n        self.run_layer_test(\n            layers.UpSampling1D,\n            init_kwargs={\"size\": 2},\n            input_shape=(3, 5, 4),\n            expected_output_shape=(3, 10, 4),\n            expected_output_dtype=\"float32\",\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n\n    def test_upsampling_1d_correctness(self):\n        self.assertAllClose(\n            layers.UpSampling1D(size=2)(np.arange(12).reshape((2, 2, 3))),\n            np.array(\n                [\n                    [\n                        [0.0, 1.0, 2.0],\n                        [0.0, 1.0, 2.0],\n                        [3.0, 4.0, 5.0],\n                        [3.0, 4.0, 5.0],\n                    ],\n                    [\n                        [6.0, 7.0, 8.0],\n                        [6.0, 7.0, 8.0],\n                        [9.0, 10.0, 11.0],\n                        [9.0, 10.0, 11.0],\n                    ],\n                ]\n            ),\n        )\n\n    def test_upsampling_1d_correctness_with_ones(self):\n        self.assertAllClose(\n            layers.UpSampling1D(size=3)(np.ones((2, 1, 5))), np.ones((2, 3, 5))\n        )\n\n    def test_upsampling_1d_with_dynamic_batch_size(self):\n        x = KerasTensor([None, 2, 3])\n        self.assertEqual(layers.UpSampling1D(size=2)(x).shape, (None, 4, 3))\n        self.assertEqual(layers.UpSampling1D(size=4)(x).shape, (None, 8, 3))\n\n    def test_upsampling_1d_with_dynamic_shape(self):\n        y = KerasTensor([2, None, 3])\n        self.assertEqual(layers.UpSampling1D(size=2)(y).shape, (2, None, 3))\n        self.assertEqual(layers.UpSampling1D(size=4)(y).shape, (2, None, 3))\n\n        z = KerasTensor([2, 3, None])\n        self.assertEqual(layers.UpSampling1D(size=2)(z).shape, (2, 6, None))\n        self.assertEqual(layers.UpSampling1D(size=4)(z).shape, (2, 12, None))\n"
  },
  {
    "path": "keras/src/layers/reshaping/up_sampling2d.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\nfrom keras.src.utils import argument_validation\n\n\n@keras_export(\"keras.layers.UpSampling2D\")\nclass UpSampling2D(Layer):\n    \"\"\"Upsampling layer for 2D inputs.\n\n    The implementation uses interpolative resizing, given the resize method\n    (specified by the `interpolation` argument). Use `interpolation=nearest`\n    to repeat the rows and columns of the data.\n\n    Example:\n\n    >>> input_shape = (2, 2, 1, 3)\n    >>> x = np.arange(np.prod(input_shape)).reshape(input_shape)\n    >>> print(x)\n    [[[[ 0  1  2]]\n      [[ 3  4  5]]]\n     [[[ 6  7  8]]\n      [[ 9 10 11]]]]\n    >>> y = keras.layers.UpSampling2D(size=(1, 2))(x)\n    >>> print(y)\n    [[[[ 0  1  2]\n       [ 0  1  2]]\n      [[ 3  4  5]\n       [ 3  4  5]]]\n     [[[ 6  7  8]\n       [ 6  7  8]]\n      [[ 9 10 11]\n       [ 9 10 11]]]]\n\n    Args:\n        size: Int, or tuple of 2 integers.\n            The upsampling factors for rows and columns.\n        data_format: A string,\n            one of `\"channels_last\"` (default) or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch_size, height, width, channels)` while `\"channels_first\"`\n            corresponds to inputs with shape\n            `(batch_size, channels, height, width)`.\n            When unspecified, uses\n            `image_data_format` value found in your Keras config file at\n            `~/.keras/keras.json` (if exists) else `\"channels_last\"`.\n            Defaults to `\"channels_last\"`.\n        interpolation: A string, one of `\"bicubic\"`, `\"bilinear\"`, `\"lanczos3\"`,\n            `\"lanczos5\"`, `\"nearest\"`.\n\n    Input shape:\n        4D tensor with shape:\n        - If `data_format` is `\"channels_last\"`:\n            `(batch_size, rows, cols, channels)`\n        - If `data_format` is `\"channels_first\"`:\n            `(batch_size, channels, rows, cols)`\n\n    Output shape:\n        4D tensor with shape:\n        - If `data_format` is `\"channels_last\"`:\n            `(batch_size, upsampled_rows, upsampled_cols, channels)`\n        - If `data_format` is `\"channels_first\"`:\n            `(batch_size, channels, upsampled_rows, upsampled_cols)`\n    \"\"\"\n\n    def __init__(\n        self, size=(2, 2), data_format=None, interpolation=\"nearest\", **kwargs\n    ):\n        super().__init__(**kwargs)\n        self.data_format = backend.standardize_data_format(data_format)\n        self.size = argument_validation.standardize_tuple(size, 2, \"size\")\n        self.interpolation = interpolation.lower()\n        self.input_spec = InputSpec(ndim=4)\n\n    def compute_output_shape(self, input_shape):\n        if self.data_format == \"channels_first\":\n            height = (\n                self.size[0] * input_shape[2]\n                if input_shape[2] is not None\n                else None\n            )\n            width = (\n                self.size[1] * input_shape[3]\n                if input_shape[3] is not None\n                else None\n            )\n            return (input_shape[0], input_shape[1], height, width)\n        else:\n            height = (\n                self.size[0] * input_shape[1]\n                if input_shape[1] is not None\n                else None\n            )\n            width = (\n                self.size[1] * input_shape[2]\n                if input_shape[2] is not None\n                else None\n            )\n            return (input_shape[0], height, width, input_shape[3])\n\n    def call(self, inputs):\n        return self._resize_images(\n            inputs,\n            self.size[0],\n            self.size[1],\n            self.data_format,\n            interpolation=self.interpolation,\n        )\n\n    def get_config(self):\n        config = {\n            \"size\": self.size,\n            \"data_format\": self.data_format,\n            \"interpolation\": self.interpolation,\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n    def _resize_images(\n        self,\n        x,\n        height_factor,\n        width_factor,\n        data_format,\n        interpolation=\"nearest\",\n    ):\n        \"\"\"Resizes the images contained in a 4D tensor.\n\n        Args:\n            x: Tensor or variable to resize.\n            height_factor: Positive integer.\n            width_factor: Positive integer.\n            data_format: One of `\"channels_first\"`, `\"channels_last\"`.\n            interpolation: A string, one of `\"bicubic\"`, `\"bilinear\"`,\n            `\"lanczos3\"`, `\"lanczos5\"`, or `\"nearest\"`.\n\n        Returns:\n            A tensor.\n        \"\"\"\n        if data_format not in {\"channels_last\", \"channels_first\"}:\n            raise ValueError(f\"Invalid `data_format` argument: {data_format}\")\n\n        if data_format == \"channels_first\":\n            x = ops.transpose(x, [0, 2, 3, 1])\n        # https://github.com/keras-team/keras/issues/294\n        # Use `ops.repeat` for `nearest` interpolation to enable XLA\n        if interpolation == \"nearest\":\n            x = ops.repeat(x, height_factor, axis=1)\n            x = ops.repeat(x, width_factor, axis=2)\n        else:\n            # multiply the height and width factor on each dim\n            # by hand (versus using element-wise multiplication\n            # by np.array([height_factor, width_factor]) then\n            # list-ifying the tensor by calling `.tolist()`)\n            # since when running under torchdynamo, `new_shape`\n            # will be traced as a symbolic variable (specifically\n            # a `FakeTensor`) which does not have a `tolist()` method.\n            shape = ops.shape(x)\n            new_shape = (\n                shape[1] * height_factor,\n                shape[2] * width_factor,\n            )\n            x = ops.image.resize(\n                x,\n                new_shape,\n                data_format=\"channels_last\",\n                interpolation=interpolation,\n            )\n        if data_format == \"channels_first\":\n            x = ops.transpose(x, [0, 3, 1, 2])\n\n        return x\n"
  },
  {
    "path": "keras/src/layers/reshaping/up_sampling2d_test.py",
    "content": "# flake8: noqa\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\nfrom keras.backend import set_image_data_format\n\n\nclass UpSampling2dTest(testing.TestCase):\n    @classmethod\n    def setUpClass(cls):\n        cls.original_image_data_format = backend.image_data_format()\n\n    @classmethod\n    def tearDownClass(cls):\n        backend.set_image_data_format(cls.original_image_data_format)\n\n    @parameterized.product(\n        data_format=[\"channels_first\", \"channels_last\"],\n        length_row=[2],\n        length_col=[2, 3],\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_upsampling_2d(self, data_format, length_row, length_col):\n        num_samples = 2\n        stack_size = 2\n        input_num_row = 11\n        input_num_col = 12\n\n        if data_format == \"channels_first\":\n            inputs = np.random.rand(\n                num_samples, stack_size, input_num_row, input_num_col\n            )\n        else:\n            inputs = np.random.rand(\n                num_samples, input_num_row, input_num_col, stack_size\n            )\n\n        # basic test\n        self.run_layer_test(\n            layers.UpSampling2D,\n            init_kwargs={\"size\": (2, 2), \"data_format\": data_format},\n            input_shape=inputs.shape,\n        )\n\n        layer = layers.UpSampling2D(\n            size=(length_row, length_col),\n            data_format=data_format,\n        )\n        layer.build(inputs.shape)\n        np_output = layer(inputs=backend.Variable(inputs))\n        if data_format == \"channels_first\":\n            self.assertEqual(np_output.shape[2], length_row * input_num_row)\n            self.assertEqual(np_output.shape[3], length_col * input_num_col)\n        else:\n            self.assertEqual(np_output.shape[1], length_row * input_num_row)\n            self.assertEqual(np_output.shape[2], length_col * input_num_col)\n\n        # compare with numpy\n        if data_format == \"channels_first\":\n            expected_out = np.repeat(inputs, length_row, axis=2)\n            expected_out = np.repeat(expected_out, length_col, axis=3)\n        else:\n            expected_out = np.repeat(inputs, length_row, axis=1)\n            expected_out = np.repeat(expected_out, length_col, axis=2)\n\n        self.assertAllClose(np_output, expected_out)\n\n    @parameterized.product(\n        data_format=[\"channels_first\", \"channels_last\"],\n        use_set_image_data_format=[True, False],\n        length_row=[2],\n        length_col=[2, 3],\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_upsampling_2d_bilinear(\n        self, data_format, use_set_image_data_format, length_row, length_col\n    ):\n        num_samples = 2\n        stack_size = 2\n        input_num_row = 11\n        input_num_col = 12\n\n        if use_set_image_data_format:\n            set_image_data_format(data_format)\n\n        if data_format == \"channels_first\":\n            inputs = np.random.rand(\n                num_samples, stack_size, input_num_row, input_num_col\n            )\n        else:\n            inputs = np.random.rand(\n                num_samples, input_num_row, input_num_col, stack_size\n            )\n\n        self.run_layer_test(\n            layers.UpSampling2D,\n            init_kwargs={\n                \"size\": (2, 2),\n                \"data_format\": data_format,\n                \"interpolation\": \"bilinear\",\n            },\n            input_shape=inputs.shape,\n        )\n\n        layer = layers.UpSampling2D(\n            size=(length_row, length_col),\n            data_format=data_format,\n            interpolation=\"bilinear\",\n        )\n        layer.build(inputs.shape)\n        np_output = layer(inputs=backend.Variable(inputs))\n        if data_format == \"channels_first\":\n            self.assertEqual(np_output.shape[2], length_row * input_num_row)\n            self.assertEqual(np_output.shape[3], length_col * input_num_col)\n        else:\n            self.assertEqual(np_output.shape[1], length_row * input_num_row)\n            self.assertEqual(np_output.shape[2], length_col * input_num_col)\n\n    def test_upsampling_2d_correctness(self):\n        input_shape = (2, 2, 1, 3)\n        x = np.arange(np.prod(input_shape)).reshape(input_shape)\n        # fmt: off\n        expected_output = np.array(\n            [[[[ 0.,  1.,  2.],\n               [ 0.,  1.,  2.]],\n              [[ 3.,  4.,  5.],\n               [ 3.,  4.,  5.]]],\n             [[[ 6.,  7.,  8.],\n               [ 6.,  7.,  8.]],\n              [[ 9., 10., 11.],\n               [ 9., 10., 11.]]]]\n        )\n        # fmt: on\n        if backend.config.image_data_format() == \"channels_first\":\n            expected_output = expected_output.transpose((0, 3, 1, 2))\n            x = x.transpose((0, 3, 1, 2))\n        self.assertAllClose(\n            layers.UpSampling2D(size=(1, 2))(x), expected_output\n        )\n\n    def test_upsampling_2d_various_interpolation_methods(self):\n        input_shape = (2, 2, 1, 3)\n        x = np.arange(np.prod(input_shape)).reshape(input_shape)\n        for interpolation in [\"nearest\", \"bilinear\", \"bicubic\"]:\n            layers.UpSampling2D(size=(1, 2), interpolation=interpolation)(x)\n\n    @pytest.mark.skipif(\n        backend.backend() == \"torch\", reason=\"Torch does not support lanczos.\"\n    )\n    def test_upsampling_2d_lanczos_interpolation_methods(self):\n        input_shape = (2, 2, 1, 3)\n        x = np.arange(np.prod(input_shape)).reshape(input_shape)\n        for interpolation in [\"lanczos3\", \"lanczos5\"]:\n            layers.UpSampling2D(size=(1, 2), interpolation=interpolation)(x)\n"
  },
  {
    "path": "keras/src/layers/reshaping/up_sampling3d.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\nfrom keras.src.utils import argument_validation\n\n\n@keras_export(\"keras.layers.UpSampling3D\")\nclass UpSampling3D(Layer):\n    \"\"\"Upsampling layer for 3D inputs.\n\n    Repeats the 1st, 2nd and 3rd dimensions\n    of the data by `size[0]`, `size[1]` and `size[2]` respectively.\n\n    Example:\n\n    >>> input_shape = (2, 1, 2, 1, 3)\n    >>> x = np.ones(input_shape)\n    >>> y = keras.layers.UpSampling3D(size=(2, 2, 2))(x)\n    >>> y.shape\n    (2, 2, 4, 2, 3)\n\n    Args:\n        size: Int, or tuple of 3 integers.\n            The upsampling factors for dim1, dim2 and dim3.\n        data_format: A string,\n            one of `\"channels_last\"` (default) or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.\n            When unspecified, uses\n            `image_data_format` value found in your Keras config file at\n             `~/.keras/keras.json` (if exists) else `\"channels_last\"`.\n            Defaults to `\"channels_last\"`.\n\n    Input shape:\n        5D tensor with shape:\n        - If `data_format` is `\"channels_last\"`:\n            `(batch_size, dim1, dim2, dim3, channels)`\n        - If `data_format` is `\"channels_first\"`:\n            `(batch_size, channels, dim1, dim2, dim3)`\n\n    Output shape:\n        5D tensor with shape:\n        - If `data_format` is `\"channels_last\"`:\n            `(batch_size, upsampled_dim1, upsampled_dim2, upsampled_dim3,\n            channels)`\n        - If `data_format` is `\"channels_first\"`:\n            `(batch_size, channels, upsampled_dim1, upsampled_dim2,\n            upsampled_dim3)`\n    \"\"\"\n\n    def __init__(self, size=(2, 2, 2), data_format=None, **kwargs):\n        super().__init__(**kwargs)\n        self.data_format = backend.standardize_data_format(data_format)\n        self.size = argument_validation.standardize_tuple(size, 3, \"size\")\n        self.input_spec = InputSpec(ndim=5)\n\n    def compute_output_shape(self, input_shape):\n        if self.data_format == \"channels_first\":\n            dim1 = (\n                self.size[0] * input_shape[2]\n                if input_shape[2] is not None\n                else None\n            )\n            dim2 = (\n                self.size[1] * input_shape[3]\n                if input_shape[3] is not None\n                else None\n            )\n            dim3 = (\n                self.size[2] * input_shape[4]\n                if input_shape[4] is not None\n                else None\n            )\n            return (input_shape[0], input_shape[1], dim1, dim2, dim3)\n        else:\n            dim1 = (\n                self.size[0] * input_shape[1]\n                if input_shape[1] is not None\n                else None\n            )\n            dim2 = (\n                self.size[1] * input_shape[2]\n                if input_shape[2] is not None\n                else None\n            )\n            dim3 = (\n                self.size[2] * input_shape[3]\n                if input_shape[3] is not None\n                else None\n            )\n            return (input_shape[0], dim1, dim2, dim3, input_shape[4])\n\n    def call(self, inputs):\n        return self._resize_volumes(\n            inputs, self.size[0], self.size[1], self.size[2], self.data_format\n        )\n\n    def get_config(self):\n        config = {\"size\": self.size, \"data_format\": self.data_format}\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n    def _resize_volumes(\n        self, x, depth_factor, height_factor, width_factor, data_format\n    ):\n        \"\"\"Resizes the volume contained in a 5D tensor.\n\n        Args:\n            x: Tensor or variable to resize.\n            depth_factor: Positive integer.\n            height_factor: Positive integer.\n            width_factor: Positive integer.\n            data_format: One of `\"channels_first\"`, `\"channels_last\"`.\n\n        Returns:\n            Resized tensor.\n        \"\"\"\n        if data_format == \"channels_first\":\n            output = ops.repeat(x, depth_factor, axis=2)\n            output = ops.repeat(output, height_factor, axis=3)\n            output = ops.repeat(output, width_factor, axis=4)\n            return output\n        elif data_format == \"channels_last\":\n            output = ops.repeat(x, depth_factor, axis=1)\n            output = ops.repeat(output, height_factor, axis=2)\n            output = ops.repeat(output, width_factor, axis=3)\n            return output\n        else:\n            raise ValueError(f\"Invalid data_format: {data_format}\")\n"
  },
  {
    "path": "keras/src/layers/reshaping/up_sampling3d_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass UpSampling3dTest(testing.TestCase):\n    @parameterized.product(\n        data_format=[\"channels_first\", \"channels_last\"],\n        length_dim1=[2, 3],\n        length_dim2=[2],\n        length_dim3=[3],\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_upsampling_3d(\n        self, data_format, length_dim1, length_dim2, length_dim3\n    ):\n        num_samples = 2\n        stack_size = 2\n        input_len_dim1 = 10\n        input_len_dim2 = 11\n        input_len_dim3 = 12\n\n        if data_format == \"channels_first\":\n            inputs = np.random.rand(\n                num_samples,\n                stack_size,\n                input_len_dim1,\n                input_len_dim2,\n                input_len_dim3,\n            )\n        else:\n            inputs = np.random.rand(\n                num_samples,\n                input_len_dim1,\n                input_len_dim2,\n                input_len_dim3,\n                stack_size,\n            )\n\n        # basic test\n        if data_format == \"channels_first\":\n            expected_output_shape = (2, 2, 20, 22, 24)\n        else:\n            expected_output_shape = (2, 20, 22, 24, 2)\n\n        self.run_layer_test(\n            layers.UpSampling3D,\n            init_kwargs={\"size\": (2, 2, 2), \"data_format\": data_format},\n            input_shape=inputs.shape,\n            expected_output_shape=expected_output_shape,\n            expected_output_dtype=\"float32\",\n            expected_num_trainable_weights=0,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            expected_num_losses=0,\n            supports_masking=False,\n        )\n\n        layer = layers.UpSampling3D(\n            size=(length_dim1, length_dim2, length_dim3),\n            data_format=data_format,\n        )\n        layer.build(inputs.shape)\n        np_output = layer(inputs=backend.Variable(inputs))\n        if data_format == \"channels_first\":\n            self.assertEqual(np_output.shape[2], length_dim1 * input_len_dim1)\n            self.assertEqual(np_output.shape[3], length_dim2 * input_len_dim2)\n            self.assertEqual(np_output.shape[4], length_dim3 * input_len_dim3)\n        else:  # tf\n            self.assertEqual(np_output.shape[1], length_dim1 * input_len_dim1)\n            self.assertEqual(np_output.shape[2], length_dim2 * input_len_dim2)\n            self.assertEqual(np_output.shape[3], length_dim3 * input_len_dim3)\n\n        # compare with numpy\n        if data_format == \"channels_first\":\n            expected_out = np.repeat(inputs, length_dim1, axis=2)\n            expected_out = np.repeat(expected_out, length_dim2, axis=3)\n            expected_out = np.repeat(expected_out, length_dim3, axis=4)\n        else:  # tf\n            expected_out = np.repeat(inputs, length_dim1, axis=1)\n            expected_out = np.repeat(expected_out, length_dim2, axis=2)\n            expected_out = np.repeat(expected_out, length_dim3, axis=3)\n\n        self.assertAllClose(np_output, expected_out)\n\n    def test_upsampling_3d_correctness(self):\n        input_shape = (2, 1, 2, 1, 3)\n        x = np.arange(np.prod(input_shape)).reshape(input_shape)\n        expected_output = np.array(\n            [\n                [\n                    [\n                        [[0.0, 1.0, 2.0], [0.0, 1.0, 2.0]],\n                        [[0.0, 1.0, 2.0], [0.0, 1.0, 2.0]],\n                        [[3.0, 4.0, 5.0], [3.0, 4.0, 5.0]],\n                        [[3.0, 4.0, 5.0], [3.0, 4.0, 5.0]],\n                    ],\n                    [\n                        [[0.0, 1.0, 2.0], [0.0, 1.0, 2.0]],\n                        [[0.0, 1.0, 2.0], [0.0, 1.0, 2.0]],\n                        [[3.0, 4.0, 5.0], [3.0, 4.0, 5.0]],\n                        [[3.0, 4.0, 5.0], [3.0, 4.0, 5.0]],\n                    ],\n                ],\n                [\n                    [\n                        [[6.0, 7.0, 8.0], [6.0, 7.0, 8.0]],\n                        [[6.0, 7.0, 8.0], [6.0, 7.0, 8.0]],\n                        [[9.0, 10.0, 11.0], [9.0, 10.0, 11.0]],\n                        [[9.0, 10.0, 11.0], [9.0, 10.0, 11.0]],\n                    ],\n                    [\n                        [[6.0, 7.0, 8.0], [6.0, 7.0, 8.0]],\n                        [[6.0, 7.0, 8.0], [6.0, 7.0, 8.0]],\n                        [[9.0, 10.0, 11.0], [9.0, 10.0, 11.0]],\n                        [[9.0, 10.0, 11.0], [9.0, 10.0, 11.0]],\n                    ],\n                ],\n            ]\n        )\n        if backend.config.image_data_format() == \"channels_first\":\n            expected_output = expected_output.transpose((0, 4, 1, 2, 3))\n            x = x.transpose((0, 4, 1, 2, 3))\n        self.assertAllClose(\n            layers.UpSampling3D(size=(2, 2, 2))(x), expected_output\n        )\n"
  },
  {
    "path": "keras/src/layers/reshaping/zero_padding1d.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\nfrom keras.src.utils import argument_validation\n\n\n@keras_export(\"keras.layers.ZeroPadding1D\")\nclass ZeroPadding1D(Layer):\n    \"\"\"Zero-padding layer for 1D input (e.g. temporal sequence).\n\n    Example:\n\n    >>> input_shape = (2, 2, 3)\n    >>> x = np.arange(np.prod(input_shape)).reshape(input_shape)\n    >>> x\n    [[[ 0  1  2]\n      [ 3  4  5]]\n     [[ 6  7  8]\n      [ 9 10 11]]]\n    >>> y = keras.layers.ZeroPadding1D(padding=2)(x)\n    >>> y\n    [[[ 0  0  0]\n      [ 0  0  0]\n      [ 0  1  2]\n      [ 3  4  5]\n      [ 0  0  0]\n      [ 0  0  0]]\n     [[ 0  0  0]\n      [ 0  0  0]\n      [ 6  7  8]\n      [ 9 10 11]\n      [ 0  0  0]\n      [ 0  0  0]]]\n\n    Args:\n        padding: Int, or tuple of int (length 2), or dictionary.\n            - If int: how many zeros to add at the beginning and end of\n              the padding dimension (axis 1).\n            - If tuple of 2 ints: how many zeros to add at the beginning and the\n              end of the padding dimension (`(left_pad, right_pad)`).\n        data_format: A string, one of `\"channels_last\"` (default) or\n            `\"channels_first\"`. The ordering of the dimensions in the inputs.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch_size, axis_to_pad, channels)` while `\"channels_first\"`\n            corresponds to inputs with shape\n            `(batch_size, channels, axis_to_pad)`.\n            When unspecified, uses `image_data_format` value found in your Keras\n            config file at `~/.keras/keras.json` (if exists). Defaults to\n            `\"channels_last\"`.\n\n    Input shape:\n        3D tensor with shape:\n        - If `data_format` is `\"channels_last\"`:\n          `(batch_size, axis_to_pad, features)`\n        - If `data_format` is `\"channels_first\"`:\n          `(batch_size, features, axis_to_pad)`\n\n    Output shape:\n        3D tensor with shape:\n        - If `data_format` is `\"channels_last\"`:\n          `(batch_size, padded_axis, features)`\n        - If `data_format` is `\"channels_first\"`:\n          `(batch_size, features, padded_axis)`\n    \"\"\"\n\n    def __init__(self, padding=1, data_format=None, **kwargs):\n        super().__init__(**kwargs)\n        self.data_format = backend.standardize_data_format(data_format)\n        self.padding = argument_validation.standardize_tuple(\n            padding, 2, \"padding\", allow_zero=True\n        )\n        self.input_spec = InputSpec(ndim=3)\n\n    def compute_output_shape(self, input_shape):\n        output_shape = list(input_shape)\n        padding_dim = 2 if self.data_format == \"channels_first\" else 1\n        if output_shape[padding_dim] is not None:\n            output_shape[padding_dim] += self.padding[0] + self.padding[1]\n        return tuple(output_shape)\n\n    def call(self, inputs):\n        if self.data_format == \"channels_first\":\n            all_dims_padding = ((0, 0), (0, 0), self.padding)\n        else:\n            all_dims_padding = ((0, 0), self.padding, (0, 0))\n        return ops.pad(inputs, all_dims_padding)\n\n    def get_config(self):\n        config = {\"padding\": self.padding, \"data_format\": self.data_format}\n        base_config = super().get_config()\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/reshaping/zero_padding1d_test.py",
    "content": "import numpy as np\nfrom absl.testing import parameterized\n\nfrom keras.src import dtype_policies\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass ZeroPadding1DTest(testing.TestCase):\n    @parameterized.parameters(\n        {\"data_format\": \"channels_first\"},\n        {\"data_format\": \"channels_last\"},\n    )\n    def test_zero_padding_1d(self, data_format):\n        inputs = np.random.rand(1, 2, 3)\n        outputs = layers.ZeroPadding1D(padding=(1, 2), data_format=data_format)(\n            inputs\n        )\n        if data_format == \"channels_last\":\n            for index in [0, -1, -2]:\n                self.assertAllClose(outputs[:, index, :], 0.0)\n            self.assertAllClose(outputs[:, 1:-2, :], inputs)\n        else:\n            for index in [0, -1, -2]:\n                self.assertAllClose(outputs[:, :, index], 0.0)\n            self.assertAllClose(outputs[:, :, 1:-2], inputs)\n\n    @parameterized.named_parameters((\"one_tuple\", (2, 2)), (\"one_int\", 2))\n    def test_zero_padding_1d_with_same_padding(self, padding):\n        inputs = np.random.rand(1, 2, 3)\n        outputs = layers.ZeroPadding1D(\n            padding=padding, data_format=\"channels_last\"\n        )(inputs)\n\n        for index in [0, 1, -1, -2]:\n            self.assertAllClose(outputs[:, index, :], 0.0)\n        self.assertAllClose(outputs[:, 2:-2, :], inputs)\n\n    def test_zero_padding_1d_with_dynamic_spatial_dim(self):\n        input_layer = layers.Input(batch_shape=(1, None, 3))\n        padded = layers.ZeroPadding1D((1, 2), data_format=\"channels_last\")(\n            input_layer\n        )\n        self.assertEqual(padded.shape, (1, None, 3))\n\n        input_layer = layers.Input(batch_shape=(1, 2, 3))\n        padded = layers.ZeroPadding1D((1, 2), data_format=\"channels_last\")(\n            input_layer\n        )\n        self.assertEqual(padded.shape, (1, 5, 3))\n\n    @parameterized.parameters(\n        {\"padding\": (1,)},\n        {\"padding\": (1, 2, 3)},\n        {\"padding\": \"1\"},\n    )\n    def test_zero_padding_1d_errors_if_padding_argument_invalid(self, padding):\n        with self.assertRaises(ValueError):\n            layers.ZeroPadding1D(padding)\n\n    @parameterized.parameters(\n        {\"data_format\": \"channels_first\"},\n        {\"data_format\": \"channels_last\"},\n    )\n    def test_zero_padding_1d_get_config(self, data_format):\n        layer = layers.ZeroPadding1D(padding=(1, 2), data_format=data_format)\n        expected_config = {\n            \"dtype\": dtype_policies.serialize(layer.dtype_policy),\n            \"data_format\": data_format,\n            \"name\": layer.name,\n            \"padding\": (1, 2),\n            \"trainable\": layer.trainable,\n        }\n        self.assertEqual(layer.get_config(), expected_config)\n"
  },
  {
    "path": "keras/src/layers/reshaping/zero_padding2d.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\nfrom keras.src.utils import argument_validation\n\n\n@keras_export(\"keras.layers.ZeroPadding2D\")\nclass ZeroPadding2D(Layer):\n    \"\"\"Zero-padding layer for 2D input (e.g. picture).\n\n    This layer can add rows and columns of zeros at the top, bottom, left and\n    right side of an image tensor.\n\n    Example:\n\n    >>> input_shape = (1, 1, 2, 2)\n    >>> x = np.arange(np.prod(input_shape)).reshape(input_shape)\n    >>> x\n    [[[[0 1]\n       [2 3]]]]\n    >>> y = keras.layers.ZeroPadding2D(padding=1)(x)\n    >>> y\n    [[[[0 0]\n       [0 0]\n       [0 0]\n       [0 0]]\n      [[0 0]\n       [0 1]\n       [2 3]\n       [0 0]]\n      [[0 0]\n       [0 0]\n       [0 0]\n       [0 0]]]]\n\n    Args:\n        padding: Int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints.\n            - If int: the same symmetric padding is applied to height and width.\n            - If tuple of 2 ints: interpreted as two different symmetric padding\n              values for height and width:\n              `(symmetric_height_pad, symmetric_width_pad)`.\n            - If tuple of 2 tuples of 2 ints: interpreted as\n             `((top_pad, bottom_pad), (left_pad, right_pad))`.\n        data_format: A string, one of `\"channels_last\"` (default) or\n            `\"channels_first\"`. The ordering of the dimensions in the inputs.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch_size, height, width, channels)` while `\"channels_first\"`\n            corresponds to inputs with shape\n            `(batch_size, channels, height, width)`.\n            When unspecified, uses `image_data_format` value found in your Keras\n            config file at `~/.keras/keras.json` (if exists). Defaults to\n            `\"channels_last\"`.\n\n    Input shape:\n        4D tensor with shape:\n        - If `data_format` is `\"channels_last\"`:\n          `(batch_size, height, width, channels)`\n        - If `data_format` is `\"channels_first\"`:\n          `(batch_size, channels, height, width)`\n\n    Output shape:\n        4D tensor with shape:\n        - If `data_format` is `\"channels_last\"`:\n          `(batch_size, padded_height, padded_width, channels)`\n        - If `data_format` is `\"channels_first\"`:\n          `(batch_size, channels, padded_height, padded_width)`\n    \"\"\"\n\n    def __init__(self, padding=(1, 1), data_format=None, **kwargs):\n        super().__init__(**kwargs)\n        self.data_format = backend.standardize_data_format(data_format)\n        if isinstance(padding, int):\n            self.padding = ((padding, padding), (padding, padding))\n        elif hasattr(padding, \"__len__\"):\n            if len(padding) != 2:\n                raise ValueError(\n                    \"`padding` should have two elements. \"\n                    f\"Received: padding={padding}.\"\n                )\n            height_padding = argument_validation.standardize_tuple(\n                padding[0], 2, \"1st entry of padding\", allow_zero=True\n            )\n            width_padding = argument_validation.standardize_tuple(\n                padding[1], 2, \"2nd entry of padding\", allow_zero=True\n            )\n            self.padding = (height_padding, width_padding)\n        else:\n            raise ValueError(\n                \"`padding` should be either an int, a tuple of 2 ints \"\n                \"(symmetric_height_crop, symmetric_width_crop), \"\n                \"or a tuple of 2 tuples of 2 ints \"\n                \"((top_crop, bottom_crop), (left_crop, right_crop)). \"\n                f\"Received: padding={padding}.\"\n            )\n        self.input_spec = InputSpec(ndim=4)\n\n    def compute_output_shape(self, input_shape):\n        output_shape = list(input_shape)\n        spatial_dims_offset = 2 if self.data_format == \"channels_first\" else 1\n        for index in range(0, 2):\n            if output_shape[index + spatial_dims_offset] is not None:\n                output_shape[index + spatial_dims_offset] += (\n                    self.padding[index][0] + self.padding[index][1]\n                )\n        return tuple(output_shape)\n\n    def call(self, inputs):\n        if self.data_format == \"channels_first\":\n            all_dims_padding = ((0, 0), (0, 0), *self.padding)\n        else:\n            all_dims_padding = ((0, 0), *self.padding, (0, 0))\n        return ops.pad(inputs, all_dims_padding)\n\n    def get_config(self):\n        config = {\"padding\": self.padding, \"data_format\": self.data_format}\n        base_config = super().get_config()\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/reshaping/zero_padding2d_test.py",
    "content": "import numpy as np\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import dtype_policies\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass ZeroPadding2DTest(testing.TestCase):\n    @parameterized.parameters(\n        {\"data_format\": \"channels_first\"},\n        {\"data_format\": \"channels_last\"},\n    )\n    def test_zero_padding_2d(self, data_format):\n        inputs = np.random.rand(1, 2, 3, 4)\n        outputs = layers.ZeroPadding2D(\n            padding=((1, 2), (3, 4)), data_format=data_format\n        )(inputs)\n\n        if data_format == \"channels_first\":\n            for index in [0, -1, -2]:\n                self.assertAllClose(outputs[:, :, index, :], 0.0)\n            for index in [0, 1, 2, -1, -2, -3, -4]:\n                self.assertAllClose(outputs[:, :, :, index], 0.0)\n            self.assertAllClose(outputs[:, :, 1:-2, 3:-4], inputs)\n        else:\n            for index in [0, -1, -2]:\n                self.assertAllClose(outputs[:, index, :, :], 0.0)\n            for index in [0, 1, 2, -1, -2, -3, -4]:\n                self.assertAllClose(outputs[:, :, index, :], 0.0)\n            self.assertAllClose(outputs[:, 1:-2, 3:-4, :], inputs)\n\n    @parameterized.product(\n        (\n            {\"padding\": ((2, 2), (2, 2))},  # 2 tuples\n            {\"padding\": (2, 2)},  # 1 tuple\n            {\"padding\": 2},  # 1 int\n        ),\n        (\n            {\"data_format\": \"channels_first\"},\n            {\"data_format\": \"channels_last\"},\n        ),\n    )\n    def test_zero_padding_2d_with_same_padding(self, padding, data_format):\n        inputs = np.random.rand(1, 2, 3, 4)\n        outputs = layers.ZeroPadding2D(\n            padding=padding, data_format=data_format\n        )(inputs)\n\n        if data_format == \"channels_first\":\n            for index in [0, 1, -1, -2]:\n                self.assertAllClose(outputs[:, :, index, :], 0.0)\n                self.assertAllClose(outputs[:, :, :, index], 0.0)\n            self.assertAllClose(outputs[:, :, 2:-2, 2:-2], inputs)\n        else:\n            for index in [0, 1, -1, -2]:\n                self.assertAllClose(outputs[:, index, :, :], 0.0)\n                self.assertAllClose(outputs[:, :, index, :], 0.0)\n            self.assertAllClose(outputs[:, 2:-2, 2:-2, :], inputs)\n\n    def test_zero_padding_2d_with_dynamic_spatial_dim(self):\n        if backend.config.image_data_format() == \"channels_last\":\n            input_layer = layers.Input(batch_shape=(1, 2, None, 4))\n        else:\n            input_layer = layers.Input(batch_shape=(1, 4, 2, None))\n        padded = layers.ZeroPadding2D(((1, 2), (3, 4)))(input_layer)\n        if backend.config.image_data_format() == \"channels_last\":\n            self.assertEqual(padded.shape, (1, 5, None, 4))\n        else:\n            self.assertEqual(padded.shape, (1, 4, 5, None))\n\n    @parameterized.parameters(\n        {\"padding\": (1,)},\n        {\"padding\": (1, 2, 3)},\n        {\"padding\": \"1\"},\n        {\"padding\": ((1, 2), (3, 4, 5))},\n        {\"padding\": ((1, 2), (3, -4))},\n        {\"padding\": ((1, 2), \"3\")},\n    )\n    def test_zero_padding_2d_errors_if_padding_argument_invalid(self, padding):\n        with self.assertRaises(ValueError):\n            layers.ZeroPadding2D(padding)\n\n    @parameterized.parameters(\n        {\"data_format\": \"channels_first\"},\n        {\"data_format\": \"channels_last\"},\n    )\n    def test_zero_padding_2d_get_config(self, data_format):\n        layer = layers.ZeroPadding2D(padding=(1, 2), data_format=data_format)\n        expected_config = {\n            \"data_format\": data_format,\n            \"dtype\": dtype_policies.serialize(layer.dtype_policy),\n            \"name\": layer.name,\n            \"padding\": ((1, 1), (2, 2)),\n            \"trainable\": layer.trainable,\n        }\n        self.assertEqual(layer.get_config(), expected_config)\n"
  },
  {
    "path": "keras/src/layers/reshaping/zero_padding3d.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\nfrom keras.src.utils import argument_validation\n\n\n@keras_export(\"keras.layers.ZeroPadding3D\")\nclass ZeroPadding3D(Layer):\n    \"\"\"Zero-padding layer for 3D data (spatial or spatio-temporal).\n\n    Example:\n\n    >>> input_shape = (1, 1, 2, 2, 3)\n    >>> x = np.arange(np.prod(input_shape)).reshape(input_shape)\n    >>> y = keras.layers.ZeroPadding3D(padding=2)(x)\n    >>> y.shape\n    (1, 5, 6, 6, 3)\n\n    Args:\n        padding: Int, or tuple of 3 ints, or tuple of 3 tuples of 2 ints.\n            - If int: the same symmetric padding is applied to depth, height,\n              and width.\n            - If tuple of 3 ints: interpreted as three different symmetric\n              padding values for depth, height, and width:\n              `(symmetric_dim1_pad, symmetric_dim2_pad, symmetric_dim3_pad)`.\n            - If tuple of 3 tuples of 2 ints: interpreted as\n              `((left_dim1_pad, right_dim1_pad), (left_dim2_pad,\n              right_dim2_pad), (left_dim3_pad, right_dim3_pad))`.\n        data_format: A string, one of `\"channels_last\"` (default) or\n            `\"channels_first\"`. The ordering of the dimensions in the inputs.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.\n            When unspecified, uses `image_data_format` value found in your Keras\n            config file at `~/.keras/keras.json` (if exists). Defaults to\n            `\"channels_last\"`.\n\n    Input shape:\n        5D tensor with shape:\n        - If `data_format` is `\"channels_last\"`:\n          `(batch_size, first_axis_to_pad, second_axis_to_pad,\n          third_axis_to_pad, depth)`\n        - If `data_format` is `\"channels_first\"`:\n          `(batch_size, depth, first_axis_to_pad, second_axis_to_pad,\n          third_axis_to_pad)`\n\n    Output shape:\n        5D tensor with shape:\n        - If `data_format` is `\"channels_last\"`:\n          `(batch_size, first_padded_axis, second_padded_axis,\n          third_axis_to_pad, depth)`\n        - If `data_format` is `\"channels_first\"`:\n          `(batch_size, depth, first_padded_axis, second_padded_axis,\n          third_axis_to_pad)`\n    \"\"\"\n\n    def __init__(\n        self, padding=((1, 1), (1, 1), (1, 1)), data_format=None, **kwargs\n    ):\n        super().__init__(**kwargs)\n        self.data_format = backend.standardize_data_format(data_format)\n        if isinstance(padding, int):\n            self.padding = (\n                (padding, padding),\n                (padding, padding),\n                (padding, padding),\n            )\n        elif hasattr(padding, \"__len__\"):\n            if len(padding) != 3:\n                raise ValueError(\n                    f\"`padding` should have 3 elements. Received: {padding}.\"\n                )\n            dim1_padding = argument_validation.standardize_tuple(\n                padding[0], 2, \"1st entry of padding\", allow_zero=True\n            )\n            dim2_padding = argument_validation.standardize_tuple(\n                padding[1], 2, \"2nd entry of padding\", allow_zero=True\n            )\n            dim3_padding = argument_validation.standardize_tuple(\n                padding[2], 2, \"3rd entry of padding\", allow_zero=True\n            )\n            self.padding = (dim1_padding, dim2_padding, dim3_padding)\n        else:\n            raise ValueError(\n                \"`padding` should be either an int, a tuple of 3 ints \"\n                \"(symmetric_dim1_pad, symmetric_dim2_pad, symmetric_dim3_pad), \"\n                \"or a tuple of 3 tuples of 2 ints \"\n                \"((left_dim1_pad, right_dim1_pad),\"\n                \" (left_dim2_pad, right_dim2_pad),\"\n                \" (left_dim3_pad, right_dim2_pad)). \"\n                f\"Received: padding={padding}.\"\n            )\n        self.input_spec = InputSpec(ndim=5)\n\n    def compute_output_shape(self, input_shape):\n        output_shape = list(input_shape)\n        spatial_dims_offset = 2 if self.data_format == \"channels_first\" else 1\n        for index in range(0, 3):\n            if output_shape[index + spatial_dims_offset] is not None:\n                output_shape[index + spatial_dims_offset] += (\n                    self.padding[index][0] + self.padding[index][1]\n                )\n        return tuple(output_shape)\n\n    def call(self, inputs):\n        if self.data_format == \"channels_first\":\n            all_dims_padding = ((0, 0), (0, 0), *self.padding)\n        else:\n            all_dims_padding = ((0, 0), *self.padding, (0, 0))\n        return ops.pad(inputs, all_dims_padding)\n\n    def get_config(self):\n        config = {\"padding\": self.padding, \"data_format\": self.data_format}\n        base_config = super().get_config()\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/layers/reshaping/zero_padding3d_test.py",
    "content": "import numpy as np\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import dtype_policies\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass ZeroPadding3DTest(testing.TestCase):\n    @parameterized.parameters(\n        {\"data_format\": \"channels_first\"}, {\"data_format\": \"channels_last\"}\n    )\n    def test_zero_padding_3d(self, data_format):\n        inputs = np.random.rand(1, 2, 3, 4, 5)\n        outputs = layers.ZeroPadding3D(\n            padding=((1, 2), (3, 4), (0, 2)), data_format=data_format\n        )(inputs)\n\n        if data_format == \"channels_first\":\n            for index in [0, -1, -2]:\n                self.assertAllClose(outputs[:, :, index, :, :], 0.0)\n            for index in [0, 1, 2, -1, -2, -3, -4]:\n                self.assertAllClose(outputs[:, :, :, index, :], 0.0)\n            for index in [-1, -2]:\n                self.assertAllClose(outputs[:, :, :, :, index], 0.0)\n            self.assertAllClose(outputs[:, :, 1:-2, 3:-4, 0:-2], inputs)\n        else:\n            for index in [0, -1, -2]:\n                self.assertAllClose(outputs[:, index, :, :, :], 0.0)\n            for index in [0, 1, 2, -1, -2, -3, -4]:\n                self.assertAllClose(outputs[:, :, index, :, :], 0.0)\n            for index in [-1, -2]:\n                self.assertAllClose(outputs[:, :, :, index, :], 0.0)\n            self.assertAllClose(outputs[:, 1:-2, 3:-4, 0:-2, :], inputs)\n\n    @parameterized.product(\n        (\n            {\"padding\": ((2, 2), (2, 2), (2, 2))},  # 3 tuples\n            {\"padding\": (2, 2, 2)},  # 1 tuple\n            {\"padding\": 2},  # 1 int\n        ),\n        (\n            {\"data_format\": \"channels_first\"},\n            {\"data_format\": \"channels_last\"},\n        ),\n    )\n    def test_zero_padding_3d_with_same_padding(self, padding, data_format):\n        inputs = np.random.rand(1, 2, 3, 4, 5)\n        outputs = layers.ZeroPadding3D(\n            padding=padding, data_format=data_format\n        )(inputs)\n\n        if data_format == \"channels_first\":\n            for index in [0, 1, -1, -2]:\n                self.assertAllClose(outputs[:, :, index, :, :], 0.0)\n                self.assertAllClose(outputs[:, :, :, index, :], 0.0)\n                self.assertAllClose(outputs[:, :, :, :, index], 0.0)\n            self.assertAllClose(outputs[:, :, 2:-2, 2:-2, 2:-2], inputs)\n        else:\n            for index in [0, 1, -1, -2]:\n                self.assertAllClose(outputs[:, index, :, :, :], 0.0)\n                self.assertAllClose(outputs[:, :, index, :, :], 0.0)\n                self.assertAllClose(outputs[:, :, :, index, :], 0.0)\n            self.assertAllClose(outputs[:, 2:-2, 2:-2, 2:-2, :], inputs)\n\n    def test_zero_padding_3d_with_dynamic_spatial_dim(self):\n        if backend.config.image_data_format() == \"channels_last\":\n            input_layer = layers.Input(batch_shape=(1, 2, None, 4, 5))\n        else:\n            input_layer = layers.Input(batch_shape=(1, 5, 2, None, 4))\n        padded = layers.ZeroPadding3D(((1, 2), (3, 4), (5, 6)))(input_layer)\n        if backend.config.image_data_format() == \"channels_last\":\n            self.assertEqual(padded.shape, (1, 5, None, 15, 5))\n        else:\n            self.assertEqual(padded.shape, (1, 5, 5, None, 15))\n\n    @parameterized.parameters(\n        {\"padding\": (1,)},\n        {\"padding\": (1, 2)},\n        {\"padding\": (1, 2, 3, 4)},\n        {\"padding\": \"1\"},\n        {\"padding\": ((1, 2), (3, 4), (5, 6, 7))},\n        {\"padding\": ((1, 2), (3, 4), (5, -6))},\n        {\"padding\": ((1, 2), (3, 4), \"5\")},\n    )\n    def test_zero_padding_3d_errors_if_padding_argument_invalid(self, padding):\n        with self.assertRaises(ValueError):\n            layers.ZeroPadding3D(padding=padding)\n\n    @parameterized.parameters(\n        {\"data_format\": \"channels_first\"},\n        {\"data_format\": \"channels_last\"},\n    )\n    def test_zero_padding_3d_get_config(self, data_format):\n        layer = layers.ZeroPadding3D(padding=(1, 2, 3), data_format=data_format)\n        expected_config = {\n            \"data_format\": data_format,\n            \"dtype\": dtype_policies.serialize(layer.dtype_policy),\n            \"name\": layer.name,\n            \"padding\": ((1, 1), (2, 2), (3, 3)),\n            \"trainable\": layer.trainable,\n        }\n        self.assertEqual(layer.get_config(), expected_config)\n"
  },
  {
    "path": "keras/src/layers/rnn/__init__.py",
    "content": ""
  },
  {
    "path": "keras/src/layers/rnn/bidirectional.py",
    "content": "import copy\n\nfrom keras.src import ops\nfrom keras.src import utils\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.layer import Layer\nfrom keras.src.saving import serialization_lib\n\n\n@keras_export(\"keras.layers.Bidirectional\")\nclass Bidirectional(Layer):\n    \"\"\"Bidirectional wrapper for RNNs.\n\n    Args:\n        layer: `keras.layers.RNN` instance, such as\n            `keras.layers.LSTM` or `keras.layers.GRU`.\n            It could also be a `keras.layers.Layer` instance\n            that meets the following criteria:\n            1. Be a sequence-processing layer (accepts 3D+ inputs).\n            2. Have a `go_backwards`, `return_sequences` and `return_state`\n            attribute (with the same semantics as for the `RNN` class).\n            3. Have an `input_spec` attribute.\n            4. Implement serialization via `get_config()` and `from_config()`.\n            Note that the recommended way to create new RNN layers is to write a\n            custom RNN cell and use it with `keras.layers.RNN`, instead of\n            subclassing `keras.layers.Layer` directly.\n            When `return_sequences` is `True`, the output of the masked\n            timestep will be zero regardless of the layer's original\n            `zero_output_for_mask` value.\n        merge_mode: Mode by which outputs of the forward and backward RNNs\n            will be combined. One of `{\"sum\", \"mul\", \"concat\", \"ave\", None}`.\n            If `None`, the outputs will not be combined,\n            they will be returned as a list. Defaults to `\"concat\"`.\n        backward_layer: Optional `keras.layers.RNN`,\n            or `keras.layers.Layer` instance to be used to handle\n            backwards input processing.\n            If `backward_layer` is not provided, the layer instance passed\n            as the `layer` argument will be used to generate the backward layer\n            automatically.\n            Note that the provided `backward_layer` layer should have properties\n            matching those of the `layer` argument, in particular\n            it should have the same values for `stateful`, `return_states`,\n            `return_sequences`, etc. In addition, `backward_layer`\n            and `layer` should have different `go_backwards` argument values.\n            A `ValueError` will be raised if these requirements are not met.\n\n    Call arguments:\n        The call arguments for this layer are the same as those of the\n        wrapped RNN layer. Beware that when passing the `initial_state`\n        argument during the call of this layer, the first half in the\n        list of elements in the `initial_state` list will be passed to\n        the forward RNN call and the last half in the list of elements\n        will be passed to the backward RNN call.\n\n    Note: instantiating a `Bidirectional` layer from an existing RNN layer\n    instance will not reuse the weights state of the RNN layer instance -- the\n    `Bidirectional` layer will have freshly initialized weights.\n\n    Examples:\n\n    ```python\n    model = Sequential([\n        Input(shape=(5, 10)),\n        Bidirectional(LSTM(10, return_sequences=True),\n        Bidirectional(LSTM(10)),\n        Dense(5, activation=\"softmax\"),\n    ])\n    model.compile(loss='categorical_crossentropy', optimizer='rmsprop')\n\n    # With custom backward layer\n    forward_layer = LSTM(10, return_sequences=True)\n    backward_layer = LSTM(10, activation='relu', return_sequences=True,\n                          go_backwards=True)\n    model = Sequential([\n        Input(shape=(5, 10)),\n        Bidirectional(forward_layer, backward_layer=backward_layer),\n        Dense(5, activation=\"softmax\"),\n    ])\n    model.compile(loss='categorical_crossentropy', optimizer='rmsprop')\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        layer,\n        merge_mode=\"concat\",\n        weights=None,\n        backward_layer=None,\n        **kwargs,\n    ):\n        if not isinstance(layer, Layer):\n            raise ValueError(\n                \"Please initialize `Bidirectional` layer with a \"\n                f\"`keras.layers.Layer` instance. Received: {layer}\"\n            )\n        if backward_layer is not None and not isinstance(backward_layer, Layer):\n            raise ValueError(\n                \"`backward_layer` need to be a `keras.layers.Layer` \"\n                f\"instance. Received: {backward_layer}\"\n            )\n        if merge_mode not in [\"sum\", \"mul\", \"ave\", \"concat\", None]:\n            raise ValueError(\n                f\"Invalid merge mode. Received: {merge_mode}. \"\n                \"Merge mode should be one of \"\n                '{\"sum\", \"mul\", \"ave\", \"concat\", None}'\n            )\n        super().__init__(**kwargs)\n\n        # Recreate the forward layer from the original layer config, so that it\n        # will not carry over any state from the layer.\n        config = serialization_lib.serialize_keras_object(layer)\n        config[\"config\"][\"name\"] = (\n            f\"forward_{utils.removeprefix(layer.name, 'forward_')}\"\n        )\n        self.forward_layer = serialization_lib.deserialize_keras_object(config)\n\n        if backward_layer is None:\n            config = serialization_lib.serialize_keras_object(layer)\n            config[\"config\"][\"go_backwards\"] = True\n            config[\"config\"][\"name\"] = (\n                f\"backward_{utils.removeprefix(layer.name, 'backward_')}\"\n            )\n            self.backward_layer = serialization_lib.deserialize_keras_object(\n                config\n            )\n        else:\n            self.backward_layer = backward_layer\n        # Keep the use_cudnn attribute if defined (not serialized).\n        if hasattr(layer, \"use_cudnn\"):\n            self.forward_layer.use_cudnn = layer.use_cudnn\n            self.backward_layer.use_cudnn = layer.use_cudnn\n        self._verify_layer_config()\n\n        def force_zero_output_for_mask(layer):\n            # Force the zero_output_for_mask to be True if returning sequences.\n            if getattr(layer, \"zero_output_for_mask\", None) is not None:\n                layer.zero_output_for_mask = layer.return_sequences\n\n        force_zero_output_for_mask(self.forward_layer)\n        force_zero_output_for_mask(self.backward_layer)\n\n        self.merge_mode = merge_mode\n        if weights:\n            nw = len(weights)\n            self.forward_layer.initial_weights = weights[: nw // 2]\n            self.backward_layer.initial_weights = weights[nw // 2 :]\n        self.stateful = layer.stateful\n        self.return_sequences = layer.return_sequences\n        self.return_state = layer.return_state\n        self.supports_masking = True\n        self.input_spec = layer.input_spec\n\n    def _verify_layer_config(self):\n        \"\"\"Ensure the forward and backward layers have valid common property.\"\"\"\n        if self.forward_layer.go_backwards == self.backward_layer.go_backwards:\n            raise ValueError(\n                \"Forward layer and backward layer should have different \"\n                \"`go_backwards` value. Received: \"\n                \"forward_layer.go_backwards \"\n                f\"{self.forward_layer.go_backwards}, \"\n                \"backward_layer.go_backwards=\"\n                f\"{self.backward_layer.go_backwards}\"\n            )\n\n        common_attributes = (\"stateful\", \"return_sequences\", \"return_state\")\n        for a in common_attributes:\n            forward_value = getattr(self.forward_layer, a)\n            backward_value = getattr(self.backward_layer, a)\n            if forward_value != backward_value:\n                raise ValueError(\n                    \"Forward layer and backward layer are expected to have \"\n                    f'the same value for attribute \"{a}\", got '\n                    f'\"{forward_value}\" for forward layer and '\n                    f'\"{backward_value}\" for backward layer'\n                )\n\n    def compute_output_shape(self, sequences_shape, initial_state_shape=None):\n        output_shape = self.forward_layer.compute_output_shape(sequences_shape)\n\n        if self.return_state:\n            output_shape, state_shape = output_shape[0], output_shape[1:]\n\n        if self.merge_mode == \"concat\":\n            output_shape = list(output_shape)\n            output_shape[-1] *= 2\n            output_shape = tuple(output_shape)\n        elif self.merge_mode is None:\n            output_shape = [output_shape, output_shape]\n\n        if self.return_state:\n            if self.merge_mode is None:\n                return tuple(output_shape) + state_shape + state_shape\n            return tuple([output_shape]) + (state_shape) + (state_shape)\n        return tuple(output_shape)\n\n    def call(\n        self,\n        sequences,\n        initial_state=None,\n        mask=None,\n        training=None,\n    ):\n        kwargs = {}\n        if self.forward_layer._call_has_training_arg:\n            kwargs[\"training\"] = training\n        if self.forward_layer._call_has_mask_arg:\n            kwargs[\"mask\"] = mask\n\n        if initial_state is not None:\n            # initial_states are not keras tensors, eg eager tensor from np\n            # array.  They are only passed in from kwarg initial_state, and\n            # should be passed to forward/backward layer via kwarg\n            # initial_state as well.\n            forward_inputs, backward_inputs = sequences, sequences\n            half = len(initial_state) // 2\n            forward_state = initial_state[:half]\n            backward_state = initial_state[half:]\n        else:\n            forward_inputs, backward_inputs = sequences, sequences\n            forward_state, backward_state = None, None\n\n        y = self.forward_layer(\n            forward_inputs, initial_state=forward_state, **kwargs\n        )\n        y_rev = self.backward_layer(\n            backward_inputs, initial_state=backward_state, **kwargs\n        )\n\n        if self.return_state:\n            states = tuple(y[1:] + y_rev[1:])\n            y = y[0]\n            y_rev = y_rev[0]\n\n        y = ops.cast(y, self.compute_dtype)\n        y_rev = ops.cast(y_rev, self.compute_dtype)\n\n        if self.return_sequences:\n            y_rev = ops.flip(y_rev, axis=1)\n        if self.merge_mode == \"concat\":\n            output = ops.concatenate([y, y_rev], axis=-1)\n        elif self.merge_mode == \"sum\":\n            output = y + y_rev\n        elif self.merge_mode == \"ave\":\n            output = (y + y_rev) / 2\n        elif self.merge_mode == \"mul\":\n            output = y * y_rev\n        elif self.merge_mode is None:\n            output = (y, y_rev)\n        else:\n            raise ValueError(\n                \"Unrecognized value for `merge_mode`. \"\n                f\"Received: {self.merge_mode}\"\n                'Expected one of {\"concat\", \"sum\", \"ave\", \"mul\"}.'\n            )\n        if self.return_state:\n            if self.merge_mode is None:\n                return output + states\n            return (output,) + states\n        return output\n\n    def reset_states(self):\n        # Compatibility alias.\n        self.reset_state()\n\n    def reset_state(self):\n        if not self.stateful:\n            raise AttributeError(\"Layer must be stateful.\")\n        self.forward_layer.reset_state()\n        self.backward_layer.reset_state()\n\n    @property\n    def states(self):\n        if self.forward_layer.states and self.backward_layer.states:\n            return tuple(self.forward_layer.states + self.backward_layer.states)\n        return None\n\n    def build(self, sequences_shape, initial_state_shape=None):\n        if not self.forward_layer.built:\n            self.forward_layer.build(sequences_shape)\n        if not self.backward_layer.built:\n            self.backward_layer.build(sequences_shape)\n\n    def compute_mask(self, _, mask):\n        if isinstance(mask, list):\n            mask = mask[0]\n        if self.return_sequences:\n            if not self.merge_mode:\n                output_mask = (mask, mask)\n            else:\n                output_mask = mask\n        else:\n            output_mask = (None, None) if not self.merge_mode else None\n\n        if self.return_state and self.states is not None:\n            state_mask = (None for _ in self.states)\n            if isinstance(output_mask, list):\n                return output_mask + state_mask * 2\n            return (output_mask,) + state_mask * 2\n        return output_mask\n\n    def get_config(self):\n        config = {\"merge_mode\": self.merge_mode}\n        config[\"layer\"] = serialization_lib.serialize_keras_object(\n            self.forward_layer\n        )\n        config[\"backward_layer\"] = serialization_lib.serialize_keras_object(\n            self.backward_layer\n        )\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n    @classmethod\n    def from_config(cls, config, custom_objects=None):\n        # Instead of updating the input, create a copy and use that.\n        config = copy.deepcopy(config)\n\n        config[\"layer\"] = serialization_lib.deserialize_keras_object(\n            config[\"layer\"], custom_objects=custom_objects\n        )\n        # Handle (optional) backward layer instantiation.\n        backward_layer_config = config.pop(\"backward_layer\", None)\n        if backward_layer_config is not None:\n            backward_layer = serialization_lib.deserialize_keras_object(\n                backward_layer_config, custom_objects=custom_objects\n            )\n            config[\"backward_layer\"] = backward_layer\n        # Instantiate the wrapper, adjust it and return it.\n        layer = cls(**config)\n        return layer\n"
  },
  {
    "path": "keras/src/layers/rnn/bidirectional_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import initializers\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass SimpleRNNTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_basics(self):\n        self.run_layer_test(\n            layers.Bidirectional,\n            init_kwargs={\"layer\": layers.SimpleRNN(4)},\n            input_shape=(3, 2, 4),\n            expected_output_shape=(3, 8),\n            expected_num_trainable_weights=6,\n            expected_num_non_trainable_weights=0,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.Bidirectional,\n            init_kwargs={\n                \"layer\": layers.SimpleRNN(4),\n                \"backward_layer\": layers.SimpleRNN(4, go_backwards=True),\n                \"merge_mode\": \"sum\",\n            },\n            input_shape=(3, 2, 4),\n            expected_output_shape=(3, 4),\n            expected_num_trainable_weights=6,\n            expected_num_non_trainable_weights=0,\n            supports_masking=True,\n        )\n\n    def test_correctness(self):\n        sequence = np.arange(24).reshape((2, 3, 4)).astype(\"float32\")\n        forward_layer = layers.SimpleRNN(\n            2,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n        )\n        layer = layers.Bidirectional(\n            layer=forward_layer,\n        )\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.39687276, 0.39687276, 0.10004295, 0.10004295],\n                    [0.7237238, 0.7237238, 0.53391594, 0.53391594],\n                ]\n            ),\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n        layer = layers.Bidirectional(layer=forward_layer, merge_mode=\"ave\")\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array([[0.24845785, 0.24845785], [0.6288199, 0.6288199]]),\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n        layer = layers.Bidirectional(layer=forward_layer, merge_mode=None)\n        output1, output2 = layer(sequence)\n        self.assertAllClose(\n            np.array([[0.39687276, 0.39687276], [0.7237238, 0.7237238]]),\n            output1,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n        self.assertAllClose(\n            np.array([[0.10004295, 0.10004295], [0.53391594, 0.53391594]]),\n            output2,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n        backward_layer = layers.SimpleRNN(\n            2,\n            kernel_initializer=initializers.Constant(0.03),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.01),\n            go_backwards=True,\n        )\n        layer = layers.Bidirectional(\n            layer=forward_layer, backward_layer=backward_layer, merge_mode=\"mul\"\n        )\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array([[0.08374989, 0.08374989], [0.6740834, 0.6740834]]),\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n        forward_layer = layers.GRU(\n            2,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            return_sequences=True,\n        )\n        layer = layers.Bidirectional(layer=forward_layer, merge_mode=\"sum\")\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [\n                        [0.20937867, 0.20937867],\n                        [0.34462988, 0.34462988],\n                        [0.40290534, 0.40290534],\n                    ],\n                    [\n                        [0.59829646, 0.59829646],\n                        [0.6734641, 0.6734641],\n                        [0.6479671, 0.6479671],\n                    ],\n                ]\n            ),\n            output,\n            atol=1e-5,\n            rtol=1e-5,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n    def test_statefulness(self):\n        sequence = np.arange(24).reshape((2, 4, 3)).astype(\"float32\")\n        forward_layer = layers.LSTM(\n            2,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            stateful=True,\n        )\n        layer = layers.Bidirectional(layer=forward_layer)\n        layer(sequence)\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.26234663, 0.26234663, 0.16959146, 0.16959146],\n                    [0.6137073, 0.6137073, 0.5381646, 0.5381646],\n                ]\n            ),\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n        layer.reset_state()\n        layer(sequence)\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.26234663, 0.26234663, 0.16959146, 0.16959146],\n                    [0.6137073, 0.6137073, 0.5381646, 0.5381646],\n                ]\n            ),\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n    def test_pass_initial_state(self):\n        sequence = np.arange(24).reshape((2, 4, 3)).astype(\"float32\")\n        initial_state = [\n            np.arange(4).reshape((2, 2)).astype(\"float32\") * 1,\n            np.arange(4).reshape((2, 2)).astype(\"float32\") * 2,\n            np.arange(4).reshape((2, 2)).astype(\"float32\") * 3,\n            np.arange(4).reshape((2, 2)).astype(\"float32\") * 4,\n        ]\n        forward_layer = layers.LSTM(\n            2,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n        )\n        layer = layers.Bidirectional(\n            layer=forward_layer,\n        )\n        output = layer(sequence, initial_state=initial_state)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.20794602, 0.4577124, 0.14046375, 0.48191673],\n                    [0.6682636, 0.6711909, 0.60943645, 0.60950446],\n                ]\n            ),\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n    def test_masking(self):\n        sequence = np.arange(24).reshape((2, 4, 3)).astype(\"float32\")\n        forward_layer = layers.GRU(\n            2,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n        )\n        layer = layers.Bidirectional(layer=forward_layer)\n        mask = np.array([[True, True, False, True], [True, False, False, True]])\n        output = layer(sequence, mask=mask)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.19393763, 0.19393763, 0.11669192, 0.11669192],\n                    [0.30818558, 0.30818558, 0.28380975, 0.28380975],\n                ]\n            ),\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n    def test_return_state(self):\n        sequence = np.arange(24).reshape((2, 4, 3)).astype(\"float32\")\n        forward_layer = layers.LSTM(\n            2,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            return_state=True,\n        )\n        layer = layers.Bidirectional(layer=forward_layer)\n        output, h1, c1, h2, c2 = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.1990008, 0.1990008, 0.12659755, 0.12659755],\n                    [0.52335435, 0.52335435, 0.44717982, 0.44717982],\n                ]\n            ),\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n        self.assertAllClose(\n            np.array([[0.1990008, 0.1990008], [0.52335435, 0.52335435]]),\n            h1,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n        self.assertAllClose(\n            np.array([[0.35567185, 0.35567185], [1.0492687, 1.0492687]]),\n            c1,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n        self.assertAllClose(\n            np.array([[0.12659755, 0.12659755], [0.44717982, 0.44717982]]),\n            h2,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n        self.assertAllClose(\n            np.array([[0.2501858, 0.2501858], [0.941473, 0.941473]]),\n            c2,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_output_shape(self):\n        x = np.array([[[101, 202], [303, 404]]])\n        for merge_mode in [\"ave\", \"concat\", \"mul\", \"sum\", None]:\n            sub_layer = layers.LSTM(2, return_state=True)\n            layer = layers.Bidirectional(sub_layer, merge_mode=merge_mode)\n            output = layer(x)\n            output_shape = layer.compute_output_shape(x.shape)\n            for out, shape in zip(output, output_shape):\n                self.assertEqual(out.shape, shape)\n\n        for merge_mode in [\"concat\", \"ave\", \"mul\", \"sum\"]:\n            sub_layer = layers.LSTM(2, return_state=False)\n            layer = layers.Bidirectional(sub_layer, merge_mode=merge_mode)\n            output = layer(x)\n            output_shape = layer.compute_output_shape(x.shape)\n            self.assertEqual(output.shape, output_shape)\n\n        # return_state=False & merge_mode=None\n        sub_layer = layers.LSTM(2, return_state=False)\n        layer = layers.Bidirectional(sub_layer, merge_mode=None)\n        output = layer(x)\n        output_shape = layer.compute_output_shape(x.shape)\n        for out, shape in zip(output, output_shape):\n            self.assertEqual(out.shape, shape)\n\n    def test_keeps_use_cudnn(self):\n        # keep use_cudnn if the layer has it\n        for rnn_class in [layers.GRU, layers.LSTM]:\n            for use_cudnn in [True, False, \"auto\"]:\n                rnn = rnn_class(1, use_cudnn=use_cudnn)\n                bidi = layers.Bidirectional(rnn)\n                self.assertEqual(bidi.forward_layer.use_cudnn, use_cudnn)\n                self.assertEqual(bidi.backward_layer.use_cudnn, use_cudnn)\n\n        # otherwise ignore it\n        rnn = layers.SimpleRNN(1)\n        bidi = layers.Bidirectional(rnn)\n        self.assertFalse(hasattr(bidi.forward_layer, \"use_cudnn\"))\n        self.assertFalse(hasattr(bidi.backward_layer, \"use_cudnn\"))\n"
  },
  {
    "path": "keras/src/layers/rnn/conv_lstm.py",
    "content": "from keras.src import activations\nfrom keras.src import backend\nfrom keras.src import constraints\nfrom keras.src import initializers\nfrom keras.src import ops\nfrom keras.src import regularizers\nfrom keras.src import tree\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\nfrom keras.src.layers.rnn.dropout_rnn_cell import DropoutRNNCell\nfrom keras.src.layers.rnn.rnn import RNN\nfrom keras.src.ops import operation_utils\nfrom keras.src.utils import argument_validation\n\n\nclass ConvLSTMCell(Layer, DropoutRNNCell):\n    \"\"\"Cell class for the ConvLSTM layer.\n\n    Args:\n        rank: Integer, rank of the convolution, e.g. \"2\" for 2D convolutions.\n        filters: Integer, the dimensionality of the output space\n            (i.e. the number of output filters in the convolution).\n        kernel_size: An integer or tuple/list of n integers, specifying the\n            dimensions of the convolution window.\n        strides: An integer or tuple/list of n integers, specifying the strides\n            of the convolution. Specifying any stride value != 1\n            is incompatible with specifying any `dilation_rate` value != 1.\n        padding: One of `\"valid\"` or `\"same\"` (case-insensitive).\n            `\"valid\"` means no padding. `\"same\"` results in padding evenly\n            to the left/right or up/down of the input such that output\n            has the same height/width dimension as the input.\n        data_format: A string, one of `channels_last` (default) or\n            `channels_first`. When unspecified, uses\n            `image_data_format` value found in your Keras config file at\n            `~/.keras/keras.json` (if exists) else 'channels_last'.\n            Defaults to `'channels_last'`.\n        dilation_rate: An integer or tuple/list of n integers, specifying the\n            dilation rate to use for dilated convolution.\n            Currently, specifying any `dilation_rate` value != 1 is\n            incompatible with specifying any `strides` value != 1.\n        activation: Activation function. If `None`, no activation is applied.\n        recurrent_activation: Activation function to use for the recurrent step.\n        use_bias: Boolean, (default `True`), whether the layer\n            should use a bias vector.\n        kernel_initializer: Initializer for the `kernel` weights matrix,\n            used for the linear transformation of the inputs. Default:\n            `\"glorot_uniform\"`.\n        recurrent_initializer: Initializer for the `recurrent_kernel`\n            weights matrix, used for the linear transformation of the recurrent\n            state. Default: `\"orthogonal\"`.\n        bias_initializer: Initializer for the bias vector. Default: `\"zeros\"`.\n        unit_forget_bias: Boolean (default `True`). If `True`,\n            add 1 to the bias of the forget gate at initialization.\n            Setting it to `True` will also force `bias_initializer=\"zeros\"`.\n            This is recommended in [Jozefowicz et al.](\n            https://github.com/mlresearch/v37/blob/gh-pages/jozefowicz15.pdf)\n        kernel_regularizer: Regularizer function applied to the `kernel` weights\n            matrix. Default: `None`.\n        recurrent_regularizer: Regularizer function applied to the\n            `recurrent_kernel` weights matrix. Default: `None`.\n        bias_regularizer: Regularizer function applied to the bias vector.\n            Default: `None`.\n        activity_regularizer: Regularizer function applied to the output of the\n            layer (its \"activation\"). Default: `None`.\n        kernel_constraint: Constraint function applied to the `kernel` weights\n            matrix. Default: `None`.\n        recurrent_constraint: Constraint function applied to the\n            `recurrent_kernel` weights matrix. Default: `None`.\n        bias_constraint: Constraint function applied to the bias vector.\n            Default: `None`.\n        dropout: Float between 0 and 1. Fraction of the units to drop for the\n            linear transformation of the inputs. Default: 0.\n        recurrent_dropout: Float between 0 and 1. Fraction of the units to drop\n            for the linear transformation of the recurrent state. Default: 0.\n        seed: Random seed for dropout.\n\n    Call arguments:\n        inputs: A (2+ `rank`)D tensor.\n        states:  List of state tensors corresponding to the previous timestep.\n        training: Python boolean indicating whether the layer should behave in\n            training mode or in inference mode. Only relevant when `dropout` or\n            `recurrent_dropout` is used.\n    \"\"\"\n\n    def __init__(\n        self,\n        rank,\n        filters,\n        kernel_size,\n        strides=1,\n        padding=\"valid\",\n        data_format=None,\n        dilation_rate=1,\n        activation=\"tanh\",\n        recurrent_activation=\"sigmoid\",\n        use_bias=True,\n        kernel_initializer=\"glorot_uniform\",\n        recurrent_initializer=\"orthogonal\",\n        bias_initializer=\"zeros\",\n        unit_forget_bias=True,\n        kernel_regularizer=None,\n        recurrent_regularizer=None,\n        bias_regularizer=None,\n        kernel_constraint=None,\n        recurrent_constraint=None,\n        bias_constraint=None,\n        dropout=0.0,\n        recurrent_dropout=0.0,\n        seed=None,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.seed = seed\n        self.seed_generator = backend.random.SeedGenerator(seed=seed)\n        self.rank = rank\n        if self.rank > 3:\n            raise ValueError(\n                f\"Rank {rank} convolutions are not currently \"\n                f\"implemented. Received: rank={rank}\"\n            )\n        self.filters = filters\n        self.kernel_size = argument_validation.standardize_tuple(\n            kernel_size, self.rank, \"kernel_size\"\n        )\n        self.strides = argument_validation.standardize_tuple(\n            strides, self.rank, \"strides\", allow_zero=True\n        )\n        self.padding = argument_validation.standardize_padding(padding)\n        self.data_format = backend.standardize_data_format(data_format)\n        self.dilation_rate = argument_validation.standardize_tuple(\n            dilation_rate, self.rank, \"dilation_rate\"\n        )\n        self.activation = activations.get(activation)\n        self.recurrent_activation = activations.get(recurrent_activation)\n        self.use_bias = use_bias\n\n        self.kernel_initializer = initializers.get(kernel_initializer)\n        self.recurrent_initializer = initializers.get(recurrent_initializer)\n        self.bias_initializer = initializers.get(bias_initializer)\n        self.unit_forget_bias = unit_forget_bias\n\n        self.kernel_regularizer = regularizers.get(kernel_regularizer)\n        self.recurrent_regularizer = regularizers.get(recurrent_regularizer)\n        self.bias_regularizer = regularizers.get(bias_regularizer)\n\n        self.kernel_constraint = constraints.get(kernel_constraint)\n        self.recurrent_constraint = constraints.get(recurrent_constraint)\n        self.bias_constraint = constraints.get(bias_constraint)\n\n        self.dropout = min(1.0, max(0.0, dropout))\n        self.recurrent_dropout = min(1.0, max(0.0, recurrent_dropout))\n        self.dropout_mask_count = 4\n        self.input_spec = InputSpec(ndim=rank + 2)\n        self.state_size = -1  # Custom, defined in methods\n\n    def build(self, inputs_shape, states_shape=None):\n        if self.data_format == \"channels_first\":\n            channel_axis = 1\n            self.spatial_dims = inputs_shape[2:]\n        else:\n            channel_axis = -1\n            self.spatial_dims = inputs_shape[1:-1]\n        if None in self.spatial_dims:\n            raise ValueError(\n                \"ConvLSTM layers only support static \"\n                \"input shapes for the spatial dimension. \"\n                f\"Received invalid input shape: input_shape={inputs_shape}\"\n            )\n        if inputs_shape[channel_axis] is None:\n            raise ValueError(\n                \"The channel dimension of the inputs (last axis) should be \"\n                \"defined. Found None. Full input shape received: \"\n                f\"input_shape={inputs_shape}\"\n            )\n        self.input_spec = InputSpec(\n            ndim=self.rank + 3, shape=(None,) + inputs_shape[1:]\n        )\n\n        input_dim = inputs_shape[channel_axis]\n        self.input_dim = input_dim\n        self.kernel_shape = self.kernel_size + (input_dim, self.filters * 4)\n        recurrent_kernel_shape = self.kernel_size + (\n            self.filters,\n            self.filters * 4,\n        )\n\n        self.kernel = self.add_weight(\n            shape=self.kernel_shape,\n            initializer=self.kernel_initializer,\n            name=\"kernel\",\n            regularizer=self.kernel_regularizer,\n            constraint=self.kernel_constraint,\n        )\n        self.recurrent_kernel = self.add_weight(\n            shape=recurrent_kernel_shape,\n            initializer=self.recurrent_initializer,\n            name=\"recurrent_kernel\",\n            regularizer=self.recurrent_regularizer,\n            constraint=self.recurrent_constraint,\n        )\n\n        if self.use_bias:\n            if self.unit_forget_bias:\n\n                def bias_initializer(_, *args, **kwargs):\n                    return ops.concatenate(\n                        [\n                            self.bias_initializer(\n                                (self.filters,), *args, **kwargs\n                            ),\n                            initializers.get(\"ones\")(\n                                (self.filters,), *args, **kwargs\n                            ),\n                            self.bias_initializer(\n                                (self.filters * 2,), *args, **kwargs\n                            ),\n                        ]\n                    )\n\n            else:\n                bias_initializer = self.bias_initializer\n            self.bias = self.add_weight(\n                shape=(self.filters * 4,),\n                name=\"bias\",\n                initializer=bias_initializer,\n                regularizer=self.bias_regularizer,\n                constraint=self.bias_constraint,\n            )\n        else:\n            self.bias = None\n\n    def call(self, inputs, states, training=False):\n        h_tm1 = states[0]  # previous memory state\n        c_tm1 = states[1]  # previous carry state\n\n        if training and 0.0 < self.dropout < 1.0:\n            dp_mask = self.get_dropout_mask(inputs)\n            inputs_i = inputs * dp_mask[0]\n            inputs_f = inputs * dp_mask[1]\n            inputs_c = inputs * dp_mask[2]\n            inputs_o = inputs * dp_mask[3]\n        else:\n            inputs_i = inputs\n            inputs_f = inputs\n            inputs_c = inputs\n            inputs_o = inputs\n\n        if training and 0.0 < self.recurrent_dropout < 1.0:\n            rec_dp_mask = self.get_recurrent_dropout_mask(h_tm1)\n            h_tm1_i = h_tm1 * rec_dp_mask[0]\n            h_tm1_f = h_tm1 * rec_dp_mask[1]\n            h_tm1_c = h_tm1 * rec_dp_mask[2]\n            h_tm1_o = h_tm1 * rec_dp_mask[3]\n        else:\n            h_tm1_i = h_tm1\n            h_tm1_f = h_tm1\n            h_tm1_c = h_tm1\n            h_tm1_o = h_tm1\n\n        (kernel_i, kernel_f, kernel_c, kernel_o) = ops.split(\n            self.kernel, 4, axis=self.rank + 1\n        )\n        (\n            recurrent_kernel_i,\n            recurrent_kernel_f,\n            recurrent_kernel_c,\n            recurrent_kernel_o,\n        ) = ops.split(self.recurrent_kernel, 4, axis=self.rank + 1)\n\n        if self.use_bias:\n            bias_i, bias_f, bias_c, bias_o = ops.split(self.bias, 4)\n        else:\n            bias_i, bias_f, bias_c, bias_o = None, None, None, None\n\n        x_i = self.input_conv(inputs_i, kernel_i, bias_i, padding=self.padding)\n        x_f = self.input_conv(inputs_f, kernel_f, bias_f, padding=self.padding)\n        x_c = self.input_conv(inputs_c, kernel_c, bias_c, padding=self.padding)\n        x_o = self.input_conv(inputs_o, kernel_o, bias_o, padding=self.padding)\n\n        h_i = self.recurrent_conv(h_tm1_i, recurrent_kernel_i)\n        h_f = self.recurrent_conv(h_tm1_f, recurrent_kernel_f)\n        h_c = self.recurrent_conv(h_tm1_c, recurrent_kernel_c)\n        h_o = self.recurrent_conv(h_tm1_o, recurrent_kernel_o)\n\n        i = self.recurrent_activation(x_i + h_i)\n        f = self.recurrent_activation(x_f + h_f)\n        c = f * c_tm1 + i * self.activation(x_c + h_c)\n        o = self.recurrent_activation(x_o + h_o)\n        h = o * self.activation(c)\n        return h, [h, c]\n\n    def compute_output_shape(self, inputs_shape, states_shape=None):\n        conv_output_shape = operation_utils.compute_conv_output_shape(\n            inputs_shape,\n            self.filters,\n            self.kernel_size,\n            strides=self.strides,\n            padding=self.padding,\n            data_format=self.data_format,\n            dilation_rate=self.dilation_rate,\n        )\n        return conv_output_shape, [conv_output_shape, conv_output_shape]\n\n    def get_initial_state(self, batch_size=None):\n        if self.data_format == \"channels_last\":\n            input_shape = (batch_size,) + self.spatial_dims + (self.input_dim,)\n        else:\n            input_shape = (batch_size, self.input_dim) + self.spatial_dims\n        state_shape = self.compute_output_shape(input_shape)[0]\n        return [\n            ops.zeros(state_shape, dtype=self.compute_dtype),\n            ops.zeros(state_shape, dtype=self.compute_dtype),\n        ]\n\n    def input_conv(self, x, w, b=None, padding=\"valid\"):\n        conv_out = ops.conv(\n            x,\n            w,\n            strides=self.strides,\n            padding=padding,\n            data_format=self.data_format,\n            dilation_rate=self.dilation_rate,\n        )\n        if b is not None:\n            if self.data_format == \"channels_last\":\n                bias_shape = (1,) * (self.rank + 1) + (self.filters,)\n            else:\n                bias_shape = (1, self.filters) + (1,) * self.rank\n            bias = ops.reshape(b, bias_shape)\n            conv_out += bias\n        return conv_out\n\n    def recurrent_conv(self, x, w):\n        strides = argument_validation.standardize_tuple(\n            1, self.rank, \"strides\", allow_zero=True\n        )\n        conv_out = ops.conv(\n            x, w, strides=strides, padding=\"same\", data_format=self.data_format\n        )\n        return conv_out\n\n    def get_config(self):\n        config = {\n            \"filters\": self.filters,\n            \"kernel_size\": self.kernel_size,\n            \"strides\": self.strides,\n            \"padding\": self.padding,\n            \"data_format\": self.data_format,\n            \"dilation_rate\": self.dilation_rate,\n            \"activation\": activations.serialize(self.activation),\n            \"recurrent_activation\": activations.serialize(\n                self.recurrent_activation\n            ),\n            \"use_bias\": self.use_bias,\n            \"kernel_initializer\": initializers.serialize(\n                self.kernel_initializer\n            ),\n            \"recurrent_initializer\": initializers.serialize(\n                self.recurrent_initializer\n            ),\n            \"bias_initializer\": initializers.serialize(self.bias_initializer),\n            \"unit_forget_bias\": self.unit_forget_bias,\n            \"kernel_regularizer\": regularizers.serialize(\n                self.kernel_regularizer\n            ),\n            \"recurrent_regularizer\": regularizers.serialize(\n                self.recurrent_regularizer\n            ),\n            \"bias_regularizer\": regularizers.serialize(self.bias_regularizer),\n            \"kernel_constraint\": constraints.serialize(self.kernel_constraint),\n            \"recurrent_constraint\": constraints.serialize(\n                self.recurrent_constraint\n            ),\n            \"bias_constraint\": constraints.serialize(self.bias_constraint),\n            \"dropout\": self.dropout,\n            \"recurrent_dropout\": self.recurrent_dropout,\n            \"seed\": self.seed,\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n\nclass ConvLSTM(RNN):\n    \"\"\"Abstract N-D Convolutional LSTM layer (used as implementation base).\n\n    Similar to an LSTM layer, but the input transformations\n    and recurrent transformations are both convolutional.\n\n    Args:\n        rank: Integer, rank of the convolution, e.g. \"2\" for 2D convolutions.\n        filters: Integer, the dimensionality of the output space\n            (i.e. the number of output filters in the convolution).\n        kernel_size: An integer or tuple/list of n integers, specifying the\n            dimensions of the convolution window.\n        strides: An integer or tuple/list of n integers,\n            specifying the strides of the convolution.\n            Specifying any stride value != 1 is incompatible with specifying\n            any `dilation_rate` value != 1.\n        padding: One of `\"valid\"` or `\"same\"` (case-insensitive).\n            `\"valid\"` means no padding. `\"same\"` results in padding evenly to\n            the left/right or up/down of the input such that output has the same\n            height/width dimension as the input.\n        data_format: A string,\n            one of `channels_last` (default) or `channels_first`.\n            The ordering of the dimensions in the inputs.\n            `channels_last` corresponds to inputs with shape\n            `(batch, time, ..., channels)`\n            while `channels_first` corresponds to\n            inputs with shape `(batch, time, channels, ...)`.\n            When unspecified, uses\n            `image_data_format` value found in your Keras config file at\n            `~/.keras/keras.json` (if exists) else 'channels_last'.\n            Defaults to `'channels_last'`.\n        dilation_rate: An integer or tuple/list of n integers, specifying\n            the dilation rate to use for dilated convolution.\n            Currently, specifying any `dilation_rate` value != 1 is\n            incompatible with specifying any `strides` value != 1.\n        activation: Activation function to use.\n            By default hyperbolic tangent activation function is applied\n            (`tanh(x)`).\n        recurrent_activation: Activation function to use\n            for the recurrent step.\n        use_bias: Boolean, whether the layer uses a bias vector.\n        kernel_initializer: Initializer for the `kernel` weights matrix,\n            used for the linear transformation of the inputs.\n        recurrent_initializer: Initializer for the `recurrent_kernel`\n            weights matrix,\n            used for the linear transformation of the recurrent state.\n        bias_initializer: Initializer for the bias vector.\n        unit_forget_bias: Boolean.\n            If True, add 1 to the bias of the forget gate at initialization.\n            Use in combination with `bias_initializer=\"zeros\"`.\n            This is recommended in [Jozefowicz et al., 2015](\n            http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)\n        kernel_regularizer: Regularizer function applied to\n            the `kernel` weights matrix.\n        recurrent_regularizer: Regularizer function applied to\n            the `recurrent_kernel` weights matrix.\n        bias_regularizer: Regularizer function applied to the bias vector.\n        activity_regularizer: Regularizer function applied to.\n        kernel_constraint: Constraint function applied to\n            the `kernel` weights matrix.\n        recurrent_constraint: Constraint function applied to\n            the `recurrent_kernel` weights matrix.\n        bias_constraint: Constraint function applied to the bias vector.\n        dropout: Float between 0 and 1.\n            Fraction of the units to drop for\n            the linear transformation of the inputs.\n        recurrent_dropout: Float between 0 and 1.\n            Fraction of the units to drop for\n            the linear transformation of the recurrent state.\n        seed: Random seed for dropout.\n        return_sequences: Boolean. Whether to return the last output\n            in the output sequence, or the full sequence. (default False)\n        return_state: Boolean Whether to return the last state\n            in addition to the output. (default False)\n        go_backwards: Boolean (default False).\n            If True, process the input sequence backwards.\n        stateful: Boolean (default False). If True, the last state\n            for each sample at index i in a batch will be used as initial\n            state for the sample of index i in the following batch.\n    \"\"\"\n\n    def __init__(\n        self,\n        rank,\n        filters,\n        kernel_size,\n        strides=1,\n        padding=\"valid\",\n        data_format=None,\n        dilation_rate=1,\n        activation=\"tanh\",\n        recurrent_activation=\"sigmoid\",\n        use_bias=True,\n        kernel_initializer=\"glorot_uniform\",\n        recurrent_initializer=\"orthogonal\",\n        bias_initializer=\"zeros\",\n        unit_forget_bias=True,\n        kernel_regularizer=None,\n        recurrent_regularizer=None,\n        bias_regularizer=None,\n        kernel_constraint=None,\n        recurrent_constraint=None,\n        bias_constraint=None,\n        dropout=0.0,\n        recurrent_dropout=0.0,\n        seed=None,\n        return_sequences=False,\n        return_state=False,\n        go_backwards=False,\n        stateful=False,\n        **kwargs,\n    ):\n        cell = ConvLSTMCell(\n            rank=rank,\n            filters=filters,\n            kernel_size=kernel_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n            activation=activation,\n            recurrent_activation=recurrent_activation,\n            use_bias=use_bias,\n            kernel_initializer=kernel_initializer,\n            recurrent_initializer=recurrent_initializer,\n            bias_initializer=bias_initializer,\n            unit_forget_bias=unit_forget_bias,\n            kernel_regularizer=kernel_regularizer,\n            recurrent_regularizer=recurrent_regularizer,\n            bias_regularizer=bias_regularizer,\n            kernel_constraint=kernel_constraint,\n            recurrent_constraint=recurrent_constraint,\n            bias_constraint=bias_constraint,\n            dropout=dropout,\n            recurrent_dropout=recurrent_dropout,\n            seed=seed,\n            name=\"conv_lstm_cell\",\n            dtype=kwargs.get(\"dtype\"),\n        )\n        super().__init__(\n            cell,\n            return_sequences=return_sequences,\n            return_state=return_state,\n            go_backwards=go_backwards,\n            stateful=stateful,\n            **kwargs,\n        )\n        self.input_spec = InputSpec(ndim=rank + 3)\n\n    def call(self, sequences, initial_state=None, mask=None, training=False):\n        return super().call(\n            sequences, initial_state=initial_state, mask=mask, training=training\n        )\n\n    def compute_output_shape(self, sequences_shape, initial_state_shape=None):\n        batch_size = sequences_shape[0]\n        steps = sequences_shape[1]\n        step_shape = (batch_size,) + sequences_shape[2:]\n        state_shape = self.cell.compute_output_shape(step_shape)[0][1:]\n\n        if self.return_sequences:\n            output_shape = (\n                batch_size,\n                steps,\n            ) + state_shape\n        else:\n            output_shape = (batch_size,) + state_shape\n\n        if self.return_state:\n            batched_state_shape = (batch_size,) + state_shape\n            return output_shape, batched_state_shape, batched_state_shape\n        return output_shape\n\n    def compute_mask(self, _, mask):\n        mask = tree.flatten(mask)[0]\n        output_mask = mask if self.return_sequences else None\n        if self.return_state:\n            state_mask = [None, None]\n            return [output_mask] + state_mask\n        else:\n            return output_mask\n\n    @property\n    def filters(self):\n        return self.cell.filters\n\n    @property\n    def kernel_size(self):\n        return self.cell.kernel_size\n\n    @property\n    def strides(self):\n        return self.cell.strides\n\n    @property\n    def padding(self):\n        return self.cell.padding\n\n    @property\n    def data_format(self):\n        return self.cell.data_format\n\n    @property\n    def dilation_rate(self):\n        return self.cell.dilation_rate\n\n    @property\n    def activation(self):\n        return self.cell.activation\n\n    @property\n    def recurrent_activation(self):\n        return self.cell.recurrent_activation\n\n    @property\n    def use_bias(self):\n        return self.cell.use_bias\n\n    @property\n    def kernel_initializer(self):\n        return self.cell.kernel_initializer\n\n    @property\n    def recurrent_initializer(self):\n        return self.cell.recurrent_initializer\n\n    @property\n    def bias_initializer(self):\n        return self.cell.bias_initializer\n\n    @property\n    def unit_forget_bias(self):\n        return self.cell.unit_forget_bias\n\n    @property\n    def kernel_regularizer(self):\n        return self.cell.kernel_regularizer\n\n    @property\n    def recurrent_regularizer(self):\n        return self.cell.recurrent_regularizer\n\n    @property\n    def bias_regularizer(self):\n        return self.cell.bias_regularizer\n\n    @property\n    def kernel_constraint(self):\n        return self.cell.kernel_constraint\n\n    @property\n    def recurrent_constraint(self):\n        return self.cell.recurrent_constraint\n\n    @property\n    def bias_constraint(self):\n        return self.cell.bias_constraint\n\n    @property\n    def dropout(self):\n        return self.cell.dropout\n\n    @property\n    def recurrent_dropout(self):\n        return self.cell.recurrent_dropout\n\n    def get_config(self):\n        config = {\n            \"filters\": self.filters,\n            \"kernel_size\": self.kernel_size,\n            \"strides\": self.strides,\n            \"padding\": self.padding,\n            \"data_format\": self.data_format,\n            \"dilation_rate\": self.dilation_rate,\n            \"activation\": activations.serialize(self.activation),\n            \"recurrent_activation\": activations.serialize(\n                self.recurrent_activation\n            ),\n            \"use_bias\": self.use_bias,\n            \"kernel_initializer\": initializers.serialize(\n                self.kernel_initializer\n            ),\n            \"recurrent_initializer\": initializers.serialize(\n                self.recurrent_initializer\n            ),\n            \"bias_initializer\": initializers.serialize(self.bias_initializer),\n            \"unit_forget_bias\": self.unit_forget_bias,\n            \"kernel_regularizer\": regularizers.serialize(\n                self.kernel_regularizer\n            ),\n            \"recurrent_regularizer\": regularizers.serialize(\n                self.recurrent_regularizer\n            ),\n            \"bias_regularizer\": regularizers.serialize(self.bias_regularizer),\n            \"activity_regularizer\": regularizers.serialize(\n                self.activity_regularizer\n            ),\n            \"kernel_constraint\": constraints.serialize(self.kernel_constraint),\n            \"recurrent_constraint\": constraints.serialize(\n                self.recurrent_constraint\n            ),\n            \"bias_constraint\": constraints.serialize(self.bias_constraint),\n            \"dropout\": self.dropout,\n            \"recurrent_dropout\": self.recurrent_dropout,\n            \"seed\": self.cell.seed,\n        }\n        base_config = super().get_config()\n        del base_config[\"cell\"]\n        return {**base_config, **config}\n\n    @classmethod\n    def from_config(cls, config):\n        return cls(**config)\n"
  },
  {
    "path": "keras/src/layers/rnn/conv_lstm1d.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.rnn.conv_lstm import ConvLSTM\n\n\n@keras_export(\"keras.layers.ConvLSTM1D\")\nclass ConvLSTM1D(ConvLSTM):\n    \"\"\"1D Convolutional LSTM.\n\n    Similar to an LSTM layer, but the input transformations\n    and recurrent transformations are both convolutional.\n\n    Args:\n        filters: int, the dimension of the output space (the number of filters\n            in the convolution).\n        kernel_size: int or tuple/list of 1 integer, specifying the size of\n            the convolution window.\n        strides: int or tuple/list of 1 integer, specifying the stride length\n            of the convolution. `strides > 1` is incompatible with\n            `dilation_rate > 1`.\n        padding: string, `\"valid\"` or `\"same\"` (case-insensitive).\n            `\"valid\"` means no padding. `\"same\"` results in padding evenly to\n            the left/right or up/down of the input such that output has the\n            same height/width dimension as the input.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, steps, features)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, features, steps)`. It defaults to the `image_data_format`\n            value found in your Keras config file at `~/.keras/keras.json`.\n            If you never set it, then it will be `\"channels_last\"`.\n        dilation_rate: int or tuple/list of 1 integers, specifying the dilation\n            rate to use for dilated convolution.\n        activation: Activation function to use. By default hyperbolic tangent\n            activation function is applied (`tanh(x)`).\n        recurrent_activation: Activation function to use for the recurrent step.\n        use_bias: Boolean, whether the layer uses a bias vector.\n        kernel_initializer: Initializer for the `kernel` weights matrix,\n            used for the linear transformation of the inputs.\n        recurrent_initializer: Initializer for the `recurrent_kernel` weights\n            matrix, used for the linear transformation of the recurrent state.\n        bias_initializer: Initializer for the bias vector.\n        unit_forget_bias: Boolean. If `True`, add 1 to the bias of\n            the forget gate at initialization.\n            Use in combination with `bias_initializer=\"zeros\"`.\n            This is recommended in [Jozefowicz et al., 2015](\n            http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)\n        kernel_regularizer: Regularizer function applied to the `kernel` weights\n            matrix.\n        recurrent_regularizer: Regularizer function applied to the\n            `recurrent_kernel` weights matrix.\n        bias_regularizer: Regularizer function applied to the bias vector.\n        activity_regularizer: Regularizer function applied to.\n        kernel_constraint: Constraint function applied to the `kernel` weights\n            matrix.\n        recurrent_constraint: Constraint function applied to the\n            `recurrent_kernel` weights matrix.\n        bias_constraint: Constraint function applied to the bias vector.\n        dropout: Float between 0 and 1. Fraction of the units to drop for the\n            linear transformation of the inputs.\n        recurrent_dropout: Float between 0 and 1. Fraction of the units to drop\n            for the linear transformation of the recurrent state.\n        seed: Random seed for dropout.\n        return_sequences: Boolean. Whether to return the last output\n            in the output sequence, or the full sequence. Default: `False`.\n        return_state: Boolean. Whether to return the last state in addition\n            to the output. Default: `False`.\n        go_backwards: Boolean (default: `False`).\n            If `True`, process the input sequence backwards and return the\n            reversed sequence.\n        stateful: Boolean (default False). If `True`, the last state\n            for each sample at index i in a batch will be used as initial\n            state for the sample of index i in the following batch.\n        unroll: Boolean (default: `False`).\n            If `True`, the network will be unrolled,\n            else a symbolic loop will be used.\n            Unrolling can speed-up a RNN,\n            although it tends to be more memory-intensive.\n            Unrolling is only suitable for short sequences.\n\n\n    Call arguments:\n        inputs: A 4D tensor.\n        initial_state: List of initial state tensors to be passed to the first\n            call of the cell.\n        mask: Binary tensor of shape `(samples, timesteps)` indicating whether a\n            given timestep should be masked.\n        training: Python boolean indicating whether the layer should behave in\n            training mode or in inference mode.\n            This is only relevant if `dropout` or `recurrent_dropout` are set.\n\n    Input shape:\n\n    - If `data_format=\"channels_first\"`:\n        4D tensor with shape: `(samples, time, channels, rows)`\n    - If `data_format=\"channels_last\"`:\n        4D tensor with shape: `(samples, time, rows, channels)`\n\n    Output shape:\n\n    - If `return_state`: a list of tensors. The first tensor is the output.\n        The remaining tensors are the last states,\n        each 3D tensor with shape: `(samples, filters, new_rows)` if\n        `data_format='channels_first'`\n        or shape: `(samples, new_rows, filters)` if\n        `data_format='channels_last'`.\n        `rows` values might have changed due to padding.\n    - If `return_sequences`: 4D tensor with shape: `(samples, timesteps,\n        filters, new_rows)` if data_format='channels_first'\n        or shape: `(samples, timesteps, new_rows, filters)` if\n        `data_format='channels_last'`.\n    - Else, 3D tensor with shape: `(samples, filters, new_rows)` if\n        `data_format='channels_first'`\n        or shape: `(samples, new_rows, filters)` if\n        `data_format='channels_last'`.\n\n    References:\n\n    - [Shi et al., 2015](http://arxiv.org/abs/1506.04214v1)\n        (the current implementation does not include the feedback loop on the\n        cells output).\n    \"\"\"\n\n    def __init__(\n        self,\n        filters,\n        kernel_size,\n        strides=1,\n        padding=\"valid\",\n        data_format=None,\n        dilation_rate=1,\n        activation=\"tanh\",\n        recurrent_activation=\"sigmoid\",\n        use_bias=True,\n        kernel_initializer=\"glorot_uniform\",\n        recurrent_initializer=\"orthogonal\",\n        bias_initializer=\"zeros\",\n        unit_forget_bias=True,\n        kernel_regularizer=None,\n        recurrent_regularizer=None,\n        bias_regularizer=None,\n        activity_regularizer=None,\n        kernel_constraint=None,\n        recurrent_constraint=None,\n        bias_constraint=None,\n        dropout=0.0,\n        recurrent_dropout=0.0,\n        seed=None,\n        return_sequences=False,\n        return_state=False,\n        go_backwards=False,\n        stateful=False,\n        **kwargs,\n    ):\n        super().__init__(\n            rank=1,\n            filters=filters,\n            kernel_size=kernel_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n            activation=activation,\n            recurrent_activation=recurrent_activation,\n            use_bias=use_bias,\n            kernel_initializer=kernel_initializer,\n            recurrent_initializer=recurrent_initializer,\n            bias_initializer=bias_initializer,\n            unit_forget_bias=unit_forget_bias,\n            kernel_regularizer=kernel_regularizer,\n            recurrent_regularizer=recurrent_regularizer,\n            bias_regularizer=bias_regularizer,\n            activity_regularizer=activity_regularizer,\n            kernel_constraint=kernel_constraint,\n            recurrent_constraint=recurrent_constraint,\n            bias_constraint=bias_constraint,\n            return_sequences=return_sequences,\n            return_state=return_state,\n            go_backwards=go_backwards,\n            stateful=stateful,\n            dropout=dropout,\n            recurrent_dropout=recurrent_dropout,\n            seed=seed,\n            **kwargs,\n        )\n"
  },
  {
    "path": "keras/src/layers/rnn/conv_lstm1d_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import backend\nfrom keras.src import initializers\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass ConvLSTM1DTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_basics(self):\n        channels_last = backend.config.image_data_format() == \"channels_last\"\n        self.run_layer_test(\n            layers.ConvLSTM1D,\n            init_kwargs={\"filters\": 5, \"kernel_size\": 3, \"padding\": \"same\"},\n            input_shape=(3, 2, 4, 3) if channels_last else (3, 2, 3, 4),\n            expected_output_shape=(3, 4, 5) if channels_last else (3, 5, 4),\n            expected_num_trainable_weights=3,\n            expected_num_non_trainable_weights=0,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.ConvLSTM1D,\n            init_kwargs={\n                \"filters\": 5,\n                \"kernel_size\": 3,\n                \"padding\": \"valid\",\n                \"recurrent_dropout\": 0.5,\n            },\n            input_shape=(3, 2, 8, 3) if channels_last else (3, 2, 3, 8),\n            call_kwargs={\"training\": True},\n            expected_output_shape=(3, 6, 5) if channels_last else (3, 5, 6),\n            expected_num_trainable_weights=3,\n            expected_num_non_trainable_weights=0,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.ConvLSTM1D,\n            init_kwargs={\n                \"filters\": 5,\n                \"kernel_size\": 3,\n                \"padding\": \"valid\",\n                \"return_sequences\": True,\n            },\n            input_shape=(3, 2, 8, 3) if channels_last else (3, 2, 3, 8),\n            expected_output_shape=(\n                (3, 2, 6, 5) if channels_last else (3, 2, 5, 6)\n            ),\n            expected_num_trainable_weights=3,\n            expected_num_non_trainable_weights=0,\n            supports_masking=True,\n        )\n\n    def test_correctness(self):\n        sequence = np.arange(120).reshape((2, 3, 4, 5)).astype(\"float32\") / 10\n        expected_output = np.array(\n            [\n                [[0.40807986, 0.40807986], [0.46421072, 0.46421072]],\n                [[0.80933154, 0.80933154], [0.8233646, 0.8233646]],\n            ]\n        )\n        if backend.config.image_data_format() == \"channels_first\":\n            sequence = sequence.transpose((0, 1, 3, 2))\n            expected_output = expected_output.transpose((0, 2, 1))\n        layer = layers.ConvLSTM1D(\n            filters=2,\n            kernel_size=3,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n        )\n        output = layer(sequence)\n        self.assertAllClose(\n            expected_output,\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n"
  },
  {
    "path": "keras/src/layers/rnn/conv_lstm2d.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.rnn.conv_lstm import ConvLSTM\n\n\n@keras_export(\"keras.layers.ConvLSTM2D\")\nclass ConvLSTM2D(ConvLSTM):\n    \"\"\"2D Convolutional LSTM.\n\n    Similar to an LSTM layer, but the input transformations\n    and recurrent transformations are both convolutional.\n\n    Args:\n        filters: int, the dimension of the output space (the number of filters\n            in the convolution).\n        kernel_size: int or tuple/list of 2 integers, specifying the size of the\n            convolution window.\n        strides: int or tuple/list of 2 integers, specifying the stride length\n            of the convolution. `strides > 1` is incompatible with\n            `dilation_rate > 1`.\n        padding: string, `\"valid\"` or `\"same\"` (case-insensitive).\n            `\"valid\"` means no padding. `\"same\"` results in padding evenly to\n            the left/right or up/down of the input such that output has the same\n            height/width dimension as the input.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, steps, features)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, features, steps)`. It defaults to the `image_data_format`\n            value found in your Keras config file at `~/.keras/keras.json`.\n            If you never set it, then it will be `\"channels_last\"`.\n        dilation_rate: int or tuple/list of 2 integers, specifying the dilation\n            rate to use for dilated convolution.\n        activation: Activation function to use. By default hyperbolic tangent\n            activation function is applied (`tanh(x)`).\n        recurrent_activation: Activation function to use for the recurrent step.\n        use_bias: Boolean, whether the layer uses a bias vector.\n        kernel_initializer: Initializer for the `kernel` weights matrix,\n            used for the linear transformation of the inputs.\n        recurrent_initializer: Initializer for the `recurrent_kernel` weights\n            matrix, used for the linear transformation of the recurrent state.\n        bias_initializer: Initializer for the bias vector.\n        unit_forget_bias: Boolean. If `True`, add 1 to the bias of the forget\n            gate at initialization.\n            Use in combination with `bias_initializer=\"zeros\"`.\n            This is recommended in [Jozefowicz et al., 2015](\n            http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)\n        kernel_regularizer: Regularizer function applied to the `kernel` weights\n            matrix.\n        recurrent_regularizer: Regularizer function applied to the\n            `recurrent_kernel` weights matrix.\n        bias_regularizer: Regularizer function applied to the bias vector.\n        activity_regularizer: Regularizer function applied to.\n        kernel_constraint: Constraint function applied to the `kernel` weights\n            matrix.\n        recurrent_constraint: Constraint function applied to the\n            `recurrent_kernel` weights matrix.\n        bias_constraint: Constraint function applied to the bias vector.\n        dropout: Float between 0 and 1. Fraction of the units to drop for the\n            linear transformation of the inputs.\n        recurrent_dropout: Float between 0 and 1. Fraction of the units to drop\n            for the linear transformation of the recurrent state.\n        seed: Random seed for dropout.\n        return_sequences: Boolean. Whether to return the last output\n            in the output sequence, or the full sequence. Default: `False`.\n        return_state: Boolean. Whether to return the last state in addition\n            to the output. Default: `False`.\n        go_backwards: Boolean (default: `False`).\n            If `True`, process the input sequence backwards and return the\n            reversed sequence.\n        stateful: Boolean (default False). If `True`, the last state\n            for each sample at index i in a batch will be used as initial\n            state for the sample of index i in the following batch.\n        unroll: Boolean (default: `False`).\n            If `True`, the network will be unrolled,\n            else a symbolic loop will be used.\n            Unrolling can speed-up a RNN,\n            although it tends to be more memory-intensive.\n            Unrolling is only suitable for short sequences.\n\n\n    Call arguments:\n        inputs: A 5D tensor.\n        mask: Binary tensor of shape `(samples, timesteps)` indicating whether a\n            given timestep should be masked.\n        training: Python boolean indicating whether the layer should behave in\n            training mode or in inference mode.\n            This is only relevant if `dropout` or `recurrent_dropout` are set.\n        initial_state: List of initial state tensors to be passed to the first\n            call of the cell.\n\n    Input shape:\n\n    - If `data_format='channels_first'`:\n        5D tensor with shape: `(samples, time, channels, rows, cols)`\n    - If `data_format='channels_last'`:\n        5D tensor with shape: `(samples, time, rows, cols, channels)`\n\n    Output shape:\n\n    - If `return_state`: a list of tensors. The first tensor is the output.\n        The remaining tensors are the last states,\n        each 4D tensor with shape: `(samples, filters, new_rows, new_cols)` if\n        `data_format='channels_first'`\n        or shape: `(samples, new_rows, new_cols, filters)` if\n        `data_format='channels_last'`. `rows` and `cols` values might have\n        changed due to padding.\n    - If `return_sequences`: 5D tensor with shape: `(samples, timesteps,\n        filters, new_rows, new_cols)` if data_format='channels_first'\n        or shape: `(samples, timesteps, new_rows, new_cols, filters)` if\n        `data_format='channels_last'`.\n    - Else, 4D tensor with shape: `(samples, filters, new_rows, new_cols)` if\n        `data_format='channels_first'`\n        or shape: `(samples, new_rows, new_cols, filters)` if\n        `data_format='channels_last'`.\n\n    References:\n\n    - [Shi et al., 2015](http://arxiv.org/abs/1506.04214v1)\n        (the current implementation does not include the feedback loop on the\n        cells output).\n    \"\"\"\n\n    def __init__(\n        self,\n        filters,\n        kernel_size,\n        strides=1,\n        padding=\"valid\",\n        data_format=None,\n        dilation_rate=1,\n        activation=\"tanh\",\n        recurrent_activation=\"sigmoid\",\n        use_bias=True,\n        kernel_initializer=\"glorot_uniform\",\n        recurrent_initializer=\"orthogonal\",\n        bias_initializer=\"zeros\",\n        unit_forget_bias=True,\n        kernel_regularizer=None,\n        recurrent_regularizer=None,\n        bias_regularizer=None,\n        activity_regularizer=None,\n        kernel_constraint=None,\n        recurrent_constraint=None,\n        bias_constraint=None,\n        dropout=0.0,\n        recurrent_dropout=0.0,\n        seed=None,\n        return_sequences=False,\n        return_state=False,\n        go_backwards=False,\n        stateful=False,\n        **kwargs,\n    ):\n        super().__init__(\n            rank=2,\n            filters=filters,\n            kernel_size=kernel_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n            activation=activation,\n            recurrent_activation=recurrent_activation,\n            use_bias=use_bias,\n            kernel_initializer=kernel_initializer,\n            recurrent_initializer=recurrent_initializer,\n            bias_initializer=bias_initializer,\n            unit_forget_bias=unit_forget_bias,\n            kernel_regularizer=kernel_regularizer,\n            recurrent_regularizer=recurrent_regularizer,\n            bias_regularizer=bias_regularizer,\n            activity_regularizer=activity_regularizer,\n            kernel_constraint=kernel_constraint,\n            recurrent_constraint=recurrent_constraint,\n            bias_constraint=bias_constraint,\n            return_sequences=return_sequences,\n            return_state=return_state,\n            go_backwards=go_backwards,\n            stateful=stateful,\n            dropout=dropout,\n            recurrent_dropout=recurrent_dropout,\n            seed=seed,\n            **kwargs,\n        )\n"
  },
  {
    "path": "keras/src/layers/rnn/conv_lstm2d_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import backend\nfrom keras.src import initializers\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass ConvLSTM2DTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_basics(self):\n        channels_last = backend.config.image_data_format() == \"channels_last\"\n        self.run_layer_test(\n            layers.ConvLSTM2D,\n            init_kwargs={\"filters\": 5, \"kernel_size\": 3, \"padding\": \"same\"},\n            input_shape=(3, 2, 4, 4, 3) if channels_last else (3, 2, 3, 4, 4),\n            expected_output_shape=(\n                (3, 4, 4, 5) if channels_last else (3, 5, 4, 4)\n            ),\n            expected_num_trainable_weights=3,\n            expected_num_non_trainable_weights=0,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.ConvLSTM2D,\n            init_kwargs={\n                \"filters\": 5,\n                \"kernel_size\": 3,\n                \"padding\": \"valid\",\n                \"recurrent_dropout\": 0.5,\n            },\n            input_shape=(3, 2, 8, 8, 3) if channels_last else (3, 2, 3, 8, 8),\n            call_kwargs={\"training\": True},\n            expected_output_shape=(\n                (3, 6, 6, 5) if channels_last else (3, 5, 6, 6)\n            ),\n            expected_num_trainable_weights=3,\n            expected_num_non_trainable_weights=0,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.ConvLSTM2D,\n            init_kwargs={\n                \"filters\": 5,\n                \"kernel_size\": 3,\n                \"padding\": \"valid\",\n                \"return_sequences\": True,\n            },\n            input_shape=(3, 2, 8, 8, 3) if channels_last else (3, 2, 3, 8, 8),\n            expected_output_shape=(\n                (3, 2, 6, 6, 5) if channels_last else (3, 2, 5, 6, 6)\n            ),\n            expected_num_trainable_weights=3,\n            expected_num_non_trainable_weights=0,\n            supports_masking=True,\n        )\n\n    def test_correctness(self):\n        sequence = (\n            np.arange(480).reshape((2, 3, 4, 4, 5)).astype(\"float32\") / 100\n        )\n        expected_output = np.array(\n            [\n                [\n                    [[0.48694518, 0.48694518], [0.50237733, 0.50237733]],\n                    [[0.5461202, 0.5461202], [0.5598283, 0.5598283]],\n                ],\n                [\n                    [[0.8661607, 0.8661607], [0.86909103, 0.86909103]],\n                    [[0.8774414, 0.8774414], [0.8800861, 0.8800861]],\n                ],\n            ]\n        )\n        if backend.config.image_data_format() == \"channels_first\":\n            sequence = sequence.transpose((0, 1, 4, 2, 3))\n            expected_output = expected_output.transpose((0, 3, 1, 2))\n        layer = layers.ConvLSTM2D(\n            filters=2,\n            kernel_size=3,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n        )\n        output = layer(sequence)\n        self.assertAllClose(\n            expected_output,\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n"
  },
  {
    "path": "keras/src/layers/rnn/conv_lstm3d.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.layers.rnn.conv_lstm import ConvLSTM\n\n\n@keras_export(\"keras.layers.ConvLSTM3D\")\nclass ConvLSTM3D(ConvLSTM):\n    \"\"\"3D Convolutional LSTM.\n\n    Similar to an LSTM layer, but the input transformations\n    and recurrent transformations are both convolutional.\n\n    Args:\n        filters: int, the dimension of the output space (the number of filters\n            in the convolution).\n        kernel_size: int or tuple/list of 3 integers, specifying the size of the\n            convolution window.\n        strides: int or tuple/list of 3 integers, specifying the stride length\n            of the convolution. `strides > 1` is incompatible with\n            `dilation_rate > 1`.\n        padding: string, `\"valid\"` or `\"same\"` (case-insensitive).\n            `\"valid\"` means no padding. `\"same\"` results in padding evenly to\n            the left/right or up/down of the input such that output has the same\n            height/width dimension as the input.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, steps, features)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, features, steps)`. It defaults to the `image_data_format`\n            value found in your Keras config file at `~/.keras/keras.json`.\n            If you never set it, then it will be `\"channels_last\"`.\n        dilation_rate: int or tuple/list of 3 integers, specifying the dilation\n            rate to use for dilated convolution.\n        activation: Activation function to use. By default hyperbolic tangent\n            activation function is applied (`tanh(x)`).\n        recurrent_activation: Activation function to use for the recurrent step.\n        use_bias: Boolean, whether the layer uses a bias vector.\n        kernel_initializer: Initializer for the `kernel` weights matrix,\n            used for the linear transformation of the inputs.\n        recurrent_initializer: Initializer for the `recurrent_kernel` weights\n            matrix, used for the linear transformation of the recurrent state.\n        bias_initializer: Initializer for the bias vector.\n        unit_forget_bias: Boolean. If `True`, add 1 to the bias of the forget\n            gate at initialization.\n            Use in combination with `bias_initializer=\"zeros\"`.\n            This is recommended in [Jozefowicz et al., 2015](\n            http://www.jmlr.org/proceedings/papers/v37/jozefowicz15.pdf)\n        kernel_regularizer: Regularizer function applied to the `kernel` weights\n            matrix.\n        recurrent_regularizer: Regularizer function applied to the\n            `recurrent_kernel` weights matrix.\n        bias_regularizer: Regularizer function applied to the bias vector.\n        activity_regularizer: Regularizer function applied to.\n        kernel_constraint: Constraint function applied to the `kernel` weights\n            matrix.\n        recurrent_constraint: Constraint function applied to the\n            `recurrent_kernel` weights matrix.\n        bias_constraint: Constraint function applied to the bias vector.\n        dropout: Float between 0 and 1. Fraction of the units to drop for the\n            linear transformation of the inputs.\n        recurrent_dropout: Float between 0 and 1. Fraction of the units to drop\n            for the linear transformation of the recurrent state.\n        seed: Random seed for dropout.\n        return_sequences: Boolean. Whether to return the last output\n            in the output sequence, or the full sequence. Default: `False`.\n        return_state: Boolean. Whether to return the last state in addition\n            to the output. Default: `False`.\n        go_backwards: Boolean (default: `False`).\n            If `True`, process the input sequence backwards and return the\n            reversed sequence.\n        stateful: Boolean (default False). If `True`, the last state\n            for each sample at index i in a batch will be used as initial\n            state for the sample of index i in the following batch.\n        unroll: Boolean (default: `False`).\n            If `True`, the network will be unrolled,\n            else a symbolic loop will be used.\n            Unrolling can speed-up a RNN,\n            although it tends to be more memory-intensive.\n            Unrolling is only suitable for short sequences.\n\n\n    Call arguments:\n        inputs: A 6D tensor.\n        mask: Binary tensor of shape `(samples, timesteps)` indicating whether a\n            given timestep should be masked.\n        training: Python boolean indicating whether the layer should behave in\n            training mode or in inference mode.\n            This is only relevant if `dropout` or `recurrent_dropout` are set.\n        initial_state: List of initial state tensors to be passed to the first\n            call of the cell.\n\n    Input shape:\n\n    - If `data_format='channels_first'`:\n        5D tensor with shape: `(samples, time, channels, *spatial_dims)`\n    - If `data_format='channels_last'`:\n        5D tensor with shape: `(samples, time, *spatial_dims, channels)`\n\n    Output shape:\n\n    - If `return_state`: a list of tensors. The first tensor is the output.\n        The remaining tensors are the last states,\n        each 4D tensor with shape: `(samples, filters, *spatial_dims)` if\n        `data_format='channels_first'`\n        or shape: `(samples, *spatial_dims, filters)` if\n        `data_format='channels_last'`.\n    - If `return_sequences`: 5D tensor with shape: `(samples, timesteps,\n        filters, *spatial_dims)` if data_format='channels_first'\n        or shape: `(samples, timesteps, *spatial_dims, filters)` if\n        `data_format='channels_last'`.\n    - Else, 4D tensor with shape: `(samples, filters, *spatial_dims)` if\n        `data_format='channels_first'`\n        or shape: `(samples, *spatial_dims, filters)` if\n        `data_format='channels_last'`.\n\n    References:\n\n    - [Shi et al., 2015](http://arxiv.org/abs/1506.04214v1)\n        (the current implementation does not include the feedback loop on the\n        cells output).\n    \"\"\"\n\n    def __init__(\n        self,\n        filters,\n        kernel_size,\n        strides=1,\n        padding=\"valid\",\n        data_format=None,\n        dilation_rate=1,\n        activation=\"tanh\",\n        recurrent_activation=\"sigmoid\",\n        use_bias=True,\n        kernel_initializer=\"glorot_uniform\",\n        recurrent_initializer=\"orthogonal\",\n        bias_initializer=\"zeros\",\n        unit_forget_bias=True,\n        kernel_regularizer=None,\n        recurrent_regularizer=None,\n        bias_regularizer=None,\n        activity_regularizer=None,\n        kernel_constraint=None,\n        recurrent_constraint=None,\n        bias_constraint=None,\n        dropout=0.0,\n        recurrent_dropout=0.0,\n        seed=None,\n        return_sequences=False,\n        return_state=False,\n        go_backwards=False,\n        stateful=False,\n        **kwargs,\n    ):\n        super().__init__(\n            rank=3,\n            filters=filters,\n            kernel_size=kernel_size,\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=dilation_rate,\n            activation=activation,\n            recurrent_activation=recurrent_activation,\n            use_bias=use_bias,\n            kernel_initializer=kernel_initializer,\n            recurrent_initializer=recurrent_initializer,\n            bias_initializer=bias_initializer,\n            unit_forget_bias=unit_forget_bias,\n            kernel_regularizer=kernel_regularizer,\n            recurrent_regularizer=recurrent_regularizer,\n            bias_regularizer=bias_regularizer,\n            activity_regularizer=activity_regularizer,\n            kernel_constraint=kernel_constraint,\n            recurrent_constraint=recurrent_constraint,\n            bias_constraint=bias_constraint,\n            return_sequences=return_sequences,\n            return_state=return_state,\n            go_backwards=go_backwards,\n            stateful=stateful,\n            dropout=dropout,\n            recurrent_dropout=recurrent_dropout,\n            seed=seed,\n            **kwargs,\n        )\n"
  },
  {
    "path": "keras/src/layers/rnn/conv_lstm3d_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import backend\nfrom keras.src import initializers\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass ConvLSTM1DTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_basics(self):\n        channels_last = backend.config.image_data_format() == \"channels_last\"\n        self.run_layer_test(\n            layers.ConvLSTM3D,\n            init_kwargs={\"filters\": 5, \"kernel_size\": 3, \"padding\": \"same\"},\n            input_shape=(\n                (3, 2, 4, 4, 4, 3) if channels_last else (3, 2, 3, 4, 4, 4)\n            ),\n            expected_output_shape=(\n                (3, 4, 4, 4, 5) if channels_last else (3, 5, 4, 4, 4)\n            ),\n            expected_num_trainable_weights=3,\n            expected_num_non_trainable_weights=0,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.ConvLSTM3D,\n            init_kwargs={\n                \"filters\": 5,\n                \"kernel_size\": 3,\n                \"padding\": \"valid\",\n                \"recurrent_dropout\": 0.5,\n            },\n            input_shape=(\n                (3, 2, 8, 8, 8, 3) if channels_last else (3, 2, 3, 8, 8, 8)\n            ),\n            call_kwargs={\"training\": True},\n            expected_output_shape=(\n                (3, 6, 6, 6, 5) if channels_last else (3, 5, 6, 6, 6)\n            ),\n            expected_num_trainable_weights=3,\n            expected_num_non_trainable_weights=0,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.ConvLSTM3D,\n            init_kwargs={\n                \"filters\": 5,\n                \"kernel_size\": 3,\n                \"padding\": \"valid\",\n                \"return_sequences\": True,\n            },\n            input_shape=(\n                (3, 2, 8, 8, 8, 3) if channels_last else (3, 2, 3, 8, 8, 8)\n            ),\n            expected_output_shape=(\n                (3, 2, 6, 6, 6, 5) if channels_last else (3, 2, 5, 6, 6, 6)\n            ),\n            expected_num_trainable_weights=3,\n            expected_num_non_trainable_weights=0,\n            supports_masking=True,\n        )\n\n    def test_correctness(self):\n        sequence = (\n            np.arange(1920).reshape((2, 3, 4, 4, 4, 5)).astype(\"float32\") / 100\n        )\n        expected_output = np.array(\n            [\n                [\n                    [\n                        [[0.99149036, 0.99149036], [0.99180907, 0.99180907]],\n                        [[0.99258363, 0.99258363], [0.9927925, 0.9927925]],\n                    ],\n                    [\n                        [[0.99413764, 0.99413764], [0.99420583, 0.99420583]],\n                        [[0.9943788, 0.9943788], [0.9944278, 0.9944278]],\n                    ],\n                ],\n                [\n                    [\n                        [[0.9950547, 0.9950547], [0.9950547, 0.9950547]],\n                        [[0.9950547, 0.9950547], [0.9950547, 0.9950547]],\n                    ],\n                    [\n                        [[0.9950547, 0.9950547], [0.9950547, 0.9950547]],\n                        [[0.9950547, 0.9950547], [0.9950547, 0.9950547]],\n                    ],\n                ],\n            ]\n        )\n        if backend.config.image_data_format() == \"channels_first\":\n            sequence = sequence.transpose((0, 1, 5, 2, 3, 4))\n            expected_output = expected_output.transpose((0, 4, 1, 2, 3))\n        layer = layers.ConvLSTM3D(\n            filters=2,\n            kernel_size=3,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n        )\n        output = layer(sequence)\n        self.assertAllClose(\n            expected_output,\n            output,\n            tpu_atol=1e-4,\n            tpu_rtol=1e-4,\n        )\n"
  },
  {
    "path": "keras/src/layers/rnn/conv_lstm_test.py",
    "content": "import numpy as np\n\nfrom keras.src import backend\nfrom keras.src import initializers\nfrom keras.src import testing\nfrom keras.src.layers.rnn.conv_lstm import ConvLSTM\nfrom keras.src.layers.rnn.conv_lstm import ConvLSTMCell\n\n\nclass ConvLSTMCellTest(testing.TestCase):\n    def test_correctness(self):\n        x = np.arange(150).reshape((2, 5, 5, 3)).astype(\"float32\") / 10\n        s1 = np.arange(200).reshape((2, 5, 5, 4)).astype(\"float32\") / 10\n        s2 = np.arange(200).reshape((2, 5, 5, 4)).astype(\"float32\") / 10\n\n        if backend.config.image_data_format() == \"channels_first\":\n            x = x.transpose((0, 3, 1, 2))\n            s1 = s1.transpose((0, 3, 1, 2))\n            s2 = s2.transpose((0, 3, 1, 2))\n        layer = ConvLSTMCell(\n            rank=2,\n            filters=4,\n            kernel_size=3,\n            padding=\"same\",\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n        )\n        output = layer(x, [s1, s2])\n        checksum_0 = np.sum(backend.convert_to_numpy(output[0]))\n        self.assertAllClose(checksum_0, 188.89502, tpu_atol=1e-4, tpu_rtol=1e-4)\n        checksum_1 = np.sum(backend.convert_to_numpy(output[1][0]))\n        self.assertAllClose(checksum_1, 188.89502, tpu_atol=1e-4, tpu_rtol=1e-4)\n        checksum_2 = np.sum(backend.convert_to_numpy(output[1][1]))\n        self.assertAllClose(checksum_2, 2170.444, tpu_atol=1e-4, tpu_rtol=1e-4)\n\n\nclass ConvLSTMTest(testing.TestCase):\n    def test_correctness(self):\n        x = np.arange(450).reshape((2, 3, 5, 5, 3)).astype(\"float32\") / 100\n        s1 = np.arange(200).reshape((2, 5, 5, 4)).astype(\"float32\") / 100\n        s2 = np.arange(200).reshape((2, 5, 5, 4)).astype(\"float32\") / 100\n\n        if backend.config.image_data_format() == \"channels_first\":\n            x = x.transpose((0, 1, 4, 2, 3))\n            s1 = s1.transpose((0, 3, 1, 2))\n            s2 = s2.transpose((0, 3, 1, 2))\n        layer = ConvLSTM(\n            rank=2,\n            filters=4,\n            kernel_size=3,\n            padding=\"same\",\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n        )\n        output = layer(x, initial_state=[s1, s2])\n        output = backend.convert_to_numpy(output)\n        self.assertAllClose(\n            np.sum(output), 119.812454, tpu_atol=1e-3, tpu_rtol=1e-3\n        )\n"
  },
  {
    "path": "keras/src/layers/rnn/dropout_rnn_cell.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\n\n\nclass DropoutRNNCell:\n    \"\"\"Object that holds dropout-related functionality for RNN cells.\n\n    This class is not a standalone RNN cell. It suppose to be used with a RNN\n    cell by multiple inheritance. Any cell that mix with class should have\n    following fields:\n\n    - `dropout`: a float number in the range `[0, 1]`.\n        Dropout rate for the input tensor.\n    - `recurrent_dropout`: a float number in the range `[0, 1]`.\n        Dropout rate for the recurrent connections.\n    - `seed_generator`, an instance of `backend.random.SeedGenerator`.\n\n    This object will create and cache dropout masks, and reuse them for\n    all incoming steps, so that the same mask is used for every step.\n    \"\"\"\n\n    def _create_dropout_mask(self, step_input, dropout_rate):\n        count = getattr(self, \"dropout_mask_count\", None)\n        ones = ops.ones_like(step_input)\n        if count is None:\n            return backend.random.dropout(\n                ones, rate=dropout_rate, seed=self.seed_generator\n            )\n        else:\n            return [\n                backend.random.dropout(\n                    ones, rate=dropout_rate, seed=self.seed_generator\n                )\n                for _ in range(count)\n            ]\n\n    def get_dropout_mask(self, step_input):\n        if not hasattr(self, \"_dropout_mask\"):\n            self._dropout_mask = None\n        if self._dropout_mask is None and self.dropout > 0:\n            self._dropout_mask = self._create_dropout_mask(\n                step_input, self.dropout\n            )\n        return self._dropout_mask\n\n    def get_recurrent_dropout_mask(self, step_input):\n        if not hasattr(self, \"_recurrent_dropout_mask\"):\n            self._recurrent_dropout_mask = None\n        if self._recurrent_dropout_mask is None and self.recurrent_dropout > 0:\n            self._recurrent_dropout_mask = self._create_dropout_mask(\n                step_input, self.recurrent_dropout\n            )\n        return self._recurrent_dropout_mask\n\n    def reset_dropout_mask(self):\n        \"\"\"Reset the cached dropout mask if any.\n\n        The RNN layer invokes this in the `call()` method\n        so that the cached mask is cleared after calling `cell.call()`. The\n        mask should be cached across all timestep within the same batch, but\n        shouldn't be cached between batches.\n        \"\"\"\n        self._dropout_mask = None\n\n    def reset_recurrent_dropout_mask(self):\n        self._recurrent_dropout_mask = None\n"
  },
  {
    "path": "keras/src/layers/rnn/dropout_rnn_cell_test.py",
    "content": "import pytest\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.layers.rnn.dropout_rnn_cell import DropoutRNNCell\n\n\nclass RNNCellWithDropout(layers.Layer, DropoutRNNCell):\n    def __init__(\n        self, units, dropout=0.5, recurrent_dropout=0.5, seed=None, **kwargs\n    ):\n        super().__init__(**kwargs)\n        self.seed = seed\n        self.seed_generator = backend.random.SeedGenerator(seed)\n        self.units = units\n        self.state_size = units\n        self.dropout = dropout\n        self.recurrent_dropout = recurrent_dropout\n\n    def build(self, input_shape):\n        self.kernel = self.add_weight(\n            shape=(input_shape[-1], self.units),\n            initializer=\"ones\",\n            name=\"kernel\",\n        )\n        self.recurrent_kernel = self.add_weight(\n            shape=(self.units, self.units),\n            initializer=\"ones\",\n            name=\"recurrent_kernel\",\n        )\n\n    def call(self, inputs, states, training=False):\n        if training:\n            dp_mask = self.get_dropout_mask(inputs)\n            inputs = inputs * dp_mask\n        prev_output = states[0]\n        h = ops.matmul(inputs, self.kernel)\n        if training:\n            rdp_mask = self.get_recurrent_dropout_mask(prev_output)\n            prev_output = prev_output * rdp_mask\n        output = h + ops.matmul(prev_output, self.recurrent_kernel)\n        return output, [output]\n\n\nclass DropoutRNNCellTest(testing.TestCase):\n    def test_seed_tracking(self):\n        cell = RNNCellWithDropout(3, seed=1337)\n        self.assertEqual(len(cell.non_trainable_variables), 1)\n        layer = layers.RNN(cell)\n        self.assertEqual(len(layer.non_trainable_variables), 1)\n\n    @pytest.mark.requires_trainable_backend\n    def test_basics(self):\n        self.run_layer_test(\n            layers.RNN,\n            init_kwargs={\"cell\": RNNCellWithDropout(5, seed=1337)},\n            input_shape=(3, 2, 4),\n            call_kwargs={\"training\": True},\n            expected_output_shape=(3, 5),\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=0,\n            expected_num_non_trainable_variables=1,\n            supports_masking=True,\n            run_mixed_precision_check=False,\n        )\n\n        # manually set dtype to mixed_float16 to run mixed precision check\n        run_mixed_precision_check = True\n        if backend.backend() == \"torch\":\n            import torch\n\n            run_mixed_precision_check = torch.cuda.is_available()\n        if run_mixed_precision_check:\n            self.run_layer_test(\n                layers.RNN,\n                init_kwargs={\n                    \"cell\": RNNCellWithDropout(\n                        5, seed=1337, dtype=\"mixed_float16\"\n                    ),\n                    \"dtype\": \"mixed_float16\",\n                },\n                input_shape=(3, 2, 4),\n                call_kwargs={\"training\": True},\n                expected_output_shape=(3, 5),\n                expected_num_trainable_weights=2,\n                expected_num_non_trainable_weights=0,\n                expected_num_non_trainable_variables=1,\n                supports_masking=True,\n                run_mixed_precision_check=False,\n            )\n"
  },
  {
    "path": "keras/src/layers/rnn/gru.py",
    "content": "from keras.src import activations\nfrom keras.src import backend\nfrom keras.src import constraints\nfrom keras.src import initializers\nfrom keras.src import ops\nfrom keras.src import regularizers\nfrom keras.src import tree\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\nfrom keras.src.layers.rnn.dropout_rnn_cell import DropoutRNNCell\nfrom keras.src.layers.rnn.rnn import RNN\n\n\n@keras_export(\"keras.layers.GRUCell\")\nclass GRUCell(Layer, DropoutRNNCell):\n    \"\"\"Cell class for the GRU layer.\n\n    This class processes one step within the whole time sequence input, whereas\n    `keras.layer.GRU` processes the whole sequence.\n\n    Args:\n        units: Positive integer, dimensionality of the output space.\n        activation: Activation function to use. Default: hyperbolic tangent\n            (`tanh`). If you pass None, no activation is applied\n            (ie. \"linear\" activation: `a(x) = x`).\n        recurrent_activation: Activation function to use for the recurrent step.\n            Default: sigmoid (`sigmoid`). If you pass `None`, no activation is\n            applied (ie. \"linear\" activation: `a(x) = x`).\n        use_bias: Boolean, (default `True`), whether the layer\n            should use a bias vector.\n        kernel_initializer: Initializer for the `kernel` weights matrix,\n            used for the linear transformation of the inputs. Default:\n            `\"glorot_uniform\"`.\n        recurrent_initializer: Initializer for the `recurrent_kernel`\n            weights matrix, used for the linear transformation\n            of the recurrent state. Default: `\"orthogonal\"`.\n        bias_initializer: Initializer for the bias vector. Default: `\"zeros\"`.\n        kernel_regularizer: Regularizer function applied to the `kernel` weights\n            matrix. Default: `None`.\n        recurrent_regularizer: Regularizer function applied to the\n            `recurrent_kernel` weights matrix. Default: `None`.\n        bias_regularizer: Regularizer function applied to the bias vector.\n            Default: `None`.\n        kernel_constraint: Constraint function applied to the `kernel` weights\n            matrix. Default: `None`.\n        recurrent_constraint: Constraint function applied to the\n            `recurrent_kernel` weights matrix. Default: `None`.\n        bias_constraint: Constraint function applied to the bias vector.\n            Default: `None`.\n        dropout: Float between 0 and 1. Fraction of the units to drop for the\n            linear transformation of the inputs. Default: 0.\n        recurrent_dropout: Float between 0 and 1. Fraction of the units to drop\n            for the linear transformation of the recurrent state. Default: 0.\n        reset_after: GRU convention (whether to apply reset gate after or\n            before matrix multiplication). False = \"before\",\n            True = \"after\" (default and cuDNN compatible).\n        seed: Random seed for dropout.\n\n    Call arguments:\n        inputs: A 2D tensor, with shape `(batch, features)`.\n        states: A 2D tensor with shape `(batch, units)`, which is the state\n            from the previous time step.\n        training: Python boolean indicating whether the layer should behave in\n            training mode or in inference mode. Only relevant when `dropout` or\n            `recurrent_dropout` is used.\n\n    Example:\n\n    >>> inputs = np.random.random((32, 10, 8))\n    >>> rnn = keras.layers.RNN(keras.layers.GRUCell(4))\n    >>> output = rnn(inputs)\n    >>> output.shape\n    (32, 4)\n    >>> rnn = keras.layers.RNN(\n    ...    keras.layers.GRUCell(4),\n    ...    return_sequences=True,\n    ...    return_state=True)\n    >>> whole_sequence_output, final_state = rnn(inputs)\n    >>> whole_sequence_output.shape\n    (32, 10, 4)\n    >>> final_state.shape\n    (32, 4)\n    \"\"\"\n\n    def __init__(\n        self,\n        units,\n        activation=\"tanh\",\n        recurrent_activation=\"sigmoid\",\n        use_bias=True,\n        kernel_initializer=\"glorot_uniform\",\n        recurrent_initializer=\"orthogonal\",\n        bias_initializer=\"zeros\",\n        kernel_regularizer=None,\n        recurrent_regularizer=None,\n        bias_regularizer=None,\n        kernel_constraint=None,\n        recurrent_constraint=None,\n        bias_constraint=None,\n        dropout=0.0,\n        recurrent_dropout=0.0,\n        reset_after=True,\n        seed=None,\n        **kwargs,\n    ):\n        if units <= 0:\n            raise ValueError(\n                \"Received an invalid value for argument `units`, \"\n                f\"expected a positive integer, got {units}.\"\n            )\n        implementation = kwargs.pop(\"implementation\", 2)\n        super().__init__(**kwargs)\n        self.implementation = implementation\n        self.units = units\n        self.activation = activations.get(activation)\n        self.recurrent_activation = activations.get(recurrent_activation)\n        self.use_bias = use_bias\n\n        self.kernel_initializer = initializers.get(kernel_initializer)\n        self.recurrent_initializer = initializers.get(recurrent_initializer)\n        self.bias_initializer = initializers.get(bias_initializer)\n\n        self.kernel_regularizer = regularizers.get(kernel_regularizer)\n        self.recurrent_regularizer = regularizers.get(recurrent_regularizer)\n        self.bias_regularizer = regularizers.get(bias_regularizer)\n\n        self.kernel_constraint = constraints.get(kernel_constraint)\n        self.recurrent_constraint = constraints.get(recurrent_constraint)\n        self.bias_constraint = constraints.get(bias_constraint)\n\n        self.dropout = min(1.0, max(0.0, dropout))\n        self.recurrent_dropout = min(1.0, max(0.0, recurrent_dropout))\n        if self.recurrent_dropout != 0.0:\n            self.implementation = 1\n        if self.implementation == 1:\n            self.dropout_mask_count = 3\n        self.seed = seed\n        self.seed_generator = backend.random.SeedGenerator(seed=seed)\n\n        self.reset_after = reset_after\n        self.state_size = self.units\n        self.output_size = self.units\n\n    def build(self, input_shape):\n        super().build(input_shape)\n        input_dim = input_shape[-1]\n        self.kernel = self.add_weight(\n            shape=(input_dim, self.units * 3),\n            name=\"kernel\",\n            initializer=self.kernel_initializer,\n            regularizer=self.kernel_regularizer,\n            constraint=self.kernel_constraint,\n        )\n        self.recurrent_kernel = self.add_weight(\n            shape=(self.units, self.units * 3),\n            name=\"recurrent_kernel\",\n            initializer=self.recurrent_initializer,\n            regularizer=self.recurrent_regularizer,\n            constraint=self.recurrent_constraint,\n        )\n\n        if self.use_bias:\n            if not self.reset_after:\n                bias_shape = (3 * self.units,)\n            else:\n                # separate biases for input and recurrent kernels\n                # Note: the shape is intentionally different from CuDNNGRU\n                # biases `(2 * 3 * self.units,)`, so that we can distinguish the\n                # classes when loading and converting saved weights.\n                bias_shape = (2, 3 * self.units)\n            self.bias = self.add_weight(\n                shape=bias_shape,\n                name=\"bias\",\n                initializer=self.bias_initializer,\n                regularizer=self.bias_regularizer,\n                constraint=self.bias_constraint,\n            )\n        else:\n            self.bias = None\n\n    def call(self, inputs, states, training=False):\n        h_tm1 = (\n            states[0] if tree.is_nested(states) else states\n        )  # previous state\n\n        if self.use_bias:\n            if not self.reset_after:\n                input_bias, recurrent_bias = self.bias, None\n            else:\n                input_bias, recurrent_bias = (\n                    ops.squeeze(e, axis=0)\n                    for e in ops.split(self.bias, self.bias.shape[0], axis=0)\n                )\n\n        if self.implementation == 1:\n            if training and 0.0 < self.dropout < 1.0:\n                dp_mask = self.get_dropout_mask(inputs)\n                inputs_z = inputs * dp_mask[0]\n                inputs_r = inputs * dp_mask[1]\n                inputs_h = inputs * dp_mask[2]\n            else:\n                inputs_z = inputs\n                inputs_r = inputs\n                inputs_h = inputs\n\n            x_z = ops.matmul(inputs_z, self.kernel[:, : self.units])\n            x_r = ops.matmul(\n                inputs_r, self.kernel[:, self.units : self.units * 2]\n            )\n            x_h = ops.matmul(inputs_h, self.kernel[:, self.units * 2 :])\n\n            if self.use_bias:\n                x_z += input_bias[: self.units]\n                x_r += input_bias[self.units : self.units * 2]\n                x_h += input_bias[self.units * 2 :]\n\n            if training and 0.0 < self.recurrent_dropout < 1.0:\n                rec_dp_mask = self.get_recurrent_dropout_mask(h_tm1)\n                h_tm1_z = h_tm1 * rec_dp_mask[0]\n                h_tm1_r = h_tm1 * rec_dp_mask[1]\n                h_tm1_h = h_tm1 * rec_dp_mask[2]\n            else:\n                h_tm1_z = h_tm1\n                h_tm1_r = h_tm1\n                h_tm1_h = h_tm1\n\n            recurrent_z = ops.matmul(\n                h_tm1_z, self.recurrent_kernel[:, : self.units]\n            )\n            recurrent_r = ops.matmul(\n                h_tm1_r, self.recurrent_kernel[:, self.units : self.units * 2]\n            )\n            if self.reset_after and self.use_bias:\n                recurrent_z += recurrent_bias[: self.units]\n                recurrent_r += recurrent_bias[self.units : self.units * 2]\n\n            z = self.recurrent_activation(x_z + recurrent_z)\n            r = self.recurrent_activation(x_r + recurrent_r)\n\n            # reset gate applied after/before matrix multiplication\n            if self.reset_after:\n                recurrent_h = ops.matmul(\n                    h_tm1_h, self.recurrent_kernel[:, self.units * 2 :]\n                )\n                if self.use_bias:\n                    recurrent_h += recurrent_bias[self.units * 2 :]\n                recurrent_h = r * recurrent_h\n            else:\n                recurrent_h = ops.matmul(\n                    r * h_tm1_h, self.recurrent_kernel[:, self.units * 2 :]\n                )\n\n            hh = self.activation(x_h + recurrent_h)\n        else:\n            if training and 0.0 < self.dropout < 1.0:\n                dp_mask = self.get_dropout_mask(inputs)\n                inputs = inputs * dp_mask\n\n            # inputs projected by all gate matrices at once\n            matrix_x = ops.matmul(inputs, self.kernel)\n            if self.use_bias:\n                # biases: bias_z_i, bias_r_i, bias_h_i\n                matrix_x = ops.add(matrix_x, input_bias)\n\n            x_z, x_r, x_h = ops.split(matrix_x, 3, axis=-1)\n\n            if self.reset_after:\n                # hidden state projected by all gate matrices at once\n                matrix_inner = ops.matmul(h_tm1, self.recurrent_kernel)\n                if self.use_bias:\n                    matrix_inner += recurrent_bias\n            else:\n                # hidden state projected separately for update/reset and new\n                matrix_inner = ops.matmul(\n                    h_tm1, self.recurrent_kernel[:, : 2 * self.units]\n                )\n\n            recurrent_z = matrix_inner[:, : self.units]\n            recurrent_r = matrix_inner[:, self.units : self.units * 2]\n            recurrent_h = matrix_inner[:, self.units * 2 :]\n\n            z = self.recurrent_activation(x_z + recurrent_z)\n            r = self.recurrent_activation(x_r + recurrent_r)\n\n            if self.reset_after:\n                recurrent_h = r * recurrent_h\n            else:\n                recurrent_h = ops.matmul(\n                    r * h_tm1, self.recurrent_kernel[:, 2 * self.units :]\n                )\n\n            hh = self.activation(x_h + recurrent_h)\n\n        # previous and candidate state mixed by update gate\n        h = z * h_tm1 + (1 - z) * hh\n        new_state = [h] if tree.is_nested(states) else h\n        return h, new_state\n\n    def get_config(self):\n        config = {\n            \"units\": self.units,\n            \"activation\": activations.serialize(self.activation),\n            \"recurrent_activation\": activations.serialize(\n                self.recurrent_activation\n            ),\n            \"use_bias\": self.use_bias,\n            \"kernel_initializer\": initializers.serialize(\n                self.kernel_initializer\n            ),\n            \"recurrent_initializer\": initializers.serialize(\n                self.recurrent_initializer\n            ),\n            \"bias_initializer\": initializers.serialize(self.bias_initializer),\n            \"kernel_regularizer\": regularizers.serialize(\n                self.kernel_regularizer\n            ),\n            \"recurrent_regularizer\": regularizers.serialize(\n                self.recurrent_regularizer\n            ),\n            \"bias_regularizer\": regularizers.serialize(self.bias_regularizer),\n            \"kernel_constraint\": constraints.serialize(self.kernel_constraint),\n            \"recurrent_constraint\": constraints.serialize(\n                self.recurrent_constraint\n            ),\n            \"bias_constraint\": constraints.serialize(self.bias_constraint),\n            \"dropout\": self.dropout,\n            \"recurrent_dropout\": self.recurrent_dropout,\n            \"reset_after\": self.reset_after,\n            \"seed\": self.seed,\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n    def get_initial_state(self, batch_size=None):\n        return [\n            ops.zeros((batch_size, self.state_size), dtype=self.compute_dtype)\n        ]\n\n\n@keras_export(\"keras.layers.GRU\")\nclass GRU(RNN):\n    \"\"\"Gated Recurrent Unit - Cho et al. 2014.\n\n    Based on available runtime hardware and constraints, this layer\n    will choose different implementations (cuDNN-based or backend-native)\n    to maximize the performance. If a GPU is available and all\n    the arguments to the layer meet the requirement of the cuDNN kernel\n    (see below for details), the layer will use a fast cuDNN implementation\n    when using the TensorFlow backend.\n\n    The requirements to use the cuDNN implementation are:\n\n    1. `activation` == `tanh`\n    2. `recurrent_activation` == `sigmoid`\n    3. `recurrent_dropout` == 0\n    4. `unroll` is `False`\n    5. `use_bias` is `True`\n    6. `reset_after` is `True`\n    7. Inputs, if use masking, are strictly right-padded.\n    8. Eager execution is enabled in the outermost context.\n\n    There are two variants of the GRU implementation. The default one is based\n    on [v3](https://arxiv.org/abs/1406.1078v3) and has reset gate applied to\n    hidden state before matrix multiplication. The other one is based on\n    [original](https://arxiv.org/abs/1406.1078v1) and has the order reversed.\n\n    The second variant is compatible with CuDNNGRU (GPU-only) and allows\n    inference on CPU. Thus it has separate biases for `kernel` and\n    `recurrent_kernel`. To use this variant, set `reset_after=True` and\n    `recurrent_activation='sigmoid'`.\n\n    For example:\n\n    >>> inputs = np.random.random((32, 10, 8))\n    >>> gru = keras.layers.GRU(4)\n    >>> output = gru(inputs)\n    >>> output.shape\n    (32, 4)\n    >>> gru = keras.layers.GRU(4, return_sequences=True, return_state=True)\n    >>> whole_sequence_output, final_state = gru(inputs)\n    >>> whole_sequence_output.shape\n    (32, 10, 4)\n    >>> final_state.shape\n    (32, 4)\n\n    Args:\n        units: Positive integer, dimensionality of the output space.\n        activation: Activation function to use.\n            Default: hyperbolic tangent (`tanh`).\n            If you pass `None`, no activation is applied\n            (ie. \"linear\" activation: `a(x) = x`).\n        recurrent_activation: Activation function to use\n            for the recurrent step.\n            Default: sigmoid (`sigmoid`).\n            If you pass `None`, no activation is applied\n            (ie. \"linear\" activation: `a(x) = x`).\n        use_bias: Boolean, (default `True`), whether the layer\n            should use a bias vector.\n        kernel_initializer: Initializer for the `kernel` weights matrix,\n            used for the linear transformation of the inputs. Default:\n            `\"glorot_uniform\"`.\n        recurrent_initializer: Initializer for the `recurrent_kernel`\n            weights matrix, used for the linear transformation of the recurrent\n            state. Default: `\"orthogonal\"`.\n        bias_initializer: Initializer for the bias vector. Default: `\"zeros\"`.\n        kernel_regularizer: Regularizer function applied to the `kernel` weights\n            matrix. Default: `None`.\n        recurrent_regularizer: Regularizer function applied to the\n            `recurrent_kernel` weights matrix. Default: `None`.\n        bias_regularizer: Regularizer function applied to the bias vector.\n            Default: `None`.\n        activity_regularizer: Regularizer function applied to the output of the\n            layer (its \"activation\"). Default: `None`.\n        kernel_constraint: Constraint function applied to the `kernel` weights\n            matrix. Default: `None`.\n        recurrent_constraint: Constraint function applied to the\n            `recurrent_kernel` weights matrix. Default: `None`.\n        bias_constraint: Constraint function applied to the bias vector.\n            Default: `None`.\n        dropout: Float between 0 and 1. Fraction of the units to drop for the\n            linear transformation of the inputs. Default: 0.\n        recurrent_dropout: Float between 0 and 1. Fraction of the units to drop\n            for the linear transformation of the recurrent state. Default: 0.\n        seed: Random seed for dropout.\n        return_sequences: Boolean. Whether to return the last output\n            in the output sequence, or the full sequence. Default: `False`.\n        return_state: Boolean. Whether to return the last state in addition\n            to the output. Default: `False`.\n        go_backwards: Boolean (default `False`).\n            If `True`, process the input sequence backwards and return the\n            reversed sequence.\n        stateful: Boolean (default: `False`). If `True`, the last state\n            for each sample at index i in a batch will be used as initial\n            state for the sample of index i in the following batch.\n        unroll: Boolean (default: `False`).\n            If `True`, the network will be unrolled,\n            else a symbolic loop will be used.\n            Unrolling can speed-up a RNN,\n            although it tends to be more memory-intensive.\n            Unrolling is only suitable for short sequences.\n        reset_after: GRU convention (whether to apply reset gate after or\n            before matrix multiplication). `False` is `\"before\"`,\n            `True` is `\"after\"` (default and cuDNN compatible).\n        use_cudnn: Whether to use a cuDNN-backed implementation. `\"auto\"` will\n            attempt to use cuDNN when feasible, and will fallback to the\n            default implementation if not.\n\n    Call arguments:\n        inputs: A 3D tensor, with shape `(batch, timesteps, feature)`.\n        mask: Binary tensor of shape `(samples, timesteps)` indicating whether\n            a given timestep should be masked  (optional).\n            An individual `True` entry indicates that the corresponding timestep\n            should be utilized, while a `False` entry indicates that the\n            corresponding timestep should be ignored. Defaults to `None`.\n        training: Python boolean indicating whether the layer should behave in\n            training mode or in inference mode. This argument is passed to the\n            cell when calling it. This is only relevant if `dropout` or\n            `recurrent_dropout` is used  (optional). Defaults to `None`.\n        initial_state: List of initial state tensors to be passed to the first\n            call of the cell (optional, `None` causes creation\n            of zero-filled initial state tensors). Defaults to `None`.\n    \"\"\"\n\n    def __init__(\n        self,\n        units,\n        activation=\"tanh\",\n        recurrent_activation=\"sigmoid\",\n        use_bias=True,\n        kernel_initializer=\"glorot_uniform\",\n        recurrent_initializer=\"orthogonal\",\n        bias_initializer=\"zeros\",\n        kernel_regularizer=None,\n        recurrent_regularizer=None,\n        bias_regularizer=None,\n        activity_regularizer=None,\n        kernel_constraint=None,\n        recurrent_constraint=None,\n        bias_constraint=None,\n        dropout=0.0,\n        recurrent_dropout=0.0,\n        seed=None,\n        return_sequences=False,\n        return_state=False,\n        go_backwards=False,\n        stateful=False,\n        unroll=False,\n        reset_after=True,\n        use_cudnn=\"auto\",\n        **kwargs,\n    ):\n        cell = GRUCell(\n            units,\n            activation=activation,\n            recurrent_activation=recurrent_activation,\n            use_bias=use_bias,\n            kernel_initializer=kernel_initializer,\n            recurrent_initializer=recurrent_initializer,\n            bias_initializer=bias_initializer,\n            kernel_regularizer=kernel_regularizer,\n            recurrent_regularizer=recurrent_regularizer,\n            bias_regularizer=bias_regularizer,\n            kernel_constraint=kernel_constraint,\n            recurrent_constraint=recurrent_constraint,\n            bias_constraint=bias_constraint,\n            dropout=dropout,\n            recurrent_dropout=recurrent_dropout,\n            reset_after=reset_after,\n            dtype=kwargs.get(\"dtype\", None),\n            trainable=kwargs.get(\"trainable\", True),\n            name=\"gru_cell\",\n            seed=seed,\n            implementation=kwargs.pop(\"implementation\", 2),\n        )\n        super().__init__(\n            cell,\n            return_sequences=return_sequences,\n            return_state=return_state,\n            go_backwards=go_backwards,\n            stateful=stateful,\n            unroll=unroll,\n            activity_regularizer=activity_regularizer,\n            **kwargs,\n        )\n        self.input_spec = InputSpec(ndim=3)\n        if use_cudnn not in (\"auto\", True, False):\n            raise ValueError(\n                \"Invalid valid received for argument `use_cudnn`. \"\n                \"Expected one of {'auto', True, False}. \"\n                f\"Received: use_cudnn={use_cudnn}\"\n            )\n        self.use_cudnn = use_cudnn\n        if (\n            backend.backend() == \"tensorflow\"\n            and backend.cudnn_ok(\n                cell.activation,\n                cell.recurrent_activation,\n                self.unroll,\n                cell.use_bias,\n                reset_after=reset_after,\n            )\n            and use_cudnn in (True, \"auto\")\n        ):\n            self.supports_jit = False\n\n    def inner_loop(self, sequences, initial_state, mask, training=False):\n        if tree.is_nested(initial_state):\n            initial_state = initial_state[0]\n        if tree.is_nested(mask):\n            mask = mask[0]\n        if self.use_cudnn in (\"auto\", True):\n            if not self.recurrent_dropout:\n                try:\n                    if training and self.dropout:\n                        dp_mask = self.cell.get_dropout_mask(sequences[:, 0, :])\n                        dp_mask = ops.expand_dims(dp_mask, axis=1)\n                        dp_mask = ops.broadcast_to(\n                            dp_mask, ops.shape(sequences)\n                        )\n                        dp_sequences = sequences * dp_mask\n                    else:\n                        dp_sequences = sequences\n                    # Backends are allowed to specify (optionally) optimized\n                    # implementation of the inner GRU loop. In the case of\n                    # TF for instance, it will leverage cuDNN when feasible, and\n                    # it will raise NotImplementedError otherwise.\n                    out = backend.gru(\n                        dp_sequences,\n                        initial_state,\n                        mask,\n                        kernel=self.cell.kernel,\n                        recurrent_kernel=self.cell.recurrent_kernel,\n                        bias=self.cell.bias,\n                        activation=self.cell.activation,\n                        recurrent_activation=self.cell.recurrent_activation,\n                        return_sequences=self.return_sequences,\n                        go_backwards=self.go_backwards,\n                        unroll=self.unroll,\n                        reset_after=self.cell.reset_after,\n                    )\n                    # We disable jit_compile for the model in this case,\n                    # since cuDNN ops aren't XLA compatible.\n                    if backend.backend() == \"tensorflow\":\n                        self.supports_jit = False\n                    return out\n                except NotImplementedError:\n                    pass\n        if self.use_cudnn is True:\n            raise ValueError(\n                \"use_cudnn=True was specified, \"\n                \"but cuDNN is not supported for this layer configuration \"\n                \"with this backend. Pass use_cudnn='auto' to fallback \"\n                \"to a non-cuDNN implementation.\"\n            )\n        return super().inner_loop(\n            sequences, initial_state, mask=mask, training=training\n        )\n\n    def call(self, sequences, initial_state=None, mask=None, training=False):\n        return super().call(\n            sequences, mask=mask, training=training, initial_state=initial_state\n        )\n\n    @property\n    def units(self):\n        return self.cell.units\n\n    @property\n    def activation(self):\n        return self.cell.activation\n\n    @property\n    def recurrent_activation(self):\n        return self.cell.recurrent_activation\n\n    @property\n    def use_bias(self):\n        return self.cell.use_bias\n\n    @property\n    def kernel_initializer(self):\n        return self.cell.kernel_initializer\n\n    @property\n    def recurrent_initializer(self):\n        return self.cell.recurrent_initializer\n\n    @property\n    def bias_initializer(self):\n        return self.cell.bias_initializer\n\n    @property\n    def kernel_regularizer(self):\n        return self.cell.kernel_regularizer\n\n    @property\n    def recurrent_regularizer(self):\n        return self.cell.recurrent_regularizer\n\n    @property\n    def bias_regularizer(self):\n        return self.cell.bias_regularizer\n\n    @property\n    def kernel_constraint(self):\n        return self.cell.kernel_constraint\n\n    @property\n    def recurrent_constraint(self):\n        return self.cell.recurrent_constraint\n\n    @property\n    def bias_constraint(self):\n        return self.cell.bias_constraint\n\n    @property\n    def dropout(self):\n        return self.cell.dropout\n\n    @property\n    def recurrent_dropout(self):\n        return self.cell.recurrent_dropout\n\n    @property\n    def reset_after(self):\n        return self.cell.reset_after\n\n    def get_config(self):\n        config = {\n            \"units\": self.units,\n            \"activation\": activations.serialize(self.activation),\n            \"recurrent_activation\": activations.serialize(\n                self.recurrent_activation\n            ),\n            \"use_bias\": self.use_bias,\n            \"kernel_initializer\": initializers.serialize(\n                self.kernel_initializer\n            ),\n            \"recurrent_initializer\": initializers.serialize(\n                self.recurrent_initializer\n            ),\n            \"bias_initializer\": initializers.serialize(self.bias_initializer),\n            \"kernel_regularizer\": regularizers.serialize(\n                self.kernel_regularizer\n            ),\n            \"recurrent_regularizer\": regularizers.serialize(\n                self.recurrent_regularizer\n            ),\n            \"bias_regularizer\": regularizers.serialize(self.bias_regularizer),\n            \"activity_regularizer\": regularizers.serialize(\n                self.activity_regularizer\n            ),\n            \"kernel_constraint\": constraints.serialize(self.kernel_constraint),\n            \"recurrent_constraint\": constraints.serialize(\n                self.recurrent_constraint\n            ),\n            \"bias_constraint\": constraints.serialize(self.bias_constraint),\n            \"dropout\": self.dropout,\n            \"recurrent_dropout\": self.recurrent_dropout,\n            \"reset_after\": self.reset_after,\n            \"seed\": self.cell.seed,\n        }\n        base_config = super().get_config()\n        del base_config[\"cell\"]\n        return {**base_config, **config}\n\n    @classmethod\n    def from_config(cls, config):\n        return cls(**config)\n"
  },
  {
    "path": "keras/src/layers/rnn/gru_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import initializers\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass GRUTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_basics(self):\n        self.run_layer_test(\n            layers.GRU,\n            init_kwargs={\"units\": 3, \"dropout\": 0.5},\n            input_shape=(3, 2, 4),\n            call_kwargs={\"training\": True},\n            expected_output_shape=(3, 3),\n            expected_num_trainable_weights=3,\n            expected_num_non_trainable_weights=0,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.GRU,\n            init_kwargs={\"units\": 3, \"dropout\": 0.5, \"recurrent_dropout\": 0.5},\n            input_shape=(3, 2, 4),\n            call_kwargs={\"training\": True},\n            expected_output_shape=(3, 3),\n            expected_num_trainable_weights=3,\n            expected_num_non_trainable_weights=0,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.GRU,\n            init_kwargs={\n                \"units\": 3,\n                \"return_sequences\": True,\n                \"bias_regularizer\": \"l1\",\n                \"kernel_regularizer\": \"l2\",\n                \"recurrent_regularizer\": \"l2\",\n            },\n            input_shape=(3, 2, 4),\n            expected_output_shape=(3, 2, 3),\n            expected_num_losses=3,\n            expected_num_trainable_weights=3,\n            expected_num_non_trainable_weights=0,\n            supports_masking=True,\n        )\n\n    @parameterized.parameters([1, 2])\n    def test_correctness(self, implementation):\n        sequence = np.arange(72).reshape((3, 6, 4)).astype(\"float32\")\n        layer = layers.GRU(\n            3,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n        )\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.5217289, 0.5217289, 0.5217289],\n                    [0.6371659, 0.6371659, 0.6371659],\n                    [0.39384964, 0.39384964, 0.3938496],\n                ]\n            ),\n            output,\n            atol=1e-5,\n            rtol=1e-5,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n        layer = layers.GRU(\n            3,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            go_backwards=True,\n        )\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.24406259, 0.24406259, 0.24406259],\n                    [0.611516, 0.611516, 0.611516],\n                    [0.3928808, 0.3928808, 0.3928808],\n                ]\n            ),\n            output,\n            atol=1e-5,\n            rtol=1e-5,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n        layer = layers.GRU(\n            3,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            unroll=True,\n        )\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.5217289, 0.5217289, 0.5217289],\n                    [0.6371659, 0.6371659, 0.6371659],\n                    [0.39384964, 0.39384964, 0.3938496],\n                ]\n            ),\n            output,\n            atol=1e-5,\n            rtol=1e-5,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n        layer = layers.GRU(\n            3,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            reset_after=False,\n        )\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.51447755, 0.51447755, 0.51447755],\n                    [0.6426879, 0.6426879, 0.6426879],\n                    [0.40208298, 0.40208298, 0.40208298],\n                ]\n            ),\n            output,\n            atol=1e-5,\n            rtol=1e-5,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n        layer = layers.GRU(\n            3,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            use_bias=False,\n        )\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.49988455, 0.49988455, 0.49988455],\n                    [0.64701194, 0.64701194, 0.64701194],\n                    [0.4103359, 0.4103359, 0.4103359],\n                ]\n            ),\n            output,\n            atol=1e-5,\n            rtol=1e-5,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n    def test_statefulness(self):\n        sequence = np.arange(24).reshape((2, 3, 4)).astype(\"float32\")\n        layer = layers.GRU(\n            4,\n            stateful=True,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n        )\n        layer(sequence)\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.29542392, 0.29542392, 0.29542392, 0.29542392],\n                    [0.5885018, 0.5885018, 0.5885018, 0.5885018],\n                ]\n            ),\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n        layer.reset_state()\n        layer(sequence)\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.29542392, 0.29542392, 0.29542392, 0.29542392],\n                    [0.5885018, 0.5885018, 0.5885018, 0.5885018],\n                ]\n            ),\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n    def test_pass_initial_state(self):\n        sequence = np.arange(24).reshape((2, 4, 3)).astype(\"float32\")\n        initial_state = np.arange(4).reshape((2, 2)).astype(\"float32\")\n        layer = layers.GRU(\n            2,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n        )\n        output = layer(sequence, initial_state=initial_state)\n        self.assertAllClose(\n            np.array([[0.23774096, 0.33508456], [0.83659905, 1.0227708]]),\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n        layer = layers.GRU(\n            2,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            go_backwards=True,\n        )\n        output = layer(sequence, initial_state=initial_state)\n        self.assertAllClose(\n            np.array([[0.13486053, 0.23261218], [0.78257304, 0.9691353]]),\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n    def test_pass_return_state(self):\n        sequence = np.arange(24).reshape((2, 4, 3)).astype(\"float32\")\n        initial_state = np.arange(4).reshape((2, 2)).astype(\"float32\")\n\n        # Test with go_backwards=False\n        layer = layers.GRU(\n            2,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            return_state=True,\n        )\n        output, state = layer(sequence, initial_state=initial_state)\n        self.assertAllClose(\n            np.array([[0.23774096, 0.33508456], [0.83659905, 1.0227708]]),\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n        self.assertAllClose(\n            output,\n            state,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n        # Test with go_backwards=True\n        layer = layers.GRU(\n            2,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            return_state=True,\n            go_backwards=True,\n        )\n        output, state = layer(sequence, initial_state=initial_state)\n        self.assertAllClose(\n            np.array([[0.13486053, 0.23261218], [0.78257304, 0.9691353]]),\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n        self.assertAllClose(\n            output,\n            state,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n    def test_masking(self):\n        sequence = np.arange(24).reshape((2, 4, 3)).astype(\"float32\")\n        mask = np.array([[True, True, False, True], [True, False, False, True]])\n        layer = layers.GRU(\n            2,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            unroll=True,\n        )\n        output = layer(sequence, mask=mask)\n        self.assertAllClose(\n            np.array([[0.19393763, 0.19393763], [0.30818558, 0.30818558]]),\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n        layer = layers.GRU(\n            2,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            return_sequences=True,\n        )\n        output = layer(sequence, mask=mask)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.03606692, 0.03606692],\n                    [0.09497581, 0.09497581],\n                    [0.09497581, 0.09497581],\n                    [0.19393763, 0.19393763],\n                ],\n            ),\n            output[0],\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.16051409, 0.16051409],\n                    [0.16051409, 0.16051409],\n                    [0.16051409, 0.16051409],\n                    [0.30818558, 0.30818558],\n                ],\n            ),\n            output[1],\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n        layer = layers.GRU(\n            2,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            return_sequences=True,\n            zero_output_for_mask=True,\n        )\n        output = layer(sequence, mask=mask)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.03606692, 0.03606692],\n                    [0.09497581, 0.09497581],\n                    [0.0, 0.0],\n                    [0.19393763, 0.19393763],\n                ],\n            ),\n            output[0],\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.16051409, 0.16051409],\n                    [0.0, 0.0],\n                    [0.0, 0.0],\n                    [0.30818558, 0.30818558],\n                ],\n            ),\n            output[1],\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n        layer = layers.GRU(\n            2,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            go_backwards=True,\n        )\n        output = layer(sequence, mask=mask)\n        self.assertAllClose(\n            np.array([[0.11669192, 0.11669192], [0.28380975, 0.28380975]]),\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n    def test_legacy_implementation_argument(self):\n        sequence = np.arange(72).reshape((3, 6, 4)).astype(\"float32\")\n        layer = layers.GRU(\n            3,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n        )\n        config = layer.get_config()\n        config[\"implementation\"] = 0  # Add legacy argument\n        layer = layers.GRU.from_config(config)\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.5217289, 0.5217289, 0.5217289],\n                    [0.6371659, 0.6371659, 0.6371659],\n                    [0.39384964, 0.39384964, 0.3938496],\n                ]\n            ),\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"Test only applicable to fixing a bug with symbolic batch size \"\n        \"for TensorFlow backend.\",\n    )\n    def test_stateful_with_symbolic_batch_size(self):\n        layer = layers.GRU(5, stateful=True)\n\n        x_concrete = np.ones((2, 10, 10), dtype=np.float32)\n        _ = layer(x_concrete, training=True)\n        import tensorflow as tf\n\n        @tf.function(\n            input_signature=[\n                tf.TensorSpec(shape=(None, 10, 10), dtype=tf.float32)\n            ]\n        )\n        def f(x):\n            return layer(x, training=True)\n\n        y = f(x_concrete)\n        self.assertEqual(y.shape, (2, 5))\n"
  },
  {
    "path": "keras/src/layers/rnn/lstm.py",
    "content": "from keras.src import activations\nfrom keras.src import backend\nfrom keras.src import constraints\nfrom keras.src import initializers\nfrom keras.src import ops\nfrom keras.src import regularizers\nfrom keras.src import tree\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\nfrom keras.src.layers.rnn.dropout_rnn_cell import DropoutRNNCell\nfrom keras.src.layers.rnn.rnn import RNN\n\n\n@keras_export(\"keras.layers.LSTMCell\")\nclass LSTMCell(Layer, DropoutRNNCell):\n    \"\"\"Cell class for the LSTM layer.\n\n    This class processes one step within the whole time sequence input, whereas\n    `keras.layer.LSTM` processes the whole sequence.\n\n    Args:\n        units: Positive integer, dimensionality of the output space.\n        activation: Activation function to use. Default: hyperbolic tangent\n            (`tanh`). If you pass None, no activation is applied\n            (ie. \"linear\" activation: `a(x) = x`).\n        recurrent_activation: Activation function to use for the recurrent step.\n            Default: sigmoid (`sigmoid`). If you pass `None`, no activation is\n            applied (ie. \"linear\" activation: `a(x) = x`).\n        use_bias: Boolean, (default `True`), whether the layer\n            should use a bias vector.\n        kernel_initializer: Initializer for the `kernel` weights matrix,\n            used for the linear transformation of the inputs. Default:\n            `\"glorot_uniform\"`.\n        recurrent_initializer: Initializer for the `recurrent_kernel`\n            weights matrix, used for the linear transformation\n            of the recurrent state. Default: `\"orthogonal\"`.\n        bias_initializer: Initializer for the bias vector. Default: `\"zeros\"`.\n        unit_forget_bias: Boolean (default `True`). If `True`,\n            add 1 to the bias of the forget gate at initialization.\n            Setting it to `True` will also force `bias_initializer=\"zeros\"`.\n            This is recommended in [Jozefowicz et al.](\n            https://github.com/mlresearch/v37/blob/gh-pages/jozefowicz15.pdf)\n        kernel_regularizer: Regularizer function applied to the `kernel` weights\n            matrix. Default: `None`.\n        recurrent_regularizer: Regularizer function applied to the\n            `recurrent_kernel` weights matrix. Default: `None`.\n        bias_regularizer: Regularizer function applied to the bias vector.\n            Default: `None`.\n        kernel_constraint: Constraint function applied to the `kernel` weights\n            matrix. Default: `None`.\n        recurrent_constraint: Constraint function applied to the\n            `recurrent_kernel` weights matrix. Default: `None`.\n        bias_constraint: Constraint function applied to the bias vector.\n            Default: `None`.\n        dropout: Float between 0 and 1. Fraction of the units to drop for the\n            linear transformation of the inputs. Default: 0.\n        recurrent_dropout: Float between 0 and 1. Fraction of the units to drop\n            for the linear transformation of the recurrent state. Default: 0.\n        seed: Random seed for dropout.\n\n    Call arguments:\n        inputs: A 2D tensor, with shape `(batch, features)`.\n        states: A 2D tensor with shape `(batch, units)`, which is the state\n            from the previous time step.\n        training: Python boolean indicating whether the layer should behave in\n            training mode or in inference mode. Only relevant when `dropout` or\n            `recurrent_dropout` is used.\n\n    Example:\n\n    >>> inputs = np.random.random((32, 10, 8))\n    >>> rnn = keras.layers.RNN(keras.layers.LSTMCell(4))\n    >>> output = rnn(inputs)\n    >>> output.shape\n    (32, 4)\n    >>> rnn = keras.layers.RNN(\n    ...    keras.layers.LSTMCell(4),\n    ...    return_sequences=True,\n    ...    return_state=True)\n    >>> whole_sequence_output, final_state = rnn(inputs)\n    >>> whole_sequence_output.shape\n    (32, 10, 4)\n    >>> final_state.shape\n    (32, 4)\n    \"\"\"\n\n    def __init__(\n        self,\n        units,\n        activation=\"tanh\",\n        recurrent_activation=\"sigmoid\",\n        use_bias=True,\n        kernel_initializer=\"glorot_uniform\",\n        recurrent_initializer=\"orthogonal\",\n        bias_initializer=\"zeros\",\n        unit_forget_bias=True,\n        kernel_regularizer=None,\n        recurrent_regularizer=None,\n        bias_regularizer=None,\n        kernel_constraint=None,\n        recurrent_constraint=None,\n        bias_constraint=None,\n        dropout=0.0,\n        recurrent_dropout=0.0,\n        seed=None,\n        **kwargs,\n    ):\n        if units <= 0:\n            raise ValueError(\n                \"Received an invalid value for argument `units`, \"\n                f\"expected a positive integer, got {units}.\"\n            )\n        implementation = kwargs.pop(\"implementation\", 2)\n        super().__init__(**kwargs)\n        self.implementation = implementation\n        self.units = units\n        self.activation = activations.get(activation)\n        self.recurrent_activation = activations.get(recurrent_activation)\n        self.use_bias = use_bias\n\n        self.kernel_initializer = initializers.get(kernel_initializer)\n        self.recurrent_initializer = initializers.get(recurrent_initializer)\n        self.bias_initializer = initializers.get(bias_initializer)\n\n        self.kernel_regularizer = regularizers.get(kernel_regularizer)\n        self.recurrent_regularizer = regularizers.get(recurrent_regularizer)\n        self.bias_regularizer = regularizers.get(bias_regularizer)\n\n        self.kernel_constraint = constraints.get(kernel_constraint)\n        self.recurrent_constraint = constraints.get(recurrent_constraint)\n        self.bias_constraint = constraints.get(bias_constraint)\n\n        self.dropout = min(1.0, max(0.0, dropout))\n        self.recurrent_dropout = min(1.0, max(0.0, recurrent_dropout))\n        if self.recurrent_dropout != 0.0:\n            self.implementation = 1\n        if self.implementation == 1:\n            self.dropout_mask_count = 4\n        self.seed = seed\n        self.seed_generator = backend.random.SeedGenerator(seed=seed)\n\n        self.unit_forget_bias = unit_forget_bias\n        self.state_size = [self.units, self.units]\n        self.output_size = self.units\n\n    def build(self, input_shape):\n        super().build(input_shape)\n        input_dim = input_shape[-1]\n        self.kernel = self.add_weight(\n            shape=(input_dim, self.units * 4),\n            name=\"kernel\",\n            initializer=self.kernel_initializer,\n            regularizer=self.kernel_regularizer,\n            constraint=self.kernel_constraint,\n        )\n        self.recurrent_kernel = self.add_weight(\n            shape=(self.units, self.units * 4),\n            name=\"recurrent_kernel\",\n            initializer=self.recurrent_initializer,\n            regularizer=self.recurrent_regularizer,\n            constraint=self.recurrent_constraint,\n        )\n\n        if self.use_bias:\n            if self.unit_forget_bias:\n\n                def bias_initializer(_, *args, **kwargs):\n                    return ops.concatenate(\n                        [\n                            self.bias_initializer(\n                                (self.units,), *args, **kwargs\n                            ),\n                            initializers.get(\"ones\")(\n                                (self.units,), *args, **kwargs\n                            ),\n                            self.bias_initializer(\n                                (self.units * 2,), *args, **kwargs\n                            ),\n                        ]\n                    )\n\n            else:\n                bias_initializer = self.bias_initializer\n            self.bias = self.add_weight(\n                shape=(self.units * 4,),\n                name=\"bias\",\n                initializer=bias_initializer,\n                regularizer=self.bias_regularizer,\n                constraint=self.bias_constraint,\n            )\n        else:\n            self.bias = None\n\n    def _compute_carry_and_output(self, x, h_tm1, c_tm1):\n        \"\"\"Computes carry and output using split kernels.\"\"\"\n        x_i, x_f, x_c, x_o = x\n        h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1\n        i = self.recurrent_activation(\n            x_i + ops.matmul(h_tm1_i, self.recurrent_kernel[:, : self.units])\n        )\n        f = self.recurrent_activation(\n            x_f\n            + ops.matmul(\n                h_tm1_f, self.recurrent_kernel[:, self.units : self.units * 2]\n            )\n        )\n        c = f * c_tm1 + i * self.activation(\n            x_c\n            + ops.matmul(\n                h_tm1_c,\n                self.recurrent_kernel[:, self.units * 2 : self.units * 3],\n            )\n        )\n        o = self.recurrent_activation(\n            x_o\n            + ops.matmul(h_tm1_o, self.recurrent_kernel[:, self.units * 3 :])\n        )\n        return c, o\n\n    def _compute_carry_and_output_fused(self, z, c_tm1):\n        \"\"\"Computes carry and output using fused kernels.\"\"\"\n        z0, z1, z2, z3 = z\n        i = self.recurrent_activation(z0)\n        f = self.recurrent_activation(z1)\n        c = f * c_tm1 + i * self.activation(z2)\n        o = self.recurrent_activation(z3)\n        return c, o\n\n    def call(self, inputs, states, training=False):\n        h_tm1 = states[0]  # previous memory state\n        c_tm1 = states[1]  # previous carry state\n\n        if self.implementation == 1:\n            if training and 0.0 < self.dropout < 1.0:\n                dp_mask = self.get_dropout_mask(inputs)\n                inputs_i = inputs * dp_mask[0]\n                inputs_f = inputs * dp_mask[1]\n                inputs_c = inputs * dp_mask[2]\n                inputs_o = inputs * dp_mask[3]\n            else:\n                inputs_i = inputs\n                inputs_f = inputs\n                inputs_c = inputs\n                inputs_o = inputs\n            k_i, k_f, k_c, k_o = ops.split(self.kernel, 4, axis=1)\n            x_i = ops.matmul(inputs_i, k_i)\n            x_f = ops.matmul(inputs_f, k_f)\n            x_c = ops.matmul(inputs_c, k_c)\n            x_o = ops.matmul(inputs_o, k_o)\n            if self.use_bias:\n                b_i, b_f, b_c, b_o = ops.split(self.bias, 4, axis=0)\n                x_i += b_i\n                x_f += b_f\n                x_c += b_c\n                x_o += b_o\n\n            if training and 0.0 < self.recurrent_dropout < 1.0:\n                rec_dp_mask = self.get_recurrent_dropout_mask(h_tm1)\n                h_tm1_i = h_tm1 * rec_dp_mask[0]\n                h_tm1_f = h_tm1 * rec_dp_mask[1]\n                h_tm1_c = h_tm1 * rec_dp_mask[2]\n                h_tm1_o = h_tm1 * rec_dp_mask[3]\n            else:\n                h_tm1_i = h_tm1\n                h_tm1_f = h_tm1\n                h_tm1_c = h_tm1\n                h_tm1_o = h_tm1\n            x = (x_i, x_f, x_c, x_o)\n            h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o)\n            c, o = self._compute_carry_and_output(x, h_tm1, c_tm1)\n        else:\n            if training and 0.0 < self.dropout < 1.0:\n                dp_mask = self.get_dropout_mask(inputs)\n                inputs = inputs * dp_mask\n\n            z = ops.matmul(inputs, self.kernel)\n\n            z = ops.add(z, ops.matmul(h_tm1, self.recurrent_kernel))\n            if self.use_bias:\n                z = ops.add(z, self.bias)\n\n            z = ops.split(z, 4, axis=1)\n            c, o = self._compute_carry_and_output_fused(z, c_tm1)\n\n        h = o * self.activation(c)\n        return h, [h, c]\n\n    def get_config(self):\n        config = {\n            \"units\": self.units,\n            \"activation\": activations.serialize(self.activation),\n            \"recurrent_activation\": activations.serialize(\n                self.recurrent_activation\n            ),\n            \"use_bias\": self.use_bias,\n            \"unit_forget_bias\": self.unit_forget_bias,\n            \"kernel_initializer\": initializers.serialize(\n                self.kernel_initializer\n            ),\n            \"recurrent_initializer\": initializers.serialize(\n                self.recurrent_initializer\n            ),\n            \"bias_initializer\": initializers.serialize(self.bias_initializer),\n            \"kernel_regularizer\": regularizers.serialize(\n                self.kernel_regularizer\n            ),\n            \"recurrent_regularizer\": regularizers.serialize(\n                self.recurrent_regularizer\n            ),\n            \"bias_regularizer\": regularizers.serialize(self.bias_regularizer),\n            \"kernel_constraint\": constraints.serialize(self.kernel_constraint),\n            \"recurrent_constraint\": constraints.serialize(\n                self.recurrent_constraint\n            ),\n            \"bias_constraint\": constraints.serialize(self.bias_constraint),\n            \"dropout\": self.dropout,\n            \"recurrent_dropout\": self.recurrent_dropout,\n            \"seed\": self.seed,\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n    def get_initial_state(self, batch_size=None):\n        return [\n            ops.zeros((batch_size, d), dtype=self.compute_dtype)\n            for d in self.state_size\n        ]\n\n\n@keras_export(\"keras.layers.LSTM\")\nclass LSTM(RNN):\n    \"\"\"Long Short-Term Memory layer - Hochreiter 1997.\n\n    Based on available runtime hardware and constraints, this layer\n    will choose different implementations (cuDNN-based or backend-native)\n    to maximize the performance. If a GPU is available and all\n    the arguments to the layer meet the requirement of the cuDNN kernel\n    (see below for details), the layer will use a fast cuDNN implementation\n    when using the TensorFlow backend.\n    The requirements to use the cuDNN implementation are:\n\n    1. `activation` == `tanh`\n    2. `recurrent_activation` == `sigmoid`\n    3. `recurrent_dropout` == 0\n    4. `unroll` is `False`\n    5. `use_bias` is `True`\n    6. Inputs, if use masking, are strictly right-padded.\n    7. Eager execution is enabled in the outermost context.\n\n    For example:\n\n    >>> inputs = np.random.random((32, 10, 8))\n    >>> lstm = keras.layers.LSTM(4)\n    >>> output = lstm(inputs)\n    >>> output.shape\n    (32, 4)\n    >>> lstm = keras.layers.LSTM(\n    ...     4, return_sequences=True, return_state=True)\n    >>> whole_seq_output, final_memory_state, final_carry_state = lstm(inputs)\n    >>> whole_seq_output.shape\n    (32, 10, 4)\n    >>> final_memory_state.shape\n    (32, 4)\n    >>> final_carry_state.shape\n    (32, 4)\n\n    Args:\n        units: Positive integer, dimensionality of the output space.\n        activation: Activation function to use.\n            Default: hyperbolic tangent (`tanh`).\n            If you pass `None`, no activation is applied\n            (ie. \"linear\" activation: `a(x) = x`).\n        recurrent_activation: Activation function to use\n            for the recurrent step.\n            Default: sigmoid (`sigmoid`).\n            If you pass `None`, no activation is applied\n            (ie. \"linear\" activation: `a(x) = x`).\n        use_bias: Boolean, (default `True`), whether the layer\n            should use a bias vector.\n        kernel_initializer: Initializer for the `kernel` weights matrix,\n            used for the linear transformation of the inputs. Default:\n            `\"glorot_uniform\"`.\n        recurrent_initializer: Initializer for the `recurrent_kernel`\n            weights matrix, used for the linear transformation of the recurrent\n            state. Default: `\"orthogonal\"`.\n        bias_initializer: Initializer for the bias vector. Default: `\"zeros\"`.\n        unit_forget_bias: Boolean (default `True`). If `True`,\n            add 1 to the bias of the forget gate at initialization.\n            Setting it to `True` will also force `bias_initializer=\"zeros\"`.\n            This is recommended in [Jozefowicz et al.](\n            https://github.com/mlresearch/v37/blob/gh-pages/jozefowicz15.pdf)\n        kernel_regularizer: Regularizer function applied to the `kernel` weights\n            matrix. Default: `None`.\n        recurrent_regularizer: Regularizer function applied to the\n            `recurrent_kernel` weights matrix. Default: `None`.\n        bias_regularizer: Regularizer function applied to the bias vector.\n            Default: `None`.\n        activity_regularizer: Regularizer function applied to the output of the\n            layer (its \"activation\"). Default: `None`.\n        kernel_constraint: Constraint function applied to the `kernel` weights\n            matrix. Default: `None`.\n        recurrent_constraint: Constraint function applied to the\n            `recurrent_kernel` weights matrix. Default: `None`.\n        bias_constraint: Constraint function applied to the bias vector.\n            Default: `None`.\n        dropout: Float between 0 and 1. Fraction of the units to drop for the\n            linear transformation of the inputs. Default: 0.\n        recurrent_dropout: Float between 0 and 1. Fraction of the units to drop\n            for the linear transformation of the recurrent state. Default: 0.\n        seed: Random seed for dropout.\n        return_sequences: Boolean. Whether to return the last output\n            in the output sequence, or the full sequence. Default: `False`.\n        return_state: Boolean. Whether to return the last state in addition\n            to the output. Default: `False`.\n        go_backwards: Boolean (default: `False`).\n            If `True`, process the input sequence backwards and return the\n            reversed sequence.\n        stateful: Boolean (default: `False`). If `True`, the last state\n            for each sample at index i in a batch will be used as initial\n            state for the sample of index i in the following batch.\n        unroll: Boolean (default False).\n            If `True`, the network will be unrolled,\n            else a symbolic loop will be used.\n            Unrolling can speed-up a RNN,\n            although it tends to be more memory-intensive.\n            Unrolling is only suitable for short sequences.\n        use_cudnn: Whether to use a cuDNN-backed implementation. `\"auto\"` will\n            attempt to use cuDNN when feasible, and will fallback to the\n            default implementation if not.\n\n    Call arguments:\n        inputs: A 3D tensor, with shape `(batch, timesteps, feature)`.\n        mask: Binary tensor of shape `(samples, timesteps)` indicating whether\n            a given timestep should be masked  (optional).\n            An individual `True` entry indicates that the corresponding timestep\n            should be utilized, while a `False` entry indicates that the\n            corresponding timestep should be ignored. Defaults to `None`.\n        training: Python boolean indicating whether the layer should behave in\n            training mode or in inference mode. This argument is passed to the\n            cell when calling it. This is only relevant if `dropout` or\n            `recurrent_dropout` is used  (optional). Defaults to `None`.\n        initial_state: List of initial state tensors to be passed to the first\n            call of the cell (optional, `None` causes creation\n            of zero-filled initial state tensors). Defaults to `None`.\n    \"\"\"\n\n    def __init__(\n        self,\n        units,\n        activation=\"tanh\",\n        recurrent_activation=\"sigmoid\",\n        use_bias=True,\n        kernel_initializer=\"glorot_uniform\",\n        recurrent_initializer=\"orthogonal\",\n        bias_initializer=\"zeros\",\n        unit_forget_bias=True,\n        kernel_regularizer=None,\n        recurrent_regularizer=None,\n        bias_regularizer=None,\n        activity_regularizer=None,\n        kernel_constraint=None,\n        recurrent_constraint=None,\n        bias_constraint=None,\n        dropout=0.0,\n        recurrent_dropout=0.0,\n        seed=None,\n        return_sequences=False,\n        return_state=False,\n        go_backwards=False,\n        stateful=False,\n        unroll=False,\n        use_cudnn=\"auto\",\n        **kwargs,\n    ):\n        cell = LSTMCell(\n            units,\n            activation=activation,\n            recurrent_activation=recurrent_activation,\n            use_bias=use_bias,\n            kernel_initializer=kernel_initializer,\n            unit_forget_bias=unit_forget_bias,\n            recurrent_initializer=recurrent_initializer,\n            bias_initializer=bias_initializer,\n            kernel_regularizer=kernel_regularizer,\n            recurrent_regularizer=recurrent_regularizer,\n            bias_regularizer=bias_regularizer,\n            kernel_constraint=kernel_constraint,\n            recurrent_constraint=recurrent_constraint,\n            bias_constraint=bias_constraint,\n            dropout=dropout,\n            recurrent_dropout=recurrent_dropout,\n            dtype=kwargs.get(\"dtype\", None),\n            trainable=kwargs.get(\"trainable\", True),\n            name=\"lstm_cell\",\n            seed=seed,\n            implementation=kwargs.pop(\"implementation\", 2),\n        )\n        super().__init__(\n            cell,\n            return_sequences=return_sequences,\n            return_state=return_state,\n            go_backwards=go_backwards,\n            stateful=stateful,\n            unroll=unroll,\n            activity_regularizer=activity_regularizer,\n            **kwargs,\n        )\n        self.input_spec = InputSpec(ndim=3)\n        if use_cudnn not in (\"auto\", True, False):\n            raise ValueError(\n                \"Invalid valid received for argument `use_cudnn`. \"\n                \"Expected one of {'auto', True, False}. \"\n                f\"Received: use_cudnn={use_cudnn}\"\n            )\n        self.use_cudnn = use_cudnn\n        if (\n            backend.backend() == \"tensorflow\"\n            and backend.cudnn_ok(\n                cell.activation,\n                cell.recurrent_activation,\n                self.unroll,\n                cell.use_bias,\n            )\n            and use_cudnn in (True, \"auto\")\n        ):\n            self.supports_jit = False\n\n    def inner_loop(self, sequences, initial_state, mask, training=False):\n        if tree.is_nested(mask):\n            mask = mask[0]\n\n        if self.use_cudnn in (\"auto\", True):\n            if not self.recurrent_dropout:\n                try:\n                    if training and self.dropout:\n                        dp_mask = self.cell.get_dropout_mask(sequences[:, 0, :])\n                        dp_mask = ops.expand_dims(dp_mask, axis=1)\n                        dp_mask = ops.broadcast_to(\n                            dp_mask, ops.shape(sequences)\n                        )\n                        dp_sequences = sequences * dp_mask\n                    else:\n                        dp_sequences = sequences\n\n                    # Backends are allowed to specify (optionally) optimized\n                    # implementation of the inner LSTM loop. In the case of\n                    # TF for instance, it will leverage cuDNN when feasible, and\n                    # it will raise NotImplementedError otherwise.\n                    out = backend.lstm(\n                        dp_sequences,\n                        initial_state[0],\n                        initial_state[1],\n                        mask,\n                        kernel=self.cell.kernel,\n                        recurrent_kernel=self.cell.recurrent_kernel,\n                        bias=self.cell.bias,\n                        activation=self.cell.activation,\n                        recurrent_activation=self.cell.recurrent_activation,\n                        return_sequences=self.return_sequences,\n                        go_backwards=self.go_backwards,\n                        unroll=self.unroll,\n                    )\n                    # We disable jit_compile for the model in this case,\n                    # since cuDNN ops aren't XLA compatible.\n                    if backend.backend() == \"tensorflow\":\n                        self.supports_jit = False\n                    return out\n                except NotImplementedError:\n                    pass\n        if self.use_cudnn is True:\n            raise ValueError(\n                \"use_cudnn=True was specified, \"\n                \"but cuDNN is not supported for this layer configuration \"\n                \"with this backend. Pass use_cudnn='auto' to fallback \"\n                \"to a non-cuDNN implementation.\"\n            )\n        return super().inner_loop(\n            sequences, initial_state, mask=mask, training=training\n        )\n\n    def call(self, sequences, initial_state=None, mask=None, training=False):\n        return super().call(\n            sequences, mask=mask, training=training, initial_state=initial_state\n        )\n\n    @property\n    def units(self):\n        return self.cell.units\n\n    @property\n    def activation(self):\n        return self.cell.activation\n\n    @property\n    def recurrent_activation(self):\n        return self.cell.recurrent_activation\n\n    @property\n    def use_bias(self):\n        return self.cell.use_bias\n\n    @property\n    def unit_forget_bias(self):\n        return self.cell.unit_forget_bias\n\n    @property\n    def kernel_initializer(self):\n        return self.cell.kernel_initializer\n\n    @property\n    def recurrent_initializer(self):\n        return self.cell.recurrent_initializer\n\n    @property\n    def bias_initializer(self):\n        return self.cell.bias_initializer\n\n    @property\n    def kernel_regularizer(self):\n        return self.cell.kernel_regularizer\n\n    @property\n    def recurrent_regularizer(self):\n        return self.cell.recurrent_regularizer\n\n    @property\n    def bias_regularizer(self):\n        return self.cell.bias_regularizer\n\n    @property\n    def kernel_constraint(self):\n        return self.cell.kernel_constraint\n\n    @property\n    def recurrent_constraint(self):\n        return self.cell.recurrent_constraint\n\n    @property\n    def bias_constraint(self):\n        return self.cell.bias_constraint\n\n    @property\n    def dropout(self):\n        return self.cell.dropout\n\n    @property\n    def recurrent_dropout(self):\n        return self.cell.recurrent_dropout\n\n    def get_config(self):\n        config = {\n            \"units\": self.units,\n            \"activation\": activations.serialize(self.activation),\n            \"recurrent_activation\": activations.serialize(\n                self.recurrent_activation\n            ),\n            \"use_bias\": self.use_bias,\n            \"kernel_initializer\": initializers.serialize(\n                self.kernel_initializer\n            ),\n            \"recurrent_initializer\": initializers.serialize(\n                self.recurrent_initializer\n            ),\n            \"bias_initializer\": initializers.serialize(self.bias_initializer),\n            \"unit_forget_bias\": self.unit_forget_bias,\n            \"kernel_regularizer\": regularizers.serialize(\n                self.kernel_regularizer\n            ),\n            \"recurrent_regularizer\": regularizers.serialize(\n                self.recurrent_regularizer\n            ),\n            \"bias_regularizer\": regularizers.serialize(self.bias_regularizer),\n            \"activity_regularizer\": regularizers.serialize(\n                self.activity_regularizer\n            ),\n            \"kernel_constraint\": constraints.serialize(self.kernel_constraint),\n            \"recurrent_constraint\": constraints.serialize(\n                self.recurrent_constraint\n            ),\n            \"bias_constraint\": constraints.serialize(self.bias_constraint),\n            \"dropout\": self.dropout,\n            \"recurrent_dropout\": self.recurrent_dropout,\n            \"seed\": self.cell.seed,\n        }\n        base_config = super().get_config()\n        del base_config[\"cell\"]\n        return {**base_config, **config}\n\n    @classmethod\n    def from_config(cls, config):\n        return cls(**config)\n"
  },
  {
    "path": "keras/src/layers/rnn/lstm_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import initializers\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass LSTMTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_basics(self):\n        self.run_layer_test(\n            layers.LSTM,\n            init_kwargs={\"units\": 3, \"dropout\": 0.5},\n            input_shape=(3, 2, 4),\n            call_kwargs={\"training\": True},\n            expected_output_shape=(3, 3),\n            expected_num_trainable_weights=3,\n            expected_num_non_trainable_weights=0,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.LSTM,\n            init_kwargs={\"units\": 3, \"dropout\": 0.5, \"recurrent_dropout\": 0.5},\n            input_shape=(3, 2, 4),\n            call_kwargs={\"training\": True},\n            expected_output_shape=(3, 3),\n            expected_num_trainable_weights=3,\n            expected_num_non_trainable_weights=0,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.LSTM,\n            init_kwargs={\n                \"units\": 3,\n                \"return_sequences\": True,\n                \"bias_regularizer\": \"l1\",\n                \"kernel_regularizer\": \"l2\",\n                \"recurrent_regularizer\": \"l2\",\n            },\n            input_shape=(3, 2, 4),\n            expected_output_shape=(3, 2, 3),\n            expected_num_losses=3,\n            expected_num_trainable_weights=3,\n            expected_num_non_trainable_weights=0,\n            supports_masking=True,\n        )\n\n    @parameterized.parameters([1, 2])\n    def test_correctness(self, implementation):\n        sequence = np.arange(72).reshape((3, 6, 4)).astype(\"float32\")\n        layer = layers.LSTM(\n            3,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            implementation=implementation,\n        )\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.6288687, 0.6288687, 0.6288687],\n                    [0.86899155, 0.86899155, 0.86899155],\n                    [0.9460773, 0.9460773, 0.9460773],\n                ]\n            ),\n            output,\n            atol=1e-5,\n            rtol=1e-5,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n        layer = layers.LSTM(\n            3,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            go_backwards=True,\n            implementation=implementation,\n        )\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.35622165, 0.35622165, 0.35622165],\n                    [0.74789524, 0.74789524, 0.74789524],\n                    [0.8872726, 0.8872726, 0.8872726],\n                ]\n            ),\n            output,\n            atol=1e-5,\n            rtol=1e-5,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n        layer = layers.LSTM(\n            3,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            unroll=True,\n            implementation=implementation,\n        )\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.6288687, 0.6288687, 0.6288687],\n                    [0.86899155, 0.86899155, 0.86899155],\n                    [0.9460773, 0.9460773, 0.9460773],\n                ]\n            ),\n            output,\n            atol=1e-5,\n            rtol=1e-5,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n        layer = layers.LSTM(\n            3,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            unit_forget_bias=False,\n            implementation=implementation,\n        )\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.57019705, 0.57019705, 0.57019705],\n                    [0.8661914, 0.8661914, 0.8661914],\n                    [0.9459622, 0.9459622, 0.9459622],\n                ]\n            ),\n            output,\n            atol=1e-5,\n            rtol=1e-5,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n        layer = layers.LSTM(\n            3,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            use_bias=False,\n            implementation=implementation,\n        )\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.54986924, 0.54986924, 0.54986924],\n                    [0.86226785, 0.86226785, 0.86226785],\n                    [0.9443936, 0.9443936, 0.9443936],\n                ]\n            ),\n            output,\n            atol=1e-5,\n            rtol=1e-5,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n    def test_statefulness(self):\n        sequence = np.arange(24).reshape((2, 3, 4)).astype(\"float32\")\n        layer = layers.LSTM(\n            4,\n            stateful=True,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n        )\n        layer(sequence)\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.3124785, 0.3124785, 0.3124785, 0.3124785],\n                    [0.6863672, 0.6863672, 0.6863672, 0.6863672],\n                ]\n            ),\n            output,\n            atol=1e-5,\n            rtol=1e-5,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n        layer.reset_state()\n        layer(sequence)\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.3124785, 0.3124785, 0.3124785, 0.3124785],\n                    [0.6863672, 0.6863672, 0.6863672, 0.6863672],\n                ]\n            ),\n            output,\n            atol=1e-5,\n            rtol=1e-5,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n    def test_pass_initial_state(self):\n        sequence = np.arange(24).reshape((2, 4, 3)).astype(\"float32\")\n        initial_state = [\n            np.arange(4).reshape((2, 2)).astype(\"float32\"),\n            np.arange(4).reshape((2, 2)).astype(\"float32\"),\n        ]\n        layer = layers.LSTM(\n            2,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n        )\n        output = layer(sequence, initial_state=initial_state)\n        self.assertAllClose(\n            np.array([[0.20574439, 0.3558822], [0.64930826, 0.66276]]),\n            output,\n            atol=1e-5,\n            rtol=1e-5,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n        layer = layers.LSTM(\n            2,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            go_backwards=True,\n        )\n        output = layer(sequence, initial_state=initial_state)\n        self.assertAllClose(\n            np.array([[0.13281618, 0.2790356], [0.5839337, 0.5992567]]),\n            output,\n            atol=1e-5,\n            rtol=1e-5,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n    def test_masking(self):\n        sequence = np.arange(24).reshape((2, 4, 3)).astype(\"float32\")\n        mask = np.array([[True, True, False, True], [True, False, False, True]])\n        layer = layers.LSTM(\n            2,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            unroll=True,\n        )\n        output = layer(sequence, mask=mask)\n        self.assertAllClose(\n            np.array([[0.1524914, 0.1524914], [0.35969394, 0.35969394]]),\n            output,\n            atol=1e-5,\n            rtol=1e-5,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n        layer = layers.LSTM(\n            2,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            return_sequences=True,\n        )\n        output = layer(sequence, mask=mask)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.0158891, 0.0158891],\n                    [0.05552047, 0.05552047],\n                    [0.05552047, 0.05552047],\n                    [0.1524914, 0.1524914],\n                ],\n            ),\n            output[0],\n            atol=1e-5,\n            rtol=1e-5,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.14185596, 0.14185596],\n                    [0.14185596, 0.14185596],\n                    [0.14185596, 0.14185596],\n                    [0.35969394, 0.35969394],\n                ],\n            ),\n            output[1],\n            atol=1e-5,\n            rtol=1e-5,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n        layer = layers.LSTM(\n            2,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            return_sequences=True,\n            zero_output_for_mask=True,\n        )\n        output = layer(sequence, mask=mask)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.0158891, 0.0158891],\n                    [0.05552047, 0.05552047],\n                    [0.0, 0.0],\n                    [0.1524914, 0.1524914],\n                ],\n            ),\n            output[0],\n            atol=1e-5,\n            rtol=1e-5,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.14185596, 0.14185596],\n                    [0.0, 0.0],\n                    [0.0, 0.0],\n                    [0.35969394, 0.35969394],\n                ],\n            ),\n            output[1],\n            atol=1e-5,\n            rtol=1e-5,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n        layer = layers.LSTM(\n            2,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            go_backwards=True,\n        )\n        output = layer(sequence, mask=mask)\n        self.assertAllClose(\n            np.array([[0.10056866, 0.10056866], [0.31006062, 0.31006062]]),\n            output,\n            atol=1e-5,\n            rtol=1e-5,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n"
  },
  {
    "path": "keras/src/layers/rnn/rnn.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\nfrom keras.src import tree\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.layer import Layer\nfrom keras.src.layers.rnn.dropout_rnn_cell import DropoutRNNCell\nfrom keras.src.layers.rnn.stacked_rnn_cells import StackedRNNCells\nfrom keras.src.saving import serialization_lib\nfrom keras.src.utils import tracking\n\n\n@keras_export(\"keras.layers.RNN\")\nclass RNN(Layer):\n    \"\"\"Base class for recurrent layers.\n\n    Args:\n        cell: A RNN cell instance or a list of RNN cell instances.\n            A RNN cell is a class that has:\n            - A `call(input_at_t, states_at_t)` method, returning\n            `(output_at_t, states_at_t_plus_1)`. The call method of the\n            cell can also take the optional argument `constants`, see\n            section \"Note on passing external constants\" below.\n            - A `state_size` attribute. This can be a single integer\n            (single state) in which case it is the size of the recurrent\n            state. This can also be a list/tuple of integers\n            (one size per state).\n            - A `output_size` attribute, a single integer.\n            - A `get_initial_state(batch_size=None)`\n            method that creates a tensor meant to be fed to `call()` as the\n            initial state, if the user didn't specify any initial state\n            via other means. The returned initial state should have\n            shape `(batch_size, cell.state_size)`.\n            The cell might choose to create a tensor full of zeros,\n            or other values based on the cell's implementation.\n            `inputs` is the input tensor to the RNN layer, with shape\n            `(batch_size, timesteps, features)`.\n            If this method is not implemented\n            by the cell, the RNN layer will create a zero filled tensor\n            with shape `(batch_size, cell.state_size)`.\n            In the case that `cell` is a list of RNN cell instances, the cells\n            will be stacked on top of each other in the RNN, resulting in an\n            efficient stacked RNN.\n        return_sequences: Boolean (default `False`). Whether to return the last\n            output in the output sequence, or the full sequence.\n        return_state: Boolean (default `False`).\n            Whether to return the last state in addition to the output.\n        go_backwards: Boolean (default `False`).\n            If `True`, process the input sequence backwards and return the\n            reversed sequence.\n        stateful: Boolean (default `False`). If True, the last state\n            for each sample at index `i` in a batch will be used as initial\n            state for the sample of index `i` in the following batch.\n        unroll: Boolean (default `False`).\n            If True, the network will be unrolled, else a symbolic loop will be\n            used. Unrolling can speed-up a RNN, although it tends to be more\n            memory-intensive. Unrolling is only suitable for short sequences.\n        zero_output_for_mask: Boolean (default `False`).\n            Whether the output should use zeros for the masked timesteps.\n            Note that this field is only used when `return_sequences`\n            is `True` and `mask` is provided.\n            It can useful if you want to reuse the raw output sequence of\n            the RNN without interference from the masked timesteps, e.g.,\n            merging bidirectional RNNs.\n\n    Call arguments:\n        sequences: A 3-D tensor with shape `(batch_size, timesteps, features)`.\n        initial_state: List of initial state tensors to be passed to the first\n            call of the cell.\n        mask: Binary tensor of shape `[batch_size, timesteps]`\n            indicating whether a given timestep should be masked.\n            An individual `True` entry indicates that the corresponding\n            timestep should be utilized, while a `False` entry indicates\n            that the corresponding timestep should be ignored.\n        training: Python boolean indicating whether the layer should behave in\n            training mode or in inference mode. This argument is passed\n            to the cell when calling it.\n            This is for use with cells that use dropout.\n\n    Output shape:\n\n    - If `return_state`: a list of tensors. The first tensor is\n    the output. The remaining tensors are the last states,\n    each with shape `(batch_size, state_size)`, where `state_size` could\n    be a high dimension tensor shape.\n    - If `return_sequences`: 3D tensor with shape\n    `(batch_size, timesteps, output_size)`.\n\n    Masking:\n\n    This layer supports masking for input data with a variable number\n    of timesteps. To introduce masks to your data,\n    use a `keras.layers.Embedding` layer with the `mask_zero` parameter\n    set to `True`.\n\n    Note on using statefulness in RNNs:\n\n    You can set RNN layers to be 'stateful', which means that the states\n    computed for the samples in one batch will be reused as initial states\n    for the samples in the next batch. This assumes a one-to-one mapping\n    between samples in different successive batches.\n\n    To enable statefulness:\n\n    - Specify `stateful=True` in the layer constructor.\n    - Specify a fixed batch size for your model, by passing\n        `batch_size=...` to the `Input` layer(s) of your model.\n        Remember to also specify the same `batch_size=...` when\n        calling `fit()`, or otherwise use a generator-like\n        data source like a `keras.utils.PyDataset` or a\n        `tf.data.Dataset`.\n    - Specify `shuffle=False` when calling `fit()`, since your\n        batches are expected to be temporally ordered.\n\n    To reset the states of your model, call `.reset_state()` on either\n    a specific layer, or on your entire model.\n\n    Note on specifying the initial state of RNNs:\n\n    You can specify the initial state of RNN layers symbolically by\n    calling them with the keyword argument `initial_state`. The value of\n    `initial_state` should be a tensor or list of tensors representing\n    the initial state of the RNN layer.\n\n    You can specify the initial state of RNN layers numerically by\n    calling `reset_state()` with the keyword argument `states`. The value of\n    `states` should be a numpy array or list of numpy arrays representing\n    the initial state of the RNN layer.\n\n    Examples:\n\n    ```python\n    from keras.layers import RNN\n    from keras import ops\n\n    # First, let's define a RNN Cell, as a layer subclass.\n    class MinimalRNNCell(keras.Layer):\n\n        def __init__(self, units, **kwargs):\n            super().__init__(**kwargs)\n            self.units = units\n            self.state_size = units\n\n        def build(self, input_shape):\n            self.kernel = self.add_weight(shape=(input_shape[-1], self.units),\n                                          initializer='uniform',\n                                          name='kernel')\n            self.recurrent_kernel = self.add_weight(\n                shape=(self.units, self.units),\n                initializer='uniform',\n                name='recurrent_kernel')\n\n        def call(self, inputs, states):\n            prev_output = states[0]\n            h = ops.matmul(inputs, self.kernel)\n            output = h + ops.matmul(prev_output, self.recurrent_kernel)\n            return output, [output]\n\n    # Let's use this cell in a RNN layer:\n\n    cell = MinimalRNNCell(32)\n    x = keras.Input((None, 5))\n    layer = RNN(cell)\n    y = layer(x)\n\n    # Here's how to use the cell to build a stacked RNN:\n\n    cells = [MinimalRNNCell(32), MinimalRNNCell(64)]\n    x = keras.Input((None, 5))\n    layer = RNN(cells)\n    y = layer(x)\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        cell,\n        return_sequences=False,\n        return_state=False,\n        go_backwards=False,\n        stateful=False,\n        unroll=False,\n        zero_output_for_mask=False,\n        **kwargs,\n    ):\n        if isinstance(cell, (list, tuple)):\n            cell = StackedRNNCells(cell)\n        if \"call\" not in dir(cell):\n            raise ValueError(\n                \"Argument `cell` should have a `call` method. \"\n                f\"Received: cell={cell}\"\n            )\n        if \"state_size\" not in dir(cell):\n            raise ValueError(\n                \"The RNN cell should have a `state_size` attribute \"\n                \"(single integer or list of integers, \"\n                \"one integer per RNN state). \"\n                f\"Received: cell={cell}\"\n            )\n        super().__init__(**kwargs)\n\n        # If True, the output for masked timestep will be zeros, whereas in the\n        # False case, output from previous timestep is returned for masked\n        # timestep.\n        self.zero_output_for_mask = zero_output_for_mask\n        self.cell = cell\n        self.return_sequences = return_sequences\n        self.return_state = return_state\n        self.go_backwards = go_backwards\n        self.stateful = stateful\n        self.unroll = unroll\n\n        self.supports_masking = True\n        self.input_spec = None\n        self.states = None\n        self._expected_batch_size = None\n\n        state_size = getattr(self.cell, \"state_size\", None)\n        if state_size is None:\n            raise ValueError(\n                \"state_size must be specified as property on the RNN cell.\"\n            )\n        if not isinstance(state_size, (list, tuple, int)):\n            raise ValueError(\n                \"state_size must be an integer, or a list/tuple of integers \"\n                \"(one for each state tensor).\"\n            )\n        if isinstance(state_size, int):\n            self.state_size = [state_size]\n            self.single_state = True\n        else:\n            self.state_size = list(state_size)\n            self.single_state = False\n\n    def compute_output_shape(self, sequences_shape, initial_state_shape=None):\n        batch_size = sequences_shape[0]\n        length = sequences_shape[1]\n        states_shape = []\n        for state_size in self.state_size:\n            if isinstance(state_size, int):\n                states_shape.append((batch_size, state_size))\n            elif isinstance(state_size, (list, tuple)):\n                states_shape.append([(batch_size, s) for s in state_size])\n\n        output_size = getattr(self.cell, \"output_size\", None)\n        if output_size is None:\n            output_size = self.state_size[0]\n        if not isinstance(output_size, int):\n            raise ValueError(\"output_size must be an integer.\")\n        if self.return_sequences:\n            output_shape = (batch_size, length, output_size)\n        else:\n            output_shape = (batch_size, output_size)\n        if self.return_state:\n            return output_shape, *states_shape\n        return output_shape\n\n    def compute_mask(self, _, mask):\n        # Time step masks must be the same for each input.\n        # This is because the mask for an RNN is of size [batch, time_steps, 1],\n        # and specifies which time steps should be skipped, and a time step\n        # must be skipped for all inputs.\n        mask = tree.flatten(mask)[0]\n        output_mask = mask if self.return_sequences else None\n        if self.return_state:\n            state_mask = [None for _ in self.state_size]\n            return [output_mask] + state_mask\n        else:\n            return output_mask\n\n    def build(self, sequences_shape, initial_state_shape=None):\n        # Build cell (if layer).\n        step_input_shape = (sequences_shape[0],) + tuple(sequences_shape[2:])\n        if isinstance(self.cell, Layer) and not self.cell.built:\n            self.cell.build(step_input_shape)\n            self.cell.built = True\n        if self.stateful:\n            if self.states is not None:\n                self.reset_state()\n            else:\n                if sequences_shape[0] is None:\n                    raise ValueError(\n                        \"When using `stateful=True` in a RNN, the \"\n                        \"batch size must be static. Found dynamic \"\n                        f\"batch size: sequence.shape={sequences_shape}\"\n                    )\n                self._create_state_variables(sequences_shape[0])\n                self._expected_batch_size = ops.shape(\n                    tree.flatten(self.states)[0]\n                )[0]\n\n    @tracking.no_automatic_dependency_tracking\n    def _create_state_variables(self, batch_size):\n        with backend.name_scope(self.name, caller=self):\n            self.states = tree.map_structure(\n                lambda value: backend.Variable(\n                    value,\n                    trainable=False,\n                    dtype=self.variable_dtype,\n                    name=\"rnn_state\",\n                ),\n                self.get_initial_state(batch_size),\n            )\n\n    def get_initial_state(self, batch_size):\n        get_initial_state_fn = getattr(self.cell, \"get_initial_state\", None)\n        if get_initial_state_fn:\n            init_state = get_initial_state_fn(batch_size=batch_size)\n        else:\n            return [\n                ops.zeros((batch_size, d), dtype=self.cell.compute_dtype)\n                for d in self.state_size\n            ]\n\n        # RNN expect the states in a list, even if single state.\n        if not tree.is_nested(init_state):\n            init_state = [init_state]\n        # Force the state to be a list in case it is a namedtuple eg\n        # LSTMStateTuple.\n        return list(init_state)\n\n    def reset_states(self):\n        # Compatibility alias.\n        self.reset_state()\n\n    def reset_state(self):\n        if self.states is not None:\n            for v in self.states:\n                v.assign(ops.zeros_like(v.value))\n\n    def inner_loop(self, sequences, initial_state, mask, training=False):\n        cell_kwargs = {}\n        if isinstance(self.cell, Layer) and self.cell._call_has_training_arg:\n            cell_kwargs[\"training\"] = training\n\n        def step(inputs, states):\n            # Create new tensor copies when using PyTorch backend\n            # with stateful=True. This prevents in-place modifications\n            # that would otherwise break PyTorch's autograd functionality\n            # by modifying tensors needed for gradient computation.\n            if backend.backend() == \"torch\" and self.stateful:\n                states = tree.map_structure(ops.copy, states)\n            output, new_states = self.cell(inputs, states, **cell_kwargs)\n            if not tree.is_nested(new_states):\n                new_states = [new_states]\n            return output, new_states\n\n        if not tree.is_nested(initial_state):\n            initial_state = [initial_state]\n\n        return backend.rnn(\n            step,\n            sequences,\n            initial_state,\n            go_backwards=self.go_backwards,\n            mask=mask,\n            unroll=self.unroll,\n            input_length=sequences.shape[1],\n            zero_output_for_mask=self.zero_output_for_mask,\n            return_all_outputs=self.return_sequences,\n        )\n\n    def call(\n        self,\n        sequences,\n        initial_state=None,\n        mask=None,\n        training=False,\n    ):\n        timesteps = sequences.shape[1]\n        if self.unroll and timesteps is None:\n            raise ValueError(\n                \"Cannot unroll a RNN if the \"\n                \"time dimension is undefined. \\n\"\n                \"- If using a Sequential model, \"\n                \"specify the time dimension by passing \"\n                \"an `Input()` as your first layer.\\n\"\n                \"- If using the functional API, specify \"\n                \"the time dimension by passing a `shape` \"\n                \"or `batch_shape` argument to your `Input()`.\"\n            )\n\n        if initial_state is None:\n            if self.stateful:\n                initial_state = self.states\n            else:\n                initial_state = self.get_initial_state(\n                    batch_size=ops.shape(sequences)[0]\n                )\n        if self.stateful:\n            actual_batch_size = sequences.shape[0]\n            if (\n                self._expected_batch_size is not None\n                and actual_batch_size is not None\n                and actual_batch_size != self._expected_batch_size\n            ):\n                raise ValueError(\n                    f\"If an RNN is stateful, the batch size of the \"\n                    f\"input sequences must be the same as the batch \"\n                    f\"size of the initial state. \\n\"\n                    f\"- Expected batch size: {self._expected_batch_size}\\n\"\n                    f\"- Received batch size: {actual_batch_size}\"\n                )\n\n        # RNN expect the states in a list, even if single state.\n        if not tree.is_nested(initial_state):\n            initial_state = [initial_state]\n        initial_state = list(initial_state)\n\n        # Cast states to compute dtype.\n        # Note that states may be deeply nested\n        # (e.g. in the stacked cells case).\n        initial_state = tree.map_structure(\n            lambda x: backend.convert_to_tensor(\n                x, dtype=self.cell.compute_dtype\n            ),\n            initial_state,\n        )\n\n        # Prepopulate the dropout state so that the inner_loop is stateless\n        # this is particularly important for JAX backend.\n        self._maybe_config_dropout_masks(\n            self.cell, sequences[:, 0, :], initial_state\n        )\n\n        last_output, outputs, states = self.inner_loop(\n            sequences=sequences,\n            initial_state=initial_state,\n            mask=mask,\n            training=training,\n        )\n        last_output = ops.cast(last_output, self.compute_dtype)\n        outputs = ops.cast(outputs, self.compute_dtype)\n        states = tree.map_structure(\n            lambda x: ops.cast(x, dtype=self.compute_dtype), states\n        )\n        self._maybe_reset_dropout_masks(self.cell)\n\n        if self.stateful:\n            for self_state, state in zip(\n                tree.flatten(self.states), tree.flatten(states)\n            ):\n                self_state.assign(state)\n\n        if self.return_sequences:\n            output = outputs\n        else:\n            output = last_output\n\n        if self.return_state:\n            return output, *states\n        return output\n\n    def _maybe_config_dropout_masks(self, cell, input_sequence, input_state):\n        state = (\n            input_state[0]\n            if isinstance(input_state, (list, tuple))\n            else input_state\n        )\n        if isinstance(cell, DropoutRNNCell):\n            cell.get_dropout_mask(input_sequence)\n            cell.get_recurrent_dropout_mask(state)\n        if isinstance(cell, StackedRNNCells):\n            for c, s in zip(cell.cells, input_state):\n                self._maybe_config_dropout_masks(c, input_sequence, s)\n                # Replicate the behavior of `StackedRNNCells.call` to compute\n                # the inputs for the next cell.\n                s = list(s) if tree.is_nested(s) else [s]\n                cell_call_fn = c.__call__ if callable(c) else c.call\n                input_sequence, _ = cell_call_fn(input_sequence, s)\n\n    def _maybe_reset_dropout_masks(self, cell):\n        if isinstance(cell, DropoutRNNCell):\n            cell.reset_dropout_mask()\n            cell.reset_recurrent_dropout_mask()\n        if isinstance(cell, StackedRNNCells):\n            for c in cell.cells:\n                self._maybe_reset_dropout_masks(c)\n\n    def get_config(self):\n        config = {\n            \"return_sequences\": self.return_sequences,\n            \"return_state\": self.return_state,\n            \"go_backwards\": self.go_backwards,\n            \"stateful\": self.stateful,\n            \"unroll\": self.unroll,\n            \"zero_output_for_mask\": self.zero_output_for_mask,\n        }\n        config[\"cell\"] = serialization_lib.serialize_keras_object(self.cell)\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n    @classmethod\n    def from_config(cls, config, custom_objects=None):\n        cell = serialization_lib.deserialize_keras_object(\n            config.pop(\"cell\"), custom_objects=custom_objects\n        )\n        layer = cls(cell, **config)\n        return layer\n"
  },
  {
    "path": "keras/src/layers/rnn/rnn_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import layers\nfrom keras.src import ops\nfrom keras.src import testing\n\n\nclass OneStateRNNCell(layers.Layer):\n    def __init__(self, units, state_size=None, **kwargs):\n        super().__init__(**kwargs)\n        self.units = units\n        self.state_size = state_size if state_size else units\n\n    def build(self, input_shape):\n        self.kernel = self.add_weight(\n            shape=(input_shape[-1], self.units),\n            initializer=\"ones\",\n            name=\"kernel\",\n        )\n        self.recurrent_kernel = self.add_weight(\n            shape=(self.units, self.units),\n            initializer=\"ones\",\n            name=\"recurrent_kernel\",\n        )\n\n    def call(self, inputs, states):\n        prev_output = states[0]\n        h = ops.matmul(inputs, self.kernel)\n        output = h + ops.matmul(prev_output, self.recurrent_kernel)\n        return output, [output]\n\n\nclass TwoStatesRNNCell(layers.Layer):\n    def __init__(self, units, state_size=None, **kwargs):\n        super().__init__(**kwargs)\n        self.units = units\n        self.state_size = state_size if state_size else [units, units]\n        self.output_size = units\n\n    def build(self, input_shape):\n        self.kernel = self.add_weight(\n            shape=(input_shape[-1], self.units),\n            initializer=\"ones\",\n            name=\"kernel\",\n        )\n        self.recurrent_kernel_1 = self.add_weight(\n            shape=(self.units, self.units),\n            initializer=\"ones\",\n            name=\"recurrent_kernel_1\",\n        )\n        self.recurrent_kernel_2 = self.add_weight(\n            shape=(self.units, self.units),\n            initializer=\"ones\",\n            name=\"recurrent_kernel_2\",\n        )\n\n    def call(self, inputs, states):\n        prev_1 = states[0]\n        prev_2 = states[0]\n        h = ops.matmul(inputs, self.kernel)\n        output_1 = h + ops.matmul(prev_1, self.recurrent_kernel_1)\n        output_2 = h + ops.matmul(prev_2, self.recurrent_kernel_2)\n        output = output_1 + output_2\n        return output, [output_1, output_2]\n\n\nclass RNNTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_basics(self):\n        self.run_layer_test(\n            layers.RNN,\n            init_kwargs={\"cell\": OneStateRNNCell(5, state_size=5)},\n            input_shape=(3, 2, 4),\n            expected_output_shape=(3, 5),\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.RNN,\n            init_kwargs={\"cell\": OneStateRNNCell(5, state_size=[5])},\n            input_shape=(3, 2, 4),\n            expected_output_shape=(3, 5),\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.RNN,\n            init_kwargs={\"cell\": OneStateRNNCell(5, state_size=(5,))},\n            input_shape=(3, 2, 4),\n            expected_output_shape=(3, 5),\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.RNN,\n            init_kwargs={\"cell\": OneStateRNNCell(5), \"return_sequences\": True},\n            input_shape=(3, 2, 4),\n            expected_output_shape=(3, 2, 5),\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.RNN,\n            init_kwargs={\n                \"cell\": OneStateRNNCell(5),\n                \"go_backwards\": True,\n                \"unroll\": True,\n            },\n            input_shape=(3, 2, 4),\n            expected_output_shape=(3, 5),\n            expected_num_trainable_weights=2,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.RNN,\n            init_kwargs={\"cell\": TwoStatesRNNCell(5, state_size=[5, 5])},\n            input_shape=(3, 2, 4),\n            expected_output_shape=(3, 5),\n            expected_num_trainable_weights=3,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.RNN,\n            init_kwargs={\"cell\": TwoStatesRNNCell(5, state_size=(5, 5))},\n            input_shape=(3, 2, 4),\n            expected_output_shape=(3, 5),\n            expected_num_trainable_weights=3,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.RNN,\n            init_kwargs={\"cell\": TwoStatesRNNCell(5), \"return_sequences\": True},\n            input_shape=(3, 2, 4),\n            expected_output_shape=(3, 2, 5),\n            expected_num_trainable_weights=3,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            supports_masking=True,\n        )\n\n    def test_compute_output_shape_single_state(self):\n        sequence = np.ones((3, 4, 5))\n        layer = layers.RNN(OneStateRNNCell(8), return_sequences=False)\n        output_shape = layer.compute_output_shape(sequence.shape)\n        self.assertEqual(output_shape, (3, 8))\n\n        layer = layers.RNN(OneStateRNNCell(8), return_sequences=True)\n        output_shape = layer.compute_output_shape(sequence.shape)\n        self.assertEqual(output_shape, (3, 4, 8))\n\n        layer = layers.RNN(\n            OneStateRNNCell(8), return_sequences=False, return_state=True\n        )\n        output_shape = layer.compute_output_shape(sequence.shape)\n        self.assertEqual(output_shape[0], (3, 8))\n        self.assertEqual(output_shape[1], (3, 8))\n\n        layer = layers.RNN(\n            OneStateRNNCell(8), return_sequences=True, return_state=True\n        )\n        output_shape = layer.compute_output_shape(sequence.shape)\n        self.assertEqual(output_shape[0], (3, 4, 8))\n        self.assertEqual(output_shape[1], (3, 8))\n\n    def test_compute_output_shape_two_states(self):\n        sequence = np.ones((3, 4, 5))\n        layer = layers.RNN(TwoStatesRNNCell(8), return_sequences=False)\n        output_shape = layer.compute_output_shape(sequence.shape)\n        self.assertEqual(output_shape, (3, 8))\n\n        layer = layers.RNN(TwoStatesRNNCell(8), return_sequences=True)\n        output_shape = layer.compute_output_shape(sequence.shape)\n        self.assertEqual(output_shape, (3, 4, 8))\n\n        layer = layers.RNN(\n            TwoStatesRNNCell(8), return_sequences=False, return_state=True\n        )\n        output_shape = layer.compute_output_shape(sequence.shape)\n        self.assertEqual(output_shape[0], (3, 8))\n        self.assertEqual(output_shape[1], (3, 8))\n        self.assertEqual(output_shape[2], (3, 8))\n\n        layer = layers.RNN(\n            TwoStatesRNNCell(8), return_sequences=True, return_state=True\n        )\n        output_shape = layer.compute_output_shape(sequence.shape)\n        self.assertEqual(output_shape[0], (3, 4, 8))\n        self.assertEqual(output_shape[1], (3, 8))\n        self.assertEqual(output_shape[2], (3, 8))\n\n    def test_dynamic_shapes(self):\n        sequence_shape = (None, None, 3)\n        layer = layers.RNN(OneStateRNNCell(8), return_sequences=False)\n        output_shape = layer.compute_output_shape(sequence_shape)\n        self.assertEqual(output_shape, (None, 8))\n\n        layer = layers.RNN(OneStateRNNCell(8), return_sequences=True)\n        output_shape = layer.compute_output_shape(sequence_shape)\n        self.assertEqual(output_shape, (None, None, 8))\n\n        layer = layers.RNN(\n            OneStateRNNCell(8), return_sequences=False, return_state=True\n        )\n        output_shape = layer.compute_output_shape(sequence_shape)\n        self.assertEqual(output_shape[0], (None, 8))\n        self.assertEqual(output_shape[1], (None, 8))\n\n        layer = layers.RNN(\n            OneStateRNNCell(8), return_sequences=True, return_state=True\n        )\n        output_shape = layer.compute_output_shape(sequence_shape)\n        self.assertEqual(output_shape[0], (None, None, 8))\n        self.assertEqual(output_shape[1], (None, 8))\n\n        layer = layers.RNN(TwoStatesRNNCell(8), return_sequences=False)\n        output_shape = layer.compute_output_shape(sequence_shape)\n        self.assertEqual(output_shape, (None, 8))\n\n        layer = layers.RNN(TwoStatesRNNCell(8), return_sequences=True)\n        output_shape = layer.compute_output_shape(sequence_shape)\n        self.assertEqual(output_shape, (None, None, 8))\n\n        layer = layers.RNN(\n            TwoStatesRNNCell(8), return_sequences=False, return_state=True\n        )\n        output_shape = layer.compute_output_shape(sequence_shape)\n        self.assertEqual(output_shape[0], (None, 8))\n        self.assertEqual(output_shape[1], (None, 8))\n        self.assertEqual(output_shape[2], (None, 8))\n\n        layer = layers.RNN(\n            TwoStatesRNNCell(8), return_sequences=True, return_state=True\n        )\n        output_shape = layer.compute_output_shape(sequence_shape)\n        self.assertEqual(output_shape[0], (None, None, 8))\n        self.assertEqual(output_shape[1], (None, 8))\n        self.assertEqual(output_shape[2], (None, 8))\n\n    def test_forward_pass_single_state(self):\n        sequence = np.ones((1, 2, 3))\n        layer = layers.RNN(OneStateRNNCell(2), return_sequences=False)\n        output = layer(sequence)\n        self.assertAllClose(np.array([[9.0, 9.0]]), output)\n\n        layer = layers.RNN(OneStateRNNCell(2), return_sequences=True)\n        output = layer(sequence)\n        self.assertAllClose(np.array([[[3.0, 3.0], [9.0, 9.0]]]), output)\n\n        layer = layers.RNN(\n            OneStateRNNCell(2), return_sequences=False, return_state=True\n        )\n        output, state = layer(sequence)\n        self.assertAllClose(np.array([[9.0, 9.0]]), output)\n        self.assertAllClose(np.array([[9.0, 9.0]]), state)\n\n        layer = layers.RNN(\n            OneStateRNNCell(2), return_sequences=True, return_state=True\n        )\n        output, state = layer(sequence)\n        self.assertAllClose(np.array([[[3.0, 3.0], [9.0, 9.0]]]), output)\n        self.assertAllClose(np.array([[9.0, 9.0]]), state)\n\n    def test_forward_pass_two_states(self):\n        sequence = np.ones((1, 2, 3))\n        layer = layers.RNN(TwoStatesRNNCell(2), return_sequences=False)\n        output = layer(sequence)\n        self.assertAllClose(np.array([[18.0, 18.0]]), output)\n\n        layer = layers.RNN(TwoStatesRNNCell(2), return_sequences=True)\n        output = layer(sequence)\n        self.assertAllClose(np.array([[[6.0, 6.0], [18.0, 18.0]]]), output)\n\n        layer = layers.RNN(\n            TwoStatesRNNCell(2), return_sequences=False, return_state=True\n        )\n        output, state1, state2 = layer(sequence)\n        self.assertAllClose(np.array([[18.0, 18.0]]), output)\n        self.assertAllClose(np.array([[9.0, 9.0]]), state1)\n        self.assertAllClose(np.array([[9.0, 9.0]]), state2)\n\n        layer = layers.RNN(\n            TwoStatesRNNCell(2), return_sequences=True, return_state=True\n        )\n        output, state1, state2 = layer(sequence)\n        self.assertAllClose(np.array([[[6.0, 6.0], [18.0, 18.0]]]), output)\n        self.assertAllClose(np.array([[9.0, 9.0]]), state1)\n        self.assertAllClose(np.array([[9.0, 9.0]]), state2)\n\n    def test_passing_initial_state_single_state(self):\n        sequence = np.ones((2, 3, 2))\n        state = np.ones((2, 2))\n        layer = layers.RNN(OneStateRNNCell(2), return_sequences=False)\n        output = layer(sequence, initial_state=state)\n        self.assertAllClose(np.array([[22.0, 22.0], [22.0, 22.0]]), output)\n\n        layer = layers.RNN(\n            OneStateRNNCell(2), return_sequences=False, return_state=True\n        )\n        output, state = layer(sequence, initial_state=state)\n        self.assertAllClose(np.array([[22.0, 22.0], [22.0, 22.0]]), output)\n        self.assertAllClose(np.array([[22.0, 22.0], [22.0, 22.0]]), state)\n\n    def test_passing_initial_state_two_states(self):\n        sequence = np.ones((2, 3, 2))\n        state = [np.ones((2, 2)), np.ones((2, 2))]\n        layer = layers.RNN(TwoStatesRNNCell(2), return_sequences=False)\n        output = layer(sequence, initial_state=state)\n        self.assertAllClose(np.array([[44.0, 44.0], [44.0, 44.0]]), output)\n\n        layer = layers.RNN(\n            TwoStatesRNNCell(2), return_sequences=False, return_state=True\n        )\n        output, state_1, state_2 = layer(sequence, initial_state=state)\n        self.assertAllClose(np.array([[44.0, 44.0], [44.0, 44.0]]), output)\n        self.assertAllClose(np.array([[22.0, 22.0], [22.0, 22.0]]), state_1)\n        self.assertAllClose(np.array([[22.0, 22.0], [22.0, 22.0]]), state_2)\n\n    def test_statefulness_single_state(self):\n        sequence = np.ones((1, 2, 3))\n        layer = layers.RNN(OneStateRNNCell(2), stateful=True)\n        layer(sequence)\n        output = layer(sequence)\n        self.assertAllClose(np.array([[45.0, 45.0]]), output)\n\n        layer = layers.RNN(OneStateRNNCell(2), stateful=True, return_state=True)\n        layer(sequence)\n        output, state = layer(sequence)\n        self.assertAllClose(np.array([[45.0, 45.0]]), output)\n        self.assertAllClose(np.array([[45.0, 45.0]]), state)\n\n    def test_statefulness_two_states(self):\n        sequence = np.ones((1, 2, 3))\n        layer = layers.RNN(TwoStatesRNNCell(2), stateful=True)\n        layer(sequence)\n        output = layer(sequence)\n        self.assertAllClose(np.array([[90.0, 90.0]]), output)\n\n        layer = layers.RNN(\n            TwoStatesRNNCell(2), stateful=True, return_state=True\n        )\n        layer(sequence)\n        output, state_1, state_2 = layer(sequence)\n        self.assertAllClose(np.array([[90.0, 90.0]]), output)\n        self.assertAllClose(np.array([[45.0, 45.0]]), state_1)\n        self.assertAllClose(np.array([[45.0, 45.0]]), state_2)\n\n    def test_go_backwards(self):\n        sequence = np.arange(24).reshape((2, 3, 4)).astype(\"float32\")\n        layer = layers.RNN(OneStateRNNCell(2), go_backwards=True)\n        layer(sequence)\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array([[202.0, 202.0], [538.0, 538.0]]),\n            output,\n            tpu_atol=1e-4,\n            tpu_rtol=1e-4,\n        )\n\n        layer = layers.RNN(OneStateRNNCell(2), stateful=True, return_state=True)\n        layer(sequence)\n        output, state = layer(sequence)\n        self.assertAllClose(\n            np.array([[954.0, 954.0], [3978.0, 3978.0]]),\n            output,\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n        self.assertAllClose(\n            np.array([[954.0, 954.0], [3978.0, 3978.0]]),\n            state,\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n\n    def test_serialization(self):\n        layer = layers.RNN(TwoStatesRNNCell(2), return_sequences=False)\n        self.run_class_serialization_test(layer)\n\n        layer = layers.RNN(OneStateRNNCell(2), return_sequences=False)\n        self.run_class_serialization_test(layer)\n\n    def test_stateful_batch_size_mismatch_raises(self):\n        from keras.src.models import Functional\n\n        batch_size = 4\n        timesteps = 5\n        features = 3\n\n        layer = layers.RNN(TwoStatesRNNCell(2), stateful=True)\n        inputs = layers.Input(\n            shape=(timesteps, features), batch_size=batch_size\n        )\n        model = Functional(inputs, layer(inputs))\n\n        # Call once with correct batch size\n        x = ops.random.uniform(shape=(batch_size, timesteps, features))\n        _ = model(x)\n\n        # Expect ValueError when called with incorrect batch size\n        with self.assertRaisesRegex(ValueError, \"batch size\"):\n            x_bad = ops.random.uniform(shape=(1, timesteps, features))\n            model(x_bad)\n\n    # TODO: test masking\n"
  },
  {
    "path": "keras/src/layers/rnn/simple_rnn.py",
    "content": "from keras.src import activations\nfrom keras.src import backend\nfrom keras.src import constraints\nfrom keras.src import initializers\nfrom keras.src import ops\nfrom keras.src import regularizers\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\nfrom keras.src.layers.rnn.dropout_rnn_cell import DropoutRNNCell\nfrom keras.src.layers.rnn.rnn import RNN\n\n\n@keras_export(\"keras.layers.SimpleRNNCell\")\nclass SimpleRNNCell(Layer, DropoutRNNCell):\n    \"\"\"Cell class for SimpleRNN.\n\n    This class processes one step within the whole time sequence input, whereas\n    `keras.layer.SimpleRNN` processes the whole sequence.\n\n    Args:\n        units: Positive integer, dimensionality of the output space.\n        activation: Activation function to use.\n            Default: hyperbolic tangent (`tanh`).\n            If you pass `None`, no activation is applied\n            (ie. \"linear\" activation: `a(x) = x`).\n        use_bias: Boolean, (default `True`), whether the layer\n            should use a bias vector.\n        kernel_initializer: Initializer for the `kernel` weights matrix,\n            used for the linear transformation of the inputs. Default:\n            `\"glorot_uniform\"`.\n        recurrent_initializer: Initializer for the `recurrent_kernel`\n            weights matrix, used for the linear transformation\n            of the recurrent state. Default: `\"orthogonal\"`.\n        bias_initializer: Initializer for the bias vector. Default: `\"zeros\"`.\n        kernel_regularizer: Regularizer function applied to the `kernel` weights\n            matrix. Default: `None`.\n        recurrent_regularizer: Regularizer function applied to the\n            `recurrent_kernel` weights matrix. Default: `None`.\n        bias_regularizer: Regularizer function applied to the bias vector.\n            Default: `None`.\n        kernel_constraint: Constraint function applied to the `kernel` weights\n            matrix. Default: `None`.\n        recurrent_constraint: Constraint function applied to the\n            `recurrent_kernel` weights matrix. Default: `None`.\n        bias_constraint: Constraint function applied to the bias vector.\n            Default: `None`.\n        dropout: Float between 0 and 1. Fraction of the units to drop for the\n            linear transformation of the inputs. Default: 0.\n        recurrent_dropout: Float between 0 and 1. Fraction of the units to drop\n            for the linear transformation of the recurrent state. Default: 0.\n        seed: Random seed for dropout.\n\n    Call arguments:\n        sequence: A 2D tensor, with shape `(batch, features)`.\n        states: A 2D tensor with shape `(batch, units)`, which is the state\n            from the previous time step.\n        training: Python boolean indicating whether the layer should behave in\n            training mode or in inference mode. Only relevant when `dropout` or\n            `recurrent_dropout` is used.\n\n    Example:\n\n    ```python\n    inputs = np.random.random([32, 10, 8]).astype(np.float32)\n    rnn = keras.layers.RNN(keras.layers.SimpleRNNCell(4))\n    output = rnn(inputs)  # The output has shape `(32, 4)`.\n    rnn = keras.layers.RNN(\n        keras.layers.SimpleRNNCell(4),\n        return_sequences=True,\n        return_state=True\n    )\n    # whole_sequence_output has shape `(32, 10, 4)`.\n    # final_state has shape `(32, 4)`.\n    whole_sequence_output, final_state = rnn(inputs)\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        units,\n        activation=\"tanh\",\n        use_bias=True,\n        kernel_initializer=\"glorot_uniform\",\n        recurrent_initializer=\"orthogonal\",\n        bias_initializer=\"zeros\",\n        kernel_regularizer=None,\n        recurrent_regularizer=None,\n        bias_regularizer=None,\n        kernel_constraint=None,\n        recurrent_constraint=None,\n        bias_constraint=None,\n        dropout=0.0,\n        recurrent_dropout=0.0,\n        seed=None,\n        **kwargs,\n    ):\n        if units <= 0:\n            raise ValueError(\n                \"Received an invalid value for argument `units`, \"\n                f\"expected a positive integer, got {units}.\"\n            )\n        super().__init__(**kwargs)\n        self.seed = seed\n        self.seed_generator = backend.random.SeedGenerator(seed)\n\n        self.units = units\n        self.activation = activations.get(activation)\n        self.use_bias = use_bias\n\n        self.kernel_initializer = initializers.get(kernel_initializer)\n        self.recurrent_initializer = initializers.get(recurrent_initializer)\n        self.bias_initializer = initializers.get(bias_initializer)\n\n        self.kernel_regularizer = regularizers.get(kernel_regularizer)\n        self.recurrent_regularizer = regularizers.get(recurrent_regularizer)\n        self.bias_regularizer = regularizers.get(bias_regularizer)\n\n        self.kernel_constraint = constraints.get(kernel_constraint)\n        self.recurrent_constraint = constraints.get(recurrent_constraint)\n        self.bias_constraint = constraints.get(bias_constraint)\n\n        self.dropout = min(1.0, max(0.0, dropout))\n        self.recurrent_dropout = min(1.0, max(0.0, recurrent_dropout))\n        self.state_size = self.units\n        self.output_size = self.units\n\n    def build(self, input_shape):\n        self.kernel = self.add_weight(\n            shape=(input_shape[-1], self.units),\n            name=\"kernel\",\n            initializer=self.kernel_initializer,\n            regularizer=self.kernel_regularizer,\n            constraint=self.kernel_constraint,\n        )\n        self.recurrent_kernel = self.add_weight(\n            shape=(self.units, self.units),\n            name=\"recurrent_kernel\",\n            initializer=self.recurrent_initializer,\n            regularizer=self.recurrent_regularizer,\n            constraint=self.recurrent_constraint,\n        )\n        if self.use_bias:\n            self.bias = self.add_weight(\n                shape=(self.units,),\n                name=\"bias\",\n                initializer=self.bias_initializer,\n                regularizer=self.bias_regularizer,\n                constraint=self.bias_constraint,\n            )\n        else:\n            self.bias = None\n\n    def call(self, sequence, states, training=False):\n        prev_output = states[0] if isinstance(states, (list, tuple)) else states\n        dp_mask = self.get_dropout_mask(sequence)\n        rec_dp_mask = self.get_recurrent_dropout_mask(prev_output)\n\n        if training and dp_mask is not None:\n            sequence = sequence * dp_mask\n        h = ops.matmul(sequence, self.kernel)\n        if self.bias is not None:\n            h = ops.add(h, self.bias)\n\n        if training and rec_dp_mask is not None:\n            prev_output = prev_output * rec_dp_mask\n        output = h + ops.matmul(prev_output, self.recurrent_kernel)\n        if self.activation is not None:\n            output = self.activation(output)\n\n        new_state = [output] if isinstance(states, (list, tuple)) else output\n        return output, new_state\n\n    def get_initial_state(self, batch_size=None):\n        return [\n            ops.zeros((batch_size, self.state_size), dtype=self.compute_dtype)\n        ]\n\n    def get_config(self):\n        config = {\n            \"units\": self.units,\n            \"activation\": activations.serialize(self.activation),\n            \"use_bias\": self.use_bias,\n            \"kernel_initializer\": initializers.serialize(\n                self.kernel_initializer\n            ),\n            \"recurrent_initializer\": initializers.serialize(\n                self.recurrent_initializer\n            ),\n            \"bias_initializer\": initializers.serialize(self.bias_initializer),\n            \"kernel_regularizer\": regularizers.serialize(\n                self.kernel_regularizer\n            ),\n            \"recurrent_regularizer\": regularizers.serialize(\n                self.recurrent_regularizer\n            ),\n            \"bias_regularizer\": regularizers.serialize(self.bias_regularizer),\n            \"kernel_constraint\": constraints.serialize(self.kernel_constraint),\n            \"recurrent_constraint\": constraints.serialize(\n                self.recurrent_constraint\n            ),\n            \"bias_constraint\": constraints.serialize(self.bias_constraint),\n            \"dropout\": self.dropout,\n            \"recurrent_dropout\": self.recurrent_dropout,\n            \"seed\": self.seed,\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n\n@keras_export(\"keras.layers.SimpleRNN\")\nclass SimpleRNN(RNN):\n    \"\"\"Fully-connected RNN where the output is to be fed back as the new input.\n\n    Args:\n        units: Positive integer, dimensionality of the output space.\n        activation: Activation function to use.\n            Default: hyperbolic tangent (`tanh`).\n            If you pass None, no activation is applied\n            (ie. \"linear\" activation: `a(x) = x`).\n        use_bias: Boolean, (default `True`), whether the layer uses\n            a bias vector.\n        kernel_initializer: Initializer for the `kernel` weights matrix,\n            used for the linear transformation of the inputs. Default:\n            `\"glorot_uniform\"`.\n        recurrent_initializer: Initializer for the `recurrent_kernel`\n            weights matrix, used for the linear transformation of the recurrent\n            state.  Default: `\"orthogonal\"`.\n        bias_initializer: Initializer for the bias vector. Default: `\"zeros\"`.\n        kernel_regularizer: Regularizer function applied to the `kernel` weights\n            matrix. Default: `None`.\n        recurrent_regularizer: Regularizer function applied to the\n            `recurrent_kernel` weights matrix. Default: `None`.\n        bias_regularizer: Regularizer function applied to the bias vector.\n            Default: `None`.\n        activity_regularizer: Regularizer function applied to the output of the\n            layer (its \"activation\"). Default: `None`.\n        kernel_constraint: Constraint function applied to the `kernel` weights\n            matrix. Default: `None`.\n        recurrent_constraint: Constraint function applied to the\n            `recurrent_kernel` weights matrix.  Default: `None`.\n        bias_constraint: Constraint function applied to the bias vector.\n            Default: `None`.\n        dropout: Float between 0 and 1.\n            Fraction of the units to drop for the linear transformation\n            of the inputs. Default: 0.\n        recurrent_dropout: Float between 0 and 1.\n            Fraction of the units to drop for the linear transformation of the\n            recurrent state. Default: 0.\n        return_sequences: Boolean. Whether to return the last output\n            in the output sequence, or the full sequence. Default: `False`.\n        return_state: Boolean. Whether to return the last state\n            in addition to the output. Default: `False`.\n        go_backwards: Boolean (default: `False`).\n            If `True`, process the input sequence backwards and return the\n            reversed sequence.\n        stateful: Boolean (default: `False`). If `True`, the last state\n            for each sample at index i in a batch will be used as the\n            initial state for the sample of index i in the following batch.\n        unroll: Boolean (default: `False`).\n            If `True`, the network will be unrolled,\n            else a symbolic loop will be used.\n            Unrolling can speed-up an RNN,\n            although it tends to be more memory-intensive.\n            Unrolling is only suitable for short sequences.\n\n    Call arguments:\n        sequence: A 3D tensor, with shape `[batch, timesteps, feature]`.\n        mask: Binary tensor of shape `[batch, timesteps]` indicating whether\n            a given timestep should be masked. An individual `True` entry\n            indicates that the corresponding timestep should be utilized,\n            while a `False` entry indicates that the corresponding timestep\n            should be ignored.\n        training: Python boolean indicating whether the layer should behave in\n            training mode or in inference mode.\n            This argument is passed to the cell when calling it.\n            This is only relevant if `dropout` or `recurrent_dropout` is used.\n        initial_state: List of initial state tensors to be passed to the first\n            call of the cell.\n\n    Example:\n\n    ```python\n    inputs = np.random.random((32, 10, 8))\n    simple_rnn = keras.layers.SimpleRNN(4)\n    output = simple_rnn(inputs)  # The output has shape `(32, 4)`.\n    simple_rnn = keras.layers.SimpleRNN(\n        4, return_sequences=True, return_state=True\n    )\n    # whole_sequence_output has shape `(32, 10, 4)`.\n    # final_state has shape `(32, 4)`.\n    whole_sequence_output, final_state = simple_rnn(inputs)\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        units,\n        activation=\"tanh\",\n        use_bias=True,\n        kernel_initializer=\"glorot_uniform\",\n        recurrent_initializer=\"orthogonal\",\n        bias_initializer=\"zeros\",\n        kernel_regularizer=None,\n        recurrent_regularizer=None,\n        bias_regularizer=None,\n        activity_regularizer=None,\n        kernel_constraint=None,\n        recurrent_constraint=None,\n        bias_constraint=None,\n        dropout=0.0,\n        recurrent_dropout=0.0,\n        return_sequences=False,\n        return_state=False,\n        go_backwards=False,\n        stateful=False,\n        unroll=False,\n        seed=None,\n        **kwargs,\n    ):\n        cell = SimpleRNNCell(\n            units,\n            activation=activation,\n            use_bias=use_bias,\n            kernel_initializer=kernel_initializer,\n            recurrent_initializer=recurrent_initializer,\n            bias_initializer=bias_initializer,\n            kernel_regularizer=kernel_regularizer,\n            recurrent_regularizer=recurrent_regularizer,\n            bias_regularizer=bias_regularizer,\n            kernel_constraint=kernel_constraint,\n            recurrent_constraint=recurrent_constraint,\n            bias_constraint=bias_constraint,\n            dropout=dropout,\n            recurrent_dropout=recurrent_dropout,\n            seed=seed,\n            dtype=kwargs.get(\"dtype\", None),\n            trainable=kwargs.get(\"trainable\", True),\n            name=\"simple_rnn_cell\",\n        )\n        super().__init__(\n            cell,\n            return_sequences=return_sequences,\n            return_state=return_state,\n            go_backwards=go_backwards,\n            stateful=stateful,\n            unroll=unroll,\n            **kwargs,\n        )\n        self.input_spec = [InputSpec(ndim=3)]\n\n    def call(self, sequences, initial_state=None, mask=None, training=False):\n        return super().call(\n            sequences, mask=mask, training=training, initial_state=initial_state\n        )\n\n    @property\n    def units(self):\n        return self.cell.units\n\n    @property\n    def activation(self):\n        return self.cell.activation\n\n    @property\n    def use_bias(self):\n        return self.cell.use_bias\n\n    @property\n    def kernel_initializer(self):\n        return self.cell.kernel_initializer\n\n    @property\n    def recurrent_initializer(self):\n        return self.cell.recurrent_initializer\n\n    @property\n    def bias_initializer(self):\n        return self.cell.bias_initializer\n\n    @property\n    def kernel_regularizer(self):\n        return self.cell.kernel_regularizer\n\n    @property\n    def recurrent_regularizer(self):\n        return self.cell.recurrent_regularizer\n\n    @property\n    def bias_regularizer(self):\n        return self.cell.bias_regularizer\n\n    @property\n    def kernel_constraint(self):\n        return self.cell.kernel_constraint\n\n    @property\n    def recurrent_constraint(self):\n        return self.cell.recurrent_constraint\n\n    @property\n    def bias_constraint(self):\n        return self.cell.bias_constraint\n\n    @property\n    def dropout(self):\n        return self.cell.dropout\n\n    @property\n    def recurrent_dropout(self):\n        return self.cell.recurrent_dropout\n\n    def get_config(self):\n        config = {\n            \"units\": self.units,\n            \"activation\": activations.serialize(self.activation),\n            \"use_bias\": self.use_bias,\n            \"kernel_initializer\": initializers.serialize(\n                self.kernel_initializer\n            ),\n            \"recurrent_initializer\": initializers.serialize(\n                self.recurrent_initializer\n            ),\n            \"bias_initializer\": initializers.serialize(self.bias_initializer),\n            \"kernel_regularizer\": regularizers.serialize(\n                self.kernel_regularizer\n            ),\n            \"recurrent_regularizer\": regularizers.serialize(\n                self.recurrent_regularizer\n            ),\n            \"bias_regularizer\": regularizers.serialize(self.bias_regularizer),\n            \"activity_regularizer\": regularizers.serialize(\n                self.activity_regularizer\n            ),\n            \"kernel_constraint\": constraints.serialize(self.kernel_constraint),\n            \"recurrent_constraint\": constraints.serialize(\n                self.recurrent_constraint\n            ),\n            \"bias_constraint\": constraints.serialize(self.bias_constraint),\n            \"dropout\": self.dropout,\n            \"recurrent_dropout\": self.recurrent_dropout,\n        }\n        base_config = super().get_config()\n        del base_config[\"cell\"]\n        return {**base_config, **config}\n\n    @classmethod\n    def from_config(cls, config):\n        return cls(**config)\n"
  },
  {
    "path": "keras/src/layers/rnn/simple_rnn_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import initializers\nfrom keras.src import layers\nfrom keras.src import testing\n\n\nclass SimpleRNNTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_basics(self):\n        self.run_layer_test(\n            layers.SimpleRNN,\n            init_kwargs={\"units\": 3, \"dropout\": 0.5, \"recurrent_dropout\": 0.5},\n            input_shape=(3, 2, 4),\n            call_kwargs={\"training\": True},\n            expected_output_shape=(3, 3),\n            expected_num_trainable_weights=3,\n            expected_num_non_trainable_weights=0,\n            expected_num_non_trainable_variables=1,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.SimpleRNN,\n            init_kwargs={\n                \"units\": 3,\n                \"return_sequences\": True,\n                \"bias_regularizer\": \"l1\",\n                \"kernel_regularizer\": \"l2\",\n                \"recurrent_regularizer\": \"l2\",\n            },\n            input_shape=(3, 2, 4),\n            expected_output_shape=(3, 2, 3),\n            expected_num_losses=3,\n            expected_num_trainable_weights=3,\n            expected_num_non_trainable_weights=0,\n            supports_masking=True,\n        )\n\n    def test_correctness(self):\n        sequence = np.arange(24).reshape((2, 3, 4)).astype(\"float32\")\n        layer = layers.SimpleRNN(\n            4,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n        )\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.405432, 0.405432, 0.405432, 0.405432],\n                    [0.73605347, 0.73605347, 0.73605347, 0.73605347],\n                ]\n            ),\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n        layer = layers.SimpleRNN(\n            4,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            unroll=True,\n        )\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.405432, 0.405432, 0.405432, 0.405432],\n                    [0.73605347, 0.73605347, 0.73605347, 0.73605347],\n                ]\n            ),\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n        layer = layers.SimpleRNN(\n            4,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            go_backwards=True,\n        )\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.11144729, 0.11144729, 0.11144729, 0.11144729],\n                    [0.5528889, 0.5528889, 0.5528889, 0.5528889],\n                ]\n            ),\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n        layer = layers.SimpleRNN(\n            4,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            go_backwards=True,\n            unroll=True,\n        )\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.11144729, 0.11144729, 0.11144729, 0.11144729],\n                    [0.5528889, 0.5528889, 0.5528889, 0.5528889],\n                ]\n            ),\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n    def test_statefulness(self):\n        sequence = np.arange(24).reshape((2, 3, 4)).astype(\"float32\")\n        layer = layers.SimpleRNN(\n            4,\n            stateful=True,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n        )\n        layer(sequence)\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.40559256, 0.40559256, 0.40559256, 0.40559256],\n                    [0.7361247, 0.7361247, 0.7361247, 0.7361247],\n                ]\n            ),\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n        layer.reset_state()\n        layer(sequence)\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.40559256, 0.40559256, 0.40559256, 0.40559256],\n                    [0.7361247, 0.7361247, 0.7361247, 0.7361247],\n                ]\n            ),\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n    def test_pass_initial_state(self):\n        sequence = np.arange(24).reshape((2, 4, 3)).astype(\"float32\")\n        initial_state = np.arange(8).reshape((2, 4)).astype(\"float32\")\n        layer = layers.SimpleRNN(\n            4,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n        )\n        output = layer(sequence, initial_state=initial_state)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.33621645, 0.33621645, 0.33621645, 0.33621645],\n                    [0.6262637, 0.6262637, 0.6262637, 0.6262637],\n                ]\n            ),\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n        layer = layers.SimpleRNN(\n            4,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            go_backwards=True,\n        )\n        output = layer(sequence, initial_state=initial_state)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.07344437, 0.07344437, 0.07344437, 0.07344437],\n                    [0.43043602, 0.43043602, 0.43043602, 0.43043602],\n                ]\n            ),\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n    def test_masking(self):\n        sequence = np.arange(24).reshape((2, 4, 3)).astype(\"float32\")\n        mask = np.array([[True, True, False, True], [True, False, False, True]])\n        layer = layers.SimpleRNN(\n            4,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            unroll=True,\n        )\n        output = layer(sequence, mask=mask)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.32951632, 0.32951632, 0.32951632, 0.32951632],\n                    [0.61799484, 0.61799484, 0.61799484, 0.61799484],\n                ]\n            ),\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n        layer = layers.SimpleRNN(\n            2,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            return_sequences=True,\n        )\n        output = layer(sequence, mask=mask)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.0599281, 0.0599281],\n                    [0.15122814, 0.15122814],\n                    [0.15122814, 0.15122814],\n                    [0.32394567, 0.32394567],\n                ],\n            ),\n            output[0],\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.3969304, 0.3969304],\n                    [0.3969304, 0.3969304],\n                    [0.3969304, 0.3969304],\n                    [0.608085, 0.608085],\n                ],\n            ),\n            output[1],\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n        layer = layers.SimpleRNN(\n            2,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            return_sequences=True,\n            zero_output_for_mask=True,\n        )\n        output = layer(sequence, mask=mask)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.0599281, 0.0599281],\n                    [0.15122814, 0.15122814],\n                    [0.0, 0.0],\n                    [0.32394567, 0.32394567],\n                ],\n            ),\n            output[0],\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.3969304, 0.3969304],\n                    [0.0, 0.0],\n                    [0.0, 0.0],\n                    [0.608085, 0.608085],\n                ],\n            ),\n            output[1],\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n\n        layer = layers.SimpleRNN(\n            4,\n            kernel_initializer=initializers.Constant(0.01),\n            recurrent_initializer=initializers.Constant(0.02),\n            bias_initializer=initializers.Constant(0.03),\n            go_backwards=True,\n        )\n        output = layer(sequence, mask=mask)\n        self.assertAllClose(\n            np.array(\n                [\n                    [0.07376196, 0.07376196, 0.07376196, 0.07376196],\n                    [0.43645123, 0.43645123, 0.43645123, 0.43645123],\n                ]\n            ),\n            output,\n            tpu_atol=1e-3,\n            tpu_rtol=1e-3,\n        )\n"
  },
  {
    "path": "keras/src/layers/rnn/stacked_rnn_cells.py",
    "content": "from keras.src import ops\nfrom keras.src import tree\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.layer import Layer\nfrom keras.src.saving import serialization_lib\n\n\n@keras_export(\"keras.layers.StackedRNNCells\")\nclass StackedRNNCells(Layer):\n    \"\"\"Wrapper allowing a stack of RNN cells to behave as a single cell.\n\n    Used to implement efficient stacked RNNs.\n\n    Args:\n      cells: List of RNN cell instances.\n\n    Example:\n\n    ```python\n    batch_size = 3\n    sentence_length = 5\n    num_features = 2\n    new_shape = (batch_size, sentence_length, num_features)\n    x = np.reshape(np.arange(30), new_shape)\n\n    rnn_cells = [keras.layers.LSTMCell(128) for _ in range(2)]\n    stacked_lstm = keras.layers.StackedRNNCells(rnn_cells)\n    lstm_layer = keras.layers.RNN(stacked_lstm)\n\n    result = lstm_layer(x)\n    ```\n    \"\"\"\n\n    def __init__(self, cells, **kwargs):\n        super().__init__(**kwargs)\n        for cell in cells:\n            if \"call\" not in dir(cell):\n                raise ValueError(\n                    \"All cells must have a `call` method. \"\n                    f\"Received cell without a `call` method: {cell}\"\n                )\n            if \"state_size\" not in dir(cell):\n                raise ValueError(\n                    \"All cells must have a `state_size` attribute. \"\n                    f\"Received cell without a `state_size`: {cell}\"\n                )\n        self.cells = cells\n\n    @property\n    def state_size(self):\n        return [c.state_size for c in self.cells]\n\n    @property\n    def output_size(self):\n        if getattr(self.cells[-1], \"output_size\", None) is not None:\n            return self.cells[-1].output_size\n        elif isinstance(self.cells[-1].state_size, (list, tuple)):\n            return self.cells[-1].state_size[0]\n        else:\n            return self.cells[-1].state_size\n\n    def get_initial_state(self, batch_size=None):\n        initial_states = []\n        for cell in self.cells:\n            get_initial_state_fn = getattr(cell, \"get_initial_state\", None)\n            if get_initial_state_fn:\n                initial_states.append(\n                    get_initial_state_fn(batch_size=batch_size)\n                )\n            else:\n                if isinstance(cell.state_size, int):\n                    initial_states.append(\n                        ops.zeros(\n                            (batch_size, cell.state_size),\n                            dtype=self.compute_dtype,\n                        )\n                    )\n                else:\n                    initial_states.append(\n                        [\n                            ops.zeros((batch_size, d), dtype=self.compute_dtype)\n                            for d in cell.state_size\n                        ]\n                    )\n        return initial_states\n\n    def call(self, inputs, states, training=False, **kwargs):\n        # Call the cells in order and store the returned states.\n        new_states = []\n        for cell, states in zip(self.cells, states):\n            state_is_list = tree.is_nested(states)\n            states = list(states) if tree.is_nested(states) else [states]\n            if isinstance(cell, Layer) and cell._call_has_training_arg:\n                kwargs[\"training\"] = training\n            else:\n                kwargs.pop(\"training\", None)\n            cell_call_fn = cell.__call__ if callable(cell) else cell.call\n            inputs, states = cell_call_fn(inputs, states, **kwargs)\n            if len(states) == 1 and not state_is_list:\n                states = states[0]\n            new_states.append(states)\n\n        if len(new_states) == 1:\n            new_states = new_states[0]\n        return inputs, new_states\n\n    def build(self, input_shape):\n        for cell in self.cells:\n            if isinstance(cell, Layer) and not cell.built:\n                cell.build(input_shape)\n                cell.built = True\n            if getattr(cell, \"output_size\", None) is not None:\n                output_dim = cell.output_size\n            elif isinstance(cell.state_size, (list, tuple)):\n                output_dim = cell.state_size[0]\n            else:\n                output_dim = cell.state_size\n            batch_size = tree.flatten(input_shape)[0]\n            input_shape = (batch_size, output_dim)\n\n    def get_config(self):\n        cells = []\n        for cell in self.cells:\n            cells.append(serialization_lib.serialize_keras_object(cell))\n        config = {\"cells\": cells}\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n    @classmethod\n    def from_config(cls, config, custom_objects=None):\n        cells = []\n        for cell_config in config.pop(\"cells\"):\n            cells.append(\n                serialization_lib.deserialize_keras_object(\n                    cell_config, custom_objects=custom_objects\n                )\n            )\n        return cls(cells, **config)\n"
  },
  {
    "path": "keras/src/layers/rnn/stacked_rnn_cells_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import layers\nfrom keras.src import testing\nfrom keras.src.layers.rnn.rnn_test import OneStateRNNCell\nfrom keras.src.layers.rnn.rnn_test import TwoStatesRNNCell\n\n\nclass StackedRNNTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_basics(self):\n        self.run_layer_test(\n            layers.RNN,\n            init_kwargs={\n                \"cell\": [\n                    OneStateRNNCell(3),\n                    OneStateRNNCell(4),\n                    OneStateRNNCell(5),\n                ],\n            },\n            input_shape=(2, 3, 4),\n            expected_output_shape=(2, 5),\n            expected_num_trainable_weights=6,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            supports_masking=True,\n            custom_objects={\"OneStateRNNCell\": OneStateRNNCell},\n        )\n        self.run_layer_test(\n            layers.RNN,\n            init_kwargs={\n                \"cell\": [\n                    OneStateRNNCell(3),\n                    OneStateRNNCell(4),\n                    OneStateRNNCell(5),\n                ],\n                \"return_sequences\": True,\n            },\n            input_shape=(2, 3, 4),\n            expected_output_shape=(2, 3, 5),\n            expected_num_trainable_weights=6,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            supports_masking=True,\n            custom_objects={\"OneStateRNNCell\": OneStateRNNCell},\n        )\n        # Two-state case.\n        self.run_layer_test(\n            layers.RNN,\n            init_kwargs={\n                \"cell\": [\n                    TwoStatesRNNCell(3),\n                    TwoStatesRNNCell(4),\n                    TwoStatesRNNCell(5),\n                ],\n            },\n            input_shape=(2, 3, 4),\n            expected_output_shape=(2, 5),\n            expected_num_trainable_weights=9,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            supports_masking=True,\n            custom_objects={\"TwoStatesRNNCell\": TwoStatesRNNCell},\n        )\n        self.run_layer_test(\n            layers.RNN,\n            init_kwargs={\n                \"cell\": [\n                    TwoStatesRNNCell(3),\n                    TwoStatesRNNCell(4),\n                    TwoStatesRNNCell(5),\n                ],\n                \"return_sequences\": True,\n            },\n            input_shape=(2, 3, 4),\n            expected_output_shape=(2, 3, 5),\n            expected_num_trainable_weights=9,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=0,\n            supports_masking=True,\n            custom_objects={\"TwoStatesRNNCell\": TwoStatesRNNCell},\n        )\n        self.run_layer_test(\n            layers.RNN,\n            init_kwargs={\n                \"cell\": [\n                    layers.SimpleRNNCell(3, dropout=0.1, recurrent_dropout=0.1),\n                    layers.SimpleRNNCell(4, dropout=0.1, recurrent_dropout=0.1),\n                    layers.SimpleRNNCell(5, dropout=0.1, recurrent_dropout=0.1),\n                ],\n                \"return_sequences\": True,\n            },\n            input_shape=(2, 3, 4),\n            expected_output_shape=(2, 3, 5),\n            expected_num_trainable_weights=9,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=3,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.RNN,\n            init_kwargs={\n                \"cell\": [\n                    layers.GRUCell(3, dropout=0.1, recurrent_dropout=0.1),\n                    layers.GRUCell(4, dropout=0.1, recurrent_dropout=0.1),\n                    layers.GRUCell(5, dropout=0.1, recurrent_dropout=0.1),\n                ],\n                \"return_sequences\": True,\n            },\n            input_shape=(2, 3, 4),\n            expected_output_shape=(2, 3, 5),\n            expected_num_trainable_weights=9,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=3,\n            supports_masking=True,\n        )\n        self.run_layer_test(\n            layers.RNN,\n            init_kwargs={\n                \"cell\": [\n                    layers.LSTMCell(3, dropout=0.1, recurrent_dropout=0.1),\n                    layers.LSTMCell(4, dropout=0.1, recurrent_dropout=0.1),\n                    layers.LSTMCell(5, dropout=0.1, recurrent_dropout=0.1),\n                ],\n                \"return_sequences\": True,\n            },\n            input_shape=(2, 3, 4),\n            expected_output_shape=(2, 3, 5),\n            expected_num_trainable_weights=9,\n            expected_num_non_trainable_weights=0,\n            expected_num_seed_generators=3,\n            supports_masking=True,\n        )\n\n    def test_correctness_single_state_stack(self):\n        sequence = np.arange(24).reshape((2, 3, 4)).astype(\"float32\")\n        layer = layers.RNN([OneStateRNNCell(3), OneStateRNNCell(2)])\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array([[786.0, 786.0], [4386.0, 4386.0]]),\n            output,\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n\n        layer = layers.RNN(\n            [OneStateRNNCell(3), OneStateRNNCell(2)], return_sequences=True\n        )\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [[18.0, 18.0], [156.0, 156.0], [786.0, 786.0]],\n                    [[162.0, 162.0], [1020.0, 1020.0], [4386.0, 4386.0]],\n                ]\n            ),\n            output,\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n\n        layer = layers.RNN(\n            [OneStateRNNCell(3), OneStateRNNCell(2)], return_state=True\n        )\n        output, state_1, state_2 = layer(sequence)\n        self.assertAllClose(\n            np.array([[786.0, 786.0], [4386.0, 4386.0]]),\n            output,\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n        self.assertAllClose(\n            np.array([[158.0, 158.0, 158.0], [782.0, 782.0, 782.0]]), state_1\n        )\n        self.assertAllClose(\n            np.array([[786.0, 786.0], [4386.0, 4386.0]]),\n            state_2,\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n\n        layer = layers.RNN(\n            [OneStateRNNCell(3), OneStateRNNCell(2)],\n            return_sequences=True,\n            return_state=True,\n        )\n        output, state_1, state_2 = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [[18.0, 18.0], [156.0, 156.0], [786.0, 786.0]],\n                    [[162.0, 162.0], [1020.0, 1020.0], [4386.0, 4386.0]],\n                ]\n            ),\n            output,\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n        self.assertAllClose(\n            np.array([[158.0, 158.0, 158.0], [782.0, 782.0, 782.0]]),\n            state_1,\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n        self.assertAllClose(\n            np.array([[786.0, 786.0], [4386.0, 4386.0]]),\n            state_2,\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n\n    def test_correctness_two_states_stack(self):\n        sequence = np.arange(24).reshape((2, 3, 4)).astype(\"float32\")\n        layer = layers.RNN([TwoStatesRNNCell(3), TwoStatesRNNCell(2)])\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array([[3144.0, 3144.0], [17544.0, 17544.0]]),\n            output,\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n\n        layer = layers.RNN(\n            [TwoStatesRNNCell(3), TwoStatesRNNCell(2)], return_sequences=True\n        )\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [\n                    [[72.0, 72.0], [624.0, 624.0], [3144.0, 3144.0]],\n                    [[648.0, 648.0], [4080.0, 4080.0], [17544.0, 17544.0]],\n                ]\n            ),\n            output,\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n\n        layer = layers.RNN(\n            [TwoStatesRNNCell(3), TwoStatesRNNCell(2)], return_state=True\n        )\n        output, state_1, state_2 = layer(sequence)\n\n        self.assertAllClose(\n            np.array([[3144.0, 3144.0], [17544.0, 17544.0]]),\n            output,\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n        self.assertAllClose(\n            np.array([[158.0, 158.0, 158.0], [782.0, 782.0, 782.0]]),\n            state_1[0],\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n        self.assertAllClose(\n            np.array([[158.0, 158.0, 158.0], [782.0, 782.0, 782.0]]),\n            state_1[1],\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n        self.assertAllClose(\n            np.array([[1572.0, 1572.0], [8772.0, 8772.0]]),\n            state_2[0],\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n        self.assertAllClose(\n            np.array([[1572.0, 1572.0], [8772.0, 8772.0]]),\n            state_2[1],\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n\n    def test_statefullness_single_state_stack(self):\n        sequence = np.arange(24).reshape((2, 3, 4)).astype(\"float32\")\n        layer = layers.RNN(\n            [OneStateRNNCell(3), OneStateRNNCell(2)], stateful=True\n        )\n        layer(sequence)\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array([[34092.0, 34092.0], [173196.0, 173196.0]]),\n            output,\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n\n    def test_statefullness_two_states_stack(self):\n        sequence = np.arange(24).reshape((2, 3, 4)).astype(\"float32\")\n        layer = layers.RNN(\n            [TwoStatesRNNCell(3), TwoStatesRNNCell(2)], stateful=True\n        )\n        layer(sequence)\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array([[136368.0, 136368.0], [692784.0, 692784.0]]),\n            output,\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n\n    def test_return_state_stacked_lstm_cell(self):\n        layer = layers.RNN(\n            [layers.LSTMCell(10), layers.LSTMCell(10)], return_state=True\n        )\n        out = layer(np.zeros((2, 3, 5)))\n        self.assertLen(out, 3)\n        self.assertEqual(out[0].shape, (2, 10))\n        self.assertEqual(out[1][0].shape, (2, 10))\n        self.assertEqual(out[1][1].shape, (2, 10))\n        self.assertEqual(out[2][0].shape, (2, 10))\n        self.assertEqual(out[2][1].shape, (2, 10))\n\n        shape = layer.compute_output_shape((2, 3, 5))\n        self.assertLen(shape, 3)\n        self.assertEqual(shape[0], (2, 10))\n        self.assertEqual(shape[1][0], (2, 10))\n        self.assertEqual(shape[1][1], (2, 10))\n        self.assertEqual(shape[2][0], (2, 10))\n        self.assertEqual(shape[2][1], (2, 10))\n\n    def test_stacked_lstm_cell_mask(self):\n        sequence = np.ones((2, 3, 4))\n        mask = np.array([[True, True, True], [True, True, False]])\n        cell_kwargs = dict(\n            units=1, kernel_initializer=\"ones\", recurrent_initializer=\"ones\"\n        )\n        rnn_cells = [layers.LSTMCell(**cell_kwargs) for _ in range(2)]\n        stacked_rnn = layers.RNN(rnn_cells)\n        output = stacked_rnn(sequence, mask=mask)\n        self.assertAllClose(np.array([[0.7793], [0.5998]]), output, atol=1e-4)\n"
  },
  {
    "path": "keras/src/layers/rnn/time_distributed.py",
    "content": "\"\"\"Wrapper layer to apply every temporal slice of an input.\"\"\"\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.core.wrapper import Wrapper\nfrom keras.src.layers.layer import Layer\n\n\n@keras_export(\"keras.layers.TimeDistributed\")\nclass TimeDistributed(Wrapper):\n    \"\"\"This wrapper allows to apply a layer to every temporal slice of an input.\n\n    Every input should be at least 3D, and the dimension of index one of the\n    first input will be considered to be the temporal dimension.\n\n    Consider a batch of 32 video samples, where each sample is a 128x128 RGB\n    image with `channels_last` data format, across 10 timesteps.\n    The batch input shape is `(32, 10, 128, 128, 3)`.\n\n    You can then use `TimeDistributed` to apply the same `Conv2D` layer to each\n    of the 10 timesteps, independently:\n\n    >>> inputs = layers.Input(shape=(10, 128, 128, 3), batch_size=32)\n    >>> conv_2d_layer = layers.Conv2D(64, (3, 3))\n    >>> outputs = layers.TimeDistributed(conv_2d_layer)(inputs)\n    >>> outputs.shape\n    (32, 10, 126, 126, 64)\n\n    Because `TimeDistributed` applies the same instance of `Conv2D` to each of\n    the timestamps, the same set of weights are used at each timestamp.\n\n    Args:\n        layer: a `keras.layers.Layer` instance.\n\n    Call arguments:\n        inputs: Input tensor of shape (batch, time, ...) or nested tensors,\n            and each of which has shape (batch, time, ...).\n        training: Python boolean indicating whether the layer should behave in\n            training mode or in inference mode. This argument is passed to the\n            wrapped layer (only if the layer supports this argument).\n        mask: Binary tensor of shape `(samples, timesteps)` indicating whether\n            a given timestep should be masked. This argument is passed to the\n            wrapped layer (only if the layer supports this argument).\n    \"\"\"\n\n    def __init__(self, layer, **kwargs):\n        if not isinstance(layer, Layer):\n            raise ValueError(\n                \"Please initialize `TimeDistributed` layer with a \"\n                f\"`keras.layers.Layer` instance. Received: {layer}\"\n            )\n        super().__init__(layer, **kwargs)\n        self.supports_masking = True\n\n    def _get_child_input_shape(self, input_shape):\n        if not isinstance(input_shape, (tuple, list)) or len(input_shape) < 3:\n            raise ValueError(\n                \"`TimeDistributed` Layer should be passed an `input_shape` \"\n                f\"with at least 3 dimensions, received: {input_shape}\"\n            )\n        return (input_shape[0], *input_shape[2:])\n\n    def compute_output_shape(self, input_shape):\n        child_input_shape = self._get_child_input_shape(input_shape)\n        child_output_shape = self.layer.compute_output_shape(child_input_shape)\n        return (child_output_shape[0], input_shape[1], *child_output_shape[1:])\n\n    def build(self, input_shape):\n        child_input_shape = self._get_child_input_shape(input_shape)\n        super().build(child_input_shape)\n\n    def call(self, inputs, training=None, mask=None):\n        # Validate mask shape using static shape info when available\n        if mask is not None:\n            mask_shape = mask.shape\n            input_shape = inputs.shape\n\n            # Check if mask has at least 2 dimensions (batch and timesteps)\n            if len(mask_shape) < 2:\n                raise ValueError(\n                    \"The `mask` passed to the `TimeDistributed` layer must be \"\n                    \"at least 2D (e.g., `(batch_size, timesteps)`), but it has \"\n                    f\"{len(mask_shape)} dimension(s) with shape {mask_shape}.\"\n                )\n\n            # Check batch size and timesteps dimensions match\n            batch_mismatch = (\n                input_shape[0] is not None\n                and mask_shape[0] is not None\n                and input_shape[0] != mask_shape[0]\n            )\n            time_mismatch = (\n                input_shape[1] is not None\n                and mask_shape[1] is not None\n                and input_shape[1] != mask_shape[1]\n            )\n\n            if batch_mismatch or time_mismatch:\n                raise ValueError(\n                    \"The `mask` passed to the `TimeDistributed` layer has a \"\n                    f\"shape {mask_shape} that is incompatible with the input \"\n                    f\"shape {input_shape}. The first two dimensions of the \"\n                    \"mask (batch size and timesteps) must match the input's \"\n                    \"first two dimensions. Expected mask shape prefix: \"\n                    f\"({input_shape[0]}, {input_shape[1]}).\"\n                )\n\n        input_shape = ops.shape(inputs)\n\n        def time_distributed_transpose(data):\n            \"\"\"Swaps the timestep and batch dimensions of a tensor.\"\"\"\n            axes = [1, 0, *range(2, len(data.shape))]\n            return ops.transpose(data, axes=axes)\n\n        inputs = time_distributed_transpose(inputs)\n        if mask is not None:\n            mask = time_distributed_transpose(mask)\n\n        def step_function(i):\n            kwargs = {}\n            if self.layer._call_has_mask_arg and mask is not None:\n                kwargs[\"mask\"] = mask[i]\n            if self.layer._call_has_training_arg:\n                kwargs[\"training\"] = training\n            return self.layer.call(inputs[i], **kwargs)\n\n        # Implementation #1: is the time axis is static, use a Python for loop.\n\n        if inputs.shape[0] is not None:\n            outputs = ops.stack(\n                [step_function(i) for i in range(inputs.shape[0])]\n            )\n            return time_distributed_transpose(outputs)\n\n        # Implementation #2: use backend.vectorized_map.\n\n        outputs = backend.vectorized_map(\n            step_function, ops.arange(input_shape[0])\n        )\n        return time_distributed_transpose(outputs)\n"
  },
  {
    "path": "keras/src/layers/rnn/time_distributed_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import backend\nfrom keras.src import initializers\nfrom keras.src import layers\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.models import Sequential\n\n\nclass TimeDistributedTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_basics(self):\n        self.run_layer_test(\n            layers.TimeDistributed,\n            init_kwargs={\"layer\": layers.Dense(1, use_bias=False)},\n            input_shape=(3, 2, 4),\n            expected_output_shape=(3, 2, 1),\n            expected_num_trainable_weights=1,\n            expected_num_non_trainable_weights=0,\n            supports_masking=True,\n        )\n\n    def test_build(self):\n        if backend.config.image_data_format() == \"channels_last\":\n            input_shape = (10, 128, 128, 3)\n            output_shape = (32, 10, 126, 126, 64)\n        else:\n            input_shape = (10, 3, 128, 128)\n            output_shape = (32, 10, 64, 126, 126)\n        inputs = layers.Input(shape=input_shape, batch_size=32)\n        conv_2d_layer = layers.Conv2D(64, (3, 3))\n        outputs = layers.TimeDistributed(conv_2d_layer)(inputs)\n        self.assertEqual(outputs.shape, output_shape)\n\n    def test_correctness(self):\n        sequence = np.arange(24).reshape((3, 2, 4)).astype(\"float32\")\n        layer = layers.Dense(\n            1,\n            kernel_initializer=initializers.Constant(0.01),\n            use_bias=False,\n        )\n        layer = layers.TimeDistributed(layer=layer)\n        output = layer(sequence)\n        self.assertAllClose(\n            np.array(\n                [[[0.06], [0.22]], [[0.38], [0.53999996]], [[0.7], [0.86]]]\n            ),\n            output,\n        )\n\n    def test_masking(self):\n        class MaskedDense(layers.Wrapper):\n            def __init__(self, units, **kwargs):\n                layer = layers.Dense(\n                    units,\n                    kernel_initializer=initializers.Constant(0.01),\n                    use_bias=False,\n                )\n                super().__init__(layer, **kwargs)\n                self.supports_masking = True\n\n            def call(self, inputs, training=False, mask=None):\n                unmasked = self.layer.call(inputs)\n                if mask is None:\n                    return unmasked\n                else:\n                    return ops.transpose(\n                        ops.transpose(unmasked) * ops.cast(mask, inputs.dtype)\n                    )\n\n        sequence = np.arange(24).reshape((3, 2, 4)).astype(\"float32\")\n        layer = layers.TimeDistributed(layer=MaskedDense(1))\n        mask = np.array([[False, True], [True, False], [True, True]])\n        output = layer(sequence, mask=mask)\n        self.assertAllClose(\n            np.array([[[0], [0.22]], [[0.38], [0]], [[0.7], [0.86]]]),\n            output,\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_with_mask_zero(self):\n        model = Sequential(\n            [\n                layers.Input(shape=(20,)),\n                layers.Embedding(input_dim=10, output_dim=5, mask_zero=True),\n                layers.TimeDistributed(\n                    layers.Dense(units=5, activation=\"softmax\")\n                ),\n            ]\n        )\n        model.compile(\n            optimizer=\"adam\",\n            loss=\"sparse_categorical_crossentropy\",\n            metrics=[\"accuracy\"],\n        )\n        X_train = np.random.uniform(1, 10, size=(22, 20))\n        Y_train = np.random.randint(1, 2, size=(22, 20))\n\n        model.fit(X_train, Y_train, epochs=1, batch_size=16)\n\n    def test_mask_validation_with_mismatched_timesteps(self):\n        \"\"\"Test TimeDistributed raises ValueError for mask with wrong timesteps.\n\n        Regression test for: https://github.com/keras-team/keras/issues/22037\n        \"\"\"\n        batch = 4\n        timesteps = 5\n        features = 3\n\n        td = layers.TimeDistributed(layers.Dense(units=5, activation=\"softmax\"))\n        inputs = np.zeros((batch, timesteps, features))\n\n        # Mask with correct timesteps should work\n        mask_valid = np.ones((batch, timesteps), dtype=bool)\n        output = td(inputs, mask=mask_valid)\n        self.assertEqual(output.shape, (batch, timesteps, 5))\n\n        # Mask with mismatched timesteps should raise ValueError\n        mask_timemismatch = np.ones((batch, timesteps + 1), dtype=bool)\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"The `mask` passed to the `TimeDistributed` layer has a shape.*\",\n        ):\n            td(inputs, mask=mask_timemismatch)\n\n    def test_mask_validation_with_mismatched_batch_size(self):\n        \"\"\"\n        Test TimeDistributed raises ValueError for mask with wrong batch size.\n\n        This tests the batch size validation in eager/non-TF backends.\n        \"\"\"\n        batch = 4\n        timesteps = 5\n        features = 3\n\n        td = layers.TimeDistributed(layers.Dense(units=5, activation=\"softmax\"))\n        inputs = np.zeros((batch, timesteps, features))\n\n        # Mask with mismatched batch size should raise ValueError\n        mask_batchmismatch = np.ones((batch + 1, timesteps), dtype=bool)\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"The `mask` passed to the `TimeDistributed` layer has a shape.*\",\n        ):\n            td(inputs, mask=mask_batchmismatch)\n\n    def test_mask_validation_with_correct_shape(self):\n        \"\"\"\n        Test TimeDistributed accepts mask with correct shape.\n\n        Tests that the validation passes when mask shape is\n        (batch_size, timesteps).\n        \"\"\"\n        batch = 4\n        timesteps = 5\n        features = 3\n\n        td = layers.TimeDistributed(layers.Dense(units=5, activation=\"softmax\"))\n        inputs = np.zeros((batch, timesteps, features))\n\n        # Mask with correct shape (batch_size, timesteps, extra_dims)\n        mask_correct = np.ones((batch, timesteps, 1), dtype=bool)\n        output = td(inputs, mask=mask_correct)\n        self.assertEqual(output.shape, (batch, timesteps, 5))\n\n    def test_mask_validation_with_3d_mask(self):\n        \"\"\"\n        Test TimeDistributed accepts 3D mask with correct leading dimensions.\n\n        This tests masks with additional dimensions beyond\n        (batch_size, timesteps).\n        \"\"\"\n        batch = 4\n        timesteps = 5\n        features = 3\n\n        td = layers.TimeDistributed(layers.Dense(units=5, activation=\"softmax\"))\n        inputs = np.zeros((batch, timesteps, features))\n\n        # 3D mask with correct batch and timestep dimensions\n        mask_3d = np.ones((batch, timesteps, 2), dtype=bool)\n        output = td(inputs, mask=mask_3d)\n        self.assertEqual(output.shape, (batch, timesteps, 5))\n\n        # 3D mask with mismatched timesteps should raise ValueError\n        mask_3d_mismatch = np.ones((batch, timesteps + 1, 2), dtype=bool)\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"The `mask` passed to the `TimeDistributed` layer has a shape.*\",\n        ):\n            td(inputs, mask=mask_3d_mismatch)\n\n    def test_mask_validation_with_none_mask(self):\n        \"\"\"Test TimeDistributed works correctly with None mask.\n\n        This tests that no validation error is raised when mask is None.\n        \"\"\"\n        batch = 4\n        timesteps = 5\n        features = 3\n\n        td = layers.TimeDistributed(layers.Dense(units=5, activation=\"softmax\"))\n        inputs = np.zeros((batch, timesteps, features))\n\n        # None mask should work without raising any error\n        output = td(inputs, mask=None)\n        self.assertEqual(output.shape, (batch, timesteps, 5))\n\n    def test_mask_validation_with_both_batch_and_timesteps_mismatched(self):\n        \"\"\"\n        Test TimeDistributed raises ValueError when both batch and timesteps\n        mismatch.\n\n        This ensures the validation catches cases where multiple dimensions\n        are wrong.\n        \"\"\"\n        batch = 4\n        timesteps = 5\n        features = 3\n\n        td = layers.TimeDistributed(layers.Dense(units=5, activation=\"softmax\"))\n        inputs = np.zeros((batch, timesteps, features))\n\n        # Mask with both batch and timesteps mismatched\n        mask_both_mismatch = np.ones((batch + 1, timesteps + 1), dtype=bool)\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"The `mask` passed to the `TimeDistributed` layer has a shape.*\",\n        ):\n            td(inputs, mask=mask_both_mismatch)\n"
  },
  {
    "path": "keras/src/legacy/__init__.py",
    "content": ""
  },
  {
    "path": "keras/src/legacy/backend.py",
    "content": "\"\"\"Legacy Keras 1/2 backend functions.\"\"\"\n\nimport itertools\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.utils.module_utils import tensorflow as tf\n\npy_any = any\npy_all = all\n\n\n@keras_export(\"keras._legacy.backend.abs\")\ndef abs(x):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.abs(x)\n\n\n@keras_export(\"keras._legacy.backend.all\")\ndef all(x, axis=None, keepdims=False):\n    \"\"\"DEPRECATED.\"\"\"\n    x = tf.cast(x, tf.bool)\n    return tf.reduce_all(x, axis, keepdims)\n\n\n@keras_export(\"keras._legacy.backend.any\")\ndef any(x, axis=None, keepdims=False):\n    \"\"\"DEPRECATED.\"\"\"\n    x = tf.cast(x, tf.bool)\n    return tf.reduce_any(x, axis, keepdims)\n\n\n@keras_export(\"keras._legacy.backend.argmax\")\ndef argmax(x, axis=-1):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.argmax(x, axis)\n\n\n@keras_export(\"keras._legacy.backend.argmin\")\ndef argmin(x, axis=-1):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.argmin(x, axis)\n\n\n@keras_export(\"keras._legacy.backend.arange\")\ndef arange(start, stop=None, step=1, dtype=\"int32\"):\n    \"\"\"DEPRECATED.\"\"\"\n    if stop is None and start < 0:\n        start = 0\n    result = tf.range(start, limit=stop, delta=step, name=\"arange\")\n    if dtype != \"int32\":\n        result = tf.cast(result, dtype)\n    return result\n\n\n@keras_export(\"keras._legacy.backend.batch_dot\")\ndef batch_dot(x, y, axes=None):\n    \"\"\"DEPRECATED.\"\"\"\n    x_shape = x.shape\n    y_shape = y.shape\n\n    x_ndim = len(x_shape)\n    y_ndim = len(y_shape)\n\n    if x_ndim < 2 or y_ndim < 2:\n        raise ValueError(\n            \"Cannot do batch_dot on inputs \"\n            \"with rank < 2. \"\n            f\"Received inputs with tf.shapes {x_shape} and {y_shape}.\"\n        )\n\n    x_batch_size = x_shape[0]\n    y_batch_size = y_shape[0]\n\n    if x_batch_size is not None and y_batch_size is not None:\n        if x_batch_size != y_batch_size:\n            raise ValueError(\n                \"Cannot do batch_dot on inputs \"\n                \"with different batch sizes. \"\n                \"Received inputs with tf.shapes \"\n                f\"{x_shape} and {y_shape}.\"\n            )\n    if isinstance(axes, int):\n        axes = [axes, axes]\n\n    if axes is None:\n        if y_ndim == 2:\n            axes = [x_ndim - 1, y_ndim - 1]\n        else:\n            axes = [x_ndim - 1, y_ndim - 2]\n\n    if py_any(isinstance(a, (list, tuple)) for a in axes):\n        raise ValueError(\n            \"Multiple target dimensions are not supported. \"\n            \"Expected: None, int, (int, int), \"\n            f\"Provided: {axes}\"\n        )\n\n    # if tuple, convert to list.\n    axes = list(axes)\n\n    # convert negative indices.\n    if axes[0] < 0:\n        axes[0] += x_ndim\n    if axes[1] < 0:\n        axes[1] += y_ndim\n\n    # sanity checks\n    if 0 in axes:\n        raise ValueError(\n            \"Cannot perform batch_dot over axis 0. \"\n            \"If your inputs are not batched, \"\n            \"add a dummy batch dimension to your \"\n            \"inputs using K.expand_dims(x, 0)\"\n        )\n    a0, a1 = axes\n    d1 = x_shape[a0]\n    d2 = y_shape[a1]\n\n    if d1 is not None and d2 is not None and d1 != d2:\n        raise ValueError(\n            \"Cannot do batch_dot on inputs with tf.shapes \"\n            f\"{x_shape} and {y_shape} with axes={axes}. \"\n            \"x.shape[%d] != y.shape[%d] (%d != %d).\"\n            % (axes[0], axes[1], d1, d2)\n        )\n\n    # backup ndims. Need them later.\n    orig_x_ndim = x_ndim\n    orig_y_ndim = y_ndim\n\n    # if rank is 2, expand to 3.\n    if x_ndim == 2:\n        x = tf.expand_dims(x, 1)\n        a0 += 1\n        x_ndim += 1\n    if y_ndim == 2:\n        y = tf.expand_dims(y, 2)\n        y_ndim += 1\n\n    # bring x's dimension to be reduced to last axis.\n    if a0 != x_ndim - 1:\n        pattern = list(range(x_ndim))\n        for i in range(a0, x_ndim - 1):\n            pattern[i] = pattern[i + 1]\n        pattern[-1] = a0\n        x = tf.transpose(x, pattern)\n\n    # bring y's dimension to be reduced to axis 1.\n    if a1 != 1:\n        pattern = list(range(y_ndim))\n        for i in range(a1, 1, -1):\n            pattern[i] = pattern[i - 1]\n        pattern[1] = a1\n        y = tf.transpose(y, pattern)\n\n    # normalize both inputs to rank 3.\n    if x_ndim > 3:\n        # squash middle dimensions of x.\n        x_shape = tf.shape(x)\n        x_mid_dims = x_shape[1:-1]\n        x_squashed_shape = tf.stack([x_shape[0], -1, x_shape[-1]])\n        x = tf.reshape(x, x_squashed_shape)\n        x_squashed = True\n    else:\n        x_squashed = False\n\n    if y_ndim > 3:\n        # squash trailing dimensions of y.\n        y_shape = tf.shape(y)\n        y_trail_dims = y_shape[2:]\n        y_squashed_shape = tf.stack([y_shape[0], y_shape[1], -1])\n        y = tf.reshape(y, y_squashed_shape)\n        y_squashed = True\n    else:\n        y_squashed = False\n\n    result = tf.matmul(x, y)\n\n    # if inputs were squashed, we have to reshape the matmul output.\n    output_shape = tf.shape(result)\n    do_reshape = False\n\n    if x_squashed:\n        output_shape = tf.concat(\n            [output_shape[:1], x_mid_dims, output_shape[-1:]], 0\n        )\n        do_reshape = True\n\n    if y_squashed:\n        output_shape = tf.concat([output_shape[:-1], y_trail_dims], 0)\n        do_reshape = True\n\n    if do_reshape:\n        result = tf.reshape(result, output_shape)\n\n    # if the inputs were originally rank 2, we remove the added 1 dim.\n    if orig_x_ndim == 2:\n        result = tf.squeeze(result, 1)\n    elif orig_y_ndim == 2:\n        result = tf.squeeze(result, -1)\n\n    return result\n\n\n@keras_export(\"keras._legacy.backend.batch_flatten\")\ndef batch_flatten(x):\n    \"\"\"DEPRECATED.\"\"\"\n    x = tf.reshape(x, tf.stack([-1, prod(tf.shape(x)[1:])]))\n    return x\n\n\n@keras_export(\"keras._legacy.backend.batch_get_value\")\ndef batch_get_value(tensors):\n    \"\"\"DEPRECATED.\"\"\"\n    return [x.numpy() for x in tensors]\n\n\n@keras_export(\"keras._legacy.backend.batch_set_value\")\ndef batch_set_value(tuples):\n    \"\"\"DEPRECATED.\"\"\"\n    if tf.executing_eagerly() or tf.inside_function():\n        for x, value in tuples:\n            value = np.asarray(value, dtype=x.dtype.name)\n            x.assign(value)\n\n\n@keras_export(\"keras._legacy.backend.batch_normalization\")\ndef batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon)\n\n\n@keras_export(\"keras._legacy.backend.bias_add\")\ndef bias_add(x, bias, data_format=None):\n    \"\"\"DEPRECATED.\"\"\"\n    if data_format is None:\n        data_format = backend.image_data_format()\n    if data_format not in {\"channels_first\", \"channels_last\"}:\n        raise ValueError(f\"Unknown data_format: {data_format}\")\n    bias_shape = bias.shape\n    if len(bias_shape) != 1 and len(bias_shape) != ndim(x) - 1:\n        raise ValueError(\n            f\"Unexpected bias dimensions {len(bias_shape)}. \"\n            f\"Expected it to be 1 or {ndim(x) - 1} dimensions\"\n        )\n\n    if len(bias_shape) == 1:\n        if data_format == \"channels_first\":\n            return tf.nn.bias_add(x, bias, data_format=\"NCHW\")\n        return tf.nn.bias_add(x, bias, data_format=\"NHWC\")\n    if ndim(x) in (3, 4, 5):\n        if data_format == \"channels_first\":\n            bias_reshape_axis = (1, bias_shape[-1]) + bias_shape[:-1]\n            return x + reshape(bias, bias_reshape_axis)\n        return x + reshape(bias, (1,) + bias_shape)\n    return tf.nn.bias_add(x, bias)\n\n\n@keras_export(\"keras._legacy.backend.binary_crossentropy\")\ndef binary_crossentropy(target, output, from_logits=False):\n    \"\"\"DEPRECATED.\"\"\"\n    target = tf.convert_to_tensor(target)\n    output = tf.convert_to_tensor(output)\n\n    if from_logits:\n        return tf.nn.sigmoid_cross_entropy_with_logits(\n            labels=target, logits=output\n        )\n\n    epsilon_ = tf.convert_to_tensor(backend.epsilon(), output.dtype)\n    output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_)\n\n    # Compute cross entropy from probabilities.\n    bce = target * tf.math.log(output + backend.epsilon())\n    bce += (1 - target) * tf.math.log(1 - output + backend.epsilon())\n    return -bce\n\n\n@keras_export(\"keras._legacy.backend.binary_focal_crossentropy\")\ndef binary_focal_crossentropy(\n    target,\n    output,\n    apply_class_balancing=False,\n    alpha=0.25,\n    gamma=2.0,\n    from_logits=False,\n):\n    \"\"\"DEPRECATED.\"\"\"\n    sigmoidal = tf.sigmoid(output) if from_logits else output\n\n    p_t = target * sigmoidal + (1 - target) * (1 - sigmoidal)\n\n    # Calculate focal factor\n    focal_factor = tf.pow(1.0 - p_t, gamma)\n\n    # Binary crossentropy\n    bce = binary_crossentropy(\n        target=target,\n        output=output,\n        from_logits=from_logits,\n    )\n    focal_bce = focal_factor * bce\n\n    if apply_class_balancing:\n        weight = target * alpha + (1 - target) * (1 - alpha)\n        focal_bce = weight * focal_bce\n\n    return focal_bce\n\n\n@keras_export(\"keras._legacy.backend.cast\")\ndef cast(x, dtype):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.cast(x, dtype)\n\n\n@keras_export(\"keras._legacy.backend.cast_to_floatx\")\ndef cast_to_floatx(x):\n    \"\"\"DEPRECATED.\"\"\"\n    if isinstance(x, (tf.Tensor, tf.Variable, tf.SparseTensor)):\n        return tf.cast(x, dtype=backend.floatx())\n    return np.asarray(x, dtype=backend.floatx())\n\n\n@keras_export(\"keras._legacy.backend.categorical_crossentropy\")\ndef categorical_crossentropy(target, output, from_logits=False, axis=-1):\n    \"\"\"DEPRECATED.\"\"\"\n    target = tf.convert_to_tensor(target)\n    output = tf.convert_to_tensor(output)\n    target.shape.assert_is_compatible_with(output.shape)\n\n    if from_logits:\n        return tf.nn.softmax_cross_entropy_with_logits(\n            labels=target, logits=output, axis=axis\n        )\n\n    # Adjust the predictions so that the probability of\n    # each class for every sample adds up to 1\n    # This is needed to ensure that the cross entropy is\n    # computed correctly.\n    output = output / tf.reduce_sum(output, axis, True)\n\n    # Compute cross entropy from probabilities.\n    epsilon_ = tf.convert_to_tensor(backend.epsilon(), output.dtype)\n    output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_)\n    return -tf.reduce_sum(target * tf.math.log(output), axis)\n\n\n@keras_export(\"keras._legacy.backend.categorical_focal_crossentropy\")\ndef categorical_focal_crossentropy(\n    target,\n    output,\n    alpha=0.25,\n    gamma=2.0,\n    from_logits=False,\n    axis=-1,\n):\n    \"\"\"DEPRECATED.\"\"\"\n    target = tf.convert_to_tensor(target)\n    output = tf.convert_to_tensor(output)\n    target.shape.assert_is_compatible_with(output.shape)\n\n    if from_logits:\n        output = tf.nn.softmax(output, axis=axis)\n\n    # Adjust the predictions so that the probability of\n    # each class for every sample adds up to 1\n    # This is needed to ensure that the cross entropy is\n    # computed correctly.\n    output = output / tf.reduce_sum(output, axis=axis, keepdims=True)\n\n    epsilon_ = tf.convert_to_tensor(backend.epsilon(), output.dtype)\n    output = tf.clip_by_value(output, epsilon_, 1.0 - epsilon_)\n\n    # Calculate cross entropy\n    cce = -target * tf.math.log(output)\n\n    # Calculate factors\n    modulating_factor = tf.pow(1.0 - output, gamma)\n    weighting_factor = tf.multiply(modulating_factor, alpha)\n\n    # Apply weighting factor\n    focal_cce = tf.multiply(weighting_factor, cce)\n    focal_cce = tf.reduce_sum(focal_cce, axis=axis)\n    return focal_cce\n\n\n@keras_export(\"keras._legacy.backend.clip\")\ndef clip(x, min_value, max_value):\n    \"\"\"DEPRECATED.\"\"\"\n    if isinstance(min_value, (int, float)) and isinstance(\n        max_value, (int, float)\n    ):\n        if max_value < min_value:\n            max_value = min_value\n    if min_value is None:\n        min_value = -np.inf\n    if max_value is None:\n        max_value = np.inf\n    return tf.clip_by_value(x, min_value, max_value)\n\n\n@keras_export(\"keras._legacy.backend.concatenate\")\ndef concatenate(tensors, axis=-1):\n    \"\"\"DEPRECATED.\"\"\"\n    if axis < 0:\n        rank = ndim(tensors[0])\n        if rank:\n            axis %= rank\n        else:\n            axis = 0\n\n    if py_all(is_sparse(x) for x in tensors):\n        return tf.compat.v1.sparse_concat(axis, tensors)\n    elif py_all(isinstance(x, tf.RaggedTensor) for x in tensors):\n        return tf.concat(tensors, axis)\n    else:\n        return tf.concat([to_dense(x) for x in tensors], axis)\n\n\n@keras_export(\"keras._legacy.backend.constant\")\ndef constant(value, dtype=None, shape=None, name=None):\n    \"\"\"DEPRECATED.\"\"\"\n    if dtype is None:\n        dtype = backend.floatx()\n\n    return tf.constant(value, dtype=dtype, shape=shape, name=name)\n\n\ndef _preprocess_conv1d_input(x, data_format):\n    tf_data_format = \"NWC\"  # to pass TF Conv2dNative operations\n    if data_format == \"channels_first\":\n        tf_data_format = \"NCW\"\n    return x, tf_data_format\n\n\ndef _preprocess_conv2d_input(x, data_format, force_transpose=False):\n    tf_data_format = \"NHWC\"\n    if data_format == \"channels_first\":\n        if force_transpose:\n            x = tf.transpose(x, (0, 2, 3, 1))  # NCHW -> NHWC\n        else:\n            tf_data_format = \"NCHW\"\n    return x, tf_data_format\n\n\ndef _preprocess_conv3d_input(x, data_format):\n    tf_data_format = \"NDHWC\"\n    if data_format == \"channels_first\":\n        tf_data_format = \"NCDHW\"\n    return x, tf_data_format\n\n\ndef _preprocess_padding(padding):\n    if padding == \"same\":\n        padding = \"SAME\"\n    elif padding == \"valid\":\n        padding = \"VALID\"\n    else:\n        raise ValueError(f\"Invalid padding: {padding}\")\n    return padding\n\n\n@keras_export(\"keras._legacy.backend.conv1d\")\ndef conv1d(\n    x, kernel, strides=1, padding=\"valid\", data_format=None, dilation_rate=1\n):\n    \"\"\"DEPRECATED.\"\"\"\n    if data_format is None:\n        data_format = backend.image_data_format()\n    if data_format not in {\"channels_first\", \"channels_last\"}:\n        raise ValueError(f\"Unknown data_format: {data_format}\")\n\n    kernel_shape = kernel.shape.as_list()\n    if padding == \"causal\":\n        # causal (dilated) convolution:\n        left_pad = dilation_rate * (kernel_shape[0] - 1)\n        x = temporal_padding(x, (left_pad, 0))\n        padding = \"valid\"\n    padding = _preprocess_padding(padding)\n\n    x, tf_data_format = _preprocess_conv1d_input(x, data_format)\n    x = tf.compat.v1.nn.convolution(\n        input=x,\n        filter=kernel,\n        dilation_rate=dilation_rate,\n        strides=strides,\n        padding=padding,\n        data_format=tf_data_format,\n    )\n    if data_format == \"channels_first\" and tf_data_format == \"NWC\":\n        x = tf.transpose(x, (0, 2, 1))  # NWC -> NCW\n    return x\n\n\n@keras_export(\"keras._legacy.backend.conv2d\")\ndef conv2d(\n    x,\n    kernel,\n    strides=(1, 1),\n    padding=\"valid\",\n    data_format=None,\n    dilation_rate=(1, 1),\n):\n    \"\"\"DEPRECATED.\"\"\"\n    if data_format is None:\n        data_format = backend.image_data_format()\n    if data_format not in {\"channels_first\", \"channels_last\"}:\n        raise ValueError(f\"Unknown data_format: {data_format}\")\n\n    x, tf_data_format = _preprocess_conv2d_input(x, data_format)\n    padding = _preprocess_padding(padding)\n    x = tf.compat.v1.nn.convolution(\n        input=x,\n        filter=kernel,\n        dilation_rate=dilation_rate,\n        strides=strides,\n        padding=padding,\n        data_format=tf_data_format,\n    )\n    if data_format == \"channels_first\" and tf_data_format == \"NHWC\":\n        x = tf.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW\n    return x\n\n\n@keras_export(\"keras._legacy.backend.conv2d_transpose\")\ndef conv2d_transpose(\n    x,\n    kernel,\n    output_shape,\n    strides=(1, 1),\n    padding=\"valid\",\n    data_format=None,\n    dilation_rate=(1, 1),\n):\n    \"\"\"DEPRECATED.\"\"\"\n    if data_format is None:\n        data_format = backend.image_data_format()\n    if data_format not in {\"channels_first\", \"channels_last\"}:\n        raise ValueError(f\"Unknown data_format: {data_format}\")\n\n    # `atrous_conv2d_transpose` only supports NHWC format, even on GPU.\n    if data_format == \"channels_first\" and dilation_rate != (1, 1):\n        force_transpose = True\n    else:\n        force_transpose = False\n\n    x, tf_data_format = _preprocess_conv2d_input(\n        x, data_format, force_transpose\n    )\n\n    if data_format == \"channels_first\" and tf_data_format == \"NHWC\":\n        output_shape = (\n            output_shape[0],\n            output_shape[2],\n            output_shape[3],\n            output_shape[1],\n        )\n    if output_shape[0] is None:\n        output_shape = (tf.shape(x)[0],) + tuple(output_shape[1:])\n\n    if isinstance(output_shape, (tuple, list)):\n        output_shape = tf.stack(list(output_shape))\n\n    padding = _preprocess_padding(padding)\n    if tf_data_format == \"NHWC\":\n        strides = (1,) + strides + (1,)\n    else:\n        strides = (1, 1) + strides\n\n    if dilation_rate == (1, 1):\n        x = tf.compat.v1.nn.conv2d_transpose(\n            x,\n            kernel,\n            output_shape,\n            strides,\n            padding=padding,\n            data_format=tf_data_format,\n        )\n    else:\n        if dilation_rate[0] != dilation_rate[1]:\n            raise ValueError(\n                \"Expected the 2 dimensions of the `dilation_rate` argument \"\n                \"to be equal to each other. \"\n                f\"Received: dilation_rate={dilation_rate}\"\n            )\n        x = tf.nn.atrous_conv2d_transpose(\n            x, kernel, output_shape, rate=dilation_rate[0], padding=padding\n        )\n    if data_format == \"channels_first\" and tf_data_format == \"NHWC\":\n        x = tf.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW\n    return x\n\n\n@keras_export(\"keras._legacy.backend.conv3d\")\ndef conv3d(\n    x,\n    kernel,\n    strides=(1, 1, 1),\n    padding=\"valid\",\n    data_format=None,\n    dilation_rate=(1, 1, 1),\n):\n    \"\"\"DEPRECATED.\"\"\"\n    if data_format is None:\n        data_format = backend.image_data_format()\n    if data_format not in {\"channels_first\", \"channels_last\"}:\n        raise ValueError(f\"Unknown data_format: {data_format}\")\n\n    x, tf_data_format = _preprocess_conv3d_input(x, data_format)\n    padding = _preprocess_padding(padding)\n    x = tf.compat.v1.nn.convolution(\n        input=x,\n        filter=kernel,\n        dilation_rate=dilation_rate,\n        strides=strides,\n        padding=padding,\n        data_format=tf_data_format,\n    )\n    if data_format == \"channels_first\" and tf_data_format == \"NDHWC\":\n        x = tf.transpose(x, (0, 4, 1, 2, 3))\n    return x\n\n\n@keras_export(\"keras._legacy.backend.cos\")\ndef cos(x):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.cos(x)\n\n\n@keras_export(\"keras._legacy.backend.count_params\")\ndef count_params(x):\n    \"\"\"DEPRECATED.\"\"\"\n    return np.prod(x.shape.as_list())\n\n\n@keras_export(\"keras._legacy.backend.ctc_batch_cost\")\ndef ctc_batch_cost(y_true, y_pred, input_length, label_length):\n    \"\"\"DEPRECATED.\"\"\"\n    label_length = tf.cast(tf.squeeze(label_length, axis=-1), tf.int32)\n    input_length = tf.cast(tf.squeeze(input_length, axis=-1), tf.int32)\n    sparse_labels = tf.cast(\n        ctc_label_dense_to_sparse(y_true, label_length), tf.int32\n    )\n\n    y_pred = tf.math.log(\n        tf.transpose(y_pred, perm=[1, 0, 2]) + backend.epsilon()\n    )\n\n    return tf.expand_dims(\n        tf.compat.v1.nn.ctc_loss(\n            inputs=y_pred, labels=sparse_labels, sequence_length=input_length\n        ),\n        1,\n    )\n\n\n@keras_export(\"keras._legacy.backend.ctc_label_dense_to_sparse\")\ndef ctc_label_dense_to_sparse(labels, label_lengths):\n    \"\"\"DEPRECATED.\"\"\"\n    label_shape = tf.shape(labels)\n    num_batches_tns = tf.stack([label_shape[0]])\n    max_num_labels_tns = tf.stack([label_shape[1]])\n\n    def range_less_than(old_input, current_input):\n        return tf.expand_dims(tf.range(tf.shape(old_input)[1]), 0) < tf.fill(\n            max_num_labels_tns, current_input\n        )\n\n    init = tf.cast(tf.fill([1, label_shape[1]], 0), tf.bool)\n    dense_mask = tf.compat.v1.scan(\n        range_less_than, label_lengths, initializer=init, parallel_iterations=1\n    )\n    dense_mask = dense_mask[:, 0, :]\n\n    label_array = tf.reshape(\n        tf.tile(tf.range(0, label_shape[1]), num_batches_tns), label_shape\n    )\n    label_ind = tf.compat.v1.boolean_mask(label_array, dense_mask)\n\n    batch_array = tf.transpose(\n        tf.reshape(\n            tf.tile(tf.range(0, label_shape[0]), max_num_labels_tns),\n            reverse(label_shape, 0),\n        )\n    )\n    batch_ind = tf.compat.v1.boolean_mask(batch_array, dense_mask)\n    indices = tf.transpose(\n        tf.reshape(concatenate([batch_ind, label_ind], axis=0), [2, -1])\n    )\n\n    vals_sparse = tf.compat.v1.gather_nd(labels, indices)\n\n    return tf.SparseTensor(\n        tf.cast(indices, tf.int64), vals_sparse, tf.cast(label_shape, tf.int64)\n    )\n\n\n@keras_export(\"keras._legacy.backend.ctc_decode\")\ndef ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):\n    \"\"\"DEPRECATED.\"\"\"\n    input_shape = tf.shape(y_pred)\n    num_samples, num_steps = input_shape[0], input_shape[1]\n    y_pred = tf.math.log(\n        tf.transpose(y_pred, perm=[1, 0, 2]) + backend.epsilon()\n    )\n    input_length = tf.cast(input_length, tf.int32)\n\n    if greedy:\n        (decoded, log_prob) = tf.nn.ctc_greedy_decoder(\n            inputs=y_pred, sequence_length=input_length\n        )\n    else:\n        (decoded, log_prob) = tf.compat.v1.nn.ctc_beam_search_decoder(\n            inputs=y_pred,\n            sequence_length=input_length,\n            beam_width=beam_width,\n            top_paths=top_paths,\n        )\n    decoded_dense = []\n    for st in decoded:\n        st = tf.SparseTensor(st.indices, st.values, (num_samples, num_steps))\n        decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1))\n    return (decoded_dense, log_prob)\n\n\n@keras_export(\"keras._legacy.backend.cumsum\")\ndef cumsum(x, axis=0):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.cumsum(x, axis=axis)\n\n\n@keras_export(\"keras._legacy.backend.cumprod\")\ndef cumprod(x, axis=0):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.math.cumprod(x, axis=axis)\n\n\n@keras_export(\"keras._legacy.backend.depthwise_conv2d\")\ndef depthwise_conv2d(\n    x,\n    depthwise_kernel,\n    strides=(1, 1),\n    padding=\"valid\",\n    data_format=None,\n    dilation_rate=(1, 1),\n):\n    \"\"\"DEPRECATED.\"\"\"\n    if data_format is None:\n        data_format = backend.image_data_format()\n    if data_format not in {\"channels_first\", \"channels_last\"}:\n        raise ValueError(f\"Unknown data_format: {data_format}\")\n\n    x, tf_data_format = _preprocess_conv2d_input(x, data_format)\n    padding = _preprocess_padding(padding)\n    if tf_data_format == \"NHWC\":\n        strides = (1,) + strides + (1,)\n    else:\n        strides = (1, 1) + strides\n\n    x = tf.nn.depthwise_conv2d(\n        x,\n        depthwise_kernel,\n        strides=strides,\n        padding=padding,\n        dilations=dilation_rate,\n        data_format=tf_data_format,\n    )\n    if data_format == \"channels_first\" and tf_data_format == \"NHWC\":\n        x = tf.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW\n    return x\n\n\n@keras_export(\"keras._legacy.backend.dot\")\ndef dot(x, y):\n    \"\"\"DEPRECATED.\"\"\"\n    if ndim(x) is not None and (ndim(x) > 2 or ndim(y) > 2):\n        x_shape = []\n        for i, s in zip(x.shape, tf.unstack(tf.shape(x))):\n            if i is not None:\n                x_shape.append(i)\n            else:\n                x_shape.append(s)\n        x_shape = tuple(x_shape)\n        y_shape = []\n        for i, s in zip(y.shape, tf.unstack(tf.shape(y))):\n            if i is not None:\n                y_shape.append(i)\n            else:\n                y_shape.append(s)\n        y_shape = tuple(y_shape)\n        y_permute_dim = list(range(ndim(y)))\n        y_permute_dim = [y_permute_dim.pop(-2)] + y_permute_dim\n        xt = tf.reshape(x, [-1, x_shape[-1]])\n        yt = tf.reshape(tf.transpose(y, perm=y_permute_dim), [y_shape[-2], -1])\n        return tf.reshape(\n            tf.matmul(xt, yt), x_shape[:-1] + y_shape[:-2] + y_shape[-1:]\n        )\n    if is_sparse(x):\n        out = tf.sparse.sparse_dense_matmul(x, y)\n    else:\n        out = tf.matmul(x, y)\n    return out\n\n\n@keras_export(\"keras._legacy.backend.dropout\")\ndef dropout(x, level, noise_shape=None, seed=None):\n    \"\"\"DEPRECATED.\"\"\"\n    if seed is None:\n        seed = np.random.randint(10e6)\n    return tf.nn.dropout(x, rate=level, noise_shape=noise_shape, seed=seed)\n\n\n@keras_export(\"keras._legacy.backend.dtype\")\ndef dtype(x):\n    \"\"\"DEPRECATED.\"\"\"\n    return x.dtype.base_dtype.name\n\n\n@keras_export(\"keras._legacy.backend.elu\")\ndef elu(x, alpha=1.0):\n    \"\"\"DEPRECATED.\"\"\"\n    res = tf.nn.elu(x)\n    if alpha == 1:\n        return res\n    else:\n        return tf.where(x > 0, res, alpha * res)\n\n\n@keras_export(\"keras._legacy.backend.equal\")\ndef equal(x, y):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.equal(x, y)\n\n\n@keras_export(\"keras._legacy.backend.eval\")\ndef eval(x):\n    \"\"\"DEPRECATED.\"\"\"\n    return get_value(to_dense(x))\n\n\n@keras_export(\"keras._legacy.backend.exp\")\ndef exp(x):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.exp(x)\n\n\n@keras_export(\"keras._legacy.backend.expand_dims\")\ndef expand_dims(x, axis=-1):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.expand_dims(x, axis)\n\n\n@keras_export(\"keras._legacy.backend.eye\")\ndef eye(size, dtype=None, name=None):\n    \"\"\"DEPRECATED.\"\"\"\n    if dtype is None:\n        dtype = backend.floatx()\n    tf_dtype = tf.as_dtype(dtype)\n    return variable(tf.eye(size, dtype=tf_dtype), dtype, name)\n\n\n@keras_export(\"keras._legacy.backend.flatten\")\ndef flatten(x):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.reshape(x, [-1])\n\n\n@keras_export(\"keras._legacy.backend.foldl\")\ndef foldl(fn, elems, initializer=None, name=None):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.compat.v1.foldl(fn, elems, initializer=initializer, name=name)\n\n\n@keras_export(\"keras._legacy.backend.foldr\")\ndef foldr(fn, elems, initializer=None, name=None):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.compat.v1.foldr(fn, elems, initializer=initializer, name=name)\n\n\n@keras_export(\"keras._legacy.backend.gather\")\ndef gather(reference, indices):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.compat.v1.gather(reference, indices)\n\n\n@keras_export(\"keras._legacy.backend.get_value\")\ndef get_value(x):\n    \"\"\"DEPRECATED.\"\"\"\n    if not tf.is_tensor(x):\n        return x\n    if tf.executing_eagerly() or isinstance(x, tf.__internal__.EagerTensor):\n        return x.numpy()\n    if not getattr(x, \"_in_graph_mode\", True):\n        # This is a variable which was created in an eager context, but is being\n        # evaluated from a Graph.\n        with tf.__internal__.eager_context.eager_mode():\n            return x.numpy()\n    with tf.init_scope():\n        return x.numpy()\n\n\n@keras_export(\"keras._legacy.backend.gradients\")\ndef gradients(loss, variables):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.compat.v1.gradients(\n        loss, variables, colocate_gradients_with_ops=True\n    )\n\n\n@keras_export(\"keras._legacy.backend.greater\")\ndef greater(x, y):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.greater(x, y)\n\n\n@keras_export(\"keras._legacy.backend.greater_equal\")\ndef greater_equal(x, y):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.greater_equal(x, y)\n\n\n@keras_export(\"keras._legacy.backend.hard_sigmoid\")\ndef hard_sigmoid(x):\n    \"\"\"DEPRECATED.\"\"\"\n    point_two = tf.convert_to_tensor(0.2, dtype=x.dtype)\n    point_five = tf.convert_to_tensor(0.5, dtype=x.dtype)\n    x = tf.multiply(x, point_two)\n    x = tf.add(x, point_five)\n    x = tf.clip_by_value(x, 0.0, 1.0)\n    return x\n\n\n@keras_export(\"keras._legacy.backend.in_top_k\")\ndef in_top_k(predictions, targets, k):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.compat.v1.math.in_top_k(predictions, targets, k)\n\n\n@keras_export(\"keras._legacy.backend.int_shape\")\ndef int_shape(x):\n    \"\"\"DEPRECATED.\"\"\"\n    try:\n        shape = x.shape\n        if not isinstance(shape, tuple):\n            shape = tuple(shape.as_list())\n        return shape\n    except ValueError:\n        return None\n\n\n@keras_export(\"keras._legacy.backend.is_sparse\")\ndef is_sparse(tensor):\n    \"\"\"DEPRECATED.\"\"\"\n    spec = getattr(tensor, \"_type_spec\", None)\n    if spec is not None:\n        return isinstance(spec, tf.SparseTensorSpec)\n    return isinstance(tensor, tf.SparseTensor)\n\n\n@keras_export(\"keras._legacy.backend.l2_normalize\")\ndef l2_normalize(x, axis=None):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.linalg.l2_normalize(x, axis=axis)\n\n\n@keras_export(\"keras._legacy.backend.less\")\ndef less(x, y):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.less(x, y)\n\n\n@keras_export(\"keras._legacy.backend.less_equal\")\ndef less_equal(x, y):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.less_equal(x, y)\n\n\n@keras_export(\"keras._legacy.backend.log\")\ndef log(x):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.math.log(x)\n\n\n@keras_export(\"keras._legacy.backend.map_fn\")\ndef map_fn(fn, elems, name=None, dtype=None):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.compat.v1.map_fn(fn, elems, name=name, dtype=dtype)\n\n\n@keras_export(\"keras._legacy.backend.max\")\ndef max(x, axis=None, keepdims=False):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.reduce_max(x, axis, keepdims)\n\n\n@keras_export(\"keras._legacy.backend.maximum\")\ndef maximum(x, y):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.maximum(x, y)\n\n\n@keras_export(\"keras._legacy.backend.mean\")\ndef mean(x, axis=None, keepdims=False):\n    \"\"\"DEPRECATED.\"\"\"\n    if x.dtype.base_dtype == tf.bool:\n        x = tf.cast(x, backend.floatx())\n    return tf.reduce_mean(x, axis, keepdims)\n\n\n@keras_export(\"keras._legacy.backend.min\")\ndef min(x, axis=None, keepdims=False):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.reduce_min(x, axis, keepdims)\n\n\n@keras_export(\"keras._legacy.backend.minimum\")\ndef minimum(x, y):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.minimum(x, y)\n\n\n@keras_export(\"keras._legacy.backend.moving_average_update\")\ndef moving_average_update(x, value, momentum):\n    \"\"\"DEPRECATED.\"\"\"\n    momentum = tf.cast(momentum, x.dtype)\n    value = tf.cast(value, x.dtype)\n    return x.assign_sub((x - value) * (1 - momentum))\n\n\n@keras_export(\"keras._legacy.backend.name_scope\")\ndef name_scope(name):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.name_scope(name)\n\n\n@keras_export(\"keras._legacy.backend.ndim\")\ndef ndim(x):\n    \"\"\"DEPRECATED.\"\"\"\n    return x.shape.rank\n\n\n@keras_export(\"keras._legacy.backend.not_equal\")\ndef not_equal(x, y):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.not_equal(x, y)\n\n\n@keras_export(\"keras._legacy.backend.one_hot\")\ndef one_hot(indices, num_classes):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.one_hot(indices, depth=num_classes, axis=-1)\n\n\n@keras_export(\"keras._legacy.backend.ones\")\ndef ones(shape, dtype=None, name=None):\n    \"\"\"DEPRECATED.\"\"\"\n    with tf.init_scope():\n        if dtype is None:\n            dtype = backend.floatx()\n        tf_dtype = tf.as_dtype(dtype)\n        v = tf.ones(shape=shape, dtype=tf_dtype, name=name)\n        if py_all(v.shape.as_list()):\n            return variable(v, dtype=dtype, name=name)\n        return v\n\n\n@keras_export(\"keras._legacy.backend.ones_like\")\ndef ones_like(x, dtype=None, name=None):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.ones_like(x, dtype=dtype, name=name)\n\n\n@keras_export(\"keras._legacy.backend.permute_dimensions\")\ndef permute_dimensions(x, pattern):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.transpose(x, perm=pattern)\n\n\n@keras_export(\"keras._legacy.backend.pool2d\")\ndef pool2d(\n    x,\n    pool_size,\n    strides=(1, 1),\n    padding=\"valid\",\n    data_format=None,\n    pool_mode=\"max\",\n):\n    \"\"\"DEPRECATED.\"\"\"\n    if data_format is None:\n        data_format = backend.image_data_format()\n    if data_format not in {\"channels_first\", \"channels_last\"}:\n        raise ValueError(f\"Unknown data_format: {data_format}\")\n    if len(pool_size) != 2:\n        raise ValueError(\"`pool_size` must be a tuple of 2 integers.\")\n    if len(strides) != 2:\n        raise ValueError(\"`strides` must be a tuple of 2 integers.\")\n\n    x, tf_data_format = _preprocess_conv2d_input(x, data_format)\n    padding = _preprocess_padding(padding)\n    if tf_data_format == \"NHWC\":\n        strides = (1,) + strides + (1,)\n        pool_size = (1,) + pool_size + (1,)\n    else:\n        strides = (1, 1) + strides\n        pool_size = (1, 1) + pool_size\n\n    if pool_mode == \"max\":\n        x = tf.compat.v1.nn.max_pool(\n            x, pool_size, strides, padding=padding, data_format=tf_data_format\n        )\n    elif pool_mode == \"avg\":\n        x = tf.compat.v1.nn.avg_pool(\n            x, pool_size, strides, padding=padding, data_format=tf_data_format\n        )\n    else:\n        raise ValueError(f\"Invalid pooling mode: {str(pool_mode)}\")\n\n    if data_format == \"channels_first\" and tf_data_format == \"NHWC\":\n        x = tf.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW\n    return x\n\n\n@keras_export(\"keras._legacy.backend.pool3d\")\ndef pool3d(\n    x,\n    pool_size,\n    strides=(1, 1, 1),\n    padding=\"valid\",\n    data_format=None,\n    pool_mode=\"max\",\n):\n    \"\"\"DEPRECATED.\"\"\"\n    if data_format is None:\n        data_format = backend.image_data_format()\n    if data_format not in {\"channels_first\", \"channels_last\"}:\n        raise ValueError(f\"Unknown data_format: {data_format}\")\n\n    x, tf_data_format = _preprocess_conv3d_input(x, data_format)\n    padding = _preprocess_padding(padding)\n    if tf_data_format == \"NDHWC\":\n        strides = (1,) + strides + (1,)\n        pool_size = (1,) + pool_size + (1,)\n    else:\n        strides = (1, 1) + strides\n        pool_size = (1, 1) + pool_size\n\n    if pool_mode == \"max\":\n        x = tf.nn.max_pool3d(\n            x, pool_size, strides, padding=padding, data_format=tf_data_format\n        )\n    elif pool_mode == \"avg\":\n        x = tf.nn.avg_pool3d(\n            x, pool_size, strides, padding=padding, data_format=tf_data_format\n        )\n    else:\n        raise ValueError(f\"Invalid pooling mode: {str(pool_mode)}\")\n\n    if data_format == \"channels_first\" and tf_data_format == \"NDHWC\":\n        x = tf.transpose(x, (0, 4, 1, 2, 3))\n    return x\n\n\n@keras_export(\"keras._legacy.backend.pow\")\ndef pow(x, a):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.pow(x, a)\n\n\n@keras_export(\"keras._legacy.backend.prod\")\ndef prod(x, axis=None, keepdims=False):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.reduce_prod(x, axis, keepdims)\n\n\n@keras_export(\"keras._legacy.backend.random_bernoulli\")\ndef random_bernoulli(shape, p=0.0, dtype=None, seed=None):\n    \"\"\"DEPRECATED.\"\"\"\n    if dtype is None:\n        dtype = backend.floatx()\n    if seed is None:\n        seed = np.random.randint(10e6)\n    return tf.where(\n        tf.random.uniform(shape, dtype=dtype, seed=seed) <= p,\n        tf.ones(shape, dtype=dtype),\n        tf.zeros(shape, dtype=dtype),\n    )\n\n\n@keras_export(\"keras._legacy.backend.random_normal\")\ndef random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):\n    \"\"\"DEPRECATED.\"\"\"\n    if dtype is None:\n        dtype = backend.floatx()\n    if seed is None:\n        seed = np.random.randint(10e6)\n    return tf.random.normal(\n        shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed\n    )\n\n\n@keras_export(\"keras._legacy.backend.random_normal_variable\")\ndef random_normal_variable(\n    shape, mean, scale, dtype=None, name=None, seed=None\n):\n    \"\"\"DEPRECATED.\"\"\"\n    if dtype is None:\n        dtype = backend.floatx()\n    tf_dtype = tf.as_dtype(dtype)\n    if seed is None:\n        # ensure that randomness is conditioned by the Numpy RNG\n        seed = np.random.randint(10e8)\n    value = tf.compat.v1.random_normal_initializer(\n        mean, scale, dtype=tf_dtype, seed=seed\n    )(shape)\n    return variable(value, dtype=dtype, name=name)\n\n\n@keras_export(\"keras._legacy.backend.random_uniform\")\ndef random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):\n    \"\"\"DEPRECATED.\"\"\"\n    if dtype is None:\n        dtype = backend.floatx()\n    if seed is None:\n        seed = np.random.randint(10e6)\n    return tf.random.uniform(\n        shape, minval=minval, maxval=maxval, dtype=dtype, seed=seed\n    )\n\n\n@keras_export(\"keras._legacy.backend.random_uniform_variable\")\ndef random_uniform_variable(shape, low, high, dtype=None, name=None, seed=None):\n    \"\"\"DEPRECATED.\"\"\"\n    if dtype is None:\n        dtype = backend.floatx()\n    tf_dtype = tf.as_dtype(dtype)\n    if seed is None:\n        # ensure that randomness is conditioned by the Numpy RNG\n        seed = np.random.randint(10e8)\n    value = tf.compat.v1.random_uniform_initializer(\n        low, high, dtype=tf_dtype, seed=seed\n    )(shape)\n    return variable(value, dtype=dtype, name=name)\n\n\n@keras_export(\"keras._legacy.backend.reshape\")\ndef reshape(x, shape):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.reshape(x, shape)\n\n\n@keras_export(\"keras._legacy.backend.relu\")\ndef relu(x, alpha=0.0, max_value=None, threshold=0.0):\n    \"\"\"DEPRECATED.\"\"\"\n    # While x can be a tensor or variable, we also see cases where\n    # numpy arrays, lists, tuples are passed as well.\n    # lists, tuples do not have 'dtype' attribute.\n    dtype = getattr(x, \"dtype\", backend.floatx())\n    if alpha != 0.0:\n        if max_value is None and threshold == 0:\n            return tf.nn.leaky_relu(x, alpha=alpha)\n\n        if threshold != 0:\n            negative_part = tf.nn.relu(-x + threshold)\n        else:\n            negative_part = tf.nn.relu(-x)\n    else:\n        negative_part = 1\n\n    clip_max = max_value is not None\n\n    if threshold != 0:\n        # computes x for x > threshold else 0\n        x = x * tf.cast(tf.greater(x, threshold), dtype=dtype)\n    elif max_value == 6:\n        # if no threshold, then can use nn.relu6 native TF op for performance\n        x = tf.nn.relu6(x)\n        clip_max = False\n    else:\n        x = tf.nn.relu(x)\n\n    if clip_max:\n        max_value = tf.convert_to_tensor(max_value, dtype=x.dtype)\n        zero = tf.convert_to_tensor(0, dtype=x.dtype)\n        x = tf.clip_by_value(x, zero, max_value)\n\n    if alpha != 0.0:\n        alpha = tf.convert_to_tensor(alpha, dtype=x.dtype)\n        x -= alpha * negative_part\n    return x\n\n\n@keras_export(\"keras._legacy.backend.repeat\")\ndef repeat(x, n):\n    \"\"\"DEPRECATED.\"\"\"\n    if ndim(x) != 2:\n        raise ValueError(\n            f\"Expected input `x` to have rank 2. Received: rank(x)={ndim(x)}\"\n        )\n    x = tf.expand_dims(x, 1)\n    pattern = tf.stack([1, n, 1])\n    return tf.tile(x, pattern)\n\n\n@keras_export(\"keras._legacy.backend.repeat_elements\")\ndef repeat_elements(x, rep, axis):\n    \"\"\"DEPRECATED.\"\"\"\n    x_shape = x.shape.as_list()\n    # For static axis\n    if x_shape[axis] is not None:\n        # slices along the repeat axis\n        splits = tf.split(value=x, num_or_size_splits=x_shape[axis], axis=axis)\n        # repeat each slice the given number of reps\n        x_rep = [s for s in splits for _ in range(rep)]\n        return concatenate(x_rep, axis)\n\n    # Here we use tf.tile to mimic behavior of np.repeat so that\n    # we can handle dynamic shapes (that include None).\n    # To do that, we need an auxiliary axis to repeat elements along\n    # it and then merge them along the desired axis.\n\n    # Repeating\n    auxiliary_axis = axis + 1\n    x_shape = tf.shape(x)\n    x_rep = tf.expand_dims(x, axis=auxiliary_axis)\n    reps = np.ones(len(x.shape) + 1)\n    reps[auxiliary_axis] = rep\n    x_rep = tf.tile(x_rep, reps)\n\n    # Merging\n    reps = np.delete(reps, auxiliary_axis)\n    reps[axis] = rep\n    reps = tf.constant(reps, dtype=\"int32\")\n    x_shape *= reps\n    x_rep = tf.reshape(x_rep, x_shape)\n\n    # Fix shape representation\n    x_shape = x.shape.as_list()\n    x_rep.set_shape(x_shape)\n    return x_rep\n\n\n@keras_export(\"keras._legacy.backend.resize_images\")\ndef resize_images(\n    x, height_factor, width_factor, data_format, interpolation=\"nearest\"\n):\n    \"\"\"DEPRECATED.\"\"\"\n    if data_format == \"channels_first\":\n        rows, cols = 2, 3\n    elif data_format == \"channels_last\":\n        rows, cols = 1, 2\n    else:\n        raise ValueError(f\"Invalid `data_format` argument: {data_format}\")\n\n    new_shape = x.shape[rows : cols + 1]\n    if new_shape.is_fully_defined():\n        new_shape = tf.constant(new_shape.as_list(), dtype=\"int32\")\n    else:\n        new_shape = tf.shape(x)[rows : cols + 1]\n    new_shape *= tf.constant(\n        np.array([height_factor, width_factor], dtype=\"int32\")\n    )\n\n    if data_format == \"channels_first\":\n        x = permute_dimensions(x, [0, 2, 3, 1])\n    interpolations = {\n        \"area\": tf.image.ResizeMethod.AREA,\n        \"bicubic\": tf.image.ResizeMethod.BICUBIC,\n        \"bilinear\": tf.image.ResizeMethod.BILINEAR,\n        \"gaussian\": tf.image.ResizeMethod.GAUSSIAN,\n        \"lanczos3\": tf.image.ResizeMethod.LANCZOS3,\n        \"lanczos5\": tf.image.ResizeMethod.LANCZOS5,\n        \"mitchellcubic\": tf.image.ResizeMethod.MITCHELLCUBIC,\n        \"nearest\": tf.image.ResizeMethod.NEAREST_NEIGHBOR,\n    }\n    interploations_list = '\"' + '\", \"'.join(interpolations.keys()) + '\"'\n    if interpolation in interpolations:\n        x = tf.image.resize(x, new_shape, method=interpolations[interpolation])\n    else:\n        raise ValueError(\n            \"`interpolation` argument should be one of: \"\n            f'{interploations_list}. Received: \"{interpolation}\".'\n        )\n    if data_format == \"channels_first\":\n        x = permute_dimensions(x, [0, 3, 1, 2])\n\n    return x\n\n\n@keras_export(\"keras._legacy.backend.resize_volumes\")\ndef resize_volumes(x, depth_factor, height_factor, width_factor, data_format):\n    \"\"\"DEPRECATED.\"\"\"\n    if data_format == \"channels_first\":\n        output = repeat_elements(x, depth_factor, axis=2)\n        output = repeat_elements(output, height_factor, axis=3)\n        output = repeat_elements(output, width_factor, axis=4)\n        return output\n    elif data_format == \"channels_last\":\n        output = repeat_elements(x, depth_factor, axis=1)\n        output = repeat_elements(output, height_factor, axis=2)\n        output = repeat_elements(output, width_factor, axis=3)\n        return output\n    else:\n        raise ValueError(f\"Invalid data_format: {data_format}\")\n\n\n@keras_export(\"keras._legacy.backend.reverse\")\ndef reverse(x, axes):\n    \"\"\"DEPRECATED.\"\"\"\n    if isinstance(axes, int):\n        axes = [axes]\n    return tf.reverse(x, axes)\n\n\n@keras_export(\"keras._legacy.backend.rnn\")\ndef rnn(\n    step_function,\n    inputs,\n    initial_states,\n    go_backwards=False,\n    mask=None,\n    constants=None,\n    unroll=False,\n    input_length=None,\n    time_major=False,\n    zero_output_for_mask=False,\n    return_all_outputs=True,\n):\n    \"\"\"DEPRECATED.\"\"\"\n    if not tf.__internal__.tf2.enabled():\n        return_all_outputs = True  # Not supported in TF1.\n\n    def swap_batch_timestep(input_t):\n        # Swap the batch and timestep dim for the incoming tensor.\n        axes = list(range(len(input_t.shape)))\n        axes[0], axes[1] = 1, 0\n        return tf.transpose(input_t, axes)\n\n    if not time_major:\n        inputs = tf.nest.map_structure(swap_batch_timestep, inputs)\n\n    flatted_inputs = tf.nest.flatten(inputs)\n    time_steps = flatted_inputs[0].shape[0]\n    batch = flatted_inputs[0].shape[1]\n    time_steps_t = tf.shape(flatted_inputs[0])[0]\n\n    for input_ in flatted_inputs:\n        input_.shape.with_rank_at_least(3)\n\n    if mask is not None:\n        if mask.dtype != tf.bool:\n            mask = tf.cast(mask, tf.bool)\n        if len(mask.shape) == 2:\n            mask = expand_dims(mask)\n        if not time_major:\n            mask = swap_batch_timestep(mask)\n\n    if constants is None:\n        constants = []\n\n    # tf.where needs its condition tensor to be the same shape as its two\n    # result tensors, but in our case the condition (mask) tensor is\n    # (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.\n    # So we need to broadcast the mask to match the shape of inputs.\n    # That's what the tile call does, it just repeats the mask along its\n    # second dimension n times.\n    def _expand_mask(mask_t, input_t, fixed_dim=1):\n        if tf.nest.is_nested(mask_t):\n            raise ValueError(\n                f\"mask_t is expected to be tensor, but got {mask_t}\"\n            )\n        if tf.nest.is_nested(input_t):\n            raise ValueError(\n                f\"input_t is expected to be tensor, but got {input_t}\"\n            )\n        rank_diff = len(input_t.shape) - len(mask_t.shape)\n        for _ in range(rank_diff):\n            mask_t = tf.expand_dims(mask_t, -1)\n        multiples = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:]\n        return tf.tile(mask_t, multiples)\n\n    if unroll:\n        if not time_steps:\n            raise ValueError(\"Unrolling requires a fixed number of timesteps.\")\n        states = tuple(initial_states)\n        successive_states = []\n        successive_outputs = []\n\n        # Process the input tensors. The input tensor need to be split on the\n        # time_step dim, and reverse if go_backwards is True. In the case of\n        # nested input, the input is flattened and then transformed\n        # individually.  The result of this will be a tuple of lists, each of\n        # the item in tuple is list of the tensor with shape (batch, feature)\n        def _process_single_input_t(input_t):\n            input_t = tf.unstack(input_t)  # unstack for time_step dim\n            if go_backwards:\n                input_t.reverse()\n            return input_t\n\n        if tf.nest.is_nested(inputs):\n            processed_input = tf.nest.map_structure(\n                _process_single_input_t, inputs\n            )\n        else:\n            processed_input = (_process_single_input_t(inputs),)\n\n        def _get_input_tensor(time):\n            inp = [t_[time] for t_ in processed_input]\n            return tf.nest.pack_sequence_as(inputs, inp)\n\n        if mask is not None:\n            mask_list = tf.unstack(mask)\n            if go_backwards:\n                mask_list.reverse()\n\n            for i in range(time_steps):\n                inp = _get_input_tensor(i)\n                mask_t = mask_list[i]\n                output, new_states = step_function(\n                    inp, tuple(states) + tuple(constants)\n                )\n                tiled_mask_t = _expand_mask(mask_t, output)\n\n                if not successive_outputs:\n                    prev_output = zeros_like(output)\n                else:\n                    prev_output = successive_outputs[-1]\n\n                output = tf.where(tiled_mask_t, output, prev_output)\n\n                flat_states = tf.nest.flatten(states)\n                flat_new_states = tf.nest.flatten(new_states)\n                tiled_mask_t = tuple(\n                    _expand_mask(mask_t, s) for s in flat_states\n                )\n                flat_final_states = tuple(\n                    tf.where(m, s, ps)\n                    for m, s, ps in zip(\n                        tiled_mask_t, flat_new_states, flat_states\n                    )\n                )\n                states = tf.nest.pack_sequence_as(states, flat_final_states)\n\n                if return_all_outputs:\n                    successive_outputs.append(output)\n                    successive_states.append(states)\n                else:\n                    successive_outputs = [output]\n                    successive_states = [states]\n            last_output = successive_outputs[-1]\n            new_states = successive_states[-1]\n            outputs = tf.stack(successive_outputs)\n\n            if zero_output_for_mask:\n                last_output = tf.where(\n                    _expand_mask(mask_list[-1], last_output),\n                    last_output,\n                    zeros_like(last_output),\n                )\n                outputs = tf.where(\n                    _expand_mask(mask, outputs, fixed_dim=2),\n                    outputs,\n                    zeros_like(outputs),\n                )\n\n        else:  # mask is None\n            for i in range(time_steps):\n                inp = _get_input_tensor(i)\n                output, states = step_function(\n                    inp, tuple(states) + tuple(constants)\n                )\n                if return_all_outputs:\n                    successive_outputs.append(output)\n                    successive_states.append(states)\n                else:\n                    successive_outputs = [output]\n                    successive_states = [states]\n            last_output = successive_outputs[-1]\n            new_states = successive_states[-1]\n            outputs = tf.stack(successive_outputs)\n\n    else:  # Unroll == False\n        states = tuple(initial_states)\n\n        # Create input tensor array, if the inputs is nested tensors, then it\n        # will be flattened first, and tensor array will be created one per\n        # flattened tensor.\n        input_ta = tuple(\n            tf.TensorArray(\n                dtype=inp.dtype,\n                size=time_steps_t,\n                tensor_array_name=f\"input_ta_{i}\",\n            )\n            for i, inp in enumerate(flatted_inputs)\n        )\n        input_ta = tuple(\n            (\n                ta.unstack(input_)\n                if not go_backwards\n                else ta.unstack(reverse(input_, 0))\n            )\n            for ta, input_ in zip(input_ta, flatted_inputs)\n        )\n\n        # Get the time(0) input and compute the output for that, the output will\n        # be used to determine the dtype of output tensor array. Don't read from\n        # input_ta due to TensorArray clear_after_read default to True.\n        input_time_zero = tf.nest.pack_sequence_as(\n            inputs, [inp[0] for inp in flatted_inputs]\n        )\n        # output_time_zero is used to determine the cell output shape and its\n        # dtype.  the value is discarded.\n        output_time_zero, _ = step_function(\n            input_time_zero, tuple(initial_states) + tuple(constants)\n        )\n\n        output_ta_size = time_steps_t if return_all_outputs else 1\n        output_ta = tuple(\n            tf.TensorArray(\n                dtype=out.dtype,\n                size=output_ta_size,\n                element_shape=out.shape,\n                tensor_array_name=f\"output_ta_{i}\",\n            )\n            for i, out in enumerate(tf.nest.flatten(output_time_zero))\n        )\n\n        time = tf.constant(0, dtype=\"int32\", name=\"time\")\n\n        if input_length is None:\n            max_iterations = time_steps_t\n        else:\n            max_iterations = tf.reduce_max(input_length)\n\n        while_loop_kwargs = {\n            \"cond\": lambda time, *_: time < time_steps_t,\n            \"maximum_iterations\": max_iterations,\n            \"parallel_iterations\": 32,\n            \"swap_memory\": True,\n        }\n        if mask is not None:\n            if go_backwards:\n                mask = reverse(mask, 0)\n\n            mask_ta = tf.TensorArray(\n                dtype=tf.bool, size=time_steps_t, tensor_array_name=\"mask_ta\"\n            )\n            mask_ta = mask_ta.unstack(mask)\n\n            def masking_fn(time):\n                return mask_ta.read(time)\n\n            def compute_masked_output(mask_t, flat_out, flat_mask):\n                tiled_mask_t = tuple(\n                    _expand_mask(mask_t, o, fixed_dim=len(mask_t.shape))\n                    for o in flat_out\n                )\n                return tuple(\n                    tf.where(m, o, fm)\n                    for m, o, fm in zip(tiled_mask_t, flat_out, flat_mask)\n                )\n\n        elif isinstance(input_length, tf.Tensor):\n            if go_backwards:\n                max_len = tf.reduce_max(input_length, axis=0)\n                rev_input_length = tf.subtract(max_len - 1, input_length)\n\n                def masking_fn(time):\n                    return tf.less(rev_input_length, time)\n\n            else:\n\n                def masking_fn(time):\n                    return tf.greater(input_length, time)\n\n            def compute_masked_output(mask_t, flat_out, flat_mask):\n                return tuple(\n                    tf.compat.v1.where(mask_t, o, zo)\n                    for (o, zo) in zip(flat_out, flat_mask)\n                )\n\n        else:\n            masking_fn = None\n\n        if masking_fn is not None:\n            # Mask for the T output will be base on the output of T - 1. In the\n            # case T = 0, a zero filled tensor will be used.\n            flat_zero_output = tuple(\n                tf.zeros_like(o) for o in tf.nest.flatten(output_time_zero)\n            )\n\n            def _step(time, output_ta_t, prev_output, *states):\n                \"\"\"RNN step function.\n\n                Args:\n                    time: Current timestep value.\n                    output_ta_t: TensorArray.\n                    prev_output: tuple of outputs from time - 1.\n                    *states: List of states.\n\n                Returns:\n                    Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)`\n                \"\"\"\n                current_input = tuple(ta.read(time) for ta in input_ta)\n                # maybe set shape.\n                current_input = tf.nest.pack_sequence_as(inputs, current_input)\n                mask_t = masking_fn(time)\n                output, new_states = step_function(\n                    current_input, tuple(states) + tuple(constants)\n                )\n                # mask output\n                flat_output = tf.nest.flatten(output)\n                flat_mask_output = (\n                    flat_zero_output\n                    if zero_output_for_mask\n                    else tf.nest.flatten(prev_output)\n                )\n                flat_new_output = compute_masked_output(\n                    mask_t, flat_output, flat_mask_output\n                )\n\n                # mask states\n                flat_state = tf.nest.flatten(states)\n                flat_new_state = tf.nest.flatten(new_states)\n                for state, new_state in zip(flat_state, flat_new_state):\n                    if isinstance(new_state, tf.Tensor):\n                        new_state.set_shape(state.shape)\n                flat_final_state = compute_masked_output(\n                    mask_t, flat_new_state, flat_state\n                )\n                new_states = tf.nest.pack_sequence_as(\n                    new_states, flat_final_state\n                )\n\n                ta_index_to_write = time if return_all_outputs else 0\n                output_ta_t = tuple(\n                    ta.write(ta_index_to_write, out)\n                    for ta, out in zip(output_ta_t, flat_new_output)\n                )\n\n                return (time + 1, output_ta_t, tuple(flat_new_output)) + tuple(\n                    new_states\n                )\n\n            final_outputs = tf.compat.v1.while_loop(\n                body=_step,\n                loop_vars=(time, output_ta, flat_zero_output) + states,\n                **while_loop_kwargs,\n            )\n            # Skip final_outputs[2] which is the output for final timestep.\n            new_states = final_outputs[3:]\n        else:\n\n            def _step(time, output_ta_t, *states):\n                \"\"\"RNN step function.\n\n                Args:\n                    time: Current timestep value.\n                    output_ta_t: TensorArray.\n                    *states: List of states.\n\n                Returns:\n                    Tuple: `(time + 1,output_ta_t) + tuple(new_states)`\n                \"\"\"\n                current_input = tuple(ta.read(time) for ta in input_ta)\n                current_input = tf.nest.pack_sequence_as(inputs, current_input)\n                output, new_states = step_function(\n                    current_input, tuple(states) + tuple(constants)\n                )\n                flat_state = tf.nest.flatten(states)\n                flat_new_state = tf.nest.flatten(new_states)\n                for state, new_state in zip(flat_state, flat_new_state):\n                    if isinstance(new_state, tf.Tensor):\n                        new_state.set_shape(state.shape)\n\n                flat_output = tf.nest.flatten(output)\n                ta_index_to_write = time if return_all_outputs else 0\n                output_ta_t = tuple(\n                    ta.write(ta_index_to_write, out)\n                    for ta, out in zip(output_ta_t, flat_output)\n                )\n\n                new_states = tf.nest.pack_sequence_as(\n                    initial_states, flat_new_state\n                )\n                return (time + 1, output_ta_t) + tuple(new_states)\n\n            final_outputs = tf.compat.v1.while_loop(\n                body=_step,\n                loop_vars=(time, output_ta) + states,\n                **while_loop_kwargs,\n            )\n            new_states = final_outputs[2:]\n\n        output_ta = final_outputs[1]\n\n        outputs = tuple(o.stack() for o in output_ta)\n        last_output = tuple(o[-1] for o in outputs)\n\n        outputs = tf.nest.pack_sequence_as(output_time_zero, outputs)\n        last_output = tf.nest.pack_sequence_as(output_time_zero, last_output)\n\n    # static shape inference\n    def set_shape(output_):\n        if isinstance(output_, tf.Tensor):\n            shape = output_.shape.as_list()\n            if return_all_outputs:\n                shape[0] = time_steps\n            else:\n                shape[0] = 1\n            shape[1] = batch\n            output_.set_shape(shape)\n        return output_\n\n    outputs = tf.nest.map_structure(set_shape, outputs)\n\n    if not time_major:\n        outputs = tf.nest.map_structure(swap_batch_timestep, outputs)\n\n    return last_output, outputs, new_states\n\n\n@keras_export(\"keras._legacy.backend.round\")\ndef round(x):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.round(x)\n\n\n@keras_export(\"keras._legacy.backend.separable_conv2d\")\ndef separable_conv2d(\n    x,\n    depthwise_kernel,\n    pointwise_kernel,\n    strides=(1, 1),\n    padding=\"valid\",\n    data_format=None,\n    dilation_rate=(1, 1),\n):\n    \"\"\"DEPRECATED.\"\"\"\n    if data_format is None:\n        data_format = backend.image_data_format()\n    if data_format not in {\"channels_first\", \"channels_last\"}:\n        raise ValueError(f\"Unknown data_format: {data_format}\")\n    if len(strides) != 2:\n        raise ValueError(\"`strides` must be a tuple of 2 integers.\")\n\n    x, tf_data_format = _preprocess_conv2d_input(x, data_format)\n    padding = _preprocess_padding(padding)\n    if not isinstance(strides, tuple):\n        strides = tuple(strides)\n    if tf_data_format == \"NHWC\":\n        strides = (1,) + strides + (1,)\n    else:\n        strides = (1, 1) + strides\n\n    x = tf.nn.separable_conv2d(\n        x,\n        depthwise_kernel,\n        pointwise_kernel,\n        strides=strides,\n        padding=padding,\n        dilations=dilation_rate,\n        data_format=tf_data_format,\n    )\n    if data_format == \"channels_first\" and tf_data_format == \"NHWC\":\n        x = tf.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW\n    return x\n\n\n@keras_export(\"keras._legacy.backend.set_value\")\ndef set_value(x, value):\n    \"\"\"DEPRECATED.\"\"\"\n    value = np.asarray(value, dtype=x.dtype.name)\n    x.assign(value)\n\n\n@keras_export(\"keras._legacy.backend.shape\")\ndef shape(x):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.shape(x)\n\n\n@keras_export(\"keras._legacy.backend.sigmoid\")\ndef sigmoid(x):\n    \"\"\"DEPRECATED.\"\"\"\n    output = tf.sigmoid(x)\n    return output\n\n\n@keras_export(\"keras._legacy.backend.sign\")\ndef sign(x):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.sign(x)\n\n\n@keras_export(\"keras._legacy.backend.sin\")\ndef sin(x):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.sin(x)\n\n\n@keras_export(\"keras._legacy.backend.softmax\")\ndef softmax(x, axis=-1):\n    \"\"\"DEPRECATED.\"\"\"\n    if x.shape.rank <= 1:\n        raise ValueError(\n            f\"Cannot apply softmax to a tensor that is 1D. Received input: {x}\"\n        )\n\n    if isinstance(axis, int):\n        output = tf.nn.softmax(x, axis=axis)\n    else:\n        # nn.softmax does not support tuple axis.\n        numerator = tf.exp(x - tf.reduce_max(x, axis=axis, keepdims=True))\n        denominator = tf.reduce_sum(numerator, axis=axis, keepdims=True)\n        output = numerator / denominator\n\n    # Cache the logits to use for crossentropy loss.\n    output._keras_logits = x\n    return output\n\n\n@keras_export(\"keras._legacy.backend.softplus\")\ndef softplus(x):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.math.softplus(x)\n\n\n@keras_export(\"keras._legacy.backend.softsign\")\ndef softsign(x):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.math.softsign(x)\n\n\n@keras_export(\"keras._legacy.backend.sparse_categorical_crossentropy\")\ndef sparse_categorical_crossentropy(\n    target, output, from_logits=False, axis=-1, ignore_class=None\n):\n    \"\"\"DEPRECATED.\"\"\"\n    target = tf.convert_to_tensor(target)\n    output = tf.convert_to_tensor(output)\n\n    target = cast(target, \"int64\")\n\n    if not from_logits:\n        epsilon_ = tf.convert_to_tensor(backend.epsilon(), output.dtype)\n        output = tf.clip_by_value(output, epsilon_, 1 - epsilon_)\n        output = tf.math.log(output)\n\n    # Permute output so that the last axis contains the logits/probabilities.\n    if isinstance(output.shape, (tuple, list)):\n        output_rank = len(output.shape)\n    else:\n        output_rank = output.shape.ndims\n    if output_rank is not None:\n        axis %= output_rank\n        if axis != output_rank - 1:\n            permutation = list(\n                itertools.chain(\n                    range(axis), range(axis + 1, output_rank), [axis]\n                )\n            )\n            output = tf.transpose(output, perm=permutation)\n    elif axis != -1:\n        raise ValueError(\n            \"Cannot compute sparse categorical crossentropy with `axis={}` \"\n            \"on an output tensor with unknown rank\".format(axis)\n        )\n\n    # Try to adjust the shape so that rank of labels = rank of logits - 1.\n    output_shape = tf.shape(output)\n    target_rank = target.shape.ndims\n\n    update_shape = (\n        target_rank is not None\n        and output_rank is not None\n        and target_rank != output_rank - 1\n    )\n    if update_shape:\n        target = flatten(target)\n        output = tf.reshape(output, [-1, output_shape[-1]])\n\n    if ignore_class is not None:\n        valid_mask = tf.not_equal(target, cast(ignore_class, target.dtype))\n        target = target[valid_mask]\n        output = output[valid_mask]\n\n    res = tf.nn.sparse_softmax_cross_entropy_with_logits(\n        labels=target, logits=output\n    )\n\n    if ignore_class is not None:\n        res_shape = cast(output_shape[:-1], \"int64\")\n        valid_mask = tf.reshape(valid_mask, res_shape)\n        res = tf.scatter_nd(tf.where(valid_mask), res, res_shape)\n        res._keras_mask = valid_mask\n\n        return res\n\n    if update_shape and output_rank >= 3:\n        # If our output includes timesteps or\n        # spatial dimensions we need to reshape\n        res = tf.reshape(res, output_shape[:-1])\n\n    return res\n\n\n@keras_export(\"keras._legacy.backend.spatial_2d_padding\")\ndef spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):\n    \"\"\"DEPRECATED.\"\"\"\n    if len(padding) != 2 or len(padding[0]) != 2 or len(padding[1]) != 2:\n        raise ValueError(\n            \"Expected `padding` to be a tuple of 2 tuples of 2 integers. \"\n            f\"Received: padding={padding}\"\n        )\n    if data_format is None:\n        data_format = backend.image_data_format()\n    if data_format not in {\"channels_first\", \"channels_last\"}:\n        raise ValueError(f\"Unknown data_format: {data_format}\")\n\n    if data_format == \"channels_first\":\n        pattern = [[0, 0], [0, 0], list(padding[0]), list(padding[1])]\n    else:\n        pattern = [[0, 0], list(padding[0]), list(padding[1]), [0, 0]]\n    return tf.compat.v1.pad(x, pattern)\n\n\n@keras_export(\"keras._legacy.backend.spatial_3d_padding\")\ndef spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None):\n    \"\"\"DEPRECATED.\"\"\"\n    if (\n        len(padding) != 3\n        or len(padding[0]) != 2\n        or len(padding[1]) != 2\n        or len(padding[2]) != 2\n    ):\n        raise ValueError(\n            \"Expected `padding` to be a tuple of 3 tuples of 2 integers. \"\n            f\"Received: padding={padding}\"\n        )\n    if data_format is None:\n        data_format = backend.image_data_format()\n    if data_format not in {\"channels_first\", \"channels_last\"}:\n        raise ValueError(f\"Unknown data_format: {data_format}\")\n\n    if data_format == \"channels_first\":\n        pattern = [\n            [0, 0],\n            [0, 0],\n            [padding[0][0], padding[0][1]],\n            [padding[1][0], padding[1][1]],\n            [padding[2][0], padding[2][1]],\n        ]\n    else:\n        pattern = [\n            [0, 0],\n            [padding[0][0], padding[0][1]],\n            [padding[1][0], padding[1][1]],\n            [padding[2][0], padding[2][1]],\n            [0, 0],\n        ]\n    return tf.compat.v1.pad(x, pattern)\n\n\n@keras_export(\"keras._legacy.backend.sqrt\")\ndef sqrt(x):\n    \"\"\"DEPRECATED.\"\"\"\n    zero = tf.convert_to_tensor(0.0, x.dtype)\n    x = tf.maximum(x, zero)\n    return tf.sqrt(x)\n\n\n@keras_export(\"keras._legacy.backend.square\")\ndef square(x):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.square(x)\n\n\n@keras_export(\"keras._legacy.backend.squeeze\")\ndef squeeze(x, axis):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.squeeze(x, [axis])\n\n\n@keras_export(\"keras._legacy.backend.stack\")\ndef stack(x, axis=0):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.stack(x, axis=axis)\n\n\n@keras_export(\"keras._legacy.backend.std\")\ndef std(x, axis=None, keepdims=False):\n    \"\"\"DEPRECATED.\"\"\"\n    if x.dtype.base_dtype == tf.bool:\n        x = tf.cast(x, backend.floatx())\n    return tf.math.reduce_std(x, axis=axis, keepdims=keepdims)\n\n\n@keras_export(\"keras._legacy.backend.stop_gradient\")\ndef stop_gradient(variables):\n    \"\"\"DEPRECATED.\"\"\"\n    if isinstance(variables, (list, tuple)):\n        return map(tf.stop_gradient, variables)\n    return tf.stop_gradient(variables)\n\n\n@keras_export(\"keras._legacy.backend.sum\")\ndef sum(x, axis=None, keepdims=False):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.reduce_sum(x, axis, keepdims)\n\n\n@keras_export(\"keras._legacy.backend.switch\")\ndef switch(condition, then_expression, else_expression):\n    \"\"\"DEPRECATED.\"\"\"\n    if condition.dtype != tf.bool:\n        condition = tf.cast(condition, \"bool\")\n    cond_ndim = ndim(condition)\n    if not cond_ndim:\n        if not callable(then_expression):\n\n            def then_expression_fn():\n                return then_expression\n\n        else:\n            then_expression_fn = then_expression\n        if not callable(else_expression):\n\n            def else_expression_fn():\n                return else_expression\n\n        else:\n            else_expression_fn = else_expression\n        x = tf.compat.v1.cond(condition, then_expression_fn, else_expression_fn)\n    else:\n        # tf.where needs its condition tensor\n        # to be the same shape as its two\n        # result tensors\n        if callable(then_expression):\n            then_expression = then_expression()\n        if callable(else_expression):\n            else_expression = else_expression()\n        expr_ndim = ndim(then_expression)\n        if cond_ndim > expr_ndim:\n            raise ValueError(\n                \"Rank of `condition` should be less than or\"\n                \" equal to rank of `then_expression` and \"\n                \"`else_expression`. ndim(condition)=\"\n                f\"{cond_ndim}, ndim(then_expression)={expr_ndim}\"\n            )\n        if cond_ndim > 1:\n            ndim_diff = expr_ndim - cond_ndim\n            cond_shape = tf.concat(\n                [tf.shape(condition), [1] * ndim_diff], axis=0\n            )\n            condition = tf.reshape(condition, cond_shape)\n            expr_shape = tf.shape(then_expression)\n            shape_diff = expr_shape - cond_shape\n            tile_shape = tf.where(\n                shape_diff > 0, expr_shape, tf.ones_like(expr_shape)\n            )\n            condition = tf.tile(condition, tile_shape)\n        x = tf.where(condition, then_expression, else_expression)\n    return x\n\n\n@keras_export(\"keras._legacy.backend.tanh\")\ndef tanh(x):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.tanh(x)\n\n\n@keras_export(\"keras._legacy.backend.temporal_padding\")\ndef temporal_padding(x, padding=(1, 1)):\n    \"\"\"DEPRECATED.\"\"\"\n    if len(padding) != 2:\n        raise ValueError(\n            \"Expected `padding` to be a tuple of 2 integers. \"\n            f\"Received: padding={padding}\"\n        )\n    pattern = [[0, 0], [padding[0], padding[1]], [0, 0]]\n    return tf.compat.v1.pad(x, pattern)\n\n\n@keras_export(\"keras._legacy.backend.tile\")\ndef tile(x, n):\n    \"\"\"DEPRECATED.\"\"\"\n    if isinstance(n, int):\n        n = [n]\n    return tf.tile(x, n)\n\n\n@keras_export(\"keras._legacy.backend.to_dense\")\ndef to_dense(tensor):\n    \"\"\"DEPRECATED.\"\"\"\n    if is_sparse(tensor):\n        return tf.sparse.to_dense(tensor)\n    else:\n        return tensor\n\n\n@keras_export(\"keras._legacy.backend.transpose\")\ndef transpose(x):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.transpose(x)\n\n\n@keras_export(\"keras._legacy.backend.truncated_normal\")\ndef truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):\n    \"\"\"DEPRECATED.\"\"\"\n    if dtype is None:\n        dtype = backend.floatx()\n    if seed is None:\n        seed = np.random.randint(10e6)\n    return tf.random.truncated_normal(\n        shape, mean, stddev, dtype=dtype, seed=seed\n    )\n\n\n@keras_export(\"keras._legacy.backend.update\")\ndef update(x, new_x):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.compat.v1.assign(x, new_x)\n\n\n@keras_export(\"keras._legacy.backend.update_add\")\ndef update_add(x, increment):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.compat.v1.assign_add(x, increment)\n\n\n@keras_export(\"keras._legacy.backend.update_sub\")\ndef update_sub(x, decrement):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.compat.v1.assign_sub(x, decrement)\n\n\n@keras_export(\"keras._legacy.backend.var\")\ndef var(x, axis=None, keepdims=False):\n    \"\"\"DEPRECATED.\"\"\"\n    if x.dtype.base_dtype == tf.bool:\n        x = tf.cast(x, backend.floatx())\n    return tf.math.reduce_variance(x, axis=axis, keepdims=keepdims)\n\n\n@keras_export(\"keras._legacy.backend.variable\")\ndef variable(value, dtype=None, name=None, constraint=None):\n    \"\"\"DEPRECATED.\"\"\"\n    if dtype is None:\n        dtype = backend.floatx()\n    if hasattr(value, \"tocoo\"):\n        sparse_coo = value.tocoo()\n        indices = np.concatenate(\n            (\n                np.expand_dims(sparse_coo.row, 1),\n                np.expand_dims(sparse_coo.col, 1),\n            ),\n            1,\n        )\n        v = tf.SparseTensor(\n            indices=indices,\n            values=sparse_coo.data,\n            dense_shape=sparse_coo.shape,\n        )\n        v._keras_shape = sparse_coo.shape\n        return v\n    v = tf.Variable(\n        value, dtype=tf.as_dtype(dtype), name=name, constraint=constraint\n    )\n    return v\n\n\n@keras_export(\"keras._legacy.backend.zeros\")\ndef zeros(shape, dtype=None, name=None):\n    \"\"\"DEPRECATED.\"\"\"\n    with tf.init_scope():\n        if dtype is None:\n            dtype = backend.floatx()\n        tf_dtype = tf.as_dtype(dtype)\n        v = tf.zeros(shape=shape, dtype=tf_dtype, name=name)\n        if py_all(v.shape.as_list()):\n            return variable(v, dtype=dtype, name=name)\n        return v\n\n\n@keras_export(\"keras._legacy.backend.zeros_like\")\ndef zeros_like(x, dtype=None, name=None):\n    \"\"\"DEPRECATED.\"\"\"\n    return tf.zeros_like(x, dtype=dtype, name=name)\n"
  },
  {
    "path": "keras/src/legacy/layers.py",
    "content": "\"\"\"Legacy Keras 1/2 layers.\n\nAlphaDropout\nRandomHeight\nRandomWidth\nThresholdedReLU\n\"\"\"\n\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.layer import Layer\nfrom keras.src.utils.module_utils import tensorflow as tf\n\n\n@keras_export(\"keras._legacy.layers.AlphaDropout\")\nclass AlphaDropout(Layer):\n    \"\"\"DEPRECATED.\"\"\"\n\n    def __init__(self, rate, noise_shape=None, seed=None, **kwargs):\n        super().__init__(**kwargs)\n        self.rate = rate\n        self.seed = seed\n        self.noise_shape = noise_shape\n        self.seed_generator = backend.random.SeedGenerator(seed)\n        self.supports_masking = True\n        self.built = True\n\n    def call(self, inputs, training=False):\n        if training and self.rate > 0:\n            alpha = 1.6732632423543772848170429916717\n            scale = 1.0507009873554804934193349852946\n            alpha_p = -alpha * scale\n\n            if self.noise_shape is None:\n                noise_shape = tf.shape(inputs)\n            else:\n                noise_shape = self.noise_shape\n            kept_idx = tf.greater_equal(\n                backend.random.uniform(noise_shape, seed=self.seed_generator),\n                self.rate,\n            )\n            kept_idx = tf.cast(kept_idx, inputs.dtype)\n\n            # Get affine transformation params\n            a = ((1 - self.rate) * (1 + self.rate * alpha_p**2)) ** -0.5\n            b = -a * alpha_p * self.rate\n\n            # Apply mask\n            x = inputs * kept_idx + alpha_p * (1 - kept_idx)\n\n            # Do affine transformation\n            return a * x + b\n        return inputs\n\n    def get_config(self):\n        config = {\"rate\": self.rate, \"seed\": self.seed}\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n\n\n@keras_export(\"keras._legacy.layers.RandomHeight\")\nclass RandomHeight(Layer):\n    \"\"\"DEPRECATED.\"\"\"\n\n    def __init__(self, factor, interpolation=\"bilinear\", seed=None, **kwargs):\n        super().__init__(**kwargs)\n        self.seed_generator = backend.random.SeedGenerator(seed)\n        self.factor = factor\n        if isinstance(factor, (tuple, list)):\n            self.height_lower = factor[0]\n            self.height_upper = factor[1]\n        else:\n            self.height_lower = -factor\n            self.height_upper = factor\n\n        if self.height_upper < self.height_lower:\n            raise ValueError(\n                \"`factor` argument cannot have an upper bound lesser than the \"\n                f\"lower bound. Received: factor={factor}\"\n            )\n        if self.height_lower < -1.0 or self.height_upper < -1.0:\n            raise ValueError(\n                \"`factor` argument must have values larger than -1. \"\n                f\"Received: factor={factor}\"\n            )\n        self.interpolation = interpolation\n        self.seed = seed\n\n    def call(self, inputs, training=True):\n        inputs = tf.convert_to_tensor(inputs, dtype=self.compute_dtype)\n\n        def random_height_inputs(inputs):\n            \"\"\"Inputs height-adjusted with random ops.\"\"\"\n            inputs_shape = tf.shape(inputs)\n            img_hd = tf.cast(inputs_shape[-3], tf.float32)\n            img_wd = inputs_shape[-2]\n            height_factor = backend.random.uniform(\n                shape=[],\n                minval=(1.0 + self.height_lower),\n                maxval=(1.0 + self.height_upper),\n                seed=self.seed_generator,\n            )\n            adjusted_height = tf.cast(height_factor * img_hd, tf.int32)\n            adjusted_size = tf.stack([adjusted_height, img_wd])\n            output = tf.image.resize(\n                images=inputs,\n                size=adjusted_size,\n                method=self.interpolation,\n            )\n            # tf.resize will output float32 regardless of input type.\n            output = tf.cast(output, self.compute_dtype)\n            output_shape = inputs.shape.as_list()\n            output_shape[-3] = None\n            output.set_shape(output_shape)\n            return output\n\n        if training:\n            return random_height_inputs(inputs)\n        else:\n            return inputs\n\n    def compute_output_shape(self, input_shape):\n        input_shape = list(input_shape)\n        input_shape[-3] = None\n        return tuple(input_shape)\n\n    def get_config(self):\n        config = {\n            \"factor\": self.factor,\n            \"interpolation\": self.interpolation,\n            \"seed\": self.seed,\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n\n@keras_export(\"keras._legacy.layers.RandomWidth\")\nclass RandomWidth(Layer):\n    \"\"\"DEPRECATED.\"\"\"\n\n    def __init__(self, factor, interpolation=\"bilinear\", seed=None, **kwargs):\n        super().__init__(**kwargs)\n        self.seed_generator = backend.random.SeedGenerator(seed)\n        self.factor = factor\n        if isinstance(factor, (tuple, list)):\n            self.width_lower = factor[0]\n            self.width_upper = factor[1]\n        else:\n            self.width_lower = -factor\n            self.width_upper = factor\n        if self.width_upper < self.width_lower:\n            raise ValueError(\n                \"`factor` argument cannot have an upper bound less than the \"\n                f\"lower bound. Received: factor={factor}\"\n            )\n        if self.width_lower < -1.0 or self.width_upper < -1.0:\n            raise ValueError(\n                \"`factor` argument must have values larger than -1. \"\n                f\"Received: factor={factor}\"\n            )\n        self.interpolation = interpolation\n        self.seed = seed\n\n    def call(self, inputs, training=True):\n        inputs = tf.convert_to_tensor(inputs, dtype=self.compute_dtype)\n\n        def random_width_inputs(inputs):\n            \"\"\"Inputs width-adjusted with random ops.\"\"\"\n            inputs_shape = tf.shape(inputs)\n            img_hd = inputs_shape[-3]\n            img_wd = tf.cast(inputs_shape[-2], tf.float32)\n            width_factor = backend.random.uniform(\n                shape=[],\n                minval=(1.0 + self.width_lower),\n                maxval=(1.0 + self.width_upper),\n                seed=self.seed_generator,\n            )\n            adjusted_width = tf.cast(width_factor * img_wd, tf.int32)\n            adjusted_size = tf.stack([img_hd, adjusted_width])\n            output = tf.image.resize(\n                images=inputs,\n                size=adjusted_size,\n                method=self.interpolation,\n            )\n            # tf.resize will output float32 regardless of input type.\n            output = tf.cast(output, self.compute_dtype)\n            output_shape = inputs.shape.as_list()\n            output_shape[-2] = None\n            output.set_shape(output_shape)\n            return output\n\n        if training:\n            return random_width_inputs(inputs)\n        else:\n            return inputs\n\n    def compute_output_shape(self, input_shape):\n        input_shape = list(input_shape)\n        input_shape[-2] = None\n        return tuple(input_shape)\n\n    def get_config(self):\n        config = {\n            \"factor\": self.factor,\n            \"interpolation\": self.interpolation,\n            \"seed\": self.seed,\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n\n@keras_export(\"keras._legacy.layers.ThresholdedReLU\")\nclass ThresholdedReLU(Layer):\n    \"\"\"DEPRECATED.\"\"\"\n\n    def __init__(self, theta=1.0, **kwargs):\n        super().__init__(**kwargs)\n        if theta is None:\n            raise ValueError(\n                \"Theta of a Thresholded ReLU layer cannot be None, expecting a \"\n                f\"float. Received: {theta}\"\n            )\n        if theta < 0:\n            raise ValueError(\n                \"The theta value of a Thresholded ReLU layer \"\n                f\"should be >=0. Received: {theta}\"\n            )\n        self.supports_masking = True\n        self.theta = tf.convert_to_tensor(theta, dtype=self.compute_dtype)\n\n    def call(self, inputs):\n        dtype = self.compute_dtype\n        return inputs * tf.cast(tf.greater(inputs, self.theta), dtype)\n\n    def get_config(self):\n        config = {\"theta\": float(self.theta)}\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n    def compute_output_shape(self, input_shape):\n        return input_shape\n"
  },
  {
    "path": "keras/src/legacy/losses.py",
    "content": "from keras.src.api_export import keras_export\n\n\n@keras_export(\"keras._legacy.losses.Reduction\")\nclass Reduction:\n    AUTO = \"auto\"\n    NONE = \"none\"\n    SUM = \"sum\"\n    SUM_OVER_BATCH_SIZE = \"sum_over_batch_size\"\n\n    @classmethod\n    def all(cls):\n        return (cls.AUTO, cls.NONE, cls.SUM, cls.SUM_OVER_BATCH_SIZE)\n\n    @classmethod\n    def validate(cls, key):\n        if key not in cls.all():\n            raise ValueError(\n                f'Invalid Reduction Key: {key}. Expected keys are \"{cls.all()}\"'\n            )\n"
  },
  {
    "path": "keras/src/legacy/preprocessing/__init__.py",
    "content": ""
  },
  {
    "path": "keras/src/legacy/preprocessing/image.py",
    "content": "\"\"\"Deprecated image preprocessing APIs from Keras 1.\"\"\"\n\nimport collections\nimport multiprocessing\nimport os\nimport threading\nimport warnings\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.trainers.data_adapters.py_dataset_adapter import PyDataset\nfrom keras.src.utils import image_utils\nfrom keras.src.utils import io_utils\nfrom keras.src.utils.module_utils import scipy\n\n\n@keras_export(\"keras._legacy.preprocessing.image.Iterator\")\nclass Iterator(PyDataset):\n    \"\"\"Base class for image data iterators.\n\n    DEPRECATED.\n\n    Every `Iterator` must implement the `_get_batches_of_transformed_samples`\n    method.\n\n    Args:\n        n: Integer, total number of samples in the dataset to loop over.\n        batch_size: Integer, size of a batch.\n        shuffle: Boolean, whether to shuffle the data between epochs.\n        seed: Random seeding for data shuffling.\n        **kwargs: Additional keyword arguments for the `PyDataset` base class,\n            such as `workers`, `use_multiprocessing`, and `max_queue_size`.\n    \"\"\"\n\n    white_list_formats = (\"png\", \"jpg\", \"jpeg\", \"bmp\", \"ppm\", \"tif\", \"tiff\")\n\n    def __init__(self, n, batch_size, shuffle, seed, **kwargs):\n        super().__init__(**kwargs)\n        self.n = n\n        self.batch_size = batch_size\n        self.seed = seed\n        self.shuffle = shuffle\n        self.batch_index = 0\n        self.total_batches_seen = 0\n        self.lock = threading.Lock()\n        self.index_array = None\n        self.index_generator = self._flow_index()\n\n    def _set_index_array(self):\n        self.index_array = np.arange(self.n)\n        if self.shuffle:\n            self.index_array = np.random.permutation(self.n)\n\n    def __getitem__(self, idx):\n        if idx >= len(self):\n            raise ValueError(\n                \"Asked to retrieve element {idx}, \"\n                \"but the Sequence \"\n                \"has length {length}\".format(idx=idx, length=len(self))\n            )\n        if self.seed is not None:\n            np.random.seed(self.seed + self.total_batches_seen)\n        self.total_batches_seen += 1\n        if self.index_array is None:\n            self._set_index_array()\n        index_array = self.index_array[\n            self.batch_size * idx : self.batch_size * (idx + 1)\n        ]\n        return self._get_batches_of_transformed_samples(index_array)\n\n    def __len__(self):\n        return (self.n + self.batch_size - 1) // self.batch_size  # round up\n\n    def on_epoch_end(self):\n        self._set_index_array()\n\n    def reset(self):\n        self.batch_index = 0\n\n    def _flow_index(self):\n        # Ensure self.batch_index is 0.\n        self.reset()\n        while 1:\n            if self.seed is not None:\n                np.random.seed(self.seed + self.total_batches_seen)\n            if self.batch_index == 0:\n                self._set_index_array()\n\n            if self.n == 0:\n                # Avoiding modulo by zero error\n                current_index = 0\n            else:\n                current_index = (self.batch_index * self.batch_size) % self.n\n            if self.n > current_index + self.batch_size:\n                self.batch_index += 1\n            else:\n                self.batch_index = 0\n            self.total_batches_seen += 1\n            yield self.index_array[\n                current_index : current_index + self.batch_size\n            ]\n\n    def __iter__(self):\n        # Needed if we want to do something like:\n        # for x, y in data_gen.flow(...):\n        return self\n\n    def __next__(self):\n        with self.lock:\n            index_array = next(self.index_generator)\n        # The transformation of images is not under thread lock\n        # so it can be done in parallel\n        return self._get_batches_of_transformed_samples(index_array)\n\n    def _get_batches_of_transformed_samples(self, index_array):\n        \"\"\"Gets a batch of transformed samples.\n\n        Args:\n            index_array: Array of sample indices to include in batch.\n        Returns:\n            A batch of transformed samples.\n        \"\"\"\n        raise NotImplementedError\n\n\ndef _iter_valid_files(directory, white_list_formats, follow_links):\n    \"\"\"Iterates on files with extension.\n\n    Args:\n        directory: Absolute path to the directory\n            containing files to be counted\n        white_list_formats: Set of strings containing allowed extensions for\n            the files to be counted.\n        follow_links: Boolean, follow symbolic links to subdirectories.\n    Yields:\n        Tuple of (root, filename) with extension in `white_list_formats`.\n    \"\"\"\n\n    def _recursive_list(subpath):\n        return sorted(\n            os.walk(subpath, followlinks=follow_links), key=lambda x: x[0]\n        )\n\n    for root, _, files in _recursive_list(directory):\n        for fname in sorted(files):\n            if fname.lower().endswith(\".tiff\"):\n                warnings.warn(\n                    'Using \".tiff\" files with multiple bands '\n                    \"will cause distortion. Please verify your output.\"\n                )\n            if fname.lower().endswith(white_list_formats):\n                yield root, fname\n\n\ndef _list_valid_filenames_in_directory(\n    directory, white_list_formats, split, class_indices, follow_links\n):\n    \"\"\"Lists paths of files in `subdir` with extensions in `white_list_formats`.\n\n    Args:\n        directory: absolute path to a directory containing the files to list.\n            The directory name is used as class label\n            and must be a key of `class_indices`.\n        white_list_formats: set of strings containing allowed extensions for\n            the files to be counted.\n        split: tuple of floats (e.g. `(0.2, 0.6)`) to only take into\n            account a certain fraction of files in each directory.\n            E.g.: `segment=(0.6, 1.0)` would only account for last 40 percent\n            of images in each directory.\n        class_indices: dictionary mapping a class name to its index.\n        follow_links: boolean, follow symbolic links to subdirectories.\n\n    Returns:\n         classes: a list of class indices\n         filenames: the path of valid files in `directory`, relative from\n             `directory`'s parent (e.g., if `directory` is \"dataset/class1\",\n            the filenames will be\n            `[\"class1/file1.jpg\", \"class1/file2.jpg\", ...]`).\n    \"\"\"\n    dirname = os.path.basename(directory)\n    if split:\n        all_files = list(\n            _iter_valid_files(directory, white_list_formats, follow_links)\n        )\n        num_files = len(all_files)\n        start, stop = int(split[0] * num_files), int(split[1] * num_files)\n        valid_files = all_files[start:stop]\n    else:\n        valid_files = _iter_valid_files(\n            directory, white_list_formats, follow_links\n        )\n    classes = []\n    filenames = []\n    for root, fname in valid_files:\n        classes.append(class_indices[dirname])\n        absolute_path = os.path.join(root, fname)\n        relative_path = os.path.join(\n            dirname, os.path.relpath(absolute_path, directory)\n        )\n        filenames.append(relative_path)\n\n    return classes, filenames\n\n\nclass BatchFromFilesMixin:\n    \"\"\"Adds methods related to getting batches from filenames.\n\n    It includes the logic to transform image files to batches.\n    \"\"\"\n\n    def set_processing_attrs(\n        self,\n        image_data_generator,\n        target_size,\n        color_mode,\n        data_format,\n        save_to_dir,\n        save_prefix,\n        save_format,\n        subset,\n        interpolation,\n        keep_aspect_ratio,\n    ):\n        \"\"\"Sets attributes to use later for processing files into a batch.\n\n        Args:\n            image_data_generator: Instance of `ImageDataGenerator`\n                to use for random transformations and normalization.\n            target_size: tuple of integers, dimensions to resize input images\n            to.\n            color_mode: One of `\"rgb\"`, `\"rgba\"`, `\"grayscale\"`.\n                Color mode to read images.\n            data_format: String, one of `channels_first`, `channels_last`.\n            save_to_dir: Optional directory where to save the pictures\n                being yielded, in a viewable format. This is useful\n                for visualizing the random transformations being\n                applied, for debugging purposes.\n            save_prefix: String prefix to use for saving sample\n                images (if `save_to_dir` is set).\n            save_format: Format to use for saving sample images\n                (if `save_to_dir` is set).\n            subset: Subset of data (`\"training\"` or `\"validation\"`) if\n                validation_split is set in ImageDataGenerator.\n            interpolation: Interpolation method used to resample the image if\n                the target size is different from that of the loaded image.\n                Supported methods are \"nearest\", \"bilinear\", and \"bicubic\". If\n                PIL version 1.1.3 or newer is installed, \"lanczos\" is also\n                supported. If PIL version 3.4.0 or newer is installed, \"box\" and\n                \"hamming\" are also supported. By default, \"nearest\" is used.\n            keep_aspect_ratio: Boolean, whether to resize images to a target\n                size without aspect ratio distortion. The image is cropped in\n                the center with target aspect ratio before resizing.\n        \"\"\"\n        self.image_data_generator = image_data_generator\n        self.target_size = tuple(target_size)\n        self.keep_aspect_ratio = keep_aspect_ratio\n        if color_mode not in {\"rgb\", \"rgba\", \"grayscale\"}:\n            raise ValueError(\n                f\"Invalid color mode: {color_mode}\"\n                '; expected \"rgb\", \"rgba\", or \"grayscale\".'\n            )\n        self.color_mode = color_mode\n        self.data_format = data_format\n        if self.color_mode == \"rgba\":\n            if self.data_format == \"channels_last\":\n                self.image_shape = self.target_size + (4,)\n            else:\n                self.image_shape = (4,) + self.target_size\n        elif self.color_mode == \"rgb\":\n            if self.data_format == \"channels_last\":\n                self.image_shape = self.target_size + (3,)\n            else:\n                self.image_shape = (3,) + self.target_size\n        else:\n            if self.data_format == \"channels_last\":\n                self.image_shape = self.target_size + (1,)\n            else:\n                self.image_shape = (1,) + self.target_size\n        self.save_to_dir = save_to_dir\n        self.save_prefix = save_prefix\n        self.save_format = save_format\n        self.interpolation = interpolation\n        if subset is not None:\n            validation_split = self.image_data_generator._validation_split\n            if subset == \"validation\":\n                split = (0, validation_split)\n            elif subset == \"training\":\n                split = (validation_split, 1)\n            else:\n                raise ValueError(\n                    f\"Invalid subset name: {subset};\"\n                    'expected \"training\" or \"validation\"'\n                )\n        else:\n            split = None\n        self.split = split\n        self.subset = subset\n\n    def _get_batches_of_transformed_samples(self, index_array):\n        \"\"\"Gets a batch of transformed samples.\n\n        Args:\n            index_array: Array of sample indices to include in batch.\n        Returns:\n            A batch of transformed samples.\n        \"\"\"\n        batch_x = np.zeros(\n            (len(index_array),) + self.image_shape, dtype=self.dtype\n        )\n        # build batch of image data\n        # self.filepaths is dynamic, is better to call it once outside the loop\n        filepaths = self.filepaths\n        for i, j in enumerate(index_array):\n            img = image_utils.load_img(\n                filepaths[j],\n                color_mode=self.color_mode,\n                target_size=self.target_size,\n                interpolation=self.interpolation,\n                keep_aspect_ratio=self.keep_aspect_ratio,\n            )\n            x = image_utils.img_to_array(img, data_format=self.data_format)\n            # Pillow images should be closed after `load_img`,\n            # but not PIL images.\n            if hasattr(img, \"close\"):\n                img.close()\n            if self.image_data_generator:\n                params = self.image_data_generator.get_random_transform(x.shape)\n                x = self.image_data_generator.apply_transform(x, params)\n                x = self.image_data_generator.standardize(x)\n            batch_x[i] = x\n        # optionally save augmented images to disk for debugging purposes\n        if self.save_to_dir:\n            for i, j in enumerate(index_array):\n                img = image_utils.array_to_img(\n                    batch_x[i], self.data_format, scale=True\n                )\n                fname = \"{prefix}_{index}_{hash}.{format}\".format(\n                    prefix=self.save_prefix,\n                    index=j,\n                    hash=np.random.randint(1e7),\n                    format=self.save_format,\n                )\n                img.save(os.path.join(self.save_to_dir, fname))\n        # build batch of labels\n        if self.class_mode == \"input\":\n            batch_y = batch_x.copy()\n        elif self.class_mode in {\"binary\", \"sparse\"}:\n            batch_y = np.empty(len(batch_x), dtype=self.dtype)\n            for i, n_observation in enumerate(index_array):\n                batch_y[i] = self.classes[n_observation]\n        elif self.class_mode == \"categorical\":\n            batch_y = np.zeros(\n                (len(batch_x), len(self.class_indices)), dtype=self.dtype\n            )\n            for i, n_observation in enumerate(index_array):\n                batch_y[i, self.classes[n_observation]] = 1.0\n        elif self.class_mode == \"multi_output\":\n            batch_y = [output[index_array] for output in self.labels]\n        elif self.class_mode == \"raw\":\n            batch_y = self.labels[index_array]\n        else:\n            return batch_x\n        if self.sample_weight is None:\n            return batch_x, batch_y\n        else:\n            return batch_x, batch_y, self.sample_weight[index_array]\n\n    @property\n    def filepaths(self):\n        \"\"\"List of absolute paths to image files.\"\"\"\n        raise NotImplementedError(\n            \"`filepaths` property method has not \"\n            \"been implemented in {}.\".format(type(self).__name__)\n        )\n\n    @property\n    def labels(self):\n        \"\"\"Class labels of every observation.\"\"\"\n        raise NotImplementedError(\n            \"`labels` property method has not been implemented in {}.\".format(\n                type(self).__name__\n            )\n        )\n\n    @property\n    def sample_weight(self):\n        raise NotImplementedError(\n            \"`sample_weight` property method has not \"\n            \"been implemented in {}.\".format(type(self).__name__)\n        )\n\n\n@keras_export(\"keras._legacy.preprocessing.image.DirectoryIterator\")\nclass DirectoryIterator(BatchFromFilesMixin, Iterator):\n    \"\"\"Iterator capable of reading images from a directory on disk.\n\n    DEPRECATED.\n    \"\"\"\n\n    allowed_class_modes = {\"categorical\", \"binary\", \"sparse\", \"input\", None}\n\n    def __init__(\n        self,\n        directory,\n        image_data_generator,\n        target_size=(256, 256),\n        color_mode=\"rgb\",\n        classes=None,\n        class_mode=\"categorical\",\n        batch_size=32,\n        shuffle=True,\n        seed=None,\n        data_format=None,\n        save_to_dir=None,\n        save_prefix=\"\",\n        save_format=\"png\",\n        follow_links=False,\n        subset=None,\n        interpolation=\"nearest\",\n        keep_aspect_ratio=False,\n        dtype=None,\n    ):\n        if data_format is None:\n            data_format = backend.image_data_format()\n        if dtype is None:\n            dtype = backend.floatx()\n        super().set_processing_attrs(\n            image_data_generator,\n            target_size,\n            color_mode,\n            data_format,\n            save_to_dir,\n            save_prefix,\n            save_format,\n            subset,\n            interpolation,\n            keep_aspect_ratio,\n        )\n        self.directory = directory\n        self.classes = classes\n        if class_mode not in self.allowed_class_modes:\n            raise ValueError(\n                \"Invalid class_mode: {}; expected one of: {}\".format(\n                    class_mode, self.allowed_class_modes\n                )\n            )\n        self.class_mode = class_mode\n        self.dtype = dtype\n        # First, count the number of samples and classes.\n        self.samples = 0\n\n        if not classes:\n            classes = []\n            for subdir in sorted(os.listdir(directory)):\n                if os.path.isdir(os.path.join(directory, subdir)):\n                    classes.append(subdir)\n        self.num_classes = len(classes)\n        self.class_indices = dict(zip(classes, range(len(classes))))\n\n        pool = multiprocessing.pool.ThreadPool()\n\n        # Second, build an index of the images\n        # in the different class subfolders.\n        results = []\n        self.filenames = []\n        i = 0\n        for dirpath in (os.path.join(directory, subdir) for subdir in classes):\n            results.append(\n                pool.apply_async(\n                    _list_valid_filenames_in_directory,\n                    (\n                        dirpath,\n                        self.white_list_formats,\n                        self.split,\n                        self.class_indices,\n                        follow_links,\n                    ),\n                )\n            )\n        classes_list = []\n        for res in results:\n            classes, filenames = res.get()\n            classes_list.append(classes)\n            self.filenames += filenames\n        self.samples = len(self.filenames)\n        self.classes = np.zeros((self.samples,), dtype=\"int32\")\n        for classes in classes_list:\n            self.classes[i : i + len(classes)] = classes\n            i += len(classes)\n\n        io_utils.print_msg(\n            f\"Found {self.samples} images belonging to \"\n            f\"{self.num_classes} classes.\"\n        )\n        pool.close()\n        pool.join()\n        self._filepaths = [\n            os.path.join(self.directory, fname) for fname in self.filenames\n        ]\n        super().__init__(self.samples, batch_size, shuffle, seed)\n\n    @property\n    def filepaths(self):\n        return self._filepaths\n\n    @property\n    def labels(self):\n        return self.classes\n\n    @property  # mixin needs this property to work\n    def sample_weight(self):\n        # no sample weights will be returned\n        return None\n\n\n@keras_export(\"keras._legacy.preprocessing.image.NumpyArrayIterator\")\nclass NumpyArrayIterator(Iterator):\n    \"\"\"Iterator yielding data from a Numpy array.\n\n    DEPRECATED.\n    \"\"\"\n\n    def __init__(\n        self,\n        x,\n        y,\n        image_data_generator,\n        batch_size=32,\n        shuffle=False,\n        sample_weight=None,\n        seed=None,\n        data_format=None,\n        save_to_dir=None,\n        save_prefix=\"\",\n        save_format=\"png\",\n        subset=None,\n        ignore_class_split=False,\n        dtype=None,\n    ):\n        if data_format is None:\n            data_format = backend.image_data_format()\n        if dtype is None:\n            dtype = backend.floatx()\n        self.dtype = dtype\n        if isinstance(x, tuple) or isinstance(x, list):\n            if not isinstance(x[1], list):\n                x_misc = [np.asarray(x[1])]\n            else:\n                x_misc = [np.asarray(xx) for xx in x[1]]\n            x = x[0]\n            for xx in x_misc:\n                if len(x) != len(xx):\n                    raise ValueError(\n                        \"All of the arrays in `x` \"\n                        \"should have the same length. \"\n                        \"Found a pair with: \"\n                        f\"len(x[0]) = {len(x)}, len(x[?]) = {len(xx)}\"\n                    )\n        else:\n            x_misc = []\n\n        if y is not None and len(x) != len(y):\n            raise ValueError(\n                \"`x` (images tensor) and `y` (labels) \"\n                \"should have the same length. \"\n                f\"Found: x.shape = {np.asarray(x).shape}, \"\n                f\"y.shape = {np.asarray(y).shape}\"\n            )\n        if sample_weight is not None and len(x) != len(sample_weight):\n            raise ValueError(\n                \"`x` (images tensor) and `sample_weight` \"\n                \"should have the same length. \"\n                f\"Found: x.shape = {np.asarray(x).shape}, \"\n                f\"sample_weight.shape = {np.asarray(sample_weight).shape}\"\n            )\n        if subset is not None:\n            if subset not in {\"training\", \"validation\"}:\n                raise ValueError(\n                    f\"Invalid subset name: {subset}\"\n                    '; expected \"training\" or \"validation\".'\n                )\n            split_idx = int(len(x) * image_data_generator._validation_split)\n\n            if (\n                y is not None\n                and not ignore_class_split\n                and not np.array_equal(\n                    np.unique(y[:split_idx]), np.unique(y[split_idx:])\n                )\n            ):\n                raise ValueError(\n                    \"Training and validation subsets \"\n                    \"have different number of classes after \"\n                    \"the split. If your numpy arrays are \"\n                    \"sorted by the label, you might want \"\n                    \"to shuffle them.\"\n                )\n\n            if subset == \"validation\":\n                x = x[:split_idx]\n                x_misc = [np.asarray(xx[:split_idx]) for xx in x_misc]\n                if y is not None:\n                    y = y[:split_idx]\n            else:\n                x = x[split_idx:]\n                x_misc = [np.asarray(xx[split_idx:]) for xx in x_misc]\n                if y is not None:\n                    y = y[split_idx:]\n\n        self.x = np.asarray(x, dtype=self.dtype)\n        self.x_misc = x_misc\n        if self.x.ndim != 4:\n            raise ValueError(\n                \"Input data in `NumpyArrayIterator` \"\n                \"should have rank 4. You passed an array \"\n                f\"with shape {self.x.shape}\"\n            )\n        channels_axis = 3 if data_format == \"channels_last\" else 1\n        if self.x.shape[channels_axis] not in {1, 3, 4}:\n            warnings.warn(\n                f\"NumpyArrayIterator is set to use the data format convention\"\n                f' \"{data_format}\" (channels on axis {channels_axis})'\n                \", i.e. expected either 1, 3, or 4 channels \"\n                f\"on axis {channels_axis}. \"\n                f\"However, it was passed an array with shape {self.x.shape}\"\n                f\" ({self.x.shape[channels_axis]} channels).\"\n            )\n        if y is not None:\n            self.y = np.asarray(y)\n        else:\n            self.y = None\n        if sample_weight is not None:\n            self.sample_weight = np.asarray(sample_weight)\n        else:\n            self.sample_weight = None\n        self.image_data_generator = image_data_generator\n        self.data_format = data_format\n        self.save_to_dir = save_to_dir\n        self.save_prefix = save_prefix\n        self.save_format = save_format\n        super().__init__(x.shape[0], batch_size, shuffle, seed)\n\n    def _get_batches_of_transformed_samples(self, index_array):\n        batch_x = np.zeros(\n            tuple([len(index_array)] + list(self.x.shape)[1:]), dtype=self.dtype\n        )\n        for i, j in enumerate(index_array):\n            x = self.x[j]\n            params = self.image_data_generator.get_random_transform(x.shape)\n            x = self.image_data_generator.apply_transform(\n                x.astype(self.dtype), params\n            )\n            x = self.image_data_generator.standardize(x)\n            batch_x[i] = x\n\n        if self.save_to_dir:\n            for i, j in enumerate(index_array):\n                img = image_utils.array_to_img(\n                    batch_x[i], self.data_format, scale=True\n                )\n                fname = \"{prefix}_{index}_{hash}.{format}\".format(\n                    prefix=self.save_prefix,\n                    index=j,\n                    hash=np.random.randint(1e4),\n                    format=self.save_format,\n                )\n                img.save(os.path.join(self.save_to_dir, fname))\n        batch_x_miscs = [xx[index_array] for xx in self.x_misc]\n        output = (batch_x if not batch_x_miscs else [batch_x] + batch_x_miscs,)\n        if self.y is None:\n            return output[0]\n        output += (self.y[index_array],)\n        if self.sample_weight is not None:\n            output += (self.sample_weight[index_array],)\n        return output\n\n\ndef validate_filename(filename, white_list_formats):\n    \"\"\"Check if a filename refers to a valid file.\n\n    Args:\n        filename: String, absolute path to a file\n        white_list_formats: Set, allowed file extensions\n    Returns:\n        A boolean value indicating if the filename is valid or not\n    \"\"\"\n    return filename.lower().endswith(white_list_formats) and os.path.isfile(\n        filename\n    )\n\n\nclass DataFrameIterator(BatchFromFilesMixin, Iterator):\n    \"\"\"Iterator capable of reading images from a directory as a dataframe.\"\"\"\n\n    allowed_class_modes = {\n        \"binary\",\n        \"categorical\",\n        \"input\",\n        \"multi_output\",\n        \"raw\",\n        \"sparse\",\n        None,\n    }\n\n    def __init__(\n        self,\n        dataframe,\n        directory=None,\n        image_data_generator=None,\n        x_col=\"filename\",\n        y_col=\"class\",\n        weight_col=None,\n        target_size=(256, 256),\n        color_mode=\"rgb\",\n        classes=None,\n        class_mode=\"categorical\",\n        batch_size=32,\n        shuffle=True,\n        seed=None,\n        data_format=\"channels_last\",\n        save_to_dir=None,\n        save_prefix=\"\",\n        save_format=\"png\",\n        subset=None,\n        interpolation=\"nearest\",\n        keep_aspect_ratio=False,\n        dtype=\"float32\",\n        validate_filenames=True,\n    ):\n        super().set_processing_attrs(\n            image_data_generator,\n            target_size,\n            color_mode,\n            data_format,\n            save_to_dir,\n            save_prefix,\n            save_format,\n            subset,\n            interpolation,\n            keep_aspect_ratio,\n        )\n        df = dataframe.copy()\n        self.directory = directory or \"\"\n        self.class_mode = class_mode\n        self.dtype = dtype\n        # check that inputs match the required class_mode\n        self._check_params(df, x_col, y_col, weight_col, classes)\n        if (\n            validate_filenames\n        ):  # check which image files are valid and keep them\n            df = self._filter_valid_filepaths(df, x_col)\n        if class_mode not in [\"input\", \"multi_output\", \"raw\", None]:\n            df, classes = self._filter_classes(df, y_col, classes)\n            num_classes = len(classes)\n            # build an index of all the unique classes\n            self.class_indices = dict(zip(classes, range(len(classes))))\n        # retrieve only training or validation set\n        if self.split:\n            num_files = len(df)\n            start = int(self.split[0] * num_files)\n            stop = int(self.split[1] * num_files)\n            df = df.iloc[start:stop, :]\n        # get labels for each observation\n        if class_mode not in [\"input\", \"multi_output\", \"raw\", None]:\n            self.classes = self.get_classes(df, y_col)\n        self.filenames = df[x_col].tolist()\n        self._sample_weight = df[weight_col].values if weight_col else None\n\n        if class_mode == \"multi_output\":\n            self._targets = [np.array(df[col].tolist()) for col in y_col]\n        if class_mode == \"raw\":\n            self._targets = df[y_col].values\n        self.samples = len(self.filenames)\n        validated_string = (\n            \"validated\" if validate_filenames else \"non-validated\"\n        )\n        if class_mode in [\"input\", \"multi_output\", \"raw\", None]:\n            io_utils.print_msg(\n                f\"Found {self.samples} {validated_string} image filenames.\"\n            )\n        else:\n            io_utils.print_msg(\n                f\"Found {self.samples} {validated_string} image filenames \"\n                f\"belonging to {num_classes} classes.\"\n            )\n        self._filepaths = [\n            os.path.join(self.directory, fname) for fname in self.filenames\n        ]\n        super().__init__(self.samples, batch_size, shuffle, seed)\n\n    def _check_params(self, df, x_col, y_col, weight_col, classes):\n        # check class mode is one of the currently supported\n        if self.class_mode not in self.allowed_class_modes:\n            raise ValueError(\n                \"Invalid class_mode: {}; expected one of: {}\".format(\n                    self.class_mode, self.allowed_class_modes\n                )\n            )\n        # check that y_col has several column names if class_mode is\n        # multi_output\n        if (self.class_mode == \"multi_output\") and not isinstance(y_col, list):\n            raise TypeError(\n                'If class_mode=\"{}\", y_col must be a list. Received {}.'.format(\n                    self.class_mode, type(y_col).__name__\n                )\n            )\n        # check that filenames/filepaths column values are all strings\n        if not all(df[x_col].apply(lambda x: isinstance(x, str))):\n            raise TypeError(\n                f\"All values in column x_col={x_col} must be strings.\"\n            )\n        # check labels are string if class_mode is binary or sparse\n        if self.class_mode in {\"binary\", \"sparse\"}:\n            if not all(df[y_col].apply(lambda x: isinstance(x, str))):\n                raise TypeError(\n                    'If class_mode=\"{}\", y_col=\"{}\" column '\n                    \"values must be strings.\".format(self.class_mode, y_col)\n                )\n        # check that if binary there are only 2 different classes\n        if self.class_mode == \"binary\":\n            if classes:\n                classes = set(classes)\n                if len(classes) != 2:\n                    raise ValueError(\n                        'If class_mode=\"binary\" there must be 2 '\n                        \"classes. {} class/es were given.\".format(len(classes))\n                    )\n            elif df[y_col].nunique() != 2:\n                raise ValueError(\n                    'If class_mode=\"binary\" there must be 2 classes. '\n                    \"Found {} classes.\".format(df[y_col].nunique())\n                )\n        # check values are string, list or tuple if class_mode is categorical\n        if self.class_mode == \"categorical\":\n            types = (str, list, tuple)\n            if not all(df[y_col].apply(lambda x: isinstance(x, types))):\n                raise TypeError(\n                    'If class_mode=\"{}\", y_col=\"{}\" column '\n                    \"values must be type string, list or tuple.\".format(\n                        self.class_mode, y_col\n                    )\n                )\n        # raise warning if classes are given but will be unused\n        if classes and self.class_mode in {\n            \"input\",\n            \"multi_output\",\n            \"raw\",\n            None,\n        }:\n            warnings.warn(\n                '`classes` will be ignored given the class_mode=\"{}\"'.format(\n                    self.class_mode\n                )\n            )\n        # check that if weight column that the values are numerical\n        if weight_col and not issubclass(df[weight_col].dtype.type, np.number):\n            raise TypeError(f\"Column weight_col={weight_col} must be numeric.\")\n\n    def get_classes(self, df, y_col):\n        labels = []\n        for label in df[y_col]:\n            if isinstance(label, (list, tuple)):\n                labels.append([self.class_indices[lbl] for lbl in label])\n            else:\n                labels.append(self.class_indices[label])\n        return labels\n\n    @staticmethod\n    def _filter_classes(df, y_col, classes):\n        df = df.copy()\n\n        def remove_classes(labels, classes):\n            if isinstance(labels, (list, tuple)):\n                labels = [cls for cls in labels if cls in classes]\n                return labels or None\n            elif isinstance(labels, str):\n                return labels if labels in classes else None\n            else:\n                raise TypeError(\n                    \"Expect string, list or tuple \"\n                    \"but found {} in {} column \".format(type(labels), y_col)\n                )\n\n        if classes:\n            # prepare for membership lookup\n            classes = list(collections.OrderedDict.fromkeys(classes).keys())\n            df[y_col] = df[y_col].apply(lambda x: remove_classes(x, classes))\n        else:\n            classes = set()\n            for v in df[y_col]:\n                if isinstance(v, (list, tuple)):\n                    classes.update(v)\n                else:\n                    classes.add(v)\n            classes = sorted(classes)\n        return df.dropna(subset=[y_col]), classes\n\n    def _filter_valid_filepaths(self, df, x_col):\n        \"\"\"Keep only dataframe rows with valid filenames.\n\n        Args:\n            df: Pandas dataframe containing filenames in a column\n            x_col: string, column in `df` that contains the filenames or\n                filepaths\n        Returns:\n            absolute paths to image files\n        \"\"\"\n        filepaths = df[x_col].map(\n            lambda fname: os.path.join(self.directory, fname)\n        )\n        mask = filepaths.apply(\n            validate_filename, args=(self.white_list_formats,)\n        )\n        n_invalid = (~mask).sum()\n        if n_invalid:\n            warnings.warn(\n                'Found {} invalid image filename(s) in x_col=\"{}\". '\n                \"These filename(s) will be ignored.\".format(n_invalid, x_col)\n            )\n        return df[mask]\n\n    @property\n    def filepaths(self):\n        return self._filepaths\n\n    @property\n    def labels(self):\n        if self.class_mode in {\"multi_output\", \"raw\"}:\n            return self._targets\n        else:\n            return self.classes\n\n    @property\n    def sample_weight(self):\n        return self._sample_weight\n\n\ndef flip_axis(x, axis):\n    x = np.asarray(x).swapaxes(axis, 0)\n    x = x[::-1, ...]\n    x = x.swapaxes(0, axis)\n    return x\n\n\n@keras_export(\"keras._legacy.preprocessing.image.ImageDataGenerator\")\nclass ImageDataGenerator:\n    \"\"\"DEPRECATED.\"\"\"\n\n    def __init__(\n        self,\n        featurewise_center=False,\n        samplewise_center=False,\n        featurewise_std_normalization=False,\n        samplewise_std_normalization=False,\n        zca_whitening=False,\n        zca_epsilon=1e-6,\n        rotation_range=0,\n        width_shift_range=0.0,\n        height_shift_range=0.0,\n        brightness_range=None,\n        shear_range=0.0,\n        zoom_range=0.0,\n        channel_shift_range=0.0,\n        fill_mode=\"nearest\",\n        cval=0.0,\n        horizontal_flip=False,\n        vertical_flip=False,\n        rescale=None,\n        preprocessing_function=None,\n        data_format=None,\n        validation_split=0.0,\n        interpolation_order=1,\n        dtype=None,\n    ):\n        if data_format is None:\n            data_format = backend.image_data_format()\n        if dtype is None:\n            dtype = backend.floatx()\n\n        self.featurewise_center = featurewise_center\n        self.samplewise_center = samplewise_center\n        self.featurewise_std_normalization = featurewise_std_normalization\n        self.samplewise_std_normalization = samplewise_std_normalization\n        self.zca_whitening = zca_whitening\n        self.zca_epsilon = zca_epsilon\n        self.rotation_range = rotation_range\n        self.width_shift_range = width_shift_range\n        self.height_shift_range = height_shift_range\n        self.shear_range = shear_range\n        self.zoom_range = zoom_range\n        self.channel_shift_range = channel_shift_range\n        self.fill_mode = fill_mode\n        self.cval = cval\n        self.horizontal_flip = horizontal_flip\n        self.vertical_flip = vertical_flip\n        self.rescale = rescale\n        self.preprocessing_function = preprocessing_function\n        self.dtype = dtype\n        self.interpolation_order = interpolation_order\n\n        if data_format not in {\"channels_last\", \"channels_first\"}:\n            raise ValueError(\n                '`data_format` should be `\"channels_last\"` '\n                \"(channel after row and column) or \"\n                '`\"channels_first\"` (channel before row and column). '\n                f\"Received: {data_format}\"\n            )\n        self.data_format = data_format\n        if data_format == \"channels_first\":\n            self.channel_axis = 1\n            self.row_axis = 2\n            self.col_axis = 3\n        if data_format == \"channels_last\":\n            self.channel_axis = 3\n            self.row_axis = 1\n            self.col_axis = 2\n        if validation_split and not 0 < validation_split < 1:\n            raise ValueError(\n                \"`validation_split` must be strictly between 0 and 1. \"\n                f\" Received: {validation_split}\"\n            )\n        self._validation_split = validation_split\n\n        self.mean = None\n        self.std = None\n        self.zca_whitening_matrix = None\n\n        if isinstance(zoom_range, (float, int)):\n            self.zoom_range = [1 - zoom_range, 1 + zoom_range]\n        elif len(zoom_range) == 2 and all(\n            isinstance(val, (float, int)) for val in zoom_range\n        ):\n            self.zoom_range = [zoom_range[0], zoom_range[1]]\n        else:\n            raise ValueError(\n                \"`zoom_range` should be a float or \"\n                \"a tuple or list of two floats. \"\n                f\"Received: {zoom_range}\"\n            )\n        if zca_whitening:\n            if not featurewise_center:\n                self.featurewise_center = True\n                warnings.warn(\n                    \"This ImageDataGenerator specifies \"\n                    \"`zca_whitening`, which overrides \"\n                    \"setting of `featurewise_center`.\"\n                )\n            if featurewise_std_normalization:\n                self.featurewise_std_normalization = False\n                warnings.warn(\n                    \"This ImageDataGenerator specifies \"\n                    \"`zca_whitening` \"\n                    \"which overrides setting of\"\n                    \"`featurewise_std_normalization`.\"\n                )\n        if featurewise_std_normalization:\n            if not featurewise_center:\n                self.featurewise_center = True\n                warnings.warn(\n                    \"This ImageDataGenerator specifies \"\n                    \"`featurewise_std_normalization`, \"\n                    \"which overrides setting of \"\n                    \"`featurewise_center`.\"\n                )\n        if samplewise_std_normalization:\n            if not samplewise_center:\n                self.samplewise_center = True\n                warnings.warn(\n                    \"This ImageDataGenerator specifies \"\n                    \"`samplewise_std_normalization`, \"\n                    \"which overrides setting of \"\n                    \"`samplewise_center`.\"\n                )\n        if brightness_range is not None:\n            if (\n                not isinstance(brightness_range, (tuple, list))\n                or len(brightness_range) != 2\n            ):\n                raise ValueError(\n                    \"`brightness_range should be tuple or list of two floats. \"\n                    f\"Received: {brightness_range}\"\n                )\n        self.brightness_range = brightness_range\n\n    def flow(\n        self,\n        x,\n        y=None,\n        batch_size=32,\n        shuffle=True,\n        sample_weight=None,\n        seed=None,\n        save_to_dir=None,\n        save_prefix=\"\",\n        save_format=\"png\",\n        ignore_class_split=False,\n        subset=None,\n    ):\n        return NumpyArrayIterator(\n            x,\n            y,\n            self,\n            batch_size=batch_size,\n            shuffle=shuffle,\n            sample_weight=sample_weight,\n            seed=seed,\n            data_format=self.data_format,\n            save_to_dir=save_to_dir,\n            save_prefix=save_prefix,\n            save_format=save_format,\n            ignore_class_split=ignore_class_split,\n            subset=subset,\n            dtype=self.dtype,\n        )\n\n    def flow_from_directory(\n        self,\n        directory,\n        target_size=(256, 256),\n        color_mode=\"rgb\",\n        classes=None,\n        class_mode=\"categorical\",\n        batch_size=32,\n        shuffle=True,\n        seed=None,\n        save_to_dir=None,\n        save_prefix=\"\",\n        save_format=\"png\",\n        follow_links=False,\n        subset=None,\n        interpolation=\"nearest\",\n        keep_aspect_ratio=False,\n    ):\n        return DirectoryIterator(\n            directory,\n            self,\n            target_size=target_size,\n            color_mode=color_mode,\n            keep_aspect_ratio=keep_aspect_ratio,\n            classes=classes,\n            class_mode=class_mode,\n            data_format=self.data_format,\n            batch_size=batch_size,\n            shuffle=shuffle,\n            seed=seed,\n            save_to_dir=save_to_dir,\n            save_prefix=save_prefix,\n            save_format=save_format,\n            follow_links=follow_links,\n            subset=subset,\n            interpolation=interpolation,\n            dtype=self.dtype,\n        )\n\n    def flow_from_dataframe(\n        self,\n        dataframe,\n        directory=None,\n        x_col=\"filename\",\n        y_col=\"class\",\n        weight_col=None,\n        target_size=(256, 256),\n        color_mode=\"rgb\",\n        classes=None,\n        class_mode=\"categorical\",\n        batch_size=32,\n        shuffle=True,\n        seed=None,\n        save_to_dir=None,\n        save_prefix=\"\",\n        save_format=\"png\",\n        subset=None,\n        interpolation=\"nearest\",\n        validate_filenames=True,\n        **kwargs,\n    ):\n        if \"has_ext\" in kwargs:\n            warnings.warn(\n                \"has_ext is deprecated, filenames in the dataframe have \"\n                \"to match the exact filenames in disk.\",\n                DeprecationWarning,\n            )\n        if \"sort\" in kwargs:\n            warnings.warn(\n                \"sort is deprecated, batches will be created in the\"\n                \"same order than the filenames provided if `shuffle`\"\n                \"is set to `False`.\",\n                DeprecationWarning,\n            )\n        if class_mode == \"other\":\n            warnings.warn(\n                '`class_mode=\"other\"` is deprecated, please use '\n                '`class_mode=\"raw\"`.',\n                DeprecationWarning,\n            )\n            class_mode = \"raw\"\n        if \"drop_duplicates\" in kwargs:\n            warnings.warn(\n                \"drop_duplicates is deprecated, you can drop duplicates \"\n                \"by using the pandas.DataFrame.drop_duplicates method.\",\n                DeprecationWarning,\n            )\n\n        return DataFrameIterator(\n            dataframe,\n            directory,\n            self,\n            x_col=x_col,\n            y_col=y_col,\n            weight_col=weight_col,\n            target_size=target_size,\n            color_mode=color_mode,\n            classes=classes,\n            class_mode=class_mode,\n            data_format=self.data_format,\n            batch_size=batch_size,\n            shuffle=shuffle,\n            seed=seed,\n            save_to_dir=save_to_dir,\n            save_prefix=save_prefix,\n            save_format=save_format,\n            subset=subset,\n            interpolation=interpolation,\n            validate_filenames=validate_filenames,\n            dtype=self.dtype,\n        )\n\n    def standardize(self, x):\n        \"\"\"Applies the normalization configuration in-place to a batch of\n        inputs.\n\n        `x` is changed in-place since the function is mainly used internally\n        to standardize images and feed them to your network. If a copy of `x`\n        would be created instead it would have a significant performance cost.\n        If you want to apply this method without changing the input in-place\n        you can call the method creating a copy before:\n\n        standardize(np.copy(x))\n\n        Args:\n            x: Batch of inputs to be normalized.\n\n        Returns:\n            The inputs, normalized.\n        \"\"\"\n        if self.preprocessing_function:\n            x = self.preprocessing_function(x)\n        if self.rescale:\n            x *= self.rescale\n        if self.samplewise_center:\n            x -= np.mean(x, keepdims=True)\n        if self.samplewise_std_normalization:\n            x /= np.std(x, keepdims=True) + 1e-6\n\n        if self.featurewise_center:\n            if self.mean is not None:\n                x -= self.mean\n            else:\n                warnings.warn(\n                    \"This ImageDataGenerator specifies \"\n                    \"`featurewise_center`, but it hasn't \"\n                    \"been fit on any training data. Fit it \"\n                    \"first by calling `.fit(numpy_data)`.\"\n                )\n        if self.featurewise_std_normalization:\n            if self.std is not None:\n                x /= self.std + 1e-6\n            else:\n                warnings.warn(\n                    \"This ImageDataGenerator specifies \"\n                    \"`featurewise_std_normalization`, \"\n                    \"but it hasn't \"\n                    \"been fit on any training data. Fit it \"\n                    \"first by calling `.fit(numpy_data)`.\"\n                )\n        if self.zca_whitening:\n            if self.zca_whitening_matrix is not None:\n                flat_x = x.reshape(-1, np.prod(x.shape[-3:]))\n                white_x = flat_x @ self.zca_whitening_matrix\n                x = np.reshape(white_x, x.shape)\n            else:\n                warnings.warn(\n                    \"This ImageDataGenerator specifies \"\n                    \"`zca_whitening`, but it hasn't \"\n                    \"been fit on any training data. Fit it \"\n                    \"first by calling `.fit(numpy_data)`.\"\n                )\n        return x\n\n    def get_random_transform(self, img_shape, seed=None):\n        \"\"\"Generates random parameters for a transformation.\n\n        Args:\n            img_shape: Tuple of integers.\n                Shape of the image that is transformed.\n            seed: Random seed.\n\n        Returns:\n            A dictionary containing randomly chosen parameters describing the\n            transformation.\n        \"\"\"\n        img_row_axis = self.row_axis - 1\n        img_col_axis = self.col_axis - 1\n\n        if seed is not None:\n            np.random.seed(seed)\n\n        if self.rotation_range:\n            theta = np.random.uniform(-self.rotation_range, self.rotation_range)\n        else:\n            theta = 0\n\n        if self.height_shift_range:\n            try:  # 1-D array-like or int\n                tx = np.random.choice(self.height_shift_range)\n                tx *= np.random.choice([-1, 1])\n            except ValueError:  # floating point\n                tx = np.random.uniform(\n                    -self.height_shift_range, self.height_shift_range\n                )\n            if np.max(self.height_shift_range) < 1:\n                tx *= img_shape[img_row_axis]\n        else:\n            tx = 0\n\n        if self.width_shift_range:\n            try:  # 1-D array-like or int\n                ty = np.random.choice(self.width_shift_range)\n                ty *= np.random.choice([-1, 1])\n            except ValueError:  # floating point\n                ty = np.random.uniform(\n                    -self.width_shift_range, self.width_shift_range\n                )\n            if np.max(self.width_shift_range) < 1:\n                ty *= img_shape[img_col_axis]\n        else:\n            ty = 0\n\n        if self.shear_range:\n            shear = np.random.uniform(-self.shear_range, self.shear_range)\n        else:\n            shear = 0\n\n        if self.zoom_range[0] == 1 and self.zoom_range[1] == 1:\n            zx, zy = 1, 1\n        else:\n            zx, zy = np.random.uniform(\n                self.zoom_range[0], self.zoom_range[1], 2\n            )\n\n        flip_horizontal = (np.random.random() < 0.5) * self.horizontal_flip\n        flip_vertical = (np.random.random() < 0.5) * self.vertical_flip\n\n        channel_shift_intensity = None\n        if self.channel_shift_range != 0:\n            channel_shift_intensity = np.random.uniform(\n                -self.channel_shift_range, self.channel_shift_range\n            )\n\n        brightness = None\n        if self.brightness_range is not None:\n            brightness = np.random.uniform(\n                self.brightness_range[0], self.brightness_range[1]\n            )\n\n        transform_parameters = {\n            \"theta\": theta,\n            \"tx\": tx,\n            \"ty\": ty,\n            \"shear\": shear,\n            \"zx\": zx,\n            \"zy\": zy,\n            \"flip_horizontal\": flip_horizontal,\n            \"flip_vertical\": flip_vertical,\n            \"channel_shift_intensity\": channel_shift_intensity,\n            \"brightness\": brightness,\n        }\n\n        return transform_parameters\n\n    def apply_transform(self, x, transform_parameters):\n        \"\"\"Applies a transformation to an image according to given parameters.\n\n        Args:\n            x: 3D tensor, single image.\n            transform_parameters: Dictionary with string - parameter pairs\n                describing the transformation.\n                Currently, the following parameters\n                from the dictionary are used:\n                - `'theta'`: Float. Rotation angle in degrees.\n                - `'tx'`: Float. Shift in the x direction.\n                - `'ty'`: Float. Shift in the y direction.\n                - `'shear'`: Float. Shear angle in degrees.\n                - `'zx'`: Float. Zoom in the x direction.\n                - `'zy'`: Float. Zoom in the y direction.\n                - `'flip_horizontal'`: Boolean. Horizontal flip.\n                - `'flip_vertical'`: Boolean. Vertical flip.\n                - `'channel_shift_intensity'`: Float. Channel shift intensity.\n                - `'brightness'`: Float. Brightness shift intensity.\n\n        Returns:\n            A transformed version of the input (same shape).\n        \"\"\"\n        # x is a single image, so it doesn't have image number at index 0\n        img_row_axis = self.row_axis - 1\n        img_col_axis = self.col_axis - 1\n        img_channel_axis = self.channel_axis - 1\n\n        x = apply_affine_transform(\n            x,\n            transform_parameters.get(\"theta\", 0),\n            transform_parameters.get(\"tx\", 0),\n            transform_parameters.get(\"ty\", 0),\n            transform_parameters.get(\"shear\", 0),\n            transform_parameters.get(\"zx\", 1),\n            transform_parameters.get(\"zy\", 1),\n            row_axis=img_row_axis,\n            col_axis=img_col_axis,\n            channel_axis=img_channel_axis,\n            fill_mode=self.fill_mode,\n            cval=self.cval,\n            order=self.interpolation_order,\n        )\n\n        if transform_parameters.get(\"channel_shift_intensity\") is not None:\n            x = apply_channel_shift(\n                x,\n                transform_parameters[\"channel_shift_intensity\"],\n                img_channel_axis,\n            )\n\n        if transform_parameters.get(\"flip_horizontal\", False):\n            x = flip_axis(x, img_col_axis)\n\n        if transform_parameters.get(\"flip_vertical\", False):\n            x = flip_axis(x, img_row_axis)\n\n        if transform_parameters.get(\"brightness\") is not None:\n            x = apply_brightness_shift(\n                x, transform_parameters[\"brightness\"], False\n            )\n\n        return x\n\n    def random_transform(self, x, seed=None):\n        \"\"\"Applies a random transformation to an image.\n\n        Args:\n            x: 3D tensor, single image.\n            seed: Random seed.\n\n        Returns:\n            A randomly transformed version of the input (same shape).\n        \"\"\"\n        params = self.get_random_transform(x.shape, seed)\n        return self.apply_transform(x, params)\n\n    def fit(self, x, augment=False, rounds=1, seed=None):\n        \"\"\"Fits the data generator to some sample data.\n\n        This computes the internal data stats related to the\n        data-dependent transformations, based on an array of sample data.\n\n        Only required if `featurewise_center` or\n        `featurewise_std_normalization` or `zca_whitening`\n        are set to `True`.\n\n        When `rescale` is set to a value, rescaling is applied to\n        sample data before computing the internal data stats.\n\n        Args:\n            x: Sample data. Should have rank 4.\n             In case of grayscale data,\n             the channels axis should have value 1, in case\n             of RGB data, it should have value 3, and in case\n             of RGBA data, it should have value 4.\n            augment: Boolean (default: False).\n                Whether to fit on randomly augmented samples.\n            rounds: Int (default: 1).\n                If using data augmentation (`augment=True`),\n                this is how many augmentation passes over the data to use.\n            seed: Int (default: None). Random seed.\n        \"\"\"\n        x = np.asarray(x, dtype=self.dtype)\n        if x.ndim != 4:\n            raise ValueError(\n                \"Input to `.fit()` should have rank 4. Got array with shape: \"\n                + str(x.shape)\n            )\n        if x.shape[self.channel_axis] not in {1, 3, 4}:\n            warnings.warn(\n                \"Expected input to be images (as Numpy array) \"\n                f'following the data format convention \"{self.data_format}'\n                f'\" (channels on axis {self.channel_axis})'\n                \", i.e. expected either 1, 3 or 4 channels on axis \"\n                f\"{self.channel_axis}. However, it was passed an array with\"\n                f\" shape {x.shape} ({x.shape[self.channel_axis]} channels).\"\n            )\n\n        if seed is not None:\n            np.random.seed(seed)\n\n        x = np.copy(x)\n        if self.rescale:\n            x *= self.rescale\n\n        if augment:\n            ax = np.zeros(\n                tuple([rounds * x.shape[0]] + list(x.shape)[1:]),\n                dtype=self.dtype,\n            )\n            for r in range(rounds):\n                for i in range(x.shape[0]):\n                    ax[i + r * x.shape[0]] = self.random_transform(x[i])\n            x = ax\n\n        if self.featurewise_center:\n            self.mean = np.mean(x, axis=(0, self.row_axis, self.col_axis))\n            broadcast_shape = [1, 1, 1]\n            broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis]\n            self.mean = np.reshape(self.mean, broadcast_shape)\n            x -= self.mean\n\n        if self.featurewise_std_normalization:\n            self.std = np.std(x, axis=(0, self.row_axis, self.col_axis))\n            broadcast_shape = [1, 1, 1]\n            broadcast_shape[self.channel_axis - 1] = x.shape[self.channel_axis]\n            self.std = np.reshape(self.std, broadcast_shape)\n            x /= self.std + 1e-6\n\n        if self.zca_whitening:\n            n = len(x)\n            flat_x = np.reshape(x, (n, -1))\n\n            u, s, _ = np.linalg.svd(flat_x.T, full_matrices=False)\n            s_inv = np.sqrt(n) / (s + self.zca_epsilon)\n            self.zca_whitening_matrix = (u * s_inv).dot(u.T)\n\n\n@keras_export(\"keras._legacy.preprocessing.image.random_rotation\")\ndef random_rotation(\n    x,\n    rg,\n    row_axis=1,\n    col_axis=2,\n    channel_axis=0,\n    fill_mode=\"nearest\",\n    cval=0.0,\n    interpolation_order=1,\n):\n    \"\"\"DEPRECATED.\"\"\"\n    theta = np.random.uniform(-rg, rg)\n    x = apply_affine_transform(\n        x,\n        theta=theta,\n        row_axis=row_axis,\n        col_axis=col_axis,\n        channel_axis=channel_axis,\n        fill_mode=fill_mode,\n        cval=cval,\n        order=interpolation_order,\n    )\n    return x\n\n\n@keras_export(\"keras._legacy.preprocessing.image.random_shift\")\ndef random_shift(\n    x,\n    wrg,\n    hrg,\n    row_axis=1,\n    col_axis=2,\n    channel_axis=0,\n    fill_mode=\"nearest\",\n    cval=0.0,\n    interpolation_order=1,\n):\n    \"\"\"DEPRECATED.\"\"\"\n    h, w = x.shape[row_axis], x.shape[col_axis]\n    tx = np.random.uniform(-hrg, hrg) * h\n    ty = np.random.uniform(-wrg, wrg) * w\n    x = apply_affine_transform(\n        x,\n        tx=tx,\n        ty=ty,\n        row_axis=row_axis,\n        col_axis=col_axis,\n        channel_axis=channel_axis,\n        fill_mode=fill_mode,\n        cval=cval,\n        order=interpolation_order,\n    )\n    return x\n\n\n@keras_export(\"keras._legacy.preprocessing.image.random_shear\")\ndef random_shear(\n    x,\n    intensity,\n    row_axis=1,\n    col_axis=2,\n    channel_axis=0,\n    fill_mode=\"nearest\",\n    cval=0.0,\n    interpolation_order=1,\n):\n    \"\"\"DEPRECATED.\"\"\"\n    shear = np.random.uniform(-intensity, intensity)\n    x = apply_affine_transform(\n        x,\n        shear=shear,\n        row_axis=row_axis,\n        col_axis=col_axis,\n        channel_axis=channel_axis,\n        fill_mode=fill_mode,\n        cval=cval,\n        order=interpolation_order,\n    )\n    return x\n\n\n@keras_export(\"keras._legacy.preprocessing.image.random_zoom\")\ndef random_zoom(\n    x,\n    zoom_range,\n    row_axis=1,\n    col_axis=2,\n    channel_axis=0,\n    fill_mode=\"nearest\",\n    cval=0.0,\n    interpolation_order=1,\n):\n    \"\"\"DEPRECATED.\"\"\"\n    if len(zoom_range) != 2:\n        raise ValueError(\n            \"`zoom_range` should be a tuple or list of two floats. \"\n            f\"Received: {zoom_range}\"\n        )\n\n    if zoom_range[0] == 1 and zoom_range[1] == 1:\n        zx, zy = 1, 1\n    else:\n        zx, zy = np.random.uniform(zoom_range[0], zoom_range[1], 2)\n    x = apply_affine_transform(\n        x,\n        zx=zx,\n        zy=zy,\n        row_axis=row_axis,\n        col_axis=col_axis,\n        channel_axis=channel_axis,\n        fill_mode=fill_mode,\n        cval=cval,\n        order=interpolation_order,\n    )\n    return x\n\n\n@keras_export(\"keras._legacy.preprocessing.image.apply_channel_shift\")\ndef apply_channel_shift(x, intensity, channel_axis=0):\n    \"\"\"Performs a channel shift.\n\n    DEPRECATED.\n\n    Args:\n        x: Input tensor. Must be 3D.\n        intensity: Transformation intensity.\n        channel_axis: Index of axis for channels in the input tensor.\n\n    Returns:\n        Numpy image tensor.\n    \"\"\"\n    x = np.rollaxis(x, channel_axis, 0)\n    min_x, max_x = np.min(x), np.max(x)\n    channel_images = [\n        np.clip(x_channel + intensity, min_x, max_x) for x_channel in x\n    ]\n    x = np.stack(channel_images, axis=0)\n    x = np.rollaxis(x, 0, channel_axis + 1)\n    return x\n\n\n@keras_export(\"keras._legacy.preprocessing.image.random_channel_shift\")\ndef random_channel_shift(x, intensity_range, channel_axis=0):\n    \"\"\"Performs a random channel shift.\n\n    DEPRECATED.\n\n    Args:\n        x: Input tensor. Must be 3D.\n        intensity_range: Transformation intensity.\n        channel_axis: Index of axis for channels in the input tensor.\n\n    Returns:\n        Numpy image tensor.\n    \"\"\"\n    intensity = np.random.uniform(-intensity_range, intensity_range)\n    return apply_channel_shift(x, intensity, channel_axis=channel_axis)\n\n\n@keras_export(\"keras._legacy.preprocessing.image.apply_brightness_shift\")\ndef apply_brightness_shift(x, brightness, scale=True):\n    \"\"\"Performs a brightness shift.\n\n    DEPRECATED.\n\n    Args:\n        x: Input tensor. Must be 3D.\n        brightness: Float. The new brightness value.\n        scale: Whether to rescale the image such that minimum and maximum values\n            are 0 and 255 respectively. Default: True.\n\n    Returns:\n        Numpy image tensor.\n\n    Raises:\n        ImportError: if PIL is not available.\n    \"\"\"\n    from PIL import ImageEnhance\n\n    x_min, x_max = np.min(x), np.max(x)\n    local_scale = (x_min < 0) or (x_max > 255)\n    x = image_utils.array_to_img(x, scale=local_scale or scale)\n    x = imgenhancer_Brightness = ImageEnhance.Brightness(x)\n    x = imgenhancer_Brightness.enhance(brightness)\n    x = image_utils.img_to_array(x)\n    if not scale and local_scale:\n        x = x / 255 * (x_max - x_min) + x_min\n    return x\n\n\n@keras_export(\"keras._legacy.preprocessing.image.random_brightness\")\ndef random_brightness(x, brightness_range, scale=True):\n    \"\"\"Performs a random brightness shift.\n\n    DEPRECATED.\n\n    Args:\n        x: Input tensor. Must be 3D.\n        brightness_range: Tuple of floats; brightness range.\n        scale: Whether to rescale the image such that minimum and maximum values\n            are 0 and 255 respectively. Default: True.\n\n    Returns:\n        Numpy image tensor.\n\n    Raises:\n        ValueError if `brightness_range` isn't a tuple.\n    \"\"\"\n    if len(brightness_range) != 2:\n        raise ValueError(\n            \"`brightness_range should be tuple or list of two floats. \"\n            f\"Received: {brightness_range}\"\n        )\n\n    u = np.random.uniform(brightness_range[0], brightness_range[1])\n    return apply_brightness_shift(x, u, scale)\n\n\ndef transform_matrix_offset_center(matrix, x, y):\n    o_x = float(x) / 2 - 0.5\n    o_y = float(y) / 2 - 0.5\n    offset_matrix = np.array([[1, 0, o_x], [0, 1, o_y], [0, 0, 1]])\n    reset_matrix = np.array([[1, 0, -o_x], [0, 1, -o_y], [0, 0, 1]])\n    transform_matrix = np.dot(np.dot(offset_matrix, matrix), reset_matrix)\n    return transform_matrix\n\n\n@keras_export(\"keras._legacy.preprocessing.image.apply_affine_transform\")\ndef apply_affine_transform(\n    x,\n    theta=0,\n    tx=0,\n    ty=0,\n    shear=0,\n    zx=1,\n    zy=1,\n    row_axis=1,\n    col_axis=2,\n    channel_axis=0,\n    fill_mode=\"nearest\",\n    cval=0.0,\n    order=1,\n):\n    \"\"\"Applies an affine transformation specified by the parameters given.\n\n    DEPRECATED.\n    \"\"\"\n    # Input sanity checks:\n    # 1. x must 2D image with one or more channels (i.e., a 3D tensor)\n    # 2. channels must be either first or last dimension\n    if np.unique([row_axis, col_axis, channel_axis]).size != 3:\n        raise ValueError(\n            \"'row_axis', 'col_axis', and 'channel_axis' must be distinct\"\n        )\n\n    # shall we support negative indices?\n    valid_indices = set([0, 1, 2])\n    actual_indices = set([row_axis, col_axis, channel_axis])\n    if actual_indices != valid_indices:\n        raise ValueError(\n            f\"Invalid axis' indices: {actual_indices - valid_indices}\"\n        )\n\n    if x.ndim != 3:\n        raise ValueError(\"Input arrays must be multi-channel 2D images.\")\n    if channel_axis not in [0, 2]:\n        raise ValueError(\n            \"Channels are allowed and the first and last dimensions.\"\n        )\n\n    transform_matrix = None\n    if theta != 0:\n        theta = np.deg2rad(theta)\n        rotation_matrix = np.array(\n            [\n                [np.cos(theta), -np.sin(theta), 0],\n                [np.sin(theta), np.cos(theta), 0],\n                [0, 0, 1],\n            ]\n        )\n        transform_matrix = rotation_matrix\n\n    if tx != 0 or ty != 0:\n        shift_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]])\n        if transform_matrix is None:\n            transform_matrix = shift_matrix\n        else:\n            transform_matrix = np.dot(transform_matrix, shift_matrix)\n\n    if shear != 0:\n        shear = np.deg2rad(shear)\n        shear_matrix = np.array(\n            [[1, -np.sin(shear), 0], [0, np.cos(shear), 0], [0, 0, 1]]\n        )\n        if transform_matrix is None:\n            transform_matrix = shear_matrix\n        else:\n            transform_matrix = np.dot(transform_matrix, shear_matrix)\n\n    if zx != 1 or zy != 1:\n        zoom_matrix = np.array([[zx, 0, 0], [0, zy, 0], [0, 0, 1]])\n        if transform_matrix is None:\n            transform_matrix = zoom_matrix\n        else:\n            transform_matrix = np.dot(transform_matrix, zoom_matrix)\n\n    if transform_matrix is not None:\n        h, w = x.shape[row_axis], x.shape[col_axis]\n        transform_matrix = transform_matrix_offset_center(\n            transform_matrix, h, w\n        )\n        x = np.rollaxis(x, channel_axis, 0)\n\n        # Matrix construction assumes that coordinates are x, y (in that order).\n        # However, regular numpy arrays use y,x (aka i,j) indexing.\n        # Possible solution is:\n        #   1. Swap the x and y axes.\n        #   2. Apply transform.\n        #   3. Swap the x and y axes again to restore image-like data ordering.\n        # Mathematically, it is equivalent to the following transformation:\n        # M' = PMP, where P is the permutation matrix, M is the original\n        # transformation matrix.\n        if col_axis > row_axis:\n            transform_matrix[:, [0, 1]] = transform_matrix[:, [1, 0]]\n            transform_matrix[[0, 1]] = transform_matrix[[1, 0]]\n        final_affine_matrix = transform_matrix[:2, :2]\n        final_offset = transform_matrix[:2, 2]\n\n        channel_images = [\n            scipy.ndimage.interpolation.affine_transform(\n                x_channel,\n                final_affine_matrix,\n                final_offset,\n                order=order,\n                mode=fill_mode,\n                cval=cval,\n            )\n            for x_channel in x\n        ]\n        x = np.stack(channel_images, axis=0)\n        x = np.rollaxis(x, 0, channel_axis + 1)\n    return x\n"
  },
  {
    "path": "keras/src/legacy/preprocessing/sequence.py",
    "content": "\"\"\"Deprecated sequence preprocessing APIs from Keras 1.\"\"\"\n\nimport json\nimport random\n\nimport numpy as np\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.trainers.data_adapters.py_dataset_adapter import PyDataset\n\n\n@keras_export(\"keras._legacy.preprocessing.sequence.TimeseriesGenerator\")\nclass TimeseriesGenerator(PyDataset):\n    \"\"\"Utility class for generating batches of temporal data.\n\n    DEPRECATED.\n\n    This class takes in a sequence of data-points gathered at\n    equal intervals, along with time series parameters such as\n    stride, length of history, etc., to produce batches for\n    training/validation.\n\n    Arguments:\n        data: Indexable generator (such as list or Numpy array)\n            containing consecutive data points (timesteps).\n            The data should be at 2D, and axis 0 is expected\n            to be the time dimension.\n        targets: Targets corresponding to timesteps in `data`.\n            It should have same length as `data`.\n        length: Length of the output sequences (in number of timesteps).\n        sampling_rate: Period between successive individual timesteps\n            within sequences. For rate `r`, timesteps\n            `data[i]`, `data[i-r]`, ... `data[i - length]`\n            are used for create a sample sequence.\n        stride: Period between successive output sequences.\n            For stride `s`, consecutive output samples would\n            be centered around `data[i]`, `data[i+s]`, `data[i+2*s]`, etc.\n        start_index: Data points earlier than `start_index` will not be used\n            in the output sequences. This is useful to reserve part of the\n            data for test or validation.\n        end_index: Data points later than `end_index` will not be used\n            in the output sequences. This is useful to reserve part of the\n            data for test or validation.\n        shuffle: Whether to shuffle output samples,\n            or instead draw them in chronological order.\n        reverse: Boolean: if `true`, timesteps in each output sample will be\n            in reverse chronological order.\n        batch_size: Number of timeseries samples in each batch\n            (except maybe the last one).\n        **kwargs: Additional keyword arguments for the `PyDataset` base class,\n            such as `workers`, `use_multiprocessing`, and `max_queue_size`.\n\n    Returns:\n        A PyDataset instance.\n    \"\"\"\n\n    def __init__(\n        self,\n        data,\n        targets,\n        length,\n        sampling_rate=1,\n        stride=1,\n        start_index=0,\n        end_index=None,\n        shuffle=False,\n        reverse=False,\n        batch_size=128,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        if len(data) != len(targets):\n            raise ValueError(\n                \"Data and targets have to be \"\n                f\"of same length. Data length is {len(data)} \"\n                f\"while target length is {len(targets)}\"\n            )\n\n        self.data = data\n        self.targets = targets\n        self.length = length\n        self.sampling_rate = sampling_rate\n        self.stride = stride\n        self.start_index = start_index + length\n        if end_index is None:\n            end_index = len(data) - 1\n        self.end_index = end_index\n        self.shuffle = shuffle\n        self.reverse = reverse\n        self.batch_size = batch_size\n\n        if self.start_index > self.end_index:\n            raise ValueError(\n                f\"`start_index+length={self.start_index} \"\n                f\"> end_index={self.end_index}` \"\n                \"is disallowed, as no part of the sequence \"\n                \"would be left to be used as current step.\"\n            )\n\n    def __len__(self):\n        return (\n            self.end_index - self.start_index + self.batch_size * self.stride\n        ) // (self.batch_size * self.stride)\n\n    def __getitem__(self, index):\n        if self.shuffle:\n            rows = np.random.randint(\n                self.start_index, self.end_index + 1, size=self.batch_size\n            )\n        else:\n            i = self.start_index + self.batch_size * self.stride * index\n            rows = np.arange(\n                i,\n                min(i + self.batch_size * self.stride, self.end_index + 1),\n                self.stride,\n            )\n\n        samples = np.array(\n            [\n                self.data[row - self.length : row : self.sampling_rate]\n                for row in rows\n            ]\n        )\n        targets = np.array([self.targets[row] for row in rows])\n\n        if self.reverse:\n            return samples[:, ::-1, ...], targets\n        return samples, targets\n\n    def get_config(self):\n        \"\"\"Returns the TimeseriesGenerator configuration as Python dictionary.\n\n        Returns:\n            A Python dictionary with the TimeseriesGenerator configuration.\n        \"\"\"\n        data = self.data\n        if type(self.data).__module__ == np.__name__:\n            data = self.data.tolist()\n        try:\n            json_data = json.dumps(data)\n        except TypeError as e:\n            raise TypeError(f\"Data not JSON Serializable: {data}\") from e\n\n        targets = self.targets\n        if type(self.targets).__module__ == np.__name__:\n            targets = self.targets.tolist()\n        try:\n            json_targets = json.dumps(targets)\n        except TypeError as e:\n            raise TypeError(f\"Targets not JSON Serializable: {targets}\") from e\n\n        config = super().get_config()\n        config.update(\n            {\n                \"data\": json_data,\n                \"targets\": json_targets,\n                \"length\": self.length,\n                \"sampling_rate\": self.sampling_rate,\n                \"stride\": self.stride,\n                \"start_index\": self.start_index,\n                \"end_index\": self.end_index,\n                \"shuffle\": self.shuffle,\n                \"reverse\": self.reverse,\n                \"batch_size\": self.batch_size,\n            }\n        )\n        return config\n\n    def to_json(self, **kwargs):\n        \"\"\"Returns a JSON string containing the generator's configuration.\n\n        Args:\n            **kwargs: Additional keyword arguments to be passed\n                to `json.dumps()`.\n\n        Returns:\n            A JSON string containing the tokenizer configuration.\n        \"\"\"\n        config = self.get_config()\n        timeseries_generator_config = {\n            \"class_name\": self.__class__.__name__,\n            \"config\": config,\n        }\n        return json.dumps(timeseries_generator_config, **kwargs)\n\n\n@keras_export(\"keras._legacy.preprocessing.sequence.make_sampling_table\")\ndef make_sampling_table(size, sampling_factor=1e-5):\n    \"\"\"Generates a word rank-based probabilistic sampling table.\n\n    DEPRECATED.\n\n    Used for generating the `sampling_table` argument for `skipgrams`.\n    `sampling_table[i]` is the probability of sampling\n    the word i-th most common word in a dataset\n    (more common words should be sampled less frequently, for balance).\n\n    The sampling probabilities are generated according\n    to the sampling distribution used in word2vec:\n\n    ```\n    p(word) = (min(1, sqrt(word_frequency / sampling_factor) /\n        (word_frequency / sampling_factor)))\n    ```\n\n    We assume that the word frequencies follow Zipf's law (s=1) to derive\n    a numerical approximation of frequency(rank):\n\n    `frequency(rank) ~ 1/(rank * (log(rank) + gamma) + 1/2 - 1/(12*rank))`\n    where `gamma` is the Euler-Mascheroni constant.\n\n    Args:\n        size: Int, number of possible words to sample.\n        sampling_factor: The sampling factor in the word2vec formula.\n\n    Returns:\n        A 1D Numpy array of length `size` where the ith entry\n        is the probability that a word of rank i should be sampled.\n    \"\"\"\n    gamma = 0.577\n    rank = np.arange(size)\n    rank[0] = 1\n    inv_fq = rank * (np.log(rank) + gamma) + 0.5 - 1.0 / (12.0 * rank)\n    f = sampling_factor * inv_fq\n\n    return np.minimum(1.0, f / np.sqrt(f))\n\n\n@keras_export(\"keras._legacy.preprocessing.sequence.skipgrams\")\ndef skipgrams(\n    sequence,\n    vocabulary_size,\n    window_size=4,\n    negative_samples=1.0,\n    shuffle=True,\n    categorical=False,\n    sampling_table=None,\n    seed=None,\n):\n    \"\"\"Generates skipgram word pairs.\n\n    DEPRECATED.\n\n    This function transforms a sequence of word indexes (list of integers)\n    into tuples of words of the form:\n\n    - (word, word in the same window), with label 1 (positive samples).\n    - (word, random word from the vocabulary), with label 0 (negative samples).\n\n    Read more about Skipgram in this gnomic paper by Mikolov et al.:\n    [Efficient Estimation of Word Representations in\n    Vector Space](http://arxiv.org/pdf/1301.3781v3.pdf)\n\n    Args:\n        sequence: A word sequence (sentence), encoded as a list\n            of word indices (integers). If using a `sampling_table`,\n            word indices are expected to match the rank\n            of the words in a reference dataset (e.g. 10 would encode\n            the 10-th most frequently occurring token).\n            Note that index 0 is expected to be a non-word and will be skipped.\n        vocabulary_size: Int, maximum possible word index + 1\n        window_size: Int, size of sampling windows (technically half-window).\n            The window of a word `w_i` will be\n            `[i - window_size, i + window_size+1]`.\n        negative_samples: Float >= 0. 0 for no negative (i.e. random) samples.\n            1 for same number as positive samples.\n        shuffle: Whether to shuffle the word couples before returning them.\n        categorical: bool. if False, labels will be\n            integers (eg. `[0, 1, 1 .. ]`),\n            if `True`, labels will be categorical, e.g.\n            `[[1,0],[0,1],[0,1] .. ]`.\n        sampling_table: 1D array of size `vocabulary_size` where the entry i\n            encodes the probability to sample a word of rank i.\n        seed: Random seed.\n\n    Returns:\n        couples, labels: where `couples` are int pairs and\n            `labels` are either 0 or 1.\n\n    Note:\n        By convention, index 0 in the vocabulary is\n        a non-word and will be skipped.\n    \"\"\"\n    couples = []\n    labels = []\n    for i, wi in enumerate(sequence):\n        if not wi:\n            continue\n        if sampling_table is not None:\n            if sampling_table[wi] < random.random():\n                continue\n\n        window_start = max(0, i - window_size)\n        window_end = min(len(sequence), i + window_size + 1)\n        for j in range(window_start, window_end):\n            if j != i:\n                wj = sequence[j]\n                if not wj:\n                    continue\n                couples.append([wi, wj])\n                if categorical:\n                    labels.append([0, 1])\n                else:\n                    labels.append(1)\n\n    if negative_samples > 0:\n        num_negative_samples = int(len(labels) * negative_samples)\n        words = [c[0] for c in couples]\n        random.shuffle(words)\n\n        couples += [\n            [words[i % len(words)], random.randint(1, vocabulary_size - 1)]\n            for i in range(num_negative_samples)\n        ]\n        if categorical:\n            labels += [[1, 0]] * num_negative_samples\n        else:\n            labels += [0] * num_negative_samples\n\n    if shuffle:\n        if seed is None:\n            seed = random.randint(0, 10e6)\n        random.seed(seed)\n        random.shuffle(couples)\n        random.seed(seed)\n        random.shuffle(labels)\n\n    return couples, labels\n"
  },
  {
    "path": "keras/src/legacy/preprocessing/text.py",
    "content": "\"\"\"Deprecated text preprocessing APIs from Keras 1.\"\"\"\n\nimport collections\nimport hashlib\nimport json\nimport warnings\n\nimport numpy as np\n\nfrom keras.src.api_export import keras_export\n\n\n@keras_export(\"keras._legacy.preprocessing.text.text_to_word_sequence\")\ndef text_to_word_sequence(\n    input_text,\n    filters='!\"#$%&()*+,-./:;<=>?@[\\\\]^_`{|}~\\t\\n',\n    lower=True,\n    split=\" \",\n):\n    \"\"\"DEPRECATED.\"\"\"\n    if lower:\n        input_text = input_text.lower()\n\n    translate_dict = {c: split for c in filters}\n    translate_map = str.maketrans(translate_dict)\n    input_text = input_text.translate(translate_map)\n\n    seq = input_text.split(split)\n    return [i for i in seq if i]\n\n\n@keras_export(\"keras._legacy.preprocessing.text.one_hot\")\ndef one_hot(\n    input_text,\n    n,\n    filters='!\"#$%&()*+,-./:;<=>?@[\\\\]^_`{|}~\\t\\n',\n    lower=True,\n    split=\" \",\n    analyzer=None,\n):\n    \"\"\"DEPRECATED.\"\"\"\n    return hashing_trick(\n        input_text,\n        n,\n        hash_function=hash,\n        filters=filters,\n        lower=lower,\n        split=split,\n        analyzer=analyzer,\n    )\n\n\n@keras_export(\"keras._legacy.preprocessing.text.hashing_trick\")\ndef hashing_trick(\n    text,\n    n,\n    hash_function=None,\n    filters='!\"#$%&()*+,-./:;<=>?@[\\\\]^_`{|}~\\t\\n',\n    lower=True,\n    split=\" \",\n    analyzer=None,\n):\n    \"\"\"DEPRECATED.\"\"\"\n    if hash_function is None:\n        hash_function = hash\n    elif hash_function == \"md5\":\n\n        def hash_function(w):\n            return int(hashlib.md5(w.encode()).hexdigest(), 16)\n\n    if analyzer is None:\n        seq = text_to_word_sequence(\n            text, filters=filters, lower=lower, split=split\n        )\n    else:\n        seq = analyzer(text)\n\n    return [(hash_function(w) % (n - 1) + 1) for w in seq]\n\n\n@keras_export(\"keras._legacy.preprocessing.text.Tokenizer\")\nclass Tokenizer:\n    \"\"\"DEPRECATED.\"\"\"\n\n    def __init__(\n        self,\n        num_words=None,\n        filters='!\"#$%&()*+,-./:;<=>?@[\\\\]^_`{|}~\\t\\n',\n        lower=True,\n        split=\" \",\n        char_level=False,\n        oov_token=None,\n        analyzer=None,\n        **kwargs,\n    ):\n        # Legacy support\n        if \"nb_words\" in kwargs:\n            warnings.warn(\n                \"The `nb_words` argument in `Tokenizer` \"\n                \"has been renamed `num_words`.\"\n            )\n            num_words = kwargs.pop(\"nb_words\")\n        document_count = kwargs.pop(\"document_count\", 0)\n        if kwargs:\n            raise TypeError(f\"Unrecognized keyword arguments: {str(kwargs)}\")\n\n        self.word_counts = collections.OrderedDict()\n        self.word_docs = collections.defaultdict(int)\n        self.filters = filters\n        self.split = split\n        self.lower = lower\n        self.num_words = num_words\n        self.document_count = document_count\n        self.char_level = char_level\n        self.oov_token = oov_token\n        self.index_docs = collections.defaultdict(int)\n        self.word_index = {}\n        self.index_word = {}\n        self.analyzer = analyzer\n\n    def fit_on_texts(self, texts):\n        for text in texts:\n            self.document_count += 1\n            if self.char_level or isinstance(text, list):\n                if self.lower:\n                    if isinstance(text, list):\n                        text = [text_elem.lower() for text_elem in text]\n                    else:\n                        text = text.lower()\n                seq = text\n            else:\n                if self.analyzer is None:\n                    seq = text_to_word_sequence(\n                        text,\n                        filters=self.filters,\n                        lower=self.lower,\n                        split=self.split,\n                    )\n                else:\n                    seq = self.analyzer(text)\n            for w in seq:\n                if w in self.word_counts:\n                    self.word_counts[w] += 1\n                else:\n                    self.word_counts[w] = 1\n            for w in set(seq):\n                # In how many documents each word occurs\n                self.word_docs[w] += 1\n\n        wcounts = list(self.word_counts.items())\n        wcounts.sort(key=lambda x: x[1], reverse=True)\n        # forcing the oov_token to index 1 if it exists\n        if self.oov_token is None:\n            sorted_voc = []\n        else:\n            sorted_voc = [self.oov_token]\n        sorted_voc.extend(wc[0] for wc in wcounts)\n\n        # note that index 0 is reserved, never assigned to an existing word\n        self.word_index = dict(\n            zip(sorted_voc, list(range(1, len(sorted_voc) + 1)))\n        )\n\n        self.index_word = {c: w for w, c in self.word_index.items()}\n\n        for w, c in list(self.word_docs.items()):\n            self.index_docs[self.word_index[w]] = c\n\n    def fit_on_sequences(self, sequences):\n        self.document_count += len(sequences)\n        for seq in sequences:\n            seq = set(seq)\n            for i in seq:\n                self.index_docs[i] += 1\n\n    def texts_to_sequences(self, texts):\n        return list(self.texts_to_sequences_generator(texts))\n\n    def texts_to_sequences_generator(self, texts):\n        num_words = self.num_words\n        oov_token_index = self.word_index.get(self.oov_token)\n        for text in texts:\n            if self.char_level or isinstance(text, list):\n                if self.lower:\n                    if isinstance(text, list):\n                        text = [text_elem.lower() for text_elem in text]\n                    else:\n                        text = text.lower()\n                seq = text\n            else:\n                if self.analyzer is None:\n                    seq = text_to_word_sequence(\n                        text,\n                        filters=self.filters,\n                        lower=self.lower,\n                        split=self.split,\n                    )\n                else:\n                    seq = self.analyzer(text)\n            vect = []\n            for w in seq:\n                i = self.word_index.get(w)\n                if i is not None:\n                    if num_words and i >= num_words:\n                        if oov_token_index is not None:\n                            vect.append(oov_token_index)\n                    else:\n                        vect.append(i)\n                elif self.oov_token is not None:\n                    vect.append(oov_token_index)\n            yield vect\n\n    def sequences_to_texts(self, sequences):\n        return list(self.sequences_to_texts_generator(sequences))\n\n    def sequences_to_texts_generator(self, sequences):\n        num_words = self.num_words\n        oov_token_index = self.word_index.get(self.oov_token)\n        for seq in sequences:\n            vect = []\n            for num in seq:\n                word = self.index_word.get(num)\n                if word is not None:\n                    if num_words and num >= num_words:\n                        if oov_token_index is not None:\n                            vect.append(self.index_word[oov_token_index])\n                    else:\n                        vect.append(word)\n                elif self.oov_token is not None:\n                    vect.append(self.index_word[oov_token_index])\n            vect = \" \".join(vect)\n            yield vect\n\n    def texts_to_matrix(self, texts, mode=\"binary\"):\n        sequences = self.texts_to_sequences(texts)\n        return self.sequences_to_matrix(sequences, mode=mode)\n\n    def sequences_to_matrix(self, sequences, mode=\"binary\"):\n        if not self.num_words:\n            if self.word_index:\n                num_words = len(self.word_index) + 1\n            else:\n                raise ValueError(\n                    \"Specify a dimension (`num_words` argument), \"\n                    \"or fit on some text data first.\"\n                )\n        else:\n            num_words = self.num_words\n\n        if mode == \"tfidf\" and not self.document_count:\n            raise ValueError(\n                \"Fit the Tokenizer on some data before using tfidf mode.\"\n            )\n\n        x = np.zeros((len(sequences), num_words))\n        for i, seq in enumerate(sequences):\n            if not seq:\n                continue\n            counts = collections.defaultdict(int)\n            for j in seq:\n                if j >= num_words:\n                    continue\n                counts[j] += 1\n            for j, c in list(counts.items()):\n                if mode == \"count\":\n                    x[i][j] = c\n                elif mode == \"freq\":\n                    x[i][j] = c / len(seq)\n                elif mode == \"binary\":\n                    x[i][j] = 1\n                elif mode == \"tfidf\":\n                    # Use weighting scheme 2 in\n                    # https://en.wikipedia.org/wiki/Tf%E2%80%93idf\n                    tf = 1 + np.log(c)\n                    idf = np.log(\n                        1\n                        + self.document_count / (1 + self.index_docs.get(j, 0))\n                    )\n                    x[i][j] = tf * idf\n                else:\n                    raise ValueError(\"Unknown vectorization mode:\", mode)\n        return x\n\n    def get_config(self):\n        json_word_counts = json.dumps(self.word_counts)\n        json_word_docs = json.dumps(self.word_docs)\n        json_index_docs = json.dumps(self.index_docs)\n        json_word_index = json.dumps(self.word_index)\n        json_index_word = json.dumps(self.index_word)\n\n        return {\n            \"num_words\": self.num_words,\n            \"filters\": self.filters,\n            \"lower\": self.lower,\n            \"split\": self.split,\n            \"char_level\": self.char_level,\n            \"oov_token\": self.oov_token,\n            \"document_count\": self.document_count,\n            \"word_counts\": json_word_counts,\n            \"word_docs\": json_word_docs,\n            \"index_docs\": json_index_docs,\n            \"index_word\": json_index_word,\n            \"word_index\": json_word_index,\n        }\n\n    def to_json(self, **kwargs):\n        config = self.get_config()\n        tokenizer_config = {\n            \"class_name\": self.__class__.__name__,\n            \"config\": config,\n        }\n        return json.dumps(tokenizer_config, **kwargs)\n\n\n@keras_export(\"keras._legacy.preprocessing.text.tokenizer_from_json\")\ndef tokenizer_from_json(json_string):\n    \"\"\"DEPRECATED.\"\"\"\n    tokenizer_config = json.loads(json_string)\n    config = tokenizer_config.get(\"config\")\n\n    word_counts = json.loads(config.pop(\"word_counts\"))\n    word_docs = json.loads(config.pop(\"word_docs\"))\n    index_docs = json.loads(config.pop(\"index_docs\"))\n    # Integer indexing gets converted to strings with json.dumps()\n    index_docs = {int(k): v for k, v in index_docs.items()}\n    index_word = json.loads(config.pop(\"index_word\"))\n    index_word = {int(k): v for k, v in index_word.items()}\n    word_index = json.loads(config.pop(\"word_index\"))\n\n    tokenizer = Tokenizer(**config)\n    tokenizer.word_counts = word_counts\n    tokenizer.word_docs = word_docs\n    tokenizer.index_docs = index_docs\n    tokenizer.word_index = word_index\n    tokenizer.index_word = index_word\n    return tokenizer\n"
  },
  {
    "path": "keras/src/legacy/saving/__init__.py",
    "content": ""
  },
  {
    "path": "keras/src/legacy/saving/json_utils.py",
    "content": "\"\"\"JSON utilities for legacy saving formats (h5 and SavedModel)\"\"\"\n\nimport collections\nimport enum\nimport functools\nimport json\n\nimport numpy as np\n\nfrom keras.src.legacy.saving import serialization\nfrom keras.src.saving import serialization_lib\nfrom keras.src.utils.module_utils import tensorflow as tf\n\n_EXTENSION_TYPE_SPEC = \"_EXTENSION_TYPE_SPEC\"\n\n\nclass Encoder(json.JSONEncoder):\n    \"\"\"JSON encoder and decoder that handles TensorShapes and tuples.\"\"\"\n\n    def default(self, obj):\n        \"\"\"Encodes objects for types that aren't handled by the default\n        encoder.\"\"\"\n        if tf.available and isinstance(obj, tf.TensorShape):\n            items = obj.as_list() if obj.rank is not None else None\n            return {\"class_name\": \"TensorShape\", \"items\": items}\n        return get_json_type(obj)\n\n    def encode(self, obj):\n        return super().encode(_encode_tuple(obj))\n\n\ndef _encode_tuple(x):\n    if isinstance(x, tuple):\n        return {\n            \"class_name\": \"__tuple__\",\n            \"items\": tuple(_encode_tuple(i) for i in x),\n        }\n    elif isinstance(x, list):\n        return [_encode_tuple(i) for i in x]\n    elif isinstance(x, dict):\n        return {key: _encode_tuple(value) for key, value in x.items()}\n    else:\n        return x\n\n\ndef decode(json_string):\n    return json.loads(json_string, object_hook=_decode_helper)\n\n\ndef decode_and_deserialize(\n    json_string, module_objects=None, custom_objects=None\n):\n    \"\"\"Decodes the JSON and deserializes any Keras objects found in the dict.\"\"\"\n    return json.loads(\n        json_string,\n        object_hook=functools.partial(\n            _decode_helper,\n            deserialize=True,\n            module_objects=module_objects,\n            custom_objects=custom_objects,\n        ),\n    )\n\n\ndef _decode_helper(\n    obj, deserialize=False, module_objects=None, custom_objects=None\n):\n    \"\"\"A decoding helper that is TF-object aware.\n\n    Args:\n      obj: A decoded dictionary that may represent an object.\n      deserialize: Boolean. When True, deserializes any Keras\n        objects found in `obj`. Defaults to `False`.\n      module_objects: A dictionary of built-in objects to look the name up in.\n        Generally, `module_objects` is provided by midlevel library\n        implementers.\n      custom_objects: A dictionary of custom objects to look the name up in.\n        Generally, `custom_objects` is provided by the end user.\n\n    Returns:\n      The decoded object.\n    \"\"\"\n    if isinstance(obj, dict) and \"class_name\" in obj:\n        if tf.available:\n            if obj[\"class_name\"] == \"TensorShape\":\n                return tf.TensorShape(obj[\"items\"])\n            elif obj[\"class_name\"] == \"TypeSpec\":\n                from tensorflow.python.framework import type_spec_registry\n\n                return type_spec_registry.lookup(obj[\"type_spec\"])._deserialize(\n                    _decode_helper(obj[\"serialized\"])\n                )\n            elif obj[\"class_name\"] == \"CompositeTensor\":\n                spec = obj[\"spec\"]\n                tensors = []\n                for dtype, tensor in obj[\"tensors\"]:\n                    tensors.append(\n                        tf.constant(tensor, dtype=tf.dtypes.as_dtype(dtype))\n                    )\n                return tf.nest.pack_sequence_as(\n                    _decode_helper(spec), tensors, expand_composites=True\n                )\n\n        if obj[\"class_name\"] == \"__tuple__\":\n            return tuple(_decode_helper(i) for i in obj[\"items\"])\n        elif obj[\"class_name\"] == \"__ellipsis__\":\n            return Ellipsis\n        elif deserialize and \"__passive_serialization__\" in obj:\n            # __passive_serialization__ is added by the JSON encoder when\n            # encoding an object that has a `get_config()` method.\n            try:\n                if (\n                    \"module\" not in obj\n                ):  # TODO(nkovela): Add TF SavedModel scope\n                    return serialization.deserialize_keras_object(\n                        obj,\n                        module_objects=module_objects,\n                        custom_objects=custom_objects,\n                    )\n                else:\n                    return serialization_lib.deserialize_keras_object(\n                        obj,\n                        module_objects=module_objects,\n                        custom_objects=custom_objects,\n                    )\n            except ValueError:\n                pass\n        elif obj[\"class_name\"] == \"__bytes__\":\n            return obj[\"value\"].encode(\"utf-8\")\n    return obj\n\n\ndef get_json_type(obj):\n    \"\"\"Serializes any object to a JSON-serializable structure.\n\n    Args:\n        obj: the object to serialize\n\n    Returns:\n        JSON-serializable structure representing `obj`.\n\n    Raises:\n        TypeError: if `obj` cannot be serialized.\n    \"\"\"\n    # if obj is a serializable Keras class instance\n    # e.g. optimizer, layer\n    if hasattr(obj, \"get_config\"):\n        # TODO(nkovela): Replace with legacy serialization\n        serialized = serialization.serialize_keras_object(obj)\n        serialized[\"__passive_serialization__\"] = True\n        return serialized\n\n    # if obj is any numpy type\n    if type(obj).__module__ == np.__name__:\n        if isinstance(obj, np.ndarray):\n            return obj.tolist()\n        else:\n            return obj.item()\n\n    # misc functions (e.g. loss function)\n    if callable(obj):\n        return obj.__name__\n\n    # if obj is a python 'type'\n    if type(obj).__name__ == type.__name__:\n        return obj.__name__\n\n    if tf.available and isinstance(obj, tf.compat.v1.Dimension):\n        return obj.value\n\n    if tf.available and isinstance(obj, tf.TensorShape):\n        return obj.as_list()\n\n    if tf.available and isinstance(obj, tf.DType):\n        return obj.name\n\n    if isinstance(obj, collections.abc.Mapping):\n        return dict(obj)\n\n    if obj is Ellipsis:\n        return {\"class_name\": \"__ellipsis__\"}\n\n    # if isinstance(obj, wrapt.ObjectProxy):\n    #     return obj.__wrapped__\n\n    if tf.available and isinstance(obj, tf.TypeSpec):\n        from tensorflow.python.framework import type_spec_registry\n\n        try:\n            type_spec_name = type_spec_registry.get_name(type(obj))\n            return {\n                \"class_name\": \"TypeSpec\",\n                \"type_spec\": type_spec_name,\n                \"serialized\": obj._serialize(),\n            }\n        except ValueError:\n            raise ValueError(\n                f\"Unable to serialize {obj} to JSON, because the TypeSpec \"\n                f\"class {type(obj)} has not been registered.\"\n            )\n    if tf.available and isinstance(obj, tf.__internal__.CompositeTensor):\n        spec = tf.type_spec_from_value(obj)\n        tensors = []\n        for tensor in tf.nest.flatten(obj, expand_composites=True):\n            tensors.append((tensor.dtype.name, tensor.numpy().tolist()))\n        return {\n            \"class_name\": \"CompositeTensor\",\n            \"spec\": get_json_type(spec),\n            \"tensors\": tensors,\n        }\n\n    if isinstance(obj, enum.Enum):\n        return obj.value\n\n    if isinstance(obj, bytes):\n        return {\"class_name\": \"__bytes__\", \"value\": obj.decode(\"utf-8\")}\n\n    raise TypeError(\n        f\"Unable to serialize {obj} to JSON. Unrecognized type {type(obj)}.\"\n    )\n"
  },
  {
    "path": "keras/src/legacy/saving/json_utils_test.py",
    "content": "import enum\n\nimport pytest\n\nfrom keras.src import backend\nfrom keras.src import testing\nfrom keras.src.legacy.saving import json_utils\n\nif backend.backend() == \"tensorflow\":\n    import tensorflow as tf\n\n\nclass JsonUtilsTestAllBackends(testing.TestCase):\n    def test_encode_decode_tuple(self):\n        metadata = {\"key1\": (3, 5), \"key2\": [(1, (3, 4)), (1,)]}\n        string = json_utils.Encoder().encode(metadata)\n        loaded = json_utils.decode(string)\n\n        self.assertEqual(set(loaded.keys()), {\"key1\", \"key2\"})\n        self.assertAllEqual(loaded[\"key1\"], (3, 5))\n        self.assertAllEqual(loaded[\"key2\"], [(1, (3, 4)), (1,)])\n\n    def test_encode_decode_enum(self):\n        class Enum(enum.Enum):\n            CLASS_A = \"a\"\n            CLASS_B = \"b\"\n\n        config = {\"key\": Enum.CLASS_A, \"key2\": Enum.CLASS_B}\n        string = json_utils.Encoder().encode(config)\n        loaded = json_utils.decode(string)\n        self.assertAllEqual({\"key\": \"a\", \"key2\": \"b\"}, loaded)\n\n    def test_encode_decode_bytes(self):\n        b_string = b\"abc\"\n        json_string = json_utils.Encoder().encode(b_string)\n        loaded = json_utils.decode(json_string)\n        self.assertAllEqual(b_string, loaded)\n\n\n@pytest.mark.skipif(\n    backend.backend() != \"tensorflow\",\n    reason=\"These JSON serialization tests are specific to TF components.\",\n)\nclass JsonUtilsTestTF(testing.TestCase):\n    def test_encode_decode_tensor_shape(self):\n        metadata = {\n            \"key1\": tf.TensorShape(None),\n            \"key2\": [tf.TensorShape([None]), tf.TensorShape([3, None, 5])],\n        }\n        string = json_utils.Encoder().encode(metadata)\n        loaded = json_utils.decode(string)\n\n        self.assertEqual(set(loaded.keys()), {\"key1\", \"key2\"})\n        self.assertEqual(loaded[\"key1\"].rank, None)\n        self.assertAllEqual(loaded[\"key2\"][0].as_list(), [None])\n        self.assertAllEqual(loaded[\"key2\"][1].as_list(), [3, None, 5])\n\n    def test_encode_decode_type_spec(self):\n        spec = tf.TensorSpec((1, 5), tf.float32)\n        string = json_utils.Encoder().encode(spec)\n        loaded = json_utils.decode(string)\n        self.assertEqual(spec, loaded)\n\n        invalid_type_spec = {\n            \"class_name\": \"TypeSpec\",\n            \"type_spec\": \"Invalid Type\",\n            \"serialized\": None,\n        }\n        string = json_utils.Encoder().encode(invalid_type_spec)\n        with self.assertRaisesRegex(\n            ValueError, \"No TypeSpec has been registered\"\n        ):\n            loaded = json_utils.decode(string)\n\n    def test_encode_decode_ragged_tensor(self):\n        x = tf.ragged.constant([[1.0, 2.0], [3.0]])\n        string = json_utils.Encoder().encode(x)\n        loaded = json_utils.decode(string)\n        self.assertAllClose(loaded.values, x.values)\n\n    def test_encode_decode_extension_type_tensor(self):\n        class MaskedTensor(tf.experimental.ExtensionType):\n            __name__ = \"MaskedTensor\"\n            values: tf.Tensor\n            mask: tf.Tensor\n\n        x = MaskedTensor(\n            values=[[1, 2, 3], [4, 5, 6]],\n            mask=[[True, True, False], [True, False, True]],\n        )\n        string = json_utils.Encoder().encode(x)\n        loaded = json_utils.decode(string)\n        self.assertAllClose(loaded.values, x.values)\n        self.assertAllClose(loaded.mask, x.mask)\n"
  },
  {
    "path": "keras/src/legacy/saving/legacy_h5_format.py",
    "content": "import json\nimport os\nimport warnings\n\nimport numpy as np\nfrom absl import logging\n\nfrom keras.src import backend\nfrom keras.src.backend.common import global_state\nfrom keras.src.legacy.saving import json_utils\nfrom keras.src.legacy.saving import saving_options\nfrom keras.src.legacy.saving import saving_utils\nfrom keras.src.saving import object_registration\nfrom keras.src.saving import serialization_lib\nfrom keras.src.utils import io_utils\n\ntry:\n    import h5py\nexcept ImportError:\n    h5py = None\n\n\nHDF5_OBJECT_HEADER_LIMIT = 64512\n\n\ndef save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):\n    if h5py is None:\n        raise ImportError(\n            \"`save_model()` using h5 format requires h5py. Could not \"\n            \"import h5py.\"\n        )\n\n    if not isinstance(filepath, h5py.File):\n        # If file exists and should not be overwritten.\n        if not overwrite and os.path.isfile(filepath):\n            proceed = io_utils.ask_to_proceed_with_overwrite(filepath)\n            if not proceed:\n                return\n\n        dirpath = os.path.dirname(filepath)\n        if dirpath and not os.path.exists(dirpath):\n            os.makedirs(dirpath, exist_ok=True)\n\n        f = h5py.File(filepath, mode=\"w\")\n        opened_new_file = True\n    else:\n        f = filepath\n        opened_new_file = False\n    try:\n        with saving_options.keras_option_scope(use_legacy_config=True):\n            model_metadata = saving_utils.model_metadata(\n                model, include_optimizer\n            )\n            for k, v in model_metadata.items():\n                if isinstance(v, (dict, list, tuple)):\n                    f.attrs[k] = json.dumps(\n                        v, default=json_utils.get_json_type\n                    ).encode(\"utf8\")\n                else:\n                    f.attrs[k] = v\n\n            model_weights_group = f.create_group(\"model_weights\")\n            save_weights_to_hdf5_group(model_weights_group, model)\n\n            # TODO(b/128683857): Add integration tests between tf.keras and\n            # external Keras, to avoid breaking TF.js users.\n            if include_optimizer and hasattr(model, \"optimizer\"):\n                save_optimizer_weights_to_hdf5_group(f, model.optimizer)\n\n        f.flush()\n    finally:\n        if opened_new_file:\n            f.close()\n\n\ndef load_model_from_hdf5(\n    filepath, custom_objects=None, compile=True, safe_mode=True\n):\n    \"\"\"Loads a model saved via `save_model_to_hdf5`.\n\n    Args:\n        filepath: One of the following:\n            - String, path to the saved model\n            - `h5py.File` object from which to load the model\n        custom_objects: Optional dictionary mapping names\n            (strings) to custom classes or functions to be\n            considered during deserialization.\n        compile: Boolean, whether to compile the model\n            after loading.\n\n    Returns:\n        A Keras model instance. If an optimizer was found\n        as part of the saved model, the model is already\n        compiled. Otherwise, the model is uncompiled and\n        a warning will be displayed. When `compile` is set\n        to `False`, the compilation is omitted without any\n        warning.\n\n    Raises:\n        ImportError: if h5py is not available.\n        ValueError: In case of an invalid savefile.\n    \"\"\"\n    if h5py is None:\n        raise ImportError(\n            \"`load_model()` using h5 format requires h5py. Could not \"\n            \"import h5py.\"\n        )\n\n    if not custom_objects:\n        custom_objects = {}\n\n    gco = object_registration.GLOBAL_CUSTOM_OBJECTS\n    tlco = global_state.get_global_attribute(\"custom_objects_scope_dict\", {})\n    custom_objects = {**custom_objects, **gco, **tlco}\n\n    opened_new_file = not isinstance(filepath, h5py.File)\n    if opened_new_file:\n        f = h5py.File(filepath, mode=\"r\")\n    else:\n        f = filepath\n\n    model = None\n    try:\n        # instantiate model\n        model_config = f.attrs.get(\"model_config\")\n        if model_config is None:\n            raise ValueError(\n                f\"No model config found in the file at {filepath}.\"\n            )\n        if hasattr(model_config, \"decode\"):\n            model_config = model_config.decode(\"utf-8\")\n        model_config = json_utils.decode(model_config)\n\n        legacy_scope = saving_options.keras_option_scope(use_legacy_config=True)\n        safe_mode_scope = serialization_lib.SafeModeScope(safe_mode)\n        with legacy_scope, safe_mode_scope:\n            model = saving_utils.model_from_config(\n                model_config, custom_objects=custom_objects\n            )\n\n            # set weights\n            load_weights_from_hdf5_group(f[\"model_weights\"], model)\n\n        if compile:\n            # instantiate optimizer\n            training_config = f.attrs.get(\"training_config\")\n            if hasattr(training_config, \"decode\"):\n                training_config = training_config.decode(\"utf-8\")\n            if training_config is None:\n                logging.warning(\n                    \"No training configuration found in the save file, so \"\n                    \"the model was *not* compiled. Compile it manually.\"\n                )\n                return model\n            training_config = json_utils.decode(training_config)\n\n            # Compile model.\n            model.compile(\n                **saving_utils.compile_args_from_training_config(\n                    training_config, custom_objects\n                )\n            )\n            saving_utils.try_build_compiled_arguments(model)\n\n            # Set optimizer weights.\n            if \"optimizer_weights\" in f:\n                try:\n                    from keras.src import optimizers\n\n                    if isinstance(model.optimizer, optimizers.Optimizer):\n                        model.optimizer.build(model._trainable_variables)\n                    else:\n                        model.optimizer._create_all_weights(\n                            model._trainable_variables\n                        )\n                except (NotImplementedError, AttributeError):\n                    logging.warning(\n                        \"Error when creating the weights of optimizer {}, \"\n                        \"making it impossible to restore the saved optimizer \"\n                        \"state. As a result, your model is starting with \"\n                        \"a freshly initialized optimizer.\"\n                    )\n\n                optimizer_weight_values = (\n                    load_optimizer_weights_from_hdf5_group(f)\n                )\n                try:\n                    model.optimizer.set_weights(optimizer_weight_values)\n                except ValueError:\n                    logging.warning(\n                        \"Error in loading the saved optimizer \"\n                        \"state. As a result, your model is \"\n                        \"starting with a freshly initialized \"\n                        \"optimizer.\"\n                    )\n    finally:\n        if opened_new_file:\n            f.close()\n    return model\n\n\ndef save_weights_to_hdf5_group(f, model):\n    \"\"\"Saves the weights of a list of layers to a HDF5 group.\n\n    Args:\n        f: HDF5 group.\n        model: Model instance.\n    \"\"\"\n    from keras.src import __version__ as keras_version\n\n    save_attributes_to_hdf5_group(\n        f, \"layer_names\", [layer.name.encode(\"utf8\") for layer in model.layers]\n    )\n    f.attrs[\"backend\"] = backend.backend().encode(\"utf8\")\n    f.attrs[\"keras_version\"] = str(keras_version).encode(\"utf8\")\n\n    # Sort model layers by layer name to ensure that group names are strictly\n    # growing to avoid prefix issues.\n    for layer in sorted(model.layers, key=lambda x: x.name):\n        g = f.create_group(layer.name)\n        weights = _legacy_weights(layer)\n        save_subset_weights_to_hdf5_group(g, weights)\n    weights = list(\n        v\n        for v in model._trainable_variables + model._non_trainable_variables\n        if v in model.weights\n    )\n    g = f.create_group(\"top_level_model_weights\")\n    save_subset_weights_to_hdf5_group(g, weights)\n\n\ndef save_subset_weights_to_hdf5_group(f, weights):\n    \"\"\"Save top-level weights of a model to a HDF5 group.\n\n    Args:\n        f: HDF5 group.\n        weights: List of weight variables.\n    \"\"\"\n    weight_values = [backend.convert_to_numpy(w) for w in weights]\n    weight_names = [str(w.path).encode(\"utf8\") for w in weights]\n    save_attributes_to_hdf5_group(f, \"weight_names\", weight_names)\n    for name, val in zip(weight_names, weight_values):\n        param_dset = f.create_dataset(name, val.shape, dtype=val.dtype)\n        if not val.shape:\n            # scalar\n            param_dset[()] = val\n        else:\n            param_dset[:] = val\n\n\ndef save_optimizer_weights_to_hdf5_group(hdf5_group, optimizer):\n    \"\"\"Saves optimizer weights of a optimizer to a HDF5 group.\n\n    Args:\n        hdf5_group: HDF5 group.\n        optimizer: optimizer instance.\n    \"\"\"\n    from keras.src import optimizers\n\n    if isinstance(optimizer, optimizers.Optimizer):\n        symbolic_weights = optimizer.variables\n    else:\n        symbolic_weights = getattr(optimizer, \"weights\")\n    if symbolic_weights:\n        weights_group = hdf5_group.create_group(\"optimizer_weights\")\n        weight_names = [str(w.path).encode(\"utf8\") for w in symbolic_weights]\n        save_attributes_to_hdf5_group(\n            weights_group, \"weight_names\", weight_names\n        )\n        weight_values = [backend.convert_to_numpy(w) for w in symbolic_weights]\n        for name, val in zip(weight_names, weight_values):\n            param_dset = weights_group.create_dataset(\n                name, val.shape, dtype=val.dtype\n            )\n            if not val.shape:\n                # scalar\n                param_dset[()] = val\n            else:\n                param_dset[:] = val\n\n\ndef save_attributes_to_hdf5_group(group, name, data):\n    \"\"\"Saves attributes (data) of the specified name into the HDF5 group.\n\n    This method deals with an inherent problem of HDF5 file which is not\n    able to store data larger than HDF5_OBJECT_HEADER_LIMIT bytes.\n\n    Args:\n        group: A pointer to a HDF5 group.\n        name: A name of the attributes to save.\n        data: Attributes data to store.\n\n    Raises:\n      RuntimeError: If any single attribute is too large to be saved.\n    \"\"\"\n    # Check that no item in `data` is larger than `HDF5_OBJECT_HEADER_LIMIT`\n    # because in that case even chunking the array would not make the saving\n    # possible.\n    bad_attributes = [x for x in data if len(x) > HDF5_OBJECT_HEADER_LIMIT]\n\n    # Expecting this to never be true.\n    if bad_attributes:\n        raise RuntimeError(\n            \"The following attributes cannot be saved to HDF5 file because \"\n            f\"they are larger than {HDF5_OBJECT_HEADER_LIMIT} \"\n            f\"bytes: {bad_attributes}\"\n        )\n\n    data_npy = np.asarray(data)\n\n    num_chunks = 1\n    chunked_data = np.array_split(data_npy, num_chunks)\n\n    # This will never loop forever thanks to the test above.\n    while any(x.nbytes > HDF5_OBJECT_HEADER_LIMIT for x in chunked_data):\n        num_chunks += 1\n        chunked_data = np.array_split(data_npy, num_chunks)\n\n    if num_chunks > 1:\n        for chunk_id, chunk_data in enumerate(chunked_data):\n            group.attrs[\"%s%d\" % (name, chunk_id)] = chunk_data\n    else:\n        group.attrs[name] = data\n\n\ndef load_weights_from_hdf5_group(f, model, skip_mismatch=False):\n    \"\"\"Implements topological (order-based) weight loading.\n\n    Args:\n        f: A pointer to a HDF5 group.\n        model: Model instance.\n        skip_mismatch: Boolean, whether to skip loading of weights\n            where there is a mismatch in the shape of the weights,\n\n    Raises:\n        ValueError: in case of mismatch between provided layers\n            and weights file.\n    \"\"\"\n    if \"keras_version\" in f.attrs:\n        original_keras_version = f.attrs[\"keras_version\"]\n        if hasattr(original_keras_version, \"decode\"):\n            original_keras_version = original_keras_version.decode(\"utf8\")\n    else:\n        original_keras_version = \"1\"\n    if \"backend\" in f.attrs:\n        original_backend = f.attrs[\"backend\"]\n        if hasattr(original_backend, \"decode\"):\n            original_backend = original_backend.decode(\"utf8\")\n    else:\n        original_backend = None\n\n    filtered_layers = []\n    for layer in model.layers:\n        weights = _legacy_weights(layer)\n        if weights:\n            filtered_layers.append(layer)\n\n    layer_names = load_attributes_from_hdf5_group(f, \"layer_names\")\n    filtered_layer_names = []\n    for name in layer_names:\n        g = f[name]\n        weight_names = load_attributes_from_hdf5_group(g, \"weight_names\")\n        if weight_names:\n            filtered_layer_names.append(name)\n    layer_names = filtered_layer_names\n    if len(layer_names) != len(filtered_layers):\n        raise ValueError(\n            \"Layer count mismatch when loading weights from file. \"\n            f\"Model expected {len(filtered_layers)} layers, found \"\n            f\"{len(layer_names)} saved layers.\"\n        )\n\n    for k, name in enumerate(layer_names):\n        g = f[name]\n        layer = filtered_layers[k]\n        symbolic_weights = _legacy_weights(layer)\n        weight_values = load_subset_weights_from_hdf5_group(g)\n        if len(weight_values) != len(symbolic_weights):\n            raise ValueError(\n                f\"Weight count mismatch for layer #{k} (named {layer.name} in \"\n                f\"the current model, {name} in the save file). \"\n                f\"Layer expects {len(symbolic_weights)} weight(s). Received \"\n                f\"{len(weight_values)} saved weight(s)\"\n            )\n        _set_weights(\n            layer,\n            symbolic_weights,\n            weight_values,\n            skip_mismatch=skip_mismatch,\n            name=f\"layer #{k} (named {layer.name})\",\n        )\n\n    if \"top_level_model_weights\" in f:\n        symbolic_weights = list(\n            # model.weights\n            v\n            for v in model._trainable_variables + model._non_trainable_variables\n            if v in model.weights\n        )\n        weight_values = load_subset_weights_from_hdf5_group(\n            f[\"top_level_model_weights\"]\n        )\n        if len(weight_values) != len(symbolic_weights):\n            raise ValueError(\n                \"Weight count mismatch for top-level weights when loading \"\n                \"weights from file. \"\n                f\"Model expects {len(symbolic_weights)} top-level weight(s). \"\n                f\"Received {len(weight_values)} saved top-level weight(s)\"\n            )\n        _set_weights(\n            model,\n            symbolic_weights,\n            weight_values,\n            skip_mismatch=skip_mismatch,\n            name=\"top-level model\",\n        )\n\n\ndef _set_weights(\n    instance, symbolic_weights, weight_values, name, skip_mismatch=False\n):\n    \"\"\"Safely set weights into a model or a layer.\n\n    Args:\n        instance: Model or layer instance,\n        symbolic_weights: symbolic tensors representing\n                        the weights of the variables to load,\n        weight_values: values of the weights to load,\n        skip_mismatch: Boolean, whether to skip loading of weights\n            where there is a mismatch in the shape of the weights,\n        name: name used to identify the group.\n\n    Raises:\n        ValueError: in case of mismatch between provided\n            model/layer and weights.\n    \"\"\"\n    for i, weight_value in enumerate(weight_values):\n        expected_shape = symbolic_weights[i].shape\n        received_shape = weight_value.shape\n        if expected_shape != received_shape:\n            if skip_mismatch:\n                warnings.warn(\n                    f\"Skipping loading weights for {name}\"\n                    f\"due to mismatch in shape for \"\n                    f\"weight {symbolic_weights[i].path}. \"\n                    f\"Weight expects shape {expected_shape}. \"\n                    \"Received saved weight \"\n                    f\"with shape {received_shape}\",\n                    stacklevel=2,\n                )\n                continue\n            raise ValueError(\n                f\"Shape mismatch in {name}\"\n                f\"for weight {symbolic_weights[i].path}. \"\n                f\"Weight expects shape {expected_shape}. \"\n                \"Received saved weight \"\n                f\"with shape {received_shape}\"\n            )\n        symbolic_weights[i].assign(weight_value)\n\n    if hasattr(instance, \"finalize_state\") and symbolic_weights:\n        instance.finalize_state()\n\n\ndef load_weights_from_hdf5_group_by_name(f, model, skip_mismatch=False):\n    \"\"\"Implements name-based weight loading (instead of topological loading).\n\n    Layers that have no matching name are skipped.\n\n    Args:\n        f: A pointer to a HDF5 group.\n        model: Model instance.\n        skip_mismatch: Boolean, whether to skip loading of layers\n            where there is a mismatch in the number of weights,\n            or a mismatch in the shape of the weights.\n\n    Raises:\n        ValueError: in case of mismatch between provided layers\n            and weights file and skip_match=False.\n    \"\"\"\n    if \"keras_version\" in f.attrs:\n        original_keras_version = f.attrs[\"keras_version\"]\n        if hasattr(original_keras_version, \"decode\"):\n            original_keras_version = original_keras_version.decode(\"utf8\")\n    else:\n        original_keras_version = \"1\"\n    if \"backend\" in f.attrs:\n        original_backend = f.attrs[\"backend\"]\n        if hasattr(original_backend, \"decode\"):\n            original_backend = original_backend.decode(\"utf8\")\n    else:\n        original_backend = None\n\n    # New file format.\n    layer_names = load_attributes_from_hdf5_group(f, \"layer_names\")\n\n    # Reverse index of layer name to list of layers with name.\n    index = {}\n    for layer in model.layers:\n        if layer.name:\n            index.setdefault(layer.name, []).append(layer)\n\n    for k, name in enumerate(layer_names):\n        g = f[name]\n        weight_values = load_subset_weights_from_hdf5_group(g)\n        for layer in index.get(name, []):\n            symbolic_weights = _legacy_weights(layer)\n            if len(weight_values) != len(symbolic_weights):\n                if skip_mismatch:\n                    warnings.warn(\n                        f\"Skipping loading of weights for layer #{k} (named \"\n                        f\"{layer.name}) due to mismatch in number of weights. \"\n                        f\"Layer expects {len(symbolic_weights)} weight(s). \"\n                        f\"Received {len(weight_values)} saved weight(s)\",\n                        stacklevel=2,\n                    )\n                    continue\n                raise ValueError(\n                    f\"Weight count mismatch for layer #{k} \"\n                    f\"(named {layer.name}). \"\n                    f\"Layer expects {len(symbolic_weights)} weight(s). \"\n                    f\"Received {len(weight_values)} saved weight(s)\"\n                )\n            # Set values.\n            _set_weights(\n                layer,\n                symbolic_weights,\n                weight_values,\n                skip_mismatch=skip_mismatch,\n                name=f\"layer #{k} (named {layer.name})\",\n            )\n\n    if \"top_level_model_weights\" in f:\n        symbolic_weights = (\n            model._trainable_variables + model._non_trainable_variables\n        )\n        weight_values = load_subset_weights_from_hdf5_group(\n            f[\"top_level_model_weights\"]\n        )\n\n        if len(weight_values) != len(symbolic_weights):\n            if skip_mismatch:\n                warnings.warn(\n                    \"Skipping loading top-level weights for model due to \"\n                    \"mismatch in number of weights. \"\n                    f\"Model expects {len(symbolic_weights)} \"\n                    \"top-level weight(s). \"\n                    f\"Received {len(weight_values)} saved top-level weight(s)\",\n                    stacklevel=2,\n                )\n            else:\n                raise ValueError(\n                    \"Weight count mismatch for top-level weights of model. \"\n                    f\"Model expects {len(symbolic_weights)} \"\n                    \"top-level weight(s). \"\n                    f\"Received {len(weight_values)} saved top-level weight(s)\"\n                )\n        else:\n            _set_weights(\n                model,\n                symbolic_weights,\n                weight_values,\n                skip_mismatch=skip_mismatch,\n                name=\"top-level model\",\n            )\n\n\ndef load_subset_weights_from_hdf5_group(f):\n    \"\"\"Load layer weights of a model from hdf5.\n\n    Args:\n        f: A pointer to a HDF5 group.\n\n    Returns:\n        List of NumPy arrays of the weight values.\n\n    Raises:\n        ValueError: in case of mismatch between provided model\n            and weights file.\n    \"\"\"\n    weight_names = load_attributes_from_hdf5_group(f, \"weight_names\")\n    return [np.asarray(f[weight_name]) for weight_name in weight_names]\n\n\ndef load_optimizer_weights_from_hdf5_group(hdf5_group):\n    \"\"\"Load optimizer weights from a HDF5 group.\n\n    Args:\n        hdf5_group: A pointer to a HDF5 group.\n\n    Returns:\n        data: List of optimizer weight names.\n    \"\"\"\n    weights_group = hdf5_group[\"optimizer_weights\"]\n    optimizer_weight_names = load_attributes_from_hdf5_group(\n        weights_group, \"weight_names\"\n    )\n    return [\n        weights_group[weight_name] for weight_name in optimizer_weight_names\n    ]\n\n\ndef load_attributes_from_hdf5_group(group, name):\n    \"\"\"Loads attributes of the specified name from the HDF5 group.\n\n    This method deals with an inherent problem\n    of HDF5 file which is not able to store\n    data larger than HDF5_OBJECT_HEADER_LIMIT bytes.\n\n    Args:\n        group: A pointer to a HDF5 group.\n        name: A name of the attributes to load.\n\n    Returns:\n        data: Attributes data.\n    \"\"\"\n    if name in group.attrs:\n        data = [\n            n.decode(\"utf8\") if hasattr(n, \"decode\") else n\n            for n in group.attrs[name]\n        ]\n    else:\n        data = []\n        chunk_id = 0\n        while f\"{name}{chunk_id}\" in group.attrs:\n            data.extend(\n                [\n                    n.decode(\"utf8\") if hasattr(n, \"decode\") else n\n                    for n in group.attrs[f\"{name}{chunk_id}\"]\n                ]\n            )\n            chunk_id += 1\n    return data\n\n\ndef _legacy_weights(layer):\n    \"\"\"Legacy weight order converter.\n\n    For legacy reason, the layer.weights was in the order of\n    [self.trainable_weights + self.non_trainable_weights], and this order was\n    used for preserving the weights in h5 format. The new order of layer.weights\n    are the same as layer.get_weights() which is more intuitive for user. To\n    keep supporting the existing saved h5 file, this method should be used to\n    save/load weights.\n\n    Args:\n        layer: a `Model` or `Layer` instance.\n\n    Returns:\n        A list of variables with the legacy weight order.\n    \"\"\"\n    return layer.trainable_weights + layer.non_trainable_weights\n"
  },
  {
    "path": "keras/src/legacy/saving/legacy_h5_format_test.py",
    "content": "import os\n\nimport numpy as np\nimport pytest\n\nimport keras\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.legacy.saving import legacy_h5_format\nfrom keras.src.saving import object_registration\nfrom keras.src.saving import serialization_lib\n\n# TODO: more thorough testing. Correctness depends\n# on exact weight ordering for each layer, so we need\n# to test across all types of layers.\n\ntry:\n    import tf_keras\nexcept:\n    tf_keras = None\n\n\ndef get_sequential_model(keras):\n    return keras.Sequential(\n        [\n            keras.layers.Input((3,), batch_size=2),\n            keras.layers.Dense(4, activation=\"relu\"),\n            keras.layers.BatchNormalization(\n                moving_mean_initializer=\"uniform\", gamma_initializer=\"uniform\"\n            ),\n            keras.layers.Dense(5, activation=\"softmax\"),\n        ]\n    )\n\n\ndef get_functional_model(keras):\n    inputs = keras.Input((3,), batch_size=2)\n    x = keras.layers.Dense(4, activation=\"relu\")(inputs)\n    residual = x\n    x = keras.layers.BatchNormalization(\n        moving_mean_initializer=\"uniform\", gamma_initializer=\"uniform\"\n    )(x)\n    x = keras.layers.Dense(4, activation=\"relu\")(x)\n    x = keras.layers.add([x, residual])\n    outputs = keras.layers.Dense(5, activation=\"softmax\")(x)\n    return keras.Model(inputs, outputs)\n\n\ndef get_subclassed_model(keras):\n    class MyModel(keras.Model):\n        def __init__(self, **kwargs):\n            super().__init__(**kwargs)\n            self.dense_1 = keras.layers.Dense(3, activation=\"relu\")\n            self.dense_2 = keras.layers.Dense(1, activation=\"sigmoid\")\n\n            # top_level_model_weights\n            self.bias = self.add_weight(\n                name=\"bias\",\n                shape=[1],\n                trainable=True,\n                initializer=keras.initializers.Zeros(),\n            )\n\n        def call(self, x):\n            x = self.dense_1(x)\n            x = self.dense_2(x)\n\n            # top_level_model_weights\n            x += ops.cast(self.bias, x.dtype)\n            return x\n\n    model = MyModel()\n    model(np.random.random((2, 3)))\n    return model\n\n\n@pytest.mark.requires_trainable_backend\n@pytest.mark.skipif(tf_keras is None, reason=\"Test requires tf_keras\")\nclass LegacyH5WeightsTest(testing.TestCase):\n    def _check_reloading_weights(self, ref_input, model, tf_keras_model):\n        ref_output = tf_keras_model(ref_input)\n        initial_weights = model.get_weights()\n        # Check weights only file\n        temp_filepath = os.path.join(self.get_temp_dir(), \"weights.h5\")\n        tf_keras_model.save_weights(temp_filepath)\n        model.load_weights(temp_filepath)\n        output = model(ref_input)\n        self.assertAllClose(ref_output, output, atol=1e-5)\n        model.set_weights(initial_weights)\n        model.load_weights(temp_filepath)\n        output = model(ref_input)\n        self.assertAllClose(ref_output, output, atol=1e-5)\n\n    def test_sequential_model_weights(self):\n        model = get_sequential_model(keras)\n        tf_keras_model = get_sequential_model(tf_keras)\n        ref_input = np.random.random((2, 3))\n        self._check_reloading_weights(ref_input, model, tf_keras_model)\n\n    def test_functional_model_weights(self):\n        model = get_functional_model(keras)\n        tf_keras_model = get_functional_model(tf_keras)\n        ref_input = np.random.random((2, 3))\n        self._check_reloading_weights(ref_input, model, tf_keras_model)\n\n    def test_subclassed_model_weights(self):\n        model = get_subclassed_model(keras)\n        tf_keras_model = get_subclassed_model(tf_keras)\n        ref_input = np.random.random((2, 3))\n        self._check_reloading_weights(ref_input, model, tf_keras_model)\n\n\n@pytest.mark.requires_trainable_backend\nclass LegacyH5WholeModelTest(testing.TestCase):\n    def _check_reloading_model(self, ref_input, model):\n        # Whole model file\n        ref_output = model(ref_input)\n        temp_filepath = os.path.join(self.get_temp_dir(), \"model.h5\")\n        legacy_h5_format.save_model_to_hdf5(model, temp_filepath)\n        loaded = legacy_h5_format.load_model_from_hdf5(temp_filepath)\n        output = loaded(ref_input)\n        self.assertAllClose(ref_output, output, atol=1e-5)\n\n    def test_sequential_model(self):\n        model = get_sequential_model(keras)\n        ref_input = np.random.random((2, 3))\n        self._check_reloading_model(ref_input, model)\n\n    def test_functional_model(self):\n        model = get_functional_model(keras)\n        ref_input = np.random.random((2, 3))\n        self._check_reloading_model(ref_input, model)\n\n    def test_compiled_model_with_various_layers(self):\n        model = models.Sequential()\n        model.add(layers.Dense(2, input_shape=(3,)))\n        model.add(layers.RepeatVector(3))\n        model.add(layers.TimeDistributed(layers.Dense(3)))\n\n        model.compile(optimizer=\"rmsprop\", loss=\"mean_squared_error\")\n        ref_input = np.random.random((1, 3))\n        self._check_reloading_model(ref_input, model)\n\n    def test_saving_lambda(self):\n        mean = ops.random.uniform((4, 2, 3))\n        std = ops.abs(ops.random.uniform((4, 2, 3))) + 1e-5\n        inputs = layers.Input(shape=(4, 2, 3))\n        output = layers.Lambda(\n            lambda image, mu, std: (image - mu) / std,\n            arguments={\"mu\": mean, \"std\": std},\n        )(inputs)\n        model = models.Model(inputs, output)\n        model.compile(\n            loss=\"mean_squared_error\", optimizer=\"sgd\", metrics=[\"acc\"]\n        )\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"lambda_model.h5\")\n        legacy_h5_format.save_model_to_hdf5(model, temp_filepath)\n\n        with self.assertRaisesRegex(ValueError, \"arbitrary code execution\"):\n            legacy_h5_format.load_model_from_hdf5(temp_filepath)\n\n        loaded = legacy_h5_format.load_model_from_hdf5(\n            temp_filepath, safe_mode=False\n        )\n        self.assertAllClose(mean, loaded.layers[1].arguments[\"mu\"])\n        self.assertAllClose(std, loaded.layers[1].arguments[\"std\"])\n\n    def test_saving_include_optimizer_false(self):\n        model = models.Sequential()\n        model.add(layers.Dense(1))\n        model.compile(\"adam\", loss=\"mean_squared_error\")\n        x, y = np.ones((10, 10)), np.ones((10, 1))\n        model.fit(x, y)\n        ref_output = model(x)\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"model.h5\")\n        legacy_h5_format.save_model_to_hdf5(\n            model, temp_filepath, include_optimizer=False\n        )\n        loaded = legacy_h5_format.load_model_from_hdf5(temp_filepath)\n        output = loaded(x)\n\n        # Assert that optimizer does not exist in loaded model\n        with self.assertRaises(AttributeError):\n            _ = loaded.optimizer\n\n        # Compare output\n        self.assertAllClose(ref_output, output, atol=1e-5)\n\n    def test_custom_sequential_registered_no_scope(self):\n        @object_registration.register_keras_serializable(package=\"my_package\")\n        class MyDense(layers.Dense):\n            def __init__(self, units, **kwargs):\n                super().__init__(units, **kwargs)\n\n        inputs = layers.Input(shape=[1])\n        custom_layer = MyDense(1)\n        model = models.Sequential(layers=[inputs, custom_layer])\n\n        ref_input = np.array([5])\n        self._check_reloading_model(ref_input, model)\n\n    def test_custom_functional_registered_no_scope(self):\n        @object_registration.register_keras_serializable(package=\"my_package\")\n        class MyDense(layers.Dense):\n            def __init__(self, units, **kwargs):\n                super().__init__(units, **kwargs)\n\n        inputs = layers.Input(shape=[1])\n        outputs = MyDense(1)(inputs)\n        model = models.Model(inputs, outputs)\n\n        ref_input = np.array([5])\n        self._check_reloading_model(ref_input, model)\n\n    def test_custom_function_from_custom_objects_no_registration(self):\n        from keras.src.saving import custom_object_scope\n\n        def custom_fn(x):\n            return x * 2.0\n\n        class MyDense(layers.Dense):\n            def call(self, inputs):\n                return custom_fn(super().call(inputs))\n\n        inputs = layers.Input(shape=[1])\n        outputs = MyDense(1)(inputs)\n        model = models.Model(inputs, outputs)\n\n        ref_input = np.array([[5.0]])\n        with custom_object_scope({\"MyDense\": MyDense, \"custom_fn\": custom_fn}):\n            self._check_reloading_model(ref_input, model)\n\n    def test_nested_layers(self):\n        class MyLayer(layers.Layer):\n            def __init__(self, sublayers, **kwargs):\n                super().__init__(**kwargs)\n                self.sublayers = sublayers\n\n            def call(self, x):\n                prev_input = x\n                for layer in self.sublayers:\n                    prev_input = layer(prev_input)\n                return prev_input\n\n            def get_config(self):\n                config = super().get_config()\n                config[\"sublayers\"] = serialization_lib.serialize_keras_object(\n                    self.sublayers\n                )\n                return config\n\n            @classmethod\n            def from_config(cls, config):\n                config[\"sublayers\"] = (\n                    serialization_lib.deserialize_keras_object(\n                        config[\"sublayers\"]\n                    )\n                )\n                return cls(**config)\n\n        @object_registration.register_keras_serializable(package=\"Foo\")\n        class RegisteredSubLayer(layers.Layer):\n            pass\n\n        layer = MyLayer(\n            [\n                layers.Dense(2, name=\"MyDense\"),\n                RegisteredSubLayer(name=\"MySubLayer\"),\n            ]\n        )\n        model = models.Sequential([layer])\n        with self.subTest(\"test_JSON\"):\n            from keras.src.models.model import model_from_json\n\n            model_json = model.to_json()\n            self.assertIn(\"Foo>RegisteredSubLayer\", model_json)\n\n            loaded_model = model_from_json(\n                model_json, custom_objects={\"MyLayer\": MyLayer}\n            )\n            loaded_layer = loaded_model.layers[0]\n\n            self.assertIsInstance(loaded_layer.sublayers[0], layers.Dense)\n            self.assertEqual(loaded_layer.sublayers[0].name, \"MyDense\")\n            self.assertIsInstance(loaded_layer.sublayers[1], RegisteredSubLayer)\n            self.assertEqual(loaded_layer.sublayers[1].name, \"MySubLayer\")\n\n        with self.subTest(\"test_H5\"):\n            temp_filepath = os.path.join(self.get_temp_dir(), \"model.h5\")\n            legacy_h5_format.save_model_to_hdf5(model, temp_filepath)\n            loaded_model = legacy_h5_format.load_model_from_hdf5(\n                temp_filepath, custom_objects={\"MyLayer\": MyLayer}\n            )\n            loaded_layer = loaded_model.layers[0]\n\n            self.assertIsInstance(loaded_layer.sublayers[0], layers.Dense)\n            self.assertEqual(loaded_layer.sublayers[0].name, \"MyDense\")\n            self.assertIsInstance(loaded_layer.sublayers[1], RegisteredSubLayer)\n            self.assertEqual(loaded_layer.sublayers[1].name, \"MySubLayer\")\n\n    def test_model_loading_with_axis_arg(self):\n        input1 = layers.Input(shape=(1, 4), name=\"input1\")\n        input2 = layers.Input(shape=(1, 4), name=\"input2\")\n        concat1 = layers.Concatenate(axis=1)([input1, input2])\n        output = layers.Dense(1, activation=\"sigmoid\")(concat1)\n        model = models.Model(inputs=[input1, input2], outputs=output)\n        model.compile(\n            optimizer=\"adam\", loss=\"binary_crossentropy\", metrics=[\"accuracy\"]\n        )\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"model_with_axis_arg.h5\"\n        )\n        legacy_h5_format.save_model_to_hdf5(model, temp_filepath)\n        legacy_h5_format.load_model_from_hdf5(temp_filepath)\n\n\n@pytest.mark.requires_trainable_backend\n@pytest.mark.skipif(tf_keras is None, reason=\"Test requires tf_keras\")\nclass LegacyH5BackwardsCompatTest(testing.TestCase):\n    def _check_reloading_model(self, ref_input, model, tf_keras_model):\n        # Whole model file\n        ref_output = tf_keras_model(ref_input)\n        temp_filepath = os.path.join(self.get_temp_dir(), \"model.h5\")\n        tf_keras_model.save(temp_filepath)\n        loaded = legacy_h5_format.load_model_from_hdf5(temp_filepath)\n        output = loaded(ref_input)\n        self.assertAllClose(ref_output, output, atol=1e-5)\n\n    def test_sequential_model(self):\n        model = get_sequential_model(keras)\n        tf_keras_model = get_sequential_model(tf_keras)\n        ref_input = np.random.random((2, 3))\n        self._check_reloading_model(ref_input, model, tf_keras_model)\n\n    def test_functional_model(self):\n        tf_keras_model = get_functional_model(tf_keras)\n        model = get_functional_model(keras)\n        ref_input = np.random.random((2, 3))\n        self._check_reloading_model(ref_input, model, tf_keras_model)\n\n    def test_compiled_model_with_various_layers(self):\n        model = models.Sequential()\n        model.add(layers.Dense(2, input_shape=(3,)))\n        model.add(layers.RepeatVector(3))\n        model.add(layers.TimeDistributed(layers.Dense(3)))\n        model.compile(optimizer=\"rmsprop\", loss=\"mse\")\n\n        tf_keras_model = tf_keras.Sequential()\n        tf_keras_model.add(tf_keras.layers.Dense(2, input_shape=(3,)))\n        tf_keras_model.add(tf_keras.layers.RepeatVector(3))\n        tf_keras_model.add(\n            tf_keras.layers.TimeDistributed(tf_keras.layers.Dense(3))\n        )\n        tf_keras_model.compile(optimizer=\"rmsprop\", loss=\"mean_squared_error\")\n\n        ref_input = np.random.random((1, 3))\n        self._check_reloading_model(ref_input, model, tf_keras_model)\n\n    def test_saving_lambda(self):\n        mean = np.random.random((4, 2, 3))\n        std = np.abs(np.random.random((4, 2, 3))) + 1e-5\n        inputs = tf_keras.layers.Input(shape=(4, 2, 3))\n        output = tf_keras.layers.Lambda(\n            lambda image, mu, std: (image - mu) / std,\n            arguments={\"mu\": mean, \"std\": std},\n            output_shape=inputs.shape,\n        )(inputs)\n        tf_keras_model = tf_keras.Model(inputs, output)\n        tf_keras_model.compile(\n            loss=\"mean_squared_error\", optimizer=\"sgd\", metrics=[\"acc\"]\n        )\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"lambda_model.h5\")\n        tf_keras_model.save(temp_filepath)\n\n        with self.assertRaisesRegex(ValueError, \"arbitrary code execution\"):\n            legacy_h5_format.load_model_from_hdf5(temp_filepath)\n\n        loaded = legacy_h5_format.load_model_from_hdf5(\n            temp_filepath, safe_mode=False\n        )\n        self.assertAllClose(mean, loaded.layers[1].arguments[\"mu\"])\n        self.assertAllClose(std, loaded.layers[1].arguments[\"std\"])\n\n    def test_saving_include_optimizer_false(self):\n        tf_keras_model = tf_keras.Sequential()\n        tf_keras_model.add(tf_keras.layers.Dense(1))\n        tf_keras_model.compile(\"adam\", loss=\"mse\")\n        x, y = np.ones((10, 10)), np.ones((10, 1))\n        tf_keras_model.fit(x, y)\n        ref_output = tf_keras_model(x)\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"model.h5\")\n        tf_keras_model.save(temp_filepath, include_optimizer=False)\n        loaded = legacy_h5_format.load_model_from_hdf5(temp_filepath)\n        output = loaded(x)\n\n        # Assert that optimizer does not exist in loaded model\n        with self.assertRaises(AttributeError):\n            _ = loaded.optimizer\n\n        # Compare output\n        self.assertAllClose(ref_output, output, atol=1e-5)\n\n    @pytest.mark.skipif(tf_keras is None, reason=\"Test requires tf_keras\")\n    def test_custom_sequential_registered_no_scope(self):\n        @tf_keras.saving.register_keras_serializable(package=\"my_package\")\n        class MyDense(tf_keras.layers.Dense):\n            def __init__(self, units, **kwargs):\n                super().__init__(units, **kwargs)\n\n        inputs = tf_keras.layers.Input(shape=[1])\n        custom_layer = MyDense(1)\n        tf_keras_model = tf_keras.Sequential(layers=[inputs, custom_layer])\n\n        # Re-implement and re-register in Keras 3\n        @object_registration.register_keras_serializable(package=\"my_package\")\n        class MyDense(layers.Dense):\n            def __init__(self, units, **kwargs):\n                super().__init__(units, **kwargs)\n\n        inputs = layers.Input(shape=[1])\n        custom_layer = MyDense(1)\n        model = models.Sequential(layers=[inputs, custom_layer])\n\n        ref_input = np.array([5])\n        self._check_reloading_model(ref_input, model, tf_keras_model)\n\n    @pytest.mark.skipif(tf_keras is None, reason=\"Test requires tf_keras\")\n    def test_custom_functional_registered_no_scope(self):\n        @tf_keras.saving.register_keras_serializable(package=\"my_package\")\n        class MyDense(tf_keras.layers.Dense):\n            def __init__(self, units, **kwargs):\n                super().__init__(units, **kwargs)\n\n        inputs = tf_keras.layers.Input(shape=[1])\n        outputs = MyDense(1)(inputs)\n        tf_keras_model = tf_keras.Model(inputs, outputs)\n\n        # Re-implement and re-register in Keras 3\n        @object_registration.register_keras_serializable(package=\"my_package\")\n        class MyDense(layers.Dense):\n            def __init__(self, units, **kwargs):\n                super().__init__(units, **kwargs)\n\n        inputs = layers.Input(shape=[1])\n        outputs = MyDense(1)(inputs)\n        model = models.Model(inputs, outputs)\n\n        ref_input = np.array([5])\n        self._check_reloading_model(ref_input, model, tf_keras_model)\n\n    def test_nested_layers(self):\n        class MyLayer(tf_keras.layers.Layer):\n            def __init__(self, sublayers, **kwargs):\n                super().__init__(**kwargs)\n                self.sublayers = sublayers\n\n            def call(self, x):\n                prev_input = x\n                for layer in self.sublayers:\n                    prev_input = layer(prev_input)\n                return prev_input\n\n            def get_config(self):\n                config = super().get_config()\n                config[\"sublayers\"] = tf_keras.saving.serialize_keras_object(\n                    self.sublayers\n                )\n                return config\n\n            @classmethod\n            def from_config(cls, config):\n                config[\"sublayers\"] = tf_keras.saving.deserialize_keras_object(\n                    config[\"sublayers\"]\n                )\n                return cls(**config)\n\n        @tf_keras.saving.register_keras_serializable(package=\"Foo\")\n        class RegisteredSubLayer(layers.Layer):\n            def call(self, x):\n                return x\n\n        layer = MyLayer(\n            [\n                tf_keras.layers.Dense(2, name=\"MyDense\"),\n                RegisteredSubLayer(name=\"MySubLayer\"),\n            ]\n        )\n        tf_keras_model = tf_keras.Sequential([layer])\n\n        x = np.random.random((4, 2))\n        ref_output = tf_keras_model(x)\n\n        # Save TF Keras model to H5 file\n        temp_filepath = os.path.join(self.get_temp_dir(), \"model.h5\")\n        tf_keras_model.save(temp_filepath)\n\n        # Re-implement in Keras 3\n        class MyLayer(layers.Layer):\n            def __init__(self, sublayers, **kwargs):\n                super().__init__(**kwargs)\n                self.sublayers = sublayers\n\n            def call(self, x):\n                prev_input = x\n                for layer in self.sublayers:\n                    prev_input = layer(prev_input)\n                return prev_input\n\n            def get_config(self):\n                config = super().get_config()\n                config[\"sublayers\"] = serialization_lib.serialize_keras_object(\n                    self.sublayers\n                )\n                return config\n\n            @classmethod\n            def from_config(cls, config):\n                config[\"sublayers\"] = (\n                    serialization_lib.deserialize_keras_object(\n                        config[\"sublayers\"]\n                    )\n                )\n                return cls(**config)\n\n        # Re-implement and re-register in Keras 3\n        @object_registration.register_keras_serializable(package=\"Foo\")\n        class RegisteredSubLayer(layers.Layer):\n            def call(self, x):\n                return x\n\n        # Load in Keras 3\n        loaded_model = legacy_h5_format.load_model_from_hdf5(\n            temp_filepath, custom_objects={\"MyLayer\": MyLayer}\n        )\n        loaded_layer = loaded_model.layers[0]\n        output = loaded_model(x)\n\n        # Ensure nested layer structure\n        self.assertIsInstance(loaded_layer.sublayers[0], layers.Dense)\n        self.assertEqual(loaded_layer.sublayers[0].name, \"MyDense\")\n        self.assertIsInstance(loaded_layer.sublayers[1], RegisteredSubLayer)\n        self.assertEqual(loaded_layer.sublayers[1].name, \"MySubLayer\")\n\n        # Compare output\n        self.assertAllClose(ref_output, output, atol=1e-5)\n\n\n@pytest.mark.requires_trainable_backend\nclass DirectoryCreationTest(testing.TestCase):\n    def test_directory_creation_on_save(self):\n        \"\"\"Test if directory is created on model save.\"\"\"\n        model = get_sequential_model(keras)\n        nested_dirpath = os.path.join(\n            self.get_temp_dir(), \"dir1\", \"dir2\", \"dir3\"\n        )\n        filepath = os.path.join(nested_dirpath, \"model.h5\")\n        self.assertFalse(os.path.exists(nested_dirpath))\n        legacy_h5_format.save_model_to_hdf5(model, filepath)\n        self.assertTrue(os.path.exists(nested_dirpath))\n        loaded_model = legacy_h5_format.load_model_from_hdf5(filepath)\n        self.assertEqual(model.to_json(), loaded_model.to_json())\n"
  },
  {
    "path": "keras/src/legacy/saving/saving_options.py",
    "content": "import contextlib\n\nfrom keras.src.backend.common import global_state\n\n\n@contextlib.contextmanager\ndef keras_option_scope(use_legacy_config=True):\n    use_legacy_config_prev_value = global_state.get_global_attribute(\n        \"use_legacy_config\", None\n    )\n    global_state.set_global_attribute(\"use_legacy_config\", use_legacy_config)\n    try:\n        yield\n    finally:\n        global_state.set_global_attribute(\n            \"use_legacy_config\", use_legacy_config_prev_value\n        )\n"
  },
  {
    "path": "keras/src/legacy/saving/saving_utils.py",
    "content": "import threading\n\nfrom absl import logging\n\nfrom keras.src import backend\nfrom keras.src import losses\nfrom keras.src import metrics as metrics_module\nfrom keras.src import tree\nfrom keras.src.legacy.saving import serialization\nfrom keras.src.saving import object_registration\n\nMODULE_OBJECTS = threading.local()\n\n# Legacy lambda arguments not found in Keras 3\nLAMBDA_DEP_ARGS = (\n    \"module\",\n    \"function_type\",\n    \"output_shape_type\",\n    \"output_shape_module\",\n)\n\n\ndef model_from_config(config, custom_objects=None):\n    \"\"\"Instantiates a Keras model from its config.\n\n    Args:\n        config: Configuration dictionary.\n        custom_objects: Optional dictionary mapping names\n            (strings) to custom classes or functions to be\n            considered during deserialization.\n\n    Returns:\n        A Keras model instance (uncompiled).\n\n    Raises:\n        TypeError: if `config` is not a dictionary.\n    \"\"\"\n    if isinstance(config, list):\n        raise TypeError(\n            \"`model_from_config` expects a dictionary, not a list. \"\n            f\"Received: config={config}. Did you meant to use \"\n            \"`Sequential.from_config(config)`?\"\n        )\n\n    global MODULE_OBJECTS\n\n    if not hasattr(MODULE_OBJECTS, \"ALL_OBJECTS\"):\n        from keras.src import layers\n        from keras.src import models\n\n        MODULE_OBJECTS.ALL_OBJECTS = layers.__dict__\n        MODULE_OBJECTS.ALL_OBJECTS[\"InputLayer\"] = layers.InputLayer\n        MODULE_OBJECTS.ALL_OBJECTS[\"Functional\"] = models.Functional\n        MODULE_OBJECTS.ALL_OBJECTS[\"Model\"] = models.Model\n        MODULE_OBJECTS.ALL_OBJECTS[\"Sequential\"] = models.Sequential\n\n    batch_input_shape = config[\"config\"].pop(\"batch_input_shape\", None)\n    if batch_input_shape is not None:\n        if config[\"class_name\"] == \"InputLayer\":\n            config[\"config\"][\"batch_shape\"] = batch_input_shape\n        else:\n            config[\"config\"][\"input_shape\"] = batch_input_shape\n\n    axis = config[\"config\"].pop(\"axis\", None)\n    if axis is not None:\n        if isinstance(axis, list) and len(axis) == 1:\n            config[\"config\"][\"axis\"] = int(axis[0])\n        elif isinstance(axis, (int, float)):\n            config[\"config\"][\"axis\"] = int(axis)\n\n    # Handle backwards compatibility for Keras lambdas\n    if config[\"class_name\"] == \"Lambda\":\n        for dep_arg in LAMBDA_DEP_ARGS:\n            _ = config[\"config\"].pop(dep_arg, None)\n        function_config = config[\"config\"][\"function\"]\n        if isinstance(function_config, list):\n            function_dict = {\"class_name\": \"__lambda__\", \"config\": {}}\n            function_dict[\"config\"][\"code\"] = function_config[0]\n            function_dict[\"config\"][\"defaults\"] = function_config[1]\n            function_dict[\"config\"][\"closure\"] = function_config[2]\n            config[\"config\"][\"function\"] = function_dict\n\n    return serialization.deserialize_keras_object(\n        config,\n        module_objects=MODULE_OBJECTS.ALL_OBJECTS,\n        custom_objects=custom_objects,\n        printable_module_name=\"layer\",\n    )\n\n\ndef model_metadata(model, include_optimizer=True, require_config=True):\n    \"\"\"Returns a dictionary containing the model metadata.\"\"\"\n    from keras.src import __version__ as keras_version\n\n    model_config = {\"class_name\": model.__class__.__name__}\n    try:\n        model_config[\"config\"] = model.get_config()\n    except NotImplementedError as e:\n        if require_config:\n            raise e\n\n    metadata = dict(\n        keras_version=str(keras_version),\n        backend=backend.backend(),\n        model_config=model_config,\n    )\n    if getattr(model, \"optimizer\", False) and include_optimizer:\n        if model.compiled:\n            training_config = model._compile_config.config\n            training_config.pop(\"optimizer\", None)  # Handled separately.\n            metadata[\"training_config\"] = _serialize_nested_config(\n                training_config\n            )\n            optimizer_config = {\n                \"class_name\": object_registration.get_registered_name(\n                    model.optimizer.__class__\n                ),\n                \"config\": model.optimizer.get_config(),\n            }\n            metadata[\"training_config\"][\"optimizer_config\"] = optimizer_config\n    return metadata\n\n\ndef compile_args_from_training_config(training_config, custom_objects=None):\n    \"\"\"Return model.compile arguments from training config.\"\"\"\n    if custom_objects is None:\n        custom_objects = {}\n\n    with object_registration.CustomObjectScope(custom_objects):\n        from keras.src import optimizers\n\n        optimizer_config = training_config[\"optimizer_config\"]\n        optimizer = optimizers.deserialize(optimizer_config)\n        # Ensure backwards compatibility for optimizers in legacy H5 files\n        optimizer = _resolve_compile_arguments_compat(\n            optimizer, optimizer_config, optimizers\n        )\n\n        # Recover losses.\n        loss = None\n        loss_config = training_config.get(\"loss\", None)\n        if loss_config is not None:\n            loss = _deserialize_nested_config(losses.deserialize, loss_config)\n            # Ensure backwards compatibility for losses in legacy H5 files\n            loss = _resolve_compile_arguments_compat(loss, loss_config, losses)\n\n        # Recover metrics.\n        metrics = None\n        metrics_config = training_config.get(\"metrics\", None)\n        if metrics_config is not None:\n            metrics = _deserialize_nested_config(\n                _deserialize_metric, metrics_config\n            )\n            # Ensure backwards compatibility for metrics in legacy H5 files\n            metrics = _resolve_compile_arguments_compat(\n                metrics, metrics_config, metrics_module\n            )\n\n        # Recover weighted metrics.\n        weighted_metrics = None\n        weighted_metrics_config = training_config.get(\"weighted_metrics\", None)\n        if weighted_metrics_config is not None:\n            weighted_metrics = _deserialize_nested_config(\n                _deserialize_metric, weighted_metrics_config\n            )\n\n        loss_weights = training_config[\"loss_weights\"]\n\n    return dict(\n        optimizer=optimizer,\n        loss=loss,\n        metrics=metrics,\n        weighted_metrics=weighted_metrics,\n        loss_weights=loss_weights,\n    )\n\n\ndef _serialize_nested_config(config):\n    \"\"\"Serialized a nested structure of Keras objects.\"\"\"\n\n    def _serialize_fn(obj):\n        if callable(obj):\n            return serialization.serialize_keras_object(obj)\n        return obj\n\n    return tree.map_structure(_serialize_fn, config)\n\n\ndef _deserialize_nested_config(deserialize_fn, config):\n    \"\"\"Deserializes arbitrary Keras `config` using `deserialize_fn`.\"\"\"\n\n    def _is_single_object(obj):\n        if isinstance(obj, dict) and \"class_name\" in obj:\n            return True  # Serialized Keras object.\n        if isinstance(obj, str):\n            return True  # Serialized function or string.\n        return False\n\n    if config is None:\n        return None\n    if _is_single_object(config):\n        return deserialize_fn(config)\n    elif isinstance(config, dict):\n        return {\n            k: _deserialize_nested_config(deserialize_fn, v)\n            for k, v in config.items()\n        }\n    elif isinstance(config, (tuple, list)):\n        return [\n            _deserialize_nested_config(deserialize_fn, obj) for obj in config\n        ]\n\n    raise ValueError(\n        \"Saved configuration not understood. Configuration should be a \"\n        f\"dictionary, string, tuple or list. Received: config={config}.\"\n    )\n\n\ndef _deserialize_metric(metric_config):\n    \"\"\"Deserialize metrics, leaving special strings untouched.\"\"\"\n    if metric_config in [\"accuracy\", \"acc\", \"crossentropy\", \"ce\"]:\n        # Do not deserialize accuracy and cross-entropy strings as we have\n        # special case handling for these in compile, based on model output\n        # shape.\n        return metric_config\n    return metrics_module.deserialize(metric_config)\n\n\ndef _resolve_compile_arguments_compat(obj, obj_config, module):\n    \"\"\"Resolves backwards compatibility issues with training config arguments.\n\n    This helper function accepts built-in Keras modules such as optimizers,\n    losses, and metrics to ensure an object being deserialized is compatible\n    with Keras 3 built-ins. For legacy H5 files saved within Keras 3,\n    this does nothing.\n    \"\"\"\n    if isinstance(obj, str) and obj not in module.ALL_OBJECTS_DICT:\n        obj = module.get(obj_config[\"config\"][\"name\"])\n    return obj\n\n\ndef try_build_compiled_arguments(model):\n    try:\n        if not model.compiled_loss.built:\n            model.compiled_loss.build(model.outputs)\n        if not model.compiled_metrics.built:\n            model.compiled_metrics.build(model.outputs, model.outputs)\n    except:\n        logging.warning(\n            \"Compiled the loaded model, but the compiled metrics have \"\n            \"yet to be built. `model.compile_metrics` will be empty \"\n            \"until you train or evaluate the model.\"\n        )\n"
  },
  {
    "path": "keras/src/legacy/saving/serialization.py",
    "content": "\"\"\"Legacy serialization logic for Keras models.\"\"\"\n\nimport contextlib\nimport inspect\nimport threading\nimport weakref\n\n# isort: off\nfrom keras.src.api_export import keras_export\nfrom keras.src.saving import object_registration\n\n# Flag that determines whether to skip the NotImplementedError when calling\n# get_config in custom models and layers. This is only enabled when saving to\n# SavedModel, when the config isn't required.\n_SKIP_FAILED_SERIALIZATION = False\n# If a layer does not have a defined config, then the returned config will be a\n# dictionary with the below key.\n_LAYER_UNDEFINED_CONFIG_KEY = \"layer was saved without config\"\n\n# Store a unique, per-object ID for shared objects.\n#\n# We store a unique ID for each object so that we may, at loading time,\n# re-create the network properly.  Without this ID, we would have no way of\n# determining whether a config is a description of a new object that\n# should be created or is merely a reference to an already-created object.\nSHARED_OBJECT_KEY = \"shared_object_id\"\n\nSHARED_OBJECT_DISABLED = threading.local()\nSHARED_OBJECT_LOADING = threading.local()\nSHARED_OBJECT_SAVING = threading.local()\n\n\n# Attributes on the threadlocal variable must be set per-thread, thus we\n# cannot initialize these globally. Instead, we have accessor functions with\n# default values.\ndef _shared_object_disabled():\n    \"\"\"Get whether shared object handling is disabled in a threadsafe manner.\"\"\"\n    return getattr(SHARED_OBJECT_DISABLED, \"disabled\", False)\n\n\ndef _shared_object_loading_scope():\n    \"\"\"Get the current shared object saving scope in a threadsafe manner.\"\"\"\n    return getattr(SHARED_OBJECT_LOADING, \"scope\", NoopLoadingScope())\n\n\ndef _shared_object_saving_scope():\n    \"\"\"Get the current shared object saving scope in a threadsafe manner.\"\"\"\n    return getattr(SHARED_OBJECT_SAVING, \"scope\", None)\n\n\nclass DisableSharedObjectScope:\n    \"\"\"A context manager for disabling handling of shared objects.\n\n    Disables shared object handling for both saving and loading.\n\n    Created primarily for use with `clone_model`, which does extra surgery that\n    is incompatible with shared objects.\n    \"\"\"\n\n    def __enter__(self):\n        SHARED_OBJECT_DISABLED.disabled = True\n        self._orig_loading_scope = _shared_object_loading_scope()\n        self._orig_saving_scope = _shared_object_saving_scope()\n\n    def __exit__(self, *args, **kwargs):\n        SHARED_OBJECT_DISABLED.disabled = False\n        SHARED_OBJECT_LOADING.scope = self._orig_loading_scope\n        SHARED_OBJECT_SAVING.scope = self._orig_saving_scope\n\n\nclass NoopLoadingScope:\n    \"\"\"The default shared object loading scope. It does nothing.\n\n    Created to simplify serialization code that doesn't care about shared\n    objects (e.g. when serializing a single object).\n    \"\"\"\n\n    def get(self, unused_object_id):\n        return None\n\n    def set(self, object_id, obj):\n        pass\n\n\nclass SharedObjectLoadingScope:\n    \"\"\"A context manager for keeping track of loaded objects.\n\n    During the deserialization process, we may come across objects that are\n    shared across multiple layers. In order to accurately restore the network\n    structure to its original state, `SharedObjectLoadingScope` allows us to\n    re-use shared objects rather than cloning them.\n    \"\"\"\n\n    def __enter__(self):\n        if _shared_object_disabled():\n            return NoopLoadingScope()\n\n        global SHARED_OBJECT_LOADING\n        SHARED_OBJECT_LOADING.scope = self\n        self._obj_ids_to_obj = {}\n        return self\n\n    def get(self, object_id):\n        \"\"\"Given a shared object ID, returns a previously instantiated object.\n\n        Args:\n          object_id: shared object ID to use when attempting to find\n            already-loaded object.\n\n        Returns:\n          The object, if we've seen this ID before. Else, `None`.\n        \"\"\"\n        # Explicitly check for `None` internally to make external calling code a\n        # bit cleaner.\n        if object_id is None:\n            return\n        return self._obj_ids_to_obj.get(object_id)\n\n    def set(self, object_id, obj):\n        \"\"\"Stores an instantiated object for future lookup and sharing.\"\"\"\n        if object_id is None:\n            return\n        self._obj_ids_to_obj[object_id] = obj\n\n    def __exit__(self, *args, **kwargs):\n        global SHARED_OBJECT_LOADING\n        SHARED_OBJECT_LOADING.scope = NoopLoadingScope()\n\n\nclass SharedObjectConfig(dict):\n    \"\"\"A configuration container that keeps track of references.\n\n    `SharedObjectConfig` will automatically attach a shared object ID to any\n    configs which are referenced more than once, allowing for proper shared\n    object reconstruction at load time.\n\n    In most cases, it would be more proper to subclass something like\n    `collections.UserDict` or `collections.Mapping` rather than `dict` directly.\n    Unfortunately, python's json encoder does not support `Mapping`s. This is\n    important functionality to retain, since we are dealing with serialization.\n\n    We should be safe to subclass `dict` here, since we aren't actually\n    overriding any core methods, only augmenting with a new one for reference\n    counting.\n    \"\"\"\n\n    def __init__(self, base_config, object_id, **kwargs):\n        self.ref_count = 1\n        self.object_id = object_id\n        super().__init__(base_config, **kwargs)\n\n    def increment_ref_count(self):\n        # As soon as we've seen the object more than once, we want to attach the\n        # shared object ID. This allows us to only attach the shared object ID\n        # when it's strictly necessary, making backwards compatibility breakage\n        # less likely.\n        if self.ref_count == 1:\n            self[SHARED_OBJECT_KEY] = self.object_id\n        self.ref_count += 1\n\n\nclass SharedObjectSavingScope:\n    \"\"\"Keeps track of shared object configs when serializing.\"\"\"\n\n    def __enter__(self):\n        if _shared_object_disabled():\n            return None\n\n        global SHARED_OBJECT_SAVING\n\n        # Serialization can happen at a number of layers for a number of\n        # reasons.  We may end up with a case where we're opening a saving scope\n        # within another saving scope. In that case, we'd like to use the\n        # outermost scope available and ignore inner scopes, since there is not\n        # (yet) a reasonable use case for having these nested and distinct.\n        if _shared_object_saving_scope() is not None:\n            self._passthrough = True\n            return _shared_object_saving_scope()\n        else:\n            self._passthrough = False\n\n        SHARED_OBJECT_SAVING.scope = self\n        self._shared_objects_config = weakref.WeakKeyDictionary()\n        self._next_id = 0\n        return self\n\n    def get_config(self, obj):\n        \"\"\"Gets a `SharedObjectConfig` if one has already been seen for `obj`.\n\n        Args:\n          obj: The object for which to retrieve the `SharedObjectConfig`.\n\n        Returns:\n          The SharedObjectConfig for a given object, if already seen. Else,\n            `None`.\n        \"\"\"\n        try:\n            shared_object_config = self._shared_objects_config[obj]\n        except (TypeError, KeyError):\n            # If the object is unhashable (e.g. a subclass of\n            # `AbstractBaseClass` that has not overridden `__hash__`), a\n            # `TypeError` will be thrown.  We'll just continue on without shared\n            # object support.\n            return None\n        shared_object_config.increment_ref_count()\n        return shared_object_config\n\n    def create_config(self, base_config, obj):\n        \"\"\"Create a new SharedObjectConfig for a given object.\"\"\"\n        shared_object_config = SharedObjectConfig(base_config, self._next_id)\n        self._next_id += 1\n        try:\n            self._shared_objects_config[obj] = shared_object_config\n        except TypeError:\n            # If the object is unhashable (e.g. a subclass of\n            # `AbstractBaseClass` that has not overridden `__hash__`), a\n            # `TypeError` will be thrown.  We'll just continue on without shared\n            # object support.\n            pass\n        return shared_object_config\n\n    def __exit__(self, *args, **kwargs):\n        if not getattr(self, \"_passthrough\", False):\n            global SHARED_OBJECT_SAVING\n            SHARED_OBJECT_SAVING.scope = None\n\n\ndef serialize_keras_class_and_config(\n    cls_name, cls_config, obj=None, shared_object_id=None\n):\n    \"\"\"Returns the serialization of the class with the given config.\"\"\"\n    base_config = {\"class_name\": cls_name, \"config\": cls_config}\n\n    # We call `serialize_keras_class_and_config` for some branches of the load\n    # path. In that case, we may already have a shared object ID we'd like to\n    # retain.\n    if shared_object_id is not None:\n        base_config[SHARED_OBJECT_KEY] = shared_object_id\n\n    # If we have an active `SharedObjectSavingScope`, check whether we've\n    # already serialized this config. If so, just use that config. This will\n    # store an extra ID field in the config, allowing us to re-create the shared\n    # object relationship at load time.\n    if _shared_object_saving_scope() is not None and obj is not None:\n        shared_object_config = _shared_object_saving_scope().get_config(obj)\n        if shared_object_config is None:\n            return _shared_object_saving_scope().create_config(base_config, obj)\n        return shared_object_config\n\n    return base_config\n\n\n@contextlib.contextmanager\ndef skip_failed_serialization():\n    global _SKIP_FAILED_SERIALIZATION\n    prev = _SKIP_FAILED_SERIALIZATION\n    try:\n        _SKIP_FAILED_SERIALIZATION = True\n        yield\n    finally:\n        _SKIP_FAILED_SERIALIZATION = prev\n\n\n@keras_export(\n    [\n        \"keras.legacy.saving.serialize_keras_object\",\n        \"keras.utils.legacy.serialize_keras_object\",\n    ]\n)\ndef serialize_keras_object(instance):\n    \"\"\"Serialize a Keras object into a JSON-compatible representation.\n\n    Calls to `serialize_keras_object` while underneath the\n    `SharedObjectSavingScope` context manager will cause any objects re-used\n    across multiple layers to be saved with a special shared object ID. This\n    allows the network to be re-created properly during deserialization.\n\n    Args:\n      instance: The object to serialize.\n\n    Returns:\n      A dict-like, JSON-compatible representation of the object's config.\n    \"\"\"\n\n    # _, instance = tf.__internal__.decorator.unwrap(instance)\n    instance = inspect.unwrap(instance)\n    if instance is None:\n        return None\n\n    if hasattr(instance, \"get_config\"):\n        name = object_registration.get_registered_name(instance.__class__)\n        try:\n            config = instance.get_config()\n        except NotImplementedError as e:\n            if _SKIP_FAILED_SERIALIZATION:\n                return serialize_keras_class_and_config(\n                    name, {_LAYER_UNDEFINED_CONFIG_KEY: True}\n                )\n            raise e\n        serialization_config = {}\n        for key, item in config.items():\n            if isinstance(item, str):\n                serialization_config[key] = item\n                continue\n\n            # Any object of a different type needs to be converted to string or\n            # dict for serialization (e.g. custom functions, custom classes)\n            try:\n                serialized_item = serialize_keras_object(item)\n                if isinstance(serialized_item, dict) and not isinstance(\n                    item, dict\n                ):\n                    serialized_item[\"__passive_serialization__\"] = True\n                serialization_config[key] = serialized_item\n            except ValueError:\n                serialization_config[key] = item\n\n        name = object_registration.get_registered_name(instance.__class__)\n        return serialize_keras_class_and_config(\n            name, serialization_config, instance\n        )\n    if hasattr(instance, \"__name__\"):\n        return object_registration.get_registered_name(instance)\n    raise ValueError(\n        f\"Cannot serialize {instance} because it doesn't implement \"\n        \"`get_config()`.\"\n    )\n\n\ndef class_and_config_for_serialized_keras_object(\n    config,\n    module_objects=None,\n    custom_objects=None,\n    printable_module_name=\"object\",\n):\n    \"\"\"Returns the class name and config for a serialized keras object.\"\"\"\n\n    if (\n        not isinstance(config, dict)\n        or \"class_name\" not in config\n        or \"config\" not in config\n    ):\n        raise ValueError(\n            f\"Improper config format for {config}. \"\n            \"Expecting python dict contains `class_name` and `config` as keys\"\n        )\n\n    class_name = config[\"class_name\"]\n    cls = object_registration.get_registered_object(\n        class_name, custom_objects, module_objects\n    )\n    if cls is None:\n        raise ValueError(\n            f\"Unknown {printable_module_name}: '{class_name}'. \"\n            \"Please ensure you are using a `keras.utils.custom_object_scope` \"\n            \"and that this object is included in the scope. See \"\n            \"https://www.tensorflow.org/guide/keras/save_and_serialize\"\n            \"#registering_the_custom_object for details.\"\n        )\n\n    cls_config = config[\"config\"]\n    # Check if `cls_config` is a list. If it is a list, return the class and the\n    # associated class configs for recursively deserialization. This case will\n    # happen on the old version of sequential model (e.g. `keras_version` ==\n    # \"2.0.6\"), which is serialized in a different structure, for example\n    # \"{'class_name': 'Sequential',\n    #   'config': [{'class_name': 'Embedding', 'config': ...}, {}, ...]}\".\n    if isinstance(cls_config, list):\n        return (cls, cls_config)\n\n    deserialized_objects = {}\n    for key, item in cls_config.items():\n        if key == \"name\":\n            # Assume that the value of 'name' is a string that should not be\n            # deserialized as a function. This avoids the corner case where\n            # cls_config['name'] has an identical name to a custom function and\n            # gets converted into that function.\n            deserialized_objects[key] = item\n        elif isinstance(item, dict) and \"__passive_serialization__\" in item:\n            deserialized_objects[key] = deserialize_keras_object(\n                item,\n                module_objects=module_objects,\n                custom_objects=custom_objects,\n                printable_module_name=\"config_item\",\n            )\n        # Also consider looking up functions in `module_objects`.\n        elif isinstance(item, str) and inspect.isfunction(\n            object_registration.get_registered_object(\n                item, custom_objects, module_objects\n            )\n        ):\n            # Handle custom functions here. When saving functions, we only save\n            # the function's name as a string. If we find a matching string in\n            # the custom objects during deserialization, we convert the string\n            # back to the original function.\n            # Note that a potential issue is that a string field could have a\n            # naming conflict with a custom function name, but this should be a\n            # rare case.  This issue does not occur if a string field has a\n            # naming conflict with a custom object, since the config of an\n            # object will always be a dict.\n            deserialized_objects[key] = (\n                object_registration.get_registered_object(\n                    item, custom_objects, module_objects\n                )\n            )\n    for key, item in deserialized_objects.items():\n        cls_config[key] = deserialized_objects[key]\n\n    return (cls, cls_config)\n\n\n@keras_export(\n    [\n        \"keras.legacy.saving.deserialize_keras_object\",\n        \"keras.utils.legacy.deserialize_keras_object\",\n    ]\n)\ndef deserialize_keras_object(\n    identifier,\n    module_objects=None,\n    custom_objects=None,\n    printable_module_name=\"object\",\n):\n    \"\"\"Turns the serialized form of a Keras object back into an actual object.\n\n    This function is for mid-level library implementers rather than end users.\n\n    Importantly, this utility requires you to provide the dict of\n    `module_objects` to use for looking up the object config; this is not\n    populated by default. If you need a deserialization utility that has\n    preexisting knowledge of built-in Keras objects, use e.g.\n    `keras.layers.deserialize(config)`, `keras.metrics.deserialize(config)`,\n    etc.\n\n    Calling `deserialize_keras_object` while underneath the\n    `SharedObjectLoadingScope` context manager will cause any already-seen\n    shared objects to be returned as-is rather than creating a new object.\n\n    Args:\n      identifier: the serialized form of the object.\n      module_objects: A dictionary of built-in objects to look the name up in.\n        Generally, `module_objects` is provided by midlevel library\n        implementers.\n      custom_objects: A dictionary of custom objects to look the name up in.\n        Generally, `custom_objects` is provided by the end user.\n      printable_module_name: A human-readable string representing the type of\n        the object. Printed in case of exception.\n\n    Returns:\n      The deserialized object.\n\n    Example:\n\n    A mid-level library implementer might want to implement a utility for\n    retrieving an object from its config, as such:\n\n    ```python\n    def deserialize(config, custom_objects=None):\n       return deserialize_keras_object(\n         identifier,\n         module_objects=globals(),\n         custom_objects=custom_objects,\n         name=\"MyObjectType\",\n       )\n    ```\n\n    This is how e.g. `keras.layers.deserialize()` is implemented.\n    \"\"\"\n\n    if identifier is None:\n        return None\n\n    if isinstance(identifier, dict):\n        # In this case we are dealing with a Keras config dictionary.\n        config = identifier\n        (cls, cls_config) = class_and_config_for_serialized_keras_object(\n            config, module_objects, custom_objects, printable_module_name\n        )\n\n        # If this object has already been loaded (i.e. it's shared between\n        # multiple objects), return the already-loaded object.\n        shared_object_id = config.get(SHARED_OBJECT_KEY)\n        shared_object = _shared_object_loading_scope().get(shared_object_id)\n        if shared_object is not None:\n            return shared_object\n\n        if hasattr(cls, \"from_config\"):\n            arg_spec = inspect.getfullargspec(cls.from_config)\n            custom_objects = custom_objects or {}\n\n            if \"custom_objects\" in arg_spec.args:\n                deserialized_obj = cls.from_config(\n                    cls_config,\n                    custom_objects={\n                        **object_registration.GLOBAL_CUSTOM_OBJECTS,\n                        **custom_objects,\n                    },\n                )\n            else:\n                with object_registration.CustomObjectScope(custom_objects):\n                    deserialized_obj = cls.from_config(cls_config)\n        else:\n            # Then `cls` may be a function returning a class.\n            # in this case by convention `config` holds\n            # the kwargs of the function.\n            custom_objects = custom_objects or {}\n            with object_registration.CustomObjectScope(custom_objects):\n                deserialized_obj = cls(**cls_config)\n\n        # Add object to shared objects, in case we find it referenced again.\n        _shared_object_loading_scope().set(shared_object_id, deserialized_obj)\n\n        return deserialized_obj\n\n    elif isinstance(identifier, str):\n        object_name = identifier\n        if custom_objects and object_name in custom_objects:\n            obj = custom_objects.get(object_name)\n        elif (\n            object_name\n            in object_registration._THREAD_LOCAL_CUSTOM_OBJECTS.__dict__\n        ):\n            obj = object_registration._THREAD_LOCAL_CUSTOM_OBJECTS.__dict__[\n                object_name\n            ]\n        elif object_name in object_registration._GLOBAL_CUSTOM_OBJECTS:\n            obj = object_registration._GLOBAL_CUSTOM_OBJECTS[object_name]\n        else:\n            obj = module_objects.get(object_name)\n            if obj is None:\n                raise ValueError(\n                    f\"Unknown {printable_module_name}: '{object_name}'. \"\n                    \"Please ensure you are using a \"\n                    \"`keras.utils.custom_object_scope` \"\n                    \"and that this object is included in the scope. See \"\n                    \"https://www.tensorflow.org/guide/keras/save_and_serialize\"\n                    \"#registering_the_custom_object for details.\"\n                )\n\n        # Classes passed by name are instantiated with no args, functions are\n        # returned as-is.\n        if inspect.isclass(obj):\n            return obj()\n        return obj\n    elif inspect.isfunction(identifier):\n        # If a function has already been deserialized, return as is.\n        return identifier\n    else:\n        raise ValueError(\n            \"Could not interpret serialized \"\n            f\"{printable_module_name}: {identifier}\"\n        )\n\n\ndef validate_config(config):\n    \"\"\"Determines whether config appears to be a valid layer config.\"\"\"\n    return (\n        isinstance(config, dict) and _LAYER_UNDEFINED_CONFIG_KEY not in config\n    )\n\n\ndef is_default(method):\n    \"\"\"Check if a method is decorated with the `default` wrapper.\"\"\"\n    return getattr(method, \"_is_default\", False)\n"
  },
  {
    "path": "keras/src/losses/__init__.py",
    "content": "import inspect\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.losses.loss import Loss\nfrom keras.src.losses.losses import CTC\nfrom keras.src.losses.losses import BinaryCrossentropy\nfrom keras.src.losses.losses import BinaryFocalCrossentropy\nfrom keras.src.losses.losses import CategoricalCrossentropy\nfrom keras.src.losses.losses import CategoricalFocalCrossentropy\nfrom keras.src.losses.losses import CategoricalHinge\nfrom keras.src.losses.losses import Circle\nfrom keras.src.losses.losses import CosineSimilarity\nfrom keras.src.losses.losses import Dice\nfrom keras.src.losses.losses import Hinge\nfrom keras.src.losses.losses import Huber\nfrom keras.src.losses.losses import KLDivergence\nfrom keras.src.losses.losses import LogCosh\nfrom keras.src.losses.losses import LossFunctionWrapper\nfrom keras.src.losses.losses import MeanAbsoluteError\nfrom keras.src.losses.losses import MeanAbsolutePercentageError\nfrom keras.src.losses.losses import MeanSquaredError\nfrom keras.src.losses.losses import MeanSquaredLogarithmicError\nfrom keras.src.losses.losses import Poisson\nfrom keras.src.losses.losses import SparseCategoricalCrossentropy\nfrom keras.src.losses.losses import SquaredHinge\nfrom keras.src.losses.losses import Tversky\nfrom keras.src.losses.losses import binary_crossentropy\nfrom keras.src.losses.losses import binary_focal_crossentropy\nfrom keras.src.losses.losses import categorical_crossentropy\nfrom keras.src.losses.losses import categorical_focal_crossentropy\nfrom keras.src.losses.losses import categorical_hinge\nfrom keras.src.losses.losses import circle\nfrom keras.src.losses.losses import cosine_similarity\nfrom keras.src.losses.losses import ctc\nfrom keras.src.losses.losses import dice\nfrom keras.src.losses.losses import hinge\nfrom keras.src.losses.losses import huber\nfrom keras.src.losses.losses import kl_divergence\nfrom keras.src.losses.losses import log_cosh\nfrom keras.src.losses.losses import mean_absolute_error\nfrom keras.src.losses.losses import mean_absolute_percentage_error\nfrom keras.src.losses.losses import mean_squared_error\nfrom keras.src.losses.losses import mean_squared_logarithmic_error\nfrom keras.src.losses.losses import poisson\nfrom keras.src.losses.losses import sparse_categorical_crossentropy\nfrom keras.src.losses.losses import squared_hinge\nfrom keras.src.losses.losses import tversky\nfrom keras.src.saving import serialization_lib\n\nALL_OBJECTS = {\n    # Base\n    Loss,\n    LossFunctionWrapper,\n    # Probabilistic\n    KLDivergence,\n    Poisson,\n    BinaryCrossentropy,\n    BinaryFocalCrossentropy,\n    CategoricalCrossentropy,\n    CategoricalFocalCrossentropy,\n    SparseCategoricalCrossentropy,\n    # Regression\n    MeanSquaredError,\n    MeanAbsoluteError,\n    MeanAbsolutePercentageError,\n    MeanSquaredLogarithmicError,\n    CosineSimilarity,\n    LogCosh,\n    Huber,\n    # Hinge\n    Hinge,\n    SquaredHinge,\n    CategoricalHinge,\n    # Image segmentation\n    Dice,\n    Tversky,\n    # Similarity\n    Circle,\n    # Sequence\n    CTC,\n    # Probabilistic\n    kl_divergence,\n    poisson,\n    binary_crossentropy,\n    binary_focal_crossentropy,\n    categorical_crossentropy,\n    categorical_focal_crossentropy,\n    sparse_categorical_crossentropy,\n    # Regression\n    mean_squared_error,\n    mean_absolute_error,\n    mean_absolute_percentage_error,\n    mean_squared_logarithmic_error,\n    cosine_similarity,\n    log_cosh,\n    huber,\n    # Hinge\n    hinge,\n    squared_hinge,\n    categorical_hinge,\n    # Image segmentation\n    dice,\n    tversky,\n    # Similarity\n    circle,\n    # Sequence\n    ctc,\n}\n\nALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}\nALL_OBJECTS_DICT.update(\n    {\n        \"bce\": binary_crossentropy,\n        \"BCE\": binary_crossentropy,\n        \"kld\": kl_divergence,\n        \"KLD\": kl_divergence,\n        \"mae\": mean_absolute_error,\n        \"MAE\": mean_absolute_error,\n        \"mse\": mean_squared_error,\n        \"MSE\": mean_squared_error,\n        \"mape\": mean_absolute_percentage_error,\n        \"MAPE\": mean_absolute_percentage_error,\n        \"msle\": mean_squared_logarithmic_error,\n        \"MSLE\": mean_squared_logarithmic_error,\n    }\n)\n\n\n@keras_export(\"keras.losses.serialize\")\ndef serialize(loss):\n    \"\"\"Serializes loss function or `Loss` instance.\n\n    Args:\n        loss: A Keras `Loss` instance or a loss function.\n\n    Returns:\n        Loss configuration dictionary.\n    \"\"\"\n    return serialization_lib.serialize_keras_object(loss)\n\n\n@keras_export(\"keras.losses.deserialize\")\ndef deserialize(name, custom_objects=None):\n    \"\"\"Deserializes a serialized loss class/function instance.\n\n    Args:\n        name: Loss configuration.\n        custom_objects: Optional dictionary mapping names (strings) to custom\n            objects (classes and functions) to be considered during\n            deserialization.\n\n    Returns:\n        A Keras `Loss` instance or a loss function.\n    \"\"\"\n    return serialization_lib.deserialize_keras_object(\n        name,\n        module_objects=ALL_OBJECTS_DICT,\n        custom_objects=custom_objects,\n    )\n\n\n@keras_export(\"keras.losses.get\")\ndef get(identifier):\n    \"\"\"Retrieves a Keras loss as a `function`/`Loss` class instance.\n\n    The `identifier` may be the string name of a loss function or `Loss` class.\n\n    >>> loss = losses.get(\"categorical_crossentropy\")\n    >>> type(loss)\n    <class 'function'>\n    >>> loss = losses.get(\"CategoricalCrossentropy\")\n    >>> type(loss)\n    <class '...CategoricalCrossentropy'>\n\n    You can also specify `config` of the loss to this function by passing dict\n    containing `class_name` and `config` as an identifier. Also note that the\n    `class_name` must map to a `Loss` class\n\n    >>> identifier = {\"class_name\": \"CategoricalCrossentropy\",\n    ...               \"config\": {\"from_logits\": True}}\n    >>> loss = losses.get(identifier)\n    >>> type(loss)\n    <class '...CategoricalCrossentropy'>\n\n    Args:\n        identifier: A loss identifier. One of None or string name of a loss\n            function/class or loss configuration dictionary or a loss function\n            or a loss class instance.\n\n    Returns:\n        A Keras loss as a `function`/ `Loss` class instance.\n    \"\"\"\n    if identifier is None:\n        return None\n    if isinstance(identifier, dict):\n        obj = deserialize(identifier)\n    elif isinstance(identifier, str):\n        obj = ALL_OBJECTS_DICT.get(identifier, None)\n    else:\n        obj = identifier\n\n    if callable(obj):\n        if inspect.isclass(obj):\n            obj = obj()\n        return obj\n    else:\n        raise ValueError(f\"Could not interpret loss identifier: {identifier}\")\n"
  },
  {
    "path": "keras/src/losses/loss.py",
    "content": "from keras.src import backend\nfrom keras.src import dtype_policies\nfrom keras.src import ops\nfrom keras.src import tree\nfrom keras.src.api_export import keras_export\nfrom keras.src.saving.keras_saveable import KerasSaveable\nfrom keras.src.utils.naming import auto_name\n\n\n@keras_export([\"keras.Loss\", \"keras.losses.Loss\"])\nclass Loss(KerasSaveable):\n    \"\"\"Loss base class.\n\n    This is the class to subclass in order to create new custom losses.\n\n    Args:\n        reduction: Type of reduction to apply to the loss. In almost all cases\n            this should be `\"sum_over_batch_size\"`. Supported options are\n            `\"sum\"`, `\"sum_over_batch_size\"`, `\"mean\"`,\n            `\"mean_with_sample_weight\"` or `None`. `\"sum\"` sums the loss,\n            `\"sum_over_batch_size\"` and `\"mean\"` sum the loss and divide by the\n            sample size, and `\"mean_with_sample_weight\"` sums the loss and\n            divides by the sum of the sample weights. `\"none\"` and `None`\n            perform no aggregation. Defaults to `\"sum_over_batch_size\"`.\n        name: Optional name for the loss instance.\n        dtype: The dtype of the loss's computations. Defaults to `None`, which\n            means using `keras.backend.floatx()`. `keras.backend.floatx()` is a\n            `\"float32\"` unless set to different value\n            (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is\n            provided, then the `compute_dtype` will be utilized.\n\n    To be implemented by subclasses:\n\n    * `call()`: Contains the logic for loss calculation using `y_true`,\n        `y_pred`.\n\n    Example subclass implementation:\n\n    ```python\n    class MeanSquaredError(Loss):\n        def call(self, y_true, y_pred):\n            return ops.mean(ops.square(y_pred - y_true), axis=-1)\n    ```\n    \"\"\"\n\n    def __init__(self, name=None, reduction=\"sum_over_batch_size\", dtype=None):\n        self.name = name or auto_name(self.__class__.__name__)\n        self.reduction = standardize_reduction(reduction)\n        self._dtype_policy = dtype_policies.get(dtype or backend.floatx())\n        self._dtype = self._dtype_policy.compute_dtype\n\n    @property\n    def dtype(self):\n        return self._dtype\n\n    def __call__(self, y_true, y_pred, sample_weight=None):\n        in_mask = backend.get_keras_mask(y_pred)\n\n        with ops.name_scope(self.name):\n            y_pred = tree.map_structure(\n                lambda x: ops.convert_to_tensor(x, dtype=self.dtype), y_pred\n            )\n            y_true = tree.map_structure(\n                lambda x: ops.convert_to_tensor(x, dtype=self.dtype), y_true\n            )\n\n            losses = self.call(y_true, y_pred)\n            out_mask = backend.get_keras_mask(losses)\n\n            if in_mask is not None and out_mask is not None:\n                mask = in_mask & out_mask\n            elif in_mask is not None:\n                mask = in_mask\n            elif out_mask is not None:\n                mask = out_mask\n            else:\n                mask = None\n\n            return reduce_weighted_values(\n                losses,\n                sample_weight=sample_weight,\n                mask=mask,\n                reduction=self.reduction,\n                dtype=self.dtype,\n            )\n\n    def call(self, y_true, y_pred):\n        raise NotImplementedError\n\n    def get_config(self):\n        return {\"name\": self.name, \"reduction\": self.reduction}\n\n    @classmethod\n    def from_config(cls, config):\n        return cls(**config)\n\n    def _obj_type(self):\n        return \"Loss\"\n\n\ndef standardize_reduction(reduction):\n    allowed = {\n        \"sum_over_batch_size\",\n        \"sum\",\n        None,\n        \"none\",\n        \"mean\",\n        \"mean_with_sample_weight\",\n    }\n    if reduction not in allowed:\n        raise ValueError(\n            \"Invalid value for argument `reduction`. \"\n            f\"Expected one of {allowed}. Received: \"\n            f\"reduction={reduction}\"\n        )\n    return reduction\n\n\ndef squeeze_or_expand_to_same_rank(x1, x2, expand_rank_1=True):\n    \"\"\"Squeeze/expand last dim if ranks differ from expected by exactly 1.\"\"\"\n    x1_rank = len(x1.shape)\n    x2_rank = len(x2.shape)\n    if x1_rank == x2_rank:\n        return x1, x2\n    if x1_rank == x2_rank + 1:\n        if x1.shape[-1] == 1:\n            if x2_rank == 1 and expand_rank_1:\n                x2 = ops.expand_dims(x2, axis=-1)\n            else:\n                x1 = ops.squeeze(x1, axis=-1)\n    if x2_rank == x1_rank + 1:\n        if x2.shape[-1] == 1:\n            if x1_rank == 1 and expand_rank_1:\n                x1 = ops.expand_dims(x1, axis=-1)\n            else:\n                x2 = ops.squeeze(x2, axis=-1)\n    return x1, x2\n\n\ndef reduce_values(values, sample_weight=None, reduction=\"sum_over_batch_size\"):\n    if (\n        reduction is None\n        or reduction == \"none\"\n        or tuple(values.shape) == ()\n        or tuple(values.shape) == (0,)\n    ):\n        return values\n    loss = ops.sum(values)\n    if reduction in (\"sum_over_batch_size\", \"mean\", \"mean_with_sample_weight\"):\n        if reduction == \"mean_with_sample_weight\" and sample_weight is not None:\n            divisor = ops.cast(ops.sum(sample_weight), loss.dtype)\n        else:\n            divisor = ops.cast(\n                ops.prod(\n                    ops.convert_to_tensor(ops.shape(values), dtype=\"int32\")\n                ),\n                loss.dtype,\n            )\n        loss = ops.divide_no_nan(loss, divisor)\n        loss = scale_loss_for_distribution(loss)\n    return loss\n\n\ndef reduce_weighted_values(\n    values,\n    sample_weight=None,\n    mask=None,\n    reduction=\"sum_over_batch_size\",\n    dtype=None,\n):\n    reduction = standardize_reduction(reduction)\n\n    values = ops.convert_to_tensor(values, dtype=dtype)\n    if sample_weight is not None:\n        sample_weight = ops.convert_to_tensor(sample_weight, dtype=dtype)\n    if mask is not None:\n        mask = ops.convert_to_tensor(mask, dtype=dtype)\n\n    # Merge mask and sample weight into sample weight.\n    sample_weight = apply_mask(\n        sample_weight, mask, dtype=values.dtype, reduction=reduction\n    )\n\n    if sample_weight is not None:\n        sample_weight = ops.cast(sample_weight, values.dtype)\n        # Update dimensions of `sample_weight` to match `losses`.\n        values, sample_weight = squeeze_or_expand_to_same_rank(\n            values, sample_weight\n        )\n        values = values * sample_weight\n\n    # Apply reduction function to the individual weighted losses.\n    loss = reduce_values(values, sample_weight, reduction)\n    return loss\n\n\ndef apply_mask(sample_weight, mask, dtype, reduction):\n    \"\"\"Applies any mask on predictions to sample weights.\"\"\"\n    if mask is not None:\n        mask = ops.cast(mask, dtype=dtype)\n        if reduction in (\"mean\", \"sum_over_batch_size\"):\n            # Valid entries have weight `total/valid`, while invalid ones\n            # have 0. When summed over batch, they will be reduced to:\n            #\n            # mean(loss * sample_weight * total / valid)\n            #   = sum(loss * sample_weight * total / valid) / total\n            #   = sum(loss * sample_weight) / total * total / valid\n            #   = sum(loss * sample_weight) / valid\n            total = ops.cast(\n                ops.prod(ops.convert_to_tensor(ops.shape(mask), dtype=\"int32\")),\n                dtype,\n            )\n            valid = ops.sum(mask)  # May be 0!\n            mask *= ops.divide_no_nan(total, valid)\n\n        if sample_weight is not None:\n            sample_weight = ops.cast(sample_weight, dtype=dtype)\n            mask, sample_weight = squeeze_or_expand_to_same_rank(\n                mask, sample_weight\n            )\n            sample_weight *= mask\n        else:\n            sample_weight = mask\n    return sample_weight\n\n\ndef scale_loss_for_distribution(value):\n    \"\"\"Scales the given value by the number of replicas in the strategy.\n\n    Currently, this function is only effective when using the tensorflow backend\n    and `tf.distribute`.\n    \"\"\"\n    if backend.backend() == \"tensorflow\":\n        import tensorflow as tf\n\n        num_replicas = tf.distribute.get_strategy().num_replicas_in_sync\n        if num_replicas > 1:\n            value = ops.multiply(\n                value, ops.cast(1.0 / num_replicas, value.dtype)\n            )\n    return value\n\n\ndef unscale_loss_for_distribution(value):\n    \"\"\"Unscales the given value by the number of replicas in the strategy.\n\n    Currently, this function is only effective when using the tensorflow backend\n    and `tf.distribute`.\n    \"\"\"\n    if backend.backend() == \"tensorflow\":\n        import tensorflow as tf\n\n        num_replicas = tf.distribute.get_strategy().num_replicas_in_sync\n        if num_replicas > 1:\n            value = ops.multiply(value, ops.cast(num_replicas, value.dtype))\n    return value\n"
  },
  {
    "path": "keras/src/losses/loss_test.py",
    "content": "import pickle\n\nimport numpy as np\nimport pytest\n\nfrom keras.src import backend\nfrom keras.src import dtype_policies\nfrom keras.src import losses as losses_module\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.losses.loss import Loss\nfrom keras.src.losses.loss import squeeze_or_expand_to_same_rank\n\n\nclass ExampleLoss(Loss):\n    def call(self, y_true, y_pred):\n        return (y_true - y_pred) ** 2\n\n\nclass LossTest(testing.TestCase):\n    def setUp(self):\n        self._global_dtype_policy = dtype_policies.dtype_policy.dtype_policy()\n        self._floatx = backend.floatx()\n        return super().setUp()\n\n    def tearDown(self):\n        dtype_policies.dtype_policy.set_dtype_policy(self._global_dtype_policy)\n        backend.set_floatx(self._floatx)\n        return super().tearDown()\n\n    def test_squeeze_or_expand(self):\n        x1 = ops.ones((3,))\n        x2 = ops.ones((3, 1))\n        x1, x2 = squeeze_or_expand_to_same_rank(x1, x2)\n        self.assertEqual(ops.shape(x1), (3, 1))\n        self.assertEqual(ops.shape(x2), (3, 1))\n\n        x1 = ops.ones((3, 2))\n        x2 = ops.ones((3, 2, 1))\n        x1, x2 = squeeze_or_expand_to_same_rank(x1, x2)\n        self.assertEqual(ops.shape(x1), (3, 2))\n        self.assertEqual(ops.shape(x2), (3, 2))\n\n        x1 = ops.ones((3,))\n        x2 = ops.ones((3, 1))\n        x2, x1 = squeeze_or_expand_to_same_rank(x2, x1)\n        self.assertEqual(ops.shape(x1), (3, 1))\n        self.assertEqual(ops.shape(x2), (3, 1))\n\n        x1 = ops.ones((3, 2))\n        x2 = ops.ones((3, 2, 1))\n        x2, x1 = squeeze_or_expand_to_same_rank(x2, x1)\n        self.assertEqual(ops.shape(x1), (3, 2))\n        self.assertEqual(ops.shape(x2), (3, 2))\n\n    def test_reduction(self):\n        y_true = np.array([1.0, 0.0, 1.0, 0.0])\n        y_pred = np.array([0.1, 0.2, 0.3, 0.4])\n\n        # No reduction\n        loss_fn = ExampleLoss(reduction=None)\n        loss = loss_fn(y_true, y_pred)\n        self.assertEqual(backend.standardize_dtype(loss.dtype), \"float32\")\n        self.assertAllClose((y_true - y_pred) ** 2, loss)\n\n        # sum\n        loss_fn = ExampleLoss(reduction=\"sum\")\n        loss = loss_fn(y_true, y_pred)\n        self.assertEqual(backend.standardize_dtype(loss.dtype), \"float32\")\n        self.assertAllClose(np.sum((y_true - y_pred) ** 2), loss)\n\n        # sum_over_batch_size or mean\n        loss_fn = ExampleLoss(reduction=\"sum_over_batch_size\")\n        loss = loss_fn(y_true, y_pred)\n        self.assertEqual(backend.standardize_dtype(loss.dtype), \"float32\")\n        self.assertAllClose(np.sum((y_true - y_pred) ** 2) / 4, loss)\n\n        # bad reduction\n        with self.assertRaisesRegex(ValueError, \"Invalid value for argument\"):\n            ExampleLoss(reduction=\"abc\")\n\n    @pytest.mark.skipif(\n        backend.backend() == \"numpy\",\n        reason=\"Numpy backend does not support masking.\",\n    )\n    def test_mask(self):\n        mask = np.array([True, False, True, True])\n        y_true = np.array([1.0, 0.0, 1.0, 0.0])\n        y_pred = np.array([0.1, 0.2, 0.3, 0.4])\n\n        masked_y_true = np.array([1.0, 1.0, 0.0])\n        masked_y_pred = np.array([0.1, 0.3, 0.4])\n\n        mask = ops.convert_to_tensor(mask)\n        y_true = ops.convert_to_tensor(y_true)\n        y_pred = ops.convert_to_tensor(y_pred)\n        y_pred._keras_mask = mask\n\n        loss_fn = ExampleLoss()\n        loss = loss_fn(y_true, y_pred)\n        self.assertEqual(backend.standardize_dtype(loss.dtype), \"float32\")\n        self.assertAllClose(\n            np.sum((masked_y_true - masked_y_pred) ** 2) / 3, loss\n        )\n\n        # Test edge case where everything is masked.\n        mask = np.array([False, False, False, False])\n        y_pred._keras_mask = mask\n        loss = loss_fn(y_true, y_pred)\n        self.assertEqual(backend.standardize_dtype(loss.dtype), \"float32\")\n        self.assertAllClose(loss, 0)  # No NaN.\n\n    def test_sample_weight(self):\n        sample_weight = np.array([0.4, 0.3, 0.2, 0.1])\n        y_true = np.array([1.0, 0.0, 1.0, 0.0])\n        y_pred = np.array([0.1, 0.2, 0.3, 0.4])\n\n        loss_fn = ExampleLoss()\n        loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)\n        self.assertEqual(backend.standardize_dtype(loss.dtype), \"float32\")\n        self.assertAllClose(\n            np.sum(sample_weight * (y_true - y_pred) ** 2) / 4, loss\n        )\n\n        # Test edge case where every weight is 0.\n        sample_weight = np.array([0.0, 0.0, 0.0, 0.0])\n        loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)\n        self.assertEqual(backend.standardize_dtype(loss.dtype), \"float32\")\n        self.assertAllClose(loss, 0)  # No NaN.\n\n    @pytest.mark.skipif(\n        backend.backend() == \"numpy\",\n        reason=\"Numpy backend does not support masking.\",\n    )\n    def test_mask_and_sample_weight(self):\n        sample_weight = np.array([0.4, 0.3, 0.2, 0.1])\n        y_true = np.array([1.0, 0.0, 1.0, 0.0])\n        y_pred = np.array([0.1, 0.2, 0.3, 0.4])\n        mask = np.array([True, False, True, True])\n\n        masked_sample_weight = np.array([0.4, 0.2, 0.1])\n        masked_y_true = np.array([1.0, 1.0, 0.0])\n        masked_y_pred = np.array([0.1, 0.3, 0.4])\n\n        mask = ops.convert_to_tensor(mask)\n        y_true = ops.convert_to_tensor(y_true)\n        y_pred = ops.convert_to_tensor(y_pred)\n        y_pred._keras_mask = mask\n\n        loss_fn = ExampleLoss()\n        loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)\n        self.assertEqual(backend.standardize_dtype(loss.dtype), \"float32\")\n        self.assertAllClose(\n            np.sum(masked_sample_weight * (masked_y_true - masked_y_pred) ** 2)\n            / 3,\n            loss,\n        )\n\n    @pytest.mark.skipif(\n        backend.backend() == \"numpy\",\n        reason=\"Numpy backend does not support masking.\",\n    )\n    def test_mask_and_sample_weight_rank2(self):\n        # check loss of inputs with duplicate rows doesn't change\n        sample_weight = np.array([0.4, 0.3, 0.2, 0.1])\n        y_true = np.array([1.0, 0.0, 1.0, 0.0])\n        y_pred = np.array([0.1, 0.2, 0.3, 0.4])\n        mask = np.array([True, False, True, True])\n\n        mask = ops.convert_to_tensor(mask)\n        y_true = ops.convert_to_tensor(y_true)\n        y_pred = ops.convert_to_tensor(y_pred)\n        y_pred._keras_mask = mask\n\n        loss_fn = ExampleLoss()\n        rank1_loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)\n\n        # duplicate rows\n        mask = ops.tile(ops.expand_dims(mask, axis=0), (2, 1))\n        y_true = ops.tile(ops.expand_dims(y_true, axis=0), (2, 1))\n        y_pred = ops.tile(ops.expand_dims(y_pred, axis=0), (2, 1))\n        sample_weight = ops.tile(ops.expand_dims(sample_weight, axis=0), (2, 1))\n        y_pred._keras_mask = mask\n        rank2_loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose(rank1_loss, rank2_loss)\n\n    # @testing.parametrize(\n    #     \"uprank\", [\"mask\", \"sample_weight\", \"y_true\", \"y_pred\"])\n    # TODO: use parameterization decorator\n    @pytest.mark.skipif(\n        backend.backend() == \"numpy\",\n        reason=\"Numpy backend does not support masking.\",\n    )\n    def test_rank_adjustment(self):\n        for uprank in [\"mask\", \"sample_weight\", \"ys\"]:\n            sample_weight = np.array([0.4, 0.3, 0.2, 0.1])\n            y_true = np.array([1.0, 0.0, 1.0, 0.0])\n            y_pred = np.array([0.1, 0.2, 0.3, 0.4])\n            mask = np.array([True, False, True, True])\n\n            if uprank == \"mask\":\n                mask = np.expand_dims(mask, -1)\n            elif uprank == \"sample_weight\":\n                sample_weight = np.expand_dims(sample_weight, -1)\n            elif uprank == \"ys\":\n                y_true = np.expand_dims(y_true, -1)\n                y_pred = np.expand_dims(y_pred, -1)\n\n            masked_sample_weight = np.array([0.4, 0.2, 0.1])\n            masked_y_true = np.array([1.0, 1.0, 0.0])\n            masked_y_pred = np.array([0.1, 0.3, 0.4])\n\n            mask = ops.convert_to_tensor(mask)\n            y_true = ops.convert_to_tensor(y_true)\n            y_pred = ops.convert_to_tensor(y_pred)\n            y_pred._keras_mask = mask\n\n            loss_fn = ExampleLoss()\n            loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)\n            self.assertEqual(backend.standardize_dtype(loss.dtype), \"float32\")\n            self.assertAllClose(\n                np.sum(\n                    masked_sample_weight * (masked_y_true - masked_y_pred) ** 2\n                )\n                / 3,\n                loss,\n            )\n\n    def test_mixed_dtypes(self):\n        sample_weight = np.array([0.4, 0.3, 0.2, 0.1], dtype=\"float64\")\n        y_true = np.array([1.0, 0.0, 1.0, 0.0], dtype=\"int32\")\n        y_pred = np.array([0.1, 0.2, 0.3, 0.4], dtype=\"float32\")\n\n        loss_fn = ExampleLoss()\n        loss = loss_fn(y_true, y_pred, sample_weight=sample_weight)\n        self.assertEqual(backend.standardize_dtype(loss.dtype), \"float32\")\n        self.assertAllClose(\n            np.sum(sample_weight * (y_true - y_pred) ** 2) / 4,\n            loss,\n        )\n\n    def test_pickle(self):\n        loss = losses_module.get(\"mse\")\n        loss = pickle.loads(pickle.dumps(loss))\n        self.assertEqual(loss, losses_module.mean_squared_error)\n\n    def test_get_method(self):\n        loss = losses_module.get(\"mse\")\n        self.assertEqual(loss, losses_module.mean_squared_error)\n\n        loss = losses_module.get(None)\n        self.assertEqual(loss, None)\n\n        with self.assertRaises(ValueError):\n            losses_module.get(\"typo\")\n\n    def test_dtype_arg(self):\n        y_true = np.array([1.0, 0.0, 1.0, 0.0], dtype=\"float32\")\n        y_pred = np.array([0.1, 0.2, 0.3, 0.4], dtype=\"float32\")\n\n        # Note: we use float16 and not float64 to test this because\n        # JAX will map float64 to float32.\n        loss_fn = ExampleLoss(dtype=\"float16\")\n        loss = loss_fn(y_true, y_pred)\n        self.assertDType(loss, \"float16\")\n\n        # Test DTypePolicy for `dtype` argument\n        loss_fn = ExampleLoss(dtype=dtype_policies.DTypePolicy(\"mixed_float16\"))\n        loss = loss_fn(y_true, y_pred)\n        self.assertDType(loss, \"float16\")\n\n        # `dtype` setter should raise AttributeError\n        with self.assertRaises(AttributeError):\n            loss_fn.dtype = \"bfloat16\"\n\n    def test_default_dtype(self):\n        y_true = np.array([1.0, 0.0, 1.0, 0.0], dtype=\"float32\")\n        y_pred = np.array([0.1, 0.2, 0.3, 0.4], dtype=\"float32\")\n\n        # Defaults to `keras.config.floatx()` not global `dtype_policy`\n        dtype_policies.dtype_policy.set_dtype_policy(\"mixed_float16\")\n        loss_fn = ExampleLoss()\n        loss = loss_fn(y_true, y_pred)\n        self.assertDType(loss, \"float32\")\n\n        backend.set_floatx(\"float16\")\n        loss_fn = ExampleLoss()\n        loss = loss_fn(y_true, y_pred)\n        self.assertDType(loss, backend.floatx())\n"
  },
  {
    "path": "keras/src/losses/losses.py",
    "content": "import warnings\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import tree\nfrom keras.src.api_export import keras_export\nfrom keras.src.losses.loss import Loss\nfrom keras.src.losses.loss import squeeze_or_expand_to_same_rank\nfrom keras.src.saving import serialization_lib\nfrom keras.src.utils.numerical_utils import build_pos_neg_masks\nfrom keras.src.utils.numerical_utils import normalize\n\n\nclass LossFunctionWrapper(Loss):\n    def __init__(\n        self,\n        fn,\n        reduction=\"sum_over_batch_size\",\n        name=None,\n        dtype=None,\n        **kwargs,\n    ):\n        super().__init__(name=name, reduction=reduction, dtype=dtype)\n        self.fn = fn\n        self._fn_kwargs = kwargs\n\n    def call(self, y_true, y_pred):\n        y_true_y_pred = tree.map_structure(\n            squeeze_or_expand_to_same_rank, y_true, y_pred\n        )\n        y_true = tree.map_structure_up_to(y_true, lambda x: x[0], y_true_y_pred)\n        y_pred = tree.map_structure_up_to(y_pred, lambda x: x[1], y_true_y_pred)\n        return self.fn(y_true, y_pred, **self._fn_kwargs)\n\n    def get_config(self):\n        config = super().get_config()\n        config.update({\"fn\": serialization_lib.serialize_keras_object(self.fn)})\n        config.update(serialization_lib.serialize_keras_object(self._fn_kwargs))\n        return config\n\n    @classmethod\n    def from_config(cls, config):\n        if \"fn\" in config:\n            config = serialization_lib.deserialize_keras_object(config)\n        return cls(**config)\n\n    def __repr__(self):\n        return f\"<LossFunctionWrapper({self.fn}, kwargs={self._fn_kwargs})>\"\n\n\n@keras_export(\"keras.losses.MeanSquaredError\")\nclass MeanSquaredError(LossFunctionWrapper):\n    \"\"\"Computes the mean of squares of errors between labels and predictions.\n\n    Formula:\n\n    ```python\n    loss = mean(square(y_true - y_pred))\n    ```\n\n    Args:\n        reduction: Type of reduction to apply to the loss. In almost all cases\n            this should be `\"sum_over_batch_size\"`. Supported options are\n            `\"sum\"`, `\"sum_over_batch_size\"`, `\"mean\"`,\n            `\"mean_with_sample_weight\"` or `None`. `\"sum\"` sums the loss,\n            `\"sum_over_batch_size\"` and `\"mean\"` sum the loss and divide by the\n            sample size, and `\"mean_with_sample_weight\"` sums the loss and\n            divides by the sum of the sample weights. `\"none\"` and `None`\n            perform no aggregation. Defaults to `\"sum_over_batch_size\"`.\n        name: Optional name for the loss instance.\n        dtype: The dtype of the loss's computations. Defaults to `None`, which\n            means using `keras.backend.floatx()`. `keras.backend.floatx()` is a\n            `\"float32\"` unless set to different value\n            (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is\n            provided, then the `compute_dtype` will be utilized.\n\n    Examples:\n\n    >>> y_true = keras.ops.array([1.0, 0.0, 1.0])\n    >>> y_pred = keras.ops.array([0.9, 0.1, 0.8])\n    >>> loss = keras.losses.MeanSquaredError()\n    >>> loss(y_true, y_pred)\n    0.02\n    \"\"\"\n\n    def __init__(\n        self,\n        reduction=\"sum_over_batch_size\",\n        name=\"mean_squared_error\",\n        dtype=None,\n    ):\n        super().__init__(\n            mean_squared_error, name=name, reduction=reduction, dtype=dtype\n        )\n\n    def get_config(self):\n        return Loss.get_config(self)\n\n\n@keras_export(\"keras.losses.MeanAbsoluteError\")\nclass MeanAbsoluteError(LossFunctionWrapper):\n    \"\"\"Computes the mean of absolute difference between labels and predictions.\n\n    Formula:\n\n    ```python\n    loss = mean(abs(y_true - y_pred))\n    ```\n\n    Args:\n        reduction: Type of reduction to apply to the loss. In almost all cases\n            this should be `\"sum_over_batch_size\"`. Supported options are\n            `\"sum\"`, `\"sum_over_batch_size\"`, `\"mean\"`,\n            `\"mean_with_sample_weight\"` or `None`. `\"sum\"` sums the loss,\n            `\"sum_over_batch_size\"` and `\"mean\"` sum the loss and divide by the\n            sample size, and `\"mean_with_sample_weight\"` sums the loss and\n            divides by the sum of the sample weights. `\"none\"` and `None`\n            perform no aggregation. Defaults to `\"sum_over_batch_size\"`.\n        name: Optional name for the loss instance.\n        dtype: The dtype of the loss's computations. Defaults to `None`, which\n            means using `keras.backend.floatx()`. `keras.backend.floatx()` is a\n            `\"float32\"` unless set to different value\n            (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is\n            provided, then the `compute_dtype` will be utilized.\n\n    Examples:\n\n    >>> y_true = keras.ops.array([1.0, 0.3, 1.0])\n    >>> y_pred = keras.ops.array([1.9, 0.3, 1.8])\n    >>> loss = keras.losses.MeanAbsoluteError()\n    >>> loss(y_true, y_pred)\n    0.5666667\n    \"\"\"\n\n    def __init__(\n        self,\n        reduction=\"sum_over_batch_size\",\n        name=\"mean_absolute_error\",\n        dtype=None,\n    ):\n        super().__init__(\n            mean_absolute_error, name=name, reduction=reduction, dtype=dtype\n        )\n\n    def get_config(self):\n        return Loss.get_config(self)\n\n\n@keras_export(\"keras.losses.MeanAbsolutePercentageError\")\nclass MeanAbsolutePercentageError(LossFunctionWrapper):\n    \"\"\"Computes the mean absolute percentage error between `y_true` & `y_pred`.\n\n    Formula:\n\n    ```python\n    loss = 100 * mean(abs((y_true - y_pred) / y_true))\n    ```\n\n    Args:\n        reduction: Type of reduction to apply to the loss. In almost all cases\n            this should be `\"sum_over_batch_size\"`. Supported options are\n            `\"sum\"`, `\"sum_over_batch_size\"`, `\"mean\"`,\n            `\"mean_with_sample_weight\"` or `None`. `\"sum\"` sums the loss,\n            `\"sum_over_batch_size\"` and `\"mean\"` sum the loss and divide by the\n            sample size, and `\"mean_with_sample_weight\"` sums the loss and\n            divides by the sum of the sample weights. `\"none\"` and `None`\n            perform no aggregation. Defaults to `\"sum_over_batch_size\"`.\n        name: Optional name for the loss instance.\n        dtype: The dtype of the loss's computations. Defaults to `None`, which\n            means using `keras.backend.floatx()`. `keras.backend.floatx()` is a\n            `\"float32\"` unless set to different value\n            (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is\n            provided, then the `compute_dtype` will be utilized.\n\n    Examples:\n\n    >>> y_true = keras.ops.array([100.0, 200.0, 300.0])\n    >>> y_pred = keras.ops.array([90.0, 210.0, 310.0])\n    >>> loss = keras.losses.MeanAbsolutePercentageError()\n    >>> loss(y_true, y_pred)\n    6.111111\n    \"\"\"\n\n    def __init__(\n        self,\n        reduction=\"sum_over_batch_size\",\n        name=\"mean_absolute_percentage_error\",\n        dtype=None,\n    ):\n        super().__init__(\n            mean_absolute_percentage_error,\n            name=name,\n            reduction=reduction,\n            dtype=dtype,\n        )\n\n    def get_config(self):\n        return Loss.get_config(self)\n\n\n@keras_export(\"keras.losses.MeanSquaredLogarithmicError\")\nclass MeanSquaredLogarithmicError(LossFunctionWrapper):\n    \"\"\"Computes the mean squared logarithmic error between `y_true` & `y_pred`.\n\n    Formula:\n\n    ```python\n    loss = mean(square(log(y_true + 1) - log(y_pred + 1)))\n    ```\n\n    Args:\n        reduction: Type of reduction to apply to the loss. In almost all cases\n            this should be `\"sum_over_batch_size\"`. Supported options are\n            `\"sum\"`, `\"sum_over_batch_size\"`, `\"mean\"`,\n            `\"mean_with_sample_weight\"` or `None`. `\"sum\"` sums the loss,\n            `\"sum_over_batch_size\"` and `\"mean\"` sum the loss and divide by the\n            sample size, and `\"mean_with_sample_weight\"` sums the loss and\n            divides by the sum of the sample weights. `\"none\"` and `None`\n            perform no aggregation. Defaults to `\"sum_over_batch_size\"`.\n        name: Optional name for the loss instance.\n        dtype: The dtype of the loss's computations. Defaults to `None`, which\n            means using `keras.backend.floatx()`. `keras.backend.floatx()` is a\n            `\"float32\"` unless set to different value\n            (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is\n            provided, then the `compute_dtype` will be utilized.\n    \"\"\"\n\n    def __init__(\n        self,\n        reduction=\"sum_over_batch_size\",\n        name=\"mean_squared_logarithmic_error\",\n        dtype=None,\n    ):\n        super().__init__(\n            mean_squared_logarithmic_error,\n            name=name,\n            reduction=reduction,\n            dtype=dtype,\n        )\n\n    def get_config(self):\n        return Loss.get_config(self)\n\n\n@keras_export(\"keras.losses.CosineSimilarity\")\nclass CosineSimilarity(LossFunctionWrapper):\n    \"\"\"Computes the cosine similarity between `y_true` & `y_pred`.\n\n    Note that it is a number between -1 and 1. When it is a negative number\n    between -1 and 0, 0 indicates orthogonality and values closer to -1\n    indicate greater similarity. This makes it usable as a loss function in a\n    setting where you try to maximize the proximity between predictions and\n    targets. If either `y_true` or `y_pred` is a zero vector, cosine similarity\n    will be 0 regardless of the proximity between predictions and targets.\n\n    Formula:\n\n    ```python\n    loss = -sum(l2_norm(y_true) * l2_norm(y_pred))\n    ```\n\n    Args:\n        axis: The axis along which the cosine similarity is computed\n            (the features axis). Defaults to `-1`.\n        reduction: Type of reduction to apply to the loss. In almost all cases\n            this should be `\"sum_over_batch_size\"`. Supported options are\n            `\"sum\"`, `\"sum_over_batch_size\"`, `\"mean\"`,\n            `\"mean_with_sample_weight\"` or `None`. `\"sum\"` sums the loss,\n            `\"sum_over_batch_size\"` and `\"mean\"` sum the loss and divide by the\n            sample size, and `\"mean_with_sample_weight\"` sums the loss and\n            divides by the sum of the sample weights. `\"none\"` and `None`\n            perform no aggregation. Defaults to `\"sum_over_batch_size\"`.\n        name: Optional name for the loss instance.\n        dtype: The dtype of the loss's computations. Defaults to `None`, which\n            means using `keras.backend.floatx()`. `keras.backend.floatx()` is a\n            `\"float32\"` unless set to different value\n            (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is\n            provided, then the `compute_dtype` will be utilized.\n    \"\"\"\n\n    def __init__(\n        self,\n        axis=-1,\n        reduction=\"sum_over_batch_size\",\n        name=\"cosine_similarity\",\n        dtype=None,\n    ):\n        super().__init__(\n            cosine_similarity,\n            name=name,\n            reduction=reduction,\n            dtype=dtype,\n            axis=axis,\n        )\n\n    def get_config(self):\n        return Loss.get_config(self)\n\n\n@keras_export(\"keras.losses.Huber\")\nclass Huber(LossFunctionWrapper):\n    \"\"\"Computes the Huber loss between `y_true` & `y_pred`.\n\n    Formula:\n\n    ```python\n    for x in error:\n        if abs(x) <= delta:\n            loss.append(0.5 * x^2)\n        elif abs(x) > delta:\n            loss.append(delta * abs(x) - 0.5 * delta^2)\n\n    loss = mean(loss, axis=-1)\n    ```\n    See: [Huber loss](https://en.wikipedia.org/wiki/Huber_loss).\n\n    Args:\n        delta: A float, the point where the Huber loss function changes from a\n            quadratic to linear.\n        reduction: Type of reduction to apply to the loss. In almost all cases\n            this should be `\"sum_over_batch_size\"`. Supported options are\n            `\"sum\"`, `\"sum_over_batch_size\"`, `\"mean\"`,\n            `\"mean_with_sample_weight\"` or `None`. `\"sum\"` sums the loss,\n            `\"sum_over_batch_size\"` and `\"mean\"` sum the loss and divide by the\n            sample size, and `\"mean_with_sample_weight\"` sums the loss and\n            divides by the sum of the sample weights. `\"none\"` and `None`\n            perform no aggregation. Defaults to `\"sum_over_batch_size\"`.\n        name: Optional name for the instance.\n        dtype: The dtype of the loss's computations. Defaults to `None`, which\n            means using `keras.backend.floatx()`. `keras.backend.floatx()` is a\n            `\"float32\"` unless set to different value\n            (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is\n            provided, then the `compute_dtype` will be utilized.\n    \"\"\"\n\n    def __init__(\n        self,\n        delta=1.0,\n        reduction=\"sum_over_batch_size\",\n        name=\"huber_loss\",\n        dtype=None,\n    ):\n        super().__init__(\n            huber,\n            name=name,\n            reduction=reduction,\n            dtype=dtype,\n            delta=delta,\n        )\n\n    def get_config(self):\n        return Loss.get_config(self)\n\n\n@keras_export(\"keras.losses.LogCosh\")\nclass LogCosh(LossFunctionWrapper):\n    \"\"\"Computes the logarithm of the hyperbolic cosine of the prediction error.\n\n    Formula:\n\n    ```python\n    error = y_pred - y_true\n    logcosh = mean(log((exp(error) + exp(-error))/2), axis=-1)`\n    ```\n    where x is the error `y_pred - y_true`.\n\n    Args:\n        reduction: Type of reduction to apply to the loss. In almost all cases\n            this should be `\"sum_over_batch_size\"`. Supported options are\n            `\"sum\"`, `\"sum_over_batch_size\"`, `\"mean\"`,\n            `\"mean_with_sample_weight\"` or `None`. `\"sum\"` sums the loss,\n            `\"sum_over_batch_size\"` and `\"mean\"` sum the loss and divide by the\n            sample size, and `\"mean_with_sample_weight\"` sums the loss and\n            divides by the sum of the sample weights. `\"none\"` and `None`\n            perform no aggregation. Defaults to `\"sum_over_batch_size\"`.\n        name: Optional name for the instance.\n        dtype: The dtype of the loss's computations. Defaults to `None`, which\n            means using `keras.backend.floatx()`. `keras.backend.floatx()` is a\n            `\"float32\"` unless set to different value\n            (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is\n            provided, then the `compute_dtype` will be utilized.\n    \"\"\"\n\n    def __init__(\n        self,\n        reduction=\"sum_over_batch_size\",\n        name=\"log_cosh\",\n        dtype=None,\n    ):\n        super().__init__(log_cosh, name=name, reduction=reduction, dtype=dtype)\n\n    def get_config(self):\n        return Loss.get_config(self)\n\n\n@keras_export(\"keras.losses.Hinge\")\nclass Hinge(LossFunctionWrapper):\n    \"\"\"Computes the hinge loss between `y_true` & `y_pred`.\n\n    Formula:\n\n    ```python\n    loss = maximum(1 - y_true * y_pred, 0)\n    ```\n\n    `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are\n    provided we will convert them to -1 or 1.\n\n    Args:\n        reduction: Type of reduction to apply to the loss. In almost all cases\n            this should be `\"sum_over_batch_size\"`. Supported options are\n            `\"sum\"`, `\"sum_over_batch_size\"`, `\"mean\"`,\n            `\"mean_with_sample_weight\"` or `None`. `\"sum\"` sums the loss,\n            `\"sum_over_batch_size\"` and `\"mean\"` sum the loss and divide by the\n            sample size, and `\"mean_with_sample_weight\"` sums the loss and\n            divides by the sum of the sample weights. `\"none\"` and `None`\n            perform no aggregation. Defaults to `\"sum_over_batch_size\"`.\n        name: Optional name for the loss instance.\n        dtype: The dtype of the loss's computations. Defaults to `None`, which\n            means using `keras.backend.floatx()`. `keras.backend.floatx()` is a\n            `\"float32\"` unless set to different value\n            (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is\n            provided, then the `compute_dtype` will be utilized.\n    \"\"\"\n\n    def __init__(\n        self,\n        reduction=\"sum_over_batch_size\",\n        name=\"hinge\",\n        dtype=None,\n    ):\n        super().__init__(hinge, name=name, reduction=reduction, dtype=dtype)\n\n    def get_config(self):\n        return Loss.get_config(self)\n\n\n@keras_export(\"keras.losses.SquaredHinge\")\nclass SquaredHinge(LossFunctionWrapper):\n    \"\"\"Computes the squared hinge loss between `y_true` & `y_pred`.\n\n    Formula:\n\n    ```python\n    loss = square(maximum(1 - y_true * y_pred, 0))\n    ```\n\n    `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are\n    provided we will convert them to -1 or 1.\n\n    Args:\n        reduction: Type of reduction to apply to the loss. In almost all cases\n            this should be `\"sum_over_batch_size\"`. Supported options are\n            `\"sum\"`, `\"sum_over_batch_size\"`, `\"mean\"`,\n            `\"mean_with_sample_weight\"` or `None`. `\"sum\"` sums the loss,\n            `\"sum_over_batch_size\"` and `\"mean\"` sum the loss and divide by the\n            sample size, and `\"mean_with_sample_weight\"` sums the loss and\n            divides by the sum of the sample weights. `\"none\"` and `None`\n            perform no aggregation. Defaults to `\"sum_over_batch_size\"`.\n        name: Optional name for the loss instance.\n        dtype: The dtype of the loss's computations. Defaults to `None`, which\n            means using `keras.backend.floatx()`. `keras.backend.floatx()` is a\n            `\"float32\"` unless set to different value\n            (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is\n            provided, then the `compute_dtype` will be utilized.\n    \"\"\"\n\n    def __init__(\n        self, reduction=\"sum_over_batch_size\", name=\"squared_hinge\", dtype=None\n    ):\n        super().__init__(\n            squared_hinge, name=name, reduction=reduction, dtype=dtype\n        )\n\n    def get_config(self):\n        return Loss.get_config(self)\n\n\n@keras_export(\"keras.losses.CategoricalHinge\")\nclass CategoricalHinge(LossFunctionWrapper):\n    \"\"\"Computes the categorical hinge loss between `y_true` & `y_pred`.\n\n    Formula:\n\n    ```python\n    loss = maximum(neg - pos + 1, 0)\n    ```\n\n    where `neg=maximum((1-y_true)*y_pred)` and `pos=sum(y_true*y_pred)`\n\n    Args:\n        reduction: Type of reduction to apply to the loss. In almost all cases\n            this should be `\"sum_over_batch_size\"`. Supported options are\n            `\"sum\"`, `\"sum_over_batch_size\"`, `\"mean\"`,\n            `\"mean_with_sample_weight\"` or `None`. `\"sum\"` sums the loss,\n            `\"sum_over_batch_size\"` and `\"mean\"` sum the loss and divide by the\n            sample size, and `\"mean_with_sample_weight\"` sums the loss and\n            divides by the sum of the sample weights. `\"none\"` and `None`\n            perform no aggregation. Defaults to `\"sum_over_batch_size\"`.\n        name: Optional name for the loss instance.\n        dtype: The dtype of the loss's computations. Defaults to `None`, which\n            means using `keras.backend.floatx()`. `keras.backend.floatx()` is a\n            `\"float32\"` unless set to different value\n            (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is\n            provided, then the `compute_dtype` will be utilized.\n    \"\"\"\n\n    def __init__(\n        self,\n        reduction=\"sum_over_batch_size\",\n        name=\"categorical_hinge\",\n        dtype=None,\n    ):\n        super().__init__(\n            categorical_hinge, name=name, reduction=reduction, dtype=dtype\n        )\n\n    def get_config(self):\n        return Loss.get_config(self)\n\n\n@keras_export(\"keras.losses.KLDivergence\")\nclass KLDivergence(LossFunctionWrapper):\n    \"\"\"Computes Kullback-Leibler divergence loss between `y_true` & `y_pred`.\n\n    Formula:\n\n    ```python\n    loss = y_true * log(y_true / y_pred)\n    ```\n\n    `y_true` and `y_pred` are expected to be probability\n    distributions, with values between 0 and 1. They will get\n    clipped to the `[0, 1]` range.\n\n    Args:\n        reduction: Type of reduction to apply to the loss. In almost all cases\n            this should be `\"sum_over_batch_size\"`. Supported options are\n            `\"sum\"`, `\"sum_over_batch_size\"`, `\"mean\"`,\n            `\"mean_with_sample_weight\"` or `None`. `\"sum\"` sums the loss,\n            `\"sum_over_batch_size\"` and `\"mean\"` sum the loss and divide by the\n            sample size, and `\"mean_with_sample_weight\"` sums the loss and\n            divides by the sum of the sample weights. `\"none\"` and `None`\n            perform no aggregation. Defaults to `\"sum_over_batch_size\"`.\n        name: Optional name for the loss instance.\n        dtype: The dtype of the loss's computations. Defaults to `None`, which\n            means using `keras.backend.floatx()`. `keras.backend.floatx()` is a\n            `\"float32\"` unless set to different value\n            (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is\n            provided, then the `compute_dtype` will be utilized.\n    \"\"\"\n\n    def __init__(\n        self, reduction=\"sum_over_batch_size\", name=\"kl_divergence\", dtype=None\n    ):\n        super().__init__(\n            kl_divergence, name=name, reduction=reduction, dtype=dtype\n        )\n\n    def get_config(self):\n        return Loss.get_config(self)\n\n\n@keras_export(\"keras.losses.Poisson\")\nclass Poisson(LossFunctionWrapper):\n    \"\"\"Computes the Poisson loss between `y_true` & `y_pred`.\n\n    Formula:\n\n    ```python\n    loss = y_pred - y_true * log(y_pred)\n    ```\n\n    Args:\n        reduction: Type of reduction to apply to the loss. In almost all cases\n            this should be `\"sum_over_batch_size\"`. Supported options are\n            `\"sum\"`, `\"sum_over_batch_size\"`, `\"mean\"`,\n            `\"mean_with_sample_weight\"` or `None`. `\"sum\"` sums the loss,\n            `\"sum_over_batch_size\"` and `\"mean\"` sum the loss and divide by the\n            sample size, and `\"mean_with_sample_weight\"` sums the loss and\n            divides by the sum of the sample weights. `\"none\"` and `None`\n            perform no aggregation. Defaults to `\"sum_over_batch_size\"`.\n        name: Optional name for the loss instance.\n        dtype: The dtype of the loss's computations. Defaults to `None`, which\n            means using `keras.backend.floatx()`. `keras.backend.floatx()` is a\n            `\"float32\"` unless set to different value\n            (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is\n            provided, then the `compute_dtype` will be utilized.\n    \"\"\"\n\n    def __init__(\n        self, reduction=\"sum_over_batch_size\", name=\"poisson\", dtype=None\n    ):\n        super().__init__(poisson, name=name, reduction=reduction, dtype=dtype)\n\n    def get_config(self):\n        return Loss.get_config(self)\n\n\n@keras_export(\"keras.losses.BinaryCrossentropy\")\nclass BinaryCrossentropy(LossFunctionWrapper):\n    \"\"\"Computes the cross-entropy loss between true labels and predicted labels.\n\n    Use this cross-entropy loss for binary (0 or 1) classification applications.\n    The loss function requires the following inputs:\n\n    - `y_true` (true label): This is either 0 or 1.\n    - `y_pred` (predicted value): This is the model's prediction, i.e, a single\n        floating-point value which either represents a\n        [logit](https://en.wikipedia.org/wiki/Logit), (i.e, value in [-inf, inf]\n        when `from_logits=True`) or a probability (i.e, value in [0., 1.] when\n        `from_logits=False`).\n\n    Args:\n        from_logits: Whether to interpret `y_pred` as a tensor of\n            [logit](https://en.wikipedia.org/wiki/Logit) values. By default, we\n            assume that `y_pred` is probabilities (i.e., values in [0, 1]).\n        label_smoothing: Float in range [0, 1]. When 0, no smoothing occurs.\n            When > 0, we compute the loss between the predicted labels\n            and a smoothed version of the true labels, where the smoothing\n            squeezes the labels towards 0.5. Larger values of\n            `label_smoothing` correspond to heavier smoothing.\n        axis: The axis along which to compute crossentropy (the features axis).\n            Defaults to `-1`.\n        reduction: Type of reduction to apply to the loss. In almost all cases\n            this should be `\"sum_over_batch_size\"`. Supported options are\n            `\"sum\"`, `\"sum_over_batch_size\"`, `\"mean\"`,\n            `\"mean_with_sample_weight\"` or `None`. `\"sum\"` sums the loss,\n            `\"sum_over_batch_size\"` and `\"mean\"` sum the loss and divide by the\n            sample size, and `\"mean_with_sample_weight\"` sums the loss and\n            divides by the sum of the sample weights. `\"none\"` and `None`\n            perform no aggregation. Defaults to `\"sum_over_batch_size\"`.\n        name: Optional name for the loss instance.\n        dtype: The dtype of the loss's computations. Defaults to `None`, which\n            means using `keras.backend.floatx()`. `keras.backend.floatx()` is a\n            `\"float32\"` unless set to different value\n            (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is\n            provided, then the `compute_dtype` will be utilized.\n\n    Examples:\n\n    **Recommended Usage:** (set `from_logits=True`)\n\n    With `compile()` API:\n\n    ```python\n    model.compile(\n        loss=keras.losses.BinaryCrossentropy(from_logits=True),\n        ...\n    )\n    ```\n\n    As a standalone function:\n\n    >>> # Example 1: (batch_size = 1, number of samples = 4)\n    >>> y_true = np.array([0, 1, 0, 0])\n    >>> y_pred = np.array([-18.6, 0.51, 2.94, -12.8])\n    >>> bce = keras.losses.BinaryCrossentropy(from_logits=True)\n    >>> bce(y_true, y_pred)\n    0.8654\n\n    >>> # Example 2: (batch_size = 2, number of samples = 4)\n    >>> y_true = np.array([[0, 1], [0, 0]])\n    >>> y_pred = np.array([[-18.6, 0.51], [2.94, -12.8]])\n    >>> # Using default 'auto'/'sum_over_batch_size' reduction type.\n    >>> bce = keras.losses.BinaryCrossentropy(from_logits=True)\n    >>> bce(y_true, y_pred)\n    0.8654\n    >>> # Using 'sample_weight' attribute\n    >>> bce(y_true, y_pred, sample_weight=[0.8, 0.2])\n    0.243\n    >>> # Using 'sum' reduction` type.\n    >>> bce = keras.losses.BinaryCrossentropy(from_logits=True,\n    ...     reduction=\"sum\")\n    >>> bce(y_true, y_pred)\n    1.730\n    >>> # Using 'none' reduction type.\n    >>> bce = keras.losses.BinaryCrossentropy(from_logits=True,\n    ...     reduction=None)\n    >>> bce(y_true, y_pred)\n    array([0.235, 1.496], dtype=float32)\n\n    **Default Usage:** (set `from_logits=False`)\n\n    >>> # Make the following updates to the above \"Recommended Usage\" section\n    >>> # 1. Set `from_logits=False`\n    >>> keras.losses.BinaryCrossentropy() # OR ...('from_logits=False')\n    >>> # 2. Update `y_pred` to use probabilities instead of logits\n    >>> y_pred = [0.6, 0.3, 0.2, 0.8] # OR [[0.6, 0.3], [0.2, 0.8]]\n    \"\"\"\n\n    def __init__(\n        self,\n        from_logits=False,\n        label_smoothing=0.0,\n        axis=-1,\n        reduction=\"sum_over_batch_size\",\n        name=\"binary_crossentropy\",\n        dtype=None,\n    ):\n        super().__init__(\n            binary_crossentropy,\n            name=name,\n            reduction=reduction,\n            dtype=dtype,\n            from_logits=from_logits,\n            label_smoothing=label_smoothing,\n            axis=axis,\n        )\n        self.from_logits = from_logits\n        self.label_smoothing = label_smoothing\n        self.axis = axis\n\n    def get_config(self):\n        config = Loss.get_config(self)\n        config.update(\n            {\n                \"from_logits\": self.from_logits,\n                \"label_smoothing\": self.label_smoothing,\n                \"axis\": self.axis,\n            }\n        )\n        return config\n\n\n@keras_export(\"keras.losses.BinaryFocalCrossentropy\")\nclass BinaryFocalCrossentropy(LossFunctionWrapper):\n    \"\"\"Computes focal cross-entropy loss between true labels and predictions.\n\n    Binary cross-entropy loss is often used for binary (0 or 1) classification\n    tasks. The loss function requires the following inputs:\n\n    - `y_true` (true label): This is either 0 or 1.\n    - `y_pred` (predicted value): This is the model's prediction, i.e, a single\n        floating-point value which either represents a\n        [logit](https://en.wikipedia.org/wiki/Logit), (i.e, value in [-inf, inf]\n        when `from_logits=True`) or a probability (i.e, value in `[0., 1.]` when\n        `from_logits=False`).\n\n    According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it\n    helps to apply a \"focal factor\" to down-weight easy examples and focus more\n    on hard examples. By default, the focal tensor is computed as follows:\n\n    `focal_factor = (1 - output) ** gamma` for class 1\n    `focal_factor = output ** gamma` for class 0\n    where `gamma` is a focusing parameter. When `gamma=0`, this function is\n    equivalent to the binary crossentropy loss.\n\n    Args:\n        apply_class_balancing: A bool, whether to apply weight balancing on the\n            binary classes 0 and 1.\n        alpha: A weight balancing factor for class 1, default is `0.25` as\n            mentioned in reference [Lin et al., 2018](\n            https://arxiv.org/pdf/1708.02002.pdf).  The weight for class 0 is\n            `1.0 - alpha`.\n        gamma: A focusing parameter used to compute the focal factor, default is\n            `2.0` as mentioned in the reference\n            [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf).\n        from_logits: Whether to interpret `y_pred` as a tensor of\n            [logit](https://en.wikipedia.org/wiki/Logit) values. By default, we\n            assume that `y_pred` are probabilities (i.e., values in `[0, 1]`).\n        label_smoothing: Float in `[0, 1]`. When `0`, no smoothing occurs.\n            When > `0`, we compute the loss between the predicted labels\n            and a smoothed version of the true labels, where the smoothing\n            squeezes the labels towards `0.5`.\n            Larger values of `label_smoothing` correspond to heavier smoothing.\n        axis: The axis along which to compute crossentropy (the features axis).\n            Defaults to `-1`.\n        reduction: Type of reduction to apply to the loss. In almost all cases\n            this should be `\"sum_over_batch_size\"`. Supported options are\n            `\"sum\"`, `\"sum_over_batch_size\"`, `\"mean\"`,\n            `\"mean_with_sample_weight\"` or `None`. `\"sum\"` sums the loss,\n            `\"sum_over_batch_size\"` and `\"mean\"` sum the loss and divide by the\n            sample size, and `\"mean_with_sample_weight\"` sums the loss and\n            divides by the sum of the sample weights. `\"none\"` and `None`\n            perform no aggregation. Defaults to `\"sum_over_batch_size\"`.\n        name: Optional name for the loss instance.\n        dtype: The dtype of the loss's computations. Defaults to `None`, which\n            means using `keras.backend.floatx()`. `keras.backend.floatx()` is a\n            `\"float32\"` unless set to different value\n            (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is\n            provided, then the `compute_dtype` will be utilized.\n\n    Examples:\n\n    With the `compile()` API:\n\n    ```python\n    model.compile(\n        loss=keras.losses.BinaryFocalCrossentropy(\n            gamma=2.0, from_logits=True),\n        ...\n    )\n    ```\n\n    As a standalone function:\n\n    >>> # Example 1: (batch_size = 1, number of samples = 4)\n    >>> y_true = np.array([0, 1, 0, 0])\n    >>> y_pred = np.array([-18.6, 0.51, 2.94, -12.8])\n    >>> loss = keras.losses.BinaryFocalCrossentropy(\n    ...    gamma=2, from_logits=True)\n    >>> loss(y_true, y_pred)\n    0.691\n\n    >>> # Apply class weight\n    >>> loss = keras.losses.BinaryFocalCrossentropy(\n    ...     apply_class_balancing=True, gamma=2, from_logits=True)\n    >>> loss(y_true, y_pred)\n    0.51\n\n    >>> # Example 2: (batch_size = 2, number of samples = 4)\n    >>> y_true = np.array([[0, 1], [0, 0]])\n    >>> y_pred = np.array([[-18.6, 0.51], [2.94, -12.8]])\n    >>> # Using default 'auto'/'sum_over_batch_size' reduction type.\n    >>> loss = keras.losses.BinaryFocalCrossentropy(\n    ...     gamma=3, from_logits=True)\n    >>> loss(y_true, y_pred)\n    0.647\n\n    >>> # Apply class weight\n    >>> loss = keras.losses.BinaryFocalCrossentropy(\n    ...      apply_class_balancing=True, gamma=3, from_logits=True)\n    >>> loss(y_true, y_pred)\n    0.482\n\n    >>> # Using 'sample_weight' attribute with focal effect\n    >>> loss = keras.losses.BinaryFocalCrossentropy(\n    ...     gamma=3, from_logits=True)\n    >>> loss(y_true, y_pred, sample_weight=[0.8, 0.2])\n    0.133\n\n    >>> # Apply class weight\n    >>> loss = keras.losses.BinaryFocalCrossentropy(\n    ...      apply_class_balancing=True, gamma=3, from_logits=True)\n    >>> loss(y_true, y_pred, sample_weight=[0.8, 0.2])\n    0.097\n\n    >>> # Using 'sum' reduction` type.\n    >>> loss = keras.losses.BinaryFocalCrossentropy(\n    ...     gamma=4, from_logits=True,\n    ...     reduction=\"sum\")\n    >>> loss(y_true, y_pred)\n    1.222\n\n    >>> # Apply class weight\n    >>> loss = keras.losses.BinaryFocalCrossentropy(\n    ...     apply_class_balancing=True, gamma=4, from_logits=True,\n    ...     reduction=\"sum\")\n    >>> loss(y_true, y_pred)\n    0.914\n\n    >>> # Using 'none' reduction type.\n    >>> loss = keras.losses.BinaryFocalCrossentropy(\n    ...     gamma=5, from_logits=True,\n    ...     reduction=None)\n    >>> loss(y_true, y_pred)\n    array([0.0017 1.1561], dtype=float32)\n\n    >>> # Apply class weight\n    >>> loss = keras.losses.BinaryFocalCrossentropy(\n    ...     apply_class_balancing=True, gamma=5, from_logits=True,\n    ...     reduction=None)\n    >>> loss(y_true, y_pred)\n    array([0.0004 0.8670], dtype=float32)\n    \"\"\"\n\n    def __init__(\n        self,\n        apply_class_balancing=False,\n        alpha=0.25,\n        gamma=2.0,\n        from_logits=False,\n        label_smoothing=0.0,\n        axis=-1,\n        reduction=\"sum_over_batch_size\",\n        name=\"binary_focal_crossentropy\",\n        dtype=None,\n    ):\n        super().__init__(\n            binary_focal_crossentropy,\n            name=name,\n            reduction=reduction,\n            dtype=dtype,\n            apply_class_balancing=apply_class_balancing,\n            alpha=alpha,\n            gamma=gamma,\n            from_logits=from_logits,\n            label_smoothing=label_smoothing,\n            axis=axis,\n        )\n        self.from_logits = from_logits\n        self.label_smoothing = label_smoothing\n        self.axis = axis\n        self.apply_class_balancing = apply_class_balancing\n        self.alpha = alpha\n        self.gamma = gamma\n\n    def get_config(self):\n        config = Loss.get_config(self)\n        config.update(\n            {\n                \"from_logits\": self.from_logits,\n                \"label_smoothing\": self.label_smoothing,\n                \"axis\": self.axis,\n                \"apply_class_balancing\": self.apply_class_balancing,\n                \"alpha\": self.alpha,\n                \"gamma\": self.gamma,\n            }\n        )\n        return config\n\n\n@keras_export(\"keras.losses.CategoricalCrossentropy\")\nclass CategoricalCrossentropy(LossFunctionWrapper):\n    \"\"\"Computes the crossentropy loss between the labels and predictions.\n\n    Use this crossentropy loss function when there are two or more label\n    classes. We expect labels to be provided in a `one_hot` representation. If\n    you want to provide labels as integers, please use\n    `SparseCategoricalCrossentropy` loss. There should be `num_classes` floating\n    point values per feature, i.e., the shape of both `y_pred` and `y_true` are\n    `[batch_size, num_classes]`.\n\n    Args:\n        from_logits: Whether `y_pred` is expected to be a logits tensor. By\n            default, we assume that `y_pred` encodes a probability distribution.\n        label_smoothing: Float in [0, 1]. When > 0, label values are smoothed,\n            meaning the confidence on label values are relaxed. For example, if\n            `0.1`, use `0.1 / num_classes` for non-target labels and\n            `0.9 + 0.1 / num_classes` for target labels.\n        axis: The axis along which to compute crossentropy (the features\n            axis). Defaults to `-1`.\n        reduction: Type of reduction to apply to the loss. In almost all cases\n            this should be `\"sum_over_batch_size\"`. Supported options are\n            `\"sum\"`, `\"sum_over_batch_size\"`, `\"mean\"`,\n            `\"mean_with_sample_weight\"` or `None`. `\"sum\"` sums the loss,\n            `\"sum_over_batch_size\"` and `\"mean\"` sum the loss and divide by the\n            sample size, and `\"mean_with_sample_weight\"` sums the loss and\n            divides by the sum of the sample weights. `\"none\"` and `None`\n            perform no aggregation. Defaults to `\"sum_over_batch_size\"`.\n        name: Optional name for the loss instance.\n        dtype: The dtype of the loss's computations. Defaults to `None`, which\n            means using `keras.backend.floatx()`. `keras.backend.floatx()` is a\n            `\"float32\"` unless set to different value\n            (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is\n            provided, then the `compute_dtype` will be utilized.\n\n    Examples:\n\n    Standalone usage:\n\n    >>> y_true = np.array([[0, 1, 0], [0, 0, 1]])\n    >>> y_pred = np.array([[0.05, 0.95, 0], [0.1, 0.8, 0.1]])\n    >>> # Using 'auto'/'sum_over_batch_size' reduction type.\n    >>> cce = keras.losses.CategoricalCrossentropy()\n    >>> cce(y_true, y_pred)\n    1.177\n\n    >>> # Calling with 'sample_weight'.\n    >>> cce(y_true, y_pred, sample_weight=np.array([0.3, 0.7]))\n    0.814\n\n    >>> # Using 'sum' reduction type.\n    >>> cce = keras.losses.CategoricalCrossentropy(\n    ...     reduction=\"sum\")\n    >>> cce(y_true, y_pred)\n    2.354\n\n    >>> # Using 'none' reduction type.\n    >>> cce = keras.losses.CategoricalCrossentropy(\n    ...     reduction=None)\n    >>> cce(y_true, y_pred)\n    array([0.0513, 2.303], dtype=float32)\n\n    Usage with the `compile()` API:\n\n    ```python\n    model.compile(optimizer='sgd',\n                  loss=keras.losses.CategoricalCrossentropy())\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        from_logits=False,\n        label_smoothing=0.0,\n        axis=-1,\n        reduction=\"sum_over_batch_size\",\n        name=\"categorical_crossentropy\",\n        dtype=None,\n    ):\n        super().__init__(\n            categorical_crossentropy,\n            name=name,\n            reduction=reduction,\n            dtype=dtype,\n            from_logits=from_logits,\n            label_smoothing=label_smoothing,\n            axis=axis,\n        )\n        self.from_logits = from_logits\n        self.label_smoothing = label_smoothing\n        self.axis = axis\n\n    def get_config(self):\n        config = Loss.get_config(self)\n        config.update(\n            {\n                \"from_logits\": self.from_logits,\n                \"label_smoothing\": self.label_smoothing,\n                \"axis\": self.axis,\n            }\n        )\n        return config\n\n\n@keras_export(\"keras.losses.CategoricalFocalCrossentropy\")\nclass CategoricalFocalCrossentropy(LossFunctionWrapper):\n    \"\"\"Computes the alpha balanced focal crossentropy loss.\n\n    Use this crossentropy loss function when there are two or more label\n    classes and if you want to handle class imbalance without using\n    `class_weights`. We expect labels to be provided in a `one_hot`\n    representation.\n\n    According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it\n    helps to apply a focal factor to down-weight easy examples and focus more on\n    hard examples. The general formula for the focal loss (FL)\n    is as follows:\n\n    `FL(p_t) = (1 - p_t) ** gamma * log(p_t)`\n\n    where `p_t` is defined as follows:\n    `p_t = output if y_true == 1, else 1 - output`\n\n    `(1 - p_t) ** gamma` is the `modulating_factor`, where `gamma` is a focusing\n    parameter. When `gamma` = 0, there is no focal effect on the cross entropy.\n    `gamma` reduces the importance given to simple examples in a smooth manner.\n\n    The authors use alpha-balanced variant of focal loss (FL) in the paper:\n    `FL(p_t) = -alpha * (1 - p_t) ** gamma * log(p_t)`\n\n    where `alpha` is the weight factor for the classes. If `alpha` = 1, the\n    loss won't be able to handle class imbalance properly as all\n    classes will have the same weight. This can be a constant or a list of\n    constants. If alpha is a list, it must have the same length as the number\n    of classes.\n\n    The formula above can be generalized to:\n    `FL(p_t) = alpha * (1 - p_t) ** gamma * CrossEntropy(y_true, y_pred)`\n\n    where minus comes from `CrossEntropy(y_true, y_pred)` (CE).\n\n    Extending this to multi-class case is straightforward:\n    `FL(p_t) = alpha * (1 - p_t) ** gamma * CategoricalCE(y_true, y_pred)`\n\n    In the snippet below, there is `num_classes` floating pointing values per\n    example. The shape of both `y_pred` and `y_true` are\n    `(batch_size, num_classes)`.\n\n    Args:\n        alpha: A weight balancing factor for all classes, default is `0.25` as\n            mentioned in the reference. It can be a list of floats or a scalar.\n            In the multi-class case, alpha may be set by inverse class\n            frequency by using `compute_class_weight` from `sklearn.utils`.\n        gamma: A focusing parameter, default is `2.0` as mentioned in the\n            reference. It helps to gradually reduce the importance given to\n            simple (easy) examples in a smooth manner.\n        from_logits: Whether `output` is expected to be a logits tensor. By\n            default, we consider that `output` encodes a probability\n            distribution.\n        label_smoothing: Float in [0, 1]. When > 0, label values are smoothed,\n            meaning the confidence on label values are relaxed. For example, if\n            `0.1`, use `0.1 / num_classes` for non-target labels and\n            `0.9 + 0.1 / num_classes` for target labels.\n        axis: The axis along which to compute crossentropy (the features\n            axis). Defaults to `-1`.\n        reduction: Type of reduction to apply to the loss. In almost all cases\n            this should be `\"sum_over_batch_size\"`. Supported options are\n            `\"sum\"`, `\"sum_over_batch_size\"`, `\"mean\"`,\n            `\"mean_with_sample_weight\"` or `None`. `\"sum\"` sums the loss,\n            `\"sum_over_batch_size\"` and `\"mean\"` sum the loss and divide by the\n            sample size, and `\"mean_with_sample_weight\"` sums the loss and\n            divides by the sum of the sample weights. `\"none\"` and `None`\n            perform no aggregation. Defaults to `\"sum_over_batch_size\"`.\n        name: Optional name for the loss instance.\n        dtype: The dtype of the loss's computations. Defaults to `None`, which\n            means using `keras.backend.floatx()`. `keras.backend.floatx()` is a\n            `\"float32\"` unless set to different value\n            (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is\n            provided, then the `compute_dtype` will be utilized.\n\n    Examples:\n\n    Standalone usage:\n\n    >>> y_true = [[0., 1., 0.], [0., 0., 1.]]\n    >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]\n    >>> # Using 'auto'/'sum_over_batch_size' reduction type.\n    >>> cce = keras.losses.CategoricalFocalCrossentropy()\n    >>> cce(y_true, y_pred)\n    0.23315276\n\n    >>> # Calling with 'sample_weight'.\n    >>> cce(y_true, y_pred, sample_weight=np.array([0.3, 0.7]))\n    0.1632\n\n    >>> # Using 'sum' reduction type.\n    >>> cce = keras.losses.CategoricalFocalCrossentropy(\n    ...     reduction=\"sum\")\n    >>> cce(y_true, y_pred)\n    0.46631\n\n    >>> # Using 'none' reduction type.\n    >>> cce = keras.losses.CategoricalFocalCrossentropy(\n    ...     reduction=None)\n    >>> cce(y_true, y_pred)\n    array([3.2058331e-05, 4.6627346e-01], dtype=float32)\n\n    Usage with the `compile()` API:\n\n    ```python\n    model.compile(optimizer='adam',\n                  loss=keras.losses.CategoricalFocalCrossentropy())\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        alpha=0.25,\n        gamma=2.0,\n        from_logits=False,\n        label_smoothing=0.0,\n        axis=-1,\n        reduction=\"sum_over_batch_size\",\n        name=\"categorical_focal_crossentropy\",\n        dtype=None,\n    ):\n        \"\"\"Initializes `CategoricalFocalCrossentropy` instance.\"\"\"\n        super().__init__(\n            categorical_focal_crossentropy,\n            name=name,\n            reduction=reduction,\n            dtype=dtype,\n            alpha=alpha,\n            gamma=gamma,\n            from_logits=from_logits,\n            label_smoothing=label_smoothing,\n            axis=axis,\n        )\n        self.from_logits = from_logits\n        self.label_smoothing = label_smoothing\n        self.axis = axis\n        self.alpha = alpha\n        self.gamma = gamma\n\n    def get_config(self):\n        config = Loss.get_config(self)\n        config.update(\n            {\n                \"from_logits\": self.from_logits,\n                \"label_smoothing\": self.label_smoothing,\n                \"axis\": self.axis,\n                \"alpha\": self.alpha,\n                \"gamma\": self.gamma,\n            }\n        )\n        return config\n\n\n@keras_export(\"keras.losses.SparseCategoricalCrossentropy\")\nclass SparseCategoricalCrossentropy(LossFunctionWrapper):\n    \"\"\"Computes the crossentropy loss between the labels and predictions.\n\n    Use this crossentropy loss function when there are two or more label\n    classes.  We expect labels to be provided as integers. If you want to\n    provide labels using `one-hot` representation, please use\n    `CategoricalCrossentropy` loss.  There should be `# classes` floating point\n    values per feature for `y_pred` and a single floating point value per\n    feature for `y_true`.\n\n    In the snippet below, there is a single floating point value per example for\n    `y_true` and `num_classes` floating pointing values per example for\n    `y_pred`. The shape of `y_true` is `[batch_size]` and the shape of `y_pred`\n    is `[batch_size, num_classes]`.\n\n    Args:\n        from_logits: Whether `y_pred` is expected to be a logits tensor. By\n            default, we assume that `y_pred` encodes a probability distribution.\n        reduction: Type of reduction to apply to the loss. In almost all cases\n            this should be `\"sum_over_batch_size\"`. Supported options are\n            `\"sum\"`, `\"sum_over_batch_size\"`, `\"mean\"`,\n            `\"mean_with_sample_weight\"` or `None`. `\"sum\"` sums the loss,\n            `\"sum_over_batch_size\"` and `\"mean\"` sum the loss and divide by the\n            sample size, and `\"mean_with_sample_weight\"` sums the loss and\n            divides by the sum of the sample weights. `\"none\"` and `None`\n            perform no aggregation. Defaults to `\"sum_over_batch_size\"`.\n        axis: The axis along which to compute crossentropy (the features\n            axis). Defaults to `-1`.\n        name: Optional name for the loss instance.\n        dtype: The dtype of the loss's computations. Defaults to `None`, which\n            means using `keras.backend.floatx()`. `keras.backend.floatx()` is a\n            `\"float32\"` unless set to different value\n            (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is\n            provided, then the `compute_dtype` will be utilized.\n\n    Examples:\n\n    >>> y_true = np.array([1, 2])\n    >>> y_pred = np.array([[0.05, 0.95, 0], [0.1, 0.8, 0.1]])\n    >>> # Using 'auto'/'sum_over_batch_size' reduction type.\n    >>> scce = keras.losses.SparseCategoricalCrossentropy()\n    >>> scce(y_true, y_pred)\n    1.177\n\n    >>> # Calling with 'sample_weight'.\n    >>> scce(y_true, y_pred, sample_weight=np.array([0.3, 0.7]))\n    0.814\n\n    >>> # Using 'sum' reduction type.\n    >>> scce = keras.losses.SparseCategoricalCrossentropy(\n    ...     reduction=\"sum\")\n    >>> scce(y_true, y_pred)\n    2.354\n\n    >>> # Using 'none' reduction type.\n    >>> scce = keras.losses.SparseCategoricalCrossentropy(\n    ...     reduction=None)\n    >>> scce(y_true, y_pred)\n    array([0.0513, 2.303], dtype=float32)\n\n    Usage with the `compile()` API:\n\n    ```python\n    model.compile(optimizer='sgd',\n                  loss=keras.losses.SparseCategoricalCrossentropy())\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        from_logits=False,\n        ignore_class=None,\n        reduction=\"sum_over_batch_size\",\n        axis=-1,\n        name=\"sparse_categorical_crossentropy\",\n        dtype=None,\n    ):\n        super().__init__(\n            sparse_categorical_crossentropy,\n            name=name,\n            reduction=reduction,\n            dtype=dtype,\n            from_logits=from_logits,\n            ignore_class=ignore_class,\n            axis=axis,\n        )\n        self.from_logits = from_logits\n        self.ignore_class = ignore_class\n\n    def get_config(self):\n        config = Loss.get_config(self)\n        config.update(\n            {\n                \"from_logits\": self.from_logits,\n                \"ignore_class\": self.ignore_class,\n            }\n        )\n        return config\n\n\n@keras_export(\"keras.losses.CTC\")\nclass CTC(LossFunctionWrapper):\n    \"\"\"CTC (Connectionist Temporal Classification) loss.\n\n    Args:\n        reduction: Type of reduction to apply to the loss. In almost all cases\n            this should be `\"sum_over_batch_size\"`. Supported options are\n            `\"sum\"`, `\"sum_over_batch_size\"`, `\"mean\"`,\n            `\"mean_with_sample_weight\"` or `None`. `\"sum\"` sums the loss,\n            `\"sum_over_batch_size\"` and `\"mean\"` sum the loss and divide by the\n            sample size, and `\"mean_with_sample_weight\"` sums the loss and\n            divides by the sum of the sample weights. `\"none\"` and `None`\n            perform no aggregation. Defaults to `\"sum_over_batch_size\"`.\n        name: Optional name for the loss instance.\n        dtype: The dtype of the loss's computations. Defaults to `None`, which\n            means using `keras.backend.floatx()`. `keras.backend.floatx()` is a\n            `\"float32\"` unless set to different value\n            (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is\n            provided, then the `compute_dtype` will be utilized.\n    \"\"\"\n\n    def __init__(self, reduction=\"sum_over_batch_size\", name=\"ctc\", dtype=None):\n        super().__init__(ctc, name=name, reduction=reduction, dtype=dtype)\n\n    def get_config(self):\n        return Loss.get_config(self)\n\n\n@keras_export(\"keras.losses.Dice\")\nclass Dice(LossFunctionWrapper):\n    \"\"\"Computes the Dice loss value between `y_true` and `y_pred`.\n\n    Formula:\n    ```python\n    loss = 1 - (2 * sum(y_true * y_pred)) / (sum(y_true) + sum(y_pred))\n    ```\n\n    Args:\n        reduction: Type of reduction to apply to the loss. In almost all cases\n            this should be `\"sum_over_batch_size\"`. Supported options are\n            `\"sum\"`, `\"sum_over_batch_size\"`, `\"mean\"`,\n            `\"mean_with_sample_weight\"` or `None`. `\"sum\"` sums the loss,\n            `\"sum_over_batch_size\"` and `\"mean\"` sum the loss and divide by the\n            sample size, and `\"mean_with_sample_weight\"` sums the loss and\n            divides by the sum of the sample weights. `\"none\"` and `None`\n            perform no aggregation. Defaults to `\"sum_over_batch_size\"`.\n        name: Optional name for the loss instance.\n        axis: Tuple for which dimensions the loss is calculated. Defaults to\n            `None`.\n        dtype: The dtype of the loss's computations. Defaults to `None`, which\n            means using `keras.backend.floatx()`. `keras.backend.floatx()` is a\n            `\"float32\"` unless set to different value\n            (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is\n            provided, then the `compute_dtype` will be utilized.\n\n    Returns:\n        Dice loss value.\n\n    Example:\n\n    >>> y_true = [[[[1.0], [1.0]], [[0.0], [0.0]]],\n    ...           [[[1.0], [1.0]], [[0.0], [0.0]]]]\n    >>> y_pred = [[[[0.0], [1.0]], [[0.0], [1.0]]],\n    ...           [[[0.4], [0.0]], [[0.0], [0.9]]]]\n    >>> axis = (1, 2, 3)\n    >>> loss = keras.losses.Dice(axis=axis, reduction=None)(y_true, y_pred)\n    >>> assert loss.shape == (2,)\n    >>> loss\n    array([0.5, 0.75757575], shape=(2,), dtype=float32)\n\n    >>> loss = keras.losses.Dice()(y_true, y_pred)\n    >>> assert loss.shape == ()\n    >>> loss\n    array(0.6164384, shape=(), dtype=float32)\n\n    >>> y_true = np.array(y_true)\n    >>> y_pred = np.array(y_pred)\n    >>> loss = keras.losses.Dice(axis=axis, reduction=None)(y_true, y_pred)\n    >>> assert loss.shape == (2,)\n    >>> loss\n    array([0.5, 0.75757575], shape=(2,), dtype=float32)\n\n    \"\"\"\n\n    def __init__(\n        self,\n        reduction=\"sum_over_batch_size\",\n        name=\"dice\",\n        axis=None,\n        dtype=None,\n    ):\n        super().__init__(\n            dice, name=name, reduction=reduction, dtype=dtype, axis=axis\n        )\n        self.axis = axis\n\n    def get_config(self):\n        config = Loss.get_config(self)\n        config.update({\"axis\": self.axis})\n        return config\n\n\n@keras_export(\"keras.losses.Tversky\")\nclass Tversky(LossFunctionWrapper):\n    \"\"\"Computes the Tversky loss value between `y_true` and `y_pred`.\n\n    This loss function is weighted by the alpha and beta coefficients\n    that penalize false positives and false negatives.\n\n    With `alpha=0.5` and `beta=0.5`, the loss value becomes equivalent to\n    Dice Loss.\n\n    Args:\n        alpha: The coefficient controlling incidence of false positives.\n            Defaults to `0.5`.\n        beta: The coefficient controlling incidence of false negatives.\n            Defaults to `0.5`.\n        reduction: Type of reduction to apply to the loss. In almost all cases\n            this should be `\"sum_over_batch_size\"`. Supported options are\n            `\"sum\"`, `\"sum_over_batch_size\"`, `\"mean\"`,\n            `\"mean_with_sample_weight\"` or `None`. `\"sum\"` sums the loss,\n            `\"sum_over_batch_size\"` and `\"mean\"` sum the loss and divide by the\n            sample size, and `\"mean_with_sample_weight\"` sums the loss and\n            divides by the sum of the sample weights. `\"none\"` and `None`\n            perform no aggregation. Defaults to `\"sum_over_batch_size\"`.\n        name: Optional name for the loss instance.\n        dtype: The dtype of the loss's computations. Defaults to `None`, which\n            means using `keras.backend.floatx()`. `keras.backend.floatx()` is a\n            `\"float32\"` unless set to different value\n            (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is\n            provided, then the `compute_dtype` will be utilized.\n\n    Returns:\n        Tversky loss value.\n\n    Reference:\n\n    - [Salehi et al., 2017](https://arxiv.org/abs/1706.05721)\n    \"\"\"\n\n    def __init__(\n        self,\n        alpha=0.5,\n        beta=0.5,\n        reduction=\"sum_over_batch_size\",\n        name=\"tversky\",\n        axis=None,\n        dtype=None,\n    ):\n        super().__init__(\n            tversky,\n            name=name,\n            reduction=reduction,\n            dtype=dtype,\n            alpha=alpha,\n            beta=beta,\n            axis=axis,\n        )\n        self.alpha = alpha\n        self.beta = beta\n        self.axis = axis\n\n    def get_config(self):\n        config = Loss.get_config(self)\n        config.update(\n            {\"alpha\": self.alpha, \"beta\": self.beta, \"axis\": self.axis}\n        )\n        return config\n\n\n@keras_export(\"keras.losses.Circle\")\nclass Circle(LossFunctionWrapper):\n    \"\"\"Computes Circle Loss between integer labels and L2-normalized embeddings.\n\n    This is a metric learning loss designed to minimize within-class distance\n    and maximize between-class distance in a flexible manner by dynamically\n    adjusting the penalty strength based on optimization status of each\n    similarity score.\n\n    To use Circle Loss effectively, the model should output embeddings without\n    an activation function (such as a `Dense` layer with `activation=None`)\n    followed by UnitNormalization layer to ensure unit-norm embeddings.\n\n    Args:\n        gamma: Scaling factor that determines the largest scale of each\n            similarity score. Defaults to `80`.\n        margin: The relaxation factor, below this distance, negatives are\n        up weighted and positives are down weighted. Similarly, above this\n        distance negatives are down weighted and positive are up weighted.\n            Defaults to `0.4`.\n        remove_diagonal: Boolean, whether to remove self-similarities from the\n            positive mask. Defaults to `True`.\n        reduction: Type of reduction to apply to the loss. In almost all cases\n            this should be `\"sum_over_batch_size\"`. Supported options are\n            `\"sum\"`, `\"sum_over_batch_size\"`, `\"mean\"`,\n            `\"mean_with_sample_weight\"` or `None`. `\"sum\"` sums the loss,\n            `\"sum_over_batch_size\"` and `\"mean\"` sum the loss and divide by the\n            sample size, and `\"mean_with_sample_weight\"` sums the loss and\n            divides by the sum of the sample weights. `\"none\"` and `None`\n            perform no aggregation. Defaults to `\"sum_over_batch_size\"`.\n        name: Optional name for the loss instance.\n        dtype: The dtype of the loss's computations. Defaults to `None`, which\n            means using `keras.backend.floatx()`. `keras.backend.floatx()` is a\n            `\"float32\"` unless set to different value\n            (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is\n            provided, then the `compute_dtype` will be utilized.\n\n    Examples:\n\n    Usage with the `compile()` API:\n\n    ```python\n    model = models.Sequential([\n        keras.layers.Input(shape=(224, 224, 3)),\n        keras.layers.Conv2D(16, (3, 3), activation='relu'),\n        keras.layers.Flatten(),\n        keras.layers.Dense(64, activation=None),  # No activation\n        keras.layers.UnitNormalization()  # L2 normalization\n    ])\n\n    model.compile(optimizer=\"adam\", loss=keras.losses.Circle())\n    ```\n\n    Reference:\n    - [Yifan Sun et al., 2020](https://arxiv.org/abs/2002.10857)\n\n    \"\"\"\n\n    def __init__(\n        self,\n        gamma=80.0,\n        margin=0.4,\n        remove_diagonal=True,\n        reduction=\"sum_over_batch_size\",\n        name=\"circle\",\n        dtype=None,\n    ):\n        super().__init__(\n            circle,\n            name=name,\n            reduction=reduction,\n            dtype=dtype,\n            gamma=gamma,\n            margin=margin,\n            remove_diagonal=remove_diagonal,\n        )\n        self.gamma = gamma\n        self.margin = margin\n        self.remove_diagonal = remove_diagonal\n\n    def get_config(self):\n        config = Loss.get_config(self)\n        config.update(\n            {\n                \"gamma\": self.gamma,\n                \"margin\": self.margin,\n                \"remove_diagonal\": self.remove_diagonal,\n            }\n        )\n        return config\n\n\n@keras_export(\"keras.losses.CategoricalGeneralizedCrossEntropy\")\nclass CategoricalGeneralizedCrossEntropy(LossFunctionWrapper):\n    \"\"\"Computes the Generalized Cross Entropy loss between `y_true` & `y_pred`.\n\n    Generalized Cross Entropy (GCE) is a noise-robust loss function\n    that provides better robustness against noisy labels than\n    standard cross entropy.\n    It generalizes both cross entropy and mean absolute error through\n    the parameter q, where values closer to 1 make the loss more robust\n    to noisy labels.\n\n    Formula:\n    ```python\n    loss = (1 - p**q) / q\n    ```\n    where `p` is the predicted probability for the true class and `q`\n    is the noise parameter.\n\n    Args:\n        q: Float in range `(0, 1)`. It is the noise parameter.\n           Controls the behavior of the loss:\n            - As `q` approaches 0: Behaves more like cross entropy\n            - As `q` approaches 1: Behaves more like mean absolute error\n           Defaults to `0.5`\n        reduction: Type of reduction to apply to the loss. In almost all cases\n            this should be `\"sum_over_batch_size\"`. Supported options are\n            `\"sum\"`, `\"sum_over_batch_size\"`, `\"mean\"`,\n            `\"mean_with_sample_weight\"` or `None`. `\"sum\"` sums the loss,\n            `\"sum_over_batch_size\"` and `\"mean\"` sum the loss and divide by the\n            sample size, and `\"mean_with_sample_weight\"` sums the loss and\n            divides by the sum of the sample weights. `\"none\"` and `None`\n            perform no aggregation. Defaults to `\"sum_over_batch_size\"`.\n        name: Optional name for the loss instance.\n        dtype: The dtype of the loss's computations. Defaults to `None`, which\n            means using `keras.backend.floatx()`. `keras.backend.floatx()` is a\n            `\"float32\"` unless set to different value\n            (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is\n            provided, then the `compute_dtype` will be utilized.\n\n    Example:\n    ```python\n    y_true = np.array([0, 1, 0, 1])\n    y_pred = np.array([[0.7, 0.3], [0.2, 0.8], [0.6, 0.4], [0.4, 0.6]])\n    keras.losses.CategoricalGeneralizedCrossEntropy()(y_true, y_pred)\n    ```\n\n    References:\n        - [Zhang, Sabuncu, 2018](https://arxiv.org/abs/1805.07836)\n          (\"Generalized Cross Entropy Loss for Training\n            Deep Neural Networks with Noisy Labels\")\n    \"\"\"\n\n    def __init__(\n        self,\n        q=0.5,\n        reduction=\"sum_over_batch_size\",\n        name=\"categorical_generalized_cross_entropy\",\n        dtype=None,\n    ):\n        if not 0 < q < 1:\n            raise ValueError(\"q must be in the interval (0, 1)\")\n        super().__init__(\n            categorical_generalized_cross_entropy,\n            name=name,\n            reduction=reduction,\n            dtype=dtype,\n            q=q,\n        )\n        self.q = q\n\n    def get_config(self):\n        config = Loss.get_config(self)\n        config.update(\n            {\n                \"q\": self.q,\n            }\n        )\n        return config\n\n\ndef convert_binary_labels_to_hinge(y_true):\n    \"\"\"Converts binary labels into -1/1 for hinge loss/metric calculation.\"\"\"\n    are_zeros = ops.equal(y_true, 0)\n    are_ones = ops.equal(y_true, 1)\n    is_binary = ops.all((ops.logical_or(are_zeros, are_ones)))\n\n    def _convert_binary_labels():\n        # Convert the binary labels to -1 or 1.\n        return 2.0 * y_true - 1.0\n\n    def _return_labels_unconverted():\n        # Returns the labels unchanged if they are non-binary\n        return y_true\n\n    updated_y_true = ops.cond(\n        is_binary, _convert_binary_labels, _return_labels_unconverted\n    )\n    return updated_y_true\n\n\n@keras_export(\n    [\n        \"keras.metrics.hinge\",\n        \"keras.losses.hinge\",\n    ]\n)\ndef hinge(y_true, y_pred):\n    \"\"\"Computes the hinge loss between `y_true` & `y_pred`.\n\n    Formula:\n\n    ```python\n    loss = mean(maximum(1 - y_true * y_pred, 0), axis=-1)\n    ```\n\n    Args:\n        y_true: The ground truth values. `y_true` values are expected to be -1\n            or 1. If binary (0 or 1) labels are provided they will be converted\n            to -1 or 1 with shape = `[batch_size, d0, .. dN]`.\n        y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`.\n\n    Returns:\n        Hinge loss values with shape = `[batch_size, d0, .. dN-1]`.\n\n    Example:\n\n    >>> y_true = np.random.choice([-1, 1], size=(2, 3))\n    >>> y_pred = np.random.random(size=(2, 3))\n    >>> loss = keras.losses.hinge(y_true, y_pred)\n    \"\"\"\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true = ops.cast(y_true, dtype=y_pred.dtype)\n    y_true = ops.convert_to_tensor(y_true)\n    y_true = convert_binary_labels_to_hinge(y_true)\n    return ops.mean(ops.maximum(1.0 - y_true * y_pred, 0.0), axis=-1)\n\n\n@keras_export(\n    [\n        \"keras.metrics.squared_hinge\",\n        \"keras.losses.squared_hinge\",\n    ]\n)\ndef squared_hinge(y_true, y_pred):\n    \"\"\"Computes the squared hinge loss between `y_true` & `y_pred`.\n\n    Formula:\n\n    ```python\n    loss = mean(square(maximum(1 - y_true * y_pred, 0)), axis=-1)\n    ```\n\n    Args:\n        y_true: The ground truth values. `y_true` values are expected to be -1\n            or 1. If binary (0 or 1) labels are provided we will convert them\n            to -1 or 1 with shape = `[batch_size, d0, .. dN]`.\n        y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`.\n\n    Returns:\n        Squared hinge loss values with shape = `[batch_size, d0, .. dN-1]`.\n\n    Example:\n\n    >>> y_true = np.random.choice([-1, 1], size=(2, 3))\n    >>> y_pred = np.random.random(size=(2, 3))\n    >>> loss = keras.losses.squared_hinge(y_true, y_pred)\n    \"\"\"\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true = ops.cast(y_true, y_pred.dtype)\n    y_true = convert_binary_labels_to_hinge(y_true)\n    return ops.mean(\n        ops.square(ops.maximum(1.0 - y_true * y_pred, 0.0)), axis=-1\n    )\n\n\n@keras_export(\n    [\n        \"keras.metrics.categorical_hinge\",\n        \"keras.losses.categorical_hinge\",\n    ]\n)\ndef categorical_hinge(y_true, y_pred):\n    \"\"\"Computes the categorical hinge loss between `y_true` & `y_pred`.\n\n    Formula:\n\n    ```python\n    loss = maximum(neg - pos + 1, 0)\n    ```\n\n    where `neg=maximum((1-y_true)*y_pred)` and `pos=sum(y_true*y_pred)`\n\n    Args:\n        y_true: The ground truth values. `y_true` values are expected to be\n            either `{-1, +1}` or `{0, 1}` (i.e. a one-hot-encoded tensor) with\n            shape = `[batch_size, d0, .. dN]`.\n        y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`.\n\n    Returns:\n        Categorical hinge loss values with shape = `[batch_size, d0, .. dN-1]`.\n\n    Example:\n\n    >>> y_true = np.random.randint(0, 3, size=(2,))\n    >>> y_true = np.eye(np.max(y_true) + 1)[y_true]\n    >>> y_pred = np.random.random(size=(2, 3))\n    >>> loss = keras.losses.categorical_hinge(y_true, y_pred)\n    \"\"\"\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true = ops.cast(y_true, y_pred.dtype)\n    pos = ops.sum(y_true * y_pred, axis=-1)\n    neg = ops.max((1.0 - y_true) * y_pred, axis=-1)\n    zero = ops.cast(0.0, y_pred.dtype)\n    return ops.maximum(neg - pos + 1.0, zero)\n\n\n@keras_export(\n    [\n        \"keras.metrics.mean_squared_error\",\n        \"keras.losses.mean_squared_error\",\n        # Legacy aliases\n        \"keras._legacy.losses.mse\",\n        \"keras._legacy.losses.MSE\",\n        \"keras._legacy.metrics.mse\",\n        \"keras._legacy.metrics.MSE\",\n    ]\n)\ndef mean_squared_error(y_true, y_pred):\n    \"\"\"Computes the mean squared error between labels and predictions.\n\n    Formula:\n\n    ```python\n    loss = mean(square(y_true - y_pred), axis=-1)\n    ```\n\n    Example:\n\n    >>> y_true = np.random.randint(0, 2, size=(2, 3))\n    >>> y_pred = np.random.random(size=(2, 3))\n    >>> loss = keras.losses.mean_squared_error(y_true, y_pred)\n\n    Args:\n        y_true: Ground truth values with shape = `[batch_size, d0, .. dN]`.\n        y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`.\n\n    Returns:\n        Mean squared error values with shape = `[batch_size, d0, .. dN-1]`.\n    \"\"\"\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)\n    y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)\n    return ops.mean(ops.square(y_true - y_pred), axis=-1)\n\n\n@keras_export(\n    [\n        \"keras.metrics.mean_absolute_error\",\n        \"keras.losses.mean_absolute_error\",\n        # Legacy aliases\n        \"keras._legacy.losses.MAE\",\n        \"keras._legacy.losses.mae\",\n        \"keras._legacy.metrics.MAE\",\n        \"keras._legacy.metrics.mae\",\n    ]\n)\ndef mean_absolute_error(y_true, y_pred):\n    \"\"\"Computes the mean absolute error between labels and predictions.\n\n    ```python\n    loss = mean(abs(y_true - y_pred), axis=-1)\n    ```\n\n    Args:\n        y_true: Ground truth values with shape = `[batch_size, d0, .. dN]`.\n        y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`.\n\n    Returns:\n        Mean absolute error values with shape = `[batch_size, d0, .. dN-1]`.\n\n    Example:\n\n    >>> y_true = np.random.randint(0, 2, size=(2, 3))\n    >>> y_pred = np.random.random(size=(2, 3))\n    >>> loss = keras.losses.mean_absolute_error(y_true, y_pred)\n    \"\"\"\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)\n    y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)\n    return ops.mean(ops.abs(y_true - y_pred), axis=-1)\n\n\n@keras_export(\n    [\n        \"keras.metrics.mean_absolute_percentage_error\",\n        \"keras.losses.mean_absolute_percentage_error\",\n        # Legacy aliases\n        \"keras._legacy.losses.mape\",\n        \"keras._legacy.losses.MAPE\",\n        \"keras._legacy.metrics.mape\",\n        \"keras._legacy.metrics.MAPE\",\n    ]\n)\ndef mean_absolute_percentage_error(y_true, y_pred):\n    \"\"\"Computes the mean absolute percentage error between `y_true` & `y_pred`.\n\n    Formula:\n\n    ```python\n    loss = 100 * mean(abs((y_true - y_pred) / y_true), axis=-1)\n    ```\n\n    Division by zero is prevented by dividing by `maximum(y_true, epsilon)`\n    where `epsilon = keras.backend.epsilon()`\n    (default to `1e-7`).\n\n    Args:\n        y_true: Ground truth values with shape = `[batch_size, d0, .. dN]`.\n        y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`.\n\n    Returns:\n        Mean absolute percentage error values with shape = `[batch_size, d0, ..\n        dN-1]`.\n\n    Example:\n\n    >>> y_true = np.random.random(size=(2, 3))\n    >>> y_pred = np.random.random(size=(2, 3))\n    >>> loss = keras.losses.mean_absolute_percentage_error(y_true, y_pred)\n    \"\"\"\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)\n    epsilon = ops.convert_to_tensor(backend.epsilon(), dtype=y_pred.dtype)\n    y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)\n    diff = ops.abs((y_true - y_pred) / ops.maximum(ops.abs(y_true), epsilon))\n    return 100.0 * ops.mean(diff, axis=-1)\n\n\n@keras_export(\n    [\n        \"keras.metrics.mean_squared_logarithmic_error\",\n        \"keras.losses.mean_squared_logarithmic_error\",\n        # Legacy aliases\n        \"keras._legacy.losses.msle\",\n        \"keras._legacy.losses.MSLE\",\n        \"keras._legacy.metrics.msle\",\n        \"keras._legacy.metrics.MSLE\",\n    ]\n)\ndef mean_squared_logarithmic_error(y_true, y_pred):\n    \"\"\"Computes the mean squared logarithmic error between `y_true` & `y_pred`.\n\n    Formula:\n\n    ```python\n    loss = mean(square(log(y_true + 1) - log(y_pred + 1)), axis=-1)\n    ```\n\n    Note that `y_pred` and `y_true` cannot be less or equal to 0. Negative\n    values and 0 values will be replaced with `keras.backend.epsilon()`\n    (default to `1e-7`).\n\n    Args:\n        y_true: Ground truth values with shape = `[batch_size, d0, .. dN]`.\n        y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`.\n\n    Returns:\n        Mean squared logarithmic error values with shape = `[batch_size, d0, ..\n        dN-1]`.\n\n    Example:\n\n    >>> y_true = np.random.randint(0, 2, size=(2, 3))\n    >>> y_pred = np.random.random(size=(2, 3))\n    >>> loss = keras.losses.mean_squared_logarithmic_error(y_true, y_pred)\n    \"\"\"\n    epsilon = ops.convert_to_tensor(backend.epsilon())\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)\n    y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)\n    first_log = ops.log(ops.maximum(y_pred, epsilon) + 1.0)\n    second_log = ops.log(ops.maximum(y_true, epsilon) + 1.0)\n    return ops.mean(ops.square(first_log - second_log), axis=-1)\n\n\n@keras_export(\"keras.losses.cosine_similarity\")\ndef cosine_similarity(y_true, y_pred, axis=-1):\n    \"\"\"Computes the cosine similarity between labels and predictions.\n\n    Formula:\n    ```python\n    loss = -sum(l2_norm(y_true) * l2_norm(y_pred))\n    ```\n\n    Note that it is a number between -1 and 1. When it is a negative number\n    between -1 and 0, 0 indicates orthogonality and values closer to -1\n    indicate greater similarity. This makes it usable as a loss function in a\n    setting where you try to maximize the proximity between predictions and\n    targets. If either `y_true` or `y_pred` is a zero vector, cosine\n    similarity will be 0 regardless of the proximity between predictions\n    and targets.\n\n    Args:\n        y_true: Tensor of true targets.\n        y_pred: Tensor of predicted targets.\n        axis: Axis along which to determine similarity. Defaults to `-1`.\n\n    Returns:\n        Cosine similarity tensor.\n\n    Example:\n\n    >>> y_true = [[0., 1.], [1., 1.], [1., 1.]]\n    >>> y_pred = [[1., 0.], [1., 1.], [-1., -1.]]\n    >>> loss = keras.losses.cosine_similarity(y_true, y_pred, axis=-1)\n    [-0., -0.99999994, 0.99999994]\n    \"\"\"\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)\n    y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)\n    y_pred = normalize(y_pred, axis=axis)\n    y_true = normalize(y_true, axis=axis)\n    return -ops.sum(y_true * y_pred, axis=axis)\n\n\n@keras_export([\"keras.losses.huber\", \"keras.metrics.huber\"])\ndef huber(y_true, y_pred, delta=1.0):\n    \"\"\"Computes Huber loss value.\n\n    Formula:\n    ```python\n    for x in error:\n        if abs(x) <= delta:\n            loss.append(0.5 * x^2)\n        elif abs(x) > delta:\n            loss.append(delta * abs(x) - 0.5 * delta^2)\n\n    loss = mean(loss, axis=-1)\n    ```\n    See: [Huber loss](https://en.wikipedia.org/wiki/Huber_loss).\n\n    Example:\n\n    >>> y_true = [[0, 1], [0, 0]]\n    >>> y_pred = [[0.6, 0.4], [0.4, 0.6]]\n    >>> loss = keras.losses.huber(y_true, y_pred)\n    0.155\n\n\n    Args:\n        y_true: tensor of true targets.\n        y_pred: tensor of predicted targets.\n        delta: A float, the point where the Huber loss function changes from a\n            quadratic to linear. Defaults to `1.0`.\n\n    Returns:\n        Tensor with one scalar loss entry per sample.\n    \"\"\"\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)\n    y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)\n    delta = ops.convert_to_tensor(delta, dtype=y_pred.dtype)\n    error = ops.subtract(y_pred, y_true)\n    abs_error = ops.abs(error)\n    half = ops.convert_to_tensor(0.5, dtype=abs_error.dtype)\n    return ops.mean(\n        ops.where(\n            abs_error <= delta,\n            half * ops.square(error),\n            delta * abs_error - half * ops.square(delta),\n        ),\n        axis=-1,\n    )\n\n\n@keras_export(\n    [\n        \"keras.losses.log_cosh\",\n        \"keras.metrics.log_cosh\",\n        # Legacy aliases\n        \"keras._legacy.losses.logcosh\",\n        \"keras._legacy.metrics.logcosh\",\n    ]\n)\ndef log_cosh(y_true, y_pred):\n    \"\"\"Logarithm of the hyperbolic cosine of the prediction error.\n\n    Formula:\n    ```python\n    loss = mean(log(cosh(y_pred - y_true)), axis=-1)\n    ```\n\n    Note that `log(cosh(x))` is approximately equal to `(x ** 2) / 2` for small\n    `x` and to `abs(x) - log(2)` for large `x`. This means that 'logcosh' works\n    mostly like the mean squared error, but will not be so strongly affected by\n    the occasional wildly incorrect prediction.\n\n    Example:\n\n    >>> y_true = [[0., 1.], [0., 0.]]\n    >>> y_pred = [[1., 1.], [0., 0.]]\n    >>> loss = keras.losses.log_cosh(y_true, y_pred)\n    0.108\n\n    Args:\n        y_true: Ground truth values with shape = `[batch_size, d0, .. dN]`.\n        y_pred: The predicted values with shape = `[batch_size, d0, .. dN]`.\n\n    Returns:\n        Logcosh error values with shape = `[batch_size, d0, .. dN-1]`.\n    \"\"\"\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)\n    y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)\n    log2 = ops.convert_to_tensor(ops.log(2.0), dtype=y_pred.dtype)\n\n    def _logcosh(x):\n        return x + ops.softplus(x * -2.0) - log2\n\n    return ops.mean(_logcosh(y_pred - y_true), axis=-1)\n\n\n@keras_export(\n    [\n        \"keras.metrics.kl_divergence\",\n        \"keras.losses.kl_divergence\",\n        # Legacy aliases\n        \"keras._legacy.losses.KLD\",\n        \"keras._legacy.losses.kld\",\n        \"keras._legacy.losses.kullback_leibler_divergence\",\n        \"keras._legacy.metrics.KLD\",\n        \"keras._legacy.metrics.kld\",\n        \"keras._legacy.metrics.kullback_leibler_divergence\",\n    ]\n)\ndef kl_divergence(y_true, y_pred):\n    \"\"\"Computes Kullback-Leibler divergence loss between `y_true` & `y_pred`.\n\n    Formula:\n\n    ```python\n    loss = y_true * log(y_true / y_pred)\n    ```\n\n    `y_true` and `y_pred` are expected to be probability\n    distributions, with values between 0 and 1. They will get\n    clipped to the `[0, 1]` range.\n\n    Args:\n        y_true: Tensor of true targets.\n        y_pred: Tensor of predicted targets.\n\n    Returns:\n        KL Divergence loss values with shape = `[batch_size, d0, .. dN-1]`.\n\n    Example:\n\n    >>> y_true = np.random.randint(0, 2, size=(2, 3)).astype(np.float32)\n    >>> y_pred = np.random.random(size=(2, 3))\n    >>> loss = keras.losses.kl_divergence(y_true, y_pred)\n    >>> assert loss.shape == (2,)\n    >>> y_true = ops.clip(y_true, 1e-7, 1)\n    >>> y_pred = ops.clip(y_pred, 1e-7, 1)\n    >>> assert np.array_equal(\n    ...     loss, np.sum(y_true * np.log(y_true / y_pred), axis=-1))\n    \"\"\"\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true = ops.convert_to_tensor(y_true, y_pred.dtype)\n    y_true = ops.clip(y_true, backend.epsilon(), 1)\n    y_pred = ops.clip(y_pred, backend.epsilon(), 1)\n    return ops.sum(y_true * ops.log(y_true / y_pred), axis=-1)\n\n\n@keras_export(\n    [\n        \"keras.metrics.poisson\",\n        \"keras.losses.poisson\",\n    ]\n)\ndef poisson(y_true, y_pred):\n    \"\"\"Computes the Poisson loss between y_true and y_pred.\n\n    Formula:\n\n    ```python\n    loss = y_pred - y_true * log(y_pred)\n    ```\n\n    Args:\n        y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.\n        y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.\n\n    Returns:\n        Poisson loss values with shape = `[batch_size, d0, .. dN-1]`.\n\n    Example:\n\n    >>> y_true = np.random.randint(0, 2, size=(2, 3))\n    >>> y_pred = np.random.random(size=(2, 3))\n    >>> loss = keras.losses.poisson(y_true, y_pred)\n    >>> assert loss.shape == (2,)\n    >>> y_pred = y_pred + 1e-7\n    >>> assert np.allclose(\n    ...     loss, np.mean(y_pred - y_true * np.log(y_pred), axis=-1),\n    ...     atol=1e-5)\n    \"\"\"\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)\n    epsilon = ops.convert_to_tensor(backend.epsilon(), dtype=y_pred.dtype)\n    return ops.mean(y_pred - y_true * ops.log(y_pred + epsilon), axis=-1)\n\n\n@keras_export(\n    [\n        \"keras.metrics.categorical_crossentropy\",\n        \"keras.losses.categorical_crossentropy\",\n    ]\n)\ndef categorical_crossentropy(\n    y_true, y_pred, from_logits=False, label_smoothing=0.0, axis=-1\n):\n    \"\"\"Computes the categorical crossentropy loss.\n\n    Args:\n        y_true: Tensor of one-hot true targets.\n        y_pred: Tensor of predicted targets.\n        from_logits: Whether `y_pred` is expected to be a logits tensor. By\n            default, we assume that `y_pred` encodes a probability distribution.\n        label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. For\n            example, if `0.1`, use `0.1 / num_classes` for non-target labels\n            and `0.9 + 0.1 / num_classes` for target labels.\n        axis: Defaults to `-1`. The dimension along which the entropy is\n            computed.\n\n    Returns:\n        Categorical crossentropy loss value.\n\n    Example:\n\n    >>> y_true = [[0, 1, 0], [0, 0, 1]]\n    >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]\n    >>> loss = keras.losses.categorical_crossentropy(y_true, y_pred)\n    >>> assert loss.shape == (2,)\n    >>> loss\n    array([0.0513, 2.303], dtype=float32)\n    \"\"\"\n    if isinstance(axis, bool):\n        raise ValueError(\n            \"`axis` must be of type `int`. \"\n            f\"Received: axis={axis} of type {type(axis)}\"\n        )\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true = ops.cast(y_true, y_pred.dtype)\n\n    if y_pred.shape[-1] == 1:\n        warnings.warn(\n            \"In loss categorical_crossentropy, expected \"\n            \"y_pred.shape to be (batch_size, num_classes) \"\n            f\"with num_classes > 1. Received: y_pred.shape={y_pred.shape}. \"\n            \"Consider using 'binary_crossentropy' if you only have 2 classes.\",\n            SyntaxWarning,\n            stacklevel=2,\n        )\n\n    if label_smoothing:\n        num_classes = ops.cast(ops.shape(y_true)[-1], y_pred.dtype)\n        y_true = y_true * (1.0 - label_smoothing) + (\n            label_smoothing / num_classes\n        )\n\n    return ops.categorical_crossentropy(\n        y_true, y_pred, from_logits=from_logits, axis=axis\n    )\n\n\n@keras_export(\n    [\n        \"keras.metrics.categorical_focal_crossentropy\",\n        \"keras.losses.categorical_focal_crossentropy\",\n    ]\n)\ndef categorical_focal_crossentropy(\n    y_true,\n    y_pred,\n    alpha=0.25,\n    gamma=2.0,\n    from_logits=False,\n    label_smoothing=0.0,\n    axis=-1,\n):\n    \"\"\"Computes the categorical focal crossentropy loss.\n\n    Args:\n        y_true: Tensor of one-hot true targets.\n        y_pred: Tensor of predicted targets.\n        alpha: A weight balancing factor for all classes, default is `0.25` as\n            mentioned in the reference. It can be a list of floats or a scalar.\n            In the multi-class case, alpha may be set by inverse class\n            frequency by using `compute_class_weight` from `sklearn.utils`.\n        gamma: A focusing parameter, default is `2.0` as mentioned in the\n            reference. It helps to gradually reduce the importance given to\n            simple examples in a smooth manner. When `gamma` = 0, there is\n            no focal effect on the categorical crossentropy.\n        from_logits: Whether `y_pred` is expected to be a logits tensor. By\n            default, we assume that `y_pred` encodes a probability\n            distribution.\n        label_smoothing: Float in [0, 1]. If > `0` then smooth the labels. For\n            example, if `0.1`, use `0.1 / num_classes` for non-target labels\n            and `0.9 + 0.1 / num_classes` for target labels.\n        axis: Defaults to `-1`. The dimension along which the entropy is\n            computed.\n\n    Returns:\n        Categorical focal crossentropy loss value.\n\n    Example:\n\n    >>> y_true = [[0, 1, 0], [0, 0, 1]]\n    >>> y_pred = [[0.05, 0.9, 0.05], [0.1, 0.85, 0.05]]\n    >>> loss = keras.losses.categorical_focal_crossentropy(y_true, y_pred)\n    >>> assert loss.shape == (2,)\n    >>> loss\n    array([2.63401289e-04, 6.75912094e-01], dtype=float32)\n    \"\"\"\n    if isinstance(axis, bool):\n        raise ValueError(\n            \"`axis` must be of type `int`. \"\n            f\"Received: axis={axis} of type {type(axis)}\"\n        )\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true = ops.cast(y_true, y_pred.dtype)\n\n    if y_pred.shape[-1] == 1:\n        warnings.warn(\n            \"In loss categorical_focal_crossentropy, expected \"\n            \"y_pred.shape to be (batch_size, num_classes) \"\n            f\"with num_classes > 1. Received: y_pred.shape={y_pred.shape}. \"\n            \"Consider using 'binary_crossentropy' if you only have 2 classes.\",\n            SyntaxWarning,\n            stacklevel=2,\n        )\n\n    if label_smoothing:\n        num_classes = ops.cast(ops.shape(y_true)[-1], y_pred.dtype)\n        y_true = y_true * (1.0 - label_smoothing) + (\n            label_smoothing / num_classes\n        )\n\n    if from_logits:\n        y_pred = ops.softmax(y_pred, axis=axis)\n\n    # Adjust the predictions so that the probability of\n    # each class for every sample adds up to 1\n    # This is needed to ensure that the cross entropy is\n    # computed correctly.\n    output = y_pred / ops.sum(y_pred, axis=axis, keepdims=True)\n    output = ops.clip(output, backend.epsilon(), 1.0 - backend.epsilon())\n\n    # Calculate cross entropy\n    cce = -y_true * ops.log(output)\n\n    # Calculate factors\n    modulating_factor = ops.power(1.0 - output, gamma)\n    weighting_factor = ops.multiply(modulating_factor, alpha)\n\n    # Apply weighting factor\n    focal_cce = ops.multiply(weighting_factor, cce)\n    focal_cce = ops.sum(focal_cce, axis=axis)\n    return focal_cce\n\n\n@keras_export(\n    [\n        \"keras.metrics.sparse_categorical_crossentropy\",\n        \"keras.losses.sparse_categorical_crossentropy\",\n    ]\n)\ndef sparse_categorical_crossentropy(\n    y_true, y_pred, from_logits=False, ignore_class=None, axis=-1\n):\n    \"\"\"Computes the sparse categorical crossentropy loss.\n\n    Args:\n        y_true: Ground truth values.\n        y_pred: The predicted values.\n        from_logits: Whether `y_pred` is expected to be a logits tensor. By\n            default, we assume that `y_pred` encodes a probability distribution.\n        ignore_class: Optional integer. The ID of a class to be ignored during\n            loss computation. This is useful, for example, in segmentation\n            problems featuring a \"void\" class (commonly -1 or 255) in\n            segmentation maps. By default (`ignore_class=None`), all classes are\n            considered.\n        axis: Defaults to `-1`. The dimension along which the entropy is\n            computed.\n\n    Returns:\n        Sparse categorical crossentropy loss value.\n\n    Examples:\n\n    >>> y_true = [1, 2]\n    >>> y_pred = [[0.05, 0.95, 0], [0.1, 0.8, 0.1]]\n    >>> loss = keras.losses.sparse_categorical_crossentropy(y_true, y_pred)\n    >>> assert loss.shape == (2,)\n    >>> loss\n    array([0.0513, 2.303], dtype=float32)\n    \"\"\"\n\n    if len(y_true.shape) == len(y_pred.shape) and y_true.shape[-1] == 1:\n        y_true = ops.squeeze(y_true, axis=-1)\n\n    if ignore_class is not None:\n        res_shape = ops.shape(y_pred)[:-1]\n        valid_mask = ops.not_equal(y_true, ops.cast(ignore_class, y_pred.dtype))\n        y_true = ops.multiply(y_true, ops.cast(valid_mask, y_true.dtype))\n        y_pred = ops.multiply(\n            y_pred,\n            ops.cast(ops.expand_dims(valid_mask, -1), y_pred.dtype),\n        )\n\n    res = ops.sparse_categorical_crossentropy(\n        y_true,\n        y_pred,\n        from_logits=from_logits,\n        axis=axis,\n    )\n\n    if ignore_class is not None:\n        valid_mask = ops.reshape(valid_mask, res_shape)\n        res = ops.where(valid_mask, res, 0.0)\n        backend.set_keras_mask(res, mask=valid_mask)\n\n    return res\n\n\n@keras_export(\n    [\n        \"keras.metrics.binary_crossentropy\",\n        \"keras.losses.binary_crossentropy\",\n    ]\n)\ndef binary_crossentropy(\n    y_true, y_pred, from_logits=False, label_smoothing=0.0, axis=-1\n):\n    \"\"\"Computes the binary crossentropy loss.\n\n    Args:\n        y_true: Ground truth values. shape = `[batch_size, d0, .. dN]`.\n        y_pred: The predicted values. shape = `[batch_size, d0, .. dN]`.\n        from_logits: Whether `y_pred` is expected to be a logits tensor. By\n            default, we assume that `y_pred` encodes a probability distribution.\n        label_smoothing: Float in `[0, 1]`. If > `0` then smooth the labels by\n            squeezing them towards 0.5, that is,\n            using `1. - 0.5 * label_smoothing` for the target class\n            and `0.5 * label_smoothing` for the non-target class.\n        axis: The axis along which the mean is computed. Defaults to `-1`.\n\n    Returns:\n        Binary crossentropy loss value. shape = `[batch_size, d0, .. dN-1]`.\n\n    Example:\n\n    >>> y_true = [[0, 1], [0, 0]]\n    >>> y_pred = [[0.6, 0.4], [0.4, 0.6]]\n    >>> loss = keras.losses.binary_crossentropy(y_true, y_pred)\n    >>> assert loss.shape == (2,)\n    >>> loss\n    array([0.916 , 0.714], dtype=float32)\n    \"\"\"\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true = ops.cast(y_true, y_pred.dtype)\n\n    if label_smoothing:\n        y_true = y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing\n\n    return ops.mean(\n        ops.binary_crossentropy(y_true, y_pred, from_logits=from_logits),\n        axis=axis,\n    )\n\n\n@keras_export(\n    [\n        \"keras.metrics.binary_focal_crossentropy\",\n        \"keras.losses.binary_focal_crossentropy\",\n    ]\n)\ndef binary_focal_crossentropy(\n    y_true,\n    y_pred,\n    apply_class_balancing=False,\n    alpha=0.25,\n    gamma=2.0,\n    from_logits=False,\n    label_smoothing=0.0,\n    axis=-1,\n):\n    \"\"\"Computes the binary focal crossentropy loss.\n\n    According to [Lin et al., 2018](https://arxiv.org/pdf/1708.02002.pdf), it\n    helps to apply a focal factor to down-weight easy examples and focus more on\n    hard examples. By default, the focal tensor is computed as follows:\n\n    `focal_factor = (1 - output) ** gamma` for class 1\n    `focal_factor = output ** gamma` for class 0\n    where `gamma` is a focusing parameter. When `gamma` = 0, there is no focal\n    effect on the binary crossentropy loss.\n\n    If `apply_class_balancing == True`, this function also takes into account a\n    weight balancing factor for the binary classes 0 and 1 as follows:\n\n    `weight = alpha` for class 1 (`target == 1`)\n    `weight = 1 - alpha` for class 0\n    where `alpha` is a float in the range of `[0, 1]`.\n\n    Args:\n        y_true: Ground truth values, of shape `(batch_size, d0, .. dN)`.\n        y_pred: The predicted values, of shape `(batch_size, d0, .. dN)`.\n        apply_class_balancing: A bool, whether to apply weight balancing on the\n            binary classes 0 and 1.\n        alpha: A weight balancing factor for class 1, default is `0.25` as\n            mentioned in the reference. The weight for class 0 is `1.0 - alpha`.\n        gamma: A focusing parameter, default is `2.0` as mentioned in the\n            reference.\n        from_logits: Whether `y_pred` is expected to be a logits tensor. By\n            default, we assume that `y_pred` encodes a probability distribution.\n        label_smoothing: Float in `[0, 1]`. If > `0` then smooth the labels by\n            squeezing them towards 0.5, that is,\n            using `1. - 0.5 * label_smoothing` for the target class\n            and `0.5 * label_smoothing` for the non-target class.\n        axis: The axis along which the mean is computed. Defaults to `-1`.\n\n    Returns:\n        Binary focal crossentropy loss value\n        with shape = `[batch_size, d0, .. dN-1]`.\n\n    Example:\n\n    >>> y_true = [[0, 1], [0, 0]]\n    >>> y_pred = [[0.6, 0.4], [0.4, 0.6]]\n    >>> # In this instance, the first sample in the second batch is the\n    >>> # 'easier' example.\n    >>> focal_loss = keras.losses.binary_focal_crossentropy(\n    ...        y_true, y_pred, gamma=2)\n    >>> assert loss.shape == (2,)\n    >>> focal_loss\n    array([0.330, 0.206], dtype=float32)\n    >>> # Compare with binary_crossentropy\n    >>> bce_loss = keras.losses.binary_focal_crossentropy(\n    ...        y_true, y_pred)\n    >>> bce_loss\n    array([0.916, 0.714], dtype=float32)\n    >>> # Binary focal crossentropy loss attributes more importance to the\n    >>> # harder example which results in a higher loss for the first batch\n    >>> # when normalized by binary cross entropy loss\n    >>> focal_loss/bce_loss\n    array([0.360, 0.289]\n    \"\"\"\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true = ops.cast(y_true, y_pred.dtype)\n\n    if label_smoothing:\n        y_true = y_true * (1.0 - label_smoothing) + 0.5 * label_smoothing\n\n    if from_logits:\n        y_pred = ops.sigmoid(y_pred)\n\n    bce = ops.binary_crossentropy(\n        target=y_true,\n        output=y_pred,\n        from_logits=False,\n    )\n\n    # Calculate focal factor\n    p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred)\n    focal_factor = ops.power(1.0 - p_t, gamma)\n\n    focal_bce = focal_factor * bce\n\n    if apply_class_balancing:\n        weight = y_true * alpha + (1 - y_true) * (1 - alpha)\n        focal_bce = weight * focal_bce\n\n    return ops.mean(focal_bce, axis=axis)\n\n\n@keras_export(\"keras.losses.ctc\")\ndef ctc(y_true, y_pred):\n    \"\"\"CTC (Connectionist Temporal Classification) loss.\n\n    Args:\n        y_true: A tensor of shape `(batch_size, max_length)` containing\n            the true labels in integer format. `0` always represents\n            the blank/mask index and should not be used for classes.\n        y_pred: A tensor of shape `(batch_size, max_length, num_classes)`\n            containing logits (the output of your model).\n            They should *not* be normalized via softmax.\n    \"\"\"\n    if len(ops.shape(y_true)) != 2:\n        raise ValueError(\n            \"Targets `y_true` are expected to be a tensor of shape \"\n            \"`(batch_size, max_length)` in integer format. \"\n            f\"Received: y_true.shape={ops.shape(y_true)}\"\n        )\n    if len(ops.shape(y_pred)) != 3:\n        raise ValueError(\n            \"Logits `y_pred` are expected to be a tensor of shape \"\n            \"`(batch_size, max_length, num_classes)`. \"\n            f\"Received: y_pred.shape={ops.shape(y_pred)}\"\n        )\n\n    mask_index = 0\n    batch_length = ops.shape(y_pred)[0]\n    input_length = ops.shape(y_pred)[1]\n    input_length = input_length * ops.ones((batch_length,), dtype=\"int32\")\n    label_length = ops.cast(\n        ops.sum(y_true != mask_index, axis=-1), dtype=\"int32\"\n    )\n\n    return ops.ctc_loss(\n        y_true, y_pred, label_length, input_length, mask_index=mask_index\n    )\n\n\n@keras_export(\"keras.losses.dice\")\ndef dice(y_true, y_pred, axis=None):\n    \"\"\"Computes the Dice loss value between `y_true` and `y_pred`.\n\n    Formula:\n    ```python\n    loss = 1 - (2 * sum(y_true * y_pred)) / (sum(y_true) + sum(y_pred))\n    ```\n\n    Args:\n        y_true: tensor of true targets.\n        y_pred: tensor of predicted targets.\n        axis: tuple for which dimensions the loss is calculated\n\n    Returns:\n        Dice loss value.\n\n    Example:\n\n    >>> y_true = [[[[1.0], [1.0]], [[0.0], [0.0]]],\n    ...           [[[1.0], [1.0]], [[0.0], [0.0]]]]\n    >>> y_pred = [[[[0.0], [1.0]], [[0.0], [1.0]]],\n    ...           [[[0.4], [0.0]], [[0.0], [0.9]]]]\n    >>> axis = (1, 2, 3)\n    >>> loss = keras.losses.dice(y_true, y_pred, axis=axis)\n    >>> assert loss.shape == (2,)\n    >>> loss\n    array([0.5, 0.75757575], shape=(2,), dtype=float32)\n\n    >>> loss = keras.losses.dice(y_true, y_pred)\n    >>> assert loss.shape == ()\n    >>> loss\n    array(0.6164384, shape=(), dtype=float32)\n\n    \"\"\"\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true = ops.cast(y_true, y_pred.dtype)\n\n    inputs = y_true\n    targets = y_pred\n\n    intersection = ops.sum(inputs * targets, axis=axis)\n    dice = ops.divide(\n        2.0 * intersection,\n        ops.sum(y_true, axis=axis)\n        + ops.sum(y_pred, axis=axis)\n        + backend.epsilon(),\n    )\n\n    return 1 - dice\n\n\n@keras_export(\"keras.losses.tversky\")\ndef tversky(y_true, y_pred, alpha=0.5, beta=0.5, axis=None):\n    \"\"\"Computes the Tversky loss value between `y_true` and `y_pred`.\n\n    This loss function is weighted by the alpha and beta coefficients\n    that penalize false positives and false negatives.\n\n    With `alpha=0.5` and `beta=0.5`, the loss value becomes equivalent to\n    Dice Loss.\n\n    Args:\n        y_true: tensor of true targets.\n        y_pred: tensor of predicted targets.\n        alpha: coefficient controlling incidence of false positives.\n        beta: coefficient controlling incidence of false negatives.\n        axis: tuple for which dimensions the loss is calculated.\n\n    Returns:\n        Tversky loss value.\n\n    Reference:\n\n    - [Salehi et al., 2017](https://arxiv.org/abs/1706.05721)\n    \"\"\"\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true = ops.cast(y_true, y_pred.dtype)\n\n    inputs = y_true\n    targets = y_pred\n\n    intersection = ops.sum(inputs * targets, axis=axis)\n    fp = ops.sum((1 - targets) * inputs, axis=axis)\n    fn = ops.sum(targets * (1 - inputs), axis=axis)\n\n    tversky = ops.divide(\n        intersection,\n        intersection + fp * alpha + fn * beta + backend.epsilon(),\n    )\n\n    return 1 - tversky\n\n\n@keras_export(\"keras.losses.circle\")\ndef circle(\n    y_true,\n    y_pred,\n    ref_labels=None,\n    ref_embeddings=None,\n    remove_diagonal=True,\n    gamma=80,\n    margin=0.4,\n):\n    \"\"\"Computes the Circle loss.\n\n    It is designed to minimize within-class distances and maximize between-class\n    distances in L2 normalized embedding space.\n\n    Args:\n        y_true: Tensor with ground truth labels in integer format.\n        y_pred: Tensor with predicted L2 normalized embeddings.\n        ref_labels: Optional integer tensor with labels for reference\n            embeddings. If `None`, defaults to `y_true`.\n        ref_embeddings: Optional tensor with L2 normalized reference embeddings.\n            If `None`, defaults to `y_pred`.\n        remove_diagonal: Boolean, whether to remove self-similarities from\n            positive mask. Defaults to `True`.\n        gamma: Float, scaling factor for the loss. Defaults to `80`.\n        margin: Float, relaxation factor for the loss. Defaults to `0.4`.\n\n    Returns:\n        Circle loss value.\n    \"\"\"\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true = ops.cast(y_true, \"int32\")\n    ref_embeddings = (\n        y_pred\n        if ref_embeddings is None\n        else ops.convert_to_tensor(ref_embeddings)\n    )\n    ref_labels = y_true if ref_labels is None else ops.cast(ref_labels, \"int32\")\n\n    optim_pos = margin\n    optim_neg = 1 + margin\n    delta_pos = margin\n    delta_neg = 1 - margin\n\n    pairwise_cosine_distances = 1 - ops.matmul(\n        y_pred, ops.transpose(ref_embeddings)\n    )\n\n    pairwise_cosine_distances = ops.maximum(pairwise_cosine_distances, 0.0)\n    positive_mask, negative_mask = build_pos_neg_masks(\n        y_true,\n        ref_labels,\n        remove_diagonal=remove_diagonal,\n    )\n    positive_mask = ops.cast(\n        positive_mask, dtype=pairwise_cosine_distances.dtype\n    )\n    negative_mask = ops.cast(\n        negative_mask, dtype=pairwise_cosine_distances.dtype\n    )\n\n    pos_weights = optim_pos + pairwise_cosine_distances\n    pos_weights = pos_weights * positive_mask\n    pos_weights = ops.maximum(pos_weights, 0.0)\n    neg_weights = optim_neg - pairwise_cosine_distances\n    neg_weights = neg_weights * negative_mask\n    neg_weights = ops.maximum(neg_weights, 0.0)\n\n    pos_dists = delta_pos - pairwise_cosine_distances\n    neg_dists = delta_neg - pairwise_cosine_distances\n\n    pos_wdists = -1 * gamma * pos_weights * pos_dists\n    neg_wdists = gamma * neg_weights * neg_dists\n\n    p_loss = ops.logsumexp(\n        ops.where(positive_mask, pos_wdists, float(\"-inf\")),\n        axis=1,\n    )\n    n_loss = ops.logsumexp(\n        ops.where(negative_mask, neg_wdists, float(\"-inf\")),\n        axis=1,\n    )\n\n    circle_loss = ops.softplus(p_loss + n_loss)\n    backend.set_keras_mask(circle_loss, circle_loss > 0)\n    return circle_loss\n\n\n@keras_export(\"keras.losses.categorical_generalized_cross_entropy\")\ndef categorical_generalized_cross_entropy(y_true, y_pred, q):\n    \"\"\"Computes the Generalized Cross Entropy loss.\n\n    Generalized Cross Entropy (GCE) is a noise-robust loss function that\n    provides better robustness against noisy labels than standard cross entropy.\n    It generalizes both cross entropy and mean absolute error through\n    the parameter q, where values closer to 1 make the loss more robust\n    to noisy labels.\n\n    Formula:\n    ```python\n    loss = (1 - p**q) / q\n    ```\n    where `p` is the predicted probability for the true class and `q`\n    is the noise parameter.\n\n    Args:\n        y_true: Ground truth labels. Expected to contain *integer class indices*\n            with shape `[batch_size]` or `[batch_size, 1]`.\n        y_pred: The predicted class probabilities, with shape\n            `[batch_size, num_classes]`.\n        q: Float in range `(0, 1)`. It is the noise parameter.\n           Controls the behavior of the loss:\n            - As `q` approaches 0: Behaves more like cross entropy\n            - As `q` approaches 1: Behaves more like mean absolute error\n\n    Returns:\n        GCE loss values with shape `[batch_size]`.\n    ```\n\n    References:\n        - [Zhang, Sabuncu, 2018](https://arxiv.org/abs/1805.07836)\n          (\"Generalized Cross Entropy Loss for Training\n            Deep Neural Networks with Noisy Labels\")\n    \"\"\"\n\n    # Convert y_true to integer type and one-hot encode\n    y_true_one_hot = ops.one_hot(\n        ops.cast(y_true, \"int\"), num_classes=ops.shape(y_pred)[-1]\n    )\n    y_true_one_hot = ops.cast(y_true_one_hot, y_pred.dtype)\n    # Calculate the probability of the true class\n    p = ops.sum(y_pred * y_true_one_hot, axis=-1)\n\n    # Compute the GCE loss for q in (0,1)\n    gce_loss = (1 - ops.power(p, q)) / q\n\n    return gce_loss\n"
  },
  {
    "path": "keras/src/losses/losses_test.py",
    "content": "import re\n\nimport numpy as np\nimport pytest\n\nfrom keras.src import backend\nfrom keras.src import testing\nfrom keras.src.losses import losses\n\n\nclass MeanSquaredErrorTest(testing.TestCase):\n    def test_config(self):\n        self.run_class_serialization_test(losses.MeanSquaredError(name=\"mymse\"))\n\n    def test_base_function_reduction(self):\n        mse_fn = losses.mean_squared_error\n        y_true = np.array([4, 8, 12])\n        y_pred = np.array([[3], [0], [1]])\n        loss = mse_fn(y_true, y_pred)\n        self.assertEqual(backend.shape(loss), (3,))\n\n    def test_all_correct_unweighted(self):\n        mse_obj = losses.MeanSquaredError()\n        y_true = np.array([[4, 8, 12], [8, 1, 3]])\n        loss = mse_obj(y_true, y_true)\n        self.assertAlmostEqual(loss, 0.0)\n\n    def test_unweighted(self):\n        mse_obj = losses.MeanSquaredError()\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        loss = mse_obj(y_true, y_pred)\n        self.assertAlmostEqual(loss, 49.5)\n\n    def test_scalar_weighted(self):\n        mse_obj = losses.MeanSquaredError()\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        loss = mse_obj(y_true, y_pred, sample_weight=2.3)\n        self.assertAlmostEqual(loss, 113.85)\n\n    def test_sample_weighted(self):\n        mse_obj = losses.MeanSquaredError()\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        sample_weight = np.array([[1.2], [3.4]])\n        loss = mse_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAlmostEqual(loss, 767.8 / 6)\n\n    def test_timestep_weighted(self):\n        mse_obj = losses.MeanSquaredError()\n        y_true = np.asarray([1, 9, 2, -5, -2, 6]).reshape(2, 3, 1)\n        y_pred = np.asarray([4, 8, 12, 8, 1, 3]).reshape(2, 3, 1)\n        sample_weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3))\n        loss = mse_obj(\n            y_true,\n            y_pred,\n            sample_weight=sample_weight,\n        )\n        self.assertAlmostEqual(loss, 97.833336)\n\n    def test_zero_weighted(self):\n        mse_obj = losses.MeanSquaredError()\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        loss = mse_obj(y_true, y_pred, sample_weight=0)\n        self.assertAlmostEqual(loss, 0.0)\n\n    def test_no_reduction(self):\n        mse_obj = losses.MeanSquaredError(reduction=None)\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        loss = mse_obj(y_true, y_pred, sample_weight=2.3)\n        self.assertAlmostEqual(loss, [84.3333, 143.3666])\n\n    def test_sum_reduction(self):\n        mse_obj = losses.MeanSquaredError(reduction=\"sum\")\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        loss = mse_obj(y_true, y_pred, sample_weight=2.3)\n        self.assertAlmostEqual(loss, 227.69998)\n\n    def test_mean_with_sample_weight_reduction(self):\n        mse_obj = losses.MeanSquaredError(reduction=\"mean_with_sample_weight\")\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        sample_weight = np.array([[1.2], [3.4]])\n        loss = mse_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAlmostEqual(\n            loss, (110 / 3 * 1.2 + 187 / 3 * 3.4) / (1.2 + 3.4)\n        )\n\n    def test_dtype_arg(self):\n        mse_obj = losses.MeanSquaredError(dtype=\"bfloat16\")\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        loss = mse_obj(y_true, y_pred)\n        self.assertDType(loss, \"bfloat16\")\n\n\nclass MeanAbsoluteErrorTest(testing.TestCase):\n    def test_config(self):\n        self.run_class_serialization_test(\n            losses.MeanAbsoluteError(name=\"myname\")\n        )\n\n    def test_all_correct_unweighted(self):\n        mae_obj = losses.MeanAbsoluteError()\n        y_true = np.array([[4, 8, 12], [8, 1, 3]])\n        loss = mae_obj(y_true, y_true)\n        self.assertAlmostEqual(loss, 0.0)\n\n    def test_unweighted(self):\n        mae_obj = losses.MeanAbsoluteError()\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        loss = mae_obj(y_true, y_pred)\n        self.assertAlmostEqual(loss, 5.5)\n\n    def test_scalar_weighted(self):\n        mae_obj = losses.MeanAbsoluteError()\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        loss = mae_obj(y_true, y_pred, sample_weight=2.3)\n        self.assertAlmostEqual(loss, 12.65)\n\n    def test_sample_weighted(self):\n        mae_obj = losses.MeanAbsoluteError()\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        sample_weight = np.array([[1.2], [3.4]])\n        loss = mae_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAlmostEqual(loss, 81.4 / 6)\n\n    def test_timestep_weighted(self):\n        mae_obj = losses.MeanAbsoluteError()\n        y_true = np.asarray([1, 9, 2, -5, -2, 6]).reshape(2, 3, 1)\n        y_pred = np.asarray([4, 8, 12, 8, 1, 3]).reshape(2, 3, 1)\n        sample_weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3))\n        loss = mae_obj(\n            y_true,\n            y_pred,\n            sample_weight=sample_weight,\n        )\n        self.assertAlmostEqual(loss, 13.833333)\n\n    def test_zero_weighted(self):\n        mae_obj = losses.MeanAbsoluteError()\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        loss = mae_obj(y_true, y_pred, sample_weight=0)\n        self.assertAlmostEqual(loss, 0.0)\n\n    def test_no_reduction(self):\n        mae_obj = losses.MeanAbsoluteError(reduction=None)\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        loss = mae_obj(y_true, y_pred, sample_weight=2.3)\n        self.assertAlmostEqual(loss, [10.7333, 14.5666])\n\n    def test_sum_reduction(self):\n        mae_obj = losses.MeanAbsoluteError(reduction=\"sum\")\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        loss = mae_obj(y_true, y_pred, sample_weight=2.3)\n        self.assertAlmostEqual(loss, 25.29999)\n\n    def test_mean_with_sample_weight_reduction(self):\n        mae_obj = losses.MeanAbsoluteError(reduction=\"mean_with_sample_weight\")\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        sample_weight = np.array([[1.2], [3.4]])\n        loss = mae_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAlmostEqual(\n            loss, (14 / 3 * 1.2 + 19 / 3 * 3.4) / (1.2 + 3.4)\n        )\n\n    def test_dtype_arg(self):\n        mae_obj = losses.MeanAbsoluteError(dtype=\"bfloat16\")\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        loss = mae_obj(y_true, y_pred)\n        self.assertDType(loss, \"bfloat16\")\n\n\nclass MeanAbsolutePercentageErrorTest(testing.TestCase):\n    def test_config(self):\n        self.run_class_serialization_test(\n            losses.MeanAbsolutePercentageError(name=\"mymape\")\n        )\n\n    def test_all_correct_unweighted(self):\n        mape_obj = losses.MeanAbsolutePercentageError()\n        y_true = np.array([[4, 8, 12], [8, 1, 3]])\n        loss = mape_obj(y_true, y_true)\n        self.assertAlmostEqual(loss, 0.0)\n\n    def test_unweighted(self):\n        mape_obj = losses.MeanAbsolutePercentageError()\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        loss = mape_obj(y_true, y_pred)\n        self.assertAlmostEqual(loss, 211.8518, 3)\n\n    def test_scalar_weighted(self):\n        mape_obj = losses.MeanAbsolutePercentageError()\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        loss = mape_obj(y_true, y_pred, sample_weight=2.3)\n        self.assertAlmostEqual(loss, 487.259, 3)\n\n    def test_sample_weighted(self):\n        mape_obj = losses.MeanAbsolutePercentageError()\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        sample_weight = np.array([[1.2], [3.4]])\n        loss = mape_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAlmostEqual(loss, 422.8888, 3)\n\n    def test_timestep_weighted(self):\n        mape_obj = losses.MeanAbsolutePercentageError()\n        y_true = np.asarray([1, 9, 2, -5, -2, 6]).reshape(2, 3, 1)\n        y_pred = np.asarray([4, 8, 12, 8, 1, 3]).reshape(2, 3, 1)\n        sample_weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3))\n        loss = mape_obj(\n            y_true,\n            y_pred,\n            sample_weight=sample_weight,\n        )\n        self.assertAlmostEqual(loss, 694.4444)\n\n    def test_zero_weighted(self):\n        mape_obj = losses.MeanAbsolutePercentageError()\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        loss = mape_obj(y_true, y_pred, sample_weight=0)\n        self.assertAlmostEqual(loss, 0.0, 3)\n\n    def test_no_reduction(self):\n        mape_obj = losses.MeanAbsolutePercentageError(reduction=None)\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        loss = mape_obj(y_true, y_pred, sample_weight=2.3)\n        self.assertAlmostEqual(loss, [621.8518, 352.6666])\n\n    def test_mean_with_sample_weight_reduction(self):\n        mape_obj = losses.MeanAbsolutePercentageError(\n            reduction=\"mean_with_sample_weight\"\n        )\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        sample_weight = np.array([[1.2], [3.4]])\n        loss = mape_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAlmostEqual(loss, 183.865)\n\n    def test_dtype_arg(self):\n        mape_obj = losses.MeanAbsolutePercentageError(dtype=\"bfloat16\")\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        loss = mape_obj(y_true, y_pred)\n        self.assertDType(loss, \"bfloat16\")\n\n\nclass MeanSquaredLogarithmicErrorTest(testing.TestCase):\n    def test_config(self):\n        self.run_class_serialization_test(\n            losses.MeanSquaredLogarithmicError(name=\"mysloge\")\n        )\n\n    def test_unweighted(self):\n        msle_obj = losses.MeanSquaredLogarithmicError()\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        loss = msle_obj(y_true, y_pred)\n        self.assertAlmostEqual(loss, 1.4370, 3)\n\n    def test_scalar_weighted(self):\n        msle_obj = losses.MeanSquaredLogarithmicError()\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        loss = msle_obj(y_true, y_pred, sample_weight=2.3)\n        self.assertAlmostEqual(loss, 3.3051, 3)\n\n    def test_sample_weighted(self):\n        msle_obj = losses.MeanSquaredLogarithmicError()\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        sample_weight = np.array([[1.2], [3.4]])\n        loss = msle_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAlmostEqual(loss, 3.7856, 3)\n\n    def test_timestep_weighted(self):\n        msle_obj = losses.MeanSquaredLogarithmicError()\n        y_true = np.asarray([1, 9, 2, -5, -2, 6]).reshape(2, 3, 1)\n        y_pred = np.asarray([4, 8, 12, 8, 1, 3]).reshape(2, 3, 1)\n        sample_weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3))\n        loss = msle_obj(\n            y_true,\n            y_pred,\n            sample_weight=sample_weight,\n        )\n        self.assertAlmostEqual(loss, 2.647374)\n\n    def test_zero_weighted(self):\n        msle_obj = losses.MeanSquaredLogarithmicError()\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        loss = msle_obj(y_true, y_pred, sample_weight=0)\n        self.assertAlmostEqual(loss, 0.0, 3)\n\n    def test_mean_with_sample_weight_reduction(self):\n        msle_obj = losses.MeanSquaredLogarithmicError(\n            reduction=\"mean_with_sample_weight\"\n        )\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        sample_weight = np.array([[1.2], [3.4]])\n        loss = msle_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAlmostEqual(loss, 1.646)\n\n    def test_dtype_arg(self):\n        msle_obj = losses.MeanSquaredLogarithmicError(dtype=\"bfloat16\")\n        y_true = np.array([[1, 9, 2], [-5, -2, 6]])\n        y_pred = np.array([[4, 8, 12], [8, 1, 3]], dtype=\"float32\")\n        loss = msle_obj(y_true, y_pred, sample_weight=2.3)\n        self.assertDType(loss, \"bfloat16\")\n\n\nclass HingeTest(testing.TestCase):\n    def test_unweighted(self):\n        y_true = np.array([[0.0, 1.0], [0.0, 0.0]])\n        y_pred = np.array([[0.6, 0.4], [0.4, 0.6]])\n\n        # Reduction = \"sum_over_batch_size\"\n        hinge_obj = losses.Hinge(reduction=\"sum_over_batch_size\")\n        loss = hinge_obj(y_true, y_pred)\n        self.assertAlmostEqual(loss, 1.3, 3)\n\n        # Reduction = \"sum\"\n        hinge_obj = losses.Hinge(reduction=\"sum\")\n        loss = hinge_obj(y_true, y_pred)\n        self.assertAlmostEqual(loss, 2.6, 3)\n\n        # Reduction = None\n        hinge_obj = losses.Hinge(reduction=None)\n        loss = hinge_obj(y_true, y_pred)\n        self.assertAllClose(loss, [1.1, 1.5])\n\n        # Bad reduction\n        with self.assertRaisesRegex(ValueError, \"Invalid value for argument\"):\n            losses.Hinge(reduction=\"abc\")\n\n    def test_weighted(self):\n        y_true = np.array([[0.0, 1.0], [0.0, 0.0]])\n        y_pred = np.array([[0.6, 0.4], [0.4, 0.6]])\n        sample_weight = [1, 0]\n\n        # Reduction = \"sum_over_batch_size\"\n        hinge_obj = losses.Hinge(reduction=\"sum_over_batch_size\")\n        loss = hinge_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAlmostEqual(loss, 0.55, 3)\n\n        # Reduction = \"sum\"\n        hinge_obj = losses.Hinge(reduction=\"sum\")\n        loss = hinge_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAlmostEqual(loss, 1.1, 3)\n\n        # Reduction = None\n        hinge_obj = losses.Hinge(reduction=None)\n        loss = hinge_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAlmostEqual(loss, [1.1, 0.0])\n\n    def test_zero_weighted(self):\n        y_true = np.array([[0.0, 1.0], [0.0, 0.0]])\n        y_pred = np.array([[0.6, 0.4], [0.4, 0.6]])\n        sample_weight = 0.0\n\n        hinge_obj = losses.Hinge()\n        loss = hinge_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertEqual(loss, 0.0)\n\n    def test_dtype_arg(self):\n        hinge_obj = losses.Hinge(dtype=\"bfloat16\")\n        y_true = np.array([[0.0, 1.0], [0.0, 0.0]])\n        y_pred = np.array([[0.6, 0.4], [0.4, 0.6]])\n        loss = hinge_obj(y_true, y_pred)\n        self.assertDType(loss, \"bfloat16\")\n\n\nclass SquaredHingeTest(testing.TestCase):\n    def test_unweighted(self):\n        y_true = np.array([[0.0, 1.0], [0.0, 0.0]])\n        y_pred = np.array([[0.6, 0.4], [0.4, 0.6]])\n\n        # Reduction = \"sum_over_batch_size\"\n        hinge_obj = losses.SquaredHinge(reduction=\"sum_over_batch_size\")\n        loss = hinge_obj(y_true, y_pred)\n        self.assertAlmostEqual(loss, 1.86, 3)\n\n        # Reduction = \"sum\"\n        hinge_obj = losses.SquaredHinge(reduction=\"sum\")\n        loss = hinge_obj(y_true, y_pred)\n        self.assertAlmostEqual(loss, 3.72, 3)\n\n        # Reduction = None\n        hinge_obj = losses.SquaredHinge(reduction=None)\n        loss = hinge_obj(y_true, y_pred)\n        self.assertAllClose(loss, [1.46, 2.26])\n\n        # Bad reduction\n        with self.assertRaisesRegex(ValueError, \"Invalid value for argument\"):\n            losses.SquaredHinge(reduction=\"abc\")\n\n    def test_weighted(self):\n        y_true = np.array([[0.0, 1.0], [0.0, 0.0]])\n        y_pred = np.array([[0.6, 0.4], [0.4, 0.6]])\n        sample_weight = [1, 0]\n\n        # Reduction = \"sum_over_batch_size\"\n        hinge_obj = losses.SquaredHinge(reduction=\"sum_over_batch_size\")\n        loss = hinge_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAlmostEqual(loss, 0.73, 3)\n\n        # Reduction = \"sum\"\n        hinge_obj = losses.SquaredHinge(reduction=\"sum\")\n        loss = hinge_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAlmostEqual(loss, 1.46, 3)\n\n        # Reduction = None\n        hinge_obj = losses.SquaredHinge(reduction=None)\n        loss = hinge_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAlmostEqual(loss, [1.46, 0.0])\n\n    def test_zero_weighted(self):\n        y_true = np.array([[0.0, 1.0], [0.0, 0.0]])\n        y_pred = np.array([[0.6, 0.4], [0.4, 0.6]])\n        sample_weight = 0.0\n\n        hinge_obj = losses.SquaredHinge()\n        loss = hinge_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertEqual(loss, 0.0)\n\n    def test_dtype_arg(self):\n        hinge_obj = losses.SquaredHinge(dtype=\"bfloat16\")\n        y_true = np.array([[0.0, 1.0], [0.0, 0.0]])\n        y_pred = np.array([[0.6, 0.4], [0.4, 0.6]])\n        loss = hinge_obj(y_true, y_pred)\n        self.assertDType(loss, \"bfloat16\")\n\n\nclass CategoricalHingeTest(testing.TestCase):\n    def test_unweighted(self):\n        y_true = np.array([[0.0, 1.0], [0.0, 0.0]])\n        y_pred = np.array([[0.6, 0.4], [0.4, 0.6]])\n\n        # Reduction = \"sum_over_batch_size\"\n        hinge_obj = losses.CategoricalHinge(reduction=\"sum_over_batch_size\")\n        loss = hinge_obj(y_true, y_pred)\n        self.assertAlmostEqual(loss, 1.4, 3)\n\n        # Reduction = \"sum\"\n        hinge_obj = losses.CategoricalHinge(reduction=\"sum\")\n        loss = hinge_obj(y_true, y_pred)\n        self.assertAlmostEqual(loss, 2.8, 3)\n\n        # Reduction = None\n        hinge_obj = losses.CategoricalHinge(reduction=None)\n        loss = hinge_obj(y_true, y_pred)\n        self.assertAllClose(loss, [1.2, 1.6])\n\n        # Bad reduction\n        with self.assertRaisesRegex(ValueError, \"Invalid value for argument\"):\n            losses.CategoricalHinge(reduction=\"abc\")\n\n    def test_weighted(self):\n        y_true = np.array([[0.0, 1.0], [0.0, 0.0]])\n        y_pred = np.array([[0.6, 0.4], [0.4, 0.6]])\n        sample_weight = [1, 0]\n\n        # Reduction = \"sum_over_batch_size\"\n        hinge_obj = losses.CategoricalHinge(reduction=\"sum_over_batch_size\")\n        loss = hinge_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAlmostEqual(loss, 0.6, 3)\n\n        # Reduction = \"sum\"\n        hinge_obj = losses.CategoricalHinge(reduction=\"sum\")\n        loss = hinge_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAlmostEqual(loss, 1.2, 3)\n\n        # Reduction = None\n        hinge_obj = losses.CategoricalHinge(reduction=None)\n        loss = hinge_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAlmostEqual(loss, [1.2, 0.0])\n\n    def test_zero_weighted(self):\n        y_true = np.array([[0.0, 1.0], [0.0, 0.0]])\n        y_pred = np.array([[0.6, 0.4], [0.4, 0.6]])\n        sample_weight = 0.0\n\n        hinge_obj = losses.CategoricalHinge()\n        loss = hinge_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertEqual(loss, 0.0)\n\n    def test_dtype_arg(self):\n        hinge_obj = losses.CategoricalHinge(dtype=\"bfloat16\")\n        y_true = np.array([[0.0, 1.0], [0.0, 0.0]])\n        y_pred = np.array([[0.6, 0.4], [0.4, 0.6]])\n        loss = hinge_obj(y_true, y_pred)\n        self.assertDType(loss, \"bfloat16\")\n\n\nclass CosineSimilarityTest(testing.TestCase):\n    def l2_norm(self, x, axis):\n        epsilon = 1e-12\n        square_sum = np.sum(np.square(x), axis=axis, keepdims=True)\n        x_inv_norm = 1 / np.sqrt(np.maximum(square_sum, epsilon))\n        return np.multiply(x, x_inv_norm)\n\n    def setup(self, axis=1):\n        self.np_y_true = np.asarray([[1, 9, 2], [-5, -2, 6]], dtype=np.float32)\n        self.np_y_pred = np.asarray([[4, 8, 12], [8, 1, 3]], dtype=np.float32)\n\n        y_true = self.l2_norm(self.np_y_true, axis)\n        y_pred = self.l2_norm(self.np_y_pred, axis)\n        self.expected_loss = np.sum(np.multiply(y_true, y_pred), axis=(axis,))\n\n        self.y_true = self.np_y_true\n        self.y_pred = self.np_y_pred\n\n    def test_config(self):\n        cosine_obj = losses.CosineSimilarity(\n            axis=2, reduction=\"sum\", name=\"cosine_loss\"\n        )\n        self.assertEqual(cosine_obj.name, \"cosine_loss\")\n        self.assertEqual(cosine_obj.reduction, \"sum\")\n        config = cosine_obj.get_config()\n        self.assertEqual(config, {\"name\": \"cosine_loss\", \"reduction\": \"sum\"})\n\n    def test_unweighted(self):\n        self.setup()\n        cosine_obj = losses.CosineSimilarity()\n        loss = cosine_obj(self.y_true, self.y_pred)\n        expected_loss = -np.mean(self.expected_loss)\n        self.assertAlmostEqual(loss, expected_loss, 3)\n\n    def test_scalar_weighted(self):\n        self.setup()\n        cosine_obj = losses.CosineSimilarity()\n        sample_weight = 2.3\n        loss = cosine_obj(self.y_true, self.y_pred, sample_weight=sample_weight)\n        expected_loss = -np.mean(self.expected_loss * sample_weight)\n        self.assertAlmostEqual(loss, expected_loss, 3)\n\n    def test_sample_weighted(self):\n        self.setup()\n        cosine_obj = losses.CosineSimilarity()\n        sample_weight = np.asarray([1.2, 3.4])\n        loss = cosine_obj(self.y_true, self.y_pred, sample_weight=sample_weight)\n        expected_loss = -np.mean(self.expected_loss * sample_weight)\n        self.assertAlmostEqual(loss, expected_loss, 3)\n\n    def test_timestep_weighted(self):\n        self.setup()\n        cosine_obj = losses.CosineSimilarity()\n        np_y_true = self.np_y_true.reshape((2, 3, 1))\n        np_y_pred = self.np_y_pred.reshape((2, 3, 1))\n        sample_weight = np.asarray([3, 6, 5, 0, 4, 2]).reshape((2, 3))\n\n        y_true = self.l2_norm(np_y_true, 2)\n        y_pred = self.l2_norm(np_y_pred, 2)\n        expected_loss = np.sum(np.multiply(y_true, y_pred), axis=(2,))\n\n        y_true = np_y_true\n        y_pred = np_y_pred\n        loss = cosine_obj(y_true, y_pred, sample_weight=sample_weight)\n\n        expected_loss = -np.mean(expected_loss * sample_weight)\n        self.assertAlmostEqual(loss, expected_loss, 3)\n\n    def test_zero_weighted(self):\n        self.setup()\n        cosine_obj = losses.CosineSimilarity()\n        loss = cosine_obj(self.y_true, self.y_pred, sample_weight=0)\n        self.assertAlmostEqual(loss, 0.0, 3)\n\n    def test_axis(self):\n        self.setup(axis=1)\n        cosine_obj = losses.CosineSimilarity(axis=1)\n        loss = cosine_obj(self.y_true, self.y_pred)\n        expected_loss = -np.mean(self.expected_loss)\n        self.assertAlmostEqual(loss, expected_loss, 3)\n\n    def test_dtype_arg(self):\n        self.setup()\n        cosine_obj = losses.CosineSimilarity(dtype=\"bfloat16\")\n        loss = cosine_obj(self.y_true, self.y_pred)\n        self.assertDType(loss, \"bfloat16\")\n\n\nclass HuberLossTest(testing.TestCase):\n    def huber_loss(self, y_true, y_pred, delta=1.0):\n        error = y_pred - y_true\n        abs_error = np.abs(error)\n\n        quadratic = np.minimum(abs_error, delta)\n        linear = np.subtract(abs_error, quadratic)\n        return np.add(\n            np.multiply(0.5, np.multiply(quadratic, quadratic)),\n            np.multiply(delta, linear),\n        )\n\n    def setup(self, delta=1.0):\n        self.np_y_pred = np.array([[0.9, 0.2, 0.2], [0.8, 0.4, 0.6]])\n        self.np_y_true = np.array([[1.0, 0.0, 1.0], [1.0, 0.0, 0.0]])\n\n        self.batch_size = 6\n        self.expected_losses = self.huber_loss(\n            self.np_y_true, self.np_y_pred, delta\n        )\n\n        self.y_pred = self.np_y_pred\n        self.y_true = self.np_y_true\n\n    def test_config(self):\n        h_obj = losses.Huber(reduction=\"sum\", name=\"huber\")\n        self.assertEqual(h_obj.name, \"huber\")\n        self.assertEqual(h_obj.reduction, \"sum\")\n        config = h_obj.get_config()\n        self.assertEqual(config, {\"name\": \"huber\", \"reduction\": \"sum\"})\n\n    def test_all_correct(self):\n        self.setup()\n        h_obj = losses.Huber()\n        loss = h_obj(self.y_true, self.y_true)\n        self.assertAlmostEqual(loss, 0.0, 3)\n\n    def test_unweighted(self):\n        self.setup()\n        h_obj = losses.Huber()\n        loss = h_obj(self.y_true, self.y_pred)\n        actual_loss = np.sum(self.expected_losses) / self.batch_size\n        self.assertAlmostEqual(loss, actual_loss, 3)\n\n    def test_scalar_weighted(self):\n        self.setup()\n        h_obj = losses.Huber()\n        sample_weight = 2.3\n        loss = h_obj(self.y_true, self.y_pred, sample_weight=sample_weight)\n        actual_loss = (\n            sample_weight * np.sum(self.expected_losses) / self.batch_size\n        )\n        self.assertAlmostEqual(loss, actual_loss, 3)\n\n        # Verify we get the same output when the same input is given\n        loss_2 = h_obj(self.y_true, self.y_pred, sample_weight=sample_weight)\n        self.assertAlmostEqual(loss, loss_2, 3)\n\n    def test_sample_weighted(self):\n        self.setup()\n        h_obj = losses.Huber()\n        sample_weight = np.array([[1.2], [3.4]])\n\n        loss = h_obj(self.y_true, self.y_pred, sample_weight=sample_weight)\n        actual_loss = np.multiply(\n            self.expected_losses,\n            np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3)),\n        )\n        actual_loss = np.sum(actual_loss) / self.batch_size\n        self.assertAlmostEqual(loss, actual_loss, 3)\n\n    def test_timestep_weighted(self):\n        self.setup()\n        h_obj = losses.Huber()\n        y_pred = self.np_y_pred.reshape((2, 3, 1))\n        y_true = self.np_y_true.reshape((2, 3, 1))\n        expected_losses = self.huber_loss(y_true, y_pred)\n\n        sample_weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3, 1))\n        loss = h_obj(\n            y_true,\n            y_pred,\n            sample_weight=sample_weight,\n        )\n        actual_loss = np.multiply(expected_losses, sample_weight)\n        actual_loss = np.sum(actual_loss) / self.batch_size\n        self.assertAlmostEqual(loss, actual_loss, 3)\n\n    def test_zero_weighted(self):\n        self.setup()\n        h_obj = losses.Huber()\n        sample_weight = 0\n        loss = h_obj(self.y_true, self.y_pred, sample_weight=sample_weight)\n        self.assertAlmostEqual(loss, 0.0, 3)\n\n    def test_non_default_delta(self):\n        self.setup(delta=0.8)\n        h_obj = losses.Huber(delta=0.8)\n        sample_weight = 2.3\n        loss = h_obj(self.y_true, self.y_pred, sample_weight=sample_weight)\n        actual_loss = (\n            sample_weight * np.sum(self.expected_losses) / self.batch_size\n        )\n        self.assertAlmostEqual(loss, actual_loss, 3)\n\n    def test_dtype_arg(self):\n        self.setup()\n        h_obj = losses.Huber(dtype=\"bfloat16\")\n        loss = h_obj(self.y_true, self.y_pred)\n        self.assertDType(loss, \"bfloat16\")\n\n\nclass LogCoshTest(testing.TestCase):\n    def setup(self):\n        y_true = np.asarray([[1, 9, 2], [-5, -2, 6]], dtype=np.float32)\n        y_pred = np.asarray([[4, 8, 12], [8, 1, 3]], dtype=np.float32)\n\n        self.batch_size = 6\n        error = y_pred - y_true\n        self.expected_losses = np.log((np.exp(error) + np.exp(-error)) / 2)\n\n        self.y_true = y_true\n        self.y_pred = y_pred\n\n    def test_config(self):\n        logcosh_obj = losses.LogCosh(reduction=\"sum\", name=\"logcosh_loss\")\n        self.assertEqual(logcosh_obj.name, \"logcosh_loss\")\n        self.assertEqual(logcosh_obj.reduction, \"sum\")\n        config = logcosh_obj.get_config()\n        self.assertEqual(config, {\"name\": \"logcosh_loss\", \"reduction\": \"sum\"})\n\n    def test_unweighted(self):\n        self.setup()\n        logcosh_obj = losses.LogCosh()\n\n        loss = logcosh_obj(self.y_true, self.y_pred)\n        expected_loss = np.sum(self.expected_losses) / self.batch_size\n        self.assertAlmostEqual(loss, expected_loss, 3)\n\n    def test_scalar_weighted(self):\n        self.setup()\n        logcosh_obj = losses.LogCosh()\n        sample_weight = 2.3\n\n        loss = logcosh_obj(\n            self.y_true, self.y_pred, sample_weight=sample_weight\n        )\n        expected_loss = (\n            sample_weight * np.sum(self.expected_losses) / self.batch_size\n        )\n        self.assertAlmostEqual(loss, expected_loss, 3)\n\n        # Verify we get the same output when the same input is given\n        loss_2 = logcosh_obj(\n            self.y_true, self.y_pred, sample_weight=sample_weight\n        )\n        self.assertAlmostEqual(loss, loss_2, 3)\n\n    def test_sample_weighted(self):\n        self.setup()\n        logcosh_obj = losses.LogCosh()\n\n        sample_weight = np.asarray([1.2, 3.4])\n        loss = logcosh_obj(\n            self.y_true, self.y_pred, sample_weight=sample_weight\n        )\n\n        expected_loss = np.multiply(\n            self.expected_losses,\n            np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3)),\n        )\n        expected_loss = np.sum(expected_loss) / self.batch_size\n        self.assertAlmostEqual(loss, expected_loss, 3)\n\n    def test_timestep_weighted(self):\n        self.setup()\n        logcosh_obj = losses.LogCosh()\n        y_true = np.asarray([1, 9, 2, -5, -2, 6]).reshape(2, 3, 1)\n        y_pred = np.asarray([4, 8, 12, 8, 1, 3]).reshape(2, 3, 1)\n        error = y_pred - y_true\n        expected_losses = np.log((np.exp(error) + np.exp(-error)) / 2)\n        sample_weight = np.array([3, 6, 5, 0, 4, 2]).reshape((2, 3, 1))\n\n        loss = logcosh_obj(\n            y_true,\n            y_pred,\n            sample_weight=sample_weight,\n        )\n        expected_loss = (\n            np.sum(expected_losses * sample_weight) / self.batch_size\n        )\n        self.assertAlmostEqual(loss, expected_loss, 3)\n\n    def test_zero_weighted(self):\n        self.setup()\n        logcosh_obj = losses.LogCosh()\n        sample_weight = 0\n        loss = logcosh_obj(\n            self.y_true, self.y_pred, sample_weight=sample_weight\n        )\n        self.assertAlmostEqual(loss, 0.0, 3)\n\n    def test_dtype_arg(self):\n        self.setup()\n        logcosh_obj = losses.LogCosh(dtype=\"bfloat16\")\n        loss = logcosh_obj(self.y_true, self.y_pred)\n        self.assertDType(loss, \"bfloat16\")\n\n\nclass KLDivergenceTest(testing.TestCase):\n    def setup(self):\n        self.y_pred = np.asarray(\n            [0.4, 0.9, 0.12, 0.36, 0.3, 0.4], dtype=np.float32\n        ).reshape((2, 3))\n        self.y_true = np.asarray(\n            [0.5, 0.8, 0.12, 0.7, 0.43, 0.8], dtype=np.float32\n        ).reshape((2, 3))\n\n        self.batch_size = 2\n        self.expected_losses = np.multiply(\n            self.y_true, np.log(self.y_true / self.y_pred)\n        )\n\n    def test_config(self):\n        k_obj = losses.KLDivergence(reduction=\"sum\", name=\"kld\")\n        self.assertEqual(k_obj.name, \"kld\")\n        self.assertEqual(k_obj.reduction, \"sum\")\n\n    def test_unweighted(self):\n        self.setup()\n        k_obj = losses.KLDivergence()\n\n        loss = k_obj(self.y_true, self.y_pred)\n        expected_loss = np.sum(self.expected_losses) / self.batch_size\n        self.assertAlmostEqual(loss, expected_loss, 3)\n\n    def test_scalar_weighted(self):\n        self.setup()\n        k_obj = losses.KLDivergence()\n        sample_weight = 2.3\n\n        loss = k_obj(self.y_true, self.y_pred, sample_weight=sample_weight)\n        expected_loss = (\n            sample_weight * np.sum(self.expected_losses) / self.batch_size\n        )\n        self.assertAlmostEqual(loss, expected_loss, 3)\n\n        # Verify we get the same output when the same input is given\n        loss_2 = k_obj(self.y_true, self.y_pred, sample_weight=sample_weight)\n        self.assertAlmostEqual(loss, loss_2, 3)\n\n    def test_sample_weighted(self):\n        self.setup()\n        k_obj = losses.KLDivergence()\n        sample_weight = np.asarray([1.2, 3.4], dtype=np.float32).reshape((2, 1))\n        loss = k_obj(self.y_true, self.y_pred, sample_weight=sample_weight)\n\n        expected_loss = np.multiply(\n            self.expected_losses,\n            np.asarray(\n                [1.2, 1.2, 1.2, 3.4, 3.4, 3.4], dtype=np.float32\n            ).reshape(2, 3),\n        )\n        expected_loss = np.sum(expected_loss) / self.batch_size\n        self.assertAlmostEqual(loss, expected_loss, 3)\n\n    def test_timestep_weighted(self):\n        self.setup()\n        k_obj = losses.KLDivergence()\n        y_true = self.y_true.reshape(2, 3, 1)\n        y_pred = self.y_pred.reshape(2, 3, 1)\n        sample_weight = np.asarray([3, 6, 5, 0, 4, 2]).reshape(2, 3)\n        expected_losses = np.sum(\n            np.multiply(y_true, np.log(y_true / y_pred)), axis=-1\n        )\n        loss = k_obj(y_true, y_pred, sample_weight=sample_weight)\n\n        num_timesteps = 3\n        expected_loss = np.sum(expected_losses * sample_weight) / (\n            self.batch_size * num_timesteps\n        )\n        self.assertAlmostEqual(loss, expected_loss, 3)\n\n    def test_zero_weighted(self):\n        self.setup()\n        k_obj = losses.KLDivergence()\n        loss = k_obj(self.y_true, self.y_pred, sample_weight=0)\n        self.assertAlmostEqual(loss, 0.0, 3)\n\n    def test_dtype_arg(self):\n        self.setup()\n        k_obj = losses.KLDivergence(dtype=\"bfloat16\")\n        loss = k_obj(self.y_true, self.y_pred)\n        self.assertDType(loss, \"bfloat16\")\n\n\nclass PoissonTest(testing.TestCase):\n    def setup(self):\n        self.y_pred = np.asarray([1, 9, 2, 5, 2, 6], dtype=np.float32).reshape(\n            (2, 3)\n        )\n        self.y_true = np.asarray([4, 8, 12, 8, 1, 3], dtype=np.float32).reshape(\n            (2, 3)\n        )\n\n        self.batch_size = 6\n        self.expected_losses = self.y_pred - np.multiply(\n            self.y_true, np.log(self.y_pred)\n        )\n\n    def test_config(self):\n        poisson_obj = losses.Poisson(reduction=\"sum\", name=\"poisson\")\n        self.assertEqual(poisson_obj.name, \"poisson\")\n        self.assertEqual(poisson_obj.reduction, \"sum\")\n\n    def test_unweighted(self):\n        self.setup()\n        poisson_obj = losses.Poisson()\n\n        loss = poisson_obj(self.y_true, self.y_pred)\n        expected_loss = np.sum(self.expected_losses) / self.batch_size\n        self.assertAlmostEqual(loss, expected_loss, 3)\n\n    def test_scalar_weighted(self):\n        self.setup()\n        poisson_obj = losses.Poisson()\n        sample_weight = 2.3\n        loss = poisson_obj(\n            self.y_true, self.y_pred, sample_weight=sample_weight\n        )\n        expected_loss = (\n            sample_weight * np.sum(self.expected_losses) / self.batch_size\n        )\n        self.assertAlmostEqual(loss, expected_loss, 3)\n        self.assertAlmostEqual(loss, expected_loss, 3)\n\n        # Verify we get the same output when the same input is given\n        loss_2 = poisson_obj(\n            self.y_true, self.y_pred, sample_weight=sample_weight\n        )\n        self.assertAlmostEqual(loss, loss_2, 3)\n\n    def test_sample_weighted(self):\n        self.setup()\n        poisson_obj = losses.Poisson()\n\n        sample_weight = np.asarray([1.2, 3.4]).reshape((2, 1))\n        loss = poisson_obj(\n            self.y_true, self.y_pred, sample_weight=sample_weight\n        )\n\n        expected_loss = np.multiply(\n            self.expected_losses,\n            np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape((2, 3)),\n        )\n        expected_loss = np.sum(expected_loss) / self.batch_size\n        self.assertAlmostEqual(loss, expected_loss, 3)\n\n    def test_timestep_weighted(self):\n        self.setup()\n        poisson_obj = losses.Poisson()\n        y_true = self.y_true.reshape(2, 3, 1)\n        y_pred = self.y_pred.reshape(2, 3, 1)\n        sample_weight = np.asarray([3, 6, 5, 0, 4, 2]).reshape(2, 3, 1)\n        expected_losses = y_pred - np.multiply(y_true, np.log(y_pred))\n\n        loss = poisson_obj(\n            y_true,\n            y_pred,\n            sample_weight=np.asarray(sample_weight).reshape((2, 3)),\n        )\n        expected_loss = (\n            np.sum(expected_losses * sample_weight) / self.batch_size\n        )\n        self.assertAlmostEqual(loss, expected_loss, 3)\n\n    def test_zero_weighted(self):\n        self.setup()\n        poisson_obj = losses.Poisson()\n        loss = poisson_obj(self.y_true, self.y_pred, sample_weight=0)\n        self.assertAlmostEqual(loss, 0.0, 3)\n\n    def test_dtype_arg(self):\n        self.setup()\n        poisson_obj = losses.Poisson(dtype=\"bfloat16\")\n        loss = poisson_obj(self.y_true, self.y_pred)\n        self.assertDType(loss, \"bfloat16\")\n\n\nclass BinaryCrossentropyTest(testing.TestCase):\n    def test_config(self):\n        self.run_class_serialization_test(\n            losses.BinaryCrossentropy(name=\"bce\", axis=-1)\n        )\n\n    def test_all_correct_unweighted(self):\n        y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=\"float32\")\n        bce_obj = losses.BinaryCrossentropy()\n        loss = bce_obj(y_true, y_true)\n        self.assertAlmostEqual(loss, 0.0)\n\n        # Test with logits.\n        logits = np.array(\n            [\n                [10.0, -10.0, -10.0],\n                [-10.0, 10.0, -10.0],\n                [-10.0, -10.0, 10.0],\n            ]\n        )\n        bce_obj = losses.BinaryCrossentropy(from_logits=True)\n        loss = bce_obj(y_true, logits)\n        self.assertAlmostEqual(loss, 0.0)\n\n    def test_unweighted(self):\n        y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=\"float32\")\n        y_pred = np.array(\n            [[0.9, 0.1, 0.2], [0.3, 0.8, 0.1], [0.1, 0.2, 0.7]], dtype=\"float32\"\n        )\n        bce_obj = losses.BinaryCrossentropy()\n        loss = bce_obj(y_true, y_pred)\n        self.assertAllClose(loss, 0.20046903)\n\n        y_true = np.array([1, 0, 1, 0]).reshape([2, 2])\n        y_pred = np.array([1, 1, 1, 0], dtype=np.float32).reshape([2, 2])\n        bce_obj = losses.BinaryCrossentropy()\n        loss = bce_obj(y_true, y_pred)\n        self.assertAlmostEqual(loss, 3.98559)\n\n        # Test with logits.\n        y_true = np.array([[1, 0, 1], [0, 1, 1]])\n        logits = np.array([[10.0, -10.0, 10.0], [10.0, 10.0, -10.0]])\n        bce_obj = losses.BinaryCrossentropy(from_logits=True)\n        loss = bce_obj(y_true, logits)\n        self.assertAlmostEqual(loss, 3.3333)\n\n    def test_scalar_weighted(self):\n        bce_obj = losses.BinaryCrossentropy()\n        y_true = np.array([1, 0, 1, 0]).reshape([2, 2])\n        y_pred = np.array([1, 1, 1, 0], dtype=\"float32\").reshape([2, 2])\n        loss = bce_obj(y_true, y_pred, sample_weight=2.3)\n        self.assertAlmostEqual(loss, 9.1668)\n\n        # Test with logits.\n        y_true = np.array([[1, 0, 1], [0, 1, 1]])\n        logits = np.array([[10.0, -10.0, 10.0], [10.0, 10.0, -10.0]])\n        bce_obj = losses.BinaryCrossentropy(from_logits=True)\n        loss = bce_obj(y_true, logits, sample_weight=2.3)\n        self.assertAlmostEqual(loss, 7.666)\n\n    def test_sample_weighted(self):\n        bce_obj = losses.BinaryCrossentropy()\n        y_true = np.array([1, 0, 1, 0]).reshape([2, 2])\n        y_pred = np.array([1, 1, 1, 0], dtype=\"float32\").reshape([2, 2])\n        sample_weight = np.array([1.2, 3.4]).reshape((2, 1))\n        loss = bce_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAlmostEqual(loss, 4.7827)\n\n        # Test with logits.\n        y_true = np.array([[1, 0, 1], [0, 1, 1]])\n        logits = np.array([[10.0, -10.0, 10.0], [10.0, 10.0, -10.0]])\n        weights = np.array([4, 3])\n        bce_obj = losses.BinaryCrossentropy(from_logits=True)\n        loss = bce_obj(y_true, logits, sample_weight=weights)\n        self.assertAlmostEqual(loss, 10.0)\n\n    def test_no_reduction(self):\n        y_true = np.array([[1, 0, 1], [0, 1, 1]])\n        logits = np.array([[10.0, -10.0, 10.0], [10.0, 10.0, -10.0]])\n        bce_obj = losses.BinaryCrossentropy(from_logits=True, reduction=None)\n        loss = bce_obj(y_true, logits)\n        self.assertAllClose(loss, [0.0, 6.666], atol=1e-3)\n\n    def test_label_smoothing(self):\n        logits = np.array([[10.0, -10.0, -10.0]])\n        y_true = np.array([[1, 0, 1]])\n        label_smoothing = 0.1\n        bce_obj = losses.BinaryCrossentropy(\n            from_logits=True, label_smoothing=label_smoothing\n        )\n        loss = bce_obj(y_true, logits)\n        expected_value = (10.0 + 5.0 * label_smoothing) / 3.0\n        self.assertAlmostEqual(loss, expected_value)\n\n    def test_shape_mismatch(self):\n        y_true = np.array([[0], [1], [2]])\n        y_pred = np.array(\n            [[0.9, 0.05, 0.05], [0.5, 0.89, 0.6], [0.05, 0.01, 0.94]]\n        )\n        cce_obj = losses.BinaryCrossentropy()\n        with self.assertRaisesRegex(ValueError, \"must have the same shape\"):\n            cce_obj(y_true, y_pred)\n\n    @pytest.mark.skipif(\n        backend.backend() == \"torch\",\n        reason=\"Torch doesn't support bfloat16 for BinaryCrossentropy\",\n    )\n    def test_dtype_arg(self):\n        y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=\"float32\")\n        y_pred = np.array(\n            [[0.9, 0.1, 0.2], [0.3, 0.8, 0.1], [0.1, 0.2, 0.7]], dtype=\"float32\"\n        )\n        bce_obj = losses.BinaryCrossentropy(dtype=\"bfloat16\")\n        loss = bce_obj(y_true, y_pred)\n        self.assertDType(loss, \"bfloat16\")\n\n\nclass CategoricalCrossentropyTest(testing.TestCase):\n    def test_config(self):\n        self.run_class_serialization_test(\n            losses.CategoricalCrossentropy(name=\"cce\", axis=-1)\n        )\n\n    def test_all_correct_unweighted(self):\n        y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=\"int64\")\n        y_pred = np.array(\n            [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],\n            dtype=\"float32\",\n        )\n        cce_obj = losses.CategoricalCrossentropy()\n        loss = cce_obj(y_true, y_pred)\n        self.assertAlmostEqual(loss, 0.0)\n\n        # Test with logits.\n        logits = np.array(\n            [[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]\n        )\n        cce_obj = losses.CategoricalCrossentropy(from_logits=True)\n        loss = cce_obj(y_true, logits)\n        self.assertAlmostEqual(loss, 0.0)\n\n    def test_unweighted(self):\n        cce_obj = losses.CategoricalCrossentropy()\n        y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])\n        y_pred = np.array(\n            [[0.9, 0.05, 0.05], [0.5, 0.89, 0.6], [0.05, 0.01, 0.94]],\n            dtype=\"float32\",\n        )\n        loss = cce_obj(y_true, y_pred)\n        self.assertAlmostEqual(loss, 0.3239)\n\n        # Test with logits.\n        logits = np.array([[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]])\n        cce_obj = losses.CategoricalCrossentropy(from_logits=True)\n        loss = cce_obj(y_true, logits)\n        self.assertAlmostEqual(loss, 0.0573)\n\n    def test_scalar_weighted(self):\n        cce_obj = losses.CategoricalCrossentropy()\n        y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])\n        y_pred = np.array(\n            [[0.9, 0.05, 0.05], [0.5, 0.89, 0.6], [0.05, 0.01, 0.94]],\n            dtype=\"float32\",\n        )\n        loss = cce_obj(y_true, y_pred, sample_weight=2.3)\n        self.assertAlmostEqual(loss, 0.7449)\n\n        # Test with logits.\n        logits = np.array([[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]])\n        cce_obj = losses.CategoricalCrossentropy(from_logits=True)\n        loss = cce_obj(y_true, logits, sample_weight=2.3)\n        self.assertAlmostEqual(loss, 0.1317)\n\n    def test_sample_weighted(self):\n        cce_obj = losses.CategoricalCrossentropy()\n        y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])\n        y_pred = np.array(\n            [[0.9, 0.05, 0.05], [0.5, 0.89, 0.6], [0.05, 0.01, 0.94]],\n            dtype=\"float32\",\n        )\n        sample_weight = np.array([[1.2], [3.4], [5.6]]).reshape((3, 1))\n        loss = cce_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAlmostEqual(loss, 1.0696)\n\n        # Test with logits.\n        logits = np.array([[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]])\n        cce_obj = losses.CategoricalCrossentropy(from_logits=True)\n        loss = cce_obj(y_true, logits, sample_weight=sample_weight)\n        self.assertAlmostEqual(loss, 0.31829)\n\n    def test_no_reduction(self):\n        y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])\n        logits = np.array([[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]])\n        cce_obj = losses.CategoricalCrossentropy(\n            from_logits=True, reduction=None\n        )\n        loss = cce_obj(y_true, logits)\n        self.assertAllClose((0.001822, 0.000459, 0.169846), loss)\n\n    def test_label_smoothing(self):\n        logits = np.array([[100.0, -100.0, -100.0]])\n        y_true = np.array([[1, 0, 0]])\n        label_smoothing = 0.1\n        cce_obj = losses.CategoricalCrossentropy(\n            from_logits=True, label_smoothing=label_smoothing\n        )\n        loss = cce_obj(y_true, logits)\n        expected_value = 400.0 * label_smoothing / 3.0\n        self.assertAlmostEqual(loss, expected_value)\n\n    def test_label_smoothing_ndarray(self):\n        logits = np.asarray([[100.0, -100.0, -100.0]])\n        y_true = np.asarray([[1, 0, 0]])\n        label_smoothing = 0.1\n        cce_obj = losses.CategoricalCrossentropy(\n            from_logits=True, label_smoothing=label_smoothing\n        )\n        loss = cce_obj(y_true, logits)\n        expected_value = 400.0 * label_smoothing / 3.0\n        self.assertAlmostEqual(loss, expected_value)\n\n    def test_shape_mismatch(self):\n        y_true = np.array([[0], [1], [2]])\n        y_pred = np.array(\n            [[0.9, 0.05, 0.05], [0.5, 0.89, 0.6], [0.05, 0.01, 0.94]]\n        )\n\n        cce_obj = losses.CategoricalCrossentropy()\n        with self.assertRaisesRegex(ValueError, \"must have the same shape\"):\n            cce_obj(y_true, y_pred)\n\n    def test_dtype_arg(self):\n        y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=\"int64\")\n        y_pred = np.array(\n            [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],\n            dtype=\"float32\",\n        )\n        cce_obj = losses.CategoricalCrossentropy(dtype=\"bfloat16\")\n        loss = cce_obj(y_true, y_pred)\n        self.assertDType(loss, \"bfloat16\")\n\n\nclass SparseCategoricalCrossentropyTest(testing.TestCase):\n    def test_config(self):\n        self.run_class_serialization_test(\n            losses.SparseCategoricalCrossentropy(name=\"scce\")\n        )\n\n    def test_all_correct_unweighted(self):\n        y_true = np.array([[0], [1], [2]], dtype=\"int64\")\n        y_pred = np.array(\n            [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],\n            dtype=\"float32\",\n        )\n        cce_obj = losses.SparseCategoricalCrossentropy()\n        loss = cce_obj(y_true, y_pred)\n        self.assertAlmostEqual(loss, 0.0, 3)\n\n        # Test with logits.\n        logits = np.array(\n            [[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]\n        )\n        cce_obj = losses.SparseCategoricalCrossentropy(from_logits=True)\n        loss = cce_obj(y_true, logits)\n        self.assertAlmostEqual(loss, 0.0, 3)\n\n    def test_unweighted(self):\n        cce_obj = losses.SparseCategoricalCrossentropy()\n        y_true = np.array([0, 1, 2])\n        y_pred = np.array(\n            [[0.9, 0.05, 0.05], [0.5, 0.89, 0.6], [0.05, 0.01, 0.94]],\n            dtype=\"float32\",\n        )\n        loss = cce_obj(y_true, y_pred)\n        self.assertAlmostEqual(loss, 0.3239, 3)\n\n        # Test with logits.\n        logits = np.array([[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]])\n        cce_obj = losses.SparseCategoricalCrossentropy(from_logits=True)\n        loss = cce_obj(y_true, logits)\n        self.assertAlmostEqual(loss, 0.0573, 3)\n\n    def test_scalar_weighted(self):\n        cce_obj = losses.SparseCategoricalCrossentropy()\n        y_true = np.array([[0], [1], [2]])\n        y_pred = np.array(\n            [[0.9, 0.05, 0.05], [0.5, 0.89, 0.6], [0.05, 0.01, 0.94]],\n            dtype=\"float32\",\n        )\n        loss = cce_obj(y_true, y_pred, sample_weight=2.3)\n        self.assertAlmostEqual(loss, 0.7449, 3)\n\n        # Test with logits.\n        logits = np.array([[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]])\n        cce_obj = losses.SparseCategoricalCrossentropy(from_logits=True)\n        loss = cce_obj(y_true, logits, sample_weight=2.3)\n        self.assertAlmostEqual(loss, 0.1317, 3)\n\n    def test_sample_weighted(self):\n        cce_obj = losses.SparseCategoricalCrossentropy()\n        y_true = np.array([[0], [1], [2]])\n        y_pred = np.array(\n            [[0.9, 0.05, 0.05], [0.5, 0.89, 0.6], [0.05, 0.01, 0.94]],\n            dtype=\"float32\",\n        )\n        sample_weight = np.array([[1.2], [3.4], [5.6]]).reshape((3, 1))\n        loss = cce_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAlmostEqual(loss, 1.0696, 3)\n\n        # Test with logits.\n        logits = np.array([[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]])\n        cce_obj = losses.SparseCategoricalCrossentropy(from_logits=True)\n        loss = cce_obj(y_true, logits, sample_weight=sample_weight)\n        self.assertAlmostEqual(loss, 0.31829, 3)\n\n    def test_no_reduction(self):\n        y_true = np.array([[0], [1], [2]])\n        logits = np.array([[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]])\n        cce_obj = losses.SparseCategoricalCrossentropy(\n            from_logits=True, reduction=None\n        )\n        loss = cce_obj(y_true, logits)\n        self.assertAllClose((0.001822, 0.000459, 0.169846), loss)\n\n    def test_ignore_class(self):\n        y_true = np.array([[-1, 2]])\n        logits = np.array([[[0.854, 0.698, 0.598], [0.088, 0.86, 0.018]]])\n        cce_obj = losses.SparseCategoricalCrossentropy(\n            from_logits=True, ignore_class=-1, reduction=None\n        )\n        loss = cce_obj(y_true, logits)\n        self.assertAllClose([[0.0, 1.480129]], loss)\n\n        y_true = np.array([[[-1], [2]]])\n        logits = np.array([[[0.854, 0.698, 0.598], [0.088, 0.86, 0.018]]])\n        cce_obj = losses.SparseCategoricalCrossentropy(\n            from_logits=True, ignore_class=-1, reduction=None\n        )\n        loss = cce_obj(y_true, logits)\n        self.assertAllClose([[0.0, 1.480129]], loss)\n\n    def test_binary_segmentation(self):\n        y_true = np.array(\n            [[0, 1, 1, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]]\n        )\n        y_pred = np.array(\n            [\n                [[1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]],\n                [[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [1.0, 0.0]],\n                [[1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]],\n                [[0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]],\n            ]\n        )\n        output = losses.SparseCategoricalCrossentropy()(y_true, y_pred)\n        self.assertAllClose(output, 0.0)\n\n        y_true = np.array(\n            [[0, 1, 1, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]]\n        )\n        y_pred = np.array(\n            [\n                [[1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [0.2, 0.8]],\n                [[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [1.0, 0.0]],\n                [[1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]],\n                [[0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.6, 0.4]],\n            ]\n        )\n        expected = np.array([-np.log(0.2), -np.log(0.4)])\n        output = losses.SparseCategoricalCrossentropy()(y_true, y_pred)\n        self.assertAllClose(output, expected.sum() / 16.0)  # 16 pixels\n\n    def test_binary_segmentation_different_axis(self):\n        y_true = np.array(\n            [[0, 1, 1, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]]\n        )\n        y_pred = np.array(\n            [\n                [[1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]],\n                [[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [1.0, 0.0]],\n                [[1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]],\n                [[0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]],\n            ]\n        )\n        y_pred_reshaped = np.moveaxis(y_pred, source=2, destination=0)\n        if backend.backend() == \"tensorflow\":\n            expected_message = (\n                \"Only axis=-1 is currently supported. Received: axis=0\"\n            )\n            escaped_message = re.escape(expected_message)\n\n            with pytest.raises(ValueError, match=escaped_message):\n                losses.SparseCategoricalCrossentropy(axis=0)(\n                    y_true, y_pred_reshaped\n                )\n        elif backend.backend() == \"jax\":\n            expected_message = (\n                \"Arguments `target` and `output` \"\n                \"must have the same shape up until\"\n                \" the last dimension: target.shape=(4, 4),\"\n                \" output.shape=(2, 4, 4)\"\n            )\n            escaped_message = re.escape(expected_message)\n\n            with pytest.raises(ValueError, match=escaped_message):\n                losses.SparseCategoricalCrossentropy(axis=0)(\n                    y_true, y_pred_reshaped\n                )\n        elif backend.backend() == \"torch\":\n            output = losses.SparseCategoricalCrossentropy(axis=0)(\n                y_true, y_pred_reshaped\n            )\n            self.assertAllClose(output, 0.0)\n\n        if backend.backend() == \"torch\":\n            y_true = np.array(\n                [[0, 1, 1, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]]\n            )\n            y_pred = np.array(\n                [\n                    [[1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [0.2, 0.8]],\n                    [[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [1.0, 0.0]],\n                    [[1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]],\n                    [[0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.6, 0.4]],\n                ]\n            )\n            y_pred_reshaped = np.moveaxis(y_pred, source=2, destination=0)\n            expected = np.array([-np.log(0.2), -np.log(0.4)])\n            output = losses.SparseCategoricalCrossentropy(axis=0)(\n                y_true, y_pred_reshaped\n            )\n            self.assertAllClose(output, expected.sum() / 16.0)\n\n            y_true = np.array([y_true, y_true, y_true])\n            y_pred_reshaped = np.array(\n                [y_pred_reshaped, y_pred_reshaped, y_pred_reshaped]\n            )\n            output = losses.SparseCategoricalCrossentropy(axis=1)(\n                y_true, y_pred_reshaped\n            )\n            self.assertAllClose(output, expected.sum() / 16.0)\n\n    def test_multi_class_segmentation(self):\n        y_true = np.array(\n            [[0, 1, 2, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]]\n        )\n        y_pred = np.array(\n            [\n                [\n                    [1.0, 0.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                    [0.0, 0.0, 1.0],\n                    [1.0, 0.0, 0.0],\n                ],\n                [\n                    [0.0, 1.0, 0.0],\n                    [1.0, 0.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                    [1.0, 0.0, 0.0],\n                ],\n                [\n                    [1.0, 0.0, 0.0],\n                    [1.0, 0.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                ],\n                [\n                    [0.0, 1.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                    [1.0, 0.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                ],\n            ]\n        )\n        output = losses.SparseCategoricalCrossentropy()(y_true, y_pred)\n        self.assertAllClose(output, 0.0)\n\n        y_true = np.array(\n            [[0, 1, 2, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]]\n        )\n        y_pred = np.array(\n            [\n                [\n                    [1.0, 0.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                    [0.0, 0.0, 1.0],\n                    [0.2, 0.0, 0.8],\n                ],\n                [\n                    [0.7, 0.3, 0.0],\n                    [1.0, 0.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                    [1.0, 0.0, 0.0],\n                ],\n                [\n                    [1.0, 0.0, 0.0],\n                    [1.0, 0.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                ],\n                [\n                    [0.0, 1.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                    [0.5, 0.5, 0.0],\n                    [0.0, 1.0, 0.0],\n                ],\n            ]\n        )\n        expected = np.array(\n            [\n                -np.log(0.2),\n                -np.log(0.3),\n                -np.log(0.5),\n            ]\n        )\n        output = losses.SparseCategoricalCrossentropy()(y_true, y_pred)\n        self.assertAllClose(output, expected.sum() / 16.0)  # 16 pixels\n\n    def test_multi_class_segmentation_different_axis(self):\n        y_true = np.array(\n            [[0, 1, 2, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]]\n        )\n        y_pred = np.array(\n            [\n                [\n                    [1.0, 0.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                    [0.0, 0.0, 1.0],\n                    [1.0, 0.0, 0.0],\n                ],\n                [\n                    [0.0, 1.0, 0.0],\n                    [1.0, 0.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                    [1.0, 0.0, 0.0],\n                ],\n                [\n                    [1.0, 0.0, 0.0],\n                    [1.0, 0.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                ],\n                [\n                    [0.0, 1.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                    [1.0, 0.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                ],\n            ]\n        )\n        y_pred_reshaped = np.moveaxis(y_pred, source=2, destination=0)\n        if backend.backend() == \"tensorflow\":\n            expected_message = (\n                \"Only axis=-1 is currently supported. Received: axis=0\"\n            )\n            escaped_message = re.escape(expected_message)\n\n            with pytest.raises(ValueError, match=escaped_message):\n                losses.SparseCategoricalCrossentropy(axis=0)(\n                    y_true, y_pred_reshaped\n                )\n        elif backend.backend() == \"jax\":\n            expected_message = (\n                \"Arguments `target` and `output` \"\n                \"must have the same shape up until\"\n                \" the last dimension: target.shape=(4, 4),\"\n                \" output.shape=(3, 4, 4)\"\n            )\n            escaped_message = re.escape(expected_message)\n\n            with pytest.raises(ValueError, match=escaped_message):\n                losses.SparseCategoricalCrossentropy(axis=0)(\n                    y_true, y_pred_reshaped\n                )\n        elif backend.backend() == \"torch\":\n            output = losses.SparseCategoricalCrossentropy(axis=0)(\n                y_true, y_pred_reshaped\n            )\n            self.assertAllClose(output, 0.0)\n\n        if backend.backend() == \"torch\":\n            y_true = np.array(\n                [[0, 1, 2, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]]\n            )\n            y_pred = np.array(\n                [\n                    [\n                        [1.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0],\n                        [0.0, 0.0, 1.0],\n                        [0.2, 0.0, 0.8],\n                    ],\n                    [\n                        [0.7, 0.3, 0.0],\n                        [1.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0],\n                        [1.0, 0.0, 0.0],\n                    ],\n                    [\n                        [1.0, 0.0, 0.0],\n                        [1.0, 0.0, 0.0],\n                        [0.0, 1.0, 0.0],\n                        [0.0, 1.0, 0.0],\n                    ],\n                    [\n                        [0.0, 1.0, 0.0],\n                        [0.0, 1.0, 0.0],\n                        [0.5, 0.5, 0.0],\n                        [0.0, 1.0, 0.0],\n                    ],\n                ]\n            )\n            expected = np.array(\n                [\n                    -np.log(0.2),\n                    -np.log(0.3),\n                    -np.log(0.5),\n                ]\n            )\n            y_pred_reshaped = np.moveaxis(y_pred, source=2, destination=0)\n            output = losses.SparseCategoricalCrossentropy(axis=0)(\n                y_true, y_pred_reshaped\n            )\n            self.assertAllClose(output, expected.sum() / 16.0)\n            y_true = np.array([y_true, y_true, y_true])\n            y_pred_reshaped = np.array(\n                [y_pred_reshaped, y_pred_reshaped, y_pred_reshaped]\n            )\n            output = losses.SparseCategoricalCrossentropy(axis=1)(\n                y_true, y_pred_reshaped\n            )\n            self.assertAllClose(output, expected.sum() / 16.0)\n\n    def test_dtype_arg(self):\n        y_true = np.array([[0], [1], [2]], dtype=\"int64\")\n        y_pred = np.array(\n            [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],\n            dtype=\"float32\",\n        )\n        cce_obj = losses.SparseCategoricalCrossentropy(dtype=\"bfloat16\")\n        loss = cce_obj(y_true, y_pred)\n        self.assertDType(loss, \"bfloat16\")\n\n\nclass BinaryFocalCrossentropyTest(testing.TestCase):\n    def test_config(self):\n        self.run_class_serialization_test(\n            losses.BinaryFocalCrossentropy(name=\"bfce\")\n        )\n\n    def test_all_correct_unweighted(self):\n        y_true = np.array(\n            [\n                [1, 0, 0],\n                [0, 1, 0],\n                [0, 0, 1],\n            ],\n            dtype=\"float32\",\n        )\n        obj = losses.BinaryFocalCrossentropy(gamma=1.5)\n        loss = obj(y_true, y_true)\n        self.assertAlmostEqual(loss, 0.0, 3)\n\n        # Test with logits.\n        logits = np.array(\n            [\n                [100.0, -100.0, -100.0],\n                [-100.0, 100.0, -100.0],\n                [-100.0, -100.0, 100.0],\n            ]\n        )\n        obj = losses.BinaryFocalCrossentropy(gamma=2.0, from_logits=True)\n        loss = obj(y_true, logits)\n        self.assertAlmostEqual(loss, 0.0, 3)\n\n    def test_unweighted(self):\n        y_true = np.asarray([1, 0, 1, 0]).reshape([2, 2])\n        y_pred = np.asarray([0.9, 0.8, 0.7, 0.2], dtype=np.float32).reshape(\n            [2, 2]\n        )\n        obj = losses.BinaryFocalCrossentropy(gamma=2.0)\n        loss = obj(y_true, y_pred)\n        self.assertAlmostEqual(loss, 0.268, 3)\n\n        # Test with logits.\n        y_true = np.array([[1, 1, 0], [0, 1, 0]], dtype=\"float32\")\n        logits = np.array([[1.5, -2.7, 2.9], [-3.8, 1.2, -4.5]])\n        obj = losses.BinaryFocalCrossentropy(gamma=3.0, from_logits=True)\n        loss = obj(y_true, logits)\n        self.assertAlmostEqual(loss, 0.799, 3)\n\n    def test_scalar_weighted(self):\n        y_true = np.asarray([1, 0, 1, 0]).reshape([2, 2])\n        y_pred = np.asarray([0.9, 0.8, 0.7, 0.2], dtype=np.float32).reshape(\n            [2, 2]\n        )\n        obj = losses.BinaryFocalCrossentropy(gamma=2.0)\n        loss = obj(y_true, y_pred, sample_weight=1.23)\n        self.assertAlmostEqual(loss, 0.3296, 3)\n\n        # Test with logits.\n        y_true = np.array([[1, 1, 0], [0, 1, 0]], dtype=\"float32\")\n        logits = np.array([[1.5, -2.7, 2.9], [-3.8, 1.2, -4.5]])\n        obj = losses.BinaryFocalCrossentropy(gamma=3.0, from_logits=True)\n        loss = obj(y_true, logits, sample_weight=3.21)\n        self.assertAlmostEqual(loss, 2.565, 3)\n\n    def test_sample_weighted(self):\n        y_true = np.asarray([1, 0, 1, 0]).reshape([2, 2])\n        y_pred = np.asarray([0.9, 0.8, 0.7, 0.2], dtype=np.float32).reshape(\n            [2, 2]\n        )\n        sample_weight = np.array([1.2, 3.4]).reshape((2, 1))\n        obj = losses.BinaryFocalCrossentropy(gamma=2.0)\n        loss = obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAlmostEqual(loss, 0.34415, 3)\n\n        # Test with logits.\n        y_true = np.array([[1, 1, 0], [0, 1, 0]], dtype=\"float32\")\n        logits = np.array([[1.5, -2.7, 2.9], [-3.8, 1.2, -4.5]])\n        obj = losses.BinaryFocalCrossentropy(gamma=3.0, from_logits=True)\n        loss = obj(y_true, logits, sample_weight=sample_weight)\n        self.assertAlmostEqual(loss, 0.95977, 3)\n\n    def test_no_reduction(self):\n        y_true = np.asarray([1, 0, 1, 0]).reshape([2, 2])\n        y_pred = np.asarray([0.9, 0.8, 0.7, 0.2], dtype=np.float32).reshape(\n            [2, 2]\n        )\n        obj = losses.BinaryFocalCrossentropy(\n            gamma=2.0,\n            reduction=None,\n        )\n        loss = obj(y_true, y_pred)\n        self.assertAllClose(loss, (0.515547, 0.020513))\n\n    @pytest.mark.skipif(\n        backend.backend() == \"torch\",\n        reason=\"Torch doesn't support bfloat16 for BinaryFocalCrossentropy\",\n    )\n    def test_dtype_arg(self):\n        y_true = np.asarray([1, 0, 1, 0]).reshape([2, 2])\n        y_pred = np.asarray([0.9, 0.8, 0.7, 0.2], dtype=np.float32).reshape(\n            [2, 2]\n        )\n        obj = losses.BinaryFocalCrossentropy(dtype=\"bfloat16\")\n        loss = obj(y_true, y_pred)\n        self.assertDType(loss, \"bfloat16\")\n\n\nclass CategoricalFocalCrossentropyTest(testing.TestCase):\n    def test_config(self):\n        self.run_class_serialization_test(\n            losses.CategoricalFocalCrossentropy(name=\"cfce\")\n        )\n\n    def test_all_correct_unweighted(self):\n        y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=\"int64\")\n        y_pred = np.array(\n            [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],\n            dtype=\"float32\",\n        )\n        cce_obj = losses.CategoricalFocalCrossentropy(alpha=0.25, gamma=2.0)\n        loss = cce_obj(y_true, y_pred)\n        self.assertAlmostEqual(loss, 0.0, 3)\n\n        # Test with logits.\n        logits = np.array(\n            [[10.0, 0.0, 0.0], [0.0, 10.0, 0.0], [0.0, 0.0, 10.0]]\n        )\n        cce_obj = losses.CategoricalFocalCrossentropy(from_logits=True)\n        loss = cce_obj(y_true, logits)\n        self.assertAlmostEqual(loss, 0.0, 3)\n\n    def test_unweighted(self):\n        cce_obj = losses.CategoricalFocalCrossentropy()\n        y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])\n        y_pred = np.array(\n            [[0.9, 0.05, 0.05], [0.5, 0.89, 0.6], [0.05, 0.01, 0.94]],\n            dtype=\"float32\",\n        )\n        loss = cce_obj(y_true, y_pred)\n        self.assertAlmostEqual(loss, 0.02059, 3)\n\n        # Test with logits.\n        logits = np.array([[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]])\n        cce_obj = losses.CategoricalFocalCrossentropy(from_logits=True)\n        loss = cce_obj(y_true, logits)\n        self.assertAlmostEqual(loss, 0.000345, 3)\n\n    def test_scalar_weighted(self):\n        cce_obj = losses.CategoricalFocalCrossentropy()\n        y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])\n        y_pred = np.array(\n            [[0.9, 0.05, 0.05], [0.5, 0.89, 0.6], [0.05, 0.01, 0.94]],\n            dtype=\"float32\",\n        )\n        loss = cce_obj(y_true, y_pred, sample_weight=2.3)\n        self.assertAlmostEqual(loss, 0.047368, 3)\n\n        # Test with logits.\n        logits = np.array([[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]])\n        cce_obj = losses.CategoricalFocalCrossentropy(from_logits=True)\n        loss = cce_obj(y_true, logits, sample_weight=2.3)\n        self.assertAlmostEqual(loss, 0.000794, 4)\n\n    def test_sample_weighted(self):\n        cce_obj = losses.CategoricalFocalCrossentropy()\n        y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])\n        y_pred = np.array(\n            [[0.9, 0.05, 0.05], [0.5, 0.89, 0.6], [0.05, 0.01, 0.94]],\n            dtype=\"float32\",\n        )\n        sample_weight = np.array([[1.2], [3.4], [5.6]]).reshape((3, 1))\n        loss = cce_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAlmostEqual(loss, 0.06987, 3)\n\n        # Test with logits.\n        logits = np.array([[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]])\n        cce_obj = losses.CategoricalFocalCrossentropy(from_logits=True)\n        loss = cce_obj(y_true, logits, sample_weight=sample_weight)\n        self.assertAlmostEqual(loss, 0.001933, 3)\n\n    def test_no_reduction(self):\n        y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])\n        logits = np.array([[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]])\n        cce_obj = losses.CategoricalFocalCrossentropy(\n            from_logits=True, reduction=None\n        )\n        loss = cce_obj(y_true, logits)\n        self.assertAllClose(\n            (1.5096224e-09, 2.4136547e-11, 1.0360638e-03),\n            loss,\n        )\n\n    def test_label_smoothing(self):\n        logits = np.array([[4.9, -0.5, 2.05]])\n        y_true = np.array([[1, 0, 0]])\n        label_smoothing = 0.1\n\n        cce_obj = losses.CategoricalFocalCrossentropy(\n            from_logits=True, label_smoothing=label_smoothing\n        )\n        loss = cce_obj(y_true, logits)\n\n        expected_value = 0.06685\n        self.assertAlmostEqual(loss, expected_value, 3)\n\n    def test_dtype_arg(self):\n        logits = np.array([[4.9, -0.5, 2.05]])\n        y_true = np.array([[1, 0, 0]])\n        cce_obj = losses.CategoricalFocalCrossentropy(\n            from_logits=True, dtype=\"bfloat16\"\n        )\n        loss = cce_obj(y_true, logits)\n        self.assertDType(loss, \"bfloat16\")\n\n\nclass CTCTest(testing.TestCase):\n    def test_config(self):\n        self.run_class_serialization_test(losses.CTC(name=\"myctc\"))\n\n    def test_correctness(self):\n        logits = (np.arange(24).reshape((2, 4, 3)).astype(\"float32\") - 12) / 100\n        y_true = np.array(([[1, 2, 1, 0], [1, 2, 0, 2]]))\n        output = losses.CTC()(y_true, logits)\n        self.assertAllClose(output, 2.448645, tpu_atol=1e-3, tpu_rtol=1e-3)\n\n    def test_dtype_arg(self):\n        logits = (np.arange(24).reshape((2, 4, 3)).astype(\"float32\") - 12) / 100\n        y_true = np.array(([[1, 2, 1, 0], [1, 2, 0, 2]]))\n        output = losses.CTC(dtype=\"bfloat16\")(y_true, logits)\n        self.assertDType(output, \"bfloat16\")\n\n\nclass DiceTest(testing.TestCase):\n    def test_config(self):\n        self.run_class_serialization_test(losses.Dice(name=\"mydice\"))\n\n    def test_correctness(self):\n        y_true = np.array(([[1, 2], [1, 2]]))\n        y_pred = np.array(([[4, 1], [6, 1]]))\n        output = losses.Dice()(y_true, y_pred)\n        self.assertAllClose(output, -0.55555546)\n\n    def test_binary_segmentation(self):\n        y_true = np.array(\n            ([[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]])\n        )\n        y_pred = np.array(\n            ([[0, 1, 0, 1], [1, 0, 1, 1], [0, 1, 0, 1], [1, 0, 1, 1]])\n        )\n        output = losses.Dice()(y_true, y_pred)\n        self.assertAllClose(output, 0.77777773)\n\n    def test_binary_segmentation_with_axis(self):\n        y_true = np.array(\n            [[[[1.0], [1.0]], [[0.0], [0.0]]], [[[1.0], [1.0]], [[0.0], [0.0]]]]\n        )\n        y_pred = np.array(\n            [[[[0.0], [1.0]], [[0.0], [1.0]]], [[[0.4], [0.0]], [[0.0], [0.9]]]]\n        )\n        output = losses.Dice(axis=(1, 2, 3), reduction=None)(y_true, y_pred)\n        self.assertAllClose(output, [0.5, 0.75757575])\n\n    def test_dtype_arg(self):\n        y_true = np.array(([[1, 2], [1, 2]]))\n        y_pred = np.array(([[4, 1], [6, 1]]))\n        output = losses.Dice(dtype=\"bfloat16\")(y_true, y_pred)\n        self.assertDType(output, \"bfloat16\")\n\n\nclass TverskyTest(testing.TestCase):\n    def test_config(self):\n        self.run_class_serialization_test(losses.Tversky(name=\"mytversky\"))\n\n    def test_correctness(self):\n        y_true = np.array(([[1, 2], [1, 2]]))\n        y_pred = np.array(([[4, 1], [6, 1]]))\n        output = losses.Tversky()(y_true, y_pred)\n        self.assertAllClose(output, -0.55555546)\n\n    def test_correctness_custom_coefficients(self):\n        y_true = np.array(([[1, 2], [1, 2]]))\n        y_pred = np.array(([[4, 1], [6, 1]]))\n        output = losses.Tversky(alpha=0.2, beta=0.8)(y_true, y_pred)\n        self.assertAllClose(output, -0.29629636)\n\n    def test_binary_segmentation(self):\n        y_true = np.array(\n            ([[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]])\n        )\n        y_pred = np.array(\n            ([[0, 1, 0, 1], [1, 0, 1, 1], [0, 1, 0, 1], [1, 0, 1, 1]])\n        )\n        output = losses.Tversky()(y_true, y_pred)\n        self.assertAllClose(output, 0.77777773)\n\n    def test_binary_segmentation_with_axis(self):\n        y_true = np.array(\n            [[[[1.0], [1.0]], [[0.0], [0.0]]], [[[1.0], [1.0]], [[0.0], [0.0]]]]\n        )\n        y_pred = np.array(\n            [[[[0.0], [1.0]], [[0.0], [1.0]]], [[[0.4], [0.0]], [[0.0], [0.9]]]]\n        )\n        output = losses.Tversky(axis=(1, 2, 3), reduction=None)(y_true, y_pred)\n        self.assertAllClose(output, [0.5, 0.75757575])\n\n    def test_binary_segmentation_custom_coefficients(self):\n        y_true = np.array(\n            ([[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]])\n        )\n        y_pred = np.array(\n            ([[0, 1, 0, 1], [1, 0, 1, 1], [0, 1, 0, 1], [1, 0, 1, 1]])\n        )\n        output = losses.Tversky(alpha=0.2, beta=0.8)(y_true, y_pred)\n        self.assertAllClose(output, 0.7916667)\n\n    def test_binary_segmentation_custom_coefficients_with_axis(self):\n        y_true = np.array(\n            [[[[1.0], [1.0]], [[0.0], [0.0]]], [[[1.0], [1.0]], [[0.0], [0.0]]]]\n        )\n        y_pred = np.array(\n            [[[[0.0], [1.0]], [[0.0], [1.0]]], [[[0.4], [0.0]], [[0.0], [0.9]]]]\n        )\n        output = losses.Tversky(\n            alpha=0.2, beta=0.8, axis=(1, 2, 3), reduction=None\n        )(y_true, y_pred)\n        self.assertAllClose(output, [0.5, 0.7222222])\n\n    def test_dtype_arg(self):\n        y_true = np.array(([[1, 2], [1, 2]]))\n        y_pred = np.array(([[4, 1], [6, 1]]))\n        output = losses.Tversky(dtype=\"bfloat16\")(y_true, y_pred)\n        self.assertDType(output, \"bfloat16\")\n\n\nclass CircleTest(testing.TestCase):\n    def setup(self):\n        super().setUp()\n        self.y_true = np.array([1, 1, 2, 2, 3])\n        self.y_pred = np.array(\n            [\n                [0.70014004, -0.42008403, 0.14002801, 0.56011203],\n                [0.17609018, 0.70436073, -0.52827054, 0.44022545],\n                [-0.34050261, 0.25537696, -0.68100522, 0.59587957],\n                [0.32163376, -0.75047877, 0.53605627, -0.21442251],\n                [0.51261459, -0.34174306, 0.17087153, 0.76892189],\n            ]\n        )\n        self.ref_labels = np.array([1, 1, 2, 2, 3, 4])\n        self.ref_embeddings = np.array(\n            [\n                [0.40824829, -0.54433105, 0.27216553, 0.68041382],\n                [0.76376261, 0.10910895, -0.54554473, 0.32732684],\n                [-0.74420841, 0.24806947, 0.49613894, -0.3721042],\n                [0.52981294, -0.13245324, 0.79471941, -0.26490647],\n                [0.54554473, -0.32732684, 0.10910895, 0.76376261],\n                [-0.27216553, 0.68041382, 0.40824829, -0.54433105],\n            ]\n        )\n\n    def test_config(self):\n        self.run_class_serialization_test(\n            losses.Circle(name=\"mycircle\", gamma=80.0, margin=0.4)\n        )\n\n    def test_correctness(self):\n        self.setup()\n        circle_loss = losses.Circle(gamma=80.0, margin=0.4)\n        loss = circle_loss(self.y_true, self.y_pred)\n        self.assertAlmostEqual(loss, 188.3883, tpu_decimal=0)\n\n        circle_loss = losses.Circle(gamma=256, margin=0.25)\n        loss = circle_loss(self.y_true, self.y_pred)\n        self.assertAlmostEqual(loss, 652.7617, tpu_decimal=0)\n\n        loss = losses.circle(\n            self.y_true,\n            self.y_pred,\n            ref_labels=self.ref_labels,\n            ref_embeddings=self.ref_embeddings,\n            gamma=80.0,\n            margin=0.4,\n            remove_diagonal=False,\n        )\n\n        self.assertAllClose(\n            loss,\n            (61.5844, 94.3465, 276.9344, 90.9873, 48.8963),\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n\n    def test_correctness_weighted(self):\n        self.setup()\n        sample_weight = np.array([2.0, 2.0, 1.0, 1.0, 0.5])\n        circle_loss = losses.Circle(gamma=80.0, margin=0.4)\n        loss = circle_loss(\n            self.y_true, self.y_pred, sample_weight=sample_weight\n        )\n        self.assertAlmostEqual(loss, 244.91918, tpu_decimal=0)\n\n    def test_no_reduction(self):\n        self.setup()\n        circle_loss = losses.Circle(gamma=80.0, margin=0.4, reduction=None)\n        loss = circle_loss(self.ref_labels, self.ref_embeddings)\n\n        self.assertAllClose(\n            loss,\n            [82.9116, 36.7942, 92.4590, 52.6798, 0.0, 0.0],\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n\n    def test_sum_reduction(self):\n        self.setup()\n        circle_loss = losses.Circle(gamma=80.0, margin=0.4, reduction=\"sum\")\n        loss = circle_loss(self.ref_labels, self.ref_embeddings)\n\n        self.assertAlmostEqual(loss, 264.845, tpu_decimal=0)\n\n    def test_mean_with_sample_weight_reduction(self):\n        self.setup()\n        sample_weight = np.array([2.0, 2.0, 1.0, 1.0, 0.5])\n        circle_loss = losses.Circle(\n            gamma=80.0, margin=0.4, reduction=\"mean_with_sample_weight\"\n        )\n        loss = circle_loss(\n            self.y_true, self.y_pred, sample_weight=sample_weight\n        )\n        self.assertAlmostEqual(loss, 163.27948, tpu_decimal=0)\n\n    def test_dtype_arg(self):\n        self.setup()\n        circle_loss = losses.Circle(dtype=\"bfloat16\")\n        loss = circle_loss(self.y_true, self.y_pred)\n        self.assertDType(loss, \"bfloat16\")\n\n\nclass CategoricalGeneralizedCrossEntropyTest(testing.TestCase):\n    def test_config(self):\n        self.run_class_serialization_test(\n            losses.CategoricalGeneralizedCrossEntropy(name=\"gce\")\n        )\n        self.run_class_serialization_test(\n            losses.CategoricalGeneralizedCrossEntropy(q=0.1, name=\"gce\")\n        )\n\n    def test_basic_correctness_for_binary(self):\n        y_true = np.array([0, 1, 0, 1])\n        y_pred = np.array([[0.7, 0.3], [0.2, 0.8], [0.6, 0.4], [0.4, 0.6]])\n        # Calculate expected GCE loss manually\n        # For q=0.5:\n        # First sample (class 0): gce = (1 - 0.7^0.5) / 0.5\n        # Second sample (class 1): gce = (1 - 0.8^0.5) / 0.5\n        # Third sample (class 0): gce = (1 - 0.6^0.5) / 0.5\n        # Fourth sample (class 1): gce = (1 - 0.6^0.5) / 0.5\n        expected = np.array(\n            [\n                (1 - np.power(0.7, 0.5)) / 0.5,\n                (1 - np.power(0.8, 0.5)) / 0.5,\n                (1 - np.power(0.6, 0.5)) / 0.5,\n                (1 - np.power(0.6, 0.5)) / 0.5,\n            ]\n        )\n        output = losses.CategoricalGeneralizedCrossEntropy()(y_true, y_pred)\n        self.assertAllClose(output, expected.sum() / len(expected))\n\n        expected_q_08 = np.array(\n            [\n                (1 - np.power(0.7, 0.8)) / 0.8,\n                (1 - np.power(0.8, 0.8)) / 0.8,\n                (1 - np.power(0.6, 0.8)) / 0.8,\n                (1 - np.power(0.6, 0.8)) / 0.8,\n            ]\n        )\n        output = losses.CategoricalGeneralizedCrossEntropy(q=0.8)(\n            y_true, y_pred\n        )\n        self.assertAllClose(output, expected_q_08.sum() / len(expected_q_08))\n\n    def test_basic_correctness_for_multi_class(self):\n        y_true = np.array([0, 1, 0, 1])\n        y_pred = np.array(\n            [[0.7, 0.3, 0.0], [0.2, 0.2, 0.6], [0.6, 0.4, 0.0], [0.2, 0.2, 0.6]]\n        )\n        # Calculate expected GCE loss manually\n        # For q=0.5:\n        # First sample (class 0): gce = (1 - 0.7^0.5) / 0.5\n        # Second sample (class 1): gce = (1 - 0^0.5) / 0.5\n        # Third sample (class 0): gce = (1 - 0.6^0.5) / 0.5\n        # Fourth sample (class 1): gce = (1 - 0.0^0.5) / 0.5\n        expected = np.array(\n            [\n                (1 - np.power(0.7, 0.5)) / 0.5,\n                (1 - np.power(0.2, 0.5)) / 0.5,\n                (1 - np.power(0.6, 0.5)) / 0.5,\n                (1 - np.power(0.2, 0.5)) / 0.5,\n            ]\n        )\n        output = losses.CategoricalGeneralizedCrossEntropy()(y_true, y_pred)\n        self.assertAllClose(output, expected.sum() / len(expected))\n\n        expected_q_08 = np.array(\n            [\n                (1 - np.power(0.7, 0.8)) / 0.8,\n                (1 - np.power(0.2, 0.8)) / 0.8,\n                (1 - np.power(0.6, 0.8)) / 0.8,\n                (1 - np.power(0.2, 0.8)) / 0.8,\n            ]\n        )\n        output = losses.CategoricalGeneralizedCrossEntropy(q=0.8)(\n            y_true, y_pred\n        )\n        self.assertAllClose(output, expected_q_08.sum() / len(expected_q_08))\n\n    def test_binary_segmentation(self):\n        y_true = np.array(\n            [[0, 1, 1, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]]\n        )\n        y_pred = np.array(\n            [\n                [[1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [1.0, 0.0]],\n                [[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [1.0, 0.0]],\n                [[1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]],\n                [[0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.0, 1.0]],\n            ]\n        )\n        output = losses.CategoricalGeneralizedCrossEntropy(q=0.5)(\n            y_true, y_pred\n        )\n        self.assertAllClose(output, 0.0)\n\n        y_true = np.array(\n            [[0, 1, 1, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]]\n        )\n        y_pred = np.array(\n            [\n                [[1.0, 0.0], [0.0, 1.0], [0.0, 1.0], [0.2, 0.8]],\n                [[0.0, 1.0], [1.0, 0.0], [0.0, 1.0], [1.0, 0.0]],\n                [[1.0, 0.0], [1.0, 0.0], [0.0, 1.0], [0.0, 1.0]],\n                [[0.0, 1.0], [0.0, 1.0], [1.0, 0.0], [0.6, 0.4]],\n            ]\n        )\n        expected = np.array(\n            [\n                (1 - np.power(0.2, 0.5)) / 0.5,\n                (1 - np.power(0.4, 0.5)) / 0.5,\n            ]\n        )\n        output = losses.CategoricalGeneralizedCrossEntropy(q=0.5)(\n            y_true, y_pred\n        )\n        self.assertAllClose(output, expected.sum() / 16.0)  # 16 pixels\n\n    def test_multi_class_segmentation(self):\n        y_true = np.array(\n            [[0, 1, 2, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]]\n        )\n        y_pred = np.array(\n            [\n                [\n                    [1.0, 0.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                    [0.0, 0.0, 1.0],\n                    [1.0, 0.0, 0.0],\n                ],\n                [\n                    [0.0, 1.0, 0.0],\n                    [1.0, 0.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                    [1.0, 0.0, 0.0],\n                ],\n                [\n                    [1.0, 0.0, 0.0],\n                    [1.0, 0.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                ],\n                [\n                    [0.0, 1.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                    [1.0, 0.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                ],\n            ]\n        )\n        output = losses.CategoricalGeneralizedCrossEntropy(q=0.5)(\n            y_true, y_pred\n        )\n        self.assertAllClose(output, 0.0)\n\n        y_true = np.array(\n            [[0, 1, 2, 0], [1, 0, 1, 0], [0, 0, 1, 1], [1, 1, 0, 1]]\n        )\n        y_pred = np.array(\n            [\n                [\n                    [1.0, 0.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                    [0.0, 0.0, 1.0],\n                    [0.2, 0.0, 0.8],\n                ],\n                [\n                    [1.0, 0.0, 0.0],\n                    [1.0, 0.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                    [1.0, 0.0, 0.0],\n                ],\n                [\n                    [1.0, 0.0, 0.0],\n                    [1.0, 0.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                ],\n                [\n                    [0.0, 1.0, 0.0],\n                    [0.0, 1.0, 0.0],\n                    [0.5, 0.5, 0.0],\n                    [0.0, 1.0, 0.0],\n                ],\n            ]\n        )\n        expected = np.array(\n            [\n                (1 - np.power(0.2, 0.5)) / 0.5,\n                (1 - np.power(0.0, 0.5)) / 0.5,\n                (1 - np.power(0.5, 0.5)) / 0.5,\n            ]\n        )\n        output = losses.CategoricalGeneralizedCrossEntropy(q=0.5)(\n            y_true, y_pred\n        )\n        self.assertAllClose(output, expected.sum() / 16.0)  # 16 pixels\n\n    def test_dtype_arg(self):\n        y_true = np.array([0, 1, 0, 1])\n        y_pred = np.array([[0.7, 0.3], [0.2, 0.8], [0.6, 0.4], [0.4, 0.6]])\n        output = losses.CategoricalGeneralizedCrossEntropy(dtype=\"bfloat16\")(\n            y_true, y_pred\n        )\n        self.assertDType(output, \"bfloat16\")\n"
  },
  {
    "path": "keras/src/metrics/__init__.py",
    "content": "import inspect\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.metrics.accuracy_metrics import Accuracy\nfrom keras.src.metrics.accuracy_metrics import BinaryAccuracy\nfrom keras.src.metrics.accuracy_metrics import CategoricalAccuracy\nfrom keras.src.metrics.accuracy_metrics import SparseCategoricalAccuracy\nfrom keras.src.metrics.accuracy_metrics import SparseTopKCategoricalAccuracy\nfrom keras.src.metrics.accuracy_metrics import TopKCategoricalAccuracy\nfrom keras.src.metrics.confusion_metrics import AUC\nfrom keras.src.metrics.confusion_metrics import FalseNegatives\nfrom keras.src.metrics.confusion_metrics import FalsePositives\nfrom keras.src.metrics.confusion_metrics import Precision\nfrom keras.src.metrics.confusion_metrics import PrecisionAtRecall\nfrom keras.src.metrics.confusion_metrics import Recall\nfrom keras.src.metrics.confusion_metrics import RecallAtPrecision\nfrom keras.src.metrics.confusion_metrics import SensitivityAtSpecificity\nfrom keras.src.metrics.confusion_metrics import SpecificityAtSensitivity\nfrom keras.src.metrics.confusion_metrics import TrueNegatives\nfrom keras.src.metrics.confusion_metrics import TruePositives\nfrom keras.src.metrics.correlation_metrics import ConcordanceCorrelation\nfrom keras.src.metrics.correlation_metrics import PearsonCorrelation\nfrom keras.src.metrics.f_score_metrics import F1Score\nfrom keras.src.metrics.f_score_metrics import FBetaScore\nfrom keras.src.metrics.hinge_metrics import CategoricalHinge\nfrom keras.src.metrics.hinge_metrics import Hinge\nfrom keras.src.metrics.hinge_metrics import SquaredHinge\nfrom keras.src.metrics.iou_metrics import BinaryIoU\nfrom keras.src.metrics.iou_metrics import IoU\nfrom keras.src.metrics.iou_metrics import MeanIoU\nfrom keras.src.metrics.iou_metrics import OneHotIoU\nfrom keras.src.metrics.iou_metrics import OneHotMeanIoU\nfrom keras.src.metrics.metric import Metric\nfrom keras.src.metrics.probabilistic_metrics import BinaryCrossentropy\nfrom keras.src.metrics.probabilistic_metrics import CategoricalCrossentropy\nfrom keras.src.metrics.probabilistic_metrics import KLDivergence\nfrom keras.src.metrics.probabilistic_metrics import Poisson\nfrom keras.src.metrics.probabilistic_metrics import (\n    SparseCategoricalCrossentropy,\n)\nfrom keras.src.metrics.reduction_metrics import Mean\nfrom keras.src.metrics.reduction_metrics import MeanMetricWrapper\nfrom keras.src.metrics.reduction_metrics import Sum\nfrom keras.src.metrics.regression_metrics import CosineSimilarity\nfrom keras.src.metrics.regression_metrics import LogCoshError\nfrom keras.src.metrics.regression_metrics import MeanAbsoluteError\nfrom keras.src.metrics.regression_metrics import MeanAbsolutePercentageError\nfrom keras.src.metrics.regression_metrics import MeanSquaredError\nfrom keras.src.metrics.regression_metrics import MeanSquaredLogarithmicError\nfrom keras.src.metrics.regression_metrics import R2Score\nfrom keras.src.metrics.regression_metrics import RootMeanSquaredError\nfrom keras.src.saving import serialization_lib\nfrom keras.src.utils.naming import to_snake_case\n\nALL_OBJECTS = {\n    # Base\n    Metric,\n    Mean,\n    Sum,\n    MeanMetricWrapper,\n    # Regression\n    MeanSquaredError,\n    RootMeanSquaredError,\n    MeanAbsoluteError,\n    MeanAbsolutePercentageError,\n    MeanSquaredLogarithmicError,\n    CosineSimilarity,\n    LogCoshError,\n    R2Score,\n    # Classification\n    AUC,\n    FalseNegatives,\n    FalsePositives,\n    Precision,\n    PrecisionAtRecall,\n    Recall,\n    RecallAtPrecision,\n    SensitivityAtSpecificity,\n    SpecificityAtSensitivity,\n    TrueNegatives,\n    TruePositives,\n    # Correlation\n    ConcordanceCorrelation,\n    PearsonCorrelation,\n    # Hinge\n    Hinge,\n    SquaredHinge,\n    CategoricalHinge,\n    # Probabilistic\n    KLDivergence,\n    Poisson,\n    BinaryCrossentropy,\n    CategoricalCrossentropy,\n    SparseCategoricalCrossentropy,\n    # Accuracy\n    Accuracy,\n    BinaryAccuracy,\n    CategoricalAccuracy,\n    SparseCategoricalAccuracy,\n    TopKCategoricalAccuracy,\n    SparseTopKCategoricalAccuracy,\n    # F-Score\n    F1Score,\n    FBetaScore,\n    # IoU\n    IoU,\n    BinaryIoU,\n    MeanIoU,\n    OneHotIoU,\n    OneHotMeanIoU,\n}\nALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}\nALL_OBJECTS_DICT.update(\n    {to_snake_case(cls.__name__): cls for cls in ALL_OBJECTS}\n)\n# TODO: Align with `tf.keras` and set the name attribute of metrics\n# with the key name. Currently it uses default name of class definitions.\nALL_OBJECTS_DICT.update(\n    {\n        \"bce\": BinaryCrossentropy,\n        \"BCE\": BinaryCrossentropy,\n        \"mse\": MeanSquaredError,\n        \"MSE\": MeanSquaredError,\n        \"mae\": MeanAbsoluteError,\n        \"MAE\": MeanAbsoluteError,\n        \"mape\": MeanAbsolutePercentageError,\n        \"MAPE\": MeanAbsolutePercentageError,\n        \"msle\": MeanSquaredLogarithmicError,\n        \"MSLE\": MeanSquaredLogarithmicError,\n    }\n)\n\n\n@keras_export(\"keras.metrics.serialize\")\ndef serialize(metric):\n    \"\"\"Serializes metric function or `Metric` instance.\n\n    Args:\n        metric: A Keras `Metric` instance or a metric function.\n\n    Returns:\n        Metric configuration dictionary.\n    \"\"\"\n    return serialization_lib.serialize_keras_object(metric)\n\n\n@keras_export(\"keras.metrics.deserialize\")\ndef deserialize(config, custom_objects=None):\n    \"\"\"Deserializes a serialized metric class/function instance.\n\n    Args:\n        config: Metric configuration.\n        custom_objects: Optional dictionary mapping names (strings)\n            to custom objects (classes and functions) to be\n            considered during deserialization.\n\n    Returns:\n        A Keras `Metric` instance or a metric function.\n    \"\"\"\n    return serialization_lib.deserialize_keras_object(\n        config,\n        module_objects=ALL_OBJECTS_DICT,\n        custom_objects=custom_objects,\n    )\n\n\n@keras_export(\"keras.metrics.get\")\ndef get(identifier):\n    \"\"\"Retrieves a Keras metric as a `function`/`Metric` class instance.\n\n    The `identifier` may be the string name of a metric function or class.\n\n    >>> metric = metrics.get(\"categorical_crossentropy\")\n    >>> type(metric)\n    <class 'function'>\n    >>> metric = metrics.get(\"CategoricalCrossentropy\")\n    >>> type(metric)\n    <class '...metrics.CategoricalCrossentropy'>\n\n    You can also specify `config` of the metric to this function by passing dict\n    containing `class_name` and `config` as an identifier. Also note that the\n    `class_name` must map to a `Metric` class\n\n    >>> identifier = {\"class_name\": \"CategoricalCrossentropy\",\n    ...               \"config\": {\"from_logits\": True}}\n    >>> metric = metrics.get(identifier)\n    >>> type(metric)\n    <class '...metrics.CategoricalCrossentropy'>\n\n    Args:\n        identifier: A metric identifier. One of None or string name of a metric\n            function/class or metric configuration dictionary or a metric\n            function or a metric class instance\n\n    Returns:\n        A Keras metric as a `function`/ `Metric` class instance.\n    \"\"\"\n    if identifier is None:\n        return None\n    if isinstance(identifier, dict):\n        obj = deserialize(identifier)\n    elif isinstance(identifier, str):\n        obj = ALL_OBJECTS_DICT.get(identifier, None)\n    else:\n        obj = identifier\n    if callable(obj):\n        if inspect.isclass(obj):\n            obj = obj()\n        return obj\n    else:\n        raise ValueError(f\"Could not interpret metric identifier: {identifier}\")\n"
  },
  {
    "path": "keras/src/metrics/accuracy_metrics.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.losses.loss import squeeze_or_expand_to_same_rank\nfrom keras.src.metrics import reduction_metrics\n\n\ndef accuracy(y_true, y_pred):\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)\n    y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)\n    return ops.cast(ops.equal(y_true, y_pred), dtype=backend.floatx())\n\n\n@keras_export(\"keras.metrics.Accuracy\")\nclass Accuracy(reduction_metrics.MeanMetricWrapper):\n    \"\"\"Calculates how often predictions equal labels.\n\n    This metric creates two local variables, `total` and `count` that are used\n    to compute the frequency with which `y_pred` matches `y_true`. This\n    frequency is ultimately returned as `binary accuracy`: an idempotent\n    operation that simply divides `total` by `count`.\n\n    If `sample_weight` is `None`, weights default to 1.\n    Use `sample_weight` of 0 to mask values.\n\n    Args:\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Examples:\n\n    >>> m = keras.metrics.Accuracy()\n    >>> m.update_state([[1], [2], [3], [4]], [[0], [2], [3], [4]])\n    >>> m.result()\n    0.75\n\n    >>> m.reset_state()\n    >>> m.update_state([[1], [2], [3], [4]], [[0], [2], [3], [4]],\n    ...                sample_weight=[1, 1, 0, 0])\n    >>> m.result()\n    0.5\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(optimizer='sgd',\n                  loss='binary_crossentropy',\n                  metrics=[keras.metrics.Accuracy()])\n    ```\n    \"\"\"\n\n    def __init__(self, name=\"accuracy\", dtype=None):\n        super().__init__(fn=accuracy, name=name, dtype=dtype)\n        # Metric should be maximized during optimization.\n        self._direction = \"up\"\n\n    def get_config(self):\n        return {\"name\": self.name, \"dtype\": self.dtype}\n\n\n@keras_export(\"keras.metrics.binary_accuracy\")\ndef binary_accuracy(y_true, y_pred, threshold=0.5):\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true = ops.convert_to_tensor(y_true)\n    threshold = ops.convert_to_tensor(threshold)\n    y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)\n    y_pred = ops.cast(ops.greater(y_pred, threshold), y_true.dtype)\n    return ops.cast(ops.equal(y_true, y_pred), dtype=backend.floatx())\n\n\n@keras_export(\"keras.metrics.BinaryAccuracy\")\nclass BinaryAccuracy(reduction_metrics.MeanMetricWrapper):\n    \"\"\"Calculates how often predictions match binary labels.\n\n    This metric creates two local variables, `total` and `count` that are used\n    to compute the frequency with which `y_pred` matches `y_true`. This\n    frequency is ultimately returned as `binary accuracy`: an idempotent\n    operation that simply divides `total` by `count`.\n\n    If `sample_weight` is `None`, weights default to 1.\n    Use `sample_weight` of 0 to mask values.\n\n    Args:\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n        threshold: (Optional) Float representing the threshold for deciding\n        whether prediction values are 1 or 0.\n\n    Example:\n\n    >>> m = keras.metrics.BinaryAccuracy()\n    >>> m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]])\n    >>> m.result()\n    0.75\n\n    >>> m.reset_state()\n    >>> m.update_state([[1], [1], [0], [0]], [[0.98], [1], [0], [0.6]],\n    ...                sample_weight=[1, 0, 0, 1])\n    >>> m.result()\n    0.5\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(optimizer='sgd',\n                  loss='binary_crossentropy',\n                  metrics=[keras.metrics.BinaryAccuracy()])\n    ```\n    \"\"\"\n\n    def __init__(self, name=\"binary_accuracy\", dtype=None, threshold=0.5):\n        super().__init__(\n            fn=binary_accuracy, name=name, dtype=dtype, threshold=threshold\n        )\n        self.threshold = threshold\n        # Metric should be maximized during optimization.\n        self._direction = \"up\"\n\n    def get_config(self):\n        return {\n            \"name\": self.name,\n            \"dtype\": self.dtype,\n            \"threshold\": self.threshold,\n        }\n\n\n@keras_export(\"keras.metrics.categorical_accuracy\")\ndef categorical_accuracy(y_true, y_pred):\n    y_true = ops.argmax(y_true, axis=-1)\n\n    reshape_matches = False\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)\n\n    y_true_org_shape = ops.shape(y_true)\n    y_pred_rank = len(y_pred.shape)\n    y_true_rank = len(y_true.shape)\n\n    # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)\n    if (\n        (y_true_rank is not None)\n        and (y_pred_rank is not None)\n        and (len(y_true.shape) == len(y_pred.shape))\n    ):\n        y_true = ops.squeeze(y_true, -1)\n        reshape_matches = True\n    y_pred = ops.argmax(y_pred, axis=-1)\n\n    # If the predicted output and actual output types don't match, force cast\n    # them to match.\n    if y_pred.dtype is not y_true.dtype:\n        y_pred = ops.cast(y_pred, dtype=y_true.dtype)\n    matches = ops.cast(ops.equal(y_true, y_pred), backend.floatx())\n    if reshape_matches:\n        matches = ops.reshape(matches, y_true_org_shape)\n    return matches\n\n\n@keras_export(\"keras.metrics.CategoricalAccuracy\")\nclass CategoricalAccuracy(reduction_metrics.MeanMetricWrapper):\n    \"\"\"Calculates how often predictions match one-hot labels.\n\n    You can provide logits of classes as `y_pred`, since argmax of\n    logits and probabilities are same.\n\n    This metric creates two local variables, `total` and `count` that are used\n    to compute the frequency with which `y_pred` matches `y_true`. This\n    frequency is ultimately returned as `categorical accuracy`: an idempotent\n    operation that simply divides `total` by `count`.\n\n    `y_pred` and `y_true` should be passed in as vectors of probabilities,\n    rather than as labels. If necessary, use `ops.one_hot` to expand `y_true` as\n    a vector.\n\n    If `sample_weight` is `None`, weights default to 1.\n    Use `sample_weight` of 0 to mask values.\n\n    Args:\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Example:\n\n    >>> m = keras.metrics.CategoricalAccuracy()\n    >>> m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8],\n    ...                 [0.05, 0.95, 0]])\n    >>> m.result()\n    0.5\n\n    >>> m.reset_state()\n    >>> m.update_state([[0, 0, 1], [0, 1, 0]], [[0.1, 0.9, 0.8],\n    ...                 [0.05, 0.95, 0]],\n    ...                sample_weight=[0.7, 0.3])\n    >>> m.result()\n    0.3\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(optimizer='sgd',\n                  loss='categorical_crossentropy',\n                  metrics=[keras.metrics.CategoricalAccuracy()])\n    ```\n    \"\"\"\n\n    def __init__(self, name=\"categorical_accuracy\", dtype=None):\n        super().__init__(fn=categorical_accuracy, name=name, dtype=dtype)\n        # Metric should be maximized during optimization.\n        self._direction = \"up\"\n\n    def get_config(self):\n        return {\"name\": self.name, \"dtype\": self.dtype}\n\n\n@keras_export(\"keras.metrics.sparse_categorical_accuracy\")\ndef sparse_categorical_accuracy(y_true, y_pred):\n    reshape_matches = False\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)\n    y_true_org_shape = ops.shape(y_true)\n    y_pred_rank = len(y_pred.shape)\n    y_true_rank = len(y_true.shape)\n\n    # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)\n    if (\n        (y_true_rank is not None)\n        and (y_pred_rank is not None)\n        and (len(y_true.shape) == len(y_pred.shape))\n        and ops.shape(y_true)[-1] == 1\n    ):\n        y_true = ops.squeeze(y_true, -1)\n        reshape_matches = True\n    y_pred = ops.argmax(y_pred, axis=-1)\n\n    # If the predicted output and actual output types don't match, force cast\n    # them to match.\n    if y_pred.dtype is not y_true.dtype:\n        y_pred = ops.cast(y_pred, y_true.dtype)\n    matches = ops.cast(ops.equal(y_true, y_pred), backend.floatx())\n    if reshape_matches:\n        matches = ops.reshape(matches, y_true_org_shape)\n    # if shape is (num_samples, 1) squeeze\n    if len(matches.shape) > 1 and matches.shape[-1] == 1:\n        matches = ops.squeeze(matches, -1)\n    return matches\n\n\n@keras_export(\"keras.metrics.SparseCategoricalAccuracy\")\nclass SparseCategoricalAccuracy(reduction_metrics.MeanMetricWrapper):\n    \"\"\"Calculates how often predictions match integer labels.\n\n    ```python\n    acc = np.dot(sample_weight, np.equal(y_true, np.argmax(y_pred, axis=1))\n    ```\n\n    You can provide logits of classes as `y_pred`, since argmax of\n    logits and probabilities are same.\n\n    This metric creates two local variables, `total` and `count` that are used\n    to compute the frequency with which `y_pred` matches `y_true`. This\n    frequency is ultimately returned as `sparse categorical accuracy`: an\n    idempotent operation that simply divides `total` by `count`.\n\n    If `sample_weight` is `None`, weights default to 1.\n    Use `sample_weight` of 0 to mask values.\n\n    Args:\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Example:\n\n    >>> m = keras.metrics.SparseCategoricalAccuracy()\n    >>> m.update_state([[2], [1]], [[0.1, 0.6, 0.3], [0.05, 0.95, 0]])\n    >>> m.result()\n    0.5\n\n    >>> m.reset_state()\n    >>> m.update_state([[2], [1]], [[0.1, 0.6, 0.3], [0.05, 0.95, 0]],\n    ...                sample_weight=[0.7, 0.3])\n    >>> m.result()\n    0.3\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(optimizer='sgd',\n                  loss='sparse_categorical_crossentropy',\n                  metrics=[keras.metrics.SparseCategoricalAccuracy()])\n    ```\n    \"\"\"\n\n    def __init__(self, name=\"sparse_categorical_accuracy\", dtype=None):\n        super().__init__(fn=sparse_categorical_accuracy, name=name, dtype=dtype)\n        # Metric should be maximized during optimization.\n        self._direction = \"up\"\n\n    def get_config(self):\n        return {\"name\": self.name, \"dtype\": self.dtype}\n\n\n@keras_export(\"keras.metrics.top_k_categorical_accuracy\")\ndef top_k_categorical_accuracy(y_true, y_pred, k=5):\n    reshape_matches = False\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)\n    y_true = ops.argmax(y_true, axis=-1)\n    y_true_rank = len(y_true.shape)\n    y_pred_rank = len(y_pred.shape)\n    y_true_org_shape = ops.shape(y_true)\n\n    # Flatten y_pred to (batch_size, num_samples) and y_true to (num_samples,)\n    if (y_true_rank is not None) and (y_pred_rank is not None):\n        if y_pred_rank > 2:\n            y_pred = ops.reshape(y_pred, [-1, y_pred.shape[-1]])\n        if y_true_rank > 1:\n            reshape_matches = True\n            y_true = ops.reshape(y_true, [-1])\n\n    matches = ops.cast(\n        ops.in_top_k(ops.cast(y_true, \"int32\"), y_pred, k=k),\n        dtype=backend.floatx(),\n    )\n\n    # returned matches is expected to have same shape as y_true input\n    if reshape_matches:\n        matches = ops.reshape(matches, y_true_org_shape)\n\n    return matches\n\n\n@keras_export(\"keras.metrics.TopKCategoricalAccuracy\")\nclass TopKCategoricalAccuracy(reduction_metrics.MeanMetricWrapper):\n    \"\"\"Computes how often targets are in the top `K` predictions.\n\n    Args:\n        k: (Optional) Number of top elements to look at for computing accuracy.\n            Defaults to `5`.\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Example:\n\n    >>> m = keras.metrics.TopKCategoricalAccuracy(k=1)\n    >>> m.update_state([[0, 0, 1], [0, 1, 0]],\n    ...                [[0.1, 0.9, 0.8], [0.05, 0.95, 0]])\n    >>> m.result()\n    0.5\n\n    >>> m.reset_state()\n    >>> m.update_state([[0, 0, 1], [0, 1, 0]],\n    ...                [[0.1, 0.9, 0.8], [0.05, 0.95, 0]],\n    ...                sample_weight=[0.7, 0.3])\n    >>> m.result()\n    0.3\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(optimizer='sgd',\n                  loss='categorical_crossentropy',\n                  metrics=[keras.metrics.TopKCategoricalAccuracy()])\n    ```\n    \"\"\"\n\n    def __init__(self, k=5, name=\"top_k_categorical_accuracy\", dtype=None):\n        super().__init__(\n            fn=top_k_categorical_accuracy,\n            name=name,\n            dtype=dtype,\n            k=k,\n        )\n        self.k = k\n        # Metric should be maximized during optimization.\n        self._direction = \"up\"\n\n    def get_config(self):\n        return {\"name\": self.name, \"dtype\": self.dtype, \"k\": self.k}\n\n\n@keras_export(\"keras.metrics.sparse_top_k_categorical_accuracy\")\ndef sparse_top_k_categorical_accuracy(\n    y_true, y_pred, k=5, from_sorted_ids=False\n):\n    \"\"\"Computes how often integer targets are in the top `K` predictions.\n\n    Args:\n        y_true: A tensor of shape `(batch_size)` representing indices or IDs of\n            true categories.\n        y_pred: If `from_sorted_ids=False`, a tensor of shape\n            `(batch_size, num_categories)` containing the scores for each sample\n            for all possible categories. If `from_sorted_ids=True`, a tensor of\n            shape `(batch_size, N)` containing indices or IDs of the top `N`\n            categories in order from highest score to lowest score.\n        k: (Optional) Number of top elements to look at for computing accuracy.\n            Defaults to `5`.\n        from_sorted_ids: (Optional) Whether `y_pred` is sorted category IDs or\n            scores for all categories (the default).\n\n    Returns:\n        A tensor with the same shape as `y_true` containing ones where `y_true`\n        is in the top `k` and zeros elsewhere.\n    \"\"\"\n    reshape_matches = False\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true_dtype = y_pred.dtype if from_sorted_ids else \"int32\"\n    y_true = ops.convert_to_tensor(y_true, dtype=y_true_dtype)\n    y_true_rank = len(y_true.shape)\n    y_pred_rank = len(y_pred.shape)\n    y_true_org_shape = ops.shape(y_true)\n\n    # Flatten y_pred to (batch_size, num_samples) and y_true to (num_samples,)\n    if (y_true_rank is not None) and (y_pred_rank is not None):\n        if y_pred_rank > 2:\n            y_pred = ops.reshape(y_pred, [-1, y_pred.shape[-1]])\n        if y_true_rank > 1:\n            reshape_matches = True\n            y_true = ops.reshape(y_true, [-1])\n\n    if from_sorted_ids:\n        # By slicing the first k items, we assume they are sorted by score.\n        # Reduce with `any` to count multiple matches only once.\n        matches = ops.any(\n            ops.equal(ops.expand_dims(y_true, axis=1), y_pred[:, :k]), axis=1\n        )\n    else:\n        matches = ops.in_top_k(y_true, y_pred, k=k)\n\n    matches = ops.cast(matches, dtype=backend.floatx())\n\n    # returned matches is expected to have same shape as y_true input\n    if reshape_matches:\n        matches = ops.reshape(matches, y_true_org_shape)\n\n    return matches\n\n\n@keras_export(\"keras.metrics.SparseTopKCategoricalAccuracy\")\nclass SparseTopKCategoricalAccuracy(reduction_metrics.MeanMetricWrapper):\n    \"\"\"Computes how often integer targets are in the top `K` predictions.\n\n    By default, the arguments expected by `update_state()` are:\n    - `y_true`: a tensor of shape `(batch_size)` representing indices of true\n        categories.\n    - `y_pred`: a tensor of shape `(batch_size, num_categories)` containing the\n        scores for each sample for all possible categories.\n\n    With `from_sorted_ids=True`, the arguments expected by `update_state` are:\n    - `y_true`: a tensor of shape `(batch_size)` representing indices or IDs of\n        true categories.\n    - `y_pred`: a tensor of shape `(batch_size, N)` containing the indices or\n        IDs of the top `N` categories sorted in order from highest score to\n        lowest score. `N` must be greater or equal to `k`.\n\n    The `from_sorted_ids=True` option can be more efficient when the set of\n    categories is very large and the model has an optimized way to retrieve the\n    top ones either without scoring or without maintaining the scores for all\n    the possible categories.\n\n    Args:\n        k: (Optional) Number of top elements to look at for computing accuracy.\n            Defaults to `5`.\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n        from_sorted_ids: (Optional) When `False`, the default, the tensor passed\n            in `y_pred` contains the unsorted scores of all possible categories.\n            When `True`, `y_pred` contains a the indices or IDs for the top\n            categories.\n\n    Example:\n\n    >>> m = keras.metrics.SparseTopKCategoricalAccuracy(k=1)\n    >>> m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]])\n    >>> m.result()\n    0.5\n\n    >>> m.reset_state()\n    >>> m.update_state([2, 1], [[0.1, 0.9, 0.8], [0.05, 0.95, 0]],\n    ...                sample_weight=[0.7, 0.3])\n    >>> m.result()\n    0.3\n\n    >>> m = keras.metrics.SparseTopKCategoricalAccuracy(k=1,\n    ...                                                from_sorted_ids=True)\n    >>> m.update_state([2, 1], [[1, 0, 3], [1, 2, 3]])\n    >>> m.result()\n    0.5\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(optimizer='sgd',\n                  loss='sparse_categorical_crossentropy',\n                  metrics=[keras.metrics.SparseTopKCategoricalAccuracy()])\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        k=5,\n        name=\"sparse_top_k_categorical_accuracy\",\n        dtype=None,\n        from_sorted_ids=False,\n    ):\n        super().__init__(\n            fn=sparse_top_k_categorical_accuracy,\n            name=name,\n            dtype=dtype,\n            k=k,\n            from_sorted_ids=from_sorted_ids,\n        )\n        self.k = k\n        self.from_sorted_ids = from_sorted_ids\n        # Metric should be maximized during optimization.\n        self._direction = \"up\"\n\n    def get_config(self):\n        config = {\"name\": self.name, \"dtype\": self.dtype, \"k\": self.k}\n        if self.from_sorted_ids:\n            config[\"from_sorted_ids\"] = True\n        return config\n"
  },
  {
    "path": "keras/src/metrics/accuracy_metrics_test.py",
    "content": "import numpy as np\n\nfrom keras.src import testing\nfrom keras.src.metrics import accuracy_metrics\n\n\nclass AccuracyTest(testing.TestCase):\n    def test_config(self):\n        acc_obj = accuracy_metrics.Accuracy(name=\"accuracy\", dtype=\"float32\")\n        self.assertEqual(acc_obj.name, \"accuracy\")\n        self.assertEqual(len(acc_obj.variables), 2)\n        self.assertEqual(acc_obj._dtype, \"float32\")\n\n        # Test get_config\n        acc_obj_config = acc_obj.get_config()\n        self.assertEqual(acc_obj_config[\"name\"], \"accuracy\")\n        self.assertEqual(acc_obj_config[\"dtype\"], \"float32\")\n\n        # Check save and restore config\n        acc_obj2 = accuracy_metrics.Accuracy.from_config(acc_obj_config)\n        self.assertEqual(acc_obj2.name, \"accuracy\")\n        self.assertEqual(len(acc_obj2.variables), 2)\n        self.assertEqual(acc_obj2._dtype, \"float32\")\n\n    def test_unweighted(self):\n        acc_obj = accuracy_metrics.Accuracy(name=\"accuracy\", dtype=\"float32\")\n        y_true = np.array([[1], [2], [3], [4]])\n        y_pred = np.array([[0], [2], [3], [4]])\n        acc_obj.update_state(y_true, y_pred)\n        result = acc_obj.result()\n        self.assertAllClose(result, 0.75, atol=1e-3)\n\n    def test_weighted(self):\n        acc_obj = accuracy_metrics.Accuracy(name=\"accuracy\", dtype=\"float32\")\n        y_true = np.array([[1], [2], [3], [4]])\n        y_pred = np.array([[0], [2], [3], [4]])\n        sample_weight = np.array([1, 1, 0, 0])\n        acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight)\n        result = acc_obj.result()\n        self.assertAllClose(result, 0.5, atol=1e-3)\n\n    def test_weighted_rank_1(self):\n        acc_obj = accuracy_metrics.Accuracy(name=\"accuracy\", dtype=\"float32\")\n        y_true = np.array([1, 2, 3, 4])\n        y_pred = np.array([0, 2, 3, 4])\n        sample_weight = np.array([1, 1, 0, 0])\n        acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight)\n        result = acc_obj.result()\n        self.assertAllClose(result, 0.5, atol=1e-3)\n\n    def test_weighted_nd_weights(self):\n        acc_obj = accuracy_metrics.Accuracy(name=\"accuracy\", dtype=\"float32\")\n        y_true = np.array([[1, 2], [3, 4]])\n        y_pred = np.array([[0, 2], [3, 4]])\n        sample_weight = np.array([[1, 0], [0, 1]])\n        acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight)\n        result = acc_obj.result()\n        self.assertAllClose(result, 0.5, atol=1e-3)\n\n    def test_weighted_nd_broadcast_weights(self):\n        acc_obj = accuracy_metrics.Accuracy(name=\"accuracy\", dtype=\"float32\")\n        y_true = np.array([[1, 2], [3, 4]])\n        y_pred = np.array([[0, 2], [3, 4]])\n        sample_weight = np.array([[1, 0]])\n        acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight)\n        result = acc_obj.result()\n        self.assertAllClose(result, 0.5, atol=1e-3)\n\n\nclass BinaryAccuracyTest(testing.TestCase):\n    def test_config(self):\n        bin_acc_obj = accuracy_metrics.BinaryAccuracy(\n            name=\"binary_accuracy\", dtype=\"float32\"\n        )\n        self.assertEqual(bin_acc_obj.name, \"binary_accuracy\")\n        self.assertEqual(len(bin_acc_obj.variables), 2)\n        self.assertEqual(bin_acc_obj._dtype, \"float32\")\n\n        # Test get_config\n        bin_acc_obj_config = bin_acc_obj.get_config()\n        self.assertEqual(bin_acc_obj_config[\"name\"], \"binary_accuracy\")\n        self.assertEqual(bin_acc_obj_config[\"dtype\"], \"float32\")\n\n        # Check save and restore config\n        bin_acc_obj2 = accuracy_metrics.BinaryAccuracy.from_config(\n            bin_acc_obj_config\n        )\n        self.assertEqual(bin_acc_obj2.name, \"binary_accuracy\")\n        self.assertEqual(len(bin_acc_obj2.variables), 2)\n        self.assertEqual(bin_acc_obj2._dtype, \"float32\")\n\n    def test_unweighted(self):\n        bin_acc_obj = accuracy_metrics.BinaryAccuracy(\n            name=\"binary_accuracy\", dtype=\"float32\"\n        )\n        y_true = np.array([[1], [1], [0], [0]])\n        y_pred = np.array([[0.98], [1], [0], [0.6]])\n        bin_acc_obj.update_state(y_true, y_pred)\n        result = bin_acc_obj.result()\n        self.assertAllClose(result, 0.75, atol=1e-3)\n\n        # Test broadcasting case\n        bin_acc_obj = accuracy_metrics.BinaryAccuracy(\n            name=\"binary_accuracy\", dtype=\"float32\"\n        )\n        y_true = np.array([1, 1, 0, 0])\n        y_pred = np.array([[0.98], [1], [0], [0.6]])\n        bin_acc_obj.update_state(y_true, y_pred)\n        result = bin_acc_obj.result()\n        self.assertAllClose(result, 0.75, atol=1e-3)\n\n    def test_weighted(self):\n        bin_acc_obj = accuracy_metrics.BinaryAccuracy(\n            name=\"binary_accuracy\", dtype=\"float32\"\n        )\n        y_true = np.array([[1], [1], [0], [0]])\n        y_pred = np.array([[0.98], [1], [0], [0.6]])\n        sample_weight = np.array([1, 0, 0, 1])\n        bin_acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight)\n        result = bin_acc_obj.result()\n        self.assertAllClose(result, 0.5, atol=1e-3)\n\n    def test_weighted_rank_1(self):\n        bin_acc_obj = accuracy_metrics.BinaryAccuracy(\n            name=\"binary_accuracy\", dtype=\"float32\"\n        )\n        y_true = np.array([1, 1, 0, 0])\n        y_pred = np.array([0.98, 1, 0, 0.6])\n        sample_weight = np.array([1, 0, 0, 1])\n        bin_acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight)\n        result = bin_acc_obj.result()\n        self.assertAllClose(result, 0.5, atol=1e-3)\n\n    def test_weighted_nd_weights(self):\n        bin_acc_obj = accuracy_metrics.BinaryAccuracy(\n            name=\"binary_accuracy\", dtype=\"float32\"\n        )\n        y_true = np.array([[1, 1], [0, 0]])\n        y_pred = np.array([[0.98, 1], [0, 0.6]])\n        sample_weight = np.array([[1, 0], [0, 1]])\n        bin_acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight)\n        result = bin_acc_obj.result()\n        self.assertAllClose(result, 0.5, atol=1e-3)\n\n    def test_weighted_nd_broadcast_weights(self):\n        bin_acc_obj = accuracy_metrics.BinaryAccuracy(\n            name=\"binary_accuracy\", dtype=\"float32\"\n        )\n        y_true = np.array([[1, 1], [0, 0]])\n        y_pred = np.array([[0.98, 1], [0, 0.6]])\n        sample_weight = np.array([[1, 0]])\n        bin_acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight)\n        result = bin_acc_obj.result()\n        self.assertAllClose(result, 1.0, atol=1e-3)\n\n    def test_threshold(self):\n        bin_acc_obj_1 = accuracy_metrics.BinaryAccuracy(\n            name=\"binary_accuracy\", dtype=\"float32\", threshold=0.3\n        )\n        bin_acc_obj_2 = accuracy_metrics.BinaryAccuracy(\n            name=\"binary_accuracy\", dtype=\"float32\", threshold=0.9\n        )\n        y_true = np.array([[1], [1], [0], [0]])\n        y_pred = np.array([[0.98], [0.5], [0.1], [0.2]])\n\n        bin_acc_obj_1.update_state(y_true, y_pred)\n        bin_acc_obj_2.update_state(y_true, y_pred)\n        result_1 = bin_acc_obj_1.result()\n        result_2 = bin_acc_obj_2.result()\n\n        # Higher threshold must result in lower measured accuracy.\n        self.assertAllClose(result_1, 1.0)\n        self.assertAllClose(result_2, 0.75)\n\n\nclass CategoricalAccuracyTest(testing.TestCase):\n    def test_config(self):\n        cat_acc_obj = accuracy_metrics.CategoricalAccuracy(\n            name=\"categorical_accuracy\", dtype=\"float32\"\n        )\n        self.assertEqual(cat_acc_obj.name, \"categorical_accuracy\")\n        self.assertEqual(len(cat_acc_obj.variables), 2)\n        self.assertEqual(cat_acc_obj._dtype, \"float32\")\n\n        # Test get_config\n        cat_acc_obj_config = cat_acc_obj.get_config()\n        self.assertEqual(cat_acc_obj_config[\"name\"], \"categorical_accuracy\")\n        self.assertEqual(cat_acc_obj_config[\"dtype\"], \"float32\")\n\n        # Check save and restore config\n        cat_acc_obj2 = accuracy_metrics.CategoricalAccuracy.from_config(\n            cat_acc_obj_config\n        )\n        self.assertEqual(cat_acc_obj2.name, \"categorical_accuracy\")\n        self.assertEqual(len(cat_acc_obj2.variables), 2)\n        self.assertEqual(cat_acc_obj2._dtype, \"float32\")\n\n    def test_unweighted(self):\n        cat_acc_obj = accuracy_metrics.CategoricalAccuracy(\n            name=\"categorical_accuracy\", dtype=\"float32\"\n        )\n        y_true = np.array([[0, 0, 1], [0, 1, 0]])\n        y_pred = np.array([[0.1, 0.9, 0.8], [0.05, 0.95, 0]])\n        cat_acc_obj.update_state(y_true, y_pred)\n        result = cat_acc_obj.result()\n        self.assertAllClose(result, 0.5, atol=1e-3)\n\n    def test_weighted(self):\n        cat_acc_obj = accuracy_metrics.CategoricalAccuracy(\n            name=\"categorical_accuracy\", dtype=\"float32\"\n        )\n        y_true = np.array([[0, 0, 1], [0, 1, 0]])\n        y_pred = np.array([[0.1, 0.9, 0.8], [0.05, 0.95, 0]])\n        sample_weight = np.array([0.7, 0.3])\n        cat_acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight)\n        result = cat_acc_obj.result()\n        self.assertAllClose(result, 0.3, atol=1e-3)\n\n\nclass SparseCategoricalAccuracyTest(testing.TestCase):\n    def test_config(self):\n        sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy(\n            name=\"sparse_categorical_accuracy\", dtype=\"float32\"\n        )\n        self.assertEqual(sp_cat_acc_obj.name, \"sparse_categorical_accuracy\")\n        self.assertEqual(len(sp_cat_acc_obj.variables), 2)\n        self.assertEqual(sp_cat_acc_obj._dtype, \"float32\")\n\n        # Test get_config\n        sp_cat_acc_obj_config = sp_cat_acc_obj.get_config()\n        self.assertEqual(\n            sp_cat_acc_obj_config[\"name\"], \"sparse_categorical_accuracy\"\n        )\n        self.assertEqual(sp_cat_acc_obj_config[\"dtype\"], \"float32\")\n\n        # Check save and restore config\n        sp_cat_acc_obj2 = (\n            accuracy_metrics.SparseCategoricalAccuracy.from_config(\n                sp_cat_acc_obj_config\n            )\n        )\n        self.assertEqual(sp_cat_acc_obj2.name, \"sparse_categorical_accuracy\")\n        self.assertEqual(len(sp_cat_acc_obj2.variables), 2)\n        self.assertEqual(sp_cat_acc_obj2._dtype, \"float32\")\n\n    def test_unweighted(self):\n        sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy(\n            name=\"sparse_categorical_accuracy\", dtype=\"float32\"\n        )\n        y_true = np.array([[2], [1]])\n        y_pred = np.array([[0.1, 0.6, 0.3], [0.05, 0.95, 0]])\n        sp_cat_acc_obj.update_state(y_true, y_pred)\n        result = sp_cat_acc_obj.result()\n        self.assertAllClose(result, 0.5, atol=1e-3)\n\n    def test_weighted(self):\n        sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy(\n            name=\"sparse_categorical_accuracy\", dtype=\"float32\"\n        )\n        y_true = np.array([[2], [1]])\n        y_pred = np.array([[0.1, 0.6, 0.3], [0.05, 0.95, 0]])\n        sample_weight = np.array([0.7, 0.3])\n        sp_cat_acc_obj.update_state(y_true, y_pred, sample_weight=sample_weight)\n        result = sp_cat_acc_obj.result()\n        self.assertAllClose(result, 0.3, atol=1e-3)\n\n    def test_squeeze_y_true(self):\n        sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy(\n            name=\"sparse_categorical_accuracy\", dtype=\"float32\"\n        )\n        # Scenario with 100% accuracy for simplicity.\n        # y_true is a 2D tensor with shape (3, 1) to test squeeze.\n        y_true = np.array([[0], [1], [2]])\n        y_pred = np.array(\n            [[0.9, 0.05, 0.05], [0.05, 0.9, 0.05], [0.05, 0.05, 0.9]]\n        )\n        sp_cat_acc_obj.update_state(y_true, y_pred)\n        result = sp_cat_acc_obj.result()\n        self.assertAllClose(result, 1.0, atol=1e-4)\n\n    def test_cast_y_pred_dtype(self):\n        sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy(\n            name=\"sparse_categorical_accuracy\", dtype=\"float32\"\n        )\n        # Scenario with 100% accuracy for simplicity.\n        # y_true is a 1D tensor with shape (2,) to test cast.\n        y_true = np.array([0, 1], dtype=np.int64)\n        y_pred = np.array([[0.9, 0.1], [0.1, 0.9]], dtype=np.float32)\n        sp_cat_acc_obj.update_state(y_true, y_pred)\n        result = sp_cat_acc_obj.result()\n        self.assertAllClose(result, 1.0, atol=1e-4)\n\n    def test_reshape_matches(self):\n        sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy(\n            name=\"sparse_categorical_accuracy\", dtype=\"float32\"\n        )\n        # Scenario with 100% accuracy for simplicity.\n        # y_true is a 2D tensor with shape (2, 1) to test reshape.\n        y_true = np.array([[0], [0]], dtype=np.int64)\n        y_pred = np.array(\n            [[[0.9, 0.1, 0.0], [0.8, 0.15, 0.05]]], dtype=np.float32\n        )\n        sp_cat_acc_obj.update_state(y_true, y_pred)\n        result = sp_cat_acc_obj.result()\n        self.assertAllClose(result, np.array([1.0, 1.0]))\n\n    def test_squeeze_y_true_shape(self):\n        sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy(\n            name=\"sparse_categorical_accuracy\", dtype=\"float32\"\n        )\n        # True labels are in the shape (num_samples, 1) should be squeezed.\n        y_true = np.array([[0], [1], [2]])\n        y_pred = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])\n        sp_cat_acc_obj.update_state(y_true, y_pred)\n        result = sp_cat_acc_obj.result()\n        self.assertAllClose(result, 1.0, atol=1e-4)\n\n    def test_cast_y_pred_to_match_y_true_dtype(self):\n        sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy(\n            name=\"sparse_categorical_accuracy\", dtype=\"float32\"\n        )\n        # True labels are integers, while predictions are floats.\n        y_true = np.array([0, 1, 2], dtype=np.int32)\n        y_pred = np.array(\n            [[0.9, 0.1, 0.0], [0.0, 0.9, 0.1], [0.1, 0.0, 0.9]],\n            dtype=np.float64,\n        )\n        sp_cat_acc_obj.update_state(y_true, y_pred)\n        result = sp_cat_acc_obj.result()\n        self.assertAllClose(result, 1.0, atol=1e-4)\n\n    def test_reshape_matches_to_original_y_true_shape(self):\n        sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy(\n            name=\"sparse_categorical_accuracy\", dtype=\"float32\"\n        )\n        # True labels have an additional dimension that needs to be squeezed.\n        y_true = np.array([[0], [1]])\n        # Predictions must trigger a reshape of matches.\n        y_pred = np.array([[0.9, 0.1], [0.1, 0.9]])\n        sp_cat_acc_obj.update_state(y_true, y_pred)\n        result = sp_cat_acc_obj.result()\n        self.assertAllClose(result, 1.0, atol=1e-4)\n\n    def test_matching_shapes_without_squeeze(self):\n        sp_cat_acc_obj = accuracy_metrics.SparseCategoricalAccuracy(\n            name=\"sparse_categorical_accuracy\", dtype=\"float32\"\n        )\n        y_true = np.array([2, 1, 0], dtype=np.int32)\n        y_pred = np.array(\n            [[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]],\n            dtype=np.float32,\n        )\n        # No need to squeeze or reshape.\n        sp_cat_acc_obj.update_state(y_true, y_pred)\n        result = sp_cat_acc_obj.result()\n        self.assertAllClose(result, 1.0, atol=1e-4)\n\n\nclass TopKCategoricalAccuracyTest(testing.TestCase):\n    def test_config(self):\n        top_k_cat_acc_obj = accuracy_metrics.TopKCategoricalAccuracy(\n            k=1, name=\"top_k_categorical_accuracy\", dtype=\"float32\"\n        )\n        self.assertEqual(top_k_cat_acc_obj.name, \"top_k_categorical_accuracy\")\n        self.assertEqual(len(top_k_cat_acc_obj.variables), 2)\n        self.assertEqual(top_k_cat_acc_obj._dtype, \"float32\")\n\n        # Test get_config\n        top_k_cat_acc_obj_config = top_k_cat_acc_obj.get_config()\n        self.assertEqual(\n            top_k_cat_acc_obj_config[\"name\"], \"top_k_categorical_accuracy\"\n        )\n        self.assertEqual(top_k_cat_acc_obj_config[\"dtype\"], \"float32\")\n        self.assertEqual(top_k_cat_acc_obj_config[\"k\"], 1)\n\n        # Check save and restore config\n        top_k_cat_acc_obj2 = (\n            accuracy_metrics.TopKCategoricalAccuracy.from_config(\n                top_k_cat_acc_obj_config\n            )\n        )\n        self.assertEqual(top_k_cat_acc_obj2.name, \"top_k_categorical_accuracy\")\n        self.assertEqual(len(top_k_cat_acc_obj2.variables), 2)\n        self.assertEqual(top_k_cat_acc_obj2._dtype, \"float32\")\n        self.assertEqual(top_k_cat_acc_obj2.k, 1)\n\n    def test_unweighted(self):\n        top_k_cat_acc_obj = accuracy_metrics.TopKCategoricalAccuracy(\n            k=1, name=\"top_k_categorical_accuracy\", dtype=\"float32\"\n        )\n        y_true = np.array([[0, 0, 1], [0, 1, 0]])\n        y_pred = np.array([[0.1, 0.9, 0.8], [0.05, 0.95, 0]], dtype=\"float32\")\n        top_k_cat_acc_obj.update_state(y_true, y_pred)\n        result = top_k_cat_acc_obj.result()\n        self.assertAllClose(result, 0.5, atol=1e-3)\n\n    def test_weighted(self):\n        top_k_cat_acc_obj = accuracy_metrics.TopKCategoricalAccuracy(\n            k=1, name=\"top_k_categorical_accuracy\", dtype=\"float32\"\n        )\n        y_true = np.array([[0, 0, 1], [0, 1, 0]])\n        y_pred = np.array([[0.1, 0.9, 0.8], [0.05, 0.95, 0]], dtype=\"float32\")\n        sample_weight = np.array([0.7, 0.3])\n        top_k_cat_acc_obj.update_state(\n            y_true, y_pred, sample_weight=sample_weight\n        )\n        result = top_k_cat_acc_obj.result()\n        self.assertAllClose(result, 0.3, atol=1e-3)\n\n\nclass SparseTopKCategoricalAccuracyTest(testing.TestCase):\n    def test_config(self):\n        sp_top_k_cat_acc_obj = accuracy_metrics.SparseTopKCategoricalAccuracy(\n            k=1, name=\"sparse_top_k_categorical_accuracy\", dtype=\"float32\"\n        )\n        self.assertEqual(\n            sp_top_k_cat_acc_obj.name, \"sparse_top_k_categorical_accuracy\"\n        )\n        self.assertEqual(len(sp_top_k_cat_acc_obj.variables), 2)\n        self.assertEqual(sp_top_k_cat_acc_obj._dtype, \"float32\")\n\n        # Test get_config\n        sp_top_k_cat_acc_obj_config = sp_top_k_cat_acc_obj.get_config()\n        self.assertEqual(\n            sp_top_k_cat_acc_obj_config[\"name\"],\n            \"sparse_top_k_categorical_accuracy\",\n        )\n        self.assertEqual(sp_top_k_cat_acc_obj_config[\"dtype\"], \"float32\")\n        self.assertEqual(sp_top_k_cat_acc_obj_config[\"k\"], 1)\n\n        # Check save and restore config\n        sp_top_k_cat_acc_obj2 = (\n            accuracy_metrics.SparseTopKCategoricalAccuracy.from_config(\n                sp_top_k_cat_acc_obj_config\n            )\n        )\n        self.assertEqual(\n            sp_top_k_cat_acc_obj2.name, \"sparse_top_k_categorical_accuracy\"\n        )\n        self.assertEqual(len(sp_top_k_cat_acc_obj2.variables), 2)\n        self.assertEqual(sp_top_k_cat_acc_obj2._dtype, \"float32\")\n        self.assertEqual(sp_top_k_cat_acc_obj2.k, 1)\n        self.assertFalse(sp_top_k_cat_acc_obj2.from_sorted_ids)\n\n    def test_config_from_sorted_ids(self):\n        sp_top_k_cat_acc_obj = accuracy_metrics.SparseTopKCategoricalAccuracy(\n            k=1,\n            name=\"sparse_top_k_categorical_accuracy\",\n            dtype=\"float32\",\n            from_sorted_ids=True,\n        )\n\n        # Test get_config\n        sp_top_k_cat_acc_obj_config = sp_top_k_cat_acc_obj.get_config()\n        self.assertTrue(sp_top_k_cat_acc_obj_config[\"from_sorted_ids\"])\n\n        # Check save and restore config\n        sp_top_k_cat_acc_obj2 = (\n            accuracy_metrics.SparseTopKCategoricalAccuracy.from_config(\n                sp_top_k_cat_acc_obj_config\n            )\n        )\n        self.assertTrue(sp_top_k_cat_acc_obj2.from_sorted_ids)\n\n    def test_unweighted(self):\n        sp_top_k_cat_acc_obj = accuracy_metrics.SparseTopKCategoricalAccuracy(\n            k=1, name=\"sparse_top_k_categorical_accuracy\", dtype=\"float32\"\n        )\n        y_true = np.array([2, 1])\n        y_pred = np.array([[0.1, 0.9, 0.8], [0.05, 0.95, 0]], dtype=\"float32\")\n        sp_top_k_cat_acc_obj.update_state(y_true, y_pred)\n        result = sp_top_k_cat_acc_obj.result()\n        self.assertAllClose(result, 0.5, atol=1e-3)\n\n    def test_weighted(self):\n        sp_top_k_cat_acc_obj = accuracy_metrics.SparseTopKCategoricalAccuracy(\n            k=1, name=\"sparse_top_k_categorical_accuracy\", dtype=\"float32\"\n        )\n        y_true = np.array([2, 1])\n        y_pred = np.array([[0.1, 0.9, 0.8], [0.05, 0.95, 0]], dtype=\"float32\")\n        sample_weight = np.array([0.7, 0.3])\n        sp_top_k_cat_acc_obj.update_state(\n            y_true, y_pred, sample_weight=sample_weight\n        )\n        result = sp_top_k_cat_acc_obj.result()\n        self.assertAllClose(result, 0.3, atol=1e-3)\n\n    def test_from_sorted_ids_unweighted(self):\n        sp_top_k_cat_acc_obj = accuracy_metrics.SparseTopKCategoricalAccuracy(\n            k=1,\n            name=\"sparse_top_k_categorical_accuracy\",\n            dtype=\"float32\",\n            from_sorted_ids=True,\n        )\n        y_true = np.array([2, 1])\n        y_pred = np.array([[1, 0, 3], [1, 2, 3]])\n        sp_top_k_cat_acc_obj.update_state(y_true, y_pred)\n        result = sp_top_k_cat_acc_obj.result()\n        self.assertAllClose(result, 0.5, atol=1e-3)\n\n    def test_from_sorted_ids_weighted(self):\n        sp_top_k_cat_acc_obj = accuracy_metrics.SparseTopKCategoricalAccuracy(\n            k=1,\n            name=\"sparse_top_k_categorical_accuracy\",\n            dtype=\"float32\",\n            from_sorted_ids=True,\n        )\n        y_true = np.array([2, 1])\n        y_pred = np.array([[1, 0, 3], [1, 2, 3]])\n        sample_weight = np.array([0.7, 0.3])\n        sp_top_k_cat_acc_obj.update_state(\n            y_true, y_pred, sample_weight=sample_weight\n        )\n        result = sp_top_k_cat_acc_obj.result()\n        self.assertAllClose(result, 0.3, atol=1e-3)\n"
  },
  {
    "path": "keras/src/metrics/confusion_metrics.py",
    "content": "import numpy as np\n\nfrom keras.src import activations\nfrom keras.src import backend\nfrom keras.src import initializers\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.metrics import metrics_utils\nfrom keras.src.metrics.metric import Metric\nfrom keras.src.utils.python_utils import to_list\n\n\nclass _ConfusionMatrixConditionCount(Metric):\n    \"\"\"Calculates the number of the given confusion matrix condition.\n\n    Args:\n        confusion_matrix_cond: One of `metrics_utils.ConfusionMatrix`\n            conditions.\n        thresholds: (Optional) Defaults to `0.5`. A float value or a python list\n            / tuple of float threshold values in `[0, 1]`. A threshold is\n            compared with prediction values to determine the truth value of\n            predictions (i.e., above the threshold is `True`, below is `False`).\n            One metric value is generated for each threshold value.\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n    \"\"\"\n\n    def __init__(\n        self, confusion_matrix_cond, thresholds=None, name=None, dtype=None\n    ):\n        super().__init__(name=name, dtype=dtype)\n        self._confusion_matrix_cond = confusion_matrix_cond\n        self.init_thresholds = thresholds\n        self.thresholds = metrics_utils.parse_init_thresholds(\n            thresholds, default_threshold=0.5\n        )\n        self._thresholds_distributed_evenly = (\n            metrics_utils.is_evenly_distributed_thresholds(self.thresholds)\n        )\n        self.accumulator = self.add_variable(\n            shape=(len(self.thresholds),),\n            initializer=initializers.Zeros(),\n            name=\"accumulator\",\n        )\n\n    def update_state(self, y_true, y_pred, sample_weight=None):\n        \"\"\"Accumulates the metric statistics.\n\n        Args:\n            y_true: The ground truth values.\n            y_pred: The predicted values.\n            sample_weight: Optional weighting of each example. Defaults to `1`.\n                Can be a tensor whose rank is either 0, or the same rank as\n                `y_true`, and must be broadcastable to `y_true`.\n        \"\"\"\n        return metrics_utils.update_confusion_matrix_variables(\n            {self._confusion_matrix_cond: self.accumulator},\n            y_true,\n            y_pred,\n            thresholds=self.thresholds,\n            thresholds_distributed_evenly=self._thresholds_distributed_evenly,\n            sample_weight=sample_weight,\n        )\n\n    def result(self):\n        if len(self.thresholds) == 1:\n            result = self.accumulator[0]\n        else:\n            result = self.accumulator\n        return backend.convert_to_tensor(result)\n\n    def get_config(self):\n        config = {\"thresholds\": self.init_thresholds}\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n\n@keras_export(\"keras.metrics.FalsePositives\")\nclass FalsePositives(_ConfusionMatrixConditionCount):\n    \"\"\"Calculates the number of false positives.\n\n    If `sample_weight` is given, calculates the sum of the weights of\n    false positives. This metric creates one local variable, `accumulator`\n    that is used to keep track of the number of false positives.\n\n    If `sample_weight` is `None`, weights default to 1.\n    Use `sample_weight` of 0 to mask values.\n\n    Args:\n        thresholds: (Optional) Defaults to `0.5`. A float value, or a Python\n            list/tuple of float threshold values in `[0, 1]`. A threshold is\n            compared with prediction values to determine the truth value of\n            predictions (i.e., above the threshold is `True`, below is `False`).\n            If used with a loss function that sets `from_logits=True` (i.e. no\n            sigmoid applied to predictions), `thresholds` should be set to 0.\n            One metric value is generated for each threshold value.\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Examples:\n\n    >>> m = keras.metrics.FalsePositives()\n    >>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1])\n    >>> m.result()\n    2.0\n\n    >>> m.reset_state()\n    >>> m.update_state([0, 1, 0, 0], [0, 0, 1, 1], sample_weight=[0, 0, 1, 0])\n    >>> m.result()\n    1.0\n    \"\"\"\n\n    def __init__(self, thresholds=None, name=None, dtype=None):\n        super().__init__(\n            confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_POSITIVES,\n            thresholds=thresholds,\n            name=name,\n            dtype=dtype,\n        )\n\n\n@keras_export(\"keras.metrics.FalseNegatives\")\nclass FalseNegatives(_ConfusionMatrixConditionCount):\n    \"\"\"Calculates the number of false negatives.\n\n    If `sample_weight` is given, calculates the sum of the weights of\n    false negatives. This metric creates one local variable, `accumulator`\n    that is used to keep track of the number of false negatives.\n\n    If `sample_weight` is `None`, weights default to 1.\n    Use `sample_weight` of 0 to mask values.\n\n    Args:\n        thresholds: (Optional) Defaults to `0.5`. A float value, or a Python\n            list/tuple of float threshold values in `[0, 1]`. A threshold is\n            compared with prediction values to determine the truth value of\n            predictions (i.e., above the threshold is `True`, below is `False`).\n            If used with a loss function that sets `from_logits=True` (i.e. no\n            sigmoid applied to predictions), `thresholds` should be set to 0.\n            One metric value is generated for each threshold value.\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Example:\n\n    >>> m = keras.metrics.FalseNegatives()\n    >>> m.update_state([0, 1, 1, 1], [0, 1, 0, 0])\n    >>> m.result()\n    2.0\n\n    >>> m.reset_state()\n    >>> m.update_state([0, 1, 1, 1], [0, 1, 0, 0], sample_weight=[0, 0, 1, 0])\n    >>> m.result()\n    1.0\n    \"\"\"\n\n    def __init__(self, thresholds=None, name=None, dtype=None):\n        super().__init__(\n            confusion_matrix_cond=metrics_utils.ConfusionMatrix.FALSE_NEGATIVES,\n            thresholds=thresholds,\n            name=name,\n            dtype=dtype,\n        )\n\n\n@keras_export(\"keras.metrics.TrueNegatives\")\nclass TrueNegatives(_ConfusionMatrixConditionCount):\n    \"\"\"Calculates the number of true negatives.\n\n    If `sample_weight` is given, calculates the sum of the weights of\n    true negatives. This metric creates one local variable, `accumulator`\n    that is used to keep track of the number of true negatives.\n\n    If `sample_weight` is `None`, weights default to 1.\n    Use `sample_weight` of 0 to mask values.\n\n    Args:\n        thresholds: (Optional) Defaults to `0.5`. A float value, or a Python\n            list/tuple of float threshold values in `[0, 1]`. A threshold is\n            compared with prediction values to determine the truth value of\n            predictions (i.e., above the threshold is `True`, below is `False`).\n            If used with a loss function that sets `from_logits=True` (i.e. no\n            sigmoid applied to predictions), `thresholds` should be set to 0.\n            One metric value is generated for each threshold value.\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Example:\n\n    >>> m = keras.metrics.TrueNegatives()\n    >>> m.update_state([0, 1, 0, 0], [1, 1, 0, 0])\n    >>> m.result()\n    2.0\n\n    >>> m.reset_state()\n    >>> m.update_state([0, 1, 0, 0], [1, 1, 0, 0], sample_weight=[0, 0, 1, 0])\n    >>> m.result()\n    1.0\n    \"\"\"\n\n    def __init__(self, thresholds=None, name=None, dtype=None):\n        super().__init__(\n            confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_NEGATIVES,\n            thresholds=thresholds,\n            name=name,\n            dtype=dtype,\n        )\n\n\n@keras_export(\"keras.metrics.TruePositives\")\nclass TruePositives(_ConfusionMatrixConditionCount):\n    \"\"\"Calculates the number of true positives.\n\n    If `sample_weight` is given, calculates the sum of the weights of\n    true positives. This metric creates one local variable, `true_positives`\n    that is used to keep track of the number of true positives.\n\n    If `sample_weight` is `None`, weights default to 1.\n    Use `sample_weight` of 0 to mask values.\n\n    Args:\n        thresholds: (Optional) Defaults to `0.5`. A float value, or a Python\n            list/tuple of float threshold values in `[0, 1]`. A threshold is\n            compared with prediction values to determine the truth value of\n            predictions (i.e., above the threshold is `True`, below is `False`).\n            If used with a loss function that sets `from_logits=True` (i.e. no\n            sigmoid applied to predictions), `thresholds` should be set to 0.\n            One metric value is generated for each threshold value.\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Example:\n\n    >>> m = keras.metrics.TruePositives()\n    >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1])\n    >>> m.result()\n    2.0\n\n    >>> m.reset_state()\n    >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0])\n    >>> m.result()\n    1.0\n    \"\"\"\n\n    def __init__(self, thresholds=None, name=None, dtype=None):\n        super().__init__(\n            confusion_matrix_cond=metrics_utils.ConfusionMatrix.TRUE_POSITIVES,\n            thresholds=thresholds,\n            name=name,\n            dtype=dtype,\n        )\n\n\n@keras_export(\"keras.metrics.Precision\")\nclass Precision(Metric):\n    \"\"\"Computes the precision of the predictions with respect to the labels.\n\n    The metric creates two local variables, `true_positives` and\n    `false_positives` that are used to compute the precision. This value is\n    ultimately returned as `precision`, an idempotent operation that simply\n    divides `true_positives` by the sum of `true_positives` and\n    `false_positives`.\n\n    If `sample_weight` is `None`, weights default to 1.\n    Use `sample_weight` of 0 to mask values.\n\n    If `top_k` is set, we'll calculate precision as how often on average a class\n    among the top-k classes with the highest predicted values of a batch entry\n    is correct and can be found in the label for that entry.\n\n    If `class_id` is specified, we calculate precision by considering only the\n    entries in the batch for which `class_id` is above the threshold and/or in\n    the top-k highest predictions, and computing the fraction of them for which\n    `class_id` is indeed a correct label.\n\n    Args:\n        thresholds: (Optional) A float value, or a Python list/tuple of float\n            threshold values in `[0, 1]`. A threshold is compared with\n            prediction values to determine the truth value of predictions (i.e.,\n            above the threshold is `True`, below is `False`). If used with a\n            loss function that sets `from_logits=True` (i.e. no sigmoid applied\n            to predictions), `thresholds` should be set to 0. One metric value\n            is generated for each threshold value. If neither `thresholds` nor\n            `top_k` are set, the default is to calculate precision with\n            `thresholds=0.5`.\n        top_k: (Optional) Unset by default. An int value specifying the top-k\n            predictions to consider when calculating precision.\n        class_id: (Optional) Integer class ID for which we want binary metrics.\n            This must be in the half-open interval `[0, num_classes)`, where\n            `num_classes` is the last dimension of predictions.\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Example:\n\n    >>> m = keras.metrics.Precision()\n    >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1])\n    >>> m.result()\n    0.6666667\n\n    >>> m.reset_state()\n    >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0])\n    >>> m.result()\n    1.0\n\n    >>> # With top_k=2, it will calculate precision over y_true[:2]\n    >>> # and y_pred[:2]\n    >>> m = keras.metrics.Precision(top_k=2)\n    >>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1])\n    >>> m.result()\n    0.0\n\n    >>> # With top_k=4, it will calculate precision over y_true[:4]\n    >>> # and y_pred[:4]\n    >>> m = keras.metrics.Precision(top_k=4)\n    >>> m.update_state([0, 0, 1, 1], [1, 1, 1, 1])\n    >>> m.result()\n    0.5\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(optimizer='sgd',\n                  loss='binary_crossentropy',\n                  metrics=[keras.metrics.Precision()])\n    ```\n\n    Usage with a loss with `from_logits=True`:\n\n    ```python\n    model.compile(optimizer='adam',\n                  loss=keras.losses.BinaryCrossentropy(from_logits=True),\n                  metrics=[keras.metrics.Precision(thresholds=0)])\n    ```\n    \"\"\"\n\n    def __init__(\n        self, thresholds=None, top_k=None, class_id=None, name=None, dtype=None\n    ):\n        super().__init__(name=name, dtype=dtype)\n        # Metric should be maximized during optimization.\n        self._direction = \"up\"\n\n        self.init_thresholds = thresholds\n        self.top_k = top_k\n        self.class_id = class_id\n\n        default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF\n        self.thresholds = metrics_utils.parse_init_thresholds(\n            thresholds, default_threshold=default_threshold\n        )\n        self._thresholds_distributed_evenly = (\n            metrics_utils.is_evenly_distributed_thresholds(self.thresholds)\n        )\n        self.true_positives = self.add_variable(\n            shape=(len(self.thresholds),),\n            initializer=initializers.Zeros(),\n            name=\"true_positives\",\n        )\n        self.false_positives = self.add_variable(\n            shape=(len(self.thresholds),),\n            initializer=initializers.Zeros(),\n            name=\"false_positives\",\n        )\n\n    def update_state(self, y_true, y_pred, sample_weight=None):\n        \"\"\"Accumulates true positive and false positive statistics.\n\n        Args:\n            y_true: The ground truth values, with the same dimensions as\n                `y_pred`. Will be cast to `bool`.\n            y_pred: The predicted values. Each element must be in the range\n                `[0, 1]`.\n            sample_weight: Optional weighting of each example. Defaults to `1`.\n                Can be a tensor whose rank is either 0, or the same rank as\n                `y_true`, and must be broadcastable to `y_true`.\n        \"\"\"\n        metrics_utils.update_confusion_matrix_variables(\n            {\n                metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,  # noqa: E501\n                metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives,  # noqa: E501\n            },\n            y_true,\n            y_pred,\n            thresholds=self.thresholds,\n            thresholds_distributed_evenly=self._thresholds_distributed_evenly,\n            top_k=self.top_k,\n            class_id=self.class_id,\n            sample_weight=sample_weight,\n        )\n\n    def result(self):\n        result = ops.divide_no_nan(\n            self.true_positives,\n            ops.add(self.true_positives, self.false_positives),\n        )\n        return result[0] if len(self.thresholds) == 1 else result\n\n    def reset_state(self):\n        num_thresholds = len(to_list(self.thresholds))\n        self.true_positives.assign(ops.zeros((num_thresholds,)))\n        self.false_positives.assign(ops.zeros((num_thresholds,)))\n\n    def get_config(self):\n        config = {\n            \"thresholds\": self.init_thresholds,\n            \"top_k\": self.top_k,\n            \"class_id\": self.class_id,\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n\n@keras_export(\"keras.metrics.Recall\")\nclass Recall(Metric):\n    \"\"\"Computes the recall of the predictions with respect to the labels.\n\n    This metric creates two local variables, `true_positives` and\n    `false_negatives`, that are used to compute the recall. This value is\n    ultimately returned as `recall`, an idempotent operation that simply divides\n    `true_positives` by the sum of `true_positives` and `false_negatives`.\n\n    If `sample_weight` is `None`, weights default to 1.\n    Use `sample_weight` of 0 to mask values.\n\n    If `top_k` is set, recall will be computed as how often on average a class\n    among the labels of a batch entry is in the top-k predictions.\n\n    If `class_id` is specified, we calculate recall by considering only the\n    entries in the batch for which `class_id` is in the label, and computing the\n    fraction of them for which `class_id` is above the threshold and/or in the\n    top-k predictions.\n\n    Args:\n        thresholds: (Optional) A float value, or a Python list/tuple of float\n            threshold values in `[0, 1]`. A threshold is compared with\n            prediction values to determine the truth value of predictions (i.e.,\n            above the threshold is `True`, below is `False`). If used with a\n            loss function that sets `from_logits=True` (i.e. no sigmoid\n            applied to predictions), `thresholds` should be set to 0.\n            One metric value is generated for each threshold value.\n            If neither `thresholds` nor `top_k` are set,\n            the default is to calculate recall with `thresholds=0.5`.\n        top_k: (Optional) Unset by default. An int value specifying the top-k\n            predictions to consider when calculating recall.\n        class_id: (Optional) Integer class ID for which we want binary metrics.\n            This must be in the half-open interval `[0, num_classes)`, where\n            `num_classes` is the last dimension of predictions.\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Example:\n\n    >>> m = keras.metrics.Recall()\n    >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1])\n    >>> m.result()\n    0.6666667\n\n    >>> m.reset_state()\n    >>> m.update_state([0, 1, 1, 1], [1, 0, 1, 1], sample_weight=[0, 0, 1, 0])\n    >>> m.result()\n    1.0\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(optimizer='sgd',\n                  loss='binary_crossentropy',\n                  metrics=[keras.metrics.Recall()])\n    ```\n\n    Usage with a loss with `from_logits=True`:\n\n    ```python\n    model.compile(optimizer='adam',\n                  loss=keras.losses.BinaryCrossentropy(from_logits=True),\n                  metrics=[keras.metrics.Recall(thresholds=0)])\n    ```\n    \"\"\"\n\n    def __init__(\n        self, thresholds=None, top_k=None, class_id=None, name=None, dtype=None\n    ):\n        super().__init__(name=name, dtype=dtype)\n        # Metric should be maximized during optimization.\n        self._direction = \"up\"\n\n        self.init_thresholds = thresholds\n        self.top_k = top_k\n        self.class_id = class_id\n\n        default_threshold = 0.5 if top_k is None else metrics_utils.NEG_INF\n        self.thresholds = metrics_utils.parse_init_thresholds(\n            thresholds, default_threshold=default_threshold\n        )\n        self._thresholds_distributed_evenly = (\n            metrics_utils.is_evenly_distributed_thresholds(self.thresholds)\n        )\n        self.true_positives = self.add_variable(\n            shape=(len(self.thresholds),),\n            initializer=initializers.Zeros(),\n            name=\"true_positives\",\n        )\n        self.false_negatives = self.add_variable(\n            shape=(len(self.thresholds),),\n            initializer=initializers.Zeros(),\n            name=\"false_negatives\",\n        )\n\n    def update_state(self, y_true, y_pred, sample_weight=None):\n        \"\"\"Accumulates true positive and false negative statistics.\n\n        Args:\n            y_true: The ground truth values, with the same dimensions as\n                `y_pred`. Will be cast to `bool`.\n            y_pred: The predicted values. Each element must be in the range\n                `[0, 1]`.\n            sample_weight: Optional weighting of each example. Defaults to `1`.\n                Can be a tensor whose rank is either 0, or the same rank as\n                `y_true`, and must be broadcastable to `y_true`.\n        \"\"\"\n        metrics_utils.update_confusion_matrix_variables(\n            {\n                metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,  # noqa: E501\n                metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives,  # noqa: E501\n            },\n            y_true,\n            y_pred,\n            thresholds=self.thresholds,\n            thresholds_distributed_evenly=self._thresholds_distributed_evenly,\n            top_k=self.top_k,\n            class_id=self.class_id,\n            sample_weight=sample_weight,\n        )\n\n    def result(self):\n        result = ops.divide_no_nan(\n            self.true_positives,\n            ops.add(self.true_positives, self.false_negatives),\n        )\n        return result[0] if len(self.thresholds) == 1 else result\n\n    def reset_state(self):\n        num_thresholds = len(to_list(self.thresholds))\n        self.true_positives.assign(ops.zeros((num_thresholds,)))\n        self.false_negatives.assign(ops.zeros((num_thresholds,)))\n\n    def get_config(self):\n        config = {\n            \"thresholds\": self.init_thresholds,\n            \"top_k\": self.top_k,\n            \"class_id\": self.class_id,\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n\nclass SensitivitySpecificityBase(Metric):\n    \"\"\"Abstract base class for computing sensitivity and specificity.\n\n    For additional information about specificity and sensitivity, see\n    [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).\n    \"\"\"\n\n    def __init__(\n        self, value, num_thresholds=200, class_id=None, name=None, dtype=None\n    ):\n        super().__init__(name=name, dtype=dtype)\n        # Metric should be maximized during optimization.\n        self._direction = \"up\"\n\n        if num_thresholds <= 0:\n            raise ValueError(\n                \"Argument `num_thresholds` must be an integer > 0. \"\n                f\"Received: num_thresholds={num_thresholds}\"\n            )\n        self.value = value\n        self.class_id = class_id\n\n        # Compute `num_thresholds` thresholds in [0, 1]\n        if num_thresholds == 1:\n            self.thresholds = [0.5]\n            self._thresholds_distributed_evenly = False\n        else:\n            thresholds = [\n                (i + 1) * 1.0 / (num_thresholds - 1)\n                for i in range(num_thresholds - 2)\n            ]\n            self.thresholds = [0.0] + thresholds + [1.0]\n            self._thresholds_distributed_evenly = True\n\n        self.true_positives = self.add_variable(\n            shape=(len(self.thresholds),),\n            initializer=initializers.Zeros(),\n            name=\"true_positives\",\n        )\n        self.false_positives = self.add_variable(\n            shape=(len(self.thresholds),),\n            initializer=initializers.Zeros(),\n            name=\"false_positives\",\n        )\n        self.true_negatives = self.add_variable(\n            shape=(len(self.thresholds),),\n            initializer=initializers.Zeros(),\n            name=\"true_negatives\",\n        )\n        self.false_negatives = self.add_variable(\n            shape=(len(self.thresholds),),\n            initializer=initializers.Zeros(),\n            name=\"false_negatives\",\n        )\n\n    def update_state(self, y_true, y_pred, sample_weight=None):\n        \"\"\"Accumulates confusion matrix statistics.\n\n        Args:\n            y_true: The ground truth values.\n            y_pred: The predicted values.\n            sample_weight: Optional weighting of each example. Defaults to `1`.\n                Can be a tensor whose rank is either 0, or the same rank as\n                `y_true`, and must be broadcastable to `y_true`.\n        \"\"\"\n        metrics_utils.update_confusion_matrix_variables(\n            {\n                metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,  # noqa: E501\n                metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives,  # noqa: E501\n                metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives,  # noqa: E501\n                metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives,  # noqa: E501\n            },\n            y_true,\n            y_pred,\n            thresholds=self.thresholds,\n            thresholds_distributed_evenly=self._thresholds_distributed_evenly,\n            class_id=self.class_id,\n            sample_weight=sample_weight,\n        )\n\n    def reset_state(self):\n        num_thresholds = len(self.thresholds)\n        self.true_positives.assign(ops.zeros((num_thresholds,)))\n        self.false_positives.assign(ops.zeros((num_thresholds,)))\n        self.true_negatives.assign(ops.zeros((num_thresholds,)))\n        self.false_negatives.assign(ops.zeros((num_thresholds,)))\n\n    def get_config(self):\n        config = {\"class_id\": self.class_id}\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n    def _find_max_under_constraint(self, constrained, dependent, predicate):\n        \"\"\"Returns the maximum of dependent_statistic that satisfies the\n        constraint.\n\n        Args:\n            constrained: Over these values the constraint is specified. A rank-1\n                tensor.\n            dependent: From these values the maximum that satisfies the\n                constraint is selected. Values in this tensor and in\n                `constrained` are linked by having the same threshold at each\n                position, hence this tensor must have the same shape.\n            predicate: A binary boolean functor to be applied to arguments\n                `constrained` and `self.value`, e.g. `ops.greater`.\n\n        Returns:\n            maximal dependent value, if no value satisfies the constraint 0.0.\n        \"\"\"\n        feasible = predicate(constrained, self.value)\n        # Mask values based on whether they satisfy the constraint and take max.\n        return ops.max(\n            ops.multiply(dependent, ops.cast(feasible, dependent.dtype)),\n            initial=0,\n        )\n\n\n@keras_export(\"keras.metrics.SensitivityAtSpecificity\")\nclass SensitivityAtSpecificity(SensitivitySpecificityBase):\n    \"\"\"Computes best sensitivity where specificity is >= specified value.\n\n    `Sensitivity` measures the proportion of actual positives that are correctly\n    identified as such `(tp / (tp + fn))`.\n    `Specificity` measures the proportion of actual negatives that are correctly\n    identified as such `(tn / (tn + fp))`.\n\n    This metric creates four local variables, `true_positives`,\n    `true_negatives`, `false_positives` and `false_negatives` that are used to\n    compute the sensitivity at the given specificity. The threshold for the\n    given specificity value is computed and used to evaluate the corresponding\n    sensitivity.\n\n    If `sample_weight` is `None`, weights default to 1.\n    Use `sample_weight` of 0 to mask values.\n\n    If `class_id` is specified, we calculate precision by considering only the\n    entries in the batch for which `class_id` is above the threshold\n    predictions, and computing the fraction of them for which `class_id` is\n    indeed a correct label.\n\n    For additional information about specificity and sensitivity, see\n    [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).\n\n    Args:\n        specificity: A scalar value in range `[0, 1]`.\n        num_thresholds: (Optional) Defaults to 200. The number of thresholds to\n            use for matching the given specificity.\n        class_id: (Optional) Integer class ID for which we want binary metrics.\n            This must be in the half-open interval `[0, num_classes)`, where\n            `num_classes` is the last dimension of predictions.\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Example:\n\n    >>> m = keras.metrics.SensitivityAtSpecificity(0.5)\n    >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])\n    >>> m.result()\n    0.5\n\n    >>> m.reset_state()\n    >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],\n    ...                sample_weight=[1, 1, 2, 2, 1])\n    >>> m.result()\n    0.333333\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(\n        optimizer='sgd',\n        loss='binary_crossentropy',\n        metrics=[keras.metrics.SensitivityAtSpecificity(specificity=0.5)])\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        specificity,\n        num_thresholds=200,\n        class_id=None,\n        name=None,\n        dtype=None,\n    ):\n        if specificity < 0 or specificity > 1:\n            raise ValueError(\n                \"Argument `specificity` must be in the range [0, 1]. \"\n                f\"Received: specificity={specificity}\"\n            )\n        self.specificity = specificity\n        self.num_thresholds = num_thresholds\n        super().__init__(\n            specificity,\n            num_thresholds=num_thresholds,\n            class_id=class_id,\n            name=name,\n            dtype=dtype,\n        )\n\n    def result(self):\n        sensitivities = ops.divide_no_nan(\n            self.true_positives,\n            ops.add(self.true_positives, self.false_negatives),\n        )\n        specificities = ops.divide_no_nan(\n            self.true_negatives,\n            ops.add(self.true_negatives, self.false_positives),\n        )\n        return self._find_max_under_constraint(\n            specificities, sensitivities, ops.greater_equal\n        )\n\n    def get_config(self):\n        config = {\n            \"num_thresholds\": self.num_thresholds,\n            \"specificity\": self.specificity,\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n\n@keras_export(\"keras.metrics.SpecificityAtSensitivity\")\nclass SpecificityAtSensitivity(SensitivitySpecificityBase):\n    \"\"\"Computes best specificity where sensitivity is >= specified value.\n\n    `Sensitivity` measures the proportion of actual positives that are correctly\n    identified as such `(tp / (tp + fn))`.\n    `Specificity` measures the proportion of actual negatives that are correctly\n    identified as such `(tn / (tn + fp))`.\n\n    This metric creates four local variables, `true_positives`,\n    `true_negatives`, `false_positives` and `false_negatives` that are used to\n    compute the specificity at the given sensitivity. The threshold for the\n    given sensitivity value is computed and used to evaluate the corresponding\n    specificity.\n\n    If `sample_weight` is `None`, weights default to 1.\n    Use `sample_weight` of 0 to mask values.\n\n    If `class_id` is specified, we calculate precision by considering only the\n    entries in the batch for which `class_id` is above the threshold\n    predictions, and computing the fraction of them for which `class_id` is\n    indeed a correct label.\n\n    For additional information about specificity and sensitivity, see\n    [the following](https://en.wikipedia.org/wiki/Sensitivity_and_specificity).\n\n    Args:\n        sensitivity: A scalar value in range `[0, 1]`.\n        num_thresholds: (Optional) Defaults to 200. The number of thresholds to\n            use for matching the given sensitivity.\n        class_id: (Optional) Integer class ID for which we want binary metrics.\n            This must be in the half-open interval `[0, num_classes)`, where\n            `num_classes` is the last dimension of predictions.\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Example:\n\n    >>> m = keras.metrics.SpecificityAtSensitivity(0.5)\n    >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])\n    >>> m.result()\n    0.66666667\n\n    >>> m.reset_state()\n    >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],\n    ...                sample_weight=[1, 1, 2, 2, 2])\n    >>> m.result()\n    0.5\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(\n        optimizer='sgd',\n        loss='binary_crossentropy',\n        metrics=[keras.metrics.SpecificityAtSensitivity(sensitivity=0.3)])\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        sensitivity,\n        num_thresholds=200,\n        class_id=None,\n        name=None,\n        dtype=None,\n    ):\n        if sensitivity < 0 or sensitivity > 1:\n            raise ValueError(\n                \"Argument `sensitivity` must be in the range [0, 1]. \"\n                f\"Received: sensitivity={sensitivity}\"\n            )\n        self.sensitivity = sensitivity\n        self.num_thresholds = num_thresholds\n        super().__init__(\n            sensitivity,\n            num_thresholds=num_thresholds,\n            class_id=class_id,\n            name=name,\n            dtype=dtype,\n        )\n\n    def result(self):\n        sensitivities = ops.divide_no_nan(\n            self.true_positives,\n            ops.add(self.true_positives, self.false_negatives),\n        )\n        specificities = ops.divide_no_nan(\n            self.true_negatives,\n            ops.add(self.true_negatives, self.false_positives),\n        )\n        return self._find_max_under_constraint(\n            sensitivities, specificities, ops.greater_equal\n        )\n\n    def get_config(self):\n        config = {\n            \"num_thresholds\": self.num_thresholds,\n            \"sensitivity\": self.sensitivity,\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n\n@keras_export(\"keras.metrics.PrecisionAtRecall\")\nclass PrecisionAtRecall(SensitivitySpecificityBase):\n    \"\"\"Computes best precision where recall is >= specified value.\n\n    This metric creates four local variables, `true_positives`,\n    `true_negatives`, `false_positives` and `false_negatives` that are used to\n    compute the precision at the given recall. The threshold for the given\n    recall value is computed and used to evaluate the corresponding precision.\n\n    If `sample_weight` is `None`, weights default to 1.\n    Use `sample_weight` of 0 to mask values.\n\n    If `class_id` is specified, we calculate precision by considering only the\n    entries in the batch for which `class_id` is above the threshold\n    predictions, and computing the fraction of them for which `class_id` is\n    indeed a correct label.\n\n    Args:\n        recall: A scalar value in range `[0, 1]`.\n        num_thresholds: (Optional) Defaults to 200. The number of thresholds to\n            use for matching the given recall.\n        class_id: (Optional) Integer class ID for which we want binary metrics.\n            This must be in the half-open interval `[0, num_classes)`, where\n            `num_classes` is the last dimension of predictions.\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Example:\n\n    >>> m = keras.metrics.PrecisionAtRecall(0.5)\n    >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8])\n    >>> m.result()\n    0.5\n\n    >>> m.reset_state()\n    >>> m.update_state([0, 0, 0, 1, 1], [0, 0.3, 0.8, 0.3, 0.8],\n    ...                sample_weight=[2, 2, 2, 1, 1])\n    >>> m.result()\n    0.33333333\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(\n        optimizer='sgd',\n        loss='binary_crossentropy',\n        metrics=[keras.metrics.PrecisionAtRecall(recall=0.8)])\n    ```\n    \"\"\"\n\n    def __init__(\n        self, recall, num_thresholds=200, class_id=None, name=None, dtype=None\n    ):\n        if recall < 0 or recall > 1:\n            raise ValueError(\n                \"Argument `recall` must be in the range [0, 1]. \"\n                f\"Received: recall={recall}\"\n            )\n        self.recall = recall\n        self.num_thresholds = num_thresholds\n        super().__init__(\n            value=recall,\n            num_thresholds=num_thresholds,\n            class_id=class_id,\n            name=name,\n            dtype=dtype,\n        )\n\n    def result(self):\n        recalls = ops.divide_no_nan(\n            self.true_positives,\n            ops.add(self.true_positives, self.false_negatives),\n        )\n        precisions = ops.divide_no_nan(\n            self.true_positives,\n            ops.add(self.true_positives, self.false_positives),\n        )\n        return self._find_max_under_constraint(\n            recalls, precisions, ops.greater_equal\n        )\n\n    def get_config(self):\n        config = {\"num_thresholds\": self.num_thresholds, \"recall\": self.recall}\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n\n@keras_export(\"keras.metrics.RecallAtPrecision\")\nclass RecallAtPrecision(SensitivitySpecificityBase):\n    \"\"\"Computes best recall where precision is >= specified value.\n\n    For a given score-label-distribution the required precision might not\n    be achievable, in this case 0.0 is returned as recall.\n\n    This metric creates four local variables, `true_positives`,\n    `true_negatives`, `false_positives` and `false_negatives` that are used to\n    compute the recall at the given precision. The threshold for the given\n    precision value is computed and used to evaluate the corresponding recall.\n\n    If `sample_weight` is `None`, weights default to 1.\n    Use `sample_weight` of 0 to mask values.\n\n    If `class_id` is specified, we calculate precision by considering only the\n    entries in the batch for which `class_id` is above the threshold\n    predictions, and computing the fraction of them for which `class_id` is\n    indeed a correct label.\n\n    Args:\n        precision: A scalar value in range `[0, 1]`.\n        num_thresholds: (Optional) Defaults to 200. The number of thresholds\n            to use for matching the given precision.\n        class_id: (Optional) Integer class ID for which we want binary metrics.\n            This must be in the half-open interval `[0, num_classes)`, where\n            `num_classes` is the last dimension of predictions.\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Example:\n\n    >>> m = keras.metrics.RecallAtPrecision(0.8)\n    >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])\n    >>> m.result()\n    0.5\n\n    >>> m.reset_state()\n    >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],\n    ...                sample_weight=[1, 0, 0, 1])\n    >>> m.result()\n    1.0\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(\n        optimizer='sgd',\n        loss='binary_crossentropy',\n        metrics=[keras.metrics.RecallAtPrecision(precision=0.8)])\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        precision,\n        num_thresholds=200,\n        class_id=None,\n        name=None,\n        dtype=None,\n    ):\n        if precision < 0 or precision > 1:\n            raise ValueError(\n                \"Argument `precision` must be in the range [0, 1]. \"\n                f\"Received: precision={precision}\"\n            )\n        self.precision = precision\n        self.num_thresholds = num_thresholds\n        super().__init__(\n            value=precision,\n            num_thresholds=num_thresholds,\n            class_id=class_id,\n            name=name,\n            dtype=dtype,\n        )\n\n    def result(self):\n        recalls = ops.divide_no_nan(\n            self.true_positives,\n            ops.add(self.true_positives, self.false_negatives),\n        )\n        precisions = ops.divide_no_nan(\n            self.true_positives,\n            ops.add(self.true_positives, self.false_positives),\n        )\n        return self._find_max_under_constraint(\n            precisions, recalls, ops.greater_equal\n        )\n\n    def get_config(self):\n        config = {\n            \"num_thresholds\": self.num_thresholds,\n            \"precision\": self.precision,\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n\n@keras_export(\"keras.metrics.AUC\")\nclass AUC(Metric):\n    \"\"\"Approximates the AUC (Area under the curve) of the ROC or PR curves.\n\n    The AUC (Area under the curve) of the ROC (Receiver operating\n    characteristic; default) or PR (Precision Recall) curves are quality\n    measures of binary classifiers. Unlike the accuracy, and like cross-entropy\n    losses, ROC-AUC and PR-AUC evaluate all the operational points of a model.\n\n    This class approximates AUCs using a Riemann sum. During the metric\n    accumulation phrase, predictions are accumulated within predefined buckets\n    by value. The AUC is then computed by interpolating per-bucket averages.\n    These buckets define the evaluated operational points.\n\n    This metric creates four local variables, `true_positives`,\n    `true_negatives`, `false_positives` and `false_negatives` that are used to\n    compute the AUC.  To discretize the AUC curve, a linearly spaced set of\n    thresholds is used to compute pairs of recall and precision values. The area\n    under the ROC-curve is therefore computed using the height of the recall\n    values by the false positive rate, while the area under the PR-curve is the\n    computed using the height of the precision values by the recall.\n\n    This value is ultimately returned as `auc`, an idempotent operation that\n    computes the area under a discretized curve of precision versus recall\n    values (computed using the aforementioned variables). The `num_thresholds`\n    variable controls the degree of discretization with larger numbers of\n    thresholds more closely approximating the true AUC. The quality of the\n    approximation may vary dramatically depending on `num_thresholds`. The\n    `thresholds` parameter can be used to manually specify thresholds which\n    split the predictions more evenly.\n\n    For a best approximation of the real AUC, `predictions` should be\n    distributed approximately uniformly in the range `[0, 1]` (if\n    `from_logits=False`). The quality of the AUC approximation may be poor if\n    this is not the case. Setting `summation_method` to 'minoring' or 'majoring'\n    can help quantify the error in the approximation by providing lower or upper\n    bound estimate of the AUC.\n\n    If `sample_weight` is `None`, weights default to 1.\n    Use `sample_weight` of 0 to mask values.\n\n    Args:\n        num_thresholds: (Optional) The number of thresholds to\n            use when discretizing the roc curve. Values must be > 1.\n            Defaults to `200`.\n        curve: (Optional) Specifies the name of the curve to be computed,\n            `'ROC'` (default) or `'PR'` for the Precision-Recall-curve.\n        summation_method: (Optional) Specifies the [Riemann summation method](\n              https://en.wikipedia.org/wiki/Riemann_sum) used.\n              'interpolation' (default) applies mid-point summation scheme for\n              `ROC`.  For PR-AUC, interpolates (true/false) positives but not\n              the ratio that is precision (see Davis & Goadrich 2006 for\n              details); 'minoring' applies left summation for increasing\n              intervals and right summation for decreasing intervals; 'majoring'\n              does the opposite.\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n        thresholds: (Optional) A list of floating point values to use as the\n            thresholds for discretizing the curve. If set, the `num_thresholds`\n            parameter is ignored. Values should be in `[0, 1]`. Endpoint\n            thresholds equal to {`-epsilon`, `1+epsilon`} for a small positive\n            epsilon value will be automatically included with these to correctly\n            handle predictions equal to exactly 0 or 1.\n        multi_label: boolean indicating whether multilabel data should be\n            treated as such, wherein AUC is computed separately for each label\n            and then averaged across labels, or (when `False`) if the data\n            should be flattened into a single label before AUC computation. In\n            the latter case, when multilabel data is passed to AUC, each\n            label-prediction pair is treated as an individual data point. Should\n            be set to `False` for multi-class data.\n        num_labels: (Optional) The number of labels, used when `multi_label` is\n            True. If `num_labels` is not specified, then state variables get\n            created on the first call to `update_state`.\n        label_weights: (Optional) list, array, or tensor of non-negative weights\n            used to compute AUCs for multilabel data. When `multi_label` is\n            True, the weights are applied to the individual label AUCs when they\n            are averaged to produce the multi-label AUC. When it's False, they\n            are used to weight the individual label predictions in computing the\n            confusion matrix on the flattened data. Note that this is unlike\n            `class_weights` in that `class_weights` weights the example\n            depending on the value of its label, whereas `label_weights` depends\n            only on the index of that label before flattening; therefore\n            `label_weights` should not be used for multi-class data.\n        from_logits: boolean indicating whether the predictions (`y_pred` in\n        `update_state`) are probabilities or sigmoid logits. As a rule of thumb,\n        when using a keras loss, the `from_logits` constructor argument of the\n        loss should match the AUC `from_logits` constructor argument.\n\n    Example:\n\n    >>> m = keras.metrics.AUC(num_thresholds=3)\n    >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])\n    >>> # threshold values are [0 - 1e-7, 0.5, 1 + 1e-7]\n    >>> # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]\n    >>> # tp_rate = recall = [1, 0.5, 0], fp_rate = [1, 0, 0]\n    >>> # auc = ((((1 + 0.5) / 2) * (1 - 0)) + (((0.5 + 0) / 2) * (0 - 0)))\n    >>> #     = 0.75\n    >>> m.result()\n    0.75\n\n    >>> m.reset_state()\n    >>> m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9],\n    ...                sample_weight=[1, 0, 0, 1])\n    >>> m.result()\n    1.0\n\n    Usage with `compile()` API:\n\n    ```python\n    # Reports the AUC of a model outputting a probability.\n    model.compile(optimizer='sgd',\n                  loss=keras.losses.BinaryCrossentropy(),\n                  metrics=[keras.metrics.AUC()])\n\n    # Reports the AUC of a model outputting a logit.\n    model.compile(optimizer='sgd',\n                  loss=keras.losses.BinaryCrossentropy(from_logits=True),\n                  metrics=[keras.metrics.AUC(from_logits=True)])\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        num_thresholds=200,\n        curve=\"ROC\",\n        summation_method=\"interpolation\",\n        name=None,\n        dtype=None,\n        thresholds=None,\n        multi_label=False,\n        num_labels=None,\n        label_weights=None,\n        from_logits=False,\n    ):\n        # Metric should be maximized during optimization.\n        self._direction = \"up\"\n\n        # Validate configurations.\n        if isinstance(curve, metrics_utils.AUCCurve) and curve not in list(\n            metrics_utils.AUCCurve\n        ):\n            raise ValueError(\n                f'Invalid `curve` argument value \"{curve}\". '\n                f\"Expected one of: {list(metrics_utils.AUCCurve)}\"\n            )\n        if isinstance(\n            summation_method, metrics_utils.AUCSummationMethod\n        ) and summation_method not in list(metrics_utils.AUCSummationMethod):\n            raise ValueError(\n                \"Invalid `summation_method` \"\n                f'argument value \"{summation_method}\". '\n                f\"Expected one of: {list(metrics_utils.AUCSummationMethod)}\"\n            )\n\n        # Update properties.\n        self._init_from_thresholds = thresholds is not None\n        if thresholds is not None:\n            # If specified, use the supplied thresholds.\n            self.num_thresholds = len(thresholds) + 2\n            thresholds = sorted(thresholds)\n            self._thresholds_distributed_evenly = (\n                metrics_utils.is_evenly_distributed_thresholds(\n                    np.array([0.0] + thresholds + [1.0])\n                )\n            )\n        else:\n            if num_thresholds <= 1:\n                raise ValueError(\n                    \"Argument `num_thresholds` must be an integer > 1. \"\n                    f\"Received: num_thresholds={num_thresholds}\"\n                )\n\n            # Otherwise, linearly interpolate (num_thresholds - 2) thresholds in\n            # (0, 1).\n            self.num_thresholds = num_thresholds\n            thresholds = [\n                (i + 1) * 1.0 / (num_thresholds - 1)\n                for i in range(num_thresholds - 2)\n            ]\n            self._thresholds_distributed_evenly = True\n\n        # Add an endpoint \"threshold\" below zero and above one for either\n        # threshold method to account for floating point imprecisions.\n        self._thresholds = np.array(\n            [0.0 - backend.epsilon()] + thresholds + [1.0 + backend.epsilon()]\n        )\n\n        if isinstance(curve, metrics_utils.AUCCurve):\n            self.curve = curve\n        else:\n            self.curve = metrics_utils.AUCCurve.from_str(curve)\n        if isinstance(summation_method, metrics_utils.AUCSummationMethod):\n            self.summation_method = summation_method\n        else:\n            self.summation_method = metrics_utils.AUCSummationMethod.from_str(\n                summation_method\n            )\n        super().__init__(name=name, dtype=dtype)\n\n        # Handle multilabel arguments.\n        self.multi_label = multi_label\n        self.num_labels = num_labels\n        if label_weights is not None:\n            label_weights = ops.array(label_weights, dtype=self.dtype)\n            self.label_weights = label_weights\n\n        else:\n            self.label_weights = None\n\n        self._from_logits = from_logits\n\n        self._built = False\n        if self.multi_label:\n            if num_labels:\n                shape = [None, num_labels]\n                self._build(shape)\n        else:\n            if num_labels:\n                raise ValueError(\n                    \"`num_labels` is needed only when `multi_label` is True.\"\n                )\n            self._build(None)\n\n    @property\n    def thresholds(self):\n        \"\"\"The thresholds used for evaluating AUC.\"\"\"\n        return list(self._thresholds)\n\n    def _build(self, shape):\n        \"\"\"Initialize TP, FP, TN, and FN tensors, given the shape of the\n        data.\"\"\"\n        if self.multi_label:\n            if len(shape) != 2:\n                raise ValueError(\n                    \"`y_pred` must have rank 2 when `multi_label=True`. \"\n                    f\"Found rank {len(shape)}. \"\n                    f\"Full shape received for `y_pred`: {shape}\"\n                )\n            self._num_labels = shape[1]\n            variable_shape = [self.num_thresholds, self._num_labels]\n        else:\n            variable_shape = [self.num_thresholds]\n\n        self._build_input_shape = shape\n        # Create metric variables\n        self.true_positives = self.add_variable(\n            shape=variable_shape,\n            initializer=initializers.Zeros(),\n            name=\"true_positives\",\n        )\n        self.false_positives = self.add_variable(\n            shape=variable_shape,\n            initializer=initializers.Zeros(),\n            name=\"false_positives\",\n        )\n        self.true_negatives = self.add_variable(\n            shape=variable_shape,\n            initializer=initializers.Zeros(),\n            name=\"true_negatives\",\n        )\n        self.false_negatives = self.add_variable(\n            shape=variable_shape,\n            initializer=initializers.Zeros(),\n            name=\"false_negatives\",\n        )\n\n        self._built = True\n\n    def update_state(self, y_true, y_pred, sample_weight=None):\n        \"\"\"Accumulates confusion matrix statistics.\n\n        Args:\n            y_true: The ground truth values.\n            y_pred: The predicted values.\n            sample_weight: Optional weighting of each example. Can\n                be a tensor whose rank is either 0, or the same rank as\n                `y_true`, and must be broadcastable to `y_true`. Defaults to\n                `1`.\n        \"\"\"\n        if not self._built:\n            self._build(y_pred.shape)\n\n        # Only forward label_weights to update_confusion_matrix_variables when\n        # multi_label is False. Otherwise the averaging of individual label AUCs\n        # is handled in AUC.result\n        label_weights = None if self.multi_label else self.label_weights\n\n        if self._from_logits:\n            y_pred = activations.sigmoid(y_pred)\n\n        metrics_utils.update_confusion_matrix_variables(\n            {\n                metrics_utils.ConfusionMatrix.TRUE_POSITIVES: self.true_positives,  # noqa: E501\n                metrics_utils.ConfusionMatrix.TRUE_NEGATIVES: self.true_negatives,  # noqa: E501\n                metrics_utils.ConfusionMatrix.FALSE_POSITIVES: self.false_positives,  # noqa: E501\n                metrics_utils.ConfusionMatrix.FALSE_NEGATIVES: self.false_negatives,  # noqa: E501\n            },\n            y_true,\n            y_pred,\n            self._thresholds,\n            thresholds_distributed_evenly=self._thresholds_distributed_evenly,\n            sample_weight=sample_weight,\n            multi_label=self.multi_label,\n            label_weights=label_weights,\n        )\n\n    def interpolate_pr_auc(self):\n        \"\"\"Interpolation formula inspired by section 4 of Davis & Goadrich 2006.\n\n        https://www.biostat.wisc.edu/~page/rocpr.pdf\n\n        Note here we derive & use a closed formula not present in the paper\n        as follows:\n\n            Precision = TP / (TP + FP) = TP / P\n\n        Modeling all of TP (true positive), FP (false positive) and their sum\n        P = TP + FP (predicted positive) as varying linearly within each\n        interval [A, B] between successive thresholds, we get\n\n            Precision slope = dTP / dP\n                            = (TP_B - TP_A) / (P_B - P_A)\n                            = (TP - TP_A) / (P - P_A)\n            Precision = (TP_A + slope * (P - P_A)) / P\n\n        The area within the interval is (slope / total_pos_weight) times\n\n            int_A^B{Precision.dP} = int_A^B{(TP_A + slope * (P - P_A)) * dP / P}\n            int_A^B{Precision.dP} = int_A^B{slope * dP + intercept * dP / P}\n\n        where intercept = TP_A - slope * P_A = TP_B - slope * P_B, resulting in\n\n            int_A^B{Precision.dP} = TP_B - TP_A + intercept * log(P_B / P_A)\n\n        Bringing back the factor (slope / total_pos_weight) we'd put aside, we\n        get\n\n            slope * [dTP + intercept *  log(P_B / P_A)] / total_pos_weight\n\n        where dTP == TP_B - TP_A.\n\n        Note that when P_A == 0 the above calculation simplifies into\n\n            int_A^B{Precision.dTP} = int_A^B{slope * dTP}\n                                   = slope * (TP_B - TP_A)\n\n        which is really equivalent to imputing constant precision throughout the\n        first bucket having >0 true positives.\n\n        Returns:\n            pr_auc: an approximation of the area under the P-R curve.\n        \"\"\"\n\n        dtp = ops.subtract(\n            self.true_positives[: self.num_thresholds - 1],\n            self.true_positives[1:],\n        )\n        p = ops.add(self.true_positives, self.false_positives)\n        dp = ops.subtract(p[: self.num_thresholds - 1], p[1:])\n        prec_slope = ops.divide_no_nan(dtp, ops.maximum(dp, 0))\n        intercept = ops.subtract(\n            self.true_positives[1:], ops.multiply(prec_slope, p[1:])\n        )\n\n        safe_p_ratio = ops.where(\n            ops.logical_and(p[: self.num_thresholds - 1] > 0, p[1:] > 0),\n            ops.divide_no_nan(\n                p[: self.num_thresholds - 1], ops.maximum(p[1:], 0)\n            ),\n            ops.ones_like(p[1:]),\n        )\n\n        pr_auc_increment = ops.divide_no_nan(\n            ops.multiply(\n                prec_slope,\n                (ops.add(dtp, ops.multiply(intercept, ops.log(safe_p_ratio)))),\n            ),\n            ops.maximum(\n                ops.add(self.true_positives[1:], self.false_negatives[1:]), 0\n            ),\n        )\n\n        if self.multi_label:\n            by_label_auc = ops.sum(pr_auc_increment, axis=0)\n            if self.label_weights is None:\n                # Evenly weighted average of the label AUCs.\n                return ops.mean(by_label_auc)\n            else:\n                # Weighted average of the label AUCs.\n                return ops.divide_no_nan(\n                    ops.sum(ops.multiply(by_label_auc, self.label_weights)),\n                    ops.sum(self.label_weights),\n                )\n        else:\n            return ops.sum(pr_auc_increment)\n\n    def result(self):\n        if (\n            self.curve == metrics_utils.AUCCurve.PR\n            and self.summation_method\n            == metrics_utils.AUCSummationMethod.INTERPOLATION\n        ):\n            # This use case is different and is handled separately.\n            return self.interpolate_pr_auc()\n\n        # Set `x` and `y` values for the curves based on `curve` config.\n        recall = ops.divide_no_nan(\n            self.true_positives,\n            ops.add(self.true_positives, self.false_negatives),\n        )\n        if self.curve == metrics_utils.AUCCurve.ROC:\n            fp_rate = ops.divide_no_nan(\n                self.false_positives,\n                ops.add(self.false_positives, self.true_negatives),\n            )\n            x = fp_rate\n            y = recall\n        elif self.curve == metrics_utils.AUCCurve.PR:  # curve == 'PR'.\n            precision = ops.divide_no_nan(\n                self.true_positives,\n                ops.add(self.true_positives, self.false_positives),\n            )\n            x = recall\n            y = precision\n        else:  # curve == 'PRGAIN'.\n            # Due to the hyperbolic transform, this formula is less robust than\n            # ROC and PR values. In particular\n            # 1) Both measures diverge when there are no negative values;\n            # 2) Both measures diverge when there are no true positives;\n            # 3) Recall gain becomes negative when the recall is lower than the\n            #    label average (i.e. when more negative examples are\n            #    classified positive than real positives).\n            #\n            # We ignore case 1 as it is easily understood that metrics would be\n            # badly defined then. For case 2 we set recall_gain to 0 and\n            # precision_gain to 1. For case 3 we set recall_gain to 0. These\n            # fixes will result in an overestimation of the AUC for estimators\n            # that are anti-correlated with the label (at some threshold).\n\n            # The scaling factor $\\frac{P}{N}$ that is used to for both gain\n            # values.\n            scaling_factor = ops.divide_no_nan(\n                ops.add(self.true_positives, self.false_negatives),\n                ops.add(self.true_negatives, self.false_positives),\n            )\n\n            recall_gain = 1.0 - scaling_factor * ops.divide_no_nan(\n                self.false_negatives, self.true_positives\n            )\n            precision_gain = 1.0 - scaling_factor * ops.divide_no_nan(\n                self.false_positives, self.true_positives\n            )\n            # Handle case 2.\n            recall_gain = ops.where(\n                ops.equal(self.true_positives, 0.0), 0.0, recall_gain\n            )\n            precision_gain = ops.where(\n                ops.equal(self.true_positives, 0.0), 1.0, precision_gain\n            )\n            # Handle case 3.\n            recall_gain = ops.maximum(recall_gain, 0.0)\n\n            x = recall_gain\n            y = precision_gain\n\n        # Find the rectangle heights based on `summation_method`.\n        if (\n            self.summation_method\n            == metrics_utils.AUCSummationMethod.INTERPOLATION\n        ):\n            # Note: the case ('PR', 'interpolation') has been handled above.\n            heights = ops.divide(\n                ops.add(y[: self.num_thresholds - 1], y[1:]), 2.0\n            )\n        elif self.summation_method == metrics_utils.AUCSummationMethod.MINORING:\n            heights = ops.minimum(y[: self.num_thresholds - 1], y[1:])\n        # self.summation_method = metrics_utils.AUCSummationMethod.MAJORING:\n        else:\n            heights = ops.maximum(y[: self.num_thresholds - 1], y[1:])\n\n        # Sum up the areas of all the rectangles.\n        riemann_terms = ops.multiply(\n            ops.subtract(x[: self.num_thresholds - 1], x[1:]), heights\n        )\n        if self.multi_label:\n            by_label_auc = ops.sum(riemann_terms, axis=0)\n\n            if self.label_weights is None:\n                # Unweighted average of the label AUCs.\n                return ops.mean(by_label_auc)\n            else:\n                # Weighted average of the label AUCs.\n                return ops.divide_no_nan(\n                    ops.sum(ops.multiply(by_label_auc, self.label_weights)),\n                    ops.sum(self.label_weights),\n                )\n        else:\n            return ops.sum(riemann_terms)\n\n    def reset_state(self):\n        if self._built:\n            if self.multi_label:\n                variable_shape = (self.num_thresholds, self._num_labels)\n            else:\n                variable_shape = (self.num_thresholds,)\n\n            self.true_positives.assign(ops.zeros(variable_shape))\n            self.false_positives.assign(ops.zeros(variable_shape))\n            self.true_negatives.assign(ops.zeros(variable_shape))\n            self.false_negatives.assign(ops.zeros(variable_shape))\n\n    def get_config(self):\n        label_weights = self.label_weights\n        config = {\n            \"num_thresholds\": self.num_thresholds,\n            \"curve\": self.curve.value,\n            \"summation_method\": self.summation_method.value,\n            \"multi_label\": self.multi_label,\n            \"num_labels\": self.num_labels,\n            \"label_weights\": label_weights,\n            \"from_logits\": self._from_logits,\n        }\n        # optimization to avoid serializing a large number of generated\n        # thresholds\n        if self._init_from_thresholds:\n            # We remove the endpoint thresholds as an inverse of how the\n            # thresholds were initialized. This ensures that a metric\n            # initialized from this config has the same thresholds.\n            config[\"thresholds\"] = self.thresholds[1:-1]\n        base_config = super().get_config()\n        return {**base_config, **config}\n"
  },
  {
    "path": "keras/src/metrics/confusion_metrics_test.py",
    "content": "import json\n\nimport numpy as np\nimport pytest\nfrom absl import logging\nfrom absl.testing import parameterized\n\nfrom keras.src import layers\nfrom keras.src import metrics\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.metrics import metrics_utils\n\n\nclass FalsePositivesTest(testing.TestCase):\n    def test_config(self):\n        fp_obj = metrics.FalsePositives(name=\"my_fp\", thresholds=[0.4, 0.9])\n        self.assertEqual(fp_obj.name, \"my_fp\")\n        self.assertLen(fp_obj.variables, 1)\n        self.assertEqual(fp_obj.thresholds, [0.4, 0.9])\n\n        # Check save and restore config\n        fp_obj2 = metrics.FalsePositives.from_config(fp_obj.get_config())\n        self.assertEqual(fp_obj2.name, \"my_fp\")\n        self.assertLen(fp_obj2.variables, 1)\n        self.assertEqual(fp_obj2.thresholds, [0.4, 0.9])\n\n    def test_unweighted(self):\n        fp_obj = metrics.FalsePositives()\n\n        y_true = np.array(\n            ((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), (1, 1, 1, 1, 0), (0, 0, 0, 0, 1))\n        )\n        y_pred = np.array(\n            ((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), (0, 1, 0, 1, 0), (1, 1, 1, 1, 1))\n        )\n\n        fp_obj.update_state(y_true, y_pred)\n        self.assertAllClose(7.0, fp_obj.result())\n\n    def test_weighted(self):\n        fp_obj = metrics.FalsePositives()\n        y_true = np.array(\n            ((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), (1, 1, 1, 1, 0), (0, 0, 0, 0, 1))\n        )\n        y_pred = np.array(\n            ((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), (0, 1, 0, 1, 0), (1, 1, 1, 1, 1))\n        )\n        sample_weight = np.array((1.0, 1.5, 2.0, 2.5))\n        result = fp_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose(14.0, result)\n\n    def test_unweighted_with_thresholds(self):\n        fp_obj = metrics.FalsePositives(thresholds=[0.15, 0.5, 0.85])\n\n        y_pred = np.array(\n            (\n                (0.9, 0.2, 0.8, 0.1),\n                (0.2, 0.9, 0.7, 0.6),\n                (0.1, 0.2, 0.4, 0.3),\n                (0, 1, 0.7, 0.3),\n            )\n        )\n        y_true = np.array(\n            ((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0), (1, 1, 1, 1))\n        )\n\n        fp_obj.update_state(y_true, y_pred)\n        self.assertAllClose([7.0, 4.0, 2.0], fp_obj.result())\n\n    def test_weighted_with_thresholds(self):\n        fp_obj = metrics.FalsePositives(thresholds=[0.15, 0.5, 0.85])\n\n        y_pred = np.array(\n            (\n                (0.9, 0.2, 0.8, 0.1),\n                (0.2, 0.9, 0.7, 0.6),\n                (0.1, 0.2, 0.4, 0.3),\n                (0, 1, 0.7, 0.3),\n            )\n        )\n        y_true = np.array(\n            ((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0), (1, 1, 1, 1))\n        )\n        sample_weight = (\n            (1.0, 2.0, 3.0, 5.0),\n            (7.0, 11.0, 13.0, 17.0),\n            (19.0, 23.0, 29.0, 31.0),\n            (5.0, 15.0, 10.0, 0),\n        )\n\n        result = fp_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose([125.0, 42.0, 12.0], result)\n\n    def test_threshold_limit(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"Threshold values must be in \\[0, 1\\]. Received: \\[-1, 2\\]\",\n        ):\n            metrics.FalsePositives(thresholds=[-1, 0.5, 2])\n\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"Threshold values must be in \\[0, 1\\]. Received: \\[None\\]\",\n        ):\n            metrics.FalsePositives(thresholds=[None])\n\n\nclass FalseNegativesTest(testing.TestCase):\n    def test_config(self):\n        fn_obj = metrics.FalseNegatives(name=\"my_fn\", thresholds=[0.4, 0.9])\n        self.assertEqual(fn_obj.name, \"my_fn\")\n        self.assertLen(fn_obj.variables, 1)\n        self.assertEqual(fn_obj.thresholds, [0.4, 0.9])\n\n        # Check save and restore config\n        fn_obj2 = metrics.FalseNegatives.from_config(fn_obj.get_config())\n        self.assertEqual(fn_obj2.name, \"my_fn\")\n        self.assertLen(fn_obj2.variables, 1)\n        self.assertEqual(fn_obj2.thresholds, [0.4, 0.9])\n\n    def test_unweighted(self):\n        fn_obj = metrics.FalseNegatives()\n\n        y_true = np.array(\n            ((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), (1, 1, 1, 1, 0), (0, 0, 0, 0, 1))\n        )\n        y_pred = np.array(\n            ((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), (0, 1, 0, 1, 0), (1, 1, 1, 1, 1))\n        )\n\n        fn_obj.update_state(y_true, y_pred)\n        self.assertAllClose(3.0, fn_obj.result())\n\n    def test_weighted(self):\n        fn_obj = metrics.FalseNegatives()\n        y_true = np.array(\n            ((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), (1, 1, 1, 1, 0), (0, 0, 0, 0, 1))\n        )\n        y_pred = np.array(\n            ((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), (0, 1, 0, 1, 0), (1, 1, 1, 1, 1))\n        )\n        sample_weight = np.array((1.0, 1.5, 2.0, 2.5))\n        result = fn_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose(5.0, result)\n\n    def test_unweighted_with_thresholds(self):\n        fn_obj = metrics.FalseNegatives(thresholds=[0.15, 0.5, 0.85])\n\n        y_pred = np.array(\n            (\n                (0.9, 0.2, 0.8, 0.1),\n                (0.2, 0.9, 0.7, 0.6),\n                (0.1, 0.2, 0.4, 0.3),\n                (0, 1, 0.7, 0.3),\n            )\n        )\n        y_true = np.array(\n            ((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0), (1, 1, 1, 1))\n        )\n\n        fn_obj.update_state(y_true, y_pred)\n        self.assertAllClose([1.0, 4.0, 6.0], fn_obj.result())\n\n    def test_weighted_with_thresholds(self):\n        fn_obj = metrics.FalseNegatives(thresholds=[0.15, 0.5, 0.85])\n\n        y_pred = np.array(\n            (\n                (0.9, 0.2, 0.8, 0.1),\n                (0.2, 0.9, 0.7, 0.6),\n                (0.1, 0.2, 0.4, 0.3),\n                (0, 1, 0.7, 0.3),\n            )\n        )\n        y_true = np.array(\n            ((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0), (1, 1, 1, 1))\n        )\n        sample_weight = ((3.0,), (5.0,), (7.0,), (4.0,))\n\n        result = fn_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose([4.0, 16.0, 23.0], result)\n\n    def test_threshold_limit(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"Threshold values must be in \\[0, 1\\]. Received: \\[-1, 2\\]\",\n        ):\n            metrics.FalseNegatives(thresholds=[-1, 0.5, 2])\n\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"Threshold values must be in \\[0, 1\\]. Received: \\[None\\]\",\n        ):\n            metrics.FalseNegatives(thresholds=[None])\n\n\nclass TrueNegativesTest(testing.TestCase):\n    def test_config(self):\n        tn_obj = metrics.TrueNegatives(name=\"my_tn\", thresholds=[0.4, 0.9])\n        self.assertEqual(tn_obj.name, \"my_tn\")\n        self.assertLen(tn_obj.variables, 1)\n        self.assertEqual(tn_obj.thresholds, [0.4, 0.9])\n\n        # Check save and restore config\n        tn_obj2 = metrics.TrueNegatives.from_config(tn_obj.get_config())\n        self.assertEqual(tn_obj2.name, \"my_tn\")\n        self.assertLen(tn_obj2.variables, 1)\n        self.assertEqual(tn_obj2.thresholds, [0.4, 0.9])\n\n    def test_unweighted(self):\n        tn_obj = metrics.TrueNegatives()\n\n        y_true = np.array(\n            ((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), (1, 1, 1, 1, 0), (0, 0, 0, 0, 1))\n        )\n        y_pred = np.array(\n            ((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), (0, 1, 0, 1, 0), (1, 1, 1, 1, 1))\n        )\n\n        tn_obj.update_state(y_true, y_pred)\n        self.assertAllClose(3.0, tn_obj.result())\n\n    def test_weighted(self):\n        tn_obj = metrics.TrueNegatives()\n        y_true = np.array(\n            ((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), (1, 1, 1, 1, 0), (0, 0, 0, 0, 1))\n        )\n        y_pred = np.array(\n            ((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), (0, 1, 0, 1, 0), (1, 1, 1, 1, 1))\n        )\n        sample_weight = np.array((1.0, 1.5, 2.0, 2.5))\n        result = tn_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose(4.0, result)\n\n    def test_unweighted_with_thresholds(self):\n        tn_obj = metrics.TrueNegatives(thresholds=[0.15, 0.5, 0.85])\n\n        y_pred = np.array(\n            (\n                (0.9, 0.2, 0.8, 0.1),\n                (0.2, 0.9, 0.7, 0.6),\n                (0.1, 0.2, 0.4, 0.3),\n                (0, 1, 0.7, 0.3),\n            )\n        )\n        y_true = np.array(\n            ((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0), (1, 1, 1, 1))\n        )\n\n        tn_obj.update_state(y_true, y_pred)\n        self.assertAllClose([2.0, 5.0, 7.0], tn_obj.result())\n\n    def test_weighted_with_thresholds(self):\n        tn_obj = metrics.TrueNegatives(thresholds=[0.15, 0.5, 0.85])\n\n        y_pred = np.array(\n            (\n                (0.9, 0.2, 0.8, 0.1),\n                (0.2, 0.9, 0.7, 0.6),\n                (0.1, 0.2, 0.4, 0.3),\n                (0, 1, 0.7, 0.3),\n            )\n        )\n        y_true = np.array(\n            ((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0), (1, 1, 1, 1))\n        )\n        sample_weight = ((0.0, 2.0, 3.0, 5.0),)\n\n        result = tn_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose([5.0, 15.0, 23.0], result)\n\n    def test_threshold_limit(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"Threshold values must be in \\[0, 1\\]. Received: \\[-1, 2\\]\",\n        ):\n            metrics.TrueNegatives(thresholds=[-1, 0.5, 2])\n\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"Threshold values must be in \\[0, 1\\]. Received: \\[None\\]\",\n        ):\n            metrics.TrueNegatives(thresholds=[None])\n\n\nclass TruePositiveTest(testing.TestCase):\n    def test_config(self):\n        tp_obj = metrics.TruePositives(name=\"my_tp\", thresholds=[0.4, 0.9])\n        self.assertEqual(tp_obj.name, \"my_tp\")\n        self.assertLen(tp_obj.variables, 1)\n        self.assertEqual(tp_obj.thresholds, [0.4, 0.9])\n\n        # Check save and restore config\n        tp_obj2 = metrics.TruePositives.from_config(tp_obj.get_config())\n        self.assertEqual(tp_obj2.name, \"my_tp\")\n        self.assertLen(tp_obj2.variables, 1)\n        self.assertEqual(tp_obj2.thresholds, [0.4, 0.9])\n\n    def test_unweighted(self):\n        tp_obj = metrics.TruePositives()\n\n        y_true = np.array(\n            ((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), (1, 1, 1, 1, 0), (0, 0, 0, 0, 1))\n        )\n        y_pred = np.array(\n            ((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), (0, 1, 0, 1, 0), (1, 1, 1, 1, 1))\n        )\n\n        tp_obj.update_state(y_true, y_pred)\n        self.assertAllClose(7.0, tp_obj.result())\n\n    def test_weighted(self):\n        tp_obj = metrics.TruePositives()\n        y_true = np.array(\n            ((0, 1, 0, 1, 0), (0, 0, 1, 1, 1), (1, 1, 1, 1, 0), (0, 0, 0, 0, 1))\n        )\n        y_pred = np.array(\n            ((0, 0, 1, 1, 0), (1, 1, 1, 1, 1), (0, 1, 0, 1, 0), (1, 1, 1, 1, 1))\n        )\n        sample_weight = np.array((1.0, 1.5, 2.0, 2.5))\n        result = tp_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose(12.0, result)\n\n    def test_unweighted_with_thresholds(self):\n        tp_obj = metrics.TruePositives(thresholds=[0.15, 0.5, 0.85])\n\n        y_pred = np.array(\n            (\n                (0.9, 0.2, 0.8, 0.1),\n                (0.2, 0.9, 0.7, 0.6),\n                (0.1, 0.2, 0.4, 0.3),\n                (0, 1, 0.7, 0.3),\n            )\n        )\n        y_true = np.array(\n            ((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0), (1, 1, 1, 1))\n        )\n\n        tp_obj.update_state(y_true, y_pred)\n        self.assertAllClose([6.0, 3.0, 1.0], tp_obj.result())\n\n    def test_weighted_with_thresholds(self):\n        tp_obj = metrics.TruePositives(thresholds=[0.15, 0.5, 0.85])\n\n        y_pred = np.array(\n            (\n                (0.9, 0.2, 0.8, 0.1),\n                (0.2, 0.9, 0.7, 0.6),\n                (0.1, 0.2, 0.4, 0.3),\n                (0, 1, 0.7, 0.3),\n            )\n        )\n        y_true = np.array(\n            ((0, 1, 1, 0), (1, 0, 0, 0), (0, 0, 0, 0), (1, 1, 1, 1))\n        )\n        sample_weight = 37.0\n\n        result = tp_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose([222.0, 111.0, 37.0], result)\n\n    def test_threshold_limit(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"Threshold values must be in \\[0, 1\\]. Received: \\[-1, 2\\]\",\n        ):\n            metrics.TruePositives(thresholds=[-1, 0.5, 2])\n\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"Threshold values must be in \\[0, 1\\]. Received: \\[None\\]\",\n        ):\n            metrics.TruePositives(thresholds=[None])\n\n\nclass PrecisionTest(testing.TestCase):\n    def test_config(self):\n        p_obj = metrics.Precision(\n            name=\"my_precision\", thresholds=[0.4, 0.9], top_k=15, class_id=12\n        )\n        self.assertEqual(p_obj.name, \"my_precision\")\n        self.assertLen(p_obj.variables, 2)\n        self.assertEqual(\n            [v.name for v in p_obj.variables],\n            [\"true_positives\", \"false_positives\"],\n        )\n        self.assertEqual(p_obj.thresholds, [0.4, 0.9])\n        self.assertEqual(p_obj.top_k, 15)\n        self.assertEqual(p_obj.class_id, 12)\n\n        # Check save and restore config\n        p_obj2 = metrics.Precision.from_config(p_obj.get_config())\n        self.assertEqual(p_obj2.name, \"my_precision\")\n        self.assertLen(p_obj2.variables, 2)\n        self.assertEqual(p_obj2.thresholds, [0.4, 0.9])\n        self.assertEqual(p_obj2.top_k, 15)\n        self.assertEqual(p_obj2.class_id, 12)\n\n    def test_unweighted(self):\n        p_obj = metrics.Precision()\n        y_pred = np.array([1, 0, 1, 0])\n        y_true = np.array([0, 1, 1, 0])\n        result = p_obj(y_true, y_pred)\n        self.assertAlmostEqual(0.5, result)\n\n    def test_unweighted_all_incorrect(self):\n        p_obj = metrics.Precision(thresholds=[0.5])\n        inputs = np.random.randint(0, 2, size=(100, 1))\n        y_pred = np.array(inputs)\n        y_true = np.array(1 - inputs)\n        result = p_obj(y_true, y_pred)\n        self.assertAlmostEqual(0, result)\n\n    def test_weighted(self):\n        p_obj = metrics.Precision()\n        y_pred = np.array([[1, 0, 1, 0], [1, 0, 1, 0]])\n        y_true = np.array([[0, 1, 1, 0], [1, 0, 0, 1]])\n        result = p_obj(\n            y_true,\n            y_pred,\n            sample_weight=np.array([[1, 2, 3, 4], [4, 3, 2, 1]]),\n        )\n        weighted_tp = 3.0 + 4.0\n        weighted_positives = (1.0 + 3.0) + (4.0 + 2.0)\n        expected_precision = weighted_tp / weighted_positives\n        self.assertAlmostEqual(expected_precision, result)\n\n    def test_div_by_zero(self):\n        p_obj = metrics.Precision()\n        y_pred = np.array([0, 0, 0, 0])\n        y_true = np.array([0, 0, 0, 0])\n        result = p_obj(y_true, y_pred)\n        self.assertEqual(0, result)\n\n    def test_unweighted_with_threshold(self):\n        p_obj = metrics.Precision(thresholds=[0.5, 0.7])\n        y_pred = np.array([1, 0, 0.6, 0])\n        y_true = np.array([0, 1, 1, 0])\n        result = p_obj(y_true, y_pred)\n        self.assertAlmostEqual([0.5, 0.0], result, 0)\n\n    def test_weighted_with_threshold(self):\n        p_obj = metrics.Precision(thresholds=[0.5, 1.0])\n        y_true = np.array([[0, 1], [1, 0]])\n        y_pred = np.array([[1, 0], [0.6, 0]], dtype=\"float32\")\n        weights = np.array([[4, 0], [3, 1]], dtype=\"float32\")\n        result = p_obj(y_true, y_pred, sample_weight=weights)\n        weighted_tp = 0 + 3.0\n        weighted_positives = (0 + 3.0) + (4.0 + 0.0)\n        expected_precision = weighted_tp / weighted_positives\n        self.assertAlmostEqual([expected_precision, 0], result, 1e-3)\n\n    def test_multiple_updates(self):\n        p_obj = metrics.Precision(thresholds=[0.5, 1.0])\n        y_true = np.array([[0, 1], [1, 0]])\n        y_pred = np.array([[1, 0], [0.6, 0]], dtype=\"float32\")\n        weights = np.array([[4, 0], [3, 1]], dtype=\"float32\")\n        for _ in range(2):\n            p_obj.update_state(y_true, y_pred, sample_weight=weights)\n\n        weighted_tp = (0 + 3.0) + (0 + 3.0)\n        weighted_positives = ((0 + 3.0) + (4.0 + 0.0)) + (\n            (0 + 3.0) + (4.0 + 0.0)\n        )\n        expected_precision = weighted_tp / weighted_positives\n        self.assertAlmostEqual([expected_precision, 0], p_obj.result(), 1e-3)\n\n    def test_unweighted_top_k(self):\n        p_obj = metrics.Precision(top_k=3)\n        y_pred = np.array([0.2, 0.1, 0.5, 0, 0.2])\n        y_true = np.array([0, 1, 1, 0, 0])\n        result = p_obj(y_true, y_pred)\n        self.assertAlmostEqual(1.0 / 3, result)\n\n    def test_weighted_top_k(self):\n        p_obj = metrics.Precision(top_k=3)\n        y_pred1 = np.array([[0.2, 0.1, 0.4, 0, 0.2]])\n        y_true1 = np.array([[0, 1, 1, 0, 1]])\n        p_obj(y_true1, y_pred1, sample_weight=np.array([[1, 4, 2, 3, 5]]))\n\n        y_pred2 = np.array([0.2, 0.6, 0.4, 0.2, 0.2])\n        y_true2 = np.array([1, 0, 1, 1, 1])\n        result = p_obj(y_true2, y_pred2, sample_weight=np.array(3))\n\n        tp = (2 + 5) + (3 + 3)\n        predicted_positives = (1 + 2 + 5) + (3 + 3 + 3)\n        expected_precision = tp / predicted_positives\n        self.assertAlmostEqual(expected_precision, result)\n\n    def test_unweighted_class_id_should_throw_error_1d(self):\n        p_obj = metrics.Precision(class_id=2)\n\n        y_pred = np.array([0.2, 0.1, 0.6, 0, 0.2])\n        y_true = np.array([0, 1, 1, 0, 0])\n\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"When class_id is provided, y_pred must be a 2D array \"\n            r\"with shape \\(num_samples, num_classes\\), found shape:.*\",\n        ):\n            p_obj(y_true, y_pred)\n\n    def test_unweighted_class_id_multiclass(self):\n        p_obj = metrics.Precision(class_id=1)\n\n        y_pred = np.array(\n            [\n                [0.1, 0.2, 0.7],\n                [0.5, 0.3, 0.2],\n                [0.2, 0.6, 0.2],\n                [0.7, 0.2, 0.1],\n                [0.1, 0.1, 0.8],\n            ]\n        )\n\n        y_true = np.array(\n            [\n                [0.0, 0.0, 1.0],\n                [1.0, 0.0, 0.0],\n                [0.0, 1.0, 0.0],\n                [1.0, 0.0, 0.0],\n                [0.0, 0.0, 1.0],\n            ]\n        )\n\n        result = p_obj(y_true, y_pred)\n        self.assertAlmostEqual(1.0, result)\n        self.assertAlmostEqual(1.0, p_obj.true_positives)\n        self.assertAlmostEqual(0.0, p_obj.false_positives)\n\n    def test_unweighted_top_k_and_threshold(self):\n        p_obj = metrics.Precision(thresholds=0.7, top_k=2)\n\n        y_pred = np.array([0.2, 0.8, 0.6, 0, 0.2])\n        y_true = np.array([0, 1, 1, 0, 1])\n        result = p_obj(y_true, y_pred)\n        self.assertAlmostEqual(1, result)\n        self.assertAlmostEqual(1, p_obj.true_positives)\n        self.assertAlmostEqual(0, p_obj.false_positives)\n\n\nclass RecallTest(testing.TestCase):\n    def test_config(self):\n        r_obj = metrics.Recall(\n            name=\"my_recall\", thresholds=[0.4, 0.9], top_k=15, class_id=12\n        )\n        self.assertEqual(r_obj.name, \"my_recall\")\n        self.assertLen(r_obj.variables, 2)\n        self.assertEqual(\n            [v.name for v in r_obj.variables],\n            [\"true_positives\", \"false_negatives\"],\n        )\n        self.assertEqual(r_obj.thresholds, [0.4, 0.9])\n        self.assertEqual(r_obj.top_k, 15)\n        self.assertEqual(r_obj.class_id, 12)\n\n        # Check save and restore config\n        r_obj2 = metrics.Recall.from_config(r_obj.get_config())\n        self.assertEqual(r_obj2.name, \"my_recall\")\n        self.assertLen(r_obj2.variables, 2)\n        self.assertEqual(r_obj2.thresholds, [0.4, 0.9])\n        self.assertEqual(r_obj2.top_k, 15)\n        self.assertEqual(r_obj2.class_id, 12)\n\n    def test_unweighted(self):\n        r_obj = metrics.Recall()\n        y_pred = np.array([1, 0, 1, 0])\n        y_true = np.array([0, 1, 1, 0])\n        self.assertAlmostEqual(0.5, r_obj(y_true, y_pred))\n\n    def test_unweighted_all_incorrect(self):\n        r_obj = metrics.Recall(thresholds=[0.5])\n        inputs = np.random.randint(0, 2, size=(100, 1))\n        y_pred = np.array(inputs)\n        y_true = np.array(1 - inputs)\n        self.assertAlmostEqual(0, r_obj(y_true, y_pred))\n\n    def test_weighted(self):\n        r_obj = metrics.Recall()\n        y_pred = np.array([[1, 0, 1, 0], [0, 1, 0, 1]])\n        y_true = np.array([[0, 1, 1, 0], [1, 0, 0, 1]])\n        result = r_obj(\n            y_true,\n            y_pred,\n            sample_weight=np.array([[1, 2, 3, 4], [4, 3, 2, 1]]),\n        )\n        weighted_tp = 3.0 + 1.0\n        weighted_t = (2.0 + 3.0) + (4.0 + 1.0)\n        expected_recall = weighted_tp / weighted_t\n        self.assertAlmostEqual(expected_recall, result)\n\n    def test_div_by_zero(self):\n        r_obj = metrics.Recall()\n        y_pred = np.array([0, 0, 0, 0])\n        y_true = np.array([0, 0, 0, 0])\n        self.assertEqual(0, r_obj(y_true, y_pred))\n\n    def test_unweighted_with_threshold(self):\n        r_obj = metrics.Recall(thresholds=[0.5, 0.7])\n        y_pred = np.array([1, 0, 0.6, 0])\n        y_true = np.array([0, 1, 1, 0])\n        self.assertAllClose([0.5, 0.0], r_obj(y_true, y_pred), 0)\n\n    def test_weighted_with_threshold(self):\n        r_obj = metrics.Recall(thresholds=[0.5, 1.0])\n        y_true = np.array([[0, 1], [1, 0]])\n        y_pred = np.array([[1, 0], [0.6, 0]], dtype=\"float32\")\n        weights = np.array([[1, 4], [3, 2]], dtype=\"float32\")\n        result = r_obj(y_true, y_pred, sample_weight=weights)\n        weighted_tp = 0 + 3.0\n        weighted_positives = (0 + 3.0) + (4.0 + 0.0)\n        expected_recall = weighted_tp / weighted_positives\n        self.assertAllClose([expected_recall, 0], result, 1e-3)\n\n    def test_multiple_updates(self):\n        r_obj = metrics.Recall(thresholds=[0.5, 1.0])\n        y_true = np.array([[0, 1], [1, 0]])\n        y_pred = np.array([[1, 0], [0.6, 0]], dtype=\"float32\")\n        weights = np.array([[1, 4], [3, 2]], dtype=\"float32\")\n        for _ in range(2):\n            r_obj.update_state(y_true, y_pred, sample_weight=weights)\n\n        weighted_tp = (0 + 3.0) + (0 + 3.0)\n        weighted_positives = ((0 + 3.0) + (4.0 + 0.0)) + (\n            (0 + 3.0) + (4.0 + 0.0)\n        )\n        expected_recall = weighted_tp / weighted_positives\n        self.assertAllClose([expected_recall, 0], r_obj.result(), 1e-3)\n\n    def test_unweighted_top_k(self):\n        r_obj = metrics.Recall(top_k=3)\n        y_pred = np.array([0.2, 0.1, 0.5, 0, 0.2])\n        y_true = np.array([0, 1, 1, 0, 0])\n        self.assertAlmostEqual(0.5, r_obj(y_true, y_pred))\n\n    def test_weighted_top_k(self):\n        r_obj = metrics.Recall(top_k=3)\n        y_pred1 = np.array([[0.2, 0.1, 0.4, 0, 0.2]])\n        y_true1 = np.array([[0, 1, 1, 0, 1]])\n        r_obj(y_true1, y_pred1, sample_weight=np.array([[1, 4, 2, 3, 5]]))\n\n        y_pred2 = np.array([0.2, 0.6, 0.4, 0.2, 0.2])\n        y_true2 = np.array([1, 0, 1, 1, 1])\n        result = r_obj(y_true2, y_pred2, sample_weight=np.array(3))\n\n        tp = (2 + 5) + (3 + 3)\n        positives = (4 + 2 + 5) + (3 + 3 + 3 + 3)\n        expected_recall = tp / positives\n        self.assertAlmostEqual(expected_recall, result)\n\n    def test_unweighted_class_id_should_throw_error_1d(self):\n        r_obj = metrics.Recall(class_id=2)\n\n        y_pred = np.array([0.2, 0.1, 0.6, 0, 0.2])\n        y_true = np.array([0, 1, 1, 0, 0])\n\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"When class_id is provided, y_pred must be a 2D array \"\n            r\"with shape \\(num_samples, num_classes\\), found shape:.*\",\n        ):\n            r_obj(y_true, y_pred)\n\n    def test_unweighted_class_id_multiclass(self):\n        r_obj = metrics.Recall(class_id=1)\n\n        y_pred = np.array(\n            [\n                [0.1, 0.2, 0.7],\n                [0.5, 0.3, 0.2],\n                [0.2, 0.6, 0.2],\n                [0.7, 0.2, 0.1],\n                [0.1, 0.1, 0.8],\n            ]\n        )\n\n        y_true = np.array(\n            [\n                [0.0, 0.0, 1.0],\n                [1.0, 0.0, 0.0],\n                [0.0, 1.0, 0.0],\n                [1.0, 0.0, 0.0],\n                [0.0, 0.0, 1.0],\n            ]\n        )\n\n        result = r_obj(y_true, y_pred)\n        self.assertAlmostEqual(1.0, result)\n        self.assertAlmostEqual(1.0, r_obj.true_positives)\n        self.assertAlmostEqual(0.0, r_obj.false_negatives)\n\n    def test_unweighted_top_k_and_threshold(self):\n        r_obj = metrics.Recall(thresholds=0.7, top_k=2)\n\n        y_pred = np.array([0.2, 0.8, 0.6, 0, 0.2])\n        y_true = np.array([1, 1, 1, 0, 1])\n        self.assertAlmostEqual(0.25, r_obj(y_true, y_pred))\n        self.assertAlmostEqual(1, r_obj.true_positives)\n        self.assertAlmostEqual(3, r_obj.false_negatives)\n\n\nclass SensitivityAtSpecificityTest(testing.TestCase):\n    def test_config(self):\n        s_obj = metrics.SensitivityAtSpecificity(\n            0.4,\n            num_thresholds=100,\n            class_id=12,\n            name=\"sensitivity_at_specificity_1\",\n        )\n        self.assertEqual(s_obj.name, \"sensitivity_at_specificity_1\")\n        self.assertLen(s_obj.variables, 4)\n        self.assertEqual(s_obj.specificity, 0.4)\n        self.assertEqual(s_obj.num_thresholds, 100)\n        self.assertEqual(s_obj.class_id, 12)\n\n        # Check save and restore config\n        s_obj2 = metrics.SensitivityAtSpecificity.from_config(\n            s_obj.get_config()\n        )\n        self.assertEqual(s_obj2.name, \"sensitivity_at_specificity_1\")\n        self.assertLen(s_obj2.variables, 4)\n        self.assertEqual(s_obj2.specificity, 0.4)\n        self.assertEqual(s_obj2.num_thresholds, 100)\n        self.assertEqual(s_obj.class_id, 12)\n\n    def test_unweighted_all_correct(self):\n        s_obj = metrics.SensitivityAtSpecificity(0.7)\n        inputs = np.random.randint(0, 2, size=(100, 1))\n        y_pred = np.array(inputs, dtype=\"float32\")\n        y_true = np.array(inputs)\n        self.assertAlmostEqual(1, s_obj(y_true, y_pred))\n\n    def test_unweighted_high_specificity(self):\n        s_obj = metrics.SensitivityAtSpecificity(0.8)\n        pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.45, 0.5, 0.8, 0.9]\n        label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]\n\n        y_pred = np.array(pred_values, dtype=\"float32\")\n        y_true = np.array(label_values)\n\n        self.assertAlmostEqual(0.8, s_obj(y_true, y_pred))\n\n    def test_unweighted_low_specificity(self):\n        s_obj = metrics.SensitivityAtSpecificity(0.4)\n        pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]\n        label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]\n\n        y_pred = np.array(pred_values, dtype=\"float32\")\n        y_true = np.array(label_values)\n\n        self.assertAlmostEqual(0.6, s_obj(y_true, y_pred))\n\n    def test_unweighted_class_id(self):\n        s_obj = metrics.SpecificityAtSensitivity(0.4, class_id=2)\n        pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]\n        label_values = [0, 0, 0, 0, 0, 2, 2, 2, 2, 2]\n\n        y_pred = ops.transpose(np.array([pred_values] * 3))\n        y_true = ops.one_hot(np.array(label_values), num_classes=3)\n\n        self.assertAlmostEqual(0.6, s_obj(y_true, y_pred))\n\n    @parameterized.parameters([\"bool\", \"int32\", \"float32\"])\n    def test_weighted(self, label_dtype):\n        s_obj = metrics.SensitivityAtSpecificity(0.4)\n        pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]\n        label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]\n        weight_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n\n        y_pred = np.array(pred_values, dtype=\"float32\")\n        y_true = ops.cast(label_values, dtype=label_dtype)\n        weights = np.array(weight_values)\n\n        result = s_obj(y_true, y_pred, sample_weight=weights)\n        self.assertAlmostEqual(0.675, result)\n\n    def test_invalid_specificity(self):\n        with self.assertRaisesRegex(\n            ValueError, r\"`specificity` must be in the range \\[0, 1\\].\"\n        ):\n            metrics.SensitivityAtSpecificity(-1)\n\n    def test_invalid_num_thresholds(self):\n        with self.assertRaisesRegex(\n            ValueError, \"Argument `num_thresholds` must be an integer > 0\"\n        ):\n            metrics.SensitivityAtSpecificity(0.4, num_thresholds=-1)\n\n    @pytest.mark.requires_trainable_backend\n    def test_handles_sas_metrics(self):\n        # Test for https://github.com/keras-team/keras/issues/19376\n        model = models.Sequential(\n            [\n                layers.Input((1,)),\n                layers.Dense(1),\n            ]\n        )\n        sas = metrics.SpecificityAtSensitivity(0.5, name=\"sas\")\n\n        model.compile(optimizer=\"adam\", loss=\"crossentropy\", metrics=[sas])\n        model.fit(np.ones((5, 1)), np.ones((5, 1)))\n\n\nclass SpecificityAtSensitivityTest(testing.TestCase):\n    def test_config(self):\n        s_obj = metrics.SpecificityAtSensitivity(\n            0.4,\n            num_thresholds=100,\n            class_id=12,\n            name=\"specificity_at_sensitivity_1\",\n        )\n        self.assertEqual(s_obj.name, \"specificity_at_sensitivity_1\")\n        self.assertLen(s_obj.variables, 4)\n        self.assertEqual(s_obj.sensitivity, 0.4)\n        self.assertEqual(s_obj.num_thresholds, 100)\n        self.assertEqual(s_obj.class_id, 12)\n\n        # Check save and restore config\n        s_obj2 = metrics.SpecificityAtSensitivity.from_config(\n            s_obj.get_config()\n        )\n        self.assertEqual(s_obj2.name, \"specificity_at_sensitivity_1\")\n        self.assertLen(s_obj2.variables, 4)\n        self.assertEqual(s_obj2.sensitivity, 0.4)\n        self.assertEqual(s_obj2.num_thresholds, 100)\n        self.assertEqual(s_obj.class_id, 12)\n\n    def test_unweighted_all_correct(self):\n        s_obj = metrics.SpecificityAtSensitivity(0.7)\n        inputs = np.random.randint(0, 2, size=(100, 1))\n        y_pred = np.array(inputs, dtype=\"float32\")\n        y_true = np.array(inputs)\n\n        self.assertAlmostEqual(1, s_obj(y_true, y_pred))\n\n    def test_unweighted_high_sensitivity(self):\n        s_obj = metrics.SpecificityAtSensitivity(1.0)\n        pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]\n        label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]\n\n        y_pred = np.array(pred_values, dtype=\"float32\")\n        y_true = np.array(label_values)\n\n        self.assertAlmostEqual(0.2, s_obj(y_true, y_pred))\n\n    def test_unweighted_low_sensitivity(self):\n        s_obj = metrics.SpecificityAtSensitivity(0.4)\n        pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]\n        label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]\n\n        y_pred = np.array(pred_values, dtype=\"float32\")\n        y_true = np.array(label_values)\n\n        self.assertAlmostEqual(0.6, s_obj(y_true, y_pred))\n\n    def test_unweighted_class_id(self):\n        s_obj = metrics.SpecificityAtSensitivity(0.4, class_id=2)\n        pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]\n        label_values = [0, 0, 0, 0, 0, 2, 2, 2, 2, 2]\n\n        y_pred = ops.transpose(np.array([pred_values] * 3))\n        y_true = ops.one_hot(np.array(label_values), num_classes=3)\n\n        self.assertAlmostEqual(0.6, s_obj(y_true, y_pred))\n\n    @parameterized.parameters([\"bool\", \"int32\", \"float32\"])\n    def test_weighted(self, label_dtype):\n        s_obj = metrics.SpecificityAtSensitivity(0.4)\n        pred_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]\n        label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]\n        weight_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]\n\n        y_pred = np.array(pred_values, dtype=\"float32\")\n        y_true = ops.cast(label_values, dtype=label_dtype)\n        weights = np.array(weight_values)\n\n        result = s_obj(y_true, y_pred, sample_weight=weights)\n        self.assertAlmostEqual(0.4, result)\n\n    def test_invalid_sensitivity(self):\n        with self.assertRaisesRegex(\n            ValueError, r\"`sensitivity` must be in the range \\[0, 1\\].\"\n        ):\n            metrics.SpecificityAtSensitivity(-1)\n\n    def test_invalid_num_thresholds(self):\n        with self.assertRaisesRegex(\n            ValueError, \"Argument `num_thresholds` must be an integer > 0\"\n        ):\n            metrics.SpecificityAtSensitivity(0.4, num_thresholds=-1)\n\n\nclass PrecisionAtRecallTest(testing.TestCase):\n    def test_config(self):\n        s_obj = metrics.PrecisionAtRecall(\n            0.4, num_thresholds=100, class_id=12, name=\"precision_at_recall_1\"\n        )\n        self.assertEqual(s_obj.name, \"precision_at_recall_1\")\n        self.assertLen(s_obj.variables, 4)\n        self.assertEqual(s_obj.recall, 0.4)\n        self.assertEqual(s_obj.num_thresholds, 100)\n        self.assertEqual(s_obj.class_id, 12)\n\n        # Check save and restore config\n        s_obj2 = metrics.PrecisionAtRecall.from_config(s_obj.get_config())\n        self.assertEqual(s_obj2.name, \"precision_at_recall_1\")\n        self.assertLen(s_obj2.variables, 4)\n        self.assertEqual(s_obj2.recall, 0.4)\n        self.assertEqual(s_obj2.num_thresholds, 100)\n        self.assertEqual(s_obj.class_id, 12)\n\n    def test_unweighted_all_correct(self):\n        s_obj = metrics.PrecisionAtRecall(0.7)\n        inputs = np.random.randint(0, 2, size=(100, 1))\n        y_pred = np.array(inputs, dtype=\"float32\")\n        y_true = np.array(inputs)\n\n        self.assertAlmostEqual(1, s_obj(y_true, y_pred))\n\n    def test_unweighted_high_recall(self):\n        s_obj = metrics.PrecisionAtRecall(0.8)\n        pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9]\n        label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]\n\n        y_pred = np.array(pred_values, dtype=\"float32\")\n        y_true = np.array(label_values)\n\n        # For 0.5 < decision threshold < 0.6.\n        self.assertAlmostEqual(2.0 / 3, s_obj(y_true, y_pred))\n\n    def test_unweighted_low_recall(self):\n        s_obj = metrics.PrecisionAtRecall(0.6)\n        pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9]\n        label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]\n\n        y_pred = np.array(pred_values, dtype=\"float32\")\n        y_true = np.array(label_values)\n\n        # For 0.2 < decision threshold < 0.5.\n        self.assertAlmostEqual(0.75, s_obj(y_true, y_pred))\n\n    def test_unweighted_class_id(self):\n        s_obj = metrics.PrecisionAtRecall(0.6, class_id=2)\n        pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9]\n        label_values = [0, 0, 0, 0, 0, 2, 2, 2, 2, 2]\n\n        y_pred = ops.transpose(np.array([pred_values] * 3))\n        y_true = ops.one_hot(np.array(label_values), num_classes=3)\n\n        # For 0.2 < decision threshold < 0.5.\n        self.assertAlmostEqual(0.75, s_obj(y_true, y_pred))\n\n    @parameterized.parameters([\"bool\", \"int32\", \"float32\"])\n    def test_weighted(self, label_dtype):\n        s_obj = metrics.PrecisionAtRecall(7.0 / 8)\n        pred_values = [0.0, 0.1, 0.2, 0.5, 0.6, 0.2, 0.5, 0.6, 0.8, 0.9]\n        label_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]\n        weight_values = [2, 1, 2, 1, 2, 1, 2, 2, 1, 2]\n\n        y_pred = np.array(pred_values, dtype=\"float32\")\n        y_true = ops.cast(label_values, dtype=label_dtype)\n        weights = np.array(weight_values)\n\n        result = s_obj(y_true, y_pred, sample_weight=weights)\n        # For 0.0 < decision threshold < 0.2.\n        self.assertAlmostEqual(0.7, result)\n\n    def test_invalid_sensitivity(self):\n        with self.assertRaisesRegex(\n            ValueError, r\"`recall` must be in the range \\[0, 1\\].\"\n        ):\n            metrics.PrecisionAtRecall(-1)\n\n    def test_invalid_num_thresholds(self):\n        with self.assertRaisesRegex(\n            ValueError, \"Argument `num_thresholds` must be an integer > 0\"\n        ):\n            metrics.PrecisionAtRecall(0.4, num_thresholds=-1)\n\n\nclass RecallAtPrecisionTest(testing.TestCase):\n    def test_config(self):\n        s_obj = metrics.RecallAtPrecision(\n            0.4, num_thresholds=100, class_id=12, name=\"recall_at_precision_1\"\n        )\n        self.assertEqual(s_obj.name, \"recall_at_precision_1\")\n        self.assertLen(s_obj.variables, 4)\n        self.assertEqual(s_obj.precision, 0.4)\n        self.assertEqual(s_obj.num_thresholds, 100)\n        self.assertEqual(s_obj.class_id, 12)\n\n        # Check save and restore config\n        s_obj2 = metrics.RecallAtPrecision.from_config(s_obj.get_config())\n        self.assertEqual(s_obj2.name, \"recall_at_precision_1\")\n        self.assertLen(s_obj2.variables, 4)\n        self.assertEqual(s_obj2.precision, 0.4)\n        self.assertEqual(s_obj2.num_thresholds, 100)\n        self.assertEqual(s_obj.class_id, 12)\n\n    def test_unweighted_all_correct(self):\n        s_obj = metrics.RecallAtPrecision(0.7)\n        inputs = np.random.randint(0, 2, size=(100, 1))\n        y_pred = np.array(inputs, dtype=\"float32\")\n        y_true = np.array(inputs)\n\n        self.assertAlmostEqual(1, s_obj(y_true, y_pred))\n\n    def test_unweighted_high_precision(self):\n        s_obj = metrics.RecallAtPrecision(0.75)\n        pred_values = [\n            0.05,\n            0.1,\n            0.2,\n            0.3,\n            0.3,\n            0.35,\n            0.4,\n            0.45,\n            0.5,\n            0.6,\n            0.9,\n            0.95,\n        ]\n        label_values = [0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1]\n        # precisions: [1/2, 6/11, 1/2, 5/9, 5/8, 5/7, 2/3, 3/5, 3/5, 2/3, 1/2,\n        # 1].\n        # recalls:    [1,   1,    5/6, 5/6, 5/6, 5/6, 2/3, 1/2, 1/2, 1/3, 1/6,\n        # 1/6].\n        y_pred = np.array(pred_values, dtype=\"float32\")\n        y_true = np.array(label_values)\n\n        # The precision 0.75 can be reached at thresholds 0.4<=t<0.45.\n        self.assertAlmostEqual(0.5, s_obj(y_true, y_pred))\n\n    def test_unweighted_low_precision(self):\n        s_obj = metrics.RecallAtPrecision(2.0 / 3)\n        pred_values = [\n            0.05,\n            0.1,\n            0.2,\n            0.3,\n            0.3,\n            0.35,\n            0.4,\n            0.45,\n            0.5,\n            0.6,\n            0.9,\n            0.95,\n        ]\n        label_values = [0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1]\n        # precisions: [1/2, 6/11, 1/2, 5/9, 5/8, 5/7, 2/3, 3/5, 3/5, 2/3, 1/2,\n        # 1].\n        # recalls:    [1,   1,    5/6, 5/6, 5/6, 5/6, 2/3, 1/2, 1/2, 1/3, 1/6,\n        # 1/6].\n        y_pred = np.array(pred_values, dtype=\"float32\")\n        y_true = np.array(label_values)\n\n        # The precision 5/7 can be reached at thresholds 00.3<=t<0.35.\n        self.assertAlmostEqual(5.0 / 6, s_obj(y_true, y_pred))\n\n    def test_unweighted_class_id(self):\n        s_obj = metrics.RecallAtPrecision(2.0 / 3, class_id=2)\n        pred_values = [\n            0.05,\n            0.1,\n            0.2,\n            0.3,\n            0.3,\n            0.35,\n            0.4,\n            0.45,\n            0.5,\n            0.6,\n            0.9,\n            0.95,\n        ]\n        label_values = [0, 2, 0, 0, 0, 2, 2, 0, 2, 2, 0, 2]\n        # precisions: [1/2, 6/11, 1/2, 5/9, 5/8, 5/7, 2/3, 3/5, 3/5, 2/3, 1/2,\n        # 1].\n        # recalls:    [1,   1,    5/6, 5/6, 5/6, 5/6, 2/3, 1/2, 1/2, 1/3, 1/6,\n        # 1/6].\n        y_pred = ops.transpose(np.array([pred_values] * 3))\n        y_true = ops.one_hot(np.array(label_values), num_classes=3)\n\n        # The precision 5/7 can be reached at thresholds 00.3<=t<0.35.\n        self.assertAlmostEqual(5.0 / 6, s_obj(y_true, y_pred))\n\n    @parameterized.parameters([\"bool\", \"int32\", \"float32\"])\n    def test_weighted(self, label_dtype):\n        s_obj = metrics.RecallAtPrecision(0.75)\n        pred_values = [0.1, 0.2, 0.3, 0.5, 0.6, 0.9, 0.9]\n        label_values = [0, 1, 0, 0, 0, 1, 1]\n        weight_values = [1, 2, 1, 2, 1, 2, 1]\n        y_pred = np.array(pred_values, dtype=\"float32\")\n        y_true = ops.cast(label_values, dtype=label_dtype)\n        weights = np.array(weight_values)\n\n        result = s_obj(y_true, y_pred, sample_weight=weights)\n        self.assertAlmostEqual(0.6, result)\n\n    def test_unachievable_precision(self):\n        s_obj = metrics.RecallAtPrecision(2.0 / 3)\n        pred_values = [0.1, 0.2, 0.3, 0.9]\n        label_values = [1, 1, 0, 0]\n        y_pred = np.array(pred_values, dtype=\"float32\")\n        y_true = np.array(label_values)\n\n        # The highest possible precision is 1/2 which is below the required\n        # value, expect 0 recall.\n        self.assertAlmostEqual(0, s_obj(y_true, y_pred))\n\n    def test_invalid_sensitivity(self):\n        with self.assertRaisesRegex(\n            ValueError, r\"`precision` must be in the range \\[0, 1\\].\"\n        ):\n            metrics.RecallAtPrecision(-1)\n\n    def test_invalid_num_thresholds(self):\n        with self.assertRaisesRegex(\n            ValueError, \"Argument `num_thresholds` must be an integer > 0\"\n        ):\n            metrics.RecallAtPrecision(0.4, num_thresholds=-1)\n\n    @pytest.mark.requires_trainable_backend\n    def test_end_to_end(self):\n        # Test for https://github.com/keras-team/keras/issues/718\n        model = models.Sequential(\n            [\n                layers.Input((1,)),\n                layers.Dense(1),\n            ]\n        )\n        model.compile(\n            optimizer=\"rmsprop\", loss=\"mse\", metrics=[metrics.Precision()]\n        )\n        model.fit(np.ones((5, 1)), np.ones((5, 1)))\n\n\nclass AUCTest(testing.TestCase):\n    def setUp(self):\n        self.num_thresholds = 3\n        self.y_pred = np.array([0, 0.5, 0.3, 0.9], dtype=\"float32\")\n        self.y_pred_multi_label = np.array(\n            [[0.0, 0.4], [0.5, 0.7], [0.3, 0.2], [0.9, 0.3]], dtype=\"float32\"\n        )\n        epsilon = 1e-12\n        self.y_pred_logits = -ops.log(1.0 / (self.y_pred + epsilon) - 1.0)\n        self.y_true = np.array([0, 0, 1, 1])\n        self.y_true_multi_label = np.array([[0, 0], [1, 1], [1, 1], [1, 0]])\n        self.sample_weight = [1, 2, 3, 4]\n\n        # threshold values are [0 - 1e-7, 0.5, 1 + 1e-7]\n        # y_pred when threshold = 0 - 1e-7  : [1, 1, 1, 1]\n        # y_pred when threshold = 0.5       : [0, 0, 0, 1]\n        # y_pred when threshold = 1 + 1e-7  : [0, 0, 0, 0]\n\n        # without sample_weight:\n        # tp = np.sum([[0, 0, 1, 1], [0, 0, 0, 1], [0, 0, 0, 0]], axis=1)\n        # fp = np.sum([[1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], axis=1)\n        # fn = np.sum([[0, 0, 0, 0], [0, 0, 1, 0], [0, 0, 1, 1]], axis=1)\n        # tn = np.sum([[0, 0, 0, 0], [1, 1, 0, 0], [1, 1, 0, 0]], axis=1)\n\n        # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]\n\n        # with sample_weight:\n        # tp = np.sum([[0, 0, 3, 4], [0, 0, 0, 4], [0, 0, 0, 0]], axis=1)\n        # fp = np.sum([[1, 2, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]], axis=1)\n        # fn = np.sum([[0, 0, 0, 0], [0, 0, 3, 0], [0, 0, 3, 4]], axis=1)\n        # tn = np.sum([[0, 0, 0, 0], [1, 2, 0, 0], [1, 2, 0, 0]], axis=1)\n\n        # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]\n\n    def test_config(self):\n        auc_obj = metrics.AUC(\n            num_thresholds=100,\n            curve=\"PR\",\n            summation_method=\"majoring\",\n            name=\"auc_1\",\n            dtype=\"float64\",\n            multi_label=True,\n            num_labels=2,\n            from_logits=True,\n        )\n        auc_obj.update_state(self.y_true_multi_label, self.y_pred_multi_label)\n        self.assertEqual(auc_obj.name, \"auc_1\")\n        self.assertEqual(auc_obj._dtype, \"float64\")\n        self.assertLen(auc_obj.variables, 4)\n        self.assertEqual(auc_obj.num_thresholds, 100)\n        self.assertEqual(auc_obj.curve, metrics_utils.AUCCurve.PR)\n        self.assertEqual(\n            auc_obj.summation_method, metrics_utils.AUCSummationMethod.MAJORING\n        )\n        self.assertTrue(auc_obj.multi_label)\n        self.assertEqual(auc_obj.num_labels, 2)\n        self.assertTrue(auc_obj._from_logits)\n        old_config = auc_obj.get_config()\n        self.assertNotIn(\"thresholds\", old_config)\n        self.assertDictEqual(old_config, json.loads(json.dumps(old_config)))\n\n        # Check save and restore config.\n        auc_obj2 = metrics.AUC.from_config(auc_obj.get_config())\n        auc_obj2.update_state(self.y_true_multi_label, self.y_pred_multi_label)\n        self.assertEqual(auc_obj2.name, \"auc_1\")\n        self.assertLen(auc_obj2.variables, 4)\n        self.assertEqual(auc_obj2.num_thresholds, 100)\n        self.assertEqual(auc_obj2.curve, metrics_utils.AUCCurve.PR)\n        self.assertEqual(\n            auc_obj2.summation_method, metrics_utils.AUCSummationMethod.MAJORING\n        )\n        self.assertTrue(auc_obj2.multi_label)\n        self.assertEqual(auc_obj2.num_labels, 2)\n        self.assertTrue(auc_obj2._from_logits)\n        new_config = auc_obj2.get_config()\n        self.assertNotIn(\"thresholds\", new_config)\n        self.assertDictEqual(old_config, new_config)\n        self.assertAllClose(auc_obj.thresholds, auc_obj2.thresholds)\n\n    def test_config_manual_thresholds(self):\n        auc_obj = metrics.AUC(\n            num_thresholds=None,\n            curve=\"PR\",\n            summation_method=\"majoring\",\n            name=\"auc_1\",\n            thresholds=[0.3, 0.5],\n        )\n        auc_obj.update_state(self.y_true, self.y_pred)\n        self.assertEqual(auc_obj.name, \"auc_1\")\n        self.assertLen(auc_obj.variables, 4)\n        self.assertEqual(auc_obj.num_thresholds, 4)\n        self.assertAllClose(auc_obj.thresholds, [0.0, 0.3, 0.5, 1.0])\n        self.assertEqual(auc_obj.curve, metrics_utils.AUCCurve.PR)\n        self.assertEqual(\n            auc_obj.summation_method, metrics_utils.AUCSummationMethod.MAJORING\n        )\n        old_config = auc_obj.get_config()\n        self.assertDictEqual(old_config, json.loads(json.dumps(old_config)))\n\n        # Check save and restore config.\n        auc_obj2 = metrics.AUC.from_config(auc_obj.get_config())\n        auc_obj2.update_state(self.y_true, self.y_pred)\n        self.assertEqual(auc_obj2.name, \"auc_1\")\n        self.assertLen(auc_obj2.variables, 4)\n        self.assertEqual(auc_obj2.num_thresholds, 4)\n        self.assertEqual(auc_obj2.curve, metrics_utils.AUCCurve.PR)\n        self.assertEqual(\n            auc_obj2.summation_method, metrics_utils.AUCSummationMethod.MAJORING\n        )\n        new_config = auc_obj2.get_config()\n        self.assertDictEqual(old_config, new_config)\n        self.assertAllClose(auc_obj.thresholds, auc_obj2.thresholds)\n\n    def test_unweighted_all_correct(self):\n        auc_obj = metrics.AUC()\n        self.assertEqual(auc_obj(self.y_true, self.y_true), 1)\n\n    def test_unweighted(self):\n        auc_obj = metrics.AUC(num_thresholds=self.num_thresholds)\n        result = auc_obj(self.y_true, self.y_pred)\n\n        # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]\n        # recall = [2/2, 1/(1+1), 0] = [1, 0.5, 0]\n        # fp_rate = [2/2, 0, 0] = [1, 0, 0]\n        # heights = [(1 + 0.5)/2, (0.5 + 0)/2] = [0.75, 0.25]\n        # widths = [(1 - 0), (0 - 0)] = [1, 0]\n        expected_result = 0.75 * 1 + 0.25 * 0\n        self.assertAllClose(result, expected_result, 1e-3)\n\n    def test_unweighted_from_logits(self):\n        auc_obj = metrics.AUC(\n            num_thresholds=self.num_thresholds, from_logits=True\n        )\n        result = auc_obj(self.y_true, self.y_pred_logits)\n\n        # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]\n        # recall = [2/2, 1/(1+1), 0] = [1, 0.5, 0]\n        # fp_rate = [2/2, 0, 0] = [1, 0, 0]\n        # heights = [(1 + 0.5)/2, (0.5 + 0)/2] = [0.75, 0.25]\n        # widths = [(1 - 0), (0 - 0)] = [1, 0]\n        expected_result = 0.75 * 1 + 0.25 * 0\n        self.assertAllClose(result, expected_result, 1e-3)\n\n    def test_manual_thresholds(self):\n        # Verify that when specified, thresholds are used instead of\n        # num_thresholds.\n        auc_obj = metrics.AUC(num_thresholds=2, thresholds=[0.5])\n        self.assertEqual(auc_obj.num_thresholds, 3)\n        self.assertAllClose(auc_obj.thresholds, [0.0, 0.5, 1.0])\n        result = auc_obj(self.y_true, self.y_pred)\n\n        # tp = [2, 1, 0], fp = [2, 0, 0], fn = [0, 1, 2], tn = [0, 2, 2]\n        # recall = [2/2, 1/(1+1), 0] = [1, 0.5, 0]\n        # fp_rate = [2/2, 0, 0] = [1, 0, 0]\n        # heights = [(1 + 0.5)/2, (0.5 + 0)/2] = [0.75, 0.25]\n        # widths = [(1 - 0), (0 - 0)] = [1, 0]\n        expected_result = 0.75 * 1 + 0.25 * 0\n        self.assertAllClose(result, expected_result, 1e-3)\n\n    def test_weighted_roc_interpolation(self):\n        auc_obj = metrics.AUC(num_thresholds=self.num_thresholds)\n        result = auc_obj(\n            self.y_true, self.y_pred, sample_weight=self.sample_weight\n        )\n\n        # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]\n        # recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0]\n        # fp_rate = [3/3, 0, 0] = [1, 0, 0]\n        # heights = [(1 + 0.571)/2, (0.571 + 0)/2] = [0.7855, 0.2855]\n        # widths = [(1 - 0), (0 - 0)] = [1, 0]\n        expected_result = 0.7855 * 1 + 0.2855 * 0\n        self.assertAllClose(result, expected_result, 1e-3)\n\n    def test_weighted_roc_majoring(self):\n        auc_obj = metrics.AUC(\n            num_thresholds=self.num_thresholds, summation_method=\"majoring\"\n        )\n        result = auc_obj(\n            self.y_true, self.y_pred, sample_weight=self.sample_weight\n        )\n\n        # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]\n        # recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0]\n        # fp_rate = [3/3, 0, 0] = [1, 0, 0]\n        # heights = [max(1, 0.571), max(0.571, 0)] = [1, 0.571]\n        # widths = [(1 - 0), (0 - 0)] = [1, 0]\n        expected_result = 1 * 1 + 0.571 * 0\n        self.assertAllClose(result, expected_result, 1e-3)\n\n    def test_weighted_roc_minoring(self):\n        auc_obj = metrics.AUC(\n            num_thresholds=self.num_thresholds, summation_method=\"minoring\"\n        )\n        result = auc_obj(\n            self.y_true, self.y_pred, sample_weight=self.sample_weight\n        )\n\n        # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]\n        # recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0]\n        # fp_rate = [3/3, 0, 0] = [1, 0, 0]\n        # heights = [min(1, 0.571), min(0.571, 0)] = [0.571, 0]\n        # widths = [(1 - 0), (0 - 0)] = [1, 0]\n        expected_result = 0.571 * 1 + 0 * 0\n        self.assertAllClose(result, expected_result, 1e-3)\n\n    def test_weighted_pr_majoring(self):\n        auc_obj = metrics.AUC(\n            num_thresholds=self.num_thresholds,\n            curve=\"PR\",\n            summation_method=\"majoring\",\n        )\n        result = auc_obj(\n            self.y_true, self.y_pred, sample_weight=self.sample_weight\n        )\n\n        # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]\n        # precision = [7/(7+3), 4/4, 0] = [0.7, 1, 0]\n        # recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0]\n        # heights = [max(0.7, 1), max(1, 0)] = [1, 1]\n        # widths = [(1 - 0.571), (0.571 - 0)] = [0.429, 0.571]\n        expected_result = 1 * 0.429 + 1 * 0.571\n        self.assertAllClose(result, expected_result, 1e-3)\n\n    def test_weighted_pr_minoring(self):\n        auc_obj = metrics.AUC(\n            num_thresholds=self.num_thresholds,\n            curve=\"PR\",\n            summation_method=\"minoring\",\n        )\n        result = auc_obj(\n            self.y_true, self.y_pred, sample_weight=self.sample_weight\n        )\n\n        # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]\n        # precision = [7/(7+3), 4/4, 0] = [0.7, 1, 0]\n        # recall = [7/7, 4/(4+3), 0] = [1, 0.571, 0]\n        # heights = [min(0.7, 1), min(1, 0)] = [0.7, 0]\n        # widths = [(1 - 0.571), (0.571 - 0)] = [0.429, 0.571]\n        expected_result = 0.7 * 0.429 + 0 * 0.571\n        self.assertAllClose(result, expected_result, 1e-3)\n\n    def test_weighted_pr_interpolation(self):\n        auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, curve=\"PR\")\n        result = auc_obj(\n            self.y_true, self.y_pred, sample_weight=self.sample_weight\n        )\n\n        # auc = (slope / Total Pos) * [dTP - intercept * log(Pb/Pa)]\n\n        # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]\n        # P = tp + fp = [10, 4, 0]\n        # dTP = [7-4, 4-0] = [3, 4]\n        # dP = [10-4, 4-0] = [6, 4]\n        # slope = dTP/dP = [0.5, 1]\n        # intercept = (TPa+(slope*Pa) = [(4 - 0.5*4), (0 - 1*0)] = [2, 0]\n        # (Pb/Pa) = (Pb/Pa) if Pb > 0 AND Pa > 0 else 1 = [10/4, 4/0] = [2.5, 1]\n        # auc * TotalPos = [(0.5 * (3 + 2 * log(2.5))), (1 * (4 + 0))]\n        #                = [2.416, 4]\n        # auc = [2.416, 4]/(tp[1:]+fn[1:])\n        expected_result = 2.416 / 7 + 4 / 7\n        self.assertAllClose(result, expected_result, 1e-3)\n\n    def test_weighted_pr_interpolation_negative_weights(self):\n        auc_obj = metrics.AUC(num_thresholds=self.num_thresholds, curve=\"PR\")\n        sample_weight = [-1, -2, -3, -4]\n        result = auc_obj(self.y_true, self.y_pred, sample_weight=sample_weight)\n\n        # Divisor in auc formula is max(tp[1:]+fn[1:], 0), which is all zeros\n        # because the all values in tp and fn are negative, divide_no_nan will\n        # produce all zeros.\n        self.assertAllClose(result, 0.0, 1e-3)\n\n    def test_weighted_prgain_majoring(self):\n        auc_obj = metrics.AUC(\n            num_thresholds=self.num_thresholds,\n            curve=\"PRGAIN\",\n            summation_method=\"majoring\",\n        )\n        result = auc_obj(\n            self.y_true, self.y_pred, sample_weight=self.sample_weight\n        )\n\n        # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]\n        # scaling_factor (P/N) = 7/3\n        # recall_gain = 1 - 7/3 [0/7, 3/4, 7/0] = [1, -3/4, -inf] -> [1, 0, 0]\n        # precision_gain = 1 - 7/3 [3/7, 0/4, 0/0] = [0, 1, NaN] -> [0, 1, 1]\n        # heights = [max(0, 1), max(1, 1)] = [1, 1]\n        # widths = [(1 - 0), (0 - 0)] = [1, 0]\n        expected_result = 1 * 1 + 0 * 1\n        self.assertAllClose(result, expected_result, 1e-3)\n\n    def test_weighted_prgain_minoring(self):\n        auc_obj = metrics.AUC(\n            num_thresholds=self.num_thresholds,\n            curve=\"PRGAIN\",\n            summation_method=\"minoring\",\n        )\n        result = auc_obj(\n            self.y_true, self.y_pred, sample_weight=self.sample_weight\n        )\n\n        # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]\n        # scaling_factor (P/N) = 7/3\n        # recall_gain = 1 - 7/3 [0/7, 3/4, 7/0] = [1, -3/4, -inf] -> [1, 0, 0]\n        # precision_gain = 1 - 7/3 [3/7, 0/4, 0/0] = [0, 1, NaN] -> [0, 1, 1]\n        # heights = [min(0, 1), min(1, 1)] = [0, 1]\n        # widths = [(1 - 0), (0 - 0)] = [1, 0]\n        expected_result = 1 * 0 + 0 * 1\n        self.assertAllClose(result, expected_result, 1e-3)\n\n    def test_weighted_prgain_interpolation(self):\n        auc_obj = metrics.AUC(\n            num_thresholds=self.num_thresholds, curve=\"PRGAIN\"\n        )\n        result = auc_obj(\n            self.y_true, self.y_pred, sample_weight=self.sample_weight\n        )\n\n        # tp = [7, 4, 0], fp = [3, 0, 0], fn = [0, 3, 7], tn = [0, 3, 3]\n        # scaling_factor (P/N) = 7/3\n        # recall_gain = 1 - 7/3 [0/7, 3/4, 7/0] = [1, -3/4, -inf] -> [1, 0, 0]\n        # precision_gain = 1 - 7/3 [3/7, 0/4, 0/0] = [0, 1, NaN] -> [0, 1, 1]\n        # heights = [(0+1)/2, (1+1)/2] = [0.5, 1]\n        # widths = [(1 - 0), (0 - 0)] = [1, 0]\n        expected_result = 1 * 0.5 + 0 * 1\n        self.assertAllClose(result, expected_result, 1e-3)\n\n    def test_prgain_interpolation(self):\n        auc_obj = metrics.AUC(\n            num_thresholds=self.num_thresholds, curve=\"PRGAIN\"\n        )\n\n        y_true = np.array([0, 0, 0, 1, 0, 1, 0, 1, 1, 1])\n        y_pred = np.array([0.1, 0.2, 0.3, 0.3, 0.4, 0.4, 0.6, 0.6, 0.8, 0.9])\n        result = auc_obj(y_true, y_pred)\n\n        # tp = [5, 3, 0], fp = [5, 1, 0], fn = [0, 2, 5], tn = [0, 4, 4]\n        # scaling_factor (P/N) = 5/5 = 1\n        # recall_gain = 1 - [0/5, 2/3, 5/0] = [1, 1/3, -inf] -> [1, 1/3, 0]\n        # precision_gain = 1 - [5/5, 1/3, 0/0] = [1, 1/3, NaN] -> [0, 2/3, 1]\n        # heights = [(0+2/3)/2, (2/3+1)/2] = [0.333333, 0.833333]\n        # widths = [(1 - 1/3), (1/3 - 0)] = [0.666666, 0.333333]\n        expected_result = 0.666666 * 0.333333 + 0.333333 * 0.833333\n        self.assertAllClose(result, expected_result, 1e-3)\n\n    def test_invalid_num_thresholds(self):\n        with self.assertRaisesRegex(\n            ValueError, \"Argument `num_thresholds` must be an integer > 1\"\n        ):\n            metrics.AUC(num_thresholds=-1)\n\n        with self.assertRaisesRegex(\n            ValueError, \"Argument `num_thresholds` must be an integer > 1.\"\n        ):\n            metrics.AUC(num_thresholds=1)\n\n    def test_invalid_curve(self):\n        with self.assertRaisesRegex(\n            ValueError, 'Invalid AUC curve value: \"Invalid\".'\n        ):\n            metrics.AUC(curve=\"Invalid\")\n\n    def test_invalid_summation_method(self):\n        with self.assertRaisesRegex(\n            ValueError, 'Invalid AUC summation method value: \"Invalid\".'\n        ):\n            metrics.AUC(summation_method=\"Invalid\")\n\n    def test_extra_dims(self):\n        try:\n            from scipy import special\n\n            logits = special.expit(\n                -np.array(\n                    [\n                        [[-10.0, 10.0, -10.0], [10.0, -10.0, 10.0]],\n                        [[-12.0, 12.0, -12.0], [12.0, -12.0, 12.0]],\n                    ],\n                    dtype=np.float32,\n                )\n            )\n            labels = np.array(\n                [[[1, 0, 0], [1, 0, 0]], [[0, 1, 1], [0, 1, 1]]], dtype=np.int64\n            )\n            auc_obj = metrics.AUC()\n            result = auc_obj(labels, logits)\n            self.assertEqual(result, 0.5)\n        except ImportError as e:\n            logging.warning(f\"Cannot test special functions: {str(e)}\")\n\n\nclass MultiAUCTest(testing.TestCase):\n    def setUp(self):\n        self.num_thresholds = 5\n        self.y_pred = np.array(\n            [[0, 0.5, 0.3, 0.9], [0.1, 0.2, 0.3, 0.4]], dtype=\"float32\"\n        ).T\n\n        epsilon = 1e-12\n        self.y_pred_logits = -ops.log(1.0 / (self.y_pred + epsilon) - 1.0)\n\n        self.y_true_good = np.array([[0, 0, 1, 1], [0, 0, 1, 1]]).T\n        self.y_true_bad = np.array([[0, 0, 1, 1], [1, 1, 0, 0]]).T\n        self.sample_weight = [1, 2, 3, 4]\n\n        # threshold values are [0 - 1e-7, 0.25, 0.5, 0.75, 1 + 1e-7]\n        # y_pred when threshold = 0 - 1e-7   : [[1, 1, 1, 1], [1, 1, 1, 1]]\n        # y_pred when threshold = 0.25       : [[0, 1, 1, 1], [0, 0, 1, 1]]\n        # y_pred when threshold = 0.5        : [[0, 0, 0, 1], [0, 0, 0, 0]]\n        # y_pred when threshold = 0.75       : [[0, 0, 0, 1], [0, 0, 0, 0]]\n        # y_pred when threshold = 1 + 1e-7   : [[0, 0, 0, 0], [0, 0, 0, 0]]\n\n        # for y_true_good, over thresholds:\n        # tp = [[2, 2, 1, 1, 0], [2, 2, 0, 0, 0]]\n        # fp = [[2, 1, 0, 0 , 0], [2, 0, 0 ,0, 0]]\n        # fn = [[0, 0, 1, 1, 2], [0, 0, 2, 2, 2]]\n        # tn = [[0, 1, 2, 2, 2], [0, 2, 2, 2, 2]]\n\n        # tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]]\n        # fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]]\n\n        # for y_true_bad:\n        # tp = [[2, 2, 1, 1, 0], [2, 0, 0, 0, 0]]\n        # fp = [[2, 1, 0, 0 , 0], [2, 2, 0 ,0, 0]]\n        # fn = [[0, 0, 1, 1, 2], [0, 2, 2, 2, 2]]\n        # tn = [[0, 1, 2, 2, 2], [0, 0, 2, 2, 2]]\n\n        # tpr = [[1, 1, 0.5, 0.5, 0], [1, 0, 0, 0, 0]]\n        # fpr = [[1, 0.5, 0, 0, 0], [1, 1, 0, 0, 0]]\n\n        # for y_true_good with sample_weights:\n\n        # tp = [[7, 7, 4, 4, 0], [7, 7, 0, 0, 0]]\n        # fp = [[3, 2, 0, 0, 0], [3, 0, 0, 0, 0]]\n        # fn = [[0, 0, 3, 3, 7], [0, 0, 7, 7, 7]]\n        # tn = [[0, 1, 3, 3, 3], [0, 3, 3, 3, 3]]\n\n        # tpr = [[1, 1,    0.57, 0.57, 0], [1, 1, 0, 0, 0]]\n        # fpr = [[1, 0.67, 0,    0,    0], [1, 0, 0, 0, 0]]\n\n    def test_unweighted_all_correct(self):\n        auc_obj = metrics.AUC(multi_label=True)\n        result = auc_obj(self.y_true_good, self.y_true_good)\n        self.assertEqual(result, 1)\n\n    def test_unweighted_all_correct_flat(self):\n        auc_obj = metrics.AUC(multi_label=False)\n        result = auc_obj(self.y_true_good, self.y_true_good)\n        self.assertEqual(result, 1)\n\n    def test_unweighted(self):\n        auc_obj = metrics.AUC(\n            num_thresholds=self.num_thresholds, multi_label=True\n        )\n        result = auc_obj(self.y_true_good, self.y_pred)\n\n        # tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]]\n        # fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]]\n        expected_result = (0.875 + 1.0) / 2.0\n        self.assertAllClose(result, expected_result, 1e-3)\n\n    def test_unweighted_from_logits(self):\n        auc_obj = metrics.AUC(\n            num_thresholds=self.num_thresholds,\n            multi_label=True,\n            from_logits=True,\n        )\n        result = auc_obj(self.y_true_good, self.y_pred_logits)\n\n        # tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]]\n        # fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]]\n        expected_result = (0.875 + 1.0) / 2.0\n        self.assertAllClose(result, expected_result, 1e-3)\n\n    def test_sample_weight_flat(self):\n        auc_obj = metrics.AUC(\n            num_thresholds=self.num_thresholds, multi_label=False\n        )\n        result = auc_obj(\n            self.y_true_good, self.y_pred, sample_weight=[1, 2, 3, 4]\n        )\n\n        # tpr = [1, 1, 0.2857, 0.2857, 0]\n        # fpr = [1, 0.3333, 0, 0, 0]\n        expected_result = 1.0 - (0.3333 * (1.0 - 0.2857) / 2.0)\n        self.assertAllClose(result, expected_result, 1e-3)\n\n    def test_full_sample_weight_flat(self):\n        auc_obj = metrics.AUC(\n            num_thresholds=self.num_thresholds, multi_label=False\n        )\n        sw = np.arange(4 * 2)\n        sw = sw.reshape(4, 2)\n        result = auc_obj(self.y_true_good, self.y_pred, sample_weight=sw)\n\n        # tpr = [1, 1, 0.2727, 0.2727, 0]\n        # fpr = [1, 0.3333, 0, 0, 0]\n        expected_result = 1.0 - (0.3333 * (1.0 - 0.2727) / 2.0)\n        self.assertAllClose(result, expected_result, 1e-3)\n\n    def test_label_weights(self):\n        auc_obj = metrics.AUC(\n            num_thresholds=self.num_thresholds,\n            multi_label=True,\n            label_weights=[0.75, 0.25],\n        )\n        result = auc_obj(self.y_true_good, self.y_pred)\n\n        # tpr = [[1, 1, 0.5, 0.5, 0], [1, 1, 0, 0, 0]]\n        # fpr = [[1, 0.5, 0, 0, 0], [1, 0, 0, 0, 0]]\n        expected_result = (0.875 * 0.75 + 1.0 * 0.25) / (0.75 + 0.25)\n        self.assertAllClose(result, expected_result, 1e-3)\n\n    def test_label_weights_flat(self):\n        auc_obj = metrics.AUC(\n            num_thresholds=self.num_thresholds,\n            multi_label=False,\n            label_weights=[0.75, 0.25],\n        )\n        result = auc_obj(self.y_true_good, self.y_pred)\n\n        # tpr = [1, 1, 0.375, 0.375, 0]\n        # fpr = [1, 0.375, 0, 0, 0]\n        expected_result = 1.0 - ((1.0 - 0.375) * 0.375 / 2.0)\n        self.assertAllClose(result, expected_result, 1e-2)\n\n    def test_unweighted_flat(self):\n        auc_obj = metrics.AUC(\n            num_thresholds=self.num_thresholds, multi_label=False\n        )\n        result = auc_obj(self.y_true_good, self.y_pred)\n\n        # tp = [4, 4, 1, 1, 0]\n        # fp = [4, 1, 0, 0, 0]\n        # fn = [0, 0, 3, 3, 4]\n        # tn = [0, 3, 4, 4, 4]\n\n        # tpr = [1, 1, 0.25, 0.25, 0]\n        # fpr = [1, 0.25, 0, 0, 0]\n        expected_result = 1.0 - (3.0 / 32.0)\n        self.assertAllClose(result, expected_result, 1e-3)\n\n    def test_unweighted_flat_from_logits(self):\n        auc_obj = metrics.AUC(\n            num_thresholds=self.num_thresholds,\n            multi_label=False,\n            from_logits=True,\n        )\n        result = auc_obj(self.y_true_good, self.y_pred_logits)\n\n        # tp = [4, 4, 1, 1, 0]\n        # fp = [4, 1, 0, 0, 0]\n        # fn = [0, 0, 3, 3, 4]\n        # tn = [0, 3, 4, 4, 4]\n\n        # tpr = [1, 1, 0.25, 0.25, 0]\n        # fpr = [1, 0.25, 0, 0, 0]\n        expected_result = 1.0 - (3.0 / 32.0)\n        self.assertAllClose(result, expected_result, 1e-3)\n\n    def test_manual_thresholds(self):\n        # Verify that when specified, thresholds are used instead of\n        # num_thresholds.\n        auc_obj = metrics.AUC(\n            num_thresholds=2, thresholds=[0.5], multi_label=True\n        )\n        self.assertEqual(auc_obj.num_thresholds, 3)\n        self.assertAllClose(auc_obj.thresholds, [0.0, 0.5, 1.0])\n        result = auc_obj(self.y_true_good, self.y_pred)\n\n        # tp = [[2, 1, 0], [2, 0, 0]]\n        # fp = [2, 0, 0], [2, 0, 0]]\n        # fn = [[0, 1, 2], [0, 2, 2]]\n        # tn = [[0, 2, 2], [0, 2, 2]]\n\n        # tpr = [[1, 0.5, 0], [1, 0, 0]]\n        # fpr = [[1, 0, 0], [1, 0, 0]]\n\n        # auc by slice = [0.75, 0.5]\n        expected_result = (0.75 + 0.5) / 2.0\n\n        self.assertAllClose(result, expected_result, 1e-3)\n\n    def test_weighted_roc_interpolation(self):\n        auc_obj = metrics.AUC(\n            num_thresholds=self.num_thresholds, multi_label=True\n        )\n        result = auc_obj(\n            self.y_true_good, self.y_pred, sample_weight=self.sample_weight\n        )\n\n        # tpr = [[1, 1,    0.57, 0.57, 0], [1, 1, 0, 0, 0]]\n        # fpr = [[1, 0.67, 0,    0,    0], [1, 0, 0, 0, 0]]\n        expected_result = 1.0 - 0.5 * 0.43 * 0.67\n        self.assertAllClose(result, expected_result, 1e-1)\n\n    def test_pr_interpolation_unweighted(self):\n        auc_obj = metrics.AUC(\n            num_thresholds=self.num_thresholds, curve=\"PR\", multi_label=True\n        )\n        good_result = auc_obj(self.y_true_good, self.y_pred)\n        with self.subTest(name=\"good\"):\n            # PR AUCs are 0.917 and 1.0 respectively\n            self.assertAllClose(good_result, (0.91667 + 1.0) / 2.0, 1e-1)\n        bad_result = auc_obj(self.y_true_bad, self.y_pred)\n        with self.subTest(name=\"bad\"):\n            # PR AUCs are 0.917 and 0.5 respectively\n            self.assertAllClose(bad_result, (0.91667 + 0.5) / 2.0, 1e-1)\n\n    def test_pr_interpolation(self):\n        auc_obj = metrics.AUC(\n            num_thresholds=self.num_thresholds, curve=\"PR\", multi_label=True\n        )\n        good_result = auc_obj(\n            self.y_true_good, self.y_pred, sample_weight=self.sample_weight\n        )\n        # PR AUCs are 0.939 and 1.0 respectively\n        self.assertAllClose(good_result, (0.939 + 1.0) / 2.0, 1e-1)\n\n    @pytest.mark.requires_trainable_backend\n    def test_keras_model_compiles(self):\n        inputs = layers.Input(shape=(10,), batch_size=1)\n        output = layers.Dense(3, activation=\"sigmoid\")(inputs)\n        model = models.Model(inputs=inputs, outputs=output)\n        model.compile(\n            optimizer=\"adam\",\n            loss=\"binary_crossentropy\",\n            metrics=[metrics.AUC(multi_label=True)],\n        )\n\n    def test_reset_state(self):\n        auc_obj = metrics.AUC(\n            num_thresholds=self.num_thresholds, multi_label=True\n        )\n        auc_obj(self.y_true_good, self.y_pred)\n        auc_obj.reset_state()\n        self.assertAllClose(auc_obj.true_positives, np.zeros((5, 2)))\n"
  },
  {
    "path": "keras/src/metrics/correlation_metrics.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.losses.loss import squeeze_or_expand_to_same_rank\nfrom keras.src.metrics import reduction_metrics\n\n\n@keras_export(\"keras.metrics.pearson_correlation\")\ndef pearson_correlation(y_true, y_pred, axis=-1):\n    \"\"\"Computes the Pearson coefficient between labels and predictions.\n\n    Formula:\n\n    ```python\n    loss = mean(l2norm(y_true - mean(y_true) * l2norm(y_pred - mean(y_pred)))\n    ```\n\n    Args:\n        y_true: Tensor of true targets.\n        y_pred: Tensor of predicted targets.\n        axis: Axis along which to determine similarity. Defaults to `-1`.\n\n    Returns:\n        Pearson Correlation Coefficient tensor.\n\n    Example:\n\n    >>> y_true = [[0, 1, 0.5], [1, 1, 0.2]]\n    >>> y_pred = [[0.1, 0.9, 0.5], [1, 0.9, 0.2]]\n    >>> loss = keras.losses.concordance_correlation(\n    ...     y_true, y_pred, axis=-1\n    ... ).numpy()\n    [1.         0.99339927]\n    \"\"\"\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)\n    y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)\n\n    y_true_norm = y_true - ops.mean(y_true, axis=axis, keepdims=True)\n    y_pred_norm = y_pred - ops.mean(y_pred, axis=axis, keepdims=True)\n\n    y_true_norm = y_true_norm / ops.std(y_true_norm, axis=axis, keepdims=True)\n    y_pred_norm = y_pred_norm / ops.std(y_pred_norm, axis=axis, keepdims=True)\n\n    return ops.mean(y_true_norm * y_pred_norm, axis=axis)\n\n\n@keras_export(\"keras.metrics.concordance_correlation\")\ndef concordance_correlation(y_true, y_pred, axis=-1):\n    \"\"\"Computes the Concordance coefficient between labels and predictions.\n\n    Formula:\n\n    ```python\n    loss = mean(\n        2 * (y_true - mean(y_true) * (y_pred - mean(y_pred)) / (\n            var(y_true) + var(y_pred) + square(mean(y_true) - mean(y_pred))\n        )\n    )\n    ```\n\n    Args:\n        y_true: Tensor of true targets.\n        y_pred: Tensor of predicted targets.\n        axis: Axis along which to determine similarity. Defaults to `-1`.\n\n    Returns:\n        Concordance Correlation Coefficient tensor.\n\n    Example:\n\n    >>> y_true = [[0, 1, 0.5], [1, 1, 0.2]]\n    >>> y_pred = [[0.1, 0.9, 0.5], [1, 0.9, 0.2]]\n    >>> loss = keras.losses.concordance_correlation(\n    ...     y_true, y_pred, axis=-1\n    ... ).numpy()\n    [0.97560976 0.98765432]\n    \"\"\"\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)\n    y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)\n\n    y_true_mean = ops.mean(y_true, axis=axis, keepdims=True)\n    y_pred_mean = ops.mean(y_pred, axis=axis, keepdims=True)\n\n    y_true_var = ops.var(y_true - y_true_mean, axis=axis, keepdims=True)\n    y_pred_var = ops.var(y_pred - y_pred_mean, axis=axis, keepdims=True)\n\n    covar = (y_true - y_pred_mean) * (y_pred - y_pred_mean)\n    norm = y_true_var + y_pred_var + ops.square(y_true_mean - y_pred_mean)\n\n    return ops.mean(2 * covar / (norm + backend.epsilon()), axis=axis)\n\n\n@keras_export(\"keras.metrics.PearsonCorrelation\")\nclass PearsonCorrelation(reduction_metrics.MeanMetricWrapper):\n    \"\"\"Calculates the Pearson Correlation Coefficient (PCC).\n\n    PCC measures the linear relationship between the true values (`y_true`) and\n    the predicted values (`y_pred`). The coefficient ranges from -1 to 1, where\n    a value of 1 implies a perfect positive linear correlation, 0 indicates no\n    linear correlation, and -1 indicates a perfect negative linear correlation.\n\n    This metric is widely used in regression tasks where the strength of the\n    linear relationship between predictions and true labels is an\n    important evaluation criterion.\n\n    Args:\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n        axis: (Optional) integer or tuple of integers of the axis/axes along\n            which to compute the metric. Defaults to `-1`.\n\n    Example:\n\n    >>> pcc = keras.metrics.PearsonCorrelation(axis=-1)\n    >>> y_true = [[0, 1, 0.5], [1, 1, 0.2]]\n    >>> y_pred = [[0.1, 0.9, 0.5], [1, 0.9, 0.2]]\n    >>> pcc.update_state(y_true, y_pred)\n    >>> pcc.result()\n    0.9966996338993913\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(optimizer='sgd',\n                  loss='mean_squared_error',\n                  metrics=[keras.metrics.PearsonCorrelation()])\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        name=\"pearson_correlation\",\n        dtype=None,\n        axis=-1,\n    ):\n        super().__init__(\n            fn=pearson_correlation,\n            name=name,\n            dtype=dtype,\n            axis=axis,\n        )\n        self.axis = axis\n        # Metric should be maximized during optimization.\n        self._direction = \"up\"\n\n    def get_config(self):\n        return {\n            \"name\": self.name,\n            \"dtype\": self.dtype,\n            \"axis\": self.axis,\n        }\n\n\n@keras_export(\"keras.metrics.ConcordanceCorrelation\")\nclass ConcordanceCorrelation(reduction_metrics.MeanMetricWrapper):\n    \"\"\"Calculates the Concordance Correlation Coefficient (CCC).\n\n    CCC evaluates the agreement between true values (`y_true`) and predicted\n    values (`y_pred`) by considering both precision and accuracy. The\n    coefficient ranges from -1 to 1, where a value of 1 indicates perfect\n    agreement.\n\n    This metric is useful in regression tasks where it is important to assess\n    how well the predictions match the true values, taking into account both\n    their correlation and proximity to the 45-degree line of perfect\n    concordance.\n\n    Args:\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n        axis: (Optional) integer or tuple of integers of the axis/axes along\n            which to compute the metric. Defaults to `-1`.\n\n    Example:\n\n    >>> ccc = keras.metrics.ConcordanceCorrelation(axis=-1)\n    >>> y_true = [[0, 1, 0.5], [1, 1, 0.2]]\n    >>> y_pred = [[0.1, 0.9, 0.5], [1, 0.9, 0.2]]\n    >>> ccc.update_state(y_true, y_pred)\n    >>> ccc.result()\n    0.9816320385426076\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(optimizer='sgd',\n                  loss='mean_squared_error',\n                  metrics=[keras.metrics.ConcordanceCorrelation()])\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        name=\"concordance_correlation\",\n        dtype=None,\n        axis=-1,\n    ):\n        super().__init__(\n            fn=concordance_correlation,\n            name=name,\n            dtype=dtype,\n            axis=axis,\n        )\n        self.axis = axis\n        # Metric should be maximized during optimization.\n        self._direction = \"up\"\n\n    def get_config(self):\n        return {\n            \"name\": self.name,\n            \"dtype\": self.dtype,\n            \"axis\": self.axis,\n        }\n"
  },
  {
    "path": "keras/src/metrics/correlation_metrics_test.py",
    "content": "import numpy as np\nfrom scipy.stats import pearsonr\n\nfrom keras.src import testing\nfrom keras.src.metrics import ConcordanceCorrelation\nfrom keras.src.metrics import PearsonCorrelation\nfrom keras.src.metrics import correlation_metrics\n\n\nclass CorrelationsTest(testing.TestCase):\n    def _get_data(self):\n        # Sample data for testing\n        y_true = np.array(\n            [[0, 1, 0.5], [1, 1, 0.2], [1, 1, 0.1], [0.1, 0.7, 0.0]],\n            dtype=\"float32\",\n        )\n        y_pred = np.array(\n            [[0.1, 0.9, 0.5], [1, 0.9, 0.2], [0.2, 0.8, 0], [0.3, 0.3, 0.9]],\n            dtype=\"float32\",\n        )\n\n        ccc_expected = np.array(\n            [0.97560976, 0.98765432, 0.46511628, -0.46376812]\n        )\n        # pcc_expected = np.array([1, 0.99339927, 0.69337525, -0.60999428])\n        pcc_expected = np.array(\n            [pearsonr(yt, yp).statistic for yt, yp in zip(y_true, y_pred)]\n        )\n        return y_true, y_pred, ccc_expected, pcc_expected\n\n    def test_pearson_function(self):\n        \"\"\"Test the functional API for Pearson Correlation Coefficient.\"\"\"\n        y_true, y_pred, _, pcc_expected = self._get_data()\n        result = correlation_metrics.pearson_correlation(\n            y_true, y_pred, axis=-1\n        )\n        self.assertAllClose(result, pcc_expected)\n\n    def test_concordance_function(self):\n        \"\"\"Test the functional API for Concordance Correlation Coefficient.\"\"\"\n        y_true, y_pred, ccc_expected, _ = self._get_data()\n        result = correlation_metrics.concordance_correlation(\n            y_true, y_pred, axis=-1\n        )\n        self.assertAllClose(result, ccc_expected)\n\n    def test_pearson_class(self):\n        \"\"\"Test the PearsonCorrelation metric class.\"\"\"\n        y_true, y_pred, _, pcc_expected = self._get_data()\n        m = PearsonCorrelation(axis=-1, dtype=\"float32\")\n        m.update_state(y_true[:2], y_pred[:2])\n        self.assertAllClose(m.result(), np.mean(pcc_expected[:2]))\n        m.update_state(y_true[2:], y_pred[2:])\n        self.assertAllClose(m.result(), np.mean(pcc_expected))\n\n    def test_concordance_class(self):\n        \"\"\"Test the ConcordanceCorrelation metric class.\"\"\"\n        y_true, y_pred, ccc_expected, _ = self._get_data()\n        m = ConcordanceCorrelation(axis=-1, dtype=\"float32\")\n        m.update_state(y_true[:2], y_pred[:2])\n        self.assertAllClose(m.result(), np.mean(ccc_expected[:2]))\n        m.update_state(y_true[2:], y_pred[2:])\n        self.assertAllClose(m.result(), np.mean(ccc_expected))\n\n    def test_pearson_config(self):\n        \"\"\"Test the get_config method for PearsonCorrelation.\"\"\"\n        m = PearsonCorrelation(axis=-1, dtype=\"float16\")\n        config = m.get_config()\n        self.assertEqual(config[\"axis\"], -1)\n        self.assertEqual(config[\"dtype\"], \"float16\")\n        self.assertEqual(config[\"name\"], \"pearson_correlation\")\n\n    def test_concordance_config(self):\n        \"\"\"Test the get_config method for ConcordanceCorrelation.\"\"\"\n        m = ConcordanceCorrelation(axis=-1, dtype=\"float32\")\n        config = m.get_config()\n        self.assertEqual(config[\"axis\"], -1)\n        self.assertEqual(config[\"dtype\"], \"float32\")\n        self.assertEqual(config[\"name\"], \"concordance_correlation\")\n"
  },
  {
    "path": "keras/src/metrics/f_score_metrics.py",
    "content": "from keras.src import backend\nfrom keras.src import initializers\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.metrics.metric import Metric\n\n\n@keras_export(\"keras.metrics.FBetaScore\")\nclass FBetaScore(Metric):\n    \"\"\"Computes F-Beta score.\n\n    Formula:\n\n    ```python\n    b2 = beta ** 2\n    f_beta_score = (1 + b2) * (precision * recall) / (precision * b2 + recall)\n    ```\n    This is the weighted harmonic mean of precision and recall.\n    Its output range is `[0, 1]`. It works for both multi-class\n    and multi-label classification.\n\n    Args:\n        average: Type of averaging to be performed across per-class results\n            in the multi-class case.\n            Acceptable values are `None`, `\"micro\"`, `\"macro\"` and\n            `\"weighted\"`. Defaults to `None`.\n            If `None`, no averaging is performed and `result()` will return\n            the score for each class.\n            If `\"micro\"`, compute metrics globally by counting the total\n            true positives, false negatives and false positives.\n            If `\"macro\"`, compute metrics for each label,\n            and return their unweighted mean.\n            This does not take label imbalance into account.\n            If `\"weighted\"`, compute metrics for each label,\n            and return their average weighted by support\n            (the number of true instances for each label).\n            This alters `\"macro\"` to account for label imbalance.\n            It can result in an score that is not between precision and recall.\n        beta: Determines the weight of given to recall\n            in the harmonic mean between precision and recall (see pseudocode\n            equation above). Defaults to `1`.\n        threshold: Elements of `y_pred` greater than `threshold` are\n            converted to be 1, and the rest 0. If `threshold` is\n            `None`, the argmax of `y_pred` is converted to 1, and the rest to 0.\n        name: Optional. String name of the metric instance.\n        dtype: Optional. Data type of the metric result.\n\n    Returns:\n        F-Beta Score: float.\n\n    Example:\n\n    >>> metric = keras.metrics.FBetaScore(beta=2.0, threshold=0.5)\n    >>> y_true = np.array([[1, 1, 1],\n    ...                    [1, 0, 0],\n    ...                    [1, 1, 0]], np.int32)\n    >>> y_pred = np.array([[0.2, 0.6, 0.7],\n    ...                    [0.2, 0.6, 0.6],\n    ...                    [0.6, 0.8, 0.0]], np.float32)\n    >>> metric.update_state(y_true, y_pred)\n    >>> result = metric.result()\n    >>> result\n    [0.3846154 , 0.90909094, 0.8333334 ]\n    \"\"\"\n\n    def __init__(\n        self,\n        average=None,\n        beta=1.0,\n        threshold=None,\n        name=\"fbeta_score\",\n        dtype=None,\n    ):\n        super().__init__(name=name, dtype=dtype)\n        # Metric should be maximized during optimization.\n        self._direction = \"up\"\n\n        if average not in (None, \"micro\", \"macro\", \"weighted\"):\n            raise ValueError(\n                \"Invalid `average` argument value. Expected one of: \"\n                \"{None, 'micro', 'macro', 'weighted'}. \"\n                f\"Received: average={average}\"\n            )\n\n        if not isinstance(beta, float):\n            raise ValueError(\n                \"Invalid `beta` argument value. \"\n                \"It should be a Python float. \"\n                f\"Received: beta={beta} of type '{type(beta)}'\"\n            )\n        if beta <= 0.0:\n            raise ValueError(\n                \"Invalid `beta` argument value. \"\n                \"It should be > 0. \"\n                f\"Received: beta={beta}\"\n            )\n\n        if threshold is not None:\n            if not isinstance(threshold, float):\n                raise ValueError(\n                    \"Invalid `threshold` argument value. \"\n                    \"It should be a Python float. \"\n                    f\"Received: threshold={threshold} \"\n                    f\"of type '{type(threshold)}'\"\n                )\n            if threshold > 1.0 or threshold <= 0.0:\n                raise ValueError(\n                    \"Invalid `threshold` argument value. \"\n                    \"It should verify 0 < threshold <= 1. \"\n                    f\"Received: threshold={threshold}\"\n                )\n\n        self.average = average\n        self.beta = beta\n        self.threshold = threshold\n        self.axis = None\n        self._built = False\n\n        if self.average != \"micro\":\n            self.axis = 0\n\n    def _build(self, y_true_shape, y_pred_shape):\n        if len(y_pred_shape) != 2 or len(y_true_shape) != 2:\n            raise ValueError(\n                \"FBetaScore expects 2D inputs with shape \"\n                \"(batch_size, output_dim). Received input \"\n                f\"shapes: y_pred.shape={y_pred_shape} and \"\n                f\"y_true.shape={y_true_shape}.\"\n            )\n        if y_pred_shape[-1] is None or y_true_shape[-1] is None:\n            raise ValueError(\n                \"FBetaScore expects 2D inputs with shape \"\n                \"(batch_size, output_dim), with output_dim fully \"\n                \"defined (not None). Received input \"\n                f\"shapes: y_pred.shape={y_pred_shape} and \"\n                f\"y_true.shape={y_true_shape}.\"\n            )\n        num_classes = y_pred_shape[-1]\n        if self.average != \"micro\":\n            init_shape = (num_classes,)\n        else:\n            init_shape = ()\n\n        def _add_zeros_variable(name):\n            return self.add_variable(\n                name=name,\n                shape=init_shape,\n                initializer=initializers.Zeros(),\n                dtype=self.dtype,\n            )\n\n        self.true_positives = _add_zeros_variable(\"true_positives\")\n        self.false_positives = _add_zeros_variable(\"false_positives\")\n        self.false_negatives = _add_zeros_variable(\"false_negatives\")\n        self.intermediate_weights = _add_zeros_variable(\"intermediate_weights\")\n        self._built = True\n\n    def update_state(self, y_true, y_pred, sample_weight=None):\n        y_true = ops.convert_to_tensor(y_true, dtype=self.dtype)\n        y_pred = ops.convert_to_tensor(y_pred, dtype=self.dtype)\n        if not self._built:\n            self._build(y_true.shape, y_pred.shape)\n\n        if self.threshold is None:\n            threshold = ops.max(y_pred, axis=-1, keepdims=True)\n            # make sure [0, 0, 0] doesn't become [1, 1, 1]\n            # Use abs(x) > eps, instead of x != 0 to check for zero\n            y_pred = ops.logical_and(\n                y_pred >= threshold, ops.abs(y_pred) > 1e-9\n            )\n        else:\n            y_pred = y_pred > self.threshold\n\n        y_pred = ops.cast(y_pred, dtype=self.dtype)\n        y_true = ops.cast(y_true, dtype=self.dtype)\n        if sample_weight is not None:\n            sample_weight = ops.convert_to_tensor(\n                sample_weight, dtype=self.dtype\n            )\n\n        def _weighted_sum(val, sample_weight):\n            if sample_weight is not None:\n                val = ops.multiply(val, ops.expand_dims(sample_weight, 1))\n            return ops.sum(val, axis=self.axis)\n\n        self.true_positives.assign(\n            self.true_positives + _weighted_sum(y_pred * y_true, sample_weight)\n        )\n        self.false_positives.assign(\n            self.false_positives\n            + _weighted_sum(y_pred * (1 - y_true), sample_weight)\n        )\n        self.false_negatives.assign(\n            self.false_negatives\n            + _weighted_sum((1 - y_pred) * y_true, sample_weight)\n        )\n        self.intermediate_weights.assign(\n            self.intermediate_weights + _weighted_sum(y_true, sample_weight)\n        )\n\n    def result(self):\n        precision = ops.divide(\n            self.true_positives,\n            self.true_positives + self.false_positives + backend.epsilon(),\n        )\n        recall = ops.divide(\n            self.true_positives,\n            self.true_positives + self.false_negatives + backend.epsilon(),\n        )\n\n        precision = ops.convert_to_tensor(precision, dtype=self.dtype)\n        recall = ops.convert_to_tensor(recall, dtype=self.dtype)\n\n        mul_value = precision * recall\n        add_value = ((self.beta**2) * precision) + recall\n        mean = ops.divide(mul_value, add_value + backend.epsilon())\n        f1_score = mean * (1 + (self.beta**2))\n\n        if self.average == \"weighted\":\n            weights = ops.divide(\n                self.intermediate_weights,\n                ops.sum(self.intermediate_weights) + backend.epsilon(),\n            )\n            f1_score = ops.sum(f1_score * weights)\n\n        elif self.average is not None:  # [micro, macro]\n            f1_score = ops.mean(f1_score)\n\n        return f1_score\n\n    def get_config(self):\n        \"\"\"Returns the serializable config of the metric.\"\"\"\n\n        config = {\n            \"name\": self.name,\n            \"dtype\": self.dtype,\n            \"average\": self.average,\n            \"beta\": self.beta,\n            \"threshold\": self.threshold,\n        }\n\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n    def reset_state(self):\n        for v in self.variables:\n            v.assign(ops.zeros(v.shape, dtype=v.dtype))\n\n\n@keras_export(\"keras.metrics.F1Score\")\nclass F1Score(FBetaScore):\n    r\"\"\"Computes F-1 Score.\n\n    Formula:\n\n    ```python\n    f1_score = 2 * (precision * recall) / (precision + recall)\n    ```\n    This is the harmonic mean of precision and recall.\n    Its output range is `[0, 1]`. It works for both multi-class\n    and multi-label classification.\n\n    Args:\n        average: Type of averaging to be performed on data.\n            Acceptable values are `None`, `\"micro\"`, `\"macro\"`\n            and `\"weighted\"`. Defaults to `None`.\n            If `None`, no averaging is performed and `result()` will return\n            the score for each class.\n            If `\"micro\"`, compute metrics globally by counting the total\n            true positives, false negatives and false positives.\n            If `\"macro\"`, compute metrics for each label,\n            and return their unweighted mean.\n            This does not take label imbalance into account.\n            If `\"weighted\"`, compute metrics for each label,\n            and return their average weighted by support\n            (the number of true instances for each label).\n            This alters `\"macro\"` to account for label imbalance.\n            It can result in an score that is not between precision and recall.\n        threshold: Elements of `y_pred` greater than `threshold` are\n            converted to be 1, and the rest 0. If `threshold` is\n            `None`, the argmax of `y_pred` is converted to 1, and the rest to 0.\n        name: Optional. String name of the metric instance.\n        dtype: Optional. Data type of the metric result.\n\n    Returns:\n        F-1 Score: float.\n\n    Example:\n\n    >>> metric = keras.metrics.F1Score(threshold=0.5)\n    >>> y_true = np.array([[1, 1, 1],\n    ...                    [1, 0, 0],\n    ...                    [1, 1, 0]], np.int32)\n    >>> y_pred = np.array([[0.2, 0.6, 0.7],\n    ...                    [0.2, 0.6, 0.6],\n    ...                    [0.6, 0.8, 0.0]], np.float32)\n    >>> metric.update_state(y_true, y_pred)\n    >>> result = metric.result()\n    array([0.5      , 0.8      , 0.6666667], dtype=float32)\n    \"\"\"\n\n    def __init__(\n        self,\n        average=None,\n        threshold=None,\n        name=\"f1_score\",\n        dtype=None,\n    ):\n        super().__init__(\n            average=average,\n            beta=1.0,\n            threshold=threshold,\n            name=name,\n            dtype=dtype,\n        )\n\n    def get_config(self):\n        base_config = super().get_config()\n        del base_config[\"beta\"]\n        return base_config\n"
  },
  {
    "path": "keras/src/metrics/f_score_metrics_test.py",
    "content": "import numpy as np\nfrom absl.testing import parameterized\n\nfrom keras.src import testing\nfrom keras.src.metrics import f_score_metrics\n\n\nclass FBetaScoreTest(testing.TestCase):\n    def _run_test(\n        self,\n        y_true,\n        y_pred,\n        sample_weights,\n        average,\n        beta,\n        threshold,\n        reference_result,\n    ):\n        fbeta = f_score_metrics.FBetaScore(\n            average, beta, threshold, dtype=\"float32\"\n        )\n        fbeta.update_state(y_true, y_pred, sample_weights)\n        result = fbeta.result()\n        self.assertAllClose(result, reference_result, atol=1e-6)\n\n    def test_config(self):\n        fbeta_obj = f_score_metrics.FBetaScore(\n            beta=0.5, threshold=0.3, average=None, dtype=\"float32\"\n        )\n        self.assertEqual(fbeta_obj.beta, 0.5)\n        self.assertEqual(fbeta_obj.average, None)\n        self.assertEqual(fbeta_obj.threshold, 0.3)\n        self.assertEqual(fbeta_obj.dtype, \"float32\")\n\n        # Check save and restore config\n        fbeta_obj2 = f_score_metrics.FBetaScore.from_config(\n            fbeta_obj.get_config()\n        )\n        self.assertEqual(fbeta_obj2.beta, 0.5)\n        self.assertEqual(fbeta_obj2.average, None)\n        self.assertEqual(fbeta_obj2.threshold, 0.3)\n        self.assertEqual(fbeta_obj2.dtype, \"float32\")\n\n    @parameterized.parameters(\n        (\"micro\", 0.5),\n        (\"micro\", 1.0),\n        (\"micro\", 2.0),\n        (\"macro\", 0.5),\n        (\"macro\", 1.0),\n        (\"macro\", 2.0),\n        (\"weighted\", 0.5),\n        (\"weighted\", 1.0),\n        (\"weighted\", 2.0),\n    )\n    def test_fbeta_perfect_score(self, average, beta):\n        y_true = [[1, 1, 1], [1, 0, 0], [1, 1, 0]]\n        y_pred = [[0.7, 0.7, 0.7], [1, 0, 0], [0.9, 0.8, 0]]\n        self._run_test(\n            y_true,\n            y_pred,\n            None,\n            average=average,\n            beta=beta,\n            threshold=0.66,\n            reference_result=1.0,\n        )\n\n    @parameterized.parameters(\n        (\"micro\", 0.5),\n        (\"micro\", 1.0),\n        (\"micro\", 2.0),\n        (\"macro\", 0.5),\n        (\"macro\", 1.0),\n        (\"macro\", 2.0),\n        (\"weighted\", 0.5),\n        (\"weighted\", 1.0),\n        (\"weighted\", 2.0),\n    )\n    def test_fbeta_worst_score(self, average, beta):\n        y_true = [[0, 0, 0], [0, 1, 0], [0, 0, 1]]\n        y_pred = [[0.7, 0.7, 0.7], [1, 0, 0], [0.9, 0.8, 0]]\n        self._run_test(\n            y_true,\n            y_pred,\n            None,\n            average=average,\n            beta=beta,\n            threshold=0.66,\n            reference_result=0.0,\n        )\n\n    @parameterized.parameters(\n        # average, beta, result\n        (None, 0.5, [0.71428573, 0.5, 0.833334]),\n        (None, 1.0, [0.8, 0.5, 0.6666667]),\n        (None, 2.0, [0.9090904, 0.5, 0.555556]),\n        (\"micro\", 0.5, 0.6666667),\n        (\"micro\", 1.0, 0.6666667),\n        (\"micro\", 2.0, 0.6666667),\n        (\"macro\", 0.5, 0.6825397),\n        (\"macro\", 1.0, 0.6555555),\n        (\"macro\", 2.0, 0.6548822),\n        (\"weighted\", 0.5, 0.6825397),\n        (\"weighted\", 1.0, 0.6555555),\n        (\"weighted\", 2.0, 0.6548822),\n    )\n    def test_fbeta_random_score(self, average, beta, result):\n        y_pred = [[0.7, 0.7, 0.7], [1, 0, 0], [0.9, 0.8, 0]]\n        y_true = [[0, 0, 1], [1, 1, 0], [1, 1, 1]]\n        self._run_test(\n            y_true,\n            y_pred,\n            None,\n            average=average,\n            beta=beta,\n            threshold=0.66,\n            reference_result=result,\n        )\n\n    @parameterized.parameters(\n        # average, beta, result\n        (None, 0.5, [0.9090904, 0.555556, 1.0]),\n        (None, 1.0, [0.8, 0.6666667, 1.0]),\n        (None, 2.0, [0.71428573, 0.833334, 1.0]),\n        (\"micro\", 0.5, 0.833334),\n        (\"micro\", 1.0, 0.833334),\n        (\"micro\", 2.0, 0.833334),\n        (\"macro\", 0.5, 0.821549),\n        (\"macro\", 1.0, 0.822222),\n        (\"macro\", 2.0, 0.849206),\n        (\"weighted\", 0.5, 0.880471),\n        (\"weighted\", 1.0, 0.844445),\n        (\"weighted\", 2.0, 0.829365),\n    )\n    def test_fbeta_random_score_none(self, average, beta, result):\n        y_true = [\n            [1, 0, 0],\n            [0, 1, 0],\n            [0, 0, 1],\n            [1, 0, 0],\n            [1, 0, 0],\n            [0, 0, 1],\n        ]\n        y_pred = [\n            [0.9, 0.1, 0],\n            [0.2, 0.6, 0.2],\n            [0, 0, 1],\n            [0.4, 0.3, 0.3],\n            [0, 0.9, 0.1],\n            [0, 0, 1],\n        ]\n        self._run_test(\n            y_true,\n            y_pred,\n            None,\n            average=average,\n            beta=beta,\n            threshold=None,\n            reference_result=result,\n        )\n\n    @parameterized.parameters(\n        # average, beta, sample_weights, result\n        (None, 0.5, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [0.909091, 0.555556, 1.0]),\n        (None, 0.5, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0]),\n        (None, 0.5, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], [0.9375, 0.714286, 1.0]),\n        (None, 1.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [0.8, 0.666667, 1.0]),\n        (None, 1.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0]),\n        (None, 1.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], [0.857143, 0.8, 1.0]),\n        (None, 2.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [0.714286, 0.833333, 1.0]),\n        (None, 2.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], [1.0, 0.0, 1.0]),\n        (None, 2.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], [0.789474, 0.909091, 1.0]),\n        (\"micro\", 0.5, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.833333),\n        (\"micro\", 0.5, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 1.0),\n        (\"micro\", 0.5, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.9),\n        (\"micro\", 1.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.833333),\n        (\"micro\", 1.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 1.0),\n        (\"micro\", 1.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.9),\n        (\"micro\", 2.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.833333),\n        (\"micro\", 2.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 1.0),\n        (\"micro\", 2.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.9),\n        (\"macro\", 0.5, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.821549),\n        (\"macro\", 0.5, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 0.666667),\n        (\"macro\", 0.5, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.883929),\n        (\"macro\", 1.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.822222),\n        (\"macro\", 1.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 0.666667),\n        (\"macro\", 1.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.885714),\n        (\"macro\", 2.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.849206),\n        (\"macro\", 2.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 0.666667),\n        (\"macro\", 2.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.899522),\n        (\"weighted\", 0.5, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.880471),\n        (\"weighted\", 0.5, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 1.0),\n        (\"weighted\", 0.5, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.917857),\n        (\"weighted\", 1.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.844444),\n        (\"weighted\", 1.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 1.0),\n        (\"weighted\", 1.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.902857),\n        (\"weighted\", 2.0, [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 0.829365),\n        (\"weighted\", 2.0, [1.0, 0.0, 1.0, 1.0, 0.0, 1.0], 1.0),\n        (\"weighted\", 2.0, [0.5, 1.0, 1.0, 1.0, 0.5, 1.0], 0.897608),\n    )\n    def test_fbeta_weighted_random_score_none(\n        self, average, beta, sample_weights, result\n    ):\n        y_true = [\n            [1, 0, 0],\n            [0, 1, 0],\n            [0, 0, 1],\n            [1, 0, 0],\n            [1, 0, 0],\n            [0, 0, 1],\n        ]\n        y_pred = [\n            [0.9, 0.1, 0],\n            [0.2, 0.6, 0.2],\n            [0, 0, 1],\n            [0.4, 0.3, 0.3],\n            [0, 0.9, 0.1],\n            [0, 0, 1],\n        ]\n        self._run_test(\n            y_true,\n            y_pred,\n            sample_weights,\n            average=average,\n            beta=beta,\n            threshold=None,\n            reference_result=result,\n        )\n\n    def test_invalid_average_raises_value_error(self):\n        expected_message = (\n            \"Invalid `average` argument value. Expected one of: \"\n            r\"\\{None, 'micro', 'macro', 'weighted'\\}. \"\n            \"Received: average=invalid_average\"\n        )\n        with self.assertRaisesRegex(ValueError, expected_message):\n            f_score_metrics.FBetaScore(\n                average=\"invalid_average\",\n                beta=1.0,\n                threshold=None,\n                dtype=\"float32\",\n            )\n\n    def test_beta_integer_type_raises_value_error(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Invalid `beta` argument value. It should be a Python float.\",\n        ):\n            f_score_metrics.FBetaScore(\n                average=\"macro\", beta=1, threshold=None, dtype=\"float32\"\n            )\n\n    def test_beta_string_type_raises_value_error(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Invalid `beta` argument value. It should be a Python float.\",\n        ):\n            f_score_metrics.FBetaScore(\n                average=\"macro\", beta=\"1.0\", threshold=None, dtype=\"float32\"\n            )\n\n    def test_beta_none_type_raises_value_error(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Invalid `beta` argument value. It should be a Python float.\",\n        ):\n            f_score_metrics.FBetaScore(\n                average=\"macro\", beta=None, threshold=None, dtype=\"float32\"\n            )\n\n    def test_beta_zero_raises_value_error(self):\n        expected_message = (\n            \"Invalid `beta` argument value. It should be > 0. \"\n            \"Received: beta=0.0\"\n        )\n        with self.assertRaisesRegex(ValueError, expected_message):\n            f_score_metrics.FBetaScore(\n                average=\"macro\", beta=0.0, threshold=None, dtype=\"float32\"\n            )\n\n    def test_beta_negative_one_raises_value_error(self):\n        expected_message = (\n            \"Invalid `beta` argument value. It should be > 0. \"\n            \"Received: beta=-1.0\"\n        )\n        with self.assertRaisesRegex(ValueError, expected_message):\n            f_score_metrics.FBetaScore(\n                average=\"macro\", beta=-1.0, threshold=None, dtype=\"float32\"\n            )\n\n    def test_beta_negative_half_raises_value_error(self):\n        expected_message = (\n            \"Invalid `beta` argument value. It should be > 0. \"\n            \"Received: beta=-0.5\"\n        )\n        with self.assertRaisesRegex(ValueError, expected_message):\n            f_score_metrics.FBetaScore(\n                average=\"macro\", beta=-0.5, threshold=None, dtype=\"float32\"\n            )\n\n    def test_threshold_not_float_raises_value_error(self):\n        expected_message_pattern = (\n            \"Invalid `threshold` argument value. \"\n            \"It should be a Python float. \"\n            \"Received: threshold=1 of type '<class 'int'>'\"\n        )\n        with self.assertRaisesRegex(ValueError, expected_message_pattern):\n            f_score_metrics.FBetaScore(\n                average=\"macro\", beta=1.0, threshold=1, dtype=\"float32\"\n            )\n\n    def test_threshold_string_raises_value_error(self):\n        expected_message_pattern = (\n            \"Invalid `threshold` argument value. \"\n            \"It should be a Python float. \"\n            \"Received: threshold=0.5 of type '<class 'str'>'\"\n        )\n        with self.assertRaisesRegex(ValueError, expected_message_pattern):\n            f_score_metrics.FBetaScore(\n                average=\"macro\", beta=1.0, threshold=\"0.5\", dtype=\"float32\"\n            )\n\n    def test_threshold_above_one_raises_value_error(self):\n        expected_message = (\n            \"Invalid `threshold` argument value. \"\n            \"It should verify 0 < threshold <= 1. \"\n            \"Received: threshold=1.1\"\n        )\n        with self.assertRaisesRegex(ValueError, expected_message):\n            f_score_metrics.FBetaScore(\n                average=\"macro\", beta=1.0, threshold=1.1, dtype=\"float32\"\n            )\n\n    def test_threshold_zero_raises_value_error(self):\n        expected_message = (\n            \"Invalid `threshold` argument value. \"\n            \"It should verify 0 < threshold <= 1. \"\n            \"Received: threshold=0.0\"\n        )\n        with self.assertRaisesRegex(ValueError, expected_message):\n            f_score_metrics.FBetaScore(\n                average=\"macro\", beta=1.0, threshold=0.0, dtype=\"float32\"\n            )\n\n    def test_threshold_negative_raises_value_error(self):\n        expected_message = (\n            \"Invalid `threshold` argument value. \"\n            \"It should verify 0 < threshold <= 1. \"\n            \"Received: threshold=-0.5\"\n        )\n        with self.assertRaisesRegex(ValueError, expected_message):\n            f_score_metrics.FBetaScore(\n                average=\"macro\", beta=1.0, threshold=-0.5, dtype=\"float32\"\n            )\n\n    def test_non_2d_input_shapes_raises_value_error(self):\n        fbeta = f_score_metrics.FBetaScore(beta=1.0, dtype=\"float32\")\n        y_true_shape = (2, 3, 4)\n        y_pred_shape = (2, 3, 4)\n        expected_error_message = (\n            \"FBetaScore expects 2D inputs with shape \"\n            r\"\\(batch_size, output_dim\\)\\. Received input \"\n            r\"shapes: y_pred\\.shape=\\(2, 3, 4\\) and \"\n            r\"y_true\\.shape=\\(2, 3, 4\\)\\.\"\n        )\n        with self.assertRaisesRegex(ValueError, expected_error_message):\n            fbeta._build(y_true_shape, y_pred_shape)\n\n    def test_undefined_output_dim_raises_value_error(self):\n        fbeta = f_score_metrics.FBetaScore(beta=1.0, dtype=\"float32\")\n        y_true_shape = (2, None)\n        y_pred_shape = (2, None)\n        expected_error_message = (\n            \"FBetaScore expects 2D inputs with shape \"\n            r\"\\(batch_size, output_dim\\), with output_dim fully \"\n            r\"defined \\(not None\\)\\. Received input \"\n            r\"shapes: y_pred\\.shape=\\(2, None\\) and \"\n            r\"y_true\\.shape=\\(2, None\\)\\.\"\n        )\n        with self.assertRaisesRegex(ValueError, expected_error_message):\n            fbeta._build(y_true_shape, y_pred_shape)\n\n\nclass F1ScoreTest(testing.TestCase):\n    def test_config(self):\n        f1_obj = f_score_metrics.F1Score(dtype=\"float32\")\n        config = f1_obj.get_config()\n        self.assertNotIn(\"beta\", config)\n\n        # Check save and restore config\n        f1_obj = f_score_metrics.F1Score.from_config(config)\n        self.assertEqual(f1_obj.average, None)\n        self.assertEqual(f1_obj.dtype, \"float32\")\n\n    def test_correctness(self):\n        f1 = f_score_metrics.F1Score()\n        fbeta = f_score_metrics.FBetaScore(beta=1.0)\n\n        y_true = np.array(\n            [\n                [1, 0, 0],\n                [0, 1, 0],\n                [0, 0, 1],\n                [1, 0, 0],\n                [1, 0, 0],\n                [0, 0, 1],\n            ]\n        )\n        y_pred = np.array(\n            [\n                [0.9, 0.1, 0],\n                [0.2, 0.6, 0.2],\n                [0, 0, 1],\n                [0.4, 0.3, 0.3],\n                [0, 0.9, 0.1],\n                [0, 0, 1],\n            ]\n        )\n\n        fbeta.update_state(y_true, y_pred)\n        f1.update_state(y_true, y_pred)\n        self.assertAllClose(fbeta.result(), f1.result(), atol=1e-6)\n"
  },
  {
    "path": "keras/src/metrics/hinge_metrics.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.losses.losses import categorical_hinge\nfrom keras.src.losses.losses import hinge\nfrom keras.src.losses.losses import squared_hinge\nfrom keras.src.metrics import reduction_metrics\n\n\n@keras_export(\"keras.metrics.Hinge\")\nclass Hinge(reduction_metrics.MeanMetricWrapper):\n    \"\"\"Computes the hinge metric between `y_true` and `y_pred`.\n\n    `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are\n    provided we will convert them to -1 or 1.\n\n    Args:\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Examples:\n\n    >>> m = keras.metrics.Hinge()\n    >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])\n    >>> m.result()\n    1.3\n    >>> m.reset_state()\n    >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],\n    ...                sample_weight=[1, 0])\n    >>> m.result()\n    1.1\n    \"\"\"\n\n    def __init__(self, name=\"hinge\", dtype=None):\n        super().__init__(fn=hinge, name=name, dtype=dtype)\n        # Metric should be minimized during optimization.\n        self._direction = \"down\"\n\n    def get_config(self):\n        return {\"name\": self.name, \"dtype\": self.dtype}\n\n\n@keras_export(\"keras.metrics.SquaredHinge\")\nclass SquaredHinge(reduction_metrics.MeanMetricWrapper):\n    \"\"\"Computes the hinge metric between `y_true` and `y_pred`.\n\n    `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are\n    provided we will convert them to -1 or 1.\n\n    Args:\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Example:\n\n    >>> m = keras.metrics.SquaredHinge()\n    >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])\n    >>> m.result()\n    1.86\n    >>> m.reset_state()\n    >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],\n    ...                sample_weight=[1, 0])\n    >>> m.result()\n    1.46\n    \"\"\"\n\n    def __init__(self, name=\"squared_hinge\", dtype=None):\n        super().__init__(fn=squared_hinge, name=name, dtype=dtype)\n        # Metric should be minimized during optimization.\n        self._direction = \"down\"\n\n    def get_config(self):\n        return {\"name\": self.name, \"dtype\": self.dtype}\n\n\n@keras_export(\"keras.metrics.CategoricalHinge\")\nclass CategoricalHinge(reduction_metrics.MeanMetricWrapper):\n    \"\"\"Computes the categorical hinge metric between `y_true` and `y_pred`.\n\n    Args:\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Example:\n    >>> m = keras.metrics.CategoricalHinge()\n    >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])\n    >>> m.result().numpy()\n    1.4000001\n    >>> m.reset_state()\n    >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],\n    ...                sample_weight=[1, 0])\n    >>> m.result()\n    1.2\n    \"\"\"\n\n    def __init__(self, name=\"categorical_hinge\", dtype=None):\n        super().__init__(fn=categorical_hinge, name=name, dtype=dtype)\n        # Metric should be minimized during optimization.\n        self._direction = \"down\"\n\n    def get_config(self):\n        return {\"name\": self.name, \"dtype\": self.dtype}\n"
  },
  {
    "path": "keras/src/metrics/hinge_metrics_test.py",
    "content": "import numpy as np\n\nfrom keras.src import testing\nfrom keras.src.metrics import hinge_metrics\n\n\nclass HingeTest(testing.TestCase):\n    def test_config(self):\n        hinge_obj = hinge_metrics.Hinge(name=\"hinge\", dtype=\"int32\")\n        self.assertEqual(hinge_obj.name, \"hinge\")\n        self.assertEqual(hinge_obj._dtype, \"int32\")\n\n        # Check save and restore config\n        hinge_obj2 = hinge_metrics.Hinge.from_config(hinge_obj.get_config())\n        self.assertEqual(hinge_obj2.name, \"hinge\")\n        self.assertEqual(len(hinge_obj2.variables), 2)\n        self.assertEqual(hinge_obj2._dtype, \"int32\")\n\n    def test_unweighted(self):\n        hinge_obj = hinge_metrics.Hinge()\n        y_true = np.array([[0.0, 1.0, 0.0, 1.0], [0.0, 0.0, 1.0, 1.0]])\n        y_pred = np.array([[-0.3, 0.2, -0.1, 1.6], [-0.25, -1.0, 0.5, 0.6]])\n        hinge_obj.update_state(y_true, y_pred)\n        result = hinge_obj.result()\n        self.assertAllClose(0.506, result, atol=1e-3)\n\n    def test_weighted(self):\n        hinge_obj = hinge_metrics.Hinge()\n        y_true = np.array([[-1, 1, -1, 1], [-1, -1, 1, 1]])\n        y_pred = np.array([[-0.3, 0.2, -0.1, 1.6], [-0.25, -1.0, 0.5, 0.6]])\n        sample_weight = np.array([1.5, 2.0])\n        result = hinge_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose(0.493, result, atol=1e-3)\n\n\nclass SquaredHingeTest(testing.TestCase):\n    def test_config(self):\n        sq_hinge_obj = hinge_metrics.SquaredHinge(\n            name=\"squared_hinge\", dtype=\"int32\"\n        )\n        self.assertEqual(sq_hinge_obj.name, \"squared_hinge\")\n        self.assertEqual(sq_hinge_obj._dtype, \"int32\")\n\n        # Check save and restore config\n        sq_hinge_obj2 = hinge_metrics.SquaredHinge.from_config(\n            sq_hinge_obj.get_config()\n        )\n        self.assertEqual(sq_hinge_obj2.name, \"squared_hinge\")\n        self.assertEqual(len(sq_hinge_obj2.variables), 2)\n        self.assertEqual(sq_hinge_obj2._dtype, \"int32\")\n\n    def test_unweighted(self):\n        sq_hinge_obj = hinge_metrics.SquaredHinge()\n        y_true = np.array([[0, 1, 0, 1], [0, 0, 1, 1]], dtype=\"float32\")\n        y_pred = np.array([[-0.3, 0.2, -0.1, 1.6], [-0.25, -1.0, 0.5, 0.6]])\n        sq_hinge_obj.update_state(y_true, y_pred)\n        result = sq_hinge_obj.result()\n        self.assertAllClose(0.364, result, atol=1e-3)\n\n    def test_weighted(self):\n        sq_hinge_obj = hinge_metrics.SquaredHinge()\n        y_true = np.array([[-1, 1, -1, 1], [-1, -1, 1, 1]], dtype=\"float32\")\n        y_pred = np.array([[-0.3, 0.2, -0.1, 1.6], [-0.25, -1.0, 0.5, 0.6]])\n        sample_weight = np.array([1.5, 2.0])\n        result = sq_hinge_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose(0.347, result, atol=1e-3)\n\n\nclass CategoricalHingeTest(testing.TestCase):\n    def test_config(self):\n        cat_hinge_obj = hinge_metrics.CategoricalHinge(\n            name=\"cat_hinge\", dtype=\"int32\"\n        )\n        self.assertEqual(cat_hinge_obj.name, \"cat_hinge\")\n        self.assertEqual(cat_hinge_obj._dtype, \"int32\")\n\n        # Check save and restore config\n        cat_hinge_obj2 = hinge_metrics.CategoricalHinge.from_config(\n            cat_hinge_obj.get_config()\n        )\n        self.assertEqual(cat_hinge_obj2.name, \"cat_hinge\")\n        self.assertEqual(len(cat_hinge_obj2.variables), 2)\n        self.assertEqual(cat_hinge_obj2._dtype, \"int32\")\n\n    def test_unweighted(self):\n        cat_hinge_obj = hinge_metrics.CategoricalHinge()\n        y_true = np.array(\n            (\n                (0, 1, 0, 1, 0),\n                (0, 0, 1, 1, 1),\n                (1, 1, 1, 1, 0),\n                (0, 0, 0, 0, 1),\n            ),\n            dtype=\"float32\",\n        )\n        y_pred = np.array(\n            (\n                (0, 0, 1, 1, 0),\n                (1, 1, 1, 1, 1),\n                (0, 1, 0, 1, 0),\n                (1, 1, 1, 1, 1),\n            ),\n            dtype=\"float32\",\n        )\n        cat_hinge_obj.update_state(y_true, y_pred)\n        result = cat_hinge_obj.result()\n        self.assertAllClose(0.5, result, atol=1e-5)\n\n    def test_weighted(self):\n        cat_hinge_obj = hinge_metrics.CategoricalHinge()\n        y_true = np.array(\n            (\n                (0, 1, 0, 1, 0),\n                (0, 0, 1, 1, 1),\n                (1, 1, 1, 1, 0),\n                (0, 0, 0, 0, 1),\n            ),\n            dtype=\"float32\",\n        )\n        y_pred = np.array(\n            (\n                (0, 0, 1, 1, 0),\n                (1, 1, 1, 1, 1),\n                (0, 1, 0, 1, 0),\n                (1, 1, 1, 1, 1),\n            ),\n            dtype=\"float32\",\n        )\n        sample_weight = np.array((1.0, 1.5, 2.0, 2.5))\n        result = cat_hinge_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose(0.5, result, atol=1e-5)\n"
  },
  {
    "path": "keras/src/metrics/iou_metrics.py",
    "content": "import warnings\n\nfrom keras.src import backend\nfrom keras.src import initializers\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.metrics.metric import Metric\nfrom keras.src.metrics.metrics_utils import confusion_matrix\n\n\nclass _IoUBase(Metric):\n    \"\"\"Computes the confusion matrix for Intersection-Over-Union metrics.\n\n    Formula:\n\n    ```python\n    iou = true_positives / (true_positives + false_positives + false_negatives)\n    ```\n    Intersection-Over-Union is a common evaluation metric for semantic image\n    segmentation.\n\n    From IoUs of individual classes, the MeanIoU can be computed as the mean of\n    the individual IoUs.\n\n    To compute IoUs, the predictions are accumulated in a confusion matrix,\n    weighted by `sample_weight` and the metric is then calculated from it.\n\n    If `sample_weight` is `None`, weights default to 1.\n    Use `sample_weight` of 0 to mask values.\n\n    Args:\n        num_classes: The possible number of labels the prediction task can have.\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n        ignore_class: Optional integer. The ID of a class to be ignored during\n            metric computation. This is useful, for example, in segmentation\n            problems featuring a \"void\" class (commonly -1 or 255) in\n            segmentation maps. By default (`ignore_class=None`), all classes are\n            considered.\n        sparse_y_true: Whether labels are encoded using integers or\n            dense floating point vectors. If `False`, the `argmax` function\n            is used to determine each sample's most likely associated label.\n        sparse_y_pred: Whether predictions are encoded using integers or\n            dense floating point vectors. If `False`, the `argmax` function\n            is used to determine each sample's most likely associated label.\n        axis: (Optional) -1 is the dimension containing the logits.\n            Defaults to `-1`.\n    \"\"\"\n\n    def __init__(\n        self,\n        num_classes,\n        name=None,\n        dtype=None,\n        ignore_class=None,\n        sparse_y_true=True,\n        sparse_y_pred=True,\n        axis=-1,\n    ):\n        # defaulting to int to avoid issues with confusion matrix\n        super().__init__(name=name, dtype=dtype or \"int\")\n        # Metric should be maximized during optimization.\n        self._direction = \"up\"\n        self.num_classes = num_classes\n        self.ignore_class = ignore_class\n        self.sparse_y_true = sparse_y_true\n        self.sparse_y_pred = sparse_y_pred\n        self.axis = axis\n\n        self.total_cm = self.add_variable(\n            name=\"total_confusion_matrix\",\n            shape=(num_classes, num_classes),\n            initializer=initializers.Zeros(),\n            dtype=self.dtype,\n        )\n\n    def update_state(self, y_true, y_pred, sample_weight=None):\n        \"\"\"Accumulates the confusion matrix statistics.\n\n        Args:\n            y_true: The ground truth values.\n            y_pred: The predicted values.\n            sample_weight: Optional weighting of each example. Can\n                be a `Tensor` whose rank is either 0, or the same as `y_true`,\n                and must be broadcastable to `y_true`. Defaults to `1`.\n\n        Returns:\n            Update op.\n        \"\"\"\n\n        if not self.sparse_y_true:\n            y_true = ops.argmax(y_true, axis=self.axis)\n        if not self.sparse_y_pred:\n            y_pred = ops.argmax(y_pred, axis=self.axis)\n\n        y_true = ops.convert_to_tensor(y_true, dtype=self.dtype)\n        y_pred = ops.convert_to_tensor(y_pred, dtype=self.dtype)\n\n        # Flatten the input if its rank > 1.\n        if len(y_pred.shape) > 1:\n            y_pred = ops.reshape(y_pred, [-1])\n\n        if len(y_true.shape) > 1:\n            y_true = ops.reshape(y_true, [-1])\n\n        if sample_weight is None:\n            sample_weight = 1\n        else:\n            if (\n                hasattr(sample_weight, \"dtype\")\n                and \"float\" in str(sample_weight.dtype)\n                and \"int\" in str(self.dtype)\n            ):\n                warnings.warn(\n                    \"You are passing weight as `float`, but dtype is `int`. \"\n                    \"This may result in an incorrect weight due to type casting\"\n                    \" Consider using integer weights.\"\n                )\n        sample_weight = ops.convert_to_tensor(sample_weight, dtype=self.dtype)\n\n        if len(sample_weight.shape) > 1:\n            sample_weight = ops.reshape(sample_weight, [-1])\n\n        sample_weight = ops.broadcast_to(sample_weight, ops.shape(y_true))\n\n        if self.ignore_class is not None:\n            ignore_class = ops.convert_to_tensor(\n                self.ignore_class, y_true.dtype\n            )\n            valid_mask = ops.not_equal(y_true, ignore_class)\n            y_true = y_true * ops.cast(valid_mask, y_true.dtype)\n            y_pred = y_pred * ops.cast(valid_mask, y_pred.dtype)\n            if sample_weight is not None:\n                sample_weight = sample_weight * ops.cast(\n                    valid_mask, sample_weight.dtype\n                )\n\n        y_pred = ops.cast(y_pred, dtype=self.dtype)\n        y_true = ops.cast(y_true, dtype=self.dtype)\n        sample_weight = ops.cast(sample_weight, dtype=self.dtype)\n\n        current_cm = confusion_matrix(\n            y_true,\n            y_pred,\n            self.num_classes,\n            weights=sample_weight,\n            dtype=self.dtype,\n        )\n\n        return self.total_cm.assign(self.total_cm + current_cm)\n\n    def reset_state(self):\n        self.total_cm.assign(\n            ops.zeros(self.total_cm.shape, dtype=self.total_cm.dtype)\n        )\n\n\n@keras_export(\"keras.metrics.IoU\")\nclass IoU(_IoUBase):\n    \"\"\"Computes the Intersection-Over-Union metric for specific target classes.\n\n    Formula:\n\n    ```python\n    iou = true_positives / (true_positives + false_positives + false_negatives)\n    ```\n    Intersection-Over-Union is a common evaluation metric for semantic image\n    segmentation.\n\n    To compute IoUs, the predictions are accumulated in a confusion matrix,\n    weighted by `sample_weight` and the metric is then calculated from it.\n\n    If `sample_weight` is `None`, weights default to 1.\n    Use `sample_weight` of 0 to mask values.\n\n    Note, this class first computes IoUs for all individual classes, then\n    returns the mean of IoUs for the classes that are specified by\n    `target_class_ids`. If `target_class_ids` has only one id value, the IoU of\n    that specific class is returned.\n\n    Args:\n        num_classes: The possible number of labels the prediction task can have.\n        target_class_ids: A tuple or list of target class ids for which the\n            metric is returned. To compute IoU for a specific class, a list\n            (or tuple) of a single id value should be provided.\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n        ignore_class: Optional integer. The ID of a class to be ignored during\n            metric computation. This is useful, for example, in segmentation\n            problems featuring a \"void\" class (commonly -1 or 255) in\n            segmentation maps. By default (`ignore_class=None`), all classes are\n              considered.\n        sparse_y_true: Whether labels are encoded using integers or\n            dense floating point vectors. If `False`, the `argmax` function\n            is used to determine each sample's most likely associated label.\n        sparse_y_pred: Whether predictions are encoded using integers or\n            dense floating point vectors. If `False`, the `argmax` function\n            is used to determine each sample's most likely associated label.\n        axis: (Optional) -1 is the dimension containing the logits.\n            Defaults to `-1`.\n\n    Examples:\n\n    >>> # cm = [[1, 1],\n    >>> #        [1, 1]]\n    >>> # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1]\n    >>> # iou = true_positives / (sum_row + sum_col - true_positives))\n    >>> # iou = [0.33, 0.33]\n    >>> m = keras.metrics.IoU(num_classes=2, target_class_ids=[0])\n    >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1])\n    >>> m.result()\n    0.33333334\n\n    >>> m.reset_state()\n    >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1],\n    ...                sample_weight=[0.3, 0.3, 0.3, 0.1])\n    >>> # cm = [[0.3, 0.3],\n    >>> #        [0.3, 0.1]]\n    >>> # sum_row = [0.6, 0.4], sum_col = [0.6, 0.4],\n    >>> # true_positives = [0.3, 0.1]\n    >>> # iou = [0.33, 0.14]\n    >>> m.result()\n    0.33333334\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(\n        optimizer='sgd',\n        loss='mse',\n        metrics=[keras.metrics.IoU(num_classes=2, target_class_ids=[0])])\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        num_classes,\n        target_class_ids,\n        name=None,\n        dtype=None,\n        ignore_class=None,\n        sparse_y_true=True,\n        sparse_y_pred=True,\n        axis=-1,\n    ):\n        super().__init__(\n            name=name,\n            num_classes=num_classes,\n            ignore_class=ignore_class,\n            sparse_y_true=sparse_y_true,\n            sparse_y_pred=sparse_y_pred,\n            axis=axis,\n            dtype=dtype,\n        )\n        if max(target_class_ids) >= num_classes:\n            raise ValueError(\n                f\"Target class id {max(target_class_ids)} \"\n                \"is out of range, which is \"\n                f\"[{0}, {num_classes}).\"\n            )\n        self.target_class_ids = list(target_class_ids)\n\n    def result(self):\n        \"\"\"Compute the intersection-over-union via the confusion matrix.\"\"\"\n        sum_over_row = ops.cast(\n            ops.sum(self.total_cm, axis=0), dtype=self.dtype\n        )\n        sum_over_col = ops.cast(\n            ops.sum(self.total_cm, axis=1), dtype=self.dtype\n        )\n        true_positives = ops.cast(ops.diag(self.total_cm), dtype=self.dtype)\n\n        # sum_over_row + sum_over_col =\n        #     2 * true_positives + false_positives + false_negatives.\n        denominator = sum_over_row + sum_over_col - true_positives\n\n        target_class_ids = ops.convert_to_tensor(\n            self.target_class_ids, dtype=\"int32\"\n        )\n\n        # Only keep the target classes\n        true_positives = ops.take_along_axis(\n            true_positives, target_class_ids, axis=-1\n        )\n        denominator = ops.take_along_axis(\n            denominator, target_class_ids, axis=-1\n        )\n        denominator = ops.cast(denominator, dtype=\"float32\")\n\n        # If the denominator is 0, we need to ignore the class.\n        num_valid_entries = ops.sum(\n            ops.cast(ops.greater(denominator, 1e-9), dtype=\"float32\")\n        )\n\n        iou = ops.divide(true_positives, denominator + backend.epsilon())\n\n        return ops.divide(\n            ops.sum(iou, axis=self.axis), num_valid_entries + backend.epsilon()\n        )\n\n    def get_config(self):\n        config = {\n            \"num_classes\": self.num_classes,\n            \"target_class_ids\": self.target_class_ids,\n            \"ignore_class\": self.ignore_class,\n            \"sparse_y_true\": self.sparse_y_true,\n            \"sparse_y_pred\": self.sparse_y_pred,\n            \"axis\": self.axis,\n        }\n        base_config = super().get_config()\n        return dict(list(base_config.items()) + list(config.items()))\n\n\n@keras_export(\"keras.metrics.BinaryIoU\")\nclass BinaryIoU(IoU):\n    \"\"\"Computes the Intersection-Over-Union metric for class 0 and/or 1.\n\n    Formula:\n\n    ```python\n    iou = true_positives / (true_positives + false_positives + false_negatives)\n    ```\n    Intersection-Over-Union is a common evaluation metric for semantic image\n    segmentation.\n\n    To compute IoUs, the predictions are accumulated in a confusion matrix,\n    weighted by `sample_weight` and the metric is then calculated from it.\n\n    If `sample_weight` is `None`, weights default to 1.\n    Use `sample_weight` of 0 to mask values.\n\n    This class can be used to compute IoUs for a binary classification task\n    where the predictions are provided as logits. First a `threshold` is applied\n    to the predicted values such that those that are below the `threshold` are\n    converted to class 0 and those that are above the `threshold` are converted\n    to class 1.\n\n    IoUs for classes 0 and 1 are then computed, the mean of IoUs for the classes\n    that are specified by `target_class_ids` is returned.\n\n    Note: with `threshold=0`, this metric has the same behavior as `IoU`.\n\n    Args:\n        target_class_ids: A tuple or list of target class ids for which the\n            metric is returned. Options are `[0]`, `[1]`, or `[0, 1]`. With\n            `[0]` (or `[1]`), the IoU metric for class 0 (or class 1,\n            respectively) is returned. With `[0, 1]`, the mean of IoUs for the\n            two classes is returned.\n        threshold: A threshold that applies to the prediction logits to convert\n            them to either predicted class 0 if the logit is below `threshold`\n            or predicted class 1 if the logit is above `threshold`.\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Example:\n\n    >>> m = keras.metrics.BinaryIoU(target_class_ids=[0, 1], threshold=0.3)\n    >>> m.update_state([0, 1, 0, 1], [0.1, 0.2, 0.4, 0.7])\n    >>> m.result()\n    0.33333334\n\n    >>> m.reset_state()\n    >>> m.update_state([0, 1, 0, 1], [0.1, 0.2, 0.4, 0.7],\n    ...                sample_weight=[0.2, 0.3, 0.4, 0.1])\n    >>> # cm = [[0.2, 0.4],\n    >>> #        [0.3, 0.1]]\n    >>> # sum_row = [0.6, 0.4], sum_col = [0.5, 0.5],\n    >>> # true_positives = [0.2, 0.1]\n    >>> # iou = [0.222, 0.125]\n    >>> m.result()\n    0.17361112\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(\n        optimizer='sgd',\n        loss='mse',\n        metrics=[keras.metrics.BinaryIoU(\n            target_class_ids=[0],\n            threshold=0.5\n        )]\n    )\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        target_class_ids=(0, 1),\n        threshold=0.5,\n        name=None,\n        dtype=None,\n    ):\n        super().__init__(\n            num_classes=2,\n            target_class_ids=target_class_ids,\n            name=name,\n            dtype=dtype,\n        )\n        self.threshold = threshold\n\n    def update_state(self, y_true, y_pred, sample_weight=None):\n        \"\"\"Accumulates the confusion matrix statistics.\n\n        Before the confusion matrix is updated, the predicted values are\n        thresholded to be:\n            0 for values that are smaller than the `threshold`\n            1 for values that are larger or equal to the `threshold`\n\n        Args:\n            y_true: The ground truth values.\n            y_pred: The predicted values.\n            sample_weight: Optional weighting of each example. Can\n                be a `Tensor` whose rank is either 0, or the same as `y_true`,\n                and must be broadcastable to `y_true`. Defaults to `1`.\n\n        Returns:\n            Update op.\n        \"\"\"\n        y_true = ops.convert_to_tensor(y_true, dtype=self.dtype)\n        # convert y_pred on float 32 and cast just after to dtype\n        y_pred = ops.convert_to_tensor(y_pred, dtype=\"float32\")\n        y_pred = ops.cast(y_pred >= self.threshold, self.dtype)\n        return super().update_state(y_true, y_pred, sample_weight)\n\n    def get_config(self):\n        return {\n            \"target_class_ids\": self.target_class_ids,\n            \"threshold\": self.threshold,\n            \"name\": self.name,\n            \"dtype\": self._dtype,\n        }\n\n\n@keras_export(\"keras.metrics.MeanIoU\")\nclass MeanIoU(IoU):\n    \"\"\"Computes the mean Intersection-Over-Union metric.\n\n    Formula:\n\n    ```python\n    iou = true_positives / (true_positives + false_positives + false_negatives)\n    ```\n    Intersection-Over-Union is a common evaluation metric for semantic image\n    segmentation.\n\n    To compute IoUs, the predictions are accumulated in a confusion matrix,\n    weighted by `sample_weight` and the metric is then calculated from it.\n\n    If `sample_weight` is `None`, weights default to 1.\n    Use `sample_weight` of 0 to mask values.\n\n    Note that this class first computes IoUs for all individual classes, then\n    returns the mean of these values.\n\n    Args:\n        num_classes: The possible number of labels the prediction task can have.\n            This value must be provided, since a confusion matrix of dimension =\n            [num_classes, num_classes] will be allocated.\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n        ignore_class: Optional integer. The ID of a class to be ignored during\n            metric computation. This is useful, for example, in segmentation\n            problems featuring a \"void\" class (commonly -1 or 255) in\n            segmentation maps. By default (`ignore_class=None`), all classes are\n            considered.\n        sparse_y_true: Whether labels are encoded using integers or\n            dense floating point vectors. If `False`, the `argmax` function\n            is used to determine each sample's most likely associated label.\n        sparse_y_pred: Whether predictions are encoded using integers or\n            dense floating point vectors. If `False`, the `argmax` function\n            is used to determine each sample's most likely associated label.\n        axis: (Optional) The dimension containing the logits. Defaults to `-1`.\n\n\n    Example:\n\n    >>> # cm = [[1, 1],\n    >>> #        [1, 1]]\n    >>> # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1]\n    >>> # iou = true_positives / (sum_row + sum_col - true_positives))\n    >>> # result = (1 / (2 + 2 - 1) + 1 / (2 + 2 - 1)) / 2 = 0.33\n    >>> m = keras.metrics.MeanIoU(num_classes=2)\n    >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1])\n    >>> m.result()\n    0.33333334\n\n    >>> m.reset_state()\n    >>> m.update_state([0, 0, 1, 1], [0, 1, 0, 1],\n    ...                sample_weight=[0.3, 0.3, 0.3, 0.1])\n    >>> m.result().numpy()\n    0.23809525\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(\n        optimizer='sgd',\n        loss='mse',\n        metrics=[keras.metrics.MeanIoU(num_classes=2)])\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        num_classes,\n        name=None,\n        dtype=None,\n        ignore_class=None,\n        sparse_y_true=True,\n        sparse_y_pred=True,\n        axis=-1,\n    ):\n        target_class_ids = list(range(num_classes))\n        super().__init__(\n            name=name,\n            num_classes=num_classes,\n            target_class_ids=target_class_ids,\n            axis=axis,\n            dtype=dtype,\n            ignore_class=ignore_class,\n            sparse_y_true=sparse_y_true,\n            sparse_y_pred=sparse_y_pred,\n        )\n\n    def get_config(self):\n        return {\n            \"num_classes\": self.num_classes,\n            \"name\": self.name,\n            \"dtype\": self._dtype,\n            \"ignore_class\": self.ignore_class,\n            \"sparse_y_true\": self.sparse_y_true,\n            \"sparse_y_pred\": self.sparse_y_pred,\n            \"axis\": self.axis,\n        }\n\n\n@keras_export(\"keras.metrics.OneHotIoU\")\nclass OneHotIoU(IoU):\n    \"\"\"Computes the Intersection-Over-Union metric for one-hot encoded labels.\n\n    Formula:\n\n    ```python\n    iou = true_positives / (true_positives + false_positives + false_negatives)\n    ```\n    Intersection-Over-Union is a common evaluation metric for semantic image\n    segmentation.\n\n    To compute IoUs, the predictions are accumulated in a confusion matrix,\n    weighted by `sample_weight` and the metric is then calculated from it.\n\n    If `sample_weight` is `None`, weights default to 1.\n    Use `sample_weight` of 0 to mask values.\n\n    This class can be used to compute IoU for multi-class classification tasks\n    where the labels are one-hot encoded (the last axis should have one\n    dimension per class). Note that the predictions should also have the same\n    shape. To compute the IoU, first the labels and predictions are converted\n    back into integer format by taking the argmax over the class axis. Then the\n    same computation steps as for the base `IoU` class apply.\n\n    Note, if there is only one channel in the labels and predictions, this class\n    is the same as class `IoU`. In this case, use `IoU` instead.\n\n    Also, make sure that `num_classes` is equal to the number of classes in the\n    data, to avoid a \"labels out of bound\" error when the confusion matrix is\n    computed.\n\n    Args:\n        num_classes: The possible number of labels the prediction task can have.\n        target_class_ids: A tuple or list of target class ids for which the\n            metric is returned. To compute IoU for a specific class, a list\n            (or tuple) of a single id value should be provided.\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n        ignore_class: Optional integer. The ID of a class to be ignored during\n            metric computation. This is useful, for example, in segmentation\n            problems featuring a \"void\" class (commonly -1 or 255) in\n            segmentation maps. By default (`ignore_class=None`), all classes are\n            considered.\n        sparse_y_pred: Whether predictions are encoded using integers or\n            dense floating point vectors. If `False`, the `argmax` function\n            is used to determine each sample's most likely associated label.\n        axis: (Optional) The dimension containing the logits. Defaults to `-1`.\n\n\n    Example:\n\n    >>> y_true = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0], [1, 0, 0]])\n    >>> y_pred = np.array([[0.2, 0.3, 0.5], [0.1, 0.2, 0.7], [0.5, 0.3, 0.1],\n    ...                       [0.1, 0.4, 0.5]])\n    >>> sample_weight = [0.1, 0.2, 0.3, 0.4]\n    >>> m = keras.metrics.OneHotIoU(num_classes=3, target_class_ids=[0, 2])\n    >>> m.update_state(\n    ...     y_true=y_true, y_pred=y_pred, sample_weight=sample_weight)\n    >>> # cm = [[0, 0, 0.2+0.4],\n    >>> #       [0.3, 0, 0],\n    >>> #       [0, 0, 0.1]]\n    >>> # sum_row = [0.3, 0, 0.7], sum_col = [0.6, 0.3, 0.1]\n    >>> # true_positives = [0, 0, 0.1]\n    >>> # single_iou = true_positives / (sum_row + sum_col - true_positives))\n    >>> # mean_iou = (0 / (0.3 + 0.6 - 0) + 0.1 / (0.7 + 0.1 - 0.1)) / 2\n    >>> m.result()\n    0.071\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(\n        optimizer='sgd',\n        loss='mse',\n        metrics=[keras.metrics.OneHotIoU(\n            num_classes=3,\n            target_class_id=[1]\n        )]\n    )\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        num_classes,\n        target_class_ids,\n        name=None,\n        dtype=None,\n        ignore_class=None,\n        sparse_y_pred=False,\n        axis=-1,\n    ):\n        super().__init__(\n            num_classes=num_classes,\n            target_class_ids=target_class_ids,\n            name=name,\n            dtype=dtype,\n            ignore_class=ignore_class,\n            sparse_y_true=False,\n            sparse_y_pred=sparse_y_pred,\n            axis=axis,\n        )\n\n    def get_config(self):\n        return {\n            \"num_classes\": self.num_classes,\n            \"target_class_ids\": self.target_class_ids,\n            \"name\": self.name,\n            \"dtype\": self._dtype,\n            \"ignore_class\": self.ignore_class,\n            \"sparse_y_pred\": self.sparse_y_pred,\n            \"axis\": self.axis,\n        }\n\n\n@keras_export(\"keras.metrics.OneHotMeanIoU\")\nclass OneHotMeanIoU(MeanIoU):\n    \"\"\"Computes mean Intersection-Over-Union metric for one-hot encoded labels.\n\n    Formula:\n\n    ```python\n    iou = true_positives / (true_positives + false_positives + false_negatives)\n    ```\n    Intersection-Over-Union is a common evaluation metric for semantic image\n    segmentation.\n\n    To compute IoUs, the predictions are accumulated in a confusion matrix,\n    weighted by `sample_weight` and the metric is then calculated from it.\n\n    If `sample_weight` is `None`, weights default to 1.\n    Use `sample_weight` of 0 to mask values.\n\n    This class can be used to compute the mean IoU for multi-class\n    classification tasks where the labels are one-hot encoded (the last axis\n    should have one dimension per class). Note that the predictions should also\n    have the same shape. To compute the mean IoU, first the labels and\n    predictions are converted back into integer format by taking the argmax over\n    the class axis. Then the same computation steps as for the base `MeanIoU`\n    class apply.\n\n    Note, if there is only one channel in the labels and predictions, this class\n    is the same as class `MeanIoU`. In this case, use `MeanIoU` instead.\n\n    Also, make sure that `num_classes` is equal to the number of classes in the\n    data, to avoid a \"labels out of bound\" error when the confusion matrix is\n    computed.\n\n    Args:\n        num_classes: The possible number of labels the prediction task can have.\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n        ignore_class: Optional integer. The ID of a class to be ignored during\n            metric computation. This is useful, for example, in segmentation\n            problems featuring a \"void\" class (commonly -1 or 255) in\n            segmentation maps. By default (`ignore_class=None`), all classes are\n            considered.\n        sparse_y_pred: Whether predictions are encoded using natural numbers or\n            probability distribution vectors. If `False`, the `argmax`\n            function will be used to determine each sample's most likely\n            associated label.\n        axis: (Optional) The dimension containing the logits. Defaults to `-1`.\n\n\n    Example:\n\n    >>> y_true = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0], [1, 0, 0]])\n    >>> y_pred = np.array([[0.2, 0.3, 0.5], [0.1, 0.2, 0.7], [0.5, 0.3, 0.1],\n    ...                       [0.1, 0.4, 0.5]])\n    >>> sample_weight = [0.1, 0.2, 0.3, 0.4]\n    >>> m = keras.metrics.OneHotMeanIoU(num_classes=3)\n    >>> m.update_state(\n    ...     y_true=y_true, y_pred=y_pred, sample_weight=sample_weight)\n    >>> # cm = [[0, 0, 0.2+0.4],\n    >>> #       [0.3, 0, 0],\n    >>> #       [0, 0, 0.1]]\n    >>> # sum_row = [0.3, 0, 0.7], sum_col = [0.6, 0.3, 0.1]\n    >>> # true_positives = [0, 0, 0.1]\n    >>> # single_iou = true_positives / (sum_row + sum_col - true_positives))\n    >>> # mean_iou = (0 + 0 + 0.1 / (0.7 + 0.1 - 0.1)) / 3\n    >>> m.result()\n    0.048\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(\n        optimizer='sgd',\n        loss='mse',\n        metrics=[keras.metrics.OneHotMeanIoU(num_classes=3)])\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        num_classes,\n        name=None,\n        dtype=None,\n        ignore_class=None,\n        sparse_y_pred=False,\n        axis=-1,\n    ):\n        super().__init__(\n            num_classes=num_classes,\n            axis=axis,\n            name=name,\n            dtype=dtype,\n            ignore_class=ignore_class,\n            sparse_y_true=False,\n            sparse_y_pred=sparse_y_pred,\n        )\n\n    def get_config(self):\n        return {\n            \"num_classes\": self.num_classes,\n            \"name\": self.name,\n            \"dtype\": self._dtype,\n            \"ignore_class\": self.ignore_class,\n            \"sparse_y_pred\": self.sparse_y_pred,\n            \"axis\": self.axis,\n        }\n"
  },
  {
    "path": "keras/src/metrics/iou_metrics_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import testing\nfrom keras.src.metrics import iou_metrics as metrics\nfrom keras.src.ops import convert_to_tensor\n\n\nclass IoUTest(testing.TestCase):\n    def test_config(self):\n        obj = metrics.IoU(\n            num_classes=2, target_class_ids=[1, 0], name=\"iou_class_1_0\"\n        )\n        self.assertEqual(obj.name, \"iou_class_1_0\")\n        self.assertEqual(obj.num_classes, 2)\n        self.assertEqual(obj.target_class_ids, [1, 0])\n\n        obj2 = metrics.IoU.from_config(obj.get_config())\n        self.assertEqual(obj2.name, \"iou_class_1_0\")\n        self.assertEqual(obj2.num_classes, 2)\n        self.assertEqual(obj2.target_class_ids, [1, 0])\n\n    def test_unweighted(self):\n        y_pred = [0, 1, 0, 1]\n        y_true = [0, 0, 1, 1]\n\n        obj = metrics.IoU(num_classes=2, target_class_ids=[0, 1])\n\n        result = obj(y_true, y_pred)\n\n        # cm = [[1, 1],\n        #       [1, 1]]\n        # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1]\n        # iou = true_positives / (sum_row + sum_col - true_positives))\n        expected_result = (1 / (2 + 2 - 1) + 1 / (2 + 2 - 1)) / 2\n        self.assertAllClose(result, expected_result, atol=1e-3)\n\n    def test_weighted(self):\n        y_pred = np.array([0, 1, 0, 1], dtype=np.float32)\n        y_true = np.array([0, 0, 1, 1])\n        sample_weight = np.array([0.2, 0.3, 0.4, 0.1])\n\n        obj = metrics.IoU(\n            num_classes=2, target_class_ids=[1, 0], dtype=\"float32\"\n        )\n\n        result = obj(y_true, y_pred, sample_weight=sample_weight)\n\n        # cm = [[0.2, 0.3],\n        #       [0.4, 0.1]]\n        # sum_row = [0.6, 0.4], sum_col = [0.5, 0.5], true_positives = [0.2,\n        # 0.1]\n        # iou = true_positives / (sum_row + sum_col - true_positives))\n        expected_result = (\n            0.1 / (0.4 + 0.5 - 0.1) + 0.2 / (0.6 + 0.5 - 0.2)\n        ) / 2\n        self.assertAllClose(result, expected_result, atol=1e-3)\n\n    def test_multi_dim_input(self):\n        y_pred = np.array([[0, 1], [0, 1]], dtype=np.float32)\n        y_true = np.array([[0, 0], [1, 1]])\n        sample_weight = np.array([[0.2, 0.3], [0.4, 0.1]])\n\n        obj = metrics.IoU(\n            num_classes=2, target_class_ids=[0, 1], dtype=\"float32\"\n        )\n\n        result = obj(y_true, y_pred, sample_weight=sample_weight)\n\n        # cm = [[0.2, 0.3],\n        #       [0.4, 0.1]]\n        # sum_row = [0.6, 0.4], sum_col = [0.5, 0.5], true_positives = [0.2,\n        # 0.1]\n        # iou = true_positives / (sum_row + sum_col - true_positives))\n        expected_result = (\n            0.2 / (0.6 + 0.5 - 0.2) + 0.1 / (0.4 + 0.5 - 0.1)\n        ) / 2\n        self.assertAllClose(result, expected_result, atol=1e-3)\n\n    def test_zero_valid_entries(self):\n        obj = metrics.IoU(num_classes=2, target_class_ids=[0, 1])\n        self.assertAllClose(obj.result(), 0, atol=1e-3)\n\n    def test_zero_and_non_zero_entries(self):\n        y_pred = np.array([1], dtype=np.float32)\n        y_true = np.array([1])\n\n        obj = metrics.IoU(num_classes=2, target_class_ids=[0, 1])\n        result = obj(y_true, y_pred)\n\n        # cm = [[0, 0],\n        #       [0, 1]]\n        # sum_row = [0, 1], sum_col = [0, 1], true_positives = [0, 1]\n        # iou = true_positives / (sum_row + sum_col - true_positives))\n        expected_result = (1 / (1 + 1 - 1)) / 1\n        self.assertAllClose(result, expected_result, atol=1e-3)\n\n    @pytest.mark.requires_trainable_backend\n    def test_compilation(self):\n        m_obj = metrics.MeanIoU(num_classes=2, ignore_class=0)\n        model = models.Sequential(\n            [\n                layers.Dense(2, activation=\"softmax\"),\n            ]\n        )\n        model.compile(optimizer=\"rmsprop\", loss=\"mse\", metrics=[m_obj])\n        model.fit(np.array([[1.0, 1.0]]), np.array([[1.0, 0.0]]))\n\n\nclass BinaryIoUTest(testing.TestCase):\n    def test_config(self):\n        obj = metrics.BinaryIoU(\n            target_class_ids=[1, 0], threshold=0.1, name=\"iou_class_1_0\"\n        )\n        self.assertEqual(obj.name, \"iou_class_1_0\")\n        self.assertAlmostEqual(obj.threshold, 0.1)\n        self.assertEqual(obj.target_class_ids, [1, 0])\n\n        obj2 = metrics.BinaryIoU.from_config(obj.get_config())\n        self.assertEqual(obj.name, \"iou_class_1_0\")\n        self.assertAlmostEqual(obj2.threshold, 0.1)\n        self.assertEqual(obj.target_class_ids, [1, 0])\n\n    def test_different_thresholds_weighted(self):\n        y_true = [0, 1, 0, 1]\n        y_pred = [0.1, 0.2, 0.4, 0.7]\n\n        sample_weight = np.array([0.2, 0.3, 0.4, 0.1])\n        # with threshold = 0.3, y_pred will be converted to [0, 0, 1, 1]\n        # cm = [[0.2, 0.4],\n        #       [0.3, 0.1]]\n        # sum_row = [0.6, 0.4], sum_col = [0.5, 0.5], true_positives = [0.2,\n        # 0.1]\n        # iou = true_positives / (sum_row + sum_col - true_positives))\n        expected_result = (\n            0.2 / (0.6 + 0.5 - 0.2) + 0.1 / (0.4 + 0.5 - 0.1)\n        ) / 2\n        obj = metrics.BinaryIoU(\n            target_class_ids=[0, 1], threshold=0.3, dtype=\"float32\"\n        )\n        result = obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose(result, expected_result, atol=1e-3)\n\n        sample_weight = np.array([0.1, 0.2, 0.4, 0.3])\n        # with threshold = 0.5, y_pred will be converted to [0, 0, 0, 1]\n        # cm = [[0.1+0.4, 0],\n        #       [0.2, 0.3]]\n        # sum_row = [0.5, 0.5], sum_col = [0.7, 0.3], true_positives = [0.5,\n        # 0.3]\n        # iou = true_positives / (sum_row + sum_col - true_positives))\n        expected_result = (\n            0.5 / (0.5 + 0.7 - 0.5) + 0.3 / (0.5 + 0.3 - 0.3)\n        ) / 2\n        obj = metrics.BinaryIoU(\n            target_class_ids=[0, 1], threshold=0.5, dtype=\"float32\"\n        )\n        result = obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose(result, expected_result, atol=1e-3)\n\n    def test_different_thresholds_unweighted(self):\n        y_true = [0, 1, 0, 1]\n        y_pred = [0.1, 0.2, 0.4, 0.7]\n\n        # with threshold = 0.3, y_pred will be converted to [0, 0, 1, 1]\n        # cm = [[1, 1],\n        #       [1, 1]]\n        # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1]\n        # iou = true_positives / (sum_row + sum_col - true_positives))\n        expected_result = (1 / (2 + 2 - 1) + 1 / (2 + 2 - 1)) / 2\n        obj = metrics.BinaryIoU(target_class_ids=[0, 1], threshold=0.3)\n        result = obj(y_true, y_pred)\n        self.assertAllClose(result, expected_result, atol=1e-3)\n\n        # with threshold = 0.5, y_pred will be converted to [0, 0, 0, 1]\n        # cm = [[2, 0],\n        #       [1, 1]]\n        # sum_row = [2, 2], sum_col = [3, 1], true_positives = [2, 1]\n        # iou = true_positives / (sum_row + sum_col - true_positives))\n        expected_result = (2 / (2 + 3 - 2) + 1 / (2 + 1 - 1)) / 2\n        obj = metrics.BinaryIoU(target_class_ids=[0, 1], threshold=0.5)\n        result = obj(y_true, y_pred)\n        self.assertAllClose(result, expected_result, atol=1e-3)\n\n    def test_multi_dim_input(self):\n        y_true = np.array([[0, 1], [0, 1]], dtype=np.float32)\n        y_pred = np.array([[0.1, 0.7], [0.9, 0.3]])\n        threshold = 0.4  # y_pred will become [[0, 1], [1, 0]]\n        sample_weight = np.array([[0.2, 0.3], [0.4, 0.1]])\n        # cm = [[0.2, 0.4],\n        #       [0.1, 0.3]]\n        # sum_row = [0.6, 0.4], sum_col = [0.3, 0.7], true_positives = [0.2,\n        # 0.3]\n        # iou = true_positives / (sum_row + sum_col - true_positives))\n        expected_result = (\n            0.2 / (0.6 + 0.3 - 0.2) + 0.3 / (0.4 + 0.7 - 0.3)\n        ) / 2\n        obj = metrics.BinaryIoU(\n            target_class_ids=[0, 1], threshold=threshold, dtype=\"float32\"\n        )\n        result = obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose(result, expected_result, atol=1e-3)\n\n    def test_zero_valid_entries(self):\n        obj = metrics.BinaryIoU(target_class_ids=[0, 1])\n        self.assertAllClose(obj.result(), 0, atol=1e-3)\n\n    def test_zero_and_non_zero_entries(self):\n        y_pred = np.array([0.6], dtype=np.float32)\n        threshold = 0.5\n        y_true = np.array([1])\n\n        obj = metrics.BinaryIoU(target_class_ids=[0, 1], threshold=threshold)\n        result = obj(y_true, y_pred)\n\n        # cm = [[0, 0],\n        #       [0, 1]]\n        # sum_row = [0, 1], sum_col = [0, 1], true_positives = [0, 1]\n        # iou = true_positives / (sum_row + sum_col - true_positives))\n        expected_result = 1 / (1 + 1 - 1)\n        self.assertAllClose(result, expected_result, atol=1e-3)\n\n\nclass MeanIoUTest(testing.TestCase):\n    def test_config(self):\n        m_obj = metrics.MeanIoU(num_classes=2, name=\"mean_iou\")\n        self.assertEqual(m_obj.name, \"mean_iou\")\n        self.assertEqual(m_obj.num_classes, 2)\n\n        m_obj2 = metrics.MeanIoU.from_config(m_obj.get_config())\n        self.assertEqual(m_obj2.name, \"mean_iou\")\n        self.assertEqual(m_obj2.num_classes, 2)\n\n    def test_unweighted(self):\n        y_pred = [0, 1, 0, 1]\n        y_true = [0, 0, 1, 1]\n\n        m_obj = metrics.MeanIoU(num_classes=2)\n\n        result = m_obj(y_true, y_pred)\n\n        # cm = [[1, 1],\n        #       [1, 1]]\n        # sum_row = [2, 2], sum_col = [2, 2], true_positives = [1, 1]\n        # iou = true_positives / (sum_row + sum_col - true_positives))\n        expected_result = (1 / (2 + 2 - 1) + 1 / (2 + 2 - 1)) / 2\n        self.assertAllClose(result, expected_result, atol=1e-3)\n\n    def test_unweighted_ignore_class_255(self):\n        y_pred = [0, 1, 1, 1]\n        y_true = [0, 1, 2, 255]\n\n        m_obj = metrics.MeanIoU(num_classes=3, ignore_class=255)\n\n        result = m_obj(y_true, y_pred)\n\n        # cm = [[1, 0, 0],\n        #       [0, 1, 0],\n        #       [0, 1, 0]]\n        # sum_row = [1, 1, 1], sum_col = [1, 2, 0], true_positives = [1, 1, 0]\n        # iou = true_positives / (sum_row + sum_col - true_positives))\n        expected_result = (\n            1 / (1 + 1 - 1) + 1 / (2 + 1 - 1) + 0 / (0 + 1 - 0)\n        ) / 3\n        self.assertAllClose(result, expected_result, atol=1e-3)\n\n    def test_unweighted_ignore_class_1(self):\n        y_pred = [0, 1, 1, 1]\n        y_true = [0, 1, 2, -1]\n\n        m_obj = metrics.MeanIoU(num_classes=3, ignore_class=-1)\n\n        result = m_obj(y_true, y_pred)\n\n        # cm = [[1, 0, 0],\n        #       [0, 1, 0],\n        #       [0, 1, 0]]\n        # sum_row = [1, 1, 1], sum_col = [1, 2, 0], true_positives = [1, 1, 0]\n        # iou = true_positives / (sum_row + sum_col - true_positives))\n        expected_result = (\n            1 / (1 + 1 - 1) + 1 / (2 + 1 - 1) + 0 / (0 + 1 - 0)\n        ) / 3\n        self.assertAllClose(result, expected_result, atol=1e-3)\n\n    def test_weighted(self):\n        y_pred = np.array([0, 1, 0, 1], dtype=np.float32)\n        y_true = np.array([0, 0, 1, 1])\n        sample_weight = np.array([0.2, 0.3, 0.4, 0.1])\n\n        m_obj = metrics.MeanIoU(num_classes=2, dtype=\"float32\")\n\n        result = m_obj(y_true, y_pred, sample_weight=sample_weight)\n\n        # cm = [[0.2, 0.3],\n        #       [0.4, 0.1]]\n        # sum_row = [0.6, 0.4], sum_col = [0.5, 0.5], true_positives = [0.2,\n        # 0.1]\n        # iou = true_positives / (sum_row + sum_col - true_positives))\n        expected_result = (\n            0.2 / (0.6 + 0.5 - 0.2) + 0.1 / (0.4 + 0.5 - 0.1)\n        ) / 2\n        self.assertAllClose(result, expected_result, atol=1e-3)\n\n    def test_weighted_ignore_class_1(self):\n        y_pred = np.array([0, 1, 0, 1], dtype=np.float32)\n        y_true = np.array([0, 0, 1, -1])\n        sample_weight = np.array([0.2, 0.3, 0.4, 0.1])\n\n        m_obj = metrics.MeanIoU(num_classes=2, ignore_class=-1, dtype=\"float32\")\n\n        result = m_obj(y_true, y_pred, sample_weight=sample_weight)\n\n        # cm = [[0.2, 0.3],\n        #       [0.4, 0.0]]\n        # sum_row = [0.6, 0.3], sum_col = [0.5, 0.4], true_positives = [0.2,\n        # 0.0]\n        # iou = true_positives / (sum_row + sum_col - true_positives))\n        expected_result = (\n            0.2 / (0.6 + 0.5 - 0.2) + 0.0 / (0.3 + 0.4 - 0.0)\n        ) / 2\n        self.assertAllClose(result, expected_result, atol=1e-3)\n\n    def test_multi_dim_input(self):\n        y_pred = np.array([[0, 1], [0, 1]], dtype=np.float32)\n        y_true = np.array([[0, 0], [1, 1]])\n        sample_weight = np.array([[0.2, 0.3], [0.4, 0.1]])\n\n        m_obj = metrics.MeanIoU(num_classes=2, dtype=\"float32\")\n\n        result = m_obj(y_true, y_pred, sample_weight=sample_weight)\n\n        # cm = [[0.2, 0.3],\n        #       [0.4, 0.1]]\n        # sum_row = [0.6, 0.4], sum_col = [0.5, 0.5], true_positives = [0.2,\n        # 0.1]\n        # iou = true_positives / (sum_row + sum_col - true_positives))\n        expected_result = (\n            0.2 / (0.6 + 0.5 - 0.2) + 0.1 / (0.4 + 0.5 - 0.1)\n        ) / 2\n        self.assertAllClose(result, expected_result, atol=1e-3)\n\n    def test_zero_valid_entries(self):\n        m_obj = metrics.MeanIoU(num_classes=2)\n        self.assertAllClose(m_obj.result(), 0, atol=1e-3)\n\n    def test_zero_and_non_zero_entries(self):\n        y_pred = np.array([1], dtype=np.float32)\n        y_true = np.array([1])\n\n        m_obj = metrics.MeanIoU(num_classes=2)\n        result = m_obj(y_true, y_pred)\n\n        # cm = [[0, 0],\n        #       [0, 1]]\n        # sum_row = [0, 1], sum_col = [0, 1], true_positives = [0, 1]\n        # iou = true_positives / (sum_row + sum_col - true_positives))\n        expected_result = (0 + 1 / (1 + 1 - 1)) / 1\n        self.assertAllClose(result, expected_result, atol=1e-3)\n\n    @staticmethod\n    def _confusion_matrix(y_true, y_pred, num_classes):\n        \"\"\"\n        Creates a confusion matrix as a numpy array using vectorized operations.\n\n        Parameters:\n        - y_true: array-like, true class labels.\n        - y_pred: array-like, predicted class labels.\n        - num_classes: int, number of classes.\n\n        Returns:\n        - conf_matrix: np.ndarray, confusion matrix of shape (num_classes,\n                                                              num_classes).\n        \"\"\"\n        # Map pairs of (y_true, y_pred) to indices in the confusion matrix\n        indices = y_true * num_classes + y_pred\n        # Count occurrences of each index\n        conf_matrix = np.bincount(indices, minlength=num_classes * num_classes)\n        # Reshape the flat array into a 2D confusion matrix\n        conf_matrix = conf_matrix.reshape((num_classes, num_classes))\n        return conf_matrix\n\n    @staticmethod\n    def _get_big_chunk(dtype):\n        np.random.seed(14)\n        all_y_true = np.random.choice([0, 1, 2], size=(10, 530, 530))\n        # Generate random probabilities for each channel\n        random_probs = np.random.rand(10, 530, 530, 3)\n        # Normalize to ensure the last dimension sums to 1\n        all_y_pred = random_probs / random_probs.sum(axis=-1, keepdims=True)\n        # Convert predictions to class indices\n        all_y_pred_arg = np.argmax(all_y_pred, axis=-1)\n        mean_iou_metric = metrics.MeanIoU(num_classes=3, dtype=dtype)\n        conf_matrix_start_point = np.array(\n            [\n                [18729664, 18728760, 18731196],\n                [18727297, 18726105, 18728071],\n                [18727917, 18717835, 18723155],\n            ]\n        )\n        mean_iou_metric.total_cm = mean_iou_metric.add_variable(\n            name=\"total_confusion_matrix\",\n            shape=(3, 3),\n            initializer=convert_to_tensor(conf_matrix_start_point),\n            dtype=dtype or \"int\",\n        )\n        mean_iou_metric.update_state(all_y_true, all_y_pred_arg)\n        tmp_true = np.reshape(all_y_true, -1)\n        tmp_pred = np.reshape(all_y_pred_arg, -1)\n        return (\n            all_y_true,\n            all_y_pred_arg,\n            mean_iou_metric,\n            tmp_true,\n            tmp_pred,\n            conf_matrix_start_point,\n        )\n\n    def test_big_chunk(self):\n        # Init. process with dtype=None which will default to int\n        (\n            all_y_true,\n            all_y_pred_arg,\n            mean_iou_metric_all,\n            tmp_true,\n            tmp_pred,\n            conf_matrix_start_point,\n        ) = self._get_big_chunk(dtype=None)\n        conf_matrix_from_keras = np.array(mean_iou_metric_all.total_cm)\n        # Validate confusion matrices and results\n        conf_matrix_manual = (\n            self._confusion_matrix(tmp_true, tmp_pred, 3)\n            + conf_matrix_start_point\n        )\n        self.assertTrue(\n            np.array_equal(conf_matrix_from_keras, conf_matrix_manual),\n            msg=\"Confusion matrices do not match!\",\n        )\n        # Now same but with float32 dtype, in here the confusion matrix\n        # should not match. Likely this can be removed\n        (\n            all_y_true,\n            all_y_pred_arg,\n            mean_iou_metric_all,\n            tmp_true,\n            tmp_pred,\n            conf_matrix_start_point,\n        ) = self._get_big_chunk(dtype=\"float32\")\n        conf_matrix_from_keras = np.array(mean_iou_metric_all.total_cm)\n        # Validate confusion matrices and results\n        conf_matrix_manual = (\n            self._confusion_matrix(tmp_true, tmp_pred, 3)\n            + conf_matrix_start_point\n        )\n        self.assertFalse(\n            np.array_equal(conf_matrix_from_keras, conf_matrix_manual),\n            msg=\"Confusion matrices match, but they should not!\",\n        )\n\n    def test_user_warning_float_weight(self):\n        y_pred = [0, 1, 1, 1]\n        y_true = [0, 1, 1, 0]\n        m_obj = metrics.MeanIoU(num_classes=3)\n        with pytest.warns(Warning, match=r\"weight.*float.*int.*casting\"):\n            m_obj(y_true, y_pred, sample_weight=np.array([0.2, 0.3, 0.4, 0.1]))\n\n\nclass OneHotIoUTest(testing.TestCase):\n    def test_unweighted(self):\n        y_true = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0], [1, 0, 0]])\n        # y_true will be converted to [2, 0, 1, 0]\n        y_pred = np.array(\n            [[0.2, 0.3, 0.5], [0.1, 0.2, 0.7], [0.5, 0.3, 0.1], [0.1, 0.4, 0.5]]\n        )\n        # y_pred will be converted to [2, 2, 0, 2]\n        # cm = [[0, 0, 2],\n        #       [1, 0, 0],\n        #       [0, 0, 1]\n        # sum_row = [1, 0, 3], sum_col = [2, 1, 1], true_positives = [0, 0, 1]\n        # iou = true_positives / (sum_row + sum_col - true_positives))\n        expected_result = (0 / (1 + 2 - 0) + 1 / (3 + 1 - 1)) / 2\n        obj = metrics.OneHotIoU(num_classes=3, target_class_ids=[0, 2])\n        result = obj(y_true, y_pred)\n        self.assertAllClose(result, expected_result, atol=1e-3)\n\n    def test_weighted(self):\n        y_true = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0], [1, 0, 0]])\n        # y_true will be converted to [2, 0, 1, 0]\n        y_pred = np.array(\n            [[0.2, 0.3, 0.5], [0.1, 0.2, 0.7], [0.5, 0.3, 0.1], [0.1, 0.4, 0.5]]\n        )\n        # y_pred will be converted to [2, 2, 0, 2]\n        sample_weight = [0.1, 0.2, 0.3, 0.4]\n        # cm = [[0, 0, 0.2+0.4],\n        #       [0.3, 0, 0],\n        #       [0, 0, 0.1]]\n        # sum_row = [0.3, 0, 0.7], sum_col = [0.6, 0.3, 0.1]\n        # true_positives = [0, 0, 0.1]\n        # iou = true_positives / (sum_row + sum_col - true_positives))\n        expected_result = (0 / (0.3 + 0.6 - 0) + 0.1 / (0.7 + 0.1 - 0.1)) / 2\n        obj = metrics.OneHotIoU(\n            num_classes=3, target_class_ids=[0, 2], dtype=\"float32\"\n        )\n        result = obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose(result, expected_result, atol=1e-3)\n\n\nclass OneHotMeanIoUTest(testing.TestCase):\n    def test_unweighted(self):\n        y_true = np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0], [1, 0, 0]])\n        # y_true will be converted to [2, 0, 1, 0]\n        y_pred = np.array(\n            [[0.2, 0.3, 0.5], [0.1, 0.2, 0.7], [0.5, 0.3, 0.1], [0.1, 0.4, 0.5]]\n        )\n        # y_pred will be converted to [2, 2, 0, 2]\n        # cm = [[0, 0, 2],\n        #       [1, 0, 0],\n        #       [0, 0, 1]\n        # sum_row = [1, 0, 3], sum_col = [2, 1, 1], true_positives = [0, 0, 1]\n        # iou = true_positives / (sum_row + sum_col - true_positives))\n        expected_result = (0 + 0 + 1 / (3 + 1 - 1)) / 3\n        obj = metrics.OneHotMeanIoU(num_classes=3)\n        result = obj(y_true, y_pred)\n        self.assertAllClose(result, expected_result, atol=1e-3)\n\n    def test_weighted(self):\n        y_true = np.array(\n            [\n                [0, 0, 1],\n                [1, 0, 0],\n                [0, 1, 0],\n                [1, 0, 0],\n                [1, 0, 0],\n            ]\n        )\n        # y_true will be converted to [2, 0, 1, 0, 0]\n        y_pred = np.array(\n            [\n                [0.2, 0.3, 0.5],\n                [0.1, 0.2, 0.7],\n                [0.5, 0.3, 0.1],\n                [0.1, 0.4, 0.5],\n                [0.6, 0.2, 0.2],\n            ]\n        )\n        # y_pred will be converted to [2, 2, 0, 2, 0]\n        sample_weight = [0.1, 0.2, 0.3, 0.3, 0.1]\n        # cm = [[0.1, 0, 0.2+0.3],\n        #       [0.3, 0, 0],\n        #       [0, 0, 0.1]]\n        # sum_row = [0.4, 0, 0.6], sum_col = [0.6, 0.3, 0.1]\n        # true_positives = [0.1, 0, 0.1]\n        # iou = true_positives / (sum_row + sum_col - true_positives))\n        expected_result = (\n            0.1 / (0.4 + 0.6 - 0.1) + 0 + 0.1 / (0.6 + 0.1 - 0.1)\n        ) / 3\n        obj = metrics.OneHotMeanIoU(num_classes=3, dtype=\"float32\")\n        result = obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose(result, expected_result, atol=1e-3)\n\n        # Check same result with int weights\n        sample_weight_int = [1, 2, 3, 3, 1]\n        obj_int = metrics.OneHotMeanIoU(num_classes=3)\n        result_int = obj_int(y_true, y_pred, sample_weight=sample_weight_int)\n        self.assertAllClose(result_int, expected_result, atol=1e-3)\n"
  },
  {
    "path": "keras/src/metrics/metric.py",
    "content": "from keras.src import backend\nfrom keras.src import dtype_policies\nfrom keras.src import initializers\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.saving.keras_saveable import KerasSaveable\nfrom keras.src.utils.naming import auto_name\nfrom keras.src.utils.tracking import Tracker\n\n\n@keras_export([\"keras.Metric\", \"keras.metrics.Metric\"])\nclass Metric(KerasSaveable):\n    \"\"\"Encapsulates metric logic and state.\n\n    Args:\n        name: Optional name for the metric instance.\n        dtype: The dtype of the metric's computations. Defaults to `None`, which\n            means using `keras.backend.floatx()`. `keras.backend.floatx()` is a\n            `\"float32\"` unless set to different value\n            (via `keras.backend.set_floatx()`). If a `keras.DTypePolicy` is\n            provided, then the `compute_dtype` will be utilized.\n\n    Example:\n\n    ```python\n    m = SomeMetric(...)\n    for input in ...:\n        m.update_state(input)\n    print('Final result: ', m.result())\n    ```\n\n    Usage with `compile()` API:\n\n    ```python\n    model = keras.Sequential()\n    model.add(keras.layers.Dense(64, activation='relu'))\n    model.add(keras.layers.Dense(64, activation='relu'))\n    model.add(keras.layers.Dense(10, activation='softmax'))\n\n    model.compile(optimizer=keras.optimizers.RMSprop(0.01),\n                  loss=keras.losses.CategoricalCrossentropy(),\n                  metrics=[keras.metrics.CategoricalAccuracy()])\n\n    data = np.random.random((1000, 32))\n    labels = np.random.random((1000, 10))\n\n    model.fit(data, labels, epochs=10)\n    ```\n\n    To be implemented by subclasses:\n\n    * `__init__()`: All state variables should be created in this method by\n      calling `self.add_variable()` like: `self.var = self.add_variable(...)`\n    * `update_state()`: Has all updates to the state variables like:\n      `self.var.assign(...)`.\n    * `result()`: Computes and returns a scalar value or a dict of scalar values\n      for the metric from the state variables.\n\n    Example subclass implementation:\n\n    ```python\n    class BinaryTruePositives(Metric):\n\n        def __init__(self, name='binary_true_positives', **kwargs):\n            super().__init__(name=name, **kwargs)\n            self.true_positives = self.add_variable(\n                shape=(),\n                initializer='zeros',\n                name='true_positives'\n            )\n\n        def update_state(self, y_true, y_pred, sample_weight=None):\n            y_true = ops.cast(y_true, \"bool\")\n            y_pred = ops.cast(y_pred, \"bool\")\n\n            values = ops.logical_and(\n                ops.equal(y_true, True), ops.equal(y_pred, True))\n            values = ops.cast(values, self.dtype)\n            if sample_weight is not None:\n                sample_weight = ops.cast(sample_weight, self.dtype)\n                sample_weight = ops.broadcast_to(\n                    sample_weight, ops.shape(values)\n                )\n                values = ops.multiply(values, sample_weight)\n            self.true_positives.assign(self.true_positives + ops.sum(values))\n\n        def result(self):\n            return self.true_positives\n    ```\n    \"\"\"\n\n    def __init__(self, dtype=None, name=None):\n        self.name = name or auto_name(self.__class__.__name__)\n        self._dtype_policy = dtype_policies.get(dtype or backend.floatx())\n        self._dtype = self._dtype_policy.compute_dtype\n        self._metrics = []\n        self._variables = []\n        self._tracker = Tracker(\n            {\n                \"variables\": (\n                    lambda x: isinstance(x, backend.Variable),\n                    self._variables,\n                ),\n                \"metrics\": (lambda x: isinstance(x, Metric), self._metrics),\n            }\n        )\n\n    def reset_state(self):\n        \"\"\"Reset all of the metric state variables.\n\n        This function is called between epochs/steps,\n        when a metric is evaluated during training.\n        \"\"\"\n        for v in self.variables:\n            v.assign(ops.zeros(v.shape, dtype=v.dtype))\n\n    def update_state(self, *args, **kwargs):\n        \"\"\"Accumulate statistics for the metric.\"\"\"\n        raise NotImplementedError\n\n    def stateless_update_state(self, metric_variables, *args, **kwargs):\n        if len(metric_variables) != len(self.variables):\n            raise ValueError(\n                \"Argument `metric_variables` must be a list of tensors \"\n                f\"corresponding 1:1 to {self.__class__.__name__}().variables. \"\n                f\"Received list with length {len(metric_variables)}, but \"\n                f\"expected {len(self.variables)} variables.\"\n            )\n        # Gather variable mapping\n        mapping = list(zip(self.variables, metric_variables))\n\n        # Call in stateless scope\n        with backend.StatelessScope(state_mapping=mapping) as scope:\n            self.update_state(*args, **kwargs)\n\n        # Gather updated variables\n        metric_variables = []\n        for v in self.variables:\n            new_v = scope.get_current_value(v)\n            if new_v is not None:\n                metric_variables.append(new_v)\n            else:\n                metric_variables.append(v)\n        return metric_variables\n\n    def result(self):\n        \"\"\"Compute the current metric value.\n\n        Returns:\n            A scalar tensor, or a dictionary of scalar tensors.\n        \"\"\"\n        raise NotImplementedError\n\n    def stateless_result(self, metric_variables):\n        if len(metric_variables) != len(self.variables):\n            raise ValueError(\n                \"Argument `metric_variables` must be a list of tensors \"\n                f\"corresponding 1:1 to {self.__class__.__name__}().variables. \"\n                f\"Received list with length {len(metric_variables)}, but \"\n                f\"expected {len(self.variables)} variables.\"\n            )\n        # Gather variable mapping\n        mapping = list(zip(self.variables, metric_variables))\n\n        # Call in stateless scope\n        with backend.StatelessScope(state_mapping=mapping):\n            res = self.result()\n        return res\n\n    def stateless_reset_state(self):\n        # Call in stateless scope\n        with backend.StatelessScope() as scope:\n            self.reset_state()\n\n        # Gather updated variables\n        metric_variables = []\n        for v in self.variables:\n            new_v = scope.get_current_value(v)\n            if new_v is not None:\n                metric_variables.append(new_v)\n            else:\n                metric_variables.append(v)\n        return metric_variables\n\n    @property\n    def dtype(self):\n        return self._dtype\n\n    def _obj_type(self):\n        return \"Metric\"\n\n    def add_variable(\n        self, shape, initializer, dtype=None, aggregation=\"sum\", name=None\n    ):\n        self._check_super_called()\n        with backend.name_scope(self.name.replace(\"/\", \">\"), caller=self):\n            initializer = initializers.get(initializer)\n            variable = backend.Variable(\n                initializer=initializer,\n                shape=shape,\n                dtype=dtype,\n                trainable=False,\n                aggregation=aggregation,\n                synchronization=\"on_read\",\n                name=name,\n            )\n        # Prevent double-tracking\n        self._tracker.add_to_store(\"variables\", variable)\n        return variable\n\n    def add_weight(self, shape=(), initializer=None, dtype=None, name=None):\n        # Backwards compatibility alias\n        return self.add_variable(\n            shape=shape, initializer=initializer, dtype=dtype, name=name\n        )\n\n    @property\n    def variables(self):\n        variables = list(self._variables)\n        for metric in self._metrics:\n            variables.extend(metric.variables)\n        return variables\n\n    def __call__(self, *args, **kwargs):\n        self._check_super_called()\n        self.update_state(*args, **kwargs)\n        return self.result()\n\n    def get_config(self):\n        \"\"\"Return the serializable config of the metric.\"\"\"\n        return {\"name\": self.name, \"dtype\": self.dtype}\n\n    @classmethod\n    def from_config(cls, config):\n        return cls(**config)\n\n    def __setattr__(self, name, value):\n        # Track Variables, Layers, Metrics\n        if hasattr(self, \"_tracker\"):\n            value = self._tracker.track(value)\n        return super().__setattr__(name, value)\n\n    def _check_super_called(self):\n        if not hasattr(self, \"_tracker\"):\n            raise RuntimeError(\n                \"You forgot to call `super().__init__()` \"\n                \"in the `__init__()` method. Go add it!\"\n            )\n\n    def __repr__(self):\n        return f\"<{self.__class__.__name__} name={self.name}>\"\n\n    def __str__(self):\n        return self.__repr__()\n"
  },
  {
    "path": "keras/src/metrics/metric_test.py",
    "content": "import pickle\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src import dtype_policies\nfrom keras.src import initializers\nfrom keras.src import metrics as metrics_module\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.metrics.metric import Metric\n\n\nclass ExampleMetric(Metric):\n    def __init__(self, name=\"mean_square_error\", dtype=None):\n        super().__init__(name=name, dtype=dtype)\n        self.sum = self.add_variable(\n            name=\"sum\", shape=(), initializer=initializers.Zeros()\n        )\n        self.total = self.add_variable(\n            name=\"total\",\n            shape=(),\n            initializer=initializers.Zeros(),\n            dtype=\"int32\",\n        )\n\n    def update_state(self, y_true, y_pred):\n        y_true = ops.convert_to_tensor(y_true, dtype=self.dtype)\n        y_pred = ops.convert_to_tensor(y_pred, dtype=self.dtype)\n        sum = ops.sum((y_true - y_pred) ** 2)\n        self.sum.assign(self.sum + sum)\n        batch_size = ops.shape(y_true)[0]\n        self.total.assign(self.total + batch_size)\n\n    def result(self):\n        _sum = ops.cast(self.sum, dtype=self.dtype)\n        _total = ops.cast(self.total, dtype=self.dtype)\n        _epsilon = ops.cast(backend.epsilon(), dtype=self.dtype)\n        return _sum / (_total + _epsilon)\n\n    def reset_state(self):\n        self.sum.assign(0)\n        self.total.assign(0)\n\n\nclass MetricTest(testing.TestCase):\n    def setUp(self):\n        self._global_dtype_policy = dtype_policies.dtype_policy.dtype_policy()\n        self._floatx = backend.floatx()\n        return super().setUp()\n\n    def tearDown(self):\n        dtype_policies.dtype_policy.set_dtype_policy(self._global_dtype_policy)\n        backend.set_floatx(self._floatx)\n        return super().tearDown()\n\n    def test_end_to_end_flow(self):\n        metric = ExampleMetric(name=\"mse\")\n        self.assertEqual(metric.name, \"mse\")\n        self.assertEqual(len(metric.variables), 2)\n\n        num_samples = 20\n        y_true = np.random.random((num_samples, 3))\n        y_pred = np.random.random((num_samples, 3))\n        batch_size = 8\n        for b in range(0, num_samples // batch_size + 1):\n            y_true_batch = y_true[b * batch_size : (b + 1) * batch_size]\n            y_pred_batch = y_pred[b * batch_size : (b + 1) * batch_size]\n            metric.update_state(y_true_batch, y_pred_batch)\n\n        self.assertAllClose(metric.total, 20)\n        result = metric.result()\n        self.assertAllClose(\n            result, np.sum((y_true - y_pred) ** 2) / num_samples\n        )\n        metric.reset_state()\n        self.assertEqual(metric.result(), 0.0)\n\n    def test_stateless_update_state(self):\n        metric = ExampleMetric(name=\"mse\")\n        self.assertEqual(len(metric.variables), 2)\n        original_variable_values = (\n            metric.variables[0].numpy(),\n            metric.variables[1].numpy(),\n        )\n\n        num_samples = 20\n        y_true = np.random.random((num_samples, 3))\n        y_pred = np.random.random((num_samples, 3))\n        batch_size = 8\n        metric_variables = metric.variables\n        for b in range(0, num_samples // batch_size + 1):\n            y_true_batch = y_true[b * batch_size : (b + 1) * batch_size]\n            y_pred_batch = y_pred[b * batch_size : (b + 1) * batch_size]\n            metric_variables = metric.stateless_update_state(\n                metric_variables, y_true_batch, y_pred_batch\n            )\n\n        self.assertAllClose(metric.variables[0], original_variable_values[0])\n        self.assertAllClose(metric.variables[1], original_variable_values[1])\n        metric.variables[0].assign(metric_variables[0])\n        metric.variables[1].assign(metric_variables[1])\n        self.assertAllClose(metric.total, 20)\n        result = metric.result()\n        self.assertAllClose(\n            result, np.sum((y_true - y_pred) ** 2) / num_samples\n        )\n\n        if backend.backend() == \"jax\":\n            # Check no side effects.\n            import jax\n\n            @jax.jit\n            def update(metric_variables, y_true_batch, y_pred_batch):\n                metric_variables = metric.stateless_update_state(\n                    metric_variables, y_true_batch, y_pred_batch\n                )\n\n            update(metric_variables, y_true_batch, y_pred_batch)\n\n    def test_stateless_result(self):\n        metric = ExampleMetric(name=\"mse\")\n        res = metric.stateless_result([ops.ones(()) * 12, ops.ones(()) * 3])\n        self.assertAllClose(res, 4.0)\n\n    def test_stateless_reset_state(self):\n        metric = ExampleMetric(name=\"mse\")\n        num_samples = 20\n        y_true = np.random.random((num_samples, 3))\n        y_pred = np.random.random((num_samples, 3))\n        metric.update_state(y_true, y_pred)\n        vars = metric.stateless_reset_state()\n        self.assertLen(vars, 2)\n        self.assertEqual(vars[0], 0)\n        self.assertEqual(vars[1], 0)\n\n    def test_variable_tracking(self):\n        # In list\n        metric = ExampleMetric(name=\"mse\")\n        metric.more_vars = [backend.Variable(0.0), backend.Variable(1.0)]\n        self.assertEqual(len(metric.variables), 4)\n\n        # In dict\n        metric = ExampleMetric(name=\"mse\")\n        metric.more_vars = {\n            \"a\": backend.Variable(0.0),\n            \"b\": backend.Variable(1.0),\n        }\n        self.assertEqual(len(metric.variables), 4)\n\n        # In nested structured\n        metric = ExampleMetric(name=\"mse\")\n        metric.more_vars = {\"a\": [backend.Variable(0.0), backend.Variable(1.0)]}\n        self.assertEqual(len(metric.variables), 4)\n\n    def test_submetric_tracking(self):\n        # Plain attr\n        metric = ExampleMetric(name=\"mse\")\n        metric.submetric = ExampleMetric(name=\"submse\")\n        self.assertEqual(len(metric.variables), 4)\n\n        # In list\n        metric = ExampleMetric(name=\"mse\")\n        metric.submetrics = [\n            ExampleMetric(name=\"submse1\"),\n            ExampleMetric(name=\"submse2\"),\n        ]\n        self.assertEqual(len(metric.variables), 6)\n\n        # In dict\n        metric = ExampleMetric(name=\"mse\")\n        metric.submetrics = {\n            \"1\": ExampleMetric(name=\"submse1\"),\n            \"2\": ExampleMetric(name=\"submse2\"),\n        }\n        self.assertEqual(len(metric.variables), 6)\n\n        # Two levels deep\n        metric = ExampleMetric(name=\"mse\")\n        metric.submetric = ExampleMetric(name=\"submse\")\n        metric.submetric.submetric = ExampleMetric(name=\"subsubmse\")\n        self.assertEqual(len(metric.variables), 6)\n\n    def test_serialization(self):\n        self.run_class_serialization_test(\n            ExampleMetric(name=\"mse\"),\n            custom_objects={\"ExampleMetric\": ExampleMetric},\n        )\n\n    def test_pickle(self):\n        metric = metrics_module.get(\"mse\")\n        reloaded = pickle.loads(pickle.dumps(metric))\n        self.assertIsInstance(reloaded, metrics_module.MeanSquaredError)\n\n    def test_get_method(self):\n        metric = metrics_module.get(\"mse\")\n        self.assertIsInstance(metric, metrics_module.MeanSquaredError)\n\n        metric = metrics_module.get(\"mean_squared_error\")\n        self.assertIsInstance(metric, metrics_module.MeanSquaredError)\n\n        metric = metrics_module.get(\"categorical_accuracy\")\n        self.assertIsInstance(metric, metrics_module.CategoricalAccuracy)\n\n        metric = metrics_module.get(None)\n        self.assertEqual(metric, None)\n\n        with self.assertRaises(ValueError):\n            metrics_module.get(\"typo\")\n\n    def test_dtype_arg(self):\n        metric = ExampleMetric(name=\"mse\", dtype=\"float16\")\n        self.assertEqual(metric.name, \"mse\")\n        self.assertEqual(len(metric.variables), 2)\n\n        num_samples = 10\n        y_true = np.random.random((num_samples, 3))\n        y_pred = np.random.random((num_samples, 3))\n        metric.update_state(y_true, y_pred)\n        result = metric.result()\n        self.assertAllClose(\n            result, np.sum((y_true - y_pred) ** 2) / num_samples, atol=1e-3\n        )\n        self.assertDType(result, \"float16\")\n\n        # Test DTypePolicy for `dtype` argument\n        metric = ExampleMetric(\n            dtype=dtype_policies.DTypePolicy(\"mixed_float16\")\n        )\n        metric.update_state(y_true, y_pred)\n        metric.update_state(y_true, y_pred)\n        result = metric.result()\n        self.assertAllClose(\n            result, np.sum((y_true - y_pred) ** 2) / num_samples, atol=1e-3\n        )\n        self.assertDType(result, \"float16\")\n\n        # `dtype` setter should raise AttributeError\n        with self.assertRaises(AttributeError):\n            metric.dtype = \"bfloat16\"\n\n    def test_default_dtype(self):\n        y_true = np.random.random((10, 3))\n        y_pred = np.random.random((10, 3))\n\n        # Defaults to `keras.config.floatx()` not global `dtype_policy`\n        dtype_policies.dtype_policy.set_dtype_policy(\"mixed_float16\")\n        metric = ExampleMetric()\n        metric.update_state(y_true, y_pred)\n        result = metric.result()\n        self.assertDType(result, \"float32\")\n\n        backend.set_floatx(\"float16\")\n        metric = ExampleMetric()\n        metric.update_state(y_true, y_pred)\n        result = metric.result()\n        self.assertDType(result, backend.floatx())\n"
  },
  {
    "path": "keras/src/metrics/metrics_utils.py",
    "content": "from enum import Enum\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src.losses.loss import squeeze_or_expand_to_same_rank\nfrom keras.src.utils.python_utils import to_list\n\nNEG_INF = -1e10\n\n\ndef assert_thresholds_range(thresholds):\n    if thresholds is not None:\n        invalid_thresholds = [\n            t for t in thresholds if t is None or t < 0 or t > 1\n        ]\n        if invalid_thresholds:\n            raise ValueError(\n                \"Threshold values must be in [0, 1]. \"\n                f\"Received: {invalid_thresholds}\"\n            )\n\n\ndef parse_init_thresholds(thresholds, default_threshold=0.5):\n    if thresholds is not None:\n        assert_thresholds_range(to_list(thresholds))\n    thresholds = to_list(\n        default_threshold if thresholds is None else thresholds\n    )\n    return thresholds\n\n\nclass ConfusionMatrix(Enum):\n    TRUE_POSITIVES = \"tp\"\n    FALSE_POSITIVES = \"fp\"\n    TRUE_NEGATIVES = \"tn\"\n    FALSE_NEGATIVES = \"fn\"\n\n\nclass AUCCurve(Enum):\n    \"\"\"Type of AUC Curve (ROC or PR).\"\"\"\n\n    ROC = \"ROC\"\n    PR = \"PR\"\n    PRGAIN = \"PRGAIN\"\n\n    @staticmethod\n    def from_str(key):\n        if key in (\"pr\", \"PR\"):\n            return AUCCurve.PR\n        elif key in (\"roc\", \"ROC\"):\n            return AUCCurve.ROC\n        elif key in (\"prgain\", \"PRGAIN\"):\n            return AUCCurve.PRGAIN\n        else:\n            raise ValueError(\n                f'Invalid AUC curve value: \"{key}\". '\n                'Expected values are [\"PR\", \"ROC\", \"PRGAIN\"]'\n            )\n\n\nclass AUCSummationMethod(Enum):\n    \"\"\"Type of AUC summation method.\n\n    https://en.wikipedia.org/wiki/Riemann_sum)\n\n    Contains the following values:\n    * 'interpolation': Applies mid-point summation scheme for `ROC` curve. For\n      `PR` curve, interpolates (true/false) positives but not the ratio that is\n      precision (see Davis & Goadrich 2006 for details).\n    * 'minoring': Applies left summation for increasing intervals and right\n      summation for decreasing intervals.\n    * 'majoring': Applies right summation for increasing intervals and left\n      summation for decreasing intervals.\n    \"\"\"\n\n    INTERPOLATION = \"interpolation\"\n    MAJORING = \"majoring\"\n    MINORING = \"minoring\"\n\n    @staticmethod\n    def from_str(key):\n        if key in (\"interpolation\", \"Interpolation\"):\n            return AUCSummationMethod.INTERPOLATION\n        elif key in (\"majoring\", \"Majoring\"):\n            return AUCSummationMethod.MAJORING\n        elif key in (\"minoring\", \"Minoring\"):\n            return AUCSummationMethod.MINORING\n        else:\n            raise ValueError(\n                f'Invalid AUC summation method value: \"{key}\". '\n                'Expected values are [\"interpolation\", \"majoring\", \"minoring\"]'\n            )\n\n\ndef _update_confusion_matrix_variables_optimized(\n    variables_to_update,\n    y_true,\n    y_pred,\n    thresholds,\n    multi_label=False,\n    sample_weights=None,\n    label_weights=None,\n    thresholds_with_epsilon=False,\n):\n    \"\"\"Update confusion matrix variables with memory efficient alternative.\n\n    Note that the thresholds need to be evenly distributed within the list, eg,\n    the diff between consecutive elements are the same.\n\n    To compute TP/FP/TN/FN, we are measuring a binary classifier\n      C(t) = (predictions >= t)\n    at each threshold 't'. So we have\n      TP(t) = sum( C(t) * true_labels )\n      FP(t) = sum( C(t) * false_labels )\n\n    But, computing C(t) requires computation for each t. To make it fast,\n    observe that C(t) is a cumulative integral, and so if we have\n      thresholds = [t_0, ..., t_{n-1}];  t_0 < ... < t_{n-1}\n    where n = num_thresholds, and if we can compute the bucket function\n      B(i) = Sum( (predictions == t), t_i <= t < t{i+1} )\n    then we get\n      C(t_i) = sum( B(j), j >= i )\n    which is the reversed cumulative sum in ops.cumsum().\n\n    We can compute B(i) efficiently by taking advantage of the fact that\n    our thresholds are evenly distributed, in that\n      width = 1.0 / (num_thresholds - 1)\n      thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0]\n    Given a prediction value p, we can map it to its bucket by\n      bucket_index(p) = floor( p * (num_thresholds - 1) )\n    so we can use ops.segment_sum() to update the buckets in one pass.\n\n    Consider following example:\n    y_true = [0, 0, 1, 1]\n    y_pred = [0.1, 0.5, 0.3, 0.9]\n    thresholds = [0.0, 0.5, 1.0]\n    num_buckets = 2   # [0.0, 1.0], (1.0, 2.0]\n    bucket_index(y_pred) = ops.floor(y_pred * num_buckets)\n                         = ops.floor([0.2, 1.0, 0.6, 1.8])\n                         = [0, 0, 0, 1]\n    # The meaning of this bucket is that if any of the label is true,\n    # then 1 will be added to the corresponding bucket with the index.\n    # Eg, if the label for 0.2 is true, then 1 will be added to bucket 0. If the\n    # label for 1.8 is true, then 1 will be added to bucket 1.\n    #\n    # Note the second item \"1.0\" is floored to 0, since the value need to be\n    # strictly larger than the bucket lower bound.\n    # In the implementation, we use ops.ceil() - 1 to achieve this.\n    tp_bucket_value = ops.segment_sum(true_labels, bucket_indices,\n                                                   num_segments=num_thresholds)\n                    = [1, 1, 0]\n    # For [1, 1, 0] here, it means there is 1 true value contributed by bucket\n    # 0, and 1 value contributed by bucket 1. When we aggregate them to\n    # together, the result become [a + b + c, b + c, c], since large thresholds\n    # will always contribute to the value for smaller thresholds.\n    true_positive = ops.cumsum(tp_bucket_value, reverse=True)\n                  = [2, 1, 0]\n\n    This implementation exhibits a run time and space complexity of O(T + N),\n    where T is the number of thresholds and N is the size of predictions.\n    Metrics that rely on standard implementation instead exhibit a complexity of\n    O(T * N).\n\n    Args:\n        variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid\n            keys and corresponding variables to update as values.\n        y_true: A floating point `Tensor` whose shape matches `y_pred`. Will be\n            cast to `bool`.\n        y_pred: A floating point `Tensor` of arbitrary shape and whose values\n            are in the range `[0, 1]`.\n        thresholds: A sorted floating point `Tensor` with value in `[0, 1]`.\n            It need to be evenly distributed (the diff between each element need\n            to be the same).\n        multi_label: Optional boolean indicating whether multidimensional\n            prediction/labels should be treated as multilabel responses, or\n            flattened into a single label. When True, the values of\n            `variables_to_update` must have a second dimension equal to the\n            number of labels in y_true and y_pred, and those tensors must not be\n            RaggedTensors.\n        sample_weights: Optional `Tensor` whose rank is either 0, or the same\n            rank as `y_true`, and must be broadcastable to `y_true` (i.e., all\n            dimensions must be either `1`, or the same as the corresponding\n            `y_true` dimension).\n        label_weights: Optional tensor of non-negative weights for multilabel\n            data. The weights are applied when calculating TP, FP, FN, and TN\n            without explicit multilabel handling (i.e. when the data is to be\n            flattened).\n        thresholds_with_epsilon: Optional boolean indicating whether the leading\n            and tailing thresholds has any epsilon added for floating point\n            imprecisions.  It will change how we handle the leading and tailing\n            bucket.\n    \"\"\"\n    num_thresholds = ops.shape(thresholds)[0]\n\n    if sample_weights is None:\n        sample_weights = 1.0\n    else:\n        sample_weights = ops.broadcast_to(\n            ops.cast(sample_weights, dtype=y_pred.dtype), ops.shape(y_pred)\n        )\n        if not multi_label:\n            sample_weights = ops.reshape(sample_weights, [-1])\n    if label_weights is None:\n        label_weights = 1.0\n    else:\n        label_weights = ops.expand_dims(label_weights, 0)\n        label_weights = ops.broadcast_to(label_weights, ops.shape(y_pred))\n        if not multi_label:\n            label_weights = ops.reshape(label_weights, [-1])\n    weights = ops.cast(\n        ops.multiply(sample_weights, label_weights), y_true.dtype\n    )\n\n    # We shouldn't need this, but in case there are predict value that is out of\n    # the range of [0.0, 1.0]\n    y_pred = ops.clip(y_pred, x_min=0.0, x_max=1.0)\n\n    y_true = ops.cast(ops.cast(y_true, \"bool\"), y_true.dtype)\n    if not multi_label:\n        y_true = ops.reshape(y_true, [-1])\n        y_pred = ops.reshape(y_pred, [-1])\n\n    true_labels = ops.multiply(y_true, weights)\n    false_labels = ops.multiply((1.0 - y_true), weights)\n\n    # Compute the bucket indices for each prediction value.\n    # Since the predict value has to be strictly greater than the thresholds,\n    # eg, buckets like [0, 0.5], (0.5, 1], and 0.5 belongs to first bucket.\n    # We have to use math.ceil(val) - 1 for the bucket.\n    bucket_indices = (\n        ops.ceil(y_pred * (ops.cast(num_thresholds, dtype=y_pred.dtype) - 1))\n        - 1\n    )\n\n    if thresholds_with_epsilon:\n        # In this case, the first bucket should actually take into account since\n        # the any prediction between [0.0, 1.0] should be larger than the first\n        # threshold. We change the bucket value from -1 to 0.\n        bucket_indices = ops.relu(bucket_indices)\n\n    bucket_indices = ops.cast(bucket_indices, \"int32\")\n\n    if multi_label:\n        # We need to run bucket segment sum for each of the label class. In the\n        # multi_label case, the rank of the label is 2. We first transpose it so\n        # that the label dim becomes the first and we can parallel run though\n        # them.\n        true_labels = ops.transpose(true_labels)\n        false_labels = ops.transpose(false_labels)\n        bucket_indices = ops.transpose(bucket_indices)\n\n        def gather_bucket(label_and_bucket_index):\n            label, bucket_index = (\n                label_and_bucket_index[0],\n                label_and_bucket_index[1],\n            )\n            return ops.segment_sum(\n                data=label,\n                segment_ids=bucket_index,\n                num_segments=num_thresholds,\n            )\n\n        tp_bucket_v = backend.vectorized_map(\n            gather_bucket,\n            (true_labels, bucket_indices),\n        )\n        fp_bucket_v = backend.vectorized_map(\n            gather_bucket, (false_labels, bucket_indices)\n        )\n        tp = ops.transpose(ops.flip(ops.cumsum(ops.flip(tp_bucket_v), axis=1)))\n        fp = ops.transpose(ops.flip(ops.cumsum(ops.flip(fp_bucket_v), axis=1)))\n    else:\n        tp_bucket_v = ops.segment_sum(\n            data=true_labels,\n            segment_ids=bucket_indices,\n            num_segments=num_thresholds,\n        )\n        fp_bucket_v = ops.segment_sum(\n            data=false_labels,\n            segment_ids=bucket_indices,\n            num_segments=num_thresholds,\n        )\n        tp = ops.flip(ops.cumsum(ops.flip(tp_bucket_v)))\n        fp = ops.flip(ops.cumsum(ops.flip(fp_bucket_v)))\n\n    # fn = sum(true_labels) - tp\n    # tn = sum(false_labels) - fp\n    if (\n        ConfusionMatrix.TRUE_NEGATIVES in variables_to_update\n        or ConfusionMatrix.FALSE_NEGATIVES in variables_to_update\n    ):\n        if multi_label:\n            total_true_labels = ops.sum(true_labels, axis=1)\n            total_false_labels = ops.sum(false_labels, axis=1)\n        else:\n            total_true_labels = ops.sum(true_labels)\n            total_false_labels = ops.sum(false_labels)\n\n    if ConfusionMatrix.TRUE_POSITIVES in variables_to_update:\n        variable = variables_to_update[ConfusionMatrix.TRUE_POSITIVES]\n        variable.assign(variable + tp)\n    if ConfusionMatrix.FALSE_POSITIVES in variables_to_update:\n        variable = variables_to_update[ConfusionMatrix.FALSE_POSITIVES]\n        variable.assign(variable + fp)\n    if ConfusionMatrix.TRUE_NEGATIVES in variables_to_update:\n        variable = variables_to_update[ConfusionMatrix.TRUE_NEGATIVES]\n        tn = total_false_labels - fp\n        variable.assign(variable + tn)\n    if ConfusionMatrix.FALSE_NEGATIVES in variables_to_update:\n        variable = variables_to_update[ConfusionMatrix.FALSE_NEGATIVES]\n        fn = total_true_labels - tp\n        variable.assign(variable + fn)\n\n\ndef is_evenly_distributed_thresholds(thresholds):\n    \"\"\"Check if the thresholds list is evenly distributed.\n\n    We could leverage evenly distributed thresholds to use less memory when\n    calculate metrics like AUC where each individual threshold need to be\n    evaluated.\n\n    Args:\n      thresholds: A python list or tuple, or 1D numpy array whose value is\n        ranged in [0, 1].\n\n    Returns:\n      boolean, whether the values in the inputs are evenly distributed.\n    \"\"\"\n    # Check the list value and see if it is evenly distributed.\n    num_thresholds = len(thresholds)\n    if num_thresholds < 3:\n        return False\n    even_thresholds = np.arange(num_thresholds, dtype=np.float32) / (\n        num_thresholds - 1\n    )\n    return np.allclose(thresholds, even_thresholds, atol=backend.epsilon())\n\n\ndef update_confusion_matrix_variables(\n    variables_to_update,\n    y_true,\n    y_pred,\n    thresholds,\n    top_k=None,\n    class_id=None,\n    sample_weight=None,\n    multi_label=False,\n    label_weights=None,\n    thresholds_distributed_evenly=False,\n):\n    \"\"\"Updates the given confusion matrix variables.\n\n    For every pair of values in y_true and y_pred:\n\n    true_positive: y_true == True and y_pred > thresholds\n    false_negatives: y_true == True and y_pred <= thresholds\n    true_negatives: y_true == False and y_pred <= thresholds\n    false_positive: y_true == False and y_pred > thresholds\n\n    The results will be weighted and added together. When multiple thresholds\n    are provided, we will repeat the same for every threshold.\n\n    For estimation of these metrics over a stream of data, the function creates\n    an `update_op` operation that updates the given variables.\n\n    If `sample_weight` is `None`, weights default to 1.\n    Use weights of 0 to mask values.\n\n    Args:\n      variables_to_update: Dictionary with 'tp', 'fn', 'tn', 'fp' as valid keys\n        and corresponding variables to update as values.\n      y_true: A `Tensor` whose shape matches `y_pred`. Will be cast to `bool`.\n      y_pred: A floating point `Tensor` of arbitrary shape and whose values are\n        in the range `[0, 1]`.\n      thresholds: A float value, float tensor, python list, or tuple of float\n        thresholds in `[0, 1]`, or NEG_INF (used when top_k is set).\n      top_k: Optional int, indicates that the positive labels should be limited\n        to the top k predictions.\n      class_id: Optional int, limits the prediction and labels to the class\n        specified by this argument.\n      sample_weight: Optional `Tensor` whose rank is either 0, or the same rank\n        as `y_true`, and must be broadcastable to `y_true` (i.e., all dimensions\n        must be either `1`, or the same as the corresponding `y_true`\n        dimension).\n      multi_label: Optional boolean indicating whether multidimensional\n        prediction/labels should be treated as multilabel responses, or\n        flattened into a single label. When True, the values of\n        `variables_to_update` must have a second dimension equal to the number\n        of labels in y_true and y_pred, and those tensors must not be\n        RaggedTensors.\n      label_weights: (optional) tensor of non-negative weights for multilabel\n        data. The weights are applied when calculating TP, FP, FN, and TN\n        without explicit multilabel handling (i.e. when the data is to be\n        flattened).\n      thresholds_distributed_evenly: Boolean, whether the thresholds are evenly\n        distributed within the list. An optimized method will be used if this is\n        the case. See _update_confusion_matrix_variables_optimized() for more\n        details.\n\n    Raises:\n      ValueError: If `y_pred` and `y_true` have mismatched shapes, or if\n        `sample_weight` is not `None` and its shape doesn't match `y_pred`, or\n        if `variables_to_update` contains invalid keys.\n    \"\"\"\n    if multi_label and label_weights is not None:\n        raise ValueError(\n            \"`label_weights` for multilabel data should be handled \"\n            \"outside of `update_confusion_matrix_variables` when \"\n            \"`multi_label` is True.\"\n        )\n    if variables_to_update is None:\n        return\n    if not any(\n        key for key in variables_to_update if key in list(ConfusionMatrix)\n    ):\n        raise ValueError(\n            \"Please provide at least one valid confusion matrix \"\n            \"variable to update. Valid variable key options are: \"\n            f'\"{list(ConfusionMatrix)}\". '\n            f'Received: \"{variables_to_update.keys()}\"'\n        )\n\n    variable_dtype = list(variables_to_update.values())[0].dtype\n\n    y_true = ops.cast(y_true, dtype=variable_dtype)\n    y_pred = ops.cast(y_pred, dtype=variable_dtype)\n\n    if thresholds_distributed_evenly:\n        # Check whether the thresholds has any leading or tailing epsilon added\n        # for floating point imprecision. The leading and tailing threshold will\n        # be handled bit differently as the corner case.  At this point,\n        # thresholds should be a list/array with more than 2 items, and ranged\n        # between [0, 1]. See is_evenly_distributed_thresholds() for more\n        # details.\n        thresholds_with_epsilon = thresholds[0] < 0.0 or thresholds[-1] > 1.0\n\n    thresholds = ops.convert_to_tensor(thresholds, dtype=variable_dtype)\n    num_thresholds = ops.shape(thresholds)[0]\n\n    if multi_label:\n        one_thresh = ops.equal(\n            np.array(1, dtype=\"int32\"),\n            len(thresholds.shape),\n        )\n    else:\n        one_thresh = np.array(True, dtype=\"bool\")\n\n    invalid_keys = [\n        key for key in variables_to_update if key not in list(ConfusionMatrix)\n    ]\n    if invalid_keys:\n        raise ValueError(\n            f'Invalid keys: \"{invalid_keys}\". '\n            f'Valid variable key options are: \"{list(ConfusionMatrix)}\"'\n        )\n\n    y_pred, y_true = squeeze_or_expand_to_same_rank(y_pred, y_true)\n    if sample_weight is not None:\n        sample_weight = ops.expand_dims(\n            ops.cast(sample_weight, dtype=variable_dtype), axis=-1\n        )\n        _, sample_weight = squeeze_or_expand_to_same_rank(\n            y_true, sample_weight, expand_rank_1=False\n        )\n\n    if top_k is not None:\n        y_pred = _filter_top_k(y_pred, top_k)\n\n    if class_id is not None:\n        if len(y_pred.shape) == 1:\n            raise ValueError(\n                \"When class_id is provided, y_pred must be a 2D array \"\n                \"with shape (num_samples, num_classes), found shape: \"\n                f\"{y_pred.shape}\"\n            )\n\n        # Preserve dimension to match with sample_weight\n        y_true = y_true[..., class_id, None]\n        y_pred = y_pred[..., class_id, None]\n\n    if thresholds_distributed_evenly:\n        return _update_confusion_matrix_variables_optimized(\n            variables_to_update,\n            y_true,\n            y_pred,\n            thresholds,\n            multi_label=multi_label,\n            sample_weights=sample_weight,\n            label_weights=label_weights,\n            thresholds_with_epsilon=thresholds_with_epsilon,\n        )\n\n    if None in y_pred.shape:\n        pred_shape = ops.shape(y_pred)\n        num_predictions = pred_shape[0]\n        if len(y_pred.shape) == 1:\n            num_labels = 1\n        else:\n            num_labels = ops.cast(\n                ops.prod(ops.array(pred_shape[1:]), axis=0), \"int32\"\n            )\n        thresh_label_tile = ops.where(one_thresh, num_labels, 1)\n    else:\n        pred_shape = ops.shape(y_pred)\n        num_predictions = pred_shape[0]\n        if len(y_pred.shape) == 1:\n            num_labels = 1\n        else:\n            num_labels = np.prod(pred_shape[1:], axis=0).astype(\"int32\")\n        thresh_label_tile = np.where(one_thresh, num_labels, 1)\n\n    # Reshape predictions and labels, adding a dim for thresholding.\n    if multi_label:\n        predictions_extra_dim = ops.expand_dims(y_pred, 0)\n        labels_extra_dim = ops.expand_dims(ops.cast(y_true, dtype=\"bool\"), 0)\n    else:\n        # Flatten predictions and labels when not multilabel.\n        predictions_extra_dim = ops.reshape(y_pred, [1, -1])\n        labels_extra_dim = ops.reshape(ops.cast(y_true, dtype=\"bool\"), [1, -1])\n\n    # Tile the thresholds for every prediction.\n    if multi_label:\n        thresh_pretile_shape = [num_thresholds, 1, -1]\n        thresh_tiles = [1, num_predictions, thresh_label_tile]\n        data_tiles = [num_thresholds, 1, 1]\n    else:\n        thresh_pretile_shape = [num_thresholds, -1]\n        thresh_tiles = [1, num_predictions * num_labels]\n        data_tiles = [num_thresholds, 1]\n\n    thresh_tiled = ops.tile(\n        ops.reshape(thresholds, thresh_pretile_shape), thresh_tiles\n    )\n\n    # Tile the predictions for every threshold.\n    preds_tiled = ops.tile(predictions_extra_dim, data_tiles)\n\n    # Compare predictions and threshold.\n    pred_is_pos = ops.greater(preds_tiled, thresh_tiled)\n\n    # Tile labels by number of thresholds\n    label_is_pos = ops.tile(labels_extra_dim, data_tiles)\n\n    if sample_weight is not None:\n        sample_weight = ops.broadcast_to(\n            ops.cast(sample_weight, dtype=y_pred.dtype), ops.shape(y_pred)\n        )\n        weights_tiled = ops.tile(\n            ops.reshape(sample_weight, thresh_tiles), data_tiles\n        )\n    else:\n        weights_tiled = None\n\n    if label_weights is not None and not multi_label:\n        label_weights = ops.expand_dims(label_weights, 0)\n        label_weights = ops.broadcast_to(label_weights, ops.shape(y_pred))\n        label_weights_tiled = ops.tile(\n            ops.reshape(label_weights, thresh_tiles), data_tiles\n        )\n        if weights_tiled is None:\n            weights_tiled = label_weights_tiled\n        else:\n            weights_tiled = ops.multiply(weights_tiled, label_weights_tiled)\n\n    def weighted_assign_add(label, pred, weights, var):\n        label_and_pred = ops.cast(ops.logical_and(label, pred), dtype=var.dtype)\n        if weights is not None:\n            label_and_pred *= ops.cast(weights, dtype=var.dtype)\n        var.assign(var + ops.sum(label_and_pred, 1))\n\n    loop_vars = {\n        ConfusionMatrix.TRUE_POSITIVES: (label_is_pos, pred_is_pos),\n    }\n    update_tn = ConfusionMatrix.TRUE_NEGATIVES in variables_to_update\n    update_fp = ConfusionMatrix.FALSE_POSITIVES in variables_to_update\n    update_fn = ConfusionMatrix.FALSE_NEGATIVES in variables_to_update\n\n    if update_fn or update_tn:\n        pred_is_neg = ops.logical_not(pred_is_pos)\n        loop_vars[ConfusionMatrix.FALSE_NEGATIVES] = (label_is_pos, pred_is_neg)\n\n    if update_fp or update_tn:\n        label_is_neg = ops.logical_not(label_is_pos)\n        loop_vars[ConfusionMatrix.FALSE_POSITIVES] = (label_is_neg, pred_is_pos)\n        if update_tn:\n            loop_vars[ConfusionMatrix.TRUE_NEGATIVES] = (\n                label_is_neg,\n                pred_is_neg,\n            )\n\n    for matrix_cond, (label, pred) in loop_vars.items():\n        if matrix_cond in variables_to_update:\n            weighted_assign_add(\n                label, pred, weights_tiled, variables_to_update[matrix_cond]\n            )\n\n\ndef _filter_top_k(x, k):\n    \"\"\"Filters top-k values in the last dim of x and set the rest to NEG_INF.\n\n    Used for computing top-k prediction values in dense labels (which has the\n    same shape as predictions) for recall and precision top-k metrics.\n\n    Args:\n      x: tensor with any dimensions.\n      k: the number of values to keep.\n\n    Returns:\n      tensor with same shape and dtype as x.\n    \"\"\"\n    _, top_k_idx = ops.top_k(x, k)\n    top_k_mask = ops.sum(\n        ops.one_hot(top_k_idx, ops.shape(x)[-1], axis=-1), axis=-2\n    )\n    return x * top_k_mask + NEG_INF * (1 - top_k_mask)\n\n\ndef confusion_matrix(\n    labels,\n    predictions,\n    num_classes,\n    weights=None,\n    dtype=\"int32\",\n):\n    \"\"\"Computes the confusion matrix from predictions and labels.\n\n    The matrix columns represent the prediction labels and the rows represent\n    the real labels. The confusion matrix is always a 2-D array of shape\n    `(n, n)`, where `n` is the number of valid labels for a given classification\n    task. Both prediction and labels must be 1-D arrays of the same shape in\n    order for this function to work.\n\n    If `num_classes` is `None`, then `num_classes` will be set to one plus the\n    maximum value in either predictions or labels. Class labels are expected to\n    start at 0. For example, if `num_classes` is 3, then the possible labels\n    would be `[0, 1, 2]`.\n\n    If `weights` is not `None`, then each prediction contributes its\n    corresponding weight to the total value of the confusion matrix cell.\n\n    For example:\n\n    ```python\n    keras.metrics.metrics_utils.confusion_matrix([1, 2, 4], [2, 2, 4]) ==>\n        [[0 0 0 0 0]\n        [0 0 1 0 0]\n        [0 0 1 0 0]\n        [0 0 0 0 0]\n        [0 0 0 0 1]]\n    ```\n\n    Note that the possible labels are assumed to be `[0, 1, 2, 3, 4]`,\n    resulting in a 5x5 confusion matrix.\n\n    Args:\n        labels: 1-D tensor of real labels for the classification task.\n        predictions: 1-D tensor of predictions for a given classification.\n        num_classes: The possible number of labels the classification\n            task can have.\n        weights: An optional tensor whose shape matches `predictions`.\n        dtype: Data type of the confusion matrix.\n\n    Returns:\n        A tensor of type `dtype` with shape `(n, n)` representing the confusion\n        matrix, where `n` is the number of possible labels in the classification\n        task.\n    \"\"\"\n    labels = ops.convert_to_tensor(labels, dtype)\n    predictions = ops.convert_to_tensor(predictions, dtype)\n    labels, predictions = squeeze_or_expand_to_same_rank(labels, predictions)\n\n    predictions = ops.cast(predictions, dtype)\n    labels = ops.cast(labels, dtype)\n\n    if weights is not None:\n        weights = ops.convert_to_tensor(weights, dtype)\n\n    indices = ops.stack([labels, predictions], axis=1)\n    values = ops.ones_like(predictions, dtype) if weights is None else weights\n    indices = ops.cast(indices, dtype=\"int64\")\n    values = ops.cast(values, dtype=dtype)\n    num_classes = int(num_classes)\n    confusion_matrix = ops.scatter(indices, values, (num_classes, num_classes))\n    return confusion_matrix\n"
  },
  {
    "path": "keras/src/metrics/probabilistic_metrics.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.losses.losses import binary_crossentropy\nfrom keras.src.losses.losses import categorical_crossentropy\nfrom keras.src.losses.losses import kl_divergence\nfrom keras.src.losses.losses import poisson\nfrom keras.src.losses.losses import sparse_categorical_crossentropy\nfrom keras.src.metrics import reduction_metrics\n\n\n@keras_export(\"keras.metrics.KLDivergence\")\nclass KLDivergence(reduction_metrics.MeanMetricWrapper):\n    \"\"\"Computes Kullback-Leibler divergence metric between `y_true` and\n    `y_pred`.\n\n    Formula:\n\n    ```python\n    metric = y_true * log(y_true / y_pred)\n    ```\n\n    `y_true` and `y_pred` are expected to be probability\n    distributions, with values between 0 and 1. They will get\n    clipped to the `[0, 1]` range.\n\n    Args:\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Examples:\n\n    >>> m = keras.metrics.KLDivergence()\n    >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])\n    >>> m.result()\n    0.45814306\n\n    >>> m.reset_state()\n    >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],\n    ...                sample_weight=[1, 0])\n    >>> m.result()\n    0.9162892\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(optimizer='sgd',\n                  loss='mse',\n                  metrics=[keras.metrics.KLDivergence()])\n    ```\n    \"\"\"\n\n    def __init__(self, name=\"kl_divergence\", dtype=None):\n        super().__init__(fn=kl_divergence, name=name, dtype=dtype)\n\n    def get_config(self):\n        return {\"name\": self.name, \"dtype\": self.dtype}\n\n\n@keras_export(\"keras.metrics.Poisson\")\nclass Poisson(reduction_metrics.MeanMetricWrapper):\n    \"\"\"Computes the Poisson metric between `y_true` and `y_pred`.\n\n    Formula:\n\n    ```python\n    metric = y_pred - y_true * log(y_pred)\n    ```\n\n    Args:\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Examples:\n\n    >>> m = keras.metrics.Poisson()\n    >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])\n    >>> m.result()\n    0.49999997\n\n    >>> m.reset_state()\n    >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]],\n    ...                sample_weight=[1, 0])\n    >>> m.result()\n    0.99999994\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(optimizer='sgd',\n                  loss='mse',\n                  metrics=[keras.metrics.Poisson()])\n    ```\n    \"\"\"\n\n    def __init__(self, name=\"poisson\", dtype=None):\n        super().__init__(fn=poisson, name=name, dtype=dtype)\n\n    def get_config(self):\n        return {\"name\": self.name, \"dtype\": self.dtype}\n\n\n@keras_export(\"keras.metrics.BinaryCrossentropy\")\nclass BinaryCrossentropy(reduction_metrics.MeanMetricWrapper):\n    \"\"\"Computes the crossentropy metric between the labels and predictions.\n\n    This is the crossentropy metric class to be used when there are only two\n    label classes (0 and 1).\n\n    Args:\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n        from_logits: (Optional) Whether output is expected\n            to be a logits tensor. By default, we consider\n            that output encodes a probability distribution.\n        label_smoothing: (Optional) Float in `[0, 1]`.\n            When > 0, label values are smoothed,\n            meaning the confidence on label values are relaxed.\n            e.g. `label_smoothing=0.2` means that we will use\n            a value of 0.1 for label \"0\" and 0.9 for label \"1\".\n\n    Examples:\n\n    >>> m = keras.metrics.BinaryCrossentropy()\n    >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]])\n    >>> m.result()\n    0.81492424\n\n    >>> m.reset_state()\n    >>> m.update_state([[0, 1], [0, 0]], [[0.6, 0.4], [0.4, 0.6]],\n    ...                sample_weight=[1, 0])\n    >>> m.result()\n    0.9162905\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(\n        optimizer='sgd',\n        loss='mse',\n        metrics=[keras.metrics.BinaryCrossentropy()])\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        name=\"binary_crossentropy\",\n        dtype=None,\n        from_logits=False,\n        label_smoothing=0,\n    ):\n        super().__init__(\n            binary_crossentropy,\n            name,\n            dtype=dtype,\n            from_logits=from_logits,\n            label_smoothing=label_smoothing,\n        )\n        self.from_logits = from_logits\n        self.label_smoothing = label_smoothing\n        # Metric should be minimized during optimization.\n        self._direction = \"down\"\n\n    def get_config(self):\n        return {\n            \"name\": self.name,\n            \"dtype\": self.dtype,\n            \"from_logits\": self.from_logits,\n            \"label_smoothing\": self.label_smoothing,\n        }\n\n\n@keras_export(\"keras.metrics.CategoricalCrossentropy\")\nclass CategoricalCrossentropy(reduction_metrics.MeanMetricWrapper):\n    \"\"\"Computes the crossentropy metric between the labels and predictions.\n\n    This is the crossentropy metric class to be used when there are multiple\n    label classes (2 or more). It assumes that labels are one-hot encoded,\n    e.g., when labels values are `[2, 0, 1]`, then\n    `y_true` is `[[0, 0, 1], [1, 0, 0], [0, 1, 0]]`.\n\n    Args:\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n        from_logits: (Optional) Whether output is expected to be\n            a logits tensor. By default, we consider that output\n            encodes a probability distribution.\n        label_smoothing: (Optional) Float in `[0, 1]`.\n            When > 0, label values are smoothed, meaning the confidence\n            on label values are relaxed. e.g. `label_smoothing=0.2` means\n            that we will use a value of 0.1 for label\n            \"0\" and 0.9 for label \"1\".\n        axis: (Optional) Defaults to `-1`.\n            The dimension along which entropy is computed.\n\n    Examples:\n\n    >>> # EPSILON = 1e-7, y = y_true, y` = y_pred\n    >>> # y` = clip_ops.clip_by_value(output, EPSILON, 1. - EPSILON)\n    >>> # y` = [[0.05, 0.95, EPSILON], [0.1, 0.8, 0.1]]\n    >>> # xent = -sum(y * log(y'), axis = -1)\n    >>> #      = -((log 0.95), (log 0.1))\n    >>> #      = [0.051, 2.302]\n    >>> # Reduced xent = (0.051 + 2.302) / 2\n    >>> m = keras.metrics.CategoricalCrossentropy()\n    >>> m.update_state([[0, 1, 0], [0, 0, 1]],\n    ...                [[0.05, 0.95, 0], [0.1, 0.8, 0.1]])\n    >>> m.result()\n    1.1769392\n\n    >>> m.reset_state()\n    >>> m.update_state([[0, 1, 0], [0, 0, 1]],\n    ...                [[0.05, 0.95, 0], [0.1, 0.8, 0.1]],\n    ...                sample_weight=np.array([0.3, 0.7]))\n    >>> m.result()\n    1.6271976\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(\n        optimizer='sgd',\n        loss='mse',\n        metrics=[keras.metrics.CategoricalCrossentropy()])\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        name=\"categorical_crossentropy\",\n        dtype=None,\n        from_logits=False,\n        label_smoothing=0,\n        axis=-1,\n    ):\n        super().__init__(\n            categorical_crossentropy,\n            name,\n            dtype=dtype,\n            from_logits=from_logits,\n            label_smoothing=label_smoothing,\n            axis=axis,\n        )\n        self.from_logits = from_logits\n        self.label_smoothing = label_smoothing\n        self.axis = axis\n        # Metric should be minimized during optimization.\n        self._direction = \"down\"\n\n    def get_config(self):\n        return {\n            \"name\": self.name,\n            \"dtype\": self.dtype,\n            \"from_logits\": self.from_logits,\n            \"label_smoothing\": self.label_smoothing,\n            \"axis\": self.axis,\n        }\n\n\n@keras_export(\"keras.metrics.SparseCategoricalCrossentropy\")\nclass SparseCategoricalCrossentropy(reduction_metrics.MeanMetricWrapper):\n    \"\"\"Computes the crossentropy metric between the labels and predictions.\n\n    Use this crossentropy metric when there are two or more label classes.\n    It expects labels to be provided as integers. If you want to provide labels\n    that are one-hot encoded, please use the `CategoricalCrossentropy`\n    metric instead.\n\n    There should be `num_classes` floating point values per feature for `y_pred`\n    and a single floating point value per feature for `y_true`.\n\n    Args:\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n        from_logits: (Optional) Whether output is expected\n            to be a logits tensor. By default, we consider that output\n            encodes a probability distribution.\n        ignore_class: (Optional) Integer.\n            Class to ignore. If `None`, no class is ignored.\n        axis: (Optional) Defaults to `-1`.\n            The dimension along which entropy is computed.\n\n    Examples:\n\n    >>> # y_true = one_hot(y_true) = [[0, 1, 0], [0, 0, 1]]\n    >>> # logits = log(y_pred)\n    >>> # softmax = exp(logits) / sum(exp(logits), axis=-1)\n    >>> # softmax = [[0.05, 0.95, EPSILON], [0.1, 0.8, 0.1]]\n    >>> # xent = -sum(y * log(softmax), 1)\n    >>> # log(softmax) = [[-2.9957, -0.0513, -16.1181],\n    >>> #                [-2.3026, -0.2231, -2.3026]]\n    >>> # y_true * log(softmax) = [[0, -0.0513, 0], [0, 0, -2.3026]]\n    >>> # xent = [0.0513, 2.3026]\n    >>> # Reduced xent = (0.0513 + 2.3026) / 2\n    >>> m = keras.metrics.SparseCategoricalCrossentropy()\n    >>> m.update_state([1, 2],\n    ...                [[0.05, 0.95, 0], [0.1, 0.8, 0.1]])\n    >>> m.result()\n    1.1769392\n\n    >>> m.reset_state()\n    >>> m.update_state([1, 2],\n    ...                [[0.05, 0.95, 0], [0.1, 0.8, 0.1]],\n    ...                sample_weight=np.array([0.3, 0.7]))\n    >>> m.result()\n    1.6271976\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(\n        optimizer='sgd',\n        loss='mse',\n        metrics=[keras.metrics.SparseCategoricalCrossentropy()])\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        name=\"sparse_categorical_crossentropy\",\n        dtype=None,\n        from_logits=False,\n        ignore_class=None,\n        axis=-1,\n    ):\n        super().__init__(\n            sparse_categorical_crossentropy,\n            name=name,\n            dtype=dtype,\n            from_logits=from_logits,\n            ignore_class=ignore_class,\n            axis=axis,\n        )\n        self.from_logits = from_logits\n        self.ignore_class = ignore_class\n        self.axis = axis\n        # Metric should be minimized during optimization.\n        self._direction = \"down\"\n\n    def get_config(self):\n        return {\n            \"name\": self.name,\n            \"dtype\": self.dtype,\n            \"from_logits\": self.from_logits,\n            \"ignore_class\": self.ignore_class,\n            \"axis\": self.axis,\n        }\n"
  },
  {
    "path": "keras/src/metrics/probabilistic_metrics_test.py",
    "content": "import numpy as np\n\nfrom keras.src import metrics\nfrom keras.src import testing\n\n\nclass KLDivergenceTest(testing.TestCase):\n    def setup(self):\n        self.y_pred = np.asarray(\n            [0.4, 0.9, 0.12, 0.36, 0.3, 0.4], dtype=np.float32\n        ).reshape((2, 3))\n        self.y_true = np.asarray(\n            [0.5, 0.8, 0.12, 0.7, 0.43, 0.8], dtype=np.float32\n        ).reshape((2, 3))\n\n        self.batch_size = 2\n        self.expected_results = np.multiply(\n            self.y_true, np.log(self.y_true / self.y_pred)\n        )\n\n    def test_config(self):\n        k_obj = metrics.KLDivergence(name=\"kld\", dtype=\"int32\")\n        self.assertEqual(k_obj.name, \"kld\")\n        self.assertEqual(k_obj._dtype, \"int32\")\n\n        k_obj2 = metrics.KLDivergence.from_config(k_obj.get_config())\n        self.assertEqual(k_obj2.name, \"kld\")\n        self.assertEqual(k_obj2._dtype, \"int32\")\n\n    def test_unweighted(self):\n        self.setup()\n        k_obj = metrics.KLDivergence()\n\n        k_obj.update_state(self.y_true, self.y_pred)\n        result = k_obj.result()\n        expected_result = np.sum(self.expected_results) / self.batch_size\n        self.assertAllClose(result, expected_result, atol=1e-3)\n\n    def test_weighted(self):\n        self.setup()\n        k_obj = metrics.KLDivergence()\n\n        sample_weight = np.asarray([1.2, 3.4], dtype=np.float32).reshape((2, 1))\n        result = k_obj(self.y_true, self.y_pred, sample_weight=sample_weight)\n\n        sample_weight = np.asarray(\n            [1.2, 1.2, 1.2, 3.4, 3.4, 3.4], dtype=np.float32\n        ).reshape((2, 3))\n        expected_result = np.multiply(self.expected_results, sample_weight)\n        expected_result = np.sum(expected_result) / (1.2 + 3.4)\n        self.assertAllClose(result, expected_result, atol=1e-3)\n\n\nclass PoissonTest(testing.TestCase):\n    def setup(self):\n        self.y_pred = np.asarray([1, 9, 2, 5, 2, 6], dtype=np.float32).reshape(\n            (2, 3)\n        )\n        self.y_true = np.asarray([4, 8, 12, 8, 1, 3], dtype=np.float32).reshape(\n            (2, 3)\n        )\n        self.batch_size = 6\n        self.expected_results = self.y_pred - np.multiply(\n            self.y_true, np.log(self.y_pred)\n        )\n\n    def test_config(self):\n        self.run_class_serialization_test(metrics.Poisson(name=\"poisson\"))\n\n    def test_unweighted(self):\n        self.setup()\n        poisson_obj = metrics.Poisson()\n        poisson_obj.update_state(self.y_true, self.y_pred)\n\n        result = poisson_obj.result()\n        expected_result = np.sum(self.expected_results) / self.batch_size\n        self.assertAllClose(result, expected_result, atol=1e-3)\n\n    def test_weighted(self):\n        self.setup()\n        poisson_obj = metrics.Poisson()\n        sample_weight = np.asarray([1.2, 3.4], dtype=np.float32).reshape((2, 1))\n\n        result = poisson_obj(\n            self.y_true, self.y_pred, sample_weight=sample_weight\n        )\n        sample_weight = np.asarray(\n            [1.2, 1.2, 1.2, 3.4, 3.4, 3.4], dtype=np.float32\n        ).reshape((2, 3))\n        expected_result = np.multiply(self.expected_results, sample_weight)\n        expected_result = np.sum(expected_result) / np.sum(sample_weight)\n        self.assertAllClose(result, expected_result, atol=1e-3)\n\n\nclass BinaryCrossentropyTest(testing.TestCase):\n    def test_config(self):\n        self.run_class_serialization_test(\n            metrics.BinaryCrossentropy(\n                name=\"bce\", dtype=\"int32\", label_smoothing=0.2\n            )\n        )\n\n    def test_unweighted(self):\n        bce_obj = metrics.BinaryCrossentropy()\n        y_true = np.array([1, 0, 1, 0]).reshape([2, 2])\n        y_pred = np.array([1, 1, 1, 0], dtype=np.float32).reshape([2, 2])\n        result = bce_obj(y_true, y_pred)\n        self.assertAllClose(result, 3.9855, atol=1e-3)\n\n    def test_unweighted_with_logits(self):\n        bce_obj = metrics.BinaryCrossentropy(from_logits=True)\n\n        y_true = np.array([[1, 0, 1], [0, 1, 1]])\n        y_pred = np.array([[10.0, -10.0, 10.0], [10.0, 10.0, -10.0]])\n        result = bce_obj(y_true, y_pred)\n        self.assertAllClose(result, 3.333, atol=1e-3)\n\n    def test_weighted(self):\n        bce_obj = metrics.BinaryCrossentropy()\n        y_true = np.array([1, 0, 1, 0]).reshape([2, 2])\n        y_pred = np.array([1, 1, 1, 0], dtype=np.float32).reshape([2, 2])\n        sample_weight = np.array([1.5, 2.0])\n        result = bce_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose(result, 3.4162, atol=1e-3)\n\n    def test_weighted_from_logits(self):\n        bce_obj = metrics.BinaryCrossentropy(from_logits=True)\n        y_true = np.array([[1, 0, 1], [0, 1, 1]])\n        y_pred = np.array([[10.0, -10.0, 10.0], [10.0, 10.0, -10.0]])\n        sample_weight = np.array([2.0, 2.5])\n        result = bce_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose(result, 3.7037, atol=1e-3)\n\n    def test_label_smoothing(self):\n        logits = np.array(((10.0, -10.0, -10.0)))\n        y_true = np.array(((1, 0, 1)))\n        label_smoothing = 0.1\n        bce_obj = metrics.BinaryCrossentropy(\n            from_logits=True, label_smoothing=label_smoothing\n        )\n        result = bce_obj(y_true, logits)\n        expected_value = (10.0 + 5.0 * label_smoothing) / 3.0\n        self.assertAllClose(expected_value, result, atol=1e-3)\n\n\nclass CategoricalCrossentropyTest(testing.TestCase):\n    def test_config(self):\n        self.run_class_serialization_test(\n            metrics.CategoricalCrossentropy(\n                name=\"cce\", dtype=\"int32\", label_smoothing=0.2\n            )\n        )\n\n    def test_unweighted(self):\n        cce_obj = metrics.CategoricalCrossentropy()\n        y_true = np.array([[0, 1, 0], [0, 0, 1]])\n        y_pred = np.array([[0.05, 0.95, 0], [0.1, 0.8, 0.1]])\n        result = cce_obj(y_true, y_pred)\n        self.assertAllClose(result, 1.176, atol=1e-3)\n\n    def test_unweighted_from_logits(self):\n        cce_obj = metrics.CategoricalCrossentropy(from_logits=True)\n\n        y_true = np.array([[0, 1, 0], [0, 0, 1]])\n        logits = np.array([[1, 9, 0], [1, 8, 1]], dtype=np.float32)\n        result = cce_obj(y_true, logits)\n        self.assertAllClose(result, 3.5011, atol=1e-3)\n\n    def test_weighted(self):\n        cce_obj = metrics.CategoricalCrossentropy()\n\n        y_true = np.array([[0, 1, 0], [0, 0, 1]])\n        y_pred = np.array([[0.05, 0.95, 0], [0.1, 0.8, 0.1]])\n        sample_weight = np.array([1.5, 2.0])\n        result = cce_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose(result, 1.338, atol=1e-3)\n\n    def test_weighted_from_logits(self):\n        cce_obj = metrics.CategoricalCrossentropy(from_logits=True)\n\n        y_true = np.array([[0, 1, 0], [0, 0, 1]])\n        logits = np.array([[1, 9, 0], [1, 8, 1]], dtype=np.float32)\n        sample_weight = np.array([1.5, 2.0])\n        result = cce_obj(y_true, logits, sample_weight=sample_weight)\n        self.assertAllClose(result, 4.0012, atol=1e-3)\n\n    def test_label_smoothing(self):\n        y_true = np.array([[0, 1, 0], [0, 0, 1]])\n        logits = np.array([[1, 9, 0], [1, 8, 1]], dtype=np.float32)\n        label_smoothing = 0.1\n        cce_obj = metrics.CategoricalCrossentropy(\n            from_logits=True, label_smoothing=label_smoothing\n        )\n        loss = cce_obj(y_true, logits)\n        self.assertAllClose(loss, 3.667, atol=1e-3)\n\n\nclass SparseCategoricalCrossentropyTest(testing.TestCase):\n    def test_config(self):\n        self.run_class_serialization_test(\n            metrics.SparseCategoricalCrossentropy(name=\"scce\", dtype=\"int32\")\n        )\n\n    def test_unweighted(self):\n        scce_obj = metrics.SparseCategoricalCrossentropy()\n\n        y_true = np.array([1, 2])\n        y_pred = np.array([[0.05, 0.95, 0], [0.1, 0.8, 0.1]])\n        result = scce_obj(y_true, y_pred)\n        self.assertAllClose(result, 1.176, atol=1e-3)\n\n    def test_unweighted_from_logits(self):\n        scce_obj = metrics.SparseCategoricalCrossentropy(from_logits=True)\n\n        y_true = np.array([1, 2])\n        logits = np.array([[1, 9, 0], [1, 8, 1]], dtype=np.float32)\n        result = scce_obj(y_true, logits)\n        self.assertAllClose(result, 3.5011, atol=1e-3)\n\n    def test_weighted(self):\n        scce_obj = metrics.SparseCategoricalCrossentropy()\n\n        y_true = np.array([1, 2])\n        y_pred = np.array([[0.05, 0.95, 0], [0.1, 0.8, 0.1]])\n        sample_weight = np.array([1.5, 2.0])\n        result = scce_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose(result, 1.338, atol=1e-3)\n\n    def test_weighted_from_logits(self):\n        scce_obj = metrics.SparseCategoricalCrossentropy(from_logits=True)\n\n        y_true = np.array([1, 2])\n        logits = np.array([[1, 9, 0], [1, 8, 1]], dtype=np.float32)\n        sample_weight = np.array([1.5, 2.0])\n        result = scce_obj(y_true, logits, sample_weight=sample_weight)\n        self.assertAllClose(result, 4.0012, atol=1e-3)\n\n    def test_ignore_class(self):\n        scce_obj = metrics.SparseCategoricalCrossentropy(ignore_class=0)\n        y_true = np.array([0, 1, 2])\n        y_pred = np.array([[0.8, 0.1, 0.1], [0.05, 0.95, 0], [0.1, 0.8, 0.1]])\n        result = scce_obj(y_true, y_pred)\n        self.assertAllClose(result, 0.78462, atol=1e-3)\n\n    def test_ignore_class_weighted(self):\n        scce_obj = metrics.SparseCategoricalCrossentropy(ignore_class=0)\n        y_true = np.array([0, 1, 2])\n        y_pred = np.array([[0.8, 0.1, 0.1], [0.05, 0.95, 0], [0.1, 0.8, 0.1]])\n        sample_weight = np.array([0.5, 1.5, 2.0])\n        result = scce_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose(result, 1.33774646, atol=1e-3)\n\n    def test_ignore_class_from_logits(self):\n        scce_obj = metrics.SparseCategoricalCrossentropy(\n            from_logits=True, ignore_class=0\n        )\n        y_true = np.array([0, 1, 2])\n        logits = np.array([[10, 1, 1], [1, 9, 0], [1, 8, 1]], dtype=np.float32)\n        result = scce_obj(y_true, logits)\n        self.assertAllClose(result, 2.33409, atol=1e-3)\n"
  },
  {
    "path": "keras/src/metrics/reduction_metrics.py",
    "content": "from keras.src import backend\nfrom keras.src import initializers\nfrom keras.src import losses\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.metrics.metric import Metric\nfrom keras.src.saving import serialization_lib\n\n\ndef reduce_to_samplewise_values(values, sample_weight, reduce_fn, dtype):\n    dtype = dtype or backend.floatx()\n    mask = backend.get_keras_mask(values)\n    values = ops.cast(values, dtype=dtype)\n    if sample_weight is not None:\n        sample_weight = ops.convert_to_tensor(sample_weight, dtype=dtype)\n\n        if mask is not None:\n            sample_weight = losses.loss.apply_mask(\n                sample_weight, mask, dtype=dtype, reduction=\"sum\"\n            )\n        # Update dimensions of weights to match with values if possible.\n        values, sample_weight = losses.loss.squeeze_or_expand_to_same_rank(\n            values, sample_weight\n        )\n        # Reduce values to same ndim as weight array.\n        weight_ndim = len(sample_weight.shape)\n        values_ndim = len(values.shape)\n        if values_ndim > weight_ndim:\n            values = reduce_fn(\n                values, axis=list(range(weight_ndim, values_ndim))\n            )\n        # Broadcast sample_weight. It doesn't change the multiplication below\n        # but changes the sample_weight reduction applied later.\n        sample_weight = ops.broadcast_to(sample_weight, ops.shape(values))\n        values = values * sample_weight\n        if weight_ndim > 1:\n            sample_weight = reduce_fn(\n                sample_weight, axis=list(range(1, weight_ndim))\n            )\n\n    values_ndim = len(values.shape)\n    if values_ndim > 1:\n        values = reduce_fn(values, axis=list(range(1, values_ndim)))\n    return values, sample_weight\n\n\n@keras_export(\"keras.metrics.Sum\")\nclass Sum(Metric):\n    \"\"\"Compute the (weighted) sum of the given values.\n\n    For example, if `values` is `[1, 3, 5, 7]` then their sum is 16.\n    If `sample_weight` was specified as `[1, 1, 0, 0]` then the sum would be 4.\n\n    This metric creates one variable, `total`.\n    This is ultimately returned as the sum value.\n\n    Args:\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Example:\n\n    >>> m = metrics.Sum()\n    >>> m.update_state([1, 3, 5, 7])\n    >>> m.result()\n    16.0\n\n    >>> m = metrics.Sum()\n    >>> m.update_state([1, 3, 5, 7], sample_weight=[1, 1, 0, 0])\n    >>> m.result()\n    4.0\n    \"\"\"\n\n    def __init__(self, name=\"sum\", dtype=None):\n        super().__init__(name=name, dtype=dtype)\n        self.total = self.add_variable(\n            shape=(),\n            initializer=initializers.Zeros(),\n            dtype=self.dtype,\n            name=\"total\",\n        )\n\n    def update_state(self, values, sample_weight=None):\n        values, _ = reduce_to_samplewise_values(\n            values, sample_weight, reduce_fn=ops.sum, dtype=self.dtype\n        )\n        self.total.assign_add(ops.sum(values))\n\n    def reset_state(self):\n        self.total.assign(0)\n\n    def result(self):\n        return ops.cast(self.total, self.dtype)\n\n\n@keras_export(\"keras.metrics.Mean\")\nclass Mean(Metric):\n    \"\"\"Compute the (weighted) mean of the given values.\n\n    For example, if values is `[1, 3, 5, 7]` then the mean is 4.\n    If `sample_weight` was specified as `[1, 1, 0, 0]` then the mean would be 2.\n\n    This metric creates two variables, `total` and `count`.\n    The mean value returned is simply `total` divided by `count`.\n\n    Args:\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Example:\n\n    >>> m = Mean()\n    >>> m.update_state([1, 3, 5, 7])\n    >>> m.result()\n    4.0\n\n    >>> m.reset_state()\n    >>> m.update_state([1, 3, 5, 7], sample_weight=[1, 1, 0, 0])\n    >>> m.result()\n    2.0\n    \"\"\"\n\n    def __init__(self, name=\"mean\", dtype=None):\n        super().__init__(name=name, dtype=dtype)\n        self.total = self.add_variable(\n            shape=(),\n            initializer=initializers.Zeros(),\n            dtype=self.dtype,\n            name=\"total\",\n        )\n        self.count = self.add_variable(\n            shape=(),\n            initializer=initializers.Zeros(),\n            dtype=self.dtype,\n            name=\"count\",\n        )\n\n    def update_state(self, values, sample_weight=None):\n        values, sample_weight = reduce_to_samplewise_values(\n            values, sample_weight, reduce_fn=ops.mean, dtype=self.dtype\n        )\n        self.total.assign_add(ops.sum(values))\n        if sample_weight is not None:\n            num_samples = ops.sum(sample_weight)\n        elif len(values.shape) >= 1:\n            num_samples = ops.shape(values)[0]\n        else:\n            num_samples = 1\n        self.count.assign_add(ops.cast(num_samples, dtype=self.dtype))\n\n    def reset_state(self):\n        self.total.assign(0)\n        self.count.assign(0)\n\n    def result(self):\n        return ops.divide_no_nan(\n            self.total, ops.cast(self.count, dtype=self.dtype)\n        )\n\n\n@keras_export(\"keras.metrics.MeanMetricWrapper\")\nclass MeanMetricWrapper(Mean):\n    \"\"\"Wrap a stateless metric function with the `Mean` metric.\n\n    You could use this class to quickly build a mean metric from a function. The\n    function needs to have the signature `fn(y_true, y_pred)` and return a\n    per-sample loss array. `MeanMetricWrapper.result()` will return\n    the average metric value across all samples seen so far.\n\n    For example:\n\n    ```python\n    def mse(y_true, y_pred):\n        return (y_true - y_pred) ** 2\n\n    mse_metric = MeanMetricWrapper(fn=mse)\n    ```\n\n    Args:\n        fn: The metric function to wrap, with signature\n            `fn(y_true, y_pred, **kwargs)`.\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n        **kwargs: Keyword arguments to pass on to `fn`.\n    \"\"\"\n\n    def __init__(self, fn, name=None, dtype=None, **kwargs):\n        super().__init__(name=name, dtype=dtype)\n        self._fn = fn\n        self._fn_kwargs = kwargs\n\n        # If we are wrapping a Keras loss, register the metric's\n        # direction as \"down\" (needs to be minimized during training).\n        if (\n            self._fn in losses.ALL_OBJECTS\n            or hasattr(self._fn, \"__class__\")\n            and self._fn.__class__ in losses.ALL_OBJECTS\n        ):\n            self._direction = \"down\"\n\n    def update_state(self, y_true, y_pred, sample_weight=None):\n        mask = backend.get_keras_mask(y_pred)\n        values = self._fn(y_true, y_pred, **self._fn_kwargs)\n        sample_weight = losses.loss.apply_mask(\n            sample_weight, mask, dtype=self.dtype, reduction=\"sum\"\n        )\n        return super().update_state(values, sample_weight=sample_weight)\n\n    def get_config(self):\n        base_config = super().get_config()\n        config = {\"fn\": serialization_lib.serialize_keras_object(self._fn)}\n        config.update(serialization_lib.serialize_keras_object(self._fn_kwargs))\n        return {**base_config, **config}\n\n    @classmethod\n    def from_config(cls, config):\n        if \"fn\" in config:\n            config = serialization_lib.deserialize_keras_object(config)\n        return cls(**config)\n"
  },
  {
    "path": "keras/src/metrics/reduction_metrics_test.py",
    "content": "import numpy as np\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import metrics\nfrom keras.src import models\nfrom keras.src import testing\nfrom keras.src.backend.common.keras_tensor import KerasTensor\nfrom keras.src.metrics import reduction_metrics\nfrom keras.src.saving import register_keras_serializable\n\n\nclass SumTest(testing.TestCase):\n    def test_config(self):\n        sum_obj = reduction_metrics.Sum(name=\"sum\", dtype=\"float32\")\n        self.assertEqual(sum_obj.name, \"sum\")\n        self.assertEqual(len(sum_obj.variables), 1)\n        self.assertEqual(sum_obj._dtype, \"float32\")\n\n        # Check save and restore config\n        sum_obj2 = reduction_metrics.Sum.from_config(sum_obj.get_config())\n        self.assertEqual(sum_obj2.name, \"sum\")\n        self.assertEqual(len(sum_obj2.variables), 1)\n        self.assertEqual(sum_obj2._dtype, \"float32\")\n\n    def test_unweighted(self):\n        sum_obj = reduction_metrics.Sum(name=\"sum\", dtype=\"float32\")\n        sum_obj.update_state([1, 3, 5, 7])\n        result = sum_obj.result()\n        self.assertAllClose(result, 16.0, atol=1e-3)\n\n    def test_weighted(self):\n        sum_obj = reduction_metrics.Sum(name=\"sum\", dtype=\"float32\")\n        sum_obj.update_state([1, 3, 5, 7], sample_weight=[1, 1, 0, 0])\n        result = sum_obj.result()\n        self.assertAllClose(result, 4.0, atol=1e-3)\n\n    def test_weighted_nd(self):\n        sum_obj = reduction_metrics.Sum(name=\"sum\", dtype=\"float32\")\n        sum_obj.update_state([[1, 3], [5, 7]], sample_weight=[[1, 1], [1, 0]])\n        result = sum_obj.result()\n        self.assertAllClose(result, 9.0, atol=1e-3)\n\n    def test_weighted_nd_broadcast(self):\n        sum_obj = reduction_metrics.Sum(name=\"sum\", dtype=\"float32\")\n        sum_obj.update_state([[1, 3], [5, 7]], sample_weight=[[1, 0]])\n        result = sum_obj.result()\n        self.assertAllClose(result, 6.0, atol=1e-3)\n\n\nclass MeanTest(testing.TestCase):\n    def test_config(self):\n        mean_obj = reduction_metrics.Mean(name=\"mean\", dtype=\"float32\")\n        self.assertEqual(mean_obj.name, \"mean\")\n        self.assertEqual(len(mean_obj.variables), 2)\n        self.assertEqual(mean_obj._dtype, \"float32\")\n\n        # Check save and restore config\n        mean_obj2 = reduction_metrics.Mean.from_config(mean_obj.get_config())\n        self.assertEqual(mean_obj2.name, \"mean\")\n        self.assertEqual(len(mean_obj2.variables), 2)\n        self.assertEqual(mean_obj2._dtype, \"float32\")\n\n    def test_unweighted(self):\n        mean_obj = reduction_metrics.Mean(name=\"mean\", dtype=\"float32\")\n        mean_obj.update_state([1, 3, 5, 7])\n        result = mean_obj.result()\n        self.assertAllClose(result, 4.0, atol=1e-3)\n\n    def test_weighted(self):\n        mean_obj = reduction_metrics.Mean(name=\"mean\", dtype=\"float32\")\n        mean_obj.update_state([1, 3, 5, 7], sample_weight=[1, 1, 0, 0])\n        result = mean_obj.result()\n        self.assertAllClose(result, 2.0, atol=1e-3)\n\n    def test_weighted_negative_weights(self):\n        mean_obj = reduction_metrics.Mean(name=\"mean\", dtype=\"float32\")\n        mean_obj.update_state([1, 3, 5, 7], sample_weight=[-1, -1, 0, 0])\n        result = mean_obj.result()\n        self.assertAllClose(result, 2.0, atol=1e-3)\n\n    def test_weighted_nd(self):\n        mean_obj = reduction_metrics.Mean(name=\"mean\", dtype=\"float32\")\n        mean_obj.update_state([[1, 3], [5, 7]], sample_weight=[[1, 1], [1, 0]])\n        result = mean_obj.result()\n        self.assertAllClose(result, 3.0, atol=1e-3)\n\n    def test_weighted_nd_broadcast(self):\n        mean_obj = reduction_metrics.Mean(name=\"mean\", dtype=\"float32\")\n        mean_obj.update_state([[1, 3], [5, 7]], sample_weight=[[1, 0]])\n        result = mean_obj.result()\n        self.assertAllClose(result, 3.0, atol=1e-3)\n\n    def test_weighted_dynamic_shapes(self):\n        mean_obj = reduction_metrics.Mean(name=\"mean\", dtype=\"float32\")\n        result = backend.compute_output_spec(\n            mean_obj, KerasTensor((None, 2)), KerasTensor((None, 2))\n        )\n        self.assertAllEqual(result.shape, ())\n\n\n# How users would register a custom function or class to use with\n# MeanMetricWrapper.\n@register_keras_serializable(package=\"test\", name=\"mse\")\ndef mse(y_true, y_pred):\n    return (y_true - y_pred) ** 2\n\n\nclass MetricWrapperTest(testing.TestCase):\n    def test_config(self):\n        mse_obj = reduction_metrics.MeanMetricWrapper(\n            fn=mse, name=\"mse\", dtype=\"float32\"\n        )\n        self.assertEqual(mse_obj.name, \"mse\")\n        self.assertEqual(len(mse_obj.variables), 2)\n        self.assertEqual(mse_obj._dtype, \"float32\")\n        # Check save and restore config\n        mse_obj2 = reduction_metrics.MeanMetricWrapper.from_config(\n            mse_obj.get_config()\n        )\n        self.assertEqual(mse_obj2.name, \"mse\")\n        self.assertEqual(len(mse_obj2.variables), 2)\n        self.assertEqual(mse_obj2._dtype, \"float32\")\n        self.assertTrue(\"fn\" in mse_obj2.get_config())\n\n    def test_unweighted(self):\n        mse_obj = reduction_metrics.MeanMetricWrapper(\n            fn=mse, name=\"mse\", dtype=\"float32\"\n        )\n        y_true = np.array(\n            [[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]]\n        )\n        y_pred = np.array(\n            [[0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]]\n        )\n\n        mse_obj.update_state(y_true, y_pred)\n        result = mse_obj.result()\n        self.assertAllClose(0.5, result, atol=1e-5)\n\n    def test_weighted(self):\n        mse_obj = reduction_metrics.MeanMetricWrapper(\n            fn=mse, name=\"mse\", dtype=\"float32\"\n        )\n        y_true = np.array(\n            [[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]]\n        )\n        y_pred = np.array(\n            [[0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]]\n        )\n        sample_weight = np.array([1.0, 1.5, 2.0, 2.5])\n        result = mse_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose(0.54285, result, atol=1e-5)\n\n    def test_weighted_broadcast(self):\n        mse_obj = reduction_metrics.MeanMetricWrapper(\n            fn=mse, name=\"mse\", dtype=\"float32\"\n        )\n        y_true = np.array(\n            [[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]]\n        )\n        y_pred = np.array(\n            [[0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]]\n        )\n        sample_weight = np.array([[1.0, 0.0, 0.5, 0.0, 1.0]])\n        result = mse_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose(0.45, result, atol=1e-5)\n\n    def test_weighted_dynamic_shape(self):\n        mse_obj = reduction_metrics.MeanMetricWrapper(\n            fn=mse, name=\"mse\", dtype=\"float32\"\n        )\n        result = backend.compute_output_spec(\n            mse_obj,\n            KerasTensor((None, 5)),\n            KerasTensor((None, 5)),\n            KerasTensor((None, 5)),\n        )\n        self.assertAllEqual(result.shape, ())\n\n    def test_binary_accuracy_with_boolean_inputs(self):\n        inp = layers.Input(shape=(1,))\n        out = inp > 0.5\n        model = models.Model(inputs=inp, outputs=out)\n\n        x = np.random.rand(32, 1)\n        y = x > 0.5\n\n        res = model.predict(x)\n        metric = metrics.BinaryAccuracy()\n        metric.update_state(y, res)\n        result = metric.result()\n        self.assertEqual(result, 1.0)\n"
  },
  {
    "path": "keras/src/metrics/regression_metrics.py",
    "content": "import warnings\n\nfrom keras.src import initializers\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.losses.loss import squeeze_or_expand_to_same_rank\nfrom keras.src.losses.losses import log_cosh\nfrom keras.src.losses.losses import mean_absolute_error\nfrom keras.src.losses.losses import mean_absolute_percentage_error\nfrom keras.src.losses.losses import mean_squared_error\nfrom keras.src.losses.losses import mean_squared_logarithmic_error\nfrom keras.src.metrics import reduction_metrics\nfrom keras.src.utils.numerical_utils import normalize\n\n\n@keras_export(\"keras.metrics.MeanSquaredError\")\nclass MeanSquaredError(reduction_metrics.MeanMetricWrapper):\n    \"\"\"Computes the mean squared error between `y_true` and `y_pred`.\n\n    Formula:\n\n    ```python\n    loss = mean(square(y_true - y_pred))\n    ```\n\n    Args:\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Example:\n    >>> m = keras.metrics.MeanSquaredError()\n    >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])\n    >>> m.result()\n    0.25\n    \"\"\"\n\n    def __init__(self, name=\"mean_squared_error\", dtype=None):\n        super().__init__(fn=mean_squared_error, name=name, dtype=dtype)\n        # Metric should be minimized during optimization.\n        self._direction = \"down\"\n\n    def get_config(self):\n        return {\"name\": self.name, \"dtype\": self.dtype}\n\n\n@keras_export(\"keras.metrics.MeanAbsoluteError\")\nclass MeanAbsoluteError(reduction_metrics.MeanMetricWrapper):\n    \"\"\"Computes the mean absolute error between the labels and predictions.\n\n    Formula:\n\n    ```python\n    loss = mean(abs(y_true - y_pred))\n    ```\n\n    Args:\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Examples:\n\n    >>> m = keras.metrics.MeanAbsoluteError()\n    >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])\n    >>> m.result()\n    0.25\n\n    >>> m.reset_state()\n    >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]],\n    ...                sample_weight=[1, 0])\n    >>> m.result()\n    0.5\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(\n        optimizer='sgd',\n        loss='mse',\n        metrics=[keras.metrics.MeanAbsoluteError()])\n    ```\n    \"\"\"\n\n    def __init__(self, name=\"mean_absolute_error\", dtype=None):\n        super().__init__(mean_absolute_error, name, dtype=dtype)\n        # Metric should be minimized during optimization.\n        self._direction = \"down\"\n\n    def get_config(self):\n        return {\"name\": self.name, \"dtype\": self.dtype}\n\n\n@keras_export(\"keras.metrics.MeanAbsolutePercentageError\")\nclass MeanAbsolutePercentageError(reduction_metrics.MeanMetricWrapper):\n    \"\"\"Computes mean absolute percentage error between `y_true` and `y_pred`.\n\n    Formula:\n\n    ```python\n    loss = 100 * mean(abs((y_true - y_pred) / y_true))\n    ```\n\n    Args:\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Examples:\n    >>> m = keras.metrics.MeanAbsolutePercentageError()\n    >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])\n    >>> m.result()\n    250000000.0\n\n    >>> m.reset_state()\n    >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]],\n    ...                sample_weight=[1, 0])\n    >>> m.result()\n    500000000.0\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(\n        optimizer='sgd',\n        loss='mse',\n        metrics=[keras.metrics.MeanAbsolutePercentageError()])\n    ```\n    \"\"\"\n\n    def __init__(self, name=\"mean_absolute_percentage_error\", dtype=None):\n        super().__init__(mean_absolute_percentage_error, name, dtype=dtype)\n        # Metric should be minimized during optimization.\n        self._direction = \"down\"\n\n    def get_config(self):\n        return {\"name\": self.name, \"dtype\": self.dtype}\n\n\n@keras_export(\"keras.metrics.MeanSquaredLogarithmicError\")\nclass MeanSquaredLogarithmicError(reduction_metrics.MeanMetricWrapper):\n    \"\"\"Computes mean squared logarithmic error between `y_true` and `y_pred`.\n\n    Formula:\n\n    ```python\n    loss = mean(square(log(y_true + 1) - log(y_pred + 1)))\n    ```\n\n    Args:\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Examples:\n\n    >>> m = keras.metrics.MeanSquaredLogarithmicError()\n    >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])\n    >>> m.result()\n    0.12011322\n\n    >>> m.reset_state()\n    >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]],\n    ...                sample_weight=[1, 0])\n    >>> m.result()\n    0.24022643\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(\n        optimizer='sgd',\n        loss='mse',\n        metrics=[keras.metrics.MeanSquaredLogarithmicError()])\n    ```\n    \"\"\"\n\n    def __init__(self, name=\"mean_squared_logarithmic_error\", dtype=None):\n        super().__init__(mean_squared_logarithmic_error, name, dtype=dtype)\n        # Metric should be minimized during optimization.\n        self._direction = \"down\"\n\n    def get_config(self):\n        return {\"name\": self.name, \"dtype\": self.dtype}\n\n\n@keras_export(\"keras.metrics.RootMeanSquaredError\")\nclass RootMeanSquaredError(reduction_metrics.Mean):\n    \"\"\"Computes root mean squared error metric between `y_true` and `y_pred`.\n\n    Formula:\n\n    ```python\n    loss = sqrt(mean((y_pred - y_true) ** 2))\n    ```\n\n    Args:\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Examples:\n\n    >>> m = keras.metrics.RootMeanSquaredError()\n    >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])\n    >>> m.result()\n    0.5\n\n    >>> m.reset_state()\n    >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]],\n    ...                sample_weight=[1, 0])\n    >>> m.result()\n    0.70710677\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(\n        optimizer='sgd',\n        loss='mse',\n        metrics=[keras.metrics.RootMeanSquaredError()])\n    ```\n    \"\"\"\n\n    def __init__(self, name=\"root_mean_squared_error\", dtype=None):\n        super().__init__(name, dtype=dtype)\n        # Metric should be minimized during optimization.\n        self._direction = \"down\"\n\n    def update_state(self, y_true, y_pred, sample_weight=None):\n        \"\"\"Accumulates root mean squared error statistics.\n\n        Args:\n            y_true: The ground truth values.\n            y_pred: The predicted values.\n            sample_weight: Optional weighting of each example. Can\n                be a `Tensor` whose rank is either 0, or the same rank as\n                `y_true`, and must be broadcastable to `y_true`.\n                Defaults to `1`.\n\n        Returns:\n            Update op.\n        \"\"\"\n        y_true = ops.convert_to_tensor(y_true, self._dtype)\n        y_pred = ops.convert_to_tensor(y_pred, self._dtype)\n        y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)\n        error_sq = ops.square(y_pred - y_true)\n        return super().update_state(error_sq, sample_weight=sample_weight)\n\n    def result(self):\n        return ops.sqrt(super().result())\n\n\n@keras_export(\"keras.metrics.CosineSimilarity\")\nclass CosineSimilarity(reduction_metrics.MeanMetricWrapper):\n    \"\"\"Computes the cosine similarity between the labels and predictions.\n\n    Formula:\n\n    ```python\n    loss = sum(l2_norm(y_true) * l2_norm(y_pred))\n    ```\n    See: [Cosine Similarity](https://en.wikipedia.org/wiki/Cosine_similarity).\n    This metric keeps the average cosine similarity between `predictions` and\n    `labels` over a stream of data.\n\n    Args:\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n        axis: (Optional) Defaults to `-1`. The dimension along which the cosine\n            similarity is computed.\n\n    Examples:\n\n    >>> # l2_norm(y_true) = [[0., 1.], [1./1.414, 1./1.414]]\n    >>> # l2_norm(y_pred) = [[1., 0.], [1./1.414, 1./1.414]]\n    >>> # l2_norm(y_true) . l2_norm(y_pred) = [[0., 0.], [0.5, 0.5]]\n    >>> # result = mean(sum(l2_norm(y_true) . l2_norm(y_pred), axis=1))\n    >>> #        = ((0. + 0.) +  (0.5 + 0.5)) / 2\n    >>> m = keras.metrics.CosineSimilarity(axis=1)\n    >>> m.update_state([[0., 1.], [1., 1.]], [[1., 0.], [1., 1.]])\n    >>> m.result()\n    0.49999997\n\n    >>> m.reset_state()\n    >>> m.update_state([[0., 1.], [1., 1.]], [[1., 0.], [1., 1.]],\n    ...                sample_weight=[0.3, 0.7])\n    >>> m.result()\n    0.6999999\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(\n        optimizer='sgd',\n        loss='mse',\n        metrics=[keras.metrics.CosineSimilarity(axis=1)])\n    ```\n    \"\"\"\n\n    def __init__(self, name=\"cosine_similarity\", dtype=None, axis=-1):\n        super().__init__(cosine_similarity, name, dtype=dtype, axis=axis)\n        # Metric should be maximized during optimization.\n        self._direction = \"up\"\n\n    def get_config(self):\n        return {\"name\": self.name, \"dtype\": self.dtype}\n\n\n@keras_export(\"keras.metrics.LogCoshError\")\nclass LogCoshError(reduction_metrics.MeanMetricWrapper):\n    \"\"\"Computes the logarithm of the hyperbolic cosine of the prediction error.\n\n    Formula:\n\n    ```python\n    error = y_pred - y_true\n    logcosh = mean(log((exp(error) + exp(-error))/2), axis=-1)\n    ```\n\n    Args:\n        name: (Optional) string name of the metric instance.\n        dtype: (Optional) data type of the metric result.\n\n    Examples:\n\n    >>> m = keras.metrics.LogCoshError()\n    >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]])\n    >>> m.result()\n    0.10844523\n\n    >>> m.reset_state()\n    >>> m.update_state([[0, 1], [0, 0]], [[1, 1], [0, 0]],\n    ...                sample_weight=[1, 0])\n    >>> m.result()\n    0.21689045\n\n    Usage with `compile()` API:\n\n    ```python\n    model.compile(optimizer='sgd',\n                  loss='mse',\n                  metrics=[keras.metrics.LogCoshError()])\n    ```\n    \"\"\"\n\n    def __init__(self, name=\"logcosh\", dtype=None):\n        super().__init__(log_cosh, name, dtype=dtype)\n        # Metric should be minimized during optimization.\n        self._direction = \"down\"\n\n    def get_config(self):\n        return {\"name\": self.name, \"dtype\": self.dtype}\n\n\n# Adapted from TF-Addons implementation (RSquare class).\n@keras_export(\"keras.metrics.R2Score\")\nclass R2Score(reduction_metrics.Metric):\n    \"\"\"Computes R2 score.\n\n    Formula:\n\n    ```python\n    sum_squares_residuals = sum((y_true - y_pred) ** 2)\n    sum_squares = sum((y_true - mean(y_true)) ** 2)\n    R2 = 1 - sum_squares_residuals / sum_squares\n    ```\n\n    This is also called the\n    [coefficient of determination](\n    https://en.wikipedia.org/wiki/Coefficient_of_determination).\n\n    It indicates how close the fitted regression line\n    is to ground-truth data.\n\n    - The highest score possible is 1.0. It indicates that the predictors\n        perfectly accounts for variation in the target.\n    - A score of 0.0 indicates that the predictors do not\n        account for variation in the target.\n    - It can also be negative if the model is worse than random.\n\n    This metric can also compute the \"Adjusted R2\" score.\n\n    Args:\n        class_aggregation: Specifies how to aggregate scores corresponding to\n            different output classes (or target dimensions),\n            i.e. different dimensions on the last axis of the predictions.\n            Equivalent to `multioutput` argument in Scikit-Learn.\n            Should be one of\n            `None` (no aggregation), `\"uniform_average\"`,\n            `\"variance_weighted_average\"`.\n        num_regressors: Number of independent regressors used\n            (\"Adjusted R2\" score). 0 is the standard R2 score.\n            Defaults to `0`.\n        name: Optional. string name of the metric instance.\n        dtype: Optional. data type of the metric result.\n\n    Example:\n\n    >>> y_true = np.array([[1], [4], [3]], dtype=np.float32)\n    >>> y_pred = np.array([[2], [4], [4]], dtype=np.float32)\n    >>> metric = keras.metrics.R2Score()\n    >>> metric.update_state(y_true, y_pred)\n    >>> result = metric.result()\n    >>> result\n    0.57142854\n    \"\"\"\n\n    def __init__(\n        self,\n        class_aggregation=\"uniform_average\",\n        num_regressors=0,\n        name=\"r2_score\",\n        dtype=None,\n    ):\n        super().__init__(name=name, dtype=dtype)\n        # Metric should be maximized during optimization.\n        self._direction = \"up\"\n\n        valid_class_aggregation_values = (\n            None,\n            \"uniform_average\",\n            \"variance_weighted_average\",\n        )\n        if class_aggregation not in valid_class_aggregation_values:\n            raise ValueError(\n                \"Invalid value for argument `class_aggregation`. Expected \"\n                f\"one of {valid_class_aggregation_values}. \"\n                f\"Received: class_aggregation={class_aggregation}\"\n            )\n        if num_regressors < 0:\n            raise ValueError(\n                \"Invalid value for argument `num_regressors`. \"\n                \"Expected a value >= 0. \"\n                f\"Received: num_regressors={num_regressors}\"\n            )\n        self.class_aggregation = class_aggregation\n        self.num_regressors = num_regressors\n        self.num_samples = self.add_variable(\n            shape=(),\n            initializer=initializers.Zeros(),\n            name=\"num_samples\",\n        )\n        self._built = False\n\n    def _build(self, y_true_shape, y_pred_shape):\n        if len(y_pred_shape) != 2 or len(y_true_shape) != 2:\n            raise ValueError(\n                \"R2Score expects 2D inputs with shape \"\n                \"(batch_size, output_dim). Received input \"\n                f\"shapes: y_pred.shape={y_pred_shape} and \"\n                f\"y_true.shape={y_true_shape}.\"\n            )\n        if y_pred_shape[-1] is None or y_true_shape[-1] is None:\n            raise ValueError(\n                \"R2Score expects 2D inputs with shape \"\n                \"(batch_size, output_dim), with output_dim fully \"\n                \"defined (not None). Received input \"\n                f\"shapes: y_pred.shape={y_pred_shape} and \"\n                f\"y_true.shape={y_true_shape}.\"\n            )\n        num_classes = y_pred_shape[-1]\n        self.squared_sum = self.add_variable(\n            name=\"squared_sum\",\n            shape=[num_classes],\n            initializer=initializers.Zeros(),\n        )\n        self.sum = self.add_variable(\n            name=\"sum\",\n            shape=[num_classes],\n            initializer=initializers.Zeros(),\n        )\n        self.total_mse = self.add_variable(\n            name=\"residual\",\n            shape=[num_classes],\n            initializer=initializers.Zeros(),\n        )\n        self.count = self.add_variable(\n            name=\"count\",\n            shape=[num_classes],\n            initializer=initializers.Zeros(),\n        )\n        self._built = True\n\n    def update_state(self, y_true, y_pred, sample_weight=None):\n        \"\"\"Accumulates root mean squared error statistics.\n\n        Args:\n            y_true: The ground truth values.\n            y_pred: The predicted values.\n            sample_weight: Optional weighting of each example. Can\n                be a `Tensor` whose rank is either 0, or the same rank as\n                `y_true`, and must be broadcastable to `y_true`.\n                Defaults to `1`.\n\n        Returns:\n            Update op.\n        \"\"\"\n        y_true = ops.convert_to_tensor(y_true, dtype=self._dtype)\n        y_pred = ops.convert_to_tensor(y_pred, dtype=self._dtype)\n        y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)\n        if not self._built:\n            self._build(y_true.shape, y_pred.shape)\n\n        if sample_weight is None:\n            sample_weight = 1\n\n        sample_weight = ops.convert_to_tensor(sample_weight, dtype=self.dtype)\n\n        if len(sample_weight.shape) == 1:\n            # Make sure there's a features dimension\n            sample_weight = ops.expand_dims(sample_weight, axis=1)\n\n        sample_weight = ops.broadcast_to(sample_weight, ops.shape(y_true))\n\n        weighted_y_true = y_true * ops.cast(sample_weight, y_true.dtype)\n        self.sum.assign(self.sum + ops.sum(weighted_y_true, axis=0))\n        self.squared_sum.assign(\n            self.squared_sum + ops.sum(y_true * weighted_y_true, axis=0)\n        )\n        self.total_mse.assign(\n            self.total_mse\n            + ops.sum(\n                (y_true - y_pred) ** 2 * ops.cast(sample_weight, y_true.dtype),\n                axis=0,\n            )\n        )\n        self.count.assign(self.count + ops.sum(sample_weight, axis=0))\n        self.num_samples.assign(self.num_samples + ops.size(y_true))\n\n    def result(self):\n        mean = self.sum / self.count\n        total = self.squared_sum - self.sum * mean\n        raw_scores = 1 - (self.total_mse / total)\n        raw_scores = ops.where(ops.isinf(raw_scores), 0.0, raw_scores)\n\n        if self.class_aggregation == \"uniform_average\":\n            r2_score = ops.mean(raw_scores)\n        elif self.class_aggregation == \"variance_weighted_average\":\n            weighted_sum = ops.sum(total * raw_scores)\n            sum_of_weights = ops.sum(total)\n            r2_score = weighted_sum / sum_of_weights\n        else:\n            r2_score = raw_scores\n\n        if self.num_regressors != 0:\n            if self.num_regressors > self.num_samples - 1:\n                warnings.warn(\n                    \"More independent predictors than datapoints \"\n                    \"in adjusted R2 score. Falling back to standard R2 score.\",\n                    stacklevel=2,\n                )\n            elif self.num_regressors == self.num_samples - 1:\n                warnings.warn(\n                    \"Division by zero in Adjusted R2 score. \"\n                    \"Falling back to standard R2 score.\",\n                    stacklevel=2,\n                )\n            else:\n                n = ops.convert_to_tensor(self.num_samples, dtype=\"float32\")\n                p = ops.convert_to_tensor(self.num_regressors, dtype=\"float32\")\n                num = ops.multiply(\n                    ops.subtract(1.0, r2_score), ops.subtract(n, 1.0)\n                )\n                den = ops.subtract(ops.subtract(n, p), 1.0)\n                r2_score = ops.subtract(1.0, ops.divide(num, den))\n        return r2_score\n\n    def reset_state(self):\n        for v in self.variables:\n            v.assign(ops.zeros(v.shape, dtype=v.dtype))\n\n    def get_config(self):\n        config = {\n            \"name\": self.name,\n            \"dtype\": self.dtype,\n            \"class_aggregation\": self.class_aggregation,\n            \"num_regressors\": self.num_regressors,\n        }\n        base_config = super().get_config()\n        return {**base_config, **config}\n\n\ndef cosine_similarity(y_true, y_pred, axis=-1):\n    \"\"\"Computes the cosine similarity between labels and predictions.\n\n    Formula:\n\n    ```python\n    loss = sum(l2_norm(y_true) * l2_norm(y_pred))\n    ```\n\n    Args:\n        y_true: Tensor of true targets.\n        y_pred: Tensor of predicted targets.\n        axis: Axis along which to determine similarity. Defaults to `-1`.\n\n    Returns:\n        Cosine similarity tensor.\n\n    Example:\n\n    >>> y_true = [[0., 1.], [1., 1.], [1., 1.]]\n    >>> y_pred = [[1., 0.], [1., 1.], [-1., -1.]]\n    >>> loss = keras.losses.cosine_similarity(y_true, y_pred, axis=-1)\n    [0., 0.99999994, -0.99999994]\n    \"\"\"\n    y_pred = ops.convert_to_tensor(y_pred)\n    y_true = ops.convert_to_tensor(y_true, dtype=y_pred.dtype)\n    y_true, y_pred = squeeze_or_expand_to_same_rank(y_true, y_pred)\n    y_pred = normalize(y_pred, axis=axis)\n    y_true = normalize(y_true, axis=axis)\n    return ops.sum(y_true * y_pred, axis=axis)\n"
  },
  {
    "path": "keras/src/metrics/regression_metrics_test.py",
    "content": "import numpy as np\nfrom absl.testing import parameterized\n\nfrom keras.src import testing\nfrom keras.src.metrics import regression_metrics as metrics\n\n\nclass MeanSquaredErrorTest(testing.TestCase):\n    def test_config(self):\n        # TODO\n        pass\n\n    def test_unweighted(self):\n        mse_obj = metrics.MeanSquaredError()\n        y_true = np.array(\n            [[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]]\n        )\n        y_pred = np.array(\n            [[0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]]\n        )\n\n        mse_obj.update_state(y_true, y_pred)\n        result = mse_obj.result()\n        self.assertAllClose(0.5, result, atol=1e-5)\n\n    def test_weighted(self):\n        mse_obj = metrics.MeanSquaredError()\n        y_true = np.array(\n            [[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]]\n        )\n        y_pred = np.array(\n            [[0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]]\n        )\n        sample_weight = np.array([1.0, 1.5, 2.0, 2.5])\n        result = mse_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose(0.54285, result, atol=1e-5)\n\n\nclass CosineSimilarityTest(testing.TestCase):\n    def l2_norm(self, x, axis):\n        epsilon = 1e-12\n        square_sum = np.sum(np.square(x), axis=axis, keepdims=True)\n        x_inv_norm = 1 / np.sqrt(np.maximum(square_sum, epsilon))\n        return np.multiply(x, x_inv_norm)\n\n    def setup(self, axis=1):\n        self.np_y_true = np.asarray([[1, 9, 2], [-5, -2, 6]], dtype=np.float32)\n        self.np_y_pred = np.asarray([[4, 8, 12], [8, 1, 3]], dtype=np.float32)\n\n        y_true = self.l2_norm(self.np_y_true, axis)\n        y_pred = self.l2_norm(self.np_y_pred, axis)\n        self.expected_loss = np.sum(np.multiply(y_true, y_pred), axis=(axis,))\n\n        self.y_true = self.np_y_true\n        self.y_pred = self.np_y_pred\n\n    def test_config(self):\n        cosine_obj = metrics.CosineSimilarity(\n            axis=2, name=\"my_cos\", dtype=\"int32\"\n        )\n        self.assertEqual(cosine_obj.name, \"my_cos\")\n        self.assertEqual(cosine_obj.dtype, \"int32\")\n\n        # Check save and restore config\n        cosine_obj2 = metrics.CosineSimilarity.from_config(\n            cosine_obj.get_config()\n        )\n        self.assertEqual(cosine_obj2.name, \"my_cos\")\n        self.assertEqual(cosine_obj2._dtype, \"int32\")\n\n    def test_unweighted(self):\n        self.setup()\n        cosine_obj = metrics.CosineSimilarity()\n        loss = cosine_obj(self.y_true, self.y_pred)\n        expected_loss = np.mean(self.expected_loss)\n        self.assertAlmostEqual(loss, expected_loss, 3)\n\n    def test_weighted(self):\n        self.setup()\n        cosine_obj = metrics.CosineSimilarity()\n        sample_weight = np.asarray([1.2, 3.4])\n        loss = cosine_obj(self.y_true, self.y_pred, sample_weight=sample_weight)\n        expected_loss = np.sum(self.expected_loss * sample_weight) / np.sum(\n            sample_weight\n        )\n        self.assertAlmostEqual(loss, expected_loss, 3)\n\n    def test_axis(self):\n        self.setup(axis=1)\n        cosine_obj = metrics.CosineSimilarity(axis=1)\n        loss = cosine_obj(self.y_true, self.y_pred)\n        expected_loss = np.mean(self.expected_loss)\n        self.assertAlmostEqual(loss, expected_loss, 3)\n\n\nclass MeanAbsoluteErrorTest(testing.TestCase):\n    def test_config(self):\n        mae_obj = metrics.MeanAbsoluteError(name=\"my_mae\", dtype=\"int32\")\n        self.assertEqual(mae_obj.name, \"my_mae\")\n        self.assertEqual(mae_obj._dtype, \"int32\")\n\n        # Check save and restore config\n        mae_obj2 = metrics.MeanAbsoluteError.from_config(mae_obj.get_config())\n        self.assertEqual(mae_obj2.name, \"my_mae\")\n        self.assertEqual(mae_obj2._dtype, \"int32\")\n\n    def test_unweighted(self):\n        mae_obj = metrics.MeanAbsoluteError()\n        y_true = np.array(\n            [[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]]\n        )\n        y_pred = np.array(\n            [[0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]]\n        )\n\n        mae_obj.update_state(y_true, y_pred)\n        result = mae_obj.result()\n        self.assertAllClose(0.5, result, atol=1e-5)\n\n    def test_weighted(self):\n        mae_obj = metrics.MeanAbsoluteError()\n        y_true = np.array(\n            [[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]]\n        )\n        y_pred = np.array(\n            [[0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]]\n        )\n        sample_weight = np.array([1.0, 1.5, 2.0, 2.5])\n        result = mae_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose(0.54285, result, atol=1e-5)\n\n\nclass MeanAbsolutePercentageErrorTest(testing.TestCase):\n    def test_config(self):\n        mape_obj = metrics.MeanAbsolutePercentageError(\n            name=\"my_mape\", dtype=\"int32\"\n        )\n        self.assertEqual(mape_obj.name, \"my_mape\")\n        self.assertEqual(mape_obj._dtype, \"int32\")\n\n        # Check save and restore config\n        mape_obj2 = metrics.MeanAbsolutePercentageError.from_config(\n            mape_obj.get_config()\n        )\n        self.assertEqual(mape_obj2.name, \"my_mape\")\n        self.assertEqual(mape_obj2._dtype, \"int32\")\n\n    def test_unweighted(self):\n        mape_obj = metrics.MeanAbsolutePercentageError()\n        y_true = np.array(\n            [[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]]\n        )\n        y_pred = np.array(\n            [\n                [0, 0, 1, 1, 0],\n                [1, 1, 1, 1, 1],\n                [0, 1, 0, 1, 0],\n                [1, 1, 1, 1, 1],\n            ],\n            dtype=\"float32\",\n        )\n\n        result = mape_obj(y_true, y_pred)\n        self.assertAllClose(35e7, result, atol=1e-5)\n\n    def test_weighted(self):\n        mape_obj = metrics.MeanAbsolutePercentageError()\n        y_true = np.array(\n            [[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]]\n        )\n        y_pred = np.array(\n            [\n                [0, 0, 1, 1, 0],\n                [1, 1, 1, 1, 1],\n                [0, 1, 0, 1, 0],\n                [1, 1, 1, 1, 1],\n            ],\n            dtype=\"float32\",\n        )\n\n        sample_weight = np.array([1.0, 1.5, 2.0, 2.5])\n        result = mape_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose(40e7, result, atol=1e-5)\n\n\nclass MeanSquaredLogarithmicErrorTest(testing.TestCase):\n    def test_config(self):\n        msle_obj = metrics.MeanSquaredLogarithmicError(\n            name=\"my_msle\", dtype=\"int32\"\n        )\n        self.assertEqual(msle_obj.name, \"my_msle\")\n        self.assertEqual(msle_obj._dtype, \"int32\")\n\n        # Check save and restore config\n        msle_obj2 = metrics.MeanSquaredLogarithmicError.from_config(\n            msle_obj.get_config()\n        )\n        self.assertEqual(msle_obj2.name, \"my_msle\")\n        self.assertEqual(msle_obj2._dtype, \"int32\")\n\n    def test_unweighted(self):\n        msle_obj = metrics.MeanSquaredLogarithmicError()\n        y_true = np.array(\n            [[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]]\n        )\n        y_pred = np.array(\n            [[0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]]\n        )\n\n        msle_obj.update_state(y_true, y_pred)\n        result = msle_obj.result()\n        self.assertAllClose(0.24022, result, atol=1e-5)\n\n    def test_weighted(self):\n        msle_obj = metrics.MeanSquaredLogarithmicError()\n        y_true = np.array(\n            [[0, 1, 0, 1, 0], [0, 0, 1, 1, 1], [1, 1, 1, 1, 0], [0, 0, 0, 0, 1]]\n        )\n        y_pred = np.array(\n            [[0, 0, 1, 1, 0], [1, 1, 1, 1, 1], [0, 1, 0, 1, 0], [1, 1, 1, 1, 1]]\n        )\n        sample_weight = np.array([1.0, 1.5, 2.0, 2.5])\n        result = msle_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose(0.26082, result, atol=1e-5)\n\n\nclass RootMeanSquaredErrorTest(testing.TestCase):\n    def test_config(self):\n        rmse_obj = metrics.RootMeanSquaredError(name=\"rmse\", dtype=\"int32\")\n        self.assertEqual(rmse_obj.name, \"rmse\")\n        self.assertEqual(rmse_obj._dtype, \"int32\")\n\n        rmse_obj2 = metrics.RootMeanSquaredError.from_config(\n            rmse_obj.get_config()\n        )\n        self.assertEqual(rmse_obj2.name, \"rmse\")\n        self.assertEqual(rmse_obj2._dtype, \"int32\")\n\n    def test_unweighted(self):\n        rmse_obj = metrics.RootMeanSquaredError()\n        y_true = np.array([2, 4, 6])\n        y_pred = np.array([1, 3, 2])\n\n        rmse_obj.update_state(y_true, y_pred)\n        result = rmse_obj.result()\n        # error = [-1, -1, -4], square(error) = [1, 1, 16], mean = 18/3 = 6\n        self.assertAllClose(np.sqrt(6), result, atol=1e-3)\n\n    def test_weighted(self):\n        rmse_obj = metrics.RootMeanSquaredError()\n        y_true = np.array([2, 4, 6])\n        y_pred = np.array([1, 3, 2])\n        y_true = np.array([2, 4, 6, 8])\n        y_pred = np.array([1, 3, 2, 3])\n        sample_weight = np.array([0, 1, 0, 1])\n        result = rmse_obj(y_true, y_pred, sample_weight=sample_weight)\n        self.assertAllClose(np.sqrt(13), result, atol=1e-3)\n\n\nclass LogCoshErrorTest(testing.TestCase):\n    def setup(self):\n        y_true = np.asarray([[1, 9, 2], [-5, -2, 6]], dtype=np.float32)\n        y_pred = np.asarray([[4, 8, 12], [8, 1, 3]], dtype=np.float32)\n\n        self.batch_size = 6\n        error = y_pred - y_true\n        self.expected_results = np.log((np.exp(error) + np.exp(-error)) / 2)\n\n        self.y_pred = y_pred\n        self.y_true = y_true\n\n    def test_config(self):\n        logcosh_obj = metrics.LogCoshError(name=\"logcosh\", dtype=\"int32\")\n        self.assertEqual(logcosh_obj.name, \"logcosh\")\n        self.assertEqual(logcosh_obj._dtype, \"int32\")\n\n    def test_unweighted(self):\n        self.setup()\n        logcosh_obj = metrics.LogCoshError()\n\n        logcosh_obj.update_state(self.y_true, self.y_pred)\n        result = logcosh_obj.result()\n        expected_result = np.sum(self.expected_results) / self.batch_size\n        self.assertAllClose(result, expected_result, atol=1e-3)\n\n    def test_weighted(self):\n        self.setup()\n        logcosh_obj = metrics.LogCoshError(dtype=\"float32\")\n        sample_weight = np.array([[1.2], [3.4]])\n        result = logcosh_obj(\n            self.y_true, self.y_pred, sample_weight=sample_weight\n        )\n\n        sample_weight = np.asarray([1.2, 1.2, 1.2, 3.4, 3.4, 3.4]).reshape(\n            (2, 3)\n        )\n        expected_result = np.multiply(self.expected_results, sample_weight)\n        expected_result = np.sum(expected_result) / np.sum(sample_weight)\n        self.assertAllClose(result, expected_result, atol=1e-3)\n\n\nclass R2ScoreTest(testing.TestCase):\n    def _run_test(\n        self,\n        y_true,\n        y_pred,\n        sample_weights,\n        class_aggregation,\n        num_regressors,\n        reference_result,\n    ):\n        r2 = metrics.R2Score(class_aggregation, num_regressors, dtype=\"float32\")\n        r2.update_state(y_true, y_pred, sample_weights)\n        result = r2.result()\n        self.assertAllClose(result, reference_result, atol=1e-6)\n\n    def test_config(self):\n        r2_obj = metrics.R2Score(\n            class_aggregation=None, num_regressors=2, dtype=\"float32\"\n        )\n        self.assertEqual(r2_obj.class_aggregation, None)\n        self.assertEqual(r2_obj.num_regressors, 2)\n        self.assertEqual(r2_obj.dtype, \"float32\")\n\n        # Check save and restore config\n        r2_obj2 = metrics.R2Score.from_config(r2_obj.get_config())\n        self.assertEqual(r2_obj2.class_aggregation, None)\n        self.assertEqual(r2_obj2.num_regressors, 2)\n        self.assertEqual(r2_obj2.dtype, \"float32\")\n\n    @parameterized.parameters(\n        # class_aggregation, num_regressors, result\n        (None, 0, [0.37, -1.295, 0.565]),\n        (\"uniform_average\", 0, -0.12),\n        (\"variance_weighted_average\", 0, -0.12),\n    )\n    def test_r2_sklearn_comparison(\n        self, class_aggregation, num_regressors, result\n    ):\n        y_true = [[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]]\n        y_pred = [[0.4, 0.5, 0.6], [0.1, 0.2, 0.3], [0.5, 0.8, 0.2]]\n        self._run_test(\n            y_true,\n            y_pred,\n            None,\n            class_aggregation=class_aggregation,\n            num_regressors=num_regressors,\n            reference_result=result,\n        )\n\n    @parameterized.parameters(\n        # class_aggregation, num_regressors, result\n        (None, 0, [0.17305559, -8.836666, -0.521]),\n        (None, 1, [0.054920673, -10.241904, -0.7382858]),\n        (None, 2, [-0.10259259, -12.115555, -1.0280001]),\n        (\"uniform_average\", 0, -3.0615367889404297),\n        (\"uniform_average\", 1, -3.641756534576416),\n        (\"uniform_average\", 2, -4.415382385253906),\n        (\"variance_weighted_average\", 0, -1.3710224628448486),\n        (\"variance_weighted_average\", 1, -1.7097399234771729),\n        (\"variance_weighted_average\", 2, -2.161363363265991),\n    )\n    def test_r2_tfa_comparison(self, class_aggregation, num_regressors, result):\n        y_true = [[0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]]\n        y_pred = [[0.4, 0.9, 1.6], [0.1, 1.2, 0.6], [1.5, 0.8, 0.6]]\n        sample_weights = [0.8, 0.1, 0.4]\n        self._run_test(\n            y_true,\n            y_pred,\n            sample_weights,\n            class_aggregation=class_aggregation,\n            num_regressors=num_regressors,\n            reference_result=result,\n        )\n\n    def test_errors(self):\n        # Bad class_aggregation value\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid value for argument `class_aggregation`\"\n        ):\n            metrics.R2Score(class_aggregation=\"wrong\")\n\n        # Bad num_regressors value\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid value for argument `num_regressors`\"\n        ):\n            metrics.R2Score(num_regressors=-1)\n\n        # Bad input shape\n        with self.assertRaisesRegex(ValueError, \"expects 2D inputs with shape\"):\n            r2 = metrics.R2Score()\n            r2.update_state([0.0, 1.0], [0.0, 1.0])\n"
  },
  {
    "path": "keras/src/models/__init__.py",
    "content": "from keras.src.models.functional import Functional\nfrom keras.src.models.model import Model\nfrom keras.src.models.sequential import Sequential\n"
  },
  {
    "path": "keras/src/models/cloning.py",
    "content": "from keras.src import backend\nfrom keras.src import tree\nfrom keras.src import utils\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers import Input\nfrom keras.src.layers import InputLayer\nfrom keras.src.models.functional import Functional\nfrom keras.src.models.functional import functional_like_constructor\nfrom keras.src.models.sequential import Sequential\nfrom keras.src.saving import serialization_lib\n\n\n@keras_export(\"keras.models.clone_model\")\ndef clone_model(\n    model,\n    input_tensors=None,\n    clone_function=None,\n    call_function=None,\n    recursive=False,\n    **kwargs,\n):\n    \"\"\"Clone a Functional or Sequential `Model` instance.\n\n    Model cloning is similar to calling a model on new inputs,\n    except that it creates new layers (and thus new weights) instead\n    of sharing the weights of the existing layers.\n\n    Note that\n    `clone_model` will not preserve the uniqueness of shared objects within the\n    model (e.g. a single variable attached to two distinct layers will be\n    restored as two separate variables).\n\n    Args:\n        model: Instance of `Model`\n            (could be a Functional model or a Sequential model).\n        input_tensors: optional list of input tensors or InputLayer objects\n            to build the model upon. If not provided,\n            new `Input` objects will be created.\n        clone_function: Callable with signature `fn(layer)`\n            to be used to clone each layer in the target\n            model (except `Input` instances). It takes as argument the\n            layer instance to be cloned, and returns the corresponding layer\n            instance to be used in the model copy. If unspecified, this callable\n            defaults to the following serialization/deserialization function:\n            `lambda layer: layer.__class__.from_config(layer.get_config())`.\n            By passing a custom callable, you can customize your copy of the\n            model, e.g. by wrapping certain layers of interest (you might want\n            to replace all `LSTM` instances with equivalent\n            `Bidirectional(LSTM(...))` instances, for example).\n            Defaults to `None`.\n        call_function: Callable with signature\n            `fn(layer, *args, **kwargs)` to be used to call each\n            cloned layer and a set of inputs. It takes the layer instance,\n            the call arguments and keyword arguments, and returns the\n            call outputs. If unspecified, this callable defaults to\n            the regular `__call__()` method:\n            `def fn(layer, *args, **kwargs): return layer(*args, **kwargs)`.\n            By passing a custom callable, you can insert new layers before or\n            after a given layer. Note: this argument can only be used with\n            Functional models.\n        recursive: Boolean. Whether to recursively clone any Sequential\n            or Functional models encountered in the original\n            Sequential/Functional model. If `False`,\n            then inner models are cloned by calling `clone_function()`.\n            If `True`, then inner models are cloned by calling `clone_model()`\n            with the same `clone_function`, `call_function`, and `recursive`\n            arguments. Note that in this case, `call_function`\n            will not be propagated to any Sequential model\n            (since it is not applicable to Sequential models).\n\n    Returns:\n        An instance of `Model` reproducing the behavior\n        of the original model, on top of new inputs tensors,\n        using newly instantiated weights. The cloned model may behave\n        differently from the original model if a custom `clone_function`\n        or `call_function` modifies a layer or layer call.\n\n    Example:\n\n    ```python\n    # Create a test Sequential model.\n    model = keras.Sequential([\n        keras.layers.Input(shape=(728,)),\n        keras.layers.Dense(32, activation='relu'),\n        keras.layers.Dense(1, activation='sigmoid'),\n    ])\n    # Create a copy of the test model (with freshly initialized weights).\n    new_model = clone_model(model)\n    ```\n\n    Using a `clone_function` to make a model deterministic by setting the\n    random seed everywhere:\n\n    ```python\n    def clone_function(layer):\n        config = layer.get_config()\n        if \"seed\" in config:\n            config[\"seed\"] = 1337\n        return layer.__class__.from_config(config)\n\n    new_model = clone_model(model, clone_function=clone_function)\n    ```\n\n    Using a `call_function` to add a `Dropout` layer after each `Dense` layer\n    (without recreating new layers):\n\n    ```python\n    def call_function(layer, *args, **kwargs):\n        out = layer(*args, **kwargs)\n        if isinstance(layer, keras.layers.Dense):\n            out = keras.layers.Dropout(0.5)(out)\n        return out\n\n    new_model = clone_model(\n        model,\n        clone_function=lambda x: x,  # Reuse the same layers.\n        call_function=call_function,\n    )\n    ```\n\n    Note that subclassed models cannot be cloned by default,\n    since their internal layer structure is not known.\n    To achieve equivalent functionality\n    as `clone_model` in the case of a subclassed model, simply make sure\n    that the model class implements `get_config()`\n    (and optionally `from_config()`), and call:\n\n    ```python\n    new_model = model.__class__.from_config(model.get_config())\n    ```\n\n    In the case of a subclassed model, you cannot using a custom\n    `clone_function`.\n    \"\"\"\n    cache = kwargs.pop(\"cache\", None)\n    if kwargs:\n        raise ValueError(\n            f\"Unexpected keyword argument(s): {tuple(kwargs.keys())}\"\n        )\n\n    if isinstance(model, Sequential):\n        # Wrap clone_function to handle recursiveness and layer sharing.\n        clone_function = _wrap_clone_function(\n            clone_function,\n            call_function=call_function,\n            recursive=recursive,\n            cache=cache,\n        )\n        if call_function is not None:\n            raise ValueError(\n                \"`call_function` argument is not supported with Sequential \"\n                \"models.  In a Sequential model, layers aren't called \"\n                \"at model-construction time (they're merely listed). \"\n                \"Use `call_function` with Functional models only. \"\n                \"Received model of \"\n                f\"type '{model.__class__.__name__}', with \"\n                f\"call_function={clone_function}\"\n            )\n        return _clone_sequential_model(\n            model,\n            clone_function=clone_function,\n            input_tensors=input_tensors,\n        )\n    if isinstance(model, Functional):\n        # Wrap clone_function to handle recursiveness and layer sharing.\n        clone_function = _wrap_clone_function(\n            clone_function,\n            call_function=call_function,\n            recursive=recursive,\n            cache=cache,\n        )\n\n        # If the get_config() method is the same as a regular Functional\n        # model, we're safe to use _clone_functional_model (which relies\n        # on a Functional constructor). In the case where the get_config\n        # is custom, this may not necessarily work, but if clone_function\n        # or input_tensors are passed, we attempt it anyway\n        # in order to preserve backwards compatibility.\n        if utils.is_default(model.get_config) or (\n            clone_function or input_tensors\n        ):\n            return _clone_functional_model(\n                model,\n                clone_function=clone_function,\n                call_function=call_function,\n                input_tensors=input_tensors,\n            )\n\n    # Case of a custom model class\n    if clone_function or input_tensors:\n        raise ValueError(\n            \"Arguments `clone_function` and `input_tensors` \"\n            \"are only supported for Sequential models \"\n            \"or Functional models. Received model of \"\n            f\"type '{model.__class__.__name__}', with \"\n            f\"clone_function={clone_function} and \"\n            f\"input_tensors={input_tensors}\"\n        )\n    if call_function is not None:\n        raise ValueError(\n            \"Argument `call_function` is only supported \"\n            \"for Functional models. Received model of \"\n            f\"type '{model.__class__.__name__}', with \"\n            f\"call_function={clone_function}\"\n        )\n    config = serialization_lib.serialize_keras_object(model)\n    return serialization_lib.deserialize_keras_object(\n        config, custom_objects={model.__class__.__name__: model.__class__}\n    )\n\n\ndef _wrap_clone_function(\n    clone_function, call_function=None, recursive=False, cache=None\n):\n    \"\"\"Wrapper to handle recursiveness and layer sharing.\"\"\"\n    if clone_function is None:\n\n        def _clone_layer(layer):\n            return layer.__class__.from_config(layer.get_config())\n\n        clone_function = _clone_layer\n\n    if cache is None:\n        cache = {}\n\n    def wrapped_clone_function(layer):\n        if id(layer) in cache:\n            return cache[id(layer)]\n        if recursive:\n            if isinstance(layer, Sequential):\n                # Note: Sequential doesn't support call_function.\n                clone = clone_model(\n                    layer,\n                    clone_function=clone_function,\n                    recursive=True,\n                    cache=cache,\n                )\n                cache[id(layer)] = clone\n                return clone\n            elif isinstance(layer, Functional):\n                clone = clone_model(\n                    layer,\n                    clone_function=clone_function,\n                    call_function=call_function,\n                    recursive=True,\n                    cache=cache,\n                )\n                cache[id(layer)] = clone\n                return clone\n        clone = clone_function(layer)\n        cache[id(layer)] = clone\n        return clone\n\n    return wrapped_clone_function\n\n\ndef _clone_sequential_model(model, clone_function, input_tensors=None):\n    \"\"\"Clone a `Sequential` model instance.\n\n    Model cloning is similar to calling a model on new inputs,\n    except that it creates new layers (and thus new weights) instead\n    of sharing the weights of the existing layers.\n\n    Args:\n        model: Instance of `Sequential`.\n        input_tensors: optional list of input tensors\n            to build the model upon. If not provided,\n            placeholders will be created.\n        clone_function: callable to be applied on non-input layers in the model.\n            By default, it clones the layer (without copying the weights).\n\n    Returns:\n        An instance of `Sequential` reproducing the behavior\n        of the original model, on top of new inputs tensors,\n        using newly instantiated weights.\n    \"\"\"\n\n    if not isinstance(model, Sequential):\n        raise ValueError(\n            \"Expected `model` argument \"\n            \"to be a `Sequential` model instance. \"\n            f\"Received: model={model}\"\n        )\n\n    if not callable(clone_function):\n        raise ValueError(\n            \"Expected `clone_function` argument to be a callable. \"\n            f\"Received: clone_function={clone_function}\"\n        )\n\n    new_layers = [clone_function(layer) for layer in model.layers]\n\n    if isinstance(model._layers[0], InputLayer):\n        ref_input_layer = model._layers[0]\n        input_name = ref_input_layer.name\n        input_batch_shape = ref_input_layer.batch_shape\n        input_dtype = ref_input_layer._dtype\n        input_optional = ref_input_layer.optional\n    else:\n        input_name = None\n        input_dtype = None\n        input_batch_shape = None\n        input_optional = False\n\n    if input_tensors is not None:\n        if isinstance(input_tensors, (list, tuple)):\n            if len(input_tensors) != 1:\n                raise ValueError(\n                    \"Argument `input_tensors` must contain a single tensor.\"\n                )\n            input_tensors = input_tensors[0]\n        if not isinstance(input_tensors, backend.KerasTensor):\n            raise ValueError(\n                \"Argument `input_tensors` must be a KerasTensor. \"\n                f\"Received invalid value: input_tensors={input_tensors}\"\n            )\n        inputs = Input(\n            tensor=input_tensors,\n            name=input_name,\n            optional=input_optional,\n        )\n        new_layers = [inputs] + new_layers\n    else:\n        if input_batch_shape is not None:\n            inputs = Input(\n                batch_shape=input_batch_shape,\n                dtype=input_dtype,\n                name=input_name,\n                optional=input_optional,\n            )\n            new_layers = [inputs] + new_layers\n    cloned_model = Sequential(\n        new_layers, name=model.name, trainable=model.trainable\n    )\n\n    # If model compiled already then set same to cloned model\n    if model.compiled:\n        compiled_config = model.get_compile_config()\n        cloned_model.compile_from_config(compiled_config)\n    return cloned_model\n\n\ndef _clone_functional_model(\n    model, clone_function, input_tensors=None, call_function=None\n):\n    \"\"\"Clone a `Functional` model instance.\n\n    Model cloning is similar to calling a model on new inputs,\n    except that it creates new layers (and thus new weights) instead\n    of sharing the weights of the existing layers.\n\n    Input layers are always cloned.\n\n    Args:\n        model: Instance of `Functional`.\n        input_tensors: optional list of input tensors\n            to build the model upon. If not provided,\n            placeholders will be created.\n        clone_function: callable to be applied on non-input layers in the model.\n            By default, it clones the layer (without copying the weights).\n\n    Returns:\n        An instance of `Functional` reproducing the behavior\n        of the original model, on top of new inputs tensors,\n        using newly instantiated weights.\n    \"\"\"\n\n    if not callable(clone_function):\n        raise ValueError(\n            \"Expected `clone_function` argument to be a callable. \"\n            f\"Received: clone_function={clone_function}\"\n        )\n\n    if not isinstance(model, Functional):\n        raise ValueError(\n            \"Expected `model` argument \"\n            f\"to be a Functional Model instance. Received: model={model}\"\n        )\n\n    if input_tensors is not None:\n        if not all(\n            isinstance(x, backend.KerasTensor)\n            for x in tree.flatten(input_tensors)\n        ):\n            raise ValueError(\n                \"All entries in `input_tensors` must be KerasTensors. \"\n                f\"Received invalid values: inputs_tensors={input_tensors}\"\n            )\n        try:\n            tree.assert_same_structure(input_tensors, model.input)\n        except ValueError as e:\n            raise ValueError(\n                \"`input_tensors` must have the same structure as model.input\"\n                f\"\\nReference structure: {model.input}\"\n                f\"\\nReceived structure: {input_tensors}\"\n            ) from e\n    else:\n        input_tensors = tree.map_structure(\n            lambda x: Input(batch_shape=x.shape, dtype=x.dtype, name=x.name),\n            model.input,\n        )\n\n    def operation_fn(layer):\n        new_layer = clone_function(layer)\n        return new_layer\n\n    output_tensors = model._run_through_graph(\n        input_tensors,\n        operation_fn=operation_fn,\n        call_fn=call_function,\n    )\n\n    if functional_like_constructor(model.__class__):\n        new_model = model.__class__(\n            input_tensors, output_tensors, name=model.name\n        )\n    else:\n        # This may be incorrect: the new model will end up having a different\n        # class than the original. However various existing models rely\n        # on this behavior, so we keep it.\n        new_model = Functional(input_tensors, output_tensors, name=model.name)\n    if model.compiled:\n        compiled_config = model.get_compile_config()\n        new_model.compile_from_config(compiled_config)\n    return new_model\n"
  },
  {
    "path": "keras/src/models/cloning_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src import tree\nfrom keras.src.models.cloning import clone_model\n\n\ndef get_mlp_functional_model(shared_layers=False):\n    inputs = layers.Input(shape=(3,))\n    x = layers.Dense(2)(inputs)\n    if shared_layers:\n        layer = layers.Dense(2, name=\"shared\")\n        x = layer(x)\n        x = layer(x)\n    outputs = layers.Dense(2)(x)\n    model = models.Model(inputs, outputs)\n    return model\n\n\ndef get_nested_functional_model():\n    inputs = layers.Input(shape=(4,))\n    x = layers.Dense(3)(inputs)\n    mlp = get_mlp_functional_model()\n    x = mlp(x)\n    outputs = layers.Dense(2)(x)\n    model = models.Model(inputs, outputs)\n    return model\n\n\ndef get_nested_sequential_model():\n    model = models.Sequential()\n    model.add(layers.Dense(2))\n    model.add(get_sequential_model(explicit_input=False))\n    model.add(layers.Dense(2))\n    return model\n\n\ndef get_doubly_nested_functional_model():\n    \"\"\"Outer -> middle (Functional) -> inner (Functional), 3 nesting levels.\"\"\"\n    inputs = layers.Input(shape=(5,))\n    x = layers.Dense(4)(inputs)\n    middle = get_nested_functional_model()  # already contains an inner mlp\n    x = middle(x)\n    outputs = layers.Dense(2)(x)\n    return models.Model(inputs, outputs)\n\n\ndef get_doubly_nested_sequential_model():\n    \"\"\"Outer -> middle (Sequential) -> inner (Sequential), 3 nesting levels.\"\"\"\n    model = models.Sequential()\n    model.add(layers.Dense(3))\n    model.add(get_nested_sequential_model())\n    model.add(layers.Dense(2))\n    return model\n\n\ndef get_cnn_functional_model(shared_layers=False):\n    inputs = layers.Input(shape=(7, 3))\n    x = layers.Conv1D(2, 2, padding=\"same\")(inputs)\n    if shared_layers:\n        layer = layers.Conv1D(2, 2, padding=\"same\", name=\"shared\")\n        x = layer(x)\n        x = layer(x)\n    outputs = layers.Conv1D(2, 2, padding=\"same\")(x)\n    model = models.Model(inputs, outputs)\n    return model\n\n\ndef get_sequential_model(explicit_input=True):\n    model = models.Sequential()\n    if explicit_input:\n        model.add(layers.Input(shape=(3,)))\n    model.add(layers.Dense(2))\n    model.add(layers.Dense(2))\n    return model\n\n\ndef get_cnn_sequential_model(explicit_input=True):\n    model = models.Sequential()\n    if explicit_input:\n        model.add(layers.Input(shape=(7, 3)))\n    model.add(layers.Conv1D(2, 2, padding=\"same\"))\n    model.add(layers.Conv1D(2, 2, padding=\"same\"))\n    return model\n\n\ndef get_subclassed_model():\n    class ExampleModel(models.Model):\n        def __init__(self, **kwargs):\n            super().__init__(**kwargs)\n            self.d1 = layers.Dense(2)\n            self.d2 = layers.Dense(2)\n\n        def call(self, x):\n            return self.d2(self.d1(x))\n\n    return ExampleModel()\n\n\n@pytest.mark.requires_trainable_backend\nclass CloneModelTest(testing.TestCase):\n    def assert_models_equal(self, model1, model2, ref_input):\n        result1 = model1(ref_input)\n        result2 = model2(ref_input)\n        for r1, r2 in zip(tree.flatten(result1), tree.flatten(result2)):\n            self.assertAllClose(\n                ops.convert_to_numpy(r1), ops.convert_to_numpy(r2)\n            )\n\n    def assert_weights_equal(self, model1, model2):\n        for a, b in zip(model1.weights, model2.weights):\n            self.assertAllClose(a.numpy(), b.numpy())\n\n    @parameterized.named_parameters(\n        (\"mlp_functional\", get_mlp_functional_model),\n        (\"cnn_functional\", get_cnn_functional_model, True),\n        (\"sequential\", get_sequential_model),\n        (\n            \"deferred_sequential\",\n            lambda: get_sequential_model(explicit_input=False),\n        ),\n        (\"subclassed\", get_subclassed_model),\n    )\n    def test_cloning_correctness(self, model_fn, is_conv=False):\n        ref_input = np.random.random((2, 7, 3) if is_conv else (2, 3))\n        model = model_fn()\n        new_model = clone_model(model)\n        model(ref_input)  # Maybe needed to build the model\n        new_model(ref_input)  # Maybe needed to build the model\n        new_model.set_weights(model.get_weights())\n        self.assert_models_equal(model, new_model, ref_input)\n\n    @parameterized.named_parameters(\n        (\"mlp_functional\", get_mlp_functional_model),\n        (\"cnn_functional\", get_cnn_functional_model),\n        (\"sequential\", get_sequential_model),\n    )\n    def test_custom_clone_function(self, model_fn):\n        def clone_function(layer):\n            config = layer.get_config()\n            config[\"name\"] = f\"{config['name']}_custom\"\n            return layer.__class__.from_config(config)\n\n        model = model_fn()\n        new_model = clone_model(model, clone_function=clone_function)\n        for l1, l2 in zip(model.layers, new_model.layers):\n            if not isinstance(l1, layers.InputLayer):\n                self.assertEqual(l2.name, f\"{l1.name}_custom\")\n\n    @parameterized.named_parameters(\n        (\"cnn_functional\", get_cnn_functional_model),\n        (\"cnn_sequential\", get_cnn_sequential_model),\n        (\n            \"cnn_sequential_noinputlayer\",\n            lambda: get_cnn_sequential_model(explicit_input=False),\n        ),\n    )\n    def test_input_tensors(self, model_fn):\n        ref_input = np.random.random((2, 7, 3))\n        model = model_fn()\n        model(ref_input)  # Maybe needed to get model inputs if no Input layer\n        input_tensor = model.inputs[0]\n        new_model = clone_model(model, input_tensors=input_tensor)\n        tree.assert_same_structure(model.inputs, new_model.inputs)\n        tree.assert_same_structure(model.outputs, new_model.outputs)\n\n    def test_shared_layers_cloning(self):\n        model = get_mlp_functional_model(shared_layers=True)\n        new_model = clone_model(model)\n        self.assertLen(new_model.layers, 4)\n\n    def test_structured_io_cloning(self):\n        x = layers.Input((3,))\n        y = layers.Input((3,))\n        z1 = x + y\n        z2 = layers.Dense(5)(z1)\n        inputs = dict(x=x, y=y)\n        outputs = dict(z1=z1, z2=z2)\n        model0 = models.Model(inputs, outputs)\n\n        model = clone_model(model0)\n        tree.assert_same_structure(model.input, inputs)\n        tree.assert_same_structure(model.output, outputs)\n\n        model = clone_model(model0, input_tensors=inputs)\n        tree.assert_same_structure(model.input, inputs)\n        tree.assert_same_structure(model.output, outputs)\n\n        with self.assertRaisesRegex(\n            ValueError,\n            \"`input_tensors` must have the same structure as model.input\",\n        ):\n            model = clone_model(model0, input_tensors=(x, y))\n\n    def test_call_fn(self):\n        model = get_mlp_functional_model(shared_layers=False)\n\n        def call_function(layer, *args, **kwargs):\n            out = layer(*args, **kwargs)\n            if isinstance(layer, layers.Dense):\n                out = layers.Dropout(0.5)(out)\n            return out\n\n        new_model = clone_model(\n            model,\n            clone_function=lambda x: x,  # Reuse the same layers.\n            call_function=call_function,\n        )\n        self.assertLen(model.layers, 3)\n        self.assertLen(new_model.layers, 5)\n        self.assertIsInstance(new_model.layers[2], layers.Dropout)\n        self.assertIsInstance(new_model.layers[4], layers.Dropout)\n        ref_input = np.random.random((2, 3))\n        self.assert_models_equal(model, new_model, ref_input)\n\n    def test_recursive(self):\n        model = get_nested_functional_model()\n\n        def call_function(layer, *args, **kwargs):\n            out = layer(*args, **kwargs)\n            if isinstance(layer, layers.Dense):\n                out = layers.Dropout(0.5)(out)\n            return out\n\n        new_model = clone_model(\n            model,\n            clone_function=lambda x: x,  # Reuse the same layers.\n            call_function=call_function,\n            recursive=True,\n        )\n        self.assertLen(model._flatten_layers(), 8)\n        self.assertLen(new_model._flatten_layers(), 12)\n        self.assertIsInstance(new_model.layers[3].layers[2], layers.Dropout)\n        self.assertIsInstance(new_model.layers[3].layers[4], layers.Dropout)\n        ref_input = np.random.random((2, 4))\n        self.assert_models_equal(model, new_model, ref_input)\n\n        # Sequential.\n        def clone_function(layer):\n            layer = layer.__class__.from_config(layer.get_config())\n            layer.flag = True\n            return layer\n\n        model = get_nested_sequential_model()\n        new_model = clone_model(\n            model,\n            clone_function=clone_function,\n            recursive=True,\n        )\n        ref_input = np.random.random((2, 3))\n        model(ref_input)  # Maybe needed to build the model\n        new_model(ref_input)  # Maybe needed to build the model\n        new_model.set_weights(model.get_weights())\n        self.assert_models_equal(model, new_model, ref_input)\n        for l1, l2 in zip(model._flatten_layers(), new_model._flatten_layers()):\n            if isinstance(l2, layers.Dense):\n                self.assertFalse(hasattr(l1, \"flag\"))\n                self.assertTrue(hasattr(l2, \"flag\"))\n\n    def test_recursive_multi_level(self):\n        # Functional: 3 nesting levels (outer -> middle -> inner).\n        # Before the fix, recursive=True was not forwarded in the\n        # recursive clone_model() calls, so only the first nesting\n        # level was entered. The inner-most Dense layers would be\n        # shared rather than cloned.\n        model = get_doubly_nested_functional_model()\n\n        def clone_function(layer):\n            layer = layer.__class__.from_config(layer.get_config())\n            layer.flag = True\n            return layer\n\n        new_model = clone_model(\n            model,\n            clone_function=clone_function,\n            recursive=True,\n        )\n        ref_input = np.random.random((2, 5))\n        model(ref_input)\n        new_model(ref_input)\n        new_model.set_weights(model.get_weights())\n        self.assert_models_equal(model, new_model, ref_input)\n\n        # Every Dense in the deepest level should have been cloned\n        # (i.e. clone_function was applied), not shared.\n        for l1, l2 in zip(model._flatten_layers(), new_model._flatten_layers()):\n            if isinstance(l2, layers.Dense):\n                self.assertFalse(hasattr(l1, \"flag\"))\n                self.assertTrue(hasattr(l2, \"flag\"))\n\n        # Sequential: 3 nesting levels.\n        model = get_doubly_nested_sequential_model()\n        new_model = clone_model(\n            model,\n            clone_function=clone_function,\n            recursive=True,\n        )\n        ref_input = np.random.random((2, 5))\n        model(ref_input)\n        new_model(ref_input)\n        new_model.set_weights(model.get_weights())\n        self.assert_models_equal(model, new_model, ref_input)\n\n        for l1, l2 in zip(model._flatten_layers(), new_model._flatten_layers()):\n            if isinstance(l2, layers.Dense):\n                self.assertFalse(hasattr(l1, \"flag\"))\n                self.assertTrue(hasattr(l2, \"flag\"))\n\n    def test_compiled_model_cloning(self):\n        model = models.Sequential()\n        model.add(layers.Input((3,)))\n        model.add(layers.Dense(5, activation=\"relu\"))\n        model.add(layers.Dense(1, activation=\"sigmoid\"))\n        model.compile(optimizer=\"adam\", loss=\"binary_crossentropy\")\n        cloned_model = clone_model(model)\n        self.assertEqual(model.compiled, cloned_model.compiled)\n"
  },
  {
    "path": "keras/src/models/functional.py",
    "content": "import copy\nimport inspect\nimport typing\nimport warnings\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import tree\nfrom keras.src.backend.common import global_state\nfrom keras.src.layers.core.input_layer import Input\nfrom keras.src.layers.core.input_layer import InputLayer\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.layers.layer import Layer\nfrom keras.src.legacy.saving import saving_utils\nfrom keras.src.legacy.saving import serialization as legacy_serialization\nfrom keras.src.models.model import Model\nfrom keras.src.ops.function import Function\nfrom keras.src.ops.function import _build_map\nfrom keras.src.ops.function import make_node_key\nfrom keras.src.ops.node import KerasHistory\nfrom keras.src.ops.node import Node\nfrom keras.src.ops.operation import Operation\nfrom keras.src.saving import serialization_lib\nfrom keras.src.utils import tracking\n\n\nclass Functional(Function, Model):\n    \"\"\"A `Functional` model is a `Model` defined as a directed graph of layers.\n\n    Three types of `Model` exist: subclassed `Model`, `Functional` model,\n    and `Sequential` (a special case of `Functional`).\n\n    A `Functional` model can be instantiated by passing two arguments to\n    `__init__()`. The first argument is the `keras.Input` objects\n    that represent the inputs to the model.\n    The second argument specifies the output tensors that represent\n    the outputs of this model. Both arguments can be a nested structure\n    of tensors.\n\n    Example:\n\n    ```\n    inputs = {'x1': keras.Input(shape=(10,), name='x1'),\n              'x2': keras.Input(shape=(1,), name='x2')}\n    t = keras.layers.Dense(1, activation='relu')(inputs['x1'])\n    outputs = keras.layers.Add()([t, inputs['x2']])\n    model = keras.Model(inputs, outputs)\n    ```\n\n    A `Functional` model constructed using the Functional API can also\n    include raw Keras 3 ops.\n\n    Example:\n\n    ```python\n    inputs = keras.Input(shape=(10,))\n    x = keras.layers.Dense(1)(inputs)\n    outputs = ops.nn.relu(x)\n    model = keras.Model(inputs, outputs)\n    ```\n\n    A new `Functional` model can also be created by using the\n    intermediate tensors. This enables you to quickly extract sub-components\n    of the model.\n\n    Example:\n\n    ```python\n    inputs = keras.Input(shape=(None, None, 3))\n    processed = keras.layers.RandomCrop(width=32, height=32)(inputs)\n    conv = keras.layers.Conv2D(filters=2, kernel_size=3)(processed)\n    pooling = keras.layers.GlobalAveragePooling2D()(conv)\n    feature = keras.layers.Dense(10)(pooling)\n\n    full_model = keras.Model(inputs, feature)\n    backbone = keras.Model(processed, conv)\n    activations = keras.Model(conv, feature)\n    ```\n\n    Note that the `backbone` and `activations` models are not\n    created with `keras.Input` objects, but with the tensors\n    that are originated from `keras.Input` objects.\n    Under the hood, the layers and weights will\n    be shared across these models, so that user can train the `full_model`, and\n    use `backbone` or `activations` to do feature extraction.\n    The inputs and outputs of the model can be nested structures of tensors as\n    well, and the created models are standard `Functional` model that support\n    all the existing API.\n\n    Args:\n        inputs: List of input tensors (must be created via `keras.Input()`\n            or originated from `keras.Input()`).\n        outputs: List of output tensors.\n        name: String, optional. Name of the model.\n        trainable: Boolean, optional. If the model's variables should be\n            trainable.\n    \"\"\"\n\n    def __new__(cls, *args, **kwargs):\n        return typing.cast(cls, super().__new__(cls))\n\n    @tracking.no_automatic_dependency_tracking\n    def __init__(self, inputs, outputs, name=None, **kwargs):\n        if isinstance(inputs, dict):\n            for k, v in inputs.items():\n                if isinstance(v, backend.KerasTensor) and k != v.name:\n                    warnings.warn(\n                        \"When providing `inputs` as a dict, all keys in the \"\n                        \"dict must match the names of the corresponding \"\n                        f\"tensors. Received key '{k}' mapping to value {v} \"\n                        f\"which has name '{v.name}'. Change the tensor name to \"\n                        f\"'{k}' (via `Input(..., name='{k}')`)\"\n                    )\n\n        trainable = kwargs.pop(\"trainable\", None)\n        flat_inputs = tree.flatten(inputs)\n        flat_outputs = tree.flatten(outputs)\n        for x in flat_inputs:\n            if not isinstance(x, backend.KerasTensor):\n                raise ValueError(\n                    \"All `inputs` values must be KerasTensors. Received: \"\n                    f\"inputs={inputs} including invalid value {x} of \"\n                    f\"type {type(x)}\"\n                )\n        for x in flat_outputs:\n            if not isinstance(x, backend.KerasTensor):\n                raise ValueError(\n                    \"All `outputs` values must be KerasTensors. Received: \"\n                    f\"outputs={outputs} including invalid value {x} of \"\n                    f\"type {type(x)}\"\n                )\n\n        if not all(is_input_keras_tensor(t) for t in flat_inputs):\n            inputs, outputs = clone_graph_nodes(inputs, outputs)\n\n        Function.__init__(self, inputs, outputs, name=name)\n\n        if trainable is not None:\n            self.trainable = trainable\n\n        self._layers = self.layers\n        self.build(None)\n        # We will convert directly (to the correct dtype per input).\n        self._convert_input_args = False\n        self._allow_non_tensor_positional_args = True\n        output_layers = [x._keras_history[0] for x in self.outputs]\n        self.output_names = [x.name for x in output_layers]\n\n    def _lock_state(self):\n        # Unlike other layers, we allow Functional state to be mutable after\n        # build. E.g. to attach a layer to a model that is not part of the\n        # functional DAG.\n        pass\n\n    def _obj_type(self):\n        return \"Functional\"\n\n    @property\n    def layers(self):\n        layers = []\n        for operation in self._operations:\n            if isinstance(operation, Layer):\n                layers.append(operation)\n        return layers\n\n    @layers.setter\n    def layers(self, _):\n        raise AttributeError(\n            \"`Model.layers` attribute is reserved and should not be used. \"\n            \"Please use another name.\"\n        )\n\n    def call(self, inputs, training=None, mask=None, **kwargs):\n        # Add support for training, masking\n        inputs = self._standardize_inputs(inputs)\n        if mask is None:\n            masks = [None] * len(inputs)\n        else:\n            masks = tree.flatten(mask)\n            for x, mask in zip(inputs, masks):\n                if mask is not None:\n                    backend.set_keras_mask(x, mask)\n        outputs = self._run_through_graph(\n            inputs,\n            operation_fn=lambda op: operation_fn(\n                op, training=training, **kwargs\n            ),\n        )\n        return unpack_singleton(outputs)\n\n    def compute_output_spec(self, inputs, training=None, mask=None):\n        # From Function\n        return super().compute_output_spec(inputs)\n\n    def compute_output_shape(self, input_shape):\n        # From Function\n        return super().compute_output_shape(input_shape)\n\n    def build(self, input_shape):\n        self.built = True\n\n    @property\n    def input_shape(self):\n        input_shapes = tree.map_structure(lambda x: x.shape, self.inputs)\n        if isinstance(input_shapes, list) and len(input_shapes) == 1:\n            return input_shapes[0]\n        return input_shapes\n\n    @property\n    def output_shape(self):\n        output_shapes = tree.map_structure(lambda x: x.shape, self.outputs)\n        if isinstance(output_shapes, list) and len(output_shapes) == 1:\n            return output_shapes[0]\n        return output_shapes\n\n    def _assert_input_compatibility(self, *args):\n        return super(Model, self)._assert_input_compatibility(*args)\n\n    def _maybe_warn_inputs_struct_mismatch(self, inputs, raise_exception=False):\n        try:\n            # We first normalize to tuples before performing the check to\n            # suppress warnings when encountering mismatched tuples and lists.\n            tree.assert_same_structure(\n                tree.lists_to_tuples(inputs),\n                tree.lists_to_tuples(self._inputs_struct),\n            )\n        except:\n            model_inputs_struct = tree.map_structure(\n                lambda x: x.name, self._inputs_struct\n            )\n            inputs_struct = tree.map_structure(\n                lambda x: f\"Tensor(shape={x.shape})\", inputs\n            )\n            msg = (\n                \"The structure of `inputs` doesn't match the expected \"\n                f\"structure.\\nExpected: {model_inputs_struct}\\n\"\n                f\"Received: inputs={inputs_struct}\"\n            )\n            if raise_exception:\n                raise ValueError(msg)\n            warnings.warn(msg)\n\n    def _convert_inputs_to_tensors(self, flat_inputs):\n        converted = []\n        for x, input in zip(flat_inputs, self._inputs):\n            if x is None:  # TODO: check if optional\n                converted.append(x)\n            else:\n                converted.append(\n                    ops.convert_to_tensor(\n                        x, dtype=input.dtype, sparse=input.sparse\n                    )\n                )\n        return converted\n\n    def _adjust_input_rank(self, flat_inputs):\n        adjusted = []\n        for i, x in enumerate(flat_inputs):\n            ref_shape = self._inputs[i].shape\n            if x is None:\n                adjusted.append(x)\n                continue\n            x_rank = len(x.shape)\n            ref_rank = len(ref_shape)\n            if x_rank == ref_rank:\n                adjusted.append(x)\n                continue\n            if x_rank == ref_rank + 1:\n                if x.shape[-1] == 1:\n                    adjusted.append(ops.squeeze(x, axis=-1))\n                    continue\n            if x_rank == ref_rank - 1:\n                if ref_shape[-1] == 1:\n                    adjusted.append(ops.expand_dims(x, axis=-1))\n                    continue\n            flat_paths_and_inputs = tree.flatten_with_path(self._inputs_struct)\n            path = \".\".join(str(p) for p in flat_paths_and_inputs[i][0])\n            raise ValueError(\n                f\"Invalid input shape for input {x} with name \"\n                f\"'{self._inputs[i].name}' and path '{path}'. Expected shape \"\n                f\"{ref_shape}, but input has incompatible shape {x.shape}\"\n            )\n        # Add back metadata.\n        for i in range(len(flat_inputs)):\n            if hasattr(flat_inputs[i], \"_keras_history\"):\n                adjusted[i]._keras_history = flat_inputs[i]._keras_history\n            mask = backend.get_keras_mask(flat_inputs[i])\n            if mask is not None:\n                backend.set_keras_mask(adjusted[i], mask)\n        return adjusted\n\n    def _standardize_inputs(self, inputs):\n        raise_exception = False\n        if (\n            isinstance(self._inputs_struct, list)\n            and len(self._inputs_struct) == 1\n            and ops.is_tensor(inputs)\n        ):\n            inputs = [inputs]\n        elif isinstance(inputs, dict) and not isinstance(\n            self._inputs_struct, dict\n        ):\n            # This is to avoid warning\n            # when we have reconcilable dict/list structs\n            if hasattr(self._inputs_struct, \"__len__\") and all(\n                isinstance(i, backend.KerasTensor) for i in self._inputs_struct\n            ):\n                expected_keys = set(i.name for i in self._inputs_struct)\n                keys = set(inputs.keys())\n                if expected_keys.issubset(keys):\n                    inputs = [inputs[i.name] for i in self._inputs_struct]\n                else:\n                    raise_exception = True\n            elif isinstance(self._inputs_struct, backend.KerasTensor):\n                if self._inputs_struct.name in inputs:\n                    inputs = [inputs[self._inputs_struct.name]]\n                else:\n                    raise_exception = True\n            else:\n                raise_exception = True\n        if (\n            isinstance(self._inputs_struct, dict)\n            and not isinstance(inputs, dict)\n            and list(self._inputs_struct.keys())\n            != sorted(self._inputs_struct.keys())\n        ):\n            raise_exception = True\n        self._maybe_warn_inputs_struct_mismatch(\n            inputs, raise_exception=raise_exception\n        )\n\n        flat_inputs = tree.flatten(inputs)\n        flat_inputs = self._convert_inputs_to_tensors(flat_inputs)\n        return self._adjust_input_rank(flat_inputs)\n\n    @property\n    def input(self):\n        # For backwards compatibility,\n        # override `input` to retrieve the used-provided\n        # constructor inputs\n        return self._inputs_struct\n\n    @property\n    def output(self):\n        return self._outputs_struct\n\n    def add_loss(self, loss):\n        # Symbolic only. TODO\n        raise NotImplementedError\n\n    @property\n    def input_spec(self):\n        if hasattr(self, \"_manual_input_spec\"):\n            return self._manual_input_spec\n\n        def shape_with_no_batch_size(x):\n            x = list(x)\n            if x:\n                x[0] = None\n            return tuple(x)\n\n        def make_spec_for_tensor(x, name=None):\n            optional = False\n            if isinstance(x._keras_history[0], InputLayer):\n                if x._keras_history[0].optional:\n                    optional = True\n            return InputSpec(\n                shape=shape_with_no_batch_size(x.shape),\n                allow_last_axis_squeeze=True,\n                name=x._keras_history[0].name if name is None else name,\n                optional=optional,\n            )\n\n        if isinstance(self._inputs_struct, dict):\n            if all(\n                isinstance(x, backend.KerasTensor)\n                for x in self._inputs_struct.values()\n            ):\n                # Case where `_nested_inputs` is a plain dict of Inputs.\n                names = sorted(self._inputs_struct.keys())\n                return [\n                    make_spec_for_tensor(self._inputs_struct[name], name=name)\n                    for name in names\n                ]\n            return None  # Deeply nested dict: skip checks.\n        return [make_spec_for_tensor(x) for x in self.inputs]\n\n    @input_spec.setter\n    def input_spec(self, value):\n        self._manual_input_spec = value\n\n    def get_config(self):\n        if not functional_like_constructor(self.__class__):\n            # Subclassed networks are not serializable\n            # (unless serialization is implemented by\n            # the author of the subclassed network).\n            return Model.get_config(self)\n\n        config = {\n            \"name\": self.name,\n            \"trainable\": self.trainable,\n        }\n        # Build a map from a layer unique name (make_node_key)\n        # to the index of the nodes that are saved in the config.\n        # Only nodes in network_nodes are saved.\n        node_reindexing_map = {}\n        for operation in self.operations:\n            if issubclass(operation.__class__, Functional):\n                # Functional models start with a pre-existing node\n                # linking their input to output.\n                kept_nodes = 1\n            else:\n                kept_nodes = 0\n            for original_node_index, node in enumerate(\n                operation._inbound_nodes\n            ):\n                node_key = make_node_key(operation, original_node_index)\n                if node_key in self._nodes:\n                    # i.e. we mark it to be saved\n                    node_reindexing_map[node_key] = kept_nodes\n                    kept_nodes += 1\n\n        # serialize and save the layers in layer_configs\n        layer_configs = []\n        for operation in self.operations:  # From the earliest layers on.\n            filtered_inbound_nodes = []\n            for original_node_index, node in enumerate(\n                operation._inbound_nodes\n            ):\n                node_key = make_node_key(operation, original_node_index)\n                if node_key in self._nodes:\n                    # The node is relevant to the model:\n                    # add to filtered_inbound_nodes.\n                    node_data = serialize_node(node, own_nodes=self._nodes)\n                    if node_data is not None:\n                        filtered_inbound_nodes.append(node_data)\n\n            serialize_obj_fn = serialization_lib.serialize_keras_object\n            if global_state.get_global_attribute(\"use_legacy_config\", False):\n                # Legacy format serialization used for H5 and SavedModel\n                serialize_obj_fn = legacy_serialization.serialize_keras_object\n            layer_config = serialize_obj_fn(operation)\n            layer_config[\"name\"] = operation.name\n            layer_config[\"inbound_nodes\"] = filtered_inbound_nodes\n            layer_configs.append(layer_config)\n        config[\"layers\"] = layer_configs\n\n        # Gather info about inputs and outputs.\n        def get_tensor_config(tensor):\n            operation = tensor._keras_history[0]\n            node_index = tensor._keras_history[1]\n            tensor_index = tensor._keras_history[2]\n            node_key = make_node_key(operation, node_index)\n            if node_key not in self._nodes:\n                raise RuntimeError(\n                    f\"Internal error: could not find node key {node_key}.\"\n                )\n            new_node_index = node_reindexing_map[node_key]\n            return [operation.name, new_node_index, tensor_index]\n\n        def map_tensors(tensors):\n            return tree.map_structure(get_tensor_config, tensors)\n\n        config[\"input_layers\"] = map_tensors(self._inputs_struct)\n        config[\"output_layers\"] = map_tensors(self._outputs_struct)\n        return copy.deepcopy(config)\n\n\ndef functional_from_config(cls, config, custom_objects=None):\n    \"\"\"Instantiates a Functional model from its config (from `get_config()`).\n\n    Args:\n        cls: Class of the model, e.g. a custom subclass of `Model`.\n        config: Output of `get_config()` for the original model instance.\n        custom_objects: Optional dict of custom objects.\n\n    Returns:\n        An instance of `cls`.\n    \"\"\"\n    # Layer instances created during\n    # the graph reconstruction process\n    created_layers = {}\n\n    # Dictionary mapping layer instances to\n    # node data that specifies a layer call.\n    # It acts as a queue that maintains any unprocessed\n    # layer call until it becomes possible to process it\n    # (i.e. until the input tensors to the call all exist).\n    unprocessed_nodes = {}\n\n    def add_unprocessed_node(layer, node_data):\n        \"\"\"Add node to layer list\n\n        Arg:\n            layer: layer object\n            node_data: Node data specifying layer call\n        \"\"\"\n        if layer not in unprocessed_nodes:\n            unprocessed_nodes[layer] = [node_data]\n        else:\n            unprocessed_nodes[layer].append(node_data)\n\n    def process_node(layer, node_data):\n        \"\"\"Reconstruct node by linking to inbound layers\n\n        Args:\n            layer: Layer to process\n            node_data: List of layer configs\n        \"\"\"\n        args, kwargs = deserialize_node(node_data, created_layers)\n        # Call layer on its inputs, thus creating the node\n        # and building the layer if needed.\n        layer(*args, **kwargs)\n\n    def process_layer(layer_data):\n        \"\"\"Deserializes a layer and index its inbound nodes.\n\n        Args:\n            layer_data: layer config dict.\n        \"\"\"\n        layer_name = layer_data[\"name\"]\n\n        # Instantiate layer.\n        if \"module\" not in layer_data:\n            # Legacy format deserialization (no \"module\" key)\n            # used for H5 and SavedModel formats\n            layer = saving_utils.model_from_config(\n                layer_data, custom_objects=custom_objects\n            )\n        else:\n            layer = serialization_lib.deserialize_keras_object(\n                layer_data, custom_objects=custom_objects\n            )\n        if not isinstance(layer, Operation):\n            raise ValueError(\n                \"Unexpected object from deserialization, expected a layer or \"\n                f\"operation, got a {type(layer)}\"\n            )\n        created_layers[layer_name] = layer\n\n        # Gather layer inputs.\n        inbound_nodes_data = layer_data[\"inbound_nodes\"]\n        for node_data in inbound_nodes_data:\n            # We don't process nodes (i.e. make layer calls)\n            # on the fly because the inbound node may not yet exist,\n            # in case of layer shared at different topological depths\n            # (e.g. a model such as A(B(A(B(x)))))\n            add_unprocessed_node(layer, node_data)\n\n    # Extract config used to instantiate Functional model from the config. The\n    # remaining config will be passed as keyword arguments to the Model\n    # constructor.\n    functional_config = {}\n    for key in [\"layers\", \"input_layers\", \"output_layers\"]:\n        functional_config[key] = config.pop(key)\n    for key in [\"name\", \"trainable\"]:\n        if key in config:\n            functional_config[key] = config.pop(key)\n        else:\n            functional_config[key] = None\n\n    # First, we create all layers and enqueue nodes to be processed\n    for layer_data in functional_config[\"layers\"]:\n        process_layer(layer_data)\n\n    # Then we process nodes in order of layer depth.\n    # Nodes that cannot yet be processed (if the inbound node\n    # does not yet exist) are re-enqueued, and the process\n    # is repeated until all nodes are processed.\n    while unprocessed_nodes:\n        for layer_data in functional_config[\"layers\"]:\n            layer = created_layers[layer_data[\"name\"]]\n\n            # Process all nodes in layer, if not yet processed\n            if layer in unprocessed_nodes:\n                node_data_list = unprocessed_nodes[layer]\n\n                # Process nodes in order\n                node_index = 0\n                while node_index < len(node_data_list):\n                    node_data = node_data_list[node_index]\n                    try:\n                        process_node(layer, node_data)\n\n                    # If the node does not have all inbound layers\n                    # available, stop processing and continue later\n                    except IndexError:\n                        break\n\n                    node_index += 1\n\n                # If not all nodes processed then store unprocessed nodes\n                if node_index < len(node_data_list):\n                    unprocessed_nodes[layer] = node_data_list[node_index:]\n                # If all nodes processed remove the layer\n                else:\n                    del unprocessed_nodes[layer]\n\n    # Create list of input and output tensors and return new class\n    name = functional_config[\"name\"]\n    trainable = functional_config[\"trainable\"]\n\n    def get_tensor(layer_name, node_index, tensor_index):\n        if layer_name not in created_layers:\n            raise RuntimeError(\n                f\"Internal error: could not find layer {layer_name}.\"\n            )\n        layer = created_layers[layer_name]\n        if isinstance(layer, Functional):\n            # Functional models start out with a built-in node.\n            node_index -= 1\n        layer_output_tensors = layer._inbound_nodes[node_index].output_tensors\n        return layer_output_tensors[tensor_index]\n\n    def map_tensors(tensors):\n        if (\n            isinstance(tensors, list)\n            and len(tensors) == 3\n            and isinstance(tensors[0], str)\n        ):\n            # Leaf\n            return get_tensor(*tensors)\n        if isinstance(tensors, dict):\n            return {k: map_tensors(v) for k, v in tensors.items()}\n        if isinstance(tensors, tuple):\n            return tuple([map_tensors(v) for v in tensors])\n        return [map_tensors(v) for v in tensors]\n\n    input_tensors = map_tensors(functional_config[\"input_layers\"])\n    output_tensors = map_tensors(functional_config[\"output_layers\"])\n\n    return cls(\n        inputs=input_tensors,\n        outputs=output_tensors,\n        name=name,\n        trainable=trainable,\n        **config,\n    )\n\n\ndef operation_fn(operation, **call_context_args):\n    \"\"\"Wraps each op to inject the call-context args.\"\"\"\n\n    def call(*args, **kwargs):\n        # Propagate all registered call-context args\n        for name, value in call_context_args.items():\n            if (\n                name in getattr(operation, \"_call_context_args\", {})\n                and value is not None\n            ):\n                kwargs[name] = value\n\n        return operation(*args, **kwargs)\n\n    return call\n\n\ndef functional_like_constructor(cls):\n    init_args = inspect.getfullargspec(cls.__init__).args[1:]\n    functional_init_args = inspect.getfullargspec(Functional.__init__).args[1:]\n    if init_args == functional_init_args:\n        return True\n    return False\n\n\ndef unpack_singleton(x):\n    if isinstance(x, (list, tuple)) and len(x) == 1:\n        return x[0]\n    return x\n\n\ndef serialize_node(node, own_nodes=()):\n    if not node.input_tensors:\n        # Does not need to be serialized.\n        return\n\n    def serialize_keras_tensor(x):\n        # Serialize KerasTensor while converting\n        # node indices to only include nodes relevant to `own_nodes`.\n        if isinstance(x, backend.KerasTensor):\n            operation, node_index, tensor_index = x._keras_history\n            irrelevant_node_count = 0\n            for i, node in enumerate(operation._inbound_nodes[:node_index]):\n                node_key = make_node_key(operation, i)\n                if node_key not in own_nodes:\n                    irrelevant_node_count += 1\n            x._keras_history = KerasHistory(\n                operation, node_index - irrelevant_node_count, tensor_index\n            )\n            serialized = serialization_lib.serialize_keras_object(x)\n            x._keras_history = KerasHistory(operation, node_index, tensor_index)\n            return serialized\n        return x\n\n    args = node.arguments.args\n    kwargs = node.arguments.kwargs\n\n    args = tree.map_structure(serialize_keras_tensor, args)\n    kwargs = tree.map_structure(serialize_keras_tensor, kwargs)\n    return {\n        \"args\": serialization_lib.serialize_keras_object(args),\n        \"kwargs\": serialization_lib.serialize_keras_object(kwargs),\n    }\n\n\ndef deserialize_node(node_data, created_layers):\n    \"\"\"Return (args, kwargs) for calling the node layer.\"\"\"\n    if not node_data:\n        return [], {}\n\n    if isinstance(node_data, list):\n        # Legacy case.\n        input_tensors = []\n        for input_data in node_data:\n            inbound_layer_name = input_data[0]\n            inbound_node_index = input_data[1]\n            inbound_tensor_index = input_data[2]\n            if len(input_data) == 3:\n                kwargs = {}\n            elif len(input_data) == 4:\n                kwargs = input_data[3]\n            else:\n                raise ValueError(\n                    \"Cannot deserialize the model (invalid config data?)\"\n                )\n            inbound_layer = created_layers[inbound_layer_name]\n\n            # Raise an error if the corresponding layer node\n            # has not yet been created\n            if len(inbound_layer._inbound_nodes) <= inbound_node_index:\n                raise IndexError(\n                    \"Layer node index out of bounds.\\n\"\n                    f\"inbound_layer = {inbound_layer}\\n\"\n                    \"inbound_layer._inbound_nodes = \"\n                    f\"{inbound_layer._inbound_nodes}\\n\"\n                    f\"inbound_node_index = {inbound_node_index}\"\n                )\n            inbound_node = inbound_layer._inbound_nodes[inbound_node_index]\n            input_tensors.append(\n                inbound_node.output_tensors[inbound_tensor_index]\n            )\n        return [unpack_singleton(input_tensors)], kwargs\n\n    args = serialization_lib.deserialize_keras_object(node_data[\"args\"])\n    kwargs = serialization_lib.deserialize_keras_object(node_data[\"kwargs\"])\n\n    def convert_revived_tensor(x):\n        if isinstance(x, backend.KerasTensor):\n            history = x._pre_serialization_keras_history\n            if history is None:\n                return x\n            layer = created_layers.get(history[0], None)\n            if layer is None:\n                raise ValueError(f\"Unknown layer: {history[0]}\")\n            inbound_node_index = history[1]\n            inbound_tensor_index = history[2]\n            if len(layer._inbound_nodes) <= inbound_node_index:\n                raise IndexError(\n                    \"Layer node index out of bounds.\\n\"\n                    f\"inbound_layer = {layer}\\n\"\n                    f\"inbound_layer._inbound_nodes = {layer._inbound_nodes}\\n\"\n                    f\"inbound_node_index = {inbound_node_index}\"\n                )\n            inbound_node = layer._inbound_nodes[inbound_node_index]\n            return inbound_node.output_tensors[inbound_tensor_index]\n        return x\n\n    args = tree.map_structure(convert_revived_tensor, args)\n    kwargs = tree.map_structure(convert_revived_tensor, kwargs)\n    return args, kwargs\n\n\ndef is_input_keras_tensor(x):\n    (\n        operation,\n        node_index,\n        _,\n    ) = x._keras_history\n    node = operation._inbound_nodes[node_index]\n    return node.is_input\n\n\ndef clone_single_keras_tensor(x):\n    return backend.KerasTensor(\n        shape=x.shape, dtype=x.dtype, sparse=x.sparse, name=f\"{x.name}_clone\"\n    )\n\n\ndef clone_keras_tensors(tensors, kt_id_mapping):\n    def swap(x):\n        if not isinstance(x, backend.KerasTensor):\n            return x\n        if id(x) in kt_id_mapping:\n            return kt_id_mapping[id(x)]\n        new_x = clone_single_keras_tensor(x)\n        kt_id_mapping[id(x)] = new_x\n        return new_x\n\n    return tree.map_structure(swap, tensors)\n\n\ndef find_nodes_by_inputs_and_outputs(inputs, outputs):\n    nodes, _ = _build_map(inputs, outputs)\n    return nodes\n\n\ndef clone_graph_nodes(inputs, outputs):\n    \"\"\"Clone the `Node` between the inputs and output tensors.\n\n    This function is used to create a new functional model from any intermediate\n    Keras tensors. The clone of the nodes mimic the behavior of reconstructing\n    the functional graph network by re-executing all the `__call__()` methods.\n    The cloned nodes will be appended to the layers.\n\n    Note that a new `keras.Input` will be created for any items in the\n    `inputs`\n\n    Args:\n    inputs: A nested structure of `KerasTensor` instances.\n    outputs: A nested structure of `KerasTensor` instances.\n\n    Returns:\n        A pair of inputs and outputs, with cloned `KerasTensor` instances.\n        They can be used to create a new functional model.\n    \"\"\"\n    nodes_to_clone = find_nodes_by_inputs_and_outputs(inputs, outputs)\n    cloned_inputs = []\n    cloned_outputs = []\n    # We not only need to create copies of Nodes (mimic the calls), also need to\n    # clone Keras tensors to avoid the override of _keras_history attached on\n    # the Keras tensor. The following dict is used to track any keras tensor we\n    # cloned The key is the string ID of the original keras tensor, and value is\n    # the cloned Keras tensor instance.\n    kt_id_mapping = {}\n    op_id_mapping = {}\n\n    for kt_input in tree.flatten(inputs):\n        if is_input_keras_tensor(kt_input):\n            # For any existing Keras tensor from keras.Input, leave them as is.\n            cloned_inputs.append(kt_input)\n            kt_id_mapping[id(kt_input)] = kt_input\n        else:\n            # We need to create a new Keras tensor for any intermediate tensor\n            original_op = kt_input._keras_history.operation\n            optional = False\n            if isinstance(original_op, InputLayer):\n                optional = original_op.optional\n            cloned_input = Input(\n                batch_shape=kt_input.shape,\n                dtype=kt_input.dtype,\n                sparse=kt_input.sparse,\n                name=f\"{kt_input.name}CLONE\",\n                optional=optional,\n            )\n            cloned_inputs.append(cloned_input)\n            kt_id_mapping[id(kt_input)] = cloned_input\n            op_id_mapping[id(kt_input._keras_history[0])] = (\n                cloned_input._keras_history[0]\n            )\n    cloned_inputs = tree.pack_sequence_as(inputs, cloned_inputs)\n\n    for kt_output in tree.flatten(outputs):\n        cpy = clone_single_keras_tensor(kt_output)\n        # We reuse the _keras_history here, which contains the old information.\n        cpy._keras_history = kt_output._keras_history\n        cloned_outputs.append(cpy)\n        kt_id_mapping[id(kt_output)] = cpy\n    cloned_outputs = tree.pack_sequence_as(outputs, cloned_outputs)\n\n    for node in nodes_to_clone:\n        if id(node.operation) in op_id_mapping:\n            operation = op_id_mapping[id(node.operation)]\n        else:\n            operation = node.operation\n        # Clone any Keras tensor to avoid override of _keras_history\n        # Or reuse an existing Keras tensor if it has already been cloned.\n        output_copy = clone_keras_tensors(node.output_tensors, kt_id_mapping)\n        if not isinstance(operation, InputLayer):\n            call_args_copy = clone_keras_tensors(\n                node.arguments.args, kt_id_mapping\n            )\n            call_kwargs_copy = clone_keras_tensors(\n                node.arguments.kwargs, kt_id_mapping\n            )\n        else:\n            call_args_copy = ()\n            call_kwargs_copy = {}\n        # Creating new nodes based on the existing node information.  Node wires\n        # itself to inbound and outbound layers.  The Node constructor actually\n        # updates this layer's self._inbound_nodes, sets _keras_history on the\n        # outputs, and adds itself to the `_outbound_nodes` of the layers that\n        # produced the inputs to this layer call.\n        Node(\n            operation,\n            call_args=call_args_copy,\n            call_kwargs=call_kwargs_copy,\n            outputs=output_copy,\n        )\n    return cloned_inputs, cloned_outputs\n"
  },
  {
    "path": "keras/src/models/functional_test.py",
    "content": "import os\nimport warnings\n\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import applications\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import ops\nfrom keras.src import saving\nfrom keras.src import testing\nfrom keras.src.backend.common.keras_tensor import KerasTensor\nfrom keras.src.dtype_policies import dtype_policy\nfrom keras.src.layers.core.input_layer import Input\nfrom keras.src.layers.input_spec import InputSpec\nfrom keras.src.models import Functional\nfrom keras.src.models import Model\nfrom keras.src.models import Sequential\nfrom keras.src.models.model import model_from_json\n\n\nclass FunctionalTest(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_basic_flow_multi_input(self):\n        input_a = Input(shape=(3,), batch_size=2, name=\"input_a\")\n        input_b = Input(shape=(3,), batch_size=2, name=\"input_b\")\n        x = input_a + input_b\n        x = layers.Dense(5)(x)\n        outputs = layers.Dense(4)(x)\n        model = Functional([input_a, input_b], outputs, name=\"basic\")\n        model.summary()\n\n        self.assertEqual(model.name, \"basic\")\n        self.assertIsInstance(model, Functional)\n        self.assertIsInstance(model, Model)\n\n        # Eager call\n        in_val = [np.random.random((2, 3)), np.random.random((2, 3))]\n        out_val = model(in_val)\n        self.assertEqual(out_val.shape, (2, 4))\n\n        # Symbolic call\n        input_a_2 = Input(shape=(3,), batch_size=2, name=\"input_a_2\")\n        input_b_2 = Input(shape=(3,), batch_size=2, name=\"input_b_2\")\n        in_val = [input_a_2, input_b_2]\n        out_val = model(in_val)\n        self.assertEqual(out_val.shape, (2, 4))\n\n    @pytest.mark.requires_trainable_backend\n    def test_scalar_input(self):\n        input_a = Input(shape=(3,), batch_size=2, name=\"input_a\")\n        input_b = Input(shape=(), batch_size=2, name=\"input_b\")\n        outputs = input_a + input_b[:, None]\n        model = Functional([input_a, input_b], outputs)\n        model.summary()\n\n        in_val = [np.zeros((2, 3)), np.ones((2,))]\n        out_val = model(in_val)\n        self.assertAllClose(out_val, np.ones((2, 3)))\n\n    @pytest.mark.requires_trainable_backend\n    def test_mutable_state(self):\n        inputs = Input(shape=(3,), batch_size=2, name=\"input\")\n        x = layers.Dense(5)(inputs)\n        outputs = layers.Dense(5)(x)\n        model = Functional(inputs, outputs)\n        # Allow attaching state to a model that isn't directly part of the DAG.\n        # Most useful for functional subclasses.\n        model.extra_layer = layers.Dense(5)\n\n    @pytest.mark.requires_trainable_backend\n    def test_basic_flow_multi_output(self):\n        inputs = Input(shape=(3,), batch_size=2, name=\"input\")\n        x = layers.Dense(5)(inputs)\n        output_a = layers.Dense(4)(x)\n        output_b = layers.Dense(5)(x)\n        model = Functional(inputs, [output_a, output_b])\n\n        # Eager call\n        in_val = np.random.random((2, 3))\n        out_val = model(in_val)\n        self.assertIsInstance(out_val, list)\n        self.assertEqual(len(out_val), 2)\n        self.assertEqual(out_val[0].shape, (2, 4))\n        self.assertEqual(out_val[1].shape, (2, 5))\n\n        # Symbolic call\n        out_val = model(Input(shape=(3,), batch_size=2))\n        self.assertIsInstance(out_val, list)\n        self.assertEqual(len(out_val), 2)\n        self.assertEqual(out_val[0].shape, (2, 4))\n        self.assertEqual(out_val[1].shape, (2, 5))\n\n    @pytest.mark.requires_trainable_backend\n    def test_basic_flow_dict_io(self):\n        input_a = Input(shape=(3,), batch_size=2, name=\"a\")\n        input_b = Input(shape=(3,), batch_size=2, name=\"b\")\n        x = input_a + input_b\n        x = layers.Dense(5)(x)\n        outputs = layers.Dense(4)(x)\n\n        with self.assertRaisesRegex(\n            ValueError, \"All `inputs` values must be KerasTensors\"\n        ):\n            model = Functional({\"a\": \"input_a\", \"b\": input_b}, outputs)\n\n        with self.assertRaisesRegex(\n            ValueError, \"All `outputs` values must be KerasTensors\"\n        ):\n            model = Functional({\"a\": input_a, \"b\": input_b}, \"outputs\")\n\n        model = Functional({\"a\": input_a, \"b\": input_b}, outputs)\n\n        # Eager call\n        in_val = {\"a\": np.random.random((2, 3)), \"b\": np.random.random((2, 3))}\n        out_val = model(in_val)\n        self.assertEqual(out_val.shape, (2, 4))\n\n        # Symbolic call\n        input_a_2 = Input(shape=(3,), batch_size=2)\n        input_b_2 = Input(shape=(3,), batch_size=2)\n        in_val = {\"a\": input_a_2, \"b\": input_b_2}\n        out_val = model(in_val)\n        self.assertEqual(out_val.shape, (2, 4))\n\n    def test_basic_flow_as_a_submodel(self):\n        # Build submodel\n        submodel_inputs = Input([4])\n        submodel_outputs = layers.Flatten()(submodel_inputs)\n        submodel = Model(submodel_inputs, submodel_outputs)\n\n        inputs = Input((None, 4))\n        outputs = layers.TimeDistributed(submodel)(inputs)\n        model = Model(inputs=inputs, outputs=outputs)\n\n        x = np.random.random((2, 3, 4))\n        y = model(x)\n        self.assertEqual(y.shape, (2, 3, 4))\n\n    @pytest.mark.requires_trainable_backend\n    def test_named_input_dict_io(self):\n        # Single input\n        input_a = Input(shape=(3,), batch_size=2, name=\"a\")\n        x = layers.Dense(5)(input_a)\n        outputs = layers.Dense(4)(x)\n        model = Functional(input_a, outputs)\n\n        # Eager call\n        in_val = {\"a\": np.random.random((2, 3))}\n        out_val = model(in_val)\n        self.assertEqual(out_val.shape, (2, 4))\n\n        # Symbolic call\n        input_a_2 = Input(shape=(3,), batch_size=2)\n        in_val = {\"a\": input_a_2}\n        out_val = model(in_val)\n        self.assertEqual(out_val.shape, (2, 4))\n\n        # ----\n        # Two inputs, input is list\n        input_a = Input(shape=(3,), batch_size=2, name=\"a\")\n        input_b = Input(shape=(4,), batch_size=2, name=\"b\")\n        a = layers.Dense(5)(input_a)\n        b = layers.Dense(5)(input_b)\n        x = layers.Concatenate()([a, b])\n        outputs = layers.Dense(4)(x)\n        model = Functional([input_a, input_b], outputs)\n\n        # Eager call\n        in_val = {\"a\": np.random.random((2, 3)), \"b\": np.random.random((2, 4))}\n        out_val = model(in_val)\n        self.assertEqual(out_val.shape, (2, 4))\n\n        # Symbolic call\n        input_a_2 = Input(shape=(3,), batch_size=2)\n        input_b_2 = Input(shape=(4,), batch_size=2)\n        in_val = {\"a\": input_a_2, \"b\": input_b_2}\n        out_val = model(in_val)\n        self.assertEqual(out_val.shape, (2, 4))\n\n        # ----\n        # Two inputs, input is dict\n        model = Functional({\"a\": input_a, \"b\": input_b}, outputs)\n\n        # Eager call\n        in_val = {\"a\": np.random.random((2, 3)), \"b\": np.random.random((2, 4))}\n        out_val = model(in_val)\n        self.assertEqual(out_val.shape, (2, 4))\n\n        # Symbolic call\n        input_a_2 = Input(shape=(3,), batch_size=2)\n        input_b_2 = Input(shape=(4,), batch_size=2)\n        in_val = {\"a\": input_a_2, \"b\": input_b_2}\n        out_val = model(in_val)\n        self.assertEqual(out_val.shape, (2, 4))\n\n        # ----\n        # Two inputs, input is dict with incorrect names\n        model = Functional({\"c\": input_a, \"d\": input_b}, outputs)\n\n        # Eager call\n        in_val = {\"c\": np.random.random((2, 3)), \"d\": np.random.random((2, 4))}\n        out_val = model(in_val)\n        self.assertEqual(out_val.shape, (2, 4))\n\n        # Symbolic call\n        input_a_2 = Input(shape=(3,), batch_size=2)\n        input_b_2 = Input(shape=(4,), batch_size=2)\n        in_val = {\"c\": input_a_2, \"d\": input_b_2}\n        out_val = model(in_val)\n        self.assertEqual(out_val.shape, (2, 4))\n\n        # Now we can't use the input names:\n        with self.assertRaises(ValueError):\n            in_val = {\n                \"a\": np.random.random((2, 3)),\n                \"b\": np.random.random((2, 4)),\n            }\n            out_val = model(in_val)\n\n    @pytest.mark.requires_trainable_backend\n    def test_input_dict_with_extra_field(self):\n        input_a = Input(shape=(3,), batch_size=2, name=\"a\")\n        x = input_a * 5\n        outputs = x + 2\n\n        model = Functional({\"a\": input_a}, outputs)\n\n        with pytest.warns() as record:\n            # Eager call\n            in_val = {\n                \"a\": np.random.random((2, 3)),\n                \"b\": np.random.random((2, 1)),\n            }\n            out_val = model(in_val)\n            self.assertEqual(out_val.shape, (2, 3))\n\n            # Symbolic call\n            input_a_2 = Input(shape=(3,), batch_size=2)\n            input_b_2 = Input(shape=(1,), batch_size=2)\n            in_val = {\"a\": input_a_2, \"b\": input_b_2}\n            out_val = model(in_val)\n            self.assertEqual(out_val.shape, (2, 3))\n        self.assertLen(record, 1)\n        self.assertStartsWith(\n            str(record[0].message),\n            r\"The structure of `inputs` doesn't match the expected structure\",\n        )\n\n    @parameterized.named_parameters(\n        (\"list\", list),\n        (\"tuple\", tuple),\n        (\"dict\", dict),\n    )\n    def test_restored_multi_output_type(self, out_type):\n        inputs = Input(shape=(3,), batch_size=2, name=\"input\")\n        x = layers.Dense(5)(inputs)\n        output_a = layers.Dense(4)(x)\n        output_b = layers.Dense(5)(x)\n        if out_type is dict:\n            outputs = {\"a\": output_a, \"b\": output_b}\n        else:\n            outputs = out_type([output_a, output_b])\n        model = Functional(inputs, outputs)\n        model_restored = Functional.from_config(model.get_config())\n\n        # Eager call\n        in_val = np.random.random((2, 3))\n        out_val = model_restored(in_val)\n        self.assertIsInstance(out_val, out_type)\n\n        # Symbolic call\n        out_val = model_restored(Input(shape=(3,), batch_size=2))\n        self.assertIsInstance(out_val, out_type)\n\n    def test_restored_nested_input(self):\n        input_a = Input(shape=(3,), batch_size=2, name=\"input_a\")\n        x = layers.Dense(5)(input_a)\n        outputs = layers.Dense(4)(x)\n        model = Functional([[input_a]], outputs)\n\n        # Serialize and deserialize the model\n        json_config = model.to_json()\n        restored_json_config = model_from_json(json_config).to_json()\n\n        # Check that the serialized model is the same as the original\n        self.assertEqual(json_config, restored_json_config)\n\n    def test_functional_input_shape_and_type(self):\n        input = layers.Input((1024, 4))\n        conv = layers.Conv1D(32, 3)(input)\n        model = Functional(input, conv)\n\n        self.assertIsInstance(model.input, KerasTensor)\n        self.assertEqual(model.input_shape, (None, 1024, 4))\n\n    @pytest.mark.requires_trainable_backend\n    def test_layer_getters(self):\n        # Test mixing ops and layers\n        input_a = Input(shape=(3,), batch_size=2, name=\"input_a\")\n        input_b = Input(shape=(3,), batch_size=2, name=\"input_b\")\n        x = input_a + input_b\n        x = layers.Dense(5, name=\"dense_1\")(x)\n        outputs = layers.Dense(4, name=\"dense_2\")(x)\n        model = Functional([input_a, input_b], outputs)\n\n        self.assertEqual(len(model.layers), 4)\n        self.assertEqual(len(model._operations), 5)\n        self.assertEqual(model.get_layer(index=0).name, \"input_a\")\n        self.assertEqual(model.get_layer(index=1).name, \"input_b\")\n        self.assertEqual(model.get_layer(index=2).name, \"dense_1\")\n        self.assertEqual(model.get_layer(index=3).name, \"dense_2\")\n        self.assertEqual(model.get_layer(name=\"dense_1\").name, \"dense_1\")\n\n    @pytest.mark.requires_trainable_backend\n    def test_training_arg(self):\n        test_obj = self\n\n        class Canary(layers.Layer):\n            def call(self, x, training=False):\n                test_obj.assertTrue(training)\n                return x\n\n            def compute_output_spec(self, x, training=False):\n                return backend.KerasTensor(x.shape, dtype=x.dtype)\n\n        inputs = Input(shape=(3,), batch_size=2)\n        outputs = Canary()(inputs)\n        model = Functional(inputs, outputs)\n        model(np.random.random((2, 3)), training=True)\n\n    def test_mask_arg(self):\n        # TODO\n        pass\n\n    @pytest.mark.requires_trainable_backend\n    def test_passing_inputs_by_name(self):\n        input_a = Input(shape=(3,), batch_size=2, name=\"input_a\")\n        input_b = Input(shape=(3,), batch_size=2, name=\"input_b\")\n        x = input_a + input_b\n        x = layers.Dense(5)(x)\n        outputs = layers.Dense(4)(x)\n        model = Functional([input_a, input_b], outputs)\n\n        # Eager call\n        in_val = {\n            \"input_a\": np.random.random((2, 3)),\n            \"input_b\": np.random.random((2, 3)),\n        }\n        out_val = model(in_val)\n        self.assertEqual(out_val.shape, (2, 4))\n\n        # Symbolic call\n        input_a_2 = Input(shape=(3,), batch_size=2, name=\"input_a_2\")\n        input_b_2 = Input(shape=(3,), batch_size=2, name=\"input_b_2\")\n        in_val = {\"input_a\": input_a_2, \"input_b\": input_b_2}\n        out_val = model(in_val)\n        self.assertEqual(out_val.shape, (2, 4))\n\n    @pytest.mark.requires_trainable_backend\n    def test_rank_standardization(self):\n        # Downranking\n        inputs = Input(shape=(3,), batch_size=2)\n        outputs = layers.Dense(3)(inputs)\n        model = Functional(inputs, outputs)\n        out_val = model(np.random.random((2, 3, 1)))\n        self.assertEqual(out_val.shape, (2, 3))\n\n        # Upranking\n        inputs = Input(shape=(3, 1), batch_size=2)\n        outputs = layers.Dense(3)(inputs)\n        model = Functional(inputs, outputs)\n        out_val = model(np.random.random((2, 3)))\n        self.assertEqual(out_val.shape, (2, 3, 3))\n\n    @pytest.mark.requires_trainable_backend\n    def test_rank_standardization_failure(self):\n        # Simple input and rank too high\n        inputs = Input(shape=(3,), name=\"foo\")\n        outputs = layers.Dense(3)(inputs)\n        model = Functional(inputs, outputs)\n        with self.assertRaisesRegex(ValueError, \"name 'foo' .* path ''\"):\n            model(np.random.random((2, 3, 4)))\n\n        # Deeply nested input and rank too low\n        inputs = [{\"foo\": Input(shape=(3,), name=\"my_input\")}]\n        outputs = layers.Dense(3)(inputs[0][\"foo\"])\n        model = Functional(inputs, outputs)\n        with self.assertRaisesRegex(\n            ValueError, \"name 'my_input' .* path '0.foo'\"\n        ):\n            model(np.random.random(()))\n\n    @pytest.mark.requires_trainable_backend\n    def test_dtype_standardization(self):\n        float_input = Input(shape=(2,), dtype=\"float16\")\n        int_input = Input(shape=(2,), dtype=\"int32\")\n        float_output = float_input + 2\n        int_output = int_input + 2\n        model = Functional((float_input, int_input), (float_output, int_output))\n        float_data, int_data = model((np.ones((2, 2)), np.ones((2, 2))))\n\n        self.assertEqual(backend.standardize_dtype(float_data.dtype), \"float16\")\n        self.assertEqual(backend.standardize_dtype(int_data.dtype), \"int32\")\n\n    @pytest.mark.requires_trainable_backend\n    def test_serialization(self):\n        # Test basic model\n        inputs = Input(shape=(3,), batch_size=2)\n        outputs = layers.Dense(3)(inputs)\n        model = Functional(inputs, outputs, trainable=False)\n        self.run_class_serialization_test(model)\n\n        # Test multi-io model\n        input_a = Input(shape=(3,), batch_size=2, name=\"input_a\")\n        input_b = Input(shape=(3,), batch_size=2, name=\"input_b\")\n        xa = layers.Dense(5, name=\"middle_a\")(input_a)\n        xb = layers.Dense(5, name=\"middle_b\")(input_b)\n        output_a = layers.Dense(4, name=\"output_a\")(xa)\n        output_b = layers.Dense(4, name=\"output_b\")(xb)\n        model = Functional(\n            [input_a, input_b], [output_a, output_b], name=\"func\"\n        )\n        self.run_class_serialization_test(model)\n\n        # Test model that includes floating ops\n        input_a = Input(shape=(3,), batch_size=2, name=\"input_a\")\n        input_b = Input(shape=(3,), batch_size=2, name=\"input_b\")\n        x = input_a + input_b\n        x = layers.Dense(5, name=\"middle\")(x)\n        output_a = layers.Dense(4, name=\"output_a\")(x)\n        output_b = layers.Dense(4, name=\"output_b\")(x)\n        model = Functional(\n            [input_a, input_b], [output_a, output_b], name=\"func\"\n        )\n        self.run_class_serialization_test(model)\n\n        # Test model with dict i/o\n        input_a = Input(shape=(3,), batch_size=2, name=\"a\")\n        input_b = Input(shape=(3,), batch_size=2, name=\"b\")\n        x = input_a + input_b\n        x = layers.Dense(5)(x)\n        outputs = layers.Dense(4)(x)\n        model = Functional({\"a\": input_a, \"b\": input_b}, outputs)\n        self.run_class_serialization_test(model)\n\n        # Test model with unmodified input as output\n        input_a = Input(shape=(3,), batch_size=2, name=\"a\")\n        input_b = Input(shape=(3,), batch_size=2, name=\"b\")\n        output_a = input_a * 2\n        output_b = input_b\n        model = Functional(\n            {\"a\": input_a, \"b\": input_b}, {\"a\": output_a, \"b\": output_b}\n        )\n        self.run_class_serialization_test(model)\n\n        # Test model with unused input\n        input_a = Input(shape=(3,), batch_size=2, name=\"a\")\n        input_b = Input(shape=(3,), batch_size=2, name=\"b\")\n        output_a = input_a * 2\n        model = Functional({\"a\": input_a, \"b\": input_b}, output_a)\n        self.run_class_serialization_test(model)\n\n    @pytest.mark.requires_trainable_backend\n    def test_bad_input_spec(self):\n        # Single input\n        inputs = Input(shape=(4,))\n        outputs = layers.Dense(2)(inputs)\n        model = Functional(inputs, outputs)\n        with self.assertRaisesRegex(\n            ValueError, r\"expected shape=\\(None, 4\\), found shape=\\(2, 3\\)\"\n        ):\n            model(np.zeros((2, 3)))\n        with self.assertRaisesRegex(ValueError, \"expects 1 input\"):\n            model([np.zeros((2, 4)), np.zeros((2, 4))])\n\n        # List input\n        input_a = Input(shape=(4,), name=\"a\")\n        input_b = Input(shape=(4,), name=\"b\")\n        x = input_a + input_b\n        outputs = layers.Dense(2)(x)\n        model = Functional([input_a, input_b], outputs)\n        with self.assertRaisesRegex(ValueError, \"expects 2 input\"):\n            model(np.zeros((2, 3)))\n        with self.assertRaisesRegex(\n            ValueError, r\"expected shape=\\(None, 4\\), found shape=\\(2, 3\\)\"\n        ):\n            model([np.zeros((2, 3)), np.zeros((2, 4))])\n\n        # Dict input\n        model = Functional({\"a\": input_a, \"b\": input_b}, outputs)\n        with self.assertRaisesRegex(ValueError, \"expects 2 input\"):\n            model(np.zeros((2, 3)))\n        with self.assertRaisesRegex(\n            ValueError, r\"expected shape=\\(None, 4\\), found shape=\\(2, 3\\)\"\n        ):\n            model({\"a\": np.zeros((2, 3)), \"b\": np.zeros((2, 4))})\n\n    @pytest.mark.requires_trainable_backend\n    def test_manual_input_spec(self):\n        inputs = Input(shape=(None, 3))\n        outputs = layers.Dense(2)(inputs)\n        model = Functional(inputs, outputs)\n        model.input_spec = InputSpec(shape=(None, 4, 3))\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"expected shape=\\(None, 4, 3\\), found shape=\\(2, 3, 3\\)\",\n        ):\n            model(np.zeros((2, 3, 3)))\n        model(np.zeros((2, 4, 3)))\n\n    def test_functional_slicing(self):\n        inputs = Input(shape=(None, 2), name=\"input\")\n        x1 = layers.Dense(3, name=\"dense1\")(inputs)\n        x2 = layers.Dense(4, name=\"dense2\")(x1)\n        outputs = layers.Dense(5, name=\"dense3\")(x2)\n\n        full_model = Functional(inputs, outputs, name=\"full_model\")\n        self.assertLen(full_model.layers, 4)\n\n        partial_model_1 = Functional(x2, outputs, name=\"partial1\")\n        self.assertLen(partial_model_1.layers, 2)  # input_layer, dense3\n        self.assertIsInstance(partial_model_1.layers[0], layers.InputLayer)\n        self.assertEqual(partial_model_1.layers[1].name, \"dense3\")\n\n        partial_model_2 = Functional(x1, x2, name=\"partial2\")\n        self.assertLen(partial_model_2.layers, 2)  # input_layer, dense2\n        self.assertIsInstance(partial_model_2.layers[0], layers.InputLayer)\n        self.assertEqual(partial_model_2.layers[1].name, \"dense2\")\n\n        partial_model_3 = Functional(\n            full_model.get_layer(\"dense2\").input, outputs, name=\"partial3\"\n        )\n        self.assertLen(partial_model_3.layers, 3)  # input_layer, dense2, dense3\n        self.assertIsInstance(partial_model_3.layers[0], layers.InputLayer)\n        self.assertEqual(partial_model_3.layers[1].name, \"dense2\")\n        self.assertEqual(partial_model_3.layers[2].name, \"dense3\")\n\n        partial_model_4 = Functional(\n            full_model.get_layer(\"dense1\").input,\n            full_model.get_layer(\"dense2\").output,\n            name=\"partial4\",\n        )\n        self.assertLen(partial_model_4.layers, 3)  # input_layer, dense1, dense2\n        self.assertIsInstance(partial_model_4.layers[0], layers.InputLayer)\n        self.assertEqual(partial_model_4.layers[1].name, \"dense1\")\n        self.assertEqual(partial_model_4.layers[2].name, \"dense2\")\n\n    def test_deeply_nested_model(self):\n        i1, i2, i3 = Input((1,)), Input((2,)), Input((3,))\n        o1, o2, o3 = (\n            layers.Dense(1)(i1),\n            layers.Dense(2)(i2),\n            layers.Dense(3)(i3),\n        )\n        model = Model(\n            {\"1\": i1, \"others\": {\"2\": i2, \"3\": i3}},\n            {\"1\": o1, \"others\": {\"2\": o2, \"3\": o3}},\n        )\n        out_eager = model(\n            {\n                \"1\": np.ones((2, 1)),\n                \"others\": {\"2\": np.ones((2, 2)), \"3\": np.ones((2, 3))},\n            }\n        )\n        out_symbolic = model(\n            {\n                \"1\": Input((1,), batch_size=2),\n                \"others\": {\n                    \"2\": Input((2,), batch_size=2),\n                    \"3\": Input((3,), batch_size=2),\n                },\n            }\n        )\n        for out in [out_eager, out_symbolic]:\n            self.assertIsInstance(out, dict)\n            self.assertEqual(set(out.keys()), {\"1\", \"others\"})\n            self.assertEqual(out[\"1\"].shape, (2, 1))\n            self.assertIsInstance(out[\"others\"], dict)\n            self.assertEqual(set(out[\"others\"].keys()), {\"2\", \"3\"})\n            self.assertEqual(out[\"others\"][\"2\"].shape, (2, 2))\n            self.assertEqual(out[\"others\"][\"3\"].shape, (2, 3))\n\n        # Test serialization boundaries\n        temp_filepath = os.path.join(self.get_temp_dir(), \"deeply_nested.keras\")\n        model.save(temp_filepath)\n        loaded_model = saving.load_model(temp_filepath)\n        new_out_eager = loaded_model(\n            {\n                \"1\": np.ones((2, 1)),\n                \"others\": {\"2\": np.ones((2, 2)), \"3\": np.ones((2, 3))},\n            }\n        )\n        self.assertAllClose(out_eager[\"1\"], new_out_eager[\"1\"])\n        self.assertAllClose(\n            out_eager[\"others\"][\"2\"], new_out_eager[\"others\"][\"2\"]\n        )\n        self.assertAllClose(\n            out_eager[\"others\"][\"3\"], new_out_eager[\"others\"][\"3\"]\n        )\n\n    def test_optional_inputs(self):\n        class OptionalInputLayer(layers.Layer):\n            def call(self, x, y=None):\n                if y is not None:\n                    return x + y\n                return x\n\n            def compute_output_shape(self, x_shape):\n                return x_shape\n\n        i1 = Input((2,))\n        i2 = Input((2,), optional=True)\n        outputs = OptionalInputLayer()(i1, i2)\n        model = Model([i1, i2], outputs)\n\n        # Eager test\n        out = model([np.ones((2, 2)), None])\n        self.assertAllClose(out, np.ones((2, 2)))\n        # Note: it's not intended to work in symbolic mode (yet).\n\n    def test_optional_dict_inputs(self):\n        class OptionalInputLayer(layers.Layer):\n            def call(self, x, y=None):\n                if y is not None:\n                    return x + y\n                return x\n\n            def compute_output_shape(self, x_shape):\n                return x_shape\n\n        i1 = Input((2,), name=\"input1\")\n        i2 = Input((2,), name=\"input2\", optional=True)\n        outputs = OptionalInputLayer()(i1, i2)\n        model = Model({\"input1\": i1, \"input2\": i2}, outputs)\n\n        # Eager test\n        out = model({\"input1\": np.ones((2, 2)), \"input2\": None})\n        self.assertAllClose(out, np.ones((2, 2)))\n        # Note: it's not intended to work in symbolic mode (yet).\n\n    def test_unmodified_inputs(self):\n        i1 = Input((2,), name=\"input1\")\n        i2 = Input((2,), name=\"input2\")\n        o1 = i1 * 2\n        o2 = i2\n        model = Model(\n            {\"input1\": i1, \"input2\": i2}, {\"output1\": o1, \"output2\": o2}\n        )\n\n        # Eager call\n        out = model({\"input1\": np.ones((2, 2)), \"input2\": np.zeros((2, 2))})\n        self.assertAllClose(out[\"output1\"], np.ones((2, 2)) * 2)\n        self.assertAllClose(out[\"output2\"], np.zeros((2, 2)))\n\n        # Symbolic call\n        i1_symbolic = Input((2,))\n        i2_symbolic = Input((2,))\n        out_symbolic = model({\"input1\": i1_symbolic, \"input2\": i2_symbolic})\n        self.assertIsInstance(out_symbolic, dict)\n        self.assertEqual(out_symbolic[\"output1\"].shape, (None, 2))\n        self.assertEqual(out_symbolic[\"output2\"].shape, (None, 2))\n\n    def test_unused_inputs(self):\n        i1 = Input((2,), name=\"input1\")\n        i2 = Input((2,), name=\"input2\")\n        o1 = i1 * 2\n        model = Model({\"input1\": i1, \"input2\": i2}, o1)\n\n        # Eager call\n        out = model({\"input1\": np.ones((2, 2)), \"input2\": np.zeros((2, 2))})\n        self.assertAllClose(out, np.ones((2, 2)) * 2)\n\n        # Symbolic call\n        i1_symbolic = Input((2,))\n        i2_symbolic = Input((2,))\n        out_symbolic = model({\"input1\": i1_symbolic, \"input2\": i2_symbolic})\n        self.assertEqual(out_symbolic.shape, (None, 2))\n\n    def test_disconnected_output(self):\n        i1 = Input((2,), name=\"input1\")\n        i2 = Input((2,), name=\"input2\")\n        o1 = i1 * 2\n        o2 = i2 * 3\n        with self.assertRaisesRegex(\n            ValueError, \"Output with path `output2` is not connected\"\n        ):\n            Model(i1, {\"output1\": o1, \"output2\": o2})\n\n    def test_warning_for_mismatched_inputs_structure(self):\n        def is_input_warning(w):\n            return str(w.message).startswith(\n                \"The structure of `inputs` doesn't match the expected structure\"\n            )\n\n        i1 = Input((2,))\n        i2 = Input((2,))\n        outputs = layers.Add()([i1, i2])\n\n        model = Model({\"i1\": i1, \"i2\": i2}, outputs)\n        with pytest.warns() as warning_logs:\n            model.predict([np.ones((2, 2)), np.zeros((2, 2))], verbose=0)\n            self.assertLen(list(filter(is_input_warning, warning_logs)), 1)\n        # No warning for mismatched tuples and lists.\n        model = Model([i1, i2], outputs)\n        with warnings.catch_warnings(record=True) as warning_logs:\n            model.predict((np.ones((2, 2)), np.zeros((2, 2))), verbose=0)\n            self.assertLen(list(filter(is_input_warning, warning_logs)), 0)\n\n    def test_for_functional_in_sequential(self):\n        # Test for a v3.4.1 regression.\n        if backend.image_data_format() == \"channels_first\":\n            image_size = (3, 100, 100)\n        else:\n            image_size = (100, 100, 3)\n        base_model = applications.mobilenet.MobileNet(\n            include_top=False, weights=None\n        )\n        model = Sequential()\n        model.add(layers.Input(shape=image_size))\n        model.add(base_model)\n        model.add(layers.GlobalAveragePooling2D())\n        model.add(layers.Dense(7, activation=\"softmax\"))\n        config = model.get_config()\n        model = Sequential.from_config(config)\n\n    def test_add_loss(self):\n        # TODO\n        pass\n\n    def test_layers_setter(self):\n        inputs = Input(shape=(3,), batch_size=2, name=\"input\")\n        outputs = layers.Dense(5)(inputs)\n        model = Functional(inputs, outputs)\n        with self.assertRaisesRegex(\n            AttributeError, \"`Model.layers` attribute is reserved\"\n        ):\n            model.layers = [layers.Dense(4)]\n\n    @pytest.mark.requires_trainable_backend\n    def test_dict_input_to_list_model(self):\n        vocabulary_size = 100\n        num_tags = 10\n        num_departments = 3\n        num_samples = 128\n\n        title = layers.Input(shape=(vocabulary_size,), name=\"title\")\n        text_body = layers.Input(shape=(vocabulary_size,), name=\"text_body\")\n        tags = layers.Input(shape=(num_tags,), name=\"tags\")\n        features = layers.Concatenate()([title, text_body, tags])\n        features = layers.Dense(64, activation=\"relu\")(features)\n        priority = layers.Dense(1, activation=\"sigmoid\", name=\"priority\")(\n            features\n        )\n        department = layers.Dense(\n            num_departments, activation=\"softmax\", name=\"department\"\n        )(features)\n        model = Functional(\n            inputs=[title, text_body, tags], outputs=[priority, department]\n        )\n\n        title_data = np.random.randint(\n            0, 2, size=(num_samples, vocabulary_size)\n        )\n        text_body_data = np.random.randint(\n            0, 2, size=(num_samples, vocabulary_size)\n        )\n        tags_data = np.random.randint(0, 2, size=(num_samples, num_tags))\n        priority_data = np.random.random(size=(num_samples, 1))\n        department_data = np.random.randint(\n            0, 2, size=(num_samples, num_departments)\n        )\n\n        # List style fit\n        model.compile(\n            optimizer=\"adam\",\n            loss=[\"mean_squared_error\", \"categorical_crossentropy\"],\n            metrics=[[\"mean_absolute_error\"], [\"accuracy\"]],\n        )\n        model.fit(\n            [title_data, text_body_data, tags_data],\n            [priority_data, department_data],\n            epochs=1,\n        )\n        model.evaluate(\n            [title_data, text_body_data, tags_data],\n            [priority_data, department_data],\n        )\n        priority_preds, department_preds = model.predict(\n            [title_data, text_body_data, tags_data]\n        )\n\n        # Dict style fit\n        model.compile(\n            optimizer=\"adam\",\n            loss={\n                \"priority\": \"mean_squared_error\",\n                \"department\": \"categorical_crossentropy\",\n            },\n            metrics={\n                \"priority\": [\"mean_absolute_error\"],\n                \"department\": [\"accuracy\"],\n            },\n        )\n        model.fit(\n            {\n                \"title\": title_data,\n                \"text_body\": text_body_data,\n                \"tags\": tags_data,\n            },\n            {\"priority\": priority_data, \"department\": department_data},\n            epochs=1,\n        )\n        model.evaluate(\n            {\n                \"title\": title_data,\n                \"text_body\": text_body_data,\n                \"tags\": tags_data,\n            },\n            {\"priority\": priority_data, \"department\": department_data},\n        )\n        priority_preds, department_preds = model.predict(\n            {\n                \"title\": title_data,\n                \"text_body\": text_body_data,\n                \"tags\": tags_data,\n            }\n        )\n\n    def test_list_input_with_dict_build(self):\n        x1 = Input((10,), name=\"IT\")\n        x2 = Input((10,), name=\"IS\")\n        y = layers.subtract([x1, x2])\n        model = Model(inputs={\"IT\": x1, \"IS\": x2}, outputs=y)\n        x1 = ops.ones((1, 10))\n        x2 = ops.zeros((1, 10))\n        # Works\n        _ = model({\"IT\": x1, \"IS\": x2})\n        with self.assertRaisesRegex(\n            ValueError,\n            \"The structure of `inputs` doesn't match the expected structure\",\n        ):\n            model([x1, x2])\n\n    def test_functional_with_dtype_policy(self):\n        original_dtype_policy = dtype_policy.dtype_policy()\n        try:\n            dtype_policy.set_dtype_policy(\"mixed_float16\")\n\n            inputs = Input((10,), name=\"input\")\n            outputs = layers.Dense(5)(inputs)\n            model = Model(inputs=inputs, outputs=outputs)\n\n            # Verify that no cast node appears in the graph.\n            self.assertLen(model.operations, 2)\n            self.assertIsInstance(model.operations[0], layers.InputLayer)\n            self.assertIsInstance(model.operations[1], layers.Dense)\n        finally:\n            dtype_policy.set_dtype_policy(original_dtype_policy)\n"
  },
  {
    "path": "keras/src/models/model.py",
    "content": "import inspect\nimport json\nimport typing\nimport warnings\nfrom collections.abc import Callable\n\nfrom keras.src import backend\nfrom keras.src import utils\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.layer import Layer\nfrom keras.src.models.variable_mapping import map_saveable_variables\nfrom keras.src.quantizers.awq_core import awq_quantize\nfrom keras.src.quantizers.gptq_core import gptq_quantize\nfrom keras.src.quantizers.utils import should_quantize_layer\nfrom keras.src.saving import saving_api\nfrom keras.src.trainers import trainer as base_trainer\nfrom keras.src.utils import summary_utils\nfrom keras.src.utils import traceback_utils\n\nif backend.backend() == \"tensorflow\":\n    from keras.src.backend.tensorflow.trainer import (\n        TensorFlowTrainer as Trainer,\n    )\nelif backend.backend() == \"jax\":\n    from keras.src.backend.jax.trainer import JAXTrainer as Trainer\nelif backend.backend() == \"torch\":\n    from keras.src.backend.torch.trainer import TorchTrainer as Trainer\nelif backend.backend() == \"numpy\":\n    from keras.src.backend.numpy.trainer import NumpyTrainer as Trainer\nelif backend.backend() == \"openvino\":\n    from keras.src.backend.openvino.trainer import OpenVINOTrainer as Trainer\nelse:\n    raise RuntimeError(\n        f\"Backend '{backend.backend()}' must implement the Trainer class.\"\n    )\n\n\n@keras_export([\"keras.Model\", \"keras.models.Model\"])\nclass Model(Trainer, base_trainer.Trainer, Layer):\n    \"\"\"A model grouping layers into an object with training/inference features.\n\n    There are three ways to instantiate a `Model`:\n\n    ## With the \"Functional API\"\n\n    You start from `Input`,\n    you chain layer calls to specify the model's forward pass,\n    and finally, you create your model from inputs and outputs:\n\n    ```python\n    inputs = keras.Input(shape=(37,))\n    x = keras.layers.Dense(32, activation=\"relu\")(inputs)\n    outputs = keras.layers.Dense(5, activation=\"softmax\")(x)\n    model = keras.Model(inputs=inputs, outputs=outputs)\n    ```\n\n    Note: Only dicts, lists, and tuples of input tensors are supported. Nested\n    inputs are not supported (e.g. lists of list or dicts of dict).\n\n    A new Functional API model can also be created by using the\n    intermediate tensors. This enables you to quickly extract sub-components\n    of the model.\n\n    Example:\n\n    ```python\n    inputs = keras.Input(shape=(None, None, 3))\n    processed = keras.layers.RandomCrop(width=128, height=128)(inputs)\n    conv = keras.layers.Conv2D(filters=32, kernel_size=3)(processed)\n    pooling = keras.layers.GlobalAveragePooling2D()(conv)\n    feature = keras.layers.Dense(10)(pooling)\n\n    full_model = keras.Model(inputs, feature)\n    backbone = keras.Model(processed, conv)\n    activations = keras.Model(conv, feature)\n    ```\n\n    Note that the `backbone` and `activations` models are not\n    created with `keras.Input` objects, but with the tensors that originate\n    from `keras.Input` objects. Under the hood, the layers and weights will\n    be shared across these models, so that user can train the `full_model`, and\n    use `backbone` or `activations` to do feature extraction.\n    The inputs and outputs of the model can be nested structures of tensors as\n    well, and the created models are standard Functional API models that support\n    all the existing APIs.\n\n    ## By subclassing the `Model` class\n\n    In that case, you should define your\n    layers in `__init__()` and you should implement the model's forward pass\n    in `call()`.\n\n    ```python\n    class MyModel(keras.Model):\n        def __init__(self):\n            super().__init__()\n            self.dense1 = keras.layers.Dense(32, activation=\"relu\")\n            self.dense2 = keras.layers.Dense(5, activation=\"softmax\")\n\n        def call(self, inputs):\n            x = self.dense1(inputs)\n            return self.dense2(x)\n\n    model = MyModel()\n    ```\n\n    If you subclass `Model`, you can optionally have\n    a `training` argument (boolean) in `call()`, which you can use to specify\n    a different behavior in training and inference:\n\n    ```python\n    class MyModel(keras.Model):\n        def __init__(self):\n            super().__init__()\n            self.dense1 = keras.layers.Dense(32, activation=\"relu\")\n            self.dense2 = keras.layers.Dense(5, activation=\"softmax\")\n            self.dropout = keras.layers.Dropout(0.5)\n\n        def call(self, inputs, training=False):\n            x = self.dense1(inputs)\n            x = self.dropout(x, training=training)\n            return self.dense2(x)\n\n    model = MyModel()\n    ```\n\n    Once the model is created, you can config the model with losses and metrics\n    with `model.compile()`, train the model with `model.fit()`, or use the model\n    to do prediction with `model.predict()`.\n\n    ## With the `Sequential` class\n\n    In addition, `keras.Sequential` is a special case of model where\n    the model is purely a stack of single-input, single-output layers.\n\n    ```python\n    model = keras.Sequential([\n        keras.Input(shape=(None, None, 3)),\n        keras.layers.Conv2D(filters=32, kernel_size=3),\n    ])\n    ```\n    \"\"\"\n\n    def __new__(cls, *args, **kwargs):\n        # Signature detection for usage of `Model` as a `Functional`\n        if functional_init_arguments(args, kwargs) and cls == Model:\n            from keras.src.models.functional import Functional\n\n            return Functional.__new__(Functional, *args, **kwargs)\n        return typing.cast(cls, super().__new__(cls))\n\n    def __init__(self, *args, **kwargs):\n        Trainer.__init__(self)\n        from keras.src.models import functional\n\n        # Signature detection for usage of a `Model` subclass\n        # as a `Functional` subclass\n        if functional_init_arguments(args, kwargs):\n            inject_functional_model_class(self.__class__)\n            functional.Functional.__init__(self, *args, **kwargs)\n        else:\n            Layer.__init__(self, *args, **kwargs)\n\n    def call(self, *args, **kwargs):\n        raise NotImplementedError(\n            f\"Model {self.__class__.__name__} does not have a `call()` \"\n            \"method implemented.\"\n        )\n\n    @property\n    def layers(self):\n        return list(self._flatten_layers(include_self=False, recursive=False))\n\n    @layers.setter\n    def layers(self, _):\n        raise AttributeError(\n            \"`Model.layers` attribute is reserved and should not be used. \"\n            \"Please use another name.\"\n        )\n\n    @traceback_utils.filter_traceback\n    def get_layer(self, name=None, index=None):\n        \"\"\"Retrieves a layer based on either its name (unique) or index.\n\n        If `name` and `index` are both provided, `index` will take precedence.\n        Indices are based on order of horizontal graph traversal (bottom-up).\n\n        Args:\n            name: String, name of layer.\n            index: Integer, index of layer.\n\n        Returns:\n            A layer instance.\n        \"\"\"\n        if index is not None and name is not None:\n            raise ValueError(\n                \"Provide only a layer name or a layer index. Received: \"\n                f\"index={index}, name={name}.\"\n            )\n        if index is not None:\n            if len(self.layers) <= index:\n                raise ValueError(\n                    f\"Was asked to retrieve layer at index {index}\"\n                    f\" but model only has {len(self.layers)}\"\n                    \" layers.\"\n                )\n            else:\n                return self.layers[index]\n\n        if name is not None:\n            for layer in self.layers:\n                if layer.name == name:\n                    return layer\n            raise ValueError(\n                f\"No such layer: {name}. Existing layers are: \"\n                f\"{list(layer.name for layer in self.layers)}.\"\n            )\n        raise ValueError(\n            \"Provide either a layer name or layer index at `get_layer`.\"\n        )\n\n    @traceback_utils.filter_traceback\n    def summary(\n        self,\n        line_length=None,\n        positions=None,\n        print_fn=None,\n        expand_nested=False,\n        show_trainable=False,\n        layer_range=None,\n    ):\n        \"\"\"Prints a string summary of the network.\n\n        Args:\n            line_length: Total length of printed lines\n                (e.g. set this to adapt the display to different\n                terminal window sizes).\n            positions: Relative or absolute positions of log elements\n                in each line. If not provided, becomes\n                `[0.3, 0.6, 0.70, 1.]`. Defaults to `None`.\n            print_fn: Print function to use. By default, prints to `stdout`.\n                If `stdout` doesn't work in your environment, change to `print`.\n                It will be called on each line of the summary.\n                You can set it to a custom function\n                in order to capture the string summary.\n            expand_nested: Whether to expand the nested models.\n                Defaults to `False`.\n            show_trainable: Whether to show if a layer is trainable.\n                Defaults to `False`.\n            layer_range: a list or tuple of 2 strings,\n                which is the starting layer name and ending layer name\n                (both inclusive) indicating the range of layers to be printed\n                in summary. It also accepts regex patterns instead of exact\n                names. In this case, the start predicate will be\n                the first element that matches `layer_range[0]`\n                and the end predicate will be the last element\n                that matches `layer_range[1]`.\n                By default `None` considers all layers of the model.\n\n        Raises:\n            ValueError: if `summary()` is called before the model is built.\n        \"\"\"\n        summary_utils.print_summary(\n            self,\n            line_length=line_length,\n            positions=positions,\n            print_fn=print_fn,\n            expand_nested=expand_nested,\n            show_trainable=show_trainable,\n            layer_range=layer_range,\n        )\n\n    @traceback_utils.filter_traceback\n    def save(self, filepath, overwrite=True, zipped=None, **kwargs):\n        \"\"\"Saves a model as a `.keras` file.\n\n        Note that `model.save()` is an alias for `keras.saving.save_model()`.\n\n        The saved `.keras` file contains:\n\n        - The model's configuration (architecture)\n        - The model's weights\n        - The model's optimizer's state (if any)\n\n        Thus models can be reinstantiated in the exact same state.\n\n        Args:\n            filepath: `str` or `pathlib.Path` object.\n                The path where to save the model. Must end in `.keras`\n                (unless saving the model as an unzipped directory\n                via `zipped=False`).\n            overwrite: Whether we should overwrite any existing model at\n                the target location, or instead ask the user via\n                an interactive prompt.\n            zipped: Whether to save the model as a zipped `.keras`\n                archive (default when saving locally), or as an\n                unzipped directory (default when saving on the\n                Hugging Face Hub).\n\n        Example:\n\n        ```python\n        model = keras.Sequential(\n            [\n                keras.layers.Dense(5, input_shape=(3,)),\n                keras.layers.Softmax(),\n            ],\n        )\n        model.save(\"model.keras\")\n        loaded_model = keras.saving.load_model(\"model.keras\")\n        x = keras.random.uniform((10, 3))\n        assert np.allclose(model.predict(x), loaded_model.predict(x))\n        ```\n        \"\"\"\n        return saving_api.save_model(\n            self, filepath, overwrite=overwrite, zipped=zipped, **kwargs\n        )\n\n    @traceback_utils.filter_traceback\n    def save_weights(self, filepath, overwrite=True, max_shard_size=None):\n        \"\"\"Saves all weights to a single file or sharded files.\n\n        By default, the weights will be saved in a single `.weights.h5` file.\n        If sharding is enabled (`max_shard_size` is not `None`), the weights\n        will be saved in multiple files, each with a size at most\n        `max_shard_size` (in GB). Additionally, a configuration file\n        `.weights.json` will contain the metadata for the sharded files.\n\n        The saved sharded files contain:\n\n        - `*.weights.json`: The configuration file containing 'metadata' and\n            'weight_map'.\n        - `*_xxxxxx.weights.h5`: The sharded files containing only the\n            weights.\n\n        Args:\n            filepath: `str` or `pathlib.Path` object. Path where the weights\n                will be saved.  When sharding, the filepath must end in\n                `.weights.json`. If `.weights.h5` is provided, it will be\n                overridden.\n            overwrite: Whether to overwrite any existing weights at the target\n                location or instead ask the user via an interactive prompt.\n            max_shard_size: `int` or `float`. Maximum size in GB for each\n                sharded file. If `None`, no sharding will be done. Defaults to\n                `None`.\n\n        Example:\n\n        ```python\n        # Instantiate a EfficientNetV2L model with about 454MB of weights.\n        model = keras.applications.EfficientNetV2L(weights=None)\n\n        # Save the weights in a single file.\n        model.save_weights(\"model.weights.h5\")\n\n        # Save the weights in sharded files. Use `max_shard_size=0.25` means\n        # each sharded file will be at most ~250MB.\n        model.save_weights(\"model.weights.json\", max_shard_size=0.25)\n\n        # Load the weights in a new model with the same architecture.\n        loaded_model = keras.applications.EfficientNetV2L(weights=None)\n        loaded_model.load_weights(\"model.weights.h5\")\n        x = keras.random.uniform((1, 480, 480, 3))\n        assert np.allclose(model.predict(x), loaded_model.predict(x))\n\n        # Load the sharded weights in a new model with the same architecture.\n        loaded_model = keras.applications.EfficientNetV2L(weights=None)\n        loaded_model.load_weights(\"model.weights.json\")\n        x = keras.random.uniform((1, 480, 480, 3))\n        assert np.allclose(model.predict(x), loaded_model.predict(x))\n        ```\n        \"\"\"\n        return saving_api.save_weights(\n            self, filepath, overwrite=overwrite, max_shard_size=max_shard_size\n        )\n\n    @traceback_utils.filter_traceback\n    def load_weights(self, filepath, skip_mismatch=False, **kwargs):\n        \"\"\"Load the weights from a single file or sharded files.\n\n        Weights are loaded based on the network's topology. This means the\n        architecture should be the same as when the weights were saved. Note\n        that layers that don't have weights are not taken into account in the\n        topological ordering, so adding or removing layers is fine as long as\n        they don't have weights.\n\n        **Partial weight loading**\n\n        If you have modified your model, for instance by adding a new layer\n        (with weights) or by changing the shape of the weights of a layer, you\n        can choose to ignore errors and continue loading by setting\n        `skip_mismatch=True`. In this case any layer with mismatching weights\n        will be skipped. A warning will be displayed for each skipped layer.\n\n        **Sharding**\n\n        When loading sharded weights, it is important to specify `filepath` that\n        ends with `*.weights.json` which is used as the configuration file.\n        Additionally, the sharded files `*_xxxxx.weights.h5` must be in the same\n        directory as the configuration file.\n\n        Args:\n            filepath: `str` or `pathlib.Path` object. Path where the weights\n                will be saved.  When sharding, the filepath must end in\n                `.weights.json`.\n            skip_mismatch: Boolean, whether to skip loading of layers where\n                there is a mismatch in the number of weights, or a mismatch in\n                the shape of the weights.\n\n        Example:\n\n        ```python\n        # Load the weights in a single file.\n        model.load_weights(\"model.weights.h5\")\n\n        # Load the weights in sharded files.\n        model.load_weights(\"model.weights.json\")\n        ```\n        \"\"\"\n        saving_api.load_weights(\n            self,\n            filepath,\n            skip_mismatch=skip_mismatch,\n            **kwargs,\n        )\n\n    def get_quantization_layer_structure(self, mode=None):\n        \"\"\"Returns the quantization structure for the model.\n\n        This method is intended to be overridden by model authors to provide\n        topology information required for structure-aware quantization modes\n        like 'gptq'.\n\n        Args:\n            mode: The quantization mode.\n\n        Returns:\n            A dictionary describing the topology, e.g.:\n            `{'pre_block_layers': [list], 'sequential_blocks': [list]}`\n            or `None` if the mode does not require structure or is not\n            supported. `'pre_block_layers'` is a list of layers that\n            the inputs should be passed through, before being passed to\n            the sequential blocks. For example, inputs to an LLM must\n            first be passed through an embedding layer, followed by\n            the transformer.\n        \"\"\"\n        del mode  # Unused.\n        return None\n\n    def quantize(self, mode=None, config=None, filters=None, **kwargs):\n        \"\"\"Quantize the weights of the model.\n\n        Note that the model must be built first before calling this method.\n        `quantize` will recursively call `quantize(...)` in all layers and\n        will be skipped if the layer doesn't implement the function.\n\n        This method can be called by passing a `mode` string, which uses the\n        default configuration for that mode. Alternatively, a `config` object\n        can be passed to customize the behavior of the quantization (e.g. to\n        use specific quantizers for weights or activations).\n\n        Args:\n            mode: The mode of the quantization. Supported modes are:\n                `\"int8\"`, `\"int4\"`, `\"float8\"`, `\"gptq\"`. This is\n                optional if `config` is provided.\n            config: The configuration object specifying additional\n                quantization options. This argument allows to configure\n                the weight and activation quantizers. be an instance of\n                `keras.quantizers.QuantizationConfig`.\n            filters: Optional filters to apply to the quantization. Can be a\n                regex string, a list of regex strings, or a callable. Only the\n                layers which match the filter conditions will be quantized.\n            **kwargs: Additional keyword arguments.\n\n        Example:\n\n        Quantize a model to int8 with default configuration:\n\n        ```python\n        # Build the model\n        model = keras.Sequential([\n            keras.Input(shape=(10,)),\n            keras.layers.Dense(10),\n        ])\n        model.build((None, 10))\n\n        # Quantize with default int8 config\n        model.quantize(\"int8\")\n        ```\n\n        Quantize a model to int8 with a custom configuration:\n\n        ```python\n        from keras.quantizers import Int8QuantizationConfig\n        from keras.quantizers import AbsMaxQuantizer\n\n        # Build the model\n        model = keras.Sequential([\n            keras.Input(shape=(10,)),\n            keras.layers.Dense(10),\n        ])\n        model.build((None, 10))\n\n        # Create a custom config\n        config = Int8QuantizationConfig(\n            weight_quantizer=AbsMaxQuantizer(\n                axis=0,\n                value_range=(-127, 127)\n            ),\n            activation_quantizer=AbsMaxQuantizer(\n                axis=-1,\n                value_range=(-127, 127)\n            ),\n        )\n\n        # Quantize with custom config\n        model.quantize(config=config)\n        ```\n        \"\"\"\n        # Validate inputs.\n        type_check = kwargs.pop(\"type_check\", True)\n        if kwargs:\n            raise ValueError(\n                \"Unrecognized keyword arguments \"\n                f\"passed to {self.__class__.__name__}: {kwargs}\"\n            )\n\n        if filters is not None:\n            if not isinstance(filters, (str, Callable, list, tuple)):\n                raise ValueError(\n                    \"The `filters` argument must be a regex string, a list of \"\n                    \"regex strings, or a callable. Received: \"\n                    f\"{type(filters)}\"\n                )\n\n        graph_modified = False\n        for layer in self._flatten_layers():\n            # Apply filters\n            if not should_quantize_layer(layer, filters):\n                continue\n\n            if len(list(layer._flatten_layers())) == 1:\n                try:\n                    layer.quantize(mode, type_check=type_check, config=config)\n                    graph_modified = True\n                except NotImplementedError as e:\n                    warnings.warn(str(e))\n                except AttributeError:\n                    pass\n\n        if mode in [\"gptq\", \"awq\"]:\n            # Resolve model structure.\n            # 1. If quantization_layer_structure is provided inside the config,\n            # use that.\n            structure = config.quantization_layer_structure\n            # 2. If no layer structure is provided in the config, try to fetch\n            # it using the `get_quantization_layer_structure` hook.\n            if structure is None:\n                structure = self.get_quantization_layer_structure(mode)\n\n            if structure is None:\n                raise ValueError(\n                    f\"For {mode=}, a valid quantization structure must be \"\n                    \"provided either via `config.quantization_layer_structure` \"\n                    \"or by overriding \"\n                    \"`model.get_quantization_layer_structure(mode)`. The \"\n                    \"structure should be a dictionary with keys \"\n                    \"'pre_block_layers' and 'sequential_blocks'.\"\n                )\n            if mode == \"gptq\":\n                gptq_quantize(config, structure, filters=filters)\n            elif mode == \"awq\":\n                awq_quantize(config, structure, filters=filters)\n\n        # If any layer was changed, we must rebuild the execution functions.\n        if graph_modified:\n            self.train_function = None\n            self.test_function = None\n            self.predict_function = None\n            self._post_quantize(mode, **kwargs)\n\n    def _post_quantize(self, mode, **kwargs):\n        if backend.backend() == \"torch\":\n            # We need to manually retrack `torch_params`.\n            # The reason is that after quantization, the removed variables are\n            # still referenced by `torch_params` and cannot be gc.\n            for layer in self._flatten_layers():\n                layer._track_variables()\n\n    def build_from_config(self, config):\n        if not config:\n            return\n        status = False\n        if \"input_shape\" in config:\n            # Case: all inputs are in the first arg (possibly nested).\n            if utils.is_default(self.build):\n                status = self._build_by_run_for_single_pos_arg(\n                    config[\"input_shape\"]\n                )\n            else:\n                try:\n                    self.build(config[\"input_shape\"])\n                    status = True\n                except:\n                    pass\n            self._build_shapes_dict = config\n\n        elif \"shapes_dict\" in config:\n            # Case: inputs were recorded as multiple keyword arguments.\n            if utils.is_default(self.build):\n                status = self._build_by_run_for_kwargs(config[\"shapes_dict\"])\n            else:\n                try:\n                    self.build(**config[\"shapes_dict\"])\n                    status = True\n                except:\n                    pass\n            self._build_shapes_dict = config[\"shapes_dict\"]\n\n        if not status:\n            warnings.warn(\n                f\"Model '{self.name}' had a build config, but the model \"\n                \"cannot be built automatically in \"\n                \"`build_from_config(config)`. \"\n                \"You should implement \"\n                \"`def build_from_config(self, config)`, \"\n                \"and you might also want to implement the method \"\n                \" that generates the config at saving time, \"\n                \"`def get_build_config(self)`. \"\n                \"The method `build_from_config()` is meant to \"\n                \"create the state of the model (i.e. its variables) \"\n                \"upon deserialization.\",\n                stacklevel=2,\n            )\n\n    def to_json(self, **kwargs):\n        \"\"\"Returns a JSON string containing the network configuration.\n\n        To load a network from a JSON save file, use\n        `keras.models.model_from_json(json_string, custom_objects={...})`.\n\n        Args:\n            **kwargs: Additional keyword arguments to be passed to\n                `json.dumps()`.\n\n        Returns:\n            A JSON string.\n        \"\"\"\n        from keras.src.saving import serialization_lib\n\n        model_config = serialization_lib.serialize_keras_object(self)\n        return json.dumps(model_config, **kwargs)\n\n    def export(\n        self,\n        filepath,\n        format=\"tf_saved_model\",\n        verbose=None,\n        input_signature=None,\n        **kwargs,\n    ):\n        \"\"\"Export the model as an artifact for inference.\n\n        Args:\n            filepath: `str` or `pathlib.Path` object. The path to save the\n                artifact.\n            format: `str`. The export format. Supported values:\n                `\"tf_saved_model\"`, `\"onnx\"`, `\"openvino\"`, and `\"litert\"`.\n                Defaults to `\"tf_saved_model\"`.\n            verbose: `bool`. Whether to print a message during export. Defaults\n                to `None`, which uses the default value set by different\n                backends and formats.\n            input_signature: Optional. Specifies the shape and dtype of the\n                model inputs. Can be a structure of `keras.InputSpec`,\n                `tf.TensorSpec`, `backend.KerasTensor`, or backend tensor. If\n                not provided, it will be automatically computed. Defaults to\n                `None`.\n            **kwargs: Additional keyword arguments.\n                - `is_static`: Optional `bool`. Specific to the JAX backend and\n                    `format=\"tf_saved_model\"`. Indicates whether `fn` is static.\n                    Set to `False` if `fn` involves state updates (e.g., RNG\n                    seeds and counters).\n                - `jax2tf_kwargs`: Optional `dict`. Specific to the JAX backend\n                    and `format=\"tf_saved_model\"`. Arguments for\n                    `jax2tf.convert`. See the documentation for\n                    [`jax2tf.convert`](\n                        https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md).\n                    If `native_serialization` and `polymorphic_shapes` are not\n                    provided, they will be automatically computed.\n                - `opset_version`: Optional `int`. Specific to `format=\"onnx\"`.\n                    An integer value that specifies the ONNX opset version.\n                - LiteRT-specific options: Optional keyword arguments specific\n                    to `format=\"litert\"`. These are passed directly to the\n                    TensorFlow Lite converter and include options like\n                    `optimizations`, `representative_dataset`,\n                    `experimental_new_quantizer`, `allow_custom_ops`,\n                    `enable_select_tf_ops`, etc. See TensorFlow Lite\n                    documentation for all available options.\n\n        **Note:** This feature is currently supported only with TensorFlow, JAX\n        and Torch backends.\n\n        **Note:** Be aware that the exported artifact may contain information\n        from the local file system when using `format=\"onnx\"`, `verbose=True`\n        and Torch backend.\n\n        Examples:\n\n        Here's how to export a TensorFlow SavedModel for inference.\n\n        ```python\n        # Export the model as a TensorFlow SavedModel artifact\n        model.export(\"path/to/location\", format=\"tf_saved_model\")\n\n        # Load the artifact in a different process/environment\n        reloaded_artifact = tf.saved_model.load(\"path/to/location\")\n        predictions = reloaded_artifact.serve(input_data)\n        ```\n\n        Here's how to export an ONNX for inference.\n\n        ```python\n        # Export the model as a ONNX artifact\n        model.export(\"path/to/location\", format=\"onnx\")\n\n        # Load the artifact in a different process/environment\n        ort_session = onnxruntime.InferenceSession(\"path/to/location\")\n        ort_inputs = {\n            k.name: v for k, v in zip(ort_session.get_inputs(), input_data)\n        }\n        predictions = ort_session.run(None, ort_inputs)\n        ```\n\n        Here's how to export a LiteRT (TFLite) for inference.\n\n        ```python\n        # Export the model as a LiteRT artifact\n        model.export(\"path/to/location\", format=\"litert\")\n\n        # Load the artifact in a different process/environment\n        interpreter = tf.lite.Interpreter(model_path=\"path/to/location\")\n        interpreter.allocate_tensors()\n        interpreter.set_tensor(\n            interpreter.get_input_details()[0]['index'], input_data\n        )\n        interpreter.invoke()\n        output_data = interpreter.get_tensor(\n            interpreter.get_output_details()[0]['index']\n        )\n        ```\n        \"\"\"\n        from keras.src.export import export_litert\n        from keras.src.export import export_onnx\n        from keras.src.export import export_openvino\n        from keras.src.export import export_saved_model\n\n        available_formats = (\"tf_saved_model\", \"onnx\", \"openvino\", \"litert\")\n        if format not in available_formats:\n            raise ValueError(\n                f\"Unrecognized format={format}. Supported formats are: \"\n                f\"{list(available_formats)}.\"\n            )\n\n        # Check if LiteRT export is available (requires TensorFlow backend)\n        if format == \"litert\" and backend.backend() != \"tensorflow\":\n            raise ImportError(\"LiteRT export requires TensorFlow backend.\")\n\n        if format == \"tf_saved_model\":\n            export_saved_model(\n                self,\n                filepath,\n                verbose,\n                input_signature=input_signature,\n                **kwargs,\n            )\n        elif format == \"onnx\":\n            export_onnx(\n                self,\n                filepath,\n                verbose,\n                input_signature=input_signature,\n                **kwargs,\n            )\n        elif format == \"openvino\":\n            export_openvino(\n                self,\n                filepath,\n                verbose,\n                input_signature=input_signature,\n                **kwargs,\n            )\n        elif format == \"litert\":\n            export_litert(\n                self,\n                filepath,\n                input_signature=input_signature,\n                **kwargs,\n            )\n\n    @classmethod\n    def from_config(cls, config, custom_objects=None):\n        from keras.src.models.functional import Functional\n\n        functional_config_keys = [\n            \"name\",\n            \"layers\",\n            \"input_layers\",\n            \"output_layers\",\n        ]\n        is_functional_config = all(\n            key in config for key in functional_config_keys\n        )\n        argspec = inspect.getfullargspec(cls.__init__)\n        functional_init_args = inspect.getfullargspec(Functional.__init__).args[\n            1:\n        ]\n        revivable_as_functional = (\n            cls in {Functional, Model}\n            or argspec.args[1:] == functional_init_args\n            or (argspec.varargs == \"args\" and argspec.varkw == \"kwargs\")\n        )\n        if is_functional_config and revivable_as_functional:\n            # Revive Functional model\n            # (but not Functional subclasses with a custom __init__)\n            from keras.src.models.functional import functional_from_config\n\n            return functional_from_config(\n                cls, config, custom_objects=custom_objects\n            )\n\n        # Either the model has a custom __init__, or the config\n        # does not contain all the information necessary to\n        # revive a Functional model. This happens when the user creates\n        # subclassed models where `get_config()` is returning\n        # insufficient information to be considered a Functional model.\n        # In this case, we fall back to provide all config into the\n        # constructor of the class.\n        try:\n            return cls(**config)\n        except TypeError as e:\n            raise TypeError(\n                \"Unable to revive model from config. When overriding \"\n                \"the `get_config()` method, make sure that the \"\n                \"returned config contains all items used as arguments \"\n                f\"in the  constructor to {cls}, \"\n                \"which is the default behavior. \"\n                \"You can override this default behavior by defining a \"\n                \"`from_config(cls, config)` class method to specify \"\n                \"how to create an \"\n                f\"instance of {cls.__name__} from its config.\\n\\n\"\n                f\"Received config={config}\\n\\n\"\n                f\"Error encountered during deserialization: {e}\"\n            )\n\n    def _get_variable_map(self):\n        store = {}\n        map_saveable_variables(self, store=store, visited_saveables=set())\n        return store\n\n    def get_state_tree(self, value_format=\"backend_tensor\"):\n        \"\"\"Retrieves tree-like structure of model variables.\n\n        This method allows retrieval of different model variables (trainable,\n        non-trainable, optimizer, and metrics). The variables are returned in a\n        nested dictionary format, where the keys correspond to the variable\n        names and the values are the nested representations of the variables.\n\n        Returns:\n            dict: A dictionary containing the nested representations of the\n                requested variables. The keys are the variable names, and the\n                values are the corresponding nested dictionaries.\n            value_format: One of `\"backend_tensor\"`, `\"numpy_array\"`.\n                The kind of array to return as the leaves of the nested\n                    state tree.\n\n        Example:\n\n        ```python\n        model = keras.Sequential([\n            keras.Input(shape=(1,), name=\"my_input\"),\n            keras.layers.Dense(1, activation=\"sigmoid\", name=\"my_dense\"),\n        ], name=\"my_sequential\")\n        model.compile(optimizer=\"adam\", loss=\"mse\", metrics=[\"mae\"])\n        model.fit(np.array([[1.0]]), np.array([[1.0]]))\n        state_tree = model.get_state_tree()\n        ```\n\n        The `state_tree` dictionary returned looks like:\n\n        ```\n        {\n            'metrics_variables': {\n                'loss': {\n                    'count': ...,\n                    'total': ...,\n                },\n                'mean_absolute_error': {\n                    'count': ...,\n                    'total': ...,\n                }\n            },\n            'trainable_variables': {\n                'my_sequential': {\n                    'my_dense': {\n                        'bias': ...,\n                        'kernel': ...,\n                    }\n                }\n            },\n            'non_trainable_variables': {},\n            'optimizer_variables': {\n                'adam': {\n                        'iteration': ...,\n                        'learning_rate': ...,\n                        'my_sequential_my_dense_bias_momentum': ...,\n                        'my_sequential_my_dense_bias_velocity': ...,\n                        'my_sequential_my_dense_kernel_momentum': ...,\n                        'my_sequential_my_dense_kernel_velocity': ...,\n                    }\n                }\n            }\n        }\n        ```\n        \"\"\"\n        variables = {}\n        variables[\"trainable_variables\"] = self._create_nested_dict(\n            self.trainable_variables, value_format\n        )\n        variables[\"non_trainable_variables\"] = self._create_nested_dict(\n            self.non_trainable_variables, value_format\n        )\n        variables[\"optimizer_variables\"] = self._create_nested_dict(\n            self.optimizer.variables, value_format\n        )\n        variables[\"metrics_variables\"] = self._create_nested_dict(\n            self.metrics_variables, value_format\n        )\n        return variables\n\n    def _create_nested_dict(self, variables, value_format):\n        flat_dict = {}\n        for v in variables:\n            if v.path in flat_dict:\n                raise ValueError(\n                    \"The following variable path is found twice in the model: \"\n                    f\"'{v.path}'. `get_state_tree()` can only be called when \"\n                    \"all variable paths are unique. Make sure to give unique \"\n                    \"names to your layers (and other objects).\"\n                )\n            if value_format == \"backend_tensor\":\n                flat_dict[v.path] = v.value\n            elif value_format == \"numpy_array\":\n                flat_dict[v.path] = v.numpy()\n            else:\n                raise ValueError(\n                    \"Invalid `value_format` argument. Expected one of \"\n                    \"{'numpy_array', 'backend_tensor'}. Received: \"\n                    f\"value_format={value_format}\"\n                )\n\n        nested_dict = {}\n        for path, value in flat_dict.items():\n            parts = path.split(\"/\")\n            current_dict = nested_dict\n            for part in parts[:-1]:\n                if part not in current_dict:\n                    current_dict[part] = {}\n                current_dict = current_dict[part]\n            current_dict[parts[-1]] = value\n\n        return nested_dict\n\n    def set_state_tree(self, state_tree):\n        \"\"\"Assigns values to variables of the model.\n\n        This method takes a dictionary of nested variable values, which\n        represents the state tree of the model, and assigns them to the\n        corresponding variables of the model. The dictionary keys represent the\n        variable names (e.g., `'trainable_variables'`, `'optimizer_variables'`),\n        and the values are nested dictionaries containing the variable\n        paths and their corresponding values.\n\n        Args:\n            state_tree: A dictionary representing the state tree of the model.\n                The keys are the variable names, and the values are nested\n                dictionaries representing the variable paths and their values.\n        \"\"\"\n        for k, v in state_tree.items():\n            path_value_dict = self._flatten_nested_dict(v)\n            if k == \"trainable_variables\":\n                self._assign_variable_values(\n                    self.trainable_variables, path_value_dict\n                )\n            elif k == \"non_trainable_variables\":\n                self._assign_variable_values(\n                    self.non_trainable_variables, path_value_dict\n                )\n            elif k == \"optimizer_variables\":\n                if hasattr(self, \"optimizer\") and self.optimizer is not None:\n                    self._assign_variable_values(\n                        self.optimizer.variables, path_value_dict\n                    )\n            elif k == \"metrics_variables\":\n                if (\n                    hasattr(self, \"metrics_variables\")\n                    and self.metrics_variables\n                ):\n                    self._assign_variable_values(\n                        self.metrics_variables, path_value_dict\n                    )\n            else:\n                raise ValueError(f\"Unknown variable name: {k}\")\n\n    def _assign_variable_values(self, variables, path_value_dict):\n        for path, value in path_value_dict.items():\n            for variable in variables:\n                if variable.path == path:\n                    variable.assign(value)\n\n    def _flatten_nested_dict(self, nested_dict):\n        flat_dict = {}\n\n        def _flatten(current_dict, prefix=\"\"):\n            for key, value in current_dict.items():\n                if isinstance(value, dict):\n                    _flatten(value, f\"{prefix}{key}/\")\n                else:\n                    flat_dict[f\"{prefix}{key}\"] = value\n\n        _flatten(nested_dict)\n        return flat_dict\n\n\n@keras_export(\"keras.models.model_from_json\")\ndef model_from_json(json_string, custom_objects=None):\n    \"\"\"Parses a JSON model configuration string and returns a model instance.\n\n    Example:\n\n    >>> model = keras.Sequential([\n    ...     keras.layers.Dense(5, input_shape=(3,)),\n    ...     keras.layers.Softmax()])\n    >>> config = model.to_json()\n    >>> loaded_model = keras.models.model_from_json(config)\n\n    Args:\n        json_string: JSON string encoding a model configuration.\n        custom_objects: Optional dictionary mapping names\n            (strings) to custom classes or functions to be\n            considered during deserialization.\n\n    Returns:\n        A Keras model instance (uncompiled).\n    \"\"\"\n    from keras.src.saving import serialization_lib\n\n    model_config = json.loads(json_string)\n    return serialization_lib.deserialize_keras_object(\n        model_config, custom_objects=custom_objects\n    )\n\n\ndef functional_init_arguments(args, kwargs):\n    return (\n        (len(args) == 2)\n        or (len(args) == 1 and \"outputs\" in kwargs)\n        or (\"inputs\" in kwargs and \"outputs\" in kwargs)\n    )\n\n\ndef inject_functional_model_class(cls):\n    \"\"\"Inject `Functional` into the hierarchy of this class if needed.\"\"\"\n    from keras.src.models import functional\n\n    if cls is Model:\n        return functional.Functional\n    # In case there is any multiple inheritance, we stop injecting the\n    # class if keras model is not in its class hierarchy.\n    if cls is object:\n        return object\n\n    cls.__bases__ = tuple(\n        inject_functional_model_class(base) for base in cls.__bases__\n    )\n    # Trigger any `__new__` class swapping that needed to happen on `Functional`\n    # but did not because functional was not in the class hierarchy.\n    cls.__new__(cls)\n\n    return cls\n"
  },
  {
    "path": "keras/src/models/model_test.py",
    "content": "import os\nimport pickle\nfrom collections import namedtuple\n\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import losses\nfrom keras.src import testing\nfrom keras.src import tree\nfrom keras.src.layers.core.input_layer import Input\nfrom keras.src.models.functional import Functional\nfrom keras.src.models.model import Model\nfrom keras.src.models.model import model_from_json\n\n\ndef _get_model():\n    input_a = Input(shape=(3,), batch_size=2, name=\"input_a\")\n    input_b = Input(shape=(3,), batch_size=2, name=\"input_b\")\n    x = input_a + input_b\n    x = layers.Dense(5)(x)\n    outputs = layers.Dense(4)(x)\n    model = Model([input_a, input_b], outputs)\n    return model\n\n\ndef _get_model_multi_outputs_list():\n    x = Input(shape=(3,), name=\"input_a\")\n    output_a = layers.Dense(1, name=\"output_a\")(x)\n    output_b = layers.Dense(1, name=\"output_b\", activation=\"sigmoid\")(x)\n    model = Model(x, [output_a, output_b])\n    return model\n\n\ndef _get_model_multi_outputs_list_no_output_names():\n    x = Input(shape=(3,), name=\"input_a\")\n    output_a = layers.Dense(1)(x)\n    output_b = layers.Dense(1, activation=\"sigmoid\")(x)\n    model = Model(x, [output_a, output_b])\n    return model\n\n\ndef _get_model_single_output():\n    x = Input(shape=(3,), name=\"input_a\")\n    output_a = layers.Dense(1, name=\"output_a\")(x)\n    model = Model(x, output_a)\n    return model\n\n\ndef _get_model_single_output_list():\n    x = Input(shape=(3,), name=\"input_a\")\n    output_a = layers.Dense(1, name=\"output_a\")(x)\n    model = Model(x, [output_a])\n    return model\n\n\ndef _get_model_single_output_dict():\n    x = Input(shape=(3,), name=\"input_a\")\n    output_a = layers.Dense(1, name=\"output_a\")(x)\n    model = Model(x, {\"output_a\": output_a})\n    return model\n\n\ndef _get_model_multi_outputs_dict():\n    x = Input(shape=(3,), name=\"input_a\")\n    output_a = layers.Dense(1, name=\"output_a\")(x)\n    output_b = layers.Dense(1, name=\"output_b\", activation=\"sigmoid\")(x)\n    model = Model(x, {\"output_a\": output_a, \"output_b\": output_b})\n    return model\n\n\ndef _get_model_multi_outputs_struct_list_like(_type):\n    x = Input(shape=(3,), name=\"x\")\n    y1 = layers.Dense(1, name=\"y1\", activation=\"sigmoid\")(x)\n    y2 = layers.Dense(1, name=\"y2\", activation=\"sigmoid\")(x)\n    model = Model(x, _type([y1, y2]))\n    return model\n\n\ndef _get_model_multi_outputs_struct_namedtuple():\n    Y = namedtuple(\"Y\", [\"y1\", \"y2\"])\n    x = Input(shape=(3,), name=\"x\")\n    y1 = layers.Dense(1, name=\"y1\", activation=\"sigmoid\")(x)\n    y2 = layers.Dense(1, name=\"y2\", activation=\"sigmoid\")(x)\n    model = Model(x, Y(y1, y2))\n    return model, Y\n\n\ndef _get_model_multi_outputs_struct_dict():\n    x = Input(shape=(3,), name=\"x\")\n    y1 = layers.Dense(1, name=\"y1\", activation=\"sigmoid\")(x)\n    y2 = layers.Dense(1, name=\"y2\", activation=\"sigmoid\")(x)\n    model = Model(x, {\"a\": y1, \"b\": y2})\n    return model\n\n\ndef _get_model_multi_outputs_struct():\n    x = Input(shape=(3,), name=\"x\")\n    y1 = layers.Dense(1, name=\"y1\", activation=\"sigmoid\")(x)\n    y2 = layers.Dense(1, name=\"y2\", activation=\"sigmoid\")(x)\n    y3 = layers.Dense(1, name=\"y3\", activation=\"sigmoid\")(x)\n    model = Model(\n        x,\n        {\n            \"a\": (y1, y2),\n            \"b\": {\"b1\": y1, \"b2\": y2},\n            \"c\": {\"c1\": (y1, y2), \"c2\": y2},\n            \"d\": y3,\n        },\n    )\n    return model\n\n\ndef _get_model_multi_outputs_dict_with_single_tensor():\n    x = Input(shape=(3,), name=\"input_a\")\n    output = layers.Dense(1, name=\"output_a\")(x)\n    model = Model(x, {\"output_a\": output, \"output_b\": output})\n    return model\n\n\ndef _get_model_with_custom_compute_loss():\n    class MyModel(Model):\n        def __init__(self):\n            inputs = Input(shape=(3,), name=\"inputs\")\n            outputs = layers.Dense(1, name=\"a\")(inputs)\n            super().__init__(inputs=inputs, outputs=outputs)\n\n        def compute_loss(self, x, y, y_pred, sample_weight=None, **kwargs):\n            y_pred = [y_pred, y_pred]  # To list\n            return super().compute_loss(\n                x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, **kwargs\n            )\n\n    model = MyModel()\n    return model\n\n\ndef _get_model_with_duplicate_variable_path():\n    class MyModel(Model):\n        def __init__(self):\n            super().__init__()\n            self.dense1 = layers.Dense(4, activation=\"relu\", name=\"layer1\")\n            self.dense2 = layers.Dense(4, activation=\"relu\", name=\"layer1\")\n            self.dense3 = layers.Dense(2)\n\n        def call(self, x):\n            x = self.dense1(x)\n            x = self.dense2(x)\n            return self.dense3(x)\n\n    model = MyModel()\n    x = np.random.random((1, 16))\n    model(x)\n    return model\n\n\ndef _get_model_optional_inputs():\n    class OptionalInputLayer(layers.Layer):\n        def __init__(self):\n            super().__init__()\n            self.dense = layers.Dense(2)\n\n        def call(self, x, o=None):\n            z = x if o is None else x + o\n            return self.dense(z)\n\n    x = Input((2,), name=\"x\")\n    o = Input((2,), name=\"o\", optional=True)\n    y = OptionalInputLayer()(x, o)\n    model = Model({\"x\": x, \"o\": o}, y)\n    return model\n\n\ndef _get_variable_value_by_path(variables, path):\n    for v in variables:\n        if v.path == path:\n            return v.value\n    raise ValueError(f\"No variable was find with path = {path}\")\n\n\n@pytest.mark.requires_trainable_backend\nclass ModelTest(testing.TestCase):\n    def test_functional_rerouting(self):\n        model = _get_model()\n        self.assertIsInstance(model, Functional)\n\n    def test_json_serialization(self):\n        model = _get_model()\n        json_string = model.to_json()\n        new_model = model_from_json(json_string)\n        self.assertEqual(json_string, new_model.to_json())\n\n    def test_tuple_input_model_subclass(self):\n        # https://github.com/keras-team/keras/issues/324\n\n        class MultiInputModel(Model):\n            def __init__(self, **kwargs):\n                super().__init__(**kwargs)\n                self.dense1 = layers.Dense(4)\n\n            def call(self, inputs):\n                a, b = inputs\n                r = self.dense1(a)\n                return layers.concatenate([r, b])\n\n        model = MultiInputModel()\n        x1 = np.random.rand(3, 3)\n        x2 = np.random.rand(3, 2)\n        out = model((x1, x2))\n        self.assertEqual(out.shape, (3, 6))\n\n    def test_reviving_functional_from_config_custom_layer(self):\n        class CustomDense(layers.Layer):\n            def __init__(self, units, **kwargs):\n                super().__init__(**kwargs)\n                self.dense = layers.Dense(units)\n\n            def call(self, x):\n                return self.dense(x)\n\n        inputs = layers.Input((4,))\n        outputs = CustomDense(10)(inputs)\n        model = Model(inputs, outputs)\n        config = model.get_config()\n\n        new_model = Model.from_config(\n            config, custom_objects={\"CustomDense\": CustomDense}\n        )\n        self.assertIsInstance(new_model, Functional)\n\n    def test_reviving_functional_from_config_custom_model(self):\n        class CustomModel(Model):\n            def __init__(self, *args, param=1, **kwargs):\n                super().__init__(*args, **kwargs)\n                self.param = param\n\n            def get_config(self):\n                base_config = super().get_config()\n                config = {\"param\": self.param}\n                return base_config | config\n\n        inputs = layers.Input((3,))\n        outputs = layers.Dense(5)(inputs)\n        model = CustomModel(inputs=inputs, outputs=outputs, param=3)\n\n        new_model = CustomModel.from_config(model.get_config())\n        self.assertEqual(new_model.param, 3)\n\n    @parameterized.named_parameters(\n        (\"single_output_1\", _get_model_single_output),\n        (\"single_output_2\", _get_model_single_output),\n        (\"single_output_3\", _get_model_single_output),\n        (\"single_output_4\", _get_model_single_output),\n        (\"single_list_output_1\", _get_model_single_output_list),\n        (\"single_list_output_2\", _get_model_single_output_list),\n        (\"single_list_output_3\", _get_model_single_output_list),\n        (\"single_list_output_4\", _get_model_single_output_list),\n    )\n    def test_functional_pickling(self, model_fn):\n        model = model_fn()\n        self.assertIsInstance(model, Functional)\n        model.compile()\n        x = np.random.rand(8, 3)\n\n        reloaded_pickle = pickle.loads(pickle.dumps(model))\n\n        pred_reloaded = reloaded_pickle.predict(x)\n        pred = model.predict(x)\n\n        self.assertAllClose(np.array(pred_reloaded), np.array(pred))\n\n    @parameterized.named_parameters(\n        (\"single_output_1\", _get_model_single_output, None),\n        (\"single_output_2\", _get_model_single_output, \"list\"),\n        (\"single_output_3\", _get_model_single_output, \"dict\"),\n        (\"single_output_4\", _get_model_single_output, \"dict_list\"),\n        (\"single_list_output_1\", _get_model_single_output_list, None),\n        (\"single_list_output_2\", _get_model_single_output_list, \"list\"),\n        (\"single_list_output_3\", _get_model_single_output_list, \"dict\"),\n        (\"single_list_output_4\", _get_model_single_output_list, \"dict_list\"),\n        (\"single_dict_output_1\", _get_model_single_output_dict, None),\n        (\"single_dict_output_2\", _get_model_single_output_dict, \"list\"),\n        (\"single_dict_output_3\", _get_model_single_output_dict, \"dict\"),\n        (\"single_dict_output_4\", _get_model_single_output_dict, \"dict_list\"),\n    )\n    def test_functional_single_output(self, model_fn, loss_type):\n        model = model_fn()\n        self.assertIsInstance(model, Functional)\n        loss = \"mean_squared_error\"\n        if loss_type == \"list\":\n            loss = [loss]\n        elif loss_type == \"dict\":\n            loss = {\"output_a\": loss}\n        elif loss_type == \"dict_list\":\n            loss = {\"output_a\": [loss]}\n        model.compile(\n            optimizer=\"sgd\",\n            loss=loss,\n            metrics={\n                \"output_a\": [\"mean_squared_error\", \"mean_absolute_error\"],\n            },\n            weighted_metrics={\n                \"output_a\": \"mean_squared_error\",\n            },\n        )\n        # Fit the model to make sure compile_metrics are built\n        x = np.random.rand(8, 3)\n        y = np.random.rand(8, 1)\n        hist = model.fit(\n            x,\n            y,\n            batch_size=2,\n            epochs=1,\n            verbose=0,\n        )\n        hist_keys = sorted(hist.history.keys())\n        ref_keys = sorted(\n            [\n                \"loss\",\n                \"mean_absolute_error\",\n                \"mean_squared_error\",\n                \"weighted_mean_squared_error\",\n            ]\n        )\n        self.assertListEqual(hist_keys, ref_keys)\n\n    def test_functional_list_outputs_list_losses(self):\n        model = _get_model_multi_outputs_list()\n        self.assertIsInstance(model, Functional)\n        x = np.random.rand(8, 3)\n        y1 = np.random.rand(8, 1)\n        y2 = np.random.randint(0, 2, (8, 1))\n        model.compile(\n            optimizer=\"sgd\",\n            loss=[\"mean_squared_error\", \"binary_crossentropy\"],\n            metrics=[\n                \"mean_squared_error\",\n                [\"mean_squared_error\", \"accuracy\"],\n            ],\n            loss_weights=[0.1, 2],\n        )\n        # Fit the model to make sure compile_metrics are built\n        hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)\n        hist_keys = sorted(hist.history.keys())\n        ref_keys = sorted(\n            [\n                \"loss\",\n                \"output_a_loss\",\n                \"output_a_mean_squared_error\",\n                \"output_b_accuracy\",\n                \"output_b_loss\",\n                \"output_b_mean_squared_error\",\n            ]\n        )\n        self.assertListEqual(hist_keys, ref_keys)\n\n    def test_functional_list_outputs_list_losses_abbr(self):\n        model = _get_model_multi_outputs_list()\n        self.assertIsInstance(model, Functional)\n        x = np.random.rand(8, 3)\n        y1 = np.random.rand(8, 1)\n        y2 = np.random.randint(0, 2, (8, 1))\n        model.compile(\n            optimizer=\"sgd\",\n            loss=[\"mse\", \"bce\"],\n            metrics=[\n                [\"bce\", \"mse\", \"mae\"],\n                [\"mse\", \"acc\"],\n            ],\n            loss_weights=[0.1, 2],\n        )\n        # Fit the model to make sure compile_metrics are built\n        hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)\n        hist_keys = sorted(hist.history.keys())\n        ref_keys = sorted(\n            [\n                \"loss\",\n                \"output_a_loss\",\n                \"output_a_bce\",\n                \"output_a_mae\",\n                \"output_a_mse\",\n                \"output_b_acc\",\n                \"output_b_loss\",\n                \"output_b_mse\",\n            ]\n        )\n        self.assertListEqual(hist_keys, ref_keys)\n\n    def test_functional_list_outputs_nested_list_losses(self):\n        model = _get_model_multi_outputs_list()\n        self.assertIsInstance(model, Functional)\n        x = np.random.rand(8, 3)\n        y1 = np.random.rand(8, 1)\n        y2 = np.random.randint(0, 2, (8, 1))\n        model.compile(\n            optimizer=\"sgd\",\n            loss=[\"mean_squared_error\", [\"binary_crossentropy\"]],\n            metrics=[\n                \"mean_squared_error\",\n                [\"mean_squared_error\", \"accuracy\"],\n            ],\n            loss_weights=[0.1, 2],\n        )\n        # Fit the model to make sure compile_metrics are built\n        hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)\n        hist_keys = sorted(hist.history.keys())\n        ref_keys = sorted(\n            [\n                \"loss\",\n                \"output_a_loss\",\n                \"output_a_mean_squared_error\",\n                \"output_b_accuracy\",\n                \"output_b_loss\",\n                \"output_b_mean_squared_error\",\n            ]\n        )\n        self.assertListEqual(hist_keys, ref_keys)\n\n    def test_functional_dict_outputs_dict_losses(self):\n        model = _get_model_multi_outputs_dict()\n        self.assertIsInstance(model, Functional)\n        x = np.random.rand(8, 3)\n        y1 = np.random.rand(8, 1)\n        y2 = np.random.randint(0, 2, (8, 1))\n        model.compile(\n            optimizer=\"sgd\",\n            loss={\n                \"output_a\": \"mean_squared_error\",\n                \"output_b\": [\"binary_crossentropy\"],\n            },\n            metrics={\n                \"output_a\": [\"mean_squared_error\"],\n                \"output_b\": [\"mean_squared_error\", \"accuracy\"],\n            },\n            weighted_metrics={\n                \"output_a\": [\"mean_squared_error\"],\n                \"output_b\": [\"mean_squared_error\", \"accuracy\"],\n            },\n        )\n        # Check dict outputs.\n        outputs = model.predict(x)\n        self.assertIsInstance(outputs, dict)\n        self.assertEqual(outputs[\"output_a\"].shape, (8, 1))\n        self.assertEqual(outputs[\"output_b\"].shape, (8, 1))\n        # Fit the model to make sure compile_metrics are built\n        hist = model.fit(\n            x,\n            {\"output_a\": y1, \"output_b\": y2},\n            batch_size=2,\n            epochs=1,\n            verbose=0,\n        )\n        hist_keys = sorted(hist.history.keys())\n        ref_keys = sorted(\n            [\n                \"loss\",\n                \"output_a_loss\",\n                \"output_a_mean_squared_error\",\n                \"output_a_weighted_mean_squared_error\",\n                \"output_b_accuracy\",\n                \"output_b_loss\",\n                \"output_b_mean_squared_error\",\n                \"output_b_weighted_accuracy\",\n                \"output_b_weighted_mean_squared_error\",\n            ]\n        )\n        self.assertListEqual(hist_keys, ref_keys)\n\n    def test_functional_dict_outputs_dict_losses_with_undefined_loss(self):\n        model = _get_model_multi_outputs_dict()\n        self.assertIsInstance(model, Functional)\n        x = np.random.rand(8, 3)\n        y1 = np.random.rand(8, 1)\n        y2 = np.random.randint(0, 2, (8, 1))\n        model.compile(\n            optimizer=\"sgd\",\n            loss={\n                \"output_b\": [\"binary_crossentropy\"],\n            },\n            metrics={\n                \"output_b\": [\"mean_squared_error\", \"accuracy\"],\n            },\n            weighted_metrics={\n                \"output_b\": [\"mean_squared_error\", \"accuracy\"],\n            },\n        )\n        # Check dict outputs.\n        outputs = model.predict(x)\n        self.assertIsInstance(outputs, dict)\n        self.assertEqual(outputs[\"output_a\"].shape, (8, 1))\n        self.assertEqual(outputs[\"output_b\"].shape, (8, 1))\n        # Fit the model to make sure compile_metrics are built\n        hist = model.fit(\n            x,\n            {\"output_a\": y1, \"output_b\": y2},\n            batch_size=2,\n            epochs=1,\n            verbose=0,\n        )\n        hist_keys = sorted(hist.history.keys())\n        ref_keys = sorted(\n            [\n                \"loss\",\n                \"output_b_accuracy\",\n                \"output_b_mean_squared_error\",\n                \"output_b_weighted_accuracy\",\n                \"output_b_weighted_mean_squared_error\",\n            ]\n        )\n        self.assertListEqual(hist_keys, ref_keys)\n\n    def test_functional_list_outputs_dict_losses_metrics(self):\n        model = _get_model_multi_outputs_list()\n        self.assertIsInstance(model, Functional)\n        x = np.random.rand(8, 3)\n        y1 = np.random.rand(8, 1)\n        y2 = np.random.randint(0, 2, (8, 1))\n        model.compile(\n            optimizer=\"sgd\",\n            loss={\n                \"output_a\": \"mean_squared_error\",\n                \"output_b\": \"binary_crossentropy\",\n            },\n            metrics={\n                \"output_a\": [\"mean_squared_error\"],\n                \"output_b\": [\"mean_squared_error\", \"accuracy\"],\n            },\n            weighted_metrics={\n                \"output_a\": [\"mean_squared_error\"],\n                \"output_b\": [\"mean_squared_error\", \"accuracy\"],\n            },\n        )\n        # Check list outputs.\n        outputs = model.predict(x)\n        self.assertIsInstance(outputs, list)\n        self.assertEqual(outputs[0].shape, (8, 1))\n        self.assertEqual(outputs[1].shape, (8, 1))\n        # Fit the model to make sure compile_metrics are built\n        hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)\n        hist_keys = sorted(hist.history.keys())\n        ref_keys = sorted(\n            [\n                \"loss\",\n                \"output_a_loss\",\n                \"output_a_mean_squared_error\",\n                \"output_a_weighted_mean_squared_error\",\n                \"output_b_accuracy\",\n                \"output_b_loss\",\n                \"output_b_mean_squared_error\",\n                \"output_b_weighted_accuracy\",\n                \"output_b_weighted_mean_squared_error\",\n            ]\n        )\n        self.assertListEqual(hist_keys, ref_keys)\n\n    def test_functional_list_outputs_dict_losses_metrics_uniq_weighted(self):\n        model = _get_model_multi_outputs_list()\n        self.assertIsInstance(model, Functional)\n        x = np.random.rand(8, 3)\n        y1 = np.random.rand(8, 1)\n        y2 = np.random.randint(0, 2, (8, 1))\n        model.compile(\n            optimizer=\"sgd\",\n            loss={\n                \"output_a\": \"mean_squared_error\",\n                \"output_b\": \"binary_crossentropy\",\n            },\n            metrics={\n                \"output_a\": [\"mean_squared_error\"],\n                \"output_b\": [\"mean_squared_error\"],\n            },\n            weighted_metrics={\n                \"output_a\": [\"mean_squared_error\"],\n                \"output_b\": [\"accuracy\"],\n            },\n        )\n        # Fit the model to make sure compile_metrics are built\n        hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)\n        hist_keys = sorted(hist.history.keys())\n        # `output_b_accuracy` doesn't have `weighted_` in metric name.\n        # When a metric is only in weighted metrics, it skips `weighted_`\n        # prefix. This behavior matches`tf.keras`.\n        ref_keys = sorted(\n            [\n                \"loss\",\n                \"output_a_loss\",\n                \"output_a_mean_squared_error\",\n                \"output_a_weighted_mean_squared_error\",\n                \"output_b_accuracy\",\n                \"output_b_loss\",\n                \"output_b_mean_squared_error\",\n            ]\n        )\n        self.assertListEqual(hist_keys, ref_keys)\n\n    def test_functional_list_outputs_dict_losses_partial_metrics(self):\n        model = _get_model_multi_outputs_list()\n        self.assertIsInstance(model, Functional)\n        x = np.random.rand(8, 3)\n        y1 = np.random.rand(8, 1)\n        y2 = np.random.randint(0, 2, (8, 1))\n        model.compile(\n            optimizer=\"sgd\",\n            loss={\n                \"output_a\": \"mean_squared_error\",\n                \"output_b\": \"binary_crossentropy\",\n            },\n            metrics={\n                \"output_b\": [\"mean_squared_error\", \"accuracy\"],\n            },\n        )\n        # Fit the model to make sure compile_metrics are built\n        hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)\n        hist_keys = sorted(hist.history.keys())\n        ref_keys = sorted(\n            [\n                \"loss\",\n                \"output_a_loss\",\n                \"output_b_accuracy\",\n                \"output_b_loss\",\n                \"output_b_mean_squared_error\",\n            ]\n        )\n        self.assertListEqual(hist_keys, ref_keys)\n\n    def test_functional_dict_outputs_with_single_tensor(self):\n        model = _get_model_multi_outputs_dict_with_single_tensor()\n        self.assertIsInstance(model, Functional)\n        x = np.random.rand(8, 3)\n        y1 = np.random.rand(8, 1)\n        y2 = np.random.randint(0, 2, (8, 1))\n\n        # `model` has 2 outputs, but there is actually only 1 output tensor.\n        self.assertLen(model.outputs, 2)\n        model.compile(\n            optimizer=\"sgd\",\n            loss={\n                \"output_a\": \"mean_squared_error\",\n                \"output_b\": \"binary_crossentropy\",\n            },\n        )\n        hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)\n        hist_keys = sorted(hist.history.keys())\n        ref_keys = sorted([\"loss\", \"output_a_loss\", \"output_b_loss\"])\n        self.assertListEqual(hist_keys, ref_keys)\n\n    def test_functional_list_outputs_with_custom_compute_loss(self):\n        model = _get_model_with_custom_compute_loss()\n        self.assertIsInstance(model, Functional)\n        x = np.random.rand(8, 3)\n        y1 = np.random.rand(8, 1)\n        y2 = np.random.randint(0, 2, (8, 1))\n\n        # `model` has 1 output, but in `compute_loss` it is separated into 2.\n        self.assertLen(model.outputs, 1)\n        model.compile(\n            optimizer=\"sgd\", loss=[\"mean_squared_error\", \"binary_crossentropy\"]\n        )\n        hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)\n        hist_keys = sorted(hist.history.keys())\n        ref_keys = sorted(\n            [\"binary_crossentropy_loss\", \"loss\", \"mean_squared_error_loss\"]\n        )\n        self.assertListEqual(hist_keys, ref_keys)\n\n    def test_functional_list_outputs_dict_losses_invalid_keys(self):\n        model = _get_model_multi_outputs_list()\n        self.assertIsInstance(model, Functional)\n        x = np.random.rand(8, 3)\n        y1 = np.random.rand(8, 1)\n        y2 = np.random.randint(0, 2, (8, 1))\n        model.compile(\n            optimizer=\"sgd\",\n            loss={\n                \"output_a\": \"mean_squared_error\",\n                \"output_c\": \"binary_crossentropy\",\n            },\n        )\n\n        # Fit the model to make sure compile_metrics are built\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Expected keys\",\n        ):\n            model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)\n\n    def test_functional_list_outputs_dict_losses_no_output_names(self):\n        model = _get_model_multi_outputs_list_no_output_names()\n        self.assertIsInstance(model, Functional)\n        x = np.random.rand(8, 3)\n        y1 = np.random.rand(8, 1)\n        y2 = np.random.randint(0, 2, (8, 1))\n        model.compile(\n            optimizer=\"sgd\",\n            loss={\"output_a\": \"mean_squared_error\"},\n        )\n        # Fit the model to make sure compile_metrics are built\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Expected keys\",\n        ):\n            model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)\n\n    def test_functional_list_outputs_dict_metrics_invalid_keys(self):\n        model = _get_model_multi_outputs_list()\n        self.assertIsInstance(model, Functional)\n        x = np.random.rand(8, 3)\n        y1 = np.random.rand(8, 1)\n        y2 = np.random.randint(0, 2, (8, 1))\n        model.compile(\n            optimizer=\"sgd\",\n            loss={\n                \"output_a\": \"mean_squared_error\",\n                \"output_b\": \"binary_crossentropy\",\n            },\n            metrics={\n                \"output_c\": [\"mean_squared_error\", \"accuracy\"],\n            },\n        )\n        # Fit the model to make sure compile_metrics are built\n        with self.assertRaisesRegex(\n            ValueError, \"(?s)Invalid `metrics`.*output_c\"\n        ):\n            model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)\n\n    def test_functional_dict_outputs_dict_losses_invalid_keys(self):\n        model = _get_model_multi_outputs_dict()\n        self.assertIsInstance(model, Functional)\n        x = np.random.rand(8, 3)\n        y1 = np.random.rand(8, 1)\n        y2 = np.random.randint(0, 2, (8, 1))\n        model.compile(\n            optimizer=\"sgd\",\n            loss={\n                \"output_a\": \"mean_squared_error\",\n                \"output_c\": \"binary_crossentropy\",\n            },\n        )\n        # Fit the model to make sure compile_metrics are built\n        with self.assertRaisesRegex(\n            KeyError,\n            \"in the `loss` argument, can't be found \"\n            \"in either the model's output\",\n        ):\n            model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)\n\n    def test_functional_dict_outputs_dict_metrics_invalid_keys(self):\n        model = _get_model_multi_outputs_dict()\n        self.assertIsInstance(model, Functional)\n        x = np.random.rand(8, 3)\n        y1 = np.random.rand(8, 1)\n        y2 = np.random.randint(0, 2, (8, 1))\n        model.compile(\n            optimizer=\"sgd\",\n            loss={\n                \"output_a\": \"mean_squared_error\",\n                \"output_b\": \"binary_crossentropy\",\n            },\n            metrics={\n                \"output_c\": [\"mean_squared_error\", \"accuracy\"],\n            },\n        )\n        # Fit the model to make sure compile_metrics are built\n        with self.assertRaisesRegex(\n            ValueError, \"(?s)Invalid `metrics`.*output_c\"\n        ):\n            model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)\n\n    def test_functional_list_outputs_invalid_nested_list_losses(self):\n        model = _get_model_multi_outputs_list()\n        self.assertIsInstance(model, Functional)\n        x = np.random.rand(8, 3)\n        y1 = np.random.rand(8, 1)\n        y2 = np.random.randint(0, 2, (8, 1))\n        model.compile(\n            optimizer=\"sgd\",\n            loss=[\n                \"mean_squared_error\",\n                [\"mean_squared_error\", \"binary_crossentropy\"],\n            ],\n        )\n        hist = model.fit(x, (y1, y2), batch_size=2, epochs=1, verbose=0)\n        hist_keys = sorted(hist.history.keys())\n        ref_keys = sorted([\"loss\", \"output_a_loss\", \"output_b_loss\"])\n        self.assertListEqual(hist_keys, ref_keys)\n\n    @parameterized.named_parameters(\n        (\"int8\", \"int8\"),\n        (\"float8\", \"float8\"),\n    )\n    def test_quantize(self, mode):\n        model = _get_model()\n        x1 = np.random.rand(2, 3)\n        x2 = np.random.rand(2, 3)\n        model.quantize(mode)\n        _ = model((x1, x2))\n\n        for layer in model._flatten_layers():\n            if isinstance(layer, (layers.Dense, layers.EinsumDense)):\n                self.assertEqual(\n                    layer.dtype_policy.name, f\"{mode}_from_float32\"\n                )\n                self.assertEqual(layer.dtype_policy.quantization_mode, mode)\n        if mode == \"int8\":\n            self.assertLen(model.variables, 6)\n            if backend.backend() == \"torch\":\n                self.assertLen(list(model.named_parameters()), 6)\n        elif mode == \"float8\":\n            self.assertLen(model.variables, 16)\n            if backend.backend() == \"torch\":\n                self.assertLen(list(model.named_parameters()), 16)\n\n    @parameterized.named_parameters(\n        (\"regex_string\", \"dense_1\", [\"dense_1\"]),\n        (\"list_of_regex\", [\"dense_1\", \"output\"], [\"dense_1\", \"output\"]),\n        (\"callable\", lambda l: \"dense\" in l.name, [\"dense_1\", \"dense_2\"]),\n    )\n    def test_quantize_with_filters(self, filters, expected_quantized_layers):\n        mode = \"int8\"\n        inputs = layers.Input([3])\n        x = layers.Dense(32, name=\"dense_1\")(inputs)\n        x = layers.Dense(32, name=\"dense_2\")(x)\n        outputs = layers.Dense(32, name=\"output\")(x)\n        model = Model(inputs, outputs)\n\n        model.quantize(mode, filters=filters)\n\n        for layer in model._flatten_layers():\n            if layer.name in expected_quantized_layers:\n                self.assertEqual(\n                    layer.dtype_policy.name, f\"{mode}_from_float32\"\n                )\n            elif isinstance(layer, layers.Dense):\n                self.assertNotEqual(\n                    layer.dtype_policy.name, f\"{mode}_from_float32\"\n                )\n\n    @parameterized.named_parameters(\n        (\"int8\", \"int8\"),\n        (\"float8\", \"float8\"),\n    )\n    def test_quantize_unbuilt(self, mode):\n        class MyModel(Model):\n            def __init__(self):\n                super().__init__()\n                self.dense1 = layers.Dense(32, activation=\"relu\")\n                self.dense2 = layers.Dense(5, activation=\"softmax\")\n                self.dropout = layers.Dropout(0.5)\n\n            def call(self, inputs, training=False):\n                x = self.dense1(inputs)\n                x = self.dropout(x, training=training)\n                return self.dense2(x)\n\n        model = MyModel()\n        with self.assertRaisesRegex(\n            ValueError, \"Cannot quantize a layer that isn't yet built.\"\n        ):\n            model.quantize(mode)\n\n        x = np.random.rand(2, 3)\n        _ = model(x)\n        model.quantize(mode)\n\n    def test_quantize_invalid_args(self):\n        model = _get_model()\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid quantization mode. Expected one of\"\n        ):\n            model.quantize(\"abc\")\n\n        with self.assertRaisesRegex(\n            ValueError, \"Unrecognized keyword arguments\"\n        ):\n            model.quantize(\"int8\", unrecognized_kwargs=None)\n\n        with self.assertRaisesRegex(ValueError, \"Invalid quantization mode\"):\n            model.quantize(\"int7\")\n\n    @parameterized.named_parameters(\n        (\"int8\", \"int8\"),\n        (\"float8\", \"float8\"),\n    )\n    def test_quantize_nested_model(self, mode):\n        class NestedLayer(layers.Layer):\n            def __init__(self, units):\n                super().__init__()\n                self.dense = layers.Dense(units)\n\n            def call(self, x):\n                x = self.dense(x)\n                return x\n\n        class DoubleNestedLayer(layers.Layer):\n            def __init__(self, units):\n                super().__init__()\n                self.nested_dense1 = NestedLayer(units)\n                self.nested_dense2 = NestedLayer(units)\n                self.dense = layers.Dense(units)\n\n            def call(self, x):\n                x = self.nested_dense1(x)\n                x = self.nested_dense2(x)\n                x = self.dense(x)\n                return x\n\n        inputs = layers.Input([3])\n        outputs = DoubleNestedLayer(8)(inputs)\n        model = Model(inputs, outputs)\n        model.quantize(mode)\n\n        if mode == \"int8\":\n            kernel_count = 0\n            for weight in model.weights:\n                if weight.name == \"kernel\":\n                    kernel_count += 1\n                    self.assertEqual(\n                        backend.standardize_dtype(weight.dtype), \"int8\"\n                    )\n            self.assertEqual(kernel_count, 3)\n        if mode == \"float8\":\n            # kernel + bias + scale * 3 + amax_history * 3 == 8\n            self.assertEqual(len(model.weights), 3 * 8)\n\n    def test_get_state_tree(self):\n        model = _get_model_single_output()\n        model.compile(loss=\"mse\", optimizer=\"adam\")\n        state_tree = model.get_state_tree()\n        self.assertAllClose(\n            state_tree[\"trainable_variables\"][\"output_a\"][\"kernel\"],\n            _get_variable_value_by_path(\n                model.trainable_variables, \"output_a/kernel\"\n            ),\n        )\n        self.assertAllClose(\n            state_tree[\"trainable_variables\"][\"output_a\"][\"bias\"],\n            _get_variable_value_by_path(\n                model.trainable_variables, \"output_a/bias\"\n            ),\n        )\n        self.assertEqual(\n            state_tree[\"non_trainable_variables\"],\n            {},\n        )\n        self.assertEqual(\n            state_tree[\"metrics_variables\"][\"loss\"][\"count\"],\n            _get_variable_value_by_path(model.metrics_variables, \"loss/count\"),\n        )\n        self.assertEqual(\n            state_tree[\"metrics_variables\"][\"loss\"][\"total\"],\n            _get_variable_value_by_path(model.metrics_variables, \"loss/total\"),\n        )\n        self.assertEqual(\n            state_tree[\"optimizer_variables\"][\"adam\"][\"iteration\"],\n            _get_variable_value_by_path(\n                model.optimizer.variables, \"adam/iteration\"\n            ),\n        )\n        self.assertEqual(\n            state_tree[\"optimizer_variables\"][\"adam\"][\"learning_rate\"],\n            _get_variable_value_by_path(\n                model.optimizer.variables, \"adam/learning_rate\"\n            ),\n        )\n\n        # Test with numpy\n        state_tree = model.get_state_tree(value_format=\"numpy_array\")\n        self.assertIsInstance(\n            state_tree[\"trainable_variables\"][\"output_a\"][\"kernel\"], np.ndarray\n        )\n\n    def test_set_state_tree(self):\n        variables = {\n            \"optimizer_variables\": {\n                \"adam\": {\n                    \"iteration\": 0,\n                    \"learning_rate\": 0.00001,\n                }\n            },\n            \"trainable_variables\": {\n                \"output_a\": {\n                    \"bias\": [0.5],\n                    \"kernel\": [[0.6], [0.7], [1.8]],\n                }\n            },\n        }\n\n        model = _get_model_single_output()\n        model.compile(optimizer=\"adam\")\n        model.set_state_tree(variables)\n\n        self.assertEqual(\n            variables[\"optimizer_variables\"][\"adam\"][\"iteration\"],\n            _get_variable_value_by_path(\n                model.optimizer.variables, \"adam/iteration\"\n            ),\n        )\n        self.assertEqual(\n            variables[\"optimizer_variables\"][\"adam\"][\"learning_rate\"],\n            _get_variable_value_by_path(\n                model.optimizer.variables, \"adam/learning_rate\"\n            ),\n        )\n        self.assertAllClose(\n            variables[\"trainable_variables\"][\"output_a\"][\"bias\"],\n            _get_variable_value_by_path(\n                model.trainable_variables, \"output_a/bias\"\n            ),\n        )\n        self.assertAllClose(\n            variables[\"trainable_variables\"][\"output_a\"][\"kernel\"],\n            _get_variable_value_by_path(\n                model.trainable_variables, \"output_a/kernel\"\n            ),\n        )\n\n    def test_get_state_tree_with_duplicate_path(self):\n        model = _get_model_with_duplicate_variable_path()\n        with self.assertRaisesRegex(\n            ValueError,\n            \"The following variable path is found twice in the model\",\n        ):\n            model.get_state_tree()\n\n    def test_layers_setter(self):\n        model = Model()\n        with self.assertRaisesRegex(\n            AttributeError, \"`Model.layers` attribute is reserved\"\n        ):\n            model.layers = [layers.Dense(4)]\n\n    def get_struct_loss(self, structure):\n        def loss_fn(y_true, y_pred):\n            tree.assert_same_structure(structure, y_true)\n            tree.assert_same_structure(structure, y_pred)\n            tree.map_structure(\n                lambda spec, tensor: self.assertEqual(spec.ndim, tensor.ndim),\n                structure,\n                y_true,\n            )\n            tree.map_structure(\n                lambda spec, tensor: self.assertEqual(spec.ndim, tensor.ndim),\n                structure,\n                y_pred,\n            )\n            flat_y_pred = tree.flatten(y_pred)\n            flat_y_true = tree.flatten(y_true)\n            diff = 0\n            for y_p, y_t in zip(flat_y_pred, flat_y_true):\n                diff += losses.mean_absolute_error(y_t, y_p)\n            return diff\n\n        return loss_fn\n\n    @parameterized.product(\n        _type=[tuple, list], other_type=[list, tuple], weighted=[False, True]\n    )\n    def test_functional_struct_outputs_struct_losses(\n        self, _type, other_type, weighted\n    ):\n        model = _get_model_multi_outputs_struct_list_like(_type)\n        self.assertIsInstance(model, Functional)\n        x = np.random.rand(8, 3)\n        y1 = np.random.rand(8, 1)\n        y2 = np.random.rand(8, 1)\n        y = _type([y1, y2])\n        loss = other_type(\n            [\n                self.get_struct_loss(model.output),\n                _type(\n                    [\n                        self.get_struct_loss(model.output[0]),\n                        self.get_struct_loss(model.output[1]),\n                    ]\n                ),\n            ]\n        )\n        if weighted:\n            loss_weights = tree.map_structure(lambda _: np.random.rand(), loss)\n        else:\n            loss_weights = None\n\n        model.compile(\n            optimizer=\"sgd\",\n            loss=loss,\n            loss_weights=loss_weights,\n        )\n\n        if _type is other_type:\n            with self.assertRaisesRegex(\n                ValueError, f\"[Ee]xpected.*{_type.__name__}\"\n            ):\n                model.fit(x, y, batch_size=2, epochs=1, verbose=0)\n        else:\n            # Check dict outputs.\n            outputs = model.predict(x)\n            self.assertIsInstance(outputs, _type)\n            # Fit the model to make sure compile_metrics are built\n            hist = model.fit(\n                x,\n                y,\n                batch_size=2,\n                epochs=1,\n                verbose=0,\n            )\n            hist_keys = sorted(hist.history.keys())\n            ref_keys = sorted(\n                [\n                    \"loss\",\n                    \"y1_loss\",\n                    \"y2_loss\",\n                    \"y1_y2_loss\",\n                ]\n            )\n            self.assertListEqual(hist_keys, ref_keys)\n\n    @parameterized.named_parameters((\"weighted\", True), (\"not_weighted\", False))\n    def test_functional_struct_outputs_dict_struct_losses(self, weighted):\n        model = _get_model_multi_outputs_struct_dict()\n        self.assertIsInstance(model, Functional)\n        x = np.random.rand(8, 3)\n        y1 = np.random.rand(8, 1)\n        y2 = np.random.rand(8, 1)\n\n        y = {\"a\": y1, \"b\": y2}\n        loss = [\n            self.get_struct_loss(model.output),\n            {\n                \"a\": self.get_struct_loss(model.output[\"a\"]),\n                \"b\": self.get_struct_loss(model.output[\"a\"]),\n            },\n        ]\n        if weighted:\n            loss_weights = tree.map_structure(lambda _: np.random.rand(), loss)\n        else:\n            loss_weights = None\n\n        model.compile(\n            optimizer=\"sgd\",\n            loss=loss,\n            loss_weights=loss_weights,\n        )\n        # Check dict outputs.\n        outputs = model.predict(x)\n        self.assertIsInstance(outputs, dict)\n\n        # Fit the model to make sure compile_metrics are built\n        hist = model.fit(\n            x,\n            y,\n            batch_size=2,\n            epochs=1,\n            verbose=0,\n        )\n        hist_keys = sorted(hist.history.keys())\n        ref_keys = sorted(\n            [\n                \"loss\",\n                \"a_loss\",\n                \"b_loss\",\n                \"a_b_loss\",\n            ]\n        )\n        self.assertListEqual(hist_keys, ref_keys)\n\n    def test_functional_struct_outputs_namedtuple_struct_losses(self):\n        model, Y = _get_model_multi_outputs_struct_namedtuple()\n        self.assertIsInstance(model, Functional)\n        x = np.random.rand(8, 3)\n        y1 = np.random.rand(8, 1)\n        y2 = np.random.rand(8, 1)\n\n        y = Y(y1, y2)\n        model.compile(\n            optimizer=\"sgd\",\n            loss=[\n                self.get_struct_loss(model.output),\n                Y(\n                    self.get_struct_loss(model.output.y1),\n                    self.get_struct_loss(model.output.y2),\n                ),\n            ],\n        )\n        # Check dict outputs.\n        outputs = model.predict(x)\n        self.assertIsInstance(outputs, tuple)\n\n        # Fit the model to make sure compile_metrics are built\n        hist = model.fit(\n            x,\n            y,\n            batch_size=2,\n            epochs=1,\n            verbose=0,\n        )\n        hist_keys = sorted(hist.history.keys())\n        ref_keys = sorted(\n            [\n                \"loss\",\n                \"y1_loss\",\n                \"y2_loss\",\n                \"y1_y2_loss\",\n            ]\n        )\n        self.assertListEqual(hist_keys, ref_keys)\n\n    def test_functional_deeply_nested_outputs_struct_losses(self):\n        model = _get_model_multi_outputs_struct()\n        self.assertIsInstance(model, Functional)\n        x = np.random.rand(8, 3)\n        y1 = np.random.rand(8, 1)\n        y2 = np.random.rand(8, 1)\n        y3 = np.random.rand(8, 1)\n        y = {\n            \"a\": (y1, y2),\n            \"b\": {\"b1\": y1, \"b2\": y2},\n            \"c\": {\"c1\": (y1, y2), \"c2\": y2},\n            \"d\": y3,\n        }\n        model.compile(\n            optimizer=\"sgd\",\n            loss={\n                \"a\": [\n                    self.get_struct_loss(model.output[\"a\"]),\n                    (None, self.get_struct_loss(model.output[\"a\"][1])),\n                ],\n                \"b\": [\n                    self.get_struct_loss(model.output[\"b\"]),\n                    {\"b1\": self.get_struct_loss(model.output[\"b\"][\"b1\"])},\n                ],\n                \"c\": [\n                    self.get_struct_loss(model.output[\"c\"]),\n                    {\"c1\": self.get_struct_loss(model.output[\"c\"][\"c1\"])},\n                ],\n                \"d\": self.get_struct_loss(model.output[\"d\"]),\n            },\n        )\n        # Check dict outputs.\n        outputs = model.predict(x)\n        self.assertIsInstance(outputs, dict)\n\n        # Fit the model to make sure compile_metrics are built\n        hist = model.fit(\n            x,\n            y,\n            batch_size=2,\n            epochs=1,\n            verbose=0,\n        )\n        hist_keys = sorted(hist.history.keys())\n        ref_keys = sorted(\n            [\n                \"a/y2_loss\",\n                \"a_loss\",\n                \"b/b1_loss\",\n                \"b_loss\",\n                \"c/c1_loss\",\n                \"c_loss\",\n                \"d_loss\",\n                \"loss\",\n            ]\n        )\n        self.assertListEqual(hist_keys, ref_keys)\n\n    @parameterized.named_parameters(\n        (\"optional_none\", True), (\"optional_tensor\", False)\n    )\n    def test_functional_optional_inputs(self, is_optional_none):\n        model = _get_model_optional_inputs()\n        x = np.ones((2, 2))\n        o = None if is_optional_none else np.ones((2, 2))\n        y_true = np.ones((2, 2))\n\n        model.compile(loss=\"mse\", optimizer=\"adam\")\n        model.fit(x={\"x\": x, \"o\": o}, y=y_true)\n        model.evaluate(x={\"x\": x, \"o\": o}, y=y_true)\n        model.predict(x={\"x\": x, \"o\": o})\n\n    @parameterized.named_parameters(\n        (\"optional_none\", True), (\"optional_tensor\", False)\n    )\n    def test_functional_optional_inputs_generator(self, is_optional_none):\n        model = _get_model_optional_inputs()\n        x = np.ones((2, 2))\n        o = None if is_optional_none else np.ones((2, 2))\n        y_true = np.ones((2, 2))\n\n        def data_generator(with_y=True):\n            for _ in range(4):\n                yield ({\"x\": x, \"o\": o},) + ((y_true,) if with_y else ())\n\n        model.compile(loss=\"mse\", optimizer=\"adam\")\n        model.fit(data_generator())\n        model.evaluate(data_generator())\n        model.predict(data_generator(with_y=False))\n\n    def test_export_error(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"exported_model\")\n        model = _get_model()\n\n        # Bad format\n        with self.assertRaisesRegex(ValueError, \"Unrecognized format=\"):\n            model.export(temp_filepath, format=\"bad_format\")\n\n        # Bad backend\n        if backend.backend() not in (\"tensorflow\", \"jax\", \"torch\"):\n            with self.assertRaisesRegex(\n                NotImplementedError,\n                (\n                    r\"`export_saved_model` only currently supports the \"\n                    r\"tensorflow, jax and torch backends.\"\n                ),\n            ):\n                model.export(temp_filepath, format=\"tf_saved_model\")\n"
  },
  {
    "path": "keras/src/models/sequential.py",
    "content": "import copy\nimport inspect\nimport typing\n\nfrom keras.src import backend\nfrom keras.src import tree\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend.common import global_state\nfrom keras.src.backend.common import standardize_shape\nfrom keras.src.layers.core.input_layer import InputLayer\nfrom keras.src.layers.layer import Layer\nfrom keras.src.legacy.saving import saving_utils\nfrom keras.src.legacy.saving import serialization as legacy_serialization\nfrom keras.src.models.functional import Functional\nfrom keras.src.models.model import Model\nfrom keras.src.saving import serialization_lib\n\n\n@keras_export([\"keras.Sequential\", \"keras.models.Sequential\"])\nclass Sequential(Model):\n    \"\"\"`Sequential` groups a linear stack of layers into a `Model`.\n\n    Examples:\n\n    ```python\n    model = keras.Sequential()\n    model.add(keras.Input(shape=(16,)))\n    model.add(keras.layers.Dense(8))\n\n    # Note that you can also omit the initial `Input`.\n    # In that case the model doesn't have any weights until the first call\n    # to a training/evaluation method (since it isn't yet built):\n    model = keras.Sequential()\n    model.add(keras.layers.Dense(8))\n    model.add(keras.layers.Dense(4))\n    # model.weights not created yet\n\n    # Whereas if you specify an `Input`, the model gets built\n    # continuously as you are adding layers:\n    model = keras.Sequential()\n    model.add(keras.Input(shape=(16,)))\n    model.add(keras.layers.Dense(8))\n    len(model.weights)  # Returns \"2\"\n\n    # When using the delayed-build pattern (no input shape specified), you can\n    # choose to manually build your model by calling\n    # `build(batch_input_shape)`:\n    model = keras.Sequential()\n    model.add(keras.layers.Dense(8))\n    model.add(keras.layers.Dense(4))\n    model.build((None, 16))\n    len(model.weights)  # Returns \"4\"\n\n    # Note that when using the delayed-build pattern (no input shape specified),\n    # the model gets built the first time you call `fit`, `eval`, or `predict`,\n    # or the first time you call the model on some input data.\n    model = keras.Sequential()\n    model.add(keras.layers.Dense(8))\n    model.add(keras.layers.Dense(1))\n    model.compile(optimizer='sgd', loss='mse')\n    # This builds the model for the first time:\n    model.fit(x, y, batch_size=32, epochs=10)\n    ```\n    \"\"\"\n\n    def __new__(cls, *args, **kwargs):\n        return typing.cast(cls, super().__new__(cls))\n\n    def __init__(self, layers=None, trainable=True, name=None):\n        super().__init__(trainable=trainable, name=name)\n        self._functional = None\n        self._layers = []\n        if layers:\n            for layer in layers:\n                self.add(layer, rebuild=False)\n            self._maybe_rebuild()\n\n    def add(self, layer, rebuild=True):\n        \"\"\"Adds a layer instance on top of the layer stack.\n\n        Args:\n            layer: layer instance.\n        \"\"\"\n        # Legacy case: if the first layer has an input_shape arg,\n        # use it to build an InputLayer.\n        if not self._layers:\n            if getattr(layer, \"_input_shape_arg\", None) is not None:\n                self.add(InputLayer(shape=layer._input_shape_arg))\n\n        # If we are passed a Keras tensor created by keras.Input(), we\n        # extract the input layer from its keras history and use that.\n        if hasattr(layer, \"_keras_history\"):\n            origin_layer = layer._keras_history[0]\n            if isinstance(origin_layer, InputLayer):\n                layer = origin_layer\n        if not isinstance(layer, Layer):\n            raise ValueError(\n                \"Only instances of `keras.Layer` can be \"\n                f\"added to a Sequential model. Received: {layer} \"\n                f\"(of type {type(layer)})\"\n            )\n        if not self._is_layer_name_unique(layer):\n            raise ValueError(\n                \"All layers added to a Sequential model \"\n                f\"should have unique names. Name '{layer.name}' is already \"\n                \"the name of a layer in this model. Update the `name` argument \"\n                \"to pass a unique name.\"\n            )\n        if (\n            isinstance(layer, InputLayer)\n            and self._layers\n            and isinstance(self._layers[0], InputLayer)\n        ):\n            raise ValueError(\n                f\"Sequential model '{self.name}' has already been configured \"\n                f\"to use input shape {self._layers[0].batch_shape}. You cannot \"\n                f\"add a different Input layer to it.\"\n            )\n\n        self._layers.append(layer)\n        if rebuild:\n            self._maybe_rebuild()\n        else:\n            self.built = False\n            self._functional = None\n\n    def pop(self, rebuild=True):\n        \"\"\"Removes the last layer in the model.\n\n        Args:\n            rebuild: `bool`. Whether to rebuild the model after removing\n            the layer. Defaults to `True`.\n\n        Returns:\n            layer: layer instance.\n        \"\"\"\n        layer = self._layers.pop()\n        self.built = False\n        self._functional = None\n        if rebuild:\n            self._maybe_rebuild()\n        return layer\n\n    def _maybe_rebuild(self):\n        self.built = False\n        self._functional = None\n        if isinstance(self._layers[0], InputLayer) and len(self._layers) > 1:\n            input_shape = self._layers[0].batch_shape\n            self.build(input_shape)\n        elif hasattr(self._layers[0], \"input_shape\") and len(self._layers) > 1:\n            # We can build the Sequential model if the first layer has the\n            # `input_shape` property. This is most commonly found in Functional\n            # model.\n            input_shape = self._layers[0].input_shape\n            self.build(input_shape)\n\n    def _lock_state(self):\n        # Unlike other layers, Sequential is mutable after build.\n        pass\n\n    def _obj_type(self):\n        return \"Sequential\"\n\n    def build(self, input_shape=None):\n        try:\n            input_shape = standardize_shape(input_shape)\n        except:\n            # Do not attempt to build if the model does not have a single\n            # input tensor.\n            return\n        if not self._layers:\n            raise ValueError(\n                f\"Sequential model {self.name} cannot be built because it has \"\n                \"no layers. Call `model.add(layer)`.\"\n            )\n        if isinstance(self._layers[0], InputLayer):\n            if self._layers[0].batch_shape != input_shape:\n                raise ValueError(\n                    f\"Sequential model '{self.name}' has already been \"\n                    \"configured to use input shape \"\n                    f\"{self._layers[0].batch_shape}. You cannot build it \"\n                    f\"with input_shape {input_shape}\"\n                )\n        else:\n            dtype = self._layers[0].compute_dtype\n            self._layers = [\n                InputLayer(batch_shape=input_shape, dtype=dtype)\n            ] + self._layers\n\n        # Build functional model\n        inputs = self._layers[0].output\n        x = inputs\n        for layer in self._layers[1:]:\n            try:\n                x = layer(x)\n            except NotImplementedError:\n                # Can happen if shape inference is not implemented.\n                # TODO: consider reverting inbound nodes on layers processed.\n                return\n            except TypeError as e:\n                signature = inspect.signature(layer.call)\n                positional_args = [\n                    param\n                    for param in signature.parameters.values()\n                    if param.kind\n                    in (\n                        inspect.Parameter.POSITIONAL_ONLY,\n                        inspect.Parameter.POSITIONAL_OR_KEYWORD,\n                    )\n                ]\n                required_positional_args = [\n                    param\n                    for param in positional_args\n                    if param.default == inspect.Parameter.empty\n                ]\n                if not positional_args:\n                    raise ValueError(\n                        \"Layers added to a Sequential model should \"\n                        \"have a single positional argument, the \"\n                        \"input tensor. Layer \"\n                        f\"{layer.__class__.__name__} has no \"\n                        \"positional arguments.\"\n                    )\n                if len(required_positional_args) > 1:\n                    raise ValueError(\n                        \"Layers added to a Sequential model can \"\n                        \"only have a single required positional \"\n                        \"argument, the input tensor. Layer \"\n                        f\"{layer.__class__.__name__} has multiple \"\n                        \"required positional arguments: \"\n                        f\"{required_positional_args}\"\n                    )\n                raise e\n        outputs = x\n        self._functional = Functional(inputs=inputs, outputs=outputs)\n\n    def call(self, inputs, training=None, mask=None, **kwargs):\n        if self._functional:\n            return self._functional.call(\n                inputs, training=training, mask=mask, **kwargs\n            )\n\n        # Fallback: Just apply the layer sequence.\n        # This typically happens if `inputs` is a nested struct.\n        for layer in self.layers:\n            # During each iteration, `inputs` are the inputs to `layer`, and\n            # `outputs` are the outputs of `layer` applied to `inputs`. At the\n            # end of each iteration `inputs` is set to `outputs` to prepare for\n            # the next layer.\n            layer_kwargs = {\n                k: kwargs[k]\n                # only inject if this layer’s signature actually has that arg\n                for k in getattr(layer, \"_call_has_context_arg\", {})\n                if k in kwargs\n            }\n            if layer._call_has_mask_arg:\n                layer_kwargs[\"mask\"] = mask\n            if layer._call_has_training_arg and training is not None:\n                layer_kwargs[\"training\"] = training\n            outputs = layer(inputs, **layer_kwargs)\n            inputs = outputs\n\n            mask = tree.map_structure(backend.get_keras_mask, outputs)\n        return outputs\n\n    @property\n    def layers(self):\n        # Historically, `sequential.layers` only returns layers that were added\n        # via `add`, and omits the auto-generated `InputLayer` that comes at the\n        # bottom of the stack.\n        layers = self._layers\n        if layers and isinstance(layers[0], InputLayer):\n            return layers[1:]\n        return layers[:]\n\n    @layers.setter\n    def layers(self, _):\n        raise AttributeError(\n            \"`Sequential.layers` attribute is reserved and should not be used. \"\n            \"Use `add()` and `pop()` to change the layers in this model.\"\n        )\n\n    def compute_output_spec(self, inputs, training=None, mask=None, **kwargs):\n        if self._functional:\n            return self._functional.compute_output_spec(\n                inputs, training=training, mask=mask, **kwargs\n            )\n        # Direct application\n        for layer in self.layers:\n            outputs = layer.compute_output_spec(\n                inputs,\n                training=training,\n                **kwargs,\n            )  # Ignore mask\n            inputs = outputs\n        return outputs\n\n    def compute_output_shape(self, input_shape):\n        if self._functional:\n            return self._functional.compute_output_shape(input_shape)\n        # Direct application\n        for layer in self.layers:\n            output_shape = layer.compute_output_shape(input_shape)\n            input_shape = output_shape\n        return output_shape\n\n    @property\n    def input_shape(self):\n        if self._functional:\n            return self._functional.input_shape\n        raise AttributeError(\n            f\"Sequential model '{self.name}' has no defined input shape yet.\"\n        )\n\n    @property\n    def output_shape(self):\n        if self._functional:\n            return self._functional.output_shape\n        raise AttributeError(\n            f\"Sequential model '{self.name}' has no defined output shape yet.\"\n        )\n\n    @property\n    def inputs(self):\n        if self._functional:\n            return self._functional.inputs\n        raise AttributeError(\n            f\"Sequential model '{self.name}' has no defined inputs yet.\"\n        )\n\n    @property\n    def outputs(self):\n        if self._functional:\n            return self._functional.outputs\n        raise AttributeError(\n            f\"Sequential model '{self.name}' has no defined outputs yet.\"\n        )\n\n    @property\n    def input_dtype(self):\n        # Sequential.__call__ will try to convert its inputs\n        # to the dtype expected by its input layer, if any.\n        layers = self._layers\n        if layers and isinstance(layers[0], InputLayer):\n            return layers[0].dtype\n        return super().input_dtype\n\n    def _is_layer_name_unique(self, layer):\n        for ref_layer in self._layers:\n            if layer.name == ref_layer.name and ref_layer is not layer:\n                return False\n        return True\n\n    def get_config(self):\n        serialize_fn = serialization_lib.serialize_keras_object\n        if global_state.get_global_attribute(\"use_legacy_config\", False):\n            # Legacy format serialization used for H5 and SavedModel formats\n            serialize_fn = legacy_serialization.serialize_keras_object\n        layer_configs = []\n        for layer in super().layers:\n            # `super().layers` include the InputLayer if available (it is\n            # filtered out of `self.layers`).\n            layer_configs.append(serialize_fn(layer))\n        config = Model.get_config(self)\n        config[\"name\"] = self.name\n        config[\"layers\"] = copy.deepcopy(layer_configs)\n        if self._functional is not None:\n            config[\"build_input_shape\"] = self._layers[0].batch_shape\n        return config\n\n    @classmethod\n    def from_config(cls, config, custom_objects=None):\n        if \"name\" in config:\n            name = config[\"name\"]\n            build_input_shape = config.get(\"build_input_shape\")\n            layer_configs = config[\"layers\"]\n        else:\n            name = None\n            layer_configs = config\n        model = cls(name=name)\n        for layer_config in layer_configs:\n            if \"module\" not in layer_config:\n                # Legacy format deserialization (no \"module\" key)\n                # used for H5 and SavedModel formats\n                layer = saving_utils.model_from_config(\n                    layer_config,\n                    custom_objects=custom_objects,\n                )\n            else:\n                layer = serialization_lib.deserialize_keras_object(\n                    layer_config,\n                    custom_objects=custom_objects,\n                )\n            model.add(layer)\n        if (\n            not model._functional\n            and \"build_input_shape\" in locals()\n            and build_input_shape\n            and isinstance(build_input_shape, (tuple, list))\n        ):\n            model.build(build_input_shape)\n        return model\n"
  },
  {
    "path": "keras/src/models/sequential_test.py",
    "content": "import pickle\n\nimport numpy as np\nimport pytest\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import saving\nfrom keras.src import testing\nfrom keras.src.layers.core.input_layer import Input\nfrom keras.src.models.functional import Functional\nfrom keras.src.models.model import Model\nfrom keras.src.models.sequential import Sequential\n\n\n@pytest.mark.requires_trainable_backend\nclass SequentialTest(testing.TestCase):\n    def test_basic_flow_with_input(self):\n        model = Sequential(name=\"seq\")\n        model.add(Input(shape=(2,), batch_size=3))\n        model.add(layers.Dense(4))\n        model.add(layers.Dense(5))\n        model.summary()\n\n        self.assertEqual(len(model.layers), 2)\n        self.assertTrue(model.built)\n        self.assertEqual(len(model.weights), 4)\n\n        # Test eager call\n        x = np.random.random((3, 2))\n        y = model(x)\n\n        self.assertEqual(type(model._functional), Functional)\n        self.assertEqual(y.shape, (3, 5))\n\n        # Test symbolic call\n        x = backend.KerasTensor((3, 2))\n        y = model(x)\n        self.assertEqual(y.shape, (3, 5))\n\n        # Test `layers` constructor arg\n        model = Sequential(\n            layers=[\n                Input(shape=(2,), batch_size=3),\n                layers.Dense(4),\n                layers.Dense(5),\n            ]\n        )\n        self.assertEqual(len(model.layers), 2)\n        self.assertTrue(model.built)\n        self.assertEqual(len(model.weights), 4)\n\n        x = np.random.random((3, 2))\n        y = model(x)\n        self.assertEqual(y.shape, (3, 5))\n\n        # Test pop\n        model.pop()\n        self.assertEqual(len(model.layers), 1)\n        self.assertTrue(model.built)\n        self.assertEqual(len(model.weights), 2)\n\n        x = np.random.random((3, 2))\n        y = model(x)\n        self.assertEqual(y.shape, (3, 4))\n\n    def test_legacy_flow_with_input_shape(self):\n        model = Sequential(name=\"seq\")\n        model.add(layers.Dense(4, input_shape=(2,)))\n        model.add(layers.Dense(5))\n\n        self.assertEqual(len(model.layers), 2)\n        self.assertTrue(model.built)\n        self.assertEqual(len(model.weights), 4)\n        self.assertEqual(type(model._functional), Functional)\n\n        # Input_dim works too\n        model = Sequential(name=\"seq\")\n        model.add(layers.Dense(4, input_dim=2))\n        model.add(layers.Dense(5))\n\n        self.assertEqual(len(model.layers), 2)\n        self.assertTrue(model.built)\n        self.assertEqual(len(model.weights), 4)\n        self.assertEqual(type(model._functional), Functional)\n\n        # Subsequent input_shapes are ignored\n        model = Sequential(name=\"seq\")\n        model.add(layers.Dense(4, input_shape=(2,)))\n        model.add(layers.Dense(5, input_shape=(3, 4)))\n\n        self.assertEqual(len(model.layers), 2)\n        self.assertTrue(model.built)\n        self.assertEqual(len(model.weights), 4)\n        self.assertEqual(type(model._functional), Functional)\n\n    def test_basic_flow_deferred(self):\n        model = Sequential(name=\"seq\")\n        model.add(layers.Dense(4))\n        model.add(layers.Dense(5))\n        model.summary()\n\n        self.assertEqual(len(model.layers), 2)\n\n        # Test eager call\n        x = np.random.random((3, 2))\n        y = model(x)\n        self.assertTrue(model.built)\n        model.summary()\n\n        self.assertEqual(type(model._functional), Functional)\n        self.assertEqual(y.shape, (3, 5))\n\n        # Test symbolic call\n        x = backend.KerasTensor((3, 2))\n        y = model(x)\n        self.assertEqual(y.shape, (3, 5))\n\n        # Test `layers` constructor arg\n        model = Sequential(\n            layers=[\n                layers.Dense(4),\n                layers.Dense(5),\n            ]\n        )\n        x = np.random.random((3, 2))\n        y = model(x)\n        self.assertEqual(y.shape, (3, 5))\n\n        # Test pop\n        model.pop()\n        self.assertEqual(len(model.layers), 1)\n        self.assertTrue(model.built)\n        self.assertEqual(len(model.weights), 2)\n\n        x = np.random.random((3, 2))\n        y = model(x)\n        self.assertEqual(y.shape, (3, 4))\n\n    def test_basic_flow_as_a_submodel(self):\n        # Build submodel\n        submodel = Sequential()\n        submodel.add(layers.Flatten())\n        self.assertFalse(submodel.built)\n\n        inputs = Input((None, 4))\n        outputs = layers.TimeDistributed(submodel)(inputs)\n        model = Model(inputs=inputs, outputs=outputs)\n\n        x = np.random.random((2, 3, 4))\n        y = model(x)\n        self.assertEqual(y.shape, (2, 3, 4))\n\n    def test_basic_flow_with_functional_model_as_first_layer(self):\n        # Build functional model\n        inputs = Input((16, 16, 3))\n        outputs = layers.Conv2D(4, 3, padding=\"same\")(inputs)\n        functional_model = Model(inputs=inputs, outputs=outputs)\n\n        model = Sequential(\n            [functional_model, layers.Flatten(), layers.Dense(1)]\n        )\n        model.summary()\n        self.assertEqual(len(model.layers), 3)\n        self.assertTrue(model.built)\n        for layer in model.layers:\n            self.assertTrue(layer.built)\n\n        # Test eager call\n        x = np.random.random((1, 16, 16, 3))\n        y = model(x)\n        self.assertEqual(type(model._functional), Functional)\n        self.assertEqual(tuple(y.shape), (1, 1))\n\n        # Test symbolic call\n        x = backend.KerasTensor((1, 16, 16, 3))\n        y = model(x)\n        self.assertEqual(y.shape, (1, 1))\n\n    def test_basic_flow_with_sequential_model_as_first_layer(self):\n        # Build sequential model\n        sequential_model = Sequential(\n            [Input((16, 16, 3)), layers.Conv2D(4, 3, padding=\"same\")]\n        )\n\n        model = Sequential(\n            [sequential_model, layers.Flatten(), layers.Dense(1)]\n        )\n        model.summary()\n        self.assertEqual(len(model.layers), 3)\n        self.assertTrue(model.built)\n        for layer in model.layers:\n            self.assertTrue(layer.built)\n\n        # Test eager call\n        x = np.random.random((1, 16, 16, 3))\n        y = model(x)\n        self.assertEqual(type(model._functional), Functional)\n        self.assertEqual(tuple(y.shape), (1, 1))\n\n        # Test symbolic call\n        x = backend.KerasTensor((1, 16, 16, 3))\n        y = model(x)\n        self.assertEqual(y.shape, (1, 1))\n\n    def test_dict_inputs(self):\n        test_obj = self\n\n        class DictLayer(layers.Layer):\n            def call(self, inputs):\n                test_obj.assertIsInstance(inputs, dict)\n                return inputs\n\n        model = Sequential([DictLayer()])\n        x = {\"a\": np.random.random((3, 2)), \"b\": np.random.random((3, 2))}\n        y = model(x)\n        self.assertEqual(type(y), dict)\n        model.summary()\n\n    def test_list_inputs(self):\n        test_obj = self\n\n        class ListLayer(layers.Layer):\n            def call(self, inputs):\n                test_obj.assertIsInstance(inputs, list)\n                return inputs\n\n        model = Sequential([ListLayer()])\n        x = [np.random.random((3, 2)), np.random.random((3, 2))]\n        y = model(x)\n        self.assertEqual(type(y), list)\n        model.summary()\n\n    def test_nested_sequential(self):\n        # https://github.com/keras-team/keras/issues/20203\n        model = Sequential()\n        model.add(Input(shape=(16,)))\n        Sequential([model])\n\n    def test_errors(self):\n        # Trying to pass 2 Inputs\n        model = Sequential()\n        model.add(Input(shape=(2,), batch_size=3))\n        with self.assertRaisesRegex(ValueError, \"already been configured\"):\n            model.add(Input(shape=(2,), batch_size=3))\n        with self.assertRaisesRegex(ValueError, \"already been configured\"):\n            model.add(layers.InputLayer(shape=(2,), batch_size=3))\n\n        # Same name 2x\n        model = Sequential()\n        model.add(layers.Dense(2, name=\"dense\"))\n        with self.assertRaisesRegex(ValueError, \"should have unique names\"):\n            model.add(layers.Dense(2, name=\"dense\"))\n\n        # No layers\n        model = Sequential()\n        x = np.random.random((3, 2))\n        with self.assertRaisesRegex(ValueError, \"no layers\"):\n            model(x)\n\n        # Build conflict\n        model = Sequential()\n        model.add(Input(shape=(2,), batch_size=3))\n        model.add(layers.Dense(2))\n        with self.assertRaisesRegex(ValueError, \"already been configured\"):\n            model.build((3, 4))\n        # But this works\n        model.build((3, 2))\n\n    def test_shape_inference_failure(self):\n        class DynamicLayer(layers.Layer):\n            def call(self, inputs):\n                return inputs + 1.0\n\n            def compute_output_spec(self, *args, **kwargs):\n                raise NotImplementedError\n\n        model = Sequential([DynamicLayer()])\n        x = np.random.random((3, 2))\n        y = model(x)\n        self.assertAllClose(y, x + 1)\n        model.summary()\n\n    def test_serialization(self):\n        test_obj = self\n\n        # Unbuilt deferred\n        model = Sequential(name=\"seq\")\n        model.add(layers.Dense(4))\n        model.add(layers.Dense(5))\n        revived = self.run_class_serialization_test(model)\n        self.assertLen(revived.layers, 2)\n\n        # Built deferred\n        model.build((2, 3))\n        revived = self.run_class_serialization_test(model)\n        self.assertLen(revived.layers, 2)\n\n        # Regular\n        model = Sequential(name=\"seq\")\n        model.add(Input(shape=(2,), batch_size=3))\n        model.add(layers.Dense(4))\n        model.add(layers.Dense(5))\n        model.add(layers.Dense(6))\n        revived = self.run_class_serialization_test(model)\n        self.assertLen(revived.layers, 3)\n\n        # Weird\n        class DictLayer(layers.Layer):\n            def call(self, inputs):\n                test_obj.assertIsInstance(inputs, dict)\n                return inputs\n\n        model = Sequential([DictLayer()])\n        revived = self.run_class_serialization_test(\n            model, custom_objects={\"DictLayer\": DictLayer}\n        )\n        self.assertLen(revived.layers, 1)\n\n    def test_serialization_with_lambda_layer(self):\n        # https://github.com/keras-team/keras/issues/20074\n        inputs = np.random.random(size=(1, 10, 4)).astype(\"float32\")\n        CONV_WIDTH = 3\n        model = Sequential([layers.Lambda(lambda x: x[:, -CONV_WIDTH:, :])])\n        outputs = model(inputs)\n\n        temp = self.get_temp_dir()\n        save_path = f\"{temp}/model.keras\"\n        model.save(save_path)\n        revived = saving.load_model(save_path, safe_mode=False)\n        revived_outputs = revived(inputs)\n        self.assertLen(revived.layers, 1)\n        self.assertAllClose(revived_outputs, outputs)\n\n    def test_functional_properties(self):\n        model = Sequential(name=\"seq\")\n        inputs = Input(shape=(2,))\n        model.add(inputs)\n        model.add(layers.Dense(4))\n\n        self.assertEqual(model.inputs, [inputs])\n        self.assertEqual(model.outputs, [model.layers[-1].output])\n        self.assertEqual(model.input_shape, (None, 2))\n        self.assertEqual(model.output_shape, (None, 4))\n\n    def test_pickleable(self):\n        model = Sequential(name=\"seq\")\n        model.add(layers.Dense(4))\n\n        result = pickle.loads(pickle.dumps(model))\n        self.assertLen(result.layers, 1)\n\n    def test_bad_layer(self):\n        model = Sequential(name=\"seq\")\n        with self.assertRaisesRegex(ValueError, \"Only instances of\"):\n            model.add({})\n\n        model = Sequential(name=\"seq\")\n\n        class BadLayer(layers.Layer):\n            def call(self, inputs, training):\n                return inputs\n\n        model.add(BadLayer())\n        with self.assertRaisesRegex(\n            ValueError, \"can only have a single.*positional\"\n        ):\n            model.build((None, 2))\n\n    def test_compute_output_shape(self):\n        layer = Sequential([layers.Dense(4), layers.Dense(8)])\n        output_shape = layer.compute_output_shape((1, 2))\n        self.assertEqual(output_shape, (1, 8))\n\n    def test_hasattr(self):\n        model = Sequential()\n        self.assertFalse(hasattr(model, \"input_shape\"))\n        self.assertFalse(hasattr(model, \"output_shape\"))\n        self.assertFalse(hasattr(model, \"inputs\"))\n        self.assertFalse(hasattr(model, \"outputs\"))\n\n        model = Sequential([layers.Input((4,)), layers.Dense(8)])\n        self.assertTrue(hasattr(model, \"input_shape\"))\n        self.assertTrue(hasattr(model, \"output_shape\"))\n        self.assertTrue(hasattr(model, \"inputs\"))\n        self.assertTrue(hasattr(model, \"outputs\"))\n\n    def test_layers_setter(self):\n        model = Sequential()\n        with self.assertRaisesRegex(\n            AttributeError, r\"Use `add\\(\\)` and `pop\\(\\)`\"\n        ):\n            model.layers = [layers.Dense(4)]\n"
  },
  {
    "path": "keras/src/models/variable_mapping.py",
    "content": "from keras.src.layers.layer import Layer\nfrom keras.src.metrics.metric import Metric\nfrom keras.src.optimizers.optimizer import Optimizer\nfrom keras.src.saving import saving_lib\nfrom keras.src.saving.keras_saveable import KerasSaveable\n\n\ndef map_saveable_variables(saveable, store, visited_saveables):\n    # If the saveable has already been seen, skip it.\n    if id(saveable) in visited_saveables:\n        return\n\n    visited_saveables.add(id(saveable))\n\n    variables = []\n    if isinstance(saveable, Layer):\n        variables = (\n            saveable._trainable_variables + saveable._non_trainable_variables\n        )\n    elif isinstance(saveable, Optimizer):\n        variables = saveable._variables\n    elif isinstance(saveable, Metric):\n        variables = saveable._variables\n    for v in variables:\n        if v.path in store:\n            raise ValueError(\n                \"The model contains two variables with a duplicate path: \"\n                f\"path='{v.path}' appears at least twice. \"\n                f\"This path is used for {v} and for {store[v.path]}. \"\n                \"In order to get a variable map, make sure to use \"\n                \"unique paths/names for each variable.\"\n            )\n        store[v.path] = v\n\n    # Recursively save state of children saveables (layers, optimizers, etc.)\n    for child_attr, child_obj in saving_lib._walk_saveable(saveable):\n        if isinstance(child_obj, KerasSaveable):\n            map_saveable_variables(\n                child_obj,\n                store,\n                visited_saveables=visited_saveables,\n            )\n        elif isinstance(child_obj, (list, dict, tuple, set)):\n            map_container_variables(\n                child_obj,\n                store,\n                visited_saveables=visited_saveables,\n            )\n\n\ndef map_container_variables(container, store, visited_saveables):\n    if isinstance(container, dict):\n        container = list(container.values())\n\n    for saveable in container:\n        if isinstance(saveable, KerasSaveable):\n            map_saveable_variables(\n                saveable,\n                store,\n                visited_saveables=visited_saveables,\n            )\n"
  },
  {
    "path": "keras/src/models/variable_mapping_test.py",
    "content": "import numpy as np\n\nfrom keras.src import testing\nfrom keras.src.saving import saving_lib_test\n\n\nclass VariableMappingTest(testing.TestCase):\n    def test_basics(self):\n        model = saving_lib_test._get_basic_functional_model()\n        model.optimizer.build(model.trainable_variables)\n        variable_map = model._get_variable_map()\n\n        self.assertIn(\"first_dense/kernel\", variable_map)\n        self.assertIn(\"second_dense/bias\", variable_map)\n        self.assertIn(\"adam/learning_rate\", variable_map)\n\n        model = saving_lib_test._get_basic_sequential_model()\n        model.build((None, 1))\n        model.optimizer.build(model.trainable_variables)\n        variable_map = model._get_variable_map()\n        self.assertIn(\"sequential/dense_1/bias\", variable_map)\n        self.assertIn(\"adam/learning_rate\", variable_map)\n\n        model = saving_lib_test._get_subclassed_model()\n        model(np.ones((1, 1)))\n        model.optimizer.build(model.trainable_variables)\n        variable_map = model._get_variable_map()\n        self.assertIn(\"custom_model_x/my_dense_1/dense/kernel\", variable_map)\n        self.assertIn(\"custom_model_x/my_dense_1/my_dict_weight\", variable_map)\n        self.assertIn(\n            \"custom_model_x/my_dense_1/my_additional_weight\", variable_map\n        )\n        self.assertIn(\"adam/learning_rate\", variable_map)\n"
  },
  {
    "path": "keras/src/ops/__init__.py",
    "content": "# from keras.src.ops.numpy import Matmul, matmul\n# from keras.src.ops.numpy import Add, add\n# from keras.src.ops.numpy import Multiply, multiply\n\nfrom keras.src.backend import cast\nfrom keras.src.backend import cond\nfrom keras.src.backend import is_tensor\nfrom keras.src.backend import name_scope\nfrom keras.src.backend import random\nfrom keras.src.ops import image\nfrom keras.src.ops import operation_utils\nfrom keras.src.ops.core import *  # noqa: F403\nfrom keras.src.ops.linalg import *  # noqa: F403\nfrom keras.src.ops.math import *  # noqa: F403\nfrom keras.src.ops.nn import *  # noqa: F403\nfrom keras.src.ops.numpy import *  # noqa: F403\n"
  },
  {
    "path": "keras/src/ops/core.py",
    "content": "import ml_dtypes\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src import tree\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend import KerasTensor\nfrom keras.src.backend import any_symbolic_tensors\nfrom keras.src.backend.common.backend_utils import slice_along_axis\nfrom keras.src.ops.operation import Operation\nfrom keras.src.saving import serialization_lib\nfrom keras.src.utils import traceback_utils\n\n\nclass Map(Operation):\n    def call(self, f, xs):\n        return backend.core.map(f, xs)\n\n    def compute_output_spec(self, f, xs):\n        x = tree.map_structure(lambda t: t[0], xs)\n        n = tree.flatten(xs)[0].shape[0]\n        y = backend.compute_output_spec(f, x)\n\n        def append_batch_axis(t):\n            return KerasTensor(\n                shape=(n,) + t.shape,\n                dtype=t.dtype,\n                sparse=t.sparse,\n                ragged=t.ragged,\n            )\n\n        y = tree.map_structure(append_batch_axis, y)\n        return y\n\n\n@keras_export(\"keras.ops.map\")\ndef map(f, xs):\n    \"\"\"Map a function over leading array axes.\n\n    Like Python’s builtin map, except inputs and outputs are in the form of\n    stacked arrays. Consider using the `vectorized_map()` transform instead,\n    unless you need to apply a function element by element for reduced memory\n    usage or heterogeneous computation with other control flow primitives.\n\n    When `xs` is an array type, the semantics of `map()` are given by this\n    Python implementation:\n\n    ```python\n    def map(f, xs):\n        return np.stack([f(x) for x in xs])\n    ```\n\n    Args:\n        f: Callable defines the function to apply element-wise over the first\n            axis or axes of `xs`.\n        xs: Values over which to map along the leading axis.\n\n    Returns:\n        Mapped values.\n\n    Examples:\n\n    >>> f = lambda x: x**2\n    >>> xs = keras.ops.arange(10)\n    >>> ys = keras.ops.map(f, xs)\n    >>> ys\n    [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]\n\n    >>> f = lambda x: {\"y1\": x**2, \"y2\": x * 10}  # Can have nested outputs\n    >>> ys = keras.ops.map(f, xs)\n    >>> ys[\"y1\"]\n    [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]\n    >>> ys[\"y2\"]\n    [0, 10, 20, 30, 40, 50, 60, 70, 80, 90]\n    \"\"\"\n    if any_symbolic_tensors((xs,)):\n        return Map().symbolic_call(f, xs)\n    return backend.core.map(f, xs)\n\n\nclass Scan(Operation):\n    def __init__(self, length=None, reverse=False, unroll=1, *, name=None):\n        super().__init__(name=name)\n        self.length = length\n        self.reverse = reverse\n        self.unroll = unroll\n\n    def call(self, f, init, xs=None):\n        return backend.core.scan(\n            f,\n            init,\n            xs,\n            length=self.length,\n            reverse=self.reverse,\n            unroll=self.unroll,\n        )\n\n    def compute_output_spec(self, f, init, xs=None):\n        if xs is None:\n            n = int(self.length)\n            x = None\n        else:\n            n = (\n                int(self.length)\n                if self.length is not None\n                else tree.flatten(xs)[0].shape[0]\n            )\n            x = xs[0]\n\n        carry, y = backend.compute_output_spec(f, init, x)\n        y = KerasTensor(shape=(n,) + y.shape, dtype=y.dtype, sparse=y.sparse)\n        return carry, y\n\n\n@keras_export(\"keras.ops.scan\")\ndef scan(f, init, xs=None, length=None, reverse=False, unroll=1):\n    \"\"\"Scan a function over leading array axes while carrying along state.\n\n    When the type of `xs` is an array type or `None`, and the type of `ys` is an\n    array type, the semantics of `scan()` are given roughly by this Python\n    implementation:\n\n    ```python\n    def scan(f, init, xs, length=None):\n        if xs is None:\n            xs = [None] * length\n        carry = init\n        ys = []\n        for x in xs:\n            carry, y = f(carry, x)\n            ys.append(y)\n        return carry, np.stack(ys)\n    ```\n\n    The loop-carried value `carry` (`init`) must hold a fixed shape and dtype\n    across all iterations.\n\n    In TensorFlow, `y` must match `carry` in shape and dtype. This is not\n    required in other backends.\n\n    Args:\n        f: Callable defines the logic for each loop iteration. This accepts two\n            arguments where the first is a value of the loop carry and the\n            second is a slice of `xs` along its leading axis.\n            This callable returns a pair where the first represents a new value\n            for the loop carry and the second represents a slice of the output.\n        init: The initial loop carry value. This can be a scalar, tensor, or any\n            nested structure. It must match the structure of the first element\n            returned by `f`.\n        xs: Optional value to scan along its leading axis. This can be a tensor\n            or any nested structure. If `xs` is not provided, you must specify\n            `length` to define the number of loop iterations.\n            Defaults to `None`.\n        length: Optional integer specifying the number of loop iterations.\n            If `length` is not provided, it defaults to the sizes of leading\n            axis of the arrays in `xs`. Defaults to `None`.\n        reverse: Optional boolean specifying whether to run the scan iteration\n            forward or in reverse, equivalent to reversing the leading axes of\n            the arrays in both `xs` and in `ys`.\n        unroll: Optional positive integer or boolean specifying how many scan\n            iterations to unroll within a single iteration of a loop. If an\n            integer is provided, it determines how many unrolled loop iterations\n            to run within a single rolled iteration of the loop. If a boolean is\n            provided, it will determine if the loop is completely unrolled\n            (`unroll=True`) or left completely unrolled (`unroll=False`).\n            Note that unrolling is only supported by JAX and TensorFlow\n            backends.\n\n    Returns:\n        A pair where the first element represents the final loop carry value and\n        the second element represents the stacked outputs of `f` when scanned\n        over the leading axis of the inputs.\n\n    Examples:\n\n    >>> sum_fn = lambda c, x: (c + x, c + x)\n    >>> init = keras.ops.array(0)\n    >>> xs = keras.ops.array([1, 2, 3, 4, 5])\n    >>> carry, result = keras.ops.scan(sum_fn, init, xs)\n    >>> carry\n    15\n    >>> result\n    [1, 3, 6, 10, 15]\n    \"\"\"\n    if any_symbolic_tensors((init, xs)):\n        return Scan(\n            length=length, reverse=reverse, unroll=unroll\n        ).symbolic_call(f, init, xs)\n    return backend.core.scan(\n        f, init, xs, length, reverse=reverse, unroll=unroll\n    )\n\n\nclass AssociativeScan(Operation):\n    def __init__(self, reverse=False, axis=0, *, name=None):\n        super().__init__(name=name)\n        self.reverse = reverse\n        self.axis = axis\n\n    def call(self, f, elems):\n        return backend.core.associative_scan(\n            f, elems, reverse=self.reverse, axis=self.axis\n        )\n\n    def compute_output_spec(self, f, elems):\n        elems_flat = tree.flatten(elems)\n        lens = [elem.shape[self.axis] for elem in elems_flat]\n        if len(set(lens)) != 1:\n            raise ValueError(\n                \"Array inputs to associative_scan must have the same \"\n                \"first dimension. (saw: {})\".format(\n                    [elem.shape for elem in elems_flat]\n                )\n            )\n\n        x = tree.pack_sequence_as(\n            elems,\n            [slice_along_axis(x, 0, 1, axis=self.axis) for x in elems_flat],\n        )\n        y_spec = backend.compute_output_spec(f, x, x)\n\n        def _restore_shape(x):\n            return KerasTensor(\n                shape=elems_flat[0].shape, dtype=x.dtype, sparse=x.sparse\n            )\n\n        y_spec = tree.map_structure(_restore_shape, y_spec)\n        return y_spec\n\n\n@keras_export(\"keras.ops.associative_scan\")\ndef associative_scan(f, elems, reverse=False, axis=0):\n    \"\"\"Performs a scan with an associative binary operation, in parallel.\n\n    This operation his similar to `scan`, with the key difference that\n    `associative_scan` is a parallel implementation with\n    potentially significant performance benefits, especially when jit compiled.\n    The catch is that it can only be used when `f` is a binary associative\n    operation (i.e. it must verify `f(a, f(b, c)) == f(f(a, b), c)`).\n\n    For an introduction to associative scans, refer to this paper:\n    Blelloch, Guy E. 1990.\n    [Prefix Sums and Their Applications](\n        https://www.cs.cmu.edu/~guyb/papers/Ble93.pdf).\n\n    Args:\n        f: A Python callable implementing an associative binary operation with\n            signature `r = f(a, b)`. Function `f` must be associative, i.e.,\n            it must satisfy the equation\n            `f(a, f(b, c)) == f(f(a, b), c)`.\n            The inputs and result are (possibly nested Python tree structures\n            of) array(s) matching `elems`. Each array has a dimension in place\n            of the `axis` dimension. `f` should be applied elementwise over\n            the `axis` dimension.\n            The result `r` has the same shape (and structure) as the\n            two inputs `a` and `b`.\n        elems: A (possibly nested Python tree structure of) array(s), each with\n            an `axis` dimension of size `num_elems`.\n        reverse: A boolean stating if the scan should be reversed with respect\n            to the `axis` dimension.\n        axis: an integer identifying the axis over which the scan should occur.\n\n    Returns:\n        A (possibly nested Python tree structure of) array(s) of the same shape\n        and structure as `elems`, in which the `k`'th element of `axis` is\n        the result of recursively applying `f` to combine the first `k`\n        elements of `elems` along `axis`. For example, given\n        `elems = [a, b, c, ...]`, the result would be\n        `[a, f(a, b), f(f(a, b), c), ...]`.\n\n    Examples:\n\n    >>> sum_fn = lambda x, y: x + y\n    >>> xs = keras.ops.arange(5)\n    >>> ys = keras.ops.associative_scan(sum_fn, xs, axis=0)\n    >>> ys\n    [0, 1, 3, 6, 10]\n\n    >>> sum_fn = lambda x, y: [x[0] + y[0], x[1] + y[1], x[2] + y[2]]\n    >>> xs = [keras.ops.array([1, 2]) for _ in range(3)]\n    >>> ys = keras.ops.associative_scan(sum_fn, xs, axis=0)\n    >>> ys\n    [[1, 3], [1, 3], [1, 3]]\n    \"\"\"\n    if any_symbolic_tensors((elems,)):\n        return AssociativeScan(reverse=reverse, axis=axis).symbolic_call(\n            f, elems\n        )\n    return backend.core.associative_scan(f, elems, reverse=reverse, axis=axis)\n\n\nclass Scatter(Operation):\n    def __init__(self, shape, *, name=None):\n        super().__init__(name=name)\n        self.shape = shape\n\n    def call(self, indices, values):\n        return backend.core.scatter(indices, values, self.shape)\n\n    def compute_output_spec(self, indices, values):\n        return KerasTensor(self.shape, dtype=values.dtype)\n\n\n@keras_export(\"keras.ops.scatter\")\ndef scatter(indices, values, shape):\n    \"\"\"Returns a tensor of shape `shape` where `indices` are set to `values`.\n\n    At a high level, this operation does `zeros[indices] = updates` and\n    returns the output. It is equivalent to:\n\n    ```python\n    zeros = keras.ops.zeros(shape)\n    output = keras.ops.scatter_update(zeros, indices, values)\n    ```\n\n    Args:\n        indices: A tensor or list/tuple specifying\n            indices for the values in `values`.\n        values: A tensor, the values to be set at `indices`.\n        shape: Shape of the output tensor.\n\n    Example:\n\n    >>> indices = [[0, 1], [1, 1]]\n    >>> values = np.array([1., 1.])\n    >>> keras.ops.scatter(indices, values, shape=(2, 2))\n    array([[0., 1.],\n           [0., 1.]])\n    \"\"\"\n    if any_symbolic_tensors((indices, values)):\n        return Scatter(shape=shape).symbolic_call(indices, values)\n    return backend.core.scatter(indices, values, shape)\n\n\nclass ScatterUpdate(Operation):\n    def __init__(self, reduction=None, *, name=None):\n        super().__init__(name=name)\n        self.reduction = reduction\n\n    def call(self, inputs, indices, updates):\n        return backend.core.scatter_update(\n            inputs, indices, updates, reduction=self.reduction\n        )\n\n    def compute_output_spec(self, inputs, indices, updates):\n        return KerasTensor(inputs.shape, dtype=inputs.dtype)\n\n\n@keras_export(\"keras.ops.scatter_update\")\ndef scatter_update(inputs, indices, updates, reduction=None):\n    \"\"\"Update inputs via updates at scattered (sparse) indices.\n\n    At a high level, this operation does `inputs[indices] = updates`.\n    Assume `inputs` is a tensor of shape `(D0, D1, ..., Dn)`, there are 2 main\n    usages of `scatter_update`.\n\n    1. `indices` is a 2D tensor of shape `(num_updates, n)`, where `num_updates`\n        is the number of updates to perform, and `updates` is a 1D tensor of\n        shape `(num_updates,)`. For example, if `inputs` is `zeros((4, 4, 4))`,\n        and we want to update `inputs[1, 2, 3]` and `inputs[0, 1, 3]` as 1, then\n        we can use:\n\n    ```python\n    inputs = np.zeros((4, 4, 4))\n    indices = [[1, 2, 3], [0, 1, 3]]\n    updates = np.array([1., 1.])\n    inputs = keras.ops.scatter_update(inputs, indices, updates)\n    ```\n\n    2 `indices` is a 2D tensor of shape `(num_updates, k)`, where `num_updates`\n        is the number of updates to perform, and `k` (`k < n`) is the size of\n        each index in `indices`. `updates` is a `n - k`-D tensor of shape\n        `(num_updates, inputs.shape[k:])`. For example, if\n        `inputs = np.zeros((4, 4, 4))`, and we want to update `inputs[1, 2, :]`\n        and `inputs[2, 3, :]` as `[1, 1, 1, 1]`, then `indices` would have shape\n        `(num_updates, 2)` (`k = 2`), and `updates` would have shape\n        `(num_updates, 4)` (`inputs.shape[2:] = 4`). See the code below:\n\n    ```python\n    inputs = np.zeros((4, 4, 4))\n    indices = [[1, 2], [2, 3]]\n    updates = np.array([[1., 1., 1, 1,], [1., 1., 1, 1,])\n    inputs = keras.ops.scatter_update(inputs, indices, updates)\n    ```\n\n    Args:\n        inputs: A tensor, the tensor to be updated.\n        indices: A tensor or list/tuple of shape `(N, inputs.ndim)`, specifying\n            indices to update. `N` is the number of indices to update, must be\n            equal to the first dimension of `updates`.\n        updates: A tensor, the new values to be put to `inputs` at `indices`.\n        reduction: A string specifying the reduction operation to apply when\n            multiple updates target the same index. Supported values are:\n            `None` (default): Updates replace existing values (last write wins).\n            `\"add\"`: Updates are added to existing values.\n            `\"max\"`: The maximum of updates and existing values is kept.\n            `\"min\"`: The minimum of updates and existing values is kept.\n            `\"mul\"`: Updates are multiplied with existing values.\n\n    Returns:\n        A tensor, has the same shape and dtype as `inputs`.\n\n    Example:\n\n    Using `reduction=\"add\"` to accumulate values at the same index:\n\n    >>> inputs = np.zeros((4,))\n    >>> indices = [[0], [0], [1]]\n    >>> updates = np.array([1., 1., 1.])\n    >>> keras.ops.scatter_update(inputs, indices, updates, reduction=\"add\")\n    array([2., 1., 0., 0.])\n    \"\"\"\n    if reduction is not None:\n        reduction = reduction.lower()\n        if reduction not in (\"add\", \"max\", \"min\", \"mul\"):\n            raise ValueError(\n                f\"Invalid reduction: {reduction}. \"\n                \"Supported values are: None, 'add', 'max', 'min', 'mul'.\"\n            )\n    if any_symbolic_tensors((inputs, indices, updates)):\n        return ScatterUpdate(reduction=reduction).symbolic_call(\n            inputs, indices, updates\n        )\n    return backend.core.scatter_update(\n        inputs, indices, updates, reduction=reduction\n    )\n\n\nclass Slice(Operation):\n    def __init__(self, shape, *, name=None):\n        super().__init__(name=name)\n        self.shape = shape\n\n    def call(self, inputs, start_indices):\n        return backend.core.slice(inputs, start_indices, self.shape)\n\n    def compute_output_spec(self, inputs, start_indices):\n        if len(self.shape) != len(inputs.shape):\n            raise ValueError(\n                \"The number of dimensions in `inputs` must match the number of \"\n                f\"dimensions in `shape`. Received inputs.shape={inputs.shape} \"\n                f\"and shape={self.shape}\"\n            )\n        if hasattr(start_indices, \"__len__\") and len(start_indices) != len(\n            inputs.shape\n        ):\n            raise ValueError(\n                \"The number of dimensions in `start_indices` must match the \"\n                \"number of dimensions in `inputs`. Received \"\n                f\"start_indices={start_indices} and inputs.shape={inputs.shape}\"\n            )\n\n        final_shape = []\n        for i, (input_dim, slice_dim) in enumerate(\n            zip(inputs.shape, self.shape)\n        ):\n            if slice_dim != -1:\n                final_shape.append(slice_dim)\n            elif isinstance(start_indices, KerasTensor) or input_dim is None:\n                final_shape.append(None)\n            else:\n                final_shape.append(input_dim - start_indices[i])\n        return KerasTensor(final_shape, dtype=inputs.dtype)\n\n\n@keras_export(\"keras.ops.slice\")\ndef slice(inputs, start_indices, shape):\n    \"\"\"Return a slice of an input tensor.\n\n    At a high level, this operation is an explicit replacement for array slicing\n    e.g. `inputs[start_indices: start_indices + shape]`.\n    Unlike slicing via brackets, this operation will accept tensor start\n    indices on all backends, which is useful when indices dynamically computed\n    via other tensor operations.\n\n    ```python\n    inputs = np.zeros((5, 5))\n    start_indices = np.array([3, 3])\n    shape = np.array([2, 2])\n    inputs = keras.ops.slice(inputs, start_indices, shape)\n    ```\n\n    Args:\n        inputs: A tensor, the tensor to be updated.\n        start_indices: A list/tuple of shape `(inputs.ndim,)`, specifying\n            the starting indices for updating.\n        shape: The full shape of the returned slice.\n\n    Returns:\n        A tensor, has the same shape and dtype as `inputs`.\n    \"\"\"\n    if any_symbolic_tensors((inputs, start_indices)):\n        return Slice(shape=shape).symbolic_call(inputs, start_indices)\n    return backend.core.slice(inputs, start_indices, shape)\n\n\nclass SliceUpdate(Operation):\n    def call(self, inputs, start_indices, updates):\n        return backend.core.slice_update(inputs, start_indices, updates)\n\n    def compute_output_spec(self, inputs, start_indices, updates):\n        return KerasTensor(inputs.shape, dtype=inputs.dtype)\n\n\n@keras_export(\"keras.ops.slice_update\")\ndef slice_update(inputs, start_indices, updates):\n    \"\"\"Update an input by slicing in a tensor of updated values.\n\n    At a high level, this operation does\n    `inputs[start_indices: start_indices + updates.shape] = updates`.\n    Assume inputs is a tensor of shape `(D0, D1, ..., Dn)`,\n    `start_indices` must be a list/tuple of n integers, specifying the starting\n    indices. `updates` must have the same rank as `inputs`, and the size of each\n    dim must not exceed `Di - start_indices[i]`. For example, if we have 2D\n    inputs `inputs = np.zeros((5, 5))`, and we want to update the intersection\n    of last 2 rows and last 2 columns as 1, i.e.,\n    `inputs[3:, 3:] = np.ones((2, 2))`, then we can use the code below:\n\n    ```python\n    inputs = np.zeros((5, 5))\n    start_indices = [3, 3]\n    updates = np.ones((2, 2))\n    inputs = keras.ops.slice_update(inputs, start_indices, updates)\n    ```\n\n    Args:\n        inputs: A tensor, the tensor to be updated.\n        start_indices: A list/tuple of shape `(inputs.ndim,)`, specifying\n            the starting indices for updating.\n        updates: A tensor, the new values to be put to `inputs` at `indices`.\n            `updates` must have the same rank as `inputs`.\n\n    Returns:\n        A tensor, has the same shape and dtype as `inputs`.\n    \"\"\"\n    if any_symbolic_tensors((inputs, start_indices, updates)):\n        return SliceUpdate().symbolic_call(inputs, start_indices, updates)\n    return backend.core.slice_update(inputs, start_indices, updates)\n\n\nclass Switch(Operation):\n    def call(self, index, branches, *operands):\n        return backend.core.switch(index, branches, *operands)\n\n    def compute_output_spec(self, index, branches, *operands):\n        # We use first branch for output_spec\n        spec = backend.compute_output_spec(branches[0], *operands)\n        return spec\n\n\n@keras_export(\"keras.ops.switch\")\ndef switch(index, branches, *operands):\n    \"\"\"Apply exactly one of the `branches` given by `index`.\n\n    If `index` is out of bounds, it is clamped to within bounds.\n\n    The semantics of `switch` are given roughly by this Python implementation:\n\n    ```python\n    def switch(index, branches, *operands):\n        index = clamp(0, index, len(branches) - 1)\n        return branches[index](*operands)\n    ```\n\n    Args:\n        index: An integer scalar indicating which branch function to apply.\n        branches: A sequence of functions to be applied based on `index`.\n        operands: Inputs to whichever branch is applied.\n\n    Returns:\n        The outputs of `branch(*operands)` for the branch that was selected\n        based on `index`.\n\n    Examples:\n\n    >>> add_fn = lambda x, y: x + y\n    >>> subtract_fn = lambda x, y: x - y\n    >>> x = keras.ops.array(2.0)\n    >>> y = keras.ops.array(0.5)\n    >>> branches = [add_fn, subtract_fn]\n    >>> keras.ops.switch(0, branches, x, y)\n    2.5\n\n    >>> keras.ops.switch(1, branches, x, y)\n    1.5\n    \"\"\"\n    if any_symbolic_tensors(operands):\n        return Switch().symbolic_call(index, branches, *operands)\n    return backend.core.switch(index, branches, *operands)\n\n\nclass WhileLoop(Operation):\n    def __init__(self, cond, body, maximum_iterations=None, *, name=None):\n        super().__init__(name=name)\n        self.cond = cond\n        self.body = body\n        self.maximum_iterations = maximum_iterations\n\n    def call(self, loop_vars):\n        return backend.core.while_loop(\n            self.cond,\n            self.body,\n            loop_vars,\n            maximum_iterations=self.maximum_iterations,\n        )\n\n    def compute_output_spec(self, loop_vars):\n        return tree.map_structure(\n            lambda v: KerasTensor(v.shape, dtype=v.dtype), loop_vars\n        )\n\n\n@keras_export(\"keras.ops.while_loop\")\ndef while_loop(\n    cond,\n    body,\n    loop_vars,\n    maximum_iterations=None,\n):\n    \"\"\"While loop implementation.\n\n    Args:\n        cond: A callable that represents the termination condition of the loop.\n            Must accept a `loop_vars` like structure as an argument. If\n            `loop_vars` is a tuple or list, each element of `loop_vars` will be\n            passed positionally to the callable.\n        body: A callable that represents the loop body. Must accept a\n            `loop_vars` like structure as an argument, and return update value\n            with the same structure. If `loop_vars` is a tuple or list, each\n            element of `loop_vars` will be passed positionally to the callable.\n        loop_vars: An arbitrary nested structure of tensor state to persist\n            across loop iterations.\n        maximum_iterations: Optional maximum number of iterations of the while\n            loop to run. If provided, the `cond` output is AND-ed with an\n            additional condition ensuring the number of iterations executed is\n            no greater than `maximum_iterations`.\n\n    Returns:\n        A list/tuple of tensors, has the same shape and dtype as `inputs`.\n\n    Examples:\n\n    >>> i = 0\n    >>> cond = lambda i: i < 10\n    >>> body = lambda i: i + 1\n    >>> keras.ops.while_loop(cond, body, i)\n    10\n\n    >>> x, y = 0, 1\n    >>> cond = lambda x, y: x < 10\n    >>> body = lambda x, y: (x + 1, y + 1)\n    >>> keras.ops.while_loop(cond, body, (x, y))\n    10, 11\n    \"\"\"\n    if any_symbolic_tensors((loop_vars,)):\n        return WhileLoop(\n            cond, body, maximum_iterations=maximum_iterations\n        ).symbolic_call(loop_vars)\n    return backend.core.while_loop(\n        cond,\n        body,\n        loop_vars,\n        maximum_iterations=maximum_iterations,\n    )\n\n\nclass StopGradient(Operation):\n    def call(self, variable):\n        return backend.core.stop_gradient(variable)\n\n    def compute_output_spec(self, variable):\n        return KerasTensor(variable.shape, dtype=variable.dtype)\n\n\n@keras_export(\"keras.ops.stop_gradient\")\ndef stop_gradient(variable):\n    \"\"\"Stops gradient computation.\n\n    Args:\n        variable: A tensor variable for which the gradient\n            computation is to be disabled.\n\n    Returns:\n        The variable with gradient computation disabled.\n\n    Examples:\n\n    >>> var = keras.backend.convert_to_tensor(\n    ...     [1., 2., 3.],\n    ...     dtype=\"float32\"\n    ... )\n    >>> var = keras.ops.stop_gradient(var)\n    \"\"\"\n    if any_symbolic_tensors((variable,)):\n        return StopGradient().symbolic_call(variable)\n    return backend.core.stop_gradient(variable)\n\n\nclass ForiLoop(Operation):\n    def __init__(self, lower, upper, body_fun, *, name=None):\n        super().__init__(name=name)\n        self.lower = lower\n        self.upper = upper\n        self.body_fun = body_fun\n\n    def call(self, init_val):\n        return backend.core.fori_loop(\n            self.lower,\n            self.upper,\n            self.body_fun,\n            init_val,\n        )\n\n    def compute_output_spec(self, init_val):\n        return KerasTensor(init_val.shape, dtype=init_val.dtype)\n\n\n@keras_export(\"keras.ops.fori_loop\")\ndef fori_loop(lower, upper, body_fun, init_val):\n    \"\"\"For loop implementation.\n\n    Args:\n        lower: The initial value of the loop variable.\n        upper: The upper bound of the loop variable.\n        body_fun: A callable that represents the loop body. Must take two\n            arguments: the loop variable and the loop state. The loop state\n            should be updated and returned by this function.\n        init_val: The initial value of the loop state.\n\n    Returns:\n        The final state after the loop.\n\n    Example:\n\n    >>> lower = 0\n    >>> upper = 10\n    >>> body_fun = lambda i, s: (i + 1, s + i)\n    >>> init_val = 0\n    >>> keras.ops.fori_loop(lower, upper, body_fun, init_val)\n    45\n    \"\"\"\n    if any_symbolic_tensors((lower, upper, init_val)):\n        return ForiLoop(lower, upper, body_fun).symbolic_call(init_val)\n    return backend.core.fori_loop(lower, upper, body_fun, init_val)\n\n\nclass Unstack(Operation):\n    def __init__(self, num=None, axis=0, *, name=None):\n        super().__init__(name=name)\n        self.num = num\n        self.axis = axis\n\n    def call(self, x):\n        return backend.core.unstack(x, self.num, self.axis)\n\n    def compute_output_spec(self, x):\n        axis = self.axis\n        if axis < 0:\n            axis = len(x.shape) + axis\n        output_shapes = x.shape[:axis] + x.shape[axis + 1 :]\n        num = self.num\n        if num is None:\n            num = x.shape[axis]\n        if num is None:\n            raise ValueError(\n                \"Cannot infer argument `num` from shape \"\n                f\"{x.shape}. Either provide a tensor with a \"\n                \"concrete shape in the `axis` dimension or \"\n                \"explicitly pass the `num` argument.\"\n            )\n        output = [\n            KerasTensor(shape=output_shapes, dtype=x.dtype) for _ in range(num)\n        ]\n        return output\n\n\n@keras_export(\"keras.ops.unstack\")\ndef unstack(x, num=None, axis=0):\n    \"\"\"Unpacks the given dimension of a rank-R tensor into rank-(R-1) tensors.\n\n    Args:\n        x: The input tensor.\n        num: The length of the dimension axis. Automatically inferred\n            if `None`.\n        axis: The axis along which to unpack.\n\n    Returns:\n        A list of tensors unpacked along the given axis.\n\n    Example:\n\n    >>> x = keras.ops.array([[1, 2], [3, 4]])\n    >>> keras.ops.unstack(x, axis=0)\n    [array([1, 2]), array([3, 4])]\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Unstack(num, axis).symbolic_call(x)\n    return backend.core.unstack(x, num=num, axis=axis)\n\n\n@keras_export(\"keras.ops.shape\")\ndef shape(x):\n    \"\"\"Gets the shape of the tensor input.\n\n    Note: On the TensorFlow backend, when `x` is a `tf.Tensor` with dynamic\n    shape, dimensions which are dynamic in the context of a compiled function\n    will have a `tf.Tensor` value instead of a static integer value.\n\n    Args:\n        x: A tensor. This function will try to access the `shape` attribute of\n            the input tensor.\n\n    Returns:\n        A tuple of integers or None values, indicating the shape of the input\n            tensor.\n\n    Example:\n\n    >>> x = keras.ops.zeros((8, 12))\n    >>> keras.ops.shape(x)\n    (8, 12)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return x.shape\n    return backend.core.shape(x)\n\n\n@keras_export(\"keras.ops.dtype\")\ndef dtype(x):\n    \"\"\"Return the dtype of the tensor input as a standardized string.\n\n    Note that due to the standardization, the dtype will not compare equal\n    to the backend-specific version of the dtype.\n\n    Args:\n        x: A tensor. This function will try to access the `dtype` attribute of\n            the input tensor.\n\n    Returns:\n        A string indicating the dtype of the input tensor, e.g. `\"float32\"`.\n\n    Example:\n\n    >>> x = keras.ops.zeros((8, 12))\n    >>> keras.ops.dtype(x)\n    'float32'\n\n    \"\"\"\n    return backend.standardize_dtype(x.dtype)\n\n\nclass Cast(Operation):\n    def __init__(self, dtype, *, name=None):\n        super().__init__(name=name)\n        self.dtype = dtype\n\n    def call(self, x):\n        return backend.core.cast(x, self.dtype)\n\n    def compute_output_spec(self, x):\n        return backend.KerasTensor(shape=x.shape, dtype=self.dtype)\n\n\n@keras_export(\"keras.ops.cast\")\ndef cast(x, dtype):\n    \"\"\"Cast a tensor to the desired dtype.\n\n    Args:\n        x: A tensor or variable.\n        dtype: The target type.\n\n    Returns:\n        A tensor of the specified `dtype`.\n\n    Example:\n\n    >>> x = keras.ops.arange(4)\n    >>> x = keras.ops.cast(x, dtype=\"float16\")\n    \"\"\"\n    dtype = backend.standardize_dtype(dtype)\n    if any_symbolic_tensors((x,)):\n        return Cast(dtype=dtype)(x)\n    return backend.core.cast(x, dtype)\n\n\nclass SaturateCast(Operation):\n    def __init__(self, dtype, *, name=None):\n        super().__init__(name=name)\n        self.dtype = dtype\n\n    def call(self, x):\n        return _saturate_cast(x, self.dtype)\n\n    def compute_output_spec(self, x):\n        return backend.KerasTensor(shape=x.shape, dtype=self.dtype)\n\n\n@keras_export(\"keras.ops.saturate_cast\")\ndef saturate_cast(x, dtype):\n    \"\"\"Performs a safe saturating cast to the desired dtype.\n\n    Saturating cast prevents data type overflow when casting to `dtype` with\n    smaller values range. E.g.\n    `ops.cast(ops.cast([-1, 256], \"float32\"), \"uint8\")` returns `[255, 0]`,\n    but `ops.saturate_cast(ops.cast([-1, 256], \"float32\"), \"uint8\")` returns\n    `[0, 255]`.\n\n    Args:\n        x: A tensor or variable.\n        dtype: The target type.\n\n    Returns:\n        A safely casted tensor of the specified `dtype`.\n\n    Example:\n\n    Image resizing with bicubic interpolation may produce values outside\n    original range.\n    >>> image2x2 = np.array([0, 1, 254, 255], dtype=\"uint8\").reshape(1, 2, 2, 1)\n    >>> image4x4 = tf.image.resize(image2x2, (4, 4), method=\"bicubic\")\n    >>> print(image4x4.numpy().squeeze())\n    >>> # [[-22.500004 -22.204624 -21.618908 -21.32353 ]\n    >>> #  [ 52.526054  52.82143   53.407146  53.70253 ]\n    >>> #  [201.29752  201.59288  202.17859  202.47395 ]\n    >>> #  [276.32355  276.61893  277.20465  277.50006 ]]\n\n    Casting this resized image back to `uint8` will cause overflow.\n    >>> image4x4_casted = ops.cast(image4x4, \"uint8\")\n    >>> print(image4x4_casted.numpy().squeeze())\n    >>> # [[234 234 235 235]\n    >>> #  [ 52  52  53  53]\n    >>> #  [201 201 202 202]\n    >>> #  [ 20  20  21  21]]\n\n    Saturate casting to `uint8` will clip values to `uint8` range before\n    casting and will not cause overflow.\n    >>> image4x4_saturate_casted = ops.saturate_cast(image4x4, \"uint8\")\n    >>> print(image4x4_saturate_casted.numpy().squeeze())\n    >>> # [[  0   0   0   0]\n    >>> #  [ 52  52  53  53]\n    >>> #  [201 201 202 202]\n    >>> #  [255 255 255 255]]\n\n    \"\"\"\n    dtype = backend.standardize_dtype(dtype)\n    if any_symbolic_tensors((x,)):\n        return SaturateCast(dtype=dtype)(x)\n    return _saturate_cast(x, dtype)\n\n\ndef _saturate_cast(x, dtype, backend_module=None):\n    backend_module = backend_module or backend\n\n    def get_dtype_min_max(dtype):\n        if \"bool\" == dtype:\n            dtype_min = 0\n            dtype_max = 1\n        elif \"int\" in dtype:\n            dtype_min = ml_dtypes.iinfo(dtype).min\n            dtype_max = ml_dtypes.iinfo(dtype).max\n        else:\n            dtype_min = ml_dtypes.finfo(dtype).min\n            dtype_max = ml_dtypes.finfo(dtype).max\n        return dtype_min, dtype_max\n\n    dtype = backend.standardize_dtype(dtype)\n    in_dtype = backend.standardize_dtype(x.dtype)\n    in_min, in_max = get_dtype_min_max(in_dtype)\n    out_min, out_max = get_dtype_min_max(dtype)\n\n    # The output min/max may not actually be representable in the\n    # in_dtype (e.g. casting float32 to uint32).  This can lead to undefined\n    # behavior when trying to cast a value outside the valid range of the\n    # target type. We work around this by nudging the min/max to fall within\n    # the valid output range. The catch is that we may actually saturate\n    # to a value less than the true saturation limit, but this is the best we\n    # can do in order to avoid UB without backend op.\n    min_limit = np.maximum(in_min, out_min).astype(in_dtype)\n    if min_limit < out_min:\n        min_limit = np.nextafter(min_limit, 0, dtype=in_dtype)\n    max_limit = np.minimum(in_max, out_max).astype(in_dtype)\n    if max_limit > out_max:\n        max_limit = np.nextafter(max_limit, 0, dtype=in_dtype)\n\n    # Unconditionally apply `clip` to fix `inf` behavior.\n    x = backend_module.numpy.clip(x, min_limit, max_limit)\n\n    return backend_module.cast(x, dtype)\n\n\nclass ConvertToTensor(Operation):\n    def __init__(self, dtype=None, sparse=None, ragged=None, *, name=None):\n        super().__init__(name=name)\n        self.dtype = dtype\n        self.sparse = sparse\n        self.ragged = ragged\n\n    def call(self, x):\n        return backend.core.convert_to_tensor(\n            x, dtype=self.dtype, sparse=self.sparse, ragged=self.ragged\n        )\n\n    def compute_output_spec(self, x):\n        dtype = (\n            backend.standardize_dtype(x.dtype)\n            if self.dtype is None\n            else self.dtype\n        )\n        sparse = (\n            False if self.sparse is not None and not self.sparse else x.sparse\n        )\n        ragged = (\n            False if self.ragged is not None and not self.ragged else x.ragged\n        )\n        return backend.KerasTensor(\n            shape=x.shape, dtype=dtype, sparse=sparse, ragged=ragged\n        )\n\n\n@keras_export(\"keras.ops.convert_to_tensor\")\ndef convert_to_tensor(x, dtype=None, sparse=None, ragged=None):\n    \"\"\"Convert a NumPy array or Python array to a tensor.\n\n    Native tensors for the current backend or left unchanged unless the `dtype`,\n    `sparse` or `ragged` arguments are set.\n\n    Args:\n        x: A NumPy array, Python array (can be nested) or a backend tensor.\n        dtype: The target type. If `None`, the type of `x` is used.\n        sparse: Whether to keep sparse tensors. `False` will cause sparse\n            tensors to be densified. The default value of `None` means that\n            sparse tensors are kept only if the backend supports them.\n        ragged: Whether to keep ragged tensors. `False` will cause ragged\n            tensors to be densified. The default value of `None` means that\n            ragged tensors are kept only if the backend supports them.\n\n    Returns:\n        A backend tensor of the specified `dtype` and sparseness.\n\n    Example:\n\n    >>> x = np.array([1, 2, 3])\n    >>> y = keras.ops.convert_to_tensor(x)\n    \"\"\"\n    dtype = None if dtype is None else backend.standardize_dtype(dtype)\n    if any_symbolic_tensors((x,)):\n        return ConvertToTensor(dtype=dtype, sparse=sparse, ragged=ragged)(x)\n    return backend.core.convert_to_tensor(\n        x, dtype=dtype, sparse=sparse, ragged=ragged\n    )\n\n\n@keras_export(\"keras.ops.convert_to_numpy\")\ndef convert_to_numpy(x):\n    \"\"\"Convert a tensor to a NumPy array.\n\n    Args:\n        x: A tensor.\n\n    Returns:\n        A NumPy array.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        # This will raise a `ValueError` defined in the `KerasTensor` class.\n        # We trigger it rather than duplicate it here.\n        return np.array(x)\n    return backend.convert_to_numpy(x)\n\n\nclass Cond(Operation):\n    @traceback_utils.filter_traceback\n    def __call__(self, *args, **kwargs):\n        def call_fn(*args, **kwargs):\n            if any_symbolic_tensors(args, kwargs):\n                return self.symbolic_call(*args, **kwargs)\n            else:\n                return self.call(*args, **kwargs)\n\n        if traceback_utils.is_traceback_filtering_enabled():\n            # Wrap self.call to provide helpful info in case of exception\n            call_fn = traceback_utils.inject_argument_info_in_traceback(\n                call_fn,\n                object_name=(f\"{self.__class__.__name__}.call()\"),\n            )\n            return call_fn(*args, **kwargs)\n\n        # Plain flow.\n        return call_fn(*args, **kwargs)\n\n    def call(self, pred, true_fn, false_fn):\n        return backend.core.cond(pred, true_fn, false_fn)\n\n    def compute_output_spec(self, pred, true_fn, false_fn):\n        true_fn_spec = backend.compute_output_spec(true_fn)\n        false_fn_spec = backend.compute_output_spec(false_fn)\n        if not self._check_output_spec(true_fn_spec, false_fn_spec):\n            raise ValueError(\n                \"`true_fn` and `false_fn` should return outputs \"\n                \"of the same kind (struct, dtype and shape). \"\n                f\"Got {true_fn_spec} and {false_fn_spec} instead.\"\n            )\n        return true_fn_spec\n\n    def _check_output_spec(self, true_fn_spec, false_fn_spec):\n        try:\n            tree.assert_same_structure(true_fn_spec, false_fn_spec)\n        except:\n            return False\n\n        def check_leaf(t_spec, f_spec):\n            if t_spec is None or f_spec is None:\n                return t_spec is None and f_spec is None\n            return t_spec.shape == f_spec.shape and t_spec.dtype == f_spec.dtype\n\n        same = tree.map_structure(check_leaf, true_fn_spec, false_fn_spec)\n        return all(tree.flatten(same))\n\n\n@keras_export(\"keras.ops.cond\")\ndef cond(pred, true_fn, false_fn):\n    \"\"\"Conditionally applies `true_fn` or `false_fn`.\n\n    Args:\n        pred: Boolean scalar type\n        true_fn: Callable returning the output for the `pred == True` case.\n        false_fn: Callable returning the output for the `pred == False` case.\n\n    Returns:\n        The output of either `true_fn` or `false_fn` depending on pred.\n    \"\"\"\n    return Cond()(pred, true_fn, false_fn)\n\n\nclass VectorizedMap(Operation):\n    def __init__(self, function, *, name=None):\n        super().__init__(name=name)\n        self.function = function\n\n    def call(self, elements):\n        return backend.core.vectorized_map(self.function, elements)\n\n    def compute_output_spec(self, elements):\n        x = tree.map_structure(lambda t: t[0], elements)\n        n = tree.flatten(elements)[0].shape[0]\n        y = backend.compute_output_spec(self.function, x)\n\n        def append_batch_axis(t):\n            return KerasTensor(\n                shape=(n,) + t.shape,\n                dtype=t.dtype,\n                sparse=t.sparse,\n                ragged=t.ragged,\n            )\n\n        y = tree.map_structure(append_batch_axis, y)\n        return y\n\n    def get_config(self):\n        config = super().get_config()\n        config.update({\"function\": self.function})\n        return config\n\n    @classmethod\n    def from_config(cls, config):\n        config = config.copy()\n        config[\"function\"] = serialization_lib.deserialize_keras_object(\n            config[\"function\"]\n        )\n        return cls(**config)\n\n\n@keras_export(\"keras.ops.vectorized_map\")\ndef vectorized_map(function, elements):\n    \"\"\"Parallel map of `function` on axis 0 of tensor(s) `elements`.\n\n    Schematically, `vectorized_map` implements the following,\n    in the case of a single tensor input `elements`:\n\n    ```python\n    def vectorized_map(function, elements):\n        outputs = []\n        for e in elements:\n            outputs.append(function(e))\n        return np.stack(outputs)\n    ```\n\n    In the case of an iterable of tensors `elements`,\n    it implements the following:\n\n    ```python\n    def vectorized_map(function, elements):\n        batch_size = elements[0].shape[0]\n        outputs = []\n        for index in range(batch_size):\n            outputs.append(function([e[index] for e in elements]))\n        return np.stack(outputs)\n    ```\n\n    In this case, `function` is expected to take as input\n    a single list of tensor arguments.\n    \"\"\"\n    if any_symbolic_tensors((elements,)):\n        return VectorizedMap(function)(elements)\n    return backend.core.vectorized_map(function, elements)\n\n\n@keras_export(\"keras.ops.is_tensor\")\ndef is_tensor(x):\n    \"\"\"Check whether the given object is a tensor.\n\n    Note: This checks for backend specific tensors so passing a TensorFlow\n    tensor would return `False` if your backend is PyTorch or JAX.\n\n    Args:\n        x: A variable.\n\n    Returns:\n        `True` if `x` is a tensor, otherwise `False`.\n    \"\"\"\n    return backend.core.is_tensor(x)\n\n\n@keras_export(\"keras.ops.custom_gradient\")\ndef custom_gradient(f):\n    \"\"\"Decorator to define a function with a custom gradient.\n\n    This decorator allows fine grained control over the gradients of a sequence\n    for operations. This may be useful for multiple reasons, including providing\n    a more efficient or numerically stable gradient for a sequence of\n    operations.\n\n    Args:\n        f: Function `f(*args)` that returns a tuple\n            `(output, grad_fn)`, where:\n            - `args` is a sequence of (nested structures of) tensor inputs to\n                the function.\n            - `output` is a (nested structure of) tensor outputs of applying\n                operations in `forward_fn` to `args`.\n            - `grad_fn` is a function with the signature `grad_fn(*args,\n                upstream)` which returns a tuple of tensors the same size as\n                (flattened) `args`: the derivatives of tensors in `output` with\n                respect to the tensors in `args`. `upstream` is a tensor or\n                sequence of tensors holding the initial value gradients for each\n                tensor in `output`.\n\n    Returns:\n        A function `h(*args)` which returns the same value as\n        `f(*args)[0]` and whose gradient is determined by\n        `f(*args)[1]`.\n\n\n    Examples:\n\n    1. Backend-agnostic example.\n\n    ```python\n    @ops.custom_gradient\n    def log1pexp(x):\n        e = ops.exp(x)\n\n        def grad(*args, upstream=None):\n            if upstream is None:\n                (upstream,) = args\n            return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))\n\n        return ops.log(1 + e), grad\n    ```\n\n    Note that the grad function that returns gradient computation\n    requires `args` as well as an `upstream` keyword argument, depending\n    on the backend being set. With the JAX and TensorFlow backends,\n    it requires only one argument, whereas it might use the `upstream`\n    argument in the case of the PyTorch backend.\n\n    When working with TensorFlow/JAX backend, `grad(upstream)`\n    is sufficient. With PyTorch, the `grad` function requires\n    `*args` as well as `upstream`, e.g. `def grad(*args, upstream)`.\n    Follow the previous example to use `@ops.custom_gradient` in\n    a way that is compatible with all backends.\n\n    2. Here's JAX & TensorFlow-specific example:\n\n    ```python\n    @ops.custom_gradient\n    def log1pexp(x):\n        e = ops.exp(x)\n        def grad(upstream):\n            return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))\n        return ops.log(1 + e), grad\n    ```\n\n    3. Lastly, here's a PyTorch-specific example,\n    using `*args` & `upstream`:\n\n    ```python\n    @ops.custom_gradient\n    def log1pexp(x):\n        e = ops.exp(x)\n        def grad(*args, upstream):\n            return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))\n        return ops.log(1 + e), grad\n    ```\n    \"\"\"\n    return backend.core.custom_gradient(f)\n"
  },
  {
    "path": "keras/src/ops/core_test.py",
    "content": "import operator\nfrom unittest.mock import Mock\n\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import losses\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import optimizers\nfrom keras.src import testing\nfrom keras.src import tree\nfrom keras.src.backend.common import dtypes\nfrom keras.src.backend.common.keras_tensor import KerasTensor\nfrom keras.src.layers.core import input_layer\nfrom keras.src.ops import core\nfrom keras.src.saving import object_registration\nfrom keras.src.testing.test_utils import named_product\n\n\nclass CoreOpsDynamicShapeTest(testing.TestCase):\n    def test_associative_scan(self):\n        xs = (KerasTensor((5, None)), KerasTensor((5, None)))\n        ys = core.associative_scan(\n            f=lambda x, y: (x[0] + y[0], x[1] + y[1]), elems=xs, axis=0\n        )\n        self.assertEqual(ys[0].shape, (5, None))\n\n        # sum two tuples of unknown (but same) length at axis\n        def _fn(x, y):\n            return tuple([x[i] + y[i] for i in range(len(x))])\n\n        ys = core.associative_scan(f=_fn, elems=xs, axis=1)\n        self.assertEqual(ys[0].shape, (5, None))\n\n    def test_cast(self):\n        x = KerasTensor((3, 5, None), dtype=\"float32\")\n        self.assertEqual(core.cast(x, \"float16\").shape, (3, 5, None))\n\n    def test_convert_to_tensor(self):\n        x = KerasTensor((2, None))\n        self.assertEqual(core.convert_to_tensor(x).shape, (2, None))\n\n    def test_fori_loop(self):\n        def body_fun(i, x):\n            return x + i\n\n        initial_value = KerasTensor((3, 5, None))\n        self.assertEqual(\n            core.fori_loop(0, 10, body_fun, initial_value).shape, (3, 5, None)\n        )\n\n    def test_map(self):\n        def f(x):\n            return x**2\n\n        xs = KerasTensor((None, 5))\n        self.assertEqual(core.map(f, xs).shape, (None, 5))\n\n        # Test nested output\n        def f2(x):\n            return {\"a\": x**2, \"b\": x * 10}\n\n        xs = KerasTensor((None, 5))\n        ys = core.map(f2, xs)\n        self.assertEqual(ys[\"a\"].shape, (None, 5))\n        self.assertEqual(ys[\"b\"].shape, (None, 5))\n\n        # Test nested input\n        def f3(x):\n            return x[0] + x[1]\n\n        xs = (KerasTensor((None, 5)), KerasTensor((None, 5)))\n        self.assertEqual(core.map(f3, xs).shape, (None, 5))\n\n    def test_saturate_cast(self):\n        x = KerasTensor((3, 5, None), dtype=\"float32\")\n        self.assertEqual(core.saturate_cast(x, \"float16\").shape, (3, 5, None))\n\n    def test_scan(self):\n        def f(carry, xs):\n            xs = xs + carry\n            return carry, carry\n\n        init = KerasTensor((None,))\n        xs = KerasTensor((6, None))\n        carry, result = core.scan(f, init, xs)\n        self.assertEqual(carry.shape, (None,))\n        self.assertEqual(result.shape, (6, None))\n\n        def f2(carry, _):\n            return carry, carry\n\n        carry, result = core.scan(f2, init, xs=None, length=3)\n        self.assertEqual(carry.shape, (None,))\n        self.assertEqual(result.shape, (3, None))\n\n    # Scatter doesn't support dynamic shape.\n\n    def test_scatter_update(self):\n        inputs = KerasTensor((4, None))\n        indices = KerasTensor((5, 2))\n        updates = KerasTensor((5,))\n        self.assertEqual(\n            core.scatter_update(inputs, indices, updates).shape, (4, None)\n        )\n\n    # Slice doesn't support dynamic shape.\n\n    def test_slice_update(self):\n        inputs = KerasTensor((4, None))\n        start_indices = KerasTensor((2,))\n        updates = KerasTensor((2, 2))\n        self.assertEqual(\n            core.slice_update(inputs, start_indices, updates).shape, (4, None)\n        )\n\n    def test_stop_gradient(self):\n        variable = KerasTensor(shape=(3, None), dtype=\"float32\")\n        self.assertEqual(core.stop_gradient(variable).shape, (3, None))\n\n    def test_switch(self):\n        def fn(x, y):\n            return x[:, 0], y[0, :]\n\n        index = KerasTensor(())\n        x = KerasTensor((None, 2))\n        y = KerasTensor((5, None))\n        result = core.switch(index, [fn], x, y)\n        self.assertEqual(result[0].shape, (None,))\n        self.assertEqual(result[1].shape, (None,))\n\n    def test_vectorized_map(self):\n        def f(x):\n            return x**2\n\n        xs = KerasTensor((None, 5))\n        self.assertEqual(core.vectorized_map(f, xs).shape, (None, 5))\n\n        # Test nested output\n        def f2(x):\n            return {\"a\": x**2, \"b\": x * 10}\n\n        xs = KerasTensor((None, 5))\n        ys = core.vectorized_map(f2, xs)\n        self.assertEqual(ys[\"a\"].shape, (None, 5))\n        self.assertEqual(ys[\"b\"].shape, (None, 5))\n\n        # Test nested input\n        def f3(x):\n            return x[0] + x[1]\n\n        xs = (KerasTensor((None, 5)), KerasTensor((None, 5)))\n        self.assertEqual(core.vectorized_map(f3, xs).shape, (None, 5))\n\n    def test_while_loop(self):\n        def cond(args):\n            return tree.flatten(args)[0] < 10\n\n        def body(args):\n            return tree.map_structure(lambda x: x + 1, args)\n\n        loop_vars = KerasTensor((None,))\n        self.assertEqual(core.while_loop(cond, body, loop_vars).shape, (None,))\n\n    def test_unstack(self):\n        x = KerasTensor((2, None, None))\n        axis, num = 1, 3\n        out = core.unstack(x, num=num, axis=axis)\n        self.assertEqual(len(out), 3)\n        for o in out:\n            self.assertEqual(o.shape, (2, None))\n\n\nclass CoreOpsStaticShapeTest(testing.TestCase):\n    def test_associative_scan(self):\n        xs = (KerasTensor((5, 10)), KerasTensor((5, 10)))\n        ys = core.associative_scan(\n            f=lambda x, y: (x[0] + y[0], x[1] + y[1]), elems=xs, axis=0\n        )\n        self.assertEqual(ys[0].shape, (5, 10))\n\n        # sum two tuples of unknown (but same) length at axis\n        def _fn(x, y):\n            return tuple([x[i] + y[i] for i in range(len(x))])\n\n        ys = core.associative_scan(f=_fn, elems=xs, axis=1)\n        self.assertEqual(ys[0].shape, (5, 10))\n\n    def test_cast(self):\n        x = KerasTensor((3, 5, 7), dtype=\"float32\")\n        self.assertEqual(core.cast(x, \"float16\").shape, (3, 5, 7))\n\n    def test_cond(self):\n        pred = KerasTensor((), dtype=\"bool\")\n        self.assertEqual(\n            ops.cond(\n                pred, lambda: ops.ones((1, 3)), lambda: ops.zeros((1, 3))\n            ).shape,\n            (1, 3),\n        )\n\n    def test_convert_to_tensor(self):\n        x = KerasTensor((2, 3))\n        out = core.convert_to_tensor(x)\n        self.assertEqual(out.shape, x.shape)\n        self.assertFalse(out.sparse)\n\n        out = core.convert_to_tensor(x, sparse=True)\n        self.assertFalse(out.sparse)\n\n        x = KerasTensor((2, 3), sparse=True)\n        out = core.convert_to_tensor(x)\n        self.assertTrue(out.sparse)\n\n        out = core.convert_to_tensor(x, sparse=True)\n        self.assertTrue(out.sparse)\n\n        out = core.convert_to_tensor(x, sparse=False)\n        self.assertFalse(out.sparse)\n\n    def test_fori_loop(self):\n        def body_fun(i, x):\n            return x + i\n\n        initial_value = KerasTensor((3, 5, 7))\n        result = core.fori_loop(0, 10, body_fun, initial_value)\n        self.assertEqual(result.shape, (3, 5, 7))\n\n    def test_map(self):\n        def f(x):\n            return x**2\n\n        xs = KerasTensor((6, 5))\n        ys = core.map(f, xs)\n        self.assertEqual(ys.shape, (6, 5))\n\n        # Test nested output\n        def f2(x):\n            return {\"a\": x**2, \"b\": x * 10}\n\n        xs = KerasTensor((6, 5))\n        ys = core.map(f2, xs)\n        self.assertEqual(ys[\"a\"].shape, (6, 5))\n        self.assertEqual(ys[\"b\"].shape, (6, 5))\n\n        # Test nested input\n        def f3(x):\n            return x[0] + x[1]\n\n        xs = (KerasTensor((6, 5)), KerasTensor((6, 5)))\n        self.assertEqual(core.map(f3, xs).shape, (6, 5))\n\n    def test_saturate_cast(self):\n        x = KerasTensor((3, 5, 7), dtype=\"float32\")\n        self.assertEqual(core.saturate_cast(x, \"float16\").shape, (3, 5, 7))\n\n    def test_scan(self):\n        def f(carry, xs):\n            xs = xs + carry\n            return carry, carry\n\n        init = KerasTensor(())\n        xs = KerasTensor((6,))\n        carry, result = core.scan(f, init, xs)\n        self.assertEqual(carry.shape, ())\n        self.assertEqual(result.shape, (6,))\n\n        def f2(carry, _):\n            return carry, carry\n\n        carry, result = core.scan(f2, init, xs=None, length=3)\n        self.assertEqual(carry.shape, ())\n        self.assertEqual(result.shape, (3,))\n\n    def test_scatter(self):\n        indices = KerasTensor((5, 2))\n        values = KerasTensor((5,))\n        shape = (4, 4)\n        self.assertEqual(core.scatter(indices, values, shape).shape, (4, 4))\n\n    def test_scatter_update(self):\n        inputs = KerasTensor((4, 4))\n        indices = KerasTensor((5, 2))\n        updates = KerasTensor((5,))\n        self.assertEqual(\n            core.scatter_update(inputs, indices, updates).shape, (4, 4)\n        )\n\n        inputs = KerasTensor((4, 4, 4))\n        indices = KerasTensor((5, 2))\n        updates = KerasTensor((5, 4))\n        self.assertEqual(\n            core.scatter_update(inputs, indices, updates).shape, (4, 4, 4)\n        )\n\n    def test_slice(self):\n        inputs = KerasTensor(shape=(3, 3), dtype=\"float32\")\n        start_indices = KerasTensor(shape=(2,), dtype=\"int32\")\n        shape = (2, 2)\n        self.assertEqual(core.slice(inputs, start_indices, shape).shape, (2, 2))\n\n    def test_slice_negative_one_shape(self):\n        inputs = KerasTensor(shape=(3, 3), dtype=\"float32\")\n        start_indices = (1, 1)\n        shape = (-1, -1)\n        self.assertEqual(core.slice(inputs, start_indices, shape).shape, (2, 2))\n\n    def test_slice_negative_one_shape_tensor_indices(self):\n        inputs = KerasTensor(shape=(3, 3), dtype=\"float32\")\n        start_indices = KerasTensor(shape=(2,), dtype=\"int32\")\n        shape = (-1, -1)\n        self.assertEqual(\n            core.slice(inputs, start_indices, shape).shape, (None, None)\n        )\n\n    def test_slice_negative_one_shape_dynamic_input_shape(self):\n        inputs = KerasTensor(shape=(None, 3), dtype=\"float32\")\n        start_indices = (1, 1)\n        shape = (-1, -1)\n        self.assertEqual(\n            core.slice(inputs, start_indices, shape).shape, (None, 2)\n        )\n\n    def test_slice_invalid_inputs(self):\n        inputs = KerasTensor(shape=(3, 3), dtype=\"float32\")\n        start_indices = (1, 1)\n        shape = (2, 2, 2)\n        with self.assertRaisesRegex(\n            ValueError,\n            \"dimensions in `inputs` must match.* dimensions in `shape`\",\n        ):\n            core.slice(inputs, start_indices, shape)\n\n        start_indices = (1, 1, 1)\n        shape = (2, 2)\n        with self.assertRaisesRegex(\n            ValueError,\n            \"dimensions in `start_indices` must match.* dimensions in `inputs`\",\n        ):\n            core.slice(inputs, start_indices, shape)\n\n    def test_slice_update(self):\n        inputs = KerasTensor((4, 4))\n        start_indices = KerasTensor((2,))\n        updates = KerasTensor((2, 2))\n        self.assertEqual(\n            core.slice_update(inputs, start_indices, updates).shape, (4, 4)\n        )\n\n        inputs = KerasTensor((4, 4, 4))\n        start_indices = KerasTensor((3,))\n        updates = KerasTensor((2, 2, 2))\n        self.assertEqual(\n            core.slice_update(inputs, start_indices, updates).shape, (4, 4, 4)\n        )\n\n    def test_stop_gradient(self):\n        variable = KerasTensor(shape=(3, 3), dtype=\"float32\")\n        self.assertEqual(core.stop_gradient(variable).shape, (3, 3))\n\n    def test_switch(self):\n        def fn(x, y):\n            return x[:, 0], y[0, :]\n\n        index = KerasTensor(())\n        x = KerasTensor((5, 2))\n        y = KerasTensor((5, 2))\n        self.assertEqual(core.switch(index, [fn], x, y)[0].shape, (5,))\n        self.assertEqual(core.switch(index, [fn], x, y)[1].shape, (2,))\n\n    def test_vectorized_map(self):\n        def f(x):\n            return x**2\n\n        xs = KerasTensor((6, 5))\n        ys = core.vectorized_map(f, xs)\n        self.assertEqual(ys.shape, (6, 5))\n\n        # Test nested output\n        def f2(x):\n            return {\"a\": x**2, \"b\": x * 10}\n\n        xs = KerasTensor((6, 5))\n        ys = core.vectorized_map(f2, xs)\n        self.assertEqual(ys[\"a\"].shape, (6, 5))\n        self.assertEqual(ys[\"b\"].shape, (6, 5))\n\n        # Test nested input\n        def f3(x):\n            return x[0] + x[1]\n\n        xs = (KerasTensor((6, 5)), KerasTensor((6, 5)))\n        self.assertEqual(core.vectorized_map(f3, xs).shape, (6, 5))\n\n    def test_while_loop(self):\n        def cond(args):\n            return tree.flatten(args)[0] < 10\n\n        def body(args):\n            return tree.map_structure(lambda x: x + 1, args)\n\n        loop_vars = KerasTensor((10,))\n        self.assertEqual(core.while_loop(cond, body, loop_vars).shape, (10,))\n\n    def test_unstack(self):\n        x = KerasTensor((2, 3, 4))\n        axis = 1\n        out = core.unstack(x, axis=axis)\n        self.assertEqual(len(out), 3)\n        for o in out:\n            self.assertEqual(o.shape, (2, 4))\n\n\nclass CoreOpsCorrectnessTest(testing.TestCase):\n    def test_associative_scan(self):\n        # Test prefix sum\n        arr = np.arange(5)\n        result = core.associative_scan(f=operator.add, elems=arr)\n        self.assertAllEqual(result, [0, 1, 3, 6, 10])\n        # Test reverse\n        result = core.associative_scan(f=operator.add, elems=arr, reverse=True)\n        self.assertAllEqual(result, [10, 10, 9, 7, 4])\n\n        # Test multiple dimensions, across different axes\n        batched_arr = np.stack([arr, arr + 1, arr + 2])\n        result = core.associative_scan(\n            f=operator.add, elems=batched_arr, axis=1\n        )\n        self.assertAllEqual(result[2], [2, 5, 9, 14, 20])\n        result = core.associative_scan(\n            f=operator.add, elems=batched_arr, axis=0\n        )\n        self.assertAllEqual(result[:, 0], [0, 1, 3])\n\n        # Test structured input\n        elems = {\n            \"a\": np.array([[0, 1, 2], [3, 4, 5]]),\n            \"b\": np.array([[6, 7, 8], [9, 10, 11]]),\n        }\n\n        def _dict_add(x, y):\n            return {\"a\": x[\"a\"] + y[\"b\"], \"b\": x[\"b\"] + y[\"b\"]}\n\n        ax0 = core.associative_scan(f=_dict_add, elems=elems, axis=0)\n        self.assertAllEqual(\n            ax0[\"b\"],\n            [[6, 7, 8], [15, 17, 19]],\n        )\n\n        # Test parallel scan op used in mamba\n        b, l, d, n = 1, 2, 3, 4\n        DB = np.random.rand(b, l, d, n)\n        DA = np.random.rand(b, l, d, n)\n\n        H_seq = np.zeros((b, d, n))\n        for i in range(l):\n            H_seq = DA[:, i] * H_seq + DB[:, i]\n\n        def scan_op(ci, cj):\n            a = cj[0] * ci[0]\n            b = cj[0] * ci[1] + cj[1]\n            return (a, b)\n\n        inputs = (DA.transpose(1, 0, 2, 3), DB.transpose(1, 0, 2, 3))\n        H_par = core.associative_scan(f=scan_op, elems=inputs)[-1][-1]\n\n        self.assertAllClose(H_seq, H_par)\n\n        # Test Operation call.\n        xs = np.arange(5, dtype=\"float32\")\n        self.assertAllClose(\n            core.AssociativeScan()(operator.add, xs), ops.cumsum(xs)\n        )\n\n    def test_cast(self):\n        x = ops.ones((2,), dtype=\"float32\")\n        y = ops.cast(x, \"float16\")\n        self.assertIn(\"float16\", str(y.dtype))\n\n        x = ops.KerasTensor((2,), dtype=\"float32\")\n        y = ops.cast(x, \"float16\")\n        self.assertEqual(\"float16\", y.dtype)\n        self.assertEqual(x.shape, y.shape)\n        self.assertTrue(hasattr(y, \"_keras_history\"))\n\n        # Test Operation call.\n        x = ops.ones((2,), dtype=\"float32\")\n        self.assertDType(core.Cast(\"float16\")(x), \"float16\")\n\n    @parameterized.named_parameters(\n        (\"float8_e4m3fn\", \"float8_e4m3fn\"), (\"float8_e5m2\", \"float8_e5m2\")\n    )\n    def test_cast_float8(self, float8_dtype):\n        # Cast to float8 and cast back\n        x = ops.ones((2,), dtype=\"float32\")\n        y = ops.cast(x, float8_dtype)\n        self.assertIn(float8_dtype, str(y.dtype))\n        x = ops.cast(y, \"float32\")\n        self.assertIn(\"float32\", str(x.dtype))\n\n        x = ops.KerasTensor((2,), dtype=\"float32\")\n        y = ops.cast(x, float8_dtype)\n        self.assertEqual(float8_dtype, y.dtype)\n        self.assertEqual(x.shape, y.shape)\n        self.assertTrue(hasattr(y, \"_keras_history\"))\n        x = ops.cast(y, \"float32\")\n        self.assertEqual(\"float32\", x.dtype)\n        self.assertEqual(x.shape, y.shape)\n        self.assertTrue(hasattr(x, \"_keras_history\"))\n\n    def test_cond(self):\n        t = ops.cond(True, lambda: 0, lambda: 1)\n        self.assertEqual(t, 0)\n        f = ops.cond(False, lambda: 0, lambda: 1)\n        self.assertEqual(f, 1)\n        f = ops.cond(False, lambda: None, lambda: None)\n        self.assertEqual(f, None)\n\n        out = ops.cond(\n            ops.convert_to_tensor(True),\n            lambda: ops.ones((1, 3)),\n            lambda: ops.zeros((1, 3)),\n        )\n        self.assertAllClose(out, ops.ones((1, 3)))\n\n        out = ops.cond(\n            ops.convert_to_tensor(False),\n            lambda: ops.ones((3,)),\n            lambda: ops.zeros((3,)),\n        )\n        self.assertAllClose(out, ops.zeros((3,)))\n\n        with self.assertRaises(ValueError):\n            ops.cond(\n                KerasTensor((), dtype=\"bool\"),\n                lambda: ops.ones((3,)),\n                lambda: ops.zeros((4,)),\n            )\n\n    def test_convert_to_tensor(self):\n        x = np.ones((2,))\n        x = ops.convert_to_tensor(x)\n        x = ops.convert_to_numpy(x)\n        self.assertAllEqual(x, (1, 1))\n        self.assertIsInstance(x, np.ndarray)\n\n        # Empty lists should give an empty array.\n        x = ops.convert_to_tensor([])\n        np_x = ops.convert_to_numpy(x)\n        self.assertTrue(ops.is_tensor(x))\n        self.assertAllEqual(x, [])\n        self.assertIsInstance(np_x, np.ndarray)\n\n        # Partially converted.\n        x = ops.convert_to_tensor((1, ops.array(2), 3))\n        self.assertAllEqual(x, (1, 2, 3))\n\n    @pytest.mark.skipif(\n        not backend.SUPPORTS_SPARSE_TENSORS,\n        reason=f\"{backend.backend()} backend doesn't support sparse tensors.\",\n    )\n    def test_convert_to_tensor_sparse(self):\n        if backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            x = tf.SparseTensor([[0, 0], [1, 2]], [1.0, 2.0], (2, 3))\n        elif backend.backend() == \"jax\":\n            import jax.experimental.sparse as jax_sparse\n\n            x = jax_sparse.BCOO(([1.0, 2.0], [[0, 0], [1, 2]]), shape=(2, 3))\n        else:\n            self.fail(f\"Sparse is unsupported with backend {backend.backend()}\")\n\n        x_default = ops.convert_to_tensor(x)\n        self.assertSparse(x_default)\n        self.assertAllClose(x, x_default)\n        x_sparse = ops.convert_to_tensor(x, sparse=True)\n        self.assertSparse(x_sparse)\n        self.assertAllClose(x, x_sparse)\n        x_dense = ops.convert_to_tensor(x, sparse=False)\n        self.assertSparse(x_dense, False)\n        self.assertAllClose(x, x_dense)\n\n        x_numpy = ops.convert_to_numpy(x)\n        self.assertIsInstance(x_numpy, np.ndarray)\n        self.assertAllClose(x_numpy, x_dense)\n\n    @pytest.mark.skipif(\n        not backend.SUPPORTS_RAGGED_TENSORS,\n        reason=f\"{backend.backend()} backend doesn't support ragged tensors.\",\n    )\n    def test_convert_to_tensor_ragged(self):\n        import tensorflow as tf\n\n        x = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])\n\n        x_default = ops.convert_to_tensor(x)\n        self.assertIsInstance(x_default, tf.RaggedTensor)\n        self.assertAllClose(x, x_default)\n        x_ragged = ops.convert_to_tensor(x, ragged=True)\n        self.assertIsInstance(x_ragged, tf.RaggedTensor)\n        self.assertAllClose(x, x_ragged)\n        x_dense = ops.convert_to_tensor(x, ragged=False)\n        self.assertNotIsInstance(x_dense, tf.RaggedTensor)\n        self.assertAllClose(x, x_dense)\n\n        x_numpy = ops.convert_to_numpy(x)\n        self.assertIsInstance(x_numpy, np.ndarray)\n        self.assertAllClose(x_numpy, x_dense)\n\n    @pytest.mark.skipif(\n        backend.backend() not in (\"tensorflow\", \"jax\", \"torch\"),\n        reason=(\n            f\"{backend.backend()} backend doesn't support `custom_gradient`.\"\n        ),\n    )\n    @parameterized.named_parameters(named_product(use_variable=(False, True)))\n    def test_custom_gradient(self, use_variable):\n        # function to test custom_gradient on\n        @ops.custom_gradient\n        def log1pexp(x):\n            e = ops.exp(x)\n\n            def grad(*args, upstream=None):\n                if upstream is None:\n                    (upstream,) = args\n                return ops.multiply(upstream, 1.0 - 1.0 / ops.add(1, e))\n\n            return ops.log(1 + e), grad\n\n        def log1pexp_nan(x):\n            return ops.log(1 + ops.exp(x))\n\n        x = ops.convert_to_tensor(100.0)\n        if use_variable:\n\n            class Log1PExpLayer(layers.Layer):\n                def __init__(self):\n                    super().__init__()\n                    self.v = backend.Variable(5.0, trainable=False)\n\n                def call(self, inputs):\n                    # The derivative of this layer is 1 with respect to inputs.\n                    # But on the side, we test passing a variable to a function\n                    # using @custom_gradient\n                    return log1pexp(self.v) + inputs\n\n            to_derive = Log1PExpLayer()\n        else:\n            to_derive = log1pexp\n\n        if backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            with tf.GradientTape() as tape1:\n                tape1.watch(x)\n                y = to_derive(x)\n            with tf.GradientTape() as tape2:\n                tape2.watch(x)\n                z = log1pexp_nan(x)\n            dy_dx = tape1.gradient(y, x)\n            dz_dx = tape2.gradient(z, x)\n            self.assertEqual(ops.convert_to_numpy(dy_dx), 1.0)\n        elif backend.backend() == \"jax\":\n            import jax\n\n            dy_dx = jax.grad(to_derive)(x)\n            dz_dx = jax.grad(log1pexp_nan)(x)\n            self.assertEqual(ops.convert_to_numpy(dy_dx), 1.0)\n            self.assertTrue(ops.isnan(dz_dx))\n        elif backend.backend() == \"torch\":\n            import torch\n\n            x = torch.tensor(100.0, requires_grad=True)\n            z = to_derive(x)\n            z.sum().backward()\n            self.assertEqual(ops.convert_to_numpy(x.grad), 1.0)\n\n    def test_dynamic_slice(self):\n        def cond(index, inputs, sum):\n            return index < 10\n\n        def body(index, inputs, sum):\n            sum = sum + core.slice(inputs, [index], [1])\n            index = index + 1\n            return index, inputs, sum\n\n        index, inputs, sum = 0, np.arange(10), np.array([0])\n        index, inputs, sum = core.while_loop(cond, body, (index, inputs, sum))\n        self.assertEqual(sum.shape, (1,))\n        self.assertAllClose(sum, [45])\n\n    def test_fori_loop(self):\n        def body_fun(i, x):\n            return x + i\n\n        initial_value = np.array(0)\n        result = core.fori_loop(0, 10, body_fun, initial_value)\n        self.assertAllClose(result, 45)\n\n        # Test Operation call.\n        self.assertAllClose(core.ForiLoop(0, 10, body_fun)(initial_value), 45)\n\n    def test_getitem(self):\n        np_tensor = np.arange(24).reshape(2, 3, 4)\n        tensor = ops.convert_to_tensor(np_tensor)\n\n        t = tensor[1]\n        n = np_tensor[1]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        t = tensor[1, 2, 3]\n        n = np_tensor[1, 2, 3]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        t = tensor[1:2]\n        n = np_tensor[1:2]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        t = tensor[1:2, 2:3, 3:4]\n        n = np_tensor[1:2, 2:3, 3:4]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        t = tensor[1:2, None]\n        n = np_tensor[1:2, None]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        t = tensor[1:2, 2:3, ...]\n        n = np_tensor[1:2, 2:3, ...]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        t = tensor[1:2, ..., 3:4]\n        n = np_tensor[1:2, ..., 3:4]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        t = tensor[None, ..., 3:4, None]\n        n = np_tensor[None, ..., 3:4, None]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        t = tensor[1:2:None]\n        n = np_tensor[1:2:None]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        t = tensor[:, 2]\n        n = np_tensor[:, 2]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        t = tensor[None]\n        n = np_tensor[None]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        t = tensor[None, None]\n        n = np_tensor[None, None]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        t = tensor[...]\n        n = np_tensor[...]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        t = tensor[..., 1]\n        n = np_tensor[..., 1]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        t = tensor[..., 1, 2]\n        n = np_tensor[..., 1, 2]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        t = tensor[..., -1, 2]\n        n = np_tensor[..., -1, 2]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        t = tensor[..., -1:-2, 2]\n        n = np_tensor[..., -1:-2, 2]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        t = tensor[..., None, None]\n        n = np_tensor[..., None, None]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        t = tensor[None, ..., None]\n        n = np_tensor[None, ..., None]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        t = tensor[1, 2, None, ..., None]\n        n = np_tensor[1, 2, None, ..., None]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        t = tensor[None, ..., 1, 2]\n        n = np_tensor[None, ..., 1, 2]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        t = tensor[1, None, 2]\n        n = np_tensor[1, None, 2]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        index_tensor = ops.convert_to_tensor(np.array(1, dtype=np.int32))\n        t = tensor[index_tensor]\n        n = np_tensor[ops.convert_to_numpy(index_tensor)]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        index_tensor = ops.convert_to_tensor(np.array(1, dtype=np.int32))\n        t = tensor[index_tensor, 2, None]\n        n = np_tensor[ops.convert_to_numpy(index_tensor), 2, None]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        index_tensor = ops.convert_to_tensor(np.array(-2, dtype=np.int32))\n        t = tensor[index_tensor, 1]\n        n = np_tensor[ops.convert_to_numpy(index_tensor), 1]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        index_tensor = ops.convert_to_tensor(np.array(-1, dtype=np.int32))\n        t = tensor[-2, index_tensor]\n        n = np_tensor[-2, ops.convert_to_numpy(index_tensor)]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        # Negative indexing\n        t = tensor[-1]\n        n = np_tensor[-1]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        t = tensor[1, -1, -2]\n        n = np_tensor[1, -1, -2]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        # Slicing with step\n        t = tensor[::2]\n        n = np_tensor[::2]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        # Mixed slices and integers\n        t = tensor[1, :, 1:4]\n        n = np_tensor[1, :, 1:4]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n        t = tensor[:, 1:2, 3]\n        n = np_tensor[:, 1:2, 3]\n        self.assertEqual(t.shape, n.shape)\n        self.assertAllClose(t, n)\n\n    def test_is_tensor(self):\n        np_x = np.array([[1, 2, 3], [3, 2, 1]])\n        x = backend.convert_to_tensor(np_x)\n        if backend.backend() != \"numpy\":\n            self.assertFalse(ops.is_tensor(np_x))\n        self.assertTrue(ops.is_tensor(x))\n        self.assertFalse(ops.is_tensor([1, 2, 3]))\n\n    def test_map(self):\n        def f(x):\n            return x**2\n\n        xs = np.arange(10)\n        self.assertAllClose(ops.map(f, xs), xs**2)\n\n        # Test nested output\n        def f2(x):\n            return {\"a\": x**2, \"b\": x * 10}\n\n        xs = np.random.rand(2, 3, 4).astype(\"float32\")\n        outputs = ops.map(f2, xs)\n        self.assertAllClose(outputs[\"a\"], xs**2)\n        self.assertAllClose(outputs[\"b\"], xs * 10)\n\n        # Test with nested structures\n        def dict_input_fn(inputs):\n            x = inputs[\"x\"][:, 0]\n            y = inputs[\"y\"] + 1\n            return {\"x\": x, \"y\": y}\n\n        def list_input_fn(inputs):\n            return [x**2 for x in inputs]\n\n        xs = {\n            \"x\": ops.convert_to_tensor(\n                np.random.rand(4, 100, 3), dtype=\"float32\"\n            ),\n            \"y\": ops.convert_to_tensor(\n                np.random.randint(0, 10, size=(4, 1)), dtype=\"int32\"\n            ),\n        }\n        xs1 = [\n            ops.convert_to_tensor(np.random.rand(4, 100, 3), dtype=\"float32\"),\n            ops.convert_to_tensor(\n                np.random.randint(0, 10, size=(4, 1)), dtype=\"int32\"\n            ),\n        ]\n        ys = ops.map(dict_input_fn, xs)\n        self.assertEqual(ys[\"x\"].shape, (4, 100))\n        self.assertEqual(\n            ops.convert_to_numpy(ys[\"y\"]).all(),\n            ops.convert_to_numpy(xs[\"y\"] + 1).all(),\n        )\n        ys = ops.map(list_input_fn, xs1)\n        for x, y in zip(xs1, ys):\n            self.assertEqual(\n                (ops.convert_to_numpy(y)).all(),\n                (ops.convert_to_numpy(x) ** 2).all(),\n            )\n\n        # Test Operation call.\n        xs = np.arange(10)\n        self.assertAllClose(ops.Map()(f, xs), xs**2)\n\n    def test_saturate_cast(self):\n        x = ops.ones((2,), dtype=\"float32\")\n        y = ops.saturate_cast(x, \"float16\")\n        self.assertIn(\"float16\", str(y.dtype))\n\n        x = ops.KerasTensor((2,), dtype=\"float32\")\n        y = ops.saturate_cast(x, \"float16\")\n        self.assertEqual(\"float16\", y.dtype)\n        self.assertEqual(x.shape, y.shape)\n        self.assertTrue(hasattr(y, \"_keras_history\"))\n\n        # Test Operation call.\n        x = np.array([-256, 1.0, 257.0], dtype=\"float32\")\n        y = core.SaturateCast(\"uint8\")(x)\n        self.assertDType(y, \"uint8\")\n        # Check that the values are the same\n        self.assertAllClose(y, np.clip(x, 0, 255).astype(\"uint8\"))\n\n    def test_scan(self):\n        # Test cumsum\n        def cumsum(carry, xs):\n            carry = carry + xs\n            return carry, carry\n\n        init = np.array(0, dtype=\"float32\")\n        xs = np.array([1, 2, 3, 4, 10, 20], dtype=\"float32\")\n        carry, result = core.scan(cumsum, init, xs)\n        self.assertAllClose(carry, 40.0)\n        self.assertAllClose(result, ops.cumsum(xs))\n\n        # Test reverse=True\n        carry, result = core.scan(cumsum, init, xs, reverse=True)\n        self.assertAllClose(carry, 40.0)\n        self.assertAllClose(result, [40, 39, 37, 34, 30, 20])\n\n        # Test unroll\n        for unroll in (True, False, 2):\n            carry, result = core.scan(cumsum, init, xs, unroll=unroll)\n            self.assertAllClose(carry, 40.0)\n            self.assertAllClose(result, ops.cumsum(xs))\n\n        # Test xs is None\n        def fibonaccis(carry, _):\n            return (carry[1], carry[0] + carry[1]), None\n\n        init = (np.array(0, dtype=\"float32\"), np.array(1, dtype=\"float32\"))\n        carry, _ = core.scan(fibonaccis, init, length=6)\n        self.assertAllClose(carry, [8, 13])\n\n        # Test nested init\n        if backend.backend() != \"tensorflow\":\n            # tensorflow doesn't support arbitrary shape/dtype of the output of\n            # `f`. It must be the same as `init`.\n            def multiply_two(carry, _):\n                value1 = carry[\"value1\"]\n                value2 = carry[\"value2\"]\n                return (\n                    {\"value1\": value1 * 2, \"value2\": value2 * 2},\n                    value1 * 2 + value2 * 2,\n                )\n\n            init = {\"value1\": 2.0, \"value2\": 3.0}\n            carry, result = core.scan(multiply_two, init, length=3)\n            self.assertAllClose(carry[\"value1\"], 16)\n            self.assertAllClose(carry[\"value2\"], 24)\n            self.assertAllClose(result, [10, 20, 40])\n\n        # Test nested xs\n        def reduce_add(carry, xs):\n            value1 = xs[\"value1\"]\n            value2 = xs[\"value2\"]\n            return carry, value1 + value2\n\n        init = np.array(0, dtype=\"float32\")\n        xs = {\n            \"value1\": np.array([1, 2, 3], dtype=\"float32\"),\n            \"value2\": np.array([10, 20, 30], dtype=\"float32\"),\n        }\n        _, result = core.scan(reduce_add, init, xs)\n        self.assertAllClose(result, [11, 22, 33])\n\n        # Test Operation call.\n        init = np.array(0, dtype=\"float32\")\n        xs = np.array([1, 2, 3, 4, 10, 20], dtype=\"float32\")\n        carry, result = core.Scan()(cumsum, init, xs)\n        self.assertAllClose(carry, 40.0)\n        self.assertAllClose(result, ops.cumsum(xs))\n\n    def test_scatter(self):\n        # Test 1D\n        indices = np.array([[1], [3], [4], [7]])\n        values = np.array([9, 10, 11, 12])\n        self.assertAllClose(\n            core.scatter(indices, values, (8,)),\n            [0, 9, 0, 10, 11, 0, 0, 12],\n        )\n        # Test 2D\n        indices = np.array([[0, 1], [2, 0]])\n        values = np.array([5, 10])\n        self.assertAllClose(\n            core.scatter(indices, values, (3, 2)), [[0, 5], [0, 0], [10, 0]]\n        )\n        # Test 3D\n        indices = np.array([[1], [3]])\n        values = np.array(\n            [\n                [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],\n                [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],\n            ]\n        )\n        self.assertAllClose(\n            core.scatter(indices, values, (4, 4, 4)),\n            [\n                [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],\n                [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],\n                [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]],\n                [[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]],\n            ],\n        )\n        # Test slices\n        indices = np.array([[2], [4]])\n        values = np.array([[1, 2, 3], [4, 5, 6]])\n        self.assertAllClose(\n            core.scatter(indices, values, (6, 3)),\n            [[0, 0, 0], [0, 0, 0], [1, 2, 3], [0, 0, 0], [4, 5, 6], [0, 0, 0]],\n        )\n        # Duplicate indices\n        indices = np.array([[0], [0]])\n        values = np.array([1, 1])\n        self.assertAllClose(core.scatter(indices, values, (1,)), [2])\n\n        # Test Operation call.\n        indices = np.array([[1, 0], [0, 1]])\n        values = np.array([10, 20])\n        shape = (2, 2)\n        self.assertAllClose(\n            core.Scatter(shape)(indices, values), np.array([[0, 20], [10, 0]])\n        )\n\n    def test_scatter_update(self):\n        # Test 1D.\n        inputs = np.array([0, 0, 0, 0, 0, 0, 0, 0])\n        indices = [[1], [3], [4], [7]]\n        updates = np.array([9, 10, 11, 12])\n        self.assertAllClose(\n            core.scatter_update(inputs, indices, updates),\n            [0, 9, 0, 10, 11, 0, 0, 12],\n        )\n\n        # Test 2D.\n        inputs = np.array([[1, 1], [1, 1], [1, 1]])\n        indices = [[0, 1], [2, 0]]\n        updates = np.array([5, 10])\n        self.assertAllClose(\n            core.scatter_update(inputs, indices, updates),\n            [[1, 5], [1, 1], [10, 1]],\n        )\n\n        # Test updates has multiple dimension.\n        inputs = np.ones([4, 4, 4])\n        indices = [[1, 1], [2, 2]]\n        updates = np.array([[0, 1, 2, 3], [3, 2, 1, 0]], dtype=\"float32\")\n        outputs = core.scatter_update(inputs, indices, updates)\n        self.assertTrue(ops.is_tensor(outputs))\n        self.assertAllClose(outputs[1, 1, :], [0, 1, 2, 3])\n        self.assertAllClose(outputs[2, 2, :], [3, 2, 1, 0])\n\n        # Test Operation call.\n        inputs = np.array([[0, 0], [0, 0]])\n        indices = np.array([[1, 0], [0, 1]])\n        updates = np.array([10, 20])\n        self.assertAllClose(\n            core.ScatterUpdate()(inputs, indices, updates),\n            np.array([[0, 20], [10, 0]]),\n        )\n\n    def test_scatter_update_with_reduction(self):\n        # Test add reduction with duplicate indices\n        inputs = np.zeros((4,))\n        indices = [[0], [0], [1]]\n        updates = np.array([1.0, 2.0, 3.0])\n        result = core.scatter_update(inputs, indices, updates, reduction=\"add\")\n        self.assertAllClose(result, [3.0, 3.0, 0.0, 0.0])\n\n        # Test add reduction 2D\n        inputs = np.zeros((3, 3))\n        indices = [[0, 0], [1, 1], [0, 0]]\n        updates = np.array([1.0, 2.0, 3.0])\n        result = core.scatter_update(inputs, indices, updates, reduction=\"add\")\n        self.assertAllClose(result[0, 0], 4.0)\n        self.assertAllClose(result[1, 1], 2.0)\n\n        # Test max reduction with duplicates\n        inputs = np.zeros((4,))\n        indices = [[0], [0], [1]]\n        updates = np.array([3.0, 5.0, 2.0])\n        result = core.scatter_update(inputs, indices, updates, reduction=\"max\")\n        self.assertAllClose(result, [5.0, 2.0, 0.0, 0.0])\n\n        # Test min reduction\n        inputs = np.array([10.0, 10.0, 10.0, 10.0])\n        indices = [[0], [0], [1]]\n        updates = np.array([3.0, 5.0, 2.0])\n        result = core.scatter_update(inputs, indices, updates, reduction=\"min\")\n        self.assertAllClose(result, [3.0, 2.0, 10.0, 10.0])\n\n        # Test mul reduction\n        inputs = np.array([2.0, 3.0, 1.0, 1.0])\n        indices = [[0], [0], [1]]\n        updates = np.array([3.0, 5.0, 2.0])\n        result = core.scatter_update(inputs, indices, updates, reduction=\"mul\")\n        self.assertAllClose(result, [30.0, 6.0, 1.0, 1.0])\n\n        # Test Operation call with reduction\n        inputs = np.zeros((4,))\n        indices = [[0], [0], [1]]\n        updates = np.array([1.0, 2.0, 3.0])\n        result = core.ScatterUpdate(reduction=\"add\")(inputs, indices, updates)\n        self.assertAllClose(result, [3.0, 3.0, 0.0, 0.0])\n\n    def test_shape(self):\n        x = ops.ones((2, 3, 7, 1))\n        self.assertEqual(core.shape(x).__class__, tuple)\n        self.assertAllEqual(core.shape(x), (2, 3, 7, 1))\n\n        x = KerasTensor((None, 3, None, 1))\n        self.assertEqual(core.shape(x).__class__, tuple)\n        self.assertAllEqual(core.shape(x), (None, 3, None, 1))\n\n    @pytest.mark.skipif(\n        not backend.SUPPORTS_SPARSE_TENSORS,\n        reason=f\"{backend.backend()} backend doesn't support sparse tensors.\",\n    )\n    def test_shape_sparse(self):\n        if backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            x = tf.SparseTensor([[0, 0], [1, 2]], [1.0, 2.0], (2, 3))\n        elif backend.backend() == \"jax\":\n            import jax.experimental.sparse as jax_sparse\n\n            x = jax_sparse.BCOO(([1.0, 2.0], [[0, 0], [1, 2]]), shape=(2, 3))\n        else:\n            self.fail(f\"Sparse is unsupported with backend {backend.backend()}\")\n\n        self.assertAllEqual(core.shape(x), (2, 3))\n\n    @pytest.mark.skipif(\n        not backend.SUPPORTS_SPARSE_TENSORS,\n        reason=f\"{backend.backend()} backend doesn't support ragged tensors.\",\n    )\n    def test_shape_ragged(self):\n        import tensorflow as tf\n\n        x = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])\n        self.assertAllEqual(core.shape(x), (5, None))\n\n        x = tf.RaggedTensor.from_row_lengths(tf.zeros([15, 2]), [4, 5, 6])\n        self.assertAllEqual(core.shape(x), (3, None, 2))\n\n    def test_slice(self):\n        # Test 1D.\n        inputs = np.arange(10)\n        start_indices = np.array([1])\n        shape = np.array([4])\n        self.assertAllClose(\n            core.slice(inputs, start_indices, shape),\n            [1, 2, 3, 4],\n        )\n\n        # Test 2D.\n        inputs = np.broadcast_to(np.arange(10), (4, 10))\n        start_indices = np.array([1, 1])\n        shape = np.array([2, 4])\n        self.assertAllClose(\n            core.slice(inputs, start_indices, shape),\n            [[1, 2, 3, 4], [1, 2, 3, 4]],\n        )\n\n        # Test N-D.\n        inputs = np.broadcast_to(np.arange(10), (4, 4, 4, 10))\n        start_indices = np.array([1, 1, 1, 1])\n        shape = np.array([1, 2, 3, 4])\n        outputs = core.slice(inputs, start_indices, shape)\n        expected = np.broadcast_to(np.arange(1, 5), (1, 2, 3, 4))\n        self.assertAllClose(outputs, expected)\n\n        # Test Operation call.\n        inputs = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])\n        start_indices = np.array([1, 1])\n        shape = (2, 2)\n        self.assertAllClose(\n            core.Slice(shape)(inputs, start_indices), np.array([[5, 6], [8, 9]])\n        )\n\n    def test_slice_update(self):\n        # Test 1D.\n        inputs = np.array([0, 0, 0, 0, 0, 0, 0, 0])\n        start_indices = np.array([1])\n        updates = np.array([9, 10, 11, 12])\n        self.assertAllClose(\n            core.slice_update(inputs, start_indices, updates),\n            [0, 9, 10, 11, 12, 0, 0, 0],\n        )\n\n        # Test 2D.\n        inputs = np.array([[1, 1], [1, 1], [1, 1]])\n        start_indices = [1, 0]\n        updates = np.array([[2, 2], [2, 2]])\n        self.assertAllClose(\n            core.slice_update(inputs, start_indices, updates),\n            [[1, 1], [2, 2], [2, 2]],\n        )\n\n        # Test N-D.\n        inputs = np.ones([4, 4, 4, 4])\n        start_indices = [1, 1, 2, 2]\n        updates = np.zeros([2, 2, 2, 2])\n        outputs = core.slice_update(inputs, start_indices, updates)\n        self.assertAllClose(outputs[1:3, 1:3, 2:4, 2:4], np.zeros([2, 2, 2, 2]))\n\n        # Test Operation call.\n        inputs = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])\n        start_indices = np.array([1, 1])\n        updates = np.array([[10, 11], [12, 13]])\n        self.assertAllClose(\n            core.SliceUpdate()(inputs, start_indices, updates),\n            np.array([[1, 2, 3], [4, 10, 11], [7, 12, 13]]),\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_stop_gradient(self):\n        class ExampleLayer(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.w = self.add_weight(shape=(1,), initializer=\"zeros\")\n                self.b = self.add_weight(shape=(1,), initializer=\"zeros\")\n\n            def call(self, x, training=False):\n                return ops.add(\n                    ops.multiply(x, ops.stop_gradient(self.w)), self.b\n                )\n\n        model = models.Sequential([ExampleLayer()])\n        model.compile(\n            optimizer=optimizers.SGD(), loss=losses.MeanSquaredError()\n        )\n        rng = np.random.default_rng(0)\n        x = np.ones((2, 4), dtype=\"float32\")\n        y = rng.standard_normal((2, 4), dtype=\"float32\")\n        model.fit(x, y, epochs=1, batch_size=2)\n        self.assertEqual(model.layers[0].w.numpy(), 0.0)\n        self.assertNotEqual(model.layers[0].b.numpy(), 0.0)\n\n    def test_stop_gradient_no_fit(self):\n        x = ops.random.uniform(shape=(2, 4), dtype=\"float32\")\n        y = ops.stop_gradient(x)\n        self.assertAllClose(x, y)\n\n        # Functional.\n        a = layers.Input(shape=(2,))\n        b = layers.Dense(4, kernel_initializer=\"ones\", use_bias=False)(a)\n        c = layers.Dense(4, kernel_initializer=\"ones\", use_bias=False)(b)\n        d = ops.stop_gradient(b) + c\n        model = models.Model(inputs=a, outputs=d)\n        output = model(ops.convert_to_tensor([[1.0, 2.0]]))\n        self.assertAllClose(output, 15.0)\n\n        # Test Operation call.\n        variable = ops.convert_to_tensor(\n            np.array([1.0, 2.0, 3.0], dtype=\"float32\")\n        )\n        self.assertAllClose(core.StopGradient()(variable), variable)\n\n    def test_switch(self):\n        def fn1(x, y):\n            return x + y\n\n        def fn2(x, y):\n            return x - y\n\n        x = np.random.rand(2, 3, 4).astype(\"float32\")\n        y = np.random.rand(2, 3, 4).astype(\"float32\")\n        branches = [fn1, fn2]\n        self.assertAllClose(core.switch(0, branches, x, y), x + y)\n        self.assertAllClose(core.switch(1, branches, x, y), x - y)\n\n        # Test out-of-bound index\n        self.assertAllClose(core.switch(-100, branches, x, y), x + y)\n        self.assertAllClose(core.switch(100, branches, x, y), x - y)\n\n        # Test Operation call.\n        self.assertAllClose(core.Switch()(0, branches, x, y), x + y)\n        self.assertAllClose(core.Switch()(1, branches, x, y), x - y)\n\n    def test_vectorized_map(self):\n        def fn(x):\n            return x + 1\n\n        output = ops.vectorized_map(fn, ops.zeros((2, 3), dtype=\"float32\"))\n        self.assertAllClose(backend.convert_to_numpy(output), np.ones((2, 3)))\n\n        def fn(x):\n            return ops.stack([x, x])\n\n        output = ops.vectorized_map(fn, ops.zeros((2, 3), dtype=\"float32\"))\n        self.assertAllClose(\n            backend.convert_to_numpy(output), np.zeros((2, 2, 3))\n        )\n\n        # Case: multiple args\n        def fn(elems):\n            x, y = elems\n            return x + y\n\n        output = ops.vectorized_map(fn, [ops.ones((2, 3)), ops.ones((2, 3))])\n        self.assertAllClose(output, 2 * np.ones((2, 3)))\n\n    @parameterized.named_parameters(\n        [\n            {\n                \"testcase_name\": \"scalar_data_with_max\",\n                \"loop_vars\": np.array(0),\n                \"expected_output\": np.array(5),\n                \"maximum_iterations\": 5,\n            },\n            {\n                \"testcase_name\": \"scalar_data_no_max\",\n                \"loop_vars\": np.array(0),\n                \"expected_output\": np.array(10),\n                \"maximum_iterations\": None,\n            },\n            {\n                \"testcase_name\": \"nested_data_with_max\",\n                \"loop_vars\": {\n                    \"a\": np.array(0),\n                    \"b\": (np.array(1), np.array(2)),\n                },\n                \"expected_output\": {\n                    \"a\": np.array(5),\n                    \"b\": (np.array(6), np.array(7)),\n                },\n                \"maximum_iterations\": 5,\n            },\n            {\n                \"testcase_name\": \"nested_data_no_max\",\n                \"loop_vars\": {\n                    \"a\": np.array(0),\n                    \"b\": (np.array(1), np.array(2)),\n                },\n                \"expected_output\": {\n                    \"a\": np.array(10),\n                    \"b\": (np.array(11), np.array(12)),\n                },\n                \"maximum_iterations\": None,\n            },\n        ]\n    )\n    def test_while_loop(self, loop_vars, expected_output, maximum_iterations):\n        def cond(args):\n            return tree.flatten(args)[0] < 10\n\n        def body(args):\n            return tree.map_structure(lambda x: x + 1, args)\n\n        output = core.while_loop(\n            cond, body, loop_vars, maximum_iterations=maximum_iterations\n        )\n        tree.map_structure(self.assertAllClose, output, expected_output)\n\n        # Test Operation call.\n        output = core.WhileLoop(\n            cond, body, maximum_iterations=maximum_iterations\n        )(loop_vars)\n        tree.map_structure(self.assertAllClose, output, expected_output)\n\n    @parameterized.named_parameters(\n        [\n            {\n                \"testcase_name\": \"with_max\",\n                \"state\": (np.array(0), np.array(1)),\n                \"output\": (np.array(5), np.array(6)),\n                \"maximum_iterations\": 5,\n            },\n            {\n                \"testcase_name\": \"no_max\",\n                \"state\": (np.array(0), np.array(1)),\n                \"output\": (np.array(10), np.array(11)),\n                \"maximum_iterations\": None,\n            },\n        ]\n    )\n    def test_while_loop_list_data(self, state, output, maximum_iterations):\n        def cond(*args):\n            return tree.flatten(args)[0] < 10\n\n        def body(*args):\n            return tree.map_structure(lambda x: x + 1, args)\n\n        state = core.while_loop(\n            cond, body, state, maximum_iterations=maximum_iterations\n        )\n        tree.map_structure(self.assertAllClose, state, output)\n\n    def test_unstack(self):\n        rng = np.random.default_rng(0)\n        x = rng.uniform(size=(2, 3, 4))\n        x_tensor = ops.convert_to_tensor(x)\n        axis = 1\n        out = ops.unstack(x_tensor, axis=axis)\n        out_ex = [x[:, i, :] for i in range(x.shape[axis])]\n        self.assertEqual(len(out), len(out_ex))\n        for o, o_e in zip(out, out_ex):\n            o = ops.convert_to_numpy(o)\n            self.assertAllClose(o, o_e)\n\n        # Test Operation call.\n        out = ops.Unstack(axis=axis)(x_tensor)\n        self.assertEqual(len(out), len(out_ex))\n        for o, o_e in zip(out, out_ex):\n            o = ops.convert_to_numpy(o)\n            self.assertAllClose(o, o_e)\n\n\nclass CoreOpsDtypeTest(testing.TestCase):\n    \"\"\"Test the dtype to verify that the behavior matches JAX.\"\"\"\n\n    ALL_DTYPES = [\n        x\n        for x in dtypes.ALLOWED_DTYPES\n        if x\n        not in (\n            \"string\",\n            \"complex64\",\n            \"complex128\",\n            # Remove 64-bit dtypes.\n            \"float64\",\n            \"uint64\",\n            \"int64\",\n        )\n        + dtypes.FLOAT8_TYPES  # Remove float8 dtypes for the following tests\n    ] + [None]\n    INT_DTYPES = [x for x in dtypes.INT_TYPES if x not in (\"uint64\", \"int64\")]\n    FLOAT_DTYPES = [x for x in dtypes.FLOAT_TYPES if x not in (\"float64\",)]\n\n    if backend.backend() == \"torch\":\n        ALL_DTYPES = [x for x in ALL_DTYPES if x not in (\"uint16\", \"uint32\")]\n        INT_DTYPES = [x for x in INT_DTYPES if x not in (\"uint16\", \"uint32\")]\n\n    @parameterized.named_parameters(\n        named_product(\n            dtype=[dtype for dtype in ALL_DTYPES if dtype is not None]\n        )\n    )\n    def test_cast(self, dtype):\n        x = np.ones((1,))\n\n        self.assertDType(core.cast(x, dtype), dtype)\n        self.assertDType(core.Cast(dtype).symbolic_call(x), dtype)\n\n    @parameterized.parameters(\n        ((), None, backend.floatx()),\n        ([], None, backend.floatx()),\n        (bool(0), None, \"bool\"),\n        (int(0), None, \"int32\"),\n        (float(0), None, backend.floatx()),\n        (1, \"bool\", \"bool\"),\n        (1.0, \"int32\", \"int32\"),\n        (1.0, \"float32\", \"float32\"),\n        ([False, True, False], None, \"bool\"),\n        ([1, 2, 3], None, \"int32\"),\n        ([1.0, 2.0, 3.0], None, backend.floatx()),\n        ([1, 2.0, 3], None, backend.floatx()),\n        ([[False], [True], [False]], None, \"bool\"),\n        ([[1], [2], [3]], None, \"int32\"),\n        ([[1], [2.0], [3]], None, backend.floatx()),\n        *[\n            (np.array(0, dtype=dtype), None, dtype)\n            for dtype in ALL_DTYPES\n            if dtype is not None\n        ],\n        *[\n            ([[1, 0, 1], [1, 1, 0]], dtype, dtype)\n            for dtype in ALL_DTYPES\n            if dtype is not None\n        ],\n    )\n    def test_convert_to_tensor(self, x, dtype, expected_dtype):\n        self.assertDType(ops.convert_to_tensor(x, dtype=dtype), expected_dtype)\n\n    @parameterized.named_parameters(\n        named_product(\n            dtype=[dtype for dtype in ALL_DTYPES if dtype is not None]\n        )\n    )\n    def test_convert_to_tensor_with_tensor(self, dtype):\n        x = ops.convert_to_tensor(np.ones((2, 3), dtype=\"float32\"))\n\n        self.assertDType(ops.convert_to_tensor(x, dtype=dtype), dtype)\n\n    @parameterized.named_parameters(\n        named_product(\n            dtype=[dtype for dtype in ALL_DTYPES if dtype is not None]\n        )\n    )\n    def test_convert_to_tensor_with_variable(self, dtype):\n        x = backend.Variable(np.ones((2, 3), dtype=\"float32\"))\n\n        self.assertDType(ops.convert_to_tensor(x, dtype=dtype), dtype)\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_saturate_cast(self, dtype):\n        x = np.ones((1,))\n\n        self.assertDType(core.saturate_cast(x, dtype), dtype)\n        self.assertDType(core.SaturateCast(dtype).symbolic_call(x), dtype)\n\n\nclass CoreOpsBehaviorTests(testing.TestCase):\n    def test_associative_scan_invalid_arguments(self):\n        # varying dimension at scan axis\n        x = (np.array([1, 2]), np.array([3, 4]), np.array([5, 6, 7]))\n        with self.assertRaisesRegex(ValueError, \" first dimension\"):\n            core.associative_scan(lambda x, y: (x[0] + y[0], x[1] + y[1]), x)\n\n        # same error, symbolic\n        x = (\n            KerasTensor((None, 5)),\n            KerasTensor((None, 4)),\n        )\n        with self.assertRaisesRegex(ValueError, \" first dimension\"):\n            core.associative_scan(\n                lambda x, y: (x[0] + y[0], x[1] + y[1]), x, axis=1\n            )\n\n    def test_cond_check_output_spec(self):\n        mock_spec = Mock(dtype=\"float32\", shape=(2, 2))\n        mock_spec_different = Mock(dtype=\"int32\", shape=(3, 3))\n\n        # List & tuple.\n        self.assertTrue(\n            core.Cond()._check_output_spec(\n                [mock_spec, mock_spec], [mock_spec, mock_spec]\n            )\n        )\n        self.assertTrue(\n            core.Cond()._check_output_spec([mock_spec], [mock_spec])\n        )\n        self.assertFalse(\n            core.Cond()._check_output_spec(\n                [mock_spec], [mock_spec, mock_spec_different]\n            )\n        )\n        self.assertTrue(\n            core.Cond()._check_output_spec((mock_spec,), (mock_spec,))\n        )\n        self.assertFalse(\n            core.Cond()._check_output_spec(\n                (mock_spec,), (mock_spec, mock_spec_different)\n            )\n        )\n\n        # Dict.\n        self.assertTrue(\n            core.Cond()._check_output_spec({\"a\": mock_spec}, {\"a\": mock_spec})\n        )\n        self.assertFalse(\n            core.Cond()._check_output_spec({\"a\": mock_spec}, {\"b\": mock_spec})\n        )\n        self.assertFalse(\n            core.Cond()._check_output_spec(\n                {\"a\": mock_spec}, {\"a\": mock_spec, \"b\": mock_spec}\n            )\n        )\n\n        # None.\n        self.assertTrue(core.Cond()._check_output_spec(None, None))\n        self.assertFalse(\n            core.Cond()._check_output_spec(\n                None, Mock(dtype=\"float32\", shape=(2, 2))\n            )\n        )\n        self.assertFalse(\n            core.Cond()._check_output_spec(\n                Mock(dtype=\"float32\", shape=(2, 2)), None\n            )\n        )\n\n        # KerasTensor.\n        mock_spec1 = KerasTensor(shape=(2, 2), dtype=\"float32\")\n        mock_spec2 = KerasTensor(shape=(2, 2), dtype=\"float32\")\n        self.assertTrue(core.Cond()._check_output_spec(mock_spec1, mock_spec2))\n\n    @pytest.mark.requires_trainable_backend\n    def test_cond_raw_bool_compile(self):\n        class ExampleLayer(layers.Layer):\n            def call(self, x, training=False):\n                return ops.cond(training, lambda: x, lambda: x * 2.0)\n\n        model = models.Sequential([ExampleLayer()])\n        model.compile(\n            optimizer=optimizers.SGD(), loss=losses.MeanSquaredError()\n        )\n        x = np.ones((2, 4), dtype=\"float32\")\n        y = np.zeros((2, 4), dtype=\"float32\")\n        model.evaluate(x, y, batch_size=2)\n\n    def test_convert_to_numpy(self):\n        x = ops.array([1, 2, 3], dtype=\"float32\")\n        y = ops.convert_to_numpy(x)\n        self.assertIsInstance(y, np.ndarray)\n        # Test assignment -- should not fail.\n        y[0] = 1.0\n\n        with self.assertRaises(ValueError):\n            ops.convert_to_numpy(KerasTensor((2,)))\n\n    def test_scan_invalid_arguments(self):\n        def cumsum(carry, xs):\n            carry = carry + xs\n            return carry, carry\n\n        init = np.array(0, dtype=\"float32\")\n        xs = np.array([1, 2, 3, 4, 10, 20], dtype=\"float32\")\n\n        # Test non-callable\n        with self.assertRaisesRegex(TypeError, \"should be a callable.\"):\n            core.scan(123, init, xs)\n\n        # Test bad unroll\n        with self.assertRaisesRegex(\n            ValueError, \"must be an positive integer or boolean.\"\n        ):\n            core.scan(cumsum, init, xs, unroll=-1)\n\n        # Test both xs and length are None\n        with self.assertRaisesRegex(ValueError, \"to scan over and\"):\n            core.scan(cumsum, init, xs=None, length=None)\n\n    def test_slice_compute_output_spec(self):\n        inputs = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=\"float32\")\n        start_indices = np.array([1, 1])\n        shape = (2, 2)\n        output_spec = core.Slice(shape).compute_output_spec(\n            inputs, start_indices\n        )\n        self.assertEqual(output_spec.shape, shape)\n        self.assertEqual(output_spec.dtype, inputs.dtype)\n\n    def test_stop_gradient_compute_output_spec(self):\n        variable = KerasTensor(shape=(3,), dtype=\"float32\")\n        stop_gradient = core.StopGradient()\n        output_spec = stop_gradient.compute_output_spec(variable)\n        self.assertEqual(output_spec.shape, variable.shape)\n        self.assertEqual(output_spec.dtype, variable.dtype)\n\n    def test_vectorized_map_serialization(self):\n        @object_registration.register_keras_serializable()\n        def f(x):\n            return x + x\n\n        inputs = input_layer.Input((10,), dtype=\"float32\")\n        outputs = core.vectorized_map(f, inputs)\n        model = models.Functional(inputs, outputs)\n        reloaded_model = model.from_config(model.get_config())\n        x = np.random.rand(5, 10).astype(\"float32\")\n        self.assertAllClose(model(x), reloaded_model(x))\n\n    def test_while_loop_output_spec(self):\n        # Define dummy cond and body functions\n        def cond(x):\n            return True\n\n        def body(x):\n            return (x,)\n\n        while_loop = core.WhileLoop(cond, body, maximum_iterations=None)\n        loop_vars = (KerasTensor(shape=(10,), dtype=\"float32\"),)\n        output_spec = while_loop.compute_output_spec(loop_vars)\n        self.assertEqual(output_spec[0].shape, loop_vars[0].shape)\n        self.assertEqual(output_spec[0].dtype, loop_vars[0].dtype)\n\n        # Test with KerasTensor.\n        loop_vars = (np.random.rand(5, 5), np.random.randint(10, size=(3, 7)))\n        keras_loop_vars = [\n            KerasTensor(v.shape, dtype=v.dtype) for v in loop_vars\n        ]\n        while_loop = core.WhileLoop(cond, body, maximum_iterations=None)\n        output_specs = while_loop.compute_output_spec(keras_loop_vars)\n        self.assertEqual(output_specs[0].shape, keras_loop_vars[0].shape)\n        self.assertEqual(output_specs[0].dtype, keras_loop_vars[0].dtype)\n        self.assertEqual(output_specs[1].shape, keras_loop_vars[1].shape)\n        self.assertEqual(output_specs[1].dtype, keras_loop_vars[1].dtype)\n\n    def test_unstack_unknown_axis_num(self):\n        x = KerasTensor((2, None, None))\n        axis = 1\n        with self.assertRaisesRegex(\n            ValueError, r\"Cannot infer argument `num` from shape\"\n        ):\n            core.unstack(x, axis=axis)\n"
  },
  {
    "path": "keras/src/ops/einops.py",
    "content": "import re\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend import KerasTensor\nfrom keras.src.backend import any_symbolic_tensors\nfrom keras.src.ops.core import shape\nfrom keras.src.ops.numpy import prod\nfrom keras.src.ops.numpy import reshape\nfrom keras.src.ops.numpy import transpose\nfrom keras.src.ops.operation import Operation\n\n\ndef _create_axes_map(axes, input_shape, axes_lengths):\n    axes_map = {}\n\n    for axis, dim in zip(axes, input_shape):\n        # Check for grouped axes pattern, e.g., \"(h1 h)\"\n        grouped_axes = re.match(r\"\\(([\\w\\s]+)\\)\", axis)\n\n        if grouped_axes:\n            inner_axes = grouped_axes.group(1).split()\n            known_axes = [a for a in inner_axes if a in axes_lengths]\n            inferred_axes = [a for a in inner_axes if a not in axes_lengths]\n\n            if inferred_axes:\n                inferred_axis = inferred_axes[0]\n                known_product = prod([axes_lengths[a] for a in known_axes])\n                axes_lengths[inferred_axis] = dim // known_product\n\n            axes_map.update({a: axes_lengths[a] for a in inner_axes})\n        else:\n            axes_map[axis] = dim\n\n    return axes_map\n\n\ndef _create_grouped_axes(axes):\n    grouped_output_axes = []\n    for axis in axes:\n        grouped_axes = re.match(r\"\\(([\\w\\s]+)\\)\", axis)\n\n        if grouped_axes:\n            inner_axes = grouped_axes.group(1).split()\n            grouped_output_axes.append(inner_axes)\n        else:\n            grouped_output_axes.append([axis])\n\n    return grouped_output_axes\n\n\ndef _flatten_group(axes):\n    return [x for xs in axes for x in xs]\n\n\ndef _get_transpose_order(from_shape, to_shape):\n    flattened_from_shape = _flatten_group(_create_grouped_axes(from_shape))\n\n    return [flattened_from_shape.index(dim) for dim in to_shape]\n\n\ndef _compute_output_shape(axes_map, grouped_axes):\n    output_shape = []\n    for group in grouped_axes:\n        size = 1\n        for axis in group:\n            size *= axes_map[axis]\n        output_shape.append(size)\n\n    return tuple(output_shape)\n\n\ndef _compute_decomposed_shape(input_axes, axes_lengths, axes_map):\n    reshaped_input_axes = []\n    reshaped_sizes = []\n\n    for axis in input_axes:\n        if \"(\" in axis:  # Decomposed axis\n            inner_axes = re.findall(r\"\\w+\", axis)\n            sizes = [axes_lengths[a] for a in inner_axes]\n            reshaped_input_axes.extend(inner_axes)\n            reshaped_sizes.extend(sizes)\n        else:\n            reshaped_input_axes.append(axis)\n            reshaped_sizes.append(axes_map[axis])\n\n    return reshaped_sizes\n\n\nclass Rearrange(Operation):\n    def call(self, tensor, pattern, **axes_lengths):\n        return rearrange(tensor, pattern, **axes_lengths)\n\n    def compute_output_spec(self, tensor, pattern, **axes_lengths):\n        input_pattern, output_pattern = re.split(r\"\\s*->\\s*\", pattern)\n        input_axes = re.findall(r\"\\w+|\\(.*?\\)\", input_pattern)\n        output_axes = re.findall(r\"\\w+|\\(.*?\\)\", output_pattern)\n        input_shape = shape(tensor)\n\n        axes_map = _create_axes_map(input_axes, input_shape, axes_lengths)\n        grouped_output_axes = _create_grouped_axes(output_axes)\n        output_shape = _compute_output_shape(axes_map, grouped_output_axes)\n\n        return KerasTensor(shape=output_shape, dtype=tensor.dtype)\n\n\n@keras_export(\"keras.ops.rearrange\")\ndef rearrange(tensor, pattern, **axes_lengths):\n    \"\"\"Rearranges the axes of a Keras tensor according to a specified pattern,\n    einops-style.\n\n    Args:\n        tensor: Input Keras tensor.\n        pattern: String describing the rearrangement in einops notation.\n        **axes_lengths: Keyword arguments specifying lengths of axes\n            when axes decomposition is used.\n\n    Returns:\n        Tensor: A Keras tensor with rearranged axes.\n\n    Follows the logic of:\n\n    1. If decomposition is needed, reshape to match decomposed dimensions.\n    2. Permute known and inferred axes to match the form of the output.\n    3. Reshape to match the desired output shape.\n\n\n    Example Usage:\n\n    ```\n    >>> import numpy as np\n    >>> from keras.ops import rearrange\n    >>> images = np.random.rand(32, 30, 40, 3) # BHWC format\n\n    # Reordering to BCHW\n    >>> rearrange(images, 'b h w c -> b c h w').shape\n    TensorShape([32, 3, 30, 40])\n\n    # \"Merge\" along first axis - concat images from a batch\n    >>> rearrange(images, 'b h w c -> (b h) w c').shape\n    TensorShape([960, 40, 3])\n\n    # \"Merge\" along second axis - concat images horizontally\n    >>> rearrange(images, 'b h w c -> h (b w) c').shape\n    TensorShape([30, 1280, 3])\n\n    # Flatten images into a CHW vector\n    >>> rearrange(images, 'b h w c -> b (c h w)').shape\n    TensorShape([32, 3600])\n\n    # Decompose H and W axes into 4 smaller patches\n    >>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape\n    TensorShape([128, 15, 20, 3])\n\n    # Space-to-depth decomposition of input axes\n    >>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape\n    TensorShape([32, 15, 20, 12])\n    ```\n    \"\"\"  # noqa: E501\n\n    if any_symbolic_tensors((tensor,)):\n        return Rearrange().symbolic_call(tensor, pattern, **axes_lengths)\n\n    # Split the input and output patterns\n    input_pattern, output_pattern = re.split(r\"\\s*->\\s*\", pattern)\n    input_axes = re.findall(r\"\\w+|\\(.*?\\)\", input_pattern)\n    output_axes = re.findall(r\"\\w+|\\(.*?\\)\", output_pattern)\n    input_shape = shape(tensor)\n\n    # Create axes map, and flattened output group\n    axes_map = _create_axes_map(input_axes, input_shape, axes_lengths)\n    grouped_output_axes = _create_grouped_axes(output_axes)\n    flattened_output_axes = _flatten_group(grouped_output_axes)\n\n    # 1. Axes decomposition\n    decomposed_shapes = _compute_decomposed_shape(\n        input_axes, axes_lengths, axes_map\n    )\n    if decomposed_shapes != tensor.shape:\n        tensor = reshape(tensor, decomposed_shapes)\n\n    # 2. Transpose to match target shape\n    permute_order = _get_transpose_order(input_axes, flattened_output_axes)\n    tensor = transpose(tensor, permute_order)\n\n    # 3. Reshape to final target shape\n    output_shape = _compute_output_shape(axes_map, grouped_output_axes)\n    tensor = reshape(tensor, output_shape)\n\n    return tensor\n"
  },
  {
    "path": "keras/src/ops/einops_test.py",
    "content": "from conftest import skip_if_backend\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.backend.common import keras_tensor\nfrom keras.src.ops.einops import rearrange\n\n\nclass RearrangeTest(testing.TestCase):\n    def test_basic_rearrangement_symbolic(self):\n        x = keras_tensor.KerasTensor((2, 3, 4))\n        y = rearrange(x, \"b c h -> b h c\")\n        self.assertIsInstance(y, keras_tensor.KerasTensor)\n        self.assertEqual(y.shape, (2, 4, 3))\n\n    @skip_if_backend(\"openvino\", \"Test operation not supported by openvino\")\n    def test_basic_rearrangement(self):\n        x = ops.random.uniform((2, 3, 4))\n        y = rearrange(x, \"b c h -> b h c\")\n        self.assertEqual(y.shape, (2, 4, 3))\n        self.assertTrue(ops.all(ops.equal(y, ops.transpose(x, (0, 2, 1)))))\n\n    @skip_if_backend(\"openvino\", \"Test operation not supported by openvino\")\n    def test_output_composition(self):\n        x = ops.random.uniform((2, 4, 4, 3))\n        y = rearrange(x, \"b h w c -> (b h) w c\")\n        target_shape = (8, 4, 3)\n        self.assertEqual(y.shape, target_shape)\n        self.assertTrue(ops.all(ops.equal(y, ops.reshape(x, (8, 4, 3)))))\n\n    def test_basic_decomposition_and_rearrangement_symbolic(self):\n        x = keras_tensor.KerasTensor((6, 8))\n        y = rearrange(x, \"(h w) c -> h w c\", h=2, w=3)\n        self.assertIsInstance(y, keras_tensor.KerasTensor)\n        self.assertEqual(y.shape, (2, 3, 8))\n\n    def test_basic_decomposition_and_rearrangement(self):\n        x = ops.random.uniform((6, 8))\n        y = rearrange(x, \"(h w) c -> h w c\", h=2, w=3)\n        self.assertEqual(y.shape, (2, 3, 8))\n\n    @skip_if_backend(\"openvino\", \"Test operation not supported by openvino\")\n    def test_unchanged_shape(self):\n        x = ops.ones([2, 3, 4])\n        y = rearrange(x, \"b h c -> b h c\")\n        self.assertTrue(ops.all(ops.equal(y, x)))\n        self.assertTrue(x.shape, y.shape)\n\n    def test_unchanged_shape_symbolic(self):\n        x = keras_tensor.KerasTensor((2, 3, 4))\n        y = rearrange(x, \"b h c -> b h c\")\n        self.assertTrue(x.shape, y.shape)\n"
  },
  {
    "path": "keras/src/ops/function.py",
    "content": "import collections\n\nfrom keras.src import tree\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend import KerasTensor\nfrom keras.src.backend.config import backend\nfrom keras.src.backend.config import is_nnx_enabled\nfrom keras.src.ops.operation import Operation\n\n\n@keras_export(\"keras.Function\")\nclass Function(Operation):\n    \"\"\"Class that encapsulates a computation graph of Keras operations.\n\n    You can use a `Function` to capture the computation graph linking\n    some input tensors to some output tensors, and reapply the same\n    computation on new inputs.\n\n    A `Function` is similar to a Functional Model, with the difference\n    that it is stateless (it does not track state variables)\n    and does not implement the `Layer` API.\n\n    Example:\n\n    ```python\n    input_1 = keras.KerasTensor(shape=(None, 2, 3))\n    input_2 = keras.KerasTensor(shape=(None, 2, 3))\n    x = input_1 + input_2\n    output = keras.ops.sigmoid(x)\n    fn = keras.Function(inputs=[input_1, input_2], outputs=output)\n\n    input_1_val = np.random.random((4, 2, 3))\n    input_2_val = np.random.random((4, 2, 3))\n    output_val = fn([input_1_val, input_2_val])\n    ```\n\n    Args:\n        inputs: `KerasTensor` instance or nested structured of\n            `KerasTensor` instances.\n        outputs: `KerasTensor` instance or nested structured of\n            `KerasTensor` instances. They should be computable\n            given only the values of `inputs`.\n        name: String. The name of the function.\n    \"\"\"\n\n    def __init__(self, inputs, outputs, name=None):\n        super().__init__(name=name)\n\n        if backend() == \"tensorflow\":\n            # Temporary work around for\n            # https://github.com/keras-team/keras/issues/931\n            # This stop tensorflow from wrapping tf.function output in a\n            # _DictWrapper object.\n            _self_setattr_tracking = getattr(\n                self, \"_self_setattr_tracking\", True\n            )\n            self._self_setattr_tracking = False\n        self._inputs_struct = tree.map_structure(lambda x: x, inputs)\n        self._outputs_struct = tree.map_structure(lambda x: x, outputs)\n        self._inputs = tree.flatten(inputs)\n        self._outputs = tree.flatten(outputs)\n        if not self._inputs:\n            raise ValueError(\n                \"`inputs` argument cannot be empty. Received:\\n\"\n                f\"inputs={inputs}\\n\"\n                f\"outputs={outputs}\"\n            )\n        if not self._outputs:\n            raise ValueError(\n                \"`outputs` argument cannot be empty. Received:\\n\"\n                f\"inputs={inputs}\\n\"\n                f\"outputs={outputs}\"\n            )\n\n        if backend() == \"tensorflow\":\n            self._self_setattr_tracking = _self_setattr_tracking\n\n        (nodes, nodes_by_depth, operations, operations_by_depth) = map_graph(\n            self._inputs, self._outputs\n        )\n        self._nodes = nodes\n        self._nodes_by_depth = nodes_by_depth\n        self._operations = operations\n        self._operations_by_depth = operations_by_depth\n\n        # Run through graph to check all outputs are connected to the inputs.\n        def empty_op_outputs(op, *args, **kwargs):\n            return [None] * len(tree.flatten(op.output))\n\n        self._run_through_graph(\n            [None] * len(self._inputs), call_fn=empty_op_outputs\n        )\n\n        # Special handling for NNX to ensure consistent operation instance usage\n        if is_nnx_enabled():\n            self._setup_nnx_op_mapping()\n\n    @property\n    def operations(self):\n        return self._operations[:]\n\n    @property\n    def inputs(self):\n        \"\"\"Flat list of the symbolic inputs of the Function.\"\"\"\n        return self._inputs\n\n    @property\n    def outputs(self):\n        \"\"\"Flat list of the symbolic outputs of the Function.\"\"\"\n        return self._outputs\n\n    def _setup_nnx_op_mapping(self):\n        \"\"\"Setup operation mapping for NNX\"\"\"\n        # Create a mapping from operation id to operation instance\n        self._nnx_op_mapping = {}\n\n        # Assign the list of operations to a single attribute for NNX traversal\n        self.nnx_operations = self._operations[:]\n        for operation in self._operations:\n            # Map the operation id to this operation instance\n            self._nnx_op_mapping[id(operation)] = operation\n\n    def _get_operation_for_node(self, node):\n        \"\"\"Get the operation for a node, using NNX mapping if enabled.\"\"\"\n        operation = node.operation\n        if hasattr(self, \"_nnx_op_mapping\") and id(operation) in getattr(\n            self, \"_nnx_op_mapping\", {}\n        ):\n            return self._nnx_op_mapping[id(operation)]\n        return operation\n\n    def compute_output_spec(self, inputs):\n        self._assert_input_compatibility(inputs)\n        # Check if input shapes are identical to ref input shapes,\n        # if so take a shortcut.\n        shortcut = True\n        for x, x_ref in zip(tree.flatten(inputs), self._inputs):\n            if x.shape != x_ref.shape:\n                shortcut = False\n                break\n        if shortcut:\n            return tree.map_structure(\n                lambda x: KerasTensor(shape=x.shape, dtype=x.dtype),\n                self._outputs_struct,\n            )\n        # No luck; take the long road through the graph.\n        # Original Keras used a cache to avoid recomputing all this\n        # when known input shapes where seen again. Perhaps a good\n        # idea to bring that back.\n        return self._run_through_graph(\n            inputs, operation_fn=lambda op: op.compute_output_spec\n        )\n\n    def compute_output_shape(self, input_shape):\n        # Wrap `input_shape` into the structure of KerasTensor to utilize\n        # `compute_output_spec`.\n        input_shape_struct = tree.map_shape_structure(\n            lambda x: KerasTensor(shape=x), input_shape\n        )\n        # Ensure that dtype and sparse settings are the same as self._inputs,\n        # because we only care about the shape in this function.\n        for x, x_ref in zip(tree.flatten(input_shape_struct), self._inputs):\n            x._dtype = x_ref.dtype\n            x._sparse = x_ref.sparse\n        output_spec = self.compute_output_spec(input_shape_struct)\n        return tree.map_structure(lambda x: x.shape, output_spec)\n\n    def call(self, inputs):\n        \"\"\"Computes output tensors for new inputs.\"\"\"\n        self._assert_input_compatibility(inputs)\n        return self._run_through_graph(inputs)\n\n    def _run_through_graph(\n        self, inputs, operation_fn=lambda op: op, call_fn=None\n    ):\n        \"\"\"Execute the graph.\n\n        At each node we compute outputs via\n        `operation_fn(node.operation)(*args, **kwargs)`.\n        \"\"\"\n        inputs = tree.flatten(inputs)\n\n        # Dictionary mapping reference tensors to computed tensors.\n        tensor_dict = {}\n        for x, y in zip(self.inputs, inputs):\n            tensor_dict[id(x)] = y\n\n        nodes_by_depth = self._nodes_by_depth\n        depth_keys = list(nodes_by_depth.keys())\n        depth_keys.sort(reverse=True)\n\n        for depth in depth_keys:\n            nodes = nodes_by_depth[depth]\n            for node in nodes:\n                if not node.operation or node.is_input:\n                    continue  # Input tensors already exist.\n\n                if any(id(x) not in tensor_dict for x in node.input_tensors):\n                    continue  # Node is not computable, try skipping.\n\n                args, kwargs = node.arguments.fill_in(tensor_dict)\n                if call_fn is not None:\n                    # Use call_fn if provided (e.g., for symbolic execution)\n                    op = operation_fn(node.operation)\n                    outputs = call_fn(op, *args, **kwargs)\n                else:\n                    # Use NNX operation mapping\n                    operation = self._get_operation_for_node(node)\n                    op = operation_fn(operation)\n                    outputs = op(*args, **kwargs)\n\n                # Update tensor_dict.\n                for x, y in zip(node.outputs, tree.flatten(outputs)):\n                    tensor_dict[id(x)] = y\n\n        output_tensors = []\n        for i, x in enumerate(self.outputs):\n            if id(x) not in tensor_dict:\n                path = tree.flatten_with_path(self._outputs_struct)[i][0]\n                path = \".\".join(str(p) for p in path)\n                raise ValueError(\n                    f\"Output with path `{path}` is not connected to `inputs`\"\n                )\n            output_tensors.append(tensor_dict[id(x)])\n\n        return tree.pack_sequence_as(self._outputs_struct, output_tensors)\n\n    def _assert_input_compatibility(self, inputs):\n        try:\n            tree.assert_same_structure(inputs, self._inputs_struct)\n        except ValueError:\n            raise ValueError(\n                \"Function was called with an invalid input structure. \"\n                f\"Expected input structure: {self._inputs_struct}\\n\"\n                f\"Received input structure: {inputs}\"\n            )\n        for x, x_ref in zip(tree.flatten(inputs), self._inputs):\n            if len(x.shape) != len(x_ref.shape):\n                raise ValueError(\n                    f\"{self.__class__.__name__} was passed \"\n                    f\"incompatible inputs. For input '{x_ref.name}', \"\n                    f\"expected shape {x_ref.shape}, but received \"\n                    f\"instead a tensor with shape {x.shape}.\"\n                )\n            for dim, ref_dim in zip(x.shape, x_ref.shape):\n                if ref_dim is not None and dim is not None:\n                    if dim != ref_dim:\n                        raise ValueError(\n                            f\"{self.__class__.__name__} was passed \"\n                            f\"incompatible inputs. For input '{x_ref.name}', \"\n                            f\"expected shape {x_ref.shape}, but received \"\n                            f\"instead a tensor with shape {x.shape}.\"\n                        )\n\n\ndef make_node_key(op, node_index):\n    return f\"{id(op)}_ib-{node_index}\"\n\n\ndef map_graph(inputs, outputs):\n    \"\"\"Validates a graph's topology and gather its operations and nodes.\n\n    Args:\n        inputs: List of input tensors.\n        outputs: List of outputs tensors.\n\n    Returns:\n        A tuple `(nodes, nodes_by_depth, operations, operations_by_depth)`.\n        - nodes: set of Node instances\n        - nodes_by_depth: dict mapping ints (depth) to lists of node instances.\n        - operations: list of Operation instances.\n        - operations_by_depth: dict mapping ints (depth) to lists of Operation\n            instances.\n    \"\"\"\n    # \"depth\" is number of operations between output Node and the Node.\n    # Nodes are ordered from inputs -> outputs.\n    nodes_in_decreasing_depth, operation_indices = _build_map(inputs, outputs)\n    nodes_in_graph = set(nodes_in_decreasing_depth)\n    network_nodes = {\n        make_node_key(node.operation, node.operation._inbound_nodes.index(node))\n        for node in nodes_in_decreasing_depth\n    }\n\n    nodes_depths = {}  # dict {node: depth value}\n    operations_depths = {}  # dict {operation: depth value}\n\n    for node in reversed(nodes_in_decreasing_depth):\n        # If the depth is not set, the node has no outbound nodes (depth 0).\n        depth = nodes_depths.setdefault(node, 0)\n\n        # Update the depth of the corresponding operation\n        previous_depth = operations_depths.get(node.operation, 0)\n        # If we've seen this operation before at a higher depth,\n        # we should use that depth instead of the node depth.\n        # This is necessary for shared operations that have inputs at different\n        # depth levels in the graph.\n        depth = max(depth, previous_depth)\n        operations_depths[node.operation] = depth\n        nodes_depths[node] = depth\n\n        # Update the depth of inbound nodes.\n        # The \"depth\" of a node is the max of the depths\n        # of all nodes it is connected to + 1.\n        # Only update nodes that are actually part of the graph.\n        for node_dep in node.parent_nodes:\n            if node_dep not in nodes_in_graph:\n                continue\n            previous_depth = nodes_depths.get(node_dep, 0)\n            nodes_depths[node_dep] = max(depth + 1, previous_depth)\n\n    # Handle inputs that are not connected to outputs.\n    # We do not error out here because the inputs may be used to compute losses\n    # and metrics.\n    for input_t in inputs:\n        input_operation = input_t._keras_history[0]\n        if input_operation and input_operation not in operations_depths:\n            node_index = input_t._keras_history.node_index\n            node = input_operation._inbound_nodes[node_index]\n            # Add InputLayer operations (unused inputs) unconditionally.\n            # Skip non-InputLayer operations, as they produce intermediate\n            # tensors used as Function inputs and are outside the graph.\n            if node.is_input:\n                operations_depths[input_operation] = 0\n                operation_indices[input_operation] = -1\n                nodes_depths[node] = 0\n                network_nodes.add(make_node_key(input_operation, node_index))\n\n    # Build a dict {depth: list of nodes with this depth}\n    nodes_by_depth = collections.defaultdict(list)\n    for node, depth in nodes_depths.items():\n        nodes_by_depth[depth].append(node)\n\n    # Build a dict {depth: list of operations with this depth}\n    operations_by_depth = collections.defaultdict(list)\n    for operation, depth in operations_depths.items():\n        operations_by_depth[depth].append(operation)\n\n    # Get sorted list of operation depths.\n    depth_keys = list(operations_by_depth.keys())\n    depth_keys.sort(reverse=True)\n\n    # Set self.operations ordered by depth.\n    operations = []\n    for depth in depth_keys:\n        operations_for_depth = operations_by_depth[depth]\n        # Network.operations needs to have a deterministic order:\n        # here we order them by traversal order.\n        operations_for_depth.sort(key=lambda x: operation_indices[x])\n        operations.extend(operations_for_depth)\n\n    # Get sorted list of node depths.\n    depth_keys = list(nodes_by_depth.keys())\n    depth_keys.sort(reverse=True)\n\n    # Check that all tensors required are computable.\n    # computable_tensors: all tensors in the graph\n    # that can be computed from the inputs provided.\n    computable_tensors = set()\n    for x in inputs:\n        computable_tensors.add(x)\n\n    operations_with_complete_input = []  # To provide a better error msg.\n    for depth in depth_keys:\n        for node in nodes_by_depth[depth]:\n            for x in tree.flatten(node.input_tensors):\n                if x not in computable_tensors:\n                    operation = node.operation\n                    raise ValueError(\n                        \"Graph disconnected: cannot find parent for \"\n                        f\"tensor {x} at operation '{operation}'. \"\n                        \"The following previous operations were accessed \"\n                        f\"without issue: {operations_with_complete_input}\"\n                    )\n                operations_with_complete_input.append(node.operation.name)\n\n            for x in tree.flatten(node.outputs):\n                computable_tensors.add(x)\n\n    # Ensure name unicity, which will be crucial for serialization\n    # (since serialized nodes refer to operations by their name).\n    all_names = [operation.name for operation in operations]\n    for name in all_names:\n        if all_names.count(name) != 1:\n            raise ValueError(\n                f'The name \"{name}\" is used {all_names.count(name)} '\n                \"times in the model. All operation names should be unique.\"\n            )\n    return network_nodes, nodes_by_depth, operations, operations_by_depth\n\n\ndef _build_map(inputs, outputs):\n    \"\"\"Topologically sort nodes in order from inputs to outputs.\n\n    It uses a depth-first search to topologically sort nodes that appear in the\n    _keras_history connectivity metadata of `outputs`.\n\n    Args:\n        outputs: the output tensors whose _keras_history metadata should be\n                walked. This may be an arbitrary nested structure.\n\n    Returns:\n        A tuple like (ordered_nodes, operation_to_first_traversal_index)\n        ordered_nodes: list of nodes appearing in the keras history,\n            topologically sorted from original inputs to the `outputs`.\n            (If outputs have different sets of ancestors, the inputs to one\n            output may appear after a different output).\n        operation_to_first_traversal_index:\n            A dict mapping operation to the traversal index in the DFS where it\n            is seen. Note: if a operation is shared by several nodes, the dict\n            will onlystore the index corresponding to the *first* time the\n            operation seen.\n    \"\"\"\n    finished_nodes = set()\n    nodes_in_progress = set()\n    nodes_in_decreasing_depth = []  # nodes from inputs -> outputs.\n    operation_indices = {}  # operation -> in traversal order.\n    for output in tree.flatten(outputs):\n        _build_map_helper(\n            inputs,\n            output,\n            finished_nodes,\n            nodes_in_progress,\n            nodes_in_decreasing_depth,\n            operation_indices,\n        )\n    return nodes_in_decreasing_depth, operation_indices\n\n\ndef _build_map_helper(\n    inputs,\n    tensor,\n    finished_nodes,\n    nodes_in_progress,\n    nodes_in_decreasing_depth,\n    operation_indices,\n):\n    \"\"\"Recursive helper for `_build_map`.\"\"\"\n    (\n        operation,\n        node_index,\n        _,\n    ) = tensor._keras_history\n    if not operation:\n        return\n\n    node = operation._inbound_nodes[node_index]\n\n    # Don't repeat work for shared subgraphs\n    if node in finished_nodes:\n        return\n\n    # If this tensor is one of the declared inputs and its producing\n    # operation is not an InputLayer, stop traversal here. The operation\n    # that produced this tensor is outside the Function's graph.\n    flat_inputs = tree.flatten(inputs)\n    if not node.is_input and tensor in flat_inputs:\n        finished_nodes.add(node)\n        return\n\n    # Prevent cycles.\n    if node in nodes_in_progress:\n        raise ValueError(\n            f\"Tensor {tensor} from operation '{operation.name}' is part of a \"\n            \"cycle.\"\n        )\n\n    # Store the traversal order for operation sorting.\n    if operation not in operation_indices:\n        operation_indices[operation] = len(operation_indices)\n\n    # Propagate to all previous tensors connected to this node.\n    nodes_in_progress.add(node)\n    if not node.is_input:\n        for input_tensor in node.input_tensors:\n            _build_map_helper(\n                inputs,\n                input_tensor,\n                finished_nodes,\n                nodes_in_progress,\n                nodes_in_decreasing_depth,\n                operation_indices,\n            )\n\n    finished_nodes.add(node)\n    nodes_in_progress.remove(node)\n    nodes_in_decreasing_depth.append(node)\n"
  },
  {
    "path": "keras/src/ops/function_test.py",
    "content": "import json\n\nimport numpy as np\n\nfrom keras.src import testing\nfrom keras.src.backend.common import keras_tensor\nfrom keras.src.layers import Dense\nfrom keras.src.layers import Input\nfrom keras.src.models import Model\nfrom keras.src.models import Sequential\nfrom keras.src.ops import function\nfrom keras.src.ops import numpy as knp\n\n\nclass FunctionTest(testing.TestCase):\n    def test_define_and_call(self):\n        x1 = keras_tensor.KerasTensor((2, 3))\n        x2 = keras_tensor.KerasTensor((2, 3))\n        x = knp.add(x1, x2)\n        y1 = x * 3\n        y2 = x**2\n        fn = function.Function(\n            inputs=[x1, x2], outputs=[y1, y2], name=\"test_function\"\n        )\n        self.assertEqual(fn.name, \"test_function\")\n\n        # Eager call\n        y_val = fn([np.ones((2, 3)), np.ones((2, 3))])\n        self.assertIsInstance(y_val, list)\n        self.assertAllClose(y_val[0], np.ones((2, 3)) * 6)\n        self.assertAllClose(y_val[1], np.ones((2, 3)) * 4)\n\n        # Symbolic call\n        x1_alt = keras_tensor.KerasTensor((2, 3))\n        x2_alt = keras_tensor.KerasTensor((2, 3))\n        y_val = fn([x1_alt, x2_alt])\n        self.assertIsInstance(y_val[0], keras_tensor.KerasTensor)\n        self.assertEqual(y_val[0].shape, (2, 3))\n        self.assertIsInstance(y_val[1], keras_tensor.KerasTensor)\n        self.assertEqual(y_val[1].shape, (2, 3))\n\n        # Recursion\n        fn = function.Function(inputs=[x1_alt, x2_alt], outputs=y_val)\n        y_val = fn([np.ones((2, 3)), np.ones((2, 3))])\n        self.assertIsInstance(y_val, list)\n        self.assertAllClose(y_val[0], np.ones((2, 3)) * 6)\n        self.assertAllClose(y_val[1], np.ones((2, 3)) * 4)\n\n    def test_dynamic_shape_inference(self):\n        x = keras_tensor.KerasTensor((None, 3))\n        y = x**2\n        fn = function.Function(x, y)\n\n        # Test with compute_output_spec\n        out = fn.compute_output_spec(keras_tensor.KerasTensor((4, 3)))\n        self.assertIsInstance(out, keras_tensor.KerasTensor)\n        self.assertEqual(out.shape, (4, 3))\n\n        # Test with compute_output_shape\n        out = fn.compute_output_shape((None, 3))\n        self.assertIsInstance(out, tuple)\n        self.assertEqual(out, (None, 3))\n\n        # Test with call\n        out = fn(keras_tensor.KerasTensor((4, 3)))\n        self.assertIsInstance(out, keras_tensor.KerasTensor)\n        self.assertEqual(out.shape, (4, 3))\n\n    def test_dict_io(self):\n        x1 = keras_tensor.KerasTensor((2, 3))\n        x2 = keras_tensor.KerasTensor((2, 3))\n        x = knp.add(x1, x2)\n        y1 = x * 3\n        y2 = x**2\n        fn = function.Function(\n            inputs={\"x1\": x1, \"x2\": x2}, outputs={\"y1\": y1, \"y2\": y2}\n        )\n\n        # Eager call\n        y_val = fn({\"x1\": np.ones((2, 3)), \"x2\": np.ones((2, 3))})\n        self.assertIsInstance(y_val, dict)\n        self.assertAllClose(y_val[\"y1\"], np.ones((2, 3)) * 6)\n        self.assertAllClose(y_val[\"y2\"], np.ones((2, 3)) * 4)\n\n        # Symbolic call\n        x1_alt = keras_tensor.KerasTensor((2, 3))\n        x2_alt = keras_tensor.KerasTensor((2, 3))\n        y_val = fn({\"x1\": x1_alt, \"x2\": x2_alt})\n        self.assertIsInstance(y_val[\"y1\"], keras_tensor.KerasTensor)\n        self.assertEqual(y_val[\"y1\"].shape, (2, 3))\n        self.assertIsInstance(y_val[\"y2\"], keras_tensor.KerasTensor)\n        self.assertEqual(y_val[\"y2\"].shape, (2, 3))\n\n    def test_invalid_inputs_error(self):\n        x1 = keras_tensor.KerasTensor((2, 3))\n        x2 = keras_tensor.KerasTensor((2, 3))\n        x = knp.add(x1, x2)\n        y1 = x * 3\n        y2 = x**2\n        fn = function.Function(\n            inputs=[x1, x2], outputs=[y1, y2], name=\"test_function\"\n        )\n        self.assertEqual(fn.name, \"test_function\")\n\n        # Bad structure\n        with self.assertRaisesRegex(ValueError, \"invalid input structure\"):\n            _ = fn(np.ones((2, 3)))\n\n        # Bad rank\n        with self.assertRaisesRegex(ValueError, \"incompatible inputs\"):\n            _ = fn([np.ones((2, 3, 3)), np.ones((2, 3))])\n\n        # Bad shape\n        with self.assertRaisesRegex(ValueError, \"incompatible inputs\"):\n            _ = fn([np.ones((4, 3)), np.ones((2, 3))])\n\n    def test_serialization(self):\n        inputs = Input(shape=(10,))\n        outputs = Dense(1)(inputs)\n        model = Model(inputs=inputs, outputs=outputs)\n\n        config = model.get_config()\n        new_model = Model.from_config(config)\n\n        self.assertEqual(\n            json.dumps(model.get_config()), json.dumps(new_model.get_config())\n        )\n\n    def test_function_with_empty_outputs(self):\n        x = keras_tensor.KerasTensor((None, 3))\n        with self.assertRaisesRegex(\n            ValueError, \"`outputs` argument cannot be empty\"\n        ):\n            _ = function.Function(inputs=x, outputs=[])\n\n    def test_function_with_empty_inputs(self):\n        x = keras_tensor.KerasTensor((None, 3))\n        with self.assertRaisesRegex(\n            ValueError, \"`inputs` argument cannot be empty\"\n        ):\n            _ = function.Function(inputs=[], outputs=x)\n\n    def test_function_with_unconnected_inputs(self):\n        model_1 = Sequential(\n            [\n                Input(shape=(6,)),\n                Dense(3, activation=\"sigmoid\"),\n            ]\n        )\n        model_2 = Sequential(\n            [\n                Input(shape=(3,)),\n                Dense(2, activation=\"sigmoid\"),\n            ],\n        )\n        with self.assertRaisesRegex(\n            ValueError, \"Output .* is not connected to `inputs`\"\n        ):\n            _ = Model(Input(shape=(6,)), model_2(model_1(Input(shape=(6,)))))\n\n        with self.assertRaisesRegex(\n            ValueError, \"Output .* is not connected to `inputs`\"\n        ):\n            _ = Model(model_1(Input(shape=(6,))), model_2(Input(shape=(3,))))\n\n    def test_function_with_intermediate_tensor_input(self):\n        \"\"\"Function with intermediate tensor as input should exclude\n        operations that produce those tensors.\"\"\"\n        x = Input(batch_shape=(), name=\"x\")\n        y = x**2\n        z = y + 1\n        fn = function.Function(y, z)\n\n        # The power operation produces `y` but is outside the graph\n        # boundary. Only the add operation should be included.\n        op_names = [op.name for op in fn.operations]\n        self.assertNotIn(\"power\", op_names)\n        self.assertIn(\"add\", op_names)\n\n        # Verify the function computes correctly (input is y, output is y+1)\n        result = fn(np.array(3.0))\n        self.assertAllClose(result, np.array(4.0))  # 3 + 1 = 4\n\n    def test_function_with_intermediate_tensor_input_chain(self):\n        \"\"\"Function with intermediate tensor input from a longer chain.\"\"\"\n        x = Input(batch_shape=(None, 3), name=\"x\")\n        a = x * 2\n        b = a + 1\n        c = b**2\n        fn = function.Function(b, c)\n\n        op_names = [op.name for op in fn.operations]\n        # Only the power op (producing c from b) should be in the graph.\n        # multiply and add that produce a and b are outside.\n        self.assertNotIn(\"multiply\", op_names)\n        self.assertNotIn(\"add\", op_names)\n        self.assertIn(\"power\", op_names)\n\n        result = fn(np.ones((2, 3)) * 3)\n        self.assertAllClose(result, np.ones((2, 3)) * 9)  # 3^2 = 9\n"
  },
  {
    "path": "keras/src/ops/image.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend import KerasTensor\nfrom keras.src.backend import any_symbolic_tensors\nfrom keras.src.ops.operation import Operation\nfrom keras.src.ops.operation_utils import compute_conv_output_shape\n\n\nclass RGBToGrayscale(Operation):\n    def __init__(self, data_format=None, *, name=None):\n        super().__init__(name=name)\n        self.data_format = backend.standardize_data_format(data_format)\n\n    def call(self, images):\n        return backend.image.rgb_to_grayscale(\n            images, data_format=self.data_format\n        )\n\n    def compute_output_spec(self, images):\n        images_shape = list(images.shape)\n        if len(images_shape) not in (3, 4):\n            raise ValueError(\n                \"Invalid images rank: expected rank 3 (single image) \"\n                \"or rank 4 (batch of images). \"\n                f\"Received: images.shape={images_shape}\"\n            )\n        if self.data_format == \"channels_last\":\n            images_shape[-1] = 1\n        else:\n            images_shape[-3] = 1\n        return KerasTensor(shape=images_shape, dtype=images.dtype)\n\n\n@keras_export(\"keras.ops.image.rgb_to_grayscale\")\ndef rgb_to_grayscale(images, data_format=None):\n    \"\"\"Convert RGB images to grayscale.\n\n    This function converts RGB images to grayscale images. It supports both\n    3D and 4D tensors.\n\n    Args:\n        images: Input image or batch of images. Must be 3D or 4D.\n        data_format: A string specifying the data format of the input tensor.\n            It can be either `\"channels_last\"` or `\"channels_first\"`.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch, height, width, channels)`, while `\"channels_first\"`\n            corresponds to inputs with shape `(batch, channels, height, width)`.\n            If not specified, the value will default to\n            `keras.config.image_data_format`.\n\n    Returns:\n        Grayscale image or batch of grayscale images.\n\n    Examples:\n\n    >>> import numpy as np\n    >>> from keras import ops\n    >>> x = np.random.random((2, 4, 4, 3))\n    >>> y = ops.image.rgb_to_grayscale(x)\n    >>> y.shape\n    (2, 4, 4, 1)\n\n    >>> x = np.random.random((4, 4, 3)) # Single RGB image\n    >>> y = ops.image.rgb_to_grayscale(x)\n    >>> y.shape\n    (4, 4, 1)\n\n    >>> x = np.random.random((2, 3, 4, 4))\n    >>> y = ops.image.rgb_to_grayscale(x, data_format=\"channels_first\")\n    >>> y.shape\n    (2, 1, 4, 4)\n    \"\"\"\n    if any_symbolic_tensors((images,)):\n        return RGBToGrayscale(data_format=data_format).symbolic_call(images)\n    return backend.image.rgb_to_grayscale(images, data_format=data_format)\n\n\nclass RGBToHSV(Operation):\n    def __init__(self, data_format=None, *, name=None):\n        super().__init__(name=name)\n        self.data_format = backend.standardize_data_format(data_format)\n\n    def call(self, images):\n        return backend.image.rgb_to_hsv(images, data_format=self.data_format)\n\n    def compute_output_spec(self, images):\n        images_shape = list(images.shape)\n        dtype = images.dtype\n        if len(images_shape) not in (3, 4):\n            raise ValueError(\n                \"Invalid images rank: expected rank 3 (single image) \"\n                \"or rank 4 (batch of images). \"\n                f\"Received: images.shape={images_shape}\"\n            )\n        if not backend.is_float_dtype(dtype):\n            raise ValueError(\n                \"Invalid images dtype: expected float dtype. \"\n                f\"Received: images.dtype={dtype}\"\n            )\n        return KerasTensor(shape=images_shape, dtype=images.dtype)\n\n\n@keras_export(\"keras.ops.image.rgb_to_hsv\")\ndef rgb_to_hsv(images, data_format=None):\n    \"\"\"Convert RGB images to HSV.\n\n    `images` must be of float dtype, and the output is only well defined if the\n    values in `images` are in `[0, 1]`.\n\n    All HSV values are in `[0, 1]`. A hue of `0` corresponds to pure red, `1/3`\n    is pure green, and `2/3` is pure blue.\n\n    Args:\n        images: Input image or batch of images. Must be 3D or 4D.\n        data_format: A string specifying the data format of the input tensor.\n            It can be either `\"channels_last\"` or `\"channels_first\"`.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch, height, width, channels)`, while `\"channels_first\"`\n            corresponds to inputs with shape `(batch, channels, height, width)`.\n            If not specified, the value will default to\n            `keras.config.image_data_format`.\n\n    Returns:\n        HSV image or batch of HSV images.\n\n    Examples:\n\n    >>> import numpy as np\n    >>> from keras import ops\n    >>> x = np.random.random((2, 4, 4, 3))\n    >>> y = ops.image.rgb_to_hsv(x)\n    >>> y.shape\n    (2, 4, 4, 3)\n\n    >>> x = np.random.random((4, 4, 3)) # Single RGB image\n    >>> y = ops.image.rgb_to_hsv(x)\n    >>> y.shape\n    (4, 4, 3)\n\n    >>> x = np.random.random((2, 3, 4, 4))\n    >>> y = ops.image.rgb_to_hsv(x, data_format=\"channels_first\")\n    >>> y.shape\n    (2, 3, 4, 4)\n    \"\"\"\n    if any_symbolic_tensors((images,)):\n        return RGBToHSV(data_format=data_format).symbolic_call(images)\n    return backend.image.rgb_to_hsv(images, data_format=data_format)\n\n\nclass HSVToRGB(Operation):\n    def __init__(self, data_format=None, *, name=None):\n        super().__init__(name=name)\n        self.data_format = backend.standardize_data_format(data_format)\n\n    def call(self, images):\n        return backend.image.hsv_to_rgb(images, data_format=self.data_format)\n\n    def compute_output_spec(self, images):\n        images_shape = list(images.shape)\n        dtype = images.dtype\n        if len(images_shape) not in (3, 4):\n            raise ValueError(\n                \"Invalid images rank: expected rank 3 (single image) \"\n                \"or rank 4 (batch of images). \"\n                f\"Received: images.shape={images_shape}\"\n            )\n        if not backend.is_float_dtype(dtype):\n            raise ValueError(\n                \"Invalid images dtype: expected float dtype. \"\n                f\"Received: images.dtype={dtype}\"\n            )\n        return KerasTensor(shape=images_shape, dtype=images.dtype)\n\n\n@keras_export(\"keras.ops.image.hsv_to_rgb\")\ndef hsv_to_rgb(images, data_format=None):\n    \"\"\"Convert HSV images to RGB.\n\n    `images` must be of float dtype, and the output is only well defined if the\n    values in `images` are in `[0, 1]`.\n\n    Args:\n        images: Input image or batch of images. Must be 3D or 4D.\n        data_format: A string specifying the data format of the input tensor.\n            It can be either `\"channels_last\"` or `\"channels_first\"`.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch, height, width, channels)`, while `\"channels_first\"`\n            corresponds to inputs with shape `(batch, channels, height, width)`.\n            If not specified, the value will default to\n            `keras.config.image_data_format`.\n\n    Returns:\n        RGB image or batch of RGB images.\n\n    Examples:\n\n    >>> import numpy as np\n    >>> from keras import ops\n    >>> x = np.random.random((2, 4, 4, 3))\n    >>> y = ops.image.hsv_to_rgb(x)\n    >>> y.shape\n    (2, 4, 4, 3)\n\n    >>> x = np.random.random((4, 4, 3)) # Single HSV image\n    >>> y = ops.image.hsv_to_rgb(x)\n    >>> y.shape\n    (4, 4, 3)\n\n    >>> x = np.random.random((2, 3, 4, 4))\n    >>> y = ops.image.hsv_to_rgb(x, data_format=\"channels_first\")\n    >>> y.shape\n    (2, 3, 4, 4)\n    \"\"\"\n    if any_symbolic_tensors((images,)):\n        return HSVToRGB(data_format=data_format).symbolic_call(images)\n    return backend.image.hsv_to_rgb(images, data_format=data_format)\n\n\nclass Resize(Operation):\n    def __init__(\n        self,\n        size,\n        interpolation=\"bilinear\",\n        antialias=False,\n        crop_to_aspect_ratio=False,\n        pad_to_aspect_ratio=False,\n        fill_mode=\"constant\",\n        fill_value=0.0,\n        data_format=None,\n        *,\n        name=None,\n    ):\n        super().__init__(name=name)\n        self.size = tuple(size)\n        self.interpolation = interpolation\n        self.antialias = antialias\n        self.crop_to_aspect_ratio = crop_to_aspect_ratio\n        self.pad_to_aspect_ratio = pad_to_aspect_ratio\n        self.fill_mode = fill_mode\n        self.fill_value = fill_value\n        self.data_format = backend.standardize_data_format(data_format)\n\n    def call(self, images):\n        return _resize(\n            images,\n            self.size,\n            interpolation=self.interpolation,\n            antialias=self.antialias,\n            data_format=self.data_format,\n            crop_to_aspect_ratio=self.crop_to_aspect_ratio,\n            pad_to_aspect_ratio=self.pad_to_aspect_ratio,\n            fill_mode=self.fill_mode,\n            fill_value=self.fill_value,\n        )\n\n    def compute_output_spec(self, images):\n        images_shape = list(images.shape)\n        if len(images_shape) not in (3, 4):\n            raise ValueError(\n                \"Invalid images rank: expected rank 3 (single image) \"\n                \"or rank 4 (batch of images). Received input with shape: \"\n                f\"images.shape={images.shape}\"\n            )\n        if self.data_format == \"channels_last\":\n            height_axis, width_axis = -3, -2\n        else:\n            height_axis, width_axis = -2, -1\n        images_shape[height_axis] = self.size[0]\n        images_shape[width_axis] = self.size[1]\n        return KerasTensor(shape=images_shape, dtype=images.dtype)\n\n\n@keras_export(\"keras.ops.image.resize\")\ndef resize(\n    images,\n    size,\n    interpolation=\"bilinear\",\n    antialias=False,\n    crop_to_aspect_ratio=False,\n    pad_to_aspect_ratio=False,\n    fill_mode=\"constant\",\n    fill_value=0.0,\n    data_format=None,\n):\n    \"\"\"Resize images to size using the specified interpolation method.\n\n    Args:\n        images: Input image or batch of images. Must be 3D or 4D.\n        size: Size of output image in `(height, width)` format.\n        interpolation: Interpolation method. Available methods are `\"nearest\"`,\n            `\"bilinear\"`, and `\"bicubic\"`. Defaults to `\"bilinear\"`.\n        antialias: Whether to use an antialiasing filter when downsampling an\n            image. Defaults to `False`.\n        crop_to_aspect_ratio: If `True`, resize the images without aspect\n            ratio distortion. When the original aspect ratio differs\n            from the target aspect ratio, the output image will be\n            cropped so as to return the\n            largest possible window in the image (of size `(height, width)`)\n            that matches the target aspect ratio. By default\n            (`crop_to_aspect_ratio=False`), aspect ratio may not be preserved.\n        pad_to_aspect_ratio: If `True`, pad the images without aspect\n            ratio distortion. When the original aspect ratio differs\n            from the target aspect ratio, the output image will be\n            evenly padded on the short side.\n        fill_mode: When using `pad_to_aspect_ratio=True`, padded areas\n            are filled according to the given mode. Only `\"constant\"` is\n            supported at this time\n            (fill with constant value, equal to `fill_value`).\n        fill_value: Float. Padding value to use when `pad_to_aspect_ratio=True`.\n        data_format: A string specifying the data format of the input tensor.\n            It can be either `\"channels_last\"` or `\"channels_first\"`.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch, height, width, channels)`, while `\"channels_first\"`\n            corresponds to inputs with shape `(batch, channels, height, width)`.\n            If not specified, the value will default to\n            `keras.config.image_data_format`.\n\n    Returns:\n        Resized image or batch of images.\n\n    Examples:\n\n    >>> x = np.random.random((2, 4, 4, 3)) # batch of 2 RGB images\n    >>> y = keras.ops.image.resize(x, (2, 2))\n    >>> y.shape\n    (2, 2, 2, 3)\n\n    >>> x = np.random.random((4, 4, 3)) # single RGB image\n    >>> y = keras.ops.image.resize(x, (2, 2))\n    >>> y.shape\n    (2, 2, 3)\n\n    >>> x = np.random.random((2, 3, 4, 4)) # batch of 2 RGB images\n    >>> y = keras.ops.image.resize(x, (2, 2),\n    ...     data_format=\"channels_first\")\n    >>> y.shape\n    (2, 3, 2, 2)\n    \"\"\"\n    if len(size) != 2:\n        raise ValueError(\n            \"Expected `size` to be a tuple of 2 integers. \"\n            f\"Received: size={size}\"\n        )\n    if size[0] <= 0 or size[1] <= 0:\n        raise ValueError(\n            f\"`size` must have positive height and width. Received: size={size}\"\n        )\n    if len(images.shape) < 3 or len(images.shape) > 4:\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n    if pad_to_aspect_ratio and crop_to_aspect_ratio:\n        raise ValueError(\n            \"Only one of `pad_to_aspect_ratio` & `crop_to_aspect_ratio` \"\n            \"can be `True`.\"\n        )\n    if any_symbolic_tensors((images,)):\n        return Resize(\n            size,\n            interpolation=interpolation,\n            antialias=antialias,\n            data_format=data_format,\n            crop_to_aspect_ratio=crop_to_aspect_ratio,\n            pad_to_aspect_ratio=pad_to_aspect_ratio,\n            fill_mode=fill_mode,\n            fill_value=fill_value,\n        ).symbolic_call(images)\n    return _resize(\n        images,\n        size,\n        interpolation=interpolation,\n        antialias=antialias,\n        crop_to_aspect_ratio=crop_to_aspect_ratio,\n        data_format=data_format,\n        pad_to_aspect_ratio=pad_to_aspect_ratio,\n        fill_mode=fill_mode,\n        fill_value=fill_value,\n    )\n\n\ndef _resize(\n    images,\n    size,\n    interpolation=\"bilinear\",\n    antialias=False,\n    crop_to_aspect_ratio=False,\n    pad_to_aspect_ratio=False,\n    fill_mode=\"constant\",\n    fill_value=0.0,\n    data_format=None,\n):\n    resized = backend.image.resize(\n        images,\n        size,\n        interpolation=interpolation,\n        antialias=antialias,\n        crop_to_aspect_ratio=crop_to_aspect_ratio,\n        data_format=data_format,\n        pad_to_aspect_ratio=pad_to_aspect_ratio,\n        fill_mode=fill_mode,\n        fill_value=fill_value,\n    )\n    if resized.dtype == images.dtype:\n        # Only `torch` backend will cast result to original dtype with\n        # correct rounding and without dtype overflow\n        return resized\n    if backend.is_int_dtype(images.dtype):\n        resized = ops.round(resized)\n    return ops.saturate_cast(resized, images.dtype)\n\n\nclass AffineTransform(Operation):\n    def __init__(\n        self,\n        interpolation=\"bilinear\",\n        fill_mode=\"constant\",\n        fill_value=0,\n        data_format=None,\n        *,\n        name=None,\n    ):\n        super().__init__(name=name)\n        self.interpolation = interpolation\n        self.fill_mode = fill_mode\n        self.fill_value = fill_value\n        self.data_format = backend.standardize_data_format(data_format)\n\n    def call(self, images, transform):\n        return backend.image.affine_transform(\n            images,\n            transform,\n            interpolation=self.interpolation,\n            fill_mode=self.fill_mode,\n            fill_value=self.fill_value,\n            data_format=self.data_format,\n        )\n\n    def compute_output_spec(self, images, transform):\n        if len(images.shape) not in (3, 4):\n            raise ValueError(\n                \"Invalid images rank: expected rank 3 (single image) \"\n                \"or rank 4 (batch of images). Received input with shape: \"\n                f\"images.shape={images.shape}\"\n            )\n        if len(transform.shape) not in (1, 2):\n            raise ValueError(\n                \"Invalid transform rank: expected rank 1 (single transform) \"\n                \"or rank 2 (batch of transforms). Received input with shape: \"\n                f\"transform.shape={transform.shape}\"\n            )\n        return KerasTensor(images.shape, dtype=images.dtype)\n\n\n@keras_export(\"keras.ops.image.affine_transform\")\ndef affine_transform(\n    images,\n    transform,\n    interpolation=\"bilinear\",\n    fill_mode=\"constant\",\n    fill_value=0,\n    data_format=None,\n):\n    \"\"\"Applies the given transform(s) to the image(s).\n\n    Args:\n        images: Input image or batch of images. Must be 3D or 4D.\n        transform: Projective transform matrix/matrices. A vector of length 8 or\n            tensor of size N x 8. If one row of transform is\n            `[a0, a1, a2, b0, b1, b2, c0, c1]`, then it maps the output point\n            `(x, y)` to a transformed input point\n            `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`,\n            where `k = c0 x + c1 y + 1`. The transform is inverted compared to\n            the transform mapping input points to output points. Note that\n            gradients are not backpropagated into transformation parameters.\n            Note that `c0` and `c1` are only effective when using TensorFlow\n            backend and will be considered as `0` when using other backends.\n        interpolation: Interpolation method. Available methods are `\"nearest\"`,\n            and `\"bilinear\"`. Defaults to `\"bilinear\"`.\n        fill_mode: Points outside the boundaries of the input are filled\n            according to the given mode. Available methods are `\"constant\"`,\n            `\"nearest\"`, `\"wrap\"` and `\"reflect\"`. Defaults to `\"constant\"`.\n            - `\"reflect\"`: `(d c b a | a b c d | d c b a)`\n                The input is extended by reflecting about the edge of the last\n                pixel.\n            - `\"constant\"`: `(k k k k | a b c d | k k k k)`\n                The input is extended by filling all values beyond\n                the edge with the same constant value k specified by\n                `fill_value`.\n            - `\"wrap\"`: `(a b c d | a b c d | a b c d)`\n                The input is extended by wrapping around to the opposite edge.\n            - `\"nearest\"`: `(a a a a | a b c d | d d d d)`\n                The input is extended by the nearest pixel.\n        fill_value: Value used for points outside the boundaries of the input if\n            `fill_mode=\"constant\"`. Defaults to `0`.\n        data_format: A string specifying the data format of the input tensor.\n            It can be either `\"channels_last\"` or `\"channels_first\"`.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch, height, width, channels)`, while `\"channels_first\"`\n            corresponds to inputs with shape `(batch, channels, height, width)`.\n            If not specified, the value will default to\n            `keras.config.image_data_format`.\n\n    Returns:\n        Applied affine transform image or batch of images.\n\n    Examples:\n\n    >>> x = np.random.random((2, 64, 80, 3)) # batch of 2 RGB images\n    >>> transform = np.array(\n    ...     [\n    ...         [1.5, 0, -20, 0, 1.5, -16, 0, 0],  # zoom\n    ...         [1, 0, -20, 0, 1, -16, 0, 0],  # translation\n    ...     ]\n    ... )\n    >>> y = keras.ops.image.affine_transform(x, transform)\n    >>> y.shape\n    (2, 64, 80, 3)\n\n    >>> x = np.random.random((64, 80, 3)) # single RGB image\n    >>> transform = np.array([1.0, 0.5, -20, 0.5, 1.0, -16, 0, 0])  # shear\n    >>> y = keras.ops.image.affine_transform(x, transform)\n    >>> y.shape\n    (64, 80, 3)\n\n    >>> x = np.random.random((2, 3, 64, 80)) # batch of 2 RGB images\n    >>> transform = np.array(\n    ...     [\n    ...         [1.5, 0, -20, 0, 1.5, -16, 0, 0],  # zoom\n    ...         [1, 0, -20, 0, 1, -16, 0, 0],  # translation\n    ...     ]\n    ... )\n    >>> y = keras.ops.image.affine_transform(x, transform,\n    ...     data_format=\"channels_first\")\n    >>> y.shape\n    (2, 3, 64, 80)\n    \"\"\"\n    if any_symbolic_tensors((images, transform)):\n        return AffineTransform(\n            interpolation=interpolation,\n            fill_mode=fill_mode,\n            fill_value=fill_value,\n            data_format=data_format,\n        ).symbolic_call(images, transform)\n    return backend.image.affine_transform(\n        images,\n        transform,\n        interpolation=interpolation,\n        fill_mode=fill_mode,\n        fill_value=fill_value,\n        data_format=data_format,\n    )\n\n\nclass ExtractPatches(Operation):\n    def __init__(\n        self,\n        size,\n        strides=None,\n        dilation_rate=1,\n        padding=\"valid\",\n        data_format=None,\n        *,\n        name=None,\n    ):\n        super().__init__(name=name)\n        if isinstance(size, int):\n            size = (size, size)\n        self.size = size\n        self.is_3d = len(self.size) == 3\n        if strides is None:\n            strides = size\n        self.strides = strides\n        self.dilation_rate = dilation_rate\n        self.padding = padding\n        self.data_format = backend.standardize_data_format(data_format)\n\n    def call(self, images):\n        return _extract_patches(\n            images=images,\n            size=self.size,\n            strides=self.strides,\n            dilation_rate=self.dilation_rate,\n            padding=self.padding,\n            data_format=self.data_format,\n        )\n\n    def compute_output_spec(self, images):\n        images_shape = list(images.shape)\n        original_ndim = len(images_shape)\n        if self.data_format == \"channels_last\":\n            channels_in = images_shape[-1]\n        else:\n            channels_in = images_shape[-4] if self.is_3d else images_shape[-3]\n\n        if self.is_3d:\n            # 3D patch extraction\n            if original_ndim == 4:\n                images_shape = [1] + images_shape\n            filters = self.size[0] * self.size[1] * self.size[2] * channels_in\n            kernel_size = (self.size[0], self.size[1], self.size[2])\n        else:\n            # 2D patch extraction\n            if original_ndim == 3:\n                images_shape = [1] + images_shape\n            filters = self.size[0] * self.size[1] * channels_in\n            kernel_size = (self.size[0], self.size[1])\n\n        out_shape = compute_conv_output_shape(\n            images_shape,\n            filters,\n            kernel_size,\n            strides=self.strides,\n            padding=self.padding,\n            data_format=self.data_format,\n            dilation_rate=self.dilation_rate,\n        )\n\n        if self.is_3d:\n            if original_ndim == 4:\n                out_shape = out_shape[1:]\n        else:\n            if original_ndim == 3:\n                out_shape = out_shape[1:]\n        return KerasTensor(shape=out_shape, dtype=images.dtype)\n\n    def get_config(self):\n        return {\n            \"size\": self.size,\n            \"strides\": self.strides,\n            \"dilation_rate\": self.dilation_rate,\n            \"padding\": self.padding,\n            \"data_format\": self.data_format,\n        }\n\n\n@keras_export(\"keras.ops.image.extract_patches\")\ndef extract_patches(\n    images,\n    size,\n    strides=None,\n    dilation_rate=1,\n    padding=\"valid\",\n    data_format=None,\n):\n    \"\"\"Extracts patches from the image(s) or volume(s).\n\n    This function supports both 2D and 3D patch extraction based on the\n    `size` argument length, similar to how `keras.ops.conv` handles\n    different dimensions.\n\n    Args:\n        images: Input image/volume or batch of images/volumes.\n            For 2D patches: 3D `(H, W, C)` or 4D `(N, H, W, C)`.\n            For 3D patches: 4D `(D, H, W, C)` or 5D `(N, D, H, W, C)`.\n        size: Patch size as int or tuple.\n            Length 2 tuple `(patch_height, patch_width)` or int for 2D patches.\n            Length 3 tuple `(patch_depth, patch_height, patch_width)` for\n            3D patches.\n        strides: Strides for patch extraction. If not specified, defaults\n            to `size` (non-overlapping patches).\n        dilation_rate: Dilation rate for patch extraction. Note that\n            `dilation_rate > 1` is not supported with `strides > 1`.\n        padding: The type of padding algorithm to use: `\"same\"` or `\"valid\"`.\n        data_format: A string specifying the data format of the input tensor.\n            It can be either `\"channels_last\"` or `\"channels_first\"`.\n            If not specified, defaults to `keras.config.image_data_format`.\n\n    Returns:\n        Extracted patches with shape depending on input and `size`:\n        - 2D patches: 3D (unbatched) or 4D (batched)\n        - 3D patches: 4D (unbatched) or 5D (batched)\n\n    Examples:\n\n    >>> # 2D patches from batch of images\n    >>> image = np.random.random(\n    ...     (2, 20, 20, 3)\n    ... ).astype(\"float32\")\n    >>> patches = keras.ops.image.extract_patches(image, (5, 5))\n    >>> patches.shape\n    (2, 4, 4, 75)\n\n    >>> # 2D patches from single image\n    >>> image = np.random.random((20, 20, 3)).astype(\"float32\")\n    >>> patches = keras.ops.image.extract_patches(image, (3, 3), (1, 1))\n    >>> patches.shape\n    (18, 18, 27)\n\n    >>> # 3D patches from batch of volumes\n    >>> volumes = np.random.random(\n    ...     (2, 10, 10, 10, 3)\n    ... ).astype(\"float32\")\n    >>> patches = keras.ops.image.extract_patches(volumes, (3, 3, 3))\n    >>> patches.shape\n    (2, 3, 3, 3, 81)\n\n    >>> # 3D patches from single volume\n    >>> volume = np.random.random((10, 10, 10, 3)).astype(\"float32\")\n    >>> patches = keras.ops.image.extract_patches(volume, (3, 3, 3))\n    >>> patches.shape\n    (3, 3, 3, 81)\n    \"\"\"\n    # Validate size argument\n    if not isinstance(size, int):\n        if not isinstance(size, (tuple, list)):\n            raise TypeError(\n                \"Invalid `size` argument. Expected an int or a tuple. \"\n                f\"Received: size={size} of type {type(size).__name__}\"\n            )\n        if len(size) not in (2, 3):\n            raise ValueError(\n                \"Invalid `size` argument. Expected a tuple of length 2 or 3. \"\n                f\"Received: size={size} with length {len(size)}\"\n            )\n\n    # 2D patch extraction (default)\n    if any_symbolic_tensors((images,)):\n        return ExtractPatches(\n            size=size,\n            strides=strides,\n            dilation_rate=dilation_rate,\n            padding=padding,\n            data_format=data_format,\n        ).symbolic_call(images)\n\n    return _extract_patches(\n        images, size, strides, dilation_rate, padding, data_format=data_format\n    )\n\n\ndef _extract_patches(\n    images,\n    size,\n    strides=None,\n    dilation_rate=1,\n    padding=\"valid\",\n    data_format=None,\n):\n    if not isinstance(size, int) and len(size) == 3:\n        return _extract_patches_3d(\n            images, size, strides, dilation_rate, padding, data_format\n        )\n    return _extract_patches_2d(\n        images, size, strides, dilation_rate, padding, data_format\n    )\n\n\ndef _extract_patches_2d(\n    images,\n    size,\n    strides=None,\n    dilation_rate=1,\n    padding=\"valid\",\n    data_format=None,\n):\n    if isinstance(size, int):\n        patch_h = patch_w = size\n    elif len(size) == 2:\n        patch_h, patch_w = size[0], size[1]\n    else:\n        raise TypeError(\n            \"Invalid `size` argument. Expected an \"\n            f\"int or a tuple of length 2. Received: size={size}\"\n        )\n    data_format = backend.standardize_data_format(data_format)\n    if data_format == \"channels_last\":\n        channels_in = images.shape[-1]\n    elif data_format == \"channels_first\":\n        channels_in = images.shape[-3]\n    if not strides:\n        strides = size\n    out_dim = patch_h * patch_w * channels_in\n    kernel = backend.numpy.eye(out_dim, dtype=images.dtype)\n    kernel = backend.numpy.reshape(\n        kernel, (patch_h, patch_w, channels_in, out_dim)\n    )\n    _unbatched = False\n    if len(images.shape) == 3:\n        _unbatched = True\n        images = backend.numpy.expand_dims(images, axis=0)\n    patches = backend.nn.conv(\n        inputs=images,\n        kernel=kernel,\n        strides=strides,\n        padding=padding,\n        data_format=data_format,\n        dilation_rate=dilation_rate,\n    )\n    if _unbatched:\n        patches = backend.numpy.squeeze(patches, axis=0)\n    return patches\n\n\ndef _extract_patches_3d(\n    volumes,\n    size,\n    strides=None,\n    dilation_rate=1,\n    padding=\"valid\",\n    data_format=None,\n):\n    if isinstance(size, int):\n        patch_d = patch_h = patch_w = size\n    elif len(size) == 3:\n        patch_d, patch_h, patch_w = size\n    else:\n        raise TypeError(\n            \"Invalid `size` argument. Expected an \"\n            f\"int or a tuple of length 3. Received: size={size}\"\n        )\n    if strides is None:\n        strides = size\n    if isinstance(strides, int):\n        strides = (strides, strides, strides)\n    if len(strides) != 3:\n        raise ValueError(f\"Invalid `strides` argument. Got: {strides}\")\n    data_format = backend.standardize_data_format(data_format)\n    if data_format == \"channels_last\":\n        channels_in = volumes.shape[-1]\n    elif data_format == \"channels_first\":\n        channels_in = volumes.shape[-4]\n    out_dim = patch_d * patch_w * patch_h * channels_in\n    kernel = backend.numpy.eye(out_dim, dtype=volumes.dtype)\n    kernel = backend.numpy.reshape(\n        kernel, (patch_d, patch_h, patch_w, channels_in, out_dim)\n    )\n    _unbatched = False\n    if len(volumes.shape) == 4:\n        _unbatched = True\n        volumes = backend.numpy.expand_dims(volumes, axis=0)\n    patches = backend.nn.conv(\n        inputs=volumes,\n        kernel=kernel,\n        strides=strides,\n        padding=padding,\n        data_format=data_format,\n        dilation_rate=dilation_rate,\n    )\n    if _unbatched:\n        patches = backend.numpy.squeeze(patches, axis=0)\n    return patches\n\n\n@keras_export(\"keras.ops.image.extract_patches_3d\")\ndef extract_patches_3d(\n    volumes,\n    size,\n    strides=None,\n    dilation_rate=1,\n    padding=\"valid\",\n    data_format=None,\n):\n    \"\"\"Extracts patches from the volume(s).\n\n    Args:\n        volumes: Input volume or batch of volumes. Must be 4D or 5D.\n        size: Patch size int or tuple (patch_depth, patch_height, patch_width)\n        strides: strides along depth, height, and width. If not specified, or\n            if `None`, it defaults to the same value as `size`.\n        dilation_rate: This is the input stride, specifying how far two\n            consecutive patch samples are in the input. Note that using\n            `dilation_rate > 1` is not supported in conjunction with\n            `strides > 1` on the TensorFlow backend.\n        padding: The type of padding algorithm to use: `\"same\"` or `\"valid\"`.\n        data_format: A string specifying the data format of the input tensor.\n            It can be either `\"channels_last\"` or `\"channels_first\"`.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch, depth, height, width, channels)`, while `\"channels_first\"`\n            corresponds to inputs with shape\n            `(batch, channels, depth, height, width)`. If not specified,\n             the value will default to `keras.config.image_data_format()`.\n\n    Returns:\n        Extracted patches 4D (if not batched) or 5D (if batched)\n\n    Examples:\n\n    >>> import numpy as np\n    >>> import keras\n    >>> # Batched case\n    >>> volumes = np.random.random(\n    ...     (2, 10, 10, 10, 3)\n    ... ).astype(\"float32\") # batch of 2 volumes\n    >>> patches = keras.ops.image.extract_patches_3d(volumes, (3, 3, 3))\n    >>> patches.shape\n    (2, 3, 3, 3, 81)\n    >>> # Unbatched case\n    >>> volume = np.random.random((10, 10, 10, 3)).astype(\"float32\") # 1 volume\n    >>> patches = keras.ops.image.extract_patches_3d(volume, (3, 3, 3))\n    >>> patches.shape\n    (3, 3, 3, 81)\n    \"\"\"\n    # Convert int to 3-tuple for 3D\n    if isinstance(size, int):\n        size = (size, size, size)\n    if any_symbolic_tensors((volumes,)):\n        return ExtractPatches(\n            size=size,\n            strides=strides,\n            dilation_rate=dilation_rate,\n            padding=padding,\n            data_format=data_format,\n        ).symbolic_call(volumes)\n\n    return _extract_patches_3d(\n        volumes, size, strides, dilation_rate, padding, data_format=data_format\n    )\n\n\nclass MapCoordinates(Operation):\n    def __init__(self, order, fill_mode=\"constant\", fill_value=0, *, name=None):\n        super().__init__(name=name)\n        self.order = order\n        self.fill_mode = fill_mode\n        self.fill_value = fill_value\n\n    def call(self, inputs, coordinates):\n        return backend.image.map_coordinates(\n            inputs,\n            coordinates,\n            order=self.order,\n            fill_mode=self.fill_mode,\n            fill_value=self.fill_value,\n        )\n\n    def compute_output_spec(self, inputs, coordinates):\n        if coordinates.shape[0] != len(inputs.shape):\n            raise ValueError(\n                \"First dim of `coordinates` must be the same as the rank of \"\n                \"`inputs`. \"\n                f\"Received inputs with shape: {inputs.shape} and coordinate \"\n                f\"leading dim of {coordinates.shape[0]}\"\n            )\n        if len(coordinates.shape) < 2:\n            raise ValueError(\n                \"Invalid coordinates rank: expected at least rank 2.\"\n                f\" Received input with shape: {coordinates.shape}\"\n            )\n        return KerasTensor(coordinates.shape[1:], dtype=inputs.dtype)\n\n\n@keras_export(\"keras.ops.image.map_coordinates\")\ndef map_coordinates(\n    inputs, coordinates, order, fill_mode=\"constant\", fill_value=0\n):\n    \"\"\"Map the input array to new coordinates by interpolation.\n\n    Note that interpolation near boundaries differs from the scipy function,\n    because we fixed an outstanding bug\n    [scipy/issues/2640](https://github.com/scipy/scipy/issues/2640).\n\n    Args:\n        inputs: The input array.\n        coordinates: The coordinates at which inputs is evaluated.\n        order: The order of the spline interpolation. The order must be `0` or\n            `1`. `0` indicates the nearest neighbor and `1` indicates the linear\n            interpolation.\n        fill_mode: Points outside the boundaries of the inputs are filled\n            according to the given mode. Available methods are `\"constant\"`,\n            `\"nearest\"`, `\"wrap\"` and `\"mirror\"` and `\"reflect\"`. Defaults to\n            `\"constant\"`.\n            - `\"constant\"`: `(k k k k | a b c d | k k k k)`\n                The inputs is extended by filling all values beyond\n                the edge with the same constant value k specified by\n                `fill_value`.\n            - `\"nearest\"`: `(a a a a | a b c d | d d d d)`\n                The inputs is extended by the nearest pixel.\n            - `\"wrap\"`: `(a b c d | a b c d | a b c d)`\n                The inputs is extended by wrapping around to the opposite edge.\n            - `\"mirror\"`: `(c d c b | a b c d | c b a b)`\n                The inputs is extended by mirroring about the edge.\n            - `\"reflect\"`: `(d c b a | a b c d | d c b a)`\n                The inputs is extended by reflecting about the edge of the last\n                pixel.\n        fill_value: Value used for points outside the boundaries of the inputs\n            if `fill_mode=\"constant\"`. Defaults to `0`.\n\n    Returns:\n        Output input or batch of inputs.\n\n    \"\"\"\n    if any_symbolic_tensors((inputs, coordinates)):\n        return MapCoordinates(\n            order,\n            fill_mode,\n            fill_value,\n        ).symbolic_call(inputs, coordinates)\n    return backend.image.map_coordinates(\n        inputs,\n        coordinates,\n        order,\n        fill_mode,\n        fill_value,\n    )\n\n\nclass PadImages(Operation):\n    def __init__(\n        self,\n        top_padding=None,\n        left_padding=None,\n        bottom_padding=None,\n        right_padding=None,\n        target_height=None,\n        target_width=None,\n        data_format=None,\n        *,\n        name=None,\n    ):\n        super().__init__(name=name)\n        self.top_padding = top_padding\n        self.left_padding = left_padding\n        self.bottom_padding = bottom_padding\n        self.right_padding = right_padding\n        self.target_height = target_height\n        self.target_width = target_width\n        self.data_format = backend.standardize_data_format(data_format)\n\n    def call(self, images):\n        return _pad_images(\n            images,\n            self.top_padding,\n            self.left_padding,\n            self.bottom_padding,\n            self.right_padding,\n            self.target_height,\n            self.target_width,\n            self.data_format,\n        )\n\n    def compute_output_spec(self, images):\n        images_shape = list(images.shape)\n\n        if self.data_format == \"channels_last\":\n            height_axis, width_axis = -3, -2\n            height, width = images_shape[height_axis], images_shape[width_axis]\n        else:\n            height_axis, width_axis = -2, -1\n            height, width = images_shape[height_axis], images_shape[width_axis]\n\n        target_height = self.target_height\n        if target_height is None and height is not None:\n            target_height = self.top_padding + height + self.bottom_padding\n        target_width = self.target_width\n        if target_width is None and width is not None:\n            target_width = self.left_padding + width + self.right_padding\n\n        images_shape[height_axis] = target_height\n        images_shape[width_axis] = target_width\n        return KerasTensor(shape=images_shape, dtype=images.dtype)\n\n\n@keras_export(\"keras.ops.image.pad_images\")\ndef pad_images(\n    images,\n    top_padding=None,\n    left_padding=None,\n    bottom_padding=None,\n    right_padding=None,\n    target_height=None,\n    target_width=None,\n    data_format=None,\n):\n    \"\"\"Pad `images` with zeros to the specified `height` and `width`.\n\n    Args:\n        images: Input image or batch of images. Must be 3D or 4D.\n        top_padding: Number of rows of zeros to add on top.\n        left_padding: Number of columns of zeros to add on the left.\n        bottom_padding: Number of rows of zeros to add at the bottom.\n        right_padding: Number of columns of zeros to add on the right.\n        target_height: Height of output images.\n        target_width: Width of output images.\n        data_format: A string specifying the data format of the input tensor.\n            It can be either `\"channels_last\"` or `\"channels_first\"`.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch, height, width, channels)`, while `\"channels_first\"`\n            corresponds to inputs with shape `(batch, channels, height, width)`.\n            If not specified, the value will default to\n            `keras.config.image_data_format`.\n\n    Returns:\n        Padded image or batch of images.\n\n    Example:\n\n    >>> images = np.random.random((15, 25, 3))\n    >>> padded_images = keras.ops.image.pad_images(\n    ...     images, 2, 3, target_height=20, target_width=30\n    ... )\n    >>> padded_images.shape\n    (20, 30, 3)\n\n    >>> batch_images = np.random.random((2, 15, 25, 3))\n    >>> padded_batch = keras.ops.image.pad_images(\n    ...     batch_images, 2, 3, target_height=20, target_width=30\n    ... )\n    >>> padded_batch.shape\n    (2, 20, 30, 3)\"\"\"\n\n    if any_symbolic_tensors((images,)):\n        return PadImages(\n            top_padding,\n            left_padding,\n            bottom_padding,\n            right_padding,\n            target_height,\n            target_width,\n            data_format,\n        ).symbolic_call(images)\n\n    return _pad_images(\n        images,\n        top_padding,\n        left_padding,\n        bottom_padding,\n        right_padding,\n        target_height,\n        target_width,\n        data_format,\n    )\n\n\ndef _pad_images(\n    images,\n    top_padding,\n    left_padding,\n    bottom_padding,\n    right_padding,\n    target_height,\n    target_width,\n    data_format=None,\n):\n    data_format = backend.standardize_data_format(data_format)\n    images = backend.convert_to_tensor(images)\n    images_shape = ops.shape(images)\n\n    # Check\n    if len(images_shape) not in (3, 4):\n        raise ValueError(\n            f\"Invalid shape for argument `images`: \"\n            \"it must have rank 3 or 4. \"\n            f\"Received: images.shape={images_shape}\"\n        )\n    if [top_padding, bottom_padding, target_height].count(None) != 1:\n        raise ValueError(\n            \"Must specify exactly two of \"\n            \"top_padding, bottom_padding, target_height. \"\n            f\"Received: top_padding={top_padding}, \"\n            f\"bottom_padding={bottom_padding}, \"\n            f\"target_height={target_height}\"\n        )\n    if [left_padding, right_padding, target_width].count(None) != 1:\n        raise ValueError(\n            \"Must specify exactly two of \"\n            \"left_padding, right_padding, target_width. \"\n            f\"Received: left_padding={left_padding}, \"\n            f\"right_padding={right_padding}, \"\n            f\"target_width={target_width}\"\n        )\n\n    is_batch = False if len(images_shape) == 3 else True\n    if data_format == \"channels_last\":\n        height, width = images_shape[-3], images_shape[-2]\n    else:\n        height, width = images_shape[-2], images_shape[-1]\n\n    # Infer padding\n    if top_padding is None:\n        top_padding = target_height - bottom_padding - height\n    if bottom_padding is None:\n        bottom_padding = target_height - top_padding - height\n    if left_padding is None:\n        left_padding = target_width - right_padding - width\n    if right_padding is None:\n        right_padding = target_width - left_padding - width\n\n    if top_padding < 0:\n        raise ValueError(\n            f\"top_padding must be >= 0. Received: top_padding={top_padding}\"\n        )\n    if left_padding < 0:\n        raise ValueError(\n            f\"left_padding must be >= 0. Received: left_padding={left_padding}\"\n        )\n    if right_padding < 0:\n        raise ValueError(\n            \"right_padding must be >= 0. \"\n            f\"Received: right_padding={right_padding}\"\n        )\n    if bottom_padding < 0:\n        raise ValueError(\n            \"bottom_padding must be >= 0. \"\n            f\"Received: bottom_padding={bottom_padding}\"\n        )\n\n    # Compute pad_width\n    pad_width = [[top_padding, bottom_padding], [left_padding, right_padding]]\n    if data_format == \"channels_last\":\n        pad_width = pad_width + [[0, 0]]\n    else:\n        pad_width = [[0, 0]] + pad_width\n    if is_batch:\n        pad_width = [[0, 0]] + pad_width\n\n    padded_images = backend.numpy.pad(images, pad_width)\n    return padded_images\n\n\nclass CropImages(Operation):\n    def __init__(\n        self,\n        top_cropping=None,\n        left_cropping=None,\n        bottom_cropping=None,\n        right_cropping=None,\n        target_height=None,\n        target_width=None,\n        data_format=None,\n        *,\n        name=None,\n    ):\n        super().__init__(name=name)\n        self.top_cropping = top_cropping\n        self.bottom_cropping = bottom_cropping\n        self.left_cropping = left_cropping\n        self.right_cropping = right_cropping\n        self.target_height = target_height\n        self.target_width = target_width\n        self.data_format = backend.standardize_data_format(data_format)\n\n    def call(self, images):\n        return _crop_images(\n            images,\n            self.top_cropping,\n            self.left_cropping,\n            self.bottom_cropping,\n            self.right_cropping,\n            self.target_height,\n            self.target_width,\n            self.data_format,\n        )\n\n    def compute_output_spec(self, images):\n        images_shape = list(images.shape)\n\n        if self.data_format == \"channels_last\":\n            height_axis, width_axis = -3, -2\n        else:\n            height_axis, width_axis = -2, -1\n        height, width = images_shape[height_axis], images_shape[width_axis]\n\n        if height is None and self.target_height is None:\n            raise ValueError(\n                \"When the height of the images is unknown, `target_height` \"\n                \"must be specified.\"\n                f\"Received images.shape={images_shape} and \"\n                f\"target_height={self.target_height}\"\n            )\n        if width is None and self.target_width is None:\n            raise ValueError(\n                \"When the width of the images is unknown, `target_width` \"\n                \"must be specified.\"\n                f\"Received images.shape={images_shape} and \"\n                f\"target_width={self.target_width}\"\n            )\n\n        target_height = self.target_height\n        if target_height is None:\n            target_height = height - self.top_cropping - self.bottom_cropping\n        target_width = self.target_width\n        if target_width is None:\n            target_width = width - self.left_cropping - self.right_cropping\n\n        images_shape[height_axis] = target_height\n        images_shape[width_axis] = target_width\n        return KerasTensor(shape=images_shape, dtype=images.dtype)\n\n\n@keras_export(\"keras.ops.image.crop_images\")\ndef crop_images(\n    images,\n    top_cropping=None,\n    left_cropping=None,\n    bottom_cropping=None,\n    right_cropping=None,\n    target_height=None,\n    target_width=None,\n    data_format=None,\n):\n    \"\"\"Crop `images` to a specified `height` and `width`.\n\n    Args:\n        images: Input image or batch of images. Must be 3D or 4D.\n        top_cropping: Number of columns to crop from the top.\n        left_cropping: Number of columns to crop from the left.\n        bottom_cropping: Number of columns to crop from the bottom.\n        right_cropping: Number of columns to crop from the right.\n        target_height: Height of the output images.\n        target_width: Width of the output images.\n        data_format: A string specifying the data format of the input tensor.\n            It can be either `\"channels_last\"` or `\"channels_first\"`.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch, height, width, channels)`, while `\"channels_first\"`\n            corresponds to inputs with shape `(batch, channels, height, width)`.\n            If not specified, the value will default to\n            `keras.config.image_data_format`.\n\n    Returns:\n        Cropped image or batch of images.\n\n    Example:\n\n    >>> images = np.reshape(np.arange(1, 28, dtype=\"float32\"), [3, 3, 3])\n    >>> images[:,:,0] # print the first channel of the images\n    array([[ 1.,  4.,  7.],\n           [10., 13., 16.],\n           [19., 22., 25.]], dtype=float32)\n    >>> cropped_images = keras.image.crop_images(images, 0, 0, 2, 2)\n    >>> cropped_images[:,:,0] # print the first channel of the cropped images\n    array([[ 1.,  4.],\n           [10., 13.]], dtype=float32)\"\"\"\n\n    if any_symbolic_tensors((images,)):\n        return CropImages(\n            top_cropping,\n            left_cropping,\n            bottom_cropping,\n            right_cropping,\n            target_height,\n            target_width,\n            data_format,\n        ).symbolic_call(images)\n\n    return _crop_images(\n        images,\n        top_cropping,\n        left_cropping,\n        bottom_cropping,\n        right_cropping,\n        target_height,\n        target_width,\n        data_format,\n    )\n\n\ndef _crop_images(\n    images,\n    top_cropping,\n    left_cropping,\n    bottom_cropping,\n    right_cropping,\n    target_height,\n    target_width,\n    data_format=None,\n):\n    data_format = backend.standardize_data_format(data_format)\n    images = backend.convert_to_tensor(images)\n    images_shape = ops.shape(images)\n\n    # Check\n    if len(images_shape) not in (3, 4):\n        raise ValueError(\n            f\"Invalid shape for argument `images`: \"\n            \"it must have rank 3 or 4. \"\n            f\"Received: images.shape={images_shape}\"\n        )\n    if [top_cropping, bottom_cropping, target_height].count(None) != 1:\n        raise ValueError(\n            \"Must specify exactly two of \"\n            \"top_cropping, bottom_cropping, target_height. \"\n            f\"Received: top_cropping={top_cropping}, \"\n            f\"bottom_cropping={bottom_cropping}, \"\n            f\"target_height={target_height}\"\n        )\n    if [left_cropping, right_cropping, target_width].count(None) != 1:\n        raise ValueError(\n            \"Must specify exactly two of \"\n            \"left_cropping, right_cropping, target_width. \"\n            f\"Received: left_cropping={left_cropping}, \"\n            f\"right_cropping={right_cropping}, \"\n            f\"target_width={target_width}\"\n        )\n\n    is_batch = False if len(images_shape) == 3 else True\n    if data_format == \"channels_last\":\n        height, width = images_shape[-3], images_shape[-2]\n        channels = images_shape[-1]\n    else:\n        height, width = images_shape[-2], images_shape[-1]\n        channels = images_shape[-3]\n\n    # Infer padding\n    if top_cropping is None:\n        top_cropping = height - target_height - bottom_cropping\n    if target_height is None:\n        target_height = height - bottom_cropping - top_cropping\n    if left_cropping is None:\n        left_cropping = width - target_width - right_cropping\n    if target_width is None:\n        target_width = width - right_cropping - left_cropping\n\n    if top_cropping < 0:\n        raise ValueError(\n            f\"top_cropping must be >= 0. Received: top_cropping={top_cropping}\"\n        )\n    if target_height < 0:\n        raise ValueError(\n            \"target_height must be >= 0. \"\n            f\"Received: target_height={target_height}\"\n        )\n    if left_cropping < 0:\n        raise ValueError(\n            \"left_cropping must be >= 0. \"\n            f\"Received: left_cropping={left_cropping}\"\n        )\n    if target_width < 0:\n        raise ValueError(\n            f\"target_width must be >= 0. Received: target_width={target_width}\"\n        )\n\n    # Compute start_indices and shape\n    start_indices = [top_cropping, left_cropping]\n    shape = [target_height, target_width]\n    if data_format == \"channels_last\":\n        start_indices = start_indices + [0]\n        shape = shape + [channels]\n    else:\n        start_indices = [0] + start_indices\n        shape = [channels] + shape\n    if is_batch:\n        batch_size = images_shape[0]\n        start_indices = [0] + start_indices\n        shape = [batch_size] + shape\n\n    cropped_images = ops.slice(images, start_indices, shape)\n    return cropped_images\n\n\nclass PerspectiveTransform(Operation):\n    def __init__(\n        self,\n        interpolation=\"bilinear\",\n        fill_value=0,\n        data_format=None,\n        *,\n        name=None,\n    ):\n        super().__init__(name=name)\n        self.interpolation = interpolation\n        self.fill_value = fill_value\n        self.data_format = backend.standardize_data_format(data_format)\n\n    def call(self, images, start_points, end_points):\n        return backend.image.perspective_transform(\n            images,\n            start_points,\n            end_points,\n            interpolation=self.interpolation,\n            fill_value=self.fill_value,\n            data_format=self.data_format,\n        )\n\n    def compute_output_spec(self, images, start_points, end_points):\n        if len(images.shape) not in (3, 4):\n            raise ValueError(\n                \"Invalid images rank: expected rank 3 (single image) \"\n                \"or rank 4 (batch of images). Received input with shape: \"\n                f\"images.shape={images.shape}\"\n            )\n        if start_points.shape[-2:] != (4, 2) or start_points.ndim not in (2, 3):\n            raise ValueError(\n                \"Invalid start_points shape: expected (4,2) for a single image\"\n                f\" or (N,4,2) for a batch. Received shape: {start_points.shape}\"\n            )\n        if end_points.shape[-2:] != (4, 2) or end_points.ndim not in (2, 3):\n            raise ValueError(\n                \"Invalid end_points shape: expected (4,2) for a single image\"\n                f\" or (N,4,2) for a batch. Received shape: {end_points.shape}\"\n            )\n        if start_points.shape != end_points.shape:\n            raise ValueError(\n                \"start_points and end_points must have the same shape.\"\n                f\" Received start_points.shape={start_points.shape}, \"\n                f\"end_points.shape={end_points.shape}\"\n            )\n        return KerasTensor(images.shape, dtype=images.dtype)\n\n\n@keras_export(\"keras.ops.image.perspective_transform\")\ndef perspective_transform(\n    images,\n    start_points,\n    end_points,\n    interpolation=\"bilinear\",\n    fill_value=0,\n    data_format=None,\n):\n    \"\"\"Applies a perspective transformation to the image(s).\n\n    Args:\n        images: Input image or batch of images. Must be 3D or 4D.\n        start_points: A tensor of shape `(N, 4, 2)` or `(4, 2)`,\n            representing the source points in the original image\n            that define the transformation.\n        end_points: A tensor of shape `(N, 4, 2)` or `(4, 2)`,\n            representing the target points in the output image\n            after transformation.\n        interpolation: Interpolation method. Available methods are `\"nearest\"`,\n            and `\"bilinear\"`. Defaults to `\"bilinear\"`.\n        fill_value: Value used for points outside the boundaries of the input if\n            extrapolation is needed. Defaults to `0`.\n        data_format: A string specifying the data format of the input tensor.\n            It can be either `\"channels_last\"` or `\"channels_first\"`.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch, height, width, channels)`, while `\"channels_first\"`\n            corresponds to inputs with shape `(batch, channels, height, width)`.\n            If not specified, the value will default to\n            `keras.config.image_data_format`.\n\n    Returns:\n        Applied perspective transform image or batch of images.\n\n    Examples:\n\n    >>> x = np.random.random((2, 64, 80, 3))  # batch of 2 RGB images\n    >>> start_points = np.array(\n    ...     [\n    ...         [[0, 0], [0, 64], [80, 0], [80, 64]],\n    ...         [[0, 0], [0, 64], [80, 0], [80, 64]],\n    ...     ]\n    ... )\n    >>> end_points = np.array(\n    ...     [\n    ...         [[3, 5], [7, 64], [76, -10], [84, 61]],\n    ...         [[8, 10], [10, 61], [65, 3], [88, 43]],\n    ...     ]\n    ... )\n    >>> y = keras.ops.image.perspective_transform(x, start_points, end_points)\n    >>> y.shape\n    (2, 64, 80, 3)\n\n    >>> x = np.random.random((64, 80, 3))  # single RGB image\n    >>> start_points = np.array([[0, 0], [0, 64], [80, 0], [80, 64]])\n    >>> end_points = np.array([[3, 5], [7, 64], [76, -10], [84, 61]])\n    >>> y = keras.ops.image.perspective_transform(x, start_points, end_points)\n    >>> y.shape\n    (64, 80, 3)\n\n    >>> x = np.random.random((2, 3, 64, 80))  # batch of 2 RGB images\n    >>> start_points = np.array(\n    ...     [\n    ...         [[0, 0], [0, 64], [80, 0], [80, 64]],\n    ...         [[0, 0], [0, 64], [80, 0], [80, 64]],\n    ...     ]\n    ... )\n    >>> end_points = np.array(\n    ...     [\n    ...         [[3, 5], [7, 64], [76, -10], [84, 61]],\n    ...         [[8, 10], [10, 61], [65, 3], [88, 43]],\n    ...     ]\n    ... )\n    >>> y = keras.ops.image.perspective_transform(\n    ...     x, start_points, end_points, data_format=\"channels_first\"\n    ... )\n    >>> y.shape\n    (2, 3, 64, 80)\n    \"\"\"\n    if any_symbolic_tensors((images, start_points, end_points)):\n        return PerspectiveTransform(\n            interpolation=interpolation,\n            fill_value=fill_value,\n            data_format=data_format,\n        ).symbolic_call(images, start_points, end_points)\n    return backend.image.perspective_transform(\n        images,\n        start_points,\n        end_points,\n        interpolation=interpolation,\n        fill_value=fill_value,\n        data_format=data_format,\n    )\n\n\nclass GaussianBlur(Operation):\n    def __init__(\n        self,\n        kernel_size=(3, 3),\n        sigma=(1.0, 1.0),\n        data_format=None,\n        *,\n        name=None,\n    ):\n        super().__init__(name=name)\n        self.kernel_size = kernel_size\n        self.sigma = sigma\n        self.data_format = backend.standardize_data_format(data_format)\n\n    def call(self, images):\n        return backend.image.gaussian_blur(\n            images,\n            kernel_size=self.kernel_size,\n            sigma=self.sigma,\n            data_format=self.data_format,\n        )\n\n    def compute_output_spec(self, images):\n        if len(images.shape) not in (3, 4):\n            raise ValueError(\n                \"Invalid images rank: expected rank 3 (single image) \"\n                \"or rank 4 (batch of images). Received input with shape: \"\n                f\"images.shape={images.shape}\"\n            )\n        return KerasTensor(images.shape, dtype=images.dtype)\n\n\n@keras_export(\"keras.ops.image.gaussian_blur\")\ndef gaussian_blur(\n    images, kernel_size=(3, 3), sigma=(1.0, 1.0), data_format=None\n):\n    \"\"\"Applies a Gaussian blur to the image(s).\n\n    Args:\n        images: Input image or batch of images. Must be 3D or 4D.\n        kernel_size: A tuple of two integers, specifying the height and width\n            of the Gaussian kernel.\n        sigma: A tuple of two floats, specifying the standard deviation of\n            the Gaussian kernel along height and width.\n        data_format: A string specifying the data format of the input tensor.\n            It can be either `\"channels_last\"` or `\"channels_first\"`.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch, height, width, channels)`, while `\"channels_first\"`\n            corresponds to inputs with shape `(batch, channels, height, width)`.\n            If not specified, the value will default to\n            `keras.config.image_data_format`.\n\n    Returns:\n        Blurred image or batch of images.\n\n    Examples:\n\n    >>> x = np.random.random((2, 64, 80, 3))  # batch of 2 RGB images\n    >>> y = keras.ops.image.gaussian_blur(x)\n    >>> y.shape\n    (2, 64, 80, 3)\n\n    >>> x = np.random.random((64, 80, 3))  # single RGB image\n    >>> y = keras.ops.image.gaussian_blur(x)\n    >>> y.shape\n    (64, 80, 3)\n\n    >>> x = np.random.random((2, 3, 64, 80))  # batch of 2 RGB images\n    >>> y = keras.ops.image.gaussian_blur(\n    ...     x, data_format=\"channels_first\")\n    >>> y.shape\n    (2, 3, 64, 80)\n    \"\"\"\n    if any_symbolic_tensors((images,)):\n        return GaussianBlur(\n            kernel_size=kernel_size,\n            sigma=sigma,\n            data_format=data_format,\n        ).symbolic_call(images)\n    return backend.image.gaussian_blur(\n        images,\n        kernel_size=kernel_size,\n        sigma=sigma,\n        data_format=data_format,\n    )\n\n\nclass ElasticTransform(Operation):\n    def __init__(\n        self,\n        alpha=20.0,\n        sigma=5.0,\n        interpolation=\"bilinear\",\n        fill_mode=\"reflect\",\n        fill_value=0.0,\n        seed=None,\n        data_format=None,\n        *,\n        name=None,\n    ):\n        super().__init__(name=name)\n        self.alpha = alpha\n        self.sigma = sigma\n        self.interpolation = interpolation\n        self.fill_mode = fill_mode\n        self.fill_value = fill_value\n        self.seed = seed\n        self.data_format = backend.standardize_data_format(data_format)\n\n    def call(self, images):\n        return backend.image.elastic_transform(\n            images,\n            alpha=self.alpha,\n            sigma=self.sigma,\n            interpolation=self.interpolation,\n            fill_mode=self.fill_mode,\n            fill_value=self.fill_value,\n            seed=self.seed,\n            data_format=self.data_format,\n        )\n\n    def compute_output_spec(self, images):\n        if len(images.shape) not in (3, 4):\n            raise ValueError(\n                \"Invalid images rank: expected rank 3 (single image) \"\n                \"or rank 4 (batch of images). Received input with shape: \"\n                f\"images.shape={images.shape}\"\n            )\n        return KerasTensor(images.shape, dtype=images.dtype)\n\n\n@keras_export(\"keras.ops.image.elastic_transform\")\ndef elastic_transform(\n    images,\n    alpha=20.0,\n    sigma=5.0,\n    interpolation=\"bilinear\",\n    fill_mode=\"reflect\",\n    fill_value=0.0,\n    seed=None,\n    data_format=None,\n):\n    \"\"\"Applies elastic deformation to the image(s).\n\n    Args:\n        images: Input image or batch of images. Must be 3D or 4D.\n        alpha: Scaling factor that controls the intensity of the deformation.\n        sigma: Standard deviation of the Gaussian filter used for\n            smoothing the displacement fields.\n        interpolation: Interpolation method. Available methods are `\"nearest\"`,\n            and `\"bilinear\"`. Defaults to `\"bilinear\"`.\n        fill_mode: Points outside the boundaries of the input are filled\n            according to the given mode. Available methods are `\"constant\"`,\n            `\"nearest\"`, `\"wrap\"` and `\"reflect\"`. Defaults to `\"constant\"`.\n            - `\"reflect\"`: `(d c b a | a b c d | d c b a)`\n                The input is extended by reflecting about the edge of the last\n                pixel.\n            - `\"constant\"`: `(k k k k | a b c d | k k k k)`\n                The input is extended by filling all values beyond\n                the edge with the same constant value k specified by\n                `fill_value`.\n            - `\"wrap\"`: `(a b c d | a b c d | a b c d)`\n                The input is extended by wrapping around to the opposite edge.\n            - `\"nearest\"`: `(a a a a | a b c d | d d d d)`\n                The input is extended by the nearest pixel.\n        fill_value: Value used for points outside the boundaries of the input if\n            `fill_mode=\"constant\"`. Defaults to `0`.\n        data_format: A string specifying the data format of the input tensor.\n            It can be either `\"channels_last\"` or `\"channels_first\"`.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch, height, width, channels)`, while `\"channels_first\"`\n            corresponds to inputs with shape `(batch, channels, height, width)`.\n            If not specified, the value will default to\n            `keras.config.image_data_format`.\n\n    Returns:\n        Transformed image or batch of images with elastic deformation.\n\n    Examples:\n\n    >>> x = np.random.random((2, 64, 80, 3))  # batch of 2 RGB images\n    >>> y = keras.ops.image.elastic_transform(x)\n    >>> y.shape\n    (2, 64, 80, 3)\n\n    >>> x = np.random.random((64, 80, 3))  # single RGB image\n    >>> y = keras.ops.image.elastic_transform(x)\n    >>> y.shape\n    (64, 80, 3)\n\n    >>> x = np.random.random((2, 3, 64, 80))  # batch of 2 RGB images\n    >>> y = keras.ops.image.elastic_transform(\n    ...     x, data_format=\"channels_first\")\n    >>> y.shape\n    (2, 3, 64, 80)\n    \"\"\"\n    if any_symbolic_tensors((images,)):\n        return ElasticTransform(\n            alpha=alpha,\n            sigma=sigma,\n            interpolation=interpolation,\n            fill_mode=fill_mode,\n            fill_value=fill_value,\n            seed=seed,\n            data_format=data_format,\n        ).symbolic_call(images)\n    return backend.image.elastic_transform(\n        images,\n        alpha=alpha,\n        sigma=sigma,\n        interpolation=interpolation,\n        fill_mode=fill_mode,\n        fill_value=fill_value,\n        seed=seed,\n        data_format=data_format,\n    )\n\n\nclass ScaleAndTranslate(Operation):\n    def __init__(self, spatial_dims, method, antialias=True, *, name=None):\n        super().__init__(name=name)\n        self.spatial_dims = spatial_dims\n        self.method = method\n        self.antialias = antialias\n\n    def call(self, images, output_shape, scale, translation):\n        return backend.image.scale_and_translate(\n            images,\n            output_shape=output_shape,\n            scale=scale,\n            translation=translation,\n            spatial_dims=self.spatial_dims,\n            method=self.method,\n            antialias=self.antialias,\n        )\n\n    def compute_output_spec(self, images, output_shape, scale, translation):\n        return KerasTensor(output_shape, dtype=images.dtype)\n\n\n@keras_export(\"keras.ops.image.scale_and_translate\")\ndef scale_and_translate(\n    images,\n    output_shape,\n    scale,\n    translation,\n    spatial_dims,\n    method,\n    antialias=True,\n):\n    \"\"\"Apply a scale and translation to the images.\n\n    Generates a new image of `output_shape` by resampling from the input image\n    using the sampling method corresponding to method. For 2D images, this\n    operation transforms a location in the input images, (x, y), to a location\n    in the output image according to:\n\n    `(x * scale[1] + translation[1], y * scale[0] + translation[0])`.\n\n    (Note the inverse warp is used to generate the sample locations.) Assumes\n    half-centered pixels, i.e the pixel at integer location row, col has\n    coordinates y, x = row + 0.5, col + 0.5, and similarly for other input image\n    dimensions.\n\n    If an output location(pixel) maps to an input sample location that is\n    outside the input boundaries then the value for the output location will be\n    set to zero.\n\n    The `method` argument expects one of the following resize methods:\n\n    - `\"linear\"`, `\"bilinear\"`, `\"trilinear\"`, `\"triangle\"`: Linear\n        interpolation. If `antialias` is True, uses a triangular filter when\n        downsampling.\n    - `\"cubic\"`, `\"bicubic\"`, `\"tricubic\"`: Cubic interpolation, using the Keys\n        cubic kernel.\n    - `\"lanczos3\"`: Lanczos resampling, using a kernel of radius 3.\n    - `\"lanczos5\"`: Lanczos resampling, using a kernel of radius 5.\n\n    Args:\n        images: The input array.\n        output_shape: The output shape, as a sequence of integers with length\n            equal to the number of dimensions of image.\n        scale: A [K] array with the same number of dimensions as `images`,\n            containing the scale to apply in each dimension.\n        translation: A [K] array with the same number of dimensions as `images`,\n            containing the translation to apply in each dimension.\n        spatial_dims: A length K tuple specifying the spatial dimensions that\n            the passed `scale` and `translation` should be applied to.\n        method: A string specifying the resizing method to use. Available\n            methods are `\"linear\"`, `\"bilinear\"`, `\"trilinear\"`, `\"triangle\"`,\n            `\"cubic\"`, `\"bicubic\"`, `\"tricubic\"`, `\"lanczos3\"` and `\"lanczos5\"`.\n        antialias: Whether an antialiasing filter should be applied when\n            downsampling. Has no effect when upsampling. Defaults to `True`.\n\n    Returns:\n        The scale and translated images.\n\n    Example:\n\n    >>> images = np.arange(9, dtype=\"float32\").reshape((3, 3))\n    >>> scale = np.array([2.0, 2.0]).astype(\"float32\")\n    >>> translation = -(scale / 2.0 - 0.5)\n    >>> resized_images = keras.image.scale_and_translate(\n    ...     images, (5, 5), scale, translation, (0, 1), \"linear\"\n    ... )\n    >>> resized_images\n    array([[0.0 0.5 1.0 1.5 2.0]\n           [1.5 2.0 2.5 3.0 3.5]\n           [3.0 3.5 4.0 4.5 5.0]\n           [4.5 5.0 5.5 6.0 6.5]\n           [6.0 6.5 7.0 7.5 8.0]], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((images, scale, translation)):\n        return ScaleAndTranslate(spatial_dims, method, antialias).symbolic_call(\n            images, output_shape, scale, translation\n        )\n    return backend.image.scale_and_translate(\n        images,\n        output_shape,\n        scale,\n        translation,\n        spatial_dims,\n        method,\n        antialias,\n    )\n"
  },
  {
    "path": "keras/src/ops/image_test.py",
    "content": "import math\n\nimport jax\nimport numpy as np\nimport pytest\nimport scipy.ndimage\nimport tensorflow as tf\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import testing\nfrom keras.src.backend.common import dtypes\nfrom keras.src.backend.common.keras_tensor import KerasTensor\nfrom keras.src.ops import image as kimage\nfrom keras.src.ops import numpy as knp\nfrom keras.src.ops import random as krandom\nfrom keras.src.testing.test_utils import named_product\n\n\nclass ImageOpsDynamicShapeTest(testing.TestCase):\n    def setUp(self):\n        # Defaults to channels_last\n        self.data_format = backend.image_data_format()\n        backend.set_image_data_format(\"channels_last\")\n        return super().setUp()\n\n    def tearDown(self):\n        backend.set_image_data_format(self.data_format)\n        return super().tearDown()\n\n    def test_rgb_to_grayscale(self):\n        # Test channels_last\n        x = KerasTensor([None, 20, 20, 3])\n        out = kimage.rgb_to_grayscale(x)\n        self.assertEqual(out.shape, (None, 20, 20, 1))\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = KerasTensor([None, 3, 20, 20])\n        out = kimage.rgb_to_grayscale(x)\n        self.assertEqual(out.shape, (None, 1, 20, 20))\n\n    def test_rgb_to_hsv(self):\n        # Test channels_last\n        x = KerasTensor([None, 20, 20, 3])\n        out = kimage.rgb_to_hsv(x)\n        self.assertEqual(out.shape, (None, 20, 20, 3))\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = KerasTensor([None, 3, 20, 20])\n        out = kimage.rgb_to_hsv(x)\n        self.assertEqual(out.shape, (None, 3, 20, 20))\n\n    def test_hsv_to_rgb(self):\n        # Test channels_last\n        x = KerasTensor([None, 20, 20, 3])\n        out = kimage.hsv_to_rgb(x)\n        self.assertEqual(out.shape, (None, 20, 20, 3))\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = KerasTensor([None, 3, 20, 20])\n        out = kimage.hsv_to_rgb(x)\n        self.assertEqual(out.shape, (None, 3, 20, 20))\n\n    def test_resize(self):\n        # Test channels_last\n        x = KerasTensor([None, 20, 20, 3])\n        out = kimage.resize(x, size=(15, 15))\n        self.assertEqual(out.shape, (None, 15, 15, 3))\n\n        x = KerasTensor([None, None, 3])\n        out = kimage.resize(x, size=(15, 15))\n        self.assertEqual(out.shape, (15, 15, 3))\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = KerasTensor([None, 3, 20, 20])\n        out = kimage.resize(x, size=(15, 15))\n        self.assertEqual(out.shape, (None, 3, 15, 15))\n\n        x = KerasTensor([3, None, None])\n        out = kimage.resize(x, size=(15, 15))\n        self.assertEqual(out.shape, (3, 15, 15))\n\n    def test_affine_transform(self):\n        # Test channels_last\n        x = KerasTensor([None, 20, 20, 3])\n        transform = KerasTensor([None, 8])\n        out = kimage.affine_transform(x, transform)\n        self.assertEqual(out.shape, (None, 20, 20, 3))\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = KerasTensor([None, 3, 20, 20])\n        transform = KerasTensor([None, 8])\n        out = kimage.affine_transform(x, transform)\n        self.assertEqual(out.shape, (None, 3, 20, 20))\n\n    def test_extract_patches(self):\n        # Test channels_last\n        x = KerasTensor([None, 20, 20, 3])\n        p_h, p_w = 5, 5\n        out = kimage.extract_patches(x, (p_h, p_w))\n        self.assertEqual(out.shape, (None, 4, 4, 75))\n        out = kimage.extract_patches(x, 5)\n        self.assertEqual(out.shape, (None, 4, 4, 75))\n        out = kimage.extract_patches(x, 5, strides=1)\n        self.assertEqual(out.shape, (None, 16, 16, 75))\n        out = kimage.extract_patches(x, 5, strides=(2, 3))\n        self.assertEqual(out.shape, (None, 8, 6, 75))\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = KerasTensor([None, 3, 20, 20])\n        p_h, p_w = 5, 5\n        out = kimage.extract_patches(x, (p_h, p_w))\n        self.assertEqual(out.shape, (None, 75, 4, 4))\n        out = kimage.extract_patches(x, 5)\n        self.assertEqual(out.shape, (None, 75, 4, 4))\n        out = kimage.extract_patches(x, 5, strides=1)\n        self.assertEqual(out.shape, (None, 75, 16, 16))\n        out = kimage.extract_patches(x, 5, strides=(2, 3))\n        self.assertEqual(out.shape, (None, 75, 8, 6))\n\n    def test_extract_patches_3d(self):\n        # Test channels_last\n        x = KerasTensor([None, 20, 20, 20, 3])\n        p_d, p_h, p_w = 5, 5, 5\n        out = kimage.extract_patches_3d(x, (p_d, p_h, p_w))\n        self.assertEqual(out.shape, (None, 4, 4, 4, 375))\n        out = kimage.extract_patches_3d(x, 5)\n        self.assertEqual(out.shape, (None, 4, 4, 4, 375))\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = KerasTensor([None, 3, 20, 20, 20])\n        p_d, p_h, p_w = 5, 5, 5\n        out = kimage.extract_patches_3d(x, (p_d, p_h, p_w))\n        self.assertEqual(out.shape, (None, 375, 4, 4, 4))\n        out = kimage.extract_patches_3d(x, 5)\n        self.assertEqual(out.shape, (None, 375, 4, 4, 4))\n\n    def test_map_coordinates(self):\n        input = KerasTensor([20, 20, None])\n        coordinates = KerasTensor([3, 15, 15, None])\n        out = kimage.map_coordinates(input, coordinates, 0)\n        self.assertEqual(out.shape, coordinates.shape[1:])\n\n    def test_pad_images(self):\n        # Test channels_last\n        x = KerasTensor([None, 15, 25, 3])\n        out = kimage.pad_images(x, 2, 3, target_height=20, target_width=30)\n        self.assertEqual(out.shape, (None, 20, 30, 3))\n\n        x = KerasTensor([None, None, 3])\n        out = kimage.pad_images(x, 2, 3, target_height=20, target_width=30)\n        self.assertEqual(out.shape, (20, 30, 3))\n\n        # Test unknown shape\n        x = KerasTensor([None, None, 3])\n        out = kimage.pad_images(x, 2, 3, 2, 3)\n        self.assertEqual(out.shape, (None, None, 3))\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = KerasTensor([None, 3, 15, 25])\n        out = kimage.pad_images(x, 2, 3, target_height=20, target_width=30)\n        self.assertEqual(out.shape, (None, 3, 20, 30))\n\n        x = KerasTensor([3, None, None])\n        out = kimage.pad_images(x, 2, 3, target_height=20, target_width=30)\n        self.assertEqual(out.shape, (3, 20, 30))\n\n    def test_crop_images(self):\n        # Test channels_last\n        x = KerasTensor([None, 15, 25, 3])\n        out = kimage.crop_images(x, 2, 3, target_height=10, target_width=20)\n        self.assertEqual(out.shape, (None, 10, 20, 3))\n\n        x = KerasTensor([None, None, 3])\n        out = kimage.crop_images(x, 2, 3, target_height=10, target_width=20)\n        self.assertEqual(out.shape, (10, 20, 3))\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = KerasTensor([None, 3, 15, 25])\n        out = kimage.crop_images(x, 2, 3, target_height=10, target_width=20)\n        self.assertEqual(out.shape, (None, 3, 10, 20))\n\n        x = KerasTensor([3, None, None])\n        out = kimage.crop_images(x, 2, 3, target_height=10, target_width=20)\n        self.assertEqual(out.shape, (3, 10, 20))\n\n    def test_perspective_transform(self):\n        # Test channels_last\n        x = KerasTensor([None, 20, 20, 3])\n        start_points = KerasTensor([None, 4, 2])\n        end_points = KerasTensor([None, 4, 2])\n        out = kimage.perspective_transform(x, start_points, end_points)\n        self.assertEqual(out.shape, (None, 20, 20, 3))\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = KerasTensor([None, 3, 20, 20])\n        start_points = KerasTensor([None, 4, 2])\n        end_points = KerasTensor([None, 4, 2])\n        out = kimage.perspective_transform(x, start_points, end_points)\n        self.assertEqual(out.shape, (None, 3, 20, 20))\n\n    def test_gaussian_blur(self):\n        # Test channels_last\n        x = KerasTensor([None, 20, 20, 3])\n        out = kimage.gaussian_blur(x)\n        self.assertEqual(out.shape, (None, 20, 20, 3))\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = KerasTensor([None, 3, 20, 20])\n        out = kimage.gaussian_blur(x)\n        self.assertEqual(out.shape, (None, 3, 20, 20))\n\n    def test_elastic_transform(self):\n        # Test channels_last\n        x = KerasTensor([None, 20, 20, 3])\n        out = kimage.elastic_transform(x)\n        self.assertEqual(out.shape, (None, 20, 20, 3))\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = KerasTensor([None, 3, 20, 20])\n        out = kimage.elastic_transform(x)\n        self.assertEqual(out.shape, (None, 3, 20, 20))\n\n    def test_scale_and_translate(self):\n        images = KerasTensor([None, 20, 20, 3])\n        output_shape = (None, 25, 25, 3)\n        scale = KerasTensor([2])\n        translation = KerasTensor([2])\n        out = kimage.scale_and_translate(\n            images,\n            output_shape=output_shape,\n            scale=scale,\n            translation=translation,\n            spatial_dims=(1, 2),\n            method=\"linear\",\n        )\n        self.assertEqual(out.shape, output_shape)\n\n\nclass ImageOpsStaticShapeTest(testing.TestCase):\n    def setUp(self):\n        # Defaults to channels_last\n        self.data_format = backend.image_data_format()\n        backend.set_image_data_format(\"channels_last\")\n        return super().setUp()\n\n    def tearDown(self):\n        backend.set_image_data_format(self.data_format)\n        return super().tearDown()\n\n    def test_rgb_to_grayscale(self):\n        # Test channels_last\n        x = KerasTensor([20, 20, 3])\n        out = kimage.rgb_to_grayscale(x)\n        self.assertEqual(out.shape, (20, 20, 1))\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = KerasTensor([3, 20, 20])\n        out = kimage.rgb_to_grayscale(x)\n        self.assertEqual(out.shape, (1, 20, 20))\n\n    def test_rgb_to_hsv(self):\n        # Test channels_last\n        x = KerasTensor([20, 20, 3])\n        out = kimage.rgb_to_hsv(x)\n        self.assertEqual(out.shape, (20, 20, 3))\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = KerasTensor([3, 20, 20])\n        out = kimage.rgb_to_hsv(x)\n        self.assertEqual(out.shape, (3, 20, 20))\n\n    def test_hsv_to_rgb(self):\n        # Test channels_last\n        x = KerasTensor([20, 20, 3])\n        out = kimage.hsv_to_rgb(x)\n        self.assertEqual(out.shape, (20, 20, 3))\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = KerasTensor([3, 20, 20])\n        out = kimage.hsv_to_rgb(x)\n        self.assertEqual(out.shape, (3, 20, 20))\n\n    def test_resize(self):\n        # Test channels_last\n        x = KerasTensor([20, 20, 3])\n        out = kimage.resize(x, size=(15, 15))\n        self.assertEqual(out.shape, (15, 15, 3))\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = KerasTensor([3, 20, 20])\n        out = kimage.resize(x, size=(15, 15))\n        self.assertEqual(out.shape, (3, 15, 15))\n\n    def test_affine_transform(self):\n        # Test channels_last\n        x = KerasTensor([20, 20, 3])\n        transform = KerasTensor([8])\n        out = kimage.affine_transform(x, transform)\n        self.assertEqual(out.shape, (20, 20, 3))\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = KerasTensor([3, 20, 20])\n        transform = KerasTensor([8])\n        out = kimage.affine_transform(x, transform)\n        self.assertEqual(out.shape, (3, 20, 20))\n\n    def test_extract_patches(self):\n        # Test channels_last\n        x = KerasTensor([20, 20, 3])\n        p_h, p_w = 5, 5\n        out = kimage.extract_patches(x, (p_h, p_w))\n        self.assertEqual(out.shape, (4, 4, 75))\n        out = kimage.extract_patches(x, 5)\n        self.assertEqual(out.shape, (4, 4, 75))\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = KerasTensor([3, 20, 20])\n        p_h, p_w = 5, 5\n        out = kimage.extract_patches(x, (p_h, p_w))\n        self.assertEqual(out.shape, (75, 4, 4))\n        out = kimage.extract_patches(x, 5)\n        self.assertEqual(out.shape, (75, 4, 4))\n\n    def test_extract_patches_3d(self):\n        # Test channels_last\n        x = KerasTensor([20, 20, 20, 3])\n        p_d, p_h, p_w = 5, 5, 5\n        out = kimage.extract_patches_3d(x, (p_d, p_h, p_w))\n        self.assertEqual(out.shape, (4, 4, 4, 375))\n        out = kimage.extract_patches_3d(x, 5)\n        self.assertEqual(out.shape, (4, 4, 4, 375))\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = KerasTensor([3, 20, 20, 20])\n        p_d, p_h, p_w = 5, 5, 5\n        out = kimage.extract_patches_3d(x, (p_d, p_h, p_w))\n        self.assertEqual(out.shape, (375, 4, 4, 4))\n        out = kimage.extract_patches_3d(x, 5)\n        self.assertEqual(out.shape, (375, 4, 4, 4))\n\n    def test_map_coordinates(self):\n        input = KerasTensor([20, 20, 3])\n        coordinates = KerasTensor([3, 15, 15, 3])\n        out = kimage.map_coordinates(input, coordinates, 0)\n        self.assertEqual(out.shape, coordinates.shape[1:])\n\n    def test_map_coordinates_uint8(self):\n        image_uint8 = tf.ones((1, 1, 3), dtype=tf.uint8)\n        coordinates = tf.convert_to_tensor([-1.0, 0.0, 0.0])[..., None, None]\n\n        if backend.backend() != \"tensorflow\":\n            pytest.skip(\"Skipping test because the backend is not TensorFlow.\")\n\n        out = kimage.map_coordinates(\n            image_uint8, coordinates, order=1, fill_mode=\"constant\"\n        )\n        self.assertEqual(out.shape, coordinates.shape[1:])\n\n    def test_map_coordinates_float32(self):\n        image_float32 = tf.ones((1, 1, 3), dtype=tf.float32)\n        coordinates = tf.convert_to_tensor([-1.0, 0.0, 0.0])[..., None, None]\n\n        if backend.backend() != \"tensorflow\":\n            pytest.skip(\"Skipping test because the backend is not TensorFlow.\")\n\n        out = kimage.map_coordinates(\n            image_float32, coordinates, order=1, fill_mode=\"constant\"\n        )\n        self.assertEqual(out.shape, coordinates.shape[1:])\n\n    def test_map_coordinates_nearest(self):\n        image_uint8 = tf.ones((1, 1, 3), dtype=tf.uint8)\n        coordinates = tf.convert_to_tensor([-1.0, 0.0, 0.0])[..., None, None]\n\n        if backend.backend() != \"tensorflow\":\n            pytest.skip(\"Skipping test because the backend is not TensorFlow.\")\n\n        out = kimage.map_coordinates(\n            image_uint8, coordinates, order=1, fill_mode=\"nearest\"\n        )\n        self.assertEqual(out.shape, coordinates.shape[1:])\n\n    def test_map_coordinates_manual_cast(self):\n        image_uint8 = tf.ones((1, 1, 3), dtype=tf.uint8)\n        coordinates = tf.convert_to_tensor([-1.0, 0.0, 0.0])[..., None, None]\n        image_uint8_casted = tf.cast(image_uint8, dtype=tf.float32)\n\n        if backend.backend() != \"tensorflow\":\n            pytest.skip(\"Skipping test because the backend is not TensorFlow.\")\n\n        out = tf.cast(\n            kimage.map_coordinates(\n                image_uint8_casted, coordinates, order=1, fill_mode=\"constant\"\n            ),\n            dtype=tf.uint8,\n        )\n        self.assertEqual(out.shape, coordinates.shape[1:])\n\n    def test_pad_images(self):\n        # Test channels_last\n        x = KerasTensor([15, 25, 3])\n        out = kimage.pad_images(x, 2, 3, target_height=20, target_width=30)\n        self.assertEqual(out.shape, (20, 30, 3))\n\n        x_batch = KerasTensor([2, 15, 25, 3])\n        out_batch = kimage.pad_images(\n            x_batch, 2, 3, target_height=20, target_width=30\n        )\n        self.assertEqual(out_batch.shape, (2, 20, 30, 3))\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = KerasTensor([3, 15, 25])\n        out = kimage.pad_images(x, 2, 3, target_height=20, target_width=30)\n        self.assertEqual(out.shape, (3, 20, 30))\n\n        x_batch = KerasTensor([2, 3, 15, 25])\n        out_batch = kimage.pad_images(\n            x_batch, 2, 3, target_height=20, target_width=30\n        )\n        self.assertEqual(out_batch.shape, (2, 3, 20, 30))\n\n    def test_crop_images(self):\n        # Test channels_last\n        x = KerasTensor([15, 25, 3])\n        out = kimage.crop_images(x, 2, 3, target_height=10, target_width=20)\n        self.assertEqual(out.shape, (10, 20, 3))\n\n        x_batch = KerasTensor([2, 15, 25, 3])\n        out_batch = kimage.crop_images(\n            x_batch, 2, 3, target_height=10, target_width=20\n        )\n        self.assertEqual(out_batch.shape, (2, 10, 20, 3))\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = KerasTensor([3, 15, 25])\n        out = kimage.crop_images(x, 2, 3, target_height=10, target_width=20)\n        self.assertEqual(out.shape, (3, 10, 20))\n\n        # Test channels_first and batched\n        x_batch = KerasTensor([2, 3, 15, 25])\n        out_batch = kimage.crop_images(\n            x_batch, 2, 3, target_height=10, target_width=20\n        )\n        self.assertEqual(out_batch.shape, (2, 3, 10, 20))\n\n    def test_perspective_transform(self):\n        # Test channels_last\n        x = KerasTensor([20, 20, 3])\n        start_points = KerasTensor([4, 2])\n        end_points = KerasTensor([4, 2])\n        out = kimage.perspective_transform(x, start_points, end_points)\n        self.assertEqual(out.shape, (20, 20, 3))\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = KerasTensor([3, 20, 20])\n        start_points = KerasTensor([4, 2])\n        end_points = KerasTensor([4, 2])\n        out = kimage.perspective_transform(x, start_points, end_points)\n        self.assertEqual(out.shape, (3, 20, 20))\n\n    def test_gaussian_blur(self):\n        # Test channels_last\n        x = KerasTensor([20, 20, 3])\n        kernel_size = KerasTensor(\n            [\n                2,\n            ]\n        )\n        sigma = KerasTensor(\n            [\n                2,\n            ]\n        )\n        out = kimage.gaussian_blur(x, kernel_size, sigma)\n        self.assertEqual(out.shape, (20, 20, 3))\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = KerasTensor([3, 20, 20])\n        kernel_size = KerasTensor(\n            [\n                2,\n            ]\n        )\n        sigma = KerasTensor(\n            [\n                2,\n            ]\n        )\n        out = kimage.gaussian_blur(x, kernel_size, sigma)\n        self.assertEqual(out.shape, (3, 20, 20))\n\n    def test_elastic_transform(self):\n        # Test channels_last\n        x = KerasTensor([20, 20, 3])\n        out = kimage.elastic_transform(x)\n        self.assertEqual(out.shape, (20, 20, 3))\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = KerasTensor([3, 20, 20])\n        out = kimage.elastic_transform(x)\n        self.assertEqual(out.shape, (3, 20, 20))\n\n    def test_scale_and_translate(self):\n        images = KerasTensor([20, 20, 3])\n        output_shape = (25, 25, 3)\n        scale = KerasTensor([2])\n        translation = KerasTensor([2])\n        out = kimage.scale_and_translate(\n            images,\n            output_shape=output_shape,\n            scale=scale,\n            translation=translation,\n            spatial_dims=(0, 1),\n            method=\"linear\",\n        )\n        self.assertEqual(out.shape, output_shape)\n\n\nAFFINE_TRANSFORM_INTERPOLATIONS = {  # map to order\n    \"nearest\": 0,\n    \"bilinear\": 1,\n}\n\n\ndef _compute_affine_transform_coordinates(image, transform):\n    image = image.copy()\n    transform = transform.copy()\n    need_squeeze = False\n    if len(image.shape) == 3:  # unbatched\n        need_squeeze = True\n        image = np.expand_dims(image, axis=0)\n        transform = np.expand_dims(transform, axis=0)\n    batch_size = image.shape[0]\n    # get indices\n    meshgrid = np.meshgrid(\n        *[np.arange(size) for size in image.shape[1:]], indexing=\"ij\"\n    )\n    indices = np.concatenate(\n        [np.expand_dims(x, axis=-1) for x in meshgrid], axis=-1\n    )\n    indices = np.tile(indices, (batch_size, 1, 1, 1, 1))\n    # swap the values\n    transform[:, 4], transform[:, 0] = (\n        transform[:, 0].copy(),\n        transform[:, 4].copy(),\n    )\n    transform[:, 5], transform[:, 2] = (\n        transform[:, 2].copy(),\n        transform[:, 5].copy(),\n    )\n    # deal with transform\n    transform = np.pad(transform, pad_width=[[0, 0], [0, 1]], constant_values=1)\n    transform = np.reshape(transform, (batch_size, 3, 3))\n    offset = np.pad(transform[:, 0:2, 2], pad_width=[[0, 0], [0, 1]])\n    transform[:, 0:2, 2] = 0\n    # transform the indices\n    coordinates = np.einsum(\"Bhwij, Bjk -> Bhwik\", indices, transform)\n    coordinates = np.moveaxis(coordinates, source=-1, destination=1)\n    coordinates += np.reshape(offset, (*offset.shape, 1, 1, 1))\n    if need_squeeze:\n        coordinates = np.squeeze(coordinates, axis=0)\n    return coordinates\n\n\ndef _fixed_map_coordinates(\n    input, coordinates, order, fill_mode=\"constant\", fill_value=0.0\n):\n    # SciPy's implementation of map_coordinates handles boundaries incorrectly,\n    # unless mode='reflect'. For order=1, this only affects interpolation\n    # outside the bounds of the original array.\n    # https://github.com/scipy/scipy/issues/2640\n    padding = [\n        (\n            max(-np.floor(c.min()).astype(int) + 1, 0),\n            max(np.ceil(c.max()).astype(int) + 1 - size, 0),\n        )\n        for c, size in zip(coordinates, input.shape)\n    ]\n    shifted_coords = [c + p[0] for p, c in zip(padding, coordinates)]\n    pad_mode = {\n        \"nearest\": \"edge\",\n        \"mirror\": \"reflect\",\n        \"reflect\": \"symmetric\",\n    }.get(fill_mode, fill_mode)\n    if fill_mode == \"constant\":\n        padded = np.pad(\n            input, padding, mode=pad_mode, constant_values=fill_value\n        )\n    else:\n        padded = np.pad(input, padding, mode=pad_mode)\n    result = scipy.ndimage.map_coordinates(\n        padded, shifted_coords, order=order, mode=fill_mode, cval=fill_value\n    )\n    return result\n\n\ndef _perspective_transform_numpy(\n    images,\n    start_points,\n    end_points,\n    interpolation=\"bilinear\",\n    fill_value=0,\n    data_format=None,\n):\n    data_format = backend.standardize_data_format(data_format)\n\n    need_squeeze = False\n    if len(images.shape) == 3:\n        images = np.expand_dims(images, axis=0)\n        need_squeeze = True\n\n    if len(start_points.shape) == 2:\n        start_points = np.expand_dims(start_points, axis=0)\n    if len(end_points.shape) == 2:\n        end_points = np.expand_dims(end_points, axis=0)\n\n    if data_format == \"channels_first\":\n        images = np.transpose(images, (0, 2, 3, 1))\n\n    batch_size, height, width, channels = images.shape\n\n    transforms = _compute_homography_matrix(start_points, end_points)\n\n    if len(transforms.shape) == 1:\n        transforms = np.expand_dims(transforms, axis=0)\n    if transforms.shape[0] == 1 and batch_size > 1:\n        transforms = np.tile(transforms, (batch_size, 1))\n\n    x, y = np.meshgrid(\n        np.arange(width, dtype=np.float32),\n        np.arange(height, dtype=np.float32),\n        indexing=\"xy\",\n    )\n\n    output = np.empty((batch_size, height, width, channels))\n\n    for i in range(batch_size):\n        a0, a1, a2, a3, a4, a5, a6, a7 = transforms[i]\n        denom = a6 * x + a7 * y + 1.0\n        x_in = (a0 * x + a1 * y + a2) / denom\n        y_in = (a3 * x + a4 * y + a5) / denom\n\n        coords = np.stack([y_in.ravel(), x_in.ravel()], axis=0)\n\n        mapped_channels = []\n        for channel in range(channels):\n            channel_img = images[i, :, :, channel]\n\n            mapped_channel = _fixed_map_coordinates(\n                channel_img,\n                coords,\n                order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],\n                fill_mode=\"constant\",\n                fill_value=fill_value,\n            )\n            mapped_channels.append(mapped_channel.reshape(height, width))\n\n        output[i] = np.stack(mapped_channels, axis=-1)\n\n    if data_format == \"channels_first\":\n        output = np.transpose(output, (0, 3, 1, 2))\n    if need_squeeze:\n        output = np.squeeze(output, axis=0)\n\n    return output\n\n\ndef gaussian_blur_np(\n    images,\n    kernel_size,\n    sigma,\n    data_format=None,\n):\n    def _create_gaussian_kernel(kernel_size, sigma, num_channels, dtype):\n        def _get_gaussian_kernel1d(size, sigma):\n            x = np.arange(size, dtype=dtype) - (size - 1) / 2\n            kernel1d = np.exp(-0.5 * (x / sigma) ** 2)\n            return kernel1d / np.sum(kernel1d)\n\n        def _get_gaussian_kernel2d(size, sigma):\n            kernel1d_x = _get_gaussian_kernel1d(size[0], sigma[0])\n            kernel1d_y = _get_gaussian_kernel1d(size[1], sigma[1])\n            return np.outer(kernel1d_y, kernel1d_x)\n\n        kernel = _get_gaussian_kernel2d(kernel_size, sigma)\n        kernel = kernel[:, :, np.newaxis]\n        kernel = np.tile(kernel, (1, 1, num_channels))\n        return kernel.astype(dtype)\n\n    images = np.asarray(images)\n    input_dtype = images.dtype\n    kernel_size = np.asarray(kernel_size)\n\n    if len(images.shape) not in (3, 4):\n        raise ValueError(\n            \"Invalid images rank: expected rank 3 (single image) \"\n            \"or rank 4 (batch of images). Received input with shape: \"\n            f\"images.shape={images.shape}\"\n        )\n\n    need_squeeze = False\n    if len(images.shape) == 3:\n        images = np.expand_dims(images, axis=0)\n        need_squeeze = True\n\n    if data_format == \"channels_first\":\n        images = np.transpose(images, (0, 2, 3, 1))\n\n    num_channels = images.shape[-1]\n    kernel = _create_gaussian_kernel(\n        kernel_size, sigma, num_channels, input_dtype\n    )\n    batch_size, height, width, _ = images.shape\n\n    kernel_h, kernel_w = kernel.shape[0], kernel.shape[1]\n    pad_h = (kernel_h - 1) // 2\n    pad_h_after = kernel_h - 1 - pad_h\n    pad_w = (kernel_w - 1) // 2\n    pad_w_after = kernel_w - 1 - pad_w\n\n    padded_images = np.pad(\n        images,\n        (\n            (0, 0),\n            (pad_h, pad_h_after),\n            (pad_w, pad_w_after),\n            (0, 0),\n        ),\n        mode=\"constant\",\n    )\n\n    blurred_images = np.zeros_like(images)\n    kernel_reshaped = kernel.reshape((1, kernel_h, kernel_w, num_channels))\n\n    for b in range(batch_size):\n        image_patch = padded_images[b : b + 1, :, :, :]\n\n    for i in range(height):\n        for j in range(width):\n            patch = image_patch[:, i : i + kernel_h, j : j + kernel_w, :]\n            blurred_images[b, i, j, :] = np.sum(\n                patch * kernel_reshaped, axis=(1, 2)\n            )\n\n    if data_format == \"channels_first\":\n        blurred_images = np.transpose(blurred_images, (0, 3, 1, 2))\n    if need_squeeze:\n        blurred_images = np.squeeze(blurred_images, axis=0)\n\n    return blurred_images\n\n\ndef elastic_transform_np(\n    images,\n    alpha=20.0,\n    sigma=5.0,\n    interpolation=\"bilinear\",\n    fill_mode=\"reflect\",\n    fill_value=0.0,\n    seed=None,\n    data_format=None,\n):\n    data_format = backend.standardize_data_format(data_format)\n\n    images = np.asarray(images)\n    input_dtype = images.dtype\n\n    alpha = np.asarray(alpha, dtype=input_dtype)\n    sigma = np.asarray(sigma, dtype=input_dtype)\n\n    kernel_size = (int(6 * sigma) | 1, int(6 * sigma) | 1)\n\n    need_squeeze = False\n    if len(images.shape) == 3:\n        images = np.expand_dims(images, axis=0)\n        need_squeeze = True\n\n    if data_format == \"channels_last\":\n        batch_size, height, width, channels = images.shape\n        channel_axis = -1\n    else:\n        batch_size, channels, height, width = images.shape\n        channel_axis = 1\n\n    rng = np.random.default_rng([seed, 0])\n    dx = (\n        rng.normal(size=(batch_size, height, width), loc=0.0, scale=1.0).astype(\n            input_dtype\n        )\n        * sigma\n    )\n    dy = (\n        rng.normal(size=(batch_size, height, width), loc=0.0, scale=1.0).astype(\n            input_dtype\n        )\n        * sigma\n    )\n\n    dx = gaussian_blur_np(\n        np.expand_dims(dx, axis=channel_axis),\n        kernel_size=kernel_size,\n        sigma=(sigma, sigma),\n        data_format=data_format,\n    )\n    dy = gaussian_blur_np(\n        np.expand_dims(dy, axis=channel_axis),\n        kernel_size=kernel_size,\n        sigma=(sigma, sigma),\n        data_format=data_format,\n    )\n\n    dx = np.squeeze(dx)\n    dy = np.squeeze(dy)\n\n    x, y = np.meshgrid(np.arange(width), np.arange(height))\n    x, y = x[None, :, :], y[None, :, :]\n\n    distorted_x = x + alpha * dx\n    distorted_y = y + alpha * dy\n\n    transformed_images = np.zeros_like(images)\n\n    if data_format == \"channels_last\":\n        for i in range(channels):\n            transformed_images[..., i] = np.stack(\n                [\n                    _fixed_map_coordinates(\n                        images[b, ..., i],\n                        [distorted_y[b], distorted_x[b]],\n                        order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],\n                        fill_mode=fill_mode,\n                        fill_value=fill_value,\n                    )\n                    for b in range(batch_size)\n                ]\n            )\n    else:\n        for i in range(channels):\n            transformed_images[:, i, :, :] = np.stack(\n                [\n                    _fixed_map_coordinates(\n                        images[b, i, ...],\n                        [distorted_y[b], distorted_x[b]],\n                        order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],\n                        fill_mode=fill_mode,\n                        fill_value=fill_value,\n                    )\n                    for b in range(batch_size)\n                ]\n            )\n\n    if need_squeeze:\n        transformed_images = np.squeeze(transformed_images, axis=0)\n    transformed_images = transformed_images.astype(input_dtype)\n\n    return transformed_images\n\n\ndef _compute_homography_matrix(start_points, end_points):\n    start_x1, start_y1 = start_points[:, 0, 0], start_points[:, 0, 1]\n    start_x2, start_y2 = start_points[:, 1, 0], start_points[:, 1, 1]\n    start_x3, start_y3 = start_points[:, 2, 0], start_points[:, 2, 1]\n    start_x4, start_y4 = start_points[:, 3, 0], start_points[:, 3, 1]\n\n    end_x1, end_y1 = end_points[:, 0, 0], end_points[:, 0, 1]\n    end_x2, end_y2 = end_points[:, 1, 0], end_points[:, 1, 1]\n    end_x3, end_y3 = end_points[:, 2, 0], end_points[:, 2, 1]\n    end_x4, end_y4 = end_points[:, 3, 0], end_points[:, 3, 1]\n\n    coefficient_matrix = np.stack(\n        [\n            np.stack(\n                [\n                    end_x1,\n                    end_y1,\n                    np.ones_like(end_x1),\n                    np.zeros_like(end_x1),\n                    np.zeros_like(end_x1),\n                    np.zeros_like(end_x1),\n                    -start_x1 * end_x1,\n                    -start_x1 * end_y1,\n                ],\n                axis=-1,\n            ),\n            np.stack(\n                [\n                    np.zeros_like(end_x1),\n                    np.zeros_like(end_x1),\n                    np.zeros_like(end_x1),\n                    end_x1,\n                    end_y1,\n                    np.ones_like(end_x1),\n                    -start_y1 * end_x1,\n                    -start_y1 * end_y1,\n                ],\n                axis=-1,\n            ),\n            np.stack(\n                [\n                    end_x2,\n                    end_y2,\n                    np.ones_like(end_x2),\n                    np.zeros_like(end_x2),\n                    np.zeros_like(end_x2),\n                    np.zeros_like(end_x2),\n                    -start_x2 * end_x2,\n                    -start_x2 * end_y2,\n                ],\n                axis=-1,\n            ),\n            np.stack(\n                [\n                    np.zeros_like(end_x2),\n                    np.zeros_like(end_x2),\n                    np.zeros_like(end_x2),\n                    end_x2,\n                    end_y2,\n                    np.ones_like(end_x2),\n                    -start_y2 * end_x2,\n                    -start_y2 * end_y2,\n                ],\n                axis=-1,\n            ),\n            np.stack(\n                [\n                    end_x3,\n                    end_y3,\n                    np.ones_like(end_x3),\n                    np.zeros_like(end_x3),\n                    np.zeros_like(end_x3),\n                    np.zeros_like(end_x3),\n                    -start_x3 * end_x3,\n                    -start_x3 * end_y3,\n                ],\n                axis=-1,\n            ),\n            np.stack(\n                [\n                    np.zeros_like(end_x3),\n                    np.zeros_like(end_x3),\n                    np.zeros_like(end_x3),\n                    end_x3,\n                    end_y3,\n                    np.ones_like(end_x3),\n                    -start_y3 * end_x3,\n                    -start_y3 * end_y3,\n                ],\n                axis=-1,\n            ),\n            np.stack(\n                [\n                    end_x4,\n                    end_y4,\n                    np.ones_like(end_x4),\n                    np.zeros_like(end_x4),\n                    np.zeros_like(end_x4),\n                    np.zeros_like(end_x4),\n                    -start_x4 * end_x4,\n                    -start_x4 * end_y4,\n                ],\n                axis=-1,\n            ),\n            np.stack(\n                [\n                    np.zeros_like(end_x4),\n                    np.zeros_like(end_x4),\n                    np.zeros_like(end_x4),\n                    end_x4,\n                    end_y4,\n                    np.ones_like(end_x4),\n                    -start_y4 * end_x4,\n                    -start_y4 * end_y4,\n                ],\n                axis=-1,\n            ),\n        ],\n        axis=1,\n    )\n\n    target_vector = np.stack(\n        [\n            start_x1,\n            start_y1,\n            start_x2,\n            start_y2,\n            start_x3,\n            start_y3,\n            start_x4,\n            start_y4,\n        ],\n        axis=-1,\n    )\n    target_vector = np.expand_dims(target_vector, axis=-1)\n\n    homography_matrix = np.linalg.solve(coefficient_matrix, target_vector)\n    homography_matrix = np.reshape(homography_matrix, [-1, 8])\n\n    return homography_matrix\n\n\nclass ImageOpsCorrectnessTest(testing.TestCase):\n    def setUp(self):\n        # Defaults to channels_last\n        self.data_format = backend.image_data_format()\n        backend.set_image_data_format(\"channels_last\")\n        return super().setUp()\n\n    def tearDown(self):\n        backend.set_image_data_format(self.data_format)\n        return super().tearDown()\n\n    def test_rgb_to_grayscale(self):\n        # Test channels_last\n        x = np.random.random((50, 50, 3)).astype(\"float32\") * 255\n        out = kimage.rgb_to_grayscale(x)\n        ref_out = tf.image.rgb_to_grayscale(x)\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(ref_out.numpy(), out)\n\n        x = np.random.random((2, 50, 50, 3)).astype(\"float32\") * 255\n        out = kimage.rgb_to_grayscale(x)\n        ref_out = tf.image.rgb_to_grayscale(x)\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(ref_out.numpy(), out)\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = np.random.random((3, 50, 50)).astype(\"float32\") * 255\n        out = kimage.rgb_to_grayscale(x)\n        ref_out = tf.image.rgb_to_grayscale(np.transpose(x, [1, 2, 0]))\n        ref_out = tf.transpose(ref_out, [2, 0, 1])\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(ref_out.numpy(), out)\n\n        x = np.random.random((2, 3, 50, 50)).astype(\"float32\") * 255\n        out = kimage.rgb_to_grayscale(x)\n        ref_out = tf.image.rgb_to_grayscale(np.transpose(x, [0, 2, 3, 1]))\n        ref_out = tf.transpose(ref_out, [0, 3, 1, 2])\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(ref_out.numpy(), out)\n\n        # Test class\n        out = kimage.RGBToGrayscale()(x)\n        self.assertAllClose(ref_out.numpy(), out)\n\n    def test_rgb_to_hsv(self):\n        # Test channels_last\n        x = np.random.random((50, 50, 3)).astype(\"float32\")\n        out = kimage.rgb_to_hsv(x)\n        ref_out = tf.image.rgb_to_hsv(x)\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(ref_out.numpy(), out)\n\n        x = np.random.random((2, 50, 50, 3)).astype(\"float32\")\n        out = kimage.rgb_to_hsv(x)\n        ref_out = tf.image.rgb_to_hsv(x)\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(ref_out.numpy(), out)\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = np.random.random((3, 50, 50)).astype(\"float32\")\n        out = kimage.rgb_to_hsv(x)\n        ref_out = tf.image.rgb_to_hsv(np.transpose(x, [1, 2, 0]))\n        ref_out = tf.transpose(ref_out, [2, 0, 1])\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(ref_out.numpy(), out)\n\n        x = np.random.random((2, 3, 50, 50)).astype(\"float32\")\n        out = kimage.rgb_to_hsv(x)\n        ref_out = tf.image.rgb_to_hsv(np.transpose(x, [0, 2, 3, 1]))\n        ref_out = tf.transpose(ref_out, [0, 3, 1, 2])\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(ref_out.numpy(), out)\n\n        # Test class\n        out = kimage.RGBToHSV()(x)\n        self.assertAllClose(ref_out.numpy(), out)\n\n    def test_hsv_to_rgb(self):\n        # Test channels_last\n        x = np.random.random((50, 50, 3)).astype(\"float32\")\n        out = kimage.hsv_to_rgb(x)\n        ref_out = tf.image.hsv_to_rgb(x)\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(ref_out.numpy(), out)\n\n        x = np.random.random((2, 50, 50, 3)).astype(\"float32\")\n        out = kimage.hsv_to_rgb(x)\n        ref_out = tf.image.hsv_to_rgb(x)\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(ref_out.numpy(), out)\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = np.random.random((3, 50, 50)).astype(\"float32\")\n        out = kimage.hsv_to_rgb(x)\n        ref_out = tf.image.hsv_to_rgb(np.transpose(x, [1, 2, 0]))\n        ref_out = tf.transpose(ref_out, [2, 0, 1])\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(ref_out.numpy(), out)\n\n        x = np.random.random((2, 3, 50, 50)).astype(\"float32\")\n        out = kimage.hsv_to_rgb(x)\n        ref_out = tf.image.hsv_to_rgb(np.transpose(x, [0, 2, 3, 1]))\n        ref_out = tf.transpose(ref_out, [0, 3, 1, 2])\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(ref_out.numpy(), out)\n\n        # Test class\n        out = kimage.HSVToRGB()(x)\n        self.assertAllClose(ref_out.numpy(), out)\n\n    @parameterized.named_parameters(\n        named_product(\n            interpolation=[\n                \"bilinear\",\n                \"nearest\",\n                \"lanczos3\",\n                \"lanczos5\",\n                \"bicubic\",\n            ],\n            antialias=[True, False],\n        )\n    )\n    def test_resize(self, interpolation, antialias):\n        if backend.backend() == \"torch\":\n            if \"lanczos\" in interpolation:\n                self.skipTest(\n                    \"Resizing with Lanczos interpolation is \"\n                    \"not supported by the PyTorch backend. \"\n                    f\"Received: interpolation={interpolation}.\"\n                )\n            if interpolation == \"bicubic\" and antialias is False:\n                self.skipTest(\n                    \"Resizing with Bicubic interpolation in \"\n                    \"PyTorch backend produces noise. Please \"\n                    \"turn on anti-aliasing. \"\n                    f\"Received: interpolation={interpolation}, \"\n                    f\"antialias={antialias}.\"\n                )\n        # Test channels_last\n        x = np.random.random((30, 30, 3)).astype(\"float32\") * 255\n        out = kimage.resize(\n            x,\n            size=(15, 15),\n            interpolation=interpolation,\n            antialias=antialias,\n        )\n        ref_out = tf.image.resize(\n            x,\n            size=(15, 15),\n            method=interpolation,\n            antialias=antialias,\n        )\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(ref_out, out, atol=1e-4)\n\n        x = np.random.random((2, 30, 30, 3)).astype(\"float32\") * 255\n        out = kimage.resize(\n            x,\n            size=(15, 15),\n            interpolation=interpolation,\n            antialias=antialias,\n        )\n        ref_out = tf.image.resize(\n            x,\n            size=(15, 15),\n            method=interpolation,\n            antialias=antialias,\n        )\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(ref_out, out, atol=1e-4)\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = np.random.random((3, 30, 30)).astype(\"float32\") * 255\n        out = kimage.resize(\n            x,\n            size=(15, 15),\n            interpolation=interpolation,\n            antialias=antialias,\n        )\n        ref_out = tf.image.resize(\n            np.transpose(x, [1, 2, 0]),\n            size=(15, 15),\n            method=interpolation,\n            antialias=antialias,\n        )\n        ref_out = tf.transpose(ref_out, [2, 0, 1])\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(ref_out, out, atol=1e-4)\n\n        x = np.random.random((2, 3, 30, 30)).astype(\"float32\") * 255\n        out = kimage.resize(\n            x,\n            size=(15, 15),\n            interpolation=interpolation,\n            antialias=antialias,\n        )\n        ref_out = tf.image.resize(\n            np.transpose(x, [0, 2, 3, 1]),\n            size=(15, 15),\n            method=interpolation,\n            antialias=antialias,\n        )\n        ref_out = tf.transpose(ref_out, [0, 3, 1, 2])\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(ref_out, out, atol=1e-4)\n\n        # Test class\n        out = kimage.Resize(\n            size=(15, 15),\n            interpolation=interpolation,\n            antialias=antialias,\n        )(x)\n        self.assertAllClose(ref_out, out, atol=1e-4)\n\n    def test_resize_uint8_round(self):\n        x = np.array([0, 1, 254, 255], dtype=\"uint8\").reshape(1, 2, 2, 1)\n        expected = np.array(\n            # OpenCV as gold standard.\n            # [\n            #     [0, 0, 1, 1],\n            #     [64, 64, 64, 65],\n            #     [191, 191, 191, 192],\n            #     [254, 254, 255, 255],\n            # ]\n            #\n            # Resize without `round` - differences in 8 points\n            # [\n            #     [0, 0, 0, 1],\n            #     [63, 63, 64, 64],\n            #     [190, 190, 191, 191],\n            #     [254, 254, 254, 255],\n            # ]\n            #\n            # Resize with `round` - differences in 2 points\n            [\n                [0, 0, 1, 1],\n                [64, 64, 64, 64],\n                [190, 191, 191, 192],\n                [254, 254, 255, 255],\n            ],\n            dtype=\"uint8\",\n        ).reshape(1, 4, 4, 1)\n        out = kimage.resize(\n            x,\n            size=(4, 4),\n            interpolation=\"bilinear\",\n            antialias=False,\n        )\n        self.assertEqual(tuple(out.shape), tuple(expected.shape))\n        self.assertEqual(backend.standardize_dtype(out.dtype), \"uint8\")\n        self.assertAllClose(out, expected, atol=1e-4)\n\n    def test_resize_uint8_round_saturate(self):\n        x = np.array([0, 1, 254, 255], dtype=\"uint8\").reshape(1, 2, 2, 1)\n        expected = np.array(\n            # OpenCV as gold standard. Same for `torch` backend.\n            (\n                [\n                    [0, 0, 0, 0],\n                    [57, 58, 58, 59],\n                    [196, 197, 197, 198],\n                    [255, 255, 255, 255],\n                ]\n                if \"torch\" == backend.backend()\n                # Resize without `round` and `saturate_cast` - differences in\n                # 16 points\n                # [\n                #     [234, 234, 235, 235],\n                #     [-5, -6, -5, -6],\n                #     [5, 4, 5, 4],\n                #     [-235, -235, -234, -234],\n                # ]\n                #\n                # Resize with `round` and `saturate_cast` - differences in\n                # 8 points\n                else [\n                    [0, 0, 0, 0],\n                    [53, 53, 53, 54],\n                    [201, 202, 202, 202],\n                    [255, 255, 255, 255],\n                ]\n            ),\n            dtype=\"uint8\",\n        ).reshape(1, 4, 4, 1)\n        out = kimage.resize(\n            x,\n            size=(4, 4),\n            interpolation=\"bicubic\",\n            antialias=False,\n        )\n        self.assertEqual(tuple(out.shape), tuple(expected.shape))\n        self.assertEqual(backend.standardize_dtype(out.dtype), \"uint8\")\n        self.assertAllClose(out, expected, atol=1e-4)\n\n    def test_resize_with_crop(self):\n        # Test channels_last\n        x = np.random.random((60, 50, 3)).astype(\"float32\") * 255\n        out = kimage.resize(x, size=(25, 25), crop_to_aspect_ratio=True)\n        self.assertEqual(out.shape, (25, 25, 3))\n\n        x = np.random.random((2, 50, 60, 3)).astype(\"float32\") * 255\n        out = kimage.resize(x, size=(25, 25), crop_to_aspect_ratio=True)\n        self.assertEqual(out.shape, (2, 25, 25, 3))\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = np.random.random((3, 60, 50)).astype(\"float32\") * 255\n        out = kimage.resize(x, size=(25, 25), crop_to_aspect_ratio=True)\n        self.assertEqual(out.shape, (3, 25, 25))\n\n        x = np.random.random((2, 3, 50, 60)).astype(\"float32\") * 255\n        out = kimage.resize(x, size=(25, 25), crop_to_aspect_ratio=True)\n        self.assertEqual(out.shape, (2, 3, 25, 25))\n\n    @parameterized.named_parameters(named_product(fill_value=[1.0, 2.0]))\n    def test_resize_with_pad(self, fill_value):\n        # Test channels_last\n        x = np.random.random((60, 50, 3)).astype(\"float32\") * 255\n        out = kimage.resize(\n            x,\n            size=(25, 25),\n            pad_to_aspect_ratio=True,\n            fill_value=fill_value,\n        )\n        self.assertEqual(out.shape, (25, 25, 3))\n\n        x = np.random.random((2, 50, 60, 3)).astype(\"float32\") * 255\n        out = kimage.resize(\n            x, size=(25, 25), pad_to_aspect_ratio=True, fill_value=fill_value\n        )\n        self.assertEqual(out.shape, (2, 25, 25, 3))\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = np.random.random((3, 60, 50)).astype(\"float32\") * 255\n        out = kimage.resize(\n            x, size=(25, 25), pad_to_aspect_ratio=True, fill_value=fill_value\n        )\n        self.assertEqual(out.shape, (3, 25, 25))\n\n        x = np.random.random((2, 3, 50, 60)).astype(\"float32\") * 255\n        out = kimage.resize(\n            x, size=(25, 25), pad_to_aspect_ratio=True, fill_value=fill_value\n        )\n        self.assertEqual(out.shape, (2, 3, 25, 25))\n\n        x = np.ones((2, 3, 10, 10)) * 128\n        out = kimage.resize(\n            x, size=(4, 4), pad_to_aspect_ratio=True, fill_value=fill_value\n        )\n        self.assertEqual(out.shape, (2, 3, 4, 4))\n        self.assertAllClose(out[:, 0, :, :], np.ones((2, 4, 4)) * 128)\n\n        x = np.ones((2, 3, 10, 8)) * 128\n        out = kimage.resize(\n            x, size=(4, 4), pad_to_aspect_ratio=True, fill_value=fill_value\n        )\n        self.assertEqual(out.shape, (2, 3, 4, 4))\n        self.assertAllClose(\n            out,\n            np.concatenate(\n                [\n                    np.ones((2, 3, 4, 1)) * 96.25,\n                    np.ones((2, 3, 4, 2)) * 128.0,\n                    np.ones((2, 3, 4, 1)) * 96.25,\n                ],\n                axis=3,\n            ),\n            atol=1.0,\n        )\n\n    @parameterized.named_parameters(\n        (\"zero_height\", (0, 10)),\n        (\"zero_width\", (10, 0)),\n        (\"zero_both\", (0, 0)),\n        (\"negative_height\", (-1, 10)),\n        (\"negative_width\", (10, -1)),\n    )\n    def test_resize_invalid_size_zero_or_negative(self, invalid_size):\n        \"\"\"Resize rejects zero or negative height/width.\"\"\"\n        x = np.random.random((10, 10, 3)).astype(\"float32\")\n        with self.assertRaisesRegex(\n            ValueError,\n            \"`size` must have positive height and width\",\n        ):\n            kimage.resize(x, size=invalid_size)\n\n    @parameterized.named_parameters(\n        named_product(\n            interpolation=[\"bilinear\", \"nearest\"],\n            fill_mode=[\"constant\", \"nearest\", \"wrap\", \"mirror\", \"reflect\"],\n        )\n    )\n    def test_affine_transform(self, interpolation, fill_mode):\n        if backend.backend() == \"tensorflow\" and fill_mode == \"mirror\":\n            self.skipTest(\n                \"In tensorflow backend, applying affine_transform with \"\n                \"fill_mode=mirror is not supported\"\n            )\n        if backend.backend() == \"tensorflow\" and fill_mode == \"wrap\":\n            self.skipTest(\n                \"In tensorflow backend, the numerical results of applying \"\n                \"affine_transform with fill_mode=wrap is inconsistent with\"\n                \"scipy\"\n            )\n        # TODO: `nearest` interpolation in jax and torch causes random index\n        # shifting, resulting in significant differences in output which leads\n        # to failure\n        if backend.backend() in (\"jax\", \"torch\") and interpolation == \"nearest\":\n            self.skipTest(\n                f\"In {backend.backend()} backend, \"\n                f\"interpolation={interpolation} causes index shifting and \"\n                \"leads test failure\"\n            )\n\n        # Test channels_last\n        np.random.seed(42)\n        x = np.random.uniform(size=(50, 50, 3)).astype(\"float32\") * 255\n        transform = np.random.uniform(size=(6)).astype(\"float32\")\n        transform = np.pad(transform, (0, 2))  # makes c0, c1 always 0\n        out = kimage.affine_transform(\n            x, transform, interpolation=interpolation, fill_mode=fill_mode\n        )\n        coordinates = _compute_affine_transform_coordinates(x, transform)\n        ref_out = _fixed_map_coordinates(\n            x,\n            coordinates,\n            order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],\n            fill_mode=fill_mode,\n        )\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(ref_out, out, atol=1e-2, tpu_atol=10, tpu_rtol=10)\n\n        x = np.random.uniform(size=(2, 50, 50, 3)).astype(\"float32\") * 255\n        transform = np.random.uniform(size=(2, 6)).astype(\"float32\")\n        transform = np.pad(transform, [(0, 0), (0, 2)])  # makes c0, c1 always 0\n        out = kimage.affine_transform(\n            x,\n            transform,\n            interpolation=interpolation,\n            fill_mode=fill_mode,\n        )\n        coordinates = _compute_affine_transform_coordinates(x, transform)\n        ref_out = np.stack(\n            [\n                _fixed_map_coordinates(\n                    x[i],\n                    coordinates[i],\n                    order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],\n                    fill_mode=fill_mode,\n                )\n                for i in range(x.shape[0])\n            ],\n            axis=0,\n        )\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(ref_out, out, atol=1e-2, tpu_atol=10, tpu_rtol=10)\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = np.random.uniform(size=(3, 50, 50)).astype(\"float32\") * 255\n        transform = np.random.uniform(size=(6)).astype(\"float32\")\n        transform = np.pad(transform, (0, 2))  # makes c0, c1 always 0\n        out = kimage.affine_transform(\n            x, transform, interpolation=interpolation, fill_mode=fill_mode\n        )\n        coordinates = _compute_affine_transform_coordinates(\n            np.transpose(x, [1, 2, 0]), transform\n        )\n        ref_out = _fixed_map_coordinates(\n            np.transpose(x, [1, 2, 0]),\n            coordinates,\n            order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],\n            fill_mode=fill_mode,\n        )\n        ref_out = np.transpose(ref_out, [2, 0, 1])\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(ref_out, out, atol=1e-2, tpu_atol=1, tpu_rtol=1)\n\n        x = np.random.uniform(size=(2, 3, 50, 50)).astype(\"float32\") * 255\n        transform = np.random.uniform(size=(2, 6)).astype(\"float32\")\n        transform = np.pad(transform, [(0, 0), (0, 2)])  # makes c0, c1 always 0\n        out = kimage.affine_transform(\n            x,\n            transform,\n            interpolation=interpolation,\n            fill_mode=fill_mode,\n        )\n        coordinates = _compute_affine_transform_coordinates(\n            np.transpose(x, [0, 2, 3, 1]), transform\n        )\n        ref_out = np.stack(\n            [\n                _fixed_map_coordinates(\n                    np.transpose(x[i], [1, 2, 0]),\n                    coordinates[i],\n                    order=AFFINE_TRANSFORM_INTERPOLATIONS[interpolation],\n                    fill_mode=fill_mode,\n                )\n                for i in range(x.shape[0])\n            ],\n            axis=0,\n        )\n        ref_out = np.transpose(ref_out, [0, 3, 1, 2])\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(ref_out, out, atol=1e-2, tpu_atol=10, tpu_rtol=10)\n\n        # Test class\n        out = kimage.AffineTransform(\n            interpolation=interpolation, fill_mode=fill_mode\n        )(x, transform)\n        self.assertAllClose(ref_out, out, atol=1e-2, tpu_atol=10, tpu_rtol=10)\n\n    @parameterized.named_parameters(\n        named_product(\n            size=[(3, 3), (5, 5)],\n            strides=[None, (1, 1), (2, 2)],\n            dilation_rate=[1, 3],\n            padding=[\"valid\", \"same\"],\n        )\n    )\n    def test_extract_patches(self, size, strides, dilation_rate, padding):\n        patch_h, patch_w = size[0], size[1]\n        if strides is None:\n            strides_h, strides_w = patch_h, patch_w\n        else:\n            strides_h, strides_w = strides[0], strides[1]\n        if (\n            backend.backend() == \"tensorflow\"\n            and strides_h > 1\n            or strides_w > 1\n            and dilation_rate > 1\n        ):\n            pytest.skip(\"dilation_rate>1 with strides>1 not supported with TF\")\n\n        # Test channels_last\n        image = np.random.uniform(size=(1, 20, 20, 3)).astype(\"float32\")\n        patches_out = kimage.extract_patches(\n            image,\n            size=size,\n            strides=strides,\n            dilation_rate=dilation_rate,\n            padding=padding,\n        )\n        patches_ref = tf.image.extract_patches(\n            image,\n            sizes=(1, patch_h, patch_w, 1),\n            strides=(1, strides_h, strides_w, 1),\n            rates=(1, dilation_rate, dilation_rate, 1),\n            padding=padding.upper(),\n        )\n        self.assertEqual(tuple(patches_out.shape), tuple(patches_ref.shape))\n        self.assertAllClose(patches_ref, patches_out, atol=1e-2)\n\n        # Test channels_first\n        if backend.backend() == \"tensorflow\":\n            # tensorflow doesn't support channels_first in\n            # `kimage.extract_patches`\n            return\n        backend.set_image_data_format(\"channels_first\")\n        image = np.random.uniform(size=(1, 3, 20, 20)).astype(\"float32\")\n        patches_out = kimage.extract_patches(\n            image,\n            size=size,\n            strides=strides,\n            dilation_rate=dilation_rate,\n            padding=padding,\n        )\n        patches_ref = tf.image.extract_patches(\n            np.transpose(image, [0, 2, 3, 1]),\n            sizes=(1, patch_h, patch_w, 1),\n            strides=(1, strides_h, strides_w, 1),\n            rates=(1, dilation_rate, dilation_rate, 1),\n            padding=padding.upper(),\n        )\n        patches_ref = tf.transpose(patches_ref, [0, 3, 1, 2])\n        self.assertEqual(tuple(patches_out.shape), tuple(patches_ref.shape))\n        self.assertAllClose(patches_ref, patches_out, atol=1e-2)\n\n        # Test class\n        patches_out = kimage.ExtractPatches(\n            size=size,\n            strides=strides,\n            dilation_rate=dilation_rate,\n            padding=padding,\n        )(image)\n        self.assertAllClose(patches_ref, patches_out, atol=1e-2)\n\n    @parameterized.named_parameters(\n        named_product(\n            # (input_shape, coordinates_shape)\n            shape=[((5,), (7,)), ((3, 4, 5), (2, 3, 4))],\n            # TODO: scipy.ndimage.map_coordinates does not support float16\n            # TODO: torch cpu does not support round & floor for float16\n            dtype=[\"uint8\", \"int32\", \"float32\"],\n            order=[0, 1],\n            fill_mode=[\"constant\", \"nearest\", \"wrap\", \"mirror\", \"reflect\"],\n        )\n    )\n    def test_map_coordinates(self, shape, dtype, order, fill_mode):\n        input_shape, coordinates_shape = shape\n        input = np.arange(math.prod(input_shape), dtype=dtype).reshape(\n            input_shape\n        )\n        coordinates_dtype = \"float32\" if \"int\" in dtype else dtype\n        coordinates = [\n            (size - 1)\n            * np.random.uniform(size=coordinates_shape).astype(\n                coordinates_dtype\n            )\n            for size in input_shape\n        ]\n        output = kimage.map_coordinates(input, coordinates, order, fill_mode)\n        expected = _fixed_map_coordinates(input, coordinates, order, fill_mode)\n        self.assertAllClose(output, expected)\n\n        # Test class\n        output = kimage.MapCoordinates(order, fill_mode)(input, coordinates)\n        self.assertAllClose(output, expected)\n\n    @parameterized.parameters(\n        [\n            (0, 0, 3, 3, None, None),\n            (1, 0, 4, 3, None, None),\n            (0, 1, 3, 4, None, None),\n            (0, 0, 4, 3, None, None),\n            (0, 0, 3, 4, None, None),\n            (0, 0, None, None, 0, 1),\n            (0, 0, None, None, 1, 0),\n            (1, 2, None, None, 3, 4),\n        ]\n    )\n    def test_pad_images(\n        self,\n        top_padding,\n        left_padding,\n        target_height,\n        target_width,\n        bottom_padding,\n        right_padding,\n    ):\n        # Test channels_last\n        image = np.random.uniform(size=(3, 3, 1)).astype(\"float32\")\n        _target_height = target_height  # For `tf.image.pad_to_bounding_box`\n        _target_width = target_width  # For `tf.image.pad_to_bounding_box`\n        if _target_height is None:\n            _target_height = image.shape[0] + top_padding + bottom_padding\n        if _target_width is None:\n            _target_width = image.shape[1] + left_padding + right_padding\n        padded_image = kimage.pad_images(\n            image,\n            top_padding,\n            left_padding,\n            bottom_padding,\n            right_padding,\n            target_height,\n            target_width,\n        )\n        ref_padded_image = tf.image.pad_to_bounding_box(\n            image, top_padding, left_padding, _target_height, _target_width\n        )\n        self.assertEqual(\n            tuple(padded_image.shape), tuple(ref_padded_image.shape)\n        )\n        self.assertAllClose(ref_padded_image, padded_image)\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        image = np.random.uniform(size=(1, 3, 3)).astype(\"float32\")\n        padded_image = kimage.pad_images(\n            image,\n            top_padding,\n            left_padding,\n            bottom_padding,\n            right_padding,\n            target_height,\n            target_width,\n        )\n        ref_padded_image = tf.image.pad_to_bounding_box(\n            np.transpose(image, [1, 2, 0]),\n            top_padding,\n            left_padding,\n            _target_height,\n            _target_width,\n        )\n        ref_padded_image = tf.transpose(ref_padded_image, [2, 0, 1])\n        self.assertEqual(\n            tuple(padded_image.shape), tuple(ref_padded_image.shape)\n        )\n        self.assertAllClose(ref_padded_image, padded_image)\n\n        # Test class\n        padded_image = kimage.PadImages(\n            top_padding,\n            left_padding,\n            bottom_padding,\n            right_padding,\n            target_height,\n            target_width,\n        )(image)\n        self.assertAllClose(ref_padded_image, padded_image)\n\n    @parameterized.parameters(\n        [\n            (0, 0, 3, 3, None, None),\n            (1, 0, 4, 3, None, None),\n            (0, 1, 3, 4, None, None),\n            (0, 0, 4, 3, None, None),\n            (0, 0, 3, 4, None, None),\n            (0, 0, None, None, 0, 1),\n            (0, 0, None, None, 1, 0),\n            (1, 2, None, None, 3, 4),\n        ]\n    )\n    def test_crop_images(\n        self,\n        top_cropping,\n        left_cropping,\n        target_height,\n        target_width,\n        bottom_cropping,\n        right_cropping,\n    ):\n        # Test channels_last\n        image = np.random.uniform(size=(10, 10, 1)).astype(\"float32\")\n        _target_height = target_height  # For `tf.image.pad_to_bounding_box`\n        _target_width = target_width  # For `tf.image.pad_to_bounding_box`\n        if _target_height is None:\n            _target_height = image.shape[0] - top_cropping - bottom_cropping\n        if _target_width is None:\n            _target_width = image.shape[1] - left_cropping - right_cropping\n        cropped_image = kimage.crop_images(\n            image,\n            top_cropping,\n            left_cropping,\n            bottom_cropping,\n            right_cropping,\n            target_height,\n            target_width,\n        )\n        ref_cropped_image = tf.image.crop_to_bounding_box(\n            image, top_cropping, left_cropping, _target_height, _target_width\n        )\n        self.assertEqual(\n            tuple(cropped_image.shape), tuple(ref_cropped_image.shape)\n        )\n        self.assertAllClose(ref_cropped_image, cropped_image)\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        image = np.random.uniform(size=(1, 10, 10)).astype(\"float32\")\n        cropped_image = kimage.crop_images(\n            image,\n            top_cropping,\n            left_cropping,\n            bottom_cropping,\n            right_cropping,\n            target_height,\n            target_width,\n        )\n        ref_cropped_image = tf.image.crop_to_bounding_box(\n            np.transpose(image, [1, 2, 0]),\n            top_cropping,\n            left_cropping,\n            _target_height,\n            _target_width,\n        )\n        ref_cropped_image = tf.transpose(ref_cropped_image, [2, 0, 1])\n        self.assertEqual(\n            tuple(cropped_image.shape), tuple(ref_cropped_image.shape)\n        )\n        self.assertAllClose(ref_cropped_image, cropped_image)\n\n        # Test class\n        cropped_image = kimage.CropImages(\n            top_cropping,\n            left_cropping,\n            bottom_cropping,\n            right_cropping,\n            target_height,\n            target_width,\n        )(image)\n        self.assertAllClose(ref_cropped_image, cropped_image)\n\n    @parameterized.named_parameters(\n        named_product(\n            interpolation=[\"bilinear\", \"nearest\"],\n        )\n    )\n    def test_perspective_transform(self, interpolation):\n        # Test channels_last\n        np.random.seed(42)\n        x = np.random.uniform(size=(50, 50, 3)).astype(\"float32\")\n        start_points = np.random.uniform(size=(1, 4, 2)).astype(\"float32\")\n        end_points = np.random.uniform(size=(1, 4, 2)).astype(\"float32\")\n\n        out = kimage.perspective_transform(\n            x, start_points, end_points, interpolation=interpolation\n        )\n\n        ref_out = _perspective_transform_numpy(\n            x, start_points, end_points, interpolation=interpolation\n        )\n\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(ref_out, out, atol=1e-2, rtol=1e-2)\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = np.random.uniform(size=(3, 50, 50)).astype(\"float32\")\n        start_points = np.random.uniform(size=(1, 4, 2)).astype(\"float32\")\n        end_points = np.random.uniform(size=(1, 4, 2)).astype(\"float32\")\n\n        out = kimage.perspective_transform(\n            x, start_points, end_points, interpolation=interpolation\n        )\n\n        ref_out = _perspective_transform_numpy(\n            x,\n            start_points,\n            end_points,\n            interpolation=interpolation,\n            data_format=\"channels_first\",\n        )\n\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(ref_out, out, atol=1e-2, rtol=1e-2)\n\n    def test_gaussian_blur(self):\n        # Test channels_last\n        backend.set_image_data_format(\"channels_last\")\n        np.random.seed(42)\n        x = np.random.uniform(size=(50, 50, 3)).astype(\"float32\")\n        kernel_size = np.array([3, 3])\n        sigma = np.random.uniform(size=(2,)).astype(\"float32\")\n\n        out = kimage.gaussian_blur(\n            x,\n            kernel_size,\n            sigma,\n            data_format=\"channels_last\",\n        )\n\n        ref_out = gaussian_blur_np(\n            x,\n            kernel_size,\n            sigma,\n            data_format=\"channels_last\",\n        )\n\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(ref_out, out, atol=1e-2, rtol=1e-2)\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = np.random.uniform(size=(3, 50, 50)).astype(\"float32\")\n        kernel_size = np.array([3, 3])\n        sigma = np.random.uniform(size=(2,)).astype(\"float32\")\n\n        out = kimage.gaussian_blur(\n            x,\n            kernel_size,\n            sigma,\n            data_format=\"channels_first\",\n        )\n\n        ref_out = gaussian_blur_np(\n            x,\n            kernel_size,\n            sigma,\n            data_format=\"channels_first\",\n        )\n\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(ref_out, out, atol=1e-2, rtol=1e-2)\n\n    def test_gaussian_blur_even_kernel_size(self):\n        \"\"\"Test gaussian_blur with even kernel sizes\"\"\"\n        # This test is specific to the numpy backend fix\n        if backend.backend() != \"numpy\":\n            self.skipTest(\n                \"Test is specific to numpy backend, current backend: \"\n                f\"{backend.backend()}\"\n            )\n\n        backend.set_image_data_format(\"channels_last\")\n        np.random.seed(42)\n        x = np.random.uniform(size=(32, 32, 3)).astype(\"float32\")\n        kernel_size = np.array([4, 4])  # Even kernel size\n        sigma = np.array([0.8, 1.2]).astype(\"float32\")\n\n        out = kimage.gaussian_blur(\n            x,\n            kernel_size,\n            sigma,\n            data_format=\"channels_last\",\n        )\n\n        ref_out = gaussian_blur_np(\n            x,\n            kernel_size,\n            sigma,\n            data_format=\"channels_last\",\n        )\n\n        self.assertEqual(tuple(out.shape), (32, 32, 3))\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(ref_out, out, atol=1e-2, rtol=1e-2)\n\n        # Test channels_first with different even kernel sizes\n        backend.set_image_data_format(\"channels_first\")\n        x = np.random.uniform(size=(3, 32, 32)).astype(\"float32\")\n        kernel_size = np.array([6, 4])  # Different even kernel sizes\n        sigma = np.array([1.0, 1.5]).astype(\"float32\")\n\n        out = kimage.gaussian_blur(\n            x,\n            kernel_size,\n            sigma,\n            data_format=\"channels_first\",\n        )\n\n        ref_out = gaussian_blur_np(\n            x,\n            kernel_size,\n            sigma,\n            data_format=\"channels_first\",\n        )\n\n        self.assertEqual(tuple(out.shape), (3, 32, 32))\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(ref_out, out, atol=1e-2, rtol=1e-2)\n\n    def test_elastic_transform(self):\n        # Test channels_last\n        backend.set_image_data_format(\"channels_last\")\n        np.random.seed(42)\n        x = np.random.uniform(size=(50, 50, 3)).astype(\"float32\")\n        alpha, sigma, seed = 20.0, 5.0, 42\n\n        out = kimage.elastic_transform(\n            x,\n            alpha=alpha,\n            sigma=sigma,\n            seed=seed,\n            data_format=\"channels_last\",\n        )\n\n        ref_out = elastic_transform_np(\n            x,\n            alpha=alpha,\n            sigma=sigma,\n            seed=seed,\n            data_format=\"channels_last\",\n        )\n\n        out = backend.convert_to_numpy(out)\n\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(\n            np.mean(ref_out), np.mean(out), atol=1e-2, rtol=1e-2\n        )\n        self.assertAllClose(np.var(ref_out), np.var(out), atol=1e-2, rtol=1e-2)\n\n        # Test channels_first\n        backend.set_image_data_format(\"channels_first\")\n        x = np.random.uniform(size=(3, 50, 50)).astype(\"float32\")\n        alpha, sigma, seed = 20.0, 5.0, 42\n\n        ref_out = elastic_transform_np(\n            x,\n            alpha=alpha,\n            sigma=sigma,\n            seed=seed,\n            data_format=\"channels_first\",\n        )\n\n        out = kimage.elastic_transform(\n            x,\n            alpha=alpha,\n            sigma=sigma,\n            seed=seed,\n            data_format=\"channels_first\",\n        )\n        out = backend.convert_to_numpy(out)\n\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(\n            np.mean(ref_out), np.mean(out), atol=1e-2, rtol=1e-2\n        )\n        self.assertAllClose(np.var(ref_out), np.var(out), atol=1e-2, rtol=1e-2)\n\n    def test_map_coordinates_constant_padding(self):\n        input_img = tf.ones((2, 2), dtype=tf.uint8)\n        # one pixel outside of the input space around the edges\n        grid = tf.stack(\n            tf.meshgrid(\n                tf.range(-1, 3, dtype=tf.float32),\n                tf.range(-1, 3, dtype=tf.float32),\n                indexing=\"ij\",\n            ),\n            axis=0,\n        )\n        out = backend.convert_to_numpy(\n            kimage.map_coordinates(\n                input_img, grid, order=0, fill_mode=\"constant\", fill_value=0\n            )\n        )\n\n        # check for ones in the middle and zeros around the edges\n        self.assertTrue(np.all(out[:1] == 0))\n        self.assertTrue(np.all(out[-1:] == 0))\n        self.assertTrue(np.all(out[:, :1] == 0))\n        self.assertTrue(np.all(out[:, -1:] == 0))\n        self.assertTrue(np.all(out[1:3, 1:3] == 1))\n\n    @parameterized.named_parameters(\n        named_product(\n            method=[\"linear\", \"cubic\", \"lanczos3\", \"lanczos5\"],\n            antialias=[True, False],\n        )\n    )\n    def test_scale_and_translate(self, method, antialias):\n        images = np.random.random((30, 30, 3)).astype(\"float32\") * 255\n        scale = np.array([2.0, 2.0]).astype(\"float32\")\n        translation = -(scale / 2.0 - 0.5)\n        out = kimage.scale_and_translate(\n            images,\n            output_shape=(15, 15, 3),\n            scale=scale,\n            translation=translation,\n            spatial_dims=(0, 1),\n            method=method,\n            antialias=antialias,\n        )\n        ref_out = jax.image.scale_and_translate(\n            images,\n            shape=(15, 15, 3),\n            spatial_dims=(0, 1),\n            scale=scale,\n            translation=translation,\n            method=method,\n            antialias=antialias,\n        )\n        self.assertEqual(tuple(out.shape), tuple(ref_out.shape))\n        self.assertAllClose(ref_out, out, atol=1e-4)\n\n\nclass ImageOpsDtypeTest(testing.TestCase):\n    \"\"\"Test the dtype to verify that the behavior matches JAX.\"\"\"\n\n    INT_DTYPES = [x for x in dtypes.INT_TYPES if x not in (\"uint64\", \"int64\")]\n    FLOAT_DTYPES = [x for x in dtypes.FLOAT_TYPES if x not in (\"float64\",)]\n\n    if backend.backend() == \"torch\":\n        INT_DTYPES = [x for x in INT_DTYPES if x not in (\"uint16\", \"uint32\")]\n\n    def setUp(self):\n        # Defaults to channels_last\n        self.data_format = backend.image_data_format()\n        backend.set_image_data_format(\"channels_last\")\n        return super().setUp()\n\n    def tearDown(self):\n        backend.set_image_data_format(self.data_format)\n        return super().tearDown()\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_affine_transform(self, dtype):\n        images = knp.ones((50, 50, 3), dtype=dtype)\n        transform = knp.ones((8,), dtype=dtype)\n        expected_dtype = dtype\n\n        self.assertDType(\n            kimage.affine_transform(images, transform), expected_dtype\n        )\n        self.assertDType(\n            kimage.AffineTransform().symbolic_call(images, transform),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_crop_images(self, dtype):\n        images = knp.ones((10, 10, 3), dtype=dtype)\n        expected_dtype = dtype\n\n        self.assertDType(kimage.crop_images(images, 0, 0, 3, 3), expected_dtype)\n        self.assertDType(\n            kimage.CropImages(0, 0, 3, 3).symbolic_call(images), expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_elastic_transform(self, dtype):\n        images = knp.ones((10, 10, 3), dtype=dtype)\n        expected_dtype = dtype\n\n        self.assertDType(kimage.elastic_transform(images), expected_dtype)\n        self.assertDType(\n            kimage.ElasticTransform().symbolic_call(images), expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_extract_patches(self, dtype):\n        images = knp.ones((10, 10, 3), dtype=dtype)\n        expected_dtype = dtype\n\n        self.assertDType(kimage.extract_patches(images, (3, 3)), expected_dtype)\n        self.assertDType(\n            kimage.ExtractPatches((3, 3)).symbolic_call(images), expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_gaussian_blur(self, dtype):\n        images = knp.ones((10, 10, 3), dtype=dtype)\n        expected_dtype = dtype\n\n        self.assertDType(kimage.gaussian_blur(images), expected_dtype)\n        self.assertDType(\n            kimage.GaussianBlur().symbolic_call(images), expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_hsv_to_rgb(self, dtype):\n        images = knp.ones((10, 10, 3), dtype=dtype)\n        expected_dtype = dtype\n\n        self.assertDType(kimage.hsv_to_rgb(images), expected_dtype)\n        self.assertDType(\n            kimage.HSVToRGB().symbolic_call(images), expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_map_coordinates(self, dtype):\n        inputs = knp.ones((3, 4, 5), dtype=dtype)\n        coordinates = knp.stack([knp.ones((2, 3, 4), dtype=dtype)] * 3)\n        expected_dtype = dtype\n\n        self.assertDType(\n            kimage.map_coordinates(inputs, coordinates, 0), expected_dtype\n        )\n        self.assertDType(\n            kimage.MapCoordinates(0).symbolic_call(inputs, coordinates),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_pad_images(self, dtype):\n        images = knp.ones((10, 10, 3), dtype=dtype)\n        expected_dtype = dtype\n\n        self.assertDType(kimage.pad_images(images, 0, 0, 3, 3), expected_dtype)\n        self.assertDType(\n            kimage.PadImages(0, 0, 3, 3).symbolic_call(images), expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_perspective_transform(self, dtype):\n        images = knp.ones((10, 10, 3), dtype=dtype)\n        start_points = krandom.uniform((1, 4, 2), dtype=dtype)\n        end_points = krandom.uniform((1, 4, 2), dtype=dtype)\n        expected_dtype = dtype\n\n        self.assertDType(\n            kimage.perspective_transform(images, start_points, end_points),\n            expected_dtype,\n        )\n        self.assertDType(\n            kimage.PerspectiveTransform().symbolic_call(\n                images, start_points, end_points\n            ),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_resize(self, dtype):\n        images = knp.ones((10, 10, 3), dtype=dtype)\n        expected_dtype = dtype\n\n        self.assertDType(kimage.resize(images, (5, 5)), expected_dtype)\n        self.assertDType(\n            kimage.Resize((5, 5)).symbolic_call(images), expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_rgb_to_grayscale(self, dtype):\n        images = knp.ones((10, 10, 3), dtype=dtype)\n        expected_dtype = dtype\n\n        self.assertDType(kimage.rgb_to_grayscale(images), expected_dtype)\n        self.assertDType(\n            kimage.RGBToGrayscale().symbolic_call(images), expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_rgb_to_hsv(self, dtype):\n        images = knp.ones((10, 10, 3), dtype=dtype)\n        expected_dtype = dtype\n\n        self.assertDType(kimage.rgb_to_hsv(images), expected_dtype)\n        self.assertDType(\n            kimage.RGBToHSV().symbolic_call(images), expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_scale_and_translate(self, dtype):\n        images = knp.ones((10, 10, 3), dtype=dtype)\n        scale = knp.ones((2,), dtype=dtype)\n        translation = knp.ones((2,), dtype=dtype)\n        expected_dtype = dtype\n\n        self.assertDType(\n            kimage.scale_and_translate(\n                images,\n                output_shape=(15, 15, 3),\n                scale=scale,\n                translation=translation,\n                spatial_dims=(0, 1),\n                method=\"linear\",\n            ),\n            expected_dtype,\n        )\n        self.assertDType(\n            kimage.ScaleAndTranslate(\n                spatial_dims=(0, 1), method=\"linear\"\n            ).symbolic_call(images, (15, 15, 3), scale, translation),\n            expected_dtype,\n        )\n\n\nclass ImageOpsBehaviorTests(testing.TestCase):\n    def setUp(self):\n        # Defaults to channels_last\n        self.data_format = backend.image_data_format()\n        backend.set_image_data_format(\"channels_last\")\n        return super().setUp()\n\n    def tearDown(self):\n        backend.set_image_data_format(self.data_format)\n        return super().tearDown()\n\n    @parameterized.named_parameters(named_product(rank=[2, 5]))\n    def test_rgb_to_grayscale_invalid_rank(self, rank):\n        shape = [3] * rank\n        invalid_image = np.random.uniform(size=shape)\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Invalid images rank: expected rank 3\",\n        ):\n            kimage.rgb_to_grayscale(invalid_image)\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Invalid images rank: expected rank 3\",\n        ):\n            kimage.RGBToGrayscale()(invalid_image)\n        invalid_image = KerasTensor(shape=shape)\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Invalid images rank: expected rank 3\",\n        ):\n            kimage.rgb_to_grayscale(invalid_image)\n\n    @parameterized.named_parameters(named_product(rank=[2, 5]))\n    def test_rgb_to_hsv_invalid_rank(self, rank):\n        shape = [3] * rank\n        invalid_image = np.random.uniform(size=shape)\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.rgb_to_hsv(invalid_image)\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.RGBToHSV()(invalid_image)\n        invalid_image = KerasTensor(shape=shape)\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.rgb_to_hsv(invalid_image)\n\n    def test_rgb_to_hsv_invalid_dtype(self):\n        invalid_image = np.random.uniform(size=(10, 10, 3)).astype(\"int32\")\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images dtype: expected float dtype.\"\n        ):\n            kimage.rgb_to_hsv(invalid_image)\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images dtype: expected float dtype.\"\n        ):\n            kimage.RGBToHSV()(invalid_image)\n        invalid_image = KerasTensor(shape=(10, 10, 3), dtype=\"int32\")\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images dtype: expected float dtype.\"\n        ):\n            kimage.rgb_to_hsv(invalid_image)\n\n    @parameterized.named_parameters(named_product(rank=[2, 5]))\n    def test_hsv_to_rgb_invalid_rank(self, rank):\n        shape = [3] * rank\n        invalid_image = np.random.uniform(size=shape)\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.hsv_to_rgb(invalid_image)\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.HSVToRGB()(invalid_image)\n        invalid_image = KerasTensor(shape=shape)\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.hsv_to_rgb(invalid_image)\n\n    def test_hsv_to_rgb_invalid_dtype(self):\n        invalid_image = np.random.uniform(size=(10, 10, 3)).astype(\"int32\")\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images dtype: expected float dtype.\"\n        ):\n            kimage.hsv_to_rgb(invalid_image)\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images dtype: expected float dtype.\"\n        ):\n            kimage.HSVToRGB()(invalid_image)\n        invalid_image = KerasTensor(shape=(10, 10, 3), dtype=\"int32\")\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images dtype: expected float dtype.\"\n        ):\n            kimage.hsv_to_rgb(invalid_image)\n\n    def test_resize_invalid_rank(self):\n        # Test rank=2\n        invalid_image = np.random.uniform(size=(10, 10))\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.resize(invalid_image, (5, 5))\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.Resize((5, 5))(invalid_image)\n\n        # Test rank=2, symbolic tensor\n        invalid_image = KerasTensor(shape=(10, 10))\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.resize(invalid_image, (5, 5))\n\n    def test_affine_transform_invalid_images_rank(self):\n        # Test rank=2\n        invalid_image = np.random.uniform(size=(10, 10))\n        transform = np.random.uniform(size=(6,))\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.affine_transform(invalid_image, transform)\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.AffineTransform()(invalid_image, transform)\n\n        # Test rank=5\n        invalid_image = np.random.uniform(size=(2, 10, 10, 3, 1))\n        transform = np.random.uniform(size=(6,))\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.affine_transform(invalid_image, transform)\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.AffineTransform()(invalid_image, transform)\n\n        # Test rank=2, symbolic tensor\n        invalid_image = KerasTensor(shape=(10, 10))\n        transform = KerasTensor(shape=(6,))\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.affine_transform(invalid_image, transform)\n\n    def test_affine_transform_invalid_transform_rank(self):\n        # Test rank=3\n        images = np.random.uniform(size=(10, 10, 3))\n        invalid_transform = np.random.uniform(size=(2, 3, 2))\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid transform rank: expected rank 1\"\n        ):\n            kimage.affine_transform(images, invalid_transform)\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid transform rank: expected rank 1\"\n        ):\n            kimage.AffineTransform()(images, invalid_transform)\n\n        # Test rank=0\n        invalid_transform = np.random.uniform(size=())\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid transform rank: expected rank 1\"\n        ):\n            kimage.affine_transform(images, invalid_transform)\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid transform rank: expected rank 1\"\n        ):\n            kimage.AffineTransform()(images, invalid_transform)\n\n        # Test rank=3, symbolic tensor\n        images = KerasTensor(shape=(10, 10, 3))\n        invalid_transform = KerasTensor(shape=(2, 3, 2))\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid transform rank: expected rank 1\"\n        ):\n            kimage.affine_transform(images, invalid_transform)\n\n    def test_extract_patches_invalid_size(self):\n        size = \"5\"  # Invalid size type\n        image = np.random.uniform(size=(2, 20, 20, 3))\n        with self.assertRaisesRegex(TypeError, \"Expected an int or a tuple\"):\n            kimage.extract_patches(image, size)\n\n        size = (3, 3, 3, 3)  # Invalid size, too many dimensions\n        with self.assertRaisesRegex(\n            ValueError, \"Expected a tuple of length 2 or 3\"\n        ):\n            kimage.extract_patches(image, size)\n\n    def test_extract_patches_unified_3d(self):\n        # Test that extract_patches handles 3D volumes when size has 3 elements\n        # channels_last\n        volume = np.random.uniform(size=(2, 20, 20, 20, 3)).astype(\"float32\")\n        patches = kimage.extract_patches(volume, (5, 5, 5))\n        self.assertEqual(patches.shape, (2, 4, 4, 4, 375))\n\n        # unbatched\n        volume = np.random.uniform(size=(20, 20, 20, 3)).astype(\"float32\")\n        patches = kimage.extract_patches(volume, (5, 5, 5))\n        self.assertEqual(patches.shape, (4, 4, 4, 375))\n\n    def test_map_coordinates_invalid_coordinates_rank(self):\n        # Test mismatched dim of coordinates\n        image = np.random.uniform(size=(10, 10, 3))\n        coordinates = np.random.uniform(size=(2, 10, 10))\n        with self.assertRaisesRegex(\n            ValueError, \"must be the same as the rank of `inputs`\"\n        ):\n            kimage.map_coordinates(image, coordinates, 0)\n        with self.assertRaisesRegex(\n            ValueError, \"must be the same as the rank of `inputs`\"\n        ):\n            kimage.MapCoordinates(0)(image, coordinates)\n\n        # Test rank=1\n        coordinates = np.random.uniform(size=(3,))\n        with self.assertRaisesRegex(ValueError, \"expected at least rank 2\"):\n            kimage.map_coordinates(image, coordinates, 0)\n        with self.assertRaisesRegex(ValueError, \"expected at least rank 2\"):\n            kimage.MapCoordinates(0)(image, coordinates)\n\n    def test_crop_images_unknown_shape(self):\n        # Test unknown height and target_height\n        x = KerasTensor([None, 10, 3])\n        with self.assertRaisesRegex(\n            ValueError, \"When the height of the images is unknown\"\n        ):\n            kimage.crop_images(x, 2, 3, 4, 5)\n\n        # Test unknown width and target_width\n        x = KerasTensor([10, None, 3])\n        with self.assertRaisesRegex(\n            ValueError, \"When the width of the images is unknown\"\n        ):\n            kimage.crop_images(x, 2, 3, 4, 5)\n\n    def test_perspective_transform_invalid_images_rank(self):\n        # Test rank=2\n        invalid_image = np.random.uniform(size=(10, 10))\n        start_points = np.random.uniform(size=(6,))\n        end_points = np.random.uniform(size=(6,))\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.perspective_transform(\n                invalid_image, start_points, end_points\n            )\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.PerspectiveTransform()(\n                invalid_image, start_points, end_points\n            )\n\n        # Test rank=5\n        invalid_image = np.random.uniform(size=(2, 10, 10, 3, 1))\n        start_points = np.random.uniform(size=(6,))\n        end_points = np.random.uniform(size=(6,))\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.perspective_transform(\n                invalid_image, start_points, end_points\n            )\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.PerspectiveTransform()(\n                invalid_image, start_points, end_points\n            )\n\n        # Test rank=2, symbolic tensor\n        invalid_image = KerasTensor(shape=(10, 10))\n        start_points = KerasTensor(shape=(6,))\n        end_points = np.random.uniform(size=(6,))\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.perspective_transform(\n                invalid_image, start_points, end_points\n            )\n\n    def test_perspective_transform_invalid_points_rank(self):\n        # Test rank=3\n        images = np.random.uniform(size=(10, 10, 3))\n        start_points = np.random.uniform(size=(2, 2, 4, 2))\n        end_points = np.random.uniform(size=(2, 2, 4, 2))\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid start_points shape: expected\"\n        ):\n            kimage.perspective_transform(images, start_points, end_points)\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid start_points shape: expected\"\n        ):\n            kimage.PerspectiveTransform()(images, start_points, end_points)\n\n        # Test rank=0\n        start_points = np.random.uniform(size=())\n        end_points = np.random.uniform(size=())\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid start_points shape: expected\"\n        ):\n            kimage.perspective_transform(images, start_points, end_points)\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid start_points shape: expected\"\n        ):\n            kimage.PerspectiveTransform()(images, start_points, end_points)\n\n        # Test rank=3, symbolic tensor\n        images = KerasTensor(shape=(10, 10, 3))\n        start_points = KerasTensor(shape=(2, 3, 2))\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid start_points shape: expected\"\n        ):\n            kimage.perspective_transform(images, start_points, end_points)\n\n    def test_gaussian_blur_invalid_images_rank(self):\n        # Test rank=2\n        invalid_image = np.random.uniform(size=(10, 10))\n        kernel_size = (3, 3)\n        sigma = (0.1, 0.1)\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.gaussian_blur(\n                invalid_image, kernel_size=kernel_size, sigma=sigma\n            )\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.GaussianBlur(kernel_size=kernel_size, sigma=sigma)(\n                invalid_image\n            )\n\n        # Test rank=5\n        invalid_image = np.random.uniform(size=(2, 10, 10, 3, 1))\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.gaussian_blur(\n                invalid_image, kernel_size=kernel_size, sigma=sigma\n            )\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.GaussianBlur(kernel_size=kernel_size, sigma=sigma)(\n                invalid_image\n            )\n\n        # Test rank=2, symbolic tensor\n        invalid_image = KerasTensor(shape=(10, 10))\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.gaussian_blur(\n                invalid_image, kernel_size=kernel_size, sigma=sigma\n            )\n\n    def test_elastic_transform_invalid_images_rank(self):\n        # Test rank=2\n        invalid_image = np.random.uniform(size=(10, 10))\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.elastic_transform(\n                invalid_image,\n            )\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.ElasticTransform()(invalid_image)\n\n        # Test rank=5\n        invalid_image = np.random.uniform(size=(2, 10, 10, 3, 1))\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.elastic_transform(invalid_image)\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.ElasticTransform()(invalid_image)\n\n        # Test rank=2, symbolic tensor\n        invalid_image = KerasTensor(shape=(10, 10))\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid images rank: expected rank 3\"\n        ):\n            kimage.elastic_transform(invalid_image)\n\n\nclass ExtractPatches3DTest(testing.TestCase):\n    FLOAT_DTYPES = [x for x in dtypes.FLOAT_TYPES if x not in (\"float64\",)]\n\n    def setUp(self):\n        backend.set_image_data_format(\"channels_last\")\n        return super().setUp()\n\n    @parameterized.named_parameters(\n        named_product(\n            dtype=FLOAT_DTYPES, data_format=[\"channels_last\", \"channels_first\"]\n        )\n    )\n    def test_extract_patches_3d_basic(self, dtype, data_format):\n        if data_format == \"channels_last\":\n            volume = np.ones((1, 96, 96, 96, 4), dtype=dtype)\n            expected_shape = (1, 24, 24, 24, 256)\n        else:\n            volume = np.ones((1, 4, 96, 96, 96), dtype=dtype)\n            expected_shape = (1, 256, 24, 24, 24)\n        patches = kimage.extract_patches_3d(\n            volume, size=(4, 4, 4), strides=(4, 4, 4), data_format=data_format\n        )\n\n        self.assertEqual(patches.shape, expected_shape)\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_extract_patches_3d_valid_padding(self, dtype):\n        volume = np.random.rand(2, 32, 32, 32, 3)\n        volume = volume.astype(dtype)\n        patches = kimage.extract_patches_3d(\n            volume, size=(8, 8, 8), strides=(8, 8, 8), padding=\"valid\"\n        )\n        self.assertEqual(patches.shape, (2, 4, 4, 4, 1536))\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_extract_patches_3d_same_padding(self, dtype):\n        volume = np.random.rand(1, 33, 33, 33, 1)\n        volume = volume.astype(dtype)\n        patches = kimage.extract_patches_3d(\n            volume, size=(4, 4, 4), strides=(4, 4, 4), padding=\"same\"\n        )\n        expected_patches = (33 + 3) // 4  # = 9\n        self.assertEqual(\n            patches.shape,\n            (1, expected_patches, expected_patches, expected_patches, 64),\n        )\n\n    @parameterized.named_parameters(\n        named_product(\n            dtype=FLOAT_DTYPES, data_format=[\"channels_last\", \"channels_first\"]\n        )\n    )\n    def test_extract_patches_3d_with_dilation(self, dtype, data_format):\n        # Shape input according to data_format\n        if data_format == \"channels_last\":\n            volume = np.random.rand(1, 64, 64, 64, 2).astype(dtype)\n        else:\n            volume = np.random.rand(1, 2, 64, 64, 64).astype(dtype)\n\n        if backend.backend() == \"tensorflow\":\n            # TensorFlow backend does not support dilation > 1 and strides > 1\n            with self.assertRaises(ValueError):\n                kimage.extract_patches_3d(\n                    volume,\n                    size=(3, 3, 3),\n                    strides=(8, 8, 8),\n                    dilation_rate=(2, 2, 2),\n                    data_format=data_format,\n                )\n        else:\n            # Runs without error; check shape\n            patches = kimage.extract_patches_3d(\n                volume,\n                size=(3, 3, 3),\n                strides=(8, 8, 8),\n                dilation_rate=(2, 2, 2),\n                data_format=data_format,\n            )\n            # eff_p = 3 + (3 - 1) * (2 - 1) = 5\n            # out = (64 - 5) // 8 + 1 = 8\n            expected_patches = 8\n            if data_format == \"channels_last\":\n                expected_shape = (\n                    1,\n                    expected_patches,\n                    expected_patches,\n                    expected_patches,\n                    54,  # 2*3*3*3\n                )\n            else:\n                expected_shape = (\n                    1,\n                    54,  # 2*3*3*3\n                    expected_patches,\n                    expected_patches,\n                    expected_patches,\n                )\n            self.assertEqual(patches.shape, expected_shape)\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_extract_patches_3d_overlapping(self, dtype):\n        volume = np.random.rand(1, 16, 16, 16, 1)\n        volume = volume.astype(dtype)\n        patches = kimage.extract_patches_3d(\n            volume, size=(4, 4, 4), strides=(2, 2, 2)\n        )\n        expected_patches = (16 - 4) // 2 + 1  # = 7\n        self.assertEqual(\n            patches.shape,\n            (1, expected_patches, expected_patches, expected_patches, 64),\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_extract_patches_3d_int_size(self, dtype):\n        volume = np.random.rand(1, 24, 24, 24, 2)\n        volume = volume.astype(dtype)\n        patches = kimage.extract_patches_3d(volume, size=6, strides=6)\n        self.assertEqual(patches.shape, (1, 4, 4, 4, 432))\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_extract_patches_3d_no_stride_provided(self, dtype):\n        volume = np.random.rand(1, 24, 24, 24, 2)\n        volume = volume.astype(dtype)\n        patches = kimage.extract_patches_3d(volume, size=6)\n        # should default to strides = size - same results as above test\n        self.assertEqual(patches.shape, (1, 4, 4, 4, 432))\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_extract_patches_3d_unbatched(self, dtype):\n        volume = np.random.rand(24, 24, 24, 2)\n        volume = volume.astype(dtype)\n        patches = kimage.extract_patches_3d(volume, size=6)\n        self.assertEqual(patches.shape, (4, 4, 4, 432))\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_extract_patches_3d_value_check(self, dtype):\n        if dtype == \"bfloat16\" and backend.backend() == \"openvino\":\n            self.skipTest(\n                \"OpenVINO's bfloat16 fails this test, \"\n                \"possibly due to precision. \"\n                \"Should be revisited.\"\n            )\n        volume = np.arange(8 * 8 * 8).reshape(1, 8, 8, 8, 1)\n        volume = volume.astype(dtype)\n        patches = kimage.extract_patches_3d(\n            volume, size=(2, 2, 2), strides=(2, 2, 2)\n        )\n        first_patch = patches[0, 0, 0, 0, :]\n        first_patch_np = backend.convert_to_numpy(first_patch)\n\n        expected = volume[0, 0:2, 0:2, 0:2, 0].flatten()\n        np.testing.assert_array_equal(first_patch_np, expected)\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_extract_patches_3d_invalid_size(self, dtype):\n        volume = np.random.rand(1, 32, 32, 32, 1).astype(dtype)\n        with self.assertRaises(TypeError):\n            kimage.extract_patches_3d(volume, size=(4, 4))\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_extract_patches_3d_invalid_strides(self, dtype):\n        volume = np.random.rand(1, 32, 32, 32, 1).astype(dtype)\n        with self.assertRaises(ValueError):\n            kimage.extract_patches_3d(volume, size=(4, 4, 4), strides=(2, 2))\n\n    @parameterized.named_parameters(\n        named_product(\n            dtype=FLOAT_DTYPES, data_format=[\"channels_last\", \"channels_first\"]\n        )\n    )\n    def test_extract_patches_3d_non_cubic(self, dtype, data_format):\n        if data_format == \"channels_last\":\n            volume = np.random.rand(1, 32, 32, 32, 3).astype(dtype)\n            expected_shape = (1, 16, 10, 8, 72)\n        else:\n            volume = np.random.rand(1, 3, 32, 32, 32).astype(dtype)\n            expected_shape = (1, 72, 16, 10, 8)\n        patches = kimage.extract_patches_3d(\n            volume, size=(2, 3, 4), strides=(2, 3, 4), data_format=data_format\n        )\n        self.assertEqual(patches.shape, expected_shape)\n"
  },
  {
    "path": "keras/src/ops/linalg.py",
    "content": "from keras.src import backend\nfrom keras.src import tree\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend import KerasTensor\nfrom keras.src.backend import any_symbolic_tensors\nfrom keras.src.ops.operation import Operation\nfrom keras.src.ops.operation_utils import reduce_shape\n\n\nclass Cholesky(Operation):\n    def __init__(self, upper=False, *, name=None):\n        super().__init__(name=name)\n        self.upper = upper\n\n    def call(self, x):\n        return _cholesky(x, self.upper)\n\n    def compute_output_spec(self, x):\n        _assert_2d(x)\n        _assert_square(x)\n        return KerasTensor(x.shape, x.dtype)\n\n\n@keras_export([\"keras.ops.cholesky\", \"keras.ops.linalg.cholesky\"])\ndef cholesky(x, upper=False):\n    \"\"\"Computes the Cholesky decomposition of a positive semi-definite matrix.\n\n    Args:\n        x: Input tensor of shape `(..., M, M)`.\n        upper (bool): If True, returns the upper-triangular Cholesky factor.\n            If False (default), returns the lower-triangular Cholesky factor.\n\n    Returns:\n        A tensor of shape `(..., M, M)` representing the Cholesky factor of `x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Cholesky(upper=upper).symbolic_call(x)\n    return _cholesky(x, upper=upper)\n\n\ndef _cholesky(x, upper=False):\n    x = backend.convert_to_tensor(x)\n    _assert_2d(x)\n    _assert_square(x)\n    try:\n        return backend.linalg.cholesky(x, upper=upper)\n    except Exception as e:\n        raise ValueError(f\"Cholesky decomposition failed: {e}\")\n\n\nclass CholeskyInverse(Operation):\n    def __init__(self, upper=False, *, name=None):\n        super().__init__(name=name)\n        self.upper = upper\n\n    def call(self, x):\n        return _cholesky_inverse(x, self.upper)\n\n    def compute_output_spec(self, x):\n        _assert_2d(x)\n        _assert_square(x)\n        return KerasTensor(x.shape, x.dtype)\n\n\n@keras_export(\n    [\"keras.ops.cholesky_inverse\", \"keras.ops.linalg.cholesky_inverse\"]\n)\ndef cholesky_inverse(x, upper=False):\n    \"\"\"Computes the inverse of a symmetric positive-definite matrix.\n\n    Args:\n        x: Input tensor of shape `(..., M, M)`.\n        upper (bool): Determines whether to use the upper- or lower-triangular\n            factor for the internal computation. Defaults to False.\n\n    Returns:\n        A tensor of shape `(..., M, M)` representing the inverse of `x`.\n\n    Raises:\n        ValueError: If `x` is not a symmetric positive-definite matrix.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return CholeskyInverse(upper=upper).symbolic_call(x)\n    return _cholesky_inverse(x, upper=upper)\n\n\ndef _cholesky_inverse(x, upper=False):\n    x = backend.convert_to_tensor(x)\n    _assert_2d(x)\n    _assert_square(x)\n    try:\n        return backend.linalg.cholesky_inverse(x, upper=upper)\n    except Exception as e:\n        raise ValueError(f\"Cholesky inverse failed: {e}\")\n\n\nclass Det(Operation):\n    def call(self, x):\n        return _det(x)\n\n    def compute_output_spec(self, x):\n        _assert_2d(x)\n        _assert_square(x)\n        return KerasTensor(x.shape[:-2], x.dtype)\n\n\n@keras_export([\"keras.ops.det\", \"keras.ops.linalg.det\"])\ndef det(x):\n    \"\"\"Computes the determinant of a square tensor.\n\n    Args:\n        x: Input tensor of shape `(..., M, M)`.\n\n    Returns:\n        A tensor of shape `(...,)` representing the determinant of `x`.\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Det().symbolic_call(x)\n    return _det(x)\n\n\ndef _det(x):\n    x = backend.convert_to_tensor(x)\n    _assert_2d(x)\n    _assert_square(x)\n    return backend.linalg.det(x)\n\n\nclass Eig(Operation):\n    def call(self, x):\n        return _eig(x)\n\n    def compute_output_spec(self, x):\n        _assert_square(x)\n        _assert_2d(x)\n        return (\n            KerasTensor(x.shape[:-1], x.dtype),\n            KerasTensor(x.shape, x.dtype),\n        )\n\n\n@keras_export([\"keras.ops.eig\", \"keras.ops.linalg.eig\"])\ndef eig(x):\n    \"\"\"Computes the eigenvalues and eigenvectors of a square matrix.\n\n    Args:\n        x: Input tensor of shape `(..., M, M)`.\n\n    Returns:\n        A tuple of two tensors: a tensor of shape `(..., M)` containing\n        eigenvalues and a tensor of shape `(..., M, M)` containing eigenvectors.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Eig().symbolic_call(x)\n    return _eig(x)\n\n\ndef _eig(x):\n    x = backend.convert_to_tensor(x)\n    _assert_square(x)\n    _assert_2d(x)\n    return backend.linalg.eig(x)\n\n\nclass Eigh(Operation):\n    def call(self, x):\n        return _eigh(x)\n\n    def compute_output_spec(self, x):\n        _assert_square(x)\n        _assert_2d(x)\n        return (\n            KerasTensor(x.shape[:-1], x.dtype),\n            KerasTensor(x.shape, x.dtype),\n        )\n\n\n@keras_export([\"keras.ops.eigh\", \"keras.ops.linalg.eigh\"])\ndef eigh(x):\n    \"\"\"Computes the eigenvalues and eigenvectors of a complex Hermitian.\n\n    Args:\n        x: Input tensor of shape `(..., M, M)`.\n\n    Returns:\n        A tuple of two tensors: a tensor of shape `(..., M)` containing\n        eigenvalues and a tensor of shape `(..., M, M)` containing eigenvectors.\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Eigh().symbolic_call(x)\n    return _eigh(x)\n\n\ndef _eigh(x):\n    x = backend.convert_to_tensor(x)\n    _assert_square(x)\n    _assert_2d(x)\n    return backend.linalg.eigh(x)\n\n\nclass Inv(Operation):\n    def call(self, x):\n        return _inv(x)\n\n    def compute_output_spec(self, x):\n        _assert_2d(x)\n        _assert_square(x)\n        return KerasTensor(x.shape, x.dtype)\n\n\n@keras_export([\"keras.ops.inv\", \"keras.ops.linalg.inv\"])\ndef inv(x):\n    \"\"\"Computes the inverse of a square tensor.\n\n    Args:\n        x: Input tensor of shape `(..., M, M)`.\n\n    Returns:\n        A tensor of shape `(..., M, M)` representing the inverse of `x`.\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Inv().symbolic_call(x)\n    return _inv(x)\n\n\ndef _inv(x):\n    x = backend.convert_to_tensor(x)\n    _assert_2d(x)\n    _assert_square(x)\n    return backend.linalg.inv(x)\n\n\nclass LuFactor(Operation):\n    def call(self, x):\n        return _lu_factor(x)\n\n    def compute_output_spec(self, x):\n        _assert_2d(x)\n        batch_shape = x.shape[:-2]\n        m, n = x.shape[-2:]\n        k = min(m, n)\n        return (\n            KerasTensor(batch_shape + (m, n), x.dtype),\n            KerasTensor(batch_shape + (k,), x.dtype),\n        )\n\n\n@keras_export([\"keras.ops.lu_factor\", \"keras.ops.linalg.lu_factor\"])\ndef lu_factor(x):\n    \"\"\"Computes the lower-upper decomposition of a square matrix.\n\n    Args:\n        x: A tensor of shape `(..., M, M)`.\n\n    Returns:\n        A tuple of two tensors: a tensor of shape `(..., M, M)` containing the\n        lower and upper triangular matrices and a tensor of shape `(..., M)`\n        containing the pivots.\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return LuFactor().symbolic_call(x)\n    return _lu_factor(x)\n\n\ndef _lu_factor(x):\n    x = backend.convert_to_tensor(x)\n    _assert_2d(x)\n    if backend.backend() == \"tensorflow\":\n        try:\n            _assert_square(x)\n        except ValueError as e:\n            raise ValueError(\n                f\"LU decomposition failed: {e}. LU decomposition is only \"\n                \"supported for square matrices in Tensorflow.\"\n            )\n    return backend.linalg.lu_factor(x)\n\n\nclass Norm(Operation):\n    def __init__(self, ord=None, axis=None, keepdims=False, *, name=None):\n        super().__init__(name=name)\n        if isinstance(ord, str):\n            if ord not in (\"fro\", \"nuc\"):\n                raise ValueError(\n                    \"Invalid `ord` argument. \"\n                    \"Expected one of {'fro', 'nuc'} when using string. \"\n                    f\"Received: ord={ord}\"\n                )\n        if isinstance(axis, int):\n            axis = [axis]\n        self.ord = ord\n        self.axis = axis\n        self.keepdims = keepdims\n\n    def compute_output_spec(self, x):\n        output_dtype = backend.standardize_dtype(x.dtype)\n        if \"int\" in output_dtype or output_dtype == \"bool\":\n            output_dtype = backend.floatx()\n        if self.axis is None:\n            axis = tuple(range(len(x.shape)))\n        else:\n            axis = self.axis\n        num_axes = len(axis)\n        if num_axes == 1 and isinstance(self.ord, str):\n            raise ValueError(\n                \"Invalid `ord` argument for vector norm. \"\n                f\"Received: ord={self.ord}\"\n            )\n        elif num_axes == 2 and self.ord not in (\n            None,\n            \"fro\",\n            \"nuc\",\n            float(\"inf\"),\n            float(\"-inf\"),\n            1,\n            -1,\n            2,\n            -2,\n        ):\n            raise ValueError(\n                \"Invalid `ord` argument for matrix norm. \"\n                f\"Received: ord={self.ord}\"\n            )\n        return KerasTensor(\n            reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims),\n            dtype=output_dtype,\n        )\n\n    def call(self, x):\n        x = backend.convert_to_tensor(x)\n        return backend.linalg.norm(\n            x, ord=self.ord, axis=self.axis, keepdims=self.keepdims\n        )\n\n\n@keras_export([\"keras.ops.norm\", \"keras.ops.linalg.norm\"])\ndef norm(x, ord=None, axis=None, keepdims=False):\n    \"\"\"Matrix or vector norm.\n\n    This function is able to return one of eight different matrix norms, or one\n    of an infinite number of vector norms (described below), depending on the\n    value of the `ord` parameter.\n\n    Args:\n        x: Input tensor.\n        ord: Order of the norm (see table under Notes). The default is `None`.\n        axis: If `axis` is an integer, it specifies the axis of `x` along which\n            to compute the vector norms. If `axis` is a 2-tuple, it specifies\n            the axes that hold 2-D matrices, and the matrix norms of these\n            matrices are computed.\n        keepdims: If this is set to `True`, the axes which are reduced are left\n            in the result as dimensions with size one.\n\n    Note:\n        For values of `ord < 1`, the result is, strictly speaking, not a\n        mathematical 'norm', but it may still be useful for various numerical\n        purposes. The following norms can be calculated:\n        - For matrices:\n            - `ord=None`: Frobenius norm\n            - `ord=\"fro\"`: Frobenius norm\n            - `ord=\"nuc\"`: nuclear norm\n            - `ord=np.inf`: `max(sum(abs(x), axis=1))`\n            - `ord=-np.inf`: `min(sum(abs(x), axis=1))`\n            - `ord=0`: not supported\n            - `ord=1`: `max(sum(abs(x), axis=0))`\n            - `ord=-1`: `min(sum(abs(x), axis=0))`\n            - `ord=2`: 2-norm (largest sing. value)\n            - `ord=-2`: smallest singular value\n            - other: not supported\n        - For vectors:\n            - `ord=None`: 2-norm\n            - `ord=\"fro\"`: not supported\n            - `ord=\"nuc\"`: not supported\n            - `ord=np.inf`: `max(abs(x))`\n            - `ord=-np.inf`: `min(abs(x))`\n            - `ord=0`: `sum(x != 0)`\n            - `ord=1`: as below\n            - `ord=-1`: as below\n            - `ord=2`: as below\n            - `ord=-2`: as below\n            - other: `sum(abs(x)**ord)**(1./ord)`\n\n    Returns:\n        Norm of the matrix or vector(s).\n\n    Example:\n\n    >>> x = keras.ops.reshape(keras.ops.arange(9, dtype=\"float32\") - 4, (3, 3))\n    >>> keras.ops.linalg.norm(x)\n    7.7459664\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Norm(ord=ord, axis=axis, keepdims=keepdims).symbolic_call(x)\n    x = backend.convert_to_tensor(x)\n    return backend.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)\n\n\nclass Qr(Operation):\n    def __init__(self, mode=\"reduced\", *, name=None):\n        super().__init__(name=name)\n        if mode not in {\"reduced\", \"complete\"}:\n            raise ValueError(\n                \"`mode` argument value not supported. \"\n                \"Expected one of {'reduced', 'complete'}. \"\n                f\"Received: mode={mode}\"\n            )\n        self.mode = mode\n\n    def compute_output_spec(self, x):\n        if len(x.shape) < 2:\n            raise ValueError(\n                \"Input should have rank >= 2. Received: \"\n                f\"input.shape = {x.shape}\"\n            )\n        m = x.shape[-2]\n        n = x.shape[-1]\n        if m is None or n is None:\n            raise ValueError(\n                \"Input should have its last 2 dimensions \"\n                \"fully-defined. Received: \"\n                f\"input.shape = {x.shape}\"\n            )\n        k = min(m, n)\n        base = tuple(x.shape[:-2])\n        if self.mode == \"reduced\":\n            return (\n                KerasTensor(shape=base + (m, k), dtype=x.dtype),\n                KerasTensor(shape=base + (k, n), dtype=x.dtype),\n            )\n        # 'complete' mode.\n        return (\n            KerasTensor(shape=base + (m, m), dtype=x.dtype),\n            KerasTensor(shape=base + (m, n), dtype=x.dtype),\n        )\n\n    def call(self, x):\n        x = backend.convert_to_tensor(x)\n        return backend.linalg.qr(x, mode=self.mode)\n\n\n@keras_export([\"keras.ops.qr\", \"keras.ops.linalg.qr\"])\ndef qr(x, mode=\"reduced\"):\n    \"\"\"Computes the QR decomposition of a tensor.\n\n    Args:\n        x: Input tensor of shape `(..., M, N)`.\n        mode: A string specifying the mode of the QR decomposition.\n            - 'reduced': Returns the reduced QR decomposition. (default)\n            - 'complete': Returns the complete QR decomposition.\n\n    Returns:\n        A tuple containing two tensors. The first tensor of shape `(..., M, K)`\n        is the orthogonal matrix `q` and the second tensor of shape\n        `(..., K, N)` is the upper triangular matrix `r`, where `K = min(M, N)`.\n\n    Example:\n\n    >>> x = keras.ops.convert_to_tensor([[1., 2.], [3., 4.], [5., 6.]])\n    >>> q, r = qr(x)\n    >>> print(q)\n    array([[-0.16903079  0.897085]\n           [-0.5070925   0.2760267 ]\n           [-0.8451542  -0.34503305]], shape=(3, 2), dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Qr(mode=mode).symbolic_call(x)\n    x = backend.convert_to_tensor(x)\n    return backend.linalg.qr(x, mode=mode)\n\n\nclass Solve(Operation):\n    def call(self, a, b):\n        return _solve(a, b)\n\n    def compute_output_spec(self, a, b):\n        _assert_2d(a)\n        _assert_square(a)\n        _assert_1d(b)\n        _assert_a_b_compat(a, b)\n        return KerasTensor(b.shape, b.dtype)\n\n\n@keras_export([\"keras.ops.solve\", \"keras.ops.linalg.solve\"])\ndef solve(a, b):\n    \"\"\"Solves a linear system of equations given by `a x = b`.\n\n    Args:\n        a: A tensor of shape `(..., M, M)` representing the coefficients matrix.\n        b: A tensor of shape `(..., M)` or `(..., M, N)` representing the\n        right-hand side or \"dependent variable\" matrix.\n\n    Returns:\n        A tensor of shape `(..., M)` or `(..., M, N)` representing the solution\n        of the linear system. Returned shape is identical to `b`.\n\n    \"\"\"\n    if any_symbolic_tensors((a, b)):\n        return Solve().symbolic_call(a, b)\n    return _solve(a, b)\n\n\ndef _solve(a, b):\n    a = backend.convert_to_tensor(a)\n    b = backend.convert_to_tensor(b)\n    _assert_2d(a)\n    _assert_square(a)\n    _assert_1d(b)\n    _assert_a_b_compat(a, b)\n    return backend.linalg.solve(a, b)\n\n\nclass SolveTriangular(Operation):\n    def __init__(self, lower=False, *, name=None):\n        super().__init__(name=name)\n        self.lower = lower\n\n    def call(self, a, b):\n        return _solve_triangular(a, b, self.lower)\n\n    def compute_output_spec(self, a, b):\n        _assert_2d(a)\n        _assert_square(a)\n        _assert_1d(b)\n        _assert_a_b_compat(a, b)\n        return KerasTensor(b.shape, b.dtype)\n\n\n@keras_export(\n    [\"keras.ops.solve_triangular\", \"keras.ops.linalg.solve_triangular\"]\n)\ndef solve_triangular(a, b, lower=False):\n    \"\"\"Solves a linear system of equations given by `a x = b`.\n\n    Args:\n        a: A tensor of shape `(..., M, M)` representing the coefficients matrix.\n        b: A tensor of shape `(..., M)` or `(..., M, N)` representing the\n        right-hand side or \"dependent variable\" matrix.\n\n    Returns:\n        A tensor of shape `(..., M)` or `(..., M, N)` representing the solution\n        of the linear system. Returned shape is identical to `b`.\n\n    \"\"\"\n    if any_symbolic_tensors((a, b)):\n        return SolveTriangular(lower).symbolic_call(a, b)\n    return _solve_triangular(a, b, lower)\n\n\ndef _solve_triangular(a, b, lower=False):\n    a = backend.convert_to_tensor(a)\n    b = backend.convert_to_tensor(b)\n    _assert_2d(a)\n    _assert_square(a)\n    _assert_1d(b)\n    _assert_a_b_compat(a, b)\n    return backend.linalg.solve_triangular(a, b, lower)\n\n\nclass SVD(Operation):\n    def __init__(self, full_matrices=True, compute_uv=True, *, name=None):\n        super().__init__(name=name)\n        self.full_matrices = full_matrices\n        self.compute_uv = compute_uv\n\n    def call(self, x):\n        return _svd(x, self.full_matrices, self.compute_uv)\n\n    def compute_output_spec(self, x):\n        _assert_2d(x)\n        rows, columns = x.shape[-2:]\n        batches = x.shape[:-2]\n        s_shape = batches + (min(rows, columns),)\n        if self.full_matrices:\n            u_shape = batches + (rows, rows)\n            v_shape = batches + (columns, columns)\n        else:\n            u_shape = batches + (rows, min(rows, columns))\n            v_shape = batches + (min(rows, columns), columns)\n\n        if self.compute_uv:\n            return (\n                KerasTensor(u_shape, x.dtype),\n                KerasTensor(s_shape, x.dtype),\n                KerasTensor(v_shape, x.dtype),\n            )\n        return KerasTensor(s_shape, x.dtype)\n\n\n@keras_export([\"keras.ops.svd\", \"keras.ops.linalg.svd\"])\ndef svd(x, full_matrices=True, compute_uv=True):\n    \"\"\"Computes the singular value decomposition of a matrix.\n\n    Args:\n        x: Input tensor of shape `(..., M, N)`.\n\n    Returns:\n        A tuple of three tensors: a tensor of shape `(..., M, M)` containing the\n        left singular vectors, a tensor of shape `(..., M, N)` containing the\n        singular values and a tensor of shape `(..., N, N)` containing the\n        right singular vectors.\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return SVD(full_matrices, compute_uv).symbolic_call(x)\n    return _svd(x, full_matrices, compute_uv)\n\n\ndef _svd(x, full_matrices=True, compute_uv=True):\n    x = backend.convert_to_tensor(x)\n    _assert_2d(x)\n    return backend.linalg.svd(x, full_matrices, compute_uv)\n\n\nclass Lstsq(Operation):\n    def __init__(self, rcond=None, *, name=None):\n        super().__init__(name=name)\n        self.rcond = rcond\n\n    def call(self, a, b):\n        return backend.linalg.lstsq(a, b, rcond=self.rcond)\n\n    def compute_output_spec(self, a, b):\n        if len(a.shape) != 2:\n            raise ValueError(\n                f\"Expected a to have rank 2. Received: a.shape={a.shape}\"\n            )\n        if len(b.shape) not in (1, 2):\n            raise ValueError(\n                f\"Expected b to have rank 1 or 2. Received: b.shape={b.shape}\"\n            )\n        m, n = a.shape\n        if b.shape[0] != m:\n            raise ValueError(\n                \"Expected b.shape[0] to be equal to \"\n                \"a.shape[0]. Received: \"\n                f\"a.shape={a.shape}, b.shape={b.shape}\"\n            )\n        if len(b.shape) == 2:\n            k = b.shape[1]\n            x = KerasTensor((n, k), dtype=a.dtype)\n        else:\n            x = KerasTensor((n,), dtype=a.dtype)\n        return x\n\n\n@keras_export([\"keras.ops.lstsq\", \"keras.ops.linalg.lstsq\"])\ndef lstsq(a, b, rcond=None):\n    \"\"\"Return the least-squares solution to a linear matrix equation.\n\n    Computes the vector x that approximately solves the equation\n    `a @ x = b`. The equation may be under-, well-, or over-determined\n    (i.e., the number of linearly independent rows of a can be less than,\n    equal to, or greater than its number of linearly independent columns).\n    If a is square and of full rank, then `x` (but for round-off error)\n    is the exact solution of the equation. Else, `x` minimizes the\n    L2 norm of `b - a * x`.\n\n    If there are multiple minimizing solutions,\n    the one with the smallest L2 norm  is returned.\n\n    Args:\n        a: \"Coefficient\" matrix of shape `(M, N)`.\n        b: Ordinate or \"dependent variable\" values,\n            of shape `(M,)` or `(M, K)`.\n            If `b` is two-dimensional, the least-squares solution\n            is calculated for each of the K columns of `b`.\n        rcond: Cut-off ratio for small singular values of `a`.\n            For the purposes of rank determination,\n            singular values are treated as zero if they are\n            smaller than rcond times the largest\n            singular value of `a`.\n\n    Returns:\n        Tensor with shape `(N,)` or `(N, K)` containing\n        the least-squares solutions.\n\n    **NOTE:** The output differs from `numpy.linalg.lstsq`.\n    NumPy returns a tuple with four elements, the first of which\n    being the least-squares solutions and the others\n    being essentially never used.\n    Keras only returns the first value. This is done both\n    to ensure consistency across backends (which cannot be achieved\n    for the other values) and to simplify the API.\n    \"\"\"\n    if any_symbolic_tensors((a, b)):\n        return Lstsq(rcond=rcond).symbolic_call(a, b)\n    return backend.linalg.lstsq(a, b, rcond=rcond)\n\n\ndef _assert_1d(*arrays):\n    for a in arrays:\n        if a.ndim < 1:\n            raise ValueError(\n                f\"Expected input to have rank >= 1. Received scalar input {a}.\"\n            )\n\n\ndef _assert_2d(*arrays):\n    for a in arrays:\n        if a.ndim < 2:\n            raise ValueError(\n                \"Expected input to have rank >= 2. \"\n                f\"Received input with shape {a.shape}.\"\n            )\n\n\ndef _assert_square(*arrays):\n    for a in arrays:\n        m, n = a.shape[-2:]\n        if m != n:\n            raise ValueError(\n                \"Expected a square matrix. \"\n                f\"Received non-square input with shape {a.shape}\"\n            )\n\n\ndef _assert_a_b_compat(a, b):\n    if a.ndim == b.ndim:\n        if a.shape[-2] != b.shape[-2]:\n            raise ValueError(\n                \"Incompatible shapes between `a` and `b`. \"\n                \"Expected `a.shape[-2] == b.shape[-2]`. \"\n                f\"Received: a.shape={a.shape}, b.shape={b.shape}\"\n            )\n    elif a.ndim == b.ndim - 1:\n        if a.shape[-1] != b.shape[-1]:\n            raise ValueError(\n                \"Incompatible shapes between `a` and `b`. \"\n                \"Expected `a.shape[-1] == b.shape[-1]`. \"\n                f\"Received: a.shape={a.shape}, b.shape={b.shape}\"\n            )\n\n\nclass JVP(Operation):\n    def __init__(self, has_aux=False, *, name=None):\n        super().__init__(name=name)\n        self.has_aux = has_aux\n\n    def call(self, fun, primals, tangents):\n        \"\"\"Computes the JVP of `fun` at `primals` along `tangents`.\n\n        Args:\n            fun: A callable that takes tensors (or nested structures) as input\n                 and returns a tensor (or nested structure) as output.\n            primals: Input tensors (or nested structures) at which the Jacobian\n                     of `fun` is evaluated.\n            tangents: Tensors (or nested structures) representing the direction\n                      vectors for the JVP. Must have the same structure as\n                      `primals`.\n\n        Returns:\n            If `has_aux` is False:\n                A tuple (primals_out, tangents_out) where:\n                - primals_out: Output of `fun(*primals)`\n                - tangents_out: JVP of `fun` at `primals` along `tangents`\n            If `has_aux` is True:\n                A tuple (primals_out, tangents_out, aux) where:\n                - aux: Auxiliary data returned by `fun`\n        \"\"\"\n        return backend.linalg.jvp(fun, primals, tangents, has_aux=self.has_aux)\n\n    def compute_output_spec(self, fun, primals, tangents):\n        # Infer primal output spec\n        if self.has_aux:\n            primals_out_spec, aux_spec = backend.compute_output_spec(\n                fun, *primals\n            )\n        else:\n            primals_out_spec = backend.compute_output_spec(fun, *primals)\n\n        # Tangents output should match primals output in structure and shape\n        tangents_out_spec = tree.map_structure(\n            lambda x: KerasTensor(x.shape, x.dtype), primals_out_spec\n        )\n\n        if self.has_aux:\n            return primals_out_spec, tangents_out_spec, aux_spec\n        return primals_out_spec, tangents_out_spec\n\n\n@keras_export([\"keras.ops.jvp\", \"keras.ops.linalg.jvp\"])\ndef jvp(fun, primals, tangents, has_aux=False):\n    \"\"\"Computes a (forward-mode) Jacobian-vector product of `fun`.\n    Args:\n        fun: Function to be differentiated. Its arguments should be arrays,\n            scalars, or standard Python containers of arrays or scalars. It\n            should return an array, scalar, or standard Python container of\n            arrays or scalars.\n        primals: The primal values at which the Jacobian of `fun` should be\n                evaluated. Should be either a tuple or a list of arguments,\n                and its length should be equal to the number of positional\n                parameters of `fun`.\n        tangents: The tangent vector for which the Jacobian-vector product\n                should be evaluated. Should be either a tuple or a list of\n                tangents, with the same tree structure and array shapes as\n                `primals`.\n        has_aux: Optional, bool. Indicates whether `fun` returns a pair where\n                the first element is considered the output of the mathematical\n                function to be differentiated and the second element is\n                auxiliary data. Default is False.\n\n    Returns:\n        If `has_aux` is False, returns a (`primals_out`, `tangents_out`) pair,\n        where `primals_out` is `fun(*primals)`, and `tangents_out` is the\n        Jacobian-vector product of `fun` evaluated at `primals` with\n        `tangents`. The `tangents_out` value has the same Python tree\n        structure and shapes as `primals_out`.\n\n        If `has_aux` is True, returns a (`primals_out`, `tangents_out`, `aux`)\n        tuple where `aux` is the auxiliary data returned by `fun`.\n\n    Example:\n    >>> from keras import ops\n    >>> a1, a2 = ops.convert_to_tensor(0.1), ops.convert_to_tensor(0.2)\n    >>> primals, tangents = ops.jvp(ops.sin, (a1,), (a2,))\n    >>> primals\n    0.09983342\n    >>> tangents\n    0.19900084\n    \"\"\"\n    if any_symbolic_tensors((primals, tangents)):\n        return JVP(has_aux=has_aux).symbolic_call(fun, primals, tangents)\n    return backend.linalg.jvp(fun, primals, tangents, has_aux=has_aux)\n"
  },
  {
    "path": "keras/src/ops/linalg_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.backend.common.keras_tensor import KerasTensor\nfrom keras.src.ops import linalg\nfrom keras.src.testing.test_utils import named_product\n\n\nclass LinalgOpsDynamicShapeTest(testing.TestCase):\n    def test_cholesky(self):\n        x = KerasTensor([None, 20, 20])\n        out = linalg.cholesky(x)\n        self.assertEqual(out.shape, (None, 20, 20))\n\n        x = KerasTensor([None, None, 20])\n        with self.assertRaises(ValueError):\n            linalg.cholesky(x)\n\n        x = KerasTensor([None, 20, 15])\n        with self.assertRaises(ValueError):\n            linalg.cholesky(x)\n\n    def test_cholesky_inverse(self):\n        x = KerasTensor([None, 20, 20])\n        out = linalg.cholesky_inverse(x)\n        self.assertEqual(out.shape, (None, 20, 20))\n\n        x = KerasTensor([None, None, 20])\n        with self.assertRaises(ValueError):\n            linalg.cholesky_inverse(x)\n\n        x = KerasTensor([None, 20, 15])\n        with self.assertRaises(ValueError):\n            linalg.cholesky_inverse(x)\n\n    def test_det(self):\n        x = KerasTensor([None, 20, 20])\n        out = linalg.det(x)\n        self.assertEqual(out.shape, (None,))\n\n        x = KerasTensor([None, None, 20])\n        with self.assertRaises(ValueError):\n            linalg.det(x)\n\n        x = KerasTensor([None, 20, 15])\n        with self.assertRaises(ValueError):\n            linalg.det(x)\n\n    def test_eig(self):\n        x = KerasTensor([None, 20, 20])\n        w, v = linalg.eig(x)\n        self.assertEqual(w.shape, (None, 20))\n        self.assertEqual(v.shape, (None, 20, 20))\n\n    def test_eigh(self):\n        x = KerasTensor([None, 20, 20])\n        w, v = linalg.eigh(x)\n        self.assertEqual(w.shape, (None, 20))\n        self.assertEqual(v.shape, (None, 20, 20))\n\n    def test_inv(self):\n        x = KerasTensor([None, 20, 20])\n        out = linalg.inv(x)\n        self.assertEqual(out.shape, (None, 20, 20))\n\n        x = KerasTensor([None, None, 20])\n        with self.assertRaises(ValueError):\n            linalg.inv(x)\n\n        x = KerasTensor([None, 20, 15])\n        with self.assertRaises(ValueError):\n            linalg.inv(x)\n\n    def test_lu_factor(self):\n        x = KerasTensor([None, 4, 3])\n        lu, p = linalg.lu_factor(x)\n        self.assertEqual(lu.shape, (None, 4, 3))\n        self.assertEqual(p.shape, (None, 3))\n\n        x = KerasTensor([None, 2, 3])\n        lu, p = linalg.lu_factor(x)\n        self.assertEqual(lu.shape, (None, 2, 3))\n        self.assertEqual(p.shape, (None, 2))\n\n    def test_norm(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(linalg.norm(x).shape, ())\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(linalg.norm(x, axis=1).shape, (None, 3))\n        self.assertEqual(\n            linalg.norm(x, axis=1, keepdims=True).shape, (None, 1, 3)\n        )\n\n    def test_qr(self):\n        x = KerasTensor((None, 4, 3), dtype=\"float32\")\n        q, r = linalg.qr(x, mode=\"reduced\")\n        qref, rref = np.linalg.qr(np.ones((2, 4, 3)), mode=\"reduced\")\n        qref_shape = (None,) + qref.shape[1:]\n        rref_shape = (None,) + rref.shape[1:]\n        self.assertEqual(q.shape, qref_shape)\n        self.assertEqual(r.shape, rref_shape)\n\n        q, r = linalg.qr(x, mode=\"complete\")\n        qref, rref = np.linalg.qr(np.ones((2, 4, 3)), mode=\"complete\")\n        qref_shape = (None,) + qref.shape[1:]\n        rref_shape = (None,) + rref.shape[1:]\n        self.assertEqual(q.shape, qref_shape)\n        self.assertEqual(r.shape, rref_shape)\n\n    def test_qr_invalid_mode(self):\n        # backend agnostic error message\n        x = np.array([[1, 2], [3, 4]])\n        invalid_mode = \"invalid_mode\"\n        with self.assertRaisesRegex(\n            ValueError, \"Expected one of {'reduced', 'complete'}.\"\n        ):\n            linalg.qr(x, mode=invalid_mode)\n\n    def test_solve(self):\n        a = KerasTensor([None, 20, 20])\n        b = KerasTensor([None, 20, 5])\n        out = linalg.solve(a, b)\n        self.assertEqual(out.shape, (None, 20, 5))\n\n        a = KerasTensor([None, 20, 20])\n        b = KerasTensor([None, 20])\n        out = linalg.solve(a, b)\n        self.assertEqual(out.shape, (None, 20))\n\n        a = KerasTensor([None, None, 20])\n        b = KerasTensor([None, 20, 5])\n        with self.assertRaises(ValueError):\n            linalg.solve(a, b)\n\n        a = KerasTensor([None, 20, 15])\n        b = KerasTensor([None, 20, 5])\n        with self.assertRaises(ValueError):\n            linalg.solve(a, b)\n\n        a = KerasTensor([None, 20, 20])\n        b = KerasTensor([None, None, 5])\n        with self.assertRaises(ValueError):\n            linalg.solve(a, b)\n\n    def test_solve_triangular(self):\n        a = KerasTensor([None, 20, 20])\n        b = KerasTensor([None, 20, 5])\n        out = linalg.solve_triangular(a, b)\n        self.assertEqual(out.shape, (None, 20, 5))\n\n        a = KerasTensor([None, 20, 20])\n        b = KerasTensor([None, 20])\n        out = linalg.solve_triangular(a, b)\n        self.assertEqual(out.shape, (None, 20))\n\n        a = KerasTensor([None, 20, 20])\n        b = KerasTensor([None, 20, 5])\n        out = linalg.solve_triangular(a, b, lower=True)\n        self.assertEqual(out.shape, (None, 20, 5))\n\n        a = KerasTensor([None, 20, 20])\n        b = KerasTensor([None, 20])\n        out = linalg.solve_triangular(a, b, lower=True)\n        self.assertEqual(out.shape, (None, 20))\n\n        a = KerasTensor([None, 20, 15])\n        b = KerasTensor([None, 20, 5])\n        with self.assertRaises(ValueError):\n            linalg.solve_triangular(a, b)\n\n        a = KerasTensor([None, 20, 20])\n        b = KerasTensor([None, None, 5])\n        with self.assertRaises(ValueError):\n            linalg.solve_triangular(a, b)\n\n    def test_svd(self):\n        x = KerasTensor((None, 3, 2))\n        u, s, v = linalg.svd(x)\n        self.assertEqual(u.shape, (None, 3, 3))\n        self.assertEqual(s.shape, (None, 2))\n        self.assertEqual(v.shape, (None, 2, 2))\n\n        u, s, v = linalg.svd(x, full_matrices=False)\n        self.assertEqual(u.shape, (None, 3, 2))\n        self.assertEqual(s.shape, (None, 2))\n        self.assertEqual(v.shape, (None, 2, 2))\n\n        s = linalg.svd(x, compute_uv=False)\n        self.assertEqual(s.shape, (None, 2))\n\n\nclass LinalgOpsStaticShapeTest(testing.TestCase):\n    def test_cholesky(self):\n        x = KerasTensor([4, 3, 3])\n        out = linalg.cholesky(x)\n        self.assertEqual(out.shape, (4, 3, 3))\n\n        x = KerasTensor([10, 20, 15])\n        with self.assertRaises(ValueError):\n            linalg.cholesky(x)\n\n    def test_cholesky_inverse(self):\n        x = KerasTensor([4, 3, 3])\n        out = linalg.cholesky_inverse(x)\n        self.assertEqual(out.shape, (4, 3, 3))\n\n        x = KerasTensor([10, 20, 15])\n        with self.assertRaises(ValueError):\n            linalg.cholesky_inverse(x)\n\n    def test_det(self):\n        x = KerasTensor([4, 3, 3])\n        out = linalg.det(x)\n        self.assertEqual(out.shape, (4,))\n\n        x = KerasTensor([10, 20, 15])\n        with self.assertRaises(ValueError):\n            linalg.det(x)\n\n    def test_eig(self):\n        x = KerasTensor([4, 3, 3])\n        w, v = linalg.eig(x)\n        self.assertEqual(w.shape, (4, 3))\n        self.assertEqual(v.shape, (4, 3, 3))\n\n        x = KerasTensor([10, 20, 15])\n        with self.assertRaises(ValueError):\n            linalg.eig(x)\n\n    def test_eigh(self):\n        x = KerasTensor([4, 3, 3])\n        w, v = linalg.eigh(x)\n        self.assertEqual(w.shape, (4, 3))\n        self.assertEqual(v.shape, (4, 3, 3))\n\n        x = KerasTensor([10, 20, 15])\n        with self.assertRaises(ValueError):\n            linalg.eigh(x)\n\n    def test_inv(self):\n        x = KerasTensor([4, 3, 3])\n        out = linalg.inv(x)\n        self.assertEqual(out.shape, (4, 3, 3))\n\n        x = KerasTensor([10, 20, 15])\n        with self.assertRaises(ValueError):\n            linalg.inv(x)\n\n    def test_lu_factor(self):\n        x = KerasTensor([10, 4, 3])\n        lu, p = linalg.lu_factor(x)\n        self.assertEqual(lu.shape, (10, 4, 3))\n        self.assertEqual(p.shape, (10, 3))\n\n        x = KerasTensor([10, 2, 3])\n        lu, p = linalg.lu_factor(x)\n        self.assertEqual(lu.shape, (10, 2, 3))\n        self.assertEqual(p.shape, (10, 2))\n\n    def test_norm(self):\n        x = KerasTensor((10, 3))\n        self.assertEqual(linalg.norm(x).shape, ())\n\n        x = KerasTensor((10, 3, 3))\n        self.assertEqual(linalg.norm(x, axis=1).shape, (10, 3))\n        self.assertEqual(\n            linalg.norm(x, axis=1, keepdims=True).shape, (10, 1, 3)\n        )\n\n    def test_qr(self):\n        x = KerasTensor((4, 3), dtype=\"float32\")\n        q, r = linalg.qr(x, mode=\"reduced\")\n        qref, rref = np.linalg.qr(np.ones((4, 3)), mode=\"reduced\")\n        self.assertEqual(q.shape, qref.shape)\n        self.assertEqual(r.shape, rref.shape)\n\n        q, r = linalg.qr(x, mode=\"complete\")\n        qref, rref = np.linalg.qr(np.ones((4, 3)), mode=\"complete\")\n        self.assertEqual(q.shape, qref.shape)\n        self.assertEqual(r.shape, rref.shape)\n\n        with self.assertRaises(ValueError):\n            linalg.qr(x, mode=\"invalid\")\n\n    def test_solve(self):\n        a = KerasTensor([4, 3, 3])\n        b = KerasTensor([4, 3, 5])\n        out = linalg.solve(a, b)\n        self.assertEqual(out.shape, (4, 3, 5))\n\n        a = KerasTensor([4, 3, 3])\n        b = KerasTensor([4, 3])\n        out = linalg.solve(a, b)\n        self.assertEqual(out.shape, (4, 3))\n\n        a = KerasTensor([10, 20, 15])\n        b = KerasTensor([10, 20, 5])\n        with self.assertRaises(ValueError):\n            linalg.solve(a, b)\n\n        a = KerasTensor([20, 20])\n        b = KerasTensor([])\n        with self.assertRaises(ValueError):\n            linalg.solve(a, b)\n\n    def test_solve_triangular(self):\n        a = KerasTensor([4, 3, 3])\n        b = KerasTensor([4, 3, 5])\n        out = linalg.solve_triangular(a, b)\n        self.assertEqual(out.shape, (4, 3, 5))\n\n        a = KerasTensor([4, 3, 3])\n        b = KerasTensor([4, 3])\n        out = linalg.solve_triangular(a, b)\n        self.assertEqual(out.shape, (4, 3))\n\n        a = KerasTensor([10, 20, 15])\n        b = KerasTensor([10, 20, 5])\n        with self.assertRaises(ValueError):\n            linalg.solve_triangular(a, b)\n\n    def test_svd(self):\n        x = KerasTensor((10, 3, 2))\n        u, s, v = linalg.svd(x)\n        self.assertEqual(u.shape, (10, 3, 3))\n        self.assertEqual(s.shape, (10, 2))\n        self.assertEqual(v.shape, (10, 2, 2))\n\n        u, s, v = linalg.svd(x, full_matrices=False)\n        self.assertEqual(u.shape, (10, 3, 2))\n        self.assertEqual(s.shape, (10, 2))\n        self.assertEqual(v.shape, (10, 2, 2))\n\n        s = linalg.svd(x, compute_uv=False)\n        self.assertEqual(s.shape, (10, 2))\n\n\nclass LinalgOpsCorrectnessTest(testing.TestCase):\n    def test_cholesky(self):\n        x_non_psd = np.random.rand(4, 3, 3).astype(\"float32\")\n        with self.assertRaises(ValueError):\n            linalg.cholesky(x_non_psd)\n\n        x = np.random.rand(4, 3, 3).astype(\"float32\")\n        x_psd = np.matmul(x, x.transpose((0, 2, 1))) + 1e-5 * np.eye(\n            3, dtype=\"float32\"\n        )\n\n        l_out = linalg.cholesky(x_psd, upper=False)\n        l_expected = np.linalg.cholesky(x_psd)\n        self.assertAllClose(l_out, l_expected, atol=1e-4)\n\n        u_out = linalg.cholesky(x_psd, upper=True)\n        u_expected = l_expected.transpose((0, 2, 1))\n        self.assertAllClose(u_out, u_expected, atol=1e-4)\n\n    @parameterized.named_parameters(\n        {\"testcase_name\": \"lower\", \"upper\": False},\n        {\"testcase_name\": \"upper\", \"upper\": True},\n    )\n    def test_cholesky_inverse(self, upper):\n        A = np.array(\n            [\n                [4.0, 12.0, -16.0],\n                [12.0, 37.0, -43.0],\n                [-16.0, -43.0, 98.0],\n            ],\n            dtype=\"float32\",\n        )\n        if upper:\n            factor = np.linalg.cholesky(A, upper=True)\n        else:\n            factor = np.linalg.cholesky(A)\n\n        expected_inverse = np.array(\n            [\n                [49.36111, -13.555555, 2.111111],\n                [-13.555555, 3.777778, -0.555556],\n                [2.111111, -0.555556, 0.111111],\n            ],\n            dtype=\"float32\",\n        )\n\n        output_inverse = linalg.cholesky_inverse(factor, upper=upper)\n        self.assertAllClose(\n            output_inverse,\n            expected_inverse,\n            atol=1e-5,\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n\n    def test_det(self):\n        x = np.random.rand(4, 3, 3)\n        out = linalg.det(x)\n        self.assertAllClose(out, np.linalg.det(x), atol=1e-5)\n\n        with self.assertRaises(ValueError):\n            x = np.random.rand(4, 3, 4)\n            linalg.det(x)\n\n    @pytest.mark.skipif(\n        testing.jax_uses_tpu(), reason=\"Unsupported on JAX with TPU\"\n    )\n    def test_eig(self):\n        x = np.random.rand(2, 3, 3)\n        x = x @ x.transpose((0, 2, 1))\n        w, v = map(ops.convert_to_numpy, linalg.eig(x))\n        x_reconstructed = (v * w[..., None, :]) @ v.transpose((0, 2, 1))\n        self.assertAllClose(x_reconstructed, x, atol=1e-4)\n\n    def test_eigh(self):\n        x = np.random.rand(2, 3, 3)\n        x = x @ x.transpose((0, 2, 1))\n        w, v = map(ops.convert_to_numpy, linalg.eigh(x))\n        x_reconstructed = (v * w[..., None, :]) @ v.transpose((0, 2, 1))\n        self.assertAllClose(x_reconstructed, x, atol=1e-4)\n\n    def test_inv(self):\n        x = np.random.rand(4, 3, 3)\n        x_inv = ops.convert_to_numpy(linalg.inv(x))\n        x_reconstructed = x @ x_inv\n        # high tolerance due to numerical instability\n        self.assertAllClose(\n            x_reconstructed, np.repeat(np.eye(3)[None], 4, 0), atol=1e-3\n        )\n\n    def test_lu_factor(self):\n        def _pivot_matrix(pivots, n):\n            p_matrix = np.eye(n)\n            for i, p in enumerate(pivots):\n                identity = np.eye(n, n)\n                q = identity[i, :].copy()\n                identity[i, :] = identity[p, :]\n                identity[p, :] = q\n                p_matrix = np.dot(p_matrix, identity)\n            return p_matrix\n\n        def _reconstruct(lu, pivots, m, n):\n            lower = np.tril(lu[:, : min(m, n)], -1) + np.eye(m, min(m, n))\n            upper = np.triu(lu[: min(m, n)])\n\n            # pivots are defined differently in tensorflow\n            # compared to the other backends\n            if backend.backend() == \"tensorflow\":\n                p_matrix = np.eye(m)[pivots]\n            else:\n                p_matrix = _pivot_matrix(pivots, m)\n            out = p_matrix @ lower @ upper\n            return out\n\n        m, n = 4, 4\n        x = np.random.rand(m, n)\n        lu, pivots = map(ops.convert_to_numpy, linalg.lu_factor(x))\n        x_reconstructed = _reconstruct(lu, pivots, m, n)\n        self.assertAllClose(x_reconstructed, x, atol=1e-5)\n\n        m, n = 4, 3\n        x = np.random.rand(m, n)\n        if backend.backend() == \"tensorflow\":\n            with self.assertRaises(ValueError):\n                linalg.lu_factor(x)\n        else:\n            lu, pivots = map(ops.convert_to_numpy, linalg.lu_factor(x))\n            x_reconstructed = _reconstruct(lu, pivots, m, n)\n            self.assertAllClose(x_reconstructed, x, atol=1e-5)\n\n        # batched case\n        m, n = 3, 4\n        x = np.random.rand(2, m, n)\n        if backend.backend() == \"tensorflow\":\n            with self.assertRaises(ValueError):\n                linalg.lu_factor(x)\n        else:\n            lu, pivots = map(ops.convert_to_numpy, linalg.lu_factor(x))\n            for i in range(2):\n                self.assertAllClose(\n                    _reconstruct(lu[i], pivots[i], m, n), x[i], atol=1e-5\n                )\n\n    @parameterized.named_parameters(\n        named_product(\n            ndim=[1, 2],\n            ord=[None, \"fro\", \"nuc\", -np.inf, -2, -1, 0, 1, 2, np.inf, 3],\n            axis=[None, 1, -1, (0, 1)],\n            keepdims=[False, True],\n        )\n    )\n    def test_norm(self, ndim, ord, axis, keepdims):\n        if ndim == 1:\n            x = np.random.random((5,)).astype(\"float32\")\n        else:\n            x = np.random.random((5, 6)).astype(\"float32\")\n\n        vector_norm = (ndim == 1) or isinstance(axis, int)\n\n        axis_out_of_bounds = ndim == 1 and (\n            axis == 1 or isinstance(axis, tuple)\n        )\n        expected_error = None\n        # when an out of bounds axis triggers an IndexError on torch is complex\n        if (\n            axis_out_of_bounds\n            and (not isinstance(axis, tuple) or ord is None)\n            and ord not in (\"fro\", \"nuc\")\n        ):\n            expected_error = IndexError\n        elif (\n            axis_out_of_bounds\n            or (vector_norm and isinstance(axis, tuple))  # inv. axis for vector\n            or (vector_norm and ord in (\"fro\", \"nuc\"))  # invalid ord for vector\n            or (not vector_norm and ord in (0, 3))  # invalid ord for matrix\n        ):\n            expected_error = RuntimeError\n\n        if expected_error is not None:\n            # Non-torch backends always throw a ValueError\n            expected_error = (\n                expected_error if backend.backend() == \"torch\" else ValueError\n            )\n            with self.assertRaises(expected_error):\n                linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)\n            return\n        output = linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)\n        expected_result = np.linalg.norm(\n            x, ord=ord, axis=axis, keepdims=keepdims\n        )\n        self.assertAllClose(output, expected_result, atol=1e-5)\n\n    def test_qr(self):\n        x = np.random.random((4, 5))\n        q, r = linalg.qr(x, mode=\"reduced\")\n        qref, rref = np.linalg.qr(x, mode=\"reduced\")\n        self.assertAllClose(qref, q)\n        self.assertAllClose(rref, r)\n\n        q, r = linalg.qr(x, mode=\"complete\")\n        qref, rref = np.linalg.qr(x, mode=\"complete\")\n        self.assertAllClose(qref, q)\n        self.assertAllClose(rref, r)\n\n    def test_solve(self):\n        x1 = np.array([[1, 2], [4, 5]], dtype=\"float32\")\n        x2 = np.array([[2, 4], [8, 10]], dtype=\"float32\")\n        output = linalg.solve(x1, x2)\n        expected_result = np.array([[2, 0], [0, 2]], dtype=\"float32\")\n        self.assertAllClose(output, expected_result)\n\n    def test_solve_triangular(self):\n        # 2d-case\n        x1 = np.array([[1, 2], [0, 5]], dtype=\"float32\")\n        x2 = np.array([2, 10], dtype=\"float32\")\n        output = linalg.solve_triangular(x1, x2, lower=True)\n        expected_result = np.array([2, 2], dtype=\"float32\")\n        self.assertAllClose(output, expected_result)\n\n        output = linalg.solve_triangular(x1, x2, lower=False)\n        expected_result = np.array([-2, 2], dtype=\"float32\")\n        self.assertAllClose(output, expected_result)\n\n        # batched case\n        x1 = np.array([[[1, 2], [0, 5]], [[1, 2], [0, 5]]], dtype=\"float32\")\n        x2 = np.array([[2, 10], [2, 10]], dtype=\"float32\")\n        output = linalg.solve_triangular(x1, x2, lower=True)\n        expected_result = np.array([[2, 2], [2, 2]], dtype=\"float32\")\n        self.assertAllClose(output, expected_result)\n\n    def test_svd(self):\n        x = np.random.rand(4, 30, 20).astype(\"float32\")\n        u, s, vh = linalg.svd(x)\n        x_reconstructed = (u[..., :, : s.shape[-1]] * s[..., None, :]) @ vh[\n            ..., : s.shape[-1], :\n        ]\n        # High tolerance due to numerical instability\n        self.assertAllClose(\n            x_reconstructed, x, atol=1e-3, tpu_atol=1e-2, tpu_rtol=1e-2\n        )\n\n        # Test `compute_uv=False`\n        s_no_uv = linalg.svd(x, compute_uv=False)\n        self.assertAllClose(\n            s_no_uv, s, atol=1e-5, rtol=1e-5, tpu_atol=1e-2, tpu_rtol=1e-2\n        )\n\n    @parameterized.named_parameters(\n        (\"b_rank_1\", 1, None),\n        (\"b_rank_2\", 2, None),\n        (\"rcond\", 1, 1e-3),\n    )\n    def test_lstsq(self, b_rank, rcond):\n        a = np.random.random((5, 7)).astype(\"float32\")\n        a_symb = backend.KerasTensor((5, 7))\n        if b_rank == 1:\n            b = np.random.random((5,)).astype(\"float32\")\n            b_symb = backend.KerasTensor((5,))\n        else:\n            b = np.random.random((5, 4)).astype(\"float32\")\n            b_symb = backend.KerasTensor((5, 4))\n        out = linalg.lstsq(a, b, rcond=rcond)\n        ref_out = np.linalg.lstsq(a, b, rcond=rcond)[0]\n        self.assertAllClose(\n            out, ref_out, atol=1e-5, tpu_atol=1e-4, tpu_rtol=1e-4\n        )\n\n        out_symb = linalg.lstsq(a_symb, b_symb)\n        self.assertEqual(out_symb.shape, out.shape)\n\n\nclass QrOpTest(testing.TestCase):\n    def test_qr_init_mode_reduced(self):\n        qr_op = linalg.Qr(mode=\"reduced\")\n        self.assertIsNotNone(qr_op)\n\n    def test_qr_init_mode_complete(self):\n        qr_op = linalg.Qr(mode=\"complete\")\n        self.assertIsNotNone(qr_op)\n\n    def test_qr_init_invalid_mode(self):\n        invalid_mode = \"invalid_mode\"\n        expected_error = (\n            r\"`mode` argument value not supported. \"\n            r\"Expected one of \\{'reduced', 'complete'\\}. \"\n            f\"Received: mode={invalid_mode}\"\n        )\n        with self.assertRaisesRegex(ValueError, expected_error):\n            linalg.Qr(mode=invalid_mode)\n\n    def test_compute_output_spec_low_rank(self):\n        qr_op = linalg.Qr(mode=\"reduced\")\n        low_rank_input = np.random.rand(3)\n        with self.assertRaisesRegex(\n            ValueError, r\"Input should have rank >= 2. Received: .*\"\n        ):\n            qr_op.compute_output_spec(low_rank_input)\n\n    def test_compute_output_spec_undefined_dimensions(self):\n        qr_op = linalg.Qr(mode=\"reduced\")\n        undefined_dim_input = KerasTensor(shape=(None, 4), dtype=\"float32\")\n        with self.assertRaisesRegex(\n            ValueError,\n            r\"Input should have its last 2 dimensions \"\n            r\"fully-defined. Received: .*\",\n        ):\n            qr_op.compute_output_spec(undefined_dim_input)\n\n    def test_qr_call_mode_reduced(self):\n        qr_op = linalg.Qr(mode=\"reduced\")\n        test_input = np.random.rand(10, 10)\n        q, r = qr_op.call(test_input)\n        self.assertEqual(q.shape, (10, 10))\n        self.assertEqual(r.shape, (10, 10))\n\n    def test_qr_call_mode_complete(self):\n        qr_op = linalg.Qr(mode=\"complete\")\n        test_input = np.random.rand(10, 10)\n        q, r = qr_op.call(test_input)\n        self.assertEqual(q.shape, (10, 10))\n        self.assertEqual(r.shape, (10, 10))\n\n    def test_jvp(self):\n        if backend.backend() in [\"openvino\", \"numpy\"]:\n            pytest.skip(\"Backend does not support jvp operation\")\n        a1, a2 = ops.convert_to_tensor(0.1), ops.convert_to_tensor(0.2)\n        primals, tangents = linalg.jvp(backend.numpy.sin, (a1,), (a2,))\n        self.assertAllClose(primals, 0.0998, atol=1e-4)\n        self.assertAllClose(tangents, 0.1990, atol=1e-4)\n\n        def f(x):\n            return backend.numpy.sin(x), x**2\n\n        primals_out, tangents_out, aux = linalg.jvp(\n            f, (a1,), (a2,), has_aux=True\n        )\n        self.assertAllClose(primals_out, 0.0998, atol=1e-4)\n        self.assertAllClose(tangents_out, 0.1990, atol=1e-4)\n        self.assertAllClose(aux, 0.01, atol=1e-4)\n\n    def test_jvp_symbolic_has_aux_false(self):\n        primals = KerasTensor((None, 7))\n        tangents = KerasTensor((None, 7))\n\n        def fun(x):\n            # simple non-linear transformation\n            return ops.sin(x) + ops.cos(x)\n\n        primals_out, tangents_out = linalg.jvp(fun, (primals,), (tangents,))\n        # output shapes must match input shapes\n        self.assertEqual(primals_out.shape, primals.shape)\n        self.assertEqual(tangents_out.shape, tangents.shape)\n\n        \"\"\"Symbolic JVP test – has_aux=True.\"\"\"\n\n        def fun(x):\n            y = ops.exp(x)\n            aux = ops.mean(y, axis=-1, keepdims=True)  # auxiliary output\n            return y, aux\n\n        primals_out, tangents_out, aux = linalg.jvp(\n            fun, (primals,), (tangents,), has_aux=True\n        )\n        # main output shapes\n        self.assertEqual(primals_out.shape, primals.shape)\n        self.assertEqual(tangents_out.shape, tangents.shape)\n        # auxiliary shape: (batch, 1)\n        self.assertEqual(aux.shape, (None, 1))\n"
  },
  {
    "path": "keras/src/ops/math.py",
    "content": "\"\"\"Commonly used math operations not included in NumPy.\"\"\"\n\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend import KerasTensor\nfrom keras.src.backend import any_symbolic_tensors\nfrom keras.src.ops.operation import Operation\nfrom keras.src.ops.operation_utils import reduce_shape\n\n\ndef _segment_reduce_validation(data, segment_ids):\n    data_shape = data.shape\n    segment_ids_shape = segment_ids.shape\n    if len(segment_ids_shape) > 1:\n        raise ValueError(\n            \"Argument `segment_ids` should be an 1-D vector, got shape: \"\n            f\"{len(segment_ids_shape)}. Consider either flatten input with \"\n            \"segment_ids.reshape((-1)) and \"\n            \"data.reshape((-1, ) + data.shape[len(segment_ids.shape):]) or \"\n            \"vectorize with vmap.\"\n        )\n    if (\n        segment_ids_shape[0] is not None\n        and data_shape[0] is not None\n        and segment_ids_shape[0] != data_shape[0]\n    ):\n        raise ValueError(\n            \"Argument `segment_ids` and `data` should have same leading \"\n            f\"dimension. Got {segment_ids_shape} v.s. \"\n            f\"{data_shape}.\"\n        )\n\n\nclass SegmentReduction(Operation):\n    def __init__(self, num_segments=None, sorted=False, *, name=None):\n        super().__init__(name=name)\n        self.num_segments = num_segments\n        self.sorted = sorted\n\n    def compute_output_spec(self, data, _):\n        output_shape = (self.num_segments,) + tuple(data.shape[1:])\n        return KerasTensor(shape=output_shape, dtype=data.dtype)\n\n\nclass SegmentSum(SegmentReduction):\n    def call(self, data, segment_ids):\n        _segment_reduce_validation(data, segment_ids)\n        return backend.math.segment_sum(\n            data,\n            segment_ids,\n            num_segments=self.num_segments,\n            sorted=self.sorted,\n        )\n\n\n@keras_export(\"keras.ops.segment_sum\")\ndef segment_sum(data, segment_ids, num_segments=None, sorted=False):\n    \"\"\"Computes the sum of segments in a tensor.\n\n    Args:\n        data: Input tensor.\n        segment_ids: A N-D tensor containing segment indices for each\n            element in `data`. Num dims for segment ids should be strictly\n            smaller or equal to number of dims in data.\n        num_segments: An integer representing the total number of\n            segments. If not specified, it is inferred from the maximum\n            value in `segment_ids`.\n        sorted: A boolean indicating whether `segment_ids` is sorted.\n            Defaults to `False`.\n\n    Returns:\n        A tensor containing the sum of segments, where each element\n        represents the sum of the corresponding segment in `data`.\n\n    Example:\n\n    >>> data = keras.ops.convert_to_tensor([1, 2, 10, 20, 100, 200])\n    >>> segment_ids = keras.ops.convert_to_tensor([0, 0, 1, 1, 2, 2])\n    >>> num_segments = 3\n    >>> keras.ops.segment_sum(data, segment_ids,num_segments)\n    array([3, 30, 300], dtype=int32)\n    \"\"\"\n    _segment_reduce_validation(data, segment_ids)\n    if any_symbolic_tensors((data,)):\n        return SegmentSum(num_segments, sorted).symbolic_call(data, segment_ids)\n    return backend.math.segment_sum(\n        data, segment_ids, num_segments=num_segments, sorted=sorted\n    )\n\n\nclass SegmentMax(SegmentReduction):\n    def call(self, data, segment_ids):\n        _segment_reduce_validation(data, segment_ids)\n        return backend.math.segment_max(\n            data,\n            segment_ids,\n            num_segments=self.num_segments,\n            sorted=self.sorted,\n        )\n\n\n@keras_export(\"keras.ops.segment_max\")\ndef segment_max(data, segment_ids, num_segments=None, sorted=False):\n    \"\"\"Computes the max of segments in a tensor.\n\n    Args:\n        data: Input tensor.\n        segment_ids: A N-D tensor containing segment indices for each\n            element in `data`. data.shape[:len(segment_ids.shape)] should match.\n        num_segments: An integer representing the total number of\n            segments. If not specified, it is inferred from the maximum\n            value in `segment_ids`.\n        sorted: A boolean indicating whether `segment_ids` is sorted.\n            Defaults to `False`.\n\n    Returns:\n        A tensor containing the max of segments, where each element\n        represents the max of the corresponding segment in `data`.\n\n    Example:\n\n    >>> data = keras.ops.convert_to_tensor([1, 2, 10, 20, 100, 200])\n    >>> segment_ids = keras.ops.convert_to_tensor([0, 0, 1, 1, 2, 2])\n    >>> num_segments = 3\n    >>> keras.ops.segment_max(data, segment_ids, num_segments)\n    array([2, 20, 200], dtype=int32)\n    \"\"\"\n    _segment_reduce_validation(data, segment_ids)\n    if any_symbolic_tensors((data,)):\n        return SegmentMax(num_segments, sorted).symbolic_call(data, segment_ids)\n    return backend.math.segment_max(\n        data, segment_ids, num_segments=num_segments, sorted=sorted\n    )\n\n\nclass TopK(Operation):\n    def __init__(self, k, sorted=True, *, name=None):\n        super().__init__(name=name)\n        self.k = k\n        self.sorted = sorted\n\n    def compute_output_spec(self, x):\n        output_shape = list(x.shape)\n        output_shape[-1] = self.k\n        # Return a tuple (values, indices).\n        return (\n            KerasTensor(shape=output_shape, dtype=x.dtype),\n            KerasTensor(shape=output_shape, dtype=\"int32\"),\n        )\n\n    def call(self, x):\n        return backend.math.top_k(x, self.k, self.sorted)\n\n\n@keras_export(\"keras.ops.top_k\")\ndef top_k(x, k, sorted=True):\n    \"\"\"Finds the top-k values and their indices in a tensor.\n\n    Args:\n        x: Input tensor.\n        k: An integer representing the number of top elements to retrieve.\n        sorted: A boolean indicating whether to sort the output in\n        descending order. Defaults to `True`.\n\n    Returns:\n        A tuple containing two tensors. The first tensor contains the\n        top-k values, and the second tensor contains the indices of the\n        top-k values in the input tensor.\n\n    Example:\n\n    >>> x = keras.ops.convert_to_tensor([5, 2, 7, 1, 9, 3])\n    >>> values, indices = top_k(x, k=3)\n    >>> print(values)\n    array([9 7 5], shape=(3,), dtype=int32)\n    >>> print(indices)\n    array([4 2 0], shape=(3,), dtype=int32)\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return TopK(k, sorted).symbolic_call(x)\n    return backend.math.top_k(x, k, sorted)\n\n\nclass InTopK(Operation):\n    def __init__(self, k, *, name=None):\n        super().__init__(name=name)\n        self.k = k\n\n    def compute_output_spec(self, targets, predictions):\n        return KerasTensor(shape=targets.shape, dtype=\"bool\")\n\n    def call(self, targets, predictions):\n        return backend.math.in_top_k(targets, predictions, self.k)\n\n\n@keras_export(\"keras.ops.in_top_k\")\ndef in_top_k(targets, predictions, k):\n    \"\"\"Checks if the targets are in the top-k predictions.\n\n    Args:\n        targets: A tensor of true labels.\n        predictions: A tensor of predicted labels.\n        k: An integer representing the number of predictions to consider.\n\n    Returns:\n        A boolean tensor of the same shape as `targets`, where each element\n        indicates whether the corresponding target is in the top-k predictions.\n\n    Example:\n\n    >>> targets = keras.ops.convert_to_tensor([2, 5, 3])\n    >>> predictions = keras.ops.convert_to_tensor(\n    ... [[0.1, 0.4, 0.6, 0.9, 0.5],\n    ...  [0.1, 0.7, 0.9, 0.8, 0.3],\n    ...  [0.1, 0.6, 0.9, 0.9, 0.5]])\n    >>> in_top_k(targets, predictions, k=3)\n    array([ True False  True], shape=(3,), dtype=bool)\n    \"\"\"\n    if any_symbolic_tensors((targets, predictions)):\n        return InTopK(k).symbolic_call(targets, predictions)\n    return backend.math.in_top_k(targets, predictions, k)\n\n\nclass Logsumexp(Operation):\n    def __init__(self, axis=None, keepdims=False, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n        self.keepdims = keepdims\n\n    def compute_output_spec(self, x):\n        output_shape = reduce_shape(x.shape, self.axis, self.keepdims)\n        return KerasTensor(shape=output_shape)\n\n    def call(self, x):\n        return backend.math.logsumexp(x, axis=self.axis, keepdims=self.keepdims)\n\n\n@keras_export(\"keras.ops.logsumexp\")\ndef logsumexp(x, axis=None, keepdims=False):\n    \"\"\"Computes the logarithm of sum of exponentials of elements in a tensor.\n\n    Args:\n        x: Input tensor.\n        axis: An integer or a tuple of integers specifying the axis/axes\n            along which to compute the sum. If `None`, the sum is computed\n            over all elements. Defaults to `None`.\n        keepdims: A boolean indicating whether to keep the dimensions of\n            the input tensor when computing the sum. Defaults to `False`.\n\n    Returns:\n        A tensor containing the logarithm of the sum of exponentials of\n        elements in `x`.\n\n    Example:\n\n    >>> x = keras.ops.convert_to_tensor([1., 2., 3.])\n    >>> logsumexp(x)\n    3.407606\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Logsumexp(axis, keepdims).symbolic_call(x)\n    return backend.math.logsumexp(x, axis=axis, keepdims=keepdims)\n\n\nclass ExtractSequences(Operation):\n    def __init__(self, sequence_length, sequence_stride, *, name=None):\n        super().__init__(name=name)\n        self.sequence_length = sequence_length\n        self.sequence_stride = sequence_stride\n\n    def compute_output_spec(self, x):\n        if len(x.shape) < 1:\n            raise ValueError(\n                f\"Input should have rank >= 1. \"\n                f\"Received: input.shape = {x.shape}\"\n            )\n        if x.shape[-1] is not None:\n            num_sequences = (\n                1 + (x.shape[-1] - self.sequence_length) // self.sequence_stride\n            )\n        else:\n            num_sequences = None\n        new_shape = x.shape[:-1] + (num_sequences, self.sequence_length)\n        return KerasTensor(shape=new_shape, dtype=x.dtype)\n\n    def call(self, x):\n        return backend.math.extract_sequences(\n            x,\n            sequence_length=self.sequence_length,\n            sequence_stride=self.sequence_stride,\n        )\n\n\n@keras_export(\"keras.ops.extract_sequences\")\ndef extract_sequences(x, sequence_length, sequence_stride):\n    \"\"\"Expands the dimension of last axis into sequences of `sequence_length`.\n\n    Slides a window of size `sequence_length` over the last axis of the input\n    with a stride of `sequence_stride`, replacing the last axis with\n    `[num_sequences, sequence_length]` sequences.\n\n    If the dimension along the last axis is N, the number of sequences can be\n    computed by:\n\n    `num_sequences = 1 + (N - sequence_length) // sequence_stride`\n\n    Args:\n        x: Input tensor.\n        sequence_length: An integer representing the sequences length.\n        sequence_stride: An integer representing the sequences hop size.\n\n    Returns:\n        A tensor of sequences with shape [..., num_sequences, sequence_length].\n\n    Example:\n\n    >>> x = keras.ops.convert_to_tensor([1, 2, 3, 4, 5, 6])\n    >>> extract_sequences(x, 3, 2)\n    array([[1, 2, 3],\n       [3, 4, 5]])\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return ExtractSequences(sequence_length, sequence_stride).symbolic_call(\n            x\n        )\n    return backend.math.extract_sequences(x, sequence_length, sequence_stride)\n\n\nclass FFT(Operation):\n    def compute_output_spec(self, x):\n        if not isinstance(x, (tuple, list)) or len(x) != 2:\n            raise ValueError(\n                \"Input `x` should be a tuple of two tensors - real and \"\n                f\"imaginary. Received: x={x}\"\n            )\n\n        real, imag = x\n        # Both real and imaginary parts should have the same shape.\n        if real.shape != imag.shape:\n            raise ValueError(\n                \"Input `x` should be a tuple of two tensors - real and \"\n                \"imaginary. Both the real and imaginary parts should have the \"\n                f\"same shape. Received: x[0].shape = {real.shape}, \"\n                f\"x[1].shape = {imag.shape}\"\n            )\n\n        # We are calculating 1D FFT. Hence, rank >= 1.\n        if len(real.shape) < 1:\n            raise ValueError(\n                f\"Input should have rank >= 1. \"\n                f\"Received: input.shape = {real.shape}\"\n            )\n\n        # The axis along which we are calculating FFT should be fully-defined.\n        m = real.shape[-1]\n        if m is None:\n            raise ValueError(\n                f\"Input should have its last dimension fully-defined. \"\n                f\"Received: input.shape = {real.shape}\"\n            )\n\n        return (\n            KerasTensor(shape=real.shape, dtype=real.dtype),\n            KerasTensor(shape=imag.shape, dtype=imag.dtype),\n        )\n\n    def call(self, x):\n        return backend.math.fft(x)\n\n\n@keras_export(\"keras.ops.fft\")\ndef fft(x):\n    \"\"\"Computes the Fast Fourier Transform along last axis of input.\n\n    Args:\n        x: Tuple of the real and imaginary parts of the input tensor. Both\n            tensors in the tuple should be of floating type.\n\n    Returns:\n        A tuple containing two tensors - the real and imaginary parts of the\n        output tensor.\n\n    Example:\n\n    >>> x = (\n    ...     keras.ops.convert_to_tensor([1., 2.]),\n    ...     keras.ops.convert_to_tensor([0., 1.]),\n    ... )\n    >>> fft(x)\n    (array([ 3., -1.], dtype=float32), array([ 1., -1.], dtype=float32))\n    \"\"\"\n    if any_symbolic_tensors(x):\n        return FFT().symbolic_call(x)\n    return backend.math.fft(x)\n\n\nclass FFT2(Operation):\n    def compute_output_spec(self, x):\n        axes = (-2, -1)\n        if not isinstance(x, (tuple, list)) or len(x) != 2:\n            raise ValueError(\n                \"Input `x` should be a tuple of two tensors - real and \"\n                f\"imaginary. Received: x={x}\"\n            )\n\n        real, imag = x\n        # Both real and imaginary parts should have the same shape.\n        if real.shape != imag.shape:\n            raise ValueError(\n                \"Input `x` should be a tuple of two tensors - real and \"\n                \"imaginary. Both the real and imaginary parts should have the \"\n                f\"same shape. Received: x[0].shape = {real.shape}, \"\n                f\"x[1].shape = {imag.shape}\"\n            )\n        # We are calculating 2D FFT. Hence, rank >= 2.\n        if len(real.shape) < 2:\n            raise ValueError(\n                f\"Input should have rank >= 2. \"\n                f\"Received: input.shape = {real.shape}\"\n            )\n\n        # The axes along which we are calculating FFT should be fully-defined.\n        m = real.shape[axes[0]]\n        n = real.shape[axes[1]]\n        if m is None or n is None:\n            raise ValueError(\n                f\"Input should have its {axes} axes fully-defined. \"\n                f\"Received: input.shape = {real.shape}\"\n            )\n\n        return (\n            KerasTensor(shape=real.shape, dtype=real.dtype),\n            KerasTensor(shape=imag.shape, dtype=imag.dtype),\n        )\n\n    def call(self, x):\n        return backend.math.fft2(x)\n\n\n@keras_export(\"keras.ops.fft2\")\ndef fft2(x):\n    \"\"\"Computes the 2D Fast Fourier Transform along the last two axes of input.\n\n    Args:\n        x: Tuple of the real and imaginary parts of the input tensor. Both\n            tensors in the tuple should be of floating type.\n\n    Returns:\n        A tuple containing two tensors - the real and imaginary parts of the\n        output.\n\n    Example:\n\n    >>> x = (\n    ...     keras.ops.convert_to_tensor([[1., 2.], [2., 1.]]),\n    ...     keras.ops.convert_to_tensor([[0., 1.], [1., 0.]]),\n    ... )\n    >>> fft2(x)\n    (array([[ 6.,  0.],\n        [ 0., -2.]], dtype=float32), array([[ 2.,  0.],\n        [ 0., -2.]], dtype=float32))\n    \"\"\"\n    if any_symbolic_tensors(x):\n        return FFT2().symbolic_call(x)\n    return backend.math.fft2(x)\n\n\nclass IFFT2(Operation):\n    def compute_output_spec(self, x):\n        axes = (-2, -1)\n        if not isinstance(x, (tuple, list)) or len(x) != 2:\n            raise ValueError(\n                \"Input `x` should be a tuple of two tensors - real and \"\n                f\"imaginary. Received: x={x}\"\n            )\n\n        real, imag = x\n        # Both real and imaginary parts should have the same shape.\n        if real.shape != imag.shape:\n            raise ValueError(\n                \"Input `x` should be a tuple of two tensors - real and \"\n                \"imaginary. Both the real and imaginary parts should have the \"\n                f\"same shape. Received: x[0].shape = {real.shape}, \"\n                f\"x[1].shape = {imag.shape}\"\n            )\n        # We are calculating 2D IFFT. Hence, rank >= 2.\n        if len(real.shape) < 2:\n            raise ValueError(\n                f\"Input should have rank >= 2. \"\n                f\"Received: input.shape = {real.shape}\"\n            )\n\n        # The axes along which we are calculating IFFT should be fully-defined.\n        m = real.shape[axes[0]]\n        n = real.shape[axes[1]]\n        if m is None or n is None:\n            raise ValueError(\n                f\"Input should have its {axes} axes fully-defined. \"\n                f\"Received: input.shape = {real.shape}\"\n            )\n\n        return (\n            KerasTensor(shape=real.shape, dtype=real.dtype),\n            KerasTensor(shape=imag.shape, dtype=imag.dtype),\n        )\n\n    def call(self, x):\n        return backend.math.ifft2(x)\n\n\n@keras_export(\"keras.ops.ifft2\")\ndef ifft2(x):\n    \"\"\"Computes the 2D Inverse Fast Fourier Transform along the last two axes of\n        input.\n\n    Args:\n        x: Tuple of the real and imaginary parts of the input tensor. Both\n            tensors in the tuple should be of floating type.\n\n    Returns:\n        A tuple containing two tensors - the real and imaginary parts of the\n        output.\n\n    Example:\n\n    >>> x = (\n    ...     keras.ops.convert_to_tensor([[1., 2.], [2., 1.]]),\n    ...     keras.ops.convert_to_tensor([[0., 1.], [1., 0.]]),\n    ... )\n    >>> ifft2(x)\n    (array([[ 6.,  0.],\n        [ 0., -2.]], dtype=float32), array([[ 2.,  0.],\n        [ 0., -2.]], dtype=float32))\n    \"\"\"\n    if any_symbolic_tensors(x):\n        return IFFT2().symbolic_call(x)\n    return backend.math.ifft2(x)\n\n\nclass RFFT(Operation):\n    def __init__(self, fft_length=None, *, name=None):\n        super().__init__(name=name)\n        self.fft_length = fft_length\n\n    def compute_output_spec(self, x):\n        # We are calculating 1D RFFT. Hence, rank >= 1.\n        if len(x.shape) < 1:\n            raise ValueError(\n                f\"Input should have rank >= 1. \"\n                f\"Received: input.shape = {x.shape}\"\n            )\n\n        if self.fft_length is not None:\n            new_last_dimension = self.fft_length // 2 + 1\n        else:\n            if x.shape[-1] is not None:\n                new_last_dimension = x.shape[-1] // 2 + 1\n            else:\n                new_last_dimension = None\n        new_shape = x.shape[:-1] + (new_last_dimension,)\n\n        return (\n            KerasTensor(shape=new_shape, dtype=x.dtype),\n            KerasTensor(shape=new_shape, dtype=x.dtype),\n        )\n\n    def call(self, x):\n        return backend.math.rfft(x, fft_length=self.fft_length)\n\n\n@keras_export(\"keras.ops.rfft\")\ndef rfft(x, fft_length=None):\n    \"\"\"Real-valued Fast Fourier Transform along the last axis of the input.\n\n    Computes the 1D Discrete Fourier Transform of a real-valued signal over the\n    inner-most dimension of input.\n\n    Since the Discrete Fourier Transform of a real-valued signal is\n    Hermitian-symmetric, RFFT only returns the `fft_length / 2 + 1` unique\n    components of the FFT: the zero-frequency term, followed by the\n    `fft_length / 2` positive-frequency terms.\n\n    Along the axis RFFT is computed on, if `fft_length` is smaller than the\n    corresponding dimension of the input, the dimension is cropped. If it is\n    larger, the dimension is padded with zeros.\n\n    Args:\n        x: Input tensor.\n        fft_length: An integer representing the number of the fft length. If not\n            specified, it is inferred from the length of the last axis of `x`.\n            Defaults to `None`.\n\n    Returns:\n        A tuple containing two tensors - the real and imaginary parts of the\n        output.\n\n    Examples:\n\n    >>> x = keras.ops.convert_to_tensor([0.0, 1.0, 2.0, 3.0, 4.0])\n    >>> rfft(x)\n    (array([10.0, -2.5, -2.5]), array([0.0, 3.4409548, 0.81229924]))\n\n    >>> rfft(x, 3)\n    (array([3.0, -1.5]), array([0.0, 0.8660254]))\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return RFFT(fft_length).symbolic_call(x)\n    return backend.math.rfft(x, fft_length)\n\n\nclass IRFFT(Operation):\n    def __init__(self, fft_length=None, *, name=None):\n        super().__init__(name=name)\n        self.fft_length = fft_length\n\n    def compute_output_spec(self, x):\n        if not isinstance(x, (tuple, list)) or len(x) != 2:\n            raise ValueError(\n                \"Input `x` should be a tuple of two tensors - real and \"\n                f\"imaginary. Received: x={x}\"\n            )\n        real, imag = x\n        # Both real and imaginary parts should have the same shape.\n        if real.shape != imag.shape:\n            raise ValueError(\n                \"Input `x` should be a tuple of two tensors - real and \"\n                \"imaginary. Both the real and imaginary parts should have the \"\n                f\"same shape. Received: x[0].shape = {real.shape}, \"\n                f\"x[1].shape = {imag.shape}\"\n            )\n        # We are calculating 1D IRFFT. Hence, rank >= 1.\n        if len(real.shape) < 1:\n            raise ValueError(\n                f\"Input should have rank >= 1. \"\n                f\"Received: input.shape = {real.shape}\"\n            )\n\n        if self.fft_length is not None:\n            new_last_dimension = self.fft_length\n        else:\n            if real.shape[-1] is not None:\n                new_last_dimension = 2 * (real.shape[-1] - 1)\n            else:\n                new_last_dimension = None\n        new_shape = real.shape[:-1] + (new_last_dimension,)\n        return KerasTensor(shape=new_shape, dtype=real.dtype)\n\n    def call(self, x):\n        return backend.math.irfft(x, fft_length=self.fft_length)\n\n\n@keras_export(\"keras.ops.irfft\")\ndef irfft(x, fft_length=None):\n    \"\"\"Inverse real-valued Fast Fourier transform along the last axis.\n\n    Computes the inverse 1D Discrete Fourier Transform of a real-valued signal\n    over the inner-most dimension of input.\n\n    The inner-most dimension of the input is assumed to be the result of RFFT:\n    the `fft_length / 2 + 1` unique components of the DFT of a real-valued\n    signal. If `fft_length` is not provided, it is computed from the size of the\n    inner-most dimension of the input `(fft_length = 2 * (inner - 1))`. If the\n    FFT length used to compute is odd, it should be provided since it cannot\n    be inferred properly.\n\n    Along the axis IRFFT is computed on, if `fft_length / 2 + 1` is smaller than\n    the corresponding dimension of the input, the dimension is cropped. If it is\n    larger, the dimension is padded with zeros.\n\n    Args:\n        x: Tuple of the real and imaginary parts of the input tensor. Both\n            tensors in the tuple should be of floating type.\n        fft_length: An integer representing the number of the fft length. If not\n            specified, it is inferred from the length of the last axis of `x`.\n            Defaults to `None`.\n\n    Returns:\n        A tensor containing the inverse real-valued Fast Fourier Transform\n        along the last axis of `x`.\n\n    Examples:\n\n    >>> real = keras.ops.convert_to_tensor([0.0, 1.0, 2.0, 3.0, 4.0])\n    >>> imag = keras.ops.convert_to_tensor([0.0, 1.0, 2.0, 3.0, 4.0])\n    >>> irfft((real, imag))\n    array([0.66666667, -0.9106836, 0.24401694])\n\n    >>> irfft(rfft(real, 5), 5)\n    array([0.0, 1.0, 2.0, 3.0, 4.0])\n    \"\"\"\n    if any_symbolic_tensors(x):\n        return IRFFT(fft_length).symbolic_call(x)\n    return backend.math.irfft(x, fft_length)\n\n\nclass STFT(Operation):\n    def __init__(\n        self,\n        sequence_length,\n        sequence_stride,\n        fft_length,\n        window=\"hann\",\n        center=True,\n        *,\n        name=None,\n    ):\n        super().__init__(name=name)\n        self.sequence_length = sequence_length\n        self.sequence_stride = sequence_stride\n        self.fft_length = fft_length\n        self.window = window\n        self.center = center\n\n    def compute_output_spec(self, x):\n        if x.shape[-1] is not None:\n            padded = 0 if self.center is False else (self.fft_length // 2) * 2\n            num_sequences = (\n                1\n                + (x.shape[-1] + padded - self.fft_length)\n                // self.sequence_stride\n            )\n        else:\n            num_sequences = None\n        new_shape = x.shape[:-1] + (num_sequences, self.fft_length // 2 + 1)\n        return (\n            KerasTensor(shape=new_shape, dtype=x.dtype),\n            KerasTensor(shape=new_shape, dtype=x.dtype),\n        )\n\n    def call(self, x):\n        return backend.math.stft(\n            x,\n            sequence_length=self.sequence_length,\n            sequence_stride=self.sequence_stride,\n            fft_length=self.fft_length,\n            window=self.window,\n            center=self.center,\n        )\n\n\n@keras_export(\"keras.ops.stft\")\ndef stft(\n    x, sequence_length, sequence_stride, fft_length, window=\"hann\", center=True\n):\n    \"\"\"Short-Time Fourier Transform along the last axis of the input.\n\n    The STFT computes the Fourier transform of short overlapping windows of the\n    input. This giving frequency components of the signal as they change over\n    time.\n\n    Args:\n        x: Input tensor.\n        sequence_length: An integer representing the sequence length.\n        sequence_stride: An integer representing the sequence hop size.\n        fft_length: An integer representing the size of the FFT to apply. If not\n            specified, uses the smallest power of 2 enclosing `sequence_length`.\n        window: A string, a tensor of the window or `None`. If `window` is a\n            string, available values are `\"hann\"` and `\"hamming\"`. If `window`\n            is a tensor, it will be used directly as the window and its length\n            must be `sequence_length`. If `window` is `None`, no windowing is\n            used. Defaults to `\"hann\"`.\n        center: Whether to pad `x` on both sides so that the t-th sequence is\n            centered at time `t * sequence_stride`. Otherwise, the t-th sequence\n            begins at time `t * sequence_stride`. Defaults to `True`.\n\n    Returns:\n        A tuple containing two tensors - the real and imaginary parts of the\n        STFT output.\n\n    Example:\n\n    >>> x = keras.ops.convert_to_tensor([0.0, 1.0, 2.0, 3.0, 4.0])\n    >>> stft(x, 3, 2, 3)\n    (array([[0.75, -0.375],\n       [3.75, -1.875],\n       [5.25, -2.625]]), array([[0.0, 0.64951905],\n       [0.0, 0.64951905],\n       [0.0, -0.64951905]]))\n    \"\"\"\n    if not isinstance(sequence_stride, int) or sequence_stride <= 0:\n        raise ValueError(\n            \"`sequence_stride` must be a positive integer. \"\n            f\"Received: sequence_stride={sequence_stride}\"\n        )\n    if any_symbolic_tensors((x,)):\n        return STFT(\n            sequence_length=sequence_length,\n            sequence_stride=sequence_stride,\n            fft_length=fft_length,\n            window=window,\n            center=center,\n        ).symbolic_call(x)\n    return backend.math.stft(\n        x,\n        sequence_length=sequence_length,\n        sequence_stride=sequence_stride,\n        fft_length=fft_length,\n        window=window,\n        center=center,\n    )\n\n\nclass ISTFT(Operation):\n    def __init__(\n        self,\n        sequence_length,\n        sequence_stride,\n        fft_length,\n        length=None,\n        window=\"hann\",\n        center=True,\n        *,\n        name=None,\n    ):\n        super().__init__(name=name)\n        self.sequence_length = sequence_length\n        self.sequence_stride = sequence_stride\n        self.fft_length = fft_length\n        self.length = length\n        self.window = window\n        self.center = center\n\n    def compute_output_spec(self, x):\n        if not isinstance(x, (tuple, list)) or len(x) != 2:\n            raise ValueError(\n                \"Input `x` should be a tuple of two tensors - real and \"\n                f\"imaginary. Received: x={x}\"\n            )\n        real, imag = x\n        # Both real and imaginary parts should have the same shape.\n        if real.shape != imag.shape:\n            raise ValueError(\n                \"Input `x` should be a tuple of two tensors - real and \"\n                \"imaginary. Both the real and imaginary parts should have the \"\n                f\"same shape. Received: x[0].shape = {real.shape}, \"\n                f\"x[1].shape = {imag.shape}\"\n            )\n        if len(real.shape) < 2:\n            raise ValueError(\n                f\"Input should have rank >= 2. \"\n                f\"Received: input.shape = {real.shape}\"\n            )\n        if real.shape[-2] is not None:\n            output_size = (\n                real.shape[-2] - 1\n            ) * self.sequence_stride + self.fft_length\n            if self.length is not None:\n                output_size = self.length\n            elif self.center:\n                output_size = output_size - (self.fft_length // 2) * 2\n        else:\n            output_size = None\n        new_shape = real.shape[:-2] + (output_size,)\n        return KerasTensor(shape=new_shape, dtype=real.dtype)\n\n    def call(self, x):\n        return backend.math.istft(\n            x,\n            sequence_length=self.sequence_length,\n            sequence_stride=self.sequence_stride,\n            fft_length=self.fft_length,\n            length=self.length,\n            window=self.window,\n            center=self.center,\n        )\n\n\n@keras_export(\"keras.ops.istft\")\ndef istft(\n    x,\n    sequence_length,\n    sequence_stride,\n    fft_length,\n    length=None,\n    window=\"hann\",\n    center=True,\n):\n    \"\"\"Inverse Short-Time Fourier Transform along the last axis of the input.\n\n    To reconstruct an original waveform, the parameters should be the same in\n    `stft`.\n\n    Args:\n        x: Tuple of the real and imaginary parts of the input tensor. Both\n            tensors in the tuple should be of floating type.\n        sequence_length: An integer representing the sequence length.\n        sequence_stride: An integer representing the sequence hop size.\n        fft_length: An integer representing the size of the FFT that produced\n            `stft`. Should be of type `int32`.\n        length: An integer representing the output is clipped to exactly length.\n            If not specified, no padding or clipping take place. Defaults to\n            `None`.\n        window: A string, a tensor of the window or `None`. If `window` is a\n            string, available values are `\"hann\"` and `\"hamming\"`. If `window`\n            is a tensor, it will be used directly as the window and its length\n            must be `sequence_length`. If `window` is `None`, no windowing is\n            used. Defaults to `\"hann\"`.\n        center: Whether `x` was padded on both sides so that the t-th sequence\n            is centered at time `t * sequence_stride`. Defaults to `True`.\n\n    Returns:\n        A tensor containing the inverse Short-Time Fourier Transform along the\n        last axis of `x`.\n\n    Example:\n\n    >>> x = keras.ops.convert_to_tensor([0.0, 1.0, 2.0, 3.0, 4.0])\n    >>> istft(stft(x, 1, 1, 1), 1, 1, 1)\n    array([0.0, 1.0, 2.0, 3.0, 4.0])\n    \"\"\"\n    if not isinstance(sequence_stride, int) or sequence_stride <= 0:\n        raise ValueError(\n            \"`sequence_stride` must be a positive integer. \"\n            f\"Received: sequence_stride={sequence_stride}\"\n        )\n    if any_symbolic_tensors(x):\n        return ISTFT(\n            sequence_length=sequence_length,\n            sequence_stride=sequence_stride,\n            fft_length=fft_length,\n            window=window,\n            center=center,\n        ).symbolic_call(x)\n    return backend.math.istft(\n        x,\n        sequence_length=sequence_length,\n        sequence_stride=sequence_stride,\n        fft_length=fft_length,\n        length=length,\n        window=window,\n        center=center,\n    )\n\n\nclass Rsqrt(Operation):\n    def call(self, x):\n        x = backend.convert_to_tensor(x)\n        return backend.math.rsqrt(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export(\"keras.ops.rsqrt\")\ndef rsqrt(x):\n    \"\"\"Computes reciprocal of square root of x element-wise.\n\n    Args:\n        x: input tensor\n\n    Returns:\n        A tensor with the same dtype as `x`.\n\n    Example:\n\n    >>> x = keras.ops.convert_to_tensor([1.0, 10.0, 100.0])\n    >>> keras.ops.rsqrt(x)\n    array([1.0, 0.31622776, 0.1], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Rsqrt().symbolic_call(x)\n    x = backend.convert_to_tensor(x)\n    return backend.math.rsqrt(x)\n\n\nclass Erf(Operation):\n    def compute_output_spec(self, x):\n        return KerasTensor(shape=x.shape, dtype=x.dtype)\n\n    def call(self, x):\n        return backend.math.erf(x)\n\n\n@keras_export(\"keras.ops.erf\")\ndef erf(x):\n    \"\"\"Computes the error function of `x`, element-wise.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        A tensor with the same dtype as `x`.\n\n    Example:\n\n    >>> x = np.array([-3.0, -2.0, -1.0, 0.0, 1.0])\n    >>> keras.ops.erf(x)\n    array([-0.99998 , -0.99532, -0.842701,  0.,  0.842701], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Erf().symbolic_call(x)\n    x = backend.convert_to_tensor(x)\n    return backend.math.erf(x)\n\n\nclass Erfinv(Operation):\n    def compute_output_spec(self, x):\n        return KerasTensor(shape=x.shape, dtype=x.dtype)\n\n    def call(self, x):\n        return backend.math.erfinv(x)\n\n\n@keras_export(\"keras.ops.erfinv\")\ndef erfinv(x):\n    \"\"\"Computes the inverse error function of `x`, element-wise.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        A tensor with the same dtype as `x`.\n\n    Example:\n\n    >>> x = np.array([-0.5, -0.2, -0.1, 0.0, 0.3])\n    >>> keras.ops.erfinv(x)\n    array([-0.47694, -0.17914, -0.08886,  0. ,  0.27246], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Erfinv().symbolic_call(x)\n    x = backend.convert_to_tensor(x)\n    return backend.math.erfinv(x)\n\n\nclass Logdet(Operation):\n    def call(self, x):\n        return backend.math.logdet(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape[:-2], dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.logdet\"])\ndef logdet(x):\n    \"\"\"Computes log of the determinant of a hermitian positive definite matrix.\n\n    Args:\n        x: Input matrix. It must 2D and square.\n\n    Returns:\n        The natural log of the determinant of matrix.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Logdet().symbolic_call(x)\n    return backend.math.logdet(x)\n\n\nclass ViewAsComplex(Operation):\n    def call(self, x):\n        x = backend.convert_to_tensor(x)\n        if len(x.shape) < 1 or x.shape[-1] != 2:\n            raise ValueError(\n                \"Input tensor's last dimension must be 2 (real and imaginary).\"\n            )\n        return x[..., 0] + 1j * x[..., 1]\n\n    def compute_output_spec(self, x):\n        return KerasTensor(shape=x.shape[:-1], dtype=\"complex64\")\n\n\nclass ViewAsReal(Operation):\n    def call(self, x):\n        x = backend.convert_to_tensor(x)\n        real_part = backend.numpy.real(x)\n        imag_part = backend.numpy.imag(x)\n        return backend.numpy.stack((real_part, imag_part), axis=-1)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(shape=x.shape + (2,), dtype=\"float32\")\n\n\n@keras_export(\"keras.ops.view_as_complex\")\ndef view_as_complex(x):\n    \"\"\"Converts a real tensor with shape `(..., 2)` to a complex tensor,\n    where the last dimension represents the real and imaginary components\n    of a complex tensor.\n\n    Args:\n        x: A real tensor with last dimension of size 2.\n\n    Returns:\n        A complex tensor with shape `x.shape[:-1]`.\n\n    Example:\n\n    ```\n    >>> import numpy as np\n    >>> from keras import ops\n\n    >>> real_imag = np.array([[1.0, 2.0], [3.0, 4.0]])\n    >>> complex_tensor = ops.view_as_complex(real_imag)\n    >>> complex_tensor\n    array([1.+2.j, 3.+4.j])\n    ```\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return ViewAsComplex().symbolic_call(x)\n\n    x = backend.convert_to_tensor(x)\n    if len(x.shape) < 1 or x.shape[-1] != 2:\n        raise ValueError(\n            \"Last dimension of input must be size 2 (real and imaginary). \"\n            f\"Received shape: {x.shape}\"\n        )\n    real_part = x[..., 0]\n    imag_part = x[..., 1]\n\n    return backend.cast(real_part, dtype=\"complex64\") + 1j * backend.cast(\n        imag_part, dtype=\"complex64\"\n    )\n\n\n@keras_export(\"keras.ops.view_as_real\")\ndef view_as_real(x):\n    \"\"\"Converts a complex tensor to a real tensor with shape `(..., 2)`,\n    where the last dimension represents the real and imaginary components.\n\n    Args:\n        x: A complex tensor.\n\n    Returns:\n        A real tensor where the last dimension contains the\n        real and imaginary parts.\n\n    Example:\n    ```\n    >>> import numpy as np\n    >>> from keras import ops\n\n    >>> complex_tensor = np.array([1 + 2j, 3 + 4j])\n    >>> real = ops.view_as_real(complex_tensor)\n    >>> real\n    array([[1., 2.],\n           [3., 4.]])\n    ```\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return ViewAsReal().symbolic_call(x)\n\n    x = backend.convert_to_tensor(x)\n    real_part = backend.numpy.real(x)\n    imag_part = backend.numpy.imag(x)\n    return backend.numpy.stack((real_part, imag_part), axis=-1)\n"
  },
  {
    "path": "keras/src/ops/math_test.py",
    "content": "import math\n\nimport jax.numpy as jnp\nimport numpy as np\nimport pytest\nimport scipy.signal\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import testing\nfrom keras.src.backend.common import dtypes\nfrom keras.src.backend.common import standardize_dtype\nfrom keras.src.backend.common.keras_tensor import KerasTensor\nfrom keras.src.ops import math as kmath\n\n\ndef _stft(\n    x, sequence_length, sequence_stride, fft_length, window=\"hann\", center=True\n):\n    # pure numpy version of stft that matches librosa's implementation\n    x = np.array(x)\n    ori_dtype = x.dtype\n\n    if center:\n        pad_width = [(0, 0) for _ in range(len(x.shape))]\n        pad_width[-1] = (fft_length // 2, fft_length // 2)\n        x = np.pad(x, pad_width, mode=\"reflect\")\n\n    l_pad = (fft_length - sequence_length) // 2\n    r_pad = fft_length - sequence_length - l_pad\n\n    if window is not None:\n        if isinstance(window, str):\n            window = scipy.signal.get_window(window, sequence_length)\n        win = np.array(window, dtype=x.dtype)\n        win = np.pad(win, [[l_pad, r_pad]])\n    else:\n        win = np.ones((sequence_length + l_pad + r_pad), dtype=x.dtype)\n\n    x = scipy.signal.stft(\n        x,\n        fs=1.0,\n        window=win,\n        nperseg=(sequence_length + l_pad + r_pad),\n        noverlap=(sequence_length + l_pad + r_pad - sequence_stride),\n        nfft=fft_length,\n        boundary=None,\n        padded=False,\n    )[-1]\n\n    # scale and swap to (..., num_sequences, fft_bins)\n    x = x / np.sqrt(1.0 / win.sum() ** 2)\n    x = np.swapaxes(x, -2, -1)\n    return np.real(x).astype(ori_dtype), np.imag(x).astype(ori_dtype)\n\n\ndef _istft(\n    x,\n    sequence_length,\n    sequence_stride,\n    fft_length,\n    length=None,\n    window=\"hann\",\n    center=True,\n):\n    # pure numpy version of istft that matches librosa's implementation\n    complex_input = x[0] + 1j * x[1]\n    x = np.fft.irfft(\n        complex_input, n=fft_length, axis=-1, norm=\"backward\"\n    ).astype(x[0].dtype)\n\n    expected_output_len = fft_length + sequence_stride * (x.shape[-2] - 1)\n\n    if window is not None:\n        if isinstance(window, str):\n            win = np.array(\n                scipy.signal.get_window(window, sequence_length), dtype=x.dtype\n            )\n        else:\n            win = np.array(window, dtype=x.dtype)\n        l_pad = (fft_length - sequence_length) // 2\n        r_pad = fft_length - sequence_length - l_pad\n        win = np.pad(win, [[l_pad, r_pad]])\n\n        # square and sum\n        _sequence_length = sequence_length + l_pad + r_pad\n        denom = np.square(win)\n        overlaps = -(-_sequence_length // sequence_stride)\n        denom = np.pad(\n            denom, [(0, overlaps * sequence_stride - _sequence_length)]\n        )\n        denom = np.reshape(denom, [overlaps, sequence_stride])\n        denom = np.sum(denom, 0, keepdims=True)\n        denom = np.tile(denom, [overlaps, 1])\n        denom = np.reshape(denom, [overlaps * sequence_stride])\n        win = np.divide(win, denom[:_sequence_length])\n        x = np.multiply(x, win)\n\n    # overlap_sequences\n    def _overlap_sequences(x, sequence_stride):\n        *batch_shape, num_sequences, sequence_length = x.shape\n        flat_batchsize = math.prod(batch_shape)\n        x = np.reshape(x, (flat_batchsize, num_sequences, sequence_length))\n        output_size = sequence_stride * (num_sequences - 1) + sequence_length\n        nstep_per_segment = 1 + (sequence_length - 1) // sequence_stride\n        padded_segment_len = nstep_per_segment * sequence_stride\n        x = np.pad(\n            x, ((0, 0), (0, 0), (0, padded_segment_len - sequence_length))\n        )\n        x = np.reshape(\n            x,\n            (flat_batchsize, num_sequences, nstep_per_segment, sequence_stride),\n        )\n        x = x.transpose((0, 2, 1, 3))\n        x = np.pad(x, ((0, 0), (0, 0), (0, num_sequences), (0, 0)))\n        shrinked = x.shape[2] - 1\n        x = np.reshape(x, (flat_batchsize, -1))\n        x = x[:, : (nstep_per_segment * shrinked * sequence_stride)]\n        x = np.reshape(\n            x, (flat_batchsize, nstep_per_segment, shrinked * sequence_stride)\n        )\n        x = np.sum(x, axis=1)[:, :output_size]\n        return np.reshape(x, tuple(batch_shape) + (-1,))\n\n    x = _overlap_sequences(x, sequence_stride)\n\n    start = 0 if center is False else fft_length // 2\n    if length is not None:\n        end = start + length\n    elif center:\n        end = -(fft_length // 2)\n    else:\n        end = expected_output_len\n    return x[..., start:end]\n\n\ndef _sum_reduce(left, right):\n    return left + right\n\n\ndef _max_reduce(left, right):\n    return np.max(np.stack([left, right]), axis=0)\n\n\nclass MathOpsDynamicShapeTest(testing.TestCase):\n    @parameterized.parameters([(kmath.segment_sum,), (kmath.segment_max,)])\n    def test_segment_reduce(self, segment_reduce_op):\n        # 1D case\n        data = KerasTensor((None, 4), dtype=\"float32\")\n        segment_ids = KerasTensor((10,), dtype=\"int32\")\n        outputs = segment_reduce_op(data, segment_ids)\n        self.assertEqual(outputs.shape, (None, 4))\n\n        data = KerasTensor((None, 4), dtype=\"float32\")\n        segment_ids = KerasTensor((10,), dtype=\"int32\")\n        outputs = segment_reduce_op(data, segment_ids, num_segments=5)\n        self.assertEqual(outputs.shape, (5, 4))\n\n        data = KerasTensor((10,), dtype=\"float32\")\n        segment_ids = KerasTensor(\n            (10,),\n            dtype=\"int32\",\n        )\n        outputs = segment_reduce_op(data, segment_ids)\n        self.assertEqual(outputs.shape, (None,))\n\n    def test_top_k(self):\n        x = KerasTensor((None, 2, 3))\n        values, indices = kmath.top_k(x, k=1)\n        self.assertEqual(values.shape, (None, 2, 1))\n        self.assertEqual(indices.shape, (None, 2, 1))\n\n    def test_in_top_k(self):\n        targets = KerasTensor((None,))\n        predictions = KerasTensor((None, 10))\n        self.assertEqual(\n            kmath.in_top_k(targets, predictions, k=1).shape, (None,)\n        )\n\n    def test_logsumexp(self):\n        x = KerasTensor((None, 2, 3), dtype=\"float32\")\n        self.assertEqual(kmath.logsumexp(x).shape, ())\n        self.assertEqual(kmath.logsumexp(x, axis=1).shape, (None, 3))\n        self.assertEqual(kmath.logsumexp(x, axis=(1, 2)).shape, (None,))\n        self.assertEqual(kmath.logsumexp(x, keepdims=True).shape, (1, 1, 1))\n\n    def test_extract_sequences(self):\n        # Defined dimension\n        x = KerasTensor((None, 32), dtype=\"float32\")\n        sequence_length = 3\n        sequence_stride = 2\n        outputs = kmath.extract_sequences(x, sequence_length, sequence_stride)\n        num_sequences = 1 + (x.shape[-1] - sequence_length) // sequence_stride\n        self.assertEqual(outputs.shape, (None, num_sequences, sequence_length))\n\n        # Undefined dimension\n        x = KerasTensor((None, None), dtype=\"float32\")\n        sequence_length = 3\n        sequence_stride = 2\n        outputs = kmath.extract_sequences(x, sequence_length, sequence_stride)\n        self.assertEqual(outputs.shape, (None, None, sequence_length))\n\n    def test_fft(self):\n        real = KerasTensor((None, 4, 3), dtype=\"float32\")\n        imag = KerasTensor((None, 4, 3), dtype=\"float32\")\n        real_output, imag_output = kmath.fft((real, imag))\n        ref = np.fft.fft(np.ones((2, 4, 3)))\n        ref_shape = (None,) + ref.shape[1:]\n        self.assertEqual(real_output.shape, ref_shape)\n        self.assertEqual(imag_output.shape, ref_shape)\n\n    def test_fft2(self):\n        real = KerasTensor((None, 4, 3), dtype=\"float32\")\n        imag = KerasTensor((None, 4, 3), dtype=\"float32\")\n        real_output, imag_output = kmath.fft2((real, imag))\n        ref = np.fft.fft2(np.ones((2, 4, 3)))\n        ref_shape = (None,) + ref.shape[1:]\n        self.assertEqual(real_output.shape, ref_shape)\n        self.assertEqual(imag_output.shape, ref_shape)\n\n    def test_ifft2(self):\n        real = KerasTensor((None, 4, 3), dtype=\"float32\")\n        imag = KerasTensor((None, 4, 3), dtype=\"float32\")\n        real_output, imag_output = kmath.ifft2((real, imag))\n        ref = np.fft.ifft2(np.ones((2, 4, 3)))\n        ref_shape = (None,) + ref.shape[1:]\n        self.assertEqual(real_output.shape, ref_shape)\n        self.assertEqual(imag_output.shape, ref_shape)\n\n    @parameterized.parameters([(None,), (1,), (5,)])\n    def test_rfft(self, fft_length):\n        x = KerasTensor((None, 4, 3), dtype=\"float32\")\n        real_output, imag_output = kmath.rfft(x, fft_length=fft_length)\n        ref = np.fft.rfft(np.ones((2, 4, 3)), n=fft_length)\n        ref_shape = (None,) + ref.shape[1:]\n        self.assertEqual(real_output.shape, ref_shape)\n        self.assertEqual(imag_output.shape, ref_shape)\n\n    @parameterized.parameters([(None,), (1,), (5,)])\n    def test_irfft(self, fft_length):\n        real = KerasTensor((None, 4, 3), dtype=\"float32\")\n        imag = KerasTensor((None, 4, 3), dtype=\"float32\")\n        output = kmath.irfft((real, imag), fft_length=fft_length)\n        ref = np.fft.irfft(np.ones((2, 4, 3)), n=fft_length)\n        ref_shape = (None,) + ref.shape[1:]\n        self.assertEqual(output.shape, ref_shape)\n\n    def test_stft(self):\n        x = KerasTensor((None, 32), dtype=\"float32\")\n        sequence_length = 10\n        sequence_stride = 3\n        fft_length = 15\n        real_output, imag_output = kmath.stft(\n            x, sequence_length, sequence_stride, fft_length\n        )\n        real_ref, imag_ref = _stft(\n            np.ones((2, 32)), sequence_length, sequence_stride, fft_length\n        )\n        real_ref_shape = (None,) + real_ref.shape[1:]\n        imag_ref_shape = (None,) + imag_ref.shape[1:]\n        self.assertEqual(real_output.shape, real_ref_shape)\n        self.assertEqual(imag_output.shape, imag_ref_shape)\n\n    def test_istft(self):\n        sequence_length = 10\n        sequence_stride = 3\n        fft_length = 15\n        real = KerasTensor((None, 32), dtype=\"float32\")\n        imag = KerasTensor((None, 32), dtype=\"float32\")\n        output = kmath.istft(\n            (real, imag), sequence_length, sequence_stride, fft_length\n        )\n        ref = _istft(\n            (np.ones((5, 32)), np.ones((5, 32))),\n            sequence_length,\n            sequence_stride,\n            fft_length,\n        )\n        ref_shape = (None,) + ref.shape[1:]\n        self.assertEqual(output.shape, ref_shape)\n\n    def test_rsqrt(self):\n        x = KerasTensor([None, 3])\n        self.assertEqual(kmath.rsqrt(x).shape, (None, 3))\n\n    def test_logdet(self):\n        x = KerasTensor((None, 3, 3))\n        out = kmath.logdet(x)\n        self.assertEqual(out.shape, (None,))\n\n\nclass MathOpsStaticShapeTest(testing.TestCase):\n    @parameterized.parameters([(kmath.segment_sum,), (kmath.segment_max,)])\n    @pytest.mark.skipif(\n        backend.backend() == \"jax\",\n        reason=\"JAX does not support `num_segments=None`.\",\n    )\n    def test_segment_reduce(self, segment_reduce_op):\n        # 1D case\n        data = KerasTensor((10, 4), dtype=\"float32\")\n        segment_ids = KerasTensor((10,), dtype=\"int32\")\n        outputs = segment_reduce_op(data, segment_ids)\n        self.assertEqual(outputs.shape, (None, 4))\n\n        data = KerasTensor((10,), dtype=\"float32\")\n        segment_ids = KerasTensor((10,), dtype=\"int32\")\n        outputs = segment_reduce_op(data, segment_ids)\n        self.assertEqual(outputs.shape, (None,))\n\n    @parameterized.parameters([(kmath.segment_sum,), (kmath.segment_max,)])\n    def test_segment_reduce_explicit_num_segments(self, segment_reduce_op):\n        # 1D case\n        data = KerasTensor((10, 4), dtype=\"float32\")\n        segment_ids = KerasTensor((10,), dtype=\"int32\")\n        outputs = segment_reduce_op(data, segment_ids, num_segments=5)\n        self.assertEqual(outputs.shape, (5, 4))\n\n        data = KerasTensor((6,), dtype=\"float32\")\n        segment_ids = KerasTensor(\n            (6,),\n            dtype=\"int32\",\n        )\n        outputs = segment_reduce_op(data, segment_ids, num_segments=5)\n        self.assertEqual(outputs.shape, (5,))\n\n    def test_topk(self):\n        x = KerasTensor((1, 2, 3))\n        values, indices = kmath.top_k(x, k=1)\n        self.assertEqual(values.shape, (1, 2, 1))\n        self.assertEqual(indices.shape, (1, 2, 1))\n\n    def test_in_top_k(self):\n        targets = KerasTensor((5,))\n        predictions = KerasTensor((5, 10))\n        self.assertEqual(kmath.in_top_k(targets, predictions, k=1).shape, (5,))\n\n    def test_logsumexp(self):\n        x = KerasTensor((1, 2, 3), dtype=\"float32\")\n        result = kmath.logsumexp(x)\n        self.assertEqual(result.shape, ())\n\n    def test_extract_sequences(self):\n        x = KerasTensor((10, 16), dtype=\"float32\")\n        sequence_length = 3\n        sequence_stride = 2\n        outputs = kmath.extract_sequences(x, sequence_length, sequence_stride)\n        num_sequences = 1 + (x.shape[-1] - sequence_length) // sequence_stride\n        self.assertEqual(outputs.shape, (10, num_sequences, sequence_length))\n\n    def test_fft(self):\n        real = KerasTensor((2, 4, 3), dtype=\"float32\")\n        imag = KerasTensor((2, 4, 3), dtype=\"float32\")\n        real_output, imag_output = kmath.fft((real, imag))\n        ref = np.fft.fft(np.ones((2, 4, 3)))\n        self.assertEqual(real_output.shape, ref.shape)\n        self.assertEqual(imag_output.shape, ref.shape)\n\n    def test_fft2(self):\n        real = KerasTensor((2, 4, 3), dtype=\"float32\")\n        imag = KerasTensor((2, 4, 3), dtype=\"float32\")\n        real_output, imag_output = kmath.fft2((real, imag))\n        ref = np.fft.fft2(np.ones((2, 4, 3)))\n        self.assertEqual(real_output.shape, ref.shape)\n        self.assertEqual(imag_output.shape, ref.shape)\n\n    def test_ifft2(self):\n        real = KerasTensor((2, 4, 3), dtype=\"float32\")\n        imag = KerasTensor((2, 4, 3), dtype=\"float32\")\n        real_output, imag_output = kmath.ifft2((real, imag))\n        ref = np.fft.ifft2(np.ones((2, 4, 3)))\n        self.assertEqual(real_output.shape, ref.shape)\n        self.assertEqual(imag_output.shape, ref.shape)\n\n    def test_rfft(self):\n        x = KerasTensor((2, 4, 3), dtype=\"float32\")\n        real_output, imag_output = kmath.rfft(x)\n        ref = np.fft.rfft(np.ones((2, 4, 3)))\n        self.assertEqual(real_output.shape, ref.shape)\n        self.assertEqual(imag_output.shape, ref.shape)\n\n    def test_irfft(self):\n        real = KerasTensor((2, 4, 3), dtype=\"float32\")\n        imag = KerasTensor((2, 4, 3), dtype=\"float32\")\n        output = kmath.irfft((real, imag))\n        ref = np.fft.irfft(np.ones((2, 4, 3)))\n        self.assertEqual(output.shape, ref.shape)\n\n    def test_rsqrt(self):\n        x = KerasTensor([4, 3], dtype=\"float32\")\n        self.assertEqual(kmath.rsqrt(x).shape, (4, 3))\n\n    def test_stft(self):\n        x = KerasTensor((2, 32), dtype=\"float32\")\n        sequence_length = 10\n        sequence_stride = 3\n        fft_length = 15\n        real_output, imag_output = kmath.stft(\n            x, sequence_length, sequence_stride, fft_length\n        )\n        real_ref, imag_ref = _stft(\n            np.ones((2, 32)), sequence_length, sequence_stride, fft_length\n        )\n        self.assertEqual(real_output.shape, real_ref.shape)\n        self.assertEqual(imag_output.shape, imag_ref.shape)\n\n    def test_istft(self):\n        # sequence_stride must <= x[0].shape[-1]\n        # sequence_stride must >= fft_length / num_sequences\n        sequence_length = 10\n        sequence_stride = 3\n        fft_length = 15\n        num_sequences = fft_length // sequence_stride + 1\n        real = KerasTensor((num_sequences, 32), dtype=\"float32\")\n        imag = KerasTensor((num_sequences, 32), dtype=\"float32\")\n        output = kmath.istft(\n            (real, imag), sequence_length, sequence_stride, fft_length\n        )\n        ref = _istft(\n            (np.ones((num_sequences, 32)), np.ones((num_sequences, 32))),\n            sequence_length,\n            sequence_stride,\n            fft_length,\n        )\n        self.assertEqual(output.shape, ref.shape)\n\n    def test_logdet(self):\n        x = KerasTensor((3, 3))\n        out = kmath.logdet(x)\n        self.assertEqual(out.shape, ())\n\n        x = KerasTensor((2, 4, 3, 3))\n        out = kmath.logdet(x)\n        self.assertEqual(out.shape, (2, 4))\n\n\nclass MathOpsCorrectnessTest(testing.TestCase):\n    def run_segment_reduce_test(\n        self,\n        segment_reduce_op,\n        element_wise_reduce_method,\n        num_indices,\n        indices_high,\n        data_dims=tuple(),\n        num_segments=None,\n        add_neg1_to_indices=False,\n        sorted_indices=False,\n    ):\n        if num_segments is not None and indices_high >= num_segments:\n            raise ValueError(\"Indices high cannot be more than num segments\")\n        indices_dims = (num_indices,)\n        full_data_dims = indices_dims + data_dims\n        data = np.random.rand(*full_data_dims).astype(np.float32)\n        segment_ids = np.concatenate(\n            [\n                np.arange(indices_high),\n                np.random.randint(\n                    low=0,\n                    high=indices_high,\n                    size=(indices_dims[0] - indices_high),\n                ),\n            ]\n        ).astype(np.int32)\n        if sorted_indices:\n            segment_ids = np.sort(segment_ids, axis=-1)\n        if add_neg1_to_indices:\n            segment_ids[0] = -1\n        outputs = segment_reduce_op(\n            data, segment_ids, num_segments, sorted=sorted_indices\n        )\n        if num_segments is None:\n            num_segments = np.max(segment_ids).item() + 1\n        expected_shape = (num_segments,) + data_dims\n        if segment_reduce_op == kmath.segment_max:\n            if backend.backend() == \"tensorflow\":\n                empty_fill_value = -np.finfo(np.float32).max\n            else:\n                empty_fill_value = -np.inf\n            expected = np.full(expected_shape, empty_fill_value)\n        else:\n            expected = np.zeros(expected_shape)\n\n        for idx in range(num_indices):\n            segment_id = segment_ids[idx]\n            if segment_id == -1:\n                continue\n            expected[segment_id] = element_wise_reduce_method(\n                expected[segment_id], data[idx]\n            )\n        self.assertAllClose(outputs, expected)\n\n    @parameterized.product(\n        (\n            dict(\n                segment_reduce_op=kmath.segment_sum,\n                element_wise_reduce_method=_sum_reduce,\n            ),\n            dict(\n                segment_reduce_op=kmath.segment_max,\n                element_wise_reduce_method=_max_reduce,\n            ),\n        ),\n        sorted_indices=(True, False),\n    )\n    @pytest.mark.skipif(\n        backend.backend() == \"jax\",\n        reason=\"JAX does not support `num_segments=None`.\",\n    )\n    def test_segment_reduce(\n        self,\n        segment_reduce_op,\n        element_wise_reduce_method,\n        sorted_indices,\n    ):\n        # Test 1D case.\n        self.run_segment_reduce_test(\n            segment_reduce_op,\n            element_wise_reduce_method,\n            num_indices=9,\n            indices_high=3,\n            sorted_indices=sorted_indices,\n        )\n\n        # Test ND data case.\n        self.run_segment_reduce_test(\n            segment_reduce_op,\n            element_wise_reduce_method,\n            num_indices=9,\n            indices_high=3,\n            data_dims=(\n                3,\n                3,\n            ),\n            sorted_indices=sorted_indices,\n        )\n\n    @parameterized.product(\n        (\n            dict(\n                segment_reduce_op=kmath.segment_sum,\n                element_wise_reduce_method=_sum_reduce,\n            ),\n            dict(\n                segment_reduce_op=kmath.segment_max,\n                element_wise_reduce_method=_max_reduce,\n            ),\n        ),\n        (\n            dict(\n                contains_neg1_in_indices=True,\n                sorted_indices=False,\n            ),\n            dict(\n                contains_neg1_in_indices=False,\n                sorted_indices=False,\n            ),\n            dict(\n                contains_neg1_in_indices=False,\n                sorted_indices=True,\n            ),\n        ),\n    )\n    def test_segment_reduce_explicit_num_segments(\n        self,\n        segment_reduce_op,\n        element_wise_reduce_method,\n        contains_neg1_in_indices,\n        sorted_indices,\n    ):\n        if backend.backend() == \"tensorflow\" and sorted_indices:\n            pytest.skip(\n                \"Num segments and sorted_indices=True doesn't work for \"\n                \"tensorflow.\"\n            )\n        # Test 1D case.\n        self.run_segment_reduce_test(\n            segment_reduce_op,\n            element_wise_reduce_method,\n            num_indices=9,\n            indices_high=3,\n            num_segments=4,\n            add_neg1_to_indices=contains_neg1_in_indices,\n            sorted_indices=sorted_indices,\n        )\n\n        # Test ND data case.\n        self.run_segment_reduce_test(\n            segment_reduce_op,\n            element_wise_reduce_method,\n            num_indices=9,\n            indices_high=3,\n            data_dims=(\n                3,\n                3,\n            ),\n            num_segments=4,\n            add_neg1_to_indices=contains_neg1_in_indices,\n            sorted_indices=sorted_indices,\n        )\n\n    def test_top_k(self):\n        x = np.array([0, 4, 2, 1, 3, -1], dtype=np.float32)\n        values, indices = kmath.top_k(x, k=2)\n        self.assertAllClose(values, [4, 3])\n        self.assertAllClose(indices, [1, 4])\n\n        x = np.array([0, 4, 2, 1, 3, -1], dtype=np.float32)\n        values, indices = kmath.top_k(x, k=2, sorted=False)\n        # Any order ok when `sorted=False`.\n        self.assertEqual(set(backend.convert_to_numpy(values)), set([4, 3]))\n        self.assertEqual(set(backend.convert_to_numpy(indices)), set([1, 4]))\n\n        x = np.random.rand(5, 5)\n        outputs = kmath.top_k(x, k=2)\n\n        expected_values = np.zeros((5, 2))\n        expected_indices = np.zeros((5, 2), dtype=np.int32)\n\n        for i in range(x.shape[0]):\n            top_k_indices = np.argsort(x[i])[-2:][::-1]\n            expected_values[i] = x[i, top_k_indices]\n            expected_indices[i] = top_k_indices\n\n        self.assertAllClose(outputs[0], expected_values)\n        self.assertAllClose(outputs[1], expected_indices)\n\n    def test_in_top_k(self):\n        targets = np.array([1, 0, 2])\n        predictions = np.array(\n            [\n                [0.1, 0.9, 0.8, 0.8],\n                [0.05, 0.95, 0, 1],\n                [0.1, 0.8, 0.3, 1],\n            ]\n        )\n        self.assertAllEqual(\n            kmath.in_top_k(targets, predictions, k=1), [True, False, False]\n        )\n        self.assertAllEqual(\n            kmath.in_top_k(targets, predictions, k=2), [True, False, False]\n        )\n        self.assertAllEqual(\n            kmath.in_top_k(targets, predictions, k=3), [True, True, True]\n        )\n\n        # Test tie cases.\n        targets = np.array([1, 0, 2])\n        predictions = np.array(\n            [\n                [0.1, 0.9, 0.8, 0.8],\n                [0.95, 0.95, 0, 0.95],\n                [0.1, 0.8, 0.8, 0.95],\n            ]\n        )\n        self.assertAllEqual(\n            kmath.in_top_k(targets, predictions, k=1), [True, True, False]\n        )\n        self.assertAllEqual(\n            kmath.in_top_k(targets, predictions, k=2), [True, True, True]\n        )\n        self.assertAllEqual(\n            kmath.in_top_k(targets, predictions, k=3), [True, True, True]\n        )\n\n        # Test `nan` in predictions\n        # https://github.com/keras-team/keras/issues/19995\n        targets = np.array([1, 0])\n        predictions = np.array([[0.1, np.nan, 0.5], [0.3, 0.2, 0.5]])\n        self.assertAllEqual(\n            kmath.in_top_k(targets, predictions, k=2), [False, True]\n        )\n\n    def test_logsumexp(self):\n        x = np.random.rand(5, 5)\n        outputs = kmath.logsumexp(x)\n        expected = np.log(np.sum(np.exp(x)))\n        self.assertAllClose(outputs, expected)\n\n        outputs = kmath.logsumexp(x, axis=1)\n        expected = np.log(np.sum(np.exp(x), axis=1))\n        self.assertAllClose(outputs, expected)\n\n    def test_extract_sequences(self):\n        # Test 1D case.\n        x = np.random.random((10,))\n        sequence_length = 3\n        sequence_stride = 2\n        output = kmath.extract_sequences(x, sequence_length, sequence_stride)\n\n        num_sequences = 1 + (x.shape[-1] - sequence_length) // sequence_stride\n        expected = np.zeros(shape=(num_sequences, sequence_length))\n        pos = 0\n        for i in range(num_sequences):\n            expected[i] = x[pos : pos + sequence_length]\n            pos += sequence_stride\n        self.assertAllClose(output, expected, tpu_atol=1e-2, tpu_rtol=1e-2)\n\n        # Test N-D case.\n        x = np.random.random((4, 8))\n        sequence_length = 3\n        sequence_stride = 2\n        output = kmath.extract_sequences(x, sequence_length, sequence_stride)\n\n        num_sequences = 1 + (x.shape[-1] - sequence_length) // sequence_stride\n        expected = np.zeros(shape=(4, num_sequences, sequence_length))\n        pos = 0\n        for i in range(num_sequences):\n            expected[:, i] = x[:, pos : pos + sequence_length]\n            pos += sequence_stride\n        self.assertAllClose(output, expected, tpu_atol=1e-2, tpu_rtol=1e-2)\n\n    def test_fft(self):\n        real = np.random.random((2, 4, 3))\n        imag = np.random.random((2, 4, 3))\n        complex_arr = real + 1j * imag\n\n        real_output, imag_output = kmath.fft((real, imag))\n        ref = np.fft.fft(complex_arr)\n        real_ref = np.real(ref)\n        imag_ref = np.imag(ref)\n        self.assertAllClose(real_ref, real_output)\n        self.assertAllClose(imag_ref, imag_output)\n\n    def test_fft2(self):\n        real = np.random.random((2, 4, 3))\n        imag = np.random.random((2, 4, 3))\n        complex_arr = real + 1j * imag\n\n        real_output, imag_output = kmath.fft2((real, imag))\n        ref = np.fft.fft2(complex_arr)\n        real_ref = np.real(ref)\n        imag_ref = np.imag(ref)\n        self.assertAllClose(real_ref, real_output)\n        self.assertAllClose(imag_ref, imag_output)\n\n    def test_ifft2(self):\n        real = np.random.random((2, 4, 3)).astype(np.float32)\n        imag = np.random.random((2, 4, 3)).astype(np.float32)\n        complex_arr = real + 1j * imag\n\n        real_output, imag_output = kmath.ifft2((real, imag))\n        ref = np.fft.ifft2(complex_arr)\n        real_ref = np.real(ref)\n        imag_ref = np.imag(ref)\n        self.assertAllClose(real_ref, real_output)\n        self.assertAllClose(imag_ref, imag_output)\n\n    @parameterized.parameters([(None,), (3,), (15,)])\n    def test_rfft(self, n):\n        # Test 1D.\n        x = np.random.random((10,))\n        real_output, imag_output = kmath.rfft(x, fft_length=n)\n        ref = np.fft.rfft(x, n=n)\n        real_ref = np.real(ref)\n        imag_ref = np.imag(ref)\n        self.assertAllClose(real_ref, real_output, atol=1e-5, rtol=1e-5)\n        self.assertAllClose(imag_ref, imag_output, atol=1e-5, rtol=1e-5)\n\n        # Test N-D case.\n        x = np.random.random((2, 3, 10))\n        real_output, imag_output = kmath.rfft(x, fft_length=n)\n        ref = np.fft.rfft(x, n=n)\n        real_ref = np.real(ref)\n        imag_ref = np.imag(ref)\n        self.assertAllClose(real_ref, real_output, atol=1e-5, rtol=1e-5)\n        self.assertAllClose(imag_ref, imag_output, atol=1e-5, rtol=1e-5)\n\n    @parameterized.parameters([(None,), (3,), (15,)])\n    def test_irfft(self, n):\n        # Test 1D.\n        real = np.random.random((10,))\n        imag = np.random.random((10,))\n        complex_arr = real + 1j * imag\n        output = kmath.irfft((real, imag), fft_length=n)\n        ref = np.fft.irfft(complex_arr, n=n)\n        self.assertAllClose(output, ref, atol=1e-5, rtol=1e-5)\n\n        # Test N-D case.\n        real = np.random.random((2, 3, 10))\n        imag = np.random.random((2, 3, 10))\n        complex_arr = real + 1j * imag\n        output = kmath.irfft((real, imag), fft_length=n)\n        ref = np.fft.irfft(complex_arr, n=n)\n        self.assertAllClose(output, ref, atol=1e-5, rtol=1e-5)\n\n    @parameterized.parameters(\n        [\n            (32, 8, 32, \"hann\", True),\n            (8, 8, 16, \"hann\", True),\n            (4, 4, 7, \"hann\", True),\n            (32, 8, 32, \"hamming\", True),\n            (32, 8, 32, \"hann\", False),\n            (32, 8, 32, np.ones((32,)), True),\n            (32, 8, 32, None, True),\n        ]\n    )\n    def test_stft(\n        self, sequence_length, sequence_stride, fft_length, window, center\n    ):\n        # Test 1D case.\n        x = np.random.random((32,))\n        real_output, imag_output = kmath.stft(\n            x, sequence_length, sequence_stride, fft_length, window, center\n        )\n        real_ref, imag_ref = _stft(\n            x, sequence_length, sequence_stride, fft_length, window, center\n        )\n        self.assertAllClose(real_ref, real_output, atol=1e-5, rtol=1e-5)\n        self.assertAllClose(imag_ref, imag_output, atol=1e-5, rtol=1e-5)\n\n        # Test N-D case.\n        x = np.random.random((2, 3, 32))\n        real_output, imag_output = kmath.stft(\n            x, sequence_length, sequence_stride, fft_length, window, center\n        )\n        real_ref, imag_ref = _stft(\n            x, sequence_length, sequence_stride, fft_length, window, center\n        )\n        self.assertAllClose(real_ref, real_output, atol=1e-5, rtol=1e-5)\n        self.assertAllClose(imag_ref, imag_output, atol=1e-5, rtol=1e-5)\n\n    @parameterized.parameters(\n        [\n            (32, 8, 32, \"hann\", True),\n            (8, 8, 16, \"hann\", True),\n            (4, 4, 7, \"hann\", True),\n            (32, 8, 32, \"hamming\", True),\n            (8, 4, 8, \"hann\", False),\n            (32, 8, 32, np.ones((32,)), True),\n            (32, 8, 32, None, True),\n        ]\n    )\n    def test_istft(\n        self, sequence_length, sequence_stride, fft_length, window, center\n    ):\n        # sequence_stride must <= x[0].shape[-1]\n        # sequence_stride must >= fft_length / num_sequences\n        # Test 1D case.\n        x = np.random.random((256,))\n        real_x, imag_x = _stft(\n            x, sequence_length, sequence_stride, fft_length, window, center\n        )\n        output = kmath.istft(\n            (real_x, imag_x),\n            sequence_length,\n            sequence_stride,\n            fft_length,\n            window=window,\n            center=center,\n        )\n        ref = _istft(\n            (real_x, imag_x),\n            sequence_length,\n            sequence_stride,\n            fft_length,\n            window=window,\n            center=center,\n        )\n        if backend.backend() in (\"numpy\", \"jax\", \"torch\"):\n            # these backends have different implementation for the boundary of\n            # the output, so we need to truncate 5% before assertAllClose\n            truncated_len = int(output.shape[-1] * 0.05)\n            output = output[..., truncated_len:-truncated_len]\n            ref = ref[..., truncated_len:-truncated_len]\n        # Nans are handled differently in different backends, so zero them out.\n        output = np.nan_to_num(backend.convert_to_numpy(output), nan=0.0)\n        ref = np.nan_to_num(ref, nan=0.0)\n        self.assertAllClose(output, ref, atol=1e-5, rtol=1e-5)\n\n        # Test N-D case.\n        x = np.random.random((2, 3, 256))\n        real_x, imag_x = _stft(\n            x, sequence_length, sequence_stride, fft_length, window, center\n        )\n        output = kmath.istft(\n            (real_x, imag_x),\n            sequence_length,\n            sequence_stride,\n            fft_length,\n            window=window,\n            center=center,\n        )\n        ref = _istft(\n            (real_x, imag_x),\n            sequence_length,\n            sequence_stride,\n            fft_length,\n            window=window,\n            center=center,\n        )\n        if backend.backend() in (\"numpy\", \"jax\", \"torch\"):\n            # these backends have different implementation for the boundary of\n            # the output, so we need to truncate 5% before assertAllClose\n            truncated_len = int(output.shape[-1] * 0.05)\n            output = output[..., truncated_len:-truncated_len]\n            ref = ref[..., truncated_len:-truncated_len]\n        # Nans are handled differently in different backends, so zero them out.\n        output = np.nan_to_num(backend.convert_to_numpy(output), nan=0.0)\n        ref = np.nan_to_num(ref, nan=0.0)\n        self.assertAllClose(output, ref, atol=1e-5, rtol=1e-5)\n\n    def test_rsqrt(self):\n        x = np.array([[1, 4, 9], [16, 25, 36]], dtype=\"float32\")\n        self.assertAllClose(kmath.rsqrt(x), 1 / np.sqrt(x))\n        self.assertAllClose(kmath.Rsqrt()(x), 1 / np.sqrt(x))\n\n    def test_erf_operation_basic(self):\n        # Sample values for testing\n        sample_values = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0])\n\n        # Expected output using numpy's approximation of the error function\n        expected_output = scipy.special.erf(sample_values)\n\n        # Output from the erf operation in keras_core\n        output_from_erf_op = kmath.erf(sample_values)\n\n        # Assert that the outputs are close\n        self.assertAllClose(expected_output, output_from_erf_op, atol=1e-4)\n\n    def test_erf_operation_dtype(self):\n        # Test for float32 and float64 data types\n        for dtype in (\"float32\", \"float64\"):\n            sample_values = np.array(\n                [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], dtype=dtype\n            )\n            expected_output = scipy.special.erf(sample_values)\n            output_from_erf_op = kmath.erf(sample_values)\n            self.assertAllClose(expected_output, output_from_erf_op, atol=1e-4)\n\n    def test_erf_operation_edge_cases(self):\n        # Test for edge cases\n        edge_values = np.array([1e5, -1e5, 1e-5, -1e-5], dtype=np.float64)\n        expected_output = scipy.special.erf(edge_values)\n        output_from_edge_erf_op = kmath.erf(edge_values)\n        self.assertAllClose(expected_output, output_from_edge_erf_op, atol=1e-4)\n\n    def test_erfinv_operation_basic(self):\n        # Sample values for testing\n        sample_values = np.array([-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0])\n\n        # Expected output using numpy's approximation of the error function\n        expected_output = scipy.special.erfinv(sample_values)\n\n        # Output from the erf operation in keras_core\n        output_from_erfinv_op = kmath.erfinv(sample_values)\n\n        # Assert that the outputs are close\n        self.assertAllClose(expected_output, output_from_erfinv_op, atol=1e-4)\n\n    def test_erfinv_operation_dtype(self):\n        # Test for float32 and float64 data types\n        for dtype in (\"float32\", \"float64\"):\n            sample_values = np.array(\n                [-3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0], dtype=dtype\n            )\n            expected_output = scipy.special.erfinv(sample_values)\n            output_from_erfinv_op = kmath.erfinv(sample_values)\n            self.assertAllClose(\n                expected_output, output_from_erfinv_op, atol=1e-4\n            )\n\n    def test_erfinv_operation_edge_cases(self):\n        # Test for edge cases\n        edge_values = np.array([1e5, -1e5, 1e-5, -1e-5], dtype=np.float64)\n        expected_output = scipy.special.erfinv(edge_values)\n        output_from_edge_erfinv_op = kmath.erfinv(edge_values)\n        self.assertAllClose(\n            expected_output, output_from_edge_erfinv_op, atol=1e-4\n        )\n\n    def test_logdet(self):\n        x = np.array(\n            [\n                [4.42, -1.18, 0.06, 0.74],\n                [-1.18, 1.77, -0.84, -1.16],\n                [0.06, -0.84, 5.84, 0.55],\n                [0.74, -1.16, 0.55, 0.77],\n            ],\n            dtype=\"float32\",\n        )\n        out = kmath.logdet(x)\n        self.assertAllClose(out, -1.1178946, atol=1e-3)\n\n\nclass MathDtypeTest(testing.TestCase):\n    \"\"\"Test the floating dtype to verify that the behavior matches JAX.\"\"\"\n\n    ALL_DTYPES = [\n        x\n        for x in dtypes.ALLOWED_DTYPES\n        if x\n        not in (\n            \"string\",\n            \"complex64\",\n            \"complex128\",\n            # Remove 64-bit dtypes.\n            \"float64\",\n            \"uint64\",\n            \"int64\",\n        )\n        + dtypes.FLOAT8_TYPES  # Remove float8 dtypes for the following tests\n    ] + [None]\n    INT_DTYPES = [x for x in dtypes.INT_TYPES if x not in (\"uint64\", \"int64\")]\n    FLOAT_DTYPES = [x for x in dtypes.FLOAT_TYPES if x not in (\"float64\",)]\n\n    if backend.backend() == \"torch\":\n        ALL_DTYPES = [x for x in ALL_DTYPES if x not in (\"uint16\", \"uint32\")]\n        INT_DTYPES = [x for x in INT_DTYPES if x not in (\"uint16\", \"uint32\")]\n\n\nclass ExtractSequencesOpTest(testing.TestCase):\n    def test_extract_sequences_init_length_1_stride_1(self):\n        extract_op = kmath.ExtractSequences(\n            sequence_length=1, sequence_stride=1\n        )\n        self.assertIsNotNone(extract_op)\n        self.assertEqual(extract_op.sequence_length, 1)\n        self.assertEqual(extract_op.sequence_stride, 1)\n\n    def test_extract_sequences_init_length_5_stride_2(self):\n        extract_op = kmath.ExtractSequences(\n            sequence_length=5, sequence_stride=2\n        )\n        self.assertIsNotNone(extract_op)\n        self.assertEqual(extract_op.sequence_length, 5)\n        self.assertEqual(extract_op.sequence_stride, 2)\n\n    def test_compute_output_spec_low_rank(self):\n        extract_op = kmath.ExtractSequences(\n            sequence_length=5, sequence_stride=1\n        )\n        low_rank_input = np.array(42)\n        error_message = r\"Input should have rank >= 1. Received: .*\"\n        with self.assertRaisesRegex(ValueError, error_message):\n            extract_op.compute_output_spec(low_rank_input)\n\n    def test_extract_sequences_call(self):\n        sequence_length, sequence_stride = 5, 2\n        extract_op = kmath.ExtractSequences(sequence_length, sequence_stride)\n        test_input = np.random.rand(10, 20)\n        result = extract_op.call(test_input)\n\n        expected_shape = self.calculate_expected_shape(\n            test_input.shape, sequence_length, sequence_stride\n        )\n        self.assertEqual(result.shape, expected_shape)\n\n    def calculate_expected_shape(\n        self, input_shape, sequence_length, sequence_stride\n    ):\n        num_sequences = (\n            (input_shape[1] - sequence_length) // sequence_stride\n        ) + 1\n        return (input_shape[0], num_sequences, sequence_length)\n\n\nclass SegmentSumTest(testing.TestCase):\n    def test_segment_sum_call(self):\n        data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)\n        segment_ids = np.array([0, 0, 1], dtype=np.int32)\n        num_segments = 2\n        sorted_segments = False\n        segment_sum_op = kmath.SegmentSum(\n            num_segments=num_segments, sorted=sorted_segments\n        )\n        output = segment_sum_op.call(data, segment_ids)\n        expected_output = np.array([[5, 7, 9], [7, 8, 9]], dtype=np.float32)\n        self.assertAllClose(output, expected_output)\n\n\nclass SegmentMaxTest(testing.TestCase):\n    def test_segment_max_call(self):\n        data = np.array([[1, 4, 7], [2, 5, 8], [3, 6, 9]], dtype=np.float32)\n        segment_ids = np.array([0, 0, 1], dtype=np.int32)\n        num_segments = 2\n        sorted_segments = False\n        segment_max_op = kmath.SegmentMax(\n            num_segments=num_segments, sorted=sorted_segments\n        )\n        output = segment_max_op.call(data, segment_ids)\n        expected_output = np.array([[2, 5, 8], [3, 6, 9]], dtype=np.float32)\n        self.assertAllClose(output, expected_output)\n\n\nclass TopKTest(testing.TestCase):\n    def test_top_k_call_values(self):\n        data = np.array([[1, 3, 2], [4, 6, 5]], dtype=np.float32)\n        k = 2\n        sorted_flag = True\n        top_k_op = kmath.TopK(k=k, sorted=sorted_flag)\n        values, _ = top_k_op.call(data)\n        expected_values = np.array([[3, 2], [6, 5]], dtype=np.float32)\n        self.assertAllClose(values, expected_values)\n\n    def test_top_k_call_indices(self):\n        data = np.array([[1, 3, 2], [4, 6, 5]], dtype=np.float32)\n        k = 2\n        sorted_flag = True\n        top_k_op = kmath.TopK(k=k, sorted=sorted_flag)\n        _, indices = top_k_op.call(data)\n        expected_indices = np.array([[1, 2], [1, 2]], dtype=np.int32)\n        self.assertAllClose(indices, expected_indices)\n\n\nclass InTopKTest(testing.TestCase):\n    def test_in_top_k_call(self):\n        targets = np.array([2, 0, 1], dtype=np.int32)\n        predictions = np.array(\n            [[0.1, 0.2, 0.7], [1.0, 0.2, 0.3], [0.2, 0.6, 0.2]],\n            dtype=np.float32,\n        )\n        k = 2\n        in_top_k_op = kmath.InTopK(k=k)\n        output = in_top_k_op.call(targets, predictions)\n        expected_output = np.array([True, True, True], dtype=bool)\n        self.assertAllEqual(output, expected_output)\n\n\nclass LogsumexpTest(testing.TestCase):\n    def test_logsumexp_call(self):\n        x = np.array([[1, 2], [3, 4]], dtype=np.float32)\n        axis = 0\n        keepdims = True\n        logsumexp_op = kmath.Logsumexp(axis=axis, keepdims=keepdims)\n        output = logsumexp_op.call(x)\n        expected_output = np.log(\n            np.sum(np.exp(x), axis=axis, keepdims=keepdims)\n        )\n        self.assertAllClose(output, expected_output)\n\n\nclass FFTTest(testing.TestCase):\n    def test_fft_input_not_tuple_or_list(self):\n        fft_op = kmath.FFT()\n        with self.assertRaisesRegex(\n            ValueError, \"Input `x` should be a tuple of two tensors\"\n        ):\n            fft_op.compute_output_spec(np.array([1, 2, 3]))\n\n    def test_fft_input_parts_different_shapes(self):\n        fft_op = kmath.FFT()\n        real = np.array([1, 2, 3])\n        imag = np.array([1, 2])\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Both the real and imaginary parts should have the same shape\",\n        ):\n            fft_op.compute_output_spec((real, imag))\n\n    def test_fft_input_not_1d(self):\n        fft_op = kmath.FFT()\n        real = np.array(1)\n        imag = np.array(1)\n        with self.assertRaisesRegex(ValueError, \"Input should have rank >= 1\"):\n            fft_op.compute_output_spec((real, imag))\n\n    def test_fft_last_axis_not_fully_defined(self):\n        fft_op = kmath.FFT()\n        real = KerasTensor(shape=(None,), dtype=\"float32\")\n        imag = KerasTensor(shape=(None,), dtype=\"float32\")\n        with self.assertRaisesRegex(\n            ValueError, \"Input should have its last dimension fully-defined\"\n        ):\n            fft_op.compute_output_spec((real, imag))\n\n\nclass FFT2Test(testing.TestCase):\n    def test_fft2_correct_input(self):\n        fft2_op = kmath.FFT2()\n        real_part = np.random.rand(2, 3, 4)\n        imag_part = np.random.rand(2, 3, 4)\n        # This should not raise any errors\n        fft2_op.compute_output_spec((real_part, imag_part))\n\n    def test_fft2_incorrect_input_type(self):\n        fft2_op = kmath.FFT2()\n        incorrect_input = np.array([1, 2, 3])  # Not a tuple or list\n        with self.assertRaisesRegex(\n            ValueError, \"should be a tuple of two tensors\"\n        ):\n            fft2_op.compute_output_spec(incorrect_input)\n\n    def test_fft2_mismatched_shapes(self):\n        fft2_op = kmath.FFT2()\n        real_part = np.random.rand(2, 3, 4)\n        imag_part = np.random.rand(2, 3)  # Mismatched shape\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Both the real and imaginary parts should have the same shape\",\n        ):\n            fft2_op.compute_output_spec((real_part, imag_part))\n\n    def test_fft2_low_rank(self):\n        fft2_op = kmath.FFT2()\n        low_rank_input = np.random.rand(3)  # Rank of 1\n        with self.assertRaisesRegex(ValueError, \"Input should have rank >= 2\"):\n            fft2_op.compute_output_spec((low_rank_input, low_rank_input))\n\n    def test_fft2_undefined_dimensions(self):\n        fft2_op = kmath.FFT2()\n        real_part = KerasTensor(shape=(None, None, 3), dtype=\"float32\")\n        imag_part = KerasTensor(shape=(None, None, 3), dtype=\"float32\")\n        with self.assertRaisesRegex(\n            ValueError, \"Input should have its .* axes fully-defined\"\n        ):\n            fft2_op.compute_output_spec((real_part, imag_part))\n\n\nclass RFFTTest(testing.TestCase):\n    def test_rfft_low_rank_input(self):\n        rfft_op = kmath.RFFT()\n        low_rank_input = np.array(5)\n        with self.assertRaisesRegex(ValueError, \"Input should have rank >= 1\"):\n            rfft_op.compute_output_spec(low_rank_input)\n\n    def test_rfft_defined_fft_length(self):\n        fft_length = 10\n        rfft_op = kmath.RFFT(fft_length=fft_length)\n        input_tensor = np.random.rand(3, 8)\n\n        expected_last_dimension = fft_length // 2 + 1\n        expected_shape = input_tensor.shape[:-1] + (expected_last_dimension,)\n\n        output_tensors = rfft_op.compute_output_spec(input_tensor)\n        for output_tensor in output_tensors:\n            self.assertEqual(output_tensor.shape, expected_shape)\n\n        def test_rfft_undefined_fft_length_defined_last_dim(self):\n            rfft_op = kmath.RFFT()\n            input_tensor = np.random.rand(3, 8)\n            expected_last_dimension = input_tensor.shape[-1] // 2 + 1\n            expected_shape = input_tensor.shape[:-1] + (\n                expected_last_dimension,\n            )\n            output_tensors = rfft_op.compute_output_spec(input_tensor)\n            for output_tensor in output_tensors:\n                self.assertEqual(output_tensor.shape, expected_shape)\n\n    def test_rfft_undefined_fft_length_undefined_last_dim(self):\n        rfft_op = kmath.RFFT()\n        input_tensor = KerasTensor(shape=(None, None), dtype=\"float32\")\n        expected_shape = input_tensor.shape[:-1] + (None,)\n        output_tensors = rfft_op.compute_output_spec(input_tensor)\n        for output_tensor in output_tensors:\n            self.assertEqual(output_tensor.shape, expected_shape)\n\n\nclass ISTFTTest(testing.TestCase):\n    def test_istft_incorrect_input_type(self):\n        istft_op = kmath.ISTFT(\n            sequence_length=5, sequence_stride=2, fft_length=10\n        )\n        incorrect_input = np.array([1, 2, 3])\n        with self.assertRaisesRegex(\n            ValueError, \"should be a tuple of two tensors\"\n        ):\n            istft_op.compute_output_spec(incorrect_input)\n\n    def test_istft_mismatched_shapes(self):\n        istft_op = kmath.ISTFT(\n            sequence_length=5, sequence_stride=2, fft_length=10\n        )\n        real_part = np.random.rand(2, 3, 4)\n        imag_part = np.random.rand(2, 3)\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Both the real and imaginary parts should have the same shape\",\n        ):\n            istft_op.compute_output_spec((real_part, imag_part))\n\n    def test_istft_low_rank_input(self):\n        istft_op = kmath.ISTFT(\n            sequence_length=5, sequence_stride=2, fft_length=10\n        )\n        low_rank_input = np.random.rand(3)\n        with self.assertRaisesRegex(ValueError, \"Input should have rank >= 2\"):\n            istft_op.compute_output_spec((low_rank_input, low_rank_input))\n\n    def test_input_not_tuple_or_list_raises_error(self):\n        irfft_op = kmath.IRFFT()\n        invalid_input = np.array([1, 2, 3])\n        with self.assertRaisesRegex(\n            ValueError, \"Input `x` should be a tuple of two tensors\"\n        ):\n            irfft_op.compute_output_spec(invalid_input)\n\n    def test_input_tuple_with_less_than_two_elements_raises_error(self):\n        irfft_op = kmath.IRFFT()\n        too_short_input = (np.array([1, 2, 3]),)\n        with self.assertRaisesRegex(\n            ValueError, \"Input `x` should be a tuple of two tensors\"\n        ):\n            irfft_op.compute_output_spec(too_short_input)\n\n    def test_input_tuple_with_more_than_two_elements_raises_error(self):\n        irfft_op = kmath.IRFFT()\n        too_long_input = (\n            np.array([1, 2, 3]),\n            np.array([4, 5, 6]),\n            np.array([7, 8, 9]),\n        )\n        with self.assertRaisesRegex(\n            ValueError, \"Input `x` should be a tuple of two tensors\"\n        ):\n            irfft_op.compute_output_spec(too_long_input)\n\n    def test_mismatched_shapes_input_validation(self):\n        irfft_op = kmath.IRFFT()\n\n        # Create real and imaginary parts with mismatched shapes\n        real_part = np.array([1, 2, 3])\n        imag_part = np.array([[1, 2], [3, 4]])\n\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Both the real and imaginary parts should have the same shape\",\n        ):\n            irfft_op.compute_output_spec((real_part, imag_part))\n\n    def test_insufficient_rank_input_validation(self):\n        irfft_op = kmath.IRFFT()\n\n        # Create real and imaginary parts with insufficient rank (0D)\n        real_part = np.array(1)\n        imag_part = np.array(1)\n\n        with self.assertRaisesRegex(ValueError, \"Input should have rank >= 1\"):\n            irfft_op.compute_output_spec((real_part, imag_part))\n\n    def test_with_specified_fft_length(self):\n        fft_length = 10\n        irfft_op = kmath.IRFFT(fft_length=fft_length)\n\n        real_part = np.random.rand(4, 8)\n        imag_part = np.random.rand(4, 8)\n\n        expected_shape = real_part.shape[:-1] + (fft_length,)\n        output_shape = irfft_op.compute_output_spec(\n            (real_part, imag_part)\n        ).shape\n\n        self.assertEqual(output_shape, expected_shape)\n\n    def test_inferred_fft_length_with_defined_last_dimension(self):\n        irfft_op = kmath.IRFFT()\n\n        real_part = np.random.rand(4, 8)\n        imag_part = np.random.rand(4, 8)\n\n        inferred_fft_length = 2 * (real_part.shape[-1] - 1)\n        expected_shape = real_part.shape[:-1] + (inferred_fft_length,)\n        output_shape = irfft_op.compute_output_spec(\n            (real_part, imag_part)\n        ).shape\n\n        self.assertEqual(output_shape, expected_shape)\n\n    def test_undefined_fft_length_and_last_dimension(self):\n        irfft_op = kmath.IRFFT()\n\n        real_part = KerasTensor(shape=(4, None), dtype=\"float32\")\n        imag_part = KerasTensor(shape=(4, None), dtype=\"float32\")\n\n        output_spec = irfft_op.compute_output_spec((real_part, imag_part))\n        expected_shape = real_part.shape[:-1] + (None,)\n\n        self.assertEqual(output_spec.shape, expected_shape)\n\n\nclass TestMathErrors(testing.TestCase):\n    @parameterized.parameters([(kmath.segment_sum,), (kmath.segment_max,)])\n    @pytest.mark.skipif(\n        backend.backend() != \"jax\", reason=\"Testing Jax errors only\"\n    )\n    def test_segment_reduce_no_num_segments(self, segment_reduce_op):\n        data = jnp.array([1, 2, 3, 4])\n        segment_ids = jnp.array([0, 0, 1, 1])\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Argument `num_segments` must be set when using the JAX backend.\",\n        ):\n            segment_reduce_op(data, segment_ids)\n\n    @parameterized.parameters([(kmath.segment_sum,), (kmath.segment_max,)])\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\", reason=\"Tensorflow error only\"\n    )\n    def test_segment_reduce_sort_and_num_segments(self, segment_reduce_op):\n        data = np.array([1, 2, 3, 4])\n        segment_ids = np.array([0, 0, 1, 1])\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Argument `num_segments` cannot be set when sorted is True when \"\n            \"using the tensorflow backend.\",\n        ):\n            segment_reduce_op(data, segment_ids, num_segments=2, sorted=True)\n\n    @parameterized.parameters([(kmath.segment_sum,), (kmath.segment_max,)])\n    def test_segment_reduce_multi_dim_segment_ids(self, segment_reduce_op):\n        data = np.array([1, 2, 3, 4])\n        segment_ids = np.array([0, 0, 1, 1]).reshape((2, 2))\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Argument `segment_ids` should be an 1-D vector,\",\n        ):\n            segment_reduce_op(data, segment_ids)\n\n    @parameterized.parameters([(kmath.segment_sum,), (kmath.segment_max,)])\n    def test_segment_reduce_leading_not_match(self, segment_reduce_op):\n        data = np.array([])\n        segment_ids = np.array([0, 0, 1, 1])\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Argument `segment_ids` and `data` should have same leading \"\n            \"dimension.\",\n        ):\n            segment_reduce_op(data, segment_ids)\n\n        output_tensor = segment_reduce_op(\n            KerasTensor(shape=(None, 4)), KerasTensor(shape=(5,))\n        )\n        self.assertEqual(output_tensor.shape, (None, 4))\n\n        output_tensor = segment_reduce_op(\n            KerasTensor(shape=(5, 4)), KerasTensor(shape=(None,))\n        )\n        self.assertEqual(output_tensor.shape, (None, 4))\n\n        output_tensor = segment_reduce_op(\n            KerasTensor(shape=(None, 4)), KerasTensor(shape=(None,))\n        )\n        self.assertEqual(output_tensor.shape, (None, 4))\n\n    def test_stft_invalid_input_type(self):\n        # backend agnostic error message\n        x = np.array([1, 2, 3, 4])\n        sequence_length = 2\n        sequence_stride = 1\n        fft_length = 4\n        with self.assertRaisesRegex(TypeError, \"`float32` or `float64`\"):\n            kmath.stft(x, sequence_length, sequence_stride, fft_length)\n\n    def test_invalid_fft_length(self):\n        # backend agnostic error message\n        x = np.array([1.0, 2.0, 3.0, 4.0])\n        sequence_length = 4\n        sequence_stride = 1\n        fft_length = 2\n        with self.assertRaisesRegex(ValueError, \"`fft_length` must equal or\"):\n            kmath.stft(x, sequence_length, sequence_stride, fft_length)\n\n    def test_stft_invalid_window(self):\n        # backend agnostic error message\n        x = np.array([1.0, 2.0, 3.0, 4.0])\n        sequence_length = 2\n        sequence_stride = 1\n        fft_length = 4\n        window = \"invalid_window\"\n        with self.assertRaisesRegex(ValueError, \"If a string is passed to\"):\n            kmath.stft(\n                x, sequence_length, sequence_stride, fft_length, window=window\n            )\n\n    def test_stft_invalid_window_shape(self):\n        # backend agnostic error message\n        x = np.array([1.0, 2.0, 3.0, 4.0])\n        sequence_length = 2\n        sequence_stride = 1\n        fft_length = 4\n        window = np.ones((sequence_length + 1))\n        with self.assertRaisesRegex(ValueError, \"The shape of `window` must\"):\n            kmath.stft(\n                x, sequence_length, sequence_stride, fft_length, window=window\n            )\n\n    @parameterized.parameters([0, -5, 1.5])\n    def test_stft_invalid_sequence_stride(self, sequence_stride):\n        x = np.array([1.0, 2.0, 3.0, 4.0])\n        sequence_length = 2\n        fft_length = 4\n        with self.assertRaisesRegex(\n            ValueError, \"`sequence_stride` must be a positive integer\"\n        ):\n            kmath.stft(x, sequence_length, sequence_stride, fft_length)\n\n    @parameterized.parameters([0, -5, 1.5])\n    def test_istft_invalid_sequence_stride(self, sequence_stride):\n        x = (np.array([[1.0, 2.0]]), np.array([[3.0, 4.0]]))\n        sequence_length = 2\n        fft_length = 4\n        with self.assertRaisesRegex(\n            ValueError, \"`sequence_stride` must be a positive integer\"\n        ):\n            kmath.istft(x, sequence_length, sequence_stride, fft_length)\n\n    def test_istft_invalid_window_shape_2D_inputs(self):\n        # backend agnostic error message\n        x = (np.array([[1.0, 2.0]]), np.array([[3.0, 4.0]]))\n        sequence_length = 2\n        sequence_stride = 1\n        fft_length = 4\n        incorrect_window = np.ones((sequence_length + 1,))\n        with self.assertRaisesRegex(\n            ValueError, \"The shape of `window` must be equal to\"\n        ):\n            kmath.istft(\n                x,\n                sequence_length,\n                sequence_stride,\n                fft_length,\n                window=incorrect_window,\n            )\n\n\n@pytest.mark.skipif(\n    backend.backend() == \"openvino\",\n    reason=\"Complex dtype is not supported on OpenVINO backend.\",\n)\nclass ViewAsComplexRealTest(testing.TestCase):\n    def test_view_as_complex_basic(self):\n        real_imag = np.array([[1.0, 2.0], [3.0, 4.0]])\n        expected = np.array([1.0 + 2.0j, 3.0 + 4.0j], dtype=np.complex64)\n\n        result = kmath.view_as_complex(real_imag)\n\n        self.assertEqual(result.shape, expected.shape)\n        self.assertEqual(standardize_dtype(result.dtype), expected.dtype)\n        self.assertAllClose(result, expected)\n\n    def test_view_as_real_basic(self):\n        complex_tensor = np.array([1 + 2j, 3 + 4j], dtype=np.complex64)\n        expected = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)\n\n        result = kmath.view_as_real(complex_tensor)\n\n        self.assertEqual(result.shape, expected.shape)\n        self.assertEqual(standardize_dtype(result.dtype), expected.dtype)\n        self.assertAllClose(result, expected)\n\n    def test_view_as_complex_invalid_shape(self):\n        bad_input = np.array([1.0, 2.0, 3.0])  # Last dimension not size 2\n        with self.assertRaisesRegex(\n            ValueError, \"Last dimension of input must be size 2\"\n        ):\n            kmath.view_as_complex(bad_input)\n\n    def test_view_as_complex_symbolic_input(self):\n        x = KerasTensor(shape=(None, 2), dtype=\"float32\")\n        result = kmath.view_as_complex(x)\n\n        self.assertEqual(result.shape, (None,))\n        self.assertEqual(standardize_dtype(result.dtype), \"complex64\")\n\n    def test_view_as_real_symbolic_input(self):\n        x = KerasTensor(shape=(None,), dtype=\"complex64\")\n        result = kmath.view_as_real(x)\n\n        self.assertEqual(result.shape, (None, 2))\n        self.assertEqual(standardize_dtype(result.dtype), \"float32\")\n\n    def test_view_as_complex_multi_dimensional(self):\n        x = np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=np.float32)\n        expected = np.array([[1 + 2j, 3 + 4j]], dtype=np.complex64)\n\n        result = kmath.view_as_complex(x)\n\n        self.assertEqual(result.shape, expected.shape)\n        self.assertEqual(standardize_dtype(result.dtype), expected.dtype)\n        self.assertAllClose(result, expected)\n\n    def test_view_as_real_multi_dimensional(self):\n        x = np.array([[1 + 2j, 3 + 4j]], dtype=np.complex64)\n        expected = np.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=np.float32)\n\n        result = kmath.view_as_real(x)\n\n        self.assertEqual(result.shape, expected.shape)\n        self.assertEqual(standardize_dtype(result.dtype), expected.dtype)\n        self.assertAllClose(result, expected)\n"
  },
  {
    "path": "keras/src/ops/nn.py",
    "content": "\"\"\"Commonly-used neural network operations not included in NumPy.\"\"\"\n\nimport warnings\n\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend import KerasTensor\nfrom keras.src.backend import any_symbolic_tensors\nfrom keras.src.backend import config\nfrom keras.src.backend import standardize_data_format\nfrom keras.src.backend.common.backend_utils import (\n    compute_conv_transpose_output_shape,\n)\nfrom keras.src.ops import operation_utils\nfrom keras.src.ops.operation import Operation\nfrom keras.src.ops.operation_utils import reduce_shape\nfrom keras.src.utils.python_utils import is_continuous_axis\n\n\nclass Relu(Operation):\n    def call(self, x):\n        return backend.nn.relu(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.relu\", \"keras.ops.nn.relu\"])\ndef relu(x):\n    \"\"\"Rectified linear unit activation function.\n\n    It is defined as `f(x) = max(0, x)`.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        A tensor with the same shape as `x`.\n\n    Example:\n\n    >>> x1 = keras.ops.convert_to_tensor([-1.0, 0.0, 1.0, 0.2])\n    >>> keras.ops.relu(x1)\n    array([0.0, 0.0, 1.0, 0.2], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Relu().symbolic_call(x)\n    return backend.nn.relu(x)\n\n\nclass Relu6(Operation):\n    def call(self, x):\n        return backend.nn.relu6(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.relu6\", \"keras.ops.nn.relu6\"])\ndef relu6(x):\n    \"\"\"Rectified linear unit activation function with upper bound of 6.\n\n    It is defined as `f(x) = np.clip(x, 0, 6)`.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        A tensor with the same shape as `x`.\n\n    Example:\n\n    >>> x = keras.ops.convert_to_tensor([-3.0, -2.0, 0.1, 0.2, 6.0, 8.0])\n    >>> keras.ops.relu6(x)\n    array([0.0, 0.0, 0.1, 0.2, 6.0, 6.0], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Relu6().symbolic_call(x)\n    return backend.nn.relu6(x)\n\n\nclass Sigmoid(Operation):\n    def call(self, x):\n        return backend.nn.sigmoid(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.sigmoid\", \"keras.ops.nn.sigmoid\"])\ndef sigmoid(x):\n    \"\"\"Sigmoid activation function.\n\n    It is defined as `f(x) = 1 / (1 + exp(-x))`.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        A tensor with the same shape as `x`.\n\n    Example:\n\n    >>> x = keras.ops.convert_to_tensor([-6.0, 1.0, 0.0, 1.0, 6.0])\n    >>> keras.ops.sigmoid(x)\n    array([0.00247262, 0.7310586, 0.5, 0.7310586, 0.9975274], dtype=float32)\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Sigmoid().symbolic_call(x)\n    return backend.nn.sigmoid(x)\n\n\nclass SparseSigmoid(Operation):\n    def call(self, x):\n        return backend.nn.sparse_sigmoid(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.sparse_sigmoid\", \"keras.ops.nn.sparse_sigmoid\"])\ndef sparse_sigmoid(x):\n    \"\"\"Sparse sigmoid activation function.\n\n    It is defined as\n\n    `f(x) = 0` for `x <= -1`,\n    `f(x) = 0.5 * (x + 1)` for `-1 < x < 1`,\n    `f(x) = 1` for `x >= 1`.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        A tensor with the same shape as `x`.\n\n    Example:\n\n    >>> x = keras.ops.convert_to_tensor([-6.0, 1.0, 0.0, 1.0, 6.0])\n    >>> keras.ops.sparse_sigmoid(x)\n    array([0. , 1. , 0.5, 1. , 1. ], dtype=float32)\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return SparseSigmoid().symbolic_call(x)\n    return backend.nn.sparse_sigmoid(x)\n\n\nclass Softplus(Operation):\n    def call(self, x):\n        return backend.nn.softplus(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.softplus\", \"keras.ops.nn.softplus\"])\ndef softplus(x):\n    \"\"\"Softplus activation function.\n\n    It is defined as `f(x) = log(exp(x) + 1)`, where `log` is the natural\n    logarithm and `exp` is the exponential function.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        A tensor with the same shape as `x`.\n\n    Example:\n\n    >>> x = keras.ops.convert_to_tensor([-0.555, 0.0, 0.555])\n    >>> keras.ops.softplus(x)\n    array([0.45366603, 0.6931472, 1.008666], dtype=float32)\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Softplus().symbolic_call(x)\n    return backend.nn.softplus(x)\n\n\nclass Softsign(Operation):\n    def call(self, x):\n        return backend.nn.softsign(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.softsign\", \"keras.ops.nn.softsign\"])\ndef softsign(x):\n    \"\"\"Softsign activation function.\n\n    It is defined as `f(x) = x / (abs(x) + 1)`.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        A tensor with the same shape as `x`.\n\n    Example:\n\n    >>> x = keras.ops.convert_to_tensor([-0.100, -10.0, 1.0, 0.0, 100.0])\n    >>> keras.ops.softsign(x)\n    Array([-0.09090909, -0.90909094, 0.5, 0.0, 0.990099], dtype=float32)\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Softsign().symbolic_call(x)\n    return backend.nn.softsign(x)\n\n\nclass SoftShrink(Operation):\n    def __init__(self, threshold=0.5, *, name=None):\n        super().__init__(name=name)\n        self.threshold = threshold\n\n    def call(self, x):\n        return backend.nn.soft_shrink(x, self.threshold)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.soft_shrink\", \"keras.ops.nn.soft_shrink\"])\ndef soft_shrink(x, threshold=0.5):\n    \"\"\"Soft Shrink activation function.\n\n    It is defined as\n\n    `f(x) = x - threshold` if `x > threshold`,\n    `f(x) = x + threshold` if `x < -threshold`,\n    `f(x) = 0` otherwise.\n\n    Args:\n        x: Input tensor.\n        threshold: Threshold value. Defaults to 0.5.\n\n    Returns:\n        A tensor with the same shape as `x`.\n\n    Example:\n\n    >>> x = np.array([-1.0, 0.0, 1.0])\n    >>> x_soft_shrink = keras.ops.soft_shrink(x)\n    >>> print(x_soft_shrink)\n    array([-0.5  0.   0.5], shape=(3,), dtype=float64)\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return SoftShrink(threshold).symbolic_call(x)\n    return backend.nn.soft_shrink(x, threshold)\n\n\nclass SparsePlus(Operation):\n    def call(self, x):\n        return backend.nn.sparse_plus(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.sparse_plus\", \"keras.ops.nn.sparse_plus\"])\ndef sparse_plus(x):\n    \"\"\"SparsePlus activation function.\n\n    It is defined as\n\n    `f(x) = 0` for `x <= -1`.\n    `f(x) = (1/4) * (x + 1)^2` for `-1 < x < 1`.\n    `f(x) = x` for `x >= 1`.\n\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        A tensor with the same shape as `x`.\n\n    Example:\n\n    >>> x = np.array([-1.0, 0.0, 1.0])\n    >>> x_sparse_plus = keras.ops.sparse_plus(x)\n    >>> print(x_sparse_plus)\n    Array([0.   0.25 1.  ], shape=(3,), dtype=float32)\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return SparsePlus().symbolic_call(x)\n    return backend.nn.sparse_plus(x)\n\n\nclass Silu(Operation):\n    def call(self, x):\n        return backend.nn.silu(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export(\n    [\n        \"keras.ops.silu\",\n        \"keras.ops.nn.silu\",\n        \"keras.ops.swish\",\n        \"keras.ops.nn.swish\",\n    ]\n)\ndef silu(x):\n    \"\"\"Sigmoid Linear Unit (SiLU) activation function, also known as Swish.\n\n    The SiLU activation function is computed by the sigmoid function multiplied\n    by its input. It is defined as `f(x) = x * sigmoid(x)`.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        A tensor with the same shape as `x`.\n\n    Example:\n\n    >>> x = keras.ops.convert_to_tensor([-6.0, 1.0, 0.0, 1.0, 6.0])\n    >>> keras.ops.sigmoid(x)\n    array([0.00247262, 0.7310586, 0.5, 0.7310586, 0.9975274], dtype=float32)\n    >>> keras.ops.silu(x)\n    array([-0.0148357, 0.7310586, 0.0, 0.7310586, 5.9851646], dtype=float32)\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Silu().symbolic_call(x)\n    return backend.nn.silu(x)\n\n\nclass Squareplus(Operation):\n    def __init__(self, b=4, *, name=None):\n        super().__init__(name=name)\n        self.b = b\n\n    def call(self, x):\n        return backend.nn.squareplus(x, self.b)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.squareplus\", \"keras.ops.nn.squareplus\"])\ndef squareplus(x, b=4):\n    \"\"\"Squareplus activation function.\n\n    The Squareplus activation function is defined as:\n\n    `f(x) = (x + sqrt(x^2 + b)) / 2`\n\n    Args:\n        x: Input tensor.\n        b: Smoothness parameter. Defaults to 4.\n\n    Returns:\n        A tensor with the same shape as `x`.\n\n    Example:\n\n    >>> x = np.array([-1.0, 0.0, 1.0])\n    >>> x_squareplus = keras.ops.squareplus(x)\n    >>> print(x_squareplus)\n    array([0.6180, 1.0000, 1.6180], dtype=float32)\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Squareplus(b).symbolic_call(x)\n    return backend.nn.squareplus(x, b)\n\n\nclass LogSigmoid(Operation):\n    def call(self, x):\n        return backend.nn.log_sigmoid(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export(\n    [\n        \"keras.ops.log_sigmoid\",\n        \"keras.ops.nn.log_sigmoid\",\n    ]\n)\ndef log_sigmoid(x):\n    \"\"\"Logarithm of the sigmoid activation function.\n\n    It is defined as `f(x) = log(1 / (1 + exp(-x)))`.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        A tensor with the same shape as `x`.\n\n    Example:\n\n    >>> x = keras.ops.convert_to_tensor([-0.541391, 0.0, 0.50, 5.0])\n    >>> keras.ops.log_sigmoid(x)\n    array([-1.0000418, -0.6931472, -0.474077, -0.00671535], dtype=float32)\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return LogSigmoid().symbolic_call(x)\n    return backend.nn.log_sigmoid(x)\n\n\nclass LeakyRelu(Operation):\n    def __init__(self, negative_slope=0.2, *, name=None):\n        super().__init__(name=name)\n        self.negative_slope = negative_slope\n\n    def call(self, x):\n        return backend.nn.leaky_relu(x, self.negative_slope)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.leaky_relu\", \"keras.ops.nn.leaky_relu\"])\ndef leaky_relu(x, negative_slope=0.2):\n    \"\"\"Leaky version of a Rectified Linear Unit activation function.\n\n    It allows a small gradient when the unit is not active, it is defined as:\n\n    `f(x) = alpha * x for x < 0` or `f(x) = x for x >= 0`.\n\n    Args:\n        x: Input tensor.\n        negative_slope: Slope of the activation function at x < 0.\n            Defaults to `0.2`.\n\n    Returns:\n        A tensor with the same shape as `x`.\n\n    Example:\n\n    >>> x = np.array([-1., 0., 1.])\n    >>> x_leaky_relu = keras.ops.leaky_relu(x)\n    >>> print(x_leaky_relu)\n    array([-0.2,  0. ,  1. ], shape=(3,), dtype=float64)\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return LeakyRelu(negative_slope).symbolic_call(x)\n    return backend.nn.leaky_relu(x, negative_slope=negative_slope)\n\n\nclass HardSigmoid(Operation):\n    def call(self, x):\n        return backend.nn.hard_sigmoid(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export(\n    [\n        \"keras.ops.hard_sigmoid\",\n        \"keras.ops.nn.hard_sigmoid\",\n    ]\n)\ndef hard_sigmoid(x):\n    \"\"\"Hard sigmoid activation function.\n\n    It is defined as:\n\n    `0 if x < -2.5`, `1 if x > 2.5`, `(0.2 * x) + 0.5 if -2.5 <= x <= 2.5`.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        A tensor with the same shape as `x`.\n\n    Example:\n\n    >>> x = np.array([-1., 0., 1.])\n    >>> x_hard_sigmoid = keras.ops.hard_sigmoid(x)\n    >>> print(x_hard_sigmoid)\n    array([0.3, 0.5, 0.7], shape=(3,), dtype=float64)\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return HardSigmoid().symbolic_call(x)\n    return backend.nn.hard_sigmoid(x)\n\n\nclass HardSilu(Operation):\n    def call(self, x):\n        return backend.nn.hard_silu(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export(\n    [\n        \"keras.ops.hard_silu\",\n        \"keras.ops.nn.hard_silu\",\n        \"keras.ops.hard_swish\",\n        \"keras.ops.nn.hard_swish\",\n    ]\n)\ndef hard_silu(x):\n    \"\"\"Hard SiLU activation function, also known as Hard Swish.\n\n    It is defined as:\n\n    - `0` if `if x < -3`\n    - `x` if `x > 3`\n    - `x * (x + 3) / 6` if `-3 <= x <= 3`\n\n    It's a faster, piecewise linear approximation of the silu activation.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        A tensor with the same shape as `x`.\n\n    Example:\n\n    >>> x = keras.ops.convert_to_tensor([-3.0, -1.0, 0.0, 1.0, 3.0])\n    >>> keras.ops.hard_silu(x)\n    array([-0.0, -0.3333333, 0.0, 0.6666667, 3.0], shape=(5,), dtype=float32)\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return HardSilu().symbolic_call(x)\n    return backend.nn.hard_silu(x)\n\n\nclass Elu(Operation):\n    def __init__(self, alpha=1.0, *, name=None):\n        super().__init__(name=name)\n        self.alpha = alpha\n\n    def call(self, x):\n        return backend.nn.elu(x, alpha=self.alpha)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.elu\", \"keras.ops.nn.elu\"])\ndef elu(x, alpha=1.0):\n    \"\"\"Exponential Linear Unit activation function.\n\n    It is defined as:\n\n    `f(x) =  alpha * (exp(x) - 1.) for x < 0`, `f(x) = x for x >= 0`.\n\n    Args:\n        x: Input tensor.\n        alpha: A scalar, slope of positive section. Defaults to `1.0`.\n\n    Returns:\n        A tensor with the same shape as `x`.\n\n    Example:\n\n    >>> x = np.array([-1., 0., 1.])\n    >>> x_elu = keras.ops.elu(x)\n    >>> print(x_elu)\n    array([-0.63212055, 0., 1.], shape=(3,), dtype=float64)\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Elu(alpha).symbolic_call(x)\n    return backend.nn.elu(x, alpha=alpha)\n\n\nclass Selu(Operation):\n    def call(self, x):\n        return backend.nn.selu(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.selu\", \"keras.ops.nn.selu\"])\ndef selu(x):\n    \"\"\"Scaled Exponential Linear Unit (SELU) activation function.\n\n    It is defined as:\n\n    `f(x) =  scale * alpha * (exp(x) - 1.) for x < 0`,\n    `f(x) = scale * x for x >= 0`.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        A tensor with the same shape as `x`.\n\n    Example:\n\n    >>> x = np.array([-1., 0., 1.])\n    >>> x_selu = keras.ops.selu(x)\n    >>> print(x_selu)\n    array([-1.11133055, 0., 1.05070098], shape=(3,), dtype=float64)\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Selu().symbolic_call(x)\n    return backend.nn.selu(x)\n\n\nclass Gelu(Operation):\n    def __init__(self, approximate=True, *, name=None):\n        super().__init__(name=name)\n        self.approximate = approximate\n\n    def call(self, x):\n        return backend.nn.gelu(x, self.approximate)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.gelu\", \"keras.ops.nn.gelu\"])\ndef gelu(x, approximate=True):\n    \"\"\"Gaussian Error Linear Unit (GELU) activation function.\n\n    If `approximate` is `True`, it is defined as:\n    `f(x) = 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3)))`\n\n    Or if `approximate` is `False`, it is defined as:\n    `f(x) = x * P(X <= x) = 0.5 * x * (1 + erf(x / sqrt(2)))`,\n    where `P(X) ~ N(0, 1)`.\n\n    Args:\n        x: Input tensor.\n        approximate: Approximate version of GELU activation. Defaults to `True`.\n\n    Returns:\n        A tensor with the same shape as `x`.\n\n    Example:\n\n    >>> x = np.array([-1., 0., 1.])\n    >>> x_gelu = keras.ops.gelu(x)\n    >>> print(x_gelu)\n    array([-0.15865525, 0., 0.84134475], shape=(3,), dtype=float64)\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Gelu(approximate).symbolic_call(x)\n    return backend.nn.gelu(x, approximate)\n\n\nclass Celu(Operation):\n    def __init__(self, alpha=1.0, *, name=None):\n        super().__init__(name=name)\n        self.alpha = alpha\n\n    def call(self, x):\n        return backend.nn.celu(x, self.alpha)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.celu\", \"keras.ops.nn.celu\"])\ndef celu(x, alpha=1.0):\n    \"\"\"Continuously-differentiable exponential linear unit.\n\n    It is defined as:\n\n    `f(x) =  alpha * (exp(x / alpha) - 1) for x < 0`, `f(x) = x for x >= 0`.\n\n    Args:\n        x: Input tensor.\n        alpha: the α value for the CELU formulation. Defaults to `1.0`.\n\n    Returns:\n        A tensor with the same shape as `x`.\n\n    Example:\n\n    >>> x = np.array([-1., 0., 1.])\n    >>> x_celu = keras.ops.celu(x)\n    >>> print(x_celu)\n    array([-0.63212056, 0. , 1. ], shape=(3,), dtype=float64)\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Celu(alpha).symbolic_call(x)\n    return backend.nn.celu(x, alpha)\n\n\nclass Glu(Operation):\n    def __init__(self, axis=-1, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n\n    def call(self, x):\n        return backend.nn.glu(x, axis=self.axis)\n\n    def compute_output_spec(self, x):\n        output_shape = list(x.shape)\n        if output_shape[self.axis] is not None:\n            if output_shape[self.axis] % 2 != 0:\n                raise ValueError(\n                    \"axis size must be divisible by 2. \"\n                    f\"Received: x.shape={x.shape} with axis={self.axis}\"\n                )\n            output_shape[self.axis] = output_shape[self.axis] // 2\n        return KerasTensor(output_shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.glu\", \"keras.ops.nn.glu\"])\ndef glu(x, axis=-1):\n    \"\"\"Gated Linear Unit (GLU) activation function.\n\n    It is defined as:\n\n    `f(x) = a * sigmoid(b)`\n    where `x` is split into `a` and `b` along the given axis.\n\n    Args:\n        x: Input tensor.\n        axis: The axis along which to split the input tensor. Defaults to `-1`.\n\n    Returns:\n        A tensor with the same shape as half of the input.\n\n    Example:\n\n    >>> x = np.array([-1., 0., 1. , 1.])\n    >>> x_glu = keras.ops.glu(x)\n    >>> print(x_glu)\n    array([-0.73105858, 0. ], shape=(2,), dtype=float64)\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Glu(axis).symbolic_call(x)\n    return backend.nn.glu(x, axis=axis)\n\n\nclass TanhShrink(Operation):\n    def call(self, x):\n        return backend.nn.tanh_shrink(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.tanh_shrink\", \"keras.ops.nn.tanh_shrink\"])\ndef tanh_shrink(x):\n    \"\"\"Applies the tanh shrink function element-wise.\n\n    It is defined as:\n\n    `f(x) = x - tanh(x)`.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Output tensor of the same shape as `x`, where each element is\n        transformed according to the tanh shrink operation.\n\n    Example:\n\n    >>> x = np.array([ -1., 0., 1.])\n    >>> x_tanh_shrink = keras.ops.tanh_shrink(x)\n    >>> print(x_tanh_shrink)\n    array([-0.23840584  0.  0.23840584], shape=(3,), dtype=float64)\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return TanhShrink().symbolic_call(x)\n    return backend.nn.tanh_shrink(x)\n\n\nclass HardTanh(Operation):\n    def call(self, x):\n        return backend.nn.hard_tanh(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.hard_tanh\", \"keras.ops.nn.hard_tanh\"])\ndef hard_tanh(x):\n    \"\"\"Applies the HardTanh function element-wise.\n\n    It is defined as:\n\n    `f(x) = -1 for x < -1`, `f(x) = x for -1 <= x <= 1`, `f(x) = 1 for x > 1`.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Output tensor of same shape as `x`\n        where values are clamped between -1 and 1.\n\n    Example:\n\n    >>> x = np.array([-2., -1., 0., 1., 2.])\n    >>> x_hard_tanh = keras.ops.hard_tanh(x)\n    >>> print(x_hard_tanh)\n    array([-1. -1.  0.  1.  1.], shape=(5,), dtype=float64)\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return HardTanh().symbolic_call(x)\n    return backend.nn.hard_tanh(x)\n\n\nclass HardShrink(Operation):\n    def __init__(self, threshold=0.5, *, name=None):\n        super().__init__(name=name)\n        self.threshold = threshold\n\n    def call(self, x):\n        return backend.nn.hard_shrink(x, self.threshold)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.hard_shrink\", \"keras.ops.nn.hard_shrink\"])\ndef hard_shrink(x, threshold=0.5):\n    \"\"\"Hard Shrink activation function.\n\n    The Hard Shrink function is a thresholding operation defined as:\n\n    `f(x) = x` if `|x| > threshold`,\n    `f(x) = 0` otherwise.\n\n    Args:\n        x: Input tensor.\n        threshold: Threshold value. Defaults to 0.5.\n\n    Returns:\n        A tensor with the same shape as `x`.\n\n    Example:\n\n    >>> x = np.array([-0.5, 0., 1.])\n    >>> x_hard_shrink = keras.ops.hard_shrink(x)\n    >>> print(x_hard_shrink)\n    array([0. 0. 1.], shape=(3,), dtype=float64)\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return HardShrink(threshold).symbolic_call(x)\n    return backend.nn.hard_shrink(x, threshold)\n\n\nclass Threshold(Operation):\n    def __init__(self, threshold, default_value, *, name=None):\n        super().__init__(name=name)\n        self.threshold = threshold\n        self.default_value = default_value\n\n    def call(self, x):\n        return backend.nn.threshold(x, self.threshold, self.default_value)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.threshold\", \"keras.ops.nn.threshold\"])\ndef threshold(x, threshold, default_value):\n    \"\"\"Threshold activation function.\n\n    The function thresholds the input `x` as follows:\n    `f(x) = x` if `x > threshold`,\n    `f(x) = default_value` otherwise.\n\n    Args:\n        x: Input tensor.\n        threshold: The value that decides when to retain or replace x.\n        default_value: Value to assign when `x <= threshold`.\n\n    Returns:\n        A tensor with the same shape as `x`.\n\n    Example:\n\n    >>> x = np.array([-1.0, 0.0, 1.0, 2.0])\n    >>> x_threshold = keras.ops.threshold(x, 1, 0)\n    >>> print(x_threshold)\n    array([0., 0., 0., 2.], shape=(4,), dtype=float64)\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Threshold(threshold, default_value).symbolic_call(x)\n    return backend.nn.threshold(x, threshold, default_value)\n\n\nclass Softmax(Operation):\n    def __init__(self, axis=-1, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n\n    def call(self, x):\n        return backend.nn.softmax(x, axis=self.axis)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.softmax\", \"keras.ops.nn.softmax\"])\ndef softmax(x, axis=-1):\n    \"\"\"Softmax activation function.\n\n    The elements of the output vector lie within the range `(0, 1)`, and their\n    total sum is exactly 1 (excluding the floating point rounding error).\n\n    Each vector is processed independently. The `axis` argument specifies the\n    axis along which the function is applied within the input.\n\n    It is defined as:\n    `f(x) = exp(x) / sum(exp(x))`\n\n    Args:\n        x: Input tensor.\n        axis: Integer, axis along which the softmax is applied.\n\n    Returns:\n        A tensor with the same shape as `x`.\n\n    Example:\n\n    >>> x = np.array([-1., 0., 1.])\n    >>> x_softmax = keras.ops.softmax(x)\n    >>> print(x_softmax)\n    array([0.09003057, 0.24472847, 0.66524096], shape=(3,), dtype=float64)\n\n    \"\"\"\n    # Don't use `backend.shape` since TensorFlow returns\n    # symbolic tensors for unknown shape which can trigger\n    # an error in TensorFlow graph execution.\n    if isinstance(axis, int) and x.shape[axis] == 1:\n        warnings.warn(\n            f\"You are using a softmax over axis {axis} \"\n            f\"of a tensor of shape {x.shape}. This axis \"\n            \"has size 1. The softmax operation will always return \"\n            \"the value 1, which is likely not what you intended. \"\n            \"Did you mean to use a sigmoid instead?\"\n        )\n    if any_symbolic_tensors((x,)):\n        return Softmax(axis).symbolic_call(x)\n    if isinstance(axis, tuple):\n        axis_to_keep = [v for v in range(len(x.shape)) if v not in axis]\n\n        x_transposed = backend.numpy.transpose(x, axes=(*axis_to_keep, *axis))\n        x_reshaped = backend.numpy.reshape(\n            x_transposed, (*[x.shape[v] for v in axis_to_keep], -1)\n        )\n\n        x = backend.nn.softmax(x_reshaped, axis=-1)\n\n        x = backend.numpy.reshape(x, x_transposed.shape)\n        combined = [*axis_to_keep, *axis]\n        x = backend.numpy.transpose(\n            x,\n            axes=sorted(range(len(combined)), key=combined.__getitem__),\n        )\n        return x\n    else:\n        return backend.nn.softmax(x, axis=axis)\n\n\nclass LogSoftmax(Operation):\n    def __init__(self, axis=-1, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n\n    def call(self, x):\n        return backend.nn.log_softmax(x, axis=self.axis)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export(\n    [\n        \"keras.ops.log_softmax\",\n        \"keras.ops.nn.log_softmax\",\n    ]\n)\ndef log_softmax(x, axis=-1):\n    \"\"\"Log-softmax activation function.\n\n    It is defined as:\n    `f(x) = x - max(x) - log(sum(exp(x - max(x))))`\n\n    Args:\n        x: Input tensor.\n        axis: Integer, axis along which the log-softmax is applied.\n            Defaults to `-1`.\n\n    Returns:\n        A tensor with the same shape as `x`.\n\n    Example:\n\n    >>> x = np.array([-1., 0., 1.])\n    >>> x_log_softmax = keras.ops.log_softmax(x)\n    >>> print(x_log_softmax)\n    array([-2.40760596, -1.40760596, -0.40760596], shape=(3,), dtype=float64)\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return LogSoftmax(axis).symbolic_call(x)\n    if isinstance(axis, tuple):\n        axis_to_keep = [v for v in range(len(x.shape)) if v not in axis]\n\n        x_transposed = backend.numpy.transpose(x, axes=(*axis_to_keep, *axis))\n        x_reshaped = backend.numpy.reshape(\n            x_transposed, (*[x.shape[v] for v in axis_to_keep], -1)\n        )\n\n        x = backend.nn.log_softmax(x_reshaped, axis=-1)\n\n        x = backend.numpy.reshape(x, x_transposed.shape)\n        combined = [*axis_to_keep, *axis]\n        x = backend.numpy.transpose(\n            x,\n            axes=sorted(range(len(combined)), key=combined.__getitem__),\n        )\n        return x\n    else:\n        return backend.nn.log_softmax(x, axis=axis)\n\n\nclass Sparsemax(Operation):\n    def __init__(self, axis=-1, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n\n    def call(self, x):\n        return backend.nn.sparsemax(x, axis=self.axis)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.sparsemax\", \"keras.ops.nn.sparsemax\"])\ndef sparsemax(x, axis=-1):\n    \"\"\"Sparsemax activation function.\n\n    For each batch `i`, and class `j`,\n    sparsemax activation function is defined as:\n\n    `sparsemax(x)[i, j] = max(x[i, j] - τ(x[i, :]), 0).`\n\n    Args:\n        x: Input tensor.\n        axis: `int`, axis along which the sparsemax operation is applied.\n\n    Returns:\n        A tensor, output of sparsemax transformation. Has the same type and\n        shape as `x`.\n\n    Example:\n\n    >>> x = np.array([-1., 0., 1.])\n    >>> x_sparsemax = keras.ops.sparsemax(x)\n    >>> print(x_sparsemax)\n    array([0., 0., 1.], shape=(3,), dtype=float64)\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Sparsemax(axis).symbolic_call(x)\n    return backend.nn.sparsemax(x, axis=axis)\n\n\nclass MaxPool(Operation):\n    def __init__(\n        self,\n        pool_size,\n        strides=None,\n        padding=\"valid\",\n        data_format=None,\n        *,\n        name=None,\n    ):\n        super().__init__(name=name)\n        self.pool_size = pool_size\n        self.strides = strides\n        self.padding = padding.lower()\n        self.data_format = data_format\n\n    def call(self, inputs):\n        return backend.nn.max_pool(\n            inputs,\n            self.pool_size,\n            self.strides,\n            self.padding,\n            self.data_format,\n        )\n\n    def compute_output_spec(self, inputs):\n        output_shape = operation_utils.compute_pooling_output_shape(\n            inputs.shape,\n            self.pool_size,\n            self.strides,\n            self.padding,\n            self.data_format,\n        )\n        return KerasTensor(output_shape, dtype=inputs.dtype)\n\n\n@keras_export([\"keras.ops.max_pool\", \"keras.ops.nn.max_pool\"])\ndef max_pool(\n    inputs,\n    pool_size,\n    strides=None,\n    padding=\"valid\",\n    data_format=None,\n):\n    \"\"\"Max pooling operation.\n\n    Args:\n        inputs: Tensor of rank N+2. `inputs` has shape\n            `(batch_size,) + inputs_spatial_shape + (num_channels,)` if\n            `data_format=\"channels_last\"`, or\n            `(batch_size, num_channels) + inputs_spatial_shape` if\n            `data_format=\"channels_first\"`. Pooling happens over the spatial\n            dimensions only.\n        pool_size: int or tuple/list of integers of size\n            `len(inputs_spatial_shape)`, specifying the size of the pooling\n            window for each spatial dimension of the input tensor. If\n            `pool_size` is int, then every spatial dimension shares the same\n            `pool_size`.\n        strides: int or tuple/list of integers of size\n            `len(inputs_spatial_shape)`. The stride of the sliding window for\n            each spatial dimension of the input tensor. If `strides` is int,\n            then every spatial dimension shares the same `strides`.\n        padding: string, either `\"valid\"` or `\"same\"`. `\"valid\"` means no\n            padding is applied, and `\"same\"` results in padding evenly to the\n            left/right or up/down of the input such that output has the\n            same height/width dimension as the input when `strides=1`.\n        data_format: A string, either `\"channels_last\"` or `\"channels_first\"`.\n            `data_format` determines the ordering of the dimensions in the\n            inputs. If `data_format=\"channels_last\"`, `inputs` is of shape\n            `(batch_size, ..., channels)` while if\n            `data_format=\"channels_first\"`, `inputs` is of shape\n            `(batch_size, channels, ...)`.\n\n    Returns:\n        A tensor of rank N+2, the result of the max pooling operation.\n    \"\"\"\n    data_format = standardize_data_format(data_format)\n    padding = padding.lower()\n    if any_symbolic_tensors((inputs,)):\n        return MaxPool(\n            pool_size,\n            strides,\n            padding,\n            data_format,\n        ).symbolic_call(inputs)\n    return backend.nn.max_pool(inputs, pool_size, strides, padding, data_format)\n\n\nclass AdaptiveMaxPool(Operation):\n    \"\"\"Adaptive max pooling operation.\"\"\"\n\n    def __init__(self, output_size, data_format=None, *, name=None):\n        super().__init__(name=name)\n        self.output_size = output_size\n        self.data_format = data_format\n\n    def call(self, inputs):\n        return backend.nn.adaptive_max_pool(\n            inputs, output_size=self.output_size, data_format=self.data_format\n        )\n\n    def compute_output_spec(self, inputs):\n        if self.data_format == \"channels_last\":\n            spatial_dims = self.output_size\n            output_shape = (\n                inputs.shape[: -len(self.output_size)]\n                + spatial_dims\n                + (inputs.shape[-1],)\n            )\n        else:\n            spatial_dims = self.output_size\n            output_shape = (inputs.shape[0], inputs.shape[1]) + spatial_dims\n        return backend.KerasTensor(output_shape, dtype=inputs.dtype)\n\n\n@keras_export([\"keras.ops.adaptive_max_pool\", \"keras.ops.nn.adaptive_max_pool\"])\ndef adaptive_max_pool(\n    inputs,\n    output_size,\n    data_format=None,\n):\n    \"\"\"Adaptive max pooling operation.\n\n    Applies an adaptive max pooling operation that automatically computes the\n    kernel size and stride to pool the input to the specified `output_size`.\n    This operation is useful when you want a fixed output size regardless of\n    input size, commonly used in models like ResNet for global feature\n    extraction.\n    Args:\n        inputs: Tensor of rank 4. Input tensor of shape:\n            - If `data_format=\"channels_last\"`:\n                `(batch_size, height, width, channels)`.\n            - If `data_format=\"channels_first\"`:\n                `(batch_size, channels, height, width)`.\n        output_size: Integer or tuple/list of 2 integers, specifying the target\n            output spatial dimensions `(output_height, output_width)`. If a\n            single\n            integer is provided, the same value is used for both dimensions.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            Defaults to the value found in your Keras config file at\n            `~/.keras/keras.json`. If never set, defaults to `\"channels_last\"`.\n\n    Returns:\n        A tensor of rank 4 representing the adaptive max pooled result.\n\n    Example:\n\n    >>> x = np.random.rand(2, 64, 64, 3)\n    >>> y = keras.ops.adaptive_max_pool(x, output_size=(32, 32))\n    >>> y.shape\n    (2, 32, 32, 3)\n\n    >>> # Works with any input size\n    >>> x = np.random.rand(2, 100, 80, 3)\n    >>> y = keras.ops.adaptive_max_pool(x, output_size=7)\n    >>> y.shape\n    (2, 7, 7, 3)\n    \"\"\"\n    if data_format is None:\n        data_format = config.image_data_format()\n\n    if any_symbolic_tensors((inputs,)):\n        return AdaptiveMaxPool(output_size, data_format).symbolic_call(inputs)\n\n    return backend.nn.adaptive_max_pool(\n        inputs, output_size=output_size, data_format=data_format\n    )\n\n\nclass AveragePool(Operation):\n    def __init__(\n        self,\n        pool_size,\n        strides=None,\n        padding=\"valid\",\n        data_format=None,\n        *,\n        name=None,\n    ):\n        super().__init__(name=name)\n        self.pool_size = pool_size\n        self.strides = strides\n        self.padding = padding.lower()\n        self.data_format = data_format\n\n    def call(self, inputs):\n        return backend.nn.average_pool(\n            inputs,\n            self.pool_size,\n            self.strides,\n            self.padding,\n            self.data_format,\n        )\n\n    def compute_output_spec(self, inputs):\n        output_shape = operation_utils.compute_pooling_output_shape(\n            inputs.shape,\n            self.pool_size,\n            self.strides,\n            self.padding,\n            self.data_format,\n        )\n        return KerasTensor(output_shape, dtype=inputs.dtype)\n\n\n@keras_export(\n    [\n        \"keras.ops.average_pool\",\n        \"keras.ops.nn.average_pool\",\n    ]\n)\ndef average_pool(\n    inputs,\n    pool_size,\n    strides=None,\n    padding=\"valid\",\n    data_format=None,\n):\n    \"\"\"Average pooling operation.\n\n    Args:\n        inputs: Tensor of rank N+2. `inputs` has shape\n            `(batch_size,) + inputs_spatial_shape + (num_channels,)` if\n            `data_format=\"channels_last\"`, or\n            `(batch_size, num_channels) + inputs_spatial_shape` if\n            `data_format=\"channels_first\"`. Pooling happens over the spatial\n            dimensions only.\n        pool_size: int or tuple/list of integers of size\n            `len(inputs_spatial_shape)`, specifying the size of the pooling\n            window for each spatial dimension of the input tensor. If\n            `pool_size` is int, then every spatial dimension shares the same\n            `pool_size`.\n        strides: int or tuple/list of integers of size\n            `len(inputs_spatial_shape)`. The stride of the sliding window for\n            each spatial dimension of the input tensor. If `strides` is int,\n            then every spatial dimension shares the same `strides`.\n        padding: string, either `\"valid\"` or `\"same\"`. `\"valid\"` means no\n            padding is applied, and `\"same\"` results in padding evenly to the\n            left/right or up/down of the input such that output has the\n            same height/width dimension as the input when `strides=1`.\n        data_format: A string, either `\"channels_last\"` or `\"channels_first\"`.\n            `data_format` determines the ordering of the dimensions in the\n            inputs. If `data_format=\"channels_last\"`, `inputs` is of shape\n            `(batch_size, ..., channels)` while if\n            `data_format=\"channels_first\"`, `inputs` is of shape\n            `(batch_size, channels, ...)`.\n\n    Returns:\n        A tensor of rank N+2, the result of the average pooling operation.\n    \"\"\"\n    data_format = standardize_data_format(data_format)\n    padding = padding.lower()\n    if any_symbolic_tensors((inputs,)):\n        return AveragePool(\n            pool_size,\n            strides,\n            padding,\n            data_format,\n        ).symbolic_call(inputs)\n    return backend.nn.average_pool(\n        inputs, pool_size, strides, padding, data_format\n    )\n\n\nclass AdaptiveAveragePool(Operation):\n    \"\"\"Adaptive average pooling operation.\"\"\"\n\n    def __init__(self, output_size, data_format=None, *, name=None):\n        super().__init__(name=name)\n        self.output_size = output_size\n        self.data_format = data_format\n\n    def call(self, inputs):\n        return backend.nn.adaptive_average_pool(\n            inputs, output_size=self.output_size, data_format=self.data_format\n        )\n\n    def compute_output_spec(self, inputs):\n        if self.data_format == \"channels_last\":\n            spatial_dims = self.output_size\n            output_shape = (\n                inputs.shape[: -len(self.output_size)]\n                + spatial_dims\n                + (inputs.shape[-1],)\n            )\n        else:\n            spatial_dims = self.output_size\n            output_shape = (inputs.shape[0], inputs.shape[1]) + spatial_dims\n        return backend.KerasTensor(output_shape, dtype=inputs.dtype)\n\n\n@keras_export(\n    [\"keras.ops.adaptive_average_pool\", \"keras.ops.nn.adaptive_average_pool\"]\n)\ndef adaptive_average_pool(\n    inputs,\n    output_size,\n    data_format=None,\n):\n    \"\"\"Adaptive average pooling operation.\n\n    Applies an adaptive average pooling operation that automatically\n    computes the kernel size and stride to pool the input to the\n    specified `output_size`. This operation is useful when you want a\n    fixed output size regardless of input size, commonly used in models\n    like ResNet for global feature extraction.\n\n    Args:\n        inputs: Tensor of rank 4. Input tensor of shape:\n            - If `data_format=\"channels_last\"`:\n                `(batch_size, height, width, channels)`.\n            - If `data_format=\"channels_first\"`:\n                `(batch_size, channels, height, width)`.\n        output_size: Integer or tuple/list of 2 integers, specifying the target\n            output spatial dimensions `(output_height, output_width)`. If a\n            single\n            integer is provided, the same value is used for both dimensions.\n        data_format: string, either `\"channels_last\"` or `\"channels_first\"`.\n            Defaults to the value found in your Keras config file at\n            `~/.keras/keras.json`. If never set, defaults to `\"channels_last\"`.\n\n    Returns:\n        A tensor of rank 4 representing the adaptive average pooled result.\n\n    Example:\n\n    >>> x = np.random.rand(2, 64, 64, 3)\n    >>> y = keras.ops.adaptive_average_pool(x, output_size=(32, 32))\n    >>> y.shape\n    (2, 32, 32, 3)\n\n    >>> # Works with any input size\n    >>> x = np.random.rand(2, 100, 80, 3)\n    >>> y = keras.ops.adaptive_average_pool(x, output_size=7)\n    >>> y.shape\n    (2, 7, 7, 3)\n    \"\"\"\n    if data_format is None:\n        data_format = config.image_data_format()\n\n    if any_symbolic_tensors((inputs,)):\n        return AdaptiveAveragePool(output_size, data_format).symbolic_call(\n            inputs\n        )\n\n    return backend.nn.adaptive_average_pool(\n        inputs, output_size=output_size, data_format=data_format\n    )\n\n\nclass Conv(Operation):\n    def __init__(\n        self,\n        strides=1,\n        padding=\"valid\",\n        data_format=None,\n        dilation_rate=1,\n        *,\n        name=None,\n    ):\n        super().__init__(name=name)\n        self.strides = strides\n        self.padding = padding.lower()\n        self.data_format = data_format\n        self.dilation_rate = dilation_rate\n\n    def call(self, inputs, kernel):\n        return backend.nn.conv(\n            inputs,\n            kernel,\n            strides=self.strides,\n            padding=self.padding,\n            data_format=self.data_format,\n            dilation_rate=self.dilation_rate,\n        )\n\n    def compute_output_spec(self, inputs, kernel):\n        output_shape = operation_utils.compute_conv_output_shape(\n            inputs.shape,\n            kernel.shape[-1],\n            kernel.shape[:-2],\n            self.strides,\n            self.padding,\n            self.data_format,\n            self.dilation_rate,\n        )\n        return KerasTensor(output_shape, dtype=inputs.dtype)\n\n\n@keras_export([\"keras.ops.conv\", \"keras.ops.nn.conv\"])\ndef conv(\n    inputs,\n    kernel,\n    strides=1,\n    padding=\"valid\",\n    data_format=None,\n    dilation_rate=1,\n):\n    \"\"\"General N-D convolution.\n\n    This ops supports 1D, 2D and 3D convolution.\n\n    Args:\n        inputs: Tensor of rank N+2. `inputs` has shape\n            `(batch_size,) + inputs_spatial_shape + (num_channels,)` if\n            `data_format=\"channels_last\"`, or\n            `(batch_size, num_channels) + inputs_spatial_shape` if\n            `data_format=\"channels_first\"`.\n        kernel: Tensor of rank N+2. `kernel` has shape\n            `(kernel_spatial_shape, num_input_channels, num_output_channels)`.\n            `num_input_channels` should match the number of channels in\n            `inputs`.\n        strides: int or int tuple/list of `len(inputs_spatial_shape)`,\n            specifying the strides of the convolution along each spatial\n            dimension. If `strides` is int, then every spatial dimension shares\n            the same `strides`.\n        padding: string, either `\"valid\"` or `\"same\"`. `\"valid\"` means no\n            padding is applied, and `\"same\"` results in padding evenly to the\n            left/right or up/down of the input such that output has the\n            same height/width dimension as the input when `strides=1`.\n        data_format: A string, either `\"channels_last\"` or `\"channels_first\"`.\n            `data_format` determines the ordering of the dimensions in the\n            inputs. If `data_format=\"channels_last\"`, `inputs` is of shape\n            `(batch_size, ..., channels)` while if\n            `data_format=\"channels_first\"`, `inputs` is of shape\n            `(batch_size, channels, ...)`.\n        dilation_rate: int or int tuple/list of `len(inputs_spatial_shape)`,\n            specifying the dilation rate to use for dilated convolution. If\n            `dilation_rate` is int, then every spatial dimension shares\n            the same `dilation_rate`.\n\n    Returns:\n        A tensor of rank N+2, the result of the conv operation.\n    \"\"\"\n    data_format = standardize_data_format(data_format)\n    padding = padding.lower()\n    if any_symbolic_tensors((inputs,)):\n        return Conv(strides, padding, data_format, dilation_rate).symbolic_call(\n            inputs, kernel\n        )\n    return backend.nn.conv(\n        inputs, kernel, strides, padding, data_format, dilation_rate\n    )\n\n\nclass DepthwiseConv(Operation):\n    def __init__(\n        self,\n        strides=1,\n        padding=\"valid\",\n        data_format=None,\n        dilation_rate=1,\n        *,\n        name=None,\n    ):\n        super().__init__(name=name)\n        self.strides = strides\n        self.padding = padding.lower()\n        self.data_format = data_format\n        self.dilation_rate = dilation_rate\n\n    def call(self, inputs, kernel):\n        return backend.nn.depthwise_conv(\n            inputs,\n            kernel,\n            self.strides,\n            self.padding,\n            self.data_format,\n            self.dilation_rate,\n        )\n\n    def compute_output_spec(self, inputs, kernel):\n        output_shape = operation_utils.compute_conv_output_shape(\n            inputs.shape,\n            kernel.shape[-1] * kernel.shape[-2],\n            kernel.shape[:-2],\n            self.strides,\n            self.padding,\n            self.data_format,\n            self.dilation_rate,\n        )\n        return KerasTensor(output_shape, dtype=inputs.dtype)\n\n\n@keras_export(\n    [\n        \"keras.ops.depthwise_conv\",\n        \"keras.ops.nn.depthwise_conv\",\n    ]\n)\ndef depthwise_conv(\n    inputs,\n    kernel,\n    strides=1,\n    padding=\"valid\",\n    data_format=None,\n    dilation_rate=1,\n):\n    \"\"\"General N-D depthwise convolution.\n\n    This ops supports 1D and 2D depthwise convolution.\n\n    Args:\n        inputs: Tensor of rank N+2. `inputs` has shape\n            `(batch_size,) + inputs_spatial_shape + (num_channels,)` if\n            `data_format=\"channels_last\"`, or\n            `(batch_size, num_channels) + inputs_spatial_shape` if\n            `data_format=\"channels_first\"`.\n        kernel: Tensor of rank N+2. `kernel` has shape\n            [kernel_spatial_shape, num_input_channels, num_channels_multiplier],\n            `num_input_channels` should match the number of channels in\n            `inputs`.\n        strides: int or int tuple/list of `len(inputs_spatial_shape)`,\n            specifying the strides of the convolution along each spatial\n            dimension. If `strides` is int, then every spatial dimension shares\n            the same `strides`.\n        padding: string, either `\"valid\"` or `\"same\"`. `\"valid\"` means no\n            padding is applied, and `\"same\"` results in padding evenly to the\n            left/right or up/down of the input such that output has the\n            same height/width dimension as the input when `strides=1`.\n        data_format: A string, either `\"channels_last\"` or `\"channels_first\"`.\n            `data_format` determines the ordering of the dimensions in the\n            inputs. If `data_format=\"channels_last\"`, `inputs` is of shape\n            `(batch_size, ..., channels)` while if\n            `data_format=\"channels_first\"`, `inputs` is of shape\n            `(batch_size, channels, ...)`.\n        dilation_rate: int or int tuple/list of `len(inputs_spatial_shape)`,\n            specifying the dilation rate to use for dilated convolution. If\n            `dilation_rate` is int, then every spatial dimension shares\n            the same `dilation_rate`.\n\n    Returns:\n        A tensor of rank N+2, the result of the depthwise conv operation.\n    \"\"\"\n    data_format = standardize_data_format(data_format)\n    padding = padding.lower()\n    if any_symbolic_tensors((inputs, kernel)):\n        return DepthwiseConv(\n            strides, padding, data_format, dilation_rate\n        ).symbolic_call(inputs, kernel)\n    return backend.nn.depthwise_conv(\n        inputs,\n        kernel,\n        strides,\n        padding,\n        data_format,\n        dilation_rate,\n    )\n\n\nclass SeparableConv(Operation):\n    def __init__(\n        self,\n        strides=1,\n        padding=\"valid\",\n        data_format=None,\n        dilation_rate=1,\n        *,\n        name=None,\n    ):\n        super().__init__(name=name)\n        self.strides = strides\n        self.padding = padding.lower()\n        self.data_format = data_format\n        self.dilation_rate = dilation_rate\n\n    def call(self, inputs, depthwise_kernel, pointwise_kernel):\n        return backend.nn.separable_conv(\n            inputs,\n            depthwise_kernel,\n            pointwise_kernel,\n            self.strides,\n            self.padding,\n            self.data_format,\n            self.dilation_rate,\n        )\n\n    def compute_output_spec(self, inputs, depthwise_kernel, pointwise_kernel):\n        output_shape = list(\n            depthwise_conv(\n                inputs,\n                depthwise_kernel,\n                self.strides,\n                self.padding,\n                self.data_format,\n                self.dilation_rate,\n            ).shape\n        )\n        if self.data_format == \"channels_last\":\n            output_shape[-1] = pointwise_kernel.shape[-1]\n        else:\n            output_shape[1] = pointwise_kernel.shape[-1]\n        return KerasTensor(output_shape, dtype=inputs.dtype)\n\n\n@keras_export(\n    [\n        \"keras.ops.separable_conv\",\n        \"keras.ops.nn.separable_conv\",\n    ]\n)\ndef separable_conv(\n    inputs,\n    depthwise_kernel,\n    pointwise_kernel,\n    strides=1,\n    padding=\"valid\",\n    data_format=None,\n    dilation_rate=1,\n):\n    \"\"\"General N-D separable convolution.\n\n    This ops supports 1D and 2D separable convolution. `separable_conv` is\n    a depthwise conv followed by a pointwise conv.\n\n    Args:\n        inputs: Tensor of rank N+2. `inputs` has shape\n            `(batch_size,) + inputs_spatial_shape + (num_channels,)` if\n            `data_format=\"channels_last\"`, or\n            `(batch_size, num_channels) + inputs_spatial_shape` if\n            `data_format=\"channels_first\"`.\n        depthwise_kernel: Tensor of rank N+2. `depthwise_kernel` has shape\n            [kernel_spatial_shape, num_input_channels, num_channels_multiplier],\n            `num_input_channels` should match the number of channels in\n            `inputs`.\n        pointwise_kernel: Tensor of rank N+2. `pointwise_kernel` has shape\n            `(*ones_like(kernel_spatial_shape),\n            num_input_channels * num_channels_multiplier, num_output_channels)`.\n        strides: int or int tuple/list of `len(inputs_spatial_shape)`,\n            specifying the strides of the convolution along each spatial\n            dimension. If `strides` is int, then every spatial dimension shares\n            the same `strides`.\n        padding: string, either `\"valid\"` or `\"same\"`. `\"valid\"` means no\n            padding is applied, and `\"same\"` results in padding evenly to the\n            left/right or up/down of the input such that output has the\n            same height/width dimension as the input when `strides=1`.\n        data_format: A string, either `\"channels_last\"` or `\"channels_first\"`.\n            `data_format` determines the ordering of the dimensions in the\n            inputs. If `data_format=\"channels_last\"`, `inputs` is of shape\n            `(batch_size, ..., channels)` while if\n            `data_format=\"channels_first\"`, `inputs` is of shape\n            `(batch_size, channels, ...)`.\n        dilation_rate: int or int tuple/list of `len(inputs_spatial_shape)`,\n            specifying the dilation rate to use for dilated convolution. If\n            `dilation_rate` is int, then every spatial dimension shares\n            the same `dilation_rate`.\n\n    Returns:\n        A tensor of rank N+2, the result of the depthwise conv operation.\n    \"\"\"\n    data_format = standardize_data_format(data_format)\n    padding = padding.lower()\n    if any_symbolic_tensors((inputs,)):\n        return SeparableConv(\n            strides,\n            padding,\n            data_format,\n            dilation_rate,\n        ).symbolic_call(inputs, depthwise_kernel, pointwise_kernel)\n    return backend.nn.separable_conv(\n        inputs,\n        depthwise_kernel,\n        pointwise_kernel,\n        strides,\n        padding,\n        data_format,\n        dilation_rate,\n    )\n\n\nclass ConvTranspose(Operation):\n    def __init__(\n        self,\n        strides=1,\n        padding=\"valid\",\n        output_padding=None,\n        data_format=None,\n        dilation_rate=1,\n        *,\n        name=None,\n    ):\n        super().__init__(name=name)\n        self.strides = strides\n        self.output_padding = output_padding\n        self.padding = padding.lower()\n        self.data_format = data_format\n        self.dilation_rate = dilation_rate\n\n    def call(\n        self,\n        inputs,\n        kernel,\n    ):\n        return backend.nn.conv_transpose(\n            inputs,\n            kernel,\n            self.strides,\n            self.output_padding,\n            self.padding,\n            self.data_format,\n            self.dilation_rate,\n        )\n\n    def compute_output_spec(self, inputs, kernel):\n        kernel_size = kernel.shape[:-2]\n        filters = kernel.shape[-2]\n        output_shape = compute_conv_transpose_output_shape(\n            inputs.shape,\n            kernel_size,\n            filters,\n            self.strides,\n            self.padding,\n            self.output_padding,\n            self.data_format,\n            self.dilation_rate,\n        )\n        return KerasTensor(output_shape, dtype=inputs.dtype)\n\n\n@keras_export(\n    [\n        \"keras.ops.conv_transpose\",\n        \"keras.ops.nn.conv_transpose\",\n    ]\n)\ndef conv_transpose(\n    inputs,\n    kernel,\n    strides=1,\n    padding=\"valid\",\n    output_padding=None,\n    data_format=None,\n    dilation_rate=1,\n):\n    \"\"\"General N-D convolution transpose.\n\n    Also known as de-convolution. This ops supports 1D, 2D and 3D convolution.\n\n    Args:\n        inputs: Tensor of rank N+2. `inputs` has shape\n            `(batch_size,) + inputs_spatial_shape + (num_channels,)` if\n            `data_format=\"channels_last\"`, or\n            `(batch_size, num_channels) + inputs_spatial_shape` if\n            `data_format=\"channels_first\"`.\n        kernel: Tensor of rank N+2. `kernel` has shape\n            [kernel_spatial_shape, num_output_channels, num_input_channels],\n            `num_input_channels` should match the number of channels in\n            `inputs`.\n        strides: int or int tuple/list of `len(inputs_spatial_shape)`,\n            specifying the strides of the convolution along each spatial\n            dimension. If `strides` is int, then every spatial dimension shares\n            the same `strides`.\n        padding: string, either `\"valid\"` or `\"same\"`. `\"valid\"` means no\n            padding is applied, and `\"same\"` results in padding evenly to the\n            left/right or up/down of the input such that output has the\n            same height/width dimension as the input when `strides=1`.\n        output_padding: int or int tuple/list of `len(inputs_spatial_shape)`,\n            specifying the amount of padding along the height and width of\n            the output tensor. Can be a single integer to specify the same\n            value for all spatial dimensions. The amount of output padding\n            along a given dimension must be lower than the stride along that\n            same dimension. If set to `None` (default), the output shape is\n            inferred.\n        data_format: A string, either `\"channels_last\"` or `\"channels_first\"`.\n            `data_format` determines the ordering of the dimensions in the\n            inputs. If `data_format=\"channels_last\"`, `inputs` is of shape\n            `(batch_size, ..., channels)` while if\n            `data_format=\"channels_first\"`, `inputs` is of shape\n            `(batch_size, channels, ...)`.\n        dilation_rate: int or int tuple/list of `len(inputs_spatial_shape)`,\n            specifying the dilation rate to use for dilated convolution. If\n            `dilation_rate` is int, then every spatial dimension shares\n            the same `dilation_rate`.\n\n    Returns:\n        A tensor of rank N+2, the result of the conv operation.\n    \"\"\"\n    data_format = standardize_data_format(data_format)\n    padding = padding.lower()\n    if any_symbolic_tensors((inputs,)):\n        return ConvTranspose(\n            strides, padding, output_padding, data_format, dilation_rate\n        ).symbolic_call(inputs, kernel)\n    return backend.nn.conv_transpose(\n        inputs,\n        kernel,\n        strides,\n        padding,\n        output_padding,\n        data_format,\n        dilation_rate,\n    )\n\n\nclass OneHot(Operation):\n    def __init__(\n        self, num_classes, axis=-1, dtype=None, sparse=False, *, name=None\n    ):\n        super().__init__(name=name)\n        self.num_classes = num_classes\n        self.axis = axis\n        self.dtype = dtype\n        self.sparse = sparse\n\n    def call(self, x):\n        return backend.nn.one_hot(\n            x,\n            self.num_classes,\n            axis=self.axis,\n            dtype=self.dtype,\n            sparse=self.sparse,\n        )\n\n    def compute_output_spec(self, x):\n        x_shape = list(getattr(x, \"shape\", []))\n        if self.axis == -1:\n            x_shape.append(self.num_classes)\n        elif self.axis >= 0 and self.axis < len(x_shape):\n            x_shape.insert(self.axis, self.num_classes)\n        else:\n            raise ValueError(\n                f\"axis must be -1 or between [0, {len(x.shape)}), but \"\n                f\"received {self.axis}.\"\n            )\n        return KerasTensor(x_shape, dtype=self.dtype, sparse=self.sparse)\n\n\n@keras_export([\"keras.ops.one_hot\", \"keras.ops.nn.one_hot\"])\ndef one_hot(x, num_classes, axis=-1, dtype=None, sparse=False):\n    \"\"\"Converts integer tensor `x` into a one-hot tensor.\n\n    The one-hot encoding is a representation where each integer value is\n    converted into a binary vector with a length equal to `num_classes`,\n    and the index corresponding to the integer value is marked as 1, while\n    all other indices are marked as 0.\n\n    Args:\n        x: Integer tensor to be encoded. The shape can be\n            arbitrary, but the dtype should be integer.\n        num_classes: Number of classes for the one-hot encoding.\n        axis: Axis along which the encoding is performed.\n            `-1` represents the last axis. Defaults to `-1`.\n        dtype: (Optional) Data type of the output tensor. If not\n            provided, it defaults to the default data type of the backend.\n        sparse: Whether to return a sparse tensor; for backends that support\n            sparse tensors.\n\n    Returns:\n        Integer tensor: One-hot encoded tensor with the same shape as `x`\n        except for the specified `axis` dimension, which will have\n        a length of `num_classes`. The dtype of the output tensor\n        is determined by `dtype` or the default data type of the backend.\n\n    Example:\n\n    >>> x = keras.ops.convert_to_tensor([1, 3, 2, 0])\n    >>> one_hot(x, num_classes=4)\n    array([[0. 1. 0. 0.]\n           [0. 0. 0. 1.]\n           [0. 0. 1. 0.]\n           [1. 0. 0. 0.]], shape=(4, 4), dtype=float32)\n    \"\"\"\n    dtype = backend.standardize_dtype(dtype)\n    if any_symbolic_tensors((x,)):\n        return OneHot(\n            num_classes, axis=axis, dtype=dtype, sparse=sparse\n        ).symbolic_call(x)\n    return backend.nn.one_hot(\n        x,\n        num_classes,\n        axis=axis,\n        dtype=dtype or backend.floatx(),\n        sparse=sparse,\n    )\n\n\nclass BinaryCrossentropy(Operation):\n    def __init__(self, from_logits=False, *, name=None):\n        super().__init__(name=name)\n        self.from_logits = from_logits\n\n    def call(self, target, output):\n        return backend.nn.binary_crossentropy(\n            target, output, from_logits=self.from_logits\n        )\n\n    def compute_output_spec(self, target, output):\n        if target.shape != output.shape:\n            raise ValueError(\n                \"Arguments `target` and `output` must have the same shape. \"\n                \"Received: \"\n                f\"target.shape={target.shape}, output.shape={output.shape}\"\n            )\n        return KerasTensor(output.shape, dtype=output.dtype)\n\n\n@keras_export(\n    [\n        \"keras.ops.binary_crossentropy\",\n        \"keras.ops.nn.binary_crossentropy\",\n    ]\n)\ndef binary_crossentropy(target, output, from_logits=False):\n    \"\"\"Computes binary cross-entropy loss between target and output tensor.\n\n    The binary cross-entropy loss is commonly used in binary\n    classification tasks where each input sample belongs to one\n    of the two classes. It measures the dissimilarity between the\n    target and output probabilities or logits.\n\n    Args:\n        target: The target tensor representing the true binary labels.\n            Its shape should match the shape of the `output` tensor.\n        output: The output tensor representing the predicted probabilities\n            or logits. Its shape should match the shape of the\n            `target` tensor.\n        from_logits: (optional) Whether `output` is a tensor of logits or\n            probabilities.\n            Set it to `True` if `output` represents logits; otherwise,\n            set it to `False` if `output` represents probabilities.\n            Defaults to `False`.\n\n    Returns:\n        Integer tensor: The computed binary cross-entropy loss between\n        `target` and `output`.\n\n    Example:\n\n    >>> target = keras.ops.convert_to_tensor([0, 1, 1, 0])\n    >>> output = keras.ops.convert_to_tensor([0.1, 0.9, 0.8, 0.2])\n    >>> binary_crossentropy(target, output)\n    array([0.10536054 0.10536054 0.22314355 0.22314355],\n          shape=(4,), dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((target, output)):\n        return BinaryCrossentropy(from_logits=from_logits).symbolic_call(\n            target, output\n        )\n    return backend.nn.binary_crossentropy(\n        target, output, from_logits=from_logits\n    )\n\n\nclass CategoricalCrossentropy(Operation):\n    def __init__(self, from_logits=False, axis=-1, *, name=None):\n        super().__init__(name=name)\n        self.from_logits = from_logits\n        self.axis = axis\n\n    def call(self, target, output):\n        return backend.nn.categorical_crossentropy(\n            target, output, from_logits=self.from_logits, axis=self.axis\n        )\n\n    def compute_output_spec(self, target, output):\n        if target.shape != output.shape:\n            raise ValueError(\n                \"Arguments `target` and `output` must have the same shape. \"\n                \"Received: \"\n                f\"target.shape={target.shape}, output.shape={output.shape}\"\n            )\n        if len(target.shape) < 1:\n            raise ValueError(\n                \"Arguments `target` and `output` must be at least rank 1. \"\n                \"Received: \"\n                f\"target.shape={target.shape}, output.shape={output.shape}\"\n            )\n        return KerasTensor(output.shape[:-1], dtype=output.dtype)\n\n\n@keras_export(\n    [\n        \"keras.ops.categorical_crossentropy\",\n        \"keras.ops.nn.categorical_crossentropy\",\n    ]\n)\ndef categorical_crossentropy(target, output, from_logits=False, axis=-1):\n    \"\"\"Computes categorical cross-entropy loss between target and output tensor.\n\n    The categorical cross-entropy loss is commonly used in multi-class\n    classification tasks where each input sample can belong to one of\n    multiple classes. It measures the dissimilarity\n    between the target and output probabilities or logits.\n\n    Args:\n        target: The target tensor representing the true categorical labels.\n            Its shape should match the shape of the `output` tensor\n            except for the last dimension.\n        output: The output tensor representing the predicted probabilities\n            or logits. Its shape should match the shape of the `target`\n            tensor except for the last dimension.\n        from_logits: (optional) Whether `output` is a tensor of logits or\n            probabilities.\n            Set it to `True` if `output` represents logits; otherwise,\n            set it to `False` if `output` represents probabilities.\n            Defaults to `False`.\n        axis: (optional) The axis along which the categorical cross-entropy\n            is computed.\n            Defaults to `-1`, which corresponds to the last dimension of\n            the tensors.\n\n    Returns:\n        Integer tensor: The computed categorical cross-entropy loss between\n        `target` and `output`.\n\n    Example:\n\n    >>> target = keras.ops.convert_to_tensor(\n    ... [[1, 0, 0],\n    ...  [0, 1, 0],\n    ...  [0, 0, 1]])\n    >>> output = keras.ops.convert_to_tensor(\n    ... [[0.9, 0.05, 0.05],\n    ...  [0.1, 0.8, 0.1],\n    ...  [0.2, 0.3, 0.5]])\n    >>> categorical_crossentropy(target, output)\n    array([0.10536054 0.22314355 0.6931472 ], shape=(3,), dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((target, output)):\n        return CategoricalCrossentropy(\n            from_logits=from_logits, axis=axis\n        ).symbolic_call(target, output)\n    return backend.nn.categorical_crossentropy(\n        target, output, from_logits=from_logits, axis=axis\n    )\n\n\nclass SparseCategoricalCrossentropy(Operation):\n    def __init__(self, from_logits=False, axis=-1, *, name=None):\n        super().__init__(name=name)\n        self.from_logits = from_logits\n        self.axis = axis\n\n    def call(self, target, output):\n        return backend.nn.sparse_categorical_crossentropy(\n            target, output, from_logits=self.from_logits, axis=self.axis\n        )\n\n    def compute_output_spec(self, target, output):\n        if len(output.shape) < 1:\n            raise ValueError(\n                \"Argument `output` must be at least rank 1. \"\n                \"Received: \"\n                f\"output.shape={output.shape}\"\n            )\n        target_shape = target.shape\n        if len(target_shape) == len(output.shape) and target_shape[-1] == 1:\n            target_shape = target_shape[:-1]\n        if target_shape != output.shape[:-1]:\n            raise ValueError(\n                \"Arguments `target` and `output` must have the same shape \"\n                \"up until the last dimension: \"\n                f\"target.shape={target.shape}, output.shape={output.shape}\"\n            )\n        return KerasTensor(output.shape[:-1], dtype=output.dtype)\n\n\n@keras_export(\n    [\n        \"keras.ops.sparse_categorical_crossentropy\",\n        \"keras.ops.nn.sparse_categorical_crossentropy\",\n    ]\n)\ndef sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):\n    \"\"\"Computes sparse categorical cross-entropy loss.\n\n    The sparse categorical cross-entropy loss is similar to categorical\n    cross-entropy, but it is used when the target tensor contains integer\n    class labels instead of one-hot encoded vectors. It measures the\n    dissimilarity between the target and output probabilities or logits.\n\n    Args:\n        target: The target tensor representing the true class labels as\n            integers. Its shape should match the shape of the `output`\n            tensor except for the last dimension.\n        output: The output tensor representing the predicted probabilities\n            or logits.\n            Its shape should match the shape of the `target` tensor except\n            for the last dimension.\n        from_logits: (optional) Whether `output` is a tensor of logits\n            or probabilities.\n            Set it to `True` if `output` represents logits; otherwise,\n            set it to `False` if `output` represents probabilities.\n            Defaults to `False`.\n        axis: (optional) The axis along which the sparse categorical\n            cross-entropy is computed.\n            Defaults to `-1`, which corresponds to the last dimension\n            of the tensors.\n\n    Returns:\n        Integer tensor: The computed sparse categorical cross-entropy\n        loss between `target` and `output`.\n\n    Example:\n\n    >>> target = keras.ops.convert_to_tensor([0, 1, 2], dtype=int32)\n    >>> output = keras.ops.convert_to_tensor(\n    ... [[0.9, 0.05, 0.05],\n    ...  [0.1, 0.8, 0.1],\n    ...  [0.2, 0.3, 0.5]])\n    >>> sparse_categorical_crossentropy(target, output)\n    array([0.10536056 0.22314355 0.6931472 ], shape=(3,), dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((target, output)):\n        return SparseCategoricalCrossentropy(\n            from_logits=from_logits, axis=axis\n        ).symbolic_call(target, output)\n    return backend.nn.sparse_categorical_crossentropy(\n        target, output, from_logits=from_logits, axis=axis\n    )\n\n\nclass MultiHot(Operation):\n    def __init__(\n        self,\n        num_classes=None,\n        axis=-1,\n        dtype=None,\n        sparse=False,\n        *,\n        name=None,\n        **kwargs,\n    ):\n        if num_classes is None and \"num_tokens\" in kwargs:\n            num_classes = kwargs.pop(\"num_tokens\")\n        if num_classes is None:\n            raise ValueError(\"Argument `num_classes` must be specified.\")\n        super().__init__(name=name)\n        self.num_classes = num_classes\n        self.axis = axis\n        self.dtype = dtype or backend.floatx()\n        self.sparse = sparse\n\n    def call(self, inputs):\n        return backend.nn.multi_hot(\n            inputs,\n            num_classes=self.num_classes,\n            axis=self.axis,\n            dtype=self.dtype,\n        )\n\n    def compute_output_spec(self, inputs):\n        x_shape = list(getattr(inputs, \"shape\", []))\n        if self.axis == -1:\n            x_shape.append(self.num_classes)\n        elif self.axis >= 0 and self.axis < len(x_shape):\n            x_shape.insert(self.axis, self.num_classes)\n        else:\n            raise ValueError(\n                f\"axis must be -1 or between [0, {len(inputs.shape)}), but \"\n                f\"received {self.axis}.\"\n            )\n\n        if len(x_shape) == 2:\n            x_shape = [x_shape[-1]]\n        else:\n            x_shape = [x_shape[0]] + x_shape[2:]\n\n        return KerasTensor(x_shape, dtype=inputs.dtype, sparse=self.sparse)\n\n\n@keras_export(\n    [\n        \"keras.ops.multi_hot\",\n        \"keras.ops.nn.multi_hot\",\n    ]\n)\ndef multi_hot(\n    inputs, num_classes=None, axis=-1, dtype=None, sparse=False, **kwargs\n):\n    \"\"\"Encodes integer labels as multi-hot vectors.\n\n    This function encodes integer labels as multi-hot vectors, where each label\n    is mapped to a binary value in the resulting vector.\n\n    Args:\n        inputs: Tensor of integer labels to be converted to multi-hot vectors.\n        num_classes: Integer, the total number of unique classes.\n        axis: (optional) Axis along which the multi-hot encoding should be\n            added. Defaults to `-1`, which corresponds to the last dimension.\n        dtype: (optional) The data type of the resulting tensor. Default\n            is backend's float type.\n        sparse: Whether to return a sparse tensor; for backends that support\n            sparse tensors.\n\n    Returns:\n        Tensor: The multi-hot encoded tensor.\n\n    Example:\n\n    >>> data = keras.ops.convert_to_tensor([0, 4])\n    >>> keras.ops.multi_hot(data, num_classes=5)\n    array([1.0, 0.0, 0.0, 0.0, 1.0], dtype=float32)\n\n    \"\"\"\n    if num_classes is None and \"num_tokens\" in kwargs:\n        num_classes = kwargs.pop(\"num_tokens\")\n    if num_classes is None:\n        raise ValueError(\"Argument `num_classes` must be specified.\")\n\n    dtype = backend.standardize_dtype(dtype)\n    if any_symbolic_tensors((inputs,)):\n        return MultiHot(num_classes, axis, dtype, sparse).symbolic_call(inputs)\n\n    return backend.nn.multi_hot(inputs, num_classes, axis, dtype, sparse)\n\n\nclass Moments(Operation):\n    def __init__(self, axes, keepdims=False, synchronized=False, *, name=None):\n        super().__init__(name=name)\n        self.axes = axes\n        self.keepdims = keepdims\n        self.synchronized = synchronized\n\n    def call(self, x):\n        return backend.nn.moments(\n            x,\n            axes=self.axes,\n            keepdims=self.keepdims,\n            synchronized=self.synchronized,\n        )\n\n    def compute_output_spec(self, x):\n        return (\n            KerasTensor(\n                reduce_shape(x.shape, axis=self.axes, keepdims=self.keepdims),\n                dtype=x.dtype,\n            ),\n            KerasTensor(\n                reduce_shape(x.shape, axis=self.axes, keepdims=self.keepdims),\n                dtype=x.dtype,\n            ),\n        )\n\n\n@keras_export(\n    [\n        \"keras.ops.moments\",\n        \"keras.ops.nn.moments\",\n    ]\n)\ndef moments(x, axes, keepdims=False, synchronized=False):\n    \"\"\"Calculates the mean and variance of `x`.\n\n    The mean and variance are calculated by aggregating the contents of `x`\n    across `axes`. If `x` is 1-D and `axes = [0]` this is just the mean and\n    variance of a vector.\n\n    Args:\n        x: Input tensor.\n        axes: A list of axes which to compute mean and variance.\n        keepdims: If this is set to `True`, the axes which are reduced are left\n            in the result as dimensions with size one.\n        synchronized: Only applicable with the TensorFlow backend.\n            If `True`, synchronizes the global batch statistics (mean and\n            variance) across all devices at each training step in a\n            distributed training strategy. If `False`, each replica uses its own\n            local batch statistics.\n\n    Returns:\n        A tuple containing two tensors - mean and variance.\n\n    Example:\n\n    >>> x = keras.ops.convert_to_tensor([0, 1, 2, 3, 100], dtype=\"float32\")\n    >>> keras.ops.moments(x, axes=[0])\n    (array(21.2, dtype=float32), array(1553.3601, dtype=float32))\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Moments(axes, keepdims, synchronized=synchronized).symbolic_call(\n            x\n        )\n\n    return backend.nn.moments(x, axes, keepdims, synchronized=synchronized)\n\n\nclass BatchNorm(Operation):\n    def __init__(self, axis, epsilon=1e-3, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n        self.epsilon = epsilon\n\n    def call(self, x, mean, variance, offset=None, scale=None):\n        return backend.nn.batch_normalization(\n            x,\n            mean,\n            variance,\n            axis=self.axis,\n            offset=offset,\n            scale=scale,\n            epsilon=self.epsilon,\n        )\n\n    def _check_shape(self, name, shape, expected_shape):\n        if shape != expected_shape:\n            raise ValueError(\n                f\"Arguments `{name}` must be a vector of length \"\n                f\"`x.shape[axis]`. Expected: `{expected_shape}`. \"\n                f\"Received: `{shape}.\"\n            )\n\n    def compute_output_spec(self, x, mean, variance, offset, scale):\n        shape = (x.shape[self.axis],)\n        self._check_shape(\"mean\", tuple(mean.shape), shape)\n        self._check_shape(\"variance\", tuple(variance.shape), shape)\n        if offset is not None:\n            self._check_shape(\"offset\", tuple(offset.shape), shape)\n        if offset is not scale:\n            self._check_shape(\"scale\", tuple(scale.shape), shape)\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export(\n    [\n        \"keras.ops.batch_normalization\",\n        \"keras.ops.nn.batch_normalization\",\n    ]\n)\ndef batch_normalization(\n    x, mean, variance, axis, offset=None, scale=None, epsilon=1e-3\n):\n    \"\"\"Normalizes `x` by `mean` and `variance`.\n\n    This op is typically used by the batch normalization step in a neural\n    network. It normalizes the input tensor along the given axis.\n\n    Args:\n        x: Input tensor.\n        mean: A mean vector of the same length as the `axis` dimension of the\n            input thensor.\n        variance: A variance vector of the same length as the `axis` dimension\n            of the input tensor.\n        axis: Integer, the axis that should be normalized.\n        offset: An offset vector of the same length as the `axis` dimension of\n            the input tensor. If not `None`, `offset` is added to the normalized\n            tensor. Defaults to `None`.\n        scale: A scale vector of the same length as the `axis` dimension of the\n            input tensor. If not `None`, the normalized tensor is multiplied by\n            `scale`. Defaults to `None`.\n        epsilon: Small float added to variance to avoid dividing by zero.\n            Defaults to 1e-3.\n\n    Returns:\n        The normalized tensor.\n\n    Example:\n\n    >>> x = keras.ops.convert_to_tensor(\n    ...     [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6], [0.7, 0.8, 0.9]]\n    ... )\n    >>> keras.ops.batch_normalization(\n    ...     x,\n    ...     mean=[0.4, 0.5, 0.6],\n    ...     variance=[0.67, 0.67, 0.67],\n    ...     axis=-1\n    ... )\n    array([[-3.6624e-01, -3.6624e-01, -3.6624e-01],\n           [-4.6445e-09,  0.0000e+00, -1.8578e-08],\n           [ 3.6624e-01,  3.6624e-01,  3.6624e-01]])\n\n    \"\"\"\n    if any_symbolic_tensors((x, mean, variance, offset, scale)):\n        return BatchNorm(axis, epsilon).symbolic_call(\n            x, mean, variance, offset, scale\n        )\n\n    return backend.nn.batch_normalization(\n        x, mean, variance, axis, offset, scale, epsilon\n    )\n\n\nclass CTCLoss(Operation):\n    def __init__(self, mask_index=0, *, name=None):\n        super().__init__(name=name)\n        self.mask_index = mask_index\n\n    def call(self, target, output, target_length, output_length):\n        return backend.nn.ctc_loss(\n            target, output, target_length, output_length, self.mask_index\n        )\n\n    def _check_shape_first_dim(self, name1, shape1, name2, shape2):\n        if shape1[0] != shape2[0]:\n            raise ValueError(\n                f\"Arguments `{name1}` and `{name2}` must have the same \"\n                \"first dimension. \"\n                f\"Received shapes: `{shape1}` and `{shape2}`.\"\n            )\n\n    def compute_output_spec(self, target, output, target_length, output_length):\n        self._check_shape_first_dim(\n            \"target\", target.shape, \"output\", output.shape\n        )\n        self._check_shape_first_dim(\n            \"target_length\", target_length.shape, \"target\", target.shape\n        )\n        self._check_shape_first_dim(\n            \"output_length\", output_length.shape, \"output\", output.shape\n        )\n        dtype = backend.result_type(output.dtype, \"float32\")\n        return KerasTensor((target.shape[0],), dtype=dtype)\n\n\n@keras_export(\n    [\n        \"keras.ops.ctc_loss\",\n        \"keras.ops.nn.ctc_loss\",\n    ]\n)\ndef ctc_loss(target, output, target_length, output_length, mask_index=0):\n    \"\"\"CTC (Connectionist Temporal Classification) loss.\n\n    Args:\n        target: A tensor of shape `(batch_size, max_length)` containing\n            the true labels in integer format.\n        output: A tensor of shape `(batch_size, max_length, num_classes)`\n            containing logits (the output of your model).\n        target_length: A tensor of shape `(batch_size,)` containing the\n            true label lengths.\n        output_length: A tensor of shape `(batch_size,)` containing the\n            output lengths.\n        mask_index: The index of the mask character in the vocabulary.\n            Defaults to `0`.\n    \"\"\"\n\n    if any_symbolic_tensors((target, output, target_length, output_length)):\n        return CTCLoss(mask_index).symbolic_call(\n            target, output, target_length, output_length\n        )\n    return backend.nn.ctc_loss(\n        target, output, target_length, output_length, mask_index\n    )\n\n\nclass CTCDecode(Operation):\n    def __init__(\n        self,\n        strategy=\"greedy\",\n        beam_width=100,\n        top_paths=1,\n        merge_repeated=True,\n        mask_index=0,\n        *,\n        name=None,\n    ):\n        super().__init__(name=name)\n        self.strategy = strategy\n        self.beam_width = beam_width\n        self.top_paths = top_paths\n        self.merge_repeated = merge_repeated\n        self.mask_index = mask_index\n\n    def call(self, inputs, sequence_lengths):\n        return backend.nn.ctc_decode(\n            inputs,\n            sequence_lengths,\n            strategy=self.strategy,\n            beam_width=self.beam_width,\n            top_paths=self.top_paths,\n            merge_repeated=self.merge_repeated,\n            mask_index=self.mask_index,\n        )\n\n    def compute_output_spec(self, inputs, sequence_lengths):\n        inputs_shape = inputs.shape\n        if self.strategy == \"greedy\":\n            top_paths = 1\n        else:\n            top_paths = self.top_paths\n        dtype = backend.result_type(inputs.dtype, \"float32\")\n        return (\n            KerasTensor(\n                (top_paths, inputs_shape[0], inputs_shape[1]), dtype=\"int32\"\n            ),\n            KerasTensor((inputs_shape[0], top_paths), dtype=dtype),\n        )\n\n\n@keras_export(\n    [\n        \"keras.ops.ctc_decode\",\n        \"keras.ops.nn.ctc_decode\",\n    ]\n)\ndef ctc_decode(\n    inputs,\n    sequence_lengths,\n    strategy=\"greedy\",\n    beam_width=100,\n    top_paths=1,\n    merge_repeated=True,\n    mask_index=0,\n):\n    \"\"\"Decodes the output of a CTC model.\n\n    Args:\n        inputs: A tensor of shape `(batch_size, max_length, num_classes)`\n            containing the logits (the output of the model).\n            They should *not* be normalized via softmax.\n        sequence_lengths: A tensor of shape `(batch_size,)` containing the\n            sequence lengths for the batch.\n        strategy: A string for the decoding strategy. Supported values are\n            `\"greedy\"` and `\"beam_search\"`.\n        beam_width: An integer scalar beam width used in beam search.\n            Defaults to 100.\n        top_paths: An integer scalar, the number of top paths to return.\n            Defaults to 1.\n        merge_repeated: A boolean scalar, whether to merge repeated\n            labels in the output. Defaults to `True`.\n        mask_index: An integer scalar, the index of the mask character in\n            the vocabulary. Defaults to `0`.\n\n    Returns:\n        A tuple containing:\n        - The tensor representing the list of decoded sequences. If\n            `strategy=\"greedy\"`, the shape is `(1, batch_size, max_length)`. If\n            `strategy=\"beam_search\"`, the shape is\n            `(top_paths, batch_size, max_length)`. Note that: `-1` indicates the\n            blank label.\n        - If `strategy=\"greedy\"`, a tensor of shape `(batch_size, 1)`\n            representing the negative of the sum of the probability logits for\n            each sequence. If `strategy=\"beam_seatch\"`, a tensor of shape\n            `(batch_size, top_paths)` representing the log probability for each\n            sequence.\n    \"\"\"\n\n    if any_symbolic_tensors((inputs, sequence_lengths)):\n        return CTCDecode(\n            strategy=strategy,\n            beam_width=beam_width,\n            top_paths=top_paths,\n            merge_repeated=merge_repeated,\n            mask_index=mask_index,\n        ).symbolic_call(inputs, sequence_lengths)\n    return backend.nn.ctc_decode(\n        inputs=inputs,\n        sequence_lengths=sequence_lengths,\n        strategy=strategy,\n        beam_width=beam_width,\n        top_paths=top_paths,\n        merge_repeated=merge_repeated,\n        mask_index=mask_index,\n    )\n\n\nclass Normalize(Operation):\n    def __init__(self, axis=-1, order=2, epsilon=None, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n        self.order = order\n        self.epsilon = epsilon\n\n    def compute_output_spec(self, x):\n        return KerasTensor(shape=x.shape)\n\n    def call(self, x):\n        return _normalize(\n            x, axis=self.axis, order=self.order, epsilon=self.epsilon\n        )\n\n\n@keras_export(\n    [\n        \"keras.ops.normalize\",\n        \"keras.ops.nn.normalize\",\n    ]\n)\ndef normalize(x, axis=-1, order=2, epsilon=None):\n    \"\"\"Normalizes `x` over the specified axis.\n\n    It is defined as: `normalize(x) = x / max(norm(x), epsilon)`.\n\n    Args:\n        x: Input tensor.\n        axis: The axis or axes along which to perform normalization.\n            Default to -1.\n        order: The exponent value in the norm formulation.\n            Defaults to 2.\n        epsilon: A lower bound value for the norm.\n            Defaults to `backend.epsilon()`.\n\n    Returns:\n        The normalized array.\n\n    Example:\n\n    >>> x = keras.ops.convert_to_tensor([[1, 2, 3], [4, 5, 6]])\n    >>> x_norm = keras.ops.math.normalize(x)\n    >>> print(x_norm)\n    array([[0.26726124 0.5345225  0.8017837 ]\n           [0.45584232 0.5698029  0.68376344]], shape=(2, 3), dtype=float32)\n\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Normalize(axis=axis, order=order, epsilon=epsilon).symbolic_call(\n            x\n        )\n    return _normalize(x, axis=axis, order=order, epsilon=epsilon)\n\n\ndef _normalize(x, axis=-1, order=2, epsilon=None):\n    if not isinstance(order, int) or not order >= 1:\n        raise ValueError(\n            f\"Argument `order` must be an int >= 1. Received: order={order}\"\n        )\n    x = backend.convert_to_tensor(x)\n    if len(x.shape) == 0:\n        x = backend.numpy.expand_dims(x, axis=0)\n    if epsilon is None:\n        epsilon = backend.epsilon()\n    if 2 == order:\n        # A special case: L2 normalization with `x * rsqrt(...)`\n        # instead of `x / sqrt(...)`\n        square_sum = backend.numpy.sum(\n            backend.numpy.square(x), axis=axis, keepdims=True\n        )\n        inv_norm = backend.math.rsqrt(square_sum)\n        inv_norm = backend.numpy.minimum(inv_norm, 1.0 / epsilon)\n        return x * inv_norm\n    norm = backend.linalg.norm(x, ord=order, axis=axis, keepdims=True)\n    denom = backend.numpy.maximum(norm, epsilon)\n    return backend.numpy.divide(x, denom)\n\n\nclass PSNR(Operation):\n    def __init__(\n        self,\n        max_val,\n        *,\n        name=None,\n    ):\n        super().__init__(name=name)\n        self.max_val = max_val\n\n    def call(self, x1, x2):\n        return backend.nn.psnr(\n            x1=x1,\n            x2=x2,\n            max_val=self.max_val,\n        )\n\n    def compute_output_spec(self, x1, x2):\n        if len(x1.shape) != len(x2.shape):\n            raise ValueError(\"Inputs must have the same rank\")\n\n        return KerasTensor(shape=())\n\n\n@keras_export(\n    [\n        \"keras.ops.psnr\",\n        \"keras.ops.nn.psnr\",\n    ]\n)\ndef psnr(\n    x1,\n    x2,\n    max_val,\n):\n    \"\"\"Peak Signal-to-Noise Ratio (PSNR) function.\n\n    This function computes the Peak Signal-to-Noise Ratio between two signals,\n    `x1` and `x2`. PSNR is a measure of the quality of a reconstructed signal.\n    The higher the PSNR, the closer the reconstructed signal is to the original\n    signal. Note that it can become negative when the signal power is\n    smaller that the noise power.\n\n    Args:\n        x1: The first input signal.\n        x2: The second input signal. Must have the same shape as `x1`.\n        max_val: The maximum possible value in the signals.\n\n    Returns:\n        float: The PSNR value between `x1` and `x2`.\n\n    Examples:\n\n    >>> x1 = keras.random.normal((2, 4, 4, 3))\n    >>> x2 = keras.random.normal((2, 4, 4, 3))\n    >>> max_val = 1.0\n    >>> keras.ops.nn.psnr(x1, x2, max_val)\n    -3.1697404\n    \"\"\"\n    if any_symbolic_tensors(\n        (\n            x1,\n            x2,\n        )\n    ):\n        return PSNR(\n            max_val,\n        ).symbolic_call(x1, x2)\n    return backend.nn.psnr(\n        x1,\n        x2,\n        max_val,\n    )\n\n\nclass DotProductAttention(Operation):\n    def __init__(\n        self,\n        is_causal=False,\n        flash_attention=None,\n        attn_logits_soft_cap=None,\n        *,\n        name=None,\n    ):\n        super().__init__(name=name)\n        self.is_causal = is_causal\n        self.flash_attention = flash_attention\n        self.attn_logits_soft_cap = attn_logits_soft_cap\n\n    def call(\n        self,\n        query,\n        key,\n        value,\n        bias=None,\n        mask=None,\n        scale=None,\n    ):\n        return backend.nn.dot_product_attention(\n            query,\n            key,\n            value,\n            bias=bias,\n            mask=mask,\n            scale=scale,\n            is_causal=self.is_causal,\n            flash_attention=self.flash_attention,\n            attn_logits_soft_cap=self.attn_logits_soft_cap,\n        )\n\n    def compute_output_spec(\n        self,\n        query,\n        key,\n        value,\n        bias=None,\n        mask=None,\n        scale=None,\n    ):\n        dtype = backend.result_type(query.dtype, key.dtype, value.dtype)\n        return KerasTensor(query.shape, dtype=dtype)\n\n\n@keras_export(\n    [\"keras.ops.dot_product_attention\", \"keras.ops.nn.dot_product_attention\"]\n)\ndef dot_product_attention(\n    query,\n    key,\n    value,\n    bias=None,\n    mask=None,\n    scale=None,\n    is_causal=False,\n    flash_attention=None,\n    attn_logits_soft_cap=None,\n):\n    \"\"\"Scaled dot product attention function.\n\n    Computes the attention function on Q (`query`), K (`key`), and V(`value`):\n    `attention(Q, K, V) = softmax(Q * K / sqrt(d)) * V`. If we define `logits`\n    as the output of `Q * K` and the `probs` as the output of `softmax`.\n\n    Throughout this function, we utilize the following notation to represent the\n    shape of array:\n    - B: batch size\n    - S: length of the key/value\n    - T: length of the query\n    - N: number of attention heads\n    - H: dimensions of each attention head\n    - K: number of key/value heads\n    - G: number of groups, which equals to `N // K`\n\n    Args:\n        query: The query array with the shape of `(B, T, N, H)`.\n        key: The key array with the shape of `(B, S, K, H)`. When `K` equals\n            `N`, multi-headed attention (MHA) is performed. Otherwise, grouped\n            query attention (GQA) is performed if `N` is a multiple of `K`. and\n            multi-query attention (MQA) is performed if `K==1` (a special case\n            of GQA).\n        value: The value array with the same shape of `key`.\n        bias: Optional bias array to be added to logits. The shape must be\n            broadcastable to `(B, N, T, S)`.\n        mask: Optional mask array used to filter out logits. It is a boolean\n            mask where `True` indicates the element should take part in\n            attention. For an additive mask, users should pass it to bias. The\n            shape must be broadcastable to `(B, N, T, S)`.\n        scale: Optional scale for the logits. If `None`, the scale will be set\n            to `1.0 / sqrt(H)`.\n        is_causal: Whether to apply causal mask.\n        flash_attention: Whether to use flash attention. If `None`, it will\n            attempt to use flash attention if the required conditions are met.\n            Typically, the inputs must be in float16 and bfloat16 dtype and the\n            input layout requirements may vary depending on the backend.\n        attn_logits_soft_cap: The value limit for maximum value of the\n            attention logits before the softmax function is applied. This is\n            only supported in JAX TPU backend. Defaults to None.\n\n    Returns:\n        An array of the attention output with the same shape of `query`.\n\n    Example:\n\n    >>> query = keras.random.normal((2, 4, 8, 16))\n    >>> key = keras.random.normal((2, 6, 8, 16))\n    >>> value = keras.random.normal((2, 6, 8, 16))\n    >>> keras.ops.nn.dot_product_attention(query, key, value).shape\n    (2, 4, 8, 16)\n    \"\"\"\n    if attn_logits_soft_cap is not None:\n        if backend.backend() == \"jax\":\n            import jax\n\n            if jax.devices()[0].platform != \"tpu\":\n                raise ValueError(\n                    \"attn_logits_soft_cap is only supported for JAX on TPU. \"\n                    \"Set attn_logits_soft_cap=None when not using JAX on TPU.\"\n                )\n        else:\n            raise ValueError(\n                \"attn_logits_soft_cap is only supported for JAX on TPU. \"\n                \"Set attn_logits_soft_cap=None when not using JAX on TPU.\"\n            )\n\n    if any_symbolic_tensors((query, key, value)):\n        return DotProductAttention(\n            is_causal=is_causal,\n            flash_attention=flash_attention,\n            attn_logits_soft_cap=attn_logits_soft_cap,\n        ).symbolic_call(\n            query,\n            key,\n            value,\n            bias=bias,\n            mask=mask,\n            scale=scale,\n        )\n    return backend.nn.dot_product_attention(\n        query,\n        key,\n        value,\n        bias=bias,\n        mask=mask,\n        scale=scale,\n        is_causal=is_causal,\n        flash_attention=flash_attention,\n        attn_logits_soft_cap=attn_logits_soft_cap,\n    )\n\n\nclass RMSNorm(Operation):\n    def __init__(self, axis=-1, epsilon=None, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n        self.epsilon = epsilon\n\n    def compute_output_spec(self, x, scale):\n        return KerasTensor(shape=x.shape, dtype=x.dtype)\n\n    def call(self, x, scale=None):\n        return _rms_normalization(\n            x, scale=scale, axis=self.axis, epsilon=self.epsilon\n        )\n\n\n@keras_export(\n    [\n        \"keras.ops.rms_normalization\",\n        \"keras.ops.nn.rms_normalization\",\n    ]\n)\ndef rms_normalization(x, scale=None, axis=-1, epsilon=None):\n    \"\"\"Performs Root Mean Square (RMS) normalization on `x`.\n\n    The Keras operation implements the operation as described in\n    [Root Mean Square Layer Normalization](https://arxiv.org/pdf/1910.07467)\n    by Biao Zhang et al.\n\n    The operation is different from LayerNormalization with RMS scaling.\n\n    It is defined as `rms_normalization(x) = x * rsqrt(mean(square(x))) * scale`\n\n    Args:\n        x: Input tensor.\n        scale: Optional scaling factor for the normalization.\n        axis: The axis or axes along which to perform normalization. Defaults\n            to `-1`.\n        epsilon: A lower bound value for the norm. Defaults to\n            `backend.epsilon()`.\n\n    Returns:\n        The normalized array.\n\n    Example:\n\n    >>> x = keras.random.normal((1, 10))\n    >>> keras.ops.rms_normalization(x)\n    array([[0.69384296, 0.94444374, 0.16551171, 0.05749961, 1.11008865,\n            0.52475186, 1.57686807, 1.69893307, 1.27292764, 0.30819128]])\n    \"\"\"\n    if any_symbolic_tensors((x, scale)):\n        return RMSNorm(axis=axis, epsilon=epsilon).symbolic_call(x, scale=scale)\n    return _rms_normalization(x, scale=scale, axis=axis, epsilon=epsilon)\n\n\ndef _rms_normalization(x, scale=None, axis=-1, epsilon=None):\n    if epsilon is None:\n        epsilon = backend.epsilon()\n    original_dtype = backend.standardize_dtype(x.dtype)\n    # Computes in at least float32 precision for stability in half precision\n    # training.\n    compute_dtype = backend.result_type(x.dtype, \"float32\")\n\n    x = backend.convert_to_tensor(x, dtype=compute_dtype)\n    if scale is not None:\n        scale = backend.convert_to_tensor(scale, x.dtype)\n\n    if isinstance(axis, (tuple, list)):\n        axis = sorted(axis)\n    if backend.backend() == \"torch\" and is_continuous_axis(axis):\n        import torch.nn.functional as F\n\n        if isinstance(axis, (tuple, list)):\n            normalized_shape = tuple(x.shape[dim] for dim in axis)\n        else:\n            normalized_shape = (x.shape[axis],)\n        outputs = F.rms_norm(x, normalized_shape, scale, epsilon)\n    else:\n        if len(x.shape) == 0:\n            x = backend.numpy.expand_dims(x, axis=0)\n        rrms = backend.math.rsqrt(\n            backend.numpy.mean(\n                backend.numpy.square(x), axis=axis, keepdims=True\n            )\n            + epsilon\n        )\n        outputs = backend.numpy.multiply(x, rrms)\n        if scale is not None:\n            outputs = backend.numpy.multiply(outputs, scale)\n    return backend.cast(outputs, original_dtype)\n\n\nclass LayerNorm(Operation):\n    def __init__(self, axis=-1, epsilon=None, rms_scaling=False, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n        self.epsilon = epsilon\n        self.rms_scaling = rms_scaling\n\n    def compute_output_spec(self, x, gamma, beta):\n        return KerasTensor(shape=x.shape, dtype=x.dtype)\n\n    def call(self, x, gamma=None, beta=None):\n        return _layer_normalization(\n            x,\n            gamma=gamma,\n            beta=beta,\n            axis=self.axis,\n            epsilon=self.epsilon,\n            rms_scaling=self.rms_scaling,\n        )\n\n\n@keras_export(\n    [\n        \"keras.ops.layer_normalization\",\n        \"keras.ops.nn.layer_normalization\",\n    ]\n)\ndef layer_normalization(\n    x, gamma=None, beta=None, axis=-1, epsilon=None, **kwargs\n):\n    \"\"\"Layer normalization layer (Ba et al., 2016).\n\n    Normalize the activations of the previous layer for each given example in a\n    batch independently, rather than across a batch like Batch Normalization.\n    i.e. applies a transformation that maintains the mean activation within each\n    example close to 0 and the activation standard deviation close to 1.\n\n    Args:\n        x: Input tensor.\n        gamma: Optional scaling factor for the normalization.\n        beta: Optional add offset for the normalized tensor.\n        axis: The axis or axes along which to perform normalization. Default to\n            `-1`.\n        epsilon: A lower bound value for the norm.\n            Defaults to `backend.epsilon()`.\n\n    Returns:\n        The normalized array.\n\n    Example:\n\n    >>> x = keras.ops.arange(5, dtype=\"float32\")\n    >>> keras.ops.layer_normalization(x)\n    array([-1.4142135, -0.70710677, 0.0, 0.7071067, 1.4142135])\n    \"\"\"\n    rms_scaling = kwargs.pop(\"rms_scaling\", False)\n    if rms_scaling:\n        warnings.warn(\n            \"You passed `rms_scaling=True`, which is deprecated. This argument \"\n            \"incorrectly scales the input by the variance, not the root mean \"\n            \"square. To correctly use RMS Normalization, please use \"\n            \"`keras.ops.rms_normalization` / `keras.ops.nn.rms_normalization` \"\n            \"instead.\"\n        )\n\n    if any_symbolic_tensors((x, gamma, beta)):\n        return LayerNorm(\n            axis=axis, epsilon=epsilon, rms_scaling=rms_scaling\n        ).symbolic_call(x, gamma, beta)\n    return _layer_normalization(\n        x,\n        gamma=gamma,\n        beta=beta,\n        axis=axis,\n        epsilon=epsilon,\n        rms_scaling=rms_scaling,\n    )\n\n\ndef _layer_normalization(\n    x, gamma=None, beta=None, axis=-1, epsilon=None, rms_scaling=False\n):\n    if epsilon is None:\n        epsilon = backend.epsilon()\n    original_dtype = backend.standardize_dtype(x.dtype)\n    # Computes in at least float32 precision for stability in half precision\n    # training.\n    compute_dtype = backend.result_type(x.dtype, \"float32\")\n\n    x = backend.convert_to_tensor(x, dtype=compute_dtype)\n    if gamma is not None:\n        gamma = backend.convert_to_tensor(gamma, x.dtype)\n    if beta is not None:\n        beta = backend.convert_to_tensor(beta, x.dtype)\n\n    # Compute the axes along which to reduce the mean / variance\n    input_shape = x.shape\n    ndims = len(input_shape)\n\n    # Broadcasting only necessary for norm when the axis is not just\n    # the last dimension\n    broadcast_shape = [1] * ndims\n    if isinstance(axis, int):\n        axis = [axis]\n    axis = sorted(axis)\n    for dim in axis:\n        broadcast_shape[dim] = input_shape[dim]\n\n    def _broadcast(v):\n        if v is not None and len(v.shape) != ndims and axis != [ndims - 1]:\n            return backend.numpy.reshape(v, broadcast_shape)\n        return v\n\n    if rms_scaling:\n        variance = backend.numpy.var(x, axis=axis, keepdims=True)\n        inv = backend.math.rsqrt(variance + epsilon)\n        outputs = outputs = x * inv\n        if gamma is not None:\n            outputs = outputs * backend.cast(_broadcast(gamma), x.dtype)\n    elif backend.config.backend() == \"torch\" and is_continuous_axis(axis):\n        # when using torch backend,use kernel to improve performance\n        import torch.nn.functional as F\n\n        normalized_shape = tuple(input_shape[dim] for dim in axis)\n        outputs = F.layer_norm(x, normalized_shape, gamma, beta, epsilon)\n    else:\n        # Calculate the mean & variance along self.axis (layer activations).\n        mean, variance = moments(x, axes=axis, keepdims=True)\n        gamma, beta = _broadcast(gamma), _broadcast(beta)\n        inv = backend.math.rsqrt(variance + epsilon)\n        if gamma is not None:\n            inv = inv * gamma\n\n        res = -mean * inv\n        if beta is not None:\n            res = res + beta\n\n        outputs = x * inv + res\n    return backend.cast(outputs, original_dtype)\n\n\nclass Polar(Operation):\n    def compute_output_spec(self, abs_, angle):\n        return KerasTensor(shape=abs_.shape)\n\n    def call(self, abs_, angle):\n        return _polar(abs_, angle)\n\n\n@keras_export([\"keras.ops.polar\", \"keras.ops.nn.polar\"])\ndef polar(abs_, angle):\n    \"\"\"Constructs a complex tensor whose elements are Cartesian\n    coordinates corresponding to the polar coordinates\n    with absolute value `abs` and angle `angle`.\n\n    The operation is numerically equivalent to `torch.polar()`.\n    It is not equivalent to `scipy.lingalg.polar()` which performs\n    Singular Value Decomposition.\n\n    Given the magnitude (`abs_`) and angle (`angle`), this function computes the\n    corresponding complex number in the form of `real + imaginary * 1j`, where:\n    - `real = abs_ * cos(angle)`\n    - `imaginary = abs_ * sin(angle)`\n\n    Args:\n        abs_: The magnitude (absolute value) of the complex number.\n        angle: The angle (in radians) of the complex number.\n\n    Returns:\n        A complex number (or array of complex numbers) with the same shape as\n        `abs_` and `angle`.\n\n    Example:\n\n    >>> abs_ = keras.random.normal((1, 2))\n    >>> angle = keras.random.normal((1, 2))\n    >>> keras.ops.nn.polar(abs_, angle).shape\n    (1, 2)\n    >>> keras.ops.nn.polar(abs_, angle)\n    Array([[0.63185346-0.59370506j, 0.48960376-0.31677645j]], dtype=complex64)\n    \"\"\"\n    if any_symbolic_tensors((abs_, angle)):\n        return Polar().symbolic_call(abs_, angle)\n    return _polar(abs_, angle)\n\n\ndef _polar(abs_, angle):\n    \"\"\"Internal implementation of the polar function.\n\n    Args:\n        abs_: The magnitude (absolute value) of the complex number.\n        angle: The angle (in radians) of the complex number.\n\n    Returns:\n        A complex number (or array of complex numbers) with the same shape as\n        `abs_` and `angle`.\n    \"\"\"\n    abs_ = backend.convert_to_tensor(abs_)\n    angle = backend.convert_to_tensor(angle)\n\n    real = abs_ * backend.numpy.cos(angle)\n    imaginary = abs_ * backend.numpy.sin(angle)\n\n    result = backend.math._get_complex_tensor_from_tuple((real, imaginary))\n\n    return result\n\n\nclass Unfold(Operation):\n    def __init__(\n        self, kernel_size, dilation=1, padding=0, stride=1, *, name=None\n    ):\n        super().__init__(name=name)\n        self.kernel_size = kernel_size\n        self.dilation = dilation\n        self.padding = padding\n        self.stride = stride\n\n    def compute_output_spec(self, x):\n        N, C, H, W = x.shape\n\n        def _pair(x):\n            return (x, x) if isinstance(x, int) else x\n\n        kH, kW = _pair(self.kernel_size)\n        dH, dW = _pair(self.dilation)\n        pH, pW = _pair(self.padding)\n        sH, sW = _pair(self.stride)\n\n        def out_size(L, k, d, p, s):\n            return (L + 2 * p - d * (k - 1) - 1) // s + 1\n\n        outH = out_size(H, kH, dH, pH, sH)\n        outW = out_size(W, kW, dW, pW, sW)\n        return KerasTensor(shape=(N, C * kH * kW, outH * outW), dtype=x.dtype)\n\n    def call(self, x):\n        return _unfold(\n            x, self.kernel_size, self.dilation, self.padding, self.stride\n        )\n\n\n@keras_export([\"keras.ops.unfold\", \"keras.ops.nn.unfold\"])\ndef unfold(x, kernel_size, dilation=1, padding=0, stride=1):\n    \"\"\"Extract sliding local blocks from a 4-D input (batched image).\n\n    This operation is known as **im2col** when used with convolution.\n    It rearranges the image into overlapping or non-overlapping patches\n    and returns a tensor whose *depth* (last axis) contains the flattened\n    patches.\n\n    Args:\n        x: A 4-D tensor of shape `(N, C, H, W)` (**channels-first** format).\n        kernel_size: int or tuple of two ints, the size of the sliding window\n            `(kH, kW)`.  If a single int is given, it is used for both\n            dimensions.\n        dilation: int or tuple of two ints, the spacing between kernel points\n            (a.k.a. **dilation** or **atrous** convolution). Default: 1.\n        padding: int or tuple of two ints, the amount of zero-padding to apply\n            to both spatial dimensions. Default: 0.\n        stride: int or tuple of two ints, the step size of the sliding window.\n            Default: 1.\n\n    Returns:\n        A 3-D tensor of shape `(N, C * kH * kW, L)` where\n        `L = num_patches_H * num_patches_W` is the total number of patches\n        extracted.\n\n    Example:\n\n    >>> x = keras.ops.ones((1, 2, 4, 4))\n    >>> patches = keras.ops.unfold(x, kernel_size=2, stride=2)\n    >>> patches.shape\n    (1, 8, 4)\n\n    \"\"\"\n    input_shape = x.shape\n    ndims = len(input_shape)\n    if ndims != 4:\n        raise ValueError(\n            f\"Input must be a 4D tensor. Received: input.shape={input_shape}\"\n        )\n    if any_symbolic_tensors((x,)):\n        return Unfold(kernel_size, dilation, padding, stride).symbolic_call(x)\n    return _unfold(x, kernel_size, dilation, padding, stride)\n\n\ndef _unfold(x, kernel_size, dilation=1, padding=0, stride=1):\n    \"\"\"Internal implementation of unfold.\"\"\"\n    return backend.nn.unfold(\n        x,\n        kernel_size=kernel_size,\n        dilation=dilation,\n        padding=padding,\n        stride=stride,\n    )\n\n\nclass Fold(Operation):\n    def __init__(\n        self,\n        output_size,\n        kernel_size,\n        dilation=1,\n        padding=0,\n        stride=1,\n        *,\n        name=None,\n    ):\n        super().__init__(name=name)\n        self.output_size = output_size\n        self.kernel_size = kernel_size\n        self.dilation = dilation\n        self.padding = padding\n        self.stride = stride\n\n    def compute_output_spec(self, x):\n        N, CKK, L = x.shape\n\n        def _pair(v):\n            return (v, v) if isinstance(v, int) else v\n\n        kH, kW = _pair(self.kernel_size)\n        oH, oW = _pair(self.output_size)\n\n        if CKK is not None and CKK % (kH * kW) != 0:\n            raise ValueError(\n                f\"The second dimension of the input ({CKK}) must be \"\n                f\"divisible by kernel_size product ({kH * kW}).\"\n            )\n\n        C = CKK // (kH * kW) if CKK is not None else None\n        return KerasTensor(shape=(N, C, oH, oW), dtype=x.dtype)\n\n    def call(self, x):\n        return backend.nn.fold(\n            x,\n            output_size=self.output_size,\n            kernel_size=self.kernel_size,\n            dilation=self.dilation,\n            padding=self.padding,\n            stride=self.stride,\n        )\n\n\n@keras_export([\"keras.ops.fold\", \"keras.ops.nn.fold\"])\ndef fold(x, output_size, kernel_size, dilation=1, padding=0, stride=1):\n    \"\"\"Combines an array of sliding local blocks into a large containing\n    tensor (reverses `unfold`).\n\n    This operation is known as **col2im** when used with convolution.\n    It takes a 3-D tensor of flattened patches and reconstructs a 4-D\n    image tensor by summing overlapping patches.\n\n    Args:\n        x: A 3-D tensor of shape `(N, C * kH * kW, L)` where `L` is\n            the total number of blocks.\n        output_size: int or tuple of two ints `(oH, oW)`, the spatial\n            shape of the output tensor.\n        kernel_size: int or tuple of two ints, the size of the sliding\n            window `(kH, kW)`.  If a single int is given, it is used\n            for both dimensions.\n        dilation: int or tuple of two ints, the spacing between kernel\n            points. Default: 1.\n        padding: int or tuple of two ints, the amount of zero-padding\n            that was applied to the input of `unfold`. Default: 0.\n        stride: int or tuple of two ints, the step size of the sliding\n            window. Default: 1.\n\n    Returns:\n        A 4-D tensor of shape `(N, C, oH, oW)`.\n\n    Example:\n\n    >>> x = keras.ops.ones((1, 2, 4, 4))\n    >>> patches = keras.ops.unfold(x, kernel_size=2, stride=2)\n    >>> patches.shape\n    (1, 8, 4)\n    >>> y = keras.ops.fold(patches, output_size=(4, 4),\n    ...     kernel_size=2, stride=2)\n    >>> y.shape\n    (1, 2, 4, 4)\n\n    \"\"\"\n    input_shape = x.shape\n    ndims = len(input_shape)\n    if ndims != 3:\n        raise ValueError(\n            f\"Input must be a 3D tensor. Received: input.shape={input_shape}\"\n        )\n    if any_symbolic_tensors((x,)):\n        return Fold(\n            output_size, kernel_size, dilation, padding, stride\n        ).symbolic_call(x)\n    return backend.nn.fold(\n        x,\n        output_size=output_size,\n        kernel_size=kernel_size,\n        dilation=dilation,\n        padding=padding,\n        stride=stride,\n    )\n\n\nclass DepthToSpace(Operation):\n    def __init__(self, block_size, data_format=\"channels_last\", *, name=None):\n        super().__init__(name=name)\n        self.block_size = block_size\n        self.data_format = standardize_data_format(data_format)\n\n    def compute_output_spec(self, x):\n        if len(x.shape) != 4:\n            raise ValueError(\n                \"`depth_to_space` requires a 4D input tensor. \"\n                f\"Received: x.shape={x.shape}\"\n            )\n        if self.data_format == \"channels_last\":\n            b, h, w, c = x.shape\n        else:\n            b, c, h, w = x.shape\n\n        if c is not None and c % (self.block_size**2) != 0:\n            raise ValueError(\n                f\"The number of channels ({c}) must be divisible by \"\n                f\"block_size**2 ({self.block_size**2}).\"\n            )\n\n        new_c = c // (self.block_size**2) if c is not None else None\n        new_h = h * self.block_size if h is not None else None\n        new_w = w * self.block_size if w is not None else None\n\n        if self.data_format == \"channels_last\":\n            output_shape = (b, new_h, new_w, new_c)\n        else:\n            output_shape = (b, new_c, new_h, new_w)\n\n        return KerasTensor(output_shape, dtype=x.dtype)\n\n    def call(self, x):\n        return backend.nn.depth_to_space(\n            x, self.block_size, data_format=self.data_format\n        )\n\n\n@keras_export([\"keras.ops.depth_to_space\", \"keras.ops.nn.depth_to_space\"])\ndef depth_to_space(x, block_size, data_format=\"channels_last\"):\n    \"\"\"Rearranges data from depth into blocks of spatial data.\n\n    This operation is useful for resizing the activations between convolutions\n    (but keeping all data), e.g., instead of pooling. It is also useful for\n    training purely convolutional models.\n\n    Also known as pixel shuffle, this operation rearranges elements in a tensor\n    of shape `(N, H, W, C * r^2)` to `(N, H * r, W * r, C)` where `r` is the\n    `block_size` for `data_format=\"channels_last\"`, or from\n    `(N, C * r^2, H, W)` to `(N, C, H * r, W * r)` for\n    `data_format=\"channels_first\"`.\n\n    This is the reverse transformation of `space_to_depth`.\n\n    Args:\n        x: Input tensor. Must be 4D.\n        block_size: An integer specifying the size of the spatial block.\n            The depth (number of channels) must be divisible by\n            `block_size ** 2`.\n        data_format: A string specifying the data format of the input tensor.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch, height, width, channels)` while `\"channels_first\"`\n            corresponds to inputs with shape `(batch, channels, height, width)`.\n            Defaults to `\"channels_last\"`.\n\n    Returns:\n        A tensor with the same dtype as `x`, with shape\n        `(N, H * block_size, W * block_size, C // block_size ** 2)` for\n        `data_format=\"channels_last\"` or\n        `(N, C // block_size ** 2, H * block_size, W * block_size)` for\n        `data_format=\"channels_first\"`.\n\n    Example:\n\n    >>> x = keras.ops.reshape(keras.ops.arange(1 * 2 * 2 * 12), (1, 2, 2, 12))\n    >>> keras.ops.depth_to_space(x, block_size=2).shape\n    (1, 4, 4, 3)\n\n    >>> # channels_first example\n    >>> x = keras.ops.reshape(keras.ops.arange(1 * 12 * 2 * 2), (1, 12, 2, 2))\n    >>> keras.ops.depth_to_space(x, block_size=2,\n    ...                          data_format=\"channels_first\").shape\n    (1, 3, 4, 4)\n    \"\"\"\n    data_format = standardize_data_format(data_format)\n    if block_size < 2:\n        raise ValueError(\n            \"`block_size` must be at least 2. \"\n            f\"Received: block_size={block_size}\"\n        )\n    if any_symbolic_tensors((x,)):\n        return DepthToSpace(block_size, data_format=data_format).symbolic_call(\n            x\n        )\n    return backend.nn.depth_to_space(x, block_size, data_format=data_format)\n\n\nclass SpaceToDepth(Operation):\n    def __init__(self, block_size, data_format=\"channels_last\", *, name=None):\n        super().__init__(name=name)\n        self.block_size = block_size\n        self.data_format = standardize_data_format(data_format)\n\n    def compute_output_spec(self, x):\n        if len(x.shape) != 4:\n            raise ValueError(\n                \"`space_to_depth` requires a 4D input tensor. \"\n                f\"Received: x.shape={x.shape}\"\n            )\n        if self.data_format == \"channels_last\":\n            b, h, w, c = x.shape\n        else:\n            b, c, h, w = x.shape\n\n        if h is not None and h % self.block_size != 0:\n            raise ValueError(\n                f\"Height ({h}) must be divisible by block_size \"\n                f\"({self.block_size}).\"\n            )\n        if w is not None and w % self.block_size != 0:\n            raise ValueError(\n                f\"Width ({w}) must be divisible by block_size \"\n                f\"({self.block_size}).\"\n            )\n\n        new_c = c * (self.block_size**2) if c is not None else None\n        new_h = h // self.block_size if h is not None else None\n        new_w = w // self.block_size if w is not None else None\n\n        if self.data_format == \"channels_last\":\n            output_shape = (b, new_h, new_w, new_c)\n        else:\n            output_shape = (b, new_c, new_h, new_w)\n\n        return KerasTensor(output_shape, dtype=x.dtype)\n\n    def call(self, x):\n        return backend.nn.space_to_depth(\n            x, self.block_size, data_format=self.data_format\n        )\n\n\n@keras_export([\"keras.ops.space_to_depth\", \"keras.ops.nn.space_to_depth\"])\ndef space_to_depth(x, block_size, data_format=\"channels_last\"):\n    \"\"\"Rearranges blocks of spatial data into depth.\n\n    This operation is useful for resizing the activations between convolutions\n    (but keeping all data). It is also useful for training purely convolutional\n    models.\n\n    This operation rearranges elements in a tensor of shape\n    `(N, H * block_size, W * block_size, C)` to a tensor of shape\n    `(N, H, W, C * block_size ** 2)` (for `data_format=\"channels_last\"`)\n    or `(N, C, H * block_size, W * block_size)` to\n    `(N, C * block_size ** 2, H, W)` (for `data_format=\"channels_first\"`).\n\n    This is the reverse transformation of `depth_to_space`.\n\n    Args:\n        x: Input tensor. Must be 4D.\n        block_size: An integer specifying the size of the spatial block.\n            The height and width of the input must be divisible by\n            `block_size`.\n        data_format: A string specifying the data format of the input tensor.\n            `\"channels_last\"` corresponds to inputs with shape\n            `(batch, height, width, channels)` while `\"channels_first\"`\n            corresponds to inputs with shape `(batch, channels, height, width)`.\n            Defaults to `\"channels_last\"`.\n\n    Returns:\n        A tensor with the same dtype as `x`, with shape\n        `(N, H // block_size, W // block_size, C * block_size ** 2)` for\n        `data_format=\"channels_last\"` or\n        `(N, C * block_size ** 2, H // block_size, W // block_size)` for\n        `data_format=\"channels_first\"`.\n\n    Example:\n\n    >>> x = keras.ops.reshape(keras.ops.arange(1 * 4 * 4 * 3), (1, 4, 4, 3))\n    >>> keras.ops.space_to_depth(x, block_size=2).shape\n    (1, 2, 2, 12)\n\n    >>> # channels_first example\n    >>> x = keras.ops.reshape(keras.ops.arange(1 * 3 * 4 * 4), (1, 3, 4, 4))\n    >>> keras.ops.space_to_depth(x, block_size=2,\n    ...                          data_format=\"channels_first\").shape\n    (1, 12, 2, 2)\n    \"\"\"\n    data_format = standardize_data_format(data_format)\n    if block_size < 2:\n        raise ValueError(\n            \"`block_size` must be at least 2. \"\n            f\"Received: block_size={block_size}\"\n        )\n    if any_symbolic_tensors((x,)):\n        return SpaceToDepth(block_size, data_format=data_format).symbolic_call(\n            x\n        )\n    return backend.nn.space_to_depth(x, block_size, data_format=data_format)\n"
  },
  {
    "path": "keras/src/ops/nn_test.py",
    "content": "import math\nfrom itertools import combinations\n\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nimport keras\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import losses\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.backend.common import dtypes\nfrom keras.src.backend.common import standardize_dtype\nfrom keras.src.backend.common.keras_tensor import KerasTensor\nfrom keras.src.layers.convolutional.conv_test import np_conv1d\nfrom keras.src.layers.convolutional.conv_test import np_conv2d\nfrom keras.src.layers.convolutional.conv_test import np_conv3d\nfrom keras.src.layers.convolutional.conv_transpose_test import (\n    np_conv1d_transpose,\n)\nfrom keras.src.layers.convolutional.conv_transpose_test import (\n    np_conv2d_transpose,\n)\nfrom keras.src.layers.convolutional.depthwise_conv_test import (\n    np_depthwise_conv2d,\n)\nfrom keras.src.layers.pooling.average_pooling_test import np_avgpool1d\nfrom keras.src.layers.pooling.average_pooling_test import np_avgpool2d\nfrom keras.src.layers.pooling.max_pooling_test import np_maxpool1d\nfrom keras.src.layers.pooling.max_pooling_test import np_maxpool2d\nfrom keras.src.ops import nn as knn\nfrom keras.src.ops import numpy as knp\nfrom keras.src.testing.test_utils import named_product\n\n\ndef _dot_product_attention(\n    query, key, value, bias=None, mask=None, scale=None, is_causal=False\n):\n    # A pure and simplified numpy version of `dot_product_attention`\n    # Ref: jax.nn.dot_product_attention\n    # https://github.com/jax-ml/jax/blob/jax-v0.4.32/jax/_src/nn/functions.py#L828\n    # Not support `query_seq_lengths` and `key_value_seq_lengths` args\n\n    def _apply_masks(logits, mask, is_causal):\n        def _get_large_negative(dtype):\n            dtype = backend.standardize_dtype(dtype)\n            if dtype == \"float16\":\n                val = 65500.0\n            else:\n                val = 3.38953e38\n            return np.asarray(val * -0.7, dtype=dtype)\n\n        def _get_causal_mask(query_length, key_length):\n            mask = np.tril(np.ones((query_length, key_length), dtype=np.bool_))\n            return mask[None, None, :, :]\n\n        if mask is None and not is_causal:\n            return logits\n        combined_mask = np.ones_like(logits, dtype=np.bool_)\n        if mask is not None:\n            combined_mask = np.logical_and(combined_mask, mask)\n        if is_causal:\n            T, S = logits.shape[2], logits.shape[3]\n            mask = _get_causal_mask(T, S)\n            combined_mask = np.logical_and(combined_mask, mask)\n        padded_logits = np.where(\n            combined_mask, logits, _get_large_negative(logits.dtype)\n        )\n        return padded_logits\n\n    def softmax(x, axis=None):\n        exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))\n        return exp_x / np.sum(exp_x, axis=axis, keepdims=True)\n\n    _, _, _, H = key.shape\n    scale = (1.0 / np.sqrt(H)) if scale is None else scale\n    logits = np.einsum(\"BTNH,BSNH->BNTS\", query, key)\n    logits *= np.array(scale, dtype=logits.dtype)\n    if bias is not None:\n        logits = (logits + bias).astype(logits.dtype)\n    padded_logits = _apply_masks(logits, mask, is_causal)\n    padded_logits = padded_logits.astype(np.float32)\n    probs = softmax(padded_logits, axis=-1).astype(key.dtype)\n    return np.einsum(\"BNTS,BSNH->BTNH\", probs, value)\n\n\nclass NNOpsDynamicShapeTest(testing.TestCase):\n    def test_relu(self):\n        x = KerasTensor([None, 2, 3])\n        self.assertEqual(knn.relu(x).shape, (None, 2, 3))\n\n    def test_relu6(self):\n        x = KerasTensor([None, 2, 3])\n        self.assertEqual(knn.relu6(x).shape, (None, 2, 3))\n\n    def test_sigmoid(self):\n        x = KerasTensor([None, 2, 3])\n        self.assertEqual(knn.sigmoid(x).shape, (None, 2, 3))\n\n    def test_sparse_sigmoid(self):\n        x = KerasTensor([None, 2, 3])\n        self.assertEqual(knn.sparse_sigmoid(x).shape, (None, 2, 3))\n\n    def test_softplus(self):\n        x = KerasTensor([None, 2, 3])\n        self.assertEqual(knn.softplus(x).shape, (None, 2, 3))\n\n    def test_softsign(self):\n        x = KerasTensor([None, 2, 3])\n        self.assertEqual(knn.softsign(x).shape, (None, 2, 3))\n\n    def test_silu(self):\n        x = KerasTensor([None, 2, 3])\n        self.assertEqual(knn.silu(x).shape, (None, 2, 3))\n\n    def test_log_sigmoid(self):\n        x = KerasTensor([None, 2, 3])\n        self.assertEqual(knn.log_sigmoid(x).shape, (None, 2, 3))\n\n    def test_leaky_relu(self):\n        x = KerasTensor([None, 2, 3])\n        self.assertEqual(knn.leaky_relu(x).shape, (None, 2, 3))\n\n    def test_hard_sigmoid(self):\n        x = KerasTensor([None, 2, 3])\n        self.assertEqual(knn.hard_sigmoid(x).shape, (None, 2, 3))\n\n    def test_hard_silu(self):\n        x = KerasTensor([None, 2, 3])\n        self.assertEqual(knn.hard_silu(x).shape, (None, 2, 3))\n\n    def test_elu(self):\n        x = KerasTensor([None, 2, 3])\n        self.assertEqual(knn.elu(x).shape, (None, 2, 3))\n\n    def test_selu(self):\n        x = KerasTensor([None, 2, 3])\n        self.assertEqual(knn.selu(x).shape, (None, 2, 3))\n\n    def test_gelu(self):\n        x = KerasTensor([None, 2, 3])\n        self.assertEqual(knn.gelu(x).shape, (None, 2, 3))\n\n    def test_celu(self):\n        x = KerasTensor([None, 2, 3])\n        self.assertEqual(knn.celu(x).shape, (None, 2, 3))\n\n    def test_glu(self):\n        x = KerasTensor([None, 2, 4])\n        self.assertEqual(knn.glu(x).shape, (None, 2, 2))\n\n    def test_tanh_shrink(self):\n        x = KerasTensor([None, 2, 3])\n        self.assertEqual(knn.tanh_shrink(x).shape, (None, 2, 3))\n\n    def test_hard_tanh(self):\n        x = KerasTensor([None, 2, 3])\n        self.assertEqual(knn.hard_tanh(x).shape, (None, 2, 3))\n\n    def test_hard_shrink(self):\n        x = KerasTensor([None, 2, 3])\n        self.assertEqual(knn.hard_shrink(x).shape, (None, 2, 3))\n\n    def test_threshld(self):\n        x = KerasTensor([None, 2, 3])\n        self.assertEqual(knn.threshold(x, 0, 0).shape, (None, 2, 3))\n\n    def test_squareplus(self):\n        x = KerasTensor([None, 2, 3])\n        self.assertEqual(knn.squareplus(x).shape, (None, 2, 3))\n\n    def test_soft_shrink(self):\n        x = KerasTensor([None, 2, 3])\n        self.assertEqual(knn.soft_shrink(x).shape, (None, 2, 3))\n\n    def test_sparse_plus(self):\n        x = KerasTensor([None, 2, 3])\n        self.assertEqual(knn.sparse_plus(x).shape, (None, 2, 3))\n\n    def test_softmax(self):\n        x = KerasTensor([None, 2, 3])\n        self.assertEqual(knn.softmax(x).shape, (None, 2, 3))\n        self.assertEqual(knn.softmax(x, axis=1).shape, (None, 2, 3))\n        self.assertEqual(knn.softmax(x, axis=-1).shape, (None, 2, 3))\n\n    def test_softmax_in_graph(self):\n        class SoftmaxLayer(keras.Layer):\n            def call(self, x):\n                return ops.softmax(x, axis=-1)\n\n        class Model(keras.Model):\n            def __init__(self):\n                x = keras.Input(shape=(None,))\n                y = SoftmaxLayer()(x)\n                super().__init__(inputs=x, outputs=y)\n\n        # Make sure Keras is able to compile the model graph\n        model = Model()\n        x = ops.array([[1.0, 2.0, 3.0, 4.0]])\n        model.predict(x)\n\n    def test_log_softmax(self):\n        x = KerasTensor([None, 2, 3])\n        self.assertEqual(knn.log_softmax(x).shape, (None, 2, 3))\n        self.assertEqual(knn.log_softmax(x, axis=1).shape, (None, 2, 3))\n        self.assertEqual(knn.log_softmax(x, axis=-1).shape, (None, 2, 3))\n\n    def test_sparsemax(self):\n        x = KerasTensor([None, 2, 3])\n        self.assertEqual(knn.sparsemax(x).shape, (None, 2, 3))\n\n    def test_max_pool(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_shape = (None, 8, 3)\n        else:\n            input_shape = (None, 3, 8)\n        x = KerasTensor(input_shape)\n        self.assertEqual(\n            knn.max_pool(x, 2, 1).shape,\n            (None, 7, 3) if data_format == \"channels_last\" else (None, 3, 7),\n        )\n        self.assertEqual(\n            knn.max_pool(x, 2, 2, padding=\"same\").shape,\n            (None, 4, 3) if data_format == \"channels_last\" else (None, 3, 4),\n        )\n\n        if data_format == \"channels_last\":\n            input_shape = (None, 8, None, 3)\n        else:\n            input_shape = (None, 3, 8, None)\n        x = KerasTensor(input_shape)\n        (\n            self.assertEqual(knn.max_pool(x, 2, 1).shape, (None, 7, None, 3))\n            if data_format == \"channels_last\"\n            else (None, 3, 7, None)\n        )\n        self.assertEqual(\n            knn.max_pool(x, 2, 2, padding=\"same\").shape,\n            (\n                (None, 4, None, 3)\n                if data_format == \"channels_last\"\n                else (None, 3, 4, None)\n            ),\n        )\n        self.assertEqual(\n            knn.max_pool(x, (2, 2), (2, 2), padding=\"same\").shape,\n            (\n                (None, 4, None, 3)\n                if data_format == \"channels_last\"\n                else (None, 3, 4, None)\n            ),\n        )\n\n    def test_average_pool(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_shape = (None, 8, 3)\n        else:\n            input_shape = (None, 3, 8)\n        x = KerasTensor(input_shape)\n        self.assertEqual(\n            knn.average_pool(x, 2, 1).shape,\n            (None, 7, 3) if data_format == \"channels_last\" else (None, 3, 7),\n        )\n        self.assertEqual(\n            knn.average_pool(x, 2, 2, padding=\"same\").shape,\n            (None, 4, 3) if data_format == \"channels_last\" else (None, 3, 4),\n        )\n\n        if data_format == \"channels_last\":\n            input_shape = (None, 8, None, 3)\n        else:\n            input_shape = (None, 3, 8, None)\n        x = KerasTensor(input_shape)\n        self.assertEqual(\n            knn.average_pool(x, 2, 1).shape,\n            (\n                (None, 7, None, 3)\n                if data_format == \"channels_last\"\n                else (None, 3, 7, None)\n            ),\n        )\n        self.assertEqual(\n            knn.average_pool(x, 2, 2, padding=\"same\").shape,\n            (\n                (None, 4, None, 3)\n                if data_format == \"channels_last\"\n                else (None, 3, 4, None)\n            ),\n        )\n        self.assertEqual(\n            knn.average_pool(x, (2, 2), (2, 2), padding=\"same\").shape,\n            (\n                (None, 4, None, 3)\n                if data_format == \"channels_last\"\n                else (None, 3, 4, None)\n            ),\n        )\n\n    def test_multi_hot(self):\n        x = KerasTensor([None, 3, 1])\n        self.assertEqual(knn.multi_hot(x, 5).shape, (None, 1, 5))\n        self.assertEqual(knn.multi_hot(x, 5, 1).shape, (None, 3, 1))\n        self.assertEqual(knn.multi_hot(x, 5, 2).shape, (None, 5, 1))\n        self.assertSparse(knn.multi_hot(x, 5, sparse=True))\n\n    @parameterized.named_parameters(\n        named_product(dtype=[\"float32\", \"int32\", \"bool\"], sparse=[False, True])\n    )\n    def test_multi_hot_dtype(self, dtype, sparse):\n        if sparse and not backend.SUPPORTS_SPARSE_TENSORS:\n            pytest.skip(\"Backend does not support sparse tensors\")\n\n        x = np.arange(5)\n        out = knn.multi_hot(x, 5, axis=0, dtype=dtype, sparse=sparse)\n        self.assertEqual(backend.standardize_dtype(out.dtype), dtype)\n        self.assertSparse(out, sparse)\n\n    def test_conv(self):\n        data_format = backend.config.image_data_format()\n        # Test 1D conv.\n        if data_format == \"channels_last\":\n            input_shape = (None, 20, 3)\n        else:\n            input_shape = (None, 3, 20)\n        inputs_1d = KerasTensor(input_shape)\n        kernel = KerasTensor([4, 3, 2])\n        for padding in [\"valid\", \"VALID\"]:\n            self.assertEqual(\n                knn.conv(inputs_1d, kernel, 1, padding=padding).shape,\n                (\n                    (None, 17, 2)\n                    if data_format == \"channels_last\"\n                    else (None, 2, 17)\n                ),\n            )\n        for padding in [\"same\", \"SAME\"]:\n            self.assertEqual(\n                knn.conv(inputs_1d, kernel, 1, padding=padding).shape,\n                (\n                    (None, 20, 2)\n                    if data_format == \"channels_last\"\n                    else (None, 2, 20)\n                ),\n            )\n        self.assertEqual(\n            knn.conv(inputs_1d, kernel, (2,), dilation_rate=2).shape,\n            (None, 7, 2) if data_format == \"channels_last\" else (None, 2, 7),\n        )\n\n        # Test 2D conv.\n        if data_format == \"channels_last\":\n            input_shape = (None, 10, None, 3)\n        else:\n            input_shape = (None, 3, 10, None)\n        inputs_2d = KerasTensor(input_shape)\n        kernel = KerasTensor([2, 2, 3, 2])\n        for padding in [\"valid\", \"VALID\"]:\n            self.assertEqual(\n                knn.conv(inputs_2d, kernel, 1, padding=padding).shape,\n                (\n                    (None, 9, None, 2)\n                    if data_format == \"channels_last\"\n                    else (None, 2, 9, None)\n                ),\n            )\n        for padding in [\"same\", \"SAME\"]:\n            self.assertEqual(\n                knn.conv(inputs_2d, kernel, 1, padding=padding).shape,\n                (\n                    (None, 10, None, 2)\n                    if data_format == \"channels_last\"\n                    else (None, 2, 10, None)\n                ),\n            )\n        self.assertEqual(\n            knn.conv(inputs_2d, kernel, (2, 1), dilation_rate=(2, 1)).shape,\n            (\n                (None, 4, None, 2)\n                if data_format == \"channels_last\"\n                else (None, 2, 4, None)\n            ),\n        )\n\n        # Test 2D conv - H, W specified\n        if data_format == \"channels_last\":\n            input_shape = (None, 10, 10, 3)\n        else:\n            input_shape = (None, 3, 10, 10)\n        inputs_2d = KerasTensor(input_shape)\n        kernel = KerasTensor([2, 2, 3, 2])\n        for padding in [\"valid\", \"VALID\"]:\n            self.assertEqual(\n                knn.conv(inputs_2d, kernel, 1, padding=padding).shape,\n                (\n                    (None, 9, 9, 2)\n                    if data_format == \"channels_last\"\n                    else (None, 2, 9, 9)\n                ),\n            )\n        for padding in [\"same\", \"SAME\"]:\n            self.assertEqual(\n                knn.conv(inputs_2d, kernel, 1, padding=padding).shape,\n                (\n                    (None, 10, 10, 2)\n                    if data_format == \"channels_last\"\n                    else (None, 2, 10, 10)\n                ),\n            )\n        self.assertEqual(\n            knn.conv(inputs_2d, kernel, (2, 1), dilation_rate=(2, 1)).shape,\n            (\n                (None, 4, 9, 2)\n                if data_format == \"channels_last\"\n                else (None, 2, 4, 9)\n            ),\n        )\n\n        # Test 3D conv.\n        if data_format == \"channels_last\":\n            input_shape = (None, 8, None, 8, 3)\n        else:\n            input_shape = (None, 3, 8, None, 8)\n        inputs_3d = KerasTensor(input_shape)\n        kernel = KerasTensor([3, 3, 3, 3, 2])\n        for padding in [\"valid\", \"VALID\"]:\n            self.assertEqual(\n                knn.conv(inputs_3d, kernel, 1, padding=padding).shape,\n                (\n                    (None, 6, None, 6, 2)\n                    if data_format == \"channels_last\"\n                    else (None, 2, 6, None, 6)\n                ),\n            )\n        for padding in [\"same\", \"SAME\"]:\n            self.assertEqual(\n                knn.conv(inputs_3d, kernel, (2, 1, 2), padding=padding).shape,\n                (\n                    (None, 4, None, 4, 2)\n                    if data_format == \"channels_last\"\n                    else (None, 2, 4, None, 4)\n                ),\n            )\n        self.assertEqual(\n            knn.conv(\n                inputs_3d, kernel, 1, padding=\"valid\", dilation_rate=(1, 2, 2)\n            ).shape,\n            (\n                (None, 6, None, 4, 2)\n                if data_format == \"channels_last\"\n                else (None, 2, 6, None, 4)\n            ),\n        )\n\n    def test_depthwise_conv(self):\n        data_format = backend.config.image_data_format()\n        # Test 1D depthwise conv.\n        if data_format == \"channels_last\":\n            input_shape = (None, 20, 3)\n        else:\n            input_shape = (None, 3, 20)\n        inputs_1d = KerasTensor(input_shape)\n        kernel = KerasTensor([4, 3, 1])\n        for padding in [\"valid\", \"VALID\"]:\n            self.assertEqual(\n                knn.depthwise_conv(inputs_1d, kernel, 1, padding=padding).shape,\n                (\n                    (None, 17, 3)\n                    if data_format == \"channels_last\"\n                    else (None, 3, 17)\n                ),\n            )\n        for padding in [\"same\", \"SAME\"]:\n            self.assertEqual(\n                knn.depthwise_conv(\n                    inputs_1d, kernel, (1,), padding=padding\n                ).shape,\n                (\n                    (None, 20, 3)\n                    if data_format == \"channels_last\"\n                    else (None, 3, 20)\n                ),\n            )\n        self.assertEqual(\n            knn.depthwise_conv(inputs_1d, kernel, 2, dilation_rate=2).shape,\n            (None, 7, 3) if data_format == \"channels_last\" else (None, 3, 7),\n        )\n\n        # Test 2D depthwise conv.\n        if data_format == \"channels_last\":\n            input_shape = (None, 10, 10, 3)\n        else:\n            input_shape = (None, 3, 10, 10)\n        inputs_2d = KerasTensor(input_shape)\n        kernel = KerasTensor([2, 2, 3, 1])\n        for padding in [\"valid\", \"VALID\"]:\n            self.assertEqual(\n                knn.depthwise_conv(inputs_2d, kernel, 1, padding=padding).shape,\n                (\n                    (None, 9, 9, 3)\n                    if data_format == \"channels_last\"\n                    else (None, 3, 9, 9)\n                ),\n            )\n        for padding in [\"same\", \"SAME\"]:\n            self.assertEqual(\n                knn.depthwise_conv(\n                    inputs_2d, kernel, (1, 2), padding=padding\n                ).shape,\n                (\n                    (None, 10, 5, 3)\n                    if data_format == \"channels_last\"\n                    else (None, 3, 10, 5)\n                ),\n            )\n        self.assertEqual(\n            knn.depthwise_conv(inputs_2d, kernel, 2, dilation_rate=2).shape,\n            (\n                (None, 4, 4, 3)\n                if data_format == \"channels_last\"\n                else (None, 3, 4, 4)\n            ),\n        )\n        self.assertEqual(\n            knn.depthwise_conv(\n                inputs_2d, kernel, 2, dilation_rate=(2, 1)\n            ).shape,\n            (\n                (None, 4, 5, 3)\n                if data_format == \"channels_last\"\n                else (None, 3, 4, 5)\n            ),\n        )\n\n    def test_separable_conv(self):\n        data_format = backend.config.image_data_format()\n        # Test 1D separable conv.\n        if data_format == \"channels_last\":\n            input_shape = (None, 20, 3)\n        else:\n            input_shape = (None, 3, 20)\n        inputs_1d = KerasTensor(input_shape)\n        kernel = KerasTensor([4, 3, 2])\n        pointwise_kernel = KerasTensor([1, 6, 5])\n        self.assertEqual(\n            knn.separable_conv(\n                inputs_1d, kernel, pointwise_kernel, 1, padding=\"valid\"\n            ).shape,\n            (None, 17, 5) if data_format == \"channels_last\" else (None, 5, 17),\n        )\n        self.assertEqual(\n            knn.separable_conv(\n                inputs_1d, kernel, pointwise_kernel, 1, padding=\"same\"\n            ).shape,\n            (None, 20, 5) if data_format == \"channels_last\" else (None, 5, 20),\n        )\n        self.assertEqual(\n            knn.separable_conv(\n                inputs_1d, kernel, pointwise_kernel, 2, dilation_rate=2\n            ).shape,\n            (None, 7, 5) if data_format == \"channels_last\" else (None, 5, 7),\n        )\n\n        # Test 2D separable conv.\n        if data_format == \"channels_last\":\n            input_shape = (None, 10, 10, 3)\n        else:\n            input_shape = (None, 3, 10, 10)\n        inputs_2d = KerasTensor(input_shape)\n        kernel = KerasTensor([2, 2, 3, 2])\n        pointwise_kernel = KerasTensor([1, 1, 6, 5])\n        self.assertEqual(\n            knn.separable_conv(\n                inputs_2d, kernel, pointwise_kernel, 1, padding=\"valid\"\n            ).shape,\n            (\n                (None, 9, 9, 5)\n                if data_format == \"channels_last\"\n                else (None, 5, 9, 9)\n            ),\n        )\n        self.assertEqual(\n            knn.separable_conv(\n                inputs_2d, kernel, pointwise_kernel, (1, 2), padding=\"same\"\n            ).shape,\n            (\n                (None, 10, 5, 5)\n                if data_format == \"channels_last\"\n                else (None, 5, 10, 5)\n            ),\n        )\n        self.assertEqual(\n            knn.separable_conv(\n                inputs_2d, kernel, pointwise_kernel, 2, dilation_rate=(2, 1)\n            ).shape,\n            (\n                (None, 4, 5, 5)\n                if data_format == \"channels_last\"\n                else (None, 5, 4, 5)\n            ),\n        )\n\n    def test_conv_transpose(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_shape = (None, 4, 3)\n        else:\n            input_shape = (None, 3, 4)\n        inputs_1d = KerasTensor(input_shape)\n        kernel = KerasTensor([2, 5, 3])\n        self.assertEqual(\n            knn.conv_transpose(inputs_1d, kernel, 2).shape,\n            (None, 8, 5) if data_format == \"channels_last\" else (None, 5, 8),\n        )\n        self.assertEqual(\n            knn.conv_transpose(inputs_1d, kernel, 2, padding=\"same\").shape,\n            (None, 8, 5) if data_format == \"channels_last\" else (None, 5, 8),\n        )\n        self.assertEqual(\n            knn.conv_transpose(\n                inputs_1d, kernel, 5, padding=\"valid\", output_padding=4\n            ).shape,\n            (None, 21, 5) if data_format == \"channels_last\" else (None, 5, 21),\n        )\n\n        if data_format == \"channels_last\":\n            input_shape = (None, 4, 4, 3)\n        else:\n            input_shape = (None, 3, 4, 4)\n        inputs_2d = KerasTensor(input_shape)\n        kernel = KerasTensor([2, 2, 5, 3])\n        self.assertEqual(\n            knn.conv_transpose(inputs_2d, kernel, 2).shape,\n            (\n                (None, 8, 8, 5)\n                if data_format == \"channels_last\"\n                else (None, 5, 8, 8)\n            ),\n        )\n        self.assertEqual(\n            knn.conv_transpose(inputs_2d, kernel, (2, 2), padding=\"same\").shape,\n            (\n                (None, 8, 8, 5)\n                if data_format == \"channels_last\"\n                else (None, 5, 8, 8)\n            ),\n        )\n        self.assertEqual(\n            knn.conv_transpose(\n                inputs_2d, kernel, (5, 5), padding=\"valid\", output_padding=4\n            ).shape,\n            (\n                (None, 21, 21, 5)\n                if data_format == \"channels_last\"\n                else (None, 5, 21, 21)\n            ),\n        )\n\n    def test_one_hot(self):\n        x = KerasTensor([None, 3, 1])\n        self.assertEqual(knn.one_hot(x, 5).shape, (None, 3, 1, 5))\n        self.assertEqual(knn.one_hot(x, 5, 1).shape, (None, 5, 3, 1))\n        self.assertEqual(knn.one_hot(x, 5, 2).shape, (None, 3, 5, 1))\n        self.assertSparse(knn.one_hot(x, 5, sparse=True))\n\n    @parameterized.named_parameters(\n        named_product(dtype=[\"float32\", \"int32\", \"bool\"], sparse=[False, True])\n    )\n    def test_one_hot_dtype(self, dtype, sparse):\n        if sparse and not backend.SUPPORTS_SPARSE_TENSORS:\n            pytest.skip(\"Backend does not support sparse tensors\")\n\n        x = np.arange(5)\n        out = knn.one_hot(x, 5, axis=0, dtype=dtype, sparse=sparse)\n        self.assertEqual(backend.standardize_dtype(out.dtype), dtype)\n        self.assertSparse(out, sparse)\n\n    def test_moments(self):\n        x = KerasTensor([None, 3, 4])\n        self.assertEqual(knn.moments(x, axes=[0])[0].shape, (3, 4))\n        self.assertEqual(knn.moments(x, axes=[0, 1])[0].shape, (4,))\n        self.assertEqual(\n            knn.moments(x, axes=[0, 1], keepdims=True)[0].shape, (1, 1, 4)\n        )\n\n        self.assertEqual(knn.moments(x, axes=[1])[0].shape, (None, 4))\n        self.assertEqual(knn.moments(x, axes=[1, 2])[0].shape, (None,))\n        self.assertEqual(\n            knn.moments(x, axes=[1, 2], keepdims=True)[0].shape, (None, 1, 1)\n        )\n\n    def test_batch_normalization(self):\n        x = KerasTensor([None, 3, 4])\n        mean = KerasTensor([4])\n        variance = KerasTensor([4])\n        self.assertEqual(\n            knn.batch_normalization(x, mean, variance, axis=-1).shape,\n            (None, 3, 4),\n        )\n\n        x = KerasTensor([None, 3, 4, 5])\n        self.assertEqual(\n            knn.batch_normalization(x, mean, variance, axis=2).shape,\n            (None, 3, 4, 5),\n        )\n\n        mean = KerasTensor([3])\n        variance = KerasTensor([3])\n        self.assertEqual(\n            knn.batch_normalization(x, mean, variance, axis=1).shape,\n            (None, 3, 4, 5),\n        )\n\n        # Test wrong offset shape\n        self.assertRaisesRegex(\n            ValueError,\n            \"`offset` must be a vector of length\",\n            knn.batch_normalization,\n            KerasTensor([None, 3, 4, 5]),\n            KerasTensor([5]),\n            KerasTensor([5]),\n            axis=-1,\n            offset=KerasTensor([3]),\n            scale=KerasTensor([5]),\n        )\n\n        # Test wrong scale shape\n        self.assertRaisesRegex(\n            ValueError,\n            \"`scale` must be a vector of length\",\n            knn.batch_normalization,\n            KerasTensor([None, 3, 4, 5]),\n            KerasTensor([5]),\n            KerasTensor([5]),\n            axis=-1,\n            offset=KerasTensor([5]),\n            scale=KerasTensor([3]),\n        )\n\n    def test_ctc_decode(self):\n        # Test strategy=\"greedy\"\n        inputs = KerasTensor([None, 2, 3])\n        sequence_lengths = KerasTensor([None])\n        decoded, scores = knn.ctc_decode(inputs, sequence_lengths)\n        self.assertEqual(decoded.shape, (1, None, 2))\n        self.assertEqual(scores.shape, (None, 1))\n\n        # Test strategy=\"beam_search\"\n        inputs = KerasTensor([None, 2, 3])\n        sequence_lengths = KerasTensor([None])\n        decoded, scores = knn.ctc_decode(\n            inputs, sequence_lengths, strategy=\"beam_search\", top_paths=2\n        )\n        self.assertEqual(decoded.shape, (2, None, 2))\n        self.assertEqual(scores.shape, (None, 2))\n\n    def test_normalize(self):\n        x = KerasTensor([None, 2, 3])\n        self.assertEqual(knn.normalize(x).shape, (None, 2, 3))\n\n    def test_psnr(self):\n        x1 = KerasTensor([None, 2, 3])\n        x2 = KerasTensor([None, 5, 6])\n        out = knn.psnr(x1, x2, max_val=224)\n        self.assertEqual(out.shape, ())\n\n    def test_dot_product_attention(self):\n        query = KerasTensor([None, None, 8, 16])\n        key = KerasTensor([None, None, 6, 16])\n        value = KerasTensor([None, None, 6, 16])\n        out = knn.dot_product_attention(query, key, value)\n        self.assertEqual(out.shape, query.shape)\n\n    def test_rms_normalization(self):\n        x = KerasTensor([None, 8, 16])\n        scale = KerasTensor([None, 8, 16])\n        out = knn.rms_normalization(x, scale)\n        self.assertEqual(out.shape, x.shape)\n\n    def test_layer_normalization(self):\n        x = KerasTensor([None, 8, 16])\n        gamma = KerasTensor([None, 16])\n        beta = KerasTensor([None, 16])\n        out = knn.layer_normalization(x, gamma, beta)\n        self.assertEqual(out.shape, x.shape)\n\n\nclass NNOpsStaticShapeTest(testing.TestCase):\n    def test_relu(self):\n        x = KerasTensor([1, 2, 3])\n        self.assertEqual(knn.relu(x).shape, (1, 2, 3))\n\n    def test_relu6(self):\n        x = KerasTensor([1, 2, 3])\n        self.assertEqual(knn.relu6(x).shape, (1, 2, 3))\n\n    def test_sigmoid(self):\n        x = KerasTensor([1, 2, 3])\n        self.assertEqual(knn.sigmoid(x).shape, (1, 2, 3))\n\n    def test_sparse_sigmoid(self):\n        x = KerasTensor([1, 2, 3])\n        self.assertEqual(knn.sparse_sigmoid(x).shape, (1, 2, 3))\n\n    def test_softplus(self):\n        x = KerasTensor([1, 2, 3])\n        self.assertEqual(knn.softplus(x).shape, (1, 2, 3))\n\n    def test_softsign(self):\n        x = KerasTensor([1, 2, 3])\n        self.assertEqual(knn.softsign(x).shape, (1, 2, 3))\n\n    def test_silu(self):\n        x = KerasTensor([1, 2, 3])\n        self.assertEqual(knn.silu(x).shape, (1, 2, 3))\n\n    def test_log_sigmoid(self):\n        x = KerasTensor([1, 2, 3])\n        self.assertEqual(knn.log_sigmoid(x).shape, (1, 2, 3))\n\n    def test_leaky_relu(self):\n        x = KerasTensor([1, 2, 3])\n        self.assertEqual(knn.leaky_relu(x).shape, (1, 2, 3))\n\n    def test_hard_sigmoid(self):\n        x = KerasTensor([1, 2, 3])\n        self.assertEqual(knn.hard_sigmoid(x).shape, (1, 2, 3))\n\n    def test_hard_silu(self):\n        x = KerasTensor([1, 2, 3])\n        self.assertEqual(knn.hard_silu(x).shape, (1, 2, 3))\n\n    def test_elu(self):\n        x = KerasTensor([1, 2, 3])\n        self.assertEqual(knn.elu(x).shape, (1, 2, 3))\n\n    def test_selu(self):\n        x = KerasTensor([1, 2, 3])\n        self.assertEqual(knn.selu(x).shape, (1, 2, 3))\n\n    def test_gelu(self):\n        x = KerasTensor([1, 2, 3])\n        self.assertEqual(knn.gelu(x).shape, (1, 2, 3))\n\n    def test_celu(self):\n        x = KerasTensor([1, 2, 3])\n        self.assertEqual(knn.celu(x).shape, (1, 2, 3))\n\n    def test_glu(self):\n        x = KerasTensor([1, 2, 4])\n        self.assertEqual(knn.glu(x).shape, (1, 2, 2))\n\n    def test_tanh_shrink(self):\n        x = KerasTensor([1, 2, 3])\n        self.assertEqual(knn.tanh_shrink(x).shape, (1, 2, 3))\n\n    def test_hard_tanh(self):\n        x = KerasTensor([1, 2, 3])\n        self.assertEqual(knn.hard_tanh(x).shape, (1, 2, 3))\n\n    def test_hard_shrink(self):\n        x = KerasTensor([1, 2, 3])\n        self.assertEqual(knn.hard_shrink(x).shape, (1, 2, 3))\n\n    def test_threshold(self):\n        x = KerasTensor([1, 2, 3])\n        self.assertEqual(knn.threshold(x, 0, 0).shape, (1, 2, 3))\n\n    def test_squareplus(self):\n        x = KerasTensor([1, 2, 3])\n        self.assertEqual(knn.squareplus(x).shape, (1, 2, 3))\n\n    def test_soft_shrink(self):\n        x = KerasTensor([1, 2, 3])\n        self.assertEqual(knn.soft_shrink(x).shape, (1, 2, 3))\n\n    def test_sparse_plus(self):\n        x = KerasTensor([1, 2, 3])\n        self.assertEqual(knn.sparse_plus(x).shape, (1, 2, 3))\n\n    def test_softmax(self):\n        x = KerasTensor([1, 2, 3])\n        self.assertEqual(knn.softmax(x).shape, (1, 2, 3))\n        self.assertEqual(knn.softmax(x, axis=1).shape, (1, 2, 3))\n        self.assertEqual(knn.softmax(x, axis=-1).shape, (1, 2, 3))\n\n    def test_log_softmax(self):\n        x = KerasTensor([1, 2, 3])\n        self.assertEqual(knn.log_softmax(x).shape, (1, 2, 3))\n        self.assertEqual(knn.log_softmax(x, axis=1).shape, (1, 2, 3))\n        self.assertEqual(knn.log_softmax(x, axis=-1).shape, (1, 2, 3))\n\n    def test_sparsemax(self):\n        x = KerasTensor([1, 2, 3])\n        self.assertEqual(knn.sparsemax(x).shape, (1, 2, 3))\n\n    def test_max_pool(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_shape = (1, 8, 3)\n        else:\n            input_shape = (1, 3, 8)\n        x = KerasTensor(input_shape)\n        self.assertEqual(\n            knn.max_pool(x, 2, 1).shape,\n            (1, 7, 3) if data_format == \"channels_last\" else (1, 3, 7),\n        )\n        self.assertEqual(\n            knn.max_pool(x, 2, 2, padding=\"same\").shape,\n            (1, 4, 3) if data_format == \"channels_last\" else (1, 3, 4),\n        )\n\n        if data_format == \"channels_last\":\n            input_shape = (1, 8, 8, 3)\n        else:\n            input_shape = (1, 3, 8, 8)\n        x = KerasTensor(input_shape)\n        self.assertEqual(\n            knn.max_pool(x, 2, 1).shape,\n            (1, 7, 7, 3) if data_format == \"channels_last\" else (1, 3, 7, 7),\n        )\n        self.assertEqual(\n            knn.max_pool(x, 2, 2, padding=\"same\").shape,\n            (1, 4, 4, 3) if data_format == \"channels_last\" else (1, 3, 4, 4),\n        )\n        self.assertEqual(\n            knn.max_pool(x, (2, 2), (2, 2), padding=\"same\").shape,\n            (1, 4, 4, 3) if data_format == \"channels_last\" else (1, 3, 4, 4),\n        )\n\n    def test_average_pool(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_shape = (1, 8, 3)\n        else:\n            input_shape = (1, 3, 8)\n        x = KerasTensor(input_shape)\n        self.assertEqual(\n            knn.average_pool(x, 2, 1).shape,\n            (1, 7, 3) if data_format == \"channels_last\" else (1, 3, 7),\n        )\n        self.assertEqual(\n            knn.average_pool(x, 2, 2, padding=\"same\").shape,\n            (1, 4, 3) if data_format == \"channels_last\" else (1, 3, 4),\n        )\n\n        if data_format == \"channels_last\":\n            input_shape = (1, 8, 8, 3)\n        else:\n            input_shape = (1, 3, 8, 8)\n        x = KerasTensor(input_shape)\n        self.assertEqual(\n            knn.average_pool(x, 2, 1).shape,\n            (1, 7, 7, 3) if data_format == \"channels_last\" else (1, 3, 7, 7),\n        )\n        self.assertEqual(\n            knn.average_pool(x, 2, 2, padding=\"same\").shape,\n            (1, 4, 4, 3) if data_format == \"channels_last\" else (1, 3, 4, 4),\n        )\n        self.assertEqual(\n            knn.average_pool(x, (2, 2), (2, 2), padding=\"same\").shape,\n            (1, 4, 4, 3) if data_format == \"channels_last\" else (1, 3, 4, 4),\n        )\n\n    def test_conv(self):\n        data_format = backend.config.image_data_format()\n        # Test 1D conv.\n        if data_format == \"channels_last\":\n            input_shape = (2, 20, 3)\n        else:\n            input_shape = (2, 3, 20)\n        inputs_1d = KerasTensor(input_shape)\n        kernel = KerasTensor([4, 3, 2])\n        self.assertEqual(\n            knn.conv(inputs_1d, kernel, 1, padding=\"valid\").shape,\n            (2, 17, 2) if data_format == \"channels_last\" else (2, 2, 17),\n        )\n        self.assertEqual(\n            knn.conv(inputs_1d, kernel, 1, padding=\"same\").shape,\n            (2, 20, 2) if data_format == \"channels_last\" else (2, 2, 20),\n        )\n        self.assertEqual(\n            knn.conv(inputs_1d, kernel, (2,), dilation_rate=2).shape,\n            (2, 7, 2) if data_format == \"channels_last\" else (2, 2, 7),\n        )\n\n        # Test 2D conv.\n        if data_format == \"channels_last\":\n            input_shape = (2, 10, 10, 3)\n        else:\n            input_shape = (2, 3, 10, 10)\n        inputs_2d = KerasTensor(input_shape)\n        kernel = KerasTensor([2, 2, 3, 2])\n        self.assertEqual(\n            knn.conv(inputs_2d, kernel, 1, padding=\"valid\").shape,\n            (2, 9, 9, 2) if data_format == \"channels_last\" else (2, 2, 9, 9),\n        )\n        self.assertEqual(\n            knn.conv(inputs_2d, kernel, 1, padding=\"same\").shape,\n            (\n                (2, 10, 10, 2)\n                if data_format == \"channels_last\"\n                else (2, 2, 10, 10)\n            ),\n        )\n        self.assertEqual(\n            knn.conv(inputs_2d, kernel, (2, 1), dilation_rate=(2, 1)).shape,\n            (2, 4, 9, 2) if data_format == \"channels_last\" else (2, 2, 4, 9),\n        )\n\n        # Test 3D conv.\n        if data_format == \"channels_last\":\n            input_shape = (2, 8, 8, 8, 3)\n        else:\n            input_shape = (2, 3, 8, 8, 8)\n        inputs_3d = KerasTensor(input_shape)\n        kernel = KerasTensor([3, 3, 3, 3, 2])\n        self.assertEqual(\n            knn.conv(inputs_3d, kernel, 1, padding=\"valid\").shape,\n            (\n                (2, 6, 6, 6, 2)\n                if data_format == \"channels_last\"\n                else (2, 2, 6, 6, 6)\n            ),\n        )\n        self.assertEqual(\n            knn.conv(inputs_3d, kernel, (2, 1, 2), padding=\"same\").shape,\n            (\n                (2, 4, 8, 4, 2)\n                if data_format == \"channels_last\"\n                else (2, 2, 4, 8, 4)\n            ),\n        )\n        self.assertEqual(\n            knn.conv(\n                inputs_3d, kernel, 1, padding=\"valid\", dilation_rate=(1, 2, 2)\n            ).shape,\n            (\n                (2, 6, 4, 4, 2)\n                if data_format == \"channels_last\"\n                else (2, 2, 6, 4, 4)\n            ),\n        )\n\n    def test_depthwise_conv(self):\n        data_format = backend.config.image_data_format()\n        # Test 1D depthwise conv.\n        if data_format == \"channels_last\":\n            input_shape = (2, 20, 3)\n        else:\n            input_shape = (2, 3, 20)\n        inputs_1d = KerasTensor(input_shape)\n        kernel = KerasTensor([4, 3, 1])\n        self.assertEqual(\n            knn.depthwise_conv(inputs_1d, kernel, 1, padding=\"valid\").shape,\n            (2, 17, 3) if data_format == \"channels_last\" else (2, 3, 17),\n        )\n        self.assertEqual(\n            knn.depthwise_conv(inputs_1d, kernel, (1,), padding=\"same\").shape,\n            (2, 20, 3) if data_format == \"channels_last\" else (2, 3, 20),\n        )\n        self.assertEqual(\n            knn.depthwise_conv(inputs_1d, kernel, 2, dilation_rate=2).shape,\n            (2, 7, 3) if data_format == \"channels_last\" else (2, 3, 7),\n        )\n\n        # Test 2D depthwise conv.\n        if data_format == \"channels_last\":\n            input_shape = (2, 10, 10, 3)\n        else:\n            input_shape = (2, 3, 10, 10)\n        inputs_2d = KerasTensor(input_shape)\n        kernel = KerasTensor([2, 2, 3, 1])\n        self.assertEqual(\n            knn.depthwise_conv(inputs_2d, kernel, 1, padding=\"valid\").shape,\n            (2, 9, 9, 3) if data_format == \"channels_last\" else (2, 3, 9, 9),\n        )\n        self.assertEqual(\n            knn.depthwise_conv(inputs_2d, kernel, (1, 2), padding=\"same\").shape,\n            (2, 10, 5, 3) if data_format == \"channels_last\" else (2, 3, 10, 5),\n        )\n        self.assertEqual(\n            knn.depthwise_conv(inputs_2d, kernel, 2, dilation_rate=2).shape,\n            (2, 4, 4, 3) if data_format == \"channels_last\" else (2, 3, 4, 4),\n        )\n        self.assertEqual(\n            knn.depthwise_conv(\n                inputs_2d, kernel, 2, dilation_rate=(2, 1)\n            ).shape,\n            (2, 4, 5, 3) if data_format == \"channels_last\" else (2, 3, 4, 5),\n        )\n\n    def test_separable_conv(self):\n        data_format = backend.config.image_data_format()\n        # Test 1D max pooling.\n        if data_format == \"channels_last\":\n            input_shape = (2, 20, 3)\n        else:\n            input_shape = (2, 3, 20)\n        inputs_1d = KerasTensor(input_shape)\n        kernel = KerasTensor([4, 3, 2])\n        pointwise_kernel = KerasTensor([1, 6, 5])\n        self.assertEqual(\n            knn.separable_conv(\n                inputs_1d, kernel, pointwise_kernel, 1, padding=\"valid\"\n            ).shape,\n            (2, 17, 5) if data_format == \"channels_last\" else (2, 5, 17),\n        )\n        self.assertEqual(\n            knn.separable_conv(\n                inputs_1d, kernel, pointwise_kernel, 1, padding=\"same\"\n            ).shape,\n            (2, 20, 5) if data_format == \"channels_last\" else (2, 5, 20),\n        )\n        self.assertEqual(\n            knn.separable_conv(\n                inputs_1d, kernel, pointwise_kernel, 2, dilation_rate=2\n            ).shape,\n            (2, 7, 5) if data_format == \"channels_last\" else (2, 5, 7),\n        )\n\n        # Test 2D separable conv.\n        if data_format == \"channels_last\":\n            input_shape = (2, 10, 10, 3)\n        else:\n            input_shape = (2, 3, 10, 10)\n        inputs_2d = KerasTensor(input_shape)\n        kernel = KerasTensor([2, 2, 3, 2])\n        pointwise_kernel = KerasTensor([1, 1, 6, 5])\n        self.assertEqual(\n            knn.separable_conv(\n                inputs_2d, kernel, pointwise_kernel, 1, padding=\"valid\"\n            ).shape,\n            (2, 9, 9, 5) if data_format == \"channels_last\" else (2, 5, 9, 9),\n        )\n        self.assertEqual(\n            knn.separable_conv(\n                inputs_2d, kernel, pointwise_kernel, (1, 2), padding=\"same\"\n            ).shape,\n            (2, 10, 5, 5) if data_format == \"channels_last\" else (2, 5, 10, 5),\n        )\n        self.assertEqual(\n            knn.separable_conv(\n                inputs_2d, kernel, pointwise_kernel, 2, dilation_rate=(2, 1)\n            ).shape,\n            (2, 4, 5, 5) if data_format == \"channels_last\" else (2, 5, 4, 5),\n        )\n\n    def test_conv_transpose(self):\n        data_format = backend.config.image_data_format()\n        if data_format == \"channels_last\":\n            input_shape = (2, 4, 3)\n        else:\n            input_shape = (2, 3, 4)\n        inputs_1d = KerasTensor(input_shape)\n        kernel = KerasTensor([2, 5, 3])\n        self.assertEqual(\n            knn.conv_transpose(inputs_1d, kernel, 2).shape,\n            (2, 8, 5) if data_format == \"channels_last\" else (2, 5, 8),\n        )\n        self.assertEqual(\n            knn.conv_transpose(inputs_1d, kernel, 2, padding=\"same\").shape,\n            (2, 8, 5) if data_format == \"channels_last\" else (2, 5, 8),\n        )\n        self.assertEqual(\n            knn.conv_transpose(\n                inputs_1d, kernel, 5, padding=\"valid\", output_padding=4\n            ).shape,\n            (2, 21, 5) if data_format == \"channels_last\" else (2, 5, 21),\n        )\n\n        if data_format == \"channels_last\":\n            input_shape = (2, 4, 4, 3)\n        else:\n            input_shape = (2, 3, 4, 4)\n        inputs_2d = KerasTensor(input_shape)\n        kernel = KerasTensor([2, 2, 5, 3])\n        self.assertEqual(\n            knn.conv_transpose(inputs_2d, kernel, 2).shape,\n            (2, 8, 8, 5) if data_format == \"channels_last\" else (2, 5, 8, 8),\n        )\n        self.assertEqual(\n            knn.conv_transpose(inputs_2d, kernel, (2, 2), padding=\"same\").shape,\n            (2, 8, 8, 5) if data_format == \"channels_last\" else (2, 5, 8, 8),\n        )\n        self.assertEqual(\n            knn.conv_transpose(\n                inputs_2d, kernel, (5, 5), padding=\"valid\", output_padding=4\n            ).shape,\n            (\n                (2, 21, 21, 5)\n                if data_format == \"channels_last\"\n                else (2, 5, 21, 21)\n            ),\n        )\n\n    def test_batched_and_unbatched_inputs_multi_hot(self):\n        x = KerasTensor([2, 3, 1])\n        unbatched_input = KerasTensor(\n            [\n                5,\n            ]\n        )\n        self.assertEqual(knn.multi_hot(unbatched_input, 5, -1).shape, (5,))\n        self.assertEqual(knn.multi_hot(x, 5).shape, (2, 1, 5))\n        self.assertEqual(knn.multi_hot(x, 5, 1).shape, (2, 3, 1))\n        self.assertEqual(knn.multi_hot(x, 5, 2).shape, (2, 5, 1))\n\n    def test_one_hot(self):\n        x = KerasTensor([2, 3, 1])\n        self.assertEqual(knn.one_hot(x, 5).shape, (2, 3, 1, 5))\n        self.assertEqual(knn.one_hot(x, 5, 1).shape, (2, 5, 3, 1))\n        self.assertEqual(knn.one_hot(x, 5, 2).shape, (2, 3, 5, 1))\n        self.assertSparse(knn.one_hot(x, 5, sparse=True))\n\n    def test_binary_crossentropy(self):\n        x1 = KerasTensor([2, 3, 1])\n        x2 = KerasTensor([2, 3, 1])\n        self.assertEqual(knn.binary_crossentropy(x1, x2).shape, (2, 3, 1))\n\n    def test_categorical_crossentropy(self):\n        x1 = KerasTensor([2, 3, 4])\n        x2 = KerasTensor([2, 3, 4])\n        self.assertEqual(knn.categorical_crossentropy(x1, x2).shape, (2, 3))\n\n    def test_sparse_categorical_crossentropy(self):\n        x1 = KerasTensor([2, 3], dtype=\"int32\")\n        x2 = KerasTensor([2, 3, 4])\n        self.assertEqual(\n            knn.sparse_categorical_crossentropy(x1, x2).shape, (2, 3)\n        )\n\n    def test_moments(self):\n        x = KerasTensor([2, 3, 4])\n        self.assertEqual(knn.moments(x, axes=[0])[0].shape, (3, 4))\n        self.assertEqual(knn.moments(x, axes=[0, 1])[0].shape, (4,))\n        self.assertEqual(\n            knn.moments(x, axes=[0, 1], keepdims=True)[0].shape, (1, 1, 4)\n        )\n\n    def test_batch_normalization(self):\n        x = KerasTensor([10, 3, 4])\n        mean = KerasTensor([4])\n        variance = KerasTensor([4])\n        self.assertEqual(\n            knn.batch_normalization(x, mean, variance, axis=-1).shape,\n            (10, 3, 4),\n        )\n\n        x = KerasTensor([10, 3, 4, 5])\n        self.assertEqual(\n            knn.batch_normalization(x, mean, variance, axis=2).shape,\n            (10, 3, 4, 5),\n        )\n\n        mean = KerasTensor([3])\n        variance = KerasTensor([3])\n        self.assertEqual(\n            knn.batch_normalization(x, mean, variance, axis=1).shape,\n            (10, 3, 4, 5),\n        )\n\n    def test_ctc_loss(self):\n        x = KerasTensor([10, 3, 4])\n        y = KerasTensor([10, 3], dtype=\"int32\")\n        x_lengths = KerasTensor([10], dtype=\"int32\")\n        y_lengths = KerasTensor([10], dtype=\"int32\")\n        self.assertEqual(knn.ctc_loss(x, y, x_lengths, y_lengths).shape, (10,))\n\n    def test_ctc_decode(self):\n        # Test strategy=\"greedy\"\n        inputs = KerasTensor([10, 2, 3])\n        sequence_lengths = KerasTensor([10])\n        decoded, scores = knn.ctc_decode(inputs, sequence_lengths)\n        self.assertEqual(decoded.shape, (1, 10, 2))\n        self.assertEqual(scores.shape, (10, 1))\n\n        # Test strategy=\"beam_search\"\n        inputs = KerasTensor([10, 2, 3])\n        sequence_lengths = KerasTensor([10])\n        decoded, scores = knn.ctc_decode(\n            inputs, sequence_lengths, strategy=\"beam_search\", top_paths=2\n        )\n        self.assertEqual(decoded.shape, (2, 10, 2))\n        self.assertEqual(scores.shape, (10, 2))\n\n    def test_normalize(self):\n        x = KerasTensor([1, 2, 3])\n        self.assertEqual(knn.normalize(x).shape, (1, 2, 3))\n\n    def test_psnr(self):\n        x1 = KerasTensor([1, 2, 3])\n        x2 = KerasTensor([5, 6, 7])\n        out = knn.psnr(x1, x2, max_val=224)\n        self.assertEqual(out.shape, ())\n\n    def test_dot_product_attention(self):\n        query = KerasTensor([2, 3, 8, 16])\n        key = KerasTensor([2, 4, 6, 16])\n        value = KerasTensor([2, 4, 6, 16])\n        out = knn.dot_product_attention(query, key, value)\n        self.assertEqual(out.shape, query.shape)\n\n    def test_rms_normalization(self):\n        x = KerasTensor([2, 8, 16])\n        scale = KerasTensor([2, 8, 16])\n        self.assertEqual(knn.rms_normalization(x, scale).shape, x.shape)\n\n    def test_layer_normalization(self):\n        x = KerasTensor([2, 8, 16])\n        gamma = KerasTensor([2, 16])\n        beta = KerasTensor([2, 16])\n        self.assertEqual(knn.layer_normalization(x, gamma, beta).shape, x.shape)\n\n    def test_polar(self):\n        abs_ = KerasTensor([1, 2])\n        angle = KerasTensor([3, 4])\n        out = knn.polar(abs_, angle)\n        self.assertEqual(out.shape, abs_.shape)\n\n\nclass NNOpsCorrectnessTest(testing.TestCase):\n    @pytest.mark.skipif(not testing.jax_uses_tpu(), reason=\"JAX on TPU only\")\n    def test_dot_product_attention_inside_scan(self):\n        import jax\n        import jax.numpy as jnp\n\n        def attention_scan_body(carry, x):\n            query, key, value = x\n            # dot_product_attention expects 4D inputs (B, H, S, D)\n            query = jnp.expand_dims(query, axis=0)\n            key = jnp.expand_dims(key, axis=0)\n            value = jnp.expand_dims(value, axis=0)\n\n            # Use a mask to trigger the issue\n            mask = jnp.ones((1, 4, 8), dtype=\"bool\")\n            out = knn.dot_product_attention(query, key, value, mask=mask)\n\n            out = jnp.squeeze(out, axis=0)\n            return carry, out\n\n        query = jnp.ones((2, 1, 4, 8))\n        key = jnp.ones((2, 1, 4, 8))\n        value = jnp.ones((2, 1, 4, 8))\n\n        # Scan over the first dimension\n        _, out = jax.lax.scan(attention_scan_body, None, (query, key, value))\n        self.assertEqual(out.shape, (2, 1, 4, 8))\n\n    def test_relu(self):\n        x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)\n        self.assertAllClose(knn.relu(x), [0, 0, 1, 2, 3])\n\n    def test_relu6(self):\n        x = np.array([-1, 0, 1, 2, 3, 4, 5, 6, 7], dtype=np.float32)\n        self.assertAllClose(knn.relu6(x), [0, 0, 1, 2, 3, 4, 5, 6, 6])\n\n    def test_sigmoid(self):\n        x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)\n        self.assertAllClose(\n            knn.sigmoid(x), [0.26894143, 0.5, 0.7310586, 0.880797, 0.95257413]\n        )\n\n    def test_sparse_sigmoid(self):\n        x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)\n        self.assertAllClose(knn.sparse_sigmoid(x), [0.0, 0.5, 1.0, 1.0, 1.0])\n\n    def test_softplus(self):\n        x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)\n        self.assertAllClose(\n            knn.softplus(x),\n            [0.31326166, 0.6931472, 1.3132616, 2.126928, 3.0485873],\n        )\n\n    def test_softsign(self):\n        x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)\n        self.assertAllClose(knn.softsign(x), [-0.5, 0, 0.5, 0.6666667, 0.75])\n\n    def test_silu(self):\n        x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)\n        self.assertAllClose(\n            knn.silu(x),\n            [-0.26894143, 0, 0.7310586, 1.7615942, 2.8577223],\n        )\n\n    def test_log_sigmoid(self):\n        x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)\n        self.assertAllClose(\n            knn.log_sigmoid(x),\n            [-1.3132616, -0.6931472, -0.31326166, -0.126928, -0.04858732],\n        )\n\n    def test_leaky_relu(self):\n        x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)\n        self.assertAllClose(\n            knn.leaky_relu(x),\n            [-0.2, 0, 1, 2, 3],\n        )\n\n    def test_hard_sigmoid(self):\n        x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)\n        self.assertAllClose(\n            knn.hard_sigmoid(x),\n            [0.33333334, 0.5, 0.6666667, 0.8333334, 1.0],\n        )\n\n    def test_hard_silu(self):\n        x = np.array([-3, -2, -1, 0, 1, 2, 3], dtype=np.float32)\n        self.assertAllClose(\n            knn.hard_silu(x),\n            [-0.0, -0.333333, -0.333333, 0.0, 0.6666667, 1.6666667, 3.0],\n        )\n\n    def test_elu(self):\n        x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)\n        self.assertAllClose(\n            knn.elu(x),\n            [-0.63212055, 0, 1, 2, 3],\n        )\n        self.assertAllClose(\n            knn.elu(x, alpha=0.5),\n            [-0.31606027, 0, 1, 2, 3],\n        )\n\n    def test_selu(self):\n        x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)\n        self.assertAllClose(\n            knn.selu(x),\n            [-1.1113307, 0.0, 1.050701, 2.101402, 3.152103],\n        )\n\n    def test_gelu(self):\n        x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)\n        self.assertAllClose(\n            knn.gelu(x),\n            [-0.15880796, 0.0, 0.841192, 1.9545977, 2.9963627],\n        )\n\n    def test_celu(self):\n        x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)\n        self.assertAllClose(\n            knn.celu(x),\n            [-0.63212055, 0.0, 1.0, 2.0, 3.0],\n        )\n\n    def test_glu(self):\n        x = np.array([-1, 0, 1, 2, 3, 4], dtype=np.float32)\n        self.assertAllClose(\n            knn.glu(x),\n            [-0.8807971, 0.0, 0.98201376],\n        )\n\n    def test_tanh_shrink(self):\n        x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)\n        self.assertAllClose(\n            knn.tanh_shrink(x),\n            [-0.238406, 0.0, 0.238406, 1.035972, 2.004945],\n        )\n\n    def test_hard_tanh(self):\n        x = np.array([-1, 0, 1, 2, 3], dtype=np.float32)\n        self.assertAllClose(\n            knn.hard_tanh(x),\n            [-1.0, 0.0, 1.0, 1.0, 1.0],\n        )\n\n    def test_hard_shrink(self):\n        x = np.array([-0.5, 0, 1, 2, 3], dtype=np.float32)\n        self.assertAllClose(\n            knn.hard_shrink(x),\n            [0.0, 0.0, 1.0, 2.0, 3.0],\n        )\n\n    def test_threshold(self):\n        x = np.array([-0.5, 0, 1, 2, 3], dtype=np.float32)\n        self.assertAllClose(\n            knn.threshold(x, 0, 0),\n            [0.0, 0.0, 1.0, 2.0, 3.0],\n        )\n\n    def test_squareplus(self):\n        x = np.array([-0.5, 0, 1, 2, 3], dtype=np.float32)\n        self.assertAllClose(\n            knn.squareplus(x),\n            [0.780776, 1.0, 1.618034, 2.414214, 3.302776],\n        )\n\n    def test_soft_shrink(self):\n        x = np.array([-0.5, 0, 1, 2, 3], dtype=np.float32)\n        self.assertAllClose(\n            knn.soft_shrink(x),\n            [0.0, 0.0, 0.5, 1.5, 2.5],\n        )\n\n    def test_sparse_plus(self):\n        x = np.array([-0.5, 0, 1, 2, 3], dtype=np.float32)\n        self.assertAllClose(\n            knn.sparse_plus(x),\n            [0.0625, 0.25, 1.0, 2.0, 3.0],\n        )\n\n    def test_softmax(self):\n        x = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32)\n        self.assertAllClose(\n            knn.softmax(x, axis=None),  # Reduce on all axes.\n            [[0.045015, 0.122364, 0.33262], [0.045015, 0.122364, 0.33262]],\n        )\n        self.assertAllClose(\n            knn.softmax(x, axis=0),\n            [[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]],\n        )\n        self.assertAllClose(\n            knn.softmax(x, axis=-1),\n            [\n                [0.09003057, 0.24472848, 0.66524094],\n                [0.09003057, 0.24472848, 0.66524094],\n            ],\n        )\n        self.assertAllClose(\n            knn.softmax(x),  # Default axis should be -1.\n            [\n                [0.09003057, 0.24472848, 0.66524094],\n                [0.09003057, 0.24472848, 0.66524094],\n            ],\n        )\n\n    def test_softmax_correctness_with_axis_tuple(self):\n        input = np.array([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]])\n        combination = combinations(range(3), 2)\n        for axis in list(combination):\n            result = keras.ops.nn.softmax(input, axis=axis)\n            normalized_sum_by_axis = np.sum(\n                ops.convert_to_numpy(result), axis=axis\n            )\n            self.assertAllClose(normalized_sum_by_axis, 1.0)\n\n    def test_log_softmax(self):\n        x = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32)\n        self.assertAllClose(\n            knn.log_softmax(x, axis=None),  # Reduce on all axes.\n            [\n                [-3.100753, -2.100753, -1.100753],\n                [-3.100753, -2.100753, -1.100753],\n            ],\n        )\n        self.assertAllClose(\n            knn.log_softmax(x, axis=0),\n            [\n                [-0.693147, -0.693147, -0.693147],\n                [-0.693147, -0.693147, -0.693147],\n            ],\n        )\n        self.assertAllClose(\n            knn.log_softmax(x, axis=-1),\n            [\n                [-2.407606, -1.407606, -0.407606],\n                [-2.407606, -1.407606, -0.407606],\n            ],\n        )\n        self.assertAllClose(\n            knn.log_softmax(x),  # Default axis should be -1.\n            [\n                [-2.407606, -1.407606, -0.407606],\n                [-2.407606, -1.407606, -0.407606],\n            ],\n        )\n\n    def test_log_softmax_correctness_with_axis_tuple(self):\n        input = np.array([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]])\n        combination = combinations(range(3), 2)\n        for axis in list(combination):\n            result = keras.ops.nn.log_softmax(input, axis=axis)\n            normalized_sum_by_axis = np.sum(\n                np.exp(ops.convert_to_numpy(result)), axis=axis\n            )\n            self.assertAllClose(normalized_sum_by_axis, 1.0)\n\n    def test_polar_corectness(self):\n        abs_ = np.array([1, 2], dtype=\"float32\")\n        angle = np.array([2, 3], dtype=\"float32\")\n        out = knn.polar(abs_, angle)\n        self.assertAllClose(\n            out, [-0.41614684 + 0.9092974j, -1.979985 + 0.28224j], atol=1e-3\n        )\n\n    def test_sparsemax(self):\n        x = np.array([-0.5, 0, 1, 2, 3], dtype=np.float32)\n        self.assertAllClose(\n            knn.sparsemax(x),\n            [0.0, 0.0, 0.0, 0.0, 1.0],\n        )\n\n    def test_max_pool(self):\n        data_format = backend.config.image_data_format()\n        # Test 1D max pooling.\n        if data_format == \"channels_last\":\n            input_shape = (2, 20, 3)\n        else:\n            input_shape = (2, 3, 20)\n        x = np.arange(120, dtype=float).reshape(input_shape)\n        self.assertAllClose(\n            knn.max_pool(x, 2, 1, padding=\"valid\"),\n            np_maxpool1d(x, 2, 1, padding=\"valid\", data_format=data_format),\n        )\n        self.assertAllClose(\n            knn.max_pool(x, 2, 2, padding=\"same\"),\n            np_maxpool1d(x, 2, 2, padding=\"same\", data_format=data_format),\n        )\n\n        # Test 2D max pooling.\n        if data_format == \"channels_last\":\n            input_shape = (2, 10, 9, 3)\n        else:\n            input_shape = (2, 3, 10, 9)\n        x = np.arange(540, dtype=float).reshape(input_shape)\n        self.assertAllClose(\n            knn.max_pool(x, 2, 1, padding=\"valid\"),\n            np_maxpool2d(x, 2, 1, padding=\"valid\", data_format=data_format),\n        )\n        self.assertAllClose(\n            knn.max_pool(x, 2, (2, 1), padding=\"same\"),\n            np_maxpool2d(x, 2, (2, 1), padding=\"same\", data_format=data_format),\n        )\n\n    def test_average_pool_valid_padding(self):\n        data_format = backend.config.image_data_format()\n        # Test 1D average pooling.\n        if data_format == \"channels_last\":\n            input_shape = (2, 20, 3)\n        else:\n            input_shape = (2, 3, 20)\n        x = np.arange(120, dtype=float).reshape(input_shape)\n        self.assertAllClose(\n            knn.average_pool(x, 2, 1, padding=\"valid\"),\n            np_avgpool1d(x, 2, 1, padding=\"valid\", data_format=data_format),\n        )\n\n        # Test 2D average pooling.\n        if data_format == \"channels_last\":\n            input_shape = (2, 10, 9, 3)\n        else:\n            input_shape = (2, 3, 10, 9)\n        x = np.arange(540, dtype=float).reshape(input_shape)\n        self.assertAllClose(\n            knn.average_pool(x, 2, 1, padding=\"valid\"),\n            np_avgpool2d(x, 2, 1, padding=\"valid\", data_format=data_format),\n        )\n\n    def test_average_pool_same_padding(self):\n        data_format = backend.config.image_data_format()\n        # Test 1D average pooling.\n        if data_format == \"channels_last\":\n            input_shape = (2, 20, 3)\n        else:\n            input_shape = (2, 3, 20)\n        x = np.arange(120, dtype=float).reshape(input_shape)\n\n        self.assertAllClose(\n            knn.average_pool(x, 2, 2, padding=\"same\"),\n            np_avgpool1d(x, 2, 2, padding=\"same\", data_format=data_format),\n        )\n\n        # Test 2D average pooling.\n        if data_format == \"channels_last\":\n            input_shape = (2, 10, 9, 3)\n        else:\n            input_shape = (2, 3, 10, 9)\n        x = np.arange(540, dtype=float).reshape(input_shape)\n        self.assertAllClose(\n            knn.average_pool(x, 2, (2, 1), padding=\"same\"),\n            np_avgpool2d(x, 2, (2, 1), padding=\"same\", data_format=data_format),\n        )\n        # Test 2D average pooling with different pool size.\n        if data_format == \"channels_last\":\n            input_shape = (2, 10, 9, 3)\n        else:\n            input_shape = (2, 3, 10, 9)\n        x = np.arange(540, dtype=float).reshape(input_shape)\n        self.assertAllClose(\n            knn.average_pool(x, (2, 3), (3, 3), padding=\"same\"),\n            np_avgpool2d(\n                x, (2, 3), (3, 3), padding=\"same\", data_format=data_format\n            ),\n        )\n\n    @parameterized.product(\n        strides=(1, 2, 3),\n        padding=(\"valid\", \"same\"),\n        dilation_rate=(1, 2),\n    )\n    def test_conv_1d(self, strides, padding, dilation_rate):\n        if strides > 1 and dilation_rate > 1:\n            pytest.skip(\"Unsupported configuration\")\n\n        if backend.config.image_data_format() == \"channels_last\":\n            input_shape = (2, 20, 3)\n        else:\n            input_shape = (2, 3, 20)\n        inputs_1d = np.arange(120, dtype=float).reshape(input_shape)\n        kernel = np.arange(24, dtype=float).reshape([4, 3, 2])\n\n        outputs = knn.conv(\n            inputs_1d,\n            kernel,\n            strides=strides,\n            padding=padding,\n            dilation_rate=dilation_rate,\n        )\n        expected = np_conv1d(\n            inputs_1d,\n            kernel,\n            bias_weights=np.zeros((2,)),\n            strides=strides,\n            padding=padding.lower(),\n            data_format=backend.config.image_data_format(),\n            dilation_rate=dilation_rate,\n            groups=1,\n        )\n        self.assertAllClose(outputs, expected)\n\n    @parameterized.product(strides=(1, 2, (1, 2)), padding=(\"valid\", \"same\"))\n    def test_conv_2d(self, strides, padding):\n        if backend.config.image_data_format() == \"channels_last\":\n            input_shape = (2, 10, 10, 3)\n        else:\n            input_shape = (2, 3, 10, 10)\n        inputs_2d = np.arange(600, dtype=float).reshape(input_shape)\n        kernel = np.arange(24, dtype=float).reshape([2, 2, 3, 2])\n\n        outputs = knn.conv(inputs_2d, kernel, strides, padding=padding)\n        expected = np_conv2d(\n            inputs_2d,\n            kernel,\n            bias_weights=np.zeros((2,)),\n            strides=strides,\n            padding=padding,\n            data_format=backend.config.image_data_format(),\n            dilation_rate=1,\n            groups=1,\n        )\n        self.assertAllClose(outputs, expected, tpu_atol=1e-2, tpu_rtol=1e-2)\n\n    @parameterized.product(strides=(1, 2), dilation_rate=(1, (2, 1)))\n    def test_conv_2d_group_2(self, strides, dilation_rate):\n        if (\n            backend.backend() == \"tensorflow\"\n            and strides == 2\n            and dilation_rate == (2, 1)\n        ):\n            # This case is not supported by the TF backend.\n            return\n        if backend.config.image_data_format() == \"channels_last\":\n            input_shape = (2, 10, 10, 4)\n        else:\n            input_shape = (2, 4, 10, 10)\n        inputs_2d = np.ones(input_shape)\n        kernel = np.ones([2, 2, 2, 6])\n        outputs = knn.conv(\n            inputs_2d,\n            kernel,\n            strides,\n            padding=\"same\",\n            dilation_rate=dilation_rate,\n        )\n        expected = np_conv2d(\n            inputs_2d,\n            kernel,\n            bias_weights=np.zeros((6,)),\n            strides=strides,\n            padding=\"same\",\n            data_format=backend.config.image_data_format(),\n            dilation_rate=dilation_rate,\n            groups=1,\n        )\n        self.assertAllClose(outputs, expected)\n\n    @parameterized.product(\n        strides=(1, (1, 1, 1), 2),\n        padding=(\"valid\", \"same\"),\n        data_format=(\"channels_first\", \"channels_last\"),\n    )\n    def test_conv_3d(self, strides, padding, data_format):\n        if data_format == \"channels_last\":\n            input_shape = (2, 8, 8, 8, 3)\n        else:\n            input_shape = (2, 3, 8, 8, 8)\n        inputs_3d = np.arange(3072, dtype=float).reshape(input_shape)\n        kernel = np.arange(162, dtype=float).reshape([3, 3, 3, 3, 2])\n\n        outputs = knn.conv(\n            inputs_3d, kernel, strides, padding=padding, data_format=data_format\n        )\n        expected = np_conv3d(\n            inputs_3d,\n            kernel,\n            bias_weights=np.zeros((2,)),\n            strides=strides,\n            padding=padding,\n            data_format=data_format,\n            dilation_rate=1,\n            groups=1,\n        )\n        self.assertAllClose(\n            outputs,\n            expected,\n            rtol=1e-5,\n            atol=1e-5,\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n\n        # Test for tracing error on tensorflow backend.\n        if backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            @tf.function\n            def conv(x):\n                return knn.conv(\n                    x, kernel, strides, padding=padding, data_format=data_format\n                )\n\n            outputs = conv(inputs_3d)\n            self.assertAllClose(\n                outputs,\n                expected,\n                rtol=1e-5,\n                atol=1e-5,\n                tpu_atol=1e-2,\n                tpu_rtol=1e-2,\n            )\n\n    @parameterized.product(\n        strides=(1, (1, 1), (2, 2)),\n        padding=(\"valid\", \"same\"),\n        dilation_rate=(1, (2, 2)),\n    )\n    def test_depthwise_conv_2d(self, strides, padding, dilation_rate):\n        if (\n            backend.backend() == \"tensorflow\"\n            and strides == (2, 2)\n            and dilation_rate == (2, 2)\n        ):\n            # This case is not supported by the TF backend.\n            return\n        print(strides, padding, dilation_rate)\n        if backend.config.image_data_format() == \"channels_last\":\n            input_shape = (2, 10, 10, 3)\n        else:\n            input_shape = (2, 3, 10, 10)\n        inputs_2d = np.arange(600, dtype=float).reshape(input_shape)\n        kernel = np.arange(24, dtype=float).reshape([2, 2, 3, 2])\n\n        outputs = knn.depthwise_conv(\n            inputs_2d,\n            kernel,\n            strides,\n            padding=padding,\n            dilation_rate=dilation_rate,\n        )\n        expected = np_depthwise_conv2d(\n            inputs_2d,\n            kernel,\n            bias_weights=np.zeros((6,)),\n            strides=strides,\n            padding=padding,\n            data_format=backend.config.image_data_format(),\n            dilation_rate=dilation_rate,\n        )\n        self.assertAllClose(outputs, expected, tpu_atol=1e-2, tpu_rtol=1e-2)\n\n    @parameterized.product(\n        strides=(1, 2),\n        padding=(\"valid\", \"same\"),\n        dilation_rate=(1, (2, 2)),\n    )\n    def test_separable_conv_2d(self, strides, padding, dilation_rate):\n        if (\n            backend.backend() == \"tensorflow\"\n            and strides == 2\n            and dilation_rate == (2, 2)\n        ):\n            # This case is not supported by the TF backend.\n            return\n        # Test 2D conv.\n        if backend.config.image_data_format() == \"channels_last\":\n            input_shape = (2, 10, 10, 3)\n        else:\n            input_shape = (2, 3, 10, 10)\n        inputs_2d = np.arange(600, dtype=float).reshape(input_shape)\n        depthwise_kernel = np.arange(24, dtype=float).reshape([2, 2, 3, 2])\n        pointwise_kernel = np.arange(72, dtype=float).reshape([1, 1, 6, 12])\n\n        outputs = knn.separable_conv(\n            inputs_2d,\n            depthwise_kernel,\n            pointwise_kernel,\n            strides,\n            padding=padding,\n            dilation_rate=dilation_rate,\n        )\n        # Depthwise followed by pointwise conv\n        expected_depthwise = np_depthwise_conv2d(\n            inputs_2d,\n            depthwise_kernel,\n            np.zeros(6),\n            strides=strides,\n            padding=padding,\n            data_format=backend.config.image_data_format(),\n            dilation_rate=dilation_rate,\n        )\n        expected = np_conv2d(\n            expected_depthwise,\n            pointwise_kernel,\n            np.zeros(6 * 12),\n            strides=1,\n            padding=padding,\n            data_format=backend.config.image_data_format(),\n            dilation_rate=dilation_rate,\n            groups=1,\n        )\n        self.assertAllClose(outputs, expected, tpu_atol=1e-2, tpu_rtol=1e-2)\n\n    @parameterized.product(padding=(\"valid\", \"same\"))\n    def test_conv_transpose_1d(self, padding):\n        if backend.config.image_data_format() == \"channels_last\":\n            input_shape = (2, 4, 3)\n        else:\n            input_shape = (2, 3, 4)\n        inputs_1d = np.arange(24, dtype=float).reshape(input_shape)\n        kernel = np.arange(30, dtype=float).reshape([2, 5, 3])\n        outputs = knn.conv_transpose(inputs_1d, kernel, 2, padding=padding)\n        expected = np_conv1d_transpose(\n            inputs_1d,\n            kernel,\n            bias_weights=np.zeros(5),\n            strides=2,\n            output_padding=None,\n            padding=padding,\n            data_format=backend.config.image_data_format(),\n            dilation_rate=1,\n        )\n        self.assertAllClose(outputs, expected)\n\n    @parameterized.product(strides=(2, (2, 2)), padding=(\"valid\", \"same\"))\n    def test_conv_transpose_2d(self, strides, padding):\n        if backend.config.image_data_format() == \"channels_last\":\n            input_shape = (2, 4, 4, 3)\n        else:\n            input_shape = (2, 3, 4, 4)\n        inputs_2d = np.arange(96, dtype=float).reshape(input_shape)\n        kernel = np.arange(60, dtype=float).reshape([2, 2, 5, 3])\n\n        outputs = knn.conv_transpose(\n            inputs_2d, kernel, strides, padding=padding\n        )\n        expected = np_conv2d_transpose(\n            inputs_2d,\n            kernel,\n            bias_weights=np.zeros(5),\n            strides=strides,\n            output_padding=None,\n            padding=padding,\n            data_format=backend.config.image_data_format(),\n            dilation_rate=1,\n        )\n        self.assertAllClose(outputs, expected)\n\n    @parameterized.named_parameters(\n        [\n            {\"testcase_name\": \"dense\", \"sparse\": False},\n            {\"testcase_name\": \"sparse\", \"sparse\": True},\n        ]\n    )\n    def test_one_hot(self, sparse):\n        if sparse and not backend.SUPPORTS_SPARSE_TENSORS:\n            pytest.skip(\"Backend does not support sparse tensors\")\n        # Test 1D one-hot.\n        indices_1d = np.array([0, 1, 2, 3])\n        output_1d = knn.one_hot(indices_1d, 4, sparse=sparse)\n        self.assertAllClose(output_1d, np.eye(4)[indices_1d])\n        self.assertSparse(output_1d, sparse)\n        output_1d = knn.one_hot(indices_1d, 4, axis=0, sparse=sparse)\n        self.assertAllClose(output_1d, np.eye(4)[indices_1d])\n        self.assertSparse(output_1d, sparse)\n\n        # Test 1D list one-hot.\n        indices_1d = [0, 1, 2, 3]\n        output_1d = knn.one_hot(indices_1d, 4, sparse=sparse)\n        self.assertAllClose(output_1d, np.eye(4)[indices_1d])\n        self.assertSparse(output_1d, sparse)\n        output_1d = knn.one_hot(indices_1d, 4, axis=0, sparse=sparse)\n        self.assertAllClose(output_1d, np.eye(4)[indices_1d])\n        self.assertSparse(output_1d, sparse)\n\n        # Test 2D one-hot.\n        indices_2d = np.array([[0, 1], [2, 3]])\n        output_2d = knn.one_hot(indices_2d, 4, sparse=sparse)\n        self.assertAllClose(output_2d, np.eye(4)[indices_2d])\n        self.assertSparse(output_2d, sparse)\n        output_2d = knn.one_hot(indices_2d, 4, axis=2, sparse=sparse)\n        self.assertAllClose(output_2d, np.eye(4)[indices_2d])\n        self.assertSparse(output_2d, sparse)\n        output_2d = knn.one_hot(indices_2d, 4, axis=1, sparse=sparse)\n        self.assertAllClose(\n            output_2d, np.transpose(np.eye(4)[indices_2d], (0, 2, 1))\n        )\n        self.assertSparse(output_2d, sparse)\n\n        # Test 1D one-hot with 1 extra dimension.\n        indices_1d = np.array([[0], [1], [2], [3]])\n        output_1d = knn.one_hot(indices_1d, 4, sparse=sparse)\n        self.assertAllClose(output_1d, np.eye(4)[indices_1d])\n        self.assertSparse(output_1d, sparse)\n        output_1d = knn.one_hot(indices_1d, 4, axis=0, sparse=sparse)\n        self.assertAllClose(output_1d, np.eye(4)[indices_1d].swapaxes(1, 2))\n        self.assertSparse(output_1d, sparse)\n\n        # Test 1D one-hot with negative inputs\n        indices_1d = np.array([0, -1, -1, 3])\n        output_1d = knn.one_hot(indices_1d, 4, sparse=sparse)\n        self.assertAllClose(\n            output_1d,\n            np.array(\n                [\n                    [1, 0, 0, 0],\n                    [0, 0, 0, 0],\n                    [0, 0, 0, 0],\n                    [0, 0, 0, 1],\n                ],\n                dtype=np.float32,\n            ),\n        )\n        self.assertSparse(output_1d, sparse)\n\n    def test_binary_crossentropy(self):\n        # Test with from_logits=False\n        target = np.array([[0.1], [0.9], [0.2], [1.0]])\n        output = np.array([[0.1], [0.2], [0.3], [0.4]])\n        result = knn.binary_crossentropy(target, output, from_logits=False)\n        self.assertAllClose(\n            result,\n            np.array([[0.32508277], [1.47080801], [0.52613434], [0.91629048]]),\n        )\n\n        # Test with from_logits=True\n        target = np.array([[0.1], [0.9], [0.2], [1.0]])\n        output = np.array([[0.1], [0.2], [0.3], [0.4]])\n        result = knn.binary_crossentropy(target, output, from_logits=True)\n        self.assertAllClose(\n            result,\n            np.array([[0.73439666], [0.61813887], [0.79435524], [0.51301525]]),\n        )\n\n        # Test with output clipping\n        target = np.array([[0.1], [0.9], [0.2], [1.0]])\n        output = np.array([[0.99], [-0.2], [0.9], [-0.4]])\n        result = knn.binary_crossentropy(target, output, from_logits=True)\n        self.assertAllClose(\n            result,\n            np.array([[1.206961], [0.778139], [1.061154], [0.913015]]),\n        )\n\n    def test_categorical_crossentropy(self):\n        target = np.array(\n            [\n                [0.33008796, 0.0391289, 0.9503603],\n                [0.80376694, 0.92363342, 0.19147756],\n            ]\n        )\n        output = np.array(\n            [\n                [0.23446431, 0.35822914, 0.06683268],\n                [0.3413979, 0.05420256, 0.81619654],\n            ]\n        )\n\n        # Test from_logits=False\n        result = knn.categorical_crossentropy(\n            target, output, from_logits=False, axis=-1\n        )\n        self.assertAllClose(result, np.array([2.54095299, 3.96374412]))\n\n        # Test axis\n        result = knn.categorical_crossentropy(\n            target, output, from_logits=False, axis=0\n        )\n        self.assertAllClose(\n            result, np.array([0.71683073, 1.87988172, 2.46810762])\n        )\n\n        # Test from_logits=True\n        result = knn.categorical_crossentropy(\n            target, output, from_logits=True, axis=-1\n        )\n        self.assertAllClose(result, np.array([1.59419954, 2.49880593]))\n\n        # Test with output clipping\n        output = np.array(\n            [\n                [1.23446431, -0.35822914, 1.06683268],\n                [0.3413979, -0.05420256, 0.81619654],\n            ]\n        )\n        result = knn.categorical_crossentropy(\n            target, output, from_logits=True, axis=-1\n        )\n        self.assertAllClose(result, np.array([1.16825923, 2.55436813]))\n\n    def test_sparse_categorical_crossentropy(self):\n        target = np.array([0, 1, 2])\n        output = np.array(\n            [[0.9, 0.05, 0.05], [0.05, 0.89, 0.06], [0.05, 0.01, 0.94]]\n        )\n        result = knn.sparse_categorical_crossentropy(target, output)\n        self.assertAllClose(result, [0.105361, 0.116534, 0.061875])\n\n        output = np.array([[8.0, 1.0, 1.0], [0.0, 9.0, 1.0], [2.0, 3.0, 5.0]])\n        result = knn.sparse_categorical_crossentropy(\n            target, output, from_logits=True\n        )\n        self.assertAllClose(result, [0.001822, 0.000459, 0.169846])\n\n    @parameterized.named_parameters(\n        [\n            {\"testcase_name\": \"dense\", \"sparse\": False},\n            {\"testcase_name\": \"sparse\", \"sparse\": True},\n        ]\n    )\n    def test_multi_hot(self, sparse):\n        if sparse and not backend.SUPPORTS_SPARSE_TENSORS:\n            pytest.skip(\"Backend does not support sparse tensors\")\n\n        # Test 1D multi-hot.\n        indices_1d = np.array([0, 1, 2, 3])\n        expected_output_1d = np.array([1, 1, 1, 1])\n        output_1d = knn.multi_hot(indices_1d, 4, sparse=sparse)\n        self.assertAllClose(output_1d, expected_output_1d)\n        self.assertSparse(output_1d, sparse)\n\n        # Test 2D multi-hot.\n        indices_2d = np.array([[0, 1], [2, 3]])\n        expected_output_2d = np.array([[1, 1, 0, 0], [0, 0, 1, 1]])\n        output_2d = knn.multi_hot(indices_2d, 4, sparse=sparse)\n        self.assertAllClose(output_2d, expected_output_2d)\n        self.assertSparse(output_2d, sparse)\n\n        # Test 1D multi-hot with negative inputs\n        indices_1d = np.array([0, -1, -1, 3])\n        expected_output_1d = np.array([1, 0, 0, 1])\n        output_1d = knn.multi_hot(indices_1d, 4, sparse=sparse)\n        self.assertAllClose(output_1d, expected_output_1d)\n        self.assertSparse(output_1d, sparse)\n\n    def test_moments(self):\n        # Test 1D moments\n        x = np.array([0, 1, 2, 3, 4, 100, -200]).astype(np.float32)\n        mean, variance = knn.moments(x, axes=[0])\n        self.assertAllClose(mean, np.mean(x), atol=1e-5, rtol=1e-5)\n        self.assertAllClose(variance, np.var(x), atol=1e-5, rtol=1e-5)\n\n        # Test batch statistics for 4D moments (batch, height, width, channels)\n        x = np.random.uniform(size=(2, 28, 28, 3)).astype(np.float32)\n        mean, variance = knn.moments(x, axes=[0])\n        self.assertAllClose(mean, np.mean(x, axis=0), atol=1e-5, rtol=1e-5)\n        self.assertAllClose(variance, np.var(x, axis=0), atol=1e-5, rtol=1e-5)\n\n        # Test global statistics for 4D moments (batch, height, width, channels)\n        x = np.random.uniform(size=(2, 28, 28, 3)).astype(np.float32)\n        mean, variance = knn.moments(x, axes=[0, 1, 2])\n        expected_mean = np.mean(x, axis=(0, 1, 2))\n        expected_variance = np.var(x, axis=(0, 1, 2))\n        self.assertAllClose(mean, expected_mean, atol=1e-5, rtol=1e-5)\n        self.assertAllClose(variance, expected_variance, atol=1e-5, rtol=1e-5)\n\n        # Test keepdims\n        x = np.random.uniform(size=(2, 28, 28, 3)).astype(np.float32)\n        mean, variance = knn.moments(x, axes=[0, 1, 2], keepdims=True)\n        expected_mean = np.mean(x, axis=(0, 1, 2), keepdims=True)\n        expected_variance = np.var(x, axis=(0, 1, 2), keepdims=True)\n        self.assertAllClose(mean, expected_mean, atol=1e-5, rtol=1e-5)\n        self.assertAllClose(variance, expected_variance, atol=1e-5, rtol=1e-5)\n\n        # Test float16 which causes overflow\n        x = np.array(\n            [-741.0, 353.2, 1099.0, -1807.0, 502.8, -83.4, 333.5, -130.9],\n            dtype=np.float16,\n        )\n        mean, variance = knn.moments(x, axes=[0])\n        expected_mean = np.mean(x.astype(np.float32)).astype(np.float16)\n        # the output variance is clipped to the max value of np.float16 because\n        # it is overflowed\n        expected_variance = np.finfo(np.float16).max\n        self.assertAllClose(mean, expected_mean, atol=1e-5, rtol=1e-5)\n        self.assertAllClose(variance, expected_variance, atol=1e-5, rtol=1e-5)\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"synchronized=True only implemented for TF backend\",\n    )\n    def test_moments_sync(self):\n        # Test batch statistics for 4D moments (batch, height, width, channels)\n        x = np.random.uniform(size=(2, 28, 28, 3)).astype(np.float32)\n        mean, variance = knn.moments(x, axes=[0], synchronized=True)\n        self.assertAllClose(mean, np.mean(x, axis=0), atol=1e-5, rtol=1e-5)\n        self.assertAllClose(variance, np.var(x, axis=0), atol=1e-5, rtol=1e-5)\n\n        # Test global statistics for 4D moments (batch, height, width, channels)\n        x = np.random.uniform(size=(2, 28, 28, 3)).astype(np.float32)\n        mean, variance = knn.moments(x, axes=[0, 1, 2], synchronized=True)\n        expected_mean = np.mean(x, axis=(0, 1, 2))\n        expected_variance = np.var(x, axis=(0, 1, 2))\n        self.assertAllClose(mean, expected_mean, atol=1e-5, rtol=1e-5)\n        self.assertAllClose(variance, expected_variance, atol=1e-5, rtol=1e-5)\n\n        # Test keepdims\n        x = np.random.uniform(size=(2, 28, 28, 3)).astype(np.float32)\n        mean, variance = knn.moments(\n            x, axes=[0, 1, 2], keepdims=True, synchronized=True\n        )\n        expected_mean = np.mean(x, axis=(0, 1, 2), keepdims=True)\n        expected_variance = np.var(x, axis=(0, 1, 2), keepdims=True)\n        self.assertAllClose(mean, expected_mean, atol=1e-5, rtol=1e-5)\n        self.assertAllClose(variance, expected_variance, atol=1e-5, rtol=1e-5)\n\n    @parameterized.product(dtype=[\"float16\", \"float32\"])\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"synchronized=True only implemented for TF backend\",\n    )\n    def test_moments_sync_with_distribution_strategy(self, dtype):\n        from tensorflow.python.eager import context\n\n        from keras.src.utils.module_utils import tensorflow as tf\n\n        context._reset_context()\n\n        # Config 2 CPUs for testing.\n        logical_cpus = tf.config.list_logical_devices(\"CPU\")\n        if len(logical_cpus) == 1:\n            from tensorflow.python.eager import context\n\n            context._reset_context()\n            tf.config.set_logical_device_configuration(\n                tf.config.list_physical_devices(\"CPU\")[0],\n                [\n                    tf.config.LogicalDeviceConfiguration(),\n                    tf.config.LogicalDeviceConfiguration(),\n                ],\n            )\n\n        @tf.function()\n        def test_on_moments(inputs):\n            return knn.moments(\n                inputs, axes=-1, keepdims=True, synchronized=True\n            )\n\n        # Test output of moments.\n        inputs = tf.constant([5.0, 9.0, 1.0, 3.0], dtype=dtype)\n        strategy = tf.distribute.MirroredStrategy([\"CPU:0\", \"CPU:1\"])\n        with strategy.scope():\n            mean, variance = strategy.run(test_on_moments, args=(inputs,))\n            self.assertEqual(mean.values[0], 4.5)\n            self.assertEqual(variance.values[0], 8.75)\n            self.assertEqual(variance.values[0], 8.75)\n\n        context._reset_context()\n\n    def test_batch_normalization(self):\n        x = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])\n        mean = np.array([0.2, 0.3, 0.4])\n        variance = np.array([4.0, 16.0, 64.0])\n        output = knn.batch_normalization(\n            x,\n            mean,\n            variance,\n            axis=-1,\n            offset=np.array([5.0, 10.0, 15.0]),\n            scale=np.array([10.0, 20.0, 30.0]),\n            epsilon=1e-7,\n        )\n        expected_output = np.array([[4.5, 9.5, 14.625], [6.0, 11.0, 15.75]])\n        self.assertAllClose(output, expected_output)\n\n        output = knn.batch_normalization(\n            x,\n            mean,\n            variance,\n            axis=1,\n            epsilon=1e-7,\n        )\n        expected_output = np.array(\n            [[-0.05, -0.025, -0.0125], [0.1, 0.05, 0.025]]\n        )\n        self.assertAllClose(output, expected_output)\n\n        output = knn.batch_normalization(\n            np.random.uniform(size=[2, 3, 3, 5]),\n            np.random.uniform(size=[5]),\n            np.random.uniform(size=[5]),\n            axis=3,\n            offset=np.random.uniform(size=[5]),\n            scale=np.random.uniform(size=[5]),\n        )\n        self.assertEqual(tuple(output.shape), (2, 3, 3, 5))\n\n    def test_ctc_loss(self):\n        labels = np.array([[1, 2, 1], [1, 2, 2]])\n        outputs = np.array(\n            [\n                [[0.4, 0.8, 0.4], [0.2, 0.8, 0.3], [0.9, 0.4, 0.5]],\n                [[0.4, 0.8, 0.4], [0.2, 0.3, 0.3], [0.4, 0.3, 0.2]],\n            ]\n        )\n\n        label_length = np.array([3, 2])\n        output_length = np.array([3, 2])\n\n        result = knn.ctc_loss(labels, outputs, label_length, output_length)\n        self.assertAllClose(\n            result,\n            np.array([3.4411672, 1.91680186]),\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n\n    def test_ctc_decode(self):\n        inputs = np.array(\n            [\n                [\n                    [0.1, 0.4, 0.2, 0.4],\n                    [0.3, -0.3, 0.4, 0.2],\n                    [0.3, 0.2, 0.4, 0.3],\n                ],\n                [\n                    [0.7, 0.4, 0.3, 0.2],\n                    [0.3, 0.3, 0.4, 0.1],\n                    [0.6, -0.1, 0.1, 0.5],\n                ],\n                [\n                    [0.1, 0.4, 0.2, 0.7],\n                    [0.3, 0.3, -0.2, 0.7],\n                    [0.3, 0.2, 0.4, 0.1],\n                ],\n            ]\n        )\n        labels = np.array([[1, 2, -1], [2, -1, -1], [3, -1, -1]])\n        score_labels = np.array([[-1.2], [-1.7], [-0.7]])\n        repeated_labels = np.array([[1, 2, 2], [2, -1, -1], [3, -1, -1]])\n\n        # Test strategy=\"greedy\" and merge_repeated=True\n        (decoded,), scores = knn.ctc_decode(\n            inputs,\n            sequence_lengths=[3, 3, 1],\n            strategy=\"greedy\",\n            mask_index=0,\n        )\n        self.assertAllClose(decoded, labels)\n        self.assertAllClose(scores, score_labels)\n\n        # Test strategy=\"greedy\" and merge_repeated=False\n        (decoded,), scores = knn.ctc_decode(\n            inputs,\n            sequence_lengths=[3, 3, 1],\n            strategy=\"greedy\",\n            merge_repeated=False,\n            mask_index=0,\n        )\n        self.assertAllClose(decoded, repeated_labels)\n        self.assertAllClose(scores, score_labels)\n\n        if backend.backend() == \"torch\":\n            self.skipTest(\"torch doesn't support 'beam_search' strategy\")\n\n        labels = np.array(\n            [\n                [[1, 2, -1], [2, -1, -1], [3, -1, -1]],\n                [[2, -1, -1], [3, -1, -1], [1, -1, -1]],\n            ]\n        )\n        score_labels = np.array(\n            [\n                [-2.426537, -2.435596],\n                [-2.127681, -2.182338],\n                [-1.063386, -1.363386],\n            ]\n        )\n        beam_width = 4\n        top_paths = 2\n\n        # Test strategy=\"beam_search\"\n        decoded, scores = knn.ctc_decode(\n            inputs,\n            sequence_lengths=[3, 3, 1],\n            strategy=\"beam_search\",\n            beam_width=beam_width,\n            top_paths=top_paths,\n            mask_index=0,\n        )\n        self.assertAllClose(decoded, labels)\n        self.assertAllClose(scores, score_labels)\n\n    def test_normalize(self):\n        x = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.float32)\n        self.assertAllClose(\n            knn.normalize(x, axis=None),\n            [\n                [0.18898225, 0.3779645, 0.56694674],\n                [0.18898225, 0.3779645, 0.56694674],\n            ],\n        )\n        self.assertAllClose(\n            knn.normalize(x, axis=0),\n            [\n                [0.70710677, 0.70710677, 0.70710677],\n                [0.70710677, 0.70710677, 0.70710677],\n            ],\n        )\n        self.assertAllClose(\n            knn.normalize(x, axis=-1),\n            [\n                [0.26726124, 0.53452247, 0.8017837],\n                [0.26726124, 0.53452247, 0.8017837],\n            ],\n        )\n        self.assertAllClose(\n            knn.normalize(x, order=3),\n            [\n                [0.30285344, 0.6057069, 0.9085603],\n                [0.30285344, 0.6057069, 0.9085603],\n            ],\n        )\n\n        # linalg.norm(x, ...) < epsilon\n        x = np.array([[1e-6, 1e-8]], dtype=np.float32)\n        self.assertAllClose(\n            knn.normalize(x, axis=-1, order=2, epsilon=1e-5),\n            [[1e-1, 1e-3]],\n        )\n\n    def test_psnr(self):\n        x1 = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])\n        x2 = np.array([[0.2, 0.2, 0.3], [0.4, 0.6, 0.6]])\n        max_val = 1.0\n        expected_psnr_1 = 20 * np.log10(max_val) - 10 * np.log10(\n            np.mean(np.square(x1 - x2))\n        )\n        psnr_1 = knn.psnr(x1, x2, max_val)\n        self.assertAlmostEqual(psnr_1, expected_psnr_1)\n\n        x3 = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])\n        x4 = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]])\n        max_val = 1.0\n        expected_psnr_2 = 20 * np.log10(max_val) - 10 * np.log10(\n            np.mean(np.square(x3 - x4))\n        )\n        psnr_2 = knn.psnr(x3, x4, max_val)\n        self.assertAlmostEqual(psnr_2, expected_psnr_2)\n\n    @parameterized.named_parameters(\n        named_product(\n            bias=(None, True),\n            scale=(None, 1.0),\n            mask_and_is_causal=((None, False), (True, False), (None, True)),\n            flash_attention=(None, True, False),\n        )\n    )\n    def test_dot_product_attention(\n        self, bias, scale, mask_and_is_causal, flash_attention\n    ):\n        mask, is_causal = mask_and_is_causal\n        query_shape = (2, 3, 4, 8)\n        key_shape = (2, 3, 4, 8)\n        bias_shape = (2, 4, 3, 3)\n        query = np.arange(math.prod(query_shape), dtype=float).reshape(\n            query_shape\n        )\n        key = np.arange(math.prod(key_shape), dtype=float).reshape(key_shape)\n        value = np.arange(math.prod(key_shape), dtype=float).reshape(key_shape)\n        if mask is not None:\n            mask = np.tril(np.ones((3, 3))).astype(\"bool\")\n            mask = mask[None, None, ...]\n            mask = np.tile(mask, (2, 4, 1, 1))\n        if bias is not None:\n            if backend.backend() == \"openvino\":\n                self.skipTest(\n                    \"openvino does not support `bias` with \"\n                    \"`dot_product_attention`\"\n                )\n            if backend.backend() == \"torch\" and mask is not None:\n                self.skipTest(\n                    \"torch does not support `mask` and `bias` with \"\n                    \"`dot_product_attention`\"\n                )\n            bias = np.arange(math.prod(bias_shape), dtype=float).reshape(\n                bias_shape\n            )\n\n        if flash_attention:\n            if backend.backend() in (\"tensorflow\", \"numpy\", \"openvino\"):\n                self.skipTest(\n                    \"Flash attention is not supported in tensorflow, numpy, \"\n                    \"and openvino backends.\"\n                )\n            elif backend.backend() == \"torch\":\n                import torch\n\n                if bias is not None:\n                    self.skipTest(\n                        \"Flash attention doesn't support `bias` in torch \"\n                        \"backend.\"\n                    )\n                if mask is not None:\n                    self.skipTest(\n                        \"Flash attention doesn't support `mask=None` in torch \"\n                        \"backend.\"\n                    )\n                if not torch.cuda.is_available():\n                    self.skipTest(\n                        \"Flash attention must be run on CUDA in torch backend.\"\n                    )\n                cuda_compute_capability = tuple(\n                    int(x) for x in torch.cuda.get_device_capability()\n                )\n                if cuda_compute_capability < (8, 0):\n                    self.skipTest(\n                        \"Flash attention must be run on CUDA compute \"\n                        \"capability >= 8.0 in torch backend.\"\n                    )\n            elif backend.backend() == \"jax\":\n                import jax\n                from jax._src import xla_bridge\n\n                if \"cuda\" not in xla_bridge.get_backend().platform_version:\n                    self.skipTest(\n                        \"Flash attention must be run on CUDA in jax backend.\"\n                    )\n                d, *_ = jax.local_devices(backend=\"gpu\")\n                cuda_compute_capability = tuple(\n                    int(x) for x in d.compute_capability.split(\".\")\n                )\n                if cuda_compute_capability < (8, 0):\n                    self.skipTest(\n                        \"Flash attention must be run on CUDA compute \"\n                        \"capability >= 8.0 in jax backend.\"\n                    )\n\n            # Flash attention only supports float16 and bfloat16. We multiply\n            # 0.1 to avoid overflow.\n            query = (query * 0.1).astype(\"float16\")\n            key = (key * 0.1).astype(\"float16\")\n            value = (value * 0.1).astype(\"float16\")\n            if bias is not None:\n                bias = (bias * 0.1).astype(\"float16\")\n\n        outputs = knn.dot_product_attention(\n            query,\n            key,\n            value,\n            bias=bias,\n            mask=mask,\n            scale=scale,\n            is_causal=is_causal,\n            flash_attention=flash_attention,\n        )\n\n        expected = _dot_product_attention(\n            query,\n            key,\n            value,\n            bias=bias,\n            mask=mask,\n            scale=scale,\n            is_causal=is_causal,\n        )\n        self.assertAllClose(\n            outputs, expected, atol=1e-3 if flash_attention else 1e-6\n        )\n\n    @parameterized.named_parameters(named_product(scale=(1.0, 10.0)))\n    def test_rms_normalization(self, scale):\n        x = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]], dtype=\"float32\")\n        scale = np.array([scale] * x.shape[-1], dtype=\"float32\")\n        expected_output = (\n            np.array([[0.46291, 0.92582, 1.38873], [0.78954, 0.98693, 1.18431]])\n            * scale\n        )\n\n        self.assertAllClose(\n            knn.rms_normalization(x, scale), expected_output, atol=1e-3\n        )\n        self.assertAllClose(knn.RMSNorm()(x, scale), expected_output, atol=1e-3)\n\n    def test_layer_normalization(self):\n        x = np.arange(5, dtype=\"float32\")\n        expected_output = np.array(\n            [-1.4142135, -0.70710677, 0.0, 0.7071067, 1.4142135]\n        )\n\n        self.assertAllClose(\n            knn.layer_normalization(x), expected_output, atol=1e-3\n        )\n        self.assertAllClose(knn.LayerNorm()(x), expected_output, atol=1e-3)\n\n\nclass NNOpsDtypeTest(testing.TestCase):\n    \"\"\"Test the floating dtype to verify that the behavior matches JAX.\"\"\"\n\n    FLOAT_DTYPES = [x for x in dtypes.FLOAT_TYPES if x not in (\"float64\",)]\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_elu(self, dtype):\n        import jax.nn as jnn\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnn.elu(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knn.elu(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knn.Elu().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_gelu(self, dtype):\n        import jax.nn as jnn\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n\n        # approximate = True\n        expected_dtype = standardize_dtype(jnn.gelu(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knn.gelu(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knn.Gelu().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n        # approximate = False\n        expected_dtype = standardize_dtype(jnn.gelu(x_jax, False).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knn.gelu(x, False).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knn.Gelu(False).symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_celu(self, dtype):\n        import jax.nn as jnn\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnn.celu(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knn.celu(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knn.Celu().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_tanh_shrink(self, dtype):\n        import torch\n        import torch.nn.functional as tnn\n\n        x = knp.ones((1), dtype=dtype)\n        x_torch = torch.ones(1, dtype=getattr(torch, dtype))\n        expected_dtype = standardize_dtype(tnn.tanhshrink(x_torch).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knn.tanh_shrink(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knn.TanhShrink().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_hard_tanh(self, dtype):\n        import jax.nn as jnn\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnn.hard_tanh(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knn.hard_tanh(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knn.HardTanh().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_hard_shrink(self, dtype):\n        import torch\n        import torch.nn.functional as tnn\n\n        x = knp.ones((1), dtype=dtype)\n        x_torch = torch.ones(1, dtype=getattr(torch, dtype))\n        expected_dtype = standardize_dtype(tnn.hardshrink(x_torch).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knn.hard_shrink(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knn.HardShrink().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_threshold(self, dtype):\n        import torch\n        import torch.nn.functional as tnn\n\n        x = knp.ones((1), dtype=dtype)\n        x_torch = torch.ones(1, dtype=getattr(torch, dtype))\n        expected_dtype = standardize_dtype(tnn.threshold(x_torch, 0, 0).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knn.threshold(x, 0, 0).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knn.Threshold(0, 0).symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_soft_shrink(self, dtype):\n        import torch\n        import torch.nn.functional as tnn\n\n        x = knp.ones((1), dtype=dtype)\n        x_torch = torch.ones(1, dtype=getattr(torch, dtype))\n        expected_dtype = standardize_dtype(tnn.softshrink(x_torch).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knn.soft_shrink(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knn.SoftShrink().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_sparse_plus(self, dtype):\n        import jax.nn as jnn\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnn.sparse_plus(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knn.sparse_plus(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knn.SparsePlus().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_glu(self, dtype):\n        import jax.nn as jnn\n        import jax.numpy as jnp\n\n        x = knp.ones((2), dtype=dtype)\n        x_jax = jnp.ones((2), dtype=dtype)\n        expected_dtype = standardize_dtype(jnn.glu(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knn.glu(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knn.Glu().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_squareplus(self, dtype):\n        import jax.nn as jnn\n        import jax.numpy as jnp\n\n        x = knp.ones((2), dtype=dtype)\n        x_jax = jnp.ones((2), dtype=dtype)\n        expected_dtype = standardize_dtype(jnn.squareplus(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knn.squareplus(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knn.Squareplus().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_hard_sigmoid(self, dtype):\n        import jax.nn as jnn\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnn.hard_sigmoid(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knn.hard_sigmoid(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knn.HardSigmoid().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_hard_silu(self, dtype):\n        import jax.nn as jnn\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnn.hard_silu(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knn.hard_silu(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knn.HardSilu().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_leaky_relu(self, dtype):\n        import jax.nn as jnn\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnn.leaky_relu(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knn.leaky_relu(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knn.LeakyRelu().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_log_sigmoid(self, dtype):\n        import jax.nn as jnn\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnn.log_sigmoid(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knn.log_sigmoid(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knn.LogSigmoid().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_log_softmax(self, dtype):\n        import jax.nn as jnn\n        import jax.numpy as jnp\n\n        x = knp.ones((10,), dtype=dtype)\n        x_jax = jnp.ones((10,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnn.log_softmax(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knn.log_softmax(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knn.LogSoftmax().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_relu(self, dtype):\n        import jax.nn as jnn\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnn.relu(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knn.relu(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knn.Relu().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_relu6(self, dtype):\n        import jax.nn as jnn\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnn.relu6(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knn.relu6(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knn.Relu6().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_selu(self, dtype):\n        import jax.nn as jnn\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnn.selu(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knn.selu(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knn.Selu().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_sigmoid(self, dtype):\n        import jax.nn as jnn\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnn.sigmoid(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knn.sigmoid(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knn.Sigmoid().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_sparse_sigmoid(self, dtype):\n        import jax.nn as jnn\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnn.sparse_sigmoid(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knn.sparse_sigmoid(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knn.SparseSigmoid().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_silu(self, dtype):\n        import jax.nn as jnn\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnn.silu(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knn.silu(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knn.Silu().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_softplus(self, dtype):\n        import jax.nn as jnn\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnn.softplus(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knn.softplus(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knn.Softplus().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_softmax(self, dtype):\n        import jax.nn as jnn\n        import jax.numpy as jnp\n\n        x = knp.ones((10,), dtype=dtype)\n        x_jax = jnp.ones((10,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnn.softmax(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knn.softmax(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knn.Softmax().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_softsign(self, dtype):\n        import jax.nn as jnn\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnn.soft_sign(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knn.softsign(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knn.Softsign().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_polar(self, dtype):\n        import jax.nn as jnn\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnn.hard_tanh(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knn.hard_tanh(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knn.HardTanh().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_ctc_loss(self, dtype):\n        labels = knp.array([[1, 2, 1]], dtype=\"int32\")\n        outputs = knp.array(\n            [[[0.4, 0.8, 0.4], [0.2, 0.8, 0.3], [0.9, 0.4, 0.5]]], dtype=dtype\n        )\n        label_length = knp.array([3])\n        output_length = knp.array([3])\n        expected_dtype = (\n            \"float32\" if dtype in (\"float16\", \"bfloat16\") else dtype\n        )\n\n        self.assertEqual(\n            standardize_dtype(\n                knn.ctc_loss(labels, outputs, label_length, output_length).dtype\n            ),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(\n                knn.CTCLoss()\n                .symbolic_call(labels, outputs, label_length, output_length)\n                .dtype\n            ),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_ctc_decode(self, dtype):\n        inputs = knp.array(\n            [[[0.4, 0.8, 0.4], [0.2, 0.8, 0.3], [0.9, 0.4, 0.5]]], dtype=dtype\n        )\n        sequence_length = knp.array([3])\n        expected_dtype = backend.result_type(dtype, \"float32\")\n\n        # Test strategy=\"greedy\"\n        decoded, scores = knn.ctc_decode(\n            inputs, sequence_length, strategy=\"greedy\"\n        )\n        self.assertEqual(standardize_dtype(decoded.dtype), \"int32\")\n        self.assertEqual(standardize_dtype(scores.dtype), expected_dtype)\n        decoded, scores = knn.CTCDecode(strategy=\"greedy\").symbolic_call(\n            inputs, sequence_length\n        )\n        self.assertEqual(standardize_dtype(decoded.dtype), \"int32\")\n        self.assertEqual(standardize_dtype(scores.dtype), expected_dtype)\n\n        if backend.backend() == \"torch\":\n            self.skipTest(\"torch doesn't support 'beam_search' strategy\")\n\n        # Test strategy=\"beam_search\"\n        decoded, scores = knn.ctc_decode(\n            inputs, sequence_length, strategy=\"beam_search\"\n        )\n        self.assertEqual(standardize_dtype(decoded.dtype), \"int32\")\n        self.assertEqual(standardize_dtype(scores.dtype), expected_dtype)\n        decoded, scores = knn.CTCDecode(strategy=\"beam_search\").symbolic_call(\n            inputs, sequence_length\n        )\n        self.assertEqual(standardize_dtype(decoded.dtype), \"int32\")\n        self.assertEqual(standardize_dtype(scores.dtype), expected_dtype)\n\n    @parameterized.named_parameters(\n        named_product(\n            dtypes=list(combinations(FLOAT_DTYPES, 2))\n            + [(dtype, dtype) for dtype in FLOAT_DTYPES]\n        )\n    )\n    def test_dot_product_attention(self, dtypes):\n        # TODO: Get expected output from jax if `jax.nn.dot_product_attention`\n        # is available.\n        query_dtype, key_value_dtype = dtypes\n        query = knp.ones((2, 3, 3, 8), dtype=query_dtype)\n        key = knp.ones((2, 3, 3, 8), dtype=key_value_dtype)\n        value = knp.ones((2, 3, 3, 8), dtype=key_value_dtype)\n        expected_dtype = backend.result_type(*dtypes)\n\n        self.assertDType(\n            knn.dot_product_attention(query, key, value), expected_dtype\n        )\n        self.assertDType(\n            knn.DotProductAttention().symbolic_call(query, key, value),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=combinations(FLOAT_DTYPES, 2))\n    )\n    def test_rms_normalization(self, dtypes):\n        input_dtype, weight_dtype = dtypes\n        inputs = knp.ones((2, 8), dtype=input_dtype)\n        scale = backend.Variable(knp.ones((8,), dtype=weight_dtype))\n        expected_dtype = input_dtype\n\n        self.assertDType(knn.rms_normalization(inputs, scale), expected_dtype)\n        self.assertDType(\n            knn.RMSNorm().symbolic_call(inputs, scale), expected_dtype\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=combinations(FLOAT_DTYPES, 2))\n    )\n    def test_layer_normalization(self, dtypes):\n        input_dtype, weight_dtype = dtypes\n        inputs = knp.ones((2, 8), dtype=input_dtype)\n        gamma = backend.Variable(knp.ones((8,), dtype=weight_dtype))\n        beta = backend.Variable(knp.ones((8,), dtype=weight_dtype))\n        expected_dtype = input_dtype\n\n        self.assertDType(\n            knn.layer_normalization(inputs, gamma, beta), expected_dtype\n        )\n        self.assertDType(\n            knn.LayerNorm().symbolic_call(inputs, gamma, beta), expected_dtype\n        )\n\n\nclass NNOpsBehaviorTest(testing.TestCase):\n    def test_logit_recovery_binary_crossentropy(self):\n        layer = layers.Dense(\n            4, activation=\"sigmoid\", use_bias=False, kernel_initializer=\"ones\"\n        )\n        loss = losses.BinaryCrossentropy()\n        x = np.array([[1.4, 1.6, 0.8]])\n        y = np.array([[0.2, 0.6, 0.1, 0.3]])\n        loss_value = loss(y, layer(x))\n        self.assertAllClose(loss_value, 2.682124)\n\n        model = models.Sequential([layer])\n        model.compile(loss=\"binary_crossentropy\", optimizer=\"sgd\")\n        out = model.evaluate(x, y)\n        self.assertAllClose(out, 2.682124)\n\n    def test_softmax_on_axis_with_size_one_warns(self):\n        x = np.array([[1.0]])\n        # Applying softmax on the second axis, which has size 1\n        axis = 1\n\n        # Expected warning message\n        expected_warning_regex = (\n            r\"You are using a softmax over axis 1 \"\n            r\"of a tensor of shape \\(1, 1\\)\\. This axis \"\n            r\"has size 1\\. The softmax operation will always return \"\n            r\"the value 1, which is likely not what you intended\\. \"\n            r\"Did you mean to use a sigmoid instead\\?\"\n        )\n\n        with self.assertWarnsRegex(UserWarning, expected_warning_regex):\n            knn.softmax(x, axis)\n\n    def test_normalize_order_validation(self):\n        # Test with a non-integer order\n        with self.assertRaisesRegex(\n            ValueError, \"Argument `order` must be an int >= 1\"\n        ):\n            knn.normalize(np.array([1, 2, 3]), order=\"a\")\n\n        # Test with a negative integer\n        with self.assertRaisesRegex(\n            ValueError, \"Argument `order` must be an int >= 1\"\n        ):\n            knn.normalize(np.array([1, 2, 3]), order=-1)\n\n        # Test with zero\n        with self.assertRaisesRegex(\n            ValueError, \"Argument `order` must be an int >= 1\"\n        ):\n            knn.normalize(np.array([1, 2, 3]), order=0)\n\n        # Test with a floating-point number\n        with self.assertRaisesRegex(\n            ValueError, \"Argument `order` must be an int >= 1\"\n        ):\n            knn.normalize(np.array([1, 2, 3]), order=2.5)\n\n    def test_check_shape_first_dim_mismatch(self):\n        name1, shape1 = \"labels\", (2, 3)\n        name2, shape2 = \"logits\", (3, 4, 5)\n        ctc_loss_instance = knn.CTCLoss(mask_index=-1)\n        with self.assertRaisesRegex(\n            ValueError, \"must have the same first dimension\"\n        ):\n            ctc_loss_instance._check_shape_first_dim(\n                name1, shape1, name2, shape2\n            )\n\n    def test_invalid_strategy_ctc_decode(self):\n        inputs = np.array(\n            [\n                [\n                    [0.1, 0.4, 0.2, 0.4],\n                    [0.3, 0.3, 0.4, 0.2],\n                    [0.3, 0.2, 0.4, 0.3],\n                ]\n            ]\n        )\n        beam_width = 4\n        top_paths = 2\n        with self.assertRaisesRegex(ValueError, \"Invalid strategy\"):\n            knn.ctc_decode(\n                inputs,\n                sequence_lengths=[3, 3, 1],\n                strategy=\"invalid\",\n                beam_width=beam_width,\n                top_paths=top_paths,\n            )\n\n    def test_layer_normalization_rms_scaling_warning(self):\n        x = np.arange(5, dtype=\"float32\")\n        with self.assertWarnsRegex(\n            UserWarning, r\"You passed `rms_scaling=True`, which is deprecated\"\n        ):\n            knn.layer_normalization(x, rms_scaling=True)\n\n    def test_unfold(self):\n        # test 1 kernel_size=2\n        x = ops.arange(8, dtype=\"float32\")\n        x = ops.reshape(x, [1, 1, 2, 4])\n        unfold_result = knn.unfold(x, 2)\n        except_result = ops.convert_to_tensor(\n            [\n                [\n                    [0.0, 1.0, 2.0],\n                    [1.0, 2.0, 3.0],\n                    [4.0, 5.0, 6.0],\n                    [5.0, 6.0, 7.0],\n                ]\n            ]\n        )\n        self.assertAllClose(unfold_result, except_result)\n\n        # test 2 kernel_size=[2,4]\n        x = ops.arange(16, dtype=\"float32\")\n        x = ops.reshape(x, [1, 1, 4, 4])\n        unfold_result = knn.unfold(x, [2, 4])\n        except_result = ops.convert_to_tensor(\n            [\n                [\n                    [0.0, 4.0, 8.0],\n                    [1.0, 5.0, 9.0],\n                    [2.0, 6.0, 10.0],\n                    [3.0, 7.0, 11.0],\n                    [4.0, 8.0, 12.0],\n                    [5.0, 9.0, 13.0],\n                    [6.0, 10.0, 14.0],\n                    [7.0, 11.0, 15.0],\n                ]\n            ],\n            dtype=\"float32\",\n        )\n        self.assertAllClose(unfold_result, except_result)\n\n        # test 3 kernel_size=[3,2],stride=[3,2]\n        x = ops.arange(12, dtype=\"float32\")\n        x = ops.reshape(x, [1, 1, 3, 4])\n        unfold_result = knn.unfold(x, [3, 2], stride=[3, 2])\n        except_result = ops.convert_to_tensor(\n            [\n                [\n                    [0.0, 2.0],\n                    [1.0, 3.0],\n                    [4.0, 6.0],\n                    [5.0, 7.0],\n                    [8.0, 10.0],\n                    [9.0, 11.0],\n                ]\n            ]\n        )\n        self.assertAllClose(unfold_result, except_result)\n\n        # test 4 kernel_size=2,dilation=2,stride=2\n        x = ops.arange(16, dtype=\"float32\")\n        x = ops.reshape(x, [1, 1, 4, 4])\n        unfold_result = knn.unfold(x, 2, 2, stride=2)\n        except_result = ops.convert_to_tensor([0, 2, 8, 10], dtype=\"float32\")\n        except_result = ops.reshape(except_result, [1, 4, 1])\n        self.assertAllClose(unfold_result, except_result)\n\n        # test 5 kernel_size=2,padding=1\n        x = ops.arange(4, dtype=\"float32\")\n        x = ops.reshape(x, [1, 1, 2, 2])\n        unfold_result = knn.unfold(x, 1, padding=1)\n        except_result = ops.convert_to_tensor(\n            [\n                [\n                    [\n                        0.0,\n                        0.0,\n                        0.0,\n                        0.0,\n                        0.0,\n                        0.0,\n                        1.0,\n                        0.0,\n                        0.0,\n                        2.0,\n                        3.0,\n                        0.0,\n                        0.0,\n                        0.0,\n                        0.0,\n                        0.0,\n                    ]\n                ]\n            ]\n        )\n        self.assertAllClose(unfold_result, except_result)\n\n        # test 6 multi channal and kernel_size=2\n        x = ops.arange(8, dtype=\"float32\")\n        x = ops.reshape(x, [1, 2, 2, 2])\n        unfold_result = knn.unfold(x, 2)\n        except_result = ops.convert_to_tensor(\n            [[[0.0], [1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0]]]\n        )\n        self.assertAllClose(unfold_result, except_result)\n\n        # test 7 multi channal and kernel_size=[2,3]\n        x = ops.arange(12, dtype=\"float32\")\n        x = ops.reshape(x, [1, 2, 2, 3])\n        unfold_result = knn.unfold(x, [2, 3])\n        except_result = ops.convert_to_tensor(\n            [\n                [\n                    [0.0],\n                    [1.0],\n                    [2.0],\n                    [3.0],\n                    [4.0],\n                    [5.0],\n                    [6.0],\n                    [7.0],\n                    [8.0],\n                    [9.0],\n                    [10.0],\n                    [11.0],\n                ]\n            ]\n        )\n        self.assertAllClose(unfold_result, except_result)\n\n        # test 8 multi channal and kernel_size=[2,3],stride=[2,3]\n        x = ops.arange(12, dtype=\"float32\")\n        x = ops.reshape(x, [1, 2, 2, 3])\n        unfold_result = knn.unfold(x, [2, 3], stride=[2, 3])\n        except_result = ops.convert_to_tensor(\n            [\n                [\n                    [0.0],\n                    [1.0],\n                    [2.0],\n                    [3.0],\n                    [4.0],\n                    [5.0],\n                    [6.0],\n                    [7.0],\n                    [8.0],\n                    [9.0],\n                    [10.0],\n                    [11.0],\n                ]\n            ]\n        )\n        self.assertAllClose(unfold_result, except_result)\n\n        # test 9 multi channal and kernel_size=2,dilation=2\n        x = ops.arange(32, dtype=\"float32\")\n        x = ops.reshape(x, [1, 2, 4, 4])\n        unfold_result = knn.unfold(x, 2, dilation=2)\n        except_result = ops.convert_to_tensor(\n            [\n                [\n                    [0.0, 1.0, 4.0, 5.0],\n                    [2.0, 3.0, 6.0, 7.0],\n                    [8.0, 9.0, 12.0, 13.0],\n                    [10.0, 11.0, 14.0, 15.0],\n                    [16.0, 17.0, 20.0, 21.0],\n                    [18.0, 19.0, 22.0, 23.0],\n                    [24.0, 25.0, 28.0, 29.0],\n                    [26.0, 27.0, 30.0, 31.0],\n                ]\n            ]\n        )\n        self.assertAllClose(unfold_result, except_result)\n\n        # test 10 multi channal and kernel_size=2,padding=1\n        x = ops.arange(8, dtype=\"float32\")\n        x = ops.reshape(x, [1, 2, 2, 2])\n        unfold_result = knn.unfold(x, 2, padding=1)\n        except_result = ops.convert_to_tensor(\n            [\n                [\n                    [0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 2.0, 3.0],\n                    [0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 2.0, 3.0, 0.0],\n                    [0.0, 0.0, 1.0, 0.0, 2.0, 3.0, 0.0, 0.0, 0.0],\n                    [0.0, 1.0, 0.0, 2.0, 3.0, 0.0, 0.0, 0.0, 0.0],\n                    [0.0, 0.0, 0.0, 0.0, 4.0, 5.0, 0.0, 6.0, 7.0],\n                    [0.0, 0.0, 0.0, 4.0, 5.0, 0.0, 6.0, 7.0, 0.0],\n                    [0.0, 4.0, 5.0, 0.0, 6.0, 7.0, 0.0, 0.0, 0.0],\n                    [4.0, 5.0, 0.0, 6.0, 7.0, 0.0, 0.0, 0.0, 0.0],\n                ]\n            ]\n        )\n        self.assertAllClose(unfold_result, except_result)\n\n    def test_fold(self):\n        # test 1: non-overlapping roundtrip (stride == kernel_size)\n        x = ops.arange(16, dtype=\"float32\")\n        x = ops.reshape(x, [1, 1, 4, 4])\n        patches = knn.unfold(x, kernel_size=2, stride=2)\n        y = knn.fold(patches, output_size=(4, 4), kernel_size=2, stride=2)\n        self.assertAllClose(y, x)\n\n        # test 2: overlapping roundtrip — fold(unfold(x)) / fold(ones) == x\n        x = ops.arange(16, dtype=\"float32\")\n        x = ops.reshape(x, [1, 1, 4, 4])\n        patches = knn.unfold(x, kernel_size=3, stride=1)\n        folded = knn.fold(patches, output_size=(4, 4), kernel_size=3, stride=1)\n        ones = ops.ones_like(x)\n        ones_patches = knn.unfold(ones, kernel_size=3, stride=1)\n        divisor = knn.fold(\n            ones_patches, output_size=(4, 4), kernel_size=3, stride=1\n        )\n        result = folded / divisor\n        self.assertAllClose(result, x)\n\n        # test 3: multi-channel non-overlapping roundtrip\n        x = ops.arange(32, dtype=\"float32\")\n        x = ops.reshape(x, [1, 2, 4, 4])\n        patches = knn.unfold(x, kernel_size=2, stride=2)\n        y = knn.fold(patches, output_size=(4, 4), kernel_size=2, stride=2)\n        self.assertAllClose(y, x)\n\n        # test 4: dilation + padding roundtrip\n        x = ops.arange(32, dtype=\"float32\")\n        x = ops.reshape(x, [1, 2, 4, 4])\n        patches = knn.unfold(x, kernel_size=2, dilation=2, padding=1)\n        y = knn.fold(\n            patches,\n            output_size=(4, 4),\n            kernel_size=2,\n            dilation=2,\n            padding=1,\n        )\n        ones = ops.ones_like(x)\n        ones_patches = knn.unfold(ones, kernel_size=2, dilation=2, padding=1)\n        divisor = knn.fold(\n            ones_patches,\n            output_size=(4, 4),\n            kernel_size=2,\n            dilation=2,\n            padding=1,\n        )\n        result = y / divisor\n        self.assertAllClose(result, x)\n\n        # test 5: explicit known values — single 1x1x2x2 non-overlap\n        x = ops.convert_to_tensor(\n            [[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]],\n            dtype=\"float32\",\n        )\n        # x shape: (1, 2, 4) — as if C=2, kernel=1x1, L=4\n        y = knn.fold(x, output_size=(2, 2), kernel_size=1, stride=1)\n        expected = ops.convert_to_tensor(\n            [[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]],\n            dtype=\"float32\",\n        )\n        self.assertAllClose(y, expected)\n\n        # test 6: input validation — must be 3D\n        x_bad = ops.ones((1, 2, 3, 4))\n        with self.assertRaisesRegex(ValueError, \"3D\"):\n            knn.fold(x_bad, output_size=(4, 4), kernel_size=2)\n\n    def test_fold_tuple_params(self):\n        \"\"\"Test fold with tuple parameters to cover _pair branches.\"\"\"\n        # Non-square kernel and stride as tuples\n        x = ops.arange(48, dtype=\"float32\")\n        x = ops.reshape(x, [1, 1, 6, 8])\n        patches = knn.unfold(x, kernel_size=(2, 4), stride=(2, 4))\n        y = knn.fold(\n            patches,\n            output_size=(6, 8),\n            kernel_size=(2, 4),\n            stride=(2, 4),\n        )\n        self.assertAllClose(y, x)\n\n        # Asymmetric padding as tuple — only pad height\n        x = ops.arange(16, dtype=\"float32\")\n        x = ops.reshape(x, [1, 1, 4, 4])\n        patches = knn.unfold(\n            x, kernel_size=(3, 2), stride=(1, 2), padding=(1, 0)\n        )\n        folded = knn.fold(\n            patches,\n            output_size=(4, 4),\n            kernel_size=(3, 2),\n            stride=(1, 2),\n            padding=(1, 0),\n        )\n        ones = ops.ones_like(x)\n        ones_p = knn.unfold(\n            ones, kernel_size=(3, 2), stride=(1, 2), padding=(1, 0)\n        )\n        divisor = knn.fold(\n            ones_p,\n            output_size=(4, 4),\n            kernel_size=(3, 2),\n            stride=(1, 2),\n            padding=(1, 0),\n        )\n        result = folded / divisor\n        self.assertAllClose(result, x)\n\n    def test_fold_no_padding(self):\n        \"\"\"Test fold with padding=0 to cover the skip-padding branch.\"\"\"\n        x = ops.arange(36, dtype=\"float32\")\n        x = ops.reshape(x, [1, 1, 6, 6])\n        patches = knn.unfold(x, kernel_size=3, stride=3, padding=0)\n        y = knn.fold(\n            patches,\n            output_size=(6, 6),\n            kernel_size=3,\n            stride=3,\n            padding=0,\n        )\n        self.assertAllClose(y, x)\n\n    def test_fold_non_square_output(self):\n        \"\"\"Test fold with non-square spatial dimensions.\"\"\"\n        x = ops.arange(24, dtype=\"float32\")\n        x = ops.reshape(x, [1, 1, 4, 6])\n        patches = knn.unfold(x, kernel_size=2, stride=2)\n        y = knn.fold(patches, output_size=(4, 6), kernel_size=2, stride=2)\n        self.assertAllClose(y, x)\n\n    def test_fold_batch_and_channels(self):\n        \"\"\"Test fold with larger batch and channel counts.\"\"\"\n        x = np.random.normal(size=(4, 8, 6, 6)).astype(\"float32\")\n        x = ops.convert_to_tensor(x)\n        patches = knn.unfold(x, kernel_size=2, stride=2)\n        y = knn.fold(patches, output_size=(6, 6), kernel_size=2, stride=2)\n        self.assertAllClose(y, x, tpu_atol=1e-2, tpu_rtol=1e-2)\n\n    def test_fold_divisibility_validation(self):\n        \"\"\"Test fold raises on CKK not divisible by kernel product.\"\"\"\n        # CKK=5, kernel=2x2 -> 5 % 4 != 0 — raises reshape error\n        x_bad = ops.ones((1, 5, 4))\n        with self.assertRaises((ValueError, Exception)):\n            knn.fold(x_bad, output_size=(4, 4), kernel_size=2)\n\n    def test_depth_to_space(self):\n        # Test channels_last (default)\n        # Input: (1, 2, 2, 12) -> Output: (1, 4, 4, 3)\n        x = ops.arange(48, dtype=\"float32\")\n        x = ops.reshape(x, [1, 2, 2, 12])\n        result = knn.depth_to_space(x, block_size=2)\n        self.assertEqual(result.shape, (1, 4, 4, 3))\n\n        # Verify the transformation is correct\n        # The depth channel is rearranged into spatial blocks\n        # For block_size=2, channels are split into 2x2 blocks\n        expected = np.array(\n            [\n                [\n                    [[0, 1, 2], [3, 4, 5], [12, 13, 14], [15, 16, 17]],\n                    [[6, 7, 8], [9, 10, 11], [18, 19, 20], [21, 22, 23]],\n                    [[24, 25, 26], [27, 28, 29], [36, 37, 38], [39, 40, 41]],\n                    [[30, 31, 32], [33, 34, 35], [42, 43, 44], [45, 46, 47]],\n                ]\n            ],\n            dtype=\"float32\",\n        )\n        self.assertAllClose(result, expected)\n\n        # Test channels_first\n        # Input: (1, 12, 2, 2) -> Output: (1, 3, 4, 4)\n        x = ops.arange(48, dtype=\"float32\")\n        x = ops.reshape(x, [1, 12, 2, 2])\n        result = knn.depth_to_space(\n            x, block_size=2, data_format=\"channels_first\"\n        )\n        self.assertEqual(result.shape, (1, 3, 4, 4))\n\n        # Test with different block size\n        x = ops.arange(1 * 2 * 2 * 27, dtype=\"float32\")\n        x = ops.reshape(x, [1, 2, 2, 27])\n        result = knn.depth_to_space(x, block_size=3)\n        self.assertEqual(result.shape, (1, 6, 6, 3))\n\n    def test_space_to_depth(self):\n        # Test channels_last (default)\n        # Input: (1, 4, 4, 3) -> Output: (1, 2, 2, 12)\n        x = ops.arange(48, dtype=\"float32\")\n        x = ops.reshape(x, [1, 4, 4, 3])\n        result = knn.space_to_depth(x, block_size=2)\n        self.assertEqual(result.shape, (1, 2, 2, 12))\n\n        # Verify the transformation is correct\n        expected = np.array(\n            [\n                [\n                    [\n                        [0, 1, 2, 3, 4, 5, 12, 13, 14, 15, 16, 17],\n                        [6, 7, 8, 9, 10, 11, 18, 19, 20, 21, 22, 23],\n                    ],\n                    [\n                        [24, 25, 26, 27, 28, 29, 36, 37, 38, 39, 40, 41],\n                        [30, 31, 32, 33, 34, 35, 42, 43, 44, 45, 46, 47],\n                    ],\n                ]\n            ],\n            dtype=\"float32\",\n        )\n        self.assertAllClose(result, expected)\n\n        # Test channels_first\n        # Input: (1, 3, 4, 4) -> Output: (1, 12, 2, 2)\n        x = ops.arange(48, dtype=\"float32\")\n        x = ops.reshape(x, [1, 3, 4, 4])\n        result = knn.space_to_depth(\n            x, block_size=2, data_format=\"channels_first\"\n        )\n        self.assertEqual(result.shape, (1, 12, 2, 2))\n\n        # Test with different block size\n        x = ops.arange(1 * 6 * 6 * 3, dtype=\"float32\")\n        x = ops.reshape(x, [1, 6, 6, 3])\n        result = knn.space_to_depth(x, block_size=3)\n        self.assertEqual(result.shape, (1, 2, 2, 27))\n\n    def test_depth_to_space_space_to_depth_roundtrip(self):\n        # depth_to_space followed by space_to_depth should be identity\n        x = ops.arange(48, dtype=\"float32\")\n        x = ops.reshape(x, [1, 2, 2, 12])\n        y = knn.depth_to_space(x, block_size=2)\n        z = knn.space_to_depth(y, block_size=2)\n        self.assertAllClose(x, z)\n\n        # space_to_depth followed by depth_to_space should be identity\n        x = ops.arange(48, dtype=\"float32\")\n        x = ops.reshape(x, [1, 4, 4, 3])\n        y = knn.space_to_depth(x, block_size=2)\n        z = knn.depth_to_space(y, block_size=2)\n        self.assertAllClose(x, z)\n\n        # Test with channels_first\n        x = ops.arange(48, dtype=\"float32\")\n        x = ops.reshape(x, [1, 12, 2, 2])\n        y = knn.depth_to_space(x, block_size=2, data_format=\"channels_first\")\n        z = knn.space_to_depth(y, block_size=2, data_format=\"channels_first\")\n        self.assertAllClose(x, z)\n\n    def test_depth_to_space_block_size_validation(self):\n        x = ops.arange(48, dtype=\"float32\")\n        x = ops.reshape(x, [1, 2, 2, 12])\n\n        # block_size must be at least 2\n        with self.assertRaisesRegex(\n            ValueError, \"`block_size` must be at least 2\"\n        ):\n            knn.depth_to_space(x, block_size=0)\n\n        with self.assertRaisesRegex(\n            ValueError, \"`block_size` must be at least 2\"\n        ):\n            knn.depth_to_space(x, block_size=1)\n\n        with self.assertRaisesRegex(\n            ValueError, \"`block_size` must be at least 2\"\n        ):\n            knn.depth_to_space(x, block_size=-1)\n\n    def test_space_to_depth_block_size_validation(self):\n        x = ops.arange(48, dtype=\"float32\")\n        x = ops.reshape(x, [1, 4, 4, 3])\n\n        # block_size must be at least 2\n        with self.assertRaisesRegex(\n            ValueError, \"`block_size` must be at least 2\"\n        ):\n            knn.space_to_depth(x, block_size=0)\n\n        with self.assertRaisesRegex(\n            ValueError, \"`block_size` must be at least 2\"\n        ):\n            knn.space_to_depth(x, block_size=1)\n\n        with self.assertRaisesRegex(\n            ValueError, \"`block_size` must be at least 2\"\n        ):\n            knn.space_to_depth(x, block_size=-1)\n"
  },
  {
    "path": "keras/src/ops/node.py",
    "content": "import collections\n\nfrom keras.src import tree\nfrom keras.src.backend import KerasTensor\nfrom keras.src.ops.symbolic_arguments import SymbolicArguments\n\n\nclass Node:\n    \"\"\"A `Node` describes an operation `__call__()` event.\n\n    A Keras Function is a DAG with `Node` instances as nodes, and\n    `KerasTensor` instances as edges. Nodes aren't `Operation` instances,\n    because a single operation could be called multiple times, which would\n    result in graph cycles.\n\n    A `__call__()` event involves input tensors (and other input arguments),\n    the operation that was called, and the resulting output tensors.\n    A `Node` will include all this information.\n\n    Since a single `Operation` could be called multiple times,\n    the `Node` instances are stored on operations as a list.\n    Each time an operation is called, a node is added to `op._inbound_nodes`.\n    Each time the output of an operation is used by another operation,\n    a node is added to `op._outbound_nodes`.\n\n    Every `KerasTensor` instance has a `KerasHistory` object attached,\n    which tracks the `Node` that records the `__call__()` event that created\n    the tensor. By recursively walking through `Node` instances\n    via the `KerasHistory` metadata of `KerasTensor` instances, once can\n    retrieve the entire DAG of a Keras Function.\n\n    Args:\n        operation: The Operation that was called in the `op.__call__()`\n            event that this node represents.\n        call_args: The positional arguments the operation was called with.\n        call_kwargs: The keyword arguments the operation was called with.\n        outputs: The output tensors of the `op.__call__()` call.\n    \"\"\"\n\n    def __init__(\n        self, operation, call_args=None, call_kwargs=None, outputs=None\n    ):\n        self.operation = operation\n        self.arguments = SymbolicArguments(*call_args, **call_kwargs)\n        self.outputs = [] if outputs is None else tree.flatten(outputs)\n        for x in self.outputs:\n            if not isinstance(x, KerasTensor):\n                raise ValueError(\n                    \"All operation outputs must be tensors. \"\n                    f\"Operation {operation} returned a non-tensor. \"\n                    f\"Non-tensor received: {x}\"\n                )\n\n        zero_history = any(\n            not x.record_history for x in self.arguments.keras_tensors\n        )\n\n        # If inputs don't have metadata yet, add it.\n        if not zero_history:\n            for tensor in self.arguments.keras_tensors:\n                if not hasattr(tensor, \"_keras_history\"):\n                    tensor._keras_history = KerasHistory(\n                        operation=None, node_index=0, tensor_index=0\n                    )\n\n        # Wire up Node to Operations.\n        self.operation._inbound_nodes.append(self)\n        for kt in self.arguments.keras_tensors:\n            inbound_op = kt._keras_history.operation\n            if inbound_op is not None:  # It's a graph entry point.\n                inbound_op._outbound_nodes.append(self)\n\n        # Set metadata on outputs.\n        if not zero_history:\n            node_index = len(self.operation._inbound_nodes) - 1\n            for i, tensor in enumerate(self.outputs):\n                tensor._keras_history = KerasHistory(\n                    operation=operation, node_index=node_index, tensor_index=i\n                )\n\n        # Whether this is a root node.\n        self.is_input = not self.arguments.keras_tensors\n\n    def __repr__(self):\n        return f\"<Node operation={self.operation.name}, id={id(self)}>\"\n\n    @property\n    def input_tensors(self):\n        return self.arguments.keras_tensors\n\n    @property\n    def output_tensors(self):\n        return self.outputs\n\n    @property\n    def parent_nodes(self):\n        \"\"\"The parent `Node`s.\n\n        Returns:\n            all the `Node`s whose output this node immediately depends on.\n        \"\"\"\n        node_deps = []\n        for kt in self.arguments.keras_tensors:\n            op = kt._keras_history.operation\n            node_index = kt._keras_history.node_index\n            if op is not None:  # `None` for `Input` tensors.\n                node_deps.append(op._inbound_nodes[node_index])\n        return node_deps\n\n\nclass KerasHistory(\n    collections.namedtuple(\n        \"KerasHistory\", [\"operation\", \"node_index\", \"tensor_index\"]\n    )\n):\n    \"\"\"Tracks the Operation call that created a Tensor.\n\n    During construction of Keras Functions, this metadata is added to\n    each Tensor produced as the output of an Operation.\n    This allows Keras to track how each Tensor was produced, and\n    this information is later retraced by the `Function` class to\n    reconstruct the Operations graph.\n\n    Attributes:\n      operation: The Operation instance that produced the Tensor.\n      node_index: The specific call to the Operation that produced this Tensor.\n        Operations can be called multiple times in order to share weights. A new\n        node is created every time an Operation is called. The corresponding\n        node that represents the call event that produced the Tensor can be\n        found at `op._inbound_nodes[node_index]`.\n      tensor_index: The output index for this Tensor.\n        Always zero if the Operation that produced this Tensor\n        only has one output. Nested structures of\n        Tensors are deterministically assigned an index via `nest.flatten`.\n    \"\"\"\n\n    # Added to maintain memory and performance characteristics of `namedtuple`\n    # while subclassing.\n    __slots__ = ()\n\n\ndef is_keras_tensor(obj):\n    return hasattr(obj, \"_keras_history\")\n"
  },
  {
    "path": "keras/src/ops/node_test.py",
    "content": "import numpy as np\n\nfrom keras.src import Layer\nfrom keras.src import testing\nfrom keras.src.backend import KerasTensor\nfrom keras.src.ops.node import Node\n\n\nclass DummyLayer(Layer):\n    pass\n\n\nclass NodeTest(testing.TestCase):\n    # Testing a simple node and layer combination **a**\n    def test_simple_case(self):\n        shape = (2, 3, 4)\n        a = KerasTensor(shape=shape)\n        a_layer = DummyLayer()\n        node = Node(a_layer, outputs=a, call_args=(), call_kwargs={})\n\n        self.assertEqual(node.is_input, True)\n\n        self.assertEqual(node.output_tensors[0], a)\n        self.assertEqual(node.output_tensors[0].shape, shape)\n\n    # Testing a simple node connection with args and kwargs **a** --> **b**\n    def test_single_wired_layers(self):\n        shape = (2, 3, 4)\n        a = KerasTensor(shape=shape)\n        a_layer = DummyLayer()\n        node1 = Node(a_layer, outputs=a, call_args=(), call_kwargs={})\n\n        b = KerasTensor(shape=shape)\n        x = KerasTensor(shape=shape)\n        kwargs = {\"x\": x}\n        args = (a,)\n        b_layer = DummyLayer()\n        node2 = Node(b_layer, outputs=b, call_args=args, call_kwargs=kwargs)\n\n        self.assertEqual(node1.is_input, True)\n        self.assertEqual(node2.is_input, False)\n\n        self.assertEqual(node1.operation, a_layer)\n        self.assertEqual(node2.operation, b_layer)\n\n        self.assertEqual(node1.output_tensors[0], a)\n        self.assertEqual(node1.output_tensors[0].shape, shape)\n\n        self.assertEqual(a_layer._inbound_nodes[0], node1)\n        self.assertEqual(a_layer._outbound_nodes[0], node2)\n\n        self.assertEqual(b_layer._inbound_nodes[0], node2)\n        self.assertEqual(node2.parent_nodes[0], node1)\n\n        self.assertEqual(node2.input_tensors, [a, x])\n        self.assertEqual(node2.arguments.kwargs, kwargs)\n        self.assertEqual(node2.arguments.args, args)\n\n    # Testing when output tensor is not Keras Tensor\n    def test_output_tensor_error(self):\n        a = np.random.rand(2, 3, 4)\n        a_layer = DummyLayer()\n        with self.assertRaisesRegex(\n            ValueError, \"operation outputs must be tensors.\"\n        ):\n            Node(a_layer, outputs=a, call_args=(), call_kwargs={})\n"
  },
  {
    "path": "keras/src/ops/numpy.py",
    "content": "import builtins\nimport re\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend import KerasTensor\nfrom keras.src.backend import any_symbolic_tensors\nfrom keras.src.backend.common import dtypes\nfrom keras.src.backend.common.backend_utils import canonicalize_axis\nfrom keras.src.backend.common.backend_utils import to_tuple_or_list\nfrom keras.src.ops import operation_utils\nfrom keras.src.ops.operation import Operation\nfrom keras.src.ops.operation_utils import broadcast_shapes\nfrom keras.src.ops.operation_utils import reduce_shape\n\n\nclass Rot90(Operation):\n    def __init__(self, k=1, axes=(0, 1), *, name=None):\n        super().__init__(name=name)\n        self.k = k\n        self.axes = axes\n\n    def call(self, array):\n        return backend.numpy.rot90(array, k=self.k, axes=self.axes)\n\n    def compute_output_spec(self, array):\n        array_shape = list(array.shape)\n        if len(array_shape) < 2:\n            raise ValueError(\n                \"Input array must have at least 2 dimensions. \"\n                f\"Received: array.shape={array_shape}\"\n            )\n        if len(self.axes) != 2 or self.axes[0] == self.axes[1]:\n            raise ValueError(\n                f\"Invalid axes: {self.axes}. \"\n                \"Axes must be a tuple of two different dimensions.\"\n            )\n        axis1, axis2 = self.axes\n        array_shape[axis1], array_shape[axis2] = (\n            array_shape[axis2],\n            array_shape[axis1],\n        )\n        return KerasTensor(shape=array_shape, dtype=array.dtype)\n\n\n@keras_export([\"keras.ops.rot90\", \"keras.ops.numpy.rot90\"])\ndef rot90(array, k=1, axes=(0, 1)):\n    \"\"\"Rotate an array by 90 degrees in the plane specified by axes.\n\n    This function rotates an array counterclockwise\n    by 90 degrees `k` times in the plane specified by `axes`.\n    Supports arrays of two or more dimensions.\n\n    Args:\n        array: Input array to rotate.\n        k: Number of times the array is rotated by 90 degrees.\n        axes: A tuple of two integers specifying the\n            plane of rotation (defaults to `(0, 1)`).\n\n    Returns:\n        Rotated array.\n\n    Examples:\n\n    >>> import numpy as np\n    >>> from keras import ops\n    >>> m = np.array([[1, 2], [3, 4]])\n    >>> rotated = ops.rot90(m)\n    >>> rotated\n    array([[2, 4],\n           [1, 3]])\n\n    >>> m = np.arange(8).reshape((2, 2, 2))\n    >>> rotated = ops.rot90(m, k=1, axes=(1, 2))\n    >>> rotated\n    array([[[1, 3],\n            [0, 2]],\n           [[5, 7],\n            [4, 6]]])\n    \"\"\"\n    if any_symbolic_tensors((array,)):\n        return Rot90(k=k, axes=axes).symbolic_call(array)\n    return backend.numpy.rot90(array, k=k, axes=axes)\n\n\ndef shape_equal(shape1, shape2, axis=None, allow_none=True):\n    \"\"\"Check if two shapes are equal.\n\n    Args:\n        shape1: A list or tuple of integers for first shape to be compared.\n        shape2: A list or tuple of integers for second shape to be compared.\n        axis: An integer, list, or tuple of integers (optional):\n            Axes to ignore during comparison. Defaults to `None`.\n        allow_none (bool, optional): If `True`, allows `None` in a shape\n            to match any value in the corresponding position of the other shape.\n            Defaults to `True`.\n\n    Returns:\n        bool: `True` if shapes are considered equal based on the criteria,\n        `False` otherwise.\n\n    Examples:\n\n    >>> shape_equal((32, 64, 128), (32, 64, 128))\n    True\n    >>> shape_equal((32, 64, 128), (32, 64, 127))\n    False\n    >>> shape_equal((32, 64, None), (32, 64, 128), allow_none=True)\n    True\n    >>> shape_equal((32, 64, None), (32, 64, 128), allow_none=False)\n    False\n    >>> shape_equal((32, 64, 128), (32, 63, 128), axis=1)\n    True\n    >>> shape_equal((32, 64, 128), (32, 63, 127), axis=(1, 2))\n    True\n    >>> shape_equal((32, 64, 128), (32, 63, 127), axis=[1,2])\n    True\n    >>> shape_equal((32, 64), (32, 64, 128))\n    False\n    \"\"\"\n    if len(shape1) != len(shape2):\n        return False\n\n    shape1 = list(shape1)\n    shape2 = list(shape2)\n\n    if axis is not None:\n        if isinstance(axis, int):\n            axis = [axis]\n        for ax in axis:\n            shape1[ax] = -1\n            shape2[ax] = -1\n\n    if allow_none:\n        for i in range(len(shape1)):\n            if shape1[i] is None:\n                shape1[i] = shape2[i]\n            if shape2[i] is None:\n                shape2[i] = shape1[i]\n\n    return shape1 == shape2\n\n\nclass Absolute(Operation):\n    def call(self, x):\n        return backend.numpy.absolute(x)\n\n    def compute_output_spec(self, x):\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(x.shape, dtype=x.dtype, sparse=sparse)\n\n\n@keras_export([\"keras.ops.absolute\", \"keras.ops.numpy.absolute\"])\ndef absolute(x):\n    \"\"\"Compute the absolute value element-wise.\n\n    `keras.ops.abs` is a shorthand for this function.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        An array containing the absolute value of each element in `x`.\n\n    Example:\n\n    >>> x = keras.ops.convert_to_tensor([-1.2, 1.2])\n    >>> keras.ops.absolute(x)\n    array([1.2, 1.2], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Absolute().symbolic_call(x)\n    return backend.numpy.absolute(x)\n\n\nclass Abs(Absolute):\n    pass\n\n\n@keras_export([\"keras.ops.abs\", \"keras.ops.numpy.abs\"])\ndef abs(x):\n    \"\"\"Shorthand for `keras.ops.absolute`.\"\"\"\n    return absolute(x)\n\n\nclass Add(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.add(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n        output_dtype = dtypes.result_type(\n            getattr(x1, \"dtype\", type(x1)),\n            getattr(x2, \"dtype\", type(x2)),\n        )\n        x1_sparse = getattr(x1, \"sparse\", False)\n        x2_sparse = getattr(x2, \"sparse\", False)\n        output_sparse = x1_sparse and x2_sparse\n        return KerasTensor(\n            output_shape, dtype=output_dtype, sparse=output_sparse\n        )\n\n\n@keras_export([\"keras.ops.add\", \"keras.ops.numpy.add\"])\ndef add(x1, x2):\n    \"\"\"Add arguments element-wise.\n\n    Args:\n        x1: First input tensor.\n        x2: Second input tensor.\n\n    Returns:\n        The tensor containing the element-wise sum of `x1` and `x2`.\n\n    Examples:\n    >>> x1 = keras.ops.convert_to_tensor([1, 4])\n    >>> x2 = keras.ops.convert_to_tensor([5, 6])\n    >>> keras.ops.add(x1, x2)\n    array([6, 10], dtype=int32)\n\n    `keras.ops.add` also broadcasts shapes:\n    >>> x1 = keras.ops.convert_to_tensor(\n    ...     [[5, 4],\n    ...      [5, 6]]\n    ... )\n    >>> x2 = keras.ops.convert_to_tensor([5, 6])\n    >>> keras.ops.add(x1, x2)\n    array([[10 10]\n           [10 12]], shape=(2, 2), dtype=int32)\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Add().symbolic_call(x1, x2)\n    return backend.numpy.add(x1, x2)\n\n\nclass All(Operation):\n    def __init__(self, axis=None, keepdims=False, *, name=None):\n        super().__init__(name=name)\n        if isinstance(axis, int):\n            self.axis = [axis]\n        else:\n            self.axis = axis\n        self.keepdims = keepdims\n\n    def call(self, x):\n        return backend.numpy.all(\n            x,\n            axis=self.axis,\n            keepdims=self.keepdims,\n        )\n\n    def compute_output_spec(self, x):\n        return KerasTensor(\n            reduce_shape(\n                x.shape,\n                axis=self.axis,\n                keepdims=self.keepdims,\n            ),\n            dtype=\"bool\",\n        )\n\n\n@keras_export([\"keras.ops.all\", \"keras.ops.numpy.all\"])\ndef all(x, axis=None, keepdims=False):\n    \"\"\"Test whether all array elements along a given axis evaluate to `True`.\n\n    Args:\n        x: Input tensor.\n        axis: An integer or tuple of integers that represent the axis along\n            which a logical AND reduction is performed. The default\n            (`axis=None`) is to perform a logical AND over all the dimensions\n            of the input array. `axis` may be negative, in which case it counts\n            for the last to the first axis.\n        keepdims: If `True`, axes which are reduced are left in the result as\n            dimensions with size one. With this option, the result will\n            broadcast correctly against the input array. Defaults to `False`.\n\n    Returns:\n        The tensor containing the logical AND reduction over the `axis`.\n\n    Examples:\n    >>> x = keras.ops.convert_to_tensor([True, False])\n    >>> keras.ops.all(x)\n    array(False, shape=(), dtype=bool)\n\n    >>> x = keras.ops.convert_to_tensor([[True, False], [True, True]])\n    >>> keras.ops.all(x, axis=0)\n    array([ True False], shape=(2,), dtype=bool)\n\n    `keepdims=True` outputs a tensor with dimensions reduced to one.\n    >>> x = keras.ops.convert_to_tensor([[True, False], [True, True]])\n    >>> keras.ops.all(x, keepdims=True)\n    array([[False]], shape=(1, 1), dtype=bool)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return All(axis=axis, keepdims=keepdims).symbolic_call(x)\n    return backend.numpy.all(x, axis=axis, keepdims=keepdims)\n\n\nclass AllClose(Operation):\n    def __init__(self, rtol=1e-05, atol=1e-08, equal_nan=False, *, name=None):\n        super().__init__(name=name)\n        self.rtol = rtol\n        self.atol = atol\n        self.equal_nan = equal_nan\n\n    def call(self, x1, x2):\n        return backend.numpy.allclose(\n            x1,\n            x2,\n            rtol=self.rtol,\n            atol=self.atol,\n            equal_nan=self.equal_nan,\n        )\n\n    def compute_output_spec(self, x1, x2):\n        return KerasTensor([], dtype=\"bool\")\n\n\n@keras_export([\"keras.ops.allclose\", \"keras.ops.numpy.allclose\"])\ndef allclose(x1, x2, rtol=1e-05, atol=1e-08, equal_nan=False):\n    \"\"\"Returns True if two arrays are element-wise equal within a tolerance.\n\n    The tolerance values are positive, typically very small numbers.  The\n    relative difference (`rtol * abs(b)`) and the absolute difference\n    `atol` are added together to compare against the absolute difference\n    between `a` and `b`.\n\n    Args:\n        x1: First input tensor.\n        x2: Second input tensor.\n        rtol: The relative tolerance parameter (see Notes).\n        atol: The absolute tolerance parameter (see Notes).\n        equal_nan: Whether to compare NaN's as equal.  If True, NaN's in\n            `a` will be considered equal to NaN's in `b` in the output array.\n\n    Returns:\n        True if the two arrays are equal within the given tolerance;\n        False otherwise.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return AllClose(\n            rtol=rtol, atol=atol, equal_nan=equal_nan\n        ).symbolic_call(x1, x2)\n    return backend.numpy.allclose(\n        x1, x2, rtol=rtol, atol=atol, equal_nan=equal_nan\n    )\n\n\nclass Angle(Operation):\n    def call(self, x):\n        return backend.numpy.angle(x)\n\n    def compute_output_spec(self, x):\n        dtype = backend.standardize_dtype(getattr(x, \"dtype\", backend.floatx()))\n        if dtype == \"int64\":\n            dtype = backend.floatx()\n        else:\n            dtype = dtypes.result_type(dtype, float)\n        return KerasTensor(x.shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.angle\", \"keras.ops.numpy.angle\"])\ndef angle(x):\n    \"\"\"Element-wise angle of a complex tensor.\n\n    Arguments:\n        x: Input tensor. Can be real or complex.\n\n    Returns:\n        Output tensor of same shape as x. containing the angle of each element\n        (in radians).\n\n    Example:\n    >>> x = keras.ops.convert_to_tensor([[1 + 3j, 2 - 5j], [4 - 3j, 3 + 2j]])\n    >>> keras.ops.angle(x)\n    array([[ 1.2490457, -1.19029  ],\n       [-0.6435011,  0.5880026]], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Angle().symbolic_call(x)\n    return backend.numpy.angle(x)\n\n\nclass Any(Operation):\n    def __init__(self, axis=None, keepdims=False, *, name=None):\n        super().__init__(name=name)\n        if isinstance(axis, int):\n            self.axis = [axis]\n        else:\n            self.axis = axis\n        self.keepdims = keepdims\n\n    def call(self, x):\n        return backend.numpy.any(\n            x,\n            axis=self.axis,\n            keepdims=self.keepdims,\n        )\n\n    def compute_output_spec(self, x):\n        return KerasTensor(\n            reduce_shape(\n                x.shape,\n                axis=self.axis,\n                keepdims=self.keepdims,\n            ),\n            dtype=\"bool\",\n        )\n\n\n@keras_export([\"keras.ops.any\", \"keras.ops.numpy.any\"])\ndef any(x, axis=None, keepdims=False):\n    \"\"\"Test whether any array element along a given axis evaluates to `True`.\n\n    Args:\n        x: Input tensor.\n        axis: An integer or tuple of integers that represent the axis along\n            which a logical OR reduction is performed. The default\n            (`axis=None`) is to perform a logical OR over all the dimensions\n            of the input array. `axis` may be negative, in which case it counts\n            for the last to the first axis.\n        keepdims: If `True`, axes which are reduced are left in the result as\n            dimensions with size one. With this option, the result will\n            broadcast correctly against the input array. Defaults to `False`.\n\n    Returns:\n        The tensor containing the logical OR reduction over the `axis`.\n\n    Examples:\n    >>> x = keras.ops.convert_to_tensor([True, False])\n    >>> keras.ops.any(x)\n    array(True, shape=(), dtype=bool)\n\n    >>> x = keras.ops.convert_to_tensor([[True, False], [True, True]])\n    >>> keras.ops.any(x, axis=0)\n    array([ True  True], shape=(2,), dtype=bool)\n\n    `keepdims=True` outputs a tensor with dimensions reduced to one.\n    >>> x = keras.ops.convert_to_tensor([[True, False], [True, True]])\n    >>> keras.ops.all(x, keepdims=True)\n    array([[False]], shape=(1, 1), dtype=bool)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Any(axis=axis, keepdims=keepdims).symbolic_call(x)\n    return backend.numpy.any(x, axis=axis, keepdims=keepdims)\n\n\nclass Amax(Operation):\n    def __init__(self, axis=None, keepdims=False, *, name=None):\n        super().__init__(name=name)\n        if isinstance(axis, int):\n            axis = [axis]\n        self.axis = axis\n        self.keepdims = keepdims\n\n    def call(self, x):\n        return backend.numpy.amax(\n            x,\n            axis=self.axis,\n            keepdims=self.keepdims,\n        )\n\n    def compute_output_spec(self, x):\n        return KerasTensor(\n            reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims),\n            dtype=x.dtype,\n        )\n\n\n@keras_export([\"keras.ops.amax\", \"keras.ops.numpy.amax\"])\ndef amax(x, axis=None, keepdims=False):\n    \"\"\"Returns the maximum of an array or maximum value along an axis.\n\n    Args:\n        x: Input tensor.\n        axis: Axis along which to compute the maximum.\n            By default (`axis=None`), find the maximum value in all the\n            dimensions of the input array.\n        keepdims: If `True`, axes which are reduced are left in the result as\n            dimensions that are broadcast to the size of the original\n            input tensor. Defaults to `False`.\n\n    Returns:\n        An array with the maximum value. If `axis=None`, the result is a scalar\n        value representing the maximum element in the entire array. If `axis` is\n        given, the result is an array with the maximum values along\n        the specified axis.\n\n    Examples:\n    >>> x = keras.ops.convert_to_tensor([[1, 3, 5], [2, 3, 6]])\n    >>> keras.ops.amax(x)\n    array(6, dtype=int32)\n\n    >>> x = keras.ops.convert_to_tensor([[1, 6, 8], [1, 5, 2]])\n    >>> keras.ops.amax(x, axis=0)\n    array([1, 6, 8], dtype=int32)\n\n    >>> x = keras.ops.convert_to_tensor([[1, 6, 8], [1, 5, 2]])\n    >>> keras.ops.amax(x, axis=1, keepdims=True)\n    array([[8], [5]], dtype=int32)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Amax(axis=axis, keepdims=keepdims).symbolic_call(x)\n    return backend.numpy.amax(x, axis=axis, keepdims=keepdims)\n\n\nclass Amin(Operation):\n    def __init__(self, axis=None, keepdims=False, *, name=None):\n        super().__init__(name=name)\n        if isinstance(axis, int):\n            axis = [axis]\n        self.axis = axis\n        self.keepdims = keepdims\n\n    def call(self, x):\n        return backend.numpy.amin(x, axis=self.axis, keepdims=self.keepdims)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(\n            reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims),\n            dtype=x.dtype,\n        )\n\n\n@keras_export([\"keras.ops.amin\", \"keras.ops.numpy.amin\"])\ndef amin(x, axis=None, keepdims=False):\n    \"\"\"Returns the minimum of an array or minimum value along an axis.\n\n    Args:\n        x: Input tensor.\n        axis: Axis along which to compute the minimum.\n            By default (`axis=None`), find the minimum value in all the\n            dimensions of the input array.\n        keepdims: If `True`, axes which are reduced are left in the result as\n            dimensions that are broadcast to the size of the original\n            input tensor. Defaults to `False`.\n\n    Returns:\n        An array with the minimum value. If `axis=None`, the result is a scalar\n        value representing the minimum element in the entire array. If `axis` is\n        given, the result is an array with the minimum values along\n        the specified axis.\n\n    Examples:\n    >>> x = keras.ops.convert_to_tensor([1, 3, 5, 2, 3, 6])\n    >>> keras.ops.amin(x)\n    array(1, dtype=int32)\n\n    >>> x = keras.ops.convert_to_tensor([[1, 6, 8], [7, 5, 3]])\n    >>> keras.ops.amin(x, axis=0)\n    array([1,5,3], dtype=int32)\n\n    >>> x = keras.ops.convert_to_tensor([[1, 6, 8], [7, 5, 3]])\n    >>> keras.ops.amin(x, axis=1, keepdims=True)\n    array([[1],[3]], dtype=int32)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Amin(axis=axis, keepdims=keepdims).symbolic_call(x)\n    return backend.numpy.amin(x, axis=axis, keepdims=keepdims)\n\n\nclass Append(Operation):\n    def __init__(self, axis=None, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n\n    def call(self, x1, x2):\n        return backend.numpy.append(x1, x2, axis=self.axis)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = x1.shape\n        x2_shape = x2.shape\n        dtype = dtypes.result_type(\n            getattr(x1, \"dtype\", type(x1)),\n            getattr(x2, \"dtype\", type(x2)),\n        )\n        if self.axis is None:\n            if None in x1_shape or None in x2_shape:\n                output_shape = [None]\n            else:\n                output_shape = [int(np.prod(x1_shape) + np.prod(x2_shape))]\n            return KerasTensor(output_shape, dtype=dtype)\n\n        if not shape_equal(x1_shape, x2_shape, [self.axis]):\n            raise ValueError(\n                \"`append` requires inputs to have the same shape except the \"\n                f\"`axis={self.axis}`, but received shape {x1_shape} and \"\n                f\"{x2_shape}.\"\n            )\n\n        output_shape = list(x1_shape)\n        output_shape[self.axis] = x1_shape[self.axis] + x2_shape[self.axis]\n        return KerasTensor(output_shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.append\", \"keras.ops.numpy.append\"])\ndef append(\n    x1,\n    x2,\n    axis=None,\n):\n    \"\"\"Append tensor `x2` to the end of tensor `x1`.\n\n    Args:\n        x1: First input tensor.\n        x2: Second input tensor.\n        axis: Axis along which tensor `x2` is appended to tensor `x1`.\n            If `None`, both tensors are flattened before use.\n\n    Returns:\n        A tensor with the values of `x2` appended to `x1`.\n\n    Examples:\n    >>> x1 = keras.ops.convert_to_tensor([1, 2, 3])\n    >>> x2 = keras.ops.convert_to_tensor([[4, 5, 6], [7, 8, 9]])\n    >>> keras.ops.append(x1, x2)\n    array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)\n\n    When `axis` is specified, `x1` and `x2` must have compatible shapes.\n    >>> x1 = keras.ops.convert_to_tensor([[1, 2, 3], [4, 5, 6]])\n    >>> x2 = keras.ops.convert_to_tensor([[7, 8, 9]])\n    >>> keras.ops.append(x1, x2, axis=0)\n    array([[1, 2, 3],\n            [4, 5, 6],\n            [7, 8, 9]], dtype=int32)\n    >>> x3 = keras.ops.convert_to_tensor([7, 8, 9])\n    >>> keras.ops.append(x1, x3, axis=0)\n    Traceback (most recent call last):\n        ...\n    TypeError: Cannot concatenate arrays with different numbers of\n    dimensions: got (2, 3), (3,).\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Append(axis=axis).symbolic_call(x1, x2)\n    return backend.numpy.append(x1, x2, axis=axis)\n\n\nclass Arange(Operation):\n    def __init__(self, dtype=None, *, name=None):\n        super().__init__(name=name)\n        self.dtype = dtype\n\n    def call(self, start, stop=None, step=None):\n        return backend.numpy.arange(start, stop, step=step, dtype=self.dtype)\n\n    def compute_output_spec(self, start, stop=None, step=None):\n        if stop is None:\n            start, stop = 0, start\n        if step is None:\n            step = 1\n        output_shape = [int(np.ceil((stop - start) / step))]\n        dtype = self.dtype\n        if dtype is None:\n            dtypes_to_resolve = [getattr(start, \"dtype\", type(start))]\n            if stop is not None:\n                dtypes_to_resolve.append(getattr(stop, \"dtype\", type(stop)))\n            if step is not None:\n                dtypes_to_resolve.append(getattr(step, \"dtype\", type(step)))\n            dtype = dtypes.result_type(*dtypes_to_resolve)\n        return KerasTensor(output_shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.arange\", \"keras.ops.numpy.arange\"])\ndef arange(start, stop=None, step=None, dtype=None):\n    \"\"\"Return evenly spaced values within a given interval.\n\n    `arange` can be called with a varying number of positional arguments:\n    * `arange(stop)`: Values are generated within the half-open interval\n        `[0, stop)` (in other words, the interval including start but excluding\n        stop).\n    * `arange(start, stop)`: Values are generated within the half-open interval\n        `[start, stop)`.\n    * `arange(start, stop, step)`: Values are generated within the half-open\n        interval `[start, stop)`, with spacing between values given by step.\n\n    Args:\n        start: Integer or real, representing the start of the interval. The\n            interval includes this value.\n        stop: Integer or real, representing the end of the interval. The\n            interval does not include this value, except in some cases where\n            `step` is not an integer and floating point round-off affects the\n            length of `out`. Defaults to `None`.\n        step: Integer or real, represent the spacing between values. For any\n            output `out`, this is the distance between two adjacent values,\n            `out[i+1] - out[i]`. The default step size is 1. If `step` is\n            specified as a position argument, `start` must also be given.\n        dtype: The type of the output array. If `dtype` is not given, infer the\n            data type from the other input arguments.\n\n    Returns:\n        Tensor of evenly spaced values.\n        For floating point arguments, the length of the result is\n        `ceil((stop - start)/step)`. Because of floating point overflow, this\n        rule may result in the last element of out being greater than stop.\n\n    Examples:\n    >>> keras.ops.arange(3)\n    array([0, 1, 2], dtype=int32)\n\n    >>> keras.ops.arange(3.0)\n    array([0., 1., 2.], dtype=float32)\n\n    >>> keras.ops.arange(3, 7)\n    array([3, 4, 5, 6], dtype=int32)\n\n    >>> keras.ops.arange(3, 7, 2)\n    array([3, 5], dtype=int32)\n    \"\"\"\n    dtype = None if dtype is None else backend.standardize_dtype(dtype)\n    if any_symbolic_tensors((start, stop, step)):\n        return Arange(dtype=dtype).symbolic_call(start, stop, step=step)\n    return backend.numpy.arange(start, stop, step=step, dtype=dtype)\n\n\nclass Arccos(Operation):\n    def call(self, x):\n        return backend.numpy.arccos(x)\n\n    def compute_output_spec(self, x):\n        dtype = backend.standardize_dtype(getattr(x, \"dtype\", backend.floatx()))\n        if dtype == \"int64\":\n            dtype = backend.floatx()\n        else:\n            dtype = dtypes.result_type(dtype, float)\n        return KerasTensor(x.shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.arccos\", \"keras.ops.numpy.arccos\"])\ndef arccos(x):\n    \"\"\"Trigonometric inverse cosine, element-wise.\n\n    The inverse of `cos` so that, if `y = cos(x)`, then `x = arccos(y)`.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Tensor of the angle of the ray intersecting the unit circle at the given\n        x-coordinate in radians `[0, pi]`.\n\n    Example:\n    >>> x = keras.ops.convert_to_tensor([1, -1])\n    >>> keras.ops.arccos(x)\n    array([0.0, 3.1415927], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Arccos().symbolic_call(x)\n    return backend.numpy.arccos(x)\n\n\nclass Arccosh(Operation):\n    def call(self, x):\n        return backend.numpy.arccosh(x)\n\n    def compute_output_spec(self, x):\n        dtype = backend.standardize_dtype(getattr(x, \"dtype\", backend.floatx()))\n        if dtype == \"int64\":\n            dtype = backend.floatx()\n        else:\n            dtype = dtypes.result_type(dtype, float)\n        return KerasTensor(x.shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.arccosh\", \"keras.ops.numpy.arccosh\"])\ndef arccosh(x):\n    \"\"\"Inverse hyperbolic cosine, element-wise.\n\n    Arguments:\n        x: Input tensor.\n\n    Returns:\n        Output tensor of same shape as x.\n\n    Example:\n    >>> x = keras.ops.convert_to_tensor([10, 100])\n    >>> keras.ops.arccosh(x)\n    array([2.993223, 5.298292], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Arccosh().symbolic_call(x)\n    return backend.numpy.arccosh(x)\n\n\nclass Arcsin(Operation):\n    def call(self, x):\n        return backend.numpy.arcsin(x)\n\n    def compute_output_spec(self, x):\n        dtype = backend.standardize_dtype(getattr(x, \"dtype\", backend.floatx()))\n        if dtype == \"int64\":\n            dtype = backend.floatx()\n        else:\n            dtype = dtypes.result_type(dtype, float)\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(x.shape, dtype=dtype, sparse=sparse)\n\n\n@keras_export([\"keras.ops.arcsin\", \"keras.ops.numpy.arcsin\"])\ndef arcsin(x):\n    \"\"\"Inverse sine, element-wise.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Tensor of the inverse sine of each element in `x`, in radians and in\n        the closed interval `[-pi/2, pi/2]`.\n\n    Example:\n    >>> x = keras.ops.convert_to_tensor([1, -1, 0])\n    >>> keras.ops.arcsin(x)\n    array([ 1.5707964, -1.5707964,  0.], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Arcsin().symbolic_call(x)\n    return backend.numpy.arcsin(x)\n\n\nclass Arcsinh(Operation):\n    def call(self, x):\n        return backend.numpy.arcsinh(x)\n\n    def compute_output_spec(self, x):\n        dtype = backend.standardize_dtype(getattr(x, \"dtype\", backend.floatx()))\n        if dtype == \"int64\":\n            dtype = backend.floatx()\n        else:\n            dtype = dtypes.result_type(dtype, float)\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(x.shape, dtype=dtype, sparse=sparse)\n\n\n@keras_export([\"keras.ops.arcsinh\", \"keras.ops.numpy.arcsinh\"])\ndef arcsinh(x):\n    \"\"\"Inverse hyperbolic sine, element-wise.\n\n    Arguments:\n        x: Input tensor.\n\n    Returns:\n        Output tensor of same shape as `x`.\n\n    Example:\n    >>> x = keras.ops.convert_to_tensor([1, -1, 0])\n    >>> keras.ops.arcsinh(x)\n    array([0.88137364, -0.88137364, 0.0], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Arcsinh().symbolic_call(x)\n    return backend.numpy.arcsinh(x)\n\n\nclass Arctan(Operation):\n    def call(self, x):\n        return backend.numpy.arctan(x)\n\n    def compute_output_spec(self, x):\n        dtype = backend.standardize_dtype(getattr(x, \"dtype\", backend.floatx()))\n        if dtype == \"int64\":\n            dtype = backend.floatx()\n        else:\n            dtype = dtypes.result_type(dtype, float)\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(x.shape, dtype=dtype, sparse=sparse)\n\n\n@keras_export([\"keras.ops.arctan\", \"keras.ops.numpy.arctan\"])\ndef arctan(x):\n    \"\"\"Trigonometric inverse tangent, element-wise.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Tensor of the inverse tangent of each element in `x`, in the interval\n        `[-pi/2, pi/2]`.\n\n    Example:\n    >>> x = keras.ops.convert_to_tensor([0, 1])\n    >>> keras.ops.arctan(x)\n    array([0., 0.7853982], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Arctan().symbolic_call(x)\n    return backend.numpy.arctan(x)\n\n\nclass Arctan2(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.arctan2(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        outputs_shape = broadcast_shapes(x1_shape, x2_shape)\n        x1_dtype = backend.standardize_dtype(\n            getattr(x1, \"dtype\", backend.floatx())\n        )\n        x2_dtype = backend.standardize_dtype(\n            getattr(x2, \"dtype\", backend.floatx())\n        )\n        dtype = dtypes.result_type(x1_dtype, x2_dtype, float)\n        return KerasTensor(outputs_shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.arctan2\", \"keras.ops.numpy.arctan2\"])\ndef arctan2(x1, x2):\n    \"\"\"Element-wise arc tangent of `x1/x2` choosing the quadrant correctly.\n\n    The quadrant (i.e., branch) is chosen so that `arctan2(x1, x2)` is the\n    signed angle in radians between the ray ending at the origin and passing\n    through the point `(1, 0)`, and the ray ending at the origin and passing\n    through the point `(x2, x1)`. (Note the role reversal: the \"y-coordinate\"\n    is the first function parameter, the \"x-coordinate\" is the second.) By IEEE\n    convention, this function is defined for `x2 = +/-0` and for either or both\n    of `x1` and `x2` `= +/-inf`.\n\n    Args:\n        x1: First input tensor.\n        x2: Second input tensor.\n\n    Returns:\n        Tensor of angles in radians, in the range `[-pi, pi]`.\n\n    Examples:\n    Consider four points in different quadrants:\n    >>> x = keras.ops.convert_to_tensor([-1, +1, +1, -1])\n    >>> y = keras.ops.convert_to_tensor([-1, -1, +1, +1])\n    >>> keras.ops.arctan2(y, x) * 180 / numpy.pi\n    array([-135., -45., 45., 135.], dtype=float32)\n\n    Note the order of the parameters. `arctan2` is defined also when x2=0 and\n    at several other points, obtaining values in the range `[-pi, pi]`:\n    >>> keras.ops.arctan2(\n    ...     keras.ops.array([1., -1.]),\n    ...     keras.ops.array([0., 0.]),\n    ... )\n    array([ 1.5707964, -1.5707964], dtype=float32)\n    >>> keras.ops.arctan2(\n    ...     keras.ops.array([0., 0., numpy.inf]),\n    ...     keras.ops.array([+0., -0., numpy.inf]),\n    ... )\n    array([0., 3.1415925, 0.7853982], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Arctan2().symbolic_call(x1, x2)\n    return backend.numpy.arctan2(x1, x2)\n\n\nclass Arctanh(Operation):\n    def call(self, x):\n        return backend.numpy.arctanh(x)\n\n    def compute_output_spec(self, x):\n        dtype = backend.standardize_dtype(getattr(x, \"dtype\", backend.floatx()))\n        if dtype == \"int64\":\n            dtype = backend.floatx()\n        else:\n            dtype = dtypes.result_type(dtype, float)\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(x.shape, dtype=dtype, sparse=sparse)\n\n\n@keras_export([\"keras.ops.arctanh\", \"keras.ops.numpy.arctanh\"])\ndef arctanh(x):\n    \"\"\"Inverse hyperbolic tangent, element-wise.\n\n    Arguments:\n        x: Input tensor.\n\n    Returns:\n        Output tensor of same shape as `x`.\n\n    Example:\n    >>> x = keras.ops.convert_to_tensor([0, -0.5])\n    >>> keras.ops.arctanh(x)\n    array([ 0.        , -0.54930615], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Arctanh().symbolic_call(x)\n    return backend.numpy.arctanh(x)\n\n\nclass Argmax(Operation):\n    def __init__(self, axis=None, keepdims=False, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n        self.keepdims = keepdims\n\n    def call(self, x):\n        return backend.numpy.argmax(x, axis=self.axis, keepdims=self.keepdims)\n\n    def compute_output_spec(self, x):\n        if self.keepdims:\n            return KerasTensor(x.shape, dtype=\"int32\")\n        if self.axis is None:\n            return KerasTensor([], dtype=\"int32\")\n        return KerasTensor(\n            reduce_shape(x.shape, axis=[self.axis]), dtype=\"int32\"\n        )\n\n\n@keras_export([\"keras.ops.argmax\", \"keras.ops.numpy.argmax\"])\ndef argmax(x, axis=None, keepdims=False):\n    \"\"\"Returns the indices of the maximum values along an axis.\n\n    Args:\n        x: Input tensor.\n        axis: By default, the index is into the flattened tensor, otherwise\n            along the specified axis.\n        keepdims: If this is set to `True`, the axes which are reduced are left\n            in the result as dimensions with size one. Defaults to `False`.\n\n    Returns:\n        Tensor of indices. It has the same shape as `x`, with the dimension\n        along `axis` removed.\n\n    Example:\n    >>> x = keras.ops.arange(6).reshape(2, 3) + 10\n    >>> x\n    array([[10, 11, 12],\n           [13, 14, 15]], dtype=int32)\n    >>> keras.ops.argmax(x)\n    array(5, dtype=int32)\n    >>> keras.ops.argmax(x, axis=0)\n    array([1, 1, 1], dtype=int32)\n    >>> keras.ops.argmax(x, axis=1)\n    array([2, 2], dtype=int32)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Argmax(axis=axis, keepdims=keepdims).symbolic_call(x)\n    return backend.numpy.argmax(x, axis=axis, keepdims=keepdims)\n\n\nclass Argmin(Operation):\n    def __init__(self, axis=None, keepdims=False, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n        self.keepdims = keepdims\n\n    def call(self, x):\n        return backend.numpy.argmin(x, axis=self.axis, keepdims=self.keepdims)\n\n    def compute_output_spec(self, x):\n        if self.keepdims:\n            return KerasTensor(x.shape, dtype=\"int32\")\n        if self.axis is None:\n            return KerasTensor([], dtype=\"int32\")\n        return KerasTensor(\n            reduce_shape(x.shape, axis=[self.axis]), dtype=\"int32\"\n        )\n\n\n@keras_export([\"keras.ops.argmin\", \"keras.ops.numpy.argmin\"])\ndef argmin(x, axis=None, keepdims=False):\n    \"\"\"Returns the indices of the minimum values along an axis.\n\n    Args:\n        x: Input tensor.\n        axis: By default, the index is into the flattened tensor, otherwise\n            along the specified axis.\n        keepdims: If this is set to `True`, the axes which are reduced are left\n            in the result as dimensions with size one. Defaults to `False`.\n\n    Returns:\n        Tensor of indices. It has the same shape as `x`, with the dimension\n        along `axis` removed.\n\n    Example:\n    >>> x = keras.ops.arange(6).reshape(2, 3) + 10\n    >>> x\n    array([[10, 11, 12],\n           [13, 14, 15]], dtype=int32)\n    >>> keras.ops.argmin(x)\n    array(0, dtype=int32)\n    >>> keras.ops.argmin(x, axis=0)\n    array([0, 0, 0], dtype=int32)\n    >>> keras.ops.argmin(x, axis=1)\n    array([0, 0], dtype=int32)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Argmin(axis=axis, keepdims=keepdims).symbolic_call(x)\n    return backend.numpy.argmin(x, axis=axis, keepdims=keepdims)\n\n\nclass Argsort(Operation):\n    def __init__(self, axis=-1, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n\n    def call(self, x):\n        return backend.numpy.argsort(x, axis=self.axis)\n\n    def compute_output_spec(self, x):\n        if self.axis is None:\n            return KerasTensor([int(np.prod(x.shape))], dtype=\"int32\")\n        return KerasTensor(x.shape, dtype=\"int32\")\n\n\n@keras_export([\"keras.ops.argsort\", \"keras.ops.numpy.argsort\"])\ndef argsort(x, axis=-1):\n    \"\"\"Returns the indices that would sort a tensor.\n\n    Args:\n        x: Input tensor.\n        axis: Axis along which to sort. Defaults to `-1` (the last axis). If\n            `None`, the flattened tensor is used.\n\n    Returns:\n        Tensor of indices that sort `x` along the specified `axis`.\n\n    Examples:\n    One dimensional array:\n    >>> x = keras.ops.array([3, 1, 2])\n    >>> keras.ops.argsort(x)\n    array([1, 2, 0], dtype=int32)\n\n    Two-dimensional array:\n    >>> x = keras.ops.array([[0, 3], [3, 2], [4, 5]])\n    >>> x\n    array([[0, 3],\n           [3, 2],\n           [4, 5]], dtype=int32)\n    >>> keras.ops.argsort(x, axis=0)\n    array([[0, 1],\n           [1, 0],\n           [2, 2]], dtype=int32)\n    >>> keras.ops.argsort(x, axis=1)\n    array([[0, 1],\n           [1, 0],\n           [0, 1]], dtype=int32)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Argsort(axis=axis).symbolic_call(x)\n    return backend.numpy.argsort(x, axis=axis)\n\n\nclass Array(Operation):\n    def __init__(self, dtype=None, *, name=None):\n        super().__init__(name=name)\n        self.dtype = dtype\n\n    def call(self, x):\n        return backend.numpy.array(x, dtype=self.dtype)\n\n    def compute_output_spec(self, x, dtype=None):\n        dtype = (\n            backend.standardize_dtype(x.dtype)\n            if self.dtype is None\n            else self.dtype\n        )\n        return KerasTensor(x.shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.array\", \"keras.ops.numpy.array\"])\ndef array(x, dtype=None):\n    \"\"\"Create a tensor.\n\n    Args:\n        x: Input tensor.\n        dtype: The desired data-type for the tensor.\n\n    Returns:\n        A tensor.\n\n    Examples:\n    >>> keras.ops.array([1, 2, 3])\n    array([1, 2, 3], dtype=int32)\n\n    >>> keras.ops.array([1, 2, 3], dtype=\"float32\")\n    array([1., 2., 3.], dtype=float32)\n    \"\"\"\n    dtype = None if dtype is None else backend.standardize_dtype(dtype)\n    if any_symbolic_tensors((x,)):\n        return Array(dtype=dtype).symbolic_call(x)\n    return backend.numpy.array(x, dtype=dtype)\n\n\nclass View(Operation):\n    def __init__(self, dtype=None, *, name=None):\n        super().__init__(name=name)\n        self.dtype = dtype\n\n    def call(self, x):\n        return backend.numpy.view(x, dtype=self.dtype)\n\n    def compute_output_spec(self, x):\n        old_dtype = backend.standardize_dtype(x.dtype)\n        new_dtype = backend.standardize_dtype(\n            self.dtype if self.dtype else x.dtype\n        )\n\n        old_itemsize = np.dtype(old_dtype).itemsize\n        new_itemsize = np.dtype(new_dtype).itemsize\n\n        if old_itemsize == new_itemsize:\n            return KerasTensor(x.shape, dtype=new_dtype)\n\n        if not x.shape:\n            raise ValueError(\n                \"Cannot view a scalar as a different dtype if item sizes \"\n                \"are different.\"\n            )\n\n        output_shape = list(x.shape)\n        if output_shape[-1] is not None:\n            if (output_shape[-1] * old_itemsize) % new_itemsize != 0:\n                raise ValueError(\n                    f\"Cannot view array of shape {x.shape} and dtype {x.dtype} \"\n                    f\"as dtype {new_dtype} because the total number of bytes \"\n                    \"is not divisible by the new itemsize.\"\n                )\n            output_shape[-1] = output_shape[-1] * old_itemsize // new_itemsize\n        return KerasTensor(tuple(output_shape), dtype=new_dtype)\n\n\n@keras_export([\"keras.ops.view\", \"keras.ops.numpy.view\"])\ndef view(x, dtype=None):\n    \"\"\"Create a new bitwise view of the same data with the specified dtype.\n\n    Args:\n        x: Input tensor.\n        dtype: Data-type descriptor of the returned view,\n            e.g., float32 or int16.\n\n    Returns:\n        View of a tensor with data type dtype.\n\n    Examples:\n    >>> x = keras.ops.array([1, 2, 3])\n    >>> x\n    array([1, 2, 3], dtype=int32)\n    >>> keras.ops.view(x, dtype=\"float32\")\n    array([1.0e-45, 3.0e-45, 4.0e-45], dtype=float32)\n    \"\"\"\n    dtype = None if dtype is None else backend.standardize_dtype(dtype)\n    if any_symbolic_tensors((x,)):\n        return View(dtype=dtype).symbolic_call(x)\n    return backend.numpy.view(x, dtype=dtype)\n\n\nclass Average(Operation):\n    def __init__(self, axis=None, *, name=None):\n        super().__init__(name=name)\n        # np.average() does not support axis as tuple as declared by the\n        # docstring, it only supports int or None.\n        self.axis = axis\n\n    def call(self, x, weights=None):\n        return backend.numpy.average(x, weights=weights, axis=self.axis)\n\n    def compute_output_spec(self, x, weights=None):\n        dtypes_to_resolve = [getattr(x, \"dtype\", type(x)), float]\n        if weights is not None:\n            shape_match = shape_equal(x.shape, weights.shape, allow_none=True)\n            if self.axis is not None:\n                shape_match_on_axis = shape_equal(\n                    [x.shape[self.axis]], weights.shape, allow_none=True\n                )\n            dtypes_to_resolve.append(getattr(weights, \"dtype\", type(weights)))\n        dtype = dtypes.result_type(*dtypes_to_resolve)\n        if self.axis is None:\n            if weights is None or shape_match:\n                return KerasTensor([], dtype=dtype)\n            else:\n                raise ValueError(\n                    \"`weights` must have the same shape as `x` when \"\n                    f\"`axis=None`, but received `weights.shape={weights.shape}`\"\n                    f\" and `x.shape={x.shape}`.\"\n                )\n\n        if weights is None or shape_match_on_axis or shape_match:\n            return KerasTensor(\n                reduce_shape(x.shape, axis=[self.axis]), dtype=dtype\n            )\n        else:\n            # `weights` can either be a 1D array of length `x.shape[axis]` or\n            # of the same shape as `x`.\n            raise ValueError(\n                \"`weights` must have the same size as `x` at \"\n                f\"`axis={self.axis}` but received \"\n                f\"`weights.shape={weights.shape}` while x.shape at \"\n                f\"`{self.axis}` is `{x.shape[self.axis]}`.\"\n            )\n\n\n@keras_export([\"keras.ops.average\", \"keras.ops.numpy.average\"])\ndef average(x, axis=None, weights=None):\n    \"\"\"Compute the weighted average along the specified axis.\n\n    Args:\n        x: Input tensor.\n        axis: Integer along which to average `x`. The default, `axis=None`,\n            will average over all of the elements of the input tensor. If axis\n            is negative it counts from the last to the first axis.\n        weights: Tensor of weights associated with the values in `x`. Each\n            value in `x` contributes to the average according to its\n            associated weight. The weights array can either be 1-D (in which\n            case its length must be the size of a along the given axis) or of\n            the same shape as `x`. If `weights=None` (default), then all data\n            in `x` are assumed to have a weight equal to one.\n\n            The 1-D calculation is: `avg = sum(a * weights) / sum(weights)`.\n            The only constraint on weights is that `sum(weights)` must not be 0.\n\n    Returns:\n        Return the average along the specified axis.\n\n    Examples:\n    >>> data = keras.ops.arange(1, 5)\n    >>> data\n    array([1, 2, 3, 4], dtype=int32)\n    >>> keras.ops.average(data)\n    array(2.5, dtype=float32)\n    >>> keras.ops.average(\n    ...     keras.ops.arange(1, 11),\n    ...     weights=keras.ops.arange(10, 0, -1)\n    ... )\n    array(4., dtype=float32)\n\n    >>> data = keras.ops.arange(6).reshape((3, 2))\n    >>> data\n    array([[0, 1],\n           [2, 3],\n           [4, 5]], dtype=int32)\n    >>> keras.ops.average(\n    ...     data,\n    ...     axis=1,\n    ...     weights=keras.ops.array([1./4, 3./4])\n    ... )\n    array([0.75, 2.75, 4.75], dtype=float32)\n    >>> keras.ops.average(\n    ...     data,\n    ...     weights=keras.ops.array([1./4, 3./4])\n    ... )\n    Traceback (most recent call last):\n        ...\n    ValueError: Axis must be specified when shapes of a and weights differ.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Average(axis=axis).symbolic_call(x, weights=weights)\n    return backend.numpy.average(x, axis=axis, weights=weights)\n\n\nclass Bartlett(Operation):\n    def call(self, x):\n        return backend.numpy.bartlett(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=backend.floatx())\n\n\n@keras_export([\"keras.ops.bartlett\", \"keras.ops.numpy.bartlett\"])\ndef bartlett(x):\n    \"\"\"Bartlett window function.\n    The Bartlett window is a triangular window that rises then falls linearly.\n\n    Args:\n        x: Scalar or 1D Tensor. Window length.\n\n    Returns:\n        A 1D tensor containing the Bartlett window values.\n\n    Example:\n    >>> x = keras.ops.convert_to_tensor(5)\n    >>> keras.ops.bartlett(x)\n    array([0. , 0.5, 1. , 0.5, 0. ], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Bartlett().symbolic_call(x)\n    return backend.numpy.bartlett(x)\n\n\nclass Hamming(Operation):\n    def call(self, x):\n        return backend.numpy.hamming(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=backend.floatx())\n\n\n@keras_export([\"keras.ops.hamming\", \"keras.ops.numpy.hamming\"])\ndef hamming(x):\n    \"\"\"Hamming window function.\n\n    The Hamming window is defined as:\n    `w[n] = 0.54 - 0.46 * cos(2 * pi * n / (N - 1))` for `0 <= n <= N - 1`.\n\n    Args:\n        x: Scalar or 1D Tensor. The window length.\n\n    Returns:\n        A 1D tensor containing the Hamming window values.\n\n    Example:\n    >>> x = keras.ops.convert_to_tensor(5)\n    >>> keras.ops.hamming(x)\n    array([0.08, 0.54, 1.  , 0.54, 0.08], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Hamming().symbolic_call(x)\n    return backend.numpy.hamming(x)\n\n\nclass Hanning(Operation):\n    def call(self, x):\n        return backend.numpy.hanning(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=backend.floatx())\n\n\n@keras_export([\"keras.ops.hanning\", \"keras.ops.numpy.hanning\"])\ndef hanning(x):\n    \"\"\"Hanning window function.\n\n    The Hanning window is defined as:\n    `w[n] = 0.5 - 0.5 * cos(2 * pi * n / (N - 1))` for `0 <= n <= N - 1`.\n\n    Args:\n        x: Scalar or 1D Tensor. The window length.\n\n    Returns:\n        A 1D tensor containing the Hanning window values.\n\n    Example:\n    >>> x = keras.ops.convert_to_tensor(5)\n    >>> keras.ops.hanning(x)\n    array([0. , 0.5, 1. , 0.5, 0. ], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Hanning().symbolic_call(x)\n    return backend.numpy.hanning(x)\n\n\nclass Heaviside(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.heaviside(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        dtype = dtypes.result_type(x1.dtype, x2.dtype)\n        if dtype in [\"int8\", \"int16\", \"int32\", \"uint8\", \"uint16\", \"uint32\"]:\n            dtype = backend.floatx()\n        elif dtype == \"int64\":\n            dtype = \"float64\"\n        return KerasTensor(broadcast_shapes(x1.shape, x2.shape), dtype=dtype)\n\n\n@keras_export([\"keras.ops.heaviside\", \"keras.ops.numpy.heaviside\"])\ndef heaviside(x1, x2):\n    \"\"\"Heaviside step function.\n\n    The Heaviside step function is defined as:\n    `heaviside(x1, x2) = 0 if x1 < 0, 1 if x1 > 0, x2 if x1 == 0`\n\n    Args:\n        x1: A tensor input.\n        x2: A scalar or tensor, the value to return when `x1 == 0`.\n\n    Returns:\n        A tensor with a shape determined by broadcasting `x1` and `x2`.\n\n    Example:\n    >>> x1 = keras.ops.convert_to_tensor([-2.0, 0.0, 3.0])\n    >>> x2 = 0.5\n    >>> keras.ops.heaviside(x1, x2)\n    array([0. , 0.5, 1. ], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Heaviside().symbolic_call(x1, x2)\n    return backend.numpy.heaviside(x1, x2)\n\n\nclass Kaiser(Operation):\n    def __init__(self, beta, *, name=None):\n        super().__init__(name=name)\n        self.beta = beta\n\n    def call(self, x):\n        return backend.numpy.kaiser(x, self.beta)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=backend.floatx())\n\n\n@keras_export([\"keras.ops.kaiser\", \"keras.ops.numpy.kaiser\"])\ndef kaiser(x, beta):\n    \"\"\"Kaiser window function.\n\n    The Kaiser window is defined as:\n    `w[n] = I0(beta * sqrt(1 - (2n / (N - 1) - 1)^2)) / I0(beta)`\n    where I0 is the modified zeroth-order Bessel function of the first kind.\n\n    Args:\n        x: Scalar or 1D Tensor. The window length.\n        beta: Float. Shape parameter for the Kaiser window.\n\n    Returns:\n        A 1D tensor containing the Kaiser window values.\n\n    Example:\n    >>> x = keras.ops.convert_to_tensor(5)\n    >>> keras.ops.kaiser(x, beta=14.0)\n    array([7.7268669e-06, 1.6493219e-01, 1.0000000e+00, 1.6493219e-01,\n       7.7268669e-06], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Kaiser(beta).symbolic_call(x)\n    return backend.numpy.kaiser(x, beta)\n\n\nclass Bincount(Operation):\n    def __init__(self, weights=None, minlength=0, sparse=False, *, name=None):\n        super().__init__(name=name)\n        self.weights = weights\n        self.minlength = minlength\n        self.sparse = sparse\n\n    def call(self, x):\n        return backend.numpy.bincount(\n            x,\n            weights=self.weights,\n            minlength=self.minlength,\n            sparse=self.sparse,\n        )\n\n    def compute_output_spec(self, x):\n        dtypes_to_resolve = [x.dtype]\n        if self.weights is not None:\n            weights = backend.convert_to_tensor(self.weights)\n            dtypes_to_resolve.append(weights.dtype)\n            dtype = dtypes.result_type(*dtypes_to_resolve)\n        else:\n            dtype = \"int32\"\n        x_sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(\n            list(x.shape[:-1]) + [None],\n            dtype=dtype,\n            sparse=x_sparse or self.sparse,\n        )\n\n\n@keras_export([\"keras.ops.bincount\", \"keras.ops.numpy.bincount\"])\ndef bincount(x, weights=None, minlength=0, sparse=False):\n    \"\"\"Count the number of occurrences of each value in a tensor of integers.\n\n    Args:\n        x: Input tensor.\n            It must be of dimension 1, and it must only contain non-negative\n            integer(s).\n        weights: Weight tensor.\n            It must have the same length as `x`. The default value is `None`.\n            If specified, `x` is weighted by it, i.e. if `n = x[i]`,\n            `out[n] += weight[i]` instead of the default behavior `out[n] += 1`.\n        minlength: An integer.\n            The default value is 0. If specified, there will be at least\n            this number of bins in the output tensor. If greater than\n            `max(x) + 1`, each value of the output at an index higher than\n            `max(x)` is set to 0.\n        sparse: Whether to return a sparse tensor; for backends that support\n            sparse tensors.\n\n    Returns:\n        1D tensor where each element gives the number of occurrence(s) of its\n        index value in x. Its length is the maximum between `max(x) + 1` and\n        minlength.\n\n    Examples:\n    >>> x = keras.ops.array([1, 2, 2, 3], dtype=\"uint8\")\n    >>> keras.ops.bincount(x)\n    array([0, 1, 2, 1], dtype=int32)\n    >>> weights = x / 2\n    >>> weights\n    array([0.5, 1., 1., 1.5], dtype=float64)\n    >>> keras.ops.bincount(x, weights=weights)\n    array([0., 0.5, 2., 1.5], dtype=float64)\n    >>> minlength = (keras.ops.max(x).numpy() + 1) + 2 # 6\n    >>> keras.ops.bincount(x, minlength=minlength)\n    array([0, 1, 2, 1, 0, 0], dtype=int32)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Bincount(\n            weights=weights, minlength=minlength, sparse=sparse\n        ).symbolic_call(x)\n    return backend.numpy.bincount(\n        x, weights=weights, minlength=minlength, sparse=sparse\n    )\n\n\nclass BitwiseAnd(Operation):\n    def call(self, x, y):\n        return backend.numpy.bitwise_and(x, y)\n\n    def compute_output_spec(self, x, y):\n        dtype = dtypes.result_type(x.dtype, y.dtype)\n        return KerasTensor(x.shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.bitwise_and\", \"keras.ops.numpy.bitwise_and\"])\ndef bitwise_and(x, y):\n    \"\"\"Compute the bit-wise AND of two arrays element-wise.\n\n    Computes the bit-wise AND of the underlying binary representation of the\n    integers in the input arrays. This ufunc implements the C/Python operator\n    `&`.\n\n    Args:\n        x: Input integer tensor.\n        y: Input integer tensor.\n\n    Returns:\n        Result tensor.\n    \"\"\"\n    if any_symbolic_tensors((x, y)):\n        return BitwiseAnd().symbolic_call(x, y)\n    return backend.numpy.bitwise_and(x, y)\n\n\nclass BitwiseInvert(Operation):\n    def call(self, x):\n        return backend.numpy.bitwise_invert(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.bitwise_invert\", \"keras.ops.numpy.bitwise_invert\"])\ndef bitwise_invert(x):\n    \"\"\"Compute bit-wise inversion, or bit-wise NOT, element-wise.\n\n    Computes the bit-wise NOT of the underlying binary representation of the\n    integers in the input arrays. This ufunc implements the C/Python operator\n    `~`.\n\n    Args:\n        x: Input integer tensor.\n\n    Returns:\n        Result tensor.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return BitwiseInvert().symbolic_call(x)\n    return backend.numpy.bitwise_invert(x)\n\n\nclass BitwiseNot(Operation):\n    def call(self, x):\n        return backend.numpy.bitwise_not(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.bitwise_not\", \"keras.ops.numpy.bitwise_not\"])\ndef bitwise_not(x):\n    \"\"\"Compute bit-wise inversion, or bit-wise NOT, element-wise.\n\n    Computes the bit-wise NOT of the underlying binary representation of the\n    integers in the input arrays. This ufunc implements the C/Python operator\n    `~`.\n\n    Args:\n        x: Input integer tensor.\n\n    Returns:\n        Result tensor.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return BitwiseNot().symbolic_call(x)\n    return backend.numpy.bitwise_not(x)\n\n\nclass BitwiseOr(Operation):\n    def call(self, x, y):\n        return backend.numpy.bitwise_or(x, y)\n\n    def compute_output_spec(self, x, y):\n        dtype = dtypes.result_type(x.dtype, y.dtype)\n        return KerasTensor(x.shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.bitwise_or\", \"keras.ops.numpy.bitwise_or\"])\ndef bitwise_or(x, y):\n    \"\"\"Compute the bit-wise OR of two arrays element-wise.\n\n    Computes the bit-wise OR of the underlying binary representation of the\n    integers in the input arrays. This ufunc implements the C/Python operator\n    `|`.\n\n    Args:\n        x: Input integer tensor.\n        y: Input integer tensor.\n\n    Returns:\n        Result tensor.\n    \"\"\"\n    if any_symbolic_tensors((x, y)):\n        return BitwiseOr().symbolic_call(x, y)\n    return backend.numpy.bitwise_or(x, y)\n\n\nclass BitwiseXor(Operation):\n    def call(self, x, y):\n        return backend.numpy.bitwise_xor(x, y)\n\n    def compute_output_spec(self, x, y):\n        dtype = dtypes.result_type(x.dtype, y.dtype)\n        return KerasTensor(x.shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.bitwise_xor\", \"keras.ops.numpy.bitwise_xor\"])\ndef bitwise_xor(x, y):\n    \"\"\"Compute the bit-wise XOR of two arrays element-wise.\n\n    Computes the bit-wise XOR of the underlying binary representation of the\n    integers in the input arrays. This ufunc implements the C/Python operator\n    `^`.\n\n    Args:\n        x: Input integer tensor.\n        y: Input integer tensor.\n\n    Returns:\n        Result tensor.\n    \"\"\"\n    if any_symbolic_tensors((x, y)):\n        return BitwiseXor().symbolic_call(x, y)\n    return backend.numpy.bitwise_xor(x, y)\n\n\nclass BitwiseLeftShift(Operation):\n    def call(self, x, y):\n        return backend.numpy.bitwise_left_shift(x, y)\n\n    def compute_output_spec(self, x, y):\n        if isinstance(y, int):\n            dtype = x.dtype\n        else:\n            dtype = dtypes.result_type(x.dtype, y.dtype)\n        return KerasTensor(x.shape, dtype=dtype)\n\n\n@keras_export(\n    [\"keras.ops.bitwise_left_shift\", \"keras.ops.numpy.bitwise_left_shift\"]\n)\ndef bitwise_left_shift(x, y):\n    \"\"\"Shift the bits of an integer to the left.\n\n    Bits are shifted to the left by appending `y` 0s at the right of `x`.\n    Since the internal representation of numbers is in binary format, this\n    operation is equivalent to multiplying `x` by `2**y`.\n\n    Args:\n        x: Input integer tensor.\n        y: Input integer tensor.\n\n    Returns:\n        Result tensor.\n    \"\"\"\n    if any_symbolic_tensors((x, y)):\n        return BitwiseLeftShift().symbolic_call(x, y)\n    return backend.numpy.bitwise_left_shift(x, y)\n\n\nclass LeftShift(Operation):\n    def call(self, x, y):\n        return backend.numpy.left_shift(x, y)\n\n    def compute_output_spec(self, x, y):\n        if isinstance(y, int):\n            dtype = x.dtype\n        else:\n            dtype = dtypes.result_type(x.dtype, y.dtype)\n        return KerasTensor(x.shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.left_shift\", \"keras.ops.numpy.left_shift\"])\ndef left_shift(x, y):\n    \"\"\"Shift the bits of an integer to the left.\n\n    Bits are shifted to the left by appending `y` 0s at the right of `x`.\n    Since the internal representation of numbers is in binary format, this\n    operation is equivalent to multiplying `x` by `2**y`.\n\n    Args:\n        x: Input integer tensor.\n        y: Input integer tensor.\n\n    Returns:\n        Result tensor.\n    \"\"\"\n    if any_symbolic_tensors((x, y)):\n        return LeftShift().symbolic_call(x, y)\n    return backend.numpy.left_shift(x, y)\n\n\nclass BitwiseRightShift(Operation):\n    def call(self, x, y):\n        return backend.numpy.bitwise_right_shift(x, y)\n\n    def compute_output_spec(self, x, y):\n        if isinstance(y, int):\n            dtype = x.dtype\n        else:\n            dtype = dtypes.result_type(x.dtype, y.dtype)\n        return KerasTensor(x.shape, dtype=dtype)\n\n\n@keras_export(\n    [\"keras.ops.bitwise_right_shift\", \"keras.ops.numpy.bitwise_right_shift\"]\n)\ndef bitwise_right_shift(x, y):\n    \"\"\"Shift the bits of an integer to the right.\n\n    Bits are shifted to the right `y`. Because the internal representation of\n    numbers is in binary format, this operation is equivalent to dividing `x` by\n    `2**y`.\n\n    Args:\n        x: Input integer tensor.\n        y: Input integer tensor.\n\n    Returns:\n        Result tensor.\n    \"\"\"\n    if any_symbolic_tensors((x, y)):\n        return BitwiseRightShift().symbolic_call(x, y)\n    return backend.numpy.bitwise_right_shift(x, y)\n\n\nclass RightShift(Operation):\n    def call(self, x, y):\n        return backend.numpy.right_shift(x, y)\n\n    def compute_output_spec(self, x, y):\n        if isinstance(y, int):\n            dtype = x.dtype\n        else:\n            dtype = dtypes.result_type(x.dtype, y.dtype)\n        return KerasTensor(x.shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.right_shift\", \"keras.ops.numpy.right_shift\"])\ndef right_shift(x, y):\n    \"\"\"Shift the bits of an integer to the right.\n\n    Bits are shifted to the right `y`. Because the internal representation of\n    numbers is in binary format, this operation is equivalent to dividing `x` by\n    `2**y`.\n\n    Args:\n        x: Input integer tensor.\n        y: Input integer tensor.\n\n    Returns:\n        Result tensor.\n    \"\"\"\n    if any_symbolic_tensors((x, y)):\n        return RightShift().symbolic_call(x, y)\n    return backend.numpy.right_shift(x, y)\n\n\nclass Blackman(Operation):\n    def call(self, x):\n        return backend.numpy.blackman(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=backend.floatx())\n\n\n@keras_export([\"keras.ops.blackman\", \"keras.ops.numpy.blackman\"])\ndef blackman(x):\n    \"\"\"Blackman window function.\n    The Blackman window is a taper formed by using a weighted cosine.\n\n    Args:\n        x: Scalar or 1D Tensor. Window length.\n\n    Returns:\n        A 1D tensor containing the Blackman window values.\n\n    Example:\n    >>> x = keras.ops.convert_to_tensor(5)\n    >>> keras.ops.blackman(x)\n    array([-1.3877788e-17,  3.4000000e-01,  1.0000000e+00,  3.4000000e-01,\n           -1.3877788e-17], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Blackman().symbolic_call(x)\n    return backend.numpy.blackman(x)\n\n\nclass BroadcastTo(Operation):\n    def __init__(self, shape, *, name=None):\n        super().__init__(name=name)\n        self.shape = shape\n\n    def call(self, x):\n        return backend.numpy.broadcast_to(x, self.shape)\n\n    def compute_output_spec(self, x):\n        # Catch broadcasting errors for clear error messages.\n        broadcast_shapes(x.shape, self.shape)\n        return KerasTensor(self.shape, dtype=x.dtype)\n\n\n@keras_export(\n    [\n        \"keras.ops.broadcast_to\",\n        \"keras.ops.numpy.broadcast_to\",\n    ]\n)\ndef broadcast_to(x, shape):\n    \"\"\"Broadcast a tensor to a new shape.\n\n    Args:\n        x: The tensor to broadcast.\n        shape: The shape of the desired tensor. A single integer `i` is\n            interpreted as `(i,)`.\n\n    Returns:\n        A tensor with the desired shape.\n\n    Examples:\n    >>> x = keras.ops.array([1, 2, 3])\n    >>> keras.ops.broadcast_to(x, (3, 3))\n    array([[1, 2, 3],\n           [1, 2, 3],\n           [1, 2, 3]])\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return BroadcastTo(shape=shape).symbolic_call(x)\n    return backend.numpy.broadcast_to(x, shape)\n\n\nclass Cbrt(Operation):\n    def call(self, x):\n        return backend.numpy.cbrt(x)\n\n    def compute_output_spec(self, x):\n        dtype = backend.standardize_dtype(x.dtype)\n        if dtype in [\n            \"bool\",\n            \"int8\",\n            \"int16\",\n            \"int32\",\n            \"uint8\",\n            \"uint16\",\n            \"uint32\",\n        ]:\n            dtype = backend.floatx()\n        elif dtype == \"int64\":\n            dtype = \"float64\"\n\n        return KerasTensor(x.shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.cbrt\", \"keras.ops.numpy.cbrt\"])\ndef cbrt(x):\n    \"\"\"Computes the cube root of the input tensor, element-wise.\n\n    This operation returns the real-valued cube root of `x`, handling\n    negative numbers properly in the real domain.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        A tensor containing the cube root of each element in `x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Cbrt().symbolic_call(x)\n    return backend.numpy.cbrt(x)\n\n\nclass Ceil(Operation):\n    def call(self, x):\n        return backend.numpy.ceil(x)\n\n    def compute_output_spec(self, x):\n        if backend.standardize_dtype(x.dtype) == \"int64\":\n            dtype = backend.floatx()\n        else:\n            dtype = dtypes.result_type(x.dtype, float)\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(x.shape, dtype=dtype, sparse=sparse)\n\n\n@keras_export([\"keras.ops.ceil\", \"keras.ops.numpy.ceil\"])\ndef ceil(x):\n    \"\"\"Return the ceiling of the input, element-wise.\n\n    The ceil of the scalar `x` is the smallest integer `i`, such that\n    `i >= x`.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        The ceiling of each element in `x`, with float dtype.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Ceil().symbolic_call(x)\n    return backend.numpy.ceil(x)\n\n\nclass Clip(Operation):\n    def __init__(self, x_min, x_max, *, name=None):\n        super().__init__(name=name)\n        self.x_min = x_min\n        self.x_max = x_max\n\n    def call(self, x):\n        return backend.numpy.clip(x, self.x_min, self.x_max)\n\n    def compute_output_spec(self, x):\n        dtype = backend.standardize_dtype(x.dtype)\n        if dtype == \"bool\":\n            dtype = \"int32\"\n        return KerasTensor(x.shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.clip\", \"keras.ops.numpy.clip\"])\ndef clip(x, x_min, x_max):\n    \"\"\"Clip (limit) the values in a tensor.\n\n    Given an interval, values outside the interval are clipped to the\n    interval edges. For example, if an interval of `[0, 1]` is specified,\n    values smaller than 0 become 0, and values larger than 1 become 1.\n\n    Args:\n        x: Input tensor.\n        x_min: Minimum value.\n        x_max: Maximum value.\n    Returns:\n        The clipped tensor.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Clip(x_min, x_max).symbolic_call(x)\n    return backend.numpy.clip(x, x_min, x_max)\n\n\nclass Concatenate(Operation):\n    def __init__(self, axis=0, *, name=None):\n        super().__init__(name=name)\n        if axis is None:\n            raise ValueError(\"`axis` cannot be None for `concatenate`.\")\n        self.axis = axis\n\n    def call(self, xs):\n        return backend.numpy.concatenate(xs, axis=self.axis)\n\n    def compute_output_spec(self, xs):\n        first_shape = xs[0].shape\n        total_size_on_axis = 0\n        all_sparse = True\n        dtypes_to_resolve = []\n        for x in xs:\n            if not shape_equal(\n                x.shape, first_shape, axis=[self.axis], allow_none=True\n            ):\n                raise ValueError(\n                    \"Every value in `xs` must have the same shape except on \"\n                    f\"the `axis` dim. But found element of shape {x.shape}, \"\n                    f\"which is different from the first element's \"\n                    f\"shape {first_shape}.\"\n                )\n            if total_size_on_axis is None or x.shape[self.axis] is None:\n                total_size_on_axis = None\n            else:\n                total_size_on_axis += x.shape[self.axis]\n            all_sparse = all_sparse and getattr(x, \"sparse\", False)\n            dtypes_to_resolve.append(getattr(x, \"dtype\", type(x)))\n        output_shape = list(first_shape)\n        output_shape[self.axis] = total_size_on_axis\n        dtype = dtypes.result_type(*dtypes_to_resolve)\n        return KerasTensor(output_shape, dtype=dtype, sparse=all_sparse)\n\n\n@keras_export(\n    [\n        \"keras.ops.concatenate\",\n        \"keras.ops.numpy.concatenate\",\n    ]\n)\ndef concatenate(xs, axis=0):\n    \"\"\"Join a sequence of tensors along an existing axis.\n\n    Args:\n        xs: The sequence of tensors to concatenate.\n        axis: The axis along which the tensors will be joined. Defaults to `0`.\n\n    Returns:\n        The concatenated tensor.\n    \"\"\"\n    if any_symbolic_tensors(xs):\n        return Concatenate(axis=axis).symbolic_call(xs)\n    return backend.numpy.concatenate(xs, axis=axis)\n\n\nclass Conjugate(Operation):\n    def call(self, x):\n        return backend.numpy.conjugate(x)\n\n    def compute_output_spec(self, x):\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(x.shape, dtype=x.dtype, sparse=sparse)\n\n\n@keras_export([\"keras.ops.conjugate\", \"keras.ops.numpy.conjugate\"])\ndef conjugate(x):\n    \"\"\"Returns the complex conjugate, element-wise.\n\n    The complex conjugate of a complex number is obtained by changing the sign\n    of its imaginary part.\n\n    `keras.ops.conj` is a shorthand for this function.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        The complex conjugate of each element in `x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Conjugate().symbolic_call(x)\n    return backend.numpy.conjugate(x)\n\n\nclass Conj(Conjugate):\n    pass\n\n\n@keras_export([\"keras.ops.conj\", \"keras.ops.numpy.conj\"])\ndef conj(x):\n    \"\"\"Shorthand for `keras.ops.conjugate`.\"\"\"\n    return conjugate(x)\n\n\nclass Copy(Operation):\n    def call(self, x):\n        return backend.numpy.copy(x)\n\n    def compute_output_spec(self, x):\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(x.shape, dtype=x.dtype, sparse=sparse)\n\n\n@keras_export([\"keras.ops.copy\", \"keras.ops.numpy.copy\"])\ndef copy(x):\n    \"\"\"Returns a copy of `x`.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        A copy of `x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Copy().symbolic_call(x)\n    return backend.numpy.copy(x)\n\n\nclass Cos(Operation):\n    def call(self, x):\n        return backend.numpy.cos(x)\n\n    def compute_output_spec(self, x):\n        dtype = backend.standardize_dtype(getattr(x, \"dtype\", backend.floatx()))\n        if dtype == \"int64\":\n            dtype = backend.floatx()\n        else:\n            dtype = dtypes.result_type(dtype, float)\n        return KerasTensor(x.shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.cos\", \"keras.ops.numpy.cos\"])\ndef cos(x):\n    \"\"\"Cosine, element-wise.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        The corresponding cosine values.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Cos().symbolic_call(x)\n    return backend.numpy.cos(x)\n\n\nclass Cosh(Operation):\n    def call(self, x):\n        return backend.numpy.cosh(x)\n\n    def compute_output_spec(self, x):\n        dtype = backend.standardize_dtype(getattr(x, \"dtype\", backend.floatx()))\n        if dtype == \"int64\":\n            dtype = backend.floatx()\n        else:\n            dtype = dtypes.result_type(dtype, float)\n        return KerasTensor(x.shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.cosh\", \"keras.ops.numpy.cosh\"])\ndef cosh(x):\n    \"\"\"Hyperbolic cosine, element-wise.\n\n    Arguments:\n        x: Input tensor.\n\n    Returns:\n        Output tensor of same shape as `x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Cosh().symbolic_call(x)\n    return backend.numpy.cosh(x)\n\n\nclass CountNonzero(Operation):\n    def __init__(self, axis=None, *, name=None):\n        super().__init__(name=name)\n        if isinstance(axis, int):\n            self.axis = (axis,)\n        else:\n            self.axis = axis\n\n    def call(self, x):\n        return backend.numpy.count_nonzero(x, axis=self.axis)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(\n            reduce_shape(x.shape, axis=self.axis),\n            dtype=\"int32\",\n        )\n\n\n@keras_export(\n    [\n        \"keras.ops.count_nonzero\",\n        \"keras.ops.numpy.count_nonzero\",\n    ]\n)\ndef count_nonzero(x, axis=None):\n    \"\"\"Counts the number of non-zero values in `x` along the given `axis`.\n\n    If no axis is specified then all non-zeros in the tensor are counted.\n\n    Args:\n        x: Input tensor.\n        axis: Axis or tuple of axes along which to count the number of\n            non-zeros. Defaults to `None`.\n\n    Returns:\n        int or tensor of ints.\n\n    Examples:\n    >>> x = keras.ops.array([[0, 1, 7, 0], [3, 0, 2, 19]])\n    >>> keras.ops.count_nonzero(x)\n    5\n    >>> keras.ops.count_nonzero(x, axis=0)\n    array([1, 1, 2, 1], dtype=int64)\n    >>> keras.ops.count_nonzero(x, axis=1)\n    array([2, 3], dtype=int64)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return CountNonzero(axis=axis).symbolic_call(x)\n    return backend.numpy.count_nonzero(x, axis=axis)\n\n\nclass Cross(Operation):\n    def __init__(self, axisa=-1, axisb=-1, axisc=-1, axis=None, *, name=None):\n        super().__init__(name=name)\n        if axis is not None:\n            self.axisa = axis\n            self.axisb = axis\n            self.axisc = axis\n        else:\n            self.axisa = axisa\n            self.axisb = axisb\n            self.axisc = axisc\n\n    def call(self, x1, x2):\n        return backend.numpy.cross(x1, x2, self.axisa, self.axisb, self.axisc)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = list(x1.shape)\n        x2_shape = list(x2.shape)\n\n        x1_value_size = x1_shape[self.axisa]\n        x2_value_size = x2_shape[self.axisa]\n        del x1_shape[self.axisa]\n        del x2_shape[self.axisb]\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n\n        if x1_value_size is not None and x1_value_size not in (2, 3):\n            raise ValueError(\n                \"`x1`'s dim on `axis={axisa}` must be either 2 or 3, but \"\n                f\"received: {x1_value_size}\"\n            )\n        if x2_value_size is not None and x2_value_size not in (2, 3):\n            raise ValueError(\n                \"`x2`'s dim on `axis={axisb}` must be either 2 or 3, but \"\n                f\"received: {x2_value_size}\"\n            )\n\n        if x1_value_size == 3 or x2_value_size == 3:\n            value_size = [3]\n        else:\n            value_size = []\n\n        output_shape = (\n            output_shape[: self.axisc] + value_size + output_shape[self.axisc :]\n        )\n\n        dtype = dtypes.result_type(x1.dtype, x2.dtype)\n        return KerasTensor(output_shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.cross\", \"keras.ops.numpy.cross\"])\ndef cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None):\n    \"\"\"Returns the cross product of two (arrays of) vectors.\n\n    The cross product of `x1` and `x2` in R^3 is a vector\n    perpendicular to both `x1` and `x2`. If `x1` and `x2` are arrays of\n    vectors, the vectors are defined by the last axis of `x1` and `x2`\n    by default, and these axes can have dimensions 2 or 3.\n\n    Where the dimension of either `x1` or `x2` is 2, the third component of\n    the input vector is assumed to be zero and the cross product calculated\n    accordingly.\n\n    In cases where both input vectors have dimension 2, the z-component of\n    the cross product is returned.\n\n    Args:\n        x1: Components of the first vector(s).\n        x2: Components of the second vector(s).\n        axisa: Axis of `x1` that defines the vector(s). Defaults to `-1`.\n        axisb: Axis of `x2` that defines the vector(s). Defaults to `-1`.\n        axisc: Axis of the result containing the cross product vector(s).\n            Ignored if both input vectors have dimension 2, as the return is\n            scalar. By default, the last axis.\n        axis: If defined, the axis of `x1`, `x2` and the result that\n            defines the vector(s) and cross product(s). Overrides `axisa`,\n            `axisb` and `axisc`.\n\n    Note:\n        Torch backend does not support two dimensional vectors, or the\n        arguments `axisa`, `axisb` and `axisc`. Use `axis` instead.\n\n    Returns:\n        Vector cross product(s).\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Cross(\n            axisa=axisa, axisb=axisb, axisc=axisc, axis=axis\n        ).symbolic_call(x1, x2)\n    return backend.numpy.cross(\n        x1,\n        x2,\n        axisa=axisa,\n        axisb=axisb,\n        axisc=axisc,\n        axis=axis,\n    )\n\n\nclass Cumprod(Operation):\n    def __init__(self, axis=None, dtype=None, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n        self.dtype = dtype\n\n    def call(self, x):\n        return backend.numpy.cumprod(x, axis=self.axis, dtype=self.dtype)\n\n    def compute_output_spec(self, x):\n        if self.axis is None:\n            if None in x.shape:\n                output_shape = (None,)\n            else:\n                output_shape = (int(np.prod(x.shape)),)\n        else:\n            output_shape = x.shape\n        output_dtype = (\n            backend.standardize_dtype(x.dtype)\n            if self.dtype is None\n            else self.dtype\n        )\n        if output_dtype == \"bool\":\n            output_dtype = \"int32\"\n        return KerasTensor(output_shape, output_dtype)\n\n\n@keras_export([\"keras.ops.cumprod\", \"keras.ops.numpy.cumprod\"])\ndef cumprod(x, axis=None, dtype=None):\n    \"\"\"Return the cumulative product of elements along a given axis.\n\n    Args:\n        x: Input tensor.\n        axis: Axis along which the cumulative product is computed.\n            By default the input is flattened.\n        dtype: dtype of returned tensor. Defaults to x.dtype.\n\n    Returns:\n        Output tensor.\n    \"\"\"\n    dtype = None if dtype is None else backend.standardize_dtype(dtype)\n    if any_symbolic_tensors((x,)):\n        return Cumprod(axis=axis, dtype=dtype).symbolic_call(x)\n    return backend.numpy.cumprod(x, axis=axis, dtype=dtype)\n\n\nclass Cumsum(Operation):\n    def __init__(self, axis=None, dtype=None, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n        self.dtype = dtype\n\n    def call(self, x):\n        return backend.numpy.cumsum(x, axis=self.axis, dtype=self.dtype)\n\n    def compute_output_spec(self, x):\n        if self.axis is None:\n            if None in x.shape:\n                output_shape = (None,)\n            else:\n                output_shape = (int(np.prod(x.shape)),)\n        else:\n            output_shape = x.shape\n        output_dtype = (\n            backend.standardize_dtype(x.dtype)\n            if self.dtype is None\n            else self.dtype\n        )\n        if output_dtype == \"bool\":\n            output_dtype = \"int32\"\n        return KerasTensor(output_shape, output_dtype)\n\n\n@keras_export([\"keras.ops.cumsum\", \"keras.ops.numpy.cumsum\"])\ndef cumsum(x, axis=None, dtype=None):\n    \"\"\"Returns the cumulative sum of elements along a given axis.\n\n    Args:\n        x: Input tensor.\n        axis: Axis along which the cumulative sum is computed.\n            By default the input is flattened.\n        dtype: dtype of returned tensor. Defaults to x.dtype.\n\n    Returns:\n        Output tensor.\n    \"\"\"\n    dtype = None if dtype is None else backend.standardize_dtype(dtype)\n    if any_symbolic_tensors((x,)):\n        return Cumsum(axis=axis, dtype=dtype).symbolic_call(x)\n    return backend.numpy.cumsum(x, axis=axis, dtype=dtype)\n\n\nclass Deg2rad(Operation):\n    def call(self, x):\n        return backend.numpy.deg2rad(x)\n\n    def compute_output_spec(self, x):\n        dtype = backend.standardize_dtype(x.dtype)\n        if dtype in [\"int64\", \"float64\"]:\n            dtype = \"float64\"\n        elif dtype not in [\"bfloat16\", \"float16\"]:\n            dtype = backend.floatx()\n        return KerasTensor(x.shape, dtype)\n\n\n@keras_export([\"keras.ops.deg2rad\", \"keras.ops.numpy.deg2rad\"])\ndef deg2rad(x):\n    \"\"\"Convert angles from degrees to radians.\n\n    The conversion is defined as:\n    `rad = deg * (π / 180)`\n\n    Args:\n        x: Input tensor of angles in degrees.\n\n    Returns:\n        A tensor containing angles converted to radians.\n\n    Examples:\n    >>> from keras import ops\n    >>> ops.deg2rad(180.0)\n    3.141592653589793\n    >>> ops.deg2rad([0.0, 90.0, 180.0])\n    array([0., 1.57079633, 3.14159265])\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Deg2rad().symbolic_call(x)\n    return backend.numpy.deg2rad(x)\n\n\nclass Diag(Operation):\n    def __init__(self, k=0, *, name=None):\n        super().__init__(name=name)\n        self.k = k\n\n    def call(self, x):\n        return backend.numpy.diag(x, k=self.k)\n\n    def compute_output_spec(self, x):\n        x_shape = x.shape\n        if len(x_shape) == 1:\n            if x_shape[0] is None:\n                output_shape = [None, None]\n            else:\n                output_shape = [\n                    x_shape[0] + int(np.abs(self.k)),\n                    x_shape[0] + int(np.abs(self.k)),\n                ]\n        elif len(x_shape) == 2:\n            if None in x_shape:\n                output_shape = [None]\n            else:\n                shorter_side = np.minimum(x_shape[0], x_shape[1])\n                if self.k > 0:\n                    remaining = x_shape[1] - self.k\n                else:\n                    remaining = x_shape[0] + self.k\n                output_shape = [\n                    int(np.maximum(0, np.minimum(remaining, shorter_side)))\n                ]\n        else:\n            raise ValueError(\n                f\"`x` must be 1-D or 2-D, but received shape {x.shape}.\"\n            )\n        return KerasTensor(output_shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.diag\", \"keras.ops.numpy.diag\"])\ndef diag(x, k=0):\n    \"\"\"Extract a diagonal or construct a diagonal array.\n\n    Args:\n        x: Input tensor. If `x` is 2-D, returns the k-th diagonal of `x`.\n            If `x` is 1-D, return a 2-D tensor with `x` on the k-th diagonal.\n        k: The diagonal to consider. Defaults to `0`. Use `k > 0` for diagonals\n            above the main diagonal, and `k < 0` for diagonals below\n            the main diagonal.\n\n    Returns:\n        The extracted diagonal or constructed diagonal tensor.\n\n    Examples:\n    >>> from keras.src import ops\n    >>> x = ops.arange(9).reshape((3, 3))\n    >>> x\n    array([[0, 1, 2],\n           [3, 4, 5],\n           [6, 7, 8]])\n\n    >>> ops.diag(x)\n    array([0, 4, 8])\n    >>> ops.diag(x, k=1)\n    array([1, 5])\n    >>> ops.diag(x, k=-1)\n    array([3, 7])\n\n    >>> ops.diag(ops.diag(x)))\n    array([[0, 0, 0],\n           [0, 4, 0],\n           [0, 0, 8]])\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Diag(k=k).symbolic_call(x)\n    return backend.numpy.diag(x, k=k)\n\n\nclass Diagflat(Operation):\n    def __init__(self, k=0, *, name=None):\n        super().__init__(name=name)\n        self.k = k\n\n    def call(self, x):\n        return backend.numpy.diagflat(x, k=self.k)\n\n    def compute_output_spec(self, x):\n        x_shape = x.shape\n\n        if len(x_shape) == 0:\n            flat_size = 1\n        elif len(x_shape) == 1:\n            flat_size = x_shape[0] if x_shape[0] is not None else None\n        else:\n            flat_size = None\n            for s in x_shape:\n                if s is None:\n                    flat_size = None\n                    break\n                elif flat_size is None:\n                    flat_size = s\n                else:\n                    flat_size *= s\n\n        if flat_size is None:\n            output_shape = [None, None]\n        else:\n            output_shape = [\n                flat_size + int(np.abs(self.k)),\n                flat_size + int(np.abs(self.k)),\n            ]\n\n        return KerasTensor(output_shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.diagflat\", \"keras.ops.numpy.diagflat\"])\ndef diagflat(x, k=0):\n    \"\"\"Create a two-dimensional array with the flattened input on\n       the k-th diagonal.\n\n    Args:\n        x: Input tensor to be flattened and placed on the diagonal.\n        k: The diagonal to place the flattened input. Defaults to `0`.\n           Use `k > 0` for diagonals above the main diagonal,\n           and `k < 0` for diagonals below the main diagonal.\n\n    Returns:\n        A 2-D tensor with the flattened input on the specified diagonal.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Diagflat(k=k).symbolic_call(x)\n    return backend.numpy.diagflat(x, k=k)\n\n\nclass Diagonal(Operation):\n    def __init__(self, offset=0, axis1=0, axis2=1, *, name=None):\n        super().__init__(name=name)\n        self.offset = offset\n        self.axis1 = axis1\n        self.axis2 = axis2\n\n    def call(self, x):\n        return backend.numpy.diagonal(\n            x,\n            offset=self.offset,\n            axis1=self.axis1,\n            axis2=self.axis2,\n        )\n\n    def compute_output_spec(self, x):\n        x_shape = list(x.shape)\n        if len(x_shape) < 2:\n            raise ValueError(\n                \"`diagonal` requires an array of at least two dimensions, but \"\n                f\"`x` is of shape {x.shape}.\"\n            )\n\n        shape_2d = [x_shape[self.axis1], x_shape[self.axis2]]\n        x_shape[self.axis1] = -1\n        x_shape[self.axis2] = -1\n        output_shape = list(filter((-1).__ne__, x_shape))\n        if None in shape_2d:\n            diag_shape = [None]\n        else:\n            shorter_side = np.minimum(shape_2d[0], shape_2d[1])\n            if self.offset > 0:\n                remaining = shape_2d[1] - self.offset\n            else:\n                remaining = shape_2d[0] + self.offset\n            diag_shape = [\n                int(np.maximum(0, np.minimum(remaining, shorter_side)))\n            ]\n        output_shape = output_shape + diag_shape\n        return KerasTensor(output_shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.diagonal\", \"keras.ops.numpy.diagonal\"])\ndef diagonal(x, offset=0, axis1=0, axis2=1):\n    \"\"\"Return specified diagonals.\n\n    If `x` is 2-D, returns the diagonal of `x` with the given offset, i.e., the\n    collection of elements of the form `x[i, i+offset]`.\n\n    If `x` has more than two dimensions, the axes specified by `axis1`\n    and `axis2` are used to determine the 2-D sub-array whose diagonal\n    is returned.\n\n    The shape of the resulting array can be determined by removing `axis1`\n    and `axis2` and appending an index to the right equal to the size of\n    the resulting diagonals.\n\n    Args:\n        x: Input tensor.\n        offset: Offset of the diagonal from the main diagonal.\n            Can be positive or negative. Defaults to `0`.(main diagonal).\n        axis1: Axis to be used as the first axis of the 2-D sub-arrays.\n            Defaults to `0`.(first axis).\n        axis2: Axis to be used as the second axis of the 2-D sub-arrays.\n            Defaults to `1` (second axis).\n\n    Returns:\n        Tensor of diagonals.\n\n    Examples:\n    >>> from keras.src import ops\n    >>> x = ops.arange(4).reshape((2, 2))\n    >>> x\n    array([[0, 1],\n           [2, 3]])\n    >>> x.diagonal()\n    array([0, 3])\n    >>> x.diagonal(1)\n    array([1])\n\n    >>> x = ops.arange(8).reshape((2, 2, 2))\n    >>> x\n    array([[[0, 1],\n            [2, 3]],\n           [[4, 5],\n            [6, 7]]])\n    >>> x.diagonal(0, 0, 1)\n    array([[0, 6],\n           [1, 7]])\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Diagonal(\n            offset=offset,\n            axis1=axis1,\n            axis2=axis2,\n        ).symbolic_call(x)\n    return backend.numpy.diagonal(\n        x,\n        offset=offset,\n        axis1=axis1,\n        axis2=axis2,\n    )\n\n\nclass Diff(Operation):\n    def __init__(self, n=1, axis=-1, *, name=None):\n        super().__init__(name=name)\n        self.n = n\n        self.axis = axis\n\n    def call(self, a):\n        return backend.numpy.diff(a, n=self.n, axis=self.axis)\n\n    def compute_output_spec(self, a):\n        shape = list(a.shape)\n        size = shape[self.axis]\n        if size is not None:\n            shape[self.axis] = builtins.max(size - self.n, 0)\n        return KerasTensor(shape, dtype=a.dtype)\n\n\n@keras_export([\"keras.ops.diff\", \"keras.ops.numpy.diff\"])\ndef diff(a, n=1, axis=-1):\n    \"\"\"Calculate the n-th discrete difference along the given axis.\n\n    The first difference is given by `out[i] = a[i+1] - a[i]` along\n    the given axis, higher differences are calculated by using `diff`\n    recursively.\n\n    Args:\n        a: Input tensor.\n        n: The number of times values are differenced. Defaults to `1`.\n        axis: Axis to compute discrete difference(s) along.\n            Defaults to `-1`.(last axis).\n\n    Returns:\n        Tensor of diagonals.\n\n    Examples:\n    >>> from keras.src import ops\n    >>> x = ops.convert_to_tensor([1, 2, 4, 7, 0])\n    >>> ops.diff(x)\n    array([ 1,  2,  3, -7])\n    >>> ops.diff(x, n=2)\n    array([  1,   1, -10])\n\n    >>> x = ops.convert_to_tensor([[1, 3, 6, 10], [0, 5, 6, 8]])\n    >>> ops.diff(x)\n    array([[2, 3, 4],\n           [5, 1, 2]])\n    >>> ops.diff(x, axis=0)\n    array([[-1,  2,  0, -2]])\n    \"\"\"\n    return Diff(n=n, axis=axis)(a)\n\n\nclass Digitize(Operation):\n    def call(self, x, bins):\n        return backend.numpy.digitize(x, bins)\n\n    def compute_output_spec(self, x, bins):\n        bins_shape = bins.shape\n        if len(bins_shape) > 1:\n            raise ValueError(\n                f\"`bins` must be a 1D array. Received: bins={bins} \"\n                f\"with shape bins.shape={bins_shape}\"\n            )\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(x.shape, dtype=\"int32\", sparse=sparse)\n\n\n@keras_export([\"keras.ops.digitize\", \"keras.ops.numpy.digitize\"])\ndef digitize(x, bins):\n    \"\"\"Returns the indices of the bins to which each value in `x` belongs.\n\n    Args:\n        x: Input array to be binned.\n        bins: Array of bins. It has to be one-dimensional and monotonically\n            increasing.\n\n    Returns:\n        Output array of indices, of same shape as `x`.\n\n    Example:\n    >>> x = np.array([0.0, 1.0, 3.0, 1.6])\n    >>> bins = np.array([0.0, 3.0, 4.5, 7.0])\n    >>> keras.ops.digitize(x, bins)\n    array([1, 1, 2, 1])\n    \"\"\"\n    if any_symbolic_tensors((x, bins)):\n        return Digitize().symbolic_call(x, bins)\n    return backend.numpy.digitize(x, bins)\n\n\nclass Dot(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.dot(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = list(getattr(x1, \"shape\", []))\n        x2_shape = list(getattr(x2, \"shape\", []))\n        dtype = dtypes.result_type(\n            getattr(x1, \"dtype\", type(x1)),\n            getattr(x2, \"dtype\", type(x2)),\n        )\n        if x1_shape == [] or x2_shape == []:\n            return multiply(x1, x2)\n        if len(x1_shape) == 1 and len(x2_shape) == 1:\n            return KerasTensor([], dtype=dtype)\n        if len(x2_shape) == 1:\n            if x1_shape[-1] != x2_shape[0]:\n                raise ValueError(\n                    \"Shape must match on the last axis of `x1` and `x2` when \"\n                    \"`x1` is N-d array while `x2` is 1-D, but receive shape \"\n                    f\"`x1.shape={x1.shape}` and x2.shape=`{x2.shape}`.\"\n                )\n            return KerasTensor(x1_shape[:-1], dtype=dtype)\n\n        if (\n            x1_shape[-1] is None\n            or x2_shape[-2] is None\n            or x1_shape[-1] == x2_shape[-2]\n        ):\n            del x1_shape[-1]\n            del x2_shape[-2]\n            return KerasTensor(x1_shape + x2_shape, dtype=dtype)\n\n        raise ValueError(\n            \"Shape must match on the last axis of `x1` and second last \"\n            \"axis of `x2` when `x1` is N-d array while `x2` is M-D, but \"\n            f\"received `x1.shape={x1.shape}` and x2.shape=`{x2.shape}`.\"\n        )\n\n\n@keras_export([\"keras.ops.dot\", \"keras.ops.numpy.dot\"])\ndef dot(x1, x2):\n    \"\"\"Dot product of two tensors.\n\n    - If both `x1` and `x2` are 1-D tensors, it is inner product of vectors\n      (without complex conjugation).\n    - If both `x1` and `x2` are 2-D tensors, it is matrix multiplication.\n    - If either `x1` or `x2` is 0-D (scalar), it is equivalent to `x1 * x2`.\n    - If `x1` is an N-D tensor and `x2` is a 1-D tensor, it is a sum product\n      over the last axis of `x1` and `x2`.\n    - If `x1` is an N-D tensor and `x2` is an M-D tensor (where `M>=2`),\n      it is a sum product over the last axis of `x1` and the second-to-last\n      axis of `x2`: `dot(x1, x2)[i,j,k,m] = sum(a[i,j,:] * b[k,:,m])`.\n\n    Args:\n        x1: First argument.\n        x2: Second argument.\n\n    Note:\n        Torch backend does not accept 0-D tensors as arguments.\n\n    Returns:\n        Dot product of `x1` and `x2`.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Dot().symbolic_call(x1, x2)\n    return backend.numpy.dot(x1, x2)\n\n\nclass Dstack(Operation):\n    def call(self, xs):\n        return backend.numpy.dstack(xs)\n\n    def compute_output_spec(self, xs):\n        dtypes_to_resolve = []\n        out_shapes = []\n        for x in xs:\n            shape = list(x.shape)\n            if len(shape) == 0:\n                shape = [1, 1, 1]\n            elif len(shape) == 1:\n                shape = [1, shape[0], 1]\n            elif len(shape) == 2:\n                shape = shape + [1]\n            out_shapes.append(shape)\n            dtypes_to_resolve.append(getattr(x, \"dtype\", type(x)))\n\n        first_shape = out_shapes[0]\n        total_depth = 0\n        for shape in out_shapes:\n            if not shape_equal(shape, first_shape, axis=[2], allow_none=True):\n                raise ValueError(\n                    \"Every value in `xs` must have the same shape except on \"\n                    f\"the `axis` dim. But found element of shape {shape}, \"\n                    f\"which is different from the first element's \"\n                    f\"shape {first_shape}.\"\n                )\n            if total_depth is None or shape[2] is None:\n                total_depth = None\n            else:\n                total_depth += shape[2]\n\n        output_shape = list(first_shape)\n        output_shape[2] = total_depth\n        dtype = dtypes.result_type(*dtypes_to_resolve)\n        return KerasTensor(output_shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.dstack\", \"keras.ops.numpy.dstack\"])\ndef dstack(xs):\n    \"\"\"Stack tensors in sequence depth wise (along third axis).\n\n    This is equivalent to concatenation along the third axis after 2-D tensors\n    of shape `(M, N)` have been reshaped to `(M, N, 1)` and 1-D tensors of shape\n    `(N,)` have been reshaped to `(1, N, 1)`.\n\n    Args:\n        xs: Sequence of tensors.\n\n    Returns:\n        The tensor formed by stacking the given tensors.\n\n    Examples:\n    >>> import keras\n    >>> x = keras.ops.array([1, 2, 3])\n    >>> y = keras.ops.array([4, 5, 6])\n    >>> keras.ops.dstack([x, y])\n    array([[[1, 4],\n            [2, 5],\n            [3, 6]]])\n\n    >>> x = keras.ops.array([[1], [2], [3]])\n    >>> y = keras.ops.array([[4], [5], [6]])\n    >>> keras.ops.dstack([x, y])\n    array([[[1, 4]],\n           [[2, 5]],\n           [[3, 6]]])\n    \"\"\"\n    if any_symbolic_tensors((xs,)):\n        return Dstack().symbolic_call(xs)\n    return backend.numpy.dstack(xs)\n\n\nclass Einsum(Operation):\n    def __init__(self, subscripts, *, name=None):\n        super().__init__(name=name)\n        self.subscripts = subscripts\n\n    def call(self, *operands, **kwargs):\n        return backend.numpy.einsum(self.subscripts, *operands, **kwargs)\n\n    def compute_output_spec(self, *operands):\n        \"\"\"Compute the output shape of `einsum`.\n\n        The shape computation follows the steps below:\n        1. Find all letters in the input specs (left part of \"->\"), and\n            break them into two categories: letters appearing more than once\n            go to `reduced_dims`, otherwise go to `kept_dims`.\n        2. Adjust `reduced_dims` and `kept_dims` based on the output spec\n            (right part of \"->\"). The rule is if the letter appears in the\n            output spec, then move it to `kept_dims`, otherwise move it to\n            `reduced_dims`.\n        3. Compute the target output shape. If no output spec is set, then\n            the target output shape will be \"...{kept_dims}\", e.g., \"...ijk\",\n            else it will be the same as output spec. \"...\" is a wildcard that\n            could map shape of arbitrary length.\n        4. For each operand in `operands`, map the shape specified in the input\n            spec to the output target, e.g, if operand is of shape [2,3,4],\n            input spec is \"i...\" and output target is \"i...jk\", then 2 will go\n            the index 0. For dims not represented by any letter, insert to the\n            wildcard part. For each letter in output target not appearing in\n            input spec, the dim will be 1 for broadcasting. After 4, each\n            operand should have a target shape containing only number and\n            `None`.\n        5. Broadcast all shapes computed from 4, and the result is the output\n            shape.\n\n        Let's take an example to illustrate the steps above. Let's define:\n        ```python\n        x = KerasTensor([None, 3, 4])\n        y = KerasTensor(2, 4, 3)\n        z = knp.einsum(\"...ij, kji->...k\", x, y)\n        ```\n\n        1. `reduced_dims` is {\"i\", \"j\"}, `kept_dims` is {\"k\"}.\n        2. `reduced_dims` is still {\"i\", \"j\"}, and `kept_dims` is {\"k\"}.\n        3. Output target is \"...k\".\n        4. For `x`, the input spec is \"...ij\", and the output target is \"...k\".\n            \"i\" and \"j\" do not appear in the output target, so no replacement\n            happens, and [None] goes to wildcard. Afterwards, \"k\" is replaced\n            by 1, so we get shape [None, 1]. Applying the same logic to `y`, we\n            get shape [2].\n        5. Broadcast [None, 1] and [2], and we get [None, 2], which is the\n            output shape.\n        \"\"\"\n        split_subscripts = self.subscripts.split(\"->\")\n        if len(split_subscripts) > 2:\n            raise ValueError(\n                \"At most one '->' is supported in `einsum` subscripts, but \"\n                f\"received {self.subscripts}.\"\n            )\n        if len(split_subscripts) == 2:\n            subscripts = split_subscripts[0]\n            output_spec = split_subscripts[1]\n        else:\n            subscripts = self.subscripts\n            output_spec = None\n        input_specs = subscripts.split(\",\")\n        if len(input_specs) != len(operands):\n            raise ValueError(\n                f\"Number of operands ({len(operands)}) does not match the \"\n                f\"number of input specs ({len(input_specs)}) in `einsum`, \"\n                f\"received subscripts={self.subscripts}.\"\n            )\n        reduced_dims = set()\n        kept_dims = set()\n        for s in subscripts:\n            if not s.isalpha():\n                continue\n            if s not in reduced_dims and s not in kept_dims:\n                kept_dims.add(s)\n            elif s in kept_dims:\n                kept_dims.remove(s)\n                reduced_dims.add(s)\n\n        if output_spec is not None:\n            # The output spec changes the rule of kept_dims and reduced_dims.\n            # In short, dims appearing in the output spec will be kept, and\n            # dims not appearing in the output spec will be reduced.\n            kept_dims_copy = kept_dims.copy()\n            reduced_dims_copy = reduced_dims.copy()\n            for dim in kept_dims:\n                if dim not in output_spec:\n                    kept_dims_copy.remove(dim)\n                    reduced_dims_copy.add(dim)\n            for dim in reduced_dims:\n                if dim in output_spec:\n                    reduced_dims_copy.remove(dim)\n                    kept_dims_copy.add(dim)\n            kept_dims = kept_dims_copy\n            reduced_dims = reduced_dims_copy\n\n        reduced_dims = sorted(reduced_dims)\n        kept_dims = sorted(kept_dims)\n\n        if output_spec is None:\n            target_broadcast_spec = f\"...{''.join(kept_dims)}\"\n        else:\n            target_broadcast_spec = output_spec\n\n        expanded_operands_shapes = []\n        for x, spec in zip(operands, input_specs):\n            x_shape = getattr(x, \"shape\", [])\n            x_shape = [-1 if size is None else size for size in x_shape]\n            split_spec = spec.split(\"...\")\n            expanded_shape = target_broadcast_spec\n            if len(split_spec) == 1:\n                # In this case, the input spec is just a string of letters,\n                # e.g., \"ijk\".\n                if len(x_shape) != len(split_spec[0]):\n                    raise ValueError(\n                        \"Number of dimensions in the subscript does not \"\n                        \"match the number of dimensions in the operand, \"\n                        f\"received subscript `{spec}` and operand of shape \"\n                        f\"{x_shape}.\"\n                    )\n                for size, s in zip(x_shape, split_spec[0]):\n                    # Replace the letter with the right shape.\n                    expanded_shape = expanded_shape.replace(s, f\"{str(size)} \")\n                expanded_shape = expanded_shape.replace(\"...\", \"\")\n            else:\n                # In this case, the input spec has \"...\", e.g., \"i...j\", \"i...\",\n                # or \"...j\".\n                for i in range(len(split_spec[0])):\n                    expanded_shape = expanded_shape.replace(\n                        split_spec[0][i], f\"{x_shape[i]} \"\n                    )\n                for i in range(len(split_spec[1])):\n                    expanded_shape = expanded_shape.replace(\n                        split_spec[1][-i - 1], f\"{x_shape[-i - 1]} \"\n                    )\n                # Shape matched by \"...\" will be inserted to the position of\n                # \"...\".\n                wildcard_shape_start_index = len(split_spec[0])\n                wildcard_shape_end_index = (\n                    len(x_shape)\n                    if len(split_spec[1]) == 0\n                    else -len(split_spec[1])\n                )\n                wildcard_shape = x_shape[\n                    wildcard_shape_start_index:wildcard_shape_end_index\n                ]\n                wildcard_shape_str = (\n                    f\"{' '.join([str(size) for size in wildcard_shape])} \"\n                )\n                expanded_shape = expanded_shape.replace(\n                    \"...\", wildcard_shape_str\n                )\n            # Replace all letters not yet handled with \"1\" for broadcasting.\n            expanded_shape = re.sub(\"[a-z]\", \"1 \", expanded_shape)\n            expanded_shape = expanded_shape.split()\n            expanded_shape = [\n                None if size == \"-1\" else int(size) for size in expanded_shape\n            ]\n            expanded_operands_shapes.append(expanded_shape)\n\n        output_shape = expanded_operands_shapes[0]\n        for shape in expanded_operands_shapes[1:]:\n            output_shape = broadcast_shapes(output_shape, shape)\n        dtypes_to_resolve = list(\n            set(\n                backend.standardize_dtype(getattr(x, \"dtype\", type(x)))\n                for x in operands\n            )\n        )\n        if len(dtypes_to_resolve) == 1 and dtypes_to_resolve[0] == \"int8\":\n            dtype = \"int32\"\n        else:\n            dtype = dtypes.result_type(*dtypes_to_resolve)\n        return KerasTensor(output_shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.einsum\", \"keras.ops.numpy.einsum\"])\ndef einsum(subscripts, *operands, **kwargs):\n    \"\"\"Evaluates the Einstein summation convention on the operands.\n\n    Args:\n        subscripts: Specifies the subscripts for summation as comma separated\n            list of subscript labels. An implicit (classical Einstein\n            summation) calculation is performed unless the explicit indicator\n            `->` is included as well as subscript labels of the precise\n            output form.\n        operands: The operands to compute the Einstein sum of.\n\n    Returns:\n        The calculation based on the Einstein summation convention.\n\n    Example:\n    >>> from keras.src import ops\n    >>> a = ops.arange(25).reshape(5, 5)\n    >>> b = ops.arange(5)\n    >>> c = ops.arange(6).reshape(2, 3)\n\n    Trace of a matrix:\n\n    >>> ops.einsum(\"ii\", a)\n    60\n    >>> ops.einsum(a, [0, 0])\n    60\n    >>> ops.trace(a)\n    60\n\n    Extract the diagonal:\n\n    >>> ops.einsum(\"ii -> i\", a)\n    array([ 0,  6, 12, 18, 24])\n    >>> ops.einsum(a, [0, 0], [0])\n    array([ 0,  6, 12, 18, 24])\n    >>> ops.diag(a)\n    array([ 0,  6, 12, 18, 24])\n\n    Sum over an axis:\n\n    >>> ops.einsum(\"ij -> i\", a)\n    array([ 10,  35,  60,  85, 110])\n    >>> ops.einsum(a, [0, 1], [0])\n    array([ 10,  35,  60,  85, 110])\n    >>> ops.sum(a, axis=1)\n    array([ 10,  35,  60,  85, 110])\n\n    For higher dimensional tensors summing a single axis can be done\n    with ellipsis:\n\n    >>> ops.einsum(\"...j -> ...\", a)\n    array([ 10,  35,  60,  85, 110])\n    >>> np.einsum(a, [..., 1], [...])\n    array([ 10,  35,  60,  85, 110])\n\n    Compute a matrix transpose or reorder any number of axes:\n\n    >>> ops.einsum(\"ji\", c)\n    array([[0, 3],\n           [1, 4],\n           [2, 5]])\n    >>> ops.einsum(\"ij -> ji\", c)\n    array([[0, 3],\n           [1, 4],\n           [2, 5]])\n    >>> ops.einsum(c, [1, 0])\n    array([[0, 3],\n           [1, 4],\n           [2, 5]])\n    >>> ops.transpose(c)\n    array([[0, 3],\n           [1, 4],\n           [2, 5]])\n\n    Matrix vector multiplication:\n\n    >>> ops.einsum(\"ij, j\", a, b)\n    array([ 30,  80, 130, 180, 230])\n    >>> ops.einsum(a, [0, 1], b, [1])\n    array([ 30,  80, 130, 180, 230])\n    >>> ops.einsum(\"...j, j\", a, b)\n    array([ 30,  80, 130, 180, 230])\n    \"\"\"\n    if any_symbolic_tensors(operands):\n        return Einsum(subscripts).symbolic_call(*operands, **kwargs)\n    return backend.numpy.einsum(subscripts, *operands, **kwargs)\n\n\n@keras_export([\"keras.ops.empty\", \"keras.ops.numpy.empty\"])\ndef empty(shape, dtype=None):\n    \"\"\"Return a tensor of given shape and type filled with uninitialized data.\n\n    Args:\n        shape: Shape of the empty tensor.\n        dtype: Desired data type of the empty tensor.\n\n    Returns:\n        The empty tensor.\n    \"\"\"\n    return backend.numpy.empty(shape, dtype=dtype)\n\n\nclass EmptyLike(Operation):\n    def __init__(self, dtype=None, *, name=None):\n        super().__init__(name=name)\n        self.dtype = dtype\n\n    def call(self, x):\n        return backend.numpy.empty_like(x, dtype=self.dtype)\n\n    def compute_output_spec(self, x):\n        dtype = (\n            backend.standardize_dtype(x.dtype)\n            if self.dtype is None\n            else self.dtype\n        )\n        return KerasTensor(x.shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.empty_like\", \"keras.ops.numpy.empty_like\"])\ndef empty_like(x, dtype=None):\n    \"\"\"Return a new uninitialized tensor with the same shape and dtype as `x`.\n\n    Args:\n        x: Input tensor to mimic shape and dtype.\n        dtype: Optional data type. If None, uses `x.dtype`.\n\n    Returns:\n        A tensor with the same shape and dtype as `x`, with arbitrary contents.\n\n    Example:\n    >>> from keras import ops\n    >>> x = ops.ones((2, 3), dtype=\"float32\")\n    >>> y = ops.empty_like(x)\n    >>> y.shape\n    (2, 3)\n    >>> y.dtype\n    dtype('float32')\n    \"\"\"\n    dtype = None if dtype is None else backend.standardize_dtype(dtype)\n    if any_symbolic_tensors((x,)):\n        return EmptyLike(dtype=dtype).symbolic_call(x)\n    return backend.numpy.empty_like(x, dtype=dtype)\n\n\nclass Equal(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.equal(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n        return KerasTensor(output_shape, dtype=\"bool\")\n\n\n@keras_export([\"keras.ops.equal\", \"keras.ops.numpy.equal\"])\ndef equal(x1, x2):\n    \"\"\"Returns `(x1 == x2)` element-wise.\n\n    Args:\n        x1: Tensor to compare.\n        x2: Tensor to compare.\n\n    Returns:\n        Output tensor, element-wise comparison of `x1` and `x2`.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Equal().symbolic_call(x1, x2)\n    return backend.numpy.equal(x1, x2)\n\n\nclass Exp(Operation):\n    def call(self, x):\n        return backend.numpy.exp(x)\n\n    def compute_output_spec(self, x):\n        dtype = backend.standardize_dtype(x.dtype)\n        if \"int\" in dtype or dtype == \"bool\":\n            dtype = backend.floatx()\n        return KerasTensor(x.shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.exp\", \"keras.ops.numpy.exp\"])\ndef exp(x):\n    \"\"\"Calculate the exponential of all elements in the input tensor.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Output tensor, element-wise exponential of `x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Exp().symbolic_call(x)\n    return backend.numpy.exp(x)\n\n\nclass Exp2(Operation):\n    def call(self, x):\n        return backend.numpy.exp2(x)\n\n    def compute_output_spec(self, x):\n        dtype = backend.standardize_dtype(x.dtype)\n        if \"int\" in dtype or dtype == \"bool\":\n            dtype = backend.floatx()\n        return KerasTensor(x.shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.exp2\", \"keras.ops.numpy.exp2\"])\ndef exp2(x):\n    \"\"\"Calculate the base-2 exponential of all elements in the input tensor.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Output tensor, element-wise base-2 exponential of `x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Exp2().symbolic_call(x)\n    return backend.numpy.exp2(x)\n\n\nclass ExpandDims(Operation):\n    def __init__(self, axis, *, name=None):\n        super().__init__(name=name)\n        if not isinstance(axis, (int, tuple, list)):\n            raise ValueError(\n                \"The `axis` argument to `expand_dims` should be an integer, \"\n                f\"tuple or list. Received axis={axis}\"\n            )\n        self.axis = axis\n\n    def call(self, x):\n        return backend.numpy.expand_dims(x, self.axis)\n\n    def compute_output_spec(self, x):\n        output_shape = operation_utils.compute_expand_dims_output_shape(\n            x.shape, self.axis\n        )\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(output_shape, dtype=x.dtype, sparse=sparse)\n\n\n@keras_export(\n    [\n        \"keras.ops.expand_dims\",\n        \"keras.ops.numpy.expand_dims\",\n    ]\n)\ndef expand_dims(x, axis):\n    \"\"\"Expand the shape of a tensor.\n\n    Insert a new axis at the `axis` position in the expanded tensor shape.\n\n    Args:\n        x: Input tensor.\n        axis: Position in the expanded axes where the new axis\n            (or axes) is placed.\n\n    Returns:\n        Output tensor with the number of dimensions increased.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return ExpandDims(axis=axis).symbolic_call(x)\n    return backend.numpy.expand_dims(x, axis)\n\n\nclass Expm1(Operation):\n    def call(self, x):\n        return backend.numpy.expm1(x)\n\n    def compute_output_spec(self, x):\n        dtype = backend.standardize_dtype(x.dtype)\n        if \"int\" in dtype or dtype == \"bool\":\n            dtype = backend.floatx()\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(x.shape, dtype=dtype, sparse=sparse)\n\n\n@keras_export([\"keras.ops.expm1\", \"keras.ops.numpy.expm1\"])\ndef expm1(x):\n    \"\"\"Calculate `exp(x) - 1` for all elements in the tensor.\n\n    Args:\n        x: Input values.\n\n    Returns:\n        Output tensor, element-wise exponential minus one.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Expm1().symbolic_call(x)\n    return backend.numpy.expm1(x)\n\n\nclass Flip(Operation):\n    def __init__(self, axis=None, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n\n    def call(self, x):\n        return backend.numpy.flip(x, axis=self.axis)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.flip\", \"keras.ops.numpy.flip\"])\ndef flip(x, axis=None):\n    \"\"\"Reverse the order of elements in the tensor along the given axis.\n\n    The shape of the tensor is preserved, but the elements are reordered.\n\n    Args:\n        x: Input tensor.\n        axis: Axis or axes along which to flip the tensor. The default,\n            `axis=None`, will flip over all of the axes of the input tensor.\n\n    Returns:\n        Output tensor with entries of `axis` reversed.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Flip(axis=axis).symbolic_call(x)\n    return backend.numpy.flip(x, axis=axis)\n\n\nclass Floor(Operation):\n    def call(self, x):\n        return backend.numpy.floor(x)\n\n    def compute_output_spec(self, x):\n        sparse = getattr(x, \"sparse\", False)\n        dtype = (\n            backend.floatx()\n            if backend.standardize_dtype(x.dtype) == \"int64\"\n            else dtypes.result_type(x.dtype, float)\n        )\n        return KerasTensor(x.shape, dtype=dtype, sparse=sparse)\n\n\n@keras_export([\"keras.ops.floor\", \"keras.ops.numpy.floor\"])\ndef floor(x):\n    \"\"\"Return the floor of the input, element-wise.\n\n    The floor of the scalar `x` is the largest integer `i`, such that `i <= x`.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Output tensor, element-wise floor of `x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Floor().symbolic_call(x)\n    return backend.numpy.floor(x)\n\n\nclass Full(Operation):\n    def __init__(self, shape, dtype=None, *, name=None):\n        super().__init__(name=name)\n        self.shape = shape\n        self.dtype = dtype\n\n    def call(self, fill_value):\n        return backend.numpy.full(self.shape, fill_value, dtype=self.dtype)\n\n    def compute_output_spec(self, fill_value):\n        dtype = backend.floatx() if self.dtype is None else self.dtype\n        return KerasTensor(self.shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.full\", \"keras.ops.numpy.full\"])\ndef full(shape, fill_value, dtype=None):\n    \"\"\"Return a new tensor of given shape and type, filled with `fill_value`.\n\n    Args:\n        shape: Shape of the new tensor.\n        fill_value: Fill value.\n        dtype: Desired data type of the tensor.\n\n    Returns:\n        Output tensor.\n    \"\"\"\n    dtype = None if dtype is None else backend.standardize_dtype(dtype)\n    if any_symbolic_tensors((fill_value,)):\n        return Full(shape=shape, dtype=dtype).symbolic_call(fill_value)\n    return backend.numpy.full(shape, fill_value, dtype=dtype)\n\n\nclass FullLike(Operation):\n    def __init__(self, dtype=None, *, name=None):\n        super().__init__(name=name)\n        self.dtype = dtype\n\n    def call(self, x, fill_value):\n        return backend.numpy.full_like(x, fill_value, dtype=self.dtype)\n\n    def compute_output_spec(self, x, fill_value):\n        dtype = (\n            backend.standardize_dtype(x.dtype)\n            if self.dtype is None\n            else self.dtype\n        )\n        return KerasTensor(x.shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.full_like\", \"keras.ops.numpy.full_like\"])\ndef full_like(x, fill_value, dtype=None):\n    \"\"\"Return a full tensor with the same shape and type as the given tensor.\n\n    Args:\n        x: Input tensor.\n        fill_value: Fill value.\n        dtype: Overrides data type of the result.\n\n    Returns:\n        Tensor of `fill_value` with the same shape and type as `x`.\n    \"\"\"\n    dtype = None if dtype is None else backend.standardize_dtype(dtype)\n    if any_symbolic_tensors((x, fill_value)):\n        return FullLike(dtype=dtype).symbolic_call(x, fill_value)\n    return backend.numpy.full_like(x, fill_value, dtype=dtype)\n\n\nclass Gcd(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.gcd(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n\n        x1_type = backend.standardize_dtype(getattr(x1, \"dtype\", type(x1)))\n        x2_type = backend.standardize_dtype(getattr(x2, \"dtype\", type(x2)))\n        dtype = dtypes.result_type(x1_type, x2_type)\n        return KerasTensor(output_shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.gcd\", \"keras.ops.numpy.gcd\"])\ndef gcd(x1, x2):\n    \"\"\"Greatest common divisor of `x1` and `x2`, element-wise.\n\n    Args:\n        x1: First input tensor (integer type).\n        x2: Second input tensor (integer type).\n\n    Returns:\n        Output tensor, element-wise greatest common divisor of `x1` and `x2`.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Gcd().symbolic_call(x1, x2)\n    return backend.numpy.gcd(x1, x2)\n\n\nclass GetItem(Operation):\n    def call(self, x, key):\n        if isinstance(key, list):\n            key = tuple(key)\n        return x[key]\n\n    def compute_output_spec(self, x, key):\n        remaining_shape = list(x.shape)\n        new_shape = []\n        if isinstance(key, int):\n            remaining_key = [key]\n        elif isinstance(key, tuple):\n            remaining_key = list(key)\n        elif isinstance(key, list):\n            remaining_key = key.copy()\n        else:\n            raise ValueError(\n                f\"Unsupported key type for array slice. Received: `{key}`\"\n            )\n        num_ellipses = remaining_key.count(Ellipsis)\n        if num_ellipses > 1:\n            raise ValueError(\n                f\"Slice should only have one ellipsis. Received: `{key}`\"\n            )\n        elif num_ellipses == 0:\n            # Add an implicit final ellipsis.\n            remaining_key.append(Ellipsis)\n        # Consume slice key element by element.\n        while True:\n            if not remaining_key:\n                break\n            subkey = remaining_key.pop(0)\n            # Check for `newaxis` and `Ellipsis`.\n            if subkey == Ellipsis:\n                # Keep as many slices remain in our key, omitting `newaxis`.\n                needed = len(remaining_key) - remaining_key.count(np.newaxis)\n                consumed = len(remaining_shape) - needed\n                new_shape += remaining_shape[:consumed]\n                remaining_shape = remaining_shape[consumed:]\n                continue\n            # All frameworks follow numpy for newaxis. `np.newaxis == None`.\n            if subkey == np.newaxis:\n                new_shape.append(1)\n                continue\n            # At this point, we need to consume a new axis from the shape.\n            if not remaining_shape:\n                raise ValueError(\n                    f\"Array has shape {x.shape} but slice \"\n                    f\"has to many indices. Received: `{key}`\"\n                )\n            length = remaining_shape.pop(0)\n            if isinstance(subkey, int):\n                if length is not None:\n                    index = subkey if subkey >= 0 else subkey + length\n                    if index < 0 or index >= length:\n                        raise ValueError(\n                            f\"Array has shape {x.shape} but out-of-bounds \"\n                            f\"index {key} was requested.\"\n                        )\n            elif isinstance(subkey, slice):\n                if length is not None:\n                    # python3 friendly way to compute a slice length.\n                    new_length = len(range(*subkey.indices(length)))\n                    new_shape.append(new_length)\n                else:\n                    new_shape.append(length)\n            else:\n                raise ValueError(\n                    f\"Unsupported key type for array slice. Received: `{key}`\"\n                )\n        return KerasTensor(tuple(new_shape), dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.get_item\", \"keras.ops.numpy.get_item\"])\ndef get_item(x, key):\n    \"\"\"Return `x[key]`.\"\"\"\n    if any_symbolic_tensors((x,)):\n        return GetItem().symbolic_call(x, key)\n    return x[key]\n\n\nclass Geomspace(Operation):\n    def __init__(self, num=50, endpoint=True, dtype=None, axis=0, *, name=None):\n        super().__init__(name=name)\n        self.num = num\n        self.endpoint = endpoint\n        self.dtype = dtype\n        self.axis = axis\n\n    def call(self, start, stop):\n        return backend.numpy.geomspace(\n            start,\n            stop,\n            num=self.num,\n            endpoint=self.endpoint,\n            dtype=self.dtype,\n            axis=self.axis,\n        )\n\n    def compute_output_spec(self, start, stop):\n        start_shape = getattr(start, \"shape\", [])\n        stop_shape = getattr(stop, \"shape\", [])\n        output_shape = broadcast_shapes(start_shape, stop_shape)\n        axis = canonicalize_axis(self.axis, len(output_shape) + 1)\n        output_shape = list(output_shape)\n        output_shape.insert(axis, self.num)\n        dtype = (\n            self.dtype\n            if self.dtype is not None\n            else backend.standardize_dtype(getattr(start, \"dtype\", type(start)))\n        )\n        dtype = backend.result_type(dtype, float)\n        return KerasTensor(output_shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.geomspace\", \"keras.ops.numpy.geomspace\"])\ndef geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):\n    \"\"\"Returns numbers spaced evenly on a log scale (a geometric progression).\n\n    This is similar to `logspace`, but with endpoints specified directly\n    instead of as logarithms. Each output sample is a constant multiple of\n    the previous.\n\n    Args:\n        start: The starting value of the sequence.\n        stop: The final value of the sequence, unless `endpoint` is `False`.\n            In that case, `num + 1` values are spaced over the interval in\n            log-space, of which all but the last (a sequence of length `num`)\n            are returned.\n        num: Number of samples to generate. Defaults to `50`.\n        endpoint: If `True`, `stop` is the last sample. Otherwise, it is not\n            included. Defaults to `True`.\n        dtype: The type of the output tensor.\n        axis: The axis in the result to store the samples. Relevant only\n            if start or stop are array-like.\n\n    Note:\n        Torch backend does not support `axis` argument.\n\n    Returns:\n        A tensor of `num` samples, evenly spaced on a log scale.\n    \"\"\"\n    dtype = None if dtype is None else backend.standardize_dtype(dtype)\n    if any_symbolic_tensors((start, stop)):\n        return Geomspace(num, endpoint, dtype, axis)(start, stop)\n    return backend.numpy.geomspace(\n        start,\n        stop,\n        num=num,\n        endpoint=endpoint,\n        dtype=dtype,\n        axis=axis,\n    )\n\n\nclass Greater(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.greater(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n        return KerasTensor(output_shape, dtype=\"bool\")\n\n\n@keras_export([\"keras.ops.greater\", \"keras.ops.numpy.greater\"])\ndef greater(x1, x2):\n    \"\"\"Return the truth value of `x1 > x2` element-wise.\n\n    Args:\n        x1: First input tensor.\n        x2: Second input tensor.\n\n    Returns:\n        Output tensor, element-wise comparison of `x1` and `x2`.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Greater().symbolic_call(x1, x2)\n    return backend.numpy.greater(x1, x2)\n\n\nclass GreaterEqual(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.greater_equal(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n        return KerasTensor(output_shape, dtype=\"bool\")\n\n\n@keras_export(\n    [\n        \"keras.ops.greater_equal\",\n        \"keras.ops.numpy.greater_equal\",\n    ]\n)\ndef greater_equal(x1, x2):\n    \"\"\"Return the truth value of `x1 >= x2` element-wise.\n\n    Args:\n        x1: First input tensor.\n        x2: Second input tensor.\n\n    Returns:\n        Output tensor, element-wise comparison of `x1` and `x2`.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return GreaterEqual().symbolic_call(x1, x2)\n    return backend.numpy.greater_equal(x1, x2)\n\n\nclass Hstack(Operation):\n    def call(self, xs):\n        return backend.numpy.hstack(xs)\n\n    def compute_output_spec(self, xs):\n        first_shape = xs[0].shape\n        total_size_on_axis = 0\n        dtypes_to_resolve = []\n        for x in xs:\n            if not shape_equal(x.shape, first_shape, axis=[1], allow_none=True):\n                raise ValueError(\n                    \"Every value in `xs` must have the same shape except on \"\n                    f\"the `axis` dim. But found element of shape {x.shape}, \"\n                    f\"which is different from the first element's \"\n                    f\"shape {first_shape}.\"\n                )\n            if total_size_on_axis is None or x.shape[1] is None:\n                total_size_on_axis = None\n            else:\n                total_size_on_axis += x.shape[1]\n            dtypes_to_resolve.append(getattr(x, \"dtype\", type(x)))\n        output_shape = list(first_shape)\n        output_shape[1] = total_size_on_axis\n        dtype = dtypes.result_type(*dtypes_to_resolve)\n        return KerasTensor(output_shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.hstack\", \"keras.ops.numpy.hstack\"])\ndef hstack(xs):\n    \"\"\"Stack tensors in sequence horizontally (column wise).\n\n    This is equivalent to concatenation along the first axis for 1-D tensors,\n    and along the second axis for all other tensors.\n\n    Args:\n        xs: Sequence of tensors.\n\n    Returns:\n        The tensor formed by stacking the given tensors.\n    \"\"\"\n    if any_symbolic_tensors((xs,)):\n        return Hstack().symbolic_call(xs)\n    return backend.numpy.hstack(xs)\n\n\nclass Hsplit(Operation):\n    def __init__(self, indices_or_sections, *, name=None):\n        super().__init__(name=name)\n        if not isinstance(indices_or_sections, int):\n            indices_or_sections = tuple(indices_or_sections)\n        self.indices_or_sections = indices_or_sections\n\n    def call(self, x):\n        return backend.numpy.hsplit(x, self.indices_or_sections)\n\n    def compute_output_spec(self, x):\n        if len(x.shape) < 1:\n            raise ValueError(\n                \"`hsplit` only works on arrays of at least 1 dimension. \"\n                f\"Received array with shape {x.shape}.\"\n            )\n\n        axis = 0 if len(x.shape) == 1 else 1\n        return _compute_split_output_spec(x, self.indices_or_sections, axis)\n\n\n@keras_export([\"keras.ops.hsplit\", \"keras.ops.numpy.hsplit\"])\ndef hsplit(x, indices_or_sections):\n    \"\"\"Split an array into multiple sub-arrays horizontally (column-wise).\n\n    Args:\n        x: Input tensor.\n        indices_or_sections: If an integer, N, the tensor will be split into N\n            equal sections along axis 1 (if ndim >= 2) or axis 0 (if ndim == 1).\n            If a 1-D array of sorted integers, the entries indicate indices at\n            which the tensor will be split along the axis.\n\n    Returns:\n        A list of sub-arrays.\n\n    Example:\n\n    >>> x = keras.ops.arange(16.0).reshape((4, 4))\n    >>> keras.ops.hsplit(x, 2)\n    [array([[ 0.,  1.],\n           [ 4.,  5.],\n           [ 8.,  9.],\n           [12., 13.]]),\n     array([[ 2.,  3.],\n           [ 6.,  7.],\n           [10., 11.],\n           [14., 15.]])]\n    >>> keras.ops.hsplit(x, [1, 3])\n    [array([[0.],\n        [4.],\n        [8.],\n        [12.]]),\n    array([[1., 2.],\n        [5., 6.],\n        [9., 10.],\n        [13., 14.]]),\n    array([[3.],\n        [7.],\n        [11.],\n        [15.]])]\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Hsplit(indices_or_sections).symbolic_call(x)\n    return backend.numpy.hsplit(x, indices_or_sections)\n\n\nclass Hypot(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.hypot(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        dtype = dtypes.result_type(x1.dtype, x2.dtype)\n        if dtype in [\"int8\", \"int16\", \"int32\", \"uint8\", \"uint16\", \"uint32\"]:\n            dtype = backend.floatx()\n        elif dtype == \"int64\":\n            dtype = \"float64\"\n        return KerasTensor(broadcast_shapes(x1.shape, x2.shape), dtype=dtype)\n\n\n@keras_export([\"keras.ops.hypot\", \"keras.ops.numpy.hypot\"])\ndef hypot(x1, x2):\n    \"\"\"Element-wise hypotenuse of right triangles with legs `x1` and `x2`.\n\n    This is equivalent to computing `sqrt(x1**2 + x2**2)` element-wise,\n    with shape determined by broadcasting.\n\n    Args:\n        x1: A tensor, representing the first leg of the right triangle.\n        x2: A tensor, representing the second leg of the right triangle.\n\n    Returns:\n        A tensor with a shape determined by broadcasting `x1` and `x2`.\n\n    Example:\n    >>> x1 = keras.ops.convert_to_tensor([3.0, 4.0, 5.0])\n    >>> x2 = keras.ops.convert_to_tensor([4.0, 3.0, 12.0])\n    >>> keras.ops.hypot(x1, x2)\n    array([5., 5., 13.], dtype=float32)\n\n    >>> x1 = keras.ops.convert_to_tensor([[1, 2], [3, 4]])\n    >>> x2 = keras.ops.convert_to_tensor([1, 1])\n    >>> keras.ops.hypot(x1, x2)\n    array([[1.41421356 2.23606798],\n          [3.16227766 4.12310563]], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Hypot().symbolic_call(x1, x2)\n    return backend.numpy.hypot(x1, x2)\n\n\n@keras_export([\"keras.ops.identity\", \"keras.ops.numpy.identity\"])\ndef identity(n, dtype=None):\n    \"\"\"Return the identity tensor.\n\n    The identity tensor is a square tensor with ones on the main diagonal and\n    zeros elsewhere.\n\n    Args:\n        n: Number of rows (and columns) in the `n x n` output tensor.\n        dtype: Data type of the output tensor.\n\n    Returns:\n        The identity tensor.\n    \"\"\"\n    return backend.numpy.identity(n, dtype=dtype)\n\n\nclass Imag(Operation):\n    def call(self, x):\n        return backend.numpy.imag(x)\n\n    def compute_output_spec(self, x):\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(x.shape, dtype=x.dtype, sparse=sparse)\n\n\n@keras_export([\"keras.ops.imag\", \"keras.ops.numpy.imag\"])\ndef imag(x):\n    \"\"\"Return the imaginary part of the complex argument.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        The imaginary component of the complex argument.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Imag().symbolic_call(x)\n    return backend.numpy.imag(x)\n\n\nclass Isclose(Operation):\n    def __init__(self, equal_nan=False, *, name=None):\n        super().__init__(name=name)\n        self.equal_nan = equal_nan\n\n    def call(self, x1, x2, rtol=1e-5, atol=1e-8):\n        return backend.numpy.isclose(x1, x2, rtol, atol, self.equal_nan)\n\n    def compute_output_spec(self, x1, x2, rtol=1e-5, atol=1e-8):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n        return KerasTensor(output_shape, dtype=\"bool\")\n\n\n@keras_export([\"keras.ops.isclose\", \"keras.ops.numpy.isclose\"])\ndef isclose(x1, x2, rtol=1e-5, atol=1e-8, equal_nan=False):\n    \"\"\"Return whether two tensors are element-wise almost equal.\n\n    Args:\n        x1: First input tensor.\n        x2: Second input tensor.\n        rtol: Relative tolerance.\n        atol: Absolute tolerance.\n        equal_nan: If `True`, element-wise NaNs are considered equal.\n\n    Returns:\n        Output boolean tensor.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Isclose(equal_nan=equal_nan).symbolic_call(x1, x2, rtol, atol)\n    return backend.numpy.isclose(x1, x2, rtol, atol, equal_nan)\n\n\nclass Isfinite(Operation):\n    def call(self, x):\n        return backend.numpy.isfinite(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=\"bool\")\n\n\n@keras_export([\"keras.ops.isfinite\", \"keras.ops.numpy.isfinite\"])\ndef isfinite(x):\n    \"\"\"Return whether a tensor is finite, element-wise.\n\n    Real values are finite when they are not NaN, not positive infinity, and\n    not negative infinity. Complex values are finite when both their real\n    and imaginary parts are finite.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Output boolean tensor.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Isfinite().symbolic_call(x)\n    return backend.numpy.isfinite(x)\n\n\nclass IsIn(Operation):\n    def __init__(\n        self,\n        assume_unique=False,\n        invert=False,\n        *,\n        name=None,\n    ):\n        super().__init__(name=name)\n        self.assume_unique = assume_unique\n        self.invert = invert\n\n    def call(self, x1, x2):\n        return backend.numpy.isin(\n            x1, x2, assume_unique=self.assume_unique, invert=self.invert\n        )\n\n    def compute_output_spec(self, x1, x2):\n        return KerasTensor(x1.shape, dtype=\"bool\")\n\n\n@keras_export([\"keras.ops.isin\", \"keras.ops.numpy.isin\"])\ndef isin(x1, x2, assume_unique=False, invert=False):\n    \"\"\"Test whether each element of `x1` is present in `x2`.\n\n    This operation performs element-wise checks to determine if each value\n    in `x1` is contained within `x2`. The result is a boolean tensor with\n    the same shape as `x1`, where each entry is `True` if the corresponding\n    element in `x1` is in `x2`, and `False` otherwise.\n\n    Args:\n        x1: Input tensor or array-like structure to test.\n        x2: Values against which each element of `x1` is tested.\n            Can be a tensor, list, or scalar.\n        assume_unique: Boolean (default: False).\n            If True, assumes both `x1` and `x2` contain only unique elements.\n            This can speed up the computation. If False, duplicates will be\n            handled correctly but may impact performance.\n        invert: A boolean (default: False).\n            If True, inverts the result. Entries will be `True`\n            where `x1` elements are not in `x2`.\n\n    Returns:\n        A boolean tensor of the same shape as `x1` indicating element-wise\n        membership in `x2`.\n\n    Example:\n    >>> from keras import ops\n    >>> x1 = ops.array([0, 1, 2, 5])\n    >>> x2 = ops.array([0, 2])\n    >>> result = ops.isin(x1, x2)\n    array([ True, False,  True, False])\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return IsIn(assume_unique=assume_unique, invert=invert).symbolic_call(\n            x1, x2\n        )\n    return backend.numpy.isin(\n        x1, x2, assume_unique=assume_unique, invert=invert\n    )\n\n\nclass Isinf(Operation):\n    def call(self, x):\n        return backend.numpy.isinf(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=\"bool\")\n\n\n@keras_export([\"keras.ops.isinf\", \"keras.ops.numpy.isinf\"])\ndef isinf(x):\n    \"\"\"Test element-wise for positive or negative infinity.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Output boolean tensor.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Isinf().symbolic_call(x)\n    return backend.numpy.isinf(x)\n\n\nclass Isnan(Operation):\n    def call(self, x):\n        return backend.numpy.isnan(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=\"bool\")\n\n\n@keras_export([\"keras.ops.isnan\", \"keras.ops.numpy.isnan\"])\ndef isnan(x):\n    \"\"\"Test element-wise for NaN and return result as a boolean tensor.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Output boolean tensor.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Isnan().symbolic_call(x)\n    return backend.numpy.isnan(x)\n\n\nclass Isneginf(Operation):\n    def call(self, x):\n        return backend.numpy.isneginf(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=\"bool\")\n\n\n@keras_export([\"keras.ops.isneginf\", \"keras.ops.numpy.isneginf\"])\ndef isneginf(x):\n    \"\"\"Test element-wise for negative infinity.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Output boolean tensor.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Isneginf().symbolic_call(x)\n    return backend.numpy.isneginf(x)\n\n\nclass Isposinf(Operation):\n    def call(self, x):\n        return backend.numpy.isposinf(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=\"bool\")\n\n\n@keras_export([\"keras.ops.isposinf\", \"keras.ops.numpy.isposinf\"])\ndef isposinf(x):\n    \"\"\"Test element-wise for positive infinity.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Output boolean tensor.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Isposinf().symbolic_call(x)\n    return backend.numpy.isposinf(x)\n\n\nclass Isreal(Operation):\n    def call(self, x):\n        return backend.numpy.isreal(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=\"bool\")\n\n\n@keras_export([\"keras.ops.isreal\", \"keras.ops.numpy.isreal\"])\ndef isreal(x):\n    \"\"\"Test element-wise for real numbers.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Output boolean tensor.\n\n    Example:\n    >>> from keras import ops\n    >>> x = ops.array([1+1j, 1+0j, 4.5, 3, 2, 2j], dtype=\"complex64\")\n    >>> ops.isreal(x)\n    array([False,  True,  True,  True,  True, False])\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Isreal().symbolic_call(x)\n    return backend.numpy.isreal(x)\n\n\nclass Kron(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.kron(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n\n        def _mul_shape_dim(a, b):\n            if a is None or b is None:\n                return None\n            return a * b\n\n        output_shape = tuple(\n            _mul_shape_dim(a, b) for a, b in zip(x1_shape, x2_shape)\n        )\n\n        x1_type = backend.standardize_dtype(getattr(x1, \"dtype\", type(x1)))\n        x2_type = backend.standardize_dtype(getattr(x2, \"dtype\", type(x2)))\n        dtype = dtypes.result_type(x1_type, x2_type)\n        return KerasTensor(output_shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.kron\", \"keras.ops.numpy.kron\"])\ndef kron(x1, x2):\n    \"\"\"Kronecker product of `x1` and `x2`.\n\n    Computes the Kronecker product of two input tensors. If `x1` has shape\n    `(a0, a1, ..., an)` and `x2` has shape `(b0, b1, ..., bn)`, then the\n    output will have shape `(a0*b0, a1*b1, ..., an*bn)`.\n\n    Args:\n        x1: First input tensor.\n        x2: Second input tensor.\n\n    Returns:\n        A tensor representing the Kronecker product of `x1` and `x2`.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Kron().symbolic_call(x1, x2)\n    return backend.numpy.kron(x1, x2)\n\n\nclass Lcm(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.lcm(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n\n        x1_type = backend.standardize_dtype(getattr(x1, \"dtype\", type(x1)))\n        x2_type = backend.standardize_dtype(getattr(x2, \"dtype\", type(x2)))\n        dtype = dtypes.result_type(x1_type, x2_type)\n        return KerasTensor(output_shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.lcm\", \"keras.ops.numpy.lcm\"])\ndef lcm(x1, x2):\n    \"\"\"Least common multiple of `x1` and `x2`, element-wise.\n\n    Args:\n        x1: First input tensor (integer type).\n        x2: Second input tensor (integer type).\n\n    Returns:\n        Output tensor, element-wise least common multiple of `x1` and `x2`.\n\n    Example:\n    >>> x1 = keras.ops.convert_to_tensor([2, 3, 4])\n    >>> x2 = keras.ops.convert_to_tensor([5, 6, 7])\n    >>> keras.ops.lcm(x1, x2)\n    array([10,  6, 28], dtype=int32)\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Lcm().symbolic_call(x1, x2)\n    return backend.numpy.lcm(x1, x2)\n\n\nclass Ldexp(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.ldexp(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n\n        x1_type = backend.standardize_dtype(getattr(x1, \"dtype\", type(x1)))\n        x2_type = backend.standardize_dtype(getattr(x2, \"dtype\", type(x2)))\n        dtype = dtypes.result_type(x1_type, x2_type, float)\n        return KerasTensor(output_shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.ldexp\", \"keras.ops.numpy.ldexp\"])\ndef ldexp(x1, x2):\n    \"\"\"Multiply `x1` by 2 raised to the power of `x2`, element-wise.\n\n    This function computes:\n        ldexp(x1, x2) = x1 * 2**x2\n\n    Args:\n        x1: Float input tensor.\n        x2: Integer exponent tensor.\n\n    Returns:\n        Output tensor\n\n    Example:\n    >>> x1 = keras.ops.convert_to_tensor([0.75, 1.5])\n    >>> x2 = keras.ops.convert_to_tensor([1, 2])\n    >>> keras.ops.ldexp(x1, x2)\n    array([1.5, 6. ], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Ldexp().symbolic_call(x1, x2)\n    return backend.numpy.ldexp(x1, x2)\n\n\nclass Less(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.less(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n        return KerasTensor(output_shape, dtype=\"bool\")\n\n\n@keras_export([\"keras.ops.less\", \"keras.ops.numpy.less\"])\ndef less(x1, x2):\n    \"\"\"Return the truth value of `x1 < x2` element-wise.\n\n    Args:\n        x1: First input tensor.\n        x2: Second input tensor.\n\n    Returns:\n        Output tensor, element-wise comparison of `x1` and `x2`.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Less().symbolic_call(x1, x2)\n    return backend.numpy.less(x1, x2)\n\n\nclass LessEqual(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.less_equal(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n        return KerasTensor(output_shape, dtype=\"bool\")\n\n\n@keras_export(\n    [\n        \"keras.ops.less_equal\",\n        \"keras.ops.numpy.less_equal\",\n    ]\n)\ndef less_equal(x1, x2):\n    \"\"\"Return the truth value of `x1 <= x2` element-wise.\n\n    Args:\n        x1: First input tensor.\n        x2: Second input tensor.\n\n    Returns:\n        Output tensor, element-wise comparison of `x1` and `x2`.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return LessEqual().symbolic_call(x1, x2)\n    return backend.numpy.less_equal(x1, x2)\n\n\nclass Linspace(Operation):\n    def __init__(\n        self,\n        num=50,\n        endpoint=True,\n        retstep=False,\n        dtype=None,\n        axis=0,\n        *,\n        name=None,\n    ):\n        super().__init__(name=name)\n        self.num = num\n        self.endpoint = endpoint\n        self.retstep = retstep\n        self.dtype = dtype\n        self.axis = axis\n\n    def call(self, start, stop):\n        return backend.numpy.linspace(\n            start,\n            stop,\n            num=self.num,\n            endpoint=self.endpoint,\n            retstep=self.retstep,\n            dtype=self.dtype,\n            axis=self.axis,\n        )\n\n    def compute_output_spec(self, start, stop):\n        start_shape = getattr(start, \"shape\", [])\n        stop_shape = getattr(stop, \"shape\", [])\n        output_shape = broadcast_shapes(start_shape, stop_shape)\n        output_shape = list(output_shape)\n        axis = canonicalize_axis(self.axis, len(output_shape) + 1)\n        output_shape.insert(axis, self.num)\n\n        dtype = (\n            self.dtype\n            if self.dtype is not None\n            else backend.standardize_dtype(getattr(start, \"dtype\", type(start)))\n        )\n        dtype = backend.result_type(dtype, float)\n        if self.retstep:\n            return (KerasTensor(output_shape, dtype=dtype), None)\n        return KerasTensor(output_shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.linspace\", \"keras.ops.numpy.linspace\"])\ndef linspace(\n    start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0\n):\n    \"\"\"Return evenly spaced numbers over a specified interval.\n\n    Returns `num` evenly spaced samples, calculated over the interval\n    `[start, stop]`.\n\n    The endpoint of the interval can optionally be excluded.\n\n    Args:\n        start: The starting value of the sequence.\n        stop: The end value of the sequence, unless `endpoint` is set to\n            `False`. In that case, the sequence consists of all but the last\n            of `num + 1` evenly spaced samples, so that `stop` is excluded.\n            Note that the step size changes when `endpoint` is `False`.\n        num: Number of samples to generate. Defaults to `50`. Must be\n            non-negative.\n        endpoint: If `True`, `stop` is the last sample. Otherwise, it is\n            not included. Defaults to `True`.\n        retstep: If `True`, return `(samples, step)`, where `step` is the\n            spacing between samples.\n        dtype: The type of the output tensor.\n        axis: The axis in the result to store the samples. Relevant only if\n            start or stop are array-like. Defaults to `0`.\n\n    Note:\n        Torch backend does not support `axis` argument.\n\n    Returns:\n        A tensor of evenly spaced numbers.\n        If `retstep` is `True`, returns `(samples, step)`\n    \"\"\"\n    if any_symbolic_tensors((start, stop)):\n        return Linspace(num, endpoint, retstep, dtype, axis)(start, stop)\n    return backend.numpy.linspace(\n        start,\n        stop,\n        num=num,\n        endpoint=endpoint,\n        retstep=retstep,\n        dtype=dtype,\n        axis=axis,\n    )\n\n\nclass Log(Operation):\n    def call(self, x):\n        return backend.numpy.log(x)\n\n    def compute_output_spec(self, x):\n        dtype = (\n            backend.floatx()\n            if backend.standardize_dtype(x.dtype) == \"int64\"\n            else dtypes.result_type(x.dtype, float)\n        )\n        return KerasTensor(x.shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.log\", \"keras.ops.numpy.log\"])\ndef log(x):\n    \"\"\"Natural logarithm, element-wise.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Output tensor, element-wise natural logarithm of `x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Log().symbolic_call(x)\n    return backend.numpy.log(x)\n\n\nclass Log10(Operation):\n    def call(self, x):\n        return backend.numpy.log10(x)\n\n    def compute_output_spec(self, x):\n        dtype = (\n            backend.floatx()\n            if backend.standardize_dtype(x.dtype) == \"int64\"\n            else dtypes.result_type(x.dtype, float)\n        )\n        return KerasTensor(x.shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.log10\", \"keras.ops.numpy.log10\"])\ndef log10(x):\n    \"\"\"Return the base 10 logarithm of the input tensor, element-wise.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Output tensor, element-wise base 10 logarithm of `x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Log10().symbolic_call(x)\n    return backend.numpy.log10(x)\n\n\nclass Log1p(Operation):\n    def call(self, x):\n        return backend.numpy.log1p(x)\n\n    def compute_output_spec(self, x):\n        dtype = (\n            backend.floatx()\n            if backend.standardize_dtype(x.dtype) == \"int64\"\n            else dtypes.result_type(x.dtype, float)\n        )\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(x.shape, dtype=dtype, sparse=sparse)\n\n\n@keras_export([\"keras.ops.log1p\", \"keras.ops.numpy.log1p\"])\ndef log1p(x):\n    \"\"\"Returns the natural logarithm of one plus the `x`, element-wise.\n\n    Calculates `log(1 + x)`.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Output tensor, element-wise natural logarithm of `1 + x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Log1p().symbolic_call(x)\n    return backend.numpy.log1p(x)\n\n\nclass Log2(Operation):\n    def call(self, x):\n        return backend.numpy.log2(x)\n\n    def compute_output_spec(self, x):\n        dtype = (\n            backend.floatx()\n            if backend.standardize_dtype(x.dtype) == \"int64\"\n            else dtypes.result_type(x.dtype, float)\n        )\n        return KerasTensor(x.shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.log2\", \"keras.ops.numpy.log2\"])\ndef log2(x):\n    \"\"\"Base-2 logarithm of `x`, element-wise.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Output tensor, element-wise base-2 logarithm of `x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Log2().symbolic_call(x)\n    return backend.numpy.log2(x)\n\n\nclass Logaddexp(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.logaddexp(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n        dtype = dtypes.result_type(\n            getattr(x1, \"dtype\", type(x1)),\n            getattr(x2, \"dtype\", type(x2)),\n            float,\n        )\n        return KerasTensor(output_shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.logaddexp\", \"keras.ops.numpy.logaddexp\"])\ndef logaddexp(x1, x2):\n    \"\"\"Logarithm of the sum of exponentiations of the inputs.\n\n    Calculates `log(exp(x1) + exp(x2))`.\n\n    Args:\n        x1: Input tensor.\n        x2: Input tensor.\n\n    Returns:\n        Output tensor, element-wise logarithm of the sum of exponentiations\n        of the inputs.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Logaddexp().symbolic_call(x1, x2)\n    return backend.numpy.logaddexp(x1, x2)\n\n\nclass Logaddexp2(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.logaddexp2(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n        dtype = dtypes.result_type(\n            getattr(x1, \"dtype\", type(x1)),\n            getattr(x2, \"dtype\", type(x2)),\n            float,\n        )\n        return KerasTensor(output_shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.logaddexp2\", \"keras.ops.numpy.logaddexp2\"])\ndef logaddexp2(x1, x2):\n    \"\"\"Base-2 logarithm of the sum of exponentiations of the inputs.\n\n    Calculates `log2(2**x1 + 2**x2)`.\n\n    Args:\n        x1: Input tensor.\n        x2: Input tensor.\n\n    Returns:\n        Output tensor, element-wise log base 2 of the sum of 2**x1 and 2**x2.\n\n    Example:\n    >>> from keras import ops\n    >>> x1 = ops.array([1, 2, 3])\n    >>> x2 = ops.array([1, 2, 3])\n    >>> ops.logaddexp2(x1, x2)\n    array([2., 3., 4.], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Logaddexp2().symbolic_call(x1, x2)\n    return backend.numpy.logaddexp2(x1, x2)\n\n\nclass LogicalAnd(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.logical_and(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n        return KerasTensor(output_shape, dtype=\"bool\")\n\n\n@keras_export(\n    [\n        \"keras.ops.logical_and\",\n        \"keras.ops.numpy.logical_and\",\n    ]\n)\ndef logical_and(x1, x2):\n    \"\"\"Computes the element-wise logical AND of the given input tensors.\n\n    Zeros are treated as `False` and non-zeros are treated as `True`.\n\n    Args:\n        x1: Input tensor.\n        x2: Input tensor.\n\n    Returns:\n        Output tensor, element-wise logical AND of the inputs.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return LogicalAnd().symbolic_call(x1, x2)\n    return backend.numpy.logical_and(x1, x2)\n\n\nclass LogicalNot(Operation):\n    def call(self, x):\n        return backend.numpy.logical_not(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=\"bool\")\n\n\n@keras_export(\n    [\n        \"keras.ops.logical_not\",\n        \"keras.ops.numpy.logical_not\",\n    ]\n)\ndef logical_not(x):\n    \"\"\"Computes the element-wise NOT of the given input tensor.\n\n    Zeros are treated as `False` and non-zeros are treated as `True`.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Output tensor, element-wise logical NOT of the input.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return LogicalNot().symbolic_call(x)\n    return backend.numpy.logical_not(x)\n\n\nclass LogicalOr(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.logical_or(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n        return KerasTensor(output_shape, dtype=\"bool\")\n\n\n@keras_export(\n    [\n        \"keras.ops.logical_or\",\n        \"keras.ops.numpy.logical_or\",\n    ]\n)\ndef logical_or(x1, x2):\n    \"\"\"Computes the element-wise logical OR of the given input tensors.\n\n    Zeros are treated as `False` and non-zeros are treated as `True`.\n\n    Args:\n        x1: Input tensor.\n        x2: Input tensor.\n\n    Returns:\n        Output tensor, element-wise logical OR of the inputs.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return LogicalOr().symbolic_call(x1, x2)\n    return backend.numpy.logical_or(x1, x2)\n\n\nclass Logspace(Operation):\n    def __init__(\n        self, num=50, endpoint=True, base=10, dtype=None, axis=0, *, name=None\n    ):\n        super().__init__(name=name)\n        self.num = num\n        self.endpoint = endpoint\n        self.base = base\n        self.dtype = dtype\n        self.axis = axis\n\n    def call(self, start, stop):\n        return backend.numpy.logspace(\n            start,\n            stop,\n            num=self.num,\n            endpoint=self.endpoint,\n            base=self.base,\n            dtype=self.dtype,\n            axis=self.axis,\n        )\n\n    def compute_output_spec(self, start, stop):\n        start_shape = getattr(start, \"shape\", [])\n        stop_shape = getattr(stop, \"shape\", [])\n        output_shape = broadcast_shapes(start_shape, stop_shape)\n        output_shape = list(output_shape)\n        axis = canonicalize_axis(self.axis, len(output_shape) + 1)\n        output_shape.insert(axis, self.num)\n        dtype = (\n            self.dtype\n            if self.dtype is not None\n            else backend.standardize_dtype(getattr(start, \"dtype\", type(start)))\n        )\n        dtype = backend.result_type(dtype, float)\n        return KerasTensor(output_shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.logspace\", \"keras.ops.numpy.logspace\"])\ndef logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, axis=0):\n    \"\"\"Returns numbers spaced evenly on a log scale.\n\n    In linear space, the sequence starts at `base ** start` and ends with\n    `base ** stop` (see `endpoint` below).\n\n    Args:\n        start: The starting value of the sequence.\n        stop: The final value of the sequence, unless `endpoint` is `False`.\n            In that case, `num + 1` values are spaced over the interval in\n            log-space, of which all but the last (a sequence of length `num`)\n            are returned.\n        num: Number of samples to generate. Defaults to `50`.\n        endpoint: If `True`, `stop` is the last sample. Otherwise, it is not\n            included. Defaults to `True`.\n        base: The base of the log space. Defaults to `10`.\n        dtype: The type of the output tensor.\n        axis: The axis in the result to store the samples. Relevant only\n            if start or stop are array-like.\n\n    Note:\n        Torch backend does not support `axis` argument.\n\n    Returns:\n        A tensor of evenly spaced samples on a log scale.\n    \"\"\"\n    dtype = None if dtype is None else backend.standardize_dtype(dtype)\n    if any_symbolic_tensors((start, stop)):\n        return Logspace(num, endpoint, base, dtype, axis)(start, stop)\n    return backend.numpy.logspace(\n        start,\n        stop,\n        num=num,\n        endpoint=endpoint,\n        base=base,\n        dtype=dtype,\n        axis=axis,\n    )\n\n\nclass Matmul(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.matmul(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = operation_utils.compute_matmul_output_shape(\n            x1_shape, x2_shape\n        )\n        x1_sparse = getattr(x1, \"sparse\", True)\n        x2_sparse = getattr(x2, \"sparse\", True)\n        output_sparse = x1_sparse and x2_sparse\n        x1_dtype = backend.standardize_dtype(getattr(x1, \"dtype\", type(x1)))\n        x2_dtype = backend.standardize_dtype(getattr(x2, \"dtype\", type(x2)))\n        if x1_dtype == \"int8\" and x2_dtype == \"int8\":\n            dtype = \"int32\"\n        else:\n            dtype = dtypes.result_type(x1_dtype, x2_dtype)\n        return KerasTensor(output_shape, dtype=dtype, sparse=output_sparse)\n\n\n@keras_export([\"keras.ops.matmul\", \"keras.ops.numpy.matmul\"])\ndef matmul(x1, x2):\n    \"\"\"Matrix product of two tensors.\n\n    - If both tensors are 1-dimensional, the dot product (scalar) is returned.\n    - If either tensor is N-D, N > 2, it is treated as a stack of matrices\n      residing in the last two indexes and broadcast accordingly.\n    - If the first tensor is 1-D, it is promoted to a matrix by prepending\n      a 1 to its dimensions. After matrix multiplication the prepended\n      1 is removed.\n    - If the second tensor is 1-D, it is promoted to a matrix by appending a 1\n      to its dimensions. After matrix multiplication the appended 1 is removed.\n\n    Args:\n        x1: First tensor.\n        x2: Second tensor.\n\n    Returns:\n        Output tensor, matrix product of the inputs.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Matmul().symbolic_call(x1, x2)\n    return backend.numpy.matmul(x1, x2)\n\n\nclass Max(Operation):\n    def __init__(self, axis=None, keepdims=False, initial=None, *, name=None):\n        super().__init__(name=name)\n        if isinstance(axis, int):\n            self.axis = [axis]\n        else:\n            self.axis = axis\n        self.keepdims = keepdims\n        self.initial = initial\n\n    def call(self, x):\n        return backend.numpy.max(\n            x, axis=self.axis, keepdims=self.keepdims, initial=self.initial\n        )\n\n    def compute_output_spec(self, x):\n        return KerasTensor(\n            reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims),\n            dtype=x.dtype,\n        )\n\n\n@keras_export([\"keras.ops.max\", \"keras.ops.numpy.max\"])\ndef max(x, axis=None, keepdims=False, initial=None):\n    \"\"\"Return the maximum of a tensor or maximum along an axis.\n\n    Args:\n        x: Input tensor.\n        axis: Axis or axes along which to operate. By default, flattened input\n            is used.\n        keepdims: If this is set to `True`, the axes which are reduced are left\n            in the result as dimensions with size one. Defaults to `False`.\n        initial: The minimum value of an output element. Defaults to `None`.\n\n    Returns:\n        Maximum of `x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Max(axis=axis, keepdims=keepdims, initial=initial).symbolic_call(\n            x\n        )\n    return backend.numpy.max(x, axis=axis, keepdims=keepdims, initial=initial)\n\n\nclass Maximum(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.maximum(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n        output_dtype = dtypes.result_type(\n            getattr(x1, \"dtype\", type(x1)),\n            getattr(x2, \"dtype\", type(x2)),\n        )\n        x1_sparse = getattr(x1, \"sparse\", False)\n        x2_sparse = getattr(x2, \"sparse\", False)\n        output_sparse = x1_sparse and x2_sparse\n        return KerasTensor(\n            output_shape, dtype=output_dtype, sparse=output_sparse\n        )\n\n\n@keras_export([\"keras.ops.maximum\", \"keras.ops.numpy.maximum\"])\ndef maximum(x1, x2):\n    \"\"\"Element-wise maximum of `x1` and `x2`.\n\n    Args:\n        x1: First tensor.\n        x2: Second tensor.\n\n    Returns:\n        Output tensor, element-wise maximum of `x1` and `x2`.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Maximum().symbolic_call(x1, x2)\n    return backend.numpy.maximum(x1, x2)\n\n\nclass Median(Operation):\n    def __init__(self, axis=None, keepdims=False, *, name=None):\n        super().__init__(name=name)\n        if isinstance(axis, int):\n            axis = [axis]\n        self.axis = axis\n        self.keepdims = keepdims\n\n    def call(self, x):\n        return backend.numpy.median(x, axis=self.axis, keepdims=self.keepdims)\n\n    def compute_output_spec(self, x):\n        output_shape = reduce_shape(\n            x.shape, axis=self.axis, keepdims=self.keepdims\n        )\n        if backend.standardize_dtype(x.dtype) == \"int64\":\n            dtype = backend.floatx()\n        else:\n            dtype = dtypes.result_type(x.dtype, float)\n        return KerasTensor(output_shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.median\", \"keras.ops.numpy.median\"])\ndef median(x, axis=None, keepdims=False):\n    \"\"\"Compute the median along the specified axis.\n\n    Args:\n        x: Input tensor.\n        axis: Axis or axes along which the medians are computed. Defaults to\n            `axis=None` which is to compute the median(s) along a flattened\n            version of the array.\n        keepdims: If this is set to `True`, the axes which are reduce\n            are left in the result as dimensions with size one.\n\n    Returns:\n        The output tensor.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Median(axis=axis, keepdims=keepdims).symbolic_call(x)\n    return backend.numpy.median(x, axis=axis, keepdims=keepdims)\n\n\nclass Meshgrid(Operation):\n    def __init__(self, indexing=\"xy\", *, name=None):\n        super().__init__(name=name)\n        if indexing not in (\"xy\", \"ij\"):\n            raise ValueError(\n                \"Valid values for `indexing` are 'xy' and 'ij', \"\n                \"but received {index}.\"\n            )\n        self.indexing = indexing\n\n    def call(self, *x):\n        return backend.numpy.meshgrid(*x, indexing=self.indexing)\n\n    def compute_output_spec(self, *x):\n        output_shape = []\n        for xi in x:\n            if len(xi.shape) == 0:\n                size = 1\n            else:\n                if None in xi.shape:\n                    size = None\n                else:\n                    size = int(np.prod(xi.shape))\n            output_shape.append(size)\n        if self.indexing == \"ij\":\n            return [KerasTensor(output_shape) for _ in range(len(x))]\n        tmp = output_shape[0]\n        output_shape[0] = output_shape[1]\n        output_shape[1] = tmp\n        return [\n            KerasTensor(output_shape, dtype=xi.dtype) for _ in range(len(x))\n        ]\n\n\n@keras_export([\"keras.ops.meshgrid\", \"keras.ops.numpy.meshgrid\"])\ndef meshgrid(*x, indexing=\"xy\"):\n    \"\"\"Creates grids of coordinates from coordinate vectors.\n\n    Given `N` 1-D tensors `T0, T1, ..., TN-1` as inputs with corresponding\n    lengths `S0, S1, ..., SN-1`, this creates an `N` N-dimensional tensors\n    `G0, G1, ..., GN-1` each with shape `(S0, ..., SN-1)` where the output\n    `Gi` is constructed by expanding `Ti` to the result shape.\n\n    Args:\n        x: 1-D tensors representing the coordinates of a grid.\n        indexing: `\"xy\"` or `\"ij\"`. \"xy\" is cartesian; `\"ij\"` is matrix\n            indexing of output. Defaults to `\"xy\"`.\n\n    Returns:\n        Sequence of N tensors.\n\n    Example:\n    >>> from keras.src import ops\n    >>> x = ops.array([1, 2, 3])\n    >>> y = ops.array([4, 5, 6])\n\n    >>> grid_x, grid_y = ops.meshgrid(x, y, indexing=\"ij\")\n    >>> grid_x\n    array([[1, 1, 1],\n           [2, 2, 2],\n           [3, 3, 3]])\n    >>> grid_y\n    array([[4, 5, 6],\n           [4, 5, 6],\n           [4, 5, 6]])\n    \"\"\"\n    if any_symbolic_tensors(x):\n        return Meshgrid(indexing=indexing).symbolic_call(*x)\n    return backend.numpy.meshgrid(*x, indexing=indexing)\n\n\nclass Min(Operation):\n    def __init__(self, axis=None, keepdims=False, initial=None, *, name=None):\n        super().__init__(name=name)\n        if isinstance(axis, int):\n            self.axis = [axis]\n        else:\n            self.axis = axis\n        self.keepdims = keepdims\n        self.initial = initial\n\n    def call(self, x):\n        return backend.numpy.min(\n            x, axis=self.axis, keepdims=self.keepdims, initial=self.initial\n        )\n\n    def compute_output_spec(self, x):\n        return KerasTensor(\n            reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims),\n            dtype=x.dtype,\n        )\n\n\n@keras_export([\"keras.ops.min\", \"keras.ops.numpy.min\"])\ndef min(x, axis=None, keepdims=False, initial=None):\n    \"\"\"Return the minimum of a tensor or minimum along an axis.\n\n    Args:\n        x: Input tensor.\n        axis: Axis or axes along which to operate. By default, flattened input\n            is used.\n        keepdims: If this is set to `True`, the axes which are reduced are left\n            in the result as dimensions with size one. Defaults to `False`.\n        initial: The maximum value of an output element. Defaults to `None`.\n\n    Returns:\n        Minimum of `x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Min(axis=axis, keepdims=keepdims, initial=initial).symbolic_call(\n            x\n        )\n    return backend.numpy.min(x, axis=axis, keepdims=keepdims, initial=initial)\n\n\nclass Minimum(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.minimum(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n        output_dtype = dtypes.result_type(\n            getattr(x1, \"dtype\", type(x1)),\n            getattr(x2, \"dtype\", type(x2)),\n        )\n        x1_sparse = getattr(x1, \"sparse\", False)\n        x2_sparse = getattr(x2, \"sparse\", False)\n        output_sparse = x1_sparse and x2_sparse\n        return KerasTensor(\n            output_shape, dtype=output_dtype, sparse=output_sparse\n        )\n\n\n@keras_export([\"keras.ops.minimum\", \"keras.ops.numpy.minimum\"])\ndef minimum(x1, x2):\n    \"\"\"Element-wise minimum of `x1` and `x2`.\n\n    Args:\n        x1: First tensor.\n        x2: Second tensor.\n\n    Returns:\n        Output tensor, element-wise minimum of `x1` and `x2`.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Minimum().symbolic_call(x1, x2)\n    return backend.numpy.minimum(x1, x2)\n\n\nclass Mod(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.mod(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n        output_dtype = dtypes.result_type(\n            getattr(x1, \"dtype\", type(x1)),\n            getattr(x2, \"dtype\", type(x2)),\n        )\n        if output_dtype == \"bool\":\n            output_dtype = \"int32\"\n        return KerasTensor(output_shape, dtype=output_dtype)\n\n\n@keras_export([\"keras.ops.mod\", \"keras.ops.numpy.mod\"])\ndef mod(x1, x2):\n    \"\"\"Returns the element-wise remainder of division.\n\n    Args:\n        x1: First tensor.\n        x2: Second tensor.\n\n    Returns:\n        Output tensor, element-wise remainder of division.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Mod().symbolic_call(x1, x2)\n    return backend.numpy.mod(x1, x2)\n\n\nclass Fmod(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.fmod(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n        output_dtype = dtypes.result_type(\n            getattr(x1, \"dtype\", type(x1)),\n            getattr(x2, \"dtype\", type(x2)),\n        )\n        if output_dtype == \"bool\":\n            output_dtype = \"int32\"\n        return KerasTensor(output_shape, dtype=output_dtype)\n\n\n@keras_export([\"keras.ops.fmod\", \"keras.ops.numpy.fmod\"])\ndef fmod(x1, x2):\n    \"\"\"Returns the element-wise remainder of division with truncation.\n\n    Computes the remainder complementary to the `floor_divide` function,\n    equivalent to the C library function ``fmod``. The result has the same\n    sign as the dividend ``x1``. This is different from `keras.ops.mod`\n    which has the same sign as the divisor ``x2``.\n\n    Args:\n        x1: First tensor, the dividend.\n        x2: Second tensor, the divisor.\n\n    Returns:\n        Output tensor, element-wise remainder with truncation.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Fmod().symbolic_call(x1, x2)\n    return backend.numpy.fmod(x1, x2)\n\n\nclass Moveaxis(Operation):\n    def __init__(self, source, destination, *, name=None):\n        super().__init__(name=name)\n        if isinstance(source, int):\n            self.source = [source]\n        else:\n            self.source = source\n        if isinstance(destination, int):\n            self.destination = [destination]\n        else:\n            self.destination = destination\n\n        if len(self.source) != len(self.destination):\n            raise ValueError(\n                \"`source` and `destination` arguments must have the same \"\n                f\"number of elements, but received `source={source}` and \"\n                f\"`destination={destination}`.\"\n            )\n\n    def call(self, x):\n        return backend.numpy.moveaxis(x, self.source, self.destination)\n\n    def compute_output_spec(self, x):\n        x_shape = list(x.shape)\n        output_shape = [-1 for _ in range(len(x.shape))]\n        for sc, dst in zip(self.source, self.destination):\n            output_shape[dst] = x_shape[sc]\n            x_shape[sc] = -1\n        i, j = 0, 0\n        while i < len(output_shape):\n            while i < len(output_shape) and output_shape[i] != -1:\n                # Find the first dim unset.\n                i += 1\n            while j < len(output_shape) and x_shape[j] == -1:\n                # Find the first dim not being passed.\n                j += 1\n            if i == len(output_shape):\n                break\n            output_shape[i] = x_shape[j]\n            i += 1\n            j += 1\n        return KerasTensor(output_shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.moveaxis\", \"keras.ops.numpy.moveaxis\"])\ndef moveaxis(x, source, destination):\n    \"\"\"Move axes of a tensor to new positions.\n\n    Other axes remain in their original order.\n\n    Args:\n        x: Tensor whose axes should be reordered.\n        source: Original positions of the axes to move. These must be unique.\n        destination: Destinations positions for each of the original axes.\n            These must also be unique.\n\n    Returns:\n        Tensor with moved axes.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Moveaxis(source, destination).symbolic_call(x)\n    return backend.numpy.moveaxis(x, source=source, destination=destination)\n\n\nclass Nanargmax(Operation):\n    def __init__(self, axis=None, keepdims=False, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n        self.keepdims = keepdims\n\n    def call(self, x):\n        return backend.numpy.nanargmax(\n            x, axis=self.axis, keepdims=self.keepdims\n        )\n\n    def compute_output_spec(self, x):\n        axis = [self.axis] if self.axis is not None else None\n        return KerasTensor(\n            reduce_shape(x.shape, axis=axis, keepdims=self.keepdims),\n            dtype=\"int32\",\n        )\n\n\n@keras_export([\"keras.ops.nanargmax\", \"keras.ops.numpy.nanargmax\"])\ndef nanargmax(x, axis=None, keepdims=False):\n    \"\"\"Returns the indices of the maximum values along an axis, ignoring NaNs.\n\n    Args:\n        x: Input tensor.\n        axis: By default, the index is into the flattened tensor, otherwise\n            along the specified axis.\n        keepdims: If this is set to `True`, the axes which are reduced are left\n            in the result as dimensions with size one. Defaults to `False`.\n\n    Returns:\n        Tensor of indices. It has the same shape as `x`, with the dimension\n        along `axis` removed. NaN values are ignored when computing the maximum.\n\n    Examples:\n    >>> import numpy as np\n    >>> from keras import ops\n    >>> x = np.array([[1.0, np.nan, 3.0],\n    ...               [np.nan, 2.0, 0.0]])\n\n    >>> ops.nanargmax(x)\n    array(2, dtype=int32)\n\n    >>> ops.nanargmax(x, axis=0)\n    array([0, 1, 0], dtype=int32)\n\n    >>> ops.nanargmax(x, axis=1)\n    array([2, 1], dtype=int32)\n\n    >>> ops.nanargmax(x, axis=1, keepdims=True)\n    array([[2],\n           [1]], dtype=int32)\n    \"\"\"\n\n    if any_symbolic_tensors((x,)):\n        return Nanargmax(axis=axis, keepdims=keepdims).symbolic_call(x)\n    return backend.numpy.nanargmax(x, axis=axis, keepdims=keepdims)\n\n\nclass Nanargmin(Operation):\n    def __init__(self, axis=None, keepdims=False, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n        self.keepdims = keepdims\n\n    def call(self, x):\n        return backend.numpy.nanargmin(\n            x, axis=self.axis, keepdims=self.keepdims\n        )\n\n    def compute_output_spec(self, x):\n        axis = [self.axis] if self.axis is not None else None\n        return KerasTensor(\n            reduce_shape(x.shape, axis=axis, keepdims=self.keepdims),\n            dtype=\"int32\",\n        )\n\n\n@keras_export([\"keras.ops.nanargmin\", \"keras.ops.numpy.nanargmin\"])\ndef nanargmin(x, axis=None, keepdims=False):\n    \"\"\"Returns the indices of the minimum values along an axis, ignoring NaNs.\n\n    Args:\n        x: Input tensor.\n        axis: By default, the index is into the flattened tensor, otherwise\n            along the specified axis.\n        keepdims: If this is set to `True`, the axes which are reduced are left\n                in the result as dimensions with size one. Defaults to `False`.\n\n    Returns:\n        Tensor of indices. It has the same shape as `x`, with the dimension\n        along `axis` removed. NaN values are ignored when computing the minimum.\n\n    Examples:\n    >>> import numpy as np\n    >>> from keras import ops\n    >>> x = np.array([[1.0, np.nan, 3.0],\n    ...               [np.nan, 2.0, 0.0]])\n\n    >>> ops.nanargmin(x)\n    array(5, dtype=int32)\n\n    >>> ops.nanargmin(x, axis=0)\n    array([0, 1, 1], dtype=int32)\n\n    >>> ops.nanargmin(x, axis=1)\n    array([0, 2], dtype=int32)\n\n    >>> ops.nanargmin(x, axis=1, keepdims=True)\n    array([[0],\n           [2]], dtype=int32)\n    \"\"\"\n\n    if any_symbolic_tensors((x,)):\n        return Nanargmin(axis=axis, keepdims=keepdims).symbolic_call(x)\n    return backend.numpy.nanargmin(x, axis=axis, keepdims=keepdims)\n\n\nclass Nancumsum(Operation):\n    def __init__(self, axis=None, dtype=None, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n        self.dtype = dtype\n\n    def call(self, x):\n        return backend.numpy.nancumsum(x, axis=self.axis, dtype=self.dtype)\n\n    def compute_output_spec(self, x):\n        if self.axis is None:\n            if None in x.shape:\n                output_shape = (None,)\n            else:\n                output_shape = (int(np.prod(x.shape)),)\n        else:\n            output_shape = x.shape\n\n        output_dtype = (\n            backend.standardize_dtype(x.dtype)\n            if self.dtype is None\n            else self.dtype\n        )\n\n        if output_dtype == \"bool\":\n            output_dtype = \"int32\"\n\n        return KerasTensor(output_shape, output_dtype)\n\n\n@keras_export([\"keras.ops.nancumsum\", \"keras.ops.numpy.nancumsum\"])\ndef nancumsum(x, axis=None, dtype=None):\n    \"\"\"Returns the cumulative sum of elements along a given axis,\n    treating NaNs as zero.\n\n    Args:\n        x: Input tensor.\n        axis: Axis along which the cumulative sum is computed.\n            By default the input is flattened.\n        dtype: dtype of returned tensor. Defaults to x.dtype.\n\n    Returns:\n        Output tensor.\n\n    Examples:\n    >>> import numpy as np\n    >>> from keras import ops\n    >>> x = np.array([[1.0, np.nan, 3.0],\n    ...               [np.nan, 2.0, 1.0]])\n    >>> ops.nancumsum(x)\n    array([1., 1., 4., 4., 6., 7.])\n\n    >>> ops.nancumsum(x, axis=1)\n    array([[1., 1., 4.],\n           [0., 2., 3.]])\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Nancumsum(axis=axis, dtype=dtype).symbolic_call(x)\n    return backend.numpy.nancumsum(x, axis=axis, dtype=dtype)\n\n\nclass Nancumprod(Operation):\n    def __init__(self, axis=None, dtype=None, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n        self.dtype = dtype\n\n    def call(self, x):\n        return backend.numpy.nancumprod(x, axis=self.axis, dtype=self.dtype)\n\n    def compute_output_spec(self, x):\n        if self.axis is None:\n            if None in x.shape:\n                output_shape = (None,)\n            else:\n                output_shape = (int(np.prod(x.shape)),)\n        else:\n            output_shape = x.shape\n\n        output_dtype = (\n            backend.standardize_dtype(x.dtype)\n            if self.dtype is None\n            else self.dtype\n        )\n\n        if output_dtype == \"bool\":\n            output_dtype = \"int32\"\n\n        return KerasTensor(output_shape, output_dtype)\n\n\n@keras_export([\"keras.ops.nancumprod\", \"keras.ops.numpy.nancumprod\"])\ndef nancumprod(x, axis=None, dtype=None):\n    \"\"\"Returns the cumulative product of elements along a given axis,\n    treating NaNs as one.\n\n    Args:\n        x: Input tensor.\n        axis: Axis along which the cumulative product is computed.\n            By default the input is flattened.\n        dtype: dtype of returned tensor. Defaults to x.dtype.\n\n    Returns:\n        Output tensor.\n\n    Examples:\n    >>> import numpy as np\n    >>> from keras import ops\n    >>> x = np.array([[1.0, np.nan, 3.0],\n    ...               [np.nan, 2.0, 1.0]])\n    >>> ops.nancumprod(x)\n    array([1., 1., 3., 3., 6., 6.])\n\n    >>> ops.nancumprod(x, axis=1)\n    array([[1., 1., 3.],\n           [1., 2., 2.]])\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Nancumprod(axis=axis, dtype=dtype).symbolic_call(x)\n    return backend.numpy.nancumprod(x, axis=axis, dtype=dtype)\n\n\nclass Nanmax(Operation):\n    def __init__(self, axis=None, keepdims=False, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n        self.keepdims = keepdims\n\n    def call(self, x):\n        return backend.numpy.nanmax(x, axis=self.axis, keepdims=self.keepdims)\n\n    def compute_output_spec(self, x):\n        dtype = dtypes.result_type(getattr(x, \"dtype\", backend.floatx()))\n\n        if backend.backend() == \"torch\" and dtype == \"uint32\":\n            dtype = \"int32\"\n\n        return KerasTensor(\n            reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims),\n            dtype=dtype,\n        )\n\n\n@keras_export([\"keras.ops.nanmax\", \"keras.ops.numpy.nanmax\"])\ndef nanmax(x, axis=None, keepdims=False):\n    \"\"\"Maximum of a tensor over the given axes, ignoring NaNs.\n\n    Args:\n        x: Input tensor.\n        axis: Axis or axes along which the maximum is computed.\n            The default is to compute the maximum of the flattened tensor.\n        keepdims: If this is set to `True`, the axes which are reduced are left\n            in the result as dimensions with size one. Defaults\n            to `False`.\n\n    Returns:\n        Output tensor containing the maximum, with NaN values ignored. If all\n        values along a reduced axis are NaN, the result is NaN.\n\n    Examples:\n    >>> import numpy as np\n    >>> from keras import ops\n    >>> x = np.array([[1.0, np.nan, 3.0],\n    ...               [np.nan, 2.0, 1.0]])\n    >>> ops.nanmax(x)\n    3.0\n\n    >>> ops.nanmax(x, axis=1)\n    array([3., 2.])\n\n    >>> ops.nanmax(x, axis=1, keepdims=True)\n    array([[3.],\n           [2.]])\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Nanmax(axis=axis, keepdims=keepdims).symbolic_call(x)\n    return backend.numpy.nanmax(x, axis=axis, keepdims=keepdims)\n\n\nclass Nanmean(Operation):\n    def __init__(self, axis=None, keepdims=False, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n        self.keepdims = keepdims\n\n    def call(self, x):\n        return backend.numpy.nanmean(x, axis=self.axis, keepdims=self.keepdims)\n\n    def compute_output_spec(self, x):\n        dtype = dtypes.result_type(x.dtype, float)\n        return KerasTensor(\n            reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims),\n            dtype=dtype,\n        )\n\n\n@keras_export([\"keras.ops.nanmean\", \"keras.ops.numpy.nanmean\"])\ndef nanmean(x, axis=None, keepdims=False):\n    \"\"\"Mean of a tensor over the given axes, ignoring NaNs.\n\n    Args:\n        x: Input tensor.\n        axis: Axis or axes along which the mean is computed.\n            The default is to compute the mean of the flattened tensor.\n        keepdims: If this is set to `True`, the axes which are reduced are left\n            in the result as dimensions with size one. Defaults\n            to `False`.\n\n    Returns:\n        Output tensor containing the mean, with NaN values ignored.\n        If all values along a reduced axis are NaN, the result is NaN.\n\n    Examples:\n    >>> import numpy as np\n    >>> from keras import ops\n    >>> x = np.array([[1.0, np.nan, 3.0],\n    ...               [np.nan, 2.0, 1.0]])\n    >>> ops.nanmean(x)\n    1.75\n\n    >>> ops.nanmean(x, axis=1)\n    array([2., 1.5])\n\n    >>> ops.nanmean(x, axis=1, keepdims=True)\n    array([[2. ],\n           [1.5]])\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Nanmean(axis=axis, keepdims=keepdims).symbolic_call(x)\n\n    return backend.numpy.nanmean(x, axis=axis, keepdims=keepdims)\n\n\nclass Nanmin(Operation):\n    def __init__(self, axis=None, keepdims=False, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n        self.keepdims = keepdims\n\n    def call(self, x):\n        return backend.numpy.nanmin(x, axis=self.axis, keepdims=self.keepdims)\n\n    def compute_output_spec(self, x):\n        dtype = dtypes.result_type(getattr(x, \"dtype\", backend.floatx()))\n\n        if backend.backend() == \"torch\" and dtype == \"uint32\":\n            dtype = \"int32\"\n\n        return KerasTensor(\n            reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims),\n            dtype=dtype,\n        )\n\n\n@keras_export([\"keras.ops.nanmin\", \"keras.ops.numpy.nanmin\"])\ndef nanmin(x, axis=None, keepdims=False):\n    \"\"\"Minimum of a tensor over the given axes, ignoring NaNs.\n\n    Args:\n        x: Input tensor.\n        axis: Axis or axes along which the minimum is computed.\n            The default is to compute the minimum of the flattened tensor.\n        keepdims: If this is set to `True`, the axes which are reduced are left\n            in the result as dimensions with size one.\n\n    Returns:\n        Output tensor containing the minimum, with NaN values ignored. If all\n        values along a reduced axis are NaN, the result is NaN.\n\n    Examples:\n    >>> import numpy as np\n    >>> from keras import ops\n    >>> x = np.array([[1.0, np.nan, 3.0],\n    ...               [np.nan, 2.0, 1.0]])\n    >>> ops.nanmin(x)\n    1.0\n\n    >>> ops.nanmin(x, axis=1)\n    array([1., 1.])\n\n    >>> ops.nanmin(x, axis=1, keepdims=True)\n    array([[1.],\n           [1.]])\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Nanmin(axis=axis, keepdims=keepdims).symbolic_call(x)\n    return backend.numpy.nanmin(x, axis=axis, keepdims=keepdims)\n\n\nclass Nanprod(Operation):\n    def __init__(self, axis=None, keepdims=False, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n        self.keepdims = keepdims\n\n    def call(self, x):\n        return backend.numpy.nanprod(\n            x,\n            axis=self.axis,\n            keepdims=self.keepdims,\n        )\n\n    def compute_output_spec(self, x):\n        dtype = backend.standardize_dtype(x.dtype)\n\n        if dtype == \"bool\":\n            dtype = \"int32\"\n        elif dtype in (\"int8\", \"int16\"):\n            dtype = \"int32\"\n        elif dtype in (\"uint8\", \"uint16\"):\n            dtype = \"uint32\"\n\n        if backend.backend() == \"torch\" and dtype == \"uint32\":\n            dtype = \"int32\"\n\n        return KerasTensor(\n            reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims),\n            dtype=dtype,\n        )\n\n\n@keras_export([\"keras.ops.nanprod\", \"keras.ops.numpy.nanprod\"])\ndef nanprod(x, axis=None, keepdims=False):\n    \"\"\"Product of a tensor over the given axes, ignoring NaNs.\n\n    Args:\n        x: Input tensor.\n        axis: Axis or axes along which the product is computed. The default is\n            to compute the product of the flattened tensor.\n        keepdims: If this is set to `True`, the axes which are reduced are left\n            in the result as dimensions with size one.\n\n    Returns:\n        Output tensor containing the product, with NaN values ignored.\n\n    Examples:\n    >>> import numpy as np\n    >>> from keras import ops\n    >>> x = np.array([[1.0, np.nan, 3.0],\n    ...               [np.nan, 2.0, 1.0]])\n    >>> ops.nanprod(x)\n    6.0\n\n    >>> ops.nanprod(x, axis=1)\n    array([3., 2.])\n\n    >>> ops.nanprod(x, axis=1, keepdims=True)\n    array([[3.],\n           [2.]])\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Nanprod(axis=axis, keepdims=keepdims).symbolic_call(x)\n    return backend.numpy.nanprod(x, axis=axis, keepdims=keepdims)\n\n\nclass Nanstd(Operation):\n    def __init__(self, axis=None, keepdims=False, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n        self.keepdims = keepdims\n\n    def call(self, x):\n        return backend.numpy.nanstd(x, axis=self.axis, keepdims=self.keepdims)\n\n    def compute_output_spec(self, x):\n        output_dtype = backend.result_type(getattr(x, \"dtype\", type(x)), float)\n        return KerasTensor(\n            reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims),\n            dtype=output_dtype,\n        )\n\n\n@keras_export([\"keras.ops.nanstd\", \"keras.ops.numpy.nanstd\"])\ndef nanstd(x, axis=None, keepdims=False):\n    \"\"\"Compute the standard deviation along the specified axes, ignoring NaNs.\n\n    Args:\n        x: Input tensor.\n        axis: Axis or axes along which the standard deviation is computed.\n            The default is to compute the std of the flattened tensor.\n        keepdims: If `True`, the axes which are reduced are left\n            in the result as dimensions with size one. Defaults to `False`.\n\n    Returns:\n        Output tensor containing the standard deviation ignoring NaNs.\n\n    Examples:\n    >>> import numpy as np\n    >>> from keras import ops\n    >>> x = np.array([[1.0, np.nan, 3.0],\n    ...               [np.nan, 2.0, 1.0]])\n\n    >>> ops.nanstd(x)\n    0.8291562\n\n    >>> ops.nanstd(x, axis=1)\n    array([1. , 0.5])\n\n    >>> ops.nanstd(x, axis=1, keepdims=True)\n    array([[1. ],\n           [0.5]])\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Nanstd(axis=axis, keepdims=keepdims).symbolic_call(x)\n    return backend.numpy.nanstd(x, axis=axis, keepdims=keepdims)\n\n\nclass Nansum(Operation):\n    def __init__(self, axis=None, keepdims=False, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n        self.keepdims = keepdims\n\n    def call(self, x):\n        return backend.numpy.nansum(x, axis=self.axis, keepdims=self.keepdims)\n\n    def compute_output_spec(self, x):\n        dtype = dtypes.result_type(getattr(x, \"dtype\", backend.floatx()))\n\n        if dtype in (\"bool\", \"int8\", \"int16\"):\n            dtype = \"int32\"\n        elif dtype in (\"uint8\", \"uint16\"):\n            dtype = \"uint32\"\n\n        if backend.backend() == \"torch\" and dtype == \"uint32\":\n            dtype = \"int32\"\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(\n            reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims),\n            dtype=dtype,\n            sparse=sparse,\n        )\n\n\n@keras_export([\"keras.ops.nansum\", \"keras.ops.numpy.nansum\"])\ndef nansum(x, axis=None, keepdims=False):\n    \"\"\"Sum of a tensor over the given axes, ignoring NaNs.\n\n    Args:\n        x: Input tensor.\n        axis: Axis or axes along which the sum is computed. The default is to\n            compute the sum of the flattened tensor.\n        keepdims: If this is set to `True`, the axes which are reduced are left\n            in the result as dimensions with size one.\n\n    Returns:\n        Output tensor containing the sum, with NaN values ignored.\n\n    Examples:\n    >>> import numpy as np\n    >>> from keras import ops\n    >>> x = np.array([[1.0, np.nan, 3.0],\n    ...               [np.nan, 2.0, 1.0]])\n    >>> ops.nansum(x)\n    7.0\n\n    >>> ops.nansum(x, axis=1)\n    array([4., 3.])\n\n    >>> ops.nansum(x, axis=1, keepdims=True)\n    array([[4.],\n           [3.]])\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Nansum(axis=axis, keepdims=keepdims).symbolic_call(x)\n    return backend.numpy.nansum(x, axis=axis, keepdims=keepdims)\n\n\nclass Nanvar(Operation):\n    def __init__(self, axis=None, keepdims=False, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n        self.keepdims = keepdims\n\n    def call(self, x):\n        return backend.numpy.nanvar(x, axis=self.axis, keepdims=self.keepdims)\n\n    def compute_output_spec(self, x):\n        output_dtype = backend.result_type(getattr(x, \"dtype\", type(x)), float)\n        return KerasTensor(\n            reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims),\n            dtype=output_dtype,\n        )\n\n\n@keras_export([\"keras.ops.nanvar\", \"keras.ops.numpy.nanvar\"])\ndef nanvar(x, axis=None, keepdims=False):\n    \"\"\"Compute the variance along the specified axes, ignoring NaNs.\n\n    Args:\n        x: Input tensor.\n        axis: Axis or axes along which the variance is computed. The default\n            is to compute the variance of the flattened tensor.\n        keepdims: If `True`, the axes which are reduced are left\n            in the result as dimensions with size one. Defaults to `False`.\n\n    Returns:\n        Output tensor containing the variance ignoring NaNs.\n\n    Examples:\n    >>> import numpy as np\n    >>> from keras import ops\n    >>> x = np.array([[1.0, np.nan, 3.0],\n    ...               [np.nan, 2.0, 1.0]])\n\n    >>> ops.nanvar(x)\n    0.6875\n\n    >>> ops.nanvar(x, axis=1)\n    array([1.  , 0.25])\n\n    >>> ops.nanvar(x, axis=1, keepdims=True)\n    array([[1.  ],\n           [0.25]])\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Nanvar(axis=axis, keepdims=keepdims).symbolic_call(x)\n    return backend.numpy.nanvar(x, axis=axis, keepdims=keepdims)\n\n\nclass NanToNum(Operation):\n    def __init__(self, nan=0.0, posinf=None, neginf=None, *, name=None):\n        super().__init__(name=name)\n        self.nan = nan\n        self.posinf = posinf\n        self.neginf = neginf\n\n    def call(self, x):\n        return backend.numpy.nan_to_num(\n            x, nan=self.nan, posinf=self.posinf, neginf=self.neginf\n        )\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export(\n    [\n        \"keras.ops.nan_to_num\",\n        \"keras.ops.numpy.nan_to_num\",\n    ]\n)\ndef nan_to_num(x, nan=0.0, posinf=None, neginf=None):\n    \"\"\"Replace NaN with zero and infinity with large finite numbers.\n\n    Args:\n        x: Input data.\n        nan: Optional float or int. Value to replace `NaN` entries with.\n        posinf: Optional float or int.\n            Value to replace positive infinity with.\n        neginf: Optional float or int.\n            Value to replace negative infinity with.\n\n    Returns:\n        `x`, with non-finite values replaced.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return NanToNum(nan=nan, posinf=posinf, neginf=neginf).symbolic_call(x)\n    return backend.numpy.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf)\n\n\nclass Ndim(Operation):\n    def call(self, x):\n        return backend.numpy.ndim(\n            x,\n        )\n\n    def compute_output_spec(self, x):\n        return KerasTensor([len(x.shape)])\n\n\n@keras_export([\"keras.ops.ndim\", \"keras.ops.numpy.ndim\"])\ndef ndim(x):\n    \"\"\"Return the number of dimensions of a tensor.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        The number of dimensions in `x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Ndim().symbolic_call(x)\n    return backend.numpy.ndim(x)\n\n\nclass Nonzero(Operation):\n    def call(self, x):\n        return backend.numpy.nonzero(x)\n\n    def compute_output_spec(self, x):\n        return tuple(\n            [KerasTensor((None,), dtype=\"int32\") for _ in range(len(x.shape))]\n        )\n\n\n@keras_export([\"keras.ops.nonzero\", \"keras.ops.numpy.nonzero\"])\ndef nonzero(x):\n    \"\"\"Return the indices of the elements that are non-zero.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Indices of elements that are non-zero.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Nonzero().symbolic_call(x)\n    return backend.numpy.nonzero(x)\n\n\nclass NotEqual(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.not_equal(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n        return KerasTensor(output_shape, dtype=\"bool\")\n\n\n@keras_export([\"keras.ops.not_equal\", \"keras.ops.numpy.not_equal\"])\ndef not_equal(x1, x2):\n    \"\"\"Return `(x1 != x2)` element-wise.\n\n    Args:\n        x1: First input tensor.\n        x2: Second input tensor.\n\n    Returns:\n        Output tensor, element-wise comparison of `x1` and `x2`.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return NotEqual().symbolic_call(x1, x2)\n    return backend.numpy.not_equal(x1, x2)\n\n\nclass OnesLike(Operation):\n    def __init__(self, dtype=None, *, name=None):\n        super().__init__(name=name)\n        self.dtype = dtype\n\n    def call(self, x):\n        return backend.numpy.ones_like(x, dtype=self.dtype)\n\n    def compute_output_spec(self, x):\n        dtype = (\n            backend.standardize_dtype(x.dtype)\n            if self.dtype is None\n            else self.dtype\n        )\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(x.shape, dtype=dtype, sparse=sparse)\n\n\n@keras_export([\"keras.ops.ones_like\", \"keras.ops.numpy.ones_like\"])\ndef ones_like(x, dtype=None):\n    \"\"\"Return a tensor of ones with the same shape and type of `x`.\n\n    Args:\n        x: Input tensor.\n        dtype: Overrides the data type of the result.\n\n    Returns:\n        A tensor of ones with the same shape and type as `x`.\n    \"\"\"\n    dtype = None if dtype is None else backend.standardize_dtype(dtype)\n    if any_symbolic_tensors((x,)):\n        return OnesLike(dtype=dtype).symbolic_call(x)\n    return backend.numpy.ones_like(x, dtype=dtype)\n\n\nclass ZerosLike(Operation):\n    def __init__(self, dtype=None, *, name=None):\n        super().__init__(name=name)\n        self.dtype = dtype\n\n    def call(self, x):\n        return backend.numpy.zeros_like(x, dtype=self.dtype)\n\n    def compute_output_spec(self, x, dtype=None):\n        dtype = (\n            backend.standardize_dtype(x.dtype)\n            if self.dtype is None\n            else self.dtype\n        )\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(x.shape, dtype=dtype, sparse=sparse)\n\n\n@keras_export(\n    [\n        \"keras.ops.zeros_like\",\n        \"keras.ops.numpy.zeros_like\",\n    ]\n)\ndef zeros_like(x, dtype=None):\n    \"\"\"Return a tensor of zeros with the same shape and type as `x`.\n\n    Args:\n        x: Input tensor.\n        dtype: Overrides the data type of the result.\n\n    Returns:\n        A tensor of zeros with the same shape and type as `x`.\n    \"\"\"\n    dtype = None if dtype is None else backend.standardize_dtype(dtype)\n    if any_symbolic_tensors((x,)):\n        return ZerosLike(dtype=dtype).symbolic_call(x)\n    return backend.numpy.zeros_like(x, dtype=dtype)\n\n\nclass Outer(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.outer(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [1])\n        x2_shape = getattr(x2, \"shape\", [1])\n        if None in x1_shape:\n            x1_flatten_shape = None\n        else:\n            x1_flatten_shape = int(np.prod(x1_shape))\n        if None in x2_shape:\n            x2_flatten_shape = None\n        else:\n            x2_flatten_shape = int(np.prod(x2_shape))\n        output_shape = [x1_flatten_shape, x2_flatten_shape]\n        output_dtype = backend.result_type(\n            getattr(x1, \"dtype\", type(x1)),\n            getattr(x2, \"dtype\", type(x2)),\n        )\n        return KerasTensor(output_shape, dtype=output_dtype)\n\n\n@keras_export([\"keras.ops.outer\", \"keras.ops.numpy.outer\"])\ndef outer(x1, x2):\n    \"\"\"Compute the outer product of two vectors.\n\n    Given two vectors `x1` and `x2`, the outer product is:\n\n    ```\n    out[i, j] = x1[i] * x2[j]\n    ```\n\n    Args:\n        x1: First input tensor.\n        x2: Second input tensor.\n\n    Returns:\n        Outer product of `x1` and `x2`.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Outer().symbolic_call(x1, x2)\n    return backend.numpy.outer(x1, x2)\n\n\nclass Pad(Operation):\n    def __init__(self, pad_width, mode=\"constant\", *, name=None):\n        super().__init__(name=name)\n        self.pad_width = self._process_pad_width(pad_width)\n        self.mode = mode\n\n    def _process_pad_width(self, pad_width):\n        if isinstance(pad_width, int):\n            return ((pad_width, pad_width),)\n        if isinstance(pad_width, (tuple, list)) and isinstance(\n            pad_width[0], int\n        ):\n            return (pad_width,)\n        first_len = len(pad_width[0])\n        for i, pw in enumerate(pad_width):\n            if len(pw) != first_len:\n                raise ValueError(\n                    \"`pad_width` should be a list of tuples of length \"\n                    f\"1 or 2. Received: pad_width={pad_width}\"\n                )\n            if len(pw) == 1:\n                pad_width[i] = (pw[0], pw[0])\n        return pad_width\n\n    def call(self, x, constant_values=None):\n        if len(self.pad_width) > 1 and len(self.pad_width) != len(x.shape):\n            raise ValueError(\n                \"`pad_width` must have the same length as `x.shape`. \"\n                f\"Received: pad_width={self.pad_width} \"\n                f\"(of length {len(self.pad_width)}) and x.shape={x.shape} \"\n                f\"(of length {len(x.shape)})\"\n            )\n        return backend.numpy.pad(\n            x,\n            pad_width=self.pad_width,\n            mode=self.mode,\n            constant_values=constant_values,\n        )\n\n    def compute_output_spec(self, x, constant_values=None):\n        output_shape = list(x.shape)\n        if len(self.pad_width) == 1:\n            pad_width = [self.pad_width[0] for _ in range(len(output_shape))]\n        elif len(self.pad_width) == len(output_shape):\n            pad_width = self.pad_width\n        else:\n            raise ValueError(\n                \"`pad_width` must have the same length as `x.shape`. \"\n                f\"Received: pad_width={self.pad_width} \"\n                f\"(of length {len(self.pad_width)}) and x.shape={x.shape} \"\n                f\"(of length {len(x.shape)})\"\n            )\n\n        for i in range(len(output_shape)):\n            if output_shape[i] is None:\n                output_shape[i] = None\n            else:\n                output_shape[i] += pad_width[i][0] + pad_width[i][1]\n        return KerasTensor(output_shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.pad\", \"keras.ops.numpy.pad\"])\ndef pad(x, pad_width, mode=\"constant\", constant_values=None):\n    \"\"\"Pad a tensor.\n\n    Args:\n        x: Tensor to pad.\n        pad_width: Number of values padded to the edges of each axis.\n            `((before_1, after_1), ...(before_N, after_N))` unique pad\n            widths for each axis.\n            `((before, after),)` yields same before and after pad for\n            each axis.\n            `(pad,)` or `int` is a shortcut for `before = after = pad`\n            width for all axes.\n        mode: One of `\"constant\"`, `\"edge\"`, `\"linear_ramp\"`,\n            `\"maximum\"`, `\"mean\"`, `\"median\"`, `\"minimum\"`,\n            `\"reflect\"`, `\"symmetric\"`, `\"wrap\"`, `\"empty\"`,\n            `\"circular\"`. Defaults to `\"constant\"`.\n        constant_values: value to pad with if `mode == \"constant\"`.\n            Defaults to `0`. A `ValueError` is raised if not None and\n            `mode != \"constant\"`.\n\n    Note:\n        Torch backend only supports modes `\"constant\"`, `\"reflect\"`,\n        `\"symmetric\"` and `\"circular\"`.\n        Only Torch backend supports `\"circular\"` mode.\n\n    Note:\n        Tensorflow backend only supports modes `\"constant\"`, `\"reflect\"`\n        and `\"symmetric\"`.\n\n    Returns:\n        Padded tensor.\n    \"\"\"\n    return Pad(pad_width, mode=mode)(x, constant_values=constant_values)\n\n\nclass Prod(Operation):\n    def __init__(self, axis=None, keepdims=False, dtype=None, *, name=None):\n        super().__init__(name=name)\n        if isinstance(axis, int):\n            self.axis = [axis]\n        else:\n            self.axis = axis\n        self.keepdims = keepdims\n        self.dtype = dtype\n\n    def call(self, x):\n        return backend.numpy.prod(\n            x,\n            axis=self.axis,\n            keepdims=self.keepdims,\n            dtype=self.dtype,\n        )\n\n    def compute_output_spec(self, x):\n        if self.dtype is not None:\n            dtype = self.dtype\n        else:\n            dtype = backend.standardize_dtype(x.dtype)\n            if dtype == \"bool\":\n                dtype = \"int32\"\n            elif dtype in (\"int8\", \"int16\"):\n                dtype = \"int32\"\n            elif dtype in (\"uint8\", \"uint16\"):\n                dtype = \"uint32\"\n        # TODO: torch doesn't support uint32\n        if backend.backend() == \"torch\" and dtype == \"uint32\":\n            dtype = \"int32\"\n        return KerasTensor(\n            reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims),\n            dtype=dtype,\n        )\n\n\n@keras_export([\"keras.ops.prod\", \"keras.ops.numpy.prod\"])\ndef prod(x, axis=None, keepdims=False, dtype=None):\n    \"\"\"Return the product of tensor elements over a given axis.\n\n    Args:\n        x: Input tensor.\n        axis: Axis or axes along which a product is performed. The default,\n            `axis=None`, will compute the product of all elements\n            in the input tensor.\n        keepdims: If this is set to `True`, the axes which are reduce\n            are left in the result as dimensions with size one.\n        dtype: Data type of the returned tensor.\n\n    Returns:\n        Product of elements of `x` over the given axis or axes.\n    \"\"\"\n    dtype = None if dtype is None else backend.standardize_dtype(dtype)\n    if any_symbolic_tensors((x,)):\n        return Prod(axis=axis, keepdims=keepdims, dtype=dtype).symbolic_call(x)\n    return backend.numpy.prod(x, axis=axis, keepdims=keepdims, dtype=dtype)\n\n\nclass Ptp(Operation):\n    def __init__(self, axis=None, keepdims=False, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n        self.keepdims = keepdims\n\n    def call(self, x):\n        return backend.numpy.ptp(\n            x,\n            axis=self.axis,\n            keepdims=self.keepdims,\n        )\n\n    def compute_output_spec(self, x):\n        dtype = backend.standardize_dtype(x.dtype)\n        return KerasTensor(\n            reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims),\n            dtype=dtype,\n        )\n\n\n@keras_export([\"keras.ops.ptp\", \"keras.ops.numpy.ptp\"])\ndef ptp(x, axis=None, keepdims=False):\n    \"\"\"Return the peak-to-peak (max - min) value of tensor elements\n    over a given axis.\n\n    The peak-to-peak value is defined as the difference between the\n    maximum and minimum values along the specified axis.\n\n    Args:\n        x: Input tensor.\n        axis: Axis or axes along which the peak-to-peak value is computed.\n            The default, `axis=None`, will compute the peak-to-peak value\n            over all elements in the input tensor.\n        keepdims: If this is set to `True`, the axes which are reduced\n            are left in the result as dimensions with size one.\n\n    Returns:\n        A tensor containing the peak-to-peak values of `x` over the\n        given axis or axes.\n\n    Examples:\n    >>> x = keras.ops.array([[1., 3., 2.],\n    ...                      [4., 0., 5.]])\n\n    >>> # Peak-to-peak over all elements\n    >>> keras.ops.ptp(x)\n    5.0\n\n    >>> # Peak-to-peak along axis 1\n    >>> keras.ops.ptp(x, axis=1)\n    array([2., 5.], dtype=float32)\n\n    >>> # Peak-to-peak over multiple axes\n    >>> x = keras.ops.reshape(x, (1, 2, 3))\n    >>> keras.ops.ptp(x, axis=(1, 2))\n    array([5.], dtype=float32)\n\n    >>> # Keep reduced dimensions\n    >>> keras.ops.ptp(x, axis=2, keepdims=True)\n    array([[[2.],\n            [5.]]], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Ptp(axis=axis, keepdims=keepdims).symbolic_call(x)\n    return backend.numpy.ptp(x, axis=axis, keepdims=keepdims)\n\n\nclass Quantile(Operation):\n    def __init__(\n        self, axis=None, method=\"linear\", keepdims=False, *, name=None\n    ):\n        super().__init__(name=name)\n        if isinstance(axis, int):\n            axis = [axis]\n        self.axis = axis\n        self.method = method\n        self.keepdims = keepdims\n\n    def call(self, x, q):\n        return backend.numpy.quantile(\n            x, q, axis=self.axis, keepdims=self.keepdims\n        )\n\n    def compute_output_spec(self, x, q):\n        output_shape = reduce_shape(\n            x.shape, axis=self.axis, keepdims=self.keepdims\n        )\n        if hasattr(q, \"shape\"):\n            if len(q.shape) > 0:\n                output_shape = (q.shape[0],) + output_shape\n        if backend.standardize_dtype(x.dtype) == \"int64\":\n            dtype = backend.floatx()\n        else:\n            dtype = dtypes.result_type(x.dtype, float)\n        return KerasTensor(output_shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.quantile\", \"keras.ops.numpy.quantile\"])\ndef quantile(x, q, axis=None, method=\"linear\", keepdims=False):\n    \"\"\"Compute the q-th quantile(s) of the data along the specified axis.\n\n    Args:\n        x: Input tensor.\n        q: Probability or sequence of probabilities for the quantiles to\n            compute. Values must be between 0 and 1 inclusive.\n        axis: Axis or axes along which the quantiles are computed. Defaults to\n            `axis=None` which is to compute the quantile(s) along a flattened\n            version of the array.\n        method: A string specifies the method to use for estimating the\n            quantile. Available methods are `\"linear\"`, `\"lower\"`, `\"higher\"`,\n            `\"midpoint\"`, and `\"nearest\"`. Defaults to `\"linear\"`.\n            If the desired quantile lies between two data points `i < j`:\n            - `\"linear\"`: `i + (j - i) * fraction`, where fraction is the\n                fractional part of the index surrounded by `i` and `j`.\n            - `\"lower\"`: `i`.\n            - `\"higher\"`: `j`.\n            - `\"midpoint\"`: `(i + j) / 2`\n            - `\"nearest\"`: `i` or `j`, whichever is nearest.\n        keepdims: If this is set to `True`, the axes which are reduce\n            are left in the result as dimensions with size one.\n\n    Returns:\n        The quantile(s). If `q` is a single probability and `axis=None`, then\n        the result is a scalar. If multiple probabilities levels are given,\n        first axis of the result corresponds to the quantiles. The other axes\n        are the axes that remain after the reduction of `x`.\n    \"\"\"\n    if any_symbolic_tensors((x, q)):\n        return Quantile(\n            axis=axis, method=method, keepdims=keepdims\n        ).symbolic_call(x, q)\n    return backend.numpy.quantile(\n        x, q, axis=axis, method=method, keepdims=keepdims\n    )\n\n\nclass Ravel(Operation):\n    def call(self, x):\n        return backend.numpy.ravel(x)\n\n    def compute_output_spec(self, x):\n        if None in x.shape:\n            output_shape = [\n                None,\n            ]\n        else:\n            output_shape = [int(np.prod(x.shape))]\n        return KerasTensor(output_shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.ravel\", \"keras.ops.numpy.ravel\"])\ndef ravel(x):\n    \"\"\"Return a contiguous flattened tensor.\n\n    A 1-D tensor, containing the elements of the input, is returned.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Output tensor.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Ravel().symbolic_call(x)\n    return backend.numpy.ravel(x)\n\n\nclass UnravelIndex(Operation):\n    def __init__(self, shape, *, name=None):\n        super().__init__(name=name)\n        self.shape = shape\n\n    def call(self, indices):\n        return backend.numpy.unravel_index(indices, self.shape)\n\n    def compute_output_spec(self, indices):\n        if None in self.shape:\n            output_shapes = [[None] for _ in self.shape]\n        else:\n            if isinstance(indices, int):\n                output_shapes = [[1] for _ in self.shape]\n            elif hasattr(indices, \"shape\"):\n                output_shapes = [list(indices.shape) for _ in self.shape]\n            else:\n                try:\n                    indices_shape = np.shape(indices)\n                    output_shapes = [list(indices_shape) for _ in self.shape]\n                except Exception:\n                    output_shapes = [[None] for _ in self.shape]\n\n        return [\n            KerasTensor(shape, dtype=indices.dtype) for shape in output_shapes\n        ]\n\n\n@keras_export([\"keras.ops.unravel_index\", \"keras.ops.numpy.unravel_index\"])\ndef unravel_index(indices, shape):\n    \"\"\"Convert flat indices to coordinate arrays in a given array shape.\n\n    Args:\n        indices: An integer or array of integers representing flat indices.\n        shape: The shape of the array to unravel into.\n\n    Returns:\n        Tuple of arrays for each dimension with unraveled indices.\n\n    Example:\n    >>> indices = 5\n    >>> shape = (3, 3)\n    >>> unravel_index(indices, shape)\n    (1, 2)  # 5 is at row 1, column 2 in a 3x3 array\n    \"\"\"\n    if any_symbolic_tensors((indices,)):\n        return UnravelIndex(shape).symbolic_call(indices)\n\n    return backend.numpy.unravel_index(indices, shape)\n\n\nclass Real(Operation):\n    def call(self, x):\n        return backend.numpy.real(x)\n\n    def compute_output_spec(self, x):\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(x.shape, dtype=x.dtype, sparse=sparse)\n\n\n@keras_export([\"keras.ops.real\", \"keras.ops.numpy.real\"])\ndef real(x):\n    \"\"\"Return the real part of the complex argument.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        The real component of the complex argument.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Real().symbolic_call(x)\n    return backend.numpy.real(x)\n\n\nclass Reciprocal(Operation):\n    def call(self, x):\n        return backend.numpy.reciprocal(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape)\n\n\n@keras_export(\n    [\n        \"keras.ops.reciprocal\",\n        \"keras.ops.numpy.reciprocal\",\n    ]\n)\ndef reciprocal(x):\n    \"\"\"Return the reciprocal of the argument, element-wise.\n\n    Calculates `1/x`.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Output tensor, element-wise reciprocal of `x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Reciprocal().symbolic_call(x)\n    return backend.numpy.reciprocal(x)\n\n\nclass Repeat(Operation):\n    def __init__(self, repeats, axis=None, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n        self.repeats = repeats\n\n    def call(self, x):\n        return backend.numpy.repeat(x, self.repeats, axis=self.axis)\n\n    def compute_output_spec(self, x):\n        x_shape = list(x.shape)\n        repeats = self.repeats\n        if isinstance(repeats, int):\n            repeats = [repeats]\n        repeats_size = len(repeats)\n        broadcast = repeats_size == 1\n\n        if self.axis is None:\n            if None in x_shape:\n                return KerasTensor([None], dtype=x.dtype)\n\n            x_flatten_size = int(np.prod(x_shape))\n            if broadcast:\n                output_shape = [x_flatten_size * repeats[0]]\n            elif repeats_size != x_flatten_size:\n                raise ValueError(\n                    \"Size of `repeats` and \"\n                    \"dimensions of `x` after flattening should be compatible. \"\n                    f\"Received: {repeats_size} and {x_flatten_size}\"\n                )\n            else:\n                output_shape = [int(np.sum(repeats))]\n            return KerasTensor(output_shape, dtype=x.dtype)\n\n        size_on_ax = x_shape[self.axis]\n        if size_on_ax is None:\n            return KerasTensor(x_shape, dtype=x.dtype)\n\n        output_shape = x_shape\n        if broadcast:\n            output_shape[self.axis] = size_on_ax * repeats[0]\n        elif size_on_ax != repeats_size:\n            raise ValueError(\n                \"Size of `repeats` and \"\n                f\"dimensions of `axis {self.axis} of x` should be compatible. \"\n                f\"Received: {repeats_size} and {x_shape}\"\n            )\n        else:\n            output_shape[self.axis] = int(np.sum(repeats))\n        return KerasTensor(output_shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.repeat\", \"keras.ops.numpy.repeat\"])\ndef repeat(x, repeats, axis=None):\n    \"\"\"Repeat each element of a tensor after themselves.\n\n    Args:\n        x: Input tensor.\n        repeats: The number of repetitions for each element.\n        axis: The axis along which to repeat values. By default, use\n            the flattened input array, and return a flat output array.\n\n    Returns:\n        Output tensor.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Repeat(repeats, axis=axis).symbolic_call(x)\n    return backend.numpy.repeat(x, repeats, axis=axis)\n\n\nclass Reshape(Operation):\n    def __init__(self, newshape, *, name=None):\n        super().__init__(name=name)\n        self.newshape = newshape\n\n    def call(self, x):\n        return backend.numpy.reshape(x, self.newshape)\n\n    def compute_output_spec(self, x):\n        output_shape = operation_utils.compute_reshape_output_shape(\n            x.shape, self.newshape, \"newshape\"\n        )\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(output_shape, dtype=x.dtype, sparse=sparse)\n\n\n@keras_export([\"keras.ops.reshape\", \"keras.ops.numpy.reshape\"])\ndef reshape(x, newshape):\n    \"\"\"Gives a new shape to a tensor without changing its data.\n\n    Args:\n        x: Input tensor.\n        newshape: The new shape should be compatible with the original shape.\n            One shape dimension can be -1 in which case the value is\n            inferred from the length of the array and remaining dimensions.\n\n    Returns:\n        The reshaped tensor.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Reshape(newshape).symbolic_call(x)\n    return backend.numpy.reshape(x, newshape)\n\n\nclass Roll(Operation):\n    def __init__(self, shift, axis=None, *, name=None):\n        super().__init__(name=name)\n        self.shift = shift\n        self.axis = axis\n\n    def call(self, x):\n        return backend.numpy.roll(x, self.shift, self.axis)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.roll\", \"keras.ops.numpy.roll\"])\ndef roll(x, shift, axis=None):\n    \"\"\"Roll tensor elements along a given axis.\n\n    Elements that roll beyond the last position are re-introduced at the first.\n\n    Args:\n        x: Input tensor.\n        shift: The number of places by which elements are shifted.\n        axis: The axis along which elements are shifted. By default, the\n            array is flattened before shifting, after which the original\n            shape is restored.\n\n    Returns:\n        Output tensor.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Roll(shift, axis=axis).symbolic_call(x)\n    return backend.numpy.roll(x, shift, axis=axis)\n\n\nclass Round(Operation):\n    def __init__(self, decimals=0, *, name=None):\n        super().__init__(name=name)\n        self.decimals = decimals\n\n    def call(self, x):\n        return backend.numpy.round(x, self.decimals)\n\n    def compute_output_spec(self, x):\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(x.shape, dtype=x.dtype, sparse=sparse)\n\n\n@keras_export([\"keras.ops.round\", \"keras.ops.numpy.round\"])\ndef round(x, decimals=0):\n    \"\"\"Evenly round to the given number of decimals.\n\n    Args:\n        x: Input tensor.\n        decimals: Number of decimal places to round to. Defaults to `0`.\n\n    Returns:\n        Output tensor.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Round(decimals).symbolic_call(x)\n    return backend.numpy.round(x, decimals)\n\n\nclass SearchSorted(Operation):\n    def __init__(self, side=\"left\", *, name=None):\n        super().__init__(name=name)\n        self.side = side\n\n    def call(self, sorted_sequence, values):\n        sorted_sequence = backend.convert_to_tensor(sorted_sequence)\n        values = backend.convert_to_tensor(values)\n        return backend.numpy.searchsorted(\n            sorted_sequence, values, side=self.side\n        )\n\n    def compute_output_spec(self, sorted_sequence, values):\n        if len(sorted_sequence.shape) != 1:\n            raise ValueError(\n                \"searchsorted only supports 1-D sorted sequences. Use\"\n                \"keras.ops.vectorized_map to extend to N-D sequences.\"\n            )\n        sequence_len = sorted_sequence.shape[0]\n        out_type = (\n            \"int32\"\n            if sequence_len is not None\n            and sequence_len <= np.iinfo(np.int32).max\n            else \"int64\"\n        )\n        return KerasTensor(values.shape, dtype=out_type)\n\n\n@keras_export([\"keras.ops.searchsorted\", \"keras.ops.numpy.searchsorted\"])\ndef searchsorted(sorted_sequence, values, side=\"left\"):\n    \"\"\"Perform a binary search, returning indices for insertion of `values`\n    into `sorted_sequence` that maintain the sorting order.\n\n    Args:\n        sorted_sequence: 1-D input tensor, sorted along the innermost\n            dimension.\n        values: N-D tensor of query insertion values.\n        side: 'left' or 'right', specifying the direction in which to insert\n            for the equality case (tie-breaker).\n\n    Returns:\n        Tensor of insertion indices of same shape as `values`.\n    \"\"\"\n    if any_symbolic_tensors((sorted_sequence, values)):\n        return SearchSorted(side=side).symbolic_call(sorted_sequence, values)\n\n    sorted_sequence = backend.convert_to_tensor(sorted_sequence)\n    values = backend.convert_to_tensor(values)\n    return backend.numpy.searchsorted(sorted_sequence, values, side=side)\n\n\nclass Sign(Operation):\n    def call(self, x):\n        return backend.numpy.sign(x)\n\n    def compute_output_spec(self, x):\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(x.shape, dtype=x.dtype, sparse=sparse)\n\n\n@keras_export([\"keras.ops.sign\", \"keras.ops.numpy.sign\"])\ndef sign(x):\n    \"\"\"Returns a tensor with the signs of the elements of `x`.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Output tensor of same shape as `x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Sign().symbolic_call(x)\n    return backend.numpy.sign(x)\n\n\nclass Signbit(Operation):\n    def call(self, x):\n        return backend.numpy.signbit(x)\n\n    def compute_output_spec(self, x):\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(x.shape, dtype=\"bool\", sparse=sparse)\n\n\n@keras_export([\"keras.ops.signbit\", \"keras.ops.numpy.signbit\"])\ndef signbit(x):\n    \"\"\"Return the sign bit of the elements of `x`.\n\n    The output boolean tensor contains `True` where the sign of `x` is negative,\n    and `False` otherwise.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Output boolean tensor of same shape as `x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Signbit().symbolic_call(x)\n    return backend.numpy.signbit(x)\n\n\nclass Sin(Operation):\n    def call(self, x):\n        return backend.numpy.sin(x)\n\n    def compute_output_spec(self, x):\n        dtype = backend.standardize_dtype(getattr(x, \"dtype\", backend.floatx()))\n        if dtype == \"int64\":\n            dtype = backend.floatx()\n        else:\n            dtype = dtypes.result_type(dtype, float)\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(x.shape, dtype=dtype, sparse=sparse)\n\n\n@keras_export([\"keras.ops.sin\", \"keras.ops.numpy.sin\"])\ndef sin(x):\n    \"\"\"Trigonometric sine, element-wise.\n\n    Arguments:\n        x: Input tensor.\n\n    Returns:\n        Output tensor of same shape as `x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Sin().symbolic_call(x)\n    return backend.numpy.sin(x)\n\n\nclass Sinc(Operation):\n    def call(self, x):\n        return backend.numpy.sinc(x)\n\n    def compute_output_spec(self, x):\n        dtype = backend.standardize_dtype(getattr(x, \"dtype\", backend.floatx()))\n        if dtype == \"int64\":\n            dtype = backend.floatx()\n        else:\n            dtype = dtypes.result_type(dtype, float)\n        return KerasTensor(x.shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.sinc\", \"keras.ops.numpy.sinc\"])\ndef sinc(x):\n    \"\"\"Return the normalized sinc function.\n\n    The sinc function is equal to `sin(pi*x) / (pi*x)` for any argument\n    `x != 0`, and `sinc(0)` takes the limit value 1, making `sinc` not\n    just everywhere continuous but also infinitely differentiable.\n\n    Arguments:\n        x: Input tensor.\n\n    Returns:\n        Output tensor of same shape as `x`.\n\n    Examples:\n    >>> x = keras.ops.convert_to_tensor([0.0, 1.0, 2.0])\n    >>> keras.ops.sinc(x)\n    array([1., 0., 0.], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Sinc().symbolic_call(x)\n    return backend.numpy.sinc(x)\n\n\nclass Sinh(Operation):\n    def call(self, x):\n        return backend.numpy.sinh(x)\n\n    def compute_output_spec(self, x):\n        dtype = backend.standardize_dtype(getattr(x, \"dtype\", backend.floatx()))\n        if dtype == \"int64\":\n            dtype = backend.floatx()\n        else:\n            dtype = dtypes.result_type(dtype, float)\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(x.shape, dtype=dtype, sparse=sparse)\n\n\n@keras_export([\"keras.ops.sinh\", \"keras.ops.numpy.sinh\"])\ndef sinh(x):\n    \"\"\"Hyperbolic sine, element-wise.\n\n    Arguments:\n        x: Input tensor.\n\n    Returns:\n        Output tensor of same shape as `x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Sinh().symbolic_call(x)\n    return backend.numpy.sinh(x)\n\n\nclass Size(Operation):\n    def call(self, x):\n        return backend.numpy.size(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor([], dtype=\"int32\")\n\n\n@keras_export([\"keras.ops.size\", \"keras.ops.numpy.size\"])\ndef size(x):\n    \"\"\"Return the number of elements in a tensor.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Number of elements in `x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Size().symbolic_call(x)\n    return backend.numpy.size(x)\n\n\nclass Sort(Operation):\n    def __init__(self, axis=-1, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n\n    def call(self, x):\n        return backend.numpy.sort(x, axis=self.axis)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, x.dtype)\n\n\n@keras_export([\"keras.ops.sort\", \"keras.ops.numpy.sort\"])\ndef sort(x, axis=-1):\n    \"\"\"Sorts the elements of `x` along a given axis in ascending order.\n\n    Args:\n        x: Input tensor.\n        axis: Axis along which to sort. If `None`, the tensor is flattened\n            before sorting. Defaults to `-1`; the last axis.\n\n    Returns:\n        Sorted tensor.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Sort(axis=axis).symbolic_call(x)\n    return backend.numpy.sort(x, axis=axis)\n\n\ndef _compute_split_output_spec(x, indices_or_sections, axis):\n    x_shape = list(x.shape)\n    x_size_on_axis = x_shape[axis]\n    if isinstance(indices_or_sections, int):\n        if x_size_on_axis is None:\n            x_shape[axis] = None\n            return [\n                KerasTensor(x_shape, dtype=x.dtype)\n                for _ in range(indices_or_sections)\n            ]\n\n        if np.mod(x_size_on_axis, indices_or_sections) != 0:\n            raise ValueError(\n                \"`x` size on given `axis` must be divisible by \"\n                \"`indices_or_sections` when `indices_or_sections` is an \"\n                f\"int. But received {x_size_on_axis} and \"\n                f\"{indices_or_sections}.\"\n            )\n\n        size = x_size_on_axis // indices_or_sections\n        x_shape[axis] = size\n        return [\n            KerasTensor(x_shape, dtype=x.dtype)\n            for _ in range(indices_or_sections)\n        ]\n\n    all_indices = [0] + list(indices_or_sections) + [x_size_on_axis]\n    outputs = []\n\n    for i in range(len(all_indices) - 1):\n        start = all_indices[i]\n        end = all_indices[i + 1]\n        if start is None or end is None:\n            output_size = None\n        else:\n            output_size = end - start\n        output_shape = list(x_shape)\n        output_shape[axis] = output_size\n        outputs.append(KerasTensor(output_shape, dtype=x.dtype))\n\n    return outputs\n\n\nclass Split(Operation):\n    def __init__(self, indices_or_sections, axis=0, *, name=None):\n        super().__init__(name=name)\n        if not isinstance(indices_or_sections, int):\n            indices_or_sections = tuple(indices_or_sections)\n        self.indices_or_sections = indices_or_sections\n        self.axis = axis\n\n    def call(self, x):\n        return backend.numpy.split(x, self.indices_or_sections, axis=self.axis)\n\n    def compute_output_spec(self, x):\n        return _compute_split_output_spec(\n            x, self.indices_or_sections, self.axis\n        )\n\n\n@keras_export([\"keras.ops.split\", \"keras.ops.numpy.split\"])\ndef split(x, indices_or_sections, axis=0):\n    \"\"\"Split a tensor into chunks.\n\n    Args:\n        x: Input tensor.\n        indices_or_sections: If an integer, N, the tensor will be split into N\n            equal sections along `axis`. If a 1-D array of sorted integers,\n            the entries indicate indices at which the tensor will be split\n            along `axis`.\n        axis: Axis along which to split. Defaults to `0`.\n\n    Note:\n        A split does not have to result in equal division when using\n        Torch backend.\n\n    Returns:\n        A list of tensors.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Split(indices_or_sections, axis=axis).symbolic_call(x)\n    return backend.numpy.split(x, indices_or_sections, axis=axis)\n\n\nclass Stack(Operation):\n    def __init__(self, axis=0, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n\n    def call(self, x):\n        return backend.numpy.stack(x, axis=self.axis)\n\n    def compute_output_spec(self, x):\n        first_shape = x[0].shape\n        dtypes_to_resolve = []\n        for a in x:\n            if not shape_equal(a.shape, first_shape, axis=[], allow_none=True):\n                raise ValueError(\n                    \"Every value in `x` must have the same shape. But found \"\n                    f\"element of shape {a.shape},  which is different from the \"\n                    f\"first element's shape {first_shape}.\"\n                )\n            dtypes_to_resolve.append(getattr(a, \"dtype\", type(a)))\n\n        size_on_axis = len(x)\n        output_shape = list(first_shape)\n        if self.axis == -1:\n            output_shape = output_shape + [size_on_axis]\n        elif self.axis >= 0:\n            output_shape.insert(self.axis, size_on_axis)\n        else:\n            output_shape.insert(self.axis + 1, size_on_axis)\n        output_dtype = dtypes.result_type(*dtypes_to_resolve)\n        return KerasTensor(output_shape, dtype=output_dtype)\n\n\n@keras_export([\"keras.ops.stack\", \"keras.ops.numpy.stack\"])\ndef stack(x, axis=0):\n    \"\"\"Join a sequence of tensors along a new axis.\n\n    The `axis` parameter specifies the index of the new axis in the\n    dimensions of the result.\n\n    Args:\n        x: A sequence of tensors.\n        axis: Axis along which to stack. Defaults to `0`.\n\n    Returns:\n        The stacked tensor.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Stack(axis=axis).symbolic_call(x)\n    return backend.numpy.stack(x, axis=axis)\n\n\nclass Std(Operation):\n    def __init__(self, axis=None, keepdims=False, *, name=None):\n        super().__init__(name=name)\n        if isinstance(axis, int):\n            self.axis = [axis]\n        else:\n            self.axis = axis\n        self.keepdims = keepdims\n\n    def call(self, x):\n        return backend.numpy.std(x, axis=self.axis, keepdims=self.keepdims)\n\n    def compute_output_spec(self, x):\n        output_dtype = backend.standardize_dtype(x.dtype)\n        if \"int\" in output_dtype or output_dtype == \"bool\":\n            output_dtype = backend.floatx()\n        return KerasTensor(\n            reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims),\n            dtype=output_dtype,\n        )\n\n\n@keras_export([\"keras.ops.std\", \"keras.ops.numpy.std\"])\ndef std(x, axis=None, keepdims=False):\n    \"\"\"Compute the standard deviation along the specified axis.\n\n    Args:\n        x: Input tensor.\n        axis: Axis along which to compute standard deviation.\n            Default is to compute the standard deviation of the\n            flattened tensor.\n        keepdims: If this is set to `True`, the axes which are reduced are left\n            in the result as dimensions with size one.\n\n    Returns:\n        Output tensor containing the standard deviation values.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Std(axis=axis, keepdims=keepdims).symbolic_call(x)\n    return backend.numpy.std(x, axis=axis, keepdims=keepdims)\n\n\nclass Swapaxes(Operation):\n    def __init__(self, axis1, axis2, *, name=None):\n        super().__init__(name=name)\n\n        self.axis1 = axis1\n        self.axis2 = axis2\n\n    def call(self, x):\n        return backend.numpy.swapaxes(x, self.axis1, self.axis2)\n\n    def compute_output_spec(self, x):\n        x_shape = list(x.shape)\n        tmp = x_shape[self.axis1]\n        x_shape[self.axis1] = x_shape[self.axis2]\n        x_shape[self.axis2] = tmp\n        return KerasTensor(x_shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.swapaxes\", \"keras.ops.numpy.swapaxes\"])\ndef swapaxes(x, axis1, axis2):\n    \"\"\"Interchange two axes of a tensor.\n\n    Args:\n        x: Input tensor.\n        axis1: First axis.\n        axis2: Second axis.\n\n    Returns:\n        A tensor with the axes swapped.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Swapaxes(axis1, axis2).symbolic_call(x)\n    return backend.numpy.swapaxes(x, axis1=axis1, axis2=axis2)\n\n\nclass Take(Operation):\n    def __init__(self, axis=None, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n\n    def call(self, x, indices):\n        return backend.numpy.take(x, indices, axis=self.axis)\n\n    def compute_output_spec(self, x, indices):\n        x_shape = list(x.shape)\n        if isinstance(indices, KerasTensor):\n            indices_shape = list(indices.shape)\n            ragged = indices.ragged\n        else:\n            indices_shape = list(getattr(np.array(indices), \"shape\", []))\n            ragged = False\n        if self.axis is None:\n            return KerasTensor(indices_shape, dtype=x.dtype)\n\n        # make sure axis is non-negative\n        axis = len(x_shape) + self.axis if self.axis < 0 else self.axis\n        output_shape = x_shape[:axis] + indices_shape + x_shape[axis + 1 :]\n        return KerasTensor(output_shape, dtype=x.dtype, ragged=ragged)\n\n\n@keras_export([\"keras.ops.take\", \"keras.ops.numpy.take\"])\ndef take(x, indices, axis=None):\n    \"\"\"Take elements from a tensor along an axis.\n\n    Args:\n        x: Source tensor.\n        indices: The indices of the values to extract.\n        axis: The axis over which to select values. By default, the\n            flattened input tensor is used.\n\n    Returns:\n        The corresponding tensor of values.\n    \"\"\"\n    if any_symbolic_tensors((x, indices)):\n        return Take(axis=axis).symbolic_call(x, indices)\n    return backend.numpy.take(x, indices, axis=axis)\n\n\nclass TakeAlongAxis(Operation):\n    def __init__(self, axis=None, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n\n    def call(self, x, indices):\n        return backend.numpy.take_along_axis(x, indices, axis=self.axis)\n\n    def compute_output_spec(self, x, indices):\n        output_shape = operation_utils.compute_take_along_axis_output_shape(\n            x.shape, indices.shape, self.axis\n        )\n        return KerasTensor(output_shape, dtype=x.dtype)\n\n\n@keras_export(\n    [\n        \"keras.ops.take_along_axis\",\n        \"keras.ops.numpy.take_along_axis\",\n    ]\n)\ndef take_along_axis(x, indices, axis=None):\n    \"\"\"Select values from `x` at the 1-D `indices` along the given axis.\n\n    Args:\n        x: Source tensor.\n        indices: The indices of the values to extract.\n        axis: The axis over which to select values. By default, the flattened\n            input tensor is used.\n\n    Returns:\n        The corresponding tensor of values.\n    \"\"\"\n    if any_symbolic_tensors((x, indices)):\n        return TakeAlongAxis(axis=axis).symbolic_call(x, indices)\n    return backend.numpy.take_along_axis(x, indices, axis=axis)\n\n\nclass Tan(Operation):\n    def call(self, x):\n        return backend.numpy.tan(x)\n\n    def compute_output_spec(self, x):\n        dtype = backend.standardize_dtype(getattr(x, \"dtype\", backend.floatx()))\n        if dtype == \"int64\":\n            dtype = backend.floatx()\n        else:\n            dtype = dtypes.result_type(dtype, float)\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(x.shape, dtype=dtype, sparse=sparse)\n\n\n@keras_export([\"keras.ops.tan\", \"keras.ops.numpy.tan\"])\ndef tan(x):\n    \"\"\"Compute tangent, element-wise.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Output tensor of same shape as `x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Tan().symbolic_call(x)\n    return backend.numpy.tan(x)\n\n\nclass Tanh(Operation):\n    def call(self, x):\n        return backend.numpy.tanh(x)\n\n    def compute_output_spec(self, x):\n        dtype = backend.standardize_dtype(getattr(x, \"dtype\", backend.floatx()))\n        if dtype == \"int64\":\n            dtype = backend.floatx()\n        else:\n            dtype = dtypes.result_type(dtype, float)\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(x.shape, dtype=dtype, sparse=sparse)\n\n\n@keras_export([\"keras.ops.tanh\", \"keras.ops.numpy.tanh\"])\ndef tanh(x):\n    \"\"\"Hyperbolic tangent, element-wise.\n\n    Arguments:\n        x: Input tensor.\n\n    Returns:\n        Output tensor of same shape as `x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Tanh().symbolic_call(x)\n    return backend.numpy.tanh(x)\n\n\nclass Tensordot(Operation):\n    def __init__(self, axes=2, *, name=None):\n        super().__init__(name=name)\n        self.axes = axes\n\n    def call(self, x1, x2):\n        return backend.numpy.tensordot(x1, x2, axes=self.axes)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = list(getattr(x1, \"shape\", []))\n        x2_shape = list(getattr(x2, \"shape\", []))\n        dtype = dtypes.result_type(\n            getattr(x1, \"dtype\", type(x1)),\n            getattr(x2, \"dtype\", type(x2)),\n        )\n        if not isinstance(self.axes, int):\n            x1_select_shape = [x1_shape[ax] for ax in self.axes[0]]\n            x2_select_shape = [x2_shape[ax] for ax in self.axes[1]]\n            if not shape_equal(\n                x1_select_shape, x2_select_shape, allow_none=True\n            ):\n                raise ValueError(\n                    \"Shape mismatch on `x1[axes[0]]` and `x2[axes[1]]`, \"\n                    f\"received {x1_select_shape} and {x2_select_shape}.\"\n                )\n\n            for ax in self.axes[0]:\n                x1_shape[ax] = -1\n            for ax in self.axes[1]:\n                x2_shape[ax] = -1\n\n            x1_shape = list(filter((-1).__ne__, x1_shape))\n            x2_shape = list(filter((-1).__ne__, x2_shape))\n\n            output_shape = x1_shape + x2_shape\n            return KerasTensor(output_shape, dtype=dtype)\n\n        if self.axes <= 0:\n            output_shape = x1_shape + x2_shape\n        else:\n            output_shape = x1_shape[: -self.axes] + x2_shape[self.axes :]\n\n        return KerasTensor(output_shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.tensordot\", \"keras.ops.numpy.tensordot\"])\ndef tensordot(x1, x2, axes=2):\n    \"\"\"Compute the tensor dot product along specified axes.\n\n    Args:\n        x1: First tensor.\n        x2: Second tensor.\n        axes: - If an integer, N, sum over the last N axes of `x1` and the\n                first N axes of `x2` in order. The sizes of the corresponding\n                axes must match.\n              - Or, a list of axes to be summed over, first sequence applying\n                to `x1`, second to `x2`. Both sequences must be of the\n                same length.\n\n    Returns:\n        The tensor dot product of the inputs.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Tensordot(axes=axes).symbolic_call(x1, x2)\n    return backend.numpy.tensordot(x1, x2, axes=axes)\n\n\nclass Tile(Operation):\n    def __init__(self, repeats, *, name=None):\n        super().__init__(name=name)\n        self.repeats = repeats\n\n    def call(self, x):\n        return backend.numpy.tile(x, self.repeats)\n\n    def compute_output_spec(self, x):\n        x_shape = list(x.shape)\n        repeats = self.repeats\n        if isinstance(repeats, int):\n            repeats = [repeats]\n        else:\n            repeats = list(repeats)\n\n        if len(x_shape) > len(repeats):\n            repeats = [1] * (len(x_shape) - len(repeats)) + repeats\n        else:\n            x_shape = [1] * (len(repeats) - len(x_shape)) + x_shape\n\n        output_shape = []\n        for x_size, repeat in zip(x_shape, repeats):\n            if isinstance(x_size, int):\n                output_shape.append(x_size * repeat)\n            else:\n                output_shape.append(None)\n        return KerasTensor(output_shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.tile\", \"keras.ops.numpy.tile\"])\ndef tile(x, repeats):\n    \"\"\"Repeat `x` the number of times given by `repeats`.\n\n    If `repeats` has length `d`, the result will have dimension of\n    `max(d, x.ndim)`.\n\n    If `x.ndim < d`, `x` is promoted to be d-dimensional by prepending\n    new axes.\n\n    If `x.ndim > d`, `repeats` is promoted to `x.ndim` by prepending 1's to it.\n\n    Args:\n        x: Input tensor.\n        repeats: The number of repetitions of `x` along each axis.\n\n    Returns:\n        The tiled output tensor.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Tile(\n            repeats,\n        ).symbolic_call(x)\n    return backend.numpy.tile(x, repeats)\n\n\nclass Trace(Operation):\n    def __init__(self, offset=0, axis1=0, axis2=1, *, name=None):\n        super().__init__(name=name)\n        self.offset = offset\n        self.axis1 = axis1\n        self.axis2 = axis2\n\n    def call(self, x):\n        return backend.numpy.trace(\n            x, offset=self.offset, axis1=self.axis1, axis2=self.axis2\n        )\n\n    def compute_output_spec(self, x):\n        x_shape = list(x.shape)\n        x_shape[self.axis1] = -1\n        x_shape[self.axis2] = -1\n        output_shape = list(filter((-1).__ne__, x_shape))\n        output_dtype = backend.standardize_dtype(x.dtype)\n        if output_dtype in (\"bool\", \"int8\", \"int16\"):\n            output_dtype = \"int32\"\n        elif output_dtype in (\"uint8\", \"uint16\"):\n            output_dtype = \"uint32\"\n        if output_dtype == \"uint32\" and backend.backend() == \"torch\":\n            # Torch backend doesn't support uint32 dtype.\n            output_dtype = \"int32\"\n        return KerasTensor(output_shape, dtype=output_dtype)\n\n\n@keras_export([\"keras.ops.trace\", \"keras.ops.numpy.trace\"])\ndef trace(x, offset=0, axis1=0, axis2=1):\n    \"\"\"Return the sum along diagonals of the tensor.\n\n    If `x` is 2-D, the sum along its diagonal with the given offset is\n    returned, i.e., the sum of elements `x[i, i+offset]` for all `i`.\n\n    If a has more than two dimensions, then the axes specified by `axis1`\n    and `axis2` are used to determine the 2-D sub-arrays whose traces are\n    returned.\n\n    The shape of the resulting tensor is the same as that of `x` with `axis1`\n    and `axis2` removed.\n\n    Args:\n        x: Input tensor.\n        offset: Offset of the diagonal from the main diagonal. Can be\n            both positive and negative. Defaults to `0`.\n        axis1: Axis to be used as the first axis of the 2-D sub-arrays.\n            Defaults to `0`.(first axis).\n        axis2: Axis to be used as the second axis of the 2-D sub-arrays.\n            Defaults to `1` (second axis).\n\n    Returns:\n        If `x` is 2-D, the sum of the diagonal is returned. If `x` has\n        larger dimensions, then a tensor of sums along diagonals is\n        returned.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Trace(offset, axis1, axis2).symbolic_call(x)\n    return backend.numpy.trace(x, offset=offset, axis1=axis1, axis2=axis2)\n\n\n@keras_export([\"keras.ops.tri\", \"keras.ops.numpy.tri\"])\ndef tri(N, M=None, k=0, dtype=None):\n    \"\"\"Return a tensor with ones at and below a diagonal and zeros elsewhere.\n\n    Args:\n        N: Number of rows in the tensor.\n        M: Number of columns in the tensor.\n        k: The sub-diagonal at and below which the array is filled.\n            `k = 0` is the main diagonal, while `k < 0` is below it, and\n            `k > 0` is above. The default is 0.\n        dtype: Data type of the returned tensor. The default is \"float32\".\n\n    Returns:\n        Tensor with its lower triangle filled with ones and zeros elsewhere.\n        `T[i, j] == 1` for `j <= i + k`, 0 otherwise.\n    \"\"\"\n    return backend.numpy.tri(N, M=M, k=k, dtype=dtype)\n\n\nclass Tril(Operation):\n    def __init__(self, k=0, *, name=None):\n        super().__init__(name=name)\n        self.k = k\n\n    def call(self, x):\n        return backend.numpy.tril(x, k=self.k)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.tril\", \"keras.ops.numpy.tril\"])\ndef tril(x, k=0):\n    \"\"\"Return lower triangle of a tensor.\n\n    For tensors with `ndim` exceeding 2, `tril` will apply to the\n    final two axes.\n\n    Args:\n        x: Input tensor.\n        k: Diagonal above which to zero elements. Defaults to `0`. the\n            main diagonal. `k < 0` is below it, and `k > 0` is above it.\n\n    Returns:\n        Lower triangle of `x`, of same shape and data type as `x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Tril(k=k).symbolic_call(x)\n    return backend.numpy.tril(x, k=k)\n\n\nclass Triu(Operation):\n    def __init__(self, k=0, *, name=None):\n        super().__init__(name=name)\n        self.k = k\n\n    def call(self, x):\n        return backend.numpy.triu(x, k=self.k)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.triu\", \"keras.ops.numpy.triu\"])\ndef triu(x, k=0):\n    \"\"\"Return upper triangle of a tensor.\n\n    For tensors with `ndim` exceeding 2, `triu` will apply to the\n    final two axes.\n\n    Args:\n        x: Input tensor.\n        k: Diagonal below which to zero elements. Defaults to `0`. the\n            main diagonal. `k < 0` is below it, and `k > 0` is above it.\n\n    Returns:\n        Upper triangle of `x`, of same shape and data type as `x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Triu(k=k).symbolic_call(x)\n    return backend.numpy.triu(x, k=k)\n\n\nclass Trunc(Operation):\n    def call(self, x):\n        return backend.numpy.trunc(x)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.trunc\", \"keras.ops.numpy.trunc\"])\ndef trunc(x):\n    \"\"\"Return the truncated value of the input, element-wise.\n\n    The truncated value of the scalar `x` is the nearest integer `i` which is\n    closer to zero than `x` is. In short, the fractional part of the signed\n    number `x` is discarded.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        The truncated value of each element in `x`.\n\n    Example:\n    >>> x = ops.array([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0])\n    >>> ops.trunc(x)\n    array([-1.0, -1.0, -0.0, 0.0, 1.0, 1.0, 2.0])\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Trunc().symbolic_call(x)\n    return backend.numpy.trunc(x)\n\n\nclass Vdot(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.vdot(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        dtype = dtypes.result_type(\n            getattr(x1, \"dtype\", type(x1)),\n            getattr(x2, \"dtype\", type(x2)),\n        )\n        return KerasTensor([], dtype=dtype)\n\n\n@keras_export([\"keras.ops.vdot\", \"keras.ops.numpy.vdot\"])\ndef vdot(x1, x2):\n    \"\"\"Return the dot product of two vectors.\n\n    If the first argument is complex, the complex conjugate of the first\n    argument is used for the calculation of the dot product.\n\n    Multidimensional tensors are flattened before the dot product is taken.\n\n    Args:\n        x1: First input tensor. If complex, its complex conjugate is taken\n            before calculation of the dot product.\n        x2: Second input tensor.\n\n    Returns:\n        Output tensor.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Vdot().symbolic_call(x1, x2)\n    return backend.numpy.vdot(x1, x2)\n\n\nclass Inner(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.inner(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        dtype = dtypes.result_type(\n            getattr(x1, \"dtype\", type(x1)),\n            getattr(x2, \"dtype\", type(x2)),\n        )\n        return KerasTensor([], dtype=dtype)\n\n\n@keras_export([\"keras.ops.inner\", \"keras.ops.numpy.inner\"])\ndef inner(x1, x2):\n    \"\"\"Return the inner product of two tensors.\n\n    Ordinary inner product of vectors for 1-D tensors\n    (without complex conjugation), in higher dimensions\n    a sum product over the last axes.\n\n    Multidimensional arrays are treated as vectors by flattening\n    all but their last axes. The resulting dot product is performed\n    over their last axes.\n\n    Args:\n        x1: First input tensor.\n        x2: Second input tensor. The last dimension of `x1` and `x2`\n            must match.\n\n    Returns:\n        Output tensor. The shape of the output is determined by\n        broadcasting the shapes of `x1` and `x2` after removing\n        their last axes.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Inner().symbolic_call(x1, x2)\n    return backend.numpy.inner(x1, x2)\n\n\n@keras_export([\"keras.ops.vectorize\", \"keras.ops.numpy.vectorize\"])\ndef vectorize(pyfunc, *, excluded=None, signature=None):\n    \"\"\"Turn a function into a vectorized function.\n\n    Example:\n\n    ```python\n    def myfunc(a, b):\n        return a + b\n\n    vfunc = keras.ops.vectorize(myfunc)\n    y = vfunc([1, 2, 3, 4], 2)  # Returns Tensor([3, 4, 5, 6])\n    ```\n\n    Args:\n        pyfunc: Callable of a single tensor argument.\n        excluded: Optional set of integers representing\n            positional arguments for which the function\n            will not be vectorized.\n            These will be passed directly to `pyfunc` unmodified.\n        signature: Optional generalized universal function signature,\n            e.g., `\"(m,n),(n)->(m)\"` for vectorized\n            matrix-vector multiplication. If provided,\n            `pyfunc` will be called with (and expected to return)\n            arrays with shapes given by the size of corresponding\n            core dimensions. By default, `pyfunc` is assumed\n            to take scalars tensors as input and output.\n\n    Returns:\n        A new function that applies `pyfunc` to every element\n        of its input along axis 0 (the batch axis).\n    \"\"\"\n    if not callable(pyfunc):\n        raise ValueError(\n            \"Expected argument `pyfunc` to be a callable. \"\n            f\"Received: pyfunc={pyfunc}\"\n        )\n    return backend.numpy.vectorize(\n        pyfunc, excluded=excluded, signature=signature\n    )\n\n\nclass Vstack(Operation):\n    def call(self, xs):\n        return backend.numpy.vstack(xs)\n\n    def compute_output_spec(self, xs):\n        first_shape = xs[0].shape\n        total_size_on_axis = 0\n        dtypes_to_resolve = []\n        for x in xs:\n            if not shape_equal(x.shape, first_shape, axis=[0], allow_none=True):\n                raise ValueError(\n                    \"Every value in `xs` must have the same shape except on \"\n                    f\"the `axis` dim. But found element of shape {x.shape}, \"\n                    f\"which is different from the first element's \"\n                    f\"shape {first_shape}.\"\n                )\n            if total_size_on_axis is None or x.shape[0] is None:\n                total_size_on_axis = None\n            else:\n                total_size_on_axis += x.shape[0]\n            dtypes_to_resolve.append(getattr(x, \"dtype\", type(x)))\n        output_shape = list(first_shape)\n        output_shape[0] = total_size_on_axis\n        output_dtype = dtypes.result_type(*dtypes_to_resolve)\n        return KerasTensor(output_shape, output_dtype)\n\n\n@keras_export([\"keras.ops.vstack\", \"keras.ops.numpy.vstack\"])\ndef vstack(xs):\n    \"\"\"Stack tensors in sequence vertically (row wise).\n\n    Args:\n        xs: Sequence of tensors.\n\n    Returns:\n        Tensor formed by stacking the given tensors.\n    \"\"\"\n    if any_symbolic_tensors((xs,)):\n        return Vstack().symbolic_call(xs)\n    return backend.numpy.vstack(xs)\n\n\nclass Vsplit(Operation):\n    def __init__(self, indices_or_sections, *, name=None):\n        super().__init__(name=name)\n        if not isinstance(indices_or_sections, int):\n            indices_or_sections = tuple(indices_or_sections)\n        self.indices_or_sections = indices_or_sections\n\n    def call(self, x):\n        return backend.numpy.vsplit(x, self.indices_or_sections)\n\n    def compute_output_spec(self, x):\n        if len(x.shape) < 2:\n            raise ValueError(\n                \"`vsplit` only works on arrays of at least 2 dimensions. \"\n                f\"Received array with shape {x.shape}.\"\n            )\n        return _compute_split_output_spec(x, self.indices_or_sections, 0)\n\n\n@keras_export([\"keras.ops.vsplit\", \"keras.ops.numpy.vsplit\"])\ndef vsplit(x, indices_or_sections):\n    \"\"\"Split an array into multiple sub-arrays vertically (row-wise).\n\n    Args:\n        x: Input tensor.\n        indices_or_sections: If an integer, N, the tensor will be split into N\n            equal sections along axis 0. If a 1-D array of sorted integers,\n            the entries indicate indices at which the tensor will be split\n            along axis 0.\n\n    Returns:\n        A list of sub-arrays.\n\n    Example:\n\n    >>> x = keras.ops.arange(12).reshape((4, 3))\n    >>> keras.ops.vsplit(x, 2)\n    [array([[0, 1, 2],\n           [3, 4, 5]]),\n     array([[ 6,  7,  8],\n           [ 9, 10, 11]])]\n    >>> keras.ops.vsplit(x, [1, 3])\n    [array([[0, 1, 2]]),\n     array([[3, 4, 5],\n           [6, 7, 8]]),\n     array([[ 9, 10, 11]])]\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Vsplit(indices_or_sections).symbolic_call(x)\n    return backend.numpy.vsplit(x, indices_or_sections)\n\n\nclass Where(Operation):\n    def call(self, condition, x1=None, x2=None):\n        return backend.numpy.where(condition, x1, x2)\n\n    def compute_output_spec(self, condition, x1, x2):\n        condition_shape = getattr(condition, \"shape\", [])\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(condition_shape, x1_shape)\n        output_shape = broadcast_shapes(output_shape, x2_shape)\n        output_dtype = dtypes.result_type(\n            getattr(x1, \"dtype\", type(x1) if x1 is not None else \"int\"),\n            getattr(x2, \"dtype\", type(x2) if x2 is not None else \"int\"),\n        )\n        return KerasTensor(output_shape, dtype=output_dtype)\n\n\n@keras_export([\"keras.ops.where\", \"keras.ops.numpy.where\"])\ndef where(condition, x1=None, x2=None):\n    \"\"\"Return elements chosen from `x1` or `x2` depending on `condition`.\n\n    Args:\n        condition: Where `True`, yield `x1`, otherwise yield `x2`.\n        x1: Values from which to choose when `condition` is `True`.\n        x2: Values from which to choose when `condition` is `False`.\n\n    Returns:\n        A tensor with elements from `x1` where `condition` is `True`, and\n        elements from `x2` where `condition` is `False`.\n    \"\"\"\n    if (x1 is None and x2 is not None) or (x1 is not None and x2 is None):\n        raise ValueError(\n            \"`x1` and `x2` either both should be `None`\"\n            \" or both should have non-None value.\"\n        )\n    if any_symbolic_tensors((condition, x1, x2)):\n        return Where().symbolic_call(condition, x1, x2)\n    return backend.numpy.where(condition, x1, x2)\n\n\nclass Subtract(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.subtract(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n        x1_sparse = getattr(x1, \"sparse\", False)\n        x2_sparse = getattr(x2, \"sparse\", False)\n        output_sparse = x1_sparse and x2_sparse\n        dtype = dtypes.result_type(\n            getattr(x1, \"dtype\", type(x1)),\n            getattr(x2, \"dtype\", type(x2)),\n        )\n        return KerasTensor(output_shape, dtype=dtype, sparse=output_sparse)\n\n\n@keras_export([\"keras.ops.subtract\", \"keras.ops.numpy.subtract\"])\ndef subtract(x1, x2):\n    \"\"\"Subtract arguments element-wise.\n\n    Args:\n        x1: First input tensor.\n        x2: Second input tensor.\n\n    Returns:\n        Output tensor, element-wise difference of `x1` and `x2`.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Subtract().symbolic_call(x1, x2)\n    return backend.numpy.subtract(x1, x2)\n\n\nclass Multiply(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.multiply(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n        x1_sparse = getattr(x1, \"sparse\", True)\n        x2_sparse = getattr(x2, \"sparse\", True)\n        output_sparse = x1_sparse or x2_sparse\n        dtype = dtypes.result_type(\n            getattr(x1, \"dtype\", type(x1)),\n            getattr(x2, \"dtype\", type(x2)),\n        )\n        return KerasTensor(output_shape, dtype=dtype, sparse=output_sparse)\n\n\n@keras_export([\"keras.ops.multiply\", \"keras.ops.numpy.multiply\"])\ndef multiply(x1, x2):\n    \"\"\"Multiply arguments element-wise.\n\n    Args:\n        x1: First input tensor.\n        x2: Second input tensor.\n\n    Returns:\n        Output tensor, element-wise product of `x1` and `x2`.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Multiply().symbolic_call(x1, x2)\n    return backend.numpy.multiply(x1, x2)\n\n\nclass Divide(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.divide(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n        output_dtype = dtypes.result_type(\n            getattr(x1, \"dtype\", type(x1)),\n            getattr(x2, \"dtype\", type(x2)),\n            float,\n        )\n        x1_sparse = getattr(x1, \"sparse\", False)\n        x2_sparse = getattr(x2, \"sparse\", False)\n        output_sparse = x1_sparse and not x2_sparse\n        return KerasTensor(\n            output_shape, dtype=output_dtype, sparse=output_sparse\n        )\n\n\n@keras_export([\"keras.ops.divide\", \"keras.ops.numpy.divide\"])\ndef divide(x1, x2):\n    \"\"\"Divide arguments element-wise.\n\n    `keras.ops.true_divide` is an alias for this function.\n\n    Args:\n        x1: First input tensor.\n        x2: Second input tensor.\n\n    Returns:\n        Output tensor, the quotient `x1/x2`, element-wise.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Divide().symbolic_call(x1, x2)\n    return backend.numpy.divide(x1, x2)\n\n\nclass DivideNoNan(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.divide_no_nan(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n        output_dtype = dtypes.result_type(\n            getattr(x1, \"dtype\", type(x1)),\n            getattr(x2, \"dtype\", type(x2)),\n            float,\n        )\n        x1_sparse = getattr(x1, \"sparse\", False)\n        x2_sparse = getattr(x2, \"sparse\", False)\n        output_sparse = x1_sparse and not x2_sparse\n        return KerasTensor(\n            output_shape, dtype=output_dtype, sparse=output_sparse\n        )\n\n\n@keras_export([\"keras.ops.divide_no_nan\", \"keras.ops.numpy.divide_no_nan\"])\ndef divide_no_nan(x1, x2):\n    \"\"\"Safe element-wise division which returns 0 where the denominator is 0.\n\n    Args:\n        x1: First input tensor.\n        x2: Second input tensor.\n\n    Returns:\n        The quotient `x1/x2`, element-wise, with zero where x2 is zero.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return DivideNoNan().symbolic_call(x1, x2)\n    return backend.numpy.divide_no_nan(x1, x2)\n\n\nclass TrueDivide(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.true_divide(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n        output_dtype = dtypes.result_type(\n            getattr(x1, \"dtype\", type(x1)),\n            getattr(x2, \"dtype\", type(x2)),\n            float,\n        )\n        x1_sparse = getattr(x1, \"sparse\", False)\n        x2_sparse = getattr(x2, \"sparse\", False)\n        output_sparse = x1_sparse and not x2_sparse\n        return KerasTensor(\n            output_shape, dtype=output_dtype, sparse=output_sparse\n        )\n\n\n@keras_export(\n    [\n        \"keras.ops.true_divide\",\n        \"keras.ops.numpy.true_divide\",\n    ]\n)\ndef true_divide(x1, x2):\n    \"\"\"Alias for `keras.ops.divide`.\"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return TrueDivide().symbolic_call(x1, x2)\n    return backend.numpy.true_divide(x1, x2)\n\n\nclass Power(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.power(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n        output_dtype = dtypes.result_type(\n            getattr(x1, \"dtype\", type(x1)), getattr(x2, \"dtype\", type(x2))\n        )\n        return KerasTensor(output_shape, dtype=output_dtype)\n\n\n@keras_export([\"keras.ops.power\", \"keras.ops.numpy.power\"])\ndef power(x1, x2):\n    \"\"\"First tensor elements raised to powers from second tensor, element-wise.\n\n    Args:\n        x1: The bases.\n        x2: The exponents.\n\n    Returns:\n        Output tensor, the bases in `x1` raised to the exponents in `x2`.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Power().symbolic_call(x1, x2)\n    return backend.numpy.power(x1, x2)\n\n\nclass Negative(Operation):\n    def call(self, x):\n        return backend.numpy.negative(x)\n\n    def compute_output_spec(self, x):\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(x.shape, dtype=x.dtype, sparse=sparse)\n\n\n@keras_export([\"keras.ops.negative\", \"keras.ops.numpy.negative\"])\ndef negative(x):\n    \"\"\"Numerical negative, element-wise.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Output tensor, `y = -x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Negative().symbolic_call(x)\n    return backend.numpy.negative(x)\n\n\nclass Nextafter(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.nextafter(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n\n        x1_type = backend.standardize_dtype(getattr(x1, \"dtype\", type(x1)))\n        x2_type = backend.standardize_dtype(getattr(x2, \"dtype\", type(x2)))\n        dtype = dtypes.result_type(x1_type, x2_type, float)\n        return KerasTensor(output_shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.nextafter\", \"keras.ops.numpy.nextafter\"])\ndef nextafter(x1, x2):\n    \"\"\"\n    Return the next representable floating-point value after `x1` towards `x2`.\n\n    This function computes the next floating-point value\n    following `x1` in the direction of `x2`, element-wise.\n\n    Args:\n        x1: Input tensor whose values will be moved to the next\n            representable floating-point value.\n        x2: Input tensor indicating the direction toward which\n            `x1` is moved.\n\n    Returns:\n        Output tensor\n\n    Example:\n    >>> x1 = keras.ops.convert_to_tensor([1.0, 1.0])\n    >>> x2 = keras.ops.convert_to_tensor([2.0, 0.0])\n    >>> keras.ops.nextafter(x1, x2)\n    array([1.0000001, 0.99999994], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Nextafter().symbolic_call(x1, x2)\n    return backend.numpy.nextafter(x1, x2)\n\n\nclass Square(Operation):\n    def call(self, x):\n        return backend.numpy.square(x)\n\n    def compute_output_spec(self, x):\n        sparse = getattr(x, \"sparse\", False)\n        dtype = backend.standardize_dtype(x.dtype)\n        if dtype == \"bool\":\n            dtype = \"int32\"\n        return KerasTensor(x.shape, dtype=dtype, sparse=sparse)\n\n\n@keras_export([\"keras.ops.square\", \"keras.ops.numpy.square\"])\ndef square(x):\n    \"\"\"Return the element-wise square of the input.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Output tensor, the square of `x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Square().symbolic_call(x)\n    return backend.numpy.square(x)\n\n\nclass Sqrt(Operation):\n    def call(self, x):\n        x = backend.convert_to_tensor(x)\n        return backend.numpy.sqrt(x)\n\n    def compute_output_spec(self, x):\n        dtype = (\n            backend.floatx()\n            if backend.standardize_dtype(x.dtype) == \"int64\"\n            else dtypes.result_type(x.dtype, float)\n        )\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(x.shape, dtype=dtype, sparse=sparse)\n\n\n@keras_export([\"keras.ops.sqrt\", \"keras.ops.numpy.sqrt\"])\ndef sqrt(x):\n    \"\"\"Return the non-negative square root of a tensor, element-wise.\n\n    Args:\n        x: Input tensor.\n\n    Returns:\n        Output tensor, the non-negative square root of `x`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Sqrt().symbolic_call(x)\n    x = backend.convert_to_tensor(x)\n    return backend.numpy.sqrt(x)\n\n\nclass Squeeze(Operation):\n    def __init__(self, axis=None, *, name=None):\n        super().__init__(name=name)\n        self.axis = axis\n\n    def call(self, x):\n        return backend.numpy.squeeze(x, axis=self.axis)\n\n    def compute_output_spec(self, x):\n        input_shape = list(x.shape)\n        sparse = getattr(x, \"sparse\", False)\n        axis = to_tuple_or_list(self.axis)\n        if axis is None:\n            output_shape = list(filter((1).__ne__, input_shape))\n            return KerasTensor(output_shape, dtype=x.dtype, sparse=sparse)\n        else:\n            for a in axis:\n                if input_shape[a] != 1:\n                    raise ValueError(\n                        f\"Cannot squeeze axis {a}, because the dimension \"\n                        \"is not 1.\"\n                    )\n            axis = [canonicalize_axis(a, len(input_shape)) for a in axis]\n            for a in sorted(axis, reverse=True):\n                del input_shape[a]\n            return KerasTensor(input_shape, dtype=x.dtype, sparse=sparse)\n\n\n@keras_export([\"keras.ops.squeeze\", \"keras.ops.numpy.squeeze\"])\ndef squeeze(x, axis=None):\n    \"\"\"Remove axes of length one from `x`.\n\n    Args:\n        x: Input tensor.\n        axis: Select a subset of the entries of length one in the shape.\n\n    Returns:\n        The input tensor with all or a subset of the dimensions of\n        length 1 removed.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Squeeze(axis=axis).symbolic_call(x)\n    return backend.numpy.squeeze(x, axis=axis)\n\n\nclass Transpose(Operation):\n    def __init__(self, axes=None, *, name=None):\n        super().__init__(name=name)\n        self.axes = axes\n\n    def call(self, x):\n        return backend.numpy.transpose(x, axes=self.axes)\n\n    def compute_output_spec(self, x):\n        output_shape = operation_utils.compute_transpose_output_shape(\n            x.shape, self.axes\n        )\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(output_shape, dtype=x.dtype, sparse=sparse)\n\n\n@keras_export([\"keras.ops.transpose\", \"keras.ops.numpy.transpose\"])\ndef transpose(x, axes=None):\n    \"\"\"Returns a tensor with `axes` transposed.\n\n    Args:\n        x: Input tensor.\n        axes: Sequence of integers. Permutation of the dimensions of `x`.\n            By default, the order of the axes are reversed.\n\n    Returns:\n        `x` with its axes permuted.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Transpose(axes=axes).symbolic_call(x)\n    return backend.numpy.transpose(x, axes=axes)\n\n\nclass Trapezoid(Operation):\n    def __init__(self, x=None, dx=1.0, axis=-1, *, name=None):\n        super().__init__(name=name)\n        self.x = x\n        self.dx = dx\n        self.axis = axis\n\n    def call(self, y):\n        return backend.numpy.trapezoid(y, x=self.x, dx=self.dx, axis=self.axis)\n\n    def compute_output_spec(self, y):\n        out_shape = list(y.shape)\n        if self.axis is not None and len(out_shape) > 0:\n            out_shape.pop(self.axis % len(out_shape))\n        dtype = backend.result_type(getattr(y, \"dtype\", type(y)), float)\n        return KerasTensor(tuple(out_shape), dtype=dtype)\n\n\n@keras_export([\"keras.ops.trapezoid\", \"keras.ops.numpy.trapezoid\"])\ndef trapezoid(y, x=None, dx=1.0, axis=-1):\n    \"\"\"Integrate along the given axis using the composite trapezoidal rule.\n\n    Args:\n        y: Input tensor.\n        x: Optional tensor specifying sample points corresponding to `y`.\n           If `None`, spacing is assumed to be `dx`.\n        dx: Spacing between sample points when `x` is `None`.\n        axis: Axis along which to integrate. Default is the last axis.\n\n    Returns:\n        The approximate integral of `y` along the given axis.\n\n    Example:\n    >>> y = keras.ops.convert_to_tensor([[1, 2, 3], [4, 5, 6]])\n    >>> keras.ops.trapezoid(y, axis=1)\n    array([ 4., 10.], dtype=float32)\n    \"\"\"\n    if any_symbolic_tensors((y,)):\n        return Trapezoid(x=x, dx=dx, axis=axis).symbolic_call(y)\n    return backend.numpy.trapezoid(y, x=x, dx=dx, axis=axis)\n\n\nclass Mean(Operation):\n    def __init__(self, axis=None, keepdims=False, *, name=None):\n        super().__init__(name=name)\n        if isinstance(axis, int):\n            axis = [axis]\n        self.axis = axis\n        self.keepdims = keepdims\n\n    def call(self, x):\n        return backend.numpy.mean(x, axis=self.axis, keepdims=self.keepdims)\n\n    def compute_output_spec(self, x):\n        ori_dtype = backend.standardize_dtype(x.dtype)\n        compute_dtype = dtypes.result_type(x.dtype, \"float32\")\n        if \"int\" in ori_dtype or ori_dtype == \"bool\":\n            result_dtype = compute_dtype\n        else:\n            result_dtype = ori_dtype\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(\n            reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims),\n            dtype=result_dtype,\n            sparse=sparse,\n        )\n\n\n@keras_export([\"keras.ops.mean\", \"keras.ops.numpy.mean\"])\ndef mean(x, axis=None, keepdims=False):\n    \"\"\"Compute the arithmetic mean along the specified axes.\n\n    Args:\n        x: Input tensor.\n        axis: Axis or axes along which the means are computed. The default\n            is to compute the mean of the flattened tensor.\n        keepdims: If this is set to `True`, the axes which are reduced are left\n            in the result as dimensions with size one.\n\n    Returns:\n        Output tensor containing the mean values.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Mean(axis=axis, keepdims=keepdims).symbolic_call(x)\n    return backend.numpy.mean(x, axis=axis, keepdims=keepdims)\n\n\nclass Vander(Operation):\n    def __init__(self, N=None, increasing=False, *, name=None):\n        super().__init__(name=name)\n        self.N = N\n        self.increasing = increasing\n\n    def call(self, x):\n        return backend.numpy.vander(x, self.N, self.increasing)\n\n    def compute_output_spec(self, x):\n        if self.N is None:\n            N = x.shape[0]\n        else:\n            N = self.N\n\n        out_shape = x.shape + (N,)\n        return KerasTensor(tuple(out_shape), dtype=x.dtype)\n\n\n@keras_export([\"keras.ops.vander\", \"keras.ops.numpy.vander\"])\ndef vander(x, N=None, increasing=False):\n    \"\"\"Generate a Vandermonde matrix.\n\n    Args:\n        x: 1D input tensor.\n        N: Number of columns. If `None`, `N` = `len(x)`.\n        increasing: Order of powers. If True, powers increase left to right.\n\n    Returns:\n        Output tensor, Vandermonde matrix of shape `(len(x), N)`.\n\n    Example:\n    >>> import numpy as np\n    >>> import keras\n    >>> x = np.array([1, 2, 3, 5])\n    >>> keras.ops.vander(x)\n    array([[  1,   1,   1,   1],\n           [  8,   4,   2,   1],\n           [ 27,   9,   3,   1],\n           [125,  25,   5,   1]])\n    \"\"\"\n\n    if len(x.shape) != 1:\n        raise ValueError(\n            \"Input tensor must be 1-dimensional. \"\n            f\"Received: input.shape={x.shape}\"\n        )\n\n    if N is not None:\n        if not isinstance(N, int):\n            raise TypeError(\n                f\"Argument `N` must be of type `int`. \"\n                f\"Received: N={N} of type {type(N)}\"\n            )\n\n        if N < 0:\n            raise ValueError(\n                f\"Argument 'N' must be nonnegative. Received: N={N}\"\n            )\n\n    if not isinstance(increasing, bool):\n        raise TypeError(\n            f\"Argument `increasing` must be of type `bool`. \"\n            f\"Received: increasing={increasing} of type {type(increasing)}\"\n        )\n\n    if any_symbolic_tensors((x,)):\n        return Vander(N=N, increasing=increasing).symbolic_call(x)\n    return backend.numpy.vander(x, N=N, increasing=increasing)\n\n\nclass Var(Operation):\n    def __init__(self, axis=None, keepdims=False, *, name=None):\n        super().__init__(name=name)\n        if isinstance(axis, int):\n            axis = [axis]\n        self.axis = axis\n        self.keepdims = keepdims\n\n    def call(self, x):\n        return backend.numpy.var(x, axis=self.axis, keepdims=self.keepdims)\n\n    def compute_output_spec(self, x):\n        output_dtype = backend.result_type(getattr(x, \"dtype\", type(x)), float)\n        return KerasTensor(\n            reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims),\n            dtype=output_dtype,\n        )\n\n\n@keras_export([\"keras.ops.var\", \"keras.ops.numpy.var\"])\ndef var(x, axis=None, keepdims=False):\n    \"\"\"Compute the variance along the specified axes.\n\n    Args:\n        x: Input tensor.\n        axis: Axis or axes along which the variance is computed. The default\n            is to compute the variance of the flattened tensor.\n        keepdims: If this is set to `True`, the axes which are reduced are left\n            in the result as dimensions with size one.\n\n    Returns:\n        Output tensor containing the variance.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Var(axis=axis, keepdims=keepdims).symbolic_call(x)\n    return backend.numpy.var(x, axis=axis, keepdims=keepdims)\n\n\nclass Sum(Operation):\n    def __init__(self, axis=None, keepdims=False, *, name=None):\n        super().__init__(name=name)\n        if isinstance(axis, int):\n            axis = [axis]\n        self.axis = axis\n        self.keepdims = keepdims\n\n    def call(self, x):\n        return backend.numpy.sum(x, axis=self.axis, keepdims=self.keepdims)\n\n    def compute_output_spec(self, x):\n        dtype = dtypes.result_type(getattr(x, \"dtype\", backend.floatx()))\n        # follow jax's rule\n        if dtype in (\"bool\", \"int8\", \"int16\"):\n            dtype = \"int32\"\n        elif dtype in (\"uint8\", \"uint16\"):\n            dtype = \"uint32\"\n        # TODO: torch doesn't support uint32\n        if backend.backend() == \"torch\" and dtype == \"uint32\":\n            dtype = \"int32\"\n        sparse = getattr(x, \"sparse\", False)\n        return KerasTensor(\n            reduce_shape(x.shape, axis=self.axis, keepdims=self.keepdims),\n            dtype=dtype,\n            sparse=sparse,\n        )\n\n\n@keras_export([\"keras.ops.sum\", \"keras.ops.numpy.sum\"])\ndef sum(x, axis=None, keepdims=False):\n    \"\"\"Sum of a tensor over the given axes.\n\n    Args:\n        x: Input tensor.\n        axis: Axis or axes along which the sum is computed. The default is to\n            compute the sum of the flattened tensor.\n        keepdims: If this is set to `True`, the axes which are reduced are left\n            in the result as dimensions with size one.\n\n    Returns:\n        Output tensor containing the sum.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Sum(axis=axis, keepdims=keepdims).symbolic_call(x)\n    return backend.numpy.sum(x, axis=axis, keepdims=keepdims)\n\n\n@keras_export([\"keras.ops.zeros\", \"keras.ops.numpy.zeros\"])\ndef zeros(shape, dtype=None):\n    \"\"\"Return a new tensor of given shape and type, filled with zeros.\n\n    Args:\n        shape: Shape of the new tensor.\n        dtype: Desired data type of the tensor.\n\n    Returns:\n        Tensor of zeros with the given shape and dtype.\n    \"\"\"\n    return backend.numpy.zeros(shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.ones\", \"keras.ops.numpy.ones\"])\ndef ones(shape, dtype=None):\n    \"\"\"Return a new tensor of given shape and type, filled with ones.\n\n    Args:\n        shape: Shape of the new tensor.\n        dtype: Desired data type of the tensor.\n\n    Returns:\n        Tensor of ones with the given shape and dtype.\n    \"\"\"\n    return backend.numpy.ones(shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.eye\", \"keras.ops.numpy.eye\"])\ndef eye(N, M=None, k=0, dtype=None):\n    \"\"\"Return a 2-D tensor with ones on the diagonal and zeros elsewhere.\n\n    Args:\n        N: Number of rows in the output.\n        M: Number of columns in the output. If `None`, defaults to `N`.\n        k: Index of the diagonal: 0 (the default) refers to the main\n            diagonal, a positive value refers to an upper diagonal,\n            and a negative value to a lower diagonal.\n        dtype: Data type of the returned tensor.\n\n    Returns:\n        Tensor with ones on the k-th diagonal and zeros elsewhere.\n    \"\"\"\n\n    def is_floating_type(v):\n        return (\n            isinstance(v, float)\n            or getattr(v, \"dtype\", None) in dtypes.FLOAT_TYPES\n        )\n\n    if is_floating_type(N):\n        raise TypeError(\"Argument `N` must be an integer or an integer tensor.\")\n    if is_floating_type(M):\n        raise TypeError(\n            \"Argument `M` must be an integer, an integer tensor, or `None`.\"\n        )\n    return backend.numpy.eye(N, M=M, k=k, dtype=dtype)\n\n\nclass FloorDivide(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.floor_divide(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n        output_dtype = dtypes.result_type(\n            getattr(x1, \"dtype\", type(x1)),\n            getattr(x2, \"dtype\", type(x2)),\n        )\n        return KerasTensor(output_shape, dtype=output_dtype)\n\n\n@keras_export([\"keras.ops.floor_divide\", \"keras.ops.numpy.floor_divide\"])\ndef floor_divide(x1, x2):\n    \"\"\"Returns the largest integer smaller or equal to the division of inputs.\n\n    Args:\n        x1: Numerator.\n        x2: Denominator.\n\n    Returns:\n        Output tensor, `y = floor(x1/x2)`\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return FloorDivide().symbolic_call(x1, x2)\n    return backend.numpy.floor_divide(x1, x2)\n\n\nclass LogicalXor(Operation):\n    def call(self, x1, x2):\n        return backend.numpy.logical_xor(x1, x2)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        output_shape = broadcast_shapes(x1_shape, x2_shape)\n        return KerasTensor(output_shape, dtype=\"bool\")\n\n\n@keras_export([\"keras.ops.logical_xor\", \"keras.ops.numpy.logical_xor\"])\ndef logical_xor(x1, x2):\n    \"\"\"Compute the truth value of `x1 XOR x2`, element-wise.\n\n    Args:\n        x1: First input tensor.\n        x2: Second input tensor.\n\n    Returns:\n        Output boolean tensor.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return LogicalXor().symbolic_call(x1, x2)\n    return backend.numpy.logical_xor(x1, x2)\n\n\nclass Corrcoef(Operation):\n    def call(self, x):\n        return backend.numpy.corrcoef(x)\n\n    def compute_output_spec(self, x):\n        dtype = backend.standardize_dtype(getattr(x, \"dtype\", backend.floatx()))\n        if dtype == \"int64\":\n            dtype = \"float64\"\n        else:\n            dtype = dtypes.result_type(dtype, float)\n        return KerasTensor(x.shape, dtype=dtype)\n\n\n@keras_export([\"keras.ops.corrcoef\", \"keras.ops.numpy.corrcoef\"])\ndef corrcoef(x):\n    \"\"\"Compute the Pearson correlation coefficient matrix.\n\n    Args:\n        x: A 2D tensor of shape `(N, D)`, where N is the number of variables\n           and D is the number of observations.\n\n    Returns:\n        A tensor of shape `(N, N)` representing the correlation matrix.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Corrcoef().symbolic_call(x)\n    return backend.numpy.corrcoef(x)\n\n\nclass Correlate(Operation):\n    def __init__(self, mode=\"valid\", *, name=None):\n        super().__init__(name=name)\n        self.mode = mode\n\n    def call(self, x1, x2):\n        return backend.numpy.correlate(x1, x2, mode=self.mode)\n\n    def compute_output_spec(self, x1, x2):\n        x1_shape = getattr(x1, \"shape\", [])\n        x2_shape = getattr(x2, \"shape\", [])\n        if len(x1_shape) != 1:\n            raise ValueError(\n                \"`x1` must be a 1-dimensional tensor, but received\"\n                + f\"shape {x1_shape}\"\n            )\n        if len(x2_shape) != 1:\n            raise ValueError(\n                \"`x2` must be a 1-dimensional tensor, but received\"\n                + f\"shape {x2_shape}\"\n            )\n        x1_len, x2_len = x1_shape[0], x2_shape[0]\n        output_shape = (\n            np.maximum(x1_len, x2_len) - np.minimum(x1_len, x2_len) + 1,\n        )\n        if self.mode == \"same\":\n            output_shape = (np.maximum(x1_len, x2_len),)\n        elif self.mode == \"full\":\n            output_shape = (x1_len + x2_len - 1,)\n        if self.mode not in (\"valid\", \"same\", \"full\"):\n            raise ValueError(\n                \"`mode` must be either `valid`, `same`, or `full`, but\"\n                f\"received: {self.mode}\"\n            )\n        output_dtype = dtypes.result_type(\n            getattr(x1, \"dtype\", type(x1)),\n            getattr(x2, \"dtype\", type(x2)),\n        )\n        if output_dtype == \"int64\":\n            output_dtype = \"float64\"\n        elif output_dtype not in [\"bfloat16\", \"float16\", \"float64\"]:\n            output_dtype = \"float32\"\n        return KerasTensor(output_shape, dtype=output_dtype)\n\n\n@keras_export([\"keras.ops.correlate\", \"keras.ops.numpy.correlate\"])\ndef correlate(x1, x2, mode=\"valid\"):\n    \"\"\"Compute the cross-correlation of two 1-dimensional tensors.\n\n    Args:\n        x1: First 1-dimensional input tensor of length M.\n        x2: Second 1-dimensional input tensor of length N.\n        mode: Either `valid`, `same` or `full`.\n            By default the mode is set to `valid`, which returns\n            an output of length max(M, N) - min(M, N) + 1.\n            `same` returns an output of length max(M, N).\n            `full` mode returns the convolution at each point of\n            overlap, with an output length of N+M-1\n\n    Returns:\n        Output tensor, cross-correlation of `x1` and `x2`.\n\n    Notes:\n        Complex-valued inputs are currently not fully supported on the\n        TensorFlow and PyTorch backends. When complex tensors are passed,\n        they are cast to floating-point types and the imaginary component\n        is discarded.\n\n        This behavior is documented for clarity and may change in the\n        future. See discussion in issue #21617.\n    \"\"\"\n    if any_symbolic_tensors((x1, x2)):\n        return Correlate(mode=mode).symbolic_call(x1, x2)\n    return backend.numpy.correlate(x1, x2, mode=mode)\n\n\nclass Select(Operation):\n    def call(self, condlist, choicelist, default=0):\n        return backend.numpy.select(condlist, choicelist, default)\n\n    def compute_output_spec(self, condlist, choicelist, default=0):\n        first_element = choicelist[0]\n        return KerasTensor(first_element.shape, dtype=first_element.dtype)\n\n\n@keras_export([\"keras.ops.select\", \"keras.ops.numpy.select\"])\ndef select(condlist, choicelist, default=0):\n    \"\"\"Return elements from `choicelist`, based on conditions in `condlist`.\n\n    Args:\n        condlist: List of boolean tensors.\n            The list of conditions which determine from which array\n            in choicelist the output elements are taken.\n            When multiple conditions are satisfied,\n            the first one encountered in condlist is used.\n        choicelist: List of tensors.\n            The list of tensors from which the output elements are taken.\n            This list has to be of the same length as `condlist`.\n        defaults: Optional scalar value.\n            The element inserted in the output\n            when all conditions evaluate to `False`.\n\n    Returns:\n        Tensor where the output at position `m` is the `m`-th element\n        of the tensor in `choicelist` where the `m`-th element of the\n        corresponding tensor in `condlist` is `True`.\n\n    Example:\n\n    ```python\n    from keras import ops\n\n    x = ops.arange(6)\n    condlist = [x<3, x>3]\n    choicelist = [x, x**2]\n    ops.select(condlist, choicelist, 42)\n    # Returns: tensor([0,  1,  2, 42, 16, 25])\n    ```\n    \"\"\"\n    if not isinstance(condlist, (list, tuple)) or not isinstance(\n        choicelist, (list, tuple)\n    ):\n        raise ValueError(\n            \"condlist and choicelist must be lists. Received: \"\n            f\"type(condlist) = {type(condlist)}, \"\n            f\"type(choicelist) = {type(choicelist)}\"\n        )\n    condlist = list(condlist)\n    choicelist = list(choicelist)\n    if not condlist or not choicelist:\n        raise ValueError(\n            \"condlist and choicelist must not be empty. Received: \"\n            f\"condlist = {condlist}, \"\n            f\"choicelist = {choicelist}\"\n        )\n    if any_symbolic_tensors(condlist + choicelist + [default]):\n        return Select().symbolic_call(condlist, choicelist, default)\n    return backend.numpy.select(condlist, choicelist, default)\n\n\nclass Slogdet(Operation):\n    def call(self, x):\n        return backend.numpy.slogdet(x)\n\n    def compute_output_spec(self, x):\n        sign = KerasTensor((), dtype=x.dtype)\n        logabsdet = KerasTensor(x.shape[:-2], dtype=x.dtype)\n        return (sign, logabsdet)\n\n\n@keras_export([\"keras.ops.slogdet\", \"keras.ops.numpy.slogdet\"])\ndef slogdet(x):\n    \"\"\"Compute the sign and natural logarithm of the determinant of a matrix.\n\n    Args:\n        x: Input matrix. It must 2D and square.\n\n    Returns:\n        A tuple `(sign, logabsdet)`. `sign` is a number representing\n        the sign of the determinant. For a real matrix, this is 1, 0, or -1.\n        For a complex matrix, this is a complex number with absolute value 1\n        (i.e., it is on the unit circle), or else 0.\n        `logabsdet` is the natural log of the absolute value of the determinant.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Slogdet().symbolic_call(x)\n    return backend.numpy.slogdet(x)\n\n\nclass Argpartition(Operation):\n    def __init__(self, kth, axis=-1, *, name=None):\n        super().__init__(name=name)\n        if not isinstance(kth, int):\n            raise ValueError(f\"kth must be an integer. Received:kth = {kth}\")\n        self.kth = kth\n        self.axis = axis\n\n    def call(self, x):\n        return backend.numpy.argpartition(x, kth=self.kth, axis=self.axis)\n\n    def compute_output_spec(self, x):\n        return KerasTensor(x.shape, dtype=\"int32\")\n\n\n@keras_export([\"keras.ops.argpartition\", \"keras.ops.numpy.argpartition\"])\ndef argpartition(x, kth, axis=-1):\n    \"\"\"Performs an indirect partition along the given axis.\n\n    It returns an array\n    of indices of the same shape as `x` that index data along the given axis\n    in partitioned order.\n\n    Args:\n        a: Array to sort.\n        kth: Element index to partition by.\n            The k-th element will be in its final sorted position and all\n            smaller elements will be moved before it and all larger elements\n            behind it. The order of all elements in the partitions is undefined.\n            If provided with a sequence of k-th it will partition all of them\n            into their sorted position at once.\n        axis: Axis along which to sort. The default is -1 (the last axis).\n            If `None`, the flattened array is used.\n\n    Returns:\n        Array of indices that partition `x` along the specified `axis`.\n    \"\"\"\n    if any_symbolic_tensors((x,)):\n        return Argpartition(kth, axis).symbolic_call(x)\n    return backend.numpy.argpartition(x, kth, axis)\n\n\nclass Histogram(Operation):\n    def __init__(self, bins=10, range=None, *, name=None):\n        super().__init__(name=name)\n\n        if not isinstance(bins, int):\n            raise TypeError(\"bins must be of type `int`\")\n        if bins < 0:\n            raise ValueError(\"`bins` should be a non-negative integer\")\n\n        if range:\n            if len(range) < 2 or not isinstance(range, tuple):\n                raise ValueError(\"range must be a tuple of two elements\")\n\n            if range[1] < range[0]:\n                raise ValueError(\n                    \"The second element of range must be greater than the first\"\n                )\n\n        self.bins = bins\n        self.range = range\n\n    def call(self, x):\n        x = backend.convert_to_tensor(x)\n        if len(x.shape) > 1:\n            raise ValueError(\"Input tensor must be 1-dimensional\")\n        return backend.math.histogram(x, bins=self.bins, range=self.range)\n\n    def compute_output_spec(self, x):\n        return (\n            KerasTensor(shape=(self.bins,), dtype=x.dtype),\n            KerasTensor(shape=(self.bins + 1,), dtype=x.dtype),\n        )\n\n\n@keras_export([\"keras.ops.histogram\", \"keras.ops.numpy.histogram\"])\ndef histogram(x, bins=10, range=None):\n    \"\"\"Computes a histogram of the data tensor `x`.\n\n    Args:\n        x: Input tensor.\n        bins: An integer representing the number of histogram bins.\n            Defaults to 10.\n        range: A tuple representing the lower and upper range of the bins.\n            If not specified, it will use the min and max of `x`.\n\n    Returns:\n        A tuple containing:\n        - A tensor representing the counts of elements in each bin.\n        - A tensor representing the bin edges.\n\n    Example:\n    >>> input_tensor = np.random.rand(8)\n    >>> keras.ops.histogram(input_tensor)\n    (array([1, 1, 1, 0, 0, 1, 2, 1, 0, 1], dtype=int32),\n    array([0.0189519 , 0.10294958, 0.18694726, 0.27094494, 0.35494262,\n        0.43894029, 0.52293797, 0.60693565, 0.69093333, 0.77493101,\n        0.85892869]))\n    \"\"\"\n    if not isinstance(bins, int):\n        raise TypeError(\n            f\"Argument `bins` must be of type `int`. Received: bins={bins}\"\n        )\n    if bins < 0:\n        raise ValueError(\n            \"Argument `bins` should be a non-negative integer. \"\n            f\"Received: bins={bins}\"\n        )\n\n    if range:\n        if len(range) < 2 or not isinstance(range, tuple):\n            raise ValueError(\n                \"Argument `range` must be a tuple of two elements. \"\n                f\"Received: range={range}\"\n            )\n\n        if range[1] < range[0]:\n            raise ValueError(\n                \"The second element of `range` must be greater than the first. \"\n                f\"Received: range={range}\"\n            )\n\n    if any_symbolic_tensors((x,)):\n        return Histogram(bins=bins, range=range).symbolic_call(x)\n\n    x = backend.convert_to_tensor(x)\n    if len(x.shape) > 1:\n        raise ValueError(\n            \"Input tensor must be 1-dimensional. \"\n            f\"Received: input.shape={x.shape}\"\n        )\n    return backend.numpy.histogram(x, bins=bins, range=range)\n\n\nclass ArraySplit(Operation):\n    def __init__(self, indices_or_sections, axis=0, *, name=None):\n        super().__init__(name=name)\n\n        self.indices_or_sections = indices_or_sections\n        self.axis = axis\n\n    def call(self, x):\n        return backend.numpy.array_split(\n            x,\n            indices_or_sections=self.indices_or_sections,\n            axis=self.axis,\n        )\n\n    def compute_output_spec(self, x):\n        num_splits = self.indices_or_sections\n\n        axis = self.axis\n        if axis < 0:\n            axis += len(x.shape)\n\n        total_size = x.shape[axis]\n\n        if total_size is None:\n            output_specs = []\n            base_shape = list(x.shape)\n            base_shape[axis] = None\n            for _ in range(num_splits):\n                output_specs.append(\n                    KerasTensor(shape=tuple(base_shape), dtype=x.dtype)\n                )\n            return tuple(output_specs)\n\n        split_size = total_size // num_splits\n        remainder = total_size % num_splits\n\n        output_specs = []\n        base_shape = list(x.shape)\n        for i in range(num_splits):\n            size = split_size + (1 if i < remainder else 0)\n            shape = base_shape.copy()\n            shape[axis] = size\n            output_specs.append(KerasTensor(shape=tuple(shape), dtype=x.dtype))\n\n        return list(output_specs)\n\n\n@keras_export([\"keras.ops.array_split\", \"keras.ops.numpy.array_split\"])\ndef array_split(x, indices_or_sections, axis=0):\n    \"\"\"Splits an array into multiple sub-arrays (unevenly).\n\n    This is similar to `keras.ops.split`, but it allows for\n    unequal splits. `indices_or_sections` must be an integer\n    that indicates the total number of sub-arrays to create.\n    If the tensor cannot be divided evenly, the first `remainder`\n    splits will have size `quotient + 1`, and the rest will\n    have size `quotient`.\n\n    Args:\n        x: Input tensor.\n        indices_or_sections: An integer indicating the number of\n            sub-arrays to create.\n        axis: The axis along which to split. Defaults to 0.\n\n    Returns:\n        A list of sub-tensors.\n\n    Example:\n    >>> x = keras.ops.arange(10)\n    >>> keras.ops.array_split(x, 3)\n    (array([0, 1, 2, 3], dtype=int32),\n     array([4, 5, 6], dtype=int32),\n     array([7, 8, 9], dtype=int32))\n    \"\"\"\n    if not isinstance(indices_or_sections, int):\n        raise TypeError(\n            \"Argument `indices_or_sections` must be of type `int`. \"\n            f\"Received: indices_or_sections={indices_or_sections}\"\n        )\n\n    if indices_or_sections <= 0:\n        raise ValueError(\n            \"Argument `indices_or_sections` must be a positive integer. \"\n            f\"Received: indices_or_sections={indices_or_sections}\"\n        )\n\n    if not isinstance(axis, int):\n        raise TypeError(\n            f\"Argument `axis` must be of type `int`. Received: {axis}\"\n        )\n\n    if any_symbolic_tensors((x,)):\n        return ArraySplit(\n            indices_or_sections=indices_or_sections, axis=axis\n        ).symbolic_call(x)\n\n    return backend.numpy.array_split(\n        x, indices_or_sections=indices_or_sections, axis=axis\n    )\n"
  },
  {
    "path": "keras/src/ops/numpy_test.py",
    "content": "import functools\nimport itertools\nimport math\nimport warnings\n\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nimport keras\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.backend.common import dtypes\nfrom keras.src.backend.common import is_int_dtype\nfrom keras.src.backend.common import standardize_dtype\nfrom keras.src.backend.common.keras_tensor import KerasTensor\nfrom keras.src.ops import numpy as knp\nfrom keras.src.testing.test_utils import named_product\n\n\nclass NumPyTestRot90(testing.TestCase):\n    def test_basic_rotation(self):\n        array = np.array([[1, 2, 3], [4, 5, 6]])\n        rotated = knp.rot90(array)\n        expected = np.rot90(array)\n        self.assertAllClose(rotated, expected)\n\n    @parameterized.named_parameters(\n        (\"k_0\", 0, [[1, 2], [3, 4]]),\n        (\"k_1\", 1, [[2, 4], [1, 3]]),\n        (\"k_2\", 2, [[4, 3], [2, 1]]),\n        (\"k_neg1\", -1, [[3, 1], [4, 2]]),\n        (\"k_5\", 5, [[2, 4], [1, 3]]),  # k=5 ≡ k=1 (mod 4)\n        (\"k_6\", 6, [[4, 3], [2, 1]]),  # k=6 ≡ k=2 (mod 4)\n    )\n    def test_k_parameter_variations(self, k, expected):\n        array = np.array([[1, 2], [3, 4]])\n        rotated = knp.rot90(array, k=k)\n        expected = np.array(expected)\n        self.assertAllClose(rotated, expected)\n\n    @parameterized.named_parameters(\n        (\"axes_0_1\", (0, 1)), (\"axes_1_2\", (1, 2)), (\"axes_0_2\", (0, 2))\n    )\n    def test_3d_operations(self, axes):\n        array_3d = np.arange(12).reshape(3, 2, 2)\n        rotated = knp.rot90(array_3d, axes=axes)\n        expected = np.rot90(array_3d, axes=axes)\n        self.assertAllClose(rotated, expected)\n\n    @parameterized.named_parameters(\n        (\"single_image\", np.random.random((4, 4, 3))),\n        (\"batch_images\", np.random.random((2, 4, 4, 3))),\n    )\n    def test_image_processing(self, array):\n        np.random.seed(0)\n        rotated = knp.rot90(array, axes=(0, 1))\n        expected = np.rot90(array, axes=(0, 1))\n        self.assertAllClose(rotated, expected)\n\n    @parameterized.named_parameters(\n        (\"single_row\", [[1, 2, 3]]),\n        (\"single_column\", [[1], [2], [3]]),\n        (\"negative_values\", [[-1, 0], [1, -2]]),\n    )\n    def test_edge_conditions(self, array):\n        numpy_array = np.array(array)\n        rotated = knp.rot90(numpy_array)\n        expected = np.rot90(numpy_array)\n        self.assertAllClose(rotated, expected)\n\n    @parameterized.named_parameters(\n        (\"1D_array\", np.array([1, 2, 3]), None),\n        (\"duplicate_axes\", np.array([[1, 2], [3, 4]]), (0, 0)),\n    )\n    def test_error_conditions(self, array, axes):\n        if axes is None:\n            with self.assertRaises(ValueError):\n                knp.rot90(array)\n        else:\n            with self.assertRaises(ValueError):\n                knp.rot90(array, axes=axes)\n\n\nclass NumpyTwoInputOpsDynamicShapeTest(testing.TestCase):\n    def test_add(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.add(x, y).shape, (2, 3))\n\n    def test_heaviside(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((None, 3))\n        self.assertEqual(knp.heaviside(x, y).shape, (None, 3))\n\n    def test_hypot(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((None, 3))\n        self.assertEqual(knp.hypot(x, y).shape, (None, 3))\n\n    def test_subtract(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.subtract(x, y).shape, (2, 3))\n\n    def test_multiply(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.multiply(x, y).shape, (2, 3))\n\n    def test_matmul(self):\n        x = KerasTensor((None, 3, 4))\n        y = KerasTensor((3, None, 4, 5))\n        self.assertEqual(knp.matmul(x, y).shape, (3, None, 3, 5))\n\n    def test_power(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.power(x, y).shape, (2, 3))\n\n    def test_divide(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.divide(x, y).shape, (2, 3))\n\n    def test_divide_no_nan(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.divide_no_nan(x, y).shape, (2, 3))\n\n    def test_true_divide(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.true_divide(x, y).shape, (2, 3))\n\n    def test_append(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.append(x, y).shape, (None,))\n\n    def test_arctan2(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.arctan2(x, y).shape, (2, 3))\n\n    def test_bitwise_and(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((None, 3))\n        self.assertEqual(knp.bitwise_and(x, y).shape, (None, 3))\n\n    def test_bitwise_or(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((None, 3))\n        self.assertEqual(knp.bitwise_or(x, y).shape, (None, 3))\n\n    def test_bitwise_xor(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((None, 3))\n        self.assertEqual(knp.bitwise_xor(x, y).shape, (None, 3))\n\n    def test_bitwise_left_shift(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((None, 3))\n        self.assertEqual(knp.bitwise_left_shift(x, y).shape, (None, 3))\n\n    # left_shift is same as bitwise_left_shift\n\n    def test_bitwise_right_shift(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((None, 3))\n        self.assertEqual(knp.bitwise_right_shift(x, y).shape, (None, 3))\n\n    # right_shift is same as bitwise_right_shift\n\n    def test_cross(self):\n        x1 = KerasTensor((2, 3, 3))\n        x2 = KerasTensor((1, 3, 2))\n        y = KerasTensor((None, 1, 2))\n        self.assertEqual(knp.cross(x1, y).shape, (2, 3, 3))\n        self.assertEqual(knp.cross(x2, y).shape, (None, 3))\n\n    def test_einsum(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((3, 4))\n        self.assertEqual(knp.einsum(\"ij,jk->ik\", x, y).shape, (None, 4))\n        self.assertEqual(knp.einsum(\"ij,jk->ikj\", x, y).shape, (None, 4, 3))\n        self.assertEqual(knp.einsum(\"ii\", x).shape, ())\n        self.assertEqual(knp.einsum(\",ij\", 5, x).shape, (None, 3))\n\n        x = KerasTensor((None, 3, 4))\n        y = KerasTensor((None, 4, 5))\n        z = KerasTensor((1, 1, 1, 9))\n        self.assertEqual(knp.einsum(\"ijk,jkl->li\", x, y).shape, (5, None))\n        self.assertEqual(knp.einsum(\"ijk,jkl->lij\", x, y).shape, (5, None, 3))\n        self.assertEqual(\n            knp.einsum(\"...,...j->...j\", x, y).shape, (None, 3, 4, 5)\n        )\n        self.assertEqual(\n            knp.einsum(\"i...,...j->i...j\", x, y).shape, (None, 3, 4, 5)\n        )\n        self.assertEqual(knp.einsum(\"i...,...j\", x, y).shape, (3, 4, None, 5))\n        self.assertEqual(\n            knp.einsum(\"i...,...j,...k\", x, y, z).shape, (1, 3, 4, None, 5, 9)\n        )\n        self.assertEqual(\n            knp.einsum(\"mij,ijk,...\", x, y, z).shape, (1, 1, 1, 9, 5, None)\n        )\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((None, 3))\n            y = KerasTensor((3, 4))\n            knp.einsum(\"ijk,jk->ik\", x, y)\n\n    def test_full_like(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.full_like(x, KerasTensor((1, 3))).shape, (None, 3))\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.full_like(x, 2).shape, (None, 3, 3))\n\n    def test_gcd(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.gcd(x, y).shape, (2, 3))\n\n    def test_geomspace(self):\n        start = KerasTensor((None, 3, 4))\n        stop = KerasTensor((2, 3, 4))\n        self.assertEqual(\n            knp.geomspace(start, stop, 10, axis=1).shape, (2, 10, 3, 4)\n        )\n\n        start = KerasTensor((None, 3))\n        stop = 2\n        self.assertEqual(\n            knp.geomspace(start, stop, 10, axis=1).shape, (None, 10, 3)\n        )\n\n    def test_greater(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.greater(x, y).shape, (2, 3))\n\n    def test_greater_equal(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.greater_equal(x, y).shape, (2, 3))\n\n    def test_allclose(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.allclose(x, y).shape, ())\n\n    def test_isclose(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.isclose(x, y).shape, (2, 3))\n\n    def test_isin(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.isin(x, y).shape, (None, 3))\n\n    def test_kron(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.kron(x, y).shape, (None, None))\n\n    def test_lcm(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.lcm(x, y).shape, (2, 3))\n\n    def test_ldexp(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((1, 3))\n        self.assertEqual(knp.ldexp(x, y).shape, (None, 3))\n\n    def test_less(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.less(x, y).shape, (2, 3))\n\n    def test_less_equal(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.less_equal(x, y).shape, (2, 3))\n\n    def test_linspace(self):\n        start = KerasTensor((None, 3, 4))\n        stop = KerasTensor((2, 3, 4))\n        self.assertEqual(\n            knp.linspace(start, stop, 10, axis=1).shape, (2, 10, 3, 4)\n        )\n\n        start = KerasTensor((None, 3))\n        stop = 2\n        self.assertEqual(\n            knp.linspace(start, stop, 10, axis=1).shape, (None, 10, 3)\n        )\n\n    def test_logical_and(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.logical_and(x, y).shape, (2, 3))\n\n    def test_logical_or(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.logical_or(x, y).shape, (2, 3))\n\n    def test_logspace(self):\n        start = KerasTensor((None, 3, 4))\n        stop = KerasTensor((2, 3, 4))\n        self.assertEqual(\n            knp.logspace(start, stop, 10, axis=1).shape, (2, 10, 3, 4)\n        )\n\n        start = KerasTensor((None, 3))\n        stop = 2\n        self.assertEqual(\n            knp.logspace(start, stop, 10, axis=1).shape, (None, 10, 3)\n        )\n\n    def test_maximum(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.maximum(x, y).shape, (2, 3))\n\n    def test_minimum(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.minimum(x, y).shape, (2, 3))\n\n    def test_mod(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.mod(x, y).shape, (2, 3))\n\n    def test_fmod(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.fmod(x, y).shape, (2, 3))\n\n    def test_nextafter(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((1, 3))\n        self.assertEqual(knp.nextafter(x, y).shape, (None, 3))\n\n    def test_not_equal(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.not_equal(x, y).shape, (2, 3))\n\n    def test_outer(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.outer(x, y).shape, (None, None))\n\n    def test_quantile(self):\n        x = KerasTensor((None, 3))\n\n        # q as scalar\n        q = KerasTensor(())\n        self.assertEqual(knp.quantile(x, q).shape, ())\n\n        # q as 1D tensor\n        q = KerasTensor((2,))\n        self.assertEqual(knp.quantile(x, q).shape, (2,))\n        self.assertEqual(knp.quantile(x, q, axis=1).shape, (2, None))\n        self.assertEqual(\n            knp.quantile(x, q, axis=1, keepdims=True).shape,\n            (2, None, 1),\n        )\n\n    def test_searchsorted(self):\n        a = KerasTensor((None,))\n        v = KerasTensor((2, 3))\n\n        output = knp.searchsorted(a, v)\n        self.assertEqual(output.shape, v.shape)\n        self.assertEqual(output.dtype, \"int64\")\n\n    def test_take(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.take(x, 1).shape, ())\n        self.assertEqual(knp.take(x, [1, 2]).shape, (2,))\n        self.assertEqual(\n            knp.take(x, [[1, 2], [1, 2]], axis=1).shape, (None, 2, 2)\n        )\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.take(x, 1, axis=1).shape, (None, 3))\n        self.assertEqual(knp.take(x, [1, 2]).shape, (2,))\n        self.assertEqual(\n            knp.take(x, [[1, 2], [1, 2]], axis=1).shape, (None, 2, 2, 3)\n        )\n\n        # test with negative axis\n        self.assertEqual(knp.take(x, 1, axis=-2).shape, (None, 3))\n\n        # test with multi-dimensional indices\n        x = KerasTensor((None, 3, None, 5))\n        indices = KerasTensor((6, 7))\n        self.assertEqual(knp.take(x, indices, axis=2).shape, (None, 3, 6, 7, 5))\n\n    def test_take_along_axis(self):\n        x = KerasTensor((None, 3))\n        indices = KerasTensor((1, 3))\n        self.assertEqual(knp.take_along_axis(x, indices, axis=0).shape, (1, 3))\n        self.assertEqual(\n            knp.take_along_axis(x, indices, axis=1).shape, (None, 3)\n        )\n\n        x = KerasTensor((None, 3, 3))\n        indices = KerasTensor((1, 3, None))\n        self.assertEqual(\n            knp.take_along_axis(x, indices, axis=1).shape, (None, 3, 3)\n        )\n\n    def test_tensordot(self):\n        x = KerasTensor((None, 3, 4))\n        y = KerasTensor((3, 4))\n        self.assertEqual(knp.tensordot(x, y, axes=1).shape, (None, 3, 4))\n        self.assertEqual(knp.tensordot(x, y, axes=[[0, 1], [1, 0]]).shape, (4,))\n\n    def test_vdot(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((None, 3))\n        self.assertEqual(knp.vdot(x, y).shape, ())\n\n        x = KerasTensor((None, 3, 3))\n        y = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.vdot(x, y).shape, ())\n\n    def test_inner(self):\n        x = KerasTensor((None,))\n        y = KerasTensor((3,))\n        self.assertEqual(knp.inner(x, y).shape, ())\n\n    def test_where(self):\n        condition = KerasTensor((2, None, 1))\n        x = KerasTensor((None, 1))\n        y = KerasTensor((None, 3))\n        self.assertEqual(knp.where(condition, x, y).shape, (2, None, 3))\n        self.assertEqual(knp.where(condition).shape, (2, None, 1))\n\n    def test_floor_divide(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.floor_divide(x, y).shape, (2, 3))\n\n    def test_xor(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((2, None))\n        self.assertEqual(knp.logical_xor(x, y).shape, (2, 3))\n\n    def test_shape_equal_basic_equality(self):\n        x = KerasTensor((3, 4)).shape\n        y = KerasTensor((3, 4)).shape\n        self.assertTrue(knp.shape_equal(x, y))\n        y = KerasTensor((3, 5)).shape\n        self.assertFalse(knp.shape_equal(x, y))\n\n    def test_shape_equal_allow_none(self):\n        x = KerasTensor((3, 4, None)).shape\n        y = KerasTensor((3, 4, 5)).shape\n        self.assertTrue(knp.shape_equal(x, y, allow_none=True))\n        self.assertFalse(knp.shape_equal(x, y, allow_none=False))\n\n    def test_shape_equal_different_shape_lengths(self):\n        x = KerasTensor((3, 4)).shape\n        y = KerasTensor((3, 4, 5)).shape\n        self.assertFalse(knp.shape_equal(x, y))\n\n    def test_shape_equal_ignore_axes(self):\n        x = KerasTensor((3, 4, 5)).shape\n        y = KerasTensor((3, 6, 5)).shape\n        self.assertTrue(knp.shape_equal(x, y, axis=1))\n        y = KerasTensor((3, 6, 7)).shape\n        self.assertTrue(knp.shape_equal(x, y, axis=(1, 2)))\n        self.assertFalse(knp.shape_equal(x, y, axis=1))\n\n    def test_shape_equal_only_none(self):\n        x = KerasTensor((None, None)).shape\n        y = KerasTensor((5, 6)).shape\n        self.assertTrue(knp.shape_equal(x, y, allow_none=True))\n\n    def test_shape_equal_axis_as_list(self):\n        x = KerasTensor((3, 4, 5)).shape\n        y = KerasTensor((3, 6, 5)).shape\n        self.assertTrue(knp.shape_equal(x, y, axis=[1]))\n\n    def test_shape_non_equal_with_negative_axis(self):\n        x = KerasTensor((3, 4, 5)).shape\n        y = KerasTensor((3, 4, 6)).shape\n        self.assertFalse(knp.shape_equal(x, y, axis=-2))\n\n    def test_shape_equal_with_negative_axis(self):\n        x = KerasTensor((3, 4, 5)).shape\n        y = KerasTensor((3, 4, 5)).shape\n        self.assertTrue(knp.shape_equal(x, y, axis=-1))\n\n    def test_shape_equal_zeros(self):\n        x = KerasTensor((0, 4)).shape\n        y = KerasTensor((0, 4)).shape\n        self.assertTrue(knp.shape_equal(x, y))\n        y = KerasTensor((0, 5)).shape\n        self.assertFalse(knp.shape_equal(x, y))\n\n    def test_broadcast_shapes_conversion_to_list(self):\n        shape1 = KerasTensor((1, 2)).shape\n        shape2 = KerasTensor((3, 1)).shape\n        expected_output = [3, 2]\n        self.assertEqual(knp.broadcast_shapes(shape1, shape2), expected_output)\n\n    def test_broadcast_shapes_shape1_longer_than_shape2(self):\n        shape1 = KerasTensor((5, 3, 2)).shape\n        shape2 = KerasTensor((1, 3)).shape\n        with self.assertRaisesRegex(ValueError, \"Cannot broadcast shape\"):\n            knp.broadcast_shapes(shape1, shape2)\n\n    def test_broadcast_shapes_shape2_longer_than_shape1(self):\n        shape1 = KerasTensor((5, 3)).shape\n        shape2 = KerasTensor((2, 5, 3)).shape\n        expected_output = [2, 5, 3]\n        self.assertEqual(knp.broadcast_shapes(shape1, shape2), expected_output)\n\n    def test_broadcast_shapes_broadcasting_shape1_is_1(self):\n        shape1 = KerasTensor((1, 3)).shape\n        shape2 = KerasTensor((5, 1)).shape\n        expected_output = [5, 3]\n        self.assertEqual(knp.broadcast_shapes(shape1, shape2), expected_output)\n\n    def test_broadcast_shapes_broadcasting_shape1_is_none(self):\n        shape1 = KerasTensor((None, 3)).shape\n        shape2 = KerasTensor((5, 1)).shape\n        expected_output = [5, 3]\n        self.assertEqual(knp.broadcast_shapes(shape1, shape2), expected_output)\n\n        shape1 = KerasTensor((None, 3)).shape\n        shape2 = KerasTensor((5, 3)).shape\n        expected_output = [5, 3]\n        self.assertEqual(knp.broadcast_shapes(shape1, shape2), expected_output)\n\n    def test_broadcast_shapes_broadcasting_shape2_conditions(self):\n        shape1 = KerasTensor((5, 3, 2)).shape\n        shape2 = KerasTensor((1, 3, 2)).shape\n        expected_output = [5, 3, 2]\n        self.assertEqual(knp.broadcast_shapes(shape1, shape2), expected_output)\n\n        shape1 = KerasTensor((5, 3, 2)).shape\n        shape2 = KerasTensor((1, None, 2)).shape\n        expected_output = [5, 3, 2]\n        self.assertEqual(knp.broadcast_shapes(shape1, shape2), expected_output)\n\n\nclass NumpyTwoInputOpsStaticShapeTest(testing.TestCase):\n    def test_add(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.add(x, y).shape, (2, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((2, 3, 4))\n            knp.add(x, y)\n\n    def test_heaviside(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.heaviside(x, y).shape, (2, 3))\n\n        x = KerasTensor((2, 3))\n        y = KerasTensor((3,))\n        self.assertEqual(knp.heaviside(x, y).shape, (2, 3))\n\n        x = KerasTensor((2, 3))\n        y = KerasTensor((1, 3))\n        self.assertEqual(knp.heaviside(x, y).shape, (2, 3))\n\n    def test_hypot(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.hypot(x, y).shape, (2, 3))\n\n    def test_subtract(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.subtract(x, y).shape, (2, 3))\n\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.subtract(x, 2).shape, (2, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((2, 3, 4))\n            knp.subtract(x, y)\n\n    def test_multiply(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.multiply(x, y).shape, (2, 3))\n\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.multiply(x, 2).shape, (2, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((2, 3, 4))\n            knp.multiply(x, y)\n\n    def test_matmul(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((3, 2))\n        self.assertEqual(knp.matmul(x, y).shape, (2, 2))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((3, 4))\n            y = KerasTensor((2, 3, 4))\n            knp.matmul(x, y)\n\n    def test_matmul_sparse(self):\n        x = KerasTensor((2, 3), sparse=True)\n        y = KerasTensor((3, 2))\n        result = knp.matmul(x, y)\n        self.assertEqual(result.shape, (2, 2))\n\n        x = KerasTensor((2, 3))\n        y = KerasTensor((3, 2), sparse=True)\n        result = knp.matmul(x, y)\n        self.assertEqual(result.shape, (2, 2))\n\n        x = KerasTensor((2, 3), sparse=True)\n        y = KerasTensor((3, 2), sparse=True)\n        result = knp.matmul(x, y)\n        self.assertEqual(result.shape, (2, 2))\n        self.assertTrue(result.sparse)\n\n    def test_power(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.power(x, y).shape, (2, 3))\n\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.power(x, 2).shape, (2, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((2, 3, 4))\n            knp.power(x, y)\n\n    def test_divide(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.divide(x, y).shape, (2, 3))\n\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.divide(x, 2).shape, (2, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((2, 3, 4))\n            knp.divide(x, y)\n\n    def test_divide_no_nan(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.divide_no_nan(x, y).shape, (2, 3))\n\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.divide_no_nan(x, 2).shape, (2, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((2, 3, 4))\n            knp.divide_no_nan(x, y)\n\n    def test_true_divide(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.true_divide(x, y).shape, (2, 3))\n\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.true_divide(x, 2).shape, (2, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((2, 3, 4))\n            knp.true_divide(x, y)\n\n    def test_append(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.append(x, y).shape, (12,))\n\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.append(x, y, axis=0).shape, (4, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((2, 3, 4))\n            knp.append(x, y, axis=2)\n\n    def test_arctan2(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.arctan2(x, y).shape, (2, 3))\n\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.arctan2(x, 2).shape, (2, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((2, 3, 4))\n            knp.arctan2(x, y)\n\n    def test_bitwise_and(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.bitwise_and(x, y).shape, (2, 3))\n\n    def test_bitwise_or(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.bitwise_or(x, y).shape, (2, 3))\n\n    def test_bitwise_xor(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.bitwise_xor(x, y).shape, (2, 3))\n\n    def test_bitwise_left_shift(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.bitwise_left_shift(x, y).shape, (2, 3))\n\n    # left_shift is same as bitwise_left_shift\n\n    def test_bitwise_right_shift(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.bitwise_right_shift(x, y).shape, (2, 3))\n\n    # right_shift is same as bitwise_right_shift\n\n    def test_cross(self):\n        x1 = KerasTensor((2, 3, 3))\n        x2 = KerasTensor((1, 3, 2))\n        y1 = KerasTensor((2, 3, 3))\n        y2 = KerasTensor((2, 3, 2))\n        self.assertEqual(knp.cross(x1, y1).shape, (2, 3, 3))\n        self.assertEqual(knp.cross(x2, y2).shape, (2, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((2, 3, 4))\n            knp.cross(x, y)\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((4, 3, 3))\n            y = KerasTensor((2, 3, 3))\n            knp.cross(x, y)\n\n    def test_einsum(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((3, 4))\n        self.assertEqual(knp.einsum(\"ij,jk->ik\", x, y).shape, (2, 4))\n        self.assertEqual(knp.einsum(\"ij,jk->ikj\", x, y).shape, (2, 4, 3))\n        self.assertEqual(knp.einsum(\"ii\", x).shape, ())\n        self.assertEqual(knp.einsum(\",ij\", 5, x).shape, (2, 3))\n\n        x = KerasTensor((2, 3, 4))\n        y = KerasTensor((3, 4, 5))\n        z = KerasTensor((1, 1, 1, 9))\n        self.assertEqual(knp.einsum(\"ijk,jkl->li\", x, y).shape, (5, 2))\n        self.assertEqual(knp.einsum(\"ijk,jkl->lij\", x, y).shape, (5, 2, 3))\n        self.assertEqual(knp.einsum(\"...,...j->...j\", x, y).shape, (2, 3, 4, 5))\n        self.assertEqual(\n            knp.einsum(\"i...,...j->i...j\", x, y).shape, (2, 3, 4, 5)\n        )\n        self.assertEqual(knp.einsum(\"i...,...j\", x, y).shape, (3, 4, 2, 5))\n        self.assertEqual(knp.einsum(\"i...,...j\", x, y).shape, (3, 4, 2, 5))\n        self.assertEqual(\n            knp.einsum(\"i...,...j,...k\", x, y, z).shape, (1, 3, 4, 2, 5, 9)\n        )\n        self.assertEqual(\n            knp.einsum(\"mij,ijk,...\", x, y, z).shape, (1, 1, 1, 9, 5, 2)\n        )\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((3, 4))\n            knp.einsum(\"ijk,jk->ik\", x, y)\n\n    def test_full_like(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.full_like(x, 2).shape, (2, 3))\n\n    def test_gcd(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.gcd(x, y).shape, (2, 3))\n\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.gcd(x, 2).shape, (2, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((2, 3, 4))\n            knp.gcd(x, y)\n\n    def test_geomspace(self):\n        start = KerasTensor((2, 3, 4))\n        stop = KerasTensor((2, 3, 4))\n        self.assertEqual(knp.geomspace(start, stop, 10).shape, (10, 2, 3, 4))\n\n        with self.assertRaises(ValueError):\n            start = KerasTensor((2, 3))\n            stop = KerasTensor((2, 3, 4))\n            knp.geomspace(start, stop)\n\n    def test_greater(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.greater(x, y).shape, (2, 3))\n\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.greater(x, 2).shape, (2, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((2, 3, 4))\n            knp.greater(x, y)\n\n    def test_greater_equal(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.greater_equal(x, y).shape, (2, 3))\n\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.greater_equal(x, 2).shape, (2, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((2, 3, 4))\n            knp.greater_equal(x, y)\n\n    def test_allclose(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.allclose(x, y).shape, ())\n\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.allclose(x, 2).shape, ())\n\n    def test_isclose(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.isclose(x, y).shape, (2, 3))\n\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.isclose(x, 2).shape, (2, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((2, 3, 4))\n            knp.isclose(x, y)\n\n    def test_isin(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((3, 3))\n        self.assertEqual(knp.isin(x, y).shape, (2, 3))\n\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.isin(x, 2).shape, (2, 3))\n\n    def test_kron(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.kron(x, y).shape, (4, 9))\n\n    def test_lcm(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.lcm(x, y).shape, (2, 3))\n\n    def test_ldexp(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.ldexp(x, y).shape, (2, 3))\n\n        x = KerasTensor((2, 3))\n        y = KerasTensor((1, 3))\n        self.assertEqual(knp.ldexp(x, y).shape, (2, 3))\n\n    def test_less(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.less(x, y).shape, (2, 3))\n\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.less(x, 2).shape, (2, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((2, 3, 4))\n            knp.less(x, y)\n\n    def test_less_equal(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.less_equal(x, y).shape, (2, 3))\n\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.less_equal(x, 2).shape, (2, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((2, 3, 4))\n            knp.less_equal(x, y)\n\n    def test_linspace(self):\n        start = KerasTensor((2, 3, 4))\n        stop = KerasTensor((2, 3, 4))\n        self.assertEqual(knp.linspace(start, stop, 10).shape, (10, 2, 3, 4))\n\n        with self.assertRaises(ValueError):\n            start = KerasTensor((2, 3))\n            stop = KerasTensor((2, 3, 4))\n            knp.linspace(start, stop)\n\n    def test_logical_and(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.logical_and(x, y).shape, (2, 3))\n\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.logical_and(x, 2).shape, (2, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((2, 3, 4))\n            knp.logical_and(x, y)\n\n    def test_logical_or(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.logical_or(x, y).shape, (2, 3))\n\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.logical_or(x, 2).shape, (2, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((2, 3, 4))\n            knp.logical_or(x, y)\n\n    def test_logspace(self):\n        start = KerasTensor((2, 3, 4))\n        stop = KerasTensor((2, 3, 4))\n        self.assertEqual(knp.logspace(start, stop, 10).shape, (10, 2, 3, 4))\n\n        with self.assertRaises(ValueError):\n            start = KerasTensor((2, 3))\n            stop = KerasTensor((2, 3, 4))\n            knp.logspace(start, stop)\n\n    def test_maximum(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.maximum(x, y).shape, (2, 3))\n\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.maximum(x, 2).shape, (2, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((2, 3, 4))\n            knp.maximum(x, y)\n\n    def test_minimum(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.minimum(x, y).shape, (2, 3))\n\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.minimum(x, 2).shape, (2, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((2, 3, 4))\n            knp.minimum(x, y)\n\n    def test_mod(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.mod(x, y).shape, (2, 3))\n\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.mod(x, 2).shape, (2, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((2, 3, 4))\n            knp.mod(x, y)\n\n    def test_fmod(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.fmod(x, y).shape, (2, 3))\n\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.fmod(x, 2).shape, (2, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((2, 3, 4))\n            knp.fmod(x, y)\n\n    def test_nextafter(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.nextafter(x, y).shape, (2, 3))\n\n        x = KerasTensor((2, 3))\n        y = KerasTensor((1, 3))\n        self.assertEqual(knp.nextafter(x, y).shape, (2, 3))\n\n    def test_not_equal(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.not_equal(x, y).shape, (2, 3))\n\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.not_equal(x, 2).shape, (2, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((2, 3, 4))\n            knp.not_equal(x, y)\n\n    def test_outer(self):\n        x = KerasTensor((3,))\n        y = KerasTensor((4,))\n        self.assertEqual(knp.outer(x, y).shape, (3, 4))\n\n        x = KerasTensor((2, 3))\n        y = KerasTensor((4, 5))\n        self.assertEqual(knp.outer(x, y).shape, (6, 20))\n\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.outer(x, 2).shape, (6, 1))\n\n    def test_quantile(self):\n        x = KerasTensor((3, 3))\n\n        # q as scalar\n        q = KerasTensor(())\n        self.assertEqual(knp.quantile(x, q).shape, ())\n\n        # q as 1D tensor\n        q = KerasTensor((2,))\n        self.assertEqual(knp.quantile(x, q).shape, (2,))\n        self.assertEqual(knp.quantile(x, q, axis=1).shape, (2, 3))\n        self.assertEqual(\n            knp.quantile(x, q, axis=1, keepdims=True).shape,\n            (2, 3, 1),\n        )\n\n    def test_searchsorted(self):\n        a = KerasTensor((3,))\n        v = KerasTensor((2, 3))\n\n        self.assertEqual(knp.searchsorted(a, v).shape, v.shape)\n\n    def test_take(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.take(x, 1).shape, ())\n        self.assertEqual(knp.take(x, [1, 2]).shape, (2,))\n        self.assertEqual(knp.take(x, [[1, 2], [1, 2]], axis=1).shape, (2, 2, 2))\n\n        # test with multi-dimensional indices\n        x = KerasTensor((2, 3, 4, 5))\n        indices = KerasTensor((6, 7))\n        self.assertEqual(knp.take(x, indices, axis=2).shape, (2, 3, 6, 7, 5))\n\n    def test_take_along_axis(self):\n        x = KerasTensor((2, 3))\n        indices = KerasTensor((1, 3))\n        self.assertEqual(knp.take_along_axis(x, indices, axis=0).shape, (1, 3))\n        self.assertEqual(knp.take_along_axis(x, indices, axis=1).shape, (2, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            indices = KerasTensor((1, 4))\n            knp.take_along_axis(x, indices, axis=0)\n\n    def test_tensordot(self):\n        x = KerasTensor((2, 3, 3))\n        y = KerasTensor((3, 3, 4))\n        self.assertEqual(knp.tensordot(x, y, axes=1).shape, (2, 3, 3, 4))\n        self.assertEqual(knp.tensordot(x, y, axes=2).shape, (2, 4))\n        self.assertEqual(\n            knp.tensordot(x, y, axes=[[1, 2], [0, 1]]).shape, (2, 4)\n        )\n\n    def test_vdot(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.vdot(x, y).shape, ())\n\n    def test_inner(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.inner(x, y).shape, ())\n\n    def test_where(self):\n        condition = KerasTensor((2, 3))\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.where(condition, x, y).shape, (2, 3))\n        self.assertAllEqual(knp.where(condition).shape, (2, 3))\n\n    def test_floor_divide(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.floor_divide(x, y).shape, (2, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((2, 3, 4))\n            knp.floor_divide(x, y)\n\n    def test_xor(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.logical_xor(x, y).shape, (2, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((2, 3, 4))\n            knp.logical_xor(x, y)\n\n    def test_digitize(self):\n        x = KerasTensor((2, 3))\n        bins = KerasTensor((3,))\n        self.assertEqual(knp.digitize(x, bins).shape, (2, 3))\n        self.assertTrue(knp.digitize(x, bins).dtype == \"int32\")\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            bins = KerasTensor((2, 3, 4))\n            knp.digitize(x, bins)\n\n    def test_correlate_mode_valid(self):\n        x = KerasTensor((3,))\n        y = KerasTensor((3,))\n        self.assertEqual(knp.correlate(x, y).shape, (1,))\n        self.assertTrue(knp.correlate(x, y).dtype == \"float32\")\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((3,))\n            y = KerasTensor((3, 4))\n            knp.correlate(x, y)\n\n    def test_correlate_mode_same(self):\n        x = KerasTensor((3,))\n        y = KerasTensor((3,))\n        self.assertEqual(knp.correlate(x, y, mode=\"same\").shape, (3,))\n        self.assertTrue(knp.correlate(x, y, mode=\"same\").dtype == \"float32\")\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((3,))\n            y = KerasTensor((3, 4))\n            knp.correlate(x, y, mode=\"same\")\n\n    def test_correlate_mode_full(self):\n        x = KerasTensor((3,))\n        y = KerasTensor((3,))\n        self.assertEqual(knp.correlate(x, y, mode=\"full\").shape, (5,))\n        self.assertTrue(knp.correlate(x, y, mode=\"full\").dtype == \"float32\")\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((3))\n            y = KerasTensor((3, 4))\n            knp.correlate(x, y, mode=\"full\")\n\n\nclass NumpyOneInputOpsDynamicShapeTest(testing.TestCase):\n    def test_mean(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.mean(x).shape, ())\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.mean(x, axis=1).shape, (None, 3))\n        self.assertEqual(knp.mean(x, axis=1, keepdims=True).shape, (None, 1, 3))\n\n    def test_all(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.all(x).shape, ())\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.all(x, axis=1).shape, (None, 3))\n        self.assertEqual(knp.all(x, axis=1, keepdims=True).shape, (None, 1, 3))\n\n    def test_any(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.any(x).shape, ())\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.any(x, axis=1).shape, (None, 3))\n        self.assertEqual(knp.any(x, axis=1, keepdims=True).shape, (None, 1, 3))\n\n    def test_trapezoid(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.trapezoid(x).shape, (None,))\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.trapezoid(x, axis=1).shape, (None, 3))\n\n    def test_vander(self):\n        x = KerasTensor((None,))\n        self.assertEqual(knp.vander(x).shape, (None, None))\n\n    def test_var(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.var(x).shape, ())\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.var(x, axis=1).shape, (None, 3))\n        self.assertEqual(knp.var(x, axis=1, keepdims=True).shape, (None, 1, 3))\n\n    def test_sum(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.sum(x).shape, ())\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.sum(x, axis=1).shape, (None, 3))\n        self.assertEqual(knp.sum(x, axis=1, keepdims=True).shape, (None, 1, 3))\n\n    def test_amax(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.amax(x).shape, ())\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.amax(x, axis=1).shape, (None, 3))\n        self.assertEqual(knp.amax(x, axis=1, keepdims=True).shape, (None, 1, 3))\n\n    def test_amin(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.amin(x).shape, ())\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.amin(x, axis=1).shape, (None, 3))\n        self.assertEqual(knp.amin(x, axis=1, keepdims=True).shape, (None, 1, 3))\n\n    def test_square(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.square(x).shape, (None, 3))\n\n    def test_negative(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.negative(x).shape, (None, 3))\n\n    def test_abs(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.abs(x).shape, (None, 3))\n\n    def test_absolute(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.absolute(x).shape, (None, 3))\n\n    def test_squeeze(self):\n        x = KerasTensor((None, 1))\n        self.assertEqual(knp.squeeze(x).shape, (None,))\n        self.assertEqual(knp.squeeze(x, axis=1).shape, (None,))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((None, 1))\n            knp.squeeze(x, axis=0)\n\n        # Multiple axes\n        x = KerasTensor((None, 1, 1, 1))\n        self.assertEqual(knp.squeeze(x, (1, 2)).shape, (None, 1))\n        self.assertEqual(knp.squeeze(x, (-1, -2)).shape, (None, 1))\n        self.assertEqual(knp.squeeze(x, (1, 2, 3)).shape, (None,))\n        self.assertEqual(knp.squeeze(x, (-1, 1)).shape, (None, 1))\n\n    def test_transpose(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.transpose(x).shape, (3, None))\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.transpose(x, (2, 0, 1)).shape, (3, None, 3))\n\n    def test_arccos(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.arccos(x).shape, (None, 3))\n\n    def test_arccosh(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.arccosh(x).shape, (None, 3))\n\n    def test_arcsin(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.arcsin(x).shape, (None, 3))\n\n    def test_arcsinh(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.arcsinh(x).shape, (None, 3))\n\n    def test_arctan(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.arctan(x).shape, (None, 3))\n\n    def test_arctanh(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.arctanh(x).shape, (None, 3))\n\n    def test_argmax(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.argmax(x).shape, ())\n        self.assertEqual(knp.argmax(x, keepdims=True).shape, (None, 3))\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.argmax(x, axis=1).shape, (None, 3))\n        self.assertEqual(knp.argmax(x, keepdims=True).shape, (None, 3, 3))\n\n    @pytest.mark.skipif(\n        keras.config.backend() == \"openvino\" or testing.jax_uses_tpu(),\n        reason=\"OpenVINO and JAX TPU don't support this\",\n    )\n    def test_argmax_negative_zero(self):\n        input_data = np.array(\n            [-1.0, -0.0, 1.401298464324817e-45], dtype=np.float32\n        )\n        self.assertEqual(knp.argmax(input_data), 2)\n\n    @pytest.mark.skipif(\n        keras.config.backend() == \"openvino\" or testing.jax_uses_tpu(),\n        reason=\"OpenVINO and JAX TPU don't support this\",\n    )\n    def test_argmin_negative_zero(self):\n        input_data = np.array(\n            [\n                0.0,\n                1.1754943508222875e-38,\n                -1.401298464324817e-45,\n                0.0,\n                459367.0,\n            ],\n            dtype=np.float32,\n        )\n        self.assertEqual(knp.argmin(input_data), 2)\n\n    def test_argmin(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.argmin(x).shape, ())\n        self.assertEqual(knp.argmin(x, keepdims=True).shape, (None, 3))\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.argmin(x, axis=1).shape, (None, 3))\n        self.assertEqual(knp.argmin(x, keepdims=True).shape, (None, 3, 3))\n\n    def test_argsort(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.argsort(x).shape, (None, 3))\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.argsort(x, axis=1).shape, (None, 3, 3))\n\n    def test_array(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.array(x).shape, (None, 3))\n\n    def test_average(self):\n        x = KerasTensor((None, 3))\n        weights = KerasTensor((None, 3))\n        self.assertEqual(knp.average(x, weights=weights).shape, ())\n\n        x = KerasTensor((None, 3))\n        weights = KerasTensor((3,))\n        self.assertEqual(knp.average(x, axis=1, weights=weights).shape, (None,))\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.average(x, axis=1).shape, (None, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((None, 3, 3))\n            weights = KerasTensor((None, 4))\n            knp.average(x, weights=weights)\n\n    def test_bartlett(self):\n        x = np.random.randint(1, 100 + 1)\n        self.assertEqual(knp.bartlett(x).shape[0], x)\n\n    def test_blackman(self):\n        x = np.random.randint(1, 100 + 1)\n        self.assertEqual(knp.blackman(x).shape[0], x)\n\n    def test_hamming(self):\n        x = np.random.randint(1, 100 + 1)\n        self.assertEqual(knp.hamming(x).shape[0], x)\n\n    def test_hanning(self):\n        x = np.random.randint(1, 100 + 1)\n        self.assertEqual(knp.hanning(x).shape[0], x)\n\n    def test_kaiser(self):\n        x = np.random.randint(1, 100 + 1)\n        beta = float(np.random.randint(10, 20 + 1))\n        self.assertEqual(knp.kaiser(x, beta).shape[0], x)\n\n    def test_bitwise_invert(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.bitwise_invert(x).shape, (None, 3))\n\n    # bitwise_not is same as bitwise_invert\n\n    def test_broadcast_to(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.broadcast_to(x, (2, 3, 3)).shape, (2, 3, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((3, 3))\n            knp.broadcast_to(x, (2, 2, 3))\n\n    def test_cbrt(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.cbrt(x).shape, (None, 3))\n\n    def test_ceil(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.ceil(x).shape, (None, 3))\n\n    def test_clip(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.clip(x, 1, 2).shape, (None, 3))\n\n    def test_concatenate(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((None, 3))\n        self.assertEqual(\n            knp.concatenate(\n                [x, y],\n            ).shape,\n            (None, 3),\n        )\n        self.assertEqual(knp.concatenate([x, y], axis=1).shape, (None, 6))\n\n        with self.assertRaises(ValueError):\n            self.assertEqual(knp.concatenate([x, y], axis=None).shape, (None,))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((None, 3, 5))\n            y = KerasTensor((None, 4, 6))\n            knp.concatenate([x, y], axis=1)\n\n    def test_concatenate_sparse(self):\n        x = KerasTensor((2, 3), sparse=True)\n        y = KerasTensor((2, 3))\n        result = knp.concatenate([x, y], axis=1)\n        self.assertEqual(result.shape, (2, 6))\n        self.assertFalse(result.sparse)\n\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3), sparse=True)\n        result = knp.concatenate([x, y], axis=1)\n        self.assertEqual(result.shape, (2, 6))\n        self.assertFalse(result.sparse)\n\n        x = KerasTensor((2, 3), sparse=True)\n        y = KerasTensor((2, 3), sparse=True)\n        result = knp.concatenate([x, y], axis=1)\n        self.assertEqual(result.shape, (2, 6))\n        self.assertTrue(result.sparse)\n\n    def test_conjugate(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.conjugate(x).shape, (None, 3))\n\n    def test_conj(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.conj(x).shape, (None, 3))\n\n    def test_copy(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.copy(x).shape, (None, 3))\n\n    def test_corrcoef(self):\n        x = KerasTensor((3, None))\n        self.assertEqual(knp.corrcoef(x).shape, (3, None))\n\n    def test_cos(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.cos(x).shape, (None, 3))\n\n    def test_cosh(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.cosh(x).shape, (None, 3))\n\n    def test_count_nonzero(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.count_nonzero(x).shape, ())\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.count_nonzero(x, axis=1).shape, (None, 3))\n\n    def test_cumprod(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.cumprod(x).shape, (None,))\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.cumprod(x, axis=1).shape, (None, 3, 3))\n\n    def test_cumsum(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.cumsum(x).shape, (None,))\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.cumsum(x, axis=1).shape, (None, 3, 3))\n\n    def test_deg2rad(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.deg2rad(x).shape, (None, 3))\n\n    def test_diag(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.diag(x).shape, (None,))\n        self.assertEqual(knp.diag(x, k=3).shape, (None,))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3, 4))\n            knp.diag(x)\n\n    def test_diagflat(self):\n        x = KerasTensor((3,))\n        self.assertEqual(knp.diagflat(x).shape, (3, 3))\n        self.assertEqual(knp.diagflat(x, k=1).shape, (4, 4))\n        self.assertEqual(knp.diagflat(x, k=-1).shape, (4, 4))\n\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.diagflat(x).shape, (6, 6))\n        self.assertEqual(knp.diagflat(x, k=2).shape, (8, 8))\n\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.diagflat(x).shape, (None, None))\n\n    def test_diagonal(self):\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.diagonal(x).shape, (3, None))\n\n    def test_diff(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.diff(x).shape, (None, 2))\n        self.assertEqual(knp.diff(x, n=2).shape, (None, 1))\n        self.assertEqual(knp.diff(x, n=3).shape, (None, 0))\n        self.assertEqual(knp.diff(x, n=4).shape, (None, 0))\n\n        self.assertEqual(knp.diff(x, axis=0).shape, (None, 3))\n        self.assertEqual(knp.diff(x, n=2, axis=0).shape, (None, 3))\n\n    def test_dot(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((3, 2))\n        z = KerasTensor((None, None, 2))\n        self.assertEqual(knp.dot(x, y).shape, (None, 2))\n        self.assertEqual(knp.dot(x, 2).shape, (None, 3))\n        self.assertEqual(knp.dot(x, z).shape, (None, None, 2))\n\n        x = KerasTensor((None,))\n        y = KerasTensor((5,))\n        self.assertEqual(knp.dot(x, y).shape, ())\n\n    def test_empty_like(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.empty_like(x).shape, (None, 3))\n        self.assertEqual(knp.empty_like(x).dtype, x.dtype)\n\n    def test_exp(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.exp(x).shape, (None, 3))\n\n    def test_exp2(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.exp2(x).shape, (None, 3))\n\n    def test_expand_dims(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.expand_dims(x, -1).shape, (None, 3, 1))\n        self.assertEqual(knp.expand_dims(x, 0).shape, (1, None, 3))\n        self.assertEqual(knp.expand_dims(x, 1).shape, (None, 1, 3))\n        self.assertEqual(knp.expand_dims(x, -2).shape, (None, 1, 3))\n\n        # Multiple axes\n        self.assertEqual(knp.expand_dims(x, (1, 2)).shape, (None, 1, 1, 3))\n        self.assertEqual(knp.expand_dims(x, (-1, -2)).shape, (None, 3, 1, 1))\n        self.assertEqual(knp.expand_dims(x, (-1, 1)).shape, (None, 1, 3, 1))\n\n    def test_expm1(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.expm1(x).shape, (None, 3))\n\n    def test_flip(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.flip(x).shape, (None, 3))\n\n    def test_floor(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.floor(x).shape, (None, 3))\n\n    def test_get_item(self):\n        x = KerasTensor((None, 5, 16))\n        # Simple slice.\n        sliced = knp.get_item(x, 5)\n        self.assertEqual(sliced.shape, (5, 16))\n        # Ellipsis slice.\n        sliced = knp.get_item(x, np.s_[..., -1])\n        self.assertEqual(sliced.shape, (None, 5))\n        # `newaxis` slice.\n        sliced = knp.get_item(x, np.s_[:, np.newaxis, ...])\n        self.assertEqual(sliced.shape, (None, 1, 5, 16))\n        # Strided slice.\n        sliced = knp.get_item(x, np.s_[:5, 3:, 3:12:2])\n        self.assertEqual(sliced.shape, (None, 2, 5))\n        # Error states.\n        with self.assertRaises(ValueError):\n            sliced = knp.get_item(x, np.s_[:, 17, :])\n        with self.assertRaises(ValueError):\n            sliced = knp.get_item(x, np.s_[..., 5, ...])\n        with self.assertRaises(ValueError):\n            sliced = knp.get_item(x, np.s_[:, :, :, :])\n\n    def test_hstack(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((None, 3))\n        self.assertEqual(knp.hstack([x, y]).shape, (None, 6))\n\n        x = KerasTensor((None, 3))\n        y = KerasTensor((None, None))\n        self.assertEqual(knp.hstack([x, y]).shape, (None, None))\n\n    def test_imag(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.imag(x).shape, (None, 3))\n\n    def test_isfinite(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.isfinite(x).shape, (None, 3))\n\n    def test_isinf(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.isinf(x).shape, (None, 3))\n\n    def test_isnan(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.isnan(x).shape, (None, 3))\n\n    def test_isneginf(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.isneginf(x).shape, (None, 3))\n\n    def test_isposinf(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.isposinf(x).shape, (None, 3))\n\n    def test_isreal(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.isreal(x).shape, (None, 3))\n\n    def test_log(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.log(x).shape, (None, 3))\n\n    def test_log10(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.log10(x).shape, (None, 3))\n\n    def test_log1p(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.log1p(x).shape, (None, 3))\n\n    def test_log2(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.log2(x).shape, (None, 3))\n\n    def test_logaddexp(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.logaddexp(x, x).shape, (None, 3))\n\n    def test_logaddexp2(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.logaddexp2(x, x).shape, (None, 3))\n\n    def test_logical_not(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.logical_not(x).shape, (None, 3))\n\n    def test_max(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.max(x).shape, ())\n\n    def test_median(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.median(x).shape, ())\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.median(x, axis=1).shape, (None, 3))\n        self.assertEqual(\n            knp.median(x, axis=1, keepdims=True).shape, (None, 1, 3)\n        )\n\n    def test_meshgrid(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((None, 3))\n        self.assertEqual(knp.meshgrid(x, y)[0].shape, (None, None))\n        self.assertEqual(knp.meshgrid(x, y)[1].shape, (None, None))\n\n        with self.assertRaises(ValueError):\n            knp.meshgrid(x, y, indexing=\"kk\")\n\n    def test_moveaxis(self):\n        x = KerasTensor((None, 3, 4, 5))\n        self.assertEqual(knp.moveaxis(x, 0, -1).shape, (3, 4, 5, None))\n        self.assertEqual(knp.moveaxis(x, -1, 0).shape, (5, None, 3, 4))\n        self.assertEqual(\n            knp.moveaxis(x, [0, 1], [-1, -2]).shape, (4, 5, 3, None)\n        )\n        self.assertEqual(knp.moveaxis(x, [0, 1], [1, 0]).shape, (3, None, 4, 5))\n        self.assertEqual(\n            knp.moveaxis(x, [0, 1], [-2, -1]).shape, (4, 5, None, 3)\n        )\n\n    def test_nanargmax(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.nanargmax(x).shape, ())\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.nanargmax(x, axis=1).shape, (None, 3))\n        self.assertEqual(knp.nanargmax(x, axis=None).shape, ())\n\n        x = KerasTensor((None, 2, 3, 4))\n        self.assertEqual(knp.nanargmax(x, axis=2).shape, (None, 2, 4))\n\n    def test_nanargmin(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.nanargmin(x).shape, ())\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.nanargmin(x, axis=1).shape, (None, 3))\n        self.assertEqual(knp.nanargmin(x, axis=None).shape, ())\n\n        x = KerasTensor((None, 2, 3, 4))\n        self.assertEqual(knp.nanargmin(x, axis=2).shape, (None, 2, 4))\n\n    def test_nancumsum(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.nancumsum(x).shape, (None,))\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.nancumsum(x, axis=1).shape, (None, 3, 3))\n        self.assertEqual(knp.nancumsum(x, axis=(1,)).shape, (None, 3, 3))\n        self.assertEqual(knp.nancumsum(x, axis=None).shape, (None,))\n\n        x = KerasTensor((None, 2, 3, 4))\n        self.assertEqual(knp.nancumsum(x, axis=2).shape, (None, 2, 3, 4))\n\n    def test_nancumprod(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.nancumprod(x).shape, (None,))\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.nancumprod(x, axis=1).shape, (None, 3, 3))\n        self.assertEqual(knp.nancumprod(x, axis=(1,)).shape, (None, 3, 3))\n        self.assertEqual(knp.nancumprod(x, axis=None).shape, (None,))\n\n        x = KerasTensor((None, 2, 3, 4))\n        self.assertEqual(knp.nancumprod(x, axis=2).shape, (None, 2, 3, 4))\n\n    def test_nanmax(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.nanmax(x).shape, ())\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.nanmax(x, axis=1).shape, (None, 3))\n        self.assertEqual(\n            knp.nanmax(x, axis=1, keepdims=True).shape, (None, 1, 3)\n        )\n\n        self.assertEqual(knp.nanmax(x, axis=(1,)).shape, (None, 3))\n\n        self.assertEqual(knp.nanmax(x, axis=(1, 2)).shape, (None,))\n        self.assertEqual(\n            knp.nanmax(x, axis=(1, 2), keepdims=True).shape, (None, 1, 1)\n        )\n\n        self.assertEqual(knp.nanmax(x, axis=()).shape, (None, 3, 3))\n\n        x4 = KerasTensor((None, 2, 3, 4))\n        self.assertEqual(knp.nanmax(x4, axis=2).shape, (None, 2, 4))\n        self.assertEqual(knp.nanmax(x4, axis=(1, 3)).shape, (None, 3))\n\n    def test_nanmean(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.nanmean(x).shape, ())\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.nanmean(x, axis=1).shape, (None, 3))\n        self.assertEqual(\n            knp.nanmean(x, axis=1, keepdims=True).shape, (None, 1, 3)\n        )\n\n        self.assertEqual(knp.nanmean(x, axis=(1,)).shape, (None, 3))\n\n        self.assertEqual(knp.nanmean(x, axis=(1, 2)).shape, (None,))\n        self.assertEqual(\n            knp.nanmean(x, axis=(1, 2), keepdims=True).shape, (None, 1, 1)\n        )\n\n        self.assertEqual(knp.nanmean(x, axis=()).shape, (None, 3, 3))\n\n        x4 = KerasTensor((None, 2, 3, 4))\n        self.assertEqual(knp.nanmean(x4, axis=2).shape, (None, 2, 4))\n        self.assertEqual(knp.nanmean(x4, axis=(1, 3)).shape, (None, 3))\n\n    def test_nanmin(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.nanmin(x).shape, ())\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.nanmin(x, axis=1).shape, (None, 3))\n        self.assertEqual(\n            knp.nanmin(x, axis=1, keepdims=True).shape, (None, 1, 3)\n        )\n\n        self.assertEqual(knp.nanmin(x, axis=(1,)).shape, (None, 3))\n\n        self.assertEqual(knp.nanmin(x, axis=(1, 2)).shape, (None,))\n        self.assertEqual(\n            knp.nanmin(x, axis=(1, 2), keepdims=True).shape, (None, 1, 1)\n        )\n\n        self.assertEqual(knp.nanmin(x, axis=()).shape, (None, 3, 3))\n\n        x4 = KerasTensor((None, 2, 3, 4))\n        self.assertEqual(knp.nanmin(x4, axis=2).shape, (None, 2, 4))\n        self.assertEqual(knp.nanmin(x4, axis=(1, 3)).shape, (None, 3))\n\n    def test_nanprod(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.nanprod(x).shape, ())\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.nanprod(x, axis=1).shape, (None, 3))\n        self.assertEqual(\n            knp.nanprod(x, axis=1, keepdims=True).shape, (None, 1, 3)\n        )\n\n        self.assertEqual(knp.nanprod(x, axis=(1,)).shape, (None, 3))\n\n        self.assertEqual(knp.nanprod(x, axis=(1, 2)).shape, (None,))\n        self.assertEqual(\n            knp.nanprod(x, axis=(1, 2), keepdims=True).shape, (None, 1, 1)\n        )\n\n        self.assertEqual(knp.nanprod(x, axis=()).shape, (None, 3, 3))\n\n        x4 = KerasTensor((None, 2, 3, 4))\n        self.assertEqual(knp.nanprod(x4, axis=2).shape, (None, 2, 4))\n        self.assertEqual(knp.nanprod(x4, axis=(1, 3)).shape, (None, 3))\n\n    def test_nanstd(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.nanstd(x).shape, ())\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.nanstd(x, axis=1).shape, (None, 3))\n        self.assertEqual(\n            knp.nanstd(x, axis=1, keepdims=True).shape, (None, 1, 3)\n        )\n\n        self.assertEqual(knp.nanstd(x, axis=(1,)).shape, (None, 3))\n\n        self.assertEqual(knp.nanstd(x, axis=(1, 2)).shape, (None,))\n        self.assertEqual(\n            knp.nanstd(x, axis=(1, 2), keepdims=True).shape, (None, 1, 1)\n        )\n\n        self.assertEqual(knp.nanstd(x, axis=()).shape, (None, 3, 3))\n\n        x4 = KerasTensor((None, 2, 3, 4))\n        self.assertEqual(knp.nanstd(x4, axis=2).shape, (None, 2, 4))\n        self.assertEqual(knp.nanstd(x4, axis=(1, 3)).shape, (None, 3))\n\n    def test_nansum(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.nansum(x).shape, ())\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.nansum(x, axis=1).shape, (None, 3))\n        self.assertEqual(\n            knp.nansum(x, axis=1, keepdims=True).shape, (None, 1, 3)\n        )\n\n        self.assertEqual(knp.nansum(x, axis=(1,)).shape, (None, 3))\n\n        self.assertEqual(knp.nansum(x, axis=(1, 2)).shape, (None,))\n        self.assertEqual(\n            knp.nansum(x, axis=(1, 2), keepdims=True).shape, (None, 1, 1)\n        )\n\n        self.assertEqual(knp.nansum(x, axis=()).shape, (None, 3, 3))\n\n        x4 = KerasTensor((None, 2, 3, 4))\n        self.assertEqual(knp.nansum(x4, axis=2).shape, (None, 2, 4))\n        self.assertEqual(knp.nansum(x4, axis=(1, 3)).shape, (None, 3))\n\n    def test_nanvar(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.nanvar(x).shape, ())\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.nanvar(x, axis=1).shape, (None, 3))\n        self.assertEqual(\n            knp.nanvar(x, axis=1, keepdims=True).shape, (None, 1, 3)\n        )\n\n        self.assertEqual(knp.nanvar(x, axis=(1,)).shape, (None, 3))\n\n        self.assertEqual(knp.nanvar(x, axis=(1, 2)).shape, (None,))\n        self.assertEqual(\n            knp.nanvar(x, axis=(1, 2), keepdims=True).shape, (None, 1, 1)\n        )\n\n        self.assertEqual(knp.nanvar(x, axis=()).shape, (None, 3, 3))\n\n        x4 = KerasTensor((None, 2, 3, 4))\n        self.assertEqual(knp.nanvar(x4, axis=2).shape, (None, 2, 4))\n        self.assertEqual(knp.nanvar(x4, axis=(1, 3)).shape, (None, 3))\n\n    def test_ndim(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.ndim(x).shape, (2,))\n\n    def test_nonzero(self):\n        x = KerasTensor((None, 5, 6))\n        result = knp.nonzero(x)\n        self.assertLen(result, 3)\n        self.assertEqual(result[0].shape, (None,))\n        self.assertEqual(result[1].shape, (None,))\n        self.assertEqual(result[2].shape, (None,))\n\n    def test_ones_like(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.ones_like(x).shape, (None, 3))\n        self.assertEqual(knp.ones_like(x).dtype, x.dtype)\n\n    def test_zeros_like(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.zeros_like(x).shape, (None, 3))\n        self.assertEqual(knp.zeros_like(x).dtype, x.dtype)\n\n    def test_pad(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.pad(x, 1).shape, (None, 5))\n        self.assertEqual(knp.pad(x, (1, 2)).shape, (None, 6))\n        self.assertEqual(knp.pad(x, ((1, 2), (3, 4))).shape, (None, 10))\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.pad(x, 1).shape, (None, 5, 5))\n        self.assertEqual(knp.pad(x, (1, 2)).shape, (None, 6, 6))\n        self.assertEqual(\n            knp.pad(x, ((1, 2), (3, 4), (5, 6))).shape, (None, 10, 14)\n        )\n\n    def test_prod(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.prod(x).shape, ())\n        self.assertEqual(knp.prod(x, axis=0).shape, (3,))\n        self.assertEqual(knp.prod(x, axis=1, keepdims=True).shape, (None, 1))\n\n    def test_ptp(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.ptp(x).shape, ())\n        self.assertEqual(knp.ptp(x, axis=0).shape, (3,))\n        self.assertEqual(knp.ptp(x, axis=1, keepdims=True).shape, (None, 1))\n\n    def test_ravel(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.ravel(x).shape, (None,))\n\n    def test_unravel_index(self):\n        x = KerasTensor((None,))\n        indices = knp.unravel_index(x, (2, 3))\n        self.assertEqual(len(indices), 2)\n        self.assertEqual(indices[0].shape, (None,))\n        self.assertEqual(indices[1].shape, (None,))\n\n        x = KerasTensor((None, 4))\n        indices = knp.unravel_index(x, (3, 4))\n        self.assertEqual(len(indices), 2)\n        self.assertEqual(indices[0].shape, (None, 4))\n        self.assertEqual(indices[1].shape, (None, 4))\n\n        x = KerasTensor((None, 3, 2))\n        indices = knp.unravel_index(x, (5, 6, 4))\n        self.assertEqual(len(indices), 3)\n        self.assertEqual(indices[0].shape, (None, 3, 2))\n        self.assertEqual(indices[1].shape, (None, 3, 2))\n        self.assertEqual(indices[2].shape, (None, 3, 2))\n\n    def test_real(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.real(x).shape, (None, 3))\n\n    def test_reciprocal(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.reciprocal(x).shape, (None, 3))\n\n    def test_repeat(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.repeat(x, 2).shape, (None,))\n        self.assertEqual(knp.repeat(x, 3, axis=1).shape, (None, 9))\n        self.assertEqual(knp.repeat(x, [1, 2], axis=0).shape, (None, 3))\n        self.assertEqual(knp.repeat(x, 2, axis=0).shape, (None, 3))\n\n    def test_reshape(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.reshape(x, (3, 2)).shape, (3, 2))\n        self.assertEqual(knp.reshape(x, (3, -1)).shape, (3, None))\n\n    def test_roll(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.roll(x, 1).shape, (None, 3))\n        self.assertEqual(knp.roll(x, 1, axis=1).shape, (None, 3))\n        self.assertEqual(knp.roll(x, 1, axis=0).shape, (None, 3))\n\n    def test_round(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.round(x).shape, (None, 3))\n\n    def test_sign(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.sign(x).shape, (None, 3))\n\n    def test_signbit(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.signbit(x).shape, (None, 3))\n\n    def test_sin(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.sin(x).shape, (None, 3))\n\n    def test_sinc(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.sinc(x).shape, (None, 3))\n\n    def test_sinh(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.sinh(x).shape, (None, 3))\n\n    def test_size(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.size(x).shape, ())\n\n    def test_sort(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.sort(x).shape, (None, 3))\n        self.assertEqual(knp.sort(x, axis=1).shape, (None, 3))\n        self.assertEqual(knp.sort(x, axis=0).shape, (None, 3))\n\n    def test_split(self):\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.split(x, 2)[0].shape, (None, 3, 3))\n        self.assertEqual(knp.split(x, 3, axis=1)[0].shape, (None, 1, 3))\n        self.assertEqual(len(knp.split(x, [1, 3], axis=1)), 3)\n        self.assertEqual(knp.split(x, [1, 3], axis=1)[0].shape, (None, 1, 3))\n        self.assertEqual(knp.split(x, [1, 3], axis=1)[1].shape, (None, 2, 3))\n        self.assertEqual(knp.split(x, [1, 3], axis=1)[2].shape, (None, 0, 3))\n\n    def test_sqrt(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.sqrt(x).shape, (None, 3))\n\n    def test_stack(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((None, 3))\n        self.assertEqual(knp.stack([x, y]).shape, (2, None, 3))\n        self.assertEqual(knp.stack([x, y], axis=-1).shape, (None, 3, 2))\n\n    def test_std(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.std(x).shape, ())\n\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.std(x, axis=1).shape, (None, 3))\n        self.assertEqual(knp.std(x, axis=1, keepdims=True).shape, (None, 1, 3))\n\n    def test_swapaxes(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.swapaxes(x, 0, 1).shape, (3, None))\n\n    def test_tan(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.tan(x).shape, (None, 3))\n\n    def test_tanh(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.tanh(x).shape, (None, 3))\n\n    def test_tile(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.tile(x, 2).shape, (None, 6))\n        self.assertEqual(knp.tile(x, [2]).shape, (None, 6))\n        self.assertEqual(knp.tile(x, [1, 2]).shape, (None, 6))\n        self.assertEqual(knp.tile(x, [2, 1, 2]).shape, (2, None, 6))\n\n        # Test with multi-dimensional input\n        x = KerasTensor((None, 3, 2, 2))\n        self.assertEqual(knp.tile(x, [1, 2, 1, 1]).shape, (None, 6, 2, 2))\n\n    def test_trace(self):\n        x = KerasTensor((None, 3, None, 5))\n        self.assertEqual(knp.trace(x).shape, (None, 5))\n        self.assertEqual(knp.trace(x, axis1=2, axis2=3).shape, (None, 3))\n\n    def test_tril(self):\n        x = KerasTensor((None, 3, None, 5))\n        self.assertEqual(knp.tril(x).shape, (None, 3, None, 5))\n        self.assertEqual(knp.tril(x, k=1).shape, (None, 3, None, 5))\n        self.assertEqual(knp.tril(x, k=-1).shape, (None, 3, None, 5))\n\n    def test_triu(self):\n        x = KerasTensor((None, 3, None, 5))\n        self.assertEqual(knp.triu(x).shape, (None, 3, None, 5))\n        self.assertEqual(knp.triu(x, k=1).shape, (None, 3, None, 5))\n        self.assertEqual(knp.triu(x, k=-1).shape, (None, 3, None, 5))\n\n    def test_trunc(self):\n        x = KerasTensor((None, 3, None, 5))\n        self.assertEqual(knp.trunc(x).shape, (None, 3, None, 5))\n\n    def test_vstack(self):\n        x = KerasTensor((None, 3))\n        y = KerasTensor((None, 3))\n        self.assertEqual(knp.vstack([x, y]).shape, (None, 3))\n\n        x = KerasTensor((None, 3))\n        y = KerasTensor((None, None))\n        self.assertEqual(knp.vstack([x, y]).shape, (None, 3))\n\n    def test_dstack(self):\n        x = KerasTensor((None,))\n        y = KerasTensor((None,))\n        self.assertEqual(knp.dstack([x, y]).shape, (1, None, 2))\n\n        x = KerasTensor((None, 3))\n        y = KerasTensor((None, 3))\n        self.assertEqual(knp.dstack([x, y]).shape, (None, 3, 2))\n\n        x = KerasTensor((None, 3))\n        y = KerasTensor((None, None))\n        self.assertEqual(knp.dstack([x, y]).shape, (None, 3, 2))\n\n    def test_hsplit(self):\n        x = KerasTensor((3, None, 3))\n        self.assertEqual(knp.hsplit(x, 2)[0].shape, (3, None, 3))\n        self.assertEqual(len(knp.hsplit(x, [1, 3])), 3)\n        self.assertEqual(knp.hsplit(x, [1, 3])[0].shape, (3, 1, 3))\n        self.assertEqual(knp.hsplit(x, [1, 3])[1].shape, (3, 2, 3))\n        self.assertEqual(knp.hsplit(x, [1, 3])[2].shape, (3, None, 3))\n\n        # test 1D case\n        x_1d = KerasTensor((None,))\n        self.assertEqual(knp.hsplit(x_1d, 2)[0].shape, (None,))\n\n        splits_1d = knp.hsplit(x_1d, [2, 5])\n        self.assertEqual(splits_1d[0].shape, (2,))\n        self.assertEqual(splits_1d[1].shape, (3,))\n        self.assertEqual(splits_1d[2].shape, (None,))\n\n    def test_vsplit(self):\n        x = KerasTensor((None, 3, 3))\n        self.assertEqual(knp.vsplit(x, 2)[0].shape, (None, 3, 3))\n        self.assertEqual(len(knp.vsplit(x, [1, 3])), 3)\n        self.assertEqual(knp.vsplit(x, [1, 3])[0].shape, (1, 3, 3))\n        self.assertEqual(knp.vsplit(x, [1, 3])[1].shape, (2, 3, 3))\n        self.assertEqual(knp.vsplit(x, [1, 3])[2].shape, (None, 3, 3))\n\n    def test_argpartition(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.argpartition(x, 3).shape, (None, 3))\n        self.assertEqual(knp.argpartition(x, 1, axis=1).shape, (None, 3))\n\n        with self.assertRaises(ValueError):\n            knp.argpartition(x, (1, 3))\n\n    def test_angle(self):\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.angle(x).shape, (None, 3))\n\n    def test_view(self):\n        x = knp.array(KerasTensor((None, 3)), dtype=\"int32\")\n        self.assertEqual(knp.view(x, dtype=\"uint32\").shape, (None, 3))\n        self.assertEqual(knp.view(x, dtype=\"uint32\").dtype, \"uint32\")\n        x = knp.array(KerasTensor((None, 3)), dtype=\"int32\")\n        self.assertEqual(knp.view(x, dtype=\"int16\").shape, (None, 6))\n        self.assertEqual(knp.view(x, dtype=\"int16\").dtype, \"int16\")\n        x = knp.array(KerasTensor((None, 4)), dtype=\"int16\")\n        self.assertEqual(knp.view(x, dtype=\"int32\").shape, (None, 2))\n        self.assertEqual(knp.view(x, dtype=\"int32\").dtype, \"int32\")\n\n    def test_array_split(self):\n        x = KerasTensor((None, 4))\n        splits = knp.array_split(x, 2, axis=0)\n        self.assertEqual(len(splits), 2)\n        self.assertEqual(splits[0].shape, (None, 4))\n        self.assertEqual(splits[1].shape, (None, 4))\n\n\nclass NumpyOneInputOpsStaticShapeTest(testing.TestCase):\n    def test_mean(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.mean(x).shape, ())\n\n    def test_all(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.all(x).shape, ())\n\n    def test_any(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.any(x).shape, ())\n\n    def test_trapezoid(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.trapezoid(x).shape, (2,))\n\n    def test_vander(self):\n        x = KerasTensor((2,))\n        self.assertEqual(knp.vander(x).shape, (2, 2))\n\n    def test_var(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.var(x).shape, ())\n\n    def test_sum(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.sum(x).shape, ())\n\n    def test_amax(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.amax(x).shape, ())\n\n    def test_amin(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.amin(x).shape, ())\n\n    def test_square(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.square(x).shape, (2, 3))\n\n    def test_negative(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.negative(x).shape, (2, 3))\n\n    def test_abs(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.abs(x).shape, (2, 3))\n\n    def test_absolute(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.absolute(x).shape, (2, 3))\n\n    def test_squeeze(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.squeeze(x).shape, (2, 3))\n\n        x = KerasTensor((2, 1, 3))\n        self.assertEqual(knp.squeeze(x).shape, (2, 3))\n        self.assertEqual(knp.squeeze(x, axis=1).shape, (2, 3))\n        self.assertEqual(knp.squeeze(x, axis=-2).shape, (2, 3))\n\n        with self.assertRaises(ValueError):\n            knp.squeeze(x, axis=0)\n\n        # Multiple axes\n        x = KerasTensor((2, 1, 1, 1))\n        self.assertEqual(knp.squeeze(x, (1, 2)).shape, (2, 1))\n        self.assertEqual(knp.squeeze(x, (-1, -2)).shape, (2, 1))\n        self.assertEqual(knp.squeeze(x, (1, 2, 3)).shape, (2,))\n        self.assertEqual(knp.squeeze(x, (-1, 1)).shape, (2, 1))\n\n    def test_transpose(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.transpose(x).shape, (3, 2))\n\n    def test_arccos(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.arccos(x).shape, (2, 3))\n\n    def test_arccosh(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.arccosh(x).shape, (2, 3))\n\n    def test_arcsin(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.arcsin(x).shape, (2, 3))\n\n    def test_arcsinh(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.arcsinh(x).shape, (2, 3))\n\n    def test_arctan(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.arctan(x).shape, (2, 3))\n\n    def test_arctanh(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.arctanh(x).shape, (2, 3))\n\n    def test_argmax(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.argmax(x).shape, ())\n        self.assertEqual(knp.argmax(x, keepdims=True).shape, (2, 3))\n\n    def test_argmin(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.argmin(x).shape, ())\n        self.assertEqual(knp.argmin(x, keepdims=True).shape, (2, 3))\n\n    def test_argsort(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.argsort(x).shape, (2, 3))\n        self.assertEqual(knp.argsort(x, axis=None).shape, (6,))\n\n    def test_array(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.array(x).shape, (2, 3))\n\n    def test_average(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.average(x).shape, ())\n\n    def test_bitwise_invert(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.bitwise_invert(x).shape, (2, 3))\n\n    # bitwise_not is same as bitwise_invert\n\n    def test_broadcast_to(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.broadcast_to(x, (2, 2, 3)).shape, (2, 2, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((3, 3))\n            knp.broadcast_to(x, (2, 2, 3))\n\n    def test_cbrt(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.cbrt(x).shape, (2, 3))\n\n    def test_ceil(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.ceil(x).shape, (2, 3))\n\n    def test_clip(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.clip(x, 1, 2).shape, (2, 3))\n\n    def test_concatenate(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.concatenate([x, y]).shape, (4, 3))\n        self.assertEqual(knp.concatenate([x, y], axis=1).shape, (2, 6))\n\n        with self.assertRaises(ValueError):\n            self.assertEqual(knp.concatenate([x, y], axis=None).shape, (None,))\n\n    def test_conjugate(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.conjugate(x).shape, (2, 3))\n\n    def test_conj(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.conj(x).shape, (2, 3))\n\n    def test_copy(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.copy(x).shape, (2, 3))\n\n    def test_cos(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.cos(x).shape, (2, 3))\n\n    def test_cosh(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.cosh(x).shape, (2, 3))\n\n    def test_count_nonzero(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.count_nonzero(x).shape, ())\n\n    def test_cumprod(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.cumprod(x).shape, (6,))\n\n    def test_cumsum(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.cumsum(x).shape, (6,))\n\n    def test_deg2rad(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.deg2rad(x).shape, (2, 3))\n\n    def test_diag(self):\n        x = KerasTensor((3,))\n        self.assertEqual(knp.diag(x).shape, (3, 3))\n        self.assertEqual(knp.diag(x, k=3).shape, (6, 6))\n        self.assertEqual(knp.diag(x, k=-2).shape, (5, 5))\n\n        x = KerasTensor((3, 5))\n        self.assertEqual(knp.diag(x).shape, (3,))\n        self.assertEqual(knp.diag(x, k=3).shape, (2,))\n        self.assertEqual(knp.diag(x, k=-2).shape, (1,))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3, 4))\n            knp.diag(x)\n\n    def test_diagflat(self):\n        x = KerasTensor((3,))\n        self.assertEqual(knp.diagflat(x).shape, (3, 3))\n        self.assertEqual(knp.diagflat(x, k=1).shape, (4, 4))\n        self.assertEqual(knp.diagflat(x, k=-1).shape, (4, 4))\n\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.diagflat(x).shape, (6, 6))\n        self.assertEqual(knp.diagflat(x, k=1).shape, (7, 7))\n        self.assertEqual(knp.diagflat(x, k=-1).shape, (7, 7))\n\n        x = KerasTensor((None, 3))\n        self.assertEqual(knp.diagflat(x).shape, (None, None))\n\n        x = KerasTensor(())\n        self.assertEqual(knp.diagflat(x).shape, (1, 1))\n\n    def test_diagonal(self):\n        x = KerasTensor((3, 3))\n        self.assertEqual(knp.diagonal(x).shape, (3,))\n        self.assertEqual(knp.diagonal(x, offset=1).shape, (2,))\n\n        x = KerasTensor((3, 5, 5))\n        self.assertEqual(knp.diagonal(x).shape, (5, 3))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((3,))\n            knp.diagonal(x)\n\n    def test_diff(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.diff(x).shape, (2, 2))\n        self.assertEqual(knp.diff(x, n=2).shape, (2, 1))\n        self.assertEqual(knp.diff(x, n=3).shape, (2, 0))\n        self.assertEqual(knp.diff(x, n=4).shape, (2, 0))\n\n        self.assertEqual(knp.diff(x, axis=0).shape, (1, 3))\n        self.assertEqual(knp.diff(x, n=2, axis=0).shape, (0, 3))\n        self.assertEqual(knp.diff(x, n=3, axis=0).shape, (0, 3))\n\n    def test_dot(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((3, 2))\n        z = KerasTensor((4, 3, 2))\n        self.assertEqual(knp.dot(x, y).shape, (2, 2))\n        self.assertEqual(knp.dot(x, 2).shape, (2, 3))\n        self.assertEqual(knp.dot(x, z).shape, (2, 4, 2))\n\n        x = KerasTensor((5,))\n        y = KerasTensor((5,))\n        self.assertEqual(knp.dot(x, y).shape, ())\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((2, 3))\n            knp.dot(x, y)\n\n    def test_empty_like(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.empty_like(x).shape, (2, 3))\n        self.assertEqual(knp.empty_like(x).dtype, x.dtype)\n\n    def test_exp(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.exp(x).shape, (2, 3))\n\n    def test_exp2(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.exp2(x).shape, (2, 3))\n\n    def test_expand_dims(self):\n        x = KerasTensor((2, 3, 4))\n        self.assertEqual(knp.expand_dims(x, 0).shape, (1, 2, 3, 4))\n        self.assertEqual(knp.expand_dims(x, 1).shape, (2, 1, 3, 4))\n        self.assertEqual(knp.expand_dims(x, -2).shape, (2, 3, 1, 4))\n\n        # Multiple axes\n        self.assertEqual(knp.expand_dims(x, (1, 2)).shape, (2, 1, 1, 3, 4))\n        self.assertEqual(knp.expand_dims(x, (-1, -2)).shape, (2, 3, 4, 1, 1))\n        self.assertEqual(knp.expand_dims(x, (-1, 1)).shape, (2, 1, 3, 4, 1))\n\n    def test_expm1(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.expm1(x).shape, (2, 3))\n\n    def test_flip(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.flip(x).shape, (2, 3))\n\n    def test_floor(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.floor(x).shape, (2, 3))\n\n    def test_get_item(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.get_item(x, 1).shape, (3,))\n\n        x = KerasTensor((5, 3, 2))\n        self.assertEqual(knp.get_item(x, 3).shape, (3, 2))\n\n        x = KerasTensor(\n            [\n                2,\n            ]\n        )\n        self.assertEqual(knp.get_item(x, 0).shape, ())\n\n    def test_hstack(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.hstack([x, y]).shape, (2, 6))\n\n    def test_imag(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.imag(x).shape, (2, 3))\n\n    def test_isfinite(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.isfinite(x).shape, (2, 3))\n\n    def test_isinf(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.isinf(x).shape, (2, 3))\n\n    def test_isnan(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.isnan(x).shape, (2, 3))\n\n    def test_isneginf(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.isneginf(x).shape, (2, 3))\n\n    def test_isposinf(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.isposinf(x).shape, (2, 3))\n\n    def test_isreal(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.isreal(x).shape, (2, 3))\n\n    def test_log(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.log(x).shape, (2, 3))\n\n    def test_log10(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.log10(x).shape, (2, 3))\n\n    def test_log1p(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.log1p(x).shape, (2, 3))\n\n    def test_log2(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.log2(x).shape, (2, 3))\n\n    def test_logaddexp(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.logaddexp(x, x).shape, (2, 3))\n\n    def test_logaddexp2(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.logaddexp2(x, x).shape, (2, 3))\n\n    def test_logical_not(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.logical_not(x).shape, (2, 3))\n\n    def test_max(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.max(x).shape, ())\n\n    def test_median(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.median(x).shape, ())\n\n        x = KerasTensor((2, 3, 3))\n        self.assertEqual(knp.median(x, axis=1).shape, (2, 3))\n        self.assertEqual(knp.median(x, axis=1, keepdims=True).shape, (2, 1, 3))\n\n    def test_meshgrid(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3, 4))\n        z = KerasTensor((2, 3, 4, 5))\n        self.assertEqual(knp.meshgrid(x, y)[0].shape, (24, 6))\n        self.assertEqual(knp.meshgrid(x, y)[1].shape, (24, 6))\n        self.assertEqual(knp.meshgrid(x, y, indexing=\"ij\")[0].shape, (6, 24))\n        self.assertEqual(\n            knp.meshgrid(x, y, z, indexing=\"ij\")[0].shape, (6, 24, 120)\n        )\n        with self.assertRaises(ValueError):\n            knp.meshgrid(x, y, indexing=\"kk\")\n\n    def test_moveaxis(self):\n        x = KerasTensor((2, 3, 4, 5))\n        self.assertEqual(knp.moveaxis(x, 0, -1).shape, (3, 4, 5, 2))\n        self.assertEqual(knp.moveaxis(x, -1, 0).shape, (5, 2, 3, 4))\n        self.assertEqual(knp.moveaxis(x, [0, 1], [-1, -2]).shape, (4, 5, 3, 2))\n        self.assertEqual(knp.moveaxis(x, [0, 1], [1, 0]).shape, (3, 2, 4, 5))\n        self.assertEqual(knp.moveaxis(x, [0, 1], [-2, -1]).shape, (4, 5, 2, 3))\n\n    def test_nanargmax(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.nanargmax(x).shape, ())\n        self.assertEqual(knp.nanargmax(x, axis=0).shape, (3,))\n        self.assertEqual(knp.nanargmax(x, axis=1).shape, (2,))\n\n    def test_nanargmin(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.nanargmin(x).shape, ())\n        self.assertEqual(knp.nanargmin(x, axis=0).shape, (3,))\n        self.assertEqual(knp.nanargmin(x, axis=1).shape, (2,))\n\n    def test_nancumsum(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.nancumsum(x).shape, (6,))\n        self.assertEqual(knp.nancumsum(x, axis=0).shape, (2, 3))\n        self.assertEqual(knp.nancumsum(x, axis=1).shape, (2, 3))\n        self.assertEqual(knp.nancumsum(x, axis=(1,)).shape, (2, 3))\n        self.assertEqual(knp.nancumsum(x, axis=()).shape, (2, 3))\n\n    def test_nancumprod(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.nancumprod(x).shape, (6,))\n        self.assertEqual(knp.nancumprod(x, axis=0).shape, (2, 3))\n        self.assertEqual(knp.nancumprod(x, axis=1).shape, (2, 3))\n        self.assertEqual(knp.nancumprod(x, axis=(1,)).shape, (2, 3))\n        self.assertEqual(knp.nancumprod(x, axis=()).shape, (2, 3))\n\n    def test_nanmax(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.nanmax(x).shape, ())\n        self.assertEqual(knp.nanmax(x, axis=0).shape, (3,))\n        self.assertEqual(knp.nanmax(x, axis=1).shape, (2,))\n        self.assertEqual(knp.nanmax(x, axis=1, keepdims=True).shape, (2, 1))\n\n    def test_nanmean(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.nanmean(x).shape, ())\n        self.assertEqual(knp.nanmean(x, axis=0).shape, (3,))\n        self.assertEqual(knp.nanmean(x, axis=1).shape, (2,))\n        self.assertEqual(knp.nanmean(x, axis=1, keepdims=True).shape, (2, 1))\n\n    def test_nanmin(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.nanmin(x).shape, ())\n        self.assertEqual(knp.nanmin(x, axis=0).shape, (3,))\n        self.assertEqual(knp.nanmin(x, axis=1).shape, (2,))\n        self.assertEqual(knp.nanmin(x, axis=1, keepdims=True).shape, (2, 1))\n\n    def test_nanprod_(self):\n        x = KerasTensor((2, 3))\n\n        self.assertEqual(knp.nanprod(x).shape, ())\n        self.assertEqual(knp.nanprod(x, axis=0).shape, (3,))\n        self.assertEqual(knp.nanprod(x, axis=1).shape, (2,))\n        self.assertEqual(knp.nanprod(x, axis=1, keepdims=True).shape, (2, 1))\n\n    def test_nanstd_(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.nanstd(x).shape, ())\n        self.assertEqual(knp.nanstd(x, axis=0).shape, (3,))\n        self.assertEqual(knp.nanstd(x, axis=1).shape, (2,))\n        self.assertEqual(knp.nanstd(x, axis=1, keepdims=True).shape, (2, 1))\n\n    def test_nansum_(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.nansum(x).shape, ())\n        self.assertEqual(knp.nansum(x, axis=0).shape, (3,))\n        self.assertEqual(knp.nansum(x, axis=1).shape, (2,))\n        self.assertEqual(knp.nansum(x, axis=1, keepdims=True).shape, (2, 1))\n\n    def test_nanvar_(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.nanvar(x).shape, ())\n        self.assertEqual(knp.nanvar(x, axis=0).shape, (3,))\n        self.assertEqual(knp.nanvar(x, axis=1).shape, (2,))\n        self.assertEqual(knp.nanvar(x, axis=1, keepdims=True).shape, (2, 1))\n\n    def test_ndim(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.ndim(x).shape, (2,))\n\n    def test_ones_like(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.ones_like(x).shape, (2, 3))\n        self.assertEqual(knp.ones_like(x).dtype, x.dtype)\n\n    def test_zeros_like(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.zeros_like(x).shape, (2, 3))\n        self.assertEqual(knp.zeros_like(x).dtype, x.dtype)\n\n    def test_pad(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.pad(x, 1).shape, (4, 5))\n        self.assertEqual(knp.pad(x, (1, 2)).shape, (5, 6))\n        self.assertEqual(knp.pad(x, ((1, 2), (3, 4))).shape, (5, 10))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            knp.pad(x, ((1, 2), (3, 4), (5, 6)))\n\n    def test_prod(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.prod(x).shape, ())\n        self.assertEqual(knp.prod(x, axis=0).shape, (3,))\n        self.assertEqual(knp.prod(x, axis=1).shape, (2,))\n\n    def test_ptp(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.ptp(x).shape, ())\n        self.assertEqual(knp.ptp(x, axis=0).shape, (3,))\n        self.assertEqual(knp.ptp(x, axis=1).shape, (2,))\n\n    def test_ravel(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.ravel(x).shape, (6,))\n\n    def test_unravel_index(self):\n        x = KerasTensor((6,))\n        indices = knp.unravel_index(x, (2, 3))\n        self.assertEqual(len(indices), 2)\n        self.assertEqual(indices[0].shape, (6,))\n        self.assertEqual(indices[1].shape, (6,))\n\n        x = KerasTensor((2, 3))\n        indices = knp.unravel_index(x, (3, 4))\n        self.assertEqual(len(indices), 2)\n        self.assertEqual(indices[0].shape, (2, 3))\n        self.assertEqual(indices[1].shape, (2, 3))\n\n    def test_real(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.real(x).shape, (2, 3))\n\n    def test_reciprocal(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.reciprocal(x).shape, (2, 3))\n\n    def test_repeat(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.repeat(x, 2).shape, (12,))\n        self.assertEqual(knp.repeat(x, [2]).shape, (12,))\n        self.assertEqual(knp.repeat(x, 3, axis=1).shape, (2, 9))\n        self.assertEqual(knp.repeat(x, [1, 2], axis=0).shape, (3, 3))\n\n        with self.assertRaises(ValueError):\n            knp.repeat(x, [1, 1])\n        with self.assertRaises(ValueError):\n            knp.repeat(x, [1, 1, 1], axis=0)\n\n    def test_reshape(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.reshape(x, (3, 2)).shape, (3, 2))\n        self.assertEqual(knp.reshape(x, (3, -1)).shape, (3, 2))\n        self.assertEqual(knp.reshape(x, (6,)).shape, (6,))\n        self.assertEqual(knp.reshape(x, (-1,)).shape, (6,))\n\n    def test_roll(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.roll(x, 1).shape, (2, 3))\n        self.assertEqual(knp.roll(x, 1, axis=1).shape, (2, 3))\n        self.assertEqual(knp.roll(x, 1, axis=0).shape, (2, 3))\n\n    def test_round(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.round(x).shape, (2, 3))\n\n    def test_sign(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.sign(x).shape, (2, 3))\n\n    def test_signbit(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.signbit(x).shape, (2, 3))\n\n    def test_sin(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.sin(x).shape, (2, 3))\n\n    def test_sinc(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.sinc(x).shape, (2, 3))\n\n    def test_sinh(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.sinh(x).shape, (2, 3))\n\n    def test_size(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.size(x).shape, ())\n\n    def test_sort(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.sort(x).shape, (2, 3))\n        self.assertEqual(knp.sort(x, axis=1).shape, (2, 3))\n        self.assertEqual(knp.sort(x, axis=0).shape, (2, 3))\n\n    def test_split(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(len(knp.split(x, 2)), 2)\n        self.assertEqual(knp.split(x, 2)[0].shape, (1, 3))\n        self.assertEqual(knp.split(x, 3, axis=1)[0].shape, (2, 1))\n        self.assertEqual(len(knp.split(x, [1, 3], axis=1)), 3)\n        self.assertEqual(knp.split(x, [1, 3], axis=1)[0].shape, (2, 1))\n        self.assertEqual(knp.split(x, [1, 3], axis=1)[1].shape, (2, 2))\n        self.assertEqual(knp.split(x, [1, 3], axis=1)[2].shape, (2, 0))\n\n        with self.assertRaises(ValueError):\n            knp.split(x, 2, axis=1)\n\n    def test_hsplit(self):\n        x = KerasTensor((3, 5))\n\n        splits = knp.hsplit(x, 5)\n        self.assertEqual(len(splits), 5)\n        for split in splits:\n            self.assertEqual(split.shape, (3, 1))\n\n        splits = knp.hsplit(x, [1, 3])\n        self.assertEqual(len(splits), 3)\n        self.assertEqual(splits[0].shape, (3, 1))\n        self.assertEqual(splits[1].shape, (3, 2))\n        self.assertEqual(splits[2].shape, (3, 2))\n\n        # test 1D case\n        x_1d = KerasTensor((10,))\n        splits = knp.hsplit(x_1d, 2)\n        self.assertEqual(len(splits), 2)\n        for split in splits:\n            self.assertEqual(split.shape, (5,))\n\n        splits = knp.hsplit(x_1d, [2, 5])\n        self.assertEqual(len(splits), 3)\n        self.assertEqual(splits[0].shape, (2,))\n        self.assertEqual(splits[1].shape, (3,))\n        self.assertEqual(splits[2].shape, (5,))\n\n    def test_vsplit(self):\n        x = KerasTensor((5, 3))\n\n        splits = knp.vsplit(x, 5)\n        self.assertEqual(len(splits), 5)\n        for split in splits:\n            self.assertEqual(split.shape, (1, 3))\n\n        splits = knp.vsplit(x, [1, 3])\n        self.assertEqual(len(splits), 3)\n        self.assertEqual(splits[0].shape, (1, 3))\n        self.assertEqual(splits[1].shape, (2, 3))\n        self.assertEqual(splits[2].shape, (2, 3))\n\n    def test_sqrt(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.sqrt(x).shape, (2, 3))\n\n    def test_stack(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.stack([x, y]).shape, (2, 2, 3))\n        self.assertEqual(knp.stack([x, y], axis=-1).shape, (2, 3, 2))\n\n        with self.assertRaises(ValueError):\n            x = KerasTensor((2, 3))\n            y = KerasTensor((3, 3))\n            knp.stack([x, y])\n\n    def test_std(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.std(x).shape, ())\n\n    def test_swapaxes(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.swapaxes(x, 0, 1).shape, (3, 2))\n\n    def test_tan(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.tan(x).shape, (2, 3))\n\n    def test_tanh(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.tanh(x).shape, (2, 3))\n\n    def test_tile(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.tile(x, 2).shape, (2, 6))\n        self.assertEqual(knp.tile(x, [2]).shape, (2, 6))\n        self.assertEqual(knp.tile(x, [1, 2]).shape, (2, 6))\n        self.assertEqual(knp.tile(x, [2, 1, 2]).shape, (2, 2, 6))\n\n    def test_trace(self):\n        x = KerasTensor((2, 3, 4, 5))\n        self.assertEqual(knp.trace(x).shape, (4, 5))\n        self.assertEqual(knp.trace(x, axis1=2, axis2=3).shape, (2, 3))\n\n    def test_tril(self):\n        x = KerasTensor((2, 3, 4, 5))\n        self.assertEqual(knp.tril(x).shape, (2, 3, 4, 5))\n        self.assertEqual(knp.tril(x, k=1).shape, (2, 3, 4, 5))\n        self.assertEqual(knp.tril(x, k=-1).shape, (2, 3, 4, 5))\n\n    def test_triu(self):\n        x = KerasTensor((2, 3, 4, 5))\n        self.assertEqual(knp.triu(x).shape, (2, 3, 4, 5))\n        self.assertEqual(knp.triu(x, k=1).shape, (2, 3, 4, 5))\n        self.assertEqual(knp.triu(x, k=-1).shape, (2, 3, 4, 5))\n\n    def test_trunc(self):\n        x = KerasTensor((2, 3, 4, 5))\n        self.assertEqual(knp.trunc(x).shape, (2, 3, 4, 5))\n\n    def test_vstack(self):\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.vstack([x, y]).shape, (4, 3))\n\n    def test_dstack(self):\n        x = KerasTensor((3,))\n        y = KerasTensor((3,))\n        self.assertEqual(knp.dstack([x, y]).shape, (1, 3, 2))\n\n        x = KerasTensor((2, 3))\n        y = KerasTensor((2, 3))\n        self.assertEqual(knp.dstack([x, y]).shape, (2, 3, 2))\n\n        x = KerasTensor((2, 3, 4))\n        y = KerasTensor((2, 3, 5))\n        self.assertEqual(knp.dstack([x, y]).shape, (2, 3, 9))\n\n    def test_argpartition(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.argpartition(x, 3).shape, (2, 3))\n        self.assertEqual(knp.argpartition(x, 1, axis=1).shape, (2, 3))\n\n        with self.assertRaises(ValueError):\n            knp.argpartition(x, (1, 3))\n\n    def test_angle(self):\n        x = KerasTensor((2, 3))\n        self.assertEqual(knp.angle(x).shape, (2, 3))\n\n    def test_view(self):\n        x = knp.array(KerasTensor((2, 3)), dtype=\"int32\")\n        self.assertEqual(knp.view(x, dtype=\"uint32\").shape, (2, 3))\n        self.assertEqual(knp.view(x, dtype=\"uint32\").dtype, \"uint32\")\n        x = knp.array(KerasTensor((2, 3)), dtype=\"int32\")\n        self.assertEqual(knp.view(x, dtype=\"int16\").shape, (2, 6))\n        self.assertEqual(knp.view(x, dtype=\"int16\").dtype, \"int16\")\n        x = knp.array(KerasTensor((2, 4)), dtype=\"int16\")\n        self.assertEqual(knp.view(x, dtype=\"int32\").shape, (2, 2))\n        self.assertEqual(knp.view(x, dtype=\"int32\").dtype, \"int32\")\n\n    def test_array_split(self):\n        x = KerasTensor((8, 4))\n        splits = knp.array_split(x, 3, axis=0)\n        self.assertEqual(len(splits), 3)\n        self.assertEqual(splits[0].shape, (3, 4))\n        self.assertEqual(splits[1].shape, (3, 4))\n        self.assertEqual(splits[2].shape, (2, 4))\n\n\nclass NumpyTwoInputOpsCorrectnessTest(testing.TestCase):\n    def test_add(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[4, 5, 6], [3, 2, 1]])\n        z = np.array([[[1, 2, 3], [3, 2, 1]]])\n        self.assertAllClose(knp.add(x, y), np.add(x, y))\n        self.assertAllClose(knp.add(x, z), np.add(x, z))\n\n        self.assertAllClose(knp.Add()(x, y), np.add(x, y))\n        self.assertAllClose(knp.Add()(x, z), np.add(x, z))\n\n    def test_heaviside(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[4, 5, 6], [6, 5, 4]])\n        self.assertAllClose(knp.heaviside(x, y), np.heaviside(x, y))\n        self.assertAllClose(knp.Heaviside()(x, y), np.heaviside(x, y))\n\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array(4)\n        self.assertAllClose(knp.heaviside(x, y), np.heaviside(x, y))\n        self.assertAllClose(knp.Heaviside()(x, y), np.heaviside(x, y))\n\n    def test_hypot(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[4, 5, 6], [6, 5, 4]])\n        self.assertAllClose(knp.hypot(x, y), np.hypot(x, y))\n        self.assertAllClose(knp.Hypot()(x, y), np.hypot(x, y))\n\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array(4)\n        self.assertAllClose(knp.hypot(x, y), np.hypot(x, y))\n        self.assertAllClose(knp.Hypot()(x, y), np.hypot(x, y))\n\n    def test_subtract(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[4, 5, 6], [3, 2, 1]])\n        z = np.array([[[1, 2, 3], [3, 2, 1]]])\n        self.assertAllClose(knp.subtract(x, y), np.subtract(x, y))\n        self.assertAllClose(knp.subtract(x, z), np.subtract(x, z))\n\n        self.assertAllClose(knp.Subtract()(x, y), np.subtract(x, y))\n        self.assertAllClose(knp.Subtract()(x, z), np.subtract(x, z))\n\n    def test_multiply(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[4, 5, 6], [3, 2, 1]])\n        z = np.array([[[1, 2, 3], [3, 2, 1]]])\n        self.assertAllClose(knp.multiply(x, y), np.multiply(x, y))\n        self.assertAllClose(knp.multiply(x, z), np.multiply(x, z))\n\n        self.assertAllClose(knp.Multiply()(x, y), np.multiply(x, y))\n        self.assertAllClose(knp.Multiply()(x, z), np.multiply(x, z))\n\n    def test_matmul(self):\n        x = np.ones([2, 3, 4, 5])\n        y = np.ones([2, 3, 5, 6])\n        z = np.ones([5, 6])\n        p = np.ones([4])\n        self.assertAllClose(knp.matmul(x, y), np.matmul(x, y))\n        self.assertAllClose(knp.matmul(x, z), np.matmul(x, z))\n        self.assertAllClose(knp.matmul(p, x), np.matmul(p, x))\n\n        self.assertAllClose(knp.Matmul()(x, y), np.matmul(x, y))\n        self.assertAllClose(knp.Matmul()(x, z), np.matmul(x, z))\n        self.assertAllClose(knp.Matmul()(p, x), np.matmul(p, x))\n\n    @parameterized.named_parameters(\n        named_product(\n            (\n                {\n                    \"testcase_name\": \"rank2\",\n                    \"x_shape\": (5, 3),\n                    \"y_shape\": (3, 4),\n                },\n                {\n                    \"testcase_name\": \"rank3\",\n                    \"x_shape\": (2, 5, 3),\n                    \"y_shape\": (2, 3, 4),\n                },\n                {\n                    \"testcase_name\": \"rank4\",\n                    \"x_shape\": (2, 2, 5, 3),\n                    \"y_shape\": (2, 2, 3, 4),\n                },\n            ),\n            dtype=[\"float16\", \"float32\", \"float64\", \"int32\"],\n            x_sparse=[False, True],\n            y_sparse=[False, True],\n        )\n    )\n    @pytest.mark.skipif(\n        not backend.SUPPORTS_SPARSE_TENSORS,\n        reason=\"Backend does not support sparse tensors.\",\n    )\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(), reason=\"Segfault on Tensorflow GPU\"\n    )\n    def test_matmul_sparse(self, dtype, x_shape, y_shape, x_sparse, y_sparse):\n        if backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            if x_sparse and y_sparse and dtype in (\"float16\", \"int32\"):\n                pytest.skip(\n                    f\"Sparse sparse matmul unsupported for {dtype}\"\n                    \" with TensorFlow backend\"\n                )\n\n            dense_to_sparse = tf.sparse.from_dense\n        elif backend.backend() == \"jax\":\n            import jax.experimental.sparse as jax_sparse\n\n            dense_to_sparse = functools.partial(\n                jax_sparse.BCOO.fromdense, n_batch=len(x_shape) - 2\n            )\n\n        rng = np.random.default_rng(0)\n\n        x = x_np = (4 * rng.standard_normal(x_shape)).astype(dtype)\n        if x_sparse:\n            x_np = np.multiply(x_np, rng.random(x_shape) < 0.7)\n            x = dense_to_sparse(x_np)\n\n        y = y_np = (4 * rng.standard_normal(y_shape)).astype(dtype)\n        if y_sparse:\n            y_np = np.multiply(y_np, rng.random(y_shape) < 0.7)\n            y = dense_to_sparse(y_np)\n\n        atol = 0.1 if dtype == \"float16\" else 1e-4\n        tpu_atol = 1 if dtype == \"float16\" else 1e-1\n        self.assertAllClose(\n            knp.matmul(x, y),\n            np.matmul(x_np, y_np),\n            atol=atol,\n            tpu_atol=tpu_atol,\n            tpu_rtol=tpu_atol,\n        )\n        self.assertSparse(knp.matmul(x, y), x_sparse and y_sparse)\n\n    def test_power(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[4, 5, 6], [3, 2, 1]])\n        z = np.array([[[1, 2, 3], [3, 2, 1]]])\n        self.assertAllClose(knp.power(x, y), np.power(x, y))\n        self.assertAllClose(knp.power(x, z), np.power(x, z))\n\n        self.assertAllClose(knp.Power()(x, y), np.power(x, y))\n        self.assertAllClose(knp.Power()(x, z), np.power(x, z))\n\n    def test_divide(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[4, 5, 6], [3, 2, 1]])\n        z = np.array([[[1, 2, 3], [3, 2, 1]]])\n        self.assertAllClose(knp.divide(x, y), np.divide(x, y))\n        self.assertAllClose(knp.divide(x, z), np.divide(x, z))\n\n        self.assertAllClose(knp.Divide()(x, y), np.divide(x, y))\n        self.assertAllClose(knp.Divide()(x, z), np.divide(x, z))\n\n    def test_divide_no_nan(self):\n        x = np.array(\n            [[2, 1, 0], [np.inf, -np.inf, np.nan], [np.inf, -np.inf, np.nan]]\n        )\n        y = np.array([[2, 0, 0], [0, 0, 0], [3, 2, 1]])\n        expected_result = np.array(\n            [[1, 0, 0], [0, 0, 0], [np.inf, -np.inf, np.nan]]\n        )\n        self.assertAllClose(knp.divide_no_nan(x, y), expected_result)\n        self.assertAllClose(knp.DivideNoNan()(x, y), expected_result)\n\n    def test_true_divide(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[4, 5, 6], [3, 2, 1]])\n        z = np.array([[[1, 2, 3], [3, 2, 1]]])\n        self.assertAllClose(knp.true_divide(x, y), np.true_divide(x, y))\n        self.assertAllClose(knp.true_divide(x, z), np.true_divide(x, z))\n\n        self.assertAllClose(knp.TrueDivide()(x, y), np.true_divide(x, y))\n        self.assertAllClose(knp.TrueDivide()(x, z), np.true_divide(x, z))\n\n    def test_append(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[4, 5, 6], [3, 2, 1]])\n        z = np.array([[[1, 2, 3], [3, 2, 1]], [[4, 5, 6], [3, 2, 1]]])\n        self.assertAllClose(knp.append(x, y), np.append(x, y))\n        self.assertAllClose(knp.append(x, y, axis=1), np.append(x, y, axis=1))\n        self.assertAllClose(knp.append(x, z), np.append(x, z))\n\n        self.assertAllClose(knp.Append()(x, y), np.append(x, y))\n        self.assertAllClose(knp.Append(axis=1)(x, y), np.append(x, y, axis=1))\n        self.assertAllClose(knp.Append()(x, z), np.append(x, z))\n\n    def test_arctan2(self):\n        x = np.array([[1.0, 2.0, 3.0], [3.0, 2.0, 1.0]])\n        y = np.array([[4.0, 5.0, 6.0], [3.0, 2.0, 1.0]])\n        self.assertAllClose(knp.arctan2(x, y), np.arctan2(x, y))\n\n        self.assertAllClose(knp.Arctan2()(x, y), np.arctan2(x, y))\n\n        a = np.array([0.0, 0.0, 0.0, 1.0, 1.0, -1.0, -1.0, 1.0, -1.0])\n        b = np.array([0.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 0.0, 0.0])\n\n        self.assertAllClose(knp.arctan2(a, b), np.arctan2(a, b))\n        self.assertAllClose(knp.Arctan2()(a, b), np.arctan2(a, b))\n\n        m = np.array([[3, 4], [7, 8]], dtype=np.int8)\n        n = np.array([[1, 2], [3, 4]], dtype=float)\n\n        self.assertAllClose(knp.arctan2(m, n), np.arctan2(m, n))\n        self.assertAllClose(knp.Arctan2()(m, n), np.arctan2(m, n))\n\n        x = np.array([1.0, 2.0, np.nan])\n        y = np.array([3.0, np.nan, 4.0])\n        self.assertAllClose(knp.arctan2(x, y), np.arctan2(x, y))\n        self.assertAllClose(knp.Arctan2()(x, y), np.arctan2(x, y))\n\n    def test_bitwise_and(self):\n        x = np.array([2, 5, 255])\n        y = np.array([3, 14, 16])\n        self.assertAllClose(knp.bitwise_and(x, y), np.bitwise_and(x, y))\n        self.assertAllClose(knp.BitwiseAnd()(x, y), np.bitwise_and(x, y))\n\n    def test_bitwise_or(self):\n        x = np.array([2, 5, 255])\n        y = np.array([3, 14, 16])\n        self.assertAllClose(knp.bitwise_or(x, y), np.bitwise_or(x, y))\n        self.assertAllClose(knp.BitwiseOr()(x, y), np.bitwise_or(x, y))\n\n    def test_bitwise_xor(self):\n        x = np.array([2, 5, 255])\n        y = np.array([3, 14, 16])\n        self.assertAllClose(knp.bitwise_xor(x, y), np.bitwise_xor(x, y))\n        self.assertAllClose(knp.BitwiseXor()(x, y), np.bitwise_xor(x, y))\n\n    def test_bitwise_left_shift(self):\n        x = np.array([50, 60, 70])\n        y = np.array([1, 2, 3])\n        self.assertAllClose(knp.bitwise_left_shift(x, y), np.left_shift(x, y))\n        self.assertAllClose(knp.BitwiseLeftShift()(x, y), np.left_shift(x, y))\n\n    # left_shift is same as bitwise_left_shift\n\n    def test_bitwise_right_shift(self):\n        x = np.array([5, 6, 7])\n        y = np.array([1, 2, 3])\n        self.assertAllClose(knp.bitwise_right_shift(x, y), np.right_shift(x, y))\n        self.assertAllClose(knp.BitwiseRightShift()(x, y), np.right_shift(x, y))\n\n    # right_shift is same as bitwise_right_shift\n\n    def test_cross(self):\n        x1 = np.ones([2, 1, 4, 3])\n        x2 = np.ones([2, 1, 4, 2])\n        y1 = np.ones([2, 1, 4, 3])\n        y2 = np.ones([1, 5, 4, 3])\n        y3 = np.ones([1, 5, 4, 2])\n        self.assertAllClose(knp.cross(x1, y1), np.cross(x1, y1))\n        self.assertAllClose(knp.cross(x1, y2), np.cross(x1, y2))\n        if backend.backend() != \"torch\":\n            # API divergence between `torch.cross` and `np.cross`\n            # `torch.cross` only allows dim 3, `np.cross` allows dim 2 or 3\n            self.assertAllClose(knp.cross(x1, y3), np.cross(x1, y3))\n            self.assertAllClose(knp.cross(x2, y3), np.cross(x2, y3))\n\n        self.assertAllClose(knp.Cross()(x1, y1), np.cross(x1, y1))\n        self.assertAllClose(knp.Cross()(x1, y2), np.cross(x1, y2))\n        if backend.backend() != \"torch\":\n            # API divergence between `torch.cross` and `np.cross`\n            # `torch.cross` only allows dim 3, `np.cross` allows dim 2 or 3\n            self.assertAllClose(knp.Cross()(x1, y3), np.cross(x1, y3))\n            self.assertAllClose(knp.Cross()(x2, y3), np.cross(x2, y3))\n\n        # Test axis is not None\n        self.assertAllClose(\n            knp.cross(x1, y1, axis=-1), np.cross(x1, y1, axis=-1)\n        )\n        self.assertAllClose(\n            knp.Cross(axis=-1)(x1, y1), np.cross(x1, y1, axis=-1)\n        )\n\n    def test_einsum(self):\n        x = np.arange(24).reshape([2, 3, 4]).astype(\"float32\")\n        y = np.arange(24).reshape([2, 4, 3]).astype(\"float32\")\n        self.assertAllClose(\n            knp.einsum(\"ijk,lkj->il\", x, y),\n            np.einsum(\"ijk,lkj->il\", x, y),\n        )\n        self.assertAllClose(\n            knp.einsum(\"ijk,ikj->i\", x, y),\n            np.einsum(\"ijk,ikj->i\", x, y),\n        )\n        self.assertAllClose(\n            knp.einsum(\"i...,j...k->...ijk\", x, y),\n            np.einsum(\"i..., j...k->...ijk\", x, y),\n        )\n        self.assertAllClose(knp.einsum(\",ijk\", 5, y), np.einsum(\",ijk\", 5, y))\n\n        self.assertAllClose(\n            knp.Einsum(\"ijk,lkj->il\")(x, y),\n            np.einsum(\"ijk,lkj->il\", x, y),\n        )\n        self.assertAllClose(\n            knp.Einsum(\"ijk,ikj->i\")(x, y),\n            np.einsum(\"ijk,ikj->i\", x, y),\n        )\n        self.assertAllClose(\n            knp.Einsum(\"i...,j...k->...ijk\")(x, y),\n            np.einsum(\"i...,j...k->...ijk\", x, y),\n        )\n        self.assertAllClose(knp.Einsum(\",ijk\")(5, y), np.einsum(\",ijk\", 5, y))\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=f\"{backend.backend()} doesn't implement custom ops for einsum.\",\n    )\n    def test_einsum_custom_ops_for_tensorflow(self):\n        subscripts = \"a,b->ab\"\n        x = np.arange(2).reshape([2]).astype(\"float32\")\n        y = np.arange(3).reshape([3]).astype(\"float32\")\n        self.assertAllClose(\n            knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y)\n        )\n\n        subscripts = \"ab,b->a\"\n        x = np.arange(6).reshape([2, 3]).astype(\"float32\")\n        y = np.arange(3).reshape([3]).astype(\"float32\")\n        self.assertAllClose(\n            knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y)\n        )\n\n        subscripts = \"ab,bc->ac\"\n        x = np.arange(6).reshape([2, 3]).astype(\"float32\")\n        y = np.arange(12).reshape([3, 4]).astype(\"float32\")\n        self.assertAllClose(\n            knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y)\n        )\n\n        subscripts = \"ab,cb->ac\"\n        x = np.arange(6).reshape([2, 3]).astype(\"float32\")\n        y = np.arange(12).reshape([4, 3]).astype(\"float32\")\n        self.assertAllClose(\n            knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y)\n        )\n\n        subscripts = \"abc,cd->abd\"\n        x = np.arange(24).reshape([2, 3, 4]).astype(\"float32\")\n        y = np.arange(20).reshape([4, 5]).astype(\"float32\")\n        self.assertAllClose(\n            knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y)\n        )\n\n        subscripts = \"abc,cde->abde\"\n        x = np.arange(24).reshape([2, 3, 4]).astype(\"float32\")\n        y = np.arange(120).reshape([4, 5, 6]).astype(\"float32\")\n        self.assertAllClose(\n            knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y)\n        )\n\n        subscripts = \"abc,dc->abd\"\n        x = np.arange(24).reshape([2, 3, 4]).astype(\"float32\")\n        y = np.arange(20).reshape([5, 4]).astype(\"float32\")\n        self.assertAllClose(\n            knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y)\n        )\n\n        subscripts = \"abc,dce->abde\"\n        x = np.arange(24).reshape([2, 3, 4]).astype(\"float32\")\n        y = np.arange(120).reshape([5, 4, 6]).astype(\"float32\")\n        self.assertAllClose(\n            knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y)\n        )\n\n        subscripts = \"abc,dec->abde\"\n        x = np.arange(24).reshape([2, 3, 4]).astype(\"float32\")\n        y = np.arange(120).reshape([5, 6, 4]).astype(\"float32\")\n        self.assertAllClose(\n            knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y)\n        )\n\n        subscripts = \"abcd,abde->abce\"\n        x = np.arange(120).reshape([2, 3, 4, 5]).astype(\"float32\")\n        y = np.arange(180).reshape([2, 3, 5, 6]).astype(\"float32\")\n        self.assertAllClose(\n            knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y)\n        )\n\n        subscripts = \"abcd,abed->abce\"\n        x = np.arange(120).reshape([2, 3, 4, 5]).astype(\"float32\")\n        y = np.arange(180).reshape([2, 3, 6, 5]).astype(\"float32\")\n        self.assertAllClose(\n            knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y)\n        )\n\n        subscripts = \"abcd,acbe->adbe\"\n        x = np.arange(120).reshape([2, 3, 4, 5]).astype(\"float32\")\n        y = np.arange(144).reshape([2, 4, 3, 6]).astype(\"float32\")\n        self.assertAllClose(\n            knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y)\n        )\n\n        subscripts = \"abcd,adbe->acbe\"\n        x = np.arange(120).reshape([2, 3, 4, 5]).astype(\"float32\")\n        y = np.arange(180).reshape([2, 5, 3, 6]).astype(\"float32\")\n        self.assertAllClose(\n            knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y)\n        )\n\n        subscripts = \"abcd,aecd->acbe\"\n        x = np.arange(120).reshape([2, 3, 4, 5]).astype(\"float32\")\n        y = np.arange(240).reshape([2, 6, 4, 5]).astype(\"float32\")\n        self.assertAllClose(\n            knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y)\n        )\n\n        subscripts = \"abcd,aecd->aceb\"\n        x = np.arange(120).reshape([2, 3, 4, 5]).astype(\"float32\")\n        y = np.arange(240).reshape([2, 6, 4, 5]).astype(\"float32\")\n        self.assertAllClose(\n            knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y)\n        )\n\n        subscripts = \"abcd,cde->abe\"\n        x = np.arange(120).reshape([2, 3, 4, 5]).astype(\"float32\")\n        y = np.arange(120).reshape([4, 5, 6]).astype(\"float32\")\n        self.assertAllClose(\n            knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y)\n        )\n\n        subscripts = \"abcd,ced->abe\"\n        x = np.arange(120).reshape([2, 3, 4, 5]).astype(\"float32\")\n        y = np.arange(120).reshape([4, 6, 5]).astype(\"float32\")\n        self.assertAllClose(\n            knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y)\n        )\n\n        subscripts = \"abcd,ecd->abe\"\n        x = np.arange(120).reshape([2, 3, 4, 5]).astype(\"float32\")\n        y = np.arange(120).reshape([6, 4, 5]).astype(\"float32\")\n        self.assertAllClose(\n            knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y)\n        )\n\n        subscripts = \"abcde,aebf->adbcf\"\n        x = np.arange(720).reshape([2, 3, 4, 5, 6]).astype(\"float32\")\n        y = np.arange(252).reshape([2, 6, 3, 7]).astype(\"float32\")\n        self.assertAllClose(\n            knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y)\n        )\n\n        subscripts = \"abcde,afce->acdbf\"\n        x = np.arange(720).reshape([2, 3, 4, 5, 6]).astype(\"float32\")\n        y = np.arange(336).reshape([2, 7, 4, 6]).astype(\"float32\")\n        self.assertAllClose(\n            knp.einsum(subscripts, x, y), np.einsum(subscripts, x, y)\n        )\n\n    def test_full_like(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.full_like(x, 2), np.full_like(x, 2))\n        self.assertAllClose(\n            knp.full_like(x, 2, dtype=\"float32\"),\n            np.full_like(x, 2, dtype=\"float32\"),\n        )\n        self.assertAllClose(\n            knp.full_like(x, np.ones([2, 3])),\n            np.full_like(x, np.ones([2, 3])),\n        )\n\n        self.assertAllClose(knp.FullLike()(x, 2), np.full_like(x, 2))\n        self.assertAllClose(\n            knp.FullLike(dtype=\"float32\")(x, 2),\n            np.full_like(x, 2, dtype=\"float32\"),\n        )\n        self.assertAllClose(\n            knp.FullLike()(x, np.ones([2, 3])),\n            np.full_like(x, np.ones([2, 3])),\n        )\n\n    def test_gcd(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[4, 5, 6], [3, 2, 1]])\n        self.assertAllClose(knp.gcd(x, y), np.gcd(x, y))\n        self.assertAllClose(knp.gcd(x, 2), np.gcd(x, 2))\n        self.assertAllClose(knp.gcd(2, x), np.gcd(2, x))\n\n        self.assertAllClose(knp.Gcd()(x, y), np.gcd(x, y))\n        self.assertAllClose(knp.Gcd()(x, 2), np.gcd(x, 2))\n        self.assertAllClose(knp.Gcd()(2, x), np.gcd(2, x))\n\n    def test_geomspace(self):\n        self.assertAllClose(knp.geomspace(1, 1000, 4), np.geomspace(1, 1000, 4))\n        self.assertAllClose(\n            knp.geomspace(1, 1000, 4, endpoint=False),\n            np.geomspace(1, 1000, 4, endpoint=False),\n        )\n        self.assertAllClose(\n            knp.Geomspace(num=4)(1, 1000), np.geomspace(1, 1000, 4)\n        )\n        self.assertAllClose(\n            knp.Geomspace(num=4, endpoint=False)(1, 1000),\n            np.geomspace(1, 1000, 4, endpoint=False),\n        )\n\n        start = np.array([1.0, 2.0, 3.0])\n        stop = np.array([1000.0, 2000.0, 3000.0])\n\n        self.assertAllClose(\n            knp.geomspace(start, stop, 4),\n            np.geomspace(start, stop, 4),\n            atol=1e-5,\n            rtol=1e-5,\n        )\n        self.assertAllClose(\n            knp.geomspace(start, stop, 4, endpoint=False),\n            np.geomspace(start, stop, 4, endpoint=False),\n            atol=1e-5,\n            rtol=1e-5,\n        )\n        self.assertAllClose(\n            knp.Geomspace(num=4)(start, stop),\n            np.geomspace(start, stop, 4),\n            atol=1e-5,\n            rtol=1e-5,\n        )\n        self.assertAllClose(\n            knp.Geomspace(num=4, endpoint=False)(start, stop),\n            np.geomspace(start, stop, 4, endpoint=False),\n        )\n\n    def test_greater(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[4, 5, 6], [3, 2, 1]])\n        self.assertAllClose(knp.greater(x, y), np.greater(x, y))\n        self.assertAllClose(knp.greater(x, 2), np.greater(x, 2))\n        self.assertAllClose(knp.greater(2, x), np.greater(2, x))\n\n        self.assertAllClose(knp.Greater()(x, y), np.greater(x, y))\n        self.assertAllClose(knp.Greater()(x, 2), np.greater(x, 2))\n        self.assertAllClose(knp.Greater()(2, x), np.greater(2, x))\n\n    def test_greater_equal(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[4, 5, 6], [3, 2, 1]])\n        self.assertAllClose(\n            knp.greater_equal(x, y),\n            np.greater_equal(x, y),\n        )\n        self.assertAllClose(\n            knp.greater_equal(x, 2),\n            np.greater_equal(x, 2),\n        )\n        self.assertAllClose(\n            knp.greater_equal(2, x),\n            np.greater_equal(2, x),\n        )\n\n        self.assertAllClose(\n            knp.GreaterEqual()(x, y),\n            np.greater_equal(x, y),\n        )\n        self.assertAllClose(\n            knp.GreaterEqual()(x, 2),\n            np.greater_equal(x, 2),\n        )\n        self.assertAllClose(\n            knp.GreaterEqual()(2, x),\n            np.greater_equal(2, x),\n        )\n\n    def test_allclose(self):\n        x = np.array([1], dtype=\"int32\")\n        y = np.array([2], dtype=\"int32\")\n        self.assertAllClose(knp.allclose(x, y, rtol=0.1, atol=1e-8), False)\n\n        x = np.array([1.0], dtype=\"float32\")\n        y = np.array([1.0000001], dtype=\"float32\")\n        self.assertAllClose(knp.allclose(x, y, rtol=0.1, atol=1e-8), True)\n\n        # Test with NaNs\n        x_nan = np.array([np.nan, 1.0])\n        y_nan = np.array([np.nan, 1.0])\n        self.assertAllClose(knp.allclose(x_nan, y_nan), False)\n        self.assertAllClose(knp.allclose(x_nan, y_nan, equal_nan=True), True)\n\n    def test_isclose(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[4, 5, 6], [3, 2, 1]])\n        self.assertAllClose(knp.isclose(x, y), np.isclose(x, y))\n        self.assertAllClose(knp.isclose(x, 2), np.isclose(x, 2))\n        self.assertAllClose(knp.isclose(2, x), np.isclose(2, x))\n\n        self.assertAllClose(knp.Isclose()(x, y), np.isclose(x, y))\n        self.assertAllClose(knp.Isclose()(x, 2), np.isclose(x, 2))\n        self.assertAllClose(knp.Isclose()(2, x), np.isclose(2, x))\n\n    def test_isin(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[4, 5, 6], [3, 2, 1]])\n        self.assertAllClose(knp.isin(x, y), np.isin(x, y))\n        self.assertAllClose(knp.isin(x, 2), np.isin(x, 2))\n        self.assertAllClose(knp.isin(2, x), np.isin(2, x))\n\n        self.assertAllClose(\n            knp.isin(x, y, assume_unique=True),\n            np.isin(x, y, assume_unique=True),\n        )\n        self.assertAllClose(\n            knp.isin(x, 2, assume_unique=True),\n            np.isin(x, 2, assume_unique=True),\n        )\n        self.assertAllClose(\n            knp.isin(2, x, assume_unique=True),\n            np.isin(2, x, assume_unique=True),\n        )\n\n        self.assertAllClose(\n            knp.isin(x, y, invert=True), np.isin(x, y, invert=True)\n        )\n        self.assertAllClose(\n            knp.isin(x, 2, invert=True), np.isin(x, 2, invert=True)\n        )\n        self.assertAllClose(\n            knp.isin(2, x, invert=True), np.isin(2, x, invert=True)\n        )\n\n        self.assertAllClose(\n            knp.isin(x, y, assume_unique=True, invert=True),\n            np.isin(x, y, assume_unique=True, invert=True),\n        )\n        self.assertAllClose(\n            knp.isin(x, 2, assume_unique=True, invert=True),\n            np.isin(x, 2, assume_unique=True, invert=True),\n        )\n        self.assertAllClose(\n            knp.isin(2, x, assume_unique=True, invert=True),\n            np.isin(2, x, assume_unique=True, invert=True),\n        )\n\n        self.assertAllClose(knp.IsIn()(x, y), np.isin(x, y))\n        self.assertAllClose(knp.IsIn()(x, 2), np.isin(x, 2))\n        self.assertAllClose(knp.IsIn()(2, x), np.isin(2, x))\n\n        self.assertAllClose(\n            knp.IsIn(assume_unique=True)(x, y),\n            np.isin(x, y, assume_unique=True),\n        )\n        self.assertAllClose(\n            knp.IsIn(invert=True)(x, y),\n            np.isin(x, y, invert=True),\n        )\n        self.assertAllClose(\n            knp.IsIn(assume_unique=True, invert=True)(x, y),\n            np.isin(x, y, assume_unique=True, invert=True),\n        )\n\n    def test_kron(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[4, 5, 6], [3, 2, 1]])\n        self.assertAllClose(knp.kron(x, y), np.kron(x, y))\n        self.assertAllClose(knp.Kron()(x, y), np.kron(x, y))\n\n    def test_lcm(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[4, 5, 6], [3, 2, 1]])\n        self.assertAllClose(knp.lcm(x, y), np.lcm(x, y))\n        self.assertAllClose(knp.Lcm()(x, y), np.lcm(x, y))\n\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array(4)\n        self.assertAllClose(knp.lcm(x, y), np.lcm(x, y))\n        self.assertAllClose(knp.Lcm()(x, y), np.lcm(x, y))\n\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([4])\n        self.assertAllClose(knp.lcm(x, y), np.lcm(x, y))\n        self.assertAllClose(knp.Lcm()(x, y), np.lcm(x, y))\n\n    def test_ldexp(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[4, 5, 6], [3, 2, 1]])\n        self.assertAllClose(knp.ldexp(x, y), np.ldexp(x, y))\n        self.assertAllClose(knp.Ldexp()(x, y), np.ldexp(x, y))\n\n    def test_less(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[4, 5, 6], [3, 2, 1]])\n        self.assertAllClose(knp.less(x, y), np.less(x, y))\n        self.assertAllClose(knp.less(x, 2), np.less(x, 2))\n        self.assertAllClose(knp.less(2, x), np.less(2, x))\n\n        self.assertAllClose(knp.Less()(x, y), np.less(x, y))\n        self.assertAllClose(knp.Less()(x, 2), np.less(x, 2))\n        self.assertAllClose(knp.Less()(2, x), np.less(2, x))\n\n    def test_less_equal(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[4, 5, 6], [3, 2, 1]])\n        self.assertAllClose(knp.less_equal(x, y), np.less_equal(x, y))\n        self.assertAllClose(knp.less_equal(x, 2), np.less_equal(x, 2))\n        self.assertAllClose(knp.less_equal(2, x), np.less_equal(2, x))\n\n        self.assertAllClose(knp.LessEqual()(x, y), np.less_equal(x, y))\n        self.assertAllClose(knp.LessEqual()(x, 2), np.less_equal(x, 2))\n        self.assertAllClose(knp.LessEqual()(2, x), np.less_equal(2, x))\n\n    def test_linspace(self):\n        self.assertAllClose(knp.linspace(0, 10, 5), np.linspace(0, 10, 5))\n        self.assertAllClose(\n            knp.linspace(0, 10, 5, endpoint=False),\n            np.linspace(0, 10, 5, endpoint=False),\n        )\n        self.assertAllClose(knp.Linspace(num=5)(0, 10), np.linspace(0, 10, 5))\n        self.assertAllClose(\n            knp.Linspace(num=5, endpoint=False)(0, 10),\n            np.linspace(0, 10, 5, endpoint=False),\n        )\n        self.assertAllClose(\n            knp.Linspace(num=0, endpoint=False)(0, 10),\n            np.linspace(0, 10, 0, endpoint=False),\n        )\n\n        start = np.zeros([2, 3, 4])\n        stop = np.ones([2, 3, 4])\n        self.assertAllClose(\n            knp.linspace(start, stop, 5, retstep=True)[0],\n            np.linspace(start, stop, 5, retstep=True)[0],\n        )\n        self.assertAllClose(\n            knp.linspace(start, stop, 5, endpoint=False, retstep=True)[0],\n            np.linspace(start, stop, 5, endpoint=False, retstep=True)[0],\n        )\n        self.assertAllClose(\n            knp.linspace(\n                start, stop, 5, endpoint=False, retstep=True, dtype=\"int32\"\n            )[0],\n            np.linspace(\n                start, stop, 5, endpoint=False, retstep=True, dtype=\"int32\"\n            )[0],\n        )\n\n        self.assertAllClose(\n            knp.Linspace(5, retstep=True)(start, stop)[0],\n            np.linspace(start, stop, 5, retstep=True)[0],\n        )\n        self.assertAllClose(\n            knp.Linspace(5, endpoint=False, retstep=True)(start, stop)[0],\n            np.linspace(start, stop, 5, endpoint=False, retstep=True)[0],\n        )\n        self.assertAllClose(\n            knp.Linspace(5, endpoint=False, retstep=True, dtype=\"int32\")(\n                start, stop\n            )[0],\n            np.linspace(\n                start, stop, 5, endpoint=False, retstep=True, dtype=\"int32\"\n            )[0],\n        )\n\n        # Test `num` as a tensor\n        # https://github.com/keras-team/keras/issues/19772\n        self.assertAllClose(\n            knp.linspace(0, 10, backend.convert_to_tensor(5)),\n            np.linspace(0, 10, 5),\n        )\n        self.assertAllClose(\n            knp.linspace(0, 10, backend.convert_to_tensor(5), endpoint=False),\n            np.linspace(0, 10, 5, endpoint=False),\n        )\n\n    def test_logical_and(self):\n        x = np.array([[True, False], [True, True]])\n        y = np.array([[False, False], [True, False]])\n        self.assertAllClose(knp.logical_and(x, y), np.logical_and(x, y))\n        self.assertAllClose(knp.logical_and(x, True), np.logical_and(x, True))\n        self.assertAllClose(knp.logical_and(True, x), np.logical_and(True, x))\n\n        self.assertAllClose(knp.LogicalAnd()(x, y), np.logical_and(x, y))\n        self.assertAllClose(knp.LogicalAnd()(x, True), np.logical_and(x, True))\n        self.assertAllClose(knp.LogicalAnd()(True, x), np.logical_and(True, x))\n\n    def test_logical_or(self):\n        x = np.array([[True, False], [True, True]])\n        y = np.array([[False, False], [True, False]])\n        self.assertAllClose(knp.logical_or(x, y), np.logical_or(x, y))\n        self.assertAllClose(knp.logical_or(x, True), np.logical_or(x, True))\n        self.assertAllClose(knp.logical_or(True, x), np.logical_or(True, x))\n\n        self.assertAllClose(knp.LogicalOr()(x, y), np.logical_or(x, y))\n        self.assertAllClose(knp.LogicalOr()(x, True), np.logical_or(x, True))\n        self.assertAllClose(knp.LogicalOr()(True, x), np.logical_or(True, x))\n\n    def test_logspace(self):\n        self.assertAllClose(\n            knp.logspace(0, 10, 5),\n            np.logspace(0, 10, 5),\n            tpu_atol=1e-4,\n            tpu_rtol=1e-4,\n        )\n        self.assertAllClose(\n            knp.logspace(0, 10, 5, endpoint=False),\n            np.logspace(0, 10, 5, endpoint=False),\n        )\n        self.assertAllClose(\n            knp.Logspace(num=5)(0, 10),\n            np.logspace(0, 10, 5),\n            tpu_atol=1e-4,\n            tpu_rtol=1e-4,\n        )\n        self.assertAllClose(\n            knp.Logspace(num=5, endpoint=False)(0, 10),\n            np.logspace(0, 10, 5, endpoint=False),\n        )\n\n        start = np.zeros([2, 3, 4])\n        stop = np.ones([2, 3, 4])\n\n        self.assertAllClose(\n            knp.logspace(start, stop, 5, base=10),\n            np.logspace(start, stop, 5, base=10),\n        )\n        self.assertAllClose(\n            knp.logspace(start, stop, 5, endpoint=False, base=10),\n            np.logspace(start, stop, 5, endpoint=False, base=10),\n        )\n\n        self.assertAllClose(\n            knp.Logspace(5, base=10)(start, stop),\n            np.logspace(start, stop, 5, base=10),\n        )\n        self.assertAllClose(\n            knp.Logspace(5, endpoint=False, base=10)(start, stop),\n            np.logspace(start, stop, 5, endpoint=False, base=10),\n        )\n\n    def test_maximum(self):\n        x = np.array([[1, 2], [3, 4]])\n        y = np.array([[5, 6], [7, 8]])\n        self.assertAllClose(knp.maximum(x, y), np.maximum(x, y))\n        self.assertAllClose(knp.maximum(x, 1), np.maximum(x, 1))\n        self.assertAllClose(knp.maximum(1, x), np.maximum(1, x))\n\n        self.assertAllClose(knp.Maximum()(x, y), np.maximum(x, y))\n        self.assertAllClose(knp.Maximum()(x, 1), np.maximum(x, 1))\n        self.assertAllClose(knp.Maximum()(1, x), np.maximum(1, x))\n\n    def test_minimum(self):\n        x = np.array([[1, 2], [3, 4]])\n        y = np.array([[5, 6], [7, 8]])\n        self.assertAllClose(knp.minimum(x, y), np.minimum(x, y))\n        self.assertAllClose(knp.minimum(x, 1), np.minimum(x, 1))\n        self.assertAllClose(knp.minimum(1, x), np.minimum(1, x))\n\n        self.assertAllClose(knp.Minimum()(x, y), np.minimum(x, y))\n        self.assertAllClose(knp.Minimum()(x, 1), np.minimum(x, 1))\n        self.assertAllClose(knp.Minimum()(1, x), np.minimum(1, x))\n\n    def test_mod(self):\n        x = np.array([[1, 2], [3, 4]])\n        y = np.array([[5, 6], [7, 8]])\n        self.assertAllClose(knp.mod(x, y), np.mod(x, y))\n        self.assertAllClose(knp.mod(x, 1), np.mod(x, 1))\n        self.assertAllClose(knp.mod(1, x), np.mod(1, x))\n\n        self.assertAllClose(knp.Mod()(x, y), np.mod(x, y))\n        self.assertAllClose(knp.Mod()(x, 1), np.mod(x, 1))\n        self.assertAllClose(knp.Mod()(1, x), np.mod(1, x))\n\n    def test_fmod(self):\n        x = np.array([[-3, 7], [5, -2]])\n        y = np.array([[2, -3], [3, 4]])\n        self.assertAllClose(knp.fmod(x, y), np.fmod(x, y))\n        self.assertAllClose(knp.fmod(x, 2), np.fmod(x, 2))\n        self.assertAllClose(knp.fmod(1, x), np.fmod(1, x))\n\n        self.assertAllClose(knp.Fmod()(x, y), np.fmod(x, y))\n        self.assertAllClose(knp.Fmod()(x, 2), np.fmod(x, 2))\n        self.assertAllClose(knp.Fmod()(1, x), np.fmod(1, x))\n\n    def test_nextafter(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[4, 5, 6], [3, 2, 1]])\n        self.assertAllClose(knp.nextafter(x, y), np.nextafter(x, y))\n        self.assertAllClose(knp.Nextafter()(x, y), np.nextafter(x, y))\n\n    def test_not_equal(self):\n        x = np.array([[1, 2], [3, 4]])\n        y = np.array([[5, 6], [7, 8]])\n        self.assertAllClose(knp.not_equal(x, y), np.not_equal(x, y))\n        self.assertAllClose(knp.not_equal(x, 1), np.not_equal(x, 1))\n        self.assertAllClose(knp.not_equal(1, x), np.not_equal(1, x))\n\n        self.assertAllClose(knp.NotEqual()(x, y), np.not_equal(x, y))\n        self.assertAllClose(knp.NotEqual()(x, 1), np.not_equal(x, 1))\n        self.assertAllClose(knp.NotEqual()(1, x), np.not_equal(1, x))\n\n    def test_outer(self):\n        x = np.array([1, 2, 3])\n        y = np.array([4, 5, 6])\n        self.assertAllClose(knp.outer(x, y), np.outer(x, y))\n        self.assertAllClose(knp.Outer()(x, y), np.outer(x, y))\n\n        x = np.ones([2, 3, 4])\n        y = np.ones([2, 3, 4, 5, 6])\n        self.assertAllClose(knp.outer(x, y), np.outer(x, y))\n        self.assertAllClose(knp.Outer()(x, y), np.outer(x, y))\n\n    def test_quantile(self):\n        x = np.arange(24).reshape([2, 3, 4]).astype(\"float32\")\n\n        # q as scalar\n        q = np.array(0.5, dtype=\"float32\")\n        self.assertAllClose(knp.quantile(x, q), np.quantile(x, q))\n        self.assertAllClose(\n            knp.quantile(x, q, keepdims=True), np.quantile(x, q, keepdims=True)\n        )\n\n        # q as 1D tensor\n        q = np.array([0.5, 1.0], dtype=\"float32\")\n        self.assertAllClose(knp.quantile(x, q), np.quantile(x, q))\n        self.assertAllClose(\n            knp.quantile(x, q, keepdims=True), np.quantile(x, q, keepdims=True)\n        )\n        self.assertAllClose(\n            knp.quantile(x, q, axis=1), np.quantile(x, q, axis=1)\n        )\n        self.assertAllClose(\n            knp.quantile(x, q, axis=1, keepdims=True),\n            np.quantile(x, q, axis=1, keepdims=True),\n        )\n\n        # multiple axes\n        self.assertAllClose(\n            knp.quantile(x, q, axis=(1, 2)), np.quantile(x, q, axis=(1, 2))\n        )\n\n        # test all supported methods\n        q = np.array([0.501, 1.0], dtype=\"float32\")\n        for method in [\"linear\", \"lower\", \"higher\", \"midpoint\", \"nearest\"]:\n            self.assertAllClose(\n                knp.quantile(x, q, method=method),\n                np.quantile(x, q, method=method),\n            )\n            self.assertAllClose(\n                knp.quantile(x, q, axis=1, method=method),\n                np.quantile(x, q, axis=1, method=method),\n            )\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"Only test tensorflow backend\",\n    )\n    def test_quantile_in_tf_function(self):\n        import tensorflow as tf\n\n        x = knp.array([[1, 2, 3], [4, 5, 6]])\n        q = [0.5]\n        expected_output = np.array([[2, 5]])\n\n        @tf.function\n        def run_quantile(x, q, axis):\n            return knp.quantile(x, q, axis=axis)\n\n        result = run_quantile(x, q, axis=1)\n        self.assertAllClose(result, expected_output)\n\n    def test_take(self):\n        x = np.arange(24).reshape([1, 2, 3, 4])\n        indices = np.array([0, 1])\n        self.assertAllClose(knp.take(x, indices), np.take(x, indices))\n        self.assertAllClose(knp.take(x, 0), np.take(x, 0))\n        self.assertAllClose(knp.take(x, 0, axis=1), np.take(x, 0, axis=1))\n\n        self.assertAllClose(knp.Take()(x, indices), np.take(x, indices))\n        self.assertAllClose(knp.Take()(x, 0), np.take(x, 0))\n        self.assertAllClose(knp.Take(axis=1)(x, 0), np.take(x, 0, axis=1))\n\n        # Test with multi-dimensional indices\n        rng = np.random.default_rng(0)\n        x = rng.standard_normal((2, 3, 4, 5))\n        indices = rng.integers(0, 4, (6, 7))\n        self.assertAllClose(\n            knp.take(x, indices, axis=2), np.take(x, indices, axis=2)\n        )\n\n        # Test with negative axis\n        self.assertAllClose(\n            knp.take(x, indices, axis=-2), np.take(x, indices, axis=-2)\n        )\n\n        # Test with axis=None & x.ndim=2\n        x = np.array(([1, 2], [3, 4]))\n        indices = np.array([2, 3])\n        self.assertAllClose(\n            knp.take(x, indices, axis=None), np.take(x, indices, axis=None)\n        )\n\n        # Test with negative indices\n        x = rng.standard_normal((2, 3, 4, 5))\n        indices = rng.integers(-3, 0, (6, 7))\n        self.assertAllClose(\n            knp.take(x, indices, axis=2), np.take(x, indices, axis=2)\n        )\n\n    @parameterized.named_parameters(\n        named_product(\n            [\n                {\"testcase_name\": \"axis_none\", \"axis\": None},\n                {\"testcase_name\": \"axis_0\", \"axis\": 0},\n                {\"testcase_name\": \"axis_1\", \"axis\": 1},\n                {\"testcase_name\": \"axis_minus1\", \"axis\": -1},\n            ],\n            dtype=[\n                \"float16\",\n                \"float32\",\n                \"float64\",\n                \"uint8\",\n                \"int8\",\n                \"int16\",\n                \"int32\",\n            ],\n        )\n    )\n    @pytest.mark.skipif(\n        not backend.SUPPORTS_SPARSE_TENSORS,\n        reason=\"Backend does not support sparse tensors.\",\n    )\n    def test_take_sparse(self, dtype, axis):\n        rng = np.random.default_rng(0)\n        x = (4 * rng.standard_normal((3, 4, 5))).astype(dtype)\n\n        if backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            indices = tf.SparseTensor([[0, 0], [1, 2]], [-1, 2], (2, 3))\n        elif backend.backend() == \"jax\":\n            import jax.experimental.sparse as jax_sparse\n\n            indices = jax_sparse.BCOO(([-1, 2], [[0, 0], [1, 2]]), shape=(2, 3))\n\n        self.assertAllClose(\n            knp.take(x, indices, axis=axis),\n            np.take(x, backend.convert_to_numpy(indices), axis=axis),\n        )\n\n    @parameterized.named_parameters(\n        named_product(\n            [\n                {\"testcase_name\": \"axis_none\", \"axis\": None},\n                {\"testcase_name\": \"axis_0\", \"axis\": 0},\n                {\"testcase_name\": \"axis_1\", \"axis\": 1},\n                {\"testcase_name\": \"axis_minus1\", \"axis\": -1},\n            ],\n            dtype=[\n                \"float16\",\n                \"float32\",\n                \"float64\",\n                \"uint8\",\n                \"int8\",\n                \"int16\",\n                \"int32\",\n            ],\n        )\n    )\n    @pytest.mark.skipif(\n        not backend.SUPPORTS_RAGGED_TENSORS,\n        reason=\"Backend does not support ragged tensors.\",\n    )\n    def test_take_ragged(self, dtype, axis):\n        rng = np.random.default_rng(0)\n        x = (4 * rng.standard_normal((3, 4, 5))).astype(dtype)\n\n        if backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            indices = tf.ragged.constant([[2], [0, -1, 1]])\n            mask = backend.convert_to_numpy(tf.ones_like(indices))\n\n        if axis == 0:\n            mask = np.expand_dims(mask, (2, 3))\n        elif axis == 1:\n            mask = np.expand_dims(mask, (2,))\n\n        self.assertAllClose(\n            knp.take(x, indices, axis=axis),\n            np.take(x, backend.convert_to_numpy(indices), axis=axis)\n            * mask.astype(dtype),\n        )\n\n    def test_take_along_axis(self):\n        x = np.arange(24).reshape([1, 2, 3, 4])\n        indices = np.ones([1, 4, 1, 1], dtype=np.int32)\n        self.assertAllClose(\n            knp.take_along_axis(x, indices, axis=1),\n            np.take_along_axis(x, indices, axis=1),\n        )\n        self.assertAllClose(\n            knp.TakeAlongAxis(axis=1)(x, indices),\n            np.take_along_axis(x, indices, axis=1),\n        )\n\n        x = np.arange(12).reshape([1, 1, 3, 4])\n        indices = np.ones([1, 4, 1, 1], dtype=np.int32)\n        self.assertAllClose(\n            knp.take_along_axis(x, indices, axis=2),\n            np.take_along_axis(x, indices, axis=2),\n        )\n        self.assertAllClose(\n            knp.TakeAlongAxis(axis=2)(x, indices),\n            np.take_along_axis(x, indices, axis=2),\n        )\n\n        # Test with axis=None\n        x = np.arange(12).reshape([1, 1, 3, 4])\n        indices = np.array([1, 2, 3], dtype=np.int32)\n        self.assertAllClose(\n            knp.take_along_axis(x, indices, axis=None),\n            np.take_along_axis(x, indices, axis=None),\n        )\n        self.assertAllClose(\n            knp.TakeAlongAxis(axis=None)(x, indices),\n            np.take_along_axis(x, indices, axis=None),\n        )\n\n        # Test with negative indices\n        x = np.arange(12).reshape([1, 1, 3, 4])\n        indices = np.full([1, 4, 1, 1], -1, dtype=np.int32)\n        self.assertAllClose(\n            knp.take_along_axis(x, indices, axis=2),\n            np.take_along_axis(x, indices, axis=2),\n        )\n        self.assertAllClose(\n            knp.TakeAlongAxis(axis=2)(x, indices),\n            np.take_along_axis(x, indices, axis=2),\n        )\n\n    def test_tensordot(self):\n        x = np.arange(24).reshape([1, 2, 3, 4]).astype(\"float32\")\n        y = np.arange(24).reshape([3, 4, 1, 2]).astype(\"float32\")\n        self.assertAllClose(\n            knp.tensordot(x, y, axes=2), np.tensordot(x, y, axes=2)\n        )\n        self.assertAllClose(\n            knp.tensordot(x, y, axes=([0, 1], [2, 3])),\n            np.tensordot(x, y, axes=([0, 1], [2, 3])),\n        )\n        self.assertAllClose(\n            knp.Tensordot(axes=2)(x, y),\n            np.tensordot(x, y, axes=2),\n        )\n        self.assertAllClose(\n            knp.Tensordot(axes=([0, 1], [2, 3]))(x, y),\n            np.tensordot(x, y, axes=([0, 1], [2, 3])),\n        )\n        self.assertAllClose(\n            knp.Tensordot(axes=[0, 2])(x, y),\n            np.tensordot(x, y, axes=[0, 2]),\n        )\n\n    def test_vdot(self):\n        x = np.array([1.0, 2.0, 3.0])\n        y = np.array([4.0, 5.0, 6.0])\n        self.assertAllClose(knp.vdot(x, y), np.vdot(x, y))\n        self.assertAllClose(knp.Vdot()(x, y), np.vdot(x, y))\n\n    def test_inner(self):\n        x = np.array([1.0, 2.0, 3.0])\n        y = np.array([4.0, 5.0, 6.0])\n        self.assertAllClose(knp.inner(x, y), np.inner(x, y))\n        self.assertAllClose(knp.Inner()(x, y), np.inner(x, y))\n\n    def test_where(self):\n        x = np.array([1, 2, 3])\n        y = np.array([4, 5, 6])\n        self.assertAllClose(knp.where(x > 1, x, y), np.where(x > 1, x, y))\n        self.assertAllClose(knp.Where()(x > 1, x, y), np.where(x > 1, x, y))\n        self.assertAllClose(knp.where(x > 1), np.where(x > 1))\n        self.assertAllClose(knp.Where()(x > 1), np.where(x > 1))\n\n        with self.assertRaisesRegex(\n            ValueError, \"`x1` and `x2` either both should be `None`\"\n        ):\n            knp.where(x > 1, x, None)\n\n    def test_digitize(self):\n        x = np.array([0.0, 1.0, 3.0, 1.6])\n        bins = np.array([0.0, 3.0, 4.5, 7.0])\n        self.assertAllClose(knp.digitize(x, bins), np.digitize(x, bins))\n        self.assertAllClose(knp.Digitize()(x, bins), np.digitize(x, bins))\n        self.assertTrue(\n            standardize_dtype(knp.digitize(x, bins).dtype) == \"int32\"\n        )\n        self.assertTrue(\n            standardize_dtype(knp.Digitize()(x, bins).dtype) == \"int32\"\n        )\n\n        x = np.array([0.2, 6.4, 3.0, 1.6])\n        bins = np.array([0.0, 1.0, 2.5, 4.0, 10.0])\n        self.assertAllClose(knp.digitize(x, bins), np.digitize(x, bins))\n        self.assertAllClose(knp.Digitize()(x, bins), np.digitize(x, bins))\n        self.assertTrue(\n            standardize_dtype(knp.digitize(x, bins).dtype) == \"int32\"\n        )\n        self.assertTrue(\n            standardize_dtype(knp.Digitize()(x, bins).dtype) == \"int32\"\n        )\n\n        x = np.array([1, 4, 10, 15])\n        bins = np.array([4, 10, 14, 15])\n        self.assertAllClose(knp.digitize(x, bins), np.digitize(x, bins))\n        self.assertAllClose(knp.Digitize()(x, bins), np.digitize(x, bins))\n        self.assertTrue(\n            standardize_dtype(knp.digitize(x, bins).dtype) == \"int32\"\n        )\n        self.assertTrue(\n            standardize_dtype(knp.Digitize()(x, bins).dtype) == \"int32\"\n        )\n\n\nclass NumpyOneInputOpsCorrectnessTest(testing.TestCase):\n    def test_mean(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.mean(x), np.mean(x))\n        self.assertAllClose(knp.mean(x, axis=()), np.mean(x, axis=()))\n        self.assertAllClose(knp.mean(x, axis=1), np.mean(x, axis=1))\n        self.assertAllClose(knp.mean(x, axis=(1,)), np.mean(x, axis=(1,)))\n        self.assertAllClose(\n            knp.mean(x, axis=1, keepdims=True),\n            np.mean(x, axis=1, keepdims=True),\n        )\n\n        self.assertAllClose(knp.Mean()(x), np.mean(x))\n        self.assertAllClose(knp.Mean(axis=1)(x), np.mean(x, axis=1))\n        self.assertAllClose(\n            knp.Mean(axis=1, keepdims=True)(x),\n            np.mean(x, axis=1, keepdims=True),\n        )\n\n        # test overflow\n        x = np.array([65504, 65504, 65504], dtype=\"float16\")\n        self.assertAllClose(knp.mean(x), np.mean(x))\n\n    def test_array_split(self):\n        x = np.array([[1, 2, 3], [4, 5, 6]])\n\n        # Even split (axis=0)\n        knp_res1 = knp.array_split(x, 2)\n        np_res1 = np.array_split(x, 2)\n        self.assertEqual(len(knp_res1), len(np_res1))\n        for k_arr, n_arr in zip(knp_res1, np_res1):\n            self.assertAllClose(k_arr, n_arr)\n\n        # Even split (axis=1)\n        knp_res2 = knp.array_split(x, 3, axis=1)\n        np_res2 = np.array_split(x, 3, axis=1)\n        self.assertEqual(len(knp_res2), len(np_res2))\n        for k_arr, n_arr in zip(knp_res2, np_res2):\n            self.assertAllClose(k_arr, n_arr)\n\n        # Uneven split (axis=1) - 3 columns into 2 sections\n        knp_res3 = knp.array_split(x, 2, axis=1)\n        np_res3 = np.array_split(x, 2, axis=1)\n        self.assertEqual(len(knp_res3), len(np_res3))\n        for k_arr, n_arr in zip(knp_res3, np_res3):\n            self.assertAllClose(k_arr, n_arr)\n\n    def test_all(self):\n        x = np.array([[True, False, True], [True, True, True]])\n        self.assertAllClose(knp.all(x), np.all(x))\n        self.assertAllClose(knp.all(x, axis=()), np.all(x, axis=()))\n        self.assertAllClose(knp.all(x, axis=1), np.all(x, axis=1))\n        self.assertAllClose(knp.all(x, axis=(1,)), np.all(x, axis=(1,)))\n        self.assertAllClose(\n            knp.all(x, axis=1, keepdims=True),\n            np.all(x, axis=1, keepdims=True),\n        )\n\n        self.assertAllClose(knp.All()(x), np.all(x))\n        self.assertAllClose(knp.All(axis=1)(x), np.all(x, axis=1))\n        self.assertAllClose(\n            knp.All(axis=1, keepdims=True)(x),\n            np.all(x, axis=1, keepdims=True),\n        )\n\n    def test_any(self):\n        x = np.array([[True, False, True], [True, True, True]])\n        self.assertAllClose(knp.any(x), np.any(x))\n        self.assertAllClose(knp.any(x, axis=()), np.any(x, axis=()))\n        self.assertAllClose(knp.any(x, axis=1), np.any(x, axis=1))\n        self.assertAllClose(knp.any(x, axis=(1,)), np.any(x, axis=(1,)))\n        self.assertAllClose(\n            knp.any(x, axis=1, keepdims=True),\n            np.any(x, axis=1, keepdims=True),\n        )\n\n        self.assertAllClose(knp.Any()(x), np.any(x))\n        self.assertAllClose(knp.Any(axis=1)(x), np.any(x, axis=1))\n        self.assertAllClose(\n            knp.Any(axis=1, keepdims=True)(x),\n            np.any(x, axis=1, keepdims=True),\n        )\n\n    def test_trapezoid(self):\n        y = np.random.random((3, 3, 3))\n        x = np.random.random((3, 3, 3))\n        dx = 2.0\n\n        self.assertAllClose(knp.trapezoid(y), np.trapezoid(y))\n        self.assertAllClose(knp.trapezoid(y, x=x), np.trapezoid(y, x=x))\n        self.assertAllClose(knp.trapezoid(y, dx=dx), np.trapezoid(y, dx=dx))\n        self.assertAllClose(\n            knp.trapezoid(y, x=x, axis=1),\n            np.trapezoid(y, x=x, axis=1),\n        )\n\n    def test_vander(self):\n        x = np.random.random((3,))\n        N = 6\n\n        self.assertAllClose(knp.vander(x), np.vander(x))\n        self.assertAllClose(knp.vander(x, N=N), np.vander(x, N=N))\n        self.assertAllClose(\n            knp.vander(x, N=N, increasing=True),\n            np.vander(x, N=N, increasing=True),\n        )\n\n        self.assertAllClose(knp.Vander().call(x), np.vander(x))\n        self.assertAllClose(knp.Vander(N=N).call(x), np.vander(x, N=N))\n        self.assertAllClose(\n            knp.Vander(N=N, increasing=True).call(x),\n            np.vander(x, N=N, increasing=True),\n        )\n        self.assertAllClose(\n            knp.Vander(N=N, increasing=False).call(x),\n            np.vander(x, N=N, increasing=False),\n        )\n\n    def test_var(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.var(x), np.var(x))\n        self.assertAllClose(knp.var(x, axis=()), np.var(x, axis=()))\n        self.assertAllClose(knp.var(x, axis=1), np.var(x, axis=1))\n        self.assertAllClose(knp.var(x, axis=(1,)), np.var(x, axis=(1,)))\n        self.assertAllClose(\n            knp.var(x, axis=1, keepdims=True),\n            np.var(x, axis=1, keepdims=True),\n        )\n\n        self.assertAllClose(knp.Var()(x), np.var(x))\n        self.assertAllClose(knp.Var(axis=1)(x), np.var(x, axis=1))\n        self.assertAllClose(\n            knp.Var(axis=1, keepdims=True)(x),\n            np.var(x, axis=1, keepdims=True),\n        )\n\n    def test_sum(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.sum(x), np.sum(x))\n        self.assertAllClose(knp.sum(x, axis=()), np.sum(x, axis=()))\n        self.assertAllClose(knp.sum(x, axis=1), np.sum(x, axis=1))\n        self.assertAllClose(knp.sum(x, axis=(1,)), np.sum(x, axis=(1,)))\n        self.assertAllClose(\n            knp.sum(x, axis=1, keepdims=True),\n            np.sum(x, axis=1, keepdims=True),\n        )\n\n        self.assertAllClose(knp.Sum()(x), np.sum(x))\n        self.assertAllClose(knp.Sum(axis=1)(x), np.sum(x, axis=1))\n        self.assertAllClose(\n            knp.Sum(axis=1, keepdims=True)(x),\n            np.sum(x, axis=1, keepdims=True),\n        )\n\n    def test_amax(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.amax(x), np.amax(x))\n        self.assertAllClose(knp.amax(x, axis=()), np.amax(x, axis=()))\n        self.assertAllClose(knp.amax(x, axis=1), np.amax(x, axis=1))\n        self.assertAllClose(knp.amax(x, axis=(1,)), np.amax(x, axis=(1,)))\n        self.assertAllClose(\n            knp.amax(x, axis=1, keepdims=True),\n            np.amax(x, axis=1, keepdims=True),\n        )\n\n        self.assertAllClose(knp.Amax()(x), np.amax(x))\n        self.assertAllClose(knp.Amax(axis=1)(x), np.amax(x, axis=1))\n        self.assertAllClose(\n            knp.Amax(axis=1, keepdims=True)(x),\n            np.amax(x, axis=1, keepdims=True),\n        )\n\n    def test_amin(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.amin(x), np.amin(x))\n        self.assertAllClose(knp.amin(x, axis=()), np.amin(x, axis=()))\n        self.assertAllClose(knp.amin(x, axis=1), np.amin(x, axis=1))\n        self.assertAllClose(knp.amin(x, axis=(1,)), np.amin(x, axis=(1,)))\n        self.assertAllClose(\n            knp.amin(x, axis=1, keepdims=True),\n            np.amin(x, axis=1, keepdims=True),\n        )\n\n        self.assertAllClose(knp.Amin()(x), np.amin(x))\n        self.assertAllClose(knp.Amin(axis=1)(x), np.amin(x, axis=1))\n        self.assertAllClose(\n            knp.Amin(axis=1, keepdims=True)(x),\n            np.amin(x, axis=1, keepdims=True),\n        )\n\n    def test_square(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.square(x), np.square(x))\n\n        self.assertAllClose(knp.Square()(x), np.square(x))\n\n    def test_negative(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.negative(x), np.negative(x))\n\n        self.assertAllClose(knp.Negative()(x), np.negative(x))\n\n    def test_abs(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.abs(x), np.abs(x))\n\n        self.assertAllClose(knp.Abs()(x), np.abs(x))\n\n    def test_absolute(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.absolute(x), np.absolute(x))\n\n        self.assertAllClose(knp.Absolute()(x), np.absolute(x))\n\n    def test_squeeze(self):\n        x = np.ones([1, 3, 1, 5])\n        self.assertAllClose(knp.squeeze(x), np.squeeze(x))\n        self.assertAllClose(knp.squeeze(x, axis=0), np.squeeze(x, axis=0))\n\n        self.assertAllClose(knp.Squeeze()(x), np.squeeze(x))\n        self.assertAllClose(knp.Squeeze(axis=0)(x), np.squeeze(x, axis=0))\n\n        # Multiple axes\n        x = np.ones([2, 1, 1, 1])\n        self.assertAllClose(knp.squeeze(x, (1, 2)), np.squeeze(x, (1, 2)))\n        self.assertAllClose(knp.squeeze(x, (-1, -2)), np.squeeze(x, (-1, -2)))\n        self.assertAllClose(knp.squeeze(x, (1, 2, 3)), np.squeeze(x, (1, 2, 3)))\n        self.assertAllClose(knp.squeeze(x, (-1, 1)), np.squeeze(x, (-1, 1)))\n\n        self.assertAllClose(knp.Squeeze((1, 2))(x), np.squeeze(x, (1, 2)))\n        self.assertAllClose(knp.Squeeze((-1, -2))(x), np.squeeze(x, (-1, -2)))\n        self.assertAllClose(knp.Squeeze((1, 2, 3))(x), np.squeeze(x, (1, 2, 3)))\n        self.assertAllClose(knp.Squeeze((-1, 1))(x), np.squeeze(x, (-1, 1)))\n\n    def test_transpose(self):\n        x = np.ones([1, 2, 3, 4, 5])\n        self.assertAllClose(knp.transpose(x), np.transpose(x))\n        self.assertAllClose(\n            knp.transpose(x, axes=(1, 0, 3, 2, 4)),\n            np.transpose(x, axes=(1, 0, 3, 2, 4)),\n        )\n\n        self.assertAllClose(knp.Transpose()(x), np.transpose(x))\n        self.assertAllClose(\n            knp.Transpose(axes=(1, 0, 3, 2, 4))(x),\n            np.transpose(x, axes=(1, 0, 3, 2, 4)),\n        )\n\n    def test_arccos(self):\n        x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]])\n        self.assertAllClose(knp.arccos(x), np.arccos(x))\n\n        self.assertAllClose(knp.Arccos()(x), np.arccos(x))\n\n    def test_arccosh(self):\n        x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]])\n        self.assertAllClose(knp.arccosh(x), np.arccosh(x))\n\n        self.assertAllClose(knp.Arccosh()(x), np.arccosh(x))\n\n    def test_arcsin(self):\n        x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]])\n        self.assertAllClose(knp.arcsin(x), np.arcsin(x))\n\n        self.assertAllClose(knp.Arcsin()(x), np.arcsin(x))\n\n    def test_arcsinh(self):\n        x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]])\n        self.assertAllClose(knp.arcsinh(x), np.arcsinh(x))\n\n        self.assertAllClose(knp.Arcsinh()(x), np.arcsinh(x))\n\n    def test_arctan(self):\n        x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]])\n        self.assertAllClose(knp.arctan(x), np.arctan(x))\n\n        self.assertAllClose(knp.Arctan()(x), np.arctan(x))\n\n    def test_arctanh(self):\n        x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]])\n        self.assertAllClose(knp.arctanh(x), np.arctanh(x))\n\n        self.assertAllClose(knp.Arctanh()(x), np.arctanh(x))\n\n    def test_argmax(self):\n        x = np.array([[1, 2, 3], [3, 2, 1], [4, 5, 6]])\n        self.assertAllClose(knp.argmax(x), np.argmax(x))\n        self.assertAllClose(knp.argmax(x, axis=1), np.argmax(x, axis=1))\n        self.assertAllClose(\n            knp.argmax(x, axis=1, keepdims=True),\n            np.argmax(x, axis=1, keepdims=True),\n        )\n        self.assertAllClose(\n            knp.argmax(x, keepdims=True), np.argmax(x, keepdims=True)\n        )\n\n        self.assertAllClose(knp.Argmax()(x), np.argmax(x))\n        self.assertAllClose(knp.Argmax(axis=1)(x), np.argmax(x, axis=1))\n\n        self.assertAllClose(knp.Argmax()(x), np.argmax(x))\n        self.assertAllClose(\n            knp.Argmax(keepdims=True)(x), np.argmax(x, keepdims=True)\n        )\n\n    def test_argmin(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.argmin(x), np.argmin(x))\n        self.assertAllClose(knp.argmin(x, axis=1), np.argmin(x, axis=1))\n        self.assertAllClose(\n            knp.argmin(x, keepdims=True), np.argmin(x, keepdims=True)\n        )\n\n        self.assertAllClose(knp.Argmin()(x), np.argmin(x))\n        self.assertAllClose(knp.Argmin(axis=1)(x), np.argmin(x, axis=1))\n        self.assertAllClose(\n            knp.Argmin(keepdims=True)(x), np.argmin(x, keepdims=True)\n        )\n\n    def test_argsort(self):\n        x = np.array([[1, 2, 3], [4, 5, 6]])\n        self.assertAllClose(knp.argsort(x), np.argsort(x))\n        self.assertAllClose(knp.argsort(x, axis=1), np.argsort(x, axis=1))\n        self.assertAllClose(knp.argsort(x, axis=None), np.argsort(x, axis=None))\n\n        self.assertAllClose(knp.Argsort()(x), np.argsort(x))\n        self.assertAllClose(knp.Argsort(axis=1)(x), np.argsort(x, axis=1))\n        self.assertAllClose(knp.Argsort(axis=None)(x), np.argsort(x, axis=None))\n\n        x = np.array(1)  # rank == 0\n        self.assertAllClose(knp.argsort(x), np.argsort(x))\n        self.assertAllClose(knp.Argsort()(x), np.argsort(x))\n\n    def test_array(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.array(x), np.array(x))\n        self.assertAllClose(knp.Array()(x), np.array(x))\n        self.assertTrue(backend.is_tensor(knp.array(x)))\n        self.assertTrue(backend.is_tensor(knp.Array()(x)))\n\n        # Check dtype conversion.\n        x = [[1, 0, 1], [1, 1, 0]]\n        output = knp.array(x, dtype=\"int32\")\n        self.assertEqual(standardize_dtype(output.dtype), \"int32\")\n        x = [[1, 0, 1], [1, 1, 0]]\n        output = knp.array(x, dtype=\"float32\")\n        self.assertEqual(standardize_dtype(output.dtype), \"float32\")\n        x = [[1, 0, 1], [1, 1, 0]]\n        output = knp.array(x, dtype=\"bool\")\n        self.assertEqual(standardize_dtype(output.dtype), \"bool\")\n\n    def test_average(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        weights = np.ones([2, 3])\n        weights_1d = np.ones([3])\n        self.assertAllClose(knp.average(x), np.average(x))\n        self.assertAllClose(knp.average(x, axis=()), np.average(x, axis=()))\n        self.assertAllClose(knp.average(x, axis=1), np.average(x, axis=1))\n        self.assertAllClose(knp.average(x, axis=(1,)), np.average(x, axis=(1,)))\n        self.assertAllClose(\n            knp.average(x, axis=1, weights=weights),\n            np.average(x, axis=1, weights=weights),\n        )\n        self.assertAllClose(\n            knp.average(x, axis=1, weights=weights_1d),\n            np.average(x, axis=1, weights=weights_1d),\n        )\n\n        self.assertAllClose(knp.Average()(x), np.average(x))\n        self.assertAllClose(knp.Average(axis=1)(x), np.average(x, axis=1))\n        self.assertAllClose(\n            knp.Average(axis=1)(x, weights=weights),\n            np.average(x, axis=1, weights=weights),\n        )\n        self.assertAllClose(\n            knp.Average(axis=1)(x, weights=weights_1d),\n            np.average(x, axis=1, weights=weights_1d),\n        )\n\n    def test_bartlett(self):\n        x = np.random.randint(1, 100 + 1)\n        self.assertAllClose(knp.bartlett(x), np.bartlett(x))\n\n        self.assertAllClose(knp.Bartlett()(x), np.bartlett(x))\n\n    def test_blackman(self):\n        x = np.random.randint(1, 100 + 1)\n        self.assertAllClose(knp.blackman(x), np.blackman(x))\n\n        self.assertAllClose(knp.Blackman()(x), np.blackman(x))\n\n    def test_hamming(self):\n        x = np.random.randint(1, 100 + 1)\n        self.assertAllClose(knp.hamming(x), np.hamming(x))\n\n        self.assertAllClose(knp.Hamming()(x), np.hamming(x))\n\n    def test_hanning(self):\n        x = np.random.randint(1, 100 + 1)\n        self.assertAllClose(knp.hanning(x), np.hanning(x))\n\n        self.assertAllClose(knp.Hanning()(x), np.hanning(x))\n\n    def test_kaiser(self):\n        x = np.random.randint(1, 100 + 1)\n        beta = float(np.random.randint(10, 20 + 1))\n        self.assertAllClose(knp.kaiser(x, beta), np.kaiser(x, beta))\n\n        self.assertAllClose(knp.Kaiser(beta)(x), np.kaiser(x, beta))\n\n    @parameterized.named_parameters(\n        named_product(sparse_input=(False, True), sparse_arg=(False, True))\n    )\n    @pytest.mark.skipif(\n        testing.tensorflow_uses_gpu(),\n        reason=\"bincount not supported on TensorFlow GPU\",\n    )\n    def test_bincount(self, sparse_input, sparse_arg):\n        if (sparse_input or sparse_arg) and not backend.SUPPORTS_SPARSE_TENSORS:\n            pytest.skip(\"Backend does not support sparse tensors\")\n\n        x = x_np = np.array([1, 1, 2, 3, 2, 4, 4, 6])\n        weights = weights_np = np.array([0, 0, 3, 2, 1, 1, 4, 2])\n        if sparse_input:\n            indices = np.array([[1], [3], [5], [7], [9], [11], [13], [15]])\n\n            if backend.backend() == \"tensorflow\":\n                import tensorflow as tf\n\n                x = tf.SparseTensor(indices, x, (16,))\n                weights = tf.SparseTensor(indices, weights, (16,))\n            elif backend.backend() == \"jax\":\n                from jax.experimental import sparse as jax_sparse\n\n                x = jax_sparse.BCOO((x, indices), shape=(16,))\n                weights = jax_sparse.BCOO((weights, indices), shape=(16,))\n\n        minlength = 3\n        output = knp.bincount(\n            x, weights=weights, minlength=minlength, sparse=sparse_arg\n        )\n        self.assertAllClose(\n            output, np.bincount(x_np, weights=weights_np, minlength=minlength)\n        )\n        self.assertSparse(output, sparse_input or sparse_arg)\n        output = knp.Bincount(\n            weights=weights, minlength=minlength, sparse=sparse_arg\n        )(x)\n        self.assertAllClose(\n            output, np.bincount(x_np, weights=weights_np, minlength=minlength)\n        )\n        self.assertSparse(output, sparse_input or sparse_arg)\n\n        x = knp.expand_dims(x, 0)\n        weights = knp.expand_dims(weights, 0)\n\n        expected_output = np.array([[0, 0, 4, 2, 5, 0, 2]])\n        output = knp.bincount(\n            x, weights=weights, minlength=minlength, sparse=sparse_arg\n        )\n        self.assertAllClose(output, expected_output)\n        self.assertSparse(output, sparse_input or sparse_arg)\n        output = knp.Bincount(\n            weights=weights, minlength=minlength, sparse=sparse_arg\n        )(x)\n        self.assertAllClose(output, expected_output)\n        self.assertSparse(output, sparse_input or sparse_arg)\n\n        # test with weights=None\n        expected_output = np.array([[0, 2, 2, 1, 2, 0, 1]])\n        output = knp.Bincount(\n            weights=None, minlength=minlength, sparse=sparse_arg\n        )(x)\n        self.assertAllClose(output, expected_output)\n        self.assertSparse(output, sparse_input or sparse_arg)\n\n    def test_bitwise_invert(self):\n        x = np.array([2, 5, 255])\n        self.assertAllClose(knp.bitwise_invert(x), np.bitwise_not(x))\n        self.assertAllClose(knp.BitwiseInvert()(x), np.bitwise_not(x))\n\n    # bitwise_not is same as bitwise_invert\n\n    def test_broadcast_to(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(\n            knp.broadcast_to(x, [2, 2, 3]),\n            np.broadcast_to(x, [2, 2, 3]),\n        )\n\n        self.assertAllClose(\n            knp.BroadcastTo([2, 2, 3])(x),\n            np.broadcast_to(x, [2, 2, 3]),\n        )\n\n    def test_cbrt(self):\n        x = np.array([[-8, -1, 0], [1, 8, 27]], dtype=\"float32\")\n        ref_y = np.sign(x) * np.abs(x) ** (1.0 / 3.0)\n        y = knp.cbrt(x)\n        self.assertEqual(standardize_dtype(y.dtype), \"float32\")\n        self.assertAllClose(y, ref_y)\n\n        y = knp.Cbrt()(x)\n        self.assertEqual(standardize_dtype(y.dtype), \"float32\")\n        self.assertAllClose(y, ref_y)\n\n    def test_ceil(self):\n        x = np.array([[1.2, 2.1, -2.5], [2.4, -11.9, -5.5]])\n        self.assertAllClose(knp.ceil(x), np.ceil(x))\n        self.assertAllClose(knp.Ceil()(x), np.ceil(x))\n\n    def test_clip(self):\n        x = np.array([[1.2, 2.1, 0.5], [2.4, 11.9, 0.5]])\n        self.assertAllClose(knp.clip(x, 1, 2), np.clip(x, 1, 2))\n        self.assertAllClose(knp.clip(x, 1, 2), np.clip(x, 1, 2))\n\n        self.assertAllClose(knp.Clip(0, 1)(x), np.clip(x, 0, 1))\n        self.assertAllClose(knp.Clip(0, 1)(x), np.clip(x, 0, 1))\n\n    def test_concatenate(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[4, 5, 6], [6, 5, 4]])\n        z = np.array([[7, 8, 9], [9, 8, 7]])\n        self.assertAllClose(\n            knp.concatenate([x, y], axis=0),\n            np.concatenate([x, y], axis=0),\n        )\n        self.assertAllClose(\n            knp.concatenate([x, y, z], axis=0),\n            np.concatenate([x, y, z], axis=0),\n        )\n        self.assertAllClose(\n            knp.concatenate([x, y], axis=1),\n            np.concatenate([x, y], axis=1),\n        )\n\n        self.assertAllClose(\n            knp.Concatenate(axis=0)([x, y]),\n            np.concatenate([x, y], axis=0),\n        )\n        self.assertAllClose(\n            knp.Concatenate(axis=0)([x, y, z]),\n            np.concatenate([x, y, z], axis=0),\n        )\n        self.assertAllClose(\n            knp.Concatenate(axis=1)([x, y]),\n            np.concatenate([x, y], axis=1),\n        )\n\n    def test_view(self):\n        x = np.array(1, dtype=\"int16\")\n        result = knp.view(x, dtype=\"float16\")\n        self.assertEqual(backend.standardize_dtype(result.dtype), \"float16\")\n\n        with self.assertRaises(Exception):\n            result = knp.view(x, dtype=\"int8\")\n\n        with self.assertRaises(Exception):\n            result = knp.view(x, dtype=\"int32\")\n\n        x = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=\"int16\")\n        result = knp.view(x, dtype=\"int16\")\n        self.assertEqual(backend.standardize_dtype(result.dtype), \"int16\")\n\n        self.assertEqual(\n            backend.standardize_dtype(knp.view(x, dtype=\"int16\").dtype), \"int16\"\n        )\n        self.assertAllClose(knp.view(x, dtype=\"int16\"), x.view(\"int16\"))\n\n        self.assertEqual(\n            backend.standardize_dtype(knp.view(x, dtype=\"float16\").dtype),\n            \"float16\",\n        )\n        self.assertAllClose(knp.view(x, dtype=\"float16\"), x.view(\"float16\"))\n\n        self.assertEqual(\n            backend.standardize_dtype(knp.view(x, dtype=\"int8\").dtype), \"int8\"\n        )\n        self.assertAllClose(knp.view(x, dtype=\"int8\"), x.view(\"int8\"))\n\n        self.assertEqual(\n            backend.standardize_dtype(knp.view(x, dtype=\"int32\").dtype), \"int32\"\n        )\n        self.assertAllClose(knp.view(x, dtype=\"int32\"), x.view(\"int32\"))\n\n    @parameterized.named_parameters(\n        [\n            {\"testcase_name\": \"axis_0\", \"axis\": 0},\n            {\"testcase_name\": \"axis_1\", \"axis\": 1},\n        ]\n    )\n    @pytest.mark.skipif(\n        not backend.SUPPORTS_SPARSE_TENSORS,\n        reason=\"Backend does not support sparse tensors.\",\n    )\n    def test_concatenate_sparse(self, axis):\n        if backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            x = tf.SparseTensor([[0, 0], [1, 2]], [1.0, 2.0], (2, 3))\n            y = tf.SparseTensor([[0, 0], [1, 1]], [4.0, 5.0], (2, 3))\n        elif backend.backend() == \"jax\":\n            import jax.experimental.sparse as jax_sparse\n\n            x = jax_sparse.BCOO(([1.0, 2.0], [[0, 0], [1, 2]]), shape=(2, 3))\n            y = jax_sparse.BCOO(([4.0, 5.0], [[0, 0], [1, 1]]), shape=(2, 3))\n\n        x_np = backend.convert_to_numpy(x)\n        y_np = backend.convert_to_numpy(y)\n        z = np.random.rand(2, 3).astype(\"float32\")\n\n        self.assertAllClose(\n            knp.concatenate([x, z], axis=axis),\n            np.concatenate([x_np, z], axis=axis),\n        )\n        self.assertAllClose(\n            knp.concatenate([z, x], axis=axis),\n            np.concatenate([z, x_np], axis=axis),\n        )\n        self.assertAllClose(\n            knp.concatenate([x, y], axis=axis),\n            np.concatenate([x_np, y_np], axis=axis),\n        )\n\n        self.assertAllClose(\n            knp.Concatenate(axis=axis)([x, z]),\n            np.concatenate([x_np, z], axis=axis),\n        )\n        self.assertAllClose(\n            knp.Concatenate(axis=axis)([z, x]),\n            np.concatenate([z, x_np], axis=axis),\n        )\n        self.assertAllClose(\n            knp.Concatenate(axis=axis)([x, y]),\n            np.concatenate([x_np, y_np], axis=axis),\n        )\n\n        self.assertSparse(knp.concatenate([x, y], axis=axis))\n        self.assertSparse(knp.Concatenate(axis=axis)([x, y]))\n\n    def test_conjugate(self):\n        x = np.array([[1 + 2j, 2 + 3j], [3 + 4j, 4 + 5j]])\n        self.assertAllClose(knp.conjugate(x), np.conjugate(x))\n        self.assertAllClose(knp.Conjugate()(x), np.conjugate(x))\n\n    def test_conj(self):\n        x = np.array([[1 + 2j, 2 + 3j], [3 + 4j, 4 + 5j]])\n        self.assertAllClose(knp.conj(x), np.conj(x))\n        self.assertAllClose(knp.Conj()(x), np.conj(x))\n\n    def test_copy(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.copy(x), np.copy(x))\n        self.assertAllClose(knp.Copy()(x), np.copy(x))\n\n    def test_corrcoef(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.corrcoef(x), np.corrcoef(x))\n        self.assertAllClose(knp.Corrcoef()(x), np.corrcoef(x))\n\n    def test_cos(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.cos(x), np.cos(x))\n        self.assertAllClose(knp.Cos()(x), np.cos(x))\n\n    def test_cosh(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.cosh(x), np.cosh(x))\n        self.assertAllClose(knp.Cosh()(x), np.cosh(x))\n\n    def test_count_nonzero(self):\n        x = np.array([[0, 2, 3], [3, 2, 0]])\n        self.assertAllClose(knp.count_nonzero(x), np.count_nonzero(x))\n        self.assertAllClose(\n            knp.count_nonzero(x, axis=()), np.count_nonzero(x, axis=())\n        )\n        self.assertAllClose(\n            knp.count_nonzero(x, axis=1),\n            np.count_nonzero(x, axis=1),\n        )\n        self.assertAllClose(\n            knp.count_nonzero(x, axis=(1,)),\n            np.count_nonzero(x, axis=(1,)),\n        )\n\n        self.assertAllClose(\n            knp.CountNonzero()(x),\n            np.count_nonzero(x),\n        )\n        self.assertAllClose(\n            knp.CountNonzero(axis=1)(x),\n            np.count_nonzero(x, axis=1),\n        )\n\n    @parameterized.product(\n        axis=[None, 0, 1, -1],\n        dtype=[None, \"int32\", \"float32\"],\n    )\n    def test_cumprod(self, axis, dtype):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(\n            knp.cumprod(x, axis=axis, dtype=dtype),\n            np.cumprod(x, axis=axis, dtype=dtype or x.dtype),\n        )\n        self.assertAllClose(\n            knp.Cumprod(axis=axis, dtype=dtype)(x),\n            np.cumprod(x, axis=axis, dtype=dtype or x.dtype),\n        )\n\n    @parameterized.product(\n        axis=[None, 0, 1, -1],\n        dtype=[None, \"int32\", \"float32\"],\n    )\n    def test_cumsum(self, axis, dtype):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(\n            knp.cumsum(x, axis=axis, dtype=dtype),\n            np.cumsum(x, axis=axis, dtype=dtype or x.dtype),\n        )\n        self.assertAllClose(\n            knp.Cumsum(axis=axis, dtype=dtype)(x),\n            np.cumsum(x, axis=axis, dtype=dtype or x.dtype),\n        )\n\n    def test_deg2rad(self):\n        x = np.random.uniform(-360, 360, size=(3, 3))\n        self.assertAllClose(knp.deg2rad(x), np.deg2rad(x))\n        self.assertAllClose(knp.Deg2rad()(x), np.deg2rad(x))\n\n    def test_diag(self):\n        x = np.array([1, 2, 3])\n        self.assertAllClose(knp.diag(x), np.diag(x))\n        self.assertAllClose(knp.diag(x, k=1), np.diag(x, k=1))\n        self.assertAllClose(knp.diag(x, k=-1), np.diag(x, k=-1))\n\n        self.assertAllClose(knp.Diag()(x), np.diag(x))\n        self.assertAllClose(knp.Diag(k=1)(x), np.diag(x, k=1))\n        self.assertAllClose(knp.Diag(k=-1)(x), np.diag(x, k=-1))\n\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.diag(x), np.diag(x))\n        self.assertAllClose(knp.diag(x, k=1), np.diag(x, k=1))\n        self.assertAllClose(knp.diag(x, k=-1), np.diag(x, k=-1))\n\n        self.assertAllClose(knp.Diag()(x), np.diag(x))\n        self.assertAllClose(knp.Diag(k=1)(x), np.diag(x, k=1))\n        self.assertAllClose(knp.Diag(k=-1)(x), np.diag(x, k=-1))\n\n    def test_diagflat(self):\n        x = np.array([1, 2, 3])\n        self.assertAllClose(knp.diagflat(x), np.diagflat(x))\n        self.assertAllClose(knp.diagflat(x, k=1), np.diagflat(x, k=1))\n        self.assertAllClose(knp.diagflat(x, k=-1), np.diagflat(x, k=-1))\n\n        x = np.array([[1, 2], [3, 4]])\n        self.assertAllClose(knp.diagflat(x), np.diagflat(x))\n        self.assertAllClose(knp.diagflat(x, k=1), np.diagflat(x, k=1))\n        self.assertAllClose(knp.diagflat(x, k=-1), np.diagflat(x, k=-1))\n\n        x = np.array([1, 2, 3, 4])\n        self.assertAllClose(knp.diagflat(x), np.diagflat(x))\n        self.assertAllClose(knp.diagflat(x, k=2), np.diagflat(x, k=2))\n        self.assertAllClose(knp.diagflat(x, k=-2), np.diagflat(x, k=-2))\n\n        x_float = np.array([1.1, 2.2, 3.3])\n        self.assertAllClose(knp.diagflat(x_float), np.diagflat(x_float))\n\n        x = np.array([1, 2, 3])\n        self.assertAllClose(knp.Diagflat()(x), np.diagflat(x))\n        self.assertAllClose(knp.Diagflat(k=1)(x), np.diagflat(x, k=1))\n        self.assertAllClose(knp.Diagflat(k=-1)(x), np.diagflat(x, k=-1))\n\n    def test_diagonal(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.diagonal(x), np.diagonal(x))\n        self.assertAllClose(\n            knp.diagonal(x, offset=1),\n            np.diagonal(x, offset=1),\n        )\n        self.assertAllClose(\n            knp.diagonal(x, offset=-1), np.diagonal(x, offset=-1)\n        )\n\n        self.assertAllClose(knp.Diagonal()(x), np.diagonal(x))\n        self.assertAllClose(knp.Diagonal(offset=1)(x), np.diagonal(x, offset=1))\n        self.assertAllClose(\n            knp.Diagonal(offset=-1)(x), np.diagonal(x, offset=-1)\n        )\n\n        x = np.ones([2, 3, 4, 5])\n        self.assertAllClose(knp.diagonal(x), np.diagonal(x))\n        self.assertAllClose(\n            knp.diagonal(x, offset=1, axis1=2, axis2=3),\n            np.diagonal(x, offset=1, axis1=2, axis2=3),\n        )\n        self.assertAllClose(\n            knp.diagonal(x, offset=-1, axis1=2, axis2=3),\n            np.diagonal(x, offset=-1, axis1=2, axis2=3),\n        )\n\n    def test_diff(self):\n        x = np.array([1, 2, 4, 7, 0])\n        self.assertAllClose(knp.diff(x), np.diff(x))\n        self.assertAllClose(knp.diff(x, n=2), np.diff(x, n=2))\n        self.assertAllClose(knp.diff(x, n=3), np.diff(x, n=3))\n\n        x = np.array([[1, 3, 6, 10], [0, 5, 6, 8]])\n        self.assertAllClose(knp.diff(x), np.diff(x))\n        self.assertAllClose(knp.diff(x, axis=0), np.diff(x, axis=0))\n        self.assertAllClose(knp.diff(x, n=2, axis=0), np.diff(x, n=2, axis=0))\n        self.assertAllClose(knp.diff(x, n=2, axis=1), np.diff(x, n=2, axis=1))\n\n        # Test n=0\n        x = np.array([1, 2, 4, 7, 0])\n        self.assertAllClose(knp.diff(x, n=0), np.diff(x, n=0))\n\n    def test_dot(self):\n        x = np.arange(24).reshape([2, 3, 4]).astype(\"float32\")\n        y = np.arange(12).reshape([4, 3]).astype(\"float32\")\n        z = np.arange(4).astype(\"float32\")\n        self.assertAllClose(knp.dot(x, y), np.dot(x, y))\n        self.assertAllClose(knp.dot(x, z), np.dot(x, z))\n        self.assertAllClose(knp.dot(x, 2), np.dot(x, 2))\n\n        self.assertAllClose(knp.Dot()(x, y), np.dot(x, y))\n        self.assertAllClose(knp.Dot()(x, z), np.dot(x, z))\n        self.assertAllClose(knp.Dot()(x, 2), np.dot(x, 2))\n\n    def test_exp(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.exp(x), np.exp(x))\n        self.assertAllClose(knp.Exp()(x), np.exp(x))\n\n    def test_exp2(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.exp2(x), np.exp2(x))\n        self.assertAllClose(knp.Exp2()(x), np.exp2(x))\n\n    def test_expand_dims(self):\n        x = np.ones([2, 3, 4])\n        self.assertAllClose(knp.expand_dims(x, 0), np.expand_dims(x, 0))\n        self.assertAllClose(knp.expand_dims(x, 1), np.expand_dims(x, 1))\n        self.assertAllClose(knp.expand_dims(x, -2), np.expand_dims(x, -2))\n\n        self.assertAllClose(knp.ExpandDims(0)(x), np.expand_dims(x, 0))\n        self.assertAllClose(knp.ExpandDims(1)(x), np.expand_dims(x, 1))\n        self.assertAllClose(knp.ExpandDims(-2)(x), np.expand_dims(x, -2))\n\n        # Multiple axes\n        self.assertAllClose(\n            knp.expand_dims(x, (1, 2)), np.expand_dims(x, (1, 2))\n        )\n        self.assertAllClose(\n            knp.expand_dims(x, (-1, -2)), np.expand_dims(x, (-1, -2))\n        )\n        self.assertAllClose(\n            knp.expand_dims(x, (-1, 1)), np.expand_dims(x, (-1, 1))\n        )\n\n        self.assertAllClose(\n            knp.ExpandDims((1, 2))(x), np.expand_dims(x, (1, 2))\n        )\n        self.assertAllClose(\n            knp.ExpandDims((-1, -2))(x), np.expand_dims(x, (-1, -2))\n        )\n        self.assertAllClose(\n            knp.ExpandDims((-1, 1))(x), np.expand_dims(x, (-1, 1))\n        )\n\n    def test_expm1(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.expm1(x), np.expm1(x))\n        self.assertAllClose(knp.Expm1()(x), np.expm1(x))\n\n    def test_flip(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.flip(x), np.flip(x))\n        self.assertAllClose(knp.flip(x, 0), np.flip(x, 0))\n        self.assertAllClose(knp.flip(x, 1), np.flip(x, 1))\n\n        self.assertAllClose(knp.Flip()(x), np.flip(x))\n        self.assertAllClose(knp.Flip(0)(x), np.flip(x, 0))\n        self.assertAllClose(knp.Flip(1)(x), np.flip(x, 1))\n\n    def test_floor(self):\n        x = np.array([[1.1, 2.2, -3.3], [3.3, 2.2, -1.1]])\n        self.assertAllClose(knp.floor(x), np.floor(x))\n        self.assertAllClose(knp.Floor()(x), np.floor(x))\n\n    def test_hstack(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[4, 5, 6], [6, 5, 4]])\n        self.assertAllClose(knp.hstack([x, y]), np.hstack([x, y]))\n        self.assertAllClose(knp.Hstack()([x, y]), np.hstack([x, y]))\n\n        x = np.ones([2, 3, 4])\n        y = np.ones([2, 5, 4])\n        self.assertAllClose(knp.hstack([x, y]), np.hstack([x, y]))\n        self.assertAllClose(knp.Hstack()([x, y]), np.hstack([x, y]))\n\n    def test_imag(self):\n        x = np.array([[1 + 1j, 2 + 2j, 3 + 3j], [3 + 3j, 2 + 2j, 1 + 1j]])\n        self.assertAllClose(knp.imag(x), np.imag(x))\n        self.assertAllClose(knp.Imag()(x), np.imag(x))\n\n    def test_isfinite(self):\n        x = np.array([[1, 2, np.inf], [np.nan, np.nan, np.nan]])\n        self.assertAllClose(knp.isfinite(x), np.isfinite(x))\n        self.assertAllClose(knp.Isfinite()(x), np.isfinite(x))\n\n    def test_isinf(self):\n        x = np.array([[1, 2, np.inf], [np.nan, np.nan, np.nan]])\n        self.assertAllClose(knp.isinf(x), np.isinf(x))\n        self.assertAllClose(knp.Isinf()(x), np.isinf(x))\n\n    def test_isnan(self):\n        x = np.array([[1, 2, np.inf], [np.nan, np.nan, np.nan]])\n        self.assertAllClose(knp.isnan(x), np.isnan(x))\n        self.assertAllClose(knp.Isnan()(x), np.isnan(x))\n\n    def test_isneginf(self):\n        x = np.array(\n            [[1, 2, np.inf, -np.inf], [np.nan, np.nan, np.nan, np.nan]]\n        )\n        self.assertAllClose(knp.isneginf(x), np.isneginf(x))\n        self.assertAllClose(knp.Isneginf()(x), np.isneginf(x))\n\n    def test_isposinf(self):\n        x = np.array(\n            [[1, 2, np.inf, -np.inf], [np.nan, np.nan, np.nan, np.nan]]\n        )\n        self.assertAllClose(knp.isposinf(x), np.isposinf(x))\n        self.assertAllClose(knp.Isposinf()(x), np.isposinf(x))\n\n    def test_isreal(self):\n        x = np.array([1 + 1j, 1 + 0j, 4.5, 3, 2, 2j], dtype=complex)\n        self.assertAllClose(knp.isreal(x), np.isreal(x))\n        self.assertAllClose(knp.Isreal()(x), np.isreal(x))\n\n        x = np.array([1.0, 2.0, 3.0])\n        self.assertAllClose(knp.isreal(x), np.isreal(x))\n        self.assertAllClose(knp.Isreal()(x), np.isreal(x))\n\n    def test_log(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.log(x), np.log(x))\n        self.assertAllClose(knp.Log()(x), np.log(x))\n\n    def test_log10(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.log10(x), np.log10(x))\n        self.assertAllClose(knp.Log10()(x), np.log10(x))\n\n    def test_log1p(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.log1p(x), np.log1p(x))\n        self.assertAllClose(knp.Log1p()(x), np.log1p(x))\n\n    def test_log2(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.log2(x), np.log2(x))\n        self.assertAllClose(knp.Log2()(x), np.log2(x))\n\n    def test_logaddexp(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.logaddexp(x, y), np.logaddexp(x, y))\n        self.assertAllClose(knp.Logaddexp()(x, y), np.logaddexp(x, y))\n\n    def test_logaddexp2(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.logaddexp2(x, y), np.logaddexp2(x, y))\n        self.assertAllClose(knp.Logaddexp2()(x, y), np.logaddexp2(x, y))\n\n    def test_logical_not(self):\n        x = np.array([[True, False], [False, True]])\n        self.assertAllClose(knp.logical_not(x), np.logical_not(x))\n        self.assertAllClose(knp.LogicalNot()(x), np.logical_not(x))\n\n    def test_max(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.max(x), np.max(x))\n        self.assertAllClose(knp.Max()(x), np.max(x))\n\n        self.assertAllClose(knp.max(x, 0), np.max(x, 0))\n        self.assertAllClose(knp.Max(0)(x), np.max(x, 0))\n\n        self.assertAllClose(knp.max(x, 1), np.max(x, 1))\n        self.assertAllClose(knp.Max(1)(x), np.max(x, 1))\n\n        # test max with initial\n        self.assertAllClose(knp.max(x, initial=4), 4)\n\n        # test empty tensor\n        x = np.array([[]])\n        self.assertAllClose(knp.max(x, initial=1), np.max(x, initial=1))\n        self.assertAllClose(\n            knp.max(x, initial=1, keepdims=True),\n            np.max(x, initial=1, keepdims=True),\n        )\n\n    def test_min(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.min(x), np.min(x))\n        self.assertAllClose(knp.Min()(x), np.min(x))\n\n        self.assertAllClose(knp.min(x, axis=(0, 1)), np.min(x, (0, 1)))\n        self.assertAllClose(knp.Min((0, 1))(x), np.min(x, (0, 1)))\n\n        self.assertAllClose(knp.min(x, axis=()), np.min(x, axis=()))\n        self.assertAllClose(knp.Min(())(x), np.min(x, axis=()))\n\n        self.assertAllClose(knp.min(x, 0), np.min(x, 0))\n        self.assertAllClose(knp.Min(0)(x), np.min(x, 0))\n\n        self.assertAllClose(knp.min(x, 1), np.min(x, 1))\n        self.assertAllClose(knp.Min(1)(x), np.min(x, 1))\n\n        # test min with initial\n        self.assertAllClose(knp.min(x, initial=0), 0)\n\n        # test empty tensor\n        x = np.array([[]])\n        self.assertAllClose(knp.min(x, initial=1), np.min(x, initial=1))\n        self.assertAllClose(\n            knp.min(x, initial=1, keepdims=True),\n            np.min(x, initial=1, keepdims=True),\n        )\n\n    def test_median(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]]).astype(\"float32\")\n        self.assertAllClose(knp.median(x), np.median(x))\n        self.assertAllClose(\n            knp.median(x, keepdims=True), np.median(x, keepdims=True)\n        )\n        self.assertAllClose(knp.median(x, axis=1), np.median(x, axis=1))\n        self.assertAllClose(knp.median(x, axis=(1,)), np.median(x, axis=(1,)))\n        self.assertAllClose(\n            knp.median(x, axis=1, keepdims=True),\n            np.median(x, axis=1, keepdims=True),\n        )\n\n        self.assertAllClose(knp.Median()(x), np.median(x))\n        self.assertAllClose(knp.Median(axis=1)(x), np.median(x, axis=1))\n        self.assertAllClose(\n            knp.Median(axis=1, keepdims=True)(x),\n            np.median(x, axis=1, keepdims=True),\n        )\n\n    def test_meshgrid(self):\n        x = np.array([1, 2, 3])\n        y = np.array([4, 5, 6])\n        z = np.array([7, 8, 9])\n        self.assertAllClose(knp.meshgrid(x, y), np.meshgrid(x, y))\n        self.assertAllClose(knp.meshgrid(x, z), np.meshgrid(x, z))\n        self.assertAllClose(\n            knp.meshgrid(x, y, z, indexing=\"ij\"),\n            np.meshgrid(x, y, z, indexing=\"ij\"),\n        )\n        self.assertAllClose(knp.Meshgrid()(x, y), np.meshgrid(x, y))\n        self.assertAllClose(knp.Meshgrid()(x, z), np.meshgrid(x, z))\n        self.assertAllClose(\n            knp.Meshgrid(indexing=\"ij\")(x, y, z),\n            np.meshgrid(x, y, z, indexing=\"ij\"),\n        )\n\n        if backend.backend() == \"tensorflow\":\n            # Arguments to `jax.numpy.meshgrid` must be 1D now.\n            x = np.ones([1, 2, 3])\n            y = np.ones([4, 5, 6, 6])\n            z = np.ones([7, 8])\n            self.assertAllClose(knp.meshgrid(x, y), np.meshgrid(x, y))\n            self.assertAllClose(knp.meshgrid(x, z), np.meshgrid(x, z))\n            self.assertAllClose(\n                knp.meshgrid(x, y, z, indexing=\"ij\"),\n                np.meshgrid(x, y, z, indexing=\"ij\"),\n            )\n            self.assertAllClose(knp.Meshgrid()(x, y), np.meshgrid(x, y))\n            self.assertAllClose(knp.Meshgrid()(x, z), np.meshgrid(x, z))\n            self.assertAllClose(\n                knp.Meshgrid(indexing=\"ij\")(x, y, z),\n                np.meshgrid(x, y, z, indexing=\"ij\"),\n            )\n\n    def test_moveaxis(self):\n        x = np.array([[[0, 1], [2, 3]], [[4, 5], [6, 7]]])\n        self.assertAllClose(knp.moveaxis(x, 0, -1), np.moveaxis(x, 0, -1))\n        self.assertAllClose(knp.moveaxis(x, -1, 0), np.moveaxis(x, -1, 0))\n        self.assertAllClose(\n            knp.moveaxis(x, (0, 1), (1, 0)),\n            np.moveaxis(x, (0, 1), (1, 0)),\n        )\n        self.assertAllClose(\n            knp.moveaxis(x, [0, 1, 2], [2, 0, 1]),\n            np.moveaxis(x, [0, 1, 2], [2, 0, 1]),\n        )\n        self.assertAllClose(knp.Moveaxis(-1, 0)(x), np.moveaxis(x, -1, 0))\n        self.assertAllClose(\n            knp.Moveaxis((0, 1), (1, 0))(x),\n            np.moveaxis(x, (0, 1), (1, 0)),\n        )\n\n        self.assertAllClose(\n            knp.Moveaxis([0, 1, 2], [2, 0, 1])(x),\n            np.moveaxis(x, [0, 1, 2], [2, 0, 1]),\n        )\n\n    def test_ndim(self):\n        x = np.array([1, 2, 3])\n        self.assertEqual(knp.ndim(x), np.ndim(x))\n        self.assertEqual(knp.Ndim()(x), np.ndim(x))\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"Only test tensorflow backend\",\n    )\n    def test_ndim_tf_ragged(self):\n        import tensorflow as tf\n\n        # Rank 2\n        ragged_2d = tf.ragged.constant([[1, 2, 3], [4]])\n        self.assertEqual(knp.ndim(ragged_2d), 2)\n        self.assertEqual(knp.Ndim()(ragged_2d), 2)\n        # Rank 0\n        ragged_scalar = tf.ragged.constant(1)\n        self.assertEqual(knp.ndim(ragged_scalar), 0)\n        self.assertEqual(knp.Ndim()(ragged_scalar), 0)\n        # Rank 3\n        ragged_3d = tf.ragged.constant([[[1], [2, 3]], [[4, 5, 6]]])\n        self.assertEqual(knp.ndim(ragged_3d), 3)\n        self.assertEqual(knp.Ndim()(ragged_3d), 3)\n\n    def test_nonzero(self):\n        x = np.array([[0, 0, 3], [3, 0, 0]])\n        self.assertAllClose(knp.nonzero(x), np.nonzero(x))\n        self.assertAllClose(knp.Nonzero()(x), np.nonzero(x))\n\n    def test_ones_like(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.ones_like(x), np.ones_like(x))\n        self.assertAllClose(knp.OnesLike()(x), np.ones_like(x))\n\n    @parameterized.named_parameters(\n        named_product(\n            dtype=[\n                \"float16\",\n                \"float32\",\n                \"float64\",\n                \"uint8\",\n                \"int8\",\n                \"int16\",\n                \"int32\",\n            ],\n            mode=[\"constant\", \"reflect\", \"symmetric\"],\n            constant_values=[None, 0, 2],\n        )\n    )\n    def test_pad(self, dtype, mode, constant_values):\n        # 2D\n        x = np.ones([2, 3], dtype=dtype)\n        pad_width = ((1, 1), (1, 1))\n\n        if mode != \"constant\":\n            if constant_values is not None:\n                with self.assertRaisesRegex(\n                    ValueError,\n                    \"Argument `constant_values` can only be \"\n                    \"provided when `mode == 'constant'`\",\n                ):\n                    knp.pad(\n                        x, pad_width, mode=mode, constant_values=constant_values\n                    )\n                return\n            # constant_values is None\n            kwargs = {}\n        else:\n            # mode is constant\n            kwargs = {\"constant_values\": constant_values or 0}\n\n        self.assertAllClose(\n            knp.pad(x, pad_width, mode=mode, constant_values=constant_values),\n            np.pad(x, pad_width, mode=mode, **kwargs),\n        )\n        self.assertAllClose(\n            knp.Pad(pad_width, mode=mode)(x, constant_values=constant_values),\n            np.pad(x, pad_width, mode=mode, **kwargs),\n        )\n\n        # 5D (pad last 3D)\n        x = np.ones([2, 3, 4, 5, 6], dtype=dtype)\n        pad_width = ((0, 0), (0, 0), (2, 3), (1, 1), (1, 1))\n        self.assertAllClose(\n            knp.pad(x, pad_width, mode=mode, constant_values=constant_values),\n            np.pad(x, pad_width, mode=mode, **kwargs),\n        )\n        self.assertAllClose(\n            knp.Pad(pad_width, mode=mode)(x, constant_values=constant_values),\n            np.pad(x, pad_width, mode=mode, **kwargs),\n        )\n\n        # 5D (pad arbitrary dimensions)\n        if backend.backend() == \"torch\" and mode != \"constant\":\n            self.skipTest(\n                \"reflect and symmetric padding for arbitrary dimensions \"\n                \"are not supported by torch\"\n            )\n        x = np.ones([2, 3, 4, 5, 6], dtype=dtype)\n        pad_width = ((1, 1), (2, 1), (3, 2), (4, 3), (5, 4))\n        self.assertAllClose(\n            knp.pad(x, pad_width, mode=mode, constant_values=constant_values),\n            np.pad(x, pad_width, mode=mode, **kwargs),\n        )\n        self.assertAllClose(\n            knp.Pad(pad_width, mode=mode)(x, constant_values=constant_values),\n            np.pad(x, pad_width, mode=mode, **kwargs),\n        )\n\n    def test_prod(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.prod(x), np.prod(x))\n        self.assertAllClose(knp.prod(x, axis=()), np.prod(x, axis=()))\n        self.assertAllClose(knp.prod(x, axis=1), np.prod(x, axis=1))\n        self.assertAllClose(knp.prod(x, axis=(1,)), np.prod(x, axis=(1,)))\n        self.assertAllClose(\n            knp.prod(x, axis=1, keepdims=True),\n            np.prod(x, axis=1, keepdims=True),\n        )\n\n        self.assertAllClose(knp.Prod()(x), np.prod(x))\n        self.assertAllClose(knp.Prod(axis=1)(x), np.prod(x, axis=1))\n        self.assertAllClose(\n            knp.Prod(axis=1, keepdims=True)(x),\n            np.prod(x, axis=1, keepdims=True),\n        )\n\n    def test_ptp(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n\n        self.assertAllClose(knp.ptp(x), np.ptp(x))\n        self.assertAllClose(knp.ptp(x, axis=None), np.ptp(x, axis=None))\n\n        self.assertAllClose(knp.ptp(x, axis=0), np.ptp(x, axis=0))\n        self.assertAllClose(knp.ptp(x, axis=1), np.ptp(x, axis=1))\n        self.assertAllClose(knp.ptp(x, axis=(1,)), np.ptp(x, axis=(1,)))\n\n        self.assertAllClose(knp.ptp(x, axis=()), np.ptp(x, axis=()))\n\n        self.assertAllClose(\n            knp.ptp(x, axis=1, keepdims=True),\n            np.ptp(x, axis=1, keepdims=True),\n        )\n\n        self.assertAllClose(knp.Ptp()(x), np.ptp(x))\n        self.assertAllClose(knp.Ptp(axis=1)(x), np.ptp(x, axis=1))\n        self.assertAllClose(knp.Ptp(axis=(0, 1))(x), np.ptp(x, axis=(0, 1)))\n        self.assertAllClose(\n            knp.Ptp(axis=1, keepdims=True)(x),\n            np.ptp(x, axis=1, keepdims=True),\n        )\n\n    def test_ravel(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.ravel(x), np.ravel(x))\n        self.assertAllClose(knp.Ravel()(x), np.ravel(x))\n\n    def test_unravel_index(self):\n        x = np.array([0, 1, 2, 3])\n        shape = (2, 2)\n        self.assertAllClose(\n            knp.unravel_index(x, shape), np.unravel_index(x, shape)\n        )\n\n        x = np.array([[0, 1], [2, 3]])\n        shape = (2, 2)\n        self.assertAllClose(\n            knp.unravel_index(x, shape), np.unravel_index(x, shape)\n        )\n\n    def test_real(self):\n        x = np.array([[1, 2, 3 - 3j], [3, 2, 1 + 5j]])\n        self.assertAllClose(knp.real(x), np.real(x))\n        self.assertAllClose(knp.Real()(x), np.real(x))\n\n    def test_reciprocal(self):\n        x = np.array([[1.0, 2.0, 3.0], [3.0, 2.0, 1.0]])\n        self.assertAllClose(knp.reciprocal(x), np.reciprocal(x))\n        self.assertAllClose(knp.Reciprocal()(x), np.reciprocal(x))\n\n    def test_repeat(self):\n        x = np.array([[1, 2], [3, 4]])\n        self.assertAllClose(knp.repeat(x, 2), np.repeat(x, 2))\n        self.assertAllClose(\n            knp.Repeat(np.array([2]))(x),\n            np.repeat(x, np.array([2])),\n        )\n        self.assertAllClose(knp.repeat(x, 3, axis=1), np.repeat(x, 3, axis=1))\n        self.assertAllClose(\n            knp.repeat(x, np.array([1, 2]), axis=-1),\n            np.repeat(x, np.array([1, 2]), axis=-1),\n        )\n        self.assertAllClose(knp.Repeat(2)(x), np.repeat(x, 2))\n        self.assertAllClose(knp.Repeat(3, axis=1)(x), np.repeat(x, 3, axis=1))\n        self.assertAllClose(\n            knp.Repeat(np.array([1, 2]), axis=0)(x),\n            np.repeat(x, np.array([1, 2]), axis=0),\n        )\n\n    def test_reshape(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.reshape(x, [3, 2]), np.reshape(x, [3, 2]))\n        self.assertAllClose(knp.Reshape([3, 2])(x), np.reshape(x, [3, 2]))\n        self.assertAllClose(knp.Reshape(-1)(x), np.reshape(x, -1))\n\n    def test_roll(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.roll(x, 1), np.roll(x, 1))\n        self.assertAllClose(knp.roll(x, 1, axis=1), np.roll(x, 1, axis=1))\n        self.assertAllClose(knp.roll(x, -1, axis=0), np.roll(x, -1, axis=0))\n        self.assertAllClose(knp.Roll(1)(x), np.roll(x, 1))\n        self.assertAllClose(knp.Roll(1, axis=1)(x), np.roll(x, 1, axis=1))\n        self.assertAllClose(knp.Roll(-1, axis=0)(x), np.roll(x, -1, axis=0))\n\n    def test_round(self):\n        x = np.array([[1.1, 2.5, 3.9], [3.2, 2.3, 1.8]])\n        self.assertAllClose(knp.round(x), np.round(x))\n        self.assertAllClose(knp.Round()(x), np.round(x))\n\n        # Test with decimal=1\n        self.assertAllClose(knp.round(x, decimals=1), np.round(x, decimals=1))\n        self.assertAllClose(knp.Round(decimals=1)(x), np.round(x, decimals=1))\n\n        # Test with integers\n        x = np.array([[1, 2, 3], [3, 2, 1]], dtype=\"int32\")\n        self.assertAllClose(knp.round(x, decimals=1), np.round(x, decimals=1))\n        self.assertAllClose(knp.Round(decimals=1)(x), np.round(x, decimals=1))\n\n        # Test with integers and decimal < 0\n        x = np.array([[123, 234, 345], [345, 234, 123]], dtype=\"int32\")\n        self.assertAllClose(knp.round(x, decimals=-1), np.round(x, decimals=-1))\n        self.assertAllClose(knp.Round(decimals=-1)(x), np.round(x, decimals=-1))\n\n    def test_searchsorted(self):\n        a = np.array([1, 2, 2, 3, 4, 5, 5])\n        v = np.array([4, 3, 5, 1, 2])\n        expected = np.searchsorted(a, v).astype(\"int32\")\n        self.assertAllEqual(knp.searchsorted(a, v), expected)\n        self.assertAllEqual(knp.SearchSorted()(a, v), expected)\n\n    def test_sign(self):\n        x = np.array([[1, -2, 3], [-3, 2, -1]])\n        self.assertAllClose(knp.sign(x), np.sign(x))\n        self.assertAllClose(knp.Sign()(x), np.sign(x))\n\n    def test_signbit(self):\n        x = np.array([[0.0, -0.0, -1.1e-45], [1.1e-38, 2, -1]])\n        self.assertAllClose(knp.signbit(x), np.signbit(x))\n        self.assertAllClose(knp.Signbit()(x), np.signbit(x))\n\n    def test_sin(self):\n        x = np.array([[1, -2, 3], [-3, 2, -1]])\n        self.assertAllClose(knp.sin(x), np.sin(x))\n        self.assertAllClose(knp.Sin()(x), np.sin(x))\n\n    def test_sinc(self):\n        x = np.array([[0, 1, -1], [0.5, -0.5, 2]])\n        self.assertAllClose(knp.sinc(x), np.sinc(x))\n        self.assertAllClose(knp.Sinc()(x), np.sinc(x))\n\n    def test_sinh(self):\n        x = np.array([[1, -2, 3], [-3, 2, -1]])\n        self.assertAllClose(knp.sinh(x), np.sinh(x))\n        self.assertAllClose(knp.Sinh()(x), np.sinh(x))\n\n    def test_size(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.size(x), np.size(x))\n        self.assertAllClose(knp.Size()(x), np.size(x))\n\n    def test_sort(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.sort(x), np.sort(x))\n        self.assertAllClose(knp.Sort()(x), np.sort(x))\n        self.assertAllClose(knp.sort(x, axis=0), np.sort(x, axis=0))\n        self.assertAllClose(knp.Sort(axis=0)(x), np.sort(x, axis=0))\n\n    def test_split(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertIsInstance(knp.split(x, 2), list)\n        self.assertAllClose(knp.split(x, 2), np.split(x, 2))\n        self.assertAllClose(knp.Split(2)(x), np.split(x, 2))\n        self.assertAllClose(\n            knp.split(x, [1, 2], axis=1),\n            np.split(x, [1, 2], axis=1),\n        )\n        self.assertAllClose(\n            knp.Split([1, 2], axis=1)(x),\n            np.split(x, [1, 2], axis=1),\n        )\n\n        # test invalid indices_or_sections\n        with self.assertRaises(Exception):\n            knp.split(x, 3)\n\n        # test zero dimension\n        x = np.ones(shape=(0,))\n        self.assertEqual(len(knp.split(x, 2)), 2)\n        self.assertEqual(len(knp.Split(2)(x)), 2)\n\n        # test indices_or_sections as tensor\n        x = knp.array([[1, 2, 3], [3, 2, 1]])\n        indices_or_sections = knp.array([1, 2])\n        x_np = np.array([[1, 2, 3], [3, 2, 1]])\n        indices_or_sections_np = np.array([1, 2])\n        self.assertAllClose(\n            knp.split(x, indices_or_sections, axis=1),\n            np.split(x_np, indices_or_sections_np, axis=1),\n        )\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"Only test tensorflow backend\",\n    )\n    def test_split_with_jit_in_tf(self):\n        import tensorflow as tf\n\n        x = knp.array([[1, 2, 3], [3, 2, 1]])\n        indices = knp.array([1, 2])\n        x_np = np.array([[1, 2, 3], [3, 2, 1]])\n        indices_np = np.array([1, 2])\n\n        @tf.function(jit_compile=True)\n        def fn(x, indices, axis):\n            return knp.split(x, indices, axis=axis)\n\n        self.assertAllClose(\n            fn(x, indices, axis=1),\n            np.split(x_np, indices_np, axis=1),\n        )\n\n    def test_hsplit(self):\n        x = np.arange(18).reshape((3, 6))\n\n        self.assertIsInstance(knp.hsplit(x, 3), list)\n        self.assertAllClose(knp.hsplit(x, 3), np.hsplit(x, 3))\n        self.assertAllClose(knp.Hsplit(3)(x), np.hsplit(x, 3))\n\n        indices = [1, 3, 5]\n\n        # Compare each split\n        for split_knp, split_np in zip(\n            knp.hsplit(x, indices), np.hsplit(x, indices)\n        ):\n            self.assertAllClose(split_knp, split_np)\n\n        for split_knp, split_np in zip(\n            knp.Hsplit(indices)(x), np.hsplit(x, indices)\n        ):\n            self.assertAllClose(split_knp, split_np)\n\n        with self.assertRaises(Exception):\n            knp.hsplit(x, 4)\n\n        x_kr = knp.array(x)\n        indices_kr = knp.array(indices)\n        indices_np = np.array(indices)\n\n        for split_knp, split_np in zip(\n            knp.hsplit(x_kr, indices_kr), np.hsplit(x, indices_np)\n        ):\n            self.assertAllClose(split_knp, split_np)\n\n        # Test 1D case\n        x_1d = np.arange(10)\n        indices_1d = [2, 5, 9]\n\n        self.assertIsInstance(knp.hsplit(x_1d, 2), list)\n        self.assertAllClose(knp.hsplit(x_1d, 2), np.hsplit(x_1d, 2))\n        self.assertAllClose(knp.Hsplit(2)(x_1d), np.hsplit(x_1d, 2))\n\n        for split_knp, split_np in zip(\n            knp.hsplit(x_1d, indices_1d), np.hsplit(x_1d, indices_1d)\n        ):\n            self.assertAllClose(split_knp, split_np)\n\n        for split_knp, split_np in zip(\n            knp.Hsplit(indices_1d)(x_1d), np.hsplit(x_1d, indices_1d)\n        ):\n            self.assertAllClose(split_knp, split_np)\n\n        with self.assertRaises(Exception):\n            knp.hsplit(x_1d, 3)\n\n        x_kr = knp.array(x_1d)\n        indices_kr = knp.array(indices_1d)\n        indices_np = np.array(indices_1d)\n\n        for split_knp, split_np in zip(\n            knp.hsplit(x_kr, indices_kr), np.hsplit(x_1d, indices_np)\n        ):\n            self.assertAllClose(split_knp, split_np)\n\n    def test_vsplit(self):\n        x = np.arange(18).reshape((6, 3))\n\n        self.assertIsInstance(knp.vsplit(x, 3), list)\n        self.assertAllClose(knp.vsplit(x, 3), np.vsplit(x, 3))\n        self.assertAllClose(knp.Vsplit(3)(x), np.vsplit(x, 3))\n\n        indices = [1, 3, 5]\n\n        # Compare each split\n        for split_knp, split_np in zip(\n            knp.vsplit(x, indices), np.vsplit(x, indices)\n        ):\n            self.assertAllClose(split_knp, split_np)\n\n        for split_knp, split_np in zip(\n            knp.Vsplit(indices)(x), np.vsplit(x, indices)\n        ):\n            self.assertAllClose(split_knp, split_np)\n\n        with self.assertRaises(Exception):\n            knp.vsplit(x, 4)\n\n        x_kr = knp.array(x)\n        indices_kr = knp.array(indices)\n        indices_np = np.array(indices)\n\n        for split_knp, split_np in zip(\n            knp.vsplit(x_kr, indices_kr), np.vsplit(x, indices_np)\n        ):\n            self.assertAllClose(split_knp, split_np)\n\n    def test_sqrt(self):\n        x = np.array([[1, 4, 9], [16, 25, 36]], dtype=\"float32\")\n        ref_y = np.sqrt(x)\n        y = knp.sqrt(x)\n        self.assertEqual(standardize_dtype(y.dtype), \"float32\")\n        self.assertAllClose(y, ref_y)\n        y = knp.Sqrt()(x)\n        self.assertEqual(standardize_dtype(y.dtype), \"float32\")\n        self.assertAllClose(y, ref_y)\n\n    def test_sqrt_int32(self):\n        x = np.array([[1, 4, 9], [16, 25, 36]], dtype=\"int32\")\n        ref_y = np.sqrt(x)\n        y = knp.sqrt(x)\n        self.assertEqual(standardize_dtype(y.dtype), \"float32\")\n        self.assertAllClose(y, ref_y)\n        y = knp.Sqrt()(x)\n        self.assertEqual(standardize_dtype(y.dtype), \"float32\")\n        self.assertAllClose(y, ref_y)\n\n    def test_stack(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[4, 5, 6], [6, 5, 4]])\n        self.assertAllClose(knp.stack([x, y]), np.stack([x, y]))\n        self.assertAllClose(knp.stack([x, y], axis=1), np.stack([x, y], axis=1))\n        self.assertAllClose(knp.Stack()([x, y]), np.stack([x, y]))\n        self.assertAllClose(knp.Stack(axis=1)([x, y]), np.stack([x, y], axis=1))\n\n    def test_std(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.std(x), np.std(x))\n        self.assertAllClose(knp.std(x, axis=1), np.std(x, axis=1))\n        self.assertAllClose(\n            knp.std(x, axis=1, keepdims=True),\n            np.std(x, axis=1, keepdims=True),\n        )\n\n        self.assertAllClose(knp.Std()(x), np.std(x))\n        self.assertAllClose(knp.Std(axis=1)(x), np.std(x, axis=1))\n        self.assertAllClose(\n            knp.Std(axis=1, keepdims=True)(x),\n            np.std(x, axis=1, keepdims=True),\n        )\n\n    def test_swapaxes(self):\n        x = np.arange(24).reshape([1, 2, 3, 4])\n        self.assertAllClose(\n            knp.swapaxes(x, 0, 1),\n            np.swapaxes(x, 0, 1),\n        )\n        self.assertAllClose(\n            knp.Swapaxes(0, 1)(x),\n            np.swapaxes(x, 0, 1),\n        )\n\n    def test_tan(self):\n        x = np.array([[1, -2, 3], [-3, 2, -1]])\n        self.assertAllClose(knp.tan(x), np.tan(x))\n        self.assertAllClose(knp.Tan()(x), np.tan(x))\n\n    def test_tanh(self):\n        x = np.array([[1, -2, 3], [-3, 2, -1]])\n        self.assertAllClose(knp.tanh(x), np.tanh(x))\n        self.assertAllClose(knp.Tanh()(x), np.tanh(x))\n\n    def test_tile(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        self.assertAllClose(knp.tile(x, 2), np.tile(x, 2))\n        self.assertAllClose(knp.tile(x, [2, 3]), np.tile(x, [2, 3]))\n        self.assertAllClose(knp.Tile([2, 3])(x), np.tile(x, [2, 3]))\n\n        # If repeats.ndim > x.ndim\n        self.assertAllClose(knp.tile(x, [2, 3, 4]), np.tile(x, [2, 3, 4]))\n        self.assertAllClose(knp.Tile([2, 3, 4])(x), np.tile(x, [2, 3, 4]))\n\n        # If repeats.ndim < x.ndim\n        self.assertAllClose(knp.tile(x, [2]), np.tile(x, [2]))\n        self.assertAllClose(knp.Tile([2])(x), np.tile(x, [2]))\n\n    def test_trace(self):\n        x = np.arange(24).reshape([1, 2, 3, 4])\n        self.assertAllClose(knp.trace(x), np.trace(x))\n        self.assertAllClose(\n            knp.trace(x, axis1=2, axis2=3),\n            np.trace(x, axis1=2, axis2=3),\n        )\n        self.assertAllClose(\n            knp.Trace(axis1=2, axis2=3)(x),\n            np.trace(x, axis1=2, axis2=3),\n        )\n\n    def test_tril(self):\n        x = np.arange(24).reshape([1, 2, 3, 4])\n        self.assertAllClose(knp.tril(x), np.tril(x))\n        self.assertAllClose(knp.tril(x, -1), np.tril(x, -1))\n        self.assertAllClose(knp.Tril(-1)(x), np.tril(x, -1))\n\n        x = np.ones([5, 5])\n        self.assertAllClose(knp.tril(x), np.tril(x))\n        self.assertAllClose(knp.tril(x, -1), np.tril(x, -1))\n        self.assertAllClose(knp.Tril(-1)(x), np.tril(x, -1))\n\n    def test_tril_in_layer(self):\n        # https://github.com/keras-team/keras/issues/18890\n        x = keras.Input((None, 3))\n        y1 = keras.layers.Lambda(\n            lambda x: keras.ops.tril(\n                keras.ops.ones((keras.ops.shape(x)[1], keras.ops.shape(x)[1]))\n            ),\n            output_shape=(None, None, 3),\n        )(x)\n        y2 = keras.layers.Lambda(\n            lambda x: keras.ops.tril(\n                keras.ops.ones((keras.ops.shape(x)[1], keras.ops.shape(x)[1])),\n                k=-1,\n            ),\n            output_shape=(None, None, 3),\n        )(x)\n        model = keras.Model(x, [y1, y2])\n\n        result = model(np.ones((1, 2, 3), \"float32\"))\n        self.assertAllClose(\n            result, [np.tril(np.ones((2, 2))), np.tril(np.ones((2, 2)), k=-1)]\n        )\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"Only test tensorflow backend\",\n    )\n    def test_tril_with_jit_in_tf(self):\n        import tensorflow as tf\n\n        x = knp.reshape(knp.arange(24), [1, 2, 3, 4])\n        k = knp.array(0)\n        x_np = np.reshape(np.arange(24), [1, 2, 3, 4])\n        k_np = np.array(0)\n\n        @tf.function(jit_compile=True)\n        def fn(x, k):\n            return knp.tril(x, k=k)\n\n        self.assertAllClose(fn(x, k), np.tril(x_np, k_np))\n\n    def test_triu(self):\n        x = np.arange(24).reshape([1, 2, 3, 4])\n        self.assertAllClose(knp.triu(x), np.triu(x))\n        self.assertAllClose(knp.triu(x, -1), np.triu(x, -1))\n        self.assertAllClose(knp.Triu(-1)(x), np.triu(x, -1))\n\n        x = np.ones([5, 5])\n        self.assertAllClose(knp.triu(x), np.triu(x))\n        self.assertAllClose(knp.triu(x, -1), np.triu(x, -1))\n        self.assertAllClose(knp.Triu(-1)(x), np.triu(x, -1))\n\n    def test_triu_in_layer(self):\n        # https://github.com/keras-team/keras/issues/18890\n        x = keras.Input((None, 3))\n        y1 = keras.layers.Lambda(\n            lambda x: keras.ops.triu(\n                keras.ops.ones((keras.ops.shape(x)[1], keras.ops.shape(x)[1]))\n            ),\n            output_shape=(None, None, 3),\n        )(x)\n        y2 = keras.layers.Lambda(\n            lambda x: keras.ops.triu(\n                keras.ops.ones((keras.ops.shape(x)[1], keras.ops.shape(x)[1])),\n                k=-1,\n            ),\n            output_shape=(None, None, 3),\n        )(x)\n        model = keras.Model(x, [y1, y2])\n\n        result = model(np.ones((1, 2, 3), \"float32\"))\n        self.assertAllClose(\n            result, [np.triu(np.ones((2, 2))), np.triu(np.ones((2, 2)), k=-1)]\n        )\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"Only test tensorflow backend\",\n    )\n    def test_triu_with_jit_in_tf(self):\n        import tensorflow as tf\n\n        x = knp.reshape(knp.arange(24), [1, 2, 3, 4])\n        k = knp.array(0)\n        x_np = np.reshape(np.arange(24), [1, 2, 3, 4])\n        k_np = np.array(0)\n\n        @tf.function(jit_compile=True)\n        def fn(x, k):\n            return knp.triu(x, k=k)\n\n        self.assertAllClose(fn(x, k), np.triu(x_np, k_np))\n\n    def test_trunc(self):\n        x = np.array([-1.7, -2.5, -0.2, 0.2, 1.5, 1.7, 2.0])\n        self.assertAllClose(knp.trunc(x), np.trunc(x))\n        self.assertAllClose(knp.Trunc()(x), np.trunc(x))\n\n        x = np.array([-1, -2, -0, 0, 1, 1, 2], dtype=\"int32\")\n        self.assertAllClose(knp.trunc(x), np.trunc(x))\n        self.assertAllClose(knp.Trunc()(x), np.trunc(x))\n\n    def test_vstack(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[4, 5, 6], [6, 5, 4]])\n        self.assertAllClose(knp.vstack([x, y]), np.vstack([x, y]))\n        self.assertAllClose(knp.Vstack()([x, y]), np.vstack([x, y]))\n\n    def test_dstack(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[4, 5, 6], [6, 5, 4]])\n        self.assertAllClose(knp.dstack([x, y]), np.dstack([x, y]))\n        self.assertAllClose(knp.Dstack()([x, y]), np.dstack([x, y]))\n\n        x = np.array([1, 2, 3])\n        y = np.array([[4, 5, 6]])\n        self.assertAllClose(knp.dstack([x, y]), np.dstack([x, y]))\n\n        x = np.ones([2, 3, 4])\n        y = np.ones([2, 3, 5])\n        self.assertAllClose(knp.dstack([x, y]), np.dstack([x, y]))\n\n    def test_floor_divide(self):\n        x = np.array([[1, 2, 3], [3, 2, 1]])\n        y = np.array([[4, 5, 6], [3, 2, 1]])\n        z = np.array([[[1, 2, 3], [3, 2, 1]]])\n        self.assertAllClose(knp.floor_divide(x, y), np.floor_divide(x, y))\n        self.assertAllClose(knp.floor_divide(x, z), np.floor_divide(x, z))\n\n        self.assertAllClose(knp.FloorDivide()(x, y), np.floor_divide(x, y))\n        self.assertAllClose(knp.FloorDivide()(x, z), np.floor_divide(x, z))\n\n    def test_xor(self):\n        x = np.array([[True, False], [True, True]])\n        y = np.array([[False, False], [True, False]])\n        self.assertAllClose(knp.logical_xor(x, y), np.logical_xor(x, y))\n        self.assertAllClose(knp.logical_xor(x, True), np.logical_xor(x, True))\n        self.assertAllClose(knp.logical_xor(True, x), np.logical_xor(True, x))\n\n        self.assertAllClose(knp.LogicalXor()(x, y), np.logical_xor(x, y))\n        self.assertAllClose(knp.LogicalXor()(x, True), np.logical_xor(x, True))\n        self.assertAllClose(knp.LogicalXor()(True, x), np.logical_xor(True, x))\n\n    def test_correlate(self):\n        x = np.array([1, 2, 3])\n        y = np.array([0, 1, 0.5])\n        self.assertAllClose(knp.correlate(x, y), np.correlate(x, y))\n        self.assertAllClose(\n            knp.correlate(x, y, mode=\"same\"), np.correlate(x, y, mode=\"same\")\n        )\n        self.assertAllClose(\n            knp.correlate(x, y, mode=\"full\"), np.correlate(x, y, mode=\"full\")\n        )\n\n        self.assertAllClose(knp.Correlate()(x, y), np.correlate(x, y))\n        self.assertAllClose(\n            knp.Correlate(mode=\"same\")(x, y), np.correlate(x, y, mode=\"same\")\n        )\n        self.assertAllClose(\n            knp.Correlate(mode=\"full\")(x, y), np.correlate(x, y, mode=\"full\")\n        )\n\n    def test_correlate_different_size(self):\n        x = np.array([1, 3, 5])\n        y = np.array([7, 9])\n        self.assertAllClose(knp.correlate(x, y), np.correlate(x, y))\n        self.assertAllClose(\n            knp.correlate(x, y, mode=\"same\"), np.correlate(x, y, mode=\"same\")\n        )\n        self.assertAllClose(\n            knp.correlate(x, y, mode=\"full\"), np.correlate(x, y, mode=\"full\")\n        )\n\n        self.assertAllClose(knp.Correlate()(x, y), np.correlate(x, y))\n        self.assertAllClose(\n            knp.Correlate(mode=\"same\")(x, y), np.correlate(x, y, mode=\"same\")\n        )\n        self.assertAllClose(\n            knp.Correlate(mode=\"full\")(x, y), np.correlate(x, y, mode=\"full\")\n        )\n\n    def test_select(self):\n        x = np.arange(6)\n        condlist = [x < 3, x > 3]\n        choicelist = [x, x**2]\n        y = knp.select(condlist, choicelist, 42)\n        self.assertAllClose(y, [0, 1, 2, 42, 16, 25])\n\n        # Test with tuples\n        condlist = (x < 3, x > 3)\n        choicelist = (x, x**2)\n        y = knp.select(condlist, choicelist, 42)\n        self.assertAllClose(y, [0, 1, 2, 42, 16, 25])\n\n        # Test with symbolic tensors\n        x = backend.KerasTensor((6,))\n        condlist = [x < 3, x > 3]\n        choicelist = [x, x**2]\n        y = knp.select(condlist, choicelist, 42)\n        self.assertEqual(y.shape, (6,))\n\n    def test_slogdet(self):\n        x = np.ones((4, 4)) * 2.0\n        out = knp.slogdet(x)\n        self.assertAllClose(out[0], 0)\n        self.assertAllClose(out[0], 0)\n\n        x = backend.KerasTensor((3, 3))\n        out = knp.slogdet(x)\n        self.assertEqual(out[0].shape, ())\n        self.assertEqual(out[1].shape, ())\n\n        x = backend.KerasTensor((2, 4, 3, 3))\n        out = knp.slogdet(x)\n        self.assertEqual(out[0].shape, ())\n        self.assertEqual(out[1].shape, (2, 4))\n\n    def test_nanargmax(self):\n        x = np.array([[1.0, np.nan, -np.inf], [np.nan, 2.0, -1.0]])\n\n        self.assertAllClose(knp.nanargmax(x), np.nanargmax(x))\n        self.assertAllClose(knp.nanargmax(x, axis=0), np.nanargmax(x, axis=0))\n        self.assertAllClose(knp.nanargmax(x, axis=1), np.nanargmax(x, axis=1))\n        self.assertAllClose(knp.Nanargmax()(x), np.nanargmax(x))\n        self.assertAllClose(knp.Nanargmax(axis=1)(x), np.nanargmax(x, axis=1))\n\n        x_3d = np.array(\n            [\n                [[1.0, np.nan], [2.0, 3.0]],\n                [[np.nan, 4.0], [5.0, np.nan]],\n            ]\n        )\n\n        self.assertAllClose(knp.nanargmax(x_3d), np.nanargmax(x_3d))\n        self.assertAllClose(\n            knp.nanargmax(x_3d, axis=0), np.nanargmax(x_3d, axis=0)\n        )\n        self.assertAllClose(\n            knp.nanargmax(x_3d, axis=1), np.nanargmax(x_3d, axis=1)\n        )\n\n        x_all_nan = np.array([[np.nan, np.nan], [np.nan, np.nan]])\n\n        self.assertEqual(knp.nanargmax(x_all_nan), -1)\n        self.assertAllClose(\n            knp.nanargmax(x_all_nan, axis=0), np.array([-1, -1])\n        )\n        self.assertAllClose(\n            knp.nanargmax(x_all_nan, axis=1), np.array([-1, -1])\n        )\n        self.assertAllClose(\n            knp.nanargmax(x_all_nan, axis=1, keepdims=True),\n            np.array([[-1], [-1]]),\n        )\n\n    def test_nanargmin(self):\n        x = np.array([[1.0, np.nan, np.inf], [np.nan, 2.0, -1.0]])\n\n        self.assertAllClose(knp.nanargmin(x), np.nanargmin(x))\n        self.assertAllClose(knp.nanargmin(x, axis=0), np.nanargmin(x, axis=0))\n        self.assertAllClose(knp.nanargmin(x, axis=1), np.nanargmin(x, axis=1))\n        self.assertAllClose(knp.Nanargmin()(x), np.nanargmin(x))\n        self.assertAllClose(knp.Nanargmin(axis=1)(x), np.nanargmin(x, axis=1))\n\n        x_3d = np.array(\n            [\n                [[1.0, np.nan], [2.0, 3.0]],\n                [[np.nan, 4.0], [5.0, np.nan]],\n            ]\n        )\n\n        self.assertAllClose(knp.nanargmin(x_3d), np.nanargmin(x_3d))\n        self.assertAllClose(\n            knp.nanargmin(x_3d, axis=0), np.nanargmin(x_3d, axis=0)\n        )\n        self.assertAllClose(\n            knp.nanargmin(x_3d, axis=1), np.nanargmin(x_3d, axis=1)\n        )\n\n        x_all_nan = np.array([[np.nan, np.nan], [np.nan, np.nan]])\n\n        self.assertEqual(knp.nanargmin(x_all_nan), -1)\n        self.assertAllClose(\n            knp.nanargmin(x_all_nan, axis=0), np.array([-1, -1])\n        )\n        self.assertAllClose(\n            knp.nanargmin(x_all_nan, axis=1), np.array([-1, -1])\n        )\n        self.assertAllClose(\n            knp.nanargmin(x_all_nan, axis=1, keepdims=True),\n            np.array([[-1], [-1]]),\n        )\n\n    def test_nancumsum(self):\n        x = np.array([[1.0, np.nan, 3.0], [np.nan, 2.0, -1.0]])\n\n        self.assertAllClose(knp.nancumsum(x), np.nancumsum(x))\n        self.assertAllClose(knp.nancumsum(x, axis=0), np.nancumsum(x, axis=0))\n        self.assertAllClose(knp.nancumsum(x, axis=1), np.nancumsum(x, axis=1))\n        self.assertAllClose(knp.Nancumsum()(x), np.nancumsum(x))\n        self.assertAllClose(knp.Nancumsum(axis=1)(x), np.nancumsum(x, axis=1))\n\n        x_3d = np.array(\n            [\n                [[1.0, np.nan], [2.0, 3.0]],\n                [[np.nan, 4.0], [5.0, np.nan]],\n            ]\n        )\n\n        self.assertAllClose(knp.nancumsum(x_3d), np.nancumsum(x_3d))\n        self.assertAllClose(\n            knp.nancumsum(x_3d, axis=0), np.nancumsum(x_3d, axis=0)\n        )\n        self.assertAllClose(\n            knp.nancumsum(x_3d, axis=1), np.nancumsum(x_3d, axis=1)\n        )\n\n        x_all_nan = np.array([[np.nan, np.nan], [np.nan, np.nan]])\n        self.assertAllClose(knp.nancumsum(x_all_nan), np.nancumsum(x_all_nan))\n        self.assertAllClose(\n            knp.nancumsum(x_all_nan, axis=1), np.nancumsum(x_all_nan, axis=1)\n        )\n\n    def test_nancumprod(self):\n        x = np.array([[1.0, np.nan, 3.0], [np.nan, 2.0, -1.0]])\n\n        self.assertAllClose(knp.nancumprod(x), np.nancumprod(x))\n        self.assertAllClose(knp.nancumprod(x, axis=0), np.nancumprod(x, axis=0))\n        self.assertAllClose(knp.nancumprod(x, axis=1), np.nancumprod(x, axis=1))\n        self.assertAllClose(knp.Nancumprod()(x), np.nancumprod(x))\n        self.assertAllClose(knp.Nancumprod(axis=1)(x), np.nancumprod(x, axis=1))\n\n        x_3d = np.array(\n            [\n                [[1.0, np.nan], [2.0, 3.0]],\n                [[np.nan, 4.0], [5.0, np.nan]],\n            ]\n        )\n\n        self.assertAllClose(knp.nancumprod(x_3d), np.nancumprod(x_3d))\n        self.assertAllClose(\n            knp.nancumprod(x_3d, axis=0), np.nancumprod(x_3d, axis=0)\n        )\n        self.assertAllClose(\n            knp.nancumprod(x_3d, axis=1), np.nancumprod(x_3d, axis=1)\n        )\n\n        x_all_nan = np.array([[np.nan, np.nan], [np.nan, np.nan]])\n        self.assertAllClose(knp.nancumprod(x_all_nan), np.nancumprod(x_all_nan))\n        self.assertAllClose(\n            knp.nancumprod(x_all_nan, axis=1),\n            np.nancumprod(x_all_nan, axis=1),\n        )\n\n    def test_nanmax(self):\n        x = np.array([[1.0, np.nan, 3.0], [np.nan, 2.0, -np.inf]])\n\n        self.assertAllClose(knp.nanmax(x), np.nanmax(x))\n        self.assertAllClose(knp.nanmax(x, axis=()), np.nanmax(x, axis=()))\n        self.assertAllClose(knp.nanmax(x, axis=1), np.nanmax(x, axis=1))\n        self.assertAllClose(knp.nanmax(x, axis=(1,)), np.nanmax(x, axis=(1,)))\n        self.assertAllClose(\n            knp.nanmax(x, axis=1, keepdims=True),\n            np.nanmax(x, axis=1, keepdims=True),\n        )\n\n        self.assertAllClose(knp.Nanmax()(x), np.nanmax(x))\n        self.assertAllClose(knp.Nanmax(axis=1)(x), np.nanmax(x, axis=1))\n        self.assertAllClose(\n            knp.Nanmax(axis=1, keepdims=True)(x),\n            np.nanmax(x, axis=1, keepdims=True),\n        )\n\n        x_all_nan = np.array([[np.nan, np.nan], [np.nan, np.nan]])\n        self.assertAllClose(knp.nanmax(x_all_nan), np.nanmax(x_all_nan))\n        self.assertAllClose(\n            knp.nanmax(x_all_nan, axis=1),\n            np.nanmax(x_all_nan, axis=1),\n        )\n\n        x_3d = np.array(\n            [\n                [[1.0, np.nan], [2.0, 3.0]],\n                [[np.nan, 4.0], [5.0, np.nan]],\n            ]\n        )\n        self.assertAllClose(knp.nanmax(x_3d), np.nanmax(x_3d))\n        self.assertAllClose(\n            knp.nanmax(x_3d, axis=(1, 2)),\n            np.nanmax(x_3d, axis=(1, 2)),\n        )\n\n    def test_nanmean(self):\n        x = np.array([[1.0, np.nan, 3.0, 4.0], [np.nan, 2.0, np.inf, -np.inf]])\n\n        self.assertAllClose(knp.nanmean(x), np.nanmean(x))\n        self.assertAllClose(knp.nanmean(x, axis=()), np.nanmean(x, axis=()))\n        self.assertAllClose(knp.nanmean(x, axis=1), np.nanmean(x, axis=1))\n        self.assertAllClose(knp.nanmean(x, axis=(1,)), np.nanmean(x, axis=(1,)))\n        self.assertAllClose(\n            knp.nanmean(x, axis=1, keepdims=True),\n            np.nanmean(x, axis=1, keepdims=True),\n        )\n\n        self.assertAllClose(knp.Nanmean()(x), np.nanmean(x))\n        self.assertAllClose(knp.Nanmean(axis=1)(x), np.nanmean(x, axis=1))\n        self.assertAllClose(\n            knp.Nanmean(axis=1, keepdims=True)(x),\n            np.nanmean(x, axis=1, keepdims=True),\n        )\n\n        x_all_nan = np.array([[np.nan, np.nan], [np.nan, np.nan]])\n        self.assertAllClose(knp.nanmean(x_all_nan), np.nanmean(x_all_nan))\n        self.assertAllClose(\n            knp.nanmean(x_all_nan, axis=1),\n            np.nanmean(x_all_nan, axis=1),\n        )\n\n        x_3d = np.array(\n            [\n                [[1.0, np.nan], [2.0, 3.0]],\n                [[np.nan, 4.0], [5.0, np.nan]],\n            ]\n        )\n        self.assertAllClose(knp.nanmean(x_3d), np.nanmean(x_3d))\n        self.assertAllClose(\n            knp.nanmean(x_3d, axis=(1, 2)),\n            np.nanmean(x_3d, axis=(1, 2)),\n        )\n\n    def test_nanmin(self):\n        x = np.array([[1.0, np.nan, 3.0], [np.nan, 2.0, np.inf]])\n\n        self.assertAllClose(knp.nanmin(x), np.nanmin(x))\n        self.assertAllClose(knp.nanmin(x, axis=()), np.nanmin(x, axis=()))\n        self.assertAllClose(knp.nanmin(x, axis=1), np.nanmin(x, axis=1))\n        self.assertAllClose(knp.nanmin(x, axis=(1,)), np.nanmin(x, axis=(1,)))\n        self.assertAllClose(\n            knp.nanmin(x, axis=1, keepdims=True),\n            np.nanmin(x, axis=1, keepdims=True),\n        )\n\n        self.assertAllClose(knp.Nanmin()(x), np.nanmin(x))\n        self.assertAllClose(knp.Nanmin(axis=1)(x), np.nanmin(x, axis=1))\n        self.assertAllClose(\n            knp.Nanmin(axis=1, keepdims=True)(x),\n            np.nanmin(x, axis=1, keepdims=True),\n        )\n\n        x_all_nan = np.array([[np.nan, np.nan], [np.nan, np.nan]])\n        self.assertAllClose(knp.nanmin(x_all_nan), np.nanmin(x_all_nan))\n        self.assertAllClose(\n            knp.nanmin(x_all_nan, axis=1),\n            np.nanmin(x_all_nan, axis=1),\n        )\n\n        x_3d = np.array(\n            [\n                [[1.0, np.nan], [2.0, 3.0]],\n                [[np.nan, 4.0], [5.0, np.nan]],\n            ]\n        )\n        self.assertAllClose(knp.nanmin(x_3d), np.nanmin(x_3d))\n        self.assertAllClose(\n            knp.nanmin(x_3d, axis=(1, 2)),\n            np.nanmin(x_3d, axis=(1, 2)),\n        )\n\n    def test_nanprod(self):\n        x = np.array([[1.0, np.nan, 3.0], [np.nan, 2.0, 1.0]])\n\n        self.assertAllClose(knp.nanprod(x), np.nanprod(x))\n        self.assertAllClose(knp.nanprod(x, axis=()), np.nanprod(x, axis=()))\n        self.assertAllClose(knp.nanprod(x, axis=1), np.nanprod(x, axis=1))\n        self.assertAllClose(knp.nanprod(x, axis=(1,)), np.nanprod(x, axis=(1,)))\n        self.assertAllClose(\n            knp.nanprod(x, axis=1, keepdims=True),\n            np.nanprod(x, axis=1, keepdims=True),\n        )\n\n        self.assertAllClose(knp.Nanprod()(x), np.nanprod(x))\n        self.assertAllClose(knp.Nanprod(axis=1)(x), np.nanprod(x, axis=1))\n        self.assertAllClose(\n            knp.Nanprod(axis=1, keepdims=True)(x),\n            np.nanprod(x, axis=1, keepdims=True),\n        )\n\n        x_all_nan = np.array([[np.nan, np.nan], [np.nan, np.nan]])\n        self.assertAllClose(knp.nanprod(x_all_nan), np.nanprod(x_all_nan))\n        self.assertAllClose(\n            knp.nanprod(x_all_nan, axis=1),\n            np.nanprod(x_all_nan, axis=1),\n        )\n\n        x_3d = np.array(\n            [\n                [[1.0, np.nan], [2.0, 3.0]],\n                [[np.nan, 4.0], [5.0, np.nan]],\n            ]\n        )\n        self.assertAllClose(knp.nanprod(x_3d), np.nanprod(x_3d))\n        self.assertAllClose(\n            knp.nanprod(x_3d, axis=(1, 2)),\n            np.nanprod(x_3d, axis=(1, 2)),\n        )\n\n    def test_nanstd(self):\n        x = np.array([[[1.0, np.nan, 3.0], [np.nan, 2.0, 1.0]]])\n\n        self.assertAllClose(knp.nanstd(x), np.nanstd(x))\n        self.assertAllClose(knp.nanstd(x, axis=()), np.nanstd(x, axis=()))\n        self.assertAllClose(knp.nanstd(x, axis=0), np.nanstd(x, axis=0))\n        self.assertAllClose(knp.nanstd(x, axis=1), np.nanstd(x, axis=1))\n        self.assertAllClose(knp.nanstd(x, axis=(1,)), np.nanstd(x, axis=(1,)))\n        self.assertAllClose(\n            knp.nanstd(x, axis=1, keepdims=True),\n            np.nanstd(x, axis=1, keepdims=True),\n        )\n\n        self.assertAllClose(knp.Nanstd()(x), np.nanstd(x))\n        self.assertAllClose(knp.Nanstd(axis=1)(x), np.nanstd(x, axis=1))\n        self.assertAllClose(\n            knp.Nanstd(axis=1, keepdims=True)(x),\n            np.nanstd(x, axis=1, keepdims=True),\n        )\n\n        x_all_nan = np.array([[np.nan, np.nan], [np.nan, np.nan]])\n        self.assertAllClose(knp.nanstd(x_all_nan), np.nanstd(x_all_nan))\n        self.assertAllClose(\n            knp.nanstd(x_all_nan, axis=1),\n            np.nanstd(x_all_nan, axis=1),\n        )\n\n        x_3d = np.array(\n            [\n                [[1.0, np.nan], [2.0, 3.0]],\n                [[np.nan, 4.0], [5.0, np.nan]],\n            ]\n        )\n        self.assertAllClose(knp.nanstd(x_3d), np.nanstd(x_3d))\n        self.assertAllClose(\n            knp.nanstd(x_3d, axis=(1, 2)),\n            np.nanstd(x_3d, axis=(1, 2)),\n        )\n\n    def test_nansum(self):\n        x = np.array([[1.0, np.nan, 3.0], [np.nan, 2.0, 1.0]])\n\n        self.assertAllClose(knp.nansum(x), np.nansum(x))\n        self.assertAllClose(knp.nansum(x, axis=()), np.nansum(x, axis=()))\n        self.assertAllClose(knp.nansum(x, axis=1), np.nansum(x, axis=1))\n        self.assertAllClose(knp.nansum(x, axis=(1,)), np.nansum(x, axis=(1,)))\n        self.assertAllClose(\n            knp.nansum(x, axis=1, keepdims=True),\n            np.nansum(x, axis=1, keepdims=True),\n        )\n\n        self.assertAllClose(knp.Nansum()(x), np.nansum(x))\n        self.assertAllClose(knp.Nansum(axis=1)(x), np.nansum(x, axis=1))\n        self.assertAllClose(\n            knp.Nansum(axis=1, keepdims=True)(x),\n            np.nansum(x, axis=1, keepdims=True),\n        )\n\n        x_all_nan = np.array([[np.nan, np.nan], [np.nan, np.nan]])\n        self.assertAllClose(knp.nansum(x_all_nan), np.nansum(x_all_nan))\n        self.assertAllClose(\n            knp.nansum(x_all_nan, axis=1),\n            np.nansum(x_all_nan, axis=1),\n        )\n\n        x_3d = np.array(\n            [\n                [[1.0, np.nan], [2.0, 3.0]],\n                [[np.nan, 4.0], [5.0, np.nan]],\n            ]\n        )\n        self.assertAllClose(knp.nansum(x_3d), np.nansum(x_3d))\n        self.assertAllClose(\n            knp.nansum(x_3d, axis=(1, 2)),\n            np.nansum(x_3d, axis=(1, 2)),\n        )\n\n    def test_nanvar(self):\n        x = np.array([[[1.0, np.nan, 3.0], [np.nan, 2.0, 1.0]]])\n\n        self.assertAllClose(knp.nanvar(x), np.nanvar(x))\n        self.assertAllClose(knp.nanvar(x, axis=()), np.nanvar(x, axis=()))\n        self.assertAllClose(knp.nanvar(x, axis=0), np.nanvar(x, axis=0))\n        self.assertAllClose(knp.nanvar(x, axis=1), np.nanvar(x, axis=1))\n        self.assertAllClose(knp.nanvar(x, axis=(1,)), np.nanvar(x, axis=(1,)))\n        self.assertAllClose(\n            knp.nanvar(x, axis=1, keepdims=True),\n            np.nanvar(x, axis=1, keepdims=True),\n        )\n\n        self.assertAllClose(knp.Nanvar()(x), np.nanvar(x))\n        self.assertAllClose(knp.Nanvar(axis=1)(x), np.nanvar(x, axis=1))\n        self.assertAllClose(\n            knp.Nanvar(axis=1, keepdims=True)(x),\n            np.nanvar(x, axis=1, keepdims=True),\n        )\n\n        x_all_nan = np.array([[np.nan, np.nan], [np.nan, np.nan]])\n        self.assertAllClose(knp.nanvar(x_all_nan), np.nanvar(x_all_nan))\n        self.assertAllClose(\n            knp.nanvar(x_all_nan, axis=1),\n            np.nanvar(x_all_nan, axis=1),\n        )\n\n        x_3d = np.array(\n            [\n                [[1.0, np.nan], [2.0, 3.0]],\n                [[np.nan, 4.0], [5.0, np.nan]],\n            ]\n        )\n        self.assertAllClose(knp.nanvar(x_3d), np.nanvar(x_3d))\n        self.assertAllClose(\n            knp.nanvar(x_3d, axis=(1, 2)),\n            np.nanvar(x_3d, axis=(1, 2)),\n        )\n\n    def test_nan_to_num(self):\n        x = knp.array([1.0, np.nan, np.inf, -np.inf])\n        self.assertAllClose(\n            knp.nan_to_num(x), [1.0, 0.0, 3.402823e38, -3.402823e38]\n        )\n        self.assertAllClose(\n            knp.NanToNum()(x), [1.0, 0.0, 3.402823e38, -3.402823e38]\n        )\n        self.assertAllClose(\n            knp.nan_to_num(x, nan=2, posinf=3, neginf=4), [1.0, 2.0, 3.0, 4.0]\n        )\n        self.assertAllClose(\n            knp.NanToNum(nan=2, posinf=3, neginf=4)(x), [1.0, 2.0, 3.0, 4.0]\n        )\n\n        x = backend.KerasTensor((3, 4))\n        self.assertEqual(\n            knp.NanToNum(nan=2, posinf=3, neginf=4)(x).shape, (3, 4)\n        )\n\n    def test_vectorize(self):\n        # Basic functionality\n        def myfunc(a, b):\n            return a + b\n\n        vfunc = np.vectorize(myfunc)\n        y = vfunc([1, 2, 3, 4], 2)\n        self.assertAllClose(y, [3, 4, 5, 6])\n\n        # Test signature arg\n        vfunc = knp.vectorize(knp.trace, signature=\"(d,d)->()\")\n        out = vfunc(np.eye(4))\n        self.assertAllClose(\n            out, np.vectorize(np.trace, signature=\"(d,d)->()\")(np.eye(4))\n        )\n\n        vfunc = knp.vectorize(knp.diag, signature=\"(d,d)->(d)\")\n        out = vfunc(np.eye(4))\n        self.assertAllClose(\n            out, np.vectorize(np.diag, signature=\"(d,d)->(d)\")(np.eye(4))\n        )\n\n    def test_argpartition(self):\n        x = np.array([3, 4, 2, 1])\n        self.assertAllClose(knp.argpartition(x, 2), np.argpartition(x, 2))\n        self.assertAllClose(knp.Argpartition(2)(x), np.argpartition(x, 2))\n\n        x = np.array([[3, 4, 2], [1, 3, 4]])\n        self.assertAllClose(knp.argpartition(x, 1), np.argpartition(x, 1))\n        self.assertAllClose(knp.Argpartition(1)(x), np.argpartition(x, 1))\n\n        x = np.array([[[3, 4], [2, 3]], [[1, 2], [0, 1]]])\n        self.assertAllClose(knp.argpartition(x, 1), np.argpartition(x, 1))\n        self.assertAllClose(knp.Argpartition(1)(x), np.argpartition(x, 1))\n\n    def test_angle(self):\n        x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]])\n        self.assertAllClose(knp.angle(x), np.angle(x))\n\n        self.assertAllClose(knp.Angle()(x), np.angle(x))\n\n\nclass NumpyArrayCreateOpsCorrectnessTest(testing.TestCase):\n    def test_ones(self):\n        self.assertAllClose(knp.ones([2, 3]), np.ones([2, 3]))\n\n    def test_zeros(self):\n        self.assertAllClose(knp.zeros([2, 3]), np.zeros([2, 3]))\n\n    def test_eye(self):\n        self.assertAllClose(knp.eye(3), np.eye(3))\n        self.assertAllClose(knp.eye(3, 4), np.eye(3, 4))\n        self.assertAllClose(knp.eye(3, 4, 1), np.eye(3, 4, 1))\n\n        # Test k >= N\n        self.assertAllClose(knp.eye(3, k=3), np.eye(3, k=3))\n\n        # Test k > 0 and N >= M\n        self.assertAllClose(knp.eye(3, k=1), np.eye(3, k=1))\n\n        # Test k > 0 and N < M and N + k > M\n        self.assertAllClose(knp.eye(3, 4, k=2), np.eye(3, 4, k=2))\n\n        # Test k < 0 and M >= N\n        self.assertAllClose(knp.eye(3, k=-1), np.eye(3, k=-1))\n\n        # Test k < 0 and M < N and M - k > N\n        self.assertAllClose(knp.eye(4, 3, k=-2), np.eye(4, 3, k=-2))\n\n    def test_eye_raises_error_with_floats(self):\n        with self.assertRaises(TypeError):\n            knp.eye(3.0)\n        with self.assertRaises(TypeError):\n            knp.eye(3.0, 2.0)\n        with self.assertRaises(TypeError):\n            knp.eye(3, 2.0)\n        with self.assertRaises(TypeError):\n            v = knp.max(knp.arange(4.0))\n            knp.eye(v)\n        with self.assertRaises(TypeError):\n            knp.eye(knp.array(3, dtype=\"bfloat16\"))\n\n    def test_arange(self):\n        self.assertAllClose(knp.arange(3), np.arange(3))\n        self.assertAllClose(knp.arange(3, 7), np.arange(3, 7))\n        self.assertAllClose(knp.arange(3, 7, 2), np.arange(3, 7, 2))\n\n        self.assertAllClose(knp.Arange()(3), np.arange(3))\n        self.assertAllClose(knp.Arange()(3, 7), np.arange(3, 7))\n        self.assertAllClose(knp.Arange()(3, 7, 2), np.arange(3, 7, 2))\n\n        self.assertEqual(standardize_dtype(knp.arange(3).dtype), \"int32\")\n        with warnings.catch_warnings(record=True) as record:\n            knp.arange(3, dtype=\"int\")\n        self.assertEqual(len(record), 0)\n\n    def test_full(self):\n        self.assertAllClose(knp.full([2, 3], 0), np.full([2, 3], 0))\n        self.assertAllClose(knp.full([2, 3], 0.1), np.full([2, 3], 0.1))\n        self.assertAllClose(\n            knp.full([2, 3], np.array([1, 4, 5])),\n            np.full([2, 3], np.array([1, 4, 5])),\n        )\n\n        self.assertAllClose(knp.Full([2, 3])(0), np.full([2, 3], 0))\n        self.assertAllClose(knp.Full([2, 3])(0.1), np.full([2, 3], 0.1))\n        self.assertAllClose(\n            knp.Full([2, 3])(np.array([1, 4, 5])),\n            np.full([2, 3], np.array([1, 4, 5])),\n        )\n\n    def test_identity(self):\n        self.assertAllClose(knp.identity(3), np.identity(3))\n\n    def test_tri(self):\n        self.assertAllClose(knp.tri(3), np.tri(3))\n        self.assertAllClose(knp.tri(3, 4), np.tri(3, 4))\n        self.assertAllClose(knp.tri(3, 4, 1), np.tri(3, 4, 1))\n\n        # Test k < 0\n        self.assertAllClose(knp.tri(3, k=-1), np.tri(3, k=-1))\n\n        # Test -k-1 > N\n        self.assertAllClose(knp.tri(3, k=-5), np.tri(3, k=-5))\n\n        # Test k > M\n        self.assertAllClose(knp.tri(3, k=4), np.tri(3, k=4))\n\n\ndef create_sparse_tensor(x, indices_from=None, start=0, delta=2):\n    if indices_from is not None:\n        indices = indices_from.indices\n    else:\n        size = math.prod(x.shape)\n        flat_indices = np.arange(start, size, delta)\n        indices = np.stack(np.where(np.ones_like(x)), axis=1)[flat_indices]\n\n    if backend.backend() == \"tensorflow\":\n        import tensorflow as tf\n\n        return tf.SparseTensor(indices, tf.gather_nd(x, indices), x.shape)\n    elif backend.backend() == \"jax\":\n        import jax\n        import jax.experimental.sparse as jax_sparse\n\n        values = x[tuple(jax.numpy.moveaxis(indices, -1, 0))]\n        return jax_sparse.BCOO((values, indices), shape=x.shape)\n\n\ndef create_indexed_slices(x, indices_from=None, start=0, delta=2):\n    indices = np.arange(start, x.shape[0], delta)\n\n    if backend.backend() == \"tensorflow\":\n        import tensorflow as tf\n\n        if indices_from is not None:\n            indices = indices_from.indices\n        return tf.IndexedSlices(tf.gather(x, indices), indices, x.shape)\n    elif backend.backend() == \"jax\":\n        import jax\n        import jax.experimental.sparse as jax_sparse\n\n        if indices_from is not None:\n            indices = indices_from.indices\n        else:\n            indices = jax.numpy.expand_dims(indices, axis=1)\n        values = jax.numpy.take(x, jax.numpy.squeeze(indices, axis=1), axis=0)\n        return jax_sparse.BCOO((values, indices), shape=x.shape)\n\n\ndef get_sparseness_combinations(dense_to_sparse_fn):\n    x = np.array([[1, 2, 3], [3, 2, 1]])\n    y = np.array([[4, 5, 6], [3, 2, 1]])\n    scalar = backend.convert_to_tensor(2)\n    x_sp = dense_to_sparse_fn(x)\n    y_sp = dense_to_sparse_fn(y, indices_from=x_sp)\n    x_sp_sup = dense_to_sparse_fn(x, start=0, delta=1)\n    y_sp_dis = dense_to_sparse_fn(y, start=1)\n    y_sp_sup = dense_to_sparse_fn(y, start=0, delta=1)\n    x = backend.convert_to_tensor(x)\n    y = backend.convert_to_tensor(y)\n    return [\n        {\"testcase_name\": \"sparse_dense\", \"x\": x_sp, \"y\": y},\n        {\"testcase_name\": \"dense_sparse\", \"x\": x, \"y\": y_sp},\n        {\"testcase_name\": \"sparse_scalar\", \"x\": x_sp, \"y\": scalar},\n        {\"testcase_name\": \"scalar_sparse\", \"x\": scalar, \"y\": y_sp},\n        {\"testcase_name\": \"sparse_sparse_same\", \"x\": x_sp, \"y\": y_sp},\n        {\"testcase_name\": \"sparse_sparse_disjoint\", \"x\": x_sp, \"y\": y_sp_dis},\n        {\"testcase_name\": \"sparse_sparse_superset\", \"x\": x_sp, \"y\": y_sp_sup},\n        {\"testcase_name\": \"sparse_sparse_subset\", \"x\": x_sp_sup, \"y\": y_sp},\n    ]\n\n\ndef sparseness(x):\n    if isinstance(x, KerasTensor):\n        return \"sparse\" if x.sparse else \"dense\"\n    elif x.__class__.__name__ == \"BCOO\":\n        if x.n_dense > 0:\n            return \"slices\"\n        else:\n            return \"sparse\"\n    elif x.__class__.__name__ == \"SparseTensor\":\n        return \"sparse\"\n    elif x.__class__.__name__ == \"IndexedSlices\":\n        return \"slices\"\n    elif not hasattr(x, \"shape\") or not x.shape:\n        return \"scalar\"\n    else:\n        return \"dense\"\n\n\ndef union_sparseness(x1, x2):\n    x1_sparseness = sparseness(x1)\n    x2_sparseness = sparseness(x2)\n    if any(s in (\"scalar\", \"dense\") for s in (x1_sparseness, x2_sparseness)):\n        return \"dense\"\n    if x1_sparseness != x2_sparseness:\n        raise ValueError(f\"Illegal combination of operands: {x1} {x2}\")\n    return x1_sparseness\n\n\ndef intersection_sparseness(x1, x2):\n    x1_sparseness = sparseness(x1)\n    x2_sparseness = sparseness(x2)\n    if x1_sparseness == \"scalar\":\n        return x2_sparseness\n    if x2_sparseness in (\"scalar\", \"dense\"):\n        return x1_sparseness\n    if x1_sparseness == \"dense\":\n        return x2_sparseness\n    if x1_sparseness != x2_sparseness:\n        raise ValueError(f\"Illegal combination of operands: {x1} {x2}\")\n    return x1_sparseness\n\n\ndef division_sparseness(x1, x2):\n    x1_sparseness = sparseness(x1)\n    x2_sparseness = sparseness(x2)\n    if x2_sparseness in (\"sparse\", \"slices\"):\n        return \"dense\"\n    return \"dense\" if x1_sparseness == \"scalar\" else x1_sparseness\n\n\ndef snake_to_pascal_case(name):\n    return \"\".join(w.capitalize() for w in name.split(\"_\"))\n\n\n@pytest.mark.skipif(\n    not backend.SUPPORTS_SPARSE_TENSORS,\n    reason=\"Backend does not support sparse tensors.\",\n)\nclass SparseTest(testing.TestCase):\n    DTYPES = [\"int32\", \"float32\"]\n    DENSIFYING_UNARY_OPS = [\n        \"arccos\",\n        \"arccosh\",\n        \"cos\",\n        \"cosh\",\n        \"exp\",\n        \"isfinite\",\n        \"log\",\n        \"log10\",\n        \"log2\",\n        \"reciprocal\",\n    ]\n    DENSIFYING_UNARY_OPS_TESTS = [\n        {\n            \"testcase_name\": op,\n            \"op_function\": getattr(knp, op),\n            \"op_class\": getattr(knp, op.capitalize()),\n            \"np_op\": getattr(np, op),\n        }\n        for op in DENSIFYING_UNARY_OPS\n    ]\n    ELEMENTWISE_UNARY_OPS = [\n        \"abs\",\n        \"absolute\",\n        \"arcsin\",\n        \"arcsinh\",\n        \"arctan\",\n        \"arctanh\",\n        \"ceil\",\n        \"conj\",\n        \"conjugate\",\n        \"copy\",\n        \"expm1\",\n        \"floor\",\n        \"imag\",\n        \"log1p\",\n        \"negative\",\n        \"real\",\n        \"round\",\n        \"sign\",\n        \"sin\",\n        \"sinh\",\n        \"sqrt\",\n        \"square\",\n        \"tan\",\n        \"tanh\",\n    ]\n    ELEMENTWISE_UNARY_OPS_TESTS = [\n        {\n            \"testcase_name\": op,\n            \"op_function\": getattr(knp, op),\n            \"op_class\": getattr(knp, snake_to_pascal_case(op)),\n            \"np_op\": getattr(np, op),\n        }\n        for op in ELEMENTWISE_UNARY_OPS\n    ]\n    OTHER_UNARY_OPS_ARGS = [\n        (\"digitize\", \"\", {}, {\"bins\": np.array([0.1, 0.2, 1.0])}, (4, 2, 3)),\n        (\"mean\", \"none\", {\"axis\": None}, {}, (4, 2, 3)),\n        (\"mean\", \"none_k\", {\"axis\": None, \"keepdims\": True}, {}, (4, 2, 3)),\n        (\"mean\", \"empty\", {\"axis\": ()}, {}, (4, 2, 3)),\n        (\"mean\", \"empty_k\", {\"axis\": (), \"keepdims\": True}, {}, (4, 2, 3)),\n        (\"mean\", \"0\", {\"axis\": 0}, {}, (4, 2, 3)),\n        (\"mean\", \"0_k\", {\"axis\": 0, \"keepdims\": True}, {}, (4, 2, 3)),\n        (\"mean\", \"1\", {\"axis\": 1}, {}, (4, 2, 3)),\n        (\"mean\", \"1_k\", {\"axis\": 1, \"keepdims\": True}, {}, (4, 2, 3)),\n        (\"mean\", \"01\", {\"axis\": (0, 1)}, {}, (4, 2, 3)),\n        (\"mean\", \"01_k\", {\"axis\": (0, 1), \"keepdims\": True}, {}, (4, 2, 3)),\n        (\"mean\", \"02\", {\"axis\": (1, 2)}, {}, (4, 2, 3)),\n        (\"mean\", \"02_k\", {\"axis\": (1, 2), \"keepdims\": True}, {}, (4, 2, 3)),\n        (\"mean\", \"all\", {\"axis\": (0, 1, 2)}, {}, (4, 2, 3)),\n        (\"mean\", \"all_k\", {\"axis\": (0, 1, 2), \"keepdims\": True}, {}, (4, 2, 3)),\n        (\"sum\", \"none\", {\"axis\": None}, {}, (4, 2, 3)),\n        (\"sum\", \"none_k\", {\"axis\": None, \"keepdims\": True}, {}, (4, 2, 3)),\n        (\"sum\", \"empty\", {\"axis\": ()}, {}, (4, 2, 3)),\n        (\"sum\", \"empty_k\", {\"axis\": (), \"keepdims\": True}, {}, (4, 2, 3)),\n        (\"sum\", \"0\", {\"axis\": 0}, {}, (4, 2, 3)),\n        (\"sum\", \"0_k\", {\"axis\": 0, \"keepdims\": True}, {}, (4, 2, 3)),\n        (\"sum\", \"1\", {\"axis\": 1}, {}, (4, 2, 3)),\n        (\"sum\", \"1_k\", {\"axis\": 1, \"keepdims\": True}, {}, (4, 2, 3)),\n        (\"sum\", \"01\", {\"axis\": (0, 1)}, {}, (4, 2, 3)),\n        (\"sum\", \"01_k\", {\"axis\": (0, 1), \"keepdims\": True}, {}, (4, 2, 3)),\n        (\"sum\", \"02\", {\"axis\": (1, 2)}, {}, (4, 2, 3)),\n        (\"sum\", \"02_k\", {\"axis\": (1, 2), \"keepdims\": True}, {}, (4, 2, 3)),\n        (\"sum\", \"all\", {\"axis\": (0, 1, 2)}, {}, (4, 2, 3)),\n        (\"sum\", \"all_k\", {\"axis\": (0, 1, 2), \"keepdims\": True}, {}, (4, 2, 3)),\n        (\"expand_dims\", \"zero\", {\"axis\": 0}, {}, (2, 3)),\n        (\"expand_dims\", \"one\", {\"axis\": 1}, {}, (2, 3)),\n        (\"expand_dims\", \"minus_two\", {\"axis\": -2}, {}, (2, 3)),\n        (\"reshape\", \"basic\", {\"newshape\": (4, 3, 2)}, {}, (4, 2, 3)),\n        (\"reshape\", \"minus_one\", {\"newshape\": (4, 3, -1)}, {}, (4, 2, 3)),\n        (\"reshape\", \"fewer_dims\", {\"newshape\": (4, 6)}, {}, (4, 2, 3)),\n        (\"squeeze\", \"no_axis_no_op\", {}, {}, (2, 3)),\n        (\"squeeze\", \"one\", {\"axis\": 1}, {}, (2, 1, 3)),\n        (\"squeeze\", \"minus_two\", {\"axis\": -2}, {}, (2, 1, 3)),\n        (\"squeeze\", \"no_axis\", {}, {}, (2, 1, 3)),\n        (\"transpose\", \"no_axes\", {}, {}, (1, 2, 3, 4)),\n        (\"transpose\", \"axes\", {\"axes\": (0, 3, 2, 1)}, {}, (1, 2, 3, 4)),\n    ]\n    OTHER_UNARY_OPS_TESTS = [\n        {\n            \"testcase_name\": \"_\".join([op, testcase_name]),\n            \"op_function\": getattr(knp, op),\n            \"op_class\": getattr(knp, snake_to_pascal_case(op)),\n            \"np_op\": getattr(np, op),\n            \"init_kwargs\": init_kwargs,\n            \"op_kwargs\": op_kwargs,\n            \"input_shape\": input_shape,\n        }\n        for op, testcase_name, init_kwargs, op_kwargs, input_shape in (\n            OTHER_UNARY_OPS_ARGS\n        )\n    ]\n\n    BINARY_OPS = [\n        (\"add\", union_sparseness),\n        (\"subtract\", union_sparseness),\n        (\"maximum\", union_sparseness),\n        (\"minimum\", union_sparseness),\n        (\"multiply\", intersection_sparseness),\n        (\"divide\", division_sparseness),\n        (\"true_divide\", division_sparseness),\n    ]\n    BINARY_OPS_TESTS = [\n        {\n            \"testcase_name\": op,\n            \"op_function\": getattr(knp, op),\n            \"op_class\": getattr(knp, snake_to_pascal_case(op)),\n            \"np_op\": getattr(np, op),\n            \"op_sparseness\": op_sparseness,\n        }\n        for op, op_sparseness in BINARY_OPS\n    ]\n\n    def assertSameSparseness(self, x, y):\n        self.assertEqual(sparseness(x), sparseness(y))\n\n    def assertSparseness(self, x, expected_sparseness):\n        self.assertEqual(sparseness(x), expected_sparseness)\n\n    @parameterized.named_parameters(ELEMENTWISE_UNARY_OPS_TESTS)\n    def test_elementwise_unary_symbolic_static_shape(\n        self, op_function, op_class, np_op\n    ):\n        x = KerasTensor([2, 3], sparse=True)\n        self.assertEqual(op_function(x).shape, (2, 3))\n        self.assertTrue(op_function(x).sparse)\n        self.assertEqual(op_class()(x).shape, (2, 3))\n        self.assertTrue(op_class()(x).sparse)\n\n    @parameterized.named_parameters(ELEMENTWISE_UNARY_OPS_TESTS)\n    def test_elementwise_unary_symbolic_dynamic_shape(\n        self, op_function, op_class, np_op\n    ):\n        x = KerasTensor([None, 3], sparse=True)\n        self.assertEqual(op_function(x).shape, (None, 3))\n        self.assertTrue(op_function(x).sparse)\n        self.assertEqual(op_class()(x).shape, (None, 3))\n        self.assertTrue(op_class()(x).sparse)\n\n    @parameterized.named_parameters(OTHER_UNARY_OPS_TESTS)\n    def test_other_unary_symbolic_static_shape(\n        self, op_function, op_class, np_op, init_kwargs, op_kwargs, input_shape\n    ):\n        expected_shape = op_function(\n            KerasTensor(input_shape), **init_kwargs, **op_kwargs\n        ).shape\n        x = KerasTensor(input_shape, sparse=True)\n        self.assertEqual(\n            op_function(x, **init_kwargs, **op_kwargs).shape, expected_shape\n        )\n        self.assertTrue(op_function(x, **init_kwargs, **op_kwargs).sparse)\n        self.assertEqual(\n            op_class(**init_kwargs)(x, **op_kwargs).shape, expected_shape\n        )\n        self.assertTrue(op_class(**init_kwargs)(x, **op_kwargs).sparse)\n\n    @parameterized.named_parameters(OTHER_UNARY_OPS_TESTS)\n    def test_other_unary_symbolic_dynamic_shape(\n        self, op_function, op_class, np_op, init_kwargs, op_kwargs, input_shape\n    ):\n        input_shape = (None,) + input_shape[1:]\n        expected_shape = op_function(\n            KerasTensor(input_shape), **init_kwargs, **op_kwargs\n        ).shape\n        x = KerasTensor(input_shape, sparse=True)\n        self.assertEqual(\n            op_function(x, **init_kwargs, **op_kwargs).shape, expected_shape\n        )\n        self.assertTrue(op_function(x, **init_kwargs, **op_kwargs).sparse)\n        self.assertEqual(\n            op_class(**init_kwargs)(x, **op_kwargs).shape, expected_shape\n        )\n        self.assertTrue(op_class(**init_kwargs)(x, **op_kwargs).sparse)\n\n    @parameterized.named_parameters(DENSIFYING_UNARY_OPS_TESTS)\n    def test_densifying_unary_sparse_correctness(\n        self, op_function, op_class, np_op\n    ):\n        x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]])\n        x = create_sparse_tensor(x)\n        x_np = backend.convert_to_numpy(x)\n\n        self.assertAllClose(op_function(x), np_op(x_np))\n        self.assertAllClose(op_class()(x), np_op(x_np))\n\n    @parameterized.named_parameters(DENSIFYING_UNARY_OPS_TESTS)\n    def test_densifying_unary_indexed_slices_correctness(\n        self, op_function, op_class, np_op\n    ):\n        x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]])\n        x = create_indexed_slices(x)\n        x_np = backend.convert_to_numpy(x)\n\n        self.assertAllClose(op_function(x), np_op(x_np))\n        self.assertAllClose(op_class()(x), np_op(x_np))\n\n    @parameterized.named_parameters(ELEMENTWISE_UNARY_OPS_TESTS)\n    def test_elementwise_unary_sparse_correctness(\n        self, op_function, op_class, np_op\n    ):\n        if op_function.__name__ in (\"conj\", \"conjugate\", \"imag\", \"real\"):\n            x = np.array([[1 + 1j, 2 + 2j, 3 + 3j], [3 + 3j, 2 + 2j, 1 + 1j]])\n        else:\n            x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]])\n        x = create_sparse_tensor(x)\n        x_np = backend.convert_to_numpy(x)\n\n        self.assertAllClose(op_function(x), np_op(x_np))\n        self.assertSameSparseness(op_function(x), x)\n        self.assertAllClose(op_class()(x), np_op(x_np))\n        self.assertSameSparseness(op_class()(x), x)\n\n    @parameterized.named_parameters(ELEMENTWISE_UNARY_OPS_TESTS)\n    def test_elementwise_unary_indexed_slices_correctness(\n        self, op_function, op_class, np_op\n    ):\n        if op_function.__name__ in (\"conj\", \"conjugate\", \"imag\", \"real\"):\n            x = np.array([[1 + 1j, 2 + 2j, 3 + 3j], [3 + 3j, 2 + 2j, 1 + 1j]])\n        else:\n            x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]])\n        x = create_indexed_slices(x)\n        x_np = backend.convert_to_numpy(x)\n\n        self.assertAllClose(op_function(x), np_op(x_np))\n        self.assertSameSparseness(op_function(x), x)\n        self.assertAllClose(op_class()(x), np_op(x_np))\n        self.assertSameSparseness(op_class()(x), x)\n\n    @parameterized.named_parameters(OTHER_UNARY_OPS_TESTS)\n    def test_other_unary_sparse_correctness(\n        self, op_function, op_class, np_op, init_kwargs, op_kwargs, input_shape\n    ):\n        x = np.random.random(input_shape)\n        if op_function is knp.mean:\n            x = create_indexed_slices(x)\n        else:\n            x = create_sparse_tensor(x)\n        x_np = backend.convert_to_numpy(x)\n\n        # `newshape` was renamed `shape` in Numpy.\n        np_init_kwargs = init_kwargs.copy()\n        if \"newshape\" in init_kwargs:\n            np_init_kwargs[\"shape\"] = np_init_kwargs.pop(\"newshape\")\n\n        self.assertAllClose(\n            op_function(x, **init_kwargs, **op_kwargs),\n            np_op(x_np, **np_init_kwargs, **op_kwargs),\n        )\n        self.assertAllClose(\n            op_class(**init_kwargs)(x, **op_kwargs),\n            np_op(x_np, **np_init_kwargs, **op_kwargs),\n        )\n        # Reduction operations have complex and backend dependent rules about\n        # when the result is sparse and it is dense.\n        if op_function is not knp.mean:\n            self.assertSameSparseness(\n                op_function(x, **init_kwargs, **op_kwargs), x\n            )\n            self.assertSameSparseness(\n                op_class(**init_kwargs)(x, **op_kwargs), x\n            )\n\n    @parameterized.named_parameters(\n        named_product(\n            BINARY_OPS_TESTS, x_sparse=[True, False], y_sparse=[True, False]\n        )\n    )\n    def test_binary_symbolic_static_shape(\n        self, x_sparse, y_sparse, op_function, op_class, np_op, op_sparseness\n    ):\n        x = KerasTensor([2, 3], sparse=x_sparse)\n        y = KerasTensor([2, 3], sparse=y_sparse)\n        self.assertEqual(op_function(x, y).shape, (2, 3))\n        self.assertSparseness(op_function(x, y), op_sparseness(x, y))\n        self.assertEqual(op_class()(x, y).shape, (2, 3))\n        self.assertSparseness(op_class()(x, y), op_sparseness(x, y))\n\n    @parameterized.named_parameters(\n        named_product(\n            BINARY_OPS_TESTS, x_sparse=[True, False], y_sparse=[True, False]\n        )\n    )\n    def test_binary_symbolic_dynamic_shape(\n        self, x_sparse, y_sparse, op_function, op_class, np_op, op_sparseness\n    ):\n        x = KerasTensor([None, 3], sparse=x_sparse)\n        y = KerasTensor([2, None], sparse=y_sparse)\n        self.assertEqual(op_function(x, y).shape, (2, 3))\n        self.assertSparseness(op_function(x, y), op_sparseness(x, y))\n        self.assertEqual(op_class()(x, y).shape, (2, 3))\n        self.assertSparseness(op_class()(x, y), op_sparseness(x, y))\n\n    @parameterized.named_parameters(\n        named_product(\n            BINARY_OPS_TESTS,\n            get_sparseness_combinations(create_sparse_tensor),\n            dtype=DTYPES,\n        )\n    )\n    def test_binary_correctness_sparse_tensor(\n        self, x, y, op_function, op_class, np_op, op_sparseness, dtype\n    ):\n        x = backend.cast(x, dtype)\n        y = backend.cast(y, dtype)\n        expected_result = np_op(\n            backend.convert_to_numpy(x), backend.convert_to_numpy(y)\n        )\n\n        self.assertAllClose(op_function(x, y), expected_result)\n        self.assertSparseness(op_function(x, y), op_sparseness(x, y))\n        self.assertAllClose(op_class()(x, y), expected_result)\n        self.assertSparseness(op_class()(x, y), op_sparseness(x, y))\n\n    @parameterized.named_parameters(\n        named_product(\n            BINARY_OPS_TESTS,\n            get_sparseness_combinations(create_indexed_slices),\n            dtype=DTYPES,\n        )\n    )\n    def test_binary_correctness_indexed_slices(\n        self, x, y, op_function, op_class, np_op, op_sparseness, dtype\n    ):\n        x = backend.cast(x, dtype)\n        y = backend.cast(y, dtype)\n        expected_result = np_op(\n            backend.convert_to_numpy(x), backend.convert_to_numpy(y)\n        )\n\n        self.assertAllClose(op_function(x, y), expected_result)\n        self.assertSparseness(op_function(x, y), op_sparseness(x, y))\n        self.assertAllClose(op_class()(x, y), expected_result)\n        self.assertSparseness(op_class()(x, y), op_sparseness(x, y))\n\n    @parameterized.named_parameters(\n        named_product(\n            sparse_type=[\"sparse_tensor\", \"indexed_slices\"],\n            dtype=[\"int32\", \"float32\"],\n        )\n    )\n    def test_divide_with_zeros_nans(self, sparse_type, dtype):\n        x = backend.convert_to_tensor([[0, 2, 3], [3, 2, 1]], dtype=dtype)\n        if sparse_type == \"indexed_slices\":\n            x = create_indexed_slices(x, start=0, delta=2)\n        else:\n            x = create_sparse_tensor(x, start=0, delta=2)\n        if dtype.startswith(\"int\"):\n            y = [[0, 0, 3], [0, 0, 1]]\n        else:\n            y = [[np.nan, np.nan, 3], [0, 0, 1]]\n        y = backend.convert_to_tensor(y, dtype=dtype)\n        expected_result = np.divide(\n            backend.convert_to_numpy(x), backend.convert_to_numpy(y)\n        )\n\n        self.assertAllClose(knp.divide(x, y), expected_result)\n        self.assertAllClose(knp.Divide()(x, y), expected_result)\n\n\nclass NumpyDtypeTest(testing.TestCase):\n    \"\"\"Test the dtype to verify that the behavior matches JAX.\"\"\"\n\n    ALL_DTYPES = [\n        x\n        for x in dtypes.ALLOWED_DTYPES\n        if x\n        not in (\n            \"string\",\n            \"complex64\",\n            \"complex128\",\n            # Remove 64-bit dtypes.\n            \"float64\",\n            \"uint64\",\n            \"int64\",\n        )\n        + dtypes.FLOAT8_TYPES  # Remove float8 dtypes for the following tests\n    ] + [None]\n    INT_DTYPES = [x for x in dtypes.INT_TYPES if x not in (\"uint64\", \"int64\")]\n    FLOAT_DTYPES = [x for x in dtypes.FLOAT_TYPES if x not in (\"float64\",)]\n\n    if backend.backend() == \"torch\":\n        ALL_DTYPES = [x for x in ALL_DTYPES if x not in (\"uint16\", \"uint32\")]\n        INT_DTYPES = [x for x in INT_DTYPES if x not in (\"uint16\", \"uint32\")]\n    elif backend.backend() == \"tensorflow\":\n        # TODO(hongyu): Re-enable uint32 tests once we determine how to handle\n        # dtypes.result_type(uint32, int*) -> int64 promotion.\n        # Since TF variables require int64 to be placed on the GPU, we\n        # exclusively enable the int64 dtype for TF. However, JAX does not\n        # natively support int64, which prevents us from comparing the dtypes.\n        ALL_DTYPES = [x for x in ALL_DTYPES if x not in (\"uint32\",)]\n        INT_DTYPES = [x for x in INT_DTYPES if x not in (\"uint32\",)]\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_add(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((1,), dtype=dtype1)\n        x2 = knp.ones((1,), dtype=dtype2)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.add(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.add(x1, x2).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Add().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_array_split(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1, 2), dtype=dtype)\n        x_jax = jnp.ones((1, 2), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.split(x_jax, 2, -1)[0].dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.split(x, 2, -1)[0].dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_add_python_types(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n\n        # python int\n        expected_dtype = standardize_dtype(jnp.add(x_jax, 1).dtype)\n\n        self.assertDType(knp.add(x, 1), expected_dtype)\n        self.assertDType(knp.Add().symbolic_call(x, 1), expected_dtype)\n\n        # python float\n        expected_dtype = standardize_dtype(jnp.add(x_jax, 1.0).dtype)\n\n        self.assertDType(knp.add(x, 1.0), expected_dtype)\n        self.assertDType(knp.Add().symbolic_call(x, 1.0), expected_dtype)\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_bartlett(self, dtype):\n        x = knp.ones((), dtype=dtype)\n        expected_dtype = backend.floatx()\n\n        self.assertEqual(\n            standardize_dtype(knp.bartlett(x).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Bartlett().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_blackman(self, dtype):\n        x = knp.ones((), dtype=dtype)\n        expected_dtype = backend.floatx()\n\n        self.assertEqual(\n            standardize_dtype(knp.blackman(x).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Blackman().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_hamming(self, dtype):\n        x = knp.ones((), dtype=dtype)\n        expected_dtype = backend.floatx()\n\n        self.assertEqual(\n            standardize_dtype(knp.hamming(x).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Hamming().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_hanning(self, dtype):\n        x = knp.ones((), dtype=dtype)\n        expected_dtype = backend.floatx()\n\n        self.assertEqual(\n            standardize_dtype(knp.hanning(x).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Hanning().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_kaiser(self, dtype):\n        x = knp.ones((), dtype=dtype)\n        beta = knp.ones((), dtype=dtype)\n        expected_dtype = backend.floatx()\n\n        self.assertEqual(\n            standardize_dtype(knp.kaiser(x, beta).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Kaiser(beta).symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=INT_DTYPES))\n    def test_bincount(self, dtype):\n        import jax.numpy as jnp\n\n        if backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            if tf.test.is_gpu_available():\n                self.skipTest(\"bincount does not work in tensorflow gpu\")\n\n        x = np.array([1, 1, 2, 3, 2, 4, 4, 5], dtype=dtype)\n        weights = np.array([0, 0, 3, 2, 1, 1, 4, 2], dtype=dtype)\n        minlength = 3\n        self.assertEqual(\n            standardize_dtype(\n                knp.bincount(x, weights=weights, minlength=minlength).dtype\n            ),\n            standardize_dtype(\n                jnp.bincount(x, weights=weights, minlength=minlength).dtype\n            ),\n        )\n        self.assertEqual(\n            knp.Bincount(weights=weights, minlength=minlength)\n            .symbolic_call(x)\n            .dtype,\n            standardize_dtype(\n                jnp.bincount(x, weights=weights, minlength=minlength).dtype\n            ),\n        )\n\n        # test float32 weights\n        weights = np.array([0, 0, 3, 2, 1, 1, 4, 2], dtype=\"float32\")\n        self.assertEqual(\n            standardize_dtype(knp.bincount(x, weights=weights).dtype),\n            standardize_dtype(jnp.bincount(x, weights=weights).dtype),\n        )\n        self.assertEqual(\n            knp.Bincount(weights=weights).symbolic_call(x).dtype,\n            standardize_dtype(jnp.bincount(x, weights=weights).dtype),\n        )\n\n        # test float16 weights\n        weights = np.array([0, 0, 3, 2, 1, 1, 4, 2], dtype=\"float16\")\n        self.assertEqual(\n            standardize_dtype(knp.bincount(x, weights=weights).dtype),\n            standardize_dtype(jnp.bincount(x, weights=weights).dtype),\n        )\n        self.assertEqual(\n            knp.Bincount(weights=weights).symbolic_call(x).dtype,\n            standardize_dtype(jnp.bincount(x, weights=weights).dtype),\n        )\n\n        # test weights=None\n        self.assertEqual(\n            standardize_dtype(knp.bincount(x).dtype),\n            standardize_dtype(jnp.bincount(x).dtype),\n        )\n        self.assertEqual(\n            knp.Bincount().symbolic_call(x).dtype,\n            standardize_dtype(jnp.bincount(x).dtype),\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_subtract(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        if dtype1 == \"bool\" and dtype2 == \"bool\":\n            self.skipTest(\"subtract does not support bool\")\n\n        x1 = knp.ones((1,), dtype=dtype1)\n        x2 = knp.ones((1,), dtype=dtype2)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.subtract(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.subtract(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            knp.Subtract().symbolic_call(x1, x2).dtype, expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_subtract_python_types(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n\n        # python int\n        expected_dtype = standardize_dtype(jnp.subtract(x_jax, 1).dtype)\n\n        self.assertDType(knp.subtract(x, 1), expected_dtype)\n        self.assertDType(knp.Subtract().symbolic_call(x, 1), expected_dtype)\n\n        # python float\n        expected_dtype = standardize_dtype(jnp.subtract(x_jax, 1.0).dtype)\n\n        self.assertDType(knp.subtract(x, 1.0), expected_dtype)\n        self.assertDType(knp.Subtract().symbolic_call(x, 1.0), expected_dtype)\n\n    @parameterized.named_parameters(\n        named_product(\n            dtypes=list(itertools.combinations(ALL_DTYPES, 2))\n            + [(\"int8\", \"int8\")]\n        )\n    )\n    def test_matmul(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        # The shape of the matrix needs to meet the requirements of\n        # torch._int_mm to test hardware-accelerated matmul\n        x1 = knp.ones((17, 16), dtype=dtype1)\n        x2 = knp.ones((16, 8), dtype=dtype2)\n        x1_jax = jnp.ones((17, 16), dtype=dtype1)\n        x2_jax = jnp.ones((16, 8), dtype=dtype2)\n        if dtype1 == \"int8\" and dtype2 == \"int8\":\n            preferred_element_type = \"int32\"\n        else:\n            preferred_element_type = None\n        expected_dtype = standardize_dtype(\n            jnp.matmul(\n                x1_jax, x2_jax, preferred_element_type=preferred_element_type\n            ).dtype\n        )\n\n        self.assertEqual(\n            standardize_dtype(knp.matmul(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            knp.Matmul().symbolic_call(x1, x2).dtype, expected_dtype\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_multiply(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((1,), dtype=dtype1)\n        x2 = knp.ones((1,), dtype=dtype2)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.multiply(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.multiply(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            knp.Multiply().symbolic_call(x1, x2).dtype, expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_multiply_python_types(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n\n        # python int\n        expected_dtype = standardize_dtype(jnp.multiply(x_jax, 1).dtype)\n\n        self.assertDType(knp.multiply(x, 1), expected_dtype)\n        self.assertDType(knp.Multiply().symbolic_call(x, 1), expected_dtype)\n\n        # python float\n        expected_dtype = standardize_dtype(jnp.multiply(x_jax, 1.0).dtype)\n\n        self.assertDType(knp.multiply(x, 1.0), expected_dtype)\n        self.assertDType(knp.Multiply().symbolic_call(x, 1.0), expected_dtype)\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_mean(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.mean(x_jax).dtype)\n        if dtype == \"int64\":\n            expected_dtype = \"float32\"\n\n        self.assertEqual(standardize_dtype(knp.mean(x).dtype), expected_dtype)\n        self.assertEqual(knp.Mean().symbolic_call(x).dtype, expected_dtype)\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_max(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.max(x_jax).dtype)\n\n        self.assertEqual(standardize_dtype(knp.max(x).dtype), expected_dtype)\n        self.assertEqual(knp.Max().symbolic_call(x).dtype, expected_dtype)\n\n        # Test with initial\n        initial = 1\n        expected_dtype = standardize_dtype(\n            jnp.max(x_jax, initial=initial).dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.max(x, initial=initial).dtype), expected_dtype\n        )\n        self.assertEqual(\n            knp.Max(initial=initial).symbolic_call(x).dtype, expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_ones(self, dtype):\n        import jax.numpy as jnp\n\n        expected_dtype = standardize_dtype(jnp.ones([2, 3], dtype=dtype).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.ones([2, 3], dtype=dtype).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_zeros(self, dtype):\n        import jax.numpy as jnp\n\n        expected_dtype = standardize_dtype(jnp.zeros([2, 3], dtype=dtype).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.zeros([2, 3], dtype=dtype).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_absolute(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.absolute(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.absolute(x).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Absolute().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_all(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.all(x_jax).dtype)\n\n        self.assertEqual(standardize_dtype(knp.all(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.All().symbolic_call(x).dtype), expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_amax(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.amax(x_jax).dtype)\n\n        self.assertEqual(standardize_dtype(knp.amax(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Amax().symbolic_call(x).dtype), expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_amin(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.amin(x_jax).dtype)\n\n        self.assertEqual(standardize_dtype(knp.amin(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Amin().symbolic_call(x).dtype), expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_any(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.any(x_jax).dtype)\n\n        self.assertEqual(standardize_dtype(knp.any(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Any().symbolic_call(x).dtype), expected_dtype\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_append(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((1,), dtype=dtype1)\n        x2 = knp.ones((1,), dtype=dtype2)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.append(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.append(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            knp.Append().symbolic_call(x1, x2).dtype, expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_argmax(self, dtype):\n        import jax.numpy as jnp\n\n        if dtype == \"bool\":\n            value = [[True, False, True], [False, True, False]]\n        else:\n            value = [[1, 2, 3], [3, 2, 1]]\n        x = knp.array(value, dtype=dtype)\n        x_jax = jnp.array(value, dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.argmax(x_jax).dtype)\n\n        self.assertEqual(standardize_dtype(knp.argmax(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Argmax().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_argmin(self, dtype):\n        import jax.numpy as jnp\n\n        if dtype == \"bool\":\n            value = [[True, False, True], [False, True, False]]\n        else:\n            value = [[1, 2, 3], [3, 2, 1]]\n        x = knp.array(value, dtype=dtype)\n        x_jax = jnp.array(value, dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.argmin(x_jax).dtype)\n\n        self.assertEqual(standardize_dtype(knp.argmin(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Argmin().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_argpartition(self, dtype):\n        import jax.numpy as jnp\n\n        if dtype == \"bool\":\n            self.skipTest(\"argpartition doesn't support bool dtype\")\n\n        x = knp.array([1, 2, 3], dtype=dtype)\n        x_jax = jnp.array([1, 2, 3], dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.argpartition(x_jax, 1).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.argpartition(x, 1).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Argpartition(1).symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_argsort(self, dtype):\n        import jax.numpy as jnp\n\n        if dtype == \"bool\":\n            value = [[True, False, True], [False, True, False]]\n        else:\n            value = [[1, 2, 3], [4, 5, 6]]\n        x = knp.array(value, dtype=dtype)\n        x_jax = jnp.array(value, dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.argsort(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.argsort(x).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Argsort().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.parameters(\n        (10, None, None, None),  # stop\n        (2, 10, None, None),  # start, stop\n        (10, None, 2, None),  # stop, step\n        (0, 10, 2, None),  # start, stop, step\n        (0, 10, 0.5, None),\n        (10.0, None, 1, None),\n        (0, 10.0, 1, None),\n        (0.0, 10, 1, None),\n        (10, None, 1, \"float32\"),\n        (10, None, 1, \"int32\"),\n        (10, None, 1, \"int16\"),\n        (10, None, 1, \"float16\"),\n    )\n    def test_arange(self, start, stop, step, dtype):\n        import jax.numpy as jnp\n\n        expected_dtype = standardize_dtype(\n            jnp.arange(start, stop, step, dtype).dtype\n        )\n\n        self.assertEqual(\n            standardize_dtype(knp.arange(start, stop, step, dtype).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(\n                knp.Arange(dtype).symbolic_call(start, stop, step).dtype\n            ),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_arccos(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.arccos(x_jax).dtype)\n        if dtype == \"int64\":\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(standardize_dtype(knp.arccos(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Arccos().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_arccosh(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.arccosh(x_jax).dtype)\n        if dtype == \"int64\":\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(\n            standardize_dtype(knp.arccosh(x).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Arccosh().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_arcsin(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.arcsin(x_jax).dtype)\n        if dtype == \"int64\":\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(standardize_dtype(knp.arcsin(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Arcsin().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_arcsinh(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.arcsinh(x_jax).dtype)\n        if dtype == \"int64\":\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(\n            standardize_dtype(knp.arcsinh(x).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Arcsinh().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_arctan(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.arctan(x_jax).dtype)\n        if dtype == \"int64\":\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(standardize_dtype(knp.arctan(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Arctan().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_arctan2(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((1,), dtype=dtype1)\n        x2 = knp.ones((1,), dtype=dtype2)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.arctan2(x1_jax, x2_jax).dtype)\n        if dtype1 is not None and \"float\" not in dtype1:\n            if dtype2 is not None and \"float\" not in dtype2:\n                if \"int64\" in (dtype1, dtype2) or \"uint32\" in (dtype1, dtype2):\n                    expected_dtype = backend.floatx()\n\n        self.assertEqual(\n            standardize_dtype(knp.arctan2(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            knp.Arctan2().symbolic_call(x1, x2).dtype, expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_arctanh(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.arctanh(x_jax).dtype)\n        if dtype == \"int64\":\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(\n            standardize_dtype(knp.arctanh(x).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Arctanh().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.parameters(\n        (bool(0), \"bool\"),\n        (int(0), \"int32\"),\n        (float(0), backend.floatx()),\n        ([False, True, False], \"bool\"),\n        ([1, 2, 3], \"int32\"),\n        ([1.0, 2.0, 3.0], backend.floatx()),\n        ([1, 2.0, 3], backend.floatx()),\n        ([[False], [True], [False]], \"bool\"),\n        ([[1], [2], [3]], \"int32\"),\n        ([[1], [2.0], [3]], backend.floatx()),\n        *[\n            (np.array(0, dtype=dtype), dtype)\n            for dtype in ALL_DTYPES\n            if dtype is not None\n        ],\n    )\n    def test_array(self, x, expected_dtype):\n        self.assertDType(knp.array(x), expected_dtype)\n        # TODO: support the assertion of knp.Array\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_average(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((1,), dtype=dtype1)\n        x2 = knp.ones((1,), dtype=dtype2)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(\n            jnp.average(x1_jax, weights=x2_jax).dtype\n        )\n        if dtype1 is not None and \"float\" not in dtype1:\n            if dtype2 is not None and \"float\" not in dtype2:\n                if \"int64\" in (dtype1, dtype2) or \"uint32\" in (dtype1, dtype2):\n                    expected_dtype = backend.floatx()\n\n        self.assertEqual(\n            standardize_dtype(knp.average(x1, weights=x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            knp.Average().symbolic_call(x1, weights=x2).dtype, expected_dtype\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(INT_DTYPES, 2))\n    )\n    def test_bitwise_and(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((1,), dtype=dtype1)\n        x2 = knp.ones((1,), dtype=dtype2)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(\n            jnp.bitwise_and(x1_jax, x2_jax).dtype\n        )\n\n        self.assertDType(knp.bitwise_and(x1, x2), expected_dtype)\n        self.assertDType(knp.BitwiseAnd().symbolic_call(x1, x2), expected_dtype)\n\n    @parameterized.named_parameters(named_product(dtype=INT_DTYPES))\n    def test_bitwise_invert(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.invert(x_jax).dtype)\n\n        self.assertDType(knp.bitwise_invert(x), expected_dtype)\n        self.assertDType(knp.BitwiseInvert().symbolic_call(x), expected_dtype)\n\n    # bitwise_not is same as bitwise_invert\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(INT_DTYPES, 2))\n    )\n    def test_bitwise_or(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((1,), dtype=dtype1)\n        x2 = knp.ones((1,), dtype=dtype2)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.bitwise_or(x1_jax, x2_jax).dtype)\n\n        self.assertDType(knp.bitwise_or(x1, x2), expected_dtype)\n        self.assertDType(knp.BitwiseOr().symbolic_call(x1, x2), expected_dtype)\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(INT_DTYPES, 2))\n    )\n    def test_bitwise_xor(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((1,), dtype=dtype1)\n        x2 = knp.ones((1,), dtype=dtype2)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(\n            jnp.bitwise_xor(x1_jax, x2_jax).dtype\n        )\n\n        self.assertDType(knp.bitwise_xor(x1, x2), expected_dtype)\n        self.assertDType(knp.BitwiseXor().symbolic_call(x1, x2), expected_dtype)\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.product(INT_DTYPES, INT_DTYPES + [None]))\n    )\n    def test_bitwise_left_shift(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((1,), dtype=dtype1)\n        x2 = knp.ones((1,), dtype=dtype2) if dtype2 else 1\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2) if dtype2 else 1\n        expected_dtype = standardize_dtype(jnp.left_shift(x1_jax, x2_jax).dtype)\n\n        self.assertDType(knp.bitwise_left_shift(x1, x2), expected_dtype)\n        self.assertDType(\n            knp.BitwiseLeftShift().symbolic_call(x1, x2), expected_dtype\n        )\n\n    # left_shift is same as bitwise_left_shift\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.product(INT_DTYPES, INT_DTYPES + [None]))\n    )\n    def test_bitwise_right_shift(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((1,), dtype=dtype1)\n        x2 = knp.ones((1,), dtype=dtype2) if dtype2 else 1\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2) if dtype2 else 1\n        expected_dtype = standardize_dtype(\n            jnp.right_shift(x1_jax, x2_jax).dtype\n        )\n\n        self.assertDType(knp.bitwise_right_shift(x1, x2), expected_dtype)\n        self.assertDType(\n            knp.BitwiseRightShift().symbolic_call(x1, x2), expected_dtype\n        )\n\n    # right_shift is same as bitwise_right_shift\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_broadcast_to(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((3,), dtype=dtype)\n        x_jax = jnp.ones((3,), dtype=dtype)\n        expected_dtype = standardize_dtype(\n            jnp.broadcast_to(x_jax, (3, 3)).dtype\n        )\n\n        self.assertEqual(\n            standardize_dtype(knp.broadcast_to(x, (3, 3)).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.BroadcastTo((3, 3)).symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_cbrt(self, dtype):\n        import jax.numpy as jnp\n\n        x1 = knp.ones((1,), dtype=dtype)\n        x1_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.cbrt(x1_jax).dtype)\n\n        self.assertEqual(standardize_dtype(knp.cbrt(x1).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Cbrt().symbolic_call(x1).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_ceil(self, dtype):\n        import jax.numpy as jnp\n\n        if dtype is None:\n            dtype = backend.floatx()\n        if dtype == \"bool\":\n            value = [[True, False, True], [True, False, True]]\n        elif \"int\" in dtype:\n            value = [[1, 2, 2], [2, 11, 5]]\n        else:\n            value = [[1.2, 2.1, 2.5], [2.4, 11.9, 5.5]]\n        x = knp.array(value, dtype=dtype)\n        x_jax = jnp.array(value, dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.ceil(x_jax).dtype)\n        # Here, we follow Numpy's rule, not JAX's; ints are promoted to floats.\n        if dtype == \"bool\" or is_int_dtype(dtype):\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(standardize_dtype(knp.ceil(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Ceil().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_clip(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.clip(x_jax, 1, 2).dtype)\n        if dtype == \"bool\":\n            expected_dtype = \"int32\"\n\n        self.assertEqual(\n            standardize_dtype(knp.clip(x, 1, 2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Clip(1, 2).symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_concatenate(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((1,), dtype=dtype1)\n        x2 = knp.ones((1,), dtype=dtype2)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(\n            jnp.concatenate([x1_jax, x2_jax]).dtype\n        )\n\n        self.assertEqual(\n            standardize_dtype(knp.concatenate([x1, x2]).dtype), expected_dtype\n        )\n        self.assertEqual(\n            knp.Concatenate().symbolic_call([x1, x2]).dtype, expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_cos(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.cos(x_jax).dtype)\n        if dtype == \"int64\":\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(standardize_dtype(knp.cos(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Cos().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_cosh(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.cosh(x_jax).dtype)\n        if dtype == \"int64\":\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(standardize_dtype(knp.cosh(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Cosh().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_copy(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.copy(x_jax).dtype)\n\n        self.assertEqual(standardize_dtype(knp.copy(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Copy().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_corrcoef(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((2, 4), dtype=dtype)\n        x_jax = jnp.ones((2, 4), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.corrcoef(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.corrcoef(x).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Corrcoef().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_correlate(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((3,), dtype=dtype1)\n        x2 = knp.ones((3,), dtype=dtype2)\n        x1_jax = jnp.ones((3,), dtype=dtype1)\n        x2_jax = jnp.ones((3,), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.correlate(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.correlate(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Correlate().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_count_nonzero(self, dtype):\n        x = knp.ones((1,), dtype=dtype)\n        expected_dtype = \"int32\"\n\n        self.assertEqual(\n            standardize_dtype(knp.count_nonzero(x).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.CountNonzero().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_cross(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((1, 1, 3), dtype=dtype1)\n        x2 = knp.ones((1, 1, 3), dtype=dtype2)\n        x1_jax = jnp.ones((1, 1, 3), dtype=dtype1)\n        x2_jax = jnp.ones((1, 1, 3), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.cross(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.cross(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            knp.Cross().symbolic_call(x1, x2).dtype, expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_cumprod(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.cumprod(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.cumprod(x).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Cumprod().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_cumsum(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.cumsum(x_jax).dtype)\n\n        self.assertEqual(standardize_dtype(knp.cumsum(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Cumsum().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_deg2rad(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.deg2rad(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.deg2rad(x).dtype), expected_dtype\n        )\n\n        self.assertEqual(\n            standardize_dtype(knp.Deg2rad().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_diag(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.diag(x_jax).dtype)\n\n        self.assertEqual(standardize_dtype(knp.diag(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Diag().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_diagflat(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.diagflat(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.diagflat(x).dtype), expected_dtype\n        )\n\n        self.assertEqual(\n            standardize_dtype(knp.Diagflat().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n        x_2d = knp.ones((1, 1), dtype=dtype)\n        x_jax_2d = jnp.ones((1, 1), dtype=dtype)\n        expected_dtype_2d = standardize_dtype(jnp.diagflat(x_jax_2d).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.diagflat(x_2d).dtype), expected_dtype_2d\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Diagflat().symbolic_call(x_2d).dtype),\n            expected_dtype_2d,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_diagonal(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1, 1, 1), dtype=dtype)\n        x_jax = jnp.ones((1, 1, 1), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.diagonal(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.diagonal(x).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Diagonal().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_diff(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.diff(x_jax).dtype)\n\n        self.assertEqual(standardize_dtype(knp.diff(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Diff().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_digitize(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        bins = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        x_bins = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.digitize(x_jax, x_bins).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.digitize(x, bins).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Digitize().symbolic_call(x, bins).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_divide(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((1,), dtype=dtype1)\n        x2 = knp.ones((1,), dtype=dtype2)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.divide(x1_jax, x2_jax).dtype)\n\n        self.assertDType(knp.divide(x1, x2), expected_dtype)\n        self.assertDType(knp.Divide().symbolic_call(x1, x2), expected_dtype)\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_divide_python_types(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n\n        # python int\n        expected_dtype = standardize_dtype(jnp.divide(x_jax, 1).dtype)\n\n        self.assertDType(knp.divide(x, 1), expected_dtype)\n        self.assertDType(knp.Divide().symbolic_call(x, 1), expected_dtype)\n\n        # python float\n        expected_dtype = standardize_dtype(jnp.divide(x_jax, 1.0).dtype)\n\n        self.assertDType(knp.divide(x, 1.0), expected_dtype)\n        self.assertDType(knp.Divide().symbolic_call(x, 1.0), expected_dtype)\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_dot(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((2, 3, 4), dtype=dtype1)\n        x2 = knp.ones((4, 3), dtype=dtype2)\n        x1_jax = jnp.ones((2, 3, 4), dtype=dtype1)\n        x2_jax = jnp.ones((4, 3), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.dot(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.dot(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(knp.Dot().symbolic_call(x1, x2).dtype, expected_dtype)\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_dstack(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((1, 1), dtype=dtype1)\n        x2 = knp.ones((1, 1), dtype=dtype2)\n        x1_jax = jnp.ones((1, 1), dtype=dtype1)\n        x2_jax = jnp.ones((1, 1), dtype=dtype2)\n\n        expected_dtype = standardize_dtype(jnp.dstack([x1_jax, x2_jax]).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.dstack([x1, x2]).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Dstack().symbolic_call([x1, x2]).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(\n            dtypes=list(itertools.combinations(ALL_DTYPES, 2))\n            + [(\"int8\", \"int8\")]\n        )\n    )\n    def test_einsum(self, dtypes):\n        import jax.numpy as jnp\n\n        def get_input_shapes(subscripts):\n            x1_labels = subscripts.split(\",\")[0]\n            x2_labels = subscripts.split(\"->\")[0][len(x1_labels) + 1 :]\n            x1_shape = [1] * len(x1_labels)\n            x2_shape = [1] * len(x2_labels)\n            return x1_shape, x2_shape\n\n        dtype1, dtype2 = dtypes\n        subscripts = \"ijk,lkj->il\"\n        x1_shape, x2_shape = get_input_shapes(subscripts)\n        x1 = knp.ones(x1_shape, dtype=dtype1)\n        x2 = knp.ones(x2_shape, dtype=dtype2)\n        x1_jax = jnp.ones(x1_shape, dtype=dtype1)\n        x2_jax = jnp.ones(x2_shape, dtype=dtype2)\n        if dtype1 == \"int8\" and dtype2 == \"int8\":\n            preferred_element_type = \"int32\"\n        else:\n            preferred_element_type = None\n        expected_dtype = standardize_dtype(\n            jnp.einsum(\n                subscripts,\n                x1_jax,\n                x2_jax,\n                preferred_element_type=preferred_element_type,\n            ).dtype\n        )\n\n        self.assertEqual(\n            standardize_dtype(knp.einsum(subscripts, x1, x2).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(\n                knp.Einsum(subscripts).symbolic_call(x1, x2).dtype\n            ),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(\n            dtypes=list(itertools.combinations(ALL_DTYPES, 2))\n            + [(\"int8\", \"int8\")]\n        )\n    )\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=f\"{backend.backend()} doesn't implement custom ops for einsum.\",\n    )\n    def test_einsum_custom_ops_for_tensorflow(self, dtypes):\n        import jax.numpy as jnp\n\n        def get_input_shapes(subscripts):\n            x1_labels = subscripts.split(\",\")[0]\n            x2_labels = subscripts.split(\"->\")[0][len(x1_labels) + 1 :]\n            x1_shape = [1] * len(x1_labels)\n            x2_shape = [1] * len(x2_labels)\n            return x1_shape, x2_shape\n\n        dtype1, dtype2 = dtypes\n        for subscripts in [\n            \"a,b->ab\",\n            \"ab,b->a\",\n            \"ab,bc->ac\",\n            \"ab,cb->ac\",\n            \"abc,cd->abd\",\n            \"abc,cde->abde\",\n            \"abc,dc->abd\",\n            \"abc,dce->abde\",\n            \"abc,dec->abde\",\n            \"abcd,abde->abce\",\n            \"abcd,abed->abce\",\n            \"abcd,acbe->adbe\",\n            \"abcd,adbe->acbe\",\n            \"abcd,aecd->acbe\",\n            \"abcd,aecd->aceb\",\n            \"abcd,cde->abe\",\n            \"abcd,ced->abe\",\n            \"abcd,ecd->abe\",\n            \"abcde,aebf->adbcf\",\n            \"abcde,afce->acdbf\",\n        ]:\n            x1_shape, x2_shape = get_input_shapes(subscripts)\n            x1 = knp.ones(x1_shape, dtype=dtype1)\n            x2 = knp.ones(x2_shape, dtype=dtype2)\n            x1_jax = jnp.ones(x1_shape, dtype=dtype1)\n            x2_jax = jnp.ones(x2_shape, dtype=dtype2)\n            if dtype1 == \"int8\" and dtype2 == \"int8\":\n                preferred_element_type = \"int32\"\n            else:\n                preferred_element_type = None\n            expected_dtype = standardize_dtype(\n                jnp.einsum(\n                    subscripts,\n                    x1_jax,\n                    x2_jax,\n                    preferred_element_type=preferred_element_type,\n                ).dtype\n            )\n\n            self.assertEqual(\n                standardize_dtype(knp.einsum(subscripts, x1, x2).dtype),\n                expected_dtype,\n            )\n            self.assertEqual(\n                standardize_dtype(\n                    knp.Einsum(subscripts).symbolic_call(x1, x2).dtype\n                ),\n                expected_dtype,\n            )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_empty(self, dtype):\n        import jax.numpy as jnp\n\n        expected_dtype = standardize_dtype(jnp.empty([2, 3], dtype=dtype).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.empty([2, 3], dtype=dtype).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_empty_like(self, dtype):\n        import jax.numpy as jnp\n\n        x_jax = jnp.empty([2, 3, 4], dtype=dtype)\n        x = knp.ones([2, 3, 4], dtype=dtype)\n        expected_dtype = standardize_dtype(\n            jnp.empty_like(x_jax, dtype=dtype).dtype\n        )\n\n        self.assertEqual(\n            standardize_dtype(knp.empty_like(x, dtype=dtype).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knp.EmptyLike().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_equal(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((), dtype=dtype1)\n        x2 = knp.ones((), dtype=dtype2)\n        x1_jax = jnp.ones((), dtype=dtype1)\n        x2_jax = jnp.ones((), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.equal(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.equal(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Equal().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_exp(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.exp(x_jax).dtype)\n        if dtype == \"int64\":\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(standardize_dtype(knp.exp(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Exp().symbolic_call(x).dtype), expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_exp2(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.exp2(x_jax).dtype)\n        if dtype == \"int64\":\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(standardize_dtype(knp.exp2(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Exp2().symbolic_call(x).dtype), expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_expand_dims(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.expand_dims(x_jax, -1).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.expand_dims(x, -1).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.ExpandDims(-1).symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_expm1(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.expm1(x_jax).dtype)\n        if dtype == \"int64\":\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(standardize_dtype(knp.expm1(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Expm1().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_eye(self, dtype):\n        import jax.numpy as jnp\n\n        expected_dtype = standardize_dtype(jnp.eye(3, dtype=dtype).dtype)\n        if dtype is None:\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(\n            standardize_dtype(knp.eye(3, dtype=dtype).dtype),\n            expected_dtype,\n        )\n\n        expected_dtype = standardize_dtype(jnp.eye(3, 4, 1, dtype=dtype).dtype)\n        if dtype is None:\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(\n            standardize_dtype(knp.eye(3, 4, k=1, dtype=dtype).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_flip(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.flip(x_jax, -1).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.flip(x, -1).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Flip(-1).symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_floor(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.floor(x_jax).dtype)\n        # Here, we follow Numpy's rule, not JAX's; ints are promoted to floats.\n        if dtype == \"bool\" or is_int_dtype(dtype):\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(standardize_dtype(knp.floor(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Floor().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_floor_divide(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((), dtype=dtype1)\n        x2 = knp.ones((), dtype=dtype2)\n        x1_jax = jnp.ones((), dtype=dtype1)\n        x2_jax = jnp.ones((), dtype=dtype2)\n        expected_dtype = standardize_dtype(\n            jnp.floor_divide(x1_jax, x2_jax).dtype\n        )\n\n        self.assertEqual(\n            standardize_dtype(knp.floor_divide(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.FloorDivide().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_floor_divide_python_types(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n\n        # python int\n        expected_dtype = standardize_dtype(jnp.floor_divide(x_jax, 1).dtype)\n\n        self.assertDType(knp.floor_divide(x, 1), expected_dtype)\n        self.assertDType(knp.FloorDivide().symbolic_call(x, 1), expected_dtype)\n\n        # python float\n        expected_dtype = standardize_dtype(jnp.floor_divide(x_jax, 1.0).dtype)\n\n        self.assertDType(knp.floor_divide(x, 1.0), expected_dtype)\n        self.assertDType(\n            knp.FloorDivide().symbolic_call(x, 1.0), expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_full(self, dtype):\n        import jax.numpy as jnp\n\n        expected_dtype = standardize_dtype(jnp.full((), 0, dtype=dtype).dtype)\n        if dtype is None:\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(\n            standardize_dtype(knp.full((), 0, dtype=dtype).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Full((), dtype=dtype).symbolic_call(0).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_full_like(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.full_like(x_jax, 0).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.full_like(x, 0).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.FullLike().symbolic_call(x, 0).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(INT_DTYPES, 2))\n    )\n    def test_gcd(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((), dtype=dtype1)\n        x2 = knp.ones((), dtype=dtype2)\n        x1_jax = jnp.ones((), dtype=dtype1)\n        x2_jax = jnp.ones((), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.gcd(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.gcd(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Gcd().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_greater(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((), dtype=dtype1)\n        x2 = knp.ones((), dtype=dtype2)\n        x1_jax = jnp.ones((), dtype=dtype1)\n        x2_jax = jnp.ones((), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.greater(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.greater(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Greater().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_greater_equal(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((), dtype=dtype1)\n        x2 = knp.ones((), dtype=dtype2)\n        x1_jax = jnp.ones((), dtype=dtype1)\n        x2_jax = jnp.ones((), dtype=dtype2)\n        expected_dtype = standardize_dtype(\n            jnp.greater_equal(x1_jax, x2_jax).dtype\n        )\n\n        self.assertEqual(\n            standardize_dtype(knp.greater_equal(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.GreaterEqual().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_heaviside(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((1, 1), dtype=dtype1)\n        x2 = knp.ones((1, 1), dtype=dtype2)\n        x1_jax = jnp.ones((1, 1), dtype=dtype1)\n        x2_jax = jnp.ones((1, 1), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.heaviside(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.heaviside(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Heaviside().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_hstack(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((1, 1), dtype=dtype1)\n        x2 = knp.ones((1, 1), dtype=dtype2)\n        x1_jax = jnp.ones((1, 1), dtype=dtype1)\n        x2_jax = jnp.ones((1, 1), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.hstack([x1_jax, x2_jax]).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.hstack([x1, x2]).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Hstack().symbolic_call([x1, x2]).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_hypot(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((1, 1), dtype=dtype1)\n        x2 = knp.ones((1, 1), dtype=dtype2)\n        x1_jax = jnp.ones((1, 1), dtype=dtype1)\n        x2_jax = jnp.ones((1, 1), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.hypot(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.hypot(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Hypot().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_identity(self, dtype):\n        import jax.numpy as jnp\n\n        expected_dtype = standardize_dtype(jnp.identity(3, dtype=dtype).dtype)\n        if dtype is None:\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(\n            standardize_dtype(knp.identity(3, dtype=dtype).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_isclose(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((), dtype=dtype1)\n        x2 = knp.ones((), dtype=dtype2)\n        x1_jax = jnp.ones((), dtype=dtype1)\n        x2_jax = jnp.ones((), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.isclose(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.isclose(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Isclose().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_isfinite(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.isfinite(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.isfinite(x).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Isfinite().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_isin(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((), dtype=dtype1)\n        x2 = knp.ones((), dtype=dtype2)\n        x1_jax = jnp.ones((), dtype=dtype1)\n        x2_jax = jnp.ones((), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.isin(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.isin(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.IsIn().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_isinf(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.isinf(x_jax).dtype)\n\n        self.assertEqual(standardize_dtype(knp.isinf(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Isinf().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_isnan(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.isnan(x_jax).dtype)\n\n        self.assertEqual(standardize_dtype(knp.isnan(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Isnan().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_isneginf(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.isneginf(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.isneginf(x).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Isneginf().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_isposinf(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.isposinf(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.isposinf(x).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Isposinf().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_isreal(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.isreal(x_jax).dtype)\n\n        self.assertEqual(standardize_dtype(knp.isreal(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Isreal().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(INT_DTYPES, 2))\n    )\n    def test_kron(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((), dtype=dtype1)\n        x2 = knp.ones((), dtype=dtype2)\n        x1_jax = jnp.ones((), dtype=dtype1)\n        x2_jax = jnp.ones((), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.kron(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.kron(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Kron().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(INT_DTYPES, 2))\n    )\n    def test_lcm(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((), dtype=dtype1)\n        x2 = knp.ones((), dtype=dtype2)\n        x1_jax = jnp.ones((), dtype=dtype1)\n        x2_jax = jnp.ones((), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.lcm(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.lcm(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Lcm().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=list(itertools.product(ALL_DTYPES, INT_DTYPES)))\n    )\n    def test_ldexp(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((), dtype=dtype1)\n        x2 = knp.ones((), dtype=dtype2)\n        x1_jax = jnp.ones((), dtype=dtype1)\n        x2_jax = jnp.ones((), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.ldexp(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.ldexp(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Ldexp().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_less(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((), dtype=dtype1)\n        x2 = knp.ones((), dtype=dtype2)\n        x1_jax = jnp.ones((), dtype=dtype1)\n        x2_jax = jnp.ones((), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.less(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.less(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Less().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_less_equal(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((), dtype=dtype1)\n        x2 = knp.ones((), dtype=dtype2)\n        x1_jax = jnp.ones((), dtype=dtype1)\n        x2_jax = jnp.ones((), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.less_equal(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.less_equal(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.LessEqual().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(\n            start_and_stop=[\n                [0, 10],\n                [0.5, 10.5],\n                [np.array([0, 1], \"int32\"), np.array([10, 20], \"int32\")],\n                [np.array([0, 1], \"float32\"), np.array([10, 20], \"float32\")],\n            ],\n            num=[0, 1, 5],\n            dtype=FLOAT_DTYPES + [None],\n        )\n    )\n    def test_linspace(self, start_and_stop, num, dtype):\n        import jax.numpy as jnp\n\n        start, stop = start_and_stop\n        expected_dtype = standardize_dtype(\n            jnp.linspace(start, stop, num, dtype=dtype).dtype\n        )\n\n        self.assertEqual(\n            standardize_dtype(\n                knp.linspace(start, stop, num, dtype=dtype).dtype\n            ),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(\n                knp.Linspace(num, dtype=dtype).symbolic_call(start, stop).dtype\n            ),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_log(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((3, 3), dtype=dtype)\n        x_jax = jnp.ones((3, 3), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.log(x_jax).dtype)\n        if dtype == \"int64\":\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(standardize_dtype(knp.log(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Log().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_log10(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((3, 3), dtype=dtype)\n        x_jax = jnp.ones((3, 3), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.log10(x_jax).dtype)\n        if dtype == \"int64\":\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(standardize_dtype(knp.log10(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Log10().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_log1p(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((3, 3), dtype=dtype)\n        x_jax = jnp.ones((3, 3), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.log1p(x_jax).dtype)\n        if dtype == \"int64\":\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(standardize_dtype(knp.log1p(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Log1p().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_log2(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((3, 3), dtype=dtype)\n        x_jax = jnp.ones((3, 3), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.log2(x_jax).dtype)\n        if dtype == \"int64\":\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(standardize_dtype(knp.log2(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Log2().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_logaddexp(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((3, 3), dtype=dtype1)\n        x2 = knp.ones((3, 3), dtype=dtype2)\n        x1_jax = jnp.ones((3, 3), dtype=dtype1)\n        x2_jax = jnp.ones((3, 3), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.logaddexp(x1_jax, x2_jax).dtype)\n        # jnp.logaddexp will promote \"int64\" and \"uint32\" to \"float64\"\n        # force the promotion to `backend.floatx()`\n        if dtype1 is not None and \"float\" not in dtype1:\n            if dtype2 is not None and \"float\" not in dtype2:\n                if \"int64\" in (dtype1, dtype2) or \"uint32\" in (dtype1, dtype2):\n                    expected_dtype = backend.floatx()\n\n        self.assertEqual(\n            standardize_dtype(knp.logaddexp(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Logaddexp().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_logaddexp2(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((3, 3), dtype=dtype1)\n        x2 = knp.ones((3, 3), dtype=dtype2)\n        x1_jax = jnp.ones((3, 3), dtype=dtype1)\n        x2_jax = jnp.ones((3, 3), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.logaddexp2(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.logaddexp2(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Logaddexp2().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(\n            start_and_stop=[\n                [0, 10],\n                [0.5, 10.5],\n                [np.array([0, 1], \"int32\"), np.array([10, 20], \"int32\")],\n                [np.array([0, 1], \"float32\"), np.array([10, 20], \"float32\")],\n            ],\n            num=[0, 1, 5],\n            dtype=FLOAT_DTYPES + [None],\n        )\n    )\n    def test_logspace(self, start_and_stop, num, dtype):\n        import jax.numpy as jnp\n\n        start, stop = start_and_stop\n        expected_dtype = standardize_dtype(\n            jnp.logspace(start, stop, num, dtype=dtype).dtype\n        )\n\n        self.assertEqual(\n            standardize_dtype(\n                knp.logspace(start, stop, num, dtype=dtype).dtype\n            ),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(\n                knp.Logspace(num, dtype=dtype).symbolic_call(start, stop).dtype\n            ),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(\n            start_and_stop=[\n                [1, 1000],\n                [0.5, 10.5],\n                [\n                    np.array([1, 2], \"float32\"),\n                    np.array([100, 200], \"float32\"),\n                ],\n            ],\n            num=[0, 1, 5],\n            dtype=FLOAT_DTYPES + [None],\n        )\n    )\n    def test_geomspace(self, start_and_stop, num, dtype):\n        import jax.numpy as jnp\n\n        start, stop = start_and_stop\n        expected_dtype = standardize_dtype(\n            jnp.geomspace(start, stop, num, dtype=dtype).dtype\n        )\n\n        self.assertEqual(\n            standardize_dtype(\n                knp.geomspace(start, stop, num, dtype=dtype).dtype\n            ),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(\n                knp.Geomspace(num, dtype=dtype).symbolic_call(start, stop).dtype\n            ),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_logical_and(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((), dtype=dtype1)\n        x2 = knp.ones((), dtype=dtype2)\n        x1_jax = jnp.ones((), dtype=dtype1)\n        x2_jax = jnp.ones((), dtype=dtype2)\n        expected_dtype = standardize_dtype(\n            jnp.logical_and(x1_jax, x2_jax).dtype\n        )\n\n        self.assertEqual(\n            standardize_dtype(knp.logical_and(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.LogicalAnd().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_logical_not(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.logical_not(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.logical_not(x).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.LogicalNot().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_logical_or(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((), dtype=dtype1)\n        x2 = knp.ones((), dtype=dtype2)\n        x1_jax = jnp.ones((), dtype=dtype1)\n        x2_jax = jnp.ones((), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.logical_or(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.logical_or(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.LogicalOr().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_logical_xor(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((), dtype=dtype1)\n        x2 = knp.ones((), dtype=dtype2)\n        x1_jax = jnp.ones((), dtype=dtype1)\n        x2_jax = jnp.ones((), dtype=dtype2)\n        expected_dtype = standardize_dtype(\n            jnp.logical_xor(x1_jax, x2_jax).dtype\n        )\n\n        self.assertEqual(\n            standardize_dtype(knp.logical_xor(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.LogicalXor().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_maximum(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((), dtype=dtype1)\n        x2 = knp.ones((), dtype=dtype2)\n        x1_jax = jnp.ones((), dtype=dtype1)\n        x2_jax = jnp.ones((), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.maximum(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.maximum(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Maximum().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_maximum_python_types(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n\n        # python int\n        expected_dtype = standardize_dtype(jnp.maximum(x_jax, 1).dtype)\n\n        self.assertDType(knp.maximum(x, 1), expected_dtype)\n        self.assertDType(knp.Maximum().symbolic_call(x, 1), expected_dtype)\n\n        # python float\n        expected_dtype = standardize_dtype(jnp.maximum(x_jax, 1.0).dtype)\n\n        self.assertDType(knp.maximum(x, 1.0), expected_dtype)\n        self.assertDType(knp.Maximum().symbolic_call(x, 1.0), expected_dtype)\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_median(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((3, 3), dtype=dtype)\n        x_jax = jnp.ones((3, 3), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.median(x_jax).dtype)\n        if dtype == \"int64\":\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(standardize_dtype(knp.median(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Median().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knp.median(x, axis=1).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Median(axis=1).symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_meshgrid(self, dtype):\n        import jax.numpy as jnp\n\n        if dtype == \"bool\":\n            self.skipTest(\"meshgrid doesn't support bool dtype\")\n        elif dtype is None:\n            dtype = backend.floatx()\n        x = knp.array([1, 2, 3], dtype=dtype)\n        y = knp.array([4, 5, 6], dtype=dtype)\n        x_jax = jnp.array([1, 2, 3], dtype=dtype)\n        y_jax = jnp.array([4, 5, 6], dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.meshgrid(x_jax, y_jax)[0].dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.meshgrid(x, y)[0].dtype), expected_dtype\n        )\n        self.assertEqual(\n            knp.Meshgrid().symbolic_call(x, y)[0].dtype, expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_min(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.min(x_jax).dtype)\n\n        self.assertEqual(standardize_dtype(knp.min(x).dtype), expected_dtype)\n        self.assertEqual(knp.Min().symbolic_call(x).dtype, expected_dtype)\n\n        # Test with initial\n        initial = 0\n        expected_dtype = standardize_dtype(\n            jnp.min(x_jax, initial=initial).dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.min(x, initial=initial).dtype), expected_dtype\n        )\n        self.assertEqual(\n            knp.Min(initial=initial).symbolic_call(x).dtype, expected_dtype\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_minimum(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((), dtype=dtype1)\n        x2 = knp.ones((), dtype=dtype2)\n        x1_jax = jnp.ones((), dtype=dtype1)\n        x2_jax = jnp.ones((), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.minimum(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.minimum(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Minimum().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_minimum_python_types(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n\n        # python int\n        expected_dtype = standardize_dtype(jnp.minimum(x_jax, 1).dtype)\n\n        self.assertDType(knp.minimum(x, 1), expected_dtype)\n        self.assertDType(knp.Minimum().symbolic_call(x, 1), expected_dtype)\n\n        # python float\n        expected_dtype = standardize_dtype(jnp.minimum(x_jax, 1.0).dtype)\n\n        self.assertDType(knp.minimum(x, 1.0), expected_dtype)\n        self.assertDType(knp.Minimum().symbolic_call(x, 1.0), expected_dtype)\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_mod(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((), dtype=dtype1)\n        x2 = knp.ones((), dtype=dtype2)\n        x1_jax = jnp.ones((), dtype=dtype1)\n        x2_jax = jnp.ones((), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.mod(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.mod(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Mod().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_fmod(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((), dtype=dtype1)\n        x2 = knp.ones((), dtype=dtype2)\n        x1_jax = jnp.ones((), dtype=dtype1)\n        x2_jax = jnp.ones((), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.fmod(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.fmod(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Fmod().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_moveaxis(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1, 1, 1), dtype=dtype)\n        x_jax = jnp.ones((1, 1, 1), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.moveaxis(x_jax, -2, -1).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.moveaxis(x, -2, -1).dtype), expected_dtype\n        )\n        self.assertEqual(\n            knp.Moveaxis(-2, -1).symbolic_call(x).dtype, expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_nanargmax(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((3,), dtype=dtype)\n        x_jax = jnp.ones((3,), dtype=dtype)\n\n        expected_dtype = standardize_dtype(jnp.nanargmax(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.nanargmax(x).dtype),\n            expected_dtype,\n        )\n\n        self.assertEqual(\n            standardize_dtype(knp.Nanargmax().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_nanargmin(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((3,), dtype=dtype)\n        x_jax = jnp.ones((3,), dtype=dtype)\n\n        expected_dtype = standardize_dtype(jnp.nanargmin(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.nanargmin(x).dtype),\n            expected_dtype,\n        )\n\n        self.assertEqual(\n            standardize_dtype(knp.Nanargmin().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_nancumsum(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n\n        expected_dtype = standardize_dtype(jnp.nancumsum(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.nancumsum(x).dtype),\n            expected_dtype,\n        )\n\n        self.assertEqual(\n            standardize_dtype(knp.Nancumsum().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_nancumprod(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n\n        expected_dtype = standardize_dtype(jnp.nancumprod(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.nancumprod(x).dtype),\n            expected_dtype,\n        )\n\n        self.assertEqual(\n            standardize_dtype(knp.Nancumprod().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_nanmax(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.nanmax(x_jax).dtype)\n\n        if backend.backend() == \"torch\" and expected_dtype == \"uint32\":\n            expected_dtype = \"int32\"\n\n        self.assertEqual(standardize_dtype(knp.nanmax(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Nanmax().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_nanmean(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.nanmean(x_jax).dtype)\n\n        if backend.backend() == \"torch\" and expected_dtype == \"uint32\":\n            expected_dtype = \"int32\"\n\n        self.assertEqual(\n            standardize_dtype(knp.nanmean(x).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Nanmean().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_nanmin(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.nanmin(x_jax).dtype)\n\n        if backend.backend() == \"torch\" and expected_dtype == \"uint32\":\n            expected_dtype = \"int32\"\n\n        self.assertEqual(standardize_dtype(knp.nanmin(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Nanmin().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_nanprod(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1, 1, 1), dtype=dtype)\n        x_jax = jnp.ones((1, 1, 1), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.nanprod(x_jax).dtype)\n\n        if backend.backend() == \"torch\" and expected_dtype == \"uint32\":\n            expected_dtype = \"int32\"\n\n        self.assertEqual(\n            standardize_dtype(knp.nanprod(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Nanprod().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_nanstd(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n\n        expected_dtype = standardize_dtype(jnp.nanstd(x_jax).dtype)\n\n        if backend.backend() == \"torch\" and expected_dtype == \"uint32\":\n            expected_dtype = \"int32\"\n\n        self.assertEqual(\n            standardize_dtype(knp.nanstd(x).dtype),\n            expected_dtype,\n        )\n\n        self.assertEqual(\n            standardize_dtype(knp.Nanstd().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_nansum(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.nansum(x_jax).dtype)\n\n        if backend.backend() == \"torch\" and expected_dtype == \"uint32\":\n            expected_dtype = \"int32\"\n\n        self.assertEqual(standardize_dtype(knp.nansum(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Nansum().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_nanvar(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n\n        expected_dtype = standardize_dtype(jnp.nanvar(x_jax).dtype)\n\n        if backend.backend() == \"torch\" and expected_dtype == \"uint32\":\n            expected_dtype = \"int32\"\n\n        self.assertEqual(standardize_dtype(knp.nanvar(x).dtype), expected_dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.Nanvar().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_nan_to_num(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.nan_to_num(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.nan_to_num(x).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.NanToNum().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=list(itertools.product(ALL_DTYPES, ALL_DTYPES)))\n    )\n    def test_nextafter(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((), dtype=dtype1)\n        x2 = knp.ones((), dtype=dtype2)\n        x1_jax = jnp.ones((), dtype=dtype1)\n        x2_jax = jnp.ones((), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.nextafter(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.nextafter(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Nextafter().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_nonzero(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.zeros((1,), dtype=dtype)\n        x_jax = jnp.zeros((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.nonzero(x_jax)[0].dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.nonzero(x)[0].dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Nonzero().symbolic_call(x)[0].dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_not_equal(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((), dtype=dtype1)\n        x2 = knp.ones((), dtype=dtype2)\n        x1_jax = jnp.ones((), dtype=dtype1)\n        x2_jax = jnp.ones((), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.not_equal(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.not_equal(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.NotEqual().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_ones_like(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.ones_like(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.ones_like(x).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.OnesLike().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_outer(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((1, 2), dtype=dtype1)\n        x2 = knp.ones((3, 4), dtype=dtype2)\n        x1_jax = jnp.ones((1, 2), dtype=dtype1)\n        x2_jax = jnp.ones((3, 4), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.outer(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.outer(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Outer().symbolic_call(x1, x2).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_pad(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((2, 2, 2, 2), dtype=dtype)\n        x_jax = jnp.ones((2, 2, 2, 2), dtype=dtype)\n        pad_width = ((0, 0), (1, 1), (1, 1), (1, 1))\n\n        for mode in (\"constant\", \"symmetric\", \"reflect\"):\n            expected_dtype = standardize_dtype(\n                jnp.pad(x_jax, pad_width, mode).dtype\n            )\n\n            self.assertEqual(\n                standardize_dtype(knp.pad(x, pad_width, mode).dtype),\n                expected_dtype,\n            )\n            self.assertEqual(\n                standardize_dtype(\n                    knp.Pad(pad_width, mode).symbolic_call(x).dtype\n                ),\n                expected_dtype,\n            )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_power(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x = knp.ones((1,), dtype=dtype1)\n        power = knp.ones((1,), dtype2)\n        x_jax = jnp.ones((1,), dtype=dtype1)\n        power_jax = jnp.ones((1,), dtype2)\n        expected_dtype = standardize_dtype(jnp.power(x_jax, power_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.power(x, power).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Power().symbolic_call(x, power).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_power_python_types(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n\n        # python int\n        expected_dtype = standardize_dtype(jnp.power(x_jax, 1).dtype)\n\n        self.assertDType(knp.power(x, 1), expected_dtype)\n        self.assertDType(knp.Power().symbolic_call(x, 1), expected_dtype)\n\n        # python float\n        expected_dtype = standardize_dtype(jnp.power(x_jax, 1.0).dtype)\n\n        self.assertDType(knp.power(x, 1.0), expected_dtype)\n        self.assertDType(knp.Power().symbolic_call(x, 1.0), expected_dtype)\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_prod(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1, 1, 1), dtype=dtype)\n        x_jax = jnp.ones((1, 1, 1), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.prod(x_jax).dtype)\n        # TODO: torch doesn't support uint32\n        if backend.backend() == \"torch\" and expected_dtype == \"uint32\":\n            expected_dtype = \"int32\"\n\n        self.assertEqual(\n            standardize_dtype(knp.prod(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Prod().symbolic_call(x).dtype), expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_ptp(self, dtype):\n        import jax.numpy as jnp\n\n        if dtype == \"bool\":\n            self.skipTest(\"ptp doesn't support bool dtype\")\n\n        x = knp.ones((1, 1, 1), dtype=dtype)\n        x_jax = jnp.ones((1, 1, 1), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.ptp(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.ptp(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Ptp().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_quantile(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((3,), dtype=dtype)\n        x_jax = jnp.ones((3,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.quantile(x_jax, 0.5).dtype)\n        if dtype == \"int64\":\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(\n            standardize_dtype(knp.quantile(x, 0.5).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Quantile().symbolic_call(x, 0.5).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_searchsorted(self, dtype):\n        import jax.numpy as jnp\n\n        if dtype == \"bool\":\n            self.skipTest(\"searchsorted doesn't support bool dtype\")\n\n        a = knp.ones((3,), dtype=dtype)\n        v = knp.ones((3,), dtype=dtype)\n\n        a_jax = jnp.ones((3,), dtype=dtype)\n        v_jax = jnp.ones((3,), dtype=dtype)\n\n        expected_dtype = standardize_dtype(jnp.searchsorted(a_jax, v_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.searchsorted(a, v).dtype), expected_dtype\n        )\n\n        self.assertEqual(\n            standardize_dtype(knp.SearchSorted().symbolic_call(a, v).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_ravel(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.ravel(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.ravel(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Ravel().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=INT_DTYPES))\n    def test_unravel_index(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((3,), dtype=dtype)\n        x_jax = jnp.ones((3,), dtype=dtype)\n\n        indices = knp.array([2, 0], dtype=dtype)\n        indices_jax = jnp.array([2, 0], dtype=dtype)\n\n        unravel_result_knp = knp.unravel_index(indices, x.shape)\n        unravel_result_jax = jnp.unravel_index(indices_jax, x_jax.shape)\n\n        expected_dtype_knp = standardize_dtype(unravel_result_knp[0].dtype)\n        expected_dtype_jax = standardize_dtype(unravel_result_jax[0].dtype)\n\n        self.assertEqual(expected_dtype_knp, expected_dtype_jax)\n\n        unravel_result_knp_symbolic = knp.UnravelIndex(x.shape).symbolic_call(\n            indices\n        )\n        expected_dtype_symbolic = standardize_dtype(\n            unravel_result_knp_symbolic[0].dtype\n        )\n\n        self.assertEqual(expected_dtype_symbolic, expected_dtype_jax)\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_repeat(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1, 1), dtype=dtype)\n        x_jax = jnp.ones((1, 1), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.repeat(x_jax, 2).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.repeat(x, 2).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Repeat(2).symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_reshape(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1, 1), dtype=dtype)\n        x_jax = jnp.ones((1, 1), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.reshape(x_jax, [1]).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.reshape(x, [1]).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Reshape([1]).symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_roll(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((5,), dtype=dtype)\n        x_jax = jnp.ones((5,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.roll(x_jax, 2).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.roll(x, 2).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Roll(2).symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_round(self, dtype):\n        import jax.numpy as jnp\n\n        if dtype == \"bool\":\n            self.skipTest(\"round doesn't support bool dtype\")\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.round(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.round(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Round().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_sign(self, dtype):\n        import jax.numpy as jnp\n\n        if dtype == \"bool\":\n            self.skipTest(\"sign doesn't support bool dtype\")\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.sign(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.sign(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Sign().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_signbit(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.signbit(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.signbit(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Signbit().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_sin(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.sin(x_jax).dtype)\n        if dtype == \"int64\":\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(standardize_dtype(knp.sin(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Sin().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_sinc(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.sinc(x_jax).dtype)\n        if dtype == \"int64\":\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(standardize_dtype(knp.sinc(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Sinc().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_sinh(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.sinh(x_jax).dtype)\n        if dtype == \"int64\":\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(standardize_dtype(knp.sinh(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Sinh().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_sort(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((2,), dtype=dtype)\n        x_jax = jnp.ones((2,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.sort(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.sort(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Sort().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_split(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1, 2), dtype=dtype)\n        x_jax = jnp.ones((1, 2), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.split(x_jax, 2, -1)[0].dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.split(x, 2, -1)[0].dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Split(2, -1).symbolic_call(x)[0].dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_hsplit(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((2, 1), dtype=dtype)\n        x_jax = jnp.ones((2, 1), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.hsplit(x_jax, [1])[0].dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.hsplit(x, [1])[0].dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Hsplit([1]).symbolic_call(x)[0].dtype),\n            expected_dtype,\n        )\n\n        # test 1d case\n        x_1d = knp.ones((4,), dtype=dtype)\n        x_1d_jax = jnp.ones((4,), dtype=dtype)\n        expected_dtype_1d = standardize_dtype(\n            jnp.hsplit(x_1d_jax, [2])[0].dtype\n        )\n\n        self.assertEqual(\n            standardize_dtype(knp.hsplit(x_1d, [2])[0].dtype),\n            expected_dtype_1d,\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Hsplit([2]).symbolic_call(x_1d)[0].dtype),\n            expected_dtype_1d,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_vsplit(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1, 2), dtype=dtype)\n        x_jax = jnp.ones((1, 2), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.vsplit(x_jax, [1])[0].dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.vsplit(x, [1])[0].dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Vsplit([1]).symbolic_call(x)[0].dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_sqrt(self, dtype):\n        import jax.numpy as jnp\n\n        x1 = knp.ones((1,), dtype=dtype)\n        x1_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.sqrt(x1_jax).dtype)\n        if dtype == \"int64\":\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(standardize_dtype(knp.sqrt(x1).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Sqrt().symbolic_call(x1).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_square(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.square(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.square(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Square().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_squeeze(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1, 1), dtype=dtype)\n        x_jax = jnp.ones((1, 1), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.squeeze(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.squeeze(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Squeeze().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_stack(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((1,), dtype=dtype1)\n        x2 = knp.ones((1,), dtype=dtype2)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.stack([x1_jax, x2_jax]).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.stack([x1, x2]).dtype), expected_dtype\n        )\n        self.assertEqual(\n            knp.Stack().symbolic_call([x1, x2]).dtype, expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_std(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.std(x_jax).dtype)\n        if dtype == \"int64\":\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(\n            standardize_dtype(knp.std(x).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Std().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_sum(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.sum(x_jax).dtype)\n\n        # TODO: torch doesn't support uint32\n        if backend.backend() == \"torch\" and expected_dtype == \"uint32\":\n            expected_dtype = \"int32\"\n\n        self.assertEqual(standardize_dtype(knp.sum(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Sum().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_swapaxes(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1, 1), dtype=dtype)\n        x_jax = jnp.ones((1, 1), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.swapaxes(x_jax, -1, -2).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.swapaxes(x, -1, -2).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Swapaxes(-1, -2).symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_take(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.take(x_jax, 0).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.take(x, 0).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Take().symbolic_call(x, 0).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtype=ALL_DTYPES, indices_dtype=INT_DTYPES)\n    )\n    def test_take_along_axis(self, dtype, indices_dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        indices = knp.zeros((1,), dtype=indices_dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        indices_jax = jnp.zeros((1,), dtype=indices_dtype)\n        expected_dtype = standardize_dtype(\n            jnp.take_along_axis(x_jax, indices_jax, 0).dtype\n        )\n\n        self.assertEqual(\n            standardize_dtype(knp.take_along_axis(x, indices, 0).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            standardize_dtype(\n                knp.TakeAlongAxis(0).symbolic_call(x, indices).dtype\n            ),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_tan(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.tan(x_jax).dtype)\n        if dtype == \"int64\":\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(standardize_dtype(knp.tan(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Tan().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_tanh(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.tanh(x_jax).dtype)\n        if dtype == \"int64\":\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(standardize_dtype(knp.tanh(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Tanh().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_tensordot(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((1, 1), dtype=dtype1)\n        x2 = knp.ones((1, 1), dtype=dtype2)\n        x1_jax = jnp.ones((1, 1), dtype=dtype1)\n        x2_jax = jnp.ones((1, 1), dtype=dtype2)\n        expected_dtype = standardize_dtype(\n            jnp.tensordot(x1_jax, x2_jax, 2).dtype\n        )\n\n        self.assertEqual(\n            standardize_dtype(knp.tensordot(x1, x2, 2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            knp.Tensordot(2).symbolic_call(x1, x2).dtype, expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_tile(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.tile(x_jax, [1]).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.tile(x, [1]).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Tile([1]).symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_trace(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1, 1, 1), dtype=dtype)\n        x_jax = jnp.ones((1, 1, 1), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.trace(x_jax).dtype)\n        # jnp.trace is buggy with bool. We set the expected_dtype to int32\n        # for bool inputs\n        if dtype == \"bool\":\n            expected_dtype = \"int32\"\n        if dtype == \"uint8\" and backend.backend() == \"torch\":\n            # Torch backend doesn't support uint32 dtype.\n            expected_dtype = \"int32\"\n\n        self.assertDType(knp.trace(x), expected_dtype)\n        self.assertDType(knp.Trace().symbolic_call(x), expected_dtype)\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_transpose(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1, 1), dtype=dtype)\n        x_jax = jnp.ones((1, 1), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.transpose(x_jax, [1, 0]).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.transpose(x, [1, 0]).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Transpose([1, 0]).symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_tri(self, dtype):\n        import jax.numpy as jnp\n\n        expected_dtype = standardize_dtype(jnp.tri(3, dtype=dtype).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.tri(3, dtype=dtype).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_tril(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1, 1), dtype=dtype)\n        x_jax = jnp.ones((1, 1), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.tril(x_jax, 0).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.tril(x, 0).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Tril(0).symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_triu(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((1, 1), dtype=dtype)\n        x_jax = jnp.ones((1, 1), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.triu(x_jax, 0).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.triu(x, 0).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Triu(0).symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_true_divide(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((1,), dtype=dtype1)\n        x2 = knp.ones((1,), dtype=dtype2)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(\n            jnp.true_divide(x1_jax, x2_jax).dtype\n        )\n\n        self.assertDType(knp.true_divide(x1, x2), expected_dtype)\n        self.assertDType(knp.TrueDivide().symbolic_call(x1, x2), expected_dtype)\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_trunc(self, dtype):\n        x = knp.ones((1, 1), dtype=dtype)\n        # TODO: jax <= 0.30.0 doesn't preserve the original dtype.\n        expected_dtype = dtype or backend.floatx()\n\n        self.assertEqual(standardize_dtype(knp.trunc(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Trunc().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_trapezoid(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((2,), dtype=dtype)\n        x_jax = jnp.ones((2,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.trapezoid(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.trapezoid(x).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.Trapezoid().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_vander(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((2,), dtype=dtype)\n        x_jax = jnp.ones((2,), dtype=dtype)\n\n        if dtype == \"bool\":\n            self.skipTest(\"vander does not support bool\")\n\n        expected_dtype = standardize_dtype(jnp.vander(x_jax).dtype)\n\n        self.assertEqual(standardize_dtype(knp.vander(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Vander().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_var(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((2,), dtype=dtype)\n        x_jax = jnp.ones((2,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.var(x_jax).dtype)\n        if dtype == \"int64\":\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(standardize_dtype(knp.var(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Var().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_vdot(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((1,), dtype=dtype1)\n        x2 = knp.ones((1,), dtype=dtype2)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.vdot(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.vdot(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(knp.Vdot().symbolic_call(x1, x2).dtype, expected_dtype)\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_inner(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((1,), dtype=dtype1)\n        x2 = knp.ones((1,), dtype=dtype2)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.inner(x1_jax, x2_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.inner(x1, x2).dtype), expected_dtype\n        )\n        self.assertEqual(\n            knp.Inner().symbolic_call(x1, x2).dtype, expected_dtype\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_vstack(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        x1 = knp.ones((1,), dtype=dtype1)\n        x2 = knp.ones((1,), dtype=dtype2)\n        x1_jax = jnp.ones((1,), dtype=dtype1)\n        x2_jax = jnp.ones((1,), dtype=dtype2)\n        expected_dtype = standardize_dtype(jnp.vstack([x1_jax, x2_jax]).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.vstack([x1, x2]).dtype), expected_dtype\n        )\n        self.assertEqual(\n            knp.Vstack().symbolic_call([x1, x2]).dtype, expected_dtype\n        )\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))\n    )\n    def test_where(self, dtypes):\n        import jax.numpy as jnp\n\n        dtype1, dtype2 = dtypes\n        condition = knp.ones((10,), dtype=\"bool\")\n        x1 = knp.ones((10,), dtype=dtype1)\n        x2 = knp.ones((10,), dtype=dtype2)\n        condition_jax = jnp.ones((10,), dtype=\"bool\")\n        x1_jax = jnp.ones((10,), dtype=dtype1)\n        x2_jax = jnp.ones((10,), dtype=dtype2)\n        expected_dtype = standardize_dtype(\n            jnp.where(condition_jax, x1_jax, x2_jax).dtype\n        )\n\n        self.assertEqual(\n            standardize_dtype(knp.where(condition, x1, x2).dtype),\n            expected_dtype,\n        )\n        self.assertEqual(\n            knp.Where().symbolic_call(condition, x1, x2).dtype, expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_where_python_types(self, dtype):\n        import jax.numpy as jnp\n\n        condition = knp.ones((10,), dtype=\"bool\")\n        x = knp.ones((10,), dtype=dtype)\n        condition_jax = jnp.ones((10,), dtype=\"bool\")\n        x_jax = jnp.ones((10,), dtype=dtype)\n\n        # python int\n        expected_dtype = standardize_dtype(\n            jnp.where(condition_jax, x_jax, 1).dtype\n        )\n\n        self.assertDType(knp.where(condition, x, 1), expected_dtype)\n        self.assertDType(\n            knp.Where().symbolic_call(condition, x, 1), expected_dtype\n        )\n\n        # python float\n        expected_dtype = standardize_dtype(\n            jnp.where(condition_jax, x_jax, 1.0).dtype\n        )\n\n        self.assertDType(knp.where(condition, x, 1.0), expected_dtype)\n        self.assertDType(\n            knp.Where().symbolic_call(condition, x, 1.0), expected_dtype\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_zeros_like(self, dtype):\n        import jax.numpy as jnp\n\n        x = knp.ones((), dtype=dtype)\n        x_jax = jnp.ones((), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.ones_like(x_jax).dtype)\n\n        self.assertEqual(\n            standardize_dtype(knp.zeros_like(x).dtype), expected_dtype\n        )\n        self.assertEqual(\n            standardize_dtype(knp.ZerosLike().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    @parameterized.named_parameters(named_product(dtype=ALL_DTYPES))\n    def test_angle(self, dtype):\n        if dtype == \"bfloat16\" and testing.torch_uses_gpu():\n            self.skipTest(\"Torch cuda does not support bfloat16\")\n\n        import jax.numpy as jnp\n\n        x = knp.ones((1,), dtype=dtype)\n        x_jax = jnp.ones((1,), dtype=dtype)\n        expected_dtype = standardize_dtype(jnp.angle(x_jax).dtype)\n        if dtype == \"bool\" or is_int_dtype(dtype):\n            expected_dtype = backend.floatx()\n\n        self.assertEqual(standardize_dtype(knp.angle(x).dtype), expected_dtype)\n        self.assertEqual(\n            standardize_dtype(knp.Angle().symbolic_call(x).dtype),\n            expected_dtype,\n        )\n\n    VIEW_DTYPES = [x for x in ALL_DTYPES if x != \"bool\" and x is not None]\n\n    @parameterized.named_parameters(\n        named_product(dtypes=itertools.combinations(VIEW_DTYPES, 2))\n    )\n    def test_view(self, dtypes):\n        import jax.numpy as jnp\n\n        input_dtype, output_dtype = dtypes\n        x = knp.ones((2, 8), dtype=input_dtype)\n        x_jax = jnp.ones((2, 8), dtype=input_dtype)\n\n        keras_output = knp.view(x, output_dtype)\n        symbolic_output = knp.View(output_dtype).symbolic_call(x)\n        expected_output = x_jax.view(output_dtype)\n        self.assertEqual(\n            standardize_dtype(keras_output.dtype),\n            standardize_dtype(expected_output.dtype),\n        )\n        self.assertEqual(\n            keras_output.shape,\n            expected_output.shape,\n        )\n        self.assertEqual(\n            standardize_dtype(symbolic_output.dtype),\n            standardize_dtype(expected_output.dtype),\n        )\n\n\n@pytest.mark.skipif(\n    testing.torch_uses_gpu(),\n    reason=\"histogram op not implemented for torch on gpu\",\n)\nclass HistogramTest(testing.TestCase):\n    def test_histogram_default_args(self):\n        hist_op = knp.histogram\n        input_tensor = np.random.rand(8)\n\n        # Expected output\n        expected_counts, expected_edges = np.histogram(input_tensor)\n\n        counts, edges = hist_op(input_tensor)\n\n        self.assertEqual(counts.shape, expected_counts.shape)\n        self.assertAllClose(counts, expected_counts)\n        self.assertEqual(edges.shape, expected_edges.shape)\n        self.assertAllClose(edges, expected_edges)\n\n    def test_histogram_custom_bins(self):\n        hist_op = knp.histogram\n        input_tensor = np.random.rand(8)\n        bins = 5\n\n        # Expected output\n        expected_counts, expected_edges = np.histogram(input_tensor, bins=bins)\n\n        counts, edges = hist_op(input_tensor, bins=bins)\n\n        self.assertEqual(counts.shape, expected_counts.shape)\n        self.assertAllClose(counts, expected_counts)\n        self.assertEqual(edges.shape, expected_edges.shape)\n        self.assertAllClose(edges, expected_edges)\n\n    def test_histogram_custom_range(self):\n        hist_op = knp.histogram\n        input_tensor = np.random.rand(10)\n        range_specified = (2, 8)\n\n        # Expected output\n        expected_counts, expected_edges = np.histogram(\n            input_tensor, range=range_specified\n        )\n\n        counts, edges = hist_op(input_tensor, range=range_specified)\n\n        self.assertEqual(counts.shape, expected_counts.shape)\n        self.assertAllClose(counts, expected_counts)\n        self.assertEqual(edges.shape, expected_edges.shape)\n        self.assertAllClose(edges, expected_edges)\n\n    def test_histogram_symbolic_input(self):\n        hist_op = knp.histogram\n        input_tensor = KerasTensor(shape=(None,), dtype=\"float32\")\n\n        counts, edges = hist_op(input_tensor)\n\n        self.assertEqual(counts.shape, (10,))\n        self.assertEqual(edges.shape, (11,))\n\n    def test_histogram_non_integer_bins_raises_error(self):\n        hist_op = knp.histogram\n        input_tensor = np.random.rand(8)\n\n        with self.assertRaisesRegex(\n            ValueError, \"Argument `bins` should be a non-negative integer\"\n        ):\n            hist_op(input_tensor, bins=-5)\n\n    def test_histogram_range_validation(self):\n        hist_op = knp.histogram\n        input_tensor = np.random.rand(8)\n\n        with self.assertRaisesRegex(\n            ValueError, \"Argument `range` must be a tuple of two elements\"\n        ):\n            hist_op(input_tensor, range=(1,))\n\n        with self.assertRaisesRegex(\n            ValueError,\n            \"The second element of `range` must be greater than the first\",\n        ):\n            hist_op(input_tensor, range=(5, 1))\n\n    def test_histogram_large_values(self):\n        hist_op = knp.histogram\n        input_tensor = np.array([1e10, 2e10, 3e10, 4e10, 5e10])\n\n        counts, edges = hist_op(input_tensor, bins=5)\n\n        expected_counts, expected_edges = np.histogram(input_tensor, bins=5)\n\n        self.assertAllClose(counts, expected_counts)\n        self.assertAllClose(edges, expected_edges)\n\n    def test_histogram_float_input(self):\n        hist_op = knp.histogram\n        input_tensor = np.random.rand(8)\n\n        counts, edges = hist_op(input_tensor, bins=5)\n\n        expected_counts, expected_edges = np.histogram(input_tensor, bins=5)\n\n        self.assertAllClose(counts, expected_counts)\n        self.assertAllClose(edges, expected_edges)\n\n    def test_histogram_high_dimensional_input(self):\n        hist_op = knp.histogram\n        input_tensor = np.random.rand(3, 4, 5)\n\n        with self.assertRaisesRegex(\n            ValueError, \"Input tensor must be 1-dimensional\"\n        ):\n            hist_op(input_tensor)\n\n    def test_histogram_values_on_edges(self):\n        hist_op = knp.histogram\n        input_tensor = np.array([0.0, 2.0, 4.0, 8.0, 10.0])\n        bins = 5\n\n        expected_counts, expected_edges = np.histogram(input_tensor, bins=bins)\n        counts, edges = hist_op(input_tensor, bins=bins)\n\n        self.assertAllClose(counts, expected_counts)\n        self.assertAllClose(edges, expected_edges)\n\n    # TODO: Fix predict for NumPy.\n    @parameterized.named_parameters(\n        (\"jit_compile_false\", False),\n        (\"jit_compile_true\", True),\n    )\n    @pytest.mark.skipif(\n        backend.backend() == \"numpy\",\n        reason=(\n            \"`predict` errors out with 'autodetected range of [nan, nan] is \"\n            \"not finite' on the NumPy backend. To be fixed.\"\n        ),\n    )\n    def test_histogram_predict(self, jit_compile):\n        class HistogramLayer(keras.layers.Layer):\n            def call(self, x):\n                shape = ops.shape(x)\n\n                # Flatten, because the op does not work with >1-dim inputs.\n                x = ops.reshape(x, (shape[0] * shape[1],))\n                return knp.histogram(x, bins=5)\n\n        inputs = keras.Input(shape=(8,))\n        counts, edges = HistogramLayer()(inputs)\n        model = keras.Model(inputs, (counts, edges))\n        model.compile(jit_compile=jit_compile)\n\n        model.predict(np.random.randn(1, 8))\n\n\nclass TileTest(testing.TestCase):\n    def test_tile_shape_inference_in_layer(self):\n        class TileLayer(keras.layers.Layer):\n            def call(self, x):\n                repeats = [1, 2, 1, 1]\n                return knp.tile(x, repeats)\n\n        inputs = keras.Input(shape=(3, 2, 2))\n        output = TileLayer()(inputs)\n\n        self.assertEqual(output.shape, (None, 6, 2, 2))\n"
  },
  {
    "path": "keras/src/ops/operation.py",
    "content": "import inspect\nimport textwrap\n\nfrom keras.src import backend\nfrom keras.src import dtype_policies\nfrom keras.src import tree\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend.common.keras_tensor import any_symbolic_tensors\nfrom keras.src.backend.config import is_nnx_enabled\nfrom keras.src.ops.node import Node\nfrom keras.src.saving.keras_saveable import KerasSaveable\nfrom keras.src.utils import python_utils\nfrom keras.src.utils import traceback_utils\nfrom keras.src.utils.naming import auto_name\n\n\n@keras_export(\"keras.Operation\")\nclass Operation(KerasSaveable):\n    def __init__(self, name=None):\n        if name is None:\n            name = auto_name(self.__class__.__name__)\n        if not isinstance(name, str) or \"/\" in name:\n            raise ValueError(\n                \"Argument `name` must be a string and \"\n                f\"cannot contain character `/`. \"\n                f\"Received: name={name} (of type {type(name)})\"\n            )\n        self.name = name\n        self._inbound_nodes = []\n        self._outbound_nodes = []\n\n    @traceback_utils.filter_traceback\n    def __call__(self, *args, **kwargs):\n        if traceback_utils.is_traceback_filtering_enabled():\n            # Wrap self.call to provide helpful info in case of exception\n            if any_symbolic_tensors(args, kwargs):\n                call_fn = self.symbolic_call\n            else:\n                if getattr(self, \"_remat_mode\", None) is not None:\n                    if getattr(self, \"quantization_mode\", None) is not None:\n                        call_fn = self.rematerialized_call(\n                            self.quantized_call,\n                            *args,\n                            **kwargs,\n                        )\n                    else:\n                        call_fn = self.rematerialized_call(\n                            self.call, *args, **kwargs\n                        )\n                else:\n                    if getattr(self, \"quantization_mode\", None) is not None:\n                        call_fn = self.quantized_call\n                    else:\n                        call_fn = self.call\n            call_fn = traceback_utils.inject_argument_info_in_traceback(\n                call_fn,\n                object_name=(f\"{self.__class__.__name__}.call()\"),\n            )\n            return call_fn(*args, **kwargs)\n\n        # Plain flow.\n        if any_symbolic_tensors(args, kwargs):\n            return self.symbolic_call(*args, **kwargs)\n        elif getattr(self, \"_remat_mode\", None) is not None:\n            if getattr(self, \"quantization_mode\", None) is not None:\n                return self.rematerialized_call(\n                    self.quantized_call, *args, **kwargs\n                )(*args, **kwargs)\n            else:\n                return self.rematerialized_call(self.call, *args, **kwargs)(\n                    *args, **kwargs\n                )\n        else:\n            if getattr(self, \"quantization_mode\", None) is not None:\n                return self.quantized_call(*args, **kwargs)\n            else:\n                return self.call(*args, **kwargs)\n\n    def symbolic_call(self, *args, **kwargs):\n        # Perform shape/dtype inference.\n        outputs = self.compute_output_spec(*args, **kwargs)\n        # Record a new node in the operations graph.\n        # The Node wires itself to inbound and outbound ops.  The\n        # Node constructor updates this op's self._inbound_nodes,\n        # sets _keras_history on the outputs, and adds itself to the\n        # `_outbound_nodes` of the ops that produced the inputs to this\n        # call.\n        Node(\n            operation=self, call_args=args, call_kwargs=kwargs, outputs=outputs\n        )\n        return outputs\n\n    def call(self, *args, **kwargs):\n        raise NotImplementedError\n\n    def quantized_call(self, *args, **kwargs):\n        raise NotImplementedError\n\n    def compute_output_spec(self, *args, **kwargs):\n        try:\n            return backend.compute_output_spec(self.call, *args, **kwargs)\n        except Exception as e:\n            new_e = e.__class__(\n                \"Could not automatically infer the output shape / dtype of \"\n                f\"'{self.name}' (of type {self.__class__.__name__}). \"\n                f\"Either the `{self.__class__.__name__}.call()` method \"\n                f\"is incorrect, or you need to implement the \"\n                f\"`{self.__class__.__name__}.compute_output_spec() / \"\n                \"compute_output_shape()` method. \"\n                f\"Error encountered:\\n\\n{e}\"\n            )\n            raise new_e.with_traceback(e.__traceback__) from None\n\n    def __new__(cls, *args, **kwargs):\n        \"\"\"We override __new__ to saving serializable constructor arguments.\n\n        These arguments are used to auto-generate an object serialization\n        config, which enables user-created subclasses to be serializable\n        out of the box in most cases without forcing the user\n        to manually implement `get_config()`.\n        \"\"\"\n        instance = super(Operation, cls).__new__(cls)\n        if backend.backend() == \"jax\" and is_nnx_enabled():\n            from flax import nnx\n\n            try:\n                vars(instance)[\"_pytree__state\"] = nnx.pytreelib.PytreeState()\n            except AttributeError:\n                vars(instance)[\"_object__state\"] = nnx.object.ObjectState()\n\n        # Generate a config to be returned by default by `get_config()`.\n        auto_config = True\n\n        signature = inspect.signature(cls.__init__)\n        argspec = inspect.getfullargspec(cls.__init__)\n\n        try:\n            bound_parameters = signature.bind(None, *args, **kwargs)\n        except TypeError:\n            # Raised by signature.bind when the supplied args and kwargs\n            # do not match the signature.\n            auto_config = False\n\n        if auto_config and any(\n            [\n                param.kind == inspect.Parameter.POSITIONAL_ONLY\n                for name, param in signature.parameters.items()\n                if name != argspec.args[0]\n            ]\n        ):\n            # cls.__init__ takes positional only arguments, which\n            # cannot be restored via cls(**config)\n            auto_config = False\n            # Create variable to show appropriate warning in get_config.\n            instance._auto_config_error_args = True\n\n        if auto_config:\n            # Include default values in the config.\n            bound_parameters.apply_defaults()\n            # Extract all arguments as a dictionary.\n            kwargs = bound_parameters.arguments\n            # Expand variable kwargs argument.\n            kwargs |= kwargs.pop(argspec.varkw, {})\n            # Remove first positional argument, self.\n            kwargs.pop(argspec.args[0])\n            # Remove argument \"name\", as it is provided by get_config.\n            kwargs.pop(\"name\", None)\n            if argspec.varargs is not None:\n                # Varargs cannot be meaningfully converted to a dictionary.\n                varargs = kwargs.pop(argspec.varargs)\n                if len(varargs) > 0:\n                    auto_config = False\n                    # Store variable to show appropriate warning in get_config.\n                    instance._auto_config_error_args = True\n\n        # For safety, we only rely on auto-configs for a small set of\n        # serializable types.\n        supported_types = (str, int, float, bool, type(None))\n        try:\n            flat_arg_values = tree.flatten(kwargs)\n            for value in flat_arg_values:\n                if not isinstance(value, supported_types):\n                    auto_config = False\n                    break\n        except TypeError:\n            auto_config = False\n        try:\n            instance._lock = False\n            if auto_config:\n                from keras.src.saving import serialization_lib\n\n                instance._auto_config = serialization_lib.SerializableDict(\n                    **kwargs\n                )\n            else:\n                instance._auto_config = None\n            instance._lock = True\n        except RecursionError:\n            # Setting an instance attribute in __new__ has the potential\n            # to trigger an infinite recursion if a subclass overrides\n            # setattr in an unsafe way.\n            pass\n        return instance\n\n    @python_utils.default\n    def get_config(self):\n        \"\"\"Returns the config of the object.\n\n        An object config is a Python dictionary (serializable)\n        containing the information needed to re-instantiate it.\n        \"\"\"\n        config = {\n            \"name\": self.name,\n        }\n\n        if not python_utils.is_default(self.get_config):\n            # In this case the subclass implements get_config()\n            return config\n\n        # In this case the subclass doesn't implement get_config():\n        # Let's see if we can autogenerate it.\n        if getattr(self, \"_auto_config\", None) is not None:\n            config.update(self._auto_config.config)\n            init_params = inspect.signature(self.__init__).parameters\n            init_has_name = \"name\" in init_params\n            init_has_kwargs = (\n                \"kwargs\" in init_params\n                and init_params[\"kwargs\"].kind == inspect.Parameter.VAR_KEYWORD\n            )\n            if not init_has_name and not init_has_kwargs:\n                # We can't pass `name` back to `__init__`, remove it.\n                config.pop(\"name\", None)\n            return config\n        else:\n            example_str = \"\"\"\n            class CustomLayer(keras.layers.Layer):\n                def __init__(self, arg1, arg2, **kwargs):\n                    super().__init__(**kwargs)\n                    self.arg1 = arg1\n                    self.arg2 = arg2\n\n                def get_config(self):\n                    config = super().get_config()\n                    config.update({\n                        \"arg1\": self.arg1,\n                        \"arg2\": self.arg2,\n                    })\n                    return config\n            \"\"\"\n            if getattr(self, \"_auto_config_error_args\", False):\n                raise NotImplementedError(\n                    textwrap.dedent(\n                        f\"\"\"\n            Object {self.__class__.__name__} was created by passing\n            positional only or variadic positional arguments (e.g.,\n            `*args`) to `__init__()`, which is not supported by the\n            automatic config generation. Please remove all positional\n            only and variadic arguments from `__init__()`\n            or override `get_config()` and `from_config()` to make\n            the object serializatble.\n\n            Example:\n\n            {example_str}\"\"\"\n                    )\n                )\n            else:\n                raise NotImplementedError(\n                    textwrap.dedent(\n                        f\"\"\"\n            Object {self.__class__.__name__} was created by passing\n            non-serializable argument values in `__init__()`,\n            and therefore the object must override `get_config()` in\n            order to be serializable. Please implement `get_config()`.\n\n            Example:\n\n            {example_str}\"\"\"\n                    )\n                )\n\n    @classmethod\n    def from_config(cls, config):\n        \"\"\"Creates an operation from its config.\n\n        This method is the reverse of `get_config`, capable of instantiating the\n        same operation from the config dictionary.\n\n        Note: If you override this method, you might receive a serialized dtype\n        config, which is a `dict`. You can deserialize it as follows:\n\n        ```python\n        if \"dtype\" in config and isinstance(config[\"dtype\"], dict):\n            policy = dtype_policies.deserialize(config[\"dtype\"])\n        ```\n\n        Args:\n            config: A Python dictionary, typically the output of `get_config`.\n\n        Returns:\n            An operation instance.\n        \"\"\"\n        # Explicitly deserialize dtype config if needed. This enables users to\n        # directly interact with the instance of `DTypePolicy`.\n        if \"dtype\" in config and isinstance(config[\"dtype\"], dict):\n            config = config.copy()\n            policy = dtype_policies.deserialize(config[\"dtype\"])\n            if (\n                not isinstance(policy, dtype_policies.DTypePolicyMap)\n                and policy.quantization_mode is None\n            ):\n                # For backward compatibility, we use a str (`name`) for\n                # `DTypePolicy`\n                policy = policy.name\n            config[\"dtype\"] = policy\n        try:\n            return cls(**config)\n        except Exception as e:\n            raise TypeError(\n                f\"Error when deserializing class '{cls.__name__}' using \"\n                f\"config={config}.\\n\\nException encountered: {e}\"\n            )\n\n    def __repr__(self):\n        return f\"<Operation name={self.name}>\"\n\n    @property\n    def input(self):\n        \"\"\"Retrieves the input tensor(s) of a symbolic operation.\n\n        Only returns the tensor(s) corresponding to the *first time*\n        the operation was called.\n\n        Returns:\n            Input tensor or list of input tensors.\n        \"\"\"\n        return self._get_node_attribute_at_index(0, \"input_tensors\", \"input\")\n\n    @property\n    def output(self):\n        \"\"\"Retrieves the output tensor(s) of a layer.\n\n        Only returns the tensor(s) corresponding to the *first time*\n        the operation was called.\n\n        Returns:\n            Output tensor or list of output tensors.\n        \"\"\"\n        return self._get_node_attribute_at_index(0, \"output_tensors\", \"output\")\n\n    def _get_node_attribute_at_index(self, node_index, attr, attr_name):\n        \"\"\"Private utility to retrieves an attribute (e.g. inputs) from a node.\n\n        This is used to implement the properties:\n        - output\n        - input\n\n        Args:\n            node_index: Integer index of the node from which\n                to retrieve the attribute.\n            attr: Exact node attribute name.\n            attr_name: Human-readable attribute name, for error messages.\n\n        Returns:\n            The operation's attribute `attr` at the node of index `node_index`.\n        \"\"\"\n        if not self._inbound_nodes:\n            raise AttributeError(\n                f\"The layer {self.name} has never been called \"\n                f\"and thus has no defined {attr_name}.\"\n            )\n        if not len(self._inbound_nodes) > node_index:\n            raise ValueError(\n                f\"Asked to get {attr_name} at node \"\n                f\"{node_index}, but the operation has only \"\n                f\"{len(self._inbound_nodes)} inbound nodes.\"\n            )\n        values = getattr(self._inbound_nodes[node_index], attr)\n        if isinstance(values, list) and len(values) == 1:\n            return values[0]\n        else:\n            return values\n\n    def _obj_type(self):\n        return \"Operation\"\n\n    # Hooks for backend layer classes\n    def _post_build(self):\n        \"\"\"Can be overridden for per backend post build actions.\"\"\"\n        pass\n\n    def _setattr_hook(self, name, value):\n        \"\"\"Can be overridden for per backend post build actions.\"\"\"\n        return name, value\n\n    def _post_track_variable(self, variable):\n        \"\"\"Can be overridden for per backend post track actions.\"\"\"\n        pass\n\n    def _post_untrack_variable(self, variable):\n        \"\"\"Can be overridden for per backend post untrack actions.\"\"\"\n        pass\n"
  },
  {
    "path": "keras/src/ops/operation_test.py",
    "content": "import numpy as np\n\nfrom conftest import skip_if_backend\nfrom keras.src import backend\nfrom keras.src import testing\nfrom keras.src.backend.common import keras_tensor\nfrom keras.src.ops import numpy as knp\nfrom keras.src.ops import operation\n\n\nclass OpWithMultipleInputs(operation.Operation):\n    def call(self, x, y, z=None):\n        # `z` has to be put first due to the order of operations issue with\n        # torch backend.\n        return 3 * z + x + 2 * y\n\n    def compute_output_spec(self, x, y, z=None):\n        return keras_tensor.KerasTensor(x.shape, x.dtype)\n\n\nclass OpWithMultipleOutputs(operation.Operation):\n    def call(self, x):\n        return (x, x + 1)\n\n    def compute_output_spec(self, x):\n        return (\n            keras_tensor.KerasTensor(x.shape, x.dtype),\n            keras_tensor.KerasTensor(x.shape, x.dtype),\n        )\n\n\nclass OpWithCustomConstructor(operation.Operation):\n    def __init__(self, alpha, *, beta=1.0, name=None):\n        super().__init__(name=name)\n        self.alpha = alpha\n        self.beta = beta\n\n    def call(self, x):\n        return self.alpha * x + self.beta\n\n    def compute_output_spec(self, x):\n        return keras_tensor.KerasTensor(x.shape, x.dtype)\n\n\nclass OpWithCustomConstructorNoName(operation.Operation):\n    def __init__(self, alpha, beta=1.0):\n        super().__init__()\n        self.alpha = alpha\n        self.beta = beta\n\n    def call(self, x):\n        return self.alpha * x + self.beta\n\n    def compute_output_spec(self, x):\n        return keras_tensor.KerasTensor(x.shape, x.dtype)\n\n\nclass OpWithKwargsInConstructor(operation.Operation):\n    def __init__(self, alpha, beta=1.0, **kwargs):\n        super().__init__(**kwargs)\n        self.alpha = alpha\n        self.beta = beta\n\n    def call(self, x):\n        return self.alpha * x + self.beta\n\n    def compute_output_spec(self, x):\n        return keras_tensor.KerasTensor(x.shape, x.dtype)\n\n\nclass OpWithArgsInConstructor(operation.Operation):\n    def __init__(self, alpha, *args, name=None):\n        super().__init__(name=name)\n        self.alpha = alpha\n\n    def call(self, x):\n        return self.alpha * x + self.beta\n\n    def compute_output_spec(self, x):\n        return keras_tensor.KerasTensor(x.shape, x.dtype)\n\n\nclass OpWithCustomConstructorGetConfig(operation.Operation):\n    def __init__(self, alpha, *, name=None):\n        super().__init__(name=name)\n        self.alpha = alpha\n\n    def call(self, x):\n        return self.alpha * x\n\n    def compute_output_spec(self, x):\n        return keras_tensor.KerasTensor(x.shape, x.dtype)\n\n    def get_config(self):\n        return {**super().get_config(), \"alpha\": self.alpha}\n\n\nclass OpWithKwargsInConstructorGetConfig(operation.Operation):\n    def __init__(self, alpha, **kwargs):\n        super().__init__(**kwargs)\n        self.alpha = alpha\n\n    def call(self, x):\n        return self.alpha * x\n\n    def compute_output_spec(self, x):\n        return keras_tensor.KerasTensor(x.shape, x.dtype)\n\n    def get_config(self):\n        return {**super().get_config(), \"alpha\": self.alpha}\n\n\nclass OperationTest(testing.TestCase):\n    def test_symbolic_call(self):\n        x = keras_tensor.KerasTensor(shape=(2, 3), name=\"x\")\n        y = keras_tensor.KerasTensor(shape=(2, 3), name=\"y\")\n        z = keras_tensor.KerasTensor(shape=(2, 3), name=\"z\")\n\n        # Positional arguments\n        op = OpWithMultipleInputs(name=\"test_op\")\n        self.assertEqual(op.name, \"test_op\")\n        out = op(x, y, z)\n        self.assertIsInstance(out, keras_tensor.KerasTensor)\n        self.assertEqual(out.shape, (2, 3))\n        self.assertEqual(len(op._inbound_nodes), 1)\n        self.assertEqual(op.input, [x, y, z])\n        self.assertEqual(op.output, out)\n\n        # Keyword arguments\n        op = OpWithMultipleInputs(name=\"test_op\")\n        out = op(x=x, y=y, z=z)\n        self.assertIsInstance(out, keras_tensor.KerasTensor)\n        self.assertEqual(out.shape, (2, 3))\n        self.assertEqual(len(op._inbound_nodes), 1)\n        self.assertEqual(op.input, [x, y, z])\n        self.assertEqual(op.output, out)\n\n        # Mix\n        op = OpWithMultipleInputs(name=\"test_op\")\n        out = op(x, y=y, z=z)\n        self.assertIsInstance(out, keras_tensor.KerasTensor)\n        self.assertEqual(out.shape, (2, 3))\n        self.assertEqual(len(op._inbound_nodes), 1)\n        self.assertEqual(op.input, [x, y, z])\n        self.assertEqual(op.output, out)\n\n        # Test op reuse\n        prev_out = out\n        out = op(x, y=y, z=z)\n        self.assertIsInstance(out, keras_tensor.KerasTensor)\n        self.assertEqual(out.shape, (2, 3))\n        self.assertEqual(len(op._inbound_nodes), 2)\n        self.assertEqual(op.output, prev_out)\n\n        # Test multiple outputs\n        op = OpWithMultipleOutputs()\n        out = op(x)\n        self.assertIsInstance(out, tuple)\n        self.assertEqual(len(out), 2)\n        self.assertIsInstance(out[0], keras_tensor.KerasTensor)\n        self.assertIsInstance(out[1], keras_tensor.KerasTensor)\n        self.assertEqual(out[0].shape, (2, 3))\n        self.assertEqual(out[1].shape, (2, 3))\n        self.assertEqual(len(op._inbound_nodes), 1)\n        self.assertEqual(op.output, list(out))\n\n    def test_eager_call(self):\n        x = knp.ones((2, 3))\n        y = knp.ones((2, 3))\n        z = knp.ones((2, 3))\n        op = OpWithMultipleInputs(name=\"test_op\")\n        self.assertEqual(op.name, \"test_op\")\n\n        # Positional arguments\n        out = op(x, y, z)\n        self.assertTrue(backend.is_tensor(out))\n        self.assertAllClose(out, 6 * np.ones((2, 3)))\n\n        # Keyword arguments\n        out = op(x=x, y=y, z=z)\n        self.assertTrue(backend.is_tensor(out))\n        self.assertAllClose(out, 6 * np.ones((2, 3)))\n\n        # Mixed arguments\n        out = op(x, y=y, z=z)\n        self.assertTrue(backend.is_tensor(out))\n        self.assertAllClose(out, 6 * np.ones((2, 3)))\n\n        # Test multiple outputs\n        op = OpWithMultipleOutputs()\n        out = op(x)\n        self.assertEqual(len(out), 2)\n        self.assertTrue(backend.is_tensor(out[0]))\n        self.assertTrue(backend.is_tensor(out[1]))\n        self.assertAllClose(out[0], np.ones((2, 3)))\n        self.assertAllClose(out[1], np.ones((2, 3)) + 1)\n\n    def test_serialization_with_default_init_and_get_config(self):\n        # Explicit name passed in constructor is serialized and deserialized.\n        op = OpWithMultipleInputs(name=\"test_op\")\n        config = op.get_config()\n        self.assertEqual(config, {\"name\": \"test_op\"})\n        revived = OpWithMultipleInputs.from_config(config)\n        self.assertEqual(revived.get_config(), config)\n        self.assertEqual(revived.name, op.name)\n\n        # Auto generated name is serialized and deserialized.\n        op = OpWithMultipleInputs()\n        config = op.get_config()\n        self.assertEqual(config, {\"name\": op.name})\n        revived = OpWithMultipleInputs.from_config(config)\n        self.assertEqual(revived.get_config(), config)\n        self.assertEqual(revived.name, op.name)\n\n    def test_serialization_custom_constructor_with_name_auto_config(self):\n        # Explicit name passed in constructor is serialized and deserialized.\n        op = OpWithCustomConstructor(alpha=0.2, name=\"test_op\")\n        config = op.get_config()\n        self.assertEqual(config, {\"alpha\": 0.2, \"beta\": 1.0, \"name\": \"test_op\"})\n        revived = OpWithCustomConstructor.from_config(config)\n        self.assertEqual(revived.get_config(), config)\n        self.assertEqual(revived.name, op.name)\n\n        # Auto generated name is serialized and deserialized.\n        op = OpWithCustomConstructor(alpha=0.2, beta=0.0)\n        config = op.get_config()\n        self.assertEqual(config, {\"alpha\": 0.2, \"beta\": 0.0, \"name\": op.name})\n        revived = OpWithCustomConstructor.from_config(config)\n        self.assertEqual(revived.get_config(), config)\n        self.assertEqual(revived.name, op.name)\n\n    def test_serialization_custom_constructor_with_no_name_auto_config(self):\n        # Auto generated name is not serialized.\n        op = OpWithCustomConstructorNoName(alpha=0.2)\n        config = op.get_config()\n        self.assertEqual(config, {\"alpha\": 0.2, \"beta\": 1.0})\n        revived = OpWithCustomConstructorNoName.from_config(config)\n        self.assertEqual(revived.get_config(), config)\n\n    def test_serialization_custom_constructor_with_kwargs_auto_config(self):\n        # Explicit name passed in constructor is serialized and deserialized.\n        op = OpWithKwargsInConstructor(alpha=0.2, name=\"test_op\")\n        config = op.get_config()\n        self.assertEqual(config, {\"alpha\": 0.2, \"beta\": 1.0, \"name\": \"test_op\"})\n        revived = OpWithKwargsInConstructor.from_config(config)\n        self.assertEqual(revived.get_config(), config)\n        self.assertEqual(revived.name, op.name)\n\n        # Auto generated name is serialized and deserialized.\n        op = OpWithKwargsInConstructor(alpha=0.2, beta=0.0)\n        config = op.get_config()\n        self.assertEqual(config, {\"alpha\": 0.2, \"beta\": 0.0, \"name\": op.name})\n        revived = OpWithKwargsInConstructor.from_config(config)\n        self.assertEqual(revived.get_config(), config)\n        self.assertEqual(revived.name, op.name)\n\n    def test_failing_serialization_non_serializable_auto_config(\n        self,\n    ):\n        class NonSerializable:\n            pass\n\n        # Custom class cannot be automatically serialized.\n        op = OpWithCustomConstructor(alpha=NonSerializable(), name=\"test_op\")\n        with self.assertRaises(NotImplementedError):\n            _ = op.get_config()\n\n    def test_failing_serialization_custom_constructor_with_args_auto_config(\n        self,\n    ):\n        # Custom constructor with variadic args cannot be automatically\n        # serialized.\n        op = OpWithArgsInConstructor(0.2, \"a\", \"b\", \"c\", name=\"test_op\")\n        with self.assertRaises(NotImplementedError):\n            _ = op.get_config()\n\n    def test_serialization_custom_constructor_custom_get_config(self):\n        # Explicit name passed in constructor is serialized and deserialized.\n        op = OpWithCustomConstructorGetConfig(alpha=0.2, name=\"test_op\")\n        config = op.get_config()\n        self.assertEqual(config, {\"alpha\": 0.2, \"name\": \"test_op\"})\n        revived = OpWithCustomConstructorGetConfig.from_config(config)\n        self.assertEqual(revived.get_config(), config)\n        self.assertEqual(revived.name, op.name)\n\n        # Auto generated name is serialized and deserialized.\n        op = OpWithCustomConstructorGetConfig(alpha=0.2)\n        config = op.get_config()\n        self.assertEqual(config, {\"alpha\": 0.2, \"name\": op.name})\n        revived = OpWithCustomConstructorGetConfig.from_config(config)\n        self.assertEqual(revived.get_config(), config)\n        self.assertEqual(revived.name, op.name)\n\n    def test_serialization_custom_constructor_with_kwargs_custom_get_config(\n        self,\n    ):\n        # Explicit name passed in constructor is serialized and deserialized.\n        op = OpWithKwargsInConstructorGetConfig(alpha=0.2, name=\"test_op\")\n        config = op.get_config()\n        self.assertEqual(config, {\"alpha\": 0.2, \"name\": \"test_op\"})\n        revived = OpWithKwargsInConstructorGetConfig.from_config(config)\n        self.assertEqual(revived.get_config(), config)\n        self.assertEqual(revived.name, op.name)\n\n        # Auto generated name is serialized and deserialized.\n        op = OpWithKwargsInConstructorGetConfig(alpha=0.2)\n        config = op.get_config()\n        self.assertEqual(config, {\"alpha\": 0.2, \"name\": op.name})\n        revived = OpWithKwargsInConstructorGetConfig.from_config(config)\n        self.assertEqual(revived.get_config(), config)\n        self.assertEqual(revived.name, op.name)\n\n    @skip_if_backend(\n        \"openvino\", \"Can not constant fold eltwise node by CPU plugin\"\n    )\n    def test_input_conversion(self):\n        x = np.ones((2,))\n        y = np.ones((2,))\n        z = knp.ones((2,))  # mix\n        if backend.backend() == \"torch\":\n            z = z.cpu()\n        op = OpWithMultipleInputs()\n        out = op(x, y, z)\n        self.assertTrue(backend.is_tensor(out))\n        self.assertAllClose(out, 6 * np.ones((2,)))\n\n    def test_valid_naming(self):\n        OpWithMultipleOutputs(name=\"test_op\")\n\n        with self.assertRaisesRegex(\n            ValueError, \"must be a string and cannot contain character `/`.\"\n        ):\n            OpWithMultipleOutputs(name=\"test/op\")\n"
  },
  {
    "path": "keras/src/ops/operation_utils.py",
    "content": "import math\n\nimport numpy as np\n\nfrom keras.src import tree\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend.common.backend_utils import canonicalize_axis\nfrom keras.src.backend.common.backend_utils import to_tuple_or_list\n\n\ndef broadcast_shapes(shape1, shape2):\n    \"\"\"Broadcast input shapes to a unified shape.\n\n    Convert to list for mutability.\n\n    Args:\n        shape1: A tuple or list of integers.\n        shape2: A tuple or list of integers.\n\n    Returns:\n        output_shape (list of integers or `None`): The broadcasted shape.\n\n    Example:\n    >>> broadcast_shapes((5, 3), (1, 3))\n    [5, 3]\n    \"\"\"\n    shape1 = list(shape1)\n    shape2 = list(shape2)\n    origin_shape1 = shape1\n    origin_shape2 = shape2\n\n    if len(shape1) > len(shape2):\n        shape2 = [1] * (len(shape1) - len(shape2)) + shape2\n    if len(shape1) < len(shape2):\n        shape1 = [1] * (len(shape2) - len(shape1)) + shape1\n    output_shape = list(shape1)\n    for i in range(len(shape1)):\n        if shape1[i] == 1:\n            output_shape[i] = shape2[i]\n        elif shape1[i] is None:\n            output_shape[i] = None if shape2[i] == 1 else shape2[i]\n        else:\n            if shape2[i] == 1 or shape2[i] is None or shape2[i] == shape1[i]:\n                output_shape[i] = shape1[i]\n            else:\n                raise ValueError(\n                    \"Cannot broadcast shape, the failure dim has value \"\n                    f\"{shape1[i]}, which cannot be broadcasted to {shape2[i]}. \"\n                    f\"Input shapes are: {origin_shape1} and {origin_shape2}.\"\n                )\n\n    return output_shape\n\n\ndef compute_expand_dims_output_shape(input_shape, axis):\n    \"\"\"Compute the output shape for the `expand_dims` operation.\n\n    Args:\n        input_shape: Input shape.\n        axis: int or sequence of ints for the axis to expand.\n\n    Returns:\n        Tuple of ints: The output shape after the `expand_dims` operation.\n    \"\"\"\n    input_shape = list(input_shape)\n    if axis is None:\n        axis = len(input_shape)\n    axis = to_tuple_or_list(axis)\n    out_ndim = len(axis) + len(input_shape)\n    axis = [canonicalize_axis(a, out_ndim) for a in axis]\n    shape_iter = iter(input_shape)\n    new_shape = [\n        1 if ax in axis else next(shape_iter) for ax in range(out_ndim)\n    ]\n    return tuple(new_shape)\n\n\ndef compute_pooling_output_shape(\n    input_shape,\n    pool_size,\n    strides,\n    padding=\"valid\",\n    data_format=\"channels_last\",\n):\n    \"\"\"Computes the output shape of pooling operations.\n\n    Args:\n        input_shape: Input shape. Must be a tuple of integers.\n        pool_size: Size of the pooling operation. Must be a tuple of integers.\n        strides: Stride of the pooling operation. Must be a tuple of integers.\n            Defaults to `pool_size`.\n        padding: Padding method. Available methods are `\"valid\"` or `\"same\"`.\n            Defaults to `\"valid\"`.\n        data_format: String, either `\"channels_last\"` or `\"channels_first\"`.\n            The ordering of the dimensions in the inputs. `\"channels_last\"`\n            corresponds to inputs with shape `(batch, height, width, channels)`\n            while `\"channels_first\"` corresponds to inputs with shape\n            `(batch, channels, height, weight)`. Defaults to `\"channels_last\"`.\n\n    Returns:\n        Tuple of ints: The output shape of the pooling operation.\n\n    Examples:\n\n    # Basic usage with square pooling on a single image\n    >>> compute_pooling_output_shape((1, 4, 4, 1), (2, 2))\n    (1, 2, 2, 1)\n\n    # Strided pooling on a single image with strides different from pool_size\n    >>> compute_pooling_output_shape((1, 4, 4, 1), (2, 2), strides=(1, 1))\n    (1, 3, 3, 1)\n\n    # Pooling on a batch of images\n    >>> compute_pooling_output_shape((32, 4, 4, 3), (2, 2))\n    (32, 2, 2, 3)\n    \"\"\"\n    strides = pool_size if strides is None else strides\n    input_shape_origin = list(input_shape)\n    input_shape = np.array(input_shape)\n    if data_format == \"channels_last\":\n        spatial_shape = input_shape[1:-1]\n    else:\n        spatial_shape = input_shape[2:]\n    none_dims = []\n    for i in range(len(spatial_shape)):\n        if spatial_shape[i] is None:\n            # Set `None` shape to a manual value so that we can run numpy\n            # computation on `spatial_shape`.\n            spatial_shape[i] = -1\n            none_dims.append(i)\n    pool_size = np.array(pool_size)\n    if padding == \"valid\":\n        output_spatial_shape = (\n            np.floor((spatial_shape - pool_size) / strides) + 1\n        )\n        for i in range(len(output_spatial_shape)):\n            if i not in none_dims and output_spatial_shape[i] < 0:\n                raise ValueError(\n                    \"Computed output size would be negative. Received: \"\n                    f\"`inputs.shape={input_shape}` and `pool_size={pool_size}`.\"\n                )\n    elif padding == \"same\":\n        output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1\n    else:\n        raise ValueError(\n            \"Argument `padding` must be either 'valid' or 'same'. Received: \"\n            f\"padding={padding}\"\n        )\n    output_spatial_shape = [int(i) for i in output_spatial_shape]\n    for i in none_dims:\n        output_spatial_shape[i] = None\n    output_spatial_shape = tuple(output_spatial_shape)\n    if data_format == \"channels_last\":\n        output_shape = (\n            (input_shape_origin[0],)\n            + output_spatial_shape\n            + (input_shape_origin[-1],)\n        )\n    else:\n        output_shape = (\n            input_shape_origin[0],\n            input_shape_origin[1],\n        ) + output_spatial_shape\n    return output_shape\n\n\ndef compute_conv_output_shape(\n    input_shape,\n    filters,\n    kernel_size,\n    strides=1,\n    padding=\"valid\",\n    data_format=\"channels_last\",\n    dilation_rate=1,\n):\n    \"\"\"Compute the output shape of conv ops.\"\"\"\n    if data_format == \"channels_last\":\n        spatial_shape = input_shape[1:-1]\n        kernel_shape = kernel_size + (input_shape[-1], filters)\n    else:\n        spatial_shape = input_shape[2:]\n        kernel_shape = kernel_size + (input_shape[1], filters)\n    if len(kernel_shape) != len(input_shape):\n        raise ValueError(\n            \"Kernel shape must have the same length as input, but received \"\n            f\"kernel of shape {kernel_shape} and \"\n            f\"input of shape {input_shape}.\"\n        )\n    if isinstance(dilation_rate, int):\n        dilation_rate = (dilation_rate,) * len(spatial_shape)\n    if isinstance(strides, int):\n        strides = (strides,) * len(spatial_shape)\n    if len(dilation_rate) != len(spatial_shape):\n        raise ValueError(\n            \"Dilation must be None, scalar or tuple/list of length of \"\n            \"inputs' spatial shape, but received \"\n            f\"`dilation_rate={dilation_rate}` and \"\n            f\"input of shape {input_shape}.\"\n        )\n    none_dims = []\n    spatial_shape = np.array(spatial_shape)\n    for i in range(len(spatial_shape)):\n        if spatial_shape[i] is None:\n            # Set `None` shape to a manual value so that we can run numpy\n            # computation on `spatial_shape`.\n            spatial_shape[i] = -1\n            none_dims.append(i)\n\n    kernel_spatial_shape = np.array(kernel_shape[:-2])\n    dilation_rate = np.array(dilation_rate)\n    if padding == \"valid\":\n        output_spatial_shape = (\n            np.floor(\n                (spatial_shape - dilation_rate * (kernel_spatial_shape - 1) - 1)\n                / strides\n            )\n            + 1\n        )\n        for i in range(len(output_spatial_shape)):\n            if i not in none_dims and output_spatial_shape[i] <= 0:\n                raise ValueError(\n                    \"Computed output size would be zero or negative. Received \"\n                    f\"`inputs shape={input_shape}`, \"\n                    f\"`kernel shape={kernel_shape}`, \"\n                    f\"`dilation_rate={dilation_rate}`, \"\n                    f\"`strides={strides}`, \"\n                    f\"`padding={padding}`.\"\n                )\n\n    elif padding in (\"same\", \"causal\"):\n        output_spatial_shape = np.floor((spatial_shape - 1) / strides) + 1\n    else:\n        raise ValueError(\n            \"`padding` must be either `'valid'` or `'same'`. Received \"\n            f\"{padding}.\"\n        )\n    output_spatial_shape = [int(i) for i in output_spatial_shape]\n    for i in none_dims:\n        output_spatial_shape[i] = None\n    output_spatial_shape = tuple(output_spatial_shape)\n    if data_format == \"channels_last\":\n        output_shape = (\n            (input_shape[0],) + output_spatial_shape + (kernel_shape[-1],)\n        )\n    else:\n        output_shape = (input_shape[0], kernel_shape[-1]) + output_spatial_shape\n    return output_shape\n\n\ndef compute_matmul_output_shape(shape1, shape2):\n    \"\"\"Compute the output shape of a `matmul` operation.\n\n    Args:\n        shape1: Shape of the left operand.\n        shape2: Shape of the right operand.\n\n    Returns:\n        Tuple of ints: The output shape for the `matmul` operation.\n    \"\"\"\n    if len(shape1) == 1:\n        shape1 = (1, shape1[0])\n    if len(shape2) == 1:\n        shape2 = (shape2[0], 1)\n    if (\n        shape1[-1] is not None\n        and shape2[-2] is not None\n        and shape1[-1] != shape2[-2]\n    ):\n        raise ValueError(\n            \"Inner dimensions (`x1.shape[-1]` and `x2.shape[-2]`) must be \"\n            f\"equal, but received `x1.shape={shape1}` and \"\n            f\"`x2.shape={shape2}`.\"\n        )\n\n    leading_shape = broadcast_shapes(shape1[:-2], shape2[:-2])\n    last_2_dims_shape = [shape1[-2], shape2[-1]]\n    output_shape = leading_shape + last_2_dims_shape\n    if len(shape1) == 1:\n        del output_shape[-2]\n    if len(shape2) == 1:\n        del output_shape[-1]\n    return tuple(output_shape)\n\n\ndef compute_reshape_output_shape(input_shape, newshape, newshape_arg_name):\n    \"\"\"Converts `-1` in `newshape` to either an actual dimension or `None`.\n\n    This utility does not special case the 0th dimension (batch size).\n    \"\"\"\n    unknown_dim_count = newshape.count(-1)\n    if unknown_dim_count > 1:\n        raise ValueError(\n            \"There must be at most one unknown dimension (-1) in \"\n            f\"{newshape_arg_name}. Received: {newshape_arg_name}={newshape}.\"\n        )\n\n    # If there is a None in input_shape, we can't infer what the -1 is\n    if None in input_shape:\n        return tuple(dim if dim != -1 else None for dim in newshape)\n\n    input_size = math.prod(input_shape)\n    # If the `newshape` is fully defined, return it\n    if unknown_dim_count == 0:\n        if input_size != math.prod(newshape):\n            raise ValueError(\n                \"The total size of the tensor must be unchanged. Received: \"\n                f\"input_shape={input_shape}, {newshape_arg_name}={newshape}\"\n            )\n        return newshape\n\n    # We have one -1 in `newshape`, compute the actual value\n    known_output_size = 1\n    unknown_dim_index = None\n    for index, dim in enumerate(newshape):\n        if dim == -1:\n            unknown_dim_index = index\n        else:\n            known_output_size *= dim\n\n    if known_output_size == 0 or input_size % known_output_size != 0:\n        raise ValueError(\n            \"The total size of the tensor must be unchanged, however, the \"\n            \"input size cannot by divided by the specified dimensions in \"\n            f\"{newshape_arg_name}. Received: input_shape={input_shape}, \"\n            f\"{newshape_arg_name}={newshape}\"\n        )\n\n    output_shape = list(newshape)\n    output_shape[unknown_dim_index] = input_size // known_output_size\n    return tuple(output_shape)\n\n\ndef compute_transpose_output_shape(input_shape, axes):\n    \"\"\"Compute the output shape for the `transpose` operation.\n\n    Args:\n        input_shape: Input shape.\n        axes: Permutation of the dimensions for the `transpose` operation.\n\n    Returns:\n        Tuple of ints: The output shape after the `transpose` operation.\n    \"\"\"\n    input_shape = list(input_shape)\n    if axes is None:\n        return tuple(input_shape[::-1])\n\n    if len(axes) != len(input_shape):\n        raise ValueError(\n            \"axis must be a list of the same length as the input shape, \"\n            f\"expected {len(input_shape)}, but received {len(axes)}.\"\n        )\n    return tuple(input_shape[ax] for ax in axes)\n\n\ndef compute_take_along_axis_output_shape(input_shape, indices_shape, axis):\n    input_shape = list(input_shape)\n    indices_shape = list(indices_shape)\n    if axis is None:\n        input_shape = (\n            [None] if None in input_shape else [int(np.prod(input_shape))]\n        )\n\n    if len(input_shape) != len(indices_shape):\n        raise ValueError(\n            \"`x` and `indices` must have the same number of dimensions, \"\n            f\"but receive shape {input_shape} and {indices_shape}.\"\n        )\n\n    input_shape[axis] = indices_shape[axis]\n    output_shape = broadcast_shapes(input_shape, indices_shape)\n    return output_shape\n\n\ndef reduce_shape(shape, axis=None, keepdims=False):\n    shape = list(shape)\n    if axis is None:\n        if keepdims:\n            return tuple([1 for _ in shape])\n        else:\n            return tuple([])\n    elif isinstance(axis, int):\n        axis = (axis,)\n\n    axis = tuple(canonicalize_axis(a, len(shape)) for a in axis)\n\n    if keepdims:\n        for ax in axis:\n            shape[ax] = 1\n        return tuple(shape)\n    else:\n        for ax in sorted(axis, reverse=True):\n            del shape[ax]\n        return tuple(shape)\n\n\n@keras_export(\"keras.utils.get_source_inputs\")\ndef get_source_inputs(tensor):\n    \"\"\"Returns the list of input tensors necessary to compute `tensor`.\n\n    Output will always be a list of tensors\n    (potentially with 1 element).\n\n    Args:\n        tensor: The tensor to start from.\n\n    Returns:\n        List of input tensors.\n    \"\"\"\n    if not hasattr(tensor, \"_keras_history\"):\n        return tensor\n\n    operation, node_index, _ = tensor._keras_history\n    if not operation or not operation._inbound_nodes:\n        return [tensor]\n    else:\n        node = operation._inbound_nodes[node_index]\n        if node.is_input:\n            # Reached input node, stop recursion.\n            return tree.flatten(node.output_tensors)\n        else:\n            source_tensors = []\n            for tensor in node.input_tensors:\n                previous_sources = get_source_inputs(tensor)\n                # Avoid input redundancy.\n                for x in previous_sources:\n                    if all(x is not t for t in source_tensors):\n                        source_tensors.append(x)\n            return source_tensors\n"
  },
  {
    "path": "keras/src/ops/operation_utils_test.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.layers.core import input_layer\nfrom keras.src.ops import operation_utils\n\n\nclass OperationUtilsTest(testing.TestCase):\n    def test_get_source_inputs(self):\n        x1 = backend.KerasTensor(shape=(2,))\n        x2 = backend.KerasTensor(shape=(2,))\n        x = x1 + x2\n        x += 2\n        x = ops.square(x)\n        self.assertEqual(operation_utils.get_source_inputs(x), [x1, x2])\n\n    def test_get_source_inputs_return_input_tensor(self):\n        inputs = input_layer.Input(shape=(10,))\n        self.assertIs(operation_utils.get_source_inputs(inputs)[0], inputs)\n\n    def test_compute_expand_dims_output_shape(self):\n        input_shape = (2, 3, 4)\n        axis = -1\n        output_shape = operation_utils.compute_expand_dims_output_shape(\n            input_shape, axis\n        )\n        expected_output_shape = (2, 3, 4, 1)\n        self.assertEqual(output_shape, expected_output_shape)\n\n        input_shape = (2, 3, 4)\n        axis = (1, -1)\n        output_shape = operation_utils.compute_expand_dims_output_shape(\n            input_shape, axis\n        )\n        expected_output_shape = (2, 1, 3, 4, 1)\n        self.assertEqual(output_shape, expected_output_shape)\n\n    def test_compute_pooling_output_shape(self):\n        input_shape = (1, 4, 4, 1)\n        pool_size = (2, 2)\n        strides = (2, 2)\n        output_shape = operation_utils.compute_pooling_output_shape(\n            input_shape, pool_size, strides\n        )\n        expected_output_shape = (1, 2, 2, 1)\n        self.assertEqual(output_shape, expected_output_shape)\n\n    def test_compute_pooling_output_shape_with_none(self):\n        input_shape = (None, 4, 4, 1)\n        pool_size = (2, 2)\n        strides = (2, 2)\n        output_shape = operation_utils.compute_pooling_output_shape(\n            input_shape, pool_size, strides\n        )\n        expected_output_shape = (None, 2, 2, 1)\n        self.assertEqual(output_shape, expected_output_shape)\n\n    def test_compute_pooling_output_shape_valid_padding(self):\n        input_shape = (1, 4, 4, 1)\n        pool_size = (2, 2)\n        strides = (2, 2)\n        output_shape = operation_utils.compute_pooling_output_shape(\n            input_shape, pool_size, strides, padding=\"valid\"\n        )\n        self.assertEqual(output_shape, (1, 2, 2, 1))\n\n    def test_compute_pooling_output_shape_channels_last(self):\n        input_shape = (1, 4, 4, 3)\n        pool_size = (2, 2)\n        strides = (2, 2)\n        output_shape = operation_utils.compute_pooling_output_shape(\n            input_shape,\n            pool_size,\n            strides,\n            padding=\"valid\",\n            data_format=\"channels_last\",\n        )\n        self.assertEqual(output_shape, (1, 2, 2, 3))\n\n    def test_compute_pooling_output_shape_same_padding_stride1(self):\n        input_shape = (1, 4, 4, 3)\n        pool_size = (2, 2)\n        strides = (1, 1)\n        output_shape = operation_utils.compute_pooling_output_shape(\n            input_shape,\n            pool_size,\n            strides,\n            padding=\"same\",\n            data_format=\"channels_last\",\n        )\n        self.assertEqual(output_shape, (1, 4, 4, 3))\n\n    def test_compute_conv_output_shape(self):\n        input_shape = (1, 4, 4, 1)\n        filters = 1\n        kernel_size = (3, 3)\n        strides = (1, 1)\n        output_shape = operation_utils.compute_conv_output_shape(\n            input_shape, filters, kernel_size, strides\n        )\n        expected_output_shape = (1, 2, 2, 1)\n        self.assertEqual(output_shape, expected_output_shape)\n\n    def test_compute_conv_output_shape_with_none(self):\n        input_shape = (None, 4, 4, 1)\n        kernel_size = (3, 3)\n        filters = 1\n        strides = (1, 1)\n        output_shape = operation_utils.compute_conv_output_shape(\n            input_shape, filters, kernel_size, strides\n        )\n        expected_output_shape = (None, 2, 2, 1)\n        self.assertEqual(output_shape, expected_output_shape)\n\n    def test_compute_conv_output_shape_valid_padding(self):\n        input_shape = (1, 4, 4, 1)\n        kernel_size = (3, 3)\n        filters = 1\n        strides = (2, 2)\n        output_shape = operation_utils.compute_conv_output_shape(\n            input_shape, filters, kernel_size, strides, padding=\"valid\"\n        )\n        self.assertEqual(output_shape, (1, 1, 1, 1))\n\n    def test_compute_conv_output_shape_channels_last(self):\n        input_shape = (1, 4, 4, 3)\n        kernel_size = (3, 3)\n        filters = 3\n        strides = (2, 2)\n        output_shape = operation_utils.compute_conv_output_shape(\n            input_shape,\n            filters,\n            kernel_size,\n            strides,\n            padding=\"valid\",\n            data_format=\"channels_last\",\n        )\n        self.assertEqual(output_shape, (1, 1, 1, 3))\n\n    def test_compute_conv_output_shape_same_padding_stride1(self):\n        input_shape = (1, 4, 4, 3)\n        kernel_size = (3, 3)\n        filters = 3\n        strides = (1, 1)\n        output_shape = operation_utils.compute_conv_output_shape(\n            input_shape,\n            filters,\n            kernel_size,\n            strides,\n            padding=\"same\",\n            data_format=\"channels_last\",\n        )\n        self.assertEqual(output_shape, (1, 4, 4, 3))\n\n    def test_compute_reshape_output_shape(self):\n        input_shape = (1, 4, 4, 1)\n        target_shape = (16, 1)\n        output_shape = operation_utils.compute_reshape_output_shape(\n            input_shape, newshape=target_shape, newshape_arg_name=\"New shape\"\n        )\n        self.assertEqual(output_shape, target_shape)\n\n    def test_reduce_shape_no_axes_no_keepdims(self):\n        input_shape = (1, 4, 4, 1)\n        output_shape = operation_utils.reduce_shape(input_shape)\n        expected_output_shape = ()\n        self.assertEqual(output_shape, expected_output_shape)\n\n    def test_reduce_shape_no_axes_with_keepdims(self):\n        input_shape = (1, 4, 4, 1)\n        output_shape = operation_utils.reduce_shape(input_shape, keepdims=True)\n        expected_output_shape = (1, 1, 1, 1)\n        self.assertEqual(output_shape, expected_output_shape)\n\n    def test_reduce_shape_single_axis_no_keepdims(self):\n        input_shape = (1, 4, 4, 1)\n        axes = [1]\n        output_shape = operation_utils.reduce_shape(input_shape, axes)\n        expected_output_shape = (1, 4, 1)\n        self.assertEqual(output_shape, expected_output_shape)\n\n    def test_reduce_shape_single_axis_with_keepdims(self):\n        input_shape = (1, 4, 4, 1)\n        axes = [1]\n        output_shape = operation_utils.reduce_shape(\n            input_shape, axes, keepdims=True\n        )\n        expected_output_shape = (1, 1, 4, 1)\n        self.assertEqual(output_shape, expected_output_shape)\n\n    def test_reduce_shape_multiple_axes_no_keepdims(self):\n        input_shape = (1, 4, 4, 1)\n        axes = [1, 2]\n        output_shape = operation_utils.reduce_shape(input_shape, axes)\n        expected_output_shape = (1, 1)\n        self.assertEqual(output_shape, expected_output_shape)\n\n    def test_reduce_shape_out_of_order_axes_no_keepdims(self):\n        input_shape = (1, 4, 4, 1)\n        axes = [2, 1]\n        output_shape = operation_utils.reduce_shape(input_shape, axes)\n        expected_output_shape = (1, 1)\n        self.assertEqual(output_shape, expected_output_shape)\n\n    def test_reduce_shape_negative_axes_no_keepdims(self):\n        input_shape = (1, 4, 4, 1)\n        axes = [-2, -3]\n        output_shape = operation_utils.reduce_shape(input_shape, axes)\n        expected_output_shape = (1, 1)\n        self.assertEqual(output_shape, expected_output_shape)\n"
  },
  {
    "path": "keras/src/ops/ops_test.py",
    "content": "import inspect\n\nfrom absl.testing import parameterized\n\ntry:\n    from keras.api import ops as api_ops_root\nexcept ImportError:\n    from keras import ops as api_ops_root\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.ops.operation import Operation\nfrom keras.src.testing.test_utils import named_product\nfrom keras.src.utils.naming import to_snake_case\n\nOPS_MODULES = (\"core\", \"image\", \"linalg\", \"math\", \"nn\", \"numpy\")\n\nSELF_PARAMETER = inspect.Parameter(\n    \"self\", inspect.Parameter.POSITIONAL_OR_KEYWORD\n)\nNAME_PARAMETER = inspect.Parameter(\n    \"name\", inspect.Parameter.KEYWORD_ONLY, default=None\n)\n\n# Parameters with these names are known to always be static (non-tensors).\nSTATIC_PARAMETER_NAMES = frozenset(\n    {\"axis\", \"axes\", \"dtype\", \"shape\", \"newshape\", \"sparse\", \"ragged\"}\n)\n\n\ndef op_functions_and_classes(ops_module):\n    \"\"\"Enumerate pairs of op function and op classes in a module.\n\n    Will return for instance `(ExpandDims, expand_dims)`, `(Sum, sum)`, ...\n\n    Args:\n        ops_module: the module to explore.\n\n    Returns:\n        iterable returning tuples with function and class pairs.\n    \"\"\"\n    # Go through all symbols.\n    for op_class_name in dir(ops_module):\n        op_class = getattr(ops_module, op_class_name)\n        # Find the ones that are classes that extend `Operation`.\n        if isinstance(op_class, type) and Operation in op_class.__mro__:\n            # Infer what the corresponding op function name should be.\n            op_function_name = to_snake_case(op_class_name)\n            # With some exceptions.\n            op_function_name = {\n                \"batch_norm\": \"batch_normalization\",\n                \"rms_norm\": \"rms_normalization\",\n                \"search_sorted\": \"searchsorted\",\n            }.get(op_function_name, op_function_name)\n            # Check if that function exist. Some classes are abstract super\n            # classes for multiple operations and should be ignored.\n            op_function = getattr(ops_module, op_function_name, None)\n            if op_function is not None:\n                # We have a pair, return it.\n                yield op_function, op_class\n\n\nclass OperationTest(testing.TestCase):\n    @parameterized.named_parameters(named_product(module_name=OPS_MODULES))\n    def test_class_function_consistency(self, module_name):\n        ops_module = getattr(ops, module_name)\n        if module_name in (\"core\", \"math\"):\n            # `core` and `math` are not exported as their own module.\n            api_ops_module = None\n        else:\n            api_ops_module = getattr(api_ops_root, module_name)\n\n        for op_function, op_class in op_functions_and_classes(ops_module):\n            name = op_function.__name__\n\n            # ==== Check exports ====\n            # - op should be exported as e.g. `keras.ops.numpy.sum`\n            # - op should also be exported as e.g. `keras.ops.sum`\n\n            if module_name != \"image\":\n                # `image` ops are not exported at the top-level.\n                self.assertIsNotNone(\n                    getattr(api_ops_root, name, None),\n                    f\"Not exported as `keras.ops.{name}`\",\n                )\n            if api_ops_module is not None:\n                # `core` and `math` are not exported as their own module.\n                self.assertIsNotNone(\n                    getattr(api_ops_module, name, None),\n                    f\"Not exported as `keras.ops.{module_name}.{name}`\",\n                )\n\n            # ==== Check handling of name in __init__ ====\n            # - op class `__init__` should have a `name` parameter at the end,\n            #   which should be keyword only and with a default value of `None`\n            # - op class `__init__` should call `super().__init__(name=name)`\n\n            if op_class.__init__ is Operation.__init__:\n                # `name` is not keyword only in `Operation`, use this instead.\n                class_init_signature = inspect.Signature(\n                    [SELF_PARAMETER, NAME_PARAMETER]\n                )\n            else:\n                class_init_signature = inspect.signature(op_class.__init__)\n\n                # Check call to super.\n                self.assertContainsSubsequence(\n                    inspect.getsource(op_class.__init__),\n                    \"super().__init__(name=name)\",\n                    f\"`{op_class.__name__}.__init__` is not calling \"\n                    \"`super().__init__(name=name)`\",\n                )\n\n            static_parameters = list(class_init_signature.parameters.values())\n            # Remove `self`.\n            static_parameters = static_parameters[1:]\n            name_index = -1\n            if static_parameters[-1].kind == inspect.Parameter.VAR_KEYWORD:\n                # When there is a `**kwargs`, `name` appears before.\n                name_index = -2\n            # Verify `name` parameter is as expected.\n            self.assertEqual(\n                static_parameters[name_index],\n                NAME_PARAMETER,\n                f\"The last parameter of `{op_class.__name__}.__init__` \"\n                \"should be `name`, should be a keyword only, and should \"\n                \"have a default value of `None`\",\n            )\n            # Remove `name`, it's not part of the op signature.\n            static_parameters.pop(name_index)\n\n            # ==== Check static parameters ====\n            # Static parameters are declared in the class' `__init__`.\n            # Dynamic parameters are declared in the class' `call` method.\n            # - they should all appear in the op signature with the same name\n            # - they should have the same default value\n            # - they should appear in the same order and usually with the\n            #   dynamic parameters first, and the static parameters last.\n\n            dynamic_parameters = list(\n                inspect.signature(op_class.call).parameters.values()\n            )[1:]  # Remove self\n\n            op_signature = inspect.signature(op_function)\n\n            for p in dynamic_parameters + static_parameters:\n                # Check the same name appears in the op signature\n                self.assertIn(\n                    p.name,\n                    op_signature.parameters,\n                    f\"Op function `{name}` is missing a parameter that is in \"\n                    f\"op class `{op_class.__name__}`\",\n                )\n                # Check default values are the same\n                self.assertEqual(\n                    p.default,\n                    op_signature.parameters[p.name].default,\n                    f\"Default mismatch for parameter `{p.name}` between op \"\n                    f\"function `{name}` and op class `{op_class.__name__}`\",\n                )\n\n            dynamic_parameter_names = [p.name for p in dynamic_parameters]\n            static_parameter_names = [p.name for p in static_parameters]\n\n            # Check for obvious mistakes in parameters that were made dynamic\n            # but should be static.\n            for p in dynamic_parameters:\n                self.assertNotIn(\n                    p.name,\n                    STATIC_PARAMETER_NAMES,\n                    f\"`{p.name}` should not be a dynamic parameter in op class \"\n                    f\"`{op_class.__name__}` based on its name.\",\n                )\n                self.assertNotIsInstance(\n                    p.default,\n                    (bool, str),\n                    f\"`{p.name}` should not be a dynamic parameter in op class \"\n                    f\"`{op_class.__name__}` based on default `{p.default}`.\",\n                )\n\n            # Check order of parameters.\n            if name in (\n                \"fori_loop\",\n                \"vectorized_map\",\n                \"while_loop\",\n                \"batch_normalization\",\n                \"dot_product_attention\",\n                \"average\",\n                \"einsum\",\n                \"full\",\n                \"pad\",\n            ):\n                # Loose case:\n                # order of of parameters is preserved but they are interspersed.\n                op_dynamic_parameter_names = [\n                    name\n                    for name in op_signature.parameters.keys()\n                    if name in dynamic_parameter_names\n                ]\n                self.assertEqual(\n                    op_dynamic_parameter_names,\n                    dynamic_parameter_names,\n                    \"Inconsistent dynamic parameter order for op \"\n                    f\"function `{name}` and op class `{op_class.__name__}`\",\n                )\n                op_static_parameter_names = [\n                    name\n                    for name in op_signature.parameters.keys()\n                    if name in static_parameter_names\n                ]\n                self.assertEqual(\n                    op_static_parameter_names,\n                    static_parameter_names,\n                    \"Inconsistent static parameter order for op \"\n                    f\"function `{name}` and op class `{op_class.__name__}`\",\n                )\n            else:\n                # Strict case:\n                # dynamic parameters first and static parameters at the end.\n                self.assertEqual(\n                    list(op_signature.parameters.keys()),\n                    dynamic_parameter_names + static_parameter_names,\n                    \"Inconsistent static parameter position for op \"\n                    f\"function `{name}` and op class `{op_class.__name__}`\",\n                )\n\n            # ==== Check compute_output_spec is implement ====\n            # - op class should override Operation's `compute_output_spec`\n            self.assertTrue(\n                hasattr(op_class, \"compute_output_spec\")\n                and op_class.compute_output_spec\n                is not Operation.compute_output_spec,\n                f\"Op class `{op_class.__name__}` should override \"\n                \"`compute_output_spec`\",\n            )\n\n    @parameterized.named_parameters(named_product(module_name=OPS_MODULES))\n    def test_backend_consistency(self, module_name):\n        ops_module = getattr(ops, module_name)\n        backend_ops_module = getattr(backend, module_name)\n\n        for op_function, _ in op_functions_and_classes(ops_module):\n            name = op_function.__name__\n\n            if hasattr(ops_module, f\"_{name}\"):\n                # For an op function `foo`, if there is a function named `_foo`,\n                # that means we have a backend independent implementation.\n                continue\n            if name in (\"view_as_complex\", \"view_as_real\", \"get_item\"):\n                # These ops have an inlined backend independent implementation.\n                continue\n\n            # ==== Check backend implementation ====\n            # - op should have an implementation in every backend\n            # - op implementation should have the same signature (same\n            #   parameters, same order, same defaults)\n\n            backend_op_function = getattr(backend_ops_module, name, None)\n\n            if backend.backend() == \"openvino\" and backend_op_function is None:\n                # Openvino is still missing a number of ops.\n                continue\n\n            self.assertIsNotNone(backend_op_function, f\"Missing op `{name}`\")\n\n            if name == \"multi_hot\":\n                # multi_hot has code to massage the input parameters before\n                # calling the backend implementation, so the signature is\n                # different on purpose.\n                continue\n\n            # Signature should match in every way.\n            self.assertEqual(\n                inspect.signature(backend_op_function),\n                inspect.signature(op_function),\n                f\"Signature mismatch for `{name}`\",\n            )\n"
  },
  {
    "path": "keras/src/ops/symbolic_arguments.py",
    "content": "from keras.src import tree\nfrom keras.src.backend import KerasTensor\n\n\nclass SymbolicArguments:\n    def __init__(self, *args, **kwargs):\n        self.args = tree.map_structure(lambda x: x, args)\n        self.kwargs = tree.map_structure(lambda x: x, kwargs)\n        self._flat_arguments = tree.flatten((self.args, self.kwargs))\n\n        # Used to avoid expensive `tree` operations in the most common case.\n        if (\n            not self.kwargs\n            and len(self.args) == 1\n            and isinstance(self.args[0], KerasTensor)\n        ):\n            self._single_positional_tensor = self.args[0]\n        else:\n            self._single_positional_tensor = None\n\n        self.keras_tensors = []\n        for arg in self._flat_arguments:\n            if isinstance(arg, KerasTensor):\n                self.keras_tensors.append(arg)\n\n    def convert(self, conversion_fn):\n        args = tree.map_structure(conversion_fn, self.args)\n        kwargs = tree.map_structure(conversion_fn, self.kwargs)\n        return args, kwargs\n\n    def fill_in(self, tensor_dict):\n        \"\"\"Maps KerasTensors to computed values using `tensor_dict`.\n\n        `tensor_dict` maps `KerasTensor` instances to their current values.\n        \"\"\"\n        if self._single_positional_tensor is not None:\n            # Performance optimization for most common case.\n            # Approx. 70x faster.\n            return (tensor_dict[id(self._single_positional_tensor)],), {}\n\n        def switch_fn(x):\n            if isinstance(x, KerasTensor):\n                return tensor_dict.get(id(x), None)\n            return x\n\n        return self.convert(switch_fn)\n"
  },
  {
    "path": "keras/src/ops/symbolic_arguments_test.py",
    "content": "from keras.src import testing\nfrom keras.src import tree\nfrom keras.src.backend import KerasTensor\nfrom keras.src.ops.symbolic_arguments import SymbolicArguments\n\n\nclass SymbolicArgumentsTest(testing.TestCase):\n    # Testing multiple args and empty kwargs\n    def test_args(self):\n        shape = (2, 3, 4)\n        a = KerasTensor(shape=shape)\n        b = KerasTensor(shape=shape)\n        args = SymbolicArguments(\n            (\n                a,\n                b,\n            ),\n            {},\n        )\n\n        self.assertEqual(args.keras_tensors, [a, b])\n        self.assertEqual(args._flat_arguments, [a, b])\n        self.assertEqual(args._single_positional_tensor, None)\n\n    # Testing single arg and single position tensor\n    def test_args_single_arg(self):\n        shape = (2, 3, 4)\n        a = KerasTensor(shape=shape)\n        args = SymbolicArguments((a))\n\n        self.assertEqual(args.keras_tensors, [a])\n        self.assertEqual(args._flat_arguments, [a])\n        self.assertEqual(len(args.kwargs), 0)\n        self.assertEqual(isinstance(args.args[0], KerasTensor), True)\n        self.assertEqual(args._single_positional_tensor, a)\n\n    # Testing kwargs\n    def test_kwargs(self):\n        shape = (2, 3, 4)\n        a = KerasTensor(shape=shape)\n        b = KerasTensor(shape=shape)\n        c = KerasTensor(shape=shape)\n        args = SymbolicArguments(\n            (\n                a,\n                b,\n            ),\n            {1: c},\n        )\n\n        self.assertEqual(args.keras_tensors, [a, b, c])\n        self.assertEqual(args._flat_arguments, [a, b, c])\n        self.assertEqual(args._single_positional_tensor, None)\n\n    # Testing conversion function with args and kwargs\n    def test_conversion_fn(self):\n        shape = (2, 3, 4)\n        a = KerasTensor(shape=shape)\n        b = KerasTensor(shape=shape)\n        c = KerasTensor(shape=shape)\n        sym_args = SymbolicArguments(\n            (\n                a,\n                b,\n            ),\n            {1: c},\n        )\n\n        (value, _) = sym_args.convert(lambda x: x**2)\n        args1 = value[0][0]\n\n        self.assertIsInstance(args1, KerasTensor)\n\n        mapped_value = tree.map_structure(lambda x: x**2, a)\n        self.assertEqual(mapped_value.shape, args1.shape)\n        self.assertEqual(mapped_value.dtype, args1.dtype)\n\n    # Testing fill in function with single args only\n    def test_fill_in_single_arg(self):\n        shape = (2, 3, 4)\n        a = KerasTensor(shape=shape)\n\n        tensor_dict = {id(a): 3}\n        sym_args = SymbolicArguments((a))\n\n        # Call the method to be tested\n        result, _ = sym_args.fill_in(tensor_dict)\n\n        self.assertEqual(result, (3,))\n\n    # Testing fill in function with multiple args\n    def test_fill_in_multiple_arg(self):\n        shape = (2, 3, 4)\n        a = KerasTensor(shape=shape)\n        b = KerasTensor(shape=shape)\n\n        tensor_dict = {id(b): 2}\n        sym_args = SymbolicArguments((a, b))\n\n        # Call the method to be tested\n        result, _ = sym_args.fill_in(tensor_dict)\n        self.assertEqual(result, ((None, 2),))\n\n    # Testing fill in function for args and kwargs\n    def test_fill_in(self):\n        shape1 = (2, 3, 4)\n        shape2 = (3, 2, 4)\n        a = KerasTensor(shape=shape1)\n        b = KerasTensor(shape=shape2)\n        c = KerasTensor(shape=shape2)\n        dictionary = {id(a): 3, id(c): 2}\n        sym_args = SymbolicArguments(\n            (\n                a,\n                b,\n            ),\n            {\"1\": c},\n        )\n\n        (values, _) = sym_args.fill_in(dictionary)\n        self.assertEqual(values, ((3, None), {\"1\": 2}))\n"
  },
  {
    "path": "keras/src/optimizers/__init__.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.optimizers.adadelta import Adadelta\nfrom keras.src.optimizers.adafactor import Adafactor\nfrom keras.src.optimizers.adagrad import Adagrad\nfrom keras.src.optimizers.adam import Adam\nfrom keras.src.optimizers.adamax import Adamax\nfrom keras.src.optimizers.adamw import AdamW\nfrom keras.src.optimizers.ftrl import Ftrl\nfrom keras.src.optimizers.lion import Lion\nfrom keras.src.optimizers.loss_scale_optimizer import LossScaleOptimizer\nfrom keras.src.optimizers.muon import Muon\nfrom keras.src.optimizers.nadam import Nadam\nfrom keras.src.optimizers.optimizer import Optimizer\nfrom keras.src.optimizers.rmsprop import RMSprop\nfrom keras.src.optimizers.schedule_free_adamw import ScheduleFreeAdamW\nfrom keras.src.optimizers.sgd import SGD\nfrom keras.src.saving import serialization_lib\n\nALL_OBJECTS = {\n    Optimizer,\n    Adam,\n    SGD,\n    RMSprop,\n    Adadelta,\n    AdamW,\n    Adagrad,\n    Adamax,\n    Adafactor,\n    Muon,\n    Nadam,\n    Ftrl,\n    Lion,\n    LossScaleOptimizer,\n    ScheduleFreeAdamW,\n}\nALL_OBJECTS_DICT = {cls.__name__.lower(): cls for cls in ALL_OBJECTS}\n\n\n@keras_export(\"keras.optimizers.serialize\")\ndef serialize(optimizer):\n    \"\"\"Returns the optimizer configuration as a Python dict.\n\n    Args:\n        optimizer: An `Optimizer` instance to serialize.\n\n    Returns:\n        Python dict which contains the configuration of the optimizer.\n    \"\"\"\n    return serialization_lib.serialize_keras_object(optimizer)\n\n\n@keras_export(\"keras.optimizers.deserialize\")\ndef deserialize(config, custom_objects=None):\n    \"\"\"Returns a Keras optimizer object via its configuration.\n\n    Args:\n        config: Optimizer configuration dictionary.\n        custom_objects: Optional dictionary mapping names (strings) to custom\n            objects (classes and functions) to be considered during\n            deserialization.\n\n    Returns:\n        A Keras Optimizer instance.\n    \"\"\"\n    # Make deserialization case-insensitive for built-in optimizers.\n    if config[\"class_name\"].lower() in ALL_OBJECTS_DICT:\n        config[\"class_name\"] = config[\"class_name\"].lower()\n\n    return serialization_lib.deserialize_keras_object(\n        config,\n        module_objects=ALL_OBJECTS_DICT,\n        custom_objects=custom_objects,\n    )\n\n\n@keras_export(\"keras.optimizers.get\")\ndef get(identifier):\n    \"\"\"Retrieves a Keras Optimizer instance.\n\n    Args:\n        identifier: Optimizer identifier, one of:\n            - String: name of an optimizer\n            - Dictionary: configuration dictionary.\n            - Keras Optimizer instance (it will be returned unchanged).\n\n    Returns:\n        A Keras Optimizer instance.\n    \"\"\"\n    if identifier is None:\n        return None\n    elif isinstance(identifier, dict):\n        obj = deserialize(identifier)\n    elif isinstance(identifier, str):\n        config = {\"class_name\": identifier, \"config\": {}}\n        obj = deserialize(config)\n    else:\n        obj = identifier\n\n    if isinstance(obj, Optimizer):\n        return obj\n    raise ValueError(f\"Could not interpret optimizer identifier: {identifier}\")\n\n\n# We will add this temporarily so that tensorflow packages that depend on\n# estimators will continue to import (there are a large number). Note that\n# Keras 3 will not work with the estimators API.\n@keras_export(\n    [\n        \"keras.optimizers.legacy.Adagrad\",\n        \"keras.optimizers.legacy.Adam\",\n        \"keras.optimizers.legacy.Ftrl\",\n        \"keras.optimizers.legacy.RMSprop\",\n        \"keras.optimizers.legacy.SGD\",\n        \"keras.optimizers.legacy.Optimizer\",\n    ]\n)\nclass LegacyOptimizerWarning:\n    def __init__(self, *args, **kwargs):\n        raise ImportError(\n            \"`keras.optimizers.legacy` is not supported in Keras 3. When using \"\n            \"`tf.keras`, to continue using a `tf.keras.optimizers.legacy` \"\n            \"optimizer, you can install the `tf_keras` package (Keras 2) and \"\n            \"set the environment variable `TF_USE_LEGACY_KERAS=True` to \"\n            \"configure TensorFlow to use `tf_keras` when accessing `tf.keras`.\"\n        )\n"
  },
  {
    "path": "keras/src/optimizers/adadelta.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.optimizers import optimizer\n\n\n@keras_export([\"keras.optimizers.Adadelta\"])\nclass Adadelta(optimizer.Optimizer):\n    \"\"\"Optimizer that implements the Adadelta algorithm.\n\n    Adadelta optimization is a stochastic gradient descent method that is based\n    on adaptive learning rate per dimension to address two drawbacks:\n\n    - The continual decay of learning rates throughout training.\n    - The need for a manually selected global learning rate.\n\n    Adadelta is a more robust extension of Adagrad that adapts learning rates\n    based on a moving window of gradient updates, instead of accumulating all\n    past gradients. This way, Adadelta continues learning even when many updates\n    have been done. Compared to Adagrad, in the original version of Adadelta you\n    don't have to set an initial learning rate. In this version, the initial\n    learning rate can be set, as in most other Keras optimizers.\n\n    Args:\n        learning_rate: A float, a\n            `keras.optimizers.schedules.LearningRateSchedule` instance, or\n            a callable that takes no arguments and returns the actual value to\n            use. The learning rate. Defaults to `0.001`. Note that `Adadelta`\n            tends to benefit from higher initial learning rate values compared\n            to other optimizers. To match the exact form in the original paper,\n            use 1.0.\n        rho: A floating point value. The decay rate. Defaults to `0.95`.\n        epsilon: Small floating point value for maintaining numerical stability.\n        {{base_optimizer_keyword_args}}\n\n    Reference:\n\n    - [Zeiler, 2012](http://arxiv.org/abs/1212.5701)\n    \"\"\"\n\n    def __init__(\n        self,\n        learning_rate=0.001,\n        rho=0.95,\n        epsilon=1e-7,\n        weight_decay=None,\n        clipnorm=None,\n        clipvalue=None,\n        global_clipnorm=None,\n        use_ema=False,\n        ema_momentum=0.99,\n        ema_overwrite_frequency=None,\n        loss_scale_factor=None,\n        gradient_accumulation_steps=None,\n        name=\"adadelta\",\n        **kwargs,\n    ):\n        super().__init__(\n            learning_rate=learning_rate,\n            weight_decay=weight_decay,\n            clipnorm=clipnorm,\n            clipvalue=clipvalue,\n            global_clipnorm=global_clipnorm,\n            use_ema=use_ema,\n            ema_momentum=ema_momentum,\n            ema_overwrite_frequency=ema_overwrite_frequency,\n            loss_scale_factor=loss_scale_factor,\n            gradient_accumulation_steps=gradient_accumulation_steps,\n            name=name,\n            **kwargs,\n        )\n        self.rho = rho\n        self.epsilon = epsilon\n\n    def build(self, var_list):\n        if self.built:\n            return\n        super().build(var_list)\n        self._accumulated_grads, self._accumulated_delta_vars = (\n            self.add_optimizer_variables(\n                var_list, [\"accumulated_grad\", \"accumulated_delta_var\"]\n            )\n        )\n\n    def update_step(self, grad, variable, learning_rate):\n        \"\"\"Update step given gradient and the associated model variable.\"\"\"\n        lr = ops.cast(learning_rate, variable.dtype)\n        grad = ops.cast(grad, variable.dtype)\n\n        rho = self.rho\n        accumulated_grad = self._accumulated_grads[\n            self._get_variable_index(variable)\n        ]\n        accumulated_delta_var = self._accumulated_delta_vars[\n            self._get_variable_index(variable)\n        ]\n\n        def rms(x):\n            return ops.sqrt(ops.add(x, self.epsilon))\n\n        self.assign(\n            accumulated_grad,\n            ops.add(\n                rho * accumulated_grad, ops.multiply(1 - rho, ops.square(grad))\n            ),\n        )\n        delta_var = ops.negative(\n            ops.divide(\n                ops.multiply(rms(accumulated_delta_var), grad),\n                rms(accumulated_grad),\n            )\n        )\n        self.assign(\n            accumulated_delta_var,\n            ops.add(\n                ops.multiply(rho, accumulated_delta_var),\n                ops.multiply(1 - rho, ops.square(delta_var)),\n            ),\n        )\n        self.assign_add(variable, ops.multiply(lr, delta_var))\n\n    def get_config(self):\n        config = super().get_config()\n\n        config.update(\n            {\n                \"rho\": self.rho,\n                \"epsilon\": self.epsilon,\n            }\n        )\n        return config\n\n\nAdadelta.__doc__ = Adadelta.__doc__.replace(\n    \"{{base_optimizer_keyword_args}}\", optimizer.base_optimizer_keyword_args\n)\n"
  },
  {
    "path": "keras/src/optimizers/adadelta_test.py",
    "content": "import numpy as np\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.optimizers.adadelta import Adadelta\n\n\nclass AdadeltaTest(testing.TestCase):\n    def test_config(self):\n        optimizer = Adadelta(\n            learning_rate=0.5,\n            rho=0.9,\n            epsilon=1e-5,\n        )\n        self.run_class_serialization_test(optimizer)\n\n    def test_single_step(self):\n        optimizer = Adadelta(learning_rate=0.5)\n        grads = ops.array([1.0, 6.0, 7.0, 2.0])\n        vars = backend.Variable([1.0, 2.0, 3.0, 4.0])\n        optimizer.apply_gradients(zip([grads], [vars]))\n        self.assertAllClose(\n            vars, [0.9993, 1.9993, 2.9993, 3.9993], rtol=1e-4, atol=1e-4\n        )\n\n    def test_weight_decay(self):\n        grads, var1, var2, var3 = (\n            ops.zeros(()),\n            backend.Variable(2.0),\n            backend.Variable(2.0, name=\"exclude\"),\n            backend.Variable(2.0),\n        )\n        optimizer_1 = Adadelta(learning_rate=1.0, weight_decay=0.004)\n        optimizer_1.apply_gradients(zip([grads], [var1]))\n\n        optimizer_2 = Adadelta(learning_rate=1.0, weight_decay=0.004)\n        optimizer_2.exclude_from_weight_decay(var_names=[\"exclude\"])\n        optimizer_2.apply_gradients(zip([grads, grads], [var1, var2]))\n\n        optimizer_3 = Adadelta(learning_rate=1.0, weight_decay=0.004)\n        optimizer_3.exclude_from_weight_decay(var_list=[var3])\n        optimizer_3.apply_gradients(zip([grads, grads], [var1, var3]))\n\n        self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6)\n        self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6)\n        self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6)\n\n    def test_correctness_with_golden(self):\n        optimizer = Adadelta(learning_rate=1.0, rho=0.8, epsilon=1e-6)\n\n        x = backend.Variable(np.ones([10]))\n        grads = ops.arange(0.1, 1.1, 0.1)\n        first_grads = ops.full((10,), 0.01)\n\n        golden = np.tile(\n            [[0.9978], [0.9947], [0.9915], [0.9882], [0.9849]], (1, 10)\n        )\n\n        optimizer.apply_gradients(zip([first_grads], [x]))\n        for i in range(5):\n            self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4)\n            optimizer.apply_gradients(zip([grads], [x]))\n\n    def test_clip_norm(self):\n        optimizer = Adadelta(clipnorm=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])\n\n    def test_clip_value(self):\n        optimizer = Adadelta(clipvalue=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [1.0, 1.0])\n"
  },
  {
    "path": "keras/src/optimizers/adafactor.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.optimizers import optimizer\n\n\n@keras_export([\"keras.optimizers.Adafactor\"])\nclass Adafactor(optimizer.Optimizer):\n    \"\"\"Optimizer that implements the Adafactor algorithm.\n\n    Adafactor is commonly used in NLP tasks, and has the advantage\n    of taking less memory because it only saves partial information of previous\n    gradients.\n\n    The default argument setup is based on the original paper (see reference).\n    When gradients are of dimension > 2, Adafactor optimizer will delete the\n    last 2 dimensions separately in its accumulator variables.\n\n    Args:\n        learning_rate: A float, a\n            `keras.optimizers.schedules.LearningRateSchedule` instance, or\n            a callable that takes no arguments and returns the actual value to\n            use. The learning rate. Defaults to `0.001`.\n        beta_2_decay: float, defaults to -0.8. The decay rate of `beta_2`.\n        epsilon_1: float, defaults to 1e-30. A small offset to keep denominator\n            away from 0.\n        epsilon_2: float, defaults to 1e-3. A small offset to avoid learning\n            rate becoming too small by time.\n        clip_threshold: float, defaults to 1.0. Clipping threshold. This is a\n            part of Adafactor algorithm, independent from `clipnorm`,\n            `clipvalue`, and `global_clipnorm`.\n        relative_step: bool, defaults to `True`. If `learning_rate` is a\n            constant and `relative_step=True`, learning rate will be adjusted\n            based on current iterations. This is a default learning rate decay\n            in Adafactor.\n        {{base_optimizer_keyword_args}}\n\n    Reference:\n\n    - [Shazeer, Noam et al., 2018](https://arxiv.org/abs/1804.04235).\n\n    \"\"\"\n\n    def __init__(\n        self,\n        learning_rate=0.001,\n        beta_2_decay=-0.8,\n        epsilon_1=1e-30,\n        epsilon_2=1e-3,\n        clip_threshold=1.0,\n        relative_step=True,\n        weight_decay=None,\n        clipnorm=None,\n        clipvalue=None,\n        global_clipnorm=None,\n        use_ema=False,\n        ema_momentum=0.99,\n        ema_overwrite_frequency=None,\n        loss_scale_factor=None,\n        gradient_accumulation_steps=None,\n        name=\"adafactor\",\n        **kwargs,\n    ):\n        super().__init__(\n            learning_rate=learning_rate,\n            name=name,\n            weight_decay=weight_decay,\n            clipnorm=clipnorm,\n            clipvalue=clipvalue,\n            global_clipnorm=global_clipnorm,\n            use_ema=use_ema,\n            ema_momentum=ema_momentum,\n            ema_overwrite_frequency=ema_overwrite_frequency,\n            loss_scale_factor=loss_scale_factor,\n            gradient_accumulation_steps=gradient_accumulation_steps,\n            **kwargs,\n        )\n        self.beta_2_decay = beta_2_decay\n        self.epsilon_1 = epsilon_1\n        self.epsilon_2 = epsilon_2\n        self.clip_threshold = clip_threshold\n        self.relative_step = relative_step\n\n    def build(self, var_list):\n        \"\"\"Initialize optimizer variables.\n\n        Adam optimizer has 3 types of variables: momentums, velocities and\n        velocity_hat (only set when amsgrad is applied),\n\n        Args:\n            var_list: list of model variables to build Adam variables on.\n        \"\"\"\n        if self.built:\n            return\n        super().build(var_list)\n        self._r = []\n        self._c = []\n        self._v = []\n        for var in var_list:\n            if len(var.shape) < 2:\n                # Don't factor if variable is of dimension < 2, but we still\n                # need to create dummy variables as placeholder.\n                self._r.append(\n                    backend.Variable(0, name=var.name, trainable=False)\n                )\n                self._c.append(\n                    backend.Variable(0, name=var.name, trainable=False)\n                )\n            elif self._overwrite_variable_with_gradient(var):\n                self._r.append(None)\n                self._c.append(None)\n            else:\n                # Always factor the last 2 dimensions.\n                r_shape = var.shape[:-1]\n                c_shape = var.shape[:-2] + (var.shape[-1],)\n                self._r.append(\n                    self.add_variable(\n                        shape=r_shape,\n                        dtype=var.dtype,\n                        name=var.name,\n                    )\n                )\n                self._c.append(\n                    self.add_variable(\n                        shape=c_shape,\n                        dtype=var.dtype,\n                        name=var.name,\n                    )\n                )\n\n            if self._overwrite_variable_with_gradient(var):\n                self._v.append(None)\n            else:\n                self._v.append(\n                    self.add_variable_from_reference(\n                        reference_variable=var, name=\"velocity\"\n                    )\n                )\n\n    def _rms(self, x):\n        return ops.sqrt(ops.mean(ops.square(x)))\n\n    def update_step(self, gradient, variable, learning_rate):\n        \"\"\"Update step given gradient and the associated model variable.\"\"\"\n\n        lr = ops.cast(learning_rate, variable.dtype)\n        gradient = ops.cast(gradient, variable.dtype)\n        epsilon_2 = ops.cast(self.epsilon_2, variable.dtype)\n        one = ops.cast(1.0, variable.dtype)\n        local_step = ops.cast(self.iterations + 1, variable.dtype)\n        if not callable(self._learning_rate) and self.relative_step:\n            lr = ops.minimum(lr, 1 / ops.sqrt(local_step))\n\n        r = self._r[self._get_variable_index(variable)]\n        c = self._c[self._get_variable_index(variable)]\n        v = self._v[self._get_variable_index(variable)]\n\n        rho_t = ops.minimum(lr, 1 / ops.sqrt(local_step))\n        alpha_t = ops.maximum(epsilon_2, self._rms(variable)) * rho_t\n        regulated_grad_square = ops.add(ops.square(gradient), self.epsilon_1)\n        beta_2_t = ops.subtract(1, ops.power(local_step, self.beta_2_decay))\n\n        if len(variable.shape) >= 2:\n            # `r` deletes the last dimension of gradient, so it is of shape\n            # `gradient.shape[:-1]`.\n            self.assign(\n                r,\n                ops.add(\n                    ops.multiply(beta_2_t, r),\n                    ops.multiply(\n                        ops.subtract(1, beta_2_t),\n                        ops.mean(regulated_grad_square, axis=-1),\n                    ),\n                ),\n            )\n            # `c` deletes the second last dimension of gradient, so it is of\n            # shape `gradient.shape[:-2] + gradient.shape[-1]`.\n            self.assign(\n                c,\n                ops.add(\n                    ops.multiply(beta_2_t, c),\n                    ops.multiply(\n                        ops.subtract(1, beta_2_t),\n                        ops.mean(regulated_grad_square, axis=-2),\n                    ),\n                ),\n            )\n            self.assign(\n                v,\n                ops.multiply(\n                    ops.expand_dims(\n                        ops.divide(r, ops.mean(r, axis=-1, keepdims=True)),\n                        axis=-1,\n                    ),\n                    ops.expand_dims(c, -2),\n                ),\n            )\n        else:\n            self.assign(\n                v,\n                ops.add(\n                    ops.multiply(beta_2_t, v),\n                    ops.multiply(\n                        ops.subtract(1, beta_2_t), regulated_grad_square\n                    ),\n                ),\n            )\n\n        u_t = ops.divide(gradient, ops.sqrt(v))\n        u_t_hat = ops.divide(\n            u_t,\n            ops.maximum(one, ops.divide(self._rms(u_t), self.clip_threshold)),\n        )\n        self.assign_sub(variable, ops.multiply(alpha_t, u_t_hat))\n\n    def get_config(self):\n        config = super().get_config()\n\n        config.update(\n            {\n                \"beta_2_decay\": self.beta_2_decay,\n                \"epsilon_1\": self.epsilon_1,\n                \"epsilon_2\": self.epsilon_2,\n                \"clip_threshold\": self.clip_threshold,\n                \"relative_step\": self.relative_step,\n            }\n        )\n        return config\n\n\nAdafactor.__doc__ = Adafactor.__doc__.replace(\n    \"{{base_optimizer_keyword_args}}\", optimizer.base_optimizer_keyword_args\n)\n"
  },
  {
    "path": "keras/src/optimizers/adafactor_test.py",
    "content": "# flake8: noqa\n\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src import testing\nfrom keras.src.optimizers.adafactor import Adafactor\n\n\nclass AdafactorTest(testing.TestCase):\n    def test_config(self):\n        optimizer = Adafactor(\n            learning_rate=0.5,\n            beta_2_decay=-0.65,\n            epsilon_1=1e-15,\n            epsilon_2=1e-4,\n            clip_threshold=0.9,\n            relative_step=False,\n        )\n        self.run_class_serialization_test(optimizer)\n\n    def test_single_step_1d(self):\n        optimizer = Adafactor(learning_rate=0.5)\n        grads = np.array([1.0, 6.0, 7.0, 2.0])\n        vars = backend.Variable([1.0, 2.0, 3.0, 4.0])\n        optimizer.apply_gradients(zip([grads], [vars]))\n        self.assertAllClose(\n            vars, [-0.3693, 0.6307, 1.6307, 2.6307], rtol=1e-4, atol=1e-4\n        )\n\n    def test_single_step_2d(self):\n        optimizer = Adafactor(learning_rate=0.5)\n        grads = np.array([[1.0, 6.0], [7.0, 2.0]])\n        vars = backend.Variable([[1.0, 2.0], [3.0, 4.0]])\n        optimizer.apply_gradients(zip([grads], [vars]))\n        self.assertAllClose(\n            vars, [[0.7007, -0.0081], [1.2492, 3.4407]], rtol=1e-4, atol=1e-4\n        )\n\n    def test_weight_decay(self):\n        grads, var1, var2, var3 = (\n            np.zeros(()),\n            backend.Variable(2.0),\n            backend.Variable(2.0, name=\"exclude\"),\n            backend.Variable(2.0),\n        )\n        optimizer_1 = Adafactor(learning_rate=1.0, weight_decay=0.004)\n        optimizer_1.apply_gradients(zip([grads], [var1]))\n\n        optimizer_2 = Adafactor(learning_rate=1.0, weight_decay=0.004)\n        optimizer_2.exclude_from_weight_decay(var_names=[\"exclude\"])\n        optimizer_2.apply_gradients(zip([grads, grads], [var1, var2]))\n\n        optimizer_3 = Adafactor(learning_rate=1.0, weight_decay=0.004)\n        optimizer_3.exclude_from_weight_decay(var_list=[var3])\n        optimizer_3.apply_gradients(zip([grads, grads], [var1, var3]))\n\n        self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6)\n        self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6)\n        self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6)\n\n    def test_correctness_with_golden(self):\n        optimizer = Adafactor(\n            learning_rate=0.5,\n            beta_2_decay=-0.65,\n            epsilon_1=1e-15,\n            epsilon_2=1e-4,\n            clip_threshold=0.9,\n            relative_step=False,\n        )\n\n        x = backend.Variable(np.ones([10]))\n        grads = np.arange(0.1, 1.1, 0.1)\n        first_grads = np.full((10,), 0.01)\n\n        # fmt: off\n        golden = np.array(\n            [[0.55, 0.55, 0.55, 0.55, 0.55, 0.55, 0.55, 0.55, 0.55, 0.55],\n            [0.3031, 0.3026, 0.3025, 0.3024, 0.3024, 0.3024, 0.3024, 0.3024, 0.3024, 0.3024],\n            [0.1671, 0.1665, 0.1663, 0.1663, 0.1663, 0.1663, 0.1663, 0.1663, 0.1663, 0.1663],\n            [0.0923, 0.0916, 0.0915, 0.0914, 0.0914, 0.0914, 0.0914, 0.0914, 0.0914, 0.0914],\n            [0.0554, 0.0548, 0.0546, 0.0546, 0.0546, 0.0546, 0.0546, 0.0545, 0.0545, 0.0545]]\n        )\n        # fmt: on\n\n        optimizer.apply_gradients(zip([first_grads], [x]))\n        for i in range(5):\n            self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4)\n            optimizer.apply_gradients(zip([grads], [x]))\n\n    def test_clip_norm(self):\n        optimizer = Adafactor(clipnorm=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])\n\n    def test_clip_value(self):\n        optimizer = Adafactor(clipvalue=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [1.0, 1.0])\n"
  },
  {
    "path": "keras/src/optimizers/adagrad.py",
    "content": "from keras.src import initializers\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.optimizers import optimizer\n\n\n@keras_export([\"keras.optimizers.Adagrad\"])\nclass Adagrad(optimizer.Optimizer):\n    \"\"\"Optimizer that implements the Adagrad algorithm.\n\n    Adagrad is an optimizer with parameter-specific learning rates,\n    which are adapted relative to how frequently a parameter gets\n    updated during training. The more updates a parameter receives,\n    the smaller the updates.\n\n    Args:\n        learning_rate: A float, a\n            `keras.optimizers.schedules.LearningRateSchedule` instance, or\n            a callable that takes no arguments and returns the actual value to\n            use. The learning rate. Defaults to `0.001`. Note that `Adagrad`\n            tends to benefit from higher initial learning rate values compared\n            to other optimizers. To match the exact form in the original paper,\n            use `1.0`.\n        initial_accumulator_value: Floating point value. Starting value for the\n            accumulators (per-parameter momentum values). Must be non-negative.\n        epsilon: Small floating point value for maintaining numerical stability.\n        {{base_optimizer_keyword_args}}\n\n    Reference:\n\n    - [Duchi et al., 2011](\n        http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf).\n    \"\"\"\n\n    def __init__(\n        self,\n        learning_rate=0.001,\n        initial_accumulator_value=0.1,\n        epsilon=1e-7,\n        weight_decay=None,\n        clipnorm=None,\n        clipvalue=None,\n        global_clipnorm=None,\n        use_ema=False,\n        ema_momentum=0.99,\n        ema_overwrite_frequency=None,\n        loss_scale_factor=None,\n        gradient_accumulation_steps=None,\n        name=\"adagrad\",\n        **kwargs,\n    ):\n        super().__init__(\n            learning_rate=learning_rate,\n            weight_decay=weight_decay,\n            clipnorm=clipnorm,\n            clipvalue=clipvalue,\n            global_clipnorm=global_clipnorm,\n            use_ema=use_ema,\n            ema_momentum=ema_momentum,\n            ema_overwrite_frequency=ema_overwrite_frequency,\n            loss_scale_factor=loss_scale_factor,\n            gradient_accumulation_steps=gradient_accumulation_steps,\n            name=name,\n            **kwargs,\n        )\n        self.initial_accumulator_value = initial_accumulator_value\n        self.epsilon = epsilon\n\n    def build(self, var_list):\n        if self.built:\n            return\n        super().build(var_list)\n        initializer = initializers.Constant(self.initial_accumulator_value)\n        self._accumulators = self.add_optimizer_variables(\n            var_list, \"accumulator\", initializer=initializer\n        )\n\n    def update_step(self, gradient, variable, learning_rate):\n        \"\"\"Update step given gradient and the associated model variable.\"\"\"\n        lr = ops.cast(learning_rate, variable.dtype)\n        gradient = ops.cast(gradient, variable.dtype)\n\n        accumulator = self._accumulators[self._get_variable_index(variable)]\n\n        self.assign_add(accumulator, ops.square(gradient))\n        self.assign_sub(\n            variable,\n            ops.divide(\n                ops.multiply(lr, gradient),\n                ops.sqrt(ops.add(accumulator, self.epsilon)),\n            ),\n        )\n\n    def get_config(self):\n        config = super().get_config()\n\n        config.update(\n            {\n                \"initial_accumulator_value\": self.initial_accumulator_value,\n                \"epsilon\": self.epsilon,\n            }\n        )\n        return config\n\n\nAdagrad.__doc__ = Adagrad.__doc__.replace(\n    \"{{base_optimizer_keyword_args}}\", optimizer.base_optimizer_keyword_args\n)\n"
  },
  {
    "path": "keras/src/optimizers/adagrad_test.py",
    "content": "# flake8: noqa\n\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.optimizers.adagrad import Adagrad\n\n\nclass AdagradTest(testing.TestCase):\n    def test_config(self):\n        optimizer = Adagrad(\n            learning_rate=0.5,\n            initial_accumulator_value=0.2,\n            epsilon=1e-5,\n        )\n        self.run_class_serialization_test(optimizer)\n\n    def test_single_step(self):\n        optimizer = Adagrad(learning_rate=0.5)\n        grads = ops.array([1.0, 6.0, 7.0, 2.0])\n        vars = backend.Variable([1.0, 2.0, 3.0, 4.0])\n        optimizer.apply_gradients(zip([grads], [vars]))\n        self.assertAllClose(\n            vars, [0.5233, 1.5007, 2.5005, 3.5061], rtol=1e-4, atol=1e-4\n        )\n\n    def test_weight_decay(self):\n        grads, var1, var2, var3 = (\n            ops.zeros(()),\n            backend.Variable(2.0),\n            backend.Variable(2.0, name=\"exclude\"),\n            backend.Variable(2.0),\n        )\n        optimizer_1 = Adagrad(learning_rate=1.0, weight_decay=0.004)\n        optimizer_1.apply_gradients(zip([grads], [var1]))\n\n        optimizer_2 = Adagrad(learning_rate=1.0, weight_decay=0.004)\n        optimizer_2.exclude_from_weight_decay(var_names=[\"exclude\"])\n        optimizer_2.apply_gradients(zip([grads, grads], [var1, var2]))\n\n        optimizer_3 = Adagrad(learning_rate=1.0, weight_decay=0.004)\n        optimizer_3.exclude_from_weight_decay(var_list=[var3])\n        optimizer_3.apply_gradients(zip([grads, grads], [var1, var3]))\n\n        self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6)\n        self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6)\n        self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6)\n\n    def test_correctness_with_golden(self):\n        optimizer = Adagrad(\n            learning_rate=0.2, initial_accumulator_value=0.3, epsilon=1e-6\n        )\n\n        x = backend.Variable(np.ones([10]))\n        grads = ops.arange(0.1, 1.1, 0.1)\n        first_grads = ops.full((10,), 0.01)\n\n        # fmt: off\n        golden = np.array(\n            [[0.9963, 0.9963, 0.9963, 0.9963, 0.9963, 0.9963, 0.9963, 0.9963, 0.9963, 0.9963],\n            [0.9604, 0.9278, 0.9003, 0.8784, 0.8615, 0.8487, 0.8388, 0.8313, 0.8255, 0.8209],\n            [0.9251, 0.8629, 0.8137, 0.7768, 0.7497, 0.7298, 0.7151, 0.704, 0.6956, 0.6891],\n            [0.8903, 0.8012, 0.7342, 0.6862, 0.6521, 0.6277, 0.6099, 0.5967, 0.5867, 0.579],\n            [0.856, 0.7422, 0.6604, 0.6037, 0.5644, 0.5367, 0.5168, 0.5021, 0.491, 0.4825]]\n        )\n        # fmt: on\n\n        optimizer.apply_gradients(zip([first_grads], [x]))\n        for i in range(5):\n            self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4)\n            optimizer.apply_gradients(zip([grads], [x]))\n\n    def test_clip_norm(self):\n        optimizer = Adagrad(clipnorm=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])\n\n    def test_clip_value(self):\n        optimizer = Adagrad(clipvalue=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [1.0, 1.0])\n"
  },
  {
    "path": "keras/src/optimizers/adam.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.optimizers import optimizer\n\n\n@keras_export([\"keras.optimizers.Adam\"])\nclass Adam(optimizer.Optimizer):\n    \"\"\"Optimizer that implements the Adam algorithm.\n\n    Adam optimization is a stochastic gradient descent method that is based on\n    adaptive estimation of first-order and second-order moments.\n\n    According to\n    [Kingma et al., 2014](http://arxiv.org/abs/1412.6980),\n    the method is \"*computationally\n    efficient, has little memory requirement, invariant to diagonal rescaling of\n    gradients, and is well suited for problems that are large in terms of\n    data/parameters*\".\n\n    Args:\n        learning_rate: A float, a\n            `keras.optimizers.schedules.LearningRateSchedule` instance, or\n            a callable that takes no arguments and returns the actual value to\n            use. The learning rate. Defaults to `0.001`.\n        beta_1: A float value or a constant float tensor, or a callable\n            that takes no arguments and returns the actual value to use. The\n            exponential decay rate for the 1st moment estimates. Defaults to\n            `0.9`.\n        beta_2: A float value or a constant float tensor, or a callable\n            that takes no arguments and returns the actual value to use. The\n            exponential decay rate for the 2nd moment estimates. Defaults to\n            `0.999`.\n        epsilon: A small constant for numerical stability. This epsilon is\n            \"epsilon hat\" in the Kingma and Ba paper (in the formula just before\n            Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults\n            to `1e-7`.\n        amsgrad: Boolean. Whether to apply AMSGrad variant of this algorithm\n            from the paper \"On the Convergence of Adam and beyond\". Defaults\n            to `False`.\n        {{base_optimizer_keyword_args}}\n    \"\"\"\n\n    def __init__(\n        self,\n        learning_rate=0.001,\n        beta_1=0.9,\n        beta_2=0.999,\n        epsilon=1e-7,\n        amsgrad=False,\n        weight_decay=None,\n        clipnorm=None,\n        clipvalue=None,\n        global_clipnorm=None,\n        use_ema=False,\n        ema_momentum=0.99,\n        ema_overwrite_frequency=None,\n        loss_scale_factor=None,\n        gradient_accumulation_steps=None,\n        name=\"adam\",\n        **kwargs,\n    ):\n        super().__init__(\n            learning_rate=learning_rate,\n            name=name,\n            weight_decay=weight_decay,\n            clipnorm=clipnorm,\n            clipvalue=clipvalue,\n            global_clipnorm=global_clipnorm,\n            use_ema=use_ema,\n            ema_momentum=ema_momentum,\n            ema_overwrite_frequency=ema_overwrite_frequency,\n            loss_scale_factor=loss_scale_factor,\n            gradient_accumulation_steps=gradient_accumulation_steps,\n            **kwargs,\n        )\n        self.beta_1 = beta_1\n        self.beta_2 = beta_2\n        self.epsilon = epsilon\n        self.amsgrad = amsgrad\n\n    def build(self, var_list):\n        \"\"\"Initialize optimizer variables.\n\n        Adam optimizer has 3 types of variables: momentums, velocities and\n        velocity_hat (only set when amsgrad is applied),\n\n        Args:\n            var_list: list of model variables to build Adam variables on.\n        \"\"\"\n        if self.built:\n            return\n        super().build(var_list)\n        self._momentums, self._velocities = self.add_optimizer_variables(\n            var_list, [\"momentum\", \"velocity\"]\n        )\n\n        if self.amsgrad:\n            self._velocity_hats = self.add_optimizer_variables(\n                var_list, \"velocity_hat\"\n            )\n\n    def update_step(self, gradient, variable, learning_rate):\n        \"\"\"Update step given gradient and the associated model variable.\"\"\"\n        lr = ops.cast(learning_rate, variable.dtype)\n        gradient = ops.cast(gradient, variable.dtype)\n        local_step = ops.cast(self.iterations + 1, variable.dtype)\n        beta_1_power = ops.power(\n            ops.cast(self.beta_1, variable.dtype), local_step\n        )\n        beta_2_power = ops.power(\n            ops.cast(self.beta_2, variable.dtype), local_step\n        )\n\n        m = self._momentums[self._get_variable_index(variable)]\n        v = self._velocities[self._get_variable_index(variable)]\n\n        alpha = lr * ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)\n\n        self.assign_add(\n            m, ops.multiply(ops.subtract(gradient, m), 1 - self.beta_1)\n        )\n        self.assign_add(\n            v,\n            ops.multiply(\n                ops.subtract(ops.square(gradient), v), 1 - self.beta_2\n            ),\n        )\n        if self.amsgrad:\n            v_hat = self._velocity_hats[self._get_variable_index(variable)]\n            self.assign(v_hat, ops.maximum(v_hat, v))\n            v = v_hat\n        self.assign_sub(\n            variable,\n            ops.divide(\n                ops.multiply(m, alpha), ops.add(ops.sqrt(v), self.epsilon)\n            ),\n        )\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"beta_1\": self.beta_1,\n                \"beta_2\": self.beta_2,\n                \"epsilon\": self.epsilon,\n                \"amsgrad\": self.amsgrad,\n            }\n        )\n        return config\n\n\nAdam.__doc__ = Adam.__doc__.replace(\n    \"{{base_optimizer_keyword_args}}\", optimizer.base_optimizer_keyword_args\n)\n"
  },
  {
    "path": "keras/src/optimizers/adam_test.py",
    "content": "import numpy as np\nimport pytest\n\nimport keras\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.optimizers.adam import Adam\n\n\nclass AdamTest(testing.TestCase):\n    def test_config(self):\n        optimizer = Adam(\n            learning_rate=0.5,\n            beta_1=0.5,\n            beta_2=0.67,\n            epsilon=1e-5,\n            amsgrad=True,\n        )\n        self.run_class_serialization_test(optimizer)\n\n    def test_single_step(self):\n        optimizer = Adam(learning_rate=0.5)\n        grads = ops.array([1.0, 6.0, 7.0, 2.0])\n        vars = backend.Variable([1.0, 2.0, 3.0, 4.0])\n        optimizer.apply_gradients(zip([grads], [vars]))\n        self.assertAllClose(vars, [0.5, 1.5, 2.5, 3.5], rtol=1e-4, atol=1e-4)\n\n    def test_weight_decay(self):\n        grads, var1, var2, var3 = (\n            ops.zeros(()),\n            backend.Variable(2.0),\n            backend.Variable(2.0, name=\"exclude\"),\n            backend.Variable(2.0),\n        )\n        optimizer_1 = Adam(learning_rate=1.0, weight_decay=0.004)\n        optimizer_1.apply_gradients(zip([grads], [var1]))\n\n        optimizer_2 = Adam(learning_rate=1.0, weight_decay=0.004)\n        optimizer_2.exclude_from_weight_decay(var_names=[\"exclude\"])\n        optimizer_2.apply_gradients(zip([grads, grads], [var1, var2]))\n\n        optimizer_3 = Adam(learning_rate=1.0, weight_decay=0.004)\n        optimizer_3.exclude_from_weight_decay(var_list=[var3])\n        optimizer_3.apply_gradients(zip([grads, grads], [var1, var3]))\n\n        self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6)\n        self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6)\n        self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6)\n\n    def test_correctness_with_golden(self):\n        optimizer = Adam(amsgrad=True)\n\n        x = backend.Variable(np.ones([10], dtype=\"float32\"))\n        grads = ops.arange(0.1, 1.1, 0.1)\n        first_grads = ops.full((10,), 0.01)\n\n        golden = np.tile(\n            [[0.999], [0.9982], [0.9974], [0.9965], [0.9955]], (1, 10)\n        )\n\n        optimizer.apply_gradients(zip([first_grads], [x]))\n        for i in range(5):\n            self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4)\n            optimizer.apply_gradients(zip([grads], [x]))\n\n    def test_clip_norm(self):\n        optimizer = Adam(clipnorm=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])\n\n    def test_clip_value(self):\n        optimizer = Adam(clipvalue=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [1.0, 1.0])\n\n    @pytest.mark.requires_trainable_backend\n    def test_ema(self):\n        # TODO: test correctness\n        model = keras.Sequential([keras.layers.Dense(10)])\n        model.compile(optimizer=Adam(use_ema=True), loss=\"mse\")\n        x = keras.ops.zeros((1, 5))\n        y = keras.ops.zeros((1, 10))\n        model.fit(x, y)\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"The IndexedSlices test can only run with TF backend.\",\n    )\n    def test_clipnorm_indexed_slices(self):\n        # https://github.com/keras-team/keras/issues/18985\n        model = keras.Sequential(\n            [\n                keras.layers.Embedding(10, 4),\n                keras.layers.Flatten(),\n                keras.layers.Dense(2),\n            ]\n        )\n        model.compile(optimizer=Adam(clipnorm=100), loss=\"mse\")\n        x = keras.ops.ones((8, 5))\n        y = keras.ops.zeros((8, 2))\n        model.fit(x, y, verbose=0)\n"
  },
  {
    "path": "keras/src/optimizers/adamax.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.optimizers import optimizer\n\n\n@keras_export([\"keras.optimizers.Adamax\"])\nclass Adamax(optimizer.Optimizer):\n    \"\"\"Optimizer that implements the Adamax algorithm.\n\n    Adamax, a variant of Adam based on the infinity norm, is a first-order\n    gradient-based optimization method. Due to its capability of adjusting the\n    learning rate based on data characteristics, it is suited to learn\n    time-variant process, e.g., speech data with dynamically changed noise\n    conditions. Default parameters follow those provided in the paper (see\n    references below).\n\n    Initialization:\n\n    ```python\n    m = 0  # Initialize initial 1st moment vector\n    u = 0  # Initialize the exponentially weighted infinity norm\n    t = 0  # Initialize timestep\n    ```\n\n    The update rule for parameter `w` with gradient `g` is described at the end\n    of section 7.1 of the paper (see the reference section):\n\n    ```python\n    t += 1\n    m = beta1 * m + (1 - beta) * g\n    u = max(beta2 * u, abs(g))\n    current_lr = learning_rate / (1 - beta1 ** t)\n    w = w - current_lr * m / (u + epsilon)\n    ```\n\n    Args:\n        learning_rate: A float, a\n            `keras.optimizers.schedules.LearningRateSchedule` instance, or\n            a callable that takes no arguments and returns the actual value to\n            use. The learning rate. Defaults to `0.001`.\n        beta_1: A float value or a constant float tensor. The exponential decay\n            rate for the 1st moment estimates.\n        beta_2: A float value or a constant float tensor. The exponential decay\n            rate for the exponentially weighted infinity norm.\n        epsilon: A small constant for numerical stability.\n            {{base_optimizer_keyword_args}}\n\n    Reference:\n\n    - [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)\n    \"\"\"\n\n    def __init__(\n        self,\n        learning_rate=0.001,\n        beta_1=0.9,\n        beta_2=0.999,\n        epsilon=1e-7,\n        weight_decay=None,\n        clipnorm=None,\n        clipvalue=None,\n        global_clipnorm=None,\n        use_ema=False,\n        ema_momentum=0.99,\n        ema_overwrite_frequency=None,\n        loss_scale_factor=None,\n        gradient_accumulation_steps=None,\n        name=\"adamax\",\n        **kwargs,\n    ):\n        super().__init__(\n            learning_rate=learning_rate,\n            name=name,\n            weight_decay=weight_decay,\n            clipnorm=clipnorm,\n            clipvalue=clipvalue,\n            global_clipnorm=global_clipnorm,\n            use_ema=use_ema,\n            ema_momentum=ema_momentum,\n            ema_overwrite_frequency=ema_overwrite_frequency,\n            loss_scale_factor=loss_scale_factor,\n            gradient_accumulation_steps=gradient_accumulation_steps,\n            **kwargs,\n        )\n        self.beta_1 = beta_1\n        self.beta_2 = beta_2\n        self.epsilon = epsilon\n\n    def build(self, var_list):\n        \"\"\"Initialize optimizer variables.\n\n        Adamax optimizer has 2 types of variables: momentums (denoted as m),\n        exponentially weighted infinity norm (denoted as u).\n\n        Args:\n            var_list: list of model variables to build Adamax variables on.\n        \"\"\"\n        if self.built:\n            return\n        super().build(var_list)\n        self._m, self._u = self.add_optimizer_variables(\n            var_list, [\"momentum\", \"norm\"]\n        )\n\n    def update_step(self, gradient, variable, learning_rate):\n        \"\"\"Update step given gradient and the associated model variable.\"\"\"\n        lr = ops.cast(learning_rate, variable.dtype)\n        gradient = ops.cast(gradient, variable.dtype)\n        local_step = ops.cast(self.iterations + 1, variable.dtype)\n        beta_1_power = ops.power(\n            ops.cast(self.beta_1, variable.dtype), local_step\n        )\n\n        m = self._m[self._get_variable_index(variable)]\n        u = self._u[self._get_variable_index(variable)]\n\n        self.assign_add(\n            m, ops.multiply(ops.subtract(gradient, m), (1 - self.beta_1))\n        )\n        self.assign(\n            u, ops.maximum(ops.multiply(self.beta_2, u), ops.abs(gradient))\n        )\n        self.assign_sub(\n            variable,\n            ops.divide(\n                ops.multiply(lr, m),\n                ops.multiply((1 - beta_1_power), ops.add(u, self.epsilon)),\n            ),\n        )\n\n    def get_config(self):\n        config = super().get_config()\n\n        config.update(\n            {\n                \"beta_1\": self.beta_1,\n                \"beta_2\": self.beta_2,\n                \"epsilon\": self.epsilon,\n            }\n        )\n        return config\n\n\nAdamax.__doc__ = Adamax.__doc__.replace(\n    \"{{base_optimizer_keyword_args}}\", optimizer.base_optimizer_keyword_args\n)\n"
  },
  {
    "path": "keras/src/optimizers/adamax_test.py",
    "content": "# flake8: noqa\n\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.optimizers.adamax import Adamax\n\n\nclass AdamaxTest(testing.TestCase):\n    def test_config(self):\n        optimizer = Adamax(\n            learning_rate=0.5,\n            beta_1=0.8,\n            beta_2=0.95,\n            epsilon=1e-5,\n        )\n        self.run_class_serialization_test(optimizer)\n\n    def test_single_step(self):\n        optimizer = Adamax(learning_rate=0.5)\n        grads = ops.array([1.0, 6.0, 7.0, 2.0])\n        vars = backend.Variable([1.0, 2.0, 3.0, 4.0])\n        optimizer.apply_gradients(zip([grads], [vars]))\n        self.assertAllClose(vars, [0.5, 1.5, 2.5, 3.5], rtol=1e-4, atol=1e-4)\n\n    def test_weight_decay(self):\n        grads, var1, var2, var3 = (\n            ops.zeros(()),\n            backend.Variable(2.0),\n            backend.Variable(2.0, name=\"exclude\"),\n            backend.Variable(2.0),\n        )\n        optimizer_1 = Adamax(learning_rate=1.0, weight_decay=0.004)\n        optimizer_1.apply_gradients(zip([grads], [var1]))\n\n        optimizer_2 = Adamax(learning_rate=1.0, weight_decay=0.004)\n        optimizer_2.exclude_from_weight_decay(var_names=[\"exclude\"])\n        optimizer_2.apply_gradients(zip([grads, grads], [var1, var2]))\n\n        optimizer_3 = Adamax(learning_rate=1.0, weight_decay=0.004)\n        optimizer_3.exclude_from_weight_decay(var_list=[var3])\n        optimizer_3.apply_gradients(zip([grads, grads], [var1, var3]))\n\n        self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6)\n        self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6)\n        self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6)\n\n    def test_correctness_with_golden(self):\n        optimizer = Adamax(\n            learning_rate=0.2, beta_1=0.85, beta_2=0.95, epsilon=1e-6\n        )\n\n        x = backend.Variable(np.ones([10], dtype=\"float32\"))\n        grads = ops.arange(0.1, 1.1, 0.1)\n        first_grads = ops.full((10,), 0.01)\n\n        # fmt: off\n        golden = np.array(\n            [[0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8],\n            [0.6827, 0.6873, 0.6888, 0.6896, 0.6901, 0.6904, 0.6906, 0.6908, 0.6909, 0.691],\n            [0.5333, 0.5407, 0.5431, 0.5444, 0.5451, 0.5456, 0.546, 0.5462, 0.5464, 0.5466],\n            [0.368, 0.3773, 0.3804, 0.382, 0.3829, 0.3835, 0.384, 0.3843, 0.3846, 0.3848],\n            [0.1933, 0.204, 0.2076, 0.2094, 0.2105, 0.2112, 0.2117, 0.2121, 0.2124, 0.2126]]\n        )\n        # fmt: on\n\n        optimizer.apply_gradients(zip([first_grads], [x]))\n        for i in range(5):\n            self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4)\n            optimizer.apply_gradients(zip([grads], [x]))\n\n    def test_clip_norm(self):\n        optimizer = Adamax(clipnorm=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])\n\n    def test_clip_value(self):\n        optimizer = Adamax(clipvalue=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [1.0, 1.0])\n"
  },
  {
    "path": "keras/src/optimizers/adamw.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.optimizers import adam\nfrom keras.src.optimizers import optimizer\n\n\n@keras_export([\"keras.optimizers.AdamW\"])\nclass AdamW(adam.Adam):\n    \"\"\"Optimizer that implements the AdamW algorithm.\n\n    AdamW optimization is a stochastic gradient descent method that is based on\n    adaptive estimation of first-order and second-order moments with an added\n    method to decay weights per the techniques discussed in the paper,\n    'Decoupled Weight Decay Regularization' by\n    [Loshchilov, Hutter et al., 2019](https://arxiv.org/abs/1711.05101).\n\n    According to\n    [Kingma et al., 2014](http://arxiv.org/abs/1412.6980),\n    the underlying Adam method is \"*computationally\n    efficient, has little memory requirement, invariant to diagonal rescaling of\n    gradients, and is well suited for problems that are large in terms of\n    data/parameters*\".\n\n    Args:\n        learning_rate: A float, a\n            `keras.optimizers.schedules.LearningRateSchedule` instance, or\n            a callable that takes no arguments and returns the actual value to\n            use. The learning rate. Defaults to `0.001`.\n        beta_1: A float value or a constant float tensor, or a callable\n            that takes no arguments and returns the actual value to use. The\n            exponential decay rate for the 1st moment estimates.\n            Defaults to `0.9`.\n        beta_2: A float value or a constant float tensor, or a callable\n            that takes no arguments and returns the actual value to use. The\n            exponential decay rate for the 2nd moment estimates.\n            Defaults to `0.999`.\n        epsilon: A small constant for numerical stability. This epsilon is\n            \"epsilon hat\" in the Kingma and Ba paper (in the formula just\n            before Section 2.1), not the epsilon in Algorithm 1 of the paper.\n            Defaults to 1e-7.\n        amsgrad: Boolean. Whether to apply AMSGrad variant of this algorithm\n            from the paper \"On the Convergence of Adam and beyond\".\n            Defaults to `False`.\n        {{base_optimizer_keyword_args}}\n\n    References:\n\n    - [Loshchilov et al., 2019](https://arxiv.org/abs/1711.05101)\n    - [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) for `adam`\n    - [Reddi et al., 2018](\n        https://openreview.net/pdf?id=ryQu7f-RZ) for `amsgrad`.\n    \"\"\"\n\n    def __init__(\n        self,\n        learning_rate=0.001,\n        weight_decay=0.004,\n        beta_1=0.9,\n        beta_2=0.999,\n        epsilon=1e-7,\n        amsgrad=False,\n        clipnorm=None,\n        clipvalue=None,\n        global_clipnorm=None,\n        use_ema=False,\n        ema_momentum=0.99,\n        ema_overwrite_frequency=None,\n        loss_scale_factor=None,\n        gradient_accumulation_steps=None,\n        name=\"adamw\",\n        **kwargs,\n    ):\n        super().__init__(\n            learning_rate=learning_rate,\n            beta_1=beta_1,\n            beta_2=beta_2,\n            epsilon=epsilon,\n            amsgrad=amsgrad,\n            name=name,\n            weight_decay=weight_decay,\n            clipnorm=clipnorm,\n            clipvalue=clipvalue,\n            global_clipnorm=global_clipnorm,\n            use_ema=use_ema,\n            ema_momentum=ema_momentum,\n            ema_overwrite_frequency=ema_overwrite_frequency,\n            loss_scale_factor=loss_scale_factor,\n            gradient_accumulation_steps=gradient_accumulation_steps,\n            **kwargs,\n        )\n\n        if self.weight_decay is None:\n            raise ValueError(\n                \"Argument `weight_decay` must be a float. Received: \"\n                \"weight_decay=None\"\n            )\n\n\nAdamW.__doc__ = AdamW.__doc__.replace(\n    \"{{base_optimizer_keyword_args}}\", optimizer.base_optimizer_keyword_args\n)\n"
  },
  {
    "path": "keras/src/optimizers/adamw_test.py",
    "content": "# flake8: noqa\n\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.optimizers.adamw import AdamW\n\n\nclass AdamWTest(testing.TestCase):\n    def test_config(self):\n        optimizer = AdamW(\n            learning_rate=0.5,\n            weight_decay=0.008,\n            beta_1=0.5,\n            beta_2=0.67,\n            epsilon=1e-5,\n            amsgrad=True,\n        )\n        self.run_class_serialization_test(optimizer)\n\n    def test_single_step(self):\n        optimizer = AdamW(learning_rate=0.5)\n        grads = ops.array([1.0, 6.0, 7.0, 2.0])\n        vars = backend.Variable([1.0, 2.0, 3.0, 4.0])\n        optimizer.apply_gradients(zip([grads], [vars]))\n        self.assertAllClose(\n            vars, [0.4980, 1.4960, 2.494, 3.492], rtol=1e-4, atol=1e-4\n        )\n\n    def test_weight_decay(self):\n        grads, var1, var2, var3 = (\n            ops.zeros(()),\n            backend.Variable(2.0),\n            backend.Variable(2.0, name=\"exclude\"),\n            backend.Variable(2.0),\n        )\n        optimizer_1 = AdamW(learning_rate=1.0, weight_decay=0.004)\n        optimizer_1.apply_gradients(zip([grads], [var1]))\n\n        optimizer_2 = AdamW(learning_rate=1.0, weight_decay=0.004)\n        optimizer_2.exclude_from_weight_decay(var_names=[\"exclude\"])\n        optimizer_2.apply_gradients(zip([grads, grads], [var1, var2]))\n\n        optimizer_3 = AdamW(learning_rate=1.0, weight_decay=0.004)\n        optimizer_3.exclude_from_weight_decay(var_list=[var3])\n        optimizer_3.apply_gradients(zip([grads, grads], [var1, var3]))\n\n        self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6)\n        self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6)\n        self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6)\n\n    def test_weight_decay_is_none(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Argument `weight_decay` must be a float. \"\n            \"Received: weight_decay=None\",\n        ):\n            AdamW(learning_rate=1.0, weight_decay=None)\n\n    def test_correctness_with_golden(self):\n        optimizer = AdamW(learning_rate=1.0, weight_decay=0.5, epsilon=2)\n\n        x = backend.Variable(np.ones([10], dtype=\"float32\"))\n        grads = ops.arange(0.1, 1.1, 0.1)\n        first_grads = ops.full((10,), 0.01)\n\n        # fmt: off\n        golden = np.array(\n            [[0.4998, 0.4998, 0.4998, 0.4998, 0.4998, 0.4998, 0.4998, 0.4998, 0.4998, 0.4998],\n            [0.2486, 0.2475, 0.2463, 0.2451, 0.244, 0.2428, 0.2417, 0.2405, 0.2394, 0.2382],\n            [0.1223, 0.1198, 0.1174, 0.1149, 0.1124, 0.11, 0.1075, 0.1051, 0.1027, 0.1003],\n            [0.0586, 0.0549, 0.0512, 0.0475, 0.0439, 0.0402, 0.0366, 0.033, 0.0294, 0.0258],\n            [0.0263, 0.0215, 0.0167, 0.012, 0.0073, 0.0026, -0.0021, -0.0067, -0.0113, -0.0159]]\n        )\n        # fmt: on\n\n        optimizer.apply_gradients(zip([first_grads], [x]))\n        for i in range(5):\n            self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4)\n            optimizer.apply_gradients(zip([grads], [x]))\n\n    def test_clip_norm(self):\n        optimizer = AdamW(clipnorm=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])\n\n    def test_clip_value(self):\n        optimizer = AdamW(clipvalue=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [1.0, 1.0])\n"
  },
  {
    "path": "keras/src/optimizers/base_optimizer.py",
    "content": "import re\nimport warnings\n\nfrom keras.src import backend\nfrom keras.src import initializers\nfrom keras.src import ops\nfrom keras.src.optimizers.schedules import learning_rate_schedule\nfrom keras.src.saving import serialization_lib\nfrom keras.src.saving.keras_saveable import KerasSaveable\nfrom keras.src.utils import tracking\nfrom keras.src.utils.naming import auto_name\n\n\nclass BaseOptimizer(KerasSaveable):\n    \"\"\"Abstract optimizer base class.\n\n    If you intend to create your own optimization algorithm, please inherit from\n    this class and override the following methods:\n\n    - `build`: Create your optimizer-related variables, such as momentum\n        variables in the SGD optimizer.\n    - `update_step`: Implement your optimizer's variable updating logic.\n    - `get_config`: serialization of the optimizer.\n\n    Example:\n\n    ```python\n    class SGD(Optimizer):\n        def __init__(self, **kwargs):\n            super().__init__(**kwargs)\n            self.momentum = 0.9\n\n        def build(self, variables):\n            super().build(variables)\n            self.momentums = []\n            for variable in variables:\n                self.momentums.append(\n                    self.add_variable_from_reference(\n                        reference_variable=variable, name=\"momentum\"\n                    )\n                )\n\n        def update_step(self, gradient, variable, learning_rate):\n            learning_rate = ops.cast(learning_rate, variable.dtype)\n            gradient = ops.cast(gradient, variable.dtype)\n            m = self.momentums[self._get_variable_index(variable)]\n            self.assign(\n                m,\n                ops.subtract(\n                    ops.multiply(m, ops.cast(self.momentum, variable.dtype)),\n                    ops.multiply(gradient, learning_rate),\n                ),\n            )\n            self.assign_add(variable, m)\n\n        def get_config(self):\n            config = super().get_config()\n            config.update(\n                {\n                    \"momentum\": self.momentum,\n                    \"nesterov\": self.nesterov,\n                }\n            )\n            return config\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        learning_rate,\n        weight_decay=None,\n        clipnorm=None,\n        clipvalue=None,\n        global_clipnorm=None,\n        use_ema=False,\n        ema_momentum=0.99,\n        ema_overwrite_frequency=None,\n        loss_scale_factor=None,\n        gradient_accumulation_steps=None,\n        name=None,\n        **kwargs,\n    ):\n        self._lock = False\n\n        if kwargs.pop(\"decay\", None) is not None:\n            warnings.warn(\n                \"Argument `decay` is no longer supported and will be ignored.\"\n            )\n        if kwargs:\n            raise ValueError(f\"Argument(s) not recognized: {kwargs}\")\n\n        if name is None:\n            name = auto_name(self.__class__.__name__)\n        self.name = name\n        self.weight_decay = weight_decay\n        self.clipnorm = clipnorm\n        self.global_clipnorm = global_clipnorm\n        self.clipvalue = clipvalue\n        self.use_ema = use_ema\n        self.loss_scale_factor = loss_scale_factor\n        self.gradient_accumulation_steps = gradient_accumulation_steps\n\n        if gradient_accumulation_steps:\n            if not gradient_accumulation_steps >= 2:\n                raise ValueError(\n                    \"`gradient_accumulation_steps` must be an integer >= 2. \"\n                    \"Received: gradient_accumulation_steps=\"\n                    f\"{gradient_accumulation_steps}\"\n                )\n\n        if use_ema:\n            # Verify the arguments related to EMA.\n            if ema_momentum > 1 or ema_momentum < 0:\n                raise ValueError(\n                    \"`ema_momentum` must be in the range [0, 1]. \"\n                    f\"Received: ema_momentum={ema_momentum}\"\n                )\n            if ema_overwrite_frequency and (\n                not isinstance(ema_overwrite_frequency, int)\n                or ema_overwrite_frequency < 1\n            ):\n                raise ValueError(\n                    \"`ema_overwrite_frequency` must be an integer >= 1 or \"\n                    \"None. Received: ema_overwrite_frequency=\"\n                    f\"{ema_overwrite_frequency}\"\n                )\n        self.ema_momentum = ema_momentum\n        self.ema_overwrite_frequency = ema_overwrite_frequency\n\n        clip_args_sum = sum(\n            a is not None for a in [clipnorm, clipvalue, global_clipnorm]\n        )\n        if clip_args_sum > 1:\n            raise ValueError(\n                \"Only one of `clipnorm`, `clipvalue` and `global_clipnorm` can \"\n                f\"be set. Received: clipnorm={clipnorm}, \"\n                f\"clipvalue={clipvalue}, global_clipnorm={global_clipnorm}\"\n            )\n        self.built = False\n\n        # Set up variable tracking.\n        self._variables = []\n        self._trainable_variables = []\n        self._tracker = tracking.Tracker(\n            {\n                \"variables\": (\n                    lambda x: isinstance(x, backend.Variable),\n                    self._variables,\n                ),\n            }\n        )\n        self._trainable_variables_indices = {}\n\n        # Create iteration variable\n        # Note: dtype=\"int\" will resolve to int32 in JAX\n        # (since int64 is disallowed in JAX) and to int64 in TF.\n        with backend.name_scope(self.name, caller=self):\n            iterations = backend.Variable(\n                0,\n                name=\"iteration\",\n                dtype=\"int\",\n                trainable=False,\n                aggregation=\"only_first_replica\",\n            )\n        self._track_variable(iterations)\n        self._iterations = iterations\n\n        # Create learning rate (schedule or variable)\n        if isinstance(\n            learning_rate, learning_rate_schedule.LearningRateSchedule\n        ):\n            self._learning_rate = learning_rate\n        elif callable(learning_rate):\n            self._learning_rate = learning_rate\n        else:\n            if not isinstance(learning_rate, float):\n                raise ValueError(\n                    \"Argument `learning_rate` should be float, or an instance \"\n                    \"of LearningRateSchedule, or a callable \"\n                    \"(that takes in the current iteration value \"\n                    \"and returns the corresponding learning rate value). \"\n                    f\"Received instead: learning_rate={learning_rate}\"\n                )\n            with backend.name_scope(self.name, caller=self):\n                learning_rate = backend.Variable(\n                    learning_rate,\n                    name=\"learning_rate\",\n                    dtype=backend.floatx(),\n                    trainable=False,\n                    aggregation=\"only_first_replica\",\n                )\n            self._track_variable(learning_rate)\n            self._learning_rate = learning_rate\n\n    @property\n    def iterations(self):\n        if self.gradient_accumulation_steps:\n            return ops.floor_divide(\n                self._iterations, self.gradient_accumulation_steps\n            )\n\n        return self._iterations\n\n    def _track_variable(self, variable):\n        self._tracker.add_to_store(\"variables\", variable)\n\n    def _overwrite_variable_with_gradient(self, variable):\n        return getattr(variable, \"overwrite_with_gradient\", False)\n\n    @tracking.no_automatic_dependency_tracking\n    def build(self, variables):\n        if self.use_ema:\n            self._model_variables_moving_average = self.add_optimizer_variables(\n                variables, \"average\"\n            )\n        if self.gradient_accumulation_steps:\n            self._accumulated_gradients = []\n        for i, variable in enumerate(variables):\n            self._trainable_variables_indices[self._var_key(variable)] = i\n            if self.gradient_accumulation_steps:\n                self._accumulated_gradients.append(\n                    self.add_variable_from_reference(\n                        variable,\n                        name=\"gradient_accumulator\",\n                    )\n                )\n        self._trainable_variables = variables[:]\n        self.built = True\n\n    def _var_key(self, variable):\n        # Helper function to get a stable ID and the variable instance mapping.\n        return id(variable)\n\n    @property\n    def variables(self):\n        return self._variables[:]\n\n    def _get_variable_index(self, variable):\n        return self._trainable_variables_indices[self._var_key(variable)]\n\n    def add_variable(\n        self,\n        shape,\n        initializer=\"zeros\",\n        dtype=None,\n        aggregation=\"none\",\n        layout=None,\n        name=None,\n    ):\n        \"\"\"Add a variable to the optimizer.\n\n        Args:\n            shape: Shape tuple for the variable. Must be fully-defined\n                (no `None` entries).\n            initializer: Initializer object to use to populate the initial\n                variable value, or string name of a built-in initializer\n                (e.g. `\"random_normal\"`). Defaults to `\"zeros\"`.\n            dtype: Dtype of the variable to create, e.g. `\"float32\"`. If\n                unspecified, defaults to the `keras.backend.floatx()`.\n            aggregation: Optional string, one of `None`, `\"none\"`, `\"mean\"`,\n                `\"sum\"` or `\"only_first_replica\"`. Annotates the variable with\n                the type of multi-replica aggregation to be used for this\n                variable when writing custom data parallel training loops.\n                Defaults to `\"none\"`.\n            layout: Optional tensor layout.  Defaults to `None`.\n            name: String name of the variable. Useful for debugging purposes.\n\n        Returns:\n            An optimizer variable, in the format of `keras.Variable`.\n        \"\"\"\n        self._check_super_called()\n        initializer = initializers.get(initializer)\n        with backend.name_scope(self.name, caller=self):\n            variable = backend.Variable(\n                initializer=initializer,\n                shape=shape,\n                dtype=dtype,\n                trainable=False,\n                aggregation=aggregation,\n                layout=layout,\n                name=name,\n            )\n        self._track_variable(variable)\n        return variable\n\n    def add_variable_from_reference(\n        self, reference_variable, name=None, initializer=\"zeros\"\n    ):\n        \"\"\"Add an optimizer variable from the model variable.\n\n        Create an optimizer variable based on the information of model variable.\n        For example, in SGD optimizer momemtum, for each model variable, a\n        corresponding momemtum variable is created of the same shape and dtype.\n\n        Args:\n            reference_variable: `keras.Variable`. The corresponding model\n                variable to the optimizer variable to be created.\n            name: Optional string. The name prefix of the optimizer variable to\n                be created. If not provided, it will be set to `\"var\"`. The\n                variable name will follow the pattern\n                `{variable_name}_{reference_variable.name}`,\n                e.g., `momemtum/dense_1`. Defaults to `None`.\n            initializer: Initializer object to use to populate the initial\n                variable value, or string name of a built-in initializer\n                (e.g. `\"random_normal\"`). If unspecified, defaults to\n                `\"zeros\"`.\n\n        Returns:\n            An optimizer variable, in the format of `keras.Variable`.\n        \"\"\"\n        name = name or \"var\"\n        if hasattr(reference_variable, \"path\"):\n            name = f\"{reference_variable.path.replace('/', '_')}_{name}\"\n        else:\n            sanitised_ref_name = (\n                str(reference_variable.name).replace(\"/\", \"_\").replace(\":\", \"_\")\n            )\n            name = f\"{sanitised_ref_name}_{name}\"\n        return self.add_variable(\n            shape=reference_variable.shape,\n            initializer=initializer,\n            dtype=reference_variable.dtype,\n            name=name,\n            layout=getattr(reference_variable, \"_layout\", None),\n        )\n\n    def add_optimizer_variables(\n        self, trainable_variables, name, initializer=\"zeros\"\n    ):\n        \"\"\"Add optimizer variables from the list of trainable model variables.\n\n        Create an optimizer variable based on the information of the supplied\n        model variables.  For example, in SGD optimizer momemtum, for each model\n        variable, a corresponding momemtum variable is created of the same shape\n        and dtype.\n\n        Note that trainable variables with `v.overwrite_with_gradient == True`\n        will insert `None`, into the output list, since the optimizer variable\n        will not be used anyways, and could be wasteful.\n\n        Args:\n            trainable_variables: `keras.Variable`, the corresponding model\n                variable to the optimizer variable to be created.\n            name: The name prefix(es) of the optimizer variable(s) to be\n                created. Can be a single string or list of strings.  If a\n                list of strings, will create an optimizer variable for each\n                prefix.  The variable name will follow the pattern\n                `{variable_name}_{trainable_variable.name}`, e.g.,\n                `momemtum/dense_1`.\n            initializer: Initializer object(s) to use to populate the initial\n                variable value(s), or string name of a built-in initializer\n                (e.g. `\"random_normal\"`). If unspecified, defaults to\n                `\"zeros\"`.\n\n        Returns:\n            A list of optimizer variables, in the format of `keras.Variable`s.\n            If multiple names are provide, returns a tuple of lists.\n        \"\"\"\n        name_list = name\n        initializer_list = initializer\n        if isinstance(name, str):\n            # Single name/initializer.\n            name_list = [name]\n            initializer_list = [initializer]\n        else:\n            # Multiple names/initializers.\n            # If there is only one initializer, use it for all names.\n            if isinstance(initializer, str) or isinstance(\n                initializer, initializers.Initializer\n            ):\n                initializer_list = [initializer] * len(name_list)\n\n        if len(name_list) != len(initializer_list):\n            raise ValueError(\n                f\"The number of provided names must match the number of \"\n                f\"provided initializers.  Received name='{name}', \"\n                f\"initializer='{initializer}'\"\n            )\n\n        # Build up lists of optimizer variables.\n        optimizer_variables = tuple([] for _ in name_list)\n        for variable in trainable_variables:\n            # Interleaves adding variables for backward-compatibility.\n            if not self._overwrite_variable_with_gradient(variable):\n                for i, (var_name, var_init) in enumerate(\n                    zip(name_list, initializer_list)\n                ):\n                    optimizer_variables[i].append(\n                        self.add_variable_from_reference(\n                            variable,\n                            name=var_name,\n                            initializer=var_init,\n                        )\n                    )\n            else:\n                for i in range(len(name_list)):\n                    optimizer_variables[i].append(None)\n\n        # If single input name, return the single list.\n        if isinstance(name, str):\n            return optimizer_variables[0]\n\n        return optimizer_variables\n\n    def _check_variables_are_known(self, variables):\n        for v in variables:\n            if self._var_key(v) not in self._trainable_variables_indices:\n                raise ValueError(\n                    f\"Unknown variable: {v}. This optimizer can only \"\n                    \"be called for the variables it was originally built with. \"\n                    \"When working with a new set of variables, you should \"\n                    \"recreate a new optimizer instance.\"\n                )\n\n    def assign(self, variable, value):\n        \"\"\"Assign a value to a variable.\n\n        This should be used in optimizers instead of `variable.assign(value)` to\n        support backend specific optimizations.\n        Note that the variable can be a model variable or an optimizer variable;\n        it can be a backend native variable or a Keras variable.\n\n        Args:\n            variable: The variable to update.\n            value: The value to add to the variable.\n        \"\"\"\n        variable.assign(value)\n\n    def assign_add(self, variable, value):\n        \"\"\"Add a value to a variable.\n\n        This should be used in optimizers instead of\n        `variable.assign_add(value)` to support backend specific optimizations.\n        Note that the variable can be a model variable or an optimizer variable;\n        it can be a backend native variable or a Keras variable.\n\n        Args:\n            variable: The variable to update.\n            value: The value to add to the variable.\n        \"\"\"\n        variable.assign_add(value)\n\n    def assign_sub(self, variable, value):\n        \"\"\"Subtract a value from a variable.\n\n        This should be used in optimizers instead of\n        `variable.assign_sub(value)` to support backend specific optimizations.\n        Note that the variable can be a model variable or an optimizer variable;\n        it can be a backend native variable or a Keras variable.\n\n        Args:\n            variable: The variable to update.\n            value: The value to add to the variable.\n        \"\"\"\n        variable.assign_sub(value)\n\n    def update_step(self, gradient, variable, learning_rate):\n        raise NotImplementedError\n\n    def apply_gradients(self, grads_and_vars):\n        grads, trainable_variables = zip(*grads_and_vars)\n        self.apply(grads, trainable_variables)\n        # Return iterations for compat with tf.keras.\n        return self._iterations\n\n    def apply(self, grads, trainable_variables=None):\n        \"\"\"Update traininable variables according to provided gradient values.\n\n        `grads` should be a list of gradient tensors\n        with 1:1 mapping to the list of variables the optimizer was built with.\n\n        `trainable_variables` can be provided\n        on the first call to build the optimizer.\n        \"\"\"\n        if len(grads) == 0:\n            # It is possible that the grad is empty. In this case,\n            # `apply_gradients` is a no-op.\n            return\n\n        if trainable_variables is None:\n            if not self.built:\n                raise ValueError(\n                    \"When passing `grads` without `variables`, the optimizer \"\n                    \"must already be built on a list of variables. \"\n                    \"Call `optimizer.build(trainable_variables)` first. \"\n                )\n            if len(grads) != len(self._trainable_variables_indices):\n                raise ValueError(\n                    \"When passing `grads` as a list of gradient tensors, the \"\n                    f\"gradients must match `optimizer.variables` one-to-on. \"\n                    f\"Received a list of {len(grads)} gradients, but the \"\n                    f\"optimizer is tracking {len(self._trainable_variables)} \"\n                    \"trainable variables.\"\n                )\n            trainable_variables = self._trainable_variables\n        else:\n            trainable_variables = list(trainable_variables)\n            # Optionally build optimizer.\n            if not self.built:\n                with backend.name_scope(self.name, caller=self):\n                    self.build(trainable_variables)\n                self.built = True\n            self._check_variables_are_known(trainable_variables)\n\n        with backend.name_scope(self.name, caller=self):\n            # Filter empty gradients.\n            grads, trainable_variables = self._filter_empty_gradients(\n                grads, trainable_variables\n            )\n\n            # Overwrite targeted variables directly with their gradients if\n            # their `overwrite_with_gradient` is set.\n            grads, trainable_variables = (\n                self._overwrite_variables_directly_with_gradients(\n                    grads, trainable_variables\n                )\n            )\n\n            if len(list(grads)) > 0:\n                # Unscale gradients.\n                scale = self.loss_scale_factor\n                if scale is not None:\n                    grads = [g if g is None else g / scale for g in grads]\n\n                # Apply gradient updates.\n                self._backend_apply_gradients(grads, trainable_variables)\n                # Apply variable constraints after applying gradients.\n                for variable in trainable_variables:\n                    if variable.constraint is not None:\n                        variable.assign(variable.constraint(variable))\n\n        # Update iteration counter.\n        self._iterations.assign_add(1)\n\n    def _backend_apply_gradients(self, grads, trainable_variables):\n        \"\"\"Apply method that can be overridden by different backends.\n\n        JAX overrides it in order to deal with statelessness in gradient\n        accumulation and EMA handling.\n\n        The below implementation is intended to be generally backend-agnostic,\n        but may not work with all backends.\n\n        This method does 4 things:\n        - Call the optimizer's update_step() to update trainable variables\n            and optimizer variables.\n        - Update EMA variables, if EMA is configured.\n        - Update gradient accumulators, if gradient accumulation is configured.\n        - Update the iteration counter.\n        \"\"\"\n        if self.gradient_accumulation_steps:\n            is_update_step = (\n                self._iterations + 1\n            ) % self.gradient_accumulation_steps == 0\n            # `trainable_variables` might have been filtered in previous\n            # processing steps, so we need to ensure the correct mapping between\n            # `self._accumulated_gradients` and `trainable_variables`\n            acc_grads = [\n                self._accumulated_gradients[self._get_variable_index(v)]\n                for v in trainable_variables\n            ]\n\n            def _update_step_fn(grads, trainable_variables):\n                # Run update step with accumulated grads + reset accumulators\n                steps = self.gradient_accumulation_steps\n                grads = [\n                    (g + acc_g) / steps for g, acc_g in zip(grads, acc_grads)\n                ]\n\n                # Apply clipping and weight decay.\n                grads = self._clip_gradients(grads)\n                self._apply_weight_decay(trainable_variables)\n\n                self._backend_update_step(\n                    grads, trainable_variables, self.learning_rate\n                )\n                self._backend_reset_gradient_accumulators()\n\n            ops.cond(\n                is_update_step,\n                lambda: _update_step_fn(grads, trainable_variables),\n                lambda: self._backend_increment_gradient_accumulators(\n                    grads, acc_grads\n                ),\n            )\n        else:\n            # Apply clipping and weight decay.\n            grads = self._clip_gradients(grads)\n            self._apply_weight_decay(trainable_variables)\n\n            # Run update step.\n            self._backend_update_step(\n                grads, trainable_variables, self.learning_rate\n            )\n\n        if self.use_ema:\n            self._update_model_variables_moving_average(\n                self._trainable_variables\n            )\n            if self.ema_overwrite_frequency:\n                # Only when self.ema_overwrite_frequency is not None, we\n                # overwrite the model variables.\n                should_overwrite_model_vars = (\n                    self.iterations + 1\n                ) % self.ema_overwrite_frequency == 0\n                ops.cond(\n                    should_overwrite_model_vars,\n                    lambda: self._overwrite_model_variables_with_average_value(\n                        self._trainable_variables\n                    ),\n                    lambda: None,\n                )\n\n    def _backend_update_step(self, grads, trainable_variables, learning_rate):\n        \"\"\"Collective update_step that can be overridden by the backend.\n\n        It is overridden by torch for performance reasons, and\n        by TF to support tf.distribute.\n        \"\"\"\n        for grad, var in zip(grads, trainable_variables):\n            self.update_step(grad, var, learning_rate)\n\n    def _backend_reset_gradient_accumulators(self):\n        for g_acc in self._accumulated_gradients:\n            if g_acc is not None:\n                g_acc.assign(ops.zeros(g_acc.shape, dtype=g_acc.dtype))\n\n    def _backend_increment_gradient_accumulators(self, grads, acc_grads):\n        new_g_accs = [(g + acc_g) for g, acc_g in zip(grads, acc_grads)]\n        for n_g_acc, g_acc in zip(new_g_accs, acc_grads):\n            g_acc.assign(n_g_acc)\n\n    def stateless_apply(self, optimizer_variables, grads, trainable_variables):\n        \"\"\"Stateless version of `apply` that returns modified variables.\n\n        Args:\n            optimizer_variables: list of tensors containing the current values\n                for the optimizer variables. These are native tensors and not\n                `keras.Variable`s.\n            grads: list of gradients to apply.\n            trainable_variables: list of tensors containing the current values\n                for the model variables. These are native tensors and not\n                `keras.Variable`s.\n\n        Returns: A tuple containing two list of tensors, the updated\n            `trainable_variables` and the updated `optimizer_variables`.\n        \"\"\"\n        self._check_super_called()\n\n        if not self.built:\n            raise ValueError(\n                f\"To call `stateless_apply`, {self.__class__.__name__} \"\n                \"must be built (i.e. its variables must have been created). \"\n                \"You can build it via `optimizer.build(trainable_variables)`.\"\n            )\n        if len(optimizer_variables) != len(self.variables):\n            raise ValueError(\n                \"Argument `optimizer_variables` must be a list of tensors \"\n                f\"corresponding 1:1 to {self.__class__.__name__}().variables. \"\n                f\"Received list with length {len(optimizer_variables)}, but \"\n                f\"expected {len(self.variables)} variables.\"\n            )\n        if len(trainable_variables) != len(self._trainable_variables):\n            raise ValueError(\n                \"Argument `optimizer_variables` must be a list of tensors \"\n                \"corresponding 1:1 to the trainable variables list that \"\n                \"the optimizer was built with. Received \"\n                f\"len(trainable_variables) == {len(trainable_variables)} \"\n                \"whereas the optimizer was built with \"\n                f\"{len(self._trainable_variables)} variables.\"\n            )\n\n        # Gather variable mapping\n        mapping = list(\n            zip(self._trainable_variables, trainable_variables)\n        ) + list(zip(self.variables, optimizer_variables))\n\n        # Call in stateless scope\n        with backend.StatelessScope(state_mapping=mapping) as scope:\n            self.apply(grads)\n\n        # Gather updated variables\n        trainable_variables = [\n            scope.get_current_value(v) for v in self._trainable_variables\n        ]\n        optimizer_variables = [\n            scope.get_current_value(v) for v in self.variables\n        ]\n        return trainable_variables, optimizer_variables\n\n    def scale_loss(self, loss):\n        \"\"\"Scale the loss before computing gradients.\n\n        Scales the loss before gradients are computed in a `train_step`. This\n        is primarily useful during mixed precision training to prevent numeric\n        underflow.\n        \"\"\"\n        if self.loss_scale_factor is not None:\n            return loss * self.loss_scale_factor\n        return loss\n\n    @property\n    def learning_rate(self):\n        return self._get_current_learning_rate()\n\n    @learning_rate.setter\n    def learning_rate(self, learning_rate):\n        if isinstance(self._learning_rate, backend.Variable):\n            prev_lr_var = self._learning_rate\n        else:\n            prev_lr_var = None\n        if isinstance(\n            learning_rate, learning_rate_schedule.LearningRateSchedule\n        ):\n            self._learning_rate = learning_rate\n        elif callable(learning_rate):\n            self._learning_rate = learning_rate\n        else:\n            if isinstance(\n                self._learning_rate, learning_rate_schedule.LearningRateSchedule\n            ):\n                raise TypeError(\n                    \"This optimizer was created with a `LearningRateSchedule`\"\n                    \" object as its `learning_rate` constructor argument, \"\n                    \"hence its learning rate is not settable. If you need the\"\n                    \" learning rate to be settable, you should instantiate \"\n                    \"the optimizer with a float `learning_rate` argument.\"\n                )\n            self._learning_rate.assign(learning_rate)\n        if prev_lr_var is not None and not isinstance(\n            self._learning_rate, backend.Variable\n        ):\n            # Untrack learning rate variable\n            self._untrack_variable(prev_lr_var)\n\n    def set_weights(self, weights):\n        \"\"\"Set the weights of the optimizer.\"\"\"\n        if not self.built:\n            raise ValueError(\n                \"You are calling `set_weights()` on an optimizer that has not \"\n                \"yet been built. Please call \"\n                \"`optimizer.build(trainable_variables)` to create the \"\n                \"optimizer weights before calling `set_weights()`.\"\n            )\n        for variable, weight in zip(self._variables, weights):\n            if variable.shape != weight.shape:\n                raise ValueError(\n                    f\"Optimizer variable {self._var_key(variable)} has shape \"\n                    f\"{str(variable.shape)} not compatible with provided \"\n                    f\"weight shape {str(weight.shape)}.\"\n                )\n            variable.assign(weight)\n\n    def save_own_variables(self, store):\n        \"\"\"Get the state of this optimizer object.\"\"\"\n        for i, variable in enumerate(self.variables):\n            store[str(i)] = variable.numpy()\n\n    def load_own_variables(self, store):\n        \"\"\"Set the state of this optimizer object.\"\"\"\n        if len(store.keys()) != len(self.variables):\n            msg = (\n                f\"Skipping variable loading for optimizer '{self.name}', \"\n                f\"because it has {len(self.variables)} variables whereas \"\n                f\"the saved optimizer has {len(store.keys())} variables. \"\n            )\n            if len(self.variables) == 0:\n                msg += (\n                    \"This is likely because the optimizer has not been \"\n                    \"called/built yet.\"\n                )\n            warnings.warn(msg, stacklevel=2)\n            return\n        for i, variable in enumerate(self.variables):\n            variable.assign(store[str(i)])\n\n    def _get_current_learning_rate(self):\n        if isinstance(\n            self._learning_rate, learning_rate_schedule.LearningRateSchedule\n        ):\n            return self._learning_rate(self._iterations)\n        elif isinstance(self._learning_rate, backend.Variable):\n            return self._learning_rate\n        elif callable(self._learning_rate):\n            return self._learning_rate()\n        return self._learning_rate\n\n    def _overwrite_variables_directly_with_gradients(self, grads, vars):\n        \"\"\"Overwrite the variables directly by their gradients.\n\n        This method is designed for a special case where we want to overwrite\n        the variable directly with its computed gradient. For example, in float8\n        training, new `scale` and `amax_history` are computed as gradients, and\n        we want to overwrite them directly instead of following the typical\n        procedure such as gradient descent with a learning rate, gradient\n        clipping and weight decaying.\n\n        After the update, the processed pairs will be filtered out.\n        \"\"\"\n        # Shortcut for `tf.Variable` because it doesn't have a\n        # `overwrite_with_gradient` attr.\n        if not any(self._overwrite_variable_with_gradient(v) for v in vars):\n            return grads, vars\n\n        # Shallow copies\n        filtered_grads = list(grads)\n        filtered_vars = list(vars)\n\n        # Iterate from right to left for safe popping\n        for i in range(len(filtered_grads) - 1, -1, -1):\n            g, v = filtered_grads[i], filtered_vars[i]\n            if self._overwrite_variable_with_gradient(v):\n                if self.gradient_accumulation_steps:\n                    # Utilize a stateless manner for JAX compatibility\n                    steps = self.gradient_accumulation_steps\n                    is_update_step = (self._iterations + 1) % steps == 0\n                    acc_g = self._accumulated_gradients[\n                        self._get_variable_index(v)\n                    ]\n                    # `ops.maximum` is utilized for gradient accumulation for\n                    # `overwrite_with_gradient=True` variables\n                    new_g_acc = ops.cond(\n                        is_update_step,\n                        lambda: ops.zeros(g.shape, dtype=g.dtype),\n                        lambda: ops.maximum(g, acc_g),\n                    )\n                    new_g = ops.cond(\n                        is_update_step,\n                        lambda: ops.maximum(g, acc_g),\n                        lambda: g,\n                    )\n                    new_v = ops.cond(\n                        is_update_step, lambda: new_g, lambda: v.value\n                    )\n                    v.assign(new_v)\n                    acc_g.assign(new_g_acc)\n                else:\n                    v.assign(g)\n                filtered_grads.pop(i)\n                filtered_vars.pop(i)\n        return filtered_grads, filtered_vars\n\n    def _filter_empty_gradients(self, grads, vars):\n        filtered_grads = list(grads)\n        filtered_vars = list(vars)\n        missing_grad_vars = []\n\n        # Iterate from right to left for safe popping\n        for i in range(len(filtered_grads) - 1, -1, -1):\n            if filtered_grads[i] is None:\n                filtered_grads.pop(i)\n                v = filtered_vars.pop(i)\n                try:\n                    missing_grad_vars.append(v.path)\n                except AttributeError:\n                    # `tf.Variable` doesn't have `path` attr.\n                    missing_grad_vars.append(v.name)\n\n        if not filtered_grads:\n            raise ValueError(\"No gradients provided for any variable.\")\n        if missing_grad_vars:\n            warnings.warn(\n                \"Gradients do not exist for variables \"\n                f\"{list(reversed(missing_grad_vars))} when minimizing the loss.\"\n                \" If using `model.compile()`, did you forget to provide a \"\n                \"`loss` argument?\"\n            )\n        return filtered_grads, filtered_vars\n\n    def _clip_gradients(self, grads):\n        if self.clipnorm and self.clipnorm > 0:\n            return [\n                self._clip_by_norm(g) if g is not None else g for g in grads\n            ]\n        elif self.global_clipnorm and self.global_clipnorm > 0:\n            return clip_by_global_norm(grads, self.global_clipnorm)\n        elif self.clipvalue and self.clipvalue > 0:\n            v = self.clipvalue\n            return [ops.clip(g, -v, v) if g is not None else g for g in grads]\n        else:\n            return grads\n\n    def exclude_from_weight_decay(self, var_list=None, var_names=None):\n        \"\"\"Exclude variables from weight decay.\n\n        This method must be called before the optimizer's `build` method is\n        called. You can set specific variables to exclude out, or set a list of\n        strings as the anchor words, if any of which appear in a variable's\n        name, then the variable is excluded.\n\n        Args:\n            var_list: A list of `Variable`s to exclude from weight decay.\n            var_names: A list of strings. If any string in `var_names` appear\n                in the model variable's name, then this model variable is\n                excluded from weight decay. For example, `var_names=['bias']`\n                excludes all bias variables from weight decay.\n        \"\"\"\n        if hasattr(self, \"_built\") and self._built:\n            raise ValueError(\n                \"`exclude_from_weight_decay()` can only be configured before \"\n                \"the optimizer is built.\"\n            )\n\n        # Use a `set` for the ids of `var_list` to speed up the searching\n        if var_list:\n            self._exclude_from_weight_decay = set(\n                self._var_key(variable) for variable in var_list\n            )\n        else:\n            self._exclude_from_weight_decay = set()\n\n        # Precompile the pattern for `var_names` to speed up the searching\n        if var_names and len(var_names) > 0:\n            self._exclude_from_weight_decay_pattern = re.compile(\n                \"|\".join(set(var_names))\n            )\n        else:\n            self._exclude_from_weight_decay_pattern = None\n\n        # Reset cache\n        self._exclude_from_weight_decay_cache = dict()\n\n    def _use_weight_decay(self, variable):\n        variable_id = self._var_key(variable)\n\n        # Immediately return the value if `variable_id` hits the cache\n        if not hasattr(self, \"_exclude_from_weight_decay_cache\"):\n            self._exclude_from_weight_decay_cache = dict()\n        if variable_id in self._exclude_from_weight_decay_cache:\n            return self._exclude_from_weight_decay_cache[variable_id]\n\n        # Determine whether the variable should apply weight decay or not\n        exclude_from_weight_decay = getattr(\n            self, \"_exclude_from_weight_decay\", set()\n        )\n        exclude_from_weight_decay_pattern = getattr(\n            self, \"_exclude_from_weight_decay_pattern\", None\n        )\n        if variable_id in exclude_from_weight_decay:\n            self._exclude_from_weight_decay_cache[variable_id] = False\n            return False\n        if exclude_from_weight_decay_pattern is not None:\n            if (\n                re.search(exclude_from_weight_decay_pattern, variable.name)\n                is not None\n            ):\n                self._exclude_from_weight_decay_cache[variable_id] = False\n                return False\n        self._exclude_from_weight_decay_cache[variable_id] = True\n        return True\n\n    def _apply_weight_decay(self, variables):\n        if self.weight_decay is None:\n            return\n        for variable in variables:\n            if self._use_weight_decay(variable):\n                lr = ops.cast(self.learning_rate, variable.dtype)\n                wd = ops.cast(self.weight_decay, variable.dtype)\n                variable.assign(variable - variable * wd * lr)\n\n    def _check_super_called(self):\n        if not hasattr(self, \"_lock\"):\n            raise RuntimeError(\n                f\"In optimizer '{self.__class__.__name__}', you forgot to call \"\n                \"`super().__init__()` as the first statement \"\n                \"in the `__init__()` method. \"\n                \"Go add it!\"\n            )\n\n    def _update_model_variables_moving_average(self, trainable_variables):\n        \"\"\"Update the stored moving average using the latest value.\"\"\"\n        if self.use_ema:\n            for var, average in zip(\n                trainable_variables, self._model_variables_moving_average\n            ):\n                if average is not None:\n                    not_first_step = ops.not_equal(self.iterations, 0)\n                    momentum = ops.multiply(\n                        ops.cast(not_first_step, var.dtype), self.ema_momentum\n                    )\n                    average.assign(\n                        ops.add(\n                            ops.multiply(momentum, average),\n                            ops.multiply(ops.subtract(1, momentum), var),\n                        )\n                    )\n\n    def _overwrite_model_variables_with_average_value(\n        self, trainable_variables\n    ):\n        \"\"\"Overwrite model variables with its moving average.\"\"\"\n        if len(trainable_variables) != len(\n            self._model_variables_moving_average\n        ):\n            raise ValueError(\n                f\"The length of model variables ({len(trainable_variables)}) \"\n                \"to override does not match the length of model variables \"\n                \"stored in the optimizer \"\n                f\"({len(self._model_variables_moving_average)}). Please \"\n                \"check if the optimizer was called on your model.\"\n            )\n        for var, average_var in zip(\n            trainable_variables, self._model_variables_moving_average\n        ):\n            if average_var is not None:\n                var.assign(average_var)\n\n    def finalize_variable_values(self, var_list):\n        \"\"\"Set the final value of model's trainable variables.\n\n        Sometimes there are some extra steps before ending the variable updates,\n        such as overriding the model variables with its average value.\n\n        Args:\n          var_list: list of model variables.\n        \"\"\"\n        if self.use_ema:\n            # If the optimizer uses EMA, then when finalizing, we replace the\n            # model variable value with its moving average stored inside\n            # optimizer.\n            self._overwrite_model_variables_with_average_value(var_list)\n\n    def _obj_type(self):\n        return \"Optimizer\"\n\n    def get_config(self):\n        \"\"\"Returns the config of the optimizer.\n\n        An optimizer config is a Python dictionary (serializable)\n        containing the configuration of an optimizer.\n        The same optimizer can be reinstantiated later\n        (without any saved state) from this configuration.\n\n        Subclass optimizer should override this method to include other\n        hyperparameters.\n\n        Returns:\n            Python dictionary.\n        \"\"\"\n\n        if isinstance(\n            self._learning_rate, learning_rate_schedule.LearningRateSchedule\n        ):\n            learning_rate = learning_rate_schedule.serialize(\n                self._learning_rate\n            )\n        elif isinstance(self._learning_rate, backend.Variable):\n            learning_rate = float(self._learning_rate.numpy())\n        elif ops.is_tensor(self._learning_rate):\n            learning_rate = float(self._learning_rate)\n        elif callable(self._learning_rate):\n            learning_rate = serialization_lib.serialize_keras_object(\n                self._learning_rate\n            )\n        else:\n            learning_rate = 0.5\n\n        config = {\n            \"name\": self.name,\n            \"learning_rate\": learning_rate,\n            \"weight_decay\": self.weight_decay,\n            \"clipnorm\": self.clipnorm,\n            \"global_clipnorm\": self.global_clipnorm,\n            \"clipvalue\": self.clipvalue,\n            \"use_ema\": self.use_ema,\n            \"ema_momentum\": self.ema_momentum,\n            \"ema_overwrite_frequency\": self.ema_overwrite_frequency,\n            \"loss_scale_factor\": self.loss_scale_factor,\n            \"gradient_accumulation_steps\": self.gradient_accumulation_steps,\n        }\n        return config\n\n    @classmethod\n    def from_config(cls, config, custom_objects=None):\n        \"\"\"Creates an optimizer from its config.\n\n        This method is the reverse of `get_config`, capable of instantiating the\n        same optimizer from the config dictionary.\n\n        Args:\n            config: A Python dictionary, typically the output of get_config.\n            custom_objects: A Python dictionary mapping names to additional\n              user-defined Python objects needed to recreate this optimizer.\n\n        Returns:\n            An optimizer instance.\n        \"\"\"\n        if \"learning_rate\" in config:\n            if isinstance(config[\"learning_rate\"], dict):\n                config[\"learning_rate\"] = (\n                    serialization_lib.deserialize_keras_object(\n                        config[\"learning_rate\"], custom_objects=custom_objects\n                    )\n                )\n        return cls(**config)\n\n    def __setattr__(self, name, value):\n        # Prevent users from attaching state to the\n        # layer before `super()` is called -- since that\n        # state would silently not be tracked.\n        if name != \"_lock\":\n            self._check_super_called()\n        # Track Variables.\n        if hasattr(self, \"_tracker\"):\n            value = self._tracker.track(value)\n        return super().__setattr__(name, value)\n\n    def _clip_by_norm(self, values, axes=None):\n        # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm\n        l2sum = ops.sum(ops.square(values), axes, keepdims=True)\n        pred = l2sum > 0\n        # Two-tap tf.where trick to bypass NaN gradients\n        l2sum_safe = ops.where(pred, l2sum, ops.ones_like(l2sum))\n        l2norm = ops.where(pred, ops.sqrt(l2sum_safe), l2sum)\n        intermediate = ops.multiply(values, self.clipnorm)\n        values_clip = ops.convert_to_tensor(intermediate) / ops.maximum(\n            l2norm, self.clipnorm\n        )\n        return values_clip\n\n    def _untrack_variable(self, variable):\n        previous_lock_state = self._tracker.locked\n        self._tracker.unlock()\n        self._tracker.untrack(variable)\n        if previous_lock_state is True:\n            self._tracker.lock()\n\n\nbase_optimizer_keyword_args = \"\"\"name: String. The name to use\n            for momentum accumulator weights created by\n            the optimizer.\n        weight_decay: Float. If set, weight decay is applied.\n        clipnorm: Float. If set, the gradient of each weight is individually\n            clipped so that its norm is no higher than this value.\n        clipvalue: Float. If set, the gradient of each weight is clipped to be\n            no higher than this value.\n        global_clipnorm: Float. If set, the gradient of all weights is clipped\n            so that their global norm is no higher than this value.\n        use_ema: Boolean, defaults to `False`.\n            If `True`, exponential moving average\n            (EMA) is applied. EMA consists of computing an exponential moving\n            average of the weights of the model (as the weight values change\n            after each training batch), and periodically overwriting the\n            weights with their moving average.\n        ema_momentum: Float, defaults to 0.99. Only used if `use_ema=True`.\n            This is the momentum to use when computing\n            the EMA of the model's weights:\n            `new_average = ema_momentum * old_average + (1 - ema_momentum) *\n            current_variable_value`.\n        ema_overwrite_frequency: Int or None, defaults to None. Only used if\n            `use_ema=True`. Every `ema_overwrite_frequency` steps of iterations,\n            we overwrite the model variable by its moving average.\n            If None, the optimizer\n            does not overwrite model variables in the middle of training,\n            and you need to explicitly overwrite the variables\n            at the end of training by calling\n            `optimizer.finalize_variable_values()` (which updates the model\n            variables in-place). When using the built-in `fit()` training loop,\n            this happens automatically after the last epoch,\n            and you don't need to do anything.\n        loss_scale_factor: Float or `None`. If a float, the scale factor will\n            be multiplied the loss before computing gradients, and the inverse\n            of the scale factor will be multiplied by the gradients before\n            updating variables. Useful for preventing underflow during\n            mixed precision training. Alternately,\n            `keras.optimizers.LossScaleOptimizer` will\n            automatically set a loss scale factor.\n        gradient_accumulation_steps: Int or `None`. If an int, model & optimizer\n            variables will not be updated at every step; instead they will be\n            updated every `gradient_accumulation_steps` steps, using the average\n            value of the gradients since the last update. This is known as\n            \"gradient accumulation\". This can be useful\n            when your batch size is very small, in order to reduce gradient\n            noise at each update step. EMA frequency will look at \"accumulated\"\n            iterations value (optimizer steps // gradient_accumulation_steps).\n            Learning rate schedules will look at \"real\" iterations value\n            (optimizer steps).\n\"\"\"\n\n\ndef global_norm(value_list):\n    \"\"\"Computes the global norm of multiple tensors.\"\"\"\n    squared_norms = [\n        ops.sum(ops.square(v)) for v in value_list if v is not None\n    ]\n    squared_norm = ops.sum(ops.stack(squared_norms))\n    return ops.sqrt(squared_norm)\n\n\ndef clip_by_global_norm(value_list, clip_norm):\n    use_norm = global_norm(value_list)\n    # Calculate L2-norm, clip elements by ratio of clip_norm to L2-norm\n    scale_for_finite = clip_norm * ops.minimum(1.0 / use_norm, 1.0 / clip_norm)\n    # If use_norm is any finite number, this is a no-op. For inf/-inf/NaN,\n    # this will make scale NaN.\n    scale = scale_for_finite + (use_norm - use_norm)\n    return [v * scale if v is not None else v for v in value_list]\n"
  },
  {
    "path": "keras/src/optimizers/ftrl.py",
    "content": "from keras.src import initializers\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.optimizers import optimizer\n\n\n@keras_export([\"keras.optimizers.Ftrl\"])\nclass Ftrl(optimizer.Optimizer):\n    r\"\"\"Optimizer that implements the FTRL algorithm.\n\n    \"Follow The Regularized Leader\" (FTRL) is an optimization algorithm\n    developed at Google for click-through rate prediction in the early 2010s. It\n    is most suitable for shallow models with large and sparse feature spaces.\n    The algorithm is described by\n    [McMahan et al., 2013](https://research.google.com/pubs/archive/41159.pdf).\n    The Keras version has support for both online L2 regularization\n    (the L2 regularization described in the paper\n    above) and shrinkage-type L2 regularization\n    (which is the addition of an L2 penalty to the loss function).\n\n    Initialization:\n\n    ```python\n    n = 0\n    sigma = 0\n    z = 0\n    ```\n\n    Update rule for one variable `w`:\n\n    ```python\n    prev_n = n\n    n = n + g ** 2\n    sigma = (n ** -lr_power - prev_n ** -lr_power) / lr\n    z = z + g - sigma * w\n    if abs(z) < lambda_1:\n      w = 0\n    else:\n      w = (sgn(z) * lambda_1 - z) / ((beta + sqrt(n)) / alpha + lambda_2)\n    ```\n\n    Notation:\n\n    - `lr` is the learning rate\n    - `g` is the gradient for the variable\n    - `lambda_1` is the L1 regularization strength\n    - `lambda_2` is the L2 regularization strength\n    - `lr_power` is the power to scale n.\n\n    Check the documentation for the `l2_shrinkage_regularization_strength`\n    parameter for more details when shrinkage is enabled, in which case gradient\n    is replaced with a gradient with shrinkage.\n\n    Args:\n        learning_rate: A float, a\n            `keras.optimizers.schedules.LearningRateSchedule` instance, or\n            a callable that takes no arguments and returns the actual value to\n            use. The learning rate. Defaults to `0.001`.\n        learning_rate_power: A float value, must be less or equal to zero.\n            Controls how the learning rate decreases during training. Use zero\n            for a fixed learning rate.\n        initial_accumulator_value: The starting value for accumulators. Only\n            zero or positive values are allowed.\n        l1_regularization_strength: A float value, must be greater than or equal\n            to zero. Defaults to `0.0`.\n        l2_regularization_strength: A float value, must be greater than or equal\n            to zero. Defaults to `0.0`.\n        l2_shrinkage_regularization_strength: A float value, must be greater\n            than or equal to zero. This differs from L2 above in that the L2\n            above is a stabilization penalty, whereas this L2 shrinkage is a\n            magnitude penalty. When input is sparse shrinkage will only happen\n            on the active weights.\n        beta: A float value, representing the beta value from the paper.\n            Defaults to `0.0`.\n        {{base_optimizer_keyword_args}}\n    \"\"\"\n\n    def __init__(\n        self,\n        learning_rate=0.001,\n        learning_rate_power=-0.5,\n        initial_accumulator_value=0.1,\n        l1_regularization_strength=0.0,\n        l2_regularization_strength=0.0,\n        l2_shrinkage_regularization_strength=0.0,\n        beta=0.0,\n        weight_decay=None,\n        clipnorm=None,\n        clipvalue=None,\n        global_clipnorm=None,\n        use_ema=False,\n        ema_momentum=0.99,\n        ema_overwrite_frequency=None,\n        loss_scale_factor=None,\n        gradient_accumulation_steps=None,\n        name=\"ftrl\",\n        **kwargs,\n    ):\n        super().__init__(\n            learning_rate=learning_rate,\n            name=name,\n            weight_decay=weight_decay,\n            clipnorm=clipnorm,\n            clipvalue=clipvalue,\n            global_clipnorm=global_clipnorm,\n            use_ema=use_ema,\n            ema_momentum=ema_momentum,\n            ema_overwrite_frequency=ema_overwrite_frequency,\n            loss_scale_factor=loss_scale_factor,\n            gradient_accumulation_steps=gradient_accumulation_steps,\n            **kwargs,\n        )\n\n        if initial_accumulator_value < 0.0:\n            raise ValueError(\n                \"`initial_accumulator_value` needs to be positive or zero. \"\n                \"Received: initial_accumulator_value=\"\n                f\"{initial_accumulator_value}.\"\n            )\n        if learning_rate_power > 0.0:\n            raise ValueError(\n                \"`learning_rate_power` needs to be negative or zero. Received: \"\n                f\"learning_rate_power={learning_rate_power}.\"\n            )\n        if l1_regularization_strength < 0.0:\n            raise ValueError(\n                \"`l1_regularization_strength` needs to be positive or zero. \"\n                \"Received: l1_regularization_strength=\"\n                f\"{l1_regularization_strength}.\"\n            )\n        if l2_regularization_strength < 0.0:\n            raise ValueError(\n                \"`l2_regularization_strength` needs to be positive or zero. \"\n                \"Received: l2_regularization_strength=\"\n                f\"{l2_regularization_strength}.\"\n            )\n        if l2_shrinkage_regularization_strength < 0.0:\n            raise ValueError(\n                \"`l2_shrinkage_regularization_strength` needs to be positive \"\n                \"or zero. Received: l2_shrinkage_regularization_strength\"\n                f\"={l2_shrinkage_regularization_strength}.\"\n            )\n\n        self.learning_rate_power = learning_rate_power\n        self.initial_accumulator_value = initial_accumulator_value\n        self.l1_regularization_strength = l1_regularization_strength\n        self.l2_regularization_strength = l2_regularization_strength\n        self.l2_shrinkage_regularization_strength = (\n            l2_shrinkage_regularization_strength\n        )\n        self.beta = beta\n\n    def build(self, var_list):\n        \"\"\"Initialize optimizer variables.\n\n        Args:\n            var_list: list of model variables to build Ftrl variables on.\n        \"\"\"\n        if self.built:\n            return\n        super().build(var_list)\n        accumulator_initializer = initializers.Constant(\n            self.initial_accumulator_value,\n        )\n        self._accumulators, self._linears = self.add_optimizer_variables(\n            var_list,\n            [\"accumulator\", \"linear\"],\n            initializer=[accumulator_initializer, \"zeros\"],\n        )\n\n    def update_step(self, gradient, variable, learning_rate):\n        \"\"\"Update step given gradient and the associated model variable.\"\"\"\n\n        lr = ops.cast(learning_rate, variable.dtype)\n        gradient = ops.cast(gradient, variable.dtype)\n\n        accum = self._accumulators[self._get_variable_index(variable)]\n        linear = self._linears[self._get_variable_index(variable)]\n\n        lr_power = self.learning_rate_power\n        l2_reg = self.l2_regularization_strength\n        l2_reg = l2_reg + self.beta / (2.0 * lr)\n\n        grad_to_use = ops.add(\n            gradient,\n            ops.multiply(\n                2 * self.l2_shrinkage_regularization_strength, variable\n            ),\n        )\n        new_accum = ops.add(accum, ops.square(gradient))\n        self.assign_add(\n            linear,\n            ops.subtract(\n                grad_to_use,\n                ops.multiply(\n                    ops.divide(\n                        ops.subtract(\n                            ops.power(new_accum, -lr_power),\n                            ops.power(accum, -lr_power),\n                        ),\n                        lr,\n                    ),\n                    variable,\n                ),\n            ),\n        )\n        quadratic = ops.add(\n            ops.divide(ops.power(new_accum, (-lr_power)), lr), 2 * l2_reg\n        )\n        linear_clipped = ops.clip(\n            linear,\n            -self.l1_regularization_strength,\n            self.l1_regularization_strength,\n        )\n        self.assign(\n            variable,\n            ops.divide(ops.subtract(linear_clipped, linear), quadratic),\n        )\n        self.assign(accum, new_accum)\n\n    def get_config(self):\n        config = super().get_config()\n\n        config.update(\n            {\n                \"learning_rate_power\": self.learning_rate_power,\n                \"initial_accumulator_value\": self.initial_accumulator_value,\n                \"l1_regularization_strength\": self.l1_regularization_strength,\n                \"l2_regularization_strength\": self.l2_regularization_strength,\n                \"l2_shrinkage_regularization_strength\": self.l2_shrinkage_regularization_strength,  # noqa: E501\n                \"beta\": self.beta,\n            }\n        )\n        return config\n\n\nFtrl.__doc__ = Ftrl.__doc__.replace(\n    \"{{base_optimizer_keyword_args}}\", optimizer.base_optimizer_keyword_args\n)\n"
  },
  {
    "path": "keras/src/optimizers/ftrl_test.py",
    "content": "# flake8: noqa\n\n\nimport numpy as np\nfrom unittest import mock\n\nfrom keras.src import backend\nfrom keras.src import testing\nfrom keras.src.optimizers.ftrl import Ftrl\n\n\nclass FtrlTest(testing.TestCase):\n    def test_config(self):\n        optimizer = Ftrl(\n            learning_rate=0.05,\n            learning_rate_power=-0.2,\n            initial_accumulator_value=0.4,\n            l1_regularization_strength=0.05,\n            l2_regularization_strength=0.15,\n            l2_shrinkage_regularization_strength=0.01,\n            beta=0.3,\n        )\n        self.run_class_serialization_test(optimizer)\n\n    def test_single_step(self):\n        optimizer = Ftrl(learning_rate=0.5)\n        grads = np.array([1.0, 6.0, 7.0, 2.0])\n        vars = backend.Variable([1.0, 2.0, 3.0, 4.0])\n        optimizer.apply_gradients(zip([grads], [vars]))\n        self.assertAllClose(\n            vars, [0.2218, 1.3954, 2.3651, 2.8814], rtol=1e-4, atol=1e-4\n        )\n\n    def test_correctness_with_golden(self):\n        optimizer = Ftrl(\n            learning_rate=0.05,\n            learning_rate_power=-0.2,\n            initial_accumulator_value=0.4,\n            l1_regularization_strength=0.05,\n            l2_regularization_strength=0.15,\n            l2_shrinkage_regularization_strength=0.01,\n            beta=0.3,\n        )\n\n        x = backend.Variable(np.ones([10]))\n        grads = np.arange(0.1, 1.1, 0.1)\n        first_grads = np.full((10,), 0.01)\n\n        # fmt: off\n        golden = np.array(\n            [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],\n            [-0.0034, -0.0077, -0.0118, -0.0157, -0.0194, -0.023, -0.0263, -0.0294, -0.0325, -0.0354],\n            [-0.0078, -0.0162, -0.0242, -0.0317, -0.0387, -0.0454, -0.0516, -0.0575, -0.0631, -0.0685],\n            [-0.0121, -0.0246, -0.0363, -0.0472, -0.0573, -0.0668, -0.0757, -0.0842, -0.0922, -0.0999],\n            [-0.0164, -0.0328, -0.0481, -0.0623, -0.0753, -0.0875, -0.099, -0.1098, -0.1201, -0.1299]]\n        )\n        # fmt: on\n\n        optimizer.apply_gradients(zip([first_grads], [x]))\n        for i in range(5):\n            self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4)\n            optimizer.apply_gradients(zip([grads], [x]))\n\n    def test_clip_norm(self):\n        optimizer = Ftrl(clipnorm=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])\n\n    def test_clip_value(self):\n        optimizer = Ftrl(clipvalue=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [1.0, 1.0])\n\n    def test_invalid_initial_accumulator_value(self):\n        invalid_value = -0.1\n        with self.assertRaisesRegex(\n            ValueError,\n            f\"^`initial_accumulator_value` needs to be positive or zero. Received: initial_accumulator_value={invalid_value}.$\",\n        ):\n            Ftrl(initial_accumulator_value=invalid_value)\n\n    def test_invalid_learning_rate_power(self):\n        invalid_value = 0.1\n        with self.assertRaisesRegex(\n            ValueError,\n            f\"^`learning_rate_power` needs to be negative or zero. Received: learning_rate_power={invalid_value}.$\",\n        ):\n            Ftrl(learning_rate_power=invalid_value)\n\n    def test_invalid_l1_regularization_strength(self):\n        invalid_value = -0.1\n        with self.assertRaisesRegex(\n            ValueError,\n            f\"^`l1_regularization_strength` needs to be positive or zero. Received: l1_regularization_strength={invalid_value}.$\",\n        ):\n            Ftrl(l1_regularization_strength=invalid_value)\n\n    def test_invalid_l2_regularization_strength(self):\n        invalid_value = -0.1\n        with self.assertRaisesRegex(\n            ValueError,\n            f\"^`l2_regularization_strength` needs to be positive or zero. Received: l2_regularization_strength={invalid_value}.$\",\n        ):\n            Ftrl(l2_regularization_strength=invalid_value)\n\n    def test_invalid_l2_shrinkage_regularization_strength(self):\n        invalid_value = -0.1\n        with self.assertRaisesRegex(\n            ValueError,\n            f\"^`l2_shrinkage_regularization_strength` needs to be positive or zero. Received: l2_shrinkage_regularization_strength={invalid_value}.$\",\n        ):\n            Ftrl(l2_shrinkage_regularization_strength=invalid_value)\n"
  },
  {
    "path": "keras/src/optimizers/lamb.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.optimizers import optimizer\n\n\n@keras_export(\"keras.optimizers.Lamb\")\nclass Lamb(optimizer.Optimizer):\n    \"\"\"Optimizer that implements the Lamb algorithm.\n\n    Lamb is a stochastic gradient descent method that\n    uses layer-wise adaptive moments to adjusts the\n    learning rate for each parameter based on the ratio of the\n    norm of the weight to the norm of the gradient\n    This helps to stabilize the training process and improves convergence\n    especially for large batch sizes.\n\n    Args:\n        learning_rate: A float, a\n            `keras.optimizers.schedules.LearningRateSchedule` instance, or\n            a callable that takes no arguments and returns the actual value to\n            use. The learning rate. Defaults to `0.001`.\n        beta_1: A float value or a constant float tensor, or a callable\n            that takes no arguments and returns the actual value to use. The\n            exponential decay rate for the 1st moment estimates. Defaults to\n            `0.9`.\n        beta_2: A float value or a constant float tensor, or a callable\n            that takes no arguments and returns the actual value to use. The\n            exponential decay rate for the 2nd moment estimates. Defaults to\n            `0.999`.\n        epsilon: A small constant for numerical stability.\n            Defaults to `1e-7`.\n        {{base_optimizer_keyword_args}}\n\n    References:\n        - [Yang et al.](https://arxiv.org/pdf/1904.00962)\n    \"\"\"\n\n    def __init__(\n        self,\n        learning_rate=0.001,\n        beta_1=0.9,\n        beta_2=0.999,\n        epsilon=1e-7,\n        weight_decay=None,\n        clipnorm=None,\n        clipvalue=None,\n        global_clipnorm=None,\n        use_ema=False,\n        ema_momentum=0.99,\n        ema_overwrite_frequency=None,\n        loss_scale_factor=None,\n        gradient_accumulation_steps=None,\n        name=\"lamb\",\n        **kwargs,\n    ):\n        super().__init__(\n            learning_rate=learning_rate,\n            name=name,\n            weight_decay=weight_decay,\n            clipnorm=clipnorm,\n            clipvalue=clipvalue,\n            global_clipnorm=global_clipnorm,\n            use_ema=use_ema,\n            ema_momentum=ema_momentum,\n            ema_overwrite_frequency=ema_overwrite_frequency,\n            loss_scale_factor=loss_scale_factor,\n            gradient_accumulation_steps=gradient_accumulation_steps,\n            **kwargs,\n        )\n        self.beta_1 = beta_1\n        self.beta_2 = beta_2\n        self.epsilon = epsilon\n\n    def build(self, var_list):\n        \"\"\"Initialize optimizer variables.\n\n        Lamb optimizer has 2 types of variables: momentums and velocities\n\n        Args:\n            var_list: list of model variables to build Lamb variables on.\n        \"\"\"\n        if self.built:\n            return\n        super().build(var_list)\n        self._momentums, self._velocities = self.add_optimizer_variables(\n            var_list, [\"momentum\", \"velocity\"]\n        )\n\n    def update_step(self, gradient, variable, learning_rate):\n        \"\"\"Update step given gradient and the associated model variable.\"\"\"\n        lr = ops.cast(learning_rate, variable.dtype)\n        gradient = ops.cast(gradient, variable.dtype)\n        local_step = ops.cast(self.iterations + 1, variable.dtype)\n\n        beta_1_power = ops.power(\n            ops.cast(self.beta_1, variable.dtype), local_step\n        )\n        beta_2_power = ops.power(\n            ops.cast(self.beta_2, variable.dtype), local_step\n        )\n\n        m = self._momentums[self._get_variable_index(variable)]\n        v = self._velocities[self._get_variable_index(variable)]\n\n        self.assign_add(\n            m, ops.multiply(ops.subtract(gradient, m), 1 - self.beta_1)\n        )\n\n        self.assign_add(\n            v,\n            ops.multiply(\n                ops.subtract(ops.square(gradient), v), 1 - self.beta_2\n            ),\n        )\n\n        m_t_hat = ops.divide(m, (1.0 - beta_1_power))\n        v_sqrt = ops.add(\n            ops.sqrt(ops.divide(v, (1.0 - beta_2_power))), self.epsilon\n        )\n\n        update = ops.divide(m_t_hat, v_sqrt)\n        w_norm = ops.sqrt(ops.sum(ops.power(variable, 2)))\n        g_norm = ops.sqrt(ops.sum(ops.power(update, 2)))\n\n        # ratio = w_norm / g_norm if w_norm > 0 and g_norm > 0 else 1\n        ratio = ops.where(\n            ops.greater(w_norm, 0),\n            ops.where(ops.greater(g_norm, 0), (w_norm / g_norm), 1.0),\n            1.0,\n        )\n\n        self.assign_sub(variable, ratio * lr * update)\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"beta_1\": self.beta_1,\n                \"beta_2\": self.beta_2,\n                \"epsilon\": self.epsilon,\n            }\n        )\n        return config\n\n\nLamb.__doc__ = Lamb.__doc__.replace(\n    \"{{base_optimizer_keyword_args}}\", optimizer.base_optimizer_keyword_args\n)\n"
  },
  {
    "path": "keras/src/optimizers/lamb_test.py",
    "content": "import numpy as np\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.optimizers.lamb import Lamb\n\n\nclass LambTest(testing.TestCase):\n    def test_config(self):\n        optimizer = Lamb(\n            learning_rate=0.5,\n            beta_1=0.5,\n            beta_2=0.67,\n            epsilon=1e-5,\n        )\n        self.run_class_serialization_test(optimizer)\n\n    def test_single_step(self):\n        optimizer = Lamb(learning_rate=0.5)\n        grads = ops.array([1.0, 6.0, 7.0, 2.0])\n        vars = backend.Variable([1.0, 2.0, 3.0, 4.0])\n        optimizer.apply_gradients(zip([grads], [vars]))\n        self.assertAllClose(\n            vars, [-0.3693, 0.6306, 1.6306, 2.6306], rtol=1e-4, atol=1e-4\n        )\n\n    def test_weight_decay(self):\n        grads, var1, var2, var3 = (\n            ops.zeros(()),\n            backend.Variable(2.0),\n            backend.Variable(2.0, name=\"exclude\"),\n            backend.Variable(2.0),\n        )\n        optimizer_1 = Lamb(learning_rate=1.0, weight_decay=0.004)\n        optimizer_1.apply_gradients(zip([grads], [var1]))\n\n        optimizer_2 = Lamb(learning_rate=1.0, weight_decay=0.004)\n        optimizer_2.exclude_from_weight_decay(var_names=[\"exclude\"])\n        optimizer_2.apply_gradients(zip([grads, grads], [var1, var2]))\n\n        optimizer_3 = Lamb(learning_rate=1.0, weight_decay=0.004)\n        optimizer_3.exclude_from_weight_decay(var_list=[var3])\n        optimizer_3.apply_gradients(zip([grads, grads], [var1, var3]))\n\n        self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6)\n        self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6)\n        self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6)\n\n    def test_correctness_with_golden(self):\n        optimizer = Lamb()\n\n        x = backend.Variable(np.ones([10], dtype=\"float32\"))\n        grads = ops.arange(0.1, 1.1, 0.1)\n        first_grads = ops.full((10,), 0.01)\n\n        golden = np.tile(\n            [[0.999], [0.9982], [0.9974], [0.9965], [0.9955]], (1, 10)\n        )\n\n        optimizer.apply_gradients(zip([first_grads], [x]))\n        for i in range(5):\n            self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4)\n            optimizer.apply_gradients(zip([grads], [x]))\n\n    def test_clip_norm(self):\n        optimizer = Lamb(clipnorm=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])\n\n    def test_clip_value(self):\n        optimizer = Lamb(clipvalue=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [1.0, 1.0])\n"
  },
  {
    "path": "keras/src/optimizers/lion.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.optimizers import optimizer\n\n\n@keras_export([\"keras.optimizers.Lion\"])\nclass Lion(optimizer.Optimizer):\n    \"\"\"Optimizer that implements the Lion algorithm.\n\n    The Lion optimizer is a stochastic-gradient-descent method that uses the\n    sign operator to control the magnitude of the update, unlike other adaptive\n    optimizers such as Adam that rely on second-order moments. This makes\n    Lion more memory-efficient as it only keeps track of the momentum. According\n    to the authors (see reference), its performance gain over Adam grows with\n    the batch size. Because the update of Lion is produced through the sign\n    operation, resulting in a larger norm, a suitable learning rate for Lion is\n    typically 3-10x smaller than that for AdamW. The weight decay for Lion\n    should in turn be 3-10x larger than that for AdamW to maintain a\n    similar strength (lr * wd).\n\n    Args:\n        learning_rate: A float, a\n            `keras.optimizers.schedules.LearningRateSchedule` instance, or\n            a callable that takes no arguments and returns the actual value to\n            use. The learning rate. Defaults to `0.001`.\n        beta_1: A float value or a constant float tensor, or a callable\n            that takes no arguments and returns the actual value to use. The\n            rate to combine the current gradient and the 1st moment estimate.\n            Defaults to `0.9`.\n        beta_2: A float value or a constant float tensor, or a callable\n            that takes no arguments and returns the actual value to use. The\n            exponential decay rate for the 1st moment estimate. Defaults to\n            `0.99`.\n        {{base_optimizer_keyword_args}}\n\n    References:\n\n    - [Chen et al., 2023](http://arxiv.org/abs/2302.06675)\n    - [Authors' implementation](\n        http://github.com/google/automl/tree/master/lion)\n\n    \"\"\"\n\n    def __init__(\n        self,\n        learning_rate=0.001,\n        beta_1=0.9,\n        beta_2=0.99,\n        weight_decay=None,\n        clipnorm=None,\n        clipvalue=None,\n        global_clipnorm=None,\n        use_ema=False,\n        ema_momentum=0.99,\n        ema_overwrite_frequency=None,\n        loss_scale_factor=None,\n        gradient_accumulation_steps=None,\n        name=\"lion\",\n        **kwargs,\n    ):\n        super().__init__(\n            learning_rate=learning_rate,\n            name=name,\n            weight_decay=weight_decay,\n            clipnorm=clipnorm,\n            clipvalue=clipvalue,\n            global_clipnorm=global_clipnorm,\n            use_ema=use_ema,\n            ema_momentum=ema_momentum,\n            ema_overwrite_frequency=ema_overwrite_frequency,\n            loss_scale_factor=loss_scale_factor,\n            gradient_accumulation_steps=gradient_accumulation_steps,\n            **kwargs,\n        )\n        self.beta_1 = beta_1\n        self.beta_2 = beta_2\n        if beta_1 <= 0 or beta_1 > 1:\n            raise ValueError(\n                \"Argument `beta_1` must be in the [0, 1] range. Otherwise, the \"\n                f\"optimizer degenerates to SignSGD. Received: beta_1={beta_1}.\"\n            )\n\n    def build(self, var_list):\n        \"\"\"Initialize optimizer variables.\n\n        Lion optimizer has one variable `momentums`.\n\n        Args:\n            var_list: list of model variables to build Lion variables on.\n        \"\"\"\n        if self.built:\n            return\n        super().build(var_list)\n        self._momentums = self.add_optimizer_variables(var_list, \"momentum\")\n\n    def update_step(self, gradient, variable, learning_rate):\n        \"\"\"Update step given gradient and the associated model variable.\"\"\"\n        lr = ops.cast(learning_rate, variable.dtype)\n        gradient = ops.cast(gradient, variable.dtype)\n        beta_1 = ops.cast(self.beta_1, variable.dtype)\n        beta_2 = ops.cast(self.beta_2, variable.dtype)\n        m = self._momentums[self._get_variable_index(variable)]\n\n        self.assign_sub(\n            variable,\n            ops.multiply(\n                lr,\n                ops.sign(\n                    ops.add(\n                        ops.multiply(m, beta_1),\n                        ops.multiply(gradient, (1.0 - beta_1)),\n                    )\n                ),\n            ),\n        )\n        self.assign(\n            m,\n            ops.add(\n                ops.multiply(m, beta_2), ops.multiply(gradient, (1.0 - beta_2))\n            ),\n        )\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"beta_1\": self.beta_1,\n                \"beta_2\": self.beta_2,\n            }\n        )\n        return config\n\n\nLion.__doc__ = Lion.__doc__.replace(\n    \"{{base_optimizer_keyword_args}}\", optimizer.base_optimizer_keyword_args\n)\n"
  },
  {
    "path": "keras/src/optimizers/lion_test.py",
    "content": "import numpy as np\nimport pytest\n\nimport keras\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.optimizers.lion import Lion\n\n\nclass LionTest(testing.TestCase):\n    def test_invalid_beta_1(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Argument `beta_1` must be in the \\\\[0, 1\\\\] range. Otherwise, the \"\n            \"optimizer degenerates to SignSGD. Received: beta_1=-0.1.\",\n        ):\n            Lion(beta_1=-0.1)\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Argument `beta_1` must be in the \\\\[0, 1\\\\] range. Otherwise, the \"\n            \"optimizer degenerates to SignSGD. Received: beta_1=0.0.\",\n        ):\n            Lion(beta_1=0.0)\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Argument `beta_1` must be in the \\\\[0, 1\\\\] range. Otherwise, the \"\n            \"optimizer degenerates to SignSGD. Received: beta_1=1.1.\",\n        ):\n            Lion(beta_1=1.1)\n\n    def test_config(self):\n        optimizer = Lion(\n            learning_rate=0.5,\n            beta_1=0.5,\n            beta_2=0.67,\n        )\n        self.run_class_serialization_test(optimizer)\n\n    def test_single_step(self):\n        optimizer = Lion(learning_rate=0.5)\n        grads = ops.array([1.0, 6.0, 7.0, 2.0])\n        vars = backend.Variable([1.0, 2.0, 3.0, 4.0])\n        optimizer.apply_gradients(zip([grads], [vars]))\n        self.assertAllClose(vars, [0.5, 1.5, 2.5, 3.5], rtol=1e-4, atol=1e-4)\n\n    def test_weight_decay(self):\n        grads, var1, var2, var3 = (\n            ops.zeros(()),\n            backend.Variable(2.0),\n            backend.Variable(2.0, name=\"exclude\"),\n            backend.Variable(2.0),\n        )\n        optimizer_1 = Lion(learning_rate=1.0, weight_decay=0.004)\n        optimizer_1.apply_gradients(zip([grads], [var1]))\n\n        optimizer_2 = Lion(learning_rate=1.0, weight_decay=0.004)\n        optimizer_2.exclude_from_weight_decay(var_names=[\"exclude\"])\n        optimizer_2.apply_gradients(zip([grads, grads], [var1, var2]))\n\n        optimizer_3 = Lion(learning_rate=1.0, weight_decay=0.004)\n        optimizer_3.exclude_from_weight_decay(var_list=[var3])\n        optimizer_3.apply_gradients(zip([grads, grads], [var1, var3]))\n\n        self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6)\n        self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6)\n        self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6)\n\n    def test_correctness_with_golden(self):\n        optimizer = Lion()\n\n        x = backend.Variable(np.ones([10]))\n        grads = ops.arange(0.1, 1.1, 0.1)\n        first_grads = ops.full((10,), 0.01)\n\n        golden = np.tile(\n            [[0.999], [0.998], [0.997], [0.996], [0.995]],\n            (1, 10),\n        )\n\n        optimizer.apply_gradients(zip([first_grads], [x]))\n        for i in range(5):\n            self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4)\n            optimizer.apply_gradients(zip([grads], [x]))\n\n    def test_clip_norm(self):\n        optimizer = Lion(clipnorm=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])\n\n    def test_clip_value(self):\n        optimizer = Lion(clipvalue=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [1.0, 1.0])\n\n    @pytest.mark.requires_trainable_backend\n    def test_ema(self):\n        # TODO: test correctness\n        model = keras.Sequential([keras.layers.Dense(10)])\n        model.compile(optimizer=Lion(use_ema=True), loss=\"mse\")\n        x = keras.ops.zeros((1, 5))\n        y = keras.ops.zeros((1, 10))\n        model.fit(x, y)\n"
  },
  {
    "path": "keras/src/optimizers/loss_scale_optimizer.py",
    "content": "from keras.src import backend\nfrom keras.src import initializers\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.optimizers import optimizer\nfrom keras.src.saving import serialization_lib\nfrom keras.src.utils import tracking\n\n\n@keras_export(\n    [\n        \"keras.optimizers.LossScaleOptimizer\",\n        \"keras.mixed_precision.LossScaleOptimizer\",\n    ]\n)\nclass LossScaleOptimizer(optimizer.Optimizer):\n    \"\"\"An optimizer that dynamically scales the loss to prevent underflow.\n\n    Loss scaling is a technique to prevent numeric underflow in intermediate\n    gradients when float16 is used. To prevent underflow, the loss is multiplied\n    (or \"scaled\") by a certain factor called the \"loss scale\", which causes\n    intermediate gradients to be scaled by the loss scale as well. The final\n    gradients are divided (or \"unscaled\") by the loss scale to bring them back\n    to their original value.\n\n    `LossScaleOptimizer` wraps another optimizer and applies dynamic loss\n    scaling to it. This loss scale is dynamically updated over time as follows:\n    - On any train step, if a nonfinite gradient is encountered, the loss scale\n      is halved, and the train step is skipped.\n    - If `dynamic_growth_steps` have occurred since the last time the loss scale\n      was updated, and no nonfinite gradients have occurred, the loss scale\n      is doubled.\n\n    Args:\n        inner_optimizer: The `keras.optimizers.Optimizer` instance to wrap.\n        initial_scale: Float. The initial loss scale. This scale will be updated\n            during training. It is recommended for this to be a very high\n            number, because a loss scale that is too high gets lowered far more\n            quickly than a loss scale that is too low gets raised.\n        dynamic_growth_steps: Int. How often to update the scale upwards. After\n            every `dynamic_growth_steps` steps with finite gradients, the\n            loss scale is doubled.\n        {{base_optimizer_keyword_args}}\n    \"\"\"\n\n    def __init__(\n        self,\n        inner_optimizer,\n        initial_scale=2.0**15,\n        dynamic_growth_steps=2000,\n        name=None,\n        **kwargs,\n    ):\n        if not kwargs.pop(\"dynamic\", True):\n            raise ValueError(\n                \"LossScaleOptimizer no longer supports `dynamic=False`. \"\n                \"Instead, simply set `loss_scale_factor` directly on the \"\n                \"`inner_optimizer`.\"\n            )\n\n        # Backwards compatibility code for deserialization.\n        # LossScaleOptimizer used to return all these parameters in `get_config`\n        # from `super.get_config` even though they are all non-functional. We\n        # no longer let user set them, but we have to allow the default values\n        # to be passed during deserialization to support older models.\n        base_optimizer_defaults = {\n            \"weight_decay\": None,\n            \"clipnorm\": None,\n            \"global_clipnorm\": None,\n            \"clipvalue\": None,\n            \"use_ema\": False,\n            \"ema_momentum\": 0.99,\n            \"ema_overwrite_frequency\": None,\n            \"loss_scale_factor\": None,\n            \"gradient_accumulation_steps\": None,\n        }\n        for arg_name, default_value in base_optimizer_defaults.items():\n            if arg_name not in kwargs:\n                continue\n            arg_value = kwargs.pop(arg_name)\n            if (\n                default_value is None and arg_value is not None\n            ) or arg_value != default_value:\n                raise ValueError(\n                    f\"LossScaleOptimizer does not support `{arg_name}`. \"\n                    f\"Instead, set `{arg_name}` on the `inner_optimizer`.\"\n                )\n\n        if kwargs:\n            raise ValueError(\n                \"LossScaleOptimizer does not support arguments: \"\n                f\"`{'`, `'.join(kwargs.keys())}`.\"\n            )\n\n        super().__init__(learning_rate=0.0, name=name)\n        self.inner_optimizer = inner_optimizer\n        self.initial_scale = initial_scale\n        self.dynamic_growth_steps = dynamic_growth_steps\n        # Disable the inner optimizer's loss scaling, otherwise\n        # gradients will be scaled twice.\n        self.inner_optimizer.loss_scale_factor = None\n\n    @tracking.no_automatic_dependency_tracking\n    def build(self, var_list):\n        self.step_counter = self.add_variable(\n            shape=(),\n            dtype=\"int\",\n            initializer=initializers.Zeros(),\n            aggregation=\"none\",\n            name=\"step_counter\",\n        )\n        self.dynamic_scale = self.add_variable(\n            shape=(),\n            dtype=\"float32\",\n            initializer=initializers.Constant(self.initial_scale),\n            aggregation=\"none\",\n            name=\"dynamic_scale\",\n        )\n        self.inner_optimizer.build(var_list)\n        super().build(var_list)\n\n    @property\n    def variables(self):\n        return self._variables + self.inner_optimizer.variables\n\n    def stateless_apply(self, optimizer_variables, grads, trainable_variables):\n        if not self.built:\n            raise ValueError(\n                f\"To call `stateless_apply`, {self.__class__.__name__} \"\n                \"must be built (i.e. its variables must have been created). \"\n                \"You can build it via `optimizer.build(trainable_variables)`.\"\n            )\n        finite = self.check_finite(grads)\n        return ops.cond(\n            finite,\n            lambda: self._stateless_handle_finite_grads(\n                optimizer_variables, grads, trainable_variables\n            ),\n            lambda: self._stateless_handle_non_finite_grads(\n                optimizer_variables, trainable_variables\n            ),\n        )\n\n    def _stateless_handle_finite_grads(\n        self, optimizer_variables, grads, trainable_variables\n    ):\n        def upscale():\n            mapping = list(zip(self.variables, optimizer_variables))\n            with backend.StatelessScope(state_mapping=mapping) as scope:\n                self.step_counter.assign(0)\n                self.dynamic_scale.assign(ops.multiply(self.dynamic_scale, 2.0))\n            return [scope.get_current_value(v) for v in self._variables]\n\n        def increment():\n            mapping = list(zip(self.variables, optimizer_variables))\n            with backend.StatelessScope(state_mapping=mapping) as scope:\n                self.step_counter.assign_add(1)\n            return [scope.get_current_value(v) for v in self._variables]\n\n        mapping = list(zip(self.variables, optimizer_variables))\n        with backend.StatelessScope(state_mapping=mapping):\n            # Potentially upscale loss and reset counter.\n            own_variables = ops.cond(\n                ops.equal(self.step_counter, self.dynamic_growth_steps - 1),\n                upscale,\n                increment,\n            )\n\n            # Unscale gradients.\n            scale = self.dynamic_scale\n            unscaled_grads = [\n                g\n                if g is None or self._overwrite_variable_with_gradient(v)\n                else ops.divide(g, scale)\n                for g, v in zip(grads, self._trainable_variables)\n            ]\n            (\n                new_trainable_variables,\n                new_inner_variables,\n            ) = self.inner_optimizer.stateless_apply(\n                self.inner_optimizer.variables,\n                unscaled_grads,\n                trainable_variables,\n            )\n\n        new_optimizer_variables = own_variables + new_inner_variables\n        return new_trainable_variables, new_optimizer_variables\n\n    def _stateless_handle_non_finite_grads(\n        self, optimizer_variables, trainable_variables\n    ):\n        mapping = list(zip(self.variables, optimizer_variables))\n        with backend.StatelessScope(state_mapping=mapping) as scope:\n            self.step_counter.assign(0)\n            self.dynamic_scale.assign(ops.multiply(self.dynamic_scale, 0.5))\n        new_optimizer_variables = []\n        for v in self.variables:\n            new_optimizer_variables.append(scope.get_current_value(v))\n        return trainable_variables, new_optimizer_variables\n\n    def apply(self, grads, trainable_variables=None):\n        # Optionally build optimizer.\n        if not self.built:\n            with backend.name_scope(self.name, caller=self):\n                self.build(trainable_variables)\n            self.built = True\n\n        if backend.backend() == \"tensorflow\":\n            self._tf_apply(grads, trainable_variables)\n        else:\n            self._common_apply(grads, trainable_variables)\n\n    def _stateful_handle_finite_grads(self, grads, trainable_variables):\n        scale = self.dynamic_scale\n        # Unscale gradients.\n        tvs = trainable_variables or self._trainable_variables\n        unscaled_grads = [\n            g\n            if g is None or self._overwrite_variable_with_gradient(v)\n            else ops.divide(g, scale)\n            for g, v in zip(grads, tvs)\n        ]\n        self.inner_optimizer.apply(\n            unscaled_grads, trainable_variables=trainable_variables\n        )\n\n        def upscale():\n            self.step_counter.assign(0)\n            self.dynamic_scale.assign(ops.multiply(self.dynamic_scale, 2.0))\n\n        def increment():\n            self.step_counter.assign_add(1)\n\n        # Potentially upscale loss and reset counter.\n        ops.cond(\n            ops.equal(self.step_counter, self.dynamic_growth_steps - 1),\n            upscale,\n            increment,\n        )\n\n    def _stateful_handle_non_finite_grads(self):\n        # If any inf or nan in grads, downscale loss and reset counter.\n        self.step_counter.assign(0)\n        self.dynamic_scale.assign(ops.multiply(self.dynamic_scale, 0.5))\n\n    def _common_apply(self, grads, trainable_variables=None):\n        finite = self.check_finite(grads)\n        ops.cond(\n            finite,\n            lambda: self._stateful_handle_finite_grads(\n                grads, trainable_variables\n            ),\n            self._stateful_handle_non_finite_grads,\n        )\n\n    def _tf_apply(self, grads, trainable_variables=None):\n        \"\"\"Tensorflow specific logic for apply, which handles distribution.\"\"\"\n        from keras.src.utils.module_utils import tensorflow as tf\n\n        if tf.distribute.in_cross_replica_context():\n            raise ValueError(\"apply() must be called in a replica context.\")\n\n        if tf.__internal__.distribute.strategy_supports_no_merge_call():\n            self._common_apply(grads, trainable_variables=trainable_variables)\n        else:\n\n            def _handle_cross_replica(distribution, grads, trainable_variables):\n                finite_per_replica = (\n                    distribution.extended.call_for_each_replica(\n                        self.check_finite, args=(grads,)\n                    )\n                )\n                # Each replica computed the same `finite` value, since\n                # `grads` is all-reduced across replicas. Arbitrarily take\n                # `finite` from the first replica.\n                finite = distribution.experimental_local_results(\n                    finite_per_replica\n                )[0]\n\n                def apply_fn():\n                    distribution.extended.call_for_each_replica(\n                        self._stateful_handle_finite_grads,\n                        args=(grads, trainable_variables),\n                    )\n\n                # Note: We must call this cond() in a cross-replica context.\n                # DistributionStrategy does not support having a cond in a\n                # replica context with a branch that calls `merge_call`, and\n                # self._optimizer.apply_gradients calls `merge_call`.\n                ops.cond(\n                    finite, apply_fn, self._stateful_handle_non_finite_grads\n                )\n\n            tf.distribute.get_replica_context().merge_call(\n                _handle_cross_replica, args=(grads, trainable_variables)\n            )\n\n    def check_finite(self, grads):\n        tensor_grads = [g for g in grads if g is not None]\n        finite_grads = [ops.all(ops.isfinite(g)) for g in tensor_grads]\n        return ops.all(ops.convert_to_tensor(finite_grads))\n\n    @property\n    def learning_rate(self):\n        return self.inner_optimizer.learning_rate\n\n    @learning_rate.setter\n    def learning_rate(self, learning_rate):\n        self.inner_optimizer.learning_rate = learning_rate\n\n    @property\n    def iterations(self):\n        return self.inner_optimizer.iterations\n\n    def scale_loss(self, loss):\n        scale = self.dynamic_scale if self.built else self.initial_scale\n        return ops.multiply(loss, scale)\n\n    def finalize_variable_values(self, var_list):\n        self.inner_optimizer.finalize_variable_values(var_list)\n\n    def get_config(self):\n        # Do not use super().get_config() as only \"name\" is supported.\n        inner_optimizer_config = serialization_lib.serialize_keras_object(\n            self.inner_optimizer\n        )\n        return {\n            \"name\": self.name,\n            \"inner_optimizer\": inner_optimizer_config,\n            \"initial_scale\": self.initial_scale,\n            \"dynamic_growth_steps\": self.dynamic_growth_steps,\n        }\n\n    @classmethod\n    def from_config(cls, config, custom_objects=None):\n        inner_optimizer = serialization_lib.deserialize_keras_object(\n            config.pop(\"inner_optimizer\"),\n            custom_objects=custom_objects,\n        )\n        return cls(inner_optimizer, **config)\n\n\nLossScaleOptimizer.__doc__ = LossScaleOptimizer.__doc__.replace(\n    \"{{base_optimizer_keyword_args}}\", optimizer.base_optimizer_keyword_args\n)\n"
  },
  {
    "path": "keras/src/optimizers/loss_scale_optimizer_test.py",
    "content": "import numpy as np\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.optimizers.loss_scale_optimizer import LossScaleOptimizer\nfrom keras.src.optimizers.sgd import SGD\n\n\nclass LossScaleOptimizerTest(testing.TestCase):\n    def _skip_test_for_stateless(self, stateless):\n        if not stateless and backend.backend() == \"jax\":\n            self.skipTest(\n                \"LossScaleOptimizer must use stateless_apply with JAX.\"\n            )\n        if stateless and backend.backend() == \"tensorflow\":\n            self.skipTest(\n                \"stateless_apply is not supported with the TF backend.\"\n            )\n\n    def test_config(self):\n        inner_optimizer = SGD(\n            learning_rate=0.5,\n            momentum=0.06,\n            nesterov=True,\n            weight_decay=0.004,\n        )\n        optimizer = LossScaleOptimizer(inner_optimizer)\n        self.run_class_serialization_test(optimizer)\n\n    def test_apply_with_no_vars(self):\n        self._skip_test_for_stateless(False)\n\n        inner_optimizer = SGD(learning_rate=0.5)\n        optimizer = LossScaleOptimizer(inner_optimizer)\n        grads = [ops.array([1.0, 6.0, 7.0, 2.0]) * optimizer.initial_scale]\n        vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])]\n        optimizer.build(vars)\n        optimizer.apply(grads)\n        self.assertAllClose(\n            vars, [[0.5, -1.0, -0.5, 3.0]], rtol=1e-4, atol=1e-4\n        )\n\n    @parameterized.named_parameters((\"stateless\", True), (\"stateful\", False))\n    def test_finite_step(self, stateless):\n        self._skip_test_for_stateless(stateless)\n\n        inner_optimizer = SGD(learning_rate=0.5)\n        optimizer = LossScaleOptimizer(inner_optimizer)\n        grads = [ops.array([1.0, 6.0, 7.0, 2.0]) * optimizer.initial_scale]\n        vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])]\n        if stateless:\n            optimizer.build(vars)\n            vars, _ = optimizer.stateless_apply(\n                [v.value for v in optimizer.variables],\n                grads,\n                [v.value for v in vars],\n            )\n        else:\n            optimizer.apply(grads, vars)\n        self.assertAllClose(\n            vars, [[0.5, -1.0, -0.5, 3.0]], rtol=1e-4, atol=1e-4\n        )\n\n    @parameterized.named_parameters((\"stateless\", True), (\"stateful\", False))\n    def test_finite_step_with_inner_loss_scale(self, stateless):\n        self._skip_test_for_stateless(stateless)\n\n        # Ensure that the inner loss scale does not interfere with the update.\n        inner_optimizer = SGD(learning_rate=0.5, loss_scale_factor=100)\n        optimizer = LossScaleOptimizer(inner_optimizer)\n        grads = [ops.array([1.0, 6.0, 7.0, 2.0]) * optimizer.initial_scale]\n        vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])]\n        if stateless:\n            optimizer.build(vars)\n            vars, _ = optimizer.stateless_apply(\n                [v.value for v in optimizer.variables],\n                grads,\n                [v.value for v in vars],\n            )\n        else:\n            optimizer.apply(grads, vars)\n        self.assertAllClose(\n            vars, [[0.5, -1.0, -0.5, 3.0]], rtol=1e-4, atol=1e-4\n        )\n\n    @parameterized.named_parameters((\"stateless\", True), (\"stateful\", False))\n    def test_infinite_step(self, stateless):\n        self._skip_test_for_stateless(stateless)\n\n        inner_optimizer = SGD(learning_rate=0.5)\n        optimizer = LossScaleOptimizer(inner_optimizer)\n        grads = [ops.array([np.inf, np.inf, np.inf, np.inf])]\n        vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])]\n        if stateless:\n            optimizer.build(vars)\n            vars, _ = optimizer.stateless_apply(\n                [v.value for v in optimizer.variables],\n                grads,\n                [v.value for v in vars],\n            )\n        else:\n            optimizer.apply(grads, vars)\n        self.assertAllClose(vars, [[1.0, 2.0, 3.0, 4.0]], rtol=1e-4, atol=1e-4)\n\n    @parameterized.named_parameters((\"stateless\", True), (\"stateful\", False))\n    def test_finite_step_with_overwrite(self, stateless):\n        self._skip_test_for_stateless(stateless)\n\n        inner_optimizer = SGD(learning_rate=0.5)\n        optimizer = LossScaleOptimizer(inner_optimizer)\n        grads = [ops.array([1.0, 6.0, 7.0, 2.0])]\n        vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])]\n        vars[0].overwrite_with_gradient = True\n\n        if stateless:\n            optimizer.build(vars)\n            vars, _ = optimizer.stateless_apply(\n                [v.value for v in optimizer.variables],\n                grads,\n                [v.value for v in vars],\n            )\n        else:\n            optimizer.apply(grads, vars)\n        self.assertAllClose(vars, grads)\n\n    @parameterized.named_parameters((\"stateless\", True), (\"stateful\", False))\n    def test_downscaling(self, stateless):\n        self._skip_test_for_stateless(stateless)\n\n        inner_optimizer = SGD(learning_rate=0.5)\n        optimizer = LossScaleOptimizer(inner_optimizer, initial_scale=400.0)\n        vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])]\n        optimizer.build(vars)\n        opt_var_values = [v.value for v in optimizer.variables]\n        grads = [ops.array([np.inf, np.inf, np.inf, np.inf])]\n        for _ in range(4):\n            if stateless:\n                _, opt_var_values = optimizer.stateless_apply(\n                    opt_var_values, grads, [v.value for v in vars]\n                )\n                for ref_v, v in zip(optimizer.variables, opt_var_values):\n                    ref_v.assign(v)\n            else:\n                optimizer.apply(grads, vars)\n        self.assertAllClose(optimizer.scale_loss(1.0), 25.0)\n\n    @parameterized.named_parameters((\"stateless\", True), (\"stateful\", False))\n    def test_upscaling(self, stateless):\n        self._skip_test_for_stateless(stateless)\n\n        inner_optimizer = SGD(learning_rate=0.5)\n        optimizer = LossScaleOptimizer(\n            inner_optimizer,\n            initial_scale=2.0,\n            dynamic_growth_steps=2,\n        )\n        vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])]\n        optimizer.build(vars)\n        opt_var_values = [v.value for v in optimizer.variables]\n        grads = [ops.array([1.0, 6.0, 7.0, 2.0])]\n        for _ in range(8):\n            if stateless:\n                _, opt_var_values = optimizer.stateless_apply(\n                    opt_var_values, grads, [v.value for v in vars]\n                )\n                for ref_v, v in zip(optimizer.variables, opt_var_values):\n                    ref_v.assign(v)\n            else:\n                optimizer.apply(grads, vars)\n        self.assertAllClose(optimizer.scale_loss(1.0), 32.0)\n\n    @parameterized.named_parameters((\"stateless\", True), (\"stateful\", False))\n    def test_iterations_update(self, stateless):\n        self._skip_test_for_stateless(stateless)\n\n        inner_optimizer = SGD(learning_rate=0.5)\n        optimizer = LossScaleOptimizer(inner_optimizer)\n        vars = [backend.Variable([1.0, 2.0, 3.0, 4.0])]\n        optimizer.build(vars)\n        opt_var_values = [v.value for v in optimizer.variables]\n        grads = [ops.array([1.0, 6.0, 7.0, 2.0])]\n\n        self.assertEqual(optimizer.iterations.value, 0)\n\n        for i in range(3):\n            if stateless:\n                _, opt_var_values = optimizer.stateless_apply(\n                    opt_var_values, grads, [v.value for v in vars]\n                )\n                for ref_v, v in zip(optimizer.variables, opt_var_values):\n                    ref_v.assign(v)\n            else:\n                optimizer.apply(grads, vars)\n            self.assertEqual(optimizer.iterations.value, i + 1)\n\n    def test_serialization(self):\n        inner_optimizer = SGD(learning_rate=0.5)\n        optimizer = LossScaleOptimizer(\n            inner_optimizer,\n            initial_scale=3.0,\n            dynamic_growth_steps=2,\n            name=\"test_opt\",\n        )\n        config = optimizer.get_config()\n        self.assertLen(config, 4)\n        self.assertEqual(config[\"name\"], \"test_opt\")\n        self.assertEqual(config[\"initial_scale\"], 3.0)\n        self.assertEqual(config[\"dynamic_growth_steps\"], 2)\n        self.assertIn(\"inner_optimizer\", config)\n        LossScaleOptimizer.from_config(config)\n\n    def test_init_dynamic_arg(self):\n        inner_optimizer = SGD(learning_rate=0.5)\n\n        # dynamic=True is supported\n        LossScaleOptimizer(inner_optimizer, dynamic=True)\n\n        # dynamic=False is not supported\n        with self.assertRaisesRegex(ValueError, \"set `loss_scale_factor`\"):\n            LossScaleOptimizer(inner_optimizer, dynamic=False)\n\n    def test_init_unsupported_arg(self):\n        inner_optimizer = SGD(learning_rate=0.5)\n        with self.assertRaisesRegex(ValueError, \"arguments: `foo`, `bar`\"):\n            LossScaleOptimizer(inner_optimizer, foo=True, bar=3)\n\n    @parameterized.named_parameters(\n        (\"weight_decay\", \"weight_decay\", 0.5),\n        (\"clipnorm\", \"clipnorm\", 0.5),\n        (\"global_clipnorm\", \"global_clipnorm\", 0.5),\n        (\"clipvalue\", \"clipvalue\", 0.5),\n        (\"use_ema\", \"use_ema\", True),\n        (\"ema_momentum\", \"ema_momentum\", 0.5),\n        (\"ema_overwrite_frequency\", \"ema_overwrite_frequency\", 2),\n        (\"loss_scale_factor\", \"loss_scale_factor\", 0.5),\n        (\"gradient_accumulation_steps\", \"gradient_accumulation_steps\", 2),\n    )\n    def test_init_base_optimizer_unsupported_args(self, arg_name, arg_value):\n        inner_optimizer = SGD(learning_rate=0.5)\n        with self.assertRaisesRegex(ValueError, \"on the `inner_optimizer`\"):\n            LossScaleOptimizer(inner_optimizer, **{arg_name: arg_value})\n\n    def test_deserialization_backwards_compatibility(self):\n        # Test deserializing with a config that has all the unsupported\n        # arguments from the base optimizer (which are no longer serialized)\n        config = {\n            \"name\": \"loss_scale_optimizer\",\n            \"weight_decay\": None,\n            \"clipnorm\": None,\n            \"global_clipnorm\": None,\n            \"clipvalue\": None,\n            \"use_ema\": False,\n            \"ema_momentum\": 0.99,\n            \"ema_overwrite_frequency\": None,\n            \"loss_scale_factor\": None,\n            \"gradient_accumulation_steps\": None,\n            \"inner_optimizer\": {\n                \"module\": \"keras.optimizers\",\n                \"class_name\": \"SGD\",\n                \"config\": {\n                    \"name\": \"SGD\",\n                    \"learning_rate\": 0.5,\n                    \"weight_decay\": None,\n                    \"clipnorm\": None,\n                    \"global_clipnorm\": None,\n                    \"clipvalue\": None,\n                    \"use_ema\": False,\n                    \"ema_momentum\": 0.99,\n                    \"ema_overwrite_frequency\": None,\n                    \"loss_scale_factor\": None,\n                    \"gradient_accumulation_steps\": None,\n                    \"momentum\": 0.0,\n                    \"nesterov\": False,\n                },\n                \"registered_name\": None,\n            },\n            \"initial_scale\": 2.0,\n            \"dynamic_growth_steps\": 2,\n        }\n        LossScaleOptimizer.from_config(config)\n"
  },
  {
    "path": "keras/src/optimizers/muon.py",
    "content": "import re\n\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.optimizers import optimizer\n\n\n@keras_export([\"keras.optimizers.Muon\"])\nclass Muon(optimizer.Optimizer):\n    \"\"\"Optimizer that implements the Muon algorithm.\n\n    Note that this optimizer should not be used in the following layers:\n\n    1. Embedding layer\n    2. Final output fully connected layer\n    3. Any {0,1}-D variables\n\n    These should all be optimized using AdamW.\n\n    The Muon optimizer can use both the Muon update step or the\n    AdamW update step based on the following:\n\n    - For any variable that isn't 2D, the AdamW step\n        will be used. This is not configurable.\n    - If the argument `exclude_embeddings` (defaults to `True`) is set\n    to `True`, the AdamW step will be used.\n    - For any variablewith a name that matches an expression\n        listed in the argument `exclude_layers` (a list), the\n        AdamW step will be used.\n    - Any other variable uses the Muon step.\n\n    Typically, you only need to pass the name of your densely-connected\n    output layer to `exclude_layers`, e.g.\n    `exclude_layers=[\"output_dense\"]`.\n\n    References:\n        - [Original implementation](https://github.com/KellerJordan/Muon)\n        - [Liu et al, 2025](https://arxiv.org/abs/2502.16982)\n\n    Args:\n        learning_rate: A float,\n            `keras.optimizers.schedules.LearningRateSchedule` instance, or\n            a callable that takes no arguments and returns the actual value to\n            use. The learning rate. Defaults to `0.001`.\n        adam_beta_1: A float value or a constant float tensor, or a callable\n            that takes no arguments and returns the actual value to use.\n            The exponential decay rate for the 1st moment estimates. Defaults to\n            `0.9`.\n        adam_beta_2: A float value or a constant float tensor, or a callable\n            that takes no arguments and returns the actual value to use.\n            The exponential decay rate for the 2nd moment estimates. Defaults to\n            `0.999`.\n        adam_weight_decay: Float. If set, weight decay is applied when using\n            the Adam optimizer.\n        epsilon: A small constant for numerical stability. This is\n            \"epsilon hat\" in the Kingma and Ba paper\n            (in the formula just before Section 2.1),\n            not the epsilon in Algorithm 1 of the paper.\n            It be used at Adamw.Defaults to `1e-7`.\n        exclude_layers: List of strings, keywords of layer names to exclude.\n            All layers with keywords in their path will use adamw.\n        exclude_embeddings: Boolean value\n            If True, embedding layers will use adamw.\n        muon_a: Float, parameter a of the muon algorithm.\n            It is recommended to use the default value\n        muon_b: Float, parameter b of the muon algorithm.\n            It is recommended to use the default value\n        muon_c: Float, parameter c of the muon algorithm.\n            It is recommended to use the default value\n        adam_lr_ratio: Float, the ratio of the learning rate when\n                using Adam to the main learning rate.\n                It is recommended to set it to 1\n        momentum: Float, momentum used by internal SGD.\n        ns_steps: Integer, number of Newton-Schulz iterations to run.\n        nesterov: Boolean, whether to use Nesterov-style momentum\n        {{base_optimizer_keyword_args}}\n        rms_rate: Float. A parameter from https://arxiv.org/abs/2502.16982\n            that can enhance the stability of Muon, allowing it to use the\n            same learning rate and weight decay as Adam. Defaults to `0.2`.\n            Set to `None` to disable this feature.\n    \"\"\"\n\n    def __init__(\n        self,\n        learning_rate=0.001,\n        adam_beta_1=0.9,\n        adam_beta_2=0.999,\n        adam_weight_decay=0.004,\n        epsilon=1e-7,\n        weight_decay=0.004,\n        clipnorm=None,\n        clipvalue=None,\n        global_clipnorm=None,\n        use_ema=False,\n        ema_momentum=0.99,\n        ema_overwrite_frequency=None,\n        loss_scale_factor=None,\n        gradient_accumulation_steps=None,\n        name=\"muon\",\n        exclude_layers=None,\n        exclude_embeddings=True,\n        muon_a=3.4445,\n        muon_b=-4.7750,\n        muon_c=2.0315,\n        adam_lr_ratio=1,\n        momentum=0.95,\n        ns_steps=5,\n        nesterov=True,\n        rms_rate=0.2,\n        **kwargs,\n    ):\n        super().__init__(\n            learning_rate=learning_rate,\n            name=name,\n            weight_decay=weight_decay,\n            clipnorm=clipnorm,\n            clipvalue=clipvalue,\n            global_clipnorm=global_clipnorm,\n            use_ema=use_ema,\n            ema_momentum=ema_momentum,\n            ema_overwrite_frequency=ema_overwrite_frequency,\n            loss_scale_factor=loss_scale_factor,\n            gradient_accumulation_steps=gradient_accumulation_steps,\n            **kwargs,\n        )\n        self.adam_beta_1 = adam_beta_1\n        self.adam_beta_2 = adam_beta_2\n        self.epsilon = epsilon\n        self.muon_a = muon_a\n        self.muon_b = muon_b\n        self.muon_c = muon_c\n        self.adam_lr_ratio = adam_lr_ratio\n        self.momentum = momentum\n        self.ns_steps = ns_steps\n        self.nesterov = nesterov\n        self.exclude_embeddings = exclude_embeddings\n        self.exclude_layers = exclude_layers or []\n        self.adam_weight_decay = adam_weight_decay\n        self.rms_rate = rms_rate\n\n    def _should_use_adamw(self, variable):\n        # it works well to just flatten their last 3 dimensions.\n        # any {0,1}-D parameters should all be optimized by adam\n        if len(variable.shape) != 2:\n            return True\n        if self.exclude_embeddings and \"embedding\" in variable.path.lower():\n            return True\n        for keyword in self.exclude_layers:\n            if re.search(keyword, variable.path):\n                return True\n        return False\n\n    def build(self, var_list):\n        \"\"\"Initialize optimizer variables.\n\n        Adam optimizer has 3 types of variables: momentums, velocities and\n        velocity_hat (only set when amsgrad is applied),\n\n        Args:\n            var_list: list of model variables to build Adam variables on.\n        \"\"\"\n        if self.built:\n            return\n        super().build(var_list)\n        # Momentums are for both Muon and Adam\n        self.momentums = [None] * len(var_list)\n        # Velocities are just for Adam\n        self.adam_velocities = [None] * len(var_list)\n\n        for var in var_list:\n            if not self._overwrite_variable_with_gradient(var):\n                self.momentums[self._get_variable_index(var)] = (\n                    self.add_variable_from_reference(\n                        reference_variable=var, name=\"momentum\"\n                    )\n                )\n                if self._should_use_adamw(var):\n                    self.adam_velocities[self._get_variable_index(var)] = (\n                        self.add_variable_from_reference(\n                            reference_variable=var, name=\"velocity\"\n                        )\n                    )\n\n    def update_step(self, gradient, variable, learning_rate):\n        variable_index = self._get_variable_index(variable)\n        m = self.momentums[variable_index]\n        v = self.adam_velocities[variable_index]\n\n        # The presence of the velocity tells us that this variable is for Adam\n        if v is not None:\n            # It should be noted that lr is one-tenth when using adamw.\n            self._adamw_update_step(\n                gradient, variable, learning_rate * self.adam_lr_ratio, m, v\n            )\n        else:\n            self._muon_update_step(gradient, variable, learning_rate, m)\n\n    def _muon_update_step(self, gradient, variable, lr, m):\n        self.assign_add(m, ops.add(gradient, m * (self.momentum - 1)))\n        if self.nesterov:\n            g = ops.add(gradient, self.momentum * m)\n        else:\n            g = m\n        update = self.zeropower_via_newtonschulz5(g, self.ns_steps)\n\n        self.assign_sub(variable, self.lr_adjust(lr * update))\n\n    def _adamw_update_step(self, gradient, variable, learning_rate, m, v):\n        \"\"\"Update step given gradient and the associated model variable.\"\"\"\n        lr = ops.cast(learning_rate, variable.dtype)\n        gradient = ops.cast(gradient, variable.dtype)\n        local_step = ops.cast(self.iterations + 1, variable.dtype)\n        adam_beta_1_power = ops.power(\n            ops.cast(self.adam_beta_1, variable.dtype), local_step\n        )\n        adam_beta_2_power = ops.power(\n            ops.cast(self.adam_beta_2, variable.dtype), local_step\n        )\n\n        alpha = lr * ops.sqrt(1 - adam_beta_2_power) / (1 - adam_beta_1_power)\n\n        self.assign_add(\n            m, ops.multiply(ops.subtract(gradient, m), 1 - self.adam_beta_1)\n        )\n        self.assign_add(\n            v,\n            ops.multiply(\n                ops.subtract(ops.square(gradient), v), 1 - self.adam_beta_2\n            ),\n        )\n        self.assign_sub(\n            variable,\n            ops.divide(\n                ops.multiply(m, alpha), ops.add(ops.sqrt(v), self.epsilon)\n            ),\n        )\n\n    def transpose_last_axis(self, X):\n        shape = ops.shape(X)\n        temp_order = list(range(len(shape)))\n        temp_order[-2] = temp_order[-1]\n        temp_order[-1] = len(shape) - 2\n        X = ops.transpose(X, temp_order)\n        return X\n\n    def lr_adjust(self, x):\n        \"\"\"Adjusts learning rate based on the Moonlight implementation.\n        This method enhances the stability of Muon, allowing it to use the same\n        learning rate and weight decay as Adam. For details, see\n        https://arxiv.org/abs/2502.16982.\n        For a 2D matrix, the update is scaled by `sqrt(max(n, m)) * rms_rate`,\n        where `n` and `m` are the dimensions of the matrix.\n        \"\"\"\n        if self.rms_rate is None:\n            return x\n        # moonlight version\n        # https://github.com/MoonshotAI/Moonlight/blob/master/examples/toy_train.py\n        return x * ops.sqrt(ops.maximum(x.shape[0], x.shape[1])) * self.rms_rate\n\n    def zeropower_via_newtonschulz5(self, x, steps: int):\n        \"\"\"We apply the Newton-Schulz iteration to compute matrix G.\n\n        We select a quintic iteration that maximizes the slope at zero. This\n        approach helps minimize steps, even if the iteration doesn't fully\n        converge across the interval. The result isn't exactly UV^T (from the\n        SVD of G), but rather an approximation like US'V^T. Despite this\n        approximation, model performance remains unaffected compared to using\n        the exact UV^T from the SVD.\n        \"\"\"\n        shape = ops.shape(x)\n        if len(shape) < 2:\n            raise ValueError(\n                \"Expected gradient or momentum to have at least 2 dimensions. \"\n                f\"Received: shape={shape}\"\n            )\n\n        a, b, c = self.muon_a, self.muon_b, self.muon_c\n        if shape[-2] > shape[-1]:\n            x = self.transpose_last_axis(x)\n\n        # Ensure spectral norm is at most 1\n        x = x / (ops.norm(x, axis=(-2, -1), keepdims=True) + 1e-7)\n        # Perform the NS iterations\n        for _ in range(steps):\n            temp_a = x @ self.transpose_last_axis(x)\n            temp_b = b * temp_a + c * temp_a @ temp_a\n            x = a * x + temp_b @ x\n\n        if shape[-2] > shape[-1]:\n            x = self.transpose_last_axis(x)\n        return x\n\n    def _apply_weight_decay(self, variables):\n        for variable in variables:\n            if not self._use_weight_decay(variable):\n                continue\n            if self._should_use_adamw(variable):\n                weight_decay_value = self.adam_weight_decay\n            else:\n                weight_decay_value = self.weight_decay\n            if weight_decay_value is None:\n                continue\n            wd = ops.cast(weight_decay_value, variable.dtype)\n            lr = ops.cast(self.learning_rate, variable.dtype)\n            variable.assign(variable - variable * wd * lr)\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"adam_beta_1\": self.adam_beta_1,\n                \"adam_beta_2\": self.adam_beta_2,\n                \"epsilon\": self.epsilon,\n                \"exclude_layers\": self.exclude_layers,\n                \"muon_a\": self.muon_a,\n                \"muon_b\": self.muon_b,\n                \"muon_c\": self.muon_c,\n                \"adam_lr_ratio\": self.adam_lr_ratio,\n                \"momentum\": self.momentum,\n                \"ns_steps\": self.ns_steps,\n                \"nesterov\": self.nesterov,\n                \"exclude_embeddings\": self.exclude_embeddings,\n                \"adam_weight_decay\": self.adam_weight_decay,\n                \"rms_rate\": self.rms_rate,\n            }\n        )\n        return config\n"
  },
  {
    "path": "keras/src/optimizers/muon_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.layers import Dense\nfrom keras.src.layers import Embedding\nfrom keras.src.layers import Input\nfrom keras.src.models import Sequential\nfrom keras.src.optimizers.muon import Muon\n\n\nclass MuonTest(testing.TestCase):\n    def test_config(self):\n        optimizer = Muon(\n            learning_rate=0.5,\n            epsilon=1e-5,\n        )\n        self.run_class_serialization_test(optimizer)\n\n    def test_Newton_Schulz(self):\n        optimizer = Muon()\n        tensor_input = ops.array([[0.2499, 0.9105], [0.2655, 0.8824]])\n        except_output = ops.array([[-0.4422, 0.6457], [0.7285, 0.2968]])\n        output = optimizer.zeropower_via_newtonschulz5(tensor_input, 5)\n        self.assertAllClose(\n            output,\n            except_output,\n            rtol=1e-3,\n            atol=1e-3,\n            tpu_atol=1e-1,\n            tpu_rtol=1e-1,\n        )\n\n    def test_adamw_single_step(self):\n        optimizer = Muon()\n        grads = ops.array([1.0, 6.0, 7.0, 2.0])\n        vars = backend.Variable([1.0, 2.0, 3.0, 4.0], name=\"test_vars\")\n        optimizer.build([vars])\n        optimizer.update_step(grads, vars, 0.5)\n        self.assertAllClose(vars, [0.5, 1.5, 2.5, 3.5], rtol=1e-4, atol=1e-4)\n\n    def test_should_use_adamw(self):\n        vars = backend.Variable([[1.0, 2.0], [3.0, 4.0]])\n        optimizer = Muon(exclude_layers=[\"var\"])\n        self.assertTrue(optimizer._should_use_adamw(vars))\n        embedding = Embedding(2, 2)\n        embedding.build()\n        self.assertTrue(optimizer._should_use_adamw(embedding.weights[0]))\n        vars = backend.Variable([[1.0, 2.0], [3.0, 4.0]])\n        optimizer = Muon()\n        self.assertFalse(optimizer._should_use_adamw(vars))\n        dense = Dense(2)\n        dense.build([None, 2])\n        self.assertFalse(optimizer._should_use_adamw(dense.weights[0]))\n\n    def test_muon_single_step(self):\n        optimizer = Muon(\n            learning_rate=0.5,\n            weight_decay=0,\n        )\n        grads = ops.array([[1.0, 6.0], [7.0, 2.0]])\n        vars = backend.Variable([[1.0, 2.0], [3.0, 4.0]])\n        optimizer.build([vars])\n        optimizer.update_step(grads, vars, 0.5)\n        self.assertAllClose(\n            vars,\n            [[0.988775, 1.887053], [2.873428, 3.97035]],\n            rtol=1e-2,\n            atol=1e-2,\n        )\n\n    def test_clip_norm(self):\n        optimizer = Muon(clipnorm=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])\n\n    def test_clip_value(self):\n        optimizer = Muon(clipvalue=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [1.0, 1.0])\n\n    def test_muon_weight_decay(self):\n        variable = backend.Variable([[1.0, 2.0], [3.0, 4.0]])\n        weight_decay = 0.01\n        expected_variable = variable - variable * weight_decay\n        optimizer = Muon(learning_rate=1.0, weight_decay=weight_decay)\n        optimizer._apply_weight_decay([variable])\n        self.assertAllClose(variable, expected_variable, rtol=1e-4, atol=1e-4)\n\n    def test_adamw_weight_decay(self):\n        variable = backend.Variable(2.0)\n        weight_decay = 0.01\n        expected_variable = variable - variable * weight_decay\n        optimizer = Muon(learning_rate=1.0, adam_weight_decay=weight_decay)\n        optimizer._apply_weight_decay([variable])\n\n        self.assertAllClose(variable, expected_variable, rtol=1e-4, atol=1e-4)\n\n    def test_lr_adjust_none(self):\n        opt = Muon(rms_rate=None)\n        x = ops.ones((4, 4))\n        want = x\n        self.assertAllClose(opt.lr_adjust(x), want)\n\n    def test_lr_adjust_2d(self):\n        opt = Muon(rms_rate=0.2)\n        x = ops.ones((4, 2))\n        want = x * 0.2 * 2\n        self.assertAllClose(opt.lr_adjust(x), want)\n\n    @pytest.mark.requires_trainable_backend\n    def test_model_fit(self):\n        model = Sequential([Input((10,)), Dense(5), Dense(1, name=\"last\")])\n        x = ops.ones((1, 10))\n        y = ops.ones((1, 1))\n        optimizer = Muon(learning_rate=1e-3, exclude_layers=[\"last\"])\n        model.compile(optimizer=optimizer, loss=\"mse\")\n        model.fit(x, y)\n"
  },
  {
    "path": "keras/src/optimizers/nadam.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.optimizers import optimizer\n\n\n@keras_export([\"keras.optimizers.Nadam\"])\nclass Nadam(optimizer.Optimizer):\n    \"\"\"Optimizer that implements the Nadam algorithm.\n\n    Much like Adam is essentially RMSprop with momentum, Nadam is Adam with\n    Nesterov momentum.\n\n    Args:\n        learning_rate: A float, a\n            `keras.optimizers.schedules.LearningRateSchedule` instance, or\n            a callable that takes no arguments and returns the actual value to\n            use. The learning rate. Defaults to `0.001`.\n        beta_1: A float value or a constant float tensor, or a callable\n            that takes no arguments and returns the actual value to use. The\n            exponential decay rate for the 1st moment estimates.\n            Defaults to `0.9`.\n        beta_2: A float value or a constant float tensor, or a callable\n            that takes no arguments and returns the actual value to use. The\n            exponential decay rate for the 2nd moment estimates. Defaults to\n            `0.999`.\n        epsilon: A small constant for numerical stability. This epsilon is\n            \"epsilon hat\" in the Kingma and Ba paper (in the formula just before\n            Section 2.1), not the epsilon in Algorithm 1 of the paper.\n            Defaults to `1e-7`.\n        {{base_optimizer_keyword_args}}\n\n    Reference:\n\n    - [Dozat, 2015](http://cs229.stanford.edu/proj2015/054_report.pdf).\n\n    \"\"\"\n\n    def __init__(\n        self,\n        learning_rate=0.001,\n        beta_1=0.9,\n        beta_2=0.999,\n        epsilon=1e-7,\n        weight_decay=None,\n        clipnorm=None,\n        clipvalue=None,\n        global_clipnorm=None,\n        use_ema=False,\n        ema_momentum=0.99,\n        ema_overwrite_frequency=None,\n        loss_scale_factor=None,\n        gradient_accumulation_steps=None,\n        name=\"nadam\",\n        **kwargs,\n    ):\n        super().__init__(\n            learning_rate=learning_rate,\n            name=name,\n            weight_decay=weight_decay,\n            clipnorm=clipnorm,\n            clipvalue=clipvalue,\n            global_clipnorm=global_clipnorm,\n            use_ema=use_ema,\n            ema_momentum=ema_momentum,\n            ema_overwrite_frequency=ema_overwrite_frequency,\n            loss_scale_factor=loss_scale_factor,\n            gradient_accumulation_steps=gradient_accumulation_steps,\n            **kwargs,\n        )\n        self.beta_1 = beta_1\n        self.beta_2 = beta_2\n        self.epsilon = epsilon\n\n    def build(self, var_list):\n        \"\"\"Initialize optimizer variables.\n\n        Nadam optimizer has 2 types of variables: momentums and velocities.\n\n        Args:\n            var_list: list of model variables to build Nadam variables on.\n        \"\"\"\n        if self.built:\n            return\n        if var_list:\n            dtype = var_list[0].dtype\n        else:\n            dtype = backend.floatx()\n        super().build(var_list)\n        self._momentums, self._velocities = self.add_optimizer_variables(\n            var_list, [\"momentum\", \"velocity\"]\n        )\n        self._u_product = backend.Variable(1.0, dtype=dtype)\n\n    def _backend_update_step(self, grads, trainable_variables, learning_rate):\n        dtype = self._u_product.dtype\n        self.assign(\n            self._u_product,\n            self._u_product\n            * self.beta_1\n            * (\n                1.0\n                - 0.5 * ops.power(0.96, ops.cast(self.iterations + 1, dtype))\n            ),\n        )\n        super()._backend_update_step(grads, trainable_variables, learning_rate)\n\n    def update_step(self, gradient, variable, learning_rate):\n        \"\"\"Update step given gradient and the associated model variable.\"\"\"\n        var_dtype = variable.dtype\n        lr = ops.cast(learning_rate, var_dtype)\n        gradient = ops.cast(gradient, var_dtype)\n\n        local_step = ops.cast(self.iterations + 1, var_dtype)\n        next_step = ops.cast(self.iterations + 2, var_dtype)\n        decay = ops.cast(0.96, var_dtype)\n        beta_1 = ops.cast(self.beta_1, var_dtype)\n        beta_2 = ops.cast(self.beta_2, var_dtype)\n        u_t = beta_1 * (1.0 - 0.5 * (ops.power(decay, local_step)))\n        u_t_1 = beta_1 * (1.0 - 0.5 * (ops.power(decay, next_step)))\n        u_product_t = ops.cast(self._u_product, var_dtype)\n\n        u_product_t_1 = u_product_t * u_t_1\n        beta_2_power = ops.power(beta_2, local_step)\n\n        m = self._momentums[self._get_variable_index(variable)]\n        v = self._velocities[self._get_variable_index(variable)]\n\n        self.assign_add(\n            m, ops.multiply(ops.subtract(gradient, m), (1 - beta_1))\n        )\n        self.assign_add(\n            v, ops.multiply(ops.subtract(ops.square(gradient), v), (1 - beta_2))\n        )\n        m_hat = ops.add(\n            ops.divide(ops.multiply(u_t_1, m), 1 - u_product_t_1),\n            ops.divide(ops.multiply(1 - u_t, gradient), 1 - u_product_t),\n        )\n        v_hat = ops.divide(v, (1 - beta_2_power))\n\n        self.assign_sub(\n            variable,\n            ops.divide(\n                ops.multiply(m_hat, lr), ops.add(ops.sqrt(v_hat), self.epsilon)\n            ),\n        )\n\n    def get_config(self):\n        config = super().get_config()\n\n        config.update(\n            {\n                \"beta_1\": self.beta_1,\n                \"beta_2\": self.beta_2,\n                \"epsilon\": self.epsilon,\n            }\n        )\n        return config\n\n\nNadam.__doc__ = Nadam.__doc__.replace(\n    \"{{base_optimizer_keyword_args}}\", optimizer.base_optimizer_keyword_args\n)\n"
  },
  {
    "path": "keras/src/optimizers/nadam_test.py",
    "content": "# flake8: noqa\n\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.optimizers.nadam import Nadam\n\n\nclass NadamTest(testing.TestCase):\n    def test_config(self):\n        optimizer = Nadam(\n            learning_rate=0.5,\n            beta_1=0.5,\n            beta_2=0.67,\n            epsilon=1e-5,\n        )\n        self.run_class_serialization_test(optimizer)\n\n    def test_build_with_empty_var_list(self):\n        optimizer = Nadam()\n        optimizer.build([])\n        self.assertEqual(optimizer._u_product.dtype, backend.floatx())\n\n    def test_single_step(self):\n        optimizer = Nadam(learning_rate=0.5)\n        grads = ops.array([1.0, 6.0, 7.0, 2.0])\n        vars = backend.Variable([1.0, 2.0, 3.0, 4.0])\n        optimizer.apply_gradients(zip([grads], [vars]))\n        self.assertAllClose(\n            vars, [0.4686, 1.4686, 2.4686, 3.4686], rtol=1e-4, atol=1e-4\n        )\n\n    def test_weight_decay(self):\n        grads, var1, var2, var3 = (\n            ops.zeros(()),\n            backend.Variable(2.0),\n            backend.Variable(2.0, name=\"exclude\"),\n            backend.Variable(2.0),\n        )\n        optimizer_1 = Nadam(learning_rate=1.0, weight_decay=0.004)\n        optimizer_1.apply_gradients(zip([grads], [var1]))\n\n        optimizer_2 = Nadam(learning_rate=1.0, weight_decay=0.004)\n        optimizer_2.exclude_from_weight_decay(var_names=[\"exclude\"])\n        optimizer_2.apply_gradients(zip([grads, grads], [var1, var2]))\n\n        optimizer_3 = Nadam(learning_rate=1.0, weight_decay=0.004)\n        optimizer_3.exclude_from_weight_decay(var_list=[var3])\n        optimizer_3.apply_gradients(zip([grads, grads], [var1, var3]))\n\n        self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6)\n        self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6)\n        self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6)\n\n    def test_correctness_with_golden(self):\n        optimizer = Nadam(\n            learning_rate=0.5,\n            beta_1=0.5,\n            beta_2=0.67,\n            epsilon=1e-5,\n        )\n\n        x = backend.Variable(np.ones([10], dtype=\"float32\"))\n        grads = ops.arange(0.1, 1.1, 0.1)\n        first_grads = ops.full((10,), 0.01)\n\n        # fmt: off\n        golden = np.array(\n            [[0.4281, 0.4281, 0.4281, 0.4281, 0.4281, 0.4281, 0.4281, 0.4281, 0.4281, 0.4281],\n            [-0.1738, -0.1731, -0.1726, -0.1723, -0.1721, -0.172, -0.1719, -0.1718, -0.1718, -0.1717],\n            [-0.7115, -0.7103, -0.7096, -0.7092, -0.709, -0.7088, -0.7086, -0.7085, -0.7085, -0.7084],\n            [-1.2335, -1.2322, -1.2313, -1.2309, -1.2306, -1.2304, -1.2302, -1.2301, -1.23, -1.2299],\n            [-1.7492, -1.7478, -1.7469, -1.7464, -1.7461, -1.7459, -1.7457, -1.7456, -1.7455, -1.7454]]\n        )\n        # fmt: on\n\n        optimizer.apply_gradients(zip([first_grads], [x]))\n        for i in range(5):\n            self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4)\n            optimizer.apply_gradients(zip([grads], [x]))\n\n    def test_clip_norm(self):\n        optimizer = Nadam(clipnorm=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])\n\n    def test_clip_value(self):\n        optimizer = Nadam(clipvalue=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [1.0, 1.0])\n"
  },
  {
    "path": "keras/src/optimizers/optimizer.py",
    "content": "from keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.optimizers import base_optimizer\n\nif backend.backend() == \"tensorflow\":\n    from keras.src.backend.tensorflow.optimizer import (\n        TFOptimizer as BackendOptimizer,\n    )\nelif backend.backend() == \"torch\":\n    from keras.src.backend.torch.optimizers import (\n        TorchOptimizer as BackendOptimizer,\n    )\nelif backend.backend() == \"jax\":\n    from keras.src.backend.jax.optimizer import JaxOptimizer as BackendOptimizer\nelse:\n\n    class BackendOptimizer(base_optimizer.BaseOptimizer):\n        pass\n\n\n@keras_export([\"keras.Optimizer\", \"keras.optimizers.Optimizer\"])\nclass Optimizer(BackendOptimizer, base_optimizer.BaseOptimizer):\n    pass\n\n\nOptimizer.__doc__ = base_optimizer.BaseOptimizer.__doc__\nbase_optimizer_keyword_args = base_optimizer.base_optimizer_keyword_args\n"
  },
  {
    "path": "keras/src/optimizers/optimizer_sparse_test.py",
    "content": "from unittest import mock\n\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import optimizers\nfrom keras.src import testing\n\n\nclass ScatterUpdateOptimizer(optimizers.Optimizer):\n    def __init__(self):\n        super().__init__(learning_rate=0.001)\n\n    def build(self, variables):\n        if self.built:\n            return\n        super().build(variables)\n        self.momentums = [\n            self.add_variable_from_reference(v, name=\"momentum\")\n            for v in variables\n        ]\n\n    def update_step(self, grad, variable, learning_rate):\n        momentum = self.momentums[self._get_variable_index(variable)]\n        self.assign(momentum, ops.cast(grad, momentum.dtype))\n        self.assign(variable, ops.cast(grad, variable.dtype))\n\n\nTEST_CASES = [\n    {\n        \"testcase_name\": \"adadelta\",\n        \"optimizer_class\": optimizers.Adadelta,\n        \"expect_model_sparse_variable_updates\": True,\n    },\n    {\n        \"testcase_name\": \"adafactor\",\n        \"optimizer_class\": optimizers.Adafactor,\n        \"init_kwargs\": {\"clip_threshold\": 0.5},\n        \"expect_model_sparse_variable_updates\": True,\n    },\n    {\n        \"testcase_name\": \"adagrad\",\n        \"optimizer_class\": optimizers.Adagrad,\n        \"expect_model_sparse_variable_updates\": True,\n        \"expect_optimizer_sparse_variable_updates\": True,\n    },\n    {\n        \"testcase_name\": \"adam\",\n        \"optimizer_class\": optimizers.Adam,\n    },\n    {\n        \"testcase_name\": \"adam_amsgrad\",\n        \"optimizer_class\": optimizers.Adam,\n        \"init_kwargs\": {\"amsgrad\": True},\n    },\n    {\n        \"testcase_name\": \"adamax\",\n        \"optimizer_class\": optimizers.Adamax,\n    },\n    {\n        \"testcase_name\": \"adamw\",\n        \"optimizer_class\": optimizers.AdamW,\n    },\n    {\n        \"testcase_name\": \"adamw_amsgrad\",\n        \"optimizer_class\": optimizers.AdamW,\n        \"init_kwargs\": {\"amsgrad\": True},\n    },\n    {\n        \"testcase_name\": \"ftrl\",\n        \"optimizer_class\": optimizers.Ftrl,\n    },\n    {\n        \"testcase_name\": \"lion\",\n        \"optimizer_class\": optimizers.Lion,\n    },\n    {\n        \"testcase_name\": \"loss_scale_optimizer_sgd\",\n        \"optimizer_class\": lambda: optimizers.LossScaleOptimizer(\n            optimizers.SGD(learning_rate=0.5)\n        ),\n        \"expect_model_sparse_variable_updates\": True,\n    },\n    {\n        \"testcase_name\": \"nadam\",\n        \"optimizer_class\": optimizers.Nadam,\n    },\n    {\n        \"testcase_name\": \"rmsprop\",\n        \"optimizer_class\": optimizers.RMSprop,\n        \"expect_model_sparse_variable_updates\": True,\n    },\n    {\n        \"testcase_name\": \"rmsprop_momentum\",\n        \"optimizer_class\": optimizers.RMSprop,\n        \"init_kwargs\": {\"momentum\": 0.05},\n    },\n    {\n        \"testcase_name\": \"rmsprop_momentum_centered\",\n        \"optimizer_class\": optimizers.RMSprop,\n        \"init_kwargs\": {\"momentum\": 0.05, \"centered\": True},\n    },\n    {\n        \"testcase_name\": \"sgd\",\n        \"optimizer_class\": optimizers.SGD,\n        \"expect_model_sparse_variable_updates\": True,\n    },\n    {\n        \"testcase_name\": \"sgd_momentum\",\n        \"optimizer_class\": optimizers.SGD,\n        \"init_kwargs\": {\"momentum\": 0.05},\n    },\n    {\n        \"testcase_name\": \"sgd_momentum_nesterov\",\n        \"optimizer_class\": optimizers.SGD,\n        \"init_kwargs\": {\"momentum\": 0.05, \"nesterov\": True},\n    },\n    {\n        \"testcase_name\": \"scatter_update\",\n        \"optimizer_class\": ScatterUpdateOptimizer,\n        \"expect_model_sparse_variable_updates\": True,\n        \"expect_optimizer_sparse_variable_updates\": True,\n    },\n]\n\n\n@pytest.mark.skipif(\n    not backend.SUPPORTS_SPARSE_TENSORS,\n    reason=\"Backend does not support sparse tensors.\",\n)\nclass OptimizerSparseTest(testing.TestCase):\n    @parameterized.named_parameters(TEST_CASES)\n    def test_sparse_gradients(\n        self,\n        optimizer_class,\n        init_kwargs={},\n        expect_model_sparse_variable_updates=False,\n        expect_optimizer_sparse_variable_updates=False,\n    ):\n        # This test verifies that:\n        # - Optimizers use Keras ops everywhere instead of native operators\n        #   (e.g. `ops.add()` instead of `+`) where sparse gradients are handled\n        # - The used ops handle sparse gradients\n        # - Optimizers use `self.assign/assign_add/assign_sub` instead of\n        #   calling the method on the variable directly. Otherwise, the sparse\n        #   updates are densified before being applied.\n        # - For some optimizers, a sparse gradient actually results in a sparse\n        #   variable update as per `expect_model_sparse_variable_updates` and\n        #   `expect_optimizer_sparse_variable_updates`\n\n        model_variable = backend.Variable(initializer=\"ones\", shape=(5, 10))\n        optimizer = optimizer_class(**init_kwargs)\n\n        # Mocking \"tensorflow.Variable\" won't work as it gets substituted with\n        # the resource variable class.\n\n        if backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            grad = tf.IndexedSlices(0.5 * ops.ones((3, 10)), (0, 2, 4), (5, 10))\n            sparse_class = tf.IndexedSlices\n            variable_class = model_variable._value.__class__\n        elif backend.backend() == \"jax\":\n            import jax.experimental.sparse as jax_sparse\n\n            grad = jax_sparse.BCOO(\n                (0.5 * ops.ones((3, 10)), ((0,), (2,), (4,))), shape=(5, 10)\n            )\n            sparse_class = jax_sparse.JAXSparse\n            variable_class = model_variable.__class__\n        else:\n            self.fail(f\"Sparse is unsupported with backend {backend.backend()}\")\n\n        optimizer_to_patch = (\n            optimizer.inner_optimizer\n            if isinstance(optimizer, optimizers.LossScaleOptimizer)\n            else optimizer\n        )\n\n        model_sparse_variable_updates = False\n        optimizer_sparse_variable_updates = False\n\n        def mock_optimizer_assign(variable, value):\n            nonlocal model_sparse_variable_updates\n            nonlocal optimizer_sparse_variable_updates\n            if isinstance(variable, backend.Variable):\n                variable = variable._value\n            if isinstance(value, sparse_class):\n                if variable is model_variable._value:\n                    model_sparse_variable_updates = True\n                elif any(variable is v._value for v in optimizer.variables):\n                    optimizer_sparse_variable_updates = True\n\n        def mock_variable_assign(variable, value):\n            # Make an exception for scalar variables\n            if len(variable.shape):\n                pytest.fail(\n                    \"Optimizer is calling `assign`, `assign_add` or \"\n                    \"`assign_sub` directly on a variable. Use \"\n                    \"`self.assign/assign_add/assign_sub(variable, value)` \"\n                    \"instead to support sparse updates.\"\n                )\n\n        # patch \"_apply_weight_decay\" to exclude this special case.\n        # patch the optimizer \"assign\" methods to detect sparse updates.\n        # patch the tf.Variable \"assign\" methods to detect direct assign calls.\n        with (\n            mock.patch.object(\n                optimizer_to_patch, \"_apply_weight_decay\", autospec=True\n            ),\n            mock.patch.object(\n                optimizer_to_patch, \"assign\", autospec=True\n            ) as optimizer_assign,\n            mock.patch.object(\n                optimizer_to_patch, \"assign_add\", autospec=True\n            ) as optimizer_assign_add,\n            mock.patch.object(\n                optimizer_to_patch, \"assign_sub\", autospec=True\n            ) as optimizer_assign_sub,\n            mock.patch.object(\n                variable_class, \"assign\", autospec=True\n            ) as variable_assign,\n            mock.patch.object(\n                variable_class, \"assign_add\", autospec=True\n            ) as variable_assign_add,\n            mock.patch.object(\n                variable_class, \"assign_sub\", autospec=True\n            ) as variable_assign_sub,\n        ):\n            optimizer_assign.side_effect = mock_optimizer_assign\n            optimizer_assign_add.side_effect = mock_optimizer_assign\n            optimizer_assign_sub.side_effect = mock_optimizer_assign\n            variable_assign.side_effect = mock_variable_assign\n            variable_assign_add.side_effect = mock_variable_assign\n            variable_assign_sub.side_effect = mock_variable_assign\n\n            optimizer.apply([grad], [model_variable])\n\n        self.assertEqual(\n            model_sparse_variable_updates, expect_model_sparse_variable_updates\n        )\n        self.assertEqual(\n            optimizer_sparse_variable_updates,\n            expect_optimizer_sparse_variable_updates,\n        )\n\n    @parameterized.named_parameters(TEST_CASES)\n    def test_sparse_correctness(\n        self, optimizer_class, init_kwargs={}, **kwargs\n    ):\n        # This test verifies that applying a sparse gradient gives the same\n        # numerical results as the same dense gradient.\n\n        optimizer_sparse = optimizer_class(**init_kwargs)\n        optimizer_dense = optimizer_class(**init_kwargs)\n        var_sparse = backend.Variable(initializer=\"ones\", shape=(5, 3, 2))\n        var_dense = backend.Variable(initializer=\"ones\", shape=(5, 3, 2))\n        stateless = backend.backend() == \"jax\"\n        if stateless:\n            optimizer_sparse.build([var_sparse])\n            optimizer_dense.build([var_dense])\n\n        optimizer_sparse_vars = optimizer_sparse.variables\n        optimizer_dense_vars = optimizer_dense.variables\n        var_sparse_values = [var_sparse.value]\n        var_dense_values = [var_dense.value]\n\n        for i in range(5):\n            if backend.backend() == \"tensorflow\":\n                import tensorflow as tf\n\n                grad_sparse = tf.IndexedSlices(\n                    values=ops.ones((3, 3, 2)) * (10.0 - i),\n                    indices=(0, 2, 4),\n                    dense_shape=(5, 3, 2),\n                )\n            elif backend.backend() == \"jax\":\n                import jax.experimental.sparse as jax_sparse\n\n                grad_sparse = jax_sparse.BCOO(\n                    (ops.ones((3, 3, 2)) * (10.0 - i), ((0,), (2,), (4,))),\n                    shape=(5, 3, 2),\n                )\n            else:\n                self.fail(\n                    f\"Sparse is unsupported with backend {backend.backend()}\"\n                )\n\n            grad_dense = ops.convert_to_tensor(grad_sparse, sparse=False)\n            if stateless:\n                (\n                    var_sparse_values,\n                    optimizer_sparse_vars,\n                ) = optimizer_sparse.stateless_apply(\n                    optimizer_sparse_vars, [grad_sparse], var_sparse_values\n                )\n                (\n                    var_dense_values,\n                    optimizer_dense_vars,\n                ) = optimizer_dense.stateless_apply(\n                    optimizer_dense_vars, [grad_dense], var_dense_values\n                )\n                self.assertAllClose(var_sparse_values[0], var_dense_values[0])\n\n            else:\n                optimizer_sparse.apply([grad_sparse], [var_sparse])\n                optimizer_dense.apply([grad_dense], [var_dense])\n                self.assertAllClose(var_sparse.value, var_dense.value)\n"
  },
  {
    "path": "keras/src/optimizers/optimizer_test.py",
    "content": "import os\nimport pickle\n\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import constraints\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import optimizers\nfrom keras.src import testing\n\n\nclass OptimizerTest(testing.TestCase):\n    def test_iterations_counter(self):\n        v = backend.Variable([[1.0, 2.0], [3.0, 4.0]])\n        grads = backend.convert_to_tensor([[1.0, 1.0], [1.0, 1.0]])\n        optimizer = optimizers.Adam(learning_rate=1.0)\n        self.assertAllClose(optimizer.iterations, 0)\n        optimizer.apply_gradients([(grads, v)])\n        self.assertAllClose(optimizer.iterations, 1)\n        optimizer.apply_gradients([(grads, v)])\n        self.assertAllClose(optimizer.iterations, 2)\n\n    def test_empty_gradients(self):\n        # Test no valid gradient\n        v = backend.Variable([[3.0, 4.0], [5.0, 6.0]])\n        grads = None\n        optimizer = optimizers.SGD(learning_rate=1.0)\n        with self.assertRaisesRegex(\n            ValueError, \"No gradients provided for any variable.\"\n        ):\n            optimizer.apply_gradients([(grads, v)])\n\n        # Test filtering of empty gradients\n        v2 = backend.Variable([[3.0, 4.0], [5.0, 6.0]])\n        grads2 = backend.convert_to_tensor([[1.0, 1.0], [1.0, 1.0]])\n        optimizer = optimizers.SGD(learning_rate=1.0)\n        with self.assertWarns(Warning):\n            optimizer.apply_gradients([(grads, v), (grads2, v2)])\n        self.assertAllClose(v, [[3.0, 4.0], [5.0, 6.0]])\n        self.assertAllClose(v2, [[2.0, 3.0], [4.0, 5.0]])\n\n    def test_clip_args(self):\n        optimizer = optimizers.SGD(learning_rate=1.0, clipnorm=0.1)\n        self.assertEqual(optimizer.clipnorm, 0.1)\n        optimizer = optimizers.SGD(learning_rate=1.0, clipvalue=0.1)\n        self.assertEqual(optimizer.clipvalue, 0.1)\n        optimizer = optimizers.SGD(learning_rate=1.0, global_clipnorm=0.1)\n        self.assertEqual(optimizer.global_clipnorm, 0.1)\n\n        # Test invalid arguments\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Only one of `clipnorm`, `clipvalue` and `global_clipnorm` can\",\n        ):\n            optimizers.SGD(\n                learning_rate=1.0,\n                clipnorm=0.1,\n                clipvalue=0.1,\n            )\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Only one of `clipnorm`, `clipvalue` and `global_clipnorm` can\",\n        ):\n            optimizers.SGD(\n                learning_rate=1.0,\n                clipnorm=0.1,\n                global_clipnorm=0.1,\n            )\n\n    def test_clip_norm(self):\n        optimizer = optimizers.SGD(clipnorm=1)\n        grad = backend.convert_to_tensor([100.0, 100.0])\n        clipped_grad = optimizer._clip_gradients([grad])\n        self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])\n\n    def test_clip_value(self):\n        optimizer = optimizers.SGD(clipvalue=1)\n        grad = backend.convert_to_tensor([100.0, 100.0])\n        clipped_grad = optimizer._clip_gradients([grad])\n        self.assertAllClose(clipped_grad[0], [1.0, 1.0])\n\n    def test_global_clip_norm(self):\n        optimizer = optimizers.SGD(global_clipnorm=1)\n        grad = np.array([50.0, 100.0], dtype=\"float32\")\n        global_norm = np.linalg.norm(grad)\n        clipped_grad = optimizer._clip_gradients(\n            [backend.convert_to_tensor(grad)]\n        )\n        self.assertAllClose(clipped_grad[0], grad / global_norm)\n\n    def test_ema(self):\n        v = backend.Variable([[3.0, 4.0], [5.0, 6.0]])\n        grads = backend.convert_to_tensor([[1.0, 1.0], [1.0, 1.0]])\n        optimizer = optimizers.SGD(\n            learning_rate=1.0,\n            use_ema=True,\n            ema_momentum=0.9,\n            ema_overwrite_frequency=3,\n        )\n        optimizer.apply_gradients([(grads, v)])\n        self.assertAllClose(v, [[2.0, 3.0], [4.0, 5.0]])\n        self.assertAllClose(\n            optimizer._model_variables_moving_average[0],\n            [[2.0, 3.0], [4.0, 5.0]],  # initialized after first step\n        )\n        optimizer.apply_gradients([(grads, v)])\n        self.assertAllClose(v, [[1.0, 2.0], [3.0, 4.0]])\n        self.assertAllClose(\n            optimizer._model_variables_moving_average[0],\n            [[1.9, 2.9], [3.9, 4.9]],\n        )\n        optimizer.apply_gradients([(grads, v)])\n        # Variables were overwritten with EMA\n        self.assertAllClose(v, [[1.71, 2.71], [3.71, 4.71]])\n        self.assertAllClose(\n            optimizer._model_variables_moving_average[0],\n            [[1.71, 2.71], [3.71, 4.71]],\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_ema_with_model_fit(self):\n        x_train = np.ones((1, 1)).astype(\"float32\")\n        y_train = np.zeros((1, 1)).astype(\"float32\")\n        optimizer = optimizers.SGD(\n            learning_rate=0.1, use_ema=True, ema_momentum=0.9\n        )\n        model = models.Sequential(\n            [layers.Dense(2, kernel_initializer=\"ones\", use_bias=False)]\n        )\n        model.compile(loss=\"mse\", optimizer=optimizer, run_eagerly=True)\n        model.fit(x_train, y_train, batch_size=1, epochs=2)\n        self.assertAllClose(\n            optimizer._model_variables_moving_average[0].numpy(),\n            [[0.891, 0.891]],\n            atol=1e-5,\n        )\n        self.assertAllClose(\n            model.trainable_variables[0].numpy(),\n            [[0.891, 0.891]],\n            atol=1e-5,\n        )\n\n    def test_constraints_are_applied(self):\n        v = backend.Variable(np.random.random((2, 2)) - 1.0)\n        v.constraint = constraints.NonNeg()\n        optimizer = optimizers.SGD(learning_rate=0.0001)\n        grad = backend.numpy.zeros((2, 2))\n        optimizer.apply_gradients([(grad, v)])\n        self.assertAlmostEqual(np.min(v), 0.0)\n\n    def test_get_method(self):\n        obj = optimizers.get(\"sgd\")\n        self.assertIsInstance(obj, optimizers.SGD)\n        obj = optimizers.get(\"adamw\")\n        self.assertIsInstance(obj, optimizers.AdamW)\n\n        obj = optimizers.get(None)\n        self.assertEqual(obj, None)\n\n        with self.assertRaises(ValueError):\n            optimizers.get(\"typo\")\n\n    def test_static_loss_scaling(self):\n        v = backend.Variable([[1.0, 2.0], [3.0, 4.0]])\n        grads = backend.convert_to_tensor([[1.0, 2.0], [3.0, 4.0]]) * 1024.0\n        optimizer = optimizers.SGD(learning_rate=1.0, loss_scale_factor=1024.0)\n        optimizer.apply_gradients([(grads, v)])\n        self.assertEqual(optimizer.scale_loss(1.0), 1024.0)\n        self.assertAllClose(v, [[0.0, 0.0], [0.0, 0.0]])\n\n    def test_set_weights(self):\n        x = backend.Variable([[1.0, 2.0], [3.0, 4.0]])\n        optimizer_1 = optimizers.Adam()\n        grads = backend.convert_to_tensor([[1.0, 2.0], [3.0, 4.0]])\n        optimizer_1.apply_gradients(zip([grads], [x]))\n        optimizer_2 = optimizers.Adam()\n        with self.assertRaisesRegex(ValueError, \"You are calling*\"):\n            optimizer_2.set_weights(optimizer_1.variables)\n        optimizer_2.build([x])\n        optimizer_2.set_weights(optimizer_1.variables)\n        for i in range(len(optimizer_1.variables)):\n            self.assertAllClose(\n                optimizer_1.variables[i],\n                optimizer_2.variables[i],\n            )\n\n    def test_gradient_accumulation(self):\n        v = backend.Variable([[1.0, 2.0], [3.0, 4.0]])\n        grads = backend.convert_to_tensor([[1.0, 1.0], [1.0, 1.0]])\n        optimizer = optimizers.SGD(\n            learning_rate=1.0, gradient_accumulation_steps=3\n        )\n        self.assertEqual(optimizer.gradient_accumulation_steps, 3)\n\n        # Iteration 1\n        optimizer.apply_gradients([(grads, v)])\n        self.assertAllClose(v, [[1.0, 2.0], [3.0, 4.0]])\n        self.assertAllClose(\n            optimizer._accumulated_gradients[0], [[1.0, 1.0], [1.0, 1.0]]\n        )\n        self.assertAllClose(optimizer._iterations, 1)\n        self.assertAllClose(optimizer.iterations, 0)\n\n        # Iteration 2\n        optimizer.apply_gradients([(grads, v)])\n        self.assertAllClose(v, [[1.0, 2.0], [3.0, 4.0]])\n        self.assertAllClose(\n            optimizer._accumulated_gradients[0], [[2.0, 2.0], [2.0, 2.0]]\n        )\n        self.assertAllClose(optimizer._iterations, 2)\n        self.assertAllClose(optimizer.iterations, 0)\n\n        # Iteration 3\n        optimizer.apply_gradients([(grads, v)])\n        self.assertAllClose(v, [[0.0, 1.0], [2.0, 3.0]])\n        self.assertAllClose(\n            optimizer._accumulated_gradients[0], [[0.0, 0.0], [0.0, 0.0]]\n        )\n        self.assertAllClose(optimizer._iterations, 3)\n        self.assertAllClose(optimizer.iterations, 1)\n\n        # Iteration 4\n        optimizer.apply_gradients([(grads, v)])\n        self.assertAllClose(v, [[0.0, 1.0], [2.0, 3.0]])\n        self.assertAllClose(\n            optimizer._accumulated_gradients[0], [[1.0, 1.0], [1.0, 1.0]]\n        )\n        self.assertAllClose(optimizer._iterations, 4)\n        self.assertAllClose(optimizer.iterations, 1)\n\n    @pytest.mark.skipif(backend.backend() != \"tensorflow\", reason=\"Requires TF\")\n    def test_tf_checkpointing(self):\n        import tensorflow as tf\n\n        model = models.Sequential([layers.Dense(2)])\n        optimizer = optimizers.Adam()\n        x, y = np.random.random((1, 2)), np.random.random((1, 2))\n        model.compile(optimizer, \"mse\")\n        model.train_on_batch(x, y)\n        ref_pred = model.predict(x)\n\n        # Both model and optimizer are Trackables\n        checkpoint = tf.train.Checkpoint(model, optimizer=optimizer)\n        temp_filepath = os.path.join(self.get_temp_dir(), \"tf_ckpt\")\n        save_path = checkpoint.save(temp_filepath)\n\n        # Keep training the model (predictions now differ)\n        model.train_on_batch(x, y)\n        pred = model.predict(x)\n        self.assertNotAllClose(pred, ref_pred, atol=1e-3)\n\n        # Restore the model and check prediction correctness\n        checkpoint.restore(save_path)\n        pred = model.predict(x)\n        self.assertAllClose(pred, ref_pred, atol=1e-5)\n\n    def test_callable_learning_rate(self):\n        v = backend.Variable([[1.0, 2.0], [3.0, 4.0]])\n        grads = backend.convert_to_tensor([[1.0, 1.0], [1.0, 1.0]])\n        optimizer = optimizers.SGD(learning_rate=lambda: 0.1)\n        self.assertAllClose(optimizer.iterations, 0)\n        optimizer.apply_gradients([(grads, v)])\n        self.assertAllClose(v, [[0.9, 1.9], [2.9, 3.9]])\n        self.assertAllClose(optimizer.iterations, 1)\n\n    def test_overwrite_with_gradient(self):\n        v = backend.Variable([[1.0, 2.0], [3.0, 4.0]])\n        v.overwrite_with_gradient = True\n        v2 = backend.Variable([[1.0, 2.0], [3.0, 4.0]])\n        grads = backend.convert_to_tensor([[1.0, 1.0], [1.0, 1.0]])\n        grads2 = backend.convert_to_tensor([[1.0, 1.0], [1.0, 1.0]])\n\n        optimizer = optimizers.SGD(learning_rate=1.0)\n        optimizer.apply_gradients([(grads, v), (grads2, v2)])\n\n        # `v` is overwritten by its gradient but `v2` is updated normally\n        self.assertAllClose(v, [[1.0, 1.0], [1.0, 1.0]])\n        self.assertAllClose(v2, [[0.0, 1.0], [2.0, 3.0]])\n\n    def test_overwrite_with_gradient_with_gradient_accumulation(self):\n        v = backend.Variable([[1.0, 2.0], [3.0, 4.0]])\n        v.overwrite_with_gradient = True\n        v2 = backend.Variable([[1.0, 2.0], [3.0, 4.0]])\n        grad_ones = backend.convert_to_tensor([[1.0, 1.0], [1.0, 1.0]])\n        grad_twos = backend.convert_to_tensor([[2.0, 2.0], [2.0, 2.0]])\n        optimizer = optimizers.SGD(\n            learning_rate=1.0, gradient_accumulation_steps=2\n        )\n\n        # Iteration 1\n        optimizer.apply_gradients([(grad_ones, v), (grad_ones, v2)])\n        self.assertAllClose(optimizer._iterations, 1)\n        self.assertAllClose(optimizer.iterations, 0)\n        self.assertAllClose(v, [[1.0, 2.0], [3.0, 4.0]])\n        self.assertAllClose(v2, [[1.0, 2.0], [3.0, 4.0]])\n        self.assertAllClose(\n            optimizer._accumulated_gradients[0], [[1.0, 1.0], [1.0, 1.0]]\n        )\n        self.assertAllClose(\n            optimizer._accumulated_gradients[1], [[1.0, 1.0], [1.0, 1.0]]\n        )\n\n        # Iteration 2\n        optimizer.apply_gradients([(grad_twos, v), (grad_twos, v2)])\n        self.assertAllClose(optimizer._iterations, 2)\n        self.assertAllClose(optimizer.iterations, 1)\n        self.assertAllClose(v, [[2.0, 2.0], [2.0, 2.0]])\n        self.assertAllClose(v2, [[-0.5, 0.5], [1.5, 2.5]])\n        self.assertAllClose(\n            optimizer._accumulated_gradients[0], [[0.0, 0.0], [0.0, 0.0]]\n        )\n        self.assertAllClose(\n            optimizer._accumulated_gradients[1], [[0.0, 0.0], [0.0, 0.0]]\n        )\n\n        # Iteration 3\n        optimizer.apply_gradients([(grad_ones, v), (grad_ones, v2)])\n        self.assertAllClose(optimizer._iterations, 3)\n        self.assertAllClose(optimizer.iterations, 1)\n        self.assertAllClose(v, [[2.0, 2.0], [2.0, 2.0]])\n        self.assertAllClose(v2, [[-0.5, 0.5], [1.5, 2.5]])\n        self.assertAllClose(\n            optimizer._accumulated_gradients[0], [[1.0, 1.0], [1.0, 1.0]]\n        )\n        self.assertAllClose(\n            optimizer._accumulated_gradients[1], [[1.0, 1.0], [1.0, 1.0]]\n        )\n\n    @parameterized.parameters(\n        [\n            (\"adam\",),\n            (\"sgd\",),\n            (\"adamw\",),\n            (\"adagrad\",),\n            (\"rmsprop\",),\n            (\"adadelta\",),\n            (\"adamax\",),\n            (\"lion\",),\n            (\"nadam\",),\n            (\"ftrl\",),\n            (\"adafactor\",),\n        ]\n    )\n    def test_gradient_accumulation_with_weigth_decay(self, optimizer):\n        optimizer1 = optimizers.get(\n            {\"class_name\": optimizer, \"config\": {\"weight_decay\": 0.05}}\n        )\n        optimizer3 = optimizers.get(\n            {\n                \"class_name\": optimizer,\n                \"config\": {\n                    \"weight_decay\": 0.05,\n                    \"gradient_accumulation_steps\": 3,\n                },\n            }\n        )\n        variable1 = backend.Variable([[0.9], [0.5]])\n        variable3 = backend.Variable([[0.9], [0.5]])\n\n        for epoch in range(8):\n            grads3 = np.random.random([3, 2, 1]).astype(\"float32\")\n\n            grads1 = backend.convert_to_tensor(grads3.mean(axis=0))\n            optimizer1.apply_gradients([(grads1, variable1)])\n\n            for batch in range(3):\n                grads3_ = backend.convert_to_tensor(grads3[batch])\n                optimizer3.apply_gradients([(grads3_, variable3)])\n\n        self.assertAllClose(variable1, variable3)\n\n    def test_setting_lr_to_callable_untracks_lr_var(self):\n        adam = optimizers.Adam(learning_rate=0.001)\n        self.assertLen(adam.variables, 2)\n        adam.learning_rate = optimizers.schedules.PolynomialDecay(\n            adam.learning_rate, 4\n        )\n        self.assertLen(adam.variables, 1)\n\n    @parameterized.parameters(\n        [\n            (\"adam\",),\n            (\"sgd\",),\n            (\"adamw\",),\n            (\"adagrad\",),\n            (\"rmsprop\",),\n            (\"adadelta\",),\n            (\"adamax\",),\n            (\"lion\",),\n            (\"nadam\",),\n            (\"ftrl\",),\n            (\"adafactor\",),\n        ]\n    )\n    def test_pickleable_optimizers(self, optimizer):\n        optimizer = optimizers.get(optimizer)\n        reloaded = pickle.loads(pickle.dumps(optimizer))\n\n        self.assertEqual(optimizer.get_config(), reloaded.get_config())\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"The tf.Variable test can only run with TensorFlow backend.\",\n    )\n    def test_mixed_with_tf_variables(self):\n        import tensorflow as tf\n\n        v = backend.Variable([[1.0, 2.0], [3.0, 4.0]])\n        grads = backend.convert_to_tensor([[1.0, 1.0], [1.0, 1.0]])\n        tf_v = tf.Variable([[1.0, 2.0], [3.0, 4.0]])\n        tf_grads = backend.convert_to_tensor([[1.0, 1.0], [1.0, 1.0]])\n        optimizer = optimizers.Adam(learning_rate=1.0)\n        optimizer.apply_gradients([(grads, v), (tf_grads, tf_v)])\n        self.assertAllClose(optimizer.iterations, 1)\n\n        # Test with no grads\n        with self.assertWarnsRegex(\n            UserWarning, \"Gradients do not exist for variables\"\n        ):\n            optimizer.apply_gradients([(grads, v), (None, tf_v)])\n            self.assertAllClose(optimizer.iterations, 2)\n"
  },
  {
    "path": "keras/src/optimizers/rmsprop.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.optimizers import optimizer\n\n\n@keras_export([\"keras.optimizers.RMSprop\"])\nclass RMSprop(optimizer.Optimizer):\n    \"\"\"Optimizer that implements the RMSprop algorithm.\n\n    The gist of RMSprop is to:\n\n    - Maintain a moving (discounted) average of the square of gradients\n    - Divide the gradient by the root of this average\n\n    This implementation of RMSprop uses plain momentum, not Nesterov momentum.\n\n    The centered version additionally maintains a moving average of the\n    gradients, and uses that average to estimate the variance.\n\n    Args:\n        learning_rate: A float, a\n            `keras.optimizers.schedules.LearningRateSchedule` instance, or\n            a callable that takes no arguments and returns the actual value to\n            use. The learning rate. Defaults to `0.001`.\n        rho: float, defaults to 0.9. Discounting factor for the old gradients.\n        momentum: float, defaults to 0.0. If not 0.0., the optimizer tracks the\n            momentum value, with a decay rate equals to `1 - momentum`.\n        epsilon: A small constant for numerical stability. This epsilon is\n            \"epsilon hat\" in the Kingma and Ba paper (in the formula just before\n            Section 2.1), not the epsilon in Algorithm 1 of the paper. Defaults\n            to 1e-7.\n        centered: Boolean. If `True`, gradients are normalized by the estimated\n            variance of the gradient; if False, by the uncentered second moment.\n            Setting this to `True` may help with training, but is slightly more\n            expensive in terms of computation and memory. Defaults to `False`.\n        {{base_optimizer_keyword_args}}\n\n    Example:\n\n    >>> opt = keras.optimizers.RMSprop(learning_rate=0.1)\n    >>> var1 = keras.backend.Variable(10.0)\n    >>> loss = lambda: (var1 ** 2) / 2.0  # d(loss) / d(var1) = var1\n    >>> opt.minimize(loss, [var1])\n    >>> var1\n    9.683772\n\n    Reference:\n\n    - [Hinton, 2012](\n        http://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf)\n    \"\"\"\n\n    def __init__(\n        self,\n        learning_rate=0.001,\n        rho=0.9,\n        momentum=0.0,\n        epsilon=1e-7,\n        centered=False,\n        weight_decay=None,\n        clipnorm=None,\n        clipvalue=None,\n        global_clipnorm=None,\n        use_ema=False,\n        ema_momentum=0.99,\n        ema_overwrite_frequency=None,\n        loss_scale_factor=None,\n        gradient_accumulation_steps=None,\n        name=\"rmsprop\",\n        **kwargs,\n    ):\n        super().__init__(\n            learning_rate=learning_rate,\n            weight_decay=weight_decay,\n            clipnorm=clipnorm,\n            clipvalue=clipvalue,\n            global_clipnorm=global_clipnorm,\n            use_ema=use_ema,\n            ema_momentum=ema_momentum,\n            ema_overwrite_frequency=ema_overwrite_frequency,\n            loss_scale_factor=loss_scale_factor,\n            gradient_accumulation_steps=gradient_accumulation_steps,\n            name=name,\n            **kwargs,\n        )\n        self.rho = rho\n        self.momentum = momentum\n        self.epsilon = epsilon\n        self.centered = centered\n\n    def build(self, var_list):\n        if self.built:\n            return\n\n        super().build(var_list)\n\n        self._velocities = self.add_optimizer_variables(var_list, \"velocity\")\n\n        self._momentums = []\n        if self.momentum > 0:\n            self._momentums = self.add_optimizer_variables(var_list, \"momentum\")\n\n        self._average_gradients = []\n        if self.centered:\n            self._average_gradients = self.add_optimizer_variables(\n                var_list, \"average_gradient\"\n            )\n\n    def update_step(self, gradient, variable, learning_rate):\n        \"\"\"Update step given gradient and the associated model variable.\"\"\"\n        lr = ops.cast(learning_rate, variable.dtype)\n        gradient = ops.cast(gradient, variable.dtype)\n\n        velocity = self._velocities[self._get_variable_index(variable)]\n        momentum = None\n        if self.momentum > 0:\n            momentum = self._momentums[self._get_variable_index(variable)]\n        average_grad = None\n        if self.centered:\n            average_grad = self._average_gradients[\n                self._get_variable_index(variable)\n            ]\n\n        rho = self.rho\n\n        self.assign(\n            velocity,\n            ops.add(\n                ops.multiply(rho, velocity),\n                ops.multiply(1 - rho, ops.square(gradient)),\n            ),\n        )\n        if self.centered:\n            self.assign(\n                average_grad,\n                ops.add(\n                    ops.multiply(rho, average_grad),\n                    ops.multiply(1 - rho, gradient),\n                ),\n            )\n            denominator = velocity - ops.square(average_grad) + self.epsilon\n        else:\n            denominator = ops.add(velocity, self.epsilon)\n        increment = ops.divide(\n            ops.multiply(lr, gradient), ops.sqrt(denominator)\n        )\n        if self.momentum > 0:\n            self.assign(\n                momentum,\n                ops.add(ops.multiply(self.momentum, momentum), increment),\n            )\n            self.assign_sub(variable, momentum)\n        else:\n            self.assign_sub(variable, increment)\n\n    def get_config(self):\n        config = super().get_config()\n\n        config.update(\n            {\n                \"rho\": self.rho,\n                \"momentum\": self.momentum,\n                \"epsilon\": self.epsilon,\n                \"centered\": self.centered,\n            }\n        )\n        return config\n\n\nRMSprop.__doc__ = RMSprop.__doc__.replace(\n    \"{{base_optimizer_keyword_args}}\", optimizer.base_optimizer_keyword_args\n)\n"
  },
  {
    "path": "keras/src/optimizers/rmsprop_test.py",
    "content": "import numpy as np\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.optimizers.rmsprop import RMSprop\n\n\nclass RMSpropTest(testing.TestCase):\n    def test_config(self):\n        optimizer = RMSprop(\n            learning_rate=0.5,\n            rho=0.8,\n            momentum=0.05,\n            epsilon=1e-6,\n            centered=True,\n        )\n        self.run_class_serialization_test(optimizer)\n\n    def test_single_step(self):\n        optimizer = RMSprop(learning_rate=0.5)\n        grads = ops.array([1.0, 6.0, 7.0, 2.0])\n        vars = backend.Variable([1.0, 2.0, 3.0, 4.0])\n        optimizer.apply_gradients(zip([grads], [vars]))\n        self.assertAllClose(\n            vars, [-0.5811, 0.4189, 1.4189, 2.4189], rtol=1e-4, atol=1e-4\n        )\n\n    def test_weight_decay(self):\n        grads, var1, var2, var3 = (\n            ops.zeros(()),\n            backend.Variable(2.0),\n            backend.Variable(2.0, name=\"exclude\"),\n            backend.Variable(2.0),\n        )\n        optimizer_1 = RMSprop(learning_rate=1.0, weight_decay=0.004)\n        optimizer_1.apply_gradients(zip([grads], [var1]))\n\n        optimizer_2 = RMSprop(learning_rate=1.0, weight_decay=0.004)\n        optimizer_2.exclude_from_weight_decay(var_names=[\"exclude\"])\n        optimizer_2.apply_gradients(zip([grads, grads], [var1, var2]))\n\n        optimizer_3 = RMSprop(learning_rate=1.0, weight_decay=0.004)\n        optimizer_3.exclude_from_weight_decay(var_list=[var3])\n        optimizer_3.apply_gradients(zip([grads, grads], [var1, var3]))\n\n        self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6)\n        self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6)\n        self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6)\n\n    def test_correctness_with_golden(self):\n        optimizer = RMSprop(centered=True)\n\n        x = backend.Variable(np.ones([10]))\n        grads = ops.arange(0.1, 1.1, 0.1)\n        first_grads = ops.full((10,), 0.01)\n\n        golden = np.tile(\n            [[0.9967], [0.9933], [0.9908], [0.9885], [0.9864]], (1, 10)\n        )\n\n        optimizer.apply_gradients(zip([first_grads], [x]))\n        for i in range(5):\n            self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4)\n            optimizer.apply_gradients(zip([grads], [x]))\n\n    def test_clip_norm(self):\n        optimizer = RMSprop(clipnorm=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])\n\n    def test_clip_value(self):\n        optimizer = RMSprop(clipvalue=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [1.0, 1.0])\n"
  },
  {
    "path": "keras/src/optimizers/schedule_free_adamw.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.optimizers import optimizer\n\n\n@keras_export([\"keras.optimizers.ScheduleFreeAdamW\"])\nclass ScheduleFreeAdamW(optimizer.Optimizer):\n    \"\"\"Optimizer that implements the Schedule-Free AdamW algorithm.\n\n    Schedule-Free learning is a method that avoids the need for a learning rate\n    schedule by maintaining a combination of interpolation and averaging.\n    This approach eliminates the requirement to specify stopping time in advance\n    and typically matches or outperforms cosine and linear decay schedules.\n\n    The optimizer maintains two sets of variables internally:\n    - `momentum`: The sequence where gradient updates are applied\n    - `x`: The averaged sequence used for evaluation\n\n    During training, the model parameters are set to an interpolation between\n    `momentum` and `x`.\n\n    Args:\n        learning_rate: A float, a\n            `keras.optimizers.schedules.LearningRateSchedule` instance, or\n            a callable that takes no arguments and returns the actual value to\n            use. The learning rate. Defaults to `0.0025`.\n        beta_1: A float value or a constant float tensor, or a callable\n            that takes no arguments and returns the actual value to use. The\n            exponential decay rate for the 1st moment estimates and controls\n            the interpolation between `momentum` and `x`. Defaults to `0.9`.\n        beta_2: A float value or a constant float tensor, or a callable\n            that takes no arguments and returns the actual value to use. The\n            exponential decay rate for the 2nd moment estimates.\n            Defaults to `0.999`.\n        epsilon: A small constant for numerical stability.\n            Defaults to `1e-8`.\n        warmup_steps: Number of warmup steps for learning rate warmup.\n            During warmup, the learning rate linearly increases from 0 to the\n            specified learning rate. Defaults to `0`.\n        {{base_optimizer_keyword_args}}\n\n    References:\n\n    - [Defazio et al., 2024](https://arxiv.org/abs/2405.15682)\n    - [Schedule-Free repository](\n        https://github.com/facebookresearch/schedule_free)\n\n    Example:\n\n    >>> optimizer = keras.optimizers.ScheduleFreeAdamW(learning_rate=0.0025)\n    >>> model.compile(optimizer=optimizer, loss=\"mse\")\n    >>> model.fit(x_train, y_train)\n\n    \"\"\"\n\n    def __init__(\n        self,\n        learning_rate=0.0025,\n        beta_1=0.9,\n        beta_2=0.999,\n        epsilon=1e-8,\n        warmup_steps=0,\n        weight_decay=None,\n        clipnorm=None,\n        clipvalue=None,\n        global_clipnorm=None,\n        use_ema=False,\n        ema_momentum=0.99,\n        ema_overwrite_frequency=None,\n        loss_scale_factor=None,\n        gradient_accumulation_steps=None,\n        name=\"schedule_free_adamw\",\n        **kwargs,\n    ):\n        super().__init__(\n            learning_rate=learning_rate,\n            name=name,\n            weight_decay=weight_decay,\n            clipnorm=clipnorm,\n            clipvalue=clipvalue,\n            global_clipnorm=global_clipnorm,\n            use_ema=use_ema,\n            ema_momentum=ema_momentum,\n            ema_overwrite_frequency=ema_overwrite_frequency,\n            loss_scale_factor=loss_scale_factor,\n            gradient_accumulation_steps=gradient_accumulation_steps,\n            **kwargs,\n        )\n        self.beta_1 = beta_1\n        self.beta_2 = beta_2\n        self.epsilon = epsilon\n        self.warmup_steps = warmup_steps\n\n    def build(self, var_list):\n        \"\"\"Initialize optimizer variables.\n\n        ScheduleFreeAdamW optimizer has the following variables:\n        - `momentum`: Auxiliary variable where gradient updates are applied\n        - `velocity`: Exponential moving average of squared gradients (Adam)\n\n        Args:\n            var_list: list of model variables to build optimizer variables on.\n        \"\"\"\n        if self.built:\n            return\n        super().build(var_list)\n        self._momentums, self._velocities = self.add_optimizer_variables(\n            var_list, [\"momentum\", \"velocity\"]\n        )\n\n        # Initialize momentum to match the initial parameter values\n        for momentum, var in zip(self._momentums, var_list):\n            if momentum is not None:\n                self.assign(momentum, ops.copy(var))\n\n    def update_step(self, gradient, variable, learning_rate):\n        \"\"\"Update step given gradient and the associated model variable.\"\"\"\n        lr = ops.cast(learning_rate, variable.dtype)\n        gradient = ops.cast(gradient, variable.dtype)\n        local_step = ops.cast(self.iterations + 1, variable.dtype)\n\n        beta_1 = ops.cast(self.beta_1, variable.dtype)\n        beta_2 = ops.cast(self.beta_2, variable.dtype)\n        epsilon = ops.cast(self.epsilon, variable.dtype)\n\n        # Apply warmup\n        if self.warmup_steps > 0:\n            warmup_steps = ops.cast(self.warmup_steps, variable.dtype)\n            warmup_factor = ops.minimum(local_step / warmup_steps, 1.0)\n            lr = lr * warmup_factor\n\n        var_index = self._get_variable_index(variable)\n        momentum = self._momentums[var_index]\n        velocity = self._velocities[var_index]\n\n        # Store momentum_old before any updates\n        momentum_old = momentum.value\n\n        # Bias correction for Adam's second moment\n        bias_correction_2 = 1 - ops.power(beta_2, local_step)\n\n        # Update velocity (second moment estimate)\n        # velocity = beta_2 * velocity + (1 - beta_2) * gradient^2\n        self.assign_add(\n            velocity,\n            ops.multiply(\n                ops.subtract(ops.square(gradient), velocity), 1 - beta_2\n            ),\n        )\n\n        # Compute the denominator (RMSprop-style with bias correction)\n        denom = ops.add(ops.sqrt(velocity / bias_correction_2), epsilon)\n\n        # Update momentum: momentum = momentum - lr * gradient / denom\n        grad_scaled = ops.divide(ops.multiply(lr, gradient), denom)\n        self.assign_sub(momentum, grad_scaled)\n\n        # Compute weight for averaging: weight = 1 / step\n        weight = 1.0 / local_step\n\n        # Recover x_old from y_old and momentum_old\n        # x_old = (y_old - (1 - beta_1) * momentum_old) / beta_1\n        y_old = variable\n        x_old = ops.divide(\n            ops.subtract(y_old, ops.multiply(1 - beta_1, momentum_old)),\n            beta_1,\n        )\n\n        # x_new = lerp(x_old, momentum, weight)\n        # x_new = (1 - weight) * x_old + weight * momentum\n        x_new = ops.add(\n            ops.multiply(1 - weight, x_old), ops.multiply(weight, momentum)\n        )\n\n        # y_new = lerp(momentum, x_new, beta_1)\n        # y_new = (1 - beta_1) * momentum + beta_1 * x_new\n        y_new = ops.add(\n            ops.multiply(1 - beta_1, momentum), ops.multiply(beta_1, x_new)\n        )\n\n        self.assign(variable, y_new)\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"beta_1\": self.beta_1,\n                \"beta_2\": self.beta_2,\n                \"epsilon\": self.epsilon,\n                \"warmup_steps\": self.warmup_steps,\n            }\n        )\n        return config\n\n\nScheduleFreeAdamW.__doc__ = ScheduleFreeAdamW.__doc__.replace(\n    \"{{base_optimizer_keyword_args}}\", optimizer.base_optimizer_keyword_args\n)\n"
  },
  {
    "path": "keras/src/optimizers/schedule_free_adamw_test.py",
    "content": "import numpy as np\nimport pytest\n\nimport keras\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.optimizers.schedule_free_adamw import ScheduleFreeAdamW\n\n\nclass ScheduleFreeAdamWTest(testing.TestCase):\n    def test_config(self):\n        optimizer = ScheduleFreeAdamW(\n            learning_rate=0.005,\n            beta_1=0.95,\n            beta_2=0.99,\n            epsilon=1e-6,\n            warmup_steps=100,\n        )\n        self.run_class_serialization_test(optimizer)\n\n    def test_single_step(self):\n        optimizer = ScheduleFreeAdamW(learning_rate=0.5)\n        grads = ops.array([1.0, 6.0, 7.0, 2.0])\n        vars = backend.Variable([1.0, 2.0, 3.0, 4.0])\n        optimizer.apply_gradients(zip([grads], [vars]))\n        # After one step, the parameters should have changed\n        self.assertNotAllClose(vars, [1.0, 2.0, 3.0, 4.0], rtol=1e-4, atol=1e-4)\n\n    def test_weight_decay(self):\n        grads, var1, var2 = (\n            ops.zeros(()),\n            backend.Variable(2.0),\n            backend.Variable(2.0, name=\"exclude\"),\n        )\n        optimizer_1 = ScheduleFreeAdamW(learning_rate=1.0, weight_decay=0.004)\n        optimizer_1.apply_gradients(zip([grads], [var1]))\n\n        optimizer_2 = ScheduleFreeAdamW(learning_rate=1.0, weight_decay=0.004)\n        optimizer_2.exclude_from_weight_decay(var_names=[\"exclude\"])\n        optimizer_2.apply_gradients(zip([grads, grads], [var1, var2]))\n\n        # var2 should be unchanged since it's excluded from weight decay\n        self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6)\n\n    def test_warmup(self):\n        \"\"\"Test that warmup affects the learning rate.\"\"\"\n        optimizer_no_warmup = ScheduleFreeAdamW(\n            learning_rate=0.5, warmup_steps=0\n        )\n        optimizer_with_warmup = ScheduleFreeAdamW(\n            learning_rate=0.5, warmup_steps=10\n        )\n\n        grads = ops.array([1.0, 1.0, 1.0])\n        var1 = backend.Variable([1.0, 2.0, 3.0])\n        var2 = backend.Variable([1.0, 2.0, 3.0])\n\n        # Apply single gradient step\n        optimizer_no_warmup.apply_gradients(zip([grads], [var1]))\n        optimizer_with_warmup.apply_gradients(zip([grads], [var2]))\n\n        # The optimizer with warmup should have made a smaller update\n        # because effective lr = lr * (step / warmup_steps) = 0.5 * 0.1 = 0.05\n        diff_no_warmup = np.abs(var1.numpy() - [1.0, 2.0, 3.0])\n        diff_with_warmup = np.abs(var2.numpy() - [1.0, 2.0, 3.0])\n\n        # With warmup, the update should be smaller\n        self.assertTrue(np.all(diff_with_warmup < diff_no_warmup))\n\n    def test_multiple_steps(self):\n        \"\"\"Test that the optimizer works over multiple steps.\"\"\"\n        optimizer = ScheduleFreeAdamW(learning_rate=0.01)\n        var = backend.Variable([1.0, 2.0, 3.0])\n\n        for _ in range(10):\n            grads = ops.array([0.1, 0.1, 0.1])\n            optimizer.apply_gradients(zip([grads], [var]))\n\n        # Parameters should have decreased\n        final_values = var.numpy()\n        self.assertTrue(np.all(final_values < [1.0, 2.0, 3.0]))\n\n    @pytest.mark.requires_trainable_backend\n    def test_with_model(self):\n        \"\"\"Test that the optimizer works with a Keras model.\"\"\"\n        model = keras.Sequential([keras.layers.Dense(10)])\n        optimizer = ScheduleFreeAdamW(learning_rate=0.01)\n        model.compile(optimizer=optimizer, loss=\"mse\")\n\n        x = keras.ops.ones((4, 5))\n        y = keras.ops.zeros((4, 10))\n\n        # Training\n        model.fit(x, y, epochs=2, verbose=0)\n\n        # Evaluation\n        loss = model.evaluate(x, y, verbose=0)\n        self.assertIsNotNone(loss)\n\n    def test_clip_norm(self):\n        optimizer = ScheduleFreeAdamW(clipnorm=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])\n\n    def test_clip_value(self):\n        optimizer = ScheduleFreeAdamW(clipvalue=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [1.0, 1.0])\n"
  },
  {
    "path": "keras/src/optimizers/schedules/__init__.py",
    "content": "from keras.src.optimizers.schedules.learning_rate_schedule import CosineDecay\nfrom keras.src.optimizers.schedules.learning_rate_schedule import (\n    CosineDecayRestarts,\n)\nfrom keras.src.optimizers.schedules.learning_rate_schedule import (\n    ExponentialDecay,\n)\nfrom keras.src.optimizers.schedules.learning_rate_schedule import (\n    InverseTimeDecay,\n)\nfrom keras.src.optimizers.schedules.learning_rate_schedule import (\n    PiecewiseConstantDecay,\n)\nfrom keras.src.optimizers.schedules.learning_rate_schedule import (\n    PolynomialDecay,\n)\n"
  },
  {
    "path": "keras/src/optimizers/schedules/learning_rate_schedule.py",
    "content": "\"\"\"Various learning rate schedule functions.\"\"\"\n\nimport math\n\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.saving import serialization_lib\n\n\n@keras_export(\"keras.optimizers.schedules.LearningRateSchedule\")\nclass LearningRateSchedule:\n    \"\"\"The learning rate schedule base class.\n\n    You can use a learning rate schedule to modulate how the learning rate\n    of your optimizer changes over time.\n\n    Several built-in learning rate schedules are available, such as\n    `keras.optimizers.schedules.ExponentialDecay` or\n    `keras.optimizers.schedules.PiecewiseConstantDecay`:\n\n    ```python\n    lr_schedule = keras.optimizers.schedules.ExponentialDecay(\n        initial_learning_rate=1e-2,\n        decay_steps=10000,\n        decay_rate=0.9)\n    optimizer = keras.optimizers.SGD(learning_rate=lr_schedule)\n    ```\n\n    A `LearningRateSchedule` instance can be passed in as the `learning_rate`\n    argument of any optimizer.\n\n    To implement your own schedule object, you should implement the `__call__`\n    method, which takes a `step` argument (scalar integer tensor, the\n    current training step count).\n    Like for any other Keras object, you can also optionally\n    make your object serializable by implementing the `get_config`\n    and `from_config` methods.\n\n    Example:\n\n    ```python\n    class MyLRSchedule(keras.optimizers.schedules.LearningRateSchedule):\n\n        def __init__(self, initial_learning_rate):\n            self.initial_learning_rate = initial_learning_rate\n\n        def __call__(self, step):\n            return self.initial_learning_rate / (step + 1)\n\n    optimizer = keras.optimizers.SGD(learning_rate=MyLRSchedule(0.1))\n    ```\n    \"\"\"\n\n    def __call__(self, step):\n        raise NotImplementedError(\n            f\"Learning rate schedule '{self.__class__.__name__}' \"\n            \"must override `__call__(self, step)`.\"\n        )\n\n    def get_config(self):\n        raise NotImplementedError(\n            f\"Learning rate schedule '{self.__class__.__name__}' \"\n            \"must override `get_config()` in order to be serializable.\"\n        )\n\n    @classmethod\n    def from_config(cls, config):\n        \"\"\"Instantiates a `LearningRateSchedule` from its config.\n\n        Args:\n            config: Output of `get_config()`.\n\n        Returns:\n            A `LearningRateSchedule` instance.\n        \"\"\"\n        return cls(**config)\n\n\n@keras_export(\"keras.optimizers.schedules.ExponentialDecay\")\nclass ExponentialDecay(LearningRateSchedule):\n    \"\"\"A `LearningRateSchedule` that uses an exponential decay schedule.\n\n    When training a model, it is often useful to lower the learning rate as\n    the training progresses. This schedule applies an exponential decay function\n    to an optimizer step, given a provided initial learning rate.\n\n    The schedule is a 1-arg callable that produces a decayed learning\n    rate when passed the current optimizer step. This can be useful for changing\n    the learning rate value across different invocations of optimizer functions.\n    It is computed as:\n\n    ```python\n    def decayed_learning_rate(step):\n        return initial_learning_rate * decay_rate ^ (step / decay_steps)\n    ```\n\n    If the argument `staircase` is `True`, then `step / decay_steps` is\n    an integer division and the decayed learning rate follows a\n    staircase function.\n\n    You can pass this schedule directly into a `keras.optimizers.Optimizer`\n    as the learning rate.\n    Example: When fitting a Keras model, decay every 100000 steps with a base\n    of 0.96:\n\n    ```python\n    initial_learning_rate = 0.1\n    lr_schedule = keras.optimizers.schedules.ExponentialDecay(\n        initial_learning_rate,\n        decay_steps=100000,\n        decay_rate=0.96,\n        staircase=True)\n\n    model.compile(optimizer=keras.optimizers.SGD(learning_rate=lr_schedule),\n                  loss='sparse_categorical_crossentropy',\n                  metrics=['accuracy'])\n\n    model.fit(data, labels, epochs=5)\n    ```\n\n    The learning rate schedule is also serializable and deserializable using\n    `keras.optimizers.schedules.serialize` and\n    `keras.optimizers.schedules.deserialize`.\n\n    Args:\n        initial_learning_rate: A Python float. The initial learning rate.\n        decay_steps: A Python integer. Must be positive. See the decay\n            computation above.\n        decay_rate: A Python float. The decay rate.\n        staircase: Boolean.  If `True` decay the learning rate at discrete\n            intervals.\n        name: String.  Optional name of the operation.  Defaults to\n            `\"ExponentialDecay`\".\n\n    Returns:\n        A 1-arg callable learning rate schedule that takes the current optimizer\n        step and outputs the decayed learning rate, a scalar tensor of the\n        same type as `initial_learning_rate`.\n    \"\"\"\n\n    def __init__(\n        self,\n        initial_learning_rate,\n        decay_steps,\n        decay_rate,\n        staircase=False,\n        name=\"ExponentialDecay\",\n    ):\n        super().__init__()\n        self.initial_learning_rate = initial_learning_rate\n        self.decay_steps = decay_steps\n        self.decay_rate = decay_rate\n        self.staircase = staircase\n        self.name = name\n\n        if self.decay_steps <= 0:\n            raise ValueError(\n                \"Argument `decay_steps` must be > 0. \"\n                f\"Received: decay_steps={self.decay_steps}\"\n            )\n\n    def __call__(self, step):\n        with ops.name_scope(self.name):\n            initial_learning_rate = ops.convert_to_tensor(\n                self.initial_learning_rate\n            )\n            dtype = initial_learning_rate.dtype\n            decay_steps = ops.cast(self.decay_steps, dtype)\n            decay_rate = ops.cast(self.decay_rate, dtype)\n\n            global_step_recomp = ops.cast(step, dtype)\n            p = global_step_recomp / decay_steps\n            if self.staircase:\n                p = ops.floor(p)\n            return ops.multiply(initial_learning_rate, ops.power(decay_rate, p))\n\n    def get_config(self):\n        return {\n            \"initial_learning_rate\": self.initial_learning_rate,\n            \"decay_steps\": self.decay_steps,\n            \"decay_rate\": self.decay_rate,\n            \"staircase\": self.staircase,\n            \"name\": self.name,\n        }\n\n\n@keras_export(\"keras.optimizers.schedules.PiecewiseConstantDecay\")\nclass PiecewiseConstantDecay(LearningRateSchedule):\n    \"\"\"A `LearningRateSchedule` that uses a piecewise constant decay schedule.\n\n    The function returns a 1-arg callable to compute the piecewise constant\n    when passed the current optimizer step. This can be useful for changing the\n    learning rate value across different invocations of optimizer functions.\n\n    Example: use a learning rate that's 1.0 for the first 100001 steps, 0.5\n        for the next 10000 steps, and 0.1 for any additional steps.\n\n    ```python\n    step = ops.array(0)\n    boundaries = [100000, 110000]\n    values = [1.0, 0.5, 0.1]\n    learning_rate_fn = keras.optimizers.schedules.PiecewiseConstantDecay(\n        boundaries, values)\n\n    # Later, whenever we perform an optimization step, we pass in the step.\n    learning_rate = learning_rate_fn(step)\n    ```\n\n    You can pass this schedule directly into a `keras.optimizers.Optimizer`\n    as the learning rate. The learning rate schedule is also serializable and\n    deserializable using `keras.optimizers.schedules.serialize` and\n    `keras.optimizers.schedules.deserialize`.\n\n    Args:\n        boundaries: A list of Python numbers with strictly increasing\n            entries, and with all elements having the same type as the\n            optimizer step.\n        values: A list of Python numbers that specifies the values for the\n            intervals defined by `boundaries`. It should have one more\n            element than `boundaries`, and all elements should have the same\n            type.\n        name: A string. Optional name of the operation. Defaults to\n            `\"PiecewiseConstant\"`.\n\n    Returns:\n        A 1-arg callable learning rate schedule that takes the current optimizer\n        step and outputs the decayed learning rate, a scalar tensor of the\n        same type as the boundary tensors.\n\n        The output of the 1-arg function that takes the `step`\n        is `values[0]` when `step <= boundaries[0]`,\n        `values[1]` when `step > boundaries[0]` and `step <= boundaries[1]`,\n        ..., and `values[-1]` when `step > boundaries[-1]`.\n\n\n    Raises:\n        ValueError: if the number of elements in the `boundaries` and `values`\n        lists do not match.\n    \"\"\"\n\n    def __init__(self, boundaries, values, name=\"PiecewiseConstant\"):\n        super().__init__()\n\n        if len(boundaries) != len(values) - 1:\n            raise ValueError(\n                \"The length of boundaries should be 1 less than the length of \"\n                f\"values. Received: boundaries={boundaries} of length \"\n                f\"{len(boundaries)}, and values={values} \"\n                f\"of length {len(values)}.\"\n            )\n\n        self.boundaries = boundaries\n        self.values = values\n        self.name = name\n\n    def __call__(self, step):\n        with ops.name_scope(self.name):\n            boundaries = [ops.convert_to_tensor(x) for x in self.boundaries]\n            values = [ops.convert_to_tensor(x) for x in self.values]\n            step = ops.convert_to_tensor(step)\n\n            for i, b in enumerate(boundaries):\n                if b.dtype != step.dtype:\n                    # We cast the boundaries to have the same type as the step\n                    b = ops.cast(b, step.dtype)\n                    boundaries[i] = b\n\n            result_dtype = values[0].dtype\n            result_value = ops.array(0, dtype=result_dtype)\n\n            # For each range between boundaries, we check whether the step is\n            # within that range, cast the resulting boolean to a number,\n            # and multiply the result by the corresponding value for the range.\n            # Taking the sum of these yields a piecewise constant function.\n            step_less_than_first_boundary = ops.cast(\n                step <= boundaries[0], result_dtype\n            )\n            result_value += step_less_than_first_boundary * values[0]\n\n            step_greater_than_last_boundary = ops.cast(\n                step > boundaries[-1], result_dtype\n            )\n            result_value += step_greater_than_last_boundary * values[-1]\n\n            for low, high, value in zip(\n                boundaries[:-1], boundaries[1:], values[1:-1]\n            ):\n                # Need to bind v here; can do this with lambda v=v: ...\n                step_in_range = ops.cast(\n                    (step > low) & (step <= high), result_dtype\n                )\n                result_value += step_in_range * value\n\n            return result_value\n\n    def get_config(self):\n        return {\n            \"boundaries\": self.boundaries,\n            \"values\": self.values,\n            \"name\": self.name,\n        }\n\n\n@keras_export(\"keras.optimizers.schedules.PolynomialDecay\")\nclass PolynomialDecay(LearningRateSchedule):\n    \"\"\"A `LearningRateSchedule` that uses a polynomial decay schedule.\n\n    It is commonly observed that a monotonically decreasing learning rate, whose\n    degree of change is carefully chosen, results in a better performing model.\n    This schedule applies a polynomial decay function to an optimizer step,\n    given a provided `initial_learning_rate`, to reach an `end_learning_rate`\n    in the given `decay_steps`.\n\n    It requires a `step` value to compute the decayed learning rate. You\n    can just pass a backend variable that you increment at each training\n    step.\n\n    The schedule is a 1-arg callable that produces a decayed learning rate\n    when passed the current optimizer step. This can be useful for changing the\n    learning rate value across different invocations of optimizer functions.\n    It is computed as:\n\n    ```python\n    def decayed_learning_rate(step):\n        step = min(step, decay_steps)\n        return ((initial_learning_rate - end_learning_rate) *\n                (1 - step / decay_steps) ^ (power)\n               ) + end_learning_rate\n    ```\n\n    If `cycle` is True then a multiple of `decay_steps` is used, the first one\n    that is bigger than `step`.\n\n    ```python\n    def decayed_learning_rate(step):\n        decay_steps = decay_steps * ceil(step / decay_steps)\n        return ((initial_learning_rate - end_learning_rate) *\n                (1 - step / decay_steps) ^ (power)\n               ) + end_learning_rate\n    ```\n\n    You can pass this schedule directly into a `keras.optimizers.Optimizer`\n    as the learning rate.\n    Example: Fit a model while decaying from 0.1 to 0.01 in 10000 steps using\n    sqrt (i.e. power=0.5):\n\n    ```python\n    ...\n    starter_learning_rate = 0.1\n    end_learning_rate = 0.01\n    decay_steps = 10000\n    learning_rate_fn = keras.optimizers.schedules.PolynomialDecay(\n        starter_learning_rate,\n        decay_steps,\n        end_learning_rate,\n        power=0.5)\n\n    model.compile(optimizer=keras.optimizers.SGD(\n                      learning_rate=learning_rate_fn),\n                  loss='sparse_categorical_crossentropy',\n                  metrics=['accuracy'])\n\n    model.fit(data, labels, epochs=5)\n    ```\n\n    The learning rate schedule is also serializable and deserializable using\n    `keras.optimizers.schedules.serialize` and\n    `keras.optimizers.schedules.deserialize`.\n\n    Args:\n        initial_learning_rate: A Python float. The initial learning rate.\n        decay_steps: A Python integer. Must be positive. See the decay\n            computation above.\n        end_learning_rate: A Python float. The minimal end learning rate.\n        power: A Python float. The power of the polynomial. Defaults to\n            `1.0`.\n        cycle: A boolean, whether it should cycle beyond decay_steps.\n        name: String.  Optional name of the operation. Defaults to\n            `\"PolynomialDecay\"`.\n\n    Returns:\n        A 1-arg callable learning rate schedule that takes the current optimizer\n        step and outputs the decayed learning rate, a scalar tensor of the\n        same type as `initial_learning_rate`.\n    \"\"\"\n\n    def __init__(\n        self,\n        initial_learning_rate,\n        decay_steps,\n        end_learning_rate=0.0001,\n        power=1.0,\n        cycle=False,\n        name=\"PolynomialDecay\",\n    ):\n        super().__init__()\n\n        self.initial_learning_rate = initial_learning_rate\n        self.decay_steps = decay_steps\n        self.end_learning_rate = end_learning_rate\n        self.power = power\n        self.cycle = cycle\n        self.name = name\n\n        if self.decay_steps <= 0:\n            raise ValueError(\n                \"Argument `decay_steps` must be > 0. \"\n                f\"Received: decay_steps={self.decay_steps}\"\n            )\n\n    def __call__(self, step):\n        with ops.name_scope(self.name):\n            initial_learning_rate = ops.convert_to_tensor(\n                self.initial_learning_rate\n            )\n            dtype = initial_learning_rate.dtype\n            end_learning_rate = ops.cast(self.end_learning_rate, dtype)\n            power = ops.cast(self.power, dtype)\n\n            global_step_recomp = ops.cast(step, dtype)\n            decay_steps_recomp = ops.cast(self.decay_steps, dtype)\n            if self.cycle:\n                # Find the first multiple of decay_steps that is bigger than\n                # global_step. If global_step is zero set the multiplier to 1\n                multiplier = ops.where(\n                    ops.equal(global_step_recomp, 0),\n                    1.0,\n                    ops.ceil(global_step_recomp / self.decay_steps),\n                )\n                decay_steps_recomp = ops.multiply(\n                    decay_steps_recomp, multiplier\n                )\n            else:\n                # Make sure that the global_step used is not bigger than\n                # decay_steps.\n                global_step_recomp = ops.minimum(\n                    global_step_recomp, decay_steps_recomp\n                )\n\n            p = ops.divide(global_step_recomp, decay_steps_recomp)\n            return ops.add(\n                ops.multiply(\n                    initial_learning_rate - end_learning_rate,\n                    ops.power(1 - p, power),\n                ),\n                end_learning_rate,\n            )\n\n    def get_config(self):\n        return {\n            \"initial_learning_rate\": self.initial_learning_rate,\n            \"decay_steps\": self.decay_steps,\n            \"end_learning_rate\": self.end_learning_rate,\n            \"power\": self.power,\n            \"cycle\": self.cycle,\n            \"name\": self.name,\n        }\n\n\n@keras_export(\"keras.optimizers.schedules.InverseTimeDecay\")\nclass InverseTimeDecay(LearningRateSchedule):\n    \"\"\"A `LearningRateSchedule` that uses an inverse time decay schedule.\n\n    When training a model, it is often useful to lower the learning rate as\n    the training progresses. This schedule applies the inverse decay function\n    to an optimizer step, given a provided initial learning rate.\n    It requires a `step` value to compute the decayed learning rate. You can\n    just pass a backend variable that you increment at each training step.\n\n    The schedule is a 1-arg callable that produces a decayed learning\n    rate when passed the current optimizer step. This can be useful for changing\n    the learning rate value across different invocations of optimizer functions.\n    It is computed as:\n\n    ```python\n    def decayed_learning_rate(step):\n        return initial_learning_rate / (1 + decay_rate * step / decay_step)\n    ```\n\n    or, if `staircase` is `True`, as:\n\n    ```python\n    def decayed_learning_rate(step):\n        return initial_learning_rate /\n               (1 + decay_rate * floor(step / decay_step))\n    ```\n\n    You can pass this schedule directly into a `keras.optimizers.Optimizer`\n    as the learning rate.\n    Example: Fit a Keras model when decaying 1/t with a rate of 0.5:\n\n    ```python\n    ...\n    initial_learning_rate = 0.1\n    decay_steps = 1.0\n    decay_rate = 0.5\n    learning_rate_fn = keras.optimizers.schedules.InverseTimeDecay(\n        initial_learning_rate, decay_steps, decay_rate)\n\n    model.compile(optimizer=keras.optimizers.SGD(\n                      learning_rate=learning_rate_fn),\n                  loss='sparse_categorical_crossentropy',\n                  metrics=['accuracy'])\n\n    model.fit(data, labels, epochs=5)\n    ```\n\n    Args:\n        initial_learning_rate: A Python float. The initial learning rate.\n        decay_steps: How often to apply decay.\n        decay_rate: A Python number.  The decay rate.\n        staircase: Whether to apply decay in a discrete staircase, as o\n        pposed to continuous, fashion.\n        name: String.  Optional name of the operation.  Defaults to\n            `\"InverseTimeDecay\"`.\n\n    Returns:\n        A 1-arg callable learning rate schedule that takes the current optimizer\n        step and outputs the decayed learning rate, a scalar tensor of the\n        same type as `initial_learning_rate`.\n    \"\"\"\n\n    def __init__(\n        self,\n        initial_learning_rate,\n        decay_steps,\n        decay_rate,\n        staircase=False,\n        name=\"InverseTimeDecay\",\n    ):\n        super().__init__()\n\n        self.initial_learning_rate = initial_learning_rate\n        self.decay_steps = decay_steps\n        self.decay_rate = decay_rate\n        self.staircase = staircase\n        self.name = name\n\n        if self.decay_steps <= 0:\n            raise ValueError(\n                \"Argument `decay_steps` must be > 0. \"\n                f\"Received: decay_steps={self.decay_steps}\"\n            )\n\n    def __call__(self, step):\n        with ops.name_scope(self.name):\n            initial_learning_rate = ops.convert_to_tensor(\n                self.initial_learning_rate\n            )\n            dtype = initial_learning_rate.dtype\n            decay_steps = ops.cast(self.decay_steps, dtype)\n            decay_rate = ops.cast(self.decay_rate, dtype)\n\n            global_step_recomp = ops.cast(step, dtype)\n            p = global_step_recomp / decay_steps\n            if self.staircase:\n                p = ops.floor(p)\n            const = ops.cast(ops.array(1), dtype)\n            denom = ops.add(const, ops.multiply(decay_rate, p))\n            return ops.divide(initial_learning_rate, denom)\n\n    def get_config(self):\n        return {\n            \"initial_learning_rate\": self.initial_learning_rate,\n            \"decay_steps\": self.decay_steps,\n            \"decay_rate\": self.decay_rate,\n            \"staircase\": self.staircase,\n            \"name\": self.name,\n        }\n\n\n@keras_export(\"keras.optimizers.schedules.CosineDecay\")\nclass CosineDecay(LearningRateSchedule):\n    \"\"\"A `LearningRateSchedule` that uses a cosine decay with optional warmup.\n\n    See [Loshchilov & Hutter, ICLR2016](https://arxiv.org/abs/1608.03983),\n    SGDR: Stochastic Gradient Descent with Warm Restarts.\n\n    For the idea of a linear warmup of our learning rate,\n    see [Goyal et al.](https://arxiv.org/pdf/1706.02677.pdf).\n\n    When we begin training a model, we often want an initial increase in our\n    learning rate followed by a decay. If `warmup_target` is an int, this\n    schedule applies a linear increase per optimizer step to our learning rate\n    from `initial_learning_rate` to `warmup_target` for a duration of\n    `warmup_steps`. Afterwards, it applies a cosine decay function taking our\n    learning rate from `warmup_target` to `warmup_target * alpha` for a\n    duration of `decay_steps`. If `warmup_target` is None we skip warmup and\n    our decay will take our learning rate from `initial_learning_rate` to\n    `initial_learning_rate * alpha`.\n    It requires a `step` value to  compute the learning rate. You can\n    just pass a backend variable that you increment at each training step.\n\n    The schedule is a 1-arg callable that produces a warmup followed by a\n    decayed learning rate when passed the current optimizer step. This can be\n    useful for changing the learning rate value across different invocations of\n    optimizer functions.\n\n    Our warmup is computed as:\n\n    ```python\n    def warmup_learning_rate(step):\n        completed_fraction = step / warmup_steps\n        total_delta = target_warmup - initial_learning_rate\n        return completed_fraction * total_delta\n    ```\n\n    And our decay is computed as:\n\n    ```python\n    if warmup_target is None:\n        initial_decay_lr = initial_learning_rate\n    else:\n        initial_decay_lr = warmup_target\n\n    def decayed_learning_rate(step):\n        step = min(step, decay_steps)\n        cosine_decay = 0.5 * (1 + cos(pi * step / decay_steps))\n        decayed = (1 - alpha) * cosine_decay + alpha\n        return initial_decay_lr * decayed\n    ```\n\n    Example usage without warmup:\n\n    ```python\n    decay_steps = 1000\n    initial_learning_rate = 0.1\n    lr_decayed_fn = keras.optimizers.schedules.CosineDecay(\n        initial_learning_rate, decay_steps)\n    ```\n\n    Example usage with warmup:\n\n    ```python\n    decay_steps = 1000\n    initial_learning_rate = 0\n    warmup_steps = 1000\n    target_learning_rate = 0.1\n    lr_warmup_decayed_fn = keras.optimizers.schedules.CosineDecay(\n        initial_learning_rate, decay_steps, warmup_target=target_learning_rate,\n        warmup_steps=warmup_steps\n    )\n    ```\n\n    You can pass this schedule directly into a `keras.optimizers.Optimizer`\n    as the learning rate. The learning rate schedule is also serializable and\n    deserializable using `keras.optimizers.schedules.serialize` and\n    `keras.optimizers.schedules.deserialize`.\n\n    Args:\n        initial_learning_rate: A Python float. The initial learning rate.\n        decay_steps: A Python int. Number of steps to decay over.\n        alpha: A Python float. Minimum learning rate value for decay as a\n            fraction of `warmup_target` or, if `warmup_target` is None,\n            `initial_learning_rate`.\n        name: String. Optional name of the operation.  Defaults to\n            `\"CosineDecay\"`.\n        warmup_target: A Python float. The target learning rate for our\n            warmup phase. Will cast to the `initial_learning_rate` datatype.\n            Setting to `None` will skip warmup and begins decay phase from\n            `initial_learning_rate`. Otherwise scheduler will warmup from\n            `initial_learning_rate` to `warmup_target`.\n        warmup_steps: A Python int. Number of steps to warmup over.\n\n    Returns:\n        A 1-arg callable learning rate schedule that takes the current optimizer\n        step and outputs the decayed learning rate, a scalar tensor of the\n        same type as `initial_learning_rate`.\n    \"\"\"\n\n    def __init__(\n        self,\n        initial_learning_rate,\n        decay_steps,\n        alpha=0.0,\n        name=\"CosineDecay\",\n        warmup_target=None,\n        warmup_steps=0,\n    ):\n        super().__init__()\n\n        self.initial_learning_rate = initial_learning_rate\n        self.decay_steps = decay_steps\n        self.alpha = alpha\n        self.name = name\n        self.warmup_steps = warmup_steps\n        self.warmup_target = warmup_target\n\n        if self.decay_steps <= 0:\n            raise ValueError(\n                \"Argument `decay_steps` must be > 0. \"\n                f\"Received: decay_steps={self.decay_steps}\"\n            )\n\n    def _decay_function(self, step, decay_steps, decay_from_lr, dtype):\n        with ops.name_scope(self.name):\n            completed_fraction = ops.divide(step, decay_steps)\n            pi = ops.array(math.pi, dtype=dtype)\n            cosine_decayed = 0.5 * (\n                1.0 + ops.cos(ops.multiply(pi, completed_fraction))\n            )\n            decayed = (1 - self.alpha) * cosine_decayed + self.alpha\n            return ops.multiply(decay_from_lr, decayed)\n\n    def _warmup_function(\n        self, step, warmup_steps, warmup_target, initial_learning_rate\n    ):\n        with ops.name_scope(self.name):\n            completed_fraction = step / warmup_steps\n            total_step_delta = warmup_target - initial_learning_rate\n            return total_step_delta * completed_fraction + initial_learning_rate\n\n    def __call__(self, step):\n        with ops.name_scope(self.name):\n            initial_learning_rate = ops.convert_to_tensor(\n                self.initial_learning_rate\n            )\n            dtype = initial_learning_rate.dtype\n            decay_steps = ops.cast(self.decay_steps, dtype)\n            global_step_recomp = ops.cast(step, dtype)\n\n            if self.warmup_target is None:\n                global_step_recomp = ops.minimum(\n                    global_step_recomp, decay_steps\n                )\n                return self._decay_function(\n                    global_step_recomp,\n                    decay_steps,\n                    initial_learning_rate,\n                    dtype,\n                )\n\n            warmup_target = ops.cast(self.warmup_target, dtype)\n            warmup_steps = ops.cast(self.warmup_steps, dtype)\n\n            global_step_recomp = ops.minimum(\n                global_step_recomp, decay_steps + warmup_steps\n            )\n\n            return ops.cond(\n                global_step_recomp < warmup_steps,\n                lambda: self._warmup_function(\n                    global_step_recomp,\n                    warmup_steps,\n                    warmup_target,\n                    initial_learning_rate,\n                ),\n                lambda: self._decay_function(\n                    global_step_recomp - warmup_steps,\n                    decay_steps,\n                    warmup_target,\n                    dtype,\n                ),\n            )\n\n    def get_config(self):\n        return {\n            \"initial_learning_rate\": self.initial_learning_rate,\n            \"decay_steps\": self.decay_steps,\n            \"alpha\": self.alpha,\n            \"name\": self.name,\n            \"warmup_target\": self.warmup_target,\n            \"warmup_steps\": self.warmup_steps,\n        }\n\n\n@keras_export(\"keras.optimizers.schedules.CosineDecayRestarts\")\nclass CosineDecayRestarts(LearningRateSchedule):\n    \"\"\"A `LearningRateSchedule` that uses a cosine decay schedule with restarts.\n\n    See [Loshchilov & Hutter, ICLR2016](https://arxiv.org/abs/1608.03983),\n    SGDR: Stochastic Gradient Descent with Warm Restarts.\n\n    When training a model, it is often useful to lower the learning rate as\n    the training progresses. This schedule applies a cosine decay function with\n    restarts to an optimizer step, given a provided initial learning rate.\n    It requires a `step` value to compute the decayed learning rate. You can\n    just pass a backend variable that you increment at each training step.\n\n    The schedule is a 1-arg callable that produces a decayed learning\n    rate when passed the current optimizer step. This can be useful for changing\n    the learning rate value across different invocations of optimizer functions.\n\n    The learning rate multiplier first decays\n    from 1 to `alpha` for `first_decay_steps` steps. Then, a warm\n    restart is performed. Each new warm restart runs for `t_mul` times more\n    steps and with `m_mul` times initial learning rate as the new learning rate.\n\n    Example:\n    ```python\n    first_decay_steps = 1000\n    lr_decayed_fn = (\n        keras.optimizers.schedules.CosineDecayRestarts(\n            initial_learning_rate,\n            first_decay_steps))\n    ```\n\n    You can pass this schedule directly into a `keras.optimizers.Optimizer`\n    as the learning rate. The learning rate schedule is also serializable and\n    deserializable using `keras.optimizers.schedules.serialize` and\n    `keras.optimizers.schedules.deserialize`.\n\n    Args:\n        initial_learning_rate: A Python float. The initial learning rate.\n        first_decay_steps: A Python integer. Number of steps to decay over.\n        t_mul: A Python float. Used to derive the number of iterations in\n            the i-th period.\n        m_mul: A Python float. Used to derive the initial learning rate of\n            the i-th period.\n        alpha: A Python float. Minimum learning rate value as a fraction of\n            the `initial_learning_rate`.\n        name: String. Optional name of the operation. Defaults to\n            `\"SGDRDecay\"`.\n\n    Returns:\n        A 1-arg callable learning rate schedule that takes the current optimizer\n        step and outputs the decayed learning rate, a scalar tensor of the\n        same type as `initial_learning_rate`.\n    \"\"\"\n\n    def __init__(\n        self,\n        initial_learning_rate,\n        first_decay_steps,\n        t_mul=2.0,\n        m_mul=1.0,\n        alpha=0.0,\n        name=\"SGDRDecay\",\n    ):\n        super().__init__()\n\n        self.initial_learning_rate = initial_learning_rate\n        self.first_decay_steps = first_decay_steps\n        self._t_mul = t_mul\n        self._m_mul = m_mul\n        self.alpha = alpha\n        self.name = name\n\n        if self.first_decay_steps <= 0:\n            raise ValueError(\n                \"Argument `first_decay_steps` must be > 0. \"\n                f\"Received: first_decay_steps={self.first_decay_steps}\"\n            )\n\n    def __call__(self, step):\n        with ops.name_scope(self.name):\n            initial_learning_rate = ops.convert_to_tensor(\n                self.initial_learning_rate\n            )\n            dtype = initial_learning_rate.dtype\n            first_decay_steps = ops.cast(self.first_decay_steps, dtype)\n            alpha = ops.cast(self.alpha, dtype)\n            t_mul = ops.cast(self._t_mul, dtype)\n            m_mul = ops.cast(self._m_mul, dtype)\n\n            global_step_recomp = ops.cast(step, dtype)\n            completed_fraction = global_step_recomp / first_decay_steps\n\n            def compute_step(completed_fraction, geometric=False):\n                \"\"\"Helper for `cond` operation.\"\"\"\n                if geometric:\n                    # ops.log is sensitive to the precision of dtype, so we need\n                    # the additional casting\n                    i_restart = ops.floor(\n                        ops.log(\n                            ops.cast(\n                                1.0 - completed_fraction * (1.0 - t_mul), dtype\n                            )\n                        )\n                        / ops.log(t_mul)\n                    )\n\n                    sum_r = ops.divide(\n                        1.0 - ops.power(t_mul, i_restart), (1.0 - t_mul)\n                    )\n                    completed_fraction = ops.divide(\n                        ops.subtract(completed_fraction, sum_r),\n                        ops.power(t_mul, i_restart),\n                    )\n\n                else:\n                    i_restart = ops.floor(completed_fraction)\n                    completed_fraction -= i_restart\n\n                return i_restart, completed_fraction\n\n            i_restart, completed_fraction = ops.cond(\n                ops.equal(t_mul, 1.0),\n                lambda: compute_step(completed_fraction, geometric=False),\n                lambda: compute_step(completed_fraction, geometric=True),\n            )\n\n            m_fac = ops.power(m_mul, i_restart)\n            cosine_decayed = (\n                0.5\n                * m_fac\n                * (\n                    1.0\n                    + ops.cos(\n                        ops.multiply(\n                            ops.array(math.pi, dtype=dtype), completed_fraction\n                        )\n                    )\n                )\n            )\n            decayed = ops.add(ops.multiply((1 - alpha), cosine_decayed), alpha)\n\n            return ops.multiply(initial_learning_rate, decayed)\n\n    def get_config(self):\n        return {\n            \"initial_learning_rate\": self.initial_learning_rate,\n            \"first_decay_steps\": self.first_decay_steps,\n            \"t_mul\": self._t_mul,\n            \"m_mul\": self._m_mul,\n            \"alpha\": self.alpha,\n            \"name\": self.name,\n        }\n\n\n@keras_export(\"keras.optimizers.schedules.serialize\")\ndef serialize(learning_rate_schedule):\n    \"\"\"Serializes a `LearningRateSchedule` into a JSON-compatible dict.\n\n    Args:\n        learning_rate_schedule: The `LearningRateSchedule` object to serialize.\n\n    Returns:\n        A JSON-serializable dict representing the object's config.\n\n    Example:\n\n    >>> lr_schedule = keras.optimizers.schedules.ExponentialDecay(\n    ...     0.1, decay_steps=100000, decay_rate=0.96, staircase=True)\n    >>> keras.optimizers.schedules.serialize(lr_schedule)\n    {'module': 'keras.optimizers.schedules',\n    'class_name': 'ExponentialDecay', 'config': {...},\n    'registered_name': None}\n    \"\"\"\n    return serialization_lib.serialize_keras_object(learning_rate_schedule)\n\n\n@keras_export(\"keras.optimizers.schedules.deserialize\")\ndef deserialize(config, custom_objects=None):\n    \"\"\"Instantiates a `LearningRateSchedule` object from a serialized form.\n\n    Args:\n        config: The serialized form of the `LearningRateSchedule`. Dictionary of\n            the form {'class_name': str, 'config': dict}.\n        custom_objects: A dictionary mapping class names (or function names) of\n            custom (non-Keras) objects to class/functions.\n\n    Returns:\n        A `LearningRateSchedule` object.\n\n    Example:\n\n    ```python\n    # Configuration for PolynomialDecay\n    config = {\n        'class_name': 'PolynomialDecay',\n        'config': {'cycle': False,\n            'decay_steps': 10000,\n            'end_learning_rate': 0.01,\n            'initial_learning_rate': 0.1,\n            'name': None,\n            'power': 0.5\n        }\n    }\n    lr_schedule = keras.optimizers.schedules.deserialize(config)\n    ```\n    \"\"\"\n    return serialization_lib.deserialize_keras_object(\n        config,\n        module_objects=globals(),\n        custom_objects=custom_objects,\n        printable_module_name=\"decay\",\n    )\n"
  },
  {
    "path": "keras/src/optimizers/schedules/learning_rate_schedule_test.py",
    "content": "\"\"\"Tests for learning rate schedule API.\"\"\"\n\nimport math\n\nimport numpy as np\nimport pytest\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import optimizers\nfrom keras.src import testing\nfrom keras.src.models import Sequential\nfrom keras.src.optimizers import schedules\n\n\nclass TestFitLRSchedulesFlow(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_fit_lr_correctness(self):\n        model = Sequential(\n            [\n                layers.Dense(\n                    2, kernel_initializer=\"ones\", bias_initializer=\"ones\"\n                )\n            ]\n        )\n        optimizer = optimizers.Adam(\n            learning_rate=schedules.ExponentialDecay(\n                initial_learning_rate=0.05, decay_steps=1, decay_rate=0.9\n            )\n        )\n        self.assertEqual(len(optimizer.variables), 1)\n        self.assertEqual(optimizer.variables[0], 0)\n\n        model.compile(optimizer=optimizer, loss=\"mse\")\n        x = np.arange(32).reshape((16, 2))\n        y = np.arange(32).reshape((16, 2))\n        history = model.fit(x, y, epochs=3, batch_size=4, shuffle=False)\n        self.assertEqual(optimizer.variables[0], 4 * 3)\n        self.assertAllClose(\n            history.history[\"loss\"],\n            [230.79457092285156, 128.30319213867188, 79.33648681640625],\n            rtol=5e-5,\n            tpu_atol=5e-3,\n            tpu_rtol=5e-3,\n        )\n\n\nclass ExponentialDecayTest(testing.TestCase):\n    def test_config(self):\n        self.run_class_serialization_test(\n            schedules.ExponentialDecay(\n                initial_learning_rate=0.05,\n                decay_steps=10,\n                decay_rate=0.96,\n                staircase=True,\n                name=\"my_ed\",\n            )\n        )\n\n    def test_continuous(self):\n        step = 5\n        decayed_lr = schedules.ExponentialDecay(0.05, 10, 0.96)\n        expected = 0.05 * 0.96 ** (5.0 / 10.0)\n        self.assertAllClose(decayed_lr(step), expected, 1e-6)\n\n    def test_staircase(self):\n        step = backend.Variable(1.0)\n        decayed_lr = schedules.ExponentialDecay(0.1, 3, 0.96, staircase=True)\n\n        # No change to learning rate due to staircase\n        expected = 0.1\n        self.assertAllClose(decayed_lr(step), expected, 1e-6)\n\n        expected = 0.1\n        step.assign(2)\n        self.assertAllClose(decayed_lr(step), expected, 1e-6)\n\n        # Decayed learning rate\n        expected = 0.1 * 0.96 ** (100 // 3)\n        step.assign(100)\n        self.assertAllClose(decayed_lr(step), expected, 1e-6)\n\n    def test_variables(self):\n        step = backend.Variable(1.0)\n        decayed_lr = schedules.ExponentialDecay(0.1, 3, 0.96, staircase=True)\n\n        # No change to learning rate\n        step.assign(1)\n        self.assertAllClose(decayed_lr(step), 0.1, 1e-6)\n        step.assign(2)\n        self.assertAllClose(decayed_lr(step), 0.1, 1e-6)\n        # Decayed learning rate\n        step.assign(100)\n        expected = 0.1 * 0.96 ** (100 // 3)\n        self.assertAllClose(decayed_lr(step), expected, 1e-6)\n\n\nclass PiecewiseConstantDecayTest(testing.TestCase):\n    def test_config(self):\n        self.run_class_serialization_test(\n            schedules.PiecewiseConstantDecay(\n                boundaries=[10, 20], values=[1, 2, 3], name=\"my_pcd\"\n            )\n        )\n\n    def test_piecewise_values(self):\n        x = backend.Variable(-999.0)\n        decayed_lr = schedules.PiecewiseConstantDecay(\n            [100, 110, 120], [1.0, 0.1, 0.01, 0.001]\n        )\n\n        self.assertAllClose(decayed_lr(x), 1.0, 1e-6)\n        x.assign(100)\n        self.assertAllClose(decayed_lr(x), 1.0, 1e-6)\n        x.assign(105)\n        self.assertAllClose(decayed_lr(x), 0.1, 1e-6)\n        x.assign(110)\n        self.assertAllClose(decayed_lr(x), 0.1, 1e-6)\n        x.assign(120)\n        self.assertAllClose(decayed_lr(x), 0.01, 1e-6)\n        x.assign(999)\n        self.assertAllClose(decayed_lr(x), 0.001, 1e-6)\n\n    def test_boundary_values(self):\n        # Test casting boundaries from int32 to int64.\n        x_int64 = backend.Variable(0, dtype=\"int64\", trainable=False)\n        boundaries, values = [1, 2, 3], [0.4, 0.5, 0.6, 0.7]\n        decayed_lr = schedules.PiecewiseConstantDecay(boundaries, values)\n\n        self.assertAllClose(decayed_lr(x_int64), 0.4, 1e-6)\n        x_int64.assign(1)\n        self.assertAllClose(decayed_lr(x_int64), 0.4, 1e-6)\n        x_int64.assign(2)\n        self.assertAllClose(decayed_lr(x_int64), 0.5, 1e-6)\n        x_int64.assign(3)\n        self.assertAllClose(decayed_lr(x_int64), 0.6, 1e-6)\n        x_int64.assign(4)\n        self.assertAllClose(decayed_lr(x_int64), 0.7, 1e-6)\n\n\nclass LinearDecayTest(testing.TestCase):\n    def test_config(self):\n        self.run_class_serialization_test(\n            schedules.PolynomialDecay(\n                initial_learning_rate=0.1,\n                decay_steps=100,\n                end_learning_rate=0.005,\n                power=1.0,\n                cycle=False,\n                name=\"my_ld\",\n            )\n        )\n\n    def test_halfway(self):\n        step = 5\n        lr = 0.05\n        end_lr = 0.0\n        decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr)\n        expected = lr * 0.5\n        self.assertAllClose(decayed_lr(step), expected, 1e-6)\n\n    def test_end(self):\n        step = 10\n        lr = 0.05\n        end_lr = 0.001\n        decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr)\n        expected = end_lr\n        self.assertAllClose(decayed_lr(step), expected, 1e-6)\n\n    def test_halfway_with_end(self):\n        step = 5\n        lr = 0.05\n        end_lr = 0.001\n        decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr)\n        expected = (lr + end_lr) * 0.5\n        self.assertAllClose(decayed_lr(step), expected, 1e-6)\n\n    def test_beyond_end(self):\n        step = 15\n        lr = 0.05\n        end_lr = 0.001\n        decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr)\n        expected = end_lr\n        self.assertAllClose(decayed_lr(step), expected, 1e-6)\n\n    def test_beyond_end_with_cycle(self):\n        step = 15\n        lr = 0.05\n        end_lr = 0.001\n        decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr, cycle=True)\n        expected = (lr - end_lr) * 0.25 + end_lr\n        self.assertAllClose(decayed_lr(step), expected, 1e-6)\n\n\nclass SqrtDecayTest(testing.TestCase):\n    def test_halfway(self):\n        step = 5\n        lr = 0.05\n        end_lr = 0.0\n        power = 0.5\n        decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr, power=power)\n        expected = lr * 0.5**power\n        self.assertAllClose(decayed_lr(step), expected, 1e-6)\n\n    def test_end(self):\n        step = 10\n        lr = 0.05\n        end_lr = 0.001\n        power = 0.5\n        decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr, power=power)\n        expected = end_lr\n        self.assertAllClose(decayed_lr(step), expected, 1e-6)\n\n    def test_halfway_with_end(self):\n        step = 5\n        lr = 0.05\n        end_lr = 0.001\n        power = 0.5\n        decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr, power=power)\n        expected = (lr - end_lr) * 0.5**power + end_lr\n        self.assertAllClose(decayed_lr(step), expected, 1e-6)\n\n    def test_beyond_end(self):\n        step = 15\n        lr = 0.05\n        end_lr = 0.001\n        power = 0.5\n        decayed_lr = schedules.PolynomialDecay(lr, 10, end_lr, power=power)\n        expected = end_lr\n        self.assertAllClose(decayed_lr(step), expected, 1e-6)\n\n    def test_beyond_end_with_cycle(self):\n        step = 15\n        lr = 0.05\n        end_lr = 0.001\n        power = 0.5\n        decayed_lr = schedules.PolynomialDecay(\n            lr, 10, end_lr, power=power, cycle=True\n        )\n        expected = (lr - end_lr) * 0.25**power + end_lr\n        self.assertAllClose(decayed_lr(step), expected, 1e-6)\n\n    def test_begin_with_cycle(self):\n        lr = 0.001\n        decay_steps = 10\n        step = 0\n        decayed_lr = schedules.PolynomialDecay(lr, decay_steps, cycle=True)\n        expected = lr\n        self.assertAllClose(decayed_lr(step), expected, 1e-6)\n\n\nclass InverseTimeDecayTest(testing.TestCase):\n    def test_config(self):\n        self.run_class_serialization_test(\n            schedules.InverseTimeDecay(\n                initial_learning_rate=0.05,\n                decay_steps=10,\n                decay_rate=0.96,\n                staircase=True,\n                name=\"my_itd\",\n            )\n        )\n\n    def test_decay(self):\n        initial_lr = 0.1\n        k = 10\n        decay_rate = 0.96\n        step = backend.Variable(0.0)\n        decayed_lr = schedules.InverseTimeDecay(initial_lr, k, decay_rate)\n\n        for i in range(k + 1):\n            expected = initial_lr / (1 + i / k * decay_rate)\n            self.assertAllClose(decayed_lr(step), expected, 1e-6)\n            step.assign(step + 1)\n\n    def test_staircase(self):\n        initial_lr = 0.1\n        k = 10\n        decay_rate = 0.96\n        step = backend.Variable(0.0)\n        decayed_lr = schedules.InverseTimeDecay(\n            initial_lr, k, decay_rate, staircase=True\n        )\n\n        for i in range(k + 1):\n            expected = initial_lr / (1 + decay_rate * (i // k))\n            self.assertAllClose(decayed_lr(step), expected, 1e-6)\n            step.assign(step + 1)\n\n\nclass CosineDecayTest(testing.TestCase):\n    def test_config(self):\n        self.run_class_serialization_test(\n            schedules.CosineDecay(\n                initial_learning_rate=0.05,\n                decay_steps=10,\n                alpha=0.1,\n                warmup_target=0.2,\n                warmup_steps=2,\n                name=\"my_cd\",\n            )\n        )\n\n    def np_cosine_decay(self, step, decay_steps, alpha=0.0):\n        step = min(step, decay_steps)\n        completed_fraction = step / decay_steps\n        decay = 0.5 * (1.0 + math.cos(math.pi * completed_fraction))\n        return (1.0 - alpha) * decay + alpha\n\n    def test_decay(self):\n        num_training_steps = 1000\n        initial_lr = 1.0\n        for step in range(0, 1500, 250):\n            decayed_lr = schedules.CosineDecay(initial_lr, num_training_steps)\n            expected = self.np_cosine_decay(step, num_training_steps)\n            self.assertAllClose(decayed_lr(step), expected, 1e-6)\n\n    def linear_warmup(self, step, warmup_steps, initial_lr, target_lr):\n        completed_fraction = step / warmup_steps\n        total_delta = target_lr - initial_lr\n        return completed_fraction * total_delta\n\n    def test_warmup(self):\n        warmup_steps = 1500\n        initial_lr = 0.0\n        target_lr = 10.0\n        for step in range(0, 1500, 250):\n            lr = schedules.CosineDecay(\n                initial_lr,\n                10,\n                warmup_target=target_lr,\n                warmup_steps=warmup_steps,\n            )\n            expected = self.linear_warmup(\n                step, warmup_steps, initial_lr, target_lr\n            )\n            self.assertAllClose(lr(step), expected)\n\n    def test_alpha(self):\n        num_training_steps = 1000\n        initial_lr = 1.0\n        alpha = 0.1\n        for step in range(0, 1500, 250):\n            decayed_lr = schedules.CosineDecay(\n                initial_lr, num_training_steps, alpha\n            )\n            expected = self.np_cosine_decay(step, num_training_steps, alpha)\n            self.assertAllClose(decayed_lr(step), expected, 1e-6)\n\n    def test_float64(self):\n        num_training_steps = 1000\n        initial_lr = np.float64(1.0)\n        for step in range(0, 1500, 250):\n            decayed_lr = schedules.CosineDecay(initial_lr, num_training_steps)\n            expected = self.np_cosine_decay(step, num_training_steps)\n            self.assertAllClose(decayed_lr(step), expected, 1e-6)\n\n    def test_warmup_decay(self):\n        warmup_steps = 2000\n        decay_steps = 1000\n        initial_lr = 0.0\n        target_lr = 10.0\n        for step in range(0, 3000, 250):\n            lr = schedules.CosineDecay(\n                initial_lr,\n                decay_steps,\n                warmup_target=target_lr,\n                warmup_steps=warmup_steps,\n            )\n            if step < warmup_steps + 1:\n                expected = self.linear_warmup(\n                    step, warmup_steps, initial_lr, target_lr\n                )\n            else:\n                expected = target_lr * self.np_cosine_decay(\n                    step - warmup_steps, decay_steps\n                )\n            self.assertAllClose(lr(step), expected)\n\n\nclass CosineDecayRestartsTest(testing.TestCase):\n    def test_config(self):\n        self.run_class_serialization_test(\n            schedules.CosineDecayRestarts(\n                initial_learning_rate=0.05,\n                first_decay_steps=10,\n                alpha=0.1,\n                t_mul=3.0,\n                m_mul=4.0,\n                name=\"my_cdr\",\n            )\n        )\n\n    def np_cosine_decay_restarts(\n        self, step, decay_steps, t_mul=2.0, m_mul=1.0, alpha=0.0\n    ):\n        fac = 1.0\n        while step >= decay_steps:\n            step -= decay_steps\n            decay_steps *= t_mul\n            fac *= m_mul\n\n        completed_fraction = step / decay_steps\n        decay = fac * 0.5 * (1.0 + math.cos(math.pi * completed_fraction))\n        return (1.0 - alpha) * decay + alpha\n\n    def test_decay(self):\n        num_training_steps = 1000\n        initial_lr = 1.0\n        for step in range(0, 1500, 250):\n            decayed_lr = schedules.CosineDecayRestarts(\n                initial_lr, num_training_steps\n            )\n            expected = self.np_cosine_decay_restarts(step, num_training_steps)\n            self.assertAllClose(decayed_lr(step), expected, 1e-6)\n\n    def test_float64(self):\n        num_training_steps = 1000\n        initial_lr = np.float64(1.0)\n        for step in range(0, 1500, 250):\n            decayed_lr = schedules.CosineDecayRestarts(\n                initial_lr, num_training_steps\n            )\n            expected = self.np_cosine_decay_restarts(step, num_training_steps)\n            self.assertAllClose(decayed_lr(step), expected, 1e-6)\n\n    def test_alpha(self):\n        num_training_steps = 1000\n        initial_lr = 1.0\n        alpha = 0.1\n        for step in range(0, 1500, 250):\n            decayed_lr = schedules.CosineDecayRestarts(\n                initial_lr, num_training_steps, alpha=alpha\n            )\n            expected = self.np_cosine_decay_restarts(\n                step, num_training_steps, alpha=alpha\n            )\n            self.assertAllClose(decayed_lr(step), expected, 1e-6)\n\n    def test_mmul(self):\n        num_training_steps = 1000\n        initial_lr = 1.0\n        m_mul = 0.9\n        for step in range(0, 1500, 250):\n            decayed_lr = schedules.CosineDecayRestarts(\n                initial_lr, num_training_steps, m_mul=m_mul\n            )\n            expected = self.np_cosine_decay_restarts(\n                step, num_training_steps, m_mul=m_mul\n            )\n            self.assertAllClose(decayed_lr(step), expected, 1e-6)\n\n    def test_tmul(self):\n        num_training_steps = 1000\n        initial_lr = 1.0\n        t_mul = 1.0\n        for step in range(0, 1500, 250):\n            decayed_lr = schedules.CosineDecayRestarts(\n                initial_lr, num_training_steps, t_mul=t_mul\n            )\n            expected = self.np_cosine_decay_restarts(\n                step, num_training_steps, t_mul=t_mul\n            )\n            self.assertAllClose(decayed_lr(step), expected, 1e-6)\n"
  },
  {
    "path": "keras/src/optimizers/sgd.py",
    "content": "from keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.optimizers import optimizer\n\n\n@keras_export(\"keras.optimizers.SGD\")\nclass SGD(optimizer.Optimizer):\n    \"\"\"Gradient descent (with momentum) optimizer.\n\n    Update rule for parameter `w` with gradient `g` when `momentum` is 0:\n\n    ```python\n    w = w - learning_rate * g\n    ```\n\n    Update rule when `momentum` is larger than 0:\n\n    ```python\n    velocity = momentum * velocity - learning_rate * g\n    w = w + velocity\n    ```\n\n    When `nesterov=True`, this rule becomes:\n\n    ```python\n    velocity = momentum * velocity - learning_rate * g\n    w = w + momentum * velocity - learning_rate * g\n    ```\n\n    Args:\n        learning_rate: A float, a\n            `keras.optimizers.schedules.LearningRateSchedule` instance, or\n            a callable that takes no arguments and returns the actual value to\n            use. The learning rate. Defaults to `0.01`.\n        momentum: float hyperparameter >= 0 that accelerates gradient descent in\n            the relevant direction and dampens oscillations. 0 is vanilla\n            gradient descent. Defaults to `0.0`.\n        nesterov: boolean. Whether to apply Nesterov momentum.\n            Defaults to `False`.\n        {{base_optimizer_keyword_args}}\n    \"\"\"\n\n    def __init__(\n        self,\n        learning_rate=0.01,\n        momentum=0.0,\n        nesterov=False,\n        weight_decay=None,\n        clipnorm=None,\n        clipvalue=None,\n        global_clipnorm=None,\n        use_ema=False,\n        ema_momentum=0.99,\n        ema_overwrite_frequency=None,\n        loss_scale_factor=None,\n        gradient_accumulation_steps=None,\n        name=\"SGD\",\n        **kwargs,\n    ):\n        super().__init__(\n            learning_rate=learning_rate,\n            name=name,\n            weight_decay=weight_decay,\n            clipnorm=clipnorm,\n            clipvalue=clipvalue,\n            global_clipnorm=global_clipnorm,\n            use_ema=use_ema,\n            ema_momentum=ema_momentum,\n            ema_overwrite_frequency=ema_overwrite_frequency,\n            loss_scale_factor=loss_scale_factor,\n            gradient_accumulation_steps=gradient_accumulation_steps,\n            **kwargs,\n        )\n        if not isinstance(momentum, float) or momentum < 0 or momentum > 1:\n            raise ValueError(\"`momentum` must be a float between [0, 1].\")\n        self.momentum = momentum\n        self.nesterov = nesterov\n\n    def build(self, variables):\n        \"\"\"Initialize optimizer variables.\n\n        SGD optimizer has one variable `momentums`, only set if `self.momentum`\n        is not 0.\n\n        Args:\n          var_list: list of model variables to build SGD variables on.\n        \"\"\"\n        if self.built:\n            return\n        super().build(variables)\n        self.momentums = []\n        if self.momentum != 0:\n            self.momentums = self.add_optimizer_variables(variables, \"momentum\")\n\n    def update_step(self, gradient, variable, learning_rate):\n        \"\"\"Update step given gradient and the associated model variable.\"\"\"\n        learning_rate = ops.cast(learning_rate, variable.dtype)\n        gradient = ops.cast(gradient, variable.dtype)\n        m = None\n        if self.momentum != 0:\n            m = self.momentums[self._get_variable_index(variable)]\n\n        if m is not None:\n            momentum = ops.cast(self.momentum, variable.dtype)\n            self.assign(\n                m,\n                ops.subtract(\n                    ops.multiply(m, momentum),\n                    ops.multiply(gradient, learning_rate),\n                ),\n            )\n            if self.nesterov:\n                self.assign_add(\n                    variable,\n                    ops.subtract(\n                        ops.multiply(m, momentum),\n                        ops.multiply(gradient, learning_rate),\n                    ),\n                )\n            else:\n                self.assign_add(variable, m)\n        else:\n            self.assign_sub(variable, ops.multiply(gradient, learning_rate))\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"momentum\": self.momentum,\n                \"nesterov\": self.nesterov,\n            }\n        )\n        return config\n\n\nSGD.__doc__ = SGD.__doc__.replace(\n    \"{{base_optimizer_keyword_args}}\", optimizer.base_optimizer_keyword_args\n)\n"
  },
  {
    "path": "keras/src/optimizers/sgd_test.py",
    "content": "# flake8: noqa\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.optimizers.sgd import SGD\n\n\nclass SGDTest(testing.TestCase):\n    def test_config(self):\n        optimizer = SGD(\n            learning_rate=0.5,\n            momentum=0.06,\n            nesterov=True,\n            weight_decay=0.004,\n        )\n        self.run_class_serialization_test(optimizer)\n\n    def test_single_step(self):\n        optimizer = SGD(learning_rate=0.5)\n        self.assertEqual(len(optimizer.variables), 2)\n        grads = ops.array([1.0, 6.0, 7.0, 2.0])\n        vars = backend.Variable([1.0, 2.0, 3.0, 4.0])\n        optimizer.build([vars])\n        optimizer.apply_gradients(zip([grads], [vars]))\n        self.assertAllClose(vars, [0.5, -1.0, -0.5, 3.0], rtol=1e-4, atol=1e-4)\n        self.assertEqual(len(optimizer.variables), 2)\n        self.assertEqual(optimizer.variables[0], 1)\n        self.assertEqual(optimizer.variables[1], 0.5)\n\n    def test_invalid_momentum(self):\n        with self.assertRaisesRegex(\n            ValueError, \"`momentum` must be a float between \\\\[0, 1\\\\].\"\n        ):\n            SGD(momentum=-1.0)\n\n        with self.assertRaisesRegex(\n            ValueError, \"`momentum` must be a float between \\\\[0, 1\\\\].\"\n        ):\n            SGD(momentum=2.0)\n\n    def test_weight_decay(self):\n        grads, var1, var2, var3 = (\n            ops.zeros(()),\n            backend.Variable(2.0),\n            backend.Variable(2.0, name=\"exclude\"),\n            backend.Variable(2.0),\n        )\n        optimizer_1 = SGD(learning_rate=1.0, weight_decay=0.004)\n        optimizer_1.apply_gradients(zip([grads], [var1]))\n\n        optimizer_2 = SGD(learning_rate=1.0, weight_decay=0.004)\n        optimizer_2.exclude_from_weight_decay(var_names=[\"exclude\"])\n        optimizer_2.apply_gradients(zip([grads, grads], [var1, var2]))\n\n        optimizer_3 = SGD(learning_rate=1.0, weight_decay=0.004)\n        optimizer_3.exclude_from_weight_decay(var_list=[var3])\n        optimizer_3.apply_gradients(zip([grads, grads], [var1, var3]))\n\n        self.assertAlmostEqual(var1.numpy(), 1.9760959, decimal=6)\n        self.assertAlmostEqual(var2.numpy(), 2.0, decimal=6)\n        self.assertAlmostEqual(var3.numpy(), 2.0, decimal=6)\n\n    def test_correctness_with_golden(self):\n        optimizer = SGD(nesterov=True)\n\n        x = backend.Variable(np.ones([10]))\n        grads = ops.arange(0.1, 1.1, 0.1)\n        first_grads = ops.full((10,), 0.01)\n\n        # fmt: off\n        golden = np.array(\n            [[0.9999, 0.9999, 0.9999, 0.9999, 0.9999, 0.9999, 0.9999, 0.9999,\n            0.9999, 0.9999], [0.9989, 0.9979, 0.9969, 0.9959, 0.9949, 0.9939,\n            0.9929, 0.9919, 0.9909, 0.9899], [0.9979, 0.9959, 0.9939, 0.9919,\n            0.9899, 0.9879, 0.9859, 0.9839, 0.9819, 0.9799], [0.9969, 0.9939,\n            0.9909, 0.9879, 0.9849, 0.9819, 0.9789, 0.9759, 0.9729, 0.9699],\n            [0.9959, 0.9919, 0.9879, 0.9839, 0.9799, 0.9759, 0.9719, 0.9679,\n            0.9639, 0.9599]]\n        )\n        # fmt: on\n\n        optimizer.apply_gradients(zip([first_grads], [x]))\n        for i in range(5):\n            self.assertAllClose(x, golden[i], rtol=5e-4, atol=5e-4)\n            optimizer.apply_gradients(zip([grads], [x]))\n\n    def test_clip_norm(self):\n        optimizer = SGD(clipnorm=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [2**0.5 / 2, 2**0.5 / 2])\n\n    def test_clip_value(self):\n        optimizer = SGD(clipvalue=1)\n        grad = [np.array([100.0, 100.0])]\n        clipped_grad = optimizer._clip_gradients(grad)\n        self.assertAllClose(clipped_grad[0], [1.0, 1.0])\n"
  },
  {
    "path": "keras/src/quantizers/__init__.py",
    "content": "import inspect\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.quantizers.awq_config import AWQConfig\nfrom keras.src.quantizers.quantization_config import Float8QuantizationConfig\nfrom keras.src.quantizers.quantization_config import Int4QuantizationConfig\nfrom keras.src.quantizers.quantization_config import Int8QuantizationConfig\nfrom keras.src.quantizers.quantization_config import QuantizationConfig\nfrom keras.src.quantizers.quantizers import AbsMaxQuantizer\nfrom keras.src.quantizers.quantizers import Quantizer\nfrom keras.src.quantizers.quantizers import abs_max_quantize\nfrom keras.src.quantizers.quantizers import (\n    abs_max_quantize_grouped_with_zero_point,\n)\nfrom keras.src.quantizers.quantizers import compute_float8_amax_history\nfrom keras.src.quantizers.quantizers import compute_float8_scale\nfrom keras.src.quantizers.quantizers import compute_quantization_parameters\nfrom keras.src.quantizers.quantizers import dequantize_with_sz_map\nfrom keras.src.quantizers.quantizers import fake_quant_with_min_max_vars\nfrom keras.src.quantizers.quantizers import pack_int4\nfrom keras.src.quantizers.quantizers import quantize_and_dequantize\nfrom keras.src.quantizers.quantizers import quantize_with_sz_map\nfrom keras.src.quantizers.quantizers import unpack_int4\nfrom keras.src.saving import serialization_lib\nfrom keras.src.utils.naming import to_snake_case\n\nALL_OBJECTS = {\n    Quantizer,\n    AbsMaxQuantizer,\n    QuantizationConfig,\n    Int8QuantizationConfig,\n    Int4QuantizationConfig,\n    Float8QuantizationConfig,\n    AWQConfig,\n}\nALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}\nALL_OBJECTS_DICT.update(\n    {to_snake_case(cls.__name__): cls for cls in ALL_OBJECTS}\n)\n\n\n@keras_export(\"keras.quantizers.serialize\")\ndef serialize(initializer):\n    return serialization_lib.serialize_keras_object(initializer)\n\n\n@keras_export(\"keras.quantizers.deserialize\")\ndef deserialize(config, custom_objects=None):\n    \"\"\"Return a Keras quantizer object via its config.\"\"\"\n    return serialization_lib.deserialize_keras_object(\n        config,\n        module_objects=ALL_OBJECTS_DICT,\n        custom_objects=custom_objects,\n    )\n\n\n@keras_export(\"keras.quantizers.get\")\ndef get(identifier, **kwargs):\n    \"\"\"Retrieve a Keras quantizer object via an identifier.\"\"\"\n    if identifier is None:\n        return None\n    if isinstance(identifier, dict):\n        obj = deserialize(identifier)\n    elif isinstance(identifier, str):\n        obj = ALL_OBJECTS_DICT.get(identifier, None)\n    else:\n        obj = identifier\n\n    if callable(obj):\n        if inspect.isclass(obj):\n            obj = obj(kwargs)\n        return obj\n    else:\n        raise ValueError(\n            f\"Could not interpret quantizer identifier: {identifier}\"\n        )\n"
  },
  {
    "path": "keras/src/quantizers/awq.py",
    "content": "\"\"\"AWQ (Activation-aware Weight Quantization) algorithm implementation.\n\nAWQ protects salient weights by finding optimal per-channel scales based on\nactivation magnitudes, then applies those scales before quantization.\n\nReference: https://arxiv.org/abs/2306.00978\n\"\"\"\n\nimport types\n\nfrom keras.src import ops\nfrom keras.src.layers import Dense\nfrom keras.src.layers import EinsumDense\nfrom keras.src.quantizers.quantizers import compute_quantization_parameters\nfrom keras.src.quantizers.quantizers import dequantize_with_sz_map\nfrom keras.src.quantizers.quantizers import dequantize_with_zero_point\nfrom keras.src.quantizers.quantizers import quantize_with_sz_map\nfrom keras.src.quantizers.quantizers import quantize_with_zero_point\n\n\ndef awq_search_optimal_scales(\n    weights,\n    activation_magnitudes,\n    *,\n    num_grid_points=20,\n    group_size=-1,\n):\n    \"\"\"Search for optimal AWQ scales using grid search.\n\n    The AWQ algorithm finds scaling factors that protect salient weights.\n    For each channel, we search for an optimal ratio in [0, 1] that minimizes\n    the activation-weighted quantization error.\n\n    The key insight: we MULTIPLY weights by scales before quantization to\n    expand salient weights. This ensures quantization noise is small relative\n    to the expanded weight magnitude. During inference, we divide by scales\n    to restore the original magnitude.\n\n    Scale formula: scales = x_max.pow(ratio).clamp(min=1e-4)\n    Loss function: Activation-weighted MSE (approximates output error)\n\n    Args:\n        weights: Weight tensor [out_features, in_features] (transposed kernel).\n        activation_magnitudes: Per-channel activation magnitudes [in_features].\n        num_grid_points: Number of grid search points. Defaults to 20.\n        group_size: Group size for quantization (-1 for per-channel).\n\n    Returns:\n        best_scales: Optimal per-channel scales [in_features].\n    \"\"\"\n    in_features = ops.shape(weights)[1]\n\n    # Compute per-channel activation magnitudes (x_max)\n    # activations should already be per-channel max magnitudes\n    x_max = ops.cast(activation_magnitudes, \"float32\")\n    # Avoid zero or very small values\n    x_max = ops.where(ops.less(x_max, 1e-8), ops.ones_like(x_max), x_max)\n\n    best_loss = None\n    best_scales = ops.ones((in_features,), dtype=\"float32\")\n\n    # Grid search over ratio values from 0 to 1\n    for i in range(num_grid_points + 1):\n        ratio = i / num_grid_points\n\n        # Compute scales: x_max^ratio (clipped to avoid numerical issues)\n        if ratio == 0:\n            scales = ops.ones_like(x_max)\n        else:\n            scales = ops.power(x_max, ratio)\n        scales = ops.maximum(scales, 1e-4)\n\n        # Normalize scales to avoid extreme values\n        scale_mean = ops.sqrt(ops.multiply(ops.max(scales), ops.min(scales)))\n        scale_mean = ops.maximum(scale_mean, 1e-8)\n        scales = ops.divide(scales, scale_mean)\n\n        # Apply scales to weights by MULTIPLYING (expand salient weights)\n        # weights_scaled: [out_features, in_features]\n        weights_scaled = ops.multiply(weights, scales)\n\n        if group_size == -1:\n            # Per-channel quantization (no grouping)\n            scale_q, zero_q, maxq = compute_quantization_parameters(\n                weights_scaled,\n                bits=4,\n                symmetric=False,\n                per_channel=True,\n                group_size=-1,\n                compute_dtype=\"float32\",\n            )\n\n            # Quantize and dequantize\n            quantized = quantize_with_zero_point(\n                weights_scaled, scale_q, zero_q, maxq\n            )\n            dequantized = dequantize_with_zero_point(quantized, scale_q, zero_q)\n        else:\n            # Grouped quantization - use proper per-row grouping\n            scale_q, zero_q, maxq = compute_quantization_parameters(\n                weights_scaled,\n                bits=4,\n                symmetric=False,\n                per_channel=True,\n                group_size=group_size,\n                compute_dtype=\"float32\",\n            )\n\n            # Compute group indices: maps each input feature to its group\n            g_idx = ops.cast(ops.arange(0, in_features) // group_size, \"int32\")\n\n            # Quantize and dequantize using group index mapping\n            quantized = quantize_with_sz_map(\n                weights_scaled, scale_q, zero_q, g_idx, maxq\n            )\n            dequantized = dequantize_with_sz_map(\n                quantized, scale_q, zero_q, g_idx\n            )\n\n        # Scale back down by DIVIDING to restore original magnitude\n        reconstructed = ops.divide(dequantized, scales)\n\n        # Compute activation-weighted MSE loss\n        # This approximates the output error: ||W*X - W_hat*X||^2\n        # by weighting each channel's error by x_max^2\n        weight_error = ops.square(ops.subtract(weights, reconstructed))\n        # Weight by activation magnitudes squared (broadcast over out_features)\n        weighted_error = ops.multiply(weight_error, ops.square(x_max))\n        loss = ops.mean(weighted_error)\n\n        # Track best\n        if best_loss is None:\n            best_loss = loss\n            best_scales = scales\n        else:\n            is_better = ops.less(loss, best_loss)\n            if is_better:\n                best_loss = loss\n                best_scales = scales\n\n    return best_scales\n\n\ndef awq_quantize_matrix(\n    weights_transpose,\n    activation_magnitudes,\n    *,\n    num_grid_points=20,\n    group_size=-1,\n):\n    \"\"\"Quantize a weight matrix using AWQ.\n\n    This function performs the complete AWQ quantization process:\n    1. Find optimal per-channel scales via grid search\n    2. Apply scales to weights\n    3. Compute quantization parameters\n    4. Quantize weights\n\n    Args:\n        weights_transpose: Weight matrix [out_features, in_features].\n        activation_magnitudes: Per-channel activation magnitudes [in_features].\n        num_grid_points: Number of grid search points.\n        group_size: Group size for quantization.\n\n    Returns:\n        quantized_weights: Quantized weights [out_features, in_features].\n        scales: Quantization scales [out_features, num_groups].\n        zeros: Zero points [out_features, num_groups].\n        awq_scales: AWQ per-channel scales [in_features].\n        g_idx: Group indices [in_features].\n    \"\"\"\n    in_features = ops.shape(weights_transpose)[1]\n\n    # Step 1: Find optimal AWQ scales via grid search\n    awq_scales = awq_search_optimal_scales(\n        weights_transpose,\n        activation_magnitudes,\n        num_grid_points=num_grid_points,\n        group_size=group_size,\n    )\n\n    # Step 2: Apply AWQ scales by MULTIPLYING (expand salient weights)\n    # weights_scaled: [out_features, in_features]\n    weights_scaled = ops.multiply(weights_transpose, awq_scales)\n\n    if group_size == -1:\n        # Per-channel quantization (no grouping)\n        scale_q, zero_q, maxq = compute_quantization_parameters(\n            weights_scaled,\n            bits=4,\n            symmetric=False,\n            per_channel=True,\n            group_size=-1,\n            compute_dtype=\"float32\",\n        )\n\n        # Quantize\n        quantized = quantize_with_zero_point(\n            weights_scaled, scale_q, zero_q, maxq\n        )\n\n        # Build group indices (all 0s for per-channel)\n        g_idx = ops.zeros((in_features,), dtype=\"float32\")\n    else:\n        # Grouped quantization - use proper per-row grouping\n        scale_q, zero_q, maxq = compute_quantization_parameters(\n            weights_scaled,\n            bits=4,\n            symmetric=False,\n            per_channel=True,\n            group_size=group_size,\n            compute_dtype=\"float32\",\n        )\n\n        # Compute group indices: maps each input feature to its group\n        g_idx = ops.cast(ops.arange(0, in_features) // group_size, \"int32\")\n\n        # Quantize using group index mapping\n        quantized = quantize_with_sz_map(\n            weights_scaled, scale_q, zero_q, g_idx, maxq\n        )\n\n        # Convert g_idx to float for storage\n        g_idx = ops.cast(g_idx, \"float32\")\n\n    return quantized, scale_q, zero_q, awq_scales, g_idx\n\n\nclass AWQ:\n    \"\"\"AWQ quantizer for a single layer.\n\n    This class accumulates activation statistics during calibration and\n    performs AWQ quantization on layer weights.\n\n    The AWQ algorithm works by:\n    1. Collecting per-channel maximum activation magnitudes\n    2. Using activation magnitudes to determine weight saliency\n    3. Finding optimal per-channel scales via grid search\n    4. Applying scales before quantization to protect salient weights\n\n    Args:\n        layer: The layer to quantize (Dense or EinsumDense).\n        config: AWQConfig instance with quantization parameters.\n    \"\"\"\n\n    def __init__(self, layer, config=None):\n        from keras.src.quantizers.awq_config import AWQConfig\n\n        self.original_layer = layer\n        self.config = config or AWQConfig(dataset=None, tokenizer=None)\n        self.num_samples = 0\n\n        # Handle Dense and EinsumDense layers\n        if isinstance(layer, Dense) or (\n            isinstance(layer, EinsumDense) and layer.kernel.ndim == 2\n        ):\n            self.kernel_shape = layer.kernel.shape\n            self.rows = self.kernel_shape[0]  # in_features\n            self.columns = self.kernel_shape[1]  # out_features\n            self.layer = layer\n        elif isinstance(layer, EinsumDense) and layer.kernel.ndim == 3:\n            # Handle 3D EinsumDense layers (typically from attention blocks)\n            self.kernel_shape = layer.kernel.shape\n            shape = list(self.kernel_shape)\n            d_model_dim_index = shape.index(max(shape))\n\n            if d_model_dim_index == 0:  # QKV projection case\n                in_features, heads, head_dim = shape\n                self.rows = in_features\n                self.columns = heads * head_dim\n            elif d_model_dim_index in [1, 2]:  # Attention Output case\n                heads, head_dim, out_features = shape\n                self.rows = heads * head_dim\n                self.columns = out_features\n            else:\n                raise ValueError(\n                    f\"Cannot determine dimensions for EinsumDense kernel \"\n                    f\"shape {shape}\"\n                )\n\n            # Create a temporary object that holds a reshaped 2D version\n            self.layer = types.SimpleNamespace(\n                kernel=ops.reshape(layer.kernel, (self.rows, self.columns)),\n            )\n        else:\n            raise TypeError(f\"Unsupported layer type for AWQ: {type(layer)}\")\n\n        # Initialize activation magnitude accumulator (per-channel max)\n        self.activation_magnitudes = ops.zeros((self.rows,), dtype=\"float32\")\n\n    def update_activation_magnitudes(self, input_batch):\n        \"\"\"Update per-channel activation magnitude statistics.\n\n        This method tracks the maximum absolute activation value for each\n        input channel across all calibration batches.\n\n        Args:\n            input_batch: Input activations tensor [batch, ..., in_features].\n        \"\"\"\n        if input_batch is None:\n            raise ValueError(\"Input tensor cannot be None.\")\n        if ops.size(input_batch) == 0:\n            raise ValueError(\"Input tensor cannot be empty.\")\n\n        # Flatten to [batch_samples, in_features]\n        if len(input_batch.shape) > 2:\n            input_batch = ops.reshape(input_batch, (-1, input_batch.shape[-1]))\n\n        x = ops.cast(input_batch, \"float32\")\n\n        # Compute per-channel max absolute value for this batch\n        batch_max = ops.max(ops.abs(x), axis=0)\n\n        # Update running max\n        self.activation_magnitudes = ops.maximum(\n            self.activation_magnitudes, batch_max\n        )\n        self.num_samples = self.num_samples + int(ops.shape(x)[0])\n\n    def quantize_layer(self):\n        \"\"\"Perform AWQ quantization on the layer.\n\n        This method:\n        1. Runs the AWQ grid search to find optimal scales\n        2. Quantizes the layer weights\n        3. Updates the layer's quantized variables\n        \"\"\"\n        from keras.src import quantizers\n\n        weights_matrix = ops.transpose(self.layer.kernel)\n\n        # Perform AWQ quantization\n        quantized, scale, zero, awq_scales, g_idx = awq_quantize_matrix(\n            weights_matrix,\n            self.activation_magnitudes,\n            num_grid_points=self.config.num_grid_points,\n            group_size=self.config.group_size,\n        )\n\n        # Cast to uint8 for storage\n        # quantized is already [out_features, in_features]\n        quantized = ops.cast(quantized, \"uint8\")\n\n        # Pack to 4-bit along axis 0 (output features)\n        quantized_packed, _, _ = quantizers.pack_int4(\n            quantized, axis=0, dtype=\"uint8\"\n        )\n\n        # Assign to layer variables\n        del self.original_layer._kernel\n        self.original_layer.quantized_kernel.assign(quantized_packed)\n        self.original_layer.kernel_scale.assign(scale)\n        self.original_layer.kernel_zero.assign(zero)\n        self.original_layer.awq_scales.assign(awq_scales)\n        self.original_layer.g_idx.assign(g_idx)\n        self.original_layer.is_awq_calibrated = True\n\n    def free(self):\n        \"\"\"Free memory used by the quantizer.\"\"\"\n        del self.activation_magnitudes\n        del self.layer\n"
  },
  {
    "path": "keras/src/quantizers/awq_config.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.quantizers.quantization_config import QuantizationConfig\n\n\n@keras_export(\"keras.quantizers.AWQConfig\")\nclass AWQConfig(QuantizationConfig):\n    \"\"\"Configuration class for AWQ (Activation-aware Weight Quantization).\n\n    AWQ is a post-training quantization method that identifies and protects\n    salient weights based on activation magnitudes. It applies per-channel\n    scaling before quantization to minimize accuracy loss.\n\n    Methodology:\n    1. Collects activation statistics from calibration data\n    2. Identifies salient weight channels based on activation magnitudes\n    3. Searches for optimal per-channel scaling factors via grid search\n    4. Applies scaling before quantization to protect important weights\n\n    References:\n    - Original AWQ paper: \"AWQ: Activation-aware Weight Quantization for\n      LLM Compression and Acceleration\" (https://arxiv.org/abs/2306.00978)\n    - Reference implementation: https://github.com/mit-han-lab/llm-awq\n\n    Args:\n        dataset: The calibration dataset. It can be an iterable that yields\n            strings or pre-tokenized numerical tensors (e.g., a list of\n            strings, a generator, or a NumPy array). This data is used to\n            analyze activation patterns.\n        tokenizer: A tokenizer instance (or a similar callable) that is used\n            to process the `dataset`.\n        weight_bits: The number of bits for weight quantization. AWQ presently\n            only supports 4-bit quantization. Defaults to 4.\n        num_samples: The number of calibration data samples to use from the\n            dataset. Defaults to 128.\n        sequence_length: The sequence length to use for each calibration\n            sample. Defaults to 512.\n        group_size: The size of weight groups to quantize together. A\n            `group_size` of -1 indicates per-channel quantization.\n            Defaults to 128.\n        num_grid_points: The number of grid search points for finding optimal\n            per-channel scales. Higher values may find better scales but\n            take longer. Defaults to 20.\n        quantization_layer_structure: A dictionary defining the model's\n            quantization structure. It should contain:\n            - \"pre_block_layers\": list of layers to run before the first\n              block (e.g., embedding layer).\n            - \"sequential_blocks\": list of transformer blocks to quantize\n              sequentially.\n            If not provided, the model must implement\n            `get_quantization_layer_structure`.\n\n    Example:\n    ```python\n    from keras.quantizers import AWQConfig\n\n    # Create configuration for 4-bit AWQ quantization\n    config = AWQConfig(\n        dataset=calibration_data,          # Your calibration dataset\n        tokenizer=your_tokenizer,          # Tokenizer for text data\n        num_samples=128,                   # Number of calibration samples\n        sequence_length=512,               # Sequence length for each sample\n        group_size=128,                    # Weight grouping for quantization\n        num_grid_points=20,                # Grid search points for scale search\n    )\n\n    # Apply quantization to your model\n    model.quantize(\"awq\", config=config)\n    ```\n\n    \"\"\"\n\n    def __init__(\n        self,\n        dataset,\n        tokenizer,\n        *,\n        weight_bits: int = 4,\n        num_samples: int = 128,\n        sequence_length: int = 512,\n        group_size: int = 128,\n        num_grid_points: int = 20,\n        quantization_layer_structure: dict = None,\n    ):\n        super().__init__()\n        # AWQ only supports 4-bit quantization\n        if weight_bits != 4:\n            raise ValueError(\n                f\"AWQ only supports 4-bit quantization. \"\n                f\"Received weight_bits={weight_bits}.\"\n            )\n        if num_samples <= 0:\n            raise ValueError(\"num_samples must be a positive integer.\")\n        if sequence_length <= 0:\n            raise ValueError(\"sequence_length must be a positive integer.\")\n        if group_size < -1 or group_size == 0:\n            raise ValueError(\n                \"Invalid group_size. Supported values are -1 (per-channel) \"\n                f\"or a positive integer, but got {group_size}.\"\n            )\n        if num_grid_points <= 0:\n            raise ValueError(\"num_grid_points must be a positive integer.\")\n\n        self.dataset = dataset\n        self.tokenizer = tokenizer\n        self.weight_bits = weight_bits\n        self.num_samples = num_samples\n        self.sequence_length = sequence_length\n        self.group_size = group_size\n        self.num_grid_points = num_grid_points\n        self.quantization_layer_structure = quantization_layer_structure\n\n    @property\n    def mode(self):\n        return \"awq\"\n\n    def dtype_policy_string(self):\n        \"\"\"Returns the dtype policy string for this configuration.\n\n        Returns:\n            A string representing the dtype policy, e.g. \"awq/4/128\".\n        \"\"\"\n        return f\"awq/{self.weight_bits}/{self.group_size}\"\n\n    def get_config(self):\n        return {\n            # Dataset and Tokenizer are only required for one-time\n            # calibration and are not saved in the config.\n            \"dataset\": None,\n            \"tokenizer\": None,\n            \"weight_bits\": self.weight_bits,\n            \"num_samples\": self.num_samples,\n            \"sequence_length\": self.sequence_length,\n            \"group_size\": self.group_size,\n            \"num_grid_points\": self.num_grid_points,\n            \"quantization_layer_structure\": self.quantization_layer_structure,\n        }\n\n    @classmethod\n    def from_config(cls, config):\n        return cls(**config)\n"
  },
  {
    "path": "keras/src/quantizers/awq_config_test.py",
    "content": "import pytest\n\nfrom keras.src import testing\nfrom keras.src.quantizers.awq_config import AWQConfig\n\n\n@pytest.mark.requires_trainable_backend\nclass AWQConfigTest(testing.TestCase):\n    \"\"\"Test AWQConfig validation and serialization.\"\"\"\n\n    class MockTokenizer:\n        \"\"\"Mock tokenizer for testing purposes.\"\"\"\n\n        def __init__(self):\n            pass\n\n    def test_config_defaults(self):\n        \"\"\"Test default configuration values.\"\"\"\n        config = AWQConfig(dataset=[\"test\"], tokenizer=self.MockTokenizer())\n        self.assertEqual(config.weight_bits, 4)\n        self.assertEqual(config.num_samples, 128)\n        self.assertEqual(config.sequence_length, 512)\n        self.assertEqual(config.group_size, 128)\n        self.assertEqual(config.num_grid_points, 20)\n        self.assertEqual(config.mode, \"awq\")\n\n    def test_config_custom_values(self):\n        \"\"\"Test custom configuration values.\"\"\"\n        config = AWQConfig(\n            dataset=[\"test\"],\n            tokenizer=self.MockTokenizer(),\n            num_samples=64,\n            sequence_length=256,\n            group_size=64,\n            num_grid_points=30,\n        )\n        self.assertEqual(config.num_samples, 64)\n        self.assertEqual(config.sequence_length, 256)\n        self.assertEqual(config.group_size, 64)\n        self.assertEqual(config.num_grid_points, 30)\n\n    def test_config_only_4bit(self):\n        \"\"\"Test that AWQ only supports 4-bit quantization.\"\"\"\n        with self.assertRaisesRegex(ValueError, \"only supports 4-bit\"):\n            AWQConfig(\n                dataset=[\"test\"], tokenizer=self.MockTokenizer(), weight_bits=8\n            )\n\n    def test_config_invalid_num_samples(self):\n        \"\"\"Test invalid num_samples validation.\"\"\"\n        with self.assertRaisesRegex(ValueError, \"num_samples must be\"):\n            AWQConfig(\n                dataset=[\"test\"], tokenizer=self.MockTokenizer(), num_samples=0\n            )\n\n    def test_config_invalid_sequence_length(self):\n        \"\"\"Test invalid sequence_length validation.\"\"\"\n        with self.assertRaisesRegex(ValueError, \"sequence_length must be\"):\n            AWQConfig(\n                dataset=[\"test\"],\n                tokenizer=self.MockTokenizer(),\n                sequence_length=-1,\n            )\n\n    def test_config_invalid_group_size(self):\n        \"\"\"Test invalid group_size validation.\"\"\"\n        with self.assertRaisesRegex(ValueError, \"Invalid group_size\"):\n            AWQConfig(\n                dataset=[\"test\"], tokenizer=self.MockTokenizer(), group_size=0\n            )\n\n    def test_config_invalid_num_grid_points(self):\n        \"\"\"Test invalid num_grid_points validation.\"\"\"\n        with self.assertRaisesRegex(ValueError, \"num_grid_points must be\"):\n            AWQConfig(\n                dataset=[\"test\"],\n                tokenizer=self.MockTokenizer(),\n                num_grid_points=0,\n            )\n\n    def test_config_per_channel_group_size(self):\n        \"\"\"Test that -1 group_size is valid (per-channel).\"\"\"\n        config = AWQConfig(\n            dataset=[\"test\"], tokenizer=self.MockTokenizer(), group_size=-1\n        )\n        self.assertEqual(config.group_size, -1)\n\n    def test_config_serialization(self):\n        \"\"\"Test configuration serialization.\"\"\"\n        config = AWQConfig(\n            dataset=[\"test\"],\n            tokenizer=self.MockTokenizer(),\n            group_size=64,\n            num_grid_points=30,\n        )\n        cfg = config.get_config()\n        self.assertEqual(cfg[\"weight_bits\"], 4)\n        self.assertEqual(cfg[\"group_size\"], 64)\n        self.assertEqual(cfg[\"num_grid_points\"], 30)\n        # Dataset and tokenizer should not be serialized\n        self.assertIsNone(cfg[\"dataset\"])\n        self.assertIsNone(cfg[\"tokenizer\"])\n\n    def test_dtype_policy_string(self):\n        \"\"\"Test dtype policy string generation.\"\"\"\n        config = AWQConfig(\n            dataset=[\"test\"], tokenizer=self.MockTokenizer(), group_size=128\n        )\n        self.assertEqual(config.dtype_policy_string(), \"awq/4/128\")\n\n        config2 = AWQConfig(\n            dataset=[\"test\"], tokenizer=self.MockTokenizer(), group_size=-1\n        )\n        self.assertEqual(config2.dtype_policy_string(), \"awq/4/-1\")\n\n    def test_awq_config_serialization(self):\n        \"\"\"Test AWQConfig serialization and deserialization round-trip.\"\"\"\n        config = AWQConfig(\n            dataset=[\"test\"],\n            tokenizer=self.MockTokenizer(),\n            weight_bits=4,\n            num_samples=64,\n            sequence_length=256,\n            group_size=64,\n            num_grid_points=30,\n        )\n        serialized_config = config.get_config()\n        deserialized_config = AWQConfig.from_config(serialized_config)\n        # Compare the serializable fields (dataset/tokenizer are not serialized)\n        self.assertEqual(config.weight_bits, deserialized_config.weight_bits)\n        self.assertEqual(config.num_samples, deserialized_config.num_samples)\n        self.assertEqual(\n            config.sequence_length, deserialized_config.sequence_length\n        )\n        self.assertEqual(config.group_size, deserialized_config.group_size)\n        self.assertEqual(\n            config.num_grid_points, deserialized_config.num_grid_points\n        )\n"
  },
  {
    "path": "keras/src/quantizers/awq_core.py",
    "content": "\"\"\"AWQ core functionality for layer-wise quantization.\n\nThis module provides the orchestration logic for applying AWQ quantization\nto transformer models in a layer-by-layer fashion.\n\"\"\"\n\nfrom contextlib import contextmanager\n\nfrom absl import logging\n\nfrom keras.src import ops\nfrom keras.src import utils as keras_utils\nfrom keras.src.dtype_policies.dtype_policy import AWQDTypePolicy\nfrom keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap\nfrom keras.src.quantizers.awq import AWQ\nfrom keras.src.quantizers.awq_config import AWQConfig\nfrom keras.src.quantizers.gptq_core import find_layers_in_block\nfrom keras.src.quantizers.gptq_core import get_dataloader\nfrom keras.src.quantizers.utils import should_quantize_layer\n\n\n@contextmanager\ndef stream_activations(layers_map, awq_objects):\n    \"\"\"Context manager to capture activations for AWQ calibration.\n\n    Temporarily patches layer.call methods to capture activation statistics\n    for computing per-channel scaling factors.\n\n    Args:\n        layers_map: Dict[str, Layer]. Mapping from layer names to layers.\n        awq_objects: Dict[str, AWQ]. Mapping from names to AWQ instances.\n\n    Yields:\n        None: The patched state is active only within the `with` block.\n    \"\"\"\n    original_calls = {}\n\n    def create_hook(name, original_call_func):\n        def hook(*args, **kwargs):\n            inp = args[0] if args else kwargs[\"inputs\"]\n            num_features = awq_objects[name].rows\n            input_2d = ops.reshape(inp, (-1, num_features))\n            awq_objects[name].update_activation_magnitudes(input_2d)\n            return original_call_func(*args, **kwargs)\n\n        return hook\n\n    try:\n        for name, layer in layers_map.items():\n            original_calls[name] = layer.call\n            layer.call = create_hook(name, layer.call)\n        yield\n    finally:\n        for name, layer in layers_map.items():\n            layer.call = original_calls[name]\n\n\ndef apply_awq_layerwise(dataloader, config, structure, filters=None):\n    \"\"\"Apply AWQ quantization layer-by-layer to a Keras model.\n\n    This function processes the model sequentially, one block at a time:\n    1. Captures activation statistics through calibration data forward pass\n    2. Uses activation magnitudes to determine weight saliency\n    3. Finds optimal per-channel scales via grid search\n    4. Quantizes weights with AWQ scaling\n\n    Args:\n        dataloader: Calibration data as numpy array.\n        config: AWQConfig instance.\n        structure: Dict with 'pre_block_layers' and 'sequential_blocks'.\n        filters: Optional layer filters.\n    \"\"\"\n    num_samples = config.num_samples\n    logging.info(\"Starting AWQ quantization...\")\n\n    pre_layers = structure.get(\"pre_block_layers\", [])\n    transformer_blocks = structure.get(\"sequential_blocks\", [])\n\n    if not transformer_blocks:\n        raise ValueError(\n            \"No sequential blocks found in the structure to quantize.\"\n        )\n\n    # Process inputs through pre-block layers (e.g., embedding)\n    inputs = []\n    for batch in dataloader:\n        batch = ops.convert_to_tensor(batch, dtype=\"int32\")\n        for layer in pre_layers:\n            batch = layer(batch)\n        inputs.append(batch)\n\n    num_samples = min(num_samples, len(inputs))\n    progbar = keras_utils.Progbar(target=len(transformer_blocks))\n\n    for block_idx, block in enumerate(transformer_blocks):\n        logging.info(f\"Quantizing Block {block_idx}\")\n        sub_layers_map = find_layers_in_block(block)\n\n        # Apply filters\n        final_sub_layers_map = {}\n        for name, layer in sub_layers_map.items():\n            if not should_quantize_layer(layer, filters):\n                continue\n            final_sub_layers_map[name] = layer\n\n        sub_layers_map = final_sub_layers_map\n\n        if not sub_layers_map:\n            logging.info(\n                f\"  No quantizable layers found in block {block_idx}. Skipping.\"\n            )\n        else:\n            logging.info(f\"Found layers: {list(sub_layers_map.keys())}\")\n\n            # Create AWQ objects for each layer\n            awq_objects = {\n                name: AWQ(layer, config)\n                for name, layer in sub_layers_map.items()\n            }\n\n            # Capture activation statistics\n            with stream_activations(sub_layers_map, awq_objects):\n                for sample_idx in range(num_samples):\n                    current_input = inputs[sample_idx]\n                    if len(current_input.shape) == 2:\n                        current_input = ops.expand_dims(current_input, axis=0)\n                    _ = block(current_input)\n\n            # Quantize each layer\n            for name, awq_object in awq_objects.items():\n                logging.info(f\"Quantizing {name}...\")\n                awq_object.quantize_layer()\n                awq_object.free()\n\n            del awq_objects\n\n        # Generate inputs for next block\n        if block_idx < len(transformer_blocks) - 1:\n            logging.info(f\"Generating inputs for block {block_idx + 1}...\")\n            next_block_inputs = []\n            for sample_idx in range(num_samples):\n                current_input = inputs[sample_idx]\n                if len(current_input.shape) == 2:\n                    current_input = ops.expand_dims(current_input, axis=0)\n                output = block(current_input)[0]\n                next_block_inputs.append(output)\n            inputs = next_block_inputs\n\n        progbar.update(current=block_idx + 1)\n\n    logging.info(\"AWQ quantization complete.\")\n\n\ndef awq_quantize(config, quantization_layer_structure, filters=None):\n    \"\"\"Main entry point for AWQ quantization.\n\n    Args:\n        config: AWQConfig instance.\n        quantization_layer_structure: Model structure dictionary.\n        filters: Optional layer filters.\n    \"\"\"\n    if config.dataset is None or config.tokenizer is None:\n        raise ValueError(\n            \"AWQ quantization requires a dataset and tokenizer. \"\n            \"Please provide them in the AWQConfig.\"\n        )\n\n    if quantization_layer_structure is None:\n        raise ValueError(\n            \"For 'awq' mode, a valid quantization structure must be provided \"\n            \"either via `config.quantization_layer_structure` or by overriding \"\n            \"`model.get_quantization_layer_structure(mode)`. The structure \"\n            \"should be a dictionary with keys 'pre_block_layers' and \"\n            \"'sequential_blocks'.\"\n        )\n\n    # Load calibration data\n    dataloader = get_dataloader(\n        config.tokenizer,\n        config.sequence_length,\n        config.dataset,\n        num_samples=config.num_samples,\n    )\n\n    apply_awq_layerwise(\n        dataloader[: config.num_samples],\n        config,\n        quantization_layer_structure,\n        filters=filters,\n    )\n\n\ndef get_group_size_for_layer(layer, config):\n    \"\"\"Get group size from config or dtype policy.\n\n    Args:\n        layer: The layer to get group size for.\n        config: Optional AWQConfig instance.\n\n    Returns:\n        int: The group size for quantization.\n\n    Raises:\n        ValueError: If group size cannot be determined.\n    \"\"\"\n    if config and isinstance(config, AWQConfig):\n        return config.group_size\n    elif isinstance(layer.dtype_policy, AWQDTypePolicy):\n        return layer.dtype_policy.group_size\n    elif isinstance(layer.dtype_policy, DTypePolicyMap):\n        policy = layer.dtype_policy[layer.path]\n        if isinstance(policy, AWQDTypePolicy):\n            return policy.group_size\n    raise ValueError(\n        \"For AWQ quantization, group_size must be specified \"\n        \"through AWQConfig or AWQDTypePolicy.\"\n    )\n"
  },
  {
    "path": "keras/src/quantizers/awq_test.py",
    "content": "\"\"\"Tests for AWQ quantization.\"\"\"\n\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nimport keras\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.quantizers.awq import AWQ\nfrom keras.src.quantizers.awq import awq_quantize_matrix\nfrom keras.src.quantizers.awq import awq_search_optimal_scales\nfrom keras.src.quantizers.awq_config import AWQConfig\n\n# Shared RNG instance for reproducible tests\nRNG = np.random.default_rng(seed=42)\n\n\nclass MockTokenizer:\n    \"\"\"Simple tokenizer for testing.\"\"\"\n\n    def __init__(self, vocab_size=100, seq_len=64):\n        self.vocab_size = vocab_size\n        self.seq_len = seq_len\n\n    def tokenize(self, text):\n        # Simple character-based tokenization\n        tokens = [ord(c) % self.vocab_size for c in str(text)]\n        # Pad or truncate to seq_len\n        if len(tokens) < self.seq_len:\n            tokens = tokens + [0] * (self.seq_len - len(tokens))\n        else:\n            tokens = tokens[: self.seq_len]\n        return ops.array([tokens], dtype=\"int32\")\n\n    def __call__(self, text):\n        return self.tokenize(text)\n\n\n@pytest.mark.requires_trainable_backend\nclass AWQAlgorithmTest(testing.TestCase):\n    \"\"\"Test AWQ algorithm core functionality.\"\"\"\n\n    def test_scale_search_returns_valid_scales(self):\n        \"\"\"Test that scale search returns valid positive scales.\"\"\"\n        weights = RNG.standard_normal((32, 16)).astype(\"float32\")\n        activations = ops.abs(\n            ops.add(RNG.standard_normal((16,)).astype(\"float32\"), 0.1)\n        )\n\n        scales = awq_search_optimal_scales(\n            weights, activations, num_grid_points=10, group_size=-1\n        )\n\n        self.assertEqual(scales.shape, (16,))\n        # All scales should be positive\n        self.assertTrue(ops.all(ops.greater(scales, 0)))\n\n    def test_scale_search_with_zero_activations(self):\n        \"\"\"Test scale search handles near-zero activations.\"\"\"\n        weights = ops.array(RNG.standard_normal((32, 16)).astype(\"float32\"))\n        # Some activations are very small\n        activations = np.abs(RNG.standard_normal((16,)).astype(\"float32\"))\n        activations[:5] = 1e-10\n        activations = ops.array(activations)\n\n        scales = awq_search_optimal_scales(\n            weights, activations, num_grid_points=10, group_size=-1\n        )\n\n        # Should handle gracefully without NaN or Inf\n        self.assertFalse(ops.any(ops.isnan(scales)))\n        self.assertFalse(ops.any(ops.isinf(scales)))\n\n    def test_quantize_matrix_shapes(self):\n        \"\"\"Test that quantize_matrix returns correct shapes.\"\"\"\n        # weights_transpose has shape [out_features, in_features]\n        weights = ops.array(RNG.standard_normal((32, 16)).astype(\"float32\"))\n        activations = ops.add(\n            ops.abs(RNG.standard_normal((16,)).astype(\"float32\")), 0.1\n        )\n\n        quantized, scale, zero, awq_scales, g_idx = awq_quantize_matrix(\n            weights, activations, num_grid_points=10, group_size=-1\n        )\n\n        # Quantized shape: [out_features, in_features]\n        self.assertEqual(quantized.shape, (32, 16))\n        # Scale shape: [out_features, num_groups]\n        self.assertEqual(scale.shape, (32, 1))\n        # AWQ scales: per-channel for input features\n        self.assertEqual(awq_scales.shape, (16,))\n        # AWQ zero shape: [out_features, num_groups]\n        self.assertEqual(zero.shape, (32, 1))\n        # Group indices\n        self.assertEqual(g_idx.shape, (16,))\n\n    def test_quantize_matrix_with_grouping(self):\n        \"\"\"Test quantize_matrix with group size.\"\"\"\n        # Use dimensions divisible by group_size for cleaner test\n        weights = ops.array(RNG.standard_normal((64, 32)).astype(\"float32\"))\n        activations = ops.add(\n            ops.abs(RNG.standard_normal((32,)).astype(\"float32\")), 0.1\n        )\n\n        # Test per-channel mode (group_size=-1) which is well-supported\n        quantized, scale, zero, awq_scales, g_idx = awq_quantize_matrix(\n            weights, activations, num_grid_points=5, group_size=8\n        )\n\n        # Quantized shape: [out_features, in_features]\n        self.assertEqual(quantized.shape, (64, 32))\n        # Scale shape: [out_features, num_groups]\n        self.assertEqual(scale.shape, (64, 4))  # 32 in_features / 8 group_size\n        # AWQ scales: per-channel for input features\n        self.assertEqual(awq_scales.shape, (32,))\n        # AWQ zero shape: [out_features, num_groups]\n        self.assertEqual(zero.shape, (64, 4))\n        # Group indices\n        self.assertEqual(g_idx.shape, (32,))\n\n        # Check g_idx values\n        self.assertEqual(ops.max(g_idx), 3)  # 4 groups: 0,1,2,3\n        self.assertEqual(awq_scales.shape, (32,))\n\n    def test_quantize_matrix_grouped_shapes(self):\n        \"\"\"Test awq_quantize_matrix with positive group_size.\n\n        This is a regression test for the InvalidArgumentError that occurred\n        when group_size != -1 due to shape mismatch in broadcasting.\n        \"\"\"\n        out_features = 768\n        in_features = 768\n        group_size = 128\n        n_groups = in_features // group_size  # 6 groups\n\n        weights = ops.array(\n            RNG.standard_normal((out_features, in_features)).astype(\"float32\")\n        )\n        activations = ops.array(\n            np.abs(RNG.standard_normal((in_features,)).astype(\"float32\")) + 0.1\n        )\n\n        quantized, scale, zero, awq_scales, g_idx = awq_quantize_matrix(\n            weights, activations, num_grid_points=5, group_size=group_size\n        )\n\n        # Quantized should match input shape\n        self.assertEqual(quantized.shape, (out_features, in_features))\n        # Scale should be [out_features, n_groups]\n        self.assertEqual(scale.shape, (out_features, n_groups))\n        # Zero should be [out_features, n_groups]\n        self.assertEqual(zero.shape, (out_features, n_groups))\n        # AWQ scales should be per-input-channel\n        self.assertEqual(awq_scales.shape, (in_features,))\n        # g_idx should be [in_features]\n        self.assertEqual(g_idx.shape, (in_features,))\n\n        # Verify g_idx values\n        expected_g_idx = ops.floor_divide(ops.arange(in_features), group_size)\n        self.assertAllEqual(g_idx, expected_g_idx)\n\n    def test_quantize_matrix_grouped_no_nan_inf(self):\n        \"\"\"Test grouped quantization produces no NaN or Inf values.\"\"\"\n        out_features = 256\n        in_features = 512\n        group_size = 64\n\n        weights = ops.array(\n            RNG.standard_normal((out_features, in_features)).astype(\"float32\")\n        )\n        activations = ops.add(\n            ops.abs(RNG.standard_normal((in_features,)).astype(\"float32\")), 0.1\n        )\n\n        quantized, scale, _, awq_scales, _ = awq_quantize_matrix(\n            weights, activations, num_grid_points=5, group_size=group_size\n        )\n\n        # Check for NaN/Inf in all outputs\n        self.assertFalse(ops.any(ops.isnan(quantized)))\n        self.assertFalse(ops.any(ops.isinf(quantized)))\n        self.assertFalse(ops.any(ops.isnan(scale)))\n        self.assertFalse(ops.any(ops.isinf(scale)))\n        self.assertFalse(ops.any(ops.isnan(awq_scales)))\n        self.assertFalse(ops.any(ops.isinf(awq_scales)))\n\n    def test_scale_search_grouped_quantization(self):\n        \"\"\"Test awq_search_optimal_scales with grouped quantization.\"\"\"\n        out_features = 128\n        in_features = 256\n        group_size = 32\n\n        weights = ops.array(\n            RNG.standard_normal((out_features, in_features)).astype(\"float32\")\n        )\n        activations = ops.add(\n            ops.abs(RNG.standard_normal((in_features,)).astype(\"float32\")), 0.1\n        )\n\n        scales = awq_search_optimal_scales(\n            weights, activations, num_grid_points=5, group_size=group_size\n        )\n\n        # Scales should be [in_features]\n        self.assertEqual(scales.shape, (in_features,))\n        # All scales should be positive\n        self.assertTrue(ops.all(ops.greater(scales, 0)))\n        # No NaN or Inf\n        self.assertFalse(ops.any(ops.isnan(scales)))\n        self.assertFalse(ops.any(ops.isinf(scales)))\n\n    @parameterized.named_parameters(\n        (\"group_8\", 8),\n        (\"group_16\", 16),\n        (\"group_32\", 32),\n        (\"group_64\", 64),\n        (\"group_128\", 128),\n    )\n    def test_quantize_matrix_various_group_sizes(self, group_size):\n        \"\"\"Test awq_quantize_matrix with various group sizes.\"\"\"\n        out_features = 64\n        in_features = 128\n        n_groups = in_features // group_size\n\n        weights = ops.array(\n            RNG.standard_normal((out_features, in_features)).astype(\"float32\")\n        )\n        activations = ops.add(\n            ops.abs(RNG.standard_normal((in_features,)).astype(\"float32\")), 0.1\n        )\n\n        _, scale, zero, _, _ = awq_quantize_matrix(\n            weights, activations, num_grid_points=3, group_size=group_size\n        )\n\n        self.assertEqual(\n            scale.shape,\n            (out_features, n_groups),\n            f\"Failed for group_size={group_size}\",\n        )\n        self.assertEqual(\n            zero.shape,\n            (out_features, n_groups),\n            f\"Failed for group_size={group_size}\",\n        )\n\n\n@pytest.mark.requires_trainable_backend\nclass AWQLayerTest(testing.TestCase):\n    \"\"\"Test AWQ class for layer quantization.\"\"\"\n\n    def test_awq_on_dense_layer(self):\n        \"\"\"Test AWQ on a Dense layer.\"\"\"\n        layer = layers.Dense(32)\n        layer.build(input_shape=(None, 16))\n\n        config = AWQConfig(\n            dataset=None,\n            tokenizer=None,\n            group_size=-1,\n            num_grid_points=10,\n        )\n\n        layer.quantize(config=config)\n        awq_obj = AWQ(layer, config)\n\n        # Simulate activation capture\n        calibration_data = RNG.standard_normal((64, 16)).astype(\"float32\")\n        awq_obj.update_activation_magnitudes(calibration_data)\n\n        self.assertEqual(awq_obj.num_samples, 64)\n        # Activation magnitudes should be non-negative\n        self.assertTrue(\n            ops.all(ops.greater_equal(awq_obj.activation_magnitudes, 0))\n        )\n\n    def test_awq_activation_accumulation(self):\n        \"\"\"Test that activation magnitudes accumulate correctly.\"\"\"\n        layer = layers.Dense(32)\n        layer.build(input_shape=(None, 16))\n\n        config = AWQConfig(\n            dataset=None, tokenizer=None, group_size=-1, num_grid_points=10\n        )\n        layer.quantize(config=config)\n        awq_obj = AWQ(layer, config)\n\n        # First batch\n        batch1 = ops.abs(RNG.standard_normal((10, 16)).astype(\"float32\"))\n        batch1_max = ops.max(batch1, axis=0)\n        awq_obj.update_activation_magnitudes(batch1)\n\n        # Second batch with higher values in some channels\n        batch2 = ops.add(\n            ops.abs(RNG.standard_normal((10, 16)).astype(\"float32\")), 1.0\n        )\n        batch2_max = ops.max(batch2, axis=0)\n        awq_obj.update_activation_magnitudes(batch2)\n\n        # Accumulated magnitudes should be element-wise max\n        expected_max = ops.maximum(batch1_max, batch2_max)\n        self.assertAllClose(\n            awq_obj.activation_magnitudes, expected_max, atol=1e-6\n        )\n\n    def test_awq_layer_variables_created(self):\n        \"\"\"Test that AWQ layer variables are properly created.\"\"\"\n        layer = layers.Dense(32)\n        layer.build(input_shape=(None, 16))\n\n        config = AWQConfig(\n            dataset=None, tokenizer=None, group_size=-1, num_grid_points=10\n        )\n        layer.quantize(config=config)\n\n        # Check that AWQ-specific variables exist\n        self.assertTrue(hasattr(layer, \"quantized_kernel\"))\n        self.assertTrue(hasattr(layer, \"kernel_scale\"))\n        self.assertTrue(hasattr(layer, \"kernel_zero\"))\n        self.assertTrue(hasattr(layer, \"awq_scales\"))\n        self.assertTrue(hasattr(layer, \"g_idx\"))\n        self.assertFalse(layer.is_awq_calibrated)\n\n\n@pytest.mark.requires_trainable_backend\nclass AWQIntegrationTest(testing.TestCase):\n    \"\"\"Integration tests for AWQ quantization.\"\"\"\n\n    def test_dense_layer_quantize_awq(self):\n        \"\"\"Test Dense layer can be quantized with AWQ.\"\"\"\n        layer = layers.Dense(64)\n        layer.build(input_shape=(None, 32))\n\n        config = AWQConfig(\n            dataset=None, tokenizer=None, group_size=16, num_grid_points=5\n        )\n        layer.quantize(config=config)\n\n        # Check layer is properly configured\n        self.assertEqual(layer.quantization_mode, \"awq\")\n        self.assertTrue(hasattr(layer, \"awq_scales\"))\n\n    def test_einsum_dense_layer_quantize_awq(self):\n        \"\"\"Test EinsumDense layer can be quantized with AWQ.\"\"\"\n        layer = layers.EinsumDense(\"ab,bc->ac\", output_shape=(64,))\n        layer.build(input_shape=(None, 32))\n\n        config = AWQConfig(\n            dataset=None, tokenizer=None, group_size=-1, num_grid_points=5\n        )\n        layer.quantize(config=config)\n\n        # Check layer is properly configured\n        self.assertEqual(layer.quantization_mode, \"awq\")\n        self.assertTrue(hasattr(layer, \"awq_scales\"))\n\n    def test_model_quantize_requires_structure(self):\n        \"\"\"Test model.quantize requires structure for AWQ.\"\"\"\n        model = models.Sequential([layers.Dense(10, input_shape=(5,))])\n        model.build()\n\n        config = AWQConfig(\n            dataset=[\"test data\"],\n            tokenizer=MockTokenizer(vocab_size=100, seq_len=5),\n        )\n\n        with self.assertRaisesRegex(ValueError, \"quantization structure\"):\n            model.quantize(config=config)\n\n\n# Constants for end-to-end tests\nVOCAB_SIZE = 1000\nSEQ_LEN = 128\nNUM_SAMPLES = 16\nNUM_CLASSES = 32\n\nCALIBRATION_TEXT = \"\"\"\nAWQ (Activation-aware Weight Quantization) is an efficient and accurate\nlow-bit weight quantization method for LLMs. AWQ is based on the observation\nthat weights are not equally important: protecting only 1% of salient weights\ncan greatly reduce quantization error. To find salient weights, AWQ looks at\nthe activation distribution, not weights. Salient weights are those that\ncorrespond to channels with larger activation magnitudes. AWQ then applies\nper-channel scaling to protect salient weights during quantization.\nThe key insight is that for a weight channel, if the corresponding activation\nchannel has large values, quantizing that weight channel will incur large\nerror. By scaling up salient weight channels before quantization and scaling\ndown during inference, AWQ can significantly reduce quantization error\nwhile maintaining the same effective computation.\n\"\"\"\n\n\ndef _mean_kl(p, q):\n    \"\"\"Compute mean KL divergence between two probability distributions.\"\"\"\n    eps = 1e-8\n    p = ops.clip(p, eps, 1.0)\n    q = ops.clip(q, eps, 1.0)\n    return ops.mean(\n        ops.sum(ops.multiply(p, ops.subtract(ops.log(p), ops.log(q))), axis=-1)\n    )\n\n\ndef _top1_match_rate(a_logits, b_logits):\n    \"\"\"Calculate top-1 match rate between two sets of logits.\"\"\"\n    return ops.mean(\n        ops.equal(ops.argmax(a_logits, axis=-1), ops.argmax(b_logits, axis=-1))\n    )\n\n\ndef _get_sequence_classifier():\n    \"\"\"Create a transformer-based sequence classifier for testing.\"\"\"\n    embed_dim = 32\n    num_heads = 4\n    ff_dim = 32\n\n    class SimpleTransformerBlock(layers.Layer):\n        def __init__(self, embed_dim, num_heads, ff_dim, **kwargs):\n            super().__init__(**kwargs)\n            self.att = layers.MultiHeadAttention(\n                num_heads=num_heads, key_dim=embed_dim // num_heads\n            )\n            self.ffn = models.Sequential(\n                [\n                    layers.Dense(ff_dim, activation=\"relu\"),\n                    layers.Dense(embed_dim),\n                ]\n            )\n            self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)\n            self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)\n\n        def call(self, inputs):\n            attention_output = self.att(inputs, inputs)\n            out1 = self.layernorm1(inputs + attention_output)\n            ffn_output = self.ffn(out1)\n            return self.layernorm2(out1 + ffn_output)\n\n    inputs = layers.Input(shape=(SEQ_LEN,), dtype=\"int32\")\n    x = layers.Embedding(VOCAB_SIZE, embed_dim)(inputs)\n    x = SimpleTransformerBlock(embed_dim, num_heads, ff_dim)(x)\n    x = layers.GlobalAveragePooling1D()(x)\n    outputs = layers.Dense(NUM_CLASSES)(x)\n    return models.Model(inputs, outputs)\n\n\ndef _char_tokenizer(vocab_size=VOCAB_SIZE, seq_len=SEQ_LEN):\n    \"\"\"Character-based tokenizer for testing.\"\"\"\n\n    def _pad_or_trim_1d(ids, length):\n        ids = ops.ravel(ops.array(ids, \"int64\"))\n        if len(ids) < length:\n            ids = ops.concatenate(\n                [ids, ops.zeros(length - len(ids), dtype=ids.dtype)]\n            )\n        else:\n            ids = ids[:length]\n        return ids\n\n    def _tok(x):\n        if isinstance(x, str):\n            ids = ops.convert_to_tensor(\n                np.fromiter((ord(c) % vocab_size for c in x), dtype=np.int64)\n            )\n        else:\n            ids = np.asarray(x, dtype=np.int64)\n        ids = _pad_or_trim_1d(ids, seq_len)\n        return ids[None, :]\n\n    _tok.tokenize = _tok\n    return _tok\n\n\ndef _string_dataset(\n    long_text, num_samples=NUM_SAMPLES, sequence_length=SEQ_LEN\n):\n    \"\"\"Yield string slices for calibration.\"\"\"\n    length = max(1, len(long_text) - sequence_length)\n    for _ in range(num_samples):\n        start = RNG.integers(0, length) if length > 1 else 0\n        yield long_text[start : start + sequence_length]\n\n\n@pytest.mark.requires_trainable_backend\nclass AWQAccuracyTest(testing.TestCase):\n    \"\"\"End-to-end accuracy preservation tests for AWQ quantization.\"\"\"\n\n    @parameterized.named_parameters(\n        (\"per_channel\", -1, 20, 0.5, 0.30),\n        (\"group_16\", 16, 10, 0.4, 0.40),\n    )\n    def test_awq_transformer_accuracy(\n        self, group_size, num_grid_points, min_top1, max_kl\n    ):\n        \"\"\"Test that AWQ quantization preserves model accuracy.\n\n        This test:\n        1. Creates a transformer-based sequence classifier\n        2. Gets baseline (full precision) predictions\n        3. Applies AWQ quantization with calibration data\n        4. Compares quantized predictions against baseline\n        5. Validates top-1 match rate and KL divergence bounds\n        \"\"\"\n        keras.utils.set_random_seed(123)\n\n        # Build calibration dataset\n        calibration_set = list(_string_dataset(CALIBRATION_TEXT, NUM_SAMPLES))\n        self.assertNotEmpty(calibration_set)\n\n        # Build model and tokenizer\n        model = _get_sequence_classifier()\n        tokenizer = _char_tokenizer(vocab_size=VOCAB_SIZE, seq_len=SEQ_LEN)\n\n        # Build eval batch from same distribution as calibration\n        batch_size = min(8, len(calibration_set))\n        eval_samples = [\n            calibration_set[RNG.integers(0, len(calibration_set))]\n            for _ in range(batch_size)\n        ]\n        x_eval = ops.concatenate([tokenizer(s) for s in eval_samples], axis=0)\n\n        # Get baseline predictions (full precision)\n        y_ref = model.predict(x_eval)\n\n        # Define layer structure for AWQ\n        embedding_layer = model.layers[1]\n        transformer_block = model.layers[2]\n\n        layer_structure = {\n            \"pre_block_layers\": [embedding_layer],\n            \"sequential_blocks\": [transformer_block],\n        }\n\n        # Configure AWQ\n        awq_config = AWQConfig(\n            dataset=calibration_set,\n            tokenizer=tokenizer,\n            num_samples=NUM_SAMPLES,\n            sequence_length=SEQ_LEN,\n            group_size=group_size,\n            num_grid_points=num_grid_points,\n            quantization_layer_structure=layer_structure,\n        )\n\n        # Quantize model with AWQ\n        model.quantize(config=awq_config)\n\n        # Get post-quantization predictions\n        y_q = model.predict(x_eval)\n\n        # Calculate accuracy metrics\n        top1_match = _top1_match_rate(y_ref, y_q)\n\n        p_ref = ops.softmax(y_ref)\n        p_q = ops.softmax(y_q)\n        kl = _mean_kl(p_ref, p_q)\n\n        # Validate accuracy preservation\n        self.assertGreaterEqual(\n            float(top1_match),\n            min_top1,\n            f\"Top-1 agreement too low for group_size={group_size}: \"\n            f\"{float(top1_match):.3f}\",\n        )\n        self.assertLessEqual(\n            float(kl),\n            max_kl,\n            f\"KL divergence too high for group_size={group_size}: \"\n            f\"{float(kl):.3f}\",\n        )\n\n    @parameterized.named_parameters(\n        (\"per_channel\", -1, 0.35),\n        (\"group_16\", 16, 0.35),\n        (\"group_32\", 32, 0.35),\n        (\"group_64\", 64, 0.35),\n        (\"group_128\", 128, 0.35),\n    )\n    def test_awq_accuracy_various_group_sizes(\n        self, group_size, max_relative_mse\n    ):\n        \"\"\"Test AWQ accuracy across various group sizes.\n\n        Verifies that quantizing a single layer maintains reasonable\n        output reconstruction error and correct variable shapes.\n        \"\"\"\n        in_features = 128\n        out_features = 64\n\n        keras.utils.set_random_seed(42)\n\n        # Create fresh layer for each test\n        layer = layers.Dense(out_features)\n        layer.build(input_shape=(None, in_features))\n\n        # Create data\n        calibration_data = RNG.standard_normal((64, in_features)).astype(\n            \"float32\"\n        )\n        test_data = RNG.standard_normal((16, in_features)).astype(\"float32\")\n\n        # Get original output\n        original_output = layer(test_data)\n\n        # Configure and quantize\n        config = AWQConfig(\n            dataset=None,\n            tokenizer=None,\n            group_size=group_size,\n            num_grid_points=5,\n        )\n        layer.quantize(config=config)\n\n        awq_obj = AWQ(layer, config)\n        awq_obj.update_activation_magnitudes(calibration_data)\n        awq_obj.quantize_layer()\n\n        # Verify layer variables have correct shapes for grouped quantization\n        if group_size > 0:\n            n_groups = in_features // group_size\n            self.assertEqual(\n                layer.kernel_scale.shape,\n                (out_features, n_groups),\n                f\"kernel_scale shape mismatch for group_size={group_size}\",\n            )\n            self.assertEqual(\n                layer.kernel_zero.shape,\n                (out_features, n_groups),\n                f\"kernel_zero shape mismatch for group_size={group_size}\",\n            )\n\n        # Verify output\n        quantized_output = layer(test_data)\n\n        # Should have no NaN/Inf\n        self.assertFalse(\n            ops.any(ops.isnan(quantized_output)),\n            f\"NaN in output for group_size={group_size}\",\n        )\n        self.assertFalse(\n            ops.any(ops.isinf(quantized_output)),\n            f\"Inf in output for group_size={group_size}\",\n        )\n\n        # Should maintain reasonable accuracy\n        mse = ops.mean(\n            ops.power(ops.subtract(original_output, quantized_output), 2)\n        )\n        original_var = ops.var(original_output)\n        relative_mse = ops.divide(mse, ops.add(original_var, 1e-8))\n\n        self.assertLess(\n            relative_mse,\n            max_relative_mse,\n            f\"Accuracy too low for group_size={group_size}: \"\n            f\"relative_mse={relative_mse:.4f}\",\n        )\n\n        awq_obj.free()\n"
  },
  {
    "path": "keras/src/quantizers/gptq.py",
    "content": "import types\n\nfrom keras.src import ops\nfrom keras.src import quantizers\nfrom keras.src.layers import Dense\nfrom keras.src.layers import EinsumDense\nfrom keras.src.ops import linalg\nfrom keras.src.quantizers.gptq_config import GPTQConfig\nfrom keras.src.quantizers.quantizers import GPTQQuantizer\nfrom keras.src.quantizers.quantizers import compute_quantization_parameters\nfrom keras.src.quantizers.quantizers import dequantize_with_zero_point\nfrom keras.src.quantizers.quantizers import quantize_with_zero_point\n\n\ndef _stable_permutation(metric):\n    \"\"\"Return a stable permutation that sorts `metric` in descending order.\n    Uses an index-based jitter to break ties deterministically.\"\"\"\n    n = ops.shape(metric)[0]\n    idx = ops.arange(0, n, dtype=\"int32\")\n    # tiny jitter = (idx / n) * 1e-12 so it never flips a real strict ordering\n    jitter = ops.divide(ops.cast(idx, \"float32\"), ops.cast(n, \"float32\"))\n    metric_jittered = ops.add(metric, ops.multiply(jitter, 1e-12))\n    # argsort by negative to get descending\n    return ops.argsort(ops.negative(metric_jittered))\n\n\ndef gptq_quantize_matrix(\n    weights_transpose,\n    inv_hessian,\n    *,\n    blocksize=128,\n    group_size=-1,\n    activation_order=False,\n    order_metric=None,\n    compute_scale_zero=compute_quantization_parameters,\n):\n    \"\"\"\n    Implements the GPTQ error correction updates.\n\n    For a single column update (column j):\n        e = invH[j, j] * (w_j - q_j)\n        W[:, j+1:] -= e * invH[j, j+1:]\n    where:\n    - w_j is the original column,\n    - q_j is the quantized column,\n    - invH is the inverse Hessian,\n    - e is the propagated error term.\n\n    Across entire blocks:\n        W[:, future] -= E_block * invH[block, future]\n    where:\n    - E_block is the quantization error accumulated for the current block,\n    - invH[block, future] denotes the cross-block slice of the inverse Hessian,\n    - W[:, future] are the columns yet to be quantized.\n\n    Args:\n        weights_transpose: Transposed weight matrix [out_features, in_features]\n         to quantize.\n        inv_hessian: Inverse Hessian matrix [in_features, in_features] for\n         error propagation.\n        blocksize: Size of the blocks to process (default: 128).\n        group_size: Size of the groups for parameter reuse\n         (default: -1, no grouping).\n        activation_order: Whether to apply activation-order permutation\n         (default: False).\n        order_metric: Metric for ordering features\n         (default: None, uses 1 / diag(invH)).\n        compute_scale_zero: Function to compute scale and zero for\n         quantization.\n\n    Returns:\n        quantized_weights: Quantized weight matrix [out_features, in_features].\n        scale: float32. Scale parameters for quantization\n         [out_features, num_groups].\n        zero: Zero-point parameters for quantization [out_features, num_groups].\n        g_idx: int32. Group indices for each feature [in_features].\n    \"\"\"\n    in_features = ops.shape(weights_transpose)[1]\n\n    if activation_order:\n        # Use 1 / diag(inverse_hessian) as importance proxy by default.\n        if order_metric is None:\n            order_metric = ops.reciprocal(\n                ops.add(ops.diagonal(inv_hessian), 1e-12)\n            )\n        else:\n            # sanitize provided metric\n            order_metric = ops.cast(order_metric, \"float32\")\n            order_metric = ops.where(\n                ops.isfinite(order_metric),\n                order_metric,\n                ops.zeros_like(order_metric),\n            )\n        # Sort in descending order by importance\n        perm = _stable_permutation(order_metric)\n        inv_perm = ops.argsort(perm)\n\n        weights_transpose = ops.take(weights_transpose, perm, axis=1)\n        inv_hessian = ops.take(\n            ops.take(inv_hessian, perm, axis=0), perm, axis=1\n        )\n    else:\n        perm = inv_perm = None\n\n    # weights_buffer: [out_features, in_features]\n    weights_buffer = weights_transpose\n    # Buffer for the final quantized matrix: [out_features, in_features]\n    quantized_weights_buffer = ops.zeros_like(weights_transpose, dtype=\"int32\")\n\n    scale_chunks = []\n    zero_chunks = []\n\n    # Compute effective group size\n    effective_group = in_features if group_size == -1 else group_size\n\n    # Process features in blocks\n    for block_start in range(0, in_features, blocksize):\n        block_end = min(block_start + blocksize, in_features)\n        block_size = block_end - block_start\n\n        # Block views\n        # block_weights: [out_features, block_size]\n        block_weights = weights_buffer[:, block_start:block_end]\n        # block_error: [out_features, block_size]\n        block_error = ops.zeros_like(block_weights)\n        # block_inv_hessian: [block_size, block_size]\n        block_inv_hessian = inv_hessian[\n            block_start:block_end, block_start:block_end\n        ]\n\n        # Per-group cached params for reuse within the group\n        cached_scale = None\n        cached_zero = None\n        cached_maxq = None\n        cached_group_start = -1\n\n        for block_idx in range(block_size):\n            # Current global column index, represents the original column\n            # in the weight matrix\n            global_idx = block_start + block_idx\n            # weight_column: [out_features,]\n            weight_column = block_weights[:, block_idx]\n            # Group-wise parameter reuse (compute once per group)\n            if not effective_group == in_features:  # group_size != -1\n                # Determine the group start index for the current column\n                group_start = (global_idx // effective_group) * effective_group\n                if group_start != cached_group_start:\n                    # New group encountered, compute & cache params\n                    # for this group\n                    group_end = min(group_start + effective_group, in_features)\n                    group_slice = weights_buffer[:, group_start:group_end]\n                    cached_scale, cached_zero, cached_maxq = compute_scale_zero(\n                        group_slice\n                    )\n                    # Store params once per group (in the order encountered).\n                    scale_chunks.append(cached_scale)\n                    zero_chunks.append(cached_zero)\n                    cached_group_start = group_start\n                scale, zero, maxq = cached_scale, cached_zero, cached_maxq\n            else:\n                # Single global group covering all columns.\n                if cached_scale is None:\n                    cached_scale, cached_zero, cached_maxq = compute_scale_zero(\n                        weights_buffer\n                    )\n                    scale_chunks.append(cached_scale)\n                    zero_chunks.append(cached_zero)\n                    cached_group_start = 0\n                scale, zero, maxq = cached_scale, cached_zero, cached_maxq\n\n            # Quantize column and store it.\n            # quantized_column: [out_features, 1]\n            quantized_column = quantize_with_zero_point(\n                ops.expand_dims(weight_column, 1), scale, zero, maxq\n            )\n\n            # Store quantized column in the buffer.\n            quantized_weights_buffer = ops.slice_update(\n                quantized_weights_buffer,\n                (0, global_idx),\n                ops.cast(quantized_column, \"int32\"),\n            )\n            # Dequantize column to compute error.\n            # dequantized_col: [out_features,]\n            dequantized_col = dequantize_with_zero_point(\n                quantized_column, scale, zero\n            )[:, 0]\n            # Error feedback for remaining columns within the block\n            # block_inv_hessian_diag: scalar\n            current_block_influence = block_inv_hessian[block_idx, block_idx]\n            # We divide by current_block_influence to get the\n            # correct scaling of the error term.\n            err = ops.divide(\n                ops.subtract(weight_column, dequantized_col),\n                current_block_influence,\n            )\n            # Record error for propagation to future blocks\n            block_error = ops.slice_update(\n                block_error, (0, block_idx), ops.expand_dims(err, 1)\n            )\n\n            # Update remaining columns in the current block\n            # (those before the current column have already been quantized)\n            # Propagate error to remaining columns in the block.\n            if block_idx < block_size - 1:\n                # update: [out_features, block_size - block_idx - 1]\n                update = ops.matmul(\n                    ops.expand_dims(err, 1),\n                    ops.expand_dims(\n                        block_inv_hessian[block_idx, block_idx + 1 :], 0\n                    ),\n                )\n                # tail is a view of the remaining columns in the block\n                # to be updated\n                # tail: [out_features, block_size - block_idx - 1]\n                tail = block_weights[:, block_idx + 1 :]\n                block_weights = ops.slice_update(\n                    block_weights,\n                    (0, block_idx + 1),\n                    ops.subtract(tail, update),\n                )\n\n        # Propagate block errors to future features (beyond the block)\n        if block_end < in_features:\n            # Total update for all future columns, based on the\n            # accumulated error in this block. This is calculated\n            # as the matrix product of the block_error and the\n            # relevant slice of the inverse Hessian.\n            # total_update: [out_features, in_features - block_end]\n            total_update = ops.matmul(\n                block_error, inv_hessian[block_start:block_end, block_end:]\n            )\n            # Update the remaining weights in the buffer. This is done\n            # by subtracting the total_update from the remaining columns.\n            weights_buffer = ops.concatenate(\n                [\n                    weights_buffer[:, :block_end],\n                    ops.subtract(weights_buffer[:, block_end:], total_update),\n                ],\n                axis=1,\n            )\n\n    # Build group indices for each (possibly permuted) column\n    # base_group = effective_group (int)\n    base_group = effective_group\n\n    # g_idx in permuted domain\n    g_idx = ops.arange(0, in_features, dtype=\"int32\")\n    g_idx = ops.divide(g_idx, base_group)\n    g_idx = ops.cast(g_idx, \"float32\")\n\n    # Map group indices and quantized weights back to original column order\n    if activation_order:\n        g_idx = ops.take(g_idx, inv_perm, axis=0)\n        quantized_weights_buffer = ops.take(\n            quantized_weights_buffer, inv_perm, axis=1\n        )\n\n    # Concatenate recorded group params\n    if len(scale_chunks) == 0:\n        # Edge case: no groups recorded (empty input); fall back to whole matrix\n        s, z, _ = compute_scale_zero(weights_transpose)\n        scale = s\n        zero = z\n    else:\n        scale = ops.concatenate(scale_chunks, axis=1)\n        zero = ops.concatenate(zero_chunks, axis=1)\n\n    return quantized_weights_buffer, scale, zero, g_idx\n\n\nclass GPTQ:\n    def __init__(self, layer, config=GPTQConfig(tokenizer=None, dataset=None)):\n        self.original_layer = layer\n        self.num_samples = 0\n        self.config = config\n        self.quantizer = GPTQQuantizer(\n            config, compute_dtype=layer.variable_dtype\n        )\n\n        # Explicitly handle each supported layer type\n        if isinstance(layer, Dense) or (\n            isinstance(layer, EinsumDense) and layer.kernel.ndim == 2\n        ):\n            # For a standard Dense layer, the dimensions are straightforward.\n            self.kernel_shape = layer.kernel.shape\n            # rows: [input_features]\n            self.rows = self.kernel_shape[0]\n            # columns: [output_features]\n            self.columns = self.kernel_shape[1]\n            self.layer = layer\n\n        # Handle 3D EinsumDense layers (typically from attention blocks).\n        elif isinstance(layer, EinsumDense) and layer.kernel.ndim == 3:\n            # For EinsumDense, we determine the effective 2D dimensions.\n            self.kernel_shape = layer.kernel.shape\n            shape = list(self.kernel_shape)\n            d_model_dim_index = shape.index(max(shape))\n\n            if d_model_dim_index == 0:  # QKV projection case\n                in_features, heads, head_dim = shape\n                self.rows, self.columns = (\n                    in_features,\n                    ops.multiply(heads, head_dim),\n                )\n            elif d_model_dim_index in [1, 2]:  # Attention Output case\n                heads, head_dim, out_features = shape\n                self.rows, self.columns = (\n                    ops.multiply(heads, head_dim),\n                    out_features,\n                )\n\n            # Create a temporary object that holds a reshaped\n            # 2D version of the kernel.\n            self.layer = types.SimpleNamespace(\n                kernel=ops.reshape(layer.kernel, (self.rows, self.columns)),\n            )\n        else:\n            # Raise an error if the layer is not supported.\n            raise TypeError(f\"Unsupported layer type for GPTQ: {type(layer)}\")\n        self.hessian = ops.zeros((self.rows, self.rows), dtype=\"float32\")\n\n    def update_hessian_with_batch(self, input_batch):\n        \"\"\"\n        Updates the running average of the Hessian matrix with a new batch.\n\n        This method computes the Hessian matrix for a given batch of input\n        activations and updates the accumulated Hessian (`self.hessian`) using a\n        numerically stable running average. This allows the Hessian to be\n        computed over a large dataset without loading all samples into memory\n        at once.\n\n        The input tensor is first reshaped into a 2D matrix [num_samples,\n        num_features] before the Hessian is calculated.\n\n        Args:\n            input_batch: A 2D or higher-dimensional tensor of input activations\n                from a calibration batch.\n\n        Raises:\n            ValueError: If the feature dimension of the input tensor\n                `input_batch` does not match the dimensions of the\n                pre-initialized Hessian matrix `self.hessian`.\n        \"\"\"\n        if input_batch is None:\n            raise ValueError(\"Input tensor cannot be None.\")\n\n        if len(input_batch.shape) < 2:\n            raise ValueError(\n                \"Input tensor must have rank >= 2 \"\n                f\"(got rank {len(input_batch.shape)}).\"\n            )\n        if ops.size(input_batch) == 0:\n            raise ValueError(\"Input tensor cannot be empty.\")\n        if len(input_batch.shape) > 2:\n            # [batch, features]\n            input_batch = ops.reshape(input_batch, (-1, input_batch.shape[-1]))\n        x = ops.cast(input_batch, \"float32\")\n\n        num_new_samples = ops.shape(x)[0]\n        num_prev_samples = self.num_samples\n        total_samples = ops.add(num_prev_samples, num_new_samples)\n\n        if ops.shape(self.hessian)[0] != ops.shape(x)[-1]:\n            raise ValueError(\n                f\"Hessian dimensions ({ops.shape(self.hessian)[0]}) do not \"\n                f\"match input features ({ops.shape(x)[-1]}).\"\n            )\n\n        # gram_matrix: [features, features]\n        gram_matrix = ops.matmul(ops.transpose(x), x)\n        # Ensures numerical stability and symmetry in case of large floating\n        # point activations.\n        gram_matrix = ops.divide(\n            ops.add(gram_matrix, ops.transpose(gram_matrix)), 2.0\n        )\n\n        # Decay previous mean and add current per-sample contribution\n        # (factor 2/N)\n        if self.num_samples > 0:\n            self.hessian = ops.multiply(\n                self.hessian, ops.divide(num_prev_samples, total_samples)\n            )\n\n        self.hessian = ops.add(\n            self.hessian,\n            ops.multiply(ops.divide(2.0, total_samples), gram_matrix),\n        )\n\n        self.num_samples = self.num_samples + ops.shape(x)[0] or 0\n\n    def quantize_and_correct_layer(\n        self,\n        blocksize=128,\n    ):\n        \"\"\"\n        Performs GPTQ quantization and correction on the layer's weights.\n\n        This method implements the core logic of the \"Optimal Brain Quant\"\n        (OBQ) method, as applied by GPTQ, to quantize the weights of a single\n        layer. It iteratively quantizes blocks of weights and corrects for the\n        quantization error by updating the remaining weights.\n\n        The algorithm follows these main steps:\n        1.  Initialization: It optionally reorders the weight columns based\n            on activation magnitudes (`activation_order=True`) to protect more\n            salient\n            weights.\n        2.  Hessian Modification: The Hessian matrix, pre-computed from\n            calibration data, is dampened to ensure its invertibility and\n            stability.\n        3.  Iterative Quantization: The function iterates through the\n            weight columns in blocks (`blocksize`). In each iteration, it:\n            a. Quantizes one column.\n            b. Calculates the quantization error.\n            c. Updates the remaining weights in the *current* block by\n                distributing the error, using the inverse Hessian.\n        4.  Block-wise Correction: After a block is quantized, the total\n            error from that block is propagated to the *next* block of weights\n            to be processed.\n        5.  Finalization: The quantized weights are reordered back if\n            `activation_order` was used, and the layer's weights are updated.\n        This implementation is based on the official GPTQ paper and repository.\n        For more details, see:\n        - Paper: https://arxiv.org/abs/2210.17323\n        - Original Code: https://github.com/IST-DASLab/gptq\n\n\n        Args:\n            blocksize: (int, optional) The size of the weight block to process\n             at a time. Defaults to 128.\n        \"\"\"\n        weights_matrix = ops.transpose(self.layer.kernel)\n\n        # Dampen the Hessian for Stability\n        hessian_diagonal = ops.diagonal(self.hessian)\n        dead_diagonal = ops.equal(hessian_diagonal, 0.0)\n        hessian_diagonal = ops.where(dead_diagonal, 1.0, hessian_diagonal)\n        hessian_matrix = ops.add(\n            self.hessian,\n            ops.diag(\n                ops.where(dead_diagonal, 1.0, ops.zeros_like(hessian_diagonal))\n            ),\n        )\n\n        # Add dampening factor to the Hessian diagonal\n        damping_factor = ops.multiply(\n            self.config.hessian_damping, ops.mean(hessian_diagonal)\n        )\n        hessian_diagonal = ops.add(hessian_diagonal, damping_factor)\n        hessian_matrix = ops.add(\n            ops.subtract(\n                hessian_matrix, ops.diag(ops.diagonal(hessian_matrix))\n            ),\n            ops.diag(hessian_diagonal),\n        )\n\n        # Compute the inverse Hessian, which is used for error correction\n        inverse_hessian = linalg.inv(hessian_matrix)\n\n        quantized, scale, zero, g_idx = gptq_quantize_matrix(\n            weights_matrix,\n            inv_hessian=inverse_hessian,\n            blocksize=blocksize,\n            group_size=self.config.group_size,\n            activation_order=self.config.activation_order,\n            order_metric=ops.diagonal(hessian_matrix),\n            compute_scale_zero=self.quantizer.find_params,\n        )\n        quantized = ops.cast(\n            quantized, self.original_layer.quantized_kernel.dtype\n        )\n\n        if self.config.weight_bits == 4:\n            # For 4-bit weights, we need to pack them into bytes\n            quantized, _, _ = quantizers.pack_int4(\n                quantized, axis=0, dtype=\"uint8\"\n            )\n\n        del self.original_layer._kernel\n        self.original_layer.quantized_kernel.assign(quantized)\n        self.original_layer.kernel_scale.assign(scale)\n        self.original_layer.kernel_zero.assign(zero)\n        self.original_layer.g_idx.assign(g_idx)\n        self.original_layer.is_gptq_calibrated = True\n\n    def free(self):\n        del self.hessian\n        del self.layer\n"
  },
  {
    "path": "keras/src/quantizers/gptq_config.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.quantizers.quantization_config import QuantizationConfig\n\n\n@keras_export(\"keras.quantizers.GPTQConfig\")\nclass GPTQConfig(QuantizationConfig):\n    \"\"\"Configuration class for the GPTQ (Gradient-based Post-Training\n    Quantization) algorithm.\n\n    GPTQ is a post-training quantization method that quantizes neural network\n    weights to lower precision (e.g., 4-bit) while minimizing the impact on\n    model accuracy. It works by analyzing the Hessian matrix of the loss\n    function with respect to the weights and applying optimal quantization\n    that preserves the most important weight values.\n\n    **When to use GPTQ:**\n    - You want to reduce model size and memory usage\n    - You need faster inference on hardware that supports low-precision\n      operations\n    - You want to maintain model accuracy as much as possible\n    - You have a pre-trained model that you want to quantize without\n      retraining\n\n    **How it works:**\n    1. Uses calibration data to compute the Hessian matrix for each layer\n    2. Applies iterative quantization with error correction\n    3. Reorders weights based on activation importance (optional)\n    4. Quantizes weights while minimizing quantization error\n\n    **Example usage:**\n    ```python\n    from keras.quantizers import GPTQConfig\n    from keras import Model\n\n    # Create configuration for 4-bit quantization\n    config = GPTQConfig(\n        dataset=calibration_data,          # Your calibration dataset\n        tokenizer=your_tokenizer,          # Tokenizer for text data\n        weight_bits=4,                     # Quantize to 4 bits\n        num_samples=128,                   # Number of calibration samples\n        sequence_length=512,               # Sequence length for each sample\n        hessian_damping=0.01,             # Hessian stabilization factor\n        group_size=128,                    # Weight grouping for quantization\n        symmetric=False,                   # Use asymmetric quantization\n        activation_order=True              # Reorder weights by importance\n    )\n\n    # Apply quantization to your model\n    model = Model(...)  # Your pre-trained model\n    model.quantize(\"gptq\", config=config)\n\n    # The model now has quantized weights and can be used for inference\n    ```\n\n    **Benefits:**\n    - **Memory reduction**: 4-bit quantization reduces memory by ~8x compared\n      to float32\n    - **Faster inference**: Lower precision operations are faster on supported\n      hardware\n    - **Accuracy preservation**: Minimizes accuracy loss through optimal\n      quantization\n    - **No retraining required**: Works with pre-trained models\n\n    **Advanced usage examples:**\n\n    **Per-channel quantization (recommended for most cases):**\n    ```python\n    config = GPTQConfig(\n        dataset=calibration_data,\n        tokenizer=tokenizer,\n        weight_bits=4,\n        group_size=-1,  # -1 enables per-channel quantization\n        symmetric=False\n    )\n    ```\n\n    **Grouped quantization (for specific hardware requirements):**\n    ```python\n    config = GPTQConfig(\n        dataset=calibration_data,\n        tokenizer=tokenizer,\n        weight_bits=4,\n        group_size=64,  # 64 weights share the same scale factor\n        symmetric=True   # Use symmetric quantization\n    )\n    ```\n\n    **High-accuracy quantization with activation ordering:**\n    ```python\n    config = GPTQConfig(\n        dataset=calibration_data,\n        tokenizer=tokenizer,\n        weight_bits=4,\n        activation_order=True,  # Reorder weights by importance\n        hessian_damping=0.005,  # Lower damping for more precise\n        # quantization\n        num_samples=256          # More samples for better accuracy\n    )\n    ```\n\n    **References:**\n    - Original GPTQ paper: \"GPTQ: Accurate Post-Training Quantization\n      for Generative Pre-trained Transformers\"\n    - Implementation based on: https://github.com/IST-DASLab/gptq\n    - Suitable for: Transformer models, large language models, and other\n      deep neural networks\n\n    **Note:** The quality of quantization depends heavily on the calibration\n    dataset. Use representative data that covers the expected input\n    distribution for best results.\n\n    Args:\n        dataset: The calibration dataset. It can be an iterable that yields\n            strings or pre-tokenized numerical tensors (e.g., a list of\n            strings, a generator, or a NumPy array). This data is used to\n            analyze the model's activations.\n        tokenizer: A `keras_nlp.Tokenizer` instance (or a similar callable)\n            that is used to process the `dataset` if it contains strings.\n        weight_bits: (int, optional) The number of bits to quantize weights to.\n            Defaults to 4.\n        num_samples: (int, optional) The number of calibration data samples to\n            use from the dataset. Defaults to 128.\n        sequence_length: (int, optional) The sequence length to use for each\n            calibration sample. Defaults to 512.\n        hessian_damping: (float, optional) The % of Hessian damping to use for\n            stabilization during inverse calculation. Defaults to 0.01.\n        group_size: (int, optional) The size of weight groups to quantize\n            together. A `group_size` of -1 indicates per-channel quantization.\n            Defaults to 128.\n        symmetric: (bool, optional) If `True`, uses symmetric quantization.\n            If `False`, uses asymmetric quantization. Defaults to `False`.\n        activation_order: (bool, optional) If `True`, reorders weight columns\n            based on activation magnitude, which can improve quantization\n            accuracy. Defaults to `False`.\n        quantization_layer_structure: (dict, optional) A dictionary defining the\n            model's quantization structure. It should contain:\n            - \"pre_block_layers\": list of layers to run before the first block.\n            - \"sequential_blocks\": list of blocks to be quantized sequentially.\n            If not provided, the model must implement\n            `get_quantization_layer_structure`.\n    \"\"\"\n\n    def __init__(\n        self,\n        dataset,\n        tokenizer,\n        *,\n        weight_bits: int = 4,\n        num_samples: int = 128,\n        per_channel: bool = True,\n        sequence_length: int = 512,\n        hessian_damping: float = 0.01,\n        group_size: int = 128,\n        symmetric: bool = False,\n        activation_order: bool = False,\n        quantization_layer_structure: dict = None,\n    ):\n        super().__init__()\n        if weight_bits not in [2, 3, 4, 8]:\n            raise ValueError(\n                f\"Unsupported weight_bits {weight_bits}. \"\n                \"Supported values are 2, 3, 4, and 8.\"\n            )\n        if num_samples <= 0:\n            raise ValueError(\"num_samples must be a positive integer.\")\n        if sequence_length <= 0:\n            raise ValueError(\"sequence_length must be a positive integer.\")\n        if hessian_damping < 0 or hessian_damping > 1:\n            raise ValueError(\"hessian_damping must be between 0 and 1.\")\n        if group_size < -1 or group_size == 0:\n            raise ValueError(\n                \"Invalid group_size. Supported values are -1 (whole-tensor) \"\n                \"or a positive integer, \"\n                f\"but got {group_size}.\"\n            )\n        self.dataset = dataset\n        self.tokenizer = tokenizer\n        self.num_samples = num_samples\n        self.per_channel = per_channel\n        self.sequence_length = sequence_length\n        self.hessian_damping = hessian_damping\n        self.weight_bits = weight_bits\n        self.group_size = group_size\n        self.symmetric = symmetric\n        self.activation_order = activation_order\n        self.quantization_layer_structure = quantization_layer_structure\n\n    def get_config(self):\n        return {\n            # Dataset and Tokenizer are only required for a one-time\n            # calibration and are not saved in the config.\n            \"dataset\": None,\n            \"tokenizer\": None,\n            \"weight_bits\": self.weight_bits,\n            \"num_samples\": self.num_samples,\n            \"per_channel\": self.per_channel,\n            \"sequence_length\": self.sequence_length,\n            \"hessian_damping\": self.hessian_damping,\n            \"group_size\": self.group_size,\n            \"symmetric\": self.symmetric,\n            \"activation_order\": self.activation_order,\n            \"quantization_layer_structure\": self.quantization_layer_structure,\n        }\n\n    @classmethod\n    def from_config(cls, config):\n        return cls(**config)\n\n    @property\n    def mode(self):\n        return \"gptq\"\n\n    def dtype_policy_string(self):\n        \"\"\"Returns the dtype policy string for this configuration.\n\n        Returns:\n            A string representing the dtype policy, e.g. \"gptq_4bit\".\n        \"\"\"\n        return f\"gptq/{self.weight_bits}/{self.group_size}\"\n"
  },
  {
    "path": "keras/src/quantizers/gptq_config_test.py",
    "content": "from keras.src import testing\nfrom keras.src.quantizers.gptq_config import GPTQConfig\n\n\nclass TestGPTQConfig(testing.TestCase):\n    def test_invalid_weight_bits(self):\n        with self.assertRaisesRegex(ValueError, \"Unsupported weight_bits\"):\n            GPTQConfig(dataset=None, tokenizer=None, weight_bits=1)\n        with self.assertRaisesRegex(ValueError, \"Unsupported weight_bits\"):\n            GPTQConfig(dataset=None, tokenizer=None, weight_bits=5)\n\n    def test_invalid_num_samples(self):\n        with self.assertRaisesRegex(\n            ValueError, \"num_samples must be a positive\"\n        ):\n            GPTQConfig(dataset=None, tokenizer=None, num_samples=0)\n        with self.assertRaisesRegex(\n            ValueError, \"num_samples must be a positive\"\n        ):\n            GPTQConfig(dataset=None, tokenizer=None, num_samples=-1)\n\n    def test_invalid_sequence_length(self):\n        with self.assertRaisesRegex(\n            ValueError, \"sequence_length must be a positive\"\n        ):\n            GPTQConfig(dataset=None, tokenizer=None, sequence_length=0)\n        with self.assertRaisesRegex(\n            ValueError, \"sequence_length must be a positive\"\n        ):\n            GPTQConfig(dataset=None, tokenizer=None, sequence_length=-10)\n\n    def test_invalid_hessian_damping(self):\n        with self.assertRaisesRegex(\n            ValueError, \"hessian_damping must be between\"\n        ):\n            GPTQConfig(dataset=None, tokenizer=None, hessian_damping=-0.1)\n        with self.assertRaisesRegex(\n            ValueError, \"hessian_damping must be between\"\n        ):\n            GPTQConfig(dataset=None, tokenizer=None, hessian_damping=1.1)\n\n    def test_invalid_group_size(self):\n        with self.assertRaisesRegex(ValueError, \"Invalid group_size\"):\n            GPTQConfig(dataset=None, tokenizer=None, group_size=0)\n        with self.assertRaisesRegex(ValueError, \"Invalid group_size\"):\n            GPTQConfig(dataset=None, tokenizer=None, group_size=-2)\n\n    def test_dtype_policy_string(self):\n        config = GPTQConfig(\n            dataset=None, tokenizer=None, weight_bits=4, group_size=64\n        )\n        self.assertEqual(config.dtype_policy_string(), \"gptq/4/64\")\n\n    def test_gptq_config_serialization(self):\n        config = GPTQConfig(\n            dataset=None, tokenizer=None, weight_bits=4, group_size=64\n        )\n        serialized_config = config.get_config()\n        deserialized_config = GPTQConfig.from_config(serialized_config)\n        self.assertDictEqual(config.__dict__, deserialized_config.__dict__)\n"
  },
  {
    "path": "keras/src/quantizers/gptq_core.py",
    "content": "import math\nfrom contextlib import contextmanager\n\nimport numpy as np\nfrom absl import logging\n\nfrom keras.src import ops\nfrom keras.src import utils as keras_utils\nfrom keras.src.dtype_policies.dtype_policy import GPTQDTypePolicy\nfrom keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap\nfrom keras.src.layers import Dense\nfrom keras.src.layers import EinsumDense\nfrom keras.src.quantizers.gptq import GPTQ\nfrom keras.src.quantizers.gptq_config import GPTQConfig\nfrom keras.src.quantizers.utils import should_quantize_layer\n\n\n@contextmanager\ndef stream_hessians(layers_map, gptq_objects):\n    \"\"\"\n    Temporarily monkey-patch each target layer's `call` method so\n    that input activations are streamed into the GPTQ instance\n    running Hessian estimate at capture time.\n\n    On `__enter__`: For every (name, layer) in `layers_map`, replaces\n     `layer.call` with a wrapper that:\n     1) extracts the layer input from `*args`/`**kwargs`,\n     2) reshapes it to 2D `[-1, rows]` where\n      `rows = gptq_objects[name].rows`,\n     3) calls `gptq_objects[name].update_hessian_with_batch(x2d)`\n     4) delegates to the original `layer.call` and returns its\n      output.\n\n    On `__exit__`: All original `layer.call` methods are restored even if an\n     exception occurs.\n\n    * Space complexity: O(d**2) per layer (for the Hessian).\n    * No weights are modified; only GPTQ statistics are updated.\n\n    Args:\n        layers_map: Dict[str, Layer]. Mapping from logical layer names to\n         the Keras layers that should be patched during calibration. Keys must\n         match `gptq_objects`.\n        gptq_objects: Dict[str, GPTQ]. Mapping from names to GPTQ instances.\n\n    Yields:\n        None: The patched state is active only within the `with` block. After\n         exit, all layers are unpatched and safe to use normally.\n\n    Example:\n    ```python\n    >>> with stream_hessians(layers_map, gptq_objects):\n    ...     for sample in calibration_inputs:\n    ...         if len(sample.shape) == 2:\n    ...             sample = ops.expand_dims(sample, 0)\n    ...         _ = block(sample)   # hooks update Hessians on-the-fly\n    >>> # <- original layer.call methods restored here\n    ```\n    \"\"\"\n    original_calls = {}\n\n    def create_hook(name, original_call_func):\n        def hook(*args, **kwargs):\n            inp = args[0] if args else kwargs[\"inputs\"]\n            # Explicitly reshape the input tensor to be 2D, with the\n            # second dimension matching the number of input features\n            # expected by the layer's kernel.\n            # This correctly handles inputs of any dimensionality\n            # (e.g., 3D or 4D).\n            num_features = gptq_objects[name].rows\n            input_2d = ops.reshape(inp, (-1, num_features))\n            gptq_objects[name].update_hessian_with_batch(input_2d)\n            return original_call_func(*args, **kwargs)\n\n        return hook\n\n    try:\n        for name, layer in layers_map.items():\n            original_calls[name] = layer.call\n            layer.call = create_hook(name, layer.call)\n        yield\n    finally:\n        for name, layer in layers_map.items():\n            layer.call = original_calls[name]\n\n\ndef get_dataloader(\n    tokenizer,\n    sequence_length,\n    dataset,\n    num_samples=128,\n    *,\n    strategy=\"strided\",\n    seed=42,\n    stride=None,\n    eos_id=None,\n):\n    \"\"\"\n    Prepares and chunks the calibration dataloader, repeating short datasets.\n    All processing happens on the CPU.\n\n    Args:\n        tokenizer: The tokenizer to use for text splitting.\n        sequence_length: The length of each input sequence.\n        dataset: The dataset to sample from.\n        num_samples: The number of samples to generate.\n        strategy: The sampling strategy to use. Possible values are\n         1. \"strided\": Samples are taken at regular intervals.\n         2. \"linspace\": Samples are taken at evenly spaced intervals.\n         3. \"random\": Samples are taken at random positions.\n        seed: The random seed for reproducibility. Used only if\n         strategy=\"random\"\n        stride: The stride length for \"strided\" sampling.\n        eos_id: The end-of-sequence token ID.\n\n    Returns:\n        np.ndarray of shape (num_samples, 1, sequence_length), dtype int32.\n    \"\"\"\n    if not hasattr(dataset, \"__iter__\") or isinstance(dataset, (str, bytes)):\n        raise TypeError(\n            \"The `dataset` argument must be an iterable (e.g., a list of \"\n            \"strings, a generator, or a NumPy array). Got type: \"\n            f\"{type(dataset).__name__}. Please pass the loaded dataset \"\n            \"directly.\"\n        )\n\n    dataset_list = list(dataset)\n    if not dataset_list:\n        raise ValueError(\"Provided dataset is empty.\")\n\n    pieces = []\n    if isinstance(dataset_list[0], str):\n        for i, s in enumerate(dataset_list):\n            toks = ops.convert_to_numpy(tokenizer.tokenize(s)).reshape(-1)\n            pieces.append(toks)\n            # avoid windows that span document boundaries\n            if eos_id is not None and i < len(dataset_list) - 1:\n                pieces.append(np.array([eos_id], dtype=np.int32))\n    else:\n        for s in dataset_list:\n            toks = ops.convert_to_numpy(s).reshape(-1)\n            pieces.append(toks.astype(np.int32, copy=False))\n\n    all_tokens = (\n        pieces[0].astype(np.int32, copy=False)\n        if len(pieces) == 1\n        else np.concatenate(pieces, axis=0).astype(np.int32, copy=False)\n    )\n\n    required_tokens = num_samples * sequence_length\n    if all_tokens.size < required_tokens:\n        repeats = math.ceil(required_tokens / max(1, all_tokens.size))\n        all_tokens = np.tile(all_tokens, repeats)\n\n    max_start = all_tokens.size - sequence_length\n    if max_start < 0:\n        raise ValueError(\n            f\"Not enough tokens to form one sample of length {sequence_length} \"\n            f\"(have {all_tokens.size}).\"\n        )\n\n    # Choose deterministic, well-spread starts by default\n    if strategy == \"random\":\n        rng = np.random.default_rng(seed)\n        starts = rng.integers(\n            0, max_start + 1, size=num_samples, dtype=np.int64\n        )\n    elif strategy == \"linspace\":\n        # even coverage with no RNG\n        starts = np.linspace(0, max_start, num_samples, dtype=np.int64)\n    elif strategy == \"strided\":\n        # stride chosen to cover the space roughly uniformly\n        if stride is None:\n            stride = max(1, (max_start + 1) // num_samples)\n        # offset derived deterministically from seed\n        offset = (\n            (abs(hash((\"gptq-calib\", seed))) % (max_start + 1))\n            if max_start > 0\n            else 0\n        )\n        starts = (offset + np.arange(num_samples, dtype=np.int64) * stride) % (\n            max_start + 1\n        )\n    else:\n        raise ValueError(f\"Unknown strategy: {strategy}\")\n\n    # Gather contiguous windows\n    # sliding_window_view avoids building a big index matrix\n    windows = np.lib.stride_tricks.sliding_window_view(\n        all_tokens, sequence_length\n    )\n    samples = windows[starts]  # (num_samples, sequence_length)\n    return samples.astype(np.int32)[:, None, :]\n\n\ndef find_layers_in_block(block):\n    \"\"\"\n    Finds all Dense and EinsumDense layers in a transformer block.\n\n    Args:\n        block: A Keras layer representing a transformer block.\n    Returns:\n        A dict mapping layer paths to the corresponding Dense or EinsumDense\n    \"\"\"\n    found_layers = {}\n    for sub_layer in block._flatten_layers():\n        if len(list(sub_layer._flatten_layers())) == 1:\n            if isinstance(sub_layer, (Dense, EinsumDense)):\n                found_layers[sub_layer.path] = sub_layer\n    return found_layers\n\n\ndef apply_gptq_layerwise(dataloader, config, structure, filters=None):\n    \"\"\"Applies GPTQ quantization layer-by-layer to a Keras model.\n\n    This function uses the provided `structure` to identify pre-quantization\n    layers and sequential blocks.\n\n    The core logic operates as follows:\n\n    1.  It processes the model sequentially, one block at a time. For each\n        block, it uses temporary hooks to capture the input activations of\n        each target layer during a forward pass with the calibration data.\n    2.  These captured activations are used to compute the Hessian matrix for\n        each layer's weights.\n    3.  The GPTQ algorithm is then applied to each layer to find the optimal\n        quantized weights that minimize the error introduced.\n    4.  The output activations from the current block are then used as the\n        input for the next block, ensuring that quantization errors are\n        accounted for throughout the model.\n\n    Args:\n        dataloader: An iterable providing calibration data.\n        config: A GPTQConfiguration object.\n        structure: A dictionary with keys \"pre_block_layers\" and\n            \"sequential_blocks\".\n        filters: Optional filters to exclude layers from quantization.\n\n    Raises:\n        ValueError: If the function cannot automatically find an embedding\n            layer or any transformer-like blocks to quantize within the model.\n    \"\"\"\n\n    num_samples = config.num_samples\n\n    logging.info(\"Starting model quantization...\")\n\n    pre_layers = structure.get(\"pre_block_layers\", [])\n    transformer_blocks = structure.get(\"sequential_blocks\", [])\n\n    if not transformer_blocks:\n        raise ValueError(\n            \"No sequential blocks found in the provided structure to quantize.\"\n        )\n\n    # Initial inputs are the outputs of the pre-block layers\n    inputs = []\n    for batch in dataloader:\n        batch = ops.convert_to_tensor(batch, dtype=\"int32\")\n        for layer in pre_layers:\n            batch = layer(batch)\n        inputs.append(batch)\n\n    num_samples = min(num_samples, len(inputs))\n\n    progbar = keras_utils.Progbar(target=len(transformer_blocks))\n\n    for block_idx, block in enumerate(transformer_blocks):\n        logging.info(f\"Quantizing Block {block_idx}\")\n        sub_layers_map = find_layers_in_block(block)\n\n        # Filter out layers that are not quantized with GPTQ\n        final_sub_layers_map = {}\n        for name, layer in sub_layers_map.items():\n            if not should_quantize_layer(layer, filters):\n                continue\n\n            final_sub_layers_map[name] = layer\n\n        sub_layers_map = final_sub_layers_map\n\n        if not sub_layers_map:\n            logging.info(\n                f\"  No quantizable layers found in block {block_idx}. Skipping.\"\n            )\n        else:\n            logging.info(f\"Found layers: {list(sub_layers_map.keys())}\")\n            gptq_objects = {\n                name: GPTQ(layer, config)\n                for name, layer in sub_layers_map.items()\n            }\n\n            with stream_hessians(sub_layers_map, gptq_objects):\n                for sample_idx in range(num_samples):\n                    current_input = inputs[sample_idx]\n                    if len(current_input.shape) == 2:\n                        current_input = ops.expand_dims(current_input, axis=0)\n                    _ = block(current_input)\n\n            for name, gptq_object in gptq_objects.items():\n                logging.info(f\"Quantizing {name}...\")\n                gptq_object.quantize_and_correct_layer()\n                gptq_object.free()\n\n            del gptq_objects\n\n        if block_idx < len(transformer_blocks) - 1:\n            logging.info(f\"Generating inputs for block {block_idx + 1}...\")\n            next_block_inputs = []\n            for sample_idx in range(num_samples):\n                current_input = inputs[sample_idx]\n                if len(current_input.shape) == 2:\n                    current_input = ops.expand_dims(current_input, axis=0)\n                output = block(current_input)[0]\n                next_block_inputs.append(output)\n            inputs = next_block_inputs\n        progbar.update(current=block_idx + 1)\n\n    logging.info(\"Quantization process complete.\")\n\n\ndef gptq_quantize(config, quantization_layer_structure, filters=None):\n    \"\"\"\n    Quantizes the model using GPTQ.\n\n    Args:\n        config: The GPTQ configuration.\n        quantization_layer_structure: A dictionary describing the model's layer\n        structure for quantization.\n        filters: Optional filters to exclude layers from quantization.\n    \"\"\"\n    if config.dataset is None or config.tokenizer is None:\n        raise ValueError(\n            \"GPTQ quantization requires a dataset and a tokenizer. \"\n            \"Please provide them in the `GPTQConfig`.\"\n        )\n\n    if quantization_layer_structure is None:\n        raise ValueError(\n            \"For 'gptq' mode, a valid quantization structure must be provided \"\n            \"either via `config.quantization_layer_structure` or by overriding \"\n            \"`model.get_quantization_layer_structure(mode)`. The structure \"\n            \"should be a dictionary with keys 'pre_block_layers' and \"\n            \"'sequential_blocks'.\"\n        )\n\n    # Load all data needed from the generator/source in a single call.\n    total_samples_to_request = config.num_samples\n    dataloader = get_dataloader(\n        config.tokenizer,\n        config.sequence_length,\n        config.dataset,\n        num_samples=total_samples_to_request,\n    )\n\n    # Split the materialized data. This works because dataloader\n    # is now a NumPy array, which can be sliced and reused.\n    calibration_dataloader = dataloader[: config.num_samples]\n\n    apply_gptq_layerwise(\n        calibration_dataloader,\n        config,\n        quantization_layer_structure,\n        filters=filters,\n    )\n\n\ndef get_group_size_for_layer(layer, config):\n    \"\"\"Determine the group size for GPTQ quantization.\n\n    The group size can be specified either through the `config` argument\n    or through the `dtype_policy` if it is of type `GPTQDTypePolicy`.\n\n    The config argument is usually available when quantizing the layer\n    via the `quantize` method. If the layer was deserialized from a\n    saved model, the group size should be specified in the `dtype_policy`.\n\n    Args:\n        config: An optional configuration object that may contain the\n            `group_size` attribute.\n    Returns:\n        int. The determined group size for GPTQ quantization.\n    Raises:\n        ValueError: If the group size is not specified in either the\n            `config` or the `dtype_policy`.\n    \"\"\"\n    if config and isinstance(config, GPTQConfig):\n        return config.group_size\n    elif isinstance(layer.dtype_policy, GPTQDTypePolicy):\n        return layer.dtype_policy.group_size\n    elif isinstance(layer.dtype_policy, DTypePolicyMap):\n        policy = layer.dtype_policy[layer.path]\n        if not isinstance(policy, GPTQDTypePolicy):\n            # This should never happen based on how we set the\n            # quantization mode, but we check just in case.\n            raise ValueError(\n                \"Expected a `dtype_policy` of type `GPTQDTypePolicy`.\"\n                f\"Got: {type(policy)}\"\n            )\n        return policy.group_size\n    else:\n        raise ValueError(\n            \"For GPTQ quantization, the group_size must be specified\"\n            \"either through a `dtype_policy` of type \"\n            \"`GPTQDTypePolicy` or the `config` argument.\"\n        )\n\n\ndef get_weight_bits_for_layer(layer, config):\n    \"\"\"Determine the number of weight bits for GPTQ quantization.\n\n    The number of weight bits can be specified either through the `config`\n    argument or through the `dtype_policy` if it is of type\n    `GPTQDTypePolicy`.\n\n    The config argument is usually available when quantizing the layer\n    via the `quantize` method. If the layer was deserialized from a\n    saved model, the weight bits should be specified in the `dtype_policy`.\n\n    Args:\n        config: An optional configuration object that may contain the\n            `weight_bits` attribute.\n    Returns:\n        int. The determined number of weight bits for GPTQ quantization.\n    Raises:\n        ValueError: If the weight bits is not specified in either the\n            `config` or the `dtype_policy`.\n    \"\"\"\n    if config and isinstance(config, GPTQConfig):\n        return config.weight_bits\n    elif isinstance(layer.dtype_policy, GPTQDTypePolicy):\n        return layer.dtype_policy.weight_bits\n    elif isinstance(layer.dtype_policy, DTypePolicyMap):\n        policy = layer.dtype_policy[layer.path]\n        if not isinstance(policy, GPTQDTypePolicy):\n            # This should never happen based on how we set the\n            # quantization mode, but we check just in case.\n            raise ValueError(\n                \"Expected a `dtype_policy` of type `GPTQDTypePolicy`.\"\n                f\"Got: {type(policy)}\"\n            )\n        return policy.weight_bits\n    else:\n        raise ValueError(\n            \"For GPTQ quantization, the weight_bits must be specified\"\n            \"either through a `dtype_policy` of type \"\n            \"`GPTQDTypePolicy` or the `config` argument.\"\n        )\n"
  },
  {
    "path": "keras/src/quantizers/gptq_core_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.quantizers.gptq_config import GPTQConfig\nfrom keras.src.quantizers.gptq_core import get_dataloader\nfrom keras.src.quantizers.gptq_core import gptq_quantize\n\nVOCAB_SIZE = 100\n\n\nclass MockTokenizer:\n    \"\"\"A mock tokenizer that mimics the real API for testing.\"\"\"\n\n    def tokenize(self, text):\n        return [ord(c) % VOCAB_SIZE for c in \"\".join(text)]\n\n    def __call__(self, text):\n        return self.tokenize(text)\n\n\nclass EmptyBlock(layers.Layer):\n    \"\"\"A block that contains no quantizable layers.\"\"\"\n\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n        self.ln = layers.LayerNormalization()\n\n    def call(self, inputs):\n        return self.ln(inputs)\n\n\nclass TransformerBlock(layers.Layer):\n    \"\"\"A toy transformer block with a quantizable Dense layer.\"\"\"\n\n    def __init__(self, **kwargs):\n        super().__init__(**kwargs)\n        self.dense = layers.Dense(128)\n\n    def call(self, inputs):\n        return self.dense(inputs)\n\n\ndef _get_model_with_backbone(\n    has_transformer_layers=True, embedding_name=\"embedding\"\n):\n    \"\"\"Creates a KerasHub-style model with a backbone.\"\"\"\n\n    class Backbone(layers.Layer):\n        def __init__(self, vocab_size, embedding_dim=128, **kwargs):\n            super().__init__(**kwargs)\n            # Use direct assignment\n            setattr(\n                self,\n                embedding_name,\n                layers.Embedding(vocab_size, embedding_dim),\n            )\n\n            # Keep track of layers in a list for the call method\n            self.transformer_layers = []\n            if has_transformer_layers:\n                self.transformer_layers.append(TransformerBlock())\n\n        def call(self, inputs):\n            x = getattr(self, embedding_name)(inputs)\n            for layer in self.transformer_layers:\n                x = layer(x)\n            return x\n\n    class Model(models.Model):\n        def __init__(self, vocab_size, **kwargs):\n            super().__init__(**kwargs)\n            # Pass configuration directly\n            self.backbone = Backbone(vocab_size=vocab_size)\n            self.classifier = layers.Dense(1, activation=\"sigmoid\")\n\n        def call(self, inputs):\n            x = self.backbone(inputs)\n            x = layers.GlobalAveragePooling1D()(x)\n            return self.classifier(x)\n\n    model = Model(vocab_size=VOCAB_SIZE)\n    rng = np.random.default_rng(seed=42)\n    dummy_input = rng.normal(loc=0, scale=1, size=(2, 64)).astype(np.float32)\n\n    _ = model(dummy_input)\n    return model\n\n\ndef build_all_tokens_strings(dataset, tokenizer, eos_id=None):\n    pieces = []\n    for i, s in enumerate(dataset):\n        toks = np.asarray(tokenizer.tokenize(s), dtype=np.int32).reshape(-1)\n        pieces.append(toks)\n        if eos_id is not None and i < len(dataset) - 1:\n            pieces.append(np.array([eos_id], dtype=np.int32))\n    return np.concatenate(pieces, axis=0).astype(np.int32, copy=False)\n\n\ndef sliding_windows(x, L):\n    return np.lib.stride_tricks.sliding_window_view(x, L)\n\n\n@pytest.mark.requires_trainable_backend\nclass TestGPTQCore(testing.TestCase):\n    @parameterized.named_parameters(\n        [(\"strided\", \"strided\"), (\"linspace\", \"linspace\"), (\"random\", \"random\")]\n    )\n    def test_shape_and_dtype_strings(self, strategy):\n        \"\"\"Test the shape and dtype of the output for string inputs.\"\"\"\n        tok = MockTokenizer()\n        dataset = [\"a b c d e f g\", \"h i j k\"]\n        seq_len, n = 5, 7\n\n        out = get_dataloader(\n            tok, seq_len, dataset, num_samples=n, strategy=strategy, seed=123\n        )\n        self.assertEqual(out.shape, (n, 1, seq_len))\n        self.assertEqual(out.dtype, np.int32)\n\n    @parameterized.named_parameters(\n        [(\"strided\", \"strided\"), (\"linspace\", \"linspace\"), (\"random\", \"random\")]\n    )\n    def test_shape_and_dtype_pretokenized(self, strategy):\n        \"\"\"Test the shape and dtype of the output for pre-tokenized inputs.\"\"\"\n        tok = MockTokenizer()\n        # Pre-tokenized inputs; mixed shapes (1, L) and (L,)\n        seqs = [\n            np.array([[1, 2, 3, 4]], dtype=np.int64),\n            np.array([5, 6], dtype=np.int64),\n        ]\n        tok = MockTokenizer()\n        seq_len, n = 3, 4\n\n        out = get_dataloader(\n            tok, seq_len, seqs, num_samples=n, strategy=strategy, seed=7\n        )\n        self.assertEqual(out.shape, (n, 1, seq_len))\n        self.assertEqual(out.dtype, np.int32)\n\n    def test_strided_is_deterministic_for_same_args(self):\n        tok = MockTokenizer()\n        dataset = [\"a b c d e\", \"f g h i j k\"]\n        out1 = get_dataloader(\n            tok, 4, dataset, num_samples=6, strategy=\"strided\", seed=99\n        )\n        out2 = get_dataloader(\n            tok, 4, dataset, num_samples=6, strategy=\"strided\", seed=99\n        )\n        self.assertTrue(ops.all(ops.equal(out1, out2)))\n\n    def test_random_reproducibility_by_seed(self):\n        tok = MockTokenizer()\n        dataset = [\"a b c d e\", \"f g h i j k\"]\n        a = get_dataloader(\n            tok, 4, dataset, num_samples=6, strategy=\"random\", seed=123\n        )\n        b = get_dataloader(\n            tok, 4, dataset, num_samples=6, strategy=\"random\", seed=123\n        )\n        c = get_dataloader(\n            tok, 4, dataset, num_samples=6, strategy=\"random\", seed=124\n        )\n        self.assertTrue(ops.all(ops.equal(a, b)))\n        self.assertFalse(ops.all(ops.equal(a, c)))\n\n    def test_linspace_windows_match_expected(self):\n        tok = MockTokenizer()\n        dataset = [\"aa bb cc dd\", \"ee ff gg\"]\n        seq_len, n = 3, 5\n        eos_id = None\n\n        all_tokens = build_all_tokens_strings(dataset, tok, eos_id=eos_id)\n        max_start = all_tokens.size - seq_len\n        expected_starts = np.linspace(0, max_start, n, dtype=np.int64)\n\n        expected = sliding_windows(all_tokens, seq_len)[expected_starts]\n        got = get_dataloader(\n            tok, seq_len, dataset, num_samples=n, strategy=\"linspace\"\n        )\n        self.assertTrue(\n            ops.all(ops.equal(got[:, 0, :], expected.astype(np.int32)))\n        )\n\n    def test_strided_override_respected(self):\n        \"\"\"Tests that strided windows are disjoint and cover the input.\"\"\"\n        tok = MockTokenizer()\n        # 20 tokens total\n        # with seq_len=4 and stride=4, we expect disjoint chunks\n        # in order (modulo offset)\n        dataset = [\" \".join([f\"t{i}\" for i in range(20)])]\n        seq_len, n, stride = 4, 5, 4\n\n        out = get_dataloader(\n            tok,\n            seq_len,\n            dataset,\n            num_samples=n,\n            strategy=\"strided\",\n            stride=stride,\n            seed=0,\n        )\n\n        # Validate that each sample is a contiguous run\n        # of length seq_len from the flattened stream\n        flat = build_all_tokens_strings(dataset, tok)\n        for s in out[:, 0, :]:\n            # Each window should appear as a slice in the flat stream\n            # (This is a soft check; exact start positions depend on offset.)\n            joined = \" \".join(map(str, s.tolist()))\n            self.assertIn(joined, \" \".join(map(str, flat.tolist())))\n\n    def test_eos_insertion_is_present_in_some_window_with_linspace(self):\n        tok = MockTokenizer()\n        dataset = [\"aa aa\", \"bb bb\"]  # len = 5 + 1(EOS) + 5 = 11\n        eos = 9999\n        seq_len = 3\n        n = 3\n\n        out = get_dataloader(\n            tok,\n            seq_len,\n            dataset,\n            num_samples=n,\n            strategy=\"linspace\",\n            eos_id=eos,\n        )\n\n        # linspace starts -> [0, 4, 8]; the middle window [4:7]\n        # includes EOS at 5\n        windows = out[:, 0, :]\n        self.assertTrue(\n            np.any(np.any(windows == eos, axis=1)),\n            \"Expected EOS to appear in at least one sampled window with \"\n            \"linspace.\",\n        )\n\n    def test_get_dataloader_error_scenarios(self):\n        \"\"\"Tests error cases for get_dataloader.\"\"\"\n        with pytest.raises(ValueError, match=\"Provided dataset is empty\"):\n            get_dataloader(\n                tokenizer=MockTokenizer(),\n                sequence_length=10,\n                dataset=[],\n                num_samples=10,\n            )\n        with self.assertRaisesRegex(\n            TypeError,\n            \"The `dataset` argument must be an iterable.*Got type: str.*\"\n            \"Please pass the loaded dataset directly.\",\n        ):\n            get_dataloader(\n                tokenizer=MockTokenizer(),\n                sequence_length=10,\n                dataset=\"wikitext2\",\n                num_samples=10,\n            )\n\n    def test_apply_gptq_on_multi_block_model(self):\n        \"\"\"Tests quantization on a model with multiple blocks.\"\"\"\n        model = models.Sequential(\n            [\n                layers.Embedding(VOCAB_SIZE, 128),\n                TransformerBlock(),\n                TransformerBlock(),\n            ]\n        )\n        model.build(input_shape=(None, 10))\n\n        layer_structure = {\n            \"pre_block_layers\": [model.layers[0]],\n            \"sequential_blocks\": [model.layers[1], model.layers[2]],\n        }\n\n        config = GPTQConfig(\n            dataset=[\"test data\"],\n            tokenizer=MockTokenizer(),\n            group_size=32,\n            quantization_layer_structure=layer_structure,\n        )\n        model.quantize(\"gptq\", config=config)\n\n    @parameterized.named_parameters(\n        (\n            \"no_embedding_layer\",\n            models.Sequential([layers.Dense(10)]),\n            \"For 'gptq' mode, a valid quantization structure must be provided\",\n        ),\n        (\n            \"no_transformer_blocks\",\n            models.Sequential(\n                [layers.Embedding(VOCAB_SIZE, 10), layers.Dense(10)]\n            ),\n            \"For 'gptq' mode, a valid quantization structure must be provided\",\n        ),\n        (\n            \"backbone_no_layers\",\n            _get_model_with_backbone(has_transformer_layers=False),\n            \"For 'gptq' mode, a valid quantization structure must be provided\",\n        ),\n        (\n            \"backbone_no_embedding\",\n            _get_model_with_backbone(embedding_name=\"wrong_name\"),\n            \"For 'gptq' mode, a valid quantization structure must be provided\",\n        ),\n    )\n    def test_apply_gptq_with_unsupported_architectures(\n        self, model, error_message\n    ):\n        \"\"\"Tests that quantize fails correctly for various unsupported\n        model architectures.\"\"\"\n        if not model.built:\n            model.build(input_shape=(None, 10))\n\n        config = GPTQConfig(dataset=[\"test\"], tokenizer=MockTokenizer())\n        with self.assertRaisesRegex(ValueError, error_message):\n            # We pass None as structure to trigger the error\n            gptq_quantize(config, quantization_layer_structure=None)\n"
  },
  {
    "path": "keras/src/quantizers/gptq_test.py",
    "content": "from collections.abc import Callable\n\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nimport keras\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.quantizers.gptq import GPTQ\nfrom keras.src.quantizers.gptq import _stable_permutation\nfrom keras.src.quantizers.gptq import gptq_quantize_matrix\nfrom keras.src.quantizers.gptq_config import GPTQConfig\nfrom keras.src.quantizers.quantization_config import QuantizationConfig\nfrom keras.src.quantizers.quantizers import dequantize_with_sz_map\nfrom keras.src.quantizers.quantizers import dequantize_with_zero_point\nfrom keras.src.quantizers.quantizers import quantize_with_zero_point\nfrom keras.src.testing.test_utils import named_product\n\nVOCAB_SIZE = 1000\nSEQ_LEN = 128\nNUM_SAMPLES = 16\nW_BITS = 4\nNUM_CLASSES = 32\n\nCALIBRATION_TEXT = r\"\"\"\nGPTQ (Generative Pre-trained Transformer Quantization) is an advanced \npost-training quantization (PTQ) algorithm designed to compress large \nlanguage models with minimal accuracy degradation. It addresses the \nchallenge of reducing model size from high-precision formats like \nFP16 to low-bit integers (e.g., INT4, INT3) without the need for\nexpensive retraining. The algorithm operates on a layer-by-layer basis, \ntreating the quantization of each weight matrix $W$ as a \nreconstruction problem. Its objective is to find a quantized weight \nmatrix $\\hat{W}$ that minimizes the mean squared error of the layer's \noutput, formulated as $\\arg\\min_{\\hat{W}} \\|WX - \\hat{W}X\\|_F^2$, \nwhere $X$ is a set of calibration inputs. GPTQ's primary innovation \nis its greedy, error-compensating quantization process, based on the \nOptimal Brain Quantizer (OBQ) framework. It quantizes weights one by \none (or in small groups). After quantizing a single weight $w_q$ to \nits discrete value $\\hat{w}_q$, it introduces a quantization error of \n$\\delta = w_q - \\hat{w}_q$. This error is then immediately compensated \nfor by updating all remaining, unquantized weights in the layer. \nThe update step is guided by second-order information, specifically \nthe inverse of the Hessian matrix ($\\mathbf{H}^{-1}$) of the layer's \nreconstruction loss. This inverse Hessian provides a measure of weight \nsaliency and inter-dependencies. The update applied to the remaining \nweights is calculated based on $\\delta$ and the corresponding entries \nin $\\mathbf{H}^{-1}$, effectively propagating the error to less \nsensitive weights. This sequential compensation minimizes the \ncumulative error across the entire layer, allowing GPTQ to maintain \nhigh model fidelity, as measured by perplexity, even at aggressive \nbit-rates.\n\"\"\"\n\n\ndef _get_test_layer(layer_type, kernel_shape):\n    if layer_type == \"Dense\":\n        layer = layers.Dense(units=kernel_shape[1])\n        layer.build(input_shape=(None, kernel_shape[0]))\n    elif layer_type == \"EinsumDense\":\n        output_shape = (kernel_shape[1], kernel_shape[2])\n        layer = layers.EinsumDense(\n            equation=\"...h,hio->...io\", output_shape=output_shape\n        )\n        layer.build(input_shape=(None, kernel_shape[0]))\n    else:\n        layer = layers.Layer()\n    return layer\n\n\n@pytest.mark.requires_trainable_backend\nclass GPTQTest(testing.TestCase):\n    def test_initialization_with_dense_layer(self):\n        mock_layer = _get_test_layer(\"Dense\", kernel_shape=(64, 128))\n\n        gptq_instance = GPTQ(mock_layer)\n        self.assertEqual(gptq_instance.rows, 64)\n        self.assertEqual(gptq_instance.columns, 128)\n        self.assertEqual(gptq_instance.hessian.shape, (64, 64))\n\n    def test_initialization_with_einsumdense_3d(self):\n        mock_layer = _get_test_layer(\"EinsumDense\", kernel_shape=(64, 4, 32))\n        gptq_instance = GPTQ(mock_layer)\n        self.assertEqual(gptq_instance.rows, 64)\n        self.assertEqual(gptq_instance.columns, 4 * 32)\n        self.assertEqual(gptq_instance.hessian.shape, (64, 64))\n\n    def test_update_hessian(self):\n        dense = _get_test_layer(\"Dense\", kernel_shape=(16, 32))\n        dense_gptq = GPTQ(dense)\n\n        rng = np.random.default_rng(seed=42)\n        batch1 = rng.standard_normal(size=(8, 16)).astype(\"float32\")\n\n        dense_gptq.update_hessian_with_batch(batch1)\n        self.assertEqual(dense_gptq.num_samples, 8)\n        H1 = dense_gptq.hessian\n\n        batch2 = rng.standard_normal(size=(4, 16)).astype(\"float32\")\n\n        dense_gptq.update_hessian_with_batch(batch2)\n        self.assertEqual(dense_gptq.num_samples, 12)\n\n        H2 = dense_gptq.hessian\n\n        self.assertNotAllClose(H1, H2)\n\n    def test_gptq_on_single_layer(self):\n        rng = np.random.default_rng(seed=42)\n        dense = _get_test_layer(\"Dense\", kernel_shape=(16, 32))\n\n        config = GPTQConfig(\n            dataset=None,\n            tokenizer=None,\n            weight_bits=4,\n            symmetric=False,\n            group_size=-1,\n        )\n\n        dense.quantize(\"gptq\", config=config)\n        dense_gptq = GPTQ(\n            dense,\n            config,\n        )\n\n        calibration_data = rng.standard_normal(size=(128, 16)).astype(\"float32\")\n\n        dense_gptq.update_hessian_with_batch(calibration_data)\n        dense_gptq.quantize_and_correct_layer()\n\n        self.assertEqual(backend.standardize_dtype(dense.kernel.dtype), \"uint8\")\n\n        dense_gptq.free()\n        self.assertIsNone(getattr(dense_gptq, \"hessian\", None))\n        self.assertIsNone(getattr(dense_gptq, \"layer\", None))\n\n    def test_unsupported_layer_error(self):\n        unsupported_layer = _get_test_layer(\"Unsupported\", kernel_shape=None)\n        with self.assertRaisesRegex(TypeError, \"Unsupported layer type\"):\n            GPTQ(unsupported_layer)\n\n    def test_update_hessian_invalid_input(self):\n        rng = np.random.default_rng(seed=42)\n        dense = _get_test_layer(\"Dense\", kernel_shape=(16, 32))\n        gptq_instance = GPTQ(dense)\n        with self.assertRaisesRegex(ValueError, \"cannot be None\"):\n            gptq_instance.update_hessian_with_batch(None)\n        with self.assertRaisesRegex(ValueError, \"cannot be empty\"):\n            gptq_instance.update_hessian_with_batch(np.empty((0, 16)))\n        with self.assertRaisesRegex(ValueError, \"match input features\"):\n            bad_input = rng.standard_normal(size=(8, 99))\n            gptq_instance.update_hessian_with_batch(bad_input)\n\n    def test_streaming_equals_big_batch(self):\n        \"\"\"Tests that streaming updates match big batch updates.\"\"\"\n        # dummy inputs\n        x = ops.array(np.random.randn(100, 7), \"float32\")\n\n        # One-shot hessian update\n        layer_1 = layers.Dense(5, use_bias=False)\n        layer_1.build(input_shape=(None, 7))\n\n        g1 = GPTQ(layer_1)\n        g1.update_hessian_with_batch(x)\n\n        # Streamed hessian update\n        layer_2 = layers.Dense(5, use_bias=False)\n        layer_2.build(input_shape=(None, 7))\n        g2 = GPTQ(layer_2)\n        g2.update_hessian_with_batch(x[:50])\n        g2.update_hessian_with_batch(x[50:])\n\n        # Both the one-shot and streamed hessian updates should match\n        self.assertAllClose(g1.hessian, g2.hessian, rtol=1e-6, atol=1e-6)\n\n    def test_hessian_matches_closed_form(self):\n        \"\"\"Tests that the Hessian matches the closed-form solution.\"\"\"\n        x = ops.array(np.random.randn(128, 7), \"float32\")\n        layer = layers.Dense(5, use_bias=False)\n        layer.build((None, 7))\n        g = GPTQ(layer)\n        g.update_hessian_with_batch(x)\n\n        expected = ops.multiply(\n            ops.divide(2.0, x.shape[0]), ops.matmul(ops.transpose(x), x)\n        )\n        self.assertAllClose(g.hessian, expected, rtol=1e-6, atol=1e-6)\n\n    def test_higher_rank_inputs_are_reshaped(self):\n        \"\"\"Tests that higher-rank inputs are reshaped correctly.\"\"\"\n        # x: [batch, time, feat]\n        x = ops.array(np.random.randn(10, 4, 7), \"float32\")\n        x_flat = ops.reshape(x, (-1, ops.shape(x)[-1]))\n\n        layer1 = layers.Dense(5, use_bias=False)\n        layer1.build((None, 7))\n        g1 = GPTQ(layer1)\n        g1.update_hessian_with_batch(x)\n\n        layer2 = layers.Dense(5, use_bias=False)\n        layer2.build((None, 7))\n        g2 = GPTQ(layer2)\n        g2.update_hessian_with_batch(x_flat)\n\n        self.assertAllClose(g1.hessian, g2.hessian, rtol=1e-6, atol=1e-6)\n\n    def test_raises_on_feature_mismatch(self):\n        x = ops.array(np.random.randn(8, 7), \"float32\")\n        layer = layers.Dense(5, use_bias=False)\n        layer.build((None, 6))  # wrong in_features\n        g = GPTQ(layer)\n\n        with self.assertRaisesRegex(ValueError, \"do not match input features\"):\n            g.update_hessian_with_batch(x)\n\n        with self.assertRaisesRegex(ValueError, \"cannot be None\"):\n            g.update_hessian_with_batch(None)\n        with self.assertRaisesRegex(ValueError, \"cannot be empty\"):\n            g.update_hessian_with_batch(\n                ops.array(np.empty((0, 7), dtype=\"float32\"))\n            )\n\n    def test_num_samples_accumulates_correctly(self):\n        \"\"\"Tests that the number of samples is accumulated correctly when\n        streaming updates are used.\"\"\"\n        x = ops.array(np.random.randn(64, 7), \"float32\")\n        layer = layers.Dense(5, use_bias=False)\n        layer.build((None, 7))\n        g = GPTQ(layer)\n\n        g.update_hessian_with_batch(x[:5])\n        g.update_hessian_with_batch(x[5:30])\n        g.update_hessian_with_batch(x[30:])\n\n        self.assertEqual(g.num_samples, 64)\n\n    def test_numeric_stability_large_values(self):\n        \"\"\"Tests numeric stability of hessian update with large input values.\"\"\"\n        x = ops.multiply(ops.array(np.random.randn(32, 7), \"float32\"), 1e6)\n        layer = layers.Dense(5, use_bias=False)\n        layer.build((None, 7))\n\n        g = GPTQ(layer)\n        g.update_hessian_with_batch(x)\n\n        # Should be finite and symmetric\n        self.assertTrue(ops.all(ops.isfinite(g.hessian)))\n        self.assertTrue(ops.all(ops.equal(g.hessian, ops.transpose(g.hessian))))\n\n    def test_einsumdense_2d_kernel_hessian_shape(self):\n        x = layers.Input((7,))\n        y = layers.EinsumDense(\"ab,bc->ac\", output_shape=(5,))(x)\n        model = keras.Model(x, y)\n        einsum_dense_layer = next(\n            l for l in model.layers if isinstance(l, layers.EinsumDense)\n        )\n\n        g = GPTQ(einsum_dense_layer)\n\n        # should infer rows==7\n        self.assertEqual(ops.shape(g.hessian), (7, 7))\n\n    def test_einsumdense_3d_kernel_streaming_equals_big_batch(self):\n        \"\"\"Tests that streaming updates to the Hessian are equivalent to a big\n        batch update.\"\"\"\n        # Construct a tiny attention-like einsum with 3D kernel\n        x = layers.Input((7,))\n        qkv = layers.EinsumDense(\"bf,fhk->bhk\", output_shape=(2, 3))(\n            x\n        )  # heads=2, head_dim=3\n        model = keras.Model(x, qkv)\n        einsum_dense_layer = next(\n            l for l in model.layers if isinstance(l, layers.EinsumDense)\n        )\n\n        x = ops.array(np.random.randn(50, 7), \"float32\")\n\n        g1 = GPTQ(einsum_dense_layer)\n        g1.update_hessian_with_batch(x)\n\n        g2 = GPTQ(einsum_dense_layer)\n        g2.update_hessian_with_batch(x[:20])\n        g2.update_hessian_with_batch(x[20:])\n\n        self.assertAllClose(g1.hessian, g2.hessian, rtol=1e-6, atol=1e-6)\n\n    def test_identity_inv_hessian_matches_direct_quantization(self):\n        \"\"\"Tests that the matrix quantization without error correction\n        matches the direct implementation.\"\"\"\n        in_features, out_features = 16, 8\n        weights = ops.reshape(\n            ops.linspace(\n                -0.9, 1.1, in_features * out_features, dtype=\"float32\"\n            ),\n            (in_features, out_features),\n        )\n        weights_transpose = ops.transpose(weights)\n\n        # inverse_hessian = identity; no cross-feature correction\n        # (since all off-diagonal elements are zero), which means\n        # there is no interaction between different features\n        inverse_hessian = ops.eye(in_features, dtype=\"float32\")\n\n        quantized_weights, scale_map, zero_map, g_idx = gptq_quantize_matrix(\n            weights_transpose,\n            inverse_hessian,\n            blocksize=128,\n            group_size=1,  # per-column quantization\n            activation_order=False,\n            compute_scale_zero=_compute_scale_zero,\n        )\n\n        dequantized_weights = dequantize_with_sz_map(\n            quantized_weights, scale_map, zero_map, g_idx\n        )\n\n        # Compare function output with columnwise direct application\n        # of quantization.\n        out = ops.zeros_like(weights_transpose)\n        for j in range(ops.shape(weights_transpose)[1]):\n            column = weights_transpose[:, j : j + 1]\n            scale, zero, maxq = _compute_scale_zero(column)\n            quantized_col = quantize_with_zero_point(column, scale, zero, maxq)\n            dequantized = dequantize_with_zero_point(quantized_col, scale, zero)\n            out = ops.slice_update(\n                out, (0, j), ops.expand_dims(dequantized[:, 0], 1)\n            )\n\n        self.assertAllClose(dequantized_weights, out, atol=1e-6)\n\n    def test_activation_order_produces_equivalent_weights(self):\n        \"\"\"\n        Tests that quantizing with `activation_order=True` yields the same\n        final weights as `activation_order=False`, because the internal\n        permutation should be undone.\n        \"\"\"\n        # Set up shared inputs and a non-trivial permutation.\n        in_features, out_features = 8, 6\n        initial_weights = ops.array(\n            np.random.randn(in_features, out_features), \"float32\"\n        )\n\n        # Generate a Hessian that creates a non-trivial permutation.\n        hessian_diag = ops.random.shuffle(\n            ops.linspace(10.0, 1.0, in_features, dtype=\"float32\")\n        )\n        hessian_matrix = ops.diag(hessian_diag)\n\n        # Sanity check: ensure the permutation is not the identity.\n        perm = _stable_permutation(hessian_diag)\n        self.assertFalse(ops.all(ops.equal(perm, ops.arange(in_features))))\n\n        def create_and_quantize(use_activation_order):\n            layer = layers.Dense(out_features, use_bias=False)\n            layer.build((None, in_features))\n            layer.set_weights([ops.copy(initial_weights)])\n\n            config = GPTQConfig(\n                dataset=None,\n                tokenizer=None,\n                group_size=-1,\n                activation_order=use_activation_order,\n            )\n            layer.quantize(\"gptq\", config=config)\n\n            quantizer = GPTQ(layer, config)\n            quantizer.hessian = hessian_matrix\n            quantizer.quantize_and_correct_layer()\n            return layer\n\n        # Quantize two layers, one with and one without activation ordering.\n        ordered_layer = create_and_quantize(use_activation_order=True)\n        unordered_layer = create_and_quantize(use_activation_order=False)\n\n        self.assertAllClose(\n            ordered_layer.get_weights()[0],\n            unordered_layer.get_weights()[0],\n            msg=\"Weights should be identical as the permutation is undone.\",\n        )\n\n\ndef _compute_scale_zero(x, **_):\n    # Per-column asymmetric int4 example\n    # scale = (max-min)/maxq, zero = round(-min/scale)\n    maxq = 15.0\n    xmin = ops.min(x, axis=0, keepdims=True)\n    xmax = ops.max(x, axis=0, keepdims=True)\n    scale = ops.divide(ops.subtract(xmax, xmin), ops.add(maxq, 1e-8))\n    zero = ops.round(ops.divide(ops.negative(xmin), ops.add(scale, 1e-8)))\n    return scale, zero, maxq\n\n\ndef _get_sequence_classifier():\n    \"\"\"Transformer-based sequence classifier\n\n    tokens -> Embedding -> Transformer -> GAP -> Dense(num_classes).\n    \"\"\"\n    embed_dim = 32\n    num_heads = 4\n    ff_dim = 32\n\n    class SimpleTransformerBlock(layers.Layer):\n        def __init__(self, embed_dim, num_heads, ff_dim, **kwargs):\n            super().__init__(**kwargs)\n\n            self.att = layers.MultiHeadAttention(\n                num_heads=num_heads, key_dim=embed_dim // num_heads\n            )\n            self.ffn = models.Sequential(\n                [\n                    layers.Dense(ff_dim, activation=\"relu\"),\n                    layers.Dense(embed_dim),\n                ]\n            )\n            self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)\n            self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)\n\n        def call(self, inputs):\n            attention_output = self.att(inputs, inputs)\n            out1 = self.layernorm1(inputs + attention_output)\n            ffn_output = self.ffn(out1)\n            return self.layernorm2(out1 + ffn_output)\n\n    inputs = layers.Input(shape=(SEQ_LEN,), dtype=\"int32\")\n    x = layers.Embedding(VOCAB_SIZE, embed_dim)(inputs)\n    x = SimpleTransformerBlock(embed_dim, num_heads, ff_dim)(x)\n    x = layers.GlobalAveragePooling1D()(x)\n    outputs = layers.Dense(NUM_CLASSES)(x)\n    return models.Model(inputs, outputs)\n\n\ndef _get_simple_model():\n    return models.Sequential([layers.Dense(10, input_shape=(5,))])\n\n\ndef _mean_kl(p, q):\n    # Add small epsilon for numerical stability\n    eps = 1e-8\n    p = ops.clip(p, eps, 1.0)\n    q = ops.clip(q, eps, 1.0)\n    # Compute KL divergence\n    # D_KL(P || Q) = sum(P * log(P / Q))\n    return ops.mean(\n        ops.sum(ops.multiply(p, ops.subtract(ops.log(p), ops.log(q))), axis=-1)\n    )\n\n\ndef _top1_match_rate(a_logits, b_logits):\n    \"\"\"Calculates the top-1 match rate between two sets of logits.\n\n    Formula: T = 1/N * sum(1{argmax(a_i) == argmax(b_i)})\n    \"\"\"\n    return ops.mean(\n        ops.equal(ops.argmax(a_logits, axis=-1), ops.argmax(b_logits, axis=-1))\n    )\n\n\nDATASETS = {\n    \"string_dataset\": lambda: _string_dataset(\n        CALIBRATION_TEXT, NUM_SAMPLES, SEQ_LEN\n    ),\n    \"token_dataset\": lambda: _token_dataset(NUM_SAMPLES, SEQ_LEN),\n}\n\nCONFIGS = {\n    \"default\": {},\n    \"per_channel\": {\"group_size\": -1, \"per_channel\": True},\n    \"act_order\": {\"activation_order\": True},\n    \"symmetric\": {\"symmetric\": True},\n    \"group_wise\": {\"group_size\": 8},\n    \"group_wise_act_order\": {\"group_size\": 8, \"activation_order\": True},\n    \"symmetric_act_order\": {\"symmetric\": True, \"activation_order\": True},\n    \"symmetric_per_channel\": {\"symmetric\": True, \"per_channel\": True},\n    \"group_wise_symmetric_8bit\": {\n        \"group_size\": 8,\n        \"symmetric\": True,\n        \"weight_bits\": 8,\n    },\n}\n\n\ndef _pad_or_trim_1d(ids, length):\n    \"\"\"Pads or trims a 1D array to a specified length.\"\"\"\n    ids = ops.ravel(ops.array(ids, \"int64\"))\n    if len(ids) < length:\n        ids = ops.concatenate(\n            [ids, ops.zeros(length - len(ids), dtype=ids.dtype)]\n        )\n    else:\n        ids = ids[:length]\n    return ids\n\n\ndef _char_tokenizer(vocab_size=VOCAB_SIZE, seq_len=SEQ_LEN):\n    \"\"\"Tokenizes strings to char-IDs or passes through int arrays;\n    outputs shape (1, seq_len).\"\"\"\n\n    def _tok(x):\n        if isinstance(x, str):\n            ids = ops.convert_to_tensor(\n                np.fromiter((ord(c) % vocab_size for c in x), dtype=np.int64)\n            )\n        else:\n            ids = np.asarray(x, dtype=np.int64)\n        ids = _pad_or_trim_1d(ids, seq_len)\n        return ids[None, :]\n\n    _tok.tokenize = _tok\n    return _tok\n\n\ndef _string_dataset(\n    long_text, num_samples=NUM_SAMPLES, sequence_length=SEQ_LEN\n):\n    \"\"\"Yields string slices\"\"\"\n    rng = np.random.default_rng(seed=0)\n    L = max(1, len(long_text) - sequence_length)\n    for _ in range(num_samples):\n        start = rng.integers(0, L) if L > 1 else 0\n        yield long_text[start : start + sequence_length]\n\n\ndef _token_dataset(\n    num_samples=NUM_SAMPLES, sequence_length=SEQ_LEN, vocab_size=VOCAB_SIZE\n):\n    \"\"\"Yields tokenized samples.\"\"\"\n    rng = np.random.default_rng(seed=0)\n    for _ in range(num_samples):\n        yield rng.integers(\n            low=0, high=vocab_size, size=(1, sequence_length), dtype=np.int64\n        )\n\n\n@pytest.mark.requires_trainable_backend\n@pytest.mark.skipif(\n    backend.backend() == \"torch\",\n    reason=\"torch gives low accuracy on CI, but works well locally\",\n)\nclass TestModelQuantization(testing.TestCase):\n    @parameterized.named_parameters(\n        named_product(\n            [\n                {\"testcase_name\": dataset_id, \"dataset\": dataset}\n                for dataset_id, dataset in DATASETS.items()\n            ],\n            [\n                {\"testcase_name\": config_id, \"config\": config}\n                for config_id, config in CONFIGS.items()\n            ],\n        )\n    )\n    def test_quantize_gptq_combinations(self, dataset, config):\n        \"\"\"Tests GPTQ quantization on a tiny transformer classifier.\n\n        Validates classification performance of the quantized model\n        with respect to the full-precision baseline.\n        \"\"\"\n        rng = np.random.default_rng(seed=321)\n        keras.utils.set_random_seed(123)\n\n        # Build the calibration set.\n        calibration_set = list(\n            dataset() if isinstance(dataset, Callable) else dataset\n        )\n        self.assertNotEmpty(calibration_set)\n\n        # Build classifier and tokenizer\n        model = _get_sequence_classifier()\n        tokenizer = _char_tokenizer(vocab_size=VOCAB_SIZE, seq_len=SEQ_LEN)\n\n        # Build an eval batch drawn from the SAME distribution as calibration\n        batch_size = min(8, len(calibration_set))\n        eval_samples = [\n            calibration_set[rng.integers(0, len(calibration_set))]\n            for _ in range(batch_size)\n        ]\n        x_eval = ops.concatenate([tokenizer(s) for s in eval_samples], axis=0)\n\n        # Baseline logits\n        y_ref = model.predict(x_eval)\n\n        embedding_layer = model.layers[1]\n        transformer_block = model.layers[2]\n\n        layer_structure = {\n            \"pre_block_layers\": [embedding_layer],\n            \"sequential_blocks\": [transformer_block],\n        }\n\n        base_cfg = dict(\n            dataset=calibration_set,\n            tokenizer=tokenizer,\n            weight_bits=W_BITS,\n            num_samples=NUM_SAMPLES,\n            sequence_length=SEQ_LEN,\n            group_size=32,\n            symmetric=False,\n            activation_order=False,\n            quantization_layer_structure=layer_structure,\n        )\n        gptq_cfg = GPTQConfig(**{**base_cfg, **config})\n\n        # Quantize\n        model.quantize(\"gptq\", config=gptq_cfg)\n\n        # Post-quant logits\n        y_q = model.predict(x_eval)\n\n        top1_match = _top1_match_rate(y_ref, y_q)\n\n        p_ref, p_q = ops.softmax(y_ref), ops.softmax(y_q)\n        kl = _mean_kl(p_ref, p_q)\n\n        self.assertGreaterEqual(\n            top1_match, 0.5, f\"Top-1 agreement too low: {top1_match:.3f}\"\n        )\n        self.assertLessEqual(kl, 0.30, f\"KL divergence too high: {kl:.3f}\")\n\n    @parameterized.named_parameters(\n        {\n            \"testcase_name\": \"gptq_with_invalid_config_type\",\n            \"mode\": \"gptq\",\n            \"config\": {\"weight_bits\": 4},\n            \"expected_exception\": ValueError,\n            \"error_msg\": \"Argument `config` must be an instance of \"\n            \"`QuantizationConfig`\",\n        },\n        {\n            \"testcase_name\": \"gptq_with_none_config\",\n            \"mode\": \"gptq\",\n            \"config\": None,\n            \"expected_exception\": ValueError,\n            \"error_msg\": \"For GPTQ, you must pass a `GPTQConfig` object \"\n            \"in the `config` argument.\",\n        },\n        {\n            \"testcase_name\": \"gptq_with_base_quantization_config\",\n            \"mode\": \"gptq\",\n            \"config\": QuantizationConfig(),\n            \"expected_exception\": NotImplementedError,\n            \"error_msg\": \"Do not instantiate QuantizationConfig directly.\",\n        },\n        {\n            \"testcase_name\": \"gptq_missing_structure\",\n            \"mode\": \"gptq\",\n            \"config\": GPTQConfig(dataset=[\"a\"], tokenizer=lambda x: x),\n            \"expected_exception\": ValueError,\n            \"error_msg\": \"For mode='gptq', a valid quantization structure\",\n        },\n    )\n    def test_quantize_scenarios(\n        self, mode, config, expected_exception, error_msg\n    ):\n        model = _get_simple_model()\n        with self.assertRaisesRegex(expected_exception, error_msg):\n            model.quantize(mode, config=config)\n\n    def test_gptq_filtering(self):\n        \"\"\"Tests that filters argument works for GPTQ.\"\"\"\n        model = _get_sequence_classifier()\n        tokenizer = _char_tokenizer(vocab_size=VOCAB_SIZE, seq_len=SEQ_LEN)\n\n        # Structure\n        embedding_layer = model.layers[1]\n        transformer_block = model.layers[2]\n        layer_structure = {\n            \"pre_block_layers\": [embedding_layer],\n            \"sequential_blocks\": [transformer_block],\n        }\n\n        config = GPTQConfig(\n            dataset=[np.zeros((1, SEQ_LEN), dtype=\"int32\")],\n            tokenizer=tokenizer,\n            quantization_layer_structure=layer_structure,\n            weight_bits=4,\n            group_size=32,\n        )\n\n        target_layer = transformer_block.ffn.layers[0]\n\n        def filter_fn(layer):\n            return layer.name != target_layer.name\n\n        model.quantize(\"gptq\", config=config, filters=filter_fn)\n\n        # Check that target_layer is NOT quantized.\n        self.assertIsNone(getattr(target_layer, \"quantization_mode\", None))\n        self.assertFalse(hasattr(target_layer, \"quantized_kernel\"))\n\n        # Check that other dense layers ARE quantized.\n        other_dense = transformer_block.ffn.layers[1]\n        self.assertEqual(\n            getattr(other_dense, \"quantization_mode\", None), \"gptq\"\n        )\n        self.assertTrue(hasattr(other_dense, \"quantized_kernel\"))\n\n    def test_gptq_multi_filtering(self):\n        \"\"\"Tests that list of regex filters works for GPTQ.\"\"\"\n        model = _get_sequence_classifier()\n        tokenizer = _char_tokenizer(vocab_size=VOCAB_SIZE, seq_len=SEQ_LEN)\n\n        embedding_layer = model.layers[1]\n        transformer_block = model.layers[2]\n        layer_structure = {\n            \"pre_block_layers\": [embedding_layer],\n            \"sequential_blocks\": [transformer_block],\n        }\n\n        config = GPTQConfig(\n            dataset=[np.zeros((1, SEQ_LEN), dtype=\"int32\")],\n            tokenizer=tokenizer,\n            quantization_layer_structure=layer_structure,\n            weight_bits=4,\n            group_size=32,\n        )\n\n        layer0 = transformer_block.ffn.layers[0]\n        layer1 = transformer_block.ffn.layers[1]\n\n        # We want to quantize only layer0.\n        filters = [f\"^{layer0.name}$\"]\n\n        model.quantize(\"gptq\", config=config, filters=filters)\n\n        # Check that layer0 is quantized.\n        self.assertEqual(getattr(layer0, \"quantization_mode\", None), \"gptq\")\n        self.assertTrue(hasattr(layer0, \"quantized_kernel\"))\n\n        # Check that layer1 is not quantized.\n        self.assertIsNone(getattr(layer1, \"quantization_mode\", None))\n        self.assertFalse(hasattr(layer1, \"quantized_kernel\"))\n"
  },
  {
    "path": "keras/src/quantizers/quantization_config.py",
    "content": "from keras.src.api_export import keras_export\nfrom keras.src.dtype_policies import QUANTIZATION_MODES\nfrom keras.src.saving import serialization_lib\n\n\n@keras_export(\"keras.quantizers.QuantizationConfig\")\nclass QuantizationConfig:\n    \"\"\"Base class for quantization configs.\n\n    Subclasses must implement the `mode` property and the `get_config` and\n    `from_config` class methods.\n\n    Args:\n        weight_quantizer: Quantizer for weights.\n        activation_quantizer: Quantizer for activations.\n    \"\"\"\n\n    def __init__(self, weight_quantizer=None, activation_quantizer=None):\n        self.weight_quantizer = weight_quantizer\n        self.activation_quantizer = activation_quantizer\n\n    @property\n    def mode(self):\n        raise NotImplementedError(\n            \"Subclasses must implement this property. Do not instantiate \"\n            \"QuantizationConfig directly.\"\n        )\n\n    def get_config(self):\n        return {\n            \"weight_quantizer\": serialization_lib.serialize_keras_object(\n                self.weight_quantizer\n            ),\n            \"activation_quantizer\": serialization_lib.serialize_keras_object(\n                self.activation_quantizer\n            ),\n        }\n\n    @classmethod\n    def from_config(cls, config):\n        weight_quantizer = serialization_lib.deserialize_keras_object(\n            config.get(\"weight_quantizer\")\n        )\n        activation_quantizer = serialization_lib.deserialize_keras_object(\n            config.get(\"activation_quantizer\")\n        )\n        return cls(\n            weight_quantizer=weight_quantizer,\n            activation_quantizer=activation_quantizer,\n        )\n\n    @staticmethod\n    def weight_quantizer_or_default(config, default):\n        if config is not None and config.weight_quantizer is not None:\n            return config.weight_quantizer\n        return default\n\n    @staticmethod\n    def activation_quantizer_or_default(config, default):\n        if config is not None:\n            return config.activation_quantizer\n        return default\n\n\n@keras_export(\"keras.quantizers.Int8QuantizationConfig\")\nclass Int8QuantizationConfig(QuantizationConfig):\n    \"\"\"Int8 quantization config.\n\n    Args:\n        weight_quantizer: Quantizer for weights.\n        activation_quantizer: Quantizer for activations. If \"default\", uses\n            AbsMaxQuantizer with axis=-1.\n    \"\"\"\n\n    def __init__(self, weight_quantizer=None, activation_quantizer=\"default\"):\n        from keras.src.quantizers.quantizers import AbsMaxQuantizer\n\n        if activation_quantizer == \"default\":\n            activation_quantizer = AbsMaxQuantizer()\n        super().__init__(weight_quantizer, activation_quantizer)\n        if self.weight_quantizer is not None:\n            if self.weight_quantizer.output_dtype != \"int8\":\n                raise ValueError(\n                    \"Int8QuantizationConfig requires a weight_quantizer \"\n                    \"with output_dtype='int8'. Received: \"\n                    f\"output_dtype={self.weight_quantizer.output_dtype}\"\n                )\n\n    @property\n    def mode(self):\n        return \"int8\"\n\n\n@keras_export(\"keras.quantizers.Int4QuantizationConfig\")\nclass Int4QuantizationConfig(QuantizationConfig):\n    \"\"\"Int4 quantization config.\n\n    Args:\n        weight_quantizer: Quantizer for weights.\n        activation_quantizer: Quantizer for activations. If \"default\", uses\n            AbsMaxQuantizer with axis=-1.\n        block_size: Size of groups along the input dimension for sub-channel\n            quantization. If a positive integer, uses sub-channel quantization\n            with `ceil(input_dim / block_size)` groups. If `None` or `-1`,\n            uses per-channel quantization (one scale per output channel).\n            Default: `128` (sub-channel with 128-element groups).\n    \"\"\"\n\n    def __init__(\n        self,\n        weight_quantizer=None,\n        activation_quantizer=\"default\",\n        block_size=128,\n    ):\n        if activation_quantizer == \"default\":\n            # Use weight-only quantization by default for int4\n            activation_quantizer = None\n        super().__init__(weight_quantizer, activation_quantizer)\n\n        # Validate block_size\n        if block_size is not None and block_size != -1 and block_size <= 0:\n            raise ValueError(\n                f\"block_size must be None, -1, or a positive integer. \"\n                f\"Received: block_size={block_size}\"\n            )\n        self.block_size = block_size\n\n        # Sub-channel quantization does not support custom quantizers\n        is_sub_channel = block_size is not None and block_size > 0\n        has_custom_quantizer = (\n            self.weight_quantizer is not None\n            or self.activation_quantizer is not None\n        )\n        if is_sub_channel and has_custom_quantizer:\n            raise ValueError(\n                \"Int4 sub-channel quantization (block_size > 0) does not \"\n                \"support custom quantizers. Either set block_size to None \"\n                \"or -1 for per-channel quantization, or remove the custom \"\n                f\"quantizer arguments. Received: block_size={block_size}\"\n            )\n\n        if self.weight_quantizer is not None:\n            if self.weight_quantizer.value_range != (-8, 7):\n                raise ValueError(\n                    \"Int4QuantizationConfig requires a weight_quantizer \"\n                    \"with value_range=(-8, 7). Received: \"\n                    f\"value_range={self.weight_quantizer.value_range}\"\n                )\n\n            if self.weight_quantizer.output_dtype != \"int8\":\n                raise ValueError(\n                    \"Int4QuantizationConfig requires a weight_quantizer \"\n                    \"with output_dtype='int8'. Received: \"\n                    f\"output_dtype={self.weight_quantizer.output_dtype}\"\n                )\n\n    @property\n    def mode(self):\n        return \"int4\"\n\n    def get_config(self):\n        config = super().get_config()\n        config[\"block_size\"] = self.block_size\n        return config\n\n    @classmethod\n    def from_config(cls, config):\n        weight_quantizer = serialization_lib.deserialize_keras_object(\n            config.get(\"weight_quantizer\")\n        )\n        activation_quantizer = serialization_lib.deserialize_keras_object(\n            config.get(\"activation_quantizer\")\n        )\n        # Default to None for backwards compatibility with models saved\n        # before block_size was introduced (those used per-channel mode)\n        block_size = config.get(\"block_size\", None)\n        return cls(\n            weight_quantizer=weight_quantizer,\n            activation_quantizer=activation_quantizer,\n            block_size=block_size,\n        )\n\n\n@keras_export(\"keras.quantizers.Float8QuantizationConfig\")\nclass Float8QuantizationConfig(QuantizationConfig):\n    \"\"\"FP8 quantization config.\n\n    FP8 mixed-precision training does not support user defined quantizers.\n    This config is only used to indicate that FP8 mixed-precision training\n    should be used.\n    \"\"\"\n\n    def __init__(self):\n        super().__init__(None, None)\n\n    @property\n    def mode(self):\n        return \"float8\"\n\n    def get_config(self):\n        return {}\n\n    @classmethod\n    def from_config(cls, config):\n        return cls()\n\n\ndef validate_and_resolve_config(mode, config):\n    \"\"\"Validate and resolve quantization config.\n\n    This function validates the quantization config and resolves the mode.\n    If mode is not provided, it is inferred from the config.\n    If config is not provided, a default config is inferred from the mode.\n\n    Args:\n        mode: Quantization mode.\n        config: Quantization config.\n    \"\"\"\n    # 1. Backwards Compatibility: Handle string shortcuts.\n    if isinstance(config, str):\n        mode = config\n        config = None\n\n    _validate_mode(mode)\n\n    # 2. Resolve \"mode\" into a Config object.\n    if config is None:\n        if mode == \"int8\":\n            config = Int8QuantizationConfig()\n        elif mode == \"int4\":\n            config = Int4QuantizationConfig()\n        elif mode == \"float8\":\n            config = Float8QuantizationConfig()\n        elif mode == \"gptq\":\n            raise ValueError(\n                \"For GPTQ, you must pass a `GPTQConfig` object in the \"\n                \"`config` argument.\"\n            )\n        elif mode == \"awq\":\n            raise ValueError(\n                \"For AWQ, you must pass an `AWQConfig` object in the \"\n                \"`config` argument.\"\n            )\n        else:\n            if mode is not None:\n                raise ValueError(\n                    f\"Invalid quantization mode. Received: mode={mode}\"\n                )\n            raise ValueError(\n                \"You must provide either `mode` or `config` to `quantize`.\"\n            )\n    else:\n        if not isinstance(config, QuantizationConfig):\n            raise ValueError(\n                \"Argument `config` must be an instance of \"\n                \"`QuantizationConfig`. \"\n                f\"Received: config={config} (of type {type(config)})\"\n            )\n\n    # 3. Validation: Prevent contradictions.\n    if mode is not None and config.mode != mode:\n        raise ValueError(\n            f\"Contradictory arguments: mode='{mode}' but \"\n            f\"config.mode='{config.mode}'\"\n        )\n\n    # Ensure mode is consistent.\n    mode = config.mode\n\n    # Ensure the mode derived from the config is valid.\n    _validate_mode(mode)\n\n    if mode == \"gptq\":\n        from keras.src.quantizers.gptq_config import GPTQConfig\n\n        if not isinstance(config, GPTQConfig):\n            raise ValueError(\n                \"Mode 'gptq' requires a valid `config` argument of type \"\n                f\"`GPTQConfig`. Received: {type(config)}\"\n            )\n\n    if mode == \"awq\":\n        from keras.src.quantizers.awq_config import AWQConfig\n\n        if not isinstance(config, AWQConfig):\n            raise ValueError(\n                \"Mode 'awq' requires a valid `config` argument of type \"\n                f\"`AWQConfig`. Received: {type(config)}\"\n            )\n\n    return config\n\n\ndef _validate_mode(mode):\n    \"\"\"Validates quantization mode.\"\"\"\n    if mode is not None and mode not in QUANTIZATION_MODES:\n        raise ValueError(\n            \"Invalid quantization mode. \"\n            f\"Expected one of {QUANTIZATION_MODES}. Received: mode={mode}\"\n        )\n\n\ndef get_block_size_for_layer(layer, config):\n    \"\"\"Determine the block size for int4 quantization.\n\n    The block size can be specified either through the `config` argument\n    or through the `dtype_policy` if it is of type `Int4DTypePolicy`.\n\n    The config argument is usually available when quantizing the layer\n    via the `quantize` method. If the layer was deserialized from a\n    saved model, the block size should be specified in the `dtype_policy`.\n\n    Args:\n        layer: The layer being quantized.\n        config: An optional configuration object that may contain the\n            `block_size` attribute.\n    Returns:\n        int or None. The determined block size for int4 quantization.\n        Returns `None` or `-1` for per-channel quantization.\n    \"\"\"\n    from keras.src.dtype_policies.dtype_policy import Int4DTypePolicy\n    from keras.src.dtype_policies.dtype_policy_map import DTypePolicyMap\n\n    if config and isinstance(config, Int4QuantizationConfig):\n        return config.block_size\n    elif isinstance(layer.dtype_policy, Int4DTypePolicy):\n        block_size = layer.dtype_policy.block_size\n        # Convert -1 to None for consistency\n        return None if block_size == -1 else block_size\n    elif isinstance(layer.dtype_policy, DTypePolicyMap):\n        policy = layer.dtype_policy[layer.path]\n        if isinstance(policy, Int4DTypePolicy):\n            block_size = policy.block_size\n            return None if block_size == -1 else block_size\n        # Fall back to None for legacy QuantizedDTypePolicy\n        return None\n    else:\n        # For backwards compatibility with models that don't have\n        # Int4DTypePolicy (legacy per-channel mode)\n        return None\n"
  },
  {
    "path": "keras/src/quantizers/quantization_config_test.py",
    "content": "import os\n\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import saving\nfrom keras.src import testing\nfrom keras.src.quantizers.quantization_config import Int4QuantizationConfig\nfrom keras.src.quantizers.quantization_config import Int8QuantizationConfig\nfrom keras.src.quantizers.quantization_config import QuantizationConfig\nfrom keras.src.quantizers.quantization_config import validate_and_resolve_config\nfrom keras.src.quantizers.quantizers import AbsMaxQuantizer\n\n\nclass QuantizationConfigTest(testing.TestCase):\n    def test_base_quantization_config(self):\n        config = QuantizationConfig()\n        with self.assertRaises(NotImplementedError):\n            _ = config.mode\n\n    def test_int8_quantization_config_valid(self):\n        config = Int8QuantizationConfig()\n        self.assertEqual(config.mode, \"int8\")\n        self.assertIsNone(config.weight_quantizer)\n\n        # Valid weight quantizer\n        q = AbsMaxQuantizer(axis=0, value_range=(-127, 127))\n        config = Int8QuantizationConfig(weight_quantizer=q)\n        self.assertEqual(config.weight_quantizer, q)\n\n    def test_int8_quantization_config_invalid(self):\n        # Invalid value_range\n        with self.assertRaisesRegex(ValueError, \"value_range\"):\n            AbsMaxQuantizer(axis=0, value_range=(-256, 256))\n\n    def test_int4_quantization_config_valid(self):\n        config = Int4QuantizationConfig()\n        self.assertEqual(config.mode, \"int4\")\n        self.assertIsNone(config.weight_quantizer)\n\n        # Valid weight quantizer with per-channel mode\n        # (custom quantizers require block_size=None or -1)\n        q = AbsMaxQuantizer(axis=0, value_range=(-8, 7))\n        config = Int4QuantizationConfig(weight_quantizer=q, block_size=None)\n        self.assertEqual(config.weight_quantizer, q)\n\n    def test_int4_quantization_config_invalid(self):\n        # Invalid value_range\n        q = AbsMaxQuantizer(axis=0, value_range=(-127, 127))\n        with self.assertRaisesRegex(ValueError, \"value_range\"):\n            Int4QuantizationConfig(weight_quantizer=q, block_size=None)\n\n    def test_int4_quantization_config_subchannel_rejects_custom_quantizer(self):\n        # Sub-channel quantization does not support custom quantizers\n        weight_q = AbsMaxQuantizer(axis=0, value_range=(-8, 7))\n        activation_q = AbsMaxQuantizer(axis=-1)\n\n        # Default block_size=128 is sub-channel, should reject custom quantizer\n        with self.assertRaisesRegex(\n            ValueError, \"sub-channel quantization.*does not support\"\n        ):\n            Int4QuantizationConfig(weight_quantizer=weight_q)\n\n        # Explicit positive block_size should also reject weight quantizer\n        with self.assertRaisesRegex(\n            ValueError, \"sub-channel quantization.*does not support\"\n        ):\n            Int4QuantizationConfig(weight_quantizer=weight_q, block_size=64)\n\n        # Sub-channel should also reject activation quantizer\n        with self.assertRaisesRegex(\n            ValueError, \"sub-channel quantization.*does not support\"\n        ):\n            Int4QuantizationConfig(activation_quantizer=activation_q)\n\n        with self.assertRaisesRegex(\n            ValueError, \"sub-channel quantization.*does not support\"\n        ):\n            Int4QuantizationConfig(\n                activation_quantizer=activation_q, block_size=64\n            )\n\n        # Per-channel (block_size=None or -1) should accept custom quantizers\n        config = Int4QuantizationConfig(\n            weight_quantizer=weight_q, block_size=None\n        )\n        self.assertEqual(config.weight_quantizer, weight_q)\n\n        config = Int4QuantizationConfig(\n            weight_quantizer=weight_q, block_size=-1\n        )\n        self.assertEqual(config.weight_quantizer, weight_q)\n\n        config = Int4QuantizationConfig(\n            activation_quantizer=activation_q, block_size=None\n        )\n        self.assertEqual(config.activation_quantizer, activation_q)\n\n        config = Int4QuantizationConfig(\n            activation_quantizer=activation_q, block_size=-1\n        )\n        self.assertEqual(config.activation_quantizer, activation_q)\n\n    def test_quantization_config_serialization(self):\n        config = Int8QuantizationConfig(\n            weight_quantizer=AbsMaxQuantizer(axis=0),\n            activation_quantizer=AbsMaxQuantizer(axis=-1),\n        )\n        serialized = config.get_config()\n        deserialized = Int8QuantizationConfig.from_config(serialized)\n        self.assertIsInstance(deserialized, Int8QuantizationConfig)\n        self.assertIsInstance(deserialized.weight_quantizer, AbsMaxQuantizer)\n        self.assertIsInstance(\n            deserialized.activation_quantizer, AbsMaxQuantizer\n        )\n        self.assertEqual(deserialized.weight_quantizer.axis, (0,))\n        self.assertEqual(deserialized.activation_quantizer.axis, (-1,))\n\n    def test_validate_and_resolve_config(self):\n        # 1. String mode\n        config = validate_and_resolve_config(\"int8\", None)\n        self.assertIsInstance(config, Int8QuantizationConfig)\n        self.assertEqual(config.mode, \"int8\")\n\n        config = validate_and_resolve_config(\"int4\", None)\n        self.assertIsInstance(config, Int4QuantizationConfig)\n        self.assertEqual(config.mode, \"int4\")\n\n        # 2. Config object\n        config_in = Int8QuantizationConfig()\n        config_out = validate_and_resolve_config(None, config_in)\n        self.assertIs(config_out, config_in)\n\n        # 3. Mode + Config (matching)\n        config_in = Int8QuantizationConfig()\n        config_out = validate_and_resolve_config(\"int8\", config_in)\n        self.assertIs(config_out, config_in)\n\n        # 4. Mode + Config (mismatch)\n        config_in = Int8QuantizationConfig()\n        with self.assertRaisesRegex(ValueError, \"Contradictory arguments\"):\n            validate_and_resolve_config(\"int4\", config_in)\n\n        # 5. Invalid mode\n        with self.assertRaisesRegex(ValueError, \"Invalid quantization mode\"):\n            validate_and_resolve_config(\"invalid_mode\", None)\n\n        # 6. GPTQ without config\n        with self.assertRaisesRegex(ValueError, \"must pass a `GPTQConfig`\"):\n            validate_and_resolve_config(\"gptq\", None)\n\n        # 7. Contradictory config\n        with self.assertRaisesRegex(ValueError, \"Contradictory arguments\"):\n            validate_and_resolve_config(\"gptq\", Int8QuantizationConfig())\n\n        # 8. GPTQ with invalid config type (but correct mode)\n        class FakeGPTQConfig(QuantizationConfig):\n            @property\n            def mode(self):\n                return \"gptq\"\n\n        with self.assertRaisesRegex(ValueError, \"requires a valid `config`\"):\n            validate_and_resolve_config(\"gptq\", FakeGPTQConfig())\n\n    def test_int8_quantization_config_output_dtype_mismatch(self):\n        # Invalid output_dtype\n        q = AbsMaxQuantizer(\n            axis=0, value_range=(-127, 127), output_dtype=\"int16\"\n        )\n        with self.assertRaisesRegex(ValueError, \"output_dtype='int8'\"):\n            Int8QuantizationConfig(weight_quantizer=q)\n\n    def test_int4_quantization_config_output_dtype_mismatch(self):\n        # Invalid output_dtype (using per-channel mode to test output_dtype)\n        q = AbsMaxQuantizer(axis=0, value_range=(-8, 7), output_dtype=\"int16\")\n        with self.assertRaisesRegex(ValueError, \"output_dtype='int8'\"):\n            Int4QuantizationConfig(weight_quantizer=q, block_size=None)\n\n    def test_model_save_and_load(self):\n        \"\"\"\n        Test custom quantizer serialization for model save and load.\n        \"\"\"\n        # Setup\n        weight_range = (-100, 100)\n        custom_quantizer = AbsMaxQuantizer(axis=0, value_range=weight_range)\n        config = Int8QuantizationConfig(\n            weight_quantizer=custom_quantizer,\n            activation_quantizer=None,\n        )\n\n        layer = layers.Dense(10)\n        layer.build((None, 5))\n        layer.quantize(\"int8\", config=config)\n\n        model = models.Sequential([layer])\n        model.build((None, 5))\n\n        # Save to temp file\n        filepath = os.path.join(self.get_temp_dir(), \"quantized_model.keras\")\n        model.save(filepath)\n\n        # Load back\n        loaded_model = saving.load_model(filepath)\n\n        # Verify\n        loaded_layer = loaded_model.layers[0]\n        self.assertIsInstance(\n            loaded_layer.quantization_config, Int8QuantizationConfig\n        )\n\n        quantizer = loaded_layer.quantization_config.weight_quantizer\n        self.assertIsInstance(quantizer, AbsMaxQuantizer)\n        self.assertEqual(quantizer.axis, (0,))\n        self.assertAllEqual(quantizer.value_range, weight_range)\n        self.assertIsNone(loaded_layer.quantization_config.activation_quantizer)\n        self.assertTrue(loaded_layer._is_quantized)\n\n    def test_awq_requires_config(self):\n        \"\"\"Test that AWQ mode requires a config.\"\"\"\n        with self.assertRaisesRegex(ValueError, \"AWQConfig\"):\n            validate_and_resolve_config(\"awq\", None)\n\n    def test_awq_requires_correct_config_type(self):\n        \"\"\"Test that AWQ requires AWQConfig type.\"\"\"\n        # Int8QuantizationConfig has mode='int8', so passing mode='awq' raises\n        # a contradictory arguments error\n        with self.assertRaisesRegex(ValueError, \"Contradictory arguments\"):\n            validate_and_resolve_config(\"awq\", Int8QuantizationConfig())\n"
  },
  {
    "path": "keras/src/quantizers/quantizers.py",
    "content": "import math\n\nimport ml_dtypes\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend import KerasTensor\nfrom keras.src.backend import any_symbolic_tensors\nfrom keras.src.backend.common.backend_utils import canonicalize_axis\nfrom keras.src.backend.common.backend_utils import standardize_axis_for_numpy\nfrom keras.src.ops.operation import Operation\nfrom keras.src.quantizers.gptq_config import GPTQConfig\n\n\"\"\"Int8-related classes and methods\"\"\"\n\n\n@keras_export([\"keras.Quantizer\", \"keras.quantizers.Quantizer\"])\nclass Quantizer:\n    def __init__(self, output_dtype=\"int8\"):\n        self.output_dtype = output_dtype\n\n    def __call__(self, x):\n        \"\"\"Compute a quantized output from an input tensor.\"\"\"\n        return x\n\n    @classmethod\n    def from_config(cls, config):\n        \"\"\"Creates a quantizer from its config.\n\n        This method is the reverse of `get_config`,\n        capable of instantiating the same quantizer from the config\n        dictionary.\n\n        This method is used by Keras `model_to_estimator`, saving and\n        loading models to HDF5 formats, Keras model cloning, some visualization\n        utilities, and exporting models to and from JSON.\n\n        Args:\n            config: A Python dictionary, typically the output of get_config.\n\n        Returns:\n            A quantizer instance.\n        \"\"\"\n        return cls(**config)\n\n    def get_config(self):\n        \"\"\"Returns the config of the quantizer.\n\n        A quantizer config is a Python dictionary (serializable)\n        containing all configuration parameters of the quantizer.\n        The same quantizer can be reinstantiated later\n        (without any saved state) from this configuration.\n\n        This method is optional if you are just training and executing models,\n        exporting to and from SavedModels, or using weight checkpoints.\n\n        This method is required for Keras `model_to_estimator`, saving and\n        loading models to HDF5 formats, Keras model cloning, some visualization\n        utilities, and exporting models to and from JSON.\n\n        Returns:\n            Python dictionary.\n        \"\"\"\n        raise NotImplementedError(f\"{self} does not implement get_config()\")\n\n\n@keras_export(\"keras.quantizers.abs_max_quantize\")\ndef abs_max_quantize(\n    inputs,\n    axis,\n    value_range=(-127, 127),\n    dtype=\"int8\",\n    epsilon=backend.epsilon(),\n    to_numpy=False,\n):\n    \"\"\"\n    Quantizes the input tensor using the absolute maximum quantization scheme.\n\n    Args:\n        inputs: Input tensor to quantize.\n        axis: Axis along which to compute the quantization range.\n        value_range: Tuple of the minimum and maximum values of the quantization\n            range.\n        dtype: Data type of the quantized output.\n        epsilon: Small value to avoid division by zero.\n        to_numpy: Whether to perform the quantization in numpy. This performs\n            the computation on the host CPU and can be useful for saving memory\n            on the device. If False, the computation is performed on the device.\n\n    Returns:\n        A tuple of the quantized tensor and the scale.\n    \"\"\"\n    if to_numpy:\n        # Save memory on the device using numpy\n        original_dtype = backend.standardize_dtype(inputs.dtype)\n        inputs = ops.convert_to_numpy(inputs)\n        axis = standardize_axis_for_numpy(axis)\n        scale = np.divide(\n            value_range[1],\n            np.add(np.max(np.abs(inputs), axis=axis, keepdims=True), epsilon),\n        )\n        outputs = np.multiply(inputs, scale)\n        outputs = np.clip(np.round(outputs), value_range[0], value_range[1])\n        outputs = outputs.astype(dtype)\n        return ops.convert_to_tensor(outputs), ops.convert_to_tensor(\n            scale, dtype=original_dtype\n        )\n\n    inputs = ops.convert_to_tensor(inputs)\n    scale = ops.divide(\n        value_range[1],\n        ops.add(ops.max(ops.abs(inputs), axis=axis, keepdims=True), epsilon),\n    )\n    scale = ops.cast(scale, backend.standardize_dtype(inputs.dtype))\n    outputs = ops.multiply(inputs, scale)\n    outputs = ops.clip(ops.round(outputs), value_range[0], value_range[1])\n    outputs = ops.cast(outputs, dtype)\n    return outputs, scale\n\n\n@keras_export(\"keras.quantizers.abs_max_quantize_grouped_with_zero_point\")\ndef abs_max_quantize_grouped_with_zero_point(\n    inputs,\n    block_size,\n    value_range=(-8, 7),\n    dtype=\"int8\",\n    epsilon=backend.epsilon(),\n    to_numpy=False,\n):\n    \"\"\"Quantizes a 2D tensor using grouped asymmetric quantization with\n    zero point.\n\n    Groups are formed along axis 0 (the input/contracting dimension).\n    Each group of `block_size` rows gets its own scale factor and zero point\n    per column. This is useful for weight distributions that are not centered\n    around zero.\n\n    Args:\n        inputs: Input tensor to quantize. Shape: `(input_dim, output_dim)`.\n        block_size: Number of elements per group along axis 0.\n        value_range: Tuple of `(min, max)` quantization range.\n        dtype: Data type of quantized output.\n        epsilon: Small value to avoid division by zero.\n        to_numpy: Whether to perform computation in numpy for memory\n            efficiency.\n\n    Returns:\n        A tuple `(quantized_tensor, scale, zero_point)` where:\n            - `quantized_tensor`: Same shape as inputs, dtype=`dtype`.\n            - `scale`: Shape `(n_groups, output_dim)` where\n              `n_groups = ceil(input_dim / block_size)`.\n            - `zero_point`: Shape `(n_groups, output_dim)`, dtype=`uint8`.\n\n    Example:\n\n    ```python\n    >>> import numpy as np\n    >>> from keras.quantizers import abs_max_quantize_grouped_with_zero_point\n    >>> kernel = np.random.randn(512, 256).astype(\"float32\")\n    >>> quantized, scale, zero_point = abs_max_quantize_grouped_with_zero_point(\n    ...     kernel, block_size=128, value_range=(-8, 7)\n    ... )\n    >>> quantized.shape\n    (512, 256)\n    >>> scale.shape  # 512 / 128 = 4 groups\n    (4, 256)\n    >>> zero_point.shape\n    (4, 256)\n    ```\n    \"\"\"\n    if to_numpy:\n        return _abs_max_quantize_grouped_with_zero_point_numpy(\n            inputs, block_size, value_range, dtype, epsilon\n        )\n    return _abs_max_quantize_grouped_with_zero_point_tensor(\n        inputs, block_size, value_range, dtype, epsilon\n    )\n\n\ndef _abs_max_quantize_grouped_with_zero_point_numpy(\n    inputs, block_size, value_range, dtype, epsilon\n):\n    \"\"\"NumPy implementation of grouped asymmetric quantization.\n\n    Uses NumPy for computation to reduce GPU memory usage during\n    model quantization.\n    \"\"\"\n    original_dtype = backend.standardize_dtype(inputs.dtype)\n    inputs = ops.convert_to_numpy(inputs)\n\n    input_dim, output_dim = inputs.shape\n    n_groups = math.ceil(input_dim / block_size)\n    qmin, qmax = value_range\n\n    # Zero-pad rows so input_dim is divisible by block_size\n    padded_input_dim = n_groups * block_size\n    if padded_input_dim > input_dim:\n        padding = np.zeros(\n            (padded_input_dim - input_dim, output_dim), dtype=inputs.dtype\n        )\n        inputs_padded = np.concatenate([inputs, padding], axis=0)\n    else:\n        inputs_padded = inputs\n\n    inputs_reshaped = inputs_padded.reshape(n_groups, block_size, output_dim)\n\n    # Compute per-group min/max for asymmetric quantization\n    min_val = np.min(inputs_reshaped, axis=1, keepdims=True)\n    max_val = np.max(inputs_reshaped, axis=1, keepdims=True)\n\n    # Scale maps the [min, max] range to [qmin, qmax]\n    scale = np.divide(np.subtract(max_val, min_val) + epsilon, qmax - qmin)\n\n    # Zero point shifts the quantized range to include the original zero\n    zero_point = np.round(np.divide(-min_val, scale)) + qmin\n    zero_point = np.clip(zero_point, qmin, qmax)\n\n    # Quantize: q = round(input / scale) + zero_point\n    outputs = np.round(np.divide(inputs_reshaped, scale)) + zero_point\n    outputs = np.clip(outputs, qmin, qmax)\n    outputs = outputs.astype(dtype)\n\n    # Remove padding and squeeze to (n_groups, output_dim)\n    outputs = outputs.reshape(padded_input_dim, output_dim)[:input_dim, :]\n    scale = np.squeeze(scale, axis=1)\n    zero_point = np.squeeze(zero_point, axis=1).astype(\"int8\")\n\n    return (\n        ops.convert_to_tensor(outputs),\n        ops.convert_to_tensor(scale, dtype=original_dtype),\n        ops.convert_to_tensor(zero_point),\n    )\n\n\ndef _abs_max_quantize_grouped_with_zero_point_tensor(\n    inputs, block_size, value_range, dtype, epsilon\n):\n    \"\"\"Tensor backend implementation of grouped asymmetric quantization.\"\"\"\n    original_dtype = backend.standardize_dtype(inputs.dtype)\n    inputs = ops.convert_to_tensor(inputs)\n\n    input_shape = ops.shape(inputs)\n    input_dim = input_shape[0]\n    output_dim = input_shape[1]\n    qmin, qmax = value_range\n\n    # Infer bit-width from quantization range (e.g., [-8, 7] -> 4 bits)\n    num_levels = qmax - qmin + 1\n    bits = int(math.log2(num_levels))\n\n    n_groups = int(math.ceil(int(input_dim) / block_size))\n    padded_input_dim = n_groups * block_size\n\n    # Transpose to [out_features, in_features] for\n    # compute_quantization_parameters\n    inputs_t = ops.transpose(inputs)\n\n    # Compute scale and zero point using the unified quantization function\n    scale_t, zero_point_t, _ = compute_quantization_parameters(\n        inputs_t,\n        bits=bits,\n        symmetric=False,\n        per_channel=True,\n        group_size=block_size,\n        compute_dtype=original_dtype,\n        epsilon=epsilon,\n        signed=True,\n    )\n\n    # Transpose results back to (n_groups, output_dim)\n    scale = ops.transpose(scale_t)\n    zero_point = ops.transpose(zero_point_t)\n\n    # Zero-pad rows so input_dim is divisible by block_size\n    pad_size = padded_input_dim - int(input_dim)\n    if pad_size > 0:\n        padding = ops.zeros((pad_size, output_dim), dtype=inputs.dtype)\n        inputs_padded = ops.concatenate([inputs, padding], axis=0)\n    else:\n        inputs_padded = inputs\n\n    inputs_reshaped = ops.reshape(\n        inputs_padded, (n_groups, block_size, output_dim)\n    )\n\n    # Expand scale and zero_point for broadcasting across block_size\n    scale_expanded = ops.expand_dims(scale, axis=1)\n    zero_point_expanded = ops.expand_dims(zero_point, axis=1)\n\n    # Quantize: q = round(input / scale) + zero_point\n    outputs = ops.add(\n        ops.round(ops.divide(inputs_reshaped, scale_expanded)),\n        zero_point_expanded,\n    )\n    outputs = ops.clip(outputs, qmin, qmax)\n    outputs = ops.cast(outputs, dtype)\n\n    # Remove padding\n    outputs = ops.reshape(outputs, (padded_input_dim, output_dim))\n    outputs = outputs[:input_dim, :]\n\n    return outputs, scale, zero_point\n\n\n@keras_export(\"keras.quantizers.AbsMaxQuantizer\")\nclass AbsMaxQuantizer(Quantizer):\n    def __init__(\n        self,\n        axis=None,  # Deprecated, provide axis in __call__ instead.\n        value_range=(-127, 127),\n        epsilon=backend.epsilon(),\n        output_dtype=\"int8\",\n    ):\n        Quantizer.__init__(self, output_dtype=output_dtype)\n        if axis is not None:\n            if isinstance(axis, int):\n                axis = (axis,)\n            self.axis = tuple(axis)\n        else:\n            self.axis = None\n        self.value_range = value_range\n        self.epsilon = epsilon\n        if output_dtype == \"int8\":\n            if value_range[0] < -128 or value_range[1] > 127:\n                raise ValueError(\n                    f\"Quantizer with output_dtype='int8' requires value_range \"\n                    f\"to be within the interval [-128, 127]. Received: \"\n                    f\"value_range={value_range}\"\n                )\n\n    def __call__(self, x, axis=None, to_numpy=False):\n        \"\"\"\n        Quantizes the input tensor.\n\n        Args:\n            x: Input tensor to quantize.\n            axis: Axis along which to compute the quantization range. If None,\n                uses the axis specified in the constructor. If None and no axis\n                was specified in the constructor, defaults to -1.\n            to_numpy: Whether to perform the quantization in numpy. This\n                performs the computation on the host CPU and can be useful for\n                saving memory on the device. If False, the computation is\n                performed on the device.\n\n        Returns:\n            A tuple of the quantized tensor and the scale.\n        \"\"\"\n        if axis is None:\n            axis = self.axis\n        if axis is None:\n            # Default to -1 if no axis is specified\n            axis = -1\n        quantized_x, scale = abs_max_quantize(\n            x,\n            axis,\n            self.value_range,\n            self.output_dtype,\n            self.epsilon,\n            to_numpy,\n        )\n        return quantized_x, scale\n\n    def get_config(self):\n        config = {\n            \"value_range\": self.value_range,\n            \"epsilon\": self.epsilon,\n            \"output_dtype\": self.output_dtype,\n        }\n        if self.axis is not None:\n            config[\"axis\"] = self.axis\n        return config\n\n\ndef adjust_and_nudge(min_range, max_range, num_bits, narrow_range):\n    \"\"\"Adjusts and nudges the quantization range for better accuracy.\"\"\"\n    # Use higher precision for the computation.\n    compute_dtype = backend.result_type(min_range.dtype, \"float32\")\n    min_range = ops.cast(min_range, compute_dtype)\n    max_range = ops.cast(max_range, compute_dtype)\n\n    quant_max = (1 << num_bits) - 1\n    quant_min = 0 if not narrow_range else 1\n    diff_range = ops.subtract(max_range, min_range)\n\n    # Calculate the scale and ensure it's positive\n    scale = ops.divide(diff_range, quant_max - quant_min)\n\n    # Re-calculate the inverse to avoid loss of precision\n    inv_scale = ops.divide(quant_max - quant_min, diff_range)\n\n    # Calculate the zero point from the min range\n    zero_point_from_min = quant_min - ops.divide(min_range, scale)\n\n    # Ensure zero point is within valid range [0, quant_max]\n    zero_point = ops.clip(zero_point_from_min, quant_min, quant_max)\n\n    # Nudge zero point if it's very close to an integer\n    nudged_zero_point = ops.round(zero_point)\n\n    # Calculate nudged limits\n    nudged_min = ops.multiply(ops.subtract(quant_min, nudged_zero_point), scale)\n    nudged_max = ops.multiply(ops.subtract(quant_max, nudged_zero_point), scale)\n\n    return nudged_min, nudged_max, scale, inv_scale\n\n\nclass FakeQuantWithMinMaxVars(Operation):\n    def __init__(self, num_bits=8, narrow_range=False, axis=None):\n        super().__init__()\n        self.num_bits = num_bits\n        self.narrow_range = narrow_range\n        self.axis = axis\n\n    def call(self, inputs, min_vals, max_vals):\n        return fake_quant_with_min_max_vars(\n            inputs,\n            min_vals,\n            max_vals,\n            num_bits=self.num_bits,\n            narrow_range=self.narrow_range,\n            axis=self.axis,\n        )\n\n    def compute_output_spec(self, inputs, min_vals, max_vals):\n        return KerasTensor(inputs.shape, dtype=inputs.dtype)\n\n\n@keras_export(\"keras.quantizers.fake_quant_with_min_max_vars\")\ndef fake_quant_with_min_max_vars(\n    inputs,\n    min_vals,\n    max_vals,\n    num_bits=8,\n    narrow_range=False,\n    axis=None,\n):\n    \"\"\"Perform per-tensor or per-channel fake quantization.\n\n    `[min_vals, max_vals]` define the clamping range for the `inputs`.\n\n    The `inputs` are quantized into the quantization range:\n    - `[0, 2^num_bits - 1]` when `narrow_range=False`\n    - `[1, 2^num_bits - 1]` when `narrow_range=True`\n\n    After quantization, the values are dequantized and output as floats within\n    the `[min_vals, max_vals]` interval.\n\n    This operation supports gradient computation, allowing `min_vals` and\n    `max_vals` to be trained.\n\n    Args:\n        inputs: Input Keras tensor of float dtype.\n        min_vals: A global minimum scalar or a per-channel minimum tensor.\n        max_vals: A global maximum scalar or a per-channel maximum tensor.\n        num_bits: Quantization bit width (e.g., `8` for int8). Defaults to `8`.\n        narrow_range: Whether to use narrow quantization range. Defaults to\n            `False`.\n        axis: Axis along which to perform per-channel quantization. If `None`,\n              per-tensor quantization is performed. Defaults to `None`.\n\n\n    Returns:\n        Tensor: A Keras tensor with fake quantization applied.\n    \"\"\"\n    if any_symbolic_tensors((inputs,)):\n        return FakeQuantWithMinMaxVars().symbolic_call(\n            inputs, min_vals, max_vals\n        )\n\n    inputs = ops.convert_to_tensor(inputs)\n    min_vals = ops.convert_to_tensor(min_vals)\n    max_vals = ops.convert_to_tensor(max_vals)\n    num_bits = int(num_bits)\n\n    if axis is not None:\n        axis = canonicalize_axis(axis, inputs.ndim)\n\n    # Shortcut for TensorFlow backend by using `tf.quantization.fake_quant_*`\n    # apis. This is necessary to be recognizable for the TFLite converter.\n    if backend.backend() == \"tensorflow\":\n        import tensorflow as tf\n\n        # `tf.quantization.fake_quant_*` only supports float32.\n        dtype = backend.standardize_dtype(inputs.dtype)\n        if axis is None:\n            outputs = tf.quantization.fake_quant_with_min_max_vars(\n                ops.cast(inputs, \"float32\"),\n                ops.cast(ops.reshape(min_vals, ()), \"float32\"),\n                ops.cast(ops.reshape(max_vals, ()), \"float32\"),\n                num_bits=num_bits,\n                narrow_range=narrow_range,\n            )\n            return ops.cast(outputs, dtype=dtype)\n        else:\n            # `tf.quantization.fake_quant_with_min_max_vars_per_channel` only\n            # supports the last channel for the per-channel quantization. We\n            # use `ops.swapaxes` for the pre- and post-processing.\n            last_axis = inputs.ndim - 1\n            inputs = ops.swapaxes(inputs, axis, last_axis)\n            outputs = tf.quantization.fake_quant_with_min_max_vars_per_channel(\n                ops.cast(inputs, \"float32\"),\n                ops.cast(min_vals, \"float32\"),\n                ops.cast(max_vals, \"float32\"),\n                num_bits=num_bits,\n                narrow_range=narrow_range,\n            )\n            outputs = ops.cast(outputs, dtype=dtype)\n            return ops.swapaxes(outputs, last_axis, axis)\n\n    @ops.custom_gradient\n    def _fake_quant_with_min_max_vars_per_channel(x, min_val, max_val):\n        dtype = backend.standardize_dtype(x.dtype)\n\n        # Calculate quantization parameters for all channels at once\n        nudged_min, nudged_max, scale, inv_scale = adjust_and_nudge(\n            min_val, max_val, num_bits, narrow_range\n        )\n\n        quant_zero = ops.floor(\n            ops.add(ops.multiply(-nudged_min, inv_scale), 0.5)\n        )\n        x_clamped = ops.clip(\n            ops.cast(x, nudged_min.dtype), nudged_min, nudged_max\n        )\n        x_clamped_shifted = ops.subtract(x_clamped, nudged_min)\n        result = ops.multiply(\n            ops.floor(\n                ops.add(\n                    ops.subtract(\n                        ops.multiply(x_clamped_shifted, inv_scale), quant_zero\n                    ),\n                    0.5,\n                )\n            ),\n            scale,\n        )\n        result = ops.cast(result, dtype=dtype)\n\n        # Create gradient mask for all channels\n        masks = ops.logical_and(\n            ops.greater_equal(x, nudged_min), ops.less_equal(x, nudged_max)\n        )\n\n        def grad(*args, upstream=None):\n            if upstream is None:\n                (upstream,) = args\n\n            # Gradient for x\n            dx = ops.where(masks, upstream, 0.0)\n            axes = [i for i in range(len(dx.shape)) if i != axis]\n\n            # Gradient for min_val\n            # When x is clipped to min, the gradient flows to min_val\n            min_mask = ops.less_equal(x, nudged_min)\n            grad_min = ops.where(min_mask, upstream, 0.0)\n            if axis is not None:\n                grad_min = ops.sum(grad_min, axis=axes)\n            else:\n                grad_min = ops.sum(grad_min)\n            grad_min = ops.reshape(grad_min, ops.shape(min_val))\n\n            # Gradient for max_val\n            # When x is clipped to max, the gradient flows to max_val\n            max_mask = ops.greater_equal(x, nudged_max)\n            grad_max = ops.where(max_mask, upstream, 0.0)\n            if axis is not None:\n                grad_max = ops.sum(grad_max, axis=axes)\n            else:\n                grad_max = ops.sum(grad_max)\n            grad_max = ops.reshape(grad_max, ops.shape(max_val))\n\n            return dx, grad_min, grad_max\n\n        return result, grad\n\n    return _fake_quant_with_min_max_vars_per_channel(inputs, min_vals, max_vals)\n\n\n\"\"\"Float8-related methods\"\"\"\n\n\n@keras_export(\"keras.quantizers.compute_float8_scale\")\ndef compute_float8_scale(amax, scale, dtype_max, margin=0):\n    # The algorithm for computing the new scale is sourced from\n    # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/jax.html#transformer_engine.jax.update_fp8_metas\n    # wherein the `original_scale` corresponds to the reciprocal of the\n    # `scale` passed in this function.\n    scale = ops.reciprocal(scale)\n    sf = ops.divide(ops.divide(dtype_max, amax), 2**margin)\n    sf = ops.where(amax > 0.0, sf, scale)\n    sf = ops.where(ops.isfinite(amax), sf, scale)\n    return ops.reciprocal(sf)\n\n\n@keras_export(\"keras.quantizers.compute_float8_amax_history\")\ndef compute_float8_amax_history(x, amax_history):\n    amax_update = ops.cast(ops.max(ops.abs(x)), amax_history.dtype)\n    new_amax_history = ops.scatter_update(\n        ops.roll(amax_history, shift=-1),\n        [[0]],\n        ops.reshape(amax_update, [1]),\n    )\n    return new_amax_history\n\n\n@keras_export(\"keras.quantizers.quantize_and_dequantize\")\ndef quantize_and_dequantize(inputs, scale, quantized_dtype, compute_dtype):\n    # Quantize\n    quantized_dtype_max = ops.cast(\n        float(ml_dtypes.finfo(quantized_dtype).max), compute_dtype\n    )\n    x = ops.divide(inputs, ops.cast(scale, compute_dtype))\n    x = ops.clip(x, -quantized_dtype_max, quantized_dtype_max)\n    x = ops.cast(x, quantized_dtype)\n\n    # Dequantize\n    x = ops.multiply(ops.cast(x, compute_dtype), ops.cast(scale, compute_dtype))\n    return x\n\n\n@keras_export(\"keras.quantizers.pack_int4\")\ndef pack_int4(arr, axis=0, dtype=\"int8\"):\n    \"\"\"Pack an int4 tensor into an int8 tensor with packed nibbles.\n\n    The input values must already be int8 in the signed range `[-8, 7]` and\n    represent the desired int4 values. Packing is performed along the specified\n    axis (default is 0).\n\n    For every two consecutive rows, the **low nibble** of the output byte\n    stores the value from the first row, and the **high nibble** stores\n    the value from the second row.\n\n    Args:\n        arr: An `int8` or `uint8` tensor containing int4 values in the range\n            `[-8, 7]`.\n        axis: The axis along which to pack the tensor. Defaults to 0.\n        dtype: The data type of the input and packed tensor. Can be\n            `\"int8\"` or `\"uint8\"`. Defaults to `\"int8\"`.\n\n    Returns:\n        tuple: A tuple `(packed, packed_shape, orig_rows)` where `packed` is\n            the packed int8 tensor with int4 values stored in nibbles,\n            `packed_shape` is the shape of the packed tensor, and `orig_rows`\n            is the original (unpacked) row count prior to any padding that may\n            have been inserted when an odd number of rows is supplied.\n\n    Example:\n\n    ```python\n    >>> import numpy as np\n    >>> from keras.quantizers import pack_int4, unpack_int4\n\n    # Example with axis=0\n    # Original array has shape (3, 2)\n    >>> original_array = np.array([[-3, 7], [2, -8], [1, 0]], dtype=np.int8)\n\n    # Pack the array along axis 0. Since the length of axis 0 (3) is\n    # odd, it will be padded to a length of 4. The packed array will\n    # have a shape of (ceil(3/2), 2) = (2, 2).\n    >>> packed, packed_shape, orig_len = pack_int4(original_array, axis=0)\n    >>> print(\"Packed array:\\n\", packed)\n    Packed array:\n    [[  45 -121]\n    [   1    0]]\n\n    # Now, unpack the array back to its original form\n    >>> unpacked = unpack_int4(packed, orig_len, axis=0)\n    >>> print(\"Unpacked array:\\n\", unpacked)\n    Unpacked array:\n    [[-3  7]\n    [ 2 -8]\n    [ 1  0]]\n    >>> np.allclose(original_array, unpacked)\n    True\n\n    # Example with axis=1\n    # Original array has shape (2, 3)\n    >>> original_array = np.array([[-3, 7, 2], [-8, 1, 0]], dtype=np.int8)\n\n    # Pack along axis 1. Length of axis 1 (3) is padded to 4.\n    # The new shape is (2, ceil(3/2)) = (2, 2).\n    >>> packed, packed_shape, orig_len = pack_int4(original_array, axis=1)\n    >>> print(\"Packed array:\\n\", packed)\n    Packed array:\n    [[ 125   2]\n    [  24   0]]\n\n    # Unpack the array\n    >>> unpacked = unpack_int4(packed, orig_len, axis=1)\n    >>> print(\"Unpacked array:\\n\", unpacked)\n    Unpacked array:\n    [[-3  7  2]\n    [-8  1  0]]\n    >>> np.allclose(original_array, unpacked)\n    True\n    ```\n    \"\"\"\n    if dtype not in (\"int8\", \"uint8\"):\n        raise ValueError(\n            f\"Expected dtype to be 'int8' or 'uint8', but got '{dtype}'.\"\n        )\n    if backend.standardize_dtype(arr.dtype) != dtype:\n        raise TypeError(\n            f\"Expected {dtype} tensor for packing, got \"\n            f\"{backend.standardize_dtype(arr.dtype)}.\"\n        )\n\n    # Perform packing in numpy. Packing is only called during\n    # quantization (not inference), and numpy correctly handles int8\n    # overflow in bitwise operations. Some accelerators (e.g. TPU) may\n    # produce incorrect results for int8 left_shift that overflows, so\n    # using numpy avoids device-specific issues.\n    arr_np = ops.convert_to_numpy(arr)\n    np_dtype = np.dtype(dtype)\n\n    rank = len(arr_np.shape)\n    if axis < 0:\n        axis += rank\n\n    # Move the pack axis to the front for uniform handling.\n    arr_np = np.moveaxis(arr_np, axis, 0)\n\n    # Pad to even length along the front axis.\n    n = arr_np.shape[0]\n    if n % 2 == 1:\n        pad_shape = (1,) + arr_np.shape[1:]\n        arr_np = np.concatenate(\n            [arr_np, np.zeros(pad_shape, dtype=arr_np.dtype)], axis=0\n        )\n\n    # Group in pairs and pack nibbles.\n    low = arr_np[::2]\n    high = arr_np[1::2]\n\n    mask = np.array(0x0F, dtype=np_dtype)\n    low_u = np.bitwise_and(low.astype(np_dtype), mask)\n    high_u = np.bitwise_and(high.astype(np_dtype), mask)\n\n    packed_np = np.bitwise_or(\n        low_u, np.left_shift(high_u, np.array(4, dtype=np_dtype))\n    )\n    packed_np = packed_np.astype(np_dtype)\n\n    # Move the pack axis back to its original position.\n    packed_np = np.moveaxis(packed_np, 0, axis)\n\n    packed = ops.convert_to_tensor(packed_np)\n    return packed, tuple(packed_np.shape), n\n\n\n@keras_export(\"keras.quantizers.unpack_int4\")\ndef unpack_int4(packed, orig_len, axis=0, dtype=\"int8\"):\n    \"\"\"Unpack a packed int4 back to an int8 tensor in the range [-8, 7].\n\n    This function reverses the packing performed by `pack_int4`, restoring\n    the original int8 tensor (values in the range [-8, 7]) from a packed int8\n    tensor where each element contains two int4 values (one in the lower nibble,\n    one in the upper nibble).\n\n    The function restores the original axis order and removes any\n    padding that was added during packing.\n\n    Args:\n        packed: An int8 tensor containing packed int4 values along the\n            specified axis. Each int8 value encodes two int4 values.\n        orig_len: The original (unpadded) length of the axis that was\n            packed. This is used to remove any padding that may have\n            been added during packing to ensure an even number of rows.\n        axis: The axis along which the tensor was packed. Defaults to 0.\n        dtype: The data type of the input and unpacked tensor. Can be\n            `\"int8\"` or `\"uint8\"`. Defaults to `\"int8\"`.\n\n    Returns:\n        unpacked: An int8 tensor with the same shape as the original\n            (unpacked) tensor, with values in the range [-8, 7].\n\n    Example:\n\n    ```python\n    >>> import numpy as np\n    >>> from keras.quantizers import pack_int4, unpack_int4\n\n    # Example with axis=0\n    # Original array has shape (3, 2)\n    >>> original_array = np.array([[-3, 7], [2, -8], [1, 0]], dtype=np.int8)\n\n    # Pack the array along axis 0. Since the length of axis 0 (3) is\n    # odd, it will be padded to a length of 4. The packed array will\n    # have a shape of (ceil(3/2), 2) = (2, 2).\n    >>> packed, packed_shape, orig_len = pack_int4(original_array, axis=0)\n    >>> print(\"Packed array:\\n\", packed)\n    Packed array:\n    [[  45 -121]\n    [   1    0]]\n\n    # Now, unpack the array back to its original form\n    >>> unpacked = unpack_int4(packed, orig_len, axis=0)\n    >>> print(\"Unpacked array:\\n\", unpacked)\n    Unpacked array:\n    [[-3  7]\n    [ 2 -8]\n    [ 1  0]]\n    >>> np.allclose(original_array, unpacked)\n    True\n\n    # Example with axis=1\n    # Original array has shape (2, 3)\n    >>> original_array = np.array([[-3, 7, 2], [-8, 1, 0]], dtype=np.int8)\n\n    # Pack along axis 1. Length of axis 1 (3) is padded to 4.\n    # The new shape is (2, ceil(3/2)) = (2, 2).\n    >>> packed, packed_shape, orig_len = pack_int4(original_array, axis=1)\n    >>> print(\"Packed array:\\n\", packed)\n    Packed array:\n    [[ 125   2]\n    [  24   0]]\n\n    # Unpack the array\n    >>> unpacked = unpack_int4(packed, orig_len, axis=1)\n    >>> print(\"Unpacked array:\\n\", unpacked)\n    Unpacked array:\n    [[-3  7  2]\n    [-8  1  0]]\n    >>> np.allclose(original_array, unpacked)\n    True\n    ```\n    \"\"\"\n    if dtype not in (\"int8\", \"uint8\"):\n        raise ValueError(\n            f\"Expected dtype to be 'int8' or 'uint8', but got '{dtype}'.\"\n        )\n\n    if backend.standardize_dtype(packed.dtype) not in (\"int8\", \"uint8\"):\n        raise TypeError(\n            f\"Expected int8 or uint8 tensor for unpacking, got {packed.dtype}\"\n        )\n\n    def to_signed(x):\n        \"\"\"Converts unpacked nibbles [0, 15] to signed int4 [-8, 7].\n\n        Uses a branchless XOR approach: (x ^ 8) - 8\n        This maps: 0->0, 1->1, ..., 7->7, 8->-8, 9->-7, ..., 15->-1\n        \"\"\"\n        dtype_x = backend.standardize_dtype(x.dtype)\n        eight = ops.cast(8, dtype_x)\n        return ops.subtract(ops.bitwise_xor(x, eight), eight)\n\n    rank = getattr(packed.shape, \"rank\", None) or len(packed.shape)\n    if axis < 0:\n        axis += rank\n\n    # Fast path for axis==0 (common case in Dense layers)\n    if axis == 0 and rank == 2:\n        mask = ops.array(0x0F, dtype=packed.dtype)\n        low_unpacked = ops.bitwise_and(packed, mask)\n        high_unpacked = ops.bitwise_and(ops.right_shift(packed, 4), mask)\n\n        if dtype == \"int8\":\n            low_unpacked = to_signed(low_unpacked)\n            high_unpacked = to_signed(high_unpacked)\n\n        low_final = ops.cast(low_unpacked, dtype)\n        high_final = ops.cast(high_unpacked, dtype)\n\n        # Interleave along axis 0 and reshape\n        stacked = ops.stack([low_final, high_final], axis=1)\n        unpacked = ops.reshape(stacked, (-1,) + tuple(ops.shape(packed)[1:]))\n\n        # Remove padding and return\n        return unpacked[:orig_len, ...]\n\n    # General case\n    perm = [axis] + [i for i in range(rank) if i != axis]\n    inv_perm = [perm.index(i) for i in range(rank)]\n    transposed = ops.transpose(packed, perm)\n\n    # 1. Split nibbles.\n    mask = ops.array(0x0F, dtype=packed.dtype)\n    low = ops.bitwise_and(transposed, mask)\n    high = ops.bitwise_and(ops.right_shift(transposed, 4), mask)\n\n    # 2. Conditionally convert to signed.\n    if dtype == \"int8\":\n        low = to_signed(low)\n        high = to_signed(high)\n\n    low = ops.cast(low, dtype)\n    high = ops.cast(high, dtype)\n\n    # 3. Interleave and reshape.\n    stacked = ops.stack([low, high], axis=1)\n    unpacked = ops.reshape(stacked, (-1,) + tuple(ops.shape(transposed)[1:]))\n\n    # 4. Remove padding and restore original layout.\n    unpacked = unpacked[:orig_len, ...]\n    unpacked = ops.transpose(unpacked, inv_perm)\n\n    return unpacked\n\n\nclass GPTQQuantizer(Quantizer):\n    \"\"\"A class that handles the quantization of weights using GPTQ method.\n\n    This class provides methods to find quantization parameters (scale and zero)\n    for a given tensor and can be used to quantize weights in a GPTQ context.\n\n    Args:\n        weight_bits: (int) The number of bits to quantize to (e.g., 4).\n        per_channel: (bool) A flag indicating whether quantization is\n            applied per-channel (`True`) or per-tensor (`False`).\n            Defaults to `False`.\n        symmetric: (bool) A flag indicating whether symmetric (`True`) or\n            asymmetric (`False`) quantization is used. Defaults to `False`.\n        group_size: (int) The size of weight groups for quantization. A\n            value of -1 indicates that grouping is not used.\n            Defaults to -1.\n    \"\"\"\n\n    def __init__(\n        self,\n        config=GPTQConfig(tokenizer=None, dataset=None),\n        compute_dtype=\"float32\",\n    ):\n        Quantizer.__init__(self)\n        self.weight_bits = config.weight_bits\n        self.per_channel = config.per_channel\n        self.symmetric = config.symmetric\n        self.group_size = config.group_size\n        self.compute_dtype = compute_dtype\n\n        # These are now determined later by `find_params`\n        self.scale = None\n        self.zero = None\n        self.maxq = None\n\n    def find_params(self, input_tensor):\n        \"\"\"Finds quantization parameters (scale and zero) for a given tensor.\"\"\"\n        self.scale, self.zero, self.maxq = compute_quantization_parameters(\n            input_tensor,\n            bits=self.weight_bits,\n            symmetric=self.symmetric,\n            per_channel=self.per_channel,\n            group_size=self.group_size,\n            compute_dtype=self.compute_dtype,\n        )\n        return self.scale, self.zero, self.maxq\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"weight_bits\": self.weight_bits,\n                \"per_channel\": self.per_channel,\n                \"symmetric\": self.symmetric,\n                \"group_size\": self.group_size,\n            }\n        )\n        return config\n\n    @classmethod\n    def from_config(cls, config):\n        gptq = GPTQConfig(\n            tokenizer=None,\n            dataset=None,\n            weight_bits=config[\"weight_bits\"],\n            per_channel=config[\"per_channel\"],\n            symmetric=config[\"symmetric\"],\n            group_size=config[\"group_size\"],\n        )\n        return cls(gptq)\n\n\ndef compute_quantization_parameters(\n    x,\n    *,\n    bits,\n    symmetric=False,\n    per_channel=False,\n    group_size=-1,\n    compute_dtype=\"float32\",\n    epsilon=0.0,\n    signed=False,\n):\n    \"\"\"\n    Computes the scale and zero-point for quantizing weight tensors.\n\n    This function calculates the scale and zero-point required for quantizing\n    a given weight tensor `x` based on the specified parameters. It supports\n    grouped, per-channel, per-tensor, symmetric, and asymmetric quantization.\n\n    For grouped quantization (per_channel=True, group_size > 0), the output\n    shapes are [out_features, n_groups] where n_groups is the number of groups\n    along the in_features dimension.\n\n    Args:\n        x: KerasTensor. The weight tensor to quantize with shape\n            [out_features, in_features].\n        bits: int. The number of bits to quantize to (e.g., 4).\n        symmetric: bool. Whether to use symmetric quantization.\n        per_channel: bool. Whether to quantize per channel.\n        group_size: int. The group size for quantization. -1 means no grouping.\n        compute_dtype: str. The dtype for computation. Defaults to \"float32\".\n        epsilon: float. Small value added to (max - min) before computing\n            scale to avoid division by zero. Defaults to 0.0.\n        signed: bool. Whether to use signed quantization range. If True, uses\n            range [-2^(bits-1), 2^(bits-1)-1] (e.g., [-8, 7] for 4-bit).\n            If False, uses range [0, 2^bits-1] (e.g., [0, 15] for 4-bit).\n            Defaults to False.\n\n    Returns:\n        scale: KerasTensor. The scale tensor for quantization.\n        zero: KerasTensor. The zero tensor for quantization (int8 if signed,\n            uint8 if unsigned).\n        maxq: scalar. The maximum quantization value.\n    \"\"\"\n    # Input validation\n    if x is None:\n        raise ValueError(f\"Input tensor {x} cannot be None.\")\n    if len(x.shape) < 2:\n        raise ValueError(\n            f\"Input weight tensor {x} must have a rank of at \"\n            f\"least 2, but got rank {len(x.shape)}.\"\n        )\n    if ops.size(x) == 0:\n        raise ValueError(\"Input tensor 'x' cannot be empty.\")\n\n    out_features, in_features = x.shape[0], x.shape[1]\n\n    # Determine number of groups for quantization\n    if per_channel and group_size > 0:\n        n_groups = (in_features + group_size - 1) // group_size\n    else:\n        n_groups = 1\n\n    # Compute min/max values based on quantization mode\n    if n_groups > 1:\n        # Grouped quantization: output shape [out_features, n_groups]\n        remainder = in_features % group_size\n        if remainder != 0:\n            pad_size = group_size - remainder\n            x = ops.pad(x, [[0, 0], [0, pad_size]], constant_values=0.0)\n\n        x_grouped = ops.reshape(x, [out_features, n_groups, group_size])\n        min_values = ops.min(x_grouped, axis=2)\n        max_values = ops.max(x_grouped, axis=2)\n    else:\n        # Per-channel or per-tensor: compute stats along rows\n        reduction_shape = [out_features, -1] if per_channel else [1, -1]\n        x_reshaped = ops.reshape(x, reduction_shape)\n        min_values = ops.min(x_reshaped, axis=1)\n        max_values = ops.max(x_reshaped, axis=1)\n\n    # Symmetric quantization: make range symmetric around zero\n    if symmetric:\n        max_abs = ops.maximum(ops.abs(min_values), max_values)\n        min_values = ops.where(\n            ops.less(min_values, 0), ops.negative(max_abs), min_values\n        )\n        max_values = max_abs\n\n    # Ensure non-zero range to avoid division errors\n    zero_range = ops.equal(min_values, max_values)\n    min_values = ops.where(zero_range, ops.subtract(min_values, 1), min_values)\n    max_values = ops.where(zero_range, ops.add(max_values, 1), max_values)\n\n    # Compute scale and zero-point\n    maxq = ops.cast(ops.subtract(ops.power(2, bits), 1), compute_dtype)\n    range_values = ops.subtract(max_values, min_values)\n    if epsilon > 0:\n        range_values = ops.add(range_values, epsilon)\n    scale = ops.divide(range_values, maxq)\n    scale = ops.where(ops.less_equal(scale, 0), 1e-8, scale)\n\n    # Compute zero-point based on signed/unsigned mode\n    if signed:\n        # For signed range [-2^(bits-1), 2^(bits-1)-1], e.g., [-8, 7] for 4-bit\n        qmin = -(2 ** (bits - 1))  # e.g., -8 for 4-bit\n        qmax_signed = 2 ** (bits - 1) - 1  # e.g., 7 for 4-bit\n        if symmetric:\n            zero = ops.full_like(scale, ops.divide(ops.add(maxq, 1), 2) + qmin)\n        else:\n            # zero_signed = round(-min / scale) + qmin\n            zero = ops.add(\n                ops.round(ops.divide(ops.negative(min_values), scale)), qmin\n            )\n        zero = ops.clip(zero, qmin, qmax_signed)\n    else:\n        # For unsigned range [0, 2^bits-1], e.g., [0, 15] for 4-bit\n        if symmetric:\n            zero = ops.full_like(scale, ops.divide(ops.add(maxq, 1), 2))\n        else:\n            zero = ops.round(ops.divide(ops.negative(min_values), scale))\n\n    # Reshape output to [out_features, n_groups] or [out_features, 1]\n    if n_groups > 1:\n        pass  # Already [out_features, n_groups]\n    elif per_channel:\n        scale = ops.reshape(scale, [-1, 1])\n        zero = ops.reshape(zero, [-1, 1])\n    else:\n        # Per-tensor: tile single value to [out_features, 1]\n        scale = ops.tile(ops.reshape(scale, (1, 1)), (out_features, 1))\n        zero = ops.tile(ops.reshape(zero, (1, 1)), (out_features, 1))\n\n    zero_dtype = \"int8\" if signed else \"uint8\"\n    return scale, ops.cast(zero, zero_dtype), maxq\n\n\ndef quantize_with_zero_point(input_tensor, scale, zero, maxq):\n    \"\"\"Quantize a float tensor into discrete levels [0, maxq] using\n    per-tensor/per-channel/grouped scaling.\n\n    Returns `q` (same dtype as inputs/scales; float is fine) where values are in\n    [0, maxq].\n\n    Args:\n        input_tensor: KerasTensor. The input tensor to quantize.\n        scale: KerasTensor. The scale tensor for quantization.\n        zero: KerasTensor. The zero tensor for quantization.\n        maxq: KerasTensor. The maximum quantization value.\n\n    Returns:\n        KerasTensor. The quantized tensor.\n    \"\"\"\n    # Guard against divide-by-zero\n    epsilon = ops.cast(1e-8, dtype=scale.dtype)\n    safe_scale = ops.where(ops.equal(scale, 0), epsilon, scale)\n\n    quantized_tensor = ops.round(\n        ops.add(\n            ops.divide(input_tensor, safe_scale), ops.cast(zero, scale.dtype)\n        )\n    )\n    quantized_tensor = ops.clip(quantized_tensor, 0, maxq)\n    return quantized_tensor\n\n\ndef dequantize_with_zero_point(input_tensor, scale, zero):\n    \"\"\"\n    Dequantizes a quantized tensor using the provided scale and zero tensors.\n\n    Args:\n        input_tensor: KerasTensor. The quantized tensor to dequantize.\n        scale: KerasTensor. The scale tensor for dequantization.\n        zero: KerasTensor. The zero tensor for dequantization.\n\n    Returns:\n        KerasTensor. The dequantized tensor.\n    \"\"\"\n    return ops.multiply(\n        scale, ops.subtract(input_tensor, ops.cast(zero, scale.dtype))\n    )\n\n\ndef quantize_with_sz_map(\n    weights_matrix, scale, zero, g_idx, maxq, group_axis=-1\n):\n    \"\"\"Quantize the weight matrix from group params.\n\n    This function uses the provided scale and zero tensors to quantize the\n    input weights_matrix according to the group indices. It maps each position\n    along group_axis of the weights_matrix to its corresponding group\n    parameters and performs the quantization operation.\n\n    Args:\n        weights_matrix: Tensor to quantize.\n        scale: Per-group scale tensor with n_groups along group_axis.\n        zero: Per-group zero-point tensor with n_groups along group_axis.\n        g_idx: 1D integer tensor of length equal to the size of\n            `weights_matrix` along the dimension being quantized. Each\n            element specifies which group index (0 to n_groups-1) that\n            position belongs to. For example, with 128 columns and\n            group_size=32, g_idx would be\n            `[0,0,...,0, 1,1,...,1, 2,2,...,2, 3,3,...,3]` (32 of each).\n        maxq: Scalar (float) representing the maximum integer quantization\n            level (e.g., 2^bits - 1).\n        group_axis: The axis in `scale` and `zero` along which to index\n            using `g_idx`. This determines which dimension of the\n            scale/zero tensors contains the per-group values. Default: -1\n            (last axis).\n\n    Returns:\n        A tensor with the same shape as `weights_matrix` containing the\n        quantized weights produced using the provided group parameters.\n    \"\"\"\n    groups = ops.cast(g_idx, \"int32\")\n    scale_cols = ops.take(scale, groups, axis=group_axis)\n    zero_cols = ops.take(zero, groups, axis=group_axis)\n\n    # Quantize elementwise, then cast to int\n    return quantize_with_zero_point(weights_matrix, scale_cols, zero_cols, maxq)\n\n\ndef dequantize_with_sz_map(weights_matrix, scale, zero, g_idx, group_axis=-1):\n    \"\"\"Rebuild a dequantized weight matrix from group params.\n\n    This function uses the provided scale and zero tensors to dequantize the\n    input weights_matrix according to the group indices. It maps each position\n    along group_axis of the weights_matrix to its corresponding group\n    parameters and performs the dequantization operation.\n\n    Args:\n        weights_matrix: Tensor to dequantize.\n        scale: Per-group scale tensor with n_groups along group_axis.\n        zero: Per-group zero-point tensor with n_groups along group_axis.\n        g_idx: 1D integer tensor of length equal to the size of\n            `weights_matrix` along the dimension being dequantized. Each\n            element specifies which group index (0 to n_groups-1) that\n            position belongs to. For example, with 128 columns and\n            group_size=32, g_idx would be\n            `[0,0,...,0, 1,1,...,1, 2,2,...,2, 3,3,...,3]` (32 of each).\n        group_axis: The axis in `scale` and `zero` along which to index\n            using `g_idx`. This determines which dimension of the\n            scale/zero tensors contains the per-group values. Default: -1\n            (last axis).\n\n    Returns:\n        A tensor with the same shape as `weights_matrix` containing the\n        dequantized weights produced using the provided group parameters.\n    \"\"\"\n    # Map group indices to scales and zeros\n    groups = ops.cast(g_idx, \"int32\")\n    scales_mapped = ops.take(scale, groups, axis=group_axis)\n    zeros_mapped = ops.take(zero, groups, axis=group_axis)\n    zeros_mapped = ops.cast(zeros_mapped, scales_mapped.dtype)\n\n    dequantized = ops.multiply(\n        ops.subtract(weights_matrix, zeros_mapped), scales_mapped\n    )\n\n    return dequantized\n"
  },
  {
    "path": "keras/src/quantizers/quantizers_test.py",
    "content": "import math\n\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import quantizers\nfrom keras.src import random\nfrom keras.src import testing\nfrom keras.src.quantizers.quantizers import compute_quantization_parameters\nfrom keras.src.quantizers.quantizers import dequantize_with_sz_map\nfrom keras.src.quantizers.quantizers import dequantize_with_zero_point\nfrom keras.src.quantizers.quantizers import quantize_with_sz_map\nfrom keras.src.quantizers.quantizers import quantize_with_zero_point\nfrom keras.src.testing.test_utils import named_product\n\n\nclass QuantizersTest(testing.TestCase):\n    def test_get_method(self):\n        quantizer = quantizers.get(\"abs_max_quantizer\")\n        self.assertTrue(quantizer, quantizers.AbsMaxQuantizer)\n\n        quantizer = quantizers.get(None)\n        self.assertEqual(quantizer, None)\n\n        with self.assertRaises(ValueError):\n            quantizers.get(\"typo\")\n\n    def test_abs_max_quantizer(self):\n        values = random.uniform([3, 4, 5], minval=-1, maxval=1, dtype=\"float32\")\n        quantizer = quantizers.AbsMaxQuantizer()\n\n        # Test quantizing\n        quantized_values, scale = quantizer(values, axis=-1)\n        self.assertDType(quantized_values, \"int8\")\n        self.assertDType(scale, \"float32\")\n        self.assertEqual(tuple(quantized_values.shape), (3, 4, 5))\n        self.assertEqual(tuple(scale.shape), (3, 4, 1))\n        self.assertLessEqual(ops.max(quantized_values), 127)\n        self.assertGreaterEqual(ops.min(quantized_values), -127)\n\n        # Test dequantizing\n        dequantized_values = ops.divide(quantized_values, scale)\n        rmse = ops.sqrt(\n            ops.mean(ops.square(ops.subtract(values, dequantized_values)))\n        )\n        self.assertLess(rmse, 1e-1)  # loose assertion\n\n        # Test serialization\n        self.run_class_serialization_test(quantizer)\n\n        # Test bfloat16 & float16 dtype\n        values = random.uniform(\n            [3, 4, 5], minval=-1, maxval=1, dtype=\"bfloat16\"\n        )\n        quantized_values, scale = quantizer(values, axis=-1)\n        self.assertDType(quantized_values, \"int8\")\n        self.assertDType(scale, \"bfloat16\")\n        values = random.uniform([3, 4, 5], minval=-1, maxval=1, dtype=\"float16\")\n        quantized_values, scale = quantizer(values, axis=-1)\n        self.assertDType(quantized_values, \"int8\")\n        self.assertDType(scale, \"float16\")\n\n    def test_abs_max_quantizer_to_numpy(self):\n        values = random.uniform([3, 4, 5], minval=-1, maxval=1, dtype=\"float32\")\n        quantized_values, scale = quantizers.abs_max_quantize(\n            values, axis=-1, to_numpy=True\n        )\n        ref_quantized_values, ref_scale = quantizers.abs_max_quantize(\n            values, axis=-1\n        )\n        self.assertAllClose(quantized_values, ref_quantized_values)\n        self.assertAllClose(scale, ref_scale)\n\n    def test_compute_float8_scale(self):\n        amax = 3.0\n        scale = 4.0\n        dtype_max = 448.0  # float8_e4m3fn\n        # The algorithm for computing the new scale is sourced from\n        # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/jax.html#transformer_engine.jax.update_fp8_metas\n        expected_scale = 1.0 / (dtype_max / amax) / (2**0)\n\n        computed_scale = quantizers.compute_float8_scale(amax, scale, dtype_max)\n        self.assertAllClose(computed_scale, expected_scale)\n\n    def test_compute_float8_amax_history(self):\n        values = random.uniform([3, 4, 5], minval=-1, maxval=1)\n        amax_history = random.uniform([123])\n        amax_from_values = ops.max(ops.abs(values))\n\n        computed_amax_history = quantizers.compute_float8_amax_history(\n            values, amax_history\n        )\n        self.assertAllClose(computed_amax_history[0], amax_from_values)\n        # Shift to left with 1 step\n        self.assertAllClose(\n            computed_amax_history[1:], ops.roll(amax_history, -1)[1:]\n        )\n\n    def test_quantize_and_dequantize(self):\n        scale = 1.0 / 100.0\n        values = random.uniform([3, 4, 5], minval=-1, maxval=1)\n        qdq_values = quantizers.quantize_and_dequantize(\n            values, scale, \"float8_e4m3fn\", \"float32\"\n        )\n        # A loose assertion due to an expected quantization error\n        self.assertAllClose(qdq_values, values, atol=1e-1)\n\n        qdq_values = quantizers.quantize_and_dequantize(\n            values, scale, \"float8_e5m2\", \"float32\"\n        )\n        # A loose assertion due to an expected quantization error\n        self.assertAllClose(qdq_values, values, atol=5e-1)\n\n    SHAPE_AXIS_SCENARIOS = [\n        # 1. 2D Tensors\n        # Covers the unpack fast path (rank=2, axis=0) for both parities\n        {\"testcase_name\": \"2d_axis0_odd\", \"shape\": (5, 8), \"axis\": 0},\n        {\"testcase_name\": \"2d_axis0_even\", \"shape\": (4, 8), \"axis\": 0},\n        # Covers the general path and a negative axis for 2D tensors\n        {\"testcase_name\": \"2d_axis1_odd\", \"shape\": (8, 7), \"axis\": 1},\n        {\"testcase_name\": \"2d_axis_neg1_even\", \"shape\": (8, 6), \"axis\": -1},\n        # 2. Higher-Rank Tensors\n        # Covers a middle axis for a complex shape with both parities\n        {\"testcase_name\": \"4d_axis1_odd\", \"shape\": (2, 5, 4, 6), \"axis\": 1},\n        {\"testcase_name\": \"4d_axis2_even\", \"shape\": (2, 4, 8, 6), \"axis\": 2},\n        # Covers the last axis of a complex shape with a negative index\n        {\n            \"testcase_name\": \"4d_axis_neg1_odd\",\n            \"shape\": (2, 4, 6, 7),\n            \"axis\": -1,\n        },\n    ]\n\n    DTYPE_PARAMS = [\n        {\"testcase_name\": \"int8\", \"dtype\": \"int8\", \"minval\": -8, \"maxval\": 8},\n        {\"testcase_name\": \"uint8\", \"dtype\": \"uint8\", \"minval\": 0, \"maxval\": 16},\n    ]\n\n    @parameterized.named_parameters(\n        named_product(SHAPE_AXIS_SCENARIOS, DTYPE_PARAMS)\n    )\n    def test_pack_unpack_int4(self, shape, axis, dtype, minval, maxval):\n        # Create a random tensor with int4 values in the specified range and\n        # dtype\n        arr = ops.cast(\n            ops.floor(random.uniform(shape, minval=minval, maxval=maxval)),\n            dtype,\n        )\n\n        # Pack the tensor using the specified dtype\n        packed, packed_shape, orig_len = quantizers.pack_int4(\n            arr, axis=axis, dtype=dtype\n        )\n\n        # Unpack the tensor using the specified dtype\n        unpacked = quantizers.unpack_int4(\n            packed, orig_len, axis=axis, dtype=dtype\n        )\n\n        # Verify that the packed tensor has the correct dtype\n        self.assertDType(packed, dtype)\n\n        # Verify that the unpacked tensor has the correct dtype\n        self.assertDType(unpacked, dtype)\n\n        # The unpacked tensor should be the same as the original tensor\n        self.assertAllClose(unpacked, arr)\n\n        # Test the packed shape\n        expected_packed_shape = list(shape)\n        expected_packed_shape[axis] = (expected_packed_shape[axis] + 1) // 2\n        self.assertEqual(\n            list(ops.convert_to_numpy(packed_shape)), expected_packed_shape\n        )\n\n    @parameterized.named_parameters(\n        (\"per_tensor\", None),\n        (\"per_channel\", -1),\n    )\n    def test_fake_quant_with_min_max_vars_symbolic(self, axis):\n        x = backend.KerasTensor((2, 3, 4))\n        y = quantizers.fake_quant_with_min_max_vars(x, -3.0, 3.0, axis=axis)\n\n        self.assertIsInstance(y, backend.KerasTensor)\n        self.assertEqual(y.shape, (2, 3, 4))\n\n    @parameterized.named_parameters(\n        [\n            {\n                \"testcase_name\": \"wide_8bits_input_mins_0.0_input_maxs_255.0\",\n                \"narrow_range\": False,\n                \"input_mins\": [0.0],\n                \"input_maxs\": [255.0],\n                \"num_bits\": 8,\n                \"expected_nudged_input_mins\": [0.0],\n                \"expected_nudged_input_maxs\": [255.0],\n                \"expected_steps\": [1.0],\n                \"axis\": None,\n            },\n            {\n                \"testcase_name\": (\n                    \"wide_8bits_scalar_input_mins_0.0_input_maxs_255.0\"\n                ),\n                \"narrow_range\": False,\n                \"input_mins\": 0.0,\n                \"input_maxs\": 255.0,\n                \"num_bits\": 8,\n                \"expected_nudged_input_mins\": [0.0],\n                \"expected_nudged_input_maxs\": [255.0],\n                \"expected_steps\": [1.0],\n                \"axis\": None,\n            },\n            {\n                \"testcase_name\": \"wide_8bits_input_mins_0.5_input_maxs_128.0\",\n                \"narrow_range\": False,\n                \"input_mins\": [0.5],\n                \"input_maxs\": [128.0],\n                \"num_bits\": 8,\n                \"expected_nudged_input_mins\": [0.0],\n                \"expected_nudged_input_maxs\": [127.5],\n                \"expected_steps\": [0.5],\n                \"axis\": None,\n            },\n            {\n                \"testcase_name\": \"wide_8bits_input_mins_-128.0_input_maxs_-0.5\",\n                \"narrow_range\": False,\n                \"input_mins\": [-128.0],\n                \"input_maxs\": [-0.5],\n                \"num_bits\": 8,\n                \"expected_nudged_input_mins\": [-127.5],\n                \"expected_nudged_input_maxs\": [0.0],\n                \"expected_steps\": [0.5],\n                \"axis\": None,\n            },\n            {\n                \"testcase_name\": \"wide_8bits_input_mins_-0.1_input_maxs_127.4\",\n                \"narrow_range\": False,\n                \"input_mins\": [-0.1],\n                \"input_maxs\": [127.4],\n                \"num_bits\": 8,\n                \"expected_nudged_input_mins\": [0.0],\n                \"expected_nudged_input_maxs\": [127.5],\n                \"expected_steps\": [0.5],\n                \"axis\": None,\n            },\n            {\n                \"testcase_name\": \"narrow_8bits_input_mins_0.0_input_maxs_254.0\",\n                \"narrow_range\": True,\n                \"input_mins\": [0.0],\n                \"input_maxs\": [254.0],\n                \"num_bits\": 8,\n                \"expected_nudged_input_mins\": [0.0],\n                \"expected_nudged_input_maxs\": [254.0],\n                \"expected_steps\": [1.0],\n                \"axis\": None,\n            },\n            {\n                \"testcase_name\": \"narrow_8bits_input_mins_0.1_input_maxs_127.1\",\n                \"narrow_range\": True,\n                \"input_mins\": [0.1],\n                \"input_maxs\": [127.1],\n                \"num_bits\": 8,\n                \"expected_nudged_input_mins\": [0.0],\n                \"expected_nudged_input_maxs\": [127.0],\n                \"expected_steps\": [0.5],\n                \"axis\": None,\n            },\n            {\n                \"testcase_name\": (\n                    \"narrow_8bits_input_mins_-127.1_input_maxs_-0.1\"\n                ),\n                \"narrow_range\": True,\n                \"input_mins\": [-127.1],\n                \"input_maxs\": [-0.1],\n                \"num_bits\": 8,\n                \"expected_nudged_input_mins\": [-127.0],\n                \"expected_nudged_input_maxs\": [0.0],\n                \"expected_steps\": [0.5],\n                \"axis\": None,\n            },\n            {\n                \"testcase_name\": (\n                    \"narrow_8bits_input_mins_-0.1_input_maxs_126.9\"\n                ),\n                \"narrow_range\": True,\n                \"input_mins\": [-0.1],\n                \"input_maxs\": [126.9],\n                \"num_bits\": 8,\n                \"expected_nudged_input_mins\": [0.0],\n                \"expected_nudged_input_maxs\": [127.0],\n                \"expected_steps\": [0.5],\n                \"axis\": None,\n            },\n            {\n                \"testcase_name\": \"wide_7bits_input_mins_0.0_input_maxs_127.0\",\n                \"narrow_range\": False,\n                \"input_mins\": [0.0],\n                \"input_maxs\": [127.0],\n                \"num_bits\": 7,\n                \"expected_nudged_input_mins\": [0.0],\n                \"expected_nudged_input_maxs\": [127.0],\n                \"expected_steps\": [1.0],\n                \"axis\": None,\n            },\n            {\n                \"testcase_name\": \"wide_7bits_input_mins_0.5_input_maxs_64.0\",\n                \"narrow_range\": False,\n                \"input_mins\": [0.5],\n                \"input_maxs\": [64.0],\n                \"num_bits\": 7,\n                \"expected_nudged_input_mins\": [0.0],\n                \"expected_nudged_input_maxs\": [63.5],\n                \"expected_steps\": [0.5],\n                \"axis\": None,\n            },\n            {\n                \"testcase_name\": \"wide_7bits_input_mins_-64.0_input_maxs_-0.5\",\n                \"narrow_range\": False,\n                \"input_mins\": [-64.0],\n                \"input_maxs\": [-0.5],\n                \"num_bits\": 7,\n                \"expected_nudged_input_mins\": [-63.5],\n                \"expected_nudged_input_maxs\": [0.0],\n                \"expected_steps\": [0.5],\n                \"axis\": None,\n            },\n            {\n                \"testcase_name\": \"wide_7bits_input_mins_-0.1_input_maxs_63.4\",\n                \"narrow_range\": False,\n                \"input_mins\": [-0.1],\n                \"input_maxs\": [63.4],\n                \"num_bits\": 7,\n                \"expected_nudged_input_mins\": [0.0],\n                \"expected_nudged_input_maxs\": [63.5],\n                \"expected_steps\": [0.5],\n                \"axis\": None,\n            },\n            {\n                \"testcase_name\": \"narrow_7bits_input_mins_0.0_input_maxs_126.0\",\n                \"narrow_range\": True,\n                \"input_mins\": [0.0],\n                \"input_maxs\": [126.0],\n                \"num_bits\": 7,\n                \"expected_nudged_input_mins\": [0.0],\n                \"expected_nudged_input_maxs\": [126.0],\n                \"expected_steps\": [1.0],\n                \"axis\": None,\n            },\n            {\n                \"testcase_name\": \"narrow_7bits_input_mins_0.1_input_maxs_63.1\",\n                \"narrow_range\": True,\n                \"input_mins\": [0.1],\n                \"input_maxs\": [63.1],\n                \"num_bits\": 7,\n                \"expected_nudged_input_mins\": [0.0],\n                \"expected_nudged_input_maxs\": [63.0],\n                \"expected_steps\": [0.5],\n                \"axis\": None,\n            },\n            {\n                \"testcase_name\": (\n                    \"narrow_7bits_input_mins_-63.1_input_maxs_-0.1\"\n                ),\n                \"narrow_range\": True,\n                \"input_mins\": [-63.1],\n                \"input_maxs\": [-0.1],\n                \"num_bits\": 7,\n                \"expected_nudged_input_mins\": [-63.0],\n                \"expected_nudged_input_maxs\": [0.0],\n                \"expected_steps\": [0.5],\n                \"axis\": None,\n            },\n            {\n                \"testcase_name\": \"narrow_7bits_input_mins_-0.1_input_maxs_62.9\",\n                \"narrow_range\": True,\n                \"input_mins\": [-0.1],\n                \"input_maxs\": [62.9],\n                \"num_bits\": 7,\n                \"expected_nudged_input_mins\": [0.0],\n                \"expected_nudged_input_maxs\": [63.0],\n                \"expected_steps\": [0.5],\n                \"axis\": None,\n            },\n            {\n                \"testcase_name\": \"wide_8bits_multi_channel\",\n                \"narrow_range\": False,\n                \"input_mins\": [0.0, 0.5, -128.0, -0.1],\n                \"input_maxs\": [255.0, 128.0, -0.5, 127.4],\n                \"num_bits\": 8,\n                \"expected_nudged_input_mins\": [0.0, 0.0, -127.5, 0.0],\n                \"expected_nudged_input_maxs\": [255.0, 127.5, 0.0, 127.5],\n                \"expected_steps\": [1.0, 0.5, 0.5, 0.5],\n                \"axis\": 1,\n            },\n            {\n                \"testcase_name\": \"narrow_8bits_multi_channel\",\n                \"narrow_range\": True,\n                \"input_mins\": [0.0, 0.1, -127.1, -0.1],\n                \"input_maxs\": [254.0, 127.1, -0.1, 126.9],\n                \"num_bits\": 8,\n                \"expected_nudged_input_mins\": [0.0, 0.0, -127.0, 0.0],\n                \"expected_nudged_input_maxs\": [254.0, 127.0, 0.0, 127.0],\n                \"expected_steps\": [1.0, 0.5, 0.5, 0.5],\n                \"axis\": 1,\n            },\n            {\n                \"testcase_name\": \"wide_7bits_multi_channel\",\n                \"narrow_range\": False,\n                \"input_mins\": [0.0, 0.5, -64.0, -0.1],\n                \"input_maxs\": [127.0, 64.0, -0.5, 63.4],\n                \"num_bits\": 7,\n                \"expected_nudged_input_mins\": [0.0, 0.0, -63.5, 0.0],\n                \"expected_nudged_input_maxs\": [127.0, 63.5, 0.0, 63.5],\n                \"expected_steps\": [1.0, 0.5, 0.5, 0.5],\n                \"axis\": 1,\n            },\n            {\n                \"testcase_name\": \"narrow_7bits_multi_channel\",\n                \"narrow_range\": True,\n                \"input_mins\": [0.0, 0.1, -63.1, -0.1],\n                \"input_maxs\": [126.0, 63.1, -0.1, 62.9],\n                \"num_bits\": 7,\n                \"expected_nudged_input_mins\": [0.0, 0.0, -63.0, 0.0],\n                \"expected_nudged_input_maxs\": [126.0, 63.0, 0.0, 63.0],\n                \"expected_steps\": [1.0, 0.5, 0.5, 0.5],\n                \"axis\": 1,\n            },\n        ]\n    )\n    @pytest.mark.skipif(\n        backend.backend() not in (\"tensorflow\", \"jax\", \"torch\"),\n        reason=f\"{backend.backend()} doesn't support `custom_gradient`.\",\n    )\n    def test_fake_quant_with_min_max_vars(\n        self,\n        input_mins,\n        input_maxs,\n        num_bits,\n        narrow_range,\n        axis,\n        expected_nudged_input_mins,\n        expected_nudged_input_maxs,\n        expected_steps,\n    ):\n        num_channels = len(expected_nudged_input_mins)\n        inputs_list = []\n        expected_list = []\n        initial_gradients_list = []\n        expected_backprops_wrt_input_list = []\n        for i in range(num_channels):\n            expected_nudged_input_min = expected_nudged_input_mins[i]\n            expected_nudged_input_max = expected_nudged_input_maxs[i]\n            expected_step = expected_steps[i]\n\n            inputs_list.append(\n                [\n                    expected_nudged_input_min - expected_step,\n                    expected_nudged_input_min - 0.01,\n                    expected_nudged_input_min,\n                    expected_nudged_input_min + 0.01,\n                    expected_nudged_input_min + expected_step - 0.01,\n                    expected_nudged_input_min + expected_step,\n                    expected_nudged_input_min + expected_step + 0.01,\n                    expected_nudged_input_max - 0.01,\n                    expected_nudged_input_max,\n                    expected_nudged_input_max + 0.01,\n                    expected_nudged_input_max + expected_step,\n                ]\n            )\n            expected_list.append(\n                [\n                    expected_nudged_input_min,\n                    expected_nudged_input_min,\n                    expected_nudged_input_min,\n                    expected_nudged_input_min,\n                    expected_nudged_input_min + expected_step,\n                    expected_nudged_input_min + expected_step,\n                    expected_nudged_input_min + expected_step,\n                    expected_nudged_input_max,\n                    expected_nudged_input_max,\n                    expected_nudged_input_max,\n                    expected_nudged_input_max,\n                ]\n            )\n            initial_gradients_list.append(\n                list(range(1, len(inputs_list[-1]) + 1))\n            )\n            expected_backprops_wrt_input_list.append(\n                [0.0, 0.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0]\n            )\n        inputs = ops.transpose(ops.array(inputs_list, dtype=\"float32\"))\n        expected = ops.transpose(ops.array(expected_list, dtype=\"float32\"))\n        expected_backprops_wrt_input = ops.transpose(\n            ops.array(expected_backprops_wrt_input_list, dtype=\"float32\")\n        )\n        input_min = ops.array(input_mins, dtype=\"float32\")\n        input_max = ops.array(input_maxs, dtype=\"float32\")\n        initial_gradients = ops.transpose(\n            ops.array(initial_gradients_list, dtype=\"float32\")\n        )\n\n        # Test gradients.\n        if backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            @tf.function(jit_compile=True)\n            def test_op(\n                inputs, input_mins, input_maxs, num_bits, narrow_range, axis\n            ):\n                with tf.GradientTape() as tape:\n                    tape.watch(inputs)\n                    result = quantizers.fake_quant_with_min_max_vars(\n                        inputs,\n                        input_mins,\n                        input_maxs,\n                        num_bits,\n                        narrow_range,\n                        axis,\n                    )\n                return initial_gradients * tape.gradient(result, inputs)\n\n        if backend.backend() == \"torch\":\n            import torch\n\n            def test_op(\n                inputs, input_mins, input_maxs, num_bits, narrow_range, axis\n            ):\n                # Create tensor and enable gradient tracking\n                inputs = torch.tensor(\n                    inputs, dtype=torch.float32, requires_grad=True\n                )\n\n                # Apply the quantization operation\n                result = quantizers.fake_quant_with_min_max_vars(\n                    inputs, input_mins, input_maxs, num_bits, narrow_range, axis\n                )\n\n                # Compute gradients\n                result.backward(torch.ones_like(result))\n\n                return initial_gradients * inputs.grad\n\n        if backend.backend() == \"jax\":\n            import jax\n\n            def test_op(\n                inputs, input_mins, input_maxs, num_bits, narrow_range, axis\n            ):\n                # Define the function to compute gradients for\n                def quantize_fn(x):\n                    return ops.sum(\n                        quantizers.fake_quant_with_min_max_vars(\n                            x,\n                            input_mins,\n                            input_maxs,\n                            num_bits,\n                            narrow_range,\n                            axis,\n                        )\n                    )\n\n                input_gradients = jax.grad(quantize_fn)(inputs)\n                return ops.multiply(initial_gradients, input_gradients)\n\n        gradients = test_op(\n            inputs, input_min, input_max, num_bits, narrow_range, axis\n        )\n        if not testing.jax_uses_gpu():\n            # JAX GPU produces less precise numbers, causing the CI to fail.\n            # For example, 127.5 / 255.0 results in 0.49999997 instead of 0.5.\n            self.assertAllClose(gradients, expected_backprops_wrt_input)\n\n        # Test outputs.\n        outputs = quantizers.fake_quant_with_min_max_vars(\n            inputs,\n            input_min,\n            input_max,\n            num_bits=num_bits,\n            narrow_range=narrow_range,\n            axis=axis,\n        )\n        self.assertAllClose(outputs, expected)\n\n        # Test bfloat16 & float16 dtype\n        outputs = quantizers.fake_quant_with_min_max_vars(\n            ops.cast(inputs, \"bfloat16\"),\n            input_min,\n            input_max,\n            num_bits=num_bits,\n            narrow_range=narrow_range,\n            axis=axis,\n        )\n        self.assertDType(outputs, \"bfloat16\")\n        self.assertAllClose(outputs, expected)\n\n        outputs = quantizers.fake_quant_with_min_max_vars(\n            ops.cast(inputs, \"float16\"),\n            input_min,\n            input_max,\n            num_bits=num_bits,\n            narrow_range=narrow_range,\n            axis=axis,\n        )\n        self.assertDType(outputs, \"float16\")\n        self.assertAllClose(outputs, expected)\n\n    @parameterized.named_parameters(\n        (\"block_32\", 32),\n        (\"block_64\", 64),\n        (\"block_128\", 128),\n    )\n    def test_grouped_quantize_dequantize_roundtrip(self, block_size):\n        \"\"\"Test that grouped quantize/dequantize has low error.\"\"\"\n        input_dim, output_dim = 256, 128\n        kernel = random.uniform(\n            (input_dim, output_dim), minval=-1, maxval=1, dtype=\"float32\"\n        )\n\n        quantized, scale, zero = (\n            quantizers.abs_max_quantize_grouped_with_zero_point(\n                kernel,\n                block_size=block_size,\n                value_range=(-8, 7),\n                dtype=\"int8\",\n            )\n        )\n\n        # Use dequantize_with_sz_map with generated g_idx\n        g_idx = ops.arange(input_dim) // block_size\n        dequantized = ops.transpose(\n            quantizers.dequantize_with_sz_map(\n                ops.transpose(ops.cast(quantized, scale.dtype)),\n                ops.transpose(scale),\n                ops.transpose(zero),\n                g_idx,\n            )\n        )\n\n        rmse = ops.sqrt(ops.mean(ops.square(kernel - dequantized)))\n        # Grouped quantization should have reasonable error\n        self.assertLess(rmse, 0.15)\n\n    def test_grouped_quantize_with_padding(self):\n        \"\"\"Test grouped quantization when input_dim is not divisible.\"\"\"\n\n        # 500 is not divisible by 128, so padding will be needed\n        input_dim, output_dim, block_size = 500, 256, 128\n        kernel = random.uniform(\n            (input_dim, output_dim), minval=-1, maxval=1, dtype=\"float32\"\n        )\n\n        quantized, scale, zero = (\n            quantizers.abs_max_quantize_grouped_with_zero_point(\n                kernel,\n                block_size=block_size,\n                value_range=(-8, 7),\n                dtype=\"int8\",\n            )\n        )\n\n        n_groups = math.ceil(input_dim / block_size)  # 4 groups\n        self.assertEqual(quantized.shape, (input_dim, output_dim))\n        self.assertEqual(scale.shape, (n_groups, output_dim))\n        self.assertEqual(zero.shape, (n_groups, output_dim))\n\n    def test_grouped_vs_perchannel_accuracy(self):\n        \"\"\"Test that grouped quantization has lower error than per-channel.\"\"\"\n        input_dim, output_dim, block_size = 512, 256, 128\n        # Use a specific seed for reproducibility\n        kernel = random.uniform(\n            (input_dim, output_dim),\n            minval=-1,\n            maxval=1,\n            dtype=\"float32\",\n            seed=42,\n        )\n\n        # Per-channel quantization (one scale per output channel)\n        quantizer = quantizers.AbsMaxQuantizer(\n            axis=0, value_range=(-8, 7), output_dtype=\"int8\"\n        )\n        pc_quantized, pc_scale = quantizer(kernel)\n        pc_dequantized = ops.cast(pc_quantized, \"float32\") / pc_scale\n        pc_rmse = ops.sqrt(ops.mean(ops.square(kernel - pc_dequantized)))\n\n        # Grouped (sub-channel) quantization with zero point\n        grouped_quantized, grouped_scale, grouped_zero = (\n            quantizers.abs_max_quantize_grouped_with_zero_point(\n                kernel, block_size=block_size, value_range=(-8, 7), dtype=\"int8\"\n            )\n        )\n\n        # Use dequantize_with_sz_map with generated g_idx\n        g_idx = ops.arange(input_dim) // block_size\n        grouped_dequantized = ops.transpose(\n            quantizers.dequantize_with_sz_map(\n                ops.transpose(ops.cast(grouped_quantized, grouped_scale.dtype)),\n                ops.transpose(grouped_scale),\n                ops.transpose(grouped_zero),\n                g_idx,\n            )\n        )\n\n        grouped_rmse = ops.sqrt(\n            ops.mean(ops.square(kernel - grouped_dequantized))\n        )\n\n        # Grouped should have lower or similar error\n        # (in most cases it should be lower due to finer granularity)\n        self.assertLessEqual(float(grouped_rmse), float(pc_rmse) + 0.01)\n\n    def test_grouped_quantize_various_block_sizes(self):\n        \"\"\"Test grouped quantization with various block sizes.\"\"\"\n\n        input_dim, output_dim = 512, 128\n        kernel = random.uniform(\n            (input_dim, output_dim), minval=-1, maxval=1, dtype=\"float32\"\n        )\n\n        for block_size in [32, 64, 128, 256]:\n            quantized, scale, zero = (\n                quantizers.abs_max_quantize_grouped_with_zero_point(\n                    kernel,\n                    block_size=block_size,\n                    value_range=(-8, 7),\n                    dtype=\"int8\",\n                )\n            )\n\n            n_groups = math.ceil(input_dim / block_size)\n            self.assertEqual(quantized.shape, (input_dim, output_dim))\n            self.assertEqual(scale.shape, (n_groups, output_dim))\n            self.assertEqual(zero.shape, (n_groups, output_dim))\n\n\nclass Int4QuantizationConfigTest(testing.TestCase):\n    def test_default_block_size(self):\n        \"\"\"Test that default block_size is 128.\"\"\"\n        from keras.src.quantizers import Int4QuantizationConfig\n\n        config = Int4QuantizationConfig()\n        self.assertEqual(config.block_size, 128)\n        self.assertEqual(config.mode, \"int4\")\n\n    @parameterized.named_parameters(\n        (\"block_32\", 32),\n        (\"block_64\", 64),\n        (\"block_128\", 128),\n        (\"block_256\", 256),\n    )\n    def test_custom_block_size(self, block_size):\n        \"\"\"Test setting custom block_size values.\"\"\"\n        from keras.src.quantizers import Int4QuantizationConfig\n\n        config = Int4QuantizationConfig(block_size=block_size)\n        self.assertEqual(config.block_size, block_size)\n\n    def test_per_channel_mode_with_none(self):\n        \"\"\"Test per-channel mode with block_size=None.\"\"\"\n        from keras.src.quantizers import Int4QuantizationConfig\n\n        config = Int4QuantizationConfig(block_size=None)\n        self.assertIsNone(config.block_size)\n\n    def test_per_channel_mode_with_negative_one(self):\n        \"\"\"Test per-channel mode with block_size=-1.\"\"\"\n        from keras.src.quantizers import Int4QuantizationConfig\n\n        config = Int4QuantizationConfig(block_size=-1)\n        self.assertEqual(config.block_size, -1)\n\n    def test_invalid_block_size_raises(self):\n        \"\"\"Test that invalid block_size values raise ValueError.\"\"\"\n        from keras.src.quantizers import Int4QuantizationConfig\n\n        with self.assertRaisesRegex(ValueError, \"block_size must be\"):\n            Int4QuantizationConfig(block_size=0)\n\n        with self.assertRaisesRegex(ValueError, \"block_size must be\"):\n            Int4QuantizationConfig(block_size=-2)\n\n    def test_get_config_includes_block_size(self):\n        \"\"\"Test that get_config includes block_size.\"\"\"\n        from keras.src.quantizers import Int4QuantizationConfig\n\n        config = Int4QuantizationConfig(block_size=64)\n        serialized = config.get_config()\n        self.assertEqual(serialized[\"block_size\"], 64)\n\n    def test_from_config_restores_block_size(self):\n        \"\"\"Test that from_config restores block_size.\"\"\"\n        from keras.src.quantizers import Int4QuantizationConfig\n\n        original = Int4QuantizationConfig(block_size=64)\n        serialized = original.get_config()\n        restored = Int4QuantizationConfig.from_config(serialized)\n        self.assertEqual(restored.block_size, 64)\n\n    def test_serialization_roundtrip(self):\n        \"\"\"Test full serialization roundtrip.\"\"\"\n        from keras.src.quantizers import Int4QuantizationConfig\n\n        config = Int4QuantizationConfig(block_size=128)\n        serialized = quantizers.serialize(config)\n        deserialized = quantizers.deserialize(serialized)\n        self.assertEqual(deserialized.block_size, 128)\n        self.assertEqual(deserialized.mode, \"int4\")\n\n    def test_serialization_with_per_channel(self):\n        \"\"\"Test serialization with per-channel mode.\"\"\"\n        from keras.src.quantizers import Int4QuantizationConfig\n\n        config = Int4QuantizationConfig(block_size=None)\n        serialized = quantizers.serialize(config)\n        deserialized = quantizers.deserialize(serialized)\n        self.assertIsNone(deserialized.block_size)\n\n\nclass GPTQQuantizerTest(testing.TestCase):\n    @parameterized.named_parameters(\n        (\"bits_2_sym_False\", 2, False),\n        (\"bits_4_sym_False\", 4, False),\n        (\"bits_8_sym_False\", 8, False),\n        (\"bits_2_sym_True\", 2, True),\n        (\"bits_4_sym_True\", 4, True),\n        (\"bits_8_sym_True\", 8, True),\n    )\n    def test_quantize_dequantize_roundtrip_error_bound_per_tensor(\n        self, bits, symmetric\n    ):\n        \"\"\"\n        For finite inputs and positive scales, the reconstruction error\n        |x_hat - clip(x)| is bounded by 0.5 * scale elementwise.\n        \"\"\"\n        rng = np.random.default_rng(0)\n        x = ops.array(rng.standard_normal((64, 32)), \"float32\")\n        scale = ops.array(0.05)  # per-tensor scale\n        maxq = ops.array(ops.subtract(ops.power(2, bits), 1), \"float32\")\n        zero = ops.array(maxq / 2.0 if symmetric else 3.0, \"float32\")\n\n        quantized = quantize_with_zero_point(x, scale, zero, maxq)\n        dequantized = dequantize_with_zero_point(quantized, scale, zero)\n\n        # Representable dequantization range:\n        # [scale*(0 - zero), scale*(maxq - zero)]\n        lo = ops.multiply(scale, ops.subtract(ops.array(0.0), zero))\n        hi = ops.multiply(scale, ops.subtract(maxq, zero))\n        x_clipped = ops.clip(x, lo, hi)\n\n        err = ops.abs(dequantized - x_clipped)\n        self.assertTrue(\n            ops.all(err <= (ops.add(ops.multiply(0.5, scale), 1e-7)))\n        )\n\n    def test_quantize_clipping_behavior_extremes(self):\n        \"\"\"\n        Very negative q == 0 ; very positive q == maxq.\n        \"\"\"\n        maxq = ops.array(15.0)\n        scale = ops.array(0.1)\n        zero = ops.array(7.0)\n\n        x = ops.array([[-1e6, 1e6]], \"float32\")\n        quantized = quantize_with_zero_point(x, scale, zero, maxq)\n\n        self.assertEqual(quantized.shape, (1, 2))\n        self.assertEqual(quantized[0, 0], 0.0)\n        self.assertEqual(quantized[0, 1], maxq)\n\n    def test_zero_scale_guard_no_nans_for_finite_inputs(self):\n        \"\"\"\n        If scale == 0, quantize should not produce NaNs (uses epsilon\n        replacement).\n        \"\"\"\n        x = ops.array([[0.0, 1.0, -2.0]])\n        scale = ops.array(0.0)  # triggers epsilon path\n        zero = ops.array(5.0)\n        maxq = ops.array(15.0)\n\n        q = quantize_with_zero_point(x, scale, zero, maxq)\n        self.assertFalse(ops.any(ops.isnan(q)))\n\n        # Dequantize should also be finite\n        x_hat = dequantize_with_zero_point(q, scale, zero)\n        self.assertTrue(ops.all(ops.isfinite(x_hat)))\n\n    @parameterized.parameters(4, 8)\n    def test_idempotent_quantize_when_input_is_already_levels(self, bits):\n        \"\"\"\n        If input is already exactly on representable dequantized grid,\n        quantize→dequantize should return the same values (within float eps).\n        \"\"\"\n        scale = ops.array(0.125)\n        maxq = ops.array(ops.subtract(ops.power(2, bits), 1), \"float32\")\n        zero = ops.array(ops.divide(maxq, 2.0))\n\n        # Build dequantized grid points: x = scale * (k - zero), k in [0..maxq]\n        ks = ops.arange(0, ops.add(maxq, 1))\n        x_vals = ops.multiply(scale, ops.subtract(ks, zero))\n        x = ops.reshape(x_vals, (1, -1))\n\n        q = quantize_with_zero_point(x, scale, zero, maxq)\n        x_hat = dequantize_with_zero_point(q, scale, zero)\n        self.assertAllClose(x_hat, x, rtol=0, atol=1e-6)\n\n\nclass ComputeScaleZeroTest(testing.TestCase):\n    def test_error_when_x_is_none(self):\n        with self.assertRaisesRegex(ValueError, \"cannot be None\"):\n            compute_quantization_parameters(None, bits=4)\n\n    def test_error_when_x_is_empty(self):\n        x = ops.array([[], []], \"float32\")  # 2D empty tensor\n        with self.assertRaisesRegex(ValueError, \"cannot be empty\"):\n            compute_quantization_parameters(x, bits=4)\n\n    def test_error_when_weight_rank_too_low(self):\n        x = ops.array([1.0, 2.0], \"float32\")  # rank-1\n        with self.assertRaisesRegex(ValueError, \"rank of at least 2\"):\n            compute_quantization_parameters(x, bits=4)\n\n    @parameterized.named_parameters(\n        (\"bits2_asym\", 2, False),\n        (\"bits4_asym\", 4, False),\n        (\"bits8_asym\", 8, False),\n        (\"bits2_sym\", 2, True),\n        (\"bits4_sym\", 4, True),\n        (\"bits8_sym\", 8, True),\n    )\n    def test_per_tensor_shapes_and_basic_invariants(self, bits, symmetric):\n        \"\"\"Test per-tensor shapes and basic invariants.\"\"\"\n        x = ops.array(\n            np.random.default_rng(0).standard_normal((7, 5), dtype=\"float32\")\n        )\n        scale, zero, maxq = compute_quantization_parameters(\n            x, bits=bits, symmetric=symmetric, per_channel=False\n        )\n\n        # Shapes (per-tensor with weight semantics): (out_features, 1)\n        self.assertEqual(scale.shape, (7, 1))\n        self.assertEqual(zero.shape, (7, 1))\n\n        # Scale must be strictly positive\n        self.assertTrue(ops.all(scale > 0.0))\n\n        # All elements in the scale and zero tensors must be equal due to\n        # tiling for per-tensor quantization\n        self.assertTrue(ops.all(scale == scale[0, 0]))\n        self.assertTrue(ops.all(zero == zero[0, 0]))\n\n    def test_per_tensor_symmetric_on_constant_input_uses_safe_range(self):\n        \"\"\"Ensures safe range adjustment if entries are equal\"\"\"\n        x = ops.array(np.full((3, 4), 0.0, dtype=np.float32))\n        scale, zero, maxq = compute_quantization_parameters(\n            x, bits=4, symmetric=True, per_channel=False\n        )\n        # With symmetric=True and constant input, zero = (maxq+1)/2\n        # Shape is now (3, 1) due to weight semantics\n        expected_zero = ops.array((float(maxq) + 1.0) / 2.0)\n        self.assertAllClose(zero[0, 0], expected_zero)\n        self.assertTrue(ops.all(ops.greater(scale, 0.0)))\n\n    def test_weight_per_tensor_tiles_rows(self):\n        \"\"\"Tests that scales/zeros tensors are properly tiled when\n        per-channel quantization is not used.\"\"\"\n        x = ops.array(\n            np.random.default_rng(1).standard_normal((8, 16)), \"float32\"\n        )\n        scale, zero, _ = compute_quantization_parameters(\n            x, bits=4, symmetric=False, per_channel=False\n        )\n        # With per_channel=False, shapes are (rows, 1)\n        self.assertEqual(scale.shape, (8, 1))\n        self.assertEqual(zero.shape, (8, 1))\n\n        # All elements in the scale and zero tensors must be equal due to\n        # tiling.\n        self.assertTrue(ops.all(scale == scale[0, 0]))\n        self.assertTrue(ops.all(zero == zero[0, 0]))\n\n    def test_weight_per_channel_ungrouped_shapes(self):\n        \"\"\"Tests that scales/zeros tensors have the correct shape when\n        per-channel quantization is used without grouping.\"\"\"\n        x = ops.array(\n            np.random.default_rng(2).standard_normal((6, 10)), \"float32\"\n        )\n        scale, zero, _ = compute_quantization_parameters(\n            x,\n            bits=4,\n            symmetric=False,\n            per_channel=True,\n            group_size=-1,\n        )\n        # Per-channel (ungrouped): one scale per output row -> (rows, 1)\n        self.assertEqual(scale.shape, (6, 1))\n        self.assertEqual(zero.shape, (6, 1))\n        self.assertTrue(ops.all(ops.greater(scale, 0.0)))\n\n        # Each channel should have roughly unique scales and zeros\n        self.assertFalse(ops.all(scale == scale[0, 0]))\n        self.assertFalse(ops.all(zero == zero[0, 0]))\n\n    def test_weight_per_channel_grouped_shapes_and_count(self):\n        \"\"\"Tests that scales/zeros have the correct shape and count when\n        per-channel quantization is used with grouping.\"\"\"\n        out_features, in_features, group_size = 8, 16, 4\n        x = ops.array(\n            np.random.default_rng(3).standard_normal(\n                (out_features, in_features)\n            ),\n            \"float32\",\n        )\n        scale, zero, _ = compute_quantization_parameters(\n            x,\n            bits=4,\n            symmetric=False,\n            per_channel=True,\n            group_size=group_size,\n        )\n        # Grouped path produces [out_features, n_groups] shape\n        n_groups = in_features // group_size\n        self.assertEqual(scale.shape, (out_features, n_groups))\n        self.assertEqual(zero.shape, (out_features, n_groups))\n        self.assertTrue(ops.all(ops.greater(scale, 0.0)))\n\n    @parameterized.named_parameters(\n        (\"sym_true\", True),\n        (\"sym_false\", False),\n    )\n    def test_dtype_and_finiteness(self, symmetric):\n        x = ops.array(\n            np.random.default_rng(4).standard_normal((5, 7)).astype(\"float32\")\n        )\n        scale, zero, maxq = compute_quantization_parameters(\n            x,\n            bits=8,\n            symmetric=symmetric,\n            per_channel=True,\n            group_size=-1,\n        )\n        # All outputs should be all finite\n        self.assertTrue(ops.all(ops.isfinite(scale)))\n        self.assertTrue(ops.all(ops.isfinite(zero)))\n        self.assertTrue(ops.all(ops.isfinite(maxq)))\n\n    def test_dequantize_with_sz_map_logic(self):\n        \"\"\"Validates the vectorized dequantization logic against a\n        manual implementation.\"\"\"\n        out_features, in_features, group_size = 4, 16, 4\n        n_groups = in_features // group_size\n\n        # Create dummy quantized weights\n        q_weights = ops.cast(\n            ops.array(\n                np.random.randint(0, 15, size=(out_features, in_features))\n            ),\n            \"uint8\",\n        )\n\n        # Create dummy scales and zeros\n        scale = ops.abs(\n            ops.array(\n                np.random.random((out_features, n_groups)).astype(\"float32\")\n            )\n        )\n        zero = ops.cast(\n            ops.array(np.random.randint(0, 15, size=(out_features, n_groups))),\n            \"uint8\",\n        )\n\n        # Create group index mapping\n        g_idx = ops.array(np.arange(in_features) // group_size, dtype=\"int32\")\n\n        # Get the result from the function under test\n        dequantized_result = dequantize_with_sz_map(\n            q_weights, scale, zero, g_idx\n        )\n\n        # Manually compute the expected result\n        expected_dequantized = np.zeros(\n            (out_features, in_features), dtype=\"float32\"\n        )\n\n        for i in range(out_features):\n            for j in range(in_features):\n                group = g_idx[j]\n                s = scale[i, group]\n                z = zero[i, group]\n                # Dequantization formula: (q_val - z) * s\n                expected_dequantized[i, j] = ops.multiply(\n                    ops.subtract(q_weights[i, j], ops.cast(z, \"float32\")), s\n                )\n\n        self.assertAllClose(dequantized_result, expected_dequantized)\n\n    def test_quantize_with_sz_map_logic(self):\n        \"\"\"Validates the vectorized quantization logic against a\n        manual implementation.\"\"\"\n        out_features, in_features, group_size = 4, 16, 4\n        n_groups = in_features // group_size\n\n        # Create dummy float weights\n        weights = ops.array(\n            np.random.default_rng(5).standard_normal(\n                (out_features, in_features)\n            ),\n            \"float32\",\n        )\n\n        # Create dummy scales and zeros\n        scale = ops.abs(\n            ops.array(\n                np.random.random((out_features, n_groups)).astype(\"float32\")\n            )\n        )\n        zero = ops.cast(\n            ops.array(np.random.randint(0, 15, size=(out_features, n_groups))),\n            \"uint8\",\n        )\n\n        maxq = ops.array(15.0)\n\n        # Create group index mapping\n        g_idx = ops.array(np.arange(in_features) // group_size, dtype=\"int32\")\n\n        # Get the result from the function under test\n        quantized_result = quantize_with_sz_map(\n            weights, scale, zero, g_idx, maxq\n        )\n\n        # Manually compute the expected result\n        expected_quantized = np.zeros(\n            (out_features, in_features), dtype=\"uint8\"\n        )\n\n        for i in range(out_features):\n            for j in range(in_features):\n                group = g_idx[j]\n                s = scale[i, group]\n                z = zero[i, group]\n                # Quantization formula: clip(round(x/s + z), 0, maxq)\n                q_val = ops.round(ops.add(ops.divide(weights[i, j], s), z))\n                q_val_clipped = ops.clip(q_val, 0.0, maxq)\n                expected_quantized[i, j] = ops.cast(q_val_clipped, \"uint8\")\n\n        self.assertAllClose(quantized_result, expected_quantized)\n\n\nclass GroupedQuantizationParametersTest(testing.TestCase):\n    \"\"\"Test grouped weight quantization in compute_quantization_parameters.\"\"\"\n\n    def test_grouped_weight_shapes_divisible(self):\n        \"\"\"Test grouped quantization with divisible dimensions.\"\"\"\n        out_features, in_features, group_size = 64, 128, 32\n        n_groups = in_features // group_size  # 4\n\n        x = ops.array(\n            np.random.randn(out_features, in_features).astype(\"float32\")\n        )\n\n        scale, zero, maxq = compute_quantization_parameters(\n            x,\n            bits=4,\n            symmetric=False,\n            per_channel=True,\n            group_size=group_size,\n        )\n\n        self.assertEqual(scale.shape, (out_features, n_groups))\n        self.assertEqual(zero.shape, (out_features, n_groups))\n        self.assertEqual(float(maxq), 15.0)\n\n    def test_grouped_weight_shapes_non_divisible(self):\n        \"\"\"Test grouped quantization with non-divisible dimensions.\"\"\"\n        out_features, in_features, group_size = 32, 100, 32\n        n_groups = (in_features + group_size - 1) // group_size  # 4\n\n        x = ops.array(\n            np.random.randn(out_features, in_features).astype(\"float32\")\n        )\n\n        scale, zero, maxq = compute_quantization_parameters(\n            x,\n            bits=4,\n            symmetric=False,\n            per_channel=True,\n            group_size=group_size,\n        )\n\n        self.assertEqual(scale.shape, (out_features, n_groups))\n        self.assertEqual(zero.shape, (out_features, n_groups))\n\n    def test_grouped_returns_3_values(self):\n        \"\"\"Test that grouped quantization returns exactly 3 values.\"\"\"\n        x = ops.array(np.random.randn(32, 64).astype(\"float32\"))\n\n        result = compute_quantization_parameters(\n            x,\n            bits=4,\n            symmetric=False,\n            per_channel=True,\n            group_size=16,\n        )\n\n        # Should return exactly 3 values\n        self.assertEqual(len(result), 3)\n        scale, zero, maxq = result\n        self.assertEqual(scale.shape, (32, 4))\n        self.assertEqual(zero.shape, (32, 4))\n\n    def test_single_group_per_channel_semantics(self):\n        \"\"\"Test that single group slice uses per-channel semantics.\"\"\"\n        out_features, in_features = 32, 16\n        group_size = 16  # in_features == group_size\n\n        x = ops.array(\n            np.random.randn(out_features, in_features).astype(\"float32\")\n        )\n\n        scale, zero, maxq = compute_quantization_parameters(\n            x,\n            bits=4,\n            symmetric=False,\n            per_channel=True,\n            group_size=group_size,\n        )\n\n        # Single group should produce per-channel output shape\n        # n_groups = 1, so shape is [out_features, 1]\n        self.assertEqual(scale.shape, (out_features, 1))\n        self.assertEqual(zero.shape, (out_features, 1))\n\n    def test_grouped_no_nan_inf(self):\n        \"\"\"Test grouped quantization produces no NaN/Inf.\"\"\"\n        x = ops.array(np.random.randn(64, 128).astype(\"float32\"))\n\n        scale, zero, maxq = compute_quantization_parameters(\n            x,\n            bits=4,\n            per_channel=True,\n            group_size=32,\n        )\n\n        self.assertFalse(ops.any(ops.isnan(scale)))\n        self.assertFalse(ops.any(ops.isinf(scale)))\n\n    def test_grouped_various_group_sizes(self):\n        \"\"\"Test grouped quantization with various group sizes.\"\"\"\n        out_features, in_features = 64, 128\n\n        for group_size in [8, 16, 32, 64]:\n            n_groups = (in_features + group_size - 1) // group_size\n            x = ops.array(\n                np.random.randn(out_features, in_features).astype(\"float32\")\n            )\n\n            scale, zero, maxq = compute_quantization_parameters(\n                x,\n                bits=4,\n                per_channel=True,\n                group_size=group_size,\n            )\n\n            self.assertEqual(\n                scale.shape,\n                (out_features, n_groups),\n                f\"Failed for group_size={group_size}\",\n            )\n"
  },
  {
    "path": "keras/src/quantizers/utils.py",
    "content": "import re\n\n\ndef should_quantize_layer(layer, filters):\n    \"\"\"Determines if a layer should be quantized based on filters.\n\n    Args:\n        layer: The layer to check.\n        filters: A regex string, a list of regex strings, or a callable.\n            If None, returns True.\n\n    Returns:\n        True if the layer should be quantized, False otherwise.\n    \"\"\"\n    if filters is None:\n        return True\n    if isinstance(filters, str):\n        return bool(re.search(filters, layer.name))\n    if isinstance(filters, (list, tuple)):\n        return any(re.search(pat, layer.name) for pat in filters)\n    if callable(filters):\n        return filters(layer)\n    return True\n"
  },
  {
    "path": "keras/src/quantizers/utils_test.py",
    "content": "from absl.testing import parameterized\n\nfrom keras.src import layers\nfrom keras.src import testing\nfrom keras.src.quantizers import utils\n\n\nclass UtilsTest(testing.TestCase):\n    @parameterized.named_parameters(\n        (\"none_filter\", None, \"dense\", True),\n        (\"regex_match\", \"dense\", \"dense_1\", True),\n        (\"regex_no_match\", \"conv\", \"dense_1\", False),\n        (\"list_match\", [\"dense\", \"conv\"], \"dense_1\", True),\n        (\"list_no_match\", [\"conv\", \"pool\"], \"dense_1\", False),\n        (\"callable_match\", lambda l: \"dense\" in l.name, \"dense_1\", True),\n        (\"callable_no_match\", lambda l: \"conv\" in l.name, \"dense_1\", False),\n    )\n    def test_should_quantize_layer(self, filters, layer_name, expected):\n        layer = layers.Layer(name=layer_name)\n        self.assertEqual(utils.should_quantize_layer(layer, filters), expected)\n"
  },
  {
    "path": "keras/src/random/__init__.py",
    "content": "from keras.src.random.random import categorical\nfrom keras.src.random.random import dropout\nfrom keras.src.random.random import gamma\nfrom keras.src.random.random import normal\nfrom keras.src.random.random import randint\nfrom keras.src.random.random import shuffle\nfrom keras.src.random.random import truncated_normal\nfrom keras.src.random.random import uniform\nfrom keras.src.random.seed_generator import SeedGenerator\n"
  },
  {
    "path": "keras/src/random/random.py",
    "content": "from keras.src import backend\nfrom keras.src.api_export import keras_export\n\n\n@keras_export(\"keras.random.normal\")\ndef normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):\n    \"\"\"Draw random samples from a normal (Gaussian) distribution.\n\n    Args:\n        shape: The shape of the random values to generate.\n        mean: Float, defaults to 0. Mean of the random values to generate.\n        stddev: Float, defaults to 1. Standard deviation of the random values\n            to generate.\n        dtype: Optional dtype of the tensor. Only floating point types are\n            supported. If not specified, `keras.config.floatx()` is used,\n            which defaults to `float32` unless you configured it otherwise (via\n            `keras.config.set_floatx(float_dtype)`).\n        seed: Optional Python integer or instance of\n           `keras.random.SeedGenerator`.\n            By default, the `seed` argument is `None`, and an internal global\n            `keras.random.SeedGenerator` is used. The `seed` argument can be\n            used to ensure deterministic (repeatable) random number generation.\n            Note that passing an integer as the `seed` value will produce the\n            same random values for each call. To generate different random\n            values for repeated calls, an instance of\n            `keras.random.SeedGenerator` must be provided as the `seed` value.\n            Remark concerning the JAX backend: When tracing functions with the\n            JAX backend the global `keras.random.SeedGenerator` is not\n            supported. Therefore, during tracing the default value `seed=None`\n            will produce an error, and a `seed` argument must be provided.\n    \"\"\"\n    return backend.random.normal(\n        shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed\n    )\n\n\n@keras_export(\"keras.random.categorical\")\ndef categorical(logits, num_samples, dtype=\"int32\", seed=None):\n    \"\"\"Draws samples from a categorical distribution.\n\n    This function takes as input `logits`, a 2-D input tensor with shape\n    (batch_size, num_classes). Each row of the input represents a categorical\n    distribution, with each column index containing the log-probability for a\n    given class.\n\n    The function will output a 2-D tensor with shape (batch_size, num_samples),\n    where each row contains samples from the corresponding row in `logits`.\n    Each column index contains an independent samples drawn from the input\n    distribution.\n\n    Args:\n        logits: 2-D Tensor with shape (batch_size, num_classes). Each row\n            should define a categorical distribution with the unnormalized\n            log-probabilities for all classes.\n        num_samples: Int, the number of independent samples to draw for each\n            row of the input. This will be the second dimension of the output\n            tensor's shape.\n        dtype: Optional dtype of the output tensor.\n        seed: Optional Python integer or instance of\n           `keras.random.SeedGenerator`.\n            By default, the `seed` argument is `None`, and an internal global\n            `keras.random.SeedGenerator` is used. The `seed` argument can be\n            used to ensure deterministic (repeatable) random number generation.\n            Note that passing an integer as the `seed` value will produce the\n            same random values for each call. To generate different random\n            values for repeated calls, an instance of\n            `keras.random.SeedGenerator` must be provided as the `seed` value.\n            Remark concerning the JAX backend: When tracing functions with the\n            JAX backend the global `keras.random.SeedGenerator` is not\n            supported. Therefore, during tracing the default value seed=None\n            will produce an error, and a `seed` argument must be provided.\n\n    Returns:\n        A 2-D tensor with (batch_size, num_samples).\n    \"\"\"\n    logits_shape = list(backend.convert_to_tensor(logits).shape)\n    if len(logits_shape) != 2:\n        raise ValueError(\n            \"`logits` should be a 2-D tensor with shape \"\n            f\"[batch_size, num_classes]. Received: logits={logits}\"\n        )\n    return backend.random.categorical(\n        logits, num_samples, dtype=dtype, seed=seed\n    )\n\n\n@keras_export(\"keras.random.uniform\")\ndef uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):\n    \"\"\"Draw samples from a uniform distribution.\n\n    The generated values follow a uniform distribution in the range\n    `[minval, maxval)`. The lower bound `minval` is included in the range,\n    while the upper bound `maxval` is excluded.\n\n    `dtype` must be a floating point type, the default range is `[0, 1)`.\n\n    Args:\n        shape: The shape of the random values to generate.\n        minval: Float, defaults to 0. Lower bound of the range of\n            random values to generate (inclusive).\n        maxval: Float, defaults to 1. Upper bound of the range of\n            random values to generate (exclusive).\n        dtype: Optional dtype of the tensor. Only floating point types are\n            supported. If not specified, `keras.config.floatx()` is used,\n            which defaults to `float32` unless you configured it otherwise (via\n            `keras.config.set_floatx(float_dtype)`)\n        seed: Optional Python integer or instance of\n           `keras.random.SeedGenerator`.\n            By default, the `seed` argument is `None`, and an internal global\n            `keras.random.SeedGenerator` is used. The `seed` argument can be\n            used to ensure deterministic (repeatable) random number generation.\n            Note that passing an integer as the `seed` value will produce the\n            same random values for each call. To generate different random\n            values for repeated calls, an instance of\n            `keras.random.SeedGenerator` must be provided as the `seed` value.\n            Remark concerning the JAX backend: When tracing functions with the\n            JAX backend the global `keras.random.SeedGenerator` is not\n            supported. Therefore, during tracing the default value seed=None\n            will produce an error, and a `seed` argument must be provided.\n    \"\"\"\n    if dtype and not backend.is_float_dtype(dtype):\n        raise ValueError(\n            \"`keras.random.uniform` requires a floating point `dtype`. \"\n            f\"Received: dtype={dtype} \"\n        )\n    return backend.random.uniform(\n        shape, minval=minval, maxval=maxval, dtype=dtype, seed=seed\n    )\n\n\n@keras_export(\"keras.random.randint\")\ndef randint(shape, minval, maxval, dtype=\"int32\", seed=None):\n    \"\"\"Draw random integers from a uniform distribution.\n\n    The generated values follow a uniform distribution in the range\n    `[minval, maxval)`. The lower bound `minval` is included in the range,\n    while the upper bound `maxval` is excluded.\n\n    `dtype` must be an integer type.\n\n    Args:\n        shape: The shape of the random values to generate.\n        minval: Float, defaults to 0. Lower bound of the range of\n            random values to generate (inclusive).\n        maxval: Float, defaults to 1. Upper bound of the range of\n            random values to generate (exclusive).\n        dtype: Optional dtype of the tensor. Only integer types are\n            supported. If not specified, `keras.config.floatx()` is used,\n            which defaults to `float32` unless you configured it otherwise (via\n            `keras.config.set_floatx(float_dtype)`)\n        seed: Optional Python integer or instance of\n           `keras.random.SeedGenerator`.\n            By default, the `seed` argument is `None`, and an internal global\n            `keras.random.SeedGenerator` is used. The `seed` argument can be\n            used to ensure deterministic (repeatable) random number generation.\n            Note that passing an integer as the `seed` value will produce the\n            same random values for each call. To generate different random\n            values for repeated calls, an instance of\n            `keras.random.SeedGenerator` must be provided as the `seed` value.\n            Remark concerning the JAX backend: When tracing functions with the\n            JAX backend the global `keras.random.SeedGenerator` is not\n            supported. Therefore, during tracing the default value seed=None\n            will produce an error, and a `seed` argument must be provided.\n    \"\"\"\n    if dtype and not backend.is_int_dtype(dtype):\n        raise ValueError(\n            \"`keras.random.randint` requires an integer `dtype`. \"\n            f\"Received: dtype={dtype} \"\n        )\n    return backend.random.randint(\n        shape, minval=minval, maxval=maxval, dtype=dtype, seed=seed\n    )\n\n\n@keras_export(\"keras.random.truncated_normal\")\ndef truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):\n    \"\"\"Draw samples from a truncated normal distribution.\n\n    The values are drawn from a normal distribution with specified mean and\n    standard deviation, discarding and re-drawing any samples that are more\n    than two standard deviations from the mean.\n\n    Args:\n        shape: The shape of the random values to generate.\n        mean: Float, defaults to 0. Mean of the random values to generate.\n        stddev: Float, defaults to 1. Standard deviation of the random values\n            to generate.\n        dtype: Optional dtype of the tensor. Only floating point types are\n            supported. If not specified, `keras.config.floatx()` is used,\n            which defaults to `float32` unless you configured it otherwise (via\n            `keras.config.set_floatx(float_dtype)`)\n        seed: Optional Python integer or instance of\n           `keras.random.SeedGenerator`.\n            By default, the `seed` argument is `None`, and an internal global\n            `keras.random.SeedGenerator` is used. The `seed` argument can be\n            used to ensure deterministic (repeatable) random number generation.\n            Note that passing an integer as the `seed` value will produce the\n            same random values for each call. To generate different random\n            values for repeated calls, an instance of\n            `keras.random.SeedGenerator` must be provided as the `seed` value.\n            Remark concerning the JAX backend: When tracing functions with the\n            JAX backend the global `keras.random.SeedGenerator` is not\n            supported. Therefore, during tracing the default value seed=None\n            will produce an error, and a `seed` argument must be provided.\n    \"\"\"\n    return backend.random.truncated_normal(\n        shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed\n    )\n\n\n@keras_export(\"keras.random.dropout\")\ndef dropout(inputs, rate, noise_shape=None, seed=None):\n    return backend.random.dropout(\n        inputs, rate, noise_shape=noise_shape, seed=seed\n    )\n\n\n@keras_export(\"keras.random.shuffle\")\ndef shuffle(x, axis=0, seed=None):\n    \"\"\"Shuffle the elements of a tensor uniformly at random along an axis.\n\n    Args:\n        x: The tensor to be shuffled.\n        axis: An integer specifying the axis along which to shuffle. Defaults to\n            `0`.\n        seed: Optional Python integer or instance of\n           `keras.random.SeedGenerator`.\n            By default, the `seed` argument is `None`, and an internal global\n            `keras.random.SeedGenerator` is used. The `seed` argument can be\n            used to ensure deterministic (repeatable) random number generation.\n            Note that passing an integer as the `seed` value will produce the\n            same random values for each call. To generate different random\n            values for repeated calls, an instance of\n            `keras.random.SeedGenerator` must be provided as the `seed` value.\n            Remark concerning the JAX backend: When tracing functions with the\n            JAX backend the global `keras.random.SeedGenerator` is not\n            supported. Therefore, during tracing the default value seed=None\n            will produce an error, and a `seed` argument must be provided.\n    \"\"\"\n    return backend.random.shuffle(x, axis=axis, seed=seed)\n\n\n@keras_export(\"keras.random.gamma\")\ndef gamma(shape, alpha, dtype=None, seed=None):\n    \"\"\"Draw random samples from the Gamma distribution.\n\n    Args:\n        shape: The shape of the random values to generate.\n        alpha: Float, the parameter of the distribution.\n        dtype: Optional dtype of the tensor. Only floating point types are\n            supported. If not specified, `keras.config.floatx()` is used,\n            which defaults to `float32` unless you configured it otherwise (via\n            `keras.config.set_floatx(float_dtype)`).\n        seed: Optional Python integer or instance of\n           `keras.random.SeedGenerator`.\n            By default, the `seed` argument is `None`, and an internal global\n            `keras.random.SeedGenerator` is used. The `seed` argument can be\n            used to ensure deterministic (repeatable) random number generation.\n            Note that passing an integer as the `seed` value will produce the\n            same random values for each call. To generate different random\n            values for repeated calls, an instance of\n            `keras.random.SeedGenerator` must be provided as the `seed` value.\n            Remark concerning the JAX backend: When tracing functions with the\n            JAX backend the global `keras.random.SeedGenerator` is not\n            supported. Therefore, during tracing the default value seed=None\n            will produce an error, and a `seed` argument must be provided.\n    \"\"\"\n    return backend.random.gamma(shape, alpha=alpha, dtype=dtype, seed=seed)\n\n\n@keras_export(\"keras.random.binomial\")\ndef binomial(shape, counts, probabilities, dtype=None, seed=None):\n    \"\"\"Draw samples from a Binomial distribution.\n\n    The values are drawn from a Binomial distribution with\n    specified trial count and probability of success.\n\n    Args:\n        shape: The shape of the random values to generate.\n        counts: A number or array of numbers representing the\n            number of trials. It must be broadcastable with `probabilities`.\n        probabilities: A float or array of floats representing the\n            probability of success of an individual event.\n            It must be broadcastable with `counts`.\n        dtype: Optional dtype of the tensor. Only floating point types are\n            supported. If not specified, `keras.config.floatx()` is used,\n            which defaults to `float32` unless you configured it otherwise (via\n            `keras.config.set_floatx(float_dtype)`).\n        seed: Optional Python integer or instance of\n           `keras.random.SeedGenerator`.\n            By default, the `seed` argument is `None`, and an internal global\n            `keras.random.SeedGenerator` is used. The `seed` argument can be\n            used to ensure deterministic (repeatable) random number generation.\n            Note that passing an integer as the `seed` value will produce the\n            same random values for each call. To generate different random\n            values for repeated calls, an instance of\n            `keras.random.SeedGenerator` must be provided as the `seed` value.\n            Remark concerning the JAX backend: When tracing functions with the\n            JAX backend the global `keras.random.SeedGenerator` is not\n            supported. Therefore, during tracing the default value seed=None\n            will produce an error, and a `seed` argument must be provided.\n    \"\"\"\n    return backend.random.binomial(\n        shape,\n        counts=counts,\n        probabilities=probabilities,\n        dtype=dtype,\n        seed=seed,\n    )\n\n\n@keras_export(\"keras.random.beta\")\ndef beta(shape, alpha, beta, dtype=None, seed=None):\n    \"\"\"Draw samples from a Beta distribution.\n\n    The values are drawn from a Beta distribution parametrized\n    by alpha and beta.\n\n    Args:\n        shape: The shape of the random values to generate.\n        alpha: Float or an array of floats representing the first\n            parameter alpha. Must be broadcastable with `beta` and `shape`.\n        beta: Float or an array of floats representing the second\n            parameter beta. Must be broadcastable with `alpha` and `shape`.\n        dtype: Optional dtype of the tensor. Only floating point types are\n            supported. If not specified, `keras.config.floatx()` is used,\n            which defaults to `float32` unless you configured it otherwise (via\n            `keras.config.set_floatx(float_dtype)`).\n        seed: Optional Python integer or instance of\n           `keras.random.SeedGenerator`.\n            By default, the `seed` argument is `None`, and an internal global\n            `keras.random.SeedGenerator` is used. The `seed` argument can be\n            used to ensure deterministic (repeatable) random number generation.\n            Note that passing an integer as the `seed` value will produce the\n            same random values for each call. To generate different random\n            values for repeated calls, an instance of\n            `keras.random.SeedGenerator` must be provided as the `seed` value.\n            Remark concerning the JAX backend: When tracing functions with the\n            JAX backend the global `keras.random.SeedGenerator` is not\n            supported. Therefore, during tracing the default value seed=None\n            will produce an error, and a `seed` argument must be provided.\n    \"\"\"\n    return backend.random.beta(\n        shape=shape, alpha=alpha, beta=beta, dtype=dtype, seed=seed\n    )\n"
  },
  {
    "path": "keras/src/random/random_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nimport keras\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.backend.common import dtypes\nfrom keras.src.backend.common import standardize_dtype\nfrom keras.src.random import random\nfrom keras.src.random import seed_generator\nfrom keras.src.testing.test_utils import named_product\nfrom keras.src.utils.rng_utils import set_random_seed\n\n\nclass RandomCorrectnessTest(testing.TestCase):\n    @parameterized.parameters(\n        {\"seed\": 10, \"shape\": (5,), \"mean\": 0, \"stddev\": 1},\n        {\"seed\": 10, \"shape\": (2, 3), \"mean\": 0, \"stddev\": 1},\n        {\"seed\": 10, \"shape\": (2, 3, 4), \"mean\": 0, \"stddev\": 1},\n        {\"seed\": 10, \"shape\": (2, 3), \"mean\": 10, \"stddev\": 1},\n        {\"seed\": 10, \"shape\": (2, 3), \"mean\": 10, \"stddev\": 3},\n    )\n    def test_normal(self, seed, shape, mean, stddev):\n        np.random.seed(seed)\n        np_res = np.random.normal(loc=mean, scale=stddev, size=shape)\n        res = random.normal(shape, mean=mean, stddev=stddev, seed=seed)\n        self.assertEqual(res.shape, shape)\n        self.assertEqual(res.shape, np_res.shape)\n\n    @parameterized.parameters(\n        {\"seed\": 10, \"shape\": (5,), \"minval\": 0, \"maxval\": 1},\n        {\"seed\": 10, \"shape\": (2, 3), \"minval\": 0, \"maxval\": 1},\n        {\"seed\": 10, \"shape\": (2, 3, 4), \"minval\": 0, \"maxval\": 2},\n        {\"seed\": 10, \"shape\": (2, 3), \"minval\": -1, \"maxval\": 1},\n        {\"seed\": 10, \"shape\": (2, 3), \"minval\": 1, \"maxval\": 3},\n    )\n    def test_uniform(self, seed, shape, minval, maxval):\n        np.random.seed(seed)\n        np_res = np.random.uniform(low=minval, high=maxval, size=shape)\n        res = random.uniform(shape, minval=minval, maxval=maxval, seed=seed)\n        self.assertEqual(res.shape, shape)\n        self.assertEqual(res.shape, np_res.shape)\n        self.assertLessEqual(ops.max(res), maxval)\n        self.assertGreaterEqual(ops.max(res), minval)\n\n    @parameterized.parameters(\n        {\"seed\": 10, \"num_samples\": 1, \"batch_size\": 1},\n        {\"seed\": 10, \"num_samples\": 5, \"batch_size\": 2},\n        {\"seed\": 10, \"num_samples\": 10, \"batch_size\": 4},\n        {\"seed\": 10, \"num_samples\": 15, \"batch_size\": 8},\n    )\n    def test_categorical(self, seed, num_samples, batch_size):\n        np.random.seed(seed)\n        # Create logits that definitely favors the batch index after a softmax\n        # is applied. Without a softmax, this would be close to random.\n        logits = np.eye(batch_size) * 1e5 + 1e6\n        res = random.categorical(logits, num_samples, seed=seed)\n        # Outputs should have shape `(batch_size, num_samples)`, where each\n        # output index matches the batch index.\n        self.assertEqual(res.shape, (batch_size, num_samples))\n        expected = np.tile(np.arange(batch_size)[:, None], (1, num_samples))\n        self.assertAllClose(res, expected)\n\n    @parameterized.parameters(\n        {\"seed\": 10, \"shape\": (5,), \"min\": 0, \"max\": 10, \"dtype\": \"uint16\"},\n        {\"seed\": 10, \"shape\": (2, 3), \"min\": 0, \"max\": 10, \"dtype\": \"uint32\"},\n        {\"seed\": 10, \"shape\": (2, 3, 4), \"min\": 0, \"max\": 2, \"dtype\": \"int8\"},\n        {\"seed\": 10, \"shape\": (2, 3), \"min\": -1, \"max\": 1, \"dtype\": \"int16\"},\n        {\"seed\": 10, \"shape\": (2, 3), \"min\": 1, \"max\": 3, \"dtype\": \"int32\"},\n    )\n    def test_randint(self, seed, shape, min, max, dtype):\n        np.random.seed(seed)\n        np_res = np.random.randint(low=min, high=max, size=shape)\n        res = random.randint(\n            shape, minval=min, maxval=max, seed=seed, dtype=dtype\n        )\n        self.assertEqual(res.shape, shape)\n        self.assertEqual(res.shape, np_res.shape)\n        self.assertLessEqual(ops.max(res), max)\n        self.assertGreaterEqual(ops.max(res), min)\n        # Torch has incomplete dtype support for uints; will remap some dtypes.\n        if keras.backend.backend() != \"torch\":\n            self.assertEqual(backend.standardize_dtype(res.dtype), dtype)\n\n    @parameterized.parameters(\n        {\"seed\": 10, \"shape\": (5,), \"mean\": 0, \"stddev\": 1},\n        {\"seed\": 10, \"shape\": (2, 3), \"mean\": 0, \"stddev\": 1},\n        {\"seed\": 10, \"shape\": (2, 3, 4), \"mean\": 0, \"stddev\": 1},\n        {\"seed\": 10, \"shape\": (2, 3), \"mean\": 10, \"stddev\": 1},\n        {\"seed\": 10, \"shape\": (2, 3), \"mean\": 10, \"stddev\": 3},\n        # Test list shapes.\n        {\"seed\": 10, \"shape\": [2, 3], \"mean\": 10, \"stddev\": 3},\n    )\n    def test_truncated_normal(self, seed, shape, mean, stddev):\n        np.random.seed(seed)\n        np_res = np.random.normal(loc=mean, scale=stddev, size=shape)\n        res = random.truncated_normal(\n            shape, mean=mean, stddev=stddev, seed=seed\n        )\n        self.assertEqual(res.shape, tuple(shape))\n        self.assertEqual(res.shape, np_res.shape)\n        self.assertLessEqual(ops.max(res), mean + 2 * stddev)\n        self.assertGreaterEqual(ops.max(res), mean - 2 * stddev)\n\n    def test_dropout(self):\n        x = ops.ones((10, 10))\n        self.assertAllClose(random.dropout(x, rate=0, seed=0), x)\n        x_res = random.dropout(x, rate=0.5, seed=0)\n        self.assertGreater(ops.max(x_res), ops.max(x))\n        self.assertAllClose(ops.max(x_res), 2.0)\n        self.assertGreater(ops.cast(ops.sum(x_res == 0), \"int32\"), 2)\n        x_res = random.dropout(x, rate=1.0, seed=0)\n        self.assertAllClose(x_res, ops.zeros((10, 10)))\n\n    def test_dropout_noise_shape(self):\n        inputs = ops.ones((2, 3, 5, 7))\n        x = random.dropout(\n            inputs, rate=0.3, noise_shape=[None, 3, 5, None], seed=0\n        )\n        self.assertEqual(x.shape, (2, 3, 5, 7))\n\n    def test_global_seed_generator(self):\n        # Check that unseeded RNG calls use and update global_rng_state()\n\n        def random_numbers(seed):\n            rng_state = seed_generator.global_seed_generator().state\n            rng_state.assign(seed)\n            x = random.normal((), seed=None)\n            y = random.normal((), seed=None)\n            return x, y, rng_state.value\n\n        if backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            random_numbers = tf.function(jit_compile=True)(random_numbers)\n\n        seed = ops.zeros((2,))\n        seed0 = ops.convert_to_numpy(seed)\n        x1, y1, seed = random_numbers(seed)\n        x1 = ops.convert_to_numpy(x1)\n        y1 = ops.convert_to_numpy(y1)\n        seed1 = ops.convert_to_numpy(seed)\n        x2, y2, seed = random_numbers(seed)\n        x2 = ops.convert_to_numpy(x2)\n        y2 = ops.convert_to_numpy(y2)\n        seed2 = ops.convert_to_numpy(seed)\n        x3, y3, seed = random_numbers(seed)\n        x3 = ops.convert_to_numpy(x3)\n        y3 = ops.convert_to_numpy(y3)\n        seed3 = ops.convert_to_numpy(seed)\n\n        self.assertNotEqual(seed0[1], seed1[1])\n        self.assertNotEqual(seed1[1], seed2[1])\n        self.assertNotEqual(seed2[1], seed3[1])\n\n        self.assertGreater(np.abs(x1 - y1), 1e-4)\n        self.assertGreater(np.abs(x1 - y1), 1e-4)\n        self.assertGreater(np.abs(x2 - y2), 1e-4)\n        self.assertGreater(np.abs(x3 - y3), 1e-4)\n        self.assertGreater(np.abs(x1 - x2), 1e-4)\n        self.assertGreater(np.abs(x1 - x3), 1e-4)\n        self.assertGreater(np.abs(x2 - x3), 1e-4)\n        self.assertGreater(np.abs(y1 - y2), 1e-4)\n        self.assertGreater(np.abs(y1 - y3), 1e-4)\n        self.assertGreater(np.abs(y2 - y3), 1e-4)\n\n        seed_generator.global_seed_generator().state.assign(seed)\n\n    def test_shuffle(self):\n        x = np.arange(100).reshape(10, 10)\n\n        # Test axis=0\n        y = random.shuffle(x, seed=0)\n\n        self.assertFalse(np.all(x == ops.convert_to_numpy(y)))\n        self.assertAllClose(np.sum(x, axis=0), ops.sum(y, axis=0))\n        self.assertNotAllClose(np.sum(x, axis=1), ops.sum(y, axis=1))\n\n        # Test axis=1\n        y = random.shuffle(x, axis=1, seed=0)\n\n        self.assertFalse(np.all(x == ops.convert_to_numpy(y)))\n        self.assertAllClose(np.sum(x, axis=1), ops.sum(y, axis=1))\n        self.assertNotAllClose(np.sum(x, axis=0), ops.sum(y, axis=0))\n\n    @parameterized.parameters(\n        {\"seed\": 10, \"shape\": (5, 2), \"alpha\": 2.0, \"dtype\": \"float16\"},\n        {\"seed\": 10, \"shape\": (2,), \"alpha\": 1.5, \"dtype\": \"float32\"},\n        {\"seed\": 10, \"shape\": (2, 3), \"alpha\": 0.5, \"dtype\": \"float32\"},\n    )\n    def test_gamma(self, seed, shape, alpha, dtype):\n        values = random.gamma(shape, alpha=alpha, seed=seed, dtype=dtype)\n        self.assertEqual(ops.shape(values), shape)\n        self.assertEqual(backend.standardize_dtype(values.dtype), dtype)\n        self.assertGreater(np.min(ops.convert_to_numpy(values)), 0.0)\n\n    @parameterized.parameters(\n        {\n            \"seed\": 10,\n            \"shape\": (5, 2),\n            \"counts\": 5e4,\n            \"probabilities\": 0.5,\n            \"dtype\": \"float16\",\n        },\n        {\n            \"seed\": 10,\n            \"shape\": (2,),\n            \"counts\": 1e5,\n            \"probabilities\": 0.5,\n            \"dtype\": \"float32\",\n        },\n        {\n            \"seed\": 10,\n            \"shape\": (2, 3),\n            \"counts\": [[1e5, 2e5, 3e5], [4e5, 5e5, 6e5]],\n            \"probabilities\": [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]],\n            \"dtype\": \"float32\",\n        },\n    )\n    def test_binomial(self, seed, shape, counts, probabilities, dtype):\n        set_random_seed(1337)\n        values = random.binomial(\n            shape=shape,\n            counts=counts,\n            probabilities=probabilities,\n            seed=seed,\n            dtype=dtype,\n        )\n        self.assertEqual(ops.shape(values), shape)\n        self.assertEqual(backend.standardize_dtype(values.dtype), dtype)\n\n        # The following test that ensures that the number of time\n        # each event occurs doesn't exceed the total input count specified\n        # by the user for that event.\n        # Hence, we do an element wise comparison between `counts` array\n        # and the (generated) `values` array.\n        values_np = ops.convert_to_numpy(values)\n        self.assertTrue(np.greater_equal(np.array(counts), values_np).all())\n\n        # Following test computes the probabilities of each event\n        # by dividing number of times an event occurs (which is the generated\n        # value) by the corresponding value in the (total) counts array.\n        # and then makes sure that the computed probabilities approximate\n        # the input probabilities\n        generated_probabilities = values_np / np.array(counts)\n        probabilities = np.ones(shape) * np.array(probabilities)\n        self.assertAllClose(\n            probabilities, generated_probabilities, rtol=0.005, atol=0.005\n        )\n\n    @parameterized.parameters(\n        {\n            \"seed\": 10,\n            \"shape\": (10000,),\n            \"alpha\": 3.0,\n            \"beta\": 2.0,\n            \"dtype\": \"float16\",\n        },\n        {\n            \"seed\": 10,\n            \"shape\": (10000, 3),\n            \"alpha\": [[7.0, 0.5, 1.5]],\n            \"beta\": [[15.0, 0.9, 4.5]],\n            \"dtype\": \"float32\",\n        },\n        {\n            \"seed\": 10,\n            \"shape\": (10000, 30),\n            \"alpha\": 1.0,\n            \"beta\": 1.0,\n            \"dtype\": \"float32\",\n        },\n    )\n    def test_beta(self, seed, shape, alpha, beta, dtype):\n        set_random_seed(1337)\n        values = random.beta(\n            shape=shape, alpha=alpha, beta=beta, seed=seed, dtype=dtype\n        )\n        self.assertEqual(ops.shape(values), shape)\n        self.assertEqual(backend.standardize_dtype(values.dtype), dtype)\n        values_np = ops.convert_to_numpy(values)\n        self.assertGreaterEqual(np.min(values_np), b=0.0)\n        self.assertLessEqual(np.max(values_np), b=1.0)\n\n        _alpha_is_an_array = False\n        if isinstance(alpha, list):\n            alpha = np.array(alpha)\n            beta = np.array(beta)\n            _alpha_is_an_array = True\n\n        # Mean check:\n        # For a beta distributed random variable,\n        # mean = alpha / (alpha + beta)\n        expected_mean = alpha / (alpha + beta)\n\n        if _alpha_is_an_array:\n            actual_mean = np.mean(values_np, axis=0)\n            self.assertAllClose(\n                expected_mean.flatten(), actual_mean, atol=0.005, rtol=0.005\n            )\n        else:\n            actual_mean = np.mean(values_np.flatten())\n            self.assertAlmostEqual(expected_mean, actual_mean, decimal=2)\n\n        # Variance check:\n        # For a beta distributed random variable,\n        # variance = (alpha * beta) / ((alpha + beta)^2)(alpha + beta + 1)\n        expected_variance = (alpha * beta) / (\n            np.square(alpha + beta) * (alpha + beta + 1)\n        )\n        if _alpha_is_an_array:\n            actual_variance = np.var(values_np, axis=0)\n            self.assertAllClose(\n                expected_variance.flatten(),\n                actual_variance,\n                atol=0.005,\n                rtol=0.005,\n            )\n        else:\n            actual_variance = np.var(values_np.flatten())\n            self.assertAlmostEqual(\n                expected_variance, actual_variance, decimal=2\n            )\n\n\nclass RandomBehaviorTest(testing.TestCase):\n    def test_beta_tf_data_compatibility(self):\n        import tensorflow as tf\n\n        from keras.src.layers.preprocessing.data_layer import DataLayer\n        from keras.src.random.seed_generator import SeedGenerator\n\n        class BetaLayer(DataLayer):\n            def __init__(self, seed=None, **kwargs):\n                super().__init__(**kwargs)\n                self.seed = seed\n                self.generator = SeedGenerator(seed)\n\n            def compute_output_shape(self, input_shape):\n                return input_shape\n\n            def call(self, inputs):\n                seed_generator = self._get_seed_generator(self.backend._backend)\n                noise = self.backend.random.beta(\n                    self.backend.shape(inputs),\n                    alpha=0.5,\n                    beta=0.5,\n                    seed=seed_generator,\n                )\n                inputs = inputs + noise\n                return inputs\n\n        layer = BetaLayer()\n        input_data = np.random.random([2, 4, 4, 3])\n        ds = tf.data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)\n        for output in ds.take(1):\n            output = ops.convert_to_numpy(output)\n        self.assertEqual(output.shape, (2, 4, 4, 3))\n\n    def test_categorical_errors(self):\n        with self.assertRaises(ValueError):\n            random.categorical(np.ones((5,)), 5)\n        with self.assertRaises(ValueError):\n            random.categorical(np.ones((5, 5, 5)), 5)\n\n    def test_randint_dtype_validation(self):\n        with self.assertRaisesRegex(\n            ValueError, \"`keras.random.randint` requires an integer `dtype`.\"\n        ):\n            random.randint((3, 4), minval=0, maxval=10, dtype=\"float64\")\n\n    def test_uniform_dtype_validation(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"`keras.random.uniform` requires a floating point `dtype`.\",\n        ):\n            random.uniform((3, 4), minval=0, maxval=10, dtype=\"int64\")\n\n    @pytest.mark.skipif(\n        keras.backend.backend() != \"jax\",\n        reason=\"This test requires `jax` as the backend.\",\n    )\n    def test_dropout_jax_jit_stateless(self):\n        import jax\n        import jax.numpy as jnp\n\n        x = ops.ones(3)\n\n        @jax.jit\n        def train_step(x):\n            with keras.src.backend.StatelessScope():\n                x = keras.layers.Dropout(rate=0.1)(x, training=True)\n            return x\n\n        x = train_step(x)\n        self.assertIsInstance(x, jnp.ndarray)\n\n    @pytest.mark.skipif(\n        keras.backend.backend() != \"jax\",\n        reason=\"This test requires `jax` as the backend.\",\n    )\n    def test_jax_rngkey_seed(self):\n        import jax\n        import jax.numpy as jnp\n\n        seed = 1234\n        rng = jax.random.PRNGKey(seed)\n        self.assertEqual(rng.shape, (2,))\n        self.assertEqual(rng.dtype, jnp.uint32)\n        x = random.randint((3, 5), 0, 10, seed=rng)\n        self.assertIsInstance(x, jnp.ndarray)\n\n    @pytest.mark.skipif(\n        keras.backend.backend() != \"jax\",\n        reason=\"This test requires `jax` as the backend.\",\n    )\n    def test_jax_unseed_disallowed_during_tracing(self):\n        import jax\n\n        @jax.jit\n        def jit_fn():\n            return random.randint((2, 2), 0, 10, seed=None)\n\n        with self.assertRaisesRegex(\n            ValueError, \"you should only use seeded random ops\"\n        ):\n            jit_fn()\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"This test requires `tensorflow` as the backend.\",\n    )\n    def test_tf_cast_seed(self):\n        import tensorflow as tf\n\n        inputs = tf.ones([2, 3], dtype=\"float32\")\n        seed = tf.int32.max + 1000  # Test floormod operation\n        outputs_mod = random.categorical(inputs, 2, seed=seed)\n        outputs_nomod = random.categorical(inputs, 2, seed=1001)\n        self.assertAllClose(outputs_mod, outputs_nomod)\n\n\nclass RandomDTypeTest(testing.TestCase):\n    \"\"\"Test the dtype to verify that the behavior matches JAX.\"\"\"\n\n    INT_DTYPES = [x for x in dtypes.INT_TYPES if x not in (\"uint64\", \"int64\")]\n    FLOAT_DTYPES = [x for x in dtypes.FLOAT_TYPES if x not in (\"float64\",)]\n    if backend.backend() == \"torch\":\n        INT_DTYPES = [x for x in INT_DTYPES if x not in (\"uint16\", \"uint32\")]\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_normal(self, dtype):\n        res = random.normal((2, 3), dtype=dtype)\n        self.assertEqual(standardize_dtype(res.dtype), dtype)\n\n    @parameterized.named_parameters(named_product(dtype=INT_DTYPES))\n    def test_categorical(self, dtype):\n        logits = np.eye(4) * 1e5 + 1e6\n        res = random.categorical(logits, 10, dtype=dtype)\n        self.assertEqual(standardize_dtype(res.dtype), dtype)\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_uniform(self, dtype):\n        res = random.uniform((2, 3), dtype=dtype)\n        self.assertEqual(standardize_dtype(res.dtype), dtype)\n\n    @parameterized.named_parameters(named_product(dtype=INT_DTYPES))\n    def test_randint(self, dtype):\n        res = random.randint((2, 3), 0, 10, dtype=dtype)\n        self.assertEqual(standardize_dtype(res.dtype), dtype)\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_truncated_normal(self, dtype):\n        res = random.truncated_normal((2, 3), dtype=dtype)\n        self.assertEqual(standardize_dtype(res.dtype), dtype)\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_dropout(self, dtype):\n        x = ops.ones((3, 5), dtype=dtype)\n        res = random.dropout(x, rate=0.8, seed=0)\n        self.assertEqual(standardize_dtype(res.dtype), dtype)\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_gamma(self, dtype):\n        res = random.gamma((2, 3), 2.0, dtype=dtype)\n        self.assertEqual(standardize_dtype(res.dtype), dtype)\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_binomial(self, dtype):\n        res = random.binomial((2,), 1e5, 0.5, dtype=dtype)\n        self.assertEqual(standardize_dtype(res.dtype), dtype)\n\n    @parameterized.named_parameters(named_product(dtype=FLOAT_DTYPES))\n    def test_beta(self, dtype):\n        res = random.beta((2, 3), 2.0, 3.0, dtype=dtype)\n        self.assertEqual(standardize_dtype(res.dtype), dtype)\n"
  },
  {
    "path": "keras/src/random/seed_generator.py",
    "content": "import random as python_random\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend.common import global_state\nfrom keras.src.utils import jax_utils\nfrom keras.src.utils.naming import auto_name\n\nGLOBAL_SEED_GENERATOR = \"global_seed_generator\"\n\n\n@keras_export(\"keras.random.SeedGenerator\")\nclass SeedGenerator:\n    \"\"\"Generates variable seeds upon each call to a function generating\n    random numbers.\n\n    In Keras, all random number generators (such as\n    `keras.random.normal()`) are stateless, meaning that if you pass an\n    integer seed to them (such as `seed=42`), they will return the same\n    values for repeated calls. To get different values for each\n    call, a `SeedGenerator` providing the state of the random generator\n    has to be used.\n\n    Note that all the random number generators have a default seed of None,\n    which implies that an internal global SeedGenerator is used.\n    If you need to decouple the RNG from the global state you can provide\n    a local `StateGenerator` with either a deterministic or random initial\n    state.\n\n    Remark concerning the JAX backend: Note that the use of a local\n    `StateGenerator` as seed argument is required for JIT compilation of\n    RNG with the JAX backend, because the use of global state is not\n    supported.\n\n    Example:\n\n    ```python\n    seed_gen = keras.random.SeedGenerator(seed=42)\n    values = keras.random.normal(shape=(2, 3), seed=seed_gen)\n    new_values = keras.random.normal(shape=(2, 3), seed=seed_gen)\n    ```\n\n    Usage in a layer:\n\n    ```python\n    class Dropout(keras.Layer):\n        def __init__(self, **kwargs):\n            super().__init__(**kwargs)\n            self.seed_generator = keras.random.SeedGenerator(1337)\n\n        def call(self, x, training=False):\n            if training:\n                return keras.random.dropout(\n                    x, rate=0.5, seed=self.seed_generator\n                )\n            return x\n    ```\n    \"\"\"\n\n    def __init__(self, seed=None, name=None, **kwargs):\n        if name is None:\n            name = auto_name(self.__class__.__name__)\n        self.name = name\n\n        custom_backend = kwargs.pop(\"backend\", None)\n        if kwargs:\n            raise ValueError(f\"Unrecognized keyword arguments: {kwargs}\")\n        if custom_backend is not None:\n            self.backend = custom_backend\n        else:\n            self.backend = backend\n\n        self._initial_seed = seed\n        if seed is None:\n            seed = make_default_seed()\n\n        if not isinstance(seed, int):\n            raise ValueError(\n                f\"Argument `seed` must be an integer. Received: seed={seed}\"\n            )\n\n        def seed_initializer(*args, **kwargs):\n            dtype = kwargs.get(\"dtype\", None)\n            return self.backend.convert_to_tensor([seed, 0], dtype=dtype)\n\n        with self.backend.name_scope(self.name, caller=self):\n            self.state = self.backend.Variable(\n                seed_initializer,\n                shape=(2,),\n                dtype=self.backend.random_seed_dtype(),\n                trainable=False,\n                aggregation=\"none\",\n                name=\"seed_generator_state\",\n            )\n\n    def next(self, ordered=True):\n        seed_state = self.state\n        # Use * 1 to create a copy\n        new_seed_value = seed_state.value * 1\n        if ordered:\n            increment = self.backend.convert_to_tensor(\n                np.array([0, 1]), dtype=seed_state.dtype\n            )\n            self.state.assign(self.backend.numpy.add(seed_state, increment))\n        else:\n            # This produces a sequence of near-unique numbers\n            # between 0 and 1M\n            self.state.assign((seed_state + 1) * 5387 % 933199)\n        return new_seed_value\n\n    def get_config(self):\n        return {\"seed\": self._initial_seed, \"name\": self.name}\n\n    @classmethod\n    def from_config(cls, config):\n        return cls(**config)\n\n\ndef global_seed_generator():\n    if jax_utils.is_in_jax_tracing_scope():\n        raise ValueError(\n            \"[JAX RNG] When tracing a JAX function, \"\n            \"you should only use seeded random ops, e.g. \"\n            \"you should create a `SeedGenerator` instance, attach it \"\n            \"to your layer/model, and pass the instance as the `seed` \"\n            \"argument when calling random ops. Unseeded random ops \"\n            \"would get incorrectly traced by JAX and would become constant \"\n            \"after tracing. Example:\\n\\n\"\n            \"```\\n\"\n            \"# Make sure to set the seed generator as a layer attribute\\n\"\n            \"self.seed_generator = keras.random.SeedGenerator(seed=1337)\\n\"\n            \"...\\n\"\n            \"out = keras.random.normal(shape=(1,), seed=self.seed_generator)\\n\"\n            \"```\"\n        )\n    gen = global_state.get_global_attribute(GLOBAL_SEED_GENERATOR)\n    if gen is None:\n        gen = SeedGenerator()\n        global_state.set_global_attribute(GLOBAL_SEED_GENERATOR, gen)\n    return gen\n\n\ndef make_default_seed():\n    return python_random.randint(1, int(1e9))\n\n\ndef draw_seed(seed):\n    from keras.src.backend import convert_to_tensor\n    from keras.src.backend import random_seed_dtype\n\n    if isinstance(seed, SeedGenerator):\n        return seed.next()\n    elif isinstance(seed, int):\n        dtype = random_seed_dtype()\n        # Seeds are conceptually uint32 values but some backends declare\n        # their seed dtype as a signed type (e.g. \"int32\"). np.array(x,\n        # dtype=\"int32\") raises OverflowError on NumPy >= 1.24 for values\n        # >= 2**31. Perform an explicit 2's-complement bit-cast via uint32\n        # so that the integer passed to convert_to_tensor is always in the\n        # representable range of the declared dtype while preserving full\n        # 32-bit entropy.\n        if dtype == \"int32\":\n            # Re-interpret the bits of a uint32 as an int32.\n            seed = (seed & 0xFFFFFFFF ^ 0x80000000) - 0x80000000\n        return convert_to_tensor([seed, 0], dtype=dtype)\n    elif seed is None:\n        return global_seed_generator().next(ordered=False)\n    raise ValueError(\n        \"Argument `seed` must be either an integer \"\n        \"or an instance of `SeedGenerator`. \"\n        f\"Received: seed={seed} (of type {type(seed)})\"\n    )\n"
  },
  {
    "path": "keras/src/random/seed_generator_test.py",
    "content": "import numpy as np\nimport pytest\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import random\nfrom keras.src import testing\nfrom keras.src.random import seed_generator\n\n\nclass SeedGeneratorTest(testing.TestCase):\n    def test_seed_generator_initialization(self):\n        gen = seed_generator.SeedGenerator()\n        self.assertIsNotNone(gen.state)\n\n        seed = 12345\n        gen = seed_generator.SeedGenerator(seed=seed)\n        self.assertEqual(ops.convert_to_numpy(gen.state)[0], seed)\n\n        with self.assertRaisesRegex(\n            ValueError, \"Argument `seed` must be an integer\"\n        ):\n            seed_generator.SeedGenerator(seed=\"invalid_seed\")\n\n    def test_seed_generator_next(self):\n        gen = seed_generator.SeedGenerator(seed=42)\n        seed1 = ops.convert_to_numpy(gen.next())\n        seed2 = ops.convert_to_numpy(gen.next())\n        self.assertFalse(np.array_equal(seed1, seed2))\n\n    def test_global_seed_generator(self):\n        gen1 = seed_generator.global_seed_generator()\n        gen2 = seed_generator.global_seed_generator()\n        self.assertEqual(gen1, gen2)\n\n    def test_make_default_seed(self):\n        seed1 = seed_generator.make_default_seed()\n        seed2 = seed_generator.make_default_seed()\n        self.assertNotEqual(seed1, seed2)\n\n    def test_seed_generator_dtype(self):\n        gen = seed_generator.SeedGenerator(seed=42)\n        self.assertEqual(gen.state.dtype, backend.random_seed_dtype())\n        seed = gen.next()\n        self.assertEqual(gen.state.dtype, backend.random_seed_dtype())\n        self.assertEqual(\n            backend.standardize_dtype(seed.dtype), backend.random_seed_dtype()\n        )\n\n    def test_draw_seed_from_seed_generator(self):\n        gen = seed_generator.SeedGenerator(seed=42)\n        seed1 = seed_generator.draw_seed(gen)\n        self.assertTrue(backend.is_tensor(seed1))\n\n    def test_draw_seed_from_integer(self):\n        seed2 = seed_generator.draw_seed(12345)\n        self.assertTrue(backend.is_tensor(seed2))\n        self.assertEqual(\n            backend.standardize_dtype(seed2.dtype), backend.random_seed_dtype()\n        )\n\n    def test_draw_seed_from_large_integer(self):\n        # Seeds at int32 boundaries must not cause overflow or sign-flip errors.\n        # 2**31 is where signed-int32 wraps; 2**32 - 1 is the max uint32 value.\n        for s in [2**31 - 1, 2**31, 2**32 - 1]:\n            seed = seed_generator.draw_seed(s)\n            self.assertTrue(backend.is_tensor(seed))\n            self.assertEqual(\n                backend.standardize_dtype(seed.dtype),\n                backend.random_seed_dtype(),\n            )\n            # Verify the seed can be consumed by a random op without error.\n            res = random.uniform((2, 2), seed=s)\n            self.assertEqual(res.shape, (2, 2))\n\n    def test_draw_seed_from_none(self):\n        seed3 = seed_generator.draw_seed(None)\n        self.assertTrue(backend.is_tensor(seed3))\n\n    def test_draw_seed_invalid(self):\n        with self.assertRaisesRegex(\n            ValueError, \"Argument `seed` must be either an integer\"\n        ):\n            seed_generator.draw_seed(\"invalid_seed\")\n\n    def test_seed_generator_unexpected_kwargs(self):\n        with self.assertRaisesRegex(\n            ValueError, \"Unrecognized keyword arguments\"\n        ):\n            seed_generator.SeedGenerator(invalid_arg=\"unexpected_value\")\n\n    @pytest.mark.skipif(\n        backend.backend() != \"jax\", reason=\"This test requires the JAX backend\"\n    )\n    def test_jax_tracing_with_global_seed_generator(self):\n        import jax\n\n        @jax.jit\n        def traced_function():\n            return seed_generator.global_seed_generator().next()\n\n        with self.assertRaisesRegex(\n            ValueError,\n            \"When tracing a JAX function, you should only use seeded random\",\n        ):\n            traced_function()\n\n    def test_seed_generator_serialization(self):\n        random_generator = seed_generator.SeedGenerator(seed=42, name=\"sg\")\n        self.run_class_serialization_test(random_generator)\n"
  },
  {
    "path": "keras/src/regularizers/__init__.py",
    "content": "import inspect\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.regularizers.regularizers import L1\nfrom keras.src.regularizers.regularizers import L1L2\nfrom keras.src.regularizers.regularizers import L2\nfrom keras.src.regularizers.regularizers import OrthogonalRegularizer\nfrom keras.src.regularizers.regularizers import Regularizer\nfrom keras.src.saving import serialization_lib\nfrom keras.src.utils.naming import to_snake_case\n\nALL_OBJECTS = {\n    Regularizer,\n    L1,\n    L2,\n    L1L2,\n    OrthogonalRegularizer,\n}\n\nALL_OBJECTS_DICT = {cls.__name__: cls for cls in ALL_OBJECTS}\nALL_OBJECTS_DICT.update(\n    {to_snake_case(cls.__name__): cls for cls in ALL_OBJECTS}\n)\n\n\n@keras_export(\"keras.regularizers.serialize\")\ndef serialize(regularizer):\n    return serialization_lib.serialize_keras_object(regularizer)\n\n\n@keras_export(\"keras.regularizers.deserialize\")\ndef deserialize(config, custom_objects=None):\n    \"\"\"Return a Keras regularizer object via its config.\"\"\"\n    return serialization_lib.deserialize_keras_object(\n        config,\n        module_objects=ALL_OBJECTS_DICT,\n        custom_objects=custom_objects,\n    )\n\n\n@keras_export(\"keras.regularizers.get\")\ndef get(identifier):\n    \"\"\"Retrieve a Keras regularizer object via an identifier.\"\"\"\n    if identifier is None:\n        return None\n    if isinstance(identifier, dict):\n        obj = deserialize(identifier)\n    elif isinstance(identifier, str):\n        obj = ALL_OBJECTS_DICT.get(identifier, None)\n    else:\n        obj = identifier\n\n    if callable(obj):\n        if inspect.isclass(obj):\n            obj = obj()\n        return obj\n    else:\n        raise ValueError(\n            f\"Could not interpret regularizer identifier: {identifier}\"\n        )\n"
  },
  {
    "path": "keras/src/regularizers/regularizers.py",
    "content": "import math\n\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.utils.numerical_utils import normalize\n\n\n@keras_export([\"keras.Regularizer\", \"keras.regularizers.Regularizer\"])\nclass Regularizer:\n    \"\"\"Regularizer base class.\n\n    Regularizers allow you to apply penalties on layer parameters or layer\n    activity during optimization. These penalties are summed into the loss\n    function that the network optimizes.\n\n    Regularization penalties are applied on a per-layer basis. The exact API\n    will depend on the layer, but many layers (e.g. `Dense`, `Conv1D`, `Conv2D`\n    and `Conv3D`) have a unified API.\n\n    These layers expose 3 keyword arguments:\n\n    - `kernel_regularizer`: Regularizer to apply a penalty on the layer's kernel\n    - `bias_regularizer`: Regularizer to apply a penalty on the layer's bias\n    - `activity_regularizer`: Regularizer to apply a penalty on the layer's\n        output\n\n    All layers (including custom layers) expose `activity_regularizer` as a\n    settable property, whether or not it is in the constructor arguments.\n\n    The value returned by the `activity_regularizer` is divided by the input\n    batch size so that the relative weighting between the weight regularizers\n    and the activity regularizers does not change with the batch size.\n\n    You can access a layer's regularization penalties by calling `layer.losses`\n    after calling the layer on inputs.\n\n    ## Example\n\n    >>> layer = Dense(\n    ...     5, input_dim=5,\n    ...     kernel_initializer='ones',\n    ...     kernel_regularizer=L1(0.01),\n    ...     activity_regularizer=L2(0.01))\n    >>> tensor = ops.ones(shape=(5, 5)) * 2.0\n    >>> out = layer(tensor)\n\n    >>> # The kernel regularization term is 0.25\n    >>> # The activity regularization term (after dividing by batch size of 5)\n    >>> # is 5.0\n    >>> ops.sum(layer.losses)\n    5.25\n\n    ## Available penalties\n\n    ```python\n    L1(0.3)  # L1 Regularization Penalty\n    L2(0.1)  # L2 Regularization Penalty\n    L1L2(l1=0.01, l2=0.01)  # L1 + L2 penalties\n    ```\n\n    ## Directly calling a regularizer\n\n    Compute a regularization loss on a tensor by directly calling a regularizer\n    as if it is a one-argument function.\n\n    E.g.\n\n    >>> regularizer = L2(2.)\n    >>> tensor = ops.ones(shape=(5, 5))\n    >>> regularizer(tensor)\n    50.0\n\n    ## Developing new regularizers\n\n    Any function that takes in a weight matrix and returns a scalar\n    tensor can be used as a regularizer, e.g.:\n\n    >>> def l1_reg(weight_matrix):\n    ...    return 0.01 * ops.sum(ops.absolute(weight_matrix))\n    ...\n    >>> layer = Dense(5, input_dim=5,\n    ...     kernel_initializer='ones', kernel_regularizer=l1_reg)\n    >>> tensor = ops.ones(shape=(5, 5))\n    >>> out = layer(tensor)\n    >>> layer.losses\n    0.25\n\n    Alternatively, you can write your custom regularizers in an\n    object-oriented way by extending this regularizer base class, e.g.:\n\n    >>> class L2Regularizer(Regularizer):\n    ...   def __init__(self, l2=0.):\n    ...     self.l2 = l2\n    ...\n    ...   def __call__(self, x):\n    ...     return self.l2 * ops.sum(ops.square(x))\n    ...\n    ...   def get_config(self):\n    ...     return {'l2': float(self.l2)}\n    ...\n    >>> layer = Dense(\n    ...   5, input_dim=5, kernel_initializer='ones',\n    ...   kernel_regularizer=L2Regularizer(l2=0.5))\n\n    >>> tensor = ops.ones(shape=(5, 5))\n    >>> out = layer(tensor)\n    >>> layer.losses\n    12.5\n\n    ### A note on serialization and deserialization:\n\n    Registering the regularizers as serializable is optional if you are just\n    training and executing models, exporting to and from SavedModels, or saving\n    and loading weight checkpoints.\n\n    Registration is required for saving and\n    loading models to HDF5 format, Keras model cloning, some visualization\n    utilities, and exporting models to and from JSON. If using this\n    functionality, you must make sure any python process running your model has\n    also defined and registered your custom regularizer.\n    \"\"\"\n\n    def __call__(self, x):\n        \"\"\"Compute a regularization penalty from an input tensor.\"\"\"\n        return 0.0\n\n    @classmethod\n    def from_config(cls, config):\n        \"\"\"Creates a regularizer from its config.\n\n        This method is the reverse of `get_config`,\n        capable of instantiating the same regularizer from the config\n        dictionary.\n\n        This method is used by Keras `model_to_estimator`, saving and\n        loading models to HDF5 formats, Keras model cloning, some visualization\n        utilities, and exporting models to and from JSON.\n\n        Args:\n            config: A Python dictionary, typically the output of get_config.\n\n        Returns:\n            A regularizer instance.\n        \"\"\"\n        return cls(**config)\n\n    def get_config(self):\n        \"\"\"Returns the config of the regularizer.\n\n        An regularizer config is a Python dictionary (serializable)\n        containing all configuration parameters of the regularizer.\n        The same regularizer can be reinstantiated later\n        (without any saved state) from this configuration.\n\n        This method is optional if you are just training and executing models,\n        exporting to and from SavedModels, or using weight checkpoints.\n\n        This method is required for Keras `model_to_estimator`, saving and\n        loading models to HDF5 formats, Keras model cloning, some visualization\n        utilities, and exporting models to and from JSON.\n\n        Returns:\n            Python dictionary.\n        \"\"\"\n        raise NotImplementedError(f\"{self} does not implement get_config()\")\n\n\n@keras_export([\"keras.regularizers.L1L2\", \"keras.regularizers.l1_l2\"])\nclass L1L2(Regularizer):\n    \"\"\"A regularizer that applies both L1 and L2 regularization penalties.\n\n    The L1 regularization penalty is computed as:\n    `loss = l1 * reduce_sum(abs(x))`\n\n    The L2 regularization penalty is computed as\n    `loss = l2 * reduce_sum(square(x))`\n\n    L1L2 may be passed to a layer as a string identifier:\n\n    >>> dense = Dense(3, kernel_regularizer='l1_l2')\n\n    In this case, the default values used are `l1=0.01` and `l2=0.01`.\n\n    Arguments:\n        l1: float, L1 regularization factor.\n        l2: float, L2 regularization factor.\n    \"\"\"\n\n    def __init__(self, l1=0.0, l2=0.0):\n        # The default value for l1 and l2 are different from the value in l1_l2\n        # for backward compatibility reason. Eg, L1L2(l2=0.1) will only have l2\n        # and no l1 penalty.\n        l1 = 0.0 if l1 is None else l1\n        l2 = 0.0 if l2 is None else l2\n        validate_float_arg(l1, name=\"l1\")\n        validate_float_arg(l2, name=\"l2\")\n\n        self.l1 = l1\n        self.l2 = l2\n\n    def __call__(self, x):\n        regularization = ops.convert_to_tensor(0.0, dtype=x.dtype)\n        if self.l1:\n            regularization += self.l1 * ops.sum(ops.absolute(x))\n        if self.l2:\n            regularization += self.l2 * ops.sum(ops.square(x))\n        return regularization\n\n    def get_config(self):\n        return {\"l1\": float(self.l1), \"l2\": float(self.l2)}\n\n\n@keras_export([\"keras.regularizers.L1\", \"keras.regularizers.l1\"])\nclass L1(Regularizer):\n    \"\"\"A regularizer that applies a L1 regularization penalty.\n\n    The L1 regularization penalty is computed as:\n    `loss = l1 * reduce_sum(abs(x))`\n\n    L1 may be passed to a layer as a string identifier:\n\n    >>> dense = Dense(3, kernel_regularizer='l1')\n\n    In this case, the default value used is `l1=0.01`.\n\n    Arguments:\n        l1: float, L1 regularization factor.\n    \"\"\"\n\n    def __init__(self, l1=0.01):\n        l1 = 0.01 if l1 is None else l1\n        validate_float_arg(l1, name=\"l1\")\n        self.l1 = ops.convert_to_tensor(l1)\n\n    def __call__(self, x):\n        return self.l1 * ops.sum(ops.absolute(x))\n\n    def get_config(self):\n        return {\"l1\": float(self.l1)}\n\n\n@keras_export([\"keras.regularizers.L2\", \"keras.regularizers.l2\"])\nclass L2(Regularizer):\n    \"\"\"A regularizer that applies a L2 regularization penalty.\n\n    The L2 regularization penalty is computed as:\n    `loss = l2 * reduce_sum(square(x))`\n\n    L2 may be passed to a layer as a string identifier:\n\n    >>> dense = Dense(3, kernel_regularizer='l2')\n\n    In this case, the default value used is `l2=0.01`.\n\n    Arguments:\n        l2: float, L2 regularization factor.\n    \"\"\"\n\n    def __init__(self, l2=0.01):\n        l2 = 0.01 if l2 is None else l2\n        validate_float_arg(l2, name=\"l2\")\n        self.l2 = l2\n\n    def __call__(self, x):\n        return self.l2 * ops.sum(ops.square(x))\n\n    def get_config(self):\n        return {\"l2\": float(self.l2)}\n\n\n@keras_export(\n    [\n        \"keras.regularizers.OrthogonalRegularizer\",\n        \"keras.regularizers.orthogonal_regularizer\",\n    ]\n)\nclass OrthogonalRegularizer(Regularizer):\n    \"\"\"Regularizer that encourages input vectors to be orthogonal to each other.\n\n    It can be applied to either the rows of a matrix (`mode=\"rows\"`) or its\n    columns (`mode=\"columns\"`). When applied to a `Dense` kernel of shape\n    `(input_dim, units)`, rows mode will seek to make the feature vectors\n    (i.e. the basis of the output space) orthogonal to each other.\n\n    Arguments:\n        factor: Float. The regularization factor. The regularization penalty\n            will be proportional to `factor` times the mean of the dot products\n            between the L2-normalized rows (if `mode=\"rows\"`, or columns if\n            `mode=\"columns\"`) of the inputs, excluding the product of each\n            row/column with itself.  Defaults to `0.01`.\n        mode: String, one of `{\"rows\", \"columns\"}`. Defaults to `\"rows\"`. In\n            rows mode, the regularization effect seeks to make the rows of the\n            input orthogonal to each other. In columns mode, it seeks to make\n            the columns of the input orthogonal to each other.\n\n    Example:\n\n    >>> regularizer = OrthogonalRegularizer(factor=0.01)\n    >>> layer = Dense(units=4, kernel_regularizer=regularizer)\n    \"\"\"\n\n    def __init__(self, factor=0.01, mode=\"rows\"):\n        validate_float_arg(factor, name=\"factor\")\n        self.factor = ops.convert_to_tensor(factor)\n        if mode not in {\"rows\", \"columns\"}:\n            raise ValueError(\n                \"Invalid value for argument `mode`. Expected one of \"\n                f'{{\"rows\", \"columns\"}}. Received: mode={mode}'\n            )\n        self.mode = mode\n\n    def __call__(self, inputs):\n        if len(inputs.shape) != 2:\n            raise ValueError(\n                \"Inputs to OrthogonalRegularizer must have rank 2. Received: \"\n                f\"inputs.shape={inputs.shape}\"\n            )\n        if self.mode == \"rows\":\n            inputs = normalize(inputs, axis=1)\n            product = ops.matmul(inputs, ops.transpose(inputs))\n            size = inputs.shape[0]\n        else:\n            inputs = normalize(inputs, axis=0)\n            product = ops.matmul(ops.transpose(inputs), inputs)\n            size = inputs.shape[1]\n        product_no_diagonal = product * (\n            1.0 - ops.eye(size, dtype=inputs.dtype)\n        )\n        num_pairs = size * (size - 1.0) / 2.0\n        return (\n            self.factor\n            * 0.5\n            * ops.sum(ops.absolute(product_no_diagonal))\n            / num_pairs\n        )\n\n    def get_config(self):\n        return {\"factor\": float(self.factor), \"mode\": self.mode}\n\n\ndef validate_float_arg(value, name):\n    \"\"\"check penalty number availability, raise ValueError if failed.\"\"\"\n    if (\n        not isinstance(value, (float, int))\n        or (math.isinf(value) or math.isnan(value))\n        or value < 0\n    ):\n        raise ValueError(\n            f\"Invalid value for argument {name}: expected a non-negative float.\"\n            f\"Received: {name}={value}\"\n        )\n    return float(value)\n"
  },
  {
    "path": "keras/src/regularizers/regularizers_test.py",
    "content": "import numpy as np\n\nfrom keras.src import backend\nfrom keras.src import regularizers\nfrom keras.src import testing\nfrom keras.src.regularizers.regularizers import validate_float_arg\n\n\nclass RegularizersTest(testing.TestCase):\n    def test_config(self):\n        reg = regularizers.L1(0.1)\n        self.run_class_serialization_test(reg)\n\n        reg = regularizers.L2(0.1)\n        self.run_class_serialization_test(reg)\n\n        reg = regularizers.L1L2(l1=0.1, l2=0.2)\n        self.run_class_serialization_test(reg)\n\n        reg = regularizers.OrthogonalRegularizer(factor=0.1, mode=\"rows\")\n        self.run_class_serialization_test(reg)\n\n    def test_l1(self):\n        value = np.random.random((4, 4)).astype(np.float32)\n        x = backend.Variable(value)\n        y = regularizers.L1(0.1)(x)\n        self.assertAllClose(y, 0.1 * np.sum(np.abs(value)))\n\n    def test_l2(self):\n        value = np.random.random((4, 4)).astype(np.float32)\n        x = backend.Variable(value)\n        y = regularizers.L2(0.1)(x)\n        self.assertAllClose(y, 0.1 * np.sum(np.square(value)))\n\n    def test_l1_l2(self):\n        value = np.random.random((4, 4)).astype(np.float32)\n        x = backend.Variable(value)\n        y = regularizers.L1L2(l1=0.1, l2=0.2)(x)\n        self.assertAllClose(\n            y, 0.1 * np.sum(np.abs(value)) + 0.2 * np.sum(np.square(value))\n        )\n\n    def test_orthogonal_regularizer(self):\n        value = np.random.random((4, 4)).astype(np.float32)\n        x = backend.Variable(value)\n        y = regularizers.OrthogonalRegularizer(factor=0.1, mode=\"rows\")(x)\n\n        l2_norm = np.linalg.norm(value, axis=1, keepdims=True)\n        inputs = value / l2_norm\n        self.assertAllClose(\n            y,\n            0.1\n            * 0.5\n            * np.sum(\n                np.abs(np.dot(inputs, np.transpose(inputs)) * (1.0 - np.eye(4)))\n            )\n            / (4.0 * (4.0 - 1.0) / 2.0),\n            tpu_atol=1e-4,\n            tpu_rtol=1e-4,\n        )\n\n    def test_get_method(self):\n        obj = regularizers.get(\"l1l2\")\n        self.assertIsInstance(obj, regularizers.L1L2)\n\n        obj = regularizers.get(\"l1\")\n        self.assertIsInstance(obj, regularizers.L1)\n\n        obj = regularizers.get(\"l2\")\n        self.assertIsInstance(obj, regularizers.L2)\n\n        obj = regularizers.get(\"orthogonal_regularizer\")\n        self.assertIsInstance(obj, regularizers.OrthogonalRegularizer)\n\n        obj = regularizers.get(None)\n        self.assertEqual(obj, None)\n\n        with self.assertRaises(ValueError):\n            regularizers.get(\"typo\")\n\n    def test_l1l2_get_config(self):\n        l1 = 0.01\n        l2 = 0.02\n        reg = regularizers.L1L2(l1=l1, l2=l2)\n        config = reg.get_config()\n\n        self.assertEqual(config, {\"l1\": l1, \"l2\": l2})\n\n        reg_from_config = regularizers.L1L2.from_config(config)\n        config_from_config = reg_from_config.get_config()\n\n        self.assertDictEqual(config, config_from_config)\n        self.assertEqual(reg_from_config.l1, l1)\n        self.assertEqual(reg_from_config.l2, l2)\n\n    def test_orthogonal_regularizer_mode_validation(self):\n        with self.assertRaises(ValueError) as context:\n            regularizers.OrthogonalRegularizer(factor=0.01, mode=\"invalid_mode\")\n\n        expected_message = (\n            'Invalid value for argument `mode`. Expected one of {\"rows\", '\n            '\"columns\"}. Received: mode=invalid_mode'\n        )\n        self.assertEqual(str(context.exception), expected_message)\n\n    def test_orthogonal_regularizer_input_rank_validation(self):\n        with self.assertRaises(ValueError) as context:\n            value = np.random.random((4, 4, 4)).astype(np.float32)\n            x = backend.Variable(value)\n            regularizers.OrthogonalRegularizer(factor=0.1)(x)\n\n        expected_message = (\n            \"Inputs to OrthogonalRegularizer must have rank 2. \"\n            f\"Received: inputs.shape={(4, 4, 4)}\"\n        )\n        self.assertEqual(str(context.exception), expected_message)\n\n    def test_orthogonal_regularizer_get_config(self):\n        factor = 0.01\n        mode = \"columns\"\n        regularizer = regularizers.OrthogonalRegularizer(\n            factor=factor, mode=mode\n        )\n        config = regularizer.get_config()\n\n        self.assertAlmostEqual(config[\"factor\"], factor, 7)\n        self.assertEqual(config[\"mode\"], mode)\n\n        reg_from_config = regularizers.OrthogonalRegularizer.from_config(config)\n        config_from_config = reg_from_config.get_config()\n\n        self.assertAlmostEqual(config_from_config[\"factor\"], factor, 7)\n        self.assertEqual(config_from_config[\"mode\"], mode)\n\n\nclass ValidateFloatArgTest(testing.TestCase):\n    def test_validate_float_with_valid_args(self):\n        self.assertEqual(validate_float_arg(1, \"test\"), 1.0)\n        self.assertEqual(validate_float_arg(1.0, \"test\"), 1.0)\n\n    def test_validate_float_with_invalid_types(self):\n        with self.assertRaisesRegex(\n            ValueError, \"expected a non-negative float\"\n        ):\n            validate_float_arg(\"not_a_number\", \"test\")\n\n    def test_validate_float_with_nan(self):\n        with self.assertRaisesRegex(\n            ValueError, \"expected a non-negative float\"\n        ):\n            validate_float_arg(float(\"nan\"), \"test\")\n\n    def test_validate_float_with_inf(self):\n        with self.assertRaisesRegex(\n            ValueError, \"expected a non-negative float\"\n        ):\n            validate_float_arg(float(\"inf\"), \"test\")\n        with self.assertRaisesRegex(\n            ValueError, \"expected a non-negative float\"\n        ):\n            validate_float_arg(-float(\"inf\"), \"test\")\n\n    def test_validate_float_with_negative_number(self):\n        with self.assertRaisesRegex(\n            ValueError, \"expected a non-negative float\"\n        ):\n            validate_float_arg(-1, \"test\")\n"
  },
  {
    "path": "keras/src/saving/__init__.py",
    "content": "from keras.src.saving.object_registration import CustomObjectScope\nfrom keras.src.saving.object_registration import custom_object_scope\nfrom keras.src.saving.object_registration import get_custom_objects\nfrom keras.src.saving.object_registration import get_registered_name\nfrom keras.src.saving.object_registration import get_registered_object\nfrom keras.src.saving.object_registration import register_keras_serializable\nfrom keras.src.saving.saving_api import load_model\nfrom keras.src.saving.serialization_lib import deserialize_keras_object\nfrom keras.src.saving.serialization_lib import serialize_keras_object\n"
  },
  {
    "path": "keras/src/saving/file_editor.py",
    "content": "import collections\nimport json\nimport os.path\nimport pprint\nimport zipfile\n\nimport h5py\nimport numpy as np\nimport rich.console\n\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.saving import saving_lib\nfrom keras.src.saving.saving_lib import H5IOStore\nfrom keras.src.utils import naming\nfrom keras.src.utils import summary_utils\n\ntry:\n    import IPython as ipython\nexcept ImportError:\n    ipython = None\n\n\ndef is_ipython_notebook():\n    \"\"\"Checks if the code is being executed in a notebook.\"\"\"\n    try:\n        from IPython import get_ipython\n\n        # Check if an active IPython shell exists.\n        if get_ipython() is not None:\n            return True\n        return False\n    except ImportError:\n        return False\n\n\n@keras_export(\"keras.saving.KerasFileEditor\")\nclass KerasFileEditor:\n    \"\"\"Utility to inspect, edit, and resave Keras weights files.\n\n    You will find this class useful when adapting\n    an old saved weights file after having made\n    architecture changes to a model.\n\n    Args:\n        filepath: The path to a local file to inspect and edit.\n\n    Examples:\n\n    ```python\n    editor = KerasFileEditor(\"my_model.weights.h5\")\n\n    # Displays current contents\n    editor.summary()\n\n    # Remove the weights of an existing layer\n    editor.delete_object(\"layers/dense_2\")\n\n    # Add the weights of a new layer\n    editor.add_object(\"layers/einsum_dense\", weights={\"0\": ..., \"1\": ...})\n\n    # Save the weights of the edited model\n    editor.resave_weights(\"edited_model.weights.h5\")\n    ```\n    \"\"\"\n\n    def __init__(\n        self,\n        filepath,\n    ):\n        self.filepath = filepath\n        self.metadata = None\n        self.config = None\n        self.model = None\n        self.console = rich.console.Console(highlight=False)\n\n        if filepath.endswith(\".keras\"):\n            zf = zipfile.ZipFile(filepath, \"r\")\n            weights_store = H5IOStore(\n                f\"{saving_lib._VARS_FNAME}.h5\",\n                archive=zf,\n                mode=\"r\",\n            )\n            with zf.open(saving_lib._CONFIG_FILENAME, \"r\") as f:\n                config_json = f.read()\n            with zf.open(saving_lib._METADATA_FILENAME, \"r\") as f:\n                metadata_json = f.read()\n            self.config = json.loads(config_json)\n            self.metadata = json.loads(metadata_json)\n\n        elif filepath.endswith(\".weights.h5\"):\n            weights_store = H5IOStore(filepath, mode=\"r\")\n        else:\n            raise ValueError(\n                \"Invalid filename: \"\n                \"expected a `.keras` `.weights.h5` extension. \"\n                f\"Received: filepath={filepath}\"\n            )\n\n        weights_dict, object_metadata = self._extract_weights_from_store(\n            weights_store.h5_file\n        )\n        weights_store.close()\n        self.weights_dict = weights_dict\n        self.object_metadata = object_metadata  # {path: object_name}\n        self.console.print(self._generate_filepath_info(rich_style=True))\n\n        if self.metadata is not None:\n            self.console.print(self._generate_metadata_info(rich_style=True))\n\n    def summary(self):\n        \"\"\"Prints the weight structure of the opened file.\"\"\"\n        self._weights_summary_cli()\n\n    def compare(self, reference_model):\n        \"\"\"Compares the opened file to a reference model.\n\n        This method will list all mismatches between the\n        currently opened file and the provided reference model.\n\n        Args:\n            reference_model: Model instance to compare to.\n\n        Returns:\n            Dict with the following keys:\n            `'status'`, `'error_count'`, `'match_count'`.\n            Status can be `'success'` or `'error'`.\n            `'error_count'` is the number of mismatches found.\n            `'match_count'` is the number of matching weights found.\n        \"\"\"\n        self.console.print(\"Running comparison\")\n        ref_spec = {}\n        get_weight_spec_of_saveable(reference_model, ref_spec)\n\n        def _compare(\n            target,\n            ref_spec,\n            inner_path,\n            target_name,\n            ref_name,\n            error_count,\n            match_count,\n            checked_paths,\n        ):\n            base_inner_path = inner_path\n            for ref_key, ref_val in ref_spec.items():\n                inner_path = f\"{base_inner_path}/{ref_key}\"\n                if inner_path in checked_paths:\n                    continue\n\n                if ref_key not in target:\n                    error_count += 1\n                    checked_paths.add(inner_path)\n                    if isinstance(ref_val, dict):\n                        self.console.print(\n                            f\"[color(160)]...Object [bold]{inner_path}[/] \"\n                            f\"present in {ref_name}, \"\n                            f\"missing from {target_name}[/]\"\n                        )\n                        self.console.print(\n                            f\"    In {ref_name}, {inner_path} contains \"\n                            f\"the following keys: {list(ref_val.keys())}\"\n                        )\n                    else:\n                        self.console.print(\n                            f\"[color(160)]...Weight [bold]{inner_path}[/] \"\n                            f\"present in {ref_name}, \"\n                            f\"missing from {target_name}[/]\"\n                        )\n                elif isinstance(ref_val, dict):\n                    _error_count, _match_count = _compare(\n                        target[ref_key],\n                        ref_spec[ref_key],\n                        inner_path,\n                        target_name,\n                        ref_name,\n                        error_count=error_count,\n                        match_count=match_count,\n                        checked_paths=checked_paths,\n                    )\n                    error_count += _error_count\n                    match_count += _match_count\n                else:\n                    if target[ref_key].shape != ref_val.shape:\n                        error_count += 1\n                        checked_paths.add(inner_path)\n                        self.console.print(\n                            f\"[color(160)]...Weight shape mismatch \"\n                            f\"for [bold]{inner_path}[/][/]\\n\"\n                            f\"    In {ref_name}: \"\n                            f\"shape={ref_val.shape}\\n\"\n                            f\"    In {target_name}: \"\n                            f\"shape={target[ref_key].shape}\"\n                        )\n                    else:\n                        match_count += 1\n            return error_count, match_count\n\n        checked_paths = set()\n        error_count, match_count = _compare(\n            self.weights_dict,\n            ref_spec,\n            inner_path=\"\",\n            target_name=\"saved file\",\n            ref_name=\"reference model\",\n            error_count=0,\n            match_count=0,\n            checked_paths=checked_paths,\n        )\n        _error_count, _ = _compare(\n            ref_spec,\n            self.weights_dict,\n            inner_path=\"\",\n            target_name=\"reference model\",\n            ref_name=\"saved file\",\n            error_count=0,\n            match_count=0,\n            checked_paths=checked_paths,\n        )\n        error_count += _error_count\n        self.console.print(\"─────────────────────\")\n        if error_count == 0:\n            status = \"success\"\n            self.console.print(\n                \"[color(28)][bold]Comparison successful:[/] \"\n                \"saved file is compatible with the reference model[/]\"\n            )\n            if match_count == 1:\n                plural = \"\"\n            else:\n                plural = \"s\"\n            self.console.print(\n                f\"    Found {match_count} matching weight{plural}\"\n            )\n        else:\n            status = \"error\"\n            if error_count == 1:\n                plural = \"\"\n            else:\n                plural = \"s\"\n            self.console.print(\n                f\"[color(160)][bold]Found {error_count} error{plural}:[/] \"\n                \"saved file is not compatible with the reference model[/]\"\n            )\n        return {\n            \"status\": status,\n            \"error_count\": error_count,\n            \"match_count\": match_count,\n        }\n\n    def _edit_object(self, edit_fn, source_name, target_name=None):\n        if target_name is not None and \"/\" in target_name:\n            raise ValueError(\n                \"Argument `target_name` should be a leaf name, \"\n                \"not a full path name. \"\n                f\"Received: target_name='{target_name}'\"\n            )\n        if \"/\" in source_name:\n            # It's a path\n            elements = source_name.split(\"/\")\n            weights_dict = self.weights_dict\n            for e in elements[:-1]:\n                if e not in weights_dict:\n                    raise ValueError(\n                        f\"Path '{source_name}' not found in model.\"\n                    )\n                weights_dict = weights_dict[e]\n            if elements[-1] not in weights_dict:\n                raise ValueError(f\"Path '{source_name}' not found in model.\")\n            edit_fn(\n                weights_dict, source_name=elements[-1], target_name=target_name\n            )\n        else:\n            # Ensure unicity\n            def count_occurences(d, name, count=0):\n                for k in d:\n                    if isinstance(d[k], dict):\n                        count += count_occurences(d[k], name, count)\n                if name in d:\n                    count += 1\n                return count\n\n            occurrences = count_occurences(self.weights_dict, source_name)\n            if occurrences > 1:\n                raise ValueError(\n                    f\"Name '{source_name}' occurs more than once in the model; \"\n                    \"try passing a complete path\"\n                )\n            if occurrences == 0:\n                raise ValueError(\n                    f\"Source name '{source_name}' does not appear in the \"\n                    \"model. Use `editor.weights_summary()` \"\n                    \"to list all objects.\"\n                )\n\n            def _edit(d):\n                for k in d:\n                    if isinstance(d[k], dict):\n                        _edit(d[k])\n                if source_name in d:\n                    edit_fn(d, source_name=source_name, target_name=target_name)\n\n            _edit(self.weights_dict)\n\n    def rename_object(self, object_name, new_name):\n        \"\"\"Rename an object in the file (e.g. a layer).\n\n        Args:\n            object_name: String, name or path of the\n                object to rename (e.g. `\"dense_2\"` or\n                `\"layers/dense_2\"`).\n            new_name: String, new name of the object.\n        \"\"\"\n\n        def rename_fn(weights_dict, source_name, target_name):\n            weights_dict[target_name] = weights_dict[source_name]\n            weights_dict.pop(source_name)\n\n        self._edit_object(rename_fn, object_name, new_name)\n\n    def delete_object(self, object_name):\n        \"\"\"Removes an object from the file (e.g. a layer).\n\n        Args:\n            object_name: String, name or path of the\n                object to delete (e.g. `\"dense_2\"` or\n                `\"layers/dense_2\"`).\n        \"\"\"\n\n        def delete_fn(weights_dict, source_name, target_name=None):\n            weights_dict.pop(source_name)\n\n        self._edit_object(delete_fn, object_name)\n\n    def add_object(self, object_path, weights):\n        \"\"\"Add a new object to the file (e.g. a layer).\n\n        Args:\n            object_path: String, full path of the\n                object to add (e.g. `\"layers/dense_2\"`).\n            weights: Dict mapping weight names to weight\n                values (arrays),\n                e.g. `{\"0\": kernel_value, \"1\": bias_value}`.\n        \"\"\"\n        if not isinstance(weights, dict):\n            raise ValueError(\n                \"Argument `weights` should be a dict \"\n                \"where keys are weight names (usually '0', '1', etc.) \"\n                \"and values are NumPy arrays. \"\n                f\"Received: type(weights)={type(weights)}\"\n            )\n\n        if \"/\" in object_path:\n            # It's a path\n            elements = object_path.split(\"/\")\n            partial_path = \"/\".join(elements[:-1])\n            weights_dict = self.weights_dict\n            for e in elements[:-1]:\n                if e not in weights_dict:\n                    raise ValueError(\n                        f\"Path '{partial_path}' not found in model.\"\n                    )\n                weights_dict = weights_dict[e]\n            weights_dict[elements[-1]] = weights\n        else:\n            self.weights_dict[object_path] = weights\n\n    def delete_weight(self, object_name, weight_name):\n        \"\"\"Removes a weight from an existing object.\n\n        Args:\n            object_name: String, name or path of the\n                object from which to remove the weight\n                (e.g. `\"dense_2\"` or `\"layers/dense_2\"`).\n            weight_name: String, name of the weight to\n                delete (e.g. `\"0\"`).\n        \"\"\"\n\n        def delete_weight_fn(weights_dict, source_name, target_name=None):\n            if weight_name not in weights_dict[source_name]:\n                raise ValueError(\n                    f\"Weight {weight_name} not found \"\n                    f\"in object {object_name}. \"\n                    \"Weights found: \"\n                    f\"{list(weights_dict[source_name].keys())}\"\n                )\n            weights_dict[source_name].pop(weight_name)\n\n        self._edit_object(delete_weight_fn, object_name)\n\n    def add_weights(self, object_name, weights):\n        \"\"\"Add one or more new weights to an existing object.\n\n        Args:\n            object_name: String, name or path of the\n                object to add the weights to\n                (e.g. `\"dense_2\"` or `\"layers/dense_2\"`).\n            weights: Dict mapping weight names to weight\n                values (arrays),\n                e.g. `{\"0\": kernel_value, \"1\": bias_value}`.\n        \"\"\"\n        if not isinstance(weights, dict):\n            raise ValueError(\n                \"Argument `weights` should be a dict \"\n                \"where keys are weight names (usually '0', '1', etc.) \"\n                \"and values are NumPy arrays. \"\n                f\"Received: type(weights)={type(weights)}\"\n            )\n\n        def add_weight_fn(weights_dict, source_name, target_name=None):\n            weights_dict[source_name].update(weights)\n\n        self._edit_object(add_weight_fn, object_name)\n\n    def save(self, filepath):\n        \"\"\"Save the edited weights file.\n\n        Args:\n            filepath: Path to save the file to.\n                Must be a `.weights.h5` file.\n        \"\"\"\n        filepath = str(filepath)\n        if not filepath.endswith(\".weights.h5\"):\n            raise ValueError(\n                \"Invalid `filepath` argument: \"\n                \"expected a `.weights.h5` extension. \"\n                f\"Received: filepath={filepath}\"\n            )\n        weights_store = H5IOStore(filepath, mode=\"w\")\n\n        def _save(weights_dict, weights_store, inner_path):\n            vars_to_create = {}\n            for name, value in weights_dict.items():\n                if isinstance(value, dict):\n                    if value:\n                        _save(\n                            weights_dict[name],\n                            weights_store,\n                            inner_path=os.path.join(inner_path, name),\n                        )\n                else:\n                    # e.g. name=\"0\", value=HDF5Dataset\n                    vars_to_create[name] = value\n            if vars_to_create:\n                var_store = weights_store.make(inner_path)\n                for name, value in vars_to_create.items():\n                    var_store[name] = value\n\n        _save(self.weights_dict, weights_store, inner_path=\"\")\n        weights_store.close()\n\n    def resave_weights(self, filepath):\n        self.save(filepath)\n\n    def _extract_weights_from_store(self, data, metadata=None, inner_path=\"\"):\n        metadata = metadata or {}\n\n        # ------------------------------------------------------\n        # Collect metadata for this HDF5 group\n        # ------------------------------------------------------\n        object_metadata = {}\n        for k, v in data.attrs.items():\n            object_metadata[k] = v\n        if object_metadata:\n            metadata[inner_path] = object_metadata\n\n        result = collections.OrderedDict()\n\n        # ------------------------------------------------------\n        # Iterate over all keys in this HDF5 group\n        # ------------------------------------------------------\n        for key in data.keys():\n            # IMPORTANT:\n            # Never mutate inner_path; use local variable.\n            current_inner_path = f\"{inner_path}/{key}\"\n            value = data[key]\n\n            # ------------------------------------------------------\n            # CASE 1 — HDF5 GROUP → RECURSE\n            # ------------------------------------------------------\n            if isinstance(value, h5py.Group):\n                # Skip empty groups\n                if len(value) == 0:\n                    continue\n\n                # Skip empty \"vars\" groups\n                if \"vars\" in value.keys() and len(value[\"vars\"]) == 0:\n                    continue\n\n                # Recurse into \"vars\" subgroup when present\n                if \"vars\" in value.keys():\n                    result[key], metadata = self._extract_weights_from_store(\n                        value[\"vars\"],\n                        metadata=metadata,\n                        inner_path=current_inner_path,\n                    )\n                else:\n                    # Recurse normally\n                    result[key], metadata = self._extract_weights_from_store(\n                        value,\n                        metadata=metadata,\n                        inner_path=current_inner_path,\n                    )\n\n                continue  # finished processing this key\n\n            # ------------------------------------------------------\n            # CASE 2 — HDF5 DATASET → SAFE LOADING\n            # ------------------------------------------------------\n\n            # Skip any objects that are not proper datasets\n            if not isinstance(value, h5py.Dataset):\n                continue\n\n            if value.external:\n                raise ValueError(\n                    \"Not allowed: H5 file Dataset with external links: \"\n                    f\"{value.external}\"\n                )\n\n            shape = value.shape\n            dtype = value.dtype\n\n            # ------------------------------------------------------\n            # Validate SHAPE (avoid malformed / malicious metadata)\n            # ------------------------------------------------------\n\n            # No negative dimensions\n            if any(dim < 0 for dim in shape):\n                raise ValueError(\n                    \"Malformed HDF5 dataset shape encountered in .keras file; \"\n                    \"negative dimension detected.\"\n                )\n\n            # Prevent absurdly high-rank tensors\n            if len(shape) > 64:\n                raise ValueError(\n                    \"Malformed HDF5 dataset shape encountered in .keras file; \"\n                    \"tensor rank exceeds safety limit.\"\n                )\n\n            # Safe product computation (Python int is unbounded)\n            num_elems = int(np.prod(shape))\n\n            # ------------------------------------------------------\n            # Validate TOTAL memory size\n            # ------------------------------------------------------\n            MAX_BYTES = 1 << 32  # 4 GiB\n\n            size_bytes = num_elems * dtype.itemsize\n\n            if size_bytes > MAX_BYTES:\n                raise ValueError(\n                    f\"HDF5 dataset too large to load safely \"\n                    f\"({size_bytes} bytes; limit is {MAX_BYTES}).\"\n                )\n\n            # ------------------------------------------------------\n            # SAFE — load dataset (guaranteed ≤ 4 GiB)\n            # ------------------------------------------------------\n            result[key] = value[()]\n\n        # ------------------------------------------------------\n        # Return final tree and metadata\n        # ------------------------------------------------------\n        return result, metadata\n\n    def _generate_filepath_info(self, rich_style=False):\n        if rich_style:\n            filepath = f\"'{self.filepath}'\"\n            filepath = f\"{summary_utils.highlight_symbol(filepath)}\"\n        else:\n            filepath = f\"'{self.filepath}'\"\n        return f\"Keras model file {filepath}\"\n\n    def _generate_config_info(self, rich_style=False):\n        return pprint.pformat(self.config)\n\n    def _generate_metadata_info(self, rich_style=False):\n        version = self.metadata[\"keras_version\"]\n        date = self.metadata[\"date_saved\"]\n        if rich_style:\n            version = f\"{summary_utils.highlight_symbol(version)}\"\n            date = f\"{summary_utils.highlight_symbol(date)}\"\n        return f\"Saved with Keras {version} - date: {date}\"\n\n    def _print_weights_structure(\n        self, weights_dict, indent=0, is_first=True, prefix=\"\", inner_path=\"\"\n    ):\n        for idx, (key, value) in enumerate(weights_dict.items()):\n            inner_path = os.path.join(inner_path, key)\n            is_last = idx == len(weights_dict) - 1\n            if is_first:\n                is_first = False\n                connector = \"> \"\n            elif is_last:\n                connector = \"└─ \"\n            else:\n                connector = \"├─ \"\n\n            if isinstance(value, dict):\n                bold_key = summary_utils.bold_text(key)\n                object_label = f\"{prefix}{connector}{bold_key}\"\n                if inner_path in self.object_metadata:\n                    metadata = self.object_metadata[inner_path]\n                    if \"name\" in metadata:\n                        name = metadata[\"name\"]\n                        object_label += f\" ('{name}')\"\n                self.console.print(object_label)\n                if is_last:\n                    appended = \"    \"\n                else:\n                    appended = \"│   \"\n                new_prefix = prefix + appended\n                self._print_weights_structure(\n                    value,\n                    indent + 1,\n                    is_first=is_first,\n                    prefix=new_prefix,\n                    inner_path=inner_path,\n                )\n            else:\n                if hasattr(value, \"shape\"):\n                    bold_key = summary_utils.bold_text(key)\n                    self.console.print(\n                        f\"{prefix}{connector}{bold_key}:\"\n                        + f\" shape={value.shape}, dtype={value.dtype}\"\n                    )\n                else:\n                    self.console.print(f\"{prefix}{connector}{key}: {value}\")\n\n    def _weights_summary_cli(self):\n        self.console.print(\"Weights structure\")\n        self._print_weights_structure(self.weights_dict, prefix=\" \" * 2)\n\n    def _weights_summary_interactive(self):\n        def _generate_html_weights(dictionary, margin_left=0, font_size=1):\n            html = \"\"\n            for key, value in dictionary.items():\n                if isinstance(value, dict) and value:\n                    weights_html = _generate_html_weights(\n                        value, margin_left + 20, font_size - 1\n                    )\n                    html += (\n                        f'<details style=\"margin-left: {margin_left}px;\">'\n                        '<summary style=\"'\n                        f\"font-size: {font_size}em; \"\n                        \"font-weight: bold;\"\n                        f'\">{key}</summary>'\n                        f\"{weights_html}\"\n                        \"</details>\"\n                    )\n                else:\n                    html += (\n                        f'<details style=\"margin-left: {margin_left}px;\">'\n                        f'<summary style=\"font-size: {font_size}em;\">'\n                        f\"{key} : shape={value.shape}\"\n                        f\", dtype={value.dtype}</summary>\"\n                        f\"<div style=\"\n                        f'\"margin-left: {margin_left}px;'\n                        f'\"margin-top: {margin_left}px;\">'\n                        f\"{display_weight(value)}\"\n                        \"</div>\"\n                        \"</details>\"\n                    )\n            return html\n\n        output = \"Weights structure\"\n\n        initialize_id_counter()\n        output += _generate_html_weights(self.weights_dict)\n        ipython.display.display(ipython.display.HTML(output))\n\n\ndef get_weight_spec_of_saveable(saveable, spec, visited_saveables=None):\n    from keras.src.saving.keras_saveable import KerasSaveable\n\n    visited_saveables = visited_saveables or set()\n\n    # If the saveable has already been saved, skip it.\n    if id(saveable) in visited_saveables:\n        return\n\n    if hasattr(saveable, \"save_own_variables\"):\n        store = {}\n        saveable.save_own_variables(store)\n        if store:\n            keys = sorted(store.keys())\n            for k in keys:\n                val = store[k]\n                spec[k] = backend.KerasTensor(shape=val.shape, dtype=val.dtype)\n\n    visited_saveables.add(id(saveable))\n\n    for child_attr, child_obj in saving_lib._walk_saveable(saveable):\n        if isinstance(child_obj, KerasSaveable):\n            sub_spec = {}\n            get_weight_spec_of_saveable(\n                child_obj,\n                sub_spec,\n                visited_saveables=visited_saveables,\n            )\n            if sub_spec:\n                spec[child_attr] = sub_spec\n        elif isinstance(child_obj, (list, dict, tuple, set)):\n            sub_spec = {}\n            get_weight_spec_of_container(\n                child_obj,\n                sub_spec,\n                visited_saveables=visited_saveables,\n            )\n            if sub_spec:\n                spec[child_attr] = sub_spec\n\n\ndef get_weight_spec_of_container(container, spec, visited_saveables):\n    from keras.src.saving.keras_saveable import KerasSaveable\n\n    used_names = {}\n    if isinstance(container, dict):\n        container = list(container.values())\n\n    for saveable in container:\n        if isinstance(saveable, KerasSaveable):\n            name = naming.to_snake_case(saveable.__class__.__name__)\n            if name in used_names:\n                used_names[name] += 1\n                name = f\"{name}_{used_names[name]}\"\n            else:\n                used_names[name] = 0\n            sub_spec = {}\n            get_weight_spec_of_saveable(\n                saveable,\n                sub_spec,\n                visited_saveables=visited_saveables,\n            )\n            if sub_spec:\n                spec[name] = sub_spec\n\n\ndef initialize_id_counter():\n    global div_id_counter\n    div_id_counter = 0\n\n\ndef increment_id_counter():\n    global div_id_counter\n    div_id_counter += 1\n\n\ndef get_id_counter():\n    return div_id_counter\n\n\ndef display_weight(weight, axis=-1, threshold=16):\n    def _find_factors_closest_to_sqrt(num):\n        sqrt_num = int(np.sqrt(num))\n\n        for i in range(sqrt_num, 0, -1):\n            if num % i == 0:\n                M = i\n                N = num // i\n\n                if M > N:\n                    return N, M\n                return M, N\n\n    def _color_from_rbg(value):\n        return f\"rgba({value[0]}, {value[1]}, {value[2]}, 1)\"\n\n    def _reduce_3d_array_by_mean(arr, n, axis):\n        if axis == 2:\n            trimmed_arr = arr[:, :, : arr.shape[2] - (arr.shape[2] % n)]\n            reshaped = np.reshape(\n                trimmed_arr, (arr.shape[0], arr.shape[1], -1, n)\n            )\n            mean_values = np.mean(reshaped, axis=3)\n\n        elif axis == 1:\n            trimmed_arr = arr[:, : arr.shape[1] - (arr.shape[1] % n), :]\n            reshaped = np.reshape(\n                trimmed_arr, (arr.shape[0], -1, n, arr.shape[2])\n            )\n            mean_values = np.mean(reshaped, axis=2)\n\n        elif axis == 0:\n            trimmed_arr = arr[: arr.shape[0] - (arr.shape[0] % n), :, :]\n            reshaped = np.reshape(\n                trimmed_arr, (-1, n, arr.shape[1], arr.shape[2])\n            )\n            mean_values = np.mean(reshaped, axis=1)\n\n        else:\n            raise ValueError(\"Axis must be 0, 1, or 2.\")\n\n        return mean_values\n\n    def _create_matrix_html(matrix, subplot_size=840):\n        rows, cols, num_slices = matrix.shape\n\n        M, N = _find_factors_closest_to_sqrt(num_slices)\n\n        try:\n            from matplotlib import cm\n        except ImportError:\n            cm = None\n        if cm:\n            rgb_matrix = cm.jet(matrix)\n        else:\n            rgb_matrix = (matrix - np.min(matrix)) / (\n                np.max(matrix) - np.min(matrix)\n            )\n            rgb_matrix = np.stack([rgb_matrix, rgb_matrix, rgb_matrix], axis=-1)\n        rgb_matrix = (rgb_matrix[..., :3] * 255).astype(\"uint8\")\n\n        subplot_html = \"\"\n        for i in range(num_slices):\n            cell_html = \"\"\n            for row in rgb_matrix[..., i, :]:\n                for rgb in row:\n                    color = _color_from_rbg(rgb)\n                    cell_html += (\n                        f'<div class=\"cell\" '\n                        f'style=\"background-color: {color};\">'\n                        f\"</div>\"\n                    )\n            subplot_html += f\"\"\"\n                        <div class=\"matrix\">\n                          {cell_html}\n                        </div>\n                        \"\"\"\n\n        cell_size = subplot_size // (N * cols)\n\n        increment_id_counter()\n        div_id = get_id_counter()\n\n        html_code = f\"\"\"\n            <div class=\"unique-container_{div_id}\">\n                  <style>\n                      .unique-container_{div_id} .subplots {{\n                      display: inline-grid;\n                      grid-template-columns: repeat({N}, 1fr);\n                      column-gap: 5px;  /* Minimal horizontal gap */\n                      row-gap: 5px;     /* Small vertical gap */\n                      margin: 0;\n                      padding: 0;\n                    }}\n                    .unique-container_{div_id} .matrix {{\n                      display: inline-grid;\n                      grid-template-columns: repeat({cols}, {cell_size}px);\n                      grid-template-rows: repeat({rows}, {cell_size}px);\n                      gap: 1px;\n                      margin: 0;\n                      padding: 0;\n                    }}\n                    .unique-container_{div_id} .cell {{\n                      width: {cell_size}px;\n                      height: {cell_size}px;\n                      display: flex;\n                      justify-content: center;\n                      align-items: center;\n                      font-size: 5px;\n                      font-weight: bold;\n                      color: white;\n                    }}\n                     .unique-container_{div_id} {{\n                      margin-top: 20px;\n                      margin-bottom: 20px;\n                    }}\n                  </style>\n                  <div class=\"subplots\">\n                    {subplot_html}\n                  </div>\n                  </div>\n                \"\"\"\n\n        return html_code\n\n    if weight.ndim == 1:\n        weight = weight[..., np.newaxis]\n\n    weight = np.swapaxes(weight, axis, -1)\n    weight = weight.reshape(-1, weight.shape[-1])\n\n    M, N = _find_factors_closest_to_sqrt(weight.shape[0])\n    weight = weight.reshape(M, N, weight.shape[-1])\n\n    for reduce_axis in [0, 1, 2]:\n        if weight.shape[reduce_axis] > threshold:\n            weight = _reduce_3d_array_by_mean(\n                weight,\n                weight.shape[reduce_axis] // threshold,\n                axis=reduce_axis,\n            )\n\n    weight = (weight - weight.min()) / (weight.max() - weight.min() + 1e-5)\n\n    html_code = _create_matrix_html(weight)\n    return html_code\n"
  },
  {
    "path": "keras/src/saving/file_editor_test.py",
    "content": "import os\n\nimport numpy as np\nimport pytest\n\nimport keras\nfrom keras.src import testing\nfrom keras.src.saving.file_editor import KerasFileEditor\n\n\ndef get_source_model():\n    inputs = keras.Input((2,))\n    x = keras.layers.Dense(3, name=\"mydense\")(inputs)\n    outputs = keras.layers.Dense(3, name=\"output_layer\")(x)\n    model = keras.Model(inputs, outputs)\n    return model\n\n\ndef get_target_model():\n    inputs = keras.Input((2,))\n    x = keras.layers.Dense(3, name=\"mydense\")(inputs)\n    x = keras.layers.Dense(3, name=\"myotherdense\")(x)\n    outputs = keras.layers.Dense(3, name=\"output_layer\")(x)\n    model = keras.Model(inputs, outputs)\n    return model\n\n\nclass SavingTest(testing.TestCase):\n    def test_basics(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"my_model.keras\")\n\n        model = get_source_model()\n        model.save(temp_filepath)\n\n        editor = KerasFileEditor(temp_filepath)\n        editor.summary()\n\n        target_model = get_target_model()\n\n        out = editor.compare(model)  # Succeeds\n        self.assertEqual(out[\"status\"], \"success\")\n        out = editor.compare(target_model)  # Fails\n\n        editor.add_object(\n            \"layers/dense_3\", weights={\"0\": np.random.random((3, 3))}\n        )\n        out = editor.compare(target_model)  # Fails\n        self.assertEqual(out[\"status\"], \"error\")\n        self.assertEqual(out[\"error_count\"], 2)\n\n        editor.rename_object(\"dense_3\", \"dense_4\")\n        editor.rename_object(\"layers/dense_4\", \"dense_2\")\n        editor.add_weights(\"dense_2\", weights={\"1\": np.random.random((3,))})\n        out = editor.compare(target_model)  # Succeeds\n        self.assertEqual(out[\"status\"], \"success\")\n\n        editor.add_object(\n            \"layers/dense_3\", weights={\"0\": np.random.random((3, 3))}\n        )\n        out = editor.compare(target_model)  # Fails\n        self.assertEqual(out[\"status\"], \"error\")\n        self.assertEqual(out[\"error_count\"], 1)\n\n        editor.delete_object(\"layers/dense_3\")\n        out = editor.compare(target_model)  # Succeeds\n        self.assertEqual(out[\"status\"], \"success\")\n        editor.summary()\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"resaved.weights.h5\")\n        editor.save(temp_filepath)\n        target_model.load_weights(temp_filepath)\n\n        editor = KerasFileEditor(temp_filepath)\n        editor.summary()\n        out = editor.compare(target_model)  # Succeeds\n        self.assertEqual(out[\"status\"], \"success\")\n\n        editor.delete_weight(\"dense_2\", \"1\")\n        out = editor.compare(target_model)  # Fails\n        self.assertEqual(out[\"status\"], \"error\")\n        self.assertEqual(out[\"error_count\"], 1)\n\n        editor.add_weights(\"dense_2\", {\"1\": np.zeros((7,))})\n        out = editor.compare(target_model)  # Fails\n        self.assertEqual(out[\"status\"], \"error\")\n        self.assertEqual(out[\"error_count\"], 1)\n\n        editor.delete_weight(\"dense_2\", \"1\")\n        editor.add_weights(\"dense_2\", {\"1\": np.zeros((3,))})\n        out = editor.compare(target_model)  # Succeeds\n        self.assertEqual(out[\"status\"], \"success\")\n\n    @pytest.mark.requires_trainable_backend\n    def test_scalar_weight(self):\n        model = keras.Sequential(name=\"my_sequential\")\n        model.add(keras.Input(shape=(1,), name=\"my_input\"))\n        model.add(keras.layers.Dense(1, activation=\"sigmoid\", name=\"my_dense\"))\n        model.compile(optimizer=\"adam\", loss=\"mse\", metrics=[\"mae\"])\n        model.fit(np.array([[1]]), np.array([[1]]), verbose=0)\n        model_fpath = os.path.join(self.get_temp_dir(), \"model.keras\")\n        weights_fpath = os.path.join(self.get_temp_dir(), \"model.weights.h5\")\n        model.save(model_fpath)\n        model.save_weights(weights_fpath)\n\n        model_editor = KerasFileEditor(model_fpath)\n        self.assertEqual(\n            len(keras.src.tree.flatten(model_editor.weights_dict)), 8\n        )\n        model_weights_editor = KerasFileEditor(weights_fpath)\n        self.assertEqual(\n            len(keras.src.tree.flatten(model_weights_editor.weights_dict)), 8\n        )\n"
  },
  {
    "path": "keras/src/saving/keras_saveable.py",
    "content": "import io\n\n\nclass KerasSaveable:\n    # Note: renaming this function will cause old pickles to be broken.\n    # This is probably not a huge deal, as pickle should not be a recommended\n    # saving format -- it should only be supported for use with distributed\n    # computing frameworks.\n\n    def _obj_type(self):\n        raise NotImplementedError(\n            \"KerasSaveable subclases must provide an \"\n            \"implementation for `obj_type()`\"\n        )\n\n    @classmethod\n    def _unpickle_model(cls, bytesio):\n        import keras.src.saving.saving_lib as saving_lib\n\n        # pickle is not safe regardless of what you do.\n        return saving_lib._load_model_from_fileobj(\n            bytesio, custom_objects=None, compile=True, safe_mode=False\n        )\n\n    def __reduce__(self):\n        \"\"\"__reduce__ is used to customize the behavior of `pickle.pickle()`.\n\n        The method returns a tuple of two elements: a function, and a list of\n        arguments to pass to that function.  In this case we just leverage the\n        keras saving library.\"\"\"\n        import keras.src.saving.saving_lib as saving_lib\n\n        buf = io.BytesIO()\n        saving_lib._save_model_to_fileobj(self, buf, \"h5\")\n        return (\n            self._unpickle_model,\n            (buf,),\n        )\n"
  },
  {
    "path": "keras/src/saving/object_registration.py",
    "content": "import inspect\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend.common import global_state\n\nGLOBAL_CUSTOM_OBJECTS = {}\nGLOBAL_CUSTOM_NAMES = {}\n\n\n@keras_export(\n    [\n        \"keras.saving.CustomObjectScope\",\n        \"keras.saving.custom_object_scope\",\n        \"keras.utils.CustomObjectScope\",\n        \"keras.utils.custom_object_scope\",\n    ]\n)\nclass CustomObjectScope:\n    \"\"\"Exposes custom classes/functions to Keras deserialization internals.\n\n    Under a scope `with custom_object_scope(objects_dict)`, Keras methods such\n    as `keras.models.load_model()` or\n    `keras.models.model_from_config()` will be able to deserialize any\n    custom object referenced by a saved config (e.g. a custom layer or metric).\n\n    Example:\n\n    Consider a custom regularizer `my_regularizer`:\n\n    ```python\n    layer = Dense(3, kernel_regularizer=my_regularizer)\n    # Config contains a reference to `my_regularizer`\n    config = layer.get_config()\n    ...\n    # Later:\n    with custom_object_scope({'my_regularizer': my_regularizer}):\n        layer = Dense.from_config(config)\n    ```\n\n    Args:\n        custom_objects: Dictionary of `{str: object}` pairs,\n            where the `str` key is the object name.\n    \"\"\"\n\n    def __init__(self, custom_objects):\n        self.custom_objects = custom_objects or {}\n        self.backup = None\n\n    def __enter__(self):\n        self.backup = global_state.get_global_attribute(\n            \"custom_objects_scope_dict\", {}\n        ).copy()\n        global_state.set_global_attribute(\n            \"custom_objects_scope_dict\", self.custom_objects.copy()\n        )\n        return self\n\n    def __exit__(self, *args, **kwargs):\n        global_state.set_global_attribute(\n            \"custom_objects_scope_dict\", self.backup.copy()\n        )\n\n\n# Alias.\ncustom_object_scope = CustomObjectScope\n\n\n@keras_export(\n    [\n        \"keras.saving.get_custom_objects\",\n        \"keras.utils.get_custom_objects\",\n    ]\n)\ndef get_custom_objects():\n    \"\"\"Retrieves a live reference to the global dictionary of custom objects.\n\n    Custom objects set using `custom_object_scope()` are not added to the\n    global dictionary of custom objects, and will not appear in the returned\n    dictionary.\n\n    Example:\n\n    ```python\n    get_custom_objects().clear()\n    get_custom_objects()['MyObject'] = MyObject\n    ```\n\n    Returns:\n        Global dictionary mapping registered class names to classes.\n    \"\"\"\n    return GLOBAL_CUSTOM_OBJECTS\n\n\n@keras_export(\n    [\n        \"keras.saving.register_keras_serializable\",\n        \"keras.utils.register_keras_serializable\",\n    ]\n)\ndef register_keras_serializable(package=\"Custom\", name=None):\n    \"\"\"Registers an object with the Keras serialization framework.\n\n    This decorator injects the decorated class or function into the Keras custom\n    object dictionary, so that it can be serialized and deserialized without\n    needing an entry in the user-provided custom object dict. It also injects a\n    function that Keras will call to get the object's serializable string key.\n\n    Note that to be serialized and deserialized, classes must implement the\n    `get_config()` method. Functions do not have this requirement.\n\n    The object will be registered under the key `'package>name'` where `name`,\n    defaults to the object name if not passed.\n\n    Example:\n\n    ```python\n    # Note that `'my_package'` is used as the `package` argument here, and since\n    # the `name` argument is not provided, `'MyDense'` is used as the `name`.\n    @register_keras_serializable('my_package')\n    class MyDense(keras.layers.Dense):\n        pass\n\n    assert get_registered_object('my_package>MyDense') == MyDense\n    assert get_registered_name(MyDense) == 'my_package>MyDense'\n    ```\n\n    Args:\n        package: The package that this class belongs to. This is used for the\n            `key` (which is `\"package>name\"`) to identify the class. Note that\n            this is the first argument passed into the decorator.\n        name: The name to serialize this class under in this package. If not\n            provided or `None`, the class' name will be used (note that this is\n            the case when the decorator is used with only one argument, which\n            becomes the `package`).\n\n    Returns:\n        A decorator that registers the decorated class with the passed names.\n    \"\"\"\n\n    def decorator(arg):\n        \"\"\"Registers a class with the Keras serialization framework.\"\"\"\n        class_name = name if name is not None else arg.__name__\n        registered_name = f\"{package}>{class_name}\"\n\n        if inspect.isclass(arg) and not hasattr(arg, \"get_config\"):\n            raise ValueError(\n                \"Cannot register a class that does not have a \"\n                \"get_config() method.\"\n            )\n\n        GLOBAL_CUSTOM_OBJECTS[registered_name] = arg\n        GLOBAL_CUSTOM_NAMES[arg] = registered_name\n\n        return arg\n\n    return decorator\n\n\n@keras_export(\n    [\n        \"keras.saving.get_registered_name\",\n        \"keras.utils.get_registered_name\",\n    ]\n)\ndef get_registered_name(obj):\n    \"\"\"Returns the name registered to an object within the Keras framework.\n\n    This function is part of the Keras serialization and deserialization\n    framework. It maps objects to the string names associated with those objects\n    for serialization/deserialization.\n\n    Args:\n        obj: The object to look up.\n\n    Returns:\n        The name associated with the object, or the default Python name if the\n            object is not registered.\n    \"\"\"\n    if obj in GLOBAL_CUSTOM_NAMES:\n        return GLOBAL_CUSTOM_NAMES[obj]\n    else:\n        return obj.__name__\n\n\n@keras_export(\n    [\n        \"keras.saving.get_registered_object\",\n        \"keras.utils.get_registered_object\",\n    ]\n)\ndef get_registered_object(name, custom_objects=None, module_objects=None):\n    \"\"\"Returns the class associated with `name` if it is registered with Keras.\n\n    This function is part of the Keras serialization and deserialization\n    framework. It maps strings to the objects associated with them for\n    serialization/deserialization.\n\n    Example:\n\n    ```python\n    def from_config(cls, config, custom_objects=None):\n        if 'my_custom_object_name' in config:\n            config['hidden_cls'] = tf.keras.saving.get_registered_object(\n                config['my_custom_object_name'], custom_objects=custom_objects)\n    ```\n\n    Args:\n        name: The name to look up.\n        custom_objects: A dictionary of custom objects to look the name up in.\n            Generally, custom_objects is provided by the user.\n        module_objects: A dictionary of custom objects to look the name up in.\n            Generally, module_objects is provided by midlevel library\n            implementers.\n\n    Returns:\n        An instantiable class associated with `name`, or `None` if no such class\n            exists.\n    \"\"\"\n    custom_objects_scope_dict = global_state.get_global_attribute(\n        \"custom_objects_scope_dict\", {}\n    )\n    if name in custom_objects_scope_dict:\n        return custom_objects_scope_dict[name]\n    elif name in GLOBAL_CUSTOM_OBJECTS:\n        return GLOBAL_CUSTOM_OBJECTS[name]\n    elif custom_objects and name in custom_objects:\n        return custom_objects[name]\n    elif module_objects and name in module_objects:\n        return module_objects[name]\n    return None\n"
  },
  {
    "path": "keras/src/saving/object_registration_test.py",
    "content": "import keras\nfrom keras.src import testing\nfrom keras.src.saving import object_registration\nfrom keras.src.saving import serialization_lib\n\n\nclass TestObjectRegistration(testing.TestCase):\n    def test_custom_object_scope(self):\n        def custom_fn():\n            pass\n\n        class CustomClass:\n            pass\n\n        def check_get_in_thread():\n            with object_registration.custom_object_scope(\n                {\"CustomClass\": CustomClass, \"custom_fn\": custom_fn}\n            ):\n                actual_custom_fn = keras.activations.get(\"custom_fn\")\n                self.assertEqual(actual_custom_fn, custom_fn)\n                actual_custom_class = keras.regularizers.get(\"CustomClass\")\n                self.assertEqual(actual_custom_class.__class__, CustomClass)\n\n            with object_registration.custom_object_scope(\n                {\"CustomClass\": CustomClass, \"custom_fn\": custom_fn}\n            ):\n                actual_custom_fn = keras.activations.get(\"custom_fn\")\n                self.assertEqual(actual_custom_fn, custom_fn)\n                actual_custom_class = keras.regularizers.get(\"CustomClass\")\n                self.assertEqual(actual_custom_class.__class__, CustomClass)\n                checked_thread = self.checkedThread(check_get_in_thread)\n                checked_thread.start()\n                checked_thread.join()\n\n    def test_serialize_custom_class_with_default_name(self):\n        @object_registration.register_keras_serializable()\n        class TestClass:\n            def __init__(self, value):\n                self._value = value\n\n            def get_config(self):\n                return {\"value\": self._value}\n\n            @classmethod\n            def from_config(cls, config):\n                return cls(**config)\n\n        serialized_name = \"Custom>TestClass\"\n        inst = TestClass(value=10)\n        class_name = object_registration.GLOBAL_CUSTOM_NAMES[TestClass]\n        self.assertEqual(serialized_name, class_name)\n\n        config = serialization_lib.serialize_keras_object(inst)\n        self.assertEqual(\"TestClass\", config[\"class_name\"])\n        new_inst = serialization_lib.deserialize_keras_object(config)\n        self.assertIsNot(inst, new_inst)\n        self.assertIsInstance(new_inst, TestClass)\n        self.assertEqual(10, new_inst._value)\n\n    def test_serialize_custom_class_with_custom_name(self):\n        @object_registration.register_keras_serializable(\n            \"TestPackage\", \"CustomName\"\n        )\n        class OtherTestClass:\n            def __init__(self, val):\n                self._val = val\n\n            def get_config(self):\n                return {\"val\": self._val}\n\n            @classmethod\n            def from_config(cls, config):\n                return cls(**config)\n\n        serialized_name = \"TestPackage>CustomName\"\n        inst = OtherTestClass(val=5)\n        class_name = object_registration.GLOBAL_CUSTOM_NAMES[OtherTestClass]\n        self.assertEqual(serialized_name, class_name)\n        fn_class_name = object_registration.get_registered_name(OtherTestClass)\n        self.assertEqual(fn_class_name, class_name)\n\n        cls = object_registration.get_registered_object(fn_class_name)\n        self.assertEqual(OtherTestClass, cls)\n\n        config = keras.saving.serialize_keras_object(inst)\n        self.assertEqual(\"OtherTestClass\", config[\"class_name\"])\n        new_inst = keras.saving.deserialize_keras_object(config)\n        self.assertIsNot(inst, new_inst)\n        self.assertIsInstance(new_inst, OtherTestClass)\n        self.assertEqual(5, new_inst._val)\n\n    def test_serialize_custom_function(self):\n        @object_registration.register_keras_serializable()\n        def my_fn():\n            return 42\n\n        serialized_name = \"Custom>my_fn\"\n        class_name = object_registration.GLOBAL_CUSTOM_NAMES[my_fn]\n        self.assertEqual(serialized_name, class_name)\n        fn_class_name = object_registration.get_registered_name(my_fn)\n        self.assertEqual(fn_class_name, class_name)\n\n        config = keras.saving.serialize_keras_object(my_fn)\n        fn = keras.saving.deserialize_keras_object(config)\n        self.assertEqual(42, fn())\n\n        fn_2 = object_registration.get_registered_object(fn_class_name)\n        self.assertEqual(42, fn_2())\n\n    def test_serialize_custom_class_without_get_config_fails(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Cannot register a class that does not have a get_config.*\",\n        ):\n\n            @object_registration.register_keras_serializable(\n                \"TestPackage\", \"TestClass\"\n            )\n            class TestClass:\n                def __init__(self, value):\n                    self._value = value\n"
  },
  {
    "path": "keras/src/saving/orbax_util.py",
    "content": "\"\"\"Orbax checkpoint loading functionality.\"\"\"\n\nimport os\n\nfrom keras.src import backend\nfrom keras.src.distribution import distribution as get_distribution\nfrom keras.src.utils import file_utils\nfrom keras.src.utils.module_utils import ocp\n\n\ndef is_orbax_checkpoint(filepath):\n    \"\"\"Check if the given path is an Orbax checkpoint directory.\n\n    This function implements custom detection logic instead of relying on\n    Orbax APIs which may be unreliable in some environments.\n    \"\"\"\n    if not file_utils.exists(filepath) or not file_utils.isdir(filepath):\n        return False\n\n    try:\n        # List directory contents\n        contents = file_utils.listdir(filepath)\n\n        # A set is more efficient for membership testing\n        orbax_indicators = {\n            \"orbax.checkpoint\",\n            \"pytree.orbax-checkpoint\",\n            \"checkpoint_metadata\",\n        }\n\n        # Fast check for standard files\n        if not orbax_indicators.isdisjoint(contents):\n            return True\n\n        # Check for step directories or temporary files in a single pass\n        return any(\n            \".orbax-checkpoint-tmp\" in item\n            or (\n                item.isdigit()\n                and file_utils.isdir(file_utils.join(filepath, item))\n            )\n            for item in contents\n        )\n\n    except (OSError, PermissionError):\n        # If we can't read the directory, assume it's not a checkpoint\n        return False\n\n\ndef find_latest_orbax_checkpoint(checkpoint_dir):\n    \"\"\"Find the latest checkpoint in an Orbax checkpoint directory.\"\"\"\n    checkpointer = ocp.training.Checkpointer(directory=checkpoint_dir)\n    latest = checkpointer.latest\n    if latest is None:\n        raise ValueError(f\"No valid checkpoints found in {checkpoint_dir}\")\n    return os.path.join(checkpoint_dir, str(latest.step))\n\n\ndef build_orbax_abstract_pytree(checkpoint_path, ref_state=None):\n    \"\"\"Build an abstract pytree for Orbax loading with target shardings.\n\n    On JAX with an active distribution, returns a pytree of\n    `jax.ShapeDtypeStruct` so that Orbax reshards arrays onto the\n    current distribution layout instead of restoring saved shardings.\n    On all other backends, or when no distribution is active, returns\n    `None` (Orbax will use saved shardings — fine when the topology\n    hasn't changed).\n\n    Args:\n        checkpoint_path: Path to a specific Orbax checkpoint step\n            directory (e.g. `<root>/2`).\n        ref_state: Optional reference state tree (from\n            `model.get_state_tree()`) whose variables carry the\n            target shardings. If `None`, shardings default to\n            `None` per leaf (Orbax uses saved shardings).\n\n    Returns:\n        A pytree of `jax.ShapeDtypeStruct` matching the checkpoint\n        structure, or `None` when resharding is not needed.\n    \"\"\"\n    if backend.backend() != \"jax\":\n        return None\n\n    if get_distribution() is None:\n        return None\n\n    import jax\n\n    pytree_meta = ocp.pytree_metadata(checkpoint_path).metadata\n\n    def _to_abstract(meta, ref=None):\n        \"\"\"Convert metadata leaf to `jax.ShapeDtypeStruct` with sharding.\"\"\"\n        if hasattr(meta, \"shape\") and hasattr(meta, \"dtype\"):\n            sharding = getattr(ref, \"sharding\", None)\n            return jax.ShapeDtypeStruct(\n                meta.shape, meta.dtype, sharding=sharding\n            )\n        if isinstance(meta, dict):\n            r = ref if isinstance(ref, dict) else {}\n            return {k: _to_abstract(v, r.get(k)) for k, v in meta.items()}\n        return None\n\n    return {\n        key: _to_abstract(val, (ref_state or {}).get(key))\n        for key, val in pytree_meta.items()\n    }\n"
  },
  {
    "path": "keras/src/saving/saving_api.py",
    "content": "import os\nimport zipfile\n\nfrom absl import logging\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.legacy.saving import legacy_h5_format\nfrom keras.src.saving import saving_lib\nfrom keras.src.saving.orbax_util import build_orbax_abstract_pytree\nfrom keras.src.saving.orbax_util import find_latest_orbax_checkpoint\nfrom keras.src.saving.orbax_util import is_orbax_checkpoint\nfrom keras.src.utils import file_utils\nfrom keras.src.utils import io_utils\nfrom keras.src.utils.module_utils import h5py\nfrom keras.src.utils.module_utils import ocp\n\n\n@keras_export([\"keras.saving.save_model\", \"keras.models.save_model\"])\ndef save_model(model, filepath, overwrite=True, zipped=None, **kwargs):\n    \"\"\"Saves a model as a `.keras` file.\n\n    Args:\n        model: Keras model instance to be saved.\n        filepath: `str` or `pathlib.Path` object. Path where to save the model.\n        overwrite: Whether we should overwrite any existing model at the target\n            location, or instead ask the user via an interactive prompt.\n        zipped: Whether to save the model as a zipped `.keras`\n            archive (default when saving locally), or as an unzipped directory\n            (default when saving on the Hugging Face Hub).\n\n    Example:\n\n    ```python\n    model = keras.Sequential(\n        [\n            keras.layers.Dense(5, input_shape=(3,)),\n            keras.layers.Softmax(),\n        ],\n    )\n    model.save(\"model.keras\")\n    loaded_model = keras.saving.load_model(\"model.keras\")\n    x = keras.random.uniform((10, 3))\n    assert np.allclose(model.predict(x), loaded_model.predict(x))\n    ```\n\n    Note that `model.save()` is an alias for `keras.saving.save_model()`.\n\n    The saved `.keras` file is a `zip` archive that contains:\n\n    - The model's configuration (architecture)\n    - The model's weights\n    - The model's optimizer's state (if any)\n\n    Thus models can be reinstantiated in the exact same state.\n    \"\"\"\n    include_optimizer = kwargs.pop(\"include_optimizer\", True)\n    save_format = kwargs.pop(\"save_format\", False)\n    if save_format:\n        if str(filepath).endswith((\".h5\", \".hdf5\")) or str(filepath).endswith(\n            \".keras\"\n        ):\n            logging.warning(\n                \"The `save_format` argument is deprecated in Keras 3. \"\n                \"We recommend removing this argument as it can be inferred \"\n                \"from the file path. \"\n                f\"Received: save_format={save_format}\"\n            )\n        else:\n            raise ValueError(\n                \"The `save_format` argument is deprecated in Keras 3. \"\n                \"Please remove this argument and pass a file path with \"\n                \"either `.keras` or `.h5` extension.\"\n                f\"Received: save_format={save_format}\"\n            )\n    if kwargs:\n        raise ValueError(\n            \"The following argument(s) are not supported: \"\n            f\"{list(kwargs.keys())}\"\n        )\n\n    # Deprecation warnings\n    if str(filepath).endswith((\".h5\", \".hdf5\")):\n        logging.warning(\n            \"You are saving your model as an HDF5 file via \"\n            \"`model.save()` or `keras.saving.save_model(model)`. \"\n            \"This file format is considered legacy. \"\n            \"We recommend using instead the native Keras format, \"\n            \"e.g. `model.save('my_model.keras')` or \"\n            \"`keras.saving.save_model(model, 'my_model.keras')`. \"\n        )\n\n    is_hf = str(filepath).startswith(\"hf://\")\n    if zipped is None:\n        zipped = not is_hf  # default behavior depends on destination\n\n    # If file exists and should not be overwritten.\n    try:\n        exists = (not is_hf) and os.path.exists(filepath)\n    except TypeError:\n        exists = False\n    if exists and not overwrite:\n        proceed = io_utils.ask_to_proceed_with_overwrite(filepath)\n        if not proceed:\n            return\n\n    if zipped and str(filepath).endswith(\".keras\"):\n        return saving_lib.save_model(model, filepath)\n    if not zipped:\n        return saving_lib.save_model(model, filepath, zipped=False)\n    if str(filepath).endswith((\".h5\", \".hdf5\")):\n        return legacy_h5_format.save_model_to_hdf5(\n            model, filepath, overwrite, include_optimizer\n        )\n    raise ValueError(\n        \"Invalid filepath extension for saving. \"\n        \"Please add either a `.keras` extension for the native Keras \"\n        f\"format (recommended) or a `.h5` extension. \"\n        \"Use `model.export(filepath)` if you want to export a SavedModel \"\n        \"for use with TFLite/TFServing/etc. \"\n        f\"Received: filepath={filepath}.\"\n    )\n\n\n@keras_export([\"keras.saving.load_model\", \"keras.models.load_model\"])\ndef load_model(filepath, custom_objects=None, compile=True, safe_mode=True):\n    \"\"\"Loads a model saved via `model.save()` or from an Orbax checkpoint.\n\n    Args:\n        filepath: `str` or `pathlib.Path` object, path to the saved model file\n            or Orbax checkpoint directory.\n        custom_objects: Optional dictionary mapping names\n            (strings) to custom classes or functions to be\n            considered during deserialization.\n        compile: Boolean, whether to compile the model after loading.\n        safe_mode: Boolean, whether to disallow unsafe `lambda` deserialization.\n            When `safe_mode=False`, loading an object has the potential to\n            trigger arbitrary code execution. This argument is only\n            applicable to the Keras v3 model format. Defaults to `True`.\n\n    Returns:\n        A Keras model instance. If the original model was compiled,\n        and the argument `compile=True` is set, then the returned model\n        will be compiled. Otherwise, the model will be left uncompiled.\n\n    Example:\n\n    ```python\n    model = keras.Sequential([\n        keras.layers.Dense(5, input_shape=(3,)),\n        keras.layers.Softmax()])\n    model.save(\"model.keras\")\n    loaded_model = keras.saving.load_model(\"model.keras\")\n    ```\n\n    Note that the model variables may have different name values\n    (`var.name` property, e.g. `\"dense_1/kernel:0\"`) after being reloaded.\n    It is recommended that you use layer attributes to\n    access specific variables, e.g. `model.get_layer(\"dense_1\").kernel`.\n    \"\"\"\n    is_keras_zip = str(filepath).endswith(\".keras\") and zipfile.is_zipfile(\n        filepath\n    )\n    is_keras_dir = file_utils.isdir(filepath) and file_utils.exists(\n        file_utils.join(filepath, \"config.json\")\n    )\n    is_hf = str(filepath).startswith(\"hf://\")\n\n    # Support for remote zip files\n    if (\n        file_utils.is_remote_path(filepath)\n        and not file_utils.isdir(filepath)\n        and not is_keras_zip\n        and not is_hf\n    ):\n        local_path = file_utils.join(\n            saving_lib.get_temp_dir(), os.path.basename(filepath)\n        )\n\n        # Copy from remote to temporary local directory\n        file_utils.copy(filepath, local_path)\n\n        # Switch filepath to local zipfile for loading model\n        if zipfile.is_zipfile(local_path):\n            filepath = local_path\n            is_keras_zip = True\n\n    if is_keras_zip or is_keras_dir or is_hf:\n        return saving_lib.load_model(\n            filepath,\n            custom_objects=custom_objects,\n            compile=compile,\n            safe_mode=safe_mode,\n        )\n    if str(filepath).endswith((\".h5\", \".hdf5\")):\n        return legacy_h5_format.load_model_from_hdf5(\n            filepath,\n            custom_objects=custom_objects,\n            compile=compile,\n            safe_mode=safe_mode,\n        )\n\n    # Check for Orbax checkpoint directory using utility function\n    if is_orbax_checkpoint(filepath):\n        return _load_model_from_orbax_checkpoint(\n            filepath,\n            custom_objects=custom_objects,\n            compile=compile,\n            safe_mode=safe_mode,\n        )\n\n    elif str(filepath).endswith(\".keras\"):\n        raise ValueError(\n            f\"File not found: filepath={filepath}. \"\n            \"Please ensure the file is an accessible `.keras` \"\n            \"zip file.\"\n        )\n    else:\n        raise ValueError(\n            f\"File format not supported: filepath={filepath}. \"\n            \"Keras 3 only supports V3 `.keras` files, \"\n            \"legacy H5 format files (`.h5` extension). \"\n            \"Note that the legacy SavedModel format is not \"\n            \"supported by `load_model()` in Keras 3. In \"\n            \"order to reload a TensorFlow SavedModel as an \"\n            \"inference-only layer in Keras 3, use \"\n            \"`keras.layers.TFSMLayer(\"\n            f\"{filepath}, call_endpoint='serving_default')` \"\n            \"(note that your `call_endpoint` \"\n            \"might have a different name).\"\n        )\n\n\n@keras_export(\"keras.saving.save_weights\")\ndef save_weights(\n    model, filepath, overwrite=True, max_shard_size=None, **kwargs\n):\n    filepath_str = str(filepath)\n    if max_shard_size is None and not filepath_str.endswith(\".weights.h5\"):\n        raise ValueError(\n            \"The filename must end in `.weights.h5`. \"\n            f\"Received: filepath={filepath_str}\"\n        )\n    elif max_shard_size is not None and not filepath_str.endswith(\n        (\"weights.h5\", \"weights.json\")\n    ):\n        raise ValueError(\n            \"The filename must end in `.weights.json` when `max_shard_size` is \"\n            f\"specified. Received: filepath={filepath_str}\"\n        )\n    try:\n        exists = os.path.exists(filepath)\n    except TypeError:\n        exists = False\n    if exists and not overwrite:\n        proceed = io_utils.ask_to_proceed_with_overwrite(filepath_str)\n        if not proceed:\n            return\n    saving_lib.save_weights_only(model, filepath, max_shard_size, **kwargs)\n\n\n@keras_export(\"keras.saving.load_weights\")\ndef load_weights(model, filepath, skip_mismatch=False, **kwargs):\n    filepath_str = str(filepath)\n\n    # Get the legacy kwargs.\n    objects_to_skip = kwargs.pop(\"objects_to_skip\", None)\n    by_name = kwargs.pop(\"by_name\", None)\n    if kwargs:\n        raise ValueError(f\"Invalid keyword arguments: {kwargs}\")\n\n    if filepath_str.endswith(\".keras\"):\n        if objects_to_skip is not None:\n            raise ValueError(\n                \"`objects_to_skip` only supports loading '.weights.h5' files.\"\n                f\"Received: {filepath}\"\n            )\n        if by_name is not None:\n            raise ValueError(\n                \"`by_name` only supports loading legacy '.h5' or '.hdf5' \"\n                f\"files. Received: {filepath}\"\n            )\n\n        saving_lib.load_weights_only(\n            model, filepath, skip_mismatch=skip_mismatch\n        )\n    elif filepath_str.endswith(\".weights.h5\") or filepath_str.endswith(\n        \".weights.json\"\n    ):\n        if by_name is not None:\n            raise ValueError(\n                \"`by_name` only supports loading legacy '.h5' or '.hdf5' \"\n                f\"files. Received: {filepath}\"\n            )\n        saving_lib.load_weights_only(\n            model,\n            filepath,\n            skip_mismatch=skip_mismatch,\n            objects_to_skip=objects_to_skip,\n        )\n    elif filepath_str.endswith(\".h5\") or filepath_str.endswith(\".hdf5\"):\n        if objects_to_skip is not None:\n            raise ValueError(\n                \"`objects_to_skip` only supports loading '.weights.h5' files.\"\n                f\"Received: {filepath}\"\n            )\n        if not h5py.available:\n            raise ImportError(\n                \"Loading HDF5 files requires the h5py package. \"\n                \"You can install it via `pip install h5py`\"\n            )\n        with h5py.File(filepath, \"r\") as f:\n            if \"layer_names\" not in f.attrs and \"model_weights\" in f:\n                f = f[\"model_weights\"]\n            if by_name:\n                legacy_h5_format.load_weights_from_hdf5_group_by_name(\n                    f, model, skip_mismatch\n                )\n            else:\n                legacy_h5_format.load_weights_from_hdf5_group(\n                    f, model, skip_mismatch\n                )\n    elif is_orbax_checkpoint(filepath):\n        # Load weights from Orbax checkpoint\n        filepath = str(filepath)\n\n        # Determine if this is a root directory or a step directory\n        items = file_utils.listdir(filepath)\n        has_step_subdirs = any(\n            file_utils.isdir(file_utils.join(filepath, item)) and item.isdigit()\n            for item in items\n        )\n\n        if has_step_subdirs:\n            # It's a root directory, find the latest checkpoint\n            checkpoint_path = find_latest_orbax_checkpoint(filepath)\n        else:\n            # It's a step directory, use it directly\n            checkpoint_path = filepath\n\n        # Build abstract pytree with target shardings so Orbax can\n        # reshard arrays onto the current distribution layout.\n        abstract_pytree = build_orbax_abstract_pytree(\n            checkpoint_path, model.get_state_tree()\n        )\n\n        loaded_checkpointables = ocp.load_checkpointables(\n            checkpoint_path, dict(pytree=abstract_pytree)\n        )\n\n        loaded_state = loaded_checkpointables[\"pytree\"]\n\n        # Set the model state directly from the loaded state\n        model.set_state_tree(loaded_state)\n    else:\n        raise ValueError(\n            f\"File format not supported: filepath={filepath}. \"\n            \"Keras 3 only supports V3 `.keras` files, \"\n            \"`.weights.h5` files, legacy H5 format files \"\n            \"(`.h5` extension), or Orbax checkpoints.\"\n        )\n\n\ndef _load_model_from_orbax_checkpoint(\n    filepath, custom_objects=None, compile=True, safe_mode=True\n):\n    \"\"\"Load a model from an Orbax checkpoint directory.\n\n    `model_config` is stored as its own checkpointable (separate from\n    `pytree`), so loading proceeds in two simple steps:\n\n      1. Load the `model_config` checkpointable to obtain the model\n         configuration string and rebuild the model.\n      2. Load the `pytree` checkpointable (all arrays).  When a JAX\n         distribution is active, an abstract pytree with target\n         shardings is provided so that Orbax reshards arrays onto the\n         current layout.\n    \"\"\"\n    # Ensure orbax is available\n    ocp.initialize()\n\n    # Find the latest checkpoint step using the utility function\n    checkpoint_path = find_latest_orbax_checkpoint(filepath)\n    step = int(os.path.basename(checkpoint_path))\n\n    # Load the composite state efficiently\n    checkpointer = ocp.training.Checkpointer(directory=filepath)\n\n    variable_keys = [\n        \"trainable_variables\",\n        \"non_trainable_variables\",\n        \"optimizer_variables\",\n        \"metrics_variables\",\n    ]\n\n    with ocp.Context():\n        # Check which checkpointables were saved so we only request\n        # keys that actually exist (Orbax raises KeyError otherwise).\n        saved_keys = set(\n            checkpointer.checkpointables_metadata(step).metadata.keys()\n        )\n\n        # Step 1: Load model_config to rebuild the model.\n        if \"model_config\" not in saved_keys:\n            raise ValueError(\n                \"Checkpoint does not contain model configuration. \"\n                \"This checkpoint may have been saved with \"\n                \"save_weights_only=True.\"\n            )\n\n        config_loaded = checkpointer.load_checkpointables(\n            step, {\"model_config\": None}\n        )\n        model = saving_lib._model_from_config(\n            config_loaded[\"model_config\"][\"config\"],\n            custom_objects=custom_objects,\n            compile=compile,\n            safe_mode=safe_mode,\n        )\n\n        # Step 2: Load pytree (arrays) with optional resharding.\n        abstract_pytree = build_orbax_abstract_pytree(\n            checkpoint_path, model.get_state_tree()\n        )\n\n        request = {\"pytree\": abstract_pytree}\n        if \"assets\" in saved_keys:\n            request[\"assets\"] = None\n\n        loaded = checkpointer.load_checkpointables(step, request)\n        composite_state = loaded[\"pytree\"]\n        assets_data = loaded.get(\"assets\")\n\n    state_tree = {\n        key: composite_state[key]\n        for key in variable_keys\n        if key in composite_state\n    }\n\n    # Apply the loaded state to the model\n    model.set_state_tree(state_tree)\n\n    # Load assets if present\n    saving_lib._load_assets_from_dict(model, assets_data)\n\n    return model\n"
  },
  {
    "path": "keras/src/saving/saving_api_test.py",
    "content": "import os\nimport pathlib\nimport unittest.mock as mock\n\nimport numpy as np\nfrom absl import logging\nfrom absl.testing import parameterized\n\nfrom keras.src import layers\nfrom keras.src.legacy.saving.legacy_h5_format import save_model_to_hdf5\nfrom keras.src.models import Sequential\nfrom keras.src.saving import saving_api\nfrom keras.src.testing import test_case\nfrom keras.src.testing.test_utils import named_product\n\n\nclass SaveModelTests(test_case.TestCase):\n    def get_model(self):\n        return Sequential(\n            [\n                layers.Dense(5, input_shape=(3,)),\n                layers.Softmax(),\n            ]\n        )\n\n    def test_basic_saving(self):\n        \"\"\"Test basic model saving and loading.\"\"\"\n        model = self.get_model()\n        filepath = os.path.join(self.get_temp_dir(), \"test_model.keras\")\n        saving_api.save_model(model, filepath)\n\n        loaded_model = saving_api.load_model(filepath)\n        x = np.random.uniform(size=(10, 3))\n        self.assertTrue(np.allclose(model.predict(x), loaded_model.predict(x)))\n\n    def test_invalid_save_format(self):\n        \"\"\"Test deprecated save_format argument.\"\"\"\n        model = self.get_model()\n        with self.assertRaisesRegex(\n            ValueError, \"The `save_format` argument is deprecated\"\n        ):\n            saving_api.save_model(model, \"model.txt\", save_format=True)\n\n    def test_unsupported_arguments(self):\n        \"\"\"Test unsupported argument during model save.\"\"\"\n        model = self.get_model()\n        filepath = os.path.join(self.get_temp_dir(), \"test_model.keras\")\n        with self.assertRaisesRegex(\n            ValueError, r\"The following argument\\(s\\) are not supported\"\n        ):\n            saving_api.save_model(model, filepath, random_arg=True)\n\n    def test_save_h5_format(self):\n        \"\"\"Test saving model in h5 format.\"\"\"\n        model = self.get_model()\n        filepath_h5 = os.path.join(self.get_temp_dir(), \"test_model.h5\")\n\n        # Verify the warning.\n        with mock.patch.object(logging, \"warning\") as mock_warn:\n            saving_api.save_model(model, filepath_h5)\n            mock_warn.assert_called_once_with(\n                \"You are saving your model as an HDF5 file via \"\n                \"`model.save()` or `keras.saving.save_model(model)`. \"\n                \"This file format is considered legacy. \"\n                \"We recommend using instead the native Keras format, \"\n                \"e.g. `model.save('my_model.keras')` or \"\n                \"`keras.saving.save_model(model, 'my_model.keras')`. \"\n            )\n        self.assertTrue(os.path.exists(filepath_h5))\n        os.remove(filepath_h5)\n\n    def test_save_unsupported_extension(self):\n        \"\"\"Test saving model with unsupported extension.\"\"\"\n        model = self.get_model()\n        with self.assertRaisesRegex(\n            ValueError, \"Invalid filepath extension for saving\"\n        ):\n            saving_api.save_model(model, \"model.png\")\n\n    def test_objects_to_skip(self):\n        model = Sequential(\n            [\n                layers.Input((3,)),\n                layers.Dense(5),\n                layers.Dense(5),\n            ]\n        )\n        skip = model.layers[0]\n        filepath = os.path.join(self.get_temp_dir(), \"test_model.weights.h5\")\n        saving_api.save_weights(model, filepath, objects_to_skip=[skip])\n        new_model = Sequential(\n            [\n                layers.Input((3,)),\n                layers.Dense(5),\n                layers.Dense(5),\n            ]\n        )\n        new_model.load_weights(filepath, objects_to_skip=[new_model.layers[0]])\n        self.assertNotAllClose(\n            new_model.layers[0].get_weights()[0],\n            model.layers[0].get_weights()[0],\n        )\n        self.assertAllClose(\n            new_model.layers[0].get_weights()[1],\n            model.layers[0].get_weights()[1],\n        )\n        saving_api.save_weights(model, filepath)\n        new_model.load_weights(filepath, objects_to_skip=[new_model.layers[0]])\n        self.assertNotAllClose(\n            new_model.layers[0].get_weights()[0],\n            model.layers[0].get_weights()[0],\n        )\n        self.assertAllClose(\n            new_model.layers[0].get_weights()[1],\n            model.layers[0].get_weights()[1],\n        )\n\n\nclass LoadModelTests(test_case.TestCase):\n    def get_model(self, dtype=None):\n        return Sequential(\n            [\n                layers.Dense(5, input_shape=(3,), dtype=dtype),\n                layers.Softmax(),\n            ]\n        )\n\n    @parameterized.named_parameters(\n        [\n            {\"testcase_name\": \"bfloat16\", \"dtype\": \"bfloat16\"},\n            {\"testcase_name\": \"float16\", \"dtype\": \"float16\"},\n            {\"testcase_name\": \"float32\", \"dtype\": \"float32\"},\n            {\"testcase_name\": \"float64\", \"dtype\": \"float64\"},\n        ]\n    )\n    def test_basic_load(self, dtype):\n        \"\"\"Test basic model loading.\"\"\"\n        model = self.get_model(dtype)\n        filepath = os.path.join(self.get_temp_dir(), \"test_model.keras\")\n        saving_api.save_model(model, filepath)\n\n        loaded_model = saving_api.load_model(filepath)\n        x = np.random.uniform(size=(10, 3))\n        self.assertEqual(loaded_model.weights[0].dtype, dtype)\n        self.assertTrue(np.allclose(model.predict(x), loaded_model.predict(x)))\n\n    def test_load_unsupported_format(self):\n        \"\"\"Test loading model with unsupported format.\"\"\"\n        with self.assertRaisesRegex(ValueError, \"File format not supported\"):\n            saving_api.load_model(\"model.pkl\")\n\n    def test_load_keras_not_zip(self):\n        \"\"\"Test loading keras file that's not a zip.\"\"\"\n        with self.assertRaisesRegex(ValueError, \"File not found\"):\n            saving_api.load_model(\"not_a_zip.keras\")\n\n    def test_load_h5_format(self):\n        \"\"\"Test loading model in h5 format.\"\"\"\n        model = self.get_model()\n        filepath_h5 = os.path.join(self.get_temp_dir(), \"test_model.h5\")\n        saving_api.save_model(model, filepath_h5)\n        loaded_model = saving_api.load_model(filepath_h5)\n        x = np.random.uniform(size=(10, 3))\n        self.assertTrue(np.allclose(model.predict(x), loaded_model.predict(x)))\n        os.remove(filepath_h5)\n\n    def test_load_model_with_custom_objects(self):\n        \"\"\"Test loading model with custom objects.\"\"\"\n\n        class CustomLayer(layers.Layer):\n            def call(self, inputs):\n                return inputs\n\n        model = Sequential([CustomLayer(input_shape=(3,))])\n        filepath = os.path.join(self.get_temp_dir(), \"custom_model.keras\")\n        model.save(filepath)\n        loaded_model = saving_api.load_model(\n            filepath, custom_objects={\"CustomLayer\": CustomLayer}\n        )\n        self.assertIsInstance(loaded_model.layers[0], CustomLayer)\n        os.remove(filepath)\n\n    def test_save_unzipped(self):\n        \"\"\"Test saving/loading an unzipped model dir.\"\"\"\n        model = self.get_model()\n\n        # Test error with keras extension\n        bad_filepath = os.path.join(self.get_temp_dir(), \"test_model.keras\")\n        with self.assertRaisesRegex(ValueError, \"should not end in\"):\n            saving_api.save_model(model, bad_filepath, zipped=False)\n\n        filepath = os.path.join(self.get_temp_dir(), \"test_model_dir\")\n        saving_api.save_model(model, filepath, zipped=False)\n\n        self.assertTrue(os.path.exists(filepath))\n        self.assertTrue(os.path.isdir(filepath))\n        config_filepath = os.path.join(filepath, \"config.json\")\n        weights_filepath = os.path.join(filepath, \"model.weights.h5\")\n        self.assertTrue(os.path.exists(config_filepath))\n        self.assertTrue(os.path.exists(weights_filepath))\n\n        loaded_model = saving_api.load_model(filepath)\n        x = np.random.uniform(size=(10, 3))\n        self.assertTrue(np.allclose(model.predict(x), loaded_model.predict(x)))\n\n\nclass LoadWeightsTests(test_case.TestCase):\n    def get_model(self, dtype=None):\n        return Sequential(\n            [\n                layers.Dense(5, input_shape=(3,), dtype=dtype),\n                layers.Softmax(),\n            ]\n        )\n\n    @parameterized.named_parameters(\n        named_product(\n            save_format=[\"keras\", \"weights.h5\", \"h5\"],\n            source_dtype=[\"float64\", \"float32\", \"float16\", \"bfloat16\"],\n            dest_dtype=[\"float64\", \"float32\", \"float16\", \"bfloat16\"],\n        )\n    )\n    def test_load_weights(self, save_format, source_dtype, dest_dtype):\n        \"\"\"Test loading keras weights.\"\"\"\n        src_model = self.get_model(dtype=source_dtype)\n        if save_format == \"keras\":\n            filepath = os.path.join(self.get_temp_dir(), \"test_weights.keras\")\n            src_model.save(filepath)\n        elif save_format == \"weights.h5\":\n            filepath = os.path.join(\n                self.get_temp_dir(), \"test_weights.weights.h5\"\n            )\n            src_model.save_weights(filepath)\n        elif save_format == \"h5\":\n            if \"bfloat16\" in (source_dtype, dest_dtype):\n                raise self.skipTest(\n                    \"bfloat16 dtype is not supported in legacy h5 format.\"\n                )\n            filepath = os.path.join(self.get_temp_dir(), \"test_weights.h5\")\n            save_model_to_hdf5(src_model, filepath)\n        else:\n            raise ValueError(f\"Unsupported save format: {save_format}\")\n\n        dest_model = self.get_model(dtype=dest_dtype)\n        dest_model.load_weights(filepath)\n\n        src_weights = src_model.get_weights()\n        dest_weights = dest_model.get_weights()\n        for orig, loaded in zip(src_weights, dest_weights):\n            self.assertAllClose(\n                orig.astype(\"float32\"),\n                loaded.astype(\"float32\"),\n                atol=0.001,\n                rtol=0.01,\n            )\n\n    def test_load_weights_invalid_kwargs(self):\n        src_model = self.get_model()\n        keras_filepath = os.path.join(self.get_temp_dir(), \"test_weights.keras\")\n        weight_h5_filepath = os.path.join(\n            self.get_temp_dir(), \"test_weights.weights.h5\"\n        )\n        legacy_h5_filepath = os.path.join(\n            self.get_temp_dir(), \"test_weights.h5\"\n        )\n        src_model.save(keras_filepath)\n        src_model.save_weights(weight_h5_filepath)\n        save_model_to_hdf5(src_model, legacy_h5_filepath)\n\n        dest_model = self.get_model()\n        # Test keras file.\n        with self.assertRaisesRegex(\n            ValueError, r\"only supports loading '.weights.h5' files.\"\n        ):\n            dest_model.load_weights(keras_filepath, objects_to_skip=[])\n        with self.assertRaisesRegex(\n            ValueError, r\"only supports loading legacy '.h5' or '.hdf5' files.\"\n        ):\n            dest_model.load_weights(keras_filepath, by_name=True)\n        with self.assertRaisesRegex(ValueError, r\"Invalid keyword arguments\"):\n            dest_model.load_weights(keras_filepath, bad_kwarg=None)\n        # Test weights.h5 file.\n        with self.assertRaisesRegex(\n            ValueError, r\"only supports loading legacy '.h5' or '.hdf5' files.\"\n        ):\n            dest_model.load_weights(weight_h5_filepath, by_name=True)\n        # Test h5 file.\n        with self.assertRaisesRegex(\n            ValueError, r\"only supports loading '.weights.h5' files.\"\n        ):\n            dest_model.load_weights(legacy_h5_filepath, objects_to_skip=[])\n\n    def test_load_weights_invalid_extension(self):\n        \"\"\"Test loading weights with unsupported extension.\"\"\"\n        model = self.get_model()\n        with self.assertRaisesRegex(ValueError, \"File format not supported\"):\n            model.load_weights(\"invalid_extension.pkl\")\n\n    def test_load_sharded_weights(self):\n        src_model = self.get_model()\n        temp_filepath = pathlib.Path(\n            os.path.join(self.get_temp_dir(), \"test_weights.weights.json\")\n        )\n        src_model.save_weights(temp_filepath, max_shard_size=1)\n        self.assertLen(os.listdir(temp_filepath.parent), 2)\n        src_weights = src_model.get_weights()\n        dest_model = self.get_model()\n        dest_model.load_weights(temp_filepath)\n        dest_weights = dest_model.get_weights()\n        for orig, loaded in zip(src_weights, dest_weights):\n            self.assertAllClose(orig, loaded)\n"
  },
  {
    "path": "keras/src/saving/saving_lib.py",
    "content": "\"\"\"Python-based idempotent model-saving functionality.\"\"\"\n\nimport datetime\nimport io\nimport json\nimport math\nimport os\nimport pathlib\nimport shutil\nimport tempfile\nimport warnings\nimport zipfile\n\nimport ml_dtypes\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src.backend.common import global_state\nfrom keras.src.saving.serialization_lib import ObjectSharingScope\nfrom keras.src.saving.serialization_lib import deserialize_keras_object\nfrom keras.src.saving.serialization_lib import serialize_keras_object\nfrom keras.src.utils import dtype_utils\nfrom keras.src.utils import file_utils\nfrom keras.src.utils import io_utils\nfrom keras.src.utils import naming\nfrom keras.src.utils import plot_model\nfrom keras.src.utils.model_visualization import check_pydot\nfrom keras.src.utils.summary_utils import readable_memory_size\nfrom keras.src.utils.summary_utils import weight_memory_size\nfrom keras.src.version import __version__ as keras_version\n\ntry:\n    import h5py\nexcept ImportError:\n    h5py = None\ntry:\n    import psutil\nexcept ImportError:\n    psutil = None\ntry:\n    import huggingface_hub\nexcept ImportError:\n    huggingface_hub = None\n\n\n_CONFIG_FILENAME = \"config.json\"\n_METADATA_FILENAME = \"metadata.json\"\n_VARS_FNAME = \"model.weights\"  # Will become e.g. \"model.weights.h5\"\n_VARS_FNAME_H5 = f\"{_VARS_FNAME}.h5\"\n_VARS_FNAME_NPZ = f\"{_VARS_FNAME}.npz\"\n_ASSETS_DIRNAME = \"assets\"\n_MEMORY_UPPER_BOUND = 0.5  # 50%\n\n\n_MODEL_CARD_TEMPLATE = \"\"\"\n---\nlibrary_name: keras\n---\n\nThis model has been uploaded using the Keras library and can be used with JAX,\nTensorFlow, and PyTorch backends.\n\nThis model card has been generated automatically and should be completed by the\nmodel author.\nSee [Model Cards documentation](https://huggingface.co/docs/hub/model-cards) for\nmore information.\n\nFor more details about the model architecture, check out\n[config.json](./config.json).\"\"\"\n\n\ndef save_model(model, filepath, weights_format=\"h5\", zipped=True):\n    \"\"\"Save a zip-archive representing a Keras model to the given file or path.\n\n    The zip-based archive contains the following structure:\n\n    - JSON-based configuration file (config.json): Records of model, layer, and\n        other saveables' configuration.\n    - H5-based saveable state files, found in respective directories, such as\n        model/states.npz, model/dense_layer/states.npz, etc.\n    - Metadata file.\n\n    The states of Keras saveables (layers, optimizers, loss, and metrics) are\n    automatically saved as long as they can be discovered through the attributes\n    returned by `dir(Model)`. Typically, the state includes the variables\n    associated with the saveable, but some specially purposed layers may\n    contain more such as the vocabularies stored in the hashmaps. The saveables\n    define how their states are saved by exposing `save_state()` and\n    `load_state()` APIs.\n\n    For the case of layer states, the variables will be visited as long as\n    they are either 1) referenced via layer attributes, or 2) referenced via a\n    container (list, tuple, or dict), and the container is referenced via a\n    layer attribute.\n    \"\"\"\n    if weights_format == \"h5\" and h5py is None:\n        raise ImportError(\"h5py must be installed in order to save a model.\")\n\n    if not model.built:\n        warnings.warn(\n            \"You are saving a model that has not yet been built. \"\n            \"It might not contain any weights yet. \"\n            \"Consider building the model first by calling it \"\n            \"on some data.\",\n            stacklevel=2,\n        )\n\n    if isinstance(filepath, io.IOBase):\n        _save_model_to_fileobj(model, filepath, weights_format)\n        return\n\n    filepath = str(filepath)\n    is_hf = filepath.startswith(\"hf://\")\n    if zipped and not filepath.endswith(\".keras\"):\n        raise ValueError(\n            \"Invalid `filepath` argument: expected a `.keras` extension. \"\n            f\"Received: filepath={filepath}\"\n        )\n    if not zipped and filepath.endswith(\".keras\"):\n        raise ValueError(\n            \"When using `zipped=False`, the `filepath` argument should not \"\n            f\"end in `.keras`. Received: filepath={filepath}\"\n        )\n    if zipped and is_hf:\n        raise ValueError(\n            \"When saving to the Hugging Face Hub, you should not save the \"\n            f\"model as zipped. Received: filepath={filepath}, zipped={zipped}\"\n        )\n    if is_hf:\n        _upload_model_to_hf(model, filepath, weights_format)\n    elif not zipped:\n        _save_model_to_dir(model, filepath, weights_format)\n    else:\n        if file_utils.is_remote_path(filepath):\n            # Remote path. Zip to local memory byte io and copy to remote\n            zip_filepath = io.BytesIO()\n            _save_model_to_fileobj(model, zip_filepath, weights_format)\n            with file_utils.File(filepath, \"wb\") as f:\n                f.write(zip_filepath.getvalue())\n        else:\n            with open(filepath, \"wb\") as f:\n                _save_model_to_fileobj(model, f, weights_format)\n\n\ndef _serialize_model_as_json(model):\n    with ObjectSharingScope():\n        serialized_model_dict = serialize_keras_object(model)\n    config_json = json.dumps(serialized_model_dict)\n    metadata_json = json.dumps(\n        {\n            \"keras_version\": keras_version,\n            \"date_saved\": datetime.datetime.now().strftime(\"%Y-%m-%d@%H:%M:%S\"),\n        }\n    )\n    return config_json, metadata_json\n\n\ndef _save_model_to_dir(model, dirpath, weights_format):\n    if not file_utils.exists(dirpath):\n        file_utils.makedirs(dirpath)\n    config_json, metadata_json = _serialize_model_as_json(model)\n    with open(file_utils.join(dirpath, _METADATA_FILENAME), \"w\") as f:\n        f.write(metadata_json)\n    with open(file_utils.join(dirpath, _CONFIG_FILENAME), \"w\") as f:\n        f.write(config_json)\n    weights_filepath = file_utils.join(dirpath, _VARS_FNAME_H5)\n    assert_dirpath = file_utils.join(dirpath, _ASSETS_DIRNAME)\n    try:\n        if weights_format == \"h5\":\n            weights_store = H5IOStore(weights_filepath, mode=\"w\")\n        elif weights_format == \"npz\":\n            weights_store = NpzIOStore(weights_filepath, mode=\"w\")\n        else:\n            raise ValueError(\n                \"Unknown `weights_format` argument. \"\n                \"Expected 'h5' or 'npz'. \"\n                f\"Received: weights_format={weights_format}\"\n            )\n        asset_store = DiskIOStore(assert_dirpath, mode=\"w\")\n        _save_state(\n            model,\n            weights_store=weights_store,\n            assets_store=asset_store,\n            inner_path=\"\",\n            visited_saveables=set(),\n        )\n    finally:\n        weights_store.close()\n        asset_store.close()\n\n\ndef _save_model_to_fileobj(model, fileobj, weights_format):\n    config_json, metadata_json = _serialize_model_as_json(model)\n\n    with zipfile.ZipFile(fileobj, \"w\") as zf:\n        with zf.open(_METADATA_FILENAME, \"w\") as f:\n            f.write(metadata_json.encode())\n        with zf.open(_CONFIG_FILENAME, \"w\") as f:\n            f.write(config_json.encode())\n\n        weights_file_path = None\n        weights_store = None\n        asset_store = None\n        write_zf = False\n        try:\n            if weights_format == \"h5\":\n                try:\n                    if is_memory_sufficient(model):\n                        # Load the model weights into memory before writing\n                        # .keras if the system memory is sufficient.\n                        weights_store = H5IOStore(\n                            _VARS_FNAME_H5, archive=zf, mode=\"w\"\n                        )\n                    else:\n                        # Try opening the .h5 file, then writing it to `zf` at\n                        # the end of the function call. This is more memory\n                        # efficient than writing the weights into memory first.\n                        working_dir = pathlib.Path(fileobj.name).parent\n                        weights_file_path = tempfile.NamedTemporaryFile(\n                            dir=working_dir\n                        )\n                        weights_store = H5IOStore(\n                            weights_file_path.name, mode=\"w\"\n                        )\n                        write_zf = True\n                except:\n                    # If we can't use the local disk for any reason, write the\n                    # weights into memory first, which consumes more memory.\n                    weights_store = H5IOStore(\n                        _VARS_FNAME_H5, archive=zf, mode=\"w\"\n                    )\n            elif weights_format == \"npz\":\n                weights_store = NpzIOStore(\n                    _VARS_FNAME_NPZ, archive=zf, mode=\"w\"\n                )\n            else:\n                raise ValueError(\n                    \"Unknown `weights_format` argument. \"\n                    \"Expected 'h5' or 'npz'. \"\n                    f\"Received: weights_format={weights_format}\"\n                )\n\n            asset_store = DiskIOStore(_ASSETS_DIRNAME, archive=zf, mode=\"w\")\n\n            _save_state(\n                model,\n                weights_store=weights_store,\n                assets_store=asset_store,\n                inner_path=\"\",\n                visited_saveables=set(),\n            )\n        except:\n            # Skip the final `zf.write` if any exception is raised\n            write_zf = False\n            if weights_store:\n                weights_store.archive = None\n            raise\n        finally:\n            if weights_store:\n                weights_store.close()\n            if asset_store:\n                asset_store.close()\n            if write_zf and weights_file_path:\n                zf.write(weights_file_path.name, _VARS_FNAME_H5)\n            if weights_file_path:\n                weights_file_path.close()\n\n\ndef _upload_model_to_hf(model, hf_path, weights_format):\n    if huggingface_hub is None:\n        raise ImportError(\n            \"To save models to the Hugging Face Hub, \"\n            \"you must install the `huggingface_hub` package.\"\n        )\n\n    original_hf_path = hf_path\n    if hf_path.startswith(\"hf://\"):\n        hf_path = hf_path[5:]\n    if hf_path.count(\"/\") > 1:\n        raise ValueError(\n            \"Invalid `hf_path` argument: expected `namespace/model_name`\"\n            f\" format. Received: hf_path={original_hf_path}\"\n        )\n\n    api = huggingface_hub.HfApi(\n        library_name=\"keras\", library_version=keras_version\n    )\n    repo_url = api.create_repo(hf_path, exist_ok=True)\n    repo_id = repo_url.repo_id\n\n    with tempfile.TemporaryDirectory() as tmp_dir:\n        _save_model_to_dir(model, tmp_dir, weights_format)\n\n        model_card = _MODEL_CARD_TEMPLATE\n\n        if check_pydot():\n            plot_path = file_utils.join(tmp_dir, \"assets\", \"summary_plot.png\")\n            plot_model(\n                model,\n                to_file=plot_path,\n                show_layer_names=True,\n                show_shapes=True,\n                show_dtype=True,\n            )\n            if len(model.layers) <= 10:\n                model_card += \"\\n\\n![](./assets/summary_plot.png)\"\n            else:\n                model_card += (\n                    \"A plot of the model can be found \"\n                    \"[here](./assets/summary_plot.png).\"\n                )\n\n        with open(file_utils.join(tmp_dir, \"README.md\"), \"w\") as f:\n            f.write(model_card)\n\n        api.upload_folder(\n            repo_id=repo_id,\n            folder_path=tmp_dir,\n            commit_message=\"Save model using Keras.\",\n        )\n        io_utils.print_msg(\n            f\"Model saved to the Hugging Face Hub: {repo_url}\\n\"\n            \"To load back the model, use \"\n            f\"`keras.saving.load_model('hf://{repo_id}')`\"\n        )\n\n\ndef load_model(filepath, custom_objects=None, compile=True, safe_mode=True):\n    \"\"\"Load a zip archive representing a Keras model.\"\"\"\n    if isinstance(filepath, io.IOBase):\n        return _load_model_from_fileobj(\n            filepath, custom_objects, compile, safe_mode\n        )\n    elif str(filepath).startswith(\"hf://\"):\n        if huggingface_hub is None:\n            raise ImportError(\n                \"To load models from the Hugging Face Hub, \"\n                \"you must install the `huggingface_hub` package.\"\n            )\n\n        repo_id = filepath[5:]\n        folder_path = huggingface_hub.snapshot_download(\n            repo_id=repo_id,\n            library_name=\"keras\",\n            library_version=keras_version,\n        )\n        return _load_model_from_dir(\n            folder_path, custom_objects, compile, safe_mode\n        )\n    else:\n        filepath = str(filepath)\n        if not filepath.endswith(\".keras\"):\n            is_keras_dir = file_utils.isdir(filepath) and file_utils.exists(\n                file_utils.join(filepath, \"config.json\")\n            )\n            if is_keras_dir:\n                return _load_model_from_dir(\n                    filepath, custom_objects, compile, safe_mode\n                )\n            raise ValueError(\n                \"Invalid filename: expected a `.keras` extension. \"\n                f\"Received: filepath={filepath}\"\n            )\n        with open(filepath, \"rb\") as f:\n            return _load_model_from_fileobj(\n                f, custom_objects, compile, safe_mode\n            )\n\n\ndef _load_model_from_dir(dirpath, custom_objects, compile, safe_mode):\n    if not file_utils.exists(dirpath):\n        raise ValueError(f\"Directory doesn't exist: {dirpath}\")\n    if not file_utils.isdir(dirpath):\n        raise ValueError(f\"Path isn't a directory: {dirpath}\")\n\n    with open(file_utils.join(dirpath, _CONFIG_FILENAME), \"r\") as f:\n        config_json = f.read()\n    model = _model_from_config(config_json, custom_objects, compile, safe_mode)\n\n    all_filenames = file_utils.listdir(dirpath)\n    try:\n        if _VARS_FNAME_H5 in all_filenames:\n            weights_file_path = file_utils.join(dirpath, _VARS_FNAME_H5)\n            weights_store = H5IOStore(weights_file_path, mode=\"r\")\n        elif _VARS_FNAME_NPZ in all_filenames:\n            weights_file_path = file_utils.join(dirpath, _VARS_FNAME_NPZ)\n            weights_store = NpzIOStore(weights_file_path, mode=\"r\")\n        else:\n            raise ValueError(\n                f\"Expected a {_VARS_FNAME_H5} or {_VARS_FNAME_NPZ} file.\"\n            )\n        if len(all_filenames) > 3:\n            asset_store = DiskIOStore(\n                file_utils.join(dirpath, _ASSETS_DIRNAME), mode=\"r\"\n            )\n\n        else:\n            asset_store = None\n\n        failed_saveables = set()\n        error_msgs = {}\n        _load_state(\n            model,\n            weights_store=weights_store,\n            assets_store=asset_store,\n            inner_path=\"\",\n            visited_saveables=set(),\n            failed_saveables=failed_saveables,\n            error_msgs=error_msgs,\n        )\n\n    finally:\n        weights_store.close()\n        if asset_store:\n            asset_store.close()\n\n    if failed_saveables:\n        _raise_loading_failure(error_msgs)\n    return model\n\n\ndef _model_from_config(config_json, custom_objects, compile, safe_mode):\n    # Note: we should NOT use a custom JSON decoder. Anything that\n    # needs custom decoding must be handled in deserialize_keras_object.\n    config_dict = json.loads(config_json)\n    if not compile:\n        # Disable compilation\n        config_dict[\"compile_config\"] = None\n    # Construct the model from the configuration file in the archive.\n    with ObjectSharingScope():\n        model = deserialize_keras_object(\n            config_dict, custom_objects, safe_mode=safe_mode\n        )\n    return model\n\n\ndef _load_model_from_fileobj(fileobj, custom_objects, compile, safe_mode):\n    with zipfile.ZipFile(fileobj, \"r\") as zf:\n        with zf.open(_CONFIG_FILENAME, \"r\") as f:\n            config_json = f.read()\n\n        model = _model_from_config(\n            config_json, custom_objects, compile, safe_mode\n        )\n\n        all_filenames = zf.namelist()\n        extract_dir = None\n        weights_store = None\n        asset_store = None\n        try:\n            if _VARS_FNAME_H5 in all_filenames:\n                try:\n                    if is_memory_sufficient(model):\n                        # Load the entire file into memory if the system memory\n                        # is sufficient.\n                        io_file = io.BytesIO(\n                            zf.open(_VARS_FNAME_H5, \"r\").read()\n                        )\n                        weights_store = H5IOStore(io_file, mode=\"r\")\n                    else:\n                        # Try extracting the model.weights.h5 file, and then\n                        # loading it using using h5py. This is significantly\n                        # faster than reading from the zip archive on the fly.\n                        extract_dir = tempfile.TemporaryDirectory(\n                            dir=pathlib.Path(fileobj.name).parent\n                        )\n                        zf.extract(_VARS_FNAME_H5, extract_dir.name)\n                        weights_store = H5IOStore(\n                            pathlib.Path(extract_dir.name, _VARS_FNAME_H5),\n                            mode=\"r\",\n                        )\n                except:\n                    # If we can't use the local disk for any reason, read the\n                    # weights from the zip archive on the fly, which is less\n                    # efficient.\n                    weights_store = H5IOStore(_VARS_FNAME_H5, zf, mode=\"r\")\n            elif _VARS_FNAME_NPZ in all_filenames:\n                weights_store = NpzIOStore(_VARS_FNAME_NPZ, zf, mode=\"r\")\n            else:\n                raise ValueError(\n                    f\"Expected a {_VARS_FNAME_H5} or {_VARS_FNAME_NPZ} file.\"\n                )\n\n            if len(all_filenames) > 3:\n                asset_store = DiskIOStore(_ASSETS_DIRNAME, archive=zf, mode=\"r\")\n\n            failed_saveables = set()\n            error_msgs = {}\n            _load_state(\n                model,\n                weights_store=weights_store,\n                assets_store=asset_store,\n                inner_path=\"\",\n                visited_saveables=set(),\n                failed_saveables=failed_saveables,\n                error_msgs=error_msgs,\n            )\n        finally:\n            if weights_store:\n                weights_store.close()\n            if asset_store:\n                asset_store.close()\n            if extract_dir:\n                extract_dir.cleanup()\n\n        if failed_saveables:\n            _raise_loading_failure(error_msgs)\n    return model\n\n\ndef save_weights_only(\n    model, filepath, max_shard_size=None, objects_to_skip=None\n):\n    \"\"\"Save only the weights of a model to a target filepath.\n\n    Supports both `.weights.h5` and `.keras`.\n    \"\"\"\n    if not model.built:\n        raise ValueError(\n            \"You are saving a model that has not yet been built. \"\n            \"Try building the model first by calling it on some data or \"\n            \"by using `build()`.\"\n        )\n\n    filepath_str = str(filepath)\n    tmp_dir = None\n    remote_filepath = None\n    if max_shard_size is None and not filepath_str.endswith(\".weights.h5\"):\n        raise ValueError(\n            \"The filename must end in `.weights.h5`. \"\n            f\"Received: filepath={filepath_str}\"\n        )\n    elif max_shard_size is not None and not filepath_str.endswith(\n        (\"weights.h5\", \"weights.json\")\n    ):\n        raise ValueError(\n            \"The filename must end in `.weights.json` when `max_shard_size` is \"\n            f\"specified. Received: filepath={filepath_str}\"\n        )\n    try:\n        if file_utils.is_remote_path(filepath):\n            tmp_dir = get_temp_dir()\n            local_filepath = os.path.join(tmp_dir, os.path.basename(filepath))\n            remote_filepath = filepath\n            filepath = local_filepath\n\n        if max_shard_size is not None:\n            weights_store = ShardedH5IOStore(filepath, max_shard_size, mode=\"w\")\n        else:\n            weights_store = H5IOStore(filepath, mode=\"w\")\n        if objects_to_skip is not None:\n            visited_saveables = set(id(o) for o in objects_to_skip)\n        else:\n            visited_saveables = set()\n        _save_state(\n            model,\n            weights_store=weights_store,\n            assets_store=None,\n            inner_path=\"\",\n            visited_saveables=visited_saveables,\n        )\n        weights_store.close()\n    finally:\n        if tmp_dir is not None:\n            file_utils.copy(filepath, remote_filepath)\n            shutil.rmtree(tmp_dir)\n\n\ndef load_weights_only(\n    model, filepath, skip_mismatch=False, objects_to_skip=None\n):\n    \"\"\"Load the weights of a model from a filepath (.keras or .weights.h5).\n\n    Note: only supports h5 for now.\n    \"\"\"\n    if not model.built:\n        raise ValueError(\n            \"You are loading weights into a model that has not yet been built. \"\n            \"Try building the model first by calling it on some data or \"\n            \"by using `build()`.\"\n        )\n\n    archive = None\n    tmp_dir = None\n    filepath_str = str(filepath)\n\n    try:\n        if file_utils.is_remote_path(filepath_str):\n            tmp_dir = get_temp_dir()\n            local_filepath = os.path.join(\n                tmp_dir, os.path.basename(filepath_str)\n            )\n            file_utils.copy(filepath_str, local_filepath)\n            filepath_str = filepath = local_filepath\n\n        if filepath_str.endswith(\"weights.h5\"):\n            weights_store = H5IOStore(filepath, mode=\"r\")\n        elif filepath_str.endswith(\"weights.json\"):\n            weights_store = ShardedH5IOStore(filepath, mode=\"r\")\n        elif filepath_str.endswith(\".keras\"):\n            archive = zipfile.ZipFile(filepath, \"r\")\n            weights_store = H5IOStore(_VARS_FNAME_H5, archive=archive, mode=\"r\")\n\n        failed_saveables = set()\n        if objects_to_skip is not None:\n            visited_saveables = set(id(o) for o in objects_to_skip)\n        else:\n            visited_saveables = set()\n        error_msgs = {}\n        _load_state(\n            model,\n            weights_store=weights_store,\n            assets_store=None,\n            inner_path=\"\",\n            skip_mismatch=skip_mismatch,\n            visited_saveables=visited_saveables,\n            failed_saveables=failed_saveables,\n            error_msgs=error_msgs,\n        )\n        weights_store.close()\n        if archive:\n            archive.close()\n\n        if failed_saveables:\n            _raise_loading_failure(error_msgs, warn_only=skip_mismatch)\n    finally:\n        if tmp_dir is not None:\n            shutil.rmtree(tmp_dir)\n\n\ndef _raise_loading_failure(error_msgs, warn_only=False):\n    first_key = list(error_msgs.keys())[0]\n    ex_saveable, ex_error = error_msgs[first_key]\n    msg = (\n        f\"A total of {len(error_msgs)} objects could not \"\n        \"be loaded. Example error message for \"\n        f\"object {ex_saveable}:\\n\\n\"\n        f\"{ex_error}\\n\\n\"\n        \"List of objects that could not be loaded:\\n\"\n        f\"{[x[0] for x in error_msgs.values()]}\"\n    )\n    if warn_only:\n        warnings.warn(msg)\n    else:\n        raise ValueError(msg)\n\n\ndef _write_to_zip_recursively(zipfile_to_save, system_path, zip_path):\n    if not file_utils.isdir(system_path):\n        zipfile_to_save.write(system_path, zip_path)\n    else:\n        for file_name in file_utils.listdir(system_path):\n            system_file_path = file_utils.join(system_path, file_name).replace(\n                \"\\\\\", \"/\"\n            )\n            zip_file_path = file_utils.join(zip_path, file_name).replace(\n                \"\\\\\", \"/\"\n            )\n            _write_to_zip_recursively(\n                zipfile_to_save, system_file_path, zip_file_path\n            )\n\n\ndef _name_key(name):\n    \"\"\"Make sure that private attributes are visited last.\"\"\"\n    if name.startswith(\"_\"):\n        return f\"~{name}\"\n    return name\n\n\ndef _walk_saveable(saveable):\n    from keras.src.saving.keras_saveable import KerasSaveable\n\n    if not isinstance(saveable, KerasSaveable):\n        raise ValueError(\n            \"Expected object to be an \"\n            \"instance of `KerasSaveable`, but \"\n            f\"got {saveable} of type {type(saveable)}\"\n        )\n\n    obj_type = saveable._obj_type()\n    attr_skipset = get_attr_skipset(obj_type)\n\n    # Save all layers directly tracked by Sequential and Functional first.\n    # This helps avoid ordering concerns for subclassed Sequential or Functional\n    # models with extra attributes--the internal Keras state take precedence.\n    if obj_type in (\"Sequential\", \"Functional\"):\n        yield \"layers\", saveable.layers\n\n    for child_attr in sorted(dir(saveable), key=lambda x: _name_key(x)):\n        if child_attr.startswith(\"__\") or child_attr in attr_skipset:\n            continue\n        try:\n            child_obj = getattr(saveable, child_attr)\n        except Exception:\n            # Avoid raising the exception when visiting the attributes.\n            continue\n        yield child_attr, child_obj\n\n\ndef _save_state(\n    saveable,\n    weights_store,\n    assets_store,\n    inner_path,\n    visited_saveables,\n):\n    from keras.src.saving.keras_saveable import KerasSaveable\n\n    if not isinstance(\n        weights_store, (H5IOStore, ShardedH5IOStore, NpzIOStore, type(None))\n    ):\n        raise ValueError(\n            \"Expected `weights_store` to be an instance of \"\n            \"`H5IOStore`, `ShardedH5IOStore`, `NpzIOStore`, or `None`. \"\n            f\"Received: {weights_store} of type {type(weights_store)}\"\n        )\n    if not isinstance(assets_store, (DiskIOStore, type(None))):\n        raise ValueError(\n            \"Expected `assets_store` to be an instance of \"\n            \"`DiskIOStore` or `None`. \"\n            f\"Received: {assets_store} of type {type(assets_store)}\"\n        )\n\n    # If the saveable has already been saved, skip it.\n    if id(saveable) in visited_saveables:\n        return\n\n    if hasattr(saveable, \"save_own_variables\") and weights_store:\n        if hasattr(saveable, \"name\") and isinstance(saveable.name, str):\n            metadata = {\"name\": saveable.name}\n        else:\n            metadata = None\n        saveable.save_own_variables(\n            weights_store.make(inner_path, metadata=metadata)\n        )\n    if hasattr(saveable, \"save_assets\") and assets_store:\n        saveable.save_assets(assets_store.make(inner_path))\n\n    visited_saveables.add(id(saveable))\n\n    # Recursively save state of children saveables (layers, optimizers, etc.)\n    for child_attr, child_obj in _walk_saveable(saveable):\n        if isinstance(child_obj, KerasSaveable):\n            _save_state(\n                child_obj,\n                weights_store,\n                assets_store,\n                inner_path=file_utils.join(inner_path, child_attr).replace(\n                    \"\\\\\", \"/\"\n                ),\n                visited_saveables=visited_saveables,\n            )\n        elif isinstance(child_obj, (list, dict, tuple, set)):\n            _save_container_state(\n                child_obj,\n                weights_store,\n                assets_store,\n                inner_path=file_utils.join(inner_path, child_attr).replace(\n                    \"\\\\\", \"/\"\n                ),\n                visited_saveables=visited_saveables,\n            )\n\n\ndef _load_state(\n    saveable,\n    weights_store,\n    assets_store,\n    inner_path,\n    skip_mismatch=False,\n    visited_saveables=None,\n    failed_saveables=None,\n    error_msgs=None,\n):\n    from keras.src.saving.keras_saveable import KerasSaveable\n\n    if not isinstance(\n        weights_store, (H5IOStore, ShardedH5IOStore, NpzIOStore, type(None))\n    ):\n        raise ValueError(\n            \"Expected `weights_store` to be an instance of \"\n            \"`H5IOStore`, `ShardedH5IOStore`, `NpzIOStore`, or `None`. \"\n            f\"Received: {weights_store} of type {type(weights_store)}\"\n        )\n    if not isinstance(assets_store, (DiskIOStore, type(None))):\n        raise ValueError(\n            \"Expected `assets_store` to be an instance of \"\n            \"`DiskIOStore` or `None`. \"\n            f\"Received: {assets_store} of type {type(assets_store)}\"\n        )\n\n    if visited_saveables and id(saveable) in visited_saveables:\n        return\n\n    failure = False\n\n    if hasattr(saveable, \"load_own_variables\") and weights_store:\n        if skip_mismatch or failed_saveables is not None:\n            try:\n                saveable.load_own_variables(weights_store.get(inner_path))\n            except Exception as e:\n                if failed_saveables is not None:\n                    failed_saveables.add(id(saveable))\n                error_msgs[id(saveable)] = saveable, e\n                failure = True\n        else:\n            saveable.load_own_variables(weights_store.get(inner_path))\n\n    if hasattr(saveable, \"load_assets\") and assets_store:\n        if skip_mismatch or failed_saveables is not None:\n            try:\n                saveable.load_assets(assets_store.get(inner_path))\n            except Exception as e:\n                if failed_saveables is not None:\n                    failed_saveables.add(id(saveable))\n                error_msgs[id(saveable)] = saveable, e\n                failure = True\n        else:\n            saveable.load_assets(assets_store.get(inner_path))\n\n    if failed_saveables is not None:\n        currently_failed = len(failed_saveables)\n    else:\n        currently_failed = 0\n\n    # Recursively load states for Keras saveables such as layers/optimizers.\n    for child_attr, child_obj in _walk_saveable(saveable):\n        if isinstance(child_obj, KerasSaveable):\n            _load_state(\n                child_obj,\n                weights_store,\n                assets_store,\n                inner_path=file_utils.join(inner_path, child_attr).replace(\n                    \"\\\\\", \"/\"\n                ),\n                skip_mismatch=skip_mismatch,\n                visited_saveables=visited_saveables,\n                failed_saveables=failed_saveables,\n                error_msgs=error_msgs,\n            )\n        elif isinstance(child_obj, (list, dict, tuple, set)):\n            _load_container_state(\n                child_obj,\n                weights_store,\n                assets_store,\n                inner_path=file_utils.join(inner_path, child_attr).replace(\n                    \"\\\\\", \"/\"\n                ),\n                skip_mismatch=skip_mismatch,\n                visited_saveables=visited_saveables,\n                failed_saveables=failed_saveables,\n                error_msgs=error_msgs,\n            )\n\n    if failed_saveables is not None:\n        newly_failed = len(failed_saveables) - currently_failed\n    else:\n        newly_failed = 0\n\n    if not failure:\n        if visited_saveables is not None and newly_failed <= 0:\n            visited_saveables.add(id(saveable))\n        if failed_saveables is not None and id(saveable) in failed_saveables:\n            failed_saveables.remove(id(saveable))\n            error_msgs.pop(id(saveable))\n\n\ndef _save_container_state(\n    container, weights_store, assets_store, inner_path, visited_saveables\n):\n    from keras.src.saving.keras_saveable import KerasSaveable\n\n    used_names = {}\n    if isinstance(container, dict):\n        container = list(container.values())\n\n    for saveable in container:\n        if isinstance(saveable, KerasSaveable):\n            # Do NOT address the saveable via `saveable.name`, since\n            # names are usually autogenerated and thus not reproducible\n            # (i.e. they may vary across two instances of the same model).\n            name = naming.to_snake_case(saveable.__class__.__name__)\n            if name in used_names:\n                used_names[name] += 1\n                name = f\"{name}_{used_names[name]}\"\n            else:\n                used_names[name] = 0\n            _save_state(\n                saveable,\n                weights_store,\n                assets_store,\n                inner_path=file_utils.join(inner_path, name).replace(\"\\\\\", \"/\"),\n                visited_saveables=visited_saveables,\n            )\n\n\ndef _load_container_state(\n    container,\n    weights_store,\n    assets_store,\n    inner_path,\n    skip_mismatch,\n    visited_saveables,\n    failed_saveables,\n    error_msgs,\n):\n    from keras.src.saving.keras_saveable import KerasSaveable\n\n    used_names = {}\n    if isinstance(container, dict):\n        container = list(container.values())\n\n    for saveable in container:\n        if isinstance(saveable, KerasSaveable):\n            name = naming.to_snake_case(saveable.__class__.__name__)\n            if name in used_names:\n                used_names[name] += 1\n                name = f\"{name}_{used_names[name]}\"\n            else:\n                used_names[name] = 0\n            _load_state(\n                saveable,\n                weights_store,\n                assets_store,\n                inner_path=file_utils.join(inner_path, name).replace(\"\\\\\", \"/\"),\n                skip_mismatch=skip_mismatch,\n                visited_saveables=visited_saveables,\n                failed_saveables=failed_saveables,\n                error_msgs=error_msgs,\n            )\n\n\nclass DiskIOStore:\n    \"\"\"Asset store backed by disk storage.\n\n    If `archive` is specified, then `root_path` refers to the filename\n    inside the archive.\n\n    If `archive` is not specified, then `root_path` refers to the full path of\n    the target directory.\n    \"\"\"\n\n    def __init__(self, root_path, archive=None, mode=None):\n        self.mode = mode\n        self.root_path = root_path\n        self.archive = archive\n        self.tmp_dir = None\n        if self.archive:\n            self.tmp_dir = get_temp_dir()\n            if self.mode == \"r\":\n                file_utils.extract_open_archive(self.archive, self.tmp_dir)\n            self.working_dir = file_utils.join(\n                self.tmp_dir, self.root_path\n            ).replace(\"\\\\\", \"/\")\n            if self.mode == \"w\":\n                file_utils.makedirs(self.working_dir)\n        else:\n            if mode == \"r\":\n                self.working_dir = root_path\n            else:\n                self.tmp_dir = get_temp_dir()\n                self.working_dir = file_utils.join(\n                    self.tmp_dir, self.root_path\n                ).replace(\"\\\\\", \"/\")\n                file_utils.makedirs(self.working_dir)\n\n    def make(self, path):\n        if not path:\n            return self.working_dir\n        path = file_utils.join(self.working_dir, path).replace(\"\\\\\", \"/\")\n        if not file_utils.exists(path):\n            file_utils.makedirs(path)\n        return path\n\n    def get(self, path):\n        if not path:\n            return self.working_dir\n        path = file_utils.join(self.working_dir, path).replace(\"\\\\\", \"/\")\n        if file_utils.exists(path):\n            return path\n        return None\n\n    def close(self):\n        if self.mode == \"w\" and self.archive:\n            _write_to_zip_recursively(\n                self.archive, self.working_dir, self.root_path\n            )\n        if self.tmp_dir and file_utils.exists(self.tmp_dir):\n            file_utils.rmtree(self.tmp_dir)\n\n\nclass H5IOStore:\n    \"\"\"Numerical variable store backed by HDF5.\n\n    Args:\n        path_or_io: `str`, `pathlib.Path` or `io.BytesIO` object. The path where\n            to save the model.\n        archive: Optional `zipfile.ZipFile` object. If specified, the h5 file\n            will be saved inside the archive and `path_or_io` will be used as\n            the filename.\n        mode: `str`. One of {`\"r\"`, `\"w\"`}. The mode to open the h5 file.\n            Defaults to `\"r\"`.\n    \"\"\"\n\n    def __init__(self, path_or_io, archive=None, mode=\"r\"):\n        if mode not in (\"w\", \"r\"):\n            raise ValueError(\n                f\"`mode` should be either 'w' or 'r'. Received: {mode}\"\n            )\n        if isinstance(path_or_io, (str, pathlib.Path)):\n            self.path_or_io = pathlib.Path(path_or_io)\n        elif isinstance(path_or_io, io.BytesIO):\n            if archive is not None:\n                raise ValueError(\n                    \"When `path_or_io` is an `io.BytesIO` object, `archive` \"\n                    \"should be `None`.\"\n                )\n            self.path_or_io = path_or_io\n        else:\n            raise TypeError(\n                \"`path_or_io` should be a `str`, `pathlib.Path` or \"\n                f\"`io.BytesIO` object. Received: path_or_io={path_or_io} of \"\n                f\"type {type(path_or_io)}.\"\n            )\n        self.mode = mode\n        self.archive = archive\n        self.io_file = None\n\n        # Init H5 file.\n        self.h5_file = self._get_h5_file(self.path_or_io)\n\n        # Init H5 entry group.\n        self._h5_entry_path = None\n        self._h5_entry_group = {}\n        self._h5_entry_metadata = None\n        self._h5_entry_initialized = False\n\n    def __bool__(self):\n        # Delegate `__bool__` to the underlying `h5_file`. Otherwise, Python\n        # will mistakenly using `__len__` to determine the value.\n        return self.h5_file.__bool__()\n\n    def _verify_group(self, group):\n        if not isinstance(group, h5py.Group):\n            raise ValueError(\n                f\"Invalid H5 file, expected Group but received {type(group)}\"\n            )\n        return group\n\n    def _verify_dataset(self, dataset):\n        if not isinstance(dataset, h5py.Dataset):\n            raise ValueError(\n                f\"Invalid H5 file, expected Dataset, received {type(dataset)}\"\n            )\n        if dataset.external:\n            raise ValueError(\n                \"Not allowed: H5 file Dataset with external links: \"\n                f\"{dataset.external}\"\n            )\n        return dataset\n\n    def _get_h5_file(self, path_or_io, mode=None):\n        mode = mode or self.mode\n        if mode not in (\"r\", \"w\", \"a\"):\n            raise ValueError(\n                f\"`mode` should be either 'r', 'w' or 'a'. Received: {mode}\"\n            )\n        if self.archive:\n            if mode == \"w\":\n                self.io_file = io.BytesIO()\n            else:\n                self.io_file = self.archive.open(str(path_or_io), \"r\")\n            return h5py.File(self.io_file, mode=mode)\n        else:\n            return h5py.File(path_or_io, mode=mode)\n\n    def make(self, path, metadata=None):\n        \"\"\"Make a new H5 entry group.\n\n        This method is only available in write mode. It defers the creation of\n        the H5 entry group until `__setitem__` is called, preventing the\n        creation of empty groups.\n\n        Args:\n            path: `str`. The variable path.\n            metadata: Optional `dict`. The metadata to save with the H5 entry\n                group. Defaults to `None`.\n        \"\"\"\n        if self.mode != \"w\":\n            raise ValueError(\"`make` is only allowed in write mode.\")\n        if not isinstance(metadata, (dict, type(None))):\n            raise ValueError(\n                f\"`metadata` should be a dict or `None`. Received: {metadata}\"\n            )\n\n        self._h5_entry_path = path\n        if metadata:\n            self._create_h5_group(path, metadata=metadata)\n        else:\n            # Defer to `__setitem__` for H5 group creation to prevent the\n            # creation of empty groups when the store is unused.\n            self._h5_entry_group = {}\n            self._h5_entry_initialized = False\n        return self\n\n    def get(self, path):\n        \"\"\"Get the H5 entry group.\n\n        This method is only available in read mode.\n\n        Args:\n            path: `str`. The variable path.\n        \"\"\"\n        if self.mode != \"r\":\n            raise ValueError(\"`get` is only allowed in read mode.\")\n\n        self._h5_entry_path = path\n        self._h5_entry_group = {}  # Defaults to an empty dict if not found.\n        if not path:\n            if \"vars\" in self.h5_file:\n                self._h5_entry_group = self._verify_group(self.h5_file[\"vars\"])\n        elif path in self.h5_file and \"vars\" in self.h5_file[path]:\n            self._h5_entry_group = self._verify_group(\n                self._verify_group(self.h5_file[path])[\"vars\"]\n            )\n        else:\n            # No hit. Fix for 2.13 compatibility.\n            if \"_layer_checkpoint_dependencies\" in self.h5_file:\n                path = path.replace(\"layers\", \"_layer_checkpoint_dependencies\")\n                if path in self.h5_file and \"vars\" in self.h5_file[path]:\n                    self._h5_entry_group = self._verify_group(\n                        self._verify_group(self.h5_file[path])[\"vars\"]\n                    )\n        self._h5_entry_initialized = True\n        return self\n\n    def close(self):\n        self.h5_file.close()\n        if self.mode == \"w\" and self.archive:\n            self.archive.writestr(str(self.path_or_io), self.io_file.getvalue())\n        if self.io_file:\n            self.io_file.close()\n\n    # H5 entry level methods.\n\n    def _create_h5_group(self, path, metadata=None):\n        if not path:\n            self._h5_entry_group = self.h5_file.create_group(\"vars\")\n        else:\n            self._h5_entry_group = self.h5_file.create_group(path).create_group(\n                \"vars\"\n            )\n        if metadata:\n            for k, v in metadata.items():\n                self._h5_entry_group.attrs[k] = v\n\n        self._h5_entry_initialized = True\n\n    def __len__(self):\n        return self._h5_entry_group.__len__()\n\n    def keys(self):\n        return self._h5_entry_group.keys()\n\n    def __getitem__(self, key):\n        value = self._verify_dataset(self._h5_entry_group[key])\n        if (\n            hasattr(value, \"attrs\")\n            and \"dtype\" in value.attrs\n            and value.attrs[\"dtype\"] == \"bfloat16\"\n        ):\n            value = np.array(value, dtype=ml_dtypes.bfloat16)\n        elif not isinstance(value, np.ndarray):\n            value = np.array(value)\n        return value\n\n    def __setitem__(self, key, value):\n        if self.mode not in (\"w\", \"a\"):\n            raise ValueError(\"Setting a value is only allowed in write mode.\")\n        if not self._h5_entry_initialized:\n            self._create_h5_group(self._h5_entry_path)\n\n        value = backend.convert_to_numpy(value)\n        if backend.standardize_dtype(value.dtype) == \"bfloat16\":\n            ds = self._h5_entry_group.create_dataset(key, data=value)\n            ds.attrs[\"dtype\"] = \"bfloat16\"\n        else:\n            self._h5_entry_group[key] = value\n\n    def __delitem__(self, key):\n        if self.mode not in (\"w\", \"a\"):\n            raise ValueError(\"Deleting a value is only allowed in write mode.\")\n        del self._h5_entry_group[key]\n\n    def __contains__(self, item):\n        return item in self._h5_entry_group\n\n\nclass ShardedH5IOStore(H5IOStore):\n    \"\"\"Sharded numerical variable store backed by HDF5.\n\n    Args:\n        path_or_io: `str` or `pathlib.Path` object. The path where to save the\n            model.\n        max_shard_size: `int` or `float`. Maximum size in GB for each sharded\n            file. If `None`, no sharding will be done. Defaults to `None`.\n        archive: Optional `zipfile.ZipFile` object. If specified, the h5 file\n            will be saved inside the archive and `path_or_io` will be used as\n            the filename.\n        mode: `str`. One of {'r', 'w'}. The mode to open the h5 file. Defaults\n            to `\"r\"`.\n    \"\"\"\n\n    def __init__(self, path_or_io, max_shard_size=5, archive=None, mode=\"r\"):\n        if mode not in (\"w\", \"r\"):\n            raise ValueError(\n                f\"`mode` should be either 'w' or 'r'. Received: {mode}\"\n            )\n        if not isinstance(path_or_io, (str, pathlib.Path)):\n            raise TypeError(\n                \"`path_or_io` should be a `str`, `pathlib.Path` object. \"\n                f\"Received: path_or_io={path_or_io} of type {type(path_or_io)}.\"\n            )\n        self.path = pathlib.Path(path_or_io)\n        self.mode = mode\n        self.archive = archive\n        self.io_file = None\n\n        self.max_shard_size = float(max_shard_size) * 1024**3  # To bytes.\n        self.base_name = self.path.stem.replace(\".weights\", \"\")\n\n        if self.path.suffix != \".json\":\n            method = \"Saving\" if self.mode == \"w\" else \"Loading\"\n            new_path = self.path.with_suffix(\".json\")\n            warnings.warn(\n                f\"{method} sharded weights requires `*.json` as the \"\n                f\"extension. The original path: {str(self.path)} will be \"\n                f\"renamed to {str(new_path)}.\"\n            )\n            self.path = new_path\n\n        # Init H5 entry group.\n        self._h5_entry_path = None\n        self._h5_entry_group = {}\n        self._h5_entry_metadata = None\n        self._h5_entry_initialized = False\n\n        # Init shard parameters.\n        self.current_shard_index = 0\n        self.current_shard_size = 0\n        self.total_shard_size = 0  # In bytes.\n        self.current_shard_path = None\n        self.current_shard_filenames = []\n        if self.mode == \"w\":\n            self.sharding_config = {\n                \"metadata\": {\n                    \"total_size\": 0,\n                },\n                \"weight_map\": {},\n            }\n        else:\n            if self.archive:\n                self.sharding_config = json.loads(\n                    self.archive.open(str(self.path), \"r\").read()\n                )\n            else:\n                with open(self.path, \"r\") as map_file:\n                    self.sharding_config = json.load(map_file)\n        self.h5_file = self._create_new_shard_file()\n\n    def make(self, path, metadata=None):\n        \"\"\"Make a new H5 entry group.\n\n        This method is only available in write mode. It defers the creation of\n        the H5 entry group until `__setitem__` is called, preventing the\n        creation of empty groups.\n\n        The information about the current shard is reset.\n\n        Args:\n            path: `str`. The variable path.\n            metadata: Optional `dict`. The metadata to save with the H5 entry\n                group. Defaults to `None`.\n        \"\"\"\n        self.current_shard_filenames = []\n        if self.h5_file is not None:\n            self.current_shard_filenames.append(\n                pathlib.Path(self.h5_file.filename).name\n            )\n        return super().make(path, metadata)\n\n    def get(self, path):\n        \"\"\"Get the H5 entry group.\n\n        This method is only available in read mode. If the path is not found in\n        the current shard, it will switch to the correct shard.\n\n        Args:\n            path: `str`. The variable path.\n        \"\"\"\n        if not path:\n            parsed_path = \"/vars\"\n        else:\n            parsed_path = path\n\n        # If not found, check shard map and switch files.\n        weight_map = self.sharding_config[\"weight_map\"]\n        filenames = weight_map.get(parsed_path) or weight_map.get(\n            f\"/{parsed_path}/vars\"\n        )\n        if filenames is not None:\n            if not isinstance(filenames, list):\n                filenames = [filenames]\n            self.current_shard_filenames = filenames\n            filename = filenames[0]\n        else:\n            self.current_shard_filenames = []\n            filename = None\n\n        if filename is not None and filename != self.current_shard_path.name:\n            self.close()\n            self.h5_file = self._get_h5_file(self.path.with_name(filename))\n        return super().get(path)\n\n    def close(self):\n        if self.h5_file is not None:\n            self.h5_file.close()\n            self.h5_file = None\n        if self.mode == \"w\":\n            self.sharding_config[\"metadata\"][\"total_size\"] = (\n                self.total_shard_size\n            )\n            json_str = json.dumps(self.sharding_config, indent=4)\n            if self.archive:\n                self.archive.writestr(str(self.path), json_str)\n                self.archive.writestr(\n                    str(self.current_shard_path), self.io_file.getvalue()\n                )\n            else:\n                with open(self.path, \"w\") as f:\n                    f.write(json_str)\n        if self.io_file:\n            self.io_file.close()\n\n    # Shard-specific methods.\n\n    def _create_new_shard_file(self):\n        \"\"\"Create a new shard file and return the H5 file object.\"\"\"\n        new_shard_path = (\n            f\"{self.base_name}_{self.current_shard_index:05}.weights.h5\"\n        )\n        self.current_shard_index += 1\n        self.current_shard_path = self.path.with_name(new_shard_path)\n        h5_file = self._get_h5_file(self.current_shard_path)\n        self.current_shard_filenames.append(pathlib.Path(h5_file.filename).name)\n        self._h5_entry_initialized = False\n        return h5_file\n\n    def _switch_h5_file(self, filename, mode):\n        \"\"\"Switch to a different H5 file with the specified mode.\n\n        This is useful for retrieving information from all shards, such as the\n        total length, keys, and items.\n        \"\"\"\n        if mode not in (\"r\", \"a\"):\n            raise ValueError(\n                f\"`mode` should be either 'r' or 'a'. Received: {mode}\"\n            )\n        self.close()\n        self.h5_file = self._get_h5_file(\n            self.path.with_name(filename), mode=mode\n        )\n        self._get_h5_group(self._h5_entry_path)\n\n    def _restore_h5_file(self):\n        \"\"\"Ensure the current shard is the last one created.\"\"\"\n        if (\n            pathlib.Path(self.h5_file.filename).name\n            != self.current_shard_path.name\n        ):\n            mode = \"a\" if self.mode == \"w\" else \"r\"\n            self._switch_h5_file(self.current_shard_path.name, mode=mode)\n\n    # H5 entry level methods.\n\n    def _get_h5_group(self, path):\n        \"\"\"Get the H5 entry group. If it doesn't exist, return an empty dict.\"\"\"\n        try:\n            if not path:\n                self._h5_entry_group = self._verify_group(self.h5_file[\"vars\"])\n            else:\n                self._h5_entry_group = self._verify_group(\n                    self._verify_group(self.h5_file[path])[\"vars\"]\n                )\n            self._h5_entry_initialized = True\n        except KeyError:\n            self._h5_entry_group = {}\n            self._h5_entry_initialized = False\n\n    # Dict methods.\n\n    def __len__(self):\n        total_len = self._h5_entry_group.__len__()\n        for filename in self.current_shard_filenames:\n            if filename == self.current_shard_path.name:\n                continue\n            self._switch_h5_file(filename, mode=\"r\")\n            total_len += self._h5_entry_group.__len__()\n        self._restore_h5_file()\n        return total_len\n\n    def keys(self):\n        keys = []\n        current_shard_keys = list(self._h5_entry_group.keys())\n        for filename in self.current_shard_filenames:\n            if filename == self.current_shard_path.name:\n                keys += current_shard_keys\n            else:\n                self._switch_h5_file(filename, mode=\"r\")\n                keys += list(self._h5_entry_group.keys())\n        self._restore_h5_file()\n        return keys\n\n    def __getitem__(self, key):\n        if key in self._h5_entry_group:\n            return super().__getitem__(key)\n\n        for filename in self.current_shard_filenames:\n            if filename == self.current_shard_path.name:\n                continue\n            self._switch_h5_file(filename, mode=\"r\")\n            if key in self._h5_entry_group:\n                item = super().__getitem__(key)\n                self._restore_h5_file()\n                return item\n        raise KeyError(\n            f\"Key '{key}' not found in any of the shards: \"\n            f\"{self.current_shard_filenames}\"\n        )\n\n    def __setitem__(self, key, value):\n        self._restore_h5_file()\n\n        # Accumulate `current_shard_size`.\n        value = backend.convert_to_numpy(value)\n        dtype = backend.standardize_dtype(value.dtype)\n        weight_counts = math.prod(value.shape)\n        per_param_size = dtype_utils.dtype_size(dtype)\n        value_size = weight_counts * per_param_size / 8  # In bytes.\n        self.total_shard_size += value_size\n        if value_size > self.max_shard_size:\n            value_size_str = readable_memory_size(value_size)\n            max_shard_size_str = readable_memory_size(self.max_shard_size)\n            raise ValueError(\n                f\"The size of {key} is {value_size_str} which \"\n                f\"exceeds the maximum shard size {max_shard_size_str}. You \"\n                \"can increase the `max_shard_size` parameter to accommodate \"\n                \"the size.\"\n            )\n\n        # Create a new shard if the current shard is full.\n        self.current_shard_size += value_size\n        if self.current_shard_size > self.max_shard_size:\n            self.close()\n            self.h5_file = self._create_new_shard_file()\n            self.current_shard_size = value_size\n\n        super().__setitem__(key, value)\n\n        # Update the weight map.\n        variable_path = self._h5_entry_group.name\n        shard_filename = self.current_shard_path.name\n        weight_map = self.sharding_config[\"weight_map\"]\n        if variable_path not in weight_map:\n            weight_map[variable_path] = shard_filename\n        else:\n            if not isinstance(weight_map[variable_path], list):\n                weight_map[variable_path] = [weight_map[variable_path]]\n            if shard_filename not in weight_map[variable_path]:\n                weight_map[variable_path].append(shard_filename)\n\n    def __delitem__(self, key):\n        if key in self._h5_entry_group:\n            super().__delitem__(key)\n            return\n\n        for filename in self.current_shard_filenames:\n            if filename == self.current_shard_path.name:\n                continue\n            self._switch_h5_file(filename, mode=\"a\")\n            if key in self._h5_entry_group:\n                super().__delitem__(key)\n                self._restore_h5_file()\n                return\n        raise KeyError(\n            f\"Key '{key}' not found in any of the shards: \"\n            f\"{self.current_shard_filenames}\"\n        )\n\n    def __contains__(self, item):\n        if item in self._h5_entry_group:\n            return True\n\n        for filename in self.current_shard_filenames:\n            if filename == self.current_shard_path.name:\n                continue\n            self._switch_h5_file(filename, mode=\"r\")\n            if item in self._h5_entry_group:\n                self._restore_h5_file()\n                return True\n        self._restore_h5_file()\n        return False\n\n\nclass NpzIOStore:\n    def __init__(self, root_path, archive=None, mode=\"r\"):\n        \"\"\"Numerical variable store backed by NumPy.savez/load.\n\n         If `archive` is specified, then `root_path` refers to the filename\n        inside the archive.\n\n        If `archive` is not specified, then `root_path` refers to the path of\n        the npz file on disk.\n        \"\"\"\n        self.root_path = root_path\n        self.mode = mode\n        self.archive = archive\n        if mode == \"w\":\n            self.contents = {}\n        else:\n            if self.archive:\n                self.f = archive.open(root_path, mode=\"r\")\n            else:\n                self.f = open(root_path, mode=\"rb\")\n            self.contents = np.load(self.f)\n\n    def make(self, path, metadata=None):\n        if not path:\n            self.contents[\"__root__\"] = {}\n            return self.contents[\"__root__\"]\n        self.contents[path] = {}\n        return self.contents[path]\n\n    def get(self, path):\n        if not path:\n            if \"__root__\" in self.contents:\n                return dict(self.contents[\"__root__\"])\n            return {}\n        if path in self.contents:\n            return self.contents[path].tolist()\n        return {}\n\n    def close(self):\n        if self.mode == \"w\":\n            if self.archive:\n                self.f = self.archive.open(\n                    self.root_path, mode=\"w\", force_zip64=True\n                )\n            else:\n                self.f = open(self.root_path, mode=\"wb\")\n            np.savez(self.f, **self.contents)\n        self.f.close()\n\n\ndef get_temp_dir():\n    temp_dir = tempfile.mkdtemp()\n    testfile = tempfile.TemporaryFile(dir=temp_dir)\n    testfile.close()\n    return temp_dir\n\n\ndef get_attr_skipset(obj_type):\n    skipset = global_state.get_global_attribute(\n        f\"saving_attr_skiplist_{obj_type}\", None\n    )\n    if skipset is not None:\n        return skipset\n\n    skipset = set(\n        [\n            \"_self_unconditional_dependency_names\",\n        ]\n    )\n    if obj_type == \"Operation\":\n        from keras.src.ops.operation import Operation\n\n        ref_obj = Operation()\n        skipset.update(dir(ref_obj))\n    elif obj_type == \"Layer\":\n        from keras.src.layers.layer import Layer\n\n        ref_obj = Layer()\n        skipset.update(dir(ref_obj))\n    elif obj_type == \"Functional\":\n        from keras.src.layers.layer import Layer\n\n        ref_obj = Layer()\n        skipset.update(dir(ref_obj) + [\"operations\", \"_operations\"])\n    elif obj_type == \"Sequential\":\n        from keras.src.layers.layer import Layer\n\n        ref_obj = Layer()\n        skipset.update(dir(ref_obj) + [\"_functional\"])\n    elif obj_type == \"Metric\":\n        from keras.src.metrics.metric import Metric\n        from keras.src.trainers.compile_utils import CompileMetrics\n\n        ref_obj_a = Metric()\n        ref_obj_b = CompileMetrics([], [])\n        skipset.update(dir(ref_obj_a) + dir(ref_obj_b))\n    elif obj_type == \"Optimizer\":\n        from keras.src.optimizers.optimizer import Optimizer\n\n        ref_obj = Optimizer(1.0)\n        skipset.update(dir(ref_obj))\n        skipset.remove(\"variables\")\n    elif obj_type == \"Loss\":\n        from keras.src.losses.loss import Loss\n\n        ref_obj = Loss()\n        skipset.update(dir(ref_obj))\n    elif obj_type == \"Cross\":\n        from keras.src.layers.preprocessing.feature_space import Cross\n\n        ref_obj = Cross((), 1)\n        skipset.update(dir(ref_obj))\n    elif obj_type == \"Feature\":\n        from keras.src.layers.preprocessing.feature_space import Feature\n\n        ref_obj = Feature(\"int32\", lambda x: x, \"int\")\n        skipset.update(dir(ref_obj))\n    else:\n        raise ValueError(\n            f\"get_attr_skipset got invalid {obj_type=}. \"\n            \"Accepted values for `obj_type` are \"\n            \"['Operation', 'Layer', 'Functional', 'Sequential', 'Metric', \"\n            \"'Optimizer', 'Loss', 'Cross', 'Feature']\"\n        )\n\n    global_state.set_global_attribute(\n        f\"saving_attr_skipset_{obj_type}\", skipset\n    )\n    return skipset\n\n\ndef is_memory_sufficient(model):\n    \"\"\"Check if there is sufficient memory to load the model into memory.\n\n    If psutil is installed, we can use it to determine whether the memory is\n    sufficient. Otherwise, we use a predefined value of 1 GB for available\n    memory.\n    \"\"\"\n    if psutil is None:\n        available_memory = 1024 * 1024 * 1024  # 1 GB in bytes\n    else:\n        available_memory = psutil.virtual_memory().available  # In bytes\n    return (\n        weight_memory_size(model.variables)\n        < available_memory * _MEMORY_UPPER_BOUND\n    )\n\n\ndef _split_path_components(path):\n    \"\"\"Split a relative path into a list of individual components.\n\n    Uses ``os.path.split`` iteratively so the result is independent of\n    the platform path separator.\n\n    Example::\n\n        _split_path_components(\"a/b/c.txt\") -> [\"a\", \"b\", \"c.txt\"]\n    \"\"\"\n    parts = []\n    while True:\n        head, tail = os.path.split(path)\n        if tail:\n            parts.append(tail)\n        elif head:\n            parts.append(head)\n            break\n        else:\n            break\n        path = head\n    parts.reverse()\n    return parts\n\n\ndef _write_nested_dict_to_dir(tree, base_dir):\n    \"\"\"Recursively write a nested dict of numpy arrays to a directory tree.\n\n    Each dict key becomes a directory or filename. Leaf values (numpy\n    arrays) are written as binary files.\n    \"\"\"\n    for key, value in tree.items():\n        child_path = os.path.join(base_dir, key)\n        if isinstance(value, dict):\n            os.makedirs(child_path, exist_ok=True)\n            _write_nested_dict_to_dir(value, child_path)\n        elif isinstance(value, np.ndarray):\n            os.makedirs(os.path.dirname(child_path), exist_ok=True)\n            with open(child_path, \"wb\") as f:\n                f.write(value.tobytes())\n\n\ndef _save_assets_to_dict(model):\n    \"\"\"Save model assets to a nested dictionary.\n\n    Collects assets (e.g. vocabulary files) from the model using the\n    Keras ``_save_state`` mechanism and returns them as a nested dictionary\n    that mirrors the directory hierarchy. Leaf values are numpy uint8\n    arrays containing file contents.\n\n    For example, a file at ``layer/sublayer/vocab.txt`` is stored as::\n\n        {\"layer\": {\"sublayer\": {\"vocab.txt\": np.array([...])}}}\n\n    Using a nested structure instead of flat path keys avoids\n    platform-specific path separator issues and zip-slip vulnerabilities.\n\n    Args:\n        model: The model whose assets should be collected.\n\n    Returns:\n        A nested dictionary of numpy uint8 arrays mirroring the asset\n        directory tree, or ``None`` if the model has no assets.\n    \"\"\"\n    assets_store = DiskIOStore(\"assets\", mode=\"w\")\n    try:\n        _save_state(\n            model,\n            weights_store=None,\n            assets_store=assets_store,\n            inner_path=\"\",\n            visited_saveables=set(),\n        )\n\n        assets_tree = {}\n        working_dir = assets_store.working_dir\n        for root, dirs, files in os.walk(working_dir):\n            for fname in files:\n                file_path = os.path.join(root, fname)\n                rel = os.path.relpath(file_path, working_dir)\n                parts = _split_path_components(rel)\n                with open(file_path, \"rb\") as f:\n                    data = np.frombuffer(f.read(), dtype=np.uint8)\n                node = assets_tree\n                for part in parts[:-1]:\n                    node = node.setdefault(part, {})\n                node[parts[-1]] = data\n\n        return assets_tree if assets_tree else None\n    finally:\n        assets_store.close()\n\n\ndef _load_assets_from_dict(model, assets_dict):\n    \"\"\"Load assets from a nested dictionary into the model.\n\n    Reconstructs the asset directory tree from a nested dictionary\n    produced by ``_save_assets_to_dict`` and loads the assets into\n    the model via the Keras ``_load_state`` mechanism.\n\n    Args:\n        model: The model to load assets into.\n        assets_dict: Nested dictionary mirroring the asset directory\n            tree, with numpy uint8 arrays as leaf values.\n    \"\"\"\n    if not assets_dict:\n        return\n\n    with tempfile.TemporaryDirectory() as tmp_dir:\n        _write_nested_dict_to_dir(assets_dict, tmp_dir)\n\n        assets_store = DiskIOStore(tmp_dir, mode=\"r\")\n        _load_state(\n            model,\n            weights_store=None,\n            assets_store=assets_store,\n            inner_path=\"\",\n            visited_saveables=set(),\n            failed_saveables=set(),\n            error_msgs={},\n        )\n        assets_store.close()\n"
  },
  {
    "path": "keras/src/saving/saving_lib_test.py",
    "content": "\"\"\"Tests for Keras python-based idempotent saving functions.\"\"\"\n\nimport json\nimport os\nimport warnings\nimport zipfile\nfrom io import BytesIO\nfrom pathlib import Path\nfrom unittest import mock\n\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nimport keras\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.saving import saving_lib\n\n\n@keras.saving.register_keras_serializable(package=\"my_custom_package\")\nclass MyDense(keras.layers.Layer):\n    def __init__(self, units, **kwargs):\n        super().__init__(**kwargs)\n        self.units = units\n        self.nested_layer = keras.layers.Dense(self.units, name=\"dense\")\n\n    def build(self, input_shape):\n        self.additional_weights = [\n            self.add_weight(\n                shape=(),\n                name=\"my_additional_weight\",\n                initializer=\"ones\",\n                trainable=True,\n            ),\n            self.add_weight(\n                shape=(),\n                name=\"my_additional_weight_2\",\n                initializer=\"ones\",\n                trainable=True,\n            ),\n        ]\n        self.weights_in_dict = {\n            \"my_weight\": self.add_weight(\n                shape=(),\n                name=\"my_dict_weight\",\n                initializer=\"ones\",\n                trainable=True,\n            ),\n        }\n        self.nested_layer.build(input_shape)\n\n    def call(self, inputs):\n        return self.nested_layer(inputs)\n\n    def two(self):\n        return 2\n\n\nASSETS_DATA = \"These are my assets\"\nVARIABLES_DATA = np.random.random((10,))\n\n\n@keras.saving.register_keras_serializable(package=\"my_custom_package\")\nclass LayerWithCustomSaving(MyDense):\n    def build(self, input_shape):\n        self.assets = ASSETS_DATA\n        self.stored_variables = VARIABLES_DATA\n        return super().build(input_shape)\n\n    def save_assets(self, inner_path):\n        with open(os.path.join(inner_path, \"assets.txt\"), \"w\") as f:\n            f.write(self.assets)\n\n    def save_own_variables(self, store):\n        store[\"variables\"] = self.stored_variables\n\n    def load_assets(self, inner_path):\n        with open(os.path.join(inner_path, \"assets.txt\"), \"r\") as f:\n            text = f.read()\n        self.assets = text\n\n    def load_own_variables(self, store):\n        self.stored_variables = np.array(store[\"variables\"])\n\n\n@keras.saving.register_keras_serializable(package=\"my_custom_package\")\nclass CustomModelX(keras.Model):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.dense1 = MyDense(1, name=\"my_dense_1\")\n        self.dense2 = MyDense(1, name=\"my_dense_2\")\n\n    def call(self, inputs):\n        out = self.dense1(inputs)\n        return self.dense2(out)\n\n    def one(self):\n        return 1\n\n\n@keras.saving.register_keras_serializable(package=\"my_custom_package\")\nclass ModelWithCustomSaving(keras.Model):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.custom_dense = LayerWithCustomSaving(1)\n\n    def call(self, inputs):\n        return self.custom_dense(inputs)\n\n\n@keras.saving.register_keras_serializable(package=\"my_custom_package\")\nclass CompileOverridingModel(keras.Model):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.dense1 = MyDense(1)\n\n    def compile(self, *args, **kwargs):\n        super().compile(*args, **kwargs)\n\n    def call(self, inputs):\n        return self.dense1(inputs)\n\n\n@keras.saving.register_keras_serializable(package=\"my_custom_package\")\nclass CompileOverridingSequential(keras.Sequential):\n    def compile(self, *args, **kwargs):\n        super().compile(*args, **kwargs)\n\n\n@keras.saving.register_keras_serializable(package=\"my_custom_package\")\nclass SubclassFunctional(keras.Model):\n    \"\"\"Subclassed functional identical to `_get_basic_functional_model`.\"\"\"\n\n    def __init__(self, **kwargs):\n        inputs = keras.Input(shape=(4,), batch_size=2)\n        dense = keras.layers.Dense(1, name=\"first_dense\")\n        x = dense(inputs)\n        outputs = keras.layers.Dense(1, name=\"second_dense\")(x)\n        super().__init__(inputs=inputs, outputs=outputs, **kwargs)\n        # Attrs for layers in the functional graph should not affect saving\n        self.layer_attr = dense\n\n    @property\n    def layer_property(self):\n        # Properties for layers in the functional graph should not affect saving\n        return self.layer_attr\n\n    def get_config(self):\n        return {}\n\n    @classmethod\n    def from_config(cls, config):\n        return cls(**config)\n\n\n@keras.saving.register_keras_serializable(package=\"my_custom_package\")\ndef my_mean_squared_error(y_true, y_pred):\n    \"\"\"Identical to built-in `mean_squared_error`, but as a custom fn.\"\"\"\n    return ops.mean(ops.square(y_pred - y_true), axis=-1)\n\n\ndef _get_subclassed_model(compile=True):\n    subclassed_model = CustomModelX(name=\"custom_model_x\")\n    if compile:\n        subclassed_model.compile(\n            optimizer=\"adam\",\n            loss=my_mean_squared_error,\n            metrics=[keras.metrics.Hinge(), \"mse\"],\n        )\n    return subclassed_model\n\n\ndef _get_custom_sequential_model(compile=True):\n    sequential_model = keras.Sequential(\n        [MyDense(1), MyDense(1)], name=\"sequential\"\n    )\n    if compile:\n        sequential_model.compile(\n            optimizer=\"adam\",\n            loss=my_mean_squared_error,\n            metrics=[keras.metrics.Hinge(), \"mse\"],\n        )\n    return sequential_model\n\n\ndef _get_basic_sequential_model(compile=True):\n    sequential_model = keras.Sequential(\n        [\n            keras.layers.Dense(1, name=\"dense_1\"),\n            keras.layers.Dense(1, name=\"dense_2\"),\n        ],\n        name=\"sequential\",\n    )\n    if compile:\n        sequential_model.compile(\n            optimizer=\"adam\",\n            loss=my_mean_squared_error,\n            metrics=[keras.metrics.Hinge(), \"mse\"],\n        )\n    return sequential_model\n\n\ndef _get_custom_functional_model(compile=True):\n    inputs = keras.Input(shape=(4,), batch_size=2)\n    x = MyDense(1, name=\"first_dense\")(inputs)\n    outputs = MyDense(1, name=\"second_dense\")(x)\n    functional_model = keras.Model(inputs, outputs)\n    if compile:\n        functional_model.compile(\n            optimizer=\"adam\",\n            loss=my_mean_squared_error,\n            metrics=[keras.metrics.Hinge(), \"mse\"],\n        )\n    return functional_model\n\n\ndef _get_basic_functional_model(compile=True):\n    inputs = keras.Input(shape=(4,), batch_size=2)\n    x = keras.layers.Dense(1, name=\"first_dense\")(inputs)\n    outputs = keras.layers.Dense(1, name=\"second_dense\")(x)\n    functional_model = keras.Model(inputs, outputs)\n    if compile:\n        functional_model.compile(\n            optimizer=\"adam\",\n            loss=my_mean_squared_error,\n            metrics=[keras.metrics.Hinge(), \"mse\"],\n        )\n    return functional_model\n\n\ndef _get_subclassed_functional_model(compile=True):\n    functional_model = SubclassFunctional()\n    if compile:\n        functional_model.compile(\n            optimizer=\"adam\",\n            loss=my_mean_squared_error,\n            metrics=[keras.metrics.Hinge(), \"mse\"],\n        )\n    return functional_model\n\n\n# We need a global function for `Pool.apply_async`\ndef _load_model_fn(filepath):\n    saving_lib.load_model(filepath)\n\n\nclass SavingTest(testing.TestCase):\n    def setUp(self):\n        # Set `_MEMORY_UPPER_BOUND` to zero for testing purpose.\n        self.original_value = saving_lib._MEMORY_UPPER_BOUND\n        saving_lib._MEMORY_UPPER_BOUND = 0\n        return super().setUp()\n\n    def tearDown(self):\n        saving_lib._MEMORY_UPPER_BOUND = self.original_value\n        return super().tearDown()\n\n    def _test_inference_after_instantiation(self, model):\n        x_ref = np.random.random((2, 4))\n        y_ref = model(x_ref)\n        temp_filepath = os.path.join(self.get_temp_dir(), \"my_model.keras\")\n        model.save(temp_filepath)\n\n        loaded_model = saving_lib.load_model(temp_filepath)\n        self.assertFalse(model.compiled)\n        for w_ref, w in zip(model.variables, loaded_model.variables):\n            self.assertAllClose(w_ref, w)\n        self.assertAllClose(y_ref, loaded_model(x_ref))\n\n    @parameterized.named_parameters(\n        (\"subclassed\", _get_subclassed_model),\n        (\"basic_sequential\", _get_basic_sequential_model),\n        (\"basic_functional\", _get_basic_functional_model),\n        (\"custom_sequential\", _get_custom_sequential_model),\n        (\"custom_functional\", _get_custom_functional_model),\n        (\"subclassed_functional\", _get_subclassed_functional_model),\n    )\n    def test_inference_after_instantiation(self, model_fn):\n        model = model_fn(compile=False)\n        self._test_inference_after_instantiation(model)\n\n        # Test small model path\n        saving_lib._MEMORY_UPPER_BOUND = 1.0\n        self._test_inference_after_instantiation(model)\n\n    def _test_compile_preserved(self, model):\n        x_ref = np.random.random((2, 4))\n        y_ref = np.random.random((2, 1))\n\n        model.fit(x_ref, y_ref)\n        out_ref = model(x_ref)\n        ref_metrics = model.evaluate(x_ref, y_ref)\n        temp_filepath = os.path.join(self.get_temp_dir(), \"my_model.keras\")\n        model.save(temp_filepath)\n\n        loaded_model = saving_lib.load_model(temp_filepath)\n        self.assertTrue(model.compiled)\n        self.assertTrue(loaded_model.built)\n        for w_ref, w in zip(model.variables, loaded_model.variables):\n            self.assertAllClose(w_ref, w)\n        self.assertAllClose(out_ref, loaded_model(x_ref))\n\n        self.assertEqual(\n            model.optimizer.__class__, loaded_model.optimizer.__class__\n        )\n        self.assertEqual(\n            model.optimizer.get_config(), loaded_model.optimizer.get_config()\n        )\n        for w_ref, w in zip(\n            model.optimizer.variables, loaded_model.optimizer.variables\n        ):\n            self.assertAllClose(w_ref, w)\n\n        new_metrics = loaded_model.evaluate(x_ref, y_ref)\n        for ref_m, m in zip(ref_metrics, new_metrics):\n            self.assertAllClose(ref_m, m)\n\n    @parameterized.named_parameters(\n        (\"subclassed\", _get_subclassed_model),\n        (\"basic_sequential\", _get_basic_sequential_model),\n        (\"basic_functional\", _get_basic_functional_model),\n        (\"custom_sequential\", _get_custom_sequential_model),\n        (\"custom_functional\", _get_custom_functional_model),\n        (\"subclassed_functional\", _get_subclassed_functional_model),\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_compile_preserved(self, model_fn):\n        model = model_fn(compile=True)\n        self._test_compile_preserved(model)\n\n        # Test small model path\n        saving_lib._MEMORY_UPPER_BOUND = 1.0\n        self._test_compile_preserved(model)\n\n    def test_saving_preserve_unbuilt_state(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"my_model.keras\")\n        subclassed_model = CustomModelX()\n        subclassed_model.save(temp_filepath)\n        loaded_model = saving_lib.load_model(temp_filepath)\n        self.assertEqual(subclassed_model.compiled, loaded_model.compiled)\n        self.assertFalse(subclassed_model.built)\n        self.assertFalse(loaded_model.built)\n\n    @pytest.mark.requires_trainable_backend\n    def test_saved_module_paths_and_class_names(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"my_model.keras\")\n        subclassed_model = _get_subclassed_model()\n        x = np.random.random((100, 32))\n        y = np.random.random((100, 1))\n        subclassed_model.fit(x, y, epochs=1)\n        subclassed_model.save(temp_filepath)\n\n        with zipfile.ZipFile(temp_filepath, \"r\") as z:\n            with z.open(saving_lib._CONFIG_FILENAME, \"r\") as c:\n                config_json = c.read()\n        config_dict = json.loads(config_json)\n        self.assertEqual(\n            config_dict[\"registered_name\"], \"my_custom_package>CustomModelX\"\n        )\n        self.assertEqual(\n            config_dict[\"compile_config\"][\"optimizer\"],\n            keras.src.saving.serialize_keras_object(\n                keras.src.optimizers.get(\"adam\")\n            ),\n        )\n        self.assertEqual(\n            config_dict[\"compile_config\"][\"loss\"][\"config\"],\n            \"my_custom_package>my_mean_squared_error\",\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_saving_custom_assets_and_variables(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"my_model.keras\")\n        model = ModelWithCustomSaving()\n        model.compile(\n            optimizer=\"adam\",\n            loss=\"mse\",\n        )\n        x = np.random.random((100, 32))\n        y = np.random.random((100, 1))\n        model.fit(x, y, epochs=1)\n\n        # Assert that the archive has not been saved.\n        self.assertFalse(os.path.exists(temp_filepath))\n\n        model.save(temp_filepath)\n\n        loaded_model = saving_lib.load_model(temp_filepath)\n        self.assertEqual(loaded_model.custom_dense.assets, ASSETS_DATA)\n        self.assertEqual(\n            loaded_model.custom_dense.stored_variables.tolist(),\n            VARIABLES_DATA.tolist(),\n        )\n\n    def _test_compile_overridden_warnings(self, model_type):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"my_model.keras\")\n        model = (\n            CompileOverridingModel()\n            if model_type == \"subclassed\"\n            else CompileOverridingSequential(\n                [keras.layers.Embedding(4, 1), MyDense(1), MyDense(1)]\n            )\n        )\n        model.compile(\"sgd\", \"mse\")\n        model.save(temp_filepath)\n\n        with mock.patch.object(warnings, \"warn\") as mock_warn:\n            saving_lib.load_model(temp_filepath)\n        if not mock_warn.call_args_list:\n            raise AssertionError(\"Did not warn.\")\n        self.assertIn(\n            \"`compile()` was not called as part of model loading \"\n            \"because the model's `compile()` method is custom. \",\n            mock_warn.call_args_list[0][0][0],\n        )\n\n    def test_compile_overridden_warnings_sequential(self):\n        self._test_compile_overridden_warnings(\"sequential\")\n\n    def test_compile_overridden_warnings_subclassed(self):\n        self._test_compile_overridden_warnings(\"subclassed\")\n\n    def test_metadata(self):\n        temp_filepath = Path(\n            os.path.join(self.get_temp_dir(), \"my_model.keras\")\n        )\n        model = CompileOverridingModel()\n        model.save(temp_filepath)\n        with zipfile.ZipFile(temp_filepath, \"r\") as z:\n            with z.open(saving_lib._METADATA_FILENAME, \"r\") as c:\n                metadata_json = c.read()\n        metadata = json.loads(metadata_json)\n        self.assertIn(\"keras_version\", metadata)\n        self.assertIn(\"date_saved\", metadata)\n\n    # def test_gfile_copy_local_called(self):\n    #     temp_filepath = Path(\n    #         os.path.join(self.get_temp_dir(), \"my_model.keras\")\n    #     )\n    #     model = CompileOverridingModel()\n    #     with mock.patch(\n    #         \"re.match\", autospec=True\n    #     ) as mock_re_match, mock.patch(\n    #         \"tensorflow.compat.v2.io.file_utils.copy\", autospec=True\n    #     ) as mock_copy:\n    #         # Mock Remote Path check to true to test gfile copy logic\n    #         mock_re_match.return_value = True\n    #         model.save(temp_filepath)\n    #         mock_re_match.assert_called()\n    #         mock_copy.assert_called()\n    #         self.assertIn(str(temp_filepath), mock_re_match.call_args.args)\n    #         self.assertIn(str(temp_filepath), mock_copy.call_args.args)\n\n    def test_save_load_weights_only(self):\n        temp_filepath = Path(\n            os.path.join(self.get_temp_dir(), \"mymodel.weights.h5\")\n        )\n        model = _get_basic_functional_model()\n        ref_input = np.random.random((2, 4))\n        ref_output = model.predict(ref_input)\n        saving_lib.save_weights_only(model, temp_filepath)\n        model = _get_basic_functional_model()\n        saving_lib.load_weights_only(model, temp_filepath)\n        self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6)\n        # Test with Model method\n        model = _get_basic_functional_model()\n        model.load_weights(temp_filepath)\n        self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6)\n\n    def test_save_weights_only_with_unbuilt_model(self):\n        temp_filepath = Path(\n            os.path.join(self.get_temp_dir(), \"mymodel.weights.h5\")\n        )\n        model = _get_subclassed_model()\n        with self.assertRaisesRegex(\n            ValueError, \"You are saving a model that has not yet been built.\"\n        ):\n            saving_lib.save_weights_only(model, temp_filepath)\n\n    def test_load_weights_only_with_unbuilt_model(self):\n        temp_filepath = Path(\n            os.path.join(self.get_temp_dir(), \"mymodel.weights.h5\")\n        )\n        model = _get_subclassed_model()\n        x = np.random.random((100, 32))\n        _ = model.predict(x)  # Build the model by calling it on some data\n        saving_lib.save_weights_only(model, temp_filepath)\n        saving_lib.load_weights_only(model, temp_filepath)\n\n        new_model = _get_subclassed_model()\n        with self.assertRaisesRegex(\n            ValueError,\n            \"You are loading weights into a model that has not yet been built.\",\n        ):\n            saving_lib.load_weights_only(new_model, temp_filepath)\n\n    def test_load_weights_only_with_keras_file(self):\n        # Test loading weights from whole saved model\n        temp_filepath = Path(os.path.join(self.get_temp_dir(), \"mymodel.keras\"))\n        model = _get_basic_functional_model()\n        ref_input = np.random.random((2, 4))\n        ref_output = model.predict(ref_input)\n        saving_lib.save_model(model, temp_filepath)\n        model = _get_basic_functional_model()\n        saving_lib.load_weights_only(model, temp_filepath)\n        self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6)\n        # Test with Model method\n        model = _get_basic_functional_model()\n        model.load_weights(temp_filepath)\n        self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6)\n\n    def test_save_weights_subclassed_functional(self):\n        # The subclassed and basic functional model should have the same\n        # weights structure.\n        temp_filepath = Path(\n            os.path.join(self.get_temp_dir(), \"mymodel.weights.h5\")\n        )\n        model = _get_basic_functional_model()\n        ref_input = np.random.random((2, 4))\n        ref_output = model.predict(ref_input)\n        # Test saving basic, loading subclassed.\n        saving_lib.save_weights_only(model, temp_filepath)\n        model = _get_subclassed_functional_model()\n        saving_lib.load_weights_only(model, temp_filepath)\n        self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6)\n        # Test saving subclassed, loading basic.\n        saving_lib.save_weights_only(model, temp_filepath)\n        model = _get_basic_functional_model()\n        saving_lib.load_weights_only(model, temp_filepath)\n        self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6)\n\n    @pytest.mark.requires_trainable_backend\n    def test_compile_arg(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"mymodel.keras\")\n        model = _get_basic_functional_model()\n        model.compile(\"sgd\", \"mse\")\n        model.fit(np.random.random((2, 4)), np.random.random((2, 1)))\n        saving_lib.save_model(model, temp_filepath)\n\n        model = saving_lib.load_model(temp_filepath)\n        self.assertEqual(model.compiled, True)\n        model = saving_lib.load_model(temp_filepath, compile=False)\n        self.assertEqual(model.compiled, False)\n\n    # def test_overwrite(self):\n    #     temp_filepath = os.path.join(self.get_temp_dir(), \"mymodel.keras\")\n    #     model = _get_basic_functional_model()\n    #     model.save(temp_filepath)\n    #     model.save(temp_filepath, overwrite=True)\n    #     with self.assertRaises(EOFError):\n    #         model.save(temp_filepath, overwrite=False)\n\n    #     temp_filepath = os.path.join(\n    #         self.get_temp_dir(), \"mymodel.weights.h5\"\n    #     )\n    #     model = _get_basic_functional_model()\n    #     model.save_weights(temp_filepath)\n    #     model.save_weights(temp_filepath, overwrite=True)\n    #     with self.assertRaises(EOFError):\n    #         model.save_weights(temp_filepath, overwrite=False)\n\n    def test_partial_load(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"mymodel.keras\")\n        original_model = keras.Sequential(\n            [\n                keras.Input(shape=(3,), batch_size=2),\n                keras.layers.Dense(4),\n                keras.layers.Dense(5),\n            ]\n        )\n        original_model.save(temp_filepath)\n\n        # Test with a model that has a differently shaped layer\n        new_model = keras.Sequential(\n            [\n                keras.Input(shape=(3,), batch_size=2),\n                keras.layers.Dense(4),\n                keras.layers.Dense(6),\n            ]\n        )\n        new_layer_kernel_value = np.array(new_model.layers[1].kernel)\n        with self.assertRaisesRegex(ValueError, \"must match\"):\n            # Doesn't work by default\n            new_model.load_weights(temp_filepath)\n        # Now it works\n        new_model.load_weights(temp_filepath, skip_mismatch=True)\n        ref_weights = original_model.layers[0].get_weights()\n        new_weights = new_model.layers[0].get_weights()\n        self.assertEqual(len(ref_weights), len(new_weights))\n        for ref_w, w in zip(ref_weights, new_weights):\n            self.assertAllClose(ref_w, w)\n        self.assertAllClose(\n            np.array(new_model.layers[1].kernel), new_layer_kernel_value\n        )\n\n        # Test with a model that has a new layer at the end\n        new_model = keras.Sequential(\n            [\n                keras.Input(shape=(3,), batch_size=2),\n                keras.layers.Dense(4),\n                keras.layers.Dense(5),\n                keras.layers.Dense(5),\n            ]\n        )\n        new_layer_kernel_value = np.array(new_model.layers[2].kernel)\n        with self.assertRaisesRegex(ValueError, \"received 0 variables\"):\n            # Doesn't work by default\n            new_model.load_weights(temp_filepath)\n        # Now it works\n        new_model.load_weights(temp_filepath, skip_mismatch=True)\n        for layer_index in [0, 1]:\n            ref_weights = original_model.layers[layer_index].get_weights()\n            new_weights = new_model.layers[layer_index].get_weights()\n            self.assertEqual(len(ref_weights), len(new_weights))\n            for ref_w, w in zip(ref_weights, new_weights):\n                self.assertAllClose(ref_w, w)\n        self.assertAllClose(\n            np.array(new_model.layers[2].kernel), new_layer_kernel_value\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_save_to_fileobj(self):\n        model = keras.Sequential(\n            [keras.layers.Dense(1, input_shape=(1,)), keras.layers.Dense(1)]\n        )\n        model.compile(optimizer=\"adam\", loss=\"mse\")\n\n        out = BytesIO()\n        saving_lib.save_model(model, out)\n        out.seek(0)\n        model = saving_lib.load_model(out)\n\n        model.fit(np.array([1, 2]), np.array([1, 2]))\n        pred1 = model.predict(np.array([1, 2]))\n\n        out = BytesIO()\n        saving_lib.save_model(model, out)\n        out.seek(0)\n        new_model = saving_lib.load_model(out)\n\n        pred2 = new_model.predict(np.array([1, 2]))\n\n        self.assertAllClose(pred1, pred2, atol=1e-5)\n\n    @parameterized.named_parameters(\n        (\"high_memory_config\", True),\n        (\"low_memory_config\", False),\n    )\n    def test_save_model_exception_raised(self, is_memory_sufficient):\n        if is_memory_sufficient:\n            saving_lib._MEMORY_UPPER_BOUND = 0.5  # 50%\n\n        # Assume we have an error in `save_own_variables`.\n        class RaiseErrorLayer(keras.layers.Layer):\n            def __init__(self, units, **kwargs):\n                super().__init__(**kwargs)\n                self.dense = keras.layers.Dense(units)\n\n            def call(self, inputs):\n                return self.dense(inputs)\n\n            def save_own_variables(self, store):\n                raise ValueError\n\n        model = keras.Sequential([keras.Input([1]), RaiseErrorLayer(1)])\n        filepath = f\"{self.get_temp_dir()}/model.keras\"\n        with self.assertRaises(ValueError):\n            saving_lib.save_model(model, filepath)\n\n        # Ensure we don't have a bad \"model.weights.h5\" inside the zip file.\n        self.assertTrue(Path(filepath).exists())\n        with zipfile.ZipFile(filepath) as zf:\n            all_filenames = zf.namelist()\n            self.assertNotIn(\"model.weights.h5\", all_filenames)\n\n        # Ensure we don't have any temporary files left.\n        self.assertLen(os.listdir(Path(filepath).parent), 1)\n        self.assertIn(\"model.keras\", os.listdir(Path(filepath).parent))\n\n    @parameterized.named_parameters(\n        (\"high_memory_config\", True),\n        (\"low_memory_config\", False),\n    )\n    def test_load_model_exception_raised(self, is_memory_sufficient):\n        if is_memory_sufficient:\n            saving_lib._MEMORY_UPPER_BOUND = 0.5  # 50%\n\n        # Assume we have an error in `load_own_variables`.\n        class RaiseErrorLayer(keras.layers.Layer):\n            def __init__(self, units, **kwargs):\n                super().__init__(**kwargs)\n                self.dense = keras.layers.Dense(units)\n\n            def call(self, inputs):\n                return self.dense(inputs)\n\n            def load_own_variables(self, store):\n                raise ValueError\n\n        model = keras.Sequential([keras.Input([1]), RaiseErrorLayer(1)])\n        filepath = f\"{self.get_temp_dir()}/model.keras\"\n        saving_lib.save_model(model, filepath)\n        with self.assertRaises(ValueError):\n            saving_lib.load_model(\n                filepath, custom_objects={\"RaiseErrorLayer\": RaiseErrorLayer}\n            )\n\n        # Ensure we don't have any temporary files left.\n        self.assertLen(os.listdir(Path(filepath).parent), 1)\n        self.assertIn(\"model.keras\", os.listdir(Path(filepath).parent))\n\n    def test_load_model_read_only_system(self):\n        model = keras.Sequential([keras.Input([1]), keras.layers.Dense(32)])\n        filepath = f\"{self.get_temp_dir()}/model.keras\"\n        saving_lib.save_model(model, filepath)\n\n        # Load the model correctly, regardless of whether an OSError occurs.\n        original_mode = os.stat(Path(filepath).parent).st_mode\n        os.chmod(Path(filepath).parent, mode=0o555)\n        model = saving_lib.load_model(filepath)\n        os.chmod(Path(filepath).parent, mode=original_mode)\n\n        # Ensure we don't have any temporary files left.\n        self.assertLen(os.listdir(Path(filepath).parent), 1)\n        self.assertIn(\"model.keras\", os.listdir(Path(filepath).parent))\n\n    @pytest.mark.skipif(\n        backend.backend() == \"jax\",\n        reason=\"JAX backend doesn't support Python's multiprocessing\",\n    )\n    @pytest.mark.skipif(\n        testing.uses_gpu(),\n        reason=\"This test doesn't support GPU\",\n    )\n    def test_load_model_concurrently(self):\n        import multiprocessing as mp\n\n        model = keras.Sequential([keras.Input([1]), keras.layers.Dense(2)])\n        filepath = f\"{self.get_temp_dir()}/model.keras\"\n        saving_lib.save_model(model, filepath)\n\n        # Load the model concurrently.\n        results = []\n        with mp.Pool(4) as pool:\n            for i in range(4):\n                results.append(pool.apply_async(_load_model_fn, (filepath,)))\n            pool.close()\n            pool.join()\n        [r.get() for r in results]  # No error occurs here\n\n    def test_load_model_containing_reused_layer(self):\n        # https://github.com/keras-team/keras/issues/20307\n        inputs = keras.Input((4,))\n        reused_layer = keras.layers.Dense(4)\n        x = reused_layer(inputs)\n        x = keras.layers.Dense(4)(x)\n        outputs = reused_layer(x)\n        model = keras.Model(inputs, outputs)\n\n        self.assertLen(model.layers, 3)  # Input + 2 Dense layers\n        self._test_inference_after_instantiation(model)\n\n    @parameterized.named_parameters(\n        (\"efficientnet_b0_512\", \"efficientnet_b0\", 1),  # Only 1 sharded file.\n        (\"efficientnet_b0_10\", \"efficientnet_b0\", 0.01),\n    )\n    def test_weights_sharding(self, model_name, max_shard_size):\n        from keras.src.applications import efficientnet\n\n        if backend.image_data_format() == \"channels_last\":\n            shape = (224, 224, 3)\n        else:\n            shape = (3, 224, 224)\n\n        if model_name == \"efficientnet_b0\":\n            model_fn = efficientnet.EfficientNetB0\n\n        temp_filepath = Path(\n            os.path.join(self.get_temp_dir(), \"mymodel.weights.json\")\n        )\n        model = model_fn(weights=None, input_shape=shape)\n        ref_input = np.random.random((1, *shape)).astype(\"float32\")\n        ref_output = model.predict(ref_input)\n\n        # Save the sharded files.\n        saving_lib.save_weights_only(\n            model, temp_filepath, max_shard_size=max_shard_size\n        )\n        self.assertIn(\"mymodel.weights.json\", os.listdir(temp_filepath.parent))\n        if max_shard_size == 1:\n            # 1 sharded file + 1 config file = 2.\n            self.assertLen(os.listdir(temp_filepath.parent), 2)\n        elif max_shard_size == 0.01:\n            # 3 sharded file + 1 config file = 4.\n            self.assertLen(os.listdir(temp_filepath.parent), 4)\n\n        with open(temp_filepath, \"r\") as f:\n            sharding_config = json.load(f)\n        self.assertIn(\"metadata\", sharding_config)\n        self.assertIn(\"weight_map\", sharding_config)\n\n        # Instantiate new model and load the sharded files.\n        model = model_fn(weights=None, input_shape=shape)\n        saving_lib.load_weights_only(model, temp_filepath)\n        self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6)\n\n\nclass SavingAPITest(testing.TestCase):\n    def test_saving_api_errors(self):\n        from keras.src.saving import saving_api\n\n        model = _get_basic_functional_model()\n\n        # Saving API errors\n        temp_filepath = os.path.join(self.get_temp_dir(), \"mymodel\")\n        with self.assertRaisesRegex(ValueError, \"argument is deprecated\"):\n            saving_api.save_model(model, temp_filepath, save_format=\"keras\")\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"mymodel.notkeras\")\n        with self.assertRaisesRegex(ValueError, \"Invalid filepath extension\"):\n            saving_api.save_model(model, temp_filepath)\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"mymodel.keras\")\n        with self.assertRaisesRegex(ValueError, \"are not supported\"):\n            saving_api.save_model(model, temp_filepath, invalid_arg=\"hello\")\n\n        # Loading API errors\n        temp_filepath = os.path.join(self.get_temp_dir(), \"non_existent.keras\")\n        with self.assertRaisesRegex(\n            ValueError, \"Please ensure the file is an accessible\"\n        ):\n            _ = saving_api.load_model(temp_filepath)\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"my_saved_model\")\n        with self.assertRaisesRegex(ValueError, \"File format not supported\"):\n            _ = saving_api.load_model(temp_filepath)\n\n    def test_model_api_endpoint(self):\n        temp_filepath = Path(os.path.join(self.get_temp_dir(), \"mymodel.keras\"))\n        model = _get_basic_functional_model()\n        ref_input = np.random.random((2, 4))\n        ref_output = model.predict(ref_input)\n        model.save(temp_filepath)\n        model = keras.saving.load_model(temp_filepath)\n        self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6)\n\n    def test_model_api_endpoint_h5(self):\n        temp_filepath = Path(os.path.join(self.get_temp_dir(), \"mymodel.h5\"))\n        model = _get_basic_functional_model()\n        ref_input = np.random.random((2, 4))\n        ref_output = model.predict(ref_input)\n        model.save(temp_filepath)\n        model = keras.saving.load_model(temp_filepath)\n        self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6)\n\n    def test_model_api_errors(self):\n        model = _get_basic_functional_model()\n\n        # Saving API errors\n        temp_filepath = os.path.join(self.get_temp_dir(), \"mymodel\")\n        with self.assertRaisesRegex(ValueError, \"argument is deprecated\"):\n            model.save(temp_filepath, save_format=\"keras\")\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"mymodel.notkeras\")\n        with self.assertRaisesRegex(ValueError, \"Invalid filepath extension\"):\n            model.save(temp_filepath)\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"mymodel.keras\")\n        with self.assertRaisesRegex(ValueError, \"are not supported\"):\n            model.save(temp_filepath, invalid_arg=\"hello\")\n\n    def test_safe_mode(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"unsafe_model.keras\")\n        model = keras.Sequential(\n            [\n                keras.Input(shape=(3,)),\n                keras.layers.Lambda(lambda x: x * 2),\n            ]\n        )\n        model.save(temp_filepath)\n        with self.assertRaisesRegex(ValueError, \"arbitrary code execution\"):\n            model = saving_lib.load_model(temp_filepath)\n        model = saving_lib.load_model(temp_filepath, safe_mode=False)\n\n    def test_normalization_kpl(self):\n        # With adapt\n        temp_filepath = os.path.join(self.get_temp_dir(), \"norm_model.keras\")\n        model = keras.Sequential(\n            [\n                keras.Input(shape=(3,)),\n                keras.layers.Normalization(),\n            ]\n        )\n        data = np.random.random((3, 3))\n        model.layers[0].adapt(data)\n        ref_out = model(data)\n        model.save(temp_filepath)\n        model = saving_lib.load_model(temp_filepath)\n        out = model(data)\n        self.assertAllClose(ref_out, out, atol=1e-6)\n\n        # Without adapt\n        model = keras.Sequential(\n            [\n                keras.Input(shape=(3,)),\n                keras.layers.Normalization(\n                    mean=np.random.random((3,)),\n                    variance=np.random.random((3,)),\n                ),\n            ]\n        )\n        ref_out = model(data)\n        model.save(temp_filepath)\n        model = saving_lib.load_model(temp_filepath)\n        out = model(data)\n        self.assertAllClose(ref_out, out, atol=1e-6)\n\n\n# This class is properly registered with a `get_config()` method.\n# However, since it does not subclass keras.layers.Layer, it lacks\n# `from_config()` for deserialization.\n@keras.saving.register_keras_serializable()\nclass GrowthFactor:\n    def __init__(self, factor):\n        self.factor = factor\n\n    def __call__(self, inputs):\n        return inputs * self.factor\n\n    def get_config(self):\n        return {\"factor\": self.factor}\n\n\n@keras.saving.register_keras_serializable(package=\"Complex\")\nclass FactorLayer(keras.layers.Layer):\n    def __init__(self, factor, **kwargs):\n        super().__init__(**kwargs)\n        self.factor = factor\n\n    def call(self, x):\n        return x * self.factor\n\n    def get_config(self):\n        return {\"factor\": self.factor}\n\n\n# This custom model does not explicitly deserialize the layers it includes\n# in its `get_config`. Explicit deserialization in a `from_config` override\n# or `__init__` is needed here, or an error will be thrown at loading time.\n@keras.saving.register_keras_serializable(package=\"Complex\")\nclass ComplexModel(keras.layers.Layer):\n    def __init__(self, first_layer, second_layer=None, **kwargs):\n        super().__init__(**kwargs)\n        self.first_layer = first_layer\n        if second_layer is not None:\n            self.second_layer = second_layer\n        else:\n            self.second_layer = keras.layers.Dense(8)\n\n    def get_config(self):\n        config = super().get_config()\n        config.update(\n            {\n                \"first_layer\": self.first_layer,\n                \"second_layer\": self.second_layer,\n            }\n        )\n        return config\n\n    def call(self, inputs):\n        return self.first_layer(self.second_layer(inputs))\n\n\nclass SavingBattleTest(testing.TestCase):\n    def test_custom_object_without_from_config(self):\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"custom_fn_model.keras\"\n        )\n\n        inputs = keras.Input(shape=(4, 4))\n        outputs = keras.layers.Dense(1, activation=GrowthFactor(0.5))(inputs)\n        model = keras.Model(inputs, outputs)\n\n        model.save(temp_filepath)\n\n        with self.assertRaisesRegex(\n            TypeError, \"Unable to reconstruct an instance\"\n        ):\n            _ = saving_lib.load_model(temp_filepath)\n\n    def test_complex_model_without_explicit_deserialization(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"complex_model.keras\")\n\n        inputs = keras.Input((32,))\n        outputs = ComplexModel(first_layer=FactorLayer(0.5))(inputs)\n        model = keras.Model(inputs, outputs)\n\n        model.save(temp_filepath)\n\n        with self.assertRaisesRegex(TypeError, \"are explicitly deserialized\"):\n            _ = saving_lib.load_model(temp_filepath)\n\n    def test_redefinition_of_trackable(self):\n        \"\"\"Test that a trackable can be aliased under a new name.\"\"\"\n\n        class NormalModel(keras.Model):\n            def __init__(self):\n                super().__init__()\n                self.dense = keras.layers.Dense(3)\n\n            def call(self, x):\n                return self.dense(x)\n\n        class WeirdModel(keras.Model):\n            def __init__(self):\n                super().__init__()\n                # This property will be traversed first,\n                # but \"_dense\" isn't in the saved file\n                # generated by NormalModel.\n                self.a_dense = keras.layers.Dense(3)\n\n            @property\n            def dense(self):\n                return self.a_dense\n\n            def call(self, x):\n                return self.dense(x)\n\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"normal_model.weights.h5\"\n        )\n        model_a = NormalModel()\n        model_a(np.random.random((2, 2)))\n        model_a.save_weights(temp_filepath)\n        model_b = WeirdModel()\n        model_b(np.random.random((2, 2)))\n        model_b.load_weights(temp_filepath)\n        self.assertAllClose(\n            model_a.dense.kernel.numpy(), model_b.dense.kernel.numpy()\n        )\n\n    def test_normalization_legacy_h5_format(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"custom_model.h5\")\n\n        inputs = keras.Input((32,))\n        normalization = keras.layers.Normalization()\n        outputs = normalization(inputs)\n\n        model = keras.Model(inputs, outputs)\n\n        x = np.random.random((1, 32))\n        normalization.adapt(x)\n        ref_out = model(x)\n\n        model.save(temp_filepath)\n        new_model = keras.saving.load_model(temp_filepath)\n        out = new_model(x)\n        self.assertAllClose(ref_out, out, atol=1e-6)\n\n    def test_legacy_h5_format(self):\n        temp_filepath = os.path.join(self.get_temp_dir(), \"custom_model.h5\")\n\n        inputs = keras.Input((32,))\n        x = MyDense(2)(inputs)\n        outputs = CustomModelX()(x)\n        model = keras.Model(inputs, outputs)\n\n        x = np.random.random((1, 32))\n        ref_out = model(x)\n\n        model.save(temp_filepath)\n        new_model = keras.saving.load_model(temp_filepath)\n        out = new_model(x)\n        self.assertAllClose(ref_out, out, atol=1e-6)\n\n    def test_nested_functional_model_saving(self):\n        def func(in_size=4, out_size=2, name=None):\n            inputs = keras.layers.Input(shape=(in_size,))\n            outputs = keras.layers.Dense(out_size)((inputs))\n            return keras.Model(inputs, outputs=outputs, name=name)\n\n        input_a, input_b = keras.Input((4,)), keras.Input((4,))\n        out_a = func(out_size=2, name=\"func_a\")(input_a)\n        out_b = func(out_size=3, name=\"func_b\")(input_b)\n        model = keras.Model([input_a, input_b], outputs=[out_a, out_b])\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"nested_func.keras\")\n        model.save(temp_filepath)\n        new_model = keras.saving.load_model(temp_filepath)\n        x = [np.random.random((2, 4))], np.random.random((2, 4))\n        ref_out = model(x)\n        out = new_model(x)\n        self.assertAllClose(ref_out[0], out[0])\n        self.assertAllClose(ref_out[1], out[1])\n\n    def test_nested_shared_functional_model_saving(self):\n        def func(in_size=4, out_size=2, name=None):\n            inputs = keras.layers.Input(shape=(in_size,))\n            outputs = keras.layers.Dense(out_size)((inputs))\n            return keras.Model(inputs, outputs=outputs, name=name)\n\n        inputs = [keras.Input((4,)), keras.Input((4,))]\n        func_shared = func(out_size=4, name=\"func_shared\")\n        shared_a = func_shared(inputs[0])\n        shared_b = func_shared(inputs[1])\n        out_a = keras.layers.Dense(2)(shared_a)\n        out_b = keras.layers.Dense(2)(shared_b)\n        model = keras.Model(inputs, outputs=[out_a, out_b])\n\n        temp_filepath = os.path.join(\n            self.get_temp_dir(), \"nested_shared_func.keras\"\n        )\n        model.save(temp_filepath)\n        new_model = keras.saving.load_model(temp_filepath)\n        x = [np.random.random((2, 4))], np.random.random((2, 4))\n        ref_out = model(x)\n        out = new_model(x)\n        self.assertAllClose(ref_out[0], out[0])\n        self.assertAllClose(ref_out[1], out[1])\n\n    def test_bidirectional_lstm_saving(self):\n        inputs = keras.Input((3, 2))\n        outputs = keras.layers.Bidirectional(keras.layers.LSTM(64))(inputs)\n        model = keras.Model(inputs, outputs)\n        temp_filepath = os.path.join(self.get_temp_dir(), \"bidir_lstm.keras\")\n        model.save(temp_filepath)\n        new_model = keras.saving.load_model(temp_filepath)\n        x = np.random.random((1, 3, 2))\n        ref_out = model(x)\n        out = new_model(x)\n        self.assertAllClose(ref_out, out)\n\n    def test_remove_weights_only_saving_and_loading(self):\n        def is_remote_path(path):\n            return True\n\n        temp_filepath = os.path.join(self.get_temp_dir(), \"model.weights.h5\")\n\n        with mock.patch(\n            \"keras.src.utils.file_utils.is_remote_path\", is_remote_path\n        ):\n            model = _get_basic_functional_model()\n            model.save_weights(temp_filepath)\n            model.load_weights(temp_filepath)\n\n\nclass SavingH5IOStoreTest(testing.TestCase):\n    def test_h5_io_store_basics(self):\n        temp_filepath = Path(os.path.join(self.get_temp_dir(), \"store.h5\"))\n\n        # Pre-defined data.\n        a = np.random.random((2, 4)).astype(\"float32\")\n        b = np.random.random((4, 8)).astype(\"int32\")\n\n        # Set.\n        store = saving_lib.H5IOStore(temp_filepath, mode=\"w\")\n        vars_store = store.make(\"vars\")\n        vars_store[\"a\"] = a\n        vars_store[\"b\"] = b\n        vars_store[\"c\"] = 42\n        self.assertAllClose(vars_store[\"a\"], a)\n        self.assertAllClose(vars_store[\"b\"], b)\n        self.assertEqual(int(vars_store[\"c\"][()]), 42)\n\n        # Delete.\n        del vars_store[\"c\"]\n\n        # Contain.\n        self.assertNotIn(\"c\", vars_store)\n\n        store.close()\n        self.assertTrue(os.path.exists(temp_filepath))\n\n        # Get.\n        store = saving_lib.H5IOStore(temp_filepath, mode=\"r\")\n        vars_store = store.get(\"vars\")\n        self.assertAllClose(vars_store[\"a\"], a)\n        self.assertAllClose(vars_store[\"b\"], b)\n        self.assertNotIn(\"c\", vars_store)\n\n    def test_h5_io_store_lora(self):\n        # For `keras_hub.models.backbone.save_lora_weights` and\n        # `keras_hub.models.backbone.load_lora_weights`\n        temp_filepath = Path(os.path.join(self.get_temp_dir(), \"layer.lora.h5\"))\n        layer = keras.layers.Dense(units=16)\n        layer.build((None, 8))\n        layer.enable_lora(4)\n\n        ref_input = np.random.random((1, 8)).astype(\"float32\")\n        ref_output = layer(ref_input)\n\n        # Save the LoRA weights.\n        store = saving_lib.H5IOStore(temp_filepath, mode=\"w\")\n        lora_store = store.make(\"lora\")\n        lora_store[\"rank\"] = layer.lora_rank\n        inner_store = store.make(\"lora/0\")\n        inner_store[\"lora_kernel_a\"] = layer.lora_kernel_a\n        inner_store[\"lora_kernel_b\"] = layer.lora_kernel_b\n        store.close()\n\n        # Load the LoRA weights.\n        revived_layer = keras.layers.Dense(units=16)\n        revived_layer.build((None, 8))\n        store = saving_lib.H5IOStore(temp_filepath, mode=\"r\")\n        lora_store = store.get(\"lora\")\n        revived_layer.enable_lora(int(lora_store[\"rank\"][()]))\n        lora_kernel_a = store.get(\"lora/0\")[\"lora_kernel_a\"]\n        lora_kernel_b = store.get(\"lora/0\")[\"lora_kernel_b\"]\n        revived_layer._kernel.assign(layer._kernel)\n        revived_layer.bias.assign(layer.bias)\n        revived_layer.lora_kernel_a.assign(lora_kernel_a)\n        revived_layer.lora_kernel_b.assign(lora_kernel_b)\n        self.assertAllClose(revived_layer(ref_input), ref_output, atol=1e-6)\n\n    def test_h5_io_store_exception_raised(self):\n        temp_filepath = Path(os.path.join(self.get_temp_dir(), \"store.h5\"))\n\n        # Bad `path_or_io`.\n        with self.assertRaisesRegex(\n            TypeError,\n            (\n                r\"`path_or_io` should be a `str`, `pathlib.Path` or \"\n                r\"`io.BytesIO` object.\"\n            ),\n        ):\n            saving_lib.H5IOStore(None, mode=\"w\")\n\n        # Bad `mode`.\n        with self.assertRaisesRegex(\n            ValueError, r\"`mode` should be either 'w' or 'r'.\"\n        ):\n            saving_lib.H5IOStore(temp_filepath, mode=\"x\")\n\n        # No archive when using `io.BytesIO` as `path_or_io`.\n        with self.assertRaisesRegex(\n            ValueError,\n            (\n                r\"When `path_or_io` is an `io.BytesIO` object, `archive` \"\n                r\"should be `None`.\"\n            ),\n        ):\n            saving_lib.H5IOStore(BytesIO(), archive=\"archive\", mode=\"w\")\n\n        store = saving_lib.H5IOStore(temp_filepath, mode=\"w\")\n\n        # Bad `metadata`.\n        with self.assertRaisesRegex(\n            ValueError, r\"`metadata` should be a dict or `None`.\"\n        ):\n            store.make(\"vars\", metadata=\"metadata\")\n\n        store.close()\n\n        store = saving_lib.H5IOStore(temp_filepath, mode=\"r\")\n        vars_store = store.get(\"vars\")\n\n        # Set in read mode.\n        with self.assertRaisesRegex(\n            ValueError, r\"Setting a value is only allowed in write mode.\"\n        ):\n            vars_store[\"weights\"] = np.random.random((2, 4)).astype(\"float32\")\n\n        # Delete in read mode.\n        with self.assertRaisesRegex(\n            ValueError, r\"Deleting a value is only allowed in write mode.\"\n        ):\n            del vars_store[\"weights\"]\n\n    def test_sharded_h5_io_store_basics(self):\n        name = \"sharded_store\"\n        temp_filepath = Path(os.path.join(self.get_temp_dir(), f\"{name}.json\"))\n\n        # Pre-defined data. Each has about 0.0037GB.\n        a = np.random.random((1000, 1000)).astype(\"float32\")\n        b = np.random.random((1000, 1000)).astype(\"int32\")\n\n        # Set.\n        store = saving_lib.ShardedH5IOStore(\n            temp_filepath, max_shard_size=0.005, mode=\"w\"\n        )\n        vars_store = store.make(\"vars\")\n        vars_store[\"a\"] = a\n        vars_store[\"b\"] = b\n        vars_store[\"c\"] = 42\n        self.assertLen(store.sharding_config[\"weight_map\"][\"/vars/vars\"], 2)\n        self.assertLen(vars_store, 3)\n        self.assertAllClose(vars_store[\"a\"], a)\n        self.assertAllClose(vars_store[\"b\"], b)\n        self.assertEqual(int(vars_store[\"c\"][()]), 42)\n\n        # Delete.\n        del vars_store[\"c\"]\n        self.assertLen(vars_store, 2)\n        del vars_store[\"a\"]  # Delete from an older shard.\n        self.assertLen(vars_store, 1)\n        vars_store[\"a\"] = a\n\n        # Contain.\n        self.assertIn(\"a\", vars_store)\n        self.assertNotIn(\"c\", vars_store)\n\n        store.close()\n        self.assertTrue(os.path.exists(temp_filepath))\n        self.assertTrue(\n            os.path.exists(temp_filepath.with_name(f\"{name}_00000.weights.h5\"))\n        )\n\n        # Get.\n        store = saving_lib.ShardedH5IOStore(temp_filepath, mode=\"r\")\n        vars_store = store.get(\"vars\")\n        self.assertLen(vars_store, 2)\n        self.assertAllClose(vars_store[\"a\"], a)\n        self.assertAllClose(vars_store[\"b\"], b)\n        self.assertNotIn(\"c\", vars_store)\n\n        # Keys.\n        for key in [\"a\", \"b\"]:\n            self.assertIn(key, vars_store.keys())\n\n    def test_sharded_h5_io_store_exception_raised(self):\n        temp_filepath = Path(os.path.join(self.get_temp_dir(), \"store.h5\"))\n\n        # Bad `path_or_io`.\n        with self.assertRaisesRegex(\n            TypeError,\n            r\"`path_or_io` should be a `str`, `pathlib.Path` object. \",\n        ):\n            saving_lib.ShardedH5IOStore(None, mode=\"w\")\n\n        # Bad `mode`.\n        with self.assertRaisesRegex(\n            ValueError, r\"`mode` should be either 'w' or 'r'.\"\n        ):\n            saving_lib.ShardedH5IOStore(temp_filepath, mode=\"x\")\n\n        store = saving_lib.ShardedH5IOStore(\n            temp_filepath, max_shard_size=0.00001, mode=\"w\"\n        )\n        vars_store = store.make(\"vars\")\n\n        # Too large data.\n        with self.assertRaisesRegex(\n            ValueError, r\"exceeds the maximum shard size\"\n        ):\n            vars_store[\"weights\"] = np.random.random((100, 100)).astype(\n                \"float32\"\n            )\n\n        # Bad `get`.\n        with self.assertRaisesRegex(\n            KeyError, r\"Key 'abc' not found in any of the shards:\"\n        ):\n            vars_store[\"abc\"]\n\n        # Bad `del`.\n        with self.assertRaisesRegex(\n            KeyError, r\"Key 'abc' not found in any of the shards:\"\n        ):\n            del vars_store[\"abc\"]\n\n        store.close()\n"
  },
  {
    "path": "keras/src/saving/serialization_lib.py",
    "content": "\"\"\"Object config serialization and deserialization logic.\"\"\"\n\nimport importlib\nimport inspect\nimport types\nimport warnings\n\nimport numpy as np\n\nfrom keras.src import api_export\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend.common import global_state\nfrom keras.src.saving import object_registration\nfrom keras.src.saving.keras_saveable import KerasSaveable\nfrom keras.src.utils import python_utils\nfrom keras.src.utils.module_utils import tensorflow as tf\n\nPLAIN_TYPES = (str, int, float, bool)\n\n# List of Keras modules with built-in string representations for Keras defaults\nBUILTIN_MODULES = frozenset(\n    {\n        \"activations\",\n        \"constraints\",\n        \"initializers\",\n        \"losses\",\n        \"metrics\",\n        \"optimizers\",\n        \"regularizers\",\n    }\n)\n\nLOADING_APIS = frozenset(\n    {\n        \"keras.config.enable_unsafe_deserialization\",\n        \"keras.models.load_model\",\n        \"keras.preprocessing.image.load_img\",\n        \"keras.saving.load_model\",\n        \"keras.saving.load_weights\",\n        \"keras.utils.get_file\",\n        \"keras.utils.load_img\",\n    }\n)\n\n\nclass SerializableDict:\n    def __init__(self, **config):\n        self.config = config\n\n    def serialize(self):\n        return serialize_keras_object(self.config)\n\n\nclass SafeModeScope:\n    \"\"\"Scope to propagate safe mode flag to nested deserialization calls.\"\"\"\n\n    def __init__(self, safe_mode=True):\n        self.safe_mode = safe_mode\n\n    def __enter__(self):\n        self.original_value = in_safe_mode()\n        global_state.set_global_attribute(\"safe_mode_saving\", self.safe_mode)\n\n    def __exit__(self, *args, **kwargs):\n        global_state.set_global_attribute(\n            \"safe_mode_saving\", self.original_value\n        )\n\n\n@keras_export(\"keras.config.enable_unsafe_deserialization\")\ndef enable_unsafe_deserialization():\n    \"\"\"Disables safe mode globally, allowing deserialization of lambdas.\"\"\"\n    global_state.set_global_attribute(\"safe_mode_saving\", False)\n\n\ndef in_safe_mode():\n    return global_state.get_global_attribute(\"safe_mode_saving\")\n\n\nclass ObjectSharingScope:\n    \"\"\"Scope to enable detection and reuse of previously seen objects.\"\"\"\n\n    def __enter__(self):\n        global_state.set_global_attribute(\"shared_objects/id_to_obj_map\", {})\n        global_state.set_global_attribute(\"shared_objects/id_to_config_map\", {})\n\n    def __exit__(self, *args, **kwargs):\n        global_state.set_global_attribute(\"shared_objects/id_to_obj_map\", None)\n        global_state.set_global_attribute(\n            \"shared_objects/id_to_config_map\", None\n        )\n\n\ndef get_shared_object(obj_id):\n    \"\"\"Retrieve an object previously seen during deserialization.\"\"\"\n    id_to_obj_map = global_state.get_global_attribute(\n        \"shared_objects/id_to_obj_map\"\n    )\n    if id_to_obj_map is not None:\n        return id_to_obj_map.get(obj_id, None)\n\n\ndef record_object_after_serialization(obj, config):\n    \"\"\"Call after serializing an object, to keep track of its config.\"\"\"\n    if config[\"module\"] == \"__main__\":\n        config[\"module\"] = None  # Ensures module is None when no module found\n    id_to_config_map = global_state.get_global_attribute(\n        \"shared_objects/id_to_config_map\"\n    )\n    if id_to_config_map is None:\n        return  # Not in a sharing scope\n    obj_id = int(id(obj))\n    if obj_id not in id_to_config_map:\n        id_to_config_map[obj_id] = config\n    else:\n        config[\"shared_object_id\"] = obj_id\n        prev_config = id_to_config_map[obj_id]\n        prev_config[\"shared_object_id\"] = obj_id\n\n\ndef record_object_after_deserialization(obj, obj_id):\n    \"\"\"Call after deserializing an object, to keep track of it in the future.\"\"\"\n    id_to_obj_map = global_state.get_global_attribute(\n        \"shared_objects/id_to_obj_map\"\n    )\n    if id_to_obj_map is None:\n        return  # Not in a sharing scope\n    id_to_obj_map[obj_id] = obj\n\n\n@keras_export(\n    [\n        \"keras.saving.serialize_keras_object\",\n        \"keras.utils.serialize_keras_object\",\n    ]\n)\ndef serialize_keras_object(obj):\n    \"\"\"Retrieve the config dict by serializing the Keras object.\n\n    `serialize_keras_object()` serializes a Keras object to a python dictionary\n    that represents the object, and is a reciprocal function of\n    `deserialize_keras_object()`. See `deserialize_keras_object()` for more\n    information about the config format.\n\n    Args:\n        obj: the Keras object to serialize.\n\n    Returns:\n        A python dict that represents the object. The python dict can be\n        deserialized via `deserialize_keras_object()`.\n    \"\"\"\n    if obj is None:\n        return obj\n\n    if isinstance(obj, PLAIN_TYPES):\n        return obj\n\n    if isinstance(obj, (list, tuple)):\n        config_arr = [serialize_keras_object(x) for x in obj]\n        return tuple(config_arr) if isinstance(obj, tuple) else config_arr\n    if isinstance(obj, dict):\n        return serialize_dict(obj)\n\n    # Special cases:\n    if isinstance(obj, bytes):\n        return {\n            \"class_name\": \"__bytes__\",\n            \"config\": {\"value\": obj.decode(\"utf-8\")},\n        }\n    if isinstance(obj, slice):\n        return {\n            \"class_name\": \"__slice__\",\n            \"config\": {\n                \"start\": serialize_keras_object(obj.start),\n                \"stop\": serialize_keras_object(obj.stop),\n                \"step\": serialize_keras_object(obj.step),\n            },\n        }\n    # Ellipsis is an instance, and ellipsis class is not in global scope.\n    # checking equality also fails elsewhere in the library, so we have\n    # to dynamically get the type.\n    if isinstance(obj, type(Ellipsis)):\n        return {\"class_name\": \"__ellipsis__\", \"config\": {}}\n    if isinstance(obj, backend.KerasTensor):\n        history = getattr(obj, \"_keras_history\", None)\n        if history:\n            history = list(history)\n            history[0] = history[0].name\n        return {\n            \"class_name\": \"__keras_tensor__\",\n            \"config\": {\n                \"shape\": obj.shape,\n                \"dtype\": obj.dtype,\n                \"keras_history\": history,\n            },\n        }\n    if tf.available and isinstance(obj, tf.TensorShape):\n        return obj.as_list() if obj._dims is not None else None\n    if backend.is_tensor(obj):\n        return {\n            \"class_name\": \"__tensor__\",\n            \"config\": {\n                \"value\": backend.convert_to_numpy(obj).tolist(),\n                \"dtype\": backend.standardize_dtype(obj.dtype),\n            },\n        }\n    if type(obj).__module__ == np.__name__:\n        if isinstance(obj, np.ndarray) and obj.ndim > 0:\n            return {\n                \"class_name\": \"__numpy__\",\n                \"config\": {\n                    \"value\": obj.tolist(),\n                    \"dtype\": backend.standardize_dtype(obj.dtype),\n                },\n            }\n        else:\n            # Treat numpy floats / etc as plain types.\n            return obj.item()\n    if tf.available and isinstance(obj, tf.DType):\n        return obj.name\n    if isinstance(obj, types.FunctionType) and obj.__name__ == \"<lambda>\":\n        warnings.warn(\n            \"The object being serialized includes a `lambda`. This is unsafe. \"\n            \"In order to reload the object, you will have to pass \"\n            \"`safe_mode=False` to the loading function. \"\n            \"Please avoid using `lambda` in the \"\n            \"future, and use named Python functions instead. \"\n            f\"This is the `lambda` being serialized: {inspect.getsource(obj)}\",\n            stacklevel=2,\n        )\n        return {\n            \"class_name\": \"__lambda__\",\n            \"config\": {\n                \"value\": python_utils.func_dump(obj),\n            },\n        }\n    if tf.available and isinstance(obj, tf.TypeSpec):\n        ts_config = obj._serialize()\n        # TensorShape and tf.DType conversion\n        ts_config = list(\n            map(\n                lambda x: (\n                    x.as_list()\n                    if isinstance(x, tf.TensorShape)\n                    else (x.name if isinstance(x, tf.DType) else x)\n                ),\n                ts_config,\n            )\n        )\n        return {\n            \"class_name\": \"__typespec__\",\n            \"spec_name\": obj.__class__.__name__,\n            \"module\": obj.__class__.__module__,\n            \"config\": ts_config,\n            \"registered_name\": None,\n        }\n\n    inner_config = _get_class_or_fn_config(obj)\n    config_with_public_class = serialize_with_public_class(\n        obj.__class__, inner_config\n    )\n\n    if config_with_public_class is not None:\n        get_build_and_compile_config(obj, config_with_public_class)\n        record_object_after_serialization(obj, config_with_public_class)\n        return config_with_public_class\n\n    # Any custom object or otherwise non-exported object\n    if isinstance(obj, types.FunctionType):\n        module = obj.__module__\n    else:\n        module = obj.__class__.__module__\n    class_name = obj.__class__.__name__\n\n    if module == \"builtins\":\n        registered_name = None\n    else:\n        if isinstance(obj, types.FunctionType):\n            registered_name = object_registration.get_registered_name(obj)\n        else:\n            registered_name = object_registration.get_registered_name(\n                obj.__class__\n            )\n\n    config = {\n        \"module\": module,\n        \"class_name\": class_name,\n        \"config\": inner_config,\n        \"registered_name\": registered_name,\n    }\n    get_build_and_compile_config(obj, config)\n    record_object_after_serialization(obj, config)\n    return config\n\n\ndef get_build_and_compile_config(obj, config):\n    if hasattr(obj, \"get_build_config\"):\n        build_config = obj.get_build_config()\n        if build_config is not None:\n            config[\"build_config\"] = serialize_dict(build_config)\n    if hasattr(obj, \"get_compile_config\"):\n        compile_config = obj.get_compile_config()\n        if compile_config is not None:\n            config[\"compile_config\"] = serialize_dict(compile_config)\n    return\n\n\ndef serialize_with_public_class(cls, inner_config=None):\n    \"\"\"Serializes classes from public Keras API or object registration.\n\n    Called to check and retrieve the config of any class that has a public\n    Keras API or has been registered as serializable via\n    `keras.saving.register_keras_serializable()`.\n    \"\"\"\n    # This gets the `keras.*` exported name, such as\n    # \"keras.optimizers.Adam\".\n    keras_api_name = api_export.get_name_from_symbol(cls)\n\n    # Case of custom or unknown class object\n    if keras_api_name is None:\n        registered_name = object_registration.get_registered_name(cls)\n        if registered_name is None:\n            return None\n\n        # Return custom object config with corresponding registration name\n        return {\n            \"module\": cls.__module__,\n            \"class_name\": cls.__name__,\n            \"config\": inner_config,\n            \"registered_name\": registered_name,\n        }\n\n    # Split the canonical Keras API name into a Keras module and class name.\n    parts = keras_api_name.split(\".\")\n    return {\n        \"module\": \".\".join(parts[:-1]),\n        \"class_name\": parts[-1],\n        \"config\": inner_config,\n        \"registered_name\": None,\n    }\n\n\ndef serialize_with_public_fn(fn, config, fn_module_name=None):\n    \"\"\"Serializes functions from public Keras API or object registration.\n\n    Called to check and retrieve the config of any function that has a public\n    Keras API or has been registered as serializable via\n    `keras.saving.register_keras_serializable()`. If function's module name\n    is already known, returns corresponding config.\n    \"\"\"\n    if fn_module_name:\n        return {\n            \"module\": fn_module_name,\n            \"class_name\": \"function\",\n            \"config\": config,\n            \"registered_name\": config,\n        }\n    keras_api_name = api_export.get_name_from_symbol(fn)\n    if keras_api_name:\n        parts = keras_api_name.split(\".\")\n        return {\n            \"module\": \".\".join(parts[:-1]),\n            \"class_name\": \"function\",\n            \"config\": config,\n            \"registered_name\": config,\n        }\n    else:\n        registered_name = object_registration.get_registered_name(fn)\n        if not registered_name and not fn.__module__ == \"builtins\":\n            return None\n        return {\n            \"module\": fn.__module__,\n            \"class_name\": \"function\",\n            \"config\": config,\n            \"registered_name\": registered_name,\n        }\n\n\ndef _get_class_or_fn_config(obj):\n    \"\"\"Return the object's config depending on its type.\"\"\"\n    # Functions / lambdas:\n    if isinstance(obj, types.FunctionType):\n        return object_registration.get_registered_name(obj)\n    # All classes:\n    if hasattr(obj, \"get_config\"):\n        config = obj.get_config()\n        if not isinstance(config, dict):\n            raise TypeError(\n                f\"The `get_config()` method of {obj} should return \"\n                f\"a dict. It returned: {config}\"\n            )\n        return serialize_dict(config)\n    elif hasattr(obj, \"__name__\"):\n        return object_registration.get_registered_name(obj)\n    else:\n        raise TypeError(\n            f\"Cannot serialize object {obj} of type {type(obj)}. \"\n            \"To be serializable, \"\n            \"a class must implement the `get_config()` method.\"\n        )\n\n\ndef serialize_dict(obj):\n    return {key: serialize_keras_object(value) for key, value in obj.items()}\n\n\n@keras_export(\n    [\n        \"keras.saving.deserialize_keras_object\",\n        \"keras.utils.deserialize_keras_object\",\n    ]\n)\ndef deserialize_keras_object(\n    config, custom_objects=None, safe_mode=True, **kwargs\n):\n    \"\"\"Retrieve the object by deserializing the config dict.\n\n    The config dict is a Python dictionary that consists of a set of key-value\n    pairs, and represents a Keras object, such as an `Optimizer`, `Layer`,\n    `Metrics`, etc. The saving and loading library uses the following keys to\n    record information of a Keras object:\n\n    - `class_name`: String. This is the name of the class,\n      as exactly defined in the source\n      code, such as \"LossesContainer\".\n    - `config`: Dict. Library-defined or user-defined key-value pairs that store\n      the configuration of the object, as obtained by `object.get_config()`.\n    - `module`: String. The path of the python module. Built-in Keras classes\n      expect to have prefix `keras`.\n    - `registered_name`: String. The key the class is registered under via\n      `keras.saving.register_keras_serializable(package, name)` API. The\n      key has the format of '{package}>{name}', where `package` and `name` are\n      the arguments passed to `register_keras_serializable()`. If `name` is not\n      provided, it uses the class name. If `registered_name` successfully\n      resolves to a class (that was registered), the `class_name` and `config`\n      values in the dict will not be used. `registered_name` is only used for\n      non-built-in classes.\n\n    For example, the following dictionary represents the built-in Adam optimizer\n    with the relevant config:\n\n    ```python\n    dict_structure = {\n        \"class_name\": \"Adam\",\n        \"config\": {\n            \"amsgrad\": false,\n            \"beta_1\": 0.8999999761581421,\n            \"beta_2\": 0.9990000128746033,\n            \"decay\": 0.0,\n            \"epsilon\": 1e-07,\n            \"learning_rate\": 0.0010000000474974513,\n            \"name\": \"Adam\"\n        },\n        \"module\": \"keras.optimizers\",\n        \"registered_name\": None\n    }\n    # Returns an `Adam` instance identical to the original one.\n    deserialize_keras_object(dict_structure)\n    ```\n\n    If the class does not have an exported Keras namespace, the library tracks\n    it by its `module` and `class_name`. For example:\n\n    ```python\n    dict_structure = {\n      \"class_name\": \"MetricsList\",\n      \"config\": {\n          ...\n      },\n      \"module\": \"keras.trainers.compile_utils\",\n      \"registered_name\": \"MetricsList\"\n    }\n\n    # Returns a `MetricsList` instance identical to the original one.\n    deserialize_keras_object(dict_structure)\n    ```\n\n    And the following dictionary represents a user-customized `MeanSquaredError`\n    loss:\n\n    ```python\n    @keras.saving.register_keras_serializable(package='my_package')\n    class ModifiedMeanSquaredError(keras.losses.MeanSquaredError):\n      ...\n\n    dict_structure = {\n        \"class_name\": \"ModifiedMeanSquaredError\",\n        \"config\": {\n            \"fn\": \"mean_squared_error\",\n            \"name\": \"mean_squared_error\",\n            \"reduction\": \"auto\"\n        },\n        \"registered_name\": \"my_package>ModifiedMeanSquaredError\"\n    }\n    # Returns the `ModifiedMeanSquaredError` object\n    deserialize_keras_object(dict_structure)\n    ```\n\n    Args:\n        config: Python dict describing the object.\n        custom_objects: Python dict containing a mapping between custom\n            object names the corresponding classes or functions.\n        safe_mode: Boolean, whether to disallow unsafe `lambda` deserialization.\n            When `safe_mode=False`, loading an object has the potential to\n            trigger arbitrary code execution. This argument is only\n            applicable to the Keras v3 model format. Defaults to `True`.\n\n    Returns:\n        The object described by the `config` dictionary.\n    \"\"\"\n    safe_scope_arg = in_safe_mode()  # Enforces SafeModeScope\n    safe_mode = safe_scope_arg if safe_scope_arg is not None else safe_mode\n\n    module_objects = kwargs.pop(\"module_objects\", None)\n    custom_objects = custom_objects or {}\n    tlco = global_state.get_global_attribute(\"custom_objects_scope_dict\", {})\n    gco = object_registration.GLOBAL_CUSTOM_OBJECTS\n    custom_objects = {**custom_objects, **tlco, **gco}\n\n    if config is None:\n        return None\n\n    if (\n        isinstance(config, str)\n        and custom_objects\n        and custom_objects.get(config) is not None\n    ):\n        # This is to deserialize plain functions which are serialized as\n        # string names by legacy saving formats.\n        return custom_objects[config]\n\n    if isinstance(config, (list, tuple)):\n        return [\n            deserialize_keras_object(\n                x, custom_objects=custom_objects, safe_mode=safe_mode\n            )\n            for x in config\n        ]\n\n    if module_objects is not None:\n        inner_config, fn_module_name, has_custom_object = None, None, False\n\n        if isinstance(config, dict):\n            if \"config\" in config:\n                inner_config = config[\"config\"]\n            if \"class_name\" not in config:\n                raise ValueError(\n                    f\"Unknown `config` as a `dict`, config={config}\"\n                )\n\n            # Check case where config is function or class and in custom objects\n            if custom_objects and (\n                config[\"class_name\"] in custom_objects\n                or config.get(\"registered_name\") in custom_objects\n                or (\n                    isinstance(inner_config, str)\n                    and inner_config in custom_objects\n                )\n            ):\n                has_custom_object = True\n\n            # Case where config is function but not in custom objects\n            elif config[\"class_name\"] == \"function\":\n                fn_module_name = config[\"module\"]\n                if fn_module_name == \"builtins\":\n                    config = config[\"config\"]\n                else:\n                    config = config[\"registered_name\"]\n\n            # Case where config is class but not in custom objects\n            else:\n                if config.get(\"module\", \"_\") is None:\n                    raise TypeError(\n                        \"Cannot deserialize object of type \"\n                        f\"`{config['class_name']}`. If \"\n                        f\"`{config['class_name']}` is a custom class, please \"\n                        \"register it using the \"\n                        \"`@keras.saving.register_keras_serializable()` \"\n                        \"decorator.\"\n                    )\n                config = config[\"class_name\"]\n\n        if not has_custom_object:\n            # Return if not found in either module objects or custom objects\n            if config not in module_objects:\n                # Object has already been deserialized\n                return config\n            if isinstance(module_objects[config], types.FunctionType):\n                return deserialize_keras_object(\n                    serialize_with_public_fn(\n                        module_objects[config], config, fn_module_name\n                    ),\n                    custom_objects=custom_objects,\n                )\n            return deserialize_keras_object(\n                serialize_with_public_class(\n                    module_objects[config], inner_config=inner_config\n                ),\n                custom_objects=custom_objects,\n            )\n\n    if isinstance(config, PLAIN_TYPES):\n        return config\n    if not isinstance(config, dict):\n        raise TypeError(f\"Could not parse config: {config}\")\n\n    if \"class_name\" not in config or \"config\" not in config:\n        return {\n            key: deserialize_keras_object(\n                value, custom_objects=custom_objects, safe_mode=safe_mode\n            )\n            for key, value in config.items()\n        }\n\n    class_name = config[\"class_name\"]\n    inner_config = config[\"config\"] or {}\n    custom_objects = custom_objects or {}\n\n    # Special cases:\n    if class_name == \"__keras_tensor__\":\n        obj = backend.KerasTensor(\n            inner_config[\"shape\"], dtype=inner_config[\"dtype\"]\n        )\n        obj._pre_serialization_keras_history = inner_config[\"keras_history\"]\n        return obj\n\n    if class_name == \"__tensor__\":\n        return backend.convert_to_tensor(\n            inner_config[\"value\"], dtype=inner_config[\"dtype\"]\n        )\n    if class_name == \"__numpy__\":\n        return np.array(inner_config[\"value\"], dtype=inner_config[\"dtype\"])\n    if config[\"class_name\"] == \"__bytes__\":\n        return inner_config[\"value\"].encode(\"utf-8\")\n    if config[\"class_name\"] == \"__ellipsis__\":\n        return Ellipsis\n    if config[\"class_name\"] == \"__slice__\":\n        return slice(\n            deserialize_keras_object(\n                inner_config[\"start\"],\n                custom_objects=custom_objects,\n                safe_mode=safe_mode,\n            ),\n            deserialize_keras_object(\n                inner_config[\"stop\"],\n                custom_objects=custom_objects,\n                safe_mode=safe_mode,\n            ),\n            deserialize_keras_object(\n                inner_config[\"step\"],\n                custom_objects=custom_objects,\n                safe_mode=safe_mode,\n            ),\n        )\n    if config[\"class_name\"] == \"__lambda__\":\n        if safe_mode:\n            raise ValueError(\n                \"Requested the deserialization of a Python lambda. This \"\n                \"carries a potential risk of arbitrary code execution and thus \"\n                \"it is disallowed by default. If you trust the source of the \"\n                \"artifact, you can override this error by passing \"\n                \"`safe_mode=False` to the loading function, or calling \"\n                \"`keras.config.enable_unsafe_deserialization().\"\n            )\n        return python_utils.func_load(inner_config[\"value\"])\n    if tf is not None and config[\"class_name\"] == \"__typespec__\":\n        obj = _retrieve_class_or_fn(\n            config[\"spec_name\"],\n            config[\"registered_name\"],\n            config[\"module\"],\n            obj_type=\"class\",\n            full_config=config,\n            custom_objects=custom_objects,\n        )\n        # Conversion to TensorShape and DType\n        inner_config = map(\n            lambda x: (\n                tf.TensorShape(x)\n                if isinstance(x, list)\n                else (getattr(tf, x) if hasattr(tf.dtypes, str(x)) else x)\n            ),\n            inner_config,\n        )\n        return obj._deserialize(tuple(inner_config))\n\n    # Below: classes and functions.\n    module = config.get(\"module\", None)\n    registered_name = config.get(\"registered_name\", class_name)\n\n    if class_name == \"function\":\n        fn_name = inner_config\n        return _retrieve_class_or_fn(\n            fn_name,\n            registered_name,\n            module,\n            obj_type=\"function\",\n            full_config=config,\n            custom_objects=custom_objects,\n        )\n\n    # Below, handling of all classes.\n    # First, is it a shared object?\n    if \"shared_object_id\" in config:\n        obj = get_shared_object(config[\"shared_object_id\"])\n        if obj is not None:\n            return obj\n\n    cls = _retrieve_class_or_fn(\n        class_name,\n        registered_name,\n        module,\n        obj_type=\"class\",\n        full_config=config,\n        custom_objects=custom_objects,\n    )\n\n    if isinstance(cls, types.FunctionType):\n        return cls\n    if not hasattr(cls, \"from_config\"):\n        raise TypeError(\n            f\"Unable to reconstruct an instance of '{class_name}' because \"\n            f\"the class is missing a `from_config()` method. \"\n            f\"Full object config: {config}\"\n        )\n\n    # Instantiate the class from its config inside a custom object scope\n    # so that we can catch any custom objects that the config refers to.\n    custom_obj_scope = object_registration.CustomObjectScope(custom_objects)\n    safe_mode_scope = SafeModeScope(safe_mode)\n    with custom_obj_scope, safe_mode_scope:\n        try:\n            instance = cls.from_config(inner_config)\n        except TypeError as e:\n            raise TypeError(\n                f\"{cls} could not be deserialized properly. Please\"\n                \" ensure that components that are Python object\"\n                \" instances (layers, models, etc.) returned by\"\n                \" `get_config()` are explicitly deserialized in the\"\n                \" model's `from_config()` method.\"\n                f\"\\n\\nconfig={config}.\\n\\nException encountered: {e}\"\n            )\n        build_config = config.get(\"build_config\", None)\n        if build_config and not instance.built:\n            instance.build_from_config(build_config)\n            instance.built = True\n        compile_config = config.get(\"compile_config\", None)\n        if compile_config:\n            instance.compile_from_config(compile_config)\n            instance.compiled = True\n\n    if \"shared_object_id\" in config:\n        record_object_after_deserialization(\n            instance, config[\"shared_object_id\"]\n        )\n    return instance\n\n\ndef _retrieve_class_or_fn(\n    name, registered_name, module, obj_type, full_config, custom_objects=None\n):\n    # If there is a custom object registered via\n    # `register_keras_serializable()`, that takes precedence.\n    if obj_type == \"function\":\n        custom_obj = object_registration.get_registered_object(\n            name, custom_objects=custom_objects\n        )\n    else:\n        custom_obj = object_registration.get_registered_object(\n            registered_name, custom_objects=custom_objects\n        )\n    if custom_obj is not None:\n        return custom_obj\n\n    if module:\n        # If it's a Keras built-in object,\n        # we cannot always use direct import, because the exported\n        # module name might not match the package structure\n        # (e.g. experimental symbols).\n        if module == \"keras\" or module.startswith(\"keras.\"):\n            api_name = f\"{module}.{name}\"\n\n            if api_name in LOADING_APIS:\n                raise ValueError(\n                    f\"Cannot deserialize `{api_name}`, loading functions are \"\n                    \"not allowed during deserialization\"\n                )\n\n            obj = api_export.get_symbol_from_name(api_name)\n            if obj is not None:\n                return obj\n\n        # Configs of Keras built-in functions do not contain identifying\n        # information other than their name (e.g. 'acc' or 'tanh'). This special\n        # case searches the Keras modules that contain built-ins to retrieve\n        # the corresponding function from the identifying string.\n        if obj_type == \"function\" and module == \"builtins\":\n            for mod in BUILTIN_MODULES:\n                obj = api_export.get_symbol_from_name(f\"keras.{mod}.{name}\")\n                if obj is not None:\n                    return obj\n\n            # Workaround for serialization bug in Keras <= 3.6 whereby custom\n            # functions would only be saved by name instead of registered name,\n            # i.e. \"name\" instead of \"package>name\". This allows recent versions\n            # of Keras to reload models saved with 3.6 and lower.\n            if \">\" not in name:\n                separated_name = f\">{name}\"\n                for custom_name, custom_object in custom_objects.items():\n                    if custom_name.endswith(separated_name):\n                        return custom_object\n\n        # Otherwise, attempt to retrieve the class object given the `module`\n        # and `class_name`. Import the module, find the class.\n        package = module.split(\".\", maxsplit=1)[0]\n        if package in {\"keras\", \"keras_hub\", \"keras_cv\", \"keras_nlp\"}:\n            try:\n                mod = importlib.import_module(module)\n                obj = vars(mod).get(name, None)\n                if isinstance(obj, type) and issubclass(obj, KerasSaveable):\n                    return obj\n                else:\n                    raise ValueError(\n                        f\"Could not deserialize '{module}.{name}' because \"\n                        \"it is not a KerasSaveable subclass\"\n                    )\n            except ModuleNotFoundError:\n                raise TypeError(\n                    f\"Could not deserialize {obj_type} '{name}' because \"\n                    f\"its parent module {module} cannot be imported. \"\n                    f\"Full object config: {full_config}\"\n                )\n\n    raise TypeError(\n        f\"Could not locate {obj_type} '{name}'. Make sure custom classes and \"\n        \"functions are decorated with \"\n        \"`@keras.saving.register_keras_serializable()`. If they are already \"\n        \"decorated, make sure they are all imported so that the decorator is \"\n        f\"run before trying to load them. Full object config: {full_config}\"\n    )\n"
  },
  {
    "path": "keras/src/saving/serialization_lib_test.py",
    "content": "\"\"\"Tests for serialization_lib.\"\"\"\n\nimport json\n\nimport numpy as np\nimport pytest\n\nimport keras\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.saving import object_registration\nfrom keras.src.saving import serialization_lib\n\n\ndef custom_fn(x):\n    return x**2\n\n\nclass CustomLayer(keras.layers.Layer):\n    def __init__(self, factor):\n        super().__init__()\n        self.factor = factor\n\n    def call(self, x):\n        return x * self.factor\n\n    def get_config(self):\n        return {\"factor\": self.factor}\n\n\nclass NestedCustomLayer(keras.layers.Layer):\n    def __init__(self, factor, dense=None, activation=None):\n        super().__init__()\n        self.factor = factor\n\n        if dense is None:\n            self.dense = keras.layers.Dense(1, activation=custom_fn)\n        else:\n            self.dense = serialization_lib.deserialize_keras_object(dense)\n        self.activation = serialization_lib.deserialize_keras_object(activation)\n\n    def call(self, x):\n        return self.dense(x * self.factor)\n\n    def get_config(self):\n        return {\n            \"factor\": self.factor,\n            \"dense\": self.dense,\n            \"activation\": self.activation,\n        }\n\n\nclass WrapperLayer(keras.layers.Wrapper):\n    def call(self, x):\n        return self.layer(x)\n\n\nclass SerializationLibTest(testing.TestCase):\n    def roundtrip(self, obj, custom_objects=None, safe_mode=True):\n        serialized = serialization_lib.serialize_keras_object(obj)\n        json_data = json.dumps(serialized)\n        json_data = json.loads(json_data)\n        deserialized = serialization_lib.deserialize_keras_object(\n            json_data, custom_objects=custom_objects, safe_mode=safe_mode\n        )\n        reserialized = serialization_lib.serialize_keras_object(deserialized)\n        return serialized, deserialized, reserialized\n\n    def test_simple_objects(self):\n        for obj in [\n            \"hello\",\n            b\"hello\",\n            np.array([0, 1]),\n            np.array([0.0, 1.0]),\n            np.float32(1.0),\n            [\"hello\", 0, \"world\", 1.0, True],\n            {\"1\": \"hello\", \"2\": 0, \"3\": True},\n            {\"1\": \"hello\", \"2\": [True, False]},\n            slice(None, 20, 1),\n            slice(None, np.array([0, 1]), 1),\n        ]:\n            serialized, _, reserialized = self.roundtrip(obj)\n            self.assertEqual(serialized, reserialized)\n\n    def test_builtin_layers(self):\n        layer = keras.layers.Dense(\n            3,\n            name=\"foo\",\n            trainable=False,\n            dtype=\"float16\",\n        )\n        serialized, restored, reserialized = self.roundtrip(layer)\n        self.assertEqual(serialized, reserialized)\n        self.assertEqual(layer.name, restored.name)\n        self.assertEqual(layer.trainable, restored.trainable)\n        self.assertEqual(layer.compute_dtype, restored.compute_dtype)\n\n    def test_numpy_get_item_layer(self):\n        def tuples_to_lists_str(x):\n            return str(x).replace(\"(\", \"[\").replace(\")\", \"]\")\n\n        input = keras.layers.Input(shape=(2,))\n        layer = input[:, 1]\n        model = keras.Model(input, layer)\n        serialized, _, reserialized = self.roundtrip(model)\n        # Anticipate JSON roundtrip mapping tuples to lists:\n        serialized_str = tuples_to_lists_str(serialized)\n        reserialized_str = tuples_to_lists_str(reserialized)\n        self.assertEqual(serialized_str, reserialized_str)\n\n    def test_serialize_ellipsis(self):\n        _, deserialized, _ = self.roundtrip(Ellipsis)\n        self.assertEqual(..., deserialized)\n\n    def test_tensors_and_shapes(self):\n        x = ops.random.normal((2, 2), dtype=\"float64\")\n        obj = {\"x\": x}\n        _, new_obj, _ = self.roundtrip(obj)\n        self.assertAllClose(x, new_obj[\"x\"], atol=1e-5)\n\n        obj = {\"x.shape\": x.shape}\n        _, new_obj, _ = self.roundtrip(obj)\n        self.assertEqual(tuple(x.shape), tuple(new_obj[\"x.shape\"]))\n\n    def test_custom_fn(self):\n        obj = {\"activation\": custom_fn}\n        serialized, _, reserialized = self.roundtrip(\n            obj, custom_objects={\"custom_fn\": custom_fn}\n        )\n        self.assertEqual(serialized, reserialized)\n\n        # Test inside layer\n        dense = keras.layers.Dense(1, activation=custom_fn)\n        dense.build((None, 2))\n        _, new_dense, _ = self.roundtrip(\n            dense, custom_objects={\"custom_fn\": custom_fn}\n        )\n        x = ops.random.normal((2, 2))\n        y1 = dense(x)\n        _ = new_dense(x)\n        new_dense.set_weights(dense.get_weights())\n        y2 = new_dense(x)\n        self.assertAllClose(y1, y2, atol=1e-5)\n\n    def test_custom_layer(self):\n        layer = CustomLayer(factor=2)\n        x = ops.random.normal((2, 2))\n        y1 = layer(x)\n        _, new_layer, _ = self.roundtrip(\n            layer, custom_objects={\"CustomLayer\": CustomLayer}\n        )\n        y2 = new_layer(x)\n        self.assertAllClose(y1, y2, atol=1e-5)\n\n        layer = NestedCustomLayer(factor=2)\n        x = ops.random.normal((2, 2))\n        y1 = layer(x)\n        _, new_layer, _ = self.roundtrip(\n            layer,\n            custom_objects={\n                \"NestedCustomLayer\": NestedCustomLayer,\n                \"custom_fn\": custom_fn,\n            },\n        )\n        _ = new_layer(x)\n        new_layer.set_weights(layer.get_weights())\n        y2 = new_layer(x)\n        self.assertAllClose(y1, y2, atol=1e-5)\n\n    def test_lambda_fn(self):\n        obj = {\"activation\": lambda x: x**2}\n        with self.assertRaisesRegex(ValueError, \"arbitrary code execution\"):\n            self.roundtrip(obj, safe_mode=True)\n\n        _, new_obj, _ = self.roundtrip(obj, safe_mode=False)\n        self.assertEqual(obj[\"activation\"](3), new_obj[\"activation\"](3))\n\n    def test_lambda_layer(self):\n        lmbda = keras.layers.Lambda(lambda x: x**2)\n        with self.assertRaisesRegex(ValueError, \"arbitrary code execution\"):\n            self.roundtrip(lmbda, safe_mode=True)\n\n        _, new_lmbda, _ = self.roundtrip(lmbda, safe_mode=False)\n        x = ops.random.normal((2, 2))\n        y1 = lmbda(x)\n        y2 = new_lmbda(x)\n        self.assertAllClose(y1, y2, atol=1e-5)\n\n    def test_safe_mode_scope(self):\n        lmbda = keras.layers.Lambda(lambda x: x**2)\n        with serialization_lib.SafeModeScope(safe_mode=True):\n            with self.assertRaisesRegex(ValueError, \"arbitrary code execution\"):\n                self.roundtrip(lmbda)\n        with serialization_lib.SafeModeScope(safe_mode=False):\n            _, new_lmbda, _ = self.roundtrip(lmbda)\n        x = ops.random.normal((2, 2))\n        y1 = lmbda(x)\n        y2 = new_lmbda(x)\n        self.assertAllClose(y1, y2, atol=1e-5)\n\n    @pytest.mark.requires_trainable_backend\n    def test_dict_inputs_outputs(self):\n        input_foo = keras.Input((2,), name=\"foo\")\n        input_bar = keras.Input((2,), name=\"bar\")\n        dense = keras.layers.Dense(1)\n        output_foo = dense(input_foo)\n        output_bar = dense(input_bar)\n        model = keras.Model(\n            {\"foo\": input_foo, \"bar\": input_bar},\n            {\"foo\": output_foo, \"bar\": output_bar},\n        )\n        _, new_model, _ = self.roundtrip(model)\n        original_output = model(\n            {\"foo\": np.zeros((2, 2)), \"bar\": np.zeros((2, 2))}\n        )\n        restored_output = model(\n            {\"foo\": np.zeros((2, 2)), \"bar\": np.zeros((2, 2))}\n        )\n        self.assertAllClose(original_output[\"foo\"], restored_output[\"foo\"])\n        self.assertAllClose(original_output[\"bar\"], restored_output[\"bar\"])\n\n    @pytest.mark.requires_trainable_backend\n    def test_shared_inner_layer(self):\n        with serialization_lib.ObjectSharingScope():\n            input_1 = keras.Input((2,))\n            input_2 = keras.Input((2,))\n            shared_layer = keras.layers.Dense(1)\n            output_1 = shared_layer(input_1)\n            wrapper_layer = WrapperLayer(shared_layer)\n            output_2 = wrapper_layer(input_2)\n            model = keras.Model([input_1, input_2], [output_1, output_2])\n            _, new_model, _ = self.roundtrip(\n                model, custom_objects={\"WrapperLayer\": WrapperLayer}\n            )\n\n            self.assertIs(model.layers[2], model.layers[3].layer)\n            self.assertIs(new_model.layers[2], new_model.layers[3].layer)\n\n    @pytest.mark.requires_trainable_backend\n    def test_functional_subclass(self):\n        class PlainFunctionalSubclass(keras.Model):\n            pass\n\n        inputs = keras.Input((2,), batch_size=3)\n        outputs = keras.layers.Dense(1)(inputs)\n        model = PlainFunctionalSubclass(inputs, outputs)\n        x = ops.random.normal((2, 2))\n        y1 = model(x)\n        _, new_model, _ = self.roundtrip(\n            model,\n            custom_objects={\"PlainFunctionalSubclass\": PlainFunctionalSubclass},\n        )\n        new_model.set_weights(model.get_weights())\n        y2 = new_model(x)\n        self.assertAllClose(y1, y2, atol=1e-5)\n        self.assertIsInstance(new_model, PlainFunctionalSubclass)\n\n        class FunctionalSubclassWCustomInit(keras.Model):\n            def __init__(self, num_units=2):\n                inputs = keras.Input((2,), batch_size=3)\n                outputs = keras.layers.Dense(num_units)(inputs)\n                super().__init__(inputs, outputs)\n                self.num_units = num_units\n\n            def get_config(self):\n                return {\"num_units\": self.num_units}\n\n        model = FunctionalSubclassWCustomInit(num_units=3)\n        x = ops.random.normal((2, 2))\n        y1 = model(x)\n        _, new_model, _ = self.roundtrip(\n            model,\n            custom_objects={\n                \"FunctionalSubclassWCustomInit\": FunctionalSubclassWCustomInit\n            },\n        )\n        new_model.set_weights(model.get_weights())\n        y2 = new_model(x)\n        self.assertAllClose(y1, y2, atol=1e-5)\n        self.assertIsInstance(new_model, FunctionalSubclassWCustomInit)\n\n    def test_shared_object(self):\n        class MyLayer(keras.layers.Layer):\n            def __init__(self, activation, **kwargs):\n                super().__init__(**kwargs)\n                if isinstance(activation, dict):\n                    self.activation = (\n                        serialization_lib.deserialize_keras_object(activation)\n                    )\n                else:\n                    self.activation = activation\n\n            def call(self, x):\n                return self.activation(x)\n\n            def get_config(self):\n                config = super().get_config()\n                config[\"activation\"] = self.activation\n                return config\n\n        class SharedActivation:\n            def __call__(self, x):\n                return x**2\n\n            def get_config(self):\n                return {}\n\n            @classmethod\n            def from_config(cls, config):\n                return cls()\n\n        shared_act = SharedActivation()\n        layer_1 = MyLayer(activation=shared_act)\n        layer_2 = MyLayer(activation=shared_act)\n        layers = [layer_1, layer_2]\n\n        with serialization_lib.ObjectSharingScope():\n            serialized, new_layers, reserialized = self.roundtrip(\n                layers,\n                custom_objects={\n                    \"MyLayer\": MyLayer,\n                    \"SharedActivation\": SharedActivation,\n                },\n            )\n        self.assertIn(\"shared_object_id\", serialized[0][\"config\"][\"activation\"])\n        obj_id = serialized[0][\"config\"][\"activation\"]\n        self.assertIn(\"shared_object_id\", serialized[1][\"config\"][\"activation\"])\n        self.assertEqual(obj_id, serialized[1][\"config\"][\"activation\"])\n        self.assertIs(layers[0].activation, layers[1].activation)\n        self.assertIs(new_layers[0].activation, new_layers[1].activation)\n\n    def test_layer_sharing(self):\n        seq = keras.Sequential(\n            [\n                keras.Input(shape=(3,)),\n                keras.layers.Dense(5),\n                keras.layers.Softmax(),\n            ],\n        )\n        func = keras.Model(inputs=seq.inputs, outputs=seq.outputs)\n        serialized, deserialized, reserialized = self.roundtrip(func)\n        self.assertLen(deserialized.layers, 3)\n\n    def test_keras36_custom_function_reloading(self):\n        @object_registration.register_keras_serializable(package=\"serial_test\")\n        def custom_registered_fn(x):\n            return x**2\n\n        config36 = {\n            \"module\": \"builtins\",\n            \"class_name\": \"function\",\n            \"config\": \"custom_registered_fn\",\n            \"registered_name\": \"function\",\n        }\n        obj = serialization_lib.deserialize_keras_object(config36)\n        self.assertIs(obj, custom_registered_fn)\n\n        config = {\n            \"module\": \"builtins\",\n            \"class_name\": \"function\",\n            \"config\": \"serial_test>custom_registered_fn\",\n            \"registered_name\": \"function\",\n        }\n        obj = serialization_lib.deserialize_keras_object(config)\n        self.assertIs(obj, custom_registered_fn)\n\n    def test_layer_instance_as_activation(self):\n        \"\"\"Tests serialization when activation is a Layer instance.\"\"\"\n\n        # Dense layer with ReLU layer as activation\n        layer_dense_relu = keras.layers.Dense(\n            units=4, activation=keras.layers.ReLU(name=\"my_relu\")\n        )\n        # Build the layer to ensure weights/state are initialized if needed\n        layer_dense_relu.build(input_shape=(None, 8))\n        _, restored_dense_relu, _ = self.roundtrip(layer_dense_relu)\n\n        # Verify the activation is correctly deserialized as a ReLU layer\n        self.assertIsInstance(restored_dense_relu.activation, keras.layers.ReLU)\n        # Verify properties are preserved\n        self.assertEqual(restored_dense_relu.activation.name, \"my_relu\")\n\n    def test_layer_instance_with_config_as_activation(self):\n        \"\"\"\n        Tests serialization when activation is a Layer instance with config.\n        \"\"\"\n\n        # Conv1D layer with LeakyReLU layer (with config) as activation\n        leaky_activation = keras.layers.LeakyReLU(\n            negative_slope=0.15, name=\"my_leaky\"\n        )\n        layer_conv_leaky = keras.layers.Conv1D(\n            filters=2, kernel_size=3, activation=leaky_activation\n        )\n        # Build the layer\n        layer_conv_leaky.build(input_shape=(None, 10, 4))\n        _, restored_conv_leaky, _ = self.roundtrip(layer_conv_leaky)\n\n        # Verify the activation is correctly deserialized as LeakyReLU\n        self.assertIsInstance(\n            restored_conv_leaky.activation, keras.layers.LeakyReLU\n        )\n        # Verify configuration of the activation layer is preserved\n        self.assertEqual(restored_conv_leaky.activation.negative_slope, 0.15)\n        self.assertEqual(restored_conv_leaky.activation.name, \"my_leaky\")\n\n    def test_layer_string_as_activation(self):\n        \"\"\"Tests serialization when activation is a string.\"\"\"\n\n        layer_dense_relu_string = keras.layers.Dense(units=4, activation=\"relu\")\n        layer_dense_relu_string.build(input_shape=(None, 8))\n        _, restored_dense_relu_string, _ = self.roundtrip(\n            layer_dense_relu_string\n        )\n\n        # Verify the activation is correctly deserialized to the relu function\n        self.assertTrue(callable(restored_dense_relu_string.activation))\n        # Check if it resolves to the canonical keras activation function\n        self.assertEqual(\n            restored_dense_relu_string.activation, keras.activations.relu\n        )\n\n\n@keras.saving.register_keras_serializable()\nclass MyDense(keras.layers.Layer):\n    def __init__(\n        self,\n        units,\n        *,\n        kernel_regularizer=None,\n        kernel_initializer=None,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self._units = units\n        self._kernel_regularizer = kernel_regularizer\n        self._kernel_initializer = kernel_initializer\n\n    def get_config(self):\n        return dict(\n            units=self._units,\n            kernel_initializer=self._kernel_initializer,\n            kernel_regularizer=self._kernel_regularizer,\n            **super().get_config(),\n        )\n\n    def build(self, input_shape):\n        _, input_units = input_shape\n        self._kernel = self.add_weight(\n            name=\"kernel\",\n            shape=[input_units, self._units],\n            dtype=\"float32\",\n            regularizer=self._kernel_regularizer,\n            initializer=self._kernel_initializer,\n        )\n\n    def call(self, inputs):\n        return ops.matmul(inputs, self._kernel)\n\n\n@keras.saving.register_keras_serializable()\nclass MyWrapper(keras.layers.Layer):\n    def __init__(self, wrapped, **kwargs):\n        super().__init__(**kwargs)\n        self._wrapped = wrapped\n\n    def get_config(self):\n        return dict(wrapped=self._wrapped, **super().get_config())\n\n    @classmethod\n    def from_config(cls, config):\n        config[\"wrapped\"] = keras.saving.deserialize_keras_object(\n            config[\"wrapped\"]\n        )\n        return cls(**config)\n\n    def call(self, inputs):\n        return self._wrapped(inputs)\n"
  },
  {
    "path": "keras/src/testing/__init__.py",
    "content": "from keras.src.testing.test_case import TestCase\nfrom keras.src.testing.test_case import jax_uses_gpu\nfrom keras.src.testing.test_case import jax_uses_tpu\nfrom keras.src.testing.test_case import tensorflow_uses_gpu\nfrom keras.src.testing.test_case import tensorflow_uses_tpu\nfrom keras.src.testing.test_case import torch_uses_gpu\nfrom keras.src.testing.test_case import uses_gpu\nfrom keras.src.testing.test_case import uses_tpu\n"
  },
  {
    "path": "keras/src/testing/test_case.py",
    "content": "import json\nimport shutil\nimport tempfile\nfrom pathlib import Path\n\nimport numpy as np\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import tree\nfrom keras.src import utils\nfrom keras.src.backend.common import is_float_dtype\nfrom keras.src.backend.common import standardize_dtype\nfrom keras.src.backend.common.global_state import clear_session\nfrom keras.src.backend.common.keras_tensor import KerasTensor\nfrom keras.src.losses.loss import Loss\nfrom keras.src.models import Model\nfrom keras.src.utils import traceback_utils\n\n\nclass TestCase(parameterized.TestCase):\n    maxDiff = None\n\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n\n    def setUp(self):\n        # clear global state so that test cases are independent\n        # required for the jit enabled torch tests since dynamo has\n        # a global cache for guards, compiled fn, etc\n        clear_session(free_memory=False)\n        if traceback_utils.is_traceback_filtering_enabled():\n            traceback_utils.disable_traceback_filtering()\n\n    def get_temp_dir(self):\n        temp_dir = tempfile.mkdtemp()\n        self.addCleanup(lambda: shutil.rmtree(temp_dir))\n        return temp_dir\n\n    def assertAllClose(\n        self,\n        x1,\n        x2,\n        atol=1e-6,\n        rtol=1e-6,\n        tpu_atol=None,\n        tpu_rtol=None,\n        msg=None,\n    ):\n        if tpu_atol is not None and uses_tpu():\n            atol = tpu_atol\n        if tpu_rtol is not None and uses_tpu():\n            rtol = tpu_rtol\n        if not isinstance(x1, np.ndarray):\n            x1 = backend.convert_to_numpy(x1)\n        if not isinstance(x2, np.ndarray):\n            x2 = backend.convert_to_numpy(x2)\n        np.testing.assert_allclose(x1, x2, atol=atol, rtol=rtol, err_msg=msg)\n\n    def assertNotAllClose(self, x1, x2, atol=1e-6, rtol=1e-6, msg=None):\n        try:\n            self.assertAllClose(x1, x2, atol=atol, rtol=rtol, msg=msg)\n        except AssertionError:\n            return\n        msg = msg or \"\"\n        raise AssertionError(\n            f\"The two values are close at all elements. \\n{msg}.\\nValues: {x1}\"\n        )\n\n    def assertAlmostEqual(self, x1, x2, decimal=3, tpu_decimal=None, msg=None):\n        if tpu_decimal is not None and uses_tpu():\n            decimal = tpu_decimal\n        msg = msg or \"\"\n        if not isinstance(x1, np.ndarray):\n            x1 = backend.convert_to_numpy(x1)\n        if not isinstance(x2, np.ndarray):\n            x2 = backend.convert_to_numpy(x2)\n        np.testing.assert_almost_equal(x1, x2, decimal=decimal, err_msg=msg)\n\n    def assertAllEqual(self, x1, x2, msg=None):\n        self.assertEqual(len(x1), len(x2), msg=msg)\n        for e1, e2 in zip(x1, x2):\n            if isinstance(e1, (list, tuple)) or isinstance(e2, (list, tuple)):\n                self.assertAllEqual(e1, e2, msg=msg)\n            else:\n                e1 = backend.convert_to_numpy(e1)\n                e2 = backend.convert_to_numpy(e2)\n                self.assertEqual(e1, e2, msg=msg)\n\n    def assertLen(self, iterable, expected_len, msg=None):\n        self.assertEqual(len(iterable), expected_len, msg=msg)\n\n    def assertSparse(self, x, sparse=True):\n        if isinstance(x, KerasTensor):\n            self.assertEqual(x.sparse, sparse)\n        elif backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            if sparse:\n                self.assertIsInstance(x, tf.SparseTensor)\n            else:\n                self.assertNotIsInstance(x, tf.SparseTensor)\n        elif backend.backend() == \"jax\":\n            import jax.experimental.sparse as jax_sparse\n\n            if sparse:\n                self.assertIsInstance(x, jax_sparse.JAXSparse)\n            else:\n                self.assertNotIsInstance(x, jax_sparse.JAXSparse)\n        else:\n            self.assertFalse(\n                sparse,\n                f\"Backend {backend.backend()} does not support sparse tensors\",\n            )\n\n    def assertRagged(self, x, ragged=True):\n        if isinstance(x, KerasTensor):\n            self.assertEqual(x.ragged, ragged)\n        elif backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            if ragged:\n                self.assertIsInstance(x, tf.RaggedTensor)\n            else:\n                self.assertNotIsInstance(x, tf.RaggedTensor)\n        else:\n            self.assertFalse(\n                ragged,\n                f\"Backend {backend.backend()} does not support ragged tensors\",\n            )\n\n    def assertDType(self, x, dtype, msg=None):\n        if hasattr(x, \"dtype\"):\n            x_dtype = backend.standardize_dtype(x.dtype)\n        else:\n            # If x is a python number\n            x_dtype = backend.standardize_dtype(type(x))\n        standardized_dtype = backend.standardize_dtype(dtype)\n        default_msg = (\n            \"The dtype of x does not match the expected one. \"\n            f\"Received: x.dtype={x_dtype} and dtype={dtype}\"\n        )\n        msg = msg or default_msg\n        self.assertEqual(x_dtype, standardized_dtype, msg=msg)\n\n    def assertFileExists(self, path):\n        if not Path(path).is_file():\n            raise AssertionError(f\"File {path} does not exist\")\n\n    def run_class_serialization_test(self, instance, custom_objects=None):\n        from keras.src.saving import custom_object_scope\n        from keras.src.saving import deserialize_keras_object\n        from keras.src.saving import serialize_keras_object\n\n        # get_config roundtrip\n        cls = instance.__class__\n        config = instance.get_config()\n        config_json = to_json_with_tuples(config)\n        ref_dir = dir(instance)[:]\n        with custom_object_scope(custom_objects):\n            revived_instance = cls.from_config(config)\n        revived_config = revived_instance.get_config()\n        revived_config_json = to_json_with_tuples(revived_config)\n        self.assertEqual(config_json, revived_config_json)\n        self.assertEqual(set(ref_dir), set(dir(revived_instance)))\n\n        # serialization roundtrip\n        serialized = serialize_keras_object(instance)\n        serialized_json = to_json_with_tuples(serialized)\n        with custom_object_scope(custom_objects):\n            revived_instance = deserialize_keras_object(\n                from_json_with_tuples(serialized_json)\n            )\n        revived_config = revived_instance.get_config()\n        revived_config_json = to_json_with_tuples(revived_config)\n        self.assertEqual(config_json, revived_config_json)\n        new_dir = dir(revived_instance)[:]\n        for lst in [ref_dir, new_dir]:\n            if \"__annotations__\" in lst:\n                lst.remove(\"__annotations__\")\n        self.assertEqual(set(ref_dir), set(new_dir))\n        return revived_instance\n\n    def run_layer_test(\n        self,\n        layer_cls,\n        init_kwargs,\n        input_shape=None,\n        input_dtype=None,\n        input_sparse=False,\n        input_ragged=False,\n        input_data=None,\n        call_kwargs=None,\n        expected_output_shape=None,\n        expected_output_dtype=None,\n        expected_output_sparse=False,\n        expected_output_ragged=False,\n        expected_output=None,\n        expected_num_trainable_weights=None,\n        expected_num_non_trainable_weights=None,\n        expected_num_non_trainable_variables=None,\n        expected_num_seed_generators=None,\n        expected_num_losses=None,\n        supports_masking=None,\n        expected_mask_shape=None,\n        custom_objects=None,\n        run_training_check=True,\n        run_mixed_precision_check=True,\n        assert_built_after_instantiation=False,\n        tpu_atol=None,\n        tpu_rtol=None,\n    ):\n        \"\"\"Run basic checks on a layer.\n\n        Args:\n            layer_cls: The class of the layer to test.\n            init_kwargs: Dict of arguments to be used to\n                instantiate the layer.\n            input_shape: Shape tuple (or list/dict of shape tuples)\n                to call the layer on.\n            input_dtype: Corresponding input dtype.\n            input_sparse: Whether the input is a sparse tensor (this requires\n                the backend to support sparse tensors).\n            input_ragged: Whether the input is a ragged tensor (this requires\n                the backend to support ragged tensors).\n            input_data: Tensor (or list/dict of tensors)\n                to call the layer on.\n            call_kwargs: Dict of arguments to use when calling the\n                layer (does not include the first input tensor argument)\n            expected_output_shape: Shape tuple\n                (or list/dict of shape tuples)\n                expected as output.\n            expected_output_dtype: dtype expected as output.\n            expected_output_sparse: Whether the output is expected to be sparse\n                (this requires the backend to support sparse tensors).\n            expected_output_ragged: Whether the output is expected to be ragged\n                (this requires the backend to support ragged tensors).\n            expected_output: Expected output tensor -- only\n                to be specified if input_data is provided.\n            expected_num_trainable_weights: Expected number\n                of trainable weights of the layer once built.\n            expected_num_non_trainable_weights: Expected number\n                of non-trainable weights of the layer once built.\n            expected_num_seed_generators: Expected number of\n                SeedGenerators objects of the layer once built.\n            expected_num_losses: Expected number of loss tensors\n                produced when calling the layer.\n            supports_masking: If True, will check that the layer\n                supports masking.\n            expected_mask_shape: Expected mask shape tuple\n                returned by compute_mask() (only supports 1 shape).\n            custom_objects: Dict of any custom objects to be\n                considered during deserialization.\n            run_training_check: Whether to attempt to train the layer\n                (if an input shape or input data was provided).\n            run_mixed_precision_check: Whether to test the layer with a mixed\n                precision dtype policy.\n            assert_built_after_instantiation: Whether to assert `built=True`\n                after the layer's instantiation.\n        \"\"\"\n        if input_shape is not None and input_data is not None:\n            raise ValueError(\n                \"input_shape and input_data cannot be passed at the same time.\"\n            )\n        if expected_output_shape is not None and expected_output is not None:\n            raise ValueError(\n                \"expected_output_shape and expected_output cannot be passed \"\n                \"at the same time.\"\n            )\n        if expected_output is not None and input_data is None:\n            raise ValueError(\n                \"In order to use expected_output, input_data must be provided.\"\n            )\n        if expected_mask_shape is not None and supports_masking is not True:\n            raise ValueError(\n                \"In order to use expected_mask_shape, supports_masking \"\n                \"must be True.\"\n            )\n\n        init_kwargs = init_kwargs or {}\n        call_kwargs = call_kwargs or {}\n\n        if input_shape is not None and input_dtype is not None:\n            if isinstance(input_shape, tuple) and is_shape_tuple(\n                input_shape[0]\n            ):\n                self.assertIsInstance(input_dtype, tuple)\n                self.assertEqual(\n                    len(input_shape),\n                    len(input_dtype),\n                    msg=\"The number of input shapes and dtypes does not match\",\n                )\n            elif isinstance(input_shape, dict):\n                self.assertIsInstance(input_dtype, dict)\n                self.assertEqual(\n                    set(input_shape.keys()),\n                    set(input_dtype.keys()),\n                    msg=\"The number of input shapes and dtypes does not match\",\n                )\n            elif isinstance(input_shape, list):\n                self.assertIsInstance(input_dtype, list)\n                self.assertEqual(\n                    len(input_shape),\n                    len(input_dtype),\n                    msg=\"The number of input shapes and dtypes does not match\",\n                )\n            elif not isinstance(input_shape, tuple):\n                raise ValueError(\"The type of input_shape is not supported\")\n        if input_shape is not None and input_dtype is None:\n            input_dtype = tree.map_shape_structure(\n                lambda _: \"float32\", input_shape\n            )\n\n        # Estimate actual number of weights, variables, seed generators if\n        # expected ones not set. When using layers uses composition it should\n        # build each sublayer manually.\n        if input_data is not None or input_shape is not None:\n            if input_data is None:\n                input_data = create_eager_tensors(\n                    input_shape, input_dtype, input_sparse, input_ragged\n                )\n            layer = layer_cls(**init_kwargs)\n            if isinstance(input_data, dict):\n                layer(**input_data, **call_kwargs)\n            else:\n                layer(input_data, **call_kwargs)\n\n            if expected_num_trainable_weights is None:\n                expected_num_trainable_weights = len(layer.trainable_weights)\n            if expected_num_non_trainable_weights is None:\n                expected_num_non_trainable_weights = len(\n                    layer.non_trainable_weights\n                )\n            if expected_num_non_trainable_variables is None:\n                expected_num_non_trainable_variables = len(\n                    layer.non_trainable_variables\n                )\n            if expected_num_seed_generators is None:\n                expected_num_seed_generators = len(get_seed_generators(layer))\n\n        # Serialization test.\n        layer = layer_cls(**init_kwargs)\n        self.run_class_serialization_test(layer, custom_objects)\n\n        # Basic masking test.\n        if supports_masking is not None:\n            self.assertEqual(\n                layer.supports_masking,\n                supports_masking,\n                msg=\"Unexpected supports_masking value\",\n            )\n\n        def run_build_asserts(layer):\n            self.assertTrue(layer.built)\n            if expected_num_trainable_weights is not None:\n                self.assertLen(\n                    layer.trainable_weights,\n                    expected_num_trainable_weights,\n                    msg=\"Unexpected number of trainable_weights\",\n                )\n            if expected_num_non_trainable_weights is not None:\n                self.assertLen(\n                    layer.non_trainable_weights,\n                    expected_num_non_trainable_weights,\n                    msg=\"Unexpected number of non_trainable_weights\",\n                )\n            if expected_num_non_trainable_variables is not None:\n                self.assertLen(\n                    layer.non_trainable_variables,\n                    expected_num_non_trainable_variables,\n                    msg=\"Unexpected number of non_trainable_variables\",\n                )\n            if expected_num_seed_generators is not None:\n                self.assertLen(\n                    get_seed_generators(layer),\n                    expected_num_seed_generators,\n                    msg=\"Unexpected number of seed_generators\",\n                )\n            if (\n                backend.backend() == \"torch\"\n                and expected_num_trainable_weights is not None\n                and expected_num_non_trainable_weights is not None\n                and expected_num_seed_generators is not None\n            ):\n                self.assertLen(\n                    layer.torch_params,\n                    expected_num_trainable_weights\n                    + expected_num_non_trainable_weights\n                    + expected_num_seed_generators,\n                    msg=\"Unexpected number of torch_params\",\n                )\n\n        def run_output_asserts(\n            layer, output, eager=False, tpu_atol=None, tpu_rtol=None\n        ):\n            if expected_output_shape is not None:\n\n                def verify_shape(expected_shape, x):\n                    shape = x.shape\n                    if len(shape) != len(expected_shape):\n                        return False\n                    for expected_dim, dim in zip(expected_shape, shape):\n                        if expected_dim is not None and expected_dim != dim:\n                            return False\n                    return True\n\n                shapes_match = tree.map_structure_up_to(\n                    output, verify_shape, expected_output_shape, output\n                )\n                self.assertTrue(\n                    all(tree.flatten(shapes_match)),\n                    msg=f\"Expected output shapes {expected_output_shape} but \"\n                    f\"received {tree.map_structure(lambda x: x.shape, output)}\",\n                )\n            if expected_output_dtype is not None:\n\n                def verify_dtype(expected_dtype, x):\n                    return expected_dtype == backend.standardize_dtype(x.dtype)\n\n                dtypes_match = tree.map_structure(\n                    verify_dtype, expected_output_dtype, output\n                )\n                self.assertTrue(\n                    all(tree.flatten(dtypes_match)),\n                    msg=f\"Expected output dtypes {expected_output_dtype} but \"\n                    f\"received {tree.map_structure(lambda x: x.dtype, output)}\",\n                )\n            if expected_output_sparse:\n                for x in tree.flatten(output):\n                    self.assertSparse(x)\n            if expected_output_ragged:\n                for x in tree.flatten(output):\n                    self.assertRagged(x)\n            if eager:\n                if expected_output is not None:\n                    self.assertEqual(type(expected_output), type(output))\n                    for ref_v, v in zip(\n                        tree.flatten(expected_output), tree.flatten(output)\n                    ):\n                        self.assertAllClose(\n                            ref_v,\n                            v,\n                            msg=\"Unexpected output value\",\n                            tpu_atol=tpu_atol,\n                            tpu_rtol=tpu_rtol,\n                        )\n                if expected_num_losses is not None:\n                    self.assertLen(layer.losses, expected_num_losses)\n\n        def run_training_step(layer, input_data, output_data):\n            class TestModel(Model):\n                def __init__(self, layer):\n                    super().__init__()\n                    self.layer = layer\n\n                def call(self, x, training=False):\n                    return self.layer(x, training=training)\n\n            model = TestModel(layer)\n\n            data = (input_data, output_data)\n            if backend.backend() == \"torch\":\n                data = tree.map_structure(backend.convert_to_numpy, data)\n\n            def data_generator():\n                while True:\n                    yield data\n\n            # Single op loss to avoid compilation issues with ragged / sparse.\n            class TestLoss(Loss):\n                def __call__(self, y_true, y_pred, sample_weight=None):\n                    return ops.sum(y_pred)\n\n            # test the \"default\" path for each backend by setting\n            # jit_compile=\"auto\".\n            # for tensorflow and jax backends auto is jitted\n            # Note that tensorflow cannot be jitted with sparse tensors\n            # for torch backend auto is eager\n            #\n            # NB: for torch, jit_compile=True turns on torchdynamo\n            #  which may not always succeed in tracing depending\n            #  on the model. Run your program with these env vars\n            #  to get debug traces of dynamo:\n            #    TORCH_LOGS=\"+dynamo\"\n            #    TORCHDYNAMO_VERBOSE=1\n            #    TORCHDYNAMO_REPORT_GUARD_FAILURES=1\n            jit_compile = \"auto\"\n            if backend.backend() == \"tensorflow\" and input_sparse:\n                jit_compile = False\n            model.compile(\n                optimizer=\"sgd\", loss=TestLoss(), jit_compile=jit_compile\n            )\n            model.fit(data_generator(), steps_per_epoch=1, verbose=0)\n\n        # Build test.\n        if input_data is not None or input_shape is not None:\n            if input_shape is None:\n                build_shape = tree.map_structure(\n                    lambda x: ops.shape(x), input_data\n                )\n            else:\n                build_shape = input_shape\n            layer = layer_cls(**init_kwargs)\n            if isinstance(build_shape, dict):\n                layer.build(**build_shape)\n            else:\n                layer.build(build_shape)\n            run_build_asserts(layer)\n\n            # Symbolic call test.\n            if input_shape is None:\n                keras_tensor_inputs = tree.map_structure(\n                    lambda x: create_keras_tensors(\n                        ops.shape(x), x.dtype, input_sparse, input_ragged\n                    ),\n                    input_data,\n                )\n            else:\n                keras_tensor_inputs = create_keras_tensors(\n                    input_shape, input_dtype, input_sparse, input_ragged\n                )\n            layer = layer_cls(**init_kwargs)\n            if isinstance(keras_tensor_inputs, dict):\n                keras_tensor_outputs = layer(\n                    **keras_tensor_inputs, **call_kwargs\n                )\n            else:\n                keras_tensor_outputs = layer(keras_tensor_inputs, **call_kwargs)\n            run_build_asserts(layer)\n            run_output_asserts(layer, keras_tensor_outputs, eager=False)\n\n            if expected_mask_shape is not None:\n                output_mask = layer.compute_mask(keras_tensor_inputs)\n                self.assertEqual(expected_mask_shape, output_mask.shape)\n\n            # The stateless layers should be built after instantiation.\n            if assert_built_after_instantiation:\n                layer = layer_cls(**init_kwargs)\n                self.assertTrue(\n                    layer.built,\n                    msg=(\n                        f\"{type(layer)} is stateless, so it should be built \"\n                        \"after instantiation.\"\n                    ),\n                )\n\n                # Ensure that the subclass layer doesn't mark itself as built\n                # when `build` is overridden.\n\n                class ModifiedBuildLayer(layer_cls):\n                    def build(self, *args, **kwargs):\n                        pass\n\n                layer = ModifiedBuildLayer(**init_kwargs)\n                self.assertFalse(\n                    layer.built,\n                    msg=(\n                        f\"The `build` of {type(layer)} is overriden, so it \"\n                        \"should not be built after instantiation.\"\n                    ),\n                )\n\n        # Eager call test and compiled training test.\n        if input_data is not None or input_shape is not None:\n            if input_data is None:\n                input_data = create_eager_tensors(\n                    input_shape, input_dtype, input_sparse\n                )\n            layer = layer_cls(**init_kwargs)\n            if isinstance(input_data, dict):\n                output_data = layer(**input_data, **call_kwargs)\n            else:\n                output_data = layer(input_data, **call_kwargs)\n            run_output_asserts(\n                layer,\n                output_data,\n                eager=True,\n                tpu_atol=tpu_atol,\n                tpu_rtol=tpu_rtol,\n            )\n\n            if run_training_check:\n                run_training_step(layer, input_data, output_data)\n\n            # Never test mixed precision on torch CPU. Torch lacks support.\n            if run_mixed_precision_check and backend.backend() == \"torch\":\n                import torch\n\n                run_mixed_precision_check = torch.cuda.is_available()\n\n            if run_mixed_precision_check:\n                layer = layer_cls(**{**init_kwargs, \"dtype\": \"mixed_float16\"})\n                input_spec = tree.map_structure(\n                    lambda spec: KerasTensor(\n                        spec.shape,\n                        dtype=(\n                            layer.compute_dtype\n                            if layer.autocast\n                            and backend.is_float_dtype(spec.dtype)\n                            else spec.dtype\n                        ),\n                    ),\n                    keras_tensor_inputs,\n                )\n                if isinstance(input_data, dict):\n                    output_data = layer(**input_data, **call_kwargs)\n                    output_spec = layer.compute_output_spec(**input_spec)\n                else:\n                    output_data = layer(input_data, **call_kwargs)\n                    output_spec = layer.compute_output_spec(input_spec)\n                for tensor, spec in zip(\n                    tree.flatten(output_data), tree.flatten(output_spec)\n                ):\n                    dtype = standardize_dtype(tensor.dtype)\n                    self.assertEqual(\n                        dtype,\n                        spec.dtype,\n                        f\"expected output dtype {spec.dtype}, got {dtype}\",\n                    )\n                for weight in layer.weights:\n                    dtype = standardize_dtype(weight.dtype)\n                    if is_float_dtype(dtype):\n                        self.assertEqual(dtype, \"float32\")\n\n\ndef _jax_uses(device_type):\n    import jax\n\n    return jax.default_backend() == device_type\n\n\ndef _tensorflow_uses(device_type):\n    import tensorflow as tf\n\n    return len(tf.config.list_physical_devices(device_type.upper())) > 0\n\n\ndef _torch_uses(device_type):\n    if device_type == \"gpu\":\n        from keras.src.backend.torch.core import get_device\n\n        return get_device() == \"cuda\"\n    return device_type == \"cpu\"\n\n\ndef uses_gpu():\n    if not hasattr(uses_gpu, \"_value\"):\n        if backend.backend() == \"tensorflow\":\n            uses_gpu._value = _tensorflow_uses(\"gpu\")\n        elif backend.backend() == \"jax\":\n            uses_gpu._value = _jax_uses(\"gpu\")\n        elif backend.backend() == \"torch\":\n            uses_gpu._value = _torch_uses(\"gpu\")\n        else:\n            uses_gpu._value = False\n    return uses_gpu._value\n\n\ndef uses_tpu():\n    if not hasattr(uses_tpu, \"_value\"):\n        if backend.backend() == \"tensorflow\":\n            uses_tpu._value = _tensorflow_uses(\"tpu\")\n        elif backend.backend() == \"jax\":\n            uses_tpu._value = _jax_uses(\"tpu\")\n        else:\n            uses_tpu._value = False\n    return uses_tpu._value\n\n\ndef jax_uses_gpu():\n    return backend.backend() == \"jax\" and uses_gpu()\n\n\ndef jax_uses_tpu():\n    return backend.backend() == \"jax\" and uses_tpu()\n\n\ndef tensorflow_uses_gpu():\n    return backend.backend() == \"tensorflow\" and uses_gpu()\n\n\ndef tensorflow_uses_tpu():\n    return backend.backend() == \"tensorflow\" and uses_tpu()\n\n\ndef torch_uses_gpu():\n    return backend.backend() == \"torch\" and uses_gpu()\n\n\ndef create_keras_tensors(input_shape, dtype, sparse, ragged):\n    if isinstance(input_shape, dict):\n        return {\n            utils.removesuffix(k, \"_shape\"): KerasTensor(\n                v, dtype=dtype[k], sparse=sparse, ragged=ragged\n            )\n            for k, v in input_shape.items()\n        }\n    return map_shape_dtype_structure(\n        lambda shape, dt: KerasTensor(\n            shape, dtype=dt, sparse=sparse, ragged=ragged\n        ),\n        input_shape,\n        dtype,\n    )\n\n\ndef create_eager_tensors(input_shape, dtype, sparse, ragged):\n    from keras.src.backend import random\n\n    if set(tree.flatten(dtype)).difference(\n        [\n            \"float16\",\n            \"float32\",\n            \"float64\",\n            \"int8\",\n            \"uint8\",\n            \"int16\",\n            \"uint16\",\n            \"int32\",\n            \"uint32\",\n            \"int64\",\n            \"uint64\",\n        ]\n    ):\n        raise ValueError(\n            \"dtype must be a standard float or int dtype. \"\n            f\"Received: dtype={dtype}\"\n        )\n\n    if sparse:\n        if backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            def create_fn(shape, dt):\n                rng = np.random.default_rng(0)\n                x = (4 * rng.standard_normal(shape)).astype(dt)\n                x = np.multiply(x, rng.random(shape) < 0.7)\n                return tf.sparse.from_dense(x)\n\n        elif backend.backend() == \"jax\":\n            import jax.experimental.sparse as jax_sparse\n\n            def create_fn(shape, dt):\n                rng = np.random.default_rng(0)\n                x = (4 * rng.standard_normal(shape)).astype(dt)\n                x = np.multiply(x, rng.random(shape) < 0.7)\n                return jax_sparse.BCOO.fromdense(x, n_batch=1)\n\n        else:\n            raise ValueError(\n                f\"Sparse is unsupported with backend {backend.backend()}\"\n            )\n\n    elif ragged:\n        if backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            def create_fn(shape, dt):\n                rng = np.random.default_rng(0)\n                x = (4 * rng.standard_normal(shape)).astype(dt)\n                x = np.multiply(x, rng.random(shape) < 0.7)\n                return tf.RaggedTensor.from_tensor(x, padding=0)\n\n        else:\n            raise ValueError(\n                f\"Ragged is unsupported with backend {backend.backend()}\"\n            )\n\n    else:\n\n        def create_fn(shape, dt):\n            return ops.cast(\n                random.uniform(shape, dtype=\"float32\") * 3, dtype=dt\n            )\n\n    if isinstance(input_shape, dict):\n        return {\n            utils.removesuffix(k, \"_shape\"): create_fn(v, dtype[k])\n            for k, v in input_shape.items()\n        }\n    return map_shape_dtype_structure(create_fn, input_shape, dtype)\n\n\ndef is_shape_tuple(x):\n    return isinstance(x, (list, tuple)) and all(\n        isinstance(e, (int, type(None))) for e in x\n    )\n\n\ndef map_shape_dtype_structure(fn, shape, dtype):\n    \"\"\"Variant of tree.map_structure that operates on shape tuples.\"\"\"\n    if is_shape_tuple(shape):\n        return fn(tuple(shape), dtype)\n    if isinstance(shape, list):\n        return [\n            map_shape_dtype_structure(fn, s, d) for s, d in zip(shape, dtype)\n        ]\n    if isinstance(shape, tuple):\n        return tuple(\n            map_shape_dtype_structure(fn, s, d) for s, d in zip(shape, dtype)\n        )\n    if isinstance(shape, dict):\n        return {\n            k: map_shape_dtype_structure(fn, v, dtype[k])\n            for k, v in shape.items()\n        }\n    else:\n        raise ValueError(\n            f\"Cannot map function to unknown objects {shape} and {dtype}\"\n        )\n\n\ndef get_seed_generators(layer):\n    \"\"\"Get a List of all seed generators in the layer recursively.\"\"\"\n    seed_generators = []\n    seen_ids = set()\n    for sublayer in layer._flatten_layers(True, True):\n        for sg in sublayer._seed_generators:\n            if id(sg) not in seen_ids:\n                seed_generators.append(sg)\n                seen_ids.add(id(sg))\n    return seed_generators\n\n\ndef to_json_with_tuples(value):\n    def _tuple_encode(obj):\n        if isinstance(obj, tuple):\n            return {\"__class__\": \"tuple\", \"__value__\": list(obj)}\n        if isinstance(obj, list):\n            return [_tuple_encode(e) for e in obj]\n        if isinstance(obj, dict):\n            return {key: _tuple_encode(value) for key, value in obj.items()}\n        return obj\n\n    class _PreserveTupleJsonEncoder(json.JSONEncoder):\n        def encode(self, obj):\n            obj = _tuple_encode(obj)\n            return super().encode(obj)\n\n    return _PreserveTupleJsonEncoder(sort_keys=True, indent=4).encode(value)\n\n\ndef from_json_with_tuples(value):\n    def _tuple_decode(obj):\n        if not isinstance(obj, dict):\n            return obj\n        if \"__class__\" not in obj or \"__value__\" not in obj:\n            return obj\n        return tuple(obj[\"__value__\"])\n\n    return json.loads(value, object_hook=_tuple_decode)\n"
  },
  {
    "path": "keras/src/testing/test_utils.py",
    "content": "import numpy as np\n\n\ndef get_test_data(\n    train_samples, test_samples, input_shape, num_classes, random_seed=None\n):\n    \"\"\"Generates balanced, stratified synthetic test data to train a model on.\n\n    Args:\n        train_samples: Integer, how many training samples to generate.\n        test_samples: Integer, how many test samples to generate.\n        input_shape: Tuple of integers, shape of the inputs.\n        num_classes: Integer, number of classes for the data and targets.\n        random_seed: Integer, random seed used by Numpy to generate data.\n\n    Returns:\n        A tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.\n    \"\"\"\n    np.random.seed(random_seed)\n\n    # Total samples\n    total_samples = train_samples + test_samples\n\n    # Ensure that we generate a balanced dataset\n    samples_per_class = total_samples // num_classes\n    y = np.array(\n        [i for i in range(num_classes) for _ in range(samples_per_class)],\n        dtype=np.int32,\n    )\n\n    # Generate extra samples in a deterministic manner\n    extra_samples = total_samples - len(y)\n    y_extra = np.array(\n        [i % num_classes for i in range(extra_samples)], dtype=np.int64\n    )\n    y = np.concatenate([y, y_extra])\n\n    # Generate data\n    templates = 2 * num_classes * np.random.random((num_classes,) + input_shape)\n    x = np.zeros((total_samples,) + input_shape, dtype=np.float32)\n    for i in range(total_samples):\n        x[i] = templates[y[i]] + np.random.normal(\n            loc=0, scale=1.0, size=input_shape\n        )\n\n    # Shuffle the entire dataset to ensure randomness based on seed\n    indices = np.arange(total_samples)\n    np.random.shuffle(indices)\n    x, y = x[indices], y[indices]\n\n    # Stratified Shuffle Split\n    x_train, y_train, x_test, y_test = [], [], [], []\n    for cls in range(num_classes):\n        cls_indices = np.where(y == cls)[0]\n        np.random.shuffle(cls_indices)\n        train_count = int(train_samples / num_classes)\n\n        x_train.extend(x[cls_indices[:train_count]])\n        y_train.extend(y[cls_indices[:train_count]])\n\n        x_test.extend(x[cls_indices[train_count:]])\n        y_test.extend(y[cls_indices[train_count:]])\n\n    # Convert to numpy arrays\n    x_train, y_train = np.array(x_train), np.array(y_train)\n    x_test, y_test = np.array(x_test), np.array(y_test)\n\n    # Shuffle training and test sets after stratified split\n    train_indices = np.arange(len(x_train))\n    test_indices = np.arange(len(x_test))\n    np.random.shuffle(train_indices)\n    np.random.shuffle(test_indices)\n\n    x_train, y_train = x_train[train_indices], y_train[train_indices]\n    x_test, y_test = x_test[test_indices], y_test[test_indices]\n\n    return (x_train, y_train), (x_test, y_test)\n\n\ndef named_product(*args, **kwargs):\n    \"\"\"Utility to generate the cartesian product of parameters values and\n    generate a test case names for each combination.\n\n    The result of this function is to be used with the\n    `@parameterized.named_parameters` decorator. It is a replacement for\n    `@parameterized.product` which adds explicit test case names.\n\n    For example, this code:\n    ```\n    class NamedExample(parameterized.TestCase):\n        @parameterized.named_parameters(\n            named_product(\n                [\n                    {'testcase_name': 'negative', 'x': -1},\n                    {'testcase_name': 'positive', 'x': 1},\n                    {'testcase_name': 'zero', 'x': 0},\n                ],\n                numeral_type=[float, int],\n            )\n        )\n        def test_conversion(self, x, numeral_type):\n            self.assertEqual(numeral_type(x), x)\n    ```\n    produces six tests (note that absl will reorder them by name):\n    - `NamedExample::test_conversion_negative_float`\n    - `NamedExample::test_conversion_positive_float`\n    - `NamedExample::test_conversion_zero_float`\n    - `NamedExample::test_conversion_negative_int`\n    - `NamedExample::test_conversion_positive_int`\n    - `NamedExample::test_conversion_zero_int`\n\n    This function is also useful in the case where there is no product to\n    generate test case names for one argument:\n    ```\n    @parameterized.named_parameters(named_product(numeral_type=[float, int]))\n    ```\n\n    Args:\n        *args: Each positional parameter is a sequence of keyword arg dicts.\n            Every test case generated will include exactly one dict from each\n            positional parameter. These will then be merged to form an overall\n            list of arguments for the test case. Each dict must contain a\n            `\"testcase_name\"` key whose value is combined with others to\n            generate the test case name.\n        **kwargs: A mapping of parameter names and their possible values.\n            Possible values should given as either a list or a tuple. A string\n            representation of each value is used to generate the test case name.\n\n    Returns:\n        A list of maps for the test parameters combinations to pass to\n        `@parameterized.named_parameters`.\n    \"\"\"\n\n    def value_to_str(value):\n        if hasattr(value, \"__name__\"):\n            return value.__name__.lower()\n        return str(value).lower()\n\n    # Convert the keyword arguments in the same dict format as the args\n    all_test_dicts = args + tuple(\n        tuple({\"testcase_name\": value_to_str(v), key: v} for v in values)\n        for key, values in kwargs.items()\n    )\n\n    # The current list of tests, start with one empty test\n    tests = [{}]\n    for test_dicts in all_test_dicts:\n        new_tests = []\n        for test_dict in test_dicts:\n            for test in tests:\n                # Augment the testcase name by appending\n                testcase_name = test.get(\"testcase_name\", \"\")\n                testcase_name += \"_\" if testcase_name else \"\"\n                testcase_name += test_dict[\"testcase_name\"]\n                new_test = test.copy()\n                # Augment the test by adding all the parameters\n                new_test.update(test_dict)\n                new_test[\"testcase_name\"] = testcase_name\n                new_tests.append(new_test)\n        # Overwrite the list of tests with the product obtained so far\n        tests = new_tests\n\n    return tests\n"
  },
  {
    "path": "keras/src/testing/test_utils_test.py",
    "content": "import numpy as np\nfrom absl.testing import parameterized\n\nfrom keras.src.testing import test_case\nfrom keras.src.testing import test_utils\n\n\nclass GetTestDataTest(test_case.TestCase):\n    def setUp(self):\n        self.train_samples = 100\n        self.test_samples = 50\n        self.input_shape = (28, 28)\n        self.num_classes = 10\n\n    def test_labels_within_range(self):\n        \"\"\"Check if labels are within valid range.\"\"\"\n        (_, y_train), (_, y_test) = test_utils.get_test_data(\n            self.train_samples,\n            self.test_samples,\n            self.input_shape,\n            self.num_classes,\n        )\n        self.assertTrue(np.all(y_train < self.num_classes))\n        self.assertTrue(np.all(y_train >= 0))\n        self.assertTrue(np.all(y_test < self.num_classes))\n        self.assertTrue(np.all(y_test >= 0))\n\n    def test_edge_cases_for_zero_samples(self):\n        \"\"\"Test when train or test samples are zero.\"\"\"\n        (x_train, _), (x_test, _) = test_utils.get_test_data(\n            0, self.test_samples, self.input_shape, self.num_classes\n        )\n        self.assertEqual(len(x_train), 0)\n\n        (x_train, _), (x_test, _) = test_utils.get_test_data(\n            self.train_samples, 0, self.input_shape, self.num_classes\n        )\n        self.assertEqual(len(x_test), 0)\n\n    def test_get_test_data_returns_correct_number_of_samples(self):\n        \"\"\"Check if returned samples count is correct.\"\"\"\n        (x_train, y_train), (x_test, y_test) = test_utils.get_test_data(\n            self.train_samples,\n            self.test_samples,\n            self.input_shape,\n            self.num_classes,\n        )\n        self.assertEqual(len(x_train), self.train_samples)\n        self.assertEqual(len(y_train), self.train_samples)\n        self.assertEqual(len(x_test), self.test_samples)\n        self.assertEqual(len(y_test), self.test_samples)\n\n    def test_get_test_data_returns_correct_shape_of_data(self):\n        \"\"\"Check if returned data shape is correct.\"\"\"\n        (x_train, y_train), (x_test, y_test) = test_utils.get_test_data(\n            self.train_samples,\n            self.test_samples,\n            self.input_shape,\n            self.num_classes,\n        )\n        self.assertEqual(\n            x_train.shape, (self.train_samples,) + self.input_shape\n        )\n        self.assertEqual(y_train.shape, (self.train_samples,))\n        self.assertEqual(x_test.shape, (self.test_samples,) + self.input_shape)\n        self.assertEqual(y_test.shape, (self.test_samples,))\n\n    def test_get_test_data_returns_different_data_for_different_seeds(self):\n        \"\"\"Test variability with different seeds.\"\"\"\n        (x_train_1, y_train_1), (x_test_1, y_test_1) = test_utils.get_test_data(\n            self.train_samples,\n            self.test_samples,\n            self.input_shape,\n            self.num_classes,\n            random_seed=1,\n        )\n        (x_train_2, y_train_2), (x_test_2, y_test_2) = test_utils.get_test_data(\n            self.train_samples,\n            self.test_samples,\n            self.input_shape,\n            self.num_classes,\n            random_seed=2,\n        )\n        self.assertFalse(np.array_equal(x_train_1, x_train_2))\n        self.assertFalse(np.array_equal(y_train_1, y_train_2))\n        self.assertFalse(np.array_equal(x_test_1, x_test_2))\n        self.assertFalse(np.array_equal(y_test_1, y_test_2))\n\n    def test_get_test_data_returns_consistent_data_for_same_seed(self):\n        \"\"\"Test consistency with the same seed.\"\"\"\n        (x_train_1, y_train_1), (x_test_1, y_test_1) = test_utils.get_test_data(\n            self.train_samples,\n            self.test_samples,\n            self.input_shape,\n            self.num_classes,\n            random_seed=1,\n        )\n        (x_train_2, y_train_2), (x_test_2, y_test_2) = test_utils.get_test_data(\n            self.train_samples,\n            self.test_samples,\n            self.input_shape,\n            self.num_classes,\n            random_seed=1,\n        )\n        self.assertTrue(np.array_equal(x_train_1, x_train_2))\n        self.assertTrue(np.array_equal(y_train_1, y_train_2))\n        self.assertTrue(np.array_equal(x_test_1, x_test_2))\n        self.assertTrue(np.array_equal(y_test_1, y_test_2))\n\n    def test_input_shape_variations(self):\n        \"\"\"Check function for different input shapes.\"\"\"\n        input_shape_3d = (28, 28, 3)\n        (x_train_3d, _), (_, _) = test_utils.get_test_data(\n            self.train_samples,\n            self.test_samples,\n            input_shape_3d,\n            self.num_classes,\n        )\n        self.assertEqual(\n            x_train_3d.shape, (self.train_samples,) + input_shape_3d\n        )\n\n    def test_all_classes_represented(self):\n        \"\"\"Ensure all classes are represented in the data.\"\"\"\n        (_, y_train), (_, y_test) = test_utils.get_test_data(\n            self.train_samples,\n            self.test_samples,\n            self.input_shape,\n            self.num_classes,\n        )\n        self.assertEqual(len(np.unique(y_train)), self.num_classes)\n        self.assertEqual(len(np.unique(y_test)), self.num_classes)\n\n    def test_data_type(self):\n        \"\"\"Validate the type of the generated data.\"\"\"\n        (x_train, _), (x_test, _) = test_utils.get_test_data(\n            self.train_samples,\n            self.test_samples,\n            self.input_shape,\n            self.num_classes,\n        )\n        self.assertEqual(x_train.dtype, np.float32)\n        self.assertEqual(x_test.dtype, np.float32)\n\n    def test_label_type(self):\n        \"\"\"Validate label type of the generated labels.\"\"\"\n        (_, y_train), (_, y_test) = test_utils.get_test_data(\n            self.train_samples,\n            self.test_samples,\n            self.input_shape,\n            self.num_classes,\n        )\n        self.assertEqual(y_train.dtype, np.int64)\n        self.assertEqual(y_test.dtype, np.int64)\n\n\nclass ClassDistributionTests(test_case.TestCase):\n    def setUp(self):\n        self.train_samples = 100\n        self.test_samples = 50\n        self.input_shape = (28, 28)\n        self.num_classes = 10\n\n    def test_equal_class_distribution(self):\n        \"\"\"Verify equal class distribution in train and test sets.\"\"\"\n        (_, y_train), (_, y_test) = test_utils.get_test_data(\n            self.train_samples,\n            self.test_samples,\n            self.input_shape,\n            self.num_classes,\n        )\n        _, counts_train = np.unique(y_train, return_counts=True)\n        _, counts_test = np.unique(y_test, return_counts=True)\n\n        self.assertTrue(\n            np.all(counts_train == self.train_samples // self.num_classes)\n        )\n        self.assertTrue(\n            np.all(counts_test == self.test_samples // self.num_classes)\n        )\n\n    def test_uneven_samples_class_distribution(self):\n        \"\"\"Check class distribution with uneven samples.\"\"\"\n        train_samples = 103\n        test_samples = 52\n        (_, y_train), (_, y_test) = test_utils.get_test_data(\n            train_samples,\n            test_samples,\n            self.input_shape,\n            self.num_classes,\n        )\n        _, counts_train = np.unique(y_train, return_counts=True)\n        _, counts_test = np.unique(y_test, return_counts=True)\n\n        self.assertTrue(np.max(counts_train) - np.min(counts_train) <= 1)\n        self.assertTrue(np.max(counts_test) - np.min(counts_test) <= 1)\n\n    def test_randomness_in_class_distribution(self):\n        \"\"\"Ensure class distribution isn't too deterministic.\"\"\"\n        (_, y_train_1), (_, y_test_1) = test_utils.get_test_data(\n            self.train_samples,\n            self.test_samples,\n            self.input_shape,\n            self.num_classes,\n        )\n        (_, y_train_2), (_, y_test_2) = test_utils.get_test_data(\n            self.train_samples,\n            self.test_samples,\n            self.input_shape,\n            self.num_classes,\n        )\n        self.assertFalse(np.array_equal(y_train_1, y_train_2))\n        self.assertFalse(np.array_equal(y_test_1, y_test_2))\n\n    def test_large_number_of_classes(self):\n        \"\"\"Validate function with a large number of classes.\"\"\"\n        num_classes = 150\n        train_samples = (\n            num_classes * 10\n        )  # 10 samples for each class in training\n        test_samples = num_classes * 5  # 5 samples for each class in testing\n        (_, y_train), (_, y_test) = test_utils.get_test_data(\n            train_samples,\n            test_samples,\n            self.input_shape,\n            num_classes,\n        )\n        self.assertEqual(len(np.unique(y_train)), num_classes)\n        self.assertEqual(len(np.unique(y_test)), num_classes)\n\n    def test_single_class(self):\n        \"\"\"Test with a single class.\"\"\"\n        num_classes = 1\n        (_, y_train), (_, y_test) = test_utils.get_test_data(\n            self.train_samples,\n            self.test_samples,\n            self.input_shape,\n            num_classes,\n        )\n        self.assertTrue(np.all(y_train == 0))\n        self.assertTrue(np.all(y_test == 0))\n\n\nclass NamedProductTest(parameterized.TestCase):\n    def test_test_cases(self):\n        all_tests = test_utils.named_product(\n            [\n                {\"testcase_name\": \"negative\", \"x\": -1},\n                {\"testcase_name\": \"positive\", \"x\": 1},\n                {\"testcase_name\": \"zero\", \"x\": 0},\n            ],\n            numeral_type=[float, int],\n        )\n        names = [test[\"testcase_name\"] for test in all_tests]\n        self.assertListEqual(\n            names,\n            [\n                \"negative_float\",\n                \"positive_float\",\n                \"zero_float\",\n                \"negative_int\",\n                \"positive_int\",\n                \"zero_int\",\n            ],\n        )\n\n    def test_test_cases_no_product(self):\n        all_tests = test_utils.named_product(numeral_type=[float, int])\n        names = [test[\"testcase_name\"] for test in all_tests]\n        self.assertListEqual(names, [\"float\", \"int\"])\n\n    @parameterized.named_parameters(\n        test_utils.named_product(\n            [\n                {\"testcase_name\": \"negative\", \"x\": -1},\n                {\"testcase_name\": \"positive\", \"x\": 1},\n                {\"testcase_name\": \"zero\", \"x\": 0},\n            ],\n            numeral_type=[float, int],\n        )\n    )\n    def test_via_decorator(self, x, numeral_type):\n        self.assertIn(x, (-1, 1, 0))\n        self.assertIn(numeral_type, (float, int))\n\n    @parameterized.named_parameters(\n        test_utils.named_product(numeral_type=[float, int])\n    )\n    def test_via_decorator_no_product(self, numeral_type):\n        self.assertIn(numeral_type, (float, int))\n"
  },
  {
    "path": "keras/src/trainers/__init__.py",
    "content": ""
  },
  {
    "path": "keras/src/trainers/compile_utils.py",
    "content": "from collections import OrderedDict\nfrom collections import namedtuple\n\nfrom keras.src import losses as losses_module\nfrom keras.src import metrics as metrics_module\nfrom keras.src import ops\nfrom keras.src import tree\nfrom keras.src.backend.common.keras_tensor import KerasTensor\nfrom keras.src.losses import loss as loss_module\nfrom keras.src.utils.naming import get_object_name\nfrom keras.src.utils.tracking import Tracker\nfrom keras.src.utils.tracking import no_automatic_dependency_tracking\n\n\nclass MetricsList(metrics_module.Metric):\n    def __init__(self, metrics, name=\"metrics_list\", output_name=None):\n        super().__init__(name=name)\n        self.metrics = metrics\n        self.output_name = output_name\n\n    def update_state(self, y_true, y_pred, sample_weight=None):\n        for m in self.metrics:\n            m.update_state(y_true, y_pred, sample_weight=sample_weight)\n\n    def reset_state(self):\n        for m in self.metrics:\n            m.reset_state()\n\n    def get_result(self):\n        return {m.name: m.result() for m in self.metrics}\n\n    def get_config(self):\n        raise NotImplementedError\n\n    @classmethod\n    def from_config(cls, config):\n        raise NotImplementedError\n\n\ndef is_binary_or_sparse_categorical(y_true, y_pred):\n    y_t_rank = len(y_true.shape)\n    y_p_rank = len(y_pred.shape)\n    y_t_last_dim = y_true.shape[-1]\n    y_p_last_dim = y_pred.shape[-1]\n\n    is_binary = y_p_last_dim == 1\n    is_sparse_categorical = (\n        y_t_rank < y_p_rank or y_t_last_dim == 1 and y_p_last_dim > 1\n    )\n    return is_binary, is_sparse_categorical\n\n\ndef get_metric(identifier, y_true, y_pred):\n    if identifier is None:\n        raise ValueError(\"Expected metric, received `None`\")\n\n    # Convenience feature for selecting b/t binary, categorical,\n    # and sparse categorical.\n    if str(identifier).lower() not in [\"accuracy\", \"acc\"]:\n        metric_obj = metrics_module.get(identifier)\n    else:\n        is_binary, is_sparse_categorical = is_binary_or_sparse_categorical(\n            y_true, y_pred\n        )\n        if is_binary:\n            metric_obj = metrics_module.BinaryAccuracy(name=str(identifier))\n        elif is_sparse_categorical:\n            metric_obj = metrics_module.SparseCategoricalAccuracy(\n                name=str(identifier)\n            )\n        else:\n            metric_obj = metrics_module.CategoricalAccuracy(\n                name=str(identifier)\n            )\n\n    if isinstance(identifier, str):\n        metric_name = identifier\n    else:\n        metric_name = get_object_name(metric_obj)\n\n    if not isinstance(metric_obj, metrics_module.Metric):\n        metric_obj = metrics_module.MeanMetricWrapper(metric_obj)\n\n    metric_obj.name = metric_name\n    return metric_obj\n\n\ndef get_metrics_list(metrics, y_true, y_pred, output_name=None):\n    if metrics is None:\n        return None\n    if isinstance(metrics, (list, tuple)):\n        return MetricsList(\n            [get_metric(m, y_true, y_pred) for m in metrics],\n            output_name=output_name,\n        )\n    else:\n        return MetricsList(\n            [get_metric(metrics, y_true, y_pred)], output_name=output_name\n        )\n\n\ndef get_loss(identifier, y_true, y_pred):\n    if identifier is None:\n        return None  # Ok to have no loss for an output.\n\n    # Convenience feature for selecting b/t binary, categorical,\n    # and sparse categorical.\n    if str(identifier).lower() not in [\"crossentropy\", \"ce\"]:\n        loss_obj = losses_module.get(identifier)\n    else:\n        is_binary, is_sparse_categorical = is_binary_or_sparse_categorical(\n            y_true, y_pred\n        )\n        if is_binary:\n            loss_obj = losses_module.binary_crossentropy\n        elif is_sparse_categorical:\n            loss_obj = losses_module.sparse_categorical_crossentropy\n        else:\n            loss_obj = losses_module.categorical_crossentropy\n\n    if not isinstance(loss_obj, losses_module.Loss):\n        if isinstance(identifier, str):\n            loss_name = identifier\n        else:\n            loss_name = get_object_name(loss_obj)\n        loss_obj = losses_module.LossFunctionWrapper(loss_obj, name=loss_name)\n    return loss_obj\n\n\nclass CompileMetrics(metrics_module.Metric):\n    @no_automatic_dependency_tracking\n    def __init__(\n        self,\n        metrics,\n        weighted_metrics,\n        name=\"compile_metric\",\n        output_names=None,\n    ):\n        super().__init__(name=name)\n        if metrics and not isinstance(metrics, (list, tuple, dict)):\n            raise ValueError(\n                \"Expected `metrics` argument to be a list, tuple, or dict. \"\n                f\"Received instead: metrics={metrics} of type {type(metrics)}\"\n            )\n        if weighted_metrics and not isinstance(\n            weighted_metrics, (list, tuple, dict)\n        ):\n            raise ValueError(\n                \"Expected `weighted_metrics` argument to be a list, tuple, or \"\n                f\"dict. Received instead: weighted_metrics={weighted_metrics} \"\n                f\"of type {type(weighted_metrics)}\"\n            )\n        self._user_metrics = metrics\n        self._user_weighted_metrics = weighted_metrics\n        self.built = False\n        self.name = \"compile_metrics\"\n        self.output_names = output_names\n\n    @property\n    def metrics(self):\n        if not self.built:\n            return []\n        metrics = []\n        for m in self._flat_metrics + self._flat_weighted_metrics:\n            if isinstance(m, MetricsList):\n                metrics.extend(m.metrics)\n            elif m is not None:\n                metrics.append(m)\n        return metrics\n\n    @property\n    def variables(self):\n        # Avoiding relying on implicit tracking since\n        # CompileMetrics may be instantiated or built in a no tracking scope.\n        if not self.built:\n            return []\n        vars = []\n        for m in self.metrics:\n            if m is not None:\n                vars.extend(m.variables)\n        return vars\n\n    def build(self, y_true, y_pred):\n        self._flat_metrics = self._build_metrics_set(\n            self._user_metrics,\n            y_true,\n            y_pred,\n            argument_name=\"metrics\",\n        )\n        self._flat_weighted_metrics = self._build_metrics_set(\n            self._user_weighted_metrics,\n            y_true,\n            y_pred,\n            argument_name=\"weighted_metrics\",\n        )\n        self.built = True\n\n    def _build_metrics_set(self, metrics, y_true, y_pred, argument_name):\n        num_outputs = len(tree.flatten(y_pred))\n\n        if not metrics:\n            return [None] * num_outputs\n\n        output_names = None\n        flat_metrics = None\n\n        if num_outputs == 1:\n            # Single output, all metrics apply to it, don't use `output_names`.\n            flat_metrics = [tree.flatten(metrics)]\n            output_names = [None]\n        elif (\n            isinstance(metrics, (list, tuple))\n            and self.output_names\n            and len(metrics) == num_outputs\n        ):\n            # `metrics` is a list with one entry per output.\n            # Use the output names to name the metrics.\n            output_names = self.output_names\n            flat_metrics = metrics\n\n        elif isinstance(metrics, dict) and len(metrics) <= num_outputs:\n            # `metrics` is a dictionary with zero or one entry per output.\n            keys = set(metrics.keys())\n            if (\n                isinstance(y_pred, dict)\n                and len(y_pred) == num_outputs\n                and keys <= set(y_pred.keys())\n            ):\n                # If the keys match the output keys, use that, but only if the\n                # outputs are a flat dictionary (not deeply nested). Note that\n                # we prefer these keys over the model output names.\n                # Order `output_names` by the flattening order of `y_pred`.\n                output_names = list(y_pred.keys())\n                if not isinstance(y_pred, OrderedDict):\n                    output_names.sort()\n            elif self.output_names and keys <= set(self.output_names):\n                # If the keys match the Functional output names, use that.\n                # The flattening order of `y_pred` is given by `output_names`.\n                output_names = self.output_names\n\n            if output_names:\n                # Flatten `metrics` with the correct flattening order.\n                flat_metrics = [\n                    metrics[name] if name in metrics else None\n                    for name in output_names\n                ]\n\n        if output_names is not None:\n            try:\n                # Flat case: one output or list or dict of metrics.\n                return [\n                    get_metrics_list(m, yt, yp, n)\n                    for m, yt, yp, n in zip(\n                        flat_metrics,\n                        tree.flatten(y_true),\n                        tree.flatten(y_pred),\n                        output_names,\n                    )\n                ]\n            except ValueError as e:\n                raise ValueError(\n                    f\"{e}\\nReceived: {argument_name}={metrics}\"\n                ) from e\n\n        try:\n            # Deeply nested case: `metrics` must have the structure of `y_pred`.\n            # Note that the tree API wants exact matches, lists and tuples are\n            # not considered equivalent, so we have to turn them all to tuples.\n            tuples_y_pred = tree.lists_to_tuples(y_pred)\n            return tree.flatten(\n                tree.map_structure_up_to(\n                    tuples_y_pred,\n                    get_metrics_list,\n                    tree.lists_to_tuples(metrics),\n                    tree.lists_to_tuples(y_true),\n                    tuples_y_pred,\n                )\n            )\n        except (ValueError, TypeError) as e:\n            # A ValueError from `get_metrics_list` or a ValueError / TypeError\n            # from `tree.map_structure_up_to` for mismatched structures.\n            if self.output_names:\n                raise ValueError(\n                    f\"{e}\\nInvalid `{argument_name}`. `{argument_name}` should \"\n                    \"contain metrics objects and either be a dict or a list \"\n                    \"matching the output names of the functional model \"\n                    f\"{self.output_names} or match the output structure of \"\n                    f\"the model: {tree.map_structure(lambda _: 'X', y_pred)}.\\n\"\n                    f\"Received: {argument_name}={metrics}\"\n                ) from e\n            else:\n                raise ValueError(\n                    f\"{e}\\nInvalid `{argument_name}`. `{argument_name}` should \"\n                    \"contain metrics objects and match the output structure of \"\n                    f\"the model: {tree.map_structure(lambda _: 'X', y_pred)}.\\n\"\n                    f\"Received: {argument_name}={metrics}\"\n                ) from e\n\n    def update_state(self, y_true, y_pred, sample_weight=None):\n        if not self.built:\n            self.build(y_true, y_pred)\n        y_true = tree.flatten(y_true)\n        y_pred = tree.flatten(y_pred)\n        for m, y_t, y_p in zip(self._flat_metrics, y_true, y_pred):\n            if m:\n                m.update_state(y_t, y_p)\n        if sample_weight is not None:\n            sample_weight = tree.flatten(sample_weight)\n            # For multi-outputs, repeat sample weights for n outputs.\n            if len(sample_weight) < len(y_true):\n                sample_weight = [sample_weight[0] for _ in range(len(y_true))]\n        else:\n            sample_weight = [None for _ in range(len(y_true))]\n        for m, y_t, y_p, s_w in zip(\n            self._flat_weighted_metrics, y_true, y_pred, sample_weight\n        ):\n            if m:\n                m.update_state(y_t, y_p, s_w)\n\n    def reset_state(self):\n        if not self.built:\n            return\n        for m in self._flat_metrics:\n            if m:\n                m.reset_state()\n        for m in self._flat_weighted_metrics:\n            if m:\n                m.reset_state()\n\n    def result(self):\n        if not self.built:\n            raise ValueError(\n                \"Cannot get result() since the metric has not yet been built.\"\n            )\n        results = {}\n        unique_name_counters = {}\n        for mls in self._flat_metrics:\n            if not mls:\n                continue\n            for m in mls.metrics:\n                name = m.name\n                if mls.output_name:\n                    name = f\"{mls.output_name}_{name}\"\n                if name not in unique_name_counters:\n                    results[name] = m.result()\n                    unique_name_counters[name] = 1\n                else:\n                    index = unique_name_counters[name]\n                    unique_name_counters[name] += 1\n                    name = f\"{name}_{index}\"\n                    results[name] = m.result()\n\n        for mls in self._flat_weighted_metrics:\n            if not mls:\n                continue\n            for m in mls.metrics:\n                name = m.name\n                if mls.output_name:\n                    name = f\"{mls.output_name}_{name}\"\n                if name not in unique_name_counters:\n                    results[name] = m.result()\n                    unique_name_counters[name] = 1\n                else:\n                    name = f\"weighted_{m.name}\"\n                    if mls.output_name:\n                        name = f\"{mls.output_name}_{name}\"\n                    if name not in unique_name_counters:\n                        unique_name_counters[name] = 1\n                    else:\n                        index = unique_name_counters[name]\n                        unique_name_counters[name] += 1\n                        name = f\"{name}_{index}\"\n                    results[name] = m.result()\n        return results\n\n    def get_config(self):\n        raise NotImplementedError\n\n    @classmethod\n    def from_config(cls, config):\n        raise NotImplementedError\n\n\nclass CompileLoss(losses_module.Loss):\n    Loss = namedtuple(\"Loss\", [\"path\", \"loss\", \"loss_weights\", \"name\"])\n\n    def __init__(\n        self,\n        loss,\n        loss_weights=None,\n        reduction=\"sum_over_batch_size\",\n        output_names=None,\n    ):\n        if loss_weights and not isinstance(\n            loss_weights, (list, tuple, dict, float)\n        ):\n            raise ValueError(\n                \"Expected `loss_weights` argument to be a float \"\n                \"(single output case) or a list, tuple, or \"\n                \"dict (multiple output case). \"\n                f\"Received instead: loss_weights={loss_weights} \"\n                f\"of type {type(loss_weights)}\"\n            )\n        self._user_loss = loss\n        self._user_loss_weights = loss_weights\n        self.built = False\n        self.output_names = output_names\n        super().__init__(name=\"compile_loss\", reduction=reduction)\n\n        # Use `Tracker` to track metrics for individual losses.\n        self._metrics = []\n        self._tracker = Tracker(\n            {\n                \"metrics\": (\n                    lambda x: isinstance(x, metrics_module.Metric),\n                    self._metrics,\n                )\n            }\n        )\n        self._flat_losses = None\n        self._y_pred_build_structure = None\n        self._y_true_build_structure = None\n\n    @property\n    def metrics(self):\n        return self._metrics\n\n    @property\n    def variables(self):\n        vars = []\n        for m in self.metrics:\n            vars.extend(m.variables)\n        return vars\n\n    def _build_nested(self, y_true, y_pred, loss, output_names, current_path):\n        flat_y_pred = tree.flatten(y_pred)\n        if not tree.is_nested(loss):\n            _loss = loss.loss\n            if _loss is None:\n                return\n            loss_weight = loss.weight\n            resolved_loss = get_loss(_loss, y_true, y_pred)\n            name_path = current_path\n            if not tree.is_nested(output_names):\n                if output_names is not None:\n                    output_name = output_names\n                else:\n                    output_name = resolved_loss.name\n                if len(name_path) == 0:\n                    name_path = (output_name,)\n                elif isinstance(name_path[-1], int):\n                    name_path = name_path[:-1] + (output_name,)\n            name = \"/\".join([str(path) for path in name_path])\n            if name == \"\":\n                if isinstance(output_names, dict):\n                    flat_output_names = list(output_names.keys())\n                else:\n                    flat_output_names = tree.flatten(output_names)\n                name = \"_\".join(flat_output_names)\n            self._flat_losses.append(\n                CompileLoss.Loss(current_path, resolved_loss, loss_weight, name)\n            )\n            return\n        elif (\n            issubclass(type(loss), (list, tuple))\n            and all([not tree.is_nested(_loss) for _loss in loss])\n            and len(loss) == len(flat_y_pred)\n        ):\n            loss = tree.pack_sequence_as(y_pred, loss)\n        elif issubclass(type(loss), (list, tuple)) and not isinstance(\n            y_pred, type(loss)\n        ):\n            for _loss in loss:\n                self._build_nested(\n                    y_true,\n                    y_pred,\n                    _loss,\n                    output_names,\n                    current_path,\n                )\n            return\n\n        if not tree.is_nested(loss):\n            return self._build_nested(\n                y_true, y_pred, loss, output_names, current_path\n            )\n\n        if not isinstance(loss, type(y_pred)):\n            raise KeyError(\n                f\"The path: {current_path} in \"\n                \"the `loss` argument, can't be found in \"\n                \"the model's output (`y_pred`).\"\n            )\n\n        # shallow traverse the loss config\n        if isinstance(loss, dict):\n            iterator = loss.items()\n\n            def key_check_fn(key, objs):\n                return all(\n                    [isinstance(obj, dict) and key in obj for obj in objs]\n                )\n\n        elif issubclass(type(loss), (list, tuple)):\n            iterator = enumerate(loss)\n\n            def key_check_fn(key, objs):\n                return all(\n                    [\n                        issubclass(type(obj), (list, tuple)) and key < len(obj)\n                        for obj in objs\n                    ]\n                )\n\n        else:\n            raise TypeError(\n                f\"Unsupported type {type(loss)} in the `loss` configuration.\"\n            )\n\n        for key, _loss in iterator:\n            if _loss is None:\n                continue\n            if not key_check_fn(key, (y_true, y_pred)):\n                raise KeyError(\n                    f\"The path: {current_path + (key,)} in \"\n                    \"the `loss` argument, can't be found in \"\n                    \"either the model's output (`y_pred`) or in the \"\n                    \"labels (`y_true`).\"\n                )\n\n            self._build_nested(\n                y_true[key],\n                y_pred[key],\n                _loss,\n                output_names[key],\n                current_path + (key,),\n            )\n\n    def build(self, y_true, y_pred):\n        loss = self._user_loss\n        loss_weights = self._user_loss_weights\n        flat_output_names = self.output_names\n        if (\n            self.output_names\n            and isinstance(self._user_loss, dict)\n            and not isinstance(y_pred, dict)\n        ):\n            if set(self.output_names) == set(self._user_loss.keys()):\n                loss = [self._user_loss[name] for name in self.output_names]\n                if isinstance(self._user_loss_weights, dict):\n                    loss_weights = [\n                        self._user_loss_weights[name]\n                        for name in self.output_names\n                    ]\n            else:\n                raise ValueError(\n                    f\"Expected keys {self.output_names} in loss dict, but \"\n                    f\"found loss.keys()={list(self._user_loss.keys())}\"\n                )\n\n        # Pytree leaf container\n        class WeightedLoss:\n            def __new__(cls, loss, weight):\n                if loss is None:\n                    return None\n                return object.__new__(cls)\n\n            def __init__(self, loss, weight):\n                self.loss = loss\n                self.weight = weight\n\n        # pack the losses and the weights together\n        if loss_weights is not None:\n            try:\n                tree.assert_same_structure(loss, loss_weights)\n            except ValueError:\n                flat_loss_weights = tree.flatten(loss_weights)\n                if len(tree.flatten(loss)) != len(flat_loss_weights):\n                    raise ValueError(\n                        f\"`loss_weights` must match the number of losses, \"\n                        f\"got {len(tree.flatten(loss))} losses \"\n                        f\"and {len(loss_weights)} weights.\"\n                    )\n                loss_weights = tree.pack_sequence_as(loss, flat_loss_weights)\n            loss = tree.map_structure(\n                lambda _loss, _weight: WeightedLoss(_loss, _weight),\n                loss,\n                loss_weights,\n            )\n        else:\n            loss = tree.map_structure(\n                lambda _loss: WeightedLoss(_loss, None), loss\n            )\n\n        self._flat_losses = []\n\n        if (\n            isinstance(loss, dict)\n            and issubclass(type(y_pred), (list, tuple))\n            and set(loss.keys()) == set(flat_output_names)\n            and len(y_pred) == len(flat_output_names)\n        ):\n            y_pred = {name: y_p for name, y_p in zip(flat_output_names, y_pred)}\n            y_true = {name: y_t for name, y_t in zip(flat_output_names, y_true)}\n        elif (\n            isinstance(loss, dict)\n            and not tree.is_nested(y_pred)\n            and set(loss.keys()) == set(flat_output_names)\n            and len(flat_output_names) == 1\n        ):\n            y_pred = {\n                name: y_p for name, y_p in zip(flat_output_names, [y_pred])\n            }\n            y_true = {\n                name: y_t for name, y_t in zip(flat_output_names, [y_true])\n            }\n\n        try:\n            output_names = tree.pack_sequence_as(y_pred, flat_output_names)\n        except:\n            inferred_flat_output_names = self._get_y_pred_output_names(y_pred)\n            output_names = tree.pack_sequence_as(\n                y_pred, inferred_flat_output_names\n            )\n\n        if not tree.is_nested(loss):\n            loss = tree.map_structure(lambda x: loss, y_pred)\n\n        self._build_nested(y_true, y_pred, loss, output_names, ())\n\n        # Add `Mean` metric to the tracker for each loss.\n        if len(self._flat_losses) > 1:\n            for _loss in self._flat_losses:\n                name = f\"{_loss.name}_loss\"\n                self._tracker.add_to_store(\n                    \"metrics\", metrics_module.Mean(name=name)\n                )\n\n        self._y_pred_build_structure = tree.map_structure(\n            lambda x: None, y_pred\n        )\n        self._y_true_build_structure = tree.map_structure(\n            lambda x: None, y_true\n        )\n        self.built = True\n\n    def _get_y_pred_output_names(self, y_pred):\n        flat_y_pred = tree.flatten(y_pred)\n        if all((isinstance(x, KerasTensor) for x in flat_y_pred)):\n            output_names = []\n            for tensor in flat_y_pred:\n                if hasattr(tensor, \"_keras_history\"):\n                    output_names.append(tensor._keras_history.operation.name)\n                else:\n                    output_names.append(tensor.name)\n        else:\n            output_names = [None] * len(flat_y_pred)\n        return output_names\n\n    def __call__(self, y_true, y_pred, sample_weight=None):\n        with ops.name_scope(self.name):\n            return self.call(y_true, y_pred, sample_weight)\n\n    def call(self, y_true, y_pred, sample_weight=None):\n        def resolve_path(path, object):\n            for _path in path:\n                object = object[_path]\n            return object\n\n        if not tree.is_nested(y_true) and not tree.is_nested(y_pred):\n            # Fast path: single output case / no loss-tracking metric.\n            if not self.built:\n                self.build(y_true, y_pred)\n            # Although we are in the fast path, we still need to iterate\n            # through the losses to prevent the torch compiler from failing.\n            loss_values = []\n            for path, loss_fn, loss_weight, _ in self._flat_losses:\n                y_t, y_p = (\n                    resolve_path(path, y_true),\n                    resolve_path(path, y_pred),\n                )\n                if sample_weight is not None and tree.is_nested(sample_weight):\n                    _sample_weight = resolve_path(path, sample_weight)\n                else:\n                    _sample_weight = sample_weight\n                value = ops.cast(\n                    loss_fn(y_t, y_p, _sample_weight), dtype=self.dtype\n                )\n                if loss_weight is not None:\n                    value = ops.multiply(value, loss_weight)\n                loss_values.append(value)\n            return loss_values[0]\n\n        try:\n            tree.assert_same_structure(y_pred, y_true)\n        except ValueError:\n            # Check case where y_true is either flat or leaf\n            if (\n                not tree.is_nested(y_true)\n                and hasattr(y_pred, \"__len__\")\n                and len(y_pred) == 1\n            ):\n                y_true = [y_true]\n\n            # Check case where y_pred is list/tuple and y_true is dict\n            elif isinstance(y_pred, (list, tuple)) and isinstance(y_true, dict):\n                if set(self.output_names) == set(y_true.keys()):\n                    y_true = [y_true[name] for name in self.output_names]\n\n            try:\n                y_true = tree.pack_sequence_as(y_pred, y_true)\n            except:\n                # Check case where y_true has the same structure but uses\n                # different (but reconcilable) container types,\n                # e.g `list` vs `tuple`.\n                try:\n                    tree.assert_same_paths(y_true, y_pred)\n                    y_true = tree.pack_sequence_as(y_pred, tree.flatten(y_true))\n                except:\n                    try:\n                        # Check case where loss is partially defined over y_pred\n                        flat_y_true = tree.flatten(y_true)\n                        flat_loss = tree.flatten(self._user_loss)\n                        flat_loss_non_nones = [\n                            (i, loss)\n                            for i, loss in enumerate(flat_loss)\n                            if loss is not None\n                        ]\n                        if len(flat_y_true) != len(flat_loss_non_nones):\n                            raise ValueError(\n                                \"Internal error: the number of values in \"\n                                f\"`y_true` ({len(flat_y_true)}) must match the \"\n                                \"number of non-None values in `loss` \"\n                                f\"({len(flat_loss_non_nones)}).\"\n                            )\n                        y_true = [None] * len(flat_loss)\n                        for y_t, (i, loss) in zip(\n                            flat_y_true, flat_loss_non_nones\n                        ):\n                            y_true[i] = y_t\n                        y_true = tree.pack_sequence_as(self._user_loss, y_true)\n                    except:\n                        y_true_struct = tree.map_structure(\n                            lambda _: \"*\", y_true\n                        )\n                        y_pred_struct = tree.map_structure(\n                            lambda _: \"*\", y_pred\n                        )\n                        raise ValueError(\n                            \"y_true and y_pred have different structures.\\n\"\n                            f\"y_true: {y_true_struct}\\n\"\n                            f\"y_pred: {y_pred_struct}\\n\"\n                        )\n\n        if not self.built:\n            self.build(y_true, y_pred)\n\n        try:\n            tree.assert_same_structure(self._y_pred_build_structure, y_pred)\n        except ValueError:\n            y_pred = tree.pack_sequence_as(\n                self._y_pred_build_structure, tree.flatten(y_pred)\n            )\n        try:\n            tree.assert_same_structure(self._y_true_build_structure, y_true)\n        except ValueError:\n            y_true = tree.pack_sequence_as(\n                self._y_true_build_structure, tree.flatten(y_true)\n            )\n\n        # We need to add a dummy `None` if the model has only a single output.\n        metrics = [None] if len(self.metrics) == 0 else self.metrics\n\n        # Iterate all losses in flat form.\n        loss_values = []\n\n        for (path, loss_fn, loss_weight, _), metric in zip(\n            self._flat_losses, metrics\n        ):\n            y_t, y_p = resolve_path(path, y_true), resolve_path(path, y_pred)\n            if sample_weight is not None and tree.is_nested(sample_weight):\n                _sample_weight = resolve_path(path, sample_weight)\n            else:\n                _sample_weight = sample_weight\n\n            value = ops.cast(\n                loss_fn(y_t, y_p, _sample_weight), dtype=self.dtype\n            )\n            # Record *unweighted* individual losses.\n            if metric:\n                metric.update_state(\n                    loss_module.unscale_loss_for_distribution(value),\n                    sample_weight=tree.flatten(y_p)[0].shape[0],\n                )\n            if loss_weight is not None:\n                value = ops.multiply(value, loss_weight)\n            loss_values.append(value)\n\n        if loss_values:\n            total_loss = sum(loss_values)\n            return total_loss\n        return None\n\n    def get_config(self):\n        raise NotImplementedError\n\n    @classmethod\n    def from_config(cls, config):\n        raise NotImplementedError\n"
  },
  {
    "path": "keras/src/trainers/compile_utils_test.py",
    "content": "from collections import namedtuple\n\nimport numpy as np\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import metrics as losses_module\nfrom keras.src import metrics as metrics_module\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src import tree\nfrom keras.src.trainers.compile_utils import CompileLoss\nfrom keras.src.trainers.compile_utils import CompileMetrics\n\n\nclass TestCompileMetrics(testing.TestCase):\n    def test_single_output_case(self):\n        compile_metrics = CompileMetrics(\n            metrics=[metrics_module.MeanSquaredError()],\n            weighted_metrics=[metrics_module.MeanSquaredError()],\n        )\n        # Test symbolic build\n        y_true = backend.KerasTensor((3, 4))\n        y_pred = backend.KerasTensor((3, 4))\n        compile_metrics.build(y_true, y_pred)\n        # Test eager build\n        y_true = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]])\n        y_pred = np.array([[0.4, 0.1], [0.2, 0.6], [0.6, 0.1]])\n        sample_weight = np.array([1, 0.0, 1])\n        compile_metrics.build(y_true, y_pred)\n\n        # Test update / result / reset flow\n        compile_metrics.update_state(\n            y_true, y_pred, sample_weight=sample_weight\n        )\n        y_pred = np.array([[0.3, 0.2], [0.1, 0.4], [0.2, 0.3]])\n        compile_metrics.update_state(\n            y_true, y_pred, sample_weight=sample_weight\n        )\n        result = compile_metrics.result()\n        self.assertIsInstance(result, dict)\n        self.assertEqual(len(result), 2)\n        self.assertAllClose(result[\"mean_squared_error\"], 0.055833336)\n        self.assertAllClose(result[\"weighted_mean_squared_error\"], 0.0725)\n\n        compile_metrics.reset_state()\n        result = compile_metrics.result()\n        self.assertIsInstance(result, dict)\n        self.assertEqual(len(result), 2)\n        self.assertAllClose(result[\"mean_squared_error\"], 0.0)\n        self.assertAllClose(result[\"weighted_mean_squared_error\"], 0.0)\n\n    def test_list_output_case(self):\n        compile_metrics = CompileMetrics(\n            metrics=[\n                [\n                    metrics_module.MeanSquaredError(),\n                    metrics_module.MeanSquaredError(),\n                ],\n                [\n                    metrics_module.MeanSquaredError(),\n                    metrics_module.MeanSquaredError(),\n                ],\n            ],\n            weighted_metrics=[\n                [\n                    metrics_module.MeanSquaredError(),\n                    metrics_module.MeanSquaredError(),\n                ],\n                [\n                    metrics_module.MeanSquaredError(),\n                    metrics_module.MeanSquaredError(),\n                ],\n            ],\n        )\n        # Test symbolic build\n        y_true = [backend.KerasTensor((3, 4)), backend.KerasTensor((3, 4))]\n        y_pred = [backend.KerasTensor((3, 4)), backend.KerasTensor((3, 4))]\n        compile_metrics.build(y_true, y_pred)\n        self.assertEqual(len(compile_metrics.metrics), 8)\n\n        # Test eager build\n        y_true = [\n            np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]),\n            np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]),\n        ]\n        y_pred = [\n            np.array([[0.4, 0.1], [0.2, 0.6], [0.6, 0.1]]),\n            np.array([[0.4, 0.1], [0.2, 0.6], [0.6, 0.1]]),\n        ]\n        sample_weight = np.array([1, 0.0, 1])\n        compile_metrics.build(y_true, y_pred)\n        self.assertEqual(len(compile_metrics.metrics), 8)\n\n        # Test update / result / reset flow\n        compile_metrics.update_state(\n            y_true, y_pred, sample_weight=sample_weight\n        )\n        y_pred = [\n            np.array([[0.3, 0.2], [0.1, 0.4], [0.2, 0.3]]),\n            np.array([[0.3, 0.2], [0.1, 0.4], [0.2, 0.3]]),\n        ]\n        compile_metrics.update_state(\n            y_true, y_pred, sample_weight=sample_weight\n        )\n        result = compile_metrics.result()\n        self.assertIsInstance(result, dict)\n        self.assertEqual(len(result), 8)\n        self.assertAllClose(result[\"mean_squared_error\"], 0.055833336)\n        self.assertAllClose(result[\"weighted_mean_squared_error\"], 0.0725)\n\n        compile_metrics.reset_state()\n        result = compile_metrics.result()\n        self.assertIsInstance(result, dict)\n        self.assertEqual(len(result), 8)\n        self.assertAllClose(result[\"mean_squared_error\"], 0.0)\n        self.assertAllClose(result[\"weighted_mean_squared_error\"], 0.0)\n\n    def test_dict_output_case(self):\n        compile_metrics = CompileMetrics(\n            metrics={\n                \"output_1\": [\n                    metrics_module.MeanSquaredError(),\n                    metrics_module.MeanSquaredError(name=\"mse\"),\n                ],\n                \"output_2\": [\n                    metrics_module.MeanSquaredError(),\n                    metrics_module.MeanSquaredError(name=\"mse\"),\n                ],\n            },\n            weighted_metrics={\n                \"output_1\": [\n                    metrics_module.MeanSquaredError(),\n                    metrics_module.MeanSquaredError(name=\"mse\"),\n                ],\n                \"output_2\": [\n                    metrics_module.MeanSquaredError(),\n                    metrics_module.MeanSquaredError(name=\"mse\"),\n                ],\n            },\n        )\n        # Test symbolic build\n        y_true = {\n            \"output_1\": backend.KerasTensor((3, 4)),\n            \"output_2\": backend.KerasTensor((3, 4)),\n        }\n        y_pred = {\n            \"output_1\": backend.KerasTensor((3, 4)),\n            \"output_2\": backend.KerasTensor((3, 4)),\n        }\n        compile_metrics.build(y_true, y_pred)\n        # Test eager build\n        y_true = {\n            \"output_1\": np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]),\n            \"output_2\": np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]),\n        }\n        y_pred = {\n            \"output_1\": np.array([[0.4, 0.1], [0.2, 0.6], [0.6, 0.1]]),\n            \"output_2\": np.array([[0.4, 0.1], [0.2, 0.6], [0.6, 0.1]]),\n        }\n        sample_weight = np.array([1, 0.0, 1])\n        compile_metrics.build(y_true, y_pred)\n\n        # Test update / result / reset flow\n        compile_metrics.update_state(\n            y_true, y_pred, sample_weight=sample_weight\n        )\n        y_pred = {\n            \"output_1\": np.array([[0.3, 0.2], [0.1, 0.4], [0.2, 0.3]]),\n            \"output_2\": np.array([[0.3, 0.2], [0.1, 0.4], [0.2, 0.3]]),\n        }\n        compile_metrics.update_state(\n            y_true, y_pred, sample_weight=sample_weight\n        )\n        result = compile_metrics.result()\n        self.assertIsInstance(result, dict)\n        self.assertEqual(len(result), 8)\n        # Result values obtained from `tf.keras`\n        # m = tf.keras.metrics.MeanSquaredError()\n        # m.update_state(y_true, y_pred1, sample_weight=weight)\n        # m.update_state(y_true, y_pred2, sample_weight=weight)\n        # m.result().numpy()\n        self.assertAllClose(result[\"output_1_mean_squared_error\"], 0.055833336)\n        self.assertAllClose(result[\"output_2_mean_squared_error\"], 0.055833336)\n        self.assertAllClose(result[\"output_1_mse\"], 0.055833336)\n        self.assertAllClose(result[\"output_2_mse\"], 0.055833336)\n        self.assertAllClose(\n            result[\"output_1_weighted_mean_squared_error\"], 0.0725\n        )\n        self.assertAllClose(\n            result[\"output_2_weighted_mean_squared_error\"], 0.0725\n        )\n        self.assertAllClose(result[\"output_1_weighted_mse\"], 0.0725)\n        self.assertAllClose(result[\"output_2_weighted_mse\"], 0.0725)\n\n        compile_metrics.reset_state()\n        result = compile_metrics.result()\n        self.assertIsInstance(result, dict)\n        self.assertEqual(len(result), 8)\n        self.assertAllClose(result[\"output_1_mean_squared_error\"], 0.0)\n        self.assertAllClose(result[\"output_2_mean_squared_error\"], 0.0)\n        self.assertAllClose(result[\"output_1_weighted_mean_squared_error\"], 0.0)\n        self.assertAllClose(result[\"output_2_weighted_mean_squared_error\"], 0.0)\n\n    def test_name_conversions(self):\n        compile_metrics = CompileMetrics(\n            metrics=[\"acc\", \"accuracy\", \"mse\"],\n            weighted_metrics=[],\n        )\n        y_true = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]])\n        y_pred = np.array([[0.4, 0.1], [0.2, 0.6], [0.6, 0.1]])\n        compile_metrics.build(y_true, y_pred)\n        compile_metrics.update_state(y_true, y_pred, sample_weight=None)\n        result = compile_metrics.result()\n        self.assertIsInstance(result, dict)\n        self.assertEqual(len(result), 3)\n        self.assertAllClose(result[\"acc\"], 0.333333)\n        self.assertAllClose(result[\"accuracy\"], 0.333333)\n        self.assertTrue(\"mse\" in result)\n\n    def test_custom_metric_function(self):\n        def my_custom_metric(y_true, y_pred):\n            return ops.mean(ops.square(y_true - y_pred), axis=-1)\n\n        compile_metrics = CompileMetrics(\n            metrics=[my_custom_metric],\n            weighted_metrics=[],\n        )\n        y_true = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]])\n        y_pred = np.array([[0.4, 0.1], [0.2, 0.6], [0.6, 0.1]])\n        compile_metrics.build(y_true, y_pred)\n        compile_metrics.update_state(y_true, y_pred, sample_weight=None)\n        result = compile_metrics.result()\n        self.assertIsInstance(result, dict)\n        self.assertEqual(list(result.keys()), [\"my_custom_metric\"])\n\n    def test_dict_outputs_uses_output_names(self):\n        \"\"\"Tests that when output_names match the metrics dict keys, and the\n        output key names don't, the output_names are used.\"\"\"\n\n        # output_names represent internal op names that do not match the dict\n        # keys of the output map.\n        compile_metrics = CompileMetrics(\n            metrics={\n                \"dense_1\": metrics_module.MeanSquaredError(),\n                \"dense_2\": metrics_module.MeanSquaredError(),\n            },\n            weighted_metrics=None,\n            output_names=[\"dense_1\", \"dense_2\"],\n        )\n\n        # Symbolic build with dict outputs keyed by user-facing names.\n        y_true = {\n            \"a\": backend.KerasTensor((3, 2)),\n            \"b\": backend.KerasTensor((3, 2)),\n        }\n        y_pred = {\n            \"a\": backend.KerasTensor((3, 2)),\n            \"b\": backend.KerasTensor((3, 2)),\n        }\n\n        compile_metrics.build(y_true, y_pred)\n\n        # Make the two outputs produce different MSEs to verify mapping.\n        y_true = {\n            \"a\": np.zeros((3, 2), dtype=\"float32\"),\n            \"b\": np.zeros((3, 2), dtype=\"float32\"),\n        }\n        y_pred = {\n            # MSE(a) = 0.0\n            \"a\": np.zeros((3, 2), dtype=\"float32\"),\n            # MSE(b) = 1.0\n            \"b\": np.ones((3, 2), dtype=\"float32\"),\n        }\n        compile_metrics.update_state(y_true, y_pred)\n\n        result = compile_metrics.result()\n        self.assertIsInstance(result, dict)\n        self.assertEqual(\n            list(result.keys()),\n            [\"dense_1_mean_squared_error\", \"dense_2_mean_squared_error\"],\n        )\n        self.assertAllClose(result[\"dense_1_mean_squared_error\"], 0.0)\n        self.assertAllClose(result[\"dense_2_mean_squared_error\"], 1.0)\n\n    def test_dict_outputs_output_names_ordering(self):\n        \"\"\"Tests that when the metrics are not declared in the same order as\n        the output names, they are remapped correctly.\"\"\"\n\n        # Put metrics in the wrong order to check the reordering happened.\n        compile_metrics = CompileMetrics(\n            metrics={\n                \"dense_2\": metrics_module.MeanAbsolutePercentageError(),\n                \"dense_1\": metrics_module.MeanSquaredError(),\n            },\n            weighted_metrics=None,\n            output_names=[\"dense_1\", \"dense_2\"],\n        )\n\n        # Symbolic build with dict outputs keyed by user-facing names.\n        y_true = {\n            \"a\": backend.KerasTensor((3, 2)),\n            \"b\": backend.KerasTensor((3, 2)),\n        }\n        y_pred = {\n            \"a\": backend.KerasTensor((3, 2)),\n            \"b\": backend.KerasTensor((3, 2)),\n        }\n\n        compile_metrics.build(y_true, y_pred)\n\n        # Make the two outputs produce different metric values to verify the\n        # order of metrics. In both cases the difference is 1, but MSE and MAPE\n        # will have different values.\n        y_true = {\n            \"a\": np.ones((3, 2), dtype=\"float32\"),\n            \"b\": np.ones((3, 2), dtype=\"float32\"),\n        }\n        y_pred = {\n            # MSE(a) = 1.0\n            \"a\": np.full((3, 2), 2.0, dtype=\"float32\"),\n            # MAPE(b) = 100.0\n            \"b\": np.zeros((3, 2), dtype=\"float32\"),\n        }\n        compile_metrics.update_state(y_true, y_pred)\n\n        result = compile_metrics.result()\n        self.assertIsInstance(result, dict)\n        self.assertEqual(\n            list(result.keys()),\n            [\n                \"dense_1_mean_squared_error\",\n                \"dense_2_mean_absolute_percentage_error\",\n            ],\n        )\n        self.assertAllClose(result[\"dense_1_mean_squared_error\"], 1.0)\n        self.assertAllClose(\n            result[\"dense_2_mean_absolute_percentage_error\"], 100.0\n        )\n\n    def test_dict_outputs_outputs_ordering(self):\n        \"\"\"Tests that when the metrics are not declared in the same order as\n        the keys in the output dict, they are remapped correctly.\"\"\"\n\n        # Put metrics in the wrong order to check the reordering happened.\n        compile_metrics = CompileMetrics(\n            metrics={\n                \"b\": metrics_module.MeanAbsolutePercentageError(),\n                \"a\": metrics_module.MeanSquaredError(),\n            },\n            weighted_metrics=None,\n            output_names=[\"dense_1\", \"dense_2\"],\n        )\n\n        # Symbolic build with dict outputs keyed by user-facing names.\n        y_true = {\n            \"a\": backend.KerasTensor((3, 2)),\n            \"b\": backend.KerasTensor((3, 2)),\n        }\n        y_pred = {\n            \"a\": backend.KerasTensor((3, 2)),\n            \"b\": backend.KerasTensor((3, 2)),\n        }\n\n        compile_metrics.build(y_true, y_pred)\n\n        # Make the two outputs produce different metric values to verify the\n        # order of metrics. In both cases the difference is 1, but MSE and MAPE\n        # will have different values.\n        y_true = {\n            \"a\": np.ones((3, 2), dtype=\"float32\"),\n            \"b\": np.ones((3, 2), dtype=\"float32\"),\n        }\n        y_pred = {\n            # MSE(a) = 1.0\n            \"a\": np.full((3, 2), 2.0, dtype=\"float32\"),\n            # MAPE(b) = 100.0\n            \"b\": np.zeros((3, 2), dtype=\"float32\"),\n        }\n        compile_metrics.update_state(y_true, y_pred)\n\n        result = compile_metrics.result()\n        self.assertIsInstance(result, dict)\n        self.assertEqual(\n            list(result.keys()),\n            [\"a_mean_squared_error\", \"b_mean_absolute_percentage_error\"],\n        )\n        self.assertAllClose(result[\"a_mean_squared_error\"], 1.0)\n        self.assertAllClose(result[\"b_mean_absolute_percentage_error\"], 100.0)\n\n    def test_dict_outputs_ignore_mismatched_output_names(self):\n        \"\"\"Tests that when output_names does not match dict keys, the keys from\n        the output dict are used.\"\"\"\n\n        # output_names represent internal op names that do not match dict keys.\n        compile_metrics = CompileMetrics(\n            metrics={\n                \"a\": metrics_module.MeanSquaredError(),\n                \"b\": metrics_module.MeanSquaredError(),\n            },\n            weighted_metrics=None,\n            output_names=[\"dense\", \"dense_1\"],\n        )\n\n        # Symbolic build with dict outputs keyed by user-facing names.\n        y_true = {\n            \"a\": backend.KerasTensor((3, 2)),\n            \"b\": backend.KerasTensor((3, 2)),\n        }\n        y_pred = {\n            \"a\": backend.KerasTensor((3, 2)),\n            \"b\": backend.KerasTensor((3, 2)),\n        }\n\n        # The build method should correctly map metrics for outputs 'a' and 'b',\n        # even when the op names do not match.\n        compile_metrics.build(y_true, y_pred)\n\n        # Make the two outputs produce different MSEs to verify mapping.\n        y_true = {\n            \"a\": np.zeros((3, 2), dtype=\"float32\"),\n            \"b\": np.zeros((3, 2), dtype=\"float32\"),\n        }\n        y_pred = {\n            # MSE(a) = 0.0\n            \"a\": np.zeros((3, 2), dtype=\"float32\"),\n            # MSE(b) = 1.0\n            \"b\": np.ones((3, 2), dtype=\"float32\"),\n        }\n        compile_metrics.update_state(y_true, y_pred)\n\n        result = compile_metrics.result()\n        self.assertIsInstance(result, dict)\n        self.assertEqual(\n            list(result.keys()),\n            [\"a_mean_squared_error\", \"b_mean_squared_error\"],\n        )\n        self.assertAllClose(result[\"a_mean_squared_error\"], 0.0)\n        self.assertAllClose(result[\"b_mean_squared_error\"], 1.0)\n\n    def test_deeply_nested_outputs_and_metrics(self):\n        \"\"\"Tests that when the outputs are deeply nested, we can declare the\n        metrics with the same deeply nested structure.\"\"\"\n\n        compile_metrics = CompileMetrics(\n            metrics={\n                \"a\": metrics_module.MeanSquaredError(name=\"mse_a\"),\n                \"b\": {\n                    \"c\": metrics_module.MeanSquaredError(name=\"mse_c\"),\n                    \"d\": [\n                        metrics_module.MeanSquaredError(name=\"mse_d1\"),\n                        metrics_module.MeanSquaredError(name=\"mse_d2\"),\n                    ],\n                },\n            },\n            weighted_metrics=None,\n            output_names=[\"dense\", \"dense_1\", \"dense_2\", \"dense_3\"],\n        )\n\n        y_true = {\n            \"a\": backend.KerasTensor((3, 2)),\n            \"b\": {\n                \"c\": backend.KerasTensor((3, 2)),\n                \"d\": [backend.KerasTensor((3, 2)), backend.KerasTensor((3, 2))],\n            },\n        }\n        y_pred = {\n            \"a\": backend.KerasTensor((3, 2)),\n            \"b\": {\n                \"c\": backend.KerasTensor((3, 2)),\n                \"d\": [backend.KerasTensor((3, 2)), backend.KerasTensor((3, 2))],\n            },\n        }\n\n        # The build method should correctly map deeply nested metrics.\n        compile_metrics.build(y_true, y_pred)\n\n        # Make the three outputs produce different MSEs to verify mapping.\n        y_true = {\n            \"a\": np.zeros((3, 2), dtype=\"float32\"),\n            \"b\": {\n                \"c\": np.zeros((3, 2), dtype=\"float32\"),\n                \"d\": [\n                    np.zeros((3, 2), dtype=\"float32\"),\n                    np.zeros((3, 2), dtype=\"float32\"),\n                ],\n            },\n        }\n        y_pred = {\n            # MSE(a) = 0.0\n            \"a\": np.zeros((3, 2), dtype=\"float32\"),\n            \"b\": {\n                # MSE(c) = 1.0\n                \"c\": np.ones((3, 2), dtype=\"float32\"),\n                \"d\": [\n                    # MSE(d1) = 4.0\n                    np.full((3, 2), 2.0, dtype=\"float32\"),\n                    # MSE(d2) = 9.0\n                    np.full((3, 2), 3.0, dtype=\"float32\"),\n                ],\n            },\n        }\n        compile_metrics.update_state(y_true, y_pred)\n\n        result = compile_metrics.result()\n        self.assertIsInstance(result, dict)\n        self.assertEqual(\n            list(result.keys()),\n            [\"mse_a\", \"mse_c\", \"mse_d1\", \"mse_d2\"],\n        )\n        self.assertAllClose(result[\"mse_a\"], 0.0)\n        self.assertAllClose(result[\"mse_c\"], 1.0)\n        self.assertAllClose(result[\"mse_d1\"], 4.0)\n        self.assertAllClose(result[\"mse_d2\"], 9.0)\n\n\nclass TestCompileLoss(testing.TestCase):\n    def test_single_output_case(self):\n        compile_loss = CompileLoss(\n            loss=losses_module.MeanSquaredError(),\n        )\n        # Test symbolic build\n        y_true = backend.KerasTensor((3, 4))\n        y_pred = backend.KerasTensor((3, 4))\n        compile_loss.build(y_true, y_pred)\n        # Test eager build\n        y_true = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]])\n        y_pred = np.array([[0.4, 0.1], [0.2, 0.6], [0.6, 0.1]])\n        compile_loss.build(y_true, y_pred)\n        value = compile_loss(y_true, y_pred)\n        self.assertAllClose(value, 0.068333, atol=1e-5)\n\n    def test_single_output_case_with_crossentropy_loss(self):\n        compile_loss = CompileLoss(loss=\"crossentropy\")\n\n        # Test symbolic build\n        y_true = backend.KerasTensor((3, 4))\n        y_pred = backend.KerasTensor((3, 4))\n        compile_loss.build(y_true, y_pred)\n        # Test eager build\n        y_true = np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]])\n        y_pred = np.array([[0.4, 0.1], [0.2, 0.6], [0.6, 0.1]])\n        compile_loss.build(y_true, y_pred)\n        value = compile_loss(y_true, y_pred)\n        self.assertAllClose(value, 0.706595, atol=1e-5)\n\n    @parameterized.parameters(True, False)\n    def test_list_output_case(self, broadcast):\n        if broadcast:\n            # Test broadcasting single loss to all outputs\n            compile_loss = CompileLoss(\n                loss=\"mse\",\n            )\n        else:\n            compile_loss = CompileLoss(\n                loss=[\"mse\", \"mse\"],\n            )\n        # Test symbolic build\n        y_true = [backend.KerasTensor((3, 4)), backend.KerasTensor((3, 4))]\n        y_pred = [backend.KerasTensor((3, 4)), backend.KerasTensor((3, 4))]\n        compile_loss.build(y_true, y_pred)\n        # Test eager build\n        y_true = [\n            np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]),\n            np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]),\n        ]\n        y_pred = [\n            np.array([[1.2, 1.1], [1.0, 0.9], [0.8, 0.7]]),\n            np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]),\n        ]\n        compile_loss.build(y_true, y_pred)\n        value = compile_loss(y_true, y_pred)\n        self.assertAllClose(value, 0.953333, atol=1e-5)\n\n    @parameterized.parameters(True, False)\n    def test_dict_output_case(self, broadcast):\n        if broadcast:\n            # Test broadcasting single loss to all outputs\n            compile_loss = CompileLoss(\n                loss=\"mse\",\n            )\n        else:\n            compile_loss = CompileLoss(\n                loss={\"a\": \"mse\", \"b\": \"mse\"},\n            )\n        # Test symbolic build\n        y_true = {\n            \"a\": backend.KerasTensor((3, 4)),\n            \"b\": backend.KerasTensor((3, 4)),\n        }\n        y_pred = {\n            \"a\": backend.KerasTensor((3, 4)),\n            \"b\": backend.KerasTensor((3, 4)),\n        }\n        compile_loss.build(y_true, y_pred)\n        # Test eager build\n        y_true = {\n            \"a\": np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]),\n            \"b\": np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]),\n        }\n        y_pred = {\n            \"a\": np.array([[1.2, 1.1], [1.0, 0.9], [0.8, 0.7]]),\n            \"b\": np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]),\n        }\n        sample_weight = {\n            \"a\": np.array([1.0, 2.0, 3.0]),\n            \"b\": np.array([3.0, 2.0, 1.0]),\n        }\n        compile_loss.build(y_true, y_pred)\n        value = compile_loss(y_true, y_pred, sample_weight)\n        self.assertAllClose(value, 1.266666, atol=1e-5)\n\n    def test_list_loss_dict_data(self):\n        compile_loss = CompileLoss(loss=[\"mse\", \"mae\"], output_names=[\"b\", \"a\"])\n        y_true = [backend.KerasTensor((3, 4)), backend.KerasTensor((3, 4))]\n        y_pred = [backend.KerasTensor((3, 4)), backend.KerasTensor((3, 4))]\n        compile_loss.build(y_true, y_pred)\n        y_true = {\n            \"a\": np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]),\n            \"b\": np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]),\n        }\n        y_pred = {\n            \"a\": np.array([[1.2, 1.1], [1.0, 0.9], [0.8, 0.7]]),\n            \"b\": np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]),\n        }\n        value = compile_loss(y_true, y_pred)\n        self.assertAllClose(value, 1.07666, atol=1e-5)\n\n    def test_struct_loss(self):\n        y_true = {\n            \"a\": {\n                \"c\": np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]),\n                \"d\": np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]),\n            },\n            \"b\": np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]),\n        }\n        y_pred = {\n            \"a\": {\n                \"c\": np.array([[1.2, 1.1], [1.0, 0.9], [0.8, 0.7]]),\n                \"d\": np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]),\n            },\n            \"b\": np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]),\n        }\n        loss = {\"a\": {\"c\": \"mse\", \"d\": \"mae\"}}\n        compile_loss = CompileLoss(loss=loss, output_names=[\"c\", \"d\", \"b\"])\n        y_true_symb = tree.map_structure(\n            lambda _: backend.KerasTensor((3, 4)), y_true\n        )\n        y_pred_symb = tree.map_structure(\n            lambda _: backend.KerasTensor((3, 4)), y_pred\n        )\n        compile_loss.build(y_true_symb, y_pred_symb)\n        value = compile_loss(y_true, y_pred)\n        self.assertAllClose(value, 1.07666, atol=1e-5)\n\n    def test_struct_loss_valid_weights(self):\n        y_true = {\n            \"a\": np.array([1, 2]),\n            \"b\": np.array([1, 2]),\n        }\n        y_pred = {\n            \"a\": np.array([3, 4]),\n            \"b\": np.array([3, 4]),\n        }\n        loss = {\"a\": \"mse\", \"b\": \"mse\"}\n        compile_loss = CompileLoss(\n            loss=loss,\n            output_names=[\"a\", \"b\"],\n            loss_weights={\n                \"a\": np.ones((2,)),\n                \"b\": np.zeros((2,)),\n            },\n        )\n        compile_loss.build(y_true, y_pred)\n        value = compile_loss(y_true, y_pred)\n        self.assertAllClose(value, 4)\n\n        # Metrics still report unweighted loss.\n        a_loss_mean, b_loss_mean = compile_loss.metrics\n        self.assertEqual(a_loss_mean.result(), 4)\n        self.assertEqual(b_loss_mean.result(), 4)\n\n    def test_struct_loss_invalid_weights(self):\n        y_true = {\n            \"a\": {\n                \"c\": np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]),\n                \"d\": np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]),\n            },\n            \"b\": np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]),\n        }\n        y_pred = {\n            \"a\": {\n                \"c\": np.array([[1.2, 1.1], [1.0, 0.9], [0.8, 0.7]]),\n                \"d\": np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]),\n            },\n            \"b\": np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]),\n        }\n        loss = {\"a\": {\"c\": \"mse\", \"d\": \"mae\"}}\n        compile_loss = CompileLoss(\n            loss=loss, output_names=[\"c\", \"d\", \"b\"], loss_weights=[1]\n        )\n        y_true_symb = tree.map_structure(\n            lambda _: backend.KerasTensor((3, 4)), y_true\n        )\n        y_pred_symb = tree.map_structure(\n            lambda _: backend.KerasTensor((3, 4)), y_pred\n        )\n        with self.assertRaisesRegex(\n            ValueError, \"must match the number of losses\"\n        ):\n            compile_loss.build(y_true_symb, y_pred_symb)\n\n    def test_struct_loss_indice_path(self):\n        y_true = {\n            \"a\": (\n                np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]),\n                np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]),\n            ),\n            \"b\": np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]),\n        }\n        y_pred = {\n            \"a\": (\n                np.array([[1.2, 1.1], [1.0, 0.9], [0.8, 0.7]]),\n                np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]),\n            ),\n            \"b\": np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]),\n        }\n        loss = {\"a\": [\"mse\", \"mae\"]}\n        compile_loss = CompileLoss(loss=loss, output_names=[\"c\", \"d\", \"b\"])\n        y_true_symb = tree.map_structure(\n            lambda _: backend.KerasTensor((3, 4)), y_true\n        )\n        y_pred_symb = tree.map_structure(\n            lambda _: backend.KerasTensor((3, 4)), y_pred\n        )\n        compile_loss.build(y_true_symb, y_pred_symb)\n        value = compile_loss(y_true, y_pred)\n        self.assertAllClose(value, 1.07666, atol=1e-5)\n\n    def test_struct_loss_namedtuple(self):\n        Point = namedtuple(\"Point\", [\"x\", \"y\"])\n        y_true = {\n            \"a\": Point(\n                np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]),\n                np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]),\n            ),\n            \"b\": np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]),\n        }\n        y_pred = {\n            \"a\": Point(\n                np.array([[1.2, 1.1], [1.0, 0.9], [0.8, 0.7]]),\n                np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]),\n            ),\n            \"b\": np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]),\n        }\n        loss = {\"a\": Point(\"mse\", \"mae\")}\n        compile_loss = CompileLoss(loss=loss, output_names=[\"c\", \"d\", \"b\"])\n        y_true_symb = tree.map_structure(\n            lambda _: backend.KerasTensor((3, 4)), y_true\n        )\n        y_pred_symb = tree.map_structure(\n            lambda _: backend.KerasTensor((3, 4)), y_pred\n        )\n        compile_loss.build(y_true_symb, y_pred_symb)\n        value = compile_loss(y_true, y_pred)\n        self.assertAllClose(value, 1.07666, atol=1e-5)\n\n    def test_struct_loss_invalid_path(self):\n        y_true = {\n            \"a\": {\n                \"c\": np.array([[0.1, 0.2], [0.3, 0.4], [0.5, 0.6]]),\n                \"d\": np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]),\n            },\n            \"b\": np.array([[0.7, 0.8], [0.9, 1.0], [1.1, 1.2]]),\n        }\n        y_pred = {\n            \"a\": {\n                \"c\": np.array([[1.2, 1.1], [1.0, 0.9], [0.8, 0.7]]),\n                \"d\": np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]),\n            },\n            \"b\": np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]),\n        }\n        loss = {\"a\": {\"c\": \"mse\"}, \"b\": {\"d\": \"mae\"}}\n        compile_loss = CompileLoss(loss=loss, output_names=[\"c\", \"d\", \"b\"])\n        y_true_symb = tree.map_structure(\n            lambda _: backend.KerasTensor((3, 4)), y_true\n        )\n        y_pred_symb = tree.map_structure(\n            lambda _: backend.KerasTensor((3, 4)), y_pred\n        )\n        with self.assertRaisesRegex(\n            KeyError, \"can't be found in the model's output\"\n        ):\n            compile_loss.build(y_true_symb, y_pred_symb)\n\n    def test_different_container_types(self):\n        y1, y2, y3 = np.array([[1]]), np.array([[2]]), np.array([[3]])\n        y_true = ([{\"a\": y1}, {\"b\": ([y2], y3)}],)\n        y_pred = [({\"a\": y1}, {\"b\": [(y2,), y3]})]\n        loss = \"mse\"\n        compile_loss = CompileLoss(loss=loss, output_names=[\"a\", \"b\", \"c\"])\n        y_true_symb = tree.map_structure(\n            lambda _: backend.KerasTensor((1, 1)), y_true\n        )\n        y_pred_symb = tree.map_structure(\n            lambda _: backend.KerasTensor((1, 1)), y_pred\n        )\n        compile_loss.build(y_true_symb, y_pred_symb)\n        compile_loss(y_true, y_pred)\n\n    def test_structure_mismatch(self):\n        y_true = [np.array([[1]]), np.array([[1]])]\n        y_pred = [np.array([[1]]), np.array([[1]])]\n        loss = [\"mse\", \"mse\"]\n        compile_loss = CompileLoss(loss=loss, output_names=[\"a\", \"b\"])\n        y_true_symb = tree.map_structure(\n            lambda _: backend.KerasTensor((1, 1)), y_true\n        )\n        y_pred_symb = tree.map_structure(\n            lambda _: backend.KerasTensor((1, 1)), y_pred\n        )\n        compile_loss.build(y_true_symb, y_pred_symb)\n        with self.assertRaisesRegex(\n            ValueError, \"y_true and y_pred have different structures.\"\n        ):\n            wrong_struc_y_true = [np.array([[1]])]\n            compile_loss(wrong_struc_y_true, y_pred)\n\n    @parameterized.parameters(\n        [\"mse\", None, None],\n        [None, \"mse\", None],\n        [None, None, \"mse\"],\n        [None, \"mse\", \"mse\"],\n        [\"mse\", None, \"mse\"],\n    )\n    def test_y_true_partial_y_pred_span(self, *loss_conf):\n        loss_conf = list(loss_conf)\n        ones = np.ones((320, 3))\n        zeros = np.zeros((320, 3))\n        twos = np.ones((320, 3)) * 2\n        y_pred = [zeros, ones, twos]\n        y_true = [y for y, loss in zip(y_pred, loss_conf) if loss is not None]\n        y_true = y_true[0] if len(y_true) == 1 else y_true\n        compile_loss = CompileLoss(loss=loss_conf, output_names=[\"a\", \"b\", \"c\"])\n        # build call\n        compile_loss(y_true, y_pred)\n        # built call\n        loss = compile_loss(y_true, y_pred)\n        self.assertEqual(loss, 0.0)\n"
  },
  {
    "path": "keras/src/trainers/data_adapters/__init__.py",
    "content": "import types\n\nfrom keras.src.distribution import distribution_lib\nfrom keras.src.trainers.data_adapters import array_data_adapter\nfrom keras.src.trainers.data_adapters import data_adapter\nfrom keras.src.trainers.data_adapters import py_dataset_adapter\nfrom keras.src.trainers.data_adapters.array_data_adapter import ArrayDataAdapter\nfrom keras.src.trainers.data_adapters.generator_data_adapter import (\n    GeneratorDataAdapter,\n)\nfrom keras.src.trainers.data_adapters.grain_dataset_adapter import (\n    GrainDatasetAdapter,\n)\nfrom keras.src.trainers.data_adapters.py_dataset_adapter import PyDatasetAdapter\nfrom keras.src.trainers.data_adapters.tf_dataset_adapter import TFDatasetAdapter\nfrom keras.src.trainers.data_adapters.torch_data_loader_adapter import (\n    TorchDataLoaderAdapter,\n)\n\n\ndef get_data_adapter(\n    x,\n    y=None,\n    sample_weight=None,\n    batch_size=None,\n    steps_per_epoch=None,\n    shuffle=False,\n    class_weight=None,\n):\n    # Allow passing a custom data adapter.\n    if isinstance(x, data_adapter.DataAdapter):\n        return x\n\n    # Check for multi-process/worker distribution.\n    distribution = distribution_lib.distribution()\n    if (\n        distribution is not None\n        and getattr(distribution, \"_is_multi_process\", False)\n        and getattr(distribution, \"auto_shard_dataset\", False)\n        and not is_tf_dataset(x)\n    ):\n        raise ValueError(\n            \"When using a multi-worker distribution with auto-sharding enabled, \"\n            \"the data must be provided as a `tf.data.Dataset` instance. \"\n            f\"Received: type(x)={type(x)}. \"\n            \"If the dataset is already sharded across workers, then set \"\n            \"`distribution.auto_shard_dataset = False`.\"\n        )\n\n    if array_data_adapter.can_convert_arrays((x, y, sample_weight)):\n        return ArrayDataAdapter(\n            x,\n            y,\n            sample_weight=sample_weight,\n            class_weight=class_weight,\n            shuffle=shuffle,\n            batch_size=batch_size,\n            steps=steps_per_epoch,\n        )\n    elif is_tf_dataset(x):\n        # Unsupported args: y, sample_weight, shuffle\n        if y is not None:\n            raise_unsupported_arg(\"y\", \"the targets\", \"tf.data.Dataset\")\n        if sample_weight is not None:\n            raise_unsupported_arg(\n                \"sample_weights\", \"the sample weights\", \"tf.data.Dataset\"\n            )\n        return TFDatasetAdapter(\n            x, class_weight=class_weight, distribution=distribution\n        )\n        # TODO: should we warn or not?\n        # warnings.warn(\n        #     \"`shuffle=True` was passed, but will be ignored since the \"\n        #     \"data `x` was provided as a tf.data.Dataset. The Dataset is \"\n        #     \"expected to already be shuffled \"\n        #     \"(via `.shuffle(tf.data.AUTOTUNE)`)\"\n        # )\n    elif isinstance(x, py_dataset_adapter.PyDataset):\n        if y is not None:\n            raise_unsupported_arg(\"y\", \"the targets\", \"PyDataset\")\n        if sample_weight is not None:\n            raise_unsupported_arg(\n                \"sample_weights\", \"the sample weights\", \"PyDataset\"\n            )\n        return PyDatasetAdapter(x, class_weight=class_weight, shuffle=shuffle)\n        # TODO: should we warn or not?\n        # if x.num_batches is None and shuffle:\n        #     warnings.warn(\n        #         \"`shuffle=True` was passed, but will be ignored since the \"\n        #         \"data `x` was provided as a infinite PyDataset. The \"\n        #         \"PyDataset is expected to already be shuffled.\"\n        # )\n    elif is_torch_dataloader(x):\n        if y is not None:\n            raise_unsupported_arg(\"y\", \"the targets\", \"torch DataLoader\")\n        if sample_weight is not None:\n            raise_unsupported_arg(\n                \"sample_weights\", \"the sample weights\", \"torch DataLoader\"\n            )\n        if class_weight is not None:\n            raise ValueError(\n                \"Argument `class_weight` is not supported for torch \"\n                f\"DataLoader inputs. You can modify your `__getitem__ ` method\"\n                \" to return input tensor, label and class_weight. \"\n                \"Alternatively, use a custom training loop. See the User Guide \"\n                \"https://keras.io/guides/custom_train_step_in_torch/\"\n                \"#supporting-sampleweight-amp-classweight for more details. \"\n                f\"Received: class_weight={class_weight}\"\n            )\n        return TorchDataLoaderAdapter(x)\n        # TODO: should we warn or not?\n        # warnings.warn(\n        #     \"`shuffle=True` was passed, but will be ignored since the \"\n        #     \"data `x` was provided as a torch DataLoader. The DataLoader \"\n        #     \"is expected to already be shuffled.\"\n        # )\n    elif is_grain_dataset(x):\n        if y is not None:\n            raise_unsupported_arg(\n                \"y\", \"the targets\", \"grain.Dataset and grain.DataLoader\"\n            )\n        if sample_weight is not None:\n            raise_unsupported_arg(\n                \"sample_weights\",\n                \"the sample weights\",\n                \"grain.Dataset and grain.DataLoader\",\n            )\n        if class_weight is not None:\n            raise ValueError(\n                \"Argument `class_weight` is not supported for grain.Dataset \"\n                f\"and grain.DataLoader inputs. You can modify your \"\n                \"`__getitem__ ` method to return input tensor, label and \"\n                \"class_weight. \"\n                f\"Received: class_weight={class_weight}\"\n            )\n        return GrainDatasetAdapter(x)\n        # TODO: should we warn or not?\n        # warnings.warn(\n        #     \"`shuffle=True` was passed, but will be ignored since the \"\n        #     \"data `x` was provided as a grain dataset. The grain dataset \"\n        #     \"is expected to already be shuffled.\"\n        # )\n    elif isinstance(x, types.GeneratorType):\n        if y is not None:\n            raise_unsupported_arg(\"y\", \"the targets\", \"PyDataset\")\n        if sample_weight is not None:\n            raise_unsupported_arg(\n                \"sample_weights\", \"the sample weights\", \"PyDataset\"\n            )\n        if class_weight is not None:\n            raise ValueError(\n                \"Argument `class_weight` is not supported for Python \"\n                f\"generator inputs. Received: class_weight={class_weight}\"\n            )\n        return GeneratorDataAdapter(x)\n        # TODO: should we warn or not?\n        # warnings.warn(\n        #     \"`shuffle=True` was passed, but will be ignored since the \"\n        #     \"data `x` was provided as a generator. The generator \"\n        #     \"is expected to yield already-shuffled data.\"\n        # )\n    else:\n        raise ValueError(f\"Unrecognized data type: x={x} (of type {type(x)})\")\n\n\ndef raise_unsupported_arg(arg_name, arg_description, input_type):\n    raise ValueError(\n        f\"When providing `x` as a {input_type}, `{arg_name}` \"\n        f\"should not be passed. Instead, {arg_description} should \"\n        f\"be included as part of the {input_type}.\"\n    )\n\n\ndef is_tf_dataset(x):\n    if hasattr(x, \"__class__\"):\n        for parent in x.__class__.__mro__:\n            if parent.__name__ in (\n                \"DatasetV2\",\n                \"DistributedDataset\",\n                \"DistributedDatasetsFromFunction\",\n            ) and \"tensorflow.python.\" in str(parent.__module__):\n                return True\n    return False\n\n\ndef is_torch_dataloader(x):\n    if hasattr(x, \"__class__\"):\n        for parent in x.__class__.__mro__:\n            if parent.__name__ == \"DataLoader\" and \"torch.utils.data\" in str(\n                parent.__module__\n            ):\n                return True\n    return False\n\n\ndef is_grain_dataset(x):\n    if hasattr(x, \"__class__\"):\n        for parent in x.__class__.__mro__:\n            if parent.__name__ in (\n                \"MapDataset\",\n                \"IterDataset\",\n                \"DataLoader\",\n            ) and \"grain\" in str(parent.__module__):\n                return True\n    return False\n"
  },
  {
    "path": "keras/src/trainers/data_adapters/array_data_adapter.py",
    "content": "import functools\nimport math\n\nimport numpy as np\n\nfrom keras.src import tree\nfrom keras.src.trainers.data_adapters import array_slicing\nfrom keras.src.trainers.data_adapters import data_adapter_utils\nfrom keras.src.trainers.data_adapters.data_adapter import DataAdapter\n\n\nclass ArrayDataAdapter(DataAdapter):\n    \"\"\"Adapter for array-like objects, e.g. TF/JAX Tensors, NumPy arrays.\"\"\"\n\n    def __init__(\n        self,\n        x,\n        y=None,\n        sample_weight=None,\n        batch_size=None,\n        steps=None,\n        shuffle=False,\n        class_weight=None,\n    ):\n        if not can_convert_arrays((x, y, sample_weight)):\n            raise ValueError(\n                \"Expected all elements of `x` to be array-like. \"\n                f\"Received invalid types: x={x}\"\n            )\n\n        if sample_weight is not None:\n            if class_weight is not None:\n                raise ValueError(\n                    \"You cannot `class_weight` and `sample_weight` \"\n                    \"at the same time.\"\n                )\n            if tree.is_nested(y):\n                if isinstance(sample_weight, (list, tuple, dict)):\n                    try:\n                        tree.assert_same_structure(y, sample_weight)\n                    except ValueError:\n                        raise ValueError(\n                            \"You should provide one `sample_weight` array per \"\n                            \"output in `y`. The two structures did not match:\\n\"\n                            f\"- y: {y}\\n\"\n                            f\"- sample_weight: {sample_weight}\\n\"\n                        )\n                else:\n                    is_samplewise = len(sample_weight.shape) == 1 or (\n                        len(sample_weight.shape) == 2\n                        and sample_weight.shape[1] == 1\n                    )\n                    if not is_samplewise:\n                        raise ValueError(\n                            \"For a model with multiple outputs, when providing \"\n                            \"a single `sample_weight` array, it should only \"\n                            \"have one scalar score per sample \"\n                            \"(i.e. shape `(num_samples,)`). If you want to use \"\n                            \"non-scalar sample weights, pass a `sample_weight` \"\n                            \"argument with one array per model output.\"\n                        )\n                    # Replicate the same sample_weight array on all outputs.\n                    sample_weight = tree.map_structure(\n                        lambda _: sample_weight, y\n                    )\n        if class_weight is not None:\n            if tree.is_nested(y):\n                raise ValueError(\n                    \"`class_weight` is only supported for Models with a single \"\n                    \"output.\"\n                )\n            sample_weight = data_adapter_utils.class_weight_to_sample_weights(\n                y, class_weight\n            )\n\n        inputs = data_adapter_utils.pack_x_y_sample_weight(x, y, sample_weight)\n\n        data_adapter_utils.check_data_cardinality(inputs)\n        num_samples = set(\n            i.shape[0] for i in tree.flatten(inputs) if i is not None\n        ).pop()\n        self._num_samples = num_samples\n        self._inputs = inputs\n\n        # If batch_size is not passed but steps is, calculate from the input\n        # data.  Defaults to `32` for backwards compatibility.\n        if not batch_size:\n            batch_size = int(math.ceil(num_samples / steps)) if steps else 32\n\n        self._size = int(math.ceil(num_samples / batch_size))\n        self._batch_size = batch_size\n        self._partial_batch_size = num_samples % batch_size\n        self._shuffle = shuffle\n\n    def get_numpy_iterator(self):\n        inputs = array_slicing.convert_to_sliceable(\n            self._inputs, target_backend=\"numpy\"\n        )\n\n        def slice_and_convert_to_numpy(sliceable, indices=None):\n            x = sliceable[indices]\n            x = sliceable.convert_to_numpy(x)\n            return x\n\n        return self._get_iterator(slice_and_convert_to_numpy, inputs)\n\n    def get_tf_dataset(self):\n        from keras.src.utils.module_utils import tensorflow as tf\n\n        shuffle = self._shuffle\n        batch_size = self._batch_size\n        num_samples = self._num_samples\n        num_full_batches = int(self._num_samples // batch_size)\n\n        # Vectorized version of shuffle.\n        # This is a performance improvement over using `from_tensor_slices`.\n        # The indices of the data are shuffled and batched, and these indices\n        # are then zipped with the data and used to extract a batch of the data\n        # at each step. The performance improvements here come from:\n        # 1. vectorized batch using gather\n        # 2. parallelized map\n        # 3. pipelined permutation generation\n        # 4. optimized permutation batching\n        # 5. disabled static optimizations\n\n        indices_dataset = tf.data.Dataset.range(1)\n\n        def permutation(_):\n            # It turns out to be more performant to make a new set of indices\n            # rather than reusing the same range Tensor. (presumably because of\n            # buffer forwarding.)\n            indices = tf.range(num_samples, dtype=tf.int64)\n            if shuffle and shuffle != \"batch\":\n                indices = tf.random.shuffle(indices)\n            return indices\n\n        # We prefetch a single element. Computing large permutations can take\n        # quite a while so we don't want to wait for prefetching over an epoch\n        # boundary to trigger the next permutation. On the other hand, too many\n        # simultaneous shuffles can contend on a hardware level and degrade all\n        # performance.\n        indices_dataset = indices_dataset.map(permutation).prefetch(1)\n\n        def slice_batch_indices(indices):\n            \"\"\"Convert a Tensor of indices into a dataset of batched indices.\n\n            This step can be accomplished in several ways. The most natural is\n            to slice the Tensor in a Dataset map. (With a condition on the upper\n            index to handle the partial batch.) However it turns out that\n            coercing the Tensor into a shape which is divisible by the batch\n            size (and handling the last partial batch separately) allows for a\n            much more favorable memory access pattern and improved performance.\n\n            Args:\n                indices: Tensor which determines the data order for an entire\n                    epoch.\n\n            Returns:\n                A Dataset of batched indices.\n            \"\"\"\n            num_in_full_batch = num_full_batches * batch_size\n            first_k_indices = tf.slice(indices, [0], [num_in_full_batch])\n            first_k_indices = tf.reshape(\n                first_k_indices, [num_full_batches, batch_size]\n            )\n\n            flat_dataset = tf.data.Dataset.from_tensor_slices(first_k_indices)\n            if self._partial_batch_size:\n                index_remainder = tf.data.Dataset.from_tensors(\n                    tf.slice(\n                        indices, [num_in_full_batch], [self._partial_batch_size]\n                    )\n                )\n                flat_dataset = flat_dataset.concatenate(index_remainder)\n\n            return flat_dataset\n\n        def slice_inputs(indices_dataset, inputs):\n            \"\"\"Slice inputs into a Dataset of batches.\n\n            Given a Dataset of batch indices and the unsliced inputs,\n            this step slices the inputs in a parallelized fashion\n            and produces a dataset of input batches.\n\n            Args:\n                indices_dataset: A Dataset of batched indices.\n                inputs: A python data structure that contains the inputs,\n                    targets, and possibly sample weights.\n\n            Returns:\n                A Dataset of input batches matching the batch indices.\n            \"\"\"\n            inputs = array_slicing.convert_to_sliceable(\n                self._inputs, target_backend=\"tensorflow\"\n            )\n            inputs = tree.lists_to_tuples(inputs)\n\n            dataset = tf.data.Dataset.zip(\n                (indices_dataset, tf.data.Dataset.from_tensors(inputs).repeat())\n            )\n\n            def grab_batch(i, data):\n                def grab_one(x):\n                    if isinstance(x, array_slicing.TensorflowSparseWrapper):\n                        return array_slicing.slice_tensorflow_sparse_wrapper(\n                            x, i\n                        )\n                    if isinstance(x, (list, tuple, dict)):\n                        return None\n                    if tf.is_tensor(x):\n                        return tf.gather(x, i, axis=0)\n                    return x\n\n                return tree.traverse(grab_one, data)\n\n            dataset = dataset.map(\n                grab_batch, num_parallel_calls=tf.data.AUTOTUNE\n            )\n\n            # Default optimizations are disabled to avoid the overhead of\n            # (unnecessary) input pipeline graph serialization & deserialization\n            options = tf.data.Options()\n            options.experimental_optimization.apply_default_optimizations = (\n                False\n            )\n            if self._shuffle:\n                options.experimental_external_state_policy = (\n                    tf.data.experimental.ExternalStatePolicy.IGNORE\n                )\n            dataset = dataset.with_options(options)\n            return dataset\n\n        indices_dataset = indices_dataset.flat_map(slice_batch_indices)\n        if shuffle == \"batch\":\n            indices_dataset = indices_dataset.map(tf.random.shuffle)\n\n        dataset = slice_inputs(indices_dataset, self._inputs)\n\n        options = tf.data.Options()\n        options.experimental_distribute.auto_shard_policy = (\n            tf.data.experimental.AutoShardPolicy.DATA\n        )\n        dataset = dataset.with_options(options)\n        return dataset.prefetch(tf.data.AUTOTUNE)\n\n    def get_jax_iterator(self):\n        inputs = array_slicing.convert_to_sliceable(\n            self._inputs, target_backend=\"jax\"\n        )\n\n        def slice_and_convert_to_jax(sliceable, indices=None):\n            x = sliceable[indices]\n            x = sliceable.convert_to_jax_compatible(x)\n            return x\n\n        return self._get_iterator(slice_and_convert_to_jax, inputs)\n\n    def get_torch_dataloader(self):\n        import torch\n\n        from keras.src.backend.torch.core import convert_to_tensor\n\n        class ArrayDataset(torch.utils.data.Dataset):\n            def __init__(self, array):\n                self.array = array\n\n            def __getitems__(self, indices):\n                def slice_and_convert(sliceable):\n                    x = sliceable[indices]\n                    x = sliceable.convert_to_torch_compatible(x)\n                    x = convert_to_tensor(x)\n                    return x\n\n                return tree.map_structure(\n                    slice_and_convert, self.array, none_is_leaf=False\n                )\n\n            def __len__(self):\n                return len(self.array[0])\n\n        class RandomBatchSampler(torch.utils.data.Sampler):\n            def __init__(self, sampler):\n                self.sampler = sampler\n\n            def __iter__(self):\n                for batch in self.sampler:\n                    yield [batch[i] for i in torch.randperm(len(batch))]\n\n            def __len__(self):\n                return len(self.sampler)\n\n        if self._shuffle == \"batch\":\n            batch_sampler = RandomBatchSampler(\n                torch.utils.data.BatchSampler(\n                    range(self._num_samples),\n                    batch_size=self._batch_size,\n                    drop_last=False,\n                )\n            )\n        elif self._shuffle:\n            batch_sampler = torch.utils.data.BatchSampler(\n                torch.utils.data.RandomSampler(range(self._num_samples)),\n                batch_size=self._batch_size,\n                drop_last=False,\n            )\n        else:\n            batch_sampler = torch.utils.data.BatchSampler(\n                torch.utils.data.SequentialSampler(range(self._num_samples)),\n                batch_size=self._batch_size,\n                drop_last=False,\n            )\n\n        # Because ArrayDataset.__getitems__ returns full batches organized in\n        # the expected structure, there is nothing to collate.\n        def no_op_collate(batch):\n            return batch\n\n        inputs = array_slicing.convert_to_sliceable(\n            self._inputs, target_backend=\"torch\"\n        )\n        dataset = ArrayDataset(inputs)\n        return torch.utils.data.DataLoader(\n            dataset, batch_sampler=batch_sampler, collate_fn=no_op_collate\n        )\n\n    def _get_iterator(self, slice_and_convert_fn, inputs):\n        global_permutation = None\n        if self._shuffle and self._shuffle != \"batch\":\n            global_permutation = np.random.permutation(self._num_samples)\n\n        for i in range(self._size):\n            start = i * self._batch_size\n            stop = min((i + 1) * self._batch_size, self._num_samples)\n            if self._shuffle == \"batch\":\n                indices = np.random.permutation(stop - start) + start\n            elif self._shuffle:\n                indices = global_permutation[start:stop]\n            else:\n                indices = slice(start, stop)\n\n            slice_indices_and_convert_fn = functools.partial(\n                slice_and_convert_fn, indices=indices\n            )\n            yield tree.map_structure(\n                slice_indices_and_convert_fn, inputs, none_is_leaf=False\n            )\n\n    @property\n    def num_batches(self):\n        return self._size\n\n    @property\n    def batch_size(self):\n        return self._batch_size\n\n    @property\n    def has_partial_batch(self):\n        return self._partial_batch_size > 0\n\n    @property\n    def partial_batch_size(self):\n        return self._partial_batch_size or None\n\n\ndef can_convert_arrays(arrays):\n    \"\"\"Check if array like-inputs can be handled by `ArrayDataAdapter`\n\n    Args:\n        inputs: Structure of `Tensor`s, NumPy arrays, or tensor-like.\n\n    Returns:\n        `True` if `arrays` can be handled by `ArrayDataAdapter`, `False`\n        otherwise.\n    \"\"\"\n\n    return all(\n        tree.flatten(tree.map_structure(array_slicing.can_slice_array, arrays))\n    )\n"
  },
  {
    "path": "keras/src/trainers/data_adapters/array_data_adapter_test.py",
    "content": "import jax\nimport jax.experimental.sparse as jax_sparse\nimport numpy as np\nimport pandas\nimport scipy\nimport tensorflow as tf\nimport torch\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import testing\nfrom keras.src.testing.test_utils import named_product\nfrom keras.src.trainers.data_adapters import array_data_adapter\n\n\nclass TestArrayDataAdapter(testing.TestCase):\n    def make_array(self, array_type, shape, dtype):\n        x = np.array([[i] * shape[1] for i in range(shape[0])], dtype=dtype)\n        if array_type == \"np\":\n            return x\n        elif array_type == \"tf\":\n            return tf.constant(x)\n        elif array_type == \"tf_ragged\":\n            return tf.RaggedTensor.from_tensor(x)\n        elif array_type == \"tf_sparse\":\n            return tf.sparse.from_dense(x)\n        elif array_type == \"jax\":\n            return jax.numpy.array(x)\n        elif array_type == \"jax_sparse\":\n            return jax_sparse.BCOO.fromdense(x)\n        elif array_type == \"torch\":\n            return torch.as_tensor(x)\n        elif array_type == \"pandas_data_frame\":\n            return pandas.DataFrame(x)\n        elif array_type == \"pandas_series\":\n            return pandas.Series(x[:, 0])\n        elif array_type == \"scipy_sparse\":\n            return scipy.sparse.coo_matrix(x)\n\n    @parameterized.named_parameters(\n        named_product(\n            array_type=[\n                \"np\",\n                \"tf\",\n                \"tf_ragged\",\n                \"tf_sparse\",\n                \"jax\",\n                \"jax_sparse\",\n                \"torch\",\n                \"pandas_data_frame\",\n                \"pandas_series\",\n                \"scipy_sparse\",\n            ],\n            array_dtype=[\"float32\", \"float64\"],\n            shuffle=[False, \"batch\", True],\n        )\n    )\n    def test_basic_flow(self, array_type, array_dtype, shuffle):\n        x = self.make_array(array_type, (34, 4), array_dtype)\n        y = self.make_array(array_type, (34, 2), \"int32\")\n        xdim1 = 1 if array_type == \"pandas_series\" else 4\n        ydim1 = 1 if array_type == \"pandas_series\" else 2\n\n        adapter = array_data_adapter.ArrayDataAdapter(\n            x,\n            y=y,\n            sample_weight=None,\n            batch_size=16,\n            steps=None,\n            shuffle=shuffle,\n        )\n        self.assertEqual(adapter.num_batches, 3)\n        self.assertEqual(adapter.batch_size, 16)\n        self.assertEqual(adapter.has_partial_batch, True)\n        self.assertEqual(adapter.partial_batch_size, 2)\n\n        if backend.backend() == \"tensorflow\":\n            it = adapter.get_tf_dataset()\n            if array_type == \"tf_ragged\":\n                expected_class = tf.RaggedTensor\n                xdim1 = None\n                ydim1 = None\n            elif array_type in (\"tf_sparse\", \"jax_sparse\", \"scipy_sparse\"):\n                expected_class = tf.SparseTensor\n            else:\n                expected_class = tf.Tensor\n        elif backend.backend() == \"jax\":\n            it = adapter.get_jax_iterator()\n            if array_type in (\"tf_sparse\", \"jax_sparse\", \"scipy_sparse\"):\n                expected_class = jax_sparse.JAXSparse\n            else:\n                expected_class = np.ndarray\n        elif backend.backend() == \"torch\":\n            it = adapter.get_torch_dataloader()\n            expected_class = torch.Tensor\n        else:\n            it = adapter.get_numpy_iterator()\n            expected_class = np.ndarray\n\n        x_order = []\n        y_order = []\n        for i, batch in enumerate(it):\n            self.assertEqual(len(batch), 2)\n            bx, by = batch\n            self.assertIsInstance(bx, expected_class)\n            self.assertIsInstance(by, expected_class)\n            self.assertEqual(\n                backend.standardize_dtype(bx.dtype), backend.floatx()\n            )\n            self.assertEqual(backend.standardize_dtype(by.dtype), \"int32\")\n            if i < 2:\n                self.assertEqual(bx.shape, (16, xdim1))\n                self.assertEqual(by.shape, (16, ydim1))\n            else:\n                self.assertEqual(bx.shape, (2, xdim1))\n                self.assertEqual(by.shape, (2, ydim1))\n\n            if isinstance(bx, tf.SparseTensor):\n                bx = tf.sparse.to_dense(bx)\n                by = tf.sparse.to_dense(by)\n            if isinstance(bx, jax_sparse.JAXSparse):\n                bx = bx.todense()\n                by = by.todense()\n            x_batch_order = [float(bx[j, 0]) for j in range(bx.shape[0])]\n            y_batch_order = [float(by[j, 0]) for j in range(by.shape[0])]\n            x_order.extend(x_batch_order)\n            y_order.extend(y_batch_order)\n\n            if shuffle == \"batch\":\n                self.assertAllClose(\n                    sorted(x_batch_order),\n                    list(range(i * 16, i * 16 + bx.shape[0])),\n                )\n\n        self.assertAllClose(x_order, y_order)\n        if shuffle:\n            self.assertNotAllClose(x_order, list(range(34)))\n        else:\n            self.assertAllClose(x_order, list(range(34)))\n\n    def test_multi_inputs_and_outputs(self):\n        x1 = np.random.random((34, 1))\n        x2 = np.random.random((34, 2))\n        y1 = np.random.random((34, 3))\n        y2 = np.random.random((34, 4))\n        sw = np.random.random((34,))\n        adapter = array_data_adapter.ArrayDataAdapter(\n            x={\"x1\": x1, \"x2\": x2},\n            y=[y1, y2],\n            sample_weight=sw,\n            batch_size=16,\n            steps=None,\n            shuffle=False,\n        )\n        gen = adapter.get_numpy_iterator()\n        for i, batch in enumerate(gen):\n            self.assertEqual(len(batch), 3)\n            bx, by, bw = batch\n            self.assertIsInstance(bx, dict)\n            self.assertIsInstance(by, list)\n            self.assertIsInstance(bw, list)\n\n            self.assertIsInstance(bx[\"x1\"], np.ndarray)\n            self.assertIsInstance(bx[\"x2\"], np.ndarray)\n            self.assertIsInstance(by[0], np.ndarray)\n            self.assertIsInstance(by[1], np.ndarray)\n            self.assertIsInstance(bw[0], np.ndarray)\n            self.assertIsInstance(bw[1], np.ndarray)\n\n            self.assertEqual(bx[\"x1\"].dtype, by[0].dtype)\n            self.assertEqual(bx[\"x1\"].dtype, backend.floatx())\n            if i < 2:\n                self.assertEqual(bx[\"x1\"].shape, (16, 1))\n                self.assertEqual(bx[\"x2\"].shape, (16, 2))\n                self.assertEqual(by[0].shape, (16, 3))\n                self.assertEqual(by[1].shape, (16, 4))\n                self.assertEqual(bw[0].shape, (16,))\n                self.assertEqual(bw[1].shape, (16,))\n            else:\n                self.assertEqual(bx[\"x1\"].shape, (2, 1))\n                self.assertEqual(by[0].shape, (2, 3))\n                self.assertEqual(bw[0].shape, (2,))\n                self.assertEqual(bw[1].shape, (2,))\n        ds = adapter.get_tf_dataset()\n        for i, batch in enumerate(ds):\n            self.assertEqual(len(batch), 3)\n            bx, by, bw = batch\n            self.assertIsInstance(bx, dict)\n            # NOTE: the y list was converted to a tuple for tf.data\n            # compatibility.\n            self.assertIsInstance(by, tuple)\n            self.assertIsInstance(bw, tuple)\n\n            self.assertIsInstance(bx[\"x1\"], tf.Tensor)\n            self.assertIsInstance(bx[\"x2\"], tf.Tensor)\n            self.assertIsInstance(by[0], tf.Tensor)\n            self.assertIsInstance(by[1], tf.Tensor)\n            self.assertIsInstance(bw[0], tf.Tensor)\n            self.assertIsInstance(bw[1], tf.Tensor)\n\n            self.assertEqual(bx[\"x1\"].dtype, by[0].dtype)\n            self.assertEqual(bx[\"x1\"].dtype, backend.floatx())\n            if i < 2:\n                self.assertEqual(tuple(bx[\"x1\"].shape), (16, 1))\n                self.assertEqual(tuple(bx[\"x2\"].shape), (16, 2))\n                self.assertEqual(tuple(by[0].shape), (16, 3))\n                self.assertEqual(tuple(by[1].shape), (16, 4))\n                self.assertEqual(tuple(bw[0].shape), (16,))\n                self.assertEqual(tuple(bw[1].shape), (16,))\n            else:\n                self.assertEqual(tuple(bx[\"x1\"].shape), (2, 1))\n                self.assertEqual(tuple(by[0].shape), (2, 3))\n                self.assertEqual(tuple(bw[0].shape), (2,))\n                self.assertEqual(tuple(bw[1].shape), (2,))\n\n    @parameterized.named_parameters(\n        named_product(target_encoding=[\"int\", \"categorical\"])\n    )\n    def test_class_weights(self, target_encoding):\n        x = np.random.random((4, 2))\n        if target_encoding == \"int\":\n            y = np.array([[0], [1], [2], [3]], dtype=\"int32\")\n        else:\n            y = np.array(\n                [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]],\n                dtype=\"float32\",\n            )\n\n        class_weight = {\n            0: 0.1,\n            1: 0.2,\n            2: 0.3,\n            3: 0.4,\n        }\n        adapter = array_data_adapter.ArrayDataAdapter(\n            x,\n            y=y,\n            class_weight=class_weight,\n            batch_size=16,\n        )\n        gen = adapter.get_numpy_iterator()\n        for batch in gen:\n            self.assertEqual(len(batch), 3)\n            _, _, bw = batch\n            self.assertAllClose(bw, [0.1, 0.2, 0.3, 0.4])\n\n    def test_errors(self):\n        x = np.random.random((34, 1))\n        y = np.random.random((34, 3))\n        sw = np.random.random((34,))\n        cw = {\n            0: 0.1,\n            1: 0.2,\n            2: 0.3,\n            3: 0.4,\n        }\n\n        with self.assertRaisesRegex(\n            ValueError, \"Expected all elements of `x` to be array-like\"\n        ):\n            array_data_adapter.ArrayDataAdapter(x=\"Invalid\")\n        with self.assertRaisesRegex(\n            ValueError, \"Expected all elements of `x` to be array-like\"\n        ):\n            array_data_adapter.ArrayDataAdapter(x=x, y=\"Invalid\")\n        with self.assertRaisesRegex(\n            ValueError, \"Expected all elements of `x` to be array-like\"\n        ):\n            array_data_adapter.ArrayDataAdapter(\n                x=x, y=y, sample_weight=\"Invalid\"\n            )\n\n        with self.assertRaisesRegex(\n            ValueError, \"You cannot `class_weight` and `sample_weight`\"\n        ):\n            array_data_adapter.ArrayDataAdapter(\n                x=x, y=y, sample_weight=sw, class_weight=cw\n            )\n\n        nested_y = ({\"x\": x, \"y\": y},)\n        with self.assertRaisesRegex(\n            ValueError, \"You should provide one `sample_weight` array per\"\n        ):\n            array_data_adapter.ArrayDataAdapter(\n                x=x, y=nested_y, sample_weight=[]\n            )\n\n        tensor_sw = self.make_array(\"tf\", (34, 2), \"int32\")\n        with self.assertRaisesRegex(\n            ValueError, \"For a model with multiple outputs, when providing\"\n        ):\n            array_data_adapter.ArrayDataAdapter(\n                x=x, y=nested_y, sample_weight=tensor_sw\n            )\n\n        with self.assertRaisesRegex(\n            ValueError,\n            \"`class_weight` is only supported for Models with a single\",\n        ):\n            array_data_adapter.ArrayDataAdapter(\n                x=x, y=nested_y, class_weight=cw\n            )\n"
  },
  {
    "path": "keras/src/trainers/data_adapters/array_slicing.py",
    "content": "import collections\nimport math\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src import tree\nfrom keras.src.trainers.data_adapters import data_adapter_utils\nfrom keras.src.utils.module_utils import tensorflow as tf\n\ntry:\n    import pandas\nexcept ImportError:\n    pandas = None\n\n\n# Leave jax, tf, and torch arrays off this list. Instead we will use\n# `__array__` to detect these types. Doing so allows us to avoid importing a\n# backend framework we are not currently using just to do type-checking.\nARRAY_TYPES = (np.ndarray,)\nif pandas:\n    ARRAY_TYPES = ARRAY_TYPES + (pandas.Series, pandas.DataFrame)\n\n\nclass Sliceable:\n    \"\"\"`Sliceable` wrapping a tensor.\n\n    A `Sliceable` implements the subscript operator to slice or index against\n    the first dimension of the array. It also has conversion methods for each\n    one of the backends.\n\n    Args:\n        array: the native array or tensor to wrap.\n\n    Attributes:\n        shape: the shape of the full dense native array.\n    \"\"\"\n\n    def __init__(self, array):\n        self.array = array\n\n    def __getitem__(self, indices):\n        \"\"\"Select elements in the 0th dimension.\n\n        Args:\n            indices: the indices to select. Only needs to support one dimension,\n                the 0th dimension. Should support a `slice` or a list, tuple,\n                `np.array` or 1D tensor.\n        Returns: A slice of `self.array`.\n        \"\"\"\n        return self.array[indices]\n\n    @classmethod\n    def cast(cls, x, dtype):\n        \"\"\"Cast a tensor to a different dtype.\n\n        Only called on a full array as provided by the user.\n\n        Args:\n            x: the tensor to cast.\n        Returns: the cast tensor.\n        \"\"\"\n        return x.astype(dtype)\n\n    @classmethod\n    def convert_to_numpy(cls, x):\n        \"\"\"Convert a tensor to a NumPy array.\n\n        Only called after slicing using `__getitem__`.\n\n        Args:\n            x: the tensor to convert.\n        Returns: the converted tensor.\n        \"\"\"\n        return x\n\n    @classmethod\n    def convert_to_tf_dataset_compatible(cls, x):\n        \"\"\"Convert a tensor to something compatible with `tf.data.Dataset`.\n\n        This can be a NumPy array, `tf.Tensor` or any other type of tensor that\n        `tf.data.Dataset.from_tensors` can consume.\n        Only called on a full array as provided by the user.\n\n        Args:\n            x: the tensor to convert.\n        Returns: converted version tensor.\n        \"\"\"\n        return x\n\n    @classmethod\n    def convert_to_jax_compatible(cls, x):\n        \"\"\"Convert a tensor to something that the JAX backend can consume.\n\n        This can be a `JAX` array, `JAXSparse` or a NumPy array.\n        Only called after slicing using `__getitem__`.\n        Used to convert sparse tensors and densify ragged tensors.\n\n        Args:\n            x: the tensor to convert.\n        Returns: the converted tensor.\n        \"\"\"\n        return x\n\n    @classmethod\n    def convert_to_torch_compatible(cls, x):\n        \"\"\"Convert a tensor to something that the Torch backend can consume.\n\n        This can be a Torch tensor, NumPy array or any other type of tensor that\n        `keras.backend.torch.core.convert_to_tensor()` can consume.\n        Only called after slicing using `__getitem__`.\n        Used to densify sparse tensors and ragged tensors.\n\n        Args:\n            x: the tensor to convert.\n        Returns: the converted tensor.\n        \"\"\"\n        return x\n\n\nclass NumpySliceable(Sliceable):\n    pass\n\n\nclass TensorflowSliceable(Sliceable):\n    def __getitem__(self, indices):\n        from keras.src.utils.module_utils import tensorflow as tf\n\n        if isinstance(indices, slice):\n            return self.array[indices]\n        else:\n            return tf.gather(self.array, indices, axis=0)\n\n    @classmethod\n    def cast(cls, x, dtype):\n        from keras.src.backend.tensorflow.core import cast\n\n        return cast(x, dtype)\n\n    @classmethod\n    def convert_to_numpy(cls, x):\n        from keras.src.backend.tensorflow.core import convert_to_numpy\n\n        return convert_to_numpy(x)\n\n\nclass TensorflowRaggedSliceable(TensorflowSliceable):\n    @classmethod\n    def convert_to_jax_compatible(cls, x):\n        return cls.convert_to_numpy(x)\n\n    @classmethod\n    def convert_to_torch_compatible(cls, x):\n        return x.to_tensor()\n\n\nclass TensorflowSparseSliceable(TensorflowSliceable):\n    def __init__(self, array):\n        super().__init__(to_tensorflow_sparse_wrapper(array))\n\n    @property\n    def shape(self):\n        return self.array.sparse.shape\n\n    def __getitem__(self, indices):\n        return slice_tensorflow_sparse_wrapper(self.array, indices)\n\n    @classmethod\n    def convert_to_tf_dataset_compatible(cls, x):\n        return to_tensorflow_sparse_wrapper(x)\n\n    @classmethod\n    def convert_to_jax_compatible(cls, x):\n        return data_adapter_utils.tf_sparse_to_jax_sparse(x)\n\n    @classmethod\n    def convert_to_torch_compatible(cls, x):\n        from keras.src.backend.tensorflow import sparse as tf_sparse\n\n        return tf_sparse.sparse_to_dense(x)\n\n\nclass JaxSparseSliceable(Sliceable):\n    def __getitem__(self, indices):\n        return self.array[indices, ...]\n\n    @classmethod\n    def convert_to_numpy(cls, x):\n        from keras.src.backend.jax.core import convert_to_numpy\n\n        return convert_to_numpy(x)\n\n    @classmethod\n    def convert_to_tf_dataset_compatible(cls, array):\n        return to_tensorflow_sparse_wrapper(\n            data_adapter_utils.jax_sparse_to_tf_sparse(array)\n        )\n\n    @classmethod\n    def convert_to_torch_compatible(cls, x):\n        return x.todense()\n\n\nclass TorchSliceable(Sliceable):\n    @classmethod\n    def cast(cls, x, dtype):\n        from keras.src.backend.torch.core import cast\n\n        return cast(x, dtype)\n\n    @classmethod\n    def convert_to_numpy(cls, x):\n        from keras.src.backend.torch.core import convert_to_numpy\n\n        return convert_to_numpy(x)\n\n\nclass PandasSliceable(Sliceable):\n    def __getitem__(self, indices):\n        return self.array.iloc[indices]\n\n    @classmethod\n    def convert_to_numpy(cls, x):\n        return x.to_numpy()\n\n    @classmethod\n    def convert_to_tf_dataset_compatible(cls, x):\n        return cls.convert_to_numpy(x)\n\n    @classmethod\n    def convert_to_jax_compatible(cls, x):\n        return cls.convert_to_numpy(x)\n\n    @classmethod\n    def convert_to_torch_compatible(cls, x):\n        return cls.convert_to_numpy(x)\n\n\nclass PandasDataFrameSliceable(PandasSliceable):\n    pass\n\n\nclass PandasSeriesSliceable(PandasSliceable):\n    @classmethod\n    def convert_to_numpy(cls, x):\n        return np.expand_dims(x.to_numpy(), axis=-1)\n\n\nclass ScipySparseSliceable(Sliceable):\n    def __init__(self, array):\n        # The COO representation is not indexable / sliceable and does not lend\n        # itself to it. Use the CSR representation instead, which is sliceable.\n        super().__init__(array.tocsr())\n\n    @classmethod\n    def convert_to_numpy(cls, x):\n        return x.todense()\n\n    @classmethod\n    def convert_to_tf_dataset_compatible(cls, x):\n        return to_tensorflow_sparse_wrapper(\n            data_adapter_utils.scipy_sparse_to_tf_sparse(x)\n        )\n\n    @classmethod\n    def convert_to_jax_compatible(cls, x):\n        return data_adapter_utils.scipy_sparse_to_jax_sparse(x)\n\n    @classmethod\n    def convert_to_torch_compatible(cls, x):\n        return x.todense()\n\n\n# `tf.SparseTensor` does not support indexing or `tf.gather`. The COO\n# representation it uses does not lend itself to indexing. We add some\n# intermediary tensors to ease the indexing and slicing. We put both indices and\n# values in `RaggedTensor`s where each row corresponds to a row in the sparse\n# tensor. This is because the number of values per row is not fixed.\n# `RaggedTensor`s do support indexing and `tf.gather`, although on CPU only.\n# We then reconstruct a `SparseTensor` from extracted rows. In theory, there is\n# no duplication of data for the indices and values, only the addition of row\n# splits for the ragged representation.\n# `TensorflowSparseWrapper` is a named tuple which combines the original\n# `SparseTensor` (used for the shape) and the ragged representations of indices\n# and values for indexing / slicing. We use a named tuple and not a `Sliceable`\n# to be able to ingest it in `tf.data.Dataset.from_tensors()` and map it.\n\nTensorflowSparseWrapper = collections.namedtuple(\n    \"TensorflowSparseWrapper\", [\"sparse\", \"ragged_indices\", \"ragged_values\"]\n)\n\n\ndef to_tensorflow_sparse_wrapper(sparse):\n    from keras.src.utils.module_utils import tensorflow as tf\n\n    row_ids = sparse.indices[:, 0]\n    row_splits = tf.experimental.RowPartition.from_value_rowids(\n        row_ids\n    ).row_splits()\n\n    ragged_indices = tf.cast(\n        tf.RaggedTensor.from_row_splits(sparse.indices, row_splits), tf.int64\n    )\n    ragged_values = tf.RaggedTensor.from_row_splits(sparse.values, row_splits)\n    return TensorflowSparseWrapper(sparse, ragged_indices, ragged_values)\n\n\ndef slice_tensorflow_sparse_wrapper(sparse_wrapper, indices):\n    from keras.src.utils.module_utils import tensorflow as tf\n\n    if isinstance(indices, slice):\n        sparse_indices = sparse_wrapper.ragged_indices[indices]\n        sparse_values = sparse_wrapper.ragged_values[indices]\n        batch_dim = indices.stop - indices.start\n    else:\n        sparse_indices = tf.gather(sparse_wrapper.ragged_indices, indices)\n        sparse_values = tf.gather(sparse_wrapper.ragged_values, indices)\n        if isinstance(indices, list):\n            batch_dim = len(indices)\n        else:\n            batch_dim = indices.shape[0]\n            if batch_dim is None:\n                batch_dim = tf.shape(indices)[0]\n\n    row_ids = sparse_indices.value_rowids()\n    sparse_indices = sparse_indices.flat_values[:, 1:]  # remove first value\n    sparse_indices = tf.concat(\n        [tf.expand_dims(row_ids, -1), sparse_indices], axis=1\n    )\n\n    sparse_values = sparse_values.flat_values\n    sparse_shape = (batch_dim,) + tuple(\n        sparse_wrapper.sparse.shape.as_list()[1:]\n    )\n    return tf.SparseTensor(sparse_indices, sparse_values, sparse_shape)\n\n\ndef can_slice_array(x):\n    return (\n        x is None\n        or isinstance(x, ARRAY_TYPES)\n        or data_adapter_utils.is_tensorflow_tensor(x)\n        or data_adapter_utils.is_jax_array(x)\n        or data_adapter_utils.is_torch_tensor(x)\n        or data_adapter_utils.is_scipy_sparse(x)\n        or hasattr(x, \"__array__\")\n    )\n\n\ndef convert_to_sliceable(arrays, target_backend=None):\n    \"\"\"Convert a structure of arrays into `Sliceable` instances\n\n    Args:\n        arrays: the arrays to convert.\n        target_backend: the target backend for the output:\n            - `None` indicates that `arrays` will be wrapped into `Sliceable`s\n              as-is without using a different representation. This is used by\n              `train_validation_split()`.\n            - `tensorflow` indicates that\n              `Sliceable.convert_to_tf_dataset_compatible` will be called. The\n              returned structure therefore contains arrays, not `Sliceable`s.\n            - `numpy`, `jax` or `torch` indices that the arrays will eventually\n              be converted to this backend type after slicing. In this case,\n              the intermediary `Sliceable`s may use a different representation\n              from the input `arrays` for better performance.\n    Returns: the same structure with `Sliceable` instances or arrays.\n    \"\"\"\n\n    def convert_single_array(x):\n        if x is None:\n            return x\n\n        # Special case: handle np \"object\" arrays containing strings\n        if (\n            isinstance(x, np.ndarray)\n            and str(x.dtype) == \"object\"\n            and backend.backend() == \"tensorflow\"\n            and all(isinstance(e, str) for e in x)\n        ):\n            x = tf.convert_to_tensor(x, dtype=\"string\")\n\n        # Step 1. Determine which Sliceable class to use.\n        if isinstance(x, np.ndarray):\n            sliceable_class = NumpySliceable\n        elif data_adapter_utils.is_tensorflow_tensor(x):\n            if data_adapter_utils.is_tensorflow_ragged(x):\n                sliceable_class = TensorflowRaggedSliceable\n            elif data_adapter_utils.is_tensorflow_sparse(x):\n                sliceable_class = TensorflowSparseSliceable\n            else:\n                sliceable_class = TensorflowSliceable\n        elif data_adapter_utils.is_jax_array(x):\n            if data_adapter_utils.is_jax_sparse(x):\n                sliceable_class = JaxSparseSliceable\n            else:\n                x = np.asarray(x)\n                sliceable_class = NumpySliceable\n        elif data_adapter_utils.is_torch_tensor(x):\n            sliceable_class = TorchSliceable\n        elif pandas is not None and isinstance(x, pandas.DataFrame):\n            sliceable_class = PandasDataFrameSliceable\n        elif pandas is not None and isinstance(x, pandas.Series):\n            sliceable_class = PandasSeriesSliceable\n        elif data_adapter_utils.is_scipy_sparse(x):\n            sliceable_class = ScipySparseSliceable\n        elif hasattr(x, \"__array__\"):\n            x = np.asarray(x)\n            sliceable_class = NumpySliceable\n        else:\n            raise ValueError(\n                \"Expected a NumPy array, tf.Tensor, tf.RaggedTensor, \"\n                \"tf.SparseTensor, jax.np.ndarray, \"\n                \"jax.experimental.sparse.JAXSparse, torch.Tensor, \"\n                \"Pandas Dataframe, or Pandas Series. Received invalid input: \"\n                f\"{x} (of type {type(x)})\"\n            )\n\n        # Step 2. Normalize floats to floatx.\n        def is_non_floatx_float(dtype):\n            return (\n                dtype is not object\n                and backend.is_float_dtype(dtype)\n                and not backend.standardize_dtype(dtype) == backend.floatx()\n            )\n\n        cast_dtype = None\n        if pandas is not None and isinstance(x, pandas.DataFrame):\n            if any(is_non_floatx_float(d) for d in x.dtypes.values):\n                cast_dtype = backend.floatx()\n        else:\n            if is_non_floatx_float(x.dtype):\n                cast_dtype = backend.floatx()\n\n        if cast_dtype is not None:\n            x = sliceable_class.cast(x, cast_dtype)\n\n        # Step 3. Apply target backend specific logic and optimizations.\n        if target_backend is None:\n            return sliceable_class(x)\n\n        if target_backend == \"tensorflow\":\n            return sliceable_class.convert_to_tf_dataset_compatible(x)\n\n        # With dense arrays and JAX as output, it is faster to use NumPy as an\n        # intermediary representation, so wrap input array in a NumPy array,\n        # which should not use extra memory.\n        # See https://github.com/google/jax/issues/1276 for an explanation of\n        # why slicing a NumPy array is faster than slicing a JAX array.\n        if target_backend == \"jax\" and sliceable_class in (\n            TensorflowSliceable,\n            TorchSliceable,\n        ):\n            x = np.asarray(x)\n            sliceable_class = NumpySliceable\n\n        return sliceable_class(x)\n\n    return tree.map_structure(convert_single_array, arrays)\n\n\ndef train_validation_split(arrays, validation_split):\n    \"\"\"Split arrays into train and validation subsets in deterministic order.\n\n    The last part of data will become validation data.\n\n    Args:\n        arrays: Tensors to split. Allowed inputs are arbitrarily nested\n            structures of Tensors and NumPy arrays.\n        validation_split: Float between 0 and 1. The proportion of the dataset\n            to include in the validation split. The rest of the dataset will be\n            included in the training split.\n\n    Returns:\n        `(train_arrays, validation_arrays)`\n    \"\"\"\n\n    flat_arrays = tree.flatten(arrays)\n    unsplitable = [type(t) for t in flat_arrays if not can_slice_array(t)]\n    if unsplitable:\n        raise ValueError(\n            \"Argument `validation_split` is only supported \"\n            \"for tensors or NumPy arrays.\"\n            f\"Found incompatible type in the input: {unsplitable}\"\n        )\n\n    if all(t is None for t in flat_arrays):\n        return arrays, arrays\n\n    first_non_none = None\n    for t in flat_arrays:\n        if t is not None:\n            first_non_none = t\n            break\n\n    # Assumes all arrays have the same batch shape or are `None`.\n    batch_dim = int(first_non_none.shape[0])\n    split_at = int(math.floor(batch_dim * (1.0 - validation_split)))\n\n    if split_at == 0 or split_at == batch_dim:\n        raise ValueError(\n            f\"Training data contains {batch_dim} samples, which is not \"\n            \"sufficient to split it into a validation and training set as \"\n            f\"specified by `validation_split={validation_split}`. Either \"\n            \"provide more data, or a different value for the \"\n            \"`validation_split` argument.\"\n        )\n\n    def _split(t, start, end):\n        if t is None:\n            return t\n        return t[start:end]\n\n    sliceables = convert_to_sliceable(arrays)\n    train_arrays = tree.map_structure(\n        lambda x: _split(x, start=0, end=split_at), sliceables\n    )\n    val_arrays = tree.map_structure(\n        lambda x: _split(x, start=split_at, end=batch_dim), sliceables\n    )\n    return train_arrays, val_arrays\n"
  },
  {
    "path": "keras/src/trainers/data_adapters/data_adapter.py",
    "content": "class DataAdapter:\n    \"\"\"Base class for input data adapters.\n\n    The purpose of a DataAdapter is to provide a unified interface to\n    iterate over input data provided in a variety of formats -- such as\n    NumPy arrays, tf.Tensors, tf.data.Datasets, Keras PyDatasets, etc.\n    \"\"\"\n\n    def get_numpy_iterator(self):\n        \"\"\"Get a Python iterable for the `DataAdapter`, that yields NumPy\n        arrays.\n\n        Returns:\n            A Python iterator.\n        \"\"\"\n        raise NotImplementedError\n\n    def get_tf_dataset(self):\n        \"\"\"Get a `tf.data.Dataset` instance for the DataAdapter.\n\n        Note that the dataset returned does not repeat for epoch, so caller\n        might need to create new iterator for the same dataset at the beginning\n        of the epoch. This behavior might change in the future.\n\n        Returns:\n            A `tf.data.Dataset`. Caller might use the dataset in different\n            context, e.g. iter(dataset) in eager to get the value directly, or\n            in graph mode, provide the iterator tensor to Keras model function.\n        \"\"\"\n        raise NotImplementedError\n\n    def get_jax_iterator(self):\n        \"\"\"Get a Python iterable for the `DataAdapter`, that yields arrays that\n        that can be fed to JAX. NumPy arrays are preferred for performance.\n\n        Returns:\n            A Python iterator.\n        \"\"\"\n        raise NotImplementedError\n\n    def get_torch_dataloader(self):\n        \"\"\"Get a Torch `DataLoader` for the `DataAdapter`.\n\n        Returns:\n            A Torch `DataLoader`.\n        \"\"\"\n        raise NotImplementedError\n\n    @property\n    def num_batches(self):\n        \"\"\"Return the size (number of batches) for the dataset created.\n\n        For certain type of the data input, the number of batches is known, eg\n        for Numpy data, the size is same as (number_of_element / batch_size).\n        Whereas for dataset or python generator, the size is unknown since it\n        may or may not have an end state.\n\n        Returns:\n            int, the number of batches for the dataset, or None if it is\n            unknown.  The caller could use this to control the loop of training,\n            show progress bar, or handle unexpected StopIteration error.\n        \"\"\"\n        raise NotImplementedError\n\n    @property\n    def batch_size(self):\n        \"\"\"Return the batch size of the dataset created.\n\n        For certain type of the data input, the batch size is known, and even\n        required, like numpy array. Whereas for dataset, the batch is unknown\n        unless we take a peek.\n\n        Returns:\n          int, the batch size of the dataset, or None if it is unknown.\n        \"\"\"\n        raise NotImplementedError\n\n    @property\n    def has_partial_batch(self):\n        \"\"\"Whether the dataset has partial batch at the end.\"\"\"\n        raise NotImplementedError\n\n    @property\n    def partial_batch_size(self):\n        \"\"\"The size of the final partial batch for dataset.\n\n        Will return None if has_partial_batch is False or batch_size is None.\n        \"\"\"\n        raise NotImplementedError\n\n    def on_epoch_begin(self):\n        \"\"\"A hook called before each epoch.\"\"\"\n        pass\n\n    def on_epoch_end(self):\n        \"\"\"A hook called after each epoch.\"\"\"\n        pass\n"
  },
  {
    "path": "keras/src/trainers/data_adapters/data_adapter_utils.py",
    "content": "import numpy as np\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import tree\nfrom keras.src.api_export import keras_export\n\nNUM_BATCHES_FOR_TENSOR_SPEC = 2\n\n\n@keras_export(\"keras.utils.unpack_x_y_sample_weight\")\ndef unpack_x_y_sample_weight(data):\n    \"\"\"Unpacks user-provided data tuple.\n\n    This is a convenience utility to be used when overriding\n    `Model.train_step`, `Model.test_step`, or `Model.predict_step`.\n    This utility makes it easy to support data of the form `(x,)`,\n    `(x, y)`, or `(x, y, sample_weight)`.\n\n    Example:\n\n    >>> features_batch = ops.ones((10, 5))\n    >>> labels_batch = ops.zeros((10, 5))\n    >>> data = (features_batch, labels_batch)\n    >>> # `y` and `sample_weight` will default to `None` if not provided.\n    >>> x, y, sample_weight = unpack_x_y_sample_weight(data)\n    >>> sample_weight is None\n    True\n\n    Args:\n        data: A tuple of the form `(x,)`, `(x, y)`, or `(x, y, sample_weight)`.\n\n    Returns:\n        The unpacked tuple, with `None`s for `y` and `sample_weight` if they are\n        not provided.\n    \"\"\"\n    if isinstance(data, list):\n        data = tuple(data)\n    if not isinstance(data, tuple):\n        return (data, None, None)\n    elif len(data) == 1:\n        return (data[0], None, None)\n    elif len(data) == 2:\n        return (data[0], data[1], None)\n    elif len(data) == 3:\n        return (data[0], data[1], data[2])\n    error_msg = (\n        \"Data is expected to be in format `x`, `(x,)`, `(x, y)`, \"\n        f\"or `(x, y, sample_weight)`, found: {data}\"\n    )\n    raise ValueError(error_msg)\n\n\n@keras_export(\"keras.utils.pack_x_y_sample_weight\")\ndef pack_x_y_sample_weight(x, y=None, sample_weight=None):\n    \"\"\"Packs user-provided data into a tuple.\n\n    This is a convenience utility for packing data into the tuple formats\n    that `Model.fit()` uses.\n\n    Example:\n\n    >>> x = ops.ones((10, 1))\n    >>> data = pack_x_y_sample_weight(x)\n    >>> isinstance(data, ops.Tensor)\n    True\n    >>> y = ops.ones((10, 1))\n    >>> data = pack_x_y_sample_weight(x, y)\n    >>> isinstance(data, tuple)\n    True\n    >>> x, y = data\n\n    Args:\n        x: Features to pass to `Model`.\n        y: Ground-truth targets to pass to `Model`.\n        sample_weight: Sample weight for each element.\n\n    Returns:\n        Tuple in the format used in `Model.fit()`.\n    \"\"\"\n    if y is None:\n        # For single x-input, we do no tuple wrapping since in this case\n        # there is no ambiguity. This also makes NumPy and Dataset\n        # consistent in that the user does not have to wrap their Dataset\n        # data in an unnecessary tuple.\n        if not isinstance(x, (tuple, list)):\n            return x\n        else:\n            return (x,)\n    elif sample_weight is None:\n        return (x, y)\n    else:\n        return (x, y, sample_weight)\n\n\ndef list_to_tuple(maybe_list):\n    \"\"\"Datasets will stack any list of tensors, so we convert them to tuples.\"\"\"\n    if isinstance(maybe_list, list):\n        return tuple(maybe_list)\n    return maybe_list\n\n\ndef check_data_cardinality(data):\n    num_samples = set(\n        int(i.shape[0]) for i in tree.flatten(data) if i is not None\n    )\n    if len(num_samples) > 1:\n        msg = (\n            \"Data cardinality is ambiguous. \"\n            \"Make sure all arrays contain the same number of samples.\"\n        )\n        for label, single_data in zip([\"x\", \"y\", \"sample_weight\"], data):\n            sizes = \", \".join(\n                str(i.shape[0]) for i in tree.flatten(single_data)\n            )\n            msg += f\"'{label}' sizes: {sizes}\\n\"\n        raise ValueError(msg)\n\n\ndef class_weight_to_sample_weights(y, class_weight):\n    # Convert to numpy to ensure consistent handling of operations\n    # (e.g., np.round()) across frameworks like TensorFlow, JAX, and PyTorch\n\n    y_numpy = ops.convert_to_numpy(y)\n    sample_weight = np.ones(shape=(y_numpy.shape[0],), dtype=backend.floatx())\n    if len(y_numpy.shape) > 1:\n        if y_numpy.shape[-1] != 1:\n            y_numpy = np.argmax(y_numpy, axis=-1)\n        else:\n            y_numpy = np.squeeze(y_numpy, axis=-1)\n    y_numpy = np.round(y_numpy).astype(\"int32\")\n\n    for i in range(y_numpy.shape[0]):\n        sample_weight[i] = class_weight.get(int(y_numpy[i]), 1.0)\n    return sample_weight\n\n\ndef get_keras_tensor_spec(batches):\n    \"\"\"Return the KerasTensor spec for a list of batches.\n\n    The spec is represented using `KerasTensor` which could handle dense, sparse\n    or ragged tensors.\n\n    Args:\n        batches: list of structures of tensors. The structures must be\n            identical, but the shape at each leaf may be different.\n\n    Returns:\n        A nested structure of `KerasTensor`.\n    \"\"\"\n\n    def get_single_tensor_spec(*tensors):\n        x = tensors[0]\n        if not hasattr(x, \"shape\"):\n            # Try to convert to a numpy array.\n            x = np.array(x)\n        rank = len(x.shape)\n        if rank < 1:\n            raise ValueError(\n                \"When passing a dataset to a Keras model, the arrays must \"\n                f\"be at least rank 1. Received: {x} of rank {len(x.shape)}.\"\n            )\n        for t in tensors:\n            if len(t.shape) != rank:\n                raise ValueError(\n                    \"When passing a dataset to a Keras model, the \"\n                    \"corresponding arrays in each batch must have the same \"\n                    f\"rank. Received: {x} and {t}\"\n                )\n        shape = []\n        # Merge shapes: go through each dimension one by one and keep the\n        # common values\n        for dims in zip(*[list(x.shape) for x in tensors]):\n            dims_set = set(dims)\n            shape.append(dims_set.pop() if len(dims_set) == 1 else None)\n\n        dtype = backend.standardize_dtype(x.dtype)\n        if is_tensorflow_ragged(x):\n            return backend.KerasTensor(\n                shape=shape,\n                dtype=dtype,\n                ragged=True,\n                ragged_rank=x.ragged_rank,\n                row_splits_dtype=x.row_splits.dtype,\n            )\n        if is_tensorflow_sparse(x) or is_scipy_sparse(x) or is_jax_sparse(x):\n            return backend.KerasTensor(shape=shape, dtype=dtype, sparse=True)\n        else:\n            return backend.KerasTensor(shape=shape, dtype=dtype)\n\n    return tree.map_structure(\n        get_single_tensor_spec, *batches, none_is_leaf=False\n    )\n\n\ndef convert_to_tf_tensor_spec(keras_tensor, batch_axis_to_none=True):\n    \"\"\"Convert a KerasTensor to a TensorSpec.\n\n    Args:\n        keras_tensor: A KerasTensor instance.\n        batch_axis_to_none: If `True`, the batch axis of the returned\n            tensor spec will be set to None. Defaults to `True`.\n    \"\"\"\n    from keras.src.utils.module_utils import tensorflow as tf\n\n    if keras_tensor is None:\n        return tf.OptionalSpec(None)\n    if not isinstance(keras_tensor, backend.KerasTensor):\n        raise TypeError(\n            f\"Expected a KerasTensor, but got {keras_tensor} of type \"\n            f\"{type(keras_tensor)}.\"\n        )\n    shape = list(keras_tensor.shape)\n    if batch_axis_to_none:\n        shape[0] = None\n    if keras_tensor.ragged:\n        return tf.RaggedTensorSpec(\n            shape=shape,\n            dtype=keras_tensor.dtype,\n            ragged_rank=keras_tensor.ragged_rank,\n            row_splits_dtype=keras_tensor.row_splits_dtype,\n        )\n    elif keras_tensor.sparse:\n        return tf.SparseTensorSpec(shape=shape, dtype=keras_tensor.dtype)\n    else:\n        return tf.TensorSpec(shape=shape, dtype=keras_tensor.dtype)\n\n\ndef get_tensor_spec(batches):\n    \"\"\"Return the common tensor spec for a list of batches.\n\n    The spec is represented using `tf.TensorSpec`, `tf.SparseTensorSpec` and\n    `tf.RaggedTensorSpec`.\n\n    Args:\n        batches: list of structures of tensors. The structures must be\n            identical, but the shape at each leaf may be different.\n\n    Returns:\n        A common tensor spec.\n    \"\"\"\n    tensor_specs = get_keras_tensor_spec(batches)\n    return tree.map_structure(convert_to_tf_tensor_spec, tensor_specs)\n\n\ndef get_jax_iterator(iterable):\n    import jax\n    import jax.experimental.sparse as jax_sparse\n\n    def convert_to_jax_compatible(x):\n        if isinstance(x, (jax.Array, jax_sparse.JAXSparse, np.ndarray)):\n            return x\n        elif is_scipy_sparse(x):\n            return scipy_sparse_to_jax_sparse(x)\n        elif is_tensorflow_sparse(x):\n            return tf_sparse_to_jax_sparse(x)\n        else:\n            return np.asarray(x)\n\n    for batch in iterable:\n        yield tree.map_structure(\n            convert_to_jax_compatible, batch, none_is_leaf=False\n        )\n\n\ndef get_numpy_iterator(iterable):\n    def convert_to_numpy(x):\n        if not isinstance(x, np.ndarray):\n            # Using `__array__` should handle `tf.Tensor`, `jax.np.ndarray`,\n            # `torch.Tensor`, as well as any other tensor-like object that\n            # has added numpy support.\n            if hasattr(x, \"__array__\"):\n                if is_torch_tensor(x):\n                    x = x.cpu()\n                x = np.asarray(x)\n        return x\n\n    for batch in iterable:\n        yield tree.map_structure(convert_to_numpy, batch, none_is_leaf=False)\n\n\ndef get_torch_dataloader(iterable):\n    import torch.utils.data as torch_data\n\n    from keras.src.backend.torch.core import convert_to_tensor\n\n    class ConverterIterableDataset(torch_data.IterableDataset):\n        def __init__(self, iterable):\n            self.iterable = iterable\n\n        def __iter__(self):\n            for batch in self.iterable:\n                yield tree.map_structure(\n                    convert_to_tensor, batch, none_is_leaf=False\n                )\n\n    dataset = ConverterIterableDataset(iterable)\n    # `batch_size=None` indicates that we should not re-batch\n    return torch_data.DataLoader(dataset, batch_size=None)\n\n\ndef is_tensorflow_tensor(value):\n    if hasattr(value, \"__class__\"):\n        if value.__class__.__name__ in (\"RaggedTensor\", \"SparseTensor\"):\n            return \"tensorflow.python.\" in str(value.__class__.__module__)\n        for parent in value.__class__.__mro__:\n            if parent.__name__ in (\"Tensor\") and \"tensorflow.python.\" in str(\n                parent.__module__\n            ):\n                return True\n    return False\n\n\ndef is_tensorflow_ragged(value):\n    if hasattr(value, \"__class__\"):\n        return (\n            value.__class__.__name__ == \"RaggedTensor\"\n            and \"tensorflow.python.\" in str(value.__class__.__module__)\n        )\n    return False\n\n\ndef is_tensorflow_sparse(value):\n    if hasattr(value, \"__class__\"):\n        return (\n            value.__class__.__name__ == \"SparseTensor\"\n            and \"tensorflow.python.\" in str(value.__class__.__module__)\n        )\n    return False\n\n\ndef is_jax_array(value):\n    if hasattr(value, \"__class__\"):\n        for parent in value.__class__.__mro__:\n            if parent.__name__ == \"Array\" and str(parent.__module__) == \"jax\":\n                return True\n    return is_jax_sparse(value)  # JAX sparse arrays do not extend jax.Array\n\n\ndef is_jax_sparse(value):\n    if hasattr(value, \"__class__\"):\n        return str(value.__class__.__module__).startswith(\n            \"jax.experimental.sparse\"\n        )\n    return False\n\n\ndef is_torch_tensor(value):\n    if hasattr(value, \"__class__\"):\n        for parent in value.__class__.__mro__:\n            if parent.__name__ == \"Tensor\" and str(parent.__module__).endswith(\n                \"torch\"\n            ):\n                return True\n    return False\n\n\ndef is_scipy_sparse(x):\n    return str(x.__class__.__module__).startswith(\"scipy.sparse\") and hasattr(\n        x, \"tocoo\"\n    )\n\n\ndef scipy_sparse_to_tf_sparse(x):\n    from keras.src.utils.module_utils import tensorflow as tf\n\n    coo = x.tocoo()\n    indices = np.concatenate(\n        (np.expand_dims(coo.row, 1), np.expand_dims(coo.col, 1)), axis=1\n    )\n    return tf.SparseTensor(indices, coo.data, coo.shape)\n\n\ndef scipy_sparse_to_jax_sparse(x):\n    import jax\n    import jax.experimental.sparse as jax_sparse\n\n    with jax.default_device(jax.local_devices(backend=\"cpu\")[0]):\n        return jax_sparse.BCOO.from_scipy_sparse(x)\n\n\ndef tf_sparse_to_jax_sparse(x):\n    import jax\n    import jax.experimental.sparse as jax_sparse\n\n    values = np.asarray(x.values)\n    indices = np.asarray(x.indices)\n    with jax.default_device(jax.local_devices(backend=\"cpu\")[0]):\n        return jax_sparse.BCOO((values, indices), shape=x.shape)\n\n\ndef jax_sparse_to_tf_sparse(x):\n    from keras.src.utils.module_utils import tensorflow as tf\n\n    return tf.SparseTensor(x.indices, x.data, x.shape)\n"
  },
  {
    "path": "keras/src/trainers/data_adapters/data_adapter_utils_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import testing\nfrom keras.src.trainers.data_adapters.data_adapter_utils import (\n    class_weight_to_sample_weights,\n)\n\n\nclass TestClassWeightToSampleWeights(testing.TestCase):\n    @parameterized.named_parameters(\n        [\n            # Simple case, where y is flat\n            (\n                \"simple_class_labels\",\n                np.array([0, 1, 0, 2]),\n                {0: 1.0, 1: 2.0, 2: 3.0},\n                np.array([1.0, 2.0, 1.0, 3.0]),\n            ),\n            # Testing with one-hot encoded labels,\n            # so basically the argmax statement\n            (\n                \"one_hot_encoded_labels\",\n                np.array([[1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1]]),\n                {0: 1.0, 1: 2.0, 2: 3.0},\n                np.array([1.0, 2.0, 1.0, 3.0]),\n            ),\n            # 3 is not mapped, so it's assigned the default weight (1)\n            (\n                \"unmapped_class\",\n                np.array([0, 3, 0, 2]),\n                {0: 1.0, 1: 2.0, 2: 3.0},\n                np.array([1.0, 1.0, 1.0, 3.0]),\n            ),\n            (\n                \"multi_dimensional_input\",\n                np.array([[0], [1], [0], [2]]),\n                {0: 1.0, 1: 2.0, 2: 3.0},\n                np.array([1.0, 2.0, 1.0, 3.0]),\n            ),\n            (\n                \"all_unmapped\",\n                np.array([0, 1, 0, 2]),\n                {},\n                np.array([1.0, 1.0, 1.0, 1.0]),\n            ),\n        ]\n    )\n    def test_class_weight_to_sample_weights(self, y, class_weight, expected):\n        self.assertAllClose(\n            class_weight_to_sample_weights(y, class_weight), expected\n        )\n\n    @pytest.mark.skipif(backend.backend() != \"torch\", reason=\"torch only\")\n    def test_class_weight_to_sample_weights_torch_specific(self):\n        import torch\n\n        y = torch.from_numpy(np.array([0, 1, 0, 2]))\n        self.assertAllClose(\n            class_weight_to_sample_weights(y, {0: 1.0, 1: 2.0, 2: 3.0}),\n            np.array([1.0, 2.0, 1.0, 3.0]),\n        )\n        y_one_hot = torch.from_numpy(\n            np.array([[1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1]])\n        )\n        self.assertAllClose(\n            class_weight_to_sample_weights(y_one_hot, {0: 1.0, 1: 2.0, 2: 3.0}),\n            np.array([1.0, 2.0, 1.0, 3.0]),\n        )\n\n    @pytest.mark.skipif(backend.backend() != \"jax\", reason=\"jax only\")\n    def test_class_weight_to_sample_weights_jax_specific(self):\n        import jax\n\n        y = jax.numpy.asarray(np.array([0, 1, 0, 2]))\n        self.assertAllClose(\n            class_weight_to_sample_weights(y, {0: 1.0, 1: 2.0, 2: 3.0}),\n            np.array([1.0, 2.0, 1.0, 3.0]),\n        )\n        y_one_hot = jax.numpy.asarray(\n            np.array([[1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1]])\n        )\n        self.assertAllClose(\n            class_weight_to_sample_weights(y_one_hot, {0: 1.0, 1: 2.0, 2: 3.0}),\n            np.array([1.0, 2.0, 1.0, 3.0]),\n        )\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\", reason=\"tensorflow only\"\n    )\n    def test_class_weight_to_sample_weights_tf_specific(self):\n        import tensorflow as tf\n\n        y = tf.convert_to_tensor(np.array([0, 1, 0, 2]))\n        self.assertAllClose(\n            class_weight_to_sample_weights(y, {0: 1.0, 1: 2.0, 2: 3.0}),\n            np.array([1.0, 2.0, 1.0, 3.0]),\n        )\n        y_one_hot = tf.convert_to_tensor(\n            np.array([[1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1]])\n        )\n        self.assertAllClose(\n            class_weight_to_sample_weights(y_one_hot, {0: 1.0, 1: 2.0, 2: 3.0}),\n            np.array([1.0, 2.0, 1.0, 3.0]),\n        )\n"
  },
  {
    "path": "keras/src/trainers/data_adapters/generator_data_adapter.py",
    "content": "import itertools\n\nfrom keras.src import tree\nfrom keras.src.trainers.data_adapters import data_adapter_utils\nfrom keras.src.trainers.data_adapters.data_adapter import DataAdapter\n\n\nclass GeneratorDataAdapter(DataAdapter):\n    \"\"\"Adapter for Python generators.\"\"\"\n\n    def __init__(self, generator):\n        first_batches, generator = peek_and_restore(generator)\n        self.generator = generator\n        self._first_batches = first_batches\n        self._output_signature = None\n        if not isinstance(first_batches[0], tuple):\n            raise ValueError(\n                \"When passing a Python generator to a Keras model, \"\n                \"the generator must return a tuple, either \"\n                \"(input,) or (inputs, targets) or \"\n                \"(inputs, targets, sample_weights). \"\n                f\"Received: {first_batches[0]}\"\n            )\n\n    def get_numpy_iterator(self):\n        return data_adapter_utils.get_numpy_iterator(self.generator())\n\n    def get_jax_iterator(self):\n        return data_adapter_utils.get_jax_iterator(self.generator())\n\n    def get_tf_dataset(self):\n        from keras.src.utils.module_utils import tensorflow as tf\n\n        def convert_to_tf(x, spec):\n            if x is None:\n                return tf.experimental.Optional.empty(None)\n            if data_adapter_utils.is_scipy_sparse(x):\n                x = data_adapter_utils.scipy_sparse_to_tf_sparse(x)\n            elif data_adapter_utils.is_jax_sparse(x):\n                x = data_adapter_utils.jax_sparse_to_tf_sparse(x)\n            if not spec.shape.is_compatible_with(x.shape):\n                raise TypeError(\n                    f\"Generator yielded an element of shape {x.shape} where \"\n                    f\"an element of shape {spec.shape} was expected. Your \"\n                    \"generator provides tensors with variable input \"\n                    \"dimensions other than the batch size. Make sure that the \"\n                    \"generator's first two batches do not have the same \"\n                    \"dimension value wherever there is a variable input \"\n                    \"dimension.\"\n                )\n            return x\n\n        def get_tf_iterator():\n            for batch in self.generator():\n                batch = tree.map_structure(\n                    convert_to_tf, batch, self._output_signature\n                )\n                yield batch\n\n        if self._output_signature is None:\n            self._output_signature = data_adapter_utils.get_tensor_spec(\n                self._first_batches\n            )\n        ds = tf.data.Dataset.from_generator(\n            get_tf_iterator,\n            output_signature=self._output_signature,\n        )\n        ds = ds.prefetch(tf.data.AUTOTUNE)\n        return ds\n\n    def get_torch_dataloader(self):\n        return data_adapter_utils.get_torch_dataloader(self.generator())\n\n    @property\n    def num_batches(self):\n        return None\n\n    @property\n    def batch_size(self):\n        return None\n\n\ndef peek_and_restore(generator):\n    batches = list(\n        itertools.islice(\n            generator, data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC\n        )\n    )\n    return batches, lambda: itertools.chain(batches, generator)\n"
  },
  {
    "path": "keras/src/trainers/data_adapters/generator_data_adapter_test.py",
    "content": "import math\n\nimport jax\nimport jax.experimental.sparse as jax_sparse\nimport numpy as np\nimport pytest\nimport scipy\nimport tensorflow as tf\nimport torch\nfrom absl.testing import parameterized\nfrom jax import numpy as jnp\n\nfrom keras.src import backend\nfrom keras.src import testing\nfrom keras.src.testing.test_utils import named_product\nfrom keras.src.trainers.data_adapters import generator_data_adapter\n\n\ndef example_generator(x, y, sample_weight=None, batch_size=32):\n    def make():\n        for i in range(math.ceil(len(x) / batch_size)):\n            low = i * batch_size\n            high = min(low + batch_size, len(x))\n            batch_x = x[low:high]\n            batch_y = y[low:high]\n            if sample_weight is not None:\n                yield batch_x, batch_y, sample_weight[low:high]\n            else:\n                yield batch_x, batch_y\n\n    return make\n\n\nclass GeneratorDataAdapterTest(testing.TestCase):\n    @parameterized.named_parameters(\n        named_product(\n            [\n                {\"testcase_name\": \"use_weight\", \"use_sample_weight\": True},\n                {\"testcase_name\": \"no_weight\", \"use_sample_weight\": False},\n            ],\n            generator_type=[\"np\", \"tf\", \"jax\", \"torch\"],\n        )\n    )\n    def test_basic_flow(self, use_sample_weight, generator_type):\n        x = np.random.random((34, 4)).astype(\"float32\")\n        y = np.array([[i, i] for i in range(34)], dtype=\"float32\")\n        sw = np.random.random((34,)).astype(\"float32\")\n        if generator_type == \"tf\":\n            x, y, sw = tf.constant(x), tf.constant(y), tf.constant(sw)\n        elif generator_type == \"jax\":\n            x, y, sw = jnp.array(x), jnp.array(y), jnp.array(sw)\n        elif generator_type == \"torch\":\n            x, y, sw = (\n                torch.as_tensor(x),\n                torch.as_tensor(y),\n                torch.as_tensor(sw),\n            )\n        if not use_sample_weight:\n            sw = None\n        make_generator = example_generator(\n            x,\n            y,\n            sample_weight=sw,\n            batch_size=16,\n        )\n\n        adapter = generator_data_adapter.GeneratorDataAdapter(make_generator())\n        if backend.backend() == \"tensorflow\":\n            it = adapter.get_tf_dataset()\n            expected_class = tf.Tensor\n        elif backend.backend() == \"jax\":\n            it = adapter.get_jax_iterator()\n            expected_class = (\n                jax.Array if generator_type == \"jax\" else np.ndarray\n            )\n        elif backend.backend() == \"torch\":\n            it = adapter.get_torch_dataloader()\n            expected_class = torch.Tensor\n        else:\n            it = adapter.get_numpy_iterator()\n            expected_class = np.ndarray\n\n        sample_order = []\n        for i, batch in enumerate(it):\n            if use_sample_weight:\n                self.assertEqual(len(batch), 3)\n                bx, by, bsw = batch\n            else:\n                self.assertEqual(len(batch), 2)\n                bx, by = batch\n            self.assertIsInstance(bx, expected_class)\n            self.assertIsInstance(by, expected_class)\n            self.assertEqual(bx.dtype, by.dtype)\n            self.assertContainsExactSubsequence(str(bx.dtype), \"float32\")\n            if i < 2:\n                self.assertEqual(bx.shape, (16, 4))\n                self.assertEqual(by.shape, (16, 2))\n            else:\n                self.assertEqual(bx.shape, (2, 4))\n                self.assertEqual(by.shape, (2, 2))\n            if use_sample_weight:\n                self.assertIsInstance(bsw, expected_class)\n            for j in range(by.shape[0]):\n                sample_order.append(by[j, 0])\n        self.assertAllClose(sample_order, list(range(34)))\n\n    def test_with_different_shapes(self):\n        def generator():\n            yield np.ones([16, 4], \"float32\"), np.ones([16, 2], \"float32\")\n            yield np.ones([16, 5], \"float32\"), np.ones([16, 2], \"float32\")\n            yield np.ones([2, 6], \"float32\"), np.ones([2, 2], \"float32\")\n\n        adapter = generator_data_adapter.GeneratorDataAdapter(generator())\n\n        if backend.backend() == \"tensorflow\":\n            it = adapter.get_tf_dataset()\n        elif backend.backend() == \"jax\":\n            it = adapter.get_jax_iterator()\n        elif backend.backend() == \"torch\":\n            it = adapter.get_torch_dataloader()\n        else:\n            it = adapter.get_numpy_iterator()\n\n        for i, batch in enumerate(it):\n            self.assertEqual(len(batch), 2)\n            bx, by = batch\n            self.assertEqual(bx.dtype, by.dtype)\n            self.assertContainsExactSubsequence(str(bx.dtype), \"float32\")\n            if i == 0:\n                self.assertEqual(bx.shape, (16, 4))\n                self.assertEqual(by.shape, (16, 2))\n            elif i == 1:\n                self.assertEqual(bx.shape, (16, 5))\n                self.assertEqual(by.shape, (16, 2))\n            else:\n                self.assertEqual(bx.shape, (2, 6))\n                self.assertEqual(by.shape, (2, 2))\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"tf.data.Dataset specific behavior\",\n    )\n    def test_with_unexpected_shapes(self):\n        def generator():\n            yield np.ones([16, 4], \"float32\"), np.ones([16, 2], \"float32\")\n            yield np.ones([16, 5], \"float32\"), np.ones([16, 2], \"float32\")\n            yield np.ones([16, 6], \"float32\"), np.ones([16, 3], \"float32\")\n\n        adapter = generator_data_adapter.GeneratorDataAdapter(generator())\n\n        it = iter(adapter.get_tf_dataset())\n        next(it)\n        next(it)\n        # note that Tensorflow wraps the TypeError in an InvalidArgumentError.\n        with self.assertRaisesRegex(\n            tf.errors.InvalidArgumentError,\n            \"TypeError:.* shape \\\\(16, 3\\\\).* shape \\\\(None, 2\\\\) was expected\"\n            \".*first two batches\",\n        ):\n            next(it)\n\n    @parameterized.named_parameters(\n        named_product(generator_type=[\"tf\", \"jax\", \"scipy\"])\n    )\n    @pytest.mark.skipif(\n        not backend.SUPPORTS_SPARSE_TENSORS,\n        reason=\"Backend does not support sparse tensors\",\n    )\n    def test_sparse_tensors(self, generator_type):\n        if generator_type == \"tf\":\n            x = tf.SparseTensor([[0, 0], [1, 2]], [1.0, 2.0], (2, 4))\n            y = tf.SparseTensor([[0, 0], [1, 1]], [3.0, 4.0], (2, 2))\n        elif generator_type == \"jax\":\n            x = jax_sparse.BCOO(([1.0, 2.0], [[0, 0], [1, 2]]), shape=(2, 4))\n            y = jax_sparse.BCOO(([3.0, 4.0], [[0, 0], [1, 1]]), shape=(2, 2))\n        elif generator_type == \"scipy\":\n            x = scipy.sparse.coo_matrix(([1.0, 2.0], ([0, 1], [0, 2])), (2, 4))\n            y = scipy.sparse.coo_matrix(([3.0, 4.0], ([0, 1], [0, 1])), (2, 2))\n\n        def generate():\n            for _ in range(4):\n                yield x, y\n\n        adapter = generator_data_adapter.GeneratorDataAdapter(generate())\n\n        if backend.backend() == \"tensorflow\":\n            it = adapter.get_tf_dataset()\n            expected_class = tf.SparseTensor\n        elif backend.backend() == \"jax\":\n            it = adapter.get_jax_iterator()\n            expected_class = jax_sparse.BCOO\n\n        for batch in it:\n            self.assertEqual(len(batch), 2)\n            bx, by = batch\n            self.assertIsInstance(bx, expected_class)\n            self.assertIsInstance(by, expected_class)\n            self.assertEqual(bx.shape, (2, 4))\n            self.assertEqual(by.shape, (2, 2))\n\n    @pytest.mark.skipif(\n        not backend.SUPPORTS_RAGGED_TENSORS,\n        reason=\"Backend does not support ragged tensors\",\n    )\n    def test_ragged_tensors(self):\n        x = tf.ragged.constant(\n            [[[0.0, 1.0]], [[2.0, 3.0], [4.0, 5.0]]], ragged_rank=1\n        )\n        y = tf.ragged.constant(\n            [[[0.0, 1.0]], [[0.0, 1.0], [0.0, 1.0]]], ragged_rank=1\n        )\n\n        def generate():\n            for _ in range(4):\n                yield x, y\n\n        adapter = generator_data_adapter.GeneratorDataAdapter(generate())\n\n        if backend.backend() == \"tensorflow\":\n            it = adapter.get_tf_dataset()\n            expected_class = tf.RaggedTensor\n\n        for batch in it:\n            self.assertEqual(len(batch), 2)\n            bx, by = batch\n            self.assertIsInstance(bx, expected_class)\n            self.assertIsInstance(by, expected_class)\n            self.assertEqual(bx.shape, (2, None, 2))\n            self.assertEqual(by.shape, (2, None, 2))\n"
  },
  {
    "path": "keras/src/trainers/data_adapters/grain_dataset_adapter.py",
    "content": "import itertools\n\nimport numpy as np\n\nfrom keras.src import tree\nfrom keras.src.trainers.data_adapters import data_adapter_utils\nfrom keras.src.trainers.data_adapters.data_adapter import DataAdapter\nfrom keras.src.utils.module_utils import grain\nfrom keras.src.utils.module_utils import tensorflow as tf\n\n\nclass GrainDatasetAdapter(DataAdapter):\n    \"\"\"Adapter that handles `grain.DataLoader`, `grain.MapDataset` and\n    `grain.IterDataset`.\n    \"\"\"\n\n    def __init__(self, dataset):\n        \"\"\"Initialize the GrainDatasetAdapter.\n\n        Args:\n            dataset: A Grain dataset instance. Must be one of\n                `grain.DataLoader`, `grain.MapDataset`, or `grain.IterDataset`.\n        \"\"\"\n\n        if not isinstance(\n            dataset, (grain.MapDataset, grain.IterDataset, grain.DataLoader)\n        ):\n            raise ValueError(\n                \"Expected `dataset` to be a grain.MapDataset, \"\n                \"grain.IterDataset or grain.DataLoader. \"\n                f\"Received: {dataset} of type {type(dataset)}\"\n            )\n\n        self._dataset = dataset\n\n        batch_size, output_signature = self._get_dataset_info(dataset)\n        self._batch_size = batch_size\n        self._output_signature = output_signature\n        self._output_tf_signature = None\n\n    def _get_dataset_info(self, dataset):\n        \"\"\"Get the `batch_size` and `output_signature` from the dataset.\n\n        We use a small list of batches to infer the `batch_size` and\n        `output_signature`.\n        \"\"\"\n        batches = list(\n            itertools.islice(\n                dataset, data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC\n            )\n        )\n        output_signature = data_adapter_utils.get_keras_tensor_spec(batches)\n        flat_output_signature = tree.flatten(output_signature)\n        batch_size = flat_output_signature[0].shape[0]\n        if batch_size is not None:\n            batch_size = int(batch_size)\n        return batch_size, output_signature\n\n    def get_numpy_iterator(self):\n        # Workaround for internal change in Grain which isn't a part of a\n        # release yet.\n        # TODO(abheesht17): Remove this after the next Grain release.\n        try:\n            from grain._src.python.shared_memory_array import (\n                SharedMemoryArrayMetadata,\n            )\n        except ImportError:\n            from grain._src.python.ipc.shared_memory_array import (\n                SharedMemoryArrayMetadata,\n            )\n\n        def convert_to_numpy(x):\n            if isinstance(x, (np.ndarray, SharedMemoryArrayMetadata)):\n                return x\n            else:\n                # Using `__array__` should handle `tf.Tensor`, `jax.np.ndarray`,\n                # `torch.Tensor`, as well as any other tensor-like object that\n                # has added numpy support.\n                if hasattr(x, \"__array__\"):\n                    if data_adapter_utils.is_torch_tensor(x):\n                        x = x.cpu()\n                    x = np.asarray(x)\n                return x\n\n        class ConvertToNumpy(grain.transforms.Map):\n            def map(self, x):\n                return tree.map_structure(\n                    convert_to_numpy, x, none_is_leaf=False\n                )\n\n        if isinstance(self._dataset, (grain.MapDataset, grain.IterDataset)):\n            dataset = self._dataset.map(ConvertToNumpy())\n        else:\n            # Instantiate a new `DataLoader`.\n            dataset = grain.DataLoader(\n                data_source=self._dataset._data_source,\n                sampler=self._dataset._sampler,\n                # Append `ConvertToNumpy`.\n                operations=list(self._dataset._operations) + [ConvertToNumpy()],\n                worker_count=self._dataset._multiprocessing_options.num_workers,\n                worker_buffer_size=self._dataset._multiprocessing_options.per_worker_buffer_size,\n                shard_options=self._dataset._shard_options,\n                read_options=self._dataset._read_options,\n                enable_profiling=self._dataset._multiprocessing_options.enable_profiling,\n            )\n        return dataset\n\n    def get_jax_iterator(self):\n        def convert_to_jax_compatible(x):\n            if data_adapter_utils.is_scipy_sparse(x):\n                x = data_adapter_utils.scipy_sparse_to_jax_sparse(x)\n            elif data_adapter_utils.is_tensorflow_sparse(x):\n                x = data_adapter_utils.tf_sparse_to_jax_sparse(x)\n            return x\n\n        class ConvertToJaxCompatible(grain.transforms.Map):\n            def map(self, x):\n                return tree.map_structure(\n                    convert_to_jax_compatible, x, none_is_leaf=False\n                )\n\n        if isinstance(self._dataset, (grain.MapDataset, grain.IterDataset)):\n            dataset = self._dataset.map(ConvertToJaxCompatible())\n        else:\n            # Instantiate a new `DataLoader`.\n            dataset = grain.DataLoader(\n                data_source=self._dataset._data_source,\n                sampler=self._dataset._sampler,\n                # Append `ConvertToJaxCompatible`.\n                operations=list(self._dataset._operations)\n                + [ConvertToJaxCompatible()],\n                worker_count=self._dataset._multiprocessing_options.num_workers,\n                worker_buffer_size=self._dataset._multiprocessing_options.per_worker_buffer_size,\n                shard_options=self._dataset._shard_options,\n                read_options=self._dataset._read_options,\n                enable_profiling=self._dataset._multiprocessing_options.enable_profiling,\n            )\n        return dataset\n\n    def get_tf_dataset(self):\n        def convert_to_tf(x):\n            if x is None:\n                return tf.experimental.Optional.empty(None)\n            if data_adapter_utils.is_scipy_sparse(x):\n                x = data_adapter_utils.scipy_sparse_to_tf_sparse(x)\n            elif data_adapter_utils.is_jax_sparse(x):\n                x = data_adapter_utils.jax_sparse_to_tf_sparse(x)\n            return x\n\n        class ConvertToTF(grain.transforms.Map):\n            def map(self, x):\n                return tree.map_structure(convert_to_tf, x)\n\n        # `tf.data.Dataset.from_generator` does not support lists as output.\n        # We convert lists to tuples.\n        class ListToTuple(grain.transforms.Map):\n            def map(self, x):\n                return tree.lists_to_tuples(x)\n\n        if isinstance(self._dataset, (grain.MapDataset, grain.IterDataset)):\n            dataset = self._dataset.map(ConvertToTF())\n            dataset = dataset.map(ListToTuple())\n        else:\n            # Instantiate a new `DataLoader`.\n            dataset = grain.DataLoader(\n                data_source=self._dataset._data_source,\n                sampler=self._dataset._sampler,\n                # Append `ConvertToTF` and `ListToTuple`.\n                operations=list(self._dataset._operations)\n                + [ConvertToTF(), ListToTuple()],\n                worker_count=self._dataset._multiprocessing_options.num_workers,\n                worker_buffer_size=self._dataset._multiprocessing_options.per_worker_buffer_size,\n                shard_options=self._dataset._shard_options,\n                read_options=self._dataset._read_options,\n                enable_profiling=self._dataset._multiprocessing_options.enable_profiling,\n            )\n\n        if self._output_tf_signature is None:\n            self._output_tf_signature = tree.map_structure(\n                data_adapter_utils.convert_to_tf_tensor_spec,\n                self._output_signature,\n            )\n\n        return tf.data.Dataset.from_generator(\n            lambda: dataset, output_signature=self._output_tf_signature\n        )\n\n    def get_torch_dataloader(self):\n        import torch.utils.data as torch_data\n\n        class ConverterIterableDataset(torch_data.IterableDataset):\n            def __init__(self, iterable):\n                super().__init__()\n                self.iterable = iterable\n\n            def __iter__(self):\n                return iter(self.iterable)\n\n        # `batch_size=None` indicates that we should not re-batch\n        return torch_data.DataLoader(\n            ConverterIterableDataset(self._dataset), batch_size=None\n        )\n\n    @property\n    def num_batches(self):\n        return None\n\n    @property\n    def batch_size(self):\n        return self._batch_size\n\n    @property\n    def has_partial_batch(self):\n        return None\n\n    @property\n    def partial_batch_size(self):\n        return None\n"
  },
  {
    "path": "keras/src/trainers/data_adapters/grain_dataset_adapter_test.py",
    "content": "import grain\nimport numpy as np\nimport tensorflow as tf\nimport torch\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import testing\nfrom keras.src.testing.test_utils import named_product\nfrom keras.src.trainers.data_adapters import grain_dataset_adapter\n\n\nclass Range2DSource(grain.sources.RandomAccessDataSource):\n    def __init__(self, start, stop):\n        self.start = start\n        self.stop = stop\n\n    def __getitem__(self, idx):\n        return np.expand_dims(np.array([self.start + idx]), axis=0)\n\n    def __len__(self):\n        return self.stop - self.start\n\n\nclass GrainDatasetAdapterTest(testing.TestCase):\n    def _get_dataset(self, dataset_type, worker_count=0, num_threads=0):\n        x = np.random.normal(size=(34, 4)).astype(\"float32\")\n        y = np.random.normal(size=(34, 2)).astype(\"float32\")\n\n        class MySource(grain.sources.RandomAccessDataSource):\n            def __init__(self, x, y):\n                self.x = x\n                self.y = y\n\n            def __getitem__(self, idx):\n                return self.x[idx], self.y[idx]\n\n            def __len__(self):\n                return len(self.x)\n\n        if dataset_type == \"map_dataset\":\n            dataset = grain.MapDataset.source(MySource(x, y)).batch(\n                batch_size=16\n            )\n        elif dataset_type == \"iter_dataset\":\n            dataset = (\n                grain.MapDataset.source(MySource(x, y))\n                .to_iter_dataset()\n                .batch(batch_size=16)\n            )\n        else:\n            source = MySource(x, y)\n            dataset = grain.DataLoader(\n                data_source=source,\n                operations=[grain.transforms.Batch(batch_size=16)],\n                shard_options=grain.sharding.NoSharding(),\n                sampler=grain.samplers.IndexSampler(\n                    num_records=len(source), num_epochs=1\n                ),\n                worker_count=worker_count,\n                read_options=grain.ReadOptions(num_threads=num_threads),\n            )\n        return dataset\n\n    @parameterized.named_parameters(\n        named_product(\n            dataset_type=[\"map_dataset\", \"iter_dataset\", \"data_loader\"]\n        )\n    )\n    def test_basic_flow(self, dataset_type):\n        dataset = self._get_dataset(dataset_type)\n        adapter = grain_dataset_adapter.GrainDatasetAdapter(dataset)\n\n        self.assertEqual(adapter.num_batches, None)\n        self.assertEqual(adapter.batch_size, 16)\n        self.assertEqual(adapter.has_partial_batch, None)\n        self.assertEqual(adapter.partial_batch_size, None)\n\n        if backend.backend() == \"tensorflow\":\n            it = adapter.get_tf_dataset()\n            expected_class = tf.Tensor\n        elif backend.backend() == \"jax\":\n            it = adapter.get_jax_iterator()\n            expected_class = np.ndarray\n        elif backend.backend() == \"torch\":\n            it = adapter.get_torch_dataloader()\n            expected_class = torch.Tensor\n        else:\n            it = adapter.get_numpy_iterator()\n            expected_class = np.ndarray\n\n        for i, batch in enumerate(it):\n            self.assertEqual(len(batch), 2)\n            bx, by = batch\n            self.assertIsInstance(bx, expected_class)\n            self.assertIsInstance(by, expected_class)\n            self.assertEqual(bx.dtype, by.dtype)\n            self.assertContainsExactSubsequence(str(bx.dtype), \"float32\")\n            if i < 2:\n                self.assertEqual(bx.shape, (16, 4))\n                self.assertEqual(by.shape, (16, 2))\n            else:\n                self.assertEqual(bx.shape, (2, 4))\n                self.assertEqual(by.shape, (2, 2))\n\n    @parameterized.named_parameters(\n        named_product(data_type=[\"list\", \"dict\", \"nested_list\", \"nested_dict\"])\n    )\n    def test_nested_data(self, data_type):\n        if data_type not in (\"list\", \"dict\", \"nested_list\", \"nested_dict\"):\n            raise ValueError(\n                \"data_type must be one of 'list', 'dict', 'nested_list' or \"\n                f\"'nested_dict'. Received: {data_type}\"\n            )\n\n        class NestedSource(grain.sources.RandomAccessDataSource):\n            def __init__(self, data_type):\n                self.x = np.random.random((40, 4)).astype(\"float32\")\n                self.y = np.random.random((40, 2)).astype(\"float32\")\n                self.data_type = data_type\n\n            def __len__(self):\n                return len(self.x)\n\n            def __getitem__(self, idx):\n                x = self.x[idx]\n                y = self.y[idx]\n                if self.data_type == \"list\":\n                    return x, y\n                elif self.data_type == \"dict\":\n                    return {\"x\": x, \"y\": y}\n                elif self.data_type == \"nested_list\":\n                    return x, (x, y)\n                elif self.data_type == \"nested_dict\":\n                    return {\"data\": {\"x\": x, \"y\": y}}\n\n        dataset = grain.MapDataset.source(NestedSource(data_type)).batch(\n            batch_size=4\n        )\n        adapter = grain_dataset_adapter.GrainDatasetAdapter(dataset)\n\n        if backend.backend() == \"tensorflow\":\n            it = adapter.get_tf_dataset()\n            expected_class = tf.Tensor\n        elif backend.backend() == \"jax\":\n            it = adapter.get_jax_iterator()\n            expected_class = np.ndarray\n        elif backend.backend() == \"torch\":\n            it = adapter.get_torch_dataloader()\n            expected_class = torch.Tensor\n        else:\n            it = adapter.get_numpy_iterator()\n            expected_class = np.ndarray\n\n        for batch in it:\n            if data_type == \"list\":\n                self.assertEqual(len(batch), 2)\n                bx, by = batch\n            elif data_type == \"dict\":\n                self.assertEqual(len(batch), 2)\n                bx, by = batch[\"x\"], batch[\"y\"]\n            elif data_type == \"nested_list\":\n                self.assertEqual(len(batch), 2)\n                bx, (_, by) = batch\n            elif data_type == \"nested_dict\":\n                self.assertEqual(len(batch[\"data\"]), 2)\n                bx, by = batch[\"data\"][\"x\"], batch[\"data\"][\"y\"]\n            self.assertIsInstance(bx, expected_class)\n            self.assertIsInstance(by, expected_class)\n            self.assertEqual(bx.dtype, by.dtype)\n            self.assertEqual(bx.shape, (4, 4))\n            self.assertEqual(by.shape, (4, 2))\n\n    def test_multiple_calling_on_iterators(self):\n        dataset = self._get_dataset(\"iter_dataset\")\n        adapter = grain_dataset_adapter.GrainDatasetAdapter(dataset)\n\n        numpy_it = adapter.get_numpy_iterator()\n        jax_it = adapter.get_jax_iterator()\n        tf_it = adapter.get_tf_dataset()\n        torch_it = adapter.get_torch_dataloader()\n        for it in (numpy_it, jax_it, tf_it, torch_it):\n            for batch in it:\n                self.assertEqual(len(batch), 2)\n                bx, by = batch\n                self.assertEqual(bx.dtype, by.dtype)\n\n    def test_num_batches(self):\n        dataset = grain.MapDataset.source(Range2DSource(0, 42))\n        adapter = grain_dataset_adapter.GrainDatasetAdapter(dataset)\n        self.assertEqual(adapter.num_batches, None)\n\n        # Test for Infinite Cardinality\n        dataset = grain.MapDataset.source(Range2DSource(0, 42))\n        dataset = dataset.repeat()\n        adapter = grain_dataset_adapter.GrainDatasetAdapter(dataset)\n        self.assertIsNone(adapter.num_batches)\n\n        # Test for Unknown Cardinality\n        dataset = dataset.filter(lambda x: True)\n        adapter = grain_dataset_adapter.GrainDatasetAdapter(dataset)\n        self.assertIsNone(adapter.num_batches)\n\n    def test_invalid_dataset_type(self):\n        with self.assertRaisesRegex(\n            ValueError,\n            (\n                r\"Expected `dataset` to be a grain.MapDataset, \"\n                r\"grain.IterDataset or grain.DataLoader. \"\n            ),\n        ):\n            grain_dataset_adapter.GrainDatasetAdapter(\n                \"This is not a grain.Dataset\"\n            )\n"
  },
  {
    "path": "keras/src/trainers/data_adapters/py_dataset_adapter.py",
    "content": "import itertools\nimport multiprocessing.dummy\nimport queue\nimport random\nimport threading\nimport warnings\nimport weakref\nfrom contextlib import closing\n\nimport numpy as np\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.trainers.data_adapters import data_adapter_utils\nfrom keras.src.trainers.data_adapters.data_adapter import DataAdapter\n\n\n@keras_export([\"keras.utils.PyDataset\", \"keras.utils.Sequence\"])\nclass PyDataset:\n    \"\"\"Base class for defining a parallel dataset using Python code.\n\n    Every `PyDataset` must implement the `__getitem__()` and the `__len__()`\n    methods. If you want to modify your dataset between epochs,\n    you may additionally implement `on_epoch_end()`,\n    or `on_epoch_begin` to be called at the start of each epoch.\n    The `__getitem__()` method should return a complete batch\n    (not a single sample), and the `__len__` method should return\n    the number of batches in the dataset (rather than the number of samples).\n\n    Args:\n        workers: Number of workers to use in multithreading or\n            multiprocessing.\n        use_multiprocessing: Whether to use Python multiprocessing for\n            parallelism. Setting this to `True` means that your\n            dataset will be replicated in multiple forked processes.\n            This is necessary to gain compute-level (rather than I/O level)\n            benefits from parallelism. However it can only be set to\n            `True` if your dataset can be safely pickled.\n        max_queue_size: Maximum number of batches to keep in the queue\n            when iterating over the dataset in a multithreaded or\n            multiprocessed setting.\n            Reduce this value to reduce the CPU memory consumption of\n            your dataset. Defaults to 10.\n\n    Notes:\n\n    - `PyDataset` is a safer way to do multiprocessing.\n        This structure guarantees that the model will only train\n        once on each sample per epoch, which is not the case\n        with Python generators.\n    - The arguments `workers`, `use_multiprocessing`, and `max_queue_size`\n        exist to configure how `fit()` uses parallelism to iterate\n        over the dataset. They are not being used by the `PyDataset` class\n        directly. When you are manually iterating over a `PyDataset`,\n        no parallelism is applied.\n\n    Example:\n\n    ```python\n    from skimage.io import imread\n    from skimage.transform import resize\n    import numpy as np\n    import math\n\n    # Here, `x_set` is list of path to the images\n    # and `y_set` are the associated classes.\n\n    class CIFAR10PyDataset(keras.utils.PyDataset):\n\n        def __init__(self, x_set, y_set, batch_size, **kwargs):\n            super().__init__(**kwargs)\n            self.x, self.y = x_set, y_set\n            self.batch_size = batch_size\n\n        def __len__(self):\n            # Return number of batches.\n            return math.ceil(len(self.x) / self.batch_size)\n\n        def __getitem__(self, idx):\n            # Return x, y for batch idx.\n            low = idx * self.batch_size\n            # Cap upper bound at array length; the last batch may be smaller\n            # if the total number of items is not a multiple of batch size.\n            high = min(low + self.batch_size, len(self.x))\n            batch_x = self.x[low:high]\n            batch_y = self.y[low:high]\n\n            return np.array([\n                resize(imread(file_name), (200, 200))\n                   for file_name in batch_x]), np.array(batch_y)\n    ```\n    \"\"\"\n\n    def __init__(self, workers=1, use_multiprocessing=False, max_queue_size=10):\n        self._workers = workers\n        self._use_multiprocessing = use_multiprocessing\n        self._max_queue_size = max_queue_size\n\n    def _warn_if_super_not_called(self):\n        warn = False\n        if not hasattr(self, \"_workers\"):\n            self._workers = 1\n            warn = True\n        if not hasattr(self, \"_use_multiprocessing\"):\n            self._use_multiprocessing = False\n            warn = True\n        if not hasattr(self, \"_max_queue_size\"):\n            self._max_queue_size = 10\n            warn = True\n        if warn:\n            warnings.warn(\n                \"Your `PyDataset` class should call \"\n                \"`super().__init__(**kwargs)` in its constructor. \"\n                \"`**kwargs` can include `workers`, \"\n                \"`use_multiprocessing`, `max_queue_size`. Do not pass \"\n                \"these arguments to `fit()`, as they will be ignored.\",\n                stacklevel=2,\n            )\n\n    @property\n    def workers(self):\n        self._warn_if_super_not_called()\n        return self._workers\n\n    @workers.setter\n    def workers(self, value):\n        self._workers = value\n\n    @property\n    def use_multiprocessing(self):\n        self._warn_if_super_not_called()\n        return self._use_multiprocessing\n\n    @use_multiprocessing.setter\n    def use_multiprocessing(self, value):\n        self._use_multiprocessing = value\n\n    @property\n    def max_queue_size(self):\n        self._warn_if_super_not_called()\n        return self._max_queue_size\n\n    @max_queue_size.setter\n    def max_queue_size(self, value):\n        self._max_queue_size = value\n\n    def __getitem__(self, index):\n        \"\"\"Gets batch at position `index`.\n\n        Args:\n            index: position of the batch in the PyDataset.\n\n        Returns:\n            A batch\n        \"\"\"\n        del index\n        raise NotImplementedError\n\n    def __iter__(self):\n        index_range = None\n        try:\n            num_batches = self.num_batches\n            if num_batches is not None:\n                index_range = range(num_batches)\n        except NotImplementedError:\n            pass\n\n        if index_range is None:\n            index_range = itertools.count()\n\n        for index in index_range:\n            yield self[index]\n\n    @property\n    def num_batches(self):\n        \"\"\"Number of batches in the PyDataset.\n\n        Returns:\n            The number of batches in the PyDataset or `None` to indicate that\n            the dataset is infinite.\n        \"\"\"\n        # For backwards compatibility, support `__len__`.\n        if hasattr(self, \"__len__\"):\n            return len(self)\n        raise NotImplementedError(\n            \"You need to implement the `num_batches` property:\\n\\n\"\n            \"@property\\ndef num_batches(self):\\n  return ...\"\n        )\n\n    def on_epoch_begin(self):\n        \"\"\"Method called at the beginning of every epoch.\"\"\"\n        pass\n\n    def on_epoch_end(self):\n        \"\"\"Method called at the end of every epoch.\"\"\"\n        pass\n\n\nclass PyDatasetAdapter(DataAdapter):\n    \"\"\"Adapter for `keras.utils.PyDataset` instances.\"\"\"\n\n    def __init__(\n        self,\n        x,\n        class_weight=None,\n        shuffle=False,\n    ):\n        self.py_dataset = x\n        self.class_weight = class_weight\n        self.enqueuer = None\n        self.shuffle = shuffle\n        self._output_signature = None\n        self._within_epoch = False\n\n        workers = self.py_dataset.workers\n        use_multiprocessing = self.py_dataset.use_multiprocessing\n        if workers > 1 or (workers > 0 and use_multiprocessing):\n            self.enqueuer = OrderedEnqueuer(\n                self.py_dataset,\n                workers=workers,\n                use_multiprocessing=use_multiprocessing,\n                max_queue_size=self.py_dataset.max_queue_size,\n                shuffle=self.shuffle,\n            )\n\n    def _standardize_batch(self, batch):\n        if isinstance(batch, dict):\n            return batch\n        if isinstance(batch, np.ndarray):\n            batch = (batch,)\n        if isinstance(batch, list):\n            batch = tuple(batch)\n        if not isinstance(batch, tuple) or len(batch) not in {1, 2, 3}:\n            raise ValueError(\n                \"PyDataset.__getitem__() must return a tuple or a dict. \"\n                \"If a tuple, it must be ordered either \"\n                \"(input,) or (inputs, targets) or \"\n                \"(inputs, targets, sample_weights). \"\n                f\"Received: {str(batch)[:100]}... of type {type(batch)}\"\n            )\n        if self.class_weight is not None:\n            if len(batch) == 3:\n                raise ValueError(\n                    \"You cannot specify `class_weight` \"\n                    \"and `sample_weight` at the same time.\"\n                )\n            if len(batch) == 2:\n                sw = data_adapter_utils.class_weight_to_sample_weights(\n                    batch[1], self.class_weight\n                )\n                batch = batch + (sw,)\n        return batch\n\n    def _infinite_generator(self):\n        for i in itertools.count():\n            yield self._standardize_batch(self.py_dataset[i])\n\n    def _finite_generator(self):\n        indices = range(self.py_dataset.num_batches)\n        if self.shuffle:\n            indices = list(indices)\n            random.shuffle(indices)\n\n        for i in indices:\n            yield self._standardize_batch(self.py_dataset[i])\n\n    def _infinite_enqueuer_generator(self):\n        self.enqueuer.start()\n        for batch in self.enqueuer.get():\n            yield self._standardize_batch(batch)\n\n    def _finite_enqueuer_generator(self):\n        self.enqueuer.start()\n        num_batches = self.py_dataset.num_batches\n        for i, batch in enumerate(self.enqueuer.get()):\n            yield self._standardize_batch(batch)\n            if i >= num_batches - 1:\n                self.enqueuer.stop()\n                return\n\n    def _get_iterator(self):\n        if self.enqueuer is None:\n            if self.py_dataset.num_batches is None:\n                return self._infinite_generator()\n            else:\n                return self._finite_generator()\n        else:\n            if self.py_dataset.num_batches is None:\n                return self._infinite_enqueuer_generator()\n            else:\n                return self._finite_enqueuer_generator()\n\n    def get_numpy_iterator(self):\n        return data_adapter_utils.get_numpy_iterator(self._get_iterator())\n\n    def get_jax_iterator(self):\n        return data_adapter_utils.get_jax_iterator(self._get_iterator())\n\n    def get_tf_dataset(self):\n        from keras.src.utils.module_utils import tensorflow as tf\n\n        num_batches = self.py_dataset.num_batches\n        if self._output_signature is None:\n            num_samples = data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC\n            if num_batches is not None:\n                num_samples = min(num_samples, num_batches)\n            batches = [\n                self._standardize_batch(self.py_dataset[i])\n                for i in range(num_samples)\n            ]\n            if len(batches) == 0:\n                raise ValueError(\"The PyDataset has length 0\")\n            self._output_signature = data_adapter_utils.get_tensor_spec(batches)\n\n        ds = tf.data.Dataset.from_generator(\n            self._get_iterator,\n            output_signature=self._output_signature,\n        )\n        if self.enqueuer is not None:\n            # The enqueuer does its own multithreading / multiprocesssing to\n            # prefetch items. Disable the tf.data.Dataset prefetching and\n            # threading as it interferes.\n            options = tf.data.Options()\n            options.autotune.enabled = False\n            options.threading.private_threadpool_size = 1\n            ds = ds.with_options(options)\n        else:\n            ds = ds.prefetch(tf.data.AUTOTUNE)\n        return ds\n\n    def get_torch_dataloader(self):\n        return data_adapter_utils.get_torch_dataloader(self._get_iterator())\n\n    def on_epoch_begin(self):\n        if self._within_epoch:\n            raise ValueError(\n                \"`on_epoch_begin` was called twice without `on_epoch_end` \"\n                \"having been called.\"\n            )\n        self._within_epoch = True\n        if self.enqueuer:\n            self.enqueuer.start()\n        self.py_dataset.on_epoch_begin()\n\n    def on_epoch_end(self):\n        if self.enqueuer:\n            self.enqueuer.stop()\n        self.py_dataset.on_epoch_end()\n        self._within_epoch = False\n\n    @property\n    def num_batches(self):\n        return self.py_dataset.num_batches\n\n    @property\n    def batch_size(self):\n        return None\n\n\n# Global variables to be shared across processes\n_SHARED_SEQUENCES = {}\n# We use a Value to provide unique id to different processes.\n_SEQUENCE_COUNTER = None\n\n\n# Because multiprocessing pools are inherently unsafe, starting from a clean\n# state can be essential to avoiding deadlocks. In order to accomplish this, we\n# need to be able to check on the status of Pools that we create.\n_DATA_POOLS = weakref.WeakSet()\n_WORKER_ID_QUEUE = None  # Only created if needed.\n_FORCE_THREADPOOL = False\n\n\ndef get_pool_class(use_multiprocessing):\n    global _FORCE_THREADPOOL\n    if not use_multiprocessing or _FORCE_THREADPOOL:\n        return multiprocessing.dummy.Pool  # ThreadPool\n    return multiprocessing.Pool\n\n\ndef get_worker_id_queue():\n    \"\"\"Lazily create the queue to track worker ids.\"\"\"\n    global _WORKER_ID_QUEUE\n    if _WORKER_ID_QUEUE is None:\n        _WORKER_ID_QUEUE = multiprocessing.Queue()\n    return _WORKER_ID_QUEUE\n\n\ndef get_index(uid, i):\n    \"\"\"Get the value from the PyDataset `uid` at index `i`.\n\n    To allow multiple PyDatasets to be used at the same time, we use `uid` to\n    get a specific one. A single PyDataset would cause the validation to\n    overwrite the training PyDataset.\n\n    This methods is called from worker threads.\n\n    Args:\n        uid: int, PyDataset identifier\n        i: index\n\n    Returns:\n        The value at index `i`.\n    \"\"\"\n    return _SHARED_SEQUENCES[uid][i]\n\n\nclass PyDatasetEnqueuer:\n    \"\"\"Base class to enqueue inputs.\n\n    The task of an Enqueuer is to use parallelism to speed up preprocessing.\n    This is done with processes or threads.\n\n    Example:\n\n    ```python\n        enqueuer = PyDatasetEnqueuer(...)\n        enqueuer.start()\n        datas = enqueuer.get()\n        for data in datas:\n            # Use the inputs; training, evaluating, predicting.\n            # ... stop sometime.\n        enqueuer.stop()\n    ```\n\n    The `enqueuer.get()` should be an infinite stream of data.\n    \"\"\"\n\n    def __init__(\n        self,\n        py_dataset,\n        workers=1,\n        use_multiprocessing=False,\n        max_queue_size=10,\n    ):\n        self.py_dataset = py_dataset\n\n        global _SEQUENCE_COUNTER\n        if _SEQUENCE_COUNTER is None:\n            try:\n                _SEQUENCE_COUNTER = multiprocessing.Value(\"i\", 0)\n            except OSError:\n                # In this case the OS does not allow us to use\n                # multiprocessing. We resort to an int\n                # for enqueuer indexing.\n                _SEQUENCE_COUNTER = 0\n\n        if isinstance(_SEQUENCE_COUNTER, int):\n            self.uid = _SEQUENCE_COUNTER\n            _SEQUENCE_COUNTER += 1\n        else:\n            # Doing Multiprocessing.Value += x is not process-safe.\n            with _SEQUENCE_COUNTER.get_lock():\n                self.uid = _SEQUENCE_COUNTER.value\n                _SEQUENCE_COUNTER.value += 1\n\n        self.ready_queue = queue.Queue()\n        self.future_queue = queue.Queue(max_queue_size)\n        self.running = False\n        self.start_stop_lock = threading.Lock()\n        self.run_thread = None\n        if use_multiprocessing:\n            self.executor_fn = self._get_executor_init(workers)\n        else:\n            # We do not need the init since it's threads.\n            self.executor_fn = lambda _: get_pool_class(False)(workers)\n\n    def is_running(self):\n        \"\"\"Whether the enqueuer is running.\n\n        This method is thread safe and called from many threads.\n\n        Returns: boolean indicating whether this enqueuer is running.\n        \"\"\"\n        return self.running\n\n    def start(self):\n        \"\"\"Starts the handler's workers.\n\n        This method is thread safe but is called from the main thread.\n        It is safe to call this method multiple times, extra calls are ignored.\n        \"\"\"\n        with self.start_stop_lock:\n            if self.running:\n                return\n            self.running = True\n            self.run_thread = threading.Thread(target=self._run)\n            self.run_thread.name = f\"Worker_{self.uid}\"\n            self.run_thread.daemon = True\n            self.run_thread.start()\n\n    def stop(self, drain_queue_and_join=True):\n        \"\"\"Stops running threads and wait for them to exit, if necessary.\n\n        This method is thread safe and is called from various threads. Note that\n        the `drain_queue_and_join` argument must be set correctly.\n        It is safe to call this method multiple times, extra calls are ignored.\n\n        Args:\n            drain_queue_and_join: set to True to drain the queue of pending\n                items and wait for the worker thread to complete. Set to False\n                if invoked from a worker thread to avoid deadlocks. Note that\n                setting this to False means this enqueuer won't be reused.\n        \"\"\"\n        with self.start_stop_lock:\n            if not self.running:\n                return\n            self.running = False\n\n            if drain_queue_and_join:\n                # Drain the `future_queue` and put items in `ready_queue` for\n                # the next run.\n                while True:\n                    try:\n                        value = self.future_queue.get(block=True, timeout=0.1)\n                        if isinstance(value, Exception):\n                            raise value  # Propagate exception from other thread\n                        inputs = value.get()\n                        self.future_queue.task_done()\n                        if inputs is not None:\n                            self.ready_queue.put(inputs)\n                    except queue.Empty:\n                        break\n                self.run_thread.join()\n\n            self.run_thread = None\n            _SHARED_SEQUENCES[self.uid] = None\n\n    def _send_py_dataset(self):\n        \"\"\"Sends current Iterable to all workers.\"\"\"\n        # For new processes that may spawn\n        _SHARED_SEQUENCES[self.uid] = self.py_dataset\n\n    def __del__(self):\n        self.stop(drain_queue_and_join=False)\n\n    def _run(self):\n        \"\"\"Submits request to the executor and queue the `Future` objects.\"\"\"\n        raise NotImplementedError\n\n    def _get_executor_init(self, workers):\n        \"\"\"Gets the Pool initializer for multiprocessing.\n\n        Args:\n            workers: Number of workers.\n\n        Returns:\n            Function, a Function to initialize the pool\n        \"\"\"\n        raise NotImplementedError\n\n    def get(self):\n        \"\"\"Creates a generator to extract data from the queue.\n\n        Skip the data if it is `None`.\n\n        This method is called from the main thread.\n\n        Yields:\n            The next element in the queue, i.e. a tuple\n            `(inputs, targets)` or\n            `(inputs, targets, sample_weights)`.\n        \"\"\"\n        raise NotImplementedError\n\n\nclass OrderedEnqueuer(PyDatasetEnqueuer):\n    \"\"\"Builds a Enqueuer from a PyDataset.\n\n    Args:\n        py_dataset: A `keras.utils.PyDataset` object.\n        use_multiprocessing: use multiprocessing if True, otherwise threading\n        shuffle: whether to shuffle the data at the beginning of each epoch\n    \"\"\"\n\n    def __init__(\n        self,\n        py_dataset,\n        workers=1,\n        use_multiprocessing=False,\n        max_queue_size=10,\n        shuffle=False,\n    ):\n        super().__init__(\n            py_dataset, workers, use_multiprocessing, max_queue_size\n        )\n        self.shuffle = shuffle\n        if self.py_dataset.num_batches is None:\n            # For infinite datasets, `self.indices` is created here once for all\n            # so that subsequent runs resume from where they stopped.\n            self.indices = itertools.count()\n\n    def _get_executor_init(self, workers):\n        \"\"\"Gets the Pool initializer for multiprocessing.\n\n        Args:\n            workers: Number of workers.\n\n        Returns:\n            Function, a Function to initialize the pool\n        \"\"\"\n\n        def pool_fn(seqs):\n            pool = get_pool_class(True)(\n                workers,\n                initializer=init_pool_generator,\n                initargs=(seqs, None, get_worker_id_queue()),\n            )\n            _DATA_POOLS.add(pool)\n            return pool\n\n        return pool_fn\n\n    def _run(self):\n        \"\"\"Submits request to the executor and queue the `Future` objects.\n\n        This method is the run method of worker threads.\n        \"\"\"\n        try:\n            if self.py_dataset.num_batches is not None:\n                # For finite datasets, `self.indices` is created here so that\n                # shuffling creates different a order each time.\n                indices = range(self.py_dataset.num_batches)\n                if self.shuffle:\n                    indices = list(indices)\n                    random.shuffle(indices)\n                self.indices = iter(indices)\n            self._send_py_dataset()  # Share the initial py_dataset\n\n            with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:\n                while self.is_running():\n                    try:\n                        i = next(self.indices)\n                        self.future_queue.put(\n                            executor.apply_async(get_index, (self.uid, i)),\n                            block=True,\n                        )\n                    except StopIteration:\n                        break\n        except Exception as e:\n            self.future_queue.put(e)  # Report exception\n\n    def get(self):\n        \"\"\"Creates a generator to extract data from the queue.\n\n        Skip the data if it is `None`.\n\n        This method is called from the main thread.\n\n        Yields:\n            The next element in the queue, i.e. a tuple\n            `(inputs, targets)` or\n            `(inputs, targets, sample_weights)`.\n        \"\"\"\n        while self.is_running():\n            try:\n                inputs = self.ready_queue.get(block=False)\n                yield inputs\n                continue  # Retry the ready_queue\n            except queue.Empty:\n                pass\n\n            try:\n                value = self.future_queue.get(block=True, timeout=5)\n                self.future_queue.task_done()\n                if isinstance(value, Exception):\n                    raise value  # Propagate exception from other thread\n                inputs = value.get()\n                if inputs is not None:\n                    yield inputs\n            except queue.Empty:\n                pass\n            except Exception as e:\n                self.stop(drain_queue_and_join=True)\n                raise e\n\n        # Note that it is ok to poll the iterator after the initial `start`,\n        # which may happen before the first `on_epoch_begin`. But it's not ok to\n        # poll after `on_epoch_end`.\n        raise ValueError(\n            \"Iterator called after `on_epoch_end` or before `on_epoch_begin`.\"\n        )\n\n\ndef init_pool_generator(gens, random_seed=None, id_queue=None):\n    \"\"\"Initializer function for pool workers.\n\n    Args:\n        gens: State which should be made available to worker processes.\n        random_seed: An optional value with which to seed child processes.\n        id_queue: A multiprocessing Queue of worker ids.\n            This is used to indicate that a worker process\n            was created by Keras.\n    \"\"\"\n    global _SHARED_SEQUENCES\n    _SHARED_SEQUENCES = gens\n\n    worker_proc = multiprocessing.current_process()\n\n    # name isn't used for anything, but setting a more descriptive name is\n    # helpful when diagnosing orphaned processes.\n    worker_proc.name = f\"Keras_worker_{worker_proc.name}\"\n\n    if random_seed is not None:\n        np.random.seed(random_seed + worker_proc.ident)\n\n    if id_queue is not None:\n        # If a worker dies during init, the pool will just create a replacement.\n        id_queue.put(worker_proc.ident, block=True, timeout=0.1)\n"
  },
  {
    "path": "keras/src/trainers/data_adapters/py_dataset_adapter_test.py",
    "content": "import math\nimport time\n\nimport jax\nimport numpy as np\nimport pytest\nimport tensorflow as tf\nimport torch\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import testing\nfrom keras.src.testing.test_utils import named_product\nfrom keras.src.trainers.data_adapters import py_dataset_adapter\nfrom keras.src.utils.rng_utils import set_random_seed\n\n\nclass ExamplePyDataset(py_dataset_adapter.PyDataset):\n    def __init__(\n        self,\n        x_set,\n        y_set,\n        sample_weight=None,\n        batch_size=32,\n        delay=0,\n        infinite=False,\n        **kwargs,\n    ):\n        super().__init__(**kwargs)\n        self.x, self.y = x_set, y_set\n        self.batch_size = batch_size\n        self.sample_weight = sample_weight\n        self.delay = delay\n        self.infinite = infinite\n\n    @property\n    def num_batches(self):\n        if self.infinite:\n            return None\n        return math.ceil(len(self.x) / self.batch_size)\n\n    def __getitem__(self, idx):\n        # Create artificial delay to test multiprocessing\n        time.sleep(self.delay)\n\n        if self.infinite:\n            idx = idx % math.ceil(len(self.x) / self.batch_size)\n        # Return x, y for batch idx.\n        low = idx * self.batch_size\n        # Cap upper bound at array length; the last batch may be smaller\n        # if the total number of items is not a multiple of batch size.\n        high = min(low + self.batch_size, len(self.x))\n        batch_x = self.x[low:high]\n        batch_y = self.y[low:high]\n        if self.sample_weight is not None:\n            return batch_x, batch_y, self.sample_weight[low:high]\n        return batch_x, batch_y\n\n\nclass DictPyDataset(py_dataset_adapter.PyDataset):\n    def __init__(self, inputs, batch_size=32, **kwargs):\n        super().__init__(**kwargs)\n        self.inputs = inputs\n        self.batch_size = batch_size\n\n    @property\n    def num_batches(self):\n        return math.ceil(len(self.inputs[\"x\"]) / self.batch_size)\n\n    def __getitem__(self, idx):\n        # Return x, y for batch idx.\n        low = idx * self.batch_size\n        # Cap upper bound at array length; the last batch may be smaller\n        # if the total number of items is not a multiple of batch size.\n        high = min(low + self.batch_size, len(self.inputs[\"x\"]))\n        batch_x = self.inputs[\"x\"][low:high]\n        batch_y = self.inputs[\"y\"][low:high]\n        batch = {\"x\": batch_x, \"y\": batch_y}\n        return batch\n\n\nclass ExceptionPyDataset(py_dataset_adapter.PyDataset):\n    @property\n    def num_batches(self):\n        return 4\n\n    def __getitem__(self, index):\n        if index < 2:\n            return (\n                np.random.random((8, 4)).astype(\"float32\"),\n                np.random.random((8, 2)).astype(\"float32\"),\n            )\n        raise ValueError(\"Expected exception\")\n\n\n@pytest.mark.skipif(\n    testing.tensorflow_uses_gpu() or testing.uses_tpu(),\n    reason=\"Flaky on TPU and GPU\",\n)\nclass PyDatasetAdapterTest(testing.TestCase):\n    @parameterized.named_parameters(\n        named_product(\n            [\n                {\n                    \"testcase_name\": \"multiprocessing\",\n                    \"workers\": 2,\n                    \"use_multiprocessing\": True,\n                    \"max_queue_size\": 10,\n                    \"dataset_type\": \"np\",\n                },\n                {\n                    \"testcase_name\": \"multithreading\",\n                    \"workers\": 2,\n                    \"use_multiprocessing\": False,\n                    \"max_queue_size\": 10,\n                    \"dataset_type\": \"np\",\n                },\n                {\n                    \"testcase_name\": \"single_np\",\n                    \"dataset_type\": \"np\",\n                },\n                {\n                    \"testcase_name\": \"single_tf\",\n                    \"dataset_type\": \"tf\",\n                },\n                {\n                    \"testcase_name\": \"single_jax\",\n                    \"dataset_type\": \"jax\",\n                },\n                {\n                    \"testcase_name\": \"single_torch\",\n                    \"dataset_type\": \"torch\",\n                },\n            ],\n            infinite=[True, False],\n            shuffle=[True, False],\n        )\n    )\n    def test_basic_flow(\n        self,\n        shuffle,\n        dataset_type,\n        infinite,\n        workers=0,\n        use_multiprocessing=False,\n        max_queue_size=0,\n    ):\n        if use_multiprocessing and shuffle:\n            pytest.skip(\"Starting processes is slow, test fewer variants\")\n\n        set_random_seed(1337)\n        x = np.random.random((64, 4)).astype(\"float32\")\n        y = np.array([[i, i] for i in range(64)], dtype=\"float32\")\n        CPU_DEVICES = {\n            \"tensorflow\": \"CPU:0\",\n            \"jax\": \"cpu:0\",\n        }\n        cpu_device = CPU_DEVICES.get(backend.backend(), \"cpu\")\n        with backend.device(cpu_device):\n            if dataset_type == \"tf\":\n                x, y = tf.constant(x), tf.constant(y)\n            elif dataset_type == \"jax\":\n                x, y = jax.numpy.array(x), jax.numpy.array(y)\n            elif dataset_type == \"torch\":\n                x, y = torch.as_tensor(x), torch.as_tensor(y)\n        py_dataset = ExamplePyDataset(\n            x,\n            y,\n            batch_size=16,\n            workers=workers,\n            use_multiprocessing=use_multiprocessing,\n            max_queue_size=max_queue_size,\n            infinite=infinite,\n        )\n        adapter = py_dataset_adapter.PyDatasetAdapter(\n            py_dataset, shuffle=shuffle\n        )\n\n        if backend.backend() == \"tensorflow\":\n            it = adapter.get_tf_dataset()\n            expected_class = tf.Tensor\n        elif backend.backend() == \"jax\":\n            it = adapter.get_jax_iterator()\n            expected_class = jax.Array if dataset_type == \"jax\" else np.ndarray\n        elif backend.backend() == \"torch\":\n            it = adapter.get_torch_dataloader()\n            expected_class = torch.Tensor\n        else:\n            it = adapter.get_numpy_iterator()\n            expected_class = np.ndarray\n\n        sample_order = []\n        adapter.on_epoch_begin()\n        for batch in it:\n            self.assertEqual(len(batch), 2)\n            bx, by = batch\n            self.assertIsInstance(bx, expected_class)\n            self.assertIsInstance(by, expected_class)\n            self.assertEqual(bx.dtype, by.dtype)\n            self.assertContainsExactSubsequence(str(bx.dtype), \"float32\")\n            self.assertEqual(bx.shape, (16, 4))\n            self.assertEqual(by.shape, (16, 2))\n            for i in range(by.shape[0]):\n                sample_order.append(by[i, 0])\n            if infinite:\n                if len(sample_order) == 64:\n                    adapter.on_epoch_end()\n                    adapter.on_epoch_begin()\n                elif len(sample_order) >= 128:\n                    break\n        adapter.on_epoch_end()\n\n        expected_order = list(range(64))\n        if infinite:\n            self.assertAllClose(sample_order, expected_order + expected_order)\n        elif shuffle:\n            self.assertNotAllClose(sample_order, expected_order)\n            self.assertAllClose(sorted(sample_order), expected_order)\n        else:\n            self.assertAllClose(sample_order, expected_order)\n\n    # TODO: test sample weights\n    # TODO: test inference mode (single output)\n\n    def test_class_weight(self):\n        x = np.random.randint(1, 100, (4, 5))\n        y = np.array([0, 1, 2, 1])\n        class_w = {0: 2, 1: 1, 2: 3}\n        py_dataset = ExamplePyDataset(x, y, batch_size=2)\n        adapter = py_dataset_adapter.PyDatasetAdapter(\n            py_dataset, shuffle=False, class_weight=class_w\n        )\n        if backend.backend() == \"tensorflow\":\n            gen = adapter.get_tf_dataset()\n        elif backend.backend() == \"jax\":\n            gen = adapter.get_jax_iterator()\n        elif backend.backend() == \"torch\":\n            gen = adapter.get_torch_dataloader()\n        else:\n            gen = adapter.get_numpy_iterator()\n\n        for index, batch in enumerate(gen):\n            # Batch is a tuple of (x, y, class_weight)\n            self.assertLen(batch, 3)\n            batch = [backend.convert_to_numpy(x) for x in batch]\n            # Let's verify the data and class weights match for each element\n            # of the batch (2 elements in each batch)\n            for sub_elem in range(2):\n                self.assertAllEqual(batch[0][sub_elem], x[index * 2 + sub_elem])\n                self.assertEqual(batch[1][sub_elem], y[index * 2 + sub_elem])\n                class_key = np.int32(batch[1][sub_elem])\n                self.assertEqual(batch[2][sub_elem], class_w[class_key])\n\n        self.assertEqual(index, 1)  # 2 batches\n\n    def test_speedup(self):\n        x = np.random.random((40, 4))\n        y = np.random.random((40, 2))\n\n        no_speedup_py_dataset = ExamplePyDataset(\n            x,\n            y,\n            batch_size=4,\n            delay=0.2,\n        )\n        adapter = py_dataset_adapter.PyDatasetAdapter(\n            no_speedup_py_dataset, shuffle=False\n        )\n        gen = adapter.get_numpy_iterator()\n        t0 = time.time()\n        for batch in gen:\n            pass\n        no_speedup_time = time.time() - t0\n\n        speedup_py_dataset = ExamplePyDataset(\n            x,\n            y,\n            batch_size=4,\n            workers=4,\n            # TODO: the github actions runner may have performance issue with\n            # multiprocessing\n            # use_multiprocessing=True,\n            max_queue_size=8,\n            delay=0.2,\n        )\n        adapter = py_dataset_adapter.PyDatasetAdapter(\n            speedup_py_dataset, shuffle=False\n        )\n        gen = adapter.get_numpy_iterator()\n        t0 = time.time()\n        for batch in gen:\n            pass\n        speedup_time = time.time() - t0\n\n        self.assertLess(speedup_time, no_speedup_time)\n\n    def test_dict_inputs(self):\n        inputs = {\n            \"x\": np.random.random((40, 4)),\n            \"y\": np.random.random((40, 2)),\n        }\n        py_dataset = DictPyDataset(inputs, batch_size=4)\n        adapter = py_dataset_adapter.PyDatasetAdapter(py_dataset, shuffle=False)\n        gen = adapter.get_numpy_iterator()\n        for batch in gen:\n            self.assertEqual(len(batch), 2)\n            bx, by = batch[\"x\"], batch[\"y\"]\n            self.assertIsInstance(bx, np.ndarray)\n            self.assertIsInstance(by, np.ndarray)\n            self.assertEqual(bx.dtype, by.dtype)\n            self.assertEqual(bx.shape, (4, 4))\n            self.assertEqual(by.shape, (4, 2))\n\n        ds = adapter.get_tf_dataset()\n        for batch in ds:\n            self.assertEqual(len(batch), 2)\n            bx, by = batch[\"x\"], batch[\"y\"]\n            self.assertIsInstance(bx, tf.Tensor)\n            self.assertIsInstance(by, tf.Tensor)\n            self.assertEqual(bx.dtype, by.dtype)\n            self.assertEqual(tuple(bx.shape), (4, 4))\n            self.assertEqual(tuple(by.shape), (4, 2))\n\n    def test_with_different_shapes(self):\n        class TestPyDataset(py_dataset_adapter.PyDataset):\n            @property\n            def num_batches(self):\n                return 3\n\n            def __getitem__(self, idx):\n                if idx == 0:\n                    return np.ones([16, 4], \"float32\"), np.ones(\n                        [16, 2], \"float32\"\n                    )\n                if idx == 1:\n                    return np.ones([16, 5], \"float32\"), np.ones(\n                        [16, 2], \"float32\"\n                    )\n                else:\n                    return np.ones([2, 6], \"float32\"), np.ones(\n                        [2, 2], \"float32\"\n                    )\n\n        adapter = py_dataset_adapter.PyDatasetAdapter(\n            TestPyDataset(), shuffle=False\n        )\n\n        if backend.backend() == \"tensorflow\":\n            it = adapter.get_tf_dataset()\n        elif backend.backend() == \"jax\":\n            it = adapter.get_jax_iterator()\n        elif backend.backend() == \"torch\":\n            it = adapter.get_torch_dataloader()\n        else:\n            it = adapter.get_numpy_iterator()\n\n        for i, batch in enumerate(it):\n            self.assertEqual(len(batch), 2)\n            bx, by = batch\n            self.assertEqual(bx.dtype, by.dtype)\n            self.assertContainsExactSubsequence(str(bx.dtype), \"float32\")\n            if i == 0:\n                self.assertEqual(bx.shape, (16, 4))\n                self.assertEqual(by.shape, (16, 2))\n            elif i == 1:\n                self.assertEqual(bx.shape, (16, 5))\n                self.assertEqual(by.shape, (16, 2))\n            else:\n                self.assertEqual(bx.shape, (2, 6))\n                self.assertEqual(by.shape, (2, 2))\n\n    @parameterized.named_parameters(\n        [\n            {\n                \"testcase_name\": \"multiprocessing\",\n                \"workers\": 2,\n                \"use_multiprocessing\": True,\n                \"max_queue_size\": 10,\n            },\n            {\n                \"testcase_name\": \"multithreading\",\n                \"workers\": 2,\n                \"max_queue_size\": 10,\n            },\n            {\n                \"testcase_name\": \"single\",\n            },\n        ]\n    )\n    def test_exception_reported(\n        self,\n        workers=0,\n        use_multiprocessing=False,\n        max_queue_size=0,\n    ):\n        if backend.backend() == \"jax\" and use_multiprocessing is True:\n            self.skipTest(\n                \"The CI failed for an unknown reason with \"\n                \"`use_multiprocessing=True` in the jax backend\"\n            )\n        dataset = ExceptionPyDataset(\n            workers=workers,\n            use_multiprocessing=use_multiprocessing,\n            max_queue_size=max_queue_size,\n        )\n        adapter = py_dataset_adapter.PyDatasetAdapter(dataset, shuffle=False)\n\n        expected_exception_class = ValueError\n        if backend.backend() == \"tensorflow\":\n            it = adapter.get_tf_dataset()\n            # tf.data wraps the exception\n            expected_exception_class = tf.errors.InvalidArgumentError\n        elif backend.backend() == \"jax\":\n            it = adapter.get_jax_iterator()\n        elif backend.backend() == \"torch\":\n            it = adapter.get_torch_dataloader()\n        else:\n            it = adapter.get_numpy_iterator()\n\n        it = iter(it)\n        next(it)\n        next(it)\n        with self.assertRaisesRegex(\n            expected_exception_class, \"Expected exception\"\n        ):\n            next(it)\n\n    def test_iterate_finite(self):\n        py_dataset = ExamplePyDataset(\n            np.ones((6, 11), dtype=\"int32\"),\n            np.zeros((6, 11), dtype=\"int32\"),\n            batch_size=2,\n        )\n        batches = [batch for batch in py_dataset]\n        self.assertLen(batches, 3)\n\n    def test_iterate_infinite_with_none_num_batches(self):\n        py_dataset = ExamplePyDataset(\n            np.ones((6, 11), dtype=\"int32\"),\n            np.zeros((6, 11), dtype=\"int32\"),\n            batch_size=2,\n            infinite=True,\n        )\n        for index, _ in enumerate(py_dataset):\n            if index >= 10:\n                break\n\n    def test_iterate_infinite_with_no_len(self):\n        class NoLenDataset(py_dataset_adapter.PyDataset):\n            def __getitem__(self, idx):\n                yield np.ones((2, 11), dtype=\"int32\")\n\n        for index, _ in enumerate(NoLenDataset()):\n            if index >= 10:\n                break\n"
  },
  {
    "path": "keras/src/trainers/data_adapters/tf_dataset_adapter.py",
    "content": "from keras.src import tree\nfrom keras.src.trainers.data_adapters import data_adapter_utils\nfrom keras.src.trainers.data_adapters.data_adapter import DataAdapter\n\n\nclass TFDatasetAdapter(DataAdapter):\n    \"\"\"Adapter that handles `tf.data.Dataset`.\"\"\"\n\n    def __init__(self, dataset, class_weight=None, distribution=None):\n        \"\"\"Initialize the TFDatasetAdapter.\n\n        Args:\n            dataset: The input `tf.data.Dataset` instance.\n            class_weight: A map where the keys are integer class ids and values\n                are the class weights, e.g. `{0: 0.2, 1: 0.6, 2: 0.3}`.\n            distribution: A `keras.distribution.Distribution` instance. Used to\n                shard the input dataset into per worker/process dataset\n                instance.\n        \"\"\"\n        from keras.src.utils.module_utils import tensorflow as tf\n\n        if not isinstance(\n            dataset, (tf.data.Dataset, tf.distribute.DistributedDataset)\n        ):\n            raise ValueError(\n                \"Expected argument `dataset` to be a tf.data.Dataset. \"\n                f\"Received: {dataset}\"\n            )\n        if class_weight is not None:\n            dataset = dataset.map(\n                make_class_weight_map_fn(class_weight)\n            ).prefetch(tf.data.AUTOTUNE)\n        if distribution is not None:\n            dataset = distribution.distribute_dataset(dataset)\n        self._dataset = dataset\n\n    def get_numpy_iterator(self):\n        from keras.src.backend.tensorflow.core import convert_to_numpy\n\n        for batch in self._dataset:\n            yield tree.map_structure(\n                convert_to_numpy, batch, none_is_leaf=False\n            )\n\n    def get_jax_iterator(self):\n        from keras.src.backend.tensorflow.core import convert_to_numpy\n        from keras.src.utils.module_utils import tensorflow as tf\n\n        def convert_to_jax(x):\n            if isinstance(x, tf.SparseTensor):\n                return data_adapter_utils.tf_sparse_to_jax_sparse(x)\n            else:\n                # We use numpy as an intermediary because it is faster.\n                return convert_to_numpy(x)\n\n        for batch in self._dataset:\n            yield tree.map_structure(convert_to_jax, batch, none_is_leaf=False)\n\n    def get_tf_dataset(self):\n        return self._dataset\n\n    def get_torch_dataloader(self):\n        return data_adapter_utils.get_torch_dataloader(self._dataset)\n\n    @property\n    def num_batches(self):\n        cardinality = self._dataset.cardinality\n        if callable(cardinality):\n            # `dataset.cardinality` is normally expected to be a callable.\n            cardinality = int(self._dataset.cardinality())\n        else:\n            # However, in the case of `DistributedDataset`, it's a np.int64.\n            cardinality = int(cardinality)\n        # Return None for Unknown and Infinite cardinality datasets\n        if cardinality < 0:\n            return None\n        return cardinality\n\n    @property\n    def batch_size(self):\n        first_element_spec = tree.flatten(self._dataset.element_spec)[0]\n        return first_element_spec.shape[0]\n\n    @property\n    def has_partial_batch(self):\n        return None\n\n    @property\n    def partial_batch_size(self):\n        return None\n\n\ndef make_class_weight_map_fn(class_weight):\n    \"\"\"Applies class weighting to a `Dataset`.\n\n    The `Dataset` is assumed to be in format `(x, y)` or `(x, y, sw)`, where\n    `y` must be a single `Tensor`.\n\n    Args:\n        class_weight: A map where the keys are integer class ids and values are\n            the class weights, e.g. `{0: 0.2, 1: 0.6, 2: 0.3}`\n\n    Returns:\n        A function that can be used with `tf.data.Dataset.map` to apply class\n        weighting.\n    \"\"\"\n    from keras.src.utils.module_utils import tensorflow as tf\n\n    class_weight_tensor = tf.convert_to_tensor(\n        [\n            class_weight.get(int(c), 1.0)\n            for c in range(max(class_weight.keys()) + 1)\n        ]\n    )\n\n    def class_weights_map_fn(*data):\n        \"\"\"Convert `class_weight` to `sample_weight`.\"\"\"\n        x, y, sw = data_adapter_utils.unpack_x_y_sample_weight(data)\n        if sw is not None:\n            raise ValueError(\n                \"You cannot `class_weight` and `sample_weight` \"\n                \"at the same time.\"\n            )\n        if tree.is_nested(y):\n            raise ValueError(\n                \"`class_weight` is only supported for Models with a single \"\n                \"output.\"\n            )\n\n        if y.shape.rank >= 2:\n            y_classes = tf.__internal__.smart_cond.smart_cond(\n                tf.shape(y)[-1] > 1,\n                lambda: tf.argmax(y, axis=-1, output_type=tf.int32),\n                lambda: tf.cast(tf.round(tf.squeeze(y, axis=-1)), tf.int32),\n            )\n        else:\n            # Special casing for rank 1, where we can guarantee sparse encoding.\n            y_classes = tf.cast(tf.round(y), tf.int32)\n\n        cw = tf.gather(class_weight_tensor, y_classes)\n        return x, y, cw\n\n    return class_weights_map_fn\n"
  },
  {
    "path": "keras/src/trainers/data_adapters/tf_dataset_adapter_test.py",
    "content": "from unittest import mock\n\nimport jax\nimport numpy as np\nimport pytest\nimport tensorflow as tf\nimport torch\n\nfrom keras.src import Sequential\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import testing\nfrom keras.src.trainers.data_adapters import tf_dataset_adapter\n\n\nclass TestTFDatasetAdapter(testing.TestCase):\n    def test_basic_flow(self):\n        x = tf.random.normal((34, 4))\n        y = tf.random.normal((34, 2))\n        base_ds = tf.data.Dataset.from_tensor_slices((x, y)).batch(16)\n        adapter = tf_dataset_adapter.TFDatasetAdapter(base_ds)\n\n        self.assertEqual(adapter.num_batches, 3)\n        self.assertEqual(adapter.batch_size, None)\n        self.assertEqual(adapter.has_partial_batch, None)\n        self.assertEqual(adapter.partial_batch_size, None)\n\n        if backend.backend() == \"tensorflow\":\n            it = adapter.get_tf_dataset()\n            expected_class = tf.Tensor\n        elif backend.backend() == \"jax\":\n            it = adapter.get_jax_iterator()\n            expected_class = np.ndarray\n        elif backend.backend() == \"torch\":\n            it = adapter.get_torch_dataloader()\n            expected_class = torch.Tensor\n        else:\n            it = adapter.get_numpy_iterator()\n            expected_class = np.ndarray\n\n        for i, batch in enumerate(it):\n            self.assertEqual(len(batch), 2)\n            bx, by = batch\n            self.assertIsInstance(bx, expected_class)\n            self.assertIsInstance(by, expected_class)\n            self.assertEqual(bx.dtype, by.dtype)\n            self.assertContainsExactSubsequence(str(bx.dtype), \"float32\")\n            if i < 2:\n                self.assertEqual(bx.shape, (16, 4))\n                self.assertEqual(by.shape, (16, 2))\n            else:\n                self.assertEqual(bx.shape, (2, 4))\n                self.assertEqual(by.shape, (2, 2))\n\n    def _test_class_weights(self, target_encoding=\"int\"):\n        x = np.random.random((4, 2))\n        if target_encoding == \"int\":\n            y = np.array([[0], [1], [2], [3]], dtype=\"int64\")\n        else:\n            y = np.array(\n                [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]],\n                dtype=\"float32\",\n            )\n\n        class_weight = {\n            0: 0.1,\n            1: 0.2,\n            2: 0.3,\n            3: 0.4,\n        }\n        base_ds = tf.data.Dataset.from_tensor_slices((x, y)).batch(16)\n        adapter = tf_dataset_adapter.TFDatasetAdapter(\n            base_ds, class_weight=class_weight\n        )\n        gen = adapter.get_numpy_iterator()\n        for batch in gen:\n            self.assertEqual(len(batch), 3)\n            _, _, bw = batch\n            self.assertAllClose(bw, [0.1, 0.2, 0.3, 0.4])\n\n    def test_class_weights_int_targets(self):\n        self._test_class_weights(target_encoding=\"int\")\n\n    def test_class_weights_categorical_targets(self):\n        self._test_class_weights(target_encoding=\"categorical\")\n\n    def test_num_batches(self):\n        dataset = tf.data.Dataset.range(42)\n        cardinality = int(dataset.cardinality())\n        self.assertEqual(cardinality, 42)\n        adapter = tf_dataset_adapter.TFDatasetAdapter(dataset)\n        self.assertEqual(adapter.num_batches, 42)\n\n        # Test for Infinite Cardinality\n        dataset = tf.data.Dataset.range(42)\n        dataset = dataset.repeat()\n        cardinality = int(dataset.cardinality())\n        self.assertEqual(cardinality, tf.data.INFINITE_CARDINALITY)\n        adapter = tf_dataset_adapter.TFDatasetAdapter(dataset)\n        self.assertIsNone(adapter.num_batches)\n\n        # Test for Unknown Cardinality\n        dataset = dataset.filter(lambda x: True)\n        cardinality = int(dataset.cardinality())\n        self.assertEqual(cardinality, tf.data.UNKNOWN_CARDINALITY)\n        adapter = tf_dataset_adapter.TFDatasetAdapter(dataset)\n        self.assertIsNone(adapter.num_batches)\n\n    def test_invalid_dataset_type(self):\n        with self.assertRaisesRegex(\n            ValueError, \"Expected argument `dataset` to be a tf.data.Dataset\"\n        ):\n            invalid_data = \"This is not a tf.data.Dataset\"\n            tf_dataset_adapter.TFDatasetAdapter(invalid_data)\n\n    def test_class_weight_and_sample_weight_together(self):\n        x = np.random.random((4, 2))\n        y = np.array([[0], [1], [2], [3]], dtype=\"int64\")\n        sw = np.array([0.5, 0.5, 0.5, 0.5])\n        base_ds = tf.data.Dataset.from_tensor_slices((x, y, sw)).batch(16)\n        class_weight = {0: 0.1, 1: 0.2, 2: 0.3, 3: 0.4}\n\n        with self.assertRaisesRegex(\n            ValueError,\n            \"You cannot `class_weight` and `sample_weight` at the same time.\",\n        ):\n            tf_dataset_adapter.TFDatasetAdapter(\n                base_ds, class_weight=class_weight\n            )\n\n    def test_different_y_shapes_with_class_weight(self):\n        x = np.random.random((4, 2))\n        y = np.array(\n            [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]],\n            dtype=\"float32\",\n        )\n        base_ds = tf.data.Dataset.from_tensor_slices((x, y)).batch(16)\n        class_weight = {0: 0.1, 1: 0.2, 2: 0.3, 3: 0.4}\n        adapter = tf_dataset_adapter.TFDatasetAdapter(\n            base_ds, class_weight=class_weight\n        )\n        gen = adapter.get_numpy_iterator()\n        for batch in gen:\n            _, _, bw = batch\n            self.assertAllClose(bw, [0.1, 0.2, 0.3, 0.4])\n\n        y_sparse = np.array([0, 1, 2, 3], dtype=\"int64\")\n        base_ds = tf.data.Dataset.from_tensor_slices((x, y_sparse)).batch(16)\n        adapter = tf_dataset_adapter.TFDatasetAdapter(\n            base_ds, class_weight=class_weight\n        )\n        gen = adapter.get_numpy_iterator()\n        for batch in gen:\n            _, _, bw = batch\n            self.assertAllClose(bw, [0.1, 0.2, 0.3, 0.4])\n\n    def test_nested_y_with_class_weight(self):\n        x = np.random.random((4, 2))\n\n        # Define two target outputs, y1 and y2, for the dataset\n        y1 = np.array([0, 1, 2, 3], dtype=\"int64\")\n        y2 = np.array([0, 1, 2, 3], dtype=\"int64\")\n\n        # Create a tf.data Dataset from the input data and two target outputs\n        base_ds = tf.data.Dataset.from_tensor_slices((x, (y1, y2))).batch(16)\n\n        # Define class weights for potential classes in the output\n        class_weight = {0: 0.1, 1: 0.2, 2: 0.3, 3: 0.4}\n\n        with self.assertRaisesRegex(\n            ValueError,\n            \"`class_weight` is only supported for Models with a single output.\",\n        ):\n            tf_dataset_adapter.TFDatasetAdapter(\n                base_ds, class_weight=class_weight\n            )\n\n    def test_class_weights_map_fn_with_sample_weight(self):\n        class_weight = {0: 0.1, 1: 0.2, 2: 0.3, 3: 0.4}\n        class_weights_map_fn = tf_dataset_adapter.make_class_weight_map_fn(\n            class_weight\n        )\n\n        x = np.array([[0.5, 0.5], [0.5, 0.5]])\n        y = np.array([[1, 0], [0, 1]])\n        sw = np.array([1.0, 1.0])\n\n        with self.assertRaisesRegex(\n            ValueError,\n            \"You cannot `class_weight` and `sample_weight` at the same time.\",\n        ):\n            class_weights_map_fn(x, y, sw)\n\n    def test_class_weights_map_fn_nested_y(self):\n        class_weight = {0: 0.1, 1: 0.2, 2: 0.3, 3: 0.4}\n        class_weights_map_fn = tf_dataset_adapter.make_class_weight_map_fn(\n            class_weight\n        )\n\n        x = np.array([[0.5, 0.5]])\n        y1 = np.array([1])\n        y2 = np.array([0])\n\n        with self.assertRaisesRegex(\n            ValueError,\n            \"`class_weight` is only supported for Models with a single output.\",\n        ):\n            class_weights_map_fn(x, (y1, y2))\n\n    def test_distribute_dataset(self):\n        x = tf.random.normal((34, 4))\n        y = tf.random.normal((34, 2))\n        base_ds = tf.data.Dataset.from_tensor_slices((x, y)).batch(16)\n\n        data_distribution = mock.Mock()\n        # Mimic that there are 2 worker, and each of the worker will get batch\n        # size of 8\n        data_distribution.distribute_dataset = mock.MagicMock(\n            return_value=base_ds.rebatch(8).shard(2, index=0)\n        )\n\n        adapter = tf_dataset_adapter.TFDatasetAdapter(\n            base_ds, distribution=data_distribution\n        )\n\n        self.assertEqual(adapter.num_batches, None)\n        self.assertEqual(adapter.batch_size, None)\n        self.assertEqual(adapter.has_partial_batch, None)\n        self.assertEqual(adapter.partial_batch_size, None)\n\n        gen = adapter.get_numpy_iterator()\n        for i, batch in enumerate(gen):\n            self.assertEqual(len(batch), 2)\n            bx, by = batch\n            self.assertIsInstance(bx, np.ndarray)\n            self.assertIsInstance(by, np.ndarray)\n            self.assertEqual(bx.dtype, by.dtype)\n            self.assertEqual(bx.dtype, \"float32\")\n            if i < 2:\n                self.assertEqual(bx.shape, (8, 4))\n                self.assertEqual(by.shape, (8, 2))\n            else:\n                self.assertEqual(bx.shape, (2, 4))\n                self.assertEqual(by.shape, (2, 2))\n        ds = adapter.get_tf_dataset()\n        for i, batch in enumerate(ds):\n            self.assertEqual(len(batch), 2)\n            bx, by = batch\n            self.assertIsInstance(bx, tf.Tensor)\n            self.assertIsInstance(by, tf.Tensor)\n            self.assertEqual(bx.dtype, by.dtype)\n            self.assertEqual(bx.dtype, \"float32\")\n            if i < 2:\n                self.assertEqual(tuple(bx.shape), (8, 4))\n                self.assertEqual(tuple(by.shape), (8, 2))\n            else:\n                self.assertEqual(tuple(bx.shape), (2, 4))\n                self.assertEqual(tuple(by.shape), (2, 2))\n\n    @pytest.mark.skipif(\n        not backend.SUPPORTS_SPARSE_TENSORS and backend.backend() != \"numpy\",\n        reason=\"Backend does not support sparse tensors\",\n    )\n    def test_tf_sparse_tensors(self):\n        x = tf.SparseTensor(\n            indices=[[0, 0], [1, 2]], values=[1.0, 2.0], dense_shape=(2, 4)\n        )\n        y = tf.SparseTensor(\n            indices=[[0, 0], [1, 1]], values=[3.0, 4.0], dense_shape=(2, 2)\n        )\n        base_ds = tf.data.Dataset.from_tensors((x, y))\n        adapter = tf_dataset_adapter.TFDatasetAdapter(base_ds)\n\n        if backend.backend() == \"numpy\":\n            it = adapter.get_numpy_iterator()\n            expected_class = np.ndarray\n        elif backend.backend() == \"tensorflow\":\n            it = adapter.get_tf_dataset()\n            expected_class = tf.SparseTensor\n        elif backend.backend() == \"jax\":\n            it = adapter.get_jax_iterator()\n            expected_class = jax.experimental.sparse.BCOO\n\n        for batch in it:\n            self.assertEqual(len(batch), 2)\n            bx, by = batch\n            self.assertIsInstance(bx, expected_class)\n            self.assertIsInstance(by, expected_class)\n            self.assertEqual(bx.shape, (2, 4))\n            self.assertEqual(by.shape, (2, 2))\n\n    def test_distributed_datasets_from_function_adapter_properties(self):\n        strategy = tf.distribute.MirroredStrategy([\"CPU:0\"])\n\n        def dataset_fn(input_context):\n            batch_size = input_context.get_per_replica_batch_size(\n                global_batch_size=2\n            )\n            x = tf.random.uniform((32, 4))\n            y = tf.random.uniform((32, 2))\n            return tf.data.Dataset.from_tensor_slices((x, y)).batch(batch_size)\n\n        dist_dataset = strategy.distribute_datasets_from_function(dataset_fn)\n        adapter = tf_dataset_adapter.TFDatasetAdapter(dist_dataset)\n        self.assertEqual(adapter.num_batches, 16)\n        self.assertIsNone(adapter.batch_size)\n        self.assertIsNone(adapter.has_partial_batch)\n        self.assertIsNone(adapter.partial_batch_size)\n\n        if backend.backend() == \"tensorflow\":\n            it = adapter.get_tf_dataset()\n            expected_class = tf.Tensor\n        elif backend.backend() == \"jax\":\n            it = adapter.get_jax_iterator()\n            expected_class = np.ndarray\n        elif backend.backend() == \"torch\":\n            it = adapter.get_torch_dataloader()\n            expected_class = torch.Tensor\n        else:\n            it = adapter.get_numpy_iterator()\n            expected_class = np.ndarray\n\n        batch_count = 0\n        for batch in it:\n            batch_count += 1\n            self.assertEqual(len(batch), 2)\n            data, labels = batch\n            self.assertIsInstance(data, expected_class)\n            self.assertIsInstance(labels, expected_class)\n            self.assertEqual(data.shape, (2, 4))\n            self.assertEqual(labels.shape, (2, 2))\n\n        self.assertEqual(batch_count, 16)\n\n    @pytest.mark.requires_trainable_backend\n    def test_distributed_datasets_from_function_model_integration(self):\n        strategy = tf.distribute.MirroredStrategy([\"CPU:0\"])\n\n        def dataset_fn(input_context):\n            batch_size = input_context.get_per_replica_batch_size(\n                global_batch_size=2\n            )\n            x = tf.random.uniform((4, 1))\n            y = tf.random.uniform((4, 2))\n            return tf.data.Dataset.from_tensor_slices((x, y)).batch(batch_size)\n\n        dist_dataset = strategy.distribute_datasets_from_function(dataset_fn)\n\n        model = Sequential([layers.Dense(2, input_shape=(1,))])\n        model.compile(optimizer=\"adam\", loss=\"mse\")\n        history = model.fit(dist_dataset, epochs=1)\n        self.assertIn(\"loss\", history.history)\n"
  },
  {
    "path": "keras/src/trainers/data_adapters/torch_data_loader_adapter.py",
    "content": "import itertools\n\nimport numpy as np\n\nfrom keras.src import tree\nfrom keras.src.trainers.data_adapters import data_adapter_utils\nfrom keras.src.trainers.data_adapters.data_adapter import DataAdapter\n\n\nclass TorchDataLoaderAdapter(DataAdapter):\n    \"\"\"Adapter that handles `torch.utils.data.DataLoader`.\"\"\"\n\n    def __init__(self, dataloader):\n        import torch\n\n        if not isinstance(dataloader, torch.utils.data.DataLoader):\n            raise ValueError(\n                f\"Expected argument `dataloader` to be an instance of\"\n                f\"`torch.utils.data.DataLoader`. Received: {dataloader}\"\n            )\n\n        self._dataloader = dataloader\n        self._output_signature = None\n        self._batch_size = dataloader.batch_size\n        self._num_batches = None\n        self._partial_batch_size = None\n        if hasattr(dataloader.dataset, \"__len__\"):\n            self._num_batches = len(dataloader)\n            if self._batch_size is not None:\n                self._partial_batch_size = (\n                    len(dataloader.dataset) % self._batch_size\n                )\n\n    def get_numpy_iterator(self):\n        for batch in self._dataloader:\n            # shared memory using `np.asarray`\n            yield tuple(\n                tree.map_structure(\n                    lambda x: np.asarray(x.cpu()), batch, none_is_leaf=False\n                )\n            )\n\n    def get_jax_iterator(self):\n        # We use numpy as an intermediary because it is faster.\n        return self.get_numpy_iterator()\n\n    def get_tf_dataset(self):\n        from keras.src.utils.module_utils import tensorflow as tf\n\n        if self._output_signature is None:\n            batches = list(\n                itertools.islice(\n                    self._dataloader,\n                    data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC,\n                )\n            )\n            self._output_signature = tuple(\n                data_adapter_utils.get_tensor_spec(batches)\n            )\n        return tf.data.Dataset.from_generator(\n            self.get_numpy_iterator,\n            output_signature=self._output_signature,\n        )\n\n    def get_torch_dataloader(self):\n        return self._dataloader\n\n    @property\n    def num_batches(self):\n        return self._num_batches\n\n    @property\n    def batch_size(self):\n        return self._batch_size\n\n    @property\n    def has_partial_batch(self):\n        if self._partial_batch_size:\n            return self._partial_batch_size > 0\n        else:\n            return None\n\n    @property\n    def partial_batch_size(self):\n        return self._partial_batch_size\n"
  },
  {
    "path": "keras/src/trainers/data_adapters/torch_data_loader_adapter_test.py",
    "content": "import math\n\nimport numpy as np\nimport tensorflow as tf\nimport torch\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import testing\nfrom keras.src.testing.test_utils import named_product\nfrom keras.src.trainers.data_adapters.torch_data_loader_adapter import (\n    TorchDataLoaderAdapter,\n)\n\n\nclass TestTorchDataLoaderAdapter(testing.TestCase):\n    def test_basic_dataloader(self):\n        x = torch.normal(2, 3, size=(34, 4))\n        y = torch.normal(1, 3, size=(34, 2))\n        ds = torch.utils.data.TensorDataset(x, y)\n        dataloader = torch.utils.data.DataLoader(ds, batch_size=16)\n        adapter = TorchDataLoaderAdapter(dataloader)\n\n        self.assertEqual(adapter.num_batches, 3)\n        self.assertEqual(adapter.batch_size, 16)\n        self.assertEqual(adapter.has_partial_batch, True)\n        self.assertEqual(adapter.partial_batch_size, 2)\n\n        if backend.backend() == \"tensorflow\":\n            it = adapter.get_tf_dataset()\n            expected_class = tf.Tensor\n        elif backend.backend() == \"jax\":\n            it = adapter.get_jax_iterator()\n            expected_class = np.ndarray\n        elif backend.backend() == \"torch\":\n            it = adapter.get_torch_dataloader()\n            expected_class = torch.Tensor\n        else:\n            it = adapter.get_numpy_iterator()\n            expected_class = np.ndarray\n\n        for i, batch in enumerate(it):\n            self.assertEqual(len(batch), 2)\n            bx, by = batch\n            self.assertIsInstance(bx, expected_class)\n            self.assertIsInstance(by, expected_class)\n            self.assertEqual(bx.dtype, by.dtype)\n            self.assertContainsExactSubsequence(str(bx.dtype), \"float32\")\n            if i < 2:\n                self.assertEqual(bx.shape, (16, 4))\n                self.assertEqual(by.shape, (16, 2))\n            else:\n                self.assertEqual(bx.shape, (2, 4))\n                self.assertEqual(by.shape, (2, 2))\n\n    @parameterized.named_parameters(\n        named_product(batch_size=[None, 3], implements_len=[True, False])\n    )\n    def test_dataloader_iterable_dataset(self, batch_size, implements_len):\n        class TestIterableDataset(torch.utils.data.IterableDataset):\n            def __init__(self):\n                self.x = torch.normal(2, 3, size=(16, 4))\n                self.y = torch.normal(1, 3, size=(16, 2))\n\n            def __iter__(self):\n                for _ in range(10):\n                    yield (self.x, self.y)\n\n        class TestIterableDatasetWithLen(TestIterableDataset):\n            def __len__(self):\n                return 10\n\n        ds = (\n            TestIterableDatasetWithLen()\n            if implements_len\n            else TestIterableDataset()\n        )\n        dataloader = torch.utils.data.DataLoader(ds, batch_size=batch_size)\n        adapter = TorchDataLoaderAdapter(dataloader)\n\n        if implements_len and batch_size:\n            self.assertEqual(adapter.num_batches, math.ceil(10 / batch_size))\n            self.assertEqual(adapter.batch_size, batch_size)\n            self.assertEqual(adapter.has_partial_batch, True)\n            self.assertEqual(adapter.partial_batch_size, 10 % batch_size)\n        elif implements_len:\n            self.assertEqual(adapter.num_batches, 10)\n            self.assertEqual(adapter.batch_size, None)\n            self.assertEqual(adapter.has_partial_batch, None)\n            self.assertEqual(adapter.partial_batch_size, None)\n        else:\n            self.assertIsNone(adapter.num_batches)\n            self.assertEqual(adapter.batch_size, batch_size)\n            self.assertIsNone(adapter.has_partial_batch)\n            self.assertIsNone(adapter.partial_batch_size)\n\n        if backend.backend() == \"tensorflow\":\n            it = adapter.get_tf_dataset()\n            expected_class = tf.Tensor\n        elif backend.backend() == \"jax\":\n            it = adapter.get_jax_iterator()\n            expected_class = np.ndarray\n        elif backend.backend() == \"torch\":\n            it = adapter.get_torch_dataloader()\n            expected_class = torch.Tensor\n        else:\n            it = adapter.get_numpy_iterator()\n            expected_class = np.ndarray\n\n        batch_count = 0\n        for i, batch in enumerate(it):\n            batch_count += 1\n            self.assertEqual(len(batch), 2)\n            bx, by = batch\n            self.assertIsInstance(bx, expected_class)\n            self.assertIsInstance(by, expected_class)\n            self.assertEqual(bx.dtype, by.dtype)\n            self.assertContainsExactSubsequence(str(bx.dtype), \"float32\")\n            if batch_size:\n                if i < 3:\n                    self.assertEqual(bx.shape, (batch_size, 16, 4))\n                    self.assertEqual(by.shape, (batch_size, 16, 2))\n                else:\n                    self.assertEqual(bx.shape, (10 % batch_size, 16, 4))\n                    self.assertEqual(by.shape, (10 % batch_size, 16, 2))\n            else:\n                self.assertEqual(bx.shape, (16, 4))\n                self.assertEqual(by.shape, (16, 2))\n\n        if batch_size:\n            self.assertEqual(batch_count, math.ceil(10 / batch_size))\n        else:\n            self.assertEqual(batch_count, 10)\n\n    def test_with_different_shapes(self):\n        x = (\n            [np.ones([4], \"float32\")] * 16\n            + [np.ones([5], \"float32\")] * 16\n            + [np.ones([6], \"float32\")] * 2\n        )\n        y = np.ones((34, 2), \"float32\")\n        ds = torch.utils.data.StackDataset(x, y)\n        dataloader = torch.utils.data.DataLoader(ds, batch_size=16)\n        adapter = TorchDataLoaderAdapter(dataloader)\n\n        self.assertEqual(adapter.num_batches, 3)\n        self.assertEqual(adapter.batch_size, 16)\n        self.assertEqual(adapter.has_partial_batch, True)\n        self.assertEqual(adapter.partial_batch_size, 2)\n\n        if backend.backend() == \"tensorflow\":\n            it = adapter.get_tf_dataset()\n        elif backend.backend() == \"jax\":\n            it = adapter.get_jax_iterator()\n        elif backend.backend() == \"torch\":\n            it = adapter.get_torch_dataloader()\n        else:\n            it = adapter.get_numpy_iterator()\n\n        for i, batch in enumerate(it):\n            self.assertEqual(len(batch), 2)\n            bx, by = batch\n            self.assertEqual(bx.dtype, by.dtype)\n            self.assertContainsExactSubsequence(str(bx.dtype), \"float32\")\n            if i == 0:\n                self.assertEqual(bx.shape, (16, 4))\n                self.assertEqual(by.shape, (16, 2))\n            elif i == 1:\n                self.assertEqual(bx.shape, (16, 5))\n                self.assertEqual(by.shape, (16, 2))\n            else:\n                self.assertEqual(bx.shape, (2, 6))\n                self.assertEqual(by.shape, (2, 2))\n"
  },
  {
    "path": "keras/src/trainers/epoch_iterator.py",
    "content": "\"\"\"\nSeparation of concerns:\n\nDataAdapter:\n    - x, y\n    - sample_weight\n    - class_weight\n    - shuffle\n    - batch_size\n        - steps, as it relates to batch_size for array data\n\nEpochIterator:\n    - whether to yield numpy or tf data\n    - steps\n    - most argument validation\n\nTrainer:\n    - steps_per_execution\n    - validation_split\n    - validation_data\n    - callbacks\n    - validation_freq\n    - epochs\n    - initial_epoch\n    - any backend-specific concern such as distribution\n\nPyDataset:\n    - num_workers\n    - use_multiprocessing\n    - max_queue_size\n\nEpochIterator steps:\n\n1. Look at data type and select correct DataHandler\n2. Instantiate DataHandler with correct arguments\n3. Raise or warn on unused arguments\n4. in __iter__, iterate, either for a fixed number of steps\nor until there is no data\n\n\"\"\"\n\nimport contextlib\nimport warnings\n\nfrom keras.src.backend import config\nfrom keras.src.trainers import data_adapters\n\n\nclass EpochIterator:\n    def __init__(\n        self,\n        x,\n        y=None,\n        sample_weight=None,\n        batch_size=None,\n        steps_per_epoch=None,\n        shuffle=False,\n        class_weight=None,\n        steps_per_execution=1,\n    ):\n        # Possibly cap steps_per_epoch for debugging runs.\n        max_steps_per_epoch = config.max_steps_per_epoch()\n        if max_steps_per_epoch:\n            if not steps_per_epoch or max_steps_per_epoch < steps_per_epoch:\n                warnings.warn(\n                    \"Limiting steps_per_epoch to %d\" % max_steps_per_epoch\n                )\n                steps_per_epoch = max_steps_per_epoch\n        self.steps_per_epoch = steps_per_epoch\n        self.steps_per_execution = steps_per_execution\n        self._current_iterator = None\n        self._epoch_iterator = None\n        self._steps_seen = 0\n        self.data_adapter = data_adapters.get_data_adapter(\n            x=x,\n            y=y,\n            sample_weight=sample_weight,\n            batch_size=batch_size,\n            steps_per_epoch=steps_per_epoch,\n            shuffle=shuffle,\n            class_weight=class_weight,\n        )\n        self._num_batches = self.data_adapter.num_batches\n\n    def _get_iterator(self):\n        return self.data_adapter.get_numpy_iterator()\n\n    def _interrupted_warning(self):\n        warnings.warn(\n            \"Your input ran out of data; interrupting training. \"\n            \"Make sure that your dataset or generator can generate \"\n            \"at least `steps_per_epoch * epochs` batches. \"\n            \"You may need to use the `.repeat()` \"\n            \"function when building your dataset.\",\n            stacklevel=2,\n        )\n\n    def reset(self):\n        self._current_iterator = None\n        self._num_batches = self.data_adapter.num_batches\n        self._steps_seen = 0\n        self._epoch_iterator = None\n        self.data_adapter.on_epoch_end()\n\n    def _enumerate_iterator(self):\n        self.data_adapter.on_epoch_begin()\n        steps_per_epoch = self.steps_per_epoch or self._num_batches or -1\n\n        if steps_per_epoch > 0:\n            if self._current_iterator is None or self.steps_per_epoch is None:\n                self._current_iterator = iter(self._get_iterator())\n                self._steps_seen = 0\n            for step in range(0, steps_per_epoch, self.steps_per_execution):\n                if self._num_batches and self._steps_seen >= self._num_batches:\n                    if self.steps_per_epoch:\n                        self._interrupted_warning()\n                    break\n                self._steps_seen += self.steps_per_execution\n                yield (\n                    step,\n                    step + self.steps_per_execution - 1,\n                    self._current_iterator,\n                )\n            if self._num_batches and self._steps_seen >= self._num_batches:\n                self._current_iterator = iter(self._get_iterator())\n                self._steps_seen = 0\n        else:\n            iterator = iter(self._get_iterator())\n            step = -self.steps_per_execution\n            while True:\n                step += self.steps_per_execution\n                self._steps_seen = step + self.steps_per_execution\n                yield step, step + self.steps_per_execution - 1, iterator\n        self.data_adapter.on_epoch_end()\n\n    def __iter__(self):\n        self._epoch_iterator = self._enumerate_iterator()\n        return self\n\n    def __next__(self):\n        buffer = []\n        begin_step, end_step, iterator = next(self._epoch_iterator)\n        with self.catch_stop_iteration():\n            for _ in range(self.steps_per_execution):\n                data = next(iterator)\n                buffer.append(data)\n            return begin_step, end_step, buffer\n        if buffer:\n            return begin_step, end_step, buffer\n        raise StopIteration\n\n    def enumerate_epoch(self):\n        for begin_step, end_step, data in self:\n            yield begin_step, end_step, data\n\n    @contextlib.contextmanager\n    def catch_stop_iteration(self):\n        \"\"\"Catches errors when an iterator runs out of data.\"\"\"\n        try:\n            yield\n        except StopIteration:\n            if self._num_batches is None:\n                self._num_batches = self._steps_seen\n            self._interrupted_warning()\n            self._current_iterator = None\n            self.data_adapter.on_epoch_end()\n\n    @property\n    def num_batches(self):\n        if self.steps_per_epoch:\n            return self.steps_per_epoch\n        # Either copied from the data_adapter, or\n        # inferred at the end of an iteration.\n        return self._num_batches\n"
  },
  {
    "path": "keras/src/trainers/epoch_iterator_test.py",
    "content": "import numpy as np\nimport pytest\nimport tensorflow as tf\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import testing\nfrom keras.src.trainers import data_adapters\nfrom keras.src.trainers import epoch_iterator\n\n\nclass TestEpochIterator(testing.TestCase):\n    @parameterized.named_parameters(\n        [(\"iterator\", \"iterator\"), (\"enumerate_epoch\", \"enumerate_epoch\")]\n    )\n    def test_basic_flow(self, call_type):\n        x = np.random.random((100, 16))\n        y = np.random.random((100, 4))\n        sample_weight = np.random.random((100,))\n        batch_size = 16\n        shuffle = True\n        iterator = epoch_iterator.EpochIterator(\n            x=x,\n            y=y,\n            sample_weight=sample_weight,\n            batch_size=batch_size,\n            shuffle=shuffle,\n        )\n        steps_seen = []\n        if call_type == \"iterator\":\n            generator = iterator\n        else:\n            generator = iterator.enumerate_epoch()\n        for begin_step, end_step, batch in generator:\n            batch = batch[0]\n            steps_seen.append(begin_step)\n            self.assertEqual(begin_step, end_step)\n            self.assertEqual(len(batch), 3)\n            self.assertIsInstance(batch[0], np.ndarray)\n        self.assertEqual(steps_seen, [0, 1, 2, 3, 4, 5, 6])\n\n    def test_insufficient_data(self):\n        batch_size = 8\n        steps_per_epoch = 6\n        dataset_size = batch_size * (steps_per_epoch - 2)\n        x = np.arange(dataset_size).reshape((dataset_size, 1))\n        y = x * 2\n        iterator = epoch_iterator.EpochIterator(\n            x=x,\n            y=y,\n            batch_size=batch_size,\n            steps_per_epoch=steps_per_epoch,\n        )\n        steps_seen = []\n        with pytest.warns(match=\"Your input ran out of data\"):\n            for step, _, _ in iterator:\n                steps_seen.append(step)\n        self.assertLen(steps_seen, steps_per_epoch - 2)\n\n        self.assertIsInstance(iterator, epoch_iterator.EpochIterator)\n\n    def test_unsupported_y_arg_tfdata(self):\n        with self.assertRaisesRegex(ValueError, \"`y` should not be passed\"):\n            x = tf.data.Dataset.from_tensor_slices(np.random.random((100, 16)))\n            y = np.random.random((100, 4))\n            _ = epoch_iterator.EpochIterator(x=x, y=y)\n\n    def test_unsupported_sample_weights_arg_tfdata(self):\n        with self.assertRaisesRegex(\n            ValueError, \"`sample_weights` should not be passed\"\n        ):\n            x = tf.data.Dataset.from_tensor_slices(np.random.random((100, 16)))\n            sample_weights = np.random.random((100,))\n            _ = epoch_iterator.EpochIterator(x=x, sample_weight=sample_weights)\n\n    @pytest.mark.skipif(\n        backend.backend() != \"torch\", reason=\"Need to import torch\"\n    )\n    def test_torch_dataloader(self):\n        import torch\n\n        class ExampleTorchDataset(torch.utils.data.Dataset):\n            def __init__(self, x, y):\n                self.x = x\n                self.y = y\n\n            def __len__(self):\n                return len(self.x)\n\n            def __getitem__(self, idx):\n                return self.x[idx], self.y[idx]\n\n        torch_dataset = ExampleTorchDataset(\n            np.random.random((64, 2)), np.random.random((64, 1))\n        )\n        torch_dataloader = torch.utils.data.DataLoader(\n            torch_dataset, batch_size=8, shuffle=True\n        )\n        iterator = epoch_iterator.EpochIterator(torch_dataloader)\n        for _, _, batch in iterator:\n            batch = batch[0]\n            self.assertEqual(batch[0].shape, (8, 2))\n            self.assertEqual(batch[1].shape, (8, 1))\n\n    @pytest.mark.skipif(\n        backend.backend() != \"torch\", reason=\"Need to import torch\"\n    )\n    def test_unsupported_y_arg_torch_dataloader(self):\n        import torch\n\n        class ExampleTorchDataset(torch.utils.data.Dataset):\n            def __init__(self, x, y):\n                self.x = x\n                self.y = y\n\n            def __len__(self):\n                return len(self.x)\n\n            def __getitem__(self, idx):\n                return self.x[idx], self.y[idx]\n\n        torch_dataset = ExampleTorchDataset(\n            np.random.random((100, 16)), np.random.random((100, 4))\n        )\n        x = torch.utils.data.DataLoader(\n            torch_dataset, batch_size=8, shuffle=True\n        )\n        y = np.random.random((100, 4))\n\n        with self.assertRaisesRegex(\n            ValueError,\n            \"When providing `x` as a torch DataLoader, `y` should not\",\n        ):\n            _ = epoch_iterator.EpochIterator(x=x, y=y)\n\n    @pytest.mark.skipif(\n        backend.backend() != \"torch\", reason=\"Need to import torch\"\n    )\n    def test_unsupported_sample_weights_arg_torch_dataloader(self):\n        import torch\n\n        class ExampleTorchDataset(torch.utils.data.Dataset):\n            def __init__(self, x, y):\n                self.x = x\n                self.y = y\n\n            def __len__(self):\n                return len(self.x)\n\n            def __getitem__(self, idx):\n                return self.x[idx], self.y[idx]\n\n        torch_dataset = ExampleTorchDataset(\n            np.random.random((100, 16)), np.random.random((100, 4))\n        )\n        x = torch.utils.data.DataLoader(\n            torch_dataset, batch_size=8, shuffle=True\n        )\n        sample_weights = np.random.random((100,))\n\n        with self.assertRaisesRegex(\n            ValueError,\n            \"When providing `x` as a torch DataLoader, `sample_weights`\",\n        ):\n            _ = epoch_iterator.EpochIterator(x=x, sample_weight=sample_weights)\n\n    def test_python_generator_input(self):\n        def generator_example():\n            for i in range(100):\n                yield (np.array([i]), np.array([i * 2]))\n\n        x = generator_example()\n        epoch_iter = epoch_iterator.EpochIterator(x=x)\n        self.assertIsInstance(\n            epoch_iter.data_adapter,\n            data_adapters.GeneratorDataAdapter,\n        )\n\n    def test_unrecognized_data_type(self):\n        x = \"unsupported_data\"\n        with self.assertRaisesRegex(ValueError, \"Unrecognized data type\"):\n            _ = epoch_iterator.EpochIterator(x=x)\n\n    @parameterized.named_parameters(\n        [\n            {\"testcase_name\": \"infinite\", \"infinite\": True},\n            {\"testcase_name\": \"finite\", \"infinite\": False},\n        ]\n    )\n    def test_epoch_callbacks(self, infinite):\n        class TestPyDataset(data_adapters.py_dataset_adapter.PyDataset):\n            def __init__(\n                self,\n                workers=1,\n                use_multiprocessing=False,\n                max_queue_size=10,\n                infinite=False,\n            ):\n                super().__init__(workers, use_multiprocessing, max_queue_size)\n                self.data = np.random.rand(64, 2)\n                self.batch_size = 16\n                self.infinite = infinite\n\n                # check that callbacks are called in the correct order\n                self.tracker = []\n\n            @property\n            def num_batches(self):\n                if self.infinite:\n                    return None\n                return len(self.data) // self.batch_size\n\n            def on_epoch_begin(self):\n                self.tracker.append(1)\n\n            def __getitem__(self, index):\n                idx = index % 2\n                return self.data[\n                    idx * self.batch_size : (idx + 1) * self.batch_size\n                ]\n\n            def on_epoch_end(self):\n                self.tracker.append(2)\n\n        ds = TestPyDataset(infinite=infinite)\n        epoch_iter = epoch_iterator.EpochIterator(x=ds, steps_per_epoch=10)\n\n        num_epochs = 5\n        for epoch in range(num_epochs):\n            for _, _, _ in epoch_iter:\n                pass\n\n        self.assertAllEqual(ds.tracker, [1, 2] * num_epochs)\n"
  },
  {
    "path": "keras/src/trainers/trainer.py",
    "content": "import inspect\nimport platform\nimport warnings\n\nfrom keras.src import backend\nfrom keras.src import metrics as metrics_module\nfrom keras.src import ops\nfrom keras.src import optimizers\nfrom keras.src import tree\nfrom keras.src.optimizers.loss_scale_optimizer import LossScaleOptimizer\nfrom keras.src.saving import serialization_lib\nfrom keras.src.trainers.compile_utils import CompileLoss\nfrom keras.src.trainers.compile_utils import CompileMetrics\nfrom keras.src.trainers.data_adapters import data_adapter_utils\nfrom keras.src.utils import traceback_utils\nfrom keras.src.utils import tracking\n\n\nclass Trainer:\n    def __init__(self):\n        self._lock = False\n        self._run_eagerly = False\n        self._jit_compile = None\n        self.compiled = False\n        self.loss = None\n        self.steps_per_execution = 1\n        # Can be set by callbacks in on_train_begin\n        self._initial_epoch = None\n        self._compute_loss_has_training_arg = (\n            \"training\" in inspect.signature(self.compute_loss).parameters\n        )\n\n        # Placeholders used in `compile`\n        self._compile_loss = None\n        self._compile_metrics = None\n        self._loss_tracker = None\n\n    @traceback_utils.filter_traceback\n    @tracking.no_automatic_dependency_tracking\n    def compile(\n        self,\n        optimizer=\"rmsprop\",\n        loss=None,\n        loss_weights=None,\n        metrics=None,\n        weighted_metrics=None,\n        run_eagerly=False,\n        steps_per_execution=1,\n        jit_compile=\"auto\",\n        auto_scale_loss=True,\n    ):\n        \"\"\"Configures the model for training.\n\n        Example:\n\n        ```python\n        model.compile(\n            optimizer=keras.optimizers.Adam(learning_rate=1e-3),\n            loss=keras.losses.BinaryCrossentropy(),\n            metrics=[\n                keras.metrics.BinaryAccuracy(),\n                keras.metrics.FalseNegatives(),\n            ],\n        )\n        ```\n\n        Args:\n            optimizer: String (name of optimizer) or optimizer instance. See\n                `keras.optimizers`.\n            loss: Loss function. May be a string (name of loss function), or\n                a `keras.losses.Loss` instance. See `keras.losses`. A\n                loss function is any callable with the signature\n                `loss = fn(y_true, y_pred)`, where `y_true` are the ground truth\n                values, and `y_pred` are the model's predictions.\n                `y_true` should have shape `(batch_size, d0, .. dN)`\n                (except in the case of sparse loss functions such as\n                sparse categorical crossentropy which expects integer arrays of\n                shape `(batch_size, d0, .. dN-1)`).\n                `y_pred` should have shape `(batch_size, d0, .. dN)`.\n                The loss function should return a float tensor.\n            loss_weights: Optional list or dictionary specifying scalar\n                coefficients (Python floats) to weight the loss contributions of\n                different model outputs. The loss value that will be minimized\n                by the model will then be the *weighted sum* of all individual\n                losses, weighted by the `loss_weights` coefficients.  If a list,\n                it is expected to have a 1:1 mapping to the model's outputs. If\n                a dict, it is expected to map output names (strings) to scalar\n                coefficients.\n            metrics: List of metrics to be evaluated by the model during\n                training and testing. Each of this can be a string (name of a\n                built-in function), function or a `keras.metrics.Metric`\n                instance. See `keras.metrics`. Typically you will use\n                `metrics=['accuracy']`. A function is any callable with the\n                signature `result = fn(y_true, _pred)`. To specify different\n                metrics for different outputs of a multi-output model, you could\n                also pass a dictionary, such as\n                `metrics={'a':'accuracy', 'b':['accuracy', 'mse']}`.\n                You can also pass a list to specify a metric or a list of\n                metrics for each output, such as\n                `metrics=[['accuracy'], ['accuracy', 'mse']]`\n                or `metrics=['accuracy', ['accuracy', 'mse']]`. When you pass\n                the strings 'accuracy' or 'acc', we convert this to one of\n                `keras.metrics.BinaryAccuracy`,\n                `keras.metrics.CategoricalAccuracy`,\n                `keras.metrics.SparseCategoricalAccuracy` based on the\n                shapes of the targets and of the model output. A similar\n                conversion is done for the strings `\"crossentropy\"`\n                and `\"ce\"` as well.\n                The metrics passed here are evaluated without sample weighting;\n                if you would like sample weighting to apply, you can specify\n                your metrics via the `weighted_metrics` argument instead.\n            weighted_metrics: List of metrics to be evaluated and weighted by\n                `sample_weight` or `class_weight` during training and testing.\n            run_eagerly: Bool. If `True`, this model's forward pass\n                 will never be compiled. It is recommended to leave this\n                 as `False` when training (for best performance),\n                 and to set it to `True` when debugging.\n            steps_per_execution: Int. The number of batches to run\n                during each a single compiled function call. Running multiple\n                batches inside a single compiled function call can\n                greatly improve performance on TPUs or small models with a large\n                Python overhead. At most, one full epoch will be run each\n                execution. If a number larger than the size of the epoch is\n                passed, the execution will be truncated to the size of the\n                epoch. Note that if `steps_per_execution` is set to `N`,\n                `Callback.on_batch_begin` and `Callback.on_batch_end` methods\n                will only be called every `N` batches (i.e. before/after\n                each compiled function execution).\n                Not supported with the PyTorch backend.\n            jit_compile: Bool or `\"auto\"`. Whether to use XLA compilation when\n                compiling a model. For `jax` and `tensorflow` backends,\n                `jit_compile=\"auto\"` enables XLA compilation if the model\n                supports it, and disabled otherwise.\n                For `torch` backend, `\"auto\"` will default to eager\n                execution and `jit_compile=True` will run with `torch.compile`\n                with the `\"inductor\"` backend.\n            auto_scale_loss: Bool. If `True` and the model dtype policy is\n                `\"mixed_float16\"`, the passed optimizer will be automatically\n                wrapped in a `LossScaleOptimizer`, which will dynamically\n                scale the loss to prevent underflow.\n\n        Note:\n            Trainable variables are determined at `compile()` time. If you\n            modify the `trainable` property of a layer after calling\n            `compile()`, those changes will not take effect during `fit()`\n            unless `compile()` is called again.\n\n            Recommended workflow when changing trainable variables:\n            ```python\n            # Initial training with some layers\n            model.compile(optimizer=\"adam\", loss=\"mse\")\n            model.fit(x_train, y_train)\n\n            # Change trainable flags\n            layer.trainable = False  # or True\n\n            # Recompile for the change to take effect\n            model.compile(optimizer=\"adam\", loss=\"mse\")\n            model.fit(x_train, y_train)\n            ```\n\n            This behavior applies to all Keras backends and is also documented\n            in the transfer learning guide.\n        \"\"\"\n        optimizer = optimizers.get(optimizer)\n        self.optimizer = optimizer\n        if (\n            auto_scale_loss\n            and self.dtype_policy.name == \"mixed_float16\"\n            and self.optimizer\n            and not isinstance(self.optimizer, LossScaleOptimizer)\n        ):\n            self.optimizer = LossScaleOptimizer(\n                self.optimizer, name=\"loss_scale_optimizer\"\n            )\n        if hasattr(self, \"output_names\"):\n            output_names = self.output_names\n        else:\n            output_names = None\n        if loss is not None:\n            self._compile_loss = CompileLoss(\n                loss, loss_weights, output_names=output_names\n            )\n            self.loss = loss\n        if metrics is not None or weighted_metrics is not None:\n            self._compile_metrics = CompileMetrics(\n                metrics, weighted_metrics, output_names=output_names\n            )\n        if jit_compile == \"auto\":\n            if run_eagerly:\n                jit_compile = False\n            else:\n                jit_compile = self._resolve_auto_jit_compile()\n        if jit_compile and run_eagerly:\n            jit_compile = False\n            warnings.warn(\n                \"If `run_eagerly` is True, then `jit_compile` \"\n                \"cannot also be True. Disabling `jit_compile`.\",\n                stacklevel=2,\n            )\n\n        self.jit_compile = jit_compile\n        self.run_eagerly = run_eagerly\n        self.stop_training = False\n        self.compiled = True\n        self._loss_tracker = metrics_module.Mean(name=\"loss\")\n        self.steps_per_execution = steps_per_execution\n\n        self.train_function = None\n        self.test_function = None\n        self.predict_function = None\n\n        self._compile_config = serialization_lib.SerializableDict(\n            optimizer=optimizer,\n            loss=loss,\n            loss_weights=loss_weights,\n            metrics=metrics,\n            weighted_metrics=weighted_metrics,\n            run_eagerly=run_eagerly,\n            steps_per_execution=steps_per_execution,\n            jit_compile=jit_compile,\n        )\n\n    @property\n    def jit_compile(self):\n        if self._jit_compile is None:\n            # Value was never set. Resolve it now.\n            self._jit_compile = self._resolve_auto_jit_compile()\n        return self._jit_compile\n\n    @jit_compile.setter\n    def jit_compile(self, value):\n        if value and not model_supports_jit(self):\n            warnings.warn(\n                \"Model doesn't support `jit_compile=True`. \"\n                \"Proceeding with `jit_compile=False`.\"\n            )\n            self._jit_compile = False\n        else:\n            self._jit_compile = value\n\n    def _resolve_auto_jit_compile(self):\n        if backend.backend() == \"torch\":\n            # jit_compile = \"auto\" with the pytorch backend defaults to eager\n            return False\n\n        if backend.backend() == \"tensorflow\":\n            import tensorflow as tf\n\n            devices = tf.config.list_physical_devices()\n            if not list(filter(lambda x: x.device_type != \"CPU\", devices)):\n                # Disable XLA on CPU-only machines.\n                return False\n\n            if self._distribute_strategy:\n                # Disable XLA with tf.distribute\n                return False\n\n        if model_supports_jit(self):\n            return True\n        return False\n\n    @property\n    def run_eagerly(self):\n        return self._run_eagerly\n\n    @run_eagerly.setter\n    def run_eagerly(self, value):\n        self._run_eagerly = value\n\n    @property\n    def metrics(self):\n        # Order: loss tracker, individual loss trackers, compiled metrics,\n        # custom metrics, sublayer metrics.\n        metrics = []\n        if self.compiled:\n            if self._loss_tracker is not None:\n                metrics.append(self._loss_tracker)\n            if self._compile_metrics is not None:\n                metrics.append(self._compile_metrics)\n            if self._compile_loss is not None:\n                metrics.extend(self._compile_loss.metrics)\n        metrics.extend(self._metrics)\n        for layer in self._flatten_layers(include_self=False):\n            if isinstance(layer, Trainer):\n                # All Trainer-related metrics in sublayers should be ignored\n                # because a new Trainer has been instantiated.\n                continue\n            metrics.extend(layer.metrics)\n        return metrics\n\n    @property\n    def metrics_names(self):\n        return [m.name for m in self.metrics]\n\n    def reset_metrics(self):\n        for m in self.metrics:\n            m.reset_state()\n\n    def _get_own_metrics(self):\n        metrics = []\n        if self._loss_tracker is not None:\n            metrics.append(self._loss_tracker)\n        if self._compile_metrics is not None:\n            metrics.append(self._compile_metrics)\n        if self._compile_loss is not None:\n            metrics.extend(self._compile_loss.metrics)\n        metrics.extend(self._metrics)\n        return metrics\n\n    def compute_loss(\n        self,\n        x=None,\n        y=None,\n        y_pred=None,\n        sample_weight=None,\n        training=True,\n    ):\n        \"\"\"Compute the total loss, validate it, and return it.\n\n        Subclasses can optionally override this method to provide custom loss\n        computation logic.\n\n        Example:\n\n        ```python\n        class MyModel(Model):\n            def __init__(self, *args, **kwargs):\n                super().__init__(*args, **kwargs)\n                self.loss_tracker = metrics.Mean(name='loss')\n\n            def compute_loss(self, x, y, y_pred, sample_weight, training=True):\n                loss = ops.mean((y_pred - y) ** 2)\n                loss += ops.sum(self.losses)\n                self.loss_tracker.update_state(loss)\n                return loss\n\n            def reset_metrics(self):\n                self.loss_tracker.reset_state()\n\n            @property\n            def metrics(self):\n                return [self.loss_tracker]\n\n        inputs = layers.Input(shape=(10,), name='my_input')\n        outputs = layers.Dense(10)(inputs)\n        model = MyModel(inputs, outputs)\n        model.add_loss(ops.sum(outputs))\n\n        optimizer = SGD()\n        model.compile(optimizer, loss='mse', steps_per_execution=10)\n        dataset = ...\n        model.fit(dataset, epochs=2, steps_per_epoch=10)\n        print(f\"Custom loss: {model.loss_tracker.result()}\")\n        ```\n\n        Args:\n            x: Input data.\n            y: Target data.\n            y_pred: Predictions returned by the model (output of `model(x)`)\n            sample_weight: Sample weights for weighting the loss function.\n            training: Whether we are training or evaluating the model.\n\n        Returns:\n            The total loss as a scalar tensor, or `None` if no loss results\n            (which is the case when called by `Model.test_step`).\n        \"\"\"\n        # The default implementation does not use `x` or `training`.\n        del x\n        del training\n        losses = []\n        if self._compile_loss is not None:\n            loss = self._compile_loss(y, y_pred, sample_weight)\n            if loss is not None:\n                losses.append(loss)\n        for loss in self.losses:\n            losses.append(self._aggregate_additional_loss(loss))\n        if backend.backend() != \"jax\" and len(losses) == 0:\n            raise ValueError(\n                \"No loss to compute. Provide a `loss` argument in `compile()`.\"\n            )\n        if len(losses) == 1:\n            total_loss = losses[0]\n        elif len(losses) == 0:\n            total_loss = ops.zeros(())\n        else:\n            total_loss = ops.sum(losses)\n        return total_loss\n\n    def _compute_loss(\n        self,\n        x=None,\n        y=None,\n        y_pred=None,\n        sample_weight=None,\n        training=True,\n    ):\n        \"\"\"Backwards compatibility wrapper for `compute_loss`.\n\n        This should be used instead `compute_loss` within `train_step` and\n        `test_step` to support overrides of `compute_loss` that may not have\n        the `training` argument, as this argument was added in Keras 3.3.\n        \"\"\"\n        if self._compute_loss_has_training_arg:\n            return self.compute_loss(\n                x, y, y_pred, sample_weight, training=training\n            )\n        else:\n            return self.compute_loss(x, y, y_pred, sample_weight)\n\n    def _aggregate_additional_loss(self, loss):\n        \"\"\"Aggregates losses from `add_loss`, regularizers and sublayers.\n\n        Args:\n            loss: A tensor representing the additional loss to aggregate.\n\n        Returns:\n            A tensor representing the summed loss, cast to the `floatx()` if\n            necessary.\n        \"\"\"\n        if not backend.is_float_dtype(loss.dtype):\n            loss = ops.cast(loss, dtype=backend.floatx())\n        return ops.sum(loss)\n\n    def stateless_compute_loss(\n        self,\n        trainable_variables,\n        non_trainable_variables,\n        metrics_variables,\n        x=None,\n        y=None,\n        y_pred=None,\n        sample_weight=None,\n        training=True,\n    ):\n        var_mapping = list(zip(self.trainable_variables, trainable_variables))\n        var_mapping.extend(\n            zip(self.non_trainable_variables, non_trainable_variables)\n        )\n        var_mapping.extend(zip(self.metrics_variables, metrics_variables))\n        with backend.StatelessScope(state_mapping=var_mapping) as scope:\n            # Note that this is needed for the regularization loss, which need\n            # the latest value of train/non-trainable variables.\n            loss = self._compute_loss(\n                x,\n                y,\n                y_pred,\n                sample_weight=sample_weight,\n                training=training,\n            )\n\n        # Update non trainable vars (may have been updated in compute_loss)\n        non_trainable_variables = []\n        for v in self.non_trainable_variables:\n            new_v = scope.get_current_value(v)\n            non_trainable_variables.append(new_v)\n\n        # Update metrics vars (may have been updated in compute_loss)\n        metrics_variables = []\n        for v in self.metrics_variables:\n            new_v = scope.get_current_value(v)\n            metrics_variables.append(new_v)\n        return loss, (\n            trainable_variables,\n            non_trainable_variables,\n            metrics_variables,\n        )\n\n    def compute_metrics(self, x, y, y_pred, sample_weight=None):\n        \"\"\"Update metric states and collect all metrics to be returned.\n\n        Subclasses can optionally override this method to provide custom metric\n        updating and collection logic. Custom metrics are not passed in\n        `compile()`, they can be created in `__init__` or `build`. They are\n        automatically tracked and returned by `self.metrics`.\n\n        Example:\n\n        ```python\n        class MyModel(Sequential):\n            def __init__(self, *args, **kwargs):\n                super().__init__(*args, **kwargs)\n                self.custom_metric = MyMetric(name=\"custom_metric\")\n\n            def compute_metrics(self, x, y, y_pred, sample_weight):\n                # This super call updates metrics from `compile` and returns\n                # results for all metrics listed in `self.metrics`.\n                metric_results = super().compute_metrics(\n                    x, y, y_pred, sample_weight)\n\n                # `metric_results` contains the previous result for\n                # `custom_metric`, this is where we update it.\n                self.custom_metric.update_state(x, y, y_pred, sample_weight)\n                metric_results['custom_metric'] = self.custom_metric.result()\n                return metric_results\n        ```\n\n        Args:\n            x: Input data.\n            y: Target data.\n            y_pred: Predictions returned by the model output of `model.call(x)`.\n            sample_weight: Sample weights for weighting the loss function.\n\n        Returns:\n            A `dict` containing values that will be passed to\n            `keras.callbacks.CallbackList.on_train_batch_end()`. Typically,\n            the values of the metrics listed in `self.metrics` are returned.\n            Example: `{'loss': 0.2, 'accuracy': 0.7}`.\n        \"\"\"\n        del x  # The default implementation does not use `x`.\n        if self._compile_metrics is not None:\n            self._compile_metrics.update_state(y, y_pred, sample_weight)\n        return self.get_metrics_result()\n\n    def get_metrics_result(self):\n        \"\"\"Returns the model's metrics values as a dict.\n\n        If any of the metric result is a dict (containing multiple metrics),\n        each of them gets added to the top level returned dict of this method.\n\n        Returns:\n            A `dict` containing values of the metrics listed in `self.metrics`.\n            Example: `{'loss': 0.2, 'accuracy': 0.7}`.\n        \"\"\"\n        return_metrics = {}\n        for metric in self.metrics:\n            result = metric.result()\n            if isinstance(result, dict):\n                return_metrics.update(result)\n            else:\n                return_metrics[metric.name] = result\n        return return_metrics\n\n    def fit(\n        self,\n        x=None,\n        y=None,\n        batch_size=None,\n        epochs=1,\n        verbose=\"auto\",\n        callbacks=None,\n        validation_split=0.0,\n        validation_data=None,\n        shuffle=True,\n        class_weight=None,\n        sample_weight=None,\n        initial_epoch=0,\n        steps_per_epoch=None,\n        validation_steps=None,\n        validation_batch_size=None,\n        validation_freq=1,\n    ):\n        \"\"\"Trains the model for a fixed number of epochs (dataset iterations).\n\n        Args:\n            x: Input data. It can be:\n                - A NumPy array (or array-like), or a list of arrays\n                (in case the model has multiple inputs).\n                - A backend-native tensor, or a list of tensors\n                (in case the model has multiple inputs).\n                - A dict mapping input names to the corresponding array/tensors,\n                if the model has named inputs.\n                - A `keras.utils.PyDataset` returning `(inputs, targets)` or\n                `(inputs, targets, sample_weights)`.\n                - A `tf.data.Dataset` yielding `(inputs, targets)` or\n                `(inputs, targets, sample_weights)`.\n                - A `torch.utils.data.DataLoader` yielding `(inputs, targets)`\n                or `(inputs, targets, sample_weights)`.\n                - A Python generator function yielding `(inputs, targets)` or\n                `(inputs, targets, sample_weights)`.\n            y: Target data. Like the input data `x`, it can be either NumPy\n                array(s) or backend-native tensor(s). If `x` is a\n                `keras.utils.PyDataset`, `tf.data.Dataset`,\n                `torch.utils.data.DataLoader` or a Python generator function,\n                `y` should not be specified since targets will be obtained from\n                `x`.\n            batch_size: Integer or `None`.\n                Number of samples per gradient update.\n                If unspecified, `batch_size` will default to 32.\n                Do not specify the `batch_size` if your input data `x` is a\n                `keras.utils.PyDataset`, `tf.data.Dataset`,\n                `torch.utils.data.DataLoader` or Python generator function\n                since they generate batches.\n            epochs: Integer. Number of epochs to train the model.\n                An epoch is an iteration over the entire `x` and `y`\n                data provided\n                (unless the `steps_per_epoch` flag is set to\n                something other than None).\n                Note that in conjunction with `initial_epoch`,\n                `epochs` is to be understood as \"final epoch\".\n                The model is not trained for a number of iterations\n                given by `epochs`, but merely until the epoch\n                of index `epochs` is reached.\n            verbose: `\"auto\"`, 0, 1, or 2. Verbosity mode.\n                0 = silent, 1 = progress bar, 2 = one line per epoch.\n                \"auto\" becomes 1 for most cases.\n                Note that the progress bar is not\n                particularly useful when logged to a file,\n                so `verbose=2` is recommended when not running interactively\n                (e.g., in a production environment). Defaults to `\"auto\"`.\n            callbacks: List of `keras.callbacks.Callback` instances.\n                List of callbacks to apply during training.\n                See `keras.callbacks`. Note\n                `keras.callbacks.ProgbarLogger` and\n                `keras.callbacks.History` callbacks are created\n                automatically and need not be passed to `model.fit()`.\n                `keras.callbacks.ProgbarLogger` is created\n                or not based on the `verbose` argument in `model.fit()`.\n            validation_split: Float between 0 and 1.\n                Fraction of the training data to be used as validation data.\n                The model will set apart this fraction of the training data,\n                will not train on it, and will evaluate the loss and any model\n                metrics on this data at the end of each epoch. The validation\n                data is selected from the last samples in the `x` and `y` data\n                provided, before shuffling.\n                This argument is only supported when `x` and `y` are made of\n                NumPy arrays or tensors.\n                If both `validation_data` and `validation_split` are provided,\n                `validation_data` will override `validation_split`.\n            validation_data: Data on which to evaluate\n                the loss and any model metrics at the end of each epoch.\n                The model will not be trained on this data. Thus, note the fact\n                that the validation loss of data provided using\n                `validation_split` or `validation_data` is not affected by\n                regularization layers like noise and dropout.\n                `validation_data` will override `validation_split`.\n                It can be:\n                - A tuple `(x_val, y_val)` of NumPy arrays or tensors.\n                - A tuple `(x_val, y_val, val_sample_weights)` of NumPy\n                arrays.\n                - A `keras.utils.PyDataset`, a `tf.data.Dataset`, a\n                `torch.utils.data.DataLoader` yielding `(inputs, targets)` or a\n                Python generator function yielding `(x_val, y_val)` or\n                `(inputs, targets, sample_weights)`.\n            shuffle: Boolean, whether to shuffle the training data before each\n                epoch. This argument is ignored when `x` is a\n                `keras.utils.PyDataset`, `tf.data.Dataset`,\n                `torch.utils.data.DataLoader` or Python generator function.\n            class_weight: Optional dictionary mapping class indices (integers)\n                to a weight (float) value, used for weighting the loss function\n                (during training only).\n                This can be useful to tell the model to\n                \"pay more attention\" to samples from\n                an under-represented class. When `class_weight` is specified\n                and targets have a rank of 2 or greater, either `y` must be\n                one-hot encoded, or an explicit final dimension of `1` must\n                be included for sparse class labels.\n            sample_weight: Optional NumPy array or tensor of weights for\n                the training samples, used for weighting the loss function\n                (during training only). You can either pass a flat (1D)\n                NumPy array or tensor with the same length as the input samples\n                (1:1 mapping between weights and samples), or in the case of\n                temporal data, you can pass a 2D NumPy array or tensor with\n                shape `(samples, sequence_length)` to apply a different weight\n                to every timestep of every sample.\n                This argument is not supported when `x` is a\n                `keras.utils.PyDataset`, `tf.data.Dataset`,\n                `torch.utils.data.DataLoader` or Python generator function.\n                Instead, provide `sample_weights` as the third element of `x`.\n                Note that sample weighting does not apply to metrics specified\n                via the `metrics` argument in `compile()`. To apply sample\n                weighting to your metrics, you can specify them via the\n                `weighted_metrics` in `compile()` instead.\n            initial_epoch: Integer.\n                Epoch at which to start training\n                (useful for resuming a previous training run).\n            steps_per_epoch: Integer or `None`.\n                Total number of steps (batches of samples) before declaring one\n                epoch finished and starting the next epoch. When training with\n                input tensors or NumPy arrays, the default `None` means that the\n                value used is the number of samples in your dataset divided by\n                the batch size, or 1 if that cannot be determined.\n                If `x` is a `keras.utils.PyDataset`, `tf.data.Dataset`,\n                `torch.utils.data.DataLoader` or Python generator function, the\n                epoch will run until the input dataset is exhausted. When\n                passing an infinitely repeating dataset, you must specify the\n                `steps_per_epoch` argument, otherwise the training will run\n                indefinitely.\n            validation_steps: Integer or `None`.\n                Only relevant if `validation_data` is provided.\n                Total number of steps (batches of samples) to draw before\n                stopping when performing validation at the end of every epoch.\n                If `validation_steps` is `None`, validation will run until the\n                `validation_data` dataset is exhausted. In the case of an\n                infinitely repeating dataset, it will run indefinitely. If\n                `validation_steps` is specified and only part of the dataset\n                is consumed, the evaluation will start from the beginning of the\n                dataset at each epoch. This ensures that the same validation\n                samples are used every time.\n            validation_batch_size: Integer or `None`.\n                Number of samples per validation batch.\n                If unspecified, will default to `batch_size`.\n                Do not specify the `validation_batch_size` if your data is a\n                `keras.utils.PyDataset`, `tf.data.Dataset`,\n                `torch.utils.data.DataLoader` or Python generator function\n                since they generate batches.\n            validation_freq: Only relevant if validation data is provided.\n                Specifies how many training epochs to run\n                before a new validation run is performed,\n                e.g. `validation_freq=2` runs validation every 2 epochs.\n\n        Unpacking behavior for iterator-like inputs:\n            A common pattern is to pass an iterator like object such as a\n            `tf.data.Dataset` or a `keras.utils.PyDataset` to `fit()`,\n            which will in fact yield not only features (`x`)\n            but optionally targets (`y`) and sample weights (`sample_weight`).\n            Keras requires that the output of such iterator-likes be\n            unambiguous. The iterator should return a tuple\n            of length 1, 2, or 3, where the optional second and third elements\n            will be used for `y` and `sample_weight` respectively.\n            Any other type provided will be wrapped in\n            a length-one tuple, effectively treating everything as `x`. When\n            yielding dicts, they should still adhere to the top-level tuple\n            structure,\n            e.g. `({\"x0\": x0, \"x1\": x1}, y)`. Keras will not attempt to separate\n            features, targets, and weights from the keys of a single dict.\n            A notable unsupported data type is the `namedtuple`. The reason is\n            that it behaves like both an ordered datatype (tuple) and a mapping\n            datatype (dict). So given a namedtuple of the form:\n            `namedtuple(\"example_tuple\", [\"y\", \"x\"])`\n            it is ambiguous whether to reverse the order of the elements when\n            interpreting the value. Even worse is a tuple of the form:\n            `namedtuple(\"other_tuple\", [\"x\", \"y\", \"z\"])`\n            where it is unclear if the tuple was intended to be unpacked\n            into `x`, `y`, and `sample_weight` or passed through\n            as a single element to `x`.\n\n        Returns:\n            A `History` object. Its `History.history` attribute is\n            a record of training loss values and metrics values\n            at successive epochs, as well as validation loss values\n            and validation metrics values (if applicable).\n        \"\"\"\n        raise NotImplementedError\n\n    def evaluate(\n        self,\n        x=None,\n        y=None,\n        batch_size=None,\n        verbose=\"auto\",\n        sample_weight=None,\n        steps=None,\n        callbacks=None,\n        return_dict=False,\n        **kwargs,\n    ):\n        \"\"\"Returns the loss value & metrics values for the model in test mode.\n\n        Computation is done in batches (see the `batch_size` arg.)\n\n        Args:\n            x: Input data. It can be:\n                - A NumPy array (or array-like), or a list of arrays\n                (in case the model has multiple inputs).\n                - A backend-native tensor, or a list of tensors\n                (in case the model has multiple inputs).\n                - A dict mapping input names to the corresponding array/tensors,\n                if the model has named inputs.\n                - A `keras.utils.PyDataset` returning `(inputs, targets)` or\n                `(inputs, targets, sample_weights)`.\n                - A `tf.data.Dataset` yielding `(inputs, targets)` or\n                `(inputs, targets, sample_weights)`.\n                - A `torch.utils.data.DataLoader` yielding `(inputs, targets)`\n                or `(inputs, targets, sample_weights)`.\n                - A Python generator function yielding `(inputs, targets)` or\n                `(inputs, targets, sample_weights)`.\n            y: Target data. Like the input data `x`, it can be either NumPy\n                array(s) or backend-native tensor(s). If `x` is a\n                `keras.utils.PyDataset`, `tf.data.Dataset`,\n                `torch.utils.data.DataLoader` or a Python generator function,\n                `y` should not be specified since targets will be obtained from\n                `x`.\n            batch_size: Integer or `None`.\n                Number of samples per batch of computation.\n                If unspecified, `batch_size` will default to 32.\n                Do not specify the `batch_size` if your input data `x` is a\n                `keras.utils.PyDataset`, `tf.data.Dataset`,\n                `torch.utils.data.DataLoader` or Python generator function\n                since they generate batches.\n            verbose: `\"auto\"`, 0, 1, or 2. Verbosity mode.\n                0 = silent, 1 = progress bar, 2 = single line.\n                `\"auto\"` becomes 1 for most cases.\n                Note that the progress bar is not\n                particularly useful when logged to a file, so `verbose=2` is\n                recommended when not running interactively\n                (e.g. in a production environment). Defaults to `\"auto\"`.\n            sample_weight: Optional NumPy array or tensor of weights for\n                the training samples, used for weighting the loss function\n                (during training only). You can either pass a flat (1D)\n                NumPy array or tensor with the same length as the input samples\n                (1:1 mapping between weights and samples), or in the case of\n                temporal data, you can pass a 2D NumPy array or tensor with\n                shape `(samples, sequence_length)` to apply a different weight\n                to every timestep of every sample.\n                This argument is not supported when `x` is a\n                `keras.utils.PyDataset`, `tf.data.Dataset`,\n                `torch.utils.data.DataLoader` or Python generator function.\n                Instead, provide `sample_weights` as the third element of `x`.\n                Note that sample weighting does not apply to metrics specified\n                via the `metrics` argument in `compile()`. To apply sample\n                weighting to your metrics, you can specify them via the\n                `weighted_metrics` in `compile()` instead.\n            steps: Integer or `None`.\n                Total number of steps (batches of samples) to draw before\n                declaring the evaluation round finished. If `steps` is `None`,\n                it will run until `x` is exhausted. In the case of an infinitely\n                repeating dataset, it will run indefinitely.\n            callbacks: List of `keras.callbacks.Callback` instances.\n                List of callbacks to apply during evaluation.\n            return_dict: If `True`, loss and metric results are returned as a\n                dict, with each key being the name of the metric.\n                If `False`, they are returned as a list.\n\n        Returns:\n            Scalar test loss (if the model has a single output and no metrics)\n            or list of scalars (if the model has multiple outputs\n            and/or metrics).\n\n        Note: When using compiled metrics, `evaluate()` may return multiple\n        submetric values, while `model.metrics_names` often lists only\n        top-level names (e.g., 'loss', 'compile_metrics'), leading to a\n        length mismatch. The order of the `evaluate()` output corresponds\n        to the order of metrics specified during `model.compile()`. You can\n        use this order to map the `evaluate()` results to the intended\n        metric. `model.metrics_names` itself will still return only the\n        top-level names.\n        \"\"\"\n        raise NotImplementedError\n\n    def predict(\n        self, x, batch_size=None, verbose=\"auto\", steps=None, callbacks=None\n    ):\n        \"\"\"Generates output predictions for the input samples.\n\n        Computation is done in batches. This method is designed for batch\n        processing of large numbers of inputs. It is not intended for use inside\n        of loops that iterate over your data and process small numbers of inputs\n        at a time.\n\n        For small numbers of inputs that fit in one batch,\n        directly use `__call__()` for faster execution, e.g.,\n        `model(x)`, or `model(x, training=False)` if you have layers such as\n        `BatchNormalization` that behave differently during\n        inference.\n\n        Note: See [this FAQ entry](\n        https://keras.io/getting_started/faq/#whats-the-difference-between-model-methods-predict-and-call)\n        for more details about the difference between `Model` methods\n        `predict()` and `__call__()`.\n\n        Args:\n            x: Input data. It can be:\n                - A NumPy array (or array-like), or a list of arrays\n                (in case the model has multiple inputs).\n                - A backend-native tensor, or a list of tensors\n                (in case the model has multiple inputs).\n                - A dict mapping input names to the corresponding array/tensors,\n                if the model has named inputs.\n                - A `keras.utils.PyDataset`.\n                - A `tf.data.Dataset`.\n                - A `torch.utils.data.DataLoader`.\n                - A Python generator function.\n            batch_size: Integer or `None`.\n                Number of samples per batch of computation.\n                If unspecified, `batch_size` will default to 32.\n                Do not specify the `batch_size` if your input data `x` is a\n                `keras.utils.PyDataset`, `tf.data.Dataset`,\n                `torch.utils.data.DataLoader` or Python generator function\n                since they generate batches.\n            verbose: `\"auto\"`, 0, 1, or 2. Verbosity mode.\n                0 = silent, 1 = progress bar, 2 = single line.\n                `\"auto\"` becomes 1 for most cases. Note that the progress bar\n                is not particularly useful when logged to a file,\n                so `verbose=2` is recommended when not running interactively\n                (e.g. in a production environment). Defaults to `\"auto\"`.\n            steps: Total number of steps (batches of samples) to draw before\n                declaring the prediction round finished. If `steps` is `None`,\n                it will run until `x` is exhausted. In the case of an infinitely\n                repeating dataset, it will run indefinitely.\n            callbacks: List of `keras.callbacks.Callback` instances.\n                List of callbacks to apply during prediction.\n\n        Returns:\n            NumPy array(s) of predictions.\n        \"\"\"\n        raise NotImplementedError\n\n    def train_on_batch(\n        self,\n        x,\n        y=None,\n        sample_weight=None,\n        class_weight=None,\n        return_dict=False,\n    ):\n        \"\"\"Runs a single gradient update on a single batch of data.\n\n        Args:\n            x: Input data. Must be array-like.\n            y: Target data. Must be array-like.\n            sample_weight: Optional array of the same length as x, containing\n                weights to apply to the model's loss for each sample.\n                In the case of temporal data, you can pass a 2D array\n                with shape `(samples, sequence_length)`, to apply a different\n                weight to every timestep of every sample.\n            class_weight: Optional dictionary mapping class indices (integers)\n                to a weight (float) to apply to the model's loss for the samples\n                from this class during training. This can be useful to tell the\n                model to \"pay more attention\" to samples from an\n                under-represented class. When `class_weight` is specified\n                and targets have a rank of 2 or greater, either `y` must\n                be one-hot encoded, or an explicit final dimension of 1\n                must be included for sparse class labels.\n            return_dict: If `True`, loss and metric results are returned as a\n                dict, with each key being the name of the metric. If `False`,\n                they are returned as a list.\n\n        Returns:\n            A scalar loss value (when no metrics and `return_dict=False`),\n            a list of loss and metric values\n            (if there are metrics and `return_dict=False`), or a dict of\n            metric and loss values (if `return_dict=True`).\n        \"\"\"\n        raise NotImplementedError\n\n    def test_on_batch(\n        self,\n        x,\n        y=None,\n        sample_weight=None,\n        return_dict=False,\n    ):\n        \"\"\"Test the model on a single batch of samples.\n\n        Args:\n            x: Input data. Must be array-like.\n            y: Target data. Must be array-like.\n            sample_weight: Optional array of the same length as x, containing\n                weights to apply to the model's loss for each sample.\n                In the case of temporal data, you can pass a 2D array\n                with shape `(samples, sequence_length)`, to apply a different\n                weight to every timestep of every sample.\n            return_dict: If `True`, loss and metric results are returned as a\n                dict, with each key being the name of the metric. If `False`,\n                they are returned as a list.\n\n        Returns:\n            A scalar loss value (when no metrics and `return_dict=False`),\n            a list of loss and metric values\n            (if there are metrics and `return_dict=False`), or a dict of\n            metric and loss values (if `return_dict=True`).\n        \"\"\"\n        raise NotImplementedError\n\n    def predict_on_batch(self, x):\n        \"\"\"Returns predictions for a single batch of samples.\n\n        Args:\n            x: Input data. It must be array-like.\n\n        Returns:\n            NumPy array(s) of predictions.\n        \"\"\"\n        raise NotImplementedError\n\n    def get_compile_config(self):\n        \"\"\"Returns a serialized config with information for compiling the model.\n\n        This method returns a config dictionary containing all the information\n        (optimizer, loss, metrics, etc.) with which the model was compiled.\n\n        Returns:\n            A dict containing information for compiling the model.\n        \"\"\"\n        if self.compiled and hasattr(self, \"_compile_config\"):\n            return self._compile_config.serialize()\n        return {}\n\n    def compile_from_config(self, config):\n        \"\"\"Compiles the model with the information given in config.\n\n        This method uses the information in the config (optimizer, loss,\n        metrics, etc.) to compile the model.\n\n        Args:\n            config: Dict containing information for compiling the model.\n        \"\"\"\n        has_overridden_compile = self.__class__.compile != Trainer.compile\n        if has_overridden_compile:\n            warnings.warn(\n                \"`compile()` was not called as part of model loading \"\n                \"because the model's `compile()` method is custom. \"\n                \"All subclassed Models that have `compile()` \"\n                \"overridden should also override \"\n                \"`get_compile_config()` and `compile_from_config(config)`. \"\n                \"Alternatively, you can \"\n                \"call `compile()` manually after loading.\",\n                stacklevel=2,\n            )\n            return\n        config = serialization_lib.deserialize_keras_object(config)\n        self.compile(**config)\n        if hasattr(self, \"optimizer\") and self.built:\n            # Create optimizer variables.\n            self.optimizer.build(self.trainable_variables)\n\n    def _should_eval(self, epoch, validation_freq):\n        epoch = epoch + 1  # one-index the user-facing epoch.\n        if isinstance(validation_freq, int):\n            return epoch % validation_freq == 0\n        elif isinstance(validation_freq, list):\n            return epoch in validation_freq\n        else:\n            raise ValueError(\n                \"Expected `validation_freq` to be a list or int. \"\n                f\"Received: validation_freq={validation_freq} of the \"\n                f\"type {type(validation_freq)}.\"\n            )\n\n    def _get_metrics_result_or_logs(self, logs):\n        \"\"\"Returns model metrics as a dict if the keys match with input logs.\n\n        When the training / evaluation is performed with an asynchronous steps,\n        the last scheduled `train / test_step` may not give the latest metrics\n        because it is not guaranteed to be executed the last. This method gets\n        metrics from the model directly instead of relying on the return from\n        last step function.\n\n        When the user has custom train / test step functions, the metrics\n        returned may be different from `Model.metrics`. In those instances,\n        this function will be no-op and return the logs passed in.\n\n        Args:\n            logs: A `dict` of metrics returned by train / test step function.\n\n        Returns:\n            A `dict` containing values of the metrics listed in `self.metrics`\n            when logs and model metrics keys match. Otherwise it returns input\n            `logs`.\n        \"\"\"\n        metric_logs = self.get_metrics_result()\n        # Verify that train / test step logs passed and metric logs have\n        # matching keys. It could be different when using custom step functions,\n        # in which case we return the logs from the last step.\n        if isinstance(logs, dict) and set(logs.keys()) == set(\n            metric_logs.keys()\n        ):\n            return metric_logs\n        return logs\n\n    def _flatten_metrics_in_order(self, logs):\n        \"\"\"Turns `logs` dict into a list as per key order of `metrics_names`.\"\"\"\n        metric_names = []\n        for metric in self.metrics:\n            if isinstance(metric, CompileMetrics):\n                metric_names += [\n                    sub_metric.name for sub_metric in metric.metrics\n                ]\n            else:\n                metric_names.append(metric.name)\n        results = []\n        for name in metric_names:\n            if name in logs:\n                results.append(logs[name])\n        for key in sorted(logs.keys()):\n            if key not in metric_names:\n                results.append(logs[key])\n        if len(results) == 1:\n            return results[0]\n        return results\n\n    def _assert_compile_called(self, method_name=None):\n        if not self.compiled:\n            msg = \"You must call `compile()` before \"\n            if metrics_module:\n                msg += \"using the model.\"\n            else:\n                msg += f\"calling `{method_name}()`.\"\n            raise ValueError(msg)\n\n    def _symbolic_build(self, iterator=None, data_batch=None):\n        model_unbuilt = not all(layer.built for layer in self._flatten_layers())\n        compile_metrics_unbuilt = (\n            self._compile_metrics is not None\n            and not self._compile_metrics.built\n        )\n        compile_loss_unbuilt = (\n            self._compile_loss is not None and not self._compile_loss.built\n        )\n        optimizer_unbuilt = (\n            self.optimizer is not None and not self.optimizer.built\n        )\n        if model_unbuilt or compile_metrics_unbuilt or compile_loss_unbuilt:\n            # Create symbolic tensors matching an input batch.\n\n            def to_symbolic_input(v):\n                if v is None:\n                    return None\n                return backend.KerasTensor(\n                    v.shape, backend.standardize_dtype(v.dtype)\n                )\n\n            if data_batch is None:\n                for _, _, data_or_iterator in iterator:\n                    if isinstance(data_or_iterator, (list, tuple)):\n                        data_batch = data_or_iterator[0]\n                    else:\n                        data_batch = next(data_or_iterator)\n                    break\n            data_batch = tree.map_structure(to_symbolic_input, data_batch)\n            (\n                x,\n                y,\n                sample_weight,\n            ) = data_adapter_utils.unpack_x_y_sample_weight(data_batch)\n\n            # Build all model state with `backend.compute_output_spec`.\n            try:\n                y_pred = backend.compute_output_spec(self, x, training=False)\n            except Exception as e:\n                raise RuntimeError(\n                    \"Unable to automatically build the model. \"\n                    \"Please build it yourself before calling \"\n                    \"fit/evaluate/predict. \"\n                    \"A model is 'built' when its variables have \"\n                    \"been created and its `self.built` attribute \"\n                    \"is True. Usually, calling the model on a batch \"\n                    \"of data is the right way to build it.\\n\"\n                    \"Exception encountered:\\n\"\n                    f\"'{e}'\"\n                )\n            if compile_metrics_unbuilt:\n                # Build all metric state with `backend.compute_output_spec`.\n                backend.compute_output_spec(\n                    self.compute_metrics,\n                    x,\n                    y,\n                    y_pred,\n                    sample_weight=sample_weight,\n                )\n            if compile_loss_unbuilt:\n                # Build `CompileLoss` state with `backend.compute_output_spec`.\n                backend.compute_output_spec(\n                    self._compute_loss,\n                    x,\n                    y,\n                    y_pred,\n                    sample_weight=sample_weight,\n                    training=False,\n                )\n        if optimizer_unbuilt:\n            # Build optimizer\n            self.optimizer.build(self.trainable_variables)\n        self._post_build()\n\n\ndef model_supports_jit(model):\n    # XLA not supported with TF on MacOS GPU\n    if platform.system() == \"Darwin\" and \"arm\" in platform.processor().lower():\n        if backend.backend() == \"tensorflow\":\n            from keras.src.utils.module_utils import tensorflow as tf\n\n            if tf.config.list_physical_devices(\"GPU\"):\n                return False\n    # XLA not supported by some layers\n    if all(x.supports_jit for x in model._flatten_layers()):\n        if backend.backend() == \"tensorflow\":\n            from tensorflow.python.framework.config import (\n                is_op_determinism_enabled,\n            )\n\n            if is_op_determinism_enabled():\n                # disable XLA with determinism enabled since not all ops are\n                # supported by XLA with determinism enabled.\n                return False\n        return True\n    return False\n"
  },
  {
    "path": "keras/src/trainers/trainer_test.py",
    "content": "from unittest import mock\n\nimport jax\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nimport keras\nfrom keras.src import backend\nfrom keras.src import initializers\nfrom keras.src import layers\nfrom keras.src import losses\nfrom keras.src import metrics\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import optimizers\nfrom keras.src import testing\nfrom keras.src.backend import config\nfrom keras.src.backend.common.symbolic_scope import in_symbolic_scope\nfrom keras.src.callbacks.callback import Callback\nfrom keras.src.distribution.distribution_lib import DataParallel\nfrom keras.src.distribution.distribution_lib import DeviceMesh\nfrom keras.src.optimizers.rmsprop import RMSprop\nfrom keras.src.testing import test_case\nfrom keras.src.testing.test_utils import named_product\nfrom keras.src.trainers.data_adapters import py_dataset_adapter\n\nif backend.backend() == \"jax\":\n    from keras.src.backend.jax.trainer import JAXTrainer as Trainer\n    from keras.src.distribution import DataParallel\n    from keras.src.distribution import DeviceMesh\nelif backend.backend() == \"torch\":\n    from keras.src.backend.torch.trainer import TorchTrainer as Trainer\nelif backend.backend() == \"tensorflow\":\n    from keras.src.backend.tensorflow.trainer import (\n        TensorFlowTrainer as Trainer,\n    )\nelif backend.backend() == \"numpy\":\n    from keras.src.backend.numpy.trainer import NumpyTrainer as Trainer\nelif backend.backend() == \"openvino\":\n    from keras.src.backend.openvino.trainer import OpenVINOTrainer as Trainer\nelse:\n    raise ImportError(f\"Invalid backend: {backend.backend()}\")\n\n\n# A model is just a layer mixed in with a Trainer.\nclass ExampleModel(Trainer, layers.Dense):\n    def __init__(self, units):\n        layers.Dense.__init__(\n            self,\n            units=units,\n            use_bias=False,\n            kernel_initializer=initializers.Ones(),\n        )\n        Trainer.__init__(self)\n\n\nclass CustomTrainTestStepModel(ExampleModel):\n    def train_step(self, data):\n        logs = super().train_step(data)\n        logs[\"my_custom_metric\"] = 10.0\n        return logs\n\n    def test_step(self, data):\n        logs = super().test_step(data)\n        logs[\"my_custom_metric\"] = 5.0\n        return logs\n\n\nclass JaxCustomTrainTestStepModel(ExampleModel):\n    def train_step(self, state, data):\n        logs, state = super().train_step(state, data)\n        logs[\"my_custom_metric\"] = 10.0\n        return logs, state\n\n    def test_step(self, state, data):\n        logs, state = super().test_step(state, data)\n        logs[\"my_custom_metric\"] = 5.0\n        return logs, state\n\n\nclass StructModel(Trainer, layers.Layer):\n    def __init__(self, units):\n        layers.Layer.__init__(self)\n        Trainer.__init__(self)\n        self.dense_1 = layers.Dense(\n            units,\n            use_bias=False,\n            kernel_initializer=initializers.Ones(),\n        )\n        self.dense_2 = layers.Dense(\n            units,\n            use_bias=False,\n            kernel_initializer=initializers.Ones(),\n        )\n\n    def call(self, x):\n        return {\n            \"y_one\": self.dense_1(x[\"x_one\"]),\n            \"y_two\": self.dense_2(x[\"x_two\"]),\n        }\n\n\nclass ListInputModel(Trainer, layers.Layer):\n    def __init__(self, units):\n        layers.Layer.__init__(self)\n        Trainer.__init__(self)\n        self.dense_1 = layers.Dense(\n            units,\n            use_bias=False,\n            kernel_initializer=initializers.Ones(),\n        )\n        self.dense_2 = layers.Dense(\n            units,\n            use_bias=False,\n            kernel_initializer=initializers.Ones(),\n        )\n\n    def call(self, x):\n        if not isinstance(x, (list, tuple)):\n            raise ValueError(\"x must be a list or tuple\")\n        return self.dense_1(x[0]) + self.dense_2(x[1])\n\n\nclass ListOutputModel(Trainer, layers.Layer):\n    def __init__(self, units):\n        layers.Layer.__init__(self)\n        Trainer.__init__(self)\n        self.dense_1 = layers.Dense(\n            units,\n            use_bias=False,\n            kernel_initializer=initializers.Ones(),\n        )\n        self.dense_2 = layers.Dense(\n            units,\n            use_bias=False,\n            kernel_initializer=initializers.Ones(),\n        )\n\n    def call(self, x):\n        return [self.dense_1(x), self.dense_2(x)]\n\n\nclass TrainingTestingLayer(Trainer, layers.Layer):\n    def __init__(self, **kwargs):\n        layers.Layer.__init__(self, **kwargs)\n        Trainer.__init__(self)\n\n    def call(self, x, training=False):\n        if training:\n            return x\n        return x * 0\n\n\nclass TestPyDataset(py_dataset_adapter.PyDataset):\n    def __init__(self, infinite=False, **kwargs):\n        super().__init__(**kwargs)\n        self.infinite = infinite\n\n    @property\n    def num_batches(self):\n        return None if self.infinite else 20\n\n    def __getitem__(self, idx):\n        CPU_DEVICES = {\n            \"tensorflow\": \"CPU:0\",\n            \"jax\": \"cpu:0\",\n            \"torch\": \"cpu\",\n        }\n        with backend.device(CPU_DEVICES[backend.backend()]):\n            return ops.ones((5, 4)), ops.zeros((5, 3))\n\n\ndef create_dataset(dataset_type, dataset_kwargs):\n    if dataset_type == \"np_array\":\n        return np.ones((100, 4)), np.zeros((100, 3))\n    elif dataset_type == \"native_array\":\n        return ops.ones((100, 4)), ops.zeros((100, 3))\n    elif dataset_type == \"py_dataset\":\n        return TestPyDataset(**dataset_kwargs), None\n    elif dataset_type == \"tf_dataset\":\n        import tensorflow as tf\n\n        dataset = tf.data.Dataset.from_tensor_slices(\n            (tf.ones((100, 4)), tf.zeros((100, 3)))\n        ).batch(5)\n        if dataset_kwargs.get(\"infinite\", False):\n            dataset = dataset.repeat()\n        return dataset, None\n    elif dataset_type == \"torch_dataloader\":\n        import torch\n\n        class TestIterableDataset(torch.utils.data.IterableDataset):\n            def __iter__(self):\n                for _ in range(20):\n                    yield torch.ones((5, 4)), torch.zeros((5, 3))\n\n        class TestIterableDatasetWithLen(TestIterableDataset):\n            def __len__(self):\n                return 20\n\n        if dataset_kwargs.get(\"iterable\", False):\n            if dataset_kwargs.get(\"has_len\", False):\n                dataset = TestIterableDatasetWithLen()\n            else:\n                dataset = TestIterableDataset()\n            return torch.utils.data.DataLoader(dataset), None\n        else:\n            dataset = torch.utils.data.TensorDataset(\n                torch.ones((100, 4)), torch.zeros((100, 3))\n            )\n            return torch.utils.data.DataLoader(dataset, batch_size=5), None\n    elif dataset_type == \"generator\":\n\n        def generate_finite():\n            for _ in range(20):\n                yield ops.ones((5, 4)), ops.zeros((5, 3))\n\n        def generate_infinite():\n            while True:\n                yield ops.ones((5, 4)), ops.zeros((5, 3))\n\n        if dataset_kwargs.get(\"infinite\", False):\n            return generate_infinite(), None\n        else:\n            return generate_finite(), None\n    elif dataset_type == \"grain_datast\":\n        import grain\n\n        class TestIterableDataset(grain.sources.RandomAccessDataSource):\n            def __init__(self):\n                super().__init__()\n                self.x = np.ones((100, 4)).astype(\"float32\")\n                self.y = np.zeros((100, 3)).astype(\"float32\")\n\n            def __len__(self):\n                return len(self.x)\n\n            def __getitem__(self, idx):\n                return self.x[idx], self.y[idx]\n\n        if dataset_kwargs.get(\"use_dataloader\", False):\n            source = TestIterableDataset()\n            dataloader = grain.DataLoader(\n                data_source=source,\n                sampler=grain.samplers.IndexSampler(len(source), num_epochs=1),\n                operations=[grain.transforms.Batch(batch_size=5)],\n            )\n            return dataloader, None\n        else:\n            dataset = grain.MapDataset.source(TestIterableDataset())\n            if dataset_kwargs.get(\"has_len\", False):\n                dataset = dataset.to_iter_dataset()\n            dataset = dataset.batch(5)\n            return dataset, None\n    else:\n        raise ValueError(f\"Invalid dataset type {dataset_type}\")\n\n\ndef sparse_generator(generator_type):\n    if generator_type == \"scipy\":\n        import scipy\n\n        for _ in range(4):\n            x = scipy.sparse.random(2, 4, density=0.25, dtype=\"float32\")\n            y = np.random.rand(2, 3).astype(\"float32\")\n            yield x, y\n    elif generator_type == \"tf\":\n        import tensorflow as tf\n\n        for _ in range(4):\n            x = tf.random.uniform((2, 4), dtype=\"float32\")\n            x = tf.sparse.from_dense(tf.nn.dropout(x, 0.25))\n            y = tf.random.uniform((2, 3), dtype=\"float32\")\n            yield x, y\n    elif generator_type == \"jax\":\n        import jax\n        import jax.experimental.sparse as jax_sparse\n\n        for _ in range(4):\n            seed = jax.random.PRNGKey(0)\n            x = jax_sparse.random_bcoo(seed, (2, 4), dtype=\"float32\", nse=0.25)\n            y = jax.random.uniform(seed, (2, 3), dtype=\"float32\")\n            yield x, y\n    else:\n        raise ValueError(f\"Invalid generator type {generator_type}\")\n\n\nclass EpochAgnosticMeanSquaredError(metrics.MeanSquaredError):\n    def __init__(self):\n        super().__init__(name=\"mse\")\n        super().reset_state()\n\n    def reset_state(self):\n        # prevent reset at each starting epoch\n        pass\n\n\nclass StepObserver(Callback):\n    def __init__(self):\n        super().__init__()\n        self.begin_count = 0\n        self.end_count = 0\n        self.epoch_begin_count = 0\n        self.epoch_end_count = 0\n        self.batch_loss_history = []\n\n    def on_epoch_begin(self, epoch, logs=None):\n        self.epoch_begin_count += 1\n\n    def on_epoch_end(self, epoch, logs=None):\n        self.epoch_end_count += 1\n\n    def on_batch_begin(self, batch, logs=None):\n        self.begin_count += 1\n\n    def on_batch_end(self, batch, logs=None):\n        self.end_count += 1\n        self.batch_loss_history.append(logs[\"mse\"])\n\n\nclass StepCount(Callback):\n    def __init__(self, steps_per_execution=1):\n        super().__init__()\n        self.begin_count = 0\n        self.end_count = 0\n        self.epoch_begin_count = 0\n        self.epoch_end_count = 0\n        self.steps_per_execution = steps_per_execution\n\n    def on_epoch_begin(self, epoch, logs=None):\n        self.begin_count = 0\n        self.end_count = 0\n        self.epoch_begin_count += 1\n\n    def on_epoch_end(self, epoch, logs=None):\n        self.epoch_end_count += 1\n\n    def on_batch_begin(self, batch, logs=None):\n        if batch != self.begin_count * self.steps_per_execution:\n            raise ValueError(\"Batch index is not correct\")\n        self.begin_count += 1\n\n    def on_batch_end(self, batch, logs=None):\n        self.end_count += 1\n        if batch != self.end_count * self.steps_per_execution - 1:\n            raise ValueError(\"Batch index is not correct\")\n\n\nclass TestTrainer(testing.TestCase):\n    @pytest.mark.requires_trainable_backend\n    def test_metric_tracking(self):\n        class ModelWithMetric(Trainer, layers.Dense):\n            def __init__(self, units):\n                layers.Dense.__init__(\n                    self,\n                    units=units,\n                    use_bias=False,\n                    kernel_initializer=initializers.Ones(),\n                )\n                Trainer.__init__(self)\n                self.my_metric = metrics.MeanSquaredError(name=\"my_metric\")\n\n        model = ModelWithMetric(units=3)\n        model.compile(\n            optimizer=optimizers.SGD(),\n            loss=losses.MeanSquaredError(),\n            metrics=[metrics.MeanSquaredError()],\n        )\n        x = np.ones((2, 4))\n        y = np.zeros((2, 3))\n        # Fit the model to make sure compile_metrics are built\n        model.fit(x, y, batch_size=2, epochs=1)\n\n        # The model should have 3 metrics: loss_tracker, compile_metrics,\n        # my_metric.\n        self.assertEqual(len(model.metrics), 3)\n        self.assertEqual(model.metrics[0], model._loss_tracker)\n        self.assertEqual(model.metrics[1], model._compile_metrics)\n        self.assertEqual(model.metrics[2], model.my_metric)\n\n        # All metrics should have their weights created\n        self.assertEqual(len(model._loss_tracker.variables), 2)\n        self.assertEqual(len(model._compile_metrics.variables), 2)\n        self.assertEqual(len(model.my_metric.variables), 2)\n\n        # And those weights are tracked at the model level\n        self.assertEqual(len(model.metrics_variables), 6)\n        self.assertLen(model.non_trainable_variables, 0)\n\n        # Models with only weighted_metrics should have the same 3 metrics\n        model_weighted = ModelWithMetric(units=3)\n        model_weighted.compile(\n            optimizer=optimizers.SGD(),\n            loss=losses.MeanSquaredError(),\n            weighted_metrics=[metrics.MeanSquaredError()],\n        )\n        model_weighted.fit(\n            x,\n            y,\n            batch_size=2,\n            epochs=1,\n            sample_weight=np.ones(2),\n        )\n        self.assertEqual(len(model_weighted.metrics), 3)\n\n    def test_nested_trainer_metrics(self):\n        # https://github.com/keras-team/keras/issues/20188\n        model = ExampleModel(units=3)\n        model.compile(\n            optimizer=optimizers.SGD(),\n            loss=losses.MeanSquaredError(),\n            metrics=[metrics.MeanSquaredError()],\n        )\n        self.assertLen(model.metrics, 2)\n        self.assertEqual(model.metrics[0], model._loss_tracker)\n        self.assertEqual(model.metrics[1], model._compile_metrics)\n\n        inputs = keras.Input((4,))\n        outputs = model(inputs)\n        outputs = layers.Dense(8)(outputs)\n        new_model = models.Model(inputs, outputs)\n        new_model.compile(\n            optimizer=optimizers.SGD(),\n            loss=losses.MeanSquaredError(),\n            metrics=[metrics.MeanSquaredError()],\n        )\n        self.assertLen(new_model.metrics, 2)\n        self.assertEqual(new_model.metrics[0], new_model._loss_tracker)\n        self.assertEqual(new_model.metrics[1], new_model._compile_metrics)\n\n    def test_nested_trainer_metrics_without_compile(self):\n        model = ExampleModel(units=3)\n        self.assertLen(model.metrics, 0)\n\n        inputs = keras.Input((4,))\n        outputs = model(inputs)\n        outputs = layers.Dense(8)(outputs)\n        new_model = models.Model(inputs, outputs)\n        new_model.compile(\n            optimizer=optimizers.SGD(),\n            loss=losses.MeanSquaredError(),\n            metrics=[metrics.MeanSquaredError()],\n        )\n        self.assertLen(new_model.metrics, 2)\n        self.assertEqual(new_model.metrics[0], new_model._loss_tracker)\n        self.assertEqual(new_model.metrics[1], new_model._compile_metrics)\n\n    def test_multiple_compiles(self):\n        # https://github.com/keras-team/keras/issues/20474\n        model1 = ExampleModel(units=3)\n        model2 = ExampleModel(units=3)\n        model1.compile(\n            optimizer=optimizers.SGD(),\n            loss=losses.MeanSquaredError(),\n            metrics=[metrics.MeanSquaredError()],\n        )\n\n        # Combine these 2 models into `combined`.\n        inputs = keras.Input(shape=(4,))\n        x = model1(inputs)\n        outputs = model2(x)\n        combined = models.Model(inputs, outputs)\n        combined.compile(\n            optimizer=optimizers.SGD(),\n            loss=losses.MeanSquaredError(),\n            metrics=[metrics.MeanSquaredError()],\n        )\n\n        self.assertLen(model1.metrics, 2)\n        self.assertIsNotNone(model1._loss_tracker)\n        self.assertEqual(model1.metrics[0], model1._loss_tracker)\n        self.assertEqual(model1.metrics[1], model1._compile_metrics)\n\n        # `combined.metrics` will not include `model1.metrics`.\n        self.assertLen(combined.metrics, 2)\n        self.assertIsNotNone(combined._loss_tracker)\n        self.assertEqual(combined.metrics[0], combined._loss_tracker)\n        self.assertEqual(combined.metrics[1], combined._compile_metrics)\n\n    @pytest.mark.skipif(\n        backend.backend() != \"torch\",\n        reason=\"torch backend runs in eager mode for jit_compile='auto'\",\n    )\n    def test_compile_eager_vs_jit_torch(self):\n        model = ExampleModel(units=3)\n        model.compile(jit_compile=\"auto\")\n        # torch trainer en/disables torch.compile only based on the value of\n        # model.jit_compile (not model.run_eagerly)\n        self.assertFalse(model.run_eagerly)\n        self.assertFalse(model.jit_compile)\n\n    @parameterized.named_parameters(\n        [\n            (\"eager\", True, False, False),\n            (\"graph_fn\", False, False, False),\n            (\"jit\", False, True, False),\n            (\"steps_per_epoch_eager\", True, False, True),\n            (\"steps_per_epoch_graph_fn\", False, False, True),\n            (\"steps_per_epoch_jit\", False, True, True),\n        ]\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_fit_flow(self, run_eagerly, jit_compile, use_steps_per_epoch):\n        model = ExampleModel(units=3)\n        epochs = 3\n        batch_size = 20\n        steps_per_epoch = 7\n        dataset_size = batch_size * (steps_per_epoch - 2)\n        x = np.ones((dataset_size, 4))\n        y = np.zeros((dataset_size, 3))\n\n        model.compile(\n            optimizer=optimizers.SGD(),\n            loss=losses.MeanSquaredError(),\n            metrics=[metrics.MeanSquaredError()],\n            run_eagerly=run_eagerly,\n            jit_compile=jit_compile,\n        )\n        history = model.fit(\n            x,\n            y,\n            batch_size=batch_size,\n            steps_per_epoch=steps_per_epoch if use_steps_per_epoch else None,\n            epochs=epochs,\n        )\n        history = history.history\n        self.assertIn(\"loss\", history)\n        self.assertIn(\"mean_squared_error\", history)\n        self.assertAllClose(\n            history[\"mean_squared_error\"],\n            [14.5, 11.5, 8.5],\n            atol=1.0,  # TODO: results vary across backends\n        )\n\n    @parameterized.named_parameters(\n        [\n            {\n                \"testcase_name\": \"np_array\",\n                \"dataset_type\": \"np_array\",\n                \"fit_kwargs\": {\"batch_size\": 5},\n            },\n            {\n                \"testcase_name\": \"native_array\",\n                \"dataset_type\": \"native_array\",\n                \"fit_kwargs\": {\"batch_size\": 5},\n            },\n            {\n                \"testcase_name\": \"py_dataset\",\n                \"dataset_type\": \"py_dataset\",\n            },\n            {\n                \"testcase_name\": \"py_dataset_cw\",\n                \"dataset_type\": \"py_dataset\",\n                \"fit_kwargs\": {\"class_weight\": {0: 1, 1: 2}},\n            },\n            {\n                \"testcase_name\": \"py_dataset_infinite\",\n                \"dataset_type\": \"py_dataset\",\n                \"dataset_kwargs\": {\"infinite\": True},\n                \"fit_kwargs\": {\"steps_per_epoch\": 20},\n            },\n            {\n                \"testcase_name\": \"py_dataset_infinite_cw\",\n                \"dataset_type\": \"py_dataset\",\n                \"dataset_kwargs\": {\"infinite\": True},\n                \"fit_kwargs\": {\n                    \"steps_per_epoch\": 20,\n                    \"class_weight\": {0: 1, 1: 2},\n                },\n            },\n            {\n                \"testcase_name\": \"py_dataset_multithreading\",\n                \"dataset_type\": \"py_dataset\",\n                \"dataset_kwargs\": {\"workers\": 2},\n            },\n            {\n                \"testcase_name\": \"py_dataset_multithreading_cw\",\n                \"dataset_type\": \"py_dataset\",\n                \"dataset_kwargs\": {\"workers\": 2},\n                \"fit_kwargs\": {\"class_weight\": {0: 1, 1: 2}},\n            },\n            {\n                \"testcase_name\": \"py_dataset_multithreading_infinite\",\n                \"dataset_type\": \"py_dataset\",\n                \"dataset_kwargs\": {\"infinite\": True, \"workers\": 2},\n                \"fit_kwargs\": {\"steps_per_epoch\": 20},\n            },\n            {\n                \"testcase_name\": \"py_dataset_multiprocessing\",\n                \"dataset_type\": \"py_dataset\",\n                \"dataset_kwargs\": {\"workers\": 2, \"use_multiprocessing\": True},\n            },\n            {\n                \"testcase_name\": \"py_dataset_multiprocessing_cw\",\n                \"dataset_type\": \"py_dataset\",\n                \"dataset_kwargs\": {\"workers\": 2, \"use_multiprocessing\": True},\n                \"fit_kwargs\": {\"class_weight\": {0: 1, 1: 2}},\n            },\n            {\n                \"testcase_name\": \"py_dataset_multiprocessing_infinite\",\n                \"dataset_type\": \"py_dataset\",\n                \"dataset_kwargs\": {\n                    \"infinite\": True,\n                    \"workers\": 2,\n                    \"use_multiprocessing\": True,\n                },\n                \"fit_kwargs\": {\"steps_per_epoch\": 20},\n            },\n            {\n                \"testcase_name\": \"tf_dataset\",\n                \"dataset_type\": \"tf_dataset\",\n            },\n            {\n                \"testcase_name\": \"tf_dataset_infinite\",\n                \"dataset_type\": \"tf_dataset\",\n                \"dataset_kwargs\": {\"infinite\": True},\n                \"fit_kwargs\": {\"steps_per_epoch\": 20},\n            },\n            {\n                \"testcase_name\": \"torch_dataloader_tensor\",\n                \"dataset_type\": \"torch_dataloader\",\n            },\n            {\n                \"testcase_name\": \"torch_dataloader_iterable\",\n                \"dataset_type\": \"torch_dataloader\",\n                \"dataset_kwargs\": {\"iterable\": True, \"has_len\": False},\n            },\n            {\n                \"testcase_name\": \"torch_dataloader_iterable_with_len\",\n                \"dataset_type\": \"torch_dataloader\",\n                \"dataset_kwargs\": {\"iterable\": True, \"has_len\": True},\n            },\n            {\n                \"testcase_name\": \"generator\",\n                \"dataset_type\": \"generator\",\n            },\n            {\n                \"testcase_name\": \"generator_infinite\",\n                \"dataset_type\": \"generator\",\n                \"dataset_kwargs\": {\"infinite\": True},\n                \"fit_kwargs\": {\"steps_per_epoch\": 20},\n            },\n            {\n                \"testcase_name\": \"grain_datast\",\n                \"dataset_type\": \"grain_datast\",\n                \"dataset_kwargs\": {\"has_len\": False},\n            },\n            {\n                \"testcase_name\": \"grain_datast_with_len\",\n                \"dataset_type\": \"grain_datast\",\n                \"dataset_kwargs\": {\"has_len\": True},\n            },\n            {\n                \"testcase_name\": \"grain_dataloader\",\n                \"dataset_type\": \"grain_datast\",\n                \"dataset_kwargs\": {\"use_dataloader\": True},\n            },\n        ]\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_fit_with_data_adapter(\n        self, dataset_type, dataset_kwargs={}, fit_kwargs={}\n    ):\n        jit_compile = True\n        if (\n            dataset_kwargs.get(\"use_multiprocessing\", False)\n            and backend.backend() == \"jax\"\n        ):\n            pytest.skip(\"Multiprocessing not supported with JAX backend\")\n        if dataset_type == \"grain_datast\" and backend.backend() == \"torch\":\n            # Grain datasets are not supported with torch + jit_compile.\n            jit_compile = False\n\n        model = ExampleModel(units=3)\n        optimizer = optimizers.Adagrad()\n        model.compile(\n            optimizer=optimizer,\n            loss=losses.MeanSquaredError(),\n            metrics=[metrics.MeanSquaredError()],\n            jit_compile=jit_compile,\n        )\n        x, y = create_dataset(dataset_type, dataset_kwargs)\n        model.fit(x, y, epochs=3, **fit_kwargs)\n\n    @parameterized.named_parameters(\n        [\n            (\"eager\", True, False, False),\n            (\"graph_fn\", False, False, False),\n            (\"jit\", False, True, False),\n            (\"steps_per_epoch_eager\", True, False, True),\n            (\"steps_per_epoch_graph_fn\", False, False, True),\n            (\"steps_per_epoch_jit\", False, True, True),\n        ]\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_fit_with_val_split(\n        self, run_eagerly, jit_compile, use_steps_per_epoch\n    ):\n        model = ExampleModel(units=3)\n        epochs = 3\n        batch_size = 20\n        steps_per_epoch = 7\n        dataset_size = batch_size * (steps_per_epoch - 2)\n        x = np.ones((dataset_size, 4))\n        y = np.zeros((dataset_size, 3))\n\n        model.compile(\n            optimizer=optimizers.SGD(),\n            loss=losses.MeanSquaredError(),\n            metrics=[metrics.MeanSquaredError()],\n            run_eagerly=run_eagerly,\n            jit_compile=jit_compile,\n        )\n        history = model.fit(\n            x,\n            y,\n            batch_size=batch_size,\n            steps_per_epoch=steps_per_epoch if use_steps_per_epoch else None,\n            epochs=epochs,\n            validation_split=0.2,\n        )\n        history = history.history\n        self.assertIn(\"loss\", history)\n        self.assertIn(\"val_loss\", history)\n\n        # Test with backend-native tensors.\n        x = ops.ones((dataset_size, 4))\n        y = ops.zeros((dataset_size, 3))\n        history = model.fit(\n            x,\n            y,\n            batch_size=batch_size,\n            steps_per_epoch=steps_per_epoch if use_steps_per_epoch else None,\n            epochs=epochs,\n            validation_split=0.2,\n        )\n        history = history.history\n        self.assertIn(\"loss\", history)\n        self.assertIn(\"val_loss\", history)\n\n    @pytest.mark.requires_trainable_backend\n    def test_fit_with_custom_train_step(self):\n        if backend.backend() == \"jax\":\n            model = JaxCustomTrainTestStepModel(units=3)\n        else:\n            model = CustomTrainTestStepModel(units=3)\n        x = np.ones((100, 4))\n        y = np.zeros((100, 3))\n        batch_size = 16\n\n        model.compile(\n            optimizer=optimizers.SGD(),\n            loss=losses.MeanSquaredError(),\n            metrics=[metrics.MeanSquaredError()],\n        )\n        history = model.fit(x, y, batch_size=batch_size)\n        history = history.history\n        self.assertIn(\"loss\", history)\n        self.assertIn(\"mean_squared_error\", history)\n        self.assertAllClose(history[\"my_custom_metric\"], 10.0)\n\n    @parameterized.named_parameters(\n        named_product(\n            generator_type=[\"tf\", \"jax\", \"scipy\"], mode=[\"eager\", \"graph\"]\n        )\n    )\n    @pytest.mark.skipif(\n        not backend.SUPPORTS_SPARSE_TENSORS,\n        reason=\"Backend does not support sparse tensors.\",\n    )\n    def test_fit_sparse(self, generator_type, mode):\n        model = ExampleModel(units=3)\n        optimizer = optimizers.Adagrad()\n        model.compile(\n            optimizer=optimizer,\n            loss=losses.MeanSquaredError(),\n            metrics=[metrics.MeanSquaredError()],\n            run_eagerly=(mode == \"eager\"),\n            jit_compile=False,\n        )\n        dataset = sparse_generator(generator_type)\n\n        sparse_variable_updates = False\n\n        def mock_optimizer_assign(variable, value):\n            nonlocal sparse_variable_updates\n            if value.__class__.__name__ == \"IndexedSlices\":\n                sparse_variable_updates = True\n\n        with mock.patch.object(\n            optimizer, \"assign_sub\", autospec=True\n        ) as optimizer_assign_sub:\n            optimizer_assign_sub.side_effect = mock_optimizer_assign\n            model.fit(dataset)\n\n        # JAX does not produce sparse gradients the way we use it.\n        if backend.backend() != \"jax\":\n            # Verify tensors did not get densified along the way.\n            self.assertTrue(sparse_variable_updates)\n\n    @parameterized.named_parameters(\n        [\n            (\"eager\", True, False),\n            (\"graph_fn\", False, False),\n            (\"jit\", False, True),\n        ]\n    )\n    def test_evaluate_flow(self, run_eagerly, jit_compile):\n        model = ExampleModel(units=3)\n        x = np.ones((100, 4))\n        y = np.zeros((100, 3))\n        batch_size = 16\n\n        model.compile(\n            optimizer=optimizers.SGD(),\n            loss=losses.MeanSquaredError(),\n            metrics=[metrics.MeanSquaredError()],\n            run_eagerly=run_eagerly,\n            jit_compile=jit_compile,\n        )\n        output = model.evaluate(x, y, batch_size=batch_size)\n        self.assertAllClose(output, [16.0, 16.0])\n        output = model.evaluate(x, y, batch_size=batch_size, return_dict=True)\n        self.assertIsInstance(output, dict)\n        self.assertIn(\"loss\", output)\n        self.assertIn(\"mean_squared_error\", output)\n        self.assertAllClose(output[\"mean_squared_error\"], 16.0)\n\n    @parameterized.named_parameters([(\"flat\", False), (\"dict\", True)])\n    @pytest.mark.requires_trainable_backend\n    def test_evaluate_with_custom_test_step(self, return_dict):\n        if backend.backend() == \"jax\":\n            model = JaxCustomTrainTestStepModel(units=3)\n        else:\n            model = CustomTrainTestStepModel(units=3)\n        x = np.ones((100, 4))\n        y = np.zeros((100, 3))\n        batch_size = 16\n\n        model.compile(\n            optimizer=optimizers.SGD(),\n            loss=losses.MeanSquaredError(),\n            metrics=[metrics.MeanSquaredError()],\n        )\n        output = model.evaluate(\n            x, y, batch_size=batch_size, return_dict=return_dict\n        )\n        self.assertLen(output, 3)\n        if return_dict:\n            self.assertAllClose(output[\"my_custom_metric\"], 5.0)\n        else:\n            self.assertAllClose(output[-1], 5.0)  # Custom metrics go last.\n\n    @parameterized.named_parameters(\n        named_product(\n            generator_type=[\"tf\", \"jax\", \"scipy\"], mode=[\"eager\", \"graph\"]\n        )\n    )\n    @pytest.mark.skipif(\n        not backend.SUPPORTS_SPARSE_TENSORS,\n        reason=\"Backend does not support sparse tensors.\",\n    )\n    def test_evaluate_sparse(self, generator_type, mode):\n        model = ExampleModel(units=3)\n        model.compile(\n            optimizer=optimizers.Adagrad(),\n            loss=losses.MeanSquaredError(),\n            metrics=[metrics.MeanSquaredError()],\n            run_eagerly=(mode == \"eager\"),\n            jit_compile=False,\n        )\n        dataset = sparse_generator(generator_type)\n        model.evaluate(dataset)\n\n    @parameterized.named_parameters(\n        [\n            (\"eager\", True, False),\n            (\"graph_fn\", False, False),\n            (\"jit\", False, True),\n        ]\n    )\n    def test_predict_flow(self, run_eagerly, jit_compile):\n        # Test basic example\n        model = ExampleModel(units=3)\n        model.run_eagerly = run_eagerly\n        model.jit_compile = jit_compile\n\n        x = np.ones((100, 4))\n        batch_size = 16\n        outputs = model.predict(x, batch_size=batch_size)\n        self.assertAllClose(outputs, 4 * np.ones((100, 3)))\n\n    @parameterized.named_parameters(\n        [\n            (\"eager\", True, False),\n            (\"graph_fn\", False, False),\n            (\"jit\", False, True),\n        ]\n    )\n    def test_predict_flow_struct(self, run_eagerly, jit_compile):\n        # Test with input/output structs\n        model = StructModel(units=3)\n        model.run_eagerly = run_eagerly\n        model.jit_compile = jit_compile\n\n        x = {\n            \"x_one\": np.ones((100, 4)),\n            \"x_two\": np.ones((100, 4)),\n        }\n        batch_size = 16\n        outputs = model.predict(x, batch_size=batch_size)\n        self.assertIsInstance(outputs, dict)\n        self.assertEqual(len(outputs), 2)\n        self.assertAllClose(outputs[\"y_one\"], 4 * np.ones((100, 3)))\n        self.assertAllClose(outputs[\"y_two\"], 4 * np.ones((100, 3)))\n\n    @parameterized.named_parameters(\n        named_product(\n            generator_type=[\"tf\", \"jax\", \"scipy\"], mode=[\"eager\", \"graph\"]\n        )\n    )\n    @pytest.mark.skipif(\n        not backend.SUPPORTS_SPARSE_TENSORS,\n        reason=\"Backend does not support sparse tensors.\",\n    )\n    def test_predict_sparse(self, generator_type, mode):\n        model = ExampleModel(units=3)\n        model.compile(\n            optimizer=optimizers.Adagrad(),\n            loss=losses.MeanSquaredError(),\n            metrics=[metrics.MeanSquaredError()],\n            run_eagerly=(mode == \"eager\"),\n            jit_compile=False,\n        )\n        dataset = sparse_generator(generator_type)\n        dataset_size = sum(\n            [batch[1].shape[0] for batch in sparse_generator(generator_type)]\n        )\n        y = model.predict(dataset)\n        self.assertEqual(len(y), dataset_size)\n\n    @pytest.mark.skipif(\n        backend.backend() != \"jax\",\n        reason=\"Memory optimization is only implemented in JAX\",\n    )\n    def test_fit_eval_flow_for_jax_model_weights(self):\n        test_obj = self\n\n        model = ExampleModel(units=3)\n        epochs = 3\n        batch_size = 20\n        steps_per_epoch = 7\n        dataset_size = batch_size * (steps_per_epoch - 2)\n        x = np.ones((dataset_size, 4))\n        y = np.zeros((dataset_size, 3))\n\n        class ModelWeightCheck(Callback):\n            def __init__(self):\n                super().__init__()\n\n            # Note that we access model via self._model since self.model\n            # will trigger a sync of the jax training state back to the model.\n            def on_train_batch_end(self, batch, logs=None):\n                for v in self._model.trainable_variables:\n                    test_obj.assertIsNone(v._value)\n                for v in self._model.non_trainable_variables:\n                    test_obj.assertIsNone(v._value)\n                for v in self._model.optimizer.variables:\n                    test_obj.assertIsNone(v._value)\n                for v in self._model.metrics_variables:\n                    test_obj.assertIsNone(v._value)\n\n            def on_test_batch_end(self, batch, logs=None):\n                for v in self._model.non_trainable_variables:\n                    test_obj.assertIsNone(v._value)\n                for v in self._model.metrics_variables:\n                    test_obj.assertIsNone(v._value)\n\n        model.compile(\n            optimizer=optimizers.SGD(),\n            loss=losses.MeanSquaredError(),\n            metrics=[metrics.MeanSquaredError()],\n        )\n\n        model.fit(\n            x,\n            y,\n            batch_size=batch_size,\n            steps_per_epoch=steps_per_epoch,\n            epochs=epochs,\n            callbacks=[ModelWeightCheck()],\n        )\n\n        model.evaluate(\n            x,\n            y,\n            batch_size=batch_size,\n            callbacks=[ModelWeightCheck()],\n        )\n\n    @parameterized.named_parameters(\n        named_product(\n            steps_per_execution=[3, 101], mode=[\"eager\", \"non_jit\", \"jit\"]\n        )\n    )\n    @pytest.mark.requires_trainable_backend\n    @pytest.mark.skipif(\n        backend.backend() == \"torch\",\n        reason=\"`steps_per_execution` not implemented for torch yet\",\n    )\n    def test_steps_per_execution_steps_count(self, steps_per_execution, mode):\n        data_size = 100\n        batch_size = 16\n        epochs = 2\n\n        x = np.ones((data_size, 4))\n        y = np.ones((data_size, 1))\n\n        model = ExampleModel(units=1)\n        model.compile(\n            loss=\"mse\",\n            optimizer=\"sgd\",\n            steps_per_execution=steps_per_execution,\n            run_eagerly=(mode == \"eager\"),\n            jit_compile=(mode == \"jit\"),\n        )\n        step_count = StepCount(steps_per_execution)\n\n        history = model.fit(\n            x=x,\n            y=y,\n            batch_size=batch_size,\n            epochs=epochs,\n            callbacks=[step_count],\n            verbose=0,\n        )\n\n        self.assertEqual(\n            step_count.begin_count,\n            1 + (data_size - 1) // (steps_per_execution * batch_size),\n        )\n        self.assertEqual(step_count.end_count, step_count.begin_count)\n        self.assertEqual(step_count.epoch_begin_count, epochs)\n        self.assertEqual(\n            step_count.epoch_end_count, step_count.epoch_begin_count\n        )\n\n        model_2 = ExampleModel(units=1)\n        model_2.compile(\n            loss=\"mse\",\n            optimizer=\"sgd\",\n            steps_per_execution=1,\n            run_eagerly=(mode == \"eager\"),\n            jit_compile=(mode == \"jit\"),\n        )\n        history_2 = model_2.fit(\n            x=x, y=y, batch_size=batch_size, epochs=epochs, verbose=0\n        )\n\n        self.assertAllClose(history.history[\"loss\"], history_2.history[\"loss\"])\n        self.assertAllClose(model.get_weights(), model_2.get_weights())\n        self.assertAllClose(\n            model.predict(x, batch_size=batch_size),\n            model_2.predict(x, batch_size=batch_size),\n        )\n        self.assertAllClose(model.evaluate(x, y), model_2.evaluate(x, y))\n\n    @parameterized.named_parameters(\n        named_product(steps_per_execution=[3, 8, 32])\n    )\n    @pytest.mark.requires_trainable_backend\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"`unrolled_steps_per_execution` is only \"\n        \"available with the tensorflow backend.\",\n    )\n    def test_steps_per_execution_unrolled_steps_steps_count(\n        self, steps_per_execution\n    ):\n        data_size = 100\n        batch_size = 16\n        epochs = 2\n        unrolled_steps_per_execution = 8\n\n        x = np.ones((data_size, 4))\n        y = np.ones((data_size, 1))\n\n        model = ExampleModel(units=1)\n        model.compile(\n            loss=\"mse\",\n            optimizer=\"sgd\",\n            steps_per_execution=steps_per_execution,\n            jit_compile=True,\n        )\n        step_count = StepCount(steps_per_execution)\n        model.unrolled_steps_per_execution = unrolled_steps_per_execution\n        history = model.fit(\n            x=x,\n            y=y,\n            batch_size=batch_size,\n            epochs=epochs,\n            callbacks=[step_count],\n            verbose=0,\n        )\n\n        self.assertEqual(\n            step_count.begin_count,\n            1 + (data_size - 1) // (steps_per_execution * batch_size),\n        )\n        self.assertEqual(step_count.end_count, step_count.begin_count)\n        self.assertEqual(step_count.epoch_begin_count, epochs)\n        self.assertEqual(\n            step_count.epoch_end_count, step_count.epoch_begin_count\n        )\n\n        model_2 = ExampleModel(units=1)\n        model_2.compile(\n            loss=\"mse\",\n            optimizer=\"sgd\",\n            steps_per_execution=steps_per_execution,\n            jit_compile=True,\n        )\n        model_2.unrolled_steps_per_execution = 1\n        history_2 = model_2.fit(\n            x=x, y=y, batch_size=batch_size, epochs=epochs, verbose=0\n        )\n\n        self.assertAllClose(history.history[\"loss\"], history_2.history[\"loss\"])\n        self.assertAllClose(model.get_weights(), model_2.get_weights())\n        self.assertAllClose(\n            model.predict(x, batch_size=batch_size),\n            model_2.predict(x, batch_size=batch_size),\n        )\n        self.assertAllClose(model.evaluate(x, y), model_2.evaluate(x, y))\n\n    @parameterized.named_parameters(\n        named_product(\n            steps_per_execution=[1, 50], mode=[\"eager\", \"non_jit\", \"jit\"]\n        )\n    )\n    def test_predict_preserve_order(self, steps_per_execution, mode):\n        if steps_per_execution > 1 and backend.backend() == \"torch\":\n            self.skipTest(\"`steps_per_execution` not implemented for torch yet\")\n\n        def generate_uneven_batches():\n            batch_sizes = [2, 3, 4]\n\n            def gen_i():\n                for i in range(100):\n                    yield i\n\n            iterator = iter(gen_i())\n            j = 0\n            while True:\n                batch_size = batch_sizes[j % len(batch_sizes)]\n                try:\n                    batch = np.array(\n                        [next(iterator) for _ in range(batch_size)]\n                    )\n                except StopIteration:\n                    break\n                j += 1\n                yield batch\n\n        from keras.src.utils.module_utils import tensorflow as tf\n\n        dataset = tf.data.Dataset.from_generator(\n            generate_uneven_batches,\n            output_signature=tf.TensorSpec((None,), dtype=tf.int32),\n        )\n        x = keras.layers.Input(shape=())\n        y = keras.layers.Identity()(x)\n        model = keras.Model(x, y)\n        model.compile(\n            loss=\"mse\",\n            optimizer=\"sgd\",\n            steps_per_execution=steps_per_execution,\n            run_eagerly=(mode == \"eager\"),\n            jit_compile=(mode == \"jit\"),\n        )\n\n        preds = model.predict(x=dataset, verbose=0)\n\n        self.assertAllEqual(preds, np.arange(len(preds), dtype=np.float32))\n\n    @parameterized.named_parameters(\n        named_product(\n            steps_per_execution=[1, 50], mode=[\"eager\", \"non_jit\", \"jit\"]\n        )\n    )\n    def test_predict_generator(self, steps_per_execution, mode):\n        if steps_per_execution > 1 and backend.backend() == \"torch\":\n            self.skipTest(\"`steps_per_execution` not implemented for torch yet\")\n\n        batch_size = 2\n\n        def generate_batches():\n            def gen_i():\n                for i in range(10):\n                    yield i\n\n            iterator = iter(gen_i())\n            j = 0\n            while True:\n                try:\n                    batch = np.array(\n                        [next(iterator) for _ in range(batch_size)]\n                    )\n                except StopIteration:\n                    break\n                j += 1\n                yield (batch,)\n\n        model = keras.Sequential(\n            [keras.layers.InputLayer(shape=()), keras.layers.Identity()]\n        )\n        model.compile(\n            loss=\"mse\",\n            optimizer=\"sgd\",\n            steps_per_execution=steps_per_execution,\n            run_eagerly=(mode == \"eager\"),\n            jit_compile=(mode == \"jit\"),\n        )\n\n        preds = model.predict(x=generate_batches(), verbose=0)\n        self.assertAllEqual(\n            preds, np.concatenate(list(generate_batches()), axis=1)[0]\n        )\n\n    @parameterized.named_parameters(\n        named_product(\n            steps_per_execution=[3, 101], mode=[\"eager\", \"non_jit\", \"jit\"]\n        )\n    )\n    @pytest.mark.requires_trainable_backend\n    @pytest.mark.skipif(\n        backend.backend() == \"torch\",\n        reason=\"`steps_per_execution` not implemented for torch yet\",\n    )\n    def test_steps_per_execution_steps_count_unknown_dataset_size(\n        self, steps_per_execution, mode\n    ):\n        data_size = 100\n        batch_size = 16\n        epochs = 2\n\n        def data_generator():\n            x = np.ones((data_size, 4), dtype=np.float32)\n            y = np.ones((data_size, 1), dtype=np.float32)\n            for _x, _y in zip(x, y):\n                yield _x, _y\n\n        import tensorflow as tf\n\n        dataset = tf.data.Dataset.from_generator(\n            data_generator,\n            output_signature=(\n                tf.TensorSpec(shape=(4,), dtype=tf.float32),\n                tf.TensorSpec(shape=(1,), dtype=tf.float32),\n            ),\n        )\n        dataset = dataset.batch(batch_size)\n\n        model = ExampleModel(units=1)\n        model.compile(\n            loss=\"mse\",\n            optimizer=\"sgd\",\n            steps_per_execution=steps_per_execution,\n            run_eagerly=(mode == \"eager\"),\n            jit_compile=(mode == \"jit\"),\n        )\n        step_count = StepCount(steps_per_execution)\n\n        history = model.fit(\n            dataset,\n            epochs=epochs,\n            callbacks=[step_count],\n            verbose=0,\n        )\n\n        batch_count = 1 + (data_size - 1) // (steps_per_execution * batch_size)\n        self.assertGreaterEqual(step_count.begin_count, batch_count)\n        self.assertEqual(step_count.end_count, batch_count)\n        self.assertEqual(step_count.epoch_begin_count, epochs)\n        self.assertEqual(\n            step_count.epoch_end_count, step_count.epoch_begin_count\n        )\n\n        model_2 = ExampleModel(units=1)\n        model_2.compile(\n            loss=\"mse\",\n            optimizer=\"sgd\",\n            steps_per_execution=1,\n            run_eagerly=(mode == \"eager\"),\n            jit_compile=(mode == \"jit\"),\n        )\n        history_2 = model_2.fit(dataset, epochs=epochs, verbose=0)\n\n        self.assertAllClose(history.history[\"loss\"], history_2.history[\"loss\"])\n        self.assertAllClose(model.get_weights(), model_2.get_weights())\n        self.assertAllClose(\n            model.predict(dataset),\n            model_2.predict(dataset),\n        )\n        self.assertAllClose(model.evaluate(dataset), model_2.evaluate(dataset))\n\n    @parameterized.named_parameters(\n        named_product(\n            steps_per_epoch_test=[\n                \"match_one_epoch\",\n                \"match_multi_epoch\",\n                \"not_match_too_low\",\n                \"not_match_but_high_enough\",\n            ],\n            mode=[\"eager\", \"non_jit\", \"jit\"],\n        )\n    )\n    @pytest.mark.requires_trainable_backend\n    @pytest.mark.skipif(\n        backend.backend() == \"torch\",\n        reason=\"`steps_per_execution` not implemented for torch yet\",\n    )\n    def test_steps_per_execution_steps_per_epoch(\n        self, steps_per_epoch_test, mode\n    ):\n        batch_size = 8\n        epochs = 2\n        steps_per_execution = 2\n        num_batches = 5 * steps_per_execution\n        data_size = num_batches * batch_size\n\n        if steps_per_epoch_test == \"match_one_epoch\":\n            steps_per_epoch = num_batches\n        elif steps_per_epoch_test == \"match_multi_epoch\":\n            steps_per_epoch = num_batches // steps_per_execution\n        elif steps_per_epoch_test == \"not_match_too_low\":\n            steps_per_epoch = num_batches - steps_per_execution\n        elif steps_per_epoch_test == \"not_match_but_high_enough\":\n            steps_per_epoch = num_batches + steps_per_execution\n\n        x = np.ones((data_size, 4))\n        y = np.ones((data_size, 1))\n\n        model = ExampleModel(units=1)\n        model.compile(\n            loss=\"mse\",\n            optimizer=\"sgd\",\n            metrics=[EpochAgnosticMeanSquaredError()],\n            steps_per_execution=steps_per_execution,\n            run_eagerly=(mode == \"eager\"),\n            jit_compile=(mode == \"jit\"),\n        )\n        step_observer = StepObserver()\n\n        model.fit(\n            x=x,\n            y=y,\n            batch_size=batch_size,\n            epochs=epochs,\n            steps_per_epoch=steps_per_epoch,\n            callbacks=[step_observer],\n            verbose=0,\n        )\n        if steps_per_epoch_test != \"not_match_too_low\":\n            training_batch_count = (\n                epochs\n                * min(steps_per_epoch, num_batches)\n                // steps_per_execution\n            )\n        else:\n            complete_epochs = (num_batches // steps_per_execution) // (\n                steps_per_epoch // steps_per_execution\n            )\n            remaining_steps = (num_batches // steps_per_execution) % (\n                steps_per_epoch // steps_per_execution\n            )\n            steps_cycles = [\n                complete_epochs * steps_per_epoch // steps_per_execution,\n                remaining_steps,\n            ] * epochs\n            steps_per_epochs = steps_cycles[:epochs]\n            training_batch_count = sum(steps_per_epochs)\n\n        self.assertEqual(step_observer.begin_count, training_batch_count)\n        self.assertEqual(step_observer.end_count, step_observer.begin_count)\n        self.assertEqual(step_observer.epoch_begin_count, epochs)\n        self.assertEqual(\n            step_observer.epoch_end_count, step_observer.epoch_begin_count\n        )\n\n        if steps_per_epoch_test != \"not_match_too_low\":\n            model_2 = ExampleModel(units=1)\n            model_2.compile(\n                loss=\"mse\",\n                optimizer=\"sgd\",\n                metrics=[EpochAgnosticMeanSquaredError()],\n                steps_per_execution=1,\n                run_eagerly=(mode == \"eager\"),\n                jit_compile=(mode == \"jit\"),\n            )\n            step_observer_2 = StepObserver()\n\n            if steps_per_epoch_test in (\n                \"not_match_but_high_enough\",\n                \"match_one_epoch\",\n            ):\n                model_2_epochs = epochs\n            else:\n                model_2_epochs = 1\n\n            model_2.fit(\n                x=x,\n                y=y,\n                batch_size=batch_size,\n                epochs=model_2_epochs,\n                callbacks=[step_observer_2],\n                verbose=0,\n            )\n\n            losses = step_observer.batch_loss_history\n            losses_2 = step_observer_2.batch_loss_history[\n                steps_per_execution - 1 :: steps_per_execution\n            ]\n            self.assertAllClose(losses, losses_2)\n            self.assertAllClose(model.get_weights(), model_2.get_weights())\n            self.assertAllClose(\n                model.predict(x, batch_size=batch_size),\n                model_2.predict(x, batch_size=batch_size),\n            )\n            self.assertAllClose(model.evaluate(x, y), model_2.evaluate(x, y))\n\n    @parameterized.named_parameters(\n        named_product(\n            steps_per_epoch_test=[\n                \"match_one_epoch\",\n                \"match_multi_epoch\",\n                \"not_match_too_low\",\n                \"not_match_but_high_enough\",\n            ],\n            mode=[\"eager\", \"non_jit\", \"jit\"],\n        )\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_steps_per_epoch(self, steps_per_epoch_test, mode):\n        batch_size = 8\n        epochs = 4\n        num_batches = 10\n        data_size = num_batches * batch_size\n\n        if steps_per_epoch_test == \"match_one_epoch\":\n            steps_per_epoch = num_batches\n        elif steps_per_epoch_test == \"match_multi_epoch\":\n            steps_per_epoch = num_batches // (epochs // 2)\n        elif steps_per_epoch_test == \"not_match_too_low\":\n            steps_per_epoch = num_batches - 1\n        elif steps_per_epoch_test == \"not_match_but_high_enough\":\n            steps_per_epoch = num_batches + 1\n\n        x = np.ones((data_size, 4))\n        y = np.ones((data_size, 1))\n\n        model = ExampleModel(units=1)\n        model.compile(\n            loss=\"mse\",\n            optimizer=\"sgd\",\n            metrics=[EpochAgnosticMeanSquaredError()],\n            run_eagerly=(mode == \"eager\"),\n            jit_compile=(mode == \"jit\"),\n        )\n        step_observer = StepObserver()\n\n        model.fit(\n            x=x,\n            y=y,\n            batch_size=batch_size,\n            epochs=epochs,\n            steps_per_epoch=steps_per_epoch,\n            callbacks=[step_observer],\n            verbose=0,\n        )\n        if steps_per_epoch_test != \"not_match_too_low\":\n            training_batch_count = epochs * min(steps_per_epoch, num_batches)\n        else:\n            complete_epochs = num_batches // steps_per_epoch\n            remaining_steps = num_batches % steps_per_epoch\n            steps_cycles = [\n                complete_epochs * steps_per_epoch,\n                remaining_steps,\n            ] * epochs\n            steps_per_epochs = steps_cycles[:epochs]\n            training_batch_count = sum(steps_per_epochs)\n\n        self.assertEqual(step_observer.begin_count, training_batch_count)\n        self.assertEqual(step_observer.end_count, step_observer.begin_count)\n        self.assertEqual(step_observer.epoch_begin_count, epochs)\n        self.assertEqual(\n            step_observer.epoch_end_count, step_observer.epoch_begin_count\n        )\n\n        if steps_per_epoch_test != \"not_match_too_low\":\n            model_2 = ExampleModel(units=1)\n            model_2.compile(\n                loss=\"mse\",\n                optimizer=\"sgd\",\n                metrics=[EpochAgnosticMeanSquaredError()],\n                steps_per_execution=1,\n                run_eagerly=(mode == \"eager\"),\n                jit_compile=(mode == \"jit\"),\n            )\n            step_observer_2 = StepObserver()\n\n            if steps_per_epoch_test in (\n                \"not_match_but_high_enough\",\n                \"match_one_epoch\",\n            ):\n                model_2_epochs = epochs\n            elif steps_per_epoch_test == \"match_multi_epoch\":\n                model_2_epochs = epochs // (num_batches // steps_per_epoch)\n            else:\n                model_2_epochs = 1\n\n            model_2.fit(\n                x=x,\n                y=y,\n                batch_size=batch_size,\n                epochs=model_2_epochs,\n                callbacks=[step_observer_2],\n                verbose=0,\n            )\n\n            losses = step_observer.batch_loss_history\n            losses_2 = step_observer_2.batch_loss_history\n\n            self.assertAllClose(losses, losses_2)\n            self.assertAllClose(model.get_weights(), model_2.get_weights())\n            self.assertAllClose(\n                model.predict(x, batch_size=batch_size),\n                model_2.predict(x, batch_size=batch_size),\n            )\n            self.assertAllClose(model.evaluate(x, y), model_2.evaluate(x, y))\n\n    @pytest.mark.requires_trainable_backend\n    def test_max_epochs_and_steps(self):\n        batch_size = 8\n        epochs = 4\n        num_batches = 10\n        data_size = num_batches * batch_size\n        x, y = np.ones((data_size, 4)), np.ones((data_size, 1))\n        model = ExampleModel(units=1)\n        model.compile(\n            loss=\"mse\",\n            optimizer=\"sgd\",\n            metrics=[EpochAgnosticMeanSquaredError()],\n        )\n        step_observer = StepObserver()\n        model.fit(\n            x=x,\n            y=y,\n            batch_size=batch_size,\n            epochs=epochs,\n            callbacks=[step_observer],\n            verbose=0,\n        )\n        self.assertEqual(step_observer.epoch_begin_count, epochs)\n        self.assertEqual(step_observer.begin_count, num_batches * epochs)\n        try:\n            config.set_max_epochs(2)\n            config.set_max_steps_per_epoch(3)\n            step_observer = StepObserver()\n            model.fit(\n                x=x,\n                y=y,\n                batch_size=batch_size,\n                epochs=epochs,\n                callbacks=[step_observer],\n                verbose=0,\n            )\n            self.assertEqual(step_observer.epoch_begin_count, 2)\n            self.assertEqual(step_observer.begin_count, 6)\n        finally:\n            config.set_max_epochs(None)\n            config.set_max_steps_per_epoch(None)\n\n    @parameterized.named_parameters(\n        named_product(\n            steps_per_epoch_test=[\n                \"match\",\n                \"not_match_too_low\",\n                \"not_match_but_high_enough\",\n            ],\n            mode=[\"eager\", \"non_jit\", \"jit\"],\n        )\n    )\n    @pytest.mark.requires_trainable_backend\n    @pytest.mark.skipif(\n        backend.backend() == \"torch\",\n        reason=\"`steps_per_execution` not implemented for torch yet\",\n    )\n    def test_steps_per_execution_steps_per_epoch_unknown_data_size(\n        self, steps_per_epoch_test, mode\n    ):\n        batch_size = 8\n        epochs = 2\n        steps_per_execution = 2\n        num_batches = 5 * epochs * steps_per_execution\n        data_size = num_batches * batch_size\n\n        if steps_per_epoch_test == \"match\":\n            steps_per_epoch = num_batches // epochs\n        elif steps_per_epoch_test == \"not_match_too_low\":\n            steps_per_epoch = num_batches - steps_per_execution\n        elif steps_per_epoch_test == \"not_match_but_high_enough\":\n            steps_per_epoch = num_batches + steps_per_execution\n\n        def data_generator():\n            x = np.ones((data_size, 4), dtype=np.float32)\n            y = np.ones((data_size, 1), dtype=np.float32)\n            for _x, _y in zip(x, y):\n                yield _x, _y\n\n        import tensorflow as tf\n\n        dataset = tf.data.Dataset.from_generator(\n            data_generator,\n            output_signature=(\n                tf.TensorSpec(shape=(4,), dtype=tf.float32),\n                tf.TensorSpec(shape=(1,), dtype=tf.float32),\n            ),\n        )\n        dataset = dataset.batch(batch_size)\n\n        model = ExampleModel(units=1)\n        model.compile(\n            loss=\"mse\",\n            optimizer=\"sgd\",\n            metrics=[EpochAgnosticMeanSquaredError()],\n            steps_per_execution=steps_per_execution,\n            run_eagerly=(mode == \"eager\"),\n            jit_compile=(mode == \"jit\"),\n        )\n        step_observer = StepObserver()\n\n        model.fit(\n            dataset,\n            epochs=epochs,\n            steps_per_epoch=steps_per_epoch,\n            callbacks=[step_observer],\n            verbose=0,\n        )\n        if steps_per_epoch_test != \"not_match_too_low\":\n            training_batch_count = (\n                epochs\n                * min(steps_per_epoch, num_batches)\n                // steps_per_execution\n            )\n        else:\n            complete_epochs = (num_batches // steps_per_execution) // (\n                steps_per_epoch // steps_per_execution\n            )\n            remaining_steps = (num_batches // steps_per_execution) % (\n                steps_per_epoch // steps_per_execution\n            )\n            steps_cycles = [\n                complete_epochs * steps_per_epoch // steps_per_execution,\n                remaining_steps,\n            ] * epochs\n            steps_per_epochs = steps_cycles[:epochs]\n            training_batch_count = sum(steps_per_epochs)\n\n        self.assertGreaterEqual(step_observer.begin_count, training_batch_count)\n        self.assertEqual(step_observer.end_count, training_batch_count)\n        self.assertEqual(step_observer.epoch_begin_count, epochs)\n        self.assertEqual(\n            step_observer.epoch_end_count, step_observer.epoch_begin_count\n        )\n\n        if steps_per_epoch_test != \"not_match_too_low\":\n            model_2 = ExampleModel(units=1)\n            model_2.compile(\n                loss=\"mse\",\n                optimizer=\"sgd\",\n                metrics=[EpochAgnosticMeanSquaredError()],\n                steps_per_execution=1,\n                run_eagerly=(mode == \"eager\"),\n                jit_compile=(mode == \"jit\"),\n            )\n            step_observer_2 = StepObserver()\n\n            if steps_per_epoch_test == \"not_match_but_high_enough\":\n                model_2_epochs = epochs\n            else:\n                model_2_epochs = 1\n\n            model_2.fit(\n                dataset,\n                epochs=model_2_epochs,\n                callbacks=[step_observer_2],\n                verbose=0,\n            )\n\n            losses = step_observer.batch_loss_history\n            losses_2 = step_observer_2.batch_loss_history[\n                steps_per_execution - 1 :: steps_per_execution\n            ]\n            self.assertAllClose(losses, losses_2)\n            self.assertAllClose(model.get_weights(), model_2.get_weights())\n            self.assertAllClose(\n                model.predict(dataset), model_2.predict(dataset)\n            )\n            self.assertAllClose(\n                model.evaluate(dataset), model_2.evaluate(dataset)\n            )\n\n    @pytest.mark.skipif(\n        backend.backend() == \"torch\",\n        reason=\"`steps_per_execution` not implemented for torch yet\",\n    )\n    def test_steps_per_execution_steps_count_without_training(self):\n        test_obj = self\n\n        class StepCount(Callback):\n            def __init__(self):\n                super().__init__()\n                self.test_count = 0\n                self.predict_count = 0\n                self.batches = [0, 3, 6]\n\n            def on_test_batch_begin(self, batch, logs=None):\n                test_obj.assertEqual(batch, self.batches[self.test_count])\n                self.test_count += 1\n\n            def on_predict_batch_begin(self, batch, logs=None):\n                test_obj.assertEqual(batch, self.batches[self.predict_count])\n                self.predict_count += 1\n\n        x = np.ones((100, 4))\n        y = np.ones((100, 1))\n        batch_size = 16\n        model = ExampleModel(units=1)\n        model.compile(loss=\"mse\", steps_per_execution=3)\n        step_count = StepCount()\n        model.predict(x, batch_size=batch_size, callbacks=[step_count])\n        self.assertEqual(step_count.predict_count, 3)\n        model.evaluate(x, y, batch_size=batch_size, callbacks=[step_count])\n        self.assertEqual(step_count.test_count, 3)\n\n    @pytest.mark.requires_trainable_backend\n    def test_fit_with_different_batch_size_same_loss(self):\n        x = np.random.rand(100, 4)\n        y = np.ones((100, 1))\n        model = ExampleModel(units=1)\n        model.trainable = False\n        model.compile(loss=\"mse\")\n        loss1 = model.fit(x, y, batch_size=80).history[\"loss\"]\n        loss2 = model.fit(x, y, batch_size=100).history[\"loss\"]\n        self.assertAllClose(loss1, loss2)\n\n    def test_evaluate_with_different_batch_size_same_loss(self):\n        x = np.random.rand(100, 4)\n        y = np.ones((100, 1))\n        model = ExampleModel(units=1)\n        model.compile(loss=\"mse\")\n        loss1 = model.evaluate(x, y, batch_size=80)\n        loss2 = model.evaluate(x, y, batch_size=100)\n        self.assertAllClose(loss1, loss2)\n\n    @pytest.mark.requires_trainable_backend\n    def test_adds_loss_scaling_optimizer(self):\n        model = TrainingTestingLayer(dtype=\"mixed_float16\")\n        model.compile(optimizer=\"rmsprop\", loss=\"mse\")\n        x = np.ones((128, 1))\n        y = np.zeros((128, 1))\n        model.fit(x, y, batch_size=32)\n        self.assertIsInstance(model.optimizer, optimizers.LossScaleOptimizer)\n\n        model = TrainingTestingLayer(dtype=\"mixed_float16\")\n        model.compile(optimizer=\"rmsprop\", loss=\"mse\", auto_scale_loss=False)\n        x = np.ones((128, 1))\n        y = np.zeros((128, 1))\n        model.fit(x, y, batch_size=32)\n        self.assertIsInstance(model.optimizer, RMSprop)\n\n        model = TrainingTestingLayer(dtype=\"mixed_bfloat16\")\n        model.compile(optimizer=\"rmsprop\", loss=\"mse\")\n        x = np.ones((128, 1))\n        y = np.zeros((128, 1))\n        model.fit(x, y, batch_size=32)\n        self.assertIsInstance(model.optimizer, RMSprop)\n\n    @pytest.mark.requires_trainable_backend\n    @pytest.mark.skipif(\n        backend.backend() == \"torch\",\n        reason=\"half precision unsupported on torch CPU.\",\n    )\n    def test_loss_scaling_prevents_underflow(self):\n        class DeepModel(Trainer, layers.Layer):\n            def __init__(self):\n                layers.Layer.__init__(self, dtype=\"mixed_float16\")\n                Trainer.__init__(self)\n                self.layers = []\n                for _ in range(15):\n                    # Sigmoid has a small gradient, will eventually underflow.\n                    self.layers.append(\n                        layers.Dense(\n                            1,\n                            use_bias=False,\n                            kernel_initializer=\"ones\",\n                            activation=\"sigmoid\",\n                            dtype=\"mixed_float16\",\n                        )\n                    )\n\n            def call(self, x):\n                for layer in self.layers:\n                    x = layer(x)\n                return x\n\n        loss = losses.MeanSquaredError()\n        # Blow up any gradient updates, so underflow is obvious.\n        optimizer = optimizers.SGD(learning_rate=1e9)\n        model = DeepModel()\n        model.compile(optimizer, loss=loss, auto_scale_loss=False)\n        model.fit(np.ones((1, 1)), np.ones((1, 1)), batch_size=1)\n        first_kernel = model.layers[0].kernel\n        # Without autoscaling, the first dense will not update.\n        self.assertEqual(first_kernel, np.ones_like(first_kernel))\n\n        # Blow up any gradient updates, so underflow is obvious.\n        optimizer = optimizers.SGD(learning_rate=1e9)\n        model = DeepModel()\n        model.compile(optimizer, loss=loss, auto_scale_loss=True)\n        model.fit(np.ones((1, 1)), np.ones((1, 1)), batch_size=1)\n        first_kernel = model.layers[0].kernel\n        # With autoscaling, the first dense will update.\n        self.assertNotEqual(first_kernel, np.ones_like(first_kernel))\n\n    @pytest.mark.requires_trainable_backend\n    def test_training_arg(self):\n        model = TrainingTestingLayer()\n        model.compile(optimizer=\"rmsprop\", loss=\"mse\")\n        x = np.ones((128, 1))\n        y = np.zeros((128, 1))\n        history = model.fit(x, y, batch_size=32)\n        self.assertAllClose(history.history[\"loss\"], [1.0])\n        val_loss = model.evaluate(x, y, batch_size=32)\n        self.assertAllClose(val_loss, 0.0)\n        preds = model.predict(x)\n        self.assertAllClose(preds, np.zeros((128, 1)))\n\n    @parameterized.named_parameters(\n        [\n            (\"eager\", True, False),\n            (\"graph_fn\", False, False),\n            (\"jit\", False, True),\n        ]\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_on_batch_methods(self, run_eagerly, jit_compile):\n        if backend.backend() == \"torch\" and jit_compile:\n            self.skipTest(\n                \"test_on_batch with jit_compile=True not supported in torch \"\n                \"backend yet.\"\n            )\n        model = ExampleModel(units=3)\n        x = np.ones((100, 4))\n        y = np.zeros((100, 3))\n        sw = np.arange(100).reshape((100,)).astype(\"float32\") / 50.0\n\n        model.compile(\n            optimizer=optimizers.SGD(),\n            loss=losses.MeanSquaredError(),\n            metrics=[metrics.MeanSquaredError()],\n            run_eagerly=run_eagerly,\n            jit_compile=jit_compile,\n        )\n        logs = model.train_on_batch(x, y)\n        self.assertIsInstance(logs, list)\n        self.assertEqual(len(logs), 2)\n        self.assertAlmostEqual(logs[0], 16.0)\n\n        logs = model.train_on_batch(x, y, return_dict=True)\n        self.assertIsInstance(logs, dict)\n        self.assertEqual(len(logs), 2)\n        self.assertAlmostEqual(logs[\"loss\"], 15.579, tpu_decimal=1)\n\n        logs = model.test_on_batch(x, y)\n        self.assertIsInstance(logs, list)\n        self.assertEqual(len(logs), 2)\n        self.assertAlmostEqual(logs[0], 15.173, tpu_decimal=1)\n\n        logs = model.test_on_batch(x, y, return_dict=True)\n        self.assertIsInstance(logs, dict)\n        self.assertEqual(len(logs), 2)\n        self.assertAlmostEqual(logs[\"loss\"], 14.97, tpu_decimal=1)\n\n        output = model.predict_on_batch(x)\n        self.assertIsInstance(output, np.ndarray)\n        self.assertAllClose(\n            output[0],\n            np.array([3.789511, 3.789511, 3.789511]),\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n\n        # With sample weights\n        logs = model.train_on_batch(x, y, sw)\n        self.assertAlmostEqual(logs[0], 14.819, tpu_decimal=1)\n        logs = model.test_on_batch(x, y, sw)\n        self.assertAlmostEqual(logs[0], 14.595, tpu_decimal=1)\n        output = model.predict_on_batch(x)\n        self.assertAllClose(\n            output[0],\n            np.array([3.689468, 3.689468, 3.689468]),\n            tpu_atol=1e-2,\n            tpu_rtol=1e-2,\n        )\n\n        # With class weights\n        logs = model.train_on_batch(x, y, class_weight={1: 0.3, 0: 0.2})\n        self.assertAlmostEqual(logs[0], 12.899, tpu_decimal=1)\n\n    @parameterized.named_parameters(\n        [\n            (\"eager\", True, False),\n            (\"graph_fn\", False, False),\n            (\"jit\", False, True),\n        ]\n    )\n    def test_on_batch_methods_without_training(self, run_eagerly, jit_compile):\n        if backend.backend() == \"torch\" and jit_compile:\n            self.skipTest(\n                \"test_on_batch with jit_compile=True not supported in torch \"\n                \"backend yet.\"\n            )\n        model = ExampleModel(units=3)\n        x = np.ones((100, 4))\n        y = np.zeros((100, 3))\n\n        model.compile(\n            loss=losses.MeanSquaredError(),\n            metrics=[metrics.MeanSquaredError()],\n            run_eagerly=run_eagerly,\n            jit_compile=jit_compile,\n        )\n        output = model.predict_on_batch(x)\n        self.assertIsInstance(output, np.ndarray)\n        self.assertAllClose(output[0], np.array([4.0, 4.0, 4.0]))\n\n        logs = model.test_on_batch(x, y)\n        self.assertIsInstance(logs, list)\n        self.assertEqual(len(logs), 2)\n        self.assertAlmostEqual(logs[0], 16.0)\n\n        logs = model.test_on_batch(x, y, return_dict=True)\n        self.assertIsInstance(logs, dict)\n        self.assertEqual(len(logs), 2)\n        self.assertAlmostEqual(logs[\"loss\"], 16.0)\n\n    def test_nested_input_predict(self):\n        # https://github.com/keras-team/keras/issues/325\n\n        class TupleInputModel(keras.Model):\n            def call(self, inputs):\n                a, b = inputs\n                return a + b\n\n        model = TupleInputModel()\n        x1, x2 = np.random.rand(2, 3, 4)\n        out = model.predict((x1, x2))\n        self.assertEqual(out.shape, (3, 4))\n\n        class DictInputModel(keras.Model):\n            def call(self, inputs):\n                return inputs[\"a\"] + inputs[\"b\"]\n\n        model = DictInputModel()\n        x1, x2 = np.random.rand(2, 3, 4)\n        out = model.predict({\"a\": x1, \"b\": x2})\n        self.assertEqual(out.shape, (3, 4))\n\n    @pytest.mark.requires_trainable_backend\n    def test_for_eval_epoch_iterator(self):\n        model = ExampleModel(units=3)\n        model.compile(\n            optimizer=\"adam\", loss=\"mse\", metrics=[\"mean_absolute_error\"]\n        )\n        x = np.ones((16, 4))\n        y = np.zeros((16, 3))\n        x_test = np.ones((16, 4))\n        y_test = np.zeros((16, 3))\n        model.fit(\n            x,\n            y,\n            batch_size=4,\n            validation_data=(x_test, y_test),\n        )\n        self.assertIsNone(getattr(model, \"_eval_epoch_iterator\", None))\n\n        # Try model.fit with reshaped validation_data\n        # This will throw an exception which is intended\n        try:\n            model.fit(\n                x,\n                y,\n                batch_size=4,\n                validation_data=(\n                    x_test.reshape((-1, 16, 4)),\n                    y_test.reshape((-1, 16, 3)),\n                ),\n            )\n        except:\n            pass\n\n        # Try model.fit with correct validation_data this should work.\n        # After successful training `_eval_epoch_iterator` should be None\n        model.fit(\n            x,\n            y,\n            batch_size=4,\n            validation_data=(x_test, y_test),\n        )\n        self.assertIsNone(getattr(model, \"_eval_epoch_iterator\", None))\n\n    @pytest.mark.requires_trainable_backend\n    def test_callback_methods_keys(self):\n        test_obj = self\n\n        class CustomCallback(Callback):\n            def on_train_begin(self, logs=None):\n                keys = sorted(list(logs.keys()))\n                test_obj.assertEqual(keys, [])\n\n            def on_train_end(self, logs=None):\n                keys = sorted(list(logs.keys()))\n                test_obj.assertEqual(\n                    keys,\n                    [\n                        \"loss\",\n                        \"mean_absolute_error\",\n                        \"val_loss\",\n                        \"val_mean_absolute_error\",\n                    ],\n                )\n\n            def on_epoch_begin(self, epoch, logs=None):\n                keys = sorted(list(logs.keys()))\n                test_obj.assertEqual(keys, [])\n\n            def on_epoch_end(self, epoch, logs=None):\n                keys = sorted(list(logs.keys()))\n                test_obj.assertEqual(\n                    keys,\n                    [\n                        \"loss\",\n                        \"mean_absolute_error\",\n                        \"val_loss\",\n                        \"val_mean_absolute_error\",\n                    ],\n                )\n\n            def on_test_begin(self, logs=None):\n                keys = sorted(list(logs.keys()))\n                test_obj.assertEqual(keys, [])\n\n            def on_test_end(self, logs=None):\n                keys = sorted(list(logs.keys()))\n                test_obj.assertEqual(keys, [\"loss\", \"mean_absolute_error\"])\n\n            def on_predict_begin(self, logs=None):\n                keys = sorted(list(logs.keys()))\n                test_obj.assertEqual(keys, [])\n\n            def on_predict_end(self, logs=None):\n                keys = sorted(list(logs.keys()))\n                test_obj.assertEqual(keys, [])\n\n            def on_train_batch_begin(self, batch, logs=None):\n                keys = sorted(list(logs.keys()))\n                test_obj.assertEqual(keys, [])\n\n            def on_train_batch_end(self, batch, logs=None):\n                keys = sorted(list(logs.keys()))\n                test_obj.assertEqual(keys, [\"loss\", \"mean_absolute_error\"])\n\n            def on_test_batch_begin(self, batch, logs=None):\n                keys = sorted(list(logs.keys()))\n                test_obj.assertEqual(keys, [])\n\n            def on_test_batch_end(self, batch, logs=None):\n                keys = sorted(list(logs.keys()))\n                test_obj.assertEqual(keys, [\"loss\", \"mean_absolute_error\"])\n\n            def on_predict_batch_begin(self, batch, logs=None):\n                keys = sorted(list(logs.keys()))\n                test_obj.assertEqual(keys, [])\n\n            def on_predict_batch_end(self, batch, logs=None):\n                keys = sorted(list(logs.keys()))\n                test_obj.assertEqual(keys, [\"outputs\"])\n\n        model = ExampleModel(units=3)\n        model.compile(\n            optimizer=\"adam\", loss=\"mse\", metrics=[\"mean_absolute_error\"]\n        )\n        x = np.ones((16, 4))\n        y = np.zeros((16, 3))\n        x_test = np.ones((16, 4))\n        y_test = np.zeros((16, 3))\n        model.fit(\n            x,\n            y,\n            callbacks=[CustomCallback()],\n            batch_size=4,\n            validation_data=(x_test, y_test),\n        )\n        model.evaluate(x_test, y_test, batch_size=4)\n        model.predict(x_test, batch_size=4)\n\n    @pytest.mark.requires_trainable_backend\n    def test_internal_only_loss(self):\n        class LossLayer(layers.Layer):\n            def call(self, x):\n                self.add_loss(ops.sum(x))\n                return x\n\n        model = keras.Sequential(\n            [\n                layers.Dense(2),\n                LossLayer(),\n                layers.Dense(1),\n            ]\n        )\n        model.compile(optimizer=\"adam\")\n        x = np.ones((16, 2))\n        y = np.zeros((16, 1))\n        model.fit(x, y, batch_size=4)\n\n    def get_layer(self):\n        class ExampleLayer(keras.Layer):\n            def call(self, x):\n                return x * 2\n\n        return ExampleLayer\n\n    def get_model(self):\n        class ExampleModel(keras.Model):\n            def call(self, x):\n                return x * 2\n\n        return ExampleModel\n\n    def get_functional(self):\n        ExampleLayer = self.get_layer()\n\n        class ExampleFunctional(keras.src.Functional):\n            def __init__(self, input_shape=(None,)):\n                inputs = keras.Input(input_shape)\n                outputs = ExampleLayer()(inputs)\n                super().__init__(inputs=inputs, outputs=outputs)\n\n        return ExampleFunctional\n\n    @parameterized.named_parameters(\n        [\n            {\n                \"testcase_name\": \"model\",\n                \"model_class\": \"get_model\",\n            },\n            {\n                \"testcase_name\": \"layer\",\n                \"model_class\": \"get_layer\",\n            },\n            {\n                \"testcase_name\": \"functional\",\n                \"model_class\": \"get_functional\",\n            },\n        ]\n    )\n    @pytest.mark.requires_trainable_backend\n    @pytest.mark.skipif(\n        keras.backend.backend() != \"tensorflow\",\n        reason=\"Only tensorflow supports raggeds\",\n    )\n    def test_trainer_with_raggeds(self, model_class):\n        from keras.src.utils.module_utils import tensorflow as tf\n\n        def loss_fn(y, y_pred, sample_weight=None):\n            return 0\n\n        model = getattr(self, model_class)()()\n        x = tf.ragged.constant([[1], [2, 3]])\n\n        # test forward pass\n        y = model(x)\n        self.assertEqual(type(y), tf.RaggedTensor)\n\n        # test training\n        if model_class in [\"get_model\", \"get_functional\"]:\n            model.compile(optimizer=\"adam\", loss=loss_fn)\n            model.fit(x, x)\n            y = model.predict(x)\n            self.assertEqual(type(y), tf.RaggedTensor)\n\n        # test if everything works with the sequential model\n        model = keras.Sequential([model])\n        model.compile(optimizer=\"adam\", loss=loss_fn)\n        model.fit(x, x)\n        y = model.predict(x)\n        self.assertEqual(type(y), tf.RaggedTensor)\n\n    def test_predict_dropout(self):\n        # Test that `predict` with a dropout op\n        # has nondeterministic behavior across batches.\n\n        inputs = layers.Input((20,))\n        outputs = layers.Dropout(0.5, seed=1337)(inputs, training=True)\n        model = keras.Model(inputs, outputs)\n        out1 = model.predict(np.ones((4, 20)), batch_size=2)\n        self.assertGreater(5, np.sum(np.abs(out1[:2, :] - out1[2:4, :])))\n\n        out2 = model.predict_on_batch(np.ones((2, 20)))\n        out3 = model.predict_on_batch(np.ones((2, 20)))\n        self.assertGreater(5, np.sum(np.abs(out2 - out3)))\n\n    @pytest.mark.requires_trainable_backend\n    def test_recompile(self):\n        model = ExampleModel(units=3)\n        model.compile(\n            optimizer=\"sgd\", loss=\"mse\", metrics=[\"mean_squared_error\"]\n        )\n        history_1 = model.fit(np.ones((3, 2)), np.ones((3, 3))).history\n        eval_out_1 = model.evaluate(\n            np.ones((3, 2)), np.ones((3, 3)), return_dict=True\n        )\n        model.compile(\n            optimizer=\"sgd\", loss=\"mse\", metrics=[\"mean_absolute_error\"]\n        )\n        history_2 = model.fit(np.ones((3, 2)), np.ones((3, 3))).history\n        eval_out_2 = model.evaluate(\n            np.ones((3, 2)), np.ones((3, 3)), return_dict=True\n        )\n        self.assertEqual(\n            sorted(list(history_1.keys())), [\"loss\", \"mean_squared_error\"]\n        )\n        self.assertEqual(\n            sorted(list(eval_out_1.keys())), [\"loss\", \"mean_squared_error\"]\n        )\n        self.assertEqual(\n            sorted(list(history_2.keys())), [\"loss\", \"mean_absolute_error\"]\n        )\n        self.assertEqual(\n            sorted(list(eval_out_2.keys())), [\"loss\", \"mean_absolute_error\"]\n        )\n\n    def test_evaluate_return_list_respect_metrics_order(self):\n        def metrics_zero(y_true, y_pred):\n            return 0.0\n\n        def metrics_one(y_true, y_pred):\n            return 1.0\n\n        model = ExampleModel(units=3)\n        model.compile(\n            optimizer=\"sgd\", loss=\"mse\", metrics=[metrics_zero, metrics_one]\n        )\n        eval_out = model.evaluate(np.ones((3, 2)), np.ones((3, 3)))\n        self.assertLen(eval_out, 3)\n        self.assertEqual(eval_out[1], 0.0)\n        self.assertEqual(eval_out[2], 1.0)\n\n        model.compile(\n            optimizer=\"sgd\", loss=\"mse\", metrics=[metrics_one, metrics_zero]\n        )\n        eval_out = model.evaluate(np.ones((3, 2)), np.ones((3, 3)))\n        self.assertLen(eval_out, 3)\n        self.assertEqual(eval_out[1], 1.0)\n        self.assertEqual(eval_out[2], 0.0)\n\n    @pytest.mark.requires_trainable_backend\n    def test_nested_inputs(self):\n        model = ListInputModel(units=2)\n        out = model([np.ones((3, 2)), np.ones((3, 3))])\n        self.assertEqual(tuple(out.shape), (3, 2))\n        model.compile(optimizer=\"sgd\", loss=\"mse\", metrics=[\"mse\"])\n        history = model.fit(\n            [np.ones((3, 2)), np.ones((3, 3))], np.ones((3, 2))\n        ).history\n        self.assertAllClose(history[\"loss\"], 16.0, tpu_atol=1e-4, tpu_rtol=1e-4)\n        train_out = model.train_on_batch(\n            [np.ones((3, 2)), np.ones((3, 3))], np.ones((3, 2))\n        )\n        self.assertAllClose(train_out[0], 15.2200, tpu_atol=1e-1, tpu_rtol=1e-1)\n        eval_out = model.evaluate(\n            [np.ones((3, 2)), np.ones((3, 3))], np.ones((3, 2))\n        )\n        self.assertAllClose(eval_out[0], 13.0321, tpu_atol=1e-2, tpu_rtol=1e-2)\n        eval_out = model.test_on_batch(\n            [np.ones((3, 2)), np.ones((3, 3))], np.ones((3, 2))\n        )\n        self.assertAllClose(eval_out[0], 13.0321, tpu_atol=1e-2, tpu_rtol=1e-2)\n        predict_out = model.predict([np.ones((3, 2)), np.ones((3, 3))])\n        self.assertEqual(predict_out.shape, (3, 2))\n        predict_out = model.predict_on_batch([np.ones((3, 2)), np.ones((3, 3))])\n        self.assertEqual(predict_out.shape, (3, 2))\n\n    @pytest.mark.requires_trainable_backend\n    def test_validation_data_infinite_generator(self):\n        # Test that you can pass an infinite generator to `validation_data`\n        # arg of fit() as well as a `validation_steps` argument and that\n        # validation only runs for the correct number of steps.\n        model = ExampleModel(units=3)\n        model.compile(optimizer=\"sgd\", loss=\"mse\", metrics=[\"mse\"])\n\n        class Recorder(keras.callbacks.Callback):\n            def __init__(self):\n                self.train_counter = 0\n                self.val_counter = 0\n\n            def on_train_batch_end(self, *args, **kwargs):\n                self.train_counter += 1\n\n            def on_test_batch_end(self, *args, **kwargs):\n                self.val_counter += 1\n\n        def infinite_gen():\n            while True:\n                yield np.ones((2, 2)), np.ones((2, 3))\n\n        recorder = Recorder()\n\n        model.fit(\n            infinite_gen(),\n            validation_data=infinite_gen(),\n            steps_per_epoch=3,\n            validation_steps=4,\n            epochs=1,\n            shuffle=False,\n            callbacks=[recorder],\n        )\n        self.assertEqual(recorder.train_counter, 3)\n        self.assertEqual(recorder.val_counter, 4)\n\n    @parameterized.named_parameters(\n        [\n            (\"fit\", \"fit\", \"training\", \"train\"),\n            (\"evaluate\", \"evaluate\", \"evaluating\", \"test\"),\n            (\"predict\", \"predict\", \"predicting\", \"predict\"),\n        ]\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_stop_loop(self, method, method_gerund, on_end_name):\n        model = ExampleModel(units=3)\n        model.compile(optimizer=\"sgd\", loss=\"mse\", metrics=[\"mse\"])\n\n        class Stopper(keras.callbacks.Callback):\n            def __init__(self, stop_count):\n                self.stop_count = stop_count\n                self.counter = 0\n                setattr(self, f\"on_{on_end_name}_batch_end\", self.batch_end)\n\n            def batch_end(self, *args, **kwargs):\n                self.counter += 1\n                if self.counter == self.stop_count:\n                    setattr(self.model, f\"stop_{method_gerund}\", True)\n\n        def infinite_gen():\n            while True:\n                x = np.ones((2, 2))\n                y = np.ones((2, 3))\n                yield (x,) if method == \"predict\" else (x, y)\n\n        stop_count = 5\n        stopper = Stopper(stop_count)\n\n        getattr(model, method)(\n            infinite_gen(),\n            callbacks=[stopper],\n        )\n        self.assertEqual(stopper.counter, stop_count)\n\n    @pytest.mark.requires_trainable_backend\n    def test_constraints_are_applied(self):\n        model = models.Sequential(\n            [layers.Dense(2, kernel_constraint=\"non_neg\")]\n        )\n        x = np.ones((2, 3))\n        y = np.ones((2, 2))\n        model.compile(optimizer=\"rmsprop\", loss=\"mse\")\n        model.fit(x, y)\n        self.assertGreaterEqual(\n            np.min(backend.convert_to_numpy(model.layers[0].kernel)), 0.0\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_rng_updated_during_predict(self):\n        class TestTimeDropout(layers.Layer):\n            def __init__(self):\n                super().__init__()\n                self.random_generator = keras.random.SeedGenerator()\n\n            def call(self, x):\n                return keras.random.dropout(\n                    x, rate=0.5, seed=self.random_generator\n                )\n\n        inputs = layers.Input((20,))\n        outputs = TestTimeDropout()(inputs)\n        model = keras.Model(inputs, outputs)\n        model.compile(optimizer=\"rmsprop\", loss=\"mse\")\n\n        x = np.ones((32, 20))\n        out_1 = model.predict(x)\n        out_2 = model.predict(x)\n        self.assertGreater(np.mean(np.abs(out_1 - out_2)), 0.01)\n\n    @pytest.mark.requires_trainable_backend\n    def test_callbacks_can_update_state_at_batch_boundary(self):\n        class CounterModel(keras.Model):\n            def __init__(self):\n                super().__init__()\n                self.train_counter = self.add_weight(\n                    shape=(),\n                    initializer=\"zeros\",\n                )\n                self.test_counter = self.add_weight(\n                    shape=(),\n                    initializer=\"zeros\",\n                )\n                self.predict_counter = self.add_weight(\n                    shape=(),\n                    initializer=\"zeros\",\n                )\n                self.dense = layers.Dense(3)\n\n            def call(self, x):\n                return self.dense(x)\n\n        class CounterCallback(keras.callbacks.Callback):\n            def __init__(self):\n                self.eager_call_counter_train = 0\n                self.eager_call_counter_test = 0\n                self.eager_call_counter_predict = 0\n\n            def on_train_batch_end(self, *args, **kwargs):\n                self.model.train_counter.assign_add(1)\n                self.eager_call_counter_train += 1\n\n            def on_test_batch_end(self, *args, **kwargs):\n                self.model.test_counter.assign_add(1)\n                self.eager_call_counter_test += 1\n\n            def on_predict_batch_end(self, *args, **kwargs):\n                self.model.predict_counter.assign_add(1)\n                self.eager_call_counter_predict += 1\n\n        model = CounterModel()\n        model.compile(\n            optimizer=\"sgd\", loss=\"mse\", metrics=[\"mse\"], run_eagerly=True\n        )\n        cbk = CounterCallback()\n        model.fit(\n            np.ones((4, 3)),\n            np.ones((4, 3)),\n            callbacks=[cbk],\n            epochs=3,\n            batch_size=1,\n            verbose=0,\n            validation_data=(np.ones((2, 3)), np.ones((2, 3))),\n        )\n        self.assertAlmostEqual(cbk.eager_call_counter_train, 12)\n        self.assertAlmostEqual(model.train_counter.numpy(), 12)\n        self.assertAlmostEqual(cbk.eager_call_counter_test, 6)\n        self.assertAlmostEqual(model.test_counter.numpy(), 6)\n        model.predict(\n            np.ones((4, 3)),\n            callbacks=[cbk],\n            batch_size=1,\n        )\n        self.assertAlmostEqual(cbk.eager_call_counter_predict, 4)\n        self.assertAlmostEqual(model.predict_counter.numpy(), 4)\n\n    @pytest.mark.requires_trainable_backend\n    def test_metric_update_in_compute_loss(self):\n        test_self = self\n\n        class MyModel(keras.Model):\n            def __init__(self):\n                super().__init__()\n                self.custom_metric = keras.metrics.Mean(name=\"custom\")\n                self.dense = keras.layers.Dense(2)\n\n            def call(self, x):\n                return self.dense(x)\n\n            def compute_loss(\n                self,\n                x=None,\n                y=None,\n                y_pred=None,\n                sample_weight=None,\n                training=True,\n            ):\n                if not in_symbolic_scope():\n                    test_self.assertTrue(training)\n                loss = super().compute_loss(\n                    x, y, y_pred, sample_weight, training\n                )\n                self.custom_metric.update_state(loss * 4)\n                return loss\n\n        model = MyModel()\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n        x = np.ones((32, 4))\n        y = np.ones((32, 2)) * 2\n        history = model.fit(x, y)\n        self.assertAlmostEqual(\n            history.history[\"custom\"][0], history.history[\"loss\"][0] * 4\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_fwd_pass_loss_presence_in_compute_loss(self):\n        test_self = self\n\n        class MyModel(keras.Model):\n            def __init__(self):\n                super().__init__()\n                self.custom_metric = keras.metrics.Mean(name=\"custom\")\n                self.dense = keras.layers.Dense(2, activity_regularizer=\"l2\")\n\n            def call(self, x):\n                return self.dense(x)\n\n            def compute_loss(\n                self,\n                x=None,\n                y=None,\n                y_pred=None,\n                sample_weight=None,\n                training=True,\n            ):\n                if not in_symbolic_scope():\n                    test_self.assertTrue(training)\n                loss = super().compute_loss(\n                    x, y, y_pred, sample_weight, training\n                )\n                self.custom_metric.update_state(sum(self.losses))\n                return loss\n\n        model = MyModel()\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n        x = np.ones((32, 4))\n        y = np.ones((32, 2)) * 2\n        history = model.fit(x, y)\n        self.assertGreater(history.history[\"custom\"][0], 0.0)\n\n    @pytest.mark.requires_trainable_backend\n    def test_evaluate_with_custom_compute_loss(self):\n        test_self = self\n\n        class MyModel(keras.Model):\n            def __init__(self):\n                super().__init__()\n                self.custom_metric = keras.metrics.Mean(name=\"custom\")\n                self.dense = keras.layers.Dense(2, activity_regularizer=\"l2\")\n\n            def call(self, x):\n                return self.dense(x)\n\n            def compute_loss(\n                self,\n                x=None,\n                y=None,\n                y_pred=None,\n                sample_weight=None,\n                training=True,\n            ):\n                if not in_symbolic_scope():\n                    test_self.assertFalse(training)\n                loss = super().compute_loss(\n                    x, y, y_pred, sample_weight, training\n                )\n                self.custom_metric.update_state(loss * 4)\n                return loss\n\n        model = MyModel()\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n        x = np.ones((32, 4))\n        y = np.ones((32, 2)) * 2\n        logs = model.evaluate(x, y, return_dict=True)\n        self.assertAlmostEqual(logs[\"custom\"], logs[\"loss\"] * 4)\n\n    @pytest.mark.requires_trainable_backend\n    def test_compute_loss_no_training_backwards_compatibility(self):\n        class MyModel(keras.Model):\n            def __init__(self):\n                super().__init__()\n                self.custom_metric = keras.metrics.Mean(name=\"custom\")\n                self.dense = keras.layers.Dense(2, activity_regularizer=\"l2\")\n\n            def call(self, x):\n                return self.dense(x)\n\n            def compute_loss(\n                self,\n                x=None,\n                y=None,\n                y_pred=None,\n                sample_weight=None,\n            ):\n                loss = super().compute_loss(x, y, y_pred, sample_weight)\n                self.custom_metric.update_state(loss * 4)\n                return loss\n\n        model = MyModel()\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n        x = np.ones((32, 4))\n        y = np.ones((32, 2)) * 2\n        logs = model.evaluate(x, y, return_dict=True)\n        self.assertAlmostEqual(logs[\"custom\"], logs[\"loss\"] * 4)\n        history = model.fit(x, y)\n        self.assertAlmostEqual(\n            history.history[\"custom\"][0], history.history[\"loss\"][0] * 4\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_loss_weights(self):\n        epochs = 3\n        batch_size = 20\n        dataset_size = batch_size * 2\n\n        # Single output case.\n        model = ExampleModel(units=3)\n        model.compile(\n            optimizer=optimizers.SGD(),\n            loss=losses.MeanSquaredError(),\n            metrics=[metrics.MeanSquaredError()],\n            loss_weights=0.2,\n        )\n        x = np.ones((dataset_size, 4))\n        y = np.zeros((dataset_size, 3))\n        history = model.fit(\n            x,\n            y,\n            batch_size=batch_size,\n            epochs=epochs,\n        )\n        history = history.history\n        self.assertIn(\"loss\", history)\n        self.assertAllClose(\n            history[\"loss\"],\n            [3.182979, 3.115617, 3.049681],\n            atol=1e-3,\n            tpu_atol=1e-2,\n        )\n\n        # Dict output case.\n        model = StructModel(units=3)\n        model.compile(\n            optimizer=optimizers.SGD(),\n            loss={\n                \"y_one\": losses.MeanSquaredError(),\n                \"y_two\": losses.MeanSquaredError(),\n            },\n            metrics={\n                \"y_one\": metrics.MeanSquaredError(),\n                \"y_two\": metrics.MeanSquaredError(),\n            },\n            loss_weights={\"y_one\": 0.1, \"y_two\": 0.2},\n        )\n        x1 = np.ones((dataset_size, 4))\n        x2 = np.ones((dataset_size, 4))\n        y1 = np.zeros((dataset_size, 3))\n        y2 = np.zeros((dataset_size, 3))\n        history = model.fit(\n            {\"x_one\": x1, \"x_two\": x2},\n            {\"y_one\": y1, \"y_two\": y2},\n            batch_size=batch_size,\n            epochs=epochs,\n        )\n        history = history.history\n        self.assertIn(\"loss\", history)\n        self.assertAllClose(\n            history[\"loss\"],\n            [4.778718, 4.694403, 4.611693],\n            atol=1e-3,\n            tpu_atol=1e-2,\n        )\n\n        # List output case.\n        model = ListOutputModel(units=3)\n        model.compile(\n            optimizer=optimizers.SGD(),\n            loss=[losses.MeanSquaredError(), losses.MeanSquaredError()],\n            metrics=[metrics.MeanSquaredError(), metrics.MeanSquaredError()],\n            loss_weights=[0.1, 0.2],\n        )\n        x = np.ones((dataset_size, 4))\n        y1 = np.zeros((dataset_size, 3))\n        y2 = np.zeros((dataset_size, 3))\n        history = model.fit(\n            x,\n            [y1, y2],\n            batch_size=batch_size,\n            epochs=epochs,\n        )\n        history = history.history\n        self.assertIn(\"loss\", history)\n        self.assertAllClose(\n            history[\"loss\"],\n            [4.778718, 4.694403, 4.611693],\n            atol=1e-3,\n            tpu_atol=1e-2,\n        )\n\n    @pytest.mark.requires_trainable_backend\n    def test_partial_loss_partial_label(self):\n        inputs = keras.Input((2,))\n        x = keras.layers.Dense(3, kernel_initializer=\"ones\")(inputs)\n        partial_model = keras.Model(inputs, [x, x, x])\n        partial_model.compile(loss=[\"mse\", None, None])\n        full_model = keras.Model(inputs, [x, x, x])\n        full_model.compile(loss=[\"mse\", \"mse\", \"mse\"])\n\n        eval_x = np.ones((32, 2))\n        eval_y = np.ones((32, 3))\n\n        partial_logs = partial_model.evaluate(eval_x, eval_y, return_dict=True)\n        logs = full_model.evaluate(eval_x, [eval_y] * 3, return_dict=True)\n\n        self.assertAlmostEqual(partial_logs[\"loss\"] * 3, logs[\"loss\"])\n\n    def test_symbolic_build(self):\n        class ExampleModelWithTrainingArgs(Trainer, layers.Layer):\n            def __init__(self, units):\n                layers.Layer.__init__(self)\n                Trainer.__init__(self)\n                self.dense = layers.Dense(units)\n                self.bn = layers.BatchNormalization(axis=-1)\n\n            def build(self, input_shape):\n                self.dense.build(input_shape)\n                input_shape = self.dense.compute_output_shape(input_shape)\n                self.bn.build(input_shape)\n\n            def call(self, x, training=None):\n                outputs = self.bn(self.dense(x), training=training)\n                return [outputs, outputs]\n\n        model = ExampleModelWithTrainingArgs(units=3)\n        model.compile(\n            optimizer=optimizers.SGD(),\n            loss=[losses.MeanSquaredError(), losses.MeanSquaredError()],\n            metrics=[metrics.MeanSquaredError(), metrics.MeanSquaredError()],\n        )\n        x = np.ones((4, 4))\n        y = np.zeros((4, 3))\n        model(x)  # Eager call to build model weights\n        ref_weights = model.get_weights()\n\n        # Before `_symbolic_build`\n        self.assertTrue(model.built)\n        self.assertFalse(model._compile_metrics.built)\n        self.assertFalse(model._compile_loss.built)\n        self.assertLen(model._compile_loss.metrics, 0)\n        self.assertLen(model.metrics, 2)\n\n        model._symbolic_build(data_batch=(x, (y, y)))\n        weights = model.get_weights()\n\n        # Ensure weights are intact\n        self.assertEqual(len(weights), len(ref_weights))\n        for w, ref_w in zip(weights, ref_weights):\n            self.assertAllClose(w, ref_w)\n\n        # Ensure `built`\n        self.assertTrue(model.built)\n        self.assertTrue(model._compile_metrics.built)\n        self.assertTrue(model._compile_loss.built)\n\n        # Ensure the len of metrics (original metrics + loss trackers)\n        self.assertLen(model._compile_metrics.metrics, 2)\n        self.assertLen(model._compile_loss.metrics, 2)\n        self.assertLen(model.metrics, 4)\n\n        # Ensure no values in metrics\n        for v in model._compile_metrics.variables:\n            self.assertAllClose(v, 0.0)\n        for v in model._compile_loss.variables:\n            self.assertAllClose(v, 0.0)\n\n    @pytest.mark.skipif(\n        backend.backend() != \"tensorflow\",\n        reason=\"This test is only applicable to TensorFlow.\",\n    )\n    @pytest.mark.requires_trainable_backend\n    def test_jit_compile_with_tf_determinism(self):\n        from tensorflow.python.framework.config import disable_op_determinism\n        from tensorflow.python.framework.config import enable_op_determinism\n\n        enable_op_determinism()\n\n        model = ExampleModel(units=3)\n        model.compile(\n            optimizer=optimizers.SGD(),\n            loss=losses.MeanSquaredError(),\n            metrics=[metrics.MeanSquaredError()],\n        )\n\n        self.assertFalse(model.jit_compile)\n        disable_op_determinism()\n\n    @pytest.mark.requires_trainable_backend\n    @pytest.mark.skipif(\n        backend.backend() == \"torch\",\n        reason=\"`steps_per_execution` not implemented for torch yet\",\n    )\n    def test_retracing(self):\n        x = np.ones((100, 4))\n        y = np.ones((100, 1))\n\n        input = keras.Input(shape=[4])\n        output = keras.layers.Dense(1, activation=\"relu\")(input)\n\n        tracing_count = [0]\n\n        class TracingCounterModel(keras.Model):\n            def train_step(self, *args):\n                tracing_count[0] = tracing_count[0] + 1\n                return super().train_step(*args)\n\n        model = TracingCounterModel(inputs=input, outputs=output)\n        model.compile(\n            loss=\"mse\",\n            optimizer=\"adam\",\n            steps_per_execution=20,\n        )\n\n        epochs = 1\n        model.fit(\n            x=x,\n            y=y,\n            batch_size=1,\n            epochs=epochs,\n            verbose=0,\n        )\n        self.assertLessEqual(tracing_count[0], 2)\n\n    @pytest.mark.requires_trainable_backend\n    @pytest.mark.skipif(\n        backend.backend() == \"torch\",\n        reason=\"`steps_per_execution` not implemented for torch yet\",\n    )\n    @pytest.mark.skipif(\n        backend.backend() == \"tensorflow\",\n        reason=\"`predict_function` with `steps_per_execution` is not \"\n        \"optimized for tensorflow yet\",\n    )\n    def test_retracing_predict(self):\n        x = np.ones((100, 4))\n\n        input = keras.Input(shape=[4])\n        output = keras.layers.Dense(1, activation=\"relu\")(input)\n\n        tracing_count = [0]\n\n        class TracingCounterModel(keras.Model):\n            def predict_step(self, *args):\n                tracing_count[0] = tracing_count[0] + 1\n                return super().predict_step(*args)\n\n        model = TracingCounterModel(inputs=input, outputs=output)\n        model.compile(\n            loss=\"mse\",\n            optimizer=\"adam\",\n            steps_per_execution=20,\n        )\n\n        model.predict(\n            x=x,\n            batch_size=1,\n            verbose=0,\n        )\n        self.assertLessEqual(tracing_count[0], 2)\n\n\nclass JAXTrainerCorrectnessTest(test_case.TestCase, parameterized.TestCase):\n    @parameterized.named_parameters(\n        (\"single_device\", False),\n        (\"distributed\", True),\n    )\n    @pytest.mark.skipif(backend.backend() != \"jax\", reason=\"JAX only\")\n    def test_jit_fit_with_out_shardings_logic(self, distributed):\n        x = np.random.rand(64, 8).astype(\"float32\")\n        y = np.random.rand(64, 1).astype(\"float32\")\n\n        distribution = None\n        if distributed:\n            if len(jax.local_devices()) < 2:\n                self.skipTest(\n                    \"Distributed test requires at least 2 JAX devices.\"\n                )\n\n            devices = jax.local_devices()\n            mesh = DeviceMesh(\n                shape=(len(devices),), axis_names=(\"batch\",), devices=devices\n            )\n            distribution = DataParallel(mesh)\n\n        scope = distribution.scope() if distribution else mock.MagicMock()\n\n        with scope:\n            model = models.Sequential(\n                [\n                    layers.Dense(4, activation=\"relu\", input_shape=(8,)),\n                    layers.Dense(1),\n                ]\n            )\n            model.compile(optimizer=\"adam\", loss=\"mse\", jit_compile=True)\n\n            if distribution:\n                expected_shardings = [\n                    v.value.sharding for v in model.trainable_variables\n                ]\n                self.assertNotEqual(len(set(expected_shardings)), 1)\n\n            model.fit(x, y, epochs=2, batch_size=32, verbose=0)\n\n            if distribution:\n                actual_shardings = [\n                    v.value.sharding for v in model.trainable_variables\n                ]\n                self.assertListEqual(actual_shardings, expected_shardings)\n"
  },
  {
    "path": "keras/src/tree/__init__.py",
    "content": "from keras.src.tree.tree_api import assert_same_paths\nfrom keras.src.tree.tree_api import assert_same_structure\nfrom keras.src.tree.tree_api import flatten\nfrom keras.src.tree.tree_api import flatten_with_path\nfrom keras.src.tree.tree_api import is_nested\nfrom keras.src.tree.tree_api import lists_to_tuples\nfrom keras.src.tree.tree_api import map_shape_structure\nfrom keras.src.tree.tree_api import map_structure\nfrom keras.src.tree.tree_api import map_structure_up_to\nfrom keras.src.tree.tree_api import pack_sequence_as\nfrom keras.src.tree.tree_api import register_tree_node_class\nfrom keras.src.tree.tree_api import traverse\n"
  },
  {
    "path": "keras/src/tree/dmtree_impl.py",
    "content": "import collections\nimport collections.abc\nimport itertools\n\nfrom keras.src.backend.config import backend\nfrom keras.src.utils.module_utils import dmtree\n\n# NOTE: There are two known discrepancies between this `dmtree` implementation\n# of the tree API and the `optree` implementation:\n#\n# 1. `map_structure` with *multiple* structures and `map_structure_up_to` do not\n#    use the object registration (they use the raw `dmtree.map_structure` and\n#    `dmtree.map_structure_up_to`). This only has consequences with two types of\n#    structures:\n#    - `TrackedSet` will not explored (considered as a leaf).\n#    - `OrderedDict` will be traversed in the order of sorted keys, not the\n#      order of the items. This is typically inconsequential because functions\n#      used with `map_structure` and `map_structure_up_to` are typically not\n#      order dependent and are, in fact, stateless.\n#\n# 2. The handling of non-sortable keys in dictionaries in inconsistent. `optree`\n#    uses the iteration order while `dmtree` raises an error. This is not an\n#    issue as keys are always strings. But this is the reason why we document\n#    non-sortable keys as unsupported (meaning behavior is undefined).\n\nREGISTERED_CLASSES = {}\n\nClassRegistration = collections.namedtuple(\n    \"ClassRegistration\", [\"flatten\", \"unflatten\"]\n)\n\n\nclass TypeErrorRemapping:\n    def __enter__(self):\n        pass\n\n    def __exit__(self, exc_type, exc_value, traceback):\n        if exc_type is TypeError:\n            raise ValueError(exc_value).with_traceback(traceback)\n        return False\n\n\ndef register_tree_node(\n    cls,\n    flatten_func=None,\n    unflatten_func=None,\n):\n    if flatten_func is None:\n        flatten_func = lambda x: x.tree_flatten()\n    if unflatten_func is None:\n        unflatten_func = cls.tree_unflatten\n    REGISTERED_CLASSES[cls] = ClassRegistration(flatten_func, unflatten_func)\n\n\ndef register_tree_node_class(cls):\n    register_tree_node(cls)\n    return cls\n\n\nregister_tree_node(\n    collections.OrderedDict,\n    lambda d: (d.values(), list(d.keys()), d.keys()),\n    lambda metadata, children: collections.OrderedDict(zip(metadata, children)),\n)\n\nif backend() == \"tensorflow\":\n    from tensorflow.python.trackable.data_structures import ListWrapper\n    from tensorflow.python.trackable.data_structures import _DictWrapper\n\n    register_tree_node(\n        ListWrapper,\n        lambda x: (x, None),\n        lambda metadata, children: ListWrapper(list(children)),\n    )\n\n    def sorted_keys_and_values(d):\n        keys = sorted(list(d.keys()))\n        values = [d[k] for k in keys]\n        return values, keys, keys\n\n    register_tree_node(\n        _DictWrapper,\n        sorted_keys_and_values,\n        lambda metadata, children: _DictWrapper(\n            {key: child for key, child in zip(metadata, children)}\n        ),\n    )\n\n\ndef is_nested(structure):\n    return type(structure) in REGISTERED_CLASSES or dmtree.is_nested(structure)\n\n\ndef traverse(func, structure, top_down=True):\n    if not callable(func):\n        raise TypeError(\n            f\"`func` must be callable, got {func} of type {type(func)}\"\n        )\n\n    def remap_map_to_none(value, new_value):\n        if isinstance(value, type) and value.__name__ == \"MAP_TO_NONE\":\n            return new_value\n        return value\n\n    def traverse_top_down(s):\n        ret = func(s)\n        if ret is not None:\n            return remap_map_to_none(ret, dmtree.MAP_TO_NONE)\n        registration = REGISTERED_CLASSES.get(type(s), None)\n        if registration is None:\n            return None\n        flat_meta_s = registration.flatten(s)\n        flat_s = [\n            dmtree.traverse(traverse_top_down, x, top_down=True)\n            for x in list(flat_meta_s[0])\n        ]\n        return registration.unflatten(flat_meta_s[1], flat_s)\n\n    def traverse_bottom_up(s):\n        registration = REGISTERED_CLASSES.get(type(s), None)\n        if registration is not None:\n            flat_meta_s = registration.flatten(s)\n            ret = [traverse_bottom_up(x) for x in list(flat_meta_s[0])]\n            ret = registration.unflatten(flat_meta_s[1], ret)\n        elif not dmtree.is_nested(s):\n            ret = s\n        elif isinstance(s, collections.abc.Mapping):\n            ret = [traverse_bottom_up(s[key]) for key in sorted(s)]\n            ret = dmtree._sequence_like(s, ret)\n        else:\n            ret = [traverse_bottom_up(x) for x in s]\n            ret = dmtree._sequence_like(s, ret)\n        func_ret = func(ret)\n        return ret if func_ret is None else remap_map_to_none(func_ret, None)\n\n    if top_down:\n        return dmtree.traverse(traverse_top_down, structure, top_down=True)\n    else:\n        return traverse_bottom_up(structure)\n\n\ndef flatten(structure):\n    if not is_nested(structure):\n        return [structure]\n\n    flattened = []\n\n    def flatten_func(s):\n        registration = REGISTERED_CLASSES.get(type(s), None)\n        if registration is not None:\n            flat_s = list(registration.flatten(s)[0])\n            return dmtree.traverse(flatten_func, flat_s, top_down=True)\n        if not is_nested(s):\n            flattened.append(s)\n            return dmtree.MAP_TO_NONE if s is None else s\n        return None\n\n    dmtree.traverse(flatten_func, structure, top_down=True)\n    return flattened\n\n\ndef _recursive_flatten_with_path(path, structure, flattened):\n    registration = REGISTERED_CLASSES.get(type(structure), None)\n    if registration is not None:\n        flat_meta_paths = registration.flatten(structure)\n        flat = flat_meta_paths[0]\n        paths = (\n            flat_meta_paths[2]\n            if len(flat_meta_paths) >= 3\n            else itertools.count()\n        )\n        for key, value in zip(paths, flat):\n            _recursive_flatten_with_path(path + (key,), value, flattened)\n    elif not dmtree.is_nested(structure):\n        flattened.append((path, structure))\n    elif isinstance(structure, collections.abc.Mapping):\n        for key in sorted(structure):\n            _recursive_flatten_with_path(\n                path + (key,), structure[key], flattened\n            )\n    else:\n        for key, value in enumerate(structure):\n            _recursive_flatten_with_path(path + (key,), value, flattened)\n\n\ndef flatten_with_path(structure):\n    if not is_nested(structure):\n        return [((), structure)]\n\n    # Fully reimplemented in Python to handle registered classes, OrderedDict\n    # and namedtuples the same way as optree.\n    flattened = []\n    _recursive_flatten_with_path((), structure, flattened)\n    return flattened\n\n\ndef map_structure(func, *structures, none_is_leaf=True):\n    if not callable(func):\n        raise TypeError(\n            f\"`func` must be callable, got {func} of type {type(func)}\"\n        )\n\n    map_func = func\n    if not none_is_leaf:\n\n        def func_skipping_none(*args):\n            # Check if the reference entry (first one) is None\n            if args[0] is None:\n                if not all(s is None for s in args):\n                    raise ValueError(\n                        \"Structure mismatch: some arguments are None, others \"\n                        f\"are not. Received arguments: {args}.\"\n                    )\n                return None\n            return func(*args)\n\n        map_func = func_skipping_none\n\n    def func_traverse_wrapper(s):\n        if is_nested(s):\n            return None\n        ret = map_func(s)\n        if ret is None:\n            return dmtree.MAP_TO_NONE\n        return ret\n\n    if len(structures) == 1:\n        return traverse(func_traverse_wrapper, structures[0])\n\n    with TypeErrorRemapping():\n        return dmtree.map_structure(map_func, *structures)\n\n\ndef map_structure_up_to(shallow_structure, func, *structures):\n    if not callable(func):\n        raise TypeError(\n            f\"`func` must be callable, got {func} of type {type(func)}\"\n        )\n\n    with TypeErrorRemapping():\n        return dmtree.map_structure_up_to(shallow_structure, func, *structures)\n\n\ndef assert_same_structure(a, b):\n    # Fully reimplemented in Python to handle registered classes.\n\n    # Don't handle OrderedDict as a registered class, use the normal dict path\n    # so that OrderedDict is equivalent to dict per optree behavior.\n    a_registration = REGISTERED_CLASSES.get(type(a), None)\n    if type(a) is collections.OrderedDict:\n        a_registration = None\n\n    b_registration = REGISTERED_CLASSES.get(type(b), None)\n    if type(b) is collections.OrderedDict:\n        b_registration = None\n\n    if a_registration != b_registration:\n        raise ValueError(\n            f\"Custom node type mismatch; \"\n            f\"expected type: {type(a)}, got type: {type(b)} \"\n            f\"while comparing {a} and {b}.\"\n        )\n    if a_registration is not None:\n        a_flat_meta = a_registration.flatten(a)\n        b_flat_meta = b_registration.flatten(b)\n        a_flat = list(a_flat_meta[0])\n        b_flat = list(b_flat_meta[0])\n        if not a_flat_meta[1] == b_flat_meta[1]:\n            raise ValueError(\n                f\"Mismatch custom node data; \"\n                f\"expected: {a_flat_meta[1]}, got: {b_flat_meta[1]} \"\n                f\"while comparing {a} and {b}.\"\n            )\n        if len(a_flat) != len(b_flat):\n            raise ValueError(\n                f\"Arity mismatch; expected: {len(a)}, got: {len(b)} \"\n                f\"while comparing {a} and {b}.\"\n            )\n        for sub_a, sub_b in zip(a_flat, b_flat):\n            assert_same_structure(sub_a, sub_b)\n    elif not dmtree.is_nested(a):\n        if dmtree.is_nested(b):\n            raise ValueError(\n                f\"Structures don't have the same nested structure: {a}, {b}.\"\n            )\n    elif isinstance(\n        a, (dict, collections.OrderedDict, collections.defaultdict)\n    ):\n        if not isinstance(\n            b, (dict, collections.OrderedDict, collections.defaultdict)\n        ):\n            raise ValueError(\n                f\"Expected an instance of dict, collections.OrderedDict, or \"\n                f\"collections.defaultdict, got {type(b)} \"\n                f\"while comparing {a} and {b}.\"\n            )\n        a_keys = sorted(a)\n        b_keys = sorted(b)\n        if not a_keys == b_keys:\n            raise ValueError(\n                f\"Dictionary key mismatch; \"\n                f\"expected key(s): {a_keys}, got key(s): {b_keys} \"\n                f\"while comparing {a} and {b}.\"\n            )\n        for key in a_keys:\n            assert_same_structure(a[key], b[key])\n    elif isinstance(a, collections.abc.Mapping):\n        raise ValueError(\n            f\"Encountered unregistered collections.abc.Mapping type: {type(a)} \"\n            f\"while comparing {a} and {b}.\"\n        )\n    else:\n        if type(a) is not type(b):\n            raise ValueError(\n                f\"Expected an instance of {type(a)}, got {type(b)} \"\n                f\"while comparing {a} and {b}.\"\n            )\n        if not len(a) == len(b):\n            raise ValueError(\n                f\"Arity mismatch; expected: {len(a)}, got: {len(b)} \"\n                f\"while comparing {a} and {b}.\"\n            )\n        for sub_a, sub_b in zip(a, b):\n            assert_same_structure(sub_a, sub_b)\n\n\ndef assert_same_paths(a, b):\n    a_paths = set([path for path, _ in flatten_with_path(a)])\n    b_paths = set([path for path, _ in flatten_with_path(b)])\n\n    if a_paths != b_paths:\n        msg = \"`a` and `b` don't have the same paths.\"\n        a_diff = a_paths.difference(b_paths)\n        if a_diff:\n            msg += f\"\\nPaths in `a` missing in `b`:\\n{a_diff}\"\n        b_diff = b_paths.difference(a_paths)\n        if b_diff:\n            msg += f\"\\nPaths in `b` missing in `a`:\\n{b_diff}\"\n        raise ValueError(msg)\n\n\ndef pack_sequence_as(structure, flat_sequence):\n    # This is not just an optimization for the case when structure is a leaf.\n    # This is required to avoid Torch Dynamo failures.\n    if not is_nested(structure):\n        if len(flat_sequence) == 1:\n            return flat_sequence[0]\n        else:\n            raise ValueError(\n                \"Incorrect number of leaves provided by `flat_sequence` for \"\n                f\"`structure`; expected: 1, got {len(flat_sequence)}.\"\n            )\n\n    flat_sequence_it = enumerate(flat_sequence)\n\n    def unflatten_func(s):\n        registration = REGISTERED_CLASSES.get(type(s), None)\n        if registration is not None:\n            flat_meta_s = registration.flatten(s)\n            flat_s = dmtree.traverse(\n                unflatten_func, list(flat_meta_s[0]), top_down=True\n            )\n            return registration.unflatten(flat_meta_s[1], flat_s)\n        elif not dmtree.is_nested(s):\n            try:\n                _, value = next(flat_sequence_it)\n                return dmtree.MAP_TO_NONE if value is None else value\n            except StopIteration:\n                raise ValueError(\n                    \"Too few leaves provided by `flat_sequence` for \"\n                    f\"`structure`. Got {len(flat_sequence)}.\"\n                )\n        return None\n\n    ret = dmtree.traverse(unflatten_func, structure, top_down=True)\n    try:\n        index, _ = next(flat_sequence_it)\n        raise ValueError(\n            \"Too many leaves provided by `flat_sequence` for `structure`; \"\n            f\"expected: {index}, got {len(flat_sequence)}.\"\n        )\n    except StopIteration:\n        return ret\n\n\ndef lists_to_tuples(structure):\n    def list_to_tuple(instance):\n        return tuple(instance) if isinstance(instance, list) else None\n\n    return traverse(list_to_tuple, structure, top_down=False)\n\n\ndef map_shape_structure(func, structure):\n    if not callable(func):\n        raise TypeError(\n            f\"`func` must be callable, got {func} of type {type(func)}\"\n        )\n\n    def map_shape_func(x):\n        if isinstance(x, (list, tuple)) and all(\n            isinstance(e, (int, type(None))) for e in x\n        ):\n            ret = func(x)\n        elif is_nested(x):\n            return None\n        else:\n            ret = func(x)\n        return ret if ret is not None else dmtree.MAP_TO_NONE\n\n    return traverse(map_shape_func, structure, top_down=True)\n"
  },
  {
    "path": "keras/src/tree/optree_impl.py",
    "content": "import optree\nimport optree.utils\n\nfrom keras.src.backend.config import backend\n\n\ndef register_tree_node_class(cls):\n    return optree.register_pytree_node_class(cls, namespace=\"keras\")\n\n\n# Register backend-specific node classes\nif backend() == \"tensorflow\":\n    from tensorflow.python.trackable.data_structures import ListWrapper\n    from tensorflow.python.trackable.data_structures import _DictWrapper\n\n    try:\n        optree.register_pytree_node(\n            ListWrapper,\n            lambda x: (x, None),\n            lambda metadata, children: ListWrapper(list(children)),\n            namespace=\"keras\",\n        )\n\n        def sorted_keys_and_values(d):\n            keys = sorted(list(d.keys()))\n            values = [d[k] for k in keys]\n            return values, keys, keys\n\n        optree.register_pytree_node(\n            _DictWrapper,\n            sorted_keys_and_values,\n            lambda metadata, children: _DictWrapper(\n                {key: child for key, child in zip(metadata, children)}\n            ),\n            namespace=\"keras\",\n        )\n    except ValueError:\n        pass  # We may have already registered if we are reimporting keras.\n\n\ndef is_nested(structure):\n    return not optree.tree_is_leaf(\n        structure, none_is_leaf=True, namespace=\"keras\"\n    )\n\n\ndef traverse(func, structure, top_down=True):\n    # From https://github.com/google/jax/pull/19695\n    def traverse_children():\n        children, treedef = optree.tree_flatten(\n            structure,\n            is_leaf=lambda x: x is not structure,\n            none_is_leaf=True,\n            namespace=\"keras\",\n        )\n        if treedef.num_nodes == 1 and treedef.num_leaves == 1:\n            return structure\n        else:\n            return optree.tree_unflatten(\n                treedef,\n                [traverse(func, c, top_down=top_down) for c in children],\n            )\n\n    if top_down:\n        ret = func(structure)\n        if ret is None:\n            return traverse_children()\n    else:\n        traversed_structure = traverse_children()\n        ret = func(traversed_structure)\n        if ret is None:\n            return traversed_structure\n    # Detect MAP_TO_NONE without tree_api import to avoid circular import.\n    if isinstance(ret, type) and ret.__name__ == \"MAP_TO_NONE\":\n        return None\n    return ret\n\n\ndef flatten(structure):\n    # optree.tree_flatten returns a pair (leaves, treespec) where the first\n    # element is a list of leaf values and the second element is a treespec\n    # representing the structure of the pytree.\n    leaves, _ = optree.tree_flatten(\n        structure, none_is_leaf=True, namespace=\"keras\"\n    )\n    return leaves\n\n\ndef flatten_with_path(structure):\n    paths, leaves, _ = optree.tree_flatten_with_path(\n        structure, none_is_leaf=True, namespace=\"keras\"\n    )\n    return list(zip(paths, leaves))\n\n\ndef map_structure(func, *structures, none_is_leaf=True):\n    if not structures:\n        raise ValueError(\"Must provide at least one structure\")\n\n    # Add check for same structures, otherwise optree just maps to shallowest.\n    def func_with_check(*args):\n        if not all(\n            optree.tree_is_leaf(s, none_is_leaf=none_is_leaf, namespace=\"keras\")\n            for s in args\n        ):\n            raise ValueError(\"Structures don't have the same nested structure.\")\n        return func(*args)\n\n    map_func = func_with_check if len(structures) > 1 else func\n\n    return optree.tree_map(\n        map_func, *structures, none_is_leaf=none_is_leaf, namespace=\"keras\"\n    )\n\n\ndef map_structure_up_to(shallow_structure, func, *structures):\n    if not structures:\n        raise ValueError(\"Must provide at least one structure\")\n\n    # Add check that `shallow_structure` really is the shallowest.\n    # Also only call `func` on `structures` and not `shallow_structure`.\n    def func_with_check_without_shallow_structure(shallow, *args):\n        if not optree.tree_is_leaf(shallow):\n            raise ValueError(\"Structures don't have the same nested structure.\")\n        return func(*args)\n\n    return optree.tree_map(\n        func_with_check_without_shallow_structure,\n        shallow_structure,\n        *structures,\n        none_is_leaf=True,\n        namespace=\"keras\",\n    )\n\n\ndef assert_same_structure(a, b):\n    def check(a_leaf, b_leaf):\n        if not optree.tree_is_leaf(\n            a_leaf, none_is_leaf=True, namespace=\"keras\"\n        ) or not optree.tree_is_leaf(\n            b_leaf, none_is_leaf=True, namespace=\"keras\"\n        ):\n            raise ValueError(\"Structures don't have the same nested structure.\")\n        return None\n\n    optree.tree_map(check, a, b, none_is_leaf=True, namespace=\"keras\")\n\n\ndef assert_same_paths(a, b):\n    a_paths = set(optree.tree_paths(a, none_is_leaf=True, namespace=\"keras\"))\n    b_paths = set(optree.tree_paths(b, none_is_leaf=True, namespace=\"keras\"))\n\n    if a_paths != b_paths:\n        msg = \"`a` and `b` don't have the same paths.\"\n        a_diff = a_paths.difference(b_paths)\n        if a_diff:\n            msg += f\"\\nPaths in `a` missing in `b`:\\n{a_diff}\"\n        b_diff = b_paths.difference(a_paths)\n        if b_diff:\n            msg += f\"\\nPaths in `b` missing in `a`:\\n{b_diff}\"\n        raise ValueError(msg)\n\n\ndef pack_sequence_as(structure, flat_sequence):\n    _, treespec = optree.tree_flatten(\n        structure, none_is_leaf=True, namespace=\"keras\"\n    )\n    return optree.tree_unflatten(treespec, flat_sequence)\n\n\ndef lists_to_tuples(structure):\n    def list_to_tuple(instance):\n        return tuple(instance) if isinstance(instance, list) else None\n\n    return traverse(list_to_tuple, structure, top_down=False)\n\n\ndef map_shape_structure(func, structure):\n    def is_shape_tuple(x):\n        return isinstance(x, (list, tuple)) and all(\n            isinstance(e, (int, type(None))) for e in x\n        )\n\n    return optree.tree_map(\n        func,\n        structure,\n        is_leaf=is_shape_tuple,\n        none_is_leaf=True,\n        namespace=\"keras\",\n    )\n"
  },
  {
    "path": "keras/src/tree/torchtree_impl.py",
    "content": "from collections import defaultdict\n\nfrom torch.utils import _pytree as torch_tree\n\n\ndef register_tree_node_class(cls):\n    torch_tree.register_pytree_node(\n        cls,\n        flatten_fn=lambda x: x.torchtree_flatten(),\n        unflatten_fn=cls.torchtree_unflatten,\n        serialized_type_name=f\"{cls.__name__}\",\n        flatten_with_keys_fn=lambda x: x.torchtree_flatten_with_keys(),\n    )\n    return cls\n\n\ndef _tree_is_leaf(tree, is_leaf=None):\n    if is_leaf is not None and is_leaf(tree):\n        return True\n    return torch_tree._get_node_type(tree) not in torch_tree.SUPPORTED_NODES\n\n\ndef _dict_to_ordered_dict(structure):\n    # We need to sort dict and defaultdict to ensure a deterministic order that\n    # that is consistent with other tree implementations.\n    def func(x):\n        if type(x) is dict:\n            return {k: x[k] for k in sorted(x.keys())}\n        elif type(x) is defaultdict:\n            return defaultdict(\n                x.default_factory,\n                {k: x[k] for k in sorted(x.keys())},\n            )\n        return None\n\n    def traverse_children():\n        children, treedef = torch_tree.tree_flatten(\n            structure,\n            is_leaf=lambda x: x is not structure,\n        )\n        if treedef.num_nodes == 1 and treedef.num_leaves == 1:\n            return structure\n        else:\n            return torch_tree.tree_unflatten(\n                [_dict_to_ordered_dict(c) for c in children],\n                treedef,\n            )\n\n    ret = func(structure)\n    if ret is None:\n        return traverse_children()\n    if isinstance(ret, type) and ret.__name__ == \"MAP_TO_NONE\":\n        return None\n    return ret\n\n\ndef is_nested(structure):\n    return not _tree_is_leaf(structure)\n\n\ndef traverse(func, structure, top_down=True):\n    def traverse_children():\n        children, treedef = torch_tree.tree_flatten(\n            structure,\n            is_leaf=lambda x: x is not structure,\n        )\n        if treedef.num_nodes == 1 and treedef.num_leaves == 1:\n            return structure\n        else:\n            return torch_tree.tree_unflatten(\n                [traverse(func, c, top_down=top_down) for c in children],\n                treedef,\n            )\n\n    structure = _dict_to_ordered_dict(structure)\n    if top_down:\n        ret = func(structure)\n        if ret is None:\n            return traverse_children()\n    else:\n        traversed_structure = traverse_children()\n        ret = func(traversed_structure)\n        if ret is None:\n            return traversed_structure\n    # Detect MAP_TO_NONE without tree_api import to avoid circular import.\n    if isinstance(ret, type) and ret.__name__ == \"MAP_TO_NONE\":\n        return None\n    return ret\n\n\ndef flatten(structure):\n    # We need to first sort dicts to ensure a deterministic order that is\n    # consistent with other tree implementations.\n    structure = _dict_to_ordered_dict(structure)\n    leaves, _ = torch_tree.tree_flatten(structure)\n    return leaves\n\n\ndef flatten_with_path(structure):\n    # We need to first sort dicts to ensure a deterministic order that is\n    # consistent with other tree implementations.\n    structure = _dict_to_ordered_dict(structure)\n    leaves_with_path, _ = torch_tree.tree_flatten_with_path(structure)\n    results = []\n    fields = []\n    for key, leaf in leaves_with_path:\n        for k in key:\n            if isinstance(k, torch_tree.GetAttrKey) and k.name not in fields:\n                fields.append(k.name)\n    fields = sorted(fields)\n    field_to_idx = {f: i for i, f in enumerate(fields)}\n    for key, leaf in leaves_with_path:\n        # Convert to a tuple of keys.\n        path = []\n        for k in key:\n            if isinstance(k, torch_tree.SequenceKey):\n                path.append(k.idx)\n            elif isinstance(k, torch_tree.MappingKey):\n                path.append(k.key)\n            elif isinstance(k, torch_tree.GetAttrKey):\n                path.append(field_to_idx[k.name])\n        results.append((tuple(path), leaf))\n    return results\n\n\ndef map_structure(func, *structures, none_is_leaf=True):\n    if not structures:\n        raise ValueError(\"Must provide at least one structure\")\n\n    map_func = func\n    if not none_is_leaf:\n\n        def func_skipping_none(*args):\n            # Check if the reference entry (first one) is None\n            if args[0] is None:\n                if not all(s is None for s in args):\n                    raise ValueError(\n                        \"Structure mismatch: some arguments are None, others \"\n                        f\"are not. Received arguments: {args}.\"\n                    )\n                return None\n            return func(*args)\n\n        map_func = func_skipping_none\n\n    return torch_tree.tree_map(map_func, *structures)\n\n\ndef map_structure_up_to(shallow_structure, func, *structures):\n    if not structures:\n        raise ValueError(\"Must provide at least one structure\")\n\n    # Add check that `shallow_structure` really is the shallowest.\n    # Also only call `func` on `structures` and not `shallow_structure`.\n    def func_with_check_without_shallow_structure(shallow, *args):\n        if not _tree_is_leaf(shallow):\n            raise ValueError(\"Structures don't have the same nested structure.\")\n        return func(*args)\n\n    return torch_tree.tree_map(\n        func_with_check_without_shallow_structure,\n        shallow_structure,\n        *structures,\n    )\n\n\ndef assert_same_structure(a, b):\n    def check(a_leaf, b_leaf):\n        if not _tree_is_leaf(a_leaf) or not _tree_is_leaf(b_leaf):\n            raise ValueError(\"Structures don't have the same nested structure.\")\n        return None\n\n    torch_tree.tree_map(check, a, b)\n\n\ndef assert_same_paths(a, b):\n    a_paths = set([path for path, _ in flatten_with_path(a)])\n    b_paths = set([path for path, _ in flatten_with_path(b)])\n\n    if a_paths != b_paths:\n        msg = \"`a` and `b` don't have the same paths.\"\n        a_diff = a_paths.difference(b_paths)\n        if a_diff:\n            msg += f\"\\nPaths in `a` missing in `b`:\\n{a_diff}\"\n        b_diff = b_paths.difference(a_paths)\n        if b_diff:\n            msg += f\"\\nPaths in `b` missing in `a`:\\n{b_diff}\"\n        raise ValueError(msg)\n\n\ndef pack_sequence_as(structure, flat_sequence):\n    # We need to first sort dicts to ensure a deterministic order that is\n    # consistent with other tree implementations.\n    structure = _dict_to_ordered_dict(structure)\n    _, treespec = torch_tree.tree_flatten(structure)\n    return torch_tree.tree_unflatten(flat_sequence, treespec)\n\n\ndef lists_to_tuples(structure):\n    def list_to_tuple(instance):\n        return tuple(instance) if isinstance(instance, list) else None\n\n    return traverse(list_to_tuple, structure, top_down=False)\n\n\ndef map_shape_structure(func, structure):\n    def is_shape_tuple(x):\n        return isinstance(x, (list, tuple)) and all(\n            isinstance(e, (int, type(None))) for e in x\n        )\n\n    # We need to first sort dicts to ensure a deterministic order that is\n    # consistent with other tree implementations.\n    structure = _dict_to_ordered_dict(structure)\n    return torch_tree.tree_map(func, structure, is_leaf=is_shape_tuple)\n"
  },
  {
    "path": "keras/src/tree/tree_api.py",
    "content": "import warnings\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend.config import backend\nfrom keras.src.utils.module_utils import dmtree\nfrom keras.src.utils.module_utils import optree\n\nif backend() == \"torch\":\n    # torchtree_impl is especially used for Torch backend, as it works better\n    # with torch.compile.\n    from keras.src.tree import torchtree_impl as tree_impl\nelif optree.available:\n    from keras.src.tree import optree_impl as tree_impl\nelif dmtree.available:\n    from keras.src.tree import dmtree_impl as tree_impl\nelse:\n    raise ImportError(\n        \"To use Keras, you need to have `optree` installed. \"\n        \"Install it via `pip install optree`\"\n    )\n\n\ndef register_tree_node_class(cls):\n    return tree_impl.register_tree_node_class(cls)\n\n\n@keras_export(\"keras.tree.MAP_TO_NONE\")\nclass MAP_TO_NONE:\n    \"\"\"Special value for use with `traverse()`.\"\"\"\n\n    pass\n\n\n@keras_export(\"keras.tree.is_nested\")\ndef is_nested(structure):\n    \"\"\"Checks if a given structure is nested.\n\n    Examples:\n\n    >>> keras.tree.is_nested(42)\n    False\n    >>> keras.tree.is_nested({\"foo\": 42})\n    True\n\n    Args:\n        structure: A structure to check.\n\n    Returns:\n        `True` if a given structure is nested, i.e. is a sequence, a mapping,\n        or a namedtuple, and `False` otherwise.\n    \"\"\"\n    return tree_impl.is_nested(structure)\n\n\n@keras_export(\"keras.tree.traverse\")\ndef traverse(func, structure, top_down=True):\n    \"\"\"Traverses the given nested structure, applying the given function.\n\n    The traversal is depth-first. If `top_down` is True (default), parents\n    are returned before their children (giving the option to avoid traversing\n    into a sub-tree).\n\n    Examples:\n\n    >>> v = []\n    >>> keras.tree.traverse(v.append, [(1, 2), [3], {\"a\": 4}], top_down=True)\n    [(1, 2), [3], {'a': 4}]\n    >>> v\n    [[(1, 2), [3], {'a': 4}], (1, 2), 1, 2, [3], 3, {'a': 4}, 4]\n\n    >>> v = []\n    >>> keras.tree.traverse(v.append, [(1, 2), [3], {\"a\": 4}], top_down=False)\n    [(1, 2), [3], {'a': 4}]\n    >>> v\n    [1, 2, (1, 2), 3, [3], 4, {'a': 4}, [(1, 2), [3], {'a': 4}]]\n\n    Args:\n        func: The function to be applied to each sub-nest of the structure.\n\n        When traversing top-down:\n            If `func(subtree) is None` the traversal continues into the\n            sub-tree.\n            If `func(subtree) is not None` the traversal does not continue\n            into the sub-tree. The sub-tree will be replaced by `func(subtree)`\n            in the returned structure (to replace the sub-tree with `None`, use\n            the special value `MAP_TO_NONE`).\n\n        When traversing bottom-up:\n            If `func(subtree) is None` the traversed sub-tree is returned\n            unaltered.\n            If `func(subtree) is not None` the sub-tree will be replaced by\n            `func(subtree)` in the returned structure (to replace the sub-tree\n            with None, use the special value `MAP_TO_NONE`).\n\n        structure: The structure to traverse.\n        top_down: If True, parent structures will be visited before their\n            children.\n\n    Returns:\n        The structured output from the traversal.\n\n    Raises:\n        TypeError: If `func` is not callable.\n    \"\"\"\n    return tree_impl.traverse(func, structure, top_down=top_down)\n\n\n@keras_export(\"keras.tree.flatten\")\ndef flatten(structure):\n    \"\"\"Flattens a possibly nested structure into a list.\n\n    In the case of dict instances, the sequence consists of the values,\n    sorted by key to ensure deterministic behavior. However, instances of\n    `collections.OrderedDict` are handled differently: their sequence order is\n    used instead of the sorted keys. The same convention is followed in\n    `pack_sequence_as`. This correctly unflattens dicts and `OrderedDict` after\n    they have been flattened, or vice-versa.\n\n    Dictionaries with non-sortable keys are not supported.\n\n    Examples:\n\n    >>> keras.tree.flatten([[1, 2, 3], [4, [5], [[6]]]])\n    [1, 2, 3, 4, 5, 6]\n    >>> keras.tree.flatten(None)\n    [None]\n    >>> keras.tree.flatten(1)\n    [1]\n    >>> keras.tree.flatten({100: 'world!', 6: 'Hello'})\n    ['Hello', 'world!']\n\n    Args:\n        structure: An arbitrarily nested structure.\n\n    Returns:\n        A list, the flattened version of the input `structure`.\n    \"\"\"\n    return tree_impl.flatten(structure)\n\n\n@keras_export(\"keras.tree.flatten_with_path\")\ndef flatten_with_path(structure):\n    \"\"\"Flattens a possibly nested structure into a list.\n\n    This is a variant of flattens() which produces a\n    list of pairs: `(path, item)`. A path is a tuple of indices and/or keys\n    which uniquely identifies the position of the corresponding item.\n\n    Dictionaries with non-sortable keys are not supported.\n\n    Examples:\n\n    >>> keras.flatten_with_path([{\"foo\": 42}])\n    [((0, 'foo'), 42)]\n\n\n    Args:\n        structure: An arbitrarily nested structure.\n\n    Returns:\n        A list of `(path, item)` pairs corresponding to the flattened\n        version of the input `structure`.\n    \"\"\"\n    return tree_impl.flatten_with_path(structure)\n\n\n@keras_export(\"keras.tree.map_structure\")\ndef map_structure(func, *structures, none_is_leaf=True):\n    \"\"\"Maps `func` through given structures.\n\n    Examples:\n\n    >>> structure = [[1], [2], [3]]\n    >>> keras.tree.map_structure(lambda v: v**2, structure)\n    [[1], [4], [9]]\n    >>> keras.tree.map_structure(lambda x, y: x * y, structure, structure)\n    [[1], [4], [9]]\n\n    >>> Foo = collections.namedtuple('Foo', ['a', 'b'])\n    >>> structure = Foo(a=1, b=2)\n    >>> keras.tree.map_structure(lambda v: v * 2, structure)\n    Foo(a=2, b=4)\n\n    Args:\n        func: A callable that accepts as many arguments as there are structures.\n        *structures: Arbitrarily nested structures of the same layout.\n        none_is_leaf: If True, `func` will be called on `None` leaves. If False,\n            `None` values are not passed to `func` and are returned in the\n            output directly.\n\n    Returns:\n        A new structure with the same layout as the given ones.\n\n    Raises:\n        TypeError: If `structures` is empty or `func` is not callable.\n        ValueError: If there is more than one items in `structures` and some of\n            the nested structures don't match according to the rules of\n            `assert_same_structure`.\n    \"\"\"\n    return tree_impl.map_structure(func, *structures, none_is_leaf=none_is_leaf)\n\n\n@keras_export(\"keras.tree.map_structure_up_to\")\ndef map_structure_up_to(shallow_structure, func, *structures):\n    \"\"\"Maps `func` through given structures up to `shallow_structure`.\n\n    This is a variant of `map_structure` which only maps the given structures\n    up to `shallow_structure`. All further nested components are retained as-is.\n\n    Examples:\n\n    >>> shallow_structure = [None, None]\n    >>> structure = [[1, 1], [2, 2]]\n    >>> keras.tree.map_structure_up_to(shallow_structure, len, structure)\n    [2, 2]\n\n    >>> shallow_structure = [None, [None, None]]\n    >>> keras.tree.map_structure_up_to(shallow_structure, str, structure)\n    ['[1, 1]', ['2', '2']]\n\n    Args:\n        shallow_structure: A structure with layout common to all `structures`.\n        func: A callable that accepts as many arguments as there are structures.\n        *structures: Arbitrarily nested structures of the same layout.\n\n    Returns:\n        A new structure with the same layout as `shallow_structure`.\n\n    Raises:\n        TypeError: If `structures` is empty or `func` is not callable.\n        ValueError: If one of the items in `structures` doesn't match the\n            nested structure of `shallow_structure` according to the rules of\n            `assert_same_structure`. Items in `structures` are allowed to be\n            nested deeper than `shallow_structure`, but they cannot be\n            shallower.\n    \"\"\"\n    return tree_impl.map_structure_up_to(shallow_structure, func, *structures)\n\n\n@keras_export(\"keras.tree.assert_same_structure\")\ndef assert_same_structure(a, b, check_types=None):\n    \"\"\"Asserts that two structures are nested in the same way.\n\n    This function verifies that the nested structures match. The leafs can be of\n    any type. At each level, the structures must be of the same type and have\n    the same number of elements. Instances of `dict`, `OrderedDict` and\n    `defaultdict` are all considered the same as long as they have the same set\n    of keys. However, `list`, `tuple`, `namedtuple` and `deque` are not the same\n    structures. Two namedtuples with identical fields and even identical names\n    are not the same structures.\n\n    Examples:\n\n    >>> keras.tree.assert_same_structure([(0, 1)], [(2, 3)])\n\n    >>> Foo = collections.namedtuple('Foo', ['a', 'b'])\n    >>> AlsoFoo = collections.namedtuple('Foo', ['a', 'b'])\n    >>> keras.tree.assert_same_structure(Foo(0, 1), Foo(2, 3))\n    >>> keras.tree.assert_same_structure(Foo(0, 1), AlsoFoo(2, 3))\n    Traceback (most recent call last):\n        ...\n    ValueError: The two structures don't have the same nested structure.\n    ...\n\n    Args:\n        a: an arbitrarily nested structure.\n        b: an arbitrarily nested structure.\n        check_types: Deprecated. The behavior of this flag was inconsistent, it\n            no longer has any effect. For a looser check, use\n            `assert_same_paths` instead, which considers `list`, `tuple`,\n            `namedtuple` and `deque` as matching structures.\n\n    Raises:\n        ValueError: If the two structures `a` and `b` don't match.\n    \"\"\"\n    if check_types is not None:\n        if check_types:\n            warnings.warn(\n                \"The `check_types` argument is deprecated and no longer has \"\n                \"any effect, please remove.\",\n                DeprecationWarning,\n                stacklevel=2,\n            )\n        else:\n            warnings.warn(\n                \"The `check_types` argument is deprecated and no longer has \"\n                \"any effect. For a looser check, use \"\n                \"`keras.tree.assert_same_paths()`, which considers `list`, \"\n                \"`tuple`, `namedtuple` and `deque` as matching\",\n                DeprecationWarning,\n                stacklevel=2,\n            )\n    return tree_impl.assert_same_structure(a, b)\n\n\n@keras_export(\"keras.tree.assert_same_paths\")\ndef assert_same_paths(a, b):\n    \"\"\"Asserts that two structures have identical paths in their tree structure.\n\n    This function verifies that two nested structures have the same paths.\n    Unlike `assert_same_structure`, this function only checks the paths\n    and ignores the collection types.\n    For Sequences, to path is the index: 0, 1, 2, etc. For Mappings, the path is\n    the key, for instance \"a\", \"b\", \"c\". Note that namedtuples also use indices\n    and not field names for the path.\n\n    Examples:\n    >>> keras.tree.assert_same_paths([0, 1], (2, 3))\n    >>> Point1 = collections.namedtuple('Point1', ['x', 'y'])\n    >>> Point2 = collections.namedtuple('Point2', ['x', 'y'])\n    >>> keras.tree.assert_same_paths(Point1(0, 1), Point2(2, 3))\n\n    Args:\n        a: an arbitrarily nested structure.\n        b: an arbitrarily nested structure.\n\n    Raises:\n        ValueError: If the paths in structure `a` don't match the paths in\n            structure `b`. The error message will include the specific paths\n            that differ.\n    \"\"\"\n    return tree_impl.assert_same_paths(a, b)\n\n\n@keras_export(\"keras.tree.pack_sequence_as\")\ndef pack_sequence_as(structure, flat_sequence):\n    \"\"\"Returns a given flattened sequence packed into a given structure.\n\n    If `structure` is an atom, `flat_sequence` must be a single-item list; in\n    this case the return value is `flat_sequence[0]`.\n\n    If `structure` is or contains a dict instance, the keys will be sorted to\n    pack the flat sequence in deterministic order. However, instances of\n    `collections.OrderedDict` are handled differently: their sequence order is\n    used instead of the sorted keys. The same convention is followed in\n    `flatten`. This correctly repacks dicts and `OrderedDicts` after they have\n    been flattened, or vice-versa.\n\n    Dictionaries with non-sortable keys are not supported.\n\n    Examples:\n\n    >>> structure = {\"key3\": \"\", \"key1\": \"\", \"key2\": \"\"}\n    >>> flat_sequence = [\"value1\", \"value2\", \"value3\"]\n    >>> keras.tree.pack_sequence_as(structure, flat_sequence)\n    {\"key3\": \"value3\", \"key1\": \"value1\", \"key2\": \"value2\"}\n\n    >>> structure = ((\"a\", \"b\"), (\"c\", \"d\", \"e\"), \"f\")\n    >>> flat_sequence = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]\n    >>> keras.tree.pack_sequence_as(structure, flat_sequence)\n    ((1.0, 2.0), (3.0, 4.0, 5.0), 6.0)\n\n    >>> structure = {\"key3\": {\"c\": (\"alpha\", \"beta\"), \"a\": (\"gamma\")},\n    ... \"key1\": {\"e\": \"val1\", \"d\": \"val2\"}}\n    >>> flat_sequence = [\"val2\", \"val1\", 3.0, 1.0, 2.0]\n    >>> keras.tree.pack_sequence_as(structure, flat_sequence)\n    {'key3': {'c': (1.0, 2.0), 'a': 3.0}, 'key1': {'e': 'val1', 'd': 'val2'}}\n\n    >>> structure = [\"a\"]\n    >>> flat_sequence = [np.array([[1, 2], [3, 4]])]\n    >>> keras.tree.pack_sequence_as(structure, flat_sequence)\n    [array([[1, 2],\n       [3, 4]])]\n\n    >>> structure = [\"a\"]\n    >>> flat_sequence = [keras.ops.ones([2, 2])]\n    >>> keras.tree.pack_sequence_as(structure, flat_sequence)\n    [array([[1., 1.],\n       [1., 1.]]]\n\n    Args:\n        structure: Arbitrarily nested structure.\n        flat_sequence: Flat sequence to pack.\n\n    Returns:\n        `flat_sequence` converted to have the same recursive structure as\n        `structure`.\n\n    Raises:\n        TypeError: If `flat_sequence` is not iterable.\n        ValueError: If `flat_sequence` cannot be repacked as `structure`; for\n            instance, if `flat_sequence` has too few or too many elements.\n    \"\"\"\n    return tree_impl.pack_sequence_as(structure, flat_sequence)\n\n\n@keras_export(\"keras.tree.lists_to_tuples\")\ndef lists_to_tuples(structure):\n    \"\"\"Returns the structure with list instances changed to tuples.\n\n    Args:\n        structure: Arbitrarily nested structure.\n\n    Returns:\n        The same structure but with tuples instead of lists.\n    \"\"\"\n    return tree_impl.lists_to_tuples(structure)\n\n\n@keras_export(\"keras.tree.map_shape_structure\")\ndef map_shape_structure(func, structure):\n    \"\"\"Variant of keras.tree.map_structure that operates on shape tuples.\n\n    Tuples containing ints and Nones are considered shapes and passed to `func`.\n\n    Args:\n        structure: Arbitrarily nested structure.\n\n    Returns:\n        The same structure with `func` applied.\n    \"\"\"\n    return tree_impl.map_shape_structure(func, structure)\n"
  },
  {
    "path": "keras/src/tree/tree_test.py",
    "content": "import functools\nfrom collections import OrderedDict\nfrom collections import defaultdict\nfrom collections import deque\nfrom collections import namedtuple\n\nimport numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.tree.tree_api import MAP_TO_NONE\nfrom keras.src.utils.module_utils import dmtree\nfrom keras.src.utils.module_utils import optree\nfrom keras.src.utils.tracking import TrackedDict\nfrom keras.src.utils.tracking import TrackedList\nfrom keras.src.utils.tracking import TrackedOrderedDict\nfrom keras.src.utils.tracking import TrackedSet\n\nTEST_CASES = []\nif dmtree.available:\n    from keras.src.tree import dmtree_impl\n\n    TEST_CASES += [\n        {\n            \"testcase_name\": \"dmtree\",\n            \"t\": dmtree_impl,\n        }\n    ]\nif backend.backend() != \"torch\" and optree.available:\n    from keras.src.tree import optree_impl\n\n    TEST_CASES += [\n        {\n            \"testcase_name\": \"optree\",\n            \"t\": optree_impl,\n        },\n    ]\nif backend.backend() == \"torch\":\n    from keras.src.tree import torchtree_impl\n\n    TEST_CASES += [\n        {\n            \"testcase_name\": \"torchtree\",\n            \"t\": torchtree_impl,\n        },\n    ]\n\n\nEmpty = namedtuple(\"Empty\", [])\nPoint = namedtuple(\"Point\", [\"x\", \"y\"])\nOtherPoint = namedtuple(\"OtherPoint\", [\"x\", \"y\"])\n\n\ndef default_value():\n    return None\n\n\nclass Visitor:\n    def __init__(self, func):\n        self.func = func\n        self.visited_list = []\n\n    def __call__(self, x):\n        self.visited_list.append(x)\n        return self.func(x)\n\n    def visited(self):\n        ret = self.visited_list\n        self.visited_list = []\n        return ret\n\n\n@parameterized.named_parameters(TEST_CASES)\nclass TreeTest(testing.TestCase):\n    def setUp(self):\n        if dmtree.available and optree.available:\n            # If both are available, the annotation on the Keras tracking\n            # wrappers will have used optree. For testing purposes, we need to\n            # also register them with dm-tree.\n            from keras.src.tree import dmtree_impl\n\n            dmtree_impl.register_tree_node_class(TrackedList)\n            dmtree_impl.register_tree_node_class(TrackedSet)\n            dmtree_impl.register_tree_node_class(TrackedDict)\n            dmtree_impl.register_tree_node_class(TrackedOrderedDict)\n        super().setUp()\n\n    def assertEqualStrict(self, a, b):\n        self.assertEqual(a, b)\n        self.assertEqual(type(a), type(b))\n        if isinstance(a, OrderedDict):\n            # Verify order.\n            self.assertEqual(a.items(), b.items())\n        elif isinstance(a, defaultdict):\n            self.assertEqual(a.default_factory, b.default_factory)\n        # Recurse\n        if isinstance(a, (tuple, list, deque)):\n            for sub_a, sub_b in zip(a, b):\n                self.assertEqualStrict(sub_a, sub_b)\n        elif isinstance(a, dict):\n            for k in a:\n                self.assertEqualStrict(a[k], b[k])\n\n    def is_dmtree(self, tree_impl):\n        if dmtree.available:\n            from keras.src.tree import dmtree_impl\n\n            return tree_impl is dmtree_impl\n        return False\n\n    def test_is_nested(self, t):\n        # Non-nested.\n        self.assertFalse(t.is_nested(1))\n        self.assertFalse(t.is_nested(\"1234\"))\n        self.assertFalse(t.is_nested(b\"1234\"))\n        self.assertFalse(t.is_nested(bytearray(\"1234\", \"ascii\")))\n        self.assertFalse(t.is_nested(np.ones((4, 5))))\n        self.assertFalse(t.is_nested(ops.ones((4, 5))))\n        self.assertFalse(t.is_nested(set([1, 2])))\n\n        # Standard structures.\n        self.assertTrue(t.is_nested(()))\n        self.assertTrue(t.is_nested((1,)))\n        self.assertTrue(t.is_nested((1, 2)))\n        self.assertTrue(t.is_nested([]))\n        self.assertTrue(t.is_nested([1]))\n        self.assertTrue(t.is_nested([1, 2]))\n        self.assertTrue(t.is_nested(deque([])))\n        self.assertTrue(t.is_nested(deque([1])))\n        self.assertTrue(t.is_nested(deque([1, 2])))\n        self.assertTrue(t.is_nested(Empty()))\n        self.assertTrue(t.is_nested(Point(x=1, y=2)))\n        self.assertTrue(t.is_nested({}))\n        self.assertTrue(t.is_nested({\"a\": 1}))\n        self.assertTrue(t.is_nested({\"b\": 2, \"a\": 1}))\n        self.assertTrue(t.is_nested(OrderedDict()))\n        self.assertTrue(t.is_nested(OrderedDict([(\"a\", 1)])))\n        self.assertTrue(t.is_nested(OrderedDict([(\"b\", 2), (\"a\", 1)])))\n        self.assertTrue(t.is_nested(defaultdict(default_value)))\n        self.assertTrue(t.is_nested(defaultdict(default_value, [(\"a\", 1)])))\n        self.assertTrue(\n            t.is_nested(defaultdict(default_value, [(\"b\", 2), (\"a\", 1)]))\n        )\n\n        # Keras tracking wrappers.\n        self.assertTrue(t.is_nested(TrackedList([])))\n        self.assertTrue(t.is_nested(TrackedList([1])))\n        self.assertTrue(t.is_nested(TrackedList([1, 2])))\n        self.assertTrue(t.is_nested(TrackedSet([])))\n        self.assertTrue(t.is_nested(TrackedSet([1])))\n        self.assertTrue(t.is_nested(TrackedSet([1, 2])))\n        self.assertTrue(t.is_nested(TrackedDict({})))\n        self.assertTrue(t.is_nested(TrackedDict({\"a\": 1})))\n        self.assertTrue(t.is_nested(TrackedDict({\"b\": 2, \"a\": 1})))\n        self.assertTrue(t.is_nested(TrackedOrderedDict({})))\n        self.assertTrue(t.is_nested(TrackedOrderedDict({\"a\": 1})))\n        self.assertTrue(t.is_nested(TrackedOrderedDict({\"b\": 2, \"a\": 1})))\n\n    @pytest.mark.skipif(backend.backend() != \"tensorflow\", reason=\"tf only\")\n    def test_is_nested_tf_wrappers(self, t):\n        from tensorflow.python.trackable.data_structures import ListWrapper\n        from tensorflow.python.trackable.data_structures import _DictWrapper\n\n        self.assertTrue(t.is_nested(ListWrapper([])))\n        self.assertTrue(t.is_nested(ListWrapper([1])))\n        self.assertTrue(t.is_nested(ListWrapper([1, 2])))\n        self.assertTrue(t.is_nested(_DictWrapper({})))\n        self.assertTrue(t.is_nested(_DictWrapper({\"a\": 1})))\n        self.assertTrue(t.is_nested(_DictWrapper({\"b\": 2, \"a\": 1})))\n\n    def test_flatten(self, t):\n        # Non-nested.\n        self.assertEqualStrict(t.flatten(1), [1])\n\n        # Standard structures.\n        self.assertEqualStrict(t.flatten(()), [])\n        self.assertEqualStrict(t.flatten((1,)), [1])\n        self.assertEqualStrict(t.flatten((1, 2)), [1, 2])\n        self.assertEqualStrict(t.flatten([]), [])\n        self.assertEqualStrict(t.flatten([1]), [1])\n        self.assertEqualStrict(t.flatten([1, 2]), [1, 2])\n        self.assertEqualStrict(t.flatten(deque([])), [])\n        self.assertEqualStrict(t.flatten(deque([1])), [1])\n        self.assertEqualStrict(t.flatten(deque([1, 2])), [1, 2])\n        self.assertEqualStrict(t.flatten(Empty()), [])\n        self.assertEqualStrict(t.flatten(Point(y=2, x=1)), [1, 2])\n        self.assertEqualStrict(t.flatten({}), [])\n        self.assertEqualStrict(t.flatten({\"a\": 1}), [1])\n        self.assertEqualStrict(t.flatten({\"b\": 2, \"a\": 1}), [1, 2])\n        self.assertEqualStrict(\n            t.flatten(OrderedDict()),\n            [],\n        )\n        self.assertEqualStrict(\n            t.flatten(OrderedDict([(\"a\", 1)])),\n            [1],\n        )\n        self.assertEqualStrict(\n            t.flatten(OrderedDict([(\"b\", 2), (\"a\", 1)])),\n            [2, 1],\n        )\n        self.assertEqualStrict(\n            t.flatten(defaultdict(default_value)),\n            [],\n        )\n        self.assertEqualStrict(\n            t.flatten(defaultdict(default_value, [(\"a\", 1)])),\n            [1],\n        )\n        self.assertEqualStrict(\n            t.flatten(defaultdict(default_value, [(\"b\", 2), (\"a\", 1)])),\n            [1, 2],\n        )\n\n        # Keras tracking wrappers.\n        self.assertEqualStrict(t.flatten(TrackedList([])), [])\n        self.assertEqualStrict(t.flatten(TrackedList([1])), [1])\n        self.assertEqualStrict(t.flatten(TrackedList([1, 2])), [1, 2])\n        self.assertEqualStrict(t.flatten(TrackedSet([])), [])\n        self.assertEqualStrict(t.flatten(TrackedSet([1])), [1])\n        self.assertEqualStrict(sorted(t.flatten(TrackedSet([1, 2]))), [1, 2])\n        self.assertEqualStrict(t.flatten(TrackedDict({})), [])\n        self.assertEqualStrict(t.flatten(TrackedDict({\"a\": 1})), [1])\n        self.assertEqualStrict(t.flatten(TrackedDict({\"b\": 2, \"a\": 1})), [1, 2])\n        self.assertEqualStrict(t.flatten(TrackedOrderedDict({})), [])\n        self.assertEqualStrict(t.flatten(TrackedOrderedDict({\"a\": 1})), [1])\n        self.assertEqualStrict(\n            t.flatten(TrackedOrderedDict({\"b\": 2, \"a\": 1})), [2, 1]\n        )\n\n        # Deeper nested structures.\n        self.assertEqualStrict(\n            t.flatten(\n                (\n                    {\"b\": [2, 3], \"a\": (1,)},\n                    TrackedDict({\"x\": 4, \"y\": TrackedList([5, 6])}),\n                    TrackedSet([7]),\n                    Point(y=9, x=8),\n                    np.array([10]),\n                )\n            ),\n            [1, 2, 3, 4, 5, 6, 7, 8, 9, np.array([10])],\n        )\n\n    @pytest.mark.skipif(backend.backend() != \"tensorflow\", reason=\"tf only\")\n    def test_flatten_tf_wrappers(self, t):\n        from tensorflow.python.trackable.data_structures import ListWrapper\n        from tensorflow.python.trackable.data_structures import _DictWrapper\n\n        self.assertEqualStrict(t.flatten(ListWrapper([])), [])\n        self.assertEqualStrict(t.flatten(ListWrapper([1])), [1])\n        self.assertEqualStrict(t.flatten(ListWrapper([1, 2])), [1, 2])\n        self.assertEqualStrict(t.flatten(_DictWrapper({})), [])\n        self.assertEqualStrict(t.flatten(_DictWrapper({\"a\": 1})), [1])\n        self.assertEqualStrict(\n            t.flatten(_DictWrapper({\"b\": 2, \"a\": 1})), [1, 2]\n        )\n\n    def test_flatten_with_path(self, t):\n        # Non-nested.\n        self.assertEqualStrict(\n            t.flatten_with_path(1),\n            [((), 1)],\n        )\n\n        # Standard structures.\n        self.assertEqualStrict(\n            t.flatten_with_path(()),\n            [],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path((1,)),\n            [((0,), 1)],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path((1, 2)),\n            [((0,), 1), ((1,), 2)],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path([]),\n            [],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path([1]),\n            [((0,), 1)],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path([1, 2]),\n            [((0,), 1), ((1,), 2)],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path(deque([])),\n            [],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path(deque([1])),\n            [((0,), 1)],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path(deque([1, 2])),\n            [((0,), 1), ((1,), 2)],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path(Empty()),\n            [],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path(Point(y=2, x=1)),\n            [((0,), 1), ((1,), 2)],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path({}),\n            [],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path({\"a\": 1}),\n            [((\"a\",), 1)],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path({\"b\": 2, \"a\": 1}),\n            [((\"a\",), 1), ((\"b\",), 2)],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path(OrderedDict()),\n            [],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path(OrderedDict([(\"a\", 1)])),\n            [((\"a\",), 1)],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path(OrderedDict([(\"b\", 2), (\"a\", 1)])),\n            [((\"b\",), 2), ((\"a\",), 1)],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path(defaultdict(default_value)),\n            [],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path(defaultdict(default_value, [(\"a\", 1)])),\n            [((\"a\",), 1)],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path(\n                defaultdict(default_value, [(\"b\", 2), (\"a\", 1)])\n            ),\n            [((\"a\",), 1), ((\"b\",), 2)],\n        )\n\n        # Keras tracking wrappers.\n        self.assertEqualStrict(\n            t.flatten_with_path(TrackedList([])),\n            [],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path(TrackedList([1])),\n            [((0,), 1)],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path(TrackedList([1, 2])),\n            [((0,), 1), ((1,), 2)],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path(TrackedSet([])),\n            [],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path(TrackedSet([1])),\n            [((0,), 1)],\n        )\n        flat = t.flatten_with_path(TrackedSet([1, 2]))\n        if flat[0][1] == 1:\n            self.assertEqualStrict(flat, [((0,), 1), ((1,), 2)])\n        else:\n            self.assertEqualStrict(flat, [((0,), 2), ((1,), 1)])\n        self.assertEqualStrict(\n            t.flatten_with_path(TrackedDict({})),\n            [],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path(TrackedDict({\"a\": 1})),\n            [((\"a\",), 1)],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path(TrackedDict({\"b\": 2, \"a\": 1})),\n            [((\"a\",), 1), ((\"b\",), 2)],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path(TrackedOrderedDict({})),\n            [],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path(TrackedOrderedDict({\"a\": 1})),\n            [((\"a\",), 1)],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path(TrackedOrderedDict({\"b\": 2, \"a\": 1})),\n            [((\"b\",), 2), ((\"a\",), 1)],\n        )\n\n        # Deeper nested structures.\n        self.assertEqualStrict(\n            t.flatten_with_path(\n                (\n                    {\"b\": [2, 3], \"a\": (1,)},\n                    TrackedDict({\"x\": 4, \"y\": TrackedList([5, 6])}),\n                    TrackedSet([7]),\n                    Point(y=9, x=8),\n                    np.array([10]),\n                )\n            ),\n            [\n                ((0, \"a\", 0), 1),\n                ((0, \"b\", 0), 2),\n                ((0, \"b\", 1), 3),\n                ((1, \"x\"), 4),\n                ((1, \"y\", 0), 5),\n                ((1, \"y\", 1), 6),\n                ((2, 0), 7),\n                ((3, 0), 8),\n                ((3, 1), 9),\n                ((4,), np.array([10])),\n            ],\n        )\n\n    @pytest.mark.skipif(backend.backend() != \"tensorflow\", reason=\"tf only\")\n    def test_flatten_with_path_tf_wrappers(self, t):\n        from tensorflow.python.trackable.data_structures import ListWrapper\n        from tensorflow.python.trackable.data_structures import _DictWrapper\n\n        self.assertEqualStrict(\n            t.flatten_with_path(ListWrapper([])),\n            [],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path(ListWrapper([1])),\n            [((0,), 1)],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path(ListWrapper([1, 2])),\n            [((0,), 1), ((1,), 2)],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path(_DictWrapper({})),\n            [],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path(_DictWrapper({\"a\": 1})),\n            [((\"a\",), 1)],\n        )\n        self.assertEqualStrict(\n            t.flatten_with_path(_DictWrapper({\"b\": 2, \"a\": 1})),\n            [((\"a\",), 1), ((\"b\",), 2)],\n        )\n\n    def test_pack_sequence_as(self, t):\n        # Non-nested.\n        self.assertEqualStrict(t.pack_sequence_as(10, [1]), 1)\n\n        # Standard structures.\n        self.assertEqualStrict(t.pack_sequence_as((), []), ())\n        self.assertEqualStrict(t.pack_sequence_as((10,), [1]), (1,))\n        self.assertEqualStrict(t.pack_sequence_as((10, 20), [1, 2]), (1, 2))\n        self.assertEqualStrict(t.pack_sequence_as([], []), [])\n        self.assertEqualStrict(t.pack_sequence_as([10], [1]), [1])\n        self.assertEqualStrict(t.pack_sequence_as([10, 20], [1, 2]), [1, 2])\n        self.assertEqualStrict(t.pack_sequence_as(deque([]), []), deque([]))\n        self.assertEqualStrict(t.pack_sequence_as(deque([10]), [1]), deque([1]))\n        self.assertEqualStrict(\n            t.pack_sequence_as(deque([10, 20]), [1, 2]), deque([1, 2])\n        )\n        self.assertEqualStrict(t.pack_sequence_as(Empty(), []), Empty())\n        self.assertEqualStrict(\n            t.pack_sequence_as(Point(y=20, x=10), [1, 2]), Point(x=1, y=2)\n        )\n        self.assertEqualStrict(t.pack_sequence_as({}, []), {})\n        self.assertEqualStrict(t.pack_sequence_as({\"a\": 10}, [1]), {\"a\": 1})\n        self.assertEqualStrict(\n            t.pack_sequence_as({\"b\": 20, \"a\": 10}, [1, 2]), {\"a\": 1, \"b\": 2}\n        )\n        self.assertEqualStrict(\n            t.pack_sequence_as(OrderedDict(), []), OrderedDict()\n        )\n        self.assertEqualStrict(\n            t.pack_sequence_as(OrderedDict([(\"a\", 10)]), [1]),\n            OrderedDict([(\"a\", 1)]),\n        )\n        self.assertEqualStrict(\n            t.pack_sequence_as(OrderedDict([(\"b\", 20), (\"a\", 10)]), [2, 1]),\n            OrderedDict([(\"b\", 2), (\"a\", 1)]),\n        )\n        self.assertEqualStrict(\n            t.pack_sequence_as(defaultdict(default_value), []),\n            defaultdict(default_value),\n        )\n        self.assertEqualStrict(\n            t.pack_sequence_as(defaultdict(default_value, [(\"a\", 10)]), [1]),\n            defaultdict(default_value, [(\"a\", 1)]),\n        )\n        self.assertEqualStrict(\n            t.pack_sequence_as(\n                defaultdict(default_value, [(\"b\", 20), (\"a\", 10)]), [1, 2]\n            ),\n            defaultdict(default_value, [(\"a\", 1), (\"b\", 2)]),\n        )\n\n        # Keras tracking wrappers.\n        self.assertEqualStrict(\n            t.pack_sequence_as(TrackedList([]), []), TrackedList([])\n        )\n        self.assertEqualStrict(\n            t.pack_sequence_as(TrackedList([10]), [1]), TrackedList([1])\n        )\n        self.assertEqualStrict(\n            t.pack_sequence_as(TrackedList([10, 20]), [1, 2]),\n            TrackedList([1, 2]),\n        )\n        self.assertEqualStrict(\n            t.pack_sequence_as(TrackedSet([]), []), TrackedSet([])\n        )\n        self.assertEqualStrict(\n            t.pack_sequence_as(TrackedSet([10]), [1]), TrackedSet([1])\n        )\n        self.assertEqualStrict(\n            t.pack_sequence_as(TrackedSet([10, 20]), [1, 2]), TrackedSet([1, 2])\n        )\n        self.assertEqualStrict(\n            t.pack_sequence_as(TrackedDict({}), []), TrackedDict({})\n        )\n        self.assertEqualStrict(\n            t.pack_sequence_as(TrackedDict({\"a\": 10}), [1]),\n            TrackedDict({\"a\": 1}),\n        )\n        self.assertEqualStrict(\n            t.pack_sequence_as(TrackedDict({\"b\": 20, \"a\": 10}), [1, 2]),\n            TrackedDict({\"a\": 1, \"b\": 2}),\n        )\n        self.assertEqualStrict(\n            t.pack_sequence_as(TrackedOrderedDict({}), []),\n            TrackedOrderedDict({}),\n        )\n        self.assertEqualStrict(\n            t.pack_sequence_as(TrackedOrderedDict({\"a\": 10}), [1]),\n            TrackedOrderedDict({\"a\": 1}),\n        )\n        self.assertEqualStrict(\n            t.pack_sequence_as(TrackedOrderedDict({\"b\": 20, \"a\": 10}), [2, 1]),\n            TrackedOrderedDict({\"b\": 2, \"a\": 1}),\n        )\n\n        # Deeper nested structures.\n        self.assertEqualStrict(\n            t.pack_sequence_as(\n                (\n                    {\"b\": [20, 30], \"a\": (10,)},\n                    TrackedDict({\"x\": 40, \"y\": TrackedList([50, 60])}),\n                    TrackedSet([70]),\n                    Point(y=90, x=80),\n                    100,\n                ),\n                [1, 2, 3, 4, 5, 6, 7, 8, 9, np.array([10])],\n            ),\n            (\n                {\"b\": [2, 3], \"a\": (1,)},\n                TrackedDict({\"x\": 4, \"y\": TrackedList([5, 6])}),\n                TrackedSet([7]),\n                Point(x=8, y=9),\n                np.array([10]),\n            ),\n        )\n\n        # Error cases.\n        with self.assertRaisesRegex(TypeError, \"[Ii]terable\"):\n            t.pack_sequence_as([10, 20], 1)\n        with self.assertRaisesRegex(ValueError, \"leaves.*[expected:|holds] 1\"):\n            t.pack_sequence_as(10, [])\n        with self.assertRaisesRegex(ValueError, \"leaves.*[expected:|holds] 1\"):\n            t.pack_sequence_as(10, [1, 2])\n        with self.assertRaisesRegex(ValueError, \"[Too few leaves|holds 2]\"):\n            t.pack_sequence_as([10, 20], [1])\n        with self.assertRaisesRegex(ValueError, \"[Too many leaves|holds 3]\"):\n            t.pack_sequence_as([10, 20], [1, 2, 3])\n\n    @pytest.mark.skipif(backend.backend() != \"tensorflow\", reason=\"tf only\")\n    def test_pack_sequence_as_tf_wrappers(self, t):\n        from tensorflow.python.trackable.data_structures import ListWrapper\n        from tensorflow.python.trackable.data_structures import _DictWrapper\n\n        self.assertEqualStrict(\n            t.pack_sequence_as(ListWrapper([]), []), ListWrapper([])\n        )\n        self.assertEqualStrict(\n            t.pack_sequence_as(ListWrapper([10]), [1]), ListWrapper([1])\n        )\n        self.assertEqualStrict(\n            t.pack_sequence_as(ListWrapper([10, 20]), [1, 2]),\n            ListWrapper([1, 2]),\n        )\n        self.assertEqualStrict(\n            t.pack_sequence_as(_DictWrapper({}), []), _DictWrapper({})\n        )\n        self.assertEqualStrict(\n            t.pack_sequence_as(_DictWrapper({\"a\": 10}), [1]),\n            _DictWrapper({\"a\": 1}),\n        )\n        self.assertEqualStrict(\n            t.pack_sequence_as(_DictWrapper({\"b\": 20, \"a\": 10}), [1, 2]),\n            _DictWrapper({\"b\": 2, \"a\": 1}),\n        )\n\n    def test_map_structure_with_one_structure(self, t):\n        def f1(x):\n            return x + 10 if isinstance(x, int) else None\n\n        # Non-nested.\n        self.assertEqualStrict(t.map_structure(f1, 1), 11)\n\n        # Standard structures.\n        self.assertEqualStrict(t.map_structure(f1, ()), ())\n        self.assertEqualStrict(t.map_structure(f1, (1,)), (11,))\n        self.assertEqualStrict(t.map_structure(f1, (1, 2)), (11, 12))\n        self.assertEqualStrict(t.map_structure(f1, []), [])\n        self.assertEqualStrict(t.map_structure(f1, [1]), [11])\n        self.assertEqualStrict(t.map_structure(f1, [1, 2]), [11, 12])\n        self.assertEqualStrict(t.map_structure(f1, deque([])), deque([]))\n        self.assertEqualStrict(t.map_structure(f1, deque([1])), deque([11]))\n        self.assertEqualStrict(\n            t.map_structure(f1, deque([1, 2])), deque([11, 12])\n        )\n        self.assertEqualStrict(t.map_structure(f1, Empty()), Empty())\n        self.assertEqualStrict(\n            t.map_structure(f1, Point(y=2, x=1)), Point(x=11, y=12)\n        )\n        self.assertEqualStrict(\n            t.map_structure(f1, {}),\n            {},\n        )\n        self.assertEqualStrict(\n            t.map_structure(f1, {\"a\": 1}),\n            {\"a\": 11},\n        )\n        self.assertEqualStrict(\n            t.map_structure(f1, {\"b\": 2, \"a\": 1}),\n            {\"a\": 11, \"b\": 12},\n        )\n        self.assertEqualStrict(\n            t.map_structure(f1, OrderedDict()),\n            OrderedDict(),\n        )\n        self.assertEqualStrict(\n            t.map_structure(f1, OrderedDict([(\"a\", 1)])),\n            OrderedDict([(\"a\", 11)]),\n        )\n        self.assertEqualStrict(\n            t.map_structure(f1, OrderedDict([(\"b\", 2), (\"a\", 1)])),\n            OrderedDict([(\"b\", 12), (\"a\", 11)]),\n        )\n        self.assertEqualStrict(\n            t.map_structure(f1, defaultdict(default_value)),\n            defaultdict(default_value),\n        )\n        self.assertEqualStrict(\n            t.map_structure(f1, defaultdict(default_value, [(\"a\", 1)])),\n            defaultdict(default_value, [(\"a\", 11)]),\n        )\n        self.assertEqualStrict(\n            t.map_structure(\n                f1, defaultdict(default_value, [(\"b\", 2), (\"a\", 1)])\n            ),\n            defaultdict(default_value, [(\"a\", 11), (\"b\", 12)]),\n        )\n\n        # Keras tracking wrappers.\n        self.assertEqualStrict(\n            t.map_structure(f1, TrackedList([])), TrackedList([])\n        )\n        self.assertEqualStrict(\n            t.map_structure(f1, TrackedList([1])), TrackedList([11])\n        )\n        self.assertEqualStrict(\n            t.map_structure(f1, TrackedList([1, 2])), TrackedList([11, 12])\n        )\n        self.assertEqualStrict(\n            t.map_structure(f1, TrackedSet([])), TrackedSet([])\n        )\n        self.assertEqualStrict(\n            t.map_structure(f1, TrackedSet([1])), TrackedSet([11])\n        )\n        self.assertEqualStrict(\n            t.map_structure(f1, TrackedSet([1, 2])), TrackedSet([11, 12])\n        )\n        self.assertEqualStrict(\n            t.map_structure(f1, TrackedDict()),\n            TrackedDict(),\n        )\n        self.assertEqualStrict(\n            t.map_structure(f1, TrackedDict({\"a\": 1})),\n            TrackedDict({\"a\": 11}),\n        )\n        self.assertEqualStrict(\n            t.map_structure(f1, TrackedDict({\"b\": 2, \"a\": 1})),\n            TrackedDict({\"a\": 11, \"b\": 12}),\n        )\n        self.assertEqualStrict(\n            t.map_structure(f1, TrackedOrderedDict()),\n            TrackedOrderedDict(),\n        )\n        self.assertEqualStrict(\n            t.map_structure(f1, TrackedOrderedDict({\"a\": 1})),\n            TrackedOrderedDict({\"a\": 11}),\n        )\n        self.assertEqualStrict(\n            t.map_structure(f1, TrackedOrderedDict({\"b\": 2, \"a\": 1})),\n            TrackedOrderedDict({\"b\": 12, \"a\": 11}),\n        )\n\n        # Deeper nested structures.\n        self.assertEqualStrict(\n            t.map_structure(\n                f1,\n                (\n                    {\"b\": [2, 3], \"a\": (1,)},\n                    TrackedDict({\"x\": 4, \"y\": TrackedList([5, 6])}),\n                    TrackedSet([7]),\n                    Point(y=9, x=8),\n                    np.array([10]),\n                ),\n            ),\n            (\n                {\"b\": [12, 13], \"a\": (11,)},\n                TrackedDict({\"x\": 14, \"y\": TrackedList([15, 16])}),\n                TrackedSet([17]),\n                Point(y=19, x=18),\n                None,\n            ),\n        )\n\n        # Error cases.\n        with self.assertRaisesRegex(TypeError, \"callable\"):\n            t.map_structure(\"bad\", [1, 2])\n        with self.assertRaisesRegex(ValueError, \"at least one structure\"):\n            t.map_structure(f1)\n\n    @pytest.mark.skipif(backend.backend() != \"tensorflow\", reason=\"tf only\")\n    def test_map_structure_with_one_structure_tf_wrappers(self, t):\n        from tensorflow.python.trackable.data_structures import ListWrapper\n        from tensorflow.python.trackable.data_structures import _DictWrapper\n\n        def f1(x):\n            return x + 10\n\n        self.assertEqualStrict(\n            t.map_structure(f1, ListWrapper([])), ListWrapper([])\n        )\n        self.assertEqualStrict(\n            t.map_structure(f1, ListWrapper([1])), ListWrapper([11])\n        )\n        self.assertEqualStrict(\n            t.map_structure(f1, ListWrapper([1, 2])), ListWrapper([11, 12])\n        )\n        self.assertEqualStrict(\n            t.map_structure(f1, _DictWrapper()),\n            _DictWrapper(),\n        )\n        self.assertEqualStrict(\n            t.map_structure(f1, _DictWrapper({\"a\": 1})),\n            _DictWrapper({\"a\": 11}),\n        )\n        self.assertEqualStrict(\n            t.map_structure(f1, _DictWrapper({\"b\": 2, \"a\": 1})),\n            _DictWrapper({\"a\": 11, \"b\": 12}),\n        )\n\n    def test_map_structure_with_multiple_structures(self, t):\n        def f2(x, y):\n            return x + y if isinstance(x, int) and isinstance(y, int) else None\n\n        # Non-nested.\n        self.assertEqualStrict(t.map_structure(f2, 1, 10), 11)\n\n        # Standard structures.\n        self.assertEqualStrict(t.map_structure(f2, ()), ())\n        self.assertEqualStrict(t.map_structure(f2, (1,), (10,)), (11,))\n        self.assertEqualStrict(t.map_structure(f2, (1, 2), (10, 20)), (11, 22))\n        self.assertEqualStrict(t.map_structure(f2, []), [])\n        self.assertEqualStrict(t.map_structure(f2, [1], [10]), [11])\n        self.assertEqualStrict(t.map_structure(f2, [1, 2], [10, 20]), [11, 22])\n        self.assertEqualStrict(t.map_structure(f2, deque([])), deque([]))\n        self.assertEqualStrict(\n            t.map_structure(f2, deque([1]), deque([10])), deque([11])\n        )\n        self.assertEqualStrict(\n            t.map_structure(f2, deque([1, 2]), deque([10, 20])), deque([11, 22])\n        )\n        self.assertEqualStrict(t.map_structure(f2, Empty()), Empty())\n        self.assertEqualStrict(\n            t.map_structure(f2, Point(y=2, x=1), Point(x=10, y=20)),\n            Point(x=11, y=22),\n        )\n        self.assertEqualStrict(t.map_structure(f2, {}), {})\n        self.assertEqualStrict(\n            t.map_structure(f2, {\"a\": 1}, {\"a\": 10}), {\"a\": 11}\n        )\n        self.assertEqualStrict(\n            t.map_structure(f2, {\"b\": 2, \"a\": 1}, {\"a\": 10, \"b\": 20}),\n            {\"a\": 11, \"b\": 22},\n        )\n        self.assertEqualStrict(\n            t.map_structure(f2, OrderedDict()),\n            OrderedDict(),\n        )\n        self.assertEqualStrict(\n            t.map_structure(\n                f2, OrderedDict([(\"a\", 1)]), OrderedDict([(\"a\", 10)])\n            ),\n            OrderedDict([(\"a\", 11)]),\n        )\n        self.assertEqualStrict(\n            t.map_structure(\n                f2,\n                OrderedDict([(\"b\", 2), (\"a\", 1)]),\n                OrderedDict([(\"b\", 20), (\"a\", 10)]),\n            ),\n            OrderedDict([(\"b\", 22), (\"a\", 11)]),\n        )\n        self.assertEqualStrict(\n            t.map_structure(\n                f2, defaultdict(default_value), defaultdict(default_value)\n            ),\n            defaultdict(default_value),\n        )\n        self.assertEqualStrict(\n            t.map_structure(\n                f2,\n                defaultdict(default_value, [(\"a\", 1)]),\n                defaultdict(default_value, [(\"a\", 10)]),\n            ),\n            defaultdict(default_value, [(\"a\", 11)]),\n        )\n        self.assertEqualStrict(\n            t.map_structure(\n                f2,\n                defaultdict(default_value, [(\"b\", 2), (\"a\", 1)]),\n                defaultdict(default_value, [(\"a\", 10), (\"b\", 20)]),\n            ),\n            defaultdict(default_value, [(\"a\", 11), (\"b\", 22)]),\n        )\n\n        # Keras tracking wrappers.\n        self.assertEqualStrict(\n            t.map_structure(\n                f2,\n                TrackedList([]),\n                TrackedList([]),\n            ),\n            TrackedList([]),\n        )\n        self.assertEqualStrict(\n            t.map_structure(\n                f2,\n                TrackedList([1]),\n                TrackedList([10]),\n            ),\n            TrackedList([11]),\n        )\n        self.assertEqualStrict(\n            t.map_structure(\n                f2,\n                TrackedList([1, 2]),\n                TrackedList([10, 20]),\n            ),\n            TrackedList([11, 22]),\n        )\n\n        # Known limitation of the dm-tree implementation:\n        # Registered classes are not handled when mapping multiple\n        # structures at once. TrackedSet is the only problematic one.\n        if not self.is_dmtree(t):\n            self.assertEqualStrict(\n                t.map_structure(\n                    f2,\n                    TrackedSet([]),\n                    TrackedSet([]),\n                ),\n                TrackedSet([]),\n            )\n            self.assertEqualStrict(\n                t.map_structure(\n                    f2,\n                    TrackedSet([1]),\n                    TrackedSet([10]),\n                ),\n                TrackedSet([11]),\n            )\n            self.assertEqualStrict(\n                t.map_structure(\n                    f2,\n                    TrackedSet([1, 2]),\n                    TrackedSet([10, 20]),\n                ),\n                TrackedSet([11, 22]),\n            )\n\n        self.assertEqualStrict(\n            t.map_structure(\n                f2,\n                TrackedDict({}),\n                TrackedDict({}),\n            ),\n            TrackedDict({}),\n        )\n        self.assertEqualStrict(\n            t.map_structure(\n                f2,\n                TrackedDict({\"a\": 1}),\n                TrackedDict({\"a\": 10}),\n            ),\n            TrackedDict({\"a\": 11}),\n        )\n        self.assertEqualStrict(\n            t.map_structure(\n                f2,\n                TrackedDict({\"b\": 2, \"a\": 1}),\n                TrackedDict({\"a\": 10, \"b\": 20}),\n            ),\n            TrackedDict({\"a\": 11, \"b\": 22}),\n        )\n\n        self.assertEqualStrict(\n            t.map_structure(\n                f2,\n                TrackedOrderedDict({}),\n                TrackedOrderedDict({}),\n            ),\n            TrackedOrderedDict({}),\n        )\n        self.assertEqualStrict(\n            t.map_structure(\n                f2,\n                TrackedOrderedDict({\"a\": 1}),\n                TrackedOrderedDict({\"a\": 10}),\n            ),\n            TrackedOrderedDict({\"a\": 11}),\n        )\n        self.assertEqualStrict(\n            t.map_structure(\n                f2,\n                TrackedOrderedDict({\"b\": 2, \"a\": 1}),\n                TrackedOrderedDict({\"b\": 20, \"a\": 10}),\n            ),\n            TrackedOrderedDict({\"b\": 22, \"a\": 11}),\n        )\n\n        # Deeper nested structures.\n        self.assertEqualStrict(\n            t.map_structure(\n                f2,\n                (\n                    {\"b\": [2, 3], \"a\": (1,)},\n                    TrackedDict({\"x\": 4, \"y\": TrackedList([5, 6])}),\n                    TrackedSet([7]),\n                    Point(y=9, x=8),\n                    np.array([10]),\n                ),\n                (\n                    {\"b\": [20, 30], \"a\": (10,)},\n                    TrackedDict({\"x\": 40, \"y\": TrackedList([50, 60])}),\n                    TrackedSet([70]),\n                    Point(y=90, x=80),\n                    np.array([100]),\n                ),\n            ),\n            (\n                {\"b\": [22, 33], \"a\": (11,)},\n                TrackedDict({\"x\": 44, \"y\": TrackedList([55, 66])}),\n                # Known limitation of the dm-tree implementation:\n                # Registered classes are not handled when mapping multiple\n                # structures at once. TrackedSet is the only problematic one.\n                None if self.is_dmtree(t) else TrackedSet([77]),\n                Point(y=99, x=88),\n                None,\n            ),\n        )\n\n        # Error cases.\n\n        # list, tuple, deque and namedtuple are not considered equivalent.\n        # Test all 6 combinations:\n        # tuple, list.\n        with self.assertRaisesRegex(ValueError, \"tuple\"):\n            t.map_structure(f2, (), [])\n        # tuple, deque.\n        with self.assertRaisesRegex(ValueError, \"tuple\"):\n            t.map_structure(f2, (), deque())\n        # tuple, namedtuple.\n        with self.assertRaisesRegex(ValueError, \"tuple\"):\n            t.map_structure(f2, (), Empty())\n        # list, deque.\n        with self.assertRaisesRegex(ValueError, \"list\"):\n            t.map_structure(f2, [], deque())\n        # list, namedtuple.\n        with self.assertRaisesRegex(ValueError, \"list\"):\n            t.map_structure(f2, [], Empty())\n        # deque, namedtuple.\n        with self.assertRaisesRegex(ValueError, \"deque\"):\n            t.map_structure(f2, deque(), Empty())\n\n        # Equivalent namedtuples don't match.\n        with self.assertRaisesRegex(ValueError, \"namedtuple\"):\n            t.map_structure(f2, Point(x=1, y=2), OtherPoint(x=10, y=20))\n\n        # Mismatched counts.\n        with self.assertRaisesRegex(ValueError, \"(number|[Aa]rity)\"):\n            t.map_structure(f2, (1, 2), (1,))\n        with self.assertRaisesRegex(ValueError, \"(number|[Aa]rity)\"):\n            t.map_structure(f2, [1, 2], [1])\n        with self.assertRaisesRegex(ValueError, \"(number|[Aa]rity)\"):\n            t.map_structure(f2, deque([1, 2]), deque([1]))\n\n        # dict, OrderedDict, defaultdict are considered equivalent, but the\n        # returned type is the first one. Test all 6 combinations (3 type\n        # combinations plus the order).\n        # dict, OrderedDict yields dict.\n        self.assertEqualStrict(\n            t.map_structure(\n                f2, {\"a\": 1, \"b\": 2}, OrderedDict([(\"b\", 20), (\"a\", 10)])\n            ),\n            {\"a\": 11, \"b\": 22},\n        )\n        # OrderedDict, dict yields OrderedDict with same order.\n        self.assertEqualStrict(\n            t.map_structure(\n                f2,\n                OrderedDict([(\"b\", 2), (\"a\", 1)]),\n                {\"a\": 10, \"b\": 20},\n            ),\n            OrderedDict([(\"b\", 22), (\"a\", 11)]),\n        )\n        # dict, defaultdict yields dict.\n        self.assertEqualStrict(\n            t.map_structure(\n                f2,\n                {\"a\": 1, \"b\": 2},\n                defaultdict(default_value, [(\"b\", 20), (\"a\", 10)]),\n            ),\n            {\"a\": 11, \"b\": 22},\n        )\n        # defaultdict, dict yields defaultdict.\n        self.assertEqualStrict(\n            t.map_structure(\n                f2,\n                defaultdict(default_value, [(\"b\", 2), (\"a\", 1)]),\n                {\"a\": 10, \"b\": 20},\n            ),\n            defaultdict(default_value, [(\"a\", 11), (\"b\", 22)]),\n        )\n        # defaultdict, OrderedDict yields defaultdict.\n        self.assertEqualStrict(\n            t.map_structure(\n                f2,\n                defaultdict(default_value, [(\"a\", 1), (\"b\", 2)]),\n                OrderedDict([(\"b\", 20), (\"a\", 10)]),\n            ),\n            defaultdict(default_value, [(\"a\", 11), (\"b\", 22)]),\n        )\n        # OrderedDict, defaultdict yields OrderedDict with same order.\n        self.assertEqualStrict(\n            t.map_structure(\n                f2,\n                OrderedDict([(\"b\", 2), (\"a\", 1)]),\n                defaultdict(default_value, [(\"a\", 10), (\"b\", 20)]),\n            ),\n            OrderedDict([(\"b\", 22), (\"a\", 11)]),\n        )\n\n        # Multiple OrderedDicts with same keys but different orders, the order\n        # of the first one prevails.\n        self.assertEqualStrict(\n            t.map_structure(\n                f2,\n                OrderedDict([(\"b\", 2), (\"a\", 1)]),\n                OrderedDict([(\"a\", 10), (\"b\", 20)]),\n            ),\n            OrderedDict([(\"b\", 22), (\"a\", 11)]),\n        )\n\n        # Mismatched keys\n        with self.assertRaisesRegex(ValueError, \"[key|Node arity mismatch]\"):\n            t.map_structure(f2, {\"a\": 1, \"b\": 2}, {\"a\": 1})\n        with self.assertRaisesRegex(ValueError, \"[key|Node arity mismatch]\"):\n            t.map_structure(\n                f2,\n                defaultdict(default_value, [(\"a\", 1), (\"b\", 2)]),\n                defaultdict(default_value, [(\"a\", 10)]),\n            )\n        with self.assertRaisesRegex(ValueError, \"[key|Node arity mismatch]\"):\n            t.map_structure(\n                f2, OrderedDict([(\"a\", 1), (\"b\", 2)]), OrderedDict([(\"a\", 10)])\n            )\n\n    @pytest.mark.skipif(backend.backend() != \"tensorflow\", reason=\"tf only\")\n    def test_map_structure_with_multiple_structures_tf_wrappers(self, t):\n        from tensorflow.python.trackable.data_structures import ListWrapper\n        from tensorflow.python.trackable.data_structures import _DictWrapper\n\n        def f2(x, y):\n            return x + y\n\n        self.assertEqualStrict(\n            t.map_structure(\n                f2,\n                ListWrapper([]),\n                ListWrapper([]),\n            ),\n            ListWrapper([]),\n        )\n        self.assertEqualStrict(\n            t.map_structure(\n                f2,\n                ListWrapper([1]),\n                ListWrapper([10]),\n            ),\n            ListWrapper([11]),\n        )\n        self.assertEqualStrict(\n            t.map_structure(\n                f2,\n                ListWrapper([1, 2]),\n                ListWrapper([10, 20]),\n            ),\n            ListWrapper([11, 22]),\n        )\n        self.assertEqualStrict(\n            t.map_structure(\n                f2,\n                _DictWrapper({}),\n                _DictWrapper({}),\n            ),\n            _DictWrapper({}),\n        )\n        self.assertEqualStrict(\n            t.map_structure(\n                f2,\n                _DictWrapper({\"a\": 1}),\n                _DictWrapper({\"a\": 10}),\n            ),\n            _DictWrapper({\"a\": 11}),\n        )\n        self.assertEqualStrict(\n            t.map_structure(\n                f2,\n                _DictWrapper({\"b\": 2, \"a\": 1}),\n                _DictWrapper({\"a\": 10, \"b\": 20}),\n            ),\n            _DictWrapper({\"a\": 11, \"b\": 22}),\n        )\n\n    def test_map_structure_up_to(self, t):\n        # Named tuples.\n        shallow = OtherPoint(x=2, y=3)\n        deep = OtherPoint(x=Point(x=1, y=2), y=Point(x=2, y=3))\n        out = t.map_structure_up_to(\n            shallow,\n            lambda a, b: (a + b.x) * b.y,\n            shallow,\n            deep,\n        )\n        self.assertEqual(out.x, 6)\n        self.assertEqual(out.y, 15)\n\n        # Lists.\n        data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]\n        name_list = [\"evens\", [\"odds\", \"primes\"]]\n        out = t.map_structure_up_to(\n            name_list,\n            lambda name, sec: \"first_{}_{}\".format(len(sec), name),\n            name_list,\n            data_list,\n        )\n        self.assertEqual(\n            out, [\"first_4_evens\", [\"first_5_odds\", \"first_3_primes\"]]\n        )\n\n    def test_assert_same_structure(self, t):\n        # Non-nested.\n        t.assert_same_structure(1, 10)\n\n        # Standard structures.\n        t.assert_same_structure((), ())\n        t.assert_same_structure((1,), (10,))\n        t.assert_same_structure((1, 2), (10, 20))\n        t.assert_same_structure([], [])\n        t.assert_same_structure([1], [10])\n        t.assert_same_structure([1, 2], [10, 20])\n        t.assert_same_structure(deque([]), deque([]))\n        t.assert_same_structure(deque([1]), deque([1]))\n        t.assert_same_structure(deque([1, 2]), deque([10, 20]))\n        t.assert_same_structure(Empty(), Empty())\n        t.assert_same_structure(Point(y=1, x=2), Point(x=10, y=20))\n        t.assert_same_structure({}, {})\n        t.assert_same_structure({\"a\": 1}, {\"a\": 10})\n        t.assert_same_structure({\"b\": 2, \"a\": 1}, {\"a\": 10, \"b\": 20})\n        t.assert_same_structure(OrderedDict(), OrderedDict())\n        t.assert_same_structure(\n            OrderedDict([(\"a\", 1)]), OrderedDict([(\"a\", 10)])\n        )\n        t.assert_same_structure(\n            OrderedDict([(\"b\", 1), (\"a\", 2)]),\n            OrderedDict([(\"b\", 10), (\"a\", 20)]),\n        )\n        t.assert_same_structure(\n            defaultdict(default_value), defaultdict(default_value)\n        )\n        t.assert_same_paths(\n            defaultdict(default_value, [(\"a\", 1)]),\n            defaultdict(default_value, [(\"a\", 10)]),\n        )\n        t.assert_same_paths(\n            defaultdict(default_value, [(\"b\", 1), (\"a\", 2)]),\n            defaultdict(default_value, [(\"a\", 10), (\"b\", 20)]),\n        )\n\n        # Keras tracking wrappers.\n        t.assert_same_structure(\n            TrackedList([]),\n            TrackedList([]),\n        )\n        t.assert_same_structure(\n            TrackedList([1]),\n            TrackedList([10]),\n        )\n        t.assert_same_structure(\n            TrackedList([1, 2]),\n            TrackedList([10, 20]),\n        )\n        t.assert_same_structure(\n            TrackedSet([]),\n            TrackedSet([]),\n        )\n        t.assert_same_structure(\n            TrackedSet([1]),\n            TrackedSet([10]),\n        )\n        t.assert_same_structure(\n            TrackedSet([1, 2]),\n            TrackedSet([10, 20]),\n        )\n        t.assert_same_structure(\n            TrackedDict({}),\n            TrackedDict({}),\n        )\n        t.assert_same_structure(\n            TrackedDict({\"a\": 1}),\n            TrackedDict({\"a\": 10}),\n        )\n        t.assert_same_structure(\n            TrackedDict({\"b\": 2, \"a\": 1}),\n            TrackedDict({\"a\": 10, \"b\": 20}),\n        )\n        t.assert_same_structure(\n            TrackedOrderedDict({}),\n            TrackedOrderedDict({}),\n        )\n        t.assert_same_structure(\n            TrackedOrderedDict({\"a\": 1}),\n            TrackedOrderedDict({\"a\": 10}),\n        )\n        t.assert_same_structure(\n            TrackedOrderedDict({\"b\": 2, \"a\": 1}),\n            TrackedOrderedDict({\"b\": 20, \"a\": 10}),\n        )\n\n        # Deeper nested structures.\n        t.assert_same_structure(\n            (\n                {\"b\": [2, 3], \"a\": (1,)},\n                TrackedDict({\"x\": 4, \"y\": TrackedList([5, 6])}),\n                TrackedSet([7]),\n                Point(y=9, x=8),\n                np.array([10]),\n            ),\n            (\n                {\"b\": [20, 30], \"a\": (10,)},\n                TrackedDict({\"x\": 40, \"y\": TrackedList([50, 60])}),\n                TrackedSet([70]),\n                Point(y=90, x=80),\n                np.array([100]),\n            ),\n        )\n\n        # Error cases.\n\n        # Non-nested vs. nested.\n        with self.assertRaisesRegex(ValueError, \"don't have the same nested\"):\n            t.assert_same_structure(1, ())\n        with self.assertRaisesRegex(ValueError, \"[Ee]xpected.*tuple\"):\n            t.assert_same_structure((), 1)\n        with self.assertRaisesRegex(ValueError, \"don't have the same nested\"):\n            t.assert_same_structure(1, [])\n        with self.assertRaisesRegex(ValueError, \"[Ee]xpected.*list\"):\n            t.assert_same_structure([], 1)\n        with self.assertRaisesRegex(ValueError, \"don't have the same nested\"):\n            t.assert_same_structure(1, deque([]))\n        with self.assertRaisesRegex(ValueError, \"[Ee]xpected.*deque\"):\n            t.assert_same_structure(deque([]), 1)\n        with self.assertRaisesRegex(ValueError, \"don't have the same nested\"):\n            t.assert_same_structure(1, Empty())\n        with self.assertRaisesRegex(ValueError, \"[Ee]xpected.*(Empty|tuple)\"):\n            t.assert_same_structure(Empty(), 1)\n        with self.assertRaisesRegex(ValueError, \"don't have the same nested\"):\n            t.assert_same_structure(1, Point(x=1, y=2))\n        with self.assertRaisesRegex(ValueError, \"[Ee]xpected.*(Point|tuple)\"):\n            t.assert_same_structure(Point(x=1, y=2), 1)\n        with self.assertRaisesRegex(ValueError, \"don't have the same nested\"):\n            t.assert_same_structure(1, {})\n        with self.assertRaisesRegex(ValueError, \"[Ee]xpected.*dict\"):\n            t.assert_same_structure({}, 1)\n        with self.assertRaisesRegex(ValueError, \"don't have the same nested\"):\n            t.assert_same_structure(1, OrderedDict())\n        with self.assertRaisesRegex(ValueError, \"[Ee]xpected.*OrderedDict\"):\n            t.assert_same_structure(OrderedDict(), 1)\n        with self.assertRaisesRegex(ValueError, \"don't have the same nested\"):\n            t.assert_same_structure(1, defaultdict(default_value))\n        with self.assertRaisesRegex(ValueError, \"[Ee]xpected.*defaultdict\"):\n            t.assert_same_structure(defaultdict(default_value), 1)\n\n        # Non-nested vs. Keras tracking wrappers.\n        with self.assertRaisesRegex(ValueError, \"(nested|TrackedList)\"):\n            t.assert_same_structure(1, TrackedList([]))\n        with self.assertRaisesRegex(ValueError, \"[Ee]xpected.*TrackedList\"):\n            t.assert_same_structure(TrackedList([]), 1)\n        with self.assertRaisesRegex(ValueError, \"(nested|TrackedSet)\"):\n            t.assert_same_structure(1, TrackedSet([]))\n        with self.assertRaisesRegex(ValueError, \"[Ee]xpected.*TrackedSet\"):\n            t.assert_same_structure(TrackedSet([]), 1)\n        with self.assertRaisesRegex(ValueError, \"(nested|TrackedDict)\"):\n            t.assert_same_structure(1, TrackedDict([]))\n        with self.assertRaisesRegex(ValueError, \"[Ee]xpected.*TrackedDict\"):\n            t.assert_same_structure(TrackedDict([]), 1)\n        with self.assertRaisesRegex(ValueError, \"(nested|TrackedOrderedDict)\"):\n            t.assert_same_structure(1, TrackedOrderedDict([]))\n        with self.assertRaisesRegex(\n            ValueError, \"[Ee]xpected.*TrackedOrderedDict\"\n        ):\n            t.assert_same_structure(TrackedOrderedDict([]), 1)\n\n        # list, tuple, deque and namedtuple are not considered equivalent.\n        # Test all 6 combinations:\n        # tuple, list.\n        with self.assertRaisesRegex(ValueError, \"[Ee]xpected.*tuple\"):\n            t.assert_same_structure((), [])\n        # tuple, deque.\n        with self.assertRaisesRegex(ValueError, \"[Ee]xpected.*tuple\"):\n            t.assert_same_structure((), deque())\n        # tuple, namedtuple.\n        with self.assertRaisesRegex(ValueError, \"[Ee]xpected.*tuple\"):\n            t.assert_same_structure((), Empty())\n        # list, deque.\n        with self.assertRaisesRegex(ValueError, \"list\"):\n            t.assert_same_structure([], deque())\n        # list, namedtuple.\n        with self.assertRaisesRegex(ValueError, \"list\"):\n            t.assert_same_structure([], Empty())\n        # deque, namedtuple.\n        with self.assertRaisesRegex(ValueError, \"deque\"):\n            t.assert_same_structure(deque(), Empty())\n\n        # Equivalent namedtuples don't match.\n        with self.assertRaisesRegex(ValueError, \"[Ee]xpected.*[. ]Point\"):\n            t.assert_same_structure(Point(x=1, y=2), OtherPoint(x=10, y=20))\n\n        # Mismatched counts.\n        with self.assertRaisesRegex(ValueError, \"[Aa]rity mismatch\"):\n            t.assert_same_structure((1, 2), (1,))\n        with self.assertRaisesRegex(ValueError, \"[Aa]rity mismatch\"):\n            t.assert_same_structure([1, 2], [1])\n        with self.assertRaisesRegex(ValueError, \"[Aa]rity mismatch\"):\n            t.assert_same_structure(deque([1, 2]), deque([1]))\n\n        # Mismatched counts with Keras tracking wrappers.\n        with self.assertRaisesRegex(ValueError, \"[Aa]rity mismatch\"):\n            t.assert_same_structure(TrackedList([1, 2]), TrackedList([1]))\n        with self.assertRaisesRegex(ValueError, \"[Aa]rity mismatch\"):\n            t.assert_same_structure(TrackedSet([1, 2]), TrackedSet([1]))\n\n        # dict, OrderedDict, defaultdict are considered equivalent.\n        # Test all 6 combinations (3 type combinations plus the order).\n        # dict, OrderedDict.\n        t.assert_same_structure(\n            {\"a\": 1, \"b\": 2}, OrderedDict([(\"b\", 20), (\"a\", 10)])\n        )\n        # OrderedDict, dict.\n        t.assert_same_structure(\n            OrderedDict([(\"b\", 20), (\"a\", 10)]), {\"a\": 1, \"b\": 2}\n        )\n        # dict, defaultdict.\n        t.assert_same_structure(\n            {\"a\": 1, \"b\": 2},\n            defaultdict(default_value, [(\"b\", 20), (\"a\", 10)]),\n        )\n        # defaultdict, dict.\n        t.assert_same_structure(\n            defaultdict(default_value, [(\"b\", 20), (\"a\", 10)]),\n            {\"a\": 1, \"b\": 2},\n        )\n        # defaultdict, OrderedDict.\n        t.assert_same_structure(\n            defaultdict(default_value, [(\"a\", 1), (\"b\", 2)]),\n            OrderedDict([(\"b\", 20), (\"a\", 10)]),\n        )\n        # OrderedDict, defaultdict.\n        t.assert_same_structure(\n            OrderedDict([(\"b\", 2), (\"a\", 1)]),\n            defaultdict(default_value, [(\"a\", 10), (\"b\", 20)]),\n        )\n\n        # Two OrderedDicts with same keys but different orders.\n        t.assert_same_structure(\n            OrderedDict([(\"b\", 2), (\"a\", 1)]),\n            OrderedDict([(\"a\", 10), (\"b\", 20)]),\n        )\n\n        # Keras tracking wrappers are not equivalent to the raw structures.\n        with self.assertRaisesRegex(ValueError, \"[Ee]xpected.*TrackedList\"):\n            t.assert_same_structure(TrackedList([1, 2]), list([10, 20]))\n        with self.assertRaisesRegex(ValueError, \"[Ee]xpected.*list\"):\n            t.assert_same_structure(list([1, 2]), TrackedList([10, 20]))\n        with self.assertRaisesRegex(ValueError, \"[Ee]xpected.*TrackedSet\"):\n            t.assert_same_structure(TrackedSet([1, 2]), list([10, 20]))\n        with self.assertRaisesRegex(ValueError, \"[Ee]xpected.*list\"):\n            t.assert_same_structure(list([1, 2]), TrackedSet([10, 20]))\n        with self.assertRaisesRegex(ValueError, \"[Ee]xpected.*TrackedDict\"):\n            t.assert_same_structure(\n                TrackedDict({\"b\": 2, \"a\": 1}), {\"a\": 10, \"b\": 20}\n            )\n        with self.assertRaisesRegex(ValueError, \"[Ee]xpected.*dict\"):\n            t.assert_same_structure(\n                {\"b\": 2, \"a\": 1}, TrackedDict({\"a\": 10, \"b\": 20})\n            )\n        with self.assertRaisesRegex(\n            ValueError, \"[Ee]xpected.*TrackedOrderedDict\"\n        ):\n            t.assert_same_structure(\n                TrackedOrderedDict({\"b\": 2, \"a\": 1}),\n                OrderedDict({\"b\": 20, \"a\": 10}),\n            )\n        with self.assertRaisesRegex(ValueError, \"[Ee]xpected.*OrderedDict\"):\n            t.assert_same_structure(\n                OrderedDict({\"b\": 2, \"a\": 1}),\n                TrackedOrderedDict({\"b\": 20, \"a\": 10}),\n            )\n\n        # Mismatched key count.\n        with self.assertRaisesRegex(\n            ValueError, \"[Dd]ictionary key mismatch|Node arity mismatch\"\n        ):\n            t.assert_same_structure(\n                {\"a\": 1, \"b\": 2},\n                {\"a\": 1},\n            )\n        with self.assertRaisesRegex(\n            ValueError, \"[Dd]ictionary key mismatch|Node arity mismatch\"\n        ):\n            t.assert_same_structure(\n                defaultdict(default_value, [(\"a\", 1), (\"b\", 2)]),\n                defaultdict(default_value, [(\"a\", 10)]),\n            )\n        with self.assertRaisesRegex(\n            ValueError, \"[Dd]ictionary key mismatch|Node arity mismatch\"\n        ):\n            t.assert_same_structure(\n                OrderedDict([(\"a\", 1), (\"b\", 2)]),\n                OrderedDict([(\"a\", 10)]),\n            )\n\n        # Mismatched keys.\n        with self.assertRaisesRegex(\n            ValueError, \"[Dd]ictionary key mismatch|Node keys mismatch\"\n        ):\n            t.assert_same_structure(\n                {\"a\": 1},\n                {\"b\": 2},\n            )\n        with self.assertRaisesRegex(\n            ValueError, \"[Dd]ictionary key mismatch|Node keys mismatch\"\n        ):\n            t.assert_same_structure(\n                defaultdict(default_value, [(\"a\", 1)]),\n                defaultdict(default_value, [(\"b\", 2)]),\n            )\n        with self.assertRaisesRegex(\n            ValueError, \"[Dd]ictionary key mismatch|Node keys mismatch\"\n        ):\n            t.assert_same_structure(\n                OrderedDict([(\"a\", 1)]),\n                OrderedDict([(\"b\", 2)]),\n            )\n\n        # Mismatched key count and keys with TrackedDict.\n        with self.assertRaisesRegex(\n            ValueError, \"Mismatch custom node data|Node arity mismatch\"\n        ):\n            t.assert_same_structure(\n                TrackedDict({\"a\": 1, \"b\": 2}),\n                TrackedDict({\"a\": 1}),\n            )\n        with self.assertRaisesRegex(\n            ValueError, \"Mismatch custom node data|Node context mismatch\"\n        ):\n            t.assert_same_structure(\n                TrackedDict({\"a\": 1}),\n                TrackedDict({\"b\": 2}),\n            )\n\n        # Mismatched key count and keys and order with TrackedOrderedDict.\n        with self.assertRaisesRegex(\n            ValueError, \"Mismatch custom node data|Node arity mismatch\"\n        ):\n            t.assert_same_structure(\n                TrackedOrderedDict({\"a\": 1, \"b\": 2}),\n                TrackedOrderedDict({\"a\": 1}),\n            )\n        with self.assertRaisesRegex(\n            ValueError, \"Mismatch custom node data|Node context mismatch\"\n        ):\n            t.assert_same_structure(\n                TrackedOrderedDict({\"a\": 1}),\n                TrackedOrderedDict({\"b\": 2}),\n            )\n        with self.assertRaisesRegex(\n            ValueError, \"Mismatch custom node data|Node context mismatch\"\n        ):\n            t.assert_same_structure(\n                TrackedOrderedDict({\"a\": 1, \"b\": 2}),\n                TrackedOrderedDict({\"b\": 2, \"a\": 1}),\n            )\n\n    @pytest.mark.skipif(backend.backend() != \"tensorflow\", reason=\"tf only\")\n    def test_assert_same_structure_tf_wrappers(self, t):\n        from tensorflow.python.trackable.data_structures import ListWrapper\n        from tensorflow.python.trackable.data_structures import _DictWrapper\n\n        t.assert_same_structure(ListWrapper([]), ListWrapper([]))\n        t.assert_same_structure(ListWrapper([1]), ListWrapper([10]))\n        t.assert_same_structure(ListWrapper([1, 2]), ListWrapper([10, 20]))\n        t.assert_same_structure(_DictWrapper(), _DictWrapper())\n        t.assert_same_structure(_DictWrapper({\"a\": 1}), _DictWrapper({\"a\": 11}))\n        t.assert_same_structure(\n            _DictWrapper({\"b\": 2, \"a\": 1}), _DictWrapper({\"a\": 11, \"b\": 12})\n        )\n\n        # Count and key mismatch\n        with self.assertRaisesRegex(ValueError, \"[Aa]rity mismatch\"):\n            t.assert_same_structure(ListWrapper([1, 2]), ListWrapper([1]))\n        with self.assertRaisesRegex(ValueError, \"Mismatch custom node data\"):\n            t.assert_same_structure(\n                _DictWrapper({\"a\": 1, \"b\": 2}),\n                _DictWrapper({\"a\": 1}),\n            )\n        with self.assertRaisesRegex(ValueError, \"Mismatch custom node data\"):\n            t.assert_same_structure(\n                _DictWrapper({\"a\": 1}),\n                _DictWrapper({\"b\": 2}),\n            )\n\n        # Tensorflow wrappers are not equivalent to the raw structures.\n        with self.assertRaisesRegex(ValueError, \"[Ee]xpected.*ListWrapper\"):\n            t.assert_same_structure(ListWrapper([1, 2]), list([10, 20]))\n        with self.assertRaisesRegex(ValueError, \"[Ee]xpected.*list\"):\n            t.assert_same_structure(list([1, 2]), ListWrapper([10, 20]))\n        with self.assertRaisesRegex(ValueError, \"[Ee]xpected.*_DictWrapper\"):\n            t.assert_same_structure(\n                _DictWrapper({\"b\": 2, \"a\": 1}), {\"a\": 10, \"b\": 20}\n            )\n        with self.assertRaisesRegex(ValueError, \"[Ee]xpected.*dict\"):\n            t.assert_same_structure(\n                {\"b\": 2, \"a\": 1}, _DictWrapper({\"a\": 10, \"b\": 20})\n            )\n\n    def test_assert_same_paths(self, t):\n        # Non-nested.\n        t.assert_same_paths(1, 10)\n\n        # Standard structures.\n        t.assert_same_paths((), ())\n        t.assert_same_paths((1,), (10,))\n        t.assert_same_paths((1, 2), (10, 20))\n        t.assert_same_paths([], [])\n        t.assert_same_paths([1], [10])\n        t.assert_same_paths([1, 2], [10, 20])\n        t.assert_same_paths(deque([]), deque([]))\n        t.assert_same_paths(deque([1]), deque([10]))\n        t.assert_same_paths(deque([1, 2]), deque([10, 20]))\n        t.assert_same_paths(Empty(), Empty())\n        t.assert_same_paths(Point(y=2, x=1), Point(x=10, y=20))\n        t.assert_same_paths({}, {})\n        t.assert_same_paths({\"a\": 1}, {\"a\": 10})\n        t.assert_same_paths({\"b\": None, \"a\": None}, {\"a\": 10, \"b\": 20})\n        t.assert_same_paths(OrderedDict(), OrderedDict())\n        t.assert_same_paths(OrderedDict([(\"a\", 1)]), OrderedDict([(\"a\", 10)]))\n        t.assert_same_paths(\n            OrderedDict([(\"b\", 2), (\"a\", 1)]),\n            OrderedDict([(\"a\", 10), (\"b\", 20)]),\n        )\n        t.assert_same_paths(\n            defaultdict(default_value), defaultdict(default_value)\n        )\n        t.assert_same_paths(\n            defaultdict(default_value, [(\"a\", 1)]),\n            defaultdict(default_value, [(\"a\", 10)]),\n        )\n        t.assert_same_paths(\n            defaultdict(default_value, [(\"b\", 2), (\"a\", 1)]),\n            defaultdict(default_value, [(\"a\", 1), (\"b\", 2)]),\n        )\n\n        # Keras tracking wrappers.\n        t.assert_same_paths(\n            TrackedList([]),\n            TrackedList([]),\n        )\n        t.assert_same_paths(\n            TrackedList([1]),\n            TrackedList([10]),\n        )\n        t.assert_same_paths(\n            TrackedList([1, 2]),\n            TrackedList([10, 20]),\n        )\n        t.assert_same_paths(\n            TrackedSet([]),\n            TrackedSet([]),\n        )\n        t.assert_same_paths(\n            TrackedSet([1]),\n            TrackedSet([10]),\n        )\n        t.assert_same_paths(\n            TrackedSet([1, 2]),\n            TrackedSet([10, 20]),\n        )\n        t.assert_same_paths(\n            TrackedDict({}),\n            TrackedDict({}),\n        )\n        t.assert_same_paths(\n            TrackedDict({\"a\": 1}),\n            TrackedDict({\"a\": 10}),\n        )\n        t.assert_same_paths(\n            TrackedDict({\"b\": 2, \"a\": 1}),\n            TrackedDict({\"a\": 10, \"b\": 20}),\n        )\n        t.assert_same_paths(\n            TrackedOrderedDict({}),\n            TrackedOrderedDict({}),\n        )\n        t.assert_same_paths(\n            TrackedOrderedDict({\"a\": 1}),\n            TrackedOrderedDict({\"a\": 10}),\n        )\n        t.assert_same_paths(\n            TrackedOrderedDict({\"b\": 2, \"a\": 1}),\n            TrackedOrderedDict({\"a\": 10, \"b\": 20}),\n        )\n\n        # Deeper nested structures.\n        t.assert_same_paths(\n            (\n                {\"b\": [2, 3], \"a\": (1,)},\n                TrackedDict({\"x\": 4, \"y\": TrackedList([5, 6])}),\n                TrackedSet([7]),\n                Point(y=9, x=8),\n                np.array([10]),\n            ),\n            (\n                {\"b\": [20, 30], \"a\": (10,)},\n                TrackedDict({\"x\": 40, \"y\": TrackedList([50, 60])}),\n                TrackedSet([70]),\n                Point(y=90, x=80),\n                np.array([100]),\n            ),\n        )\n\n        # list, tuple, deque and namedtuple have the same paths.\n        # Test all 6 combinations:\n        # tuple, list.\n        t.assert_same_paths((), [])\n        t.assert_same_paths([1, 2], (10, 20))\n        # tuple, deque.\n        t.assert_same_paths((), deque())\n        t.assert_same_paths(deque([1, 2]), (10, 20))\n        # tuple, namedtuple.\n        t.assert_same_paths((), Empty())\n        t.assert_same_paths(Point(x=1, y=2), (10, 20))\n        # list, deque.\n        t.assert_same_paths([], deque())\n        t.assert_same_paths(deque([1, 2]), [10, 20])\n        # list, namedtuple.\n        t.assert_same_paths([], Empty())\n        t.assert_same_paths(Point(x=None, y=20), [1, 2])\n        # deque, namedtuple.\n        t.assert_same_paths(deque(), Empty())\n        t.assert_same_paths(Point(x=None, y=20), deque([1, 2]))\n\n        # Equivalent namedtuples.\n        t.assert_same_paths(Point(x=1, y=2), OtherPoint(x=None, y=20))\n\n        # Mismatched counts.\n        with self.assertRaisesRegex(ValueError, \"don't have the same paths\"):\n            t.assert_same_paths((1, 2), (1,))\n        with self.assertRaisesRegex(ValueError, \"don't have the same paths\"):\n            t.assert_same_paths([1, 2], [1])\n        with self.assertRaisesRegex(ValueError, \"don't have the same paths\"):\n            t.assert_same_paths(deque([1, 2]), deque([1]))\n\n        # dict, OrderedDict, defaultdict are considered equivalent. Test all 6\n        # combinations (3 type combinations plus the order).\n        # dict, OrderedDict.\n        t.assert_same_paths(\n            {\"a\": 1, \"b\": 2}, OrderedDict([(\"b\", 20), (\"a\", 10)])\n        )\n        # OrderedDict, dict.\n        t.assert_same_paths(\n            OrderedDict([(\"b\", 20), (\"a\", 10)]), {\"a\": 1, \"b\": 2}\n        )\n        # dict, defaultdict.\n        t.assert_same_paths(\n            {\"a\": 1, \"b\": 2},\n            defaultdict(default_value, [(\"b\", 20), (\"a\", 10)]),\n        )\n        # defaultdict, dict.\n        t.assert_same_paths(\n            defaultdict(default_value, [(\"b\", 20), (\"a\", 10)]),\n            {\"a\": 1, \"b\": 2},\n        )\n        # defaultdict, OrderedDict.\n        t.assert_same_paths(\n            defaultdict(default_value, [(\"a\", 1), (\"b\", 2)]),\n            OrderedDict([(\"b\", 20), (\"a\", 10)]),\n        )\n        # OrderedDict, defaultdict.\n        t.assert_same_paths(\n            OrderedDict([(\"b\", 2), (\"a\", 1)]),\n            defaultdict(default_value, [(\"a\", 10), (\"b\", 20)]),\n        )\n\n        # Two OrderedDicts with same keys but different orders.\n        t.assert_same_paths(\n            OrderedDict([(\"b\", 2), (\"a\", 1)]),\n            OrderedDict([(\"a\", 10), (\"b\", 20)]),\n        )\n\n        # Keras tracking wrappers are equivalent to the raw structures.\n        t.assert_same_paths(TrackedList([1, 2]), list([10, 20]))\n        t.assert_same_paths(list([1, 2]), TrackedList([10, 20]))\n        t.assert_same_paths(TrackedSet([1, 2]), list([10, 20]))\n        t.assert_same_paths(list([1, 2]), TrackedSet([10, 20]))\n        t.assert_same_paths(TrackedDict({\"b\": 2, \"a\": 1}), {\"a\": 10, \"b\": 20})\n        t.assert_same_paths({\"b\": 2, \"a\": 1}, TrackedDict({\"a\": 10, \"b\": 20}))\n\n        # Mismatched keys\n        with self.assertRaisesRegex(ValueError, \"don't have the same paths\"):\n            t.assert_same_paths({\"a\": 1, \"b\": 2}, {\"a\": 1})\n        with self.assertRaisesRegex(ValueError, \"don't have the same paths\"):\n            t.assert_same_paths(\n                defaultdict(default_value, [(\"a\", 1), (\"b\", 2)]),\n                defaultdict(default_value, [(\"a\", 10)]),\n            )\n        with self.assertRaisesRegex(ValueError, \"don't have the same paths\"):\n            t.assert_same_paths(\n                OrderedDict([(\"a\", 1), (\"b\", 2)]), OrderedDict([(\"a\", 10)])\n            )\n\n    @pytest.mark.skipif(backend.backend() != \"tensorflow\", reason=\"tf only\")\n    def test_assert_same_paths_tf_wrappers(self, t):\n        from tensorflow.python.trackable.data_structures import ListWrapper\n        from tensorflow.python.trackable.data_structures import _DictWrapper\n\n        t.assert_same_paths(ListWrapper([]), ListWrapper([]))\n        t.assert_same_paths(ListWrapper([1]), ListWrapper([10]))\n        t.assert_same_paths(ListWrapper([1, 2]), ListWrapper([10, 20]))\n        t.assert_same_paths(_DictWrapper(), _DictWrapper())\n        t.assert_same_paths(_DictWrapper({\"a\": 1}), _DictWrapper({\"a\": 11}))\n        t.assert_same_paths(\n            _DictWrapper({\"b\": 2, \"a\": 1}), _DictWrapper({\"a\": 11, \"b\": 12})\n        )\n\n        # Tensorflow wrappers are equivalent to the raw structures.\n        t.assert_same_paths(ListWrapper([1, 2]), list([10, 20]))\n        t.assert_same_paths(list([1, 2]), ListWrapper([10, 20]))\n        t.assert_same_paths(_DictWrapper({\"b\": 2, \"a\": 1}), {\"a\": 10, \"b\": 20})\n        t.assert_same_paths({\"b\": 2, \"a\": 1}, _DictWrapper({\"a\": 10, \"b\": 20}))\n\n    def test_traverse_top_down(self, t):\n        v = Visitor(lambda x: None if t.is_nested(x) else x + 10)\n\n        # Non-nested.\n        self.assertEqualStrict(t.traverse(v, 1), 11)\n        self.assertEqualStrict(v.visited(), [1])\n\n        # Standard structures.\n        self.assertEqualStrict(t.traverse(v, ()), ())\n        self.assertEqualStrict(v.visited(), [()])\n\n        self.assertEqualStrict(t.traverse(v, (1,)), (11,))\n        self.assertEqualStrict(v.visited(), [(1,), 1])\n\n        self.assertEqualStrict(t.traverse(v, (1, 2)), (11, 12))\n        self.assertEqualStrict(v.visited(), [(1, 2), 1, 2])\n\n        self.assertEqualStrict(t.traverse(v, []), [])\n        self.assertEqualStrict(v.visited(), [[]])\n\n        self.assertEqualStrict(t.traverse(v, [1]), [11])\n        self.assertEqualStrict(v.visited(), [[1], 1])\n\n        self.assertEqualStrict(t.traverse(v, [1, 2]), [11, 12])\n        self.assertEqualStrict(v.visited(), [[1, 2], 1, 2])\n\n        self.assertEqualStrict(t.traverse(v, deque([])), deque([]))\n        self.assertEqualStrict(v.visited(), [deque([])])\n\n        self.assertEqualStrict(t.traverse(v, deque([1])), deque([11]))\n        self.assertEqualStrict(v.visited(), [deque([1]), 1])\n\n        self.assertEqualStrict(t.traverse(v, deque([1, 2])), deque([11, 12]))\n        self.assertEqualStrict(v.visited(), [deque([1, 2]), 1, 2])\n\n        self.assertEqualStrict(t.traverse(v, Empty()), Empty())\n        self.assertEqualStrict(v.visited(), [Empty()])\n\n        self.assertEqualStrict(\n            t.traverse(v, Point(y=2, x=1)), Point(x=11, y=12)\n        )\n        self.assertEqualStrict(v.visited(), [Point(x=1, y=2), 1, 2])\n\n        self.assertEqualStrict(t.traverse(v, {}), {})\n        self.assertEqualStrict(v.visited(), [{}])\n\n        self.assertEqualStrict(t.traverse(v, {\"a\": 1}), {\"a\": 11})\n        self.assertEqualStrict(v.visited(), [{\"a\": 1}, 1])\n\n        self.assertEqualStrict(\n            t.traverse(v, {\"b\": 2, \"a\": 1}), {\"a\": 11, \"b\": 12}\n        )\n        self.assertEqualStrict(v.visited(), [{\"a\": 1, \"b\": 2}, 1, 2])\n\n        self.assertEqualStrict(t.traverse(v, OrderedDict()), OrderedDict())\n        self.assertEqualStrict(v.visited(), [OrderedDict()])\n\n        self.assertEqualStrict(\n            t.traverse(v, OrderedDict([(\"a\", 1)])), OrderedDict([(\"a\", 11)])\n        )\n        self.assertEqualStrict(v.visited(), [OrderedDict([(\"a\", 1)]), 1])\n\n        self.assertEqualStrict(\n            t.traverse(v, OrderedDict([(\"b\", 2), (\"a\", 1)])),\n            OrderedDict([(\"b\", 12), (\"a\", 11)]),\n        )\n        self.assertEqualStrict(\n            v.visited(), [OrderedDict([(\"b\", 2), (\"a\", 1)]), 2, 1]\n        )\n\n        self.assertEqualStrict(\n            t.traverse(v, defaultdict(default_value)),\n            defaultdict(default_value),\n        )\n        self.assertEqualStrict(v.visited(), [defaultdict(default_value)])\n\n        self.assertEqualStrict(\n            t.traverse(v, defaultdict(default_value, [(\"a\", 1)])),\n            defaultdict(default_value, [(\"a\", 11)]),\n        )\n        self.assertEqualStrict(\n            v.visited(), [defaultdict(default_value, [(\"a\", 1)]), 1]\n        )\n\n        self.assertEqualStrict(\n            t.traverse(v, defaultdict(default_value, [(\"b\", 2), (\"a\", 1)])),\n            defaultdict(default_value, [(\"a\", 11), (\"b\", 12)]),\n        )\n        self.assertEqualStrict(\n            v.visited(),\n            [defaultdict(default_value, [(\"a\", 1), (\"b\", 2)]), 1, 2],\n        )\n\n        # Keras tracking wrappers.\n        self.assertEqualStrict(t.traverse(v, TrackedList([])), TrackedList([]))\n        self.assertEqualStrict(v.visited(), [TrackedList([])])\n\n        self.assertEqualStrict(\n            t.traverse(v, TrackedList([1])), TrackedList([11])\n        )\n        self.assertEqualStrict(v.visited(), [TrackedList([1]), 1])\n\n        self.assertEqualStrict(\n            t.traverse(v, TrackedList([1, 2])), TrackedList([11, 12])\n        )\n        self.assertEqualStrict(v.visited(), [TrackedList([1, 2]), 1, 2])\n\n        self.assertEqualStrict(t.traverse(v, TrackedSet([])), TrackedSet([]))\n        self.assertEqualStrict(v.visited(), [TrackedSet([])])\n\n        self.assertEqualStrict(t.traverse(v, TrackedSet([1])), TrackedSet([11]))\n        self.assertEqualStrict(v.visited(), [TrackedSet([1]), 1])\n\n        self.assertEqualStrict(\n            t.traverse(v, TrackedSet([1, 2])), TrackedSet([11, 12])\n        )\n        visited = v.visited()\n        self.assertEqualStrict(visited[0], TrackedSet([1, 2]))\n        self.assertEqualStrict(sorted(visited[1:]), [1, 2])\n\n        self.assertEqualStrict(\n            t.traverse(v, TrackedDict()),\n            TrackedDict(),\n        )\n        self.assertEqualStrict(v.visited(), [TrackedDict()])\n\n        self.assertEqualStrict(\n            t.traverse(v, TrackedDict({\"a\": 1})),\n            TrackedDict({\"a\": 11}),\n        )\n        self.assertEqualStrict(v.visited(), [TrackedDict({\"a\": 1}), 1])\n\n        self.assertEqualStrict(\n            t.traverse(v, TrackedDict({\"b\": 2, \"a\": 1})),\n            TrackedDict({\"a\": 11, \"b\": 12}),\n        )\n        self.assertEqualStrict(\n            v.visited(), [TrackedDict({\"a\": 1, \"b\": 2}), 1, 2]\n        )\n\n        self.assertEqualStrict(\n            t.traverse(v, TrackedOrderedDict()),\n            TrackedOrderedDict(),\n        )\n        self.assertEqualStrict(v.visited(), [TrackedOrderedDict()])\n\n        self.assertEqualStrict(\n            t.traverse(v, TrackedOrderedDict({\"a\": 1})),\n            TrackedOrderedDict({\"a\": 11}),\n        )\n        self.assertEqualStrict(v.visited(), [TrackedOrderedDict({\"a\": 1}), 1])\n\n        self.assertEqualStrict(\n            t.traverse(v, TrackedOrderedDict({\"b\": 2, \"a\": 1})),\n            TrackedOrderedDict({\"b\": 12, \"a\": 11}),\n        )\n        self.assertEqualStrict(\n            v.visited(), [TrackedOrderedDict({\"b\": 2, \"a\": 1}), 2, 1]\n        )\n\n        # Deeper nested structures.\n        self.assertEqualStrict(\n            t.traverse(\n                v,\n                (\n                    {\"b\": [2, 3], \"a\": (1,)},\n                    TrackedDict({\"x\": 4, \"y\": TrackedList([5, 6])}),\n                    TrackedSet([7]),\n                    Point(y=9, x=8),\n                    np.array([10]),\n                ),\n            ),\n            (\n                {\"b\": [12, 13], \"a\": (11,)},\n                TrackedDict({\"x\": 14, \"y\": TrackedList([15, 16])}),\n                TrackedSet([17]),\n                Point(y=19, x=18),\n                np.array([20]),\n            ),\n        )\n        self.assertEqualStrict(\n            v.visited(),\n            [\n                (\n                    {\"b\": [2, 3], \"a\": (1,)},\n                    TrackedDict({\"x\": 4, \"y\": TrackedList([5, 6])}),\n                    TrackedSet([7]),\n                    Point(y=9, x=8),\n                    np.array([10]),\n                ),\n                {\"b\": [2, 3], \"a\": (1,)},\n                (1,),\n                1,\n                [2, 3],\n                2,\n                3,\n                TrackedDict({\"x\": 4, \"y\": TrackedList([5, 6])}),\n                4,\n                TrackedList([5, 6]),\n                5,\n                6,\n                TrackedSet([7]),\n                7,\n                Point(x=8, y=9),\n                8,\n                9,\n                np.array([10]),\n            ],\n        )\n\n        # Error cases.\n        with self.assertRaisesRegex(TypeError, \"callable\"):\n            t.traverse(\"bad\", [1, 2])\n\n        # Children are not explored if structure is replaced with a leaf.\n        v = Visitor(lambda x: \"X\" if isinstance(x, tuple) else None)\n        self.assertEqualStrict(\n            t.traverse(v, [(1, [2]), [3, (4, 5, 6)]]),\n            [\"X\", [3, \"X\"]],\n        )\n        self.assertEqualStrict(\n            v.visited(),\n            [\n                [(1, [2]), [3, (4, 5, 6)]],\n                (1, [2]),\n                [3, (4, 5, 6)],\n                3,\n                (4, 5, 6),\n            ],\n        )\n\n        # Children are not explored if structure is replaced with structure.\n        v = Visitor(lambda x: (\"a\", \"b\") if isinstance(x, tuple) else None)\n        self.assertEqualStrict(\n            t.traverse(v, [(1, [2]), [3, (4, 5, 6)]]),\n            [(\"a\", \"b\"), [3, (\"a\", \"b\")]],\n        )\n        self.assertEqualStrict(\n            v.visited(),\n            [\n                [(1, [2]), [3, (4, 5, 6)]],\n                (1, [2]),\n                [3, (4, 5, 6)],\n                3,\n                (4, 5, 6),\n            ],\n        )\n\n        # MAP_TO_NONE.\n        v = Visitor(lambda x: MAP_TO_NONE if isinstance(x, tuple) else None)\n        self.assertEqualStrict(\n            t.traverse(v, [(1, [2]), [3, (4, 5, 6)]]),\n            [None, [3, None]],\n        )\n        self.assertEqualStrict(\n            v.visited(),\n            [\n                [(1, [2]), [3, (4, 5, 6)]],\n                (1, [2]),\n                [3, (4, 5, 6)],\n                3,\n                (4, 5, 6),\n            ],\n        )\n\n    @pytest.mark.skipif(backend.backend() != \"tensorflow\", reason=\"tf only\")\n    def test_traverse_top_down_tf_wrappers(self, t):\n        from tensorflow.python.trackable.data_structures import ListWrapper\n        from tensorflow.python.trackable.data_structures import _DictWrapper\n\n        v = Visitor(lambda x: None if t.is_nested(x) else x + 10)\n\n        self.assertEqualStrict(t.traverse(v, ListWrapper([])), ListWrapper([]))\n        self.assertEqualStrict(v.visited(), [ListWrapper([])])\n\n        self.assertEqualStrict(\n            t.traverse(v, ListWrapper([1])), ListWrapper([11])\n        )\n        self.assertEqualStrict(v.visited(), [ListWrapper([1]), 1])\n\n        self.assertEqualStrict(\n            t.traverse(v, ListWrapper([1, 2])), ListWrapper([11, 12])\n        )\n        self.assertEqualStrict(v.visited(), [ListWrapper([1, 2]), 1, 2])\n\n        self.assertEqualStrict(\n            t.traverse(v, _DictWrapper()),\n            _DictWrapper(),\n        )\n        self.assertEqualStrict(v.visited(), [_DictWrapper()])\n\n        self.assertEqualStrict(\n            t.traverse(v, _DictWrapper({\"a\": 1})),\n            _DictWrapper({\"a\": 11}),\n        )\n        self.assertEqualStrict(v.visited(), [_DictWrapper({\"a\": 1}), 1])\n\n        self.assertEqualStrict(\n            t.traverse(v, _DictWrapper({\"b\": 2, \"a\": 1})),\n            _DictWrapper({\"a\": 11, \"b\": 12}),\n        )\n        self.assertEqualStrict(\n            v.visited(), [_DictWrapper({\"a\": 1, \"b\": 2}), 1, 2]\n        )\n\n    def test_traverse_bottom_up(self, t):\n        v = Visitor(lambda x: None if t.is_nested(x) else x + 10)\n        traverse_u = functools.partial(t.traverse, top_down=False)\n\n        # Non-nested.\n        self.assertEqualStrict(traverse_u(v, 1), 11)\n        self.assertEqualStrict(v.visited(), [1])\n\n        # Standard structures.\n        self.assertEqualStrict(traverse_u(v, ()), ())\n        self.assertEqualStrict(v.visited(), [()])\n\n        self.assertEqualStrict(traverse_u(v, (1,)), (11,))\n        self.assertEqualStrict(v.visited(), [1, (11,)])\n\n        self.assertEqualStrict(traverse_u(v, (1, 2)), (11, 12))\n        self.assertEqualStrict(v.visited(), [1, 2, (11, 12)])\n\n        self.assertEqualStrict(traverse_u(v, []), [])\n        self.assertEqualStrict(v.visited(), [[]])\n\n        self.assertEqualStrict(traverse_u(v, [1]), [11])\n        self.assertEqualStrict(v.visited(), [1, [11]])\n\n        self.assertEqualStrict(traverse_u(v, [1, 2]), [11, 12])\n        self.assertEqualStrict(v.visited(), [1, 2, [11, 12]])\n\n        self.assertEqualStrict(traverse_u(v, deque([])), deque([]))\n        self.assertEqualStrict(v.visited(), [deque([])])\n\n        self.assertEqualStrict(traverse_u(v, deque([1])), deque([11]))\n        self.assertEqualStrict(v.visited(), [1, deque([11])])\n\n        self.assertEqualStrict(traverse_u(v, deque([1, 2])), deque([11, 12]))\n        self.assertEqualStrict(v.visited(), [1, 2, deque([11, 12])])\n\n        self.assertEqualStrict(traverse_u(v, Empty()), Empty())\n        self.assertEqualStrict(v.visited(), [Empty()])\n\n        self.assertEqualStrict(\n            traverse_u(v, Point(y=2, x=1)), Point(x=11, y=12)\n        )\n        self.assertEqualStrict(v.visited(), [1, 2, Point(x=11, y=12)])\n\n        self.assertEqualStrict(traverse_u(v, {}), {})\n        self.assertEqualStrict(v.visited(), [{}])\n\n        self.assertEqualStrict(traverse_u(v, {\"a\": 1}), {\"a\": 11})\n        self.assertEqualStrict(v.visited(), [1, {\"a\": 11}])\n\n        self.assertEqualStrict(\n            traverse_u(v, {\"b\": 2, \"a\": 1}), {\"a\": 11, \"b\": 12}\n        )\n        self.assertEqualStrict(v.visited(), [1, 2, {\"a\": 11, \"b\": 12}])\n\n        self.assertEqualStrict(traverse_u(v, OrderedDict()), OrderedDict())\n        self.assertEqualStrict(v.visited(), [OrderedDict()])\n\n        self.assertEqualStrict(\n            traverse_u(v, OrderedDict([(\"a\", 1)])), OrderedDict([(\"a\", 11)])\n        )\n        self.assertEqualStrict(v.visited(), [1, OrderedDict([(\"a\", 11)])])\n\n        self.assertEqualStrict(\n            traverse_u(v, OrderedDict([(\"b\", 2), (\"a\", 1)])),\n            OrderedDict([(\"b\", 12), (\"a\", 11)]),\n        )\n        self.assertEqualStrict(\n            v.visited(), [2, 1, OrderedDict([(\"b\", 12), (\"a\", 11)])]\n        )\n\n        self.assertEqualStrict(\n            traverse_u(v, defaultdict(default_value)),\n            defaultdict(default_value),\n        )\n        self.assertEqualStrict(v.visited(), [defaultdict(default_value)])\n\n        self.assertEqualStrict(\n            traverse_u(v, defaultdict(default_value, [(\"a\", 1)])),\n            defaultdict(default_value, [(\"a\", 11)]),\n        )\n        self.assertEqualStrict(\n            v.visited(), [1, defaultdict(default_value, [(\"a\", 11)])]\n        )\n\n        self.assertEqualStrict(\n            traverse_u(v, defaultdict(default_value, [(\"b\", 2), (\"a\", 1)])),\n            defaultdict(default_value, [(\"a\", 11), (\"b\", 12)]),\n        )\n        self.assertEqualStrict(\n            v.visited(),\n            [1, 2, defaultdict(default_value, [(\"a\", 11), (\"b\", 12)])],\n        )\n\n        # Keras tracking wrappers.\n        self.assertEqualStrict(traverse_u(v, TrackedList([])), TrackedList([]))\n        self.assertEqualStrict(v.visited(), [TrackedList([])])\n\n        self.assertEqualStrict(\n            traverse_u(v, TrackedList([1])), TrackedList([11])\n        )\n        self.assertEqualStrict(v.visited(), [1, TrackedList([11])])\n\n        self.assertEqualStrict(\n            traverse_u(v, TrackedList([1, 2])), TrackedList([11, 12])\n        )\n        self.assertEqualStrict(v.visited(), [1, 2, TrackedList([11, 12])])\n\n        self.assertEqualStrict(traverse_u(v, TrackedSet([])), TrackedSet([]))\n        self.assertEqualStrict(v.visited(), [TrackedSet([])])\n\n        self.assertEqualStrict(traverse_u(v, TrackedSet([1])), TrackedSet([11]))\n        self.assertEqualStrict(v.visited(), [1, TrackedSet([11])])\n\n        self.assertEqualStrict(\n            traverse_u(v, TrackedSet([1, 2])), TrackedSet([11, 12])\n        )\n        visited = v.visited()\n        self.assertEqualStrict(visited[-1], TrackedSet([11, 12]))\n        self.assertEqualStrict(sorted(visited[:-1]), [1, 2])\n\n        self.assertEqualStrict(\n            traverse_u(v, TrackedDict()),\n            TrackedDict(),\n        )\n        self.assertEqualStrict(v.visited(), [TrackedDict()])\n\n        self.assertEqualStrict(\n            traverse_u(v, TrackedDict({\"a\": 1})),\n            TrackedDict({\"a\": 11}),\n        )\n        self.assertEqualStrict(v.visited(), [1, TrackedDict({\"a\": 11})])\n\n        self.assertEqualStrict(\n            traverse_u(v, TrackedDict({\"b\": 2, \"a\": 1})),\n            TrackedDict({\"a\": 11, \"b\": 12}),\n        )\n        self.assertEqualStrict(\n            v.visited(), [1, 2, TrackedDict({\"a\": 11, \"b\": 12})]\n        )\n\n        self.assertEqualStrict(\n            traverse_u(v, TrackedOrderedDict()),\n            TrackedOrderedDict(),\n        )\n        self.assertEqualStrict(v.visited(), [TrackedOrderedDict()])\n\n        self.assertEqualStrict(\n            traverse_u(v, TrackedOrderedDict({\"a\": 1})),\n            TrackedOrderedDict({\"a\": 11}),\n        )\n        self.assertEqualStrict(v.visited(), [1, TrackedOrderedDict({\"a\": 11})])\n\n        self.assertEqualStrict(\n            traverse_u(v, TrackedOrderedDict({\"b\": 2, \"a\": 1})),\n            TrackedOrderedDict({\"b\": 12, \"a\": 11}),\n        )\n        self.assertEqualStrict(\n            v.visited(), [2, 1, TrackedOrderedDict({\"b\": 12, \"a\": 11})]\n        )\n\n        # Deeper nested structures.\n        self.assertEqualStrict(\n            traverse_u(\n                v,\n                (\n                    {\"b\": [2, 3], \"a\": (1,)},\n                    TrackedDict({\"x\": 4, \"y\": TrackedList([5, 6])}),\n                    TrackedSet([7]),\n                    Point(y=9, x=8),\n                    np.array([10]),\n                ),\n            ),\n            (\n                {\"b\": [12, 13], \"a\": (11,)},\n                TrackedDict({\"x\": 14, \"y\": TrackedList([15, 16])}),\n                TrackedSet([17]),\n                Point(y=19, x=18),\n                np.array([20]),\n            ),\n        )\n        self.assertEqualStrict(\n            v.visited(),\n            [\n                1,\n                (11,),\n                2,\n                3,\n                [12, 13],\n                {\"b\": [12, 13], \"a\": (11,)},\n                4,\n                5,\n                6,\n                TrackedList([15, 16]),\n                TrackedDict({\"x\": 14, \"y\": TrackedList([15, 16])}),\n                7,\n                TrackedSet([17]),\n                8,\n                9,\n                Point(x=18, y=19),\n                np.array([10]),\n                (\n                    {\"b\": [12, 13], \"a\": (11,)},\n                    TrackedDict({\"x\": 14, \"y\": TrackedList([15, 16])}),\n                    TrackedSet([17]),\n                    Point(y=19, x=18),\n                    np.array([20]),\n                ),\n            ],\n        )\n\n        # Error cases.\n        with self.assertRaisesRegex(TypeError, \"callable\"):\n            traverse_u(\"bad\", [1, 2])\n\n        # Children are not explored if structure is replaced with a leaf.\n        v = Visitor(lambda x: \"X\" if isinstance(x, tuple) else None)\n        self.assertEqualStrict(\n            traverse_u(v, [(1, [2]), [3, (4, 5, 6)]]),\n            [\"X\", [3, \"X\"]],\n        )\n        self.assertEqualStrict(\n            v.visited(),\n            [\n                1,\n                2,\n                [2],\n                (1, [2]),\n                3,\n                4,\n                5,\n                6,\n                (4, 5, 6),\n                [3, \"X\"],\n                [\"X\", [3, \"X\"]],\n            ],\n        )\n\n        # Children are not explored if structure is replaced with structure.\n        v = Visitor(lambda x: (\"a\", \"b\") if isinstance(x, tuple) else None)\n        self.assertEqualStrict(\n            traverse_u(v, [(1, [2]), [3, (4, 5, 6)]]),\n            [(\"a\", \"b\"), [3, (\"a\", \"b\")]],\n        )\n        self.assertEqualStrict(\n            v.visited(),\n            [\n                1,\n                2,\n                [2],\n                (1, [2]),\n                3,\n                4,\n                5,\n                6,\n                (4, 5, 6),\n                [3, (\"a\", \"b\")],\n                [(\"a\", \"b\"), [3, (\"a\", \"b\")]],\n            ],\n        )\n\n        # MAP_TO_NONE.\n        v = Visitor(lambda x: MAP_TO_NONE if isinstance(x, tuple) else None)\n        self.assertEqualStrict(\n            traverse_u(v, [(1, [2]), [3, (4, 5, 6)]]),\n            [None, [3, None]],\n        )\n        self.assertEqualStrict(\n            v.visited(),\n            [\n                1,\n                2,\n                [2],\n                (1, [2]),\n                3,\n                4,\n                5,\n                6,\n                (4, 5, 6),\n                [3, None],\n                [None, [3, None]],\n            ],\n        )\n\n    @pytest.mark.skipif(backend.backend() != \"tensorflow\", reason=\"tf only\")\n    def test_traverse_bottom_up_tf_wrappers(self, t):\n        from tensorflow.python.trackable.data_structures import ListWrapper\n        from tensorflow.python.trackable.data_structures import _DictWrapper\n\n        v = Visitor(lambda x: None if t.is_nested(x) else x + 10)\n        traverse_u = functools.partial(t.traverse, top_down=False)\n\n        self.assertEqualStrict(traverse_u(v, ListWrapper([])), ListWrapper([]))\n        self.assertEqualStrict(v.visited(), [ListWrapper([])])\n\n        self.assertEqualStrict(\n            traverse_u(v, ListWrapper([1])), ListWrapper([11])\n        )\n        self.assertEqualStrict(v.visited(), [1, ListWrapper([11])])\n\n        self.assertEqualStrict(\n            traverse_u(v, ListWrapper([1, 2])), ListWrapper([11, 12])\n        )\n        self.assertEqualStrict(v.visited(), [1, 2, ListWrapper([11, 12])])\n\n        self.assertEqualStrict(\n            traverse_u(v, _DictWrapper()),\n            _DictWrapper(),\n        )\n        self.assertEqualStrict(v.visited(), [_DictWrapper()])\n\n        self.assertEqualStrict(\n            traverse_u(v, _DictWrapper({\"a\": 1})),\n            _DictWrapper({\"a\": 11}),\n        )\n        self.assertEqualStrict(v.visited(), [1, _DictWrapper({\"a\": 11})])\n\n        self.assertEqualStrict(\n            traverse_u(v, _DictWrapper({\"b\": 2, \"a\": 1})),\n            _DictWrapper({\"a\": 11, \"b\": 12}),\n        )\n        self.assertEqualStrict(\n            v.visited(), [1, 2, _DictWrapper({\"a\": 11, \"b\": 12})]\n        )\n\n    def test_lists_to_tuples(self, t):\n        self.assertEqualStrict(\n            t.lists_to_tuples([1, 2, 3]),\n            (1, 2, 3),\n        )\n        self.assertEqualStrict(\n            t.lists_to_tuples([[1], [2, 3]]),\n            ((1,), (2, 3)),\n        )\n\n        # Deeper nested structures.\n        self.assertEqualStrict(\n            t.lists_to_tuples(\n                (\n                    {\"b\": [2, 3], \"a\": (1,)},\n                    TrackedDict({\"x\": 4, \"y\": TrackedList([5, 6])}),\n                    TrackedSet([(7, 8, 9)]),\n                ),\n            ),\n            (\n                {\"b\": (2, 3), \"a\": (1,)},\n                TrackedDict({\"x\": 4, \"y\": (5, 6)}),\n                TrackedSet([(7, 8, 9)]),\n            ),\n        )\n\n    def test_map_shape_structure(self, t):\n        v = Visitor(\n            lambda x: tuple(x) + (10,) if isinstance(x, (tuple, list)) else None\n        )\n\n        self.assertEqualStrict(\n            t.map_shape_structure(v, (1, 2, 3)),\n            (1, 2, 3, 10),\n        )\n        self.assertEqualStrict(\n            v.visited(),\n            [\n                (1, 2, 3),\n            ],\n        )\n\n        self.assertEqualStrict(\n            t.map_shape_structure(v, {\"a\": [1, 2, None], \"b\": (5,), \"c\": \"hi\"}),\n            {\"a\": (1, 2, None, 10), \"b\": (5, 10), \"c\": None},\n        )\n        self.assertEqualStrict(\n            v.visited(),\n            [\n                [1, 2, None],\n                (5,),\n                \"hi\",\n            ],\n        )\n\n        # Deeper nested structures.\n        self.assertEqualStrict(\n            t.map_shape_structure(\n                v,\n                (\n                    {\"b\": [2, 3], \"a\": (None,)},\n                    TrackedDict({\"x\": 4, \"y\": TrackedList([5, 6])}),\n                    TrackedSet([(7, None, 9)]),\n                ),\n            ),\n            (\n                {\"b\": (2, 3, 10), \"a\": (None, 10)},\n                TrackedDict({\"x\": None, \"y\": (5, 6, 10)}),\n                TrackedSet([(7, None, 9, 10)]),\n            ),\n        )\n        self.assertEqualStrict(\n            v.visited(),\n            [\n                (None,),\n                [2, 3],\n                4,\n                TrackedList([5, 6]),\n                (7, None, 9),\n            ],\n        )\n"
  },
  {
    "path": "keras/src/utils/__init__.py",
    "content": "from keras.src.utils.audio_dataset_utils import audio_dataset_from_directory\nfrom keras.src.utils.dataset_utils import split_dataset\nfrom keras.src.utils.file_utils import get_file\nfrom keras.src.utils.image_dataset_utils import image_dataset_from_directory\nfrom keras.src.utils.image_utils import array_to_img\nfrom keras.src.utils.image_utils import img_to_array\nfrom keras.src.utils.image_utils import load_img\nfrom keras.src.utils.image_utils import save_img\nfrom keras.src.utils.io_utils import disable_interactive_logging\nfrom keras.src.utils.io_utils import enable_interactive_logging\nfrom keras.src.utils.io_utils import is_interactive_logging_enabled\nfrom keras.src.utils.model_visualization import model_to_dot\nfrom keras.src.utils.model_visualization import plot_model\nfrom keras.src.utils.numerical_utils import normalize\nfrom keras.src.utils.numerical_utils import to_categorical\nfrom keras.src.utils.progbar import Progbar\nfrom keras.src.utils.python_utils import default\nfrom keras.src.utils.python_utils import is_default\nfrom keras.src.utils.python_utils import removeprefix\nfrom keras.src.utils.python_utils import removesuffix\nfrom keras.src.utils.rng_utils import set_random_seed\nfrom keras.src.utils.sequence_utils import pad_sequences\nfrom keras.src.utils.text_dataset_utils import text_dataset_from_directory\nfrom keras.src.utils.timeseries_dataset_utils import (\n    timeseries_dataset_from_array,\n)\n"
  },
  {
    "path": "keras/src/utils/argument_validation.py",
    "content": "def standardize_tuple(value, n, name, allow_zero=False):\n    \"\"\"Transforms non-negative/positive integer/integers into an integer tuple.\n\n    Args:\n        value: int or iterable of ints. The value to validate and convert.\n        n: int. The size of the tuple to be returned.\n        name: string. The name of the argument being validated, e.g. \"strides\"\n            or \"kernel_size\". This is only used to format error messages.\n        allow_zero: bool, defaults to `False`. A `ValueError` will raised\n            if zero is received and this argument is `False`.\n\n    Returns:\n        A tuple of n integers.\n    \"\"\"\n    error_msg = (\n        f\"The `{name}` argument must be a tuple of {n} integers. \"\n        f\"Received {name}={value}\"\n    )\n\n    if isinstance(value, int):\n        value_tuple = (value,) * n\n    else:\n        try:\n            value_tuple = tuple(value)\n        except TypeError:\n            raise ValueError(error_msg)\n        if len(value_tuple) != n:\n            raise ValueError(error_msg)\n        for single_value in value_tuple:\n            try:\n                int(single_value)\n            except (ValueError, TypeError):\n                error_msg += (\n                    f\"including element {single_value} of \"\n                    f\"type {type(single_value)}\"\n                )\n                raise ValueError(error_msg)\n\n    if allow_zero:\n        unqualified_values = {v for v in value_tuple if v < 0}\n        req_msg = \">= 0\"\n    else:\n        unqualified_values = {v for v in value_tuple if v <= 0}\n        req_msg = \"> 0\"\n\n    if unqualified_values:\n        error_msg += (\n            f\", including values {unqualified_values}\"\n            f\" that do not satisfy `value {req_msg}`\"\n        )\n        raise ValueError(error_msg)\n\n    return value_tuple\n\n\ndef standardize_padding(value, allow_causal=False):\n    if isinstance(value, (list, tuple)):\n        return value\n    padding = value.lower()\n    if allow_causal:\n        allowed_values = {\"valid\", \"same\", \"causal\"}\n    else:\n        allowed_values = {\"valid\", \"same\"}\n    if padding not in allowed_values:\n        raise ValueError(\n            \"The `padding` argument must be a list/tuple or one of \"\n            f\"{allowed_values}. \"\n            f\"Received: {padding}\"\n        )\n    return padding\n\n\ndef validate_string_arg(\n    value,\n    allowable_strings,\n    caller_name,\n    arg_name,\n    allow_none=False,\n    allow_callables=False,\n):\n    \"\"\"Validates the correctness of a string-based arg.\"\"\"\n    if allow_none and value is None:\n        return\n    elif allow_callables and callable(value):\n        return\n    elif isinstance(value, str) and value in allowable_strings:\n        return\n    raise ValueError(\n        f\"Unknown value for `{arg_name}` argument of {caller_name}. \"\n        f\"Allowed values are: {allowable_strings}. Received: \"\n        f\"{arg_name}={value}\"\n    )\n"
  },
  {
    "path": "keras/src/utils/audio_dataset_utils.py",
    "content": "import numpy as np\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.utils import dataset_utils\nfrom keras.src.utils.module_utils import tensorflow as tf\nfrom keras.src.utils.module_utils import tensorflow_io as tfio\n\nALLOWED_FORMATS = (\".wav\",)\n\n\n@keras_export(\"keras.utils.audio_dataset_from_directory\")\ndef audio_dataset_from_directory(\n    directory,\n    labels=\"inferred\",\n    label_mode=\"int\",\n    class_names=None,\n    batch_size=32,\n    sampling_rate=None,\n    output_sequence_length=None,\n    ragged=False,\n    shuffle=True,\n    seed=None,\n    validation_split=None,\n    subset=None,\n    follow_links=False,\n    verbose=True,\n):\n    \"\"\"Generates a `tf.data.Dataset` from audio files in a directory.\n\n    If your directory structure is:\n\n    ```\n    main_directory/\n    ...class_a/\n    ......a_audio_1.wav\n    ......a_audio_2.wav\n    ...class_b/\n    ......b_audio_1.wav\n    ......b_audio_2.wav\n    ```\n\n    Then calling `audio_dataset_from_directory(main_directory,\n    labels='inferred')`\n    will return a `tf.data.Dataset` that yields batches of audio files from\n    the subdirectories `class_a` and `class_b`, together with labels\n    0 and 1 (0 corresponding to `class_a` and 1 corresponding to `class_b`).\n\n    Only `.wav` files are supported at this time.\n\n    Args:\n        directory: Directory where the data is located.\n            If `labels` is `\"inferred\"`, it should contain subdirectories,\n            each containing audio files for a class. Otherwise, the directory\n            structure is ignored.\n        labels: Either \"inferred\" (labels are generated from the directory\n            structure), `None` (no labels), or a list/tuple of integer labels\n            of the same size as the number of audio files found in\n            the directory. Labels should be sorted according to the\n            alphanumeric order of the audio file paths\n            (obtained via `os.walk(directory)` in Python).\n        label_mode: String describing the encoding of `labels`. Options are:\n            - `\"int\"`: means that the labels are encoded as integers (e.g. for\n              `sparse_categorical_crossentropy` loss).\n            - `\"categorical\"` means that the labels are encoded as a categorical\n              vector (e.g. for `categorical_crossentropy` loss)\n            - `\"binary\"` means that the labels (there can be only 2)\n              are encoded as `float32` scalars with values 0\n              or 1 (e.g. for `binary_crossentropy`).\n            - `None` (no labels).\n        class_names: Only valid if \"labels\" is `\"inferred\"`.\n            This is the explicit list of class names\n            (must match names of subdirectories). Used to control the order\n            of the classes (otherwise alphanumerical order is used).\n        batch_size: Size of the batches of data. Default: 32. If `None`,\n            the data will not be batched\n            (the dataset will yield individual samples).\n        sampling_rate: Audio sampling rate (in samples per second).\n        output_sequence_length: Maximum length of an audio sequence. Audio files\n            longer than this will be truncated to `output_sequence_length`.\n            If set to `None`, then all sequences in the same batch will\n            be padded to the\n            length of the longest sequence in the batch.\n        ragged: Whether to return a Ragged dataset (where each sequence has its\n            own length). Defaults to `False`.\n        shuffle: Whether to shuffle the data.\n            If set to `False`, sorts the data in alphanumeric order.\n            Defaults to `True`.\n        seed: Optional random seed for shuffling and transformations.\n        validation_split: Optional float between 0 and 1, fraction of data to\n            reserve for validation.\n        subset: Subset of the data to return. One of `\"training\"`,\n            `\"validation\"` or `\"both\"`. Only used if `validation_split` is set.\n        follow_links: Whether to visits subdirectories pointed to by symlinks.\n            Defaults to `False`.\n        verbose: Whether to display number information on classes and\n            number of files found. Defaults to `True`.\n\n    Returns:\n\n    A `tf.data.Dataset` object.\n\n    - If `label_mode` is `None`, it yields `string` tensors of shape\n      `(batch_size,)`, containing the contents of a batch of audio files.\n    - Otherwise, it yields a tuple `(audio, labels)`, where `audio`\n      has shape `(batch_size, sequence_length, num_channels)` and `labels`\n      follows the format described\n      below.\n\n    Rules regarding labels format:\n\n    - if `label_mode` is `int`, the labels are an `int32` tensor of shape\n      `(batch_size,)`.\n    - if `label_mode` is `binary`, the labels are a `float32` tensor of\n      1s and 0s of shape `(batch_size, 1)`.\n    - if `label_mode` is `categorical`, the labels are a `float32` tensor\n      of shape `(batch_size, num_classes)`, representing a one-hot\n      encoding of the class index.\n    \"\"\"\n    if labels not in (\"inferred\", None):\n        if not isinstance(labels, (list, tuple)):\n            raise ValueError(\n                \"The `labels` argument should be a list/tuple of integer \"\n                \"labels, of the same size as the number of audio files in \"\n                \"the target directory. If you wish to infer the labels from \"\n                \"the subdirectory names in the target directory,\"\n                ' pass `labels=\"inferred\"`. '\n                \"If you wish to get a dataset that only contains audio samples \"\n                f\"(no labels), pass `labels=None`. Received: labels={labels}\"\n            )\n        if class_names:\n            raise ValueError(\n                \"You can only pass `class_names` if \"\n                f'`labels=\"inferred\"`. Received: labels={labels}, and '\n                f\"class_names={class_names}\"\n            )\n    if label_mode not in {\"int\", \"categorical\", \"binary\", None}:\n        raise ValueError(\n            '`label_mode` argument must be one of \"int\", \"categorical\", '\n            '\"binary\", '\n            f\"or None. Received: label_mode={label_mode}\"\n        )\n\n    if ragged and output_sequence_length is not None:\n        raise ValueError(\n            \"Cannot set both `ragged` and `output_sequence_length`\"\n        )\n\n    if sampling_rate is not None:\n        if not isinstance(sampling_rate, int):\n            raise ValueError(\n                \"`sampling_rate` should have an integer value. \"\n                f\"Received: sampling_rate={sampling_rate}\"\n            )\n\n        if sampling_rate <= 0:\n            raise ValueError(\n                \"`sampling_rate` should be higher than 0. \"\n                f\"Received: sampling_rate={sampling_rate}\"\n            )\n\n        if not tfio.available:\n            raise ImportError(\n                \"To use the argument `sampling_rate`, you should install \"\n                \"tensorflow_io. You can install it via `pip install \"\n                \"tensorflow-io`.\"\n            )\n\n    if labels is None or label_mode is None:\n        labels = None\n        label_mode = None\n\n    dataset_utils.check_validation_split_arg(\n        validation_split, subset, shuffle, seed\n    )\n\n    if seed is None:\n        seed = np.random.randint(1e6)\n    if batch_size is not None:\n        shuffle_buffer_size = batch_size * 8\n    else:\n        shuffle_buffer_size = 1024\n\n    file_paths, labels, class_names = dataset_utils.index_directory(\n        directory,\n        labels,\n        formats=ALLOWED_FORMATS,\n        class_names=class_names,\n        shuffle=shuffle,\n        seed=seed,\n        follow_links=follow_links,\n        verbose=verbose,\n    )\n\n    if label_mode == \"binary\" and len(class_names) != 2:\n        raise ValueError(\n            'When passing `label_mode=\"binary\"`, there must be exactly 2 '\n            f\"class_names. Received: class_names={class_names}\"\n        )\n\n    if subset == \"both\":\n        train_dataset, val_dataset = get_training_and_validation_dataset(\n            file_paths=file_paths,\n            labels=labels,\n            validation_split=validation_split,\n            directory=directory,\n            label_mode=label_mode,\n            class_names=class_names,\n            sampling_rate=sampling_rate,\n            output_sequence_length=output_sequence_length,\n            ragged=ragged,\n            shuffle=shuffle,\n            shuffle_buffer_size=shuffle_buffer_size,\n            seed=seed,\n        )\n        train_dataset = prepare_dataset(\n            dataset=train_dataset,\n            batch_size=batch_size,\n            class_names=class_names,\n            output_sequence_length=output_sequence_length,\n            ragged=ragged,\n        )\n        val_dataset = prepare_dataset(\n            dataset=val_dataset,\n            batch_size=batch_size,\n            class_names=class_names,\n            output_sequence_length=output_sequence_length,\n            ragged=ragged,\n        )\n        return train_dataset, val_dataset\n\n    else:\n        dataset = get_dataset(\n            file_paths=file_paths,\n            labels=labels,\n            directory=directory,\n            validation_split=validation_split,\n            subset=subset,\n            label_mode=label_mode,\n            class_names=class_names,\n            sampling_rate=sampling_rate,\n            output_sequence_length=output_sequence_length,\n            ragged=ragged,\n            shuffle=shuffle,\n            shuffle_buffer_size=shuffle_buffer_size,\n            seed=seed,\n        )\n        dataset = prepare_dataset(\n            dataset=dataset,\n            batch_size=batch_size,\n            class_names=class_names,\n            output_sequence_length=output_sequence_length,\n            ragged=ragged,\n        )\n        return dataset\n\n\ndef prepare_dataset(\n    dataset,\n    batch_size,\n    class_names,\n    output_sequence_length,\n    ragged,\n):\n    dataset = dataset.prefetch(tf.data.AUTOTUNE)\n    if batch_size is not None:\n        if output_sequence_length is None and not ragged:\n            dataset = dataset.padded_batch(\n                batch_size, padded_shapes=([None, None], [])\n            )\n        else:\n            dataset = dataset.batch(batch_size)\n\n    # Users may need to reference `class_names`.\n    dataset.class_names = class_names\n    return dataset\n\n\ndef get_training_and_validation_dataset(\n    file_paths,\n    labels,\n    validation_split,\n    directory,\n    label_mode,\n    class_names,\n    sampling_rate,\n    output_sequence_length,\n    ragged,\n    shuffle=False,\n    shuffle_buffer_size=None,\n    seed=None,\n):\n    (\n        file_paths_train,\n        labels_train,\n    ) = dataset_utils.get_training_or_validation_split(\n        file_paths, labels, validation_split, \"training\"\n    )\n    if not file_paths_train:\n        raise ValueError(\n            f\"No training audio files found in directory {directory}. \"\n            f\"Allowed format(s): {ALLOWED_FORMATS}\"\n        )\n\n    file_paths_val, labels_val = dataset_utils.get_training_or_validation_split(\n        file_paths, labels, validation_split, \"validation\"\n    )\n    if not file_paths_val:\n        raise ValueError(\n            f\"No validation audio files found in directory {directory}. \"\n            f\"Allowed format(s): {ALLOWED_FORMATS}\"\n        )\n\n    train_dataset = paths_and_labels_to_dataset(\n        file_paths=file_paths_train,\n        labels=labels_train,\n        label_mode=label_mode,\n        num_classes=len(class_names) if class_names else 0,\n        sampling_rate=sampling_rate,\n        output_sequence_length=output_sequence_length,\n        ragged=ragged,\n        shuffle=shuffle,\n        shuffle_buffer_size=shuffle_buffer_size,\n        seed=seed,\n    )\n\n    val_dataset = paths_and_labels_to_dataset(\n        file_paths=file_paths_val,\n        labels=labels_val,\n        label_mode=label_mode,\n        num_classes=len(class_names) if class_names else 0,\n        sampling_rate=sampling_rate,\n        output_sequence_length=output_sequence_length,\n        ragged=ragged,\n        shuffle=False,\n    )\n\n    return train_dataset, val_dataset\n\n\ndef get_dataset(\n    file_paths,\n    labels,\n    directory,\n    validation_split,\n    subset,\n    label_mode,\n    class_names,\n    sampling_rate,\n    output_sequence_length,\n    ragged,\n    shuffle=False,\n    shuffle_buffer_size=None,\n    seed=None,\n):\n    file_paths, labels = dataset_utils.get_training_or_validation_split(\n        file_paths, labels, validation_split, subset\n    )\n    if not file_paths:\n        raise ValueError(\n            f\"No audio files found in directory {directory}. \"\n            f\"Allowed format(s): {ALLOWED_FORMATS}\"\n        )\n\n    return paths_and_labels_to_dataset(\n        file_paths=file_paths,\n        labels=labels,\n        label_mode=label_mode,\n        num_classes=len(class_names) if class_names else 0,\n        sampling_rate=sampling_rate,\n        output_sequence_length=output_sequence_length,\n        ragged=ragged,\n        shuffle=shuffle,\n        shuffle_buffer_size=shuffle_buffer_size,\n        seed=seed,\n    )\n\n\ndef read_and_decode_audio(\n    path, sampling_rate=None, output_sequence_length=None\n):\n    \"\"\"Reads and decodes audio file.\"\"\"\n    audio = tf.io.read_file(path)\n\n    if output_sequence_length is None:\n        output_sequence_length = -1\n\n    audio, default_audio_rate = tf.audio.decode_wav(\n        contents=audio, desired_samples=output_sequence_length\n    )\n    if sampling_rate is not None:\n        # default_audio_rate should have dtype=int64\n        default_audio_rate = tf.cast(default_audio_rate, tf.int64)\n        audio = tfio.audio.resample(\n            input=audio, rate_in=default_audio_rate, rate_out=sampling_rate\n        )\n    return audio\n\n\ndef paths_and_labels_to_dataset(\n    file_paths,\n    labels,\n    label_mode,\n    num_classes,\n    sampling_rate,\n    output_sequence_length,\n    ragged,\n    shuffle=False,\n    shuffle_buffer_size=None,\n    seed=None,\n):\n    \"\"\"Constructs a fixed-size dataset of audio and labels.\"\"\"\n    path_ds = tf.data.Dataset.from_tensor_slices(file_paths)\n    if label_mode:\n        label_ds = dataset_utils.labels_to_dataset_tf(\n            labels, label_mode, num_classes\n        )\n        ds = tf.data.Dataset.zip((path_ds, label_ds))\n    else:\n        ds = path_ds\n\n    if shuffle:\n        ds = ds.shuffle(buffer_size=shuffle_buffer_size or 1024, seed=seed)\n\n    if label_mode:\n        ds = ds.map(\n            lambda x, y: (\n                read_and_decode_audio(x, sampling_rate, output_sequence_length),\n                y,\n            ),\n            num_parallel_calls=tf.data.AUTOTUNE,\n        )\n\n        if ragged:\n            ds = ds.map(\n                lambda x, y: (tf.RaggedTensor.from_tensor(x), y),\n                num_parallel_calls=tf.data.AUTOTUNE,\n            )\n\n    else:\n        ds = ds.map(\n            lambda x: read_and_decode_audio(\n                x, sampling_rate, output_sequence_length\n            ),\n            num_parallel_calls=tf.data.AUTOTUNE,\n        )\n\n        if ragged:\n            ds = ds.map(\n                lambda x: tf.RaggedTensor.from_tensor(x),\n                num_parallel_calls=tf.data.AUTOTUNE,\n            )\n\n    return ds\n"
  },
  {
    "path": "keras/src/utils/audio_dataset_utils_test.py",
    "content": "import os\n\nimport numpy as np\n\nfrom keras.src import testing\nfrom keras.src.utils import audio_dataset_utils\nfrom keras.src.utils.module_utils import tensorflow as tf\n\n\nclass AudioDatasetFromDirectoryTest(testing.TestCase):\n    def _get_audio_samples(self, count=16, different_sequence_lengths=False):\n        sequence_length = 30\n        num_channels = 1\n        audio_samples = []\n        for _ in range(count):\n            if different_sequence_lengths:\n                random_sequence_length = np.random.randint(\n                    10, sequence_length + 1\n                )\n                audio = np.random.random((random_sequence_length, num_channels))\n            else:\n                audio = np.random.random((sequence_length, num_channels))\n            audio_samples.append(tf.audio.encode_wav(audio, 1000))\n        return audio_samples\n\n    def _prepare_directory(\n        self,\n        num_classes=2,\n        nested_dirs=False,\n        count=16,\n        different_sequence_lengths=False,\n    ):\n        # Get a unique temp directory\n        temp_dir = self.get_temp_dir()\n\n        # Generate paths to class subdirectories\n        paths = []\n        for class_index in range(num_classes):\n            class_directory = f\"class_{class_index}\"\n            if nested_dirs:\n                class_paths = [\n                    class_directory,\n                    os.path.join(class_directory, \"subfolder_1\"),\n                    os.path.join(class_directory, \"subfolder_2\"),\n                    os.path.join(\n                        class_directory, \"subfolder_1\", \"sub-subfolder\"\n                    ),\n                ]\n            else:\n                class_paths = [class_directory]\n            for path in class_paths:\n                os.mkdir(os.path.join(temp_dir, path))\n            paths += class_paths\n\n        # Save audio samples to the paths\n        i = 0\n        for audio in self._get_audio_samples(\n            count=count, different_sequence_lengths=different_sequence_lengths\n        ):\n            path = paths[i % len(paths)]\n            ext = \"wav\"\n            filename = os.path.join(path, f\"audio_{i}.{ext}\")\n            with open(os.path.join(temp_dir, filename), \"wb\") as f:\n                f.write(audio.numpy())\n            i += 1\n        return temp_dir\n\n    def test_audio_dataset_from_directory_standalone(self):\n        # Test retrieving audio samples without labels from a directory and its\n        # subdirs.\n        # Save a few extra audio in the parent directory.\n        directory = self._prepare_directory(count=7, num_classes=2)\n        for i, audio in enumerate(self._get_audio_samples(3)):\n            filename = f\"audio_{i}.wav\"\n            with open(os.path.join(directory, filename), \"wb\") as f:\n                f.write(audio.numpy())\n\n        dataset = audio_dataset_utils.audio_dataset_from_directory(\n            directory, batch_size=5, output_sequence_length=30, labels=None\n        )\n        batch = next(iter(dataset))\n        # We return plain audio\n        self.assertEqual(batch.shape, (5, 30, 1))\n        self.assertEqual(batch.dtype.name, \"float32\")\n        # Count samples\n        batch_count = 0\n        sample_count = 0\n        for batch in dataset:\n            batch_count += 1\n            sample_count += batch.shape[0]\n        self.assertEqual(batch_count, 2)\n        self.assertEqual(sample_count, 10)\n\n    def test_audio_dataset_from_directory_binary(self):\n        directory = self._prepare_directory(num_classes=2)\n        dataset = audio_dataset_utils.audio_dataset_from_directory(\n            directory, batch_size=8, output_sequence_length=30, label_mode=\"int\"\n        )\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        self.assertEqual(batch[0].shape, (8, 30, 1))\n        self.assertEqual(batch[0].dtype.name, \"float32\")\n        self.assertEqual(batch[1].shape, (8,))\n        self.assertEqual(batch[1].dtype.name, \"int32\")\n\n        dataset = audio_dataset_utils.audio_dataset_from_directory(\n            directory,\n            batch_size=8,\n            output_sequence_length=30,\n            label_mode=\"binary\",\n        )\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        self.assertEqual(batch[0].shape, (8, 30, 1))\n        self.assertEqual(batch[0].dtype.name, \"float32\")\n        self.assertEqual(batch[1].shape, (8, 1))\n        self.assertEqual(batch[1].dtype.name, \"float32\")\n\n        dataset = audio_dataset_utils.audio_dataset_from_directory(\n            directory,\n            batch_size=8,\n            output_sequence_length=30,\n            label_mode=\"categorical\",\n        )\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        self.assertEqual(batch[0].shape, (8, 30, 1))\n        self.assertEqual(batch[0].dtype.name, \"float32\")\n        self.assertEqual(batch[1].shape, (8, 2))\n        self.assertEqual(batch[1].dtype.name, \"float32\")\n\n    def test_static_shape_in_graph(self):\n        directory = self._prepare_directory(num_classes=2)\n        dataset = audio_dataset_utils.audio_dataset_from_directory(\n            directory, batch_size=8, output_sequence_length=30, label_mode=\"int\"\n        )\n        test_obj = self\n\n        @tf.function\n        def symbolic_fn(ds):\n            for x, _ in ds.take(1):\n                test_obj.assertListEqual(x.shape.as_list(), [None, 30, None])\n\n        symbolic_fn(dataset)\n\n    def test_sample_count(self):\n        directory = self._prepare_directory(num_classes=4, count=15)\n        dataset = audio_dataset_utils.audio_dataset_from_directory(\n            directory, batch_size=8, output_sequence_length=30, label_mode=None\n        )\n        sample_count = 0\n        for batch in dataset:\n            sample_count += batch.shape[0]\n        self.assertEqual(sample_count, 15)\n\n    def test_audio_dataset_from_directory_multiclass(self):\n        directory = self._prepare_directory(num_classes=4, count=15)\n\n        dataset = audio_dataset_utils.audio_dataset_from_directory(\n            directory, batch_size=8, output_sequence_length=30, label_mode=None\n        )\n        batch = next(iter(dataset))\n        self.assertEqual(batch.shape, (8, 30, 1))\n\n        dataset = audio_dataset_utils.audio_dataset_from_directory(\n            directory, batch_size=8, output_sequence_length=30, label_mode=None\n        )\n        sample_count = 0\n        iterator = iter(dataset)\n        for batch in dataset:\n            sample_count += next(iterator).shape[0]\n        self.assertEqual(sample_count, 15)\n\n        dataset = audio_dataset_utils.audio_dataset_from_directory(\n            directory, batch_size=8, output_sequence_length=30, label_mode=\"int\"\n        )\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        self.assertEqual(batch[0].shape, (8, 30, 1))\n        self.assertEqual(batch[0].dtype.name, \"float32\")\n        self.assertEqual(batch[1].shape, (8,))\n        self.assertEqual(batch[1].dtype.name, \"int32\")\n\n        dataset = audio_dataset_utils.audio_dataset_from_directory(\n            directory,\n            batch_size=8,\n            output_sequence_length=30,\n            label_mode=\"categorical\",\n        )\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        self.assertEqual(batch[0].shape, (8, 30, 1))\n        self.assertEqual(batch[0].dtype.name, \"float32\")\n        self.assertEqual(batch[1].shape, (8, 4))\n        self.assertEqual(batch[1].dtype.name, \"float32\")\n\n    def test_audio_dataset_from_directory_validation_split(self):\n        directory = self._prepare_directory(num_classes=2, count=10)\n        dataset = audio_dataset_utils.audio_dataset_from_directory(\n            directory,\n            batch_size=10,\n            output_sequence_length=30,\n            validation_split=0.2,\n            subset=\"training\",\n            seed=1337,\n        )\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        self.assertEqual(batch[0].shape, (8, 30, 1))\n        dataset = audio_dataset_utils.audio_dataset_from_directory(\n            directory,\n            batch_size=10,\n            output_sequence_length=30,\n            validation_split=0.2,\n            subset=\"validation\",\n            seed=1337,\n        )\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        self.assertEqual(batch[0].shape, (2, 30, 1))\n\n    def test_audio_dataset_from_directory_manual_labels(self):\n        directory = self._prepare_directory(num_classes=2, count=2)\n        dataset = audio_dataset_utils.audio_dataset_from_directory(\n            directory,\n            batch_size=8,\n            output_sequence_length=30,\n            labels=[0, 1],\n            shuffle=False,\n        )\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        self.assertAllClose(batch[1], [0, 1])\n\n    def test_audio_dataset_from_directory_follow_links(self):\n        directory = self._prepare_directory(\n            num_classes=2, count=25, nested_dirs=True\n        )\n        dataset = audio_dataset_utils.audio_dataset_from_directory(\n            directory,\n            batch_size=8,\n            output_sequence_length=30,\n            label_mode=None,\n            follow_links=True,\n        )\n        sample_count = 0\n        for batch in dataset:\n            sample_count += batch.shape[0]\n        self.assertEqual(sample_count, 25)\n\n    def test_audio_dataset_from_directory_no_audio(self):\n        directory = self._prepare_directory(num_classes=2, count=0)\n        with self.assertRaisesRegex(\n            ValueError, \"No audio files found in directory\"\n        ):\n            _ = audio_dataset_utils.audio_dataset_from_directory(directory)\n\n    def test_audio_dataset_from_directory_ragged(self):\n        directory = self._prepare_directory(\n            num_classes=2, count=16, different_sequence_lengths=True\n        )\n        dataset = audio_dataset_utils.audio_dataset_from_directory(\n            directory, ragged=True, batch_size=8\n        )\n        batch = next(iter(dataset))\n\n        self.assertEqual(batch[0].shape.as_list(), [8, None, None])\n\n    def test_audio_dataset_from_directory_no_output_sequence_length_no_ragged(\n        self,\n    ):\n        # This test case tests `audio_dataset_from_directory` when `ragged` and\n        # `output_sequence_length` are not passed while the input sequence\n        # lengths are different.\n        directory = self._prepare_directory(\n            num_classes=2, count=16, different_sequence_lengths=True\n        )\n        # The tensor shapes are different and output_sequence_length is None\n        # should work fine and pad each sequence to the length of the longest\n        # sequence in it's batch\n        min_sequence_length, max_sequence_length = 10, 30\n        possible_sequence_lengths = [\n            i for i in range(min_sequence_length, max_sequence_length + 1)\n        ]\n        dataset = audio_dataset_utils.audio_dataset_from_directory(\n            directory, batch_size=2\n        )\n        sequence_lengths = list(set([b.shape[1] for b, _ in dataset]))\n        for seq_len in sequence_lengths:\n            self.assertIn(seq_len, possible_sequence_lengths)\n\n    def test_audio_dataset_from_directory_no_output_sequence_length_same_lengths(  # noqa: E501\n        self,\n    ):\n        # This test case tests `audio_dataset_from_directory` when `ragged` and\n        # `output_sequence_length` are not passed while the input sequence\n        # lengths are the same\n        directory = self._prepare_directory(\n            num_classes=2, count=16, different_sequence_lengths=False\n        )\n        # The tensor shapes are different and output_sequence_length is None\n        # should work fine and pad each sequence to the length of the longest\n        # sequence in it's batch\n        dataset = audio_dataset_utils.audio_dataset_from_directory(\n            directory, batch_size=2\n        )\n        sequence_lengths = list(set([batch[0].shape[1] for batch in dataset]))\n        self.assertEqual(len(sequence_lengths), 1)\n\n    def test_audio_dataset_from_directory_errors(self):\n        directory = self._prepare_directory(num_classes=3, count=5)\n\n        with self.assertRaisesRegex(\n            ValueError, \"`sampling_rate` should be higher than 0. Received:\"\n        ):\n            _ = audio_dataset_utils.audio_dataset_from_directory(\n                directory,\n                ragged=False,\n                output_sequence_length=10,\n                sampling_rate=-1,\n            )\n\n        with self.assertRaisesRegex(\n            ValueError,\n            \"`sampling_rate` should have an integer value. Received:\",\n        ):\n            _ = audio_dataset_utils.audio_dataset_from_directory(\n                directory,\n                ragged=False,\n                output_sequence_length=10,\n                sampling_rate=1.2,\n            )\n\n        # Only run this test case when we don't have tensorflow_io.\n        try:\n            import tensorflow_io  # noqa: F401\n        except ImportError:\n            with self.assertRaisesRegex(\n                ImportError,\n                \"To use the argument `sampling_rate`.*tensorflow_io.*\",\n            ):\n                _ = audio_dataset_utils.audio_dataset_from_directory(\n                    directory,\n                    ragged=False,\n                    output_sequence_length=10,\n                    sampling_rate=44100,\n                )\n\n        with self.assertRaisesRegex(\n            ValueError, \"Cannot set both `ragged` and `output_sequence_length`\"\n        ):\n            _ = audio_dataset_utils.audio_dataset_from_directory(\n                directory, ragged=True, output_sequence_length=30\n            )\n\n        with self.assertRaisesRegex(ValueError, \"`labels` argument should be\"):\n            _ = audio_dataset_utils.audio_dataset_from_directory(\n                directory, labels=\"other\"\n            )\n\n        with self.assertRaisesRegex(\n            ValueError, \"`label_mode` argument must be\"\n        ):\n            _ = audio_dataset_utils.audio_dataset_from_directory(\n                directory, label_mode=\"other\"\n            )\n\n        with self.assertRaisesRegex(\n            ValueError, 'only pass `class_names` if `labels=\"inferred\"`'\n        ):\n            _ = audio_dataset_utils.audio_dataset_from_directory(\n                directory,\n                labels=[0, 0, 1, 1, 1],\n                class_names=[\"class_0\", \"class_1\", \"class_2\"],\n            )\n\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Expected the lengths of `labels` to match the number of files\",\n        ):\n            _ = audio_dataset_utils.audio_dataset_from_directory(\n                directory, labels=[0, 0, 1, 1]\n            )\n\n        with self.assertRaisesRegex(\n            ValueError, \"`class_names` passed did not match\"\n        ):\n            _ = audio_dataset_utils.audio_dataset_from_directory(\n                directory, class_names=[\"class_0\", \"wrong_class\"]\n            )\n\n        with self.assertRaisesRegex(ValueError, \"there must be exactly 2\"):\n            _ = audio_dataset_utils.audio_dataset_from_directory(\n                directory, label_mode=\"binary\"\n            )\n\n        with self.assertRaisesRegex(\n            ValueError, \"`validation_split` must be between 0 and 1\"\n        ):\n            _ = audio_dataset_utils.audio_dataset_from_directory(\n                directory, validation_split=2\n            )\n\n        with self.assertRaisesRegex(\n            ValueError, '`subset` must be either \"training\",'\n        ):\n            _ = audio_dataset_utils.audio_dataset_from_directory(\n                directory, validation_split=0.2, subset=\"other\"\n            )\n\n        with self.assertRaisesRegex(\n            ValueError, \"`validation_split` must be set\"\n        ):\n            _ = audio_dataset_utils.audio_dataset_from_directory(\n                directory, validation_split=0.0, subset=\"training\"\n            )\n\n        with self.assertRaisesRegex(ValueError, \"must provide a `seed`\"):\n            _ = audio_dataset_utils.audio_dataset_from_directory(\n                directory, validation_split=0.2, subset=\"training\"\n            )\n\n    def test_audio_dataset_from_directory_not_batched(self):\n        directory = self._prepare_directory(num_classes=2, count=2)\n        dataset = audio_dataset_utils.audio_dataset_from_directory(\n            directory,\n            batch_size=None,\n            output_sequence_length=30,\n            label_mode=None,\n            shuffle=False,\n        )\n        sample = next(iter(dataset))\n        self.assertEqual(len(sample.shape), 2)\n"
  },
  {
    "path": "keras/src/utils/backend_utils.py",
    "content": "import copy\nimport importlib\nimport inspect\nimport os\nimport sys\nimport warnings\n\nfrom keras.src import backend as backend_module\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend.common import global_state\n\n\ndef in_tf_graph():\n    if global_state.get_global_attribute(\"in_tf_graph_scope\", False):\n        return True\n\n    if \"tensorflow\" in sys.modules:\n        from keras.src.utils.module_utils import tensorflow as tf\n\n        return not tf.executing_eagerly()\n    return False\n\n\ndef convert_tf_tensor(outputs, dtype=None):\n    if backend_module.backend() != \"tensorflow\" and not in_tf_graph():\n        outputs = backend_module.convert_to_tensor(outputs, dtype=dtype)\n    return outputs\n\n\nclass TFGraphScope:\n    def __init__(self):\n        self._original_value = global_state.get_global_attribute(\n            \"in_tf_graph_scope\", False\n        )\n\n    def __enter__(self):\n        global_state.set_global_attribute(\"in_tf_graph_scope\", True)\n\n    def __exit__(self, *args, **kwargs):\n        global_state.set_global_attribute(\n            \"in_tf_graph_scope\", self._original_value\n        )\n\n\ndef in_grain_data_pipeline():\n    if \"grain\" not in sys.modules:\n        # Fast path to check if grain is not imported.\n        return False\n\n    # We use a lightweight version of `inspect.stack` to detect execution within\n    # grain.\n    current_frame = inspect.currentframe()\n    while current_frame:\n        if (\n            os.path.join(\"grain\", \"_src\", \"python\", \"dataset\")\n            in current_frame.f_code.co_filename\n            or os.path.join(\"grain\", \"_src\", \"python\", \"data_loader\")\n            in current_frame.f_code.co_filename\n        ):\n            return True\n        current_frame = current_frame.f_back\n    return False\n\n\nclass DynamicBackend:\n    \"\"\"A class that can be used to switch from one backend to another.\n\n    Example:\n\n    ```python\n    backend = DynamicBackend(\"tensorflow\")\n    y = backend.square(tf.constant(...))\n    backend.set_backend(\"jax\")\n    y = backend.square(jax.numpy.array(...))\n    ```\n\n    Args:\n        backend: Initial backend to use (string).\n    \"\"\"\n\n    def __init__(self, backend=None):\n        self._backend = backend or backend_module.backend()\n\n    def set_backend(self, backend):\n        if backend not in (\"tensorflow\", \"jax\", \"torch\", \"numpy\", \"openvino\"):\n            raise ValueError(\n                \"Available backends are ('tensorflow', 'jax', 'torch', \"\n                f\"'numpy' and 'openvino'). Received: backend={backend}\"\n            )\n        self._backend = backend\n\n    def reset(self):\n        self._backend = backend_module.backend()\n\n    @property\n    def name(self):\n        return self._backend\n\n    def __getattr__(self, name):\n        if self._backend == \"tensorflow\":\n            module = importlib.import_module(\"keras.src.backend.tensorflow\")\n            return getattr(module, name)\n        if self._backend == \"jax\":\n            module = importlib.import_module(\"keras.src.backend.jax\")\n            return getattr(module, name)\n        if self._backend == \"torch\":\n            module = importlib.import_module(\"keras.src.backend.torch\")\n            return getattr(module, name)\n        if self._backend == \"numpy\":\n            if backend_module.backend() == \"numpy\":\n                return getattr(backend_module, name)\n            else:\n                raise NotImplementedError(\n                    \"Currently, we cannot dynamically import the numpy backend \"\n                    \"because it would disrupt the namespace of the import.\"\n                )\n        if self._backend == \"openvino\":\n            module = importlib.import_module(\"keras.src.backend.openvino\")\n            return getattr(module, name)\n\n\n@keras_export(\"keras.config.set_backend\")\ndef set_backend(backend):\n    \"\"\"Reload the backend (and the Keras package).\n\n    Example:\n\n    >>> import os\n    >>> os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n    >>>\n    >>> import keras\n    >>> from keras import ops\n    >>> type(ops.ones(()))\n    <class 'tensorflow.python.framework.ops.EagerTensor'>\n    >>>\n    >>> keras.config.set_backend(\"jax\")\n    UserWarning: Using `keras.config.set_backend` is dangerous...\n    >>> del keras, ops\n    >>>\n    >>> import keras\n    >>> from keras import ops\n    >>> type(ops.ones(()))\n    <class 'jaxlib.xla_extension.ArrayImpl'>\n\n    ⚠️ WARNING ⚠️: Using this function is dangerous and should be done\n    carefully. Changing the backend will **NOT** convert\n    the type of any already-instantiated objects.\n    Thus, any layers / tensors / etc. already created will no\n    longer be usable without errors. It is strongly recommended **not**\n    to keep around **any** Keras-originated objects instances created\n    before calling `set_backend()`.\n\n    This includes any function or class instance that uses any Keras\n    functionality. All such code needs to be re-executed after calling\n    `set_backend()` and re-importing all imported `keras` modules.\n    \"\"\"\n    os.environ[\"KERAS_BACKEND\"] = backend\n    # Clear module cache.\n    loaded_modules = [\n        key for key in sys.modules.keys() if key.startswith(\"keras\")\n    ]\n    for key in loaded_modules:\n        del sys.modules[key]\n    # Reimport Keras with the new backend (set via KERAS_BACKEND).\n    import keras\n\n    # Finally: refresh all imported Keras submodules.\n    globs = copy.copy(globals())\n    for key, value in globs.items():\n        if value.__class__ == keras.__class__:\n            if str(value).startswith(\"<module 'keras.\"):\n                module_name = str(value)\n                module_name = module_name[module_name.find(\"'\") + 1 :]\n                module_name = module_name[: module_name.find(\"'\")]\n                globals()[key] = importlib.import_module(module_name)\n\n    warnings.warn(\n        \"Using `keras.config.set_backend` is dangerous and should be done \"\n        \"carefully. Already-instantiated objects will not be converted. Thus, \"\n        \"any layers / tensors / etc. already created will no longer be usable \"\n        \"without errors. It is strongly recommended not to keep around any \"\n        \"Keras-originated objects instances created before calling \"\n        \"`set_backend()`. This includes any function or class instance that \"\n        \"uses any Keras functionality. All such code needs to be re-executed \"\n        \"after calling `set_backend()` and re-importing all imported `keras` \"\n        \"modules.\",\n        stacklevel=2,\n    )\n"
  },
  {
    "path": "keras/src/utils/backend_utils_test.py",
    "content": "import numpy as np\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import testing\nfrom keras.src.utils import backend_utils\n\n\nclass BackendUtilsTest(testing.TestCase):\n    @parameterized.named_parameters(\n        (\"numpy\", \"numpy\"),\n        (\"jax\", \"jax\"),\n        (\"tensorflow\", \"tensorflow\"),\n        (\"torch\", \"torch\"),\n    )\n    def test_dynamic_backend(self, name):\n        dynamic_backend = backend_utils.DynamicBackend()\n        x = np.random.uniform(size=[1, 2, 3]).astype(\"float32\")\n\n        if name == \"numpy\":\n            dynamic_backend.set_backend(name)\n            if backend.backend() != \"numpy\":\n                with self.assertRaisesRegex(\n                    NotImplementedError,\n                    \"Currently, we cannot dynamically import the numpy backend\",\n                ):\n                    y = dynamic_backend.numpy.log10(x)\n            else:\n                y = dynamic_backend.numpy.log10(x)\n                self.assertIsInstance(y, np.ndarray)\n        elif name == \"jax\":\n            import jax\n\n            dynamic_backend.set_backend(name)\n            y = dynamic_backend.numpy.log10(x)\n            self.assertIsInstance(y, jax.Array)\n        elif name == \"tensorflow\":\n            import tensorflow as tf\n\n            dynamic_backend.set_backend(name)\n            y = dynamic_backend.numpy.log10(x)\n            self.assertIsInstance(y, tf.Tensor)\n        elif name == \"torch\":\n            import torch\n\n            dynamic_backend.set_backend(name)\n            y = dynamic_backend.numpy.log10(x)\n            self.assertIsInstance(y, torch.Tensor)\n\n    def test_dynamic_backend_invalid_name(self):\n        dynamic_backend = backend_utils.DynamicBackend()\n        with self.assertRaisesRegex(ValueError, \"Available backends are\"):\n            dynamic_backend.set_backend(\"abc\")\n"
  },
  {
    "path": "keras/src/utils/code_stats.py",
    "content": "import os\n\n\ndef count_loc(directory, exclude=(\"_test\",), extensions=(\".py\",), verbose=0):\n    loc = 0\n    for root, _, fnames in os.walk(directory):\n        skip = False\n        for ex in exclude:\n            if root.endswith(ex):\n                skip = True\n        if skip:\n            continue\n\n        for fname in fnames:\n            skip = False\n            for ext in extensions:\n                if not fname.endswith(ext):\n                    skip = True\n                    break\n\n                for ex in exclude:\n                    if fname.endswith(ex + ext):\n                        skip = True\n                        break\n            if skip:\n                continue\n\n            fname = os.path.join(root, fname)\n            if verbose:\n                print(f\"Count LoCs in {fname}\")\n\n            with open(fname) as f:\n                lines = f.read().split(\"\\n\")\n\n            string_open = False\n            for line in lines:\n                line = line.strip()\n                if not line or line.startswith(\"#\"):\n                    continue\n                if not string_open:\n                    if not line.startswith('\"\"\"'):\n                        loc += 1\n                    else:\n                        if not line.endswith('\"\"\"'):\n                            string_open = True\n                else:\n                    if line.startswith('\"\"\"'):\n                        string_open = False\n    return loc\n"
  },
  {
    "path": "keras/src/utils/code_stats_test.py",
    "content": "import os\nimport sys\nfrom io import StringIO\n\nfrom keras.src.testing import test_case\nfrom keras.src.utils.code_stats import count_loc\n\n\nclass TestCountLoc(test_case.TestCase):\n    def setUp(self):\n        self.test_dir = self.get_temp_dir()\n\n    def create_file(self, filename, content):\n        with open(\n            os.path.join(self.test_dir, filename), \"w\", encoding=\"utf-8\"\n        ) as f:\n            f.write(content)\n\n    def test_count_loc_valid_python(self):\n        self.create_file(\n            \"sample.py\", \"# This is a test file\\n\\nprint('Hello')\\n\"\n        )\n        loc = count_loc(self.test_dir)\n        self.assertEqual(loc, 1)\n\n    def test_exclude_test_files(self):\n        self.create_file(\"sample_test.py\", \"print('Hello')\\n\")\n        loc = count_loc(self.test_dir, exclude=(\"_test\",))\n        self.assertEqual(loc, 0)\n\n    def test_other_extensions(self):\n        self.create_file(\"sample.txt\", \"Hello\\n\")\n        loc = count_loc(self.test_dir, extensions=(\".py\",))\n        self.assertEqual(loc, 0)\n\n    def test_comment_lines(self):\n        self.create_file(\n            \"sample.py\", \"# Comment\\nprint('Hello')\\n# Another comment\\n\"\n        )\n        loc = count_loc(self.test_dir)\n        self.assertEqual(loc, 1)\n\n    def test_empty_file(self):\n        self.create_file(\"empty.py\", \"\")\n        loc = count_loc(self.test_dir)\n        self.assertEqual(loc, 0)\n\n    def test_whitespace_only(self):\n        self.create_file(\"whitespace.py\", \"     \\n\\t\\n\")\n        loc = count_loc(self.test_dir)\n        self.assertEqual(loc, 0)\n\n    def test_inline_comments_after_code(self):\n        content = 'print(\"Hello\") # This is an inline comment'\n        self.create_file(\"inline_comment_sample.py\", content)\n        loc = count_loc(self.test_dir)\n        self.assertEqual(loc, 1)  # The comment shouldn't affect the count\n\n    def test_directory_structure(self):\n        content1 = 'print(\"Hello from file1\")'\n        content2 = 'print(\"Hello from file2\")'\n        os.mkdir(os.path.join(self.test_dir, \"subdir\"))\n        self.create_file(\"sample1.py\", content1)\n        self.create_file(os.path.join(\"subdir\", \"sample2.py\"), content2)\n        loc = count_loc(self.test_dir)\n        self.assertEqual(loc, 2)  # Both files should be counted\n\n    def test_normal_directory_name(self):\n        content = 'print(\"Hello from a regular directory\")'\n        os.makedirs(os.path.join(self.test_dir, \"some_test_dir\"))\n        self.create_file(os.path.join(\"some_test_dir\", \"sample.py\"), content)\n        loc = count_loc(self.test_dir)\n        self.assertEqual(loc, 1)  # Should count normally\n\n    def test_exclude_directory_name(self):\n        content = 'print(\"Hello from an excluded directory\")'\n        os.makedirs(os.path.join(self.test_dir, \"dir_test\"))\n        self.create_file(os.path.join(\"dir_test\", \"sample.py\"), content)\n        loc = count_loc(self.test_dir)\n        self.assertEqual(loc, 0)\n        # Shouldn't count the file in dir_test due to the exclusion pattern\n\n    def test_verbose_output(self):\n        content = 'print(\"Hello\")'\n        self.create_file(\"sample.py\", content)\n        original_stdout = sys.stdout\n        sys.stdout = StringIO()\n        count_loc(self.test_dir, verbose=1)\n        output = sys.stdout.getvalue()\n        sys.stdout = original_stdout\n        self.assertIn(\"Count LoCs in\", output)\n\n    def test_multiline_string_same_line(self):\n        content = '''\"\"\"This is a multiline string ending on the same line\"\"\"\n        print(\"Outside string\")'''\n        self.create_file(\"same_line_multiline.py\", content)\n        loc = count_loc(self.test_dir)\n        self.assertEqual(loc, 1)  # Only the print statement should count\n\n    def test_multiline_string_ends_on_same_line(self):\n        content = '\"\"\"a multiline string end on same line\"\"\"\\nprint(\"Outstr\")'\n        self.create_file(\"same_line_multiline.py\", content)\n        loc = count_loc(self.test_dir)\n        self.assertEqual(loc, 1)  # Only the print statement should count\n\n    def test_multiline_string_ends_in_middle_of_line(self):\n        content = '''print(\"Start\")\n        \"\"\"This is a multiline string ending in the middle of a line\"\"\"\n        \"\"\"This is another multiline string.\"\"\"\n        print(\"End\")'''\n        self.create_file(\"multiline_in_middle.py\", content)\n        loc = count_loc(self.test_dir)\n        self.assertEqual(loc, 2)  # Both print statements should count\n\n    def test_line_starting_with_triple_quotes_not_ending(self):\n        content = '\"\"\"\\nThis is a multiline string\\n'\n        self.create_file(\"test_file_2.py\", content)\n        path = os.path.join(self.test_dir, \"test_file_2.py\")\n        self.assertEqual(count_loc(path), 0)\n        # Because it's part of a multiline string\n\n    def test_line_starting_and_ending_with_triple_quotes(self):\n        content = '\"\"\"This is a one-liner docstring.\"\"\"\\n'\n        self.create_file(\"test_file_3.py\", content)\n        path = os.path.join(self.test_dir, \"test_file_3.py\")\n        self.assertEqual(count_loc(path), 0)\n        # This is still considered a comment/docstring\n\n    def test_string_open_true_line_starting_with_triple_quotes(self):\n        content = '\"\"\"\\nEnd of the multiline string.\"\"\"\\n'\n        self.create_file(\"test_file_4.py\", content)\n        path = os.path.join(self.test_dir, \"test_file_4.py\")\n        self.assertEqual(count_loc(path), 0)\n        # Entire content is a multiline string/comment\n"
  },
  {
    "path": "keras/src/utils/config.py",
    "content": "import copy\nimport json\n\ntry:\n    import difflib\nexcept ImportError:\n    difflib = None\n\nfrom keras.src.api_export import keras_export\n\n\n@keras_export(\"keras.utils.Config\")\nclass Config:\n    \"\"\"A Config is a dict-like container for named values.\n\n    It offers a few advantages over a plain dict:\n\n    - Setting and retrieving values via attribute setting / getting.\n    - Ability to freeze the config to ensure no accidental config modifications\n        occur past a certain point in your program.\n    - Easy serialization of the whole config as JSON.\n\n    Examples:\n\n    ```python\n    # Create a config via constructor arguments\n    config = Config(\"learning_rate\"=0.1, \"momentum\"=0.9)\n\n    # Then keep adding to it via attribute-style setting\n    config.use_ema = True\n    config.ema_overwrite_frequency = 100\n\n    # You can also add attributes via dict-like access\n    config[\"seed\"] = 123\n\n    # You can retrieve entries both via attribute-style\n    # access and dict-style access\n    assert config.seed == 100\n    assert config[\"learning_rate\"] == 0.1\n    ```\n\n    A config behaves like a dict:\n\n    ```python\n    config = Config(\"learning_rate\"=0.1, \"momentum\"=0.9)\n    for k, v in config.items():\n        print(f\"{k}={v}\")\n\n    print(f\"keys: {list(config.keys())}\")\n    print(f\"values: {list(config.values())}\")\n    ```\n\n    In fact, it can be turned into one:\n\n    ```python\n    config = Config(\"learning_rate\"=0.1, \"momentum\"=0.9)\n    dict_config = config.as_dict()\n    ```\n\n    You can easily serialize a config to JSON:\n\n    ```python\n    config = Config(\"learning_rate\"=0.1, \"momentum\"=0.9)\n\n    json_str = config.to_json()\n    ```\n\n    You can also freeze a config to prevent further changes:\n\n    ```python\n    config = Config()\n    config.optimizer = \"adam\"\n    config.seed = 123\n\n    # Freeze the config to prevent changes.\n    config.freeze()\n    assert config.frozen\n\n    config.foo = \"bar\"  # This will raise an error.\n    ```\n    \"\"\"\n\n    __attrs__ = None\n\n    def __init__(self, **kwargs):\n        self._config = kwargs\n        self._frozen = False\n        self.__attrs__ = set(dir(self))\n\n    @property\n    def frozen(self):\n        \"\"\"Returns True if the config is frozen.\"\"\"\n        return self._frozen\n\n    def freeze(self):\n        \"\"\"Marks the config as frozen, preventing any ulterior modification.\"\"\"\n        self._frozen = True\n\n    def unfreeze(self):\n        self._frozen = False\n\n    def _raise_if_frozen(self):\n        if self._frozen:\n            raise ValueError(\n                \"Cannot mutate attribute(s) because the config is frozen.\"\n            )\n\n    def as_dict(self):\n        return copy.copy(self._config)\n\n    def to_json(self):\n        return json.dumps(self._config)\n\n    def keys(self):\n        return self._config.keys()\n\n    def values(self):\n        return self._config.values()\n\n    def items(self):\n        return self._config.items()\n\n    def pop(self, *args):\n        self._raise_if_frozen()\n        return self._config.pop(*args)\n\n    def update(self, *args, **kwargs):\n        self._raise_if_frozen()\n        return self._config.update(*args, **kwargs)\n\n    def get(self, keyname, value=None):\n        return self._config.get(keyname, value)\n\n    def __setattr__(self, name, value):\n        attrs = object.__getattribute__(self, \"__attrs__\")\n        if attrs is None or name in attrs:\n            return object.__setattr__(self, name, value)\n\n        self._raise_if_frozen()\n        self._config[name] = value\n\n    def __getattr__(self, name):\n        attrs = object.__getattribute__(self, \"__attrs__\")\n        if attrs is None or name in attrs:\n            return object.__getattribute__(self, name)\n\n        if name in self._config:\n            return self._config[name]\n\n        msg = f\"Unknown attribute: '{name}'.\"\n        if difflib is not None:\n            closest_matches = difflib.get_close_matches(\n                name, self._config.keys(), n=1, cutoff=0.7\n            )\n            if closest_matches:\n                msg += f\" Did you mean '{closest_matches[0]}'?\"\n        raise AttributeError(msg)\n\n    def __setitem__(self, key, item):\n        self._raise_if_frozen()\n        self._config[key] = item\n\n    def __getitem__(self, key):\n        return self._config[key]\n\n    def __repr__(self):\n        return f\"<Config {self._config}>\"\n\n    def __iter__(self):\n        keys = sorted(self._config.keys())\n        for k in keys:\n            yield k\n\n    def __len__(self):\n        return len(self._config)\n\n    def __delitem__(self, key):\n        self._raise_if_frozen()\n        del self._config[key]\n\n    def __contains__(self, item):\n        return item in self._config\n"
  },
  {
    "path": "keras/src/utils/dataset_utils.py",
    "content": "import os\nimport random\nimport time\nimport warnings\nfrom multiprocessing.pool import ThreadPool\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src import tree\nfrom keras.src.api_export import keras_export\nfrom keras.src.utils import file_utils\nfrom keras.src.utils import io_utils\nfrom keras.src.utils.module_utils import grain\n\n\n@keras_export(\"keras.utils.split_dataset\")\ndef split_dataset(\n    dataset,\n    left_size=None,\n    right_size=None,\n    shuffle=False,\n    seed=None,\n    preferred_backend=None,\n):\n    \"\"\"Splits a dataset into a left half and a right half (e.g. train / test).\n\n    Args:\n        dataset:\n            A `tf.data.Dataset`, a `torch.utils.data.Dataset` object,\n            or a list/tuple of arrays with the same length.\n        left_size: If float (in the range `[0, 1]`), it signifies\n            the fraction of the data to pack in the left dataset. If integer, it\n            signifies the number of samples to pack in the left dataset. If\n            `None`, defaults to the complement to `right_size`.\n            Defaults to `None`.\n        right_size: If float (in the range `[0, 1]`), it signifies\n            the fraction of the data to pack in the right dataset.\n            If integer, it signifies the number of samples to pack\n            in the right dataset.\n            If `None`, defaults to the complement to `left_size`.\n            Defaults to `None`.\n        shuffle: Boolean, whether to shuffle the data before splitting it.\n        seed: A random seed for shuffling.\n        preferred_backend: String, specifying which backend\n            (e.g.; \"tensorflow\", \"torch\") to use. If `None`, the\n            backend is inferred from the type of `dataset` - if\n            `dataset` is a `tf.data.Dataset`, \"tensorflow\" backend\n            is used, if `dataset` is a `torch.utils.data.Dataset`,\n            \"torch\" backend is used, and if `dataset` is a list/tuple/np.array\n            the current Keras backend is used. Defaults to `None`.\n\n    Returns:\n        A tuple of two dataset objects, the left and right splits. The exact\n        type of the returned objects depends on the `preferred_backend`.\n        For example, with a \"tensorflow\" backend,\n        `tf.data.Dataset` objects are returned. With a \"torch\" backend,\n        `torch.utils.data.Dataset` objects are returned.\n    Example:\n\n    >>> data = np.random.random(size=(1000, 4))\n    >>> left_ds, right_ds = keras.utils.split_dataset(data, left_size=0.8)\n    >>> # For a tf.data.Dataset, you can use .cardinality()\n    >>> # >>> int(left_ds.cardinality())\n    >>> # 800\n    >>> # For a torch.utils.data.Dataset, you can use len()\n    >>> # >>> len(left_ds)\n    >>> # 800\n    \"\"\"\n    preferred_backend = preferred_backend or _infer_preferred_backend(dataset)\n    if preferred_backend != \"torch\":\n        return _split_dataset_tf(\n            dataset,\n            left_size=left_size,\n            right_size=right_size,\n            shuffle=shuffle,\n            seed=seed,\n        )\n    else:\n        return _split_dataset_torch(\n            dataset,\n            left_size=left_size,\n            right_size=right_size,\n            shuffle=shuffle,\n            seed=seed,\n        )\n\n\ndef _split_dataset_tf(\n    dataset, left_size=None, right_size=None, shuffle=False, seed=None\n):\n    \"\"\"Splits a dataset into a left half and a right half (e.g. train / test).\n\n    Args:\n        dataset:\n            A `tf.data.Dataset` object,\n            or a list/tuple of arrays with the same length.\n        left_size: If float (in the range `[0, 1]`), it signifies\n            the fraction of the data to pack in the left dataset. If integer, it\n            signifies the number of samples to pack in the left dataset. If\n            `None`, defaults to the complement to `right_size`.\n            Defaults to `None`.\n        right_size: If float (in the range `[0, 1]`), it signifies\n            the fraction of the data to pack in the right dataset.\n            If integer, it signifies the number of samples to pack\n            in the right dataset.\n            If `None`, defaults to the complement to `left_size`.\n            Defaults to `None`.\n        shuffle: Boolean, whether to shuffle the data before splitting it.\n        seed: A random seed for shuffling.\n\n    Returns:\n        A tuple of two `tf.data.Dataset` objects:\n        the left and right splits.\n    \"\"\"\n    from keras.src.utils.module_utils import tensorflow as tf\n\n    dataset_type_spec = _get_type_spec(dataset)\n\n    if dataset_type_spec is None:\n        raise TypeError(\n            \"The `dataset` argument must be either\"\n            \"a `tf.data.Dataset` object, or\"\n            \"a list/tuple of arrays. \"\n            f\"Received: dataset={dataset} of type {type(dataset)}\"\n        )\n\n    if right_size is None and left_size is None:\n        raise ValueError(\n            \"At least one of the `left_size` or `right_size` \"\n            \"must be specified. Received: left_size=None and \"\n            \"right_size=None\"\n        )\n\n    dataset_as_list = _convert_dataset_to_list(dataset, dataset_type_spec)\n\n    if shuffle:\n        if seed is None:\n            seed = random.randint(0, int(1e6))\n        random.seed(seed)\n        random.shuffle(dataset_as_list)\n\n    total_length = len(dataset_as_list)\n\n    left_size, right_size = _rescale_dataset_split_sizes(\n        left_size, right_size, total_length\n    )\n    left_split = list(dataset_as_list[:left_size])\n    right_split = list(dataset_as_list[-right_size:])\n\n    left_split = _restore_dataset_from_list(\n        left_split, dataset_type_spec, dataset\n    )\n    right_split = _restore_dataset_from_list(\n        right_split, dataset_type_spec, dataset\n    )\n\n    left_split = tf.data.Dataset.from_tensor_slices(left_split)\n    right_split = tf.data.Dataset.from_tensor_slices(right_split)\n\n    # apply batching to the splits if the dataset is batched\n    if dataset_type_spec is tf.data.Dataset and is_batched(dataset):\n        batch_size = get_batch_size(dataset)\n        if batch_size is not None:\n            left_split = left_split.batch(batch_size)\n            right_split = right_split.batch(batch_size)\n\n    left_split = left_split.prefetch(tf.data.AUTOTUNE)\n    right_split = right_split.prefetch(tf.data.AUTOTUNE)\n    return left_split, right_split\n\n\ndef _split_dataset_torch(\n    dataset, left_size=None, right_size=None, shuffle=False, seed=None\n):\n    \"\"\"Splits a dataset into a left half and a right half (e.g. train / test).\n\n    Args:\n        dataset:\n            A `torch.utils.data.Dataset` object,\n            or a list/tuple of arrays with the same length.\n        left_size: If float (in the range `[0, 1]`), it signifies\n            the fraction of the data to pack in the left dataset. If integer, it\n            signifies the number of samples to pack in the left dataset. If\n            `None`, defaults to the complement to `right_size`.\n            Defaults to `None`.\n        right_size: If float (in the range `[0, 1]`), it signifies\n            the fraction of the data to pack in the right dataset.\n            If integer, it signifies the number of samples to pack\n            in the right dataset.\n            If `None`, defaults to the complement to `left_size`.\n            Defaults to `None`.\n        shuffle: Boolean, whether to shuffle the data before splitting it.\n        seed: A random seed for shuffling.\n\n    Returns:\n        A tuple of two `torch.utils.data.Dataset` objects:\n        the left and right splits.\n    \"\"\"\n    import torch\n    from torch.utils.data import TensorDataset\n    from torch.utils.data import random_split\n\n    dataset_type_spec = _get_type_spec(dataset)\n    if dataset_type_spec is None:\n        raise TypeError(\n            \"The `dataset` argument must be a `torch.utils.data.Dataset`\"\n            \" object, or a list/tuple of arrays.\"\n            f\" Received: dataset={dataset} of type {type(dataset)}\"\n        )\n\n    if not isinstance(dataset, torch.utils.data.Dataset):\n        if dataset_type_spec is np.ndarray:\n            dataset = TensorDataset(torch.from_numpy(dataset))\n        elif dataset_type_spec in (list, tuple):\n            tensors = [torch.from_numpy(x) for x in dataset]\n            dataset = TensorDataset(*tensors)\n        elif is_tf_dataset(dataset):\n            dataset_as_list = _convert_dataset_to_list(\n                dataset, dataset_type_spec\n            )\n            tensors = [\n                torch.from_numpy(np.array(sample))\n                for sample in zip(*dataset_as_list)\n            ]\n            dataset = TensorDataset(*tensors)\n\n    if right_size is None and left_size is None:\n        raise ValueError(\n            \"At least one of the `left_size` or `right_size` \"\n            \"must be specified. \"\n            \"Received: left_size=None and right_size=None\"\n        )\n\n    # Calculate total length and rescale split sizes\n    total_length = len(dataset)\n    left_size, right_size = _rescale_dataset_split_sizes(\n        left_size, right_size, total_length\n    )\n\n    # Shuffle the dataset if required\n    if shuffle:\n        generator = torch.Generator()\n        if seed is not None:\n            generator.manual_seed(seed)\n        else:\n            generator.seed()\n    else:\n        generator = None\n\n    left_split, right_split = random_split(\n        dataset, [left_size, right_size], generator=generator\n    )\n\n    return left_split, right_split\n\n\ndef _infer_preferred_backend(dataset):\n    \"\"\"Infer the backend from the dataset type.\"\"\"\n    if isinstance(dataset, (list, tuple, np.ndarray)):\n        return backend.backend()\n    if is_tf_dataset(dataset):\n        return \"tensorflow\"\n    elif is_torch_dataset(dataset):\n        return \"torch\"\n    else:\n        raise TypeError(f\"Unsupported dataset type: {type(dataset)}\")\n\n\ndef _convert_dataset_to_list(\n    dataset,\n    dataset_type_spec,\n    data_size_warning_flag=True,\n    ensure_shape_similarity=True,\n):\n    \"\"\"Convert `dataset` object to a list of samples.\n\n    Args:\n        dataset: A `tf.data.Dataset`, a `torch.utils.data.Dataset` object,\n            or a list/tuple of arrays.\n        dataset_type_spec: the type of the dataset.\n        data_size_warning_flag: If set to `True`, a warning will\n            be issued if the dataset takes longer than 10 seconds to iterate.\n            Defaults to `True`.\n        ensure_shape_similarity: If set to `True`, the shape of\n            the first sample will be used to validate the shape of rest of the\n            samples. Defaults to `True`.\n\n    Returns:\n        List: A list of samples.\n    \"\"\"\n    dataset_iterator = _get_data_iterator_from_dataset(\n        dataset, dataset_type_spec\n    )\n    dataset_as_list = []\n\n    start_time = time.time()\n    for sample in _get_next_sample(\n        dataset_iterator,\n        ensure_shape_similarity,\n        data_size_warning_flag,\n        start_time,\n    ):\n        dataset_as_list.append(sample)\n\n    return dataset_as_list\n\n\ndef _get_data_iterator_from_dataset(dataset, dataset_type_spec):\n    \"\"\"Get the iterator from a dataset.\n\n    Args:\n        dataset: A `tf.data.Dataset`, a `torch.utils.data.Dataset` object,\n            or a list/tuple of arrays.\n        dataset_type_spec: The type of the dataset.\n\n    Returns:\n        iterator: An `iterator` object.\n    \"\"\"\n    if dataset_type_spec is list:\n        if len(dataset) == 0:\n            raise ValueError(\n                \"Received an empty list dataset. \"\n                \"Please provide a non-empty list of arrays.\"\n            )\n\n        expected_shape = None\n        for i, element in enumerate(dataset):\n            if not isinstance(element, np.ndarray):\n                raise ValueError(\n                    \"Expected a list of `numpy.ndarray` objects,\"\n                    f\"Received: {type(element)} at index {i}.\"\n                )\n            if expected_shape is None:\n                expected_shape = element.shape\n            elif element.shape[0] != expected_shape[0]:\n                raise ValueError(\n                    \"Received a list of NumPy arrays with different lengths.\"\n                    f\"Mismatch found at index {i}, \"\n                    f\"Expected shape={expected_shape} \"\n                    f\"Received shape={np.array(element).shape}.\"\n                    \"Please provide a list of NumPy arrays of the same length.\"\n                )\n\n        return iter(zip(*dataset))\n    elif dataset_type_spec is tuple:\n        if len(dataset) == 0:\n            raise ValueError(\n                \"Received an empty list dataset.\"\n                \"Please provide a non-empty tuple of arrays.\"\n            )\n\n        expected_shape = None\n        for i, element in enumerate(dataset):\n            if not isinstance(element, np.ndarray):\n                raise ValueError(\n                    \"Expected a tuple of `numpy.ndarray` objects,\"\n                    f\"Received: {type(element)} at index {i}.\"\n                )\n            if expected_shape is None:\n                expected_shape = element.shape\n            elif element.shape[0] != expected_shape[0]:\n                raise ValueError(\n                    \"Received a tuple of NumPy arrays with different lengths.\"\n                    f\"Mismatch found at index {i}, \"\n                    f\"Expected shape={expected_shape} \"\n                    f\"Received shape={np.array(element).shape}.\"\n                    \"Please provide a tuple of NumPy arrays of the same length.\"\n                )\n\n        return iter(zip(*dataset))\n    elif is_tf_dataset(dataset):\n        if is_batched(dataset):\n            dataset = dataset.unbatch()\n        return iter(dataset)\n\n    elif is_torch_dataset(dataset):\n        return iter(dataset)\n    elif dataset_type_spec is np.ndarray:\n        return iter(dataset)\n    raise ValueError(f\"Invalid dataset_type_spec: {dataset_type_spec}\")\n\n\ndef _get_next_sample(\n    dataset_iterator,\n    ensure_shape_similarity,\n    data_size_warning_flag,\n    start_time,\n):\n    \"\"\"Yield data samples from the `dataset_iterator`.\n\n    Args:\n        dataset_iterator: An `iterator` object.\n        ensure_shape_similarity: If set to `True`, the shape of\n            the first sample will be used to validate the shape of rest of the\n            samples. Defaults to `True`.\n        data_size_warning_flag: If set to `True`, a warning will\n            be issued if the dataset takes longer than 10 seconds to iterate.\n            Defaults to `True`.\n        start_time (float): the start time of the dataset iteration. this is\n            used only if `data_size_warning_flag` is set to true.\n\n    Yields:\n        data_sample: The next sample.\n    \"\"\"\n    from keras.src.trainers.data_adapters.data_adapter_utils import (\n        is_tensorflow_tensor,\n    )\n    from keras.src.trainers.data_adapters.data_adapter_utils import (\n        is_torch_tensor,\n    )\n\n    try:\n        dataset_iterator = iter(dataset_iterator)\n        first_sample = next(dataset_iterator)\n        if (\n            isinstance(first_sample, np.ndarray)\n            or is_tensorflow_tensor(first_sample)\n            or is_torch_tensor(first_sample)\n        ):\n            first_sample_shape = np.array(first_sample).shape\n        else:\n            first_sample_shape = None\n            ensure_shape_similarity = False\n        yield first_sample\n    except StopIteration:\n        raise ValueError(\n            \"Received an empty dataset. Argument `dataset` must \"\n            \"be a non-empty list/tuple of `numpy.ndarray` objects \"\n            \"or `tf.data.Dataset` objects.\"\n        )\n\n    for i, sample in enumerate(dataset_iterator):\n        if ensure_shape_similarity:\n            if first_sample_shape != np.array(sample).shape:\n                raise ValueError(\n                    \"All `dataset` samples must have same shape, \"\n                    f\"Expected shape: {np.array(first_sample).shape} \"\n                    f\"Received shape: {np.array(sample).shape} at index \"\n                    f\"{i}.\"\n                )\n        if data_size_warning_flag:\n            if i % 10 == 0:\n                cur_time = time.time()\n                # warns user if the dataset is too large to iterate within 10s\n                if int(cur_time - start_time) > 10 and data_size_warning_flag:\n                    warnings.warn(\n                        \"The dataset is taking longer than 10 seconds to \"\n                        \"iterate over. This may be due to the size of the \"\n                        \"dataset. Keep in mind that the `split_dataset` \"\n                        \"utility is only for small in-memory dataset \"\n                        \"(e.g. < 10,000 samples).\",\n                        category=ResourceWarning,\n                        source=\"split_dataset\",\n                    )\n                    data_size_warning_flag = False\n        yield sample\n\n\ndef is_tf_dataset(dataset):\n    return _mro_matches(\n        dataset,\n        class_names=(\"DatasetV2\", \"Dataset\"),\n        module_substrings=(\n            \"tensorflow.python.data\",  # TF classic\n            \"tensorflow.data\",  # newer TF paths\n        ),\n    )\n\n\ndef is_grain_dataset(dataset):\n    return _mro_matches(\n        dataset,\n        class_names=(\"MapDataset\", \"IterDataset\"),\n        module_prefixes=(\"grain._src.python\",),\n    )\n\n\ndef is_torch_dataset(dataset):\n    return _mro_matches(dataset, (\"Dataset\",), (\"torch.utils.data\",))\n\n\ndef _mro_matches(\n    dataset, class_names, module_prefixes=(), module_substrings=()\n):\n    if not hasattr(dataset, \"__class__\"):\n        return False\n    for parent in dataset.__class__.__mro__:\n        if parent.__name__ in class_names:\n            mod = str(parent.__module__)\n            if any(mod.startswith(pref) for pref in module_prefixes):\n                return True\n            if any(subs in mod for subs in module_substrings):\n                return True\n    return False\n\n\ndef _rescale_dataset_split_sizes(left_size, right_size, total_length):\n    \"\"\"Rescale the dataset split sizes.\n\n    We want to ensure that the sum of\n    the split sizes is equal to the total length of the dataset.\n\n    Args:\n        left_size: The size of the left dataset split.\n        right_size: The size of the right dataset split.\n        total_length: The total length of the dataset.\n\n    Returns:\n        tuple: A tuple of rescaled `left_size` and `right_size` integers.\n    \"\"\"\n    left_size_type = type(left_size)\n    right_size_type = type(right_size)\n\n    # check both left_size and right_size are integers or floats\n    if (left_size is not None and left_size_type not in [int, float]) and (\n        right_size is not None and right_size_type not in [int, float]\n    ):\n        raise TypeError(\n            \"Invalid `left_size` and `right_size` Types. Expected: \"\n            \"integer or float or None, Received: type(left_size)=\"\n            f\"{left_size_type} and type(right_size)={right_size_type}\"\n        )\n\n    # check left_size is a integer or float\n    if left_size is not None and left_size_type not in [int, float]:\n        raise TypeError(\n            \"Invalid `left_size` Type. Expected: int or float or None, \"\n            f\"Received: type(left_size)={left_size_type}.  \"\n        )\n\n    # check right_size is a integer or float\n    if right_size is not None and right_size_type not in [int, float]:\n        raise TypeError(\n            \"Invalid `right_size` Type. \"\n            \"Expected: int or float or None,\"\n            f\"Received: type(right_size)={right_size_type}.\"\n        )\n\n    # check left_size and right_size are non-zero\n    if left_size == 0 and right_size == 0:\n        raise ValueError(\n            \"Both `left_size` and `right_size` are zero. \"\n            \"At least one of the split sizes must be non-zero.\"\n        )\n\n    # check left_size is non-negative and less than 1 and less than total_length\n    if (\n        left_size_type is int\n        and (left_size <= 0 or left_size >= total_length)\n        or left_size_type is float\n        and (left_size <= 0 or left_size >= 1)\n    ):\n        raise ValueError(\n            \"`left_size` should be either a positive integer \"\n            f\"smaller than {total_length}, or a float \"\n            \"within the range `[0, 1]`. Received: left_size=\"\n            f\"{left_size}\"\n        )\n\n    # check right_size is non-negative and less than 1 and less than\n    # total_length\n    if (\n        right_size_type is int\n        and (right_size <= 0 or right_size >= total_length)\n        or right_size_type is float\n        and (right_size <= 0 or right_size >= 1)\n    ):\n        raise ValueError(\n            \"`right_size` should be either a positive integer \"\n            f\"and smaller than {total_length} or a float \"\n            \"within the range `[0, 1]`. Received: right_size=\"\n            f\"{right_size}\"\n        )\n\n    # check sum of left_size and right_size is less than or equal to\n    # total_length\n    if (\n        right_size_type is left_size_type is float\n        and right_size + left_size > 1\n    ):\n        raise ValueError(\n            \"The sum of `left_size` and `right_size` is greater \"\n            \"than 1. It must be less than or equal to 1.\"\n        )\n\n    if left_size_type is float:\n        left_size = round(left_size * total_length)\n    elif left_size_type is int:\n        left_size = float(left_size)\n\n    if right_size_type is float:\n        right_size = round(right_size * total_length)\n    elif right_size_type is int:\n        right_size = float(right_size)\n\n    if left_size is None:\n        left_size = total_length - right_size\n    elif right_size is None:\n        right_size = total_length - left_size\n\n    if left_size + right_size > total_length:\n        raise ValueError(\n            \"The sum of `left_size` and `right_size` should \"\n            f\"be smaller than the {total_length}. \"\n            f\"Received: left_size + right_size = {left_size + right_size}\"\n            f\"and total_length = {total_length}\"\n        )\n\n    for split, side in [(left_size, \"left\"), (right_size, \"right\")]:\n        if split == 0:\n            raise ValueError(\n                f\"With `dataset` of length={total_length}, `left_size`=\"\n                f\"{left_size} and `right_size`={right_size}.\"\n                f\"Resulting {side} side dataset split will be empty. \"\n                \"Adjust any of the aforementioned parameters\"\n            )\n\n    left_size, right_size = int(left_size), int(right_size)\n    return left_size, right_size\n\n\ndef _restore_dataset_from_list(\n    dataset_as_list, dataset_type_spec, original_dataset\n):\n    \"\"\"Restore the dataset from the list of arrays.\"\"\"\n    if (\n        dataset_type_spec in [tuple, list]\n        or is_tf_dataset(original_dataset)\n        or is_torch_dataset(original_dataset)\n    ):\n        # Save structure by taking the first element.\n        element_spec = dataset_as_list[0]\n        # Flatten each element.\n        dataset_as_list = [tree.flatten(sample) for sample in dataset_as_list]\n        # Combine respective elements at all indices.\n        dataset_as_list = [np.array(sample) for sample in zip(*dataset_as_list)]\n        # Recreate the original structure of elements.\n        dataset_as_list = tree.pack_sequence_as(element_spec, dataset_as_list)\n        # Turn lists to tuples as tf.data will fail on lists.\n        return tree.traverse(\n            lambda x: tuple(x) if isinstance(x, list) else x,\n            dataset_as_list,\n            top_down=False,\n        )\n\n    return dataset_as_list\n\n\ndef is_batched(dataset):\n    \"\"\"Check if the `tf.data.Dataset` is batched.\"\"\"\n    return hasattr(dataset, \"_batch_size\")\n\n\ndef get_batch_size(dataset):\n    \"\"\"Get the batch size of the dataset.\"\"\"\n    if is_batched(dataset):\n        return dataset._batch_size\n    else:\n        return None\n\n\ndef _get_type_spec(dataset):\n    \"\"\"Get the type spec of the dataset.\"\"\"\n    if isinstance(dataset, tuple):\n        return tuple\n    elif isinstance(dataset, list):\n        return list\n    elif isinstance(dataset, np.ndarray):\n        return np.ndarray\n    elif is_tf_dataset(dataset):\n        from keras.src.utils.module_utils import tensorflow as tf\n\n        return tf.data.Dataset\n    elif is_torch_dataset(dataset):\n        from torch.utils.data import Dataset as TorchDataset\n\n        return TorchDataset\n    elif is_grain_dataset(dataset):\n        from grain import MapDataset\n\n        return MapDataset\n    else:\n        return None\n\n\ndef index_directory(\n    directory,\n    labels,\n    formats,\n    class_names=None,\n    shuffle=True,\n    seed=None,\n    follow_links=False,\n    verbose=True,\n):\n    \"\"\"List all files in `directory`, with their labels.\n\n    Args:\n        directory: Directory where the data is located.\n            If `labels` is `\"inferred\"`, it should contain\n            subdirectories, each containing files for a class.\n            Otherwise, the directory structure is ignored.\n        labels: Either `\"inferred\"`\n            (labels are generated from the directory structure),\n            `None` (no labels),\n            or a list/tuple of integer labels of the same size as the number\n            of valid files found in the directory.\n            Labels should be sorted according\n            to the alphanumeric order of the image file paths\n            (obtained via `os.walk(directory)` in Python).\n        formats: Allowlist of file extensions to index\n            (e.g. `\".jpg\"`, `\".txt\"`).\n        class_names: Only valid if `labels=\"inferred\"`. This is the explicit\n            list of class names (must match names of subdirectories). Used\n            to control the order of the classes\n            (otherwise alphanumerical order is used).\n        shuffle: Whether to shuffle the data. Defaults to `True`.\n            If set to `False`, sorts the data in alphanumeric order.\n        seed: Optional random seed for shuffling.\n        follow_links: Whether to visits subdirectories pointed to by symlinks.\n        verbose: Whether the function prints number of files found and classes.\n            Defaults to `True`.\n\n    Returns:\n        tuple (file_paths, labels, class_names).\n        - file_paths: list of file paths (strings).\n        - labels: list of matching integer labels (same length as file_paths)\n        - class_names: names of the classes corresponding to these labels, in\n        order.\n    \"\"\"\n    if file_utils.is_remote_path(directory):\n        from keras.src.utils.module_utils import tensorflow as tf\n\n        os_module = tf.io.gfile\n        path_module = tf.io.gfile\n    else:\n        os_module = os\n        path_module = os.path\n\n    if labels == \"inferred\":\n        subdirs = []\n        for subdir in sorted(os_module.listdir(directory)):\n            if path_module.isdir(path_module.join(directory, subdir)):\n                if not subdir.startswith(\".\"):\n                    if subdir.endswith(\"/\"):\n                        subdir = subdir[:-1]\n                    subdirs.append(subdir)\n        if class_names is not None:\n            if not set(class_names).issubset(set(subdirs)):\n                raise ValueError(\n                    \"The `class_names` passed did not match the \"\n                    \"names of the subdirectories of the target directory. \"\n                    f\"Expected: {subdirs} (or a subset of it), \"\n                    f\"but received: class_names={class_names}\"\n                )\n            subdirs = class_names  # Keep provided order.\n    else:\n        # In the explicit/no-label cases, index from the parent directory down.\n        subdirs = [\"\"]\n        if class_names is not None:\n            if labels is None:\n                raise ValueError(\n                    \"When `labels=None` (no labels), argument `class_names` \"\n                    \"cannot be specified.\"\n                )\n            else:\n                raise ValueError(\n                    \"When argument `labels` is specified, argument \"\n                    \"`class_names` cannot be specified (the `class_names` \"\n                    \"will be the sorted list of labels).\"\n                )\n    class_names = subdirs\n    class_indices = dict(zip(class_names, range(len(class_names))))\n\n    # Build an index of the files\n    # in the different class subfolders.\n    pool = ThreadPool()\n    results = []\n    filenames = []\n\n    for dirpath in (path_module.join(directory, subdir) for subdir in subdirs):\n        results.append(\n            pool.apply_async(\n                index_subdirectory,\n                (dirpath, class_indices, follow_links, formats),\n            )\n        )\n    labels_list = []\n    for res in results:\n        partial_filenames, partial_labels = res.get()\n        labels_list.append(partial_labels)\n        filenames += partial_filenames\n\n    if labels == \"inferred\":\n        # Inferred labels.\n        i = 0\n        labels = np.zeros((len(filenames),), dtype=\"int32\")\n        for partial_labels in labels_list:\n            labels[i : i + len(partial_labels)] = partial_labels\n            i += len(partial_labels)\n    elif labels is None:\n        class_names = None\n    else:\n        # Manual labels.\n        if len(labels) != len(filenames):\n            raise ValueError(\n                \"Expected the lengths of `labels` to match the number \"\n                \"of files in the target directory. len(labels) is \"\n                f\"{len(labels)} while we found {len(filenames)} files \"\n                f\"in directory {directory}.\"\n            )\n        class_names = [str(label) for label in sorted(set(labels))]\n    if verbose:\n        if labels is None:\n            io_utils.print_msg(f\"Found {len(filenames)} files.\")\n        else:\n            io_utils.print_msg(\n                f\"Found {len(filenames)} files belonging \"\n                f\"to {len(class_names)} classes.\"\n            )\n    pool.close()\n    pool.join()\n    file_paths = [path_module.join(directory, fname) for fname in filenames]\n\n    if shuffle:\n        # Shuffle globally to erase macro-structure\n        if seed is None:\n            seed = np.random.randint(1e6)\n        rng = np.random.RandomState(seed)\n        rng.shuffle(file_paths)\n        if labels is not None:\n            rng = np.random.RandomState(seed)\n            rng.shuffle(labels)\n    return file_paths, labels, class_names\n\n\ndef iter_valid_files(directory, follow_links, formats):\n    if file_utils.is_remote_path(directory):\n        from keras.src.utils.module_utils import tensorflow as tf\n\n        io_module = tf.io.gfile\n    else:\n        io_module = os\n\n    if not follow_links:\n        walk = io_module.walk(directory)\n    else:\n        walk = os.walk(directory, followlinks=follow_links)\n    for root, _, files in sorted(walk, key=lambda x: x[0]):\n        for fname in sorted(files):\n            if fname.lower().endswith(formats):\n                yield root, fname\n\n\ndef index_subdirectory(directory, class_indices, follow_links, formats):\n    \"\"\"Recursively walks directory and list image paths and their class index.\n\n    Args:\n        directory: string, target directory.\n        class_indices: dict mapping class names to their index.\n        follow_links: boolean, whether to recursively follow subdirectories\n            (if False, we only list top-level images in `directory`).\n        formats: Allowlist of file extensions to index (e.g. \".jpg\", \".txt\").\n\n    Returns:\n        tuple `(filenames, labels)`. `filenames` is a list of relative file\n            paths, and `labels` is a list of integer labels corresponding\n            to these files.\n    \"\"\"\n    if file_utils.is_remote_path(directory):\n        from keras.src.utils.module_utils import tensorflow as tf\n\n        path_module = tf.io.gfile\n    else:\n        path_module = os.path\n\n    dirname = os.path.basename(directory)\n    valid_files = iter_valid_files(directory, follow_links, formats)\n    labels = []\n    filenames = []\n    for root, fname in valid_files:\n        labels.append(class_indices[dirname])\n        absolute_path = path_module.join(root, fname)\n        relative_path = path_module.join(\n            dirname, os.path.relpath(absolute_path, directory)\n        )\n        filenames.append(relative_path)\n    return filenames, labels\n\n\ndef get_training_or_validation_split(samples, labels, validation_split, subset):\n    \"\"\"Potentially restrict samples & labels to a training or validation split.\n\n    Args:\n        samples: List of elements.\n        labels: List of corresponding labels.\n        validation_split: Float, fraction of data to reserve for validation.\n        subset: Subset of the data to return.\n            Either `\"training\"`, `\"validation\"`, or `None`.\n            If `None`, we return all of the data.\n\n    Returns:\n        tuple (samples, labels), potentially restricted to the specified subset.\n    \"\"\"\n    if not validation_split:\n        return samples, labels\n\n    num_val_samples = int(validation_split * len(samples))\n    if subset == \"training\":\n        io_utils.print_msg(\n            f\"Using {len(samples) - num_val_samples} files for training.\"\n        )\n        samples = samples[:-num_val_samples]\n        if labels is not None:\n            labels = labels[:-num_val_samples]\n    elif subset == \"validation\":\n        io_utils.print_msg(f\"Using {num_val_samples} files for validation.\")\n        samples = samples[-num_val_samples:]\n        if labels is not None:\n            labels = labels[-num_val_samples:]\n    else:\n        raise ValueError(\n            '`subset` must be either \"training\" '\n            f'or \"validation\", received: {subset}'\n        )\n    return samples, labels\n\n\ndef labels_to_dataset_tf(labels, label_mode, num_classes):\n    \"\"\"Create a `tf.data.Dataset` from the list/tuple of labels.\n\n    Args:\n        labels: list/tuple of labels to be converted into a `tf.data.Dataset`.\n        label_mode: String describing the encoding of `labels`. Options are:\n        - `\"binary\"` indicates that the labels (there can be only 2) are encoded\n            as `float32` scalars with values 0 or 1\n            (e.g. for `binary_crossentropy`).\n        - `\"categorical\"` means that the labels are mapped into a categorical\n            vector.  (e.g. for `categorical_crossentropy` loss).\n        num_classes: number of classes of labels.\n\n    Returns:\n        A `tf.data.Dataset` instance.\n    \"\"\"\n    from keras.src.utils.module_utils import tensorflow as tf\n\n    label_ds = tf.data.Dataset.from_tensor_slices(labels)\n    if label_mode == \"binary\":\n        label_ds = label_ds.map(\n            lambda x: tf.expand_dims(tf.cast(x, \"float32\"), axis=-1),\n            num_parallel_calls=tf.data.AUTOTUNE,\n        )\n    elif label_mode == \"categorical\":\n        label_ds = label_ds.map(\n            lambda x: tf.one_hot(x, num_classes),\n            num_parallel_calls=tf.data.AUTOTUNE,\n        )\n    return label_ds\n\n\ndef labels_to_dataset_grain(labels, label_mode, num_classes):\n    \"\"\"Create a `grain.MapDataset` from the list/tuple of labels.\n\n    Args:\n        labels: list/tuple of labels to be converted into a `grain.MapDataset`.\n        label_mode: String describing the encoding of `labels`. Options are:\n        - `\"binary\"` indicates that the labels (there can be only 2) are encoded\n            as `float32` scalars with values 0 or 1\n            (e.g. for `binary_crossentropy`).\n        - `\"categorical\"` means that the labels are mapped into a categorical\n            vector.  (e.g. for `categorical_crossentropy` loss).\n        num_classes: number of classes of labels.\n\n    Returns:\n        A `grain.MapDataset` instance.\n    \"\"\"\n    from keras.src import backend\n    from keras.src import ops\n\n    if label_mode not in (\"binary\", \"categorical\", \"int\"):\n        raise ValueError(\n            f\"Invalid `label_mode`: {label_mode}. \"\n            \"Expected one of: 'binary', 'categorical', 'int'.\"\n        )\n\n    def preprocess_labels_in_cpu(label_mode, x, num_classes):\n        with backend.device_scope(\"cpu\"):\n            if label_mode == \"binary\":\n                return ops.expand_dims(\n                    ops.convert_to_tensor(x, dtype=\"float32\"), axis=-1\n                )\n            elif label_mode == \"categorical\":\n                return ops.one_hot(\n                    ops.convert_to_tensor(x, dtype=\"int32\"), num_classes\n                )\n            else:\n                return ops.convert_to_tensor(x, dtype=\"int32\")\n\n    label_ds = grain.MapDataset.source(labels)\n    label_ds = label_ds.map(\n        lambda x: preprocess_labels_in_cpu(label_mode, x, num_classes),\n    )\n    return label_ds\n\n\ndef check_validation_split_arg(validation_split, subset, shuffle, seed):\n    \"\"\"Raise errors in case of invalid argument values.\n\n    Args:\n        validation_split: float between 0 and 1, fraction of data to reserve for\n            validation.\n        subset: One of `\"training\"`, `\"validation\"`, or `\"both\"`. Only used if\n            `validation_split` is set.\n        shuffle: Whether to shuffle the data. Either `True` or `False`.\n        seed: random seed for shuffling and transformations.\n    \"\"\"\n    if validation_split and not 0 < validation_split < 1:\n        raise ValueError(\n            \"`validation_split` must be between 0 and 1, \"\n            f\"received: {validation_split}\"\n        )\n    if (validation_split or subset) and not (validation_split and subset):\n        raise ValueError(\n            \"If `subset` is set, `validation_split` must be set, and inversely.\"\n        )\n    if subset not in (\"training\", \"validation\", \"both\", None):\n        raise ValueError(\n            '`subset` must be either \"training\", '\n            f'\"validation\" or \"both\", received: {subset}'\n        )\n    if validation_split and shuffle and seed is None:\n        raise ValueError(\n            \"If using `validation_split` and shuffling the data, you must \"\n            \"provide a `seed` argument, to make sure that there is no \"\n            \"overlap between the training and validation subset.\"\n        )\n"
  },
  {
    "path": "keras/src/utils/dataset_utils_test.py",
    "content": "import collections\nimport itertools\n\nimport numpy as np\nimport torch\nfrom absl.testing import parameterized\nfrom torch.utils.data import Dataset as TorchDataset\n\nfrom keras.src import backend\nfrom keras.src.testing import test_case\nfrom keras.src.testing.test_utils import named_product\nfrom keras.src.utils.dataset_utils import split_dataset\nfrom keras.src.utils.module_utils import tensorflow as tf\n\n\nclass MyTorchDataset(TorchDataset):\n    def __init__(self, x, y=None):\n        # Convert NumPy → Torch tensors if needed\n        def to_tensor(v):\n            if isinstance(v, torch.Tensor):\n                return v\n            if hasattr(v, \"shape\"):\n                return torch.as_tensor(v, dtype=torch.float32)\n            return v\n\n        # Convert structured input recursively\n        def map_structure(obj):\n            if isinstance(obj, (dict, collections.OrderedDict)):\n                return {k: map_structure(v) for k, v in obj.items()}\n            if isinstance(obj, (tuple, list)):\n                typ = type(obj)\n                return typ(map_structure(v) for v in obj)\n            return to_tensor(obj)\n\n        self.x = map_structure(x)\n        self.y = None if y is None else map_structure(y)\n\n        # Infer dataset length from the first tensor in x\n        def first_tensor(obj):\n            if isinstance(obj, (dict, collections.OrderedDict)):\n                return first_tensor(next(iter(obj.values())))\n            if isinstance(obj, (tuple, list)):\n                return first_tensor(obj[0])\n            return obj\n\n        self.length = len(first_tensor(self.x))\n\n    def __len__(self):\n        return self.length\n\n    def __getitem__(self, idx):\n        def index_structure(obj):\n            if isinstance(obj, (dict, collections.OrderedDict)):\n                return obj.__class__(\n                    (k, index_structure(v)) for k, v in obj.items()\n                )\n            if isinstance(obj, (tuple, list)):\n                typ = type(obj)\n                return typ(index_structure(v) for v in obj)\n            return obj[idx]\n\n        if self.y is None:\n            return index_structure(self.x)\n        return index_structure(self.x), index_structure(self.y)\n\n\nclass DatasetUtilsTest(test_case.TestCase):\n    @parameterized.named_parameters(\n        named_product(\n            dataset_type=[\"list\", \"tuple\", \"tensorflow\", \"torch\"],\n            features_shape=[(2,), (100, 2), (10, 10, 2)],\n            preferred_backend=[None, \"tensorflow\", \"torch\"],\n        )\n    )\n    def test_split_dataset(\n        self, dataset_type, features_shape, preferred_backend\n    ):\n        n_sample, left_size, right_size = 100, 0.2, 0.8\n        features = np.random.sample((n_sample,) + features_shape)\n        labels = np.random.sample((n_sample, 1))\n        cardinality_function = (\n            tf.data.Dataset.cardinality\n            if (backend.backend() != \"torch\" and preferred_backend != \"torch\")\n            else len\n        )\n\n        if dataset_type == \"list\":\n            dataset = [features, labels]\n        elif dataset_type == \"tuple\":\n            dataset = (features, labels)\n        elif dataset_type == \"tensorflow\":\n            dataset = tf.data.Dataset.from_tensor_slices((features, labels))\n        elif dataset_type == \"torch\":\n            dataset = MyTorchDataset(features, labels)\n            cardinality_function = len\n        else:\n            raise ValueError(f\"Unknown dataset_type: {dataset_type}\")\n\n        dataset_left, dataset_right = split_dataset(\n            dataset,\n            left_size=left_size,\n            right_size=right_size,\n            preferred_backend=preferred_backend,\n        )\n        self.assertEqual(\n            int(cardinality_function(dataset_left)), int(n_sample * left_size)\n        )\n        self.assertEqual(\n            int(cardinality_function(dataset_right)), int(n_sample * right_size)\n        )\n        for sample in itertools.chain(dataset_left, dataset_right):\n            self.assertEqual(sample[0].shape, features_shape)\n            self.assertEqual(sample[1].shape, (1,))\n\n    @parameterized.named_parameters(\n        named_product(structure_type=[\"tuple\", \"dict\", \"OrderedDict\"])\n    )\n    def test_split_dataset_nested_structures(self, structure_type):\n        n_sample, left_size, right_size = 100, 0.2, 0.8\n        features1 = np.random.sample((n_sample, 2))\n        features2 = np.random.sample((n_sample, 10, 2))\n        labels = np.random.sample((n_sample, 1))\n\n        if backend.backend() != \"torch\":\n            create_dataset_function = tf.data.Dataset.from_tensor_slices\n            cardinality_function = tf.data.Dataset.cardinality\n        else:\n            create_dataset_function = MyTorchDataset\n            cardinality_function = len\n\n        if structure_type == \"tuple\":\n            dataset = create_dataset_function(((features1, features2), labels))\n        if structure_type == \"dict\":\n            dataset = create_dataset_function(\n                {\"y\": features2, \"x\": features1, \"labels\": labels}\n            )\n        if structure_type == \"OrderedDict\":\n            dataset = create_dataset_function(\n                collections.OrderedDict(\n                    [(\"y\", features2), (\"x\", features1), (\"labels\", labels)]\n                )\n            )\n\n        dataset_left, dataset_right = split_dataset(\n            dataset, left_size=left_size, right_size=right_size\n        )\n        self.assertEqual(\n            int(cardinality_function(dataset_left)), int(n_sample * left_size)\n        )\n        self.assertEqual(\n            int(cardinality_function(dataset_right)), int(n_sample * right_size)\n        )\n        for sample in itertools.chain(dataset_left, dataset_right):\n            if structure_type in (\"dict\", \"OrderedDict\"):\n                x, y, labels = sample[\"x\"], sample[\"y\"], sample[\"labels\"]\n            elif structure_type == \"tuple\":\n                (x, y), labels = sample\n            self.assertEqual(x.shape, (2,))\n            self.assertEqual(y.shape, (10, 2))\n            self.assertEqual(labels.shape, (1,))\n"
  },
  {
    "path": "keras/src/utils/dtype_utils.py",
    "content": "from keras.src import backend\nfrom keras.src import ops\n\nDTYPE_TO_SIZE = {\n    **{f\"float{i}\": i for i in (16, 32, 64)},\n    **{f\"int{i}\": i for i in (8, 16, 32, 64)},\n    **{f\"uint{i}\": i for i in (8, 16, 32, 64)},\n    \"bfloat16\": 16,\n    \"bool\": 1,\n}\n\n\ndef dtype_size(dtype):\n    size = DTYPE_TO_SIZE.get(dtype, None)\n    if size is None:\n        raise ValueError(f\"Invalid dtype: {dtype}\")\n    return size\n\n\ndef is_float(dtype):\n    return \"float\" in dtype\n\n\ndef cast_to_common_dtype(tensors):\n    \"\"\"Cast a list of tensors to a common dtype.\n\n    If any tensor is floating-point, all tensors are cast to a common\n    floating-point dtype with sufficient precision. The promotion follows\n    the highest precision floating dtype present, with special handling\n    for mixed `float16` and `bfloat16`, which are promoted to `float32`.\n\n    If no floating-point tensors are present, tensors are returned unchanged.\n\n    Args:\n        tensors: A list of tensors.\n\n    Returns:\n        List of tensors cast to a common dtype when needed.\n    \"\"\"\n    highest_float = None\n    highest_float_size = -1\n\n    seen_float16 = False\n    seen_bfloat16 = False\n\n    for x in tensors:\n        dtype = backend.standardize_dtype(x.dtype)\n\n        if is_float(dtype):\n            if dtype == \"float16\":\n                seen_float16 = True\n            elif dtype == \"bfloat16\":\n                seen_bfloat16 = True\n\n            if highest_float is None or dtype_size(dtype) > highest_float_size:\n                highest_float = dtype\n                highest_float_size = dtype_size(dtype)\n\n    # Promote mixed float16 + bfloat16 to float32\n    # Do not downgrade if higher precision already found (e.g., float64)\n    if seen_float16 and seen_bfloat16 and highest_float_size < 32:\n        highest_float = \"float32\"\n\n    if highest_float:\n        tensors = [ops.cast(x, highest_float) for x in tensors]\n\n    return tensors\n"
  },
  {
    "path": "keras/src/utils/dtype_utils_test.py",
    "content": "from keras.src.backend.common.keras_tensor import KerasTensor\nfrom keras.src.testing import test_case\nfrom keras.src.utils import dtype_utils\n\n\nclass DtypeSizeTests(test_case.TestCase):\n    def test_bfloat16_dtype_size(self):\n        self.assertEqual(dtype_utils.dtype_size(\"bfloat16\"), 16)\n\n    def test_float16_dtype_size(self):\n        self.assertEqual(dtype_utils.dtype_size(\"float16\"), 16)\n\n    def test_float32_dtype_size(self):\n        self.assertEqual(dtype_utils.dtype_size(\"float32\"), 32)\n\n    def test_int32_dtype_size(self):\n        self.assertEqual(dtype_utils.dtype_size(\"int32\"), 32)\n\n    def test_float64_dtype_size(self):\n        self.assertEqual(dtype_utils.dtype_size(\"float64\"), 64)\n\n    def test_int64_dtype_size(self):\n        self.assertEqual(dtype_utils.dtype_size(\"int64\"), 64)\n\n    def test_uint8_dtype_size(self):\n        self.assertEqual(dtype_utils.dtype_size(\"uint8\"), 8)\n\n    def test_bool_dtype_size(self):\n        self.assertEqual(dtype_utils.dtype_size(\"bool\"), 1)\n\n    def test_invalid_dtype_size(self):\n        with self.assertRaises(ValueError):\n            dtype_utils.dtype_size(\"unknown_dtype\")\n\n\nclass IsFloatTests(test_case.TestCase):\n    def test_is_float_float16(self):\n        self.assertTrue(dtype_utils.is_float(\"float16\"))\n\n    def test_is_float_float32(self):\n        self.assertTrue(dtype_utils.is_float(\"float32\"))\n\n    def test_is_float_float64(self):\n        self.assertTrue(dtype_utils.is_float(\"float64\"))\n\n    def test_is_float_int32(self):\n        self.assertFalse(dtype_utils.is_float(\"int32\"))\n\n    def test_is_float_bool(self):\n        self.assertFalse(dtype_utils.is_float(\"bool\"))\n\n    def test_is_float_uint8(self):\n        self.assertFalse(dtype_utils.is_float(\"uint8\"))\n\n    def test_is_float_containing_float(self):\n        self.assertTrue(dtype_utils.is_float(\"floating\"))\n\n    def test_is_float_empty_string(self):\n        self.assertFalse(dtype_utils.is_float(\"\"))\n\n\nclass CastToCommonDtype(test_case.TestCase):\n    def test_cast_to_common_dtype_float32_float64(self):\n        tensor1 = KerasTensor([1, 2, 3], dtype=\"float32\")\n        tensor2 = KerasTensor([4, 5, 6], dtype=\"float64\")\n        casted_tensors = dtype_utils.cast_to_common_dtype([tensor1, tensor2])\n        for tensor in casted_tensors:\n            self.assertEqual(tensor.dtype, \"float64\")\n\n    def test_cast_to_common_dtype_float16_float32_float64(self):\n        tensor1 = KerasTensor([1, 2, 3], dtype=\"float16\")\n        tensor2 = KerasTensor([4, 5, 6], dtype=\"float32\")\n        tensor3 = KerasTensor([7, 8, 9], dtype=\"float64\")\n        casted_tensors = dtype_utils.cast_to_common_dtype(\n            [tensor1, tensor2, tensor3]\n        )\n        for tensor in casted_tensors:\n            self.assertEqual(tensor.dtype, \"float64\")\n\n    def test_cast_to_common_dtype_float16_int16_float32(self):\n        tensor1 = KerasTensor([1, 2, 3], dtype=\"float16\")\n        tensor2 = KerasTensor([4, 5, 6], dtype=\"int16\")\n        tensor3 = KerasTensor([7, 8, 9], dtype=\"float32\")\n        casted_tensors = dtype_utils.cast_to_common_dtype(\n            [tensor1, tensor2, tensor3]\n        )\n        for tensor in casted_tensors:\n            self.assertEqual(tensor.dtype, \"float32\")\n\n    def test_cast_to_common_dtype_all_float32(self):\n        tensor1 = KerasTensor([1, 2, 3], dtype=\"float32\")\n        tensor2 = KerasTensor([4, 5, 6], dtype=\"float32\")\n        tensor3 = KerasTensor([7, 8, 9], dtype=\"float32\")\n        casted_tensors = dtype_utils.cast_to_common_dtype(\n            [tensor1, tensor2, tensor3]\n        )\n        for tensor in casted_tensors:\n            self.assertEqual(tensor.dtype, \"float32\")\n\n    def test_cast_to_common_dtype_float16_bfloat16(self):\n        tensor1 = KerasTensor([1, 2, 3], dtype=\"float16\")\n        tensor2 = KerasTensor([4, 5, 6], dtype=\"bfloat16\")\n        casted_tensors = dtype_utils.cast_to_common_dtype([tensor1, tensor2])\n        for tensor in casted_tensors:\n            self.assertEqual(tensor.dtype, \"float32\")\n\n    def test_cast_to_common_dtype_float16_uint8(self):\n        tensor1 = KerasTensor([1, 2, 3], dtype=\"float16\")\n        tensor2 = KerasTensor([4, 5, 6], dtype=\"uint8\")\n        casted_tensors = dtype_utils.cast_to_common_dtype([tensor1, tensor2])\n        for tensor in casted_tensors:\n            self.assertEqual(tensor.dtype, \"float16\")\n\n    def test_cast_to_common_dtype_mixed_types(self):\n        tensor1 = KerasTensor([1, 2, 3], dtype=\"float32\")\n        tensor2 = KerasTensor([4, 5, 6], dtype=\"int32\")\n        tensor3 = KerasTensor([7, 8, 9], dtype=\"bool\")\n        casted_tensors = dtype_utils.cast_to_common_dtype(\n            [tensor1, tensor2, tensor3]\n        )\n        for tensor in casted_tensors:\n            self.assertEqual(tensor.dtype, \"float32\")\n\n    def test_cast_to_common_dtype_no_float(self):\n        tensor1 = KerasTensor([1, 2, 3], dtype=\"int32\")\n        tensor2 = KerasTensor([4, 5, 6], dtype=\"uint8\")\n        casted_tensors = dtype_utils.cast_to_common_dtype([tensor1, tensor2])\n        self.assertEqual(casted_tensors[0].dtype, \"int32\")\n        self.assertEqual(casted_tensors[1].dtype, \"uint8\")\n\n    def test_cast_to_common_dtype_float16_bfloat16_promotion(self):\n        tensor1 = KerasTensor([4, 5, 6], dtype=\"bfloat16\")\n        tensor2 = KerasTensor([1, 2, 3], dtype=\"float16\")\n        casted_tensors = dtype_utils.cast_to_common_dtype([tensor1, tensor2])\n        for tensor in casted_tensors:\n            self.assertEqual(tensor.dtype, \"float32\")\n\n    def test_cast_to_common_dtype_bfloat16_float16_promotion(self):\n        tensor1 = KerasTensor([1, 2, 3], dtype=\"float16\")\n        tensor2 = KerasTensor([4, 5, 6], dtype=\"bfloat16\")\n        casted_tensors = dtype_utils.cast_to_common_dtype([tensor1, tensor2])\n        for tensor in casted_tensors:\n            self.assertEqual(tensor.dtype, \"float32\")\n\n    def test_cast_to_common_dtype_f16_bf16_f64_preservation(self):\n        t1 = KerasTensor([1], dtype=\"float16\")\n        t2 = KerasTensor([2], dtype=\"bfloat16\")\n        t3 = KerasTensor([3], dtype=\"float64\")\n        casted = dtype_utils.cast_to_common_dtype([t1, t2, t3])\n        for tensor in casted:\n            self.assertEqual(tensor.dtype, \"float64\")\n"
  },
  {
    "path": "keras/src/utils/file_utils.py",
    "content": "import hashlib\nimport os\nimport re\nimport shutil\nimport sys\nimport tarfile\nimport tempfile\nimport urllib\nimport urllib.error\nimport urllib.parse\nimport warnings\nimport zipfile\nfrom urllib.request import urlretrieve\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend import config\nfrom keras.src.utils import io_utils\nfrom keras.src.utils.module_utils import gfile\nfrom keras.src.utils.progbar import Progbar\n\n\ndef path_to_string(path):\n    \"\"\"Convert `PathLike` objects to their string representation.\n\n    If given a non-string typed path object, converts it to its string\n    representation.\n\n    If the object passed to `path` is not among the above, then it is\n    returned unchanged. This allows e.g. passthrough of file objects\n    through this function.\n\n    Args:\n        path: `PathLike` object that represents a path\n\n    Returns:\n        A string representation of the path argument, if Python support exists.\n    \"\"\"\n    if isinstance(path, os.PathLike):\n        return os.fspath(path)\n    return path\n\n\ndef resolve_path(path):\n    return os.path.realpath(os.path.abspath(path))\n\n\ndef is_path_in_dir(path, base_dir):\n    return resolve_path(os.path.join(base_dir, path)).startswith(base_dir)\n\n\ndef is_link_in_dir(info, base):\n    tip = resolve_path(os.path.join(base, os.path.dirname(info.name)))\n    return is_path_in_dir(info.linkname, base_dir=tip)\n\n\ndef filter_safe_zipinfos(members, base_dir):\n    base_dir = resolve_path(base_dir)\n    for finfo in members:\n        valid_path = False\n        if is_path_in_dir(finfo.filename, base_dir):\n            valid_path = True\n            yield finfo\n        if not valid_path:\n            warnings.warn(\n                \"Skipping invalid path during archive extraction: \"\n                f\"'{finfo.name}'.\",\n                stacklevel=2,\n            )\n\n\ndef filter_safe_tarinfos(members, base_dir):\n    base_dir = resolve_path(base_dir)\n    for finfo in members:\n        valid_path = False\n        if finfo.issym() or finfo.islnk():\n            if is_link_in_dir(finfo, base_dir):\n                valid_path = True\n                yield finfo\n        elif is_path_in_dir(finfo.name, base_dir):\n            valid_path = True\n            yield finfo\n        if not valid_path:\n            warnings.warn(\n                \"Skipping invalid path during archive extraction: \"\n                f\"'{finfo.name}'.\",\n                stacklevel=2,\n            )\n\n\ndef extract_open_archive(archive, path=\".\"):\n    \"\"\"Extracts an open tar or zip archive to the provided directory.\n\n    This function filters unsafe paths during extraction.\n\n    Args:\n        archive: The archive object, either a `TarFile` or a `ZipFile`.\n        path: Where to extract the archive file.\n    \"\"\"\n    if isinstance(archive, zipfile.ZipFile):\n        # Zip archive.\n        archive.extractall(\n            path, members=filter_safe_zipinfos(archive.infolist(), path)\n        )\n    else:\n        # Tar archive.\n        extractall_kwargs = {}\n        # The `filter=\"data\"` option was added in Python 3.12. It became the\n        # default starting from Python 3.14. So we only specify it between\n        # those two versions.\n        if sys.version_info >= (3, 12) and sys.version_info < (3, 14):\n            extractall_kwargs = {\"filter\": \"data\"}\n        archive.extractall(\n            path,\n            members=filter_safe_tarinfos(archive, path),\n            **extractall_kwargs,\n        )\n\n\ndef extract_archive(file_path, path=\".\", archive_format=\"auto\"):\n    \"\"\"Extracts an archive if it matches a support format.\n\n    Supports `.tar`, `.tar.gz`, `.tar.bz`, and `.zip` formats.\n\n    Args:\n        file_path: Path to the archive file.\n        path: Where to extract the archive file.\n        archive_format: Archive format to try for extracting the file.\n            Options are `\"auto\"`, `\"tar\"`, `\"zip\"`, and `None`.\n            `\"tar\"` includes `.tar`, `.tar.gz`, and `.tar.bz` files.\n            The default `\"auto\"` uses `[\"tar\", \"zip\"]`.\n            `None` or an empty list will return no matches found.\n\n    Returns:\n        `True` if a match was found and an archive extraction was completed,\n        `False` otherwise.\n    \"\"\"\n    if archive_format is None:\n        return False\n    if archive_format == \"auto\":\n        archive_format = [\"tar\", \"zip\"]\n    if isinstance(archive_format, str):\n        archive_format = [archive_format]\n\n    file_path = path_to_string(file_path)\n    path = path_to_string(path)\n\n    for archive_type in archive_format:\n        if archive_type == \"tar\":\n            open_fn = tarfile.open\n            is_match_fn = tarfile.is_tarfile\n        elif archive_type == \"zip\":\n            open_fn = zipfile.ZipFile\n            is_match_fn = zipfile.is_zipfile\n        else:\n            raise NotImplementedError(archive_type)\n\n        if is_match_fn(file_path):\n            with open_fn(file_path) as archive:\n                try:\n                    extract_open_archive(archive, path)\n                except (tarfile.TarError, RuntimeError, KeyboardInterrupt):\n                    if os.path.exists(path):\n                        if os.path.isfile(path):\n                            os.remove(path)\n                        else:\n                            shutil.rmtree(path)\n                    raise\n            return True\n    return False\n\n\n@keras_export(\"keras.utils.get_file\")\ndef get_file(\n    fname=None,\n    origin=None,\n    untar=False,\n    md5_hash=None,\n    file_hash=None,\n    cache_subdir=\"datasets\",\n    hash_algorithm=\"auto\",\n    extract=False,\n    archive_format=\"auto\",\n    cache_dir=None,\n    force_download=False,\n):\n    \"\"\"Downloads a file from a URL if it not already in the cache.\n\n    By default the file at the url `origin` is downloaded to the\n    cache_dir `~/.keras`, placed in the cache_subdir `datasets`,\n    and given the filename `fname`. The final location of a file\n    `example.txt` would therefore be `~/.keras/datasets/example.txt`.\n    Files in `.tar`, `.tar.gz`, `.tar.bz`, and `.zip` formats can\n    also be extracted.\n\n    Passing a hash will verify the file after download. The command line\n    programs `shasum` and `sha256sum` can compute the hash.\n\n    Example:\n\n    ```python\n    path_to_downloaded_file = get_file(\n        origin=\"https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz\",\n        extract=True\n    )\n    ```\n\n    Args:\n        fname: If the target is a single file, this is your desired\n            local name for the file.\n            If `None`, the name of the file at `origin` will be used.\n            If downloading and extracting a directory archive,\n            the provided `fname` will be used as extraction directory\n            name (only if it doesn't have an extension).\n        origin: Original URL of the file.\n        untar: Deprecated in favor of `extract` argument.\n            Boolean, whether the file is a tar archive that should\n            be extracted.\n        md5_hash: Deprecated in favor of `file_hash` argument.\n            md5 hash of the file for file integrity verification.\n        file_hash: The expected hash string of the file after download.\n            The sha256 and md5 hash algorithms are both supported.\n        cache_subdir: Subdirectory under the Keras cache dir where the file is\n            saved. If an absolute path, e.g. `\"/path/to/folder\"` is\n            specified, the file will be saved at that location.\n        hash_algorithm: Select the hash algorithm to verify the file.\n            options are `\"md5'`, `\"sha256'`, and `\"auto'`.\n            The default 'auto' detects the hash algorithm in use.\n        extract: If `True`, extracts the archive. Only applicable to compressed\n            archive files like tar or zip.\n        archive_format: Archive format to try for extracting the file.\n            Options are `\"auto'`, `\"tar'`, `\"zip'`, and `None`.\n            `\"tar\"` includes tar, tar.gz, and tar.bz files.\n            The default `\"auto\"` corresponds to `[\"tar\", \"zip\"]`.\n            None or an empty list will return no matches found.\n        cache_dir: Location to store cached files, when None it\n            defaults ether `$KERAS_HOME` if the `KERAS_HOME` environment\n            variable is set or `~/.keras/`.\n        force_download: If `True`, the file will always be re-downloaded\n            regardless of the cache state.\n\n    Returns:\n        Path to the downloaded file.\n\n    **⚠️ Warning on malicious downloads ⚠️**\n\n    Downloading something from the Internet carries a risk.\n    NEVER download a file/archive if you do not trust the source.\n    We recommend that you specify the `file_hash` argument\n    (if the hash of the source file is known) to make sure that the file you\n    are getting is the one you expect.\n    \"\"\"\n    if origin is None:\n        raise ValueError(\n            'Please specify the \"origin\" argument (URL of the file '\n            \"to download).\"\n        )\n\n    if cache_dir is None:\n        cache_dir = config.keras_home()\n    if md5_hash is not None and file_hash is None:\n        file_hash = md5_hash\n        hash_algorithm = \"md5\"\n    datadir_base = os.path.expanduser(cache_dir)\n    if not os.access(datadir_base, os.W_OK):\n        datadir_base = os.path.join(\n            \"/tmp\" if os.path.isdir(\"/tmp\") else tempfile.gettempdir(), \".keras\"\n        )\n    datadir = os.path.join(datadir_base, cache_subdir)\n    os.makedirs(datadir, exist_ok=True)\n\n    provided_fname = fname\n    fname = path_to_string(fname)\n\n    if not fname:\n        fname = os.path.basename(urllib.parse.urlsplit(origin).path)\n        if not fname:\n            raise ValueError(\n                \"Can't parse the file name from the origin provided: \"\n                f\"'{origin}'.\"\n                \"Please specify the `fname` argument.\"\n            )\n    else:\n        if os.sep in fname:\n            raise ValueError(\n                \"Paths are no longer accepted as the `fname` argument. \"\n                \"To specify the file's parent directory, use \"\n                f\"the `cache_dir` argument. Received: fname={fname}\"\n            )\n\n    if extract or untar:\n        if provided_fname:\n            if \".\" in fname:\n                download_target = os.path.join(datadir, fname)\n                fname = fname[: fname.find(\".\")]\n                extraction_dir = os.path.join(datadir, f\"{fname}_extracted\")\n            else:\n                extraction_dir = os.path.join(datadir, fname)\n                download_target = os.path.join(datadir, f\"{fname}_archive\")\n        else:\n            extraction_dir = os.path.join(datadir, fname)\n            download_target = os.path.join(datadir, f\"{fname}_archive\")\n    else:\n        download_target = os.path.join(datadir, fname)\n\n    if force_download:\n        download = True\n    elif os.path.exists(download_target):\n        # File found in cache.\n        download = False\n        # Verify integrity if a hash was provided.\n        if file_hash is not None:\n            if not validate_file(\n                download_target, file_hash, algorithm=hash_algorithm\n            ):\n                io_utils.print_msg(\n                    \"A local file was found, but it seems to be \"\n                    f\"incomplete or outdated because the {hash_algorithm} \"\n                    \"file hash does not match the original value of \"\n                    f\"{file_hash} so we will re-download the data.\"\n                )\n                download = True\n    else:\n        download = True\n\n    if download:\n        io_utils.print_msg(f\"Downloading data from {origin}\")\n\n        class DLProgbar:\n            \"\"\"Manage progress bar state for use in urlretrieve.\"\"\"\n\n            def __init__(self):\n                self.progbar = None\n                self.finished = False\n\n            def __call__(self, block_num, block_size, total_size):\n                if total_size == -1:\n                    total_size = None\n                if not self.progbar:\n                    self.progbar = Progbar(total_size)\n                current = block_num * block_size\n\n                if total_size is None:\n                    self.progbar.update(current)\n                else:\n                    if current < total_size:\n                        self.progbar.update(current)\n                    elif not self.finished:\n                        self.progbar.update(self.progbar.target)\n                        self.finished = True\n\n        error_msg = \"URL fetch failure on {}: {} -- {}\"\n        try:\n            try:\n                urlretrieve(origin, download_target, DLProgbar())\n            except urllib.error.HTTPError as e:\n                raise Exception(error_msg.format(origin, e.code, e.msg))\n            except urllib.error.URLError as e:\n                raise Exception(error_msg.format(origin, e.errno, e.reason))\n        except (Exception, KeyboardInterrupt):\n            if os.path.exists(download_target):\n                os.remove(download_target)\n            raise\n\n        # Validate download if succeeded and user provided an expected hash\n        # Security conscious users would get the hash of the file from a\n        # separate channel and pass it to this API to prevent MITM / corruption:\n        if os.path.exists(download_target) and file_hash is not None:\n            if not validate_file(\n                download_target, file_hash, algorithm=hash_algorithm\n            ):\n                raise ValueError(\n                    \"Incomplete or corrupted file detected. \"\n                    f\"The {hash_algorithm} \"\n                    \"file hash does not match the provided value \"\n                    f\"of {file_hash}.\"\n                )\n\n    if extract or untar:\n        if untar:\n            archive_format = \"tar\"\n\n        status = extract_archive(\n            download_target, extraction_dir, archive_format\n        )\n        if not status:\n            warnings.warn(\"Could not extract archive.\", stacklevel=2)\n        return extraction_dir\n\n    return download_target\n\n\ndef resolve_hasher(algorithm, file_hash=None):\n    \"\"\"Returns hash algorithm as hashlib function.\"\"\"\n    if algorithm == \"sha256\":\n        return hashlib.sha256()\n\n    if algorithm == \"auto\" and file_hash is not None and len(file_hash) == 64:\n        return hashlib.sha256()\n\n    # This is used only for legacy purposes.\n    return hashlib.md5()\n\n\ndef hash_file(fpath, algorithm=\"sha256\", chunk_size=65535):\n    \"\"\"Calculates a file sha256 or md5 hash.\n\n    Example:\n\n    >>> hash_file('/path/to/file.zip')\n    'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'\n\n    Args:\n        fpath: Path to the file being validated.\n        algorithm: Hash algorithm, one of `\"auto\"`, `\"sha256\"`, or `\"md5\"`.\n            The default `\"auto\"` detects the hash algorithm in use.\n        chunk_size: Bytes to read at a time, important for large files.\n\n    Returns:\n        The file hash.\n    \"\"\"\n    if isinstance(algorithm, str):\n        hasher = resolve_hasher(algorithm)\n    else:\n        hasher = algorithm\n\n    with open(fpath, \"rb\") as fpath_file:\n        for chunk in iter(lambda: fpath_file.read(chunk_size), b\"\"):\n            hasher.update(chunk)\n\n    return hasher.hexdigest()\n\n\ndef validate_file(fpath, file_hash, algorithm=\"auto\", chunk_size=65535):\n    \"\"\"Validates a file against a sha256 or md5 hash.\n\n    Args:\n        fpath: path to the file being validated\n        file_hash:  The expected hash string of the file.\n            The sha256 and md5 hash algorithms are both supported.\n        algorithm: Hash algorithm, one of `\"auto\"`, `\"sha256\"`, or `\"md5\"`.\n            The default `\"auto\"` detects the hash algorithm in use.\n        chunk_size: Bytes to read at a time, important for large files.\n\n    Returns:\n        Boolean, whether the file is valid.\n    \"\"\"\n    hasher = resolve_hasher(algorithm, file_hash)\n\n    if str(hash_file(fpath, hasher, chunk_size)) == str(file_hash):\n        return True\n    else:\n        return False\n\n\ndef is_remote_path(filepath):\n    \"\"\"\n    Determines if a given filepath indicates a remote location.\n\n    This function checks if the filepath represents a known remote pattern\n    such as GCS (`/gcs`), CNS (`/cns`), CFS (`/cfs`), HDFS (`/hdfs`), Placer\n    (`/placer`), TFHub (`/tfhub`), or a URL (`.*://`).\n\n    Args:\n        filepath (str): The path to be checked.\n\n    Returns:\n        bool: True if the filepath is a recognized remote path, otherwise False\n    \"\"\"\n    if re.match(\n        r\"^(/cns|/cfs|/gcs|/hdfs|/readahead|/placer|/tfhub|.*://).*$\",\n        str(filepath),\n    ):\n        return True\n    return False\n\n\n# Below are gfile-replacement utils.\n\n\ndef _raise_if_no_gfile(path):\n    raise ValueError(\n        \"Handling remote paths requires installing TensorFlow \"\n        f\"(in order to use gfile). Received path: {path}\"\n    )\n\n\ndef exists(path):\n    if is_remote_path(path):\n        if gfile.available:\n            return gfile.exists(path)\n        else:\n            _raise_if_no_gfile(path)\n    return os.path.exists(path)\n\n\ndef File(path, mode=\"r\"):\n    if is_remote_path(path):\n        if gfile.available:\n            return gfile.GFile(path, mode=mode)\n        else:\n            _raise_if_no_gfile(path)\n    return open(path, mode=mode)\n\n\ndef join(path, *paths):\n    if is_remote_path(path):\n        if gfile.available:\n            return gfile.join(path, *paths)\n        else:\n            _raise_if_no_gfile(path)\n    return os.path.join(path, *paths)\n\n\ndef isdir(path):\n    if is_remote_path(path):\n        if gfile.available:\n            return gfile.isdir(path)\n        else:\n            _raise_if_no_gfile(path)\n    return os.path.isdir(path)\n\n\ndef remove(path):\n    if is_remote_path(path):\n        if gfile.available:\n            return gfile.remove(path)\n        else:\n            _raise_if_no_gfile(path)\n    return os.remove(path)\n\n\ndef rmtree(path):\n    if is_remote_path(path):\n        if gfile.available:\n            return gfile.rmtree(path)\n        else:\n            _raise_if_no_gfile(path)\n    return shutil.rmtree(path)\n\n\ndef listdir(path):\n    if is_remote_path(path):\n        if gfile.available:\n            return gfile.listdir(path)\n        else:\n            _raise_if_no_gfile(path)\n    return os.listdir(path)\n\n\ndef copy(src, dst):\n    if is_remote_path(src) or is_remote_path(dst):\n        if gfile.available:\n            return gfile.copy(src, dst, overwrite=True)\n        else:\n            _raise_if_no_gfile(f\"src={src} dst={dst}\")\n    return shutil.copy(src, dst)\n\n\ndef makedirs(path):\n    if is_remote_path(path):\n        if gfile.available:\n            return gfile.makedirs(path)\n        else:\n            _raise_if_no_gfile(path)\n    return os.makedirs(path)\n"
  },
  {
    "path": "keras/src/utils/file_utils_test.py",
    "content": "import hashlib\nimport os\nimport tarfile\nimport urllib\nimport urllib.parse\nimport urllib.request\nimport zipfile\nfrom unittest.mock import patch\n\nfrom keras.src.testing import test_case\nfrom keras.src.utils import file_utils\n\n\nclass PathToStringTest(test_case.TestCase):\n    def test_path_to_string_with_string_path(self):\n        path = os.path.join(os.path.sep, \"path\", \"to\", \"file.txt\")\n        string_path = file_utils.path_to_string(path)\n        self.assertEqual(string_path, path)\n\n    def test_path_to_string_with_PathLike_object(self):\n        path = os.path.join(os.path.sep, \"path\", \"to\", \"file.txt\")\n        string_path = file_utils.path_to_string(path)\n        self.assertEqual(string_path, str(path))\n\n    def test_path_to_string_with_non_string_typed_path_object(self):\n        class NonStringTypedPathObject:\n            def __fspath__(self):\n                return os.path.join(os.path.sep, \"path\", \"to\", \"file.txt\")\n\n        path = NonStringTypedPathObject()\n        string_path = file_utils.path_to_string(path)\n        self.assertEqual(\n            string_path, os.path.join(os.path.sep, \"path\", \"to\", \"file.txt\")\n        )\n\n    def test_path_to_string_with_none_path(self):\n        string_path = file_utils.path_to_string(None)\n        self.assertEqual(string_path, None)\n\n\nclass ResolvePathTest(test_case.TestCase):\n    def test_resolve_path_with_absolute_path(self):\n        path = os.path.join(os.path.sep, \"path\", \"to\", \"file.txt\")\n        resolved_path = file_utils.resolve_path(path)\n        self.assertEqual(resolved_path, os.path.realpath(os.path.abspath(path)))\n\n    def test_resolve_path_with_relative_path(self):\n        path = os.path.join(\".\", \"file.txt\")\n        resolved_path = file_utils.resolve_path(path)\n        self.assertEqual(resolved_path, os.path.realpath(os.path.abspath(path)))\n\n\nclass IsPathInDirTest(test_case.TestCase):\n    def test_is_path_in_dir_with_absolute_paths(self):\n        base_dir = os.path.join(os.path.sep, \"path\", \"to\", \"base_dir\")\n        path = os.path.join(base_dir, \"file.txt\")\n        self.assertTrue(file_utils.is_path_in_dir(path, base_dir))\n\n\nclass IsLinkInDirTest(test_case.TestCase):\n    def test_is_link_in_dir_with_absolute_paths(self):\n        base_dir = self.get_temp_dir()\n        link_path = os.path.join(base_dir, \"symlink\")\n        target_path = os.path.join(base_dir, \"file.txt\")\n\n        # Create the file.txt file.\n        with open(target_path, \"w\") as f:\n            f.write(\"Hello, world!\")\n\n        os.symlink(target_path, link_path)\n\n        # Creating a stat_result-like object with a name attribute\n        info = os.lstat(link_path)\n        info = type(\n            \"stat_with_name\",\n            (object,),\n            {\n                \"name\": os.path.basename(link_path),\n                \"linkname\": os.readlink(link_path),\n            },\n        )\n\n        self.assertTrue(file_utils.is_link_in_dir(info, base_dir))\n\n    def test_is_link_in_dir_with_relative_paths(self):\n        base_dir = self.get_temp_dir()\n        link_path = os.path.join(base_dir, \"symlink\")\n        target_path = os.path.join(base_dir, \"file.txt\")\n\n        # Create the file.txt file.\n        with open(target_path, \"w\") as f:\n            f.write(\"Hello, world!\")\n\n        os.symlink(target_path, link_path)\n\n        # Creating a stat_result-like object with a name attribute\n        info = os.lstat(link_path)\n        info = type(\n            \"stat_with_name\",\n            (object,),\n            {\n                \"name\": os.path.basename(link_path),\n                \"linkname\": os.readlink(link_path),\n            },\n        )\n\n        self.assertTrue(file_utils.is_link_in_dir(info, base_dir))\n\n\nclass FilterSafePathsTest(test_case.TestCase):\n    def setUp(self):\n        self.base_dir = os.path.abspath(self.get_temp_dir())\n        self.tar_path = os.path.join(self.base_dir, \"test.tar\")\n        self.target_path = os.path.join(self.base_dir, \"target.txt\")\n        with open(self.target_path, \"w\") as f:\n            f.write(\"target\")\n        self.symlink_path = os.path.join(self.base_dir, \"symlink.txt\")\n        os.symlink(self.target_path, self.symlink_path)\n\n    def test_member_within_base_dir(self):\n        \"\"\"Test a member within the base directory.\"\"\"\n        with tarfile.open(self.tar_path, \"w\") as tar:\n            tar.add(self.target_path, arcname=\"safe_path.txt\")\n        with tarfile.open(self.tar_path, \"r\") as tar:\n            members = list(\n                file_utils.filter_safe_tarinfos(tar.getmembers(), self.base_dir)\n            )\n            self.assertEqual(len(members), 1)\n            self.assertEqual(members[0].name, \"safe_path.txt\")\n\n    def test_symlink_within_base_dir(self):\n        \"\"\"Test a symlink pointing within the base directory.\"\"\"\n        with tarfile.open(self.tar_path, \"w\") as tar:\n            tar.add(self.symlink_path, arcname=\"symlink.txt\")\n        with tarfile.open(self.tar_path, \"r\") as tar:\n            members = list(\n                file_utils.filter_safe_tarinfos(tar.getmembers(), self.base_dir)\n            )\n            self.assertEqual(len(members), 1)\n            self.assertEqual(members[0].name, \"symlink.txt\")\n\n    def test_invalid_path_warning(self):\n        \"\"\"Test warning for an invalid path during archive extraction.\"\"\"\n        with tarfile.open(self.tar_path, \"w\") as tar:\n            tar.add(\n                self.target_path, arcname=\"../../invalid.txt\"\n            )  # Path intended to be outside of base dir\n        with tarfile.open(self.tar_path, \"r\") as tar:\n            with patch(\"warnings.warn\") as mock_warn:\n                _ = list(\n                    file_utils.filter_safe_tarinfos(\n                        tar.getmembers(), self.base_dir\n                    )\n                )\n                warning_msg = (\n                    \"Skipping invalid path during archive extraction: \"\n                    \"'../../invalid.txt'.\"\n                )\n                mock_warn.assert_called_with(warning_msg, stacklevel=2)\n\n    def test_symbolic_link_in_base_dir(self):\n        \"\"\"symbolic link within the base directory is correctly processed.\"\"\"\n        # Add the symbolic link to the tar archive.\n        with tarfile.open(self.tar_path, \"w\") as tar:\n            tar.add(self.symlink_path, arcname=\"symlink.txt\")\n\n        with tarfile.open(self.tar_path, \"r\") as tar:\n            members = list(\n                file_utils.filter_safe_tarinfos(tar.getmembers(), self.base_dir)\n            )\n            self.assertEqual(len(members), 1)\n            self.assertEqual(members[0].name, \"symlink.txt\")\n            self.assertTrue(\n                members[0].issym()\n            )  # Explicitly assert it's a symbolic link.\n\n\nclass ExtractArchiveTest(test_case.TestCase):\n    def setUp(self):\n        \"\"\"Create temporary directories and files for testing.\"\"\"\n        self.temp_dir = self.get_temp_dir()\n        self.file_content = \"Hello, world!\"\n\n        # Create sample files to be archived\n        with open(os.path.join(self.temp_dir, \"sample.txt\"), \"w\") as f:\n            f.write(self.file_content)\n\n    def create_tar(self):\n        archive_path = os.path.join(self.temp_dir, \"sample.tar\")\n        with tarfile.open(archive_path, \"w\") as archive:\n            archive.add(\n                os.path.join(self.temp_dir, \"sample.txt\"), arcname=\"sample.txt\"\n            )\n        return archive_path\n\n    def create_zip(self):\n        archive_path = os.path.join(self.temp_dir, \"sample.zip\")\n        with zipfile.ZipFile(archive_path, \"w\") as archive:\n            archive.write(\n                os.path.join(self.temp_dir, \"sample.txt\"), arcname=\"sample.txt\"\n            )\n        return archive_path\n\n    def test_extract_tar(self):\n        archive_path = self.create_tar()\n        extract_path = os.path.join(self.temp_dir, \"extract_tar\")\n        result = file_utils.extract_archive(archive_path, extract_path, \"tar\")\n        self.assertTrue(result)\n        with open(os.path.join(extract_path, \"sample.txt\"), \"r\") as f:\n            self.assertEqual(f.read(), self.file_content)\n\n    def test_extract_zip(self):\n        archive_path = self.create_zip()\n        extract_path = os.path.join(self.temp_dir, \"extract_zip\")\n        result = file_utils.extract_archive(archive_path, extract_path, \"zip\")\n        self.assertTrue(result)\n        with open(os.path.join(extract_path, \"sample.txt\"), \"r\") as f:\n            self.assertEqual(f.read(), self.file_content)\n\n    def test_extract_auto(self):\n        # This will test the 'auto' functionality\n        tar_archive_path = self.create_tar()\n        zip_archive_path = self.create_zip()\n\n        extract_tar_path = os.path.join(self.temp_dir, \"extract_auto_tar\")\n        extract_zip_path = os.path.join(self.temp_dir, \"extract_auto_zip\")\n\n        self.assertTrue(\n            file_utils.extract_archive(tar_archive_path, extract_tar_path)\n        )\n        self.assertTrue(\n            file_utils.extract_archive(zip_archive_path, extract_zip_path)\n        )\n\n        with open(os.path.join(extract_tar_path, \"sample.txt\"), \"r\") as f:\n            self.assertEqual(f.read(), self.file_content)\n\n        with open(os.path.join(extract_zip_path, \"sample.txt\"), \"r\") as f:\n            self.assertEqual(f.read(), self.file_content)\n\n    def test_non_existent_file(self):\n        extract_path = os.path.join(self.temp_dir, \"non_existent\")\n        with self.assertRaises(FileNotFoundError):\n            file_utils.extract_archive(\"non_existent.tar\", extract_path)\n\n    def test_archive_format_none(self):\n        archive_path = self.create_tar()\n        extract_path = os.path.join(self.temp_dir, \"none_format\")\n        result = file_utils.extract_archive(archive_path, extract_path, None)\n        self.assertFalse(result)\n\n    def test_runtime_error_during_extraction(self):\n        tar_path = self.create_tar()\n        extract_path = os.path.join(self.temp_dir, \"runtime_error_extraction\")\n\n        with patch.object(\n            tarfile.TarFile, \"extractall\", side_effect=RuntimeError\n        ):\n            with self.assertRaises(RuntimeError):\n                file_utils.extract_archive(tar_path, extract_path, \"tar\")\n        self.assertFalse(os.path.exists(extract_path))\n\n    def test_keyboard_interrupt_during_extraction(self):\n        tar_path = self.create_tar()\n        extract_path = os.path.join(\n            self.temp_dir, \"keyboard_interrupt_extraction\"\n        )\n\n        with patch.object(\n            tarfile.TarFile, \"extractall\", side_effect=KeyboardInterrupt\n        ):\n            with self.assertRaises(KeyboardInterrupt):\n                file_utils.extract_archive(tar_path, extract_path, \"tar\")\n        self.assertFalse(os.path.exists(extract_path))\n\n\nclass GetFileTest(test_case.TestCase):\n    def setUp(self):\n        \"\"\"Set up temporary directories and sample files.\"\"\"\n        self.temp_dir = self.get_temp_dir()\n        self.file_path = os.path.join(self.temp_dir, \"sample_file.txt\")\n        with open(self.file_path, \"w\") as f:\n            f.write(\"Sample content\")\n\n    def test_valid_tar_extraction(self):\n        \"\"\"Test valid tar.gz extraction and hash validation.\"\"\"\n        dest_dir = self.get_temp_dir()\n        orig_dir = self.get_temp_dir()\n        _, tar_file_path = self._create_tar_file(orig_dir)\n        self._test_file_extraction_and_validation(\n            dest_dir, tar_file_path, \"tar.gz\"\n        )\n\n    def test_valid_zip_extraction(self):\n        \"\"\"Test valid zip extraction and hash validation.\"\"\"\n        dest_dir = self.get_temp_dir()\n        orig_dir = self.get_temp_dir()\n        _, zip_file_path = self._create_zip_file(orig_dir)\n        self._test_file_extraction_and_validation(\n            dest_dir, zip_file_path, \"zip\"\n        )\n\n    def test_valid_text_file_download(self):\n        \"\"\"Test valid text file download and hash validation.\"\"\"\n        dest_dir = self.get_temp_dir()\n        orig_dir = self.get_temp_dir()\n        text_file_path = os.path.join(orig_dir, \"test.txt\")\n        with open(text_file_path, \"w\") as text_file:\n            text_file.write(\"Float like a butterfly, sting like a bee.\")\n        self._test_file_extraction_and_validation(\n            dest_dir, text_file_path, None\n        )\n\n    def test_get_file_with_tgz_extension(self):\n        \"\"\"Test extraction of file with .tar.gz extension.\"\"\"\n        dest_dir = self.get_temp_dir()\n        orig_dir = dest_dir\n        _, tar_file_path = self._create_tar_file(orig_dir)\n\n        origin = urllib.parse.urljoin(\n            \"file://\",\n            urllib.request.pathname2url(os.path.abspath(tar_file_path)),\n        )\n\n        path = file_utils.get_file(\n            \"test.txt.tar.gz\", origin, untar=True, cache_subdir=dest_dir\n        )\n        self.assertTrue(os.path.exists(path))\n        self.assertTrue(os.path.exists(os.path.join(path, \"test.txt\")))\n\n    def test_get_file_with_integrity_check(self):\n        \"\"\"Test file download with integrity check.\"\"\"\n        orig_dir = self.get_temp_dir()\n        file_path = os.path.join(orig_dir, \"test.txt\")\n\n        with open(file_path, \"w\") as text_file:\n            text_file.write(\"Float like a butterfly, sting like a bee.\")\n\n        hashval = file_utils.hash_file(file_path)\n\n        origin = urllib.parse.urljoin(\n            \"file://\", urllib.request.pathname2url(os.path.abspath(file_path))\n        )\n\n        path = file_utils.get_file(\"test.txt\", origin, file_hash=hashval)\n        self.assertTrue(os.path.exists(path))\n\n    def test_cache_invalidation(self):\n        \"\"\"Test using a hash to force cache invalidation.\"\"\"\n        cache_dir = self.get_temp_dir()\n        src_path = os.path.join(self.get_temp_dir(), \"test.txt\")\n        with open(src_path, \"w\") as text_file:\n            text_file.write(\"Float like a butterfly, sting like a bee.\")\n        orig_hash = file_utils.hash_file(src_path)\n        origin = urllib.parse.urljoin(\n            \"file://\", urllib.request.pathname2url(os.path.abspath(src_path))\n        )\n        # Download into the cache.\n        dest_path = file_utils.get_file(\n            \"test.txt\", origin, file_hash=orig_hash, cache_dir=cache_dir\n        )\n        self.assertEqual(orig_hash, file_utils.hash_file(dest_path))\n\n        with open(src_path, \"w\") as text_file:\n            text_file.write(\"Float like a zeppelin, sting like a jellyfish.\")\n        new_hash = file_utils.hash_file(src_path)\n        # Without a hash, we should get the cached version.\n        dest_path = file_utils.get_file(\"test.txt\", origin, cache_dir=cache_dir)\n        self.assertEqual(orig_hash, file_utils.hash_file(dest_path))\n        # Without the new hash, we should re-download.\n        dest_path = file_utils.get_file(\n            \"test.txt\", origin, file_hash=new_hash, cache_dir=cache_dir\n        )\n        self.assertEqual(new_hash, file_utils.hash_file(dest_path))\n\n    def test_force_download(self):\n        \"\"\"Test using a hash to force cache invalidation.\"\"\"\n        cache_dir = self.get_temp_dir()\n        src_path = os.path.join(self.get_temp_dir(), \"test.txt\")\n        with open(src_path, \"w\") as text_file:\n            text_file.write(\"Float like a butterfly, sting like a bee.\")\n        orig_hash = file_utils.hash_file(src_path)\n        origin = urllib.parse.urljoin(\n            \"file://\", urllib.request.pathname2url(os.path.abspath(src_path))\n        )\n        # Download into the cache.\n        dest_path = file_utils.get_file(\"test.txt\", origin, cache_dir=cache_dir)\n        self.assertEqual(orig_hash, file_utils.hash_file(dest_path))\n\n        with open(src_path, \"w\") as text_file:\n            text_file.write(\"Float like a zeppelin, sting like a jellyfish.\")\n        new_hash = file_utils.hash_file(src_path)\n        # Get cached version.\n        dest_path = file_utils.get_file(\"test.txt\", origin, cache_dir=cache_dir)\n        self.assertEqual(orig_hash, file_utils.hash_file(dest_path))\n        # Force download.\n        dest_path = file_utils.get_file(\n            \"test.txt\", origin, force_download=True, cache_dir=cache_dir\n        )\n        self.assertEqual(new_hash, file_utils.hash_file(dest_path))\n\n    def test_get_file_with_failed_integrity_check(self):\n        \"\"\"Test file download with failed integrity check.\"\"\"\n        orig_dir = self.get_temp_dir()\n        file_path = os.path.join(orig_dir, \"test.txt\")\n\n        with open(file_path, \"w\") as text_file:\n            text_file.write(\"Float like a butterfly, sting like a bee.\")\n\n        hashval = \"0\" * 64\n\n        origin = urllib.parse.urljoin(\n            \"file://\", urllib.request.pathname2url(os.path.abspath(file_path))\n        )\n\n        with self.assertRaisesRegex(\n            ValueError, \"Incomplete or corrupted file.*\"\n        ):\n            _ = file_utils.get_file(\"test.txt\", origin, file_hash=hashval)\n\n    def _create_tar_file(self, directory):\n        \"\"\"Helper function to create a tar file.\"\"\"\n        text_file_path = os.path.join(directory, \"test.txt\")\n        tar_file_path = os.path.join(directory, \"test.tar.gz\")\n        with open(text_file_path, \"w\") as text_file:\n            text_file.write(\"Float like a butterfly, sting like a bee.\")\n\n        with tarfile.open(tar_file_path, \"w:gz\") as tar_file:\n            tar_file.add(text_file_path, arcname=\"test.txt\")\n\n        return text_file_path, tar_file_path\n\n    def _create_zip_file(self, directory):\n        \"\"\"Helper function to create a zip file.\"\"\"\n        text_file_path = os.path.join(directory, \"test.txt\")\n        zip_file_path = os.path.join(directory, \"test.zip\")\n        with open(text_file_path, \"w\") as text_file:\n            text_file.write(\"Float like a butterfly, sting like a bee.\")\n\n        with zipfile.ZipFile(zip_file_path, \"w\") as zip_file:\n            zip_file.write(text_file_path, arcname=\"test.txt\")\n\n        return text_file_path, zip_file_path\n\n    def _test_file_extraction_and_validation(\n        self, dest_dir, file_path, archive_type\n    ):\n        \"\"\"Helper function for file extraction and validation.\"\"\"\n        origin = urllib.parse.urljoin(\n            \"file://\",\n            urllib.request.pathname2url(os.path.abspath(file_path)),\n        )\n\n        hashval_md5 = file_utils.hash_file(file_path, algorithm=\"md5\")\n\n        extract = bool(archive_type)\n\n        path = file_utils.get_file(\n            \"test\",\n            origin,\n            md5_hash=hashval_md5,\n            extract=extract,\n            cache_subdir=dest_dir,\n        )\n        if extract:\n            fpath = f\"{path}_archive\"\n        else:\n            fpath = path\n\n        self.assertTrue(os.path.exists(path))\n        self.assertTrue(file_utils.validate_file(fpath, hashval_md5))\n        if extract:\n            self.assertTrue(os.path.exists(os.path.join(path, \"test.txt\")))\n\n    def test_exists(self):\n        temp_dir = self.get_temp_dir()\n        file_path = os.path.join(temp_dir, \"test_exists.txt\")\n\n        with open(file_path, \"w\") as f:\n            f.write(\"test\")\n\n        self.assertTrue(file_utils.exists(file_path))\n        self.assertFalse(\n            file_utils.exists(os.path.join(temp_dir, \"non_existent.txt\"))\n        )\n\n    def test_file_open_read(self):\n        temp_dir = self.get_temp_dir()\n        file_path = os.path.join(temp_dir, \"test_file.txt\")\n        content = \"test content\"\n\n        with open(file_path, \"w\") as f:\n            f.write(content)\n\n        with file_utils.File(file_path, \"r\") as f:\n            self.assertEqual(f.read(), content)\n\n    def test_file_open_write(self):\n        temp_dir = self.get_temp_dir()\n        file_path = os.path.join(temp_dir, \"test_file_write.txt\")\n        content = \"test write content\"\n\n        with file_utils.File(file_path, \"w\") as f:\n            f.write(content)\n\n        with open(file_path, \"r\") as f:\n            self.assertEqual(f.read(), content)\n\n    def test_isdir(self):\n        temp_dir = self.get_temp_dir()\n        self.assertTrue(file_utils.isdir(temp_dir))\n\n        file_path = os.path.join(temp_dir, \"test_isdir.txt\")\n        with open(file_path, \"w\") as f:\n            f.write(\"test\")\n        self.assertFalse(file_utils.isdir(file_path))\n\n    def test_join_simple(self):\n        self.assertEqual(file_utils.join(\"/path\", \"to\", \"dir\"), \"/path/to/dir\")\n\n    def test_join_single_directory(self):\n        self.assertEqual(file_utils.join(\"/path\"), \"/path\")\n\n    def test_listdir(self):\n        content = file_utils.listdir(self.temp_dir)\n        self.assertIn(\"sample_file.txt\", content)\n\n    def test_makedirs_and_rmtree(self):\n        new_dir = os.path.join(self.temp_dir, \"new_directory\")\n        file_utils.makedirs(new_dir)\n        self.assertTrue(os.path.isdir(new_dir))\n        file_utils.rmtree(new_dir)\n        self.assertFalse(os.path.exists(new_dir))\n\n    def test_copy(self):\n        dest_path = os.path.join(self.temp_dir, \"copy_sample_file.txt\")\n        file_utils.copy(self.file_path, dest_path)\n        self.assertTrue(os.path.exists(dest_path))\n        with open(dest_path, \"r\") as f:\n            content = f.read()\n        self.assertEqual(content, \"Sample content\")\n\n    def test_remove_sub_directory(self):\n        parent_dir = os.path.join(self.get_temp_dir(), \"parent_directory\")\n        child_dir = os.path.join(parent_dir, \"child_directory\")\n        file_utils.makedirs(child_dir)\n        file_utils.rmtree(parent_dir)\n        self.assertFalse(os.path.exists(parent_dir))\n        self.assertFalse(os.path.exists(child_dir))\n\n    def test_remove_files_inside_directory(self):\n        dir_path = os.path.join(self.get_temp_dir(), \"test_directory\")\n        file_path = os.path.join(dir_path, \"test.txt\")\n        file_utils.makedirs(dir_path)\n        with open(file_path, \"w\") as f:\n            f.write(\"Test content\")\n        file_utils.rmtree(dir_path)\n        self.assertFalse(os.path.exists(dir_path))\n        self.assertFalse(os.path.exists(file_path))\n\n    def test_handle_complex_paths(self):\n        complex_dir = os.path.join(self.get_temp_dir(), \"complex dir@#%&!\")\n        file_utils.makedirs(complex_dir)\n        file_utils.rmtree(complex_dir)\n        self.assertFalse(os.path.exists(complex_dir))\n\n\nclass HashFileTest(test_case.TestCase):\n    def setUp(self):\n        self.test_content = b\"Hello, World!\"\n        self.temp_file = os.path.join(self.get_temp_dir(), \"test_file.txt\")\n        with open(self.temp_file, \"wb\") as f:\n            f.write(self.test_content)\n\n    def test_hash_file_sha256(self):\n        \"\"\"Test SHA256 hashing of a file.\"\"\"\n        expected_sha256 = (\n            \"dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f\"\n        )\n        calculated_sha256 = file_utils.hash_file(\n            self.temp_file, algorithm=\"sha256\"\n        )\n        self.assertEqual(expected_sha256, calculated_sha256)\n\n    def test_hash_file_md5(self):\n        \"\"\"Test MD5 hashing of a file.\"\"\"\n        expected_md5 = \"65a8e27d8879283831b664bd8b7f0ad4\"\n        calculated_md5 = file_utils.hash_file(self.temp_file, algorithm=\"md5\")\n        self.assertEqual(expected_md5, calculated_md5)\n\n\nclass TestValidateFile(test_case.TestCase):\n    def setUp(self):\n        self.temp_file = os.path.join(self.get_temp_dir(), \"test_file.txt\")\n        with open(self.temp_file, \"wb\") as f:\n            f.write(b\"Hello, World!\")\n\n        self.sha256_hash = (\n            \"dffd6021bb2bd5b0af676290809ec3a53191dd81c7f70a4b28688a362182986f\"\n        )\n        self.md5_hash = \"65a8e27d8879283831b664bd8b7f0ad4\"\n\n    def test_validate_file_sha256(self):\n        \"\"\"Validate SHA256 hash of a file.\"\"\"\n        self.assertTrue(\n            file_utils.validate_file(self.temp_file, self.sha256_hash, \"sha256\")\n        )\n\n    def test_validate_file_md5(self):\n        \"\"\"Validate MD5 hash of a file.\"\"\"\n        self.assertTrue(\n            file_utils.validate_file(self.temp_file, self.md5_hash, \"md5\")\n        )\n\n    def test_validate_file_auto_sha256(self):\n        \"\"\"Auto-detect and validate SHA256 hash.\"\"\"\n        self.assertTrue(\n            file_utils.validate_file(self.temp_file, self.sha256_hash, \"auto\")\n        )\n\n    def test_validate_file_auto_md5(self):\n        \"\"\"Auto-detect and validate MD5 hash.\"\"\"\n        self.assertTrue(\n            file_utils.validate_file(self.temp_file, self.md5_hash, \"auto\")\n        )\n\n    def test_validate_file_wrong_hash(self):\n        \"\"\"Test validation with incorrect hash.\"\"\"\n        wrong_hash = \"deadbeef\" * 8\n        self.assertFalse(\n            file_utils.validate_file(self.temp_file, wrong_hash, \"sha256\")\n        )\n\n\nclass ResolveHasherTest(test_case.TestCase):\n    def test_resolve_hasher_sha256(self):\n        \"\"\"Test resolving hasher for sha256 algorithm.\"\"\"\n        hasher = file_utils.resolve_hasher(\"sha256\")\n        self.assertIsInstance(hasher, type(hashlib.sha256()))\n\n    def test_resolve_hasher_auto_sha256(self):\n        \"\"\"Auto-detect and resolve hasher for sha256.\"\"\"\n        hasher = file_utils.resolve_hasher(\"auto\", file_hash=\"a\" * 64)\n        self.assertIsInstance(hasher, type(hashlib.sha256()))\n\n    def test_resolve_hasher_auto_md5(self):\n        \"\"\"Auto-detect and resolve hasher for md5.\"\"\"\n        hasher = file_utils.resolve_hasher(\"auto\", file_hash=\"a\" * 32)\n        self.assertIsInstance(hasher, type(hashlib.md5()))\n\n    def test_resolve_hasher_default(self):\n        \"\"\"Resolve hasher with a random algorithm value.\"\"\"\n        hasher = file_utils.resolve_hasher(\"random_value\")\n        self.assertIsInstance(hasher, type(hashlib.md5()))\n\n\nclass IsRemotePathTest(test_case.TestCase):\n    def test_gcs_remote_path(self):\n        self.assertTrue(file_utils.is_remote_path(\"/gcs/some/path/to/file.txt\"))\n        self.assertTrue(file_utils.is_remote_path(\"/gcs/another/directory/\"))\n        self.assertTrue(file_utils.is_remote_path(\"gcs://bucket/some/file.txt\"))\n\n    def test_hdfs_remote_path(self):\n        self.assertTrue(file_utils.is_remote_path(\"hdfs://some/path/on/hdfs\"))\n        self.assertTrue(file_utils.is_remote_path(\"/hdfs/some/local/path\"))\n\n    def test_cns_remote_path(self):\n        self.assertTrue(file_utils.is_remote_path(\"/cns/some/path\"))\n\n    def test_placer_remote_path(self):\n        self.assertTrue(\n            file_utils.is_remote_path(\"/placer/prod/home/some/path\")\n        )\n        self.assertTrue(\n            file_utils.is_remote_path(\"/placer/test/home/some/path\")\n        )\n        self.assertTrue(\n            file_utils.is_remote_path(\"/placer/prod/scratch/home/some/path\")\n        )\n\n    def test_tfhub_remote_path(self):\n        self.assertTrue(file_utils.is_remote_path(\"/tfhub/some/path\"))\n\n    def test_cfs_remote_path(self):\n        self.assertTrue(file_utils.is_remote_path(\"/cfs/some/path\"))\n\n    def test_readahead_remote_path(self):\n        self.assertTrue(file_utils.is_remote_path(\"/readahead/some/path\"))\n\n    def test_non_remote_paths(self):\n        self.assertFalse(file_utils.is_remote_path(\"/local/path/to/file.txt\"))\n        self.assertFalse(\n            file_utils.is_remote_path(\"C:\\\\local\\\\path\\\\on\\\\windows\\\\file.txt\")\n        )\n        self.assertFalse(file_utils.is_remote_path(\"~/relative/path/\"))\n        self.assertFalse(file_utils.is_remote_path(\"./another/relative/path\"))\n        self.assertFalse(file_utils.is_remote_path(\"/local/path\"))\n        self.assertFalse(file_utils.is_remote_path(\"./relative/path\"))\n        self.assertFalse(file_utils.is_remote_path(\"~/relative/path\"))\n\n\nclass TestRaiseIfNoGFile(test_case.TestCase):\n    def test_raise_if_no_gfile_raises_correct_message(self):\n        path = \"gs://bucket/some/file.txt\"\n        expected_error_msg = (\n            \"Handling remote paths requires installing TensorFlow \"\n            f\".*Received path: {path}\"\n        )\n        with self.assertRaisesRegex(ValueError, expected_error_msg):\n            file_utils._raise_if_no_gfile(path)\n"
  },
  {
    "path": "keras/src/utils/grain_utils.py",
    "content": "from keras.src import backend\nfrom keras.src import tree\n\n\ndef make_batch(values):\n    from keras.src import ops\n\n    if not values:\n        raise ValueError(\"Cannot batch 0 values. Please file a bug.\")\n\n    with backend.device_scope(\"cpu\"):\n        return tree.map_structure(lambda *xs: ops.stack(xs), *values)\n\n\ndef make_string_batch(values):\n    from keras.src import ops\n\n    if not values:\n        raise ValueError(\"Cannot batch 0 values. Please file a bug.\")\n\n    def batch_fn(*xs):\n        if isinstance(xs[0], str):\n            if backend.backend() == \"tensorflow\":\n                import tensorflow as tf\n\n                xs = [tf.convert_to_tensor(x, dtype=tf.string) for x in xs]\n                xs = tf.stack(xs)\n            return xs\n        else:\n            return ops.stack(xs)\n\n    with backend.device_scope(\"cpu\"):\n        return tree.map_structure(batch_fn, *values)\n"
  },
  {
    "path": "keras/src/utils/image_dataset_utils.py",
    "content": "import io\nimport pathlib\n\nimport numpy as np\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend.config import standardize_data_format\nfrom keras.src.utils import dataset_utils\nfrom keras.src.utils import image_utils\nfrom keras.src.utils.grain_utils import make_batch\nfrom keras.src.utils.module_utils import grain\nfrom keras.src.utils.module_utils import tensorflow as tf\n\ntry:\n    from PIL import Image as pil_image\n\n    try:\n        pil_image_resampling = pil_image.Resampling\n    except AttributeError:\n        pil_image_resampling = pil_image\nexcept ImportError:\n    pil_image = None\n    pil_image_resampling = None\n\nALLOWLIST_FORMATS = (\".bmp\", \".gif\", \".jpeg\", \".jpg\", \".png\")\n\n\n@keras_export(\n    [\n        \"keras.utils.image_dataset_from_directory\",\n        \"keras.preprocessing.image_dataset_from_directory\",\n    ]\n)\ndef image_dataset_from_directory(\n    directory,\n    labels=\"inferred\",\n    label_mode=\"int\",\n    class_names=None,\n    color_mode=\"rgb\",\n    batch_size=32,\n    image_size=(256, 256),\n    shuffle=True,\n    seed=None,\n    validation_split=None,\n    subset=None,\n    interpolation=\"bilinear\",\n    follow_links=False,\n    crop_to_aspect_ratio=False,\n    pad_to_aspect_ratio=False,\n    data_format=None,\n    format=\"tf\",\n    verbose=True,\n):\n    \"\"\"Generates a dataset from image files in a directory.\n\n    If your directory structure is:\n\n    ```\n    main_directory/\n    ...class_a/\n    ......a_image_1.jpg\n    ......a_image_2.jpg\n    ...class_b/\n    ......b_image_1.jpg\n    ......b_image_2.jpg\n    ```\n\n    Then calling `image_dataset_from_directory(main_directory,\n    labels='inferred')` will return a dataset that yields batches of\n    images from the subdirectories `class_a` and `class_b`, together with labels\n    0 and 1 (0 corresponding to `class_a` and 1 corresponding to `class_b`).\n\n    Supported image formats: `.jpeg`, `.jpg`, `.png`, `.bmp`, `.gif`.\n    Animated gifs are truncated to the first frame.\n\n    By default, this function will return a `tf.data.Dataset` object. You can\n    set `format=\"grain\"` to return a `grain.IterDataset` object instead, which\n    removes the TensorFlow dependency.\n\n    Args:\n        directory: Directory where the data is located.\n            If `labels` is `\"inferred\"`, it should contain\n            subdirectories, each containing images for a class.\n            Otherwise, the directory structure is ignored.\n        labels: Either `\"inferred\"`\n            (labels are generated from the directory structure),\n            `None` (no labels),\n            or a list/tuple of integer labels of the same size as the number of\n            image files found in the directory. Labels should be sorted\n            according to the alphanumeric order of the image file paths\n            (obtained via `os.walk(directory)` in Python).\n        label_mode: String describing the encoding of `labels`. Options are:\n            - `\"int\"`: means that the labels are encoded as integers\n                (e.g. for `sparse_categorical_crossentropy` loss).\n            - `\"categorical\"` means that the labels are\n                encoded as a categorical vector\n                (e.g. for `categorical_crossentropy` loss).\n            - `\"binary\"` means that the labels (there can be only 2)\n                are encoded as `float32` scalars with values 0 or 1\n                (e.g. for `binary_crossentropy`).\n            - `None` (no labels).\n        class_names: Only valid if `labels` is `\"inferred\"`.\n            This is the explicit list of class names\n            (must match names of subdirectories). Used to control the order\n            of the classes (otherwise alphanumerical order is used).\n        color_mode: One of `\"grayscale\"`, `\"rgb\"`, `\"rgba\"`.\n            Whether the images will be converted to\n            have 1, 3, or 4 channels. Defaults to `\"rgb\"`.\n        batch_size: Size of the batches of data. Defaults to 32.\n            If `None`, the data will not be batched\n            (the dataset will yield individual samples).\n        image_size: Size to resize images to after they are read from disk,\n            specified as `(height, width)`.\n            Since the pipeline processes batches of images that must all have\n            the same size, this must be provided. Defaults to `(256, 256)`.\n        shuffle: Whether to shuffle the data. Defaults to `True`.\n            If set to `False`, sorts the data in alphanumeric order.\n        seed: Optional random seed for shuffling and transformations.\n        validation_split: Optional float between 0 and 1,\n            fraction of data to reserve for validation.\n        subset: Subset of the data to return.\n            One of `\"training\"`, `\"validation\"`, or `\"both\"`.\n            Only used if `validation_split` is set.\n            When `subset=\"both\"`, the utility returns a tuple of two datasets\n            (the training and validation datasets respectively).\n        interpolation: String, the interpolation method used when\n            resizing images.\n            Supports `\"bilinear\"`, `\"nearest\"`, `\"bicubic\"`, `\"area\"`,\n            `\"lanczos3\"`, `\"lanczos5\"`, `\"gaussian\"`, `\"mitchellcubic\"`.\n            Defaults to `\"bilinear\"`.\n        follow_links: Whether to visit subdirectories pointed to by symlinks.\n            Defaults to `False`.\n        crop_to_aspect_ratio: If `True`, resize the images without aspect\n            ratio distortion. When the original aspect ratio differs from the\n            target aspect ratio, the output image will be cropped so as to\n            return the largest possible window in the image\n            (of size `image_size`) that matches the target aspect ratio. By\n            default (`crop_to_aspect_ratio=False`), aspect ratio may not be\n            preserved.\n        pad_to_aspect_ratio: If `True`, resize the images without aspect\n            ratio distortion. When the original aspect ratio differs from the\n            target aspect ratio, the output image will be padded so as to\n            return the largest possible window in the image\n            (of size `image_size`) that matches the target aspect ratio. By\n            default (`pad_to_aspect_ratio=False`), aspect ratio may not be\n            preserved.\n        data_format: If None uses keras.config.image_data_format()\n            otherwise either 'channel_last' or 'channel_first'.\n        format: The format of the return object. Defaults to `\"tf\"`. Available\n            options are:\n            - `\"tf\"`: returns a `tf.data.Dataset` object. Requires\n                TensorFlow to be installed.\n            - `\"grain\"`: returns a `grain.IterDataset` object. Requires\n                Grain to be installed.\n        verbose: Whether to display number information on classes and\n            number of files found. Defaults to `True`.\n\n    Returns:\n\n    A `tf.data.Dataset` (`format=\"tf\"`) or `grain.IterDataset`\n    (`format=\"grain\"`) object.\n\n    - If `label_mode` is `None`, it yields `float32` tensors of shape\n        `(batch_size, image_size[0], image_size[1], num_channels)`,\n        encoding images (see below for rules regarding `num_channels`).\n    - Otherwise, it yields a tuple `(images, labels)`, where `images` has\n        shape `(batch_size, image_size[0], image_size[1], num_channels)`,\n        and `labels` follows the format described below.\n\n    Rules regarding labels format:\n\n    - if `label_mode` is `\"int\"`, the labels are an `int32` tensor of shape\n        `(batch_size,)`.\n    - if `label_mode` is `\"binary\"`, the labels are a `float32` tensor of\n        1s and 0s of shape `(batch_size, 1)`.\n    - if `label_mode` is `\"categorical\"`, the labels are a `float32` tensor\n        of shape `(batch_size, num_classes)`, representing a one-hot\n        encoding of the class index.\n\n    Rules regarding number of channels in the yielded images:\n\n    - if `color_mode` is `\"grayscale\"`,\n        there's 1 channel in the image tensors.\n    - if `color_mode` is `\"rgb\"`,\n        there are 3 channels in the image tensors.\n    - if `color_mode` is `\"rgba\"`,\n        there are 4 channels in the image tensors.\n    \"\"\"\n\n    if labels not in (\"inferred\", None):\n        if not isinstance(labels, (list, tuple)):\n            raise ValueError(\n                \"`labels` argument should be a list/tuple of integer labels, \"\n                \"of the same size as the number of image files in the target \"\n                \"directory. If you wish to infer the labels from the \"\n                \"subdirectory \"\n                'names in the target directory, pass `labels=\"inferred\"`. '\n                \"If you wish to get a dataset that only contains images \"\n                f\"(no labels), pass `labels=None`. Received: labels={labels}\"\n            )\n        if class_names:\n            raise ValueError(\n                \"You can only pass `class_names` if \"\n                f'`labels=\"inferred\"`. Received: labels={labels}, and '\n                f\"class_names={class_names}\"\n            )\n    if label_mode not in {\"int\", \"categorical\", \"binary\", None}:\n        raise ValueError(\n            '`label_mode` argument must be one of \"int\", '\n            '\"categorical\", \"binary\", '\n            f\"or None. Received: label_mode={label_mode}\"\n        )\n    if labels is None or label_mode is None:\n        labels = None\n        label_mode = None\n    if color_mode == \"rgb\":\n        num_channels = 3\n    elif color_mode == \"rgba\":\n        num_channels = 4\n    elif color_mode == \"grayscale\":\n        num_channels = 1\n    else:\n        raise ValueError(\n            '`color_mode` must be one of {\"rgb\", \"rgba\", \"grayscale\"}. '\n            f\"Received: color_mode={color_mode}\"\n        )\n\n    if isinstance(image_size, int):\n        image_size = (image_size, image_size)\n    elif not isinstance(image_size, (list, tuple)) or not len(image_size) == 2:\n        raise ValueError(\n            \"Invalid `image_size` value. Expected a tuple of 2 integers. \"\n            f\"Received: image_size={image_size}\"\n        )\n\n    interpolation = interpolation.lower()\n    supported_interpolations = (\n        \"bilinear\",\n        \"nearest\",\n        \"bicubic\",\n        \"area\",\n        \"lanczos3\",\n        \"lanczos5\",\n        \"gaussian\",\n        \"mitchellcubic\",\n    )\n    if interpolation not in supported_interpolations:\n        raise ValueError(\n            \"Argument `interpolation` should be one of \"\n            f\"{supported_interpolations}. \"\n            f\"Received: interpolation={interpolation}\"\n        )\n    if format not in (\"tf\", \"grain\"):\n        raise ValueError(\n            '`format` should be either \"tf\" or \"grain\". '\n            f\"Received: format={format}\"\n        )\n\n    dataset_utils.check_validation_split_arg(\n        validation_split, subset, shuffle, seed\n    )\n\n    if seed is None:\n        seed = np.random.randint(1e6)\n    image_paths, labels, class_names = dataset_utils.index_directory(\n        directory,\n        labels,\n        formats=ALLOWLIST_FORMATS,\n        class_names=class_names,\n        shuffle=shuffle,\n        seed=seed,\n        follow_links=follow_links,\n        verbose=verbose,\n    )\n\n    if label_mode == \"binary\" and len(class_names) != 2:\n        raise ValueError(\n            'When passing `label_mode=\"binary\"`, there must be exactly 2 '\n            f\"class_names. Received: class_names={class_names}\"\n        )\n\n    data_format = standardize_data_format(data_format=data_format)\n    if batch_size is not None:\n        shuffle_buffer_size = batch_size * 8\n    else:\n        shuffle_buffer_size = 1024\n\n    if subset == \"both\":\n        (\n            image_paths_train,\n            labels_train,\n        ) = dataset_utils.get_training_or_validation_split(\n            image_paths, labels, validation_split, \"training\"\n        )\n        (\n            image_paths_val,\n            labels_val,\n        ) = dataset_utils.get_training_or_validation_split(\n            image_paths, labels, validation_split, \"validation\"\n        )\n        if not image_paths_train:\n            raise ValueError(\n                f\"No training images found in directory {directory}. \"\n                f\"Allowed formats: {ALLOWLIST_FORMATS}\"\n            )\n        if not image_paths_val:\n            raise ValueError(\n                f\"No validation images found in directory {directory}. \"\n                f\"Allowed formats: {ALLOWLIST_FORMATS}\"\n            )\n        train_dataset = paths_and_labels_to_dataset(\n            image_paths=image_paths_train,\n            image_size=image_size,\n            num_channels=num_channels,\n            labels=labels_train,\n            label_mode=label_mode,\n            num_classes=len(class_names) if class_names else 0,\n            interpolation=interpolation,\n            crop_to_aspect_ratio=crop_to_aspect_ratio,\n            pad_to_aspect_ratio=pad_to_aspect_ratio,\n            data_format=data_format,\n            shuffle=shuffle,\n            shuffle_buffer_size=shuffle_buffer_size,\n            seed=seed,\n            format=format,\n        )\n\n        val_dataset = paths_and_labels_to_dataset(\n            image_paths=image_paths_val,\n            image_size=image_size,\n            num_channels=num_channels,\n            labels=labels_val,\n            label_mode=label_mode,\n            num_classes=len(class_names) if class_names else 0,\n            interpolation=interpolation,\n            crop_to_aspect_ratio=crop_to_aspect_ratio,\n            pad_to_aspect_ratio=pad_to_aspect_ratio,\n            data_format=data_format,\n            shuffle=False,\n            format=format,\n        )\n\n        if format == \"tf\":\n            if batch_size is not None:\n                train_dataset = train_dataset.batch(batch_size)\n                val_dataset = val_dataset.batch(batch_size)\n            train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)\n            val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE)\n        else:\n            train_dataset = train_dataset.to_iter_dataset()\n            val_dataset = val_dataset.to_iter_dataset()\n            if batch_size is not None:\n                train_dataset = train_dataset.batch(\n                    batch_size, batch_fn=make_batch\n                )\n                val_dataset = val_dataset.batch(batch_size, batch_fn=make_batch)\n\n        # Users may need to reference `class_names`.\n        train_dataset.class_names = class_names\n        val_dataset.class_names = class_names\n\n        # Include file paths for images as attribute.\n        train_dataset.file_paths = image_paths_train\n        val_dataset.file_paths = image_paths_val\n\n        dataset = [train_dataset, val_dataset]\n    else:\n        image_paths, labels = dataset_utils.get_training_or_validation_split(\n            image_paths, labels, validation_split, subset\n        )\n        if not image_paths:\n            raise ValueError(\n                f\"No images found in directory {directory}. \"\n                f\"Allowed formats: {ALLOWLIST_FORMATS}\"\n            )\n\n        dataset = paths_and_labels_to_dataset(\n            image_paths=image_paths,\n            image_size=image_size,\n            num_channels=num_channels,\n            labels=labels,\n            label_mode=label_mode,\n            num_classes=len(class_names) if class_names else 0,\n            interpolation=interpolation,\n            crop_to_aspect_ratio=crop_to_aspect_ratio,\n            pad_to_aspect_ratio=pad_to_aspect_ratio,\n            data_format=data_format,\n            shuffle=shuffle,\n            shuffle_buffer_size=shuffle_buffer_size,\n            seed=seed,\n            format=format,\n        )\n\n        if format == \"tf\":\n            if batch_size is not None:\n                dataset = dataset.batch(batch_size)\n            dataset = dataset.prefetch(tf.data.AUTOTUNE)\n        else:\n            dataset = dataset.to_iter_dataset()\n            if batch_size is not None:\n                dataset = dataset.batch(batch_size, batch_fn=make_batch)\n\n        # Users may need to reference `class_names`.\n        dataset.class_names = class_names\n\n        # Include file paths for images as attribute.\n        dataset.file_paths = image_paths\n\n    return dataset\n\n\ndef paths_and_labels_to_dataset(\n    image_paths,\n    image_size,\n    num_channels,\n    labels,\n    label_mode,\n    num_classes,\n    interpolation,\n    data_format,\n    crop_to_aspect_ratio=False,\n    pad_to_aspect_ratio=False,\n    shuffle=False,\n    shuffle_buffer_size=None,\n    seed=None,\n    format=\"tf\",\n):\n    \"\"\"Constructs a dataset of images and labels.\"\"\"\n    if format == \"tf\":\n        return _paths_and_labels_to_dataset_tf(\n            image_paths=image_paths,\n            image_size=image_size,\n            num_channels=num_channels,\n            labels=labels,\n            label_mode=label_mode,\n            num_classes=num_classes,\n            interpolation=interpolation,\n            data_format=data_format,\n            crop_to_aspect_ratio=crop_to_aspect_ratio,\n            pad_to_aspect_ratio=pad_to_aspect_ratio,\n            shuffle=shuffle,\n            shuffle_buffer_size=shuffle_buffer_size,\n            seed=seed,\n        )\n    elif format == \"grain\":\n        return _paths_and_labels_to_dataset_grain(\n            image_paths=image_paths,\n            image_size=image_size,\n            num_channels=num_channels,\n            labels=labels,\n            label_mode=label_mode,\n            num_classes=num_classes,\n            interpolation=interpolation,\n            data_format=data_format,\n            crop_to_aspect_ratio=crop_to_aspect_ratio,\n            pad_to_aspect_ratio=pad_to_aspect_ratio,\n            shuffle=shuffle,\n            seed=seed,\n        )\n    else:\n        raise ValueError(\n            '`format` should be either \"tf\" or \"grain\". '\n            f\"Received: format={format}\"\n        )\n\n\ndef _paths_and_labels_to_dataset_tf(\n    image_paths,\n    image_size,\n    num_channels,\n    labels,\n    label_mode,\n    num_classes,\n    interpolation,\n    data_format,\n    crop_to_aspect_ratio=False,\n    pad_to_aspect_ratio=False,\n    shuffle=False,\n    shuffle_buffer_size=None,\n    seed=None,\n):\n    \"\"\"Constructs a dataset of images and labels.\"\"\"\n    path_ds = tf.data.Dataset.from_tensor_slices(image_paths)\n    if label_mode:\n        label_ds = dataset_utils.labels_to_dataset_tf(\n            labels, label_mode, num_classes\n        )\n        ds = tf.data.Dataset.zip((path_ds, label_ds))\n    else:\n        ds = path_ds\n\n    if shuffle:\n        ds = ds.shuffle(buffer_size=shuffle_buffer_size or 1024, seed=seed)\n\n    args = (\n        image_size,\n        num_channels,\n        interpolation,\n        data_format,\n        crop_to_aspect_ratio,\n        pad_to_aspect_ratio,\n    )\n    if label_mode:\n        ds = ds.map(\n            lambda x, y: (_load_image_tf(x, *args), y),\n            num_parallel_calls=tf.data.AUTOTUNE,\n        )\n    else:\n        ds = ds.map(\n            lambda x: _load_image_tf(x, *args),\n            num_parallel_calls=tf.data.AUTOTUNE,\n        )\n    return ds\n\n\ndef _load_image_tf(\n    path,\n    image_size,\n    num_channels,\n    interpolation,\n    data_format,\n    crop_to_aspect_ratio=False,\n    pad_to_aspect_ratio=False,\n):\n    \"\"\"Load an image from a path and resize it.\"\"\"\n    img = tf.io.read_file(path)\n    img = tf.image.decode_image(\n        img, channels=num_channels, expand_animations=False\n    )\n\n    if pad_to_aspect_ratio and crop_to_aspect_ratio:\n        raise ValueError(\n            \"Only one of `pad_to_aspect_ratio`, `crop_to_aspect_ratio`\"\n            \" can be set to `True`.\"\n        )\n\n    if crop_to_aspect_ratio:\n        from keras.src.backend import tensorflow as tf_backend\n\n        if data_format == \"channels_first\":\n            img = tf.transpose(img, (2, 0, 1))\n        img = image_utils.smart_resize(\n            img,\n            image_size,\n            interpolation=interpolation,\n            data_format=data_format,\n            backend_module=tf_backend,\n        )\n    elif pad_to_aspect_ratio:\n        img = tf.image.resize_with_pad(\n            img, image_size[0], image_size[1], method=interpolation\n        )\n        if data_format == \"channels_first\":\n            img = tf.transpose(img, (2, 0, 1))\n    else:\n        img = tf.image.resize(img, image_size, method=interpolation)\n        if data_format == \"channels_first\":\n            img = tf.transpose(img, (2, 0, 1))\n\n    if data_format == \"channels_last\":\n        img.set_shape((image_size[0], image_size[1], num_channels))\n    else:\n        img.set_shape((num_channels, image_size[0], image_size[1]))\n    return img\n\n\ndef _paths_and_labels_to_dataset_grain(\n    image_paths,\n    image_size,\n    num_channels,\n    labels,\n    label_mode,\n    num_classes,\n    interpolation,\n    data_format,\n    crop_to_aspect_ratio=False,\n    pad_to_aspect_ratio=False,\n    shuffle=False,\n    seed=None,\n):\n    \"\"\"Constructs a dataset of images and labels.\"\"\"\n    path_ds = grain.MapDataset.source(image_paths)\n    if label_mode:\n        label_ds = dataset_utils.labels_to_dataset_grain(\n            labels, label_mode, num_classes\n        )\n        ds = grain.experimental.ZipMapDataset([path_ds, label_ds])\n    else:\n        ds = path_ds\n\n    if shuffle:\n        ds = ds.shuffle(seed=seed)\n\n    args = (\n        image_size,\n        num_channels,\n        interpolation,\n        data_format,\n        crop_to_aspect_ratio,\n        pad_to_aspect_ratio,\n    )\n    if label_mode:\n        ds = ds.map(lambda data: (_load_image_grain(data[0], *args), data[1]))\n    else:\n        ds = ds.map(lambda x: _load_image_grain(x, *args))\n\n    return ds\n\n\ndef _load_image_grain(\n    path,\n    image_size,\n    num_channels,\n    interpolation,\n    data_format,\n    crop_to_aspect_ratio=False,\n    pad_to_aspect_ratio=False,\n):\n    \"\"\"Load an image from a path and resize it.\"\"\"\n    from keras.src import backend\n    from keras.src import ops\n\n    if pil_image is None:\n        raise ImportError(\n            \"Could not import PIL.Image. The use of `load_img` requires PIL.\"\n        )\n    if pad_to_aspect_ratio and crop_to_aspect_ratio:\n        raise ValueError(\n            \"Only one of `pad_to_aspect_ratio`, `crop_to_aspect_ratio`\"\n            \" can be set to `True`.\"\n        )\n\n    if isinstance(path, io.BytesIO):\n        img = pil_image.open(path)\n    elif isinstance(path, (pathlib.Path, bytes, str)):\n        if isinstance(path, pathlib.Path):\n            path = str(path.resolve())\n        img = pil_image.open(path)\n    else:\n        raise TypeError(\n            f\"path should be path-like or io.BytesIO, not {type(path)}\"\n        )\n    if num_channels == 1:\n        # if image is not already an 8-bit, 16-bit or 32-bit grayscale image\n        # convert it to an 8-bit grayscale image.\n        if img.mode not in (\"L\", \"I;16\", \"I\"):\n            img = img.convert(\"L\")\n    elif num_channels == 4:\n        if img.mode != \"RGBA\":\n            img = img.convert(\"RGBA\")\n    elif num_channels == 3:\n        if img.mode != \"RGB\":\n            img = img.convert(\"RGB\")\n    else:\n        raise ValueError(\n            \"num_channels must be 1, 3 or 4. \"\n            f\"Received: num_channels={num_channels}\"\n        )\n\n    with backend.device_scope(\"cpu\"):\n        img = ops.convert_to_tensor(np.array(img), dtype=\"float32\")\n        if len(img.shape) == 2:\n            # If the image is grayscale, expand dims to add channel axis.\n            # The reason is that `ops.image.resize` expects 3D or 4D tensors.\n            img = ops.expand_dims(img, axis=-1)\n        if data_format == \"channels_first\":\n            img = ops.transpose(img, (2, 0, 1))\n        img = ops.image.resize(\n            img,\n            size=image_size,\n            interpolation=interpolation,\n            crop_to_aspect_ratio=crop_to_aspect_ratio,\n            pad_to_aspect_ratio=pad_to_aspect_ratio,\n            data_format=data_format,\n        )\n        if backend.backend() == \"tensorflow\":\n            if data_format == \"channels_last\":\n                img.set_shape((image_size[0], image_size[1], num_channels))\n            else:\n                img.set_shape((num_channels, image_size[0], image_size[1]))\n    return img\n"
  },
  {
    "path": "keras/src/utils/image_dataset_utils_test.py",
    "content": "import os\n\nimport numpy as np\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.utils import image_dataset_utils\nfrom keras.src.utils import image_utils\nfrom keras.src.utils.module_utils import tensorflow as tf\n\n\nclass ImageDatasetFromDirectoryTest(testing.TestCase):\n    def _get_images(self, count=16, color_mode=\"rgb\"):\n        width = height = 24\n        imgs = []\n        for _ in range(count):\n            if color_mode == \"grayscale\":\n                img = np.random.randint(0, 256, size=(height, width, 1))\n            elif color_mode == \"rgba\":\n                img = np.random.randint(0, 256, size=(height, width, 4))\n            else:\n                img = np.random.randint(0, 256, size=(height, width, 3))\n            if backend.config.image_data_format() == \"channels_first\":\n                img = np.transpose(img, (2, 0, 1))\n            img = image_utils.array_to_img(img)\n            imgs.append(img)\n        return imgs\n\n    def _prepare_directory(\n        self,\n        num_classes=2,\n        nested_dirs=False,\n        color_mode=\"rgb\",\n        count=16,\n    ):\n        # Generate paths to class subdirectories\n        temp_dir = self.get_temp_dir()\n        paths = []\n        for class_index in range(num_classes):\n            class_directory = f\"class_{class_index}\"\n            if nested_dirs:\n                class_paths = [\n                    class_directory,\n                    os.path.join(class_directory, \"subfolder_1\"),\n                    os.path.join(class_directory, \"subfolder_2\"),\n                    os.path.join(\n                        class_directory, \"subfolder_1\", \"sub-subfolder\"\n                    ),\n                ]\n            else:\n                class_paths = [class_directory]\n            for path in class_paths:\n                os.mkdir(os.path.join(temp_dir, path))\n            paths += class_paths\n\n        # Save images to the paths\n        i = 0\n        for img in self._get_images(color_mode=color_mode, count=count):\n            path = paths[i % len(paths)]\n            if color_mode == \"rgb\":\n                ext = \"jpg\"\n            else:\n                ext = \"png\"\n            filename = os.path.join(path, f\"image_{i}.{ext}\")\n            img.save(os.path.join(temp_dir, filename))\n            i += 1\n        return temp_dir\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_image_dataset_from_directory_no_labels(self, format):\n        # Test retrieving images without labels from a directory and its\n        # subdirs.\n\n        # Save a few extra images in the parent directory.\n        directory = self._prepare_directory(count=7, num_classes=2)\n        for i, img in enumerate(self._get_images(3)):\n            filename = f\"image_{i}.jpg\"\n            img.save(os.path.join(directory, filename))\n\n        dataset = image_dataset_utils.image_dataset_from_directory(\n            directory,\n            batch_size=5,\n            image_size=(18, 18),\n            labels=None,\n            format=format,\n        )\n        if backend.config.image_data_format() == \"channels_last\":\n            output_shape = [5, 18, 18, 3]\n        else:\n            output_shape = [5, 3, 18, 18]\n        self.assertEqual(dataset.class_names, None)\n        batch = next(iter(dataset))\n        # We return plain images\n        self.assertEqual(list(batch.shape), output_shape)\n        self.assertDType(batch, \"float32\")\n        # Count samples\n        batch_count = 0\n        sample_count = 0\n        for batch in dataset:\n            batch_count += 1\n            sample_count += batch.shape[0]\n        self.assertEqual(batch_count, 2)\n        self.assertEqual(sample_count, 10)\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_image_dataset_from_directory_binary(self, format):\n        directory = self._prepare_directory(num_classes=2)\n        dataset = image_dataset_utils.image_dataset_from_directory(\n            directory,\n            batch_size=8,\n            image_size=(18, 18),\n            label_mode=\"int\",\n            format=format,\n        )\n        if backend.config.image_data_format() == \"channels_last\":\n            output_shape = [8, 18, 18, 3]\n        else:\n            output_shape = [8, 3, 18, 18]\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        self.assertEqual(list(batch[0].shape), output_shape)\n        self.assertDType(batch[0], \"float32\")\n        self.assertEqual(list(batch[1].shape), [8])\n        self.assertDType(batch[1], \"int32\")\n\n        dataset = image_dataset_utils.image_dataset_from_directory(\n            directory,\n            batch_size=8,\n            image_size=(18, 18),\n            label_mode=\"binary\",\n            format=format,\n        )\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        self.assertEqual(list(batch[0].shape), output_shape)\n        self.assertDType(batch[0], \"float32\")\n        self.assertEqual(list(batch[1].shape), [8, 1])\n        self.assertDType(batch[1], \"float32\")\n\n        dataset = image_dataset_utils.image_dataset_from_directory(\n            directory,\n            batch_size=8,\n            image_size=(18, 18),\n            label_mode=\"categorical\",\n            format=format,\n        )\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        self.assertEqual(list(batch[0].shape), output_shape)\n        self.assertDType(batch[0], \"float32\")\n        self.assertEqual(list(batch[1].shape), [8, 2])\n        self.assertDType(batch[1], \"float32\")\n\n    def test_static_shape_in_graph(self):\n        directory = self._prepare_directory(num_classes=2)\n        dataset = image_dataset_utils.image_dataset_from_directory(\n            directory, batch_size=8, image_size=(18, 18), label_mode=\"int\"\n        )\n        test_obj = self\n        if backend.config.image_data_format() == \"channels_last\":\n            output_shape = [None, 18, 18, 3]\n        else:\n            output_shape = [None, 3, 18, 18]\n\n        @tf.function\n        def symbolic_fn(ds):\n            for x, _ in ds.take(1):\n                test_obj.assertListEqual(x.shape.as_list(), output_shape)\n\n        symbolic_fn(dataset)\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_sample_count(self, format):\n        directory = self._prepare_directory(num_classes=4, count=15)\n        dataset = image_dataset_utils.image_dataset_from_directory(\n            directory,\n            batch_size=8,\n            image_size=(18, 18),\n            label_mode=None,\n            format=format,\n        )\n        sample_count = 0\n        for batch in dataset:\n            sample_count += batch.shape[0]\n        self.assertEqual(sample_count, 15)\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_image_dataset_from_directory_multiclass(self, format):\n        directory = self._prepare_directory(num_classes=4, count=15)\n\n        dataset = image_dataset_utils.image_dataset_from_directory(\n            directory,\n            batch_size=8,\n            image_size=(18, 18),\n            label_mode=None,\n            format=format,\n        )\n        if backend.config.image_data_format() == \"channels_last\":\n            output_shape = [8, 18, 18, 3]\n        else:\n            output_shape = [8, 3, 18, 18]\n        batch = next(iter(dataset))\n        self.assertEqual(list(batch.shape), output_shape)\n\n        dataset = image_dataset_utils.image_dataset_from_directory(\n            directory,\n            batch_size=8,\n            image_size=(18, 18),\n            label_mode=None,\n            format=format,\n        )\n        sample_count = 0\n        iterator = iter(dataset)\n        for batch in dataset:\n            sample_count += next(iterator).shape[0]\n        self.assertEqual(sample_count, 15)\n\n        dataset = image_dataset_utils.image_dataset_from_directory(\n            directory,\n            batch_size=8,\n            image_size=(18, 18),\n            label_mode=\"int\",\n            format=format,\n        )\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        self.assertEqual(list(batch[0].shape), output_shape)\n        self.assertDType(batch[0], \"float32\")\n        self.assertEqual(list(batch[1].shape), [8])\n        self.assertDType(batch[1], \"int32\")\n\n        dataset = image_dataset_utils.image_dataset_from_directory(\n            directory,\n            batch_size=8,\n            image_size=(18, 18),\n            label_mode=\"categorical\",\n            format=format,\n        )\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        self.assertEqual(list(batch[0].shape), output_shape)\n        self.assertDType(batch[0], \"float32\")\n        self.assertEqual(list(batch[1].shape), [8, 4])\n        self.assertDType(batch[1], \"float32\")\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_image_dataset_from_directory_color_modes(self, format):\n        directory = self._prepare_directory(num_classes=4, color_mode=\"rgba\")\n        dataset = image_dataset_utils.image_dataset_from_directory(\n            directory,\n            batch_size=8,\n            image_size=(18, 18),\n            color_mode=\"rgba\",\n            format=format,\n        )\n        if backend.config.image_data_format() == \"channels_last\":\n            output_shape = [8, 18, 18, 4]\n        else:\n            output_shape = [8, 4, 18, 18]\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        self.assertEqual(list(batch[0].shape), output_shape)\n        self.assertDType(batch[0], \"float32\")\n\n        directory = self._prepare_directory(\n            num_classes=4, color_mode=\"grayscale\"\n        )\n        dataset = image_dataset_utils.image_dataset_from_directory(\n            directory,\n            batch_size=8,\n            image_size=(18, 18),\n            color_mode=\"grayscale\",\n            format=format,\n        )\n        if backend.config.image_data_format() == \"channels_last\":\n            output_shape = [8, 18, 18, 1]\n        else:\n            output_shape = [8, 1, 18, 18]\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        self.assertEqual(list(batch[0].shape), output_shape)\n        self.assertDType(batch[0], \"float32\")\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_image_dataset_from_directory_validation_split(self, format):\n        directory = self._prepare_directory(num_classes=2, count=10)\n        dataset = image_dataset_utils.image_dataset_from_directory(\n            directory,\n            batch_size=10,\n            image_size=(18, 18),\n            validation_split=0.2,\n            subset=\"training\",\n            seed=1337,\n            format=format,\n        )\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        if backend.config.image_data_format() == \"channels_last\":\n            train_output_shape = [8, 18, 18, 3]\n            val_output_shape = [2, 18, 18, 3]\n        else:\n            train_output_shape = [8, 3, 18, 18]\n            val_output_shape = [2, 3, 18, 18]\n        self.assertEqual(list(batch[0].shape), train_output_shape)\n        dataset = image_dataset_utils.image_dataset_from_directory(\n            directory,\n            batch_size=10,\n            image_size=(18, 18),\n            validation_split=0.2,\n            subset=\"validation\",\n            seed=1337,\n            format=format,\n        )\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        self.assertEqual(list(batch[0].shape), val_output_shape)\n\n        (\n            train_dataset,\n            val_dataset,\n        ) = image_dataset_utils.image_dataset_from_directory(\n            directory,\n            batch_size=10,\n            image_size=(18, 18),\n            validation_split=0.2,\n            subset=\"both\",\n            seed=1337,\n            format=format,\n        )\n        batch = next(iter(train_dataset))\n        self.assertLen(batch, 2)\n        self.assertEqual(list(batch[0].shape), train_output_shape)\n        batch = next(iter(val_dataset))\n        self.assertLen(batch, 2)\n        self.assertEqual(list(batch[0].shape), val_output_shape)\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_image_dataset_from_directory_manual_labels(self, format):\n        # Case: wrong number of labels\n        directory = self._prepare_directory(num_classes=1, count=4)\n        with self.assertRaisesRegex(ValueError, \"match the number of files\"):\n            image_dataset_utils.image_dataset_from_directory(\n                directory,\n                batch_size=8,\n                image_size=(18, 18),\n                labels=[0, 1, 0],\n                shuffle=False,\n                format=format,\n            )\n\n        # Case: single directory\n        directory = self._prepare_directory(num_classes=1, count=4)\n        dataset = image_dataset_utils.image_dataset_from_directory(\n            directory,\n            batch_size=8,\n            image_size=(18, 18),\n            labels=[0, 1, 0, 1],\n            shuffle=False,\n            format=format,\n        )\n        if backend.config.image_data_format() == \"channels_last\":\n            output_shape = [18, 18, 3]\n        else:\n            output_shape = [3, 18, 18]\n        self.assertEqual(dataset.class_names, [\"0\", \"1\"])\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        self.assertEqual(list(batch[0].shape), [4] + output_shape)\n        self.assertAllClose(batch[1], [0, 1, 0, 1])\n\n        # Case: multiple directories\n        directory = self._prepare_directory(num_classes=3, count=6)\n        dataset = image_dataset_utils.image_dataset_from_directory(\n            directory,\n            batch_size=8,\n            image_size=(18, 18),\n            labels=[0, 1, 0, 1, 1, 1],\n            shuffle=False,\n            format=format,\n        )\n        self.assertEqual(dataset.class_names, [\"0\", \"1\"])\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        self.assertEqual(list(batch[0].shape), [6] + output_shape)\n        self.assertAllClose(batch[1], [0, 1, 0, 1, 1, 1])\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_image_dataset_from_directory_follow_links(self, format):\n        directory = self._prepare_directory(\n            num_classes=2, count=25, nested_dirs=True\n        )\n        dataset = image_dataset_utils.image_dataset_from_directory(\n            directory,\n            batch_size=8,\n            image_size=(18, 18),\n            label_mode=None,\n            follow_links=True,\n            format=format,\n        )\n        sample_count = 0\n        for batch in dataset:\n            sample_count += batch.shape[0]\n        self.assertEqual(sample_count, 25)\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_image_dataset_from_directory_no_images(self, format):\n        directory = self._prepare_directory(num_classes=2, count=0)\n        with self.assertRaisesRegex(ValueError, \"No images found.\"):\n            _ = image_dataset_utils.image_dataset_from_directory(\n                directory, format=format\n            )\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_image_dataset_from_directory_crop_to_aspect_ratio(self, format):\n        directory = self._prepare_directory(num_classes=2, count=5)\n        dataset = image_dataset_utils.image_dataset_from_directory(\n            directory,\n            batch_size=5,\n            image_size=(18, 18),\n            crop_to_aspect_ratio=True,\n            format=format,\n        )\n        if backend.config.image_data_format() == \"channels_last\":\n            output_shape = [5, 18, 18, 3]\n        else:\n            output_shape = [5, 3, 18, 18]\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        self.assertEqual(list(batch[0].shape), output_shape)\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_image_dataset_from_directory_pad_to_aspect_ratio(self, format):\n        directory = self._prepare_directory(num_classes=2, count=5)\n        dataset = image_dataset_utils.image_dataset_from_directory(\n            directory,\n            batch_size=5,\n            image_size=(18, 18),\n            pad_to_aspect_ratio=True,\n            format=format,\n        )\n        if backend.config.image_data_format() == \"channels_last\":\n            output_shape = [5, 18, 18, 3]\n        else:\n            output_shape = [5, 3, 18, 18]\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        self.assertEqual(list(batch[0].shape), output_shape)\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_image_dataset_from_directory_errors(self, format):\n        directory = self._prepare_directory(num_classes=3, count=5)\n\n        with self.assertRaisesRegex(ValueError, \"`labels` argument should be\"):\n            _ = image_dataset_utils.image_dataset_from_directory(\n                directory, labels=\"other\", format=format\n            )\n\n        with self.assertRaisesRegex(\n            ValueError, \"`label_mode` argument must be\"\n        ):\n            _ = image_dataset_utils.image_dataset_from_directory(\n                directory, label_mode=\"other\", format=format\n            )\n\n        with self.assertRaisesRegex(ValueError, \"`color_mode` must be one of\"):\n            _ = image_dataset_utils.image_dataset_from_directory(\n                directory, color_mode=\"other\", format=format\n            )\n\n        with self.assertRaisesRegex(\n            ValueError, 'only pass `class_names` if `labels=\"inferred\"`'\n        ):\n            _ = image_dataset_utils.image_dataset_from_directory(\n                directory,\n                labels=[0, 0, 1, 1, 1],\n                class_names=[\"class_0\", \"class_1\", \"class_2\"],\n                format=format,\n            )\n\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Expected the lengths of `labels` to match the number of files\",\n        ):\n            _ = image_dataset_utils.image_dataset_from_directory(\n                directory, labels=[0, 0, 1, 1], format=format\n            )\n\n        with self.assertRaisesRegex(\n            ValueError, \"`class_names` passed did not match\"\n        ):\n            _ = image_dataset_utils.image_dataset_from_directory(\n                directory, class_names=[\"class_0\", \"wrong_class\"], format=format\n            )\n\n        with self.assertRaisesRegex(ValueError, \"there must be exactly 2\"):\n            _ = image_dataset_utils.image_dataset_from_directory(\n                directory, label_mode=\"binary\", format=format\n            )\n\n        with self.assertRaisesRegex(\n            ValueError, \"`validation_split` must be between 0 and 1\"\n        ):\n            _ = image_dataset_utils.image_dataset_from_directory(\n                directory, validation_split=2, format=format\n            )\n\n        with self.assertRaisesRegex(\n            ValueError,\n            '`subset` must be either \"training\", \"validation\" or \"both\"',\n        ):\n            _ = image_dataset_utils.image_dataset_from_directory(\n                directory, validation_split=0.2, subset=\"other\", format=format\n            )\n\n        with self.assertRaisesRegex(\n            ValueError, \"`validation_split` must be set\"\n        ):\n            _ = image_dataset_utils.image_dataset_from_directory(\n                directory,\n                validation_split=0.0,\n                subset=\"training\",\n                format=format,\n            )\n\n        with self.assertRaisesRegex(ValueError, \"must provide a `seed`\"):\n            _ = image_dataset_utils.image_dataset_from_directory(\n                directory,\n                validation_split=0.2,\n                subset=\"training\",\n                format=format,\n            )\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_image_dataset_from_directory_not_batched(self, format):\n        directory = self._prepare_directory(num_classes=2, count=2)\n        dataset = image_dataset_utils.image_dataset_from_directory(\n            directory,\n            batch_size=None,\n            image_size=(18, 18),\n            label_mode=None,\n            shuffle=False,\n            format=format,\n        )\n        sample = next(iter(dataset))\n        self.assertEqual(len(sample.shape), 3)\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_image_dataset_from_directory_shuffle(self, format):\n        # TODO: add same test for train/val\n        directory = self._prepare_directory(\n            num_classes=2, count=25, nested_dirs=True\n        )\n        dataset = image_dataset_utils.image_dataset_from_directory(\n            directory,\n            batch_size=8,\n            image_size=(18, 18),\n            label_mode=None,\n            follow_links=True,\n            shuffle=False,\n            format=format,\n        )\n        batches_1 = []\n        batches_2 = []\n        for b in dataset:\n            batches_1.append(ops.convert_to_numpy(b))\n        batches_1 = np.concatenate(batches_1, axis=0)\n        for b in dataset:\n            batches_2.append(ops.convert_to_numpy(b))\n        batches_2 = np.concatenate(batches_2, axis=0)\n        self.assertAllClose(batches_1, batches_2, atol=1e-6)\n\n        dataset = image_dataset_utils.image_dataset_from_directory(\n            directory,\n            batch_size=8,\n            image_size=(18, 18),\n            label_mode=None,\n            follow_links=True,\n            shuffle=True,\n            seed=1337,\n            format=format,\n        )\n        batches_1 = []\n        batches_2 = []\n        for b in dataset:\n            batches_1.append(ops.convert_to_numpy(b))\n        batches_1 = np.concatenate(batches_1, axis=0)\n        for b in dataset:\n            batches_2.append(ops.convert_to_numpy(b))\n        batches_2 = np.concatenate(batches_2, axis=0)\n        if format == \"tf\":\n            self.assertNotAllClose(batches_1, batches_2, atol=1e-6)\n        else:\n            # Grain shuffles deterministically, so we expect the same batches.\n            self.assertAllClose(batches_1, batches_2, atol=1e-6)\n\n        # Test random seed determinism\n        dataset = image_dataset_utils.image_dataset_from_directory(\n            directory,\n            batch_size=8,\n            image_size=(18, 18),\n            label_mode=None,\n            follow_links=True,\n            shuffle=True,\n            seed=1337,\n            format=format,\n        )\n        batches_1_alt = []\n        for b in dataset:\n            batches_1_alt.append(ops.convert_to_numpy(b))\n        batches_1_alt = np.concatenate(batches_1_alt, axis=0)\n        self.assertAllClose(batches_1, batches_1_alt, atol=1e-6)\n"
  },
  {
    "path": "keras/src/utils/image_utils.py",
    "content": "\"\"\"Utilities related to image handling.\"\"\"\n\nimport io\nimport pathlib\nimport warnings\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\n\ntry:\n    from PIL import Image as pil_image\n\n    try:\n        pil_image_resampling = pil_image.Resampling\n    except AttributeError:\n        pil_image_resampling = pil_image\nexcept ImportError:\n    pil_image = None\n    pil_image_resampling = None\n\n\nif pil_image_resampling is not None:\n    PIL_INTERPOLATION_METHODS = {\n        \"nearest\": pil_image_resampling.NEAREST,\n        \"bilinear\": pil_image_resampling.BILINEAR,\n        \"bicubic\": pil_image_resampling.BICUBIC,\n        \"hamming\": pil_image_resampling.HAMMING,\n        \"box\": pil_image_resampling.BOX,\n        \"lanczos\": pil_image_resampling.LANCZOS,\n    }\n\n\n@keras_export(\n    [\n        \"keras.utils.array_to_img\",\n        \"keras.preprocessing.image.array_to_img\",\n    ]\n)\ndef array_to_img(x, data_format=None, scale=True, dtype=None):\n    \"\"\"Converts a 3D NumPy array to a PIL Image instance.\n\n    Example:\n\n    ```python\n    from PIL import Image\n    img = np.random.random(size=(100, 100, 3))\n    pil_img = keras.utils.array_to_img(img)\n    ```\n\n    Args:\n        x: Input data, in any form that can be converted to a NumPy array.\n        data_format: Image data format, can be either `\"channels_first\"` or\n            `\"channels_last\"`. Defaults to `None`, in which case the global\n            setting `keras.backend.image_data_format()` is used (unless you\n            changed it, it defaults to `\"channels_last\"`).\n        scale: Whether to rescale the image such that minimum and maximum values\n            are 0 and 255 respectively. Defaults to `True`.\n        dtype: Dtype to use. `None` means the global setting\n            `keras.backend.floatx()` is used (unless you changed it, it\n            defaults to `\"float32\"`). Defaults to `None`.\n\n    Returns:\n        A PIL Image instance.\n    \"\"\"\n\n    data_format = backend.standardize_data_format(data_format)\n    if dtype is None:\n        dtype = backend.floatx()\n    if pil_image is None:\n        raise ImportError(\n            \"Could not import PIL.Image. \"\n            \"The use of `array_to_img` requires PIL.\"\n        )\n    x = np.asarray(x, dtype=dtype)\n    if x.ndim != 3:\n        raise ValueError(\n            \"Expected image array to have rank 3 (single image). \"\n            f\"Got array with shape: {x.shape}\"\n        )\n\n    # Original NumPy array x has format (height, width, channel)\n    # or (channel, height, width)\n    # but target PIL image has format (width, height, channel)\n    if data_format == \"channels_first\":\n        x = x.transpose(1, 2, 0)\n    if scale:\n        x = x - np.min(x)\n        x_max = np.max(x)\n        if x_max != 0:\n            x /= x_max\n        x *= 255\n    if x.shape[2] == 4:\n        # RGBA\n        return pil_image.fromarray(x.astype(\"uint8\"), \"RGBA\")\n    elif x.shape[2] == 3:\n        # RGB\n        return pil_image.fromarray(x.astype(\"uint8\"), \"RGB\")\n    elif x.shape[2] == 1:\n        # grayscale\n        if np.max(x) > 255:\n            # 32-bit signed integer grayscale image. PIL mode \"I\"\n            return pil_image.fromarray(x[:, :, 0].astype(\"int32\"), \"I\")\n        return pil_image.fromarray(x[:, :, 0].astype(\"uint8\"), \"L\")\n    else:\n        raise ValueError(f\"Unsupported channel number: {x.shape[2]}\")\n\n\n@keras_export(\n    [\n        \"keras.utils.img_to_array\",\n        \"keras.preprocessing.image.img_to_array\",\n    ]\n)\ndef img_to_array(img, data_format=None, dtype=None):\n    \"\"\"Converts a PIL Image instance to a NumPy array.\n\n    Example:\n\n    ```python\n    from PIL import Image\n    img_data = np.random.random(size=(100, 100, 3))\n    img = keras.utils.array_to_img(img_data)\n    array = keras.utils.image.img_to_array(img)\n    ```\n\n    Args:\n        img: Input PIL Image instance.\n        data_format: Image data format, can be either `\"channels_first\"` or\n            `\"channels_last\"`. Defaults to `None`, in which case the global\n            setting `keras.backend.image_data_format()` is used (unless you\n            changed it, it defaults to `\"channels_last\"`).\n        dtype: Dtype to use. `None` means the global setting\n            `keras.backend.floatx()` is used (unless you changed it, it\n            defaults to `\"float32\"`).\n\n    Returns:\n        A 3D NumPy array.\n    \"\"\"\n\n    data_format = backend.standardize_data_format(data_format)\n    if dtype is None:\n        dtype = backend.floatx()\n    # NumPy array x has format (height, width, channel)\n    # or (channel, height, width)\n    # but original PIL image has format (width, height, channel)\n    x = np.asarray(img, dtype=dtype)\n    if len(x.shape) == 3:\n        if data_format == \"channels_first\":\n            x = x.transpose(2, 0, 1)\n    elif len(x.shape) == 2:\n        if data_format == \"channels_first\":\n            x = x.reshape((1, x.shape[0], x.shape[1]))\n        else:\n            x = x.reshape((x.shape[0], x.shape[1], 1))\n    else:\n        raise ValueError(f\"Unsupported image shape: {x.shape}\")\n    return x\n\n\n@keras_export([\"keras.utils.save_img\", \"keras.preprocessing.image.save_img\"])\ndef save_img(path, x, data_format=None, file_format=None, scale=True, **kwargs):\n    \"\"\"Saves an image stored as a NumPy array to a path or file object.\n\n    Args:\n        path: Path or file object.\n        x: NumPy array.\n        data_format: Image data format, either `\"channels_first\"` or\n            `\"channels_last\"`.\n        file_format: Optional file format override. If omitted, the format to\n            use is determined from the filename extension. If a file object was\n            used instead of a filename, this parameter should always be used.\n        scale: Whether to rescale image values to be within `[0, 255]`.\n        **kwargs: Additional keyword arguments passed to `PIL.Image.save()`.\n    \"\"\"\n    data_format = backend.standardize_data_format(data_format)\n\n    # Infer format from path if not explicitly provided\n    if file_format is None and isinstance(path, (str, pathlib.Path)):\n        file_format = pathlib.Path(path).suffix[1:].lower()\n\n    # Normalize jpg → jpeg for Pillow compatibility\n    if file_format and file_format.lower() == \"jpg\":\n        file_format = \"jpeg\"\n\n    img = array_to_img(x, data_format=data_format, scale=scale)\n\n    # Handle RGBA → RGB conversion for JPEG\n    if img.mode == \"RGBA\" and file_format == \"jpeg\":\n        warnings.warn(\n            \"The JPEG format does not support RGBA images, converting to RGB.\"\n        )\n        img = img.convert(\"RGB\")\n\n    img.save(path, format=file_format, **kwargs)\n\n\n@keras_export([\"keras.utils.load_img\", \"keras.preprocessing.image.load_img\"])\ndef load_img(\n    path,\n    color_mode=\"rgb\",\n    target_size=None,\n    interpolation=\"nearest\",\n    keep_aspect_ratio=False,\n):\n    \"\"\"Loads an image into PIL format.\n\n    Example:\n\n    ```python\n    image = keras.utils.load_img(image_path)\n    input_arr = keras.utils.img_to_array(image)\n    input_arr = np.array([input_arr])  # Convert single image to a batch.\n    predictions = model.predict(input_arr)\n    ```\n\n    Args:\n        path: Path to image file.\n        color_mode: One of `\"grayscale\"`, `\"rgb\"`, `\"rgba\"`. Default: `\"rgb\"`.\n            The desired image format.\n        target_size: Either `None` (default to original size) or tuple of ints\n            `(img_height, img_width)`.\n        interpolation: Interpolation method used to resample the image if the\n            target size is different from that of the loaded image. Supported\n            methods are `\"nearest\"`, `\"bilinear\"`, and `\"bicubic\"`.\n            If PIL version 1.1.3 or newer is installed, `\"lanczos\"`\n            is also supported. If PIL version 3.4.0 or newer is installed,\n            `\"box\"` and `\"hamming\"` are also\n            supported. By default, `\"nearest\"` is used.\n        keep_aspect_ratio: Boolean, whether to resize images to a target\n            size without aspect ratio distortion. The image is cropped in\n            the center with target aspect ratio before resizing.\n\n    Returns:\n        A PIL Image instance.\n    \"\"\"\n    if pil_image is None:\n        raise ImportError(\n            \"Could not import PIL.Image. The use of `load_img` requires PIL.\"\n        )\n    if isinstance(path, io.BytesIO):\n        img = pil_image.open(path)\n    elif isinstance(path, (pathlib.Path, bytes, str)):\n        if isinstance(path, pathlib.Path):\n            path = str(path.resolve())\n        with open(path, \"rb\") as f:\n            img = pil_image.open(io.BytesIO(f.read()))\n    else:\n        raise TypeError(\n            f\"path should be path-like or io.BytesIO, not {type(path)}\"\n        )\n\n    if color_mode == \"grayscale\":\n        # if image is not already an 8-bit, 16-bit or 32-bit grayscale image\n        # convert it to an 8-bit grayscale image.\n        if img.mode not in (\"L\", \"I;16\", \"I\"):\n            img = img.convert(\"L\")\n    elif color_mode == \"rgba\":\n        if img.mode != \"RGBA\":\n            img = img.convert(\"RGBA\")\n    elif color_mode == \"rgb\":\n        if img.mode != \"RGB\":\n            img = img.convert(\"RGB\")\n    else:\n        raise ValueError('color_mode must be \"grayscale\", \"rgb\", or \"rgba\"')\n    if target_size is not None:\n        width_height_tuple = (target_size[1], target_size[0])\n        if img.size != width_height_tuple:\n            if interpolation not in PIL_INTERPOLATION_METHODS:\n                raise ValueError(\n                    \"Invalid interpolation method {} specified. Supported \"\n                    \"methods are {}\".format(\n                        interpolation,\n                        \", \".join(PIL_INTERPOLATION_METHODS.keys()),\n                    )\n                )\n            resample = PIL_INTERPOLATION_METHODS[interpolation]\n\n            if keep_aspect_ratio:\n                width, height = img.size\n                target_width, target_height = width_height_tuple\n\n                crop_height = (width * target_height) // target_width\n                crop_width = (height * target_width) // target_height\n\n                # Set back to input height / width\n                # if crop_height / crop_width is not smaller.\n                crop_height = min(height, crop_height)\n                crop_width = min(width, crop_width)\n\n                crop_box_hstart = (height - crop_height) // 2\n                crop_box_wstart = (width - crop_width) // 2\n                crop_box_wend = crop_box_wstart + crop_width\n                crop_box_hend = crop_box_hstart + crop_height\n                crop_box = [\n                    crop_box_wstart,\n                    crop_box_hstart,\n                    crop_box_wend,\n                    crop_box_hend,\n                ]\n                img = img.resize(width_height_tuple, resample, box=crop_box)\n            else:\n                img = img.resize(width_height_tuple, resample)\n    return img\n\n\n@keras_export(\"keras.preprocessing.image.smart_resize\")\ndef smart_resize(\n    x,\n    size,\n    interpolation=\"bilinear\",\n    data_format=\"channels_last\",\n    **kwargs,\n):\n    \"\"\"Resize images to a target size without aspect ratio distortion.\n\n    Image datasets typically yield images that have each a different\n    size. However, these images need to be batched before they can be\n    processed by Keras layers. To be batched, images need to share the same\n    height and width.\n\n    You could simply do, in TF (or JAX equivalent):\n\n    ```python\n    size = (200, 200)\n    ds = ds.map(lambda img: resize(img, size))\n    ```\n\n    However, if you do this, you distort the aspect ratio of your images, since\n    in general they do not all have the same aspect ratio as `size`. This is\n    fine in many cases, but not always (e.g. for image generation models\n    this can be a problem).\n\n    Note that passing the argument `preserve_aspect_ratio=True` to `resize`\n    will preserve the aspect ratio, but at the cost of no longer respecting the\n    provided target size.\n\n    This calls for:\n\n    ```python\n    size = (200, 200)\n    ds = ds.map(lambda img: smart_resize(img, size))\n    ```\n\n    Your output images will actually be `(200, 200)`, and will not be distorted.\n    Instead, the parts of the image that do not fit within the target size\n    get cropped out.\n\n    The resizing process is:\n\n    1. Take the largest centered crop of the image that has the same aspect\n    ratio as the target size. For instance, if `size=(200, 200)` and the input\n    image has size `(340, 500)`, we take a crop of `(340, 340)` centered along\n    the width.\n    2. Resize the cropped image to the target size. In the example above,\n    we resize the `(340, 340)` crop to `(200, 200)`.\n\n    Args:\n        x: Input image or batch of images (as a tensor or NumPy array).\n            Must be in format `(height, width, channels)`\n            or `(batch_size, height, width, channels)`.\n        size: Tuple of `(height, width)` integer. Target size.\n        interpolation: String, interpolation to use for resizing.\n            Supports `\"bilinear\"`, `\"nearest\"`, `\"bicubic\"`,\n            `\"lanczos3\"`, `\"lanczos5\"`.\n            Defaults to `\"bilinear\"`.\n        data_format: `\"channels_last\"` or `\"channels_first\"`.\n\n    Returns:\n        Array with shape `(size[0], size[1], channels)`.\n        If the input image was a NumPy array, the output is a NumPy array,\n        and if it was a backend-native tensor,\n        the output is a backend-native tensor.\n    \"\"\"\n    backend_module = kwargs.pop(\"backend_module\", None) or backend\n    if kwargs:\n        raise TypeError(\n            \"smart_resize() got unexpected keyword arguments: \"\n            f\"{list(kwargs.keys())}\"\n        )\n    if len(size) != 2:\n        raise ValueError(\n            f\"Expected `size` to be a tuple of 2 integers, but got: {size}.\"\n        )\n    img = backend_module.convert_to_tensor(x)\n    if len(img.shape) is not None:\n        if len(img.shape) < 3 or len(img.shape) > 4:\n            raise ValueError(\n                \"Expected an image array with shape `(height, width, \"\n                \"channels)`, or `(batch_size, height, width, channels)`, but \"\n                f\"got input with incorrect rank, of shape {img.shape}.\"\n            )\n    shape = backend_module.shape(img)\n    if data_format == \"channels_last\":\n        height, width = shape[-3], shape[-2]\n    else:\n        height, width = shape[-2], shape[-1]\n    target_height, target_width = size\n\n    # Set back to input height / width if crop_height / crop_width is not\n    # smaller.\n    if isinstance(height, int) and isinstance(width, int):\n        # For JAX, we need to keep the slice indices as static integers\n        crop_height = int(float(width * target_height) / target_width)\n        crop_height = max(min(height, crop_height), 1)\n        crop_width = int(float(height * target_width) / target_height)\n        crop_width = max(min(width, crop_width), 1)\n        crop_box_hstart = int(float(height - crop_height) / 2)\n        crop_box_wstart = int(float(width - crop_width) / 2)\n    else:\n        crop_height = backend_module.cast(\n            backend_module.cast(width * target_height, \"float32\")\n            / target_width,\n            \"int32\",\n        )\n        crop_height = backend_module.numpy.minimum(height, crop_height)\n        crop_height = backend_module.numpy.maximum(crop_height, 1)\n        crop_height = backend_module.cast(crop_height, \"int32\")\n\n        crop_width = backend_module.cast(\n            backend_module.cast(height * target_width, \"float32\")\n            / target_height,\n            \"int32\",\n        )\n        crop_width = backend_module.numpy.minimum(width, crop_width)\n        crop_width = backend_module.numpy.maximum(crop_width, 1)\n        crop_width = backend_module.cast(crop_width, \"int32\")\n\n        crop_box_hstart = backend_module.cast(\n            backend_module.cast(height - crop_height, \"float32\") / 2, \"int32\"\n        )\n        crop_box_wstart = backend_module.cast(\n            backend_module.cast(width - crop_width, \"float32\") / 2, \"int32\"\n        )\n\n    if data_format == \"channels_last\":\n        if len(img.shape) == 4:\n            img = img[\n                :,\n                crop_box_hstart : crop_box_hstart + crop_height,\n                crop_box_wstart : crop_box_wstart + crop_width,\n                :,\n            ]\n        else:\n            img = img[\n                crop_box_hstart : crop_box_hstart + crop_height,\n                crop_box_wstart : crop_box_wstart + crop_width,\n                :,\n            ]\n    else:\n        if len(img.shape) == 4:\n            img = img[\n                :,\n                :,\n                crop_box_hstart : crop_box_hstart + crop_height,\n                crop_box_wstart : crop_box_wstart + crop_width,\n            ]\n        else:\n            img = img[\n                :,\n                crop_box_hstart : crop_box_hstart + crop_height,\n                crop_box_wstart : crop_box_wstart + crop_width,\n            ]\n\n    img = backend_module.image.resize(\n        img, size=size, interpolation=interpolation, data_format=data_format\n    )\n\n    if isinstance(x, np.ndarray):\n        return np.array(img)\n    return img\n"
  },
  {
    "path": "keras/src/utils/image_utils_test.py",
    "content": "import os\n\nimport numpy as np\nfrom absl.testing import parameterized\n\nfrom keras.src import testing\nfrom keras.src.utils import img_to_array\nfrom keras.src.utils import load_img\nfrom keras.src.utils import save_img\n\n\nclass SaveImgTest(testing.TestCase, parameterized.TestCase):\n    @parameterized.named_parameters(\n        (\"rgb_explicit_format\", (50, 50, 3), \"rgb.jpg\", \"jpg\", True),\n        (\"rgba_explicit_format\", (50, 50, 4), \"rgba.jpg\", \"jpg\", True),\n        (\"rgb_inferred_format\", (50, 50, 3), \"rgb_inferred.jpg\", None, False),\n        (\"rgba_inferred_format\", (50, 50, 4), \"rgba_inferred.jpg\", None, False),\n    )\n    def test_save_jpg(self, shape, name, file_format, use_explicit_format):\n        tmp_dir = self.get_temp_dir()\n        path = os.path.join(tmp_dir, name)\n\n        img = np.random.randint(0, 256, size=shape, dtype=np.uint8)\n\n        # Test the actual inferred case - don't pass file_format at all\n        if use_explicit_format:\n            save_img(path, img, file_format=file_format)\n        else:\n            save_img(path, img)  # Let it infer from path\n\n        self.assertTrue(os.path.exists(path))\n\n        # Verify saved image is correctly converted to RGB if needed\n        loaded_img = load_img(path)\n        loaded_array = img_to_array(loaded_img)\n        self.assertEqual(loaded_array.shape, (50, 50, 3))\n"
  },
  {
    "path": "keras/src/utils/io_utils.py",
    "content": "import sys\n\nfrom absl import logging\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend.common import global_state\n\n\n@keras_export(\n    [\n        \"keras.config.enable_interactive_logging\",\n        \"keras.utils.enable_interactive_logging\",\n    ]\n)\ndef enable_interactive_logging():\n    \"\"\"Turn on interactive logging.\n\n    When interactive logging is enabled, Keras displays logs via stdout.\n    This provides the best experience when using Keras in an interactive\n    environment such as a shell or a notebook.\n    \"\"\"\n    global_state.set_global_attribute(\"interactive_logging\", True)\n\n\n@keras_export(\n    [\n        \"keras.config.disable_interactive_logging\",\n        \"keras.utils.disable_interactive_logging\",\n    ]\n)\ndef disable_interactive_logging():\n    \"\"\"Turn off interactive logging.\n\n    When interactive logging is disabled, Keras sends logs to `absl.logging`.\n    This is the best option when using Keras in a non-interactive\n    way, such as running a training or inference job on a server.\n    \"\"\"\n    global_state.set_global_attribute(\"interactive_logging\", False)\n\n\n@keras_export(\n    [\n        \"keras.config.is_interactive_logging_enabled\",\n        \"keras.utils.is_interactive_logging_enabled\",\n    ]\n)\ndef is_interactive_logging_enabled():\n    \"\"\"Check if interactive logging is enabled.\n\n    To switch between writing logs to stdout and `absl.logging`, you may use\n    `keras.config.enable_interactive_logging()` and\n    `keras.config.disable_interactive_logging()`.\n\n    Returns:\n        Boolean, `True` if interactive logging is enabled,\n        and `False` otherwise.\n    \"\"\"\n    return global_state.get_global_attribute(\"interactive_logging\", True)\n\n\ndef set_logging_verbosity(level):\n    \"\"\"Sets the verbosity level for logging.\n\n    Supported log levels are as follows:\n\n    - `\"FATAL\"` (least verbose)\n    - `\"ERROR\"`\n    - `\"WARNING\"`\n    - `\"INFO\"`\n    - `\"DEBUG\"` (most verbose)\n\n    Args:\n        level: A string corresponding to the level of verbosity for logging.\n    \"\"\"\n    valid_levels = {\n        \"FATAL\": logging.FATAL,\n        \"ERROR\": logging.ERROR,\n        \"WARNING\": logging.WARNING,\n        \"INFO\": logging.INFO,\n        \"DEBUG\": logging.DEBUG,\n    }\n    verbosity = valid_levels.get(level)\n    if verbosity is None:\n        raise ValueError(\n            \"Please pass a valid level for logging verbosity. \"\n            f\"Expected one of: {set(valid_levels.keys())}. \"\n            f\"Received: {level}\"\n        )\n    logging.set_verbosity(verbosity)\n\n\ndef print_msg(message, line_break=True):\n    \"\"\"Print the message to absl logging or stdout.\"\"\"\n    message = str(message)\n    if is_interactive_logging_enabled():\n        message = f\"{message}\\n\" if line_break else message\n        try:\n            sys.stdout.write(message)\n        except UnicodeEncodeError:\n            # If the encoding differs from UTF-8, `sys.stdout.write` may fail.\n            # To address this, replace special unicode characters in the\n            # message, and then encode and decode using the target encoding.\n            message = _replace_special_unicode_character(message)\n            # Fallback to UTF-8 when `sys.stdout.encoding` is `None` (e.g. when\n            # stdout is redirected). This prevents a `TypeError` that would be\n            # raised by `bytes.encode(None)` / `bytes.decode(None)`.\n            encoding = sys.stdout.encoding or \"utf-8\"\n            message_bytes = message.encode(encoding, errors=\"ignore\")\n            message = message_bytes.decode(encoding)\n            sys.stdout.write(message)\n        sys.stdout.flush()\n    else:\n        logging.info(message)\n\n\ndef ask_to_proceed_with_overwrite(filepath):\n    \"\"\"Produces a prompt asking about overwriting a file.\n\n    Args:\n        filepath: the path to the file to be overwritten.\n\n    Returns:\n        True if we can proceed with overwrite, False otherwise.\n    \"\"\"\n    overwrite = (\n        input(f\"[WARNING] {filepath} already exists - overwrite? [y/n]\")\n        .strip()\n        .lower()\n    )\n    while overwrite not in (\"y\", \"n\"):\n        overwrite = (\n            input('Enter \"y\" (overwrite) or \"n\" (cancel).').strip().lower()\n        )\n    if overwrite == \"n\":\n        return False\n    print_msg(\"[TIP] Next time specify overwrite=True!\")\n    return True\n\n\ndef _replace_special_unicode_character(message):\n    message = str(message).replace(\"━\", \"=\")  # Fall back to Keras2 behavior.\n    return message\n"
  },
  {
    "path": "keras/src/utils/io_utils_test.py",
    "content": "import sys\nimport tempfile\nfrom unittest.mock import patch\n\nfrom keras.src.testing import test_case\nfrom keras.src.utils import io_utils\n\n\nclass TestIoUtils(test_case.TestCase):\n    def test_enable_interactive_logging(self):\n        io_utils.enable_interactive_logging()\n        self.assertTrue(io_utils.is_interactive_logging_enabled())\n\n    def test_disable_interactive_logging(self):\n        io_utils.disable_interactive_logging()\n        self.assertFalse(io_utils.is_interactive_logging_enabled())\n\n    def test_set_logging_verbosity_valid(self):\n        valid_levels = [\"FATAL\", \"ERROR\", \"WARNING\", \"INFO\", \"DEBUG\"]\n        for level in valid_levels:\n            io_utils.set_logging_verbosity(level)\n\n    def test_set_logging_verbosity_invalid(self):\n        with self.assertRaises(ValueError):\n            io_utils.set_logging_verbosity(\"INVALID\")\n\n    @patch(\"builtins.input\", side_effect=[\"y\"])\n    def test_ask_to_proceed_with_overwrite_yes(self, _):\n        self.assertTrue(io_utils.ask_to_proceed_with_overwrite(\"test_path\"))\n\n    @patch(\"builtins.input\", side_effect=[\"n\"])\n    def test_ask_to_proceed_with_overwrite_no(self, _):\n        self.assertFalse(io_utils.ask_to_proceed_with_overwrite(\"test_path\"))\n\n    @patch(\"sys.stdout.write\")\n    def test_print_msg_interactive_with_line_break(self, mock_write):\n        io_utils.enable_interactive_logging()\n        io_utils.print_msg(\"Hello\", line_break=True)\n        mock_write.assert_called_once_with(\"Hello\\n\")\n\n    @patch(\"sys.stdout.write\")\n    def test_print_msg_interactive_without_line_break(self, mock_write):\n        io_utils.enable_interactive_logging()\n        io_utils.print_msg(\"Hello\", line_break=False)\n        mock_write.assert_called_once_with(\"Hello\")\n\n    @patch(\"absl.logging.info\")\n    def test_print_msg_non_interactive(self, mock_logging):\n        io_utils.disable_interactive_logging()\n        io_utils.print_msg(\"Hello\")\n        mock_logging.assert_called_once_with(\"Hello\")\n\n    @patch(\"builtins.input\", side_effect=[\"invalid\", \"invalid\", \"y\"])\n    def test_ask_to_proceed_with_overwrite_invalid_then_yes(self, _):\n        self.assertTrue(io_utils.ask_to_proceed_with_overwrite(\"test_path\"))\n\n    @patch(\"builtins.input\", side_effect=[\"invalid\", \"n\"])\n    def test_ask_to_proceed_with_overwrite_invalid_then_no(self, _):\n        self.assertFalse(io_utils.ask_to_proceed_with_overwrite(\"test_path\"))\n\n    def test_print_msg_with_different_encoding(self):\n        # https://github.com/keras-team/keras/issues/19386\n        io_utils.enable_interactive_logging()\n        self.assertTrue(io_utils.is_interactive_logging_enabled())\n        ori_stdout = sys.stdout\n        with tempfile.TemporaryFile(mode=\"w\", encoding=\"cp1251\") as tmp:\n            sys.stdout = tmp\n            io_utils.print_msg(\"━\")\n        sys.stdout = ori_stdout\n"
  },
  {
    "path": "keras/src/utils/jax_layer.py",
    "content": "import functools\nimport inspect\nimport itertools\nimport string\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src import tree\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend.common.variables import is_float_dtype\nfrom keras.src.backend.common.variables import standardize_dtype\nfrom keras.src.layers.layer import Layer\nfrom keras.src.random.seed_generator import draw_seed\nfrom keras.src.saving import serialization_lib\nfrom keras.src.utils import jax_utils\nfrom keras.src.utils import tracking\nfrom keras.src.utils.module_utils import jax\nfrom keras.src.utils.module_utils import tensorflow as tf\n\nif backend.backend() == \"tensorflow\":\n    tf_no_automatic_dependency_tracking = (\n        tf.__internal__.tracking.no_automatic_dependency_tracking\n    )\nelse:\n\n    def tf_no_automatic_dependency_tracking(fn):\n        return fn\n\n\ndef _convert_to_jax_key(tensor):\n    if backend.backend() == \"tensorflow\":\n        return tf.bitcast(tensor, tf.uint32)[0]\n    return tensor\n\n\n@keras_export(\"keras.layers.JaxLayer\")\nclass JaxLayer(Layer):\n    \"\"\"Keras Layer that wraps a JAX model.\n\n    This layer enables the use of JAX components within Keras when using JAX as\n    the backend for Keras.\n\n    ## Model function\n\n    This layer accepts JAX models in the form of a function, `call_fn`, which\n    must take the following arguments with these exact names:\n\n    - `params`: trainable parameters of the model.\n    - `state` (*optional*): non-trainable state of the model. Can be omitted if\n        the model has no non-trainable state.\n    - `rng` (*optional*): a `jax.random.PRNGKey` instance. Can be omitted if the\n        model does not need RNGs, neither during training nor during inference.\n    - `inputs`: inputs to the model, a JAX array or a `PyTree` of arrays.\n    - `training` (*optional*): an argument specifying if we're in training mode\n        or inference mode, `True` is passed in training mode. Can be omitted if\n        the model behaves the same in training mode and inference mode.\n\n    The `inputs` argument is mandatory. Inputs to the model must be provided via\n    a single argument. If the JAX model takes multiple inputs as separate\n    arguments, they must be combined into a single structure, for instance in a\n    `tuple` or a `dict`.\n\n    ## Model weights initialization\n\n    The initialization of the `params` and `state` of the model can be handled\n    by this layer, in which case the `init_fn` argument must be provided. This\n    allows the model to be initialized dynamically with the right shape.\n    Alternatively, and if the shape is known, the `params` argument and\n    optionally the `state` argument can be used to create an already initialized\n    model.\n\n    The `init_fn` function, if provided, must take the following arguments with\n    these exact names:\n\n    - `rng`: a `jax.random.PRNGKey` instance.\n    - `inputs`: a JAX array or a `PyTree` of arrays with placeholder values to\n        provide the shape of the inputs.\n    - `training` (*optional*): an argument specifying if we're in training mode\n        or inference mode. `True` is always passed to `init_fn`. Can be omitted\n        regardless of whether `call_fn` has a `training` argument.\n\n    ## Models with non-trainable state\n\n    For JAX models that have non-trainable state:\n\n    - `call_fn` must have a `state` argument\n    - `call_fn` must return a `tuple` containing the outputs of the model and\n        the new non-trainable state of the model\n    - `init_fn` must return a `tuple` containing the initial trainable params of\n        the model and the initial non-trainable state of the model.\n\n    This code shows a possible combination of `call_fn` and `init_fn` signatures\n    for a model with non-trainable state. In this example, the model has a\n    `training` argument and an `rng` argument in `call_fn`.\n\n    ```python\n    def stateful_call(params, state, rng, inputs, training):\n        outputs = ...\n        new_state = ...\n        return outputs, new_state\n\n    def stateful_init(rng, inputs):\n        initial_params = ...\n        initial_state = ...\n        return initial_params, initial_state\n    ```\n\n    ## Models without non-trainable state\n\n    For JAX models with no non-trainable state:\n\n    - `call_fn` must not have a `state` argument\n    - `call_fn` must return only the outputs of the model\n    - `init_fn` must return only the initial trainable params of the model.\n\n    This code shows a possible combination of `call_fn` and `init_fn` signatures\n    for a model without non-trainable state. In this example, the model does not\n    have a `training` argument and does not have an `rng` argument in `call_fn`.\n\n    ```python\n    def stateless_call(params, inputs):\n        outputs = ...\n        return outputs\n\n    def stateless_init(rng, inputs):\n        initial_params = ...\n        return initial_params\n    ```\n\n    ## Conforming to the required signature\n\n    If a model has a different signature than the one required by `JaxLayer`,\n    one can easily write a wrapper method to adapt the arguments. This example\n    shows a model that has multiple inputs as separate arguments, expects\n    multiple RNGs in a `dict`, and has a `deterministic` argument with the\n    opposite meaning of `training`. To conform, the inputs are combined in a\n    single structure using a `tuple`, the RNG is split and used the populate the\n    expected `dict`, and the Boolean flag is negated:\n\n    ```python\n    def my_model_fn(params, rngs, input1, input2, deterministic):\n        ...\n        if not deterministic:\n            dropout_rng = rngs[\"dropout\"]\n            keep = jax.random.bernoulli(dropout_rng, dropout_rate, x.shape)\n            x = jax.numpy.where(keep, x / dropout_rate, 0)\n            ...\n        ...\n        return outputs\n\n    def my_model_wrapper_fn(params, rng, inputs, training):\n        input1, input2 = inputs\n        rng1, rng2 = jax.random.split(rng)\n        rngs = {\"dropout\": rng1, \"preprocessing\": rng2}\n        deterministic = not training\n        return my_model_fn(params, rngs, input1, input2, deterministic)\n\n    keras_layer = JaxLayer(my_model_wrapper_fn, params=initial_params)\n    ```\n\n    ## Usage with Haiku modules\n\n    `JaxLayer` enables the use of [Haiku](https://dm-haiku.readthedocs.io)\n    components in the form of\n    [`haiku.Module`](https://dm-haiku.readthedocs.io/en/latest/api.html#module).\n    This is achieved by transforming the module per the Haiku pattern and then\n    passing `module.apply` in the `call_fn` parameter and `module.init` in the\n    `init_fn` parameter if needed.\n\n    If the model has non-trainable state, it should be transformed with\n    [`haiku.transform_with_state`](\n      https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.transform_with_state).\n    If the model has no non-trainable state, it should be transformed with\n    [`haiku.transform`](\n      https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.transform).\n    Additionally, and optionally, if the module does not use RNGs in \"apply\", it\n    can be transformed with\n    [`haiku.without_apply_rng`](\n      https://dm-haiku.readthedocs.io/en/latest/api.html#without-apply-rng).\n\n    The following example shows how to create a `JaxLayer` from a Haiku module\n    that uses random number generators via `hk.next_rng_key()` and takes a\n    training positional argument:\n\n    ```python\n    class MyHaikuModule(hk.Module):\n        def __call__(self, x, training):\n            x = hk.Conv2D(32, (3, 3))(x)\n            x = jax.nn.relu(x)\n            x = hk.AvgPool((1, 2, 2, 1), (1, 2, 2, 1), \"VALID\")(x)\n            x = hk.Flatten()(x)\n            x = hk.Linear(200)(x)\n            if training:\n                x = hk.dropout(rng=hk.next_rng_key(), rate=0.3, x=x)\n            x = jax.nn.relu(x)\n            x = hk.Linear(10)(x)\n            x = jax.nn.softmax(x)\n            return x\n\n    def my_haiku_module_fn(inputs, training):\n        module = MyHaikuModule()\n        return module(inputs, training)\n\n    transformed_module = hk.transform(my_haiku_module_fn)\n\n    keras_layer = JaxLayer(\n        call_fn=transformed_module.apply,\n        init_fn=transformed_module.init,\n    )\n    ```\n\n    Args:\n        call_fn: The function to call the model. See description above for the\n            list of arguments it takes and the outputs it returns.\n        init_fn: the function to call to initialize the model. See description\n            above for the list of arguments it takes and the outputs it returns.\n            If `None`, then `params` and/or `state` must be provided.\n      params: A `PyTree` containing all the model trainable parameters. This\n            allows passing trained parameters or controlling the initialization.\n            If both `params` and `state` are `None`, `init_fn` is called at\n            build time to initialize the trainable parameters of the model.\n      state: A `PyTree` containing all the model non-trainable state. This\n            allows passing learned state or controlling the initialization. If\n            both `params` and `state` are `None`, and `call_fn` takes a `state`\n            argument, then `init_fn` is called at build time to initialize the\n            non-trainable state of the model.\n      seed: Seed for random number generator. Optional.\n      dtype: The dtype of the layer's computations and weights. Can also be a\n            `keras.DTypePolicy`. Optional. Defaults to the default policy.\n    \"\"\"\n\n    def __init__(\n        self,\n        call_fn,\n        init_fn=None,\n        params=None,\n        state=None,\n        seed=None,\n        **kwargs,\n    ):\n        if backend.backend() not in [\"jax\", \"tensorflow\"]:\n            raise ValueError(\n                f\"{self.__class__.__name__} is only supported with the JAX or\"\n                f\" Tensorflow backend. Current backend: {backend.backend()}\"\n            )\n\n        super().__init__(**kwargs)\n        self.call_fn = call_fn\n        self.init_fn = init_fn\n        self.tracked_params = self._create_variables(params, trainable=True)\n        self.tracked_state = self._create_variables(state, trainable=False)\n        if self.params is not None or self.state is not None:\n            self._build_at_init()\n\n        self.call_fn_arguments = self._validate_signature(\n            call_fn,\n            \"call_fn\",\n            {\"params\", \"state\", \"rng\", \"inputs\", \"training\"},\n            {\"inputs\"},\n        )\n        self.call_fn_has_params = \"params\" in self.call_fn_arguments\n        self.call_fn_has_state = \"state\" in self.call_fn_arguments\n        call_fn_has_rng = \"rng\" in self.call_fn_arguments\n\n        if call_fn_has_rng:\n            self.seed_generator = backend.random.SeedGenerator(seed)\n        else:\n            self.seed_generator = None\n\n        if (\n            init_fn is None\n            and params is None\n            and state is None\n            and (self.call_fn_has_params or self.call_fn_has_state)\n        ):\n            raise ValueError(\n                \"`init_fn`, `params` and `state` cannot all be `None` when \"\n                \"`call_fn` takes a `params` or a `state` argument.\"\n            )\n\n        if init_fn:\n            self.init_fn_arguments = self._validate_signature(\n                init_fn, \"init_fn\", {\"rng\", \"inputs\", \"training\"}, {\"inputs\"}\n            )\n\n        # Attributes for jax2tf functions\n        self.jax2tf_training_false_fn = None\n        self.jax2tf_training_true_fn = None\n\n    def _validate_signature(self, fn, fn_name, allowed, required):\n        fn_parameters = inspect.signature(fn).parameters\n        for parameter_name in required:\n            if parameter_name not in fn_parameters:\n                raise ValueError(\n                    f\"Missing required argument in `{fn_name}`: \"\n                    f\"`{parameter_name}`\"\n                )\n\n        parameter_names = []\n        for parameter in fn_parameters.values():\n            if parameter.name not in allowed:\n                raise ValueError(\n                    f\"Unsupported argument in `{fn_name}`: `{parameter.name}`, \"\n                    f\"supported arguments are `{'`, `'.join(allowed)}`\"\n                )\n            parameter_names.append(parameter.name)\n\n        return parameter_names\n\n    def _get_jax2tf_input_shape(self, input_shape):\n        \"\"\"Convert input shape in a format suitable for `jax2tf`.\n\n        `jax2tf` expects a letter for each unknown dimension, which allows\n        correlated dimensions. Since correlated dimensions are not supported by\n        Keras, we simply use 'a', 'b', 'c'..., for each unknown dimension. We\n        however use 'batch' for dimension 0 if not defined to correlate the\n        batch size across inputs.\n\n        Example (spaces added for readability):\n        ```\n        input_shape:  (None , 4   , None, None, 5   )\n        result:      \"(batch, 4   , a   , b   , 5   )\"\n        ```\n\n        Args:\n          input_shape: a single shape or a structure of shapes for the inputs.\n        Returns:\n          the shape or shapes structure in the `jax2tf` format as strings.\n        \"\"\"\n        dim_names = itertools.chain(\n            string.ascii_lowercase,  # a, b, ... z\n            itertools.starmap(  # aa, ab, ... az, ba, bb, ... zz\n                lambda a, b: a + b,\n                itertools.product(string.ascii_lowercase, repeat=2),\n            ),\n        )\n\n        def get_single_jax2tf_shape(shape):\n            jax2tf_shape = []\n\n            for index, dim in enumerate(shape):\n                if dim is not None:\n                    jax2tf_shape.append(str(dim))\n                elif index == 0:\n                    jax2tf_shape.append(\"batch\")\n                else:\n                    jax2tf_shape.append(next(dim_names))\n\n            return \"(\" + \", \".join(jax2tf_shape) + \")\"\n\n        res = tree.map_shape_structure(get_single_jax2tf_shape, input_shape)\n        return res\n\n    def _jax2tf_convert(self, fn, polymorphic_shapes):\n        from jax.experimental import jax2tf\n\n        converted_fn = jax2tf.convert(fn, polymorphic_shapes=polymorphic_shapes)\n        # Autograph won't work with the output of jax2tf.\n        converted_fn = tf.autograph.experimental.do_not_convert(converted_fn)\n        return converted_fn\n\n    def _partial_with_positional(self, fn, index, value):\n        \"\"\"Return a new partial with one positional argument set to a value.\n\n        This is needed because `jax2tf` only supports positional arguments and\n        `functools.partial` only supports setting positional arguments starting\n        from the left. Our use case is the `training` argument which is\n        typically the righmost argument.\n\n        Args:\n          fn: the function to wrap.\n          index: the index of the positional argument to set to `value`.\n          value: the value for the positional argument at `index`.\n        \"\"\"\n\n        @functools.wraps(fn)\n        def wrapper(*args):\n            args = args[0:index] + (value,) + args[index:]\n            return fn(*args)\n\n        return wrapper\n\n    @tracking.no_automatic_dependency_tracking\n    @tf_no_automatic_dependency_tracking\n    def _create_variables(self, values, trainable):\n        \"\"\"Create a structure of variables from a structure of JAX arrays.\n\n        `values` is traversed via JAX's `tree_map`. When a leaf is a JAX array\n        or a tensor-like object, a corresponding variable is created with it as\n        the initial value. The resulting structure of variables is assigned to\n        `self.params` or `self.state` depending on `trainable`. Then, a\n        flattened version of the variables is returned for tracking.\n        `self.params` or `self.state` are intentionally not tracked because\n        structures like `TrackedList` interfere with `jax.tree_utils`.\n        Note that leaf objects that are not JAX arrays and not tensor-like are\n        left intact as they are assumed to be configuration used by the model.\n\n        Args:\n            values: the structure of values to traverse.\n            trainable: whether to create trainable variables.\n\n        Returns:\n            flat list of variables initialized with `values` for tracking.\n        \"\"\"\n\n        def create_variable(value):\n            if backend.is_tensor(value) or isinstance(\n                value, (np.ndarray, np.generic, jax.Array)\n            ):\n                dtype = value.dtype\n                if is_float_dtype(dtype):\n                    dtype = None  # Use the layer dtype policy\n                return self.add_weight(\n                    value.shape,\n                    initializer=backend.convert_to_tensor(value),\n                    dtype=dtype,\n                    trainable=trainable,\n                )\n            elif isinstance(value, (bool, int, float)):\n                dtype = standardize_dtype(type(value))\n                if is_float_dtype(dtype):\n                    dtype = None  # Use the layer dtype policy\n                return self.add_weight(\n                    (),\n                    initializer=backend.convert_to_tensor(value),\n                    dtype=dtype,\n                    trainable=trainable,\n                )\n            else:\n                return value\n\n        # Use JAX's tree_map as it understands registered classes.\n        variables = jax.tree_util.tree_map(create_variable, values)\n\n        if trainable:\n            self.params = variables\n        else:\n            self.state = variables\n\n        flat_variables, _ = jax.tree_util.tree_flatten(variables)\n        return flat_variables\n\n    def _get_init_seed(self):\n        \"\"\"\n        Returns a single seed as a tensor of shape [2].\n\n        Call this within `_get_init_rng()` to obtain a new seed.\n\n        Returns:\n            A native tensor of shape [2] and the backend dtype for seeds.\n        \"\"\"\n        # Use the global SeedGenerator.\n        return draw_seed(None)\n\n    def _get_init_rng(self):\n        \"\"\"\n        Returns a seed or seeds to pass as the `rng` argument of `init_fn`.\n\n        By default, this returns a single seed. Override this to return a\n        different structure. Overrides should use `self._get_init_seed()` to\n        obtain new seeds.\n\n        Returns:\n            RNG key or structure of keys as tensors of shape [2] and the backend\n            dtype for seeds.\n        \"\"\"\n        return self._get_init_seed()\n\n    def _get_call_seed(self):\n        \"\"\"\n        Returns a single seed as a tensor of shape [2].\n\n        Call this within `_get_call_rng()` to obtain a new seed.\n\n        Returns:\n            A native tensor of shape [2] and the backend dtype for seeds.\n        \"\"\"\n        return self.seed_generator.next()\n\n    def _get_call_rng(self, training):\n        \"\"\"\n        Returns a seed or seeds to pass as the `rng` argument of `call_fn`.\n\n        By default, this returns a seed when `training` is `True`, and `None`\n        when `training` is `False`. Override this to return a different\n        structure or to pass seeds in inference mode too. Overrides should use\n        `self._get_call_seed()` to obtain seeds.\n\n        Returns:\n            RNG key or structure of keys as tensors of shape [2] and the backend\n            dtype for seeds.\n        \"\"\"\n        if training:\n            return self._get_call_seed()\n        else:\n            return None\n\n    def _initialize_weights(self, input_shape):\n        if tf.inside_function():\n            raise ValueError(\"'JaxLayer' cannot be built inside tf function\")\n\n        # Initialize `params` and `state` if needed by calling `init_fn`.\n        def create_input(shape):\n            shape = [d if d is not None else 1 for d in shape]\n            return jax.numpy.ones(shape)\n\n        init_inputs = tree.map_shape_structure(create_input, input_shape)\n        if backend.backend() == \"jax\" and jax_utils.is_in_jax_tracing_scope(\n            tree.flatten(init_inputs)[0]\n        ):\n            raise ValueError(\"'JaxLayer' cannot be built in a tracing scope\")\n\n        init_args = []\n        for argument_name in self.init_fn_arguments:\n            if argument_name == \"rng\":\n                init_args.append(\n                    jax.tree_util.tree_map(\n                        lambda x: jax.numpy.array(_convert_to_jax_key(x)),\n                        self._get_init_rng(),\n                    )\n                )\n            elif argument_name == \"inputs\":\n                init_args.append(init_inputs)\n            elif argument_name == \"training\":\n                init_args.append(True)\n\n        init_result = self.init_fn(*init_args)\n        if self.call_fn_has_state:\n            init_params, init_state = init_result\n        else:\n            init_params, init_state = init_result, None\n\n        self.tracked_params = self._create_variables(\n            init_params, trainable=True\n        )\n        self.tracked_state = self._create_variables(init_state, trainable=False)\n\n    def build(self, input_shape):\n        if (\n            self.params is None\n            and self.state is None\n            and (self.call_fn_has_params or self.call_fn_has_state)\n        ):\n            self._initialize_weights(input_shape)\n\n        if backend.backend() == \"tensorflow\":\n            polymorphic_shapes = []\n            for argument in self.call_fn_arguments:\n                if argument == \"inputs\":\n                    polymorphic_shapes.append(\n                        self._get_jax2tf_input_shape(input_shape)\n                    )\n                elif argument != \"training\":\n                    # params, state, rng\n                    polymorphic_shapes.append(\"...\")\n\n            if \"training\" in self.call_fn_arguments:\n                training_argument_index = self.call_fn_arguments.index(\n                    \"training\"\n                )\n                self.jax2tf_training_false_fn = self._jax2tf_convert(\n                    self._partial_with_positional(\n                        self.call_fn, training_argument_index, False\n                    ),\n                    polymorphic_shapes,\n                )\n                self.jax2tf_training_true_fn = self._jax2tf_convert(\n                    self._partial_with_positional(\n                        self.call_fn, training_argument_index, True\n                    ),\n                    polymorphic_shapes,\n                )\n            else:\n                self.jax2tf_training_false_fn = self._jax2tf_convert(\n                    self.call_fn,\n                    polymorphic_shapes,\n                )\n                self.jax2tf_training_true_fn = None\n            super().build(input_shape)\n\n    def call(self, inputs, training=False):\n        def unwrap_variable(variable):\n            return None if variable is None else variable.value\n\n        call_args = []\n        for argument_name in self.call_fn_arguments:\n            if argument_name == \"params\":\n                call_args.append(\n                    jax.tree_util.tree_map(unwrap_variable, self.params)\n                )\n            elif argument_name == \"state\":\n                call_args.append(\n                    jax.tree_util.tree_map(unwrap_variable, self.state)\n                )\n            elif argument_name == \"rng\":\n                call_args.append(\n                    jax.tree_util.tree_map(\n                        _convert_to_jax_key, self._get_call_rng(training)\n                    )\n                )\n            elif argument_name == \"inputs\":\n                call_args.append(inputs)\n            elif argument_name == \"training\":\n                if backend.backend() == \"jax\":\n                    call_args.append(training)\n\n        def assign_state_to_variable(value, variable):\n            # This exists only to make debugging this error case easier.\n            if not hasattr(variable, \"assign\"):\n                raise ValueError(\n                    \"Structure mismatch: the structure of the state returned \"\n                    \"by `call` does not match the structure of the state at \"\n                    \"initialization time.\"\n                )\n            variable.assign(value)\n\n        def call_with_fn(fn):\n            if self.call_fn_has_state:\n                predictions, new_state = fn(*call_args)\n                jax.tree_util.tree_map(\n                    assign_state_to_variable, new_state, self.state\n                )\n                return predictions\n            else:\n                return fn(*call_args)\n\n        if backend.backend() == \"jax\":\n            return call_with_fn(self.call_fn)\n        elif backend.backend() == \"tensorflow\":\n            if training and self.jax2tf_training_true_fn is not None:\n                return call_with_fn(self.jax2tf_training_true_fn)\n            else:\n                return call_with_fn(self.jax2tf_training_false_fn)\n\n    def get_config(self):\n        config = {\n            \"call_fn\": serialization_lib.serialize_keras_object(self.call_fn),\n            \"init_fn\": serialization_lib.serialize_keras_object(self.init_fn),\n        }\n        base_config = super().get_config()\n        return dict(list(base_config.items()) + list(config.items()))\n\n    @classmethod\n    def from_config(cls, config):\n        call_fn = serialization_lib.deserialize_keras_object(config[\"call_fn\"])\n        init_fn = serialization_lib.deserialize_keras_object(config[\"init_fn\"])\n        config[\"call_fn\"] = call_fn\n        config[\"init_fn\"] = init_fn\n        return super().from_config(config)\n\n\n@keras_export(\"keras.layers.FlaxLayer\")\nclass FlaxLayer(JaxLayer):\n    \"\"\"Keras Layer that wraps a [Flax](https://flax.readthedocs.io) module.\n\n    This layer enables the use of Flax components in the form of\n    [`flax.linen.Module`](\n        https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html)\n    instances within Keras when using JAX as the backend for Keras.\n\n    The module method to use for the forward pass can be specified via the\n    `method` argument and is `__call__` by default. This method must take the\n    following arguments with these exact names:\n\n    - `self` if the method is bound to the module, which is the case for the\n        default of `__call__`, and `module` otherwise to pass the module.\n    - `inputs`: the inputs to the model, a JAX array or a `PyTree` of arrays.\n    - `training` *(optional)*: an argument specifying if we're in training mode\n        or inference mode, `True` is passed in training mode.\n\n    `FlaxLayer` handles the non-trainable state of your model and required RNGs\n    automatically. Note that the `mutable` parameter of\n    [`flax.linen.Module.apply()`](\n        https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.apply)\n    is set to `DenyList([\"params\"])`, therefore making the assumption that all\n    the variables outside of the \"params\" collection are non-trainable weights.\n\n    This example shows how to create a `FlaxLayer` from a Flax `Module` with\n    the default `__call__` method and no training argument:\n\n    ```python\n    class MyFlaxModule(flax.linen.Module):\n        @flax.linen.compact\n        def __call__(self, inputs):\n            x = inputs\n            x = flax.linen.Conv(features=32, kernel_size=(3, 3))(x)\n            x = flax.linen.relu(x)\n            x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n            x = x.reshape((x.shape[0], -1))  # flatten\n            x = flax.linen.Dense(features=200)(x)\n            x = flax.linen.relu(x)\n            x = flax.linen.Dense(features=10)(x)\n            x = flax.linen.softmax(x)\n            return x\n\n    flax_module = MyFlaxModule()\n    keras_layer = FlaxLayer(flax_module)\n    ```\n\n    This example shows how to wrap the module method to conform to the required\n    signature. This allows having multiple input arguments and a training\n    argument that has a different name and values. This additionally shows how\n    to use a function that is not bound to the module.\n\n    ```python\n    class MyFlaxModule(flax.linen.Module):\n        @flax.linen.compact\n        def forward(self, input1, input2, deterministic):\n            ...\n            return outputs\n\n    def my_flax_module_wrapper(module, inputs, training):\n        input1, input2 = inputs\n        return module.forward(input1, input2, not training)\n\n    flax_module = MyFlaxModule()\n    keras_layer = FlaxLayer(\n        module=flax_module,\n        method=my_flax_module_wrapper,\n    )\n    ```\n\n    Args:\n        module: An instance of `flax.linen.Module` or subclass.\n        method: The method to call the model. This is generally a method in the\n            `Module`. If not provided, the `__call__` method is used. `method`\n            can also be a function not defined in the `Module`, in which case it\n            must take the `Module` as the first argument. It is used for both\n            `Module.init` and `Module.apply`. Details are documented in the\n            `method` argument of [`flax.linen.Module.apply()`](\n              https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.apply).\n        variables: A `dict` containing all the variables of the module in the\n            same format as what is returned by [`flax.linen.Module.init()`](\n              https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html#flax.linen.init).\n            It should contain a \"params\" key and, if applicable, other keys for\n            collections of variables for non-trainable state. This allows\n            passing trained parameters and learned non-trainable state or\n            controlling the initialization. If `None` is passed, the module's\n            `init` function is called at build time to initialize the variables\n            of the model.\n    \"\"\"\n\n    def __init__(\n        self,\n        module,\n        method=None,\n        variables=None,\n        **kwargs,\n    ):\n        # Late import to only require Flax when this is used.\n        from flax.linen import DenyList\n\n        self.module = module\n        self.method = method\n\n        apply_mutable = DenyList([\"params\"])\n\n        def apply_with_training(params, state, rng, inputs, training):\n            return self.module.apply(\n                self._params_and_state_to_variables(params, state),\n                inputs,\n                rngs=rng,\n                method=self.method,\n                mutable=apply_mutable,\n                training=training,\n            )\n\n        def apply_without_training(params, state, rng, inputs):\n            return self.module.apply(\n                self._params_and_state_to_variables(params, state),\n                inputs,\n                rngs=rng,\n                method=self.method,\n                mutable=apply_mutable,\n            )\n\n        def init_with_training(rng, inputs, training):\n            return self._variables_to_params_and_state(\n                self.module.init(\n                    rng,\n                    inputs,\n                    method=self.method,\n                    training=training,\n                )\n            )\n\n        def init_without_training(rng, inputs):\n            return self._variables_to_params_and_state(\n                self.module.init(\n                    rng,\n                    inputs,\n                    method=self.method,\n                )\n            )\n\n        if (\n            \"training\"\n            in inspect.signature(method or module.__call__).parameters\n        ):\n            call_fn, init_fn = apply_with_training, init_with_training\n        else:\n            call_fn, init_fn = apply_without_training, init_without_training\n\n        params, state = self._variables_to_params_and_state(variables)\n\n        super().__init__(\n            call_fn=call_fn,\n            init_fn=init_fn,\n            params=params,\n            state=state,\n            **kwargs,\n        )\n\n    def _params_and_state_to_variables(self, params, state):\n        if params:\n            if state:\n                return {**params, **state}\n            else:\n                return params\n        elif state:\n            return state\n        return {}\n\n    def _variables_to_params_and_state(self, variables):\n        # neither params nor state\n        if variables is None:\n            return None, None\n        # state only\n        if \"params\" not in variables:\n            return {}, variables\n        # params only\n        if len(variables) == 1:\n            return variables, {}\n        # both, we need to split\n        params = {\"params\": variables[\"params\"]}\n        state = {k: v for k, v in variables.items() if k != \"params\"}\n        return params, state\n\n    def _get_init_rng(self):\n        return {\n            \"params\": self._get_init_seed(),\n            \"dropout\": self._get_init_seed(),\n        }\n\n    def _get_call_rng(self, training):\n        if training:\n            return {\"dropout\": self._get_call_seed()}\n        else:\n            return {}\n\n    def get_config(self):\n        config_method = self.method\n        if (\n            hasattr(self.method, \"__self__\")\n            and self.method.__self__ == self.module\n        ):\n            # A method bound to the module is serialized by name.\n            config_method = self.method.__name__\n        config = {\n            \"module\": serialization_lib.serialize_keras_object(self.module),\n            \"method\": serialization_lib.serialize_keras_object(config_method),\n        }\n        base_config = super().get_config()\n        # call_fn and init_fn come from module, do not save them.\n        base_config.pop(\"call_fn\")\n        base_config.pop(\"init_fn\")\n        return dict(list(base_config.items()) + list(config.items()))\n\n    @classmethod\n    def from_config(cls, config):\n        module = serialization_lib.deserialize_keras_object(config[\"module\"])\n        method = serialization_lib.deserialize_keras_object(config[\"method\"])\n        if isinstance(config[\"method\"], str):\n            # Deserialize bound method from the module.\n            method = getattr(module, method)\n        config[\"module\"] = module\n        config[\"method\"] = method\n        return cls(**config)\n"
  },
  {
    "path": "keras/src/utils/jax_layer_test.py",
    "content": "import math\nimport os\n\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\nimport pytest\nimport tensorflow as tf\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import metrics\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import random\nfrom keras.src import saving\nfrom keras.src import testing\nfrom keras.src import tree\nfrom keras.src import utils\nfrom keras.src.dtype_policies.dtype_policy import DTypePolicy\nfrom keras.src.saving import object_registration\nfrom keras.src.utils.jax_layer import FlaxLayer\nfrom keras.src.utils.jax_layer import JaxLayer\n\ntry:\n    import flax\nexcept ImportError:\n    flax = None\n\nnum_classes = 10\ninput_shape = (28, 28, 1)  # Excluding batch_size\n\n\n@object_registration.register_keras_serializable()\ndef jax_stateless_function(inputs):\n    x = jax.numpy.sum(inputs, axis=(1, 2, 3))\n    x = jax.numpy.expand_dims(x, axis=1)\n    x = jax.numpy.tile(x, (1, 10))\n    return x\n\n\n@object_registration.register_keras_serializable()\ndef jax_model_no_state_init(rng, inputs):\n    layer_sizes = [784, 300, 100, 10]\n    params = []\n    w_init = jax.nn.initializers.glorot_normal()\n    b_init = jax.nn.initializers.normal(0.1)\n    for m, n in zip(layer_sizes[:-1], layer_sizes[1:]):\n        rng, w_rng = jax.random.split(rng)\n        rng, b_rng = jax.random.split(rng)\n        params.append([w_init(w_rng, (m, n)), b_init(b_rng, (n,))])\n    return params\n\n\n@object_registration.register_keras_serializable()\ndef jax_model_no_state_apply(params, inputs):\n    activations = inputs.reshape((inputs.shape[0], -1))  # flatten\n    for w, b in params[:-1]:\n        outputs = jnp.dot(activations, w) + b\n        activations = jnp.tanh(outputs)\n\n    final_w, final_b = params[-1]\n    logits = jnp.dot(activations, final_w) + final_b\n    return jax.nn.softmax(logits, axis=-1)\n\n\n@object_registration.register_keras_serializable()\ndef jax_model_with_state_init(rng, inputs, training):\n    params = jax_model_no_state_init(rng, inputs)\n    state = jnp.zeros([], jnp.int32)\n    return params, state\n\n\n@object_registration.register_keras_serializable()\ndef jax_model_with_state_apply(params, state, inputs, training):\n    outputs = jax_model_no_state_apply(params, inputs)\n    if training:\n        state = state + 1\n    return outputs, state\n\n\nif flax is not None:\n\n    @object_registration.register_keras_serializable()\n    class FlaxTrainingIndependentModel(flax.linen.Module):\n        @flax.linen.compact\n        def forward(self, inputs):\n            x = inputs\n            x = flax.linen.Conv(features=32, kernel_size=(3, 3))(x)\n            x = flax.linen.relu(x)\n            x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n            x = flax.linen.Conv(features=64, kernel_size=(3, 3))(x)\n            x = flax.linen.relu(x)\n            x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n            x = x.reshape((x.shape[0], -1))  # flatten\n            x = flax.linen.Dense(features=200)(x)\n            x = flax.linen.relu(x)\n            x = flax.linen.Dense(features=10)(x)\n            x = flax.linen.softmax(x)\n            return x\n\n        def get_config(self):\n            return {}\n\n        @classmethod\n        def from_config(cls, config):\n            return cls(**config)\n\n    @object_registration.register_keras_serializable()\n    class FlaxDropoutModel(flax.linen.Module):\n        @flax.linen.compact\n        def my_apply(self, inputs, training):\n            x = inputs\n            x = flax.linen.Conv(features=32, kernel_size=(3, 3))(x)\n            x = flax.linen.relu(x)\n            x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n            x = flax.linen.Conv(features=64, kernel_size=(3, 3))(x)\n            x = flax.linen.relu(x)\n            x = flax.linen.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n            x = x.reshape((x.shape[0], -1))  # flatten\n            x = flax.linen.Dense(features=200)(x)\n            x = flax.linen.Dropout(rate=0.3, deterministic=not training)(x)\n            x = flax.linen.relu(x)\n            x = flax.linen.Dense(features=10)(x)\n            x = flax.linen.softmax(x)\n            return x\n\n        def get_config(self):\n            return {}\n\n        @classmethod\n        def from_config(cls, config):\n            return cls(**config)\n\n    @object_registration.register_keras_serializable()\n    def flax_dropout_wrapper(module, x, training):\n        return module.my_apply(x, training)\n\n    @object_registration.register_keras_serializable()\n    class FlaxBatchNormModel(flax.linen.Module):\n        @flax.linen.compact\n        def __call__(self, inputs, training=False):\n            ura = not training\n            x = inputs\n            x = flax.linen.Conv(\n                features=12, kernel_size=(3, 3), use_bias=False\n            )(x)\n            x = flax.linen.BatchNorm(use_running_average=ura, use_scale=False)(\n                x\n            )\n            x = flax.linen.relu(x)\n            x = flax.linen.Conv(\n                features=24, kernel_size=(6, 6), strides=(2, 2)\n            )(x)\n            x = flax.linen.BatchNorm(use_running_average=ura, use_scale=False)(\n                x\n            )\n            x = flax.linen.relu(x)\n            x = flax.linen.Conv(\n                features=32, kernel_size=(6, 6), strides=(2, 2)\n            )(x)\n            x = flax.linen.BatchNorm(use_running_average=ura, use_scale=False)(\n                x\n            )\n            x = x.reshape((x.shape[0], -1))  # flatten\n            x = flax.linen.Dense(features=200, use_bias=True)(x)\n            x = flax.linen.BatchNorm(use_running_average=ura, use_scale=False)(\n                x\n            )\n            x = flax.linen.Dropout(rate=0.3, deterministic=not training)(x)\n            x = flax.linen.relu(x)\n            x = flax.linen.Dense(features=10)(x)\n            x = flax.linen.softmax(x)\n            return x\n\n        def get_config(self):\n            return {}\n\n        @classmethod\n        def from_config(cls, config):\n            return cls(**config)\n\n    FLAX_OBJECTS = {\n        \"FlaxTrainingIndependentModel\": FlaxTrainingIndependentModel,\n        \"FlaxBatchNormModel\": FlaxBatchNormModel,\n        \"FlaxDropoutModel\": FlaxDropoutModel,\n        \"flax_dropout_wrapper\": flax_dropout_wrapper,\n    }\n\n\n@pytest.mark.skipif(\n    backend.backend() not in [\"jax\", \"tensorflow\"],\n    reason=\"JaxLayer and FlaxLayer are only supported with JAX and TF backend\",\n)\n@pytest.mark.skipif(testing.tensorflow_uses_gpu(), reason=\"GPU test failure\")\nclass TestJaxLayer(testing.TestCase):\n    def _test_layer(\n        self,\n        model_name,\n        layer_class,\n        layer_init_kwargs,\n        trainable_weights,\n        trainable_params,\n        non_trainable_weights,\n        non_trainable_params,\n    ):\n        # Fake MNIST data\n        x_train = random.uniform(shape=(320, 28, 28, 1))\n        y_train_indices = ops.cast(\n            ops.random.uniform(shape=(320,), minval=0, maxval=num_classes),\n            dtype=\"int32\",\n        )\n        y_train = ops.one_hot(y_train_indices, num_classes, dtype=\"int32\")\n        x_test = random.uniform(shape=(32, 28, 28, 1))\n\n        def _count_params(weights):\n            count = 0\n            for weight in weights:\n                count = count + math.prod(ops.shape(weight))\n            return count\n\n        def verify_weights_and_params(layer):\n            self.assertEqual(trainable_weights, len(layer.trainable_weights))\n            self.assertEqual(\n                trainable_params,\n                _count_params(layer.trainable_weights),\n            )\n            self.assertEqual(\n                non_trainable_weights, len(layer.non_trainable_weights)\n            )\n            self.assertEqual(\n                non_trainable_params,\n                _count_params(layer.non_trainable_weights),\n            )\n\n        # functional model\n        layer1 = layer_class(**layer_init_kwargs)\n        inputs1 = layers.Input(shape=input_shape)\n        outputs1 = layer1(inputs1)\n        model1 = models.Model(\n            inputs=inputs1, outputs=outputs1, name=f\"{model_name}1\"\n        )\n        model1.summary()\n\n        verify_weights_and_params(layer1)\n\n        model1.compile(\n            loss=\"categorical_crossentropy\",\n            optimizer=\"adam\",\n            metrics=[metrics.CategoricalAccuracy()],\n        )\n\n        tw1_before_fit = tree.map_structure(\n            backend.convert_to_numpy, layer1.trainable_weights\n        )\n        ntw1_before_fit = tree.map_structure(\n            backend.convert_to_numpy, layer1.non_trainable_weights\n        )\n        model1.fit(x_train, y_train, epochs=1, steps_per_epoch=10)\n        tw1_after_fit = tree.map_structure(\n            backend.convert_to_numpy, layer1.trainable_weights\n        )\n        ntw1_after_fit = tree.map_structure(\n            backend.convert_to_numpy, layer1.non_trainable_weights\n        )\n\n        # verify both trainable and non-trainable weights did change after fit\n        for before, after in zip(tw1_before_fit, tw1_after_fit):\n            self.assertNotAllClose(before, after)\n        for before, after in zip(ntw1_before_fit, ntw1_after_fit):\n            self.assertNotAllClose(before, after)\n\n        expected_output_shape = (ops.shape(x_test)[0], num_classes)\n        output1 = model1(x_test)\n        self.assertEqual(output1.shape, expected_output_shape)\n        predict1 = model1.predict(x_test, steps=1)\n        self.assertEqual(predict1.shape, expected_output_shape)\n\n        # verify both trainable and non-trainable weights did not change\n        tw1_after_call = tree.map_structure(\n            backend.convert_to_numpy, layer1.trainable_weights\n        )\n        ntw1_after_call = tree.map_structure(\n            backend.convert_to_numpy, layer1.non_trainable_weights\n        )\n        for after_fit, after_call in zip(tw1_after_fit, tw1_after_call):\n            self.assertAllClose(after_fit, after_call)\n        for after_fit, after_call in zip(ntw1_after_fit, ntw1_after_call):\n            self.assertAllClose(after_fit, after_call)\n\n        exported_params = jax.tree_util.tree_map(\n            backend.convert_to_numpy, layer1.params\n        )\n        if layer1.state is not None:\n            exported_state = jax.tree_util.tree_map(\n                backend.convert_to_numpy, layer1.state\n            )\n        else:\n            exported_state = None\n\n        def verify_identical_model(model):\n            output = model(x_test)\n            self.assertAllClose(output1, output)\n\n            predict = model.predict(x_test, steps=1)\n            self.assertAllClose(predict1, predict)\n\n        # sequential model to compare results\n        layer2 = layer_class(\n            params=exported_params,\n            state=exported_state,\n            input_shape=input_shape,\n            **layer_init_kwargs,\n        )\n        model2 = models.Sequential([layer2], name=f\"{model_name}2\")\n        model2.summary()\n        verify_weights_and_params(layer2)\n        model2.compile(\n            loss=\"categorical_crossentropy\",\n            optimizer=\"adam\",\n            metrics=[metrics.CategoricalAccuracy()],\n        )\n        verify_identical_model(model2)\n\n        # save, load back and compare results\n        path = os.path.join(self.get_temp_dir(), \"jax_layer_model.keras\")\n        model2.save(path)\n\n        model3 = saving.load_model(path)\n        layer3 = model3.layers[0]\n        model3.summary()\n        verify_weights_and_params(layer3)\n        verify_identical_model(model3)\n\n        # export, load back and compare results\n        path = os.path.join(self.get_temp_dir(), \"jax_layer_export\")\n        model2.export(path, format=\"tf_saved_model\")\n        model4 = tf.saved_model.load(path)\n        output4 = model4.serve(x_test)\n        # The output difference is greater when using the GPU or bfloat16\n        lower_precision = testing.jax_uses_gpu() or \"dtype\" in layer_init_kwargs\n        self.assertAllClose(\n            output1,\n            output4,\n            atol=1e-2 if lower_precision else 1e-6,\n            rtol=1e-3 if lower_precision else 1e-6,\n        )\n\n        # test subclass model building without a build method\n        class TestModel(models.Model):\n            def __init__(self, layer):\n                super().__init__()\n                self._layer = layer\n\n            def call(self, inputs):\n                return self._layer(inputs)\n\n        layer5 = layer_class(**layer_init_kwargs)\n        model5 = TestModel(layer5)\n        output5 = model5(x_test)\n        self.assertNotAllClose(output5, 0.0)\n\n    @parameterized.named_parameters(\n        {\n            \"testcase_name\": \"stateless\",\n            \"init_kwargs\": {\n                \"call_fn\": jax_stateless_function,\n                \"init_fn\": None,\n            },\n            \"trainable_weights\": 0,\n            \"trainable_params\": 0,\n            \"non_trainable_weights\": 0,\n            \"non_trainable_params\": 0,\n        },\n        {\n            \"testcase_name\": \"training_independent\",\n            \"init_kwargs\": {\n                \"call_fn\": jax_model_no_state_apply,\n                \"init_fn\": jax_model_no_state_init,\n            },\n            \"trainable_weights\": 6,\n            \"trainable_params\": 266610,\n            \"non_trainable_weights\": 0,\n            \"non_trainable_params\": 0,\n        },\n        {\n            \"testcase_name\": \"training_state\",\n            \"init_kwargs\": {\n                \"call_fn\": jax_model_with_state_apply,\n                \"init_fn\": jax_model_with_state_init,\n            },\n            \"trainable_weights\": 6,\n            \"trainable_params\": 266610,\n            \"non_trainable_weights\": 1,\n            \"non_trainable_params\": 1,\n        },\n        {\n            \"testcase_name\": \"training_state_dtype_policy\",\n            \"init_kwargs\": {\n                \"call_fn\": jax_model_with_state_apply,\n                \"init_fn\": jax_model_with_state_init,\n                \"dtype\": DTypePolicy(\"mixed_float16\"),\n            },\n            \"trainable_weights\": 6,\n            \"trainable_params\": 266610,\n            \"non_trainable_weights\": 1,\n            \"non_trainable_params\": 1,\n        },\n    )\n    def test_jax_layer(\n        self,\n        init_kwargs,\n        trainable_weights,\n        trainable_params,\n        non_trainable_weights,\n        non_trainable_params,\n    ):\n        self._test_layer(\n            init_kwargs[\"call_fn\"].__name__,\n            JaxLayer,\n            init_kwargs,\n            trainable_weights,\n            trainable_params,\n            non_trainable_weights,\n            non_trainable_params,\n        )\n\n    @parameterized.named_parameters(\n        {\n            \"testcase_name\": \"training_independent_bound_method\",\n            \"flax_model_class\": \"FlaxTrainingIndependentModel\",\n            \"flax_model_method\": \"forward\",\n            \"init_kwargs\": {},\n            \"trainable_weights\": 8,\n            \"trainable_params\": 648226,\n            \"non_trainable_weights\": 0,\n            \"non_trainable_params\": 0,\n        },\n        {\n            \"testcase_name\": \"training_rng_unbound_method\",\n            \"flax_model_class\": \"FlaxDropoutModel\",\n            \"flax_model_method\": None,\n            \"init_kwargs\": {\n                \"method\": \"flax_dropout_wrapper\",\n            },\n            \"trainable_weights\": 8,\n            \"trainable_params\": 648226,\n            \"non_trainable_weights\": 0,\n            \"non_trainable_params\": 0,\n        },\n        {\n            \"testcase_name\": \"training_rng_state_no_method\",\n            \"flax_model_class\": \"FlaxBatchNormModel\",\n            \"flax_model_method\": None,\n            \"init_kwargs\": {},\n            \"trainable_weights\": 13,\n            \"trainable_params\": 354258,\n            \"non_trainable_weights\": 8,\n            \"non_trainable_params\": 536,\n        },\n        {\n            \"testcase_name\": \"training_rng_unbound_method_dtype_policy\",\n            \"flax_model_class\": \"FlaxDropoutModel\",\n            \"flax_model_method\": None,\n            \"init_kwargs\": {\n                \"method\": \"flax_dropout_wrapper\",\n                \"dtype\": DTypePolicy(\"mixed_float16\"),\n            },\n            \"trainable_weights\": 8,\n            \"trainable_params\": 648226,\n            \"non_trainable_weights\": 0,\n            \"non_trainable_params\": 0,\n        },\n    )\n    @pytest.mark.skipif(flax is None, reason=\"Flax library is not available.\")\n    def test_flax_layer(\n        self,\n        flax_model_class,\n        flax_model_method,\n        init_kwargs,\n        trainable_weights,\n        trainable_params,\n        non_trainable_weights,\n        non_trainable_params,\n    ):\n        flax_model_class = FLAX_OBJECTS.get(flax_model_class)\n        if \"method\" in init_kwargs:\n            init_kwargs[\"method\"] = FLAX_OBJECTS.get(init_kwargs[\"method\"])\n\n        def create_wrapper(**kwargs):\n            params = kwargs.pop(\"params\") if \"params\" in kwargs else None\n            state = kwargs.pop(\"state\") if \"state\" in kwargs else None\n            if params and state:\n                variables = {**params, **state}\n            elif params:\n                variables = params\n            elif state:\n                variables = state\n            else:\n                variables = None\n            kwargs[\"variables\"] = variables\n            flax_model = flax_model_class()\n            if flax_model_method:\n                kwargs[\"method\"] = getattr(flax_model, flax_model_method)\n            return FlaxLayer(flax_model_class(), **kwargs)\n\n        self._test_layer(\n            flax_model_class.__name__,\n            create_wrapper,\n            init_kwargs,\n            trainable_weights,\n            trainable_params,\n            non_trainable_weights,\n            non_trainable_params,\n        )\n\n    def test_with_no_init_fn_and_no_params(self):\n        def jax_fn(params, inputs):\n            return inputs\n\n        with self.assertRaises(ValueError):\n            JaxLayer(jax_fn)\n\n    def test_with_training_in_call_fn_but_not_init_fn(self):\n        def jax_call_fn(params, state, rng, inputs, training):\n            return inputs, {}\n\n        def jax_init_fn(rng, inputs):\n            return {}, {}\n\n        layer = JaxLayer(jax_call_fn, jax_init_fn)\n        layer(np.ones((1,)))\n\n    def test_with_different_argument_order(self):\n        def jax_call_fn(training, inputs, rng, state, params):\n            return inputs, {}\n\n        def jax_init_fn(training, inputs, rng):\n            return {}, {}\n\n        layer = JaxLayer(jax_call_fn, jax_init_fn)\n        layer(np.ones((1,)))\n\n    def test_with_minimal_arguments(self):\n        def jax_call_fn(inputs):\n            return inputs\n\n        def jax_init_fn(inputs):\n            return {}\n\n        layer = JaxLayer(jax_call_fn, jax_init_fn)\n        layer(np.ones((1,)))\n\n    def test_with_missing_inputs_in_call_fn(self):\n        def jax_call_fn(params, rng, training):\n            return jnp.ones((1,))\n\n        def jax_init_fn(rng, inputs):\n            return {}\n\n        with self.assertRaisesRegex(ValueError, \"`call_fn`.*`inputs`\"):\n            JaxLayer(jax_call_fn, jax_init_fn)\n\n    def test_with_missing_inputs_in_init_fn(self):\n        def jax_call_fn(params, rng, inputs, training):\n            return jnp.ones((1,))\n\n        def jax_init_fn(rng, training):\n            return {}\n\n        with self.assertRaisesRegex(ValueError, \"`init_fn`.*`inputs`\"):\n            JaxLayer(jax_call_fn, jax_init_fn)\n\n    def test_with_unsupported_argument_in_call_fn(self):\n        def jax_call_fn(params, rng, inputs, mode):\n            return jnp.ones((1,))\n\n        def jax_init_fn(rng, inputs):\n            return {}\n\n        with self.assertRaisesRegex(ValueError, \"`call_fn`.*`mode`\"):\n            JaxLayer(jax_call_fn, jax_init_fn)\n\n    def test_with_unsupported_argument_in_init_fn(self):\n        def jax_call_fn(params, rng, inputs, training):\n            return inputs\n\n        def jax_init_fn(rng, inputs, mode):\n            return {}\n\n        with self.assertRaisesRegex(ValueError, \"`init_fn`.*`mode`\"):\n            JaxLayer(jax_call_fn, jax_init_fn)\n\n    def test_with_structures_as_inputs_and_outputs(self):\n        def jax_fn(params, inputs):\n            a = inputs[\"a\"]\n            b = inputs[\"b\"]\n            output1 = jnp.concatenate([a, b], axis=1)\n            output2 = jnp.concatenate([b, a], axis=1)\n            return output1, output2\n\n        layer = JaxLayer(jax_fn, params={})\n        inputs = {\n            \"a\": layers.Input((None, 3)),\n            \"b\": layers.Input((None, 3)),\n        }\n        outputs = layer(inputs)\n        model = models.Model(inputs, outputs)\n\n        test_inputs = {\n            \"a\": np.ones((2, 6, 3)),\n            \"b\": np.ones((2, 7, 3)),\n        }\n        test_outputs = model(test_inputs)\n        self.assertAllClose(test_outputs[0], np.ones((2, 13, 3)))\n        self.assertAllClose(test_outputs[1], np.ones((2, 13, 3)))\n\n    def test_with_polymorphic_shape_more_than_26_dimension_names(self):\n        def jax_fn(params, inputs):\n            return jnp.concatenate(inputs, axis=1)\n\n        layer = JaxLayer(jax_fn, params=())\n        inputs = [layers.Input((None, 3)) for _ in range(60)]\n        output = layer(inputs)\n        model = models.Model(inputs, output)\n\n        test_inputs = [np.ones((2, 1, 3))] * 60\n        test_output = model(test_inputs)\n        self.assertAllClose(test_output, np.ones((2, 60, 3)))\n\n    @pytest.mark.skipif(flax is None, reason=\"Flax library is not available.\")\n    def test_with_flax_state_no_params(self):\n        class MyFlaxLayer(flax.linen.Module):\n            @flax.linen.compact\n            def __call__(self, x):\n                def zeros_init(shape):\n                    return jnp.zeros(shape, jnp.int32)\n\n                count = self.variable(\"a\", \"b\", zeros_init, [])\n                count.value = count.value + 1\n                return x\n\n        layer = FlaxLayer(MyFlaxLayer(), variables={\"a\": {\"b\": 0}})\n        layer(np.ones((1,)))\n        self.assertLen(layer.params, 0)\n        self.assertEqual(layer.state[\"a\"][\"b\"].value, 1)\n\n    def test_with_state_none_leaves(self):\n        def jax_fn(params, state, inputs):\n            return inputs, state\n\n        layer = JaxLayer(jax_fn, state={\"foo\": None})\n        self.assertIsNone(layer.state[\"foo\"])\n        layer(np.ones((1,)))\n\n    def test_with_state_non_tensor_leaves(self):\n        def jax_fn(params, state, inputs):\n            return inputs, state\n\n        layer = JaxLayer(jax_fn, state={\"foo\": \"bar\"})\n        self.assertEqual(layer.state[\"foo\"], \"bar\")\n        # layer cannot be invoked as jax2tf will fail on strings\n\n    def test_with_state_jax_registered_node_class(self):\n        @jax.tree_util.register_pytree_node_class\n        class NamedPoint:\n            def __init__(self, x, y, name):\n                self.x = x\n                self.y = y\n                self.name = name\n\n            def tree_flatten(self):\n                return ((self.x, self.y), self.name)\n\n            @classmethod\n            def tree_unflatten(cls, aux_data, children):\n                return cls(*children, aux_data)\n\n        def jax_fn(params, state, inputs):\n            return inputs, state\n\n        layer = JaxLayer(jax_fn, state=[NamedPoint(1.0, 2.0, \"foo\")])\n        layer(np.ones((1,)))\n\n    @parameterized.named_parameters(\n        {\n            \"testcase_name\": \"sequence_instead_of_mapping\",\n            \"init_state\": [0.0],\n            \"error_regex\": \"Expected dict, got \",\n        },\n        {\n            \"testcase_name\": \"mapping_instead_of_sequence\",\n            \"init_state\": {\"state\": {\"foo\": 0.0}},\n            \"error_regex\": \"Expected list, got \",\n        },\n        {\n            \"testcase_name\": \"sequence_instead_of_variable\",\n            \"init_state\": {\"state\": [[0.0]]},\n            \"error_regex\": \"Structure mismatch\",\n        },\n        {\n            \"testcase_name\": \"no_initial_state\",\n            \"init_state\": None,\n            \"error_regex\": \"Expected dict, got None\",\n        },\n        {\n            \"testcase_name\": \"missing_dict_key\",\n            \"init_state\": {\"state\": {}},\n            \"error_regex\": \"Expected list, got \",\n        },\n        {\n            \"testcase_name\": \"missing_variable_in_list\",\n            \"init_state\": {\"state\": {\"foo\": [2.0]}},\n            \"error_regex\": \"Expected list, got \",\n        },\n    )\n    def test_state_mismatch_during_update(self, init_state, error_regex):\n        def jax_fn(params, state, inputs):\n            return inputs, {\"state\": [jnp.ones([])]}\n\n        layer = JaxLayer(jax_fn, params={}, state=init_state)\n        with self.assertRaisesRegex(ValueError, error_regex):\n            layer(np.ones((1,)))\n\n    def test_rng_seeding(self):\n        def jax_init(rng, inputs):\n            return [jax.nn.initializers.normal(1.0)(rng, inputs.shape)]\n\n        def jax_apply(params, inputs):\n            return jnp.dot(inputs, params[0])\n\n        shape = (2, 2)\n\n        utils.set_random_seed(0)\n        layer1 = JaxLayer(jax_apply, jax_init)\n        layer1.build(shape)\n        utils.set_random_seed(0)\n        layer2 = JaxLayer(jax_apply, jax_init)\n        layer2.build(shape)\n        self.assertAllClose(layer1.params[0], layer2.params[0])\n"
  },
  {
    "path": "keras/src/utils/jax_utils.py",
    "content": "from keras.src import backend\n\n\ndef is_in_jax_tracing_scope(x=None):\n    if backend.backend() == \"jax\":\n        import jax\n\n        if x is None:\n            x = backend.numpy.ones(())\n        return isinstance(x, jax.core.Tracer)\n    return False\n"
  },
  {
    "path": "keras/src/utils/model_visualization.py",
    "content": "\"\"\"Utilities related to model visualization.\"\"\"\n\nimport os\nimport sys\n\nfrom keras.src import tree\nfrom keras.src.api_export import keras_export\nfrom keras.src.utils import io_utils\n\ntry:\n    import pydot\nexcept ImportError:\n    # pydot_ng and pydotplus are older forks of pydot\n    # which may still be used by some users\n    try:\n        import pydot_ng as pydot\n    except ImportError:\n        try:\n            import pydotplus as pydot\n        except ImportError:\n            pydot = None\n\n\ndef check_pydot():\n    \"\"\"Returns True if PyDot is available.\"\"\"\n    return pydot is not None\n\n\ndef check_graphviz():\n    \"\"\"Returns True if both PyDot and Graphviz are available.\"\"\"\n    if not check_pydot():\n        return False\n    try:\n        # Attempt to create an image of a blank graph\n        # to check the pydot/graphviz installation.\n        pydot.Dot.create(pydot.Dot())\n        return True\n    except (OSError, pydot.PydotException):\n        return False\n\n\ndef add_edge(dot, src, dst):\n    src_id = str(id(src))\n    dst_id = str(id(dst))\n    if not dot.get_edge(src_id, dst_id):\n        edge = pydot.Edge(src_id, dst_id)\n        edge.set(\"penwidth\", \"2\")\n        dot.add_edge(edge)\n\n\ndef get_layer_activation_name(layer):\n    if hasattr(layer.activation, \"name\"):\n        activation_name = layer.activation.name\n    elif hasattr(layer.activation, \"__name__\"):\n        activation_name = layer.activation.__name__\n    else:\n        activation_name = str(layer.activation)\n    return activation_name\n\n\ndef make_layer_label(layer, **kwargs):\n    class_name = layer.__class__.__name__\n\n    show_layer_names = kwargs.pop(\"show_layer_names\")\n    show_layer_activations = kwargs.pop(\"show_layer_activations\")\n    show_dtype = kwargs.pop(\"show_dtype\")\n    show_shapes = kwargs.pop(\"show_shapes\")\n    show_trainable = kwargs.pop(\"show_trainable\")\n    if kwargs:\n        raise ValueError(f\"Invalid kwargs: {kwargs}\")\n\n    table = (\n        '<<table border=\"0\" cellborder=\"1\" bgcolor=\"black\" cellpadding=\"10\">'\n    )\n\n    colspan_max = sum(int(x) for x in (show_dtype, show_trainable))\n    if show_shapes:\n        colspan_max += 2\n    colspan = max(1, colspan_max)\n\n    if show_layer_names:\n        table += (\n            f'<tr><td colspan=\"{colspan}\" bgcolor=\"black\">'\n            '<font point-size=\"16\" color=\"white\">'\n            f\"<b>{layer.name}</b> ({class_name})\"\n            \"</font></td></tr>\"\n        )\n    else:\n        table += (\n            f'<tr><td colspan=\"{colspan}\" bgcolor=\"black\">'\n            '<font point-size=\"16\" color=\"white\">'\n            f\"<b>{class_name}</b>\"\n            \"</font></td></tr>\"\n        )\n    if (\n        show_layer_activations\n        and hasattr(layer, \"activation\")\n        and layer.activation is not None\n    ):\n        table += (\n            f'<tr><td bgcolor=\"white\" colspan=\"{colspan}\">'\n            '<font point-size=\"14\">'\n            f\"Activation: <b>{get_layer_activation_name(layer)}</b>\"\n            \"</font></td></tr>\"\n        )\n\n    cols = []\n    if show_shapes:\n        input_shape = None\n        output_shape = None\n        try:\n            input_shape = tree.map_structure(lambda x: x.shape, layer.input)\n            output_shape = tree.map_structure(lambda x: x.shape, layer.output)\n        except (ValueError, AttributeError):\n            pass\n\n        def format_shape(shape):\n            if shape is not None:\n                if isinstance(shape, dict):\n                    shape_str = \", \".join(\n                        [f\"{k}: {v}\" for k, v in shape.items()]\n                    )\n                else:\n                    shape_str = f\"{shape}\"\n                shape_str = shape_str.replace(\"}\", \"\").replace(\"{\", \"\")\n            else:\n                shape_str = \"?\"\n            return shape_str\n\n        if class_name != \"InputLayer\":\n            cols.append(\n                (\n                    '<td bgcolor=\"white\"><font point-size=\"14\">'\n                    f\"Input shape: <b>{format_shape(input_shape)}</b>\"\n                    \"</font></td>\"\n                )\n            )\n        cols.append(\n            (\n                '<td bgcolor=\"white\"><font point-size=\"14\">'\n                f\"Output shape: <b>{format_shape(output_shape)}</b>\"\n                \"</font></td>\"\n            )\n        )\n    if show_dtype:\n        dtype = None\n        try:\n            dtype = tree.map_structure(lambda x: x.dtype, layer.output)\n        except (ValueError, AttributeError):\n            pass\n        cols.append(\n            (\n                '<td bgcolor=\"white\"><font point-size=\"14\">'\n                f\"Output dtype: <b>{dtype or '?'}</b>\"\n                \"</font></td>\"\n            )\n        )\n    if show_trainable and hasattr(layer, \"trainable\") and layer.weights:\n        if layer.trainable:\n            cols.append(\n                (\n                    '<td bgcolor=\"forestgreen\">'\n                    '<font point-size=\"14\" color=\"white\">'\n                    \"<b>Trainable</b></font></td>\"\n                )\n            )\n        else:\n            cols.append(\n                (\n                    '<td bgcolor=\"firebrick\">'\n                    '<font point-size=\"14\" color=\"white\">'\n                    \"<b>Non-trainable</b></font></td>\"\n                )\n            )\n    if cols:\n        colspan = len(cols)\n    else:\n        colspan = 1\n\n    if cols:\n        table += f\"<tr>{''.join(cols)}</tr>\"\n    table += \"</table>>\"\n    return table\n\n\ndef make_node(layer, **kwargs):\n    node = pydot.Node(str(id(layer)), label=make_layer_label(layer, **kwargs))\n    node.set(\"fontname\", \"Helvetica\")\n    node.set(\"border\", \"0\")\n    node.set(\"margin\", \"0\")\n    return node\n\n\n@keras_export(\"keras.utils.model_to_dot\")\ndef model_to_dot(\n    model,\n    show_shapes=False,\n    show_dtype=False,\n    show_layer_names=True,\n    rankdir=\"TB\",\n    expand_nested=False,\n    dpi=200,\n    subgraph=False,\n    show_layer_activations=False,\n    show_trainable=False,\n    splines=\"ortho\",\n    **kwargs,\n):\n    \"\"\"Convert a Keras model to dot format.\n\n    Args:\n        model: A Keras model instance.\n        show_shapes: whether to display shape information.\n        show_dtype: whether to display layer dtypes.\n        show_layer_names: whether to display layer names.\n        rankdir: `rankdir` argument passed to PyDot,\n            a string specifying the format of the plot: `\"TB\"`\n            creates a vertical plot; `\"LR\"` creates a horizontal plot.\n        expand_nested: whether to expand nested Functional models\n            into clusters.\n        dpi: Image resolution in dots per inch.\n        subgraph: whether to return a `pydot.Cluster` instance.\n        show_layer_activations: Display layer activations (only for layers that\n            have an `activation` property).\n        show_trainable: whether to display if a layer is trainable.\n        splines: Controls how edges are drawn. Defaults to `\"ortho\"`\n            (right-angle lines). Other options include `\"curved\"`,\n            `\"polyline\"`, `\"spline\"`, and `\"line\"`.\n\n    Returns:\n        A `pydot.Dot` instance representing the Keras model or\n        a `pydot.Cluster` instance representing nested model if\n        `subgraph=True`.\n    \"\"\"\n    from keras.src.ops.function import make_node_key\n\n    if not model.built:\n        raise ValueError(\n            \"This model has not yet been built. \"\n            \"Build the model first by calling `build()` or by calling \"\n            \"the model on a batch of data.\"\n        )\n\n    from keras.src.models import functional\n    from keras.src.models import sequential\n\n    # from keras.src.layers import Wrapper\n\n    if not check_pydot():\n        raise ImportError(\n            \"You must install pydot (`pip install pydot`) for \"\n            \"model_to_dot to work.\"\n        )\n\n    if subgraph:\n        dot = pydot.Cluster(style=\"dashed\", graph_name=model.name)\n        dot.set(\"label\", model.name)\n        dot.set(\"labeljust\", \"l\")\n    else:\n        dot = pydot.Dot()\n        dot.set(\"rankdir\", rankdir)\n        dot.set(\"concentrate\", True)\n        dot.set(\"dpi\", dpi)\n        dot.set(\"splines\", splines)\n        dot.set_node_defaults(shape=\"record\")\n\n    if kwargs.pop(\"layer_range\", None) is not None:\n        raise ValueError(\"Argument `layer_range` is no longer supported.\")\n    if kwargs:\n        raise ValueError(f\"Unrecognized keyword arguments: {kwargs}\")\n\n    kwargs = {\n        \"show_layer_names\": show_layer_names,\n        \"show_layer_activations\": show_layer_activations,\n        \"show_dtype\": show_dtype,\n        \"show_shapes\": show_shapes,\n        \"show_trainable\": show_trainable,\n    }\n\n    if isinstance(model, sequential.Sequential):\n        layers = model.layers\n    elif not isinstance(model, functional.Functional):\n        # We treat subclassed models as a single node.\n        node = make_node(model, **kwargs)\n        dot.add_node(node)\n        return dot\n    else:\n        layers = model._operations\n\n    # Create graph nodes.\n    for i, layer in enumerate(layers):\n        # Process nested functional and sequential models.\n        if expand_nested and isinstance(\n            layer, (functional.Functional, sequential.Sequential)\n        ):\n            submodel = model_to_dot(\n                layer,\n                show_shapes,\n                show_dtype,\n                show_layer_names,\n                rankdir,\n                expand_nested,\n                subgraph=True,\n                show_layer_activations=show_layer_activations,\n                show_trainable=show_trainable,\n            )\n            dot.add_subgraph(submodel)\n\n        else:\n            node = make_node(layer, **kwargs)\n            dot.add_node(node)\n\n    # Connect nodes with edges.\n    if isinstance(model, sequential.Sequential):\n        if not expand_nested:\n            # Single Sequential case.\n            for i in range(len(layers) - 1):\n                add_edge(dot, layers[i], layers[i + 1])\n            return dot\n        else:\n            # The first layer is connected to the `InputLayer`, which is not\n            # represented for Sequential models, so we skip it. What will draw\n            # the incoming edge from outside of the sequential model is the\n            # edge connecting the Sequential model itself.\n            layers = model.layers[1:]\n\n    # Functional and nested Sequential case.\n    for layer in layers:\n        # Go from current layer to input `Node`s.\n        for inbound_index, inbound_node in enumerate(layer._inbound_nodes):\n            # `inbound_node` is a `Node`.\n            if (\n                isinstance(model, functional.Functional)\n                and make_node_key(layer, inbound_index) not in model._nodes\n            ):\n                continue\n\n            # Go from input `Node` to `KerasTensor` representing that input.\n            for input_index, input_tensor in enumerate(\n                inbound_node.input_tensors\n            ):\n                # `input_tensor` is a `KerasTensor`.\n                # `input_history` is a `KerasHistory`.\n                input_history = input_tensor._keras_history\n                if input_history.operation is None:\n                    # Operation is `None` for `Input` tensors.\n                    continue\n\n                # Go from input `KerasTensor` to the `Operation` that produced\n                # it as an output.\n                input_node = input_history.operation._inbound_nodes[\n                    input_history.node_index\n                ]\n                output_index = input_history.tensor_index\n\n                # Tentative source and destination of the edge.\n                source = input_node.operation\n                destination = layer\n\n                if not expand_nested:\n                    # No nesting, connect directly.\n                    add_edge(dot, source, layer)\n                    continue\n\n                # ==== Potentially nested models case ====\n\n                # ---- Resolve the source of the edge ----\n                while isinstance(\n                    source,\n                    (functional.Functional, sequential.Sequential),\n                ):\n                    # When `source` is a `Functional` or `Sequential` model, we\n                    # need to connect to the correct box within that model.\n                    # Functional and sequential models do not have explicit\n                    # \"output\" boxes, so we need to find the correct layer that\n                    # produces the output we're connecting to, which can be\n                    # nested several levels deep in sub-models. Hence the while\n                    # loop to continue going into nested models until we\n                    # encounter a real layer that's not a `Functional` or\n                    # `Sequential`.\n                    source, _, output_index = source.outputs[\n                        output_index\n                    ]._keras_history\n\n                # ---- Resolve the destination of the edge ----\n                while isinstance(\n                    destination,\n                    (functional.Functional, sequential.Sequential),\n                ):\n                    if isinstance(destination, functional.Functional):\n                        # When `destination` is a `Functional`, we point to the\n                        # specific `InputLayer` in the model.\n                        destination = destination.inputs[\n                            input_index\n                        ]._keras_history.operation\n                    else:\n                        # When `destination` is a `Sequential`, there is no\n                        # explicit \"input\" box, so we want to point to the first\n                        # box in the model, but it may itself be another model.\n                        # Hence the while loop to continue going into nested\n                        # models until we encounter a real layer that's not a\n                        # `Functional` or `Sequential`.\n                        destination = destination.layers[0]\n\n                add_edge(dot, source, destination)\n    return dot\n\n\n@keras_export(\"keras.utils.plot_model\")\ndef plot_model(\n    model,\n    to_file=\"model.png\",\n    show_shapes=False,\n    show_dtype=False,\n    show_layer_names=False,\n    rankdir=\"TB\",\n    expand_nested=False,\n    dpi=200,\n    show_layer_activations=False,\n    show_trainable=False,\n    splines=\"ortho\",\n    **kwargs,\n):\n    \"\"\"Converts a Keras model to dot format and save to a file.\n\n    Example:\n\n    ```python\n    inputs = ...\n    outputs = ...\n    model = keras.Model(inputs=inputs, outputs=outputs)\n\n    dot_img_file = '/tmp/model_1.png'\n    keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True)\n    ```\n\n    Args:\n        model: A Keras model instance\n        to_file: File name of the plot image.\n        show_shapes: whether to display shape information.\n        show_dtype: whether to display layer dtypes.\n        show_layer_names: whether to display layer names.\n        rankdir: `rankdir` argument passed to PyDot,\n            a string specifying the format of the plot: `\"TB\"`\n            creates a vertical plot; `\"LR\"` creates a horizontal plot.\n        expand_nested: whether to expand nested Functional models\n            into clusters.\n        dpi: Image resolution in dots per inch.\n        show_layer_activations: Display layer activations (only for layers that\n            have an `activation` property).\n        show_trainable: whether to display if a layer is trainable.\n        splines: Controls how edges are drawn. Defaults to `\"ortho\"`\n            (right-angle lines). Other options include `\"curved\"`,\n            `\"polyline\"`, `\"spline\"`, and `\"line\"`.\n\n    Returns:\n        A Jupyter notebook Image object if Jupyter is installed.\n        This enables in-line display of the model plots in notebooks.\n    \"\"\"\n\n    if not model.built:\n        raise ValueError(\n            \"This model has not yet been built. \"\n            \"Build the model first by calling `build()` or by calling \"\n            \"the model on a batch of data.\"\n        )\n    if not check_pydot():\n        message = (\n            \"You must install pydot (`pip install pydot`) \"\n            \"for `plot_model` to work.\"\n        )\n        if \"IPython.core.magics.namespace\" in sys.modules:\n            # We don't raise an exception here in order to avoid crashing\n            # notebook tests where graphviz is not available.\n            io_utils.print_msg(message)\n            return\n        else:\n            raise ImportError(message)\n    if not check_graphviz():\n        message = (\n            \"You must install graphviz \"\n            \"(see instructions at https://graphviz.gitlab.io/download/) \"\n            \"for `plot_model` to work.\"\n        )\n        if \"IPython.core.magics.namespace\" in sys.modules:\n            # We don't raise an exception here in order to avoid crashing\n            # notebook tests where graphviz is not available.\n            io_utils.print_msg(message)\n            return\n        else:\n            raise ImportError(message)\n\n    if kwargs.pop(\"layer_range\", None) is not None:\n        raise ValueError(\"Argument `layer_range` is no longer supported.\")\n    if kwargs:\n        raise ValueError(f\"Unrecognized keyword arguments: {kwargs}\")\n\n    dot = model_to_dot(\n        model,\n        show_shapes=show_shapes,\n        show_dtype=show_dtype,\n        show_layer_names=show_layer_names,\n        rankdir=rankdir,\n        expand_nested=expand_nested,\n        dpi=dpi,\n        show_layer_activations=show_layer_activations,\n        show_trainable=show_trainable,\n        splines=splines,\n    )\n    to_file = str(to_file)\n    if dot is None:\n        return\n    _, extension = os.path.splitext(to_file)\n    if not extension:\n        extension = \"png\"\n    else:\n        extension = extension[1:]\n    # Save image to disk.\n    dot.write(to_file, format=extension)\n    # Return the image as a Jupyter Image object, to be displayed in-line.\n    # Note that we cannot easily detect whether the code is running in a\n    # notebook, and thus we always return the Image if Jupyter is available.\n    if extension != \"pdf\":\n        try:\n            from IPython import display\n\n            return display.Image(filename=to_file)\n        except ImportError:\n            pass\n"
  },
  {
    "path": "keras/src/utils/module_utils.py",
    "content": "import importlib\n\n\nclass LazyModule:\n    def __init__(self, name, pip_name=None, import_error_msg=None):\n        self.name = name\n        self.pip_name = pip_name or name\n        self.import_error_msg = import_error_msg or (\n            f\"This requires the {self.name} module. \"\n            f\"You can install it via `pip install {self.pip_name}`\"\n        )\n        self.module = None\n        self._available = None\n\n    @property\n    def available(self):\n        if self._available is None:\n            try:\n                self.initialize()\n                self._available = True\n            except ImportError:\n                self._available = False\n        return self._available\n\n    def initialize(self):\n        try:\n            self.module = importlib.import_module(self.name)\n        except ImportError:\n            raise ImportError(self.import_error_msg)\n\n    def __getattr__(self, name):\n        if name == \"_api_export_path\":\n            raise AttributeError\n        if self.module is None:\n            self.initialize()\n        return getattr(self.module, name)\n\n    def __repr__(self):\n        return f\"LazyModule({self.name})\"\n\n\nclass OrbaxLazyModule(LazyModule):\n    def initialize(self):\n        try:\n            parent_module = importlib.import_module(\"orbax.checkpoint\")\n            self.module = parent_module.v1\n            self.parent_module = parent_module\n        except ImportError:\n            raise ImportError(self.import_error_msg)\n\n    def __getattr__(self, name):\n        if name == \"_api_export_path\":\n            raise AttributeError\n        if self.module is None:\n            self.initialize()\n        if name == \"multihost\":\n            return self.parent_module.multihost\n        return getattr(self.module, name)\n\n\ntensorflow = LazyModule(\"tensorflow\")\ngfile = LazyModule(\"tensorflow.io.gfile\", pip_name=\"tensorflow\")\ntensorflow_io = LazyModule(\"tensorflow_io\")\nscipy = LazyModule(\"scipy\")\njax = LazyModule(\"jax\")\nh5py = LazyModule(\"h5py\")\ntorch_xla = LazyModule(\n    \"torch_xla\",\n    import_error_msg=(\n        \"This requires the torch_xla module. You can install it via \"\n        \"`pip install torch-xla`. Additionally, you may need to update \"\n        \"LD_LIBRARY_PATH if necessary. Torch XLA builds a shared library, \"\n        \"_XLAC.so, which needs to link to the version of Python it was built \"\n        \"with. Use the following command to update LD_LIBRARY_PATH: \"\n        \"`export LD_LIBRARY_PATH=<path to Python>/lib:$LD_LIBRARY_PATH`\"\n    ),\n)\noptree = LazyModule(\"optree\")\ndmtree = LazyModule(\"tree\")\ntf2onnx = LazyModule(\"tf2onnx\")\ngrain = LazyModule(\"grain\")\nlitert = LazyModule(\"ai_edge_litert\")\nocp = OrbaxLazyModule(\n    \"orbax.checkpoint.v1\",\n    pip_name=\"orbax-checkpoint\",\n    import_error_msg=(\n        \"OrbaxCheckpoint requires the 'orbax-checkpoint' package. \"\n        \"You can install it via pip install orbax-checkpoint\"\n    ),\n)\n"
  },
  {
    "path": "keras/src/utils/naming.py",
    "content": "import collections\nimport re\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend.common import global_state\n\n\ndef auto_name(prefix):\n    prefix = to_snake_case(prefix)\n    return uniquify(prefix)\n\n\ndef uniquify(name):\n    object_name_uids = global_state.get_global_attribute(\n        \"object_name_uids\",\n        default=collections.defaultdict(int),\n        set_to_default=True,\n    )\n    if name in object_name_uids:\n        unique_name = f\"{name}_{object_name_uids[name]}\"\n    else:\n        unique_name = name\n    object_name_uids[name] += 1\n    return unique_name\n\n\ndef to_snake_case(name):\n    name = re.sub(r\"\\W+\", \"\", name)\n    name = re.sub(\"(.)([A-Z][a-z]+)\", r\"\\1_\\2\", name)\n    name = re.sub(\"([a-z])([A-Z])\", r\"\\1_\\2\", name).lower()\n    return name\n\n\n@keras_export(\"keras.backend.get_uid\")\ndef get_uid(prefix=\"\"):\n    \"\"\"Associates a string prefix with an integer counter.\n\n    Args:\n        prefix: String prefix to index.\n\n    Returns:\n        Unique integer ID.\n\n    Example:\n\n    >>> get_uid('dense')\n    1\n    >>> get_uid('dense')\n    2\n    \"\"\"\n    object_name_uids = global_state.get_global_attribute(\n        \"object_name_uids\",\n        default=collections.defaultdict(int),\n        set_to_default=True,\n    )\n    object_name_uids[prefix] += 1\n    return object_name_uids[prefix]\n\n\ndef reset_uids():\n    global_state.set_global_attribute(\n        \"object_name_uids\", collections.defaultdict(int)\n    )\n\n\ndef get_object_name(obj):\n    if hasattr(obj, \"name\"):  # Most Keras objects.\n        return obj.name\n    elif hasattr(obj, \"__name__\"):  # Function.\n        return to_snake_case(obj.__name__)\n    elif hasattr(obj, \"__class__\"):  # Class instance.\n        return to_snake_case(obj.__class__.__name__)\n    return to_snake_case(str(obj))\n"
  },
  {
    "path": "keras/src/utils/naming_test.py",
    "content": "from keras.src.testing import test_case\nfrom keras.src.utils import naming\n\n\nclass NamingUtilsTest(test_case.TestCase):\n    def test_uniquify_unique_name(self):\n        name = \"the_unique_name\"\n        unique_name = naming.uniquify(name)\n        self.assertEqual(unique_name, name)\n\n    def test_auto_name(self):\n        self.assertEqual(naming.auto_name(\"unique_name\"), \"unique_name\")\n        self.assertEqual(naming.auto_name(\"unique_name\"), \"unique_name_1\")\n        self.assertEqual(naming.auto_name(\"unique_name\"), \"unique_name_2\")\n\n    def test_get_uid(self):\n        self.assertEqual(naming.get_uid(\"very_unique_name\"), 1)\n        self.assertEqual(naming.get_uid(\"very_unique_name\"), 2)\n        self.assertEqual(naming.get_uid(\"very_unique_name\"), 3)\n\n    def test_uniquify_non_unique_name(self):\n        name = \"non_unique_name\"\n        naming.uniquify(name)\n        unique_name = naming.uniquify(name)\n        self.assertEqual(unique_name, f\"{name}_1\")\n\n    def test_to_snake_case_snake_case_name(self):\n        name = \"snake_case_name\"\n        snake_case_name = naming.to_snake_case(name)\n        self.assertEqual(snake_case_name, name)\n\n    def test_get_uid_existing_prefix(self):\n        prefix = \"existing_prefix\"\n        naming.get_uid(prefix)\n        uid = naming.get_uid(prefix)\n        self.assertEqual(uid, 2)\n\n    def test_reset_uids(self):\n        naming.get_uid(\"unique_name\")\n        naming.reset_uids()\n        uid = naming.get_uid(\"unique_name\")\n        self.assertEqual(uid, 1)\n\n    def test_get_object_name_no_name_attribute(self):\n        class ObjectWithoutName:\n            __name__ = \"ObjectWithoutName\"\n\n        obj = ObjectWithoutName()\n        object_name = naming.get_object_name(obj)\n        self.assertEqual(object_name, \"object_without_name\")\n\n    def test_get_object_name_no_name_or_class_attribute(self):\n        class ObjectWithoutNameOrClass:\n            pass\n\n        obj = ObjectWithoutNameOrClass()\n        object_name = naming.get_object_name(obj)\n        self.assertEqual(object_name, \"object_without_name_or_class\")\n\n    def test_uniquify_already_uniquified_name(self):\n        name = \"unique_name\"\n        unique_name = naming.uniquify(name)\n        new_unique_name = naming.uniquify(unique_name)\n\n        # first time `name` is uniquified so returns same name\n        self.assertEqual(name, unique_name)\n\n        # second time `name` is uniquified should be different\n        # from the first output\n        self.assertNotEqual(new_unique_name, unique_name)\n\n    def test_to_snake_case_capital_after_any_character(self):\n        name = \"myVariableNameHere\"\n        snake_case_name = naming.to_snake_case(name)\n        self.assertEqual(snake_case_name, \"my_variable_name_here\")\n\n    def test_to_snake_case_lower_before_upper(self):\n        name = \"convertTHIS\"\n        snake_case_name = naming.to_snake_case(name)\n        self.assertEqual(snake_case_name, \"convert_this\")\n\n    def test_to_snake_case_already_snake_cased(self):\n        name = \"already_snake_cased\"\n        snake_case_name = naming.to_snake_case(name)\n        self.assertEqual(snake_case_name, name)\n\n    def test_to_snake_case_no_changes(self):\n        name = \"lowercase\"\n        snake_case_name = naming.to_snake_case(name)\n        self.assertEqual(snake_case_name, name)\n\n    def test_to_snake_case_single_uppercase_word(self):\n        name = \"UPPERCASE\"\n        snake_case_name = naming.to_snake_case(name)\n        self.assertEqual(snake_case_name, \"uppercase\")\n\n    def test_get_object_name_for_keras_objects(self):\n        class MockKerasObject:\n            name = \"mock_object\"\n\n        obj = MockKerasObject()\n        result = naming.get_object_name(obj)\n        self.assertEqual(\n            result, \"mock_object\", f\"Expected 'mock_object' but got {result}\"\n        )\n\n    # Test for function objects that have a `__name__` attribute.\n    def test_get_object_name_for_functions(self):\n        def mock_function():\n            pass\n\n        result = naming.get_object_name(mock_function)\n        # Assumes to_snake_case works correctly.\n        expected_name = naming.to_snake_case(mock_function.__name__)\n        self.assertEqual(\n            result,\n            expected_name,\n            f\"Expected '{expected_name}' but got {result}\",\n        )\n"
  },
  {
    "path": "keras/src/utils/numerical_utils.py",
    "content": "import numpy as np\n\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.utils import tf_utils\n\n\n@keras_export(\"keras.utils.normalize\")\ndef normalize(x, axis=-1, order=2):\n    \"\"\"Normalizes an array.\n\n    If the input is a NumPy array, a NumPy array will be returned.\n    If it's a backend tensor, a backend tensor will be returned.\n\n    Args:\n        x: Array to normalize.\n        axis: axis along which to normalize.\n        order: Normalization order (e.g. `order=2` for L2 norm).\n\n    Returns:\n        A normalized copy of the array.\n    \"\"\"\n    from keras.src import ops\n\n    if isinstance(x, np.ndarray):\n        # NumPy input\n        norm = np.atleast_1d(np.linalg.norm(x, order, axis))\n        norm[norm == 0] = 1\n\n        # axis cannot be `None`\n        axis = axis or -1\n        return x / np.expand_dims(norm, axis)\n\n    # Backend tensor input\n    return ops.nn.normalize(x, axis=axis, order=order)\n\n\n@keras_export(\"keras.utils.to_categorical\")\ndef to_categorical(x, num_classes=None):\n    \"\"\"Converts a class vector (integers) to binary class matrix.\n\n    E.g. for use with `categorical_crossentropy`.\n\n    Args:\n        x: Array-like with class values to be converted into a matrix\n            (integers from 0 to `num_classes - 1`).\n        num_classes: Total number of classes. If `None`, this would be inferred\n            as `max(x) + 1`. Defaults to `None`.\n\n    Returns:\n        A binary matrix representation of the input as a NumPy array. The class\n        axis is placed last.\n\n    Example:\n\n    >>> a = keras.utils.to_categorical([0, 1, 2, 3], num_classes=4)\n    >>> print(a)\n    [[1. 0. 0. 0.]\n     [0. 1. 0. 0.]\n     [0. 0. 1. 0.]\n     [0. 0. 0. 1.]]\n\n    >>> b = np.array([.9, .04, .03, .03,\n    ...               .3, .45, .15, .13,\n    ...               .04, .01, .94, .05,\n    ...               .12, .21, .5, .17]).reshape(4,4)\n    >>> loss = keras.ops.categorical_crossentropy(a, b)\n    >>> print(np.around(loss, 5))\n    [0.10536 0.82807 0.1011  1.77196]\n\n    >>> loss = keras.ops.categorical_crossentropy(a, a)\n    >>> print(np.around(loss, 5))\n    [0. 0. 0. 0.]\n    \"\"\"\n    if backend.is_tensor(x):\n        input_shape = backend.core.shape(x)\n        # Shrink the last dimension if the shape is (..., 1).\n        if (\n            input_shape is not None\n            and len(input_shape) > 1\n            and input_shape[-1] == 1\n        ):\n            newshape = tuple(input_shape[:-1])\n            x = backend.numpy.reshape(x, newshape)\n        return backend.nn.one_hot(x, num_classes)\n    x = np.array(x, dtype=\"int64\")\n    input_shape = x.shape\n\n    # Shrink the last dimension if the shape is (..., 1).\n    if input_shape and input_shape[-1] == 1 and len(input_shape) > 1:\n        input_shape = tuple(input_shape[:-1])\n\n    x = x.reshape(-1)\n    if not num_classes:\n        num_classes = np.max(x) + 1\n    batch_size = x.shape[0]\n    categorical = np.zeros((batch_size, num_classes))\n    categorical[np.arange(batch_size), x] = 1\n    output_shape = input_shape + (num_classes,)\n    categorical = np.reshape(categorical, output_shape)\n    return categorical\n\n\ndef encode_categorical_inputs(\n    inputs,\n    output_mode,\n    depth,\n    dtype,\n    sparse=False,\n    count_weights=None,\n    backend_module=None,\n):\n    \"\"\"Encodes categorical inputs according to output_mode.\n\n    Args:\n        inputs: the inputs to encode.\n        output_mode: one of `\"int\"`, `\"one_hot\"`, `\"multi_hot\"`, or `\"count\"`.\n        depth: number of classes, this will be the last dimension of the output.\n        dtype: the dtype of the output, unless `count_weights` is not `None`.\n        sparse: whether the output should be sparse for backends supporting it.\n        count_weights: weights to apply if `output_mode` is `\"count\"`.\n        backend_module: the backend to use instead of the current one.\n\n    Returns: the encoded inputs.\n    \"\"\"\n    backend_module = backend_module or backend\n\n    if output_mode == \"int\":\n        return backend_module.cast(inputs, dtype=dtype)\n\n    rank_of_inputs = len(backend_module.shape(inputs))\n\n    # In all cases, we should uprank scalar input to a single sample.\n    if rank_of_inputs == 0:\n        inputs = backend_module.numpy.expand_dims(inputs, -1)\n        rank_of_inputs = 1\n\n    if (\n        backend_module.__name__.endswith(\"tensorflow\")\n        and rank_of_inputs <= 2\n        and output_mode in (\"multi_hot\", \"count\")\n    ):\n        # TF only fastpath. Uses bincount; faster. Doesn't work for rank 3+.\n        try:\n            return tf_utils.tf_encode_categorical_inputs(\n                inputs,\n                output_mode,\n                depth,\n                dtype=dtype,\n                sparse=sparse,\n                count_weights=count_weights,\n            )\n        except ValueError:\n            pass\n\n    if output_mode == \"multi_hot\":\n        return backend_module.nn.multi_hot(\n            inputs, depth, dtype=dtype, sparse=sparse\n        )\n    elif output_mode == \"one_hot\":\n        input_shape = backend_module.core.shape(inputs)\n        # Shrink the last dimension if the shape is (..., 1).\n        if (\n            input_shape is not None\n            and len(input_shape) > 1\n            and input_shape[-1] == 1\n        ):\n            newshape = tuple(input_shape[:-1])\n            inputs = backend_module.numpy.reshape(inputs, newshape)\n        return backend_module.nn.one_hot(\n            inputs, depth, dtype=dtype, sparse=sparse\n        )\n    elif output_mode == \"count\":\n        # We don't use `ops.bincount` because its output has a dynamic shape\n        # (last dimension is the highest value of `inputs`). We implement a\n        # narrower use case where `minlength` and `maxlength` (not supported by\n        # `ops.bincount`) are the same and static value: `depth`. We also don't\n        # need to support indices that are negative or greater than `depth`.\n        reduction_axis = 1 if len(inputs.shape) > 1 else 0\n\n        if count_weights is not None:\n            dtype = count_weights.dtype\n        one_hot_encoding = backend_module.nn.one_hot(\n            inputs, depth, dtype=dtype, sparse=sparse\n        )\n        if count_weights is not None:\n            count_weights = backend_module.numpy.expand_dims(count_weights, -1)\n            one_hot_encoding = one_hot_encoding * count_weights\n\n        outputs = backend_module.numpy.sum(\n            one_hot_encoding,\n            axis=reduction_axis,\n        )\n        return outputs\n\n\ndef build_pos_neg_masks(\n    query_labels,\n    key_labels,\n    remove_diagonal=True,\n):\n    from keras.src import ops\n\n    if ops.ndim(query_labels) == 1:\n        query_labels = ops.reshape(query_labels, (-1, 1))\n\n    if ops.ndim(key_labels) == 1:\n        key_labels = ops.reshape(key_labels, (-1, 1))\n\n    positive_mask = ops.equal(query_labels, ops.transpose(key_labels))\n    negative_mask = ops.logical_not(positive_mask)\n\n    if remove_diagonal:\n        positive_mask = ops.logical_and(\n            positive_mask,\n            ~ops.eye(\n                ops.size(query_labels),\n                ops.size(key_labels),\n                k=0,\n                dtype=\"bool\",\n            ),\n        )\n\n    return positive_mask, negative_mask\n"
  },
  {
    "path": "keras/src/utils/numerical_utils_test.py",
    "content": "import numpy as np\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import testing\nfrom keras.src.utils import numerical_utils\n\nNUM_CLASSES = 5\n\n\nclass TestNumericalUtils(testing.TestCase):\n    @parameterized.parameters(\n        [\n            ((1,), (1, NUM_CLASSES)),\n            ((3,), (3, NUM_CLASSES)),\n            ((4, 3), (4, 3, NUM_CLASSES)),\n            ((5, 4, 3), (5, 4, 3, NUM_CLASSES)),\n            ((3, 1), (3, NUM_CLASSES)),\n            ((3, 2, 1), (3, 2, NUM_CLASSES)),\n        ]\n    )\n    def test_to_categorical(self, shape, expected_shape):\n        label = np.random.randint(0, NUM_CLASSES, shape)\n        one_hot = numerical_utils.to_categorical(label, NUM_CLASSES)\n        # Check shape\n        self.assertEqual(one_hot.shape, expected_shape)\n        # Make sure there is only one 1 in a row\n        self.assertTrue(np.all(one_hot.sum(axis=-1) == 1))\n        # Get original labels back from one hots\n        self.assertTrue(\n            np.all(np.argmax(one_hot, -1).reshape(label.shape) == label)\n        )\n\n    def test_to_categorical_without_num_classes(self):\n        label = [0, 2, 5]\n        one_hot = numerical_utils.to_categorical(label)\n        self.assertEqual(one_hot.shape, (3, 5 + 1))\n\n    def test_to_categorical_with_backend_tensor(self):\n        label = backend.convert_to_tensor(np.array([0, 2, 1, 3, 4]))\n        expected = backend.convert_to_tensor(\n            np.array(\n                [\n                    [1, 0, 0, 0, 0],\n                    [0, 0, 1, 0, 0],\n                    [0, 1, 0, 0, 0],\n                    [0, 0, 0, 1, 0],\n                    [0, 0, 0, 0, 1],\n                ]\n            )\n        )\n        one_hot = numerical_utils.to_categorical(label, NUM_CLASSES)\n        self.assertTrue(backend.is_tensor(one_hot))\n        self.assertAllClose(one_hot, expected)\n\n    @parameterized.parameters([1, 2, 3])\n    def test_normalize(self, order):\n        xb = backend.random.uniform((3, 3), seed=1337)\n        xnp = backend.convert_to_numpy(xb)\n\n        # Expected result\n        l2 = np.atleast_1d(np.linalg.norm(xnp, order, axis=-1))\n        l2[l2 == 0] = 1\n        expected = xnp / np.expand_dims(l2, axis=-1)\n\n        # Test NumPy\n        out = numerical_utils.normalize(xnp, axis=-1, order=order)\n        self.assertIsInstance(out, np.ndarray)\n        self.assertAllClose(out, expected)\n\n        # Test backend\n        out = numerical_utils.normalize(xb, axis=-1, order=order)\n        self.assertTrue(backend.is_tensor(out))\n        self.assertAllClose(backend.convert_to_numpy(out), expected)\n\n    def test_build_pos_neg_masks(self):\n        query_labels = np.array([0, 1, 2, 2, 0])\n        key_labels = np.array([0, 1, 2, 0, 2])\n        expected_shape = (len(query_labels), len(key_labels))\n\n        positive_mask, negative_mask = numerical_utils.build_pos_neg_masks(\n            query_labels, key_labels, remove_diagonal=False\n        )\n\n        positive_mask = backend.convert_to_numpy(positive_mask)\n        negative_mask = backend.convert_to_numpy(negative_mask)\n        self.assertEqual(positive_mask.shape, expected_shape)\n        self.assertEqual(negative_mask.shape, expected_shape)\n        self.assertTrue(\n            np.all(np.logical_not(np.logical_and(positive_mask, negative_mask)))\n        )\n\n        expected_positive_mask_keep_diag = np.array(\n            [\n                [1, 0, 0, 1, 0],\n                [0, 1, 0, 0, 0],\n                [0, 0, 1, 0, 1],\n                [0, 0, 1, 0, 1],\n                [1, 0, 0, 1, 0],\n            ],\n            dtype=\"bool\",\n        )\n        self.assertTrue(\n            np.all(positive_mask == expected_positive_mask_keep_diag)\n        )\n        self.assertTrue(\n            np.all(\n                negative_mask\n                == np.logical_not(expected_positive_mask_keep_diag)\n            )\n        )\n\n        positive_mask, negative_mask = numerical_utils.build_pos_neg_masks(\n            query_labels, key_labels, remove_diagonal=True\n        )\n        positive_mask = backend.convert_to_numpy(positive_mask)\n        negative_mask = backend.convert_to_numpy(negative_mask)\n        self.assertEqual(positive_mask.shape, expected_shape)\n        self.assertEqual(negative_mask.shape, expected_shape)\n        self.assertTrue(\n            np.all(np.logical_not(np.logical_and(positive_mask, negative_mask)))\n        )\n\n        expected_positive_mask_with_remove_diag = np.array(\n            [\n                [0, 0, 0, 1, 0],\n                [0, 0, 0, 0, 0],\n                [0, 0, 0, 0, 1],\n                [0, 0, 1, 0, 1],\n                [1, 0, 0, 1, 0],\n            ],\n            dtype=\"bool\",\n        )\n        self.assertTrue(\n            np.all(positive_mask == expected_positive_mask_with_remove_diag)\n        )\n\n        query_labels = np.array([1, 2, 3])\n        key_labels = np.array([1, 2, 3, 1])\n\n        positive_mask, negative_mask = numerical_utils.build_pos_neg_masks(\n            query_labels, key_labels, remove_diagonal=True\n        )\n        positive_mask = backend.convert_to_numpy(positive_mask)\n        negative_mask = backend.convert_to_numpy(negative_mask)\n        expected_shape_diff_sizes = (len(query_labels), len(key_labels))\n        self.assertEqual(positive_mask.shape, expected_shape_diff_sizes)\n        self.assertEqual(negative_mask.shape, expected_shape_diff_sizes)\n        self.assertTrue(\n            np.all(np.logical_not(np.logical_and(positive_mask, negative_mask)))\n        )\n"
  },
  {
    "path": "keras/src/utils/progbar.py",
    "content": "import math\nimport os\nimport sys\nimport time\n\nimport numpy as np\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.utils import io_utils\n\n\n@keras_export(\"keras.utils.Progbar\")\nclass Progbar:\n    \"\"\"Displays a progress bar.\n\n    Args:\n        target: Total number of steps expected, None if unknown.\n        width: Progress bar width on screen.\n        verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)\n        stateful_metrics: Iterable of string names of metrics that should *not*\n            be averaged over time. Metrics in this list will be displayed as-is.\n            All others will be averaged by the progbar before display.\n        interval: Minimum visual progress update interval (in seconds).\n        unit_name: Display name for step counts (usually \"step\" or \"sample\").\n    \"\"\"\n\n    def __init__(\n        self,\n        target,\n        width=20,\n        verbose=1,\n        interval=0.05,\n        stateful_metrics=None,\n        unit_name=\"step\",\n    ):\n        self.target = target\n        self.width = width\n        self.verbose = verbose\n        self.interval = interval\n        self.unit_name = unit_name\n        if stateful_metrics:\n            self.stateful_metrics = set(stateful_metrics)\n        else:\n            self.stateful_metrics = set()\n\n        self._dynamic_display = (\n            (hasattr(sys.stdout, \"isatty\") and sys.stdout.isatty())\n            or \"ipykernel\" in sys.modules\n            or \"posix\" in sys.modules\n            or \"PYCHARM_HOSTED\" in os.environ\n        )\n        self._seen_so_far = 0\n        # We use a dict + list to avoid garbage collection\n        # issues found in OrderedDict\n        self._values = {}\n        self._values_order = []\n        self._start = time.time()\n        self._last_update = 0\n        self._time_at_epoch_start = self._start\n        self._time_after_first_step = None\n        self._prev_total_width = 0\n\n    def update(self, current, values=None, finalize=None):\n        \"\"\"Updates the progress bar.\n\n        Args:\n            current: Index of current step.\n            values: List of tuples: `(name, value_for_last_step)`. If `name` is\n                in `stateful_metrics`, `value_for_last_step` will be displayed\n                as-is. Else, an average of the metric over time will be\n                displayed.\n            finalize: Whether this is the last update for the progress bar. If\n                `None`, defaults to `current >= self.target`.\n        \"\"\"\n        if finalize is None:\n            if self.target is None:\n                finalize = False\n            else:\n                finalize = current >= self.target\n\n        values = values or []\n        for k, v in values:\n            if k not in self._values_order:\n                self._values_order.append(k)\n            if k not in self.stateful_metrics:\n                # In the case that progress bar doesn't have a target value in\n                # the first epoch, both on_batch_end and on_epoch_end will be\n                # called, which will cause 'current' and 'self._seen_so_far' to\n                # have the same value. Force the minimal value to 1 here,\n                # otherwise stateful_metric will be 0s.\n                if finalize:\n                    self._values[k] = [v, 1]\n                else:\n                    value_base = max(current - self._seen_so_far, 1)\n                    if k not in self._values:\n                        self._values[k] = [v * value_base, value_base]\n                    else:\n                        self._values[k][0] += v * value_base\n                        self._values[k][1] += value_base\n            else:\n                # Stateful metrics output a numeric value. This representation\n                # means \"take an average from a single value\" but keeps the\n                # numeric formatting.\n                self._values[k] = [v, 1]\n        self._seen_so_far = current\n\n        message = \"\"\n        special_char_len = 0\n        now = time.time()\n        time_per_unit = self._estimate_step_duration(current, now)\n\n        if self.verbose == 1:\n            if now - self._last_update < self.interval and not finalize:\n                return\n\n            if self._dynamic_display:\n                message += \"\\b\" * self._prev_total_width\n                message += \"\\r\"\n            else:\n                message += \"\\n\"\n\n            if self.target is not None and self.target > 0:\n                numdigits = int(math.log10(self.target)) + 1\n                bar = (f\"%{numdigits}d/%d\") % (current, self.target)\n                bar = f\"\\x1b[1m{bar}\\x1b[0m \"\n                special_char_len += 8\n                prog = float(current) / self.target\n                prog_width = int(self.width * prog)\n\n                if prog_width > 0:\n                    bar += f\"\\33[32m{'━' * prog_width}\\x1b[0m\"\n                    special_char_len += 9\n                bar += f\"\\33[37m{'━' * (self.width - prog_width)}\\x1b[0m\"\n                special_char_len += 9\n\n            else:\n                bar = \"%7d/Unknown\" % current\n            message += bar\n\n            # Add ETA if applicable\n            if self.target is not None and self.target > 0 and not finalize:\n                eta = time_per_unit * (self.target - current)\n                if eta > 3600:\n                    eta_format = \"%d:%02d:%02d\" % (\n                        eta // 3600,\n                        (eta % 3600) // 60,\n                        eta % 60,\n                    )\n                elif eta > 60:\n                    eta_format = \"%d:%02d\" % (eta // 60, eta % 60)\n                else:\n                    eta_format = \"%ds\" % eta\n                info = f\" \\x1b[1m{eta_format}\\x1b[0m\"\n            else:\n                # Time elapsed since start, in seconds\n                info = f\" \\x1b[1m{now - self._start:.0f}s\\x1b[0m\"\n            special_char_len += 8\n\n            # Add time/step\n            info += self._format_time(time_per_unit, self.unit_name)\n\n            # Add metrics\n            for k in self._values_order:\n                info += f\" - {k}:\"\n                if isinstance(self._values[k], list):\n                    values, count = self._values[k]\n                    if not isinstance(values, float):\n                        values = np.mean(values)\n                    avg = values / max(1, count)\n                    if abs(avg) > 1e-3:\n                        info += f\" {avg:.4f}\"\n                    else:\n                        info += f\" {avg:.4e}\"\n                else:\n                    info += f\" {self._values[k]}\"\n            message += info\n\n            total_width = len(bar) + len(info) - special_char_len\n            if self._prev_total_width > total_width:\n                message += \" \" * (self._prev_total_width - total_width)\n            if finalize:\n                message += \"\\n\"\n\n            io_utils.print_msg(message, line_break=False)\n            self._prev_total_width = total_width\n            message = \"\"\n\n        elif self.verbose == 2:\n            if finalize and self.target is not None and self.target > 0:\n                numdigits = int(math.log10(self.target)) + 1\n                count = f\"%{numdigits}d/%d\" % (current, self.target)\n                info = f\"{count} - {now - self._start:.0f}s\"\n                info += f\" -{self._format_time(time_per_unit, self.unit_name)}\"\n                for k in self._values_order:\n                    info += f\" - {k}:\"\n                    values, count = self._values[k]\n                    if not isinstance(values, float):\n                        values = np.mean(values)\n                    avg = values / max(1, count)\n                    if avg > 1e-3:\n                        info += f\" {avg:.4f}\"\n                    else:\n                        info += f\" {avg:.4e}\"\n                info += \"\\n\"\n                message += info\n                io_utils.print_msg(message, line_break=False)\n                message = \"\"\n\n        self._last_update = now\n\n    def add(self, n, values=None):\n        self.update(self._seen_so_far + n, values)\n\n    def _format_time(self, time_per_unit, unit_name):\n        \"\"\"format a given duration to display to the user.\n\n        Given the duration, this function formats it in either milliseconds\n        or seconds and displays the unit (i.e. ms/step or s/epoch).\n\n        Args:\n            time_per_unit: the duration to display\n            unit_name: the name of the unit to display\n\n        Returns:\n            A string with the correctly formatted duration and units\n        \"\"\"\n        formatted = \"\"\n        if time_per_unit >= 1 or time_per_unit == 0:\n            formatted += f\" {time_per_unit:.0f}s/{unit_name}\"\n        elif time_per_unit >= 1e-3:\n            formatted += f\" {time_per_unit * 1000.0:.0f}ms/{unit_name}\"\n        else:\n            formatted += f\" {time_per_unit * 1000000.0:.0f}us/{unit_name}\"\n        return formatted\n\n    def _estimate_step_duration(self, current, now):\n        \"\"\"Estimate the duration of a single step.\n\n        Given the step number `current` and the corresponding time `now` this\n        function returns an estimate for how long a single step takes. If this\n        is called before one step has been completed (i.e. `current == 0`) then\n        zero is given as an estimate. The duration estimate ignores the duration\n        of the (assumed to be non-representative) first step for estimates when\n        more steps are available (i.e. `current>1`).\n\n        Args:\n            current: Index of current step.\n            now: The current time.\n\n        Returns: Estimate of the duration of a single step.\n        \"\"\"\n        if current:\n            # there are a few special scenarios here:\n            # 1) somebody is calling the progress bar without ever supplying\n            #    step 1\n            # 2) somebody is calling the progress bar and supplies step one\n            #    multiple times, e.g. as part of a finalizing call\n            # in these cases, we just fall back to the simple calculation\n            if self._time_after_first_step is not None and current > 1:\n                time_per_unit = (now - self._time_after_first_step) / (\n                    current - 1\n                )\n            else:\n                time_per_unit = (now - self._start) / current\n\n            if current == 1:\n                self._time_after_first_step = now\n            return time_per_unit\n        else:\n            return 0\n"
  },
  {
    "path": "keras/src/utils/progbar_test.py",
    "content": "import numpy as np\nfrom absl.testing import parameterized\n\nfrom keras.src import testing\nfrom keras.src.utils import progbar\n\n\nclass ProgbarTest(testing.TestCase):\n    @parameterized.named_parameters(\n        [\n            (\"float\", \"float\"),\n            (\"np\", \"np\"),\n            (\"list\", \"list\"),\n        ]\n    )\n    def test_update(self, value_type):\n        if value_type == \"float\":\n            values = 1.0\n        elif value_type == \"np\":\n            values = np.array(1.0)\n        elif value_type == \"list\":\n            values = [0.0, 1.0, 2.0]\n        else:\n            raise ValueError(\"Unknown value_type\")\n        pb = progbar.Progbar(target=1, verbose=1)\n\n        pb.update(1, values=[(\"values\", values)], finalize=True)\n\n    @parameterized.named_parameters(\n        [\n            (\"verbose_1\", 1),\n            (\"verbose_2\", 2),\n        ]\n    )\n    def test_zero_target(self, verbose):\n        pb = progbar.Progbar(target=0, verbose=verbose)\n        pb.update(0, finalize=True)\n"
  },
  {
    "path": "keras/src/utils/python_utils.py",
    "content": "import binascii\nimport codecs\nimport marshal\nimport os\nimport types as python_types\n\n\ndef is_continuous_axis(axis):\n    # Used to determine whether the dimensions in an axis are continuous\n    if isinstance(axis, int) or len(axis) <= 1:\n        return True\n\n    step = axis[1] - axis[0]\n    if step not in (1, -1):\n        return False\n\n    return all(axis[i + 1] - axis[i] == step for i in range(len(axis) - 1))\n\n\ndef default(method):\n    \"\"\"Decorates a method to detect overrides in subclasses.\"\"\"\n    method._is_default = True\n    return method\n\n\ndef is_default(method):\n    \"\"\"Check if a method is decorated with the `default` wrapper.\"\"\"\n    return getattr(method, \"_is_default\", False)\n\n\ndef func_dump(func):\n    \"\"\"Serializes a user-defined function.\n\n    Args:\n        func: the function to serialize.\n\n    Returns:\n        A tuple `(code, defaults, closure)`.\n    \"\"\"\n    if os.name == \"nt\":\n        raw_code = marshal.dumps(func.__code__).replace(b\"\\\\\", b\"/\")\n        code = codecs.encode(raw_code, \"base64\").decode(\"ascii\")\n    else:\n        raw_code = marshal.dumps(func.__code__)\n        code = codecs.encode(raw_code, \"base64\").decode(\"ascii\")\n    defaults = func.__defaults__\n    if func.__closure__:\n        closure = tuple(c.cell_contents for c in func.__closure__)\n    else:\n        closure = None\n    return code, defaults, closure\n\n\ndef func_load(code, defaults=None, closure=None, globs=None):\n    \"\"\"Deserializes a user defined function.\n\n    Args:\n        code: bytecode of the function.\n        defaults: defaults of the function.\n        closure: closure of the function.\n        globs: dictionary of global objects.\n\n    Returns:\n        A function object.\n    \"\"\"\n    if isinstance(code, (tuple, list)):  # unpack previous dump\n        code, defaults, closure = code\n        if isinstance(defaults, list):\n            defaults = tuple(defaults)\n\n    def ensure_value_to_cell(value):\n        \"\"\"Ensures that a value is converted to a python cell object.\n\n        Args:\n            value: Any value that needs to be casted to the cell type\n\n        Returns:\n            A value wrapped as a cell object (see function \"func_load\")\n        \"\"\"\n\n        def dummy_fn():\n            value  # just access it so it gets captured in .__closure__\n\n        cell_value = dummy_fn.__closure__[0]\n        if not isinstance(value, type(cell_value)):\n            return cell_value\n        return value\n\n    if closure is not None:\n        closure = tuple(ensure_value_to_cell(_) for _ in closure)\n    try:\n        raw_code = codecs.decode(code.encode(\"ascii\"), \"base64\")\n    except (UnicodeEncodeError, binascii.Error):\n        raw_code = code.encode(\"raw_unicode_escape\")\n    code = marshal.loads(raw_code)\n    if globs is None:\n        globs = globals()\n    return python_types.FunctionType(\n        code, globs, name=code.co_name, argdefs=defaults, closure=closure\n    )\n\n\ndef to_list(x):\n    \"\"\"Normalizes a list/tensor into a list.\n\n    If a tensor is passed, we return\n    a list of size 1 containing the tensor.\n\n    Args:\n        x: target object to be normalized.\n\n    Returns:\n        A list.\n    \"\"\"\n    if isinstance(x, list):\n        return x\n    return [x]\n\n\ndef remove_long_seq(maxlen, seq, label):\n    \"\"\"Removes sequences that exceed the maximum length.\n\n    Args:\n        maxlen: Int, maximum length of the output sequences.\n        seq: List of lists, where each sublist is a sequence.\n        label: List where each element is an integer.\n\n    Returns:\n        new_seq, new_label: shortened lists for `seq` and `label`.\n    \"\"\"\n    new_seq, new_label = [], []\n    for x, y in zip(seq, label):\n        if len(x) < maxlen:\n            new_seq.append(x)\n            new_label.append(y)\n    return new_seq, new_label\n\n\ndef removeprefix(x, prefix):\n    \"\"\"Backport of `removeprefix` from PEP-616 (Python 3.9+)\"\"\"\n\n    if len(prefix) > 0 and x.startswith(prefix):\n        return x[len(prefix) :]\n    else:\n        return x\n\n\ndef removesuffix(x, suffix):\n    \"\"\"Backport of `removesuffix` from PEP-616 (Python 3.9+)\"\"\"\n\n    if len(suffix) > 0 and x.endswith(suffix):\n        return x[: -len(suffix)]\n    else:\n        return x\n\n\ndef remove_by_id(lst, value):\n    \"\"\"Remove a value from a list by id.\"\"\"\n    for i, v in enumerate(lst):\n        if id(v) == id(value):\n            del lst[i]\n            return\n\n\ndef pythonify_logs(logs):\n    \"\"\"Flatten and convert log values to Python-native types.\n\n    This function attempts to convert dict value by `float(value)` and skips\n    the conversion if it fails.\n\n    Args:\n        logs: A dict containing log values.\n\n    Returns:\n        A flattened dict with values converted to Python-native types if\n        possible.\n    \"\"\"\n    from keras.src import backend\n\n    logs = logs or {}\n    result = {}\n    for key, value in sorted(logs.items()):\n        if isinstance(value, dict):\n            result.update(pythonify_logs(value))\n        else:\n            try:\n                # Prevent torch compiler from breaking the graph.\n                if backend.is_tensor(value):\n                    value = backend.convert_to_numpy(value)\n                value = float(value)\n            except:\n                pass\n            result[key] = value\n    return result\n"
  },
  {
    "path": "keras/src/utils/python_utils_test.py",
    "content": "import base64\nimport marshal\n\nfrom keras.src import testing\nfrom keras.src.utils import python_utils\n\n\nclass PythonUtilsTest(testing.TestCase):\n    def test_func_dump_and_load(self):\n        def my_function(x, y=1, **kwargs):\n            return x + y\n\n        serialized = python_utils.func_dump(my_function)\n        deserialized = python_utils.func_load(serialized)\n        self.assertEqual(deserialized(2, y=3), 5)\n\n    def test_removesuffix(self):\n        x = \"model.keras\"\n        self.assertEqual(python_utils.removesuffix(x, \".keras\"), \"model\")\n        self.assertEqual(python_utils.removesuffix(x, \"model\"), x)\n\n    def test_removeprefix(self):\n        x = \"model.keras\"\n        self.assertEqual(python_utils.removeprefix(x, \"model\"), \".keras\")\n        self.assertEqual(python_utils.removeprefix(x, \".keras\"), x)\n\n    def test_func_load_defaults_as_tuple(self):\n        # Using tuple as a default argument\n        def dummy_function(x=(1, 2, 3)):\n            pass\n\n        serialized = python_utils.func_dump(dummy_function)\n        deserialized = python_utils.func_load(serialized)\n        # Ensure that the defaults are still a tuple\n        self.assertIsInstance(deserialized.__defaults__[0], tuple)\n        # Ensure that the tuple default remains unchanged\n        self.assertEqual(deserialized.__defaults__[0], (1, 2, 3))\n\n    def test_remove_long_seq_standard_case(self):\n        sequences = [[1], [2, 2], [3, 3, 3], [4, 4, 4, 4]]\n        labels = [1, 2, 3, 4]\n        new_sequences, new_labels = python_utils.remove_long_seq(\n            3, sequences, labels\n        )\n        self.assertEqual(new_sequences, [[1], [2, 2]])\n        self.assertEqual(new_labels, [1, 2])\n\n    def test_func_load_with_closure(self):\n        def outer_fn(x):\n            def inner_fn(y):\n                return x + y\n\n            return inner_fn\n\n        func_with_closure = outer_fn(10)\n        serialized = python_utils.func_dump(func_with_closure)\n        deserialized = python_utils.func_load(serialized)\n        self.assertEqual(deserialized(5), 15)\n\n    def test_func_load_closure_conversion(self):\n        def my_function_with_closure(x):\n            return x + y\n\n        y = 5\n        serialized = python_utils.func_dump(my_function_with_closure)\n        deserialized = python_utils.func_load(serialized)\n        self.assertEqual(deserialized(5), 10)\n\n    def test_ensure_value_to_cell(self):\n        value_to_test = \"test_value\"\n\n        def dummy_fn():\n            value_to_test\n\n        cell_value = dummy_fn.__closure__[0].cell_contents\n        self.assertEqual(value_to_test, cell_value)\n\n    def test_closure_processing(self):\n        def simple_function(x):\n            return x + 10\n\n        serialized = python_utils.func_dump(simple_function)\n        deserialized = python_utils.func_load(serialized)\n        self.assertEqual(deserialized(5), 15)\n\n    def test_func_load_valid_encoded_code(self):\n        def another_simple_function(x):\n            return x * 2\n\n        raw_data = marshal.dumps(another_simple_function.__code__)\n        valid_encoded_code = base64.b64encode(raw_data).decode(\"utf-8\")\n\n        try:\n            python_utils.func_load(valid_encoded_code)\n        except (UnicodeEncodeError, ValueError):\n            self.fail(\"Expected no error for valid code, but got an error.\")\n\n    def test_func_load_bad_encoded_code(self):\n        bad_encoded_code = \"This isn't valid base64!\"\n        with self.assertRaises(AttributeError):\n            python_utils.func_load(bad_encoded_code)\n\n    def test_is_continuous_axis(self):\n        # Single axis\n        self.assertTrue(python_utils.is_continuous_axis(1))\n        self.assertTrue(python_utils.is_continuous_axis([1]))\n        # Forward-ordered continuous\n        self.assertTrue(python_utils.is_continuous_axis([1, 2, 3]))\n        self.assertTrue(python_utils.is_continuous_axis([-2, -1]))\n        # Reverse-ordered continuous\n        self.assertTrue(python_utils.is_continuous_axis([3, 2, 1]))\n        self.assertTrue(python_utils.is_continuous_axis([-1, -2]))\n        # Non-continuous\n        self.assertFalse(python_utils.is_continuous_axis([1, 3]))\n        self.assertFalse(python_utils.is_continuous_axis([-1, -3]))\n"
  },
  {
    "path": "keras/src/utils/rng_utils.py",
    "content": "import random\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend.common import global_state\nfrom keras.src.random import seed_generator\nfrom keras.src.utils.module_utils import tensorflow as tf\n\nGLOBAL_RANDOM_SEED = \"global_random_seed\"\n\n\n@keras_export(\"keras.utils.set_random_seed\")\ndef set_random_seed(seed):\n    \"\"\"Sets all random seeds (Python, NumPy, and backend framework, e.g. TF).\n\n    You can use this utility to make almost any Keras program fully\n    deterministic. Some limitations apply in cases where network communications\n    are involved (e.g. parameter server distribution), which creates additional\n    sources of randomness, or when certain non-deterministic cuDNN ops are\n    involved.\n\n    Calling this utility does the following:\n\n    ```python\n    import random\n    random.seed(seed)\n\n    import numpy as np\n    np.random.seed(seed)\n\n    import tensorflow as tf  # Only if TF is installed\n    tf.random.set_seed(seed)\n\n    import torch  # Only if the backend is 'torch'\n    torch.manual_seed(seed)\n    ```\n\n    Additionally, it resets the global Keras `SeedGenerator`, which is used by\n    `keras.random` functions when the `seed` is not provided.\n\n    Note that the TensorFlow seed is set even if you're not using TensorFlow\n    as your backend framework, since many workflows leverage `tf.data`\n    pipelines (which feature random shuffling). Likewise many workflows\n    might leverage NumPy APIs.\n\n    Arguments:\n        seed: Integer, the random seed to use.\n    \"\"\"\n    if not isinstance(seed, int):\n        raise ValueError(\n            \"Expected `seed` argument to be an integer. \"\n            f\"Received: seed={seed} (of type {type(seed)})\"\n        )\n\n    # Store seed in global state so we can query it if set.\n    global_state.set_global_attribute(GLOBAL_RANDOM_SEED, seed)\n    # Remove global SeedGenerator, it will be recreated from the seed.\n    global_state.set_global_attribute(\n        seed_generator.GLOBAL_SEED_GENERATOR, None\n    )\n    random.seed(seed)\n    np.random.seed(seed)\n    if tf.available:\n        tf.random.set_seed(seed)\n    if backend.backend() == \"torch\":\n        import torch\n\n        torch.manual_seed(seed)\n\n\ndef get_random_seed():\n    \"\"\"Returns the explicit integer random seed if set.\n\n    If the seed has been explicitly set via `set_random_seed`, then\n    returns the seed.  Otherwise, returns `None`.\n    \"\"\"\n    return global_state.get_global_attribute(GLOBAL_RANDOM_SEED)\n"
  },
  {
    "path": "keras/src/utils/rng_utils_test.py",
    "content": "import numpy as np\n\nimport keras\nfrom keras.src import backend\nfrom keras.src.testing import test_case\nfrom keras.src.utils import rng_utils\n\n\nclass TestRandomSeedSetting(test_case.TestCase):\n    def test_set_random_seed_with_seed_generator(self):\n        def get_model_output():\n            model = keras.Sequential(\n                [\n                    keras.layers.Dense(10),\n                    keras.layers.Dropout(0.5),\n                    keras.layers.Dense(10),\n                ]\n            )\n            x = np.random.random((32, 10)).astype(\"float32\")\n            return model.predict(x, batch_size=16)\n\n        rng_utils.set_random_seed(42)\n        y1 = get_model_output()\n\n        # Second call should produce different results.\n        y2 = get_model_output()\n        self.assertNotAllClose(y1, y2)\n\n        # Re-seeding should produce the same results as the first time.\n        rng_utils.set_random_seed(42)\n        y3 = get_model_output()\n        self.assertAllClose(y1, y3)\n\n        # Re-seeding with a different seed should produce different results.\n        rng_utils.set_random_seed(1337)\n        y4 = get_model_output()\n        self.assertNotAllClose(y1, y4)\n\n    def test_set_random_seed_with_global_seed_generator(self):\n        rng_utils.set_random_seed(42)\n        y1 = backend.random.randint((32, 10), minval=0, maxval=1000)\n\n        # Second call should produce different results.\n        y2 = backend.random.randint((32, 10), minval=0, maxval=1000)\n        self.assertNotAllClose(y1, y2)\n\n        # Re-seeding should produce the same results as the first time.\n        rng_utils.set_random_seed(42)\n        y3 = backend.random.randint((32, 10), minval=0, maxval=1000)\n        self.assertAllClose(y1, y3)\n\n        # Re-seeding with a different seed should produce different results.\n        rng_utils.set_random_seed(1337)\n        y4 = backend.random.randint((32, 10), minval=0, maxval=1000)\n        self.assertNotAllClose(y1, y4)\n"
  },
  {
    "path": "keras/src/utils/sequence_utils.py",
    "content": "import numpy as np\n\nfrom keras.src.api_export import keras_export\n\n\n@keras_export(\n    [\n        \"keras.utils.pad_sequences\",\n        \"keras.preprocessing.sequence.pad_sequences\",\n    ]\n)\ndef pad_sequences(\n    sequences,\n    maxlen=None,\n    dtype=\"int32\",\n    padding=\"pre\",\n    truncating=\"pre\",\n    value=0.0,\n):\n    \"\"\"Pads sequences to the same length.\n\n    This function transforms a list (of length `num_samples`)\n    of sequences (lists of integers)\n    into a 2D NumPy array of shape `(num_samples, num_timesteps)`.\n    `num_timesteps` is either the `maxlen` argument if provided,\n    or the length of the longest sequence in the list.\n\n    Sequences that are shorter than `num_timesteps`\n    are padded with `value` until they are `num_timesteps` long.\n\n    Sequences longer than `num_timesteps` are truncated\n    so that they fit the desired length.\n\n    The position where padding or truncation happens is determined by\n    the arguments `padding` and `truncating`, respectively.\n    Pre-padding or removing values from the beginning of the sequence is the\n    default.\n\n    >>> sequence = [[1], [2, 3], [4, 5, 6]]\n    >>> keras.utils.pad_sequences(sequence)\n    array([[0, 0, 1],\n           [0, 2, 3],\n           [4, 5, 6]], dtype=int32)\n\n    >>> keras.utils.pad_sequences(sequence, value=-1)\n    array([[-1, -1,  1],\n           [-1,  2,  3],\n           [ 4,  5,  6]], dtype=int32)\n\n    >>> keras.utils.pad_sequences(sequence, padding='post')\n    array([[1, 0, 0],\n           [2, 3, 0],\n           [4, 5, 6]], dtype=int32)\n\n    >>> keras.utils.pad_sequences(sequence, maxlen=2)\n    array([[0, 1],\n           [2, 3],\n           [5, 6]], dtype=int32)\n\n    Args:\n        sequences: List of sequences (each sequence is a list of integers).\n        maxlen: Optional Int, maximum length of all sequences. If not provided,\n            sequences will be padded to the length of the longest individual\n            sequence.\n        dtype: (Optional, defaults to `\"int32\"`). Type of the output sequences.\n            To pad sequences with variable length strings, you can use `object`.\n        padding: String, \"pre\" or \"post\" (optional, defaults to `\"pre\"`):\n            pad either before or after each sequence.\n        truncating: String, \"pre\" or \"post\" (optional, defaults to `\"pre\"`):\n            remove values from sequences larger than\n            `maxlen`, either at the beginning or at the end of the sequences.\n        value: Float or String, padding value. (Optional, defaults to `0.`)\n\n    Returns:\n        NumPy array with shape `(len(sequences), maxlen)`\n    \"\"\"\n    if not hasattr(sequences, \"__len__\"):\n        raise ValueError(\"`sequences` must be iterable.\")\n    num_samples = len(sequences)\n\n    lengths = []\n    sample_shape = ()\n    flag = True\n\n    # take the sample shape from the first non empty sequence\n    # checking for consistency in the main loop below.\n\n    for x in sequences:\n        try:\n            lengths.append(len(x))\n            if flag and len(x):\n                sample_shape = np.asarray(x).shape[1:]\n                flag = False\n        except TypeError as e:\n            raise ValueError(\n                \"`sequences` must be a list of iterables. \"\n                f\"Found non-iterable: {str(x)}\"\n            ) from e\n\n    if maxlen is None:\n        maxlen = np.max(lengths)\n\n    is_dtype_str = np.issubdtype(dtype, np.str_) or np.issubdtype(\n        dtype, np.str_\n    )\n    if isinstance(value, str) and dtype is not object and not is_dtype_str:\n        raise ValueError(\n            f\"`dtype` {dtype} is not compatible with `value`'s type: \"\n            f\"{type(value)}\\nYou should set `dtype=object` for variable length \"\n            \"strings.\"\n        )\n\n    x = np.full((num_samples, maxlen) + sample_shape, value, dtype=dtype)\n    for idx, s in enumerate(sequences):\n        if not len(s):\n            continue  # empty list/array was found\n        if truncating == \"pre\":\n            trunc = s[-maxlen:]\n        elif truncating == \"post\":\n            trunc = s[:maxlen]\n        else:\n            raise ValueError(f'Truncating type \"{truncating}\" not understood')\n\n        # check `trunc` has expected shape\n        trunc = np.asarray(trunc, dtype=dtype)\n        if trunc.shape[1:] != sample_shape:\n            raise ValueError(\n                f\"Shape of sample {trunc.shape[1:]} of sequence at \"\n                f\"position {idx} is different from expected shape \"\n                f\"{sample_shape}\"\n            )\n\n        if padding == \"post\":\n            x[idx, : len(trunc)] = trunc\n        elif padding == \"pre\":\n            x[idx, -len(trunc) :] = trunc\n        else:\n            raise ValueError(f'Padding type \"{padding}\" not understood')\n    return x\n"
  },
  {
    "path": "keras/src/utils/sequence_utils_test.py",
    "content": "from keras.src import testing\nfrom keras.src.utils import sequence_utils\n\n\nclass PadSequencesTest(testing.TestCase):\n    def test_pad_sequences(self):\n        a = [[1], [1, 2], [1, 2, 3]]\n\n        # test padding\n        b = sequence_utils.pad_sequences(a, maxlen=3, padding=\"pre\")\n        self.assertAllClose(b, [[0, 0, 1], [0, 1, 2], [1, 2, 3]])\n        b = sequence_utils.pad_sequences(a, maxlen=3, padding=\"post\")\n        self.assertAllClose(b, [[1, 0, 0], [1, 2, 0], [1, 2, 3]])\n\n        # test truncating\n        b = sequence_utils.pad_sequences(a, maxlen=2, truncating=\"pre\")\n        self.assertAllClose(b, [[0, 1], [1, 2], [2, 3]])\n        b = sequence_utils.pad_sequences(a, maxlen=2, truncating=\"post\")\n        self.assertAllClose(b, [[0, 1], [1, 2], [1, 2]])\n\n        # test value\n        b = sequence_utils.pad_sequences(a, maxlen=3, value=1)\n        self.assertAllClose(b, [[1, 1, 1], [1, 1, 2], [1, 2, 3]])\n\n    def test_pad_sequences_float(self):\n        a = [[1.2], [1.2, 2.3], [1.2, 2.3, 3.4]]\n\n        # test padding\n        b = sequence_utils.pad_sequences(\n            a, maxlen=3, padding=\"pre\", dtype=\"float32\"\n        )\n        self.assertAllClose(b, [[0, 0, 1.2], [0, 1.2, 2.3], [1.2, 2.3, 3.4]])\n        b = sequence_utils.pad_sequences(\n            a, maxlen=3, padding=\"post\", dtype=\"float32\"\n        )\n        self.assertAllClose(b, [[1.2, 0, 0], [1.2, 2.3, 0], [1.2, 2.3, 3.4]])\n\n        # test truncating\n        b = sequence_utils.pad_sequences(\n            a, maxlen=2, truncating=\"pre\", dtype=\"float32\"\n        )\n        self.assertAllClose(b, [[0, 1.2], [1.2, 2.3], [2.3, 3.4]])\n        b = sequence_utils.pad_sequences(\n            a, maxlen=2, truncating=\"post\", dtype=\"float32\"\n        )\n        self.assertAllClose(b, [[0, 1.2], [1.2, 2.3], [1.2, 2.3]])\n\n        # test value\n        b = sequence_utils.pad_sequences(a, maxlen=3, value=1, dtype=\"float32\")\n        self.assertAllClose(b, [[1, 1, 1.2], [1, 1.2, 2.3], [1.2, 2.3, 3.4]])\n\n    def test_pad_sequences_str(self):\n        a = [[\"1\"], [\"1\", \"2\"], [\"1\", \"2\", \"3\"]]\n\n        # test padding\n        b = sequence_utils.pad_sequences(\n            a, maxlen=3, padding=\"pre\", value=\"pad\", dtype=object\n        )\n        self.assertAllEqual(\n            b, [[\"pad\", \"pad\", \"1\"], [\"pad\", \"1\", \"2\"], [\"1\", \"2\", \"3\"]]\n        )\n        b = sequence_utils.pad_sequences(\n            a, maxlen=3, padding=\"post\", value=\"pad\", dtype=\"<U3\"\n        )\n        self.assertAllEqual(\n            b, [[\"1\", \"pad\", \"pad\"], [\"1\", \"2\", \"pad\"], [\"1\", \"2\", \"3\"]]\n        )\n\n        # test truncating\n        b = sequence_utils.pad_sequences(\n            a, maxlen=2, truncating=\"pre\", value=\"pad\", dtype=object\n        )\n        self.assertAllEqual(b, [[\"pad\", \"1\"], [\"1\", \"2\"], [\"2\", \"3\"]])\n        b = sequence_utils.pad_sequences(\n            a, maxlen=2, truncating=\"post\", value=\"pad\", dtype=\"<U3\"\n        )\n        self.assertAllEqual(b, [[\"pad\", \"1\"], [\"1\", \"2\"], [\"1\", \"2\"]])\n\n        with self.assertRaisesRegex(\n            ValueError, \"`dtype` int32 is not compatible with \"\n        ):\n            sequence_utils.pad_sequences(\n                a, maxlen=2, truncating=\"post\", value=\"pad\"\n            )\n\n    def test_pad_sequences_vector(self):\n        a = [[[1, 1]], [[2, 1], [2, 2]], [[3, 1], [3, 2], [3, 3]]]\n\n        # test padding\n        b = sequence_utils.pad_sequences(a, maxlen=3, padding=\"pre\")\n        self.assertAllClose(\n            b,\n            [\n                [[0, 0], [0, 0], [1, 1]],\n                [[0, 0], [2, 1], [2, 2]],\n                [[3, 1], [3, 2], [3, 3]],\n            ],\n        )\n        b = sequence_utils.pad_sequences(a, maxlen=3, padding=\"post\")\n        self.assertAllClose(\n            b,\n            [\n                [[1, 1], [0, 0], [0, 0]],\n                [[2, 1], [2, 2], [0, 0]],\n                [[3, 1], [3, 2], [3, 3]],\n            ],\n        )\n\n        # test truncating\n        b = sequence_utils.pad_sequences(a, maxlen=2, truncating=\"pre\")\n        self.assertAllClose(\n            b, [[[0, 0], [1, 1]], [[2, 1], [2, 2]], [[3, 2], [3, 3]]]\n        )\n\n        b = sequence_utils.pad_sequences(a, maxlen=2, truncating=\"post\")\n        self.assertAllClose(\n            b, [[[0, 0], [1, 1]], [[2, 1], [2, 2]], [[3, 1], [3, 2]]]\n        )\n\n        # test value\n        b = sequence_utils.pad_sequences(a, maxlen=3, value=1)\n        self.assertAllClose(\n            b,\n            [\n                [[1, 1], [1, 1], [1, 1]],\n                [[1, 1], [2, 1], [2, 2]],\n                [[3, 1], [3, 2], [3, 3]],\n            ],\n        )\n"
  },
  {
    "path": "keras/src/utils/summary_utils.py",
    "content": "import functools\nimport math\nimport re\nimport shutil\n\nimport rich\nimport rich.console\nimport rich.markup\n\n# See https://github.com/keras-team/keras/issues/448\n# for below imports\nimport rich.table\n\nfrom keras.src import backend\nfrom keras.src import tree\nfrom keras.src.utils import dtype_utils\nfrom keras.src.utils import io_utils\n\n\ndef count_params(weights):\n    shapes = [v.shape for v in weights]\n    return int(sum(math.prod(p) for p in shapes))\n\n\n@functools.lru_cache(512)\ndef _compute_memory_size(shape, dtype):\n    weight_counts = math.prod(shape)\n    dtype = backend.standardize_dtype(dtype)\n    per_param_size = dtype_utils.dtype_size(dtype)\n    return weight_counts * per_param_size\n\n\ndef weight_memory_size(weights):\n    \"\"\"Compute the memory footprint for weights based on their dtypes.\n\n    Args:\n        weights: An iterable contains the weights to compute weight size.\n\n    Returns:\n        The total memory size (in Bytes) of the weights.\n    \"\"\"\n    unique_weights = {id(w): w for w in weights}.values()\n    total_memory_size = 0\n    for w in unique_weights:\n        total_memory_size += _compute_memory_size(w.shape, w.dtype)\n    return total_memory_size / 8\n\n\ndef readable_memory_size(weight_memory_size):\n    \"\"\"Convert the weight memory size (Bytes) to a readable string.\"\"\"\n    units = [\"B\", \"KB\", \"MB\", \"GB\", \"TB\", \"PB\"]\n    scale = 1024\n    for unit in units:\n        if weight_memory_size / scale < 1:\n            return \"{:.2f} {}\".format(weight_memory_size, unit)\n        else:\n            weight_memory_size /= scale\n    return \"{:.2f} {}\".format(weight_memory_size, units[-1])\n\n\ndef highlight_number(x):\n    \"\"\"Themes numbers in a summary using rich markup.\n\n    We use a separate color for `None`s, e.g. in a layer shape.\n    \"\"\"\n    if x is None:\n        return f\"[color(45)]{x}[/]\"\n    else:\n        return f\"[color(34)]{x}[/]\"\n\n\ndef highlight_symbol(x):\n    \"\"\"Themes keras symbols in a summary using rich markup.\"\"\"\n    return f\"[color(33)]{x}[/]\"\n\n\ndef bold_text(x, color=None):\n    \"\"\"Bolds text using rich markup.\"\"\"\n    if color:\n        return f\"[bold][color({color})]{x}[/][/]\"\n    return f\"[bold]{x}[/]\"\n\n\ndef format_layer_shape(layer):\n    if not layer._inbound_nodes and not layer._build_shapes_dict:\n        return \"?\"\n\n    def format_shape(shape):\n        highlighted = [highlight_number(x) for x in shape]\n        return f\"({', '.join(highlighted)})\"\n\n    # There are 2 approaches to get output shapes:\n    # 1. Using `layer._inbound_nodes`, which is possible if the model is a\n    # Sequential or Functional.\n    # 2. Using `layer._build_shapes_dict`, which is possible if users manually\n    # build the layer.\n    if len(layer._inbound_nodes) > 0:\n        for i in range(len(layer._inbound_nodes)):\n            outputs = layer._inbound_nodes[i].output_tensors\n            output_shapes = tree.map_structure(\n                lambda x: format_shape(x.shape), outputs\n            )\n    else:\n        try:\n            if hasattr(layer, \"output_shape\"):\n                output_shapes = format_shape(layer.output_shape)\n            else:\n                outputs = layer.compute_output_shape(**layer._build_shapes_dict)\n                output_shapes = tree.map_shape_structure(\n                    lambda x: format_shape(x), outputs\n                )\n        except NotImplementedError:\n            return \"?\"\n    if len(output_shapes) == 1:\n        return output_shapes[0]\n    out = str(output_shapes)\n    out = out.replace(\"'\", \"\")\n    return out\n\n\ndef print_summary(\n    model,\n    line_length=None,\n    positions=None,\n    print_fn=None,\n    expand_nested=False,\n    show_trainable=False,\n    layer_range=None,\n):\n    \"\"\"Prints a summary of a model.\n\n    Args:\n        model: Keras model instance.\n        line_length: Total length of printed lines\n            (e.g. set this to adapt the display to different\n            terminal window sizes).\n        positions: Relative or absolute positions of log elements in each line.\n            If not provided, defaults to `[0.3, 0.6, 0.70, 1.]`.\n        print_fn: Print function to use.\n            It will be called on each line of the summary.\n            You can set it to a custom function\n            in order to capture the string summary.\n            It defaults to `print` (prints to stdout).\n        expand_nested: Whether to expand the nested models.\n            If not provided, defaults to `False`.\n        show_trainable: Whether to show if a layer is trainable.\n            If not provided, defaults to `False`.\n        layer_range: List or tuple containing two strings,\n            the starting layer name and ending layer name (both inclusive),\n            indicating the range of layers to be printed in the summary. The\n            strings could also be regexes instead of an exact name. In this\n             case, the starting layer will be the first layer that matches\n            `layer_range[0]` and the ending layer will be the last element that\n            matches `layer_range[1]`. By default (`None`) all\n            layers in the model are included in the summary.\n    \"\"\"\n    from keras.src.models import Functional\n    from keras.src.models import Sequential\n\n    if not print_fn and not io_utils.is_interactive_logging_enabled():\n        print_fn = io_utils.print_msg\n\n    if isinstance(model, Sequential):\n        sequential_like = True\n        layers = model.layers\n    elif not isinstance(model, Functional):\n        # We treat subclassed models as a simple sequence of layers, for logging\n        # purposes.\n        sequential_like = True\n        layers = model.layers\n    else:\n        layers = model._operations\n        sequential_like = True\n        nodes_by_depth = model._nodes_by_depth.values()\n        nodes = []\n        for v in nodes_by_depth:\n            if (len(v) > 1) or (\n                len(v) == 1 and len(tree.flatten(v[0].input_tensors)) > 1\n            ):\n                # if the model has multiple nodes\n                # or if the nodes have multiple inbound_layers\n                # the model is no longer sequential\n                sequential_like = False\n                break\n            nodes += v\n        if sequential_like:\n            # search for shared layers\n            for layer in model.layers:\n                flag = False\n                for node in layer._inbound_nodes:\n                    if node in nodes:\n                        if flag:\n                            sequential_like = False\n                            break\n                        else:\n                            flag = True\n                if not sequential_like:\n                    break\n\n    if sequential_like:\n        default_line_length = 88\n        positions = positions or [0.45, 0.80, 1.0]\n        # header names for the different log elements\n        header = [\"Layer (type)\", \"Output Shape\", \"Param #\"]\n        alignment = [\"left\", \"left\", \"right\"]\n    else:\n        default_line_length = 108\n        positions = positions or [0.3, 0.56, 0.74, 1.0]\n        # header names for the different log elements\n        header = [\"Layer (type)\", \"Output Shape\", \"Param #\", \"Connected to\"]\n        alignment = [\"left\", \"left\", \"right\", \"left\"]\n        relevant_nodes = []\n        for v in model._nodes_by_depth.values():\n            relevant_nodes += v\n\n    if show_trainable:\n        default_line_length += 12\n        positions = [p * 0.90 for p in positions] + [1.0]\n        header.append(\"Trainable\")\n        alignment.append(\"center\")\n\n    # Compute columns widths\n    default_line_length = min(\n        default_line_length, shutil.get_terminal_size().columns - 4\n    )\n    line_length = line_length or default_line_length\n    column_widths = []\n    current = 0\n    for pos in positions:\n        width = int(pos * line_length) - current\n        if width < 4:\n            raise ValueError(\"Insufficient console width to print summary.\")\n        column_widths.append(width)\n        current += width\n\n    # Render summary as a rich table.\n    columns = []\n    # Right align parameter counts.\n    for i, name in enumerate(header):\n        column = rich.table.Column(\n            name,\n            justify=alignment[i],\n            width=column_widths[i],\n        )\n        columns.append(column)\n\n    table = rich.table.Table(*columns, width=line_length, show_lines=True)\n\n    def get_connections(layer):\n        connections = \"\"\n        for node in layer._inbound_nodes:\n            if relevant_nodes and node not in relevant_nodes:\n                # node is not part of the current network\n                continue\n            for kt in node.input_tensors:\n                keras_history = kt._keras_history\n                inbound_layer = keras_history.operation\n                node_index = highlight_number(keras_history.node_index)\n                tensor_index = highlight_number(keras_history.tensor_index)\n                if connections:\n                    connections += \", \"\n                connections += (\n                    f\"{inbound_layer.name}[{node_index}][{tensor_index}]\"\n                )\n        if not connections:\n            connections = \"-\"\n        return connections\n\n    def get_layer_fields(layer, prefix=\"\"):\n        output_shape = format_layer_shape(layer)\n        name = f\"{prefix}{layer.name}\"\n        cls_name = layer.__class__.__name__\n        name = rich.markup.escape(name)\n        name += f\" ({highlight_symbol(rich.markup.escape(cls_name))})\"\n\n        if not hasattr(layer, \"built\"):\n            params = highlight_number(0)\n        elif not layer.built:\n            params = f\"{highlight_number(0)} (unbuilt)\"\n        else:\n            params = highlight_number(f\"{layer.count_params():,}\")\n\n        fields = [name, output_shape, params]\n        if not sequential_like:\n            fields.append(get_connections(layer))\n        if show_trainable:\n            if hasattr(layer, \"weights\") and len(layer.weights) > 0:\n                fields.append(\n                    bold_text(\"Y\", color=34)\n                    if layer.trainable\n                    else bold_text(\"N\", color=9)\n                )\n            else:\n                fields.append(bold_text(\"-\"))\n        return fields\n\n    def print_layer(layer, nested_level=0):\n        if nested_level:\n            prefix = \"   \" * nested_level + \"└ \"\n        else:\n            prefix = \"\"\n\n        fields = get_layer_fields(layer, prefix=prefix)\n\n        rows = [fields]\n        if expand_nested and hasattr(layer, \"layers\") and layer.layers:\n            nested_layers = layer.layers\n            nested_level += 1\n            for i in range(len(nested_layers)):\n                rows.extend(\n                    print_layer(nested_layers[i], nested_level=nested_level)\n                )\n        return rows\n\n    # Render all layers to the rich table.\n    layer_range = get_layer_index_bound_by_layer_name(layers, layer_range)\n    for layer in layers[layer_range[0] : layer_range[1]]:\n        for row in print_layer(layer):\n            table.add_row(*row)\n\n    # After the table, append information about parameter count and size.\n    if hasattr(model, \"_collected_trainable_weights\"):\n        trainable_count = count_params(model._collected_trainable_weights)\n        trainable_memory_size = weight_memory_size(\n            model._collected_trainable_weights\n        )\n    else:\n        trainable_count = count_params(model.trainable_weights)\n        trainable_memory_size = weight_memory_size(model.trainable_weights)\n\n    non_trainable_count = count_params(model.non_trainable_weights)\n    non_trainable_memory_size = weight_memory_size(model.non_trainable_weights)\n\n    if model.compiled and model.optimizer and model.optimizer.built:\n        optimizer_weight_count = count_params(model.optimizer.variables)\n        optimizer_memory_size = weight_memory_size(model.optimizer.variables)\n        optimizer_built = True\n    else:\n        optimizer_weight_count = 0\n        optimizer_memory_size = 0\n        optimizer_built = False\n\n    total_count = trainable_count + non_trainable_count + optimizer_weight_count\n    total_memory_size = (\n        trainable_memory_size\n        + non_trainable_memory_size\n        + optimizer_memory_size\n    )\n\n    # Create a rich console for printing. Capture for non-interactive logging.\n    if print_fn:\n        console = rich.console.Console(\n            highlight=False, force_terminal=False, color_system=None\n        )\n        console.begin_capture()\n    else:\n        console = rich.console.Console(highlight=False)\n\n    # Print the to the console.\n    console.print(bold_text(f'Model: \"{rich.markup.escape(model.name)}\"'))\n    console.print(table)\n    console.print(\n        bold_text(\" Total params: \")\n        + highlight_number(f\"{total_count:,}\")\n        + f\" ({readable_memory_size(total_memory_size)})\"\n    )\n    console.print(\n        bold_text(\" Trainable params: \")\n        + highlight_number(f\"{trainable_count:,}\")\n        + f\" ({readable_memory_size(trainable_memory_size)})\"\n    )\n    console.print(\n        bold_text(\" Non-trainable params: \")\n        + highlight_number(f\"{non_trainable_count:,}\")\n        + f\" ({readable_memory_size(non_trainable_memory_size)})\"\n    )\n    if optimizer_built:\n        console.print(\n            bold_text(\" Optimizer params: \")\n            + highlight_number(f\"{optimizer_weight_count:,}\")\n            + f\" ({readable_memory_size(optimizer_memory_size)})\"\n        )\n\n    # Output captured summary for non-interactive logging.\n    if print_fn:\n        if print_fn is io_utils.print_msg:\n            print_fn(console.end_capture(), line_break=False)\n        else:\n            print_fn(console.end_capture())\n\n\ndef get_layer_index_bound_by_layer_name(layers, layer_range=None):\n    \"\"\"Get the layer indexes from the model based on layer names.\n\n    The layer indexes can be used to slice the model into sub models for\n    display.\n\n    Args:\n        model: `Model` instance.\n        layer_names: a list or tuple of 2 strings, the starting layer name and\n            ending layer name (both inclusive) for the result. All layers will\n            be included when `None` is provided.\n\n    Returns:\n        The index value of layer based on its unique name (layer_names).\n        Output will be [first_layer_index, last_layer_index + 1].\n    \"\"\"\n    if layer_range is not None:\n        if len(layer_range) != 2:\n            raise ValueError(\n                \"layer_range must be a list or tuple of length 2. Received: \"\n                f\"layer_range = {layer_range} of length {len(layer_range)}\"\n            )\n        if not isinstance(layer_range[0], str) or not isinstance(\n            layer_range[1], str\n        ):\n            raise ValueError(\n                \"layer_range should contain string type only. \"\n                f\"Received: {layer_range}\"\n            )\n    else:\n        return [0, len(layers)]\n\n    lower_index = [\n        idx\n        for idx, layer in enumerate(layers)\n        if re.match(layer_range[0], layer.name)\n    ]\n    upper_index = [\n        idx\n        for idx, layer in enumerate(layers)\n        if re.match(layer_range[1], layer.name)\n    ]\n\n    if not lower_index or not upper_index:\n        raise ValueError(\n            \"Passed layer_names do not match the layer names in the model. \"\n            f\"Received: {layer_range}\"\n        )\n\n    if min(lower_index) > max(upper_index):\n        return [min(upper_index), max(lower_index) + 1]\n    return [min(lower_index), max(upper_index) + 1]\n"
  },
  {
    "path": "keras/src/utils/summary_utils_test.py",
    "content": "import numpy as np\nimport pytest\nfrom absl.testing import parameterized\n\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import ops\nfrom keras.src import testing\nfrom keras.src.utils import summary_utils\n\n\nclass SummaryUtilsTest(testing.TestCase):\n    @parameterized.parameters([(\"adam\",), (None,)])\n    @pytest.mark.requires_trainable_backend\n    def test_print_model_summary(self, optimizer):\n        inputs = layers.Input((2,))\n        outputs = layers.Dense(3)(inputs)\n        model = models.Model(inputs, outputs)\n        model.compile(optimizer=optimizer, loss=\"mse\", metrics=[\"mse\"])\n        if optimizer:\n            # Trigger the optimizer weights creation\n            model.fit(x=np.zeros([4, 2]), y=np.zeros([4, 3]))\n\n        summary_content = []\n\n        def print_to_variable(text, line_break=False):\n            summary_content.append(text)\n\n        try:\n            summary_utils.print_summary(model, print_fn=print_to_variable)\n            summary_content = \"\\n\".join(summary_content)\n            if optimizer:\n                self.assertIn(\"Total params: 29\", summary_content)\n                self.assertIn(\"Trainable params: 9\", summary_content)\n                self.assertIn(\"Non-trainable params: 0\", summary_content)\n                self.assertIn(\"Optimizer params: 20\", summary_content)\n            else:\n                self.assertIn(\"Total params: 9\", summary_content)\n                self.assertIn(\"Trainable params: 9\", summary_content)\n                self.assertIn(\"Non-trainable params: 0\", summary_content)\n                self.assertNotIn(\"Optimizer params\", summary_content)\n        except ImportError:\n            pass\n\n    def test_print_model_summary_custom_build(self):\n        class MyModel(models.Model):\n            def __init__(self):\n                super().__init__()\n                self.dense1 = layers.Dense(4, activation=\"relu\")\n                self.dense2 = layers.Dense(2, activation=\"softmax\")\n                self.unbuilt_dense = layers.Dense(1)\n\n            def build(self, input_shape):\n                self.dense1.build(input_shape)\n                input_shape = self.dense1.compute_output_shape(input_shape)\n                self.dense2.build(input_shape)\n\n            def call(self, inputs):\n                x = self.dense1(inputs)\n                return self.dense2(x)\n\n        model = MyModel()\n        model.build((None, 2))\n\n        summary_content = []\n\n        def print_to_variable(text, line_break=False):\n            summary_content.append(text)\n\n        summary_utils.print_summary(model, print_fn=print_to_variable)\n        summary_content = \"\\n\".join(summary_content)\n        self.assertIn(\"(None, 4)\", summary_content)  # dense1\n        self.assertIn(\"(None, 2)\", summary_content)  # dense2\n        self.assertIn(\"?\", summary_content)  # unbuilt_dense\n        self.assertIn(\"Total params: 22\", summary_content)\n        self.assertIn(\"Trainable params: 22\", summary_content)\n        self.assertIn(\"Non-trainable params: 0\", summary_content)\n\n    def test_print_model_summary_op_as_layer(self):\n        inputs = layers.Input((2,))\n        x = layers.Dense(4)(inputs)\n        outputs = ops.mean(x)\n        model = models.Model(inputs, outputs)\n\n        summary_content = []\n\n        def print_to_variable(text, line_break=False):\n            summary_content.append(text)\n\n        summary_utils.print_summary(\n            model, print_fn=print_to_variable, show_trainable=True\n        )\n        summary_content = \"\\n\".join(summary_content)\n        self.assertIn(\"(None, 4)\", summary_content)  # dense\n        self.assertIn(\"Y\", summary_content)  # dense\n        self.assertIn(\"()\", summary_content)  # mean\n        self.assertIn(\"-\", summary_content)  # mean\n        self.assertIn(\"Total params: 12\", summary_content)\n        self.assertIn(\"Trainable params: 12\", summary_content)\n        self.assertIn(\"Non-trainable params: 0\", summary_content)\n\n    def test_print_model_summary_with_mha(self):\n        # In Keras <= 3.6, MHA exposes `output_shape` property which breaks this\n        # test.\n        class MyModel(models.Model):\n            def __init__(self):\n                super().__init__()\n                self.mha = layers.MultiHeadAttention(2, 2, output_shape=(4,))\n\n            def call(self, inputs):\n                return self.mha(inputs, inputs, inputs)\n\n        model = MyModel()\n        model(np.ones((1, 2, 2)))\n\n        summary_content = []\n\n        def print_to_variable(text, line_break=False):\n            summary_content.append(text)\n\n        summary_utils.print_summary(model, print_fn=print_to_variable)\n        summary_content = \"\\n\".join(summary_content)\n        self.assertIn(\"(1, 2, 4)\", summary_content)  # mha\n        self.assertIn(\"Total params: 56\", summary_content)\n        self.assertIn(\"Trainable params: 56\", summary_content)\n        self.assertIn(\"Non-trainable params: 0\", summary_content)\n"
  },
  {
    "path": "keras/src/utils/text_dataset_utils.py",
    "content": "import numpy as np\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.utils import dataset_utils\nfrom keras.src.utils.grain_utils import make_string_batch\nfrom keras.src.utils.module_utils import grain\nfrom keras.src.utils.module_utils import tensorflow as tf\n\n\n@keras_export(\n    [\n        \"keras.utils.text_dataset_from_directory\",\n        \"keras.preprocessing.text_dataset_from_directory\",\n    ]\n)\ndef text_dataset_from_directory(\n    directory,\n    labels=\"inferred\",\n    label_mode=\"int\",\n    class_names=None,\n    batch_size=32,\n    max_length=None,\n    shuffle=True,\n    seed=None,\n    validation_split=None,\n    subset=None,\n    follow_links=False,\n    format=\"tf\",\n    verbose=True,\n):\n    \"\"\"Generates a dataset from text files in a directory.\n\n    If your directory structure is:\n\n    ```\n    main_directory/\n    ...class_a/\n    ......a_text_1.txt\n    ......a_text_2.txt\n    ...class_b/\n    ......b_text_1.txt\n    ......b_text_2.txt\n    ```\n\n    Then calling `text_dataset_from_directory(main_directory,\n    labels='inferred')` will return a dataset that yields batches of\n    texts from the subdirectories `class_a` and `class_b`, together with labels\n    0 and 1 (0 corresponding to `class_a` and 1 corresponding to `class_b`).\n\n    Only `.txt` files are supported at this time.\n\n    By default, this function will return a `tf.data.Dataset` object. You can\n    set `format=\"grain\"` to return a `grain.IterDataset` object instead, which\n    removes the TensorFlow dependency.\n\n    Args:\n        directory: Directory where the data is located.\n            If `labels` is `\"inferred\"`, it should contain\n            subdirectories, each containing text files for a class.\n            Otherwise, the directory structure is ignored.\n        labels: Either `\"inferred\"`\n            (labels are generated from the directory structure),\n            `None` (no labels),\n            or a list/tuple of integer labels of the same size as the number of\n            text files found in the directory. Labels should be sorted according\n            to the alphanumeric order of the text file paths\n            (obtained via `os.walk(directory)` in Python).\n        label_mode: String describing the encoding of `labels`. Options are:\n            - `\"int\"`: means that the labels are encoded as integers\n                (e.g. for `sparse_categorical_crossentropy` loss).\n            - `\"categorical\"` means that the labels are\n                encoded as a categorical vector\n                (e.g. for `categorical_crossentropy` loss).\n            - `\"binary\"` means that the labels (there can be only 2)\n                are encoded as `float32` scalars with values 0 or 1\n                (e.g. for `binary_crossentropy`).\n            - `None` (no labels).\n        class_names: Only valid if `\"labels\"` is `\"inferred\"`.\n            This is the explicit list of class names\n            (must match names of subdirectories). Used to control the order\n            of the classes (otherwise alphanumerical order is used).\n        batch_size: Size of the batches of data.\n            If `None`, the data will not be batched\n            (the dataset will yield individual samples).\n            Defaults to `32`.\n        max_length: Maximum size of a text string. Texts longer than this will\n            be truncated to `max_length`.\n        shuffle: Whether to shuffle the data.\n            If set to `False`, sorts the data in alphanumeric order.\n            Defaults to `True`.\n        seed: Optional random seed for shuffling and transformations.\n        validation_split: Optional float between 0 and 1,\n            fraction of data to reserve for validation.\n        subset: Subset of the data to return.\n            One of `\"training\"`, `\"validation\"` or `\"both\"`.\n            Only used if `validation_split` is set.\n            When `subset=\"both\"`, the utility returns a tuple of two datasets\n            (the training and validation datasets respectively).\n        follow_links: Whether to visits subdirectories pointed to by symlinks.\n            Defaults to `False`.\n        format: The format of the return object. Defaults to `\"tf\"`. Available\n            options are:\n            - `\"tf\"`: returns a `tf.data.Dataset` object. Requires\n                TensorFlow to be installed.\n            - `\"grain\"`: returns a `grain.IterDataset` object. Requires\n                Grain to be installed.\n        verbose: Whether to display number information on classes and\n            number of files found. Defaults to `True`.\n\n    Returns:\n\n    A `tf.data.Dataset` (`format=\"tf\"`) or `grain.IterDataset`\n    (`format=\"grain\"`) object.\n\n    When `format=\"tf\"`:\n    - If `label_mode` is `None`, it yields `string` tensors of shape\n        `(batch_size,)`, containing the contents of a batch of text files.\n    - Otherwise, it yields a tuple `(texts, labels)`, where `texts`\n        has shape `(batch_size,)` and `labels` follows the format described\n        below.\n\n    When `format=\"grain\"`:\n    - If `label_mode` is `None`, it yields a list of Python strings containing\n        the contents of a batch of text files.\n    - Otherwise, it yields a tuple `(texts, labels)`, where `texts`\n        is a list of Python strings and `labels` follows the format described\n        below.\n\n    Rules regarding labels format:\n\n    - if `label_mode` is `int`, the labels are an `int32` tensor of shape\n        `(batch_size,)`.\n    - if `label_mode` is `binary`, the labels are a `float32` tensor of\n        1s and 0s of shape `(batch_size, 1)`.\n    - if `label_mode` is `categorical`, the labels are a `float32` tensor\n        of shape `(batch_size, num_classes)`, representing a one-hot\n        encoding of the class index.\n    \"\"\"\n    if labels not in (\"inferred\", None):\n        if not isinstance(labels, (list, tuple)):\n            raise ValueError(\n                \"`labels` argument should be a list/tuple of integer labels, \"\n                \"of the same size as the number of text files in the target \"\n                \"directory. If you wish to infer the labels from the \"\n                \"subdirectory names in the target directory, \"\n                'pass `labels=\"inferred\"`. '\n                \"If you wish to get a dataset that only contains text samples \"\n                f\"(no labels), pass `labels=None`. Received: labels={labels}\"\n            )\n        if class_names:\n            raise ValueError(\n                \"You can only pass `class_names` if \"\n                f'`labels=\"inferred\"`. Received: labels={labels}, and '\n                f\"class_names={class_names}\"\n            )\n    if label_mode not in {\"int\", \"categorical\", \"binary\", None}:\n        raise ValueError(\n            '`label_mode` argument must be one of \"int\", '\n            '\"categorical\", \"binary\", '\n            f\"or None. Received: label_mode={label_mode}\"\n        )\n    if format not in (\"tf\", \"grain\"):\n        raise ValueError(\n            '`format` should be either \"tf\" or \"grain\". '\n            f\"Received: format={format}\"\n        )\n    if labels is None or label_mode is None:\n        labels = None\n        label_mode = None\n    dataset_utils.check_validation_split_arg(\n        validation_split, subset, shuffle, seed\n    )\n\n    if seed is None:\n        seed = np.random.randint(1e6)\n    file_paths, labels, class_names = dataset_utils.index_directory(\n        directory,\n        labels,\n        formats=(\".txt\",),\n        class_names=class_names,\n        shuffle=shuffle,\n        seed=seed,\n        follow_links=follow_links,\n        verbose=verbose,\n    )\n\n    if label_mode == \"binary\" and len(class_names) != 2:\n        raise ValueError(\n            'When passing `label_mode=\"binary\"`, there must be exactly 2 '\n            f\"class_names. Received: class_names={class_names}\"\n        )\n    if batch_size is not None:\n        shuffle_buffer_size = batch_size * 8\n    else:\n        shuffle_buffer_size = 1024\n\n    if subset == \"both\":\n        (\n            file_paths_train,\n            labels_train,\n        ) = dataset_utils.get_training_or_validation_split(\n            file_paths, labels, validation_split, \"training\"\n        )\n        (\n            file_paths_val,\n            labels_val,\n        ) = dataset_utils.get_training_or_validation_split(\n            file_paths, labels, validation_split, \"validation\"\n        )\n        if not file_paths_train:\n            raise ValueError(\n                f\"No training text files found in directory {directory}. \"\n                \"Allowed format: .txt\"\n            )\n        if not file_paths_val:\n            raise ValueError(\n                f\"No validation text files found in directory {directory}. \"\n                \"Allowed format: .txt\"\n            )\n        train_dataset = paths_and_labels_to_dataset(\n            file_paths=file_paths_train,\n            labels=labels_train,\n            label_mode=label_mode,\n            num_classes=len(class_names) if class_names else 0,\n            max_length=max_length,\n            shuffle=shuffle,\n            shuffle_buffer_size=shuffle_buffer_size,\n            seed=seed,\n            format=format,\n        )\n        val_dataset = paths_and_labels_to_dataset(\n            file_paths=file_paths_val,\n            labels=labels_val,\n            label_mode=label_mode,\n            num_classes=len(class_names) if class_names else 0,\n            max_length=max_length,\n            shuffle=False,\n            format=format,\n        )\n\n        if format == \"tf\":\n            if batch_size is not None:\n                train_dataset = train_dataset.batch(batch_size)\n                val_dataset = val_dataset.batch(batch_size)\n            train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)\n            val_dataset = val_dataset.prefetch(tf.data.AUTOTUNE)\n        else:\n            train_dataset = train_dataset.to_iter_dataset()\n            val_dataset = val_dataset.to_iter_dataset()\n            if batch_size is not None:\n                train_dataset = train_dataset.batch(\n                    batch_size, batch_fn=make_string_batch\n                )\n                val_dataset = val_dataset.batch(\n                    batch_size, batch_fn=make_string_batch\n                )\n\n        # Users may need to reference `class_names`.\n        train_dataset.class_names = class_names\n        val_dataset.class_names = class_names\n        dataset = [train_dataset, val_dataset]\n    else:\n        file_paths, labels = dataset_utils.get_training_or_validation_split(\n            file_paths, labels, validation_split, subset\n        )\n        if not file_paths:\n            raise ValueError(\n                f\"No text files found in directory {directory}. \"\n                \"Allowed format: .txt\"\n            )\n        dataset = paths_and_labels_to_dataset(\n            file_paths=file_paths,\n            labels=labels,\n            label_mode=label_mode,\n            num_classes=len(class_names) if class_names else 0,\n            max_length=max_length,\n            shuffle=shuffle,\n            shuffle_buffer_size=shuffle_buffer_size,\n            seed=seed,\n            format=format,\n        )\n\n        if format == \"tf\":\n            if batch_size is not None:\n                dataset = dataset.batch(batch_size)\n            dataset = dataset.prefetch(tf.data.AUTOTUNE)\n        else:\n            dataset = dataset.to_iter_dataset()\n            if batch_size is not None:\n                dataset = dataset.batch(batch_size, batch_fn=make_string_batch)\n\n        # Users may need to reference `class_names`.\n        dataset.class_names = class_names\n    return dataset\n\n\ndef paths_and_labels_to_dataset(\n    file_paths,\n    labels,\n    label_mode,\n    num_classes,\n    max_length,\n    shuffle=False,\n    shuffle_buffer_size=None,\n    seed=None,\n    format=\"tf\",\n):\n    \"\"\"Constructs a dataset of text strings and labels.\"\"\"\n    if format == \"tf\":\n        return _paths_and_labels_to_dataset_tf(\n            file_paths,\n            labels,\n            label_mode,\n            num_classes,\n            max_length,\n            shuffle,\n            shuffle_buffer_size,\n            seed,\n        )\n    elif format == \"grain\":\n        return _paths_and_labels_to_dataset_grain(\n            file_paths,\n            labels,\n            label_mode,\n            num_classes,\n            max_length,\n            shuffle,\n            shuffle_buffer_size,\n            seed,\n        )\n\n\ndef _paths_and_labels_to_dataset_tf(\n    file_paths,\n    labels,\n    label_mode,\n    num_classes,\n    max_length,\n    shuffle=False,\n    shuffle_buffer_size=None,\n    seed=None,\n):\n    \"\"\"Constructs a dataset of text strings and labels.\"\"\"\n    path_ds = tf.data.Dataset.from_tensor_slices(file_paths)\n    if label_mode:\n        label_ds = dataset_utils.labels_to_dataset_tf(\n            labels, label_mode, num_classes\n        )\n        ds = tf.data.Dataset.zip((path_ds, label_ds))\n    else:\n        ds = path_ds\n\n    if shuffle:\n        ds = ds.shuffle(buffer_size=shuffle_buffer_size or 1024, seed=seed)\n\n    if label_mode:\n        ds = ds.map(\n            lambda x, y: (_path_to_string_content_tf(x, max_length), y),\n            num_parallel_calls=tf.data.AUTOTUNE,\n        )\n    else:\n        ds = ds.map(\n            lambda x: _path_to_string_content_tf(x, max_length),\n            num_parallel_calls=tf.data.AUTOTUNE,\n        )\n    return ds\n\n\ndef _path_to_string_content_tf(path, max_length):\n    txt = tf.io.read_file(path)\n    if max_length is not None:\n        txt = tf.strings.substr(txt, 0, max_length)\n    return txt\n\n\ndef _paths_and_labels_to_dataset_grain(\n    file_paths,\n    labels,\n    label_mode,\n    num_classes,\n    max_length,\n    shuffle=False,\n    shuffle_buffer_size=None,\n    seed=None,\n):\n    \"\"\"Constructs a dataset of text strings and labels.\"\"\"\n    path_ds = grain.MapDataset.source(file_paths)\n    if label_mode:\n        label_ds = dataset_utils.labels_to_dataset_grain(\n            labels, label_mode, num_classes\n        )\n        ds = grain.experimental.ZipMapDataset([path_ds, label_ds])\n    else:\n        ds = path_ds\n\n    if shuffle:\n        ds = ds.shuffle(seed=seed)\n\n    if label_mode:\n        ds = ds.map(\n            lambda data: (\n                _path_to_string_content_grain(data[0], max_length),\n                data[1],\n            ),\n        )\n    else:\n        ds = ds.map(lambda x: _path_to_string_content_grain(x, max_length))\n    return ds\n\n\ndef _path_to_string_content_grain(path, max_length):\n    with open(path, \"r\") as f:\n        txt = f.read()\n    if max_length is not None:\n        txt = txt[:max_length]\n    return txt\n"
  },
  {
    "path": "keras/src/utils/text_dataset_utils_test.py",
    "content": "import os\nimport random\nimport string\n\nfrom absl.testing import parameterized\n\nfrom keras.src import backend\nfrom keras.src import testing\nfrom keras.src.utils import text_dataset_utils\n\n\nclass TextDatasetFromDirectoryTest(testing.TestCase):\n    def _prepare_directory(\n        self, num_classes=2, nested_dirs=False, count=16, length=20\n    ):\n        # Get a unique temp directory\n        temp_dir = self.get_temp_dir()\n\n        # Generate paths to class subdirectories\n        paths = []\n        for class_index in range(num_classes):\n            class_directory = f\"class_{class_index}\"\n            if nested_dirs:\n                class_paths = [\n                    class_directory,\n                    os.path.join(class_directory, \"subfolder_1\"),\n                    os.path.join(class_directory, \"subfolder_2\"),\n                    os.path.join(\n                        class_directory, \"subfolder_1\", \"sub-subfolder\"\n                    ),\n                ]\n            else:\n                class_paths = [class_directory]\n            for path in class_paths:\n                os.mkdir(os.path.join(temp_dir, path))\n            paths += class_paths\n\n        for i in range(count):\n            path = paths[i % len(paths)]\n            filename = os.path.join(path, f\"text_{i}.txt\")\n            with open(os.path.join(temp_dir, filename), \"w\") as f:\n                text = \"\".join(\n                    [random.choice(string.printable) for _ in range(length)]\n                )\n                f.write(text)\n        return temp_dir\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_text_dataset_from_directory_standalone(self, format):\n        # Test retrieving txt files without labels from a directory and its\n        # subdirs. Save a few extra files in the parent directory.\n        directory = self._prepare_directory(count=7, num_classes=2)\n        for i in range(3):\n            filename = f\"text_{i}.txt\"\n            with open(os.path.join(directory, filename), \"w\") as f:\n                text = \"\".join(\n                    [random.choice(string.printable) for _ in range(20)]\n                )\n                f.write(text)\n\n        dataset = text_dataset_utils.text_dataset_from_directory(\n            directory,\n            batch_size=5,\n            label_mode=None,\n            max_length=10,\n            format=format,\n        )\n        batch = next(iter(dataset))\n        # We just return the texts, no labels\n        if format == \"tf\" or backend.backend() == \"tensorflow\":\n            self.assertEqual(list(batch.shape), [5])\n            self.assertDType(batch, \"string\")\n        else:\n            self.assertLen(batch, 5)\n            self.assertIsInstance(batch[0], str)\n        # Count samples\n        batch_count = 0\n        sample_count = 0\n        for batch in dataset:\n            batch_count += 1\n            sample_count += len(batch)\n        self.assertEqual(batch_count, 2)\n        self.assertEqual(sample_count, 10)\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_text_dataset_from_directory_binary(self, format=format):\n        directory = self._prepare_directory(num_classes=2)\n        dataset = text_dataset_utils.text_dataset_from_directory(\n            directory,\n            batch_size=8,\n            label_mode=\"int\",\n            max_length=10,\n            format=format,\n        )\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        if format == \"tf\" or backend.backend() == \"tensorflow\":\n            self.assertEqual(batch[0].shape, (8,))\n            self.assertDType(batch[0], \"string\")\n            self.assertEqual(len(batch[0].numpy()[0]), 10)  # Test max_length\n        else:\n            self.assertLen(batch[0], 8)\n            self.assertIsInstance(batch[0][0], str)\n            self.assertLen(batch[0][0], 10)  # Test max_length\n        self.assertEqual(list(batch[1].shape), [8])\n        self.assertDType(batch[1], \"int32\")\n\n        dataset = text_dataset_utils.text_dataset_from_directory(\n            directory,\n            batch_size=8,\n            label_mode=\"binary\",\n            format=format,\n        )\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        if format == \"tf\" or backend.backend() == \"tensorflow\":\n            self.assertEqual(list(batch[0].shape), [8])\n            self.assertEqual(batch[0].dtype.name, \"string\")\n        else:\n            self.assertLen(batch[0], 8)\n            self.assertIsInstance(batch[0][0], str)\n        self.assertEqual(list(batch[1].shape), [8, 1])\n        self.assertDType(batch[1], \"float32\")\n\n        dataset = text_dataset_utils.text_dataset_from_directory(\n            directory,\n            batch_size=8,\n            label_mode=\"categorical\",\n            format=format,\n        )\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        if format == \"tf\" or backend.backend() == \"tensorflow\":\n            self.assertEqual(list(batch[0].shape), [8])\n            self.assertEqual(batch[0].dtype.name, \"string\")\n        else:\n            self.assertLen(batch[0], 8)\n            self.assertIsInstance(batch[0][0], str)\n        self.assertEqual(list(batch[1].shape), [8, 2])\n        self.assertDType(batch[1], \"float32\")\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_sample_count(self, format):\n        directory = self._prepare_directory(num_classes=4, count=15)\n        dataset = text_dataset_utils.text_dataset_from_directory(\n            directory, batch_size=8, label_mode=None, format=format\n        )\n        sample_count = 0\n        for batch in dataset:\n            sample_count += len(batch)\n        self.assertEqual(sample_count, 15)\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_text_dataset_from_directory_multiclass(self, format):\n        directory = self._prepare_directory(num_classes=4, count=15)\n\n        dataset = text_dataset_utils.text_dataset_from_directory(\n            directory, batch_size=8, label_mode=None, format=format\n        )\n        batch = next(iter(dataset))\n        self.assertLen(batch, 8)\n\n        dataset = text_dataset_utils.text_dataset_from_directory(\n            directory, batch_size=8, label_mode=None, format=format\n        )\n        sample_count = 0\n        iterator = iter(dataset)\n        for batch in dataset:\n            sample_count += len(next(iterator))\n        self.assertEqual(sample_count, 15)\n\n        dataset = text_dataset_utils.text_dataset_from_directory(\n            directory, batch_size=8, label_mode=\"int\", format=format\n        )\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        if format == \"tf\" or backend.backend() == \"tensorflow\":\n            self.assertEqual(list(batch[0].shape), [8])\n            self.assertEqual(batch[0].dtype.name, \"string\")\n        else:\n            self.assertLen(batch[0], 8)\n            self.assertIsInstance(batch[0][0], str)\n        self.assertEqual(list(batch[1].shape), [8])\n        self.assertDType(batch[1], \"int32\")\n\n        dataset = text_dataset_utils.text_dataset_from_directory(\n            directory, batch_size=8, label_mode=\"categorical\", format=format\n        )\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        if format == \"tf\" or backend.backend() == \"tensorflow\":\n            self.assertEqual(list(batch[0].shape), [8])\n            self.assertEqual(batch[0].dtype.name, \"string\")\n        else:\n            self.assertLen(batch[0], 8)\n            self.assertIsInstance(batch[0][0], str)\n        self.assertEqual(list(batch[1].shape), [8, 4])\n        self.assertDType(batch[1], \"float32\")\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_text_dataset_from_directory_validation_split(self, format):\n        directory = self._prepare_directory(num_classes=2, count=10)\n        dataset = text_dataset_utils.text_dataset_from_directory(\n            directory,\n            batch_size=10,\n            validation_split=0.2,\n            subset=\"training\",\n            seed=1337,\n            format=format,\n        )\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        self.assertLen(batch[0], 8)\n        dataset = text_dataset_utils.text_dataset_from_directory(\n            directory,\n            batch_size=10,\n            validation_split=0.2,\n            subset=\"validation\",\n            seed=1337,\n            format=format,\n        )\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        self.assertLen(batch[0], 2)\n\n        (\n            train_dataset,\n            val_dataset,\n        ) = text_dataset_utils.text_dataset_from_directory(\n            directory,\n            batch_size=10,\n            validation_split=0.2,\n            subset=\"both\",\n            seed=1337,\n            format=format,\n        )\n        batch = next(iter(train_dataset))\n        self.assertLen(batch, 2)\n        self.assertLen(batch[0], 8)\n        batch = next(iter(val_dataset))\n        self.assertLen(batch, 2)\n        self.assertLen(batch[0], 2)\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_text_dataset_from_directory_manual_labels(self, format):\n        directory = self._prepare_directory(num_classes=2, count=2)\n        dataset = text_dataset_utils.text_dataset_from_directory(\n            directory, batch_size=8, labels=[0, 1], shuffle=False, format=format\n        )\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        self.assertAllClose(batch[1], [0, 1])\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_text_dataset_from_directory_follow_links(self, format):\n        directory = self._prepare_directory(\n            num_classes=2, count=25, nested_dirs=True\n        )\n        dataset = text_dataset_utils.text_dataset_from_directory(\n            directory,\n            batch_size=8,\n            label_mode=None,\n            follow_links=True,\n            format=format,\n        )\n        sample_count = 0\n        for batch in dataset:\n            sample_count += len(batch)\n        self.assertEqual(sample_count, 25)\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_text_dataset_from_directory_no_files(self, format):\n        directory = self._prepare_directory(num_classes=2, count=0)\n        with self.assertRaisesRegex(ValueError, \"No text files found\"):\n            _ = text_dataset_utils.text_dataset_from_directory(\n                directory, format=format\n            )\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_text_dataset_from_directory_errors(self, format):\n        directory = self._prepare_directory(num_classes=3, count=5)\n\n        with self.assertRaisesRegex(ValueError, \"`labels` argument should be\"):\n            _ = text_dataset_utils.text_dataset_from_directory(\n                directory, labels=\"other\", format=format\n            )\n\n        with self.assertRaisesRegex(\n            ValueError, \"`label_mode` argument must be\"\n        ):\n            _ = text_dataset_utils.text_dataset_from_directory(\n                directory, label_mode=\"other\", format=format\n            )\n\n        with self.assertRaisesRegex(\n            ValueError, 'only pass `class_names` if `labels=\"inferred\"`'\n        ):\n            _ = text_dataset_utils.text_dataset_from_directory(\n                directory,\n                labels=[0, 0, 1, 1, 1],\n                class_names=[\"class_0\", \"class_1\", \"class_2\"],\n                format=format,\n            )\n\n        with self.assertRaisesRegex(\n            ValueError,\n            \"Expected the lengths of `labels` to match the number of files\",\n        ):\n            _ = text_dataset_utils.text_dataset_from_directory(\n                directory, labels=[0, 0, 1, 1], format=format\n            )\n\n        with self.assertRaisesRegex(\n            ValueError, \"`class_names` passed did not match\"\n        ):\n            _ = text_dataset_utils.text_dataset_from_directory(\n                directory, class_names=[\"class_0\", \"wrong_class\"], format=format\n            )\n\n        with self.assertRaisesRegex(ValueError, \"there must be exactly 2\"):\n            _ = text_dataset_utils.text_dataset_from_directory(\n                directory, label_mode=\"binary\", format=format\n            )\n\n        with self.assertRaisesRegex(\n            ValueError, \"`validation_split` must be between 0 and 1\"\n        ):\n            _ = text_dataset_utils.text_dataset_from_directory(\n                directory, validation_split=2, format=format\n            )\n\n        with self.assertRaisesRegex(\n            ValueError,\n            '`subset` must be either \"training\", \"validation\" or \"both\"',\n        ):\n            _ = text_dataset_utils.text_dataset_from_directory(\n                directory, validation_split=0.2, subset=\"other\", format=format\n            )\n\n        with self.assertRaisesRegex(\n            ValueError, \"`validation_split` must be set\"\n        ):\n            _ = text_dataset_utils.text_dataset_from_directory(\n                directory,\n                validation_split=0.0,\n                subset=\"training\",\n                format=format,\n            )\n\n        with self.assertRaisesRegex(ValueError, \"must provide a `seed`\"):\n            _ = text_dataset_utils.text_dataset_from_directory(\n                directory,\n                validation_split=0.2,\n                subset=\"training\",\n                format=format,\n            )\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_text_dataset_from_directory_not_batched(self, format):\n        directory = self._prepare_directory()\n        dataset = text_dataset_utils.text_dataset_from_directory(\n            directory,\n            batch_size=None,\n            label_mode=None,\n            follow_links=True,\n            format=format,\n        )\n\n        sample = next(iter(dataset))\n        if format == \"tf\":\n            self.assertEqual(len(sample.shape), 0)\n        else:\n            self.assertIsInstance(sample, str)\n"
  },
  {
    "path": "keras/src/utils/tf_utils.py",
    "content": "from keras.src import backend\nfrom keras.src.utils.module_utils import tensorflow as tf\n\n\ndef get_tensor_spec(t, dynamic_batch=False, name=None):\n    \"\"\"Returns a `TensorSpec` given a single `Tensor` or `TensorSpec`.\"\"\"\n    if isinstance(t, tf.TypeSpec):\n        spec = t\n    elif isinstance(t, tf.__internal__.CompositeTensor):\n        # Check for ExtensionTypes\n        spec = t._type_spec\n    elif hasattr(t, \"shape\") and hasattr(t, \"dtype\"):\n        spec = tf.TensorSpec(shape=t.shape, dtype=t.dtype, name=name)\n    else:\n        return None  # Allow non-Tensors to pass through.\n\n    if not dynamic_batch:\n        return spec\n\n    shape = spec.shape\n    if shape.rank is None or shape.rank == 0:\n        return spec\n\n    shape_list = shape.as_list()\n    shape_list[0] = None\n    shape = tf.TensorShape(shape_list)\n    spec._shape = shape\n    return spec\n\n\ndef ensure_tensor(inputs, dtype=None):\n    \"\"\"Ensures the input is a Tensor, SparseTensor or RaggedTensor.\"\"\"\n    if not isinstance(inputs, (tf.Tensor, tf.SparseTensor, tf.RaggedTensor)):\n        if backend.backend() == \"torch\" and backend.is_tensor(inputs):\n            # Plain `np.asarray()` conversion fails with PyTorch.\n            inputs = backend.convert_to_numpy(inputs)\n        inputs = tf.convert_to_tensor(inputs, dtype)\n    if dtype is not None and inputs.dtype != dtype:\n        inputs = tf.cast(inputs, dtype)\n    return inputs\n\n\ndef is_ragged_tensor(x):\n    return \"ragged_tensor.RaggedTensor\" in str(type(x))\n\n\ndef sparse_bincount(inputs, depth, binary_output, dtype, count_weights=None):\n    \"\"\"Apply binary or count encoding to an input and return a sparse tensor.\"\"\"\n    result = tf.sparse.bincount(\n        inputs,\n        weights=count_weights,\n        minlength=depth,\n        maxlength=depth,\n        axis=-1,\n        binary_output=binary_output,\n    )\n    result = tf.cast(result, dtype)\n    if inputs.shape.rank == 1:\n        output_shape = (depth,)\n    else:\n        batch_size = tf.shape(result)[0]\n        output_shape = (batch_size, depth)\n    result = tf.SparseTensor(\n        indices=result.indices, values=result.values, dense_shape=output_shape\n    )\n    return result\n\n\ndef dense_bincount(inputs, depth, binary_output, dtype, count_weights=None):\n    \"\"\"Apply binary or count encoding to an input.\"\"\"\n    result = tf.math.bincount(\n        inputs,\n        weights=count_weights,\n        minlength=depth,\n        maxlength=depth,\n        dtype=dtype,\n        axis=-1,\n        binary_output=binary_output,\n    )\n    if inputs.shape.rank == 1:\n        result.set_shape(tf.TensorShape((depth,)))\n    else:\n        batch_size = inputs.shape.as_list()[0]\n        result.set_shape(tf.TensorShape((batch_size, depth)))\n    return result\n\n\ndef expand_dims(inputs, axis):\n    \"\"\"Expand dims on sparse, ragged, or dense tensors.\"\"\"\n    if isinstance(inputs, tf.SparseTensor):\n        return tf.sparse.expand_dims(inputs, axis)\n    return tf.expand_dims(inputs, axis)\n\n\ndef tf_encode_categorical_inputs(\n    inputs,\n    output_mode,\n    depth,\n    dtype=\"float32\",\n    sparse=False,\n    count_weights=None,\n    idf_weights=None,\n):\n    \"\"\"Encodes categorical inputs according to output_mode.\n\n    Faster method that relies on bincount.\n    \"\"\"\n\n    if output_mode == \"int\":\n        return tf.identity(tf.cast(inputs, dtype))\n\n    original_shape = inputs.shape\n    # In all cases, we should uprank scalar input to a single sample.\n    if inputs.shape.rank == 0:\n        inputs = expand_dims(inputs, -1)\n    # One hot will uprank only if the final output dimension is not already 1.\n    if output_mode == \"one_hot\":\n        if inputs.shape[-1] != 1:\n            inputs = expand_dims(inputs, -1)\n\n    if inputs.shape.rank > 2:\n        raise ValueError(\n            \"When output_mode is not `'int'`, maximum supported output rank \"\n            f\"is 2. Received output_mode {output_mode} and input shape \"\n            f\"{original_shape}, \"\n            f\"which would result in output rank {inputs.shape.rank}.\"\n        )\n\n    binary_output = output_mode in (\"multi_hot\", \"one_hot\")\n    if sparse:\n        bincounts = sparse_bincount(\n            inputs, depth, binary_output, dtype, count_weights\n        )\n    else:\n        bincounts = dense_bincount(\n            inputs, depth, binary_output, dtype, count_weights\n        )\n\n    bincounts = tf.cast(bincounts, dtype)\n    if output_mode != \"tf_idf\":\n        return bincounts\n\n    if idf_weights is None:\n        raise ValueError(\n            \"When output mode is `'tf_idf'`, idf_weights must be provided. \"\n            f\"Received: output_mode={output_mode} and idf_weights={idf_weights}\"\n        )\n\n    if sparse:\n        value_weights = tf.gather(idf_weights, bincounts.indices[:, -1])\n        return tf.SparseTensor(\n            bincounts.indices,\n            value_weights * bincounts.values,\n            bincounts.dense_shape,\n        )\n    else:\n        return tf.multiply(bincounts, idf_weights)\n"
  },
  {
    "path": "keras/src/utils/timeseries_dataset_utils.py",
    "content": "import numpy as np\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.utils.grain_utils import make_batch\nfrom keras.src.utils.module_utils import grain\nfrom keras.src.utils.module_utils import tensorflow as tf\n\n\n@keras_export(\n    [\n        \"keras.utils.timeseries_dataset_from_array\",\n        \"keras.preprocessing.timeseries_dataset_from_array\",\n    ]\n)\ndef timeseries_dataset_from_array(\n    data,\n    targets,\n    sequence_length,\n    sequence_stride=1,\n    sampling_rate=1,\n    batch_size=128,\n    shuffle=False,\n    seed=None,\n    start_index=None,\n    end_index=None,\n    format=\"tf\",\n):\n    \"\"\"Creates a dataset of sliding windows over a timeseries provided as array.\n\n    This function takes in a sequence of data-points gathered at\n    equal intervals, along with time series parameters such as\n    length of the sequences/windows, spacing between two sequence/windows, etc.,\n    to produce batches of timeseries inputs and targets.\n\n    Args:\n        data: Numpy array or eager tensor\n            containing consecutive data points (timesteps).\n            Axis 0 is expected to be the time dimension.\n        targets: Targets corresponding to timesteps in `data`.\n            `targets[i]` should be the target\n            corresponding to the window that starts at index `i`\n            (see example 2 below).\n            Pass `None` if you don't have target data (in this case the dataset\n            will only yield the input data).\n        sequence_length: Length of the output sequences\n            (in number of timesteps).\n        sequence_stride: Period between successive output sequences.\n            For stride `s`, output samples would\n            start at index `data[i]`, `data[i + s]`, `data[i + 2 * s]`, etc.\n        sampling_rate: Period between successive individual timesteps\n            within sequences. For rate `r`, timesteps\n            `data[i], data[i + r], ... data[i + sequence_length]`\n            are used for creating a sample sequence.\n        batch_size: Number of timeseries samples in each batch\n            (except maybe the last one). If `None`, the data will not be batched\n            (the dataset will yield individual samples).\n        shuffle: Whether to shuffle output samples,\n            or instead draw them in chronological order.\n        seed: Optional int; random seed for shuffling.\n        start_index: Optional int; data points earlier (exclusive)\n            than `start_index` will not be used\n            in the output sequences. This is useful to reserve part of the\n            data for test or validation.\n        end_index: Optional int; data points later (exclusive) than `end_index`\n            will not be used in the output sequences.\n            This is useful to reserve part of the data for test or validation.\n        format: Optional string; either `\"tf\"` (default) for a\n            `tf.data.Dataset`, or `\"grain\"` for a Grain dataset. Grain\n            datasets are framework-agnostic and work with JAX, PyTorch,\n            and NumPy backends without requiring TensorFlow.\n\n    Returns:\n\n    A dataset instance. If `format=\"tf\"`, returns a `tf.data.Dataset`.\n    If `format=\"grain\"`, returns a Grain `IterDataset` (or `MapDataset`\n    when `batch_size` is `None`).\n    If `targets` was passed, the dataset yields\n    tuple `(batch_of_sequences, batch_of_targets)`. If not, the dataset yields\n    only `batch_of_sequences`.\n\n    Example 1:\n\n    Consider indices `[0, 1, ... 98]`.\n    With `sequence_length=10,  sampling_rate=2, sequence_stride=3`,\n    `shuffle=False`, the dataset will yield batches of sequences\n    composed of the following indices:\n\n    ```\n    First sequence:  [0  2  4  6  8 10 12 14 16 18]\n    Second sequence: [3  5  7  9 11 13 15 17 19 21]\n    Third sequence:  [6  8 10 12 14 16 18 20 22 24]\n    ...\n    Last sequence:   [78 80 82 84 86 88 90 92 94 96]\n    ```\n\n    In this case the last 2 data points are discarded since no full sequence\n    can be generated to include them (the next sequence would have started\n    at index 81, and thus its last step would have gone over 98).\n\n    Example 2: Temporal regression.\n\n    Consider an array `data` of scalar values, of shape `(steps,)`.\n    To generate a dataset that uses the past 10\n    timesteps to predict the next timestep, you would use:\n\n    ```python\n    input_data = data[:-10]\n    targets = data[10:]\n    dataset = timeseries_dataset_from_array(\n        input_data, targets, sequence_length=10)\n    for batch in dataset:\n      inputs, targets = batch\n      assert np.array_equal(inputs[0], data[:10])  # First sequence: steps [0-9]\n      # Corresponding target: step 10\n      assert np.array_equal(targets[0], data[10])\n      break\n    ```\n\n    Example 3: Temporal regression for many-to-many architectures.\n\n    Consider two arrays of scalar values `X` and `Y`,\n    both of shape `(100,)`. The resulting dataset should consist samples with\n    20 timestamps each. The samples should not overlap.\n    To generate a dataset that uses the current timestamp\n    to predict the corresponding target timestep, you would use:\n\n    ```python\n    X = np.arange(100)\n    Y = X*2\n\n    sample_length = 20\n    input_dataset = timeseries_dataset_from_array(\n        X, None, sequence_length=sample_length, sequence_stride=sample_length)\n    target_dataset = timeseries_dataset_from_array(\n        Y, None, sequence_length=sample_length, sequence_stride=sample_length)\n\n    for batch in zip(input_dataset, target_dataset):\n        inputs, targets = batch\n        assert np.array_equal(inputs[0], X[:sample_length])\n\n        # second sample equals output timestamps 20-40\n        assert np.array_equal(targets[1], Y[sample_length:2*sample_length])\n        break\n    ```\n    \"\"\"\n    if start_index:\n        if start_index < 0:\n            raise ValueError(\n                \"`start_index` must be 0 or greater. Received: \"\n                f\"start_index={start_index}\"\n            )\n        if start_index >= len(data):\n            raise ValueError(\n                \"`start_index` must be lower than the length of the \"\n                f\"data. Received: start_index={start_index}, for data \"\n                f\"of length {len(data)}\"\n            )\n    if end_index:\n        if start_index and end_index <= start_index:\n            raise ValueError(\n                \"`end_index` must be higher than `start_index`. \"\n                f\"Received: start_index={start_index}, and \"\n                f\"end_index={end_index} \"\n            )\n        if end_index >= len(data):\n            raise ValueError(\n                \"`end_index` must be lower than the length of the \"\n                f\"data. Received: end_index={end_index}, for data of \"\n                f\"length {len(data)}\"\n            )\n        if end_index <= 0:\n            raise ValueError(\n                \"`end_index` must be higher than 0. \"\n                f\"Received: end_index={end_index}\"\n            )\n\n    # Validate strides\n    if sampling_rate <= 0:\n        raise ValueError(\n            \"`sampling_rate` must be higher than 0. Received: \"\n            f\"sampling_rate={sampling_rate}\"\n        )\n    if sampling_rate >= len(data):\n        raise ValueError(\n            \"`sampling_rate` must be lower than the length of the \"\n            f\"data. Received: sampling_rate={sampling_rate}, for data \"\n            f\"of length {len(data)}\"\n        )\n    if sequence_stride <= 0:\n        raise ValueError(\n            \"`sequence_stride` must be higher than 0. Received: \"\n            f\"sequence_stride={sequence_stride}\"\n        )\n    if sequence_stride >= len(data):\n        raise ValueError(\n            \"`sequence_stride` must be lower than the length of the \"\n            f\"data. Received: sequence_stride={sequence_stride}, for \"\n            f\"data of length {len(data)}\"\n        )\n\n    if format not in (\"tf\", \"grain\"):\n        raise ValueError(\n            '`format` should be either \"tf\" or \"grain\". '\n            f\"Received: format={format}\"\n        )\n\n    if format == \"tf\":\n        return _timeseries_dataset_tf(\n            data=data,\n            targets=targets,\n            sequence_length=sequence_length,\n            sequence_stride=sequence_stride,\n            sampling_rate=sampling_rate,\n            batch_size=batch_size,\n            shuffle=shuffle,\n            seed=seed,\n            start_index=start_index,\n            end_index=end_index,\n        )\n    else:\n        return _timeseries_dataset_grain(\n            data=data,\n            targets=targets,\n            sequence_length=sequence_length,\n            sequence_stride=sequence_stride,\n            sampling_rate=sampling_rate,\n            batch_size=batch_size,\n            shuffle=shuffle,\n            seed=seed,\n            start_index=start_index,\n            end_index=end_index,\n        )\n\n\ndef _timeseries_dataset_tf(\n    data,\n    targets,\n    sequence_length,\n    sequence_stride,\n    sampling_rate,\n    batch_size,\n    shuffle,\n    seed,\n    start_index,\n    end_index,\n):\n    if start_index is None:\n        start_index = 0\n    if end_index is None:\n        end_index = len(data)\n\n    # Determine the lowest dtype to store start positions (to lower memory\n    # usage).\n    num_seqs = end_index - start_index - (sequence_length - 1) * sampling_rate\n    if targets is not None:\n        num_seqs = min(num_seqs, len(targets))\n    if num_seqs < 2147483647:\n        index_dtype = \"int32\"\n    else:\n        index_dtype = \"int64\"\n\n    # Generate start positions\n    start_positions = np.arange(0, num_seqs, sequence_stride, dtype=index_dtype)\n    if shuffle:\n        if seed is None:\n            seed = np.random.randint(1e6)\n        rng = np.random.RandomState(seed)\n        rng.shuffle(start_positions)\n\n    sequence_length = tf.cast(sequence_length, dtype=index_dtype)\n    sampling_rate = tf.cast(sampling_rate, dtype=index_dtype)\n\n    positions_ds = tf.data.Dataset.from_tensors(start_positions).repeat()\n\n    # For each initial window position, generates indices of the window elements\n    indices = tf.data.Dataset.zip(\n        (tf.data.Dataset.range(len(start_positions)), positions_ds)\n    ).map(\n        lambda i, positions: tf.range(\n            positions[i],\n            positions[i] + sequence_length * sampling_rate,\n            sampling_rate,\n        ),\n        num_parallel_calls=tf.data.AUTOTUNE,\n    )\n\n    dataset = sequences_from_indices(data, indices, start_index, end_index)\n    if targets is not None:\n        indices = tf.data.Dataset.zip(\n            (tf.data.Dataset.range(len(start_positions)), positions_ds)\n        ).map(\n            lambda i, positions: positions[i],\n            num_parallel_calls=tf.data.AUTOTUNE,\n        )\n        target_ds = sequences_from_indices(\n            targets, indices, start_index, end_index\n        )\n        dataset = tf.data.Dataset.zip((dataset, target_ds))\n    dataset = dataset.prefetch(tf.data.AUTOTUNE)\n    if batch_size is not None:\n        if shuffle:\n            # Shuffle locally at each iteration\n            dataset = dataset.shuffle(buffer_size=batch_size * 8, seed=seed)\n        dataset = dataset.batch(batch_size)\n    else:\n        if shuffle:\n            dataset = dataset.shuffle(buffer_size=1024, seed=seed)\n    return dataset\n\n\ndef _timeseries_dataset_grain(\n    data,\n    targets,\n    sequence_length,\n    sequence_stride,\n    sampling_rate,\n    batch_size,\n    shuffle,\n    seed,\n    start_index,\n    end_index,\n):\n    if start_index is None:\n        start_index = 0\n    if end_index is None:\n        end_index = len(data)\n\n    # Compute number of sequences and start positions\n    num_seqs = end_index - start_index - (sequence_length - 1) * sampling_rate\n    if targets is not None:\n        num_seqs = min(num_seqs, len(targets))\n\n    start_positions = np.arange(0, num_seqs, sequence_stride)\n    if shuffle:\n        if seed is None:\n            seed = np.random.randint(1e6)\n        rng = np.random.RandomState(seed)\n        rng.shuffle(start_positions)\n\n    data_slice = np.array(data[start_index:end_index])\n\n    # Build the list of sequences as numpy arrays\n    sequences = []\n    for pos in start_positions:\n        indices = np.arange(\n            pos, pos + sequence_length * sampling_rate, sampling_rate\n        )\n        sequences.append(data_slice[indices])\n\n    seq_ds = grain.MapDataset.source(sequences)\n\n    if targets is not None:\n        target_slice = np.array(targets[start_index:])\n        target_values = [target_slice[pos] for pos in start_positions]\n        target_ds = grain.MapDataset.source(target_values)\n        ds = grain.experimental.ZipMapDataset([seq_ds, target_ds])\n    else:\n        ds = seq_ds\n\n    if shuffle:\n        ds = ds.shuffle(seed=seed)\n\n    ds = ds.to_iter_dataset()\n    if batch_size is not None:\n        ds = ds.batch(batch_size, batch_fn=make_batch)\n\n    return ds\n\n\ndef sequences_from_indices(array, indices_ds, start_index, end_index):\n    dataset = tf.data.Dataset.from_tensors(array[start_index:end_index])\n    dataset = tf.data.Dataset.zip((dataset.repeat(), indices_ds)).map(\n        lambda steps, inds: tf.gather(steps, inds),\n        num_parallel_calls=tf.data.AUTOTUNE,\n    )\n    return dataset\n"
  },
  {
    "path": "keras/src/utils/timeseries_dataset_utils_test.py",
    "content": "import numpy as np\nfrom absl.testing import parameterized\n\nfrom keras.src import testing\nfrom keras.src.utils import timeseries_dataset_utils\n\n\nclass TimeseriesDatasetTest(testing.TestCase):\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_basics(self, format):\n        # Test ordering, targets, sequence length, batch size\n        data = np.arange(100)\n        targets = data * 2\n        dataset = timeseries_dataset_utils.timeseries_dataset_from_array(\n            data, targets, sequence_length=9, batch_size=5, format=format\n        )\n        # Expect 19 batches\n        for i, batch in enumerate(dataset):\n            self.assertLen(batch, 2)\n            inputs, targets = batch\n            if i < 18:\n                self.assertEqual(inputs.shape, (5, 9))\n            if i == 18:\n                # Last batch: size 2\n                self.assertEqual(inputs.shape, (2, 9))\n            # Check target values\n            self.assertAllClose(targets, inputs[:, 0] * 2)\n            for j in range(min(5, len(inputs))):\n                # Check each sample in the batch\n                self.assertAllClose(\n                    inputs[j], np.arange(i * 5 + j, i * 5 + j + 9)\n                )\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_timeseries_regression(self, format):\n        # Test simple timeseries regression use case\n        data = np.arange(10)\n        offset = 3\n        targets = data[offset:]\n        dataset = timeseries_dataset_utils.timeseries_dataset_from_array(\n            data, targets, sequence_length=offset, batch_size=1, format=format\n        )\n        i = 0\n        for batch in dataset:\n            self.assertLen(batch, 2)\n            inputs, targets = batch\n            self.assertEqual(inputs.shape, (1, 3))\n            # Check values\n            self.assertAllClose(targets[0], data[offset + i])\n            self.assertAllClose(inputs[0], data[i : i + offset])\n            i += 1\n        self.assertEqual(i, 7)  # Expect 7 batches\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_no_targets(self, format):\n        data = np.arange(50)\n        dataset = timeseries_dataset_utils.timeseries_dataset_from_array(\n            data, None, sequence_length=10, batch_size=5, format=format\n        )\n        # Expect 9 batches\n        i = None\n        for i, batch in enumerate(dataset):\n            if i < 8:\n                self.assertEqual(batch.shape, (5, 10))\n            elif i == 8:\n                self.assertEqual(batch.shape, (1, 10))\n            for j in range(min(5, len(batch))):\n                # Check each sample in the batch\n                self.assertAllClose(\n                    batch[j], np.arange(i * 5 + j, i * 5 + j + 10)\n                )\n        self.assertEqual(i, 8)\n\n    def test_shuffle_tf(self):\n        # Test cross-epoch random order and seed determinism (TF-specific\n        # since grain does not support .take())\n        data = np.arange(10)\n        targets = data * 2\n        dataset = timeseries_dataset_utils.timeseries_dataset_from_array(\n            data,\n            targets,\n            sequence_length=5,\n            batch_size=1,\n            shuffle=True,\n            seed=123,\n            format=\"tf\",\n        )\n        first_seq = None\n        for x, y in dataset.take(1):\n            self.assertNotAllClose(x, np.arange(0, 5))\n            self.assertAllClose(x[:, 0] * 2, y)\n            first_seq = x\n        # Check that a new iteration with the same dataset yields different\n        # results\n        for x, _ in dataset.take(1):\n            self.assertNotAllClose(x, first_seq)\n        # Check determinism with same seed\n        dataset = timeseries_dataset_utils.timeseries_dataset_from_array(\n            data,\n            targets,\n            sequence_length=5,\n            batch_size=1,\n            shuffle=True,\n            seed=123,\n            format=\"tf\",\n        )\n        for x, _ in dataset.take(1):\n            self.assertAllClose(x, first_seq)\n\n    def test_shuffle_grain(self):\n        # Test shuffle with grain format\n        data = np.arange(20)\n        targets = data * 2\n        dataset = timeseries_dataset_utils.timeseries_dataset_from_array(\n            data,\n            targets,\n            sequence_length=5,\n            batch_size=1,\n            shuffle=True,\n            seed=123,\n            format=\"grain\",\n        )\n        batch = next(iter(dataset))\n        self.assertLen(batch, 2)\n        inputs, targets = batch\n        # Verify that inputs and targets are consistent\n        self.assertAllClose(inputs[:, 0] * 2, targets)\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_sampling_rate(self, format):\n        data = np.arange(100)\n        targets = data * 2\n        dataset = timeseries_dataset_utils.timeseries_dataset_from_array(\n            data,\n            targets,\n            sequence_length=9,\n            batch_size=5,\n            sampling_rate=2,\n            format=format,\n        )\n        for i, batch in enumerate(dataset):\n            self.assertLen(batch, 2)\n            inputs, targets = batch\n            if i < 16:\n                self.assertEqual(inputs.shape, (5, 9))\n            if i == 16:\n                # Last batch: size 4\n                self.assertEqual(inputs.shape, (4, 9))\n            # Check target values\n            self.assertAllClose(inputs[:, 0] * 2, targets)\n            for j in range(min(5, len(inputs))):\n                # Check each sample in the batch\n                start_index = i * 5 + j\n                end_index = start_index + 9 * 2\n                self.assertAllClose(\n                    inputs[j], np.arange(start_index, end_index, 2)\n                )\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_sequence_stride(self, format):\n        data = np.arange(100)\n        targets = data * 2\n        dataset = timeseries_dataset_utils.timeseries_dataset_from_array(\n            data,\n            targets,\n            sequence_length=9,\n            batch_size=5,\n            sequence_stride=3,\n            format=format,\n        )\n        for i, batch in enumerate(dataset):\n            self.assertLen(batch, 2)\n            inputs, targets = batch\n            if i < 6:\n                self.assertEqual(inputs.shape, (5, 9))\n            if i == 6:\n                # Last batch: size 1\n                self.assertEqual(inputs.shape, (1, 9))\n            # Check target values\n            self.assertAllClose(inputs[:, 0] * 2, targets)\n            for j in range(min(5, len(inputs))):\n                # Check each sample in the batch\n                start_index = i * 5 * 3 + j * 3\n                end_index = start_index + 9\n                self.assertAllClose(\n                    inputs[j], np.arange(start_index, end_index)\n                )\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_start_and_end_index(self, format):\n        data = np.arange(100)\n        dataset = timeseries_dataset_utils.timeseries_dataset_from_array(\n            data,\n            None,\n            sequence_length=9,\n            batch_size=5,\n            sequence_stride=3,\n            sampling_rate=2,\n            start_index=10,\n            end_index=90,\n            format=format,\n        )\n        for batch in dataset:\n            self.assertLess(np.max(np.array(batch[0])), 90)\n            self.assertGreater(np.min(np.array(batch[0])), 9)\n\n    def test_errors(self):\n        # bad start index\n        with self.assertRaisesRegex(ValueError, \"`start_index` must be \"):\n            _ = timeseries_dataset_utils.timeseries_dataset_from_array(\n                np.arange(10), None, 3, start_index=-1\n            )\n        with self.assertRaisesRegex(ValueError, \"`start_index` must be \"):\n            _ = timeseries_dataset_utils.timeseries_dataset_from_array(\n                np.arange(10), None, 3, start_index=11\n            )\n        # bad end index\n        with self.assertRaisesRegex(ValueError, \"`end_index` must be \"):\n            _ = timeseries_dataset_utils.timeseries_dataset_from_array(\n                np.arange(10), None, 3, end_index=-1\n            )\n        with self.assertRaisesRegex(ValueError, \"`end_index` must be \"):\n            _ = timeseries_dataset_utils.timeseries_dataset_from_array(\n                np.arange(10), None, 3, end_index=11\n            )\n        # bad sampling_rate\n        with self.assertRaisesRegex(ValueError, \"`sampling_rate` must be \"):\n            _ = timeseries_dataset_utils.timeseries_dataset_from_array(\n                np.arange(10), None, 3, sampling_rate=0\n            )\n        # bad sequence stride\n        with self.assertRaisesRegex(ValueError, \"`sequence_stride` must be \"):\n            _ = timeseries_dataset_utils.timeseries_dataset_from_array(\n                np.arange(10), None, 3, sequence_stride=0\n            )\n\n    def test_invalid_format(self):\n        with self.assertRaisesRegex(ValueError, \"`format` should be\"):\n            _ = timeseries_dataset_utils.timeseries_dataset_from_array(\n                np.arange(10), None, 3, format=\"invalid\"\n            )\n\n    @parameterized.named_parameters(\n        (\"tf\", \"tf\"),\n        (\"grain\", \"grain\"),\n    )\n    def test_not_batched(self, format):\n        data = np.arange(100)\n\n        dataset = timeseries_dataset_utils.timeseries_dataset_from_array(\n            data,\n            None,\n            sequence_length=9,\n            batch_size=None,\n            shuffle=True,\n            format=format,\n        )\n        sample = next(iter(dataset))\n        self.assertEqual(len(sample.shape), 1)\n"
  },
  {
    "path": "keras/src/utils/torch_utils.py",
    "content": "import base64\nimport io\n\nfrom packaging.version import parse\n\nfrom keras.src import backend\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers import Layer\nfrom keras.src.ops import convert_to_numpy\nfrom keras.src.ops import convert_to_tensor\nfrom keras.src.saving.serialization_lib import in_safe_mode\n\n\n@keras_export(\"keras.layers.TorchModuleWrapper\")\nclass TorchModuleWrapper(Layer):\n    \"\"\"Torch module wrapper layer.\n\n    `TorchModuleWrapper` is a wrapper class that can turn any\n    `torch.nn.Module` into a Keras layer, in particular by making its\n    parameters trackable by Keras.\n\n    `TorchModuleWrapper` is only compatible with the PyTorch backend and\n    cannot be used with the TensorFlow or JAX backends.\n\n    Args:\n        module: `torch.nn.Module` instance. If it's a `LazyModule`\n            instance, then its parameters must be initialized before\n            passing the instance to `TorchModuleWrapper` (e.g. by calling\n            it once).\n        output_shape :The shape of the output of this layer. It helps Keras\n            perform automatic shape inference.\n        name: The name of the layer (string).\n\n    Example:\n\n    Here's an example of how the `TorchModuleWrapper` can be used with vanilla\n    PyTorch modules.\n\n    ```python\n    import torch\n    import torch.nn as nn\n    import torch.nn.functional as F\n\n    import keras\n    from keras.layers import TorchModuleWrapper\n\n    class Classifier(keras.Model):\n        def __init__(self, **kwargs):\n            super().__init__(**kwargs)\n            # Wrap `torch.nn.Module`s with `TorchModuleWrapper`\n            # if they contain parameters\n            self.conv1 = TorchModuleWrapper(\n                nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3))\n            )\n            self.conv2 = TorchModuleWrapper(\n                nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3))\n            )\n            self.pool = nn.MaxPool2d(kernel_size=(2, 2))\n            self.flatten = nn.Flatten()\n            self.dropout = nn.Dropout(p=0.5)\n            self.fc = TorchModuleWrapper(nn.Linear(1600, 10))\n\n        def call(self, inputs):\n            x = F.relu(self.conv1(inputs))\n            x = self.pool(x)\n            x = F.relu(self.conv2(x))\n            x = self.pool(x)\n            x = self.flatten(x)\n            x = self.dropout(x)\n            x = self.fc(x)\n            return F.softmax(x, dim=1)\n\n\n    model = Classifier()\n    model.build((1, 28, 28))\n    print(\"Output shape:\", model(torch.ones(1, 1, 28, 28).to(\"cuda\")).shape)\n\n    model.compile(\n        loss=\"sparse_categorical_crossentropy\",\n        optimizer=\"adam\",\n        metrics=[\"accuracy\"]\n    )\n    model.fit(train_loader, epochs=5)\n    ```\n    \"\"\"\n\n    def __init__(self, module, name=None, output_shape=None, **kwargs):\n        super().__init__(name=name, **kwargs)\n        import torch.nn as nn\n\n        from keras.src.backend.torch.core import get_device\n\n        if (\n            isinstance(module, nn.modules.lazy.LazyModuleMixin)\n            and module.has_uninitialized_params()\n        ):\n            raise ValueError(\n                \"LazyModules are not supported unless they \"\n                \"are already initialized. \"\n                f\"Received uninitialized LazyModule: module={module}\"\n            )\n\n        self.module = module.to(get_device())\n        self._track_module_parameters()\n        self.output_shape = output_shape\n\n    def parameters(self, recurse=True):\n        return self.module.parameters(recurse=recurse)\n\n    def _track_module_parameters(self):\n        for param in self.module.parameters():\n            # The Variable will reuse the raw `param`\n            # and simply wrap it.\n            variable = backend.Variable(\n                initializer=param, trainable=param.requires_grad\n            )\n            self._track_variable(variable)\n        self.built = True\n\n    def call(self, *args, training=None, **kwargs):\n        if training is False:\n            self.eval()\n        else:\n            self.train()\n        return self.module(*args, **kwargs)\n\n    def save_own_variables(self, store):\n        \"\"\"Saves model's state from `state_dict`.\n        `model.parameters` excludes some of model's state like\n        `BatchNorm` mean and variance. So, use `state_dict` to obtain\n        all of model's state.\n        \"\"\"\n        state_dict = self.module.state_dict()\n        for key in state_dict.keys():\n            store[key] = convert_to_numpy(state_dict[key])\n\n    def load_own_variables(self, store):\n        \"\"\"Loads model's state via `state_dict`.\"\"\"\n        state_dict = {}\n        for key in store.keys():\n            if isinstance(key, bytes):\n                key = key.decode()\n            state_dict[key] = convert_to_tensor(store[key])\n        self.module.load_state_dict(state_dict)\n\n    def compute_output_shape(self, input_shape):\n        if self.output_shape is None:\n            return super().compute_output_shape(input_shape)\n        return self.output_shape\n\n    def get_config(self):\n        base_config = super().get_config()\n        import torch\n\n        buffer = io.BytesIO()\n        torch.save(self.module, buffer)\n        # Encode the buffer using base64 to ensure safe serialization\n        buffer_b64 = base64.b64encode(buffer.getvalue()).decode(\"ascii\")\n        config = {\n            \"module\": buffer_b64,\n            \"output_shape\": self.output_shape,\n        }\n        return {**base_config, **config}\n\n    @classmethod\n    def from_config(cls, config):\n        import torch\n\n        if \"module\" in config:\n            if in_safe_mode():\n                raise ValueError(\n                    \"Requested the deserialization of a `torch.nn.Module` \"\n                    \"object via `torch.load()`. This carries a potential risk \"\n                    \"of arbitrary code execution and thus it is disallowed by \"\n                    \"default. If you trust the source of the artifact, you can \"\n                    \"override this error by passing `safe_mode=False` to the \"\n                    \"loading function, or calling \"\n                    \"`keras.config.enable_unsafe_deserialization().\"\n                )\n\n            # Decode the base64 string back to bytes\n            buffer_bytes = base64.b64decode(config[\"module\"].encode(\"ascii\"))\n            buffer = io.BytesIO(buffer_bytes)\n            config[\"module\"] = torch.load(buffer, weights_only=False)\n        return cls(**config)\n\n\ndef no_grad(orig_func):\n    import torch\n\n    if parse(torch.__version__) >= parse(\"2.1.0\"):\n        return torch.no_grad(orig_func)\n    else:\n        return orig_func\n"
  },
  {
    "path": "keras/src/utils/torch_utils_test.py",
    "content": "import os\n\nimport numpy as np\nimport pytest\nimport torch\nfrom absl.testing import parameterized\n\nimport keras\nfrom keras.src import backend\nfrom keras.src import layers\nfrom keras.src import models\nfrom keras.src import saving\nfrom keras.src import testing\nfrom keras.src.backend.torch.core import get_device\nfrom keras.src.utils.torch_utils import TorchModuleWrapper\n\n\nclass Classifier(models.Model):\n    def __init__(\n        self, use_batch_norm=False, num_torch_layers=1, *args, **kwargs\n    ):\n        super().__init__(*args, **kwargs)\n        self.use_batch_norm = use_batch_norm\n        self.num_torch_layers = num_torch_layers\n        self.torch_wrappers = []\n        for _ in range(num_torch_layers):\n            modules = [torch.nn.Linear(2, 2)]\n            if use_batch_norm:\n                modules.append(torch.nn.BatchNorm1d(2))\n            torch_model = torch.nn.Sequential(*modules)\n            self.torch_wrappers.append(TorchModuleWrapper(torch_model))\n        self.fc = layers.Dense(1)\n\n    def call(self, x, training=None):\n        for wrapper in self.torch_wrappers:\n            x = wrapper(x, training=training)\n        return self.fc(x)\n\n    def get_config(self):\n        config = super().get_config()\n        config[\"use_batch_norm\"] = self.use_batch_norm\n        config[\"num_torch_layers\"] = self.num_torch_layers\n        return config\n\n\nclass ClassifierWithNoSpecialCasing(models.Model):\n    def __init__(self, *args, **kwargs):\n        super().__init__(*args, **kwargs)\n        self.fc1 = torch.nn.Linear(2, 4)\n        self.bn1 = torch.nn.BatchNorm1d(4)\n        self.fc2 = torch.nn.Linear(4, 4)\n        self.fc3 = layers.Dense(2)\n\n    def call(self, x, training=None):\n        return self.fc3(self.fc2(self.bn1(self.fc1(x))))\n\n\n@pytest.mark.skipif(\n    backend.backend() != \"torch\", reason=\"Requires torch backend\"\n)\nclass TorchUtilsTest(testing.TestCase):\n    @parameterized.parameters(\n        {\"use_batch_norm\": False, \"num_torch_layers\": 1},\n        {\"use_batch_norm\": True, \"num_torch_layers\": 1},\n    )\n    def test_basic_usage(self, use_batch_norm, num_torch_layers):\n        model = Classifier(use_batch_norm, num_torch_layers)\n        self.assertEqual(len(model.layers), 2)\n        # Linear - Weights, bias, BN - beta, gamma\n        torch_trainable_count = 0\n        for i, layer in zip(range(num_torch_layers), model.torch_wrappers):\n            layer_trainable_count = 2\n            if use_batch_norm:\n                layer_trainable_count += 2\n            self.assertEqual(\n                len(layer.trainable_weights), layer_trainable_count\n            )\n            torch_trainable_count += layer_trainable_count\n        model(np.random.random((3, 2)))\n        self.assertEqual(len(model.layers), 2 * num_torch_layers)\n        self.assertEqual(\n            len(model.trainable_weights), torch_trainable_count + 2\n        )\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n        model.fit(np.random.random((3, 2)), np.random.random((3, 1)))\n\n    @parameterized.named_parameters(\n        (\n            \"explicit_torch_wrapper\",\n            Classifier,\n            {\"use_batch_norm\": True, \"num_torch_layers\": 1},\n        ),\n        (\"implicit_torch_wrapper\", ClassifierWithNoSpecialCasing, {}),\n    )\n    def test_training_args(self, cls, kwargs):\n        model = cls(**kwargs)\n        model(np.random.random((3, 2)), training=False)  # Eager call to build\n        ref_weights = model.get_weights()\n        ref_running_mean = backend.convert_to_numpy(\n            model.torch_wrappers[0].module[-1].running_mean\n            if cls is Classifier\n            else model.bn1.module.running_mean\n        )\n\n        # Test training=False doesn't affect model weights\n        model(np.random.random((3, 2)), training=False)\n        weights = model.get_weights()\n        for w, ref_w in zip(weights, ref_weights):\n            self.assertAllClose(w, ref_w)\n\n        # Test training=None affects BN's stats\n        model.set_weights(ref_weights)  # Restore previous weights\n        model(np.random.random((3, 2)))\n        running_mean = backend.convert_to_numpy(\n            model.torch_wrappers[0].module[-1].running_mean\n            if cls is Classifier\n            else model.bn1.module.running_mean\n        )\n        self.assertNotAllClose(running_mean, ref_running_mean)\n\n        # Test training=True affects BN's stats\n        model.set_weights(ref_weights)  # Restore previous weights\n        model(np.random.random((3, 2)), training=True)\n        running_mean = backend.convert_to_numpy(\n            model.torch_wrappers[0].module[-1].running_mean\n            if cls is Classifier\n            else model.bn1.module.running_mean\n        )\n        self.assertNotAllClose(running_mean, ref_running_mean)\n\n    def test_module_autowrapping(self):\n        model = ClassifierWithNoSpecialCasing()\n        self.assertIsInstance(model.fc1, TorchModuleWrapper)\n        self.assertIsInstance(model.bn1, TorchModuleWrapper)\n        self.assertIsInstance(model.fc2, TorchModuleWrapper)\n        self.assertFalse(isinstance(model.fc3, TorchModuleWrapper))\n        self.assertEqual(len(model.fc1.trainable_weights), 2)\n        self.assertEqual(len(model.bn1.trainable_weights), 2)\n        self.assertEqual(len(model.fc2.trainable_weights), 2)\n        model(np.random.random((3, 2)))\n        self.assertEqual(len(model.layers), 4)\n        self.assertEqual(len(model.fc3.trainable_weights), 2)\n        self.assertEqual(len(model.trainable_weights), 8)\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n        model.fit(np.random.random((3, 2)), np.random.random((3, 2)))\n\n    def test_load_weights_autowrapping(self):\n        # Test loading weights\n        temp_filepath = os.path.join(self.get_temp_dir(), \"mymodel.weights.h5\")\n        model = ClassifierWithNoSpecialCasing()\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n        x, y = np.random.random((3, 2)), np.random.random((3, 1))\n        x_test, y_test = np.random.random((3, 2)), np.random.random((3, 1))\n        model.fit(x, y)\n        ref_loss = model.evaluate(x_test, y_test)\n        model.save_weights(temp_filepath)\n\n        new_model = ClassifierWithNoSpecialCasing()\n        new_model(np.random.random((3, 2)))\n        new_model.compile(optimizer=\"sgd\", loss=\"mse\")\n        new_model.load_weights(temp_filepath)\n        for ref_w, new_w in zip(model.get_weights(), new_model.get_weights()):\n            self.assertAllClose(ref_w, new_w, atol=1e-5)\n        loss = new_model.evaluate(x_test, y_test)\n        self.assertAllClose(ref_loss, loss, atol=1e-5)\n\n    def test_serialize_model_autowrapping(self):\n        # Test loading saved model\n        temp_filepath = os.path.join(self.get_temp_dir(), \"mymodel.keras\")\n        model = ClassifierWithNoSpecialCasing()\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n        x, y = np.random.random((3, 2)), np.random.random((3, 1))\n        x_test, y_test = np.random.random((3, 2)), np.random.random((3, 1))\n        model.fit(x, y)\n        ref_loss = model.evaluate(x_test, y_test)\n        model.save(temp_filepath)\n\n        new_model = saving.load_model(temp_filepath)\n        for ref_w, new_w in zip(model.get_weights(), new_model.get_weights()):\n            self.assertAllClose(ref_w, new_w, atol=1e-5)\n        loss = new_model.evaluate(x_test, y_test)\n        self.assertAllClose(ref_loss, loss, atol=1e-5)\n\n    @parameterized.parameters(\n        {\"use_batch_norm\": False, \"num_torch_layers\": 1},\n        {\"use_batch_norm\": True, \"num_torch_layers\": 1},\n        {\"use_batch_norm\": False, \"num_torch_layers\": 2},\n        {\"use_batch_norm\": True, \"num_torch_layers\": 2},\n    )\n    def test_load_weights(self, use_batch_norm, num_torch_layers):\n        # Test loading weights\n        temp_filepath = os.path.join(self.get_temp_dir(), \"mymodel.weights.h5\")\n        model = Classifier(use_batch_norm, num_torch_layers)\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n        x, y = np.random.random((3, 2)), np.random.random((3, 1))\n        x_test, y_test = np.random.random((3, 2)), np.random.random((3, 1))\n        model.fit(x, y)\n        ref_loss = model.evaluate(x_test, y_test)\n        model.save_weights(temp_filepath)\n\n        new_model = Classifier(use_batch_norm, num_torch_layers)\n        new_model(np.random.random((3, 2)))\n        new_model.compile(optimizer=\"sgd\", loss=\"mse\")\n        new_model.load_weights(temp_filepath)\n        for ref_w, new_w in zip(model.get_weights(), new_model.get_weights()):\n            self.assertAllClose(ref_w, new_w, atol=1e-5)\n        loss = new_model.evaluate(x_test, y_test)\n        self.assertAllClose(ref_loss, loss, atol=1e-5)\n\n    @parameterized.parameters(\n        {\"use_batch_norm\": False, \"num_torch_layers\": 1},\n        {\"use_batch_norm\": True, \"num_torch_layers\": 1},\n        {\"use_batch_norm\": False, \"num_torch_layers\": 2},\n        {\"use_batch_norm\": True, \"num_torch_layers\": 2},\n    )\n    def test_serialize_model(self, use_batch_norm, num_torch_layers):\n        # Test loading saved model\n        temp_filepath = os.path.join(self.get_temp_dir(), \"mymodel.keras\")\n        model = Classifier(use_batch_norm, num_torch_layers)\n        model.compile(optimizer=\"sgd\", loss=\"mse\")\n        x, y = np.random.random((3, 2)), np.random.random((3, 1))\n        x_test, y_test = np.random.random((3, 2)), np.random.random((3, 1))\n        model.fit(x, y)\n        ref_loss = model.evaluate(x_test, y_test)\n        model.save(temp_filepath)\n\n        new_model = saving.load_model(temp_filepath)\n        for ref_w, new_w in zip(model.get_weights(), new_model.get_weights()):\n            self.assertAllClose(ref_w, new_w, atol=1e-5)\n        loss = new_model.evaluate(x_test, y_test)\n        self.assertAllClose(ref_loss, loss, atol=1e-5)\n\n    def test_from_config(self):\n        module = torch.nn.Sequential(torch.nn.Linear(2, 4))\n        mw = TorchModuleWrapper(module)\n        config = mw.get_config()\n        new_mw = TorchModuleWrapper.from_config(config)\n        for ref_w, new_w in zip(mw.get_weights(), new_mw.get_weights()):\n            self.assertAllClose(ref_w, new_w, atol=1e-5)\n\n    def test_build_model(self):\n        x = keras.Input([4])\n        z = TorchModuleWrapper(torch.nn.Linear(4, 8), output_shape=[None, 8])(x)\n        y = TorchModuleWrapper(torch.nn.Linear(8, 16), output_shape=[None, 16])(\n            z\n        )\n        model = keras.Model(x, y)\n        self.assertEqual(model.predict(np.zeros([5, 4])).shape, (5, 16))\n        self.assertEqual(model(np.zeros([5, 4])).shape, (5, 16))\n\n    @parameterized.named_parameters(\n        (\"safe_mode\", True),\n        (\"unsafe_mode\", False),\n    )\n    def test_save_load(self, safe_mode):\n        @keras.saving.register_keras_serializable()\n        class M(keras.Model):\n            def __init__(self, module, **kwargs):\n                super().__init__(**kwargs)\n                self.module = module\n\n            def call(self, x):\n                return self.module(x)\n\n            def get_config(self):\n                base_config = super().get_config()\n                config = {\"module\": self.module}\n                return {**base_config, **config}\n\n            @classmethod\n            def from_config(cls, config):\n                config[\"module\"] = saving.deserialize_keras_object(\n                    config[\"module\"]\n                )\n                return cls(**config)\n\n        m = M(torch.nn.Conv2d(1, 10, kernel_size=(3, 3)))\n        device = get_device()  # Get the current device (e.g., \"cuda\" or \"cpu\")\n        x = torch.ones(\n            (10, 1, 28, 28), device=device\n        )  # Place input on the correct device\n        ref_output = m(x)\n        temp_filepath = os.path.join(self.get_temp_dir(), \"mymodel.keras\")\n        m.save(temp_filepath)\n\n        if safe_mode:\n            with self.assertRaisesRegex(ValueError, \"arbitrary code execution\"):\n                saving.load_model(temp_filepath, safe_mode=safe_mode)\n        else:\n            new_model = saving.load_model(temp_filepath, safe_mode=safe_mode)\n            self.assertAllClose(new_model(x), ref_output)\n"
  },
  {
    "path": "keras/src/utils/traceback_utils.py",
    "content": "import inspect\nimport os\nimport traceback\nimport types\nfrom functools import wraps\n\nfrom keras.src import backend\nfrom keras.src import tree\nfrom keras.src.api_export import keras_export\nfrom keras.src.backend.common import global_state\n\n_EXCLUDED_PATHS = (\n    os.path.abspath(os.path.join(__file__, \"..\", \"..\")),\n    os.path.join(\"tensorflow\", \"python\"),\n)\n\n\n@keras_export(\"keras.config.enable_traceback_filtering\")\ndef enable_traceback_filtering():\n    \"\"\"Turn on traceback filtering.\n\n    Raw Keras tracebacks (also known as stack traces)\n    involve many internal frames, which can be\n    challenging to read through, while not being actionable for end users.\n    By default, Keras filters internal frames in most exceptions that it\n    raises, to keep traceback short, readable, and focused on what's\n    actionable for you (your own code).\n\n    See also `keras.config.disable_traceback_filtering()` and\n    `keras.config.is_traceback_filtering_enabled()`.\n\n    If you have previously disabled traceback filtering via\n    `keras.config.disable_traceback_filtering()`, you can re-enable it via\n    `keras.config.enable_traceback_filtering()`.\n    \"\"\"\n    global_state.set_global_attribute(\"traceback_filtering\", True)\n\n\n@keras_export(\"keras.config.disable_traceback_filtering\")\ndef disable_traceback_filtering():\n    \"\"\"Turn off traceback filtering.\n\n    Raw Keras tracebacks (also known as stack traces)\n    involve many internal frames, which can be\n    challenging to read through, while not being actionable for end users.\n    By default, Keras filters internal frames in most exceptions that it\n    raises, to keep traceback short, readable, and focused on what's\n    actionable for you (your own code).\n\n    See also `keras.config.enable_traceback_filtering()` and\n    `keras.config.is_traceback_filtering_enabled()`.\n\n    If you have previously disabled traceback filtering via\n    `keras.config.disable_traceback_filtering()`, you can re-enable it via\n    `keras.config.enable_traceback_filtering()`.\n    \"\"\"\n    global_state.set_global_attribute(\"traceback_filtering\", False)\n\n\n@keras_export(\"keras.config.is_traceback_filtering_enabled\")\ndef is_traceback_filtering_enabled():\n    \"\"\"Check if traceback filtering is enabled.\n\n    Raw Keras tracebacks (also known as stack traces)\n    involve many internal frames, which can be\n    challenging to read through, while not being actionable for end users.\n    By default, Keras filters internal frames in most exceptions that it\n    raises, to keep traceback short, readable, and focused on what's\n    actionable for you (your own code).\n\n    See also `keras.config.enable_traceback_filtering()` and\n    `keras.config.disable_traceback_filtering()`.\n\n    If you have previously disabled traceback filtering via\n    `keras.config.disable_traceback_filtering()`, you can re-enable it via\n    `keras.config.enable_traceback_filtering()`.\n\n    Returns:\n        Boolean, `True` if traceback filtering is enabled,\n        and `False` otherwise.\n    \"\"\"\n    return global_state.get_global_attribute(\"traceback_filtering\", True)\n\n\ndef include_frame(fname):\n    for exclusion in _EXCLUDED_PATHS:\n        if exclusion in fname:\n            return False\n    return True\n\n\ndef _process_traceback_frames(tb):\n    \"\"\"Iterate through traceback frames and return a new, filtered traceback.\"\"\"\n    last_tb = None\n    tb_list = list(traceback.walk_tb(tb))\n    for f, line_no in reversed(tb_list):\n        if include_frame(f.f_code.co_filename):\n            last_tb = types.TracebackType(last_tb, f, f.f_lasti, line_no)\n    if last_tb is None and tb_list:\n        # If no frames were kept during filtering, create a new traceback\n        # from the outermost function.\n        f, line_no = tb_list[-1]\n        last_tb = types.TracebackType(last_tb, f, f.f_lasti, line_no)\n    return last_tb\n\n\ndef filter_traceback(fn):\n    \"\"\"Filter out Keras-internal traceback frames in exceptions raised by fn.\"\"\"\n\n    @wraps(fn)\n    def error_handler(*args, **kwargs):\n        if not is_traceback_filtering_enabled():\n            return fn(*args, **kwargs)\n\n        filtered_tb = None\n        try:\n            return fn(*args, **kwargs)\n        except Exception as e:\n            filtered_tb = _process_traceback_frames(e.__traceback__)\n            # To get the full stack trace, call:\n            # `keras.config.disable_traceback_filtering()`\n            raise e.with_traceback(filtered_tb) from None\n        finally:\n            del filtered_tb\n\n    return error_handler\n\n\ndef inject_argument_info_in_traceback(fn, object_name=None):\n    \"\"\"Add information about call argument values to an error message.\n\n    Arguments:\n        fn: Function to wrap. Exceptions raised by the this function will be\n            re-raised with additional information added to the error message,\n            displaying the values of the different arguments that the function\n            was called with.\n        object_name: String, display name of the class/function being called,\n            e.g. `'layer \"layer_name\" (LayerClass)'`.\n\n    Returns:\n        A wrapped version of `fn`.\n    \"\"\"\n    if backend.backend() == \"tensorflow\":\n        from tensorflow import errors as tf_errors\n    else:\n        tf_errors = None\n\n    @wraps(fn)\n    def error_handler(*args, **kwargs):\n        if not is_traceback_filtering_enabled():\n            return fn(*args, **kwargs)\n\n        signature = None\n        bound_signature = None\n        try:\n            return fn(*args, **kwargs)\n        except Exception as e:\n            if hasattr(e, \"_keras_call_info_injected\"):\n                # Only inject info for the innermost failing call\n                raise e\n            signature = inspect.signature(fn)\n            try:\n                # The first argument is `self`, so filter it out\n                bound_signature = signature.bind(*args, **kwargs)\n            except TypeError:\n                # Likely unbindable arguments\n                raise e\n\n            # Add argument context\n            arguments_context = []\n            for arg in list(signature.parameters.values()):\n                if arg.name in bound_signature.arguments:\n                    value = tree.map_structure(\n                        format_argument_value,\n                        bound_signature.arguments[arg.name],\n                    )\n                else:\n                    value = arg.default\n                arguments_context.append(f\"  • {arg.name}={value}\")\n            if arguments_context:\n                arguments_context = \"\\n\".join(arguments_context)\n                # Get original error message and append information to it.\n                if tf_errors is not None and isinstance(e, tf_errors.OpError):\n                    message = e.message\n                elif e.args:\n                    # Canonically, the 1st argument in an exception is the error\n                    # message. This works for all built-in Python exceptions.\n                    message = e.args[0]\n                else:\n                    message = \"\"\n                display_name = f\"{object_name if object_name else fn.__name__}\"\n                message = (\n                    f\"Exception encountered when calling {display_name}.\\n\\n\"\n                    f\"\\x1b[1m{message}\\x1b[0m\\n\\n\"\n                    f\"Arguments received by {display_name}:\\n\"\n                    f\"{arguments_context}\"\n                )\n\n                # Reraise exception, with added context\n                if tf_errors is not None and isinstance(e, tf_errors.OpError):\n                    new_e = e.__class__(e.node_def, e.op, message, e.error_code)\n                else:\n                    try:\n                        # For standard exceptions such as ValueError, TypeError,\n                        # etc.\n                        new_e = e.__class__(message)\n                    except TypeError:\n                        # For any custom error that doesn't have a standard\n                        # signature.\n                        new_e = RuntimeError(message)\n                new_e._keras_call_info_injected = True\n            else:\n                new_e = e\n            raise new_e.with_traceback(e.__traceback__) from None\n        finally:\n            del signature\n            del bound_signature\n\n    return error_handler\n\n\ndef format_argument_value(value):\n    if backend.is_tensor(value):\n        # Simplified representation for eager / graph tensors\n        # to keep messages readable\n        if backend.backend() == \"tensorflow\":\n            tensor_cls = \"tf.Tensor\"\n        elif backend.backend() == \"jax\":\n            tensor_cls = \"jnp.ndarray\"\n        elif backend.backend() == \"torch\":\n            tensor_cls = \"torch.Tensor\"\n        elif backend.backend() == \"numpy\":\n            tensor_cls = \"np.ndarray\"\n        else:\n            tensor_cls = \"array\"\n\n        return (\n            f\"{tensor_cls}(shape={value.shape}, \"\n            f\"dtype={backend.standardize_dtype(value.dtype)})\"\n        )\n    return repr(value)\n"
  },
  {
    "path": "keras/src/utils/tracking.py",
    "content": "from collections import OrderedDict\nfrom functools import wraps\n\nfrom keras.src import tree\nfrom keras.src.backend.common.global_state import get_global_attribute\nfrom keras.src.backend.common.global_state import set_global_attribute\nfrom keras.src.utils import python_utils\n\n\nclass DotNotTrackScope:\n    def __enter__(self):\n        self.original_value = is_tracking_enabled()\n        set_global_attribute(\"tracking_on\", False)\n\n    def __exit__(self, *args, **kwargs):\n        set_global_attribute(\"tracking_on\", self.original_value)\n\n\ndef is_tracking_enabled():\n    return get_global_attribute(\"tracking_on\", True)\n\n\ndef no_automatic_dependency_tracking(fn):\n    @wraps(fn)\n    def wrapper(*args, **kwargs):\n        with DotNotTrackScope():\n            return fn(*args, **kwargs)\n\n    return wrapper\n\n\nclass Tracker:\n    \"\"\"Attribute tracker, used for e.g. Variable tracking.\n\n    Monitors certain attribute types and places matching\n    objects into user provided tracking collections.\n\n    Also passively tracks certain mutable collections\n    (e.g. dict and list) ensuring that items added after\n    initialization are still tracked. This is done by wrapping\n    these collections in tracking-aware proxy objects.\n\n    Example:\n\n    ```python\n    def __init__(self):\n        self.tracker = Tracker(\n            # Format: `name: (test_fn, store)`\n            {\n                \"variables\":\n                    (lambda x: isinstance(x, Variable), self._variables),\n                \"metrics\": (lambda x: isinstance(x, Metric), self._metrics),\n                \"layers\": (lambda x: isinstance(x, Layer), self._layers),\n            }\n        )\n\n    def __setattr__(self, name, value):\n        if hasattr(self, \"_tracker\"):\n            value = self._tracker.track(value)\n        return super().__setattr__(name, value)\n    ```\n    \"\"\"\n\n    def __init__(self, config, exclusions=None):\n        self.config = config\n        self.stored_ids = {name: set() for name in self.config.keys()}\n        self.locked = False\n        self._lock_violation_msg = None\n        self.exclusions = exclusions or {}\n\n    def track(self, attr):\n        if not is_tracking_enabled():\n            return attr\n\n        for store_name, (is_attr_type, _) in self.config.items():\n            if is_attr_type(attr):\n                if store_name in self.exclusions:\n                    for excl in self.exclusions[store_name]:\n                        if self.is_in_store(excl, attr):\n                            return attr\n                if not self.is_in_store(store_name, attr):\n                    self.add_to_store(store_name, attr)\n                return attr\n        if isinstance(attr, tuple) and hasattr(attr, \"_fields\"):\n            # Named tuple case.\n            wrapped_attr = {}\n            for name, e in attr._asdict().items():\n                wrapped_attr[name] = self.track(e)\n            return attr.__class__(**wrapped_attr)\n        if isinstance(attr, tuple):\n            wrapped_attr = []\n            for e in attr:\n                wrapped_attr.append(self.track(e))\n            return attr.__class__(wrapped_attr)\n        elif isinstance(attr, list):\n            return TrackedList(attr, self)\n        elif isinstance(attr, OrderedDict):\n            return TrackedOrderedDict(attr, self)\n        elif isinstance(attr, dict):\n            return TrackedDict(attr, self)\n        elif isinstance(attr, set):\n            return TrackedSet(attr, self)\n        return attr\n\n    def untrack(self, value):\n        for store_name in self.stored_ids.keys():\n            if id(value) in self.stored_ids[store_name]:\n                self.stored_ids[store_name].remove(id(value))\n                python_utils.remove_by_id(self.config[store_name][1], value)\n\n    def lock(self, msg=None):\n        self.locked = True\n        if msg is not None:\n            self._lock_violation_msg = msg\n\n    def unlock(self):\n        self.locked = False\n\n    def add_to_store(self, store_name, value):\n        if self.locked:\n            raise ValueError(self._lock_violation_msg)\n        self.config[store_name][1].append(value)\n        self.stored_ids[store_name].add(id(value))\n\n    def is_in_store(self, store_name, value):\n        return id(value) in self.stored_ids[store_name]\n\n    def replace_tracked_value(self, store_name, old_value, new_value):\n        if not self.is_in_store(store_name, old_value):\n            raise ValueError(f\"Unknown value: {old_value}\")\n        store_list = self.config[store_name][1]\n        index = store_list.index(old_value)\n        store_list[index] = new_value\n        self.stored_ids[store_name].remove(id(old_value))\n        self.stored_ids[store_name].add(id(new_value))\n\n\n@tree.register_tree_node_class\nclass TrackedList(list):\n    def __init__(self, values=None, tracker=None):\n        self.tracker = tracker\n        if tracker and values:\n            values = [tracker.track(v) for v in values]\n        super().__init__(values or [])\n\n    def append(self, value):\n        if self.tracker:\n            self.tracker.track(value)\n        super().append(value)\n\n    def insert(self, index, value):\n        if self.tracker:\n            self.tracker.track(value)\n        super().insert(index, value)\n\n    def extend(self, values):\n        if self.tracker:\n            values = [self.tracker.track(v) for v in values]\n        super().extend(values)\n\n    def remove(self, value):\n        if self.tracker:\n            self.tracker.untrack(value)\n        try:\n            super().remove(value)\n        except ValueError:\n            python_utils.remove_by_id(self, value)\n\n    def pop(self, index=-1):\n        if self.tracker:\n            value = self[index]\n            self.tracker.untrack(value)\n            return super().pop(index)\n        else:\n            return super().pop(index)\n\n    def clear(self):\n        if self.tracker:\n            for value in self:\n                self.tracker.untrack(value)\n        super().clear()\n\n    def __delitem__(self, index):\n        value = self[index]  # Get value before removing\n        super().__delitem__(index)\n        if self.tracker:\n            self.tracker.untrack(value)\n\n    def tree_flatten(self):\n        # For optree / dmtree\n        return (self, None)\n\n    @classmethod\n    def tree_unflatten(cls, metadata, children):\n        # For optree / dmtree\n        return cls(children)\n\n    def torchtree_flatten(self):\n        # For torchtree\n        # Returns (values, metadata)\n        return (self, None)\n\n    @classmethod\n    def torchtree_unflatten(cls, children, metadata):\n        # For torchtree\n        # Requires (children, metadata)\n        return cls(children)\n\n    def torchtree_flatten_with_keys(self):\n        # For torchtree\n        # Returns (children, metadata)\n        from torch.utils import _pytree as torch_tree\n\n        values, context = self.torchtree_flatten()\n        return [\n            (torch_tree.SequenceKey(i), v) for i, v in enumerate(values)\n        ], context\n\n\n@tree.register_tree_node_class\nclass TrackedDict(dict):\n    def __init__(self, values=None, tracker=None):\n        self.tracker = tracker\n        if tracker and values:\n            # Accept either a mapping (with .items()) or an iterable of\n            # (key, value) pairs (e.g. a zip object). Normalize to an\n            # items iterator before tracking elements.\n            if hasattr(values, \"items\"):\n                items_iter = values.items()\n            else:\n                items_iter = values\n            values = {k: tracker.track(v) for k, v in items_iter}\n        super().__init__(values or {})\n\n    def __setitem__(self, key, value):\n        if self.tracker:\n            self.tracker.track(value)\n        super().__setitem__(key, value)\n\n    def update(self, mapping):\n        if self.tracker:\n            mapping = {k: self.tracker.track(v) for k, v in mapping.items()}\n        super().update(mapping)\n\n    def pop(self, key, *args):\n        if len(args) > 1:\n            raise TypeError(\n                f\"pop expected at most 2 arguments, got {1 + len(args)}\"\n            )\n\n        if not self.tracker:\n            return super().pop(key, *args)\n\n        try:\n            value = super().pop(key)\n            self.tracker.untrack(value)\n            return value\n        except KeyError:\n            if args:\n                return args[0]\n            raise\n\n    def popitem(self):\n        key, value = super().popitem()\n        if self.tracker:\n            self.tracker.untrack(value)\n        return key, value\n\n    def clear(self):\n        if self.tracker:\n            for value in self.values():\n                self.tracker.untrack(value)\n        super().clear()\n\n    def tree_flatten(self):\n        # For optree / dmtree\n        keys = sorted(list(self.keys()))\n        values = [self[k] for k in keys]\n        return values, keys, keys\n\n    @classmethod\n    def tree_unflatten(cls, keys, values):\n        # For optree / dmtree\n        return cls(zip(keys, values))\n\n    def torchtree_flatten(self):\n        # For torch_tree\n        # Returns (values, metadata)\n        keys = sorted(list(self.keys()))\n        values = [self[k] for k in keys]\n        return values, keys\n\n    @classmethod\n    def torchtree_unflatten(cls, values, keys):\n        # For torch_tree\n        # Requires (children, metadata)\n        return cls(zip(keys, values))\n\n    def torchtree_flatten_with_keys(self):\n        # For torchtree\n        # Returns (children, metadata)\n        from torch.utils import _pytree as torch_tree\n\n        values, context = self.torchtree_flatten()\n        return [\n            (torch_tree.MappingKey(k), v) for k, v in zip(context, values)\n        ], context\n\n\n@tree.register_tree_node_class\nclass TrackedOrderedDict(OrderedDict):\n    def __init__(self, values=None, tracker=None):\n        self.tracker = tracker\n        if tracker and values:\n            if hasattr(values, \"items\"):\n                items_iter = values.items()\n            else:\n                items_iter = dict(values).items()\n            values = OrderedDict((k, tracker.track(v)) for k, v in items_iter)\n        super().__init__(values or OrderedDict())\n\n    def __setitem__(self, key, value):\n        if self.tracker:\n            self.tracker.track(value)\n        super().__setitem__(key, value)\n\n    def update(self, mapping):\n        if self.tracker:\n            mapping = OrderedDict(\n                (k, self.tracker.track(v)) for k, v in mapping.items()\n            )\n        super().update(mapping)\n\n    def pop(self, key, *args):\n        if len(args) > 1:\n            raise TypeError(\n                f\"pop expected at most 2 arguments, got {1 + len(args)}\"\n            )\n        if not self.tracker:\n            return super().pop(key, *args)\n        try:\n            value = super().pop(key)\n            self.tracker.untrack(value)\n            return value\n        except KeyError:\n            if args:\n                return args[0]\n            raise\n\n    def popitem(self, last=True):\n        key, value = super().popitem(last=last)\n        if self.tracker:\n            self.tracker.untrack(value)\n        return key, value\n\n    def clear(self):\n        if self.tracker:\n            for value in self.values():\n                self.tracker.untrack(value)\n        super().clear()\n\n    def tree_flatten(self):\n        keys = list(self.keys())\n        values = [self[k] for k in keys]\n        return values, keys, keys\n\n    @classmethod\n    def tree_unflatten(cls, keys, values):\n        return cls(zip(keys, values))\n\n    def torchtree_flatten(self):\n        keys = list(self.keys())\n        values = [self[k] for k in keys]\n        return values, keys\n\n    @classmethod\n    def torchtree_unflatten(cls, values, keys):\n        return cls(zip(keys, values))\n\n    def torchtree_flatten_with_keys(self):\n        from torch.utils import _pytree as torch_tree\n\n        values, context = self.torchtree_flatten()\n        return [\n            (torch_tree.MappingKey(k), v) for k, v in zip(context, values)\n        ], context\n\n\n@tree.register_tree_node_class\nclass TrackedSet(set):\n    def __init__(self, values=None, tracker=None):\n        self.tracker = tracker\n        if tracker and values:\n            values = {tracker.track(v) for v in values}\n        super().__init__(values or [])\n\n    def add(self, value):\n        if self.tracker:\n            self.tracker.track(value)\n        super().add(value)\n\n    def update(self, values):\n        if self.tracker:\n            values = [self.tracker.track(v) for v in values]\n        super().update(values)\n\n    def remove(self, value):\n        if self.tracker:\n            self.tracker.untrack(value)\n        super().remove(value)\n\n    def pop(self):\n        value = super().pop()\n        if self.tracker:\n            self.tracker.untrack(value)\n        return value\n\n    def clear(self):\n        if self.tracker:\n            for value in self:\n                self.tracker.untrack(value)\n        super().clear()\n\n    def tree_flatten(self):\n        # For optree / dmtree\n        return (self, None)\n\n    @classmethod\n    def tree_unflatten(cls, metadata, children):\n        # For optree / dmtree\n        return cls(children)\n\n    def torchtree_flatten(self):\n        # For torchtree\n        # Returns (values, metadata)\n        return (self, None)\n\n    @classmethod\n    def torchtree_unflatten(cls, children, metadata):\n        # For torchtree\n        # Requires (values, metadata)\n        return cls(children)\n\n    def torchtree_flatten_with_keys(self):\n        # For torchtree\n        # Returns (children, metadata)\n        from torch.utils import _pytree as torch_tree\n\n        values, context = self.torchtree_flatten()\n        return [\n            (torch_tree.SequenceKey(i), v) for i, v in enumerate(values)\n        ], context\n"
  },
  {
    "path": "keras/src/utils/tracking_test.py",
    "content": "import collections\n\nfrom keras.src import backend\nfrom keras.src import testing\nfrom keras.src.utils import tracking\n\n\nclass TrackingTest(testing.TestCase):\n    def test_untracking_in_tracked_list(self):\n        tracked_variables = []\n        tracker = tracking.Tracker(\n            {\n                \"variables\": (\n                    lambda x: isinstance(x, backend.Variable),\n                    tracked_variables,\n                ),\n            }\n        )\n        v1 = backend.Variable(1.0)\n        v2 = backend.Variable(2.0)\n        lst = tracking.TrackedList([], tracker)\n        lst.append(v1)\n        lst.append(float(\"nan\"))\n        lst.append(v2)\n        lst.append(0)\n\n        self.assertLen(tracked_variables, 2)\n        self.assertEqual(tracked_variables[0], v1)\n        self.assertEqual(tracked_variables[1], v2)\n\n        lst.remove(v1)\n        self.assertLen(lst, 3)\n        self.assertLen(tracked_variables, 1)\n\n        lst.remove(v2)\n        self.assertLen(lst, 2)\n        self.assertLen(tracked_variables, 0)\n\n        lst2 = tracking.TrackedList([], tracker)\n        lst2.append(v1)\n        lst2.append(float(\"nan\"))\n        lst2.append(v2)\n        lst2.append(0)\n\n        popped_value = lst2.pop()\n        self.assertEqual(popped_value, 0)\n        self.assertLen(lst2, 3)\n        self.assertLen(tracked_variables, 2)\n\n        lst2.clear()\n        self.assertLen(lst2, 0)\n        self.assertLen(tracked_variables, 0)\n\n        lst2.append(v1)\n        lst2.append(v2)\n        del lst2[0]\n        self.assertLen(lst2, 1)\n        self.assertLen(tracked_variables, 1)\n\n    def test_tuple_tracking(self):\n        tracked_variables = []\n        tracker = tracking.Tracker(\n            {\n                \"variables\": (\n                    lambda x: isinstance(x, backend.Variable),\n                    tracked_variables,\n                ),\n            }\n        )\n        v1 = backend.Variable(1.0)\n        v2 = backend.Variable(2.0)\n        tup = (v1, v2)\n        tup = tracker.track(tup)\n        self.assertIsInstance(tup, tuple)\n        self.assertLen(tracked_variables, 2)\n        self.assertEqual(tracked_variables[0], v1)\n        self.assertEqual(tracked_variables[1], v2)\n\n    def test_namedtuple_tracking(self):\n        tracked_variables = []\n        tracker = tracking.Tracker(\n            {\n                \"variables\": (\n                    lambda x: isinstance(x, backend.Variable),\n                    tracked_variables,\n                ),\n            }\n        )\n        v1 = backend.Variable(1.0)\n        v2 = backend.Variable(2.0)\n        nt = collections.namedtuple(\"NT\", [\"x\", \"y\"])\n        tup = nt(x=v1, y=v2)\n        tup = tracker.track(tup)\n        self.assertIsInstance(tup, tuple)\n        self.assertEqual(tup.x, v1)\n        self.assertEqual(tup.y, v2)\n        self.assertLen(tracked_variables, 2)\n        self.assertEqual(tracked_variables[0], v1)\n        self.assertEqual(tracked_variables[1], v2)\n\n    def test_tracked_dict_constructor_with_ordered_dict(self):\n        tracked_variables = []\n        tracker = tracking.Tracker(\n            {\n                \"variables\": (\n                    lambda x: isinstance(x, backend.Variable),\n                    tracked_variables,\n                ),\n            }\n        )\n        v1 = backend.Variable(1.0)\n        v2 = backend.Variable(2.0)\n        v3 = backend.Variable(3.0)\n        input_ordered = collections.OrderedDict(\n            [(\"x\", v1), (\"y\", v2), (\"z\", v3)]\n        )\n        tdict = tracking.TrackedDict(input_ordered, tracker=tracker)\n\n        self.assertIsInstance(tdict, dict)\n        self.assertEqual(list(tdict.keys()), [\"x\", \"y\", \"z\"])\n        self.assertEqual(list(tdict.values()), [v1, v2, v3])\n        self.assertLen(tracked_variables, 3)\n        self.assertEqual(tracked_variables, [v1, v2, v3])\n\n    def test_tracked_dict_constructor_with_iterable_pairs(self):\n        tracked_variables = []\n        tracker = tracking.Tracker(\n            {\n                \"variables\": (\n                    lambda x: isinstance(x, backend.Variable),\n                    tracked_variables,\n                ),\n            }\n        )\n        v1 = backend.Variable(1.0)\n        v2 = backend.Variable(2.0)\n        # Test with zip object\n        keys = [\"p\", \"q\"]\n        values = [v1, v2]\n        iterable_pairs = zip(keys, values)\n        tdict = tracking.TrackedDict(iterable_pairs, tracker=tracker)\n\n        self.assertIsInstance(tdict, dict)\n        self.assertEqual(tdict[\"p\"], v1)\n        self.assertEqual(tdict[\"q\"], v2)\n        self.assertLen(tdict, 2)\n        self.assertLen(tracked_variables, 2)\n        self.assertIn(v1, tracked_variables)\n        self.assertIn(v2, tracked_variables)\n\n    def test_tracked_ordered_dict_preserves_type_and_order(self):\n        tracked_variables = []\n        tracker = tracking.Tracker(\n            {\n                \"variables\": (\n                    lambda x: isinstance(x, backend.Variable),\n                    tracked_variables,\n                ),\n            }\n        )\n        v1 = backend.Variable(1.0)\n        v2 = backend.Variable(2.0)\n        input_ordered = collections.OrderedDict([(\"x\", v1), (\"y\", v2)])\n        ordered_dict = tracker.track(input_ordered)\n\n        self.assertIsInstance(ordered_dict, collections.OrderedDict)\n        self.assertEqual(list(ordered_dict.keys()), [\"x\", \"y\"])\n        self.assertLen(tracked_variables, 2)\n\n    def test_tracked_ordered_dict_setitem_and_pop(self):\n        tracked_variables = []\n        tracker = tracking.Tracker(\n            {\n                \"variables\": (\n                    lambda x: isinstance(x, backend.Variable),\n                    tracked_variables,\n                ),\n            }\n        )\n        v1 = backend.Variable(1.0)\n        v2 = backend.Variable(2.0)\n        ordered_dict = tracking.TrackedOrderedDict(\n            collections.OrderedDict([(\"x\", v1)]), tracker=tracker\n        )\n        self.assertLen(tracked_variables, 1)\n\n        ordered_dict[\"y\"] = v2\n        self.assertLen(tracked_variables, 2)\n\n        ordered_dict.pop(\"x\")\n        self.assertLen(tracked_variables, 1)\n        self.assertNotIn(v1, tracked_variables)\n"
  },
  {
    "path": "keras/src/version.py",
    "content": "from keras.src.api_export import keras_export\n\n# Unique source of truth for the version number.\n__version__ = \"3.14.0\"\n\n\n@keras_export(\"keras.version\")\ndef version():\n    return __version__\n"
  },
  {
    "path": "keras/src/visualization/__init__.py",
    "content": "from keras.src.visualization import draw_bounding_boxes\nfrom keras.src.visualization import plot_image_gallery\n"
  },
  {
    "path": "keras/src/visualization/draw_bounding_boxes.py",
    "content": "import numpy as np\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.bounding_boxes.converters import (  # noqa: E501\n    convert_format,\n)\n\ntry:\n    import cv2\nexcept ImportError:\n    cv2 = None\n\n\n@keras_export(\"keras.visualization.draw_bounding_boxes\")\ndef draw_bounding_boxes(\n    images,\n    bounding_boxes,\n    bounding_box_format,\n    class_mapping=None,\n    color=(128, 128, 128),\n    line_thickness=2,\n    text_thickness=1,\n    font_scale=1.0,\n    data_format=None,\n):\n    \"\"\"Draws bounding boxes on images.\n\n    This function draws bounding boxes on a batch of images.  It supports\n    different bounding box formats and can optionally display class labels\n    and confidences.\n\n    Args:\n        images: A batch of images as a 4D tensor or NumPy array. Shape should be\n            `(batch_size, height, width, channels)`.\n        bounding_boxes: A dictionary containing bounding box data.  Should have\n            the following keys:\n            - `boxes`: A tensor or array of shape `(batch_size, num_boxes, 4)`\n               containing the bounding box coordinates in the specified format.\n            - `labels`: A tensor or array of shape `(batch_size, num_boxes)`\n              containing the class labels for each bounding box.\n            - `confidences` (Optional): A tensor or array of shape\n               `(batch_size, num_boxes)` containing the confidence scores for\n               each bounding box.\n        bounding_box_format: A string specifying the format of the bounding\n            boxes. See `keras.utils.bounding_boxes.convert_format` for\n            supported formats (e.g., \"xyxy\", \"yxyx\", \"xywh\",\n            \"center_xywh\", \"center_yxhw\", and the \"rel_*\" variants).\n        class_mapping: A dictionary mapping class IDs (integers) to class labels\n            (strings).  Used to display class labels next to the bounding boxes.\n            Defaults to None (no labels displayed).\n        color: A tuple or list representing the RGB color of the bounding boxes.\n            For example, `(255, 0, 0)` for red. Defaults to `(128, 128, 128)`.\n        line_thickness: An integer specifying the thickness of the bounding box\n            lines. Defaults to `2`.\n        text_thickness: An integer specifying the thickness of the text labels.\n            Defaults to `1`.\n        font_scale: A float specifying the scale of the font used for text\n            labels. Defaults to `1.0`.\n        data_format: A string, either `\"channels_last\"` or `\"channels_first\"`,\n            specifying the order of dimensions in the input images. Defaults to\n            the `image_data_format` value found in your Keras config file at\n            `~/.keras/keras.json`. If you never set it, then it will be\n            \"channels_last\".\n\n    Returns:\n        A NumPy array of the annotated images with the bounding boxes drawn.\n        The array will have the same shape as the input `images`.\n\n    Raises:\n        ValueError: If `images` is not a 4D tensor/array, if `bounding_boxes` is\n            not a dictionary, or if `bounding_boxes` does not contain `\"boxes\"`\n            and `\"labels\"` keys.\n        TypeError: If `bounding_boxes` is not a dictionary.\n        ImportError: If `cv2` (OpenCV) is not installed.\n    \"\"\"\n\n    if cv2 is None:\n        raise ImportError(\n            \"The `draw_bounding_boxes` function requires the `cv2` package \"\n            \" (OpenCV). Please install it with `pip install opencv-python`.\"\n        )\n\n    class_mapping = class_mapping or {}\n    text_thickness = (\n        text_thickness or line_thickness\n    )  # Default text_thickness if not provided.\n    data_format = data_format or backend.image_data_format()\n    images_shape = ops.shape(images)\n    if len(images_shape) != 4:\n        raise ValueError(\n            \"`images` must be batched 4D tensor. \"\n            f\"Received: images.shape={images_shape}\"\n        )\n    if not isinstance(bounding_boxes, dict):\n        raise TypeError(\n            \"`bounding_boxes` should be a dict. \"\n            f\"Received: bounding_boxes={bounding_boxes} of type \"\n            f\"{type(bounding_boxes)}\"\n        )\n    if \"boxes\" not in bounding_boxes or \"labels\" not in bounding_boxes:\n        raise ValueError(\n            \"`bounding_boxes` should be a dict containing 'boxes' and \"\n            f\"'labels' keys. Received: bounding_boxes={bounding_boxes}\"\n        )\n    if data_format == \"channels_last\":\n        h_axis = -3\n        w_axis = -2\n    else:\n        h_axis = -2\n        w_axis = -1\n    height = images_shape[h_axis]\n    width = images_shape[w_axis]\n    bounding_boxes = bounding_boxes.copy()\n    bounding_boxes = convert_format(\n        bounding_boxes, bounding_box_format, \"xyxy\", height, width\n    )\n\n    # To numpy array\n    images = ops.convert_to_numpy(images)\n    if images.dtype.kind == \"f\" and images.size > 0:\n        if images.max() <= 1.0:\n            images = np.clip(images, 0, 1) * 255\n        else:\n            images = np.clip(images, 0, 255)\n    images = images.astype(\"uint8\")\n    boxes = ops.convert_to_numpy(bounding_boxes[\"boxes\"])\n    labels = ops.convert_to_numpy(bounding_boxes[\"labels\"])\n    if \"confidences\" in bounding_boxes:\n        confidences = ops.convert_to_numpy(bounding_boxes[\"confidences\"])\n    else:\n        confidences = None\n\n    result = []\n    batch_size = images.shape[0]\n    for i in range(batch_size):\n        _image = images[i]\n        _box = boxes[i]\n        _class = labels[i]\n        for box_i in range(_box.shape[0]):\n            x1, y1, x2, y2 = _box[box_i].astype(\"int32\")\n            c = _class[box_i].astype(\"int32\")\n            if c == -1:\n                continue\n            x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)\n            c = int(c)\n            # Draw bounding box\n            cv2.rectangle(_image, (x1, y1), (x2, y2), color, line_thickness)\n\n            if c in class_mapping:\n                label = class_mapping[c]\n                if confidences is not None:\n                    conf = confidences[i][box_i]\n                    label = f\"{label} | {conf:.2f}\"\n\n                font_x1, font_y1 = _find_text_location(\n                    x1, y1, font_scale, text_thickness\n                )\n                cv2.putText(\n                    img=_image,\n                    text=label,\n                    org=(font_x1, font_y1),\n                    fontFace=cv2.FONT_HERSHEY_SIMPLEX,\n                    fontScale=font_scale,\n                    color=color,\n                    thickness=text_thickness,\n                )\n        result.append(_image)\n    return np.stack(result, axis=0)\n\n\ndef _find_text_location(x, y, font_scale, thickness):\n    font_height = int(font_scale * 12)\n    target_y = y - 8\n    if target_y - (2 * font_height) > 0:\n        return x, y - 8\n\n    line_offset = thickness\n    static_offset = 3\n\n    return (\n        x + static_offset,\n        y + (2 * font_height) + line_offset + static_offset,\n    )\n"
  },
  {
    "path": "keras/src/visualization/draw_segmentation_masks.py",
    "content": "import numpy as np\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\n\n\n@keras_export(\"keras.visualization.draw_segmentation_masks\")\ndef draw_segmentation_masks(\n    images,\n    segmentation_masks,\n    num_classes=None,\n    color_mapping=None,\n    alpha=0.8,\n    blend=True,\n    ignore_index=-1,\n    data_format=None,\n):\n    \"\"\"Draws segmentation masks on images.\n\n    The function overlays segmentation masks on the input images.\n    The masks are blended with the images using the specified alpha value.\n\n    Args:\n        images: A batch of images as a 4D tensor or NumPy array. Shape\n            should be (batch_size, height, width, channels).\n        segmentation_masks: A batch of segmentation masks as a 3D or 4D tensor\n            or NumPy array.  Shape should be (batch_size, height, width) or\n            (batch_size, height, width, 1). The values represent class indices\n            starting from 1 up to `num_classes`. Class 0 is reserved for\n            the background and will be ignored if `ignore_index` is not 0.\n        num_classes: The number of segmentation classes. If `None`, it is\n            inferred from the maximum value in `segmentation_masks`.\n        color_mapping: A dictionary mapping class indices to RGB colors.\n            If `None`, a default color palette is generated. The keys should be\n            integers starting from 1 up to `num_classes`.\n        alpha: The opacity of the segmentation masks. Must be in the range\n            `[0, 1]`.\n        blend: Whether to blend the masks with the input image using the\n            `alpha` value. If `False`, the masks are drawn directly on the\n            images without blending. Defaults to `True`.\n        ignore_index: The class index to ignore. Mask pixels with this value\n            will not be drawn.  Defaults to -1.\n        data_format: Image data format, either `\"channels_last\"` or\n            `\"channels_first\"`. Defaults to the `image_data_format` value found\n            in your Keras config file at `~/.keras/keras.json`. If you never\n            set it, then it will be `\"channels_last\"`.\n\n    Returns:\n        A NumPy array of the images with the segmentation masks overlaid.\n\n    Raises:\n        ValueError: If the input `images` is not a 4D tensor or NumPy array.\n        TypeError: If the input `segmentation_masks` is not an integer type.\n    \"\"\"\n    data_format = data_format or backend.image_data_format()\n    images_shape = ops.shape(images)\n    if len(images_shape) != 4:\n        raise ValueError(\n            \"`images` must be batched 4D tensor. \"\n            f\"Received: images.shape={images_shape}\"\n        )\n    if data_format == \"channels_first\":\n        images = ops.transpose(images, (0, 2, 3, 1))\n        segmentation_masks = ops.transpose(segmentation_masks, (0, 2, 3, 1))\n    images = ops.convert_to_tensor(images, dtype=\"float32\")\n    segmentation_masks = ops.convert_to_tensor(segmentation_masks)\n\n    if not backend.is_int_dtype(segmentation_masks.dtype):\n        dtype = backend.standardize_dtype(segmentation_masks.dtype)\n        raise TypeError(\n            \"`segmentation_masks` must be in integer dtype. \"\n            f\"Received: segmentation_masks.dtype={dtype}\"\n        )\n\n    # Infer num_classes\n    if num_classes is None:\n        num_classes = int(ops.convert_to_numpy(ops.max(segmentation_masks)))\n    if color_mapping is None:\n        colors = _generate_color_palette(num_classes)\n    else:\n        colors = [color_mapping[i] for i in range(num_classes)]\n    valid_masks = ops.not_equal(segmentation_masks, ignore_index)\n    valid_masks = ops.squeeze(valid_masks, axis=-1)\n    segmentation_masks = ops.one_hot(segmentation_masks, num_classes)\n    segmentation_masks = segmentation_masks[..., 0, :]\n    segmentation_masks = ops.convert_to_numpy(segmentation_masks)\n\n    # Replace class with color\n    masks = segmentation_masks\n    masks = np.transpose(masks, axes=(3, 0, 1, 2)).astype(\"bool\")\n    images_to_draw = ops.convert_to_numpy(images).copy()\n    for mask, color in zip(masks, colors):\n        color = np.array(color, dtype=images_to_draw.dtype)\n        images_to_draw[mask, ...] = color[None, :]\n    images_to_draw = ops.convert_to_tensor(images_to_draw)\n    outputs = ops.cast(images_to_draw, dtype=\"float32\")\n\n    if blend:\n        outputs = images * (1 - alpha) + outputs * alpha\n        outputs = ops.where(valid_masks[..., None], outputs, images)\n        outputs = ops.cast(outputs, dtype=\"uint8\")\n        outputs = ops.convert_to_numpy(outputs)\n    return outputs\n\n\ndef _generate_color_palette(num_classes):\n    palette = np.array([2**25 - 1, 2**15 - 1, 2**21 - 1])\n    return [((i * palette) % 255).tolist() for i in range(num_classes)]\n"
  },
  {
    "path": "keras/src/visualization/plot_bounding_box_gallery.py",
    "content": "import functools\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.visualization.draw_bounding_boxes import draw_bounding_boxes\nfrom keras.src.visualization.plot_image_gallery import plot_image_gallery\n\ntry:\n    from matplotlib import patches  # For legend patches\nexcept ImportError:\n    patches = None\n\n\n@keras_export(\"keras.visualization.plot_bounding_box_gallery\")\ndef plot_bounding_box_gallery(\n    images,\n    bounding_box_format,\n    y_true=None,\n    y_pred=None,\n    value_range=(0, 255),\n    true_color=(0, 188, 212),\n    pred_color=(255, 235, 59),\n    line_thickness=2,\n    font_scale=1.0,\n    text_thickness=None,\n    class_mapping=None,\n    ground_truth_mapping=None,\n    prediction_mapping=None,\n    legend=False,\n    legend_handles=None,\n    rows=None,\n    cols=None,\n    data_format=None,\n    **kwargs,\n):\n    \"\"\"Plots a gallery of images with bounding boxes.\n\n    This function can display both ground truth and predicted bounding boxes on\n    a set of images.  It supports various bounding box formats and can include\n    class labels and a legend.\n\n    Args:\n        images: A 4D tensor or NumPy array of images. Shape should be\n            `(batch_size, height, width, channels)`.\n        bounding_box_format: The format of the bounding boxes. See\n            `keras.utils.bounding_boxes.convert_format` for supported formats\n            (e.g., \"xyxy\", \"yxyx\", \"xywh\", \"center_xywh\", \"center_yxhw\",\n            and the \"rel_*\" variants).\n        y_true: A dictionary containing the ground truth bounding boxes and\n            labels. Should have the same structure as the `bounding_boxes`\n            argument in `keras.visualization.draw_bounding_boxes`.\n            Defaults to `None`.\n        y_pred: A dictionary containing the predicted bounding boxes and labels.\n            Should have the same structure as `y_true`. Defaults to `None`.\n        value_range: A tuple specifying the value range of the images\n            (e.g., `(0, 255)` or `(0, 1)`). Defaults to `(0, 255)`.\n        true_color: A tuple of three integers representing the RGB color for the\n            ground truth bounding boxes. Defaults to `(0, 188, 212)`.\n        pred_color: A tuple of three integers representing the RGB color for the\n            predicted bounding boxes. Defaults to `(255, 235, 59)`.\n        line_thickness: The thickness of the bounding box lines. Defaults to 2.\n        font_scale: The scale of the font used for labels. Defaults to 1.0.\n        text_thickness: The thickness of the bounding box text. Defaults to\n            `line_thickness`.\n        class_mapping: A dictionary mapping class IDs to class names. Used for\n            both ground truth and predicted boxes if `ground_truth_mapping`\n            and `prediction_mapping` are not provided. Defaults to `None`.\n        ground_truth_mapping:  A dictionary mapping class IDs to class names\n            specifically for ground truth boxes. Overrides `class_mapping`\n            for ground truth. Defaults to `None`.\n        prediction_mapping: A dictionary mapping class IDs to class names\n            specifically for predicted boxes. Overrides `class_mapping` for\n            predictions. Defaults to `None`.\n        legend: A boolean indicating whether to show a legend.\n            Defaults to `False`.\n        legend_handles: A list of matplotlib `Patch` objects to use for the\n            legend. If this is provided, the `legend` argument will be ignored.\n            Defaults to `None`.\n        rows: The number of rows in the image gallery. Required if the images\n            are not batched. Defaults to `None`.\n        cols: The number of columns in the image gallery. Required if the images\n            are not batched. Defaults to `None`.\n        data_format: The image data format `\"channels_last\"` or\n            `\"channels_first\"`. Defaults to the Keras backend data format.\n        kwargs: Additional keyword arguments to be passed to\n            `keras.visualization.plot_image_gallery`.\n\n    Returns:\n       The output of `keras.visualization.plot_image_gallery`.\n\n    Raises:\n        ValueError: If `images` is not a 4D tensor/array or if both `legend`\n            and `legend_handles` are specified.\n        ImportError: if matplotlib is not installed\n    \"\"\"\n    if patches is None:\n        raise ImportError(\n            \"The `plot_bounding_box_gallery` function requires the \"\n            \" `matplotlib` package. Please install it with \"\n            \" `pip install matplotlib`.\"\n        )\n\n    prediction_mapping = prediction_mapping or class_mapping\n    ground_truth_mapping = ground_truth_mapping or class_mapping\n    data_format = data_format or backend.image_data_format()\n    images_shape = ops.shape(images)\n    if len(images_shape) != 4:\n        raise ValueError(\n            \"`images` must be batched 4D tensor. \"\n            f\"Received: images.shape={images_shape}\"\n        )\n    if data_format == \"channels_first\":  # Ensure correct data format\n        images = ops.transpose(images, (0, 2, 3, 1))\n    plotted_images = ops.convert_to_numpy(images)\n\n    draw_fn = functools.partial(\n        draw_bounding_boxes,\n        bounding_box_format=bounding_box_format,\n        line_thickness=line_thickness,\n        text_thickness=text_thickness,\n        font_scale=font_scale,\n    )\n\n    if y_true is not None:\n        plotted_images = draw_fn(\n            plotted_images,\n            y_true,\n            color=true_color,\n            class_mapping=ground_truth_mapping,\n        )\n\n    if y_pred is not None:\n        plotted_images = draw_fn(\n            plotted_images,\n            y_pred,\n            color=pred_color,\n            class_mapping=prediction_mapping,\n        )\n\n    if legend:\n        if legend_handles:\n            raise ValueError(\n                \"Only pass `legend` OR `legend_handles` to \"\n                \"`keras.visualization.plot_bounding_box_gallery()`.\"\n            )\n        legend_handles = [\n            patches.Patch(\n                color=np.array(true_color) / 255.0,  # Normalize color\n                label=\"Ground Truth\",\n            ),\n            patches.Patch(\n                color=np.array(pred_color) / 255.0,  # Normalize color\n                label=\"Prediction\",\n            ),\n        ]\n\n    return plot_image_gallery(\n        plotted_images,\n        value_range=value_range,\n        legend_handles=legend_handles,\n        rows=rows,\n        cols=cols,\n        **kwargs,\n    )\n"
  },
  {
    "path": "keras/src/visualization/plot_image_gallery.py",
    "content": "import math\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import (  # noqa: E501\n    BaseImagePreprocessingLayer,\n)\n\ntry:\n    import matplotlib.pyplot as plt\nexcept ImportError:\n    plt = None\n\n\ndef _extract_image_batch(images, num_images, batch_size):\n    \"\"\"Extracts a batch of images for plotting.\n\n    Args:\n        images: The 4D tensor or NumPy array of images.\n        num_images: The number of images to extract.\n        batch_size: The original batch size of the images.\n\n    Returns:\n        A 4D tensor or NumPy array containing the extracted images.\n\n    Raises:\n        ValueError: If `images` is not a 4D tensor/array.\n    \"\"\"\n\n    if len(ops.shape(images)) != 4:\n        raise ValueError(\n            \"`plot_images_gallery()` requires you to \"\n            \"batch your `np.array` samples together.\"\n        )\n    num_samples = min(num_images, batch_size)\n    sample = images[:num_samples, ...]\n\n    return sample\n\n\n@keras_export(\"keras.visualization.plot_image_gallery\")\ndef plot_image_gallery(\n    images,\n    y_true=None,\n    y_pred=None,\n    label_map=None,\n    rows=None,\n    cols=None,\n    value_range=(0, 255),\n    scale=2,\n    path=None,\n    show=None,\n    transparent=True,\n    dpi=60,\n    legend_handles=None,\n    data_format=None,\n):\n    \"\"\"Displays a gallery of images with optional labels and predictions.\n\n    Args:\n        images: A 4D tensor or NumPy array of images. Shape should be\n           `(batch_size, height, width, channels)`.\n        y_true: A 1D tensor or NumPy array of true labels (class indices).\n           Defaults to `None`.\n        y_pred: A 1D tensor or NumPy array of predicted labels (class indices).\n           Defaults to `None`.\n        label_map: A dictionary mapping class indices to class names.\n            Required if `y_true` or `y_pred` are provided.\n           Defaults to `None`.\n        value_range: A tuple specifying the value range of the images\n            (e.g., `(0, 255)` or `(0, 1)`). Defaults to `(0, 255)`.\n        rows: The number of rows in the gallery. If `None`, it's calculated\n            based on the number of images and `cols`. Defaults to `None`.\n        cols: The number of columns in the gallery. If `None`, it's calculated\n            based on the number of images and `rows`. Defaults to `None`.\n        scale: A float controlling the size of the displayed images. The images\n            are scaled by this factor. Defaults to `2`.\n        path: The path to save the generated gallery image. If `None`, the\n            image is displayed using `plt.show()`. Defaults to `None`.\n        show: Whether to display the image using `plt.show()`. If `True`, the\n            image is displayed. If `False`, the image is not displayed.\n            Ignored if `path` is not `None`. Defaults to `True` if `path`\n            is `None`, `False` otherwise.\n        transparent:  A boolean, whether to save the figure with a transparent\n            background. Defaults to `True`.\n        dpi: The DPI (dots per inch) for saving the figure. Defaults to 60.\n        legend_handles: A list of matplotlib `Patch` objects to use as legend\n            handles. Defaults to `None`.\n        data_format: The image data format `\"channels_last\"` or\n            `\"channels_first\"`. Defaults to the Keras backend data format.\n\n    Raises:\n        ValueError: If both `path` and `show` are set to non-`None` values,\n            if `images` is not a 4D tensor or array, or if `y_true` or `y_pred`\n            are provided without a `label_map`.\n        ImportError: if matplotlib is not installed.\n    \"\"\"\n    if plt is None:\n        raise ImportError(\n            \"The `plot_image_gallery` function requires the `matplotlib` \"\n            \"package. Please install it with `pip install matplotlib`.\"\n        )\n\n    if path is not None and show:\n        raise ValueError(\n            \"plot_gallery() expects either `path` to be set, or `show` \"\n            \"to be true.\"\n        )\n\n    if (y_true is not None or y_pred is not None) and label_map is None:\n        raise ValueError(\n            \"If `y_true` or `y_pred` are provided, a `label_map` must also be\"\n            \" provided.\"\n        )\n\n    show = show if show is not None else (path is None)\n    data_format = data_format or backend.image_data_format()\n\n    batch_size = ops.shape(images)[0] if len(ops.shape(images)) == 4 else 1\n\n    rows = rows or int(math.ceil(math.sqrt(batch_size)))\n    cols = cols or int(math.ceil(batch_size // rows))\n    num_images = rows * cols\n\n    images = _extract_image_batch(images, num_images, batch_size)\n    if (\n        data_format == \"channels_first\"\n    ):  # Ensure correct data format for plotting\n        images = ops.transpose(images, (0, 2, 3, 1))\n\n    # Generate subplots\n    fig, axes = plt.subplots(\n        nrows=rows,\n        ncols=cols,\n        figsize=(cols * scale, rows * scale),\n        frameon=False,\n        layout=\"tight\",\n        squeeze=True,\n        sharex=\"row\",\n        sharey=\"col\",\n    )\n    fig.subplots_adjust(wspace=0, hspace=0)\n\n    if isinstance(axes, np.ndarray) and len(axes.shape) == 1:\n        expand_axis = 0 if rows == 1 else -1\n        axes = np.expand_dims(axes, expand_axis)\n\n    if legend_handles is not None:\n        fig.legend(handles=legend_handles, loc=\"lower center\")\n\n    images = BaseImagePreprocessingLayer()._transform_value_range(\n        images=images, original_range=value_range, target_range=(0, 255)\n    )\n\n    images = ops.convert_to_numpy(images)\n    if data_format == \"channels_first\":\n        images = images.transpose(0, 2, 3, 1)\n\n    if y_true is not None:\n        y_true = ops.convert_to_numpy(y_true)\n    if y_pred is not None:\n        y_pred = ops.convert_to_numpy(y_pred)\n\n    for row in range(rows):\n        for col in range(cols):\n            index = row * cols + col\n            current_axis = (\n                axes[row, col] if isinstance(axes, np.ndarray) else axes\n            )\n            current_axis.imshow(images[index].astype(\"uint8\"))\n            current_axis.margins(x=0, y=0)\n            current_axis.axis(\"off\")\n            title_parts = []\n            if y_true is not None and index < len(y_true):\n                title_parts.append(\n                    f\"Label: {label_map.get(y_true[index], 'Unknown')}\"\n                )\n            if y_pred is not None and index < len(y_pred):\n                title_parts.append(\n                    f\"Pred: {label_map.get(y_pred[index], 'Unknown')}\"\n                )\n\n            if title_parts:\n                current_axis.set_title(\"  \".join(title_parts), fontsize=8)\n\n    if path is not None:\n        plt.savefig(\n            fname=path,\n            pad_inches=0,\n            bbox_inches=\"tight\",\n            transparent=transparent,\n            dpi=dpi,\n        )\n        plt.close()\n    elif show:\n        plt.show()\n        plt.close()\n"
  },
  {
    "path": "keras/src/visualization/plot_segmentation_mask_gallery.py",
    "content": "import functools\n\nimport numpy as np\n\nfrom keras.src import backend\nfrom keras.src import ops\nfrom keras.src.api_export import keras_export\nfrom keras.src.visualization.draw_segmentation_masks import (\n    draw_segmentation_masks,\n)\nfrom keras.src.visualization.plot_image_gallery import plot_image_gallery\n\n\n@keras_export(\"keras.visualization.plot_segmentation_mask_gallery\")\ndef plot_segmentation_mask_gallery(\n    images,\n    num_classes,\n    value_range=(0, 255),\n    y_true=None,\n    y_pred=None,\n    color_mapping=None,\n    blend=True,\n    alpha=0.8,\n    ignore_index=-1,\n    data_format=None,\n    **kwargs,\n):\n    \"\"\"Plots a gallery of images with corresponding segmentation masks.\n\n    Args:\n        images: A 4D tensor or NumPy array of images. Shape should be\n            `(batch_size, height, width, channels)`.\n        num_classes: The number of segmentation classes.  Class indices should\n            start from `1`.  Class `0` will be treated as background and\n            ignored if `ignore_index` is not 0.\n        value_range: A tuple specifying the value range of the images\n            (e.g., `(0, 255)` or `(0, 1)`). Defaults to `(0, 255)`.\n        y_true: A 3D/4D tensor or NumPy array representing the ground truth\n            segmentation masks. Shape should be `(batch_size, height, width)` or\n            `(batch_size, height, width, 1)`. Defaults to `None`.\n        y_pred: A 3D/4D tensor or NumPy array representing the predicted\n            segmentation masks.  Shape should be the same as `y_true`.\n            Defaults to `None`.\n        color_mapping: A dictionary mapping class indices to RGB colors.\n            If `None`, a default color palette is used. Class indices start\n            from `1`. Defaults to `None`.\n        blend: Whether to blend the masks with the input image using the\n            `alpha` value. If `False`, the masks are drawn directly on the\n            images without blending. Defaults to `True`.\n        alpha: The opacity of the segmentation masks (a float between 0 and 1).\n            Defaults to `0.8`.\n        ignore_index: The class index to ignore when drawing masks.\n            Defaults to `-1`.\n        data_format: The image data format `\"channels_last\"` or\n            `\"channels_first\"`. Defaults to the Keras backend data format.\n        kwargs: Additional keyword arguments to be passed to\n            `keras.visualization.plot_image_gallery`.\n\n    Returns:\n        The output of `keras.visualization.plot_image_gallery`.\n\n    Raises:\n        ValueError: If `images` is not a 4D tensor/array.\n    \"\"\"\n    data_format = data_format or backend.image_data_format()\n    image_shape = ops.shape(images)\n    if len(image_shape) != 4:\n        raise ValueError(\n            \"`images` must be batched 4D tensor. \"\n            f\"Received: images.shape={image_shape}\"\n        )\n    if data_format == \"channels_first\":\n        images = ops.transpose(images, (0, 2, 3, 1))\n\n    batch_size = image_shape[0] if len(image_shape) == 4 else 1\n\n    rows = batch_size\n    cols = 1\n\n    if y_true is not None:\n        cols += 1\n\n    if y_pred is not None:\n        cols += 1\n\n    images_np = ops.convert_to_numpy(images)\n\n    draw_masks_fn = functools.partial(\n        draw_segmentation_masks,\n        num_classes=num_classes,\n        color_mapping=color_mapping,\n        alpha=alpha,\n        ignore_index=ignore_index,\n        blend=blend,\n    )\n\n    if y_true is not None:\n        if data_format == \"channels_first\":\n            y_true = ops.transpose(y_true, (0, 2, 3, 1))\n        y_true = ops.cast(y_true, \"int32\")\n        true_masks_drawn = draw_masks_fn(images_np, y_true)\n\n    if y_pred is not None:\n        if data_format == \"channels_first\":\n            y_pred = ops.transpose(y_pred, (0, 2, 3, 1))\n        y_pred = ops.cast(y_pred, \"int32\")\n        predicted_masks_drawn = draw_masks_fn(images_np, y_pred)\n\n    images_with_masks = []\n    for i in range(batch_size):\n        images_with_masks.append(images_np[i])\n        if y_true is not None:\n            images_with_masks.append(true_masks_drawn[i])\n        if y_pred is not None:\n            images_with_masks.append(predicted_masks_drawn[i])\n\n    gallery_images = np.stack(images_with_masks, axis=0)\n\n    return plot_image_gallery(\n        gallery_images, value_range=value_range, rows=rows, cols=cols, **kwargs\n    )\n"
  },
  {
    "path": "keras/src/wrappers/__init__.py",
    "content": "from keras.src.wrappers.sklearn_wrapper import SKLearnClassifier\nfrom keras.src.wrappers.sklearn_wrapper import SKLearnRegressor\nfrom keras.src.wrappers.sklearn_wrapper import SKLearnTransformer\n\n__all__ = [\"SKLearnClassifier\", \"SKLearnRegressor\", \"SKLearnTransformer\"]\n"
  },
  {
    "path": "keras/src/wrappers/fixes.py",
    "content": "try:\n    import sklearn\nexcept ImportError:\n    sklearn = None\n\n\ndef _validate_data(estimator, *args, **kwargs):\n    \"\"\"Validate the input data.\n\n    wrapper for sklearn.utils.validation.validate_data or\n    BaseEstimator._validate_data depending on the scikit-learn version.\n\n    TODO: remove when minimum scikit-learn version is 1.6\n    \"\"\"\n    try:\n        # scikit-learn >= 1.6\n        from sklearn.utils.validation import validate_data\n\n        return validate_data(estimator, *args, **kwargs)\n    except ImportError:\n        return estimator._validate_data(*args, **kwargs)\n    except:\n        raise\n\n\ndef type_of_target(y, input_name=\"\", *, raise_unknown=False):\n    def _raise_or_return(target_type):\n        \"\"\"Depending on the value of raise_unknown, either raise an error or\n        return 'unknown'.\n        \"\"\"\n        if raise_unknown and target_type == \"unknown\":\n            input = input_name if input_name else \"data\"\n            raise ValueError(f\"Unknown label type for {input}: {y!r}\")\n        else:\n            return target_type\n\n    from sklearn.utils.multiclass import type_of_target as sk_type_of_target\n\n    target_type = sk_type_of_target(y, input_name=input_name)\n    return _raise_or_return(target_type)\n\n\ndef _routing_enabled():\n    \"\"\"Return whether metadata routing is enabled.\n\n    Returns:\n        enabled : bool\n            Whether metadata routing is enabled. If the config is not set, it\n            defaults to False.\n\n    TODO: remove when the config key is no longer available in scikit-learn\n    \"\"\"\n    return sklearn.get_config().get(\"enable_metadata_routing\", False)\n\n\ndef _raise_for_params(params, owner, method):\n    \"\"\"Raise an error if metadata routing is not enabled and params are passed.\n\n    Parameters:\n        params : dict\n            The metadata passed to a method.\n        owner : object\n            The object to which the method belongs.\n        method : str\n            The name of the method, e.g. \"fit\".\n\n    Raises:\n        ValueError\n            If metadata routing is not enabled and params are passed.\n    \"\"\"\n    caller = (\n        f\"{owner.__class__.__name__}.{method}\"\n        if method\n        else owner.__class__.__name__\n    )\n    if not _routing_enabled() and params:\n        raise ValueError(\n            f\"Passing extra keyword arguments to {caller} is only supported if\"\n            \" enable_metadata_routing=True, which you can set using\"\n            \" `sklearn.set_config`. See the User Guide\"\n            \" <https://scikit-learn.org/stable/metadata_routing.html> for more\"\n            f\" details. Extra parameters passed are: {set(params)}\"\n        )\n"
  },
  {
    "path": "keras/src/wrappers/sklearn_test.py",
    "content": "\"\"\"Tests using Scikit-Learn's bundled estimator_checks.\"\"\"\n\nimport unittest\nfrom contextlib import contextmanager\n\nimport pytest\nimport sklearn\nfrom packaging.version import parse as parse_version\nfrom sklearn.utils.estimator_checks import parametrize_with_checks\n\nimport keras\nfrom keras.src.backend import floatx\nfrom keras.src.backend import set_floatx\nfrom keras.src.layers import Dense\nfrom keras.src.layers import Input\nfrom keras.src.models import Model\nfrom keras.src.wrappers import SKLearnClassifier\nfrom keras.src.wrappers import SKLearnRegressor\nfrom keras.src.wrappers import SKLearnTransformer\n\n\ndef wrapped_parametrize_with_checks(\n    estimators,\n    *,\n    legacy=True,\n    expected_failed_checks=None,\n):\n    \"\"\"Wrapped `parametrize_with_checks` handling backwards compat.\"\"\"\n    sklearn_version = parse_version(\n        parse_version(sklearn.__version__).base_version\n    )\n\n    if sklearn_version >= parse_version(\"1.6\"):\n        return parametrize_with_checks(\n            estimators,\n            legacy=legacy,\n            expected_failed_checks=expected_failed_checks,\n        )\n\n    def patched_more_tags(estimator, expected_failed_checks):\n        import copy\n\n        original_tags = copy.deepcopy(sklearn.utils._tags._safe_tags(estimator))\n\n        def patched_more_tags(self):\n            original_tags.update({\"_xfail_checks\": expected_failed_checks})\n            return original_tags\n\n        estimator.__class__._more_tags = patched_more_tags\n        return estimator\n\n    estimators = [\n        patched_more_tags(estimator, expected_failed_checks(estimator))\n        for estimator in estimators\n    ]\n\n    # legacy is not supported and ignored\n    return parametrize_with_checks(estimators)\n\n\ndef dynamic_model(X, y, loss, layers=[10]):\n    \"\"\"Creates a basic MLP classifier dynamically choosing binary/multiclass\n    classification loss and output activations.\n    \"\"\"\n    n_features_in = X.shape[1]\n    inp = Input(shape=(n_features_in,))\n\n    hidden = inp\n    for layer_size in layers:\n        hidden = Dense(layer_size, activation=\"relu\")(hidden)\n\n    n_outputs = y.shape[1] if len(y.shape) > 1 else 1\n    out = [Dense(n_outputs, activation=\"softmax\")(hidden)]\n    model = Model(inp, out)\n    model.compile(loss=loss, optimizer=\"rmsprop\")\n\n    return model\n\n\n@contextmanager\ndef use_floatx(x):\n    \"\"\"Context manager to temporarily\n    set the keras backend precision.\n    \"\"\"\n    _floatx = floatx()\n    set_floatx(x)\n    try:\n        yield\n    finally:\n        set_floatx(_floatx)\n\n\nEXPECTED_FAILED_CHECKS = {\n    \"SKLearnClassifier\": {\n        \"check_classifiers_regression_target\": \"not an issue in sklearn>=1.6\",\n        \"check_parameters_default_constructible\": (\n            \"not an issue in sklearn>=1.6\"\n        ),\n        \"check_classifiers_one_label_sample_weights\": (\n            \"0 sample weight is not ignored\"\n        ),\n        \"check_classifiers_classes\": (\n            \"with small test cases the estimator returns not all classes \"\n            \"sometimes\"\n        ),\n        \"check_classifier_data_not_an_array\": (\n            \"This test assumes reproducibility in fit.\"\n        ),\n        \"check_supervised_y_2d\": \"This test assumes reproducibility in fit.\",\n        \"check_fit_idempotent\": \"This test assumes reproducibility in fit.\",\n    },\n    \"SKLearnRegressor\": {\n        \"check_parameters_default_constructible\": (\n            \"not an issue in sklearn>=1.6\"\n        ),\n    },\n    \"SKLearnTransformer\": {\n        \"check_parameters_default_constructible\": (\n            \"not an issue in sklearn>=1.6\"\n        ),\n    },\n}\n\n\n@wrapped_parametrize_with_checks(\n    estimators=[\n        SKLearnClassifier(\n            model=dynamic_model,\n            model_kwargs={\n                \"loss\": \"categorical_crossentropy\",\n                \"layers\": [20, 20, 20],\n            },\n            fit_kwargs={\"epochs\": 5},\n        ),\n        SKLearnRegressor(\n            model=dynamic_model,\n            model_kwargs={\"loss\": \"mse\"},\n        ),\n        SKLearnTransformer(\n            model=dynamic_model,\n            model_kwargs={\"loss\": \"mse\"},\n        ),\n    ],\n    expected_failed_checks=lambda estimator: EXPECTED_FAILED_CHECKS[\n        type(estimator).__name__\n    ],\n)\ndef test_sklearn_estimator_checks(estimator, check):\n    \"\"\"Checks that can be passed with sklearn's default tolerances\n    and in a single epoch.\n    \"\"\"\n    try:\n        check(estimator)\n    except Exception as exc:\n        if keras.config.backend() in [\"numpy\", \"openvino\"] and (\n            isinstance(exc, NotImplementedError)\n            or \"NotImplementedError\" in str(exc)\n        ):\n            pytest.xfail(\"Backend not implemented\")\n        elif isinstance(exc, unittest.SkipTest):\n            # Workaround for https://github.com/pytest-dev/pytest/issues/13895\n            pytest.skip(str(exc))\n        else:\n            raise\n"
  },
  {
    "path": "keras/src/wrappers/sklearn_wrapper.py",
    "content": "import copy\n\nimport numpy as np\n\nfrom keras.src.api_export import keras_export\nfrom keras.src.models.cloning import clone_model\nfrom keras.src.models.model import Model\nfrom keras.src.wrappers.fixes import _routing_enabled\nfrom keras.src.wrappers.fixes import _validate_data\nfrom keras.src.wrappers.fixes import type_of_target\nfrom keras.src.wrappers.utils import TargetReshaper\nfrom keras.src.wrappers.utils import _check_model\nfrom keras.src.wrappers.utils import assert_sklearn_installed\n\ntry:\n    import sklearn\n    from sklearn.base import BaseEstimator\n    from sklearn.base import ClassifierMixin\n    from sklearn.base import RegressorMixin\n    from sklearn.base import TransformerMixin\nexcept ImportError:\n    sklearn = None\n\n    class BaseEstimator:\n        pass\n\n    class ClassifierMixin:\n        pass\n\n    class RegressorMixin:\n        pass\n\n    class TransformerMixin:\n        pass\n\n\nclass SKLBase(BaseEstimator):\n    \"\"\"Base class for scikit-learn wrappers.\n\n    Note that there are sources of randomness in model initialization and\n    training. Refer to [Reproducibility in Keras Models](\n    https://keras.io/examples/keras_recipes/reproducibility_recipes/) on how to\n    control randomness.\n\n    Args:\n        model: `Model`.\n            An instance of `Model`, or a callable returning such an object.\n            Note that if input is a `Model`, it will be cloned using\n            `keras.models.clone_model` before being fitted, unless\n            `warm_start=True`.\n            The `Model` instance needs to be passed as already compiled.\n            If callable, it must accept at least `X` and `y` as keyword\n            arguments. Other arguments must be accepted if passed as\n            `model_kwargs` by the user.\n        warm_start: bool, defaults to `False`.\n            Whether to reuse the model weights from the previous fit. If `True`,\n            the given model won't be cloned and the weights from the previous\n            fit will be reused.\n        model_kwargs: dict, defaults to `None`.\n            Keyword arguments passed to `model`, if `model` is callable.\n        fit_kwargs: dict, defaults to `None`.\n            Keyword arguments passed to `model.fit`. These can also be passed\n            directly to the `fit` method of the scikit-learn wrapper. The\n            values passed directly to the `fit` method take precedence over\n            these.\n\n    Attributes:\n        model_ : `Model`\n            The fitted model.\n        history_ : dict\n            The history of the fit, returned by `model.fit`.\n    \"\"\"\n\n    def __init__(\n        self,\n        model,\n        warm_start=False,\n        model_kwargs=None,\n        fit_kwargs=None,\n    ):\n        assert_sklearn_installed(self.__class__.__name__)\n        self.model = model\n        self.warm_start = warm_start\n        self.model_kwargs = model_kwargs\n        self.fit_kwargs = fit_kwargs\n\n    def _more_tags(self):\n        return {\"non_deterministic\": True}\n\n    def __sklearn_tags__(self):\n        tags = super().__sklearn_tags__()\n        tags.non_deterministic = True\n        return tags\n\n    def __sklearn_clone__(self):\n        \"\"\"Return a deep copy of the model.\n\n        This is used by the `sklearn.base.clone` function.\n        \"\"\"\n        model = (\n            self.model if callable(self.model) else copy.deepcopy(self.model)\n        )\n        return type(self)(\n            model=model,\n            warm_start=self.warm_start,\n            model_kwargs=self.model_kwargs,\n        )\n\n    @property\n    def epoch_(self):\n        \"\"\"The current training epoch.\"\"\"\n        return getattr(self, \"history_\", {}).get(\"epoch\", 0)\n\n    def set_fit_request(self, **kwargs):\n        \"\"\"Set requested parameters by the fit method.\n\n        Please see [scikit-learn's metadata routing](\n        https://scikit-learn.org/stable/metadata_routing.html) for more\n        details.\n\n\n        Arguments:\n            kwargs : dict\n                Arguments should be of the form `param_name=alias`, and `alias`\n                can be one of `{True, False, None, str}`.\n\n        Returns:\n            self\n        \"\"\"\n        if not _routing_enabled():\n            raise RuntimeError(\n                \"This method is only available when metadata routing is \"\n                \"enabled. You can enable it using \"\n                \"sklearn.set_config(enable_metadata_routing=True).\"\n            )\n\n        self._metadata_request = sklearn.utils.metadata_routing.MetadataRequest(\n            owner=self.__class__.__name__\n        )\n        for param, alias in kwargs.items():\n            self._metadata_request.score.add_request(param=param, alias=alias)\n        return self\n\n    def _get_model(self, X, y):\n        if isinstance(self.model, Model):\n            return clone_model(self.model)\n        else:\n            args = self.model_kwargs or {}\n            return self.model(X=X, y=y, **args)\n\n    def fit(self, X, y, **kwargs):\n        \"\"\"Fit the model.\n\n        Args:\n            X: array-like, shape=(n_samples, n_features)\n                The input samples.\n            y: array-like, shape=(n_samples,) or (n_samples, n_outputs)\n                The targets.\n            **kwargs: keyword arguments passed to `model.fit`\n        \"\"\"\n        X, y = _validate_data(self, X, y)\n        y = self._process_target(y, reset=True)\n        model = self._get_model(X, y)\n        _check_model(model)\n\n        fit_kwargs = self.fit_kwargs or {}\n        fit_kwargs.update(kwargs)\n        self.history_ = model.fit(X, y, **fit_kwargs)\n\n        self.model_ = model\n        return self\n\n    def predict(self, X):\n        \"\"\"Predict using the model.\"\"\"\n        from sklearn.utils.validation import check_is_fitted\n\n        check_is_fitted(self)\n        X = _validate_data(self, X, reset=False)\n        raw_output = self.model_.predict(X)\n        return self._reverse_process_target(raw_output)\n\n    def _process_target(self, y, reset=False):\n        \"\"\"Regressors are NOOP here, classifiers do OHE.\"\"\"\n        # This is here to raise the right error in case of invalid target\n        type_of_target(y, raise_unknown=True)\n        if reset:\n            self._target_encoder = TargetReshaper().fit(y)\n        return self._target_encoder.transform(y)\n\n    def _reverse_process_target(self, y):\n        \"\"\"Regressors are NOOP here, classifiers reverse OHE.\"\"\"\n        return self._target_encoder.inverse_transform(y)\n\n\n@keras_export(\"keras.wrappers.SKLearnClassifier\")\nclass SKLearnClassifier(ClassifierMixin, SKLBase):\n    \"\"\"scikit-learn compatible classifier wrapper for Keras models.\n\n    Note that there are sources of randomness in model initialization and\n    training. Refer to [Reproducibility in Keras Models](\n    https://keras.io/examples/keras_recipes/reproducibility_recipes/) on how to\n    control randomness.\n\n    Args:\n        model: `Model`.\n            An instance of `Model`, or a callable returning such an object.\n            Note that if input is a `Model`, it will be cloned using\n            `keras.models.clone_model` before being fitted, unless\n            `warm_start=True`.\n            The `Model` instance needs to be passed as already compiled.\n            If callable, it must accept at least `X` and `y` as keyword\n            arguments. Other arguments must be accepted if passed as\n            `model_kwargs` by the user.\n        warm_start: bool, defaults to `False`.\n            Whether to reuse the model weights from the previous fit. If `True`,\n            the given model won't be cloned and the weights from the previous\n            fit will be reused.\n        model_kwargs: dict, defaults to `None`.\n            Keyword arguments passed to `model`, if `model` is callable.\n        fit_kwargs: dict, defaults to `None`.\n            Keyword arguments passed to `model.fit`. These can also be passed\n            directly to the `fit` method of the scikit-learn wrapper. The\n            values passed directly to the `fit` method take precedence over\n            these.\n\n    Attributes:\n        model_ : `Model`\n            The fitted model.\n        history_ : dict\n            The history of the fit, returned by `model.fit`.\n        classes_ : array-like, shape=(n_classes,)\n            The classes labels.\n\n    Example:\n    Here we use a function which creates a basic MLP model dynamically\n    choosing the input and output shapes. We will use this to create our\n    scikit-learn model.\n\n    ``` python\n    from keras.layers import Dense, Input\n    from keras.models import Model\n\n    def dynamic_model(X, y, loss, layers=[10]):\n        # Creates a basic MLP model dynamically choosing the input and\n        # output shapes.\n        n_features_in = X.shape[1]\n        inp = Input(shape=(n_features_in,))\n\n        hidden = inp\n        for layer_size in layers:\n            hidden = Dense(layer_size, activation=\"relu\")(hidden)\n\n        n_outputs = y.shape[1] if len(y.shape) > 1 else 1\n        out = Dense(n_outputs, activation=\"softmax\")(hidden)\n        model = Model(inp, out)\n        model.compile(loss=loss, optimizer=\"rmsprop\")\n\n        return model\n    ```\n\n    You can then use this function to create a scikit-learn compatible model\n    and fit it on some data.\n\n    ``` python\n    from sklearn.datasets import make_classification\n    from keras.wrappers import SKLearnClassifier\n\n    X, y = make_classification(n_samples=1000, n_features=10)\n    est = SKLearnClassifier(\n        model=dynamic_model,\n        model_kwargs={\n            \"loss\": \"categorical_crossentropy\",\n            \"layers\": [20, 20, 20],\n        },\n    )\n\n    est.fit(X, y, epochs=5)\n    ```\n    \"\"\"\n\n    def _process_target(self, y, reset=False):\n        \"\"\"Classifiers do OHE.\"\"\"\n        target_type = type_of_target(y, raise_unknown=True)\n        if target_type not in [\"binary\", \"multiclass\"]:\n            raise ValueError(\n                \"Only binary and multiclass target types are supported.\"\n                f\" Target type: {target_type}\"\n            )\n        if reset:\n            self._target_encoder = sklearn.pipeline.make_pipeline(\n                TargetReshaper(),\n                sklearn.preprocessing.OneHotEncoder(sparse_output=False),\n            ).fit(y)\n            self.classes_ = np.unique(y)\n            if len(self.classes_) == 1:\n                raise ValueError(\n                    \"Classifier can't train when only one class is present.\"\n                )\n        return self._target_encoder.transform(y)\n\n    def _more_tags(self):\n        # required to be compatible with scikit-learn<1.6\n        return {\"poor_score\": True}\n\n    def __sklearn_tags__(self):\n        tags = super().__sklearn_tags__()\n        tags.classifier_tags.poor_score = True\n        return tags\n\n\n@keras_export(\"keras.wrappers.SKLearnRegressor\")\nclass SKLearnRegressor(RegressorMixin, SKLBase):\n    \"\"\"scikit-learn compatible regressor wrapper for Keras models.\n\n    Note that there are sources of randomness in model initialization and\n    training. Refer to [Reproducibility in Keras Models](\n    https://keras.io/examples/keras_recipes/reproducibility_recipes/) on how to\n    control randomness.\n\n    Args:\n        model: `Model`.\n            An instance of `Model`, or a callable returning such an object.\n            Note that if input is a `Model`, it will be cloned using\n            `keras.models.clone_model` before being fitted, unless\n            `warm_start=True`.\n            The `Model` instance needs to be passed as already compiled.\n            If callable, it must accept at least `X` and `y` as keyword\n            arguments. Other arguments must be accepted if passed as\n            `model_kwargs` by the user.\n        warm_start: bool, defaults to `False`.\n            Whether to reuse the model weights from the previous fit. If `True`,\n            the given model won't be cloned and the weights from the previous\n            fit will be reused.\n        model_kwargs: dict, defaults to `None`.\n            Keyword arguments passed to `model`, if `model` is callable.\n        fit_kwargs: dict, defaults to `None`.\n            Keyword arguments passed to `model.fit`. These can also be passed\n            directly to the `fit` method of the scikit-learn wrapper. The\n            values passed directly to the `fit` method take precedence over\n            these.\n\n    Attributes:\n        model_ : `Model`\n            The fitted model.\n\n    Example:\n    Here we use a function which creates a basic MLP model dynamically\n    choosing the input and output shapes. We will use this to create our\n    scikit-learn model.\n\n    ``` python\n    from keras.layers import Dense, Input\n    from keras.models import Model\n\n    def dynamic_model(X, y, loss, layers=[10]):\n        # Creates a basic MLP model dynamically choosing the input and\n        # output shapes.\n        n_features_in = X.shape[1]\n        inp = Input(shape=(n_features_in,))\n\n        hidden = inp\n        for layer_size in layers:\n            hidden = Dense(layer_size, activation=\"relu\")(hidden)\n\n        n_outputs = y.shape[1] if len(y.shape) > 1 else 1\n        out = Dense(n_outputs)(hidden)\n        model = Model(inp, out)\n        model.compile(loss=loss, optimizer=\"rmsprop\")\n\n        return model\n    ```\n\n    You can then use this function to create a scikit-learn compatible model\n    and fit it on some data.\n\n    ``` python\n    from sklearn.datasets import make_regression\n    from keras.wrappers import SKLearnRegressor\n\n    X, y = make_regression(n_samples=1000, n_features=10)\n    est = SKLearnRegressor(\n        model=dynamic_model,\n        model_kwargs={\n            \"loss\": \"mse\",\n            \"layers\": [20, 20, 20],\n        },\n    )\n\n    est.fit(X, y, epochs=5)\n    ```\n    \"\"\"\n\n    def _more_tags(self):\n        # required to be compatible with scikit-learn<1.6\n        return {\"poor_score\": True}\n\n    def __sklearn_tags__(self):\n        tags = super().__sklearn_tags__()\n        tags.regressor_tags.poor_score = True\n        return tags\n\n\n@keras_export(\"keras.wrappers.SKLearnTransformer\")\nclass SKLearnTransformer(TransformerMixin, SKLBase):\n    \"\"\"scikit-learn compatible transformer wrapper for Keras models.\n\n    Note that this is a scikit-learn compatible transformer, and not a\n    transformer in the deep learning sense.\n\n    Also note that there are sources of randomness in model initialization and\n    training. Refer to [Reproducibility in Keras Models](\n    https://keras.io/examples/keras_recipes/reproducibility_recipes/) on how to\n    control randomness.\n\n    Args:\n        model: `Model`.\n            An instance of `Model`, or a callable returning such an object.\n            Note that if input is a `Model`, it will be cloned using\n            `keras.models.clone_model` before being fitted, unless\n            `warm_start=True`.\n            The `Model` instance needs to be passed as already compiled.\n            If callable, it must accept at least `X` and `y` as keyword\n            arguments. Other arguments must be accepted if passed as\n            `model_kwargs` by the user.\n        warm_start: bool, defaults to `False`.\n            Whether to reuse the model weights from the previous fit. If `True`,\n            the given model won't be cloned and the weights from the previous\n            fit will be reused.\n        model_kwargs: dict, defaults to `None`.\n            Keyword arguments passed to `model`, if `model` is callable.\n        fit_kwargs: dict, defaults to `None`.\n            Keyword arguments passed to `model.fit`. These can also be passed\n            directly to the `fit` method of the scikit-learn wrapper. The\n            values passed directly to the `fit` method take precedence over\n            these.\n\n    Attributes:\n        model_ : `Model`\n            The fitted model.\n        history_ : dict\n            The history of the fit, returned by `model.fit`.\n\n    Example:\n    A common use case for a scikit-learn transformer, is to have a step\n    which gives you the embedding of your data. Here we assume\n    `my_package.my_model` is a Keras model which takes the input and gives\n    embeddings of the data, and `my_package.my_data` is your dataset loader.\n\n    ``` python\n    from my_package import my_model, my_data\n    from keras.wrappers import SKLearnTransformer\n    from sklearn.frozen import FrozenEstimator # requires scikit-learn>=1.6\n    from sklearn.pipeline import make_pipeline\n    from sklearn.ensemble import HistGradientBoostingClassifier\n\n    X, y = my_data()\n\n    trs = FrozenEstimator(SKLearnTransformer(model=my_model))\n    pipe = make_pipeline(trs, HistGradientBoostingClassifier())\n    pipe.fit(X, y)\n    ```\n\n    Note that in the above example, `FrozenEstimator` prevents any further\n    training of the transformer step in the pipeline, which can be the case\n    if you don't want to change the embedding model at hand.\n    \"\"\"\n\n    def transform(self, X):\n        \"\"\"Transform the data.\n\n        Args:\n            X: array-like, shape=(n_samples, n_features)\n                The input samples.\n\n        Returns:\n            X_transformed: array-like, shape=(n_samples, n_features)\n                The transformed data.\n        \"\"\"\n        from sklearn.utils.validation import check_is_fitted\n\n        check_is_fitted(self)\n        X = _validate_data(self, X, reset=False)\n        return self.model_.predict(X)\n\n    def _more_tags(self):\n        # required to be compatible with scikit-learn<1.6\n        return {\n            \"preserves_dtype\": [],\n        }\n\n    def __sklearn_tags__(self):\n        tags = super().__sklearn_tags__()\n        tags.transformer_tags.preserves_dtype = []\n        return tags\n"
  },
  {
    "path": "keras/src/wrappers/utils.py",
    "content": "import numpy as np\n\ntry:\n    import sklearn\n    from sklearn.base import BaseEstimator\n    from sklearn.base import TransformerMixin\nexcept ImportError:\n    sklearn = None\n\n    class BaseEstimator:\n        pass\n\n    class TransformerMixin:\n        pass\n\n\ndef assert_sklearn_installed(symbol_name):\n    if sklearn is None:\n        raise ImportError(\n            f\"{symbol_name} requires `scikit-learn` to be installed. \"\n            \"Run `pip install scikit-learn` to install it.\"\n        )\n\n\ndef _check_model(model):\n    \"\"\"Check whether the model need sto be compiled.\"\"\"\n    # compile model if user gave us an un-compiled model\n    if not model.compiled or not model.loss or not model.optimizer:\n        raise RuntimeError(\n            \"Given model needs to be compiled, and have a loss \"\n            \"and an optimizer.\"\n        )\n\n\nclass TargetReshaper(TransformerMixin, BaseEstimator):\n    \"\"\"Convert 1D targets to 2D and back.\n\n    For use in pipelines with transformers that only accept\n    2D inputs, like OneHotEncoder and OrdinalEncoder.\n\n    Attributes:\n        ndim_ : int\n            Dimensions of y that the transformer was trained on.\n    \"\"\"\n\n    def fit(self, y):\n        \"\"\"Fit the transformer to a target y.\n\n        Returns:\n            TargetReshaper\n                A reference to the current instance of TargetReshaper.\n        \"\"\"\n        self.ndim_ = y.ndim\n        return self\n\n    def transform(self, y):\n        \"\"\"Makes 1D y 2D.\n\n        Args:\n            y : np.ndarray\n                Target y to be transformed.\n\n        Returns:\n            np.ndarray\n                A numpy array, of dimension at least 2.\n        \"\"\"\n        if y.ndim == 1:\n            return y.reshape(-1, 1)\n        return y\n\n    def inverse_transform(self, y):\n        \"\"\"Revert the transformation of transform.\n\n        Args:\n            y: np.ndarray\n                Transformed numpy array.\n\n        Returns:\n            np.ndarray\n                If the transformer was fit to a 1D numpy array,\n                and a 2D numpy array with a singleton second dimension\n                is passed, it will be squeezed back to 1D. Otherwise, it\n                will eb left untouched.\n        \"\"\"\n        from sklearn.utils.validation import check_is_fitted\n\n        check_is_fitted(self)\n        if self.ndim_ == 1 and y.ndim == 2:\n            return np.squeeze(y, axis=1)\n        return y\n"
  },
  {
    "path": "pip_build.py",
    "content": "\"\"\"Script to create (and optionally install) a `.whl` archive for Keras 3.\n\nUsage:\n\n1. Create a `.whl` file in `dist/`:\n\n```\npython3 pip_build.py\n```\n\n2. Also install the new package immediately after:\n\n```\npython3 pip_build.py --install\n```\n\"\"\"\n\nimport argparse\nimport datetime\nimport glob\nimport os\nimport pathlib\nimport re\nimport shutil\nimport subprocess\n\n# Needed because importing torch after TF causes the runtime to crash\ntry:\n    import torch  # noqa: F401\nexcept ImportError:\n    pass\n\npackage = \"keras\"\nbuild_directory = \"tmp_build_dir\"\ndist_directory = \"dist\"\nto_copy = [\"pyproject.toml\", \"README.md\"]\n\n\ndef export_version_string(version, is_nightly=False, rc_index=None):\n    \"\"\"Export Version and Package Name.\"\"\"\n    if is_nightly:\n        date = datetime.datetime.now()\n        version += f\".dev{date:%Y%m%d%H}\"\n        # Update `name = \"keras\"` with \"keras-nightly\"\n        pyproj_pth = pathlib.Path(\"pyproject.toml\")\n        pyproj_str = pyproj_pth.read_text().replace(\n            'name = \"keras\"', 'name = \"keras-nightly\"'\n        )\n        pyproj_pth.write_text(pyproj_str)\n    elif rc_index is not None:\n        version += f\"rc{str(rc_index)}\"\n\n    # Make sure to export the __version__ string\n    with open(os.path.join(package, \"src\", \"version.py\")) as f:\n        init_contents = f.read()\n    with open(os.path.join(package, \"src\", \"version.py\"), \"w\") as f:\n        init_contents = re.sub(\n            \"\\n__version__ = .*\\n\",\n            f'\\n__version__ = \"{version}\"\\n',\n            init_contents,\n        )\n        f.write(init_contents)\n\n\ndef ignore_files(_, filenames):\n    return [f for f in filenames if f.endswith(\"_test.py\")]\n\n\ndef copy_source_to_build_directory(root_path):\n    # Copy sources (`keras/` directory and setup files) to build\n    # directory\n    os.chdir(root_path)\n    os.mkdir(build_directory)\n    shutil.copytree(\n        package, os.path.join(build_directory, package), ignore=ignore_files\n    )\n    for fname in to_copy:\n        shutil.copy(fname, os.path.join(f\"{build_directory}\", fname))\n    os.chdir(build_directory)\n\n\ndef build(root_path, is_nightly=False, rc_index=None):\n    if os.path.exists(build_directory):\n        raise ValueError(f\"Directory already exists: {build_directory}\")\n\n    try:\n        copy_source_to_build_directory(root_path)\n\n        from keras.src.version import __version__  # noqa: E402\n\n        export_version_string(__version__, is_nightly, rc_index)\n        return build_and_save_output(root_path, __version__)\n    finally:\n        # Clean up: remove the build directory (no longer needed)\n        shutil.rmtree(build_directory)\n\n\ndef build_and_save_output(root_path, __version__):\n    # Build the package\n    os.system(\"python3 -m build\")\n\n    # Save the dist files generated by the build process\n    os.chdir(root_path)\n    if not os.path.exists(dist_directory):\n        os.mkdir(dist_directory)\n    for fpath in glob.glob(\n        os.path.join(build_directory, dist_directory, \"*.*\")\n    ):\n        shutil.copy(fpath, dist_directory)\n\n    # Find the .whl file path\n    whl_path = None\n    for fname in os.listdir(dist_directory):\n        if __version__ in fname and fname.endswith(\".whl\"):\n            whl_path = os.path.abspath(os.path.join(dist_directory, fname))\n    if whl_path:\n        print(f\"Build successful. Wheel file available at {whl_path}\")\n    else:\n        print(\"Build failed.\")\n    return whl_path\n\n\ndef install_whl(whl_fpath):\n    print(f\"Installing wheel file: {whl_fpath}\")\n    subprocess.run(\n        [\n            \"pip3\",\n            \"install\",\n            whl_fpath,\n            \"--force-reinstall\",\n            \"--no-dependencies\",\n        ],\n        check=True,\n    )\n\n\nif __name__ == \"__main__\":\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        \"--install\",\n        action=\"store_true\",\n        help=\"Whether to install the generated wheel file.\",\n    )\n    parser.add_argument(\n        \"--nightly\",\n        action=\"store_true\",\n        help=\"Whether to generate nightly wheel file.\",\n    )\n    parser.add_argument(\n        \"--rc\",\n        type=int,\n        help=\"Specify `[0-9] when generating RC wheels.\",\n    )\n    args = parser.parse_args()\n    root_path = pathlib.Path(__file__).parent.resolve()\n    whl_path = build(root_path, args.nightly, args.rc)\n    if whl_path and args.install:\n        install_whl(whl_path)\n"
  },
  {
    "path": "pyproject.toml",
    "content": "[build-system]\nrequires = [\"setuptools >=61.0\"]\nbuild-backend = \"setuptools.build_meta\"\n\n[project]\nname = \"keras\"\nauthors = [\n    {name = \"Keras team\", email = \"keras-users@googlegroups.com\"},\n]\ndescription = \"Multi-backend Keras\"\nreadme = \"README.md\"\nrequires-python = \">=3.11\"\nlicense = {text = \"Apache License 2.0\"}\ndynamic = [\"version\"]\nclassifiers = [\n    \"Development Status :: 4 - Beta\",\n    \"Programming Language :: Python :: 3\",\n    \"Programming Language :: Python :: 3.11\",\n    \"Programming Language :: Python :: 3.12\",\n    \"Programming Language :: Python :: 3 :: Only\",\n    \"Operating System :: Unix\",\n    \"Operating System :: MacOS\",\n    \"Intended Audience :: Science/Research\",\n    \"Topic :: Scientific/Engineering\",\n    \"Topic :: Software Development\",\n]\ndependencies = [\n    \"absl-py\",\n    \"numpy\",\n    \"rich\",\n    \"namex\",\n    \"h5py\",\n    \"optree\",\n    \"ml-dtypes\",\n    \"packaging\",\n]\n# Run also: pip install -r requirements.txt\n\n[project.urls]\nHome = \"https://keras.io/\"\nRepository = \"https://github.com/keras-team/keras\"\n\n[tool.setuptools.dynamic]\nversion = {attr = \"keras.src.version.__version__\"}\n\n[tool.setuptools.package-dir]\n\"\" = \".\"\n\"keras\" = \"keras/api\"  # Remap api/ to the root of the package.\n\"keras.src\" = \"keras/src\"\n\n[tool.ruff]\nline-length = 80\nexclude = [\"keras/src/namex\"]\n\n[tool.ruff.lint]\nselect = [\n    \"E\",  # pycodestyle error\n    \"F\",  # Pyflakes\n    \"I\",  # isort\n]\nignore = [\n    \"E722\",  # do not use bare 'except'\n    \"E741\",  # ambiguous variable name\n    \"E731\",  # do not assign a `lambda` expression, use a `def`\n]\n\n[tool.ruff.lint.per-file-ignores]\n\"**/__init__.py\" = [\"E501\", \"F401\"]  # lines too long; imported but unused\n\"**/random.py\" = [\"F401\"]  # imported but unused\n\"examples/*\" = [\"I\", \"E\"]\n\"guides/*\" = [\"I\", \"E\", \"F\"]\n\n[tool.ruff.lint.isort]\nforce-single-line = true\nknown-first-party = [\"keras\"]\n\n[tool.pytest.ini_options]\nfilterwarnings = [\n    \"error\",\n    \"ignore::DeprecationWarning\",\n    \"ignore::ImportWarning\",\n    \"ignore::RuntimeWarning\",\n    \"ignore::PendingDeprecationWarning\",\n    \"ignore::FutureWarning\",\n    \"ignore::UserWarning\",\n    # Ignore a spurious warning on tf-nightly related to save model changes.\n    \"ignore:Custom mask layers require a config\",\n]\naddopts = \"-vv\"\n\n# Do not run tests in the `build` folders\nnorecursedirs = [\"build\"]\n\n[tool.coverage.report]\nexclude_lines = [\n    \"pragma: no cover\",\n    \"@abstract\",\n    \"raise NotImplementedError\",\n]\nomit = [\n    \"*/*_test.py\",\n    \"keras/src/legacy/*\",\n]\n\n[tool.coverage.run]\nbranch = true\nomit = [\n    \"*/*_test.py\",\n    \"keras/src/legacy/*\",\n]\n\n"
  },
  {
    "path": "requirements-common.txt",
    "content": "pre-commit\nnamex>=0.0.8\nruff\npytest\nnumpy\nscipy\nscikit-learn\npillow\npandas\nabsl-py\nrequests\nh5py\nml-dtypes\nprotobuf\ntensorboard\ntensorboard-plugin-profile\nrich\nbuild\noptree\npytest-cov\npackaging\n# for tree_test.py\ndm_tree\ncoverage\n# for onnx_test.py\nonnxruntime\n# https://github.com/keras-team/keras/issues/21390\n# onnxscript==0.3.2 breaks LSTM model export.\nonnxscript!=0.3.2\nopenvino\ngrain\n# for orbax_checkpoint_test.py\norbax-checkpoint\n"
  },
  {
    "path": "requirements-jax-cuda.txt",
    "content": "# Tensorflow cpu-only version (needed for testing).\ntensorflow-cpu\ntf2onnx\n\n# Torch cpu-only version (needed for testing).\n--extra-index-url https://download.pytorch.org/whl/cpu\ntorch\n\n# Jax with cuda support.\n--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html\njax[cuda12]==0.6.2\nflax\n\n-r requirements-common.txt\n"
  },
  {
    "path": "requirements-jax-tpu.txt",
    "content": "# Tensorflow cpu-only version (needed for testing).\ntensorflow-cpu\ntf2onnx\n\n# Torch cpu-only version (needed for testing).\n--extra-index-url https://download.pytorch.org/whl/cpu\ntorch\n\n# Jax with cuda support.\n--find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html\njax[tpu]\nflax\n\n-r requirements-common.txt\n"
  },
  {
    "path": "requirements-tensorflow-cuda.txt",
    "content": "# Tensorflow with cuda support.\ntensorflow[and-cuda]\ntf2onnx\nai-edge-litert\n\n# Torch cpu-only version (needed for testing).\n--extra-index-url https://download.pytorch.org/whl/cpu\ntorch\n\n# Jax cpu-only version (needed for testing).\njax[cpu]\n\n-r requirements-common.txt\n"
  },
  {
    "path": "requirements-tensorflow-tpu.txt",
    "content": "# 2.19.1 is the last version of the TensorFlow that supports TPUs.\n--find-links https://storage.googleapis.com/libtpu-tf-releases/index.html\ntensorflow-tpu==2.19.1\n\ntf2onnx\n\n# Torch cpu-only version (needed for testing).\n--extra-index-url https://download.pytorch.org/whl/cpu\ntorch\n\n# Jax cpu-only version (needed for testing).\njax\n\n-r requirements-common.txt\n"
  },
  {
    "path": "requirements-torch-cuda.txt",
    "content": "# Tensorflow cpu-only version (needed for testing).\ntensorflow-cpu\ntf2onnx\n\n# Torch with cuda support.\n# - torch is pinned to a version that is compatible with torch-xla.\n--extra-index-url https://download.pytorch.org/whl/cu126\ntorch\ntorch-xla;sys_platform != 'darwin'\n\n# Jax cpu-only version (needed for testing).\njax[cpu]\n\n-r requirements-common.txt\n"
  },
  {
    "path": "requirements.txt",
    "content": "# Tensorflow.\n# Note: when the version of Tensorflow is changed, the version tf_keras must be\n# changed in .github/workflows/actions.yml (pip install --no-deps tf_keras).\ntensorflow-cpu;sys_platform != 'darwin'\ntensorflow;sys_platform == 'darwin'\ntf2onnx\nai-edge-litert\n\n# Torch.\n--extra-index-url https://download.pytorch.org/whl/cpu\ntorch\ntorch-xla;sys_platform != 'darwin'\n\n# Jax.\n# Pinned to 0.8.0 on CPU for CI compatibility with older backends.\n# Note that we test against the latest JAX on GPU.\njax[cpu]\nflax\n\n# Common deps.\n-r requirements-common.txt\n"
  },
  {
    "path": "shell/api_gen.sh",
    "content": "#!/bin/bash\nset -Eeuo pipefail\n\nbase_dir=$(dirname $(dirname $0))\n\necho \"Generating api directory with public APIs...\"\n# Generate API Files\npython3 \"${base_dir}\"/api_gen.py\n\n# Format code because `api_gen.py` might order\n# imports differently.\necho \"Formatting api directory...\"\n(SKIP=api-gen pre-commit run --files $(find \"${base_dir}\"/keras/api -type f) --hook-stage pre-commit || true) > /dev/null\n"
  },
  {
    "path": "shell/format.sh",
    "content": "#!/bin/bash\nset -Eeuo pipefail\n\nif ! command -v pre-commit 2>&1 >/dev/null\nthen\n    echo 'Please `pip install pre-commit` to run format.sh.'\n    exit 1\nfi\n\nbase_dir=$(dirname $(dirname $0))\n\necho \"Formatting all files...\"\nSKIP=api-gen pre-commit run --all-files\n"
  }
]